diff --git a/client_test01/.config/vpn_config.toml b/client_test01/.config/vpn_config.toml new file mode 100644 index 0000000..1da6312 --- /dev/null +++ b/client_test01/.config/vpn_config.toml @@ -0,0 +1,6 @@ +[mode.Client] +server = "127.0.0.1:443" +interface_ip = "5.5.5.5/32" +interface_name = "xvpn1" +local_routes = ["6.6.6.0/24"] +mtu = 1400 diff --git a/client_test01/xvpn b/client_test01/xvpn new file mode 120000 index 0000000..800e523 --- /dev/null +++ b/client_test01/xvpn @@ -0,0 +1 @@ +../target/release/xvpn \ No newline at end of file diff --git a/src/client.rs b/src/client.rs index eb1a5f6..7d552c6 100644 --- a/src/client.rs +++ b/src/client.rs @@ -13,7 +13,7 @@ use crate::{ CLIENT_REGISTER_TIMEOUT, CliRegMessages, ClientStream, RouterMessages, RoutesMap, SERVER_PACKET_SIZE, TCP_NODELAY, VpnPacket, }, - tun::{add_route, del_route, inti_tun_interface}, + tun::{add_route, del_route, init_tun_if}, }; pub struct ClientStaTistic { @@ -50,7 +50,7 @@ pub struct ClientCfg { pub mtu: u16, } -pub async fn start(config: ClientCfg) -> Result<()> { +pub async fn start(mut config: ClientCfg) -> Result<()> { println!("Starting client with config: {:?}", config); let stream = tokio::net::TcpStream::connect(&config.server).await?; @@ -61,9 +61,9 @@ pub async fn start(config: ClientCfg) -> Result<()> { let mut vpn_buf = vec![0u8; SERVER_PACKET_SIZE]; let mut tun_buf = vec![0u8; config.mtu as usize]; - let self_uuid = register_client(&client_stream, &config, &mut vpn_buf).await?; + let self_uuid = register_client(&client_stream, &mut config, &mut vpn_buf).await?; - let tun_device = inti_tun_interface(&config).await?; + let tun_device = init_tun_if(&config).await?; let mut route_map = RoutesMap::default(); @@ -88,17 +88,14 @@ pub async fn start(config: ClientCfg) -> Result<()> { RouterMessages::RouteUpdate(mut updated_routes) => { updated_routes.retain(|_ , u| *u != self_uuid); - println!("Received route update from router: {:?}", updated_routes); let removed_routes = route_map.iter().filter(|(n,_)| !updated_routes.contains_key(n)); - println!("Removed routes: {:?}", removed_routes); for removed_route in removed_routes { println!("Route {:?} removed by router", removed_route); del_route(&tun_device, *removed_route.0).await?; } let new_routes = updated_routes.iter().filter(|(n,_)| !route_map.contains_key(n)); - println!("New routes: {:?}", new_routes); for new_route in new_routes { println!("Route {} added by router", new_route.0); add_route(&tun_device, *new_route.0).await?; @@ -108,21 +105,23 @@ pub async fn start(config: ClientCfg) -> Result<()> { } - RouterMessages::Data(packet_info) => { + RouterMessages::Data(_packet_info) => { if let Some(data) = data { - println!("Received data packet from router for client {}: size {} bytes", packet_info.src_uuid, packet_info.data_len); tun_device.send(&data).await?; } }, - _ => println!("Received message from router: {:?}", message), + _ => { + let msg = format!("Unexpected message type received from router: {:?}.", message); + client_stream.close().await?; + bail!(msg); + } } } Err(e) => { - eprintln!("Error reading from router: {}", e); client_stream.close().await?; - bail!(format!("Error reading from router: {}", e)); + bail!("Error reading from router: {}", e); } } } @@ -131,11 +130,9 @@ pub async fn start(config: ClientCfg) -> Result<()> { match data { Ok(n) => { let packet = etherparse::Ipv4HeaderSlice::from_slice(&tun_buf[..n])?; - println!("Read packet from TUN interface: {} -> {}, size: {}", packet.source_addr(), packet.destination_addr(), n); let dst = packet.destination_addr(); match ip_match_network(dst, &route_map).await { Some(uuid) => { - println!("Packet destination {} matches route for client {}, sending to router", dst, uuid); let msg = VpnPacket{ dst_uuid: uuid, src_uuid: self_uuid, @@ -150,9 +147,8 @@ pub async fn start(config: ClientCfg) -> Result<()> { } Err(e) => { - eprintln!("Error reading from TUN interface: {}", e); client_stream.close().await?; - bail!(format!("Error reading from TUN interface: {}", e)); + bail!("Error reading from TUN interface: {}", e); } } } @@ -171,8 +167,11 @@ pub async fn start(config: ClientCfg) -> Result<()> { // Ok(()) } -pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, buf: &mut [u8]) -> Result { - let register_msg = RouterMessages::CliReg(CliRegMessages::Reg(config.clone())); +pub async fn register_client(client_stream: &ClientStream, client_cfg: &mut ClientCfg, buf: &mut [u8]) -> Result { + let local_peer_address = client_cfg.interface_ip; + client_cfg.local_routes.push(local_peer_address); + + let register_msg = RouterMessages::CliReg(CliRegMessages::Reg(client_cfg.clone())); let mut client_registration_timeout = tokio::time::interval_at(Instant::now() + CLIENT_REGISTER_TIMEOUT, CLIENT_REGISTER_TIMEOUT); @@ -192,20 +191,17 @@ pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, b } RouterMessages::CliReg(CliRegMessages::RegFailed(err_msg)) => { - eprintln!("Client registration failed: {}", err_msg); bail!("Client registration failed: {}", err_msg); } - _ => { - let msg = "Unexpected message type received during client registration."; - eprintln!("{}", msg); + unexpected_msg => { + let msg = format!("Unexpected message type received during client registration: {:?}.", unexpected_msg); client_stream.close().await?; bail!(msg); } } } Err(e) => { - eprintln!("Error reading from router during client registration: {}", e); client_stream.close().await?; bail!("Error reading from router: {}", e); } @@ -213,8 +209,6 @@ pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, b } _= client_registration_timeout.tick() => { let msg = "Client registration timed out waiting for confirmation from router."; - eprintln!("{}", msg); - eprintln!("Closing connection with Server"); client_stream.close().await?; bail!(msg); } diff --git a/src/main.rs b/src/main.rs index 314e2e2..88ba45c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,8 @@ pub mod config; pub mod network; pub mod router; pub mod tun; +#[cfg(test)] +mod tests; use clap::{Parser, Subcommand}; diff --git a/src/router.rs b/src/router.rs index a394152..c75aede 100644 --- a/src/router.rs +++ b/src/router.rs @@ -108,8 +108,8 @@ pub struct ClientStream { rx: Arc>, } impl ClientStream { - pub fn new(tx: TcpStream) -> Self { - let (rx, tx) = tx.into_split(); + pub fn new(tcp_stream: TcpStream) -> Self { + let (rx, tx) = tcp_stream.into_split(); Self { tx: Arc::new(Mutex::new(tx)), rx: Arc::new(Mutex::new(rx)), @@ -153,10 +153,7 @@ impl ClientStream { // unwrap is safe here because we already checked payload is some packet.extend_from_slice(&payload.unwrap()); } - println!( - "Sending packet: total_len={}, msg_len={}, payload_len={}, message={:?}", - total_len, msg_len, payload_len, msg - ); + if total_len as usize > SERVER_PACKET_SIZE { bail!( "Packet too large to send, receive buffer too small: total_len={}, buf_len={}", @@ -225,28 +222,61 @@ impl ClientStream { tx.shutdown().await?; Ok(()) } + + pub async fn peer_addr(&self) -> Result { + let tx = self.tx.lock().await; + let peer_addr = tx.peer_addr()?; + Ok(peer_addr) + } } #[derive(Debug, Clone, Default)] pub struct Router { clients: Arc>>, routing_table: Arc>>, + registration_lock: Arc>, } impl Router { - pub async fn register_client(&self, client: ClientCfg, vpn_client: VPNClient) -> Result<()> { - let id = vpn_client.id(); + pub async fn register_client( + &self, + client_stream: ClientStream, + client_cfg: &ClientCfg, + id: Uuid, + ) -> Result { + // Ensure only one registration process can run at a time to + // Drops lock immediately after checking for route conflicts and updating routing table + // to minimize contention and allow concurrent client registrations + // as much as possible while still preventing + let _reg_lock = self.registration_lock.lock().await; - // Append local interface IP to client's local routes for routing table construction - let mut client_local_route = client.local_routes; - client_local_route.push(client.interface_ip); + println!( + "Registering client {} with local routes: {:?}", + id, client_cfg.local_routes + ); - // Add Client local routes to global routing table - for net in client_local_route { - self.routing_table.write().await.insert(net, id); + // Build a map of the client's routes to their UUID for efficient registration and routing table updates + let reg_routes: HashMap = + HashMap::from_iter(client_cfg.local_routes.iter().map(|route| (*route, id))); + + for route in reg_routes.keys() { + if let Some(existing_client_id) = self.check_if_route_exists(&route).await { + let router_msg = format!( + "Route {} overlaps with existing route for client {}, cannot register client {}", + route, existing_client_id, id + ); + let cli_reg_msg = format!("Failed to register, you have overlapping routes {} ", route); + client_stream + .send(RouterMessages::CliReg(CliRegMessages::RegFailed(cli_reg_msg)), None) + .await?; + bail!(router_msg); + } } - self.clients.write().await.insert(id, vpn_client); - self.update_all_clients().await + let vpn_client = VPNClient::new(id, client_stream); + self.clients.write().await.insert(id, vpn_client.clone()); + + self.routing_table.write().await.extend(reg_routes); + return Ok(vpn_client); } pub async fn remove_client(&self, client_id: Uuid) -> Result<()> { @@ -263,6 +293,15 @@ impl Router { self.clients.read().await.get(&uuid).cloned() } + pub async fn check_if_route_exists(&self, net: &Ipv4Net) -> Option { + for (existing_net, existing_client_id) in self.routing_table.read().await.iter() { + if existing_net.contains(net) || net.contains(existing_net) { + return Some(*existing_client_id); + } + } + None + } + pub async fn update_all_clients(&self) -> Result<()> { // Lock only the routing table and clients list once to get the current state, // then release locks before sending updates to avoid blocking other operations @@ -325,9 +364,8 @@ pub async fn start(bind_address: SocketAddr) -> Result<()> { _ = handle_router_connections(router.clone(), socket) => {} _ = tokio::signal::ctrl_c() => { - println!(); let msg = "Received shutdown signal, shutting down router..."; - println!("{}", msg); + println!("\n{}", msg); router.close_all_client(msg).await? } } @@ -343,11 +381,19 @@ async fn handle_router_connections(router: Router, socket: tokio::net::TcpListen println!("Accepted connection from {}", addr); //Clone the router for the new task tokio::spawn(async move { - println!("Handling connection from {}", addr); - match handle_client(router, tcp_stream).await { - Ok(_) => println!("Finished handling connection from {}", addr), + let id = Uuid::new_v4(); + println!("Handling connection from {} :{}", addr, id); + match handle_client(router.clone(), tcp_stream, id).await { + Ok(_) => { + println!("Finished handling connection from {}", addr); + } Err(e) => eprintln!("Error handling connection from {}: {}", addr, e), } + println!("Cleaning up client {} from connection {}", id, addr); + router + .remove_client(id) + .await + .unwrap_or_else(|e| eprintln!("Error removing client {}: {}", id, e)); }); } Err(e) => { @@ -357,48 +403,73 @@ async fn handle_router_connections(router: Router, socket: tokio::net::TcpListen } } -pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> { - stream.set_nodelay(TCP_NODELAY)?; - - let vpn_client = VPNClient::new(Uuid::new_v4(), ClientStream::new(stream)); +pub async fn handle_client(router: Router, tcp_stream: TcpStream, id: Uuid) -> Result { + tcp_stream.set_nodelay(TCP_NODELAY)?; let mut keep_alive_tick = tokio::time::interval_at(Instant::now() + KEEP_ALIVE_INTERVAL, KEEP_ALIVE_INTERVAL); let mut buf = vec![0u8; SERVER_PACKET_SIZE]; - match client_init(&vpn_client, &mut buf).await { - Ok(client) => { + let peer_address = tcp_stream.peer_addr()?; + + let (client_stream, client_config) = match client_init(tcp_stream, &mut buf, id).await { + Ok((client_stream, client_cfg)) => { println!( "Client {} registered with routing table: {:?} and local endpoint {:?}", - vpn_client.id(), - client.local_routes, - client.interface_ip + id, client_cfg.local_routes, peer_address ); - println!("Registering client {} with router...", vpn_client.id()); - router.register_client(client, vpn_client.clone()).await?; + (client_stream, client_cfg) } Err(e) => { - eprintln!("Failed to initialize client {}: {}", vpn_client.id(), e); - return Err(e); + let msg = format!( + "Failed to initialize client connection form peer {:?}:{} {}", + peer_address, id, e + ); + bail!(msg); } - } + }; + let vpn_client = match router.register_client(client_stream, &client_config, id).await { + Ok(vpn_client) => { + println!( + "Successfully registered client {} with peer address {:?} and routing table: {:?} ", + vpn_client.id(), + peer_address, + client_config.local_routes + ); + vpn_client + .send(RouterMessages::CliReg(CliRegMessages::RegOk(id)), None) + .await?; + vpn_client + } + Err(e) => { + let msg = format!( + "Failed to register client {} with peer address {:?}: {}", + id, peer_address, e + ); + bail!(msg); + } + }; + println!( + "Finished client initialization for client {} with peer address {:?}, entering main loop...", + vpn_client.id(), + peer_address + ); + + router.update_all_clients().await?; loop { tokio::select! { msg = vpn_client.receive(&mut buf) => { match msg { Ok((message, data)) => { - println!("Received message from client {}: {:?}", vpn_client.id(), message); match message { RouterMessages::Quit(msg) =>{ println!("Received quit message from client {}: {}", vpn_client.id(), msg); println!("Removing client {} from router...", vpn_client.id()); - router.remove_client(vpn_client.id()).await?; vpn_client.close().await?; - return Ok(()); + return Ok(id); } RouterMessages::Data(packet_info) => { if let Some(dst_client) = router.get_client(packet_info.dst_uuid).await { - println!("Forwarding packet from client {} to client {}", vpn_client.id(), dst_client.id()); if let Err(e) = dst_client.send(RouterMessages::Data(packet_info), data).await { eprintln!("Error forwarding packet to client {}: {}", dst_client.id(), e); } @@ -406,13 +477,15 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> { eprintln!("Destination client {} not found for packet from client {}", packet_info.dst_uuid, vpn_client.id()); } } - _ => println!("Received message from client {}: {:?}", vpn_client.id(), message) + _ => { + let msg = format!("Received unexpected message from client {}: {:?}", vpn_client.id(), message); + vpn_client.close().await?; + bail!(msg); + } } } Err(e) => { - eprintln!("Error reading from client {}: {}", vpn_client.id(), e); - println!("Removing client {} from router...", vpn_client.id()); - router.remove_client(vpn_client.id()).await?; + eprintln!("Error removing client {} from router...", vpn_client.id()); vpn_client.close().await?; bail!(format!("Error reading from client: {}", e)); } @@ -420,9 +493,7 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> { } - // _= router.notify_changes(vpn_client.id()) => { - // println!("Routing table updated. Current routing table:"); - // } + _= keep_alive_tick.tick() => { // Send keep-alive message to the client @@ -435,35 +506,33 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> { //Ok(()) } -pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result { +pub async fn client_init(tcp_stream: TcpStream, buf: &mut [u8], id: Uuid) -> Result<(ClientStream, ClientCfg)> { + let client_stream = ClientStream::new(tcp_stream); let mut client_registration_timeout = tokio::time::interval_at(Instant::now() + CLIENT_REGISTER_TIMEOUT, CLIENT_REGISTER_TIMEOUT); loop { tokio::select! { - msg = vpn_client.receive(buf) => { + msg = client_stream.receive(buf) => { match msg { Ok((router_msg, _)) => { match router_msg { - RouterMessages::CliReg(CliRegMessages::Reg(client))=> { - println!("Received client registration with routing table: {:?}", client.local_routes); - vpn_client.send(RouterMessages::CliReg(CliRegMessages::RegOk(vpn_client.id())),None).await?; - return Ok(client); + RouterMessages::CliReg(CliRegMessages::Reg(client_cfg))=> { + println!("Received client registration with routing table: {:?}", client_cfg.local_routes); + return Ok((client_stream,client_cfg)); } router_msg => { let msg = format!("Expected client registration message, but received: {:?}", router_msg); - eprintln!("{}", msg); - eprintln!("Closing connection with client {}", vpn_client.id()); - vpn_client.close().await?; + client_stream.close().await?; bail!(msg); } } } Err(e) => { - eprintln!("Error reading from client {} during registration: {}", vpn_client.id(), e); - vpn_client.close().await?; - bail!(format!("Error reading from client during registration: {}", e)); + let msg = format!("Error reading from client {} during registration: {:?}", id, e); + client_stream.close().await?; + bail!(msg); } } @@ -472,8 +541,8 @@ pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result { let msg = format!("Client registration timed out after {}ms", (CLIENT_REGISTER_TIMEOUT.as_millis())); - vpn_client.send(RouterMessages::Quit(msg.clone()),None).await?; - vpn_client.close().await?; + client_stream.send(RouterMessages::Quit(msg.clone()),None).await?; + client_stream.close().await?; bail!(msg); } diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..bf18ebc --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,111 @@ +use std::net::Ipv4Addr; + +use ipnet::Ipv4Net; +use uuid::Uuid; + +use crate::{ + client::ClientCfg, + network::ip_match_network, + router::{CliRegMessages, RouterMessages, VpnPacket}, +}; + +fn sample_cfg() -> ClientCfg { + ClientCfg { + server: "127.0.0.1:443".to_string(), + interface_ip: "10.8.0.2/32".parse().expect("cidr parse"), + interface_name: "xvpn0".to_string(), + local_routes: vec!["10.10.0.0/24".parse().expect("cidr parse")], + mtu: 1400, + } +} + +#[tokio::test] +async fn router_message_keep_alive_roundtrip_json() { + let msg = RouterMessages::KeepAlive(123_456); + let bytes = msg.to_bytes().await; + let decoded = RouterMessages::from_slice(&bytes).await; + + match decoded { + RouterMessages::KeepAlive(ts) => assert_eq!(ts, 123_456), + other => panic!("unexpected decoded message: {other:?}"), + } +} + +#[tokio::test] +async fn router_message_data_roundtrip_json() { + let src = Uuid::new_v4(); + let dst = Uuid::new_v4(); + let msg = RouterMessages::Data(VpnPacket { + src_uuid: src, + dst_uuid: dst, + data_len: 512, + }); + + let bytes = msg.to_bytes().await; + let decoded = RouterMessages::from_slice(&bytes).await; + + match decoded { + RouterMessages::Data(pkt) => { + assert_eq!(pkt.src_uuid, src); + assert_eq!(pkt.dst_uuid, dst); + assert_eq!(pkt.data_len, 512); + } + other => panic!("unexpected decoded message: {other:?}"), + } +} + +#[tokio::test] +async fn router_message_from_invalid_slice_is_unknown() { + let decoded = RouterMessages::from_slice(b"{ not valid json").await; + match decoded { + RouterMessages::Unknown(s) => assert!(!s.is_empty(), "unknown payload should not be empty"), + other => panic!("expected Unknown, got {other:?}"), + } +} + +#[test] +fn cli_reg_message_reg_roundtrip_json() { + let cfg = sample_cfg(); + let msg = CliRegMessages::Reg(cfg.clone()); + let bytes = msg.to_bytes(); + let decoded = CliRegMessages::from_slice(&bytes); + + match decoded { + CliRegMessages::Reg(decoded_cfg) => { + assert_eq!(decoded_cfg.server, cfg.server); + assert_eq!(decoded_cfg.interface_ip, cfg.interface_ip); + assert_eq!(decoded_cfg.interface_name, cfg.interface_name); + assert_eq!(decoded_cfg.local_routes, cfg.local_routes); + assert_eq!(decoded_cfg.mtu, cfg.mtu); + } + other => panic!("unexpected decoded message: {other:?}"), + } +} + +#[test] +fn cli_reg_message_invalid_utf8_is_unknown() { + let decoded = CliRegMessages::from_slice(&[0xff, 0xfe, 0xfd]); + match decoded { + CliRegMessages::Unknown(s) => assert!(s.contains("Invalid UTF-8"), "unexpected content: {s}"), + other => panic!("expected Unknown, got {other:?}"), + } +} + +#[tokio::test] +async fn ip_match_network_returns_expected_uuid() { + let mut routes = std::collections::HashMap::new(); + let id = Uuid::new_v4(); + routes.insert("192.168.1.0/24".parse::().expect("cidr parse"), id); + + let match_id = ip_match_network(Ipv4Addr::new(192, 168, 1, 44), &routes).await; + assert_eq!(match_id, Some(id)); +} + +#[tokio::test] +async fn ip_match_network_returns_none_when_no_route_matches() { + let mut routes = std::collections::HashMap::new(); + routes.insert("192.168.1.0/24".parse::().expect("cidr parse"), Uuid::new_v4()); + + let no_match = ip_match_network(Ipv4Addr::new(10, 10, 10, 10), &routes).await; + assert!(no_match.is_none(), "non-matching ip should return none"); +} diff --git a/src/tun.rs b/src/tun.rs index 6b6b2ca..7d75c33 100644 --- a/src/tun.rs +++ b/src/tun.rs @@ -5,7 +5,7 @@ use tun_rs::{AsyncDevice, DeviceBuilder}; use crate::client::ClientCfg; -pub async fn inti_tun_interface(config: &ClientCfg) -> Result { +pub async fn init_tun_if(config: &ClientCfg) -> Result { println!( "Initializing TUN interface with name: {}, IP: {}/{}, MTU: {}", config.interface_name, @@ -22,7 +22,6 @@ pub async fn inti_tun_interface(config: &ClientCfg) -> Result { Ok(dev) => dev, Err(e) => { let msg = format!("Failed to create TUN interface: {:#}", e); - eprintln!("{}", msg); bail!(msg); } }; @@ -44,7 +43,6 @@ pub async fn add_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> { if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); let msg = format!("Failed to add route {} via device {}: {:#}", route, dev, stderr); - eprintln!("{}", msg); bail!(msg); } @@ -65,7 +63,6 @@ pub async fn del_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> { if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); let msg = format!("Failed to delete route {} via device {}: {:#}", route, dev, stderr); - eprintln!("{}", msg); bail!(msg); }