Refactor error handling to use bail! for cleaner error propagation and add shutdown signal handling in client and router

This commit is contained in:
2026-03-03 10:19:37 +01:00
parent 3113b04294
commit 479270834d
3 changed files with 118 additions and 60 deletions

View File

@@ -1,4 +1,4 @@
use anyhow::Result; use anyhow::{Result, bail};
use chrono::Utc; use chrono::Utc;
use clap::Args; use clap::Args;
use ipnet::Ipv4Net; use ipnet::Ipv4Net;
@@ -74,6 +74,13 @@ pub async fn start(config: ClientCfg) -> Result<()> {
match msg { match msg {
Ok((message, data)) => { Ok((message, data)) => {
match message { match message {
RouterMessages::Quit(msg) =>{
println!("Received quit message from server {}", msg);
println!("Closing connection with Server");
client_stream.close().await?;
println!("Client shutdown complete");
std::process::exit(0);
}
RouterMessages::KeepAlive(timestamp) => { RouterMessages::KeepAlive(timestamp) => {
println!("Received keep-alive message from router with timestamp: {}, delta {} ms", timestamp, (Utc::now().timestamp_micros() - timestamp).abs() as f64 / 1000.0); println!("Received keep-alive message from router with timestamp: {}, delta {} ms", timestamp, (Utc::now().timestamp_micros() - timestamp).abs() as f64 / 1000.0);
} }
@@ -115,7 +122,7 @@ pub async fn start(config: ClientCfg) -> Result<()> {
Err(e) => { Err(e) => {
eprintln!("Error reading from router: {}", e); eprintln!("Error reading from router: {}", e);
client_stream.close().await?; client_stream.close().await?;
return Err(anyhow::anyhow!(format!("Error reading from router: {}", e))); bail!(format!("Error reading from router: {}", e));
} }
} }
} }
@@ -145,10 +152,19 @@ pub async fn start(config: ClientCfg) -> Result<()> {
Err(e) => { Err(e) => {
eprintln!("Error reading from TUN interface: {}", e); eprintln!("Error reading from TUN interface: {}", e);
client_stream.close().await?; client_stream.close().await?;
return Err(anyhow::anyhow!(format!("Error reading from TUN interface: {}", e))); bail!(format!("Error reading from TUN interface: {}", e));
} }
} }
} }
_ = tokio::signal::ctrl_c() => {
let msg = "Received shutdown signal, shutting down...";
println!("\n{}\n", msg);
client_stream.send(RouterMessages::Quit(msg.into()), None).await?;
client_stream.close().await?;
println!("Client shutdown complete");
std::process::exit(0);
}
} }
} }
@@ -169,27 +185,29 @@ pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, b
match msg { match msg {
Ok((message, _data)) => { Ok((message, _data)) => {
match message { match message {
RouterMessages::CliReg(CliRegMessages::RegOk(uuid)) => { RouterMessages::CliReg(CliRegMessages::RegOk(uuid)) => {
println!("Received registration response from router: {:?}", message); println!("Received registration response from router: {:?}", message);
println!("Client registration successful with UUID: {}", uuid); println!("Client registration successful with UUID: {}", uuid);
return Ok(uuid); return Ok(uuid);
} }
RouterMessages::CliReg(CliRegMessages::RegFailed(err_msg)) => { RouterMessages::CliReg(CliRegMessages::RegFailed(err_msg)) => {
eprintln!("Client registration failed: {}", err_msg); eprintln!("Client registration failed: {}", err_msg);
return Err(anyhow::anyhow!(format!("Client registration failed: {}", err_msg))); bail!("Client registration failed: {}", err_msg);
} }
_ => { _ => {
let msg = "Unexpected message type received during client registration."; let msg = "Unexpected message type received during client registration.";
eprintln!("{}", msg); eprintln!("{}", msg);
client_stream.close().await?; client_stream.close().await?;
return Err(anyhow::anyhow!(msg)); bail!(msg);
} }
} }
} }
Err(e) => { Err(e) => {
eprintln!("Error reading from router during client registration: {}", e); eprintln!("Error reading from router during client registration: {}", e);
client_stream.close().await?; client_stream.close().await?;
return Err(anyhow::anyhow!(format!("Error reading from router: {}", e))); bail!("Error reading from router: {}", e);
} }
} }
} }
@@ -198,9 +216,8 @@ pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, b
eprintln!("{}", msg); eprintln!("{}", msg);
eprintln!("Closing connection with Server"); eprintln!("Closing connection with Server");
client_stream.close().await?; client_stream.close().await?;
return Err(anyhow::anyhow!(msg)); bail!(msg);
} }
} }
} }
// Ok(()) // Ok(())

View File

@@ -11,7 +11,7 @@ use tokio::{
TcpStream, TcpStream,
tcp::{OwnedReadHalf, OwnedWriteHalf}, tcp::{OwnedReadHalf, OwnedWriteHalf},
}, },
sync::{Mutex, Notify, RwLock}, sync::{Mutex, RwLock},
time::Instant, time::Instant,
}; };
use uuid::Uuid; use uuid::Uuid;
@@ -19,7 +19,7 @@ use uuid::Uuid;
use crate::client::ClientCfg; use crate::client::ClientCfg;
pub static KEEP_ALIVE_INTERVAL: Duration = tokio::time::Duration::from_secs(30); pub static KEEP_ALIVE_INTERVAL: Duration = tokio::time::Duration::from_secs(30);
pub static CLIENT_REGISTER_TIMEOUT: Duration = tokio::time::Duration::from_millis(100); pub static CLIENT_REGISTER_TIMEOUT: Duration = tokio::time::Duration::from_millis(900);
pub static TCP_NODELAY: bool = true; pub static TCP_NODELAY: bool = true;
pub const SERVER_PACKET_SIZE: usize = 1024 * 9; pub const SERVER_PACKET_SIZE: usize = 1024 * 9;
@@ -177,6 +177,10 @@ impl ClientStream {
let payload: Option<Vec<u8>>; let payload: Option<Vec<u8>>;
match rx.read_exact(&mut buf[..6]).await { match rx.read_exact(&mut buf[..6]).await {
Ok(0) => {
bail!("Client closed the connection");
}
Ok(_) => { Ok(_) => {
let total_len = u16::from_be_bytes(buf[0..2].try_into().unwrap()) as usize; let total_len = u16::from_be_bytes(buf[0..2].try_into().unwrap()) as usize;
let msg_len = u16::from_be_bytes(buf[2..4].try_into().unwrap()) as usize; let msg_len = u16::from_be_bytes(buf[2..4].try_into().unwrap()) as usize;
@@ -209,8 +213,7 @@ impl ClientStream {
} }
} }
Err(e) => { Err(e) => {
eprintln!("Error reading from client stream: {}", e); bail!(format!("{:#}", e));
return Err(anyhow::anyhow!(format!("Error reading from client stream: {}", e)));
} }
}; };
@@ -228,28 +231,28 @@ impl ClientStream {
pub struct Router { pub struct Router {
clients: Arc<RwLock<HashMap<Uuid, VPNClient>>>, clients: Arc<RwLock<HashMap<Uuid, VPNClient>>>,
routing_table: Arc<RwLock<HashMap<Ipv4Net, Uuid>>>, routing_table: Arc<RwLock<HashMap<Ipv4Net, Uuid>>>,
notify: Arc<Notify>,
} }
impl Router { impl Router {
pub async fn register_client(&self, client: ClientCfg, vpn_client: VPNClient) -> Result<()> { pub async fn register_client(&self, client: ClientCfg, vpn_client: VPNClient) -> Result<()> {
let id = vpn_client.id(); let id = vpn_client.id();
// Append local interface IP to client's local routes for routing table construction
let mut client_local_route = client.local_routes; let mut client_local_route = client.local_routes;
client_local_route.push(client.interface_ip); client_local_route.push(client.interface_ip);
self.routing_table.write().await.insert(client.interface_ip, id);
// Add Client local routes to global routing table
for net in client_local_route { for net in client_local_route {
self.routing_table.write().await.insert(net, id); self.routing_table.write().await.insert(net, id);
} }
self.clients.write().await.insert(id, vpn_client); self.clients.write().await.insert(id, vpn_client);
self.notify.notify_waiters(); self.update_all_clients().await
Ok(())
} }
pub async fn remove_client(&self, client_id: Uuid) -> Result<()> { pub async fn remove_client(&self, client_id: Uuid) -> Result<()> {
self.clients.write().await.remove(&client_id); self.clients.write().await.remove(&client_id);
self.routing_table.write().await.retain(|_, &mut id| id != client_id); self.routing_table.write().await.retain(|_, &mut id| id != client_id);
self.notify.notify_waiters(); self.update_all_clients().await
Ok(())
} }
pub async fn get_routing_table(&self) -> HashMap<Ipv4Net, Uuid> { pub async fn get_routing_table(&self) -> HashMap<Ipv4Net, Uuid> {
@@ -260,34 +263,53 @@ impl Router {
self.clients.read().await.get(&uuid).cloned() self.clients.read().await.get(&uuid).cloned()
} }
pub async fn notify_changes(&self, uuid: Uuid) { pub async fn update_all_clients(&self) -> Result<()> {
self.notify.notified().await; // Lock only the routing table and clients list once to get the current state,
match self.get_client(uuid).await { // then release locks before sending updates to avoid blocking other operations
Some(client) => { // while waiting for network I/O
println!("Notifying client {} of routing table change", client.id()); let routing_table = self.get_routing_table().await;
if let Err(e) = client
.send(RouterMessages::RouteUpdate(self.get_routing_table().await), None) // Clone the clients list to avoid holding the lock while sending updates
.await let clients = self.clients.read().await.clone();
{
eprintln!("Error notifying client {}: {}", client.id(), e); // Routes update task generator
} let mut update_tasks = Vec::new();
} for (id, client) in clients {
None => { let routing_table = routing_table.clone();
eprintln!( let update_task = tokio::spawn(async move {
"Client {} not found while trying to notify of routing table change", let result = client.send(RouterMessages::RouteUpdate(routing_table), None).await;
uuid (id, result)
); });
update_tasks.push(update_task);
}
// Parallelize sending updates to all clients and wait for all tasks to complete
for task in update_tasks {
let (id, result) = task.await?;
if let Err(e) = result {
eprintln!("Error sending routing table update to client {}: {}", id, e);
} else {
println!("Successfully sent routing table update to client {}", id);
} }
} }
// for client in self.clients.read().await.values() {
// println!("Notifying client {} of routing table change", client.id()); Ok(())
// if let Err(e) = client }
// .send(RouterMessages::RouteUpdate(self.get_routing_table().await), None)
// .await pub async fn close_all_client(&self, message: impl AsRef<str>) -> Result<()> {
// { let clients = self.clients.read().await.clone();
// eprintln!("Error notifying client {}: {}", client.id(), e); for (id, client) in clients {
// } println!("Closing connection with client {}...", id);
// } client
.send(RouterMessages::Quit(String::from(message.as_ref().to_string())), None)
.await?;
if let Err(e) = client.close().await {
eprintln!("Error closing connection with client {}: {}", id, e);
} else {
println!("Successfully closed connection with client {}", id);
}
}
Ok(())
} }
} }
@@ -299,6 +321,21 @@ pub async fn start(bind_address: SocketAddr) -> Result<()> {
let socket = tokio::net::TcpListener::bind(bind_address).await?; let socket = tokio::net::TcpListener::bind(bind_address).await?;
println!("Router is listening on {}...", bind_address); println!("Router is listening on {}...", bind_address);
tokio::select! {
_ = handle_router_connections(router.clone(), socket) => {}
_ = tokio::signal::ctrl_c() => {
println!();
let msg = "Received shutdown signal, shutting down router...";
println!("{}", msg);
router.close_all_client(msg).await?
}
}
Ok(())
}
async fn handle_router_connections(router: Router, socket: tokio::net::TcpListener) -> ! {
loop { loop {
let router = router.clone(); let router = router.clone();
match socket.accept().await { match socket.accept().await {
@@ -318,7 +355,6 @@ pub async fn start(bind_address: SocketAddr) -> Result<()> {
} }
} }
} }
// Ok(())
} }
pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> { pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
@@ -339,9 +375,6 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
); );
println!("Registering client {} with router...", vpn_client.id()); println!("Registering client {} with router...", vpn_client.id());
router.register_client(client, vpn_client.clone()).await?; router.register_client(client, vpn_client.clone()).await?;
vpn_client
.send(RouterMessages::RouteUpdate(router.get_routing_table().await), None)
.await?;
} }
Err(e) => { Err(e) => {
eprintln!("Failed to initialize client {}: {}", vpn_client.id(), e); eprintln!("Failed to initialize client {}: {}", vpn_client.id(), e);
@@ -356,6 +389,13 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
Ok((message, data)) => { Ok((message, data)) => {
println!("Received message from client {}: {:?}", vpn_client.id(), message); println!("Received message from client {}: {:?}", vpn_client.id(), message);
match message { match message {
RouterMessages::Quit(msg) =>{
println!("Received quit message from client {}: {}", vpn_client.id(), msg);
println!("Removing client {} from router...", vpn_client.id());
router.remove_client(vpn_client.id()).await?;
vpn_client.close().await?;
return Ok(());
}
RouterMessages::Data(packet_info) => { RouterMessages::Data(packet_info) => {
if let Some(dst_client) = router.get_client(packet_info.dst_uuid).await { if let Some(dst_client) = router.get_client(packet_info.dst_uuid).await {
println!("Forwarding packet from client {} to client {}", vpn_client.id(), dst_client.id()); println!("Forwarding packet from client {} to client {}", vpn_client.id(), dst_client.id());
@@ -374,20 +414,21 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
println!("Removing client {} from router...", vpn_client.id()); println!("Removing client {} from router...", vpn_client.id());
router.remove_client(vpn_client.id()).await?; router.remove_client(vpn_client.id()).await?;
vpn_client.close().await?; vpn_client.close().await?;
return Err(anyhow::anyhow!(format!("Error reading from client: {}", e))); bail!(format!("Error reading from client: {}", e));
} }
} }
} }
_= router.notify_changes(vpn_client.id()) => { // _= router.notify_changes(vpn_client.id()) => {
println!("Routing table updated. Current routing table:"); // println!("Routing table updated. Current routing table:");
} // }
_= keep_alive_tick.tick() => { _= keep_alive_tick.tick() => {
// Send keep-alive message to the client // Send keep-alive message to the client
vpn_client.send(RouterMessages::KeepAlive(Utc::now().timestamp_micros()), None).await?; vpn_client.send(RouterMessages::KeepAlive(Utc::now().timestamp_micros()), None).await?;
} }
} }
} }
@@ -414,7 +455,7 @@ pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result<Clien
eprintln!("{}", msg); eprintln!("{}", msg);
eprintln!("Closing connection with client {}", vpn_client.id()); eprintln!("Closing connection with client {}", vpn_client.id());
vpn_client.close().await?; vpn_client.close().await?;
return Err(anyhow::anyhow!(msg)); bail!(msg);
} }
} }
@@ -422,7 +463,7 @@ pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result<Clien
Err(e) => { Err(e) => {
eprintln!("Error reading from client {} during registration: {}", vpn_client.id(), e); eprintln!("Error reading from client {} during registration: {}", vpn_client.id(), e);
vpn_client.close().await?; vpn_client.close().await?;
return Err(anyhow::anyhow!(format!("Error reading from client during registration: {}", e))); bail!(format!("Error reading from client during registration: {}", e));
} }
} }
@@ -433,7 +474,7 @@ pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result<Clien
let msg = format!("Client registration timed out after {}ms", (CLIENT_REGISTER_TIMEOUT.as_millis())); let msg = format!("Client registration timed out after {}ms", (CLIENT_REGISTER_TIMEOUT.as_millis()));
vpn_client.send(RouterMessages::Quit(msg.clone()),None).await?; vpn_client.send(RouterMessages::Quit(msg.clone()),None).await?;
vpn_client.close().await?; vpn_client.close().await?;
return Err(anyhow::anyhow!(msg)); bail!(msg);
} }
} }

View File

@@ -1,4 +1,4 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result, bail};
use ipnet::Ipv4Net; use ipnet::Ipv4Net;
use tokio::process::Command; use tokio::process::Command;
use tun_rs::{AsyncDevice, DeviceBuilder}; use tun_rs::{AsyncDevice, DeviceBuilder};
@@ -23,7 +23,7 @@ pub async fn inti_tun_interface(config: &ClientCfg) -> Result<AsyncDevice> {
Err(e) => { Err(e) => {
let msg = format!("Failed to create TUN interface: {:#}", e); let msg = format!("Failed to create TUN interface: {:#}", e);
eprintln!("{}", msg); eprintln!("{}", msg);
return Err(anyhow::anyhow!(msg)); bail!(msg);
} }
}; };
@@ -45,7 +45,7 @@ pub async fn add_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> {
let stderr = String::from_utf8_lossy(&output.stderr); let stderr = String::from_utf8_lossy(&output.stderr);
let msg = format!("Failed to add route {} via device {}: {:#}", route, dev, stderr); let msg = format!("Failed to add route {} via device {}: {:#}", route, dev, stderr);
eprintln!("{}", msg); eprintln!("{}", msg);
return Err(anyhow::anyhow!(msg)); bail!(msg);
} }
Ok(()) Ok(())
@@ -66,7 +66,7 @@ pub async fn del_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> {
let stderr = String::from_utf8_lossy(&output.stderr); let stderr = String::from_utf8_lossy(&output.stderr);
let msg = format!("Failed to delete route {} via device {}: {:#}", route, dev, stderr); let msg = format!("Failed to delete route {} via device {}: {:#}", route, dev, stderr);
eprintln!("{}", msg); eprintln!("{}", msg);
return Err(anyhow::anyhow!(msg)); bail!(msg);
} }
Ok(()) Ok(())