diff --git a/src/router.rs b/src/router.rs index 6fbb9c8..5b34b83 100644 --- a/src/router.rs +++ b/src/router.rs @@ -1,11 +1,10 @@ -use anyhow::{Context, Result, bail}; +use anyhow::{Result, bail}; use chrono::Utc; -use etherparse::err::packet; use ipnet::Ipv4Net; use serde::{Deserialize, Serialize}; -use serde_bytes::Bytes; -use std::{collections::HashMap, mem, net::SocketAddr, sync::Arc, time::Duration}; + +use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{ @@ -158,6 +157,13 @@ impl ClientStream { "Sending packet: total_len={}, msg_len={}, payload_len={}, message={:?}", total_len, msg_len, payload_len, msg ); + if total_len as usize > SERVER_PACKET_SIZE { + bail!( + "Packet too large to send, receive buffer too small: total_len={}, buf_len={}", + total_len, + SERVER_PACKET_SIZE + ); + } let mut tx = self.tx.lock().await; tx.write_all(&packet).await?; Ok(packet.len()) @@ -184,6 +190,14 @@ impl ClientStream { ); } + if total_len > buf.len() { + bail!( + "Packet too large for receive buffer: total_len={}, buf_len={}", + total_len, + buf.len() + ); + } + rx.read_exact(&mut buf[..msg_len]).await?; router_message = RouterMessages::from_slice(&buf[..msg_len]).await; @@ -246,17 +260,34 @@ impl Router { self.clients.read().await.get(&uuid).cloned() } - pub async fn notify_changes(&self) { + pub async fn notify_changes(&self, uuid: Uuid) { self.notify.notified().await; - for client in self.clients.read().await.values() { - println!("Notifying client {} of routing table change", client.id()); - if let Err(e) = client - .send(RouterMessages::RouteUpdate(self.get_routing_table().await), None) - .await - { - eprintln!("Error notifying client {}: {}", client.id(), e); + match self.get_client(uuid).await { + Some(client) => { + println!("Notifying client {} of routing table change", client.id()); + if let Err(e) = client + .send(RouterMessages::RouteUpdate(self.get_routing_table().await), None) + .await + { + eprintln!("Error notifying client {}: {}", client.id(), e); + } + } + None => { + eprintln!( + "Client {} not found while trying to notify of routing table change", + uuid + ); } } + // for client in self.clients.read().await.values() { + // println!("Notifying client {} of routing table change", client.id()); + // if let Err(e) = client + // .send(RouterMessages::RouteUpdate(self.get_routing_table().await), None) + // .await + // { + // eprintln!("Error notifying client {}: {}", client.id(), e); + // } + // } } } @@ -308,6 +339,9 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> { ); println!("Registering client {} with router...", vpn_client.id()); router.register_client(client, vpn_client.clone()).await?; + vpn_client + .send(RouterMessages::RouteUpdate(router.get_routing_table().await), None) + .await?; } Err(e) => { eprintln!("Failed to initialize client {}: {}", vpn_client.id(), e); @@ -346,7 +380,7 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> { } - _= router.notify_changes() => { + _= router.notify_changes(vpn_client.id()) => { println!("Routing table updated. Current routing table:"); }