From fec885c427d257874f5ce151b9c847af44efafc4 Mon Sep 17 00:00:00 2001 From: "Sijie.Sun" Date: Sat, 24 May 2025 00:36:00 +0800 Subject: [PATCH] fix token mismatch when using web (#871) --- easytier-web/src/client_manager/mod.rs | 16 +++-- easytier-web/src/client_manager/session.rs | 57 ++++++++++++----- easytier-web/src/client_manager/storage.rs | 72 +++++++++++----------- easytier-web/src/db/mod.rs | 2 +- easytier-web/src/restful/mod.rs | 18 +----- easytier-web/src/restful/network.rs | 33 +++++----- 6 files changed, 108 insertions(+), 90 deletions(-) diff --git a/easytier-web/src/client_manager/mod.rs b/easytier-web/src/client_manager/mod.rs index aa998b1..2251178 100644 --- a/easytier-web/src/client_manager/mod.rs +++ b/easytier-web/src/client_manager/mod.rs @@ -10,7 +10,7 @@ use easytier::{ use session::Session; use storage::{Storage, StorageToken}; -use crate::db::Db; +use crate::db::{Db, UserIdInDb}; #[derive(Debug)] pub struct ClientManager { @@ -86,15 +86,21 @@ impl ClientManager { ret } - pub fn get_session_by_machine_id(&self, machine_id: &uuid::Uuid) -> Option> { - let c_url = self.storage.get_client_url_by_machine_id(machine_id)?; + pub fn get_session_by_machine_id( + &self, + user_id: UserIdInDb, + machine_id: &uuid::Uuid, + ) -> Option> { + let c_url = self + .storage + .get_client_url_by_machine_id(user_id, machine_id)?; self.client_sessions .get(&c_url) .map(|item| item.value().clone()) } - pub async fn list_machine_by_token(&self, token: String) -> Vec { - self.storage.list_token_clients(&token) + pub async fn list_machine_by_user_id(&self, user_id: UserIdInDb) -> Vec { + self.storage.list_user_clients(user_id) } pub async fn get_heartbeat_requests(&self, client_url: &url::Url) -> Option { diff --git a/easytier-web/src/client_manager/session.rs b/easytier-web/src/client_manager/session.rs index 5d17500..2b26df2 100644 --- a/easytier-web/src/client_manager/session.rs +++ b/easytier-web/src/client_manager/session.rs @@ -1,5 +1,6 @@ use std::{fmt::Debug, str::FromStr as _, sync::Arc}; +use anyhow::Context; use easytier::{ common::scoped_task::ScopedTask, proto::{ @@ -78,30 +79,54 @@ impl WebServerService for SessionRpcService { req: HeartbeatRequest, ) -> rpc_types::error::Result { let mut data = self.data.write().await; + + let Ok(storage) = Storage::try_from(data.storage.clone()) else { + tracing::error!("Failed to get storage"); + return Ok(HeartbeatResponse {}); + }; + + let machine_id: uuid::Uuid = + req.machine_id + .clone() + .map(Into::into) + .ok_or(anyhow::anyhow!( + "Machine id is not set correctly, expect uuid but got: {:?}", + req.machine_id + ))?; + + let user_id = storage + .db() + .get_user_id_by_token(req.user_token.clone()) + .await + .with_context(|| { + format!( + "Failed to get user id by token from db: {:?}", + req.user_token + ) + })? + .ok_or(anyhow::anyhow!( + "User not found by token: {:?}", + req.user_token + ))?; + if data.req.replace(req.clone()).is_none() { assert!(data.storage_token.is_none()); data.storage_token = Some(StorageToken { token: req.user_token.clone().into(), client_url: data.client_url.clone(), - machine_id: req - .machine_id - .clone() - .map(Into::into) - .unwrap_or(uuid::Uuid::new_v4()), + machine_id, + user_id, }); } - if let Ok(storage) = Storage::try_from(data.storage.clone()) { - let Ok(report_time) = chrono::DateTime::::from_str(&req.report_time) - else { - tracing::error!("Failed to parse report time: {:?}", req.report_time); - return Ok(HeartbeatResponse {}); - }; - storage.update_client( - data.storage_token.as_ref().unwrap().clone(), - report_time.timestamp(), - ); - } + let Ok(report_time) = chrono::DateTime::::from_str(&req.report_time) else { + tracing::error!("Failed to parse report time: {:?}", req.report_time); + return Ok(HeartbeatResponse {}); + }; + storage.update_client( + data.storage_token.as_ref().unwrap().clone(), + report_time.timestamp(), + ); let _ = data.notifier.send(req); Ok(HeartbeatResponse {}) diff --git a/easytier-web/src/client_manager/storage.rs b/easytier-web/src/client_manager/storage.rs index a9be575..9acf6f6 100644 --- a/easytier-web/src/client_manager/storage.rs +++ b/easytier-web/src/client_manager/storage.rs @@ -2,7 +2,7 @@ use std::sync::{Arc, Weak}; use dashmap::DashMap; -use crate::db::Db; +use crate::db::{Db, UserIdInDb}; // use this to maintain Storage #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] @@ -10,21 +10,19 @@ pub struct StorageToken { pub token: String, pub client_url: url::Url, pub machine_id: uuid::Uuid, + pub user_id: UserIdInDb, } #[derive(Debug, Clone)] struct ClientInfo { - client_url: url::Url, - machine_id: uuid::Uuid, - token: String, + storage_token: StorageToken, report_time: i64, } #[derive(Debug)] pub struct StorageInner { // some map for indexing - token_clients_map: DashMap>, - machine_client_url_map: DashMap, + user_clients_map: DashMap>, pub db: Db, } @@ -43,8 +41,7 @@ impl TryFrom for Storage { impl Storage { pub fn new(db: Db) -> Self { Storage(Arc::new(StorageInner { - token_clients_map: DashMap::new(), - machine_client_url_map: DashMap::new(), + user_clients_map: DashMap::new(), db, })) } @@ -54,17 +51,22 @@ impl Storage { machine_id: &uuid::Uuid, client_url: &url::Url, ) { - map.remove_if(&machine_id, |_, v| v.client_url == *client_url); + map.remove_if(&machine_id, |_, v| { + v.storage_token.client_url == *client_url + }); } fn update_mid_to_client_info_map( map: &DashMap, client_info: &ClientInfo, ) { - map.entry(client_info.machine_id) + map.entry(client_info.storage_token.machine_id) .and_modify(|e| { if e.report_time < client_info.report_time { - assert_eq!(e.machine_id, client_info.machine_id); + assert_eq!( + e.storage_token.machine_id, + client_info.storage_token.machine_id + ); *e = client_info.clone(); } }) @@ -74,53 +76,51 @@ impl Storage { pub fn update_client(&self, stoken: StorageToken, report_time: i64) { let inner = self .0 - .token_clients_map - .entry(stoken.token.clone()) + .user_clients_map + .entry(stoken.user_id) .or_insert_with(DashMap::new); let client_info = ClientInfo { - client_url: stoken.client_url.clone(), - machine_id: stoken.machine_id, - token: stoken.token.clone(), + storage_token: stoken.clone(), report_time, }; Self::update_mid_to_client_info_map(&inner, &client_info); - Self::update_mid_to_client_info_map(&self.0.machine_client_url_map, &client_info); } pub fn remove_client(&self, stoken: &StorageToken) { - self.0.token_clients_map.remove_if(&stoken.token, |_, set| { - Self::remove_mid_to_client_info_map(set, &stoken.machine_id, &stoken.client_url); - set.is_empty() - }); - - Self::remove_mid_to_client_info_map( - &self.0.machine_client_url_map, - &stoken.machine_id, - &stoken.client_url, - ); + self.0 + .user_clients_map + .remove_if(&stoken.user_id, |_, set| { + Self::remove_mid_to_client_info_map(set, &stoken.machine_id, &stoken.client_url); + set.is_empty() + }); } pub fn weak_ref(&self) -> WeakRefStorage { Arc::downgrade(&self.0) } - pub fn get_client_url_by_machine_id(&self, machine_id: &uuid::Uuid) -> Option { - self.0 - .machine_client_url_map - .get(&machine_id) - .map(|info| info.client_url.clone()) + pub fn get_client_url_by_machine_id( + &self, + user_id: UserIdInDb, + machine_id: &uuid::Uuid, + ) -> Option { + self.0.user_clients_map.get(&user_id).and_then(|info_map| { + info_map + .get(machine_id) + .map(|info| info.storage_token.client_url.clone()) + }) } - pub fn list_token_clients(&self, token: &str) -> Vec { + pub fn list_user_clients(&self, user_id: UserIdInDb) -> Vec { self.0 - .token_clients_map - .get(token) + .user_clients_map + .get(&user_id) .map(|info_map| { info_map .iter() - .map(|info| info.value().client_url.clone()) + .map(|info| info.value().storage_token.client_url.clone()) .collect() }) .unwrap_or_default() diff --git a/easytier-web/src/db/mod.rs b/easytier-web/src/db/mod.rs index bd286a2..377019d 100644 --- a/easytier-web/src/db/mod.rs +++ b/easytier-web/src/db/mod.rs @@ -12,7 +12,7 @@ use sqlx::{migrate::MigrateDatabase as _, types::chrono, Sqlite, SqlitePool}; use crate::migrator; -type UserIdInDb = i32; +pub type UserIdInDb = i32; pub enum ListNetworkProps { All, diff --git a/easytier-web/src/restful/mod.rs b/easytier-web/src/restful/mod.rs index 98543d4..3838e03 100644 --- a/easytier-web/src/restful/mod.rs +++ b/easytier-web/src/restful/mod.rs @@ -9,7 +9,7 @@ use axum::http::StatusCode; use axum::routing::post; use axum::{extract::State, routing::get, Json, Router}; use axum_login::tower_sessions::{ExpiredDeletion, SessionManagerLayer}; -use axum_login::{login_required, AuthManagerLayerBuilder, AuthzBackend}; +use axum_login::{login_required, AuthManagerLayerBuilder, AuthUser, AuthzBackend}; use axum_messages::MessagesManagerLayer; use easytier::common::config::ConfigLoader; use easytier::common::scoped_task::ScopedTask; @@ -24,7 +24,6 @@ use tower_sessions::Expiry; use tower_sessions_sqlx_store::SqliteStore; use users::{AuthSession, Backend}; -use crate::client_manager::session::Session; use crate::client_manager::storage::StorageToken; use crate::client_manager::ClientManager; use crate::db::Db; @@ -112,17 +111,6 @@ impl RestfulServer { }) } - async fn get_session_by_machine_id( - client_mgr: &ClientManager, - machine_id: &uuid::Uuid, - ) -> Result, HttpHandleError> { - let Some(result) = client_mgr.get_session_by_machine_id(machine_id) else { - return Err((StatusCode::NOT_FOUND, other_error("No such session").into())); - }; - - Ok(result) - } - async fn handle_list_all_sessions( auth_session: AuthSession, State(client_mgr): AppState, @@ -145,9 +133,7 @@ impl RestfulServer { return Err((StatusCode::UNAUTHORIZED, other_error("No such user").into())); }; - let machines = client_mgr - .list_machine_by_token(user.tokens[0].clone()) - .await; + let machines = client_mgr.list_machine_by_user_id(user.id().clone()).await; Ok(GetSummaryJsonResp { device_count: machines.len() as u32, diff --git a/easytier-web/src/restful/network.rs b/easytier-web/src/restful/network.rs index 6636b49..5a6f622 100644 --- a/easytier-web/src/restful/network.rs +++ b/easytier-web/src/restful/network.rs @@ -5,7 +5,6 @@ use axum::http::StatusCode; use axum::routing::{delete, post}; use axum::{extract::State, routing::get, Json, Router}; use axum_login::AuthUser; -use dashmap::DashSet; use easytier::launcher::NetworkConfig; use easytier::proto::common::Void; use easytier::proto::rpc_types::controller::BaseController; @@ -13,7 +12,7 @@ use easytier::proto::web::*; use crate::client_manager::session::Session; use crate::client_manager::ClientManager; -use crate::db::ListNetworkProps; +use crate::db::{ListNetworkProps, UserIdInDb}; use super::users::AuthSession; use super::{ @@ -81,12 +80,24 @@ impl NetworkApi { Self {} } + fn get_user_id(auth_session: &AuthSession) -> Result)> { + let Some(user_id) = auth_session.user.as_ref().map(|x| x.id()) else { + return Err(( + StatusCode::UNAUTHORIZED, + other_error(format!("No user id found")).into(), + )); + }; + Ok(user_id) + } + async fn get_session_by_machine_id( auth_session: &AuthSession, client_mgr: &ClientManager, machine_id: &uuid::Uuid, ) -> Result, HttpHandleError> { - let Some(result) = client_mgr.get_session_by_machine_id(machine_id) else { + let user_id = Self::get_user_id(auth_session)?; + + let Some(result) = client_mgr.get_session_by_machine_id(user_id, machine_id) else { return Err(( StatusCode::NOT_FOUND, other_error(format!("No such session: {}", machine_id)).into(), @@ -289,23 +300,13 @@ impl NetworkApi { auth_session: AuthSession, State(client_mgr): AppState, ) -> Result, HttpHandleError> { - let tokens = auth_session - .user - .as_ref() - .map(|x| x.tokens.clone()) - .unwrap_or_default(); + let user_id = Self::get_user_id(&auth_session)?; - let client_urls = DashSet::new(); - for token in tokens { - let urls = client_mgr.list_machine_by_token(token).await; - for url in urls { - client_urls.insert(url); - } - } + let client_urls = client_mgr.list_machine_by_user_id(user_id).await; let mut machines = vec![]; for item in client_urls.iter() { - let client_url = item.key().clone(); + let client_url = item.clone(); let session = client_mgr.get_heartbeat_requests(&client_url).await; machines.push(ListMachineItem { client_url: Some(client_url),