Refactor error handling to use bail! for cleaner error propagation and add shutdown signal handling in client and router

This commit is contained in:
2026-03-03 10:19:37 +01:00
parent 3113b04294
commit 479270834d
3 changed files with 118 additions and 60 deletions

View File

@@ -1,4 +1,4 @@
use anyhow::Result;
use anyhow::{Result, bail};
use chrono::Utc;
use clap::Args;
use ipnet::Ipv4Net;
@@ -74,6 +74,13 @@ pub async fn start(config: ClientCfg) -> Result<()> {
match msg {
Ok((message, data)) => {
match message {
RouterMessages::Quit(msg) =>{
println!("Received quit message from server {}", msg);
println!("Closing connection with Server");
client_stream.close().await?;
println!("Client shutdown complete");
std::process::exit(0);
}
RouterMessages::KeepAlive(timestamp) => {
println!("Received keep-alive message from router with timestamp: {}, delta {} ms", timestamp, (Utc::now().timestamp_micros() - timestamp).abs() as f64 / 1000.0);
}
@@ -115,7 +122,7 @@ pub async fn start(config: ClientCfg) -> Result<()> {
Err(e) => {
eprintln!("Error reading from router: {}", e);
client_stream.close().await?;
return Err(anyhow::anyhow!(format!("Error reading from router: {}", e)));
bail!(format!("Error reading from router: {}", e));
}
}
}
@@ -145,10 +152,19 @@ pub async fn start(config: ClientCfg) -> Result<()> {
Err(e) => {
eprintln!("Error reading from TUN interface: {}", e);
client_stream.close().await?;
return Err(anyhow::anyhow!(format!("Error reading from TUN interface: {}", e)));
bail!(format!("Error reading from TUN interface: {}", e));
}
}
}
_ = tokio::signal::ctrl_c() => {
let msg = "Received shutdown signal, shutting down...";
println!("\n{}\n", msg);
client_stream.send(RouterMessages::Quit(msg.into()), None).await?;
client_stream.close().await?;
println!("Client shutdown complete");
std::process::exit(0);
}
}
}
@@ -169,27 +185,29 @@ pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, b
match msg {
Ok((message, _data)) => {
match message {
RouterMessages::CliReg(CliRegMessages::RegOk(uuid)) => {
RouterMessages::CliReg(CliRegMessages::RegOk(uuid)) => {
println!("Received registration response from router: {:?}", message);
println!("Client registration successful with UUID: {}", uuid);
return Ok(uuid);
}
RouterMessages::CliReg(CliRegMessages::RegFailed(err_msg)) => {
eprintln!("Client registration failed: {}", err_msg);
return Err(anyhow::anyhow!(format!("Client registration failed: {}", err_msg)));
bail!("Client registration failed: {}", err_msg);
}
_ => {
let msg = "Unexpected message type received during client registration.";
eprintln!("{}", msg);
client_stream.close().await?;
return Err(anyhow::anyhow!(msg));
bail!(msg);
}
}
}
Err(e) => {
eprintln!("Error reading from router during client registration: {}", e);
client_stream.close().await?;
return Err(anyhow::anyhow!(format!("Error reading from router: {}", e)));
bail!("Error reading from router: {}", e);
}
}
}
@@ -198,9 +216,8 @@ pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, b
eprintln!("{}", msg);
eprintln!("Closing connection with Server");
client_stream.close().await?;
return Err(anyhow::anyhow!(msg));
bail!(msg);
}
}
}
// Ok(())

View File

@@ -11,7 +11,7 @@ use tokio::{
TcpStream,
tcp::{OwnedReadHalf, OwnedWriteHalf},
},
sync::{Mutex, Notify, RwLock},
sync::{Mutex, RwLock},
time::Instant,
};
use uuid::Uuid;
@@ -19,7 +19,7 @@ use uuid::Uuid;
use crate::client::ClientCfg;
pub static KEEP_ALIVE_INTERVAL: Duration = tokio::time::Duration::from_secs(30);
pub static CLIENT_REGISTER_TIMEOUT: Duration = tokio::time::Duration::from_millis(100);
pub static CLIENT_REGISTER_TIMEOUT: Duration = tokio::time::Duration::from_millis(900);
pub static TCP_NODELAY: bool = true;
pub const SERVER_PACKET_SIZE: usize = 1024 * 9;
@@ -177,6 +177,10 @@ impl ClientStream {
let payload: Option<Vec<u8>>;
match rx.read_exact(&mut buf[..6]).await {
Ok(0) => {
bail!("Client closed the connection");
}
Ok(_) => {
let total_len = u16::from_be_bytes(buf[0..2].try_into().unwrap()) as usize;
let msg_len = u16::from_be_bytes(buf[2..4].try_into().unwrap()) as usize;
@@ -209,8 +213,7 @@ impl ClientStream {
}
}
Err(e) => {
eprintln!("Error reading from client stream: {}", e);
return Err(anyhow::anyhow!(format!("Error reading from client stream: {}", e)));
bail!(format!("{:#}", e));
}
};
@@ -228,28 +231,28 @@ impl ClientStream {
pub struct Router {
clients: Arc<RwLock<HashMap<Uuid, VPNClient>>>,
routing_table: Arc<RwLock<HashMap<Ipv4Net, Uuid>>>,
notify: Arc<Notify>,
}
impl Router {
pub async fn register_client(&self, client: ClientCfg, vpn_client: VPNClient) -> Result<()> {
let id = vpn_client.id();
// Append local interface IP to client's local routes for routing table construction
let mut client_local_route = client.local_routes;
client_local_route.push(client.interface_ip);
self.routing_table.write().await.insert(client.interface_ip, id);
// Add Client local routes to global routing table
for net in client_local_route {
self.routing_table.write().await.insert(net, id);
}
self.clients.write().await.insert(id, vpn_client);
self.notify.notify_waiters();
Ok(())
self.update_all_clients().await
}
pub async fn remove_client(&self, client_id: Uuid) -> Result<()> {
self.clients.write().await.remove(&client_id);
self.routing_table.write().await.retain(|_, &mut id| id != client_id);
self.notify.notify_waiters();
Ok(())
self.update_all_clients().await
}
pub async fn get_routing_table(&self) -> HashMap<Ipv4Net, Uuid> {
@@ -260,34 +263,53 @@ impl Router {
self.clients.read().await.get(&uuid).cloned()
}
pub async fn notify_changes(&self, uuid: Uuid) {
self.notify.notified().await;
match self.get_client(uuid).await {
Some(client) => {
println!("Notifying client {} of routing table change", client.id());
if let Err(e) = client
.send(RouterMessages::RouteUpdate(self.get_routing_table().await), None)
.await
{
eprintln!("Error notifying client {}: {}", client.id(), e);
}
}
None => {
eprintln!(
"Client {} not found while trying to notify of routing table change",
uuid
);
pub async fn update_all_clients(&self) -> Result<()> {
// Lock only the routing table and clients list once to get the current state,
// then release locks before sending updates to avoid blocking other operations
// while waiting for network I/O
let routing_table = self.get_routing_table().await;
// Clone the clients list to avoid holding the lock while sending updates
let clients = self.clients.read().await.clone();
// Routes update task generator
let mut update_tasks = Vec::new();
for (id, client) in clients {
let routing_table = routing_table.clone();
let update_task = tokio::spawn(async move {
let result = client.send(RouterMessages::RouteUpdate(routing_table), None).await;
(id, result)
});
update_tasks.push(update_task);
}
// Parallelize sending updates to all clients and wait for all tasks to complete
for task in update_tasks {
let (id, result) = task.await?;
if let Err(e) = result {
eprintln!("Error sending routing table update to client {}: {}", id, e);
} else {
println!("Successfully sent routing table update to client {}", id);
}
}
// for client in self.clients.read().await.values() {
// println!("Notifying client {} of routing table change", client.id());
// if let Err(e) = client
// .send(RouterMessages::RouteUpdate(self.get_routing_table().await), None)
// .await
// {
// eprintln!("Error notifying client {}: {}", client.id(), e);
// }
// }
Ok(())
}
pub async fn close_all_client(&self, message: impl AsRef<str>) -> Result<()> {
let clients = self.clients.read().await.clone();
for (id, client) in clients {
println!("Closing connection with client {}...", id);
client
.send(RouterMessages::Quit(String::from(message.as_ref().to_string())), None)
.await?;
if let Err(e) = client.close().await {
eprintln!("Error closing connection with client {}: {}", id, e);
} else {
println!("Successfully closed connection with client {}", id);
}
}
Ok(())
}
}
@@ -299,6 +321,21 @@ pub async fn start(bind_address: SocketAddr) -> Result<()> {
let socket = tokio::net::TcpListener::bind(bind_address).await?;
println!("Router is listening on {}...", bind_address);
tokio::select! {
_ = handle_router_connections(router.clone(), socket) => {}
_ = tokio::signal::ctrl_c() => {
println!();
let msg = "Received shutdown signal, shutting down router...";
println!("{}", msg);
router.close_all_client(msg).await?
}
}
Ok(())
}
async fn handle_router_connections(router: Router, socket: tokio::net::TcpListener) -> ! {
loop {
let router = router.clone();
match socket.accept().await {
@@ -318,7 +355,6 @@ pub async fn start(bind_address: SocketAddr) -> Result<()> {
}
}
}
// Ok(())
}
pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
@@ -339,9 +375,6 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
);
println!("Registering client {} with router...", vpn_client.id());
router.register_client(client, vpn_client.clone()).await?;
vpn_client
.send(RouterMessages::RouteUpdate(router.get_routing_table().await), None)
.await?;
}
Err(e) => {
eprintln!("Failed to initialize client {}: {}", vpn_client.id(), e);
@@ -356,6 +389,13 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
Ok((message, data)) => {
println!("Received message from client {}: {:?}", vpn_client.id(), message);
match message {
RouterMessages::Quit(msg) =>{
println!("Received quit message from client {}: {}", vpn_client.id(), msg);
println!("Removing client {} from router...", vpn_client.id());
router.remove_client(vpn_client.id()).await?;
vpn_client.close().await?;
return Ok(());
}
RouterMessages::Data(packet_info) => {
if let Some(dst_client) = router.get_client(packet_info.dst_uuid).await {
println!("Forwarding packet from client {} to client {}", vpn_client.id(), dst_client.id());
@@ -374,20 +414,21 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
println!("Removing client {} from router...", vpn_client.id());
router.remove_client(vpn_client.id()).await?;
vpn_client.close().await?;
return Err(anyhow::anyhow!(format!("Error reading from client: {}", e)));
bail!(format!("Error reading from client: {}", e));
}
}
}
_= router.notify_changes(vpn_client.id()) => {
println!("Routing table updated. Current routing table:");
}
// _= router.notify_changes(vpn_client.id()) => {
// println!("Routing table updated. Current routing table:");
// }
_= keep_alive_tick.tick() => {
// Send keep-alive message to the client
vpn_client.send(RouterMessages::KeepAlive(Utc::now().timestamp_micros()), None).await?;
}
}
}
@@ -414,7 +455,7 @@ pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result<Clien
eprintln!("{}", msg);
eprintln!("Closing connection with client {}", vpn_client.id());
vpn_client.close().await?;
return Err(anyhow::anyhow!(msg));
bail!(msg);
}
}
@@ -422,7 +463,7 @@ pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result<Clien
Err(e) => {
eprintln!("Error reading from client {} during registration: {}", vpn_client.id(), e);
vpn_client.close().await?;
return Err(anyhow::anyhow!(format!("Error reading from client during registration: {}", e)));
bail!(format!("Error reading from client during registration: {}", e));
}
}
@@ -433,7 +474,7 @@ pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result<Clien
let msg = format!("Client registration timed out after {}ms", (CLIENT_REGISTER_TIMEOUT.as_millis()));
vpn_client.send(RouterMessages::Quit(msg.clone()),None).await?;
vpn_client.close().await?;
return Err(anyhow::anyhow!(msg));
bail!(msg);
}
}

View File

@@ -1,4 +1,4 @@
use anyhow::{Context, Result};
use anyhow::{Context, Result, bail};
use ipnet::Ipv4Net;
use tokio::process::Command;
use tun_rs::{AsyncDevice, DeviceBuilder};
@@ -23,7 +23,7 @@ pub async fn inti_tun_interface(config: &ClientCfg) -> Result<AsyncDevice> {
Err(e) => {
let msg = format!("Failed to create TUN interface: {:#}", e);
eprintln!("{}", msg);
return Err(anyhow::anyhow!(msg));
bail!(msg);
}
};
@@ -45,7 +45,7 @@ pub async fn add_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> {
let stderr = String::from_utf8_lossy(&output.stderr);
let msg = format!("Failed to add route {} via device {}: {:#}", route, dev, stderr);
eprintln!("{}", msg);
return Err(anyhow::anyhow!(msg));
bail!(msg);
}
Ok(())
@@ -66,7 +66,7 @@ pub async fn del_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> {
let stderr = String::from_utf8_lossy(&output.stderr);
let msg = format!("Failed to delete route {} via device {}: {:#}", route, dev, stderr);
eprintln!("{}", msg);
return Err(anyhow::anyhow!(msg));
bail!(msg);
}
Ok(())