diff --git a/easytier-web/src/main.rs b/easytier-web/src/main.rs index c4a17e1..9b0f955 100644 --- a/easytier-web/src/main.rs +++ b/easytier-web/src/main.rs @@ -10,8 +10,9 @@ use easytier::{ common::{ config::{ConfigLoader, ConsoleLoggerConfig, FileLoggerConfig, TomlConfigLoader}, constants::EASYTIER_VERSION, + error::Error, }, - tunnel::udp::UdpTunnelListener, + tunnel::{tcp::TcpTunnelListener, udp::UdpTunnelListener, TunnelListener}, utils::{init_logger, setup_panic_handler}, }; @@ -71,6 +72,18 @@ struct Cli { api_server_port: u16, } +pub fn get_listener_by_url( + l: &url::Url, +) -> Result, Error> { + Ok(match l.scheme() { + "tcp" => Box::new(TcpTunnelListener::new(l.clone())), + "udp" => Box::new(UdpTunnelListener::new(l.clone())), + _ => { + return Err(Error::InvalidUrl(l.to_string())); + } + }) +} + #[tokio::main] async fn main() { let locale = sys_locale::get_locale().unwrap_or_else(|| String::from("en-US")); @@ -92,14 +105,10 @@ async fn main() { // let db = db::Db::new(":memory:").await.unwrap(); let db = db::Db::new(cli.db).await.unwrap(); - let listener = UdpTunnelListener::new( - format!( - "{}://0.0.0.0:{}", - cli.config_server_protocol, cli.config_server_port - ) - .parse() - .unwrap(), - ); + let listener = get_listener_by_url( + &format!("{}://0.0.0.0:{}", cli.config_server_protocol, cli.config_server_port).parse().unwrap(), + ) + .unwrap(); let mut mgr = client_manager::ClientManager::new(db.clone()); mgr.serve(listener).await.unwrap(); let mgr = Arc::new(mgr);