diff --git a/Cargo.lock b/Cargo.lock index 770e548..785346c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -325,7 +325,7 @@ dependencies = [ "futures-lite", "parking", "polling", - "rustix", + "rustix 0.38.34", "slab", "tracing", "windows-sys 0.59.0", @@ -357,7 +357,7 @@ dependencies = [ "cfg-if", "event-listener", "futures-lite", - "rustix", + "rustix 0.38.34", "tracing", "windows-sys 0.59.0", ] @@ -395,7 +395,7 @@ dependencies = [ "cfg-if", "futures-core", "futures-io", - "rustix", + "rustix 0.38.34", "signal-hook-registry", "slab", "windows-sys 0.59.0", @@ -1663,6 +1663,17 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "575f75dfd25738df5b91b8e43e14d44bda14637a58fae779fd2b064f8bf3e010" +[[package]] +name = "dbus" +version = "0.9.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bb21987b9fb1613058ba3843121dd18b163b254d8a6e797e144cbac14d96d1b" +dependencies = [ + "libc", + "libdbus-sys", + "winapi", +] + [[package]] name = "defguard_wireguard_rs" version = "0.4.2" @@ -1736,6 +1747,37 @@ dependencies = [ "serde", ] +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn 2.0.87", +] + [[package]] name = "derive_more" version = "0.99.18" @@ -1919,7 +1961,9 @@ dependencies = [ "clap", "crossbeam", "dashmap", + "dbus", "defguard_wireguard_rs", + "derive_builder", "easytier-rpc-build 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "encoding", "futures", @@ -1927,17 +1971,22 @@ dependencies = [ "gethostname 0.5.0", "git-version", "globwalk", + "hickory-client", "hickory-proto", "hickory-resolver", + "hickory-server", "http", "http_req", "humansize", + "humantime-serde", "jemalloc-ctl", "jemalloc-sys", "jemallocator", "kcp-sys", "machine-uid", + "maplit", "mimalloc-rust", + "multimap", "netlink-packet-core", "netlink-packet-route 0.21.0", "netlink-packet-utils", @@ -1960,6 +2009,7 @@ dependencies = [ "rcgen", "regex", "reqwest", + "resolv-conf", "ring", "ringbuf", "rstest", @@ -1993,6 +2043,8 @@ dependencies = [ "tun-easytier", "url", "uuid", + "version-compare", + "which 7.0.3", "wildmatch", "windows 0.52.0", "windows-service", @@ -2216,6 +2268,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3d8a32ae18130a3c84dd492d4215c3d913c3b07c6b63c2eb3eb7ff1101ab7bf" +[[package]] +name = "endian-type" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" + [[package]] name = "enum-as-inner" version = "0.6.1" @@ -2249,6 +2307,12 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "env_home" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7f84e12ccf0a7ddc17a6c41c93326024c42920d7ee630d04950e6926645c0fe" + [[package]] name = "equivalent" version = "1.0.1" @@ -2267,12 +2331,12 @@ dependencies = [ [[package]] name = "errno" -version = "0.3.9" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2735,7 +2799,7 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc3655aa6818d65bc620d6911f05aa7b6aeb596291e1e9f79e52df85583d1e30" dependencies = [ - "rustix", + "rustix 0.38.34", "windows-targets 0.52.6", ] @@ -2765,9 +2829,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.3.3" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" +checksum = "73fea8450eea4bac3940448fb7ae50d91f034f941199fcd9d909a5a07aa455f0" dependencies = [ "cfg-if", "libc", @@ -3090,6 +3154,25 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hickory-client" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c466cd63a4217d5b2b8e32f23f58312741ce96e3c84bf7438677d2baff0fc555" +dependencies = [ + "cfg-if", + "data-encoding", + "futures-channel", + "futures-util", + "hickory-proto", + "once_cell", + "radix_trie", + "rand 0.9.1", + "thiserror 2.0.11", + "tokio", + "tracing", +] + [[package]] name = "hickory-proto" version = "0.25.2" @@ -3108,6 +3191,7 @@ dependencies = [ "once_cell", "rand 0.9.1", "ring", + "serde", "thiserror 2.0.11", "tinyvec", "tokio", @@ -3130,12 +3214,37 @@ dependencies = [ "parking_lot", "rand 0.9.1", "resolv-conf", + "serde", "smallvec", "thiserror 2.0.11", "tokio", "tracing", ] +[[package]] +name = "hickory-server" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d53e5fe811b941c74ee46b8818228bfd2bc2688ba276a0eaeb0f2c95ea3b2585" +dependencies = [ + "async-trait", + "bytes", + "cfg-if", + "data-encoding", + "enum-as-inner", + "futures-util", + "hickory-proto", + "hickory-resolver", + "ipnet", + "prefix-trie", + "serde", + "thiserror 2.0.11", + "time", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "hkdf" version = "0.12.4" @@ -3163,17 +3272,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "hostname" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c731c3e10504cc8ed35cfe2f1db4c9274c3d35fa486e3b31df46f068ef3e867" -dependencies = [ - "libc", - "match_cfg", - "winapi", -] - [[package]] name = "html5ever" version = "0.26.0" @@ -3259,6 +3357,22 @@ dependencies = [ "libm", ] +[[package]] +name = "humantime" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b112acc8b3adf4b107a8ec20977da0273a8c386765a3ec0229bd500a1443f9f" + +[[package]] +name = "humantime-serde" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57a3db5ea5923d99402c94e9feb261dc5ee9b4efa158b0315f788cf549cc200c" +dependencies = [ + "humantime", + "serde", +] + [[package]] name = "hyper" version = "1.4.1" @@ -3666,9 +3780,12 @@ dependencies = [ [[package]] name = "ipnet" -version = "2.9.0" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +dependencies = [ + "serde", +] [[package]] name = "ipnetwork" @@ -3964,6 +4081,16 @@ version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" +[[package]] +name = "libdbus-sys" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06085512b750d640299b79be4bad3d2fa90a9c00b1fd9e1b46364f66f0485c72" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "libloading" version = "0.7.4" @@ -4023,6 +4150,12 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +[[package]] +name = "linux-raw-sys" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" + [[package]] name = "litemap" version = "0.7.5" @@ -4103,6 +4236,12 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ca88d725a0a943b096803bd34e73a4437208b6077654cc4ecb2947a5f91618d" +[[package]] +name = "maplit" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" + [[package]] name = "markup5ever" version = "0.11.0" @@ -4117,12 +4256,6 @@ dependencies = [ "tendril", ] -[[package]] -name = "match_cfg" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4" - [[package]] name = "matchers" version = "0.1.0" @@ -4188,8 +4321,7 @@ dependencies = [ [[package]] name = "mimalloc-rust" version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5eb726c8298efb4010b2c46d8050e4be36cf807b9d9e98cb112f830914fc9bbe" +source = "git+https://github.com/EasyTier/mimalloc-rust#eb61c4d50fef4eb5fbd5db83e4ea83153646b482" dependencies = [ "cty", "mimalloc-rust-sys", @@ -4197,9 +4329,8 @@ dependencies = [ [[package]] name = "mimalloc-rust-sys" -version = "1.7.9-source" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6413e13241a9809f291568133eca6694572cf528c1a6175502d090adce5dd5db" +version = "2.1.2-source" +source = "git+https://github.com/EasyTier/mimalloc-rust#eb61c4d50fef4eb5fbd5db83e4ea83153646b482" dependencies = [ "cc", "cty", @@ -4293,6 +4424,9 @@ name = "multimap" version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" +dependencies = [ + "serde", +] [[package]] name = "nalgebra" @@ -4463,6 +4597,15 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" +[[package]] +name = "nibble_vec" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" +dependencies = [ + "smallvec", +] + [[package]] name = "nix" version = "0.25.1" @@ -5545,7 +5688,7 @@ dependencies = [ "concurrent-queue", "hermit-abi 0.4.0", "pin-project-lite", - "rustix", + "rustix 0.38.34", "tracing", "windows-sys 0.59.0", ] @@ -5600,6 +5743,16 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" +[[package]] +name = "prefix-trie" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85cf4c7c25f1dd66c76b451e9041a8cfce26e4ca754934fa7aed8d5a59a01d20" +dependencies = [ + "ipnet", + "num-traits", +] + [[package]] name = "prettyplease" version = "0.2.20" @@ -5618,7 +5771,7 @@ checksum = "765ec92721e112ffe07f5c06fb0654da0b708990888981d05cf12a7c9909df30" dependencies = [ "libc", "security-framework-sys", - "which", + "which 4.4.2", "windows-sys 0.48.0", ] @@ -5830,12 +5983,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "quick-error" -version = "1.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" - [[package]] name = "quick-xml" version = "0.32.0" @@ -5915,6 +6062,16 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" +[[package]] +name = "radix_trie" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c069c179fcdc6a2fe24d8d18305cf085fdbd4f922c041943e203685d6a1c58fd" +dependencies = [ + "endian-type", + "nibble_vec", +] + [[package]] name = "rand" version = "0.7.3" @@ -6004,7 +6161,7 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom 0.3.3", + "getrandom 0.3.2", ] [[package]] @@ -6206,13 +6363,9 @@ dependencies = [ [[package]] name = "resolv-conf" -version = "0.7.0" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52e44394d2086d010551b14b53b1f24e31647570cd1deb0379e2c21b329aba00" -dependencies = [ - "hostname", - "quick-error", -] +checksum = "fc7c8f7f733062b66dc1c63f9db168ac0b97a9210e247fa90fdc9ad08f51b302" [[package]] name = "ring" @@ -6472,10 +6625,23 @@ dependencies = [ "bitflags 2.8.0", "errno", "libc", - "linux-raw-sys", + "linux-raw-sys 0.4.14", "windows-sys 0.52.0", ] +[[package]] +name = "rustix" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" +dependencies = [ + "bitflags 2.8.0", + "errno", + "libc", + "linux-raw-sys 0.9.4", + "windows-sys 0.59.0", +] + [[package]] name = "rustls" version = "0.23.12" @@ -7077,7 +7243,7 @@ dependencies = [ "encoding-utils", "encoding_rs", "plist", - "which", + "which 4.4.2", "xml-rs", ] @@ -8164,7 +8330,7 @@ dependencies = [ "cfg-if", "fastrand", "once_cell", - "rustix", + "rustix 0.38.34", "windows-sys 0.59.0", ] @@ -8185,7 +8351,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5352447f921fda68cf61b4101566c0bdb5104eff6804d0678e5227580ab6a4e9" dependencies = [ - "rustix", + "rustix 0.38.34", "windows-sys 0.59.0", ] @@ -9349,7 +9515,19 @@ dependencies = [ "either", "home", "once_cell", - "rustix", + "rustix 0.38.34", +] + +[[package]] +name = "which" +version = "7.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d643ce3fd3e5b54854602a080f34fb10ab75e0b813ee32d00ca2b44fa74762" +dependencies = [ + "either", + "env_home", + "rustix 1.0.7", + "winsafe", ] [[package]] @@ -9812,6 +9990,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "winsafe" +version = "0.0.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d135d17ab770252ad95e9a872d365cf3090e3be864a34ab46f48555993efc904" + [[package]] name = "wintun" version = "0.5.0" @@ -9924,7 +10108,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d91ffca73ee7f68ce055750bf9f6eca0780b8c85eff9bc046a3b0da41755e12" dependencies = [ "gethostname 0.4.3", - "rustix", + "rustix 0.38.34", "x11rb-protocol", ] diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index d890505..6076f71 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -153,7 +153,7 @@ humansize = "2.1.3" base64 = "0.22" -mimalloc-rust = { version = "0.2.1", optional = true } +mimalloc-rust = { git = "https://github.com/EasyTier/mimalloc-rust", optional = true } # mips atomic-shim = "0.2.0" @@ -205,6 +205,14 @@ http_req = { git = "https://github.com/EasyTier/http_req.git", default-features hickory-resolver = "0.25.2" hickory-proto = "0.25.2" +# for magic dns +hickory-client = "0.25.2" +hickory-server = { version = "0.25.2", features = ["resolver"] } +derive_builder = "0.20.2" +humantime-serde = "1.1.1" +multimap = "0.10.0" +version-compare = "0.2.0" + bounded_join_set = "0.3.0" jemallocator = { version = "0.5.4", optional = true } @@ -223,6 +231,10 @@ netlink-sys = "0.8.7" netlink-packet-route = "0.21.0" netlink-packet-core = { version = "0.7.0" } netlink-packet-utils = "0.5.2" +# for magic dns +resolv-conf = "0.7.3" +dbus = { version = "0.9.7", features = ["vendored"] } +which = "7.0.3" [target.'cfg(windows)'.dependencies] windows = { version = "0.52.0", features = [ @@ -264,6 +276,7 @@ thunk-rs = { git = "https://github.com/easytier/thunk.git", default-features = f serial_test = "3.0.0" rstest = "0.18.2" futures-util = "0.3.30" +maplit = "1.0.2" [target.'cfg(target_os = "linux")'.dev-dependencies] defguard_wireguard_rs = "0.4.2" diff --git a/easytier/build.rs b/easytier/build.rs index 271a17d..2cb725f 100644 --- a/easytier/build.rs +++ b/easytier/build.rs @@ -129,7 +129,10 @@ fn check_locale() { fn main() -> Result<(), Box> { // enable thunk-rs when target os is windows and arch is x86_64 or i686 #[cfg(target_os = "windows")] - if !std::env::var("TARGET").unwrap_or_default().contains("aarch64"){ + if !std::env::var("TARGET") + .unwrap_or_default() + .contains("aarch64") + { thunk::thunk(); } @@ -143,6 +146,7 @@ fn main() -> Result<(), Box> { "src/proto/tests.proto", "src/proto/cli.proto", "src/proto/web.proto", + "src/proto/magic_dns.proto", ]; for proto_file in proto_files.iter().chain(proto_files_reflect.iter()) { diff --git a/easytier/locales/app.yml b/easytier/locales/app.yml index c27d021..a9a7693 100644 --- a/easytier/locales/app.yml +++ b/easytier/locales/app.yml @@ -152,6 +152,9 @@ core_clap: port_forward: en: "forward local port to remote port in virtual network. e.g.: udp://0.0.0.0:12345/10.126.126.1:23456, means forward local udp port 12345 to 10.126.126.1:23456 in the virtual network. can specify multiple." zh-CN: "将本地端口转发到虚拟网络中的远程端口。例如:udp://0.0.0.0:12345/10.126.126.1:23456,表示将本地UDP端口12345转发到虚拟网络中的10.126.126.1:23456。可以指定多个。" + accept_dns: + en: "if true, enable magic dns. with magic dns, you can access other nodes with a domain name, e.g.: .et.net. magic dns will modify your system dns settings, enable it carefully." + zh-CN: "如果为true,则启用魔法DNS。使用魔法DNS,您可以使用域名访问其他节点,例如:.et.net。魔法DNS将修改您的系统DNS设置,请谨慎启用。" core_app: panic_backtrace_save: diff --git a/easytier/src/common/config.rs b/easytier/src/common/config.rs index 5738c5e..8295976 100644 --- a/easytier/src/common/config.rs +++ b/easytier/src/common/config.rs @@ -36,6 +36,7 @@ pub fn gen_default_flags() -> Flags { enable_kcp_proxy: false, disable_kcp_input: false, disable_relay_kcp: true, + accept_dns: false, } } diff --git a/easytier/src/common/global_ctx.rs b/easytier/src/common/global_ctx.rs index bae73ab..6da3a7f 100644 --- a/easytier/src/common/global_ctx.rs +++ b/easytier/src/common/global_ctx.rs @@ -63,7 +63,7 @@ pub struct GlobalCtx { ip_collector: Arc, - hostname: String, + hostname: Mutex, stun_info_collection: Box, @@ -122,7 +122,7 @@ impl GlobalCtx { ip_collector: Arc::new(IPCollector::new(net_ns, stun_info_collection.clone())), - hostname, + hostname: Mutex::new(hostname), stun_info_collection: Box::new(stun_info_collection), @@ -219,7 +219,11 @@ impl GlobalCtx { } pub fn get_hostname(&self) -> String { - return self.hostname.clone(); + return self.hostname.lock().unwrap().clone(); + } + + pub fn set_hostname(&self, hostname: String) { + *self.hostname.lock().unwrap() = hostname; } pub fn get_stun_info_collector(&self) -> impl StunInfoCollectorTrait + '_ { @@ -300,7 +304,10 @@ impl GlobalCtx { #[cfg(test)] pub mod tests { - use crate::common::{config::TomlConfigLoader, new_peer_id}; + use crate::{ + common::{config::TomlConfigLoader, new_peer_id, stun::MockStunInfoCollector}, + proto::common::NatType, + }; use super::*; @@ -340,7 +347,12 @@ pub mod tests { let config_fs = TomlConfigLoader::default(); config_fs.set_inst_name(format!("test_{}", config_fs.get_id())); config_fs.set_network_identity(network_identy.unwrap_or(NetworkIdentity::default())); - std::sync::Arc::new(GlobalCtx::new(config_fs)) + + let ctx = Arc::new(GlobalCtx::new(config_fs)); + ctx.replace_stun_info_collector(Box::new(MockStunInfoCollector { + udp_nat_type: NatType::Unknown, + })); + ctx } pub fn get_mock_global_ctx() -> ArcGlobalCtx { diff --git a/easytier/src/common/ifcfg/darwin.rs b/easytier/src/common/ifcfg/darwin.rs index a496726..2cf13be 100644 --- a/easytier/src/common/ifcfg/darwin.rs +++ b/easytier/src/common/ifcfg/darwin.rs @@ -12,13 +12,15 @@ impl IfConfiguerTrait for MacIfConfiger { name: &str, address: Ipv4Addr, cidr_prefix: u8, + cost: Option, ) -> Result<(), Error> { run_shell_cmd( format!( - "route -n add {} -netmask {} -interface {} -hopcount 7", + "route -n add {} -netmask {} -interface {} -hopcount {}", address, cidr_to_subnet_mask(cidr_prefix), - name + name, + cost.unwrap_or(7) ) .as_str(), ) diff --git a/easytier/src/common/ifcfg/mod.rs b/easytier/src/common/ifcfg/mod.rs index 79a8ea2..e779cac 100644 --- a/easytier/src/common/ifcfg/mod.rs +++ b/easytier/src/common/ifcfg/mod.rs @@ -21,6 +21,7 @@ pub trait IfConfiguerTrait: Send + Sync { _name: &str, _address: Ipv4Addr, _cidr_prefix: u8, + _cost: Option, ) -> Result<(), Error> { Ok(()) } @@ -125,3 +126,6 @@ pub type IfConfiger = windows::WindowsIfConfiger; target_os = "freebsd", )))] pub type IfConfiger = DummyIfConfiger; + +#[cfg(target_os = "windows")] +pub use windows::RegistryManager; diff --git a/easytier/src/common/ifcfg/netlink.rs b/easytier/src/common/ifcfg/netlink.rs index 0ddcb2a..80ecbd2 100644 --- a/easytier/src/common/ifcfg/netlink.rs +++ b/easytier/src/common/ifcfg/netlink.rs @@ -350,6 +350,7 @@ impl IfConfiguerTrait for NetlinkIfConfiger { name: &str, address: Ipv4Addr, cidr_prefix: u8, + cost: Option, ) -> Result<(), Error> { let mut message = RouteMessage::default(); @@ -359,7 +360,9 @@ impl IfConfiguerTrait for NetlinkIfConfiger { message.header.kind = RouteType::Unicast; message.header.address_family = AddressFamily::Inet; // metric - message.attributes.push(RouteAttribute::Priority(65535)); + message + .attributes + .push(RouteAttribute::Priority(cost.unwrap_or(65535) as u32)); // output interface message .attributes @@ -550,7 +553,7 @@ mod tests { ifcfg.set_link_status(DUMMY_IFACE_NAME, true).await.unwrap(); ifcfg - .add_ipv4_route(DUMMY_IFACE_NAME, "10.5.5.0".parse().unwrap(), 24) + .add_ipv4_route(DUMMY_IFACE_NAME, "10.5.5.0".parse().unwrap(), 24, None) .await .unwrap(); diff --git a/easytier/src/common/ifcfg/windows.rs b/easytier/src/common/ifcfg/windows.rs index 3abdc2c..5699104 100644 --- a/easytier/src/common/ifcfg/windows.rs +++ b/easytier/src/common/ifcfg/windows.rs @@ -1,6 +1,10 @@ -use std::net::Ipv4Addr; +use std::{io, net::Ipv4Addr}; use async_trait::async_trait; +use winreg::{ + enums::{HKEY_LOCAL_MACHINE, KEY_READ, KEY_WRITE}, + RegKey, +}; use super::{cidr_to_subnet_mask, run_shell_cmd, Error, IfConfiguerTrait}; @@ -59,16 +63,18 @@ impl IfConfiguerTrait for WindowsIfConfiger { name: &str, address: Ipv4Addr, cidr_prefix: u8, + cost: Option, ) -> Result<(), Error> { let Some(idx) = Self::get_interface_index(name) else { return Err(Error::NotFound); }; run_shell_cmd( format!( - "route ADD {} MASK {} 10.1.1.1 IF {} METRIC 9000", + "route ADD {} MASK {} 10.1.1.1 IF {} METRIC {}", address, cidr_to_subnet_mask(cidr_prefix), - idx + idx, + cost.unwrap_or(9000) ) .as_str(), ) @@ -164,3 +170,220 @@ impl IfConfiguerTrait for WindowsIfConfiger { .await } } + +pub struct RegistryManager; + +impl RegistryManager { + pub const IPV4_TCPIP_INTERFACE_PREFIX: &str = + r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\"; + pub const IPV6_TCPIP_INTERFACE_PREFIX: &str = + r"SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\"; + pub const NETBT_INTERFACE_PREFIX: &str = + r"SYSTEM\CurrentControlSet\Services\NetBT\Parameters\Interfaces\Tcpip_"; + + pub fn reg_delete_obsoleted_items(dev_name: &str) -> io::Result<()> { + use winreg::{enums::HKEY_LOCAL_MACHINE, enums::KEY_ALL_ACCESS, RegKey}; + let hklm = RegKey::predef(HKEY_LOCAL_MACHINE); + let profiles_key = hklm.open_subkey_with_flags( + "SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\NetworkList\\Profiles", + KEY_ALL_ACCESS, + )?; + let unmanaged_key = hklm.open_subkey_with_flags( + "SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\NetworkList\\Signatures\\Unmanaged", + KEY_ALL_ACCESS, + )?; + // collect subkeys to delete + let mut keys_to_delete = Vec::new(); + let mut keys_to_delete_unmanaged = Vec::new(); + for subkey_name in profiles_key.enum_keys().filter_map(Result::ok) { + let subkey = profiles_key.open_subkey(&subkey_name)?; + // check if ProfileName contains "et" + match subkey.get_value::("ProfileName") { + Ok(profile_name) => { + if profile_name.contains("et_") + || (!dev_name.is_empty() && dev_name == profile_name) + { + keys_to_delete.push(subkey_name); + } + } + Err(e) => { + tracing::error!( + "Failed to read ProfileName for subkey {}: {}", + subkey_name, + e + ); + } + } + } + for subkey_name in unmanaged_key.enum_keys().filter_map(Result::ok) { + let subkey = unmanaged_key.open_subkey(&subkey_name)?; + // check if ProfileName contains "et" + match subkey.get_value::("Description") { + Ok(profile_name) => { + if profile_name.contains("et_") + || (!dev_name.is_empty() && dev_name == profile_name) + { + keys_to_delete_unmanaged.push(subkey_name); + } + } + Err(e) => { + tracing::error!( + "Failed to read ProfileName for subkey {}: {}", + subkey_name, + e + ); + } + } + } + // delete collected subkeys + if !keys_to_delete.is_empty() { + for subkey_name in keys_to_delete { + match profiles_key.delete_subkey_all(&subkey_name) { + Ok(_) => tracing::trace!("Successfully deleted subkey: {}", subkey_name), + Err(e) => tracing::error!("Failed to delete subkey {}: {}", subkey_name, e), + } + } + } + if !keys_to_delete_unmanaged.is_empty() { + for subkey_name in keys_to_delete_unmanaged { + match unmanaged_key.delete_subkey_all(&subkey_name) { + Ok(_) => tracing::trace!("Successfully deleted subkey: {}", subkey_name), + Err(e) => tracing::error!("Failed to delete subkey {}: {}", subkey_name, e), + } + } + } + Ok(()) + } + + pub fn reg_change_catrgory_in_profile(dev_name: &str) -> io::Result<()> { + use winreg::{enums::HKEY_LOCAL_MACHINE, enums::KEY_ALL_ACCESS, RegKey}; + let hklm = RegKey::predef(HKEY_LOCAL_MACHINE); + let profiles_key = hklm.open_subkey_with_flags( + "SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\NetworkList\\Profiles", + KEY_ALL_ACCESS, + )?; + + for subkey_name in profiles_key.enum_keys().filter_map(Result::ok) { + let subkey = profiles_key.open_subkey_with_flags(&subkey_name, KEY_ALL_ACCESS)?; + match subkey.get_value::("ProfileName") { + Ok(profile_name) => { + if !dev_name.is_empty() && dev_name == profile_name { + match subkey.set_value("Category", &1u32) { + Ok(_) => tracing::trace!("Successfully set Category in registry"), + Err(e) => tracing::error!("Failed to set Category in registry: {}", e), + } + } + } + Err(e) => { + tracing::error!( + "Failed to read ProfileName for subkey {}: {}", + subkey_name, + e + ); + } + } + } + Ok(()) + } + + // 根据接口名称查找 GUID + pub fn find_interface_guid(interface_name: &str) -> io::Result { + // 注册表路径:所有网络接口的根目录 + let network_key_path = + r"SYSTEM\CurrentControlSet\Control\Network\{4D36E972-E325-11CE-BFC1-08002BE10318}"; + + let hklm = RegKey::predef(HKEY_LOCAL_MACHINE); + let network_key = hklm.open_subkey_with_flags(network_key_path, KEY_READ)?; + + // 遍历该路径下的所有 GUID 子键 + for guid in network_key.enum_keys().map_while(Result::ok) { + if let Ok(guid_key) = network_key.open_subkey_with_flags(&guid, KEY_READ) { + // 检查 Connection/Name 是否匹配目标接口名 + if let Ok(conn_key) = guid_key.open_subkey_with_flags("Connection", KEY_READ) { + if let Ok(name) = conn_key.get_value::("Name") { + if name == interface_name { + return Ok(guid); + } + } + } + } + } + + // 如果没有找到对应的接口 + Err(io::Error::new( + io::ErrorKind::NotFound, + "Interface not found", + )) + } + + // 打开注册表键 + pub fn open_interface_key(interface_guid: &str, prefix: &str) -> io::Result { + let path = format!(r"{}{}", prefix, interface_guid); + let hkey_local_machine = RegKey::predef(HKEY_LOCAL_MACHINE); + hkey_local_machine.open_subkey_with_flags(&path, KEY_WRITE) + } + + // 禁用动态 DNS 更新 + // disableDynamicUpdates sets the appropriate registry values to prevent the + // Windows DHCP client from sending dynamic DNS updates for our interface to + // AD domain controllers. + pub fn disable_dynamic_updates(interface_guid: &str) -> io::Result<()> { + let prefixes = [ + Self::IPV4_TCPIP_INTERFACE_PREFIX, + Self::IPV6_TCPIP_INTERFACE_PREFIX, + ]; + + for prefix in &prefixes { + let key = match Self::open_interface_key(interface_guid, prefix) { + Ok(k) => k, + Err(e) => { + // 模拟 mute-key-not-found-if-closing 行为 + if matches!(e.kind(), io::ErrorKind::NotFound) { + continue; + } else { + return Err(e); + } + } + }; + + key.set_value("RegistrationEnabled", &0u32)?; + key.set_value("DisableDynamicUpdate", &1u32)?; + key.set_value("MaxNumberOfAddressesToRegister", &0u32)?; + } + + Ok(()) + } + + // 设置单个 DWORD 值到指定的注册表路径下 + fn set_single_dword( + interface_guid: &str, + prefix: &str, + value_name: &str, + data: u32, + ) -> io::Result<()> { + let key = match Self::open_interface_key(interface_guid, prefix) { + Ok(k) => k, + Err(e) => { + // 模拟 muteKeyNotFoundIfClosing 行为:忽略 Key Not Found 错误 + return if matches!(e.kind(), io::ErrorKind::NotFound) { + Ok(()) + } else { + Err(e) + }; + } + }; + + key.set_value(value_name, &data)?; + Ok(()) + } + + // 禁用 NetBIOS 名称解析请求 + pub fn disable_netbios(interface_guid: &str) -> io::Result<()> { + Self::set_single_dword( + interface_guid, + Self::NETBT_INTERFACE_PREFIX, + "NetbiosOptions", + 2, + ) + } +} diff --git a/easytier/src/easytier-core.rs b/easytier/src/easytier-core.rs index 1d6fa84..4f86a8c 100644 --- a/easytier/src/easytier-core.rs +++ b/easytier/src/easytier-core.rs @@ -445,6 +445,13 @@ struct Cli { num_args = 1.. )] port_forward: Vec, + + #[arg( + long, + env = "ET_ACCEPT_DNS", + help = t!("core_clap.accept_dns").to_string(), + )] + accept_dns: Option, } rust_i18n::i18n!("locales", fallback = "en"); @@ -762,6 +769,7 @@ impl TryFrom<&Cli> for TomlConfigLoader { f.bind_device = cli.bind_device.unwrap_or(f.bind_device); f.enable_kcp_proxy = cli.enable_kcp_proxy.unwrap_or(f.enable_kcp_proxy); f.disable_kcp_input = cli.disable_kcp_input.unwrap_or(f.disable_kcp_input); + f.accept_dns = cli.accept_dns.unwrap_or(f.accept_dns); cfg.set_flags(f); if !cli.exit_nodes.is_empty() { diff --git a/easytier/src/gateway/tcp_proxy.rs b/easytier/src/gateway/tcp_proxy.rs index 05d617b..520318e 100644 --- a/easytier/src/gateway/tcp_proxy.rs +++ b/easytier/src/gateway/tcp_proxy.rs @@ -351,9 +351,10 @@ impl PeerPacketFilter for TcpProxy { #[async_trait::async_trait] impl NicPacketFilter for TcpProxy { async fn try_process_packet_from_nic(&self, zc_packet: &mut ZCPacket) -> bool { - let Some(my_ipv4) = self.get_local_ip() else { + let Some(my_ipv4_inet) = self.get_local_inet() else { return false; }; + let my_ipv4 = my_ipv4_inet.address(); let data = zc_packet.payload(); let ip_packet = Ipv4Packet::new(data).unwrap(); @@ -377,7 +378,7 @@ impl NicPacketFilter for TcpProxy { // for kcp proxy, the src ip of nat entry will be converted from my ip to fake ip // here we need to convert it back - if !self.is_smoltcp_enabled() && dst_addr.ip() == Self::get_fake_local_ipv4(my_ipv4) { + if !self.is_smoltcp_enabled() && dst_addr.ip() == Self::get_fake_local_ipv4(&my_ipv4_inet) { dst_addr.set_ip(IpAddr::V4(my_ipv4)); need_transform_dst = true; } @@ -620,13 +621,15 @@ impl TcpProxy { continue; }; - let my_ip = global_ctx - .get_ipv4() + let my_ip_inet = global_ctx.get_ipv4(); + let my_ip = my_ip_inet .as_ref() .map(Ipv4Inet::address) .unwrap_or(Ipv4Addr::UNSPECIFIED); - if socket_addr.ip() == Self::get_fake_local_ipv4(my_ip) { + if my_ip_inet.is_some() + && socket_addr.ip() == Self::get_fake_local_ipv4(&my_ip_inet.unwrap()) + { socket_addr.set_ip(IpAddr::V4(my_ip)); } @@ -768,13 +771,14 @@ impl TcpProxy { } pub fn get_local_ip(&self) -> Option { + self.get_local_inet().map(|inet| inet.address()) + } + + pub fn get_local_inet(&self) -> Option { if self.is_smoltcp_enabled() { - Some(Ipv4Addr::new(192, 88, 99, 254)) + Some(Ipv4Inet::new(Ipv4Addr::new(192, 88, 99, 254), 24).unwrap()) } else { - self.global_ctx - .get_ipv4() - .as_ref() - .map(cidr::Ipv4Inet::address) + self.global_ctx.get_ipv4().as_ref().cloned() } } @@ -787,9 +791,8 @@ impl TcpProxy { .load(std::sync::atomic::Ordering::Relaxed) } - pub fn get_fake_local_ipv4(local_ip: Ipv4Addr) -> Ipv4Addr { - let octets = local_ip.octets(); - Ipv4Addr::new(octets[0], octets[1], octets[2], 0) + pub fn get_fake_local_ipv4(local_ip: &Ipv4Inet) -> Ipv4Addr { + local_ip.first_address() } async fn try_handle_peer_packet(&self, packet: &mut ZCPacket) -> Option<()> { @@ -800,7 +803,8 @@ impl TcpProxy { return None; } - let ipv4_addr = self.get_local_ip()?; + let ipv4_inet = self.get_local_inet()?; + let ipv4_addr = ipv4_inet.address(); let hdr = packet.peer_manager_header().unwrap().clone(); if hdr.packet_type != PacketType::Data as u8 || hdr.is_no_proxy() { @@ -849,7 +853,7 @@ impl TcpProxy { let mut ip_packet = MutableIpv4Packet::new(payload_bytes).unwrap(); if !self.is_smoltcp_enabled() && source_ip == ipv4_addr { // modify the source so the response packet can be handled by tun device - ip_packet.set_source(Self::get_fake_local_ipv4(ipv4_addr)); + ip_packet.set_source(Self::get_fake_local_ipv4(&ipv4_inet)); } ip_packet.set_destination(ipv4_addr); let source = ip_packet.get_source(); diff --git a/easytier/src/instance/dns_server/client_instance.rs b/easytier/src/instance/dns_server/client_instance.rs new file mode 100644 index 0000000..5139cfc --- /dev/null +++ b/easytier/src/instance/dns_server/client_instance.rs @@ -0,0 +1,104 @@ +use std::{sync::Arc, time::Duration}; + +use tokio::task::JoinSet; + +use crate::{ + peers::peer_manager::PeerManager, + proto::{ + cli::Route, + common::Void, + magic_dns::{ + HandshakeRequest, MagicDnsServerRpc, MagicDnsServerRpcClientFactory, + UpdateDnsRecordRequest, + }, + rpc_impl::standalone::StandAloneClient, + rpc_types::controller::BaseController, + }, + tunnel::tcp::TcpTunnelConnector, +}; + +use super::{DEFAULT_ET_DNS_ZONE, MAGIC_DNS_INSTANCE_ADDR}; + +pub struct MagicDnsClientInstance { + rpc_client: StandAloneClient, + rpc_stub: Option + Send>>, + peer_mgr: Arc, + tasks: JoinSet<()>, +} + +impl MagicDnsClientInstance { + pub async fn new(peer_mgr: Arc) -> Result { + let tcp_connector = TcpTunnelConnector::new(MAGIC_DNS_INSTANCE_ADDR.parse().unwrap()); + let mut rpc_client = StandAloneClient::new(tcp_connector); + let rpc_stub = rpc_client + .scoped_client::>("".to_string()) + .await?; + Ok(MagicDnsClientInstance { + rpc_client, + rpc_stub: Some(rpc_stub), + peer_mgr, + tasks: JoinSet::new(), + }) + } + + async fn update_dns_task( + peer_mgr: Arc, + rpc_stub: Box + Send>, + ) -> Result<(), anyhow::Error> { + let mut prev_last_update = None; + rpc_stub + .handshake(BaseController::default(), HandshakeRequest::default()) + .await?; + loop { + rpc_stub + .heartbeat(BaseController::default(), Void::default()) + .await?; + + let last_update = peer_mgr.get_route_peer_info_last_update_time().await; + if Some(last_update) == prev_last_update { + tokio::time::sleep(Duration::from_millis(500)).await; + continue; + } + prev_last_update = Some(last_update); + let mut routes = peer_mgr.list_routes().await; + // add self as a route + let ctx = peer_mgr.get_global_ctx(); + routes.push(Route { + hostname: ctx.get_hostname(), + ipv4_addr: ctx.get_ipv4().map(Into::into), + ..Default::default() + }); + let req = UpdateDnsRecordRequest { + routes, + zone: DEFAULT_ET_DNS_ZONE.to_string(), + }; + tracing::debug!( + "MagicDnsClientInstance::update_dns_task: update dns records: {:?}", + req + ); + rpc_stub + .update_dns_record(BaseController::default(), req) + .await?; + } + } + + pub async fn run_and_wait(&mut self) { + let rpc_stub = self.rpc_stub.take().unwrap(); + let peer_mgr = self.peer_mgr.clone(); + self.tasks.spawn(async move { + let ret = Self::update_dns_task(peer_mgr, rpc_stub).await; + if let Err(e) = ret { + tracing::error!("MagicDnsServerInstanceData::run_and_wait: {:?}", e); + } + }); + + tokio::select! { + _ = self.tasks.join_next() => { + tracing::warn!("MagicDnsServerInstanceData::run_and_wait: dns record update task exited"); + } + _ = self.rpc_client.wait() => { + tracing::warn!("MagicDnsServerInstanceData::run_and_wait: rpc client exited"); + } + } + } +} diff --git a/easytier/src/instance/dns_server/config.rs b/easytier/src/instance/dns_server/config.rs new file mode 100644 index 0000000..e55fadc --- /dev/null +++ b/easytier/src/instance/dns_server/config.rs @@ -0,0 +1,193 @@ +use hickory_proto::rr; +use hickory_proto::rr::RData; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::net::{IpAddr, Ipv4Addr}; +use std::str::FromStr; +use std::time::Duration; + +#[derive(Serialize, Deserialize, Debug, Clone, derive_builder::Builder)] +pub struct RunConfig { + general: GeneralConfig, + + #[builder(default = HashMap::new())] + zones: Zone, + + #[builder(default = Vec::new())] + #[serde(default)] + excluded_forward_nameservers: Vec, +} + +impl RunConfig { + pub fn general(&self) -> &GeneralConfig { + &self.general + } + + pub fn zones(&self) -> &Zone { + &self.zones + } + + pub fn excluded_forward_nameservers(&self) -> &Vec { + &self.excluded_forward_nameservers + } +} + +#[derive(Serialize, Deserialize, Debug, Clone, derive_builder::Builder)] +pub struct GeneralConfig { + #[builder(setter(into, strip_option), default = None)] + listen_tcp: Option, + + #[builder(setter(into, strip_option), default = None)] + listen_udp: Option, +} + +impl GeneralConfig { + pub fn listen_tcp(&self) -> &Option { + &self.listen_tcp + } + + pub fn listen_udp(&self) -> &Option { + &self.listen_udp + } +} + +pub type Zone = HashMap>; // domain -> records + +pub type RecordType = rr::RecordType; + +#[derive(Serialize, Deserialize, Debug, Clone, derive_builder::Builder)] +pub struct Record { + #[serde(rename = "type")] + rr_type: RecordType, + + name: String, + value: String, + + #[serde(with = "humantime_serde")] + ttl: Duration, +} + +impl Record { + fn name(&self) -> anyhow::Result { + let name = rr::Name::from_str(self.name.as_str())?; + Ok(name) + } + + fn rr_type(&self) -> rr::RecordType { + self.rr_type.clone().into() + } +} + +impl TryFrom for rr::Record { + type Error = anyhow::Error; + + fn try_from(value: Record) -> Result { + let r: rr::Record = (&value).try_into()?; + Ok(r) + } +} + +impl TryFrom<&Record> for rr::Record { + type Error = anyhow::Error; + + fn try_from(value: &Record) -> Result { + let name = value.name()?; + let mut record = Self::update0(name, value.ttl.as_secs() as u32, value.rr_type()); + record.set_dns_class(rr::DNSClass::IN); + match value.rr_type { + RecordType::A => { + let addr: Ipv4Addr = value.value.parse()?; + record.set_data(RData::A(rr::rdata::a::A(addr))); + } + RecordType::SOA => { + let soa = value.value.split_whitespace().collect::>(); + if soa.len() != 7 { + return Err(anyhow::anyhow!("invalid SOA record")); + } + let mname = rr::Name::from_str(soa[0])?; + let rname = rr::Name::from_str(soa[1])?; + let serial: u32 = soa[2].parse()?; + let refresh: u32 = soa[3].parse()?; + let retry: u32 = soa[4].parse()?; + let expire: u32 = soa[5].parse()?; + let minimum: u32 = soa[6].parse()?; + record.set_data(RData::SOA(rr::rdata::soa::SOA::new( + mname, + rname, + serial, + refresh.try_into().unwrap(), + retry.try_into().unwrap(), + expire.try_into().unwrap(), + minimum, + ))); + } + _ => todo!(), + } + Ok(record) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use anyhow::anyhow; + + #[tokio::test] + async fn it_works() -> anyhow::Result<()> { + let text = r#" +[general] +listen_tcp = "127.0.0.1:5300" +listen_udp = "127.0.0.1:5353" + +[[zones."et.internal"]] +type = "A" +name = "www" +value = "123.123.123.123" +ttl = "60s" + +[[zones."et.top"]] +type = "A" +name = "@" +value = "100.100.100.100" +ttl = "61s" + +"#; + + let config = toml::from_str::(text)?; + assert_eq!( + config.general.listen_tcp().clone().unwrap(), + "127.0.0.1:5300" + ); + assert_eq!( + config.general.listen_udp().clone().unwrap(), + "127.0.0.1:5353" + ); + assert_eq!(config.zones.len(), 2); + + let (domain, records) = config + .zones + .get_key_value("et.internal") + .map_or(Err(anyhow!("parse error")), |x| Ok(x))?; + assert_eq!(domain, "et.internal"); + assert_eq!(records.len(), 1); + let record = &records[0]; + assert_eq!(record.rr_type, RecordType::A); + assert_eq!(record.name, "www"); + assert_eq!(record.value, "123.123.123.123"); + assert_eq!(record.ttl.as_secs(), 60); + + let (domain, records) = config + .zones + .get_key_value("et.top") + .map_or(Err(anyhow!("parse error")), |x| Ok(x))?; + assert_eq!(domain, "et.top"); + assert_eq!(records.len(), 1); + let record = &records[0]; + assert_eq!(record.rr_type, RecordType::A); + assert_eq!(record.name, "@"); + assert_eq!(record.value, "100.100.100.100"); + assert_eq!(record.ttl.as_secs(), 61); + + Ok(()) + } +} diff --git a/easytier/src/instance/dns_server/mod.rs b/easytier/src/instance/dns_server/mod.rs new file mode 100644 index 0000000..3c9cdab --- /dev/null +++ b/easytier/src/instance/dns_server/mod.rs @@ -0,0 +1,15 @@ +// This module is copy and modified from https://github.com/fanyang89/libdns +pub(crate) mod config; +pub(crate) mod server; + +pub mod client_instance; +pub mod runner; +pub mod server_instance; +pub mod system_config; + +#[cfg(test)] +mod tests; + +pub static MAGIC_DNS_INSTANCE_ADDR: &str = "tcp://127.0.0.1:49813"; +pub static MAGIC_DNS_FAKE_IP: &str = "100.100.100.101"; +pub static DEFAULT_ET_DNS_ZONE: &str = "et.net."; diff --git a/easytier/src/instance/dns_server/runner.rs b/easytier/src/instance/dns_server/runner.rs new file mode 100644 index 0000000..524a253 --- /dev/null +++ b/easytier/src/instance/dns_server/runner.rs @@ -0,0 +1,93 @@ +use cidr::Ipv4Inet; +use tokio_util::sync::CancellationToken; + +use crate::peers::peer_manager::PeerManager; +use std::{net::Ipv4Addr, sync::Arc, time::Duration}; + +use super::{client_instance::MagicDnsClientInstance, server_instance::MagicDnsServerInstance}; + +static DEFAULT_ET_DNS_ZONE: &str = "et.net."; + +pub struct DnsRunner { + client: Option, + server: Option, + peer_mgr: Arc, + tun_dev: Option, + tun_inet: Ipv4Inet, + fake_ip: Ipv4Addr, +} + +impl DnsRunner { + pub fn new( + peer_mgr: Arc, + tun_dev: Option, + tun_inet: Ipv4Inet, + fake_ip: Ipv4Addr, + ) -> Self { + Self { + client: None, + server: None, + peer_mgr, + tun_dev, + tun_inet, + fake_ip, + } + } + + async fn clean_env(&mut self) { + if let Some(server) = self.server.take() { + server.clean_env().await; + } + self.client.take(); + } + + async fn run_once(&mut self) -> anyhow::Result<()> { + // try server first + match MagicDnsServerInstance::new( + self.peer_mgr.clone(), + self.tun_dev.clone(), + self.tun_inet, + self.fake_ip, + ) + .await + { + Ok(server) => { + self.server = Some(server); + tracing::info!("DnsRunner::run_once: server started"); + } + Err(e) => { + tracing::error!("DnsRunner::run_once: {:?}", e); + } + } + + // every runner must run a client + let client = MagicDnsClientInstance::new(self.peer_mgr.clone()).await?; + self.client = Some(client); + self.client.as_mut().unwrap().run_and_wait().await; + + return Err(anyhow::anyhow!("Client instance exit")); + } + + pub async fn run(&mut self, canel_token: CancellationToken) { + loop { + tracing::info!("DnsRunner::run: start"); + tokio::select! { + _ = canel_token.cancelled() => { + self.clean_env().await; + tracing::info!("DnsRunner::run: cancelled"); + return; + } + + ret = self.run_once() => { + self.clean_env().await; + if let Err(e) = ret { + tracing::error!("DnsRunner::run: {:?}", e); + } else { + tracing::info!("DnsRunner::run: unexpected exit, server may be down"); + } + tokio::time::sleep(Duration::from_millis(500)).await; + } + } + } + } +} diff --git a/easytier/src/instance/dns_server/server.rs b/easytier/src/instance/dns_server/server.rs new file mode 100644 index 0000000..3057346 --- /dev/null +++ b/easytier/src/instance/dns_server/server.rs @@ -0,0 +1,338 @@ +use anyhow::{Context, Result}; +use hickory_proto::op::Edns; +use hickory_proto::rr; +use hickory_proto::rr::LowerName; +use hickory_resolver::config::ResolverOpts; +use hickory_resolver::name_server::TokioConnectionProvider; +use hickory_resolver::system_conf::read_system_conf; +use hickory_server::authority::{AuthorityObject, Catalog, ZoneType}; +use hickory_server::server::{Request, RequestHandler, ResponseHandler, ResponseInfo}; +use hickory_server::store::forwarder::ForwardConfig; +use hickory_server::store::{forwarder::ForwardAuthority, in_memory::InMemoryAuthority}; +use hickory_server::ServerFuture; +use std::io; +use std::net::SocketAddr; +use std::str::FromStr; +use std::sync::Arc; +use std::time::Duration; +use tokio::net::{TcpListener, UdpSocket}; +use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; +use tokio::task::JoinSet; + +use crate::common::stun::get_default_resolver_config; + +use super::config::{GeneralConfig, Record, RunConfig}; + +pub struct Server { + server: ServerFuture, + catalog: Arc>, + general_config: GeneralConfig, + udp_local_addr: Option, + tcp_local_addr: Option, + tasks: JoinSet<()>, +} + +struct CatalogRequestHandler { + catalog: Arc>, +} + +impl CatalogRequestHandler { + fn new(catalog: Arc>) -> CatalogRequestHandler { + // let system_conf = read_system_conf(); + // let recursor = match system_conf { + // Ok((conf, _)) => RecursorBuilder::default().build(conf), + // Err(_) => RecursorBuilder::default().build(get_default_resolver_config()), + // } + // // policy is security unware, this will never return an error + // .unwrap(); + + Self { catalog } + } +} + +#[async_trait::async_trait] +impl RequestHandler for CatalogRequestHandler { + async fn handle_request( + &self, + request: &Request, + response_handle: R, + ) -> ResponseInfo { + self.catalog + .read() + .await + .handle_request(request, response_handle) + .await + } +} + +pub fn build_authority(domain: &str, records: &[Record]) -> Result { + let zone = rr::Name::from_str(domain)?; + let mut authority = InMemoryAuthority::empty(zone.clone(), ZoneType::Primary, false); + for record in records.iter() { + let r = record.try_into()?; + authority.upsert_mut(r, 0); + } + Ok(authority) +} + +impl Server { + pub fn new(config: RunConfig) -> Self { + Self::try_new(config).unwrap() + } + + fn try_new(config: RunConfig) -> Result { + let mut catalog = Catalog::new(); + for (domain, records) in config.zones().iter() { + let zone = rr::Name::from_str(domain.as_str())?; + let authroty = build_authority(domain, records)?; + catalog.upsert(zone.clone().into(), vec![Arc::new(authroty)]); + } + + // use forwarder authority for the root zone + let system_conf = + read_system_conf().unwrap_or((get_default_resolver_config(), ResolverOpts::default())); + let forward_config = ForwardConfig { + name_servers: system_conf + .0 + .name_servers() + .iter() + .cloned() + .filter(|x| { + !config + .excluded_forward_nameservers() + .contains(&x.socket_addr.ip()) + }) + .collect::>() + .into(), + options: Some(system_conf.1), + }; + let auth = ForwardAuthority::builder_with_config( + forward_config, + TokioConnectionProvider::default(), + ) + .build() + .unwrap(); + + catalog.upsert(rr::Name::from_str(".")?.into(), vec![Arc::new(auth)]); + + let catalog = Arc::new(RwLock::new(catalog)); + let handler = CatalogRequestHandler::new(catalog.clone()); + let server = ServerFuture::new(handler); + + Ok(Self { + server, + catalog, + general_config: config.general().clone(), + udp_local_addr: None, + tcp_local_addr: None, + tasks: JoinSet::new(), + }) + } + + pub fn udp_local_addr(&self) -> Option { + self.udp_local_addr + } + + pub fn tcp_local_addr(&self) -> Option { + self.tcp_local_addr + } + + pub async fn register_udp_socket(&mut self, address: String) -> Result { + let bind_addr = SocketAddr::from_str(&address) + .with_context(|| format!("DNS Server failed to parse address {}", address))?; + let socket = socket2::Socket::new( + socket2::Domain::IPV4, + socket2::Type::DGRAM, + Some(socket2::Protocol::UDP), + ) + .with_context(|| { + format!( + "DNS Server failed to create UDP socket for address {}", + address.to_string() + ) + })?; + socket2::SockRef::from(&socket) + .set_reuse_address(true) + .with_context(|| { + format!( + "DNS Server failed to set reuse address on socket {}", + address.to_string() + ) + })?; + socket.bind(&bind_addr.into()).with_context(|| { + format!("DNS Server failed to bind socket to address {}", bind_addr) + })?; + socket + .set_nonblocking(true) + .with_context(|| format!("DNS Server failed to set socket to non-blocking"))?; + let socket = UdpSocket::from_std(socket.into()).with_context(|| { + format!( + "DNS Server failed to convert socket to UdpSocket for address {}", + address.to_string() + ) + })?; + + let local_addr = socket + .local_addr() + .with_context(|| format!("DNS Server failed to get local address"))?; + self.server.register_socket(socket); + + Ok(local_addr) + } + + pub async fn run(&mut self) -> Result<()> { + if let Some(address) = self.general_config.listen_tcp() { + let tcp_listener = TcpListener::bind(address.clone()) + .await + .with_context(|| format!("DNS Server failed to bind TCP address {}", address))?; + self.tcp_local_addr = Some(tcp_listener.local_addr()?); + self.server + .register_listener(tcp_listener, Duration::from_secs(5)); + } + + if let Some(address) = self.general_config.listen_udp() { + let local_addr = self.register_udp_socket(address.clone()).await?; + self.udp_local_addr = Some(local_addr); + }; + + Ok(()) + } + + pub async fn shutdown(&mut self) -> Result<()> { + self.server.shutdown_gracefully().await?; + Ok(()) + } + + pub async fn upsert(&self, name: LowerName, authority: Arc) { + self.catalog.write().await.upsert(name, vec![authority]); + } + + pub async fn remove(&self, name: &LowerName) -> Option>> { + self.catalog.write().await.remove(name) + } + + pub async fn update( + &self, + update: &Request, + response_edns: Option, + response_handle: R, + ) -> io::Result { + self.catalog + .write() + .await + .update(update, response_edns, response_handle) + .await + } + + pub async fn contains(&self, name: &LowerName) -> bool { + self.catalog.read().await.contains(name) + } + + pub async fn lookup( + &self, + request: &Request, + response_edns: Option, + response_handle: R, + ) -> ResponseInfo { + self.catalog + .read() + .await + .lookup(request, response_edns, response_handle) + .await + } + + pub async fn read_catalog(&self) -> RwLockReadGuard<'_, Catalog> { + self.catalog.read().await + } + + pub async fn write_catalog(&self) -> RwLockWriteGuard<'_, Catalog> { + self.catalog.write().await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::instance::dns_server::config::{ + GeneralConfigBuilder, RecordBuilder, RecordType, RunConfigBuilder, + }; + use anyhow::Result; + use hickory_client::client::{Client, ClientHandle}; + use hickory_proto::rr; + use hickory_proto::runtime::TokioRuntimeProvider; + use hickory_proto::udp::UdpClientStream; + use maplit::hashmap; + use std::time::Duration; + + #[tokio::test] + async fn it_works() -> Result<()> { + let mut server = Server::new( + RunConfigBuilder::default() + .general(GeneralConfigBuilder::default().build()?) + .build()?, + ); + server.run().await?; + server.shutdown().await?; + Ok(()) + } + + #[tokio::test] + async fn can_resolve_records() -> Result<()> { + let configured_record = RecordBuilder::default() + .rr_type(RecordType::A) + .name("www.et.internal.".to_string()) + .value("123.123.123.123".to_string()) + .ttl(Duration::from_secs(60)) + .build()?; + let configured_record2 = RecordBuilder::default() + .rr_type(RecordType::A) + .name("中文.et.internal.".to_string()) + .value("123.123.123.123".to_string()) + .ttl(Duration::from_secs(60)) + .build()?; + let soa_record = RecordBuilder::default() + .rr_type(RecordType::SOA) + .name("et.internal.".to_string()) + .value( + "ns.et.internal. hostmaster.et.internal. 2023101001 7200 3600 1209600 86400" + .to_string(), + ) + .ttl(Duration::from_secs(60)) + .build()?; + let config = RunConfigBuilder::default() + .general( + GeneralConfigBuilder::default() + .listen_udp("127.0.0.1:0") + .build()?, + ) + .zones(hashmap! { + "et.internal.".to_string() => vec![configured_record.clone(), soa_record.clone(), configured_record2.clone()], + }) + .build()?; + + let mut server = Server::new(config); + server.run().await?; + + let local_addr = server.udp_local_addr().unwrap(); + let stream = UdpClientStream::builder(local_addr, TokioRuntimeProvider::default()).build(); + let (mut client, background) = Client::connect(stream).await?; + let background_task = tokio::spawn(background); + let response = client + .query( + rr::Name::from_str("www.et.internal")?, + rr::DNSClass::IN, + rr::RecordType::A, + ) + .await?; + drop(background_task); + + println!("Response: {:?}", response); + + assert_eq!(response.answers().len(), 1); + let expected_record: rr::Record = configured_record.try_into()?; + assert_eq!(response.answers().first().unwrap(), &expected_record); + + server.shutdown().await?; + Ok(()) + } +} diff --git a/easytier/src/instance/dns_server/server_instance.rs b/easytier/src/instance/dns_server/server_instance.rs new file mode 100644 index 0000000..36ab916 --- /dev/null +++ b/easytier/src/instance/dns_server/server_instance.rs @@ -0,0 +1,414 @@ +// single-instance server in one machine, every easytier instance that has ip address and tun device will try create a server instance. + +// magic dns client will connect to this server to update the dns records. +// magic dns server will add the dns server ip address to the tun device, and forward the dns request to the dns server + +// magic dns client will establish a long live tcp connection to the magic dns server, and when the server stops or crashes, +// all the clients will exit and let the easytier instance to launch a new server instance. + +use std::{collections::BTreeMap, net::Ipv4Addr, str::FromStr, sync::Arc, time::Duration}; + +use anyhow::Context; +use cidr::Ipv4Inet; +use dashmap::DashMap; +use hickory_proto::rr::LowerName; +use multimap::MultiMap; +use pnet::packet::{ + icmp::{self, IcmpTypes, MutableIcmpPacket}, + ip::IpNextHeaderProtocols, + ipv4::{self, MutableIpv4Packet}, + tcp::{self, MutableTcpPacket}, + udp::{self, MutableUdpPacket}, + MutablePacket, +}; + +use crate::{ + common::{ + ifcfg::{IfConfiger, IfConfiguerTrait}, + PeerId, + }, + instance::dns_server::{ + config::{Record, RecordBuilder, RecordType}, + server::build_authority, + DEFAULT_ET_DNS_ZONE, + }, + peers::{peer_manager::PeerManager, NicPacketFilter}, + proto::{ + cli::Route, + common::{TunnelInfo, Void}, + magic_dns::{ + dns_record::{self}, + DnsRecord, DnsRecordA, DnsRecordList, GetDnsRecordResponse, HandshakeRequest, + HandshakeResponse, MagicDnsServerRpc, MagicDnsServerRpcServer, UpdateDnsRecordRequest, + }, + rpc_impl::standalone::{RpcServerHook, StandAloneServer}, + rpc_types::controller::{BaseController, Controller}, + }, + tunnel::{packet_def::ZCPacket, tcp::TcpTunnelListener}, +}; + +use super::{ + config::{GeneralConfigBuilder, RunConfigBuilder}, + server::Server, + MAGIC_DNS_INSTANCE_ADDR, +}; + +static NIC_PIPELINE_NAME: &str = "magic_dns_server"; + +pub(super) struct MagicDnsServerInstanceData { + dns_server: Server, + tun_dev: Option, + tun_ip: Ipv4Addr, + fake_ip: Ipv4Addr, + my_peer_id: PeerId, + + // zone -> (tunnel remote addr -> route) + route_infos: DashMap>, +} + +impl MagicDnsServerInstanceData { + pub async fn update_dns_records<'a, T: Iterator>( + &self, + routes: T, + zone: &str, + ) -> Result<(), anyhow::Error> { + let mut records: Vec = vec![]; + for route in routes { + if route.hostname.is_empty() { + continue; + } + + let Some(ipv4_addr) = route.ipv4_addr.unwrap_or_default().address else { + continue; + }; + + let record = RecordBuilder::default() + .rr_type(RecordType::A) + .name(format!("{}.{}", route.hostname, zone)) + .value(ipv4_addr.to_string()) + .ttl(Duration::from_secs(1)) + .build()?; + + records.push(record); + } + + let soa_record = RecordBuilder::default() + .rr_type(RecordType::SOA) + .name(zone.to_string()) + .value(format!( + "ns.{} hostmaster.{} 2023101001 7200 3600 1209600 86400", + zone, zone + )) + .ttl(Duration::from_secs(60)) + .build()?; + records.push(soa_record); + + let authority = build_authority(zone, &records)?; + + self.dns_server + .upsert( + LowerName::from_str(zone) + .with_context(|| "Invalid zone name, expect fomat like \"et.net.\"")?, + Arc::new(authority), + ) + .await; + + tracing::debug!("Updated DNS records for zone {}: {:?}", zone, records); + + Ok(()) + } + + pub async fn update(&self) { + for item in self.route_infos.iter() { + let zone = item.key(); + let route_iter = item.value().flat_iter().map(|x| x.1); + if let Err(e) = self.update_dns_records(route_iter, zone).await { + tracing::error!("Failed to update DNS records for zone {}: {:?}", zone, e); + } + } + } + + fn do_system_config(&self, _zone: &str) -> Result<(), anyhow::Error> { + #[cfg(target_os = "windows")] + { + use super::system_config::windows::WindowsDNSManager; + let cfg = WindowsDNSManager::new(self.tun_dev.as_ref().unwrap())?; + cfg.set_primary_dns(&[self.fake_ip.clone().into()], &[_zone.to_string()])?; + } + + Ok(()) + } +} + +#[async_trait::async_trait] +impl MagicDnsServerRpc for MagicDnsServerInstanceData { + type Controller = BaseController; + async fn handshake( + &self, + _ctrl: Self::Controller, + _input: HandshakeRequest, + ) -> crate::proto::rpc_types::error::Result { + Ok(Default::default()) + } + + async fn update_dns_record( + &self, + ctrl: Self::Controller, + input: UpdateDnsRecordRequest, + ) -> crate::proto::rpc_types::error::Result { + let Some(tunnel_info) = ctrl.get_tunnel_info() else { + return Err(anyhow::anyhow!("No tunnel info").into()); + }; + let Some(remote_addr) = &tunnel_info.remote_addr else { + return Err(anyhow::anyhow!("No remote addr").into()); + }; + let zone = input.zone.clone(); + self.route_infos + .entry(zone.clone()) + .or_default() + .insert_many(remote_addr.clone().into(), input.routes); + + self.update().await; + Ok(Default::default()) + } + + async fn get_dns_record( + &self, + _ctrl: Self::Controller, + _input: Void, + ) -> crate::proto::rpc_types::error::Result { + let mut ret = BTreeMap::new(); + for item in self.route_infos.iter() { + let zone = item.key(); + let routes = item.value(); + let mut dns_records = DnsRecordList::default(); + for route in routes.iter().map(|x| x.1) { + dns_records.records.push(DnsRecord { + record: Some(dns_record::Record::A(DnsRecordA { + name: format!("{}.{}", route.hostname, zone), + value: route.ipv4_addr.unwrap_or_default().address, + ttl: 1, + })), + }); + } + ret.insert(zone.clone(), dns_records); + } + Ok(GetDnsRecordResponse { records: ret }) + } + + async fn heartbeat( + &self, + _ctrl: Self::Controller, + _input: Void, + ) -> crate::proto::rpc_types::error::Result { + Ok(Default::default()) + } +} + +#[async_trait::async_trait] +impl NicPacketFilter for MagicDnsServerInstanceData { + async fn try_process_packet_from_nic(&self, zc_packet: &mut ZCPacket) -> bool { + let data = zc_packet.mut_payload(); + let mut ip_packet = MutableIpv4Packet::new(data).unwrap(); + if ip_packet.get_version() != 4 || ip_packet.get_destination() != self.fake_ip { + return false; + } + + match ip_packet.get_next_level_protocol() { + IpNextHeaderProtocols::Udp => { + let Some(dns_udp_addr) = self.dns_server.udp_local_addr() else { + return false; + }; + + let Some(mut udp_packet) = MutableUdpPacket::new(ip_packet.payload_mut()) else { + return false; + }; + if udp_packet.get_destination() == 53 { + // for dns request + udp_packet.set_destination(dns_udp_addr.port()); + } else if udp_packet.get_source() == dns_udp_addr.port() { + // for dns response + udp_packet.set_source(53); + } else { + return false; + } + udp_packet.set_checksum(udp::ipv4_checksum( + &udp_packet.to_immutable(), + &self.fake_ip, + &self.tun_ip, + )); + } + + IpNextHeaderProtocols::Tcp => { + let Some(dns_tcp_addr) = self.dns_server.tcp_local_addr() else { + return false; + }; + + let Some(mut tcp_packet) = MutableTcpPacket::new(ip_packet.payload_mut()) else { + return false; + }; + if tcp_packet.get_destination() == 53 { + // for dns request + tcp_packet.set_destination(dns_tcp_addr.port()); + } else if tcp_packet.get_source() == dns_tcp_addr.port() { + // for dns response + tcp_packet.set_source(53); + } else { + return false; + } + tcp_packet.set_checksum(tcp::ipv4_checksum( + &tcp_packet.to_immutable(), + &self.fake_ip, + &self.tun_ip, + )); + } + + IpNextHeaderProtocols::Icmp => { + let Some(mut icmp_packet) = MutableIcmpPacket::new(ip_packet.payload_mut()) else { + return false; + }; + if icmp_packet.get_icmp_type() != IcmpTypes::EchoRequest { + return false; + } + icmp_packet.set_icmp_type(IcmpTypes::EchoReply); + icmp_packet.set_checksum(icmp::checksum(&icmp_packet.to_immutable())); + } + + _ => { + return false; + } + } + + ip_packet.set_source(self.fake_ip); + ip_packet.set_destination(self.tun_ip); + + ip_packet.set_checksum(ipv4::checksum(&ip_packet.to_immutable())); + zc_packet.mut_peer_manager_header().unwrap().to_peer_id = self.my_peer_id.into(); + + true + } + + fn id(&self) -> String { + NIC_PIPELINE_NAME.to_string() + } +} + +#[async_trait::async_trait] +impl RpcServerHook for MagicDnsServerInstanceData { + async fn on_new_client(&self, tunnel_info: Option) { + println!("New client connected: {:?}", tunnel_info); + } + + async fn on_client_disconnected(&self, tunnel_info: Option) { + println!("Client disconnected: {:?}", tunnel_info); + let Some(tunnel_info) = tunnel_info else { + return; + }; + let Some(remote_addr) = tunnel_info.remote_addr else { + return; + }; + let remote_addr = remote_addr.into(); + for mut item in self.route_infos.iter_mut() { + item.value_mut().remove(&remote_addr); + } + self.route_infos.retain(|_, v| !v.is_empty()); + self.update().await; + } +} + +pub struct MagicDnsServerInstance { + rpc_server: StandAloneServer, + pub(super) data: Arc, + peer_mgr: Arc, + tun_inet: Ipv4Inet, +} + +impl MagicDnsServerInstance { + pub async fn new( + peer_mgr: Arc, + tun_dev: Option, + tun_inet: Ipv4Inet, + fake_ip: Ipv4Addr, + ) -> Result { + let tcp_listener = TcpTunnelListener::new(MAGIC_DNS_INSTANCE_ADDR.parse().unwrap()); + let mut rpc_server = StandAloneServer::new(tcp_listener); + rpc_server.serve().await?; + + let bind_addr = tun_inet.address(); + + let dns_config = RunConfigBuilder::default() + .general( + GeneralConfigBuilder::default() + .listen_udp(format!("{}:0", bind_addr)) + .listen_tcp(format!("{}:0", bind_addr)) + .build() + .unwrap(), + ) + .excluded_forward_nameservers(vec![fake_ip.into()]) + .build() + .unwrap(); + let mut dns_server = Server::new(dns_config); + dns_server.run().await?; + + if !tun_inet.contains(&fake_ip) && tun_dev.is_some() { + let cost = if cfg!(target_os = "windows") { + Some(4) + } else { + None + }; + let ifcfg = IfConfiger {}; + ifcfg + .add_ipv4_route(tun_dev.as_ref().unwrap(), fake_ip, 32, cost) + .await?; + } + + let data = Arc::new(MagicDnsServerInstanceData { + dns_server, + tun_dev, + tun_ip: tun_inet.address(), + fake_ip, + my_peer_id: peer_mgr.my_peer_id(), + route_infos: DashMap::new(), + }); + rpc_server + .registry() + .register(MagicDnsServerRpcServer::new(data.clone()), ""); + rpc_server.set_hook(data.clone()); + + peer_mgr + .add_nic_packet_process_pipeline(Box::new(data.clone())) + .await; + + let data_clone = data.clone(); + tokio::task::spawn_blocking(move || data_clone.do_system_config(DEFAULT_ET_DNS_ZONE)) + .await + .context("Failed to configure system")??; + + Ok(Self { + rpc_server, + data, + peer_mgr, + tun_inet, + }) + } + + pub async fn clean_env(&self) { + if !self.tun_inet.contains(&self.data.fake_ip) && self.data.tun_dev.is_some() { + let ifcfg = IfConfiger {}; + let _ = ifcfg + .remove_ipv4_route(&self.data.tun_dev.as_ref().unwrap(), self.data.fake_ip, 32) + .await; + } + + let _ = self + .peer_mgr + .remove_nic_packet_process_pipeline(NIC_PIPELINE_NAME.to_string()) + .await; + } +} + +impl Drop for MagicDnsServerInstance { + fn drop(&mut self) { + println!("MagicDnsServerInstance dropped"); + } +} diff --git a/easytier/src/instance/dns_server/system_config/linux.rs b/easytier/src/instance/dns_server/system_config/linux.rs new file mode 100644 index 0000000..f2bfc61 --- /dev/null +++ b/easytier/src/instance/dns_server/system_config/linux.rs @@ -0,0 +1,357 @@ +// translated from tailscale #32ce1bdb48078ec4cedaeeb5b1b2ff9c0ef61a49 + +use crate::defer; +use anyhow::{Context, Result}; +use dbus::blocking::stdintf::org_freedesktop_dbus::Properties as _; +use std::fs; +use std::net::Ipv4Addr; +use std::path::Path; +use std::process::Command; +use std::time::Duration; +use version_compare::Cmp; + +// 声明依赖项(需要添加到Cargo.toml) +// use dbus::blocking::Connection; +// use nix::unistd::AccessFlags; +// use resolv_conf::Resolver; + +// 常量定义 +const RESOLV_CONF: &str = "/etc/resolv.conf"; +const PING_TIMEOUT: Duration = Duration::from_secs(1); + +// 错误类型定义 +#[derive(Debug)] +struct DNSConfigError { + message: String, + source: Option, +} + +// 配置环境结构体 +struct OSConfigEnv { + fs: Box, + dbus_ping: Box Result<()>>, + dbus_read_string: Box Result>, + nm_is_using_resolved: Box Result<()>>, + nm_version_between: Box Result>, + resolvconf_style: Box String>, +} + +// DNS管理器trait +trait OSConfigurator: Send + Sync { + // 实现相关方法 +} + +// 文件系统操作trait +trait FileSystem { + fn read_file(&self, path: &str) -> Result>; + fn exists(&self, path: &str) -> bool; +} + +// 直接文件系统实现 +struct DirectFS; + +impl FileSystem for DirectFS { + fn read_file(&self, path: &str) -> Result> { + fs::read(path).context("Failed to read file") + } + + fn exists(&self, path: &str) -> bool { + Path::new(path).exists() + } +} + +/// 检查 NetworkManager 是否使用 systemd-resolved 作为 DNS 管理器 +pub fn nm_is_using_resolved() -> Result<()> { + // 连接系统 D-Bus + let conn = dbus::blocking::Connection::new_system().context("Failed to connect to D-Bus")?; + + // 创建 NetworkManager DnsManager 对象代理 + let proxy = conn.with_proxy( + "org.freedesktop.NetworkManager", + "/org/freedesktop/NetworkManager/DnsManager", + std::time::Duration::from_secs(1), + ); + + // 获取 Mode 属性 + let (value,): (dbus::arg::Variant>,) = proxy + .method_call( + "org.freedesktop.DBus.Properties", + "Get", + ("org.freedesktop.NetworkManager.DnsManager", "Mode"), + ) + .context("Failed to get NM mode property")?; + + // 检查 Mode 是否为 "systemd-resolved" + if value.0.as_str() != Some("systemd-resolved") { + return Err(anyhow::anyhow!( + "NetworkManager is not using systemd-resolved, found: {:?}", + value + ) + .into()); + } + + Ok(()) +} + +/// 返回系统中使用的 resolvconf 实现类型("debian" 或 "openresolv") +pub fn resolvconf_style() -> String { + // 检查 resolvconf 命令是否存在 + if which::which("resolvconf").is_err() { + return String::new(); + } + + // 执行 resolvconf --version 命令 + let output = match Command::new("resolvconf").arg("--version").output() { + Ok(output) => output, + Err(e) => { + // 处理命令执行错误 + if let Some(code) = e.raw_os_error() { + // Debian 版本的 resolvconf 不支持 --version,返回特定错误码 99 + if code == 99 { + return "debian".to_string(); + } + } + return String::new(); // 其他错误返回空字符串 + } + }; + + // 检查输出是否以 "Debian resolvconf" 开头 + if output.stdout.starts_with(b"Debian resolvconf") { + return "debian".to_string(); + } + + // 默认视为 openresolv + "openresolv".to_string() +} + +// 构建配置环境 +fn new_os_config_env() -> OSConfigEnv { + OSConfigEnv { + fs: Box::new(DirectFS), + dbus_ping: Box::new(dbus_ping), + dbus_read_string: Box::new(dbus_read_string), + nm_is_using_resolved: Box::new(nm_is_using_resolved), + nm_version_between: Box::new(nm_version_between), + resolvconf_style: Box::new(resolvconf_style), + } +} + +// 创建DNS配置器 +fn new_os_configurator(_interface_name: String) -> Result<()> { + let env = new_os_config_env(); + + let mode = dns_mode(&env).context("Failed to detect DNS mode")?; + + tracing::info!("dns: using {} mode", mode); + + // match mode.as_str() { + // "direct" => Ok(Box::new(DirectManager::new(env.fs)?)), + // // "systemd-resolved" => Ok(Box::new(ResolvedManager::new( + // // &logf, + // // health, + // // interface_name, + // // )?)), + // // "network-manager" => Ok(Box::new(NMManager::new(interface_name)?)), + // // "debian-resolvconf" => Ok(Box::new(DebianResolvconfManager::new(&logf)?)), + // // "openresolv" => Ok(Box::new(OpenresolvManager::new(&logf)?)), + // _ => { + // tracing::warn!("Unexpected DNS mode {}, using direct manager", mode); + // Ok(Box::new(DirectManager::new(env.fs)?)) + // } + // } + Ok(()) +} + +use std::io::{self, BufRead, Cursor}; + +/// 返回 `resolv.conf` 内容的拥有者("systemd-resolved"、"NetworkManager"、"resolvconf" 或空字符串) +pub fn resolv_owner(bs: &[u8]) -> String { + let mut likely = String::new(); + let cursor = Cursor::new(bs); + let reader = io::BufReader::new(cursor); + + for line_result in reader.lines() { + match line_result { + Ok(line) => { + let line = line.trim(); + if line.is_empty() { + continue; + } + + if !line.starts_with('#') { + // 第一个非注释且非空的行,直接返回当前结果 + return likely; + } + + // 检查注释行中的关键字 + if line.contains("systemd-resolved") { + likely = "systemd-resolved".to_string(); + } else if line.contains("NetworkManager") { + likely = "NetworkManager".to_string(); + } else if line.contains("resolvconf") { + likely = "resolvconf".to_string(); + } + } + Err(_) => { + // 读取错误(如无效 UTF-8),直接返回当前结果 + return likely; + } + } + } + + likely +} + +// 检测DNS模式 +fn dns_mode(env: &OSConfigEnv) -> Result { + let debug = std::cell::RefCell::new(Vec::new()); + let dbg = |k: &str, v: &str| debug.borrow_mut().push((k.to_string(), v.to_string())); + + // defer 日志记录 + defer! { + if !debug.borrow().is_empty() { + let log_entries: Vec = + debug.borrow().iter().map(|(k, v)| format!("{}={}", k, v)).collect(); + tracing::info!("dns: [{}]", log_entries.join(" ")); + } + }; + + // 检查systemd-resolved状态 + let resolved_up = + (env.dbus_ping)("org.freedesktop.resolve1", "/org/freedesktop/resolve1").is_ok(); + if resolved_up { + dbg("resolved-ping", "yes"); + } + + // 读取resolv.conf + let content = match env.fs.read_file(RESOLV_CONF) { + Ok(content) => content, + Err(e) if e.to_string().contains("NotFound") => { + dbg("rc", "missing"); + return Ok("direct".to_string()); + } + Err(e) => return Err(e).context("reading /etc/resolv.conf"), + }; + + // 检查resolv.conf所有者 + match resolv_owner(&content).as_str() { + "systemd-resolved" => { + dbg("rc", "resolved"); + // 检查是否实际使用resolved + if let Err(e) = resolved_is_actually_resolver(env, &dbg, &content) { + tracing::warn!("resolvedIsActuallyResolver error: {}", e); + dbg("resolved", "not-in-use"); + return Ok("direct".to_string()); + } + + // NetworkManager检查逻辑... + + Ok("systemd-resolved".to_string()) + } + "resolvconf" => { + // resolvconf处理逻辑... + Ok("debian-resolvconf".to_string()) + } + "NetworkManager" => { + // NetworkManager处理逻辑... + Ok("systemd-resolved".to_string()) + } + _ => Ok("direct".to_string()), + } +} + +// D-Bus ping实现 +fn dbus_ping(name: &str, object_path: &str) -> Result<()> { + let conn = dbus::blocking::Connection::new_system()?; + let proxy = conn.with_proxy(name, object_path, PING_TIMEOUT); + let _: () = proxy.method_call("org.freedesktop.DBus.Peer", "Ping", ())?; + Ok(()) +} + +// D-Bus读取字符串实现 +fn dbus_read_string(name: &str, object_path: &str, iface: &str, member: &str) -> Result { + let conn = dbus::blocking::Connection::new_system()?; + let proxy = conn.with_proxy(name, object_path, PING_TIMEOUT); + let (value,): (String,) = + proxy.method_call("org.freedesktop.DBus.Properties", "Get", (iface, member))?; + Ok(value) +} + +// NetworkManager版本检查 +fn nm_version_between(first: &str, last: &str) -> Result { + let conn = dbus::blocking::Connection::new_system()?; + let proxy = conn.with_proxy( + "org.freedesktop.NetworkManager", + "/org/freedesktop/NetworkManager", + PING_TIMEOUT, + ); + + let version: String = proxy.get("org.freedesktop.NetworkManager", "Version")?; + let cmp_first = version_compare::compare(&version, first).unwrap_or(Cmp::Lt); + let cmp_last = version_compare::compare(&version, last).unwrap_or(Cmp::Gt); + Ok(cmp_first == Cmp::Ge && cmp_last == Cmp::Le) +} + +// 检查是否实际使用systemd-resolved +fn resolved_is_actually_resolver( + env: &OSConfigEnv, + dbg: &dyn Fn(&str, &str), + content: &[u8], +) -> Result<()> { + if is_libnss_resolve_used(env).is_ok() { + dbg("resolved", "nss"); + return Ok(()); + } + + // 解析resolv.conf内容 + let resolver = resolv_conf::Config::parse(content)?; + + // 检查nameserver配置 + if resolver.nameservers.is_empty() { + return Err(anyhow::anyhow!("resolv.conf has no nameservers")); + } + + for ns in resolver.nameservers { + if ns != Ipv4Addr::new(127, 0, 0, 53).into() { + return Err(anyhow::anyhow!( + "resolv.conf doesn't point to systemd-resolved" + )); + } + } + + dbg("resolved", "file"); + Ok(()) +} + +// 检查是否使用libnss_resolve +fn is_libnss_resolve_used(env: &OSConfigEnv) -> Result<()> { + let content = env.fs.read_file("/etc/nsswitch.conf")?; + + for line in String::from_utf8_lossy(&content).lines() { + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.first() == Some(&"hosts:") { + for module in parts.iter().skip(1) { + if *module == "dns" { + return Err(anyhow::anyhow!("dns module has higher priority")); + } + if *module == "resolve" { + return Ok(()); + } + } + } + } + + Err(anyhow::anyhow!("libnss_resolve not used")) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn dns_mode_test() { + let env = new_os_config_env(); + let mode = dns_mode(&env).unwrap(); + println!("Detected DNS mode: {}", mode); + } +} diff --git a/easytier/src/instance/dns_server/system_config/mod.rs b/easytier/src/instance/dns_server/system_config/mod.rs new file mode 100644 index 0000000..51b8bc8 --- /dev/null +++ b/easytier/src/instance/dns_server/system_config/mod.rs @@ -0,0 +1,5 @@ +#[cfg(target_os = "linux")] +pub mod linux; + +#[cfg(target_os = "windows")] +pub mod windows; diff --git a/easytier/src/instance/dns_server/system_config/windows.rs b/easytier/src/instance/dns_server/system_config/windows.rs new file mode 100644 index 0000000..b8fc251 --- /dev/null +++ b/easytier/src/instance/dns_server/system_config/windows.rs @@ -0,0 +1,233 @@ +use std::net::IpAddr; +use std::process::Command; + +use std::io; +use winreg::RegKey; + +use crate::common::ifcfg::RegistryManager; + +pub fn is_windows_10_or_better() -> io::Result { + let hklm = winreg::enums::HKEY_LOCAL_MACHINE; + let key_path = "SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion"; + let key = winreg::RegKey::predef(hklm).open_subkey(key_path)?; + + // check CurrentMajorVersionNumber, which only exists on Windows 10 and later + let value_name = "CurrentMajorVersionNumber"; + key.get_raw_value(value_name).map(|_| true) +} + +// 假设 interface_guid 是你的网络接口 GUID +pub struct InterfaceControl { + interface_guid: String, +} + +impl InterfaceControl { + // 构造函数 + pub fn new(interface_guid: &str) -> Self { + InterfaceControl { + interface_guid: interface_guid.to_string(), + } + } + + // 删除注册表值(模拟 delValue) + fn delete_value(key: &RegKey, value_name: &str) -> io::Result<()> { + match key.delete_value(value_name) { + Ok(_) => Ok(()), + Err(e) => { + if matches!(e.kind(), io::ErrorKind::NotFound) { + Ok(()) // 忽略不存在的值 + } else { + Err(e.into()) + } + } + } + } + + pub fn set_primary_dns(&self, resolvers: &[IpAddr], domains: &[String]) -> io::Result<()> { + let (ipsv4, ipsv6): (Vec, Vec) = resolvers + .iter() + .map(|ip| ip.to_string()) + .partition(|ip| ip.contains('.')); + + let dom_strs: Vec = domains + .iter() + .map(|d| d.trim_end_matches('.').to_string()) + .collect(); + + // IPv4 处理 + if let Ok(key4) = RegistryManager::open_interface_key( + &self.interface_guid, + RegistryManager::IPV4_TCPIP_INTERFACE_PREFIX, + ) { + if ipsv4.is_empty() { + Self::delete_value(&key4, "NameServer")?; + } else { + key4.set_value("NameServer", &ipsv4.join(","))?; + } + + if dom_strs.is_empty() { + Self::delete_value(&key4, "SearchList")?; + } else { + key4.set_value("SearchList", &dom_strs.join(","))?; + } + + // 禁用 LLMNR(通过 DisableMulticast) + key4.set_value("EnableMulticast", &0u32)?; + } + + // IPv6 处理 + if let Ok(key6) = RegistryManager::open_interface_key( + &self.interface_guid, + RegistryManager::IPV6_TCPIP_INTERFACE_PREFIX, + ) { + if ipsv6.is_empty() { + Self::delete_value(&key6, "NameServer")?; + } else { + key6.set_value("NameServer", &ipsv6.join(","))?; + } + + if dom_strs.is_empty() { + Self::delete_value(&key6, "SearchList")?; + } else { + key6.set_value("SearchList", &dom_strs.join(","))?; + } + key6.set_value("EnableMulticast", &0u32)?; + } + + Ok(()) + } + + fn flush_dns(&self) -> io::Result<()> { + // 刷新 DNS 缓存 + let output = Command::new("ipconfig") + .arg("/flushdns") + .output() + .expect("failed to execute process"); + if !output.status.success() { + return Err(io::Error::new( + io::ErrorKind::Other, + "Failed to flush DNS cache", + )); + } + Ok(()) + } + + // re-register DNS + pub fn re_register_dns(&self) -> io::Result<()> { + // ipconfig /registerdns + let output = Command::new("ipconfig") + .arg("/registerdns") + .output() + .expect("failed to execute process"); + if !output.status.success() { + return Err(io::Error::new( + io::ErrorKind::Other, + "Failed to register DNS", + )); + } + Ok(()) + } +} + +pub struct WindowsDNSManager { + tun_dev_name: String, + interface_control: InterfaceControl, +} + +impl WindowsDNSManager { + pub fn new(tun_dev_name: &str) -> io::Result { + let interface_guid = RegistryManager::find_interface_guid(tun_dev_name)?; + Ok(WindowsDNSManager { + tun_dev_name: tun_dev_name.to_string(), + interface_control: InterfaceControl::new(&interface_guid), + }) + } + + pub fn set_primary_dns(&self, resolvers: &[IpAddr], domains: &[String]) -> io::Result<()> { + self.interface_control.set_primary_dns(resolvers, domains)?; + self.interface_control.flush_dns()?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use cidr::Ipv4Inet; + + #[cfg(target_os = "windows")] + #[tokio::test] + async fn test_windows_set_primary_server() { + use std::{net::Ipv4Addr, str::FromStr as _, time::Duration}; + + use tokio_util::sync::CancellationToken; + + use crate::instance::dns_server::{ + runner::DnsRunner, + tests::{check_dns_record, prepare_env}, + }; + + let tun_ip = Ipv4Inet::from_str("10.144.144.10/24").unwrap(); + let (peer_mgr, virtual_nic) = prepare_env("test1", tun_ip).await; + let tun_name = virtual_nic.ifname().await.unwrap(); + + println!("dev_name: {}", tun_name); + let fake_ip = Ipv4Addr::from_str("100.100.100.101").unwrap(); + let mut dns_runner = DnsRunner::new(peer_mgr, Some(tun_name.clone()), tun_ip, fake_ip); + + let cancel_token = CancellationToken::new(); + let cancel_token_clone = cancel_token.clone(); + let t = tokio::spawn(async move { + dns_runner.run(cancel_token_clone).await; + }); + + // windows is slow to add a ip address, wait for a longer time for dns server ready ,with ping + let now = std::time::Instant::now(); + while now.elapsed() < Duration::from_secs(15) { + tokio::time::sleep(Duration::from_secs(1)).await; + if let Ok(o) = tokio::process::Command::new("ping") + .arg("-n") + .arg("1") + .arg("-w") + .arg("100") + .arg(&fake_ip.to_string()) + .output() + .await + { + if o.status.success() { + break; + } + } + } + + check_dns_record(&fake_ip, "test1.et.net", "10.144.144.10").await; + + let dns_mgr = super::WindowsDNSManager::new(&tun_name).unwrap(); + println!("dev_name: {}", tun_name); + println!("guid: {}", dns_mgr.interface_control.interface_guid); + + dns_mgr + .interface_control + .set_primary_dns( + &["100.100.100.101".parse().unwrap()], + &[".et.net.".to_string()], + ) + .unwrap(); + dns_mgr.interface_control.flush_dns().unwrap(); + + tracing::info!("check dns record with nslookup"); + + // nslookup should return 10.144.144.10 + let ret = tokio::process::Command::new("nslookup") + .arg("test1.et.net") + .output() + .await + .expect("failed to execute process"); + assert!(ret.status.success()); + let output = String::from_utf8_lossy(&ret.stdout); + println!("nslookup output: {}", output); + assert!(output.contains("10.144.144.10")); + + cancel_token.cancel(); + let _ = t.await; + } +} diff --git a/easytier/src/instance/dns_server/tests.rs b/easytier/src/instance/dns_server/tests.rs new file mode 100644 index 0000000..6035cb4 --- /dev/null +++ b/easytier/src/instance/dns_server/tests.rs @@ -0,0 +1,134 @@ +use std::net::{Ipv4Addr, SocketAddr}; +use std::str::FromStr as _; +use std::sync::Arc; +use std::time::Duration; + +use cidr::Ipv4Inet; +use hickory_client::client::{Client, ClientHandle as _}; +use hickory_proto::rr; +use hickory_proto::runtime::TokioRuntimeProvider; +use hickory_proto::udp::UdpClientStream; +use tokio_util::sync::CancellationToken; + +use crate::common::global_ctx::tests::get_mock_global_ctx; +use crate::connector::udp_hole_punch::tests::replace_stun_info_collector; + +use crate::instance::dns_server::runner::DnsRunner; +use crate::instance::dns_server::server_instance::MagicDnsServerInstance; +use crate::instance::dns_server::DEFAULT_ET_DNS_ZONE; +use crate::instance::virtual_nic::NicCtx; +use crate::peers::peer_manager::{PeerManager, RouteAlgoType}; + +use crate::peers::create_packet_recv_chan; +use crate::proto::cli::Route; +use crate::proto::common::NatType; + +pub async fn prepare_env(dns_name: &str, tun_ip: Ipv4Inet) -> (Arc, NicCtx) { + let ctx = get_mock_global_ctx(); + ctx.set_hostname(dns_name.to_owned()); + ctx.set_ipv4(Some(tun_ip)); + let (s, r) = create_packet_recv_chan(); + let peer_mgr = Arc::new(PeerManager::new(RouteAlgoType::Ospf, ctx, s)); + peer_mgr.run().await.unwrap(); + replace_stun_info_collector(peer_mgr.clone(), NatType::PortRestricted); + + let r = Arc::new(tokio::sync::Mutex::new(r)); + let mut virtual_nic = NicCtx::new(peer_mgr.get_global_ctx(), &peer_mgr, r); + virtual_nic.run(tun_ip).await.unwrap(); + + (peer_mgr, virtual_nic) +} + +pub async fn check_dns_record(fake_ip: &Ipv4Addr, domain: &str, expected_ip: &str) { + let stream = UdpClientStream::builder( + SocketAddr::new(fake_ip.clone().into(), 53), + TokioRuntimeProvider::default(), + ) + .build(); + let (mut client, background) = Client::connect(stream).await.unwrap(); + let background_task = tokio::spawn(background); + let response = client + .query( + rr::Name::from_str(domain).unwrap(), + rr::DNSClass::IN, + rr::RecordType::A, + ) + .await + .unwrap(); + drop(background_task); + + println!("Response: {:?}", response); + + assert_eq!(response.answers().len(), 1, "{:?}", response.answers()); + let resp = response.answers().first().unwrap(); + assert_eq!( + resp.clone().into_parts().rdata.into_a().unwrap().0, + expected_ip.parse::().unwrap() + ); +} + +#[tokio::test] +async fn test_magic_dns_server_instance() { + let tun_ip = Ipv4Inet::from_str("10.144.144.10/24").unwrap(); + let (peer_mgr, virtual_nic) = prepare_env("test1", tun_ip).await; + let tun_name = virtual_nic.ifname().await.unwrap(); + let fake_ip = Ipv4Addr::from_str("100.100.100.101").unwrap(); + let dns_server_inst = + MagicDnsServerInstance::new(peer_mgr.clone(), Some(tun_name), tun_ip, fake_ip) + .await + .unwrap(); + + let routes = vec![Route { + hostname: "test1".to_string(), + ipv4_addr: Some(Ipv4Inet::from_str("8.8.8.8/24").unwrap().into()), + ..Default::default() + }]; + dns_server_inst + .data + .update_dns_records(routes.iter(), DEFAULT_ET_DNS_ZONE) + .await + .unwrap(); + + check_dns_record(&fake_ip, "test1.et.net", "8.8.8.8").await; +} + +#[tokio::test] +async fn test_magic_dns_runner() { + let tun_ip = Ipv4Inet::from_str("10.144.144.10/24").unwrap(); + let (peer_mgr, virtual_nic) = prepare_env("test1", tun_ip).await; + let tun_name = virtual_nic.ifname().await.unwrap(); + let fake_ip = Ipv4Addr::from_str("100.100.100.101").unwrap(); + let mut dns_runner = DnsRunner::new(peer_mgr, Some(tun_name), tun_ip, fake_ip); + + let cancel_token = CancellationToken::new(); + let cancel_token_clone = cancel_token.clone(); + let t = tokio::spawn(async move { + dns_runner.run(cancel_token_clone).await; + }); + tokio::time::sleep(Duration::from_secs(3)).await; + check_dns_record(&fake_ip, "test1.et.net", "10.144.144.10").await; + + // add a new dns runner + let tun_ip2 = Ipv4Inet::from_str("10.144.144.20/24").unwrap(); + let (peer_mgr, virtual_nic) = prepare_env("test2", tun_ip2).await; + let tun_name2 = virtual_nic.ifname().await.unwrap(); + let mut dns_runner2 = DnsRunner::new(peer_mgr, Some(tun_name2), tun_ip2, fake_ip); + let cancel_token2 = CancellationToken::new(); + let cancel_token2_clone = cancel_token2.clone(); + let t2 = tokio::spawn(async move { + dns_runner2.run(cancel_token2_clone).await; + }); + tokio::time::sleep(Duration::from_secs(3)).await; + check_dns_record(&fake_ip, "test1.et.net", "10.144.144.10").await; + check_dns_record(&fake_ip, "test2.et.net", "10.144.144.20").await; + + // stop runner 1, runner 2 will take over the dns server + cancel_token.cancel(); + t.await.unwrap(); + + tokio::time::sleep(Duration::from_secs(3)).await; + check_dns_record(&fake_ip, "test2.et.net", "10.144.144.20").await; + + cancel_token2.cancel(); + t2.await.unwrap(); +} diff --git a/easytier/src/instance/instance.rs b/easytier/src/instance/instance.rs index ebdb2f1..3e6bc50 100644 --- a/easytier/src/instance/instance.rs +++ b/easytier/src/instance/instance.rs @@ -7,7 +7,9 @@ use std::sync::{Arc, Weak}; use anyhow::Context; use cidr::Ipv4Inet; +use tokio::task::JoinHandle; use tokio::{sync::Mutex, task::JoinSet}; +use tokio_util::sync::CancellationToken; use crate::common::config::ConfigLoader; use crate::common::error::Error; @@ -34,6 +36,8 @@ use crate::proto::rpc_types::controller::BaseController; use crate::tunnel::tcp::TcpTunnelListener; use crate::vpn_portal::{self, VpnPortal}; +use super::dns_server::runner::DnsRunner; +use super::dns_server::MAGIC_DNS_FAKE_IP; use super::listeners::ListenerManager; #[cfg(feature = "socks5")] @@ -101,7 +105,49 @@ impl NicCtx { } } -type ArcNicCtx = Arc>>>; +struct MagicDnsContainer { + dns_runner_task: JoinHandle<()>, + dns_runner_cancel_token: CancellationToken, +} + +// nic container will be cleared when dhcp ip changed +pub(crate) struct NicCtxContainer { + nic_ctx: Option>, + magic_dns: Option, +} + +impl NicCtxContainer { + fn new(nic_ctx: NicCtx, dns_runner: Option) -> Self { + if let Some(mut dns_runner) = dns_runner { + let token = CancellationToken::new(); + let token_clone = token.clone(); + let task = tokio::spawn(async move { + let _ = dns_runner.run(token_clone).await; + }); + Self { + nic_ctx: Some(Box::new(nic_ctx)), + magic_dns: Some(MagicDnsContainer { + dns_runner_task: task, + dns_runner_cancel_token: token, + }), + } + } else { + Self { + nic_ctx: Some(Box::new(nic_ctx)), + magic_dns: None, + } + } + } + + fn new_with_any(ctx: T) -> Self { + Self { + nic_ctx: Some(Box::new(ctx)), + magic_dns: None, + } + } +} + +type ArcNicCtx = Arc>>; pub struct Instance { inst_name: String, @@ -233,7 +279,14 @@ impl Instance { arc_nic_ctx: ArcNicCtx, packet_recv: Arc>, ) { - let _ = arc_nic_ctx.lock().await.take(); + if let Some(old_ctx) = arc_nic_ctx.lock().await.take() { + if let Some(dns_runner) = old_ctx.magic_dns { + dns_runner.dns_runner_cancel_token.cancel(); + tracing::debug!("cancelling dns runner task"); + let ret = dns_runner.dns_runner_task.await; + tracing::debug!("dns runner task cancelled, ret: {:?}", ret); + } + }; let mut tasks = JoinSet::new(); tasks.spawn(async move { @@ -242,14 +295,40 @@ impl Instance { tracing::trace!("packet consumed by mock nic ctx: {:?}", packet); } }); - arc_nic_ctx.lock().await.replace(Box::new(tasks)); + arc_nic_ctx + .lock() + .await + .replace(NicCtxContainer::new_with_any(tasks)); tracing::debug!("nic ctx cleared."); } - async fn use_new_nic_ctx(arc_nic_ctx: ArcNicCtx, nic_ctx: NicCtx) { + fn create_magic_dns_runner( + peer_mgr: Arc, + tun_dev: Option, + tun_ip: Ipv4Inet, + ) -> Option { + let ctx = peer_mgr.get_global_ctx(); + if !ctx.config.get_flags().accept_dns { + return None; + } + + let runner = DnsRunner::new( + peer_mgr, + tun_dev, + tun_ip, + MAGIC_DNS_FAKE_IP.parse().unwrap(), + ); + Some(runner) + } + + async fn use_new_nic_ctx( + arc_nic_ctx: ArcNicCtx, + nic_ctx: NicCtx, + magic_dns: Option, + ) { let mut g = arc_nic_ctx.lock().await; - *g = Some(Box::new(nic_ctx)); + *g = Some(NicCtxContainer::new(nic_ctx, magic_dns)); tracing::debug!("nic ctx updated."); } @@ -339,7 +418,17 @@ impl Instance { global_ctx_c.set_ipv4(None); continue; } - Self::use_new_nic_ctx(nic_ctx.clone(), new_nic_ctx).await; + let ifname = new_nic_ctx.ifname().await; + Self::use_new_nic_ctx( + nic_ctx.clone(), + new_nic_ctx, + Self::create_magic_dns_runner( + peer_manager_c.clone(), + ifname, + ip.clone(), + ), + ) + .await; } current_dhcp_ip = Some(ip); @@ -374,7 +463,17 @@ impl Instance { self.peer_packet_receiver.clone(), ); new_nic_ctx.run(ipv4_addr).await?; - Self::use_new_nic_ctx(self.nic_ctx.clone(), new_nic_ctx).await; + let ifname = new_nic_ctx.ifname().await; + Self::use_new_nic_ctx( + self.nic_ctx.clone(), + new_nic_ctx, + Self::create_magic_dns_runner( + self.peer_manager.clone(), + ifname, + ipv4_addr.clone(), + ), + ) + .await; } } @@ -611,7 +710,13 @@ impl Instance { .run_for_android(fd) .await .with_context(|| "add ip failed")?; - Self::use_new_nic_ctx(nic_ctx.clone(), new_nic_ctx).await; + + let magic_dns_runner = if let Some(ipv4) = global_ctx.get_ipv4() { + Self::create_magic_dns_runner(peer_manager.clone(), None, ipv4) + } else { + None + }; + Self::use_new_nic_ctx(nic_ctx.clone(), new_nic_ctx, magic_dns_runner).await; Ok(()) } } diff --git a/easytier/src/instance/mod.rs b/easytier/src/instance/mod.rs index 17f7f32..1df957d 100644 --- a/easytier/src/instance/mod.rs +++ b/easytier/src/instance/mod.rs @@ -1,3 +1,4 @@ +pub mod dns_server; pub mod instance; pub mod listeners; diff --git a/easytier/src/instance/virtual_nic.rs b/easytier/src/instance/virtual_nic.rs index 7a2625c..029a6b6 100644 --- a/easytier/src/instance/virtual_nic.rs +++ b/easytier/src/instance/virtual_nic.rs @@ -34,6 +34,9 @@ use tokio_util::bytes::Bytes; use tun::{AbstractDevice, AsyncDevice, Configuration, Layer}; use zerocopy::{NativeEndian, NetworkEndian}; +#[cfg(target_os = "windows")] +use crate::common::ifcfg::RegistryManager; + pin_project! { pub struct TunStream { #[pin] @@ -243,81 +246,6 @@ pub struct VirtualNic { ifcfg: Box, } -#[cfg(target_os = "windows")] -pub fn checkreg(dev_name: &str) -> io::Result<()> { - use winreg::{enums::HKEY_LOCAL_MACHINE, enums::KEY_ALL_ACCESS, RegKey}; - let hklm = RegKey::predef(HKEY_LOCAL_MACHINE); - let profiles_key = hklm.open_subkey_with_flags( - "SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\NetworkList\\Profiles", - KEY_ALL_ACCESS, - )?; - let unmanaged_key = hklm.open_subkey_with_flags( - "SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\NetworkList\\Signatures\\Unmanaged", - KEY_ALL_ACCESS, - )?; - // collect subkeys to delete - let mut keys_to_delete = Vec::new(); - let mut keys_to_delete_unmanaged = Vec::new(); - for subkey_name in profiles_key.enum_keys().filter_map(Result::ok) { - let subkey = profiles_key.open_subkey(&subkey_name)?; - // check if ProfileName contains "et" - match subkey.get_value::("ProfileName") { - Ok(profile_name) => { - if profile_name.contains("et_") - || (!dev_name.is_empty() && dev_name == profile_name) - { - keys_to_delete.push(subkey_name); - } - } - Err(e) => { - tracing::error!( - "Failed to read ProfileName for subkey {}: {}", - subkey_name, - e - ); - } - } - } - for subkey_name in unmanaged_key.enum_keys().filter_map(Result::ok) { - let subkey = unmanaged_key.open_subkey(&subkey_name)?; - // check if ProfileName contains "et" - match subkey.get_value::("Description") { - Ok(profile_name) => { - if profile_name.contains("et_") - || (!dev_name.is_empty() && dev_name == profile_name) - { - keys_to_delete_unmanaged.push(subkey_name); - } - } - Err(e) => { - tracing::error!( - "Failed to read ProfileName for subkey {}: {}", - subkey_name, - e - ); - } - } - } - // delete collected subkeys - if !keys_to_delete.is_empty() { - for subkey_name in keys_to_delete { - match profiles_key.delete_subkey_all(&subkey_name) { - Ok(_) => tracing::trace!("Successfully deleted subkey: {}", subkey_name), - Err(e) => tracing::error!("Failed to delete subkey {}: {}", subkey_name, e), - } - } - } - if !keys_to_delete_unmanaged.is_empty() { - for subkey_name in keys_to_delete_unmanaged { - match unmanaged_key.delete_subkey_all(&subkey_name) { - Ok(_) => tracing::trace!("Successfully deleted subkey: {}", subkey_name), - Err(e) => tracing::error!("Failed to delete subkey {}: {}", subkey_name, e), - } - } - } - Ok(()) -} - impl VirtualNic { pub fn new(global_ctx: ArcGlobalCtx) -> Self { Self { @@ -358,7 +286,7 @@ impl VirtualNic { } } - match checkreg(&dev_name) { + match RegistryManager::reg_delete_obsoleted_items(&dev_name) { Ok(_) => tracing::trace!("delete successful!"), Err(e) => tracing::error!("An error occurred: {}", e), } @@ -433,6 +361,30 @@ impl VirtualNic { let ifname = dev.tun_name()?; self.ifcfg.wait_interface_show(ifname.as_str()).await?; + #[cfg(target_os = "windows")] + { + if let Ok(guid) = RegistryManager::find_interface_guid(&ifname) { + if let Err(e) = RegistryManager::disable_dynamic_updates(&guid) { + tracing::error!( + "Failed to disable dhcp for interface {} {}: {}", + ifname, + guid, + e + ); + } + + // Disable NetBIOS over TCP/IP + if let Err(e) = RegistryManager::disable_netbios(&guid) { + tracing::error!( + "Failed to disable netbios for interface {} {}: {}", + ifname, + guid, + e + ); + } + } + } + let dev = AsyncDevice::new(dev)?; let flags = self.global_ctx.config.get_flags(); @@ -476,7 +428,7 @@ impl VirtualNic { pub async fn add_route(&self, address: Ipv4Addr, cidr: u8) -> Result<(), Error> { let _g = self.global_ctx.net_ns.guard(); self.ifcfg - .add_ipv4_route(self.ifname(), address, cidr) + .add_ipv4_route(self.ifname(), address, cidr, None) .await?; Ok(()) } @@ -500,38 +452,6 @@ impl VirtualNic { } } -#[cfg(target_os = "windows")] -pub fn reg_change_catrgory_in_profile(dev_name: &str) -> io::Result<()> { - use winreg::{enums::HKEY_LOCAL_MACHINE, enums::KEY_ALL_ACCESS, RegKey}; - let hklm = RegKey::predef(HKEY_LOCAL_MACHINE); - let profiles_key = hklm.open_subkey_with_flags( - "SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\NetworkList\\Profiles", - KEY_ALL_ACCESS, - )?; - - for subkey_name in profiles_key.enum_keys().filter_map(Result::ok) { - let subkey = profiles_key.open_subkey_with_flags(&subkey_name, KEY_ALL_ACCESS)?; - match subkey.get_value::("ProfileName") { - Ok(profile_name) => { - if !dev_name.is_empty() && dev_name == profile_name { - match subkey.set_value("Category", &1u32) { - Ok(_) => tracing::trace!("Successfully set Category in registry"), - Err(e) => tracing::error!("Failed to set Category in registry: {}", e), - } - } - } - Err(e) => { - tracing::error!( - "Failed to read ProfileName for subkey {}: {}", - subkey_name, - e - ); - } - } - } - Ok(()) -} - pub struct NicCtx { global_ctx: ArcGlobalCtx, peer_mgr: Weak, @@ -556,7 +476,12 @@ impl NicCtx { } } - async fn assign_ipv4_to_tun_device(&self, ipv4_addr: cidr::Ipv4Inet) -> Result<(), Error> { + pub async fn ifname(&self) -> Option { + let nic = self.nic.lock().await; + nic.ifname.as_ref().map(|s| s.to_owned()) + } + + pub async fn assign_ipv4_to_tun_device(&self, ipv4_addr: cidr::Ipv4Inet) -> Result<(), Error> { let nic = self.nic.lock().await; nic.link_up().await?; nic.remove_ip(None).await?; @@ -700,6 +625,7 @@ impl NicCtx { ifname.as_str(), cidr.first_address(), cidr.network_length(), + None, ) .await; @@ -728,7 +654,7 @@ impl NicCtx { #[cfg(target_os = "windows")] { let dev_name = self.global_ctx.get_flags().dev_name; - let _ = reg_change_catrgory_in_profile(&dev_name); + let _ = RegistryManager::reg_change_catrgory_in_profile(&dev_name); } self.global_ctx diff --git a/easytier/src/peer_center/instance.rs b/easytier/src/peer_center/instance.rs index 0537c7b..82499c2 100644 --- a/easytier/src/peer_center/instance.rs +++ b/easytier/src/peer_center/instance.rs @@ -32,7 +32,7 @@ use super::{server::PeerCenterServer, Digest, Error}; struct PeerCenterBase { peer_mgr: Arc, - tasks: Arc>>, + tasks: Mutex>, lock: Arc>, } @@ -139,7 +139,7 @@ impl PeerCenterBase { pub fn new(peer_mgr: Arc) -> Self { PeerCenterBase { peer_mgr, - tasks: Arc::new(Mutex::new(JoinSet::new())), + tasks: Mutex::new(JoinSet::new()), lock: Arc::new(Mutex::new(())), } } diff --git a/easytier/src/peers/mod.rs b/easytier/src/peers/mod.rs index bb9fd19..0fbd63f 100644 --- a/easytier/src/peers/mod.rs +++ b/easytier/src/peers/mod.rs @@ -34,6 +34,10 @@ pub trait PeerPacketFilter { #[auto_impl::auto_impl(Arc)] pub trait NicPacketFilter { async fn try_process_packet_from_nic(&self, data: &mut ZCPacket) -> bool; + + fn id(&self) -> String { + format!("{:p}", self) + } } type BoxPeerPacketFilter = Box; diff --git a/easytier/src/peers/peer_manager.rs b/easytier/src/peers/peer_manager.rs index b06bec5..6abad6f 100644 --- a/easytier/src/peers/peer_manager.rs +++ b/easytier/src/peers/peer_manager.rs @@ -2,7 +2,7 @@ use std::{ fmt::Debug, net::Ipv4Addr, sync::{Arc, Weak}, - time::SystemTime, + time::{Instant, SystemTime}, }; use anyhow::Context; @@ -120,7 +120,7 @@ pub struct PeerManager { global_ctx: ArcGlobalCtx, nic_channel: PacketRecvChan, - tasks: Arc>>, + tasks: Mutex>, packet_recv: Arc>>, @@ -249,7 +249,7 @@ impl PeerManager { global_ctx, nic_channel, - tasks: Arc::new(Mutex::new(JoinSet::new())), + tasks: Mutex::new(JoinSet::new()), packet_recv: Arc::new(Mutex::new(Some(packet_recv))), @@ -735,6 +735,10 @@ impl PeerManager { self.get_route().list_routes().await } + pub async fn get_route_peer_info_last_update_time(&self) -> Instant { + self.get_route().get_peer_info_last_update_time().await + } + pub async fn dump_route(&self) -> String { self.get_route().dump().await } @@ -767,6 +771,16 @@ impl PeerManager { } } + pub async fn remove_nic_packet_process_pipeline(&self, id: String) -> Result<(), Error> { + let mut pipelines = self.nic_packet_process_pipeline.write().await; + if let Some(pos) = pipelines.iter().position(|x| x.id() == id) { + pipelines.remove(pos); + Ok(()) + } else { + Err(Error::NotFound) + } + } + fn get_next_hop_policy(is_first_latency: bool) -> NextHopPolicy { if is_first_latency { NextHopPolicy::LeastCost diff --git a/easytier/src/peers/peer_ospf_route.rs b/easytier/src/peers/peer_ospf_route.rs index ada7cd7..11fcb58 100644 --- a/easytier/src/peers/peer_ospf_route.rs +++ b/easytier/src/peers/peer_ospf_route.rs @@ -1030,6 +1030,8 @@ struct PeerRouteServiceImpl { cached_local_conn_map: std::sync::Mutex, last_update_my_foreign_network: AtomicCell>, + + peer_info_last_update: AtomicCell, } impl Debug for PeerRouteServiceImpl { @@ -1076,6 +1078,8 @@ impl PeerRouteServiceImpl { cached_local_conn_map: std::sync::Mutex::new(RouteConnBitmap::new()), last_update_my_foreign_network: AtomicCell::new(None), + + peer_info_last_update: AtomicCell::new(std::time::Instant::now()), } } @@ -1225,6 +1229,8 @@ impl PeerRouteServiceImpl { } fn update_route_table_and_cached_local_conn_bitmap(&self) { + self.update_peer_info_last_update(); + // update route table first because we want to filter out unreachable peers. self.update_route_table(); @@ -1347,6 +1353,9 @@ impl PeerRouteServiceImpl { if my_conn_info_updated || my_peer_info_updated { self.update_foreign_network_owner_map(); } + if my_peer_info_updated { + self.update_peer_info_last_update(); + } my_peer_info_updated || my_conn_info_updated || my_foreign_network_updated } @@ -1547,6 +1556,15 @@ impl PeerRouteServiceImpl { } return false; } + + fn update_peer_info_last_update(&self) { + tracing::debug!(?self, "update_peer_info_last_update"); + self.peer_info_last_update.store(std::time::Instant::now()); + } + + fn get_peer_info_last_update(&self) -> std::time::Instant { + self.peer_info_last_update.load() + } } impl Drop for PeerRouteServiceImpl { @@ -2195,6 +2213,10 @@ impl Route for PeerRoute { .get(&peer_id) .and_then(|x| x.feature_flag.clone()) } + + async fn get_peer_info_last_update_time(&self) -> Instant { + self.service_impl.get_peer_info_last_update() + } } impl PeerPacketFilter for Arc {} diff --git a/easytier/src/peers/peer_rpc.rs b/easytier/src/peers/peer_rpc.rs index 739bdd8..eaad3f6 100644 --- a/easytier/src/peers/peer_rpc.rs +++ b/easytier/src/peers/peer_rpc.rs @@ -26,7 +26,7 @@ pub trait PeerRpcManagerTransport: Send + Sync + 'static { pub struct PeerRpcManager { tspt: Arc>, bidirect_rpc: BidirectRpcManager, - tasks: Arc>>, + tasks: Mutex>, } impl std::fmt::Debug for PeerRpcManager { @@ -43,7 +43,7 @@ impl PeerRpcManager { tspt: Arc::new(Box::new(tspt)), bidirect_rpc: BidirectRpcManager::new(), - tasks: Arc::new(Mutex::new(JoinSet::new())), + tasks: Mutex::new(JoinSet::new()), } } diff --git a/easytier/src/peers/route_trait.rs b/easytier/src/peers/route_trait.rs index 76fd1fb..80c6533 100644 --- a/easytier/src/peers/route_trait.rs +++ b/easytier/src/peers/route_trait.rs @@ -99,6 +99,8 @@ pub trait Route { async fn get_feature_flag(&self, peer_id: PeerId) -> Option; + async fn get_peer_info_last_update_time(&self) -> std::time::Instant; + async fn dump(&self) -> String { "this route implementation does not support dump".to_string() } diff --git a/easytier/src/proto/common.proto b/easytier/src/proto/common.proto index 6d21183..405b028 100644 --- a/easytier/src/proto/common.proto +++ b/easytier/src/proto/common.proto @@ -30,6 +30,9 @@ message FlagsInConfig { // allow relay kcp packets (for public server, this can reduce the throughput) bool disable_relay_kcp = 20; bool proxy_forward_by_system = 21; + + // enable magic dns or not + bool accept_dns = 22; } message RpcDescriptor { diff --git a/easytier/src/proto/magic_dns.proto b/easytier/src/proto/magic_dns.proto new file mode 100644 index 0000000..54790aa --- /dev/null +++ b/easytier/src/proto/magic_dns.proto @@ -0,0 +1,49 @@ +syntax = "proto3"; + +import "google/protobuf/timestamp.proto"; +import "common.proto"; +import "cli.proto"; + +package magic_dns; + +message DnsRecordA { + string name = 1; + common.Ipv4Addr value = 2; + int32 ttl = 3; +} + +message DnsRecordSOA { + string name = 1; + string value = 2; +} + +message DnsRecord { + oneof record { + DnsRecordA a = 1; + DnsRecordSOA soa = 2; + } +} + +message DnsRecordList { + repeated DnsRecord records = 1; +} + +message UpdateDnsRecordRequest { + string zone = 1; + repeated cli.Route routes = 2; +} + +message GetDnsRecordResponse { + map records = 1; +} + +message HandshakeRequest {} + +message HandshakeResponse {} + +service MagicDnsServerRpc { + rpc Handshake(HandshakeRequest) returns (HandshakeResponse) {} + rpc Heartbeat(common.Void) returns (common.Void) {} + rpc UpdateDnsRecord(UpdateDnsRecordRequest) returns (common.Void) {} + rpc GetDnsRecord(common.Void) returns (GetDnsRecordResponse) {} +} diff --git a/easytier/src/proto/magic_dns.rs b/easytier/src/proto/magic_dns.rs new file mode 100644 index 0000000..f87812d --- /dev/null +++ b/easytier/src/proto/magic_dns.rs @@ -0,0 +1 @@ +include!(concat!(env!("OUT_DIR"), "/magic_dns.rs")); diff --git a/easytier/src/proto/mod.rs b/easytier/src/proto/mod.rs index f1da8fd..51fe99a 100644 --- a/easytier/src/proto/mod.rs +++ b/easytier/src/proto/mod.rs @@ -4,6 +4,7 @@ pub mod rpc_types; pub mod cli; pub mod common; pub mod error; +pub mod magic_dns; pub mod peer_rpc; pub mod web; diff --git a/easytier/src/proto/rpc_impl/client.rs b/easytier/src/proto/rpc_impl/client.rs index d40de18..3e971a2 100644 --- a/easytier/src/proto/rpc_impl/client.rs +++ b/easytier/src/proto/rpc_impl/client.rs @@ -65,7 +65,7 @@ pub struct Client { transport: Mutex, inflight_requests: InflightRequestTable, peer_info: PeerInfoTable, - tasks: Arc>>, + tasks: Mutex>, } impl Client { @@ -76,7 +76,7 @@ impl Client { transport: Mutex::new(MpscTunnel::new(ring_b, None)), inflight_requests: Arc::new(DashMap::new()), peer_info: Arc::new(DashMap::new()), - tasks: Arc::new(Mutex::new(JoinSet::new())), + tasks: Mutex::new(JoinSet::new()), } } diff --git a/easytier/src/proto/rpc_impl/server.rs b/easytier/src/proto/rpc_impl/server.rs index 63a1b5e..d81fbf2 100644 --- a/easytier/src/proto/rpc_impl/server.rs +++ b/easytier/src/proto/rpc_impl/server.rs @@ -12,7 +12,10 @@ use tokio_stream::StreamExt; use crate::{ common::{join_joinset_background, PeerId}, proto::{ - common::{self, CompressionAlgoPb, RpcCompressionInfo, RpcPacket, RpcRequest, RpcResponse}, + common::{ + self, CompressionAlgoPb, RpcCompressionInfo, RpcPacket, RpcRequest, RpcResponse, + TunnelInfo, + }, rpc_types::{controller::Controller, error::Result}, }, tunnel::{ @@ -82,7 +85,8 @@ impl Server { let packet_merges = self.packet_mergers.clone(); let reg = self.registry.clone(); - let t = tasks.clone(); + let t = Arc::downgrade(&tasks); + let tunnel_info = mpsc.tunnel_info(); tasks.lock().unwrap().spawn(async move { let mut mpsc = mpsc; let mut rx = mpsc.get_stream(); @@ -120,10 +124,15 @@ impl Server { match ret { Ok(Some(packet)) => { packet_merges.remove(&key); + let Some(t) = t.upgrade() else { + tracing::error!("tasks is dropped"); + return; + }; t.lock().unwrap().spawn(Self::handle_rpc( mpsc.get_sink(), packet, reg.clone(), + tunnel_info.clone(), )); } Ok(None) => {} @@ -143,7 +152,11 @@ impl Server { }); } - async fn handle_rpc_request(packet: RpcPacket, reg: Arc) -> Result { + async fn handle_rpc_request( + packet: RpcPacket, + reg: Arc, + tunnel_info: Option, + ) -> Result { let body = if let Some(compression_info) = packet.compression_info { decompress_packet( compression_info.algo.try_into().unwrap_or_default(), @@ -158,6 +171,7 @@ impl Server { let mut ctrl = RpcController::default(); let raw_req = Bytes::from(rpc_request.request); ctrl.set_raw_input(raw_req.clone()); + ctrl.set_tunnel_info(tunnel_info); let ret = timeout( timeout_duration, reg.call_method(packet.descriptor.unwrap(), ctrl.clone(), raw_req), @@ -170,7 +184,12 @@ impl Server { } } - async fn handle_rpc(sender: MpscTunnelSender, packet: RpcPacket, reg: Arc) { + async fn handle_rpc( + sender: MpscTunnelSender, + packet: RpcPacket, + reg: Arc, + tunnel_info: Option, + ) { let from_peer = packet.from_peer; let to_peer = packet.to_peer; let transaction_id = packet.transaction_id; @@ -181,7 +200,7 @@ impl Server { let now = std::time::Instant::now(); let compression_info = packet.compression_info.clone(); - let resp_bytes = Self::handle_rpc_request(packet, reg).await; + let resp_bytes = Self::handle_rpc_request(packet, reg, tunnel_info).await; match &resp_bytes { Ok(r) => { diff --git a/easytier/src/proto/rpc_impl/standalone.rs b/easytier/src/proto/rpc_impl/standalone.rs index c86c5b0..54ee94f 100644 --- a/easytier/src/proto/rpc_impl/standalone.rs +++ b/easytier/src/proto/rpc_impl/standalone.rs @@ -9,6 +9,7 @@ use tokio::task::JoinSet; use crate::{ common::join_joinset_background, proto::{ + common::TunnelInfo, rpc_impl::bidirect::BidirectRpcManager, rpc_types::{__rt::RpcClientFactory, error::Error}, }, @@ -17,11 +18,22 @@ use crate::{ use super::service_registry::ServiceRegistry; +#[async_trait::async_trait] +#[auto_impl::auto_impl(Arc, Box)] +pub trait RpcServerHook: Send + Sync { + async fn on_new_client(&self, _tunnel_info: Option) {} + async fn on_client_disconnected(&self, _tunnel_info: Option) {} +} + +struct DefaultHook; +impl RpcServerHook for DefaultHook {} + pub struct StandAloneServer { registry: Arc, listener: Option, inflight_server: Arc, - tasks: Arc>>, + tasks: JoinSet<()>, + hook: Option>, } impl StandAloneServer { @@ -30,10 +42,16 @@ impl StandAloneServer { registry: Arc::new(ServiceRegistry::new()), listener: Some(listener), inflight_server: Arc::new(AtomicU32::new(0)), - tasks: Arc::new(Mutex::new(JoinSet::new())), + tasks: JoinSet::new(), + + hook: None, } } + pub fn set_hook(&mut self, hook: Arc) { + self.hook = Some(hook); + } + pub fn registry(&self) -> &ServiceRegistry { &self.registry } @@ -42,17 +60,20 @@ impl StandAloneServer { listener: &mut L, inflight: Arc, registry: Arc, - tasks: Arc>>, + hook: Arc, ) -> Result<(), Error> { - listener - .listen() - .await - .with_context(|| "failed to listen")?; + let tasks = Arc::new(Mutex::new(JoinSet::new())); + join_joinset_background(tasks.clone(), "standalone serve_loop".to_string()); loop { let tunnel = listener.accept().await?; + let tunnel_info = tunnel.info(); let registry = registry.clone(); let inflight_server = inflight.clone(); + let hook = hook.clone(); + + hook.on_new_client(tunnel_info.clone()).await; + inflight_server.fetch_add(1, std::sync::atomic::Ordering::Relaxed); tasks.lock().unwrap().spawn(async move { let server = @@ -60,27 +81,32 @@ impl StandAloneServer { server.rpc_server().registry().replace_registry(®istry); server.run_with_tunnel(tunnel); server.wait().await; + hook.on_client_disconnected(tunnel_info.clone()).await; inflight_server.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); }); } } pub async fn serve(&mut self) -> Result<(), Error> { - let tasks = self.tasks.clone(); let mut listener = self.listener.take().unwrap(); - let registry = self.registry.clone(); + let hook = self.hook.take().unwrap_or_else(|| Arc::new(DefaultHook)); - join_joinset_background(tasks.clone(), "standalone server tasks".to_string()); + listener + .listen() + .await + .with_context(|| "failed to listen")?; + + let registry = self.registry.clone(); let inflight_server = self.inflight_server.clone(); - self.tasks.lock().unwrap().spawn(async move { + self.tasks.spawn(async move { loop { let ret = Self::serve_loop( &mut listener, inflight_server.clone(), registry.clone(), - tasks.clone(), + hook.clone(), ) .await; if let Err(e) = ret { @@ -146,4 +172,34 @@ impl StandAloneClient { .rpc_client() .scoped_client::(1, 1, domain_name)) } + + pub async fn wait(&mut self) { + if let Some(client) = self.client.take() { + client.wait().await; + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + proto::rpc_impl::standalone::StandAloneServer, + tunnel::{ + tcp::{TcpTunnelConnector, TcpTunnelListener}, + TunnelConnector as _, + }, + }; + + #[tokio::test] + async fn standalone_exit_on_drop() { + let addr: url::Url = "tcp://0.0.0.0:53884".parse().unwrap(); + let tunnel = TcpTunnelListener::new(addr.clone()); + let mut server = StandAloneServer::new(tunnel); + server.serve().await.unwrap(); + drop(server); + + // tcp should closed + let mut connector = TcpTunnelConnector::new(addr); + connector.connect().await.unwrap_err(); + } } diff --git a/easytier/src/proto/rpc_types/controller.rs b/easytier/src/proto/rpc_types/controller.rs index 0259ad4..fe04b8e 100644 --- a/easytier/src/proto/rpc_types/controller.rs +++ b/easytier/src/proto/rpc_types/controller.rs @@ -2,6 +2,8 @@ use std::sync::{Arc, Mutex}; use bytes::Bytes; +use crate::proto::common::TunnelInfo; + // Controller must impl clone and all cloned controllers share the same data pub trait Controller: Send + Sync + Clone + 'static { fn timeout_ms(&self) -> i32 { @@ -21,6 +23,11 @@ pub trait Controller: Send + Sync + Clone + 'static { None } + fn set_tunnel_info(&mut self, _tunnel_info: Option) {} + fn get_tunnel_info(&self) -> Option<&TunnelInfo> { + None + } + fn set_raw_output(&mut self, _raw_output: Bytes) {} fn get_raw_output(&self) -> Option { None @@ -38,6 +45,7 @@ pub struct BaseController { pub timeout_ms: i32, pub trace_id: i32, pub raw_data: Arc>, + pub tunnel_info: Option, } impl Controller for BaseController { @@ -72,6 +80,14 @@ impl Controller for BaseController { fn get_raw_output(&self) -> Option { self.raw_data.lock().unwrap().raw_output.clone() } + + fn get_tunnel_info(&self) -> Option<&TunnelInfo> { + self.tunnel_info.as_ref() + } + + fn set_tunnel_info(&mut self, tunnel_info: Option) { + self.tunnel_info = tunnel_info; + } } impl Default for BaseController { @@ -83,6 +99,7 @@ impl Default for BaseController { raw_input: None, raw_output: None, })), + tunnel_info: None, } } } diff --git a/easytier/src/tunnel/mpsc.rs b/easytier/src/tunnel/mpsc.rs index bc39a1b..dec7379 100644 --- a/easytier/src/tunnel/mpsc.rs +++ b/easytier/src/tunnel/mpsc.rs @@ -5,7 +5,7 @@ use std::{pin::Pin, time::Duration}; use anyhow::Context; use tokio::time::timeout; -use crate::common::scoped_task::ScopedTask; +use crate::{common::scoped_task::ScopedTask, proto::common::TunnelInfo}; use super::{packet_def::ZCPacket, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream}; @@ -133,6 +133,10 @@ impl MpscTunnel { self.tx.take(); self.task.abort(); } + + pub fn tunnel_info(&self) -> Option { + self.tunnel.info() + } } #[cfg(test)]