From 479270834d235923733ef8c843f07203a67b1985 Mon Sep 17 00:00:00 2001 From: icsboyx Date: Tue, 3 Mar 2026 10:19:37 +0100 Subject: [PATCH] Refactor error handling to use `bail!` for cleaner error propagation and add shutdown signal handling in client and router --- src/client.rs | 35 +++++++++---- src/router.rs | 135 ++++++++++++++++++++++++++++++++------------------ src/tun.rs | 8 +-- 3 files changed, 118 insertions(+), 60 deletions(-) diff --git a/src/client.rs b/src/client.rs index 4875b62..eb1a5f6 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{Result, bail}; use chrono::Utc; use clap::Args; use ipnet::Ipv4Net; @@ -74,6 +74,13 @@ pub async fn start(config: ClientCfg) -> Result<()> { match msg { Ok((message, data)) => { match message { + RouterMessages::Quit(msg) =>{ + println!("Received quit message from server {}", msg); + println!("Closing connection with Server"); + client_stream.close().await?; + println!("Client shutdown complete"); + std::process::exit(0); + } RouterMessages::KeepAlive(timestamp) => { println!("Received keep-alive message from router with timestamp: {}, delta {} ms", timestamp, (Utc::now().timestamp_micros() - timestamp).abs() as f64 / 1000.0); } @@ -115,7 +122,7 @@ pub async fn start(config: ClientCfg) -> Result<()> { Err(e) => { eprintln!("Error reading from router: {}", e); client_stream.close().await?; - return Err(anyhow::anyhow!(format!("Error reading from router: {}", e))); + bail!(format!("Error reading from router: {}", e)); } } } @@ -145,10 +152,19 @@ pub async fn start(config: ClientCfg) -> Result<()> { Err(e) => { eprintln!("Error reading from TUN interface: {}", e); client_stream.close().await?; - return Err(anyhow::anyhow!(format!("Error reading from TUN interface: {}", e))); + bail!(format!("Error reading from TUN interface: {}", e)); } } } + + _ = tokio::signal::ctrl_c() => { + let msg = "Received shutdown signal, shutting down..."; + println!("\n{}\n", msg); + client_stream.send(RouterMessages::Quit(msg.into()), None).await?; + client_stream.close().await?; + println!("Client shutdown complete"); + std::process::exit(0); + } } } @@ -169,27 +185,29 @@ pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, b match msg { Ok((message, _data)) => { match message { - RouterMessages::CliReg(CliRegMessages::RegOk(uuid)) => { + RouterMessages::CliReg(CliRegMessages::RegOk(uuid)) => { println!("Received registration response from router: {:?}", message); println!("Client registration successful with UUID: {}", uuid); return Ok(uuid); } + RouterMessages::CliReg(CliRegMessages::RegFailed(err_msg)) => { eprintln!("Client registration failed: {}", err_msg); - return Err(anyhow::anyhow!(format!("Client registration failed: {}", err_msg))); + bail!("Client registration failed: {}", err_msg); } + _ => { let msg = "Unexpected message type received during client registration."; eprintln!("{}", msg); client_stream.close().await?; - return Err(anyhow::anyhow!(msg)); + bail!(msg); } } } Err(e) => { eprintln!("Error reading from router during client registration: {}", e); client_stream.close().await?; - return Err(anyhow::anyhow!(format!("Error reading from router: {}", e))); + bail!("Error reading from router: {}", e); } } } @@ -198,9 +216,8 @@ pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, b eprintln!("{}", msg); eprintln!("Closing connection with Server"); client_stream.close().await?; - return Err(anyhow::anyhow!(msg)); + bail!(msg); } - } } // Ok(()) diff --git a/src/router.rs b/src/router.rs index 5b34b83..a394152 100644 --- a/src/router.rs +++ b/src/router.rs @@ -11,7 +11,7 @@ use tokio::{ TcpStream, tcp::{OwnedReadHalf, OwnedWriteHalf}, }, - sync::{Mutex, Notify, RwLock}, + sync::{Mutex, RwLock}, time::Instant, }; use uuid::Uuid; @@ -19,7 +19,7 @@ use uuid::Uuid; use crate::client::ClientCfg; pub static KEEP_ALIVE_INTERVAL: Duration = tokio::time::Duration::from_secs(30); -pub static CLIENT_REGISTER_TIMEOUT: Duration = tokio::time::Duration::from_millis(100); +pub static CLIENT_REGISTER_TIMEOUT: Duration = tokio::time::Duration::from_millis(900); pub static TCP_NODELAY: bool = true; pub const SERVER_PACKET_SIZE: usize = 1024 * 9; @@ -177,6 +177,10 @@ impl ClientStream { let payload: Option>; match rx.read_exact(&mut buf[..6]).await { + Ok(0) => { + bail!("Client closed the connection"); + } + Ok(_) => { let total_len = u16::from_be_bytes(buf[0..2].try_into().unwrap()) as usize; let msg_len = u16::from_be_bytes(buf[2..4].try_into().unwrap()) as usize; @@ -209,8 +213,7 @@ impl ClientStream { } } Err(e) => { - eprintln!("Error reading from client stream: {}", e); - return Err(anyhow::anyhow!(format!("Error reading from client stream: {}", e))); + bail!(format!("{:#}", e)); } }; @@ -228,28 +231,28 @@ impl ClientStream { pub struct Router { clients: Arc>>, routing_table: Arc>>, - notify: Arc, } impl Router { pub async fn register_client(&self, client: ClientCfg, vpn_client: VPNClient) -> Result<()> { let id = vpn_client.id(); + + // Append local interface IP to client's local routes for routing table construction let mut client_local_route = client.local_routes; client_local_route.push(client.interface_ip); - self.routing_table.write().await.insert(client.interface_ip, id); + + // Add Client local routes to global routing table for net in client_local_route { self.routing_table.write().await.insert(net, id); } self.clients.write().await.insert(id, vpn_client); - self.notify.notify_waiters(); - Ok(()) + self.update_all_clients().await } pub async fn remove_client(&self, client_id: Uuid) -> Result<()> { self.clients.write().await.remove(&client_id); self.routing_table.write().await.retain(|_, &mut id| id != client_id); - self.notify.notify_waiters(); - Ok(()) + self.update_all_clients().await } pub async fn get_routing_table(&self) -> HashMap { @@ -260,34 +263,53 @@ impl Router { self.clients.read().await.get(&uuid).cloned() } - pub async fn notify_changes(&self, uuid: Uuid) { - self.notify.notified().await; - match self.get_client(uuid).await { - Some(client) => { - println!("Notifying client {} of routing table change", client.id()); - if let Err(e) = client - .send(RouterMessages::RouteUpdate(self.get_routing_table().await), None) - .await - { - eprintln!("Error notifying client {}: {}", client.id(), e); - } - } - None => { - eprintln!( - "Client {} not found while trying to notify of routing table change", - uuid - ); + pub async fn update_all_clients(&self) -> Result<()> { + // Lock only the routing table and clients list once to get the current state, + // then release locks before sending updates to avoid blocking other operations + // while waiting for network I/O + let routing_table = self.get_routing_table().await; + + // Clone the clients list to avoid holding the lock while sending updates + let clients = self.clients.read().await.clone(); + + // Routes update task generator + let mut update_tasks = Vec::new(); + for (id, client) in clients { + let routing_table = routing_table.clone(); + let update_task = tokio::spawn(async move { + let result = client.send(RouterMessages::RouteUpdate(routing_table), None).await; + (id, result) + }); + update_tasks.push(update_task); + } + + // Parallelize sending updates to all clients and wait for all tasks to complete + for task in update_tasks { + let (id, result) = task.await?; + if let Err(e) = result { + eprintln!("Error sending routing table update to client {}: {}", id, e); + } else { + println!("Successfully sent routing table update to client {}", id); } } - // for client in self.clients.read().await.values() { - // println!("Notifying client {} of routing table change", client.id()); - // if let Err(e) = client - // .send(RouterMessages::RouteUpdate(self.get_routing_table().await), None) - // .await - // { - // eprintln!("Error notifying client {}: {}", client.id(), e); - // } - // } + + Ok(()) + } + + pub async fn close_all_client(&self, message: impl AsRef) -> Result<()> { + let clients = self.clients.read().await.clone(); + for (id, client) in clients { + println!("Closing connection with client {}...", id); + client + .send(RouterMessages::Quit(String::from(message.as_ref().to_string())), None) + .await?; + if let Err(e) = client.close().await { + eprintln!("Error closing connection with client {}: {}", id, e); + } else { + println!("Successfully closed connection with client {}", id); + } + } + Ok(()) } } @@ -299,6 +321,21 @@ pub async fn start(bind_address: SocketAddr) -> Result<()> { let socket = tokio::net::TcpListener::bind(bind_address).await?; println!("Router is listening on {}...", bind_address); + tokio::select! { + _ = handle_router_connections(router.clone(), socket) => {} + + _ = tokio::signal::ctrl_c() => { + println!(); + let msg = "Received shutdown signal, shutting down router..."; + println!("{}", msg); + router.close_all_client(msg).await? + } + } + + Ok(()) +} + +async fn handle_router_connections(router: Router, socket: tokio::net::TcpListener) -> ! { loop { let router = router.clone(); match socket.accept().await { @@ -318,7 +355,6 @@ pub async fn start(bind_address: SocketAddr) -> Result<()> { } } } - // Ok(()) } pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> { @@ -339,9 +375,6 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> { ); println!("Registering client {} with router...", vpn_client.id()); router.register_client(client, vpn_client.clone()).await?; - vpn_client - .send(RouterMessages::RouteUpdate(router.get_routing_table().await), None) - .await?; } Err(e) => { eprintln!("Failed to initialize client {}: {}", vpn_client.id(), e); @@ -356,6 +389,13 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> { Ok((message, data)) => { println!("Received message from client {}: {:?}", vpn_client.id(), message); match message { + RouterMessages::Quit(msg) =>{ + println!("Received quit message from client {}: {}", vpn_client.id(), msg); + println!("Removing client {} from router...", vpn_client.id()); + router.remove_client(vpn_client.id()).await?; + vpn_client.close().await?; + return Ok(()); + } RouterMessages::Data(packet_info) => { if let Some(dst_client) = router.get_client(packet_info.dst_uuid).await { println!("Forwarding packet from client {} to client {}", vpn_client.id(), dst_client.id()); @@ -374,20 +414,21 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> { println!("Removing client {} from router...", vpn_client.id()); router.remove_client(vpn_client.id()).await?; vpn_client.close().await?; - return Err(anyhow::anyhow!(format!("Error reading from client: {}", e))); + bail!(format!("Error reading from client: {}", e)); } } } - _= router.notify_changes(vpn_client.id()) => { - println!("Routing table updated. Current routing table:"); - } + // _= router.notify_changes(vpn_client.id()) => { + // println!("Routing table updated. Current routing table:"); + // } _= keep_alive_tick.tick() => { // Send keep-alive message to the client vpn_client.send(RouterMessages::KeepAlive(Utc::now().timestamp_micros()), None).await?; } + } } @@ -414,7 +455,7 @@ pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result Result { eprintln!("Error reading from client {} during registration: {}", vpn_client.id(), e); vpn_client.close().await?; - return Err(anyhow::anyhow!(format!("Error reading from client during registration: {}", e))); + bail!(format!("Error reading from client during registration: {}", e)); } } @@ -433,7 +474,7 @@ pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result Result { Err(e) => { let msg = format!("Failed to create TUN interface: {:#}", e); eprintln!("{}", msg); - return Err(anyhow::anyhow!(msg)); + bail!(msg); } }; @@ -45,7 +45,7 @@ pub async fn add_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> { let stderr = String::from_utf8_lossy(&output.stderr); let msg = format!("Failed to add route {} via device {}: {:#}", route, dev, stderr); eprintln!("{}", msg); - return Err(anyhow::anyhow!(msg)); + bail!(msg); } Ok(()) @@ -66,7 +66,7 @@ pub async fn del_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> { let stderr = String::from_utf8_lossy(&output.stderr); let msg = format!("Failed to delete route {} via device {}: {:#}", route, dev, stderr); eprintln!("{}", msg); - return Err(anyhow::anyhow!(msg)); + bail!(msg); } Ok(())