diff --git a/Cargo.lock b/Cargo.lock index 6f4ba29..18fb409 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -670,6 +670,12 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +[[package]] +name = "beef" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a8241f3ebb85c056b509d4327ad0358fbbba6ffb340bf388f26350aeda225b1" + [[package]] name = "bigdecimal" version = "0.4.6" @@ -693,7 +699,7 @@ dependencies = [ "bitflags 2.8.0", "cexpr", "clang-sys", - "itertools 0.11.0", + "itertools 0.12.1", "proc-macro2", "quote", "regex", @@ -1917,6 +1923,8 @@ dependencies = [ "pnet", "prost", "prost-build", + "prost-reflect", + "prost-reflect-build", "prost-types", "quinn", "rand 0.8.5", @@ -3709,6 +3717,39 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "logos" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7251356ef8cb7aec833ddf598c6cb24d17b689d20b993f9d11a3d764e34e6458" +dependencies = [ + "logos-derive", +] + +[[package]] +name = "logos-codegen" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59f80069600c0d66734f5ff52cc42f2dabd6b29d205f333d61fd7832e9e9963f" +dependencies = [ + "beef", + "fnv", + "lazy_static", + "proc-macro2", + "quote", + "regex-syntax 0.8.4", + "syn 2.0.87", +] + +[[package]] +name = "logos-derive" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24fb722b06a9dc12adb0963ed585f19fc61dc5413e6a9be9422ef92c091e731d" +dependencies = [ + "logos-codegen", +] + [[package]] name = "loom" version = "0.5.6" @@ -4578,6 +4619,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "ordered-float" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c" +dependencies = [ + "num-traits", +] + [[package]] name = "ordered-float" version = "3.9.2" @@ -5340,7 +5390,7 @@ checksum = "f8650aabb6c35b860610e9cff5dc1af886c9e25073b7b1712a68972af4281302" dependencies = [ "bytes", "heck 0.5.0", - "itertools 0.11.0", + "itertools 0.12.1", "log", "multimap", "once_cell", @@ -5360,7 +5410,44 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acf0c195eebb4af52c752bec4f52f645da98b6e92077a04110c7f349477ae5ac" dependencies = [ "anyhow", - "itertools 0.11.0", + "itertools 0.12.1", + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "prost-reflect" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e92b959d24e05a3e2da1d0beb55b48bc8a97059b8336ea617780bd6addbbfb5a" +dependencies = [ + "base64 0.22.1", + "logos", + "once_cell", + "prost", + "prost-reflect-derive", + "prost-types", + "serde", + "serde-value", +] + +[[package]] +name = "prost-reflect-build" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50e2537231d94dd2778920c2ada37dd9eb1ac0325bb3ee3ee651bd44c1134123" +dependencies = [ + "prost-build", + "prost-reflect", +] + +[[package]] +name = "prost-reflect-derive" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4fce6b22f15cc8d8d400a2b98ad29202b33bd56c7d9ddd815bc803a807ecb65" +dependencies = [ "proc-macro2", "quote", "syn 2.0.87", @@ -6311,7 +6398,7 @@ dependencies = [ "bigdecimal", "chrono", "inherent", - "ordered-float", + "ordered-float 3.9.2", "rust_decimal", "sea-query-derive", "serde_json", @@ -6451,6 +6538,16 @@ dependencies = [ "typeid", ] +[[package]] +name = "serde-value" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3a1a3341211875ef120e117ea7fd5228530ae7e7036a779fdc9117be6b3282c" +dependencies = [ + "ordered-float 2.10.1", + "serde", +] + [[package]] name = "serde_derive" version = "1.0.207" diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index 2af1259..0365e64 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -188,6 +188,12 @@ async-compression = { version = "0.4.17", default-features = false, features = [ kcp-sys = { git = "https://github.com/EasyTier/kcp-sys" } +prost-reflect = { version = "0.14.5", features = [ + "serde", + "derive", + "text-format" +] } + [target.'cfg(any(target_os = "linux", target_os = "macos", target_os = "windows", target_os = "freebsd"))'.dependencies] machine-uid = "0.5.3" @@ -208,6 +214,7 @@ globwalk = "0.8.1" regex = "1" prost-build = "0.13.2" rpc_build = { package = "easytier-rpc-build", version = "0.1.0", features = ["internal-namespace"] } +prost-reflect-build = { version = "0.14.0" } [target.'cfg(windows)'.build-dependencies] reqwest = { version = "0.11", features = ["blocking"] } diff --git a/easytier/build.rs b/easytier/build.rs index fa0d30f..0a6a87a 100644 --- a/easytier/build.rs +++ b/easytier/build.rs @@ -141,7 +141,8 @@ fn main() -> Result<(), Box> { println!("cargo:rerun-if-changed={}", proto_file); } - prost_build::Config::new() + let mut config = prost_build::Config::new(); + config .protoc_arg("--experimental_allow_proto3_optional") .type_attribute(".common", "#[derive(serde::Serialize, serde::Deserialize)]") .type_attribute(".error", "#[derive(serde::Serialize, serde::Deserialize)]") @@ -156,9 +157,11 @@ fn main() -> Result<(), Box> { .type_attribute("peer_rpc.ForeignNetworkRouteInfoKey", "#[derive(Hash, Eq)]") .type_attribute("common.RpcDescriptor", "#[derive(Hash, Eq)]") .service_generator(Box::new(rpc_build::ServiceGenerator::new())) - .btree_map(["."]) - .compile_protos(&proto_files, &["src/proto/"]) - .unwrap(); + .btree_map(["."]); + + prost_reflect_build::Builder::new() + .file_descriptor_set_bytes("crate::proto::DESCRIPTOR_POOL_BYTES") + .compile_protos_with_config(config, &proto_files, &["src/proto/"])?; check_locale(); Ok(()) diff --git a/easytier/src/connector/udp_hole_punch/sym_to_cone.rs b/easytier/src/connector/udp_hole_punch/sym_to_cone.rs index a913a58..bc449e6 100644 --- a/easytier/src/connector/udp_hole_punch/sym_to_cone.rs +++ b/easytier/src/connector/udp_hole_punch/sym_to_cone.rs @@ -284,6 +284,7 @@ impl PunchSymToConeHoleClient { BaseController { timeout_ms: 4000, trace_id: 0, + ..Default::default() }, req, ) @@ -314,6 +315,7 @@ impl PunchSymToConeHoleClient { BaseController { timeout_ms: 4000, trace_id: 0, + ..Default::default() }, req, ) diff --git a/easytier/src/peers/peer_conn_ping.rs b/easytier/src/peers/peer_conn_ping.rs index c0c2ead..0e5aeb1 100644 --- a/easytier/src/peers/peer_conn_ping.rs +++ b/easytier/src/peers/peer_conn_ping.rs @@ -1,5 +1,4 @@ use std::{ - fmt::Debug, sync::{ atomic::{AtomicU32, Ordering}, Arc, diff --git a/easytier/src/peers/peer_ospf_route.rs b/easytier/src/peers/peer_ospf_route.rs index 779ea21..87a92cb 100644 --- a/easytier/src/peers/peer_ospf_route.rs +++ b/easytier/src/peers/peer_ospf_route.rs @@ -16,6 +16,8 @@ use petgraph::{ graph::NodeIndex, Directed, Graph, }; +use prost::Message; +use prost_reflect::{DynamicMessage, ReflectMessage}; use serde::{Deserialize, Serialize}; use tokio::{ select, @@ -283,6 +285,8 @@ type Error = SyncRouteInfoError; #[derive(Debug)] struct SyncedRouteInfo { peer_infos: DashMap, + // prost doesn't support unknown fields, so we use DynamicMessage to store raw infos and progate them to other peers. + raw_peer_infos: DashMap, conn_map: DashMap, AtomicVersion)>, foreign_network: DashMap, } @@ -297,6 +301,7 @@ impl SyncedRouteInfo { fn remove_peer(&self, peer_id: PeerId) { tracing::warn!(?peer_id, "remove_peer from synced_route_info"); self.peer_infos.remove(&peer_id); + self.raw_peer_infos.remove(&peer_id); self.conn_map.remove(&peer_id); self.foreign_network.retain(|k, _| k.peer_id != peer_id); } @@ -369,8 +374,11 @@ impl SyncedRouteInfo { my_peer_route_id: u64, dst_peer_id: PeerId, peer_infos: &Vec, + raw_peer_infos: &Vec, ) -> Result<(), Error> { - for mut route_info in peer_infos.iter().map(Clone::clone) { + for (idx, route_info) in peer_infos.iter().enumerate() { + let mut route_info = route_info.clone(); + let raw_route_info = &raw_peer_infos[idx]; self.check_duplicate_peer_id( my_peer_id, my_peer_route_id, @@ -391,10 +399,16 @@ impl SyncedRouteInfo { .entry(route_info.peer_id) .and_modify(|old_entry| { if route_info.version > old_entry.version { + self.raw_peer_infos + .insert(route_info.peer_id, raw_route_info.clone()); *old_entry = route_info.clone(); } }) - .or_insert_with(|| route_info.clone()); + .or_insert_with(|| { + self.raw_peer_infos + .insert(route_info.peer_id, raw_route_info.clone()); + route_info.clone() + }); } Ok(()) } @@ -1047,6 +1061,7 @@ impl PeerRouteServiceImpl { synced_route_info: SyncedRouteInfo { peer_infos: DashMap::new(), + raw_peer_infos: DashMap::new(), conn_map: DashMap::new(), foreign_network: DashMap::new(), }, @@ -1381,6 +1396,39 @@ impl PeerRouteServiceImpl { } } + fn build_sync_route_raw_req( + req: &SyncRouteInfoRequest, + raw_peer_infos: &DashMap, + ) -> DynamicMessage { + use prost_reflect::Value; + + let mut req_dynamic_msg = DynamicMessage::new(SyncRouteInfoRequest::default().descriptor()); + req_dynamic_msg.transcode_from(req).unwrap(); + + let peer_infos = req.peer_infos.as_ref().map(|x| &x.items); + if let Some(peer_infos) = peer_infos { + let mut peer_info_raws = Vec::new(); + for peer_info in peer_infos.iter() { + if let Some(info) = raw_peer_infos.get(&peer_info.peer_id) { + peer_info_raws.push(Value::Message(info.clone())); + } else { + let mut p = DynamicMessage::new(RoutePeerInfo::default().descriptor()); + p.transcode_from(peer_info).unwrap(); + peer_info_raws.push(Value::Message(p)); + } + } + + let mut peer_infos = DynamicMessage::new(RoutePeerInfos::default().descriptor()); + peer_infos.set_field_by_name("items", Value::List(peer_info_raws)); + + req_dynamic_msg.set_field_by_name("peer_infos", Value::Message(peer_infos)); + } + + tracing::trace!(?req_dynamic_msg, "build_sync_route_raw_req"); + + req_dynamic_msg + } + async fn sync_route_with_peer( &self, dst_peer_id: PeerId, @@ -1419,20 +1467,27 @@ impl PeerRouteServiceImpl { self.global_ctx.get_network_name(), ); + let sync_route_info_req = SyncRouteInfoRequest { + my_peer_id, + my_session_id: session.my_session_id.load(Ordering::Relaxed), + is_initiator: session.we_are_initiator.load(Ordering::Relaxed), + peer_infos: peer_infos.clone().map(|x| RoutePeerInfos { items: x }), + conn_bitmap: conn_bitmap.clone().map(Into::into), + foreign_network_infos: foreign_network.clone(), + }; + let mut ctrl = BaseController::default(); ctrl.set_timeout_ms(3000); - let ret = rpc_stub - .sync_route_info( - ctrl, - SyncRouteInfoRequest { - my_peer_id, - my_session_id: session.my_session_id.load(Ordering::Relaxed), - is_initiator: session.we_are_initiator.load(Ordering::Relaxed), - peer_infos: peer_infos.clone().map(|x| RoutePeerInfos { items: x }), - conn_bitmap: conn_bitmap.clone().map(Into::into), - foreign_network_infos: foreign_network.clone(), - }, + ctrl.set_raw_input( + Self::build_sync_route_raw_req( + &sync_route_info_req, + &self.synced_route_info.raw_peer_infos, ) + .encode_to_vec() + .into(), + ); + let ret = rpc_stub + .sync_route_info(ctrl, SyncRouteInfoRequest::default()) .await; if let Err(e) = &ret { @@ -1508,12 +1563,30 @@ impl Debug for RouteSessionManager { } } +fn get_raw_peer_infos(req_raw_input: &mut bytes::Bytes) -> Option> { + let sync_req_dynamic_msg = + DynamicMessage::decode(SyncRouteInfoRequest::default().descriptor(), req_raw_input) + .unwrap(); + + let peer_infos = sync_req_dynamic_msg.get_field_by_name("peer_infos")?; + + let infos = peer_infos + .as_message()? + .get_field_by_name("items")? + .as_list()? + .iter() + .map(|x| x.as_message().unwrap().clone()) + .collect(); + + Some(infos) +} + #[async_trait::async_trait] impl OspfRouteRpc for RouteSessionManager { type Controller = BaseController; async fn sync_route_info( &self, - _ctrl: BaseController, + ctrl: BaseController, request: SyncRouteInfoRequest, ) -> Result { let from_peer_id = request.my_peer_id; @@ -1522,6 +1595,13 @@ impl OspfRouteRpc for RouteSessionManager { let peer_infos = request.peer_infos.map(|x| x.items); let conn_bitmap = request.conn_bitmap.map(Into::into); let foreign_network = request.foreign_network_infos; + let raw_peer_infos = if peer_infos.is_some() { + let r = get_raw_peer_infos(&mut ctrl.get_raw_input().unwrap()).unwrap(); + assert_eq!(r.len(), peer_infos.as_ref().unwrap().len()); + Some(r) + } else { + None + }; let ret = self .do_sync_route_info( @@ -1529,6 +1609,7 @@ impl OspfRouteRpc for RouteSessionManager { from_session_id, is_initiator, peer_infos, + raw_peer_infos, conn_bitmap, foreign_network, ) @@ -1783,6 +1864,7 @@ impl RouteSessionManager { from_session_id: SessionId, is_initiator: bool, peer_infos: Option>, + raw_peer_infos: Option>, conn_bitmap: Option, foreign_network: Option, ) -> Result { @@ -1805,6 +1887,7 @@ impl RouteSessionManager { service_impl.my_peer_route_id, from_peer_id, peer_infos, + raw_peer_infos.as_ref().unwrap(), )?; session.update_dst_saved_peer_info_version(peer_infos); need_update_route_table = true; @@ -2123,18 +2206,26 @@ mod tests { time::Duration, }; + use dashmap::DashMap; + use prost_reflect::{DynamicMessage, ReflectMessage}; + use crate::{ common::{global_ctx::tests::get_mock_global_ctx, PeerId}, connector::udp_hole_punch::tests::replace_stun_info_collector, peers::{ create_packet_recv_chan, peer_manager::{PeerManager, RouteAlgoType}, + peer_ospf_route::PeerRouteServiceImpl, route_trait::{NextHopPolicy, Route, RouteCostCalculatorInterface}, tests::connect_peer_manager, }, - proto::common::NatType, + proto::{ + common::NatType, + peer_rpc::{RoutePeerInfo, RoutePeerInfos, SyncRouteInfoRequest}, + }, tunnel::common::tests::wait_for_condition, }; + use prost::Message; use super::PeerRoute; @@ -2554,4 +2645,31 @@ mod tests { ) .await; } + + #[tokio::test] + async fn test_raw_peer_info() { + let mut req = SyncRouteInfoRequest::default(); + let raw_info_map: DashMap = DashMap::new(); + + req.peer_infos = Some(RoutePeerInfos { + items: vec![RoutePeerInfo { + peer_id: 1, + ..Default::default() + }], + }); + + let mut raw_req = DynamicMessage::new(RoutePeerInfo::default().descriptor()); + raw_req + .transcode_from(&req.peer_infos.as_ref().unwrap().items[0]) + .unwrap(); + raw_info_map.insert(1, raw_req); + + let out = PeerRouteServiceImpl::build_sync_route_raw_req(&req, &raw_info_map); + + let out_bytes = out.encode_to_vec(); + + let req2 = SyncRouteInfoRequest::decode(out_bytes.as_slice()).unwrap(); + + assert_eq!(req, req2); + } } diff --git a/easytier/src/proto/mod.rs b/easytier/src/proto/mod.rs index 0db7766..f1da8fd 100644 --- a/easytier/src/proto/mod.rs +++ b/easytier/src/proto/mod.rs @@ -9,3 +9,6 @@ pub mod web; #[cfg(test)] pub mod tests; + +const DESCRIPTOR_POOL_BYTES: &[u8] = + include_bytes!(concat!(env!("OUT_DIR"), "/file_descriptor_set.bin")); diff --git a/easytier/src/proto/rpc_impl/client.rs b/easytier/src/proto/rpc_impl/client.rs index 0094d99..d40de18 100644 --- a/easytier/src/proto/rpc_impl/client.rs +++ b/easytier/src/proto/rpc_impl/client.rs @@ -192,7 +192,7 @@ impl Client { async fn call( &self, - ctrl: Self::Controller, + mut ctrl: Self::Controller, method: ::Method, input: bytes::Bytes, ) -> Result { @@ -224,7 +224,11 @@ impl Client { }; let rpc_req = RpcRequest { - request: input.into(), + request: if let Some(raw_input) = ctrl.get_raw_input() { + raw_input.into() + } else { + input.into() + }, timeout_ms: ctrl.timeout_ms(), ..Default::default() }; @@ -280,7 +284,10 @@ impl Client { return Err(err.into()); } - Ok(bytes::Bytes::from(rpc_resp.response)) + let raw_output = Bytes::from(rpc_resp.response.clone()); + ctrl.set_raw_output(raw_output.clone()); + + Ok(raw_output) } } diff --git a/easytier/src/proto/rpc_impl/server.rs b/easytier/src/proto/rpc_impl/server.rs index 773e965..63a1b5e 100644 --- a/easytier/src/proto/rpc_impl/server.rs +++ b/easytier/src/proto/rpc_impl/server.rs @@ -13,7 +13,7 @@ use crate::{ common::{join_joinset_background, PeerId}, proto::{ common::{self, CompressionAlgoPb, RpcCompressionInfo, RpcPacket, RpcRequest, RpcResponse}, - rpc_types::error::Result, + rpc_types::{controller::Controller, error::Result}, }, tunnel::{ mpsc::{MpscTunnel, MpscTunnelSender}, @@ -155,16 +155,19 @@ impl Server { }; let rpc_request = RpcRequest::decode(Bytes::from(body))?; let timeout_duration = std::time::Duration::from_millis(rpc_request.timeout_ms as u64); - let ctrl = RpcController::default(); - Ok(timeout( + let mut ctrl = RpcController::default(); + let raw_req = Bytes::from(rpc_request.request); + ctrl.set_raw_input(raw_req.clone()); + let ret = timeout( timeout_duration, - reg.call_method( - packet.descriptor.unwrap(), - ctrl, - Bytes::from(rpc_request.request), - ), + reg.call_method(packet.descriptor.unwrap(), ctrl.clone(), raw_req), ) - .await??) + .await??; + if let Some(raw_output) = ctrl.get_raw_output() { + Ok(raw_output) + } else { + Ok(ret) + } } async fn handle_rpc(sender: MpscTunnelSender, packet: RpcPacket, reg: Arc) { diff --git a/easytier/src/proto/rpc_types/controller.rs b/easytier/src/proto/rpc_types/controller.rs index 900fa2a..0259ad4 100644 --- a/easytier/src/proto/rpc_types/controller.rs +++ b/easytier/src/proto/rpc_types/controller.rs @@ -1,4 +1,9 @@ -pub trait Controller: Send + Sync + 'static { +use std::sync::{Arc, Mutex}; + +use bytes::Bytes; + +// Controller must impl clone and all cloned controllers share the same data +pub trait Controller: Send + Sync + Clone + 'static { fn timeout_ms(&self) -> i32 { 5000 } @@ -10,12 +15,29 @@ pub trait Controller: Send + Sync + 'static { fn trace_id(&self) -> i32 { 0 } + + fn set_raw_input(&mut self, _raw_input: Bytes) {} + fn get_raw_input(&self) -> Option { + None + } + + fn set_raw_output(&mut self, _raw_output: Bytes) {} + fn get_raw_output(&self) -> Option { + None + } } #[derive(Debug)] +pub struct BaseControllerRawData { + pub raw_input: Option, + pub raw_output: Option, +} + +#[derive(Debug, Clone)] pub struct BaseController { pub timeout_ms: i32, pub trace_id: i32, + pub raw_data: Arc>, } impl Controller for BaseController { @@ -34,6 +56,22 @@ impl Controller for BaseController { fn trace_id(&self) -> i32 { self.trace_id } + + fn set_raw_input(&mut self, raw_input: Bytes) { + self.raw_data.lock().unwrap().raw_input = Some(raw_input); + } + + fn get_raw_input(&self) -> Option { + self.raw_data.lock().unwrap().raw_input.clone() + } + + fn set_raw_output(&mut self, raw_output: Bytes) { + self.raw_data.lock().unwrap().raw_output = Some(raw_output); + } + + fn get_raw_output(&self) -> Option { + self.raw_data.lock().unwrap().raw_output.clone() + } } impl Default for BaseController { @@ -41,6 +79,10 @@ impl Default for BaseController { Self { timeout_ms: 5000, trace_id: 0, + raw_data: Arc::new(Mutex::new(BaseControllerRawData { + raw_input: None, + raw_output: None, + })), } } }