Refactor error handling to use bail! for cleaner error propagation and add shutdown signal handling in client and router
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
use anyhow::Result;
|
||||
use anyhow::{Result, bail};
|
||||
use chrono::Utc;
|
||||
use clap::Args;
|
||||
use ipnet::Ipv4Net;
|
||||
@@ -74,6 +74,13 @@ pub async fn start(config: ClientCfg) -> Result<()> {
|
||||
match msg {
|
||||
Ok((message, data)) => {
|
||||
match message {
|
||||
RouterMessages::Quit(msg) =>{
|
||||
println!("Received quit message from server {}", msg);
|
||||
println!("Closing connection with Server");
|
||||
client_stream.close().await?;
|
||||
println!("Client shutdown complete");
|
||||
std::process::exit(0);
|
||||
}
|
||||
RouterMessages::KeepAlive(timestamp) => {
|
||||
println!("Received keep-alive message from router with timestamp: {}, delta {} ms", timestamp, (Utc::now().timestamp_micros() - timestamp).abs() as f64 / 1000.0);
|
||||
}
|
||||
@@ -115,7 +122,7 @@ pub async fn start(config: ClientCfg) -> Result<()> {
|
||||
Err(e) => {
|
||||
eprintln!("Error reading from router: {}", e);
|
||||
client_stream.close().await?;
|
||||
return Err(anyhow::anyhow!(format!("Error reading from router: {}", e)));
|
||||
bail!(format!("Error reading from router: {}", e));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -145,10 +152,19 @@ pub async fn start(config: ClientCfg) -> Result<()> {
|
||||
Err(e) => {
|
||||
eprintln!("Error reading from TUN interface: {}", e);
|
||||
client_stream.close().await?;
|
||||
return Err(anyhow::anyhow!(format!("Error reading from TUN interface: {}", e)));
|
||||
bail!(format!("Error reading from TUN interface: {}", e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
let msg = "Received shutdown signal, shutting down...";
|
||||
println!("\n{}\n", msg);
|
||||
client_stream.send(RouterMessages::Quit(msg.into()), None).await?;
|
||||
client_stream.close().await?;
|
||||
println!("Client shutdown complete");
|
||||
std::process::exit(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,27 +185,29 @@ pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, b
|
||||
match msg {
|
||||
Ok((message, _data)) => {
|
||||
match message {
|
||||
RouterMessages::CliReg(CliRegMessages::RegOk(uuid)) => {
|
||||
RouterMessages::CliReg(CliRegMessages::RegOk(uuid)) => {
|
||||
println!("Received registration response from router: {:?}", message);
|
||||
println!("Client registration successful with UUID: {}", uuid);
|
||||
return Ok(uuid);
|
||||
}
|
||||
|
||||
RouterMessages::CliReg(CliRegMessages::RegFailed(err_msg)) => {
|
||||
eprintln!("Client registration failed: {}", err_msg);
|
||||
return Err(anyhow::anyhow!(format!("Client registration failed: {}", err_msg)));
|
||||
bail!("Client registration failed: {}", err_msg);
|
||||
}
|
||||
|
||||
_ => {
|
||||
let msg = "Unexpected message type received during client registration.";
|
||||
eprintln!("{}", msg);
|
||||
client_stream.close().await?;
|
||||
return Err(anyhow::anyhow!(msg));
|
||||
bail!(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error reading from router during client registration: {}", e);
|
||||
client_stream.close().await?;
|
||||
return Err(anyhow::anyhow!(format!("Error reading from router: {}", e)));
|
||||
bail!("Error reading from router: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -198,9 +216,8 @@ pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, b
|
||||
eprintln!("{}", msg);
|
||||
eprintln!("Closing connection with Server");
|
||||
client_stream.close().await?;
|
||||
return Err(anyhow::anyhow!(msg));
|
||||
bail!(msg);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
// Ok(())
|
||||
|
||||
135
src/router.rs
135
src/router.rs
@@ -11,7 +11,7 @@ use tokio::{
|
||||
TcpStream,
|
||||
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||
},
|
||||
sync::{Mutex, Notify, RwLock},
|
||||
sync::{Mutex, RwLock},
|
||||
time::Instant,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
@@ -19,7 +19,7 @@ use uuid::Uuid;
|
||||
use crate::client::ClientCfg;
|
||||
|
||||
pub static KEEP_ALIVE_INTERVAL: Duration = tokio::time::Duration::from_secs(30);
|
||||
pub static CLIENT_REGISTER_TIMEOUT: Duration = tokio::time::Duration::from_millis(100);
|
||||
pub static CLIENT_REGISTER_TIMEOUT: Duration = tokio::time::Duration::from_millis(900);
|
||||
pub static TCP_NODELAY: bool = true;
|
||||
pub const SERVER_PACKET_SIZE: usize = 1024 * 9;
|
||||
|
||||
@@ -177,6 +177,10 @@ impl ClientStream {
|
||||
let payload: Option<Vec<u8>>;
|
||||
|
||||
match rx.read_exact(&mut buf[..6]).await {
|
||||
Ok(0) => {
|
||||
bail!("Client closed the connection");
|
||||
}
|
||||
|
||||
Ok(_) => {
|
||||
let total_len = u16::from_be_bytes(buf[0..2].try_into().unwrap()) as usize;
|
||||
let msg_len = u16::from_be_bytes(buf[2..4].try_into().unwrap()) as usize;
|
||||
@@ -209,8 +213,7 @@ impl ClientStream {
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error reading from client stream: {}", e);
|
||||
return Err(anyhow::anyhow!(format!("Error reading from client stream: {}", e)));
|
||||
bail!(format!("{:#}", e));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -228,28 +231,28 @@ impl ClientStream {
|
||||
pub struct Router {
|
||||
clients: Arc<RwLock<HashMap<Uuid, VPNClient>>>,
|
||||
routing_table: Arc<RwLock<HashMap<Ipv4Net, Uuid>>>,
|
||||
notify: Arc<Notify>,
|
||||
}
|
||||
impl Router {
|
||||
pub async fn register_client(&self, client: ClientCfg, vpn_client: VPNClient) -> Result<()> {
|
||||
let id = vpn_client.id();
|
||||
|
||||
// Append local interface IP to client's local routes for routing table construction
|
||||
let mut client_local_route = client.local_routes;
|
||||
client_local_route.push(client.interface_ip);
|
||||
self.routing_table.write().await.insert(client.interface_ip, id);
|
||||
|
||||
// Add Client local routes to global routing table
|
||||
for net in client_local_route {
|
||||
self.routing_table.write().await.insert(net, id);
|
||||
}
|
||||
|
||||
self.clients.write().await.insert(id, vpn_client);
|
||||
self.notify.notify_waiters();
|
||||
Ok(())
|
||||
self.update_all_clients().await
|
||||
}
|
||||
|
||||
pub async fn remove_client(&self, client_id: Uuid) -> Result<()> {
|
||||
self.clients.write().await.remove(&client_id);
|
||||
self.routing_table.write().await.retain(|_, &mut id| id != client_id);
|
||||
self.notify.notify_waiters();
|
||||
Ok(())
|
||||
self.update_all_clients().await
|
||||
}
|
||||
|
||||
pub async fn get_routing_table(&self) -> HashMap<Ipv4Net, Uuid> {
|
||||
@@ -260,34 +263,53 @@ impl Router {
|
||||
self.clients.read().await.get(&uuid).cloned()
|
||||
}
|
||||
|
||||
pub async fn notify_changes(&self, uuid: Uuid) {
|
||||
self.notify.notified().await;
|
||||
match self.get_client(uuid).await {
|
||||
Some(client) => {
|
||||
println!("Notifying client {} of routing table change", client.id());
|
||||
if let Err(e) = client
|
||||
.send(RouterMessages::RouteUpdate(self.get_routing_table().await), None)
|
||||
.await
|
||||
{
|
||||
eprintln!("Error notifying client {}: {}", client.id(), e);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
eprintln!(
|
||||
"Client {} not found while trying to notify of routing table change",
|
||||
uuid
|
||||
);
|
||||
pub async fn update_all_clients(&self) -> Result<()> {
|
||||
// Lock only the routing table and clients list once to get the current state,
|
||||
// then release locks before sending updates to avoid blocking other operations
|
||||
// while waiting for network I/O
|
||||
let routing_table = self.get_routing_table().await;
|
||||
|
||||
// Clone the clients list to avoid holding the lock while sending updates
|
||||
let clients = self.clients.read().await.clone();
|
||||
|
||||
// Routes update task generator
|
||||
let mut update_tasks = Vec::new();
|
||||
for (id, client) in clients {
|
||||
let routing_table = routing_table.clone();
|
||||
let update_task = tokio::spawn(async move {
|
||||
let result = client.send(RouterMessages::RouteUpdate(routing_table), None).await;
|
||||
(id, result)
|
||||
});
|
||||
update_tasks.push(update_task);
|
||||
}
|
||||
|
||||
// Parallelize sending updates to all clients and wait for all tasks to complete
|
||||
for task in update_tasks {
|
||||
let (id, result) = task.await?;
|
||||
if let Err(e) = result {
|
||||
eprintln!("Error sending routing table update to client {}: {}", id, e);
|
||||
} else {
|
||||
println!("Successfully sent routing table update to client {}", id);
|
||||
}
|
||||
}
|
||||
// for client in self.clients.read().await.values() {
|
||||
// println!("Notifying client {} of routing table change", client.id());
|
||||
// if let Err(e) = client
|
||||
// .send(RouterMessages::RouteUpdate(self.get_routing_table().await), None)
|
||||
// .await
|
||||
// {
|
||||
// eprintln!("Error notifying client {}: {}", client.id(), e);
|
||||
// }
|
||||
// }
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn close_all_client(&self, message: impl AsRef<str>) -> Result<()> {
|
||||
let clients = self.clients.read().await.clone();
|
||||
for (id, client) in clients {
|
||||
println!("Closing connection with client {}...", id);
|
||||
client
|
||||
.send(RouterMessages::Quit(String::from(message.as_ref().to_string())), None)
|
||||
.await?;
|
||||
if let Err(e) = client.close().await {
|
||||
eprintln!("Error closing connection with client {}: {}", id, e);
|
||||
} else {
|
||||
println!("Successfully closed connection with client {}", id);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -299,6 +321,21 @@ pub async fn start(bind_address: SocketAddr) -> Result<()> {
|
||||
let socket = tokio::net::TcpListener::bind(bind_address).await?;
|
||||
println!("Router is listening on {}...", bind_address);
|
||||
|
||||
tokio::select! {
|
||||
_ = handle_router_connections(router.clone(), socket) => {}
|
||||
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
println!();
|
||||
let msg = "Received shutdown signal, shutting down router...";
|
||||
println!("{}", msg);
|
||||
router.close_all_client(msg).await?
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_router_connections(router: Router, socket: tokio::net::TcpListener) -> ! {
|
||||
loop {
|
||||
let router = router.clone();
|
||||
match socket.accept().await {
|
||||
@@ -318,7 +355,6 @@ pub async fn start(bind_address: SocketAddr) -> Result<()> {
|
||||
}
|
||||
}
|
||||
}
|
||||
// Ok(())
|
||||
}
|
||||
|
||||
pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
|
||||
@@ -339,9 +375,6 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
|
||||
);
|
||||
println!("Registering client {} with router...", vpn_client.id());
|
||||
router.register_client(client, vpn_client.clone()).await?;
|
||||
vpn_client
|
||||
.send(RouterMessages::RouteUpdate(router.get_routing_table().await), None)
|
||||
.await?;
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Failed to initialize client {}: {}", vpn_client.id(), e);
|
||||
@@ -356,6 +389,13 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
|
||||
Ok((message, data)) => {
|
||||
println!("Received message from client {}: {:?}", vpn_client.id(), message);
|
||||
match message {
|
||||
RouterMessages::Quit(msg) =>{
|
||||
println!("Received quit message from client {}: {}", vpn_client.id(), msg);
|
||||
println!("Removing client {} from router...", vpn_client.id());
|
||||
router.remove_client(vpn_client.id()).await?;
|
||||
vpn_client.close().await?;
|
||||
return Ok(());
|
||||
}
|
||||
RouterMessages::Data(packet_info) => {
|
||||
if let Some(dst_client) = router.get_client(packet_info.dst_uuid).await {
|
||||
println!("Forwarding packet from client {} to client {}", vpn_client.id(), dst_client.id());
|
||||
@@ -374,20 +414,21 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
|
||||
println!("Removing client {} from router...", vpn_client.id());
|
||||
router.remove_client(vpn_client.id()).await?;
|
||||
vpn_client.close().await?;
|
||||
return Err(anyhow::anyhow!(format!("Error reading from client: {}", e)));
|
||||
bail!(format!("Error reading from client: {}", e));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
_= router.notify_changes(vpn_client.id()) => {
|
||||
println!("Routing table updated. Current routing table:");
|
||||
}
|
||||
// _= router.notify_changes(vpn_client.id()) => {
|
||||
// println!("Routing table updated. Current routing table:");
|
||||
// }
|
||||
|
||||
_= keep_alive_tick.tick() => {
|
||||
// Send keep-alive message to the client
|
||||
vpn_client.send(RouterMessages::KeepAlive(Utc::now().timestamp_micros()), None).await?;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -414,7 +455,7 @@ pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result<Clien
|
||||
eprintln!("{}", msg);
|
||||
eprintln!("Closing connection with client {}", vpn_client.id());
|
||||
vpn_client.close().await?;
|
||||
return Err(anyhow::anyhow!(msg));
|
||||
bail!(msg);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -422,7 +463,7 @@ pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result<Clien
|
||||
Err(e) => {
|
||||
eprintln!("Error reading from client {} during registration: {}", vpn_client.id(), e);
|
||||
vpn_client.close().await?;
|
||||
return Err(anyhow::anyhow!(format!("Error reading from client during registration: {}", e)));
|
||||
bail!(format!("Error reading from client during registration: {}", e));
|
||||
|
||||
}
|
||||
}
|
||||
@@ -433,7 +474,7 @@ pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result<Clien
|
||||
let msg = format!("Client registration timed out after {}ms", (CLIENT_REGISTER_TIMEOUT.as_millis()));
|
||||
vpn_client.send(RouterMessages::Quit(msg.clone()),None).await?;
|
||||
vpn_client.close().await?;
|
||||
return Err(anyhow::anyhow!(msg));
|
||||
bail!(msg);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use anyhow::{Context, Result};
|
||||
use anyhow::{Context, Result, bail};
|
||||
use ipnet::Ipv4Net;
|
||||
use tokio::process::Command;
|
||||
use tun_rs::{AsyncDevice, DeviceBuilder};
|
||||
@@ -23,7 +23,7 @@ pub async fn inti_tun_interface(config: &ClientCfg) -> Result<AsyncDevice> {
|
||||
Err(e) => {
|
||||
let msg = format!("Failed to create TUN interface: {:#}", e);
|
||||
eprintln!("{}", msg);
|
||||
return Err(anyhow::anyhow!(msg));
|
||||
bail!(msg);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -45,7 +45,7 @@ pub async fn add_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let msg = format!("Failed to add route {} via device {}: {:#}", route, dev, stderr);
|
||||
eprintln!("{}", msg);
|
||||
return Err(anyhow::anyhow!(msg));
|
||||
bail!(msg);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -66,7 +66,7 @@ pub async fn del_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let msg = format!("Failed to delete route {} via device {}: {:#}", route, dev, stderr);
|
||||
eprintln!("{}", msg);
|
||||
return Err(anyhow::anyhow!(msg));
|
||||
bail!(msg);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
Reference in New Issue
Block a user