diff --git a/easytier/src/peers/peer_conn.rs b/easytier/src/peers/peer_conn.rs index d5ad8ba..adbca90 100644 --- a/easytier/src/peers/peer_conn.rs +++ b/easytier/src/peers/peer_conn.rs @@ -133,15 +133,23 @@ impl PeerConn { let mut locked = self.recv.lock().await; let recv = locked.as_mut().unwrap(); - let Some(rsp) = recv.next().await else { - return Err(Error::WaitRespError( - "conn closed during wait handshake response".to_owned(), - )); + let rsp = match recv.next().await { + Some(Ok(rsp)) => rsp, + Some(Err(e)) => { + return Err(Error::WaitRespError(format!( + "conn recv error during wait handshake response, err: {:?}", + e + ))) + } + None => { + return Err(Error::WaitRespError( + "conn closed during wait handshake response".to_owned(), + )) + } }; *need_retry = true; - let rsp = rsp?; let Some(peer_mgr_hdr) = rsp.peer_manager_header() else { return Err(Error::WaitRespError(format!( "unexpected packet: {:?}, cannot decode peer manager hdr", @@ -214,6 +222,9 @@ impl PeerConn { Error::WaitRespError("send handshake request error".to_owned()) })?; + // yield to send the response packet + tokio::task::yield_now().await; + Ok(()) } diff --git a/easytier/src/tunnel/common.rs b/easytier/src/tunnel/common.rs index 7d66b2e..cc5ee29 100644 --- a/easytier/src/tunnel/common.rs +++ b/easytier/src/tunnel/common.rs @@ -75,18 +75,12 @@ pin_project! { #[pin] reader: R, buf: BytesMut, - state: FrameReaderState, max_packet_size: usize, associate_data: Option>, + error: Option, } } -// usize means the size remaining to read -enum FrameReaderState { - ReadingHeader(usize), - ReadingBody(usize), -} - impl FramedReader { pub fn new(reader: R, max_packet_size: usize) -> Self { Self::new_with_associate_data(reader, max_packet_size, None) @@ -100,9 +94,9 @@ impl FramedReader { FramedReader { reader, buf: BytesMut::with_capacity(max_packet_size), - state: FrameReaderState::ReadingHeader(4), max_packet_size, associate_data, + error: None, } } @@ -146,9 +140,19 @@ where let mut self_mut = self.project(); loop { + if let Some(e) = self_mut.error.as_ref() { + tracing::warn!("poll_next on a failed FramedReader, {:?}", e); + return Poll::Ready(None); + } + while let Some(packet) = Self::extract_one_packet(self_mut.buf, *self_mut.max_packet_size) { + if let Err(TunnelError::InvalidPacket(msg)) = packet.as_ref() { + self_mut + .error + .replace(TunnelError::InvalidPacket(msg.clone())); + } return Poll::Ready(Some(packet)); }