From 9241f1fbacf39e76252a2f4d3d1ee7082e87427a Mon Sep 17 00:00:00 2001 From: icsboyx Date: Mon, 2 Mar 2026 14:29:25 +0100 Subject: [PATCH] Refactor code structure for improved readability and maintainability --- Cargo.lock | 11 ++ Cargo.toml | 12 +- client_test/.config/vpn_config.toml | 6 + client_test/xvpn | 1 + src/client.rs | 133 ++++++++------ src/main.rs | 2 +- src/network.rs | 10 +- src/router.rs | 261 +++++++++++++++++++++------- src/tun.rs | 46 ++++- 9 files changed, 355 insertions(+), 127 deletions(-) create mode 100644 client_test/.config/vpn_config.toml create mode 120000 client_test/xvpn diff --git a/Cargo.lock b/Cargo.lock index 762ce5d..33a6bd3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -832,6 +832,16 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde_bytes" +version = "0.11.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5d440709e79d88e51ac01c4b72fc6cb7314017bb7da9eeff678aa94c10e3ea8" +dependencies = [ + "serde", + "serde_core", +] + [[package]] name = "serde_core" version = "1.0.228" @@ -1595,6 +1605,7 @@ dependencies = [ "etherparse", "ipnet", "serde", + "serde_bytes", "serde_json", "tokio", "toml", diff --git a/Cargo.toml b/Cargo.toml index 21f1254..134fbdd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,16 +8,7 @@ clap = { version = "4.5.37", features = ["derive"] } serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.149" toml = "1.0.3" -tokio = { version = "1.49.0", features = [ - "macros", - "rt-multi-thread", - "time", - "fs", - "net", - "io-util", - "sync", - "signal", -] } +tokio = { version = "1.49.0", features = ["macros", "rt-multi-thread", "time", "fs", "net", "io-util", "sync", "signal", "process"] } anyhow = "1.0.102" uuid = { version = "1.21.0", features = ["v4", "serde"] } ipnet = { version = "2.11.0", features = ["serde"] } @@ -25,3 +16,4 @@ base64 = "0.22.1" tun-rs = { version = "2.8.2", features = ["async"] } chrono = "0.4.44" etherparse = "0.19.0" +serde_bytes = "0.11.19" diff --git a/client_test/.config/vpn_config.toml b/client_test/.config/vpn_config.toml new file mode 100644 index 0000000..d1db108 --- /dev/null +++ b/client_test/.config/vpn_config.toml @@ -0,0 +1,6 @@ +[mode.Client] +server = "127.0.0.1:443" +interface_ip = "2.2.2.2/32" +interface_name = "xvpn0" +local_routes = ["4.4.4.0/24"] +mtu = 1400 diff --git a/client_test/xvpn b/client_test/xvpn new file mode 120000 index 0000000..800e523 --- /dev/null +++ b/client_test/xvpn @@ -0,0 +1 @@ +../target/release/xvpn \ No newline at end of file diff --git a/src/client.rs b/src/client.rs index 7409ba5..4875b62 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,16 +4,16 @@ use clap::Args; use ipnet::Ipv4Net; use serde::{Deserialize, Serialize}; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::tcp::{OwnedReadHalf, OwnedWriteHalf}, - time::Instant, -}; +use tokio::time::Instant; +use uuid::Uuid; use crate::{ network::ip_match_network, - router::{CLIENT_REGISTER_TIMEOUT, CliRegMessages, RouterMessages, SERVER_PACKET_SIZE}, - tun::inti_tun_interface, + router::{ + CLIENT_REGISTER_TIMEOUT, CliRegMessages, ClientStream, RouterMessages, RoutesMap, SERVER_PACKET_SIZE, + TCP_NODELAY, VpnPacket, + }, + tun::{add_route, del_route, inti_tun_interface}, }; pub struct ClientStaTistic { @@ -54,51 +54,97 @@ pub async fn start(config: ClientCfg) -> Result<()> { println!("Starting client with config: {:?}", config); let stream = tokio::net::TcpStream::connect(&config.server).await?; - //stream.set_nodelay(true)?; - let (mut rx, mut tx) = stream.into_split(); - // let client_stream = ClientStream::new(tx); + stream.set_nodelay(TCP_NODELAY)?; + + let client_stream = ClientStream::new(stream); let mut vpn_buf = vec![0u8; SERVER_PACKET_SIZE]; let mut tun_buf = vec![0u8; config.mtu as usize]; - register_client(&mut rx, &mut tx, &config, &mut vpn_buf).await?; + + let self_uuid = register_client(&client_stream, &config, &mut vpn_buf).await?; + let tun_device = inti_tun_interface(&config).await?; + let mut route_map = RoutesMap::default(); + println!("Client registration successful. Entering main loop to receive messages from router..."); loop { tokio::select! { - msg = rx.read(&mut vpn_buf) => { + msg = client_stream.receive(&mut vpn_buf) => { match msg { - Ok(0) => { - println!("Connection to router closed by peer."); - return Ok(()); - } - Ok(n) => { - match RouterMessages::from_slice(&vpn_buf[..n]){ + Ok((message, data)) => { + match message { RouterMessages::KeepAlive(timestamp) => { println!("Received keep-alive message from router with timestamp: {}, delta {} ms", timestamp, (Utc::now().timestamp_micros() - timestamp).abs() as f64 / 1000.0); } - _ => println!("Received message from router: {:?}", RouterMessages::from_slice(&vpn_buf[..n])) - }; + + RouterMessages::RouteUpdate(mut updated_routes) => { + updated_routes.retain(|_ , u| *u != self_uuid); + println!("Received route update from router: {:?}", updated_routes); + + let removed_routes = route_map.iter().filter(|(n,_)| !updated_routes.contains_key(n)); + println!("Removed routes: {:?}", removed_routes); + for removed_route in removed_routes { + println!("Route {:?} removed by router", removed_route); + del_route(&tun_device, *removed_route.0).await?; + } + + let new_routes = updated_routes.iter().filter(|(n,_)| !route_map.contains_key(n)); + println!("New routes: {:?}", new_routes); + for new_route in new_routes { + println!("Route {} added by router", new_route.0); + add_route(&tun_device, *new_route.0).await?; + } + + route_map = updated_routes; + + } + + RouterMessages::Data(packet_info) => { + if let Some(data) = data { + println!("Received data packet from router for client {}: size {} bytes", packet_info.src_uuid, packet_info.data_len); + tun_device.send(&data).await?; + } + }, + + _ => println!("Received message from router: {:?}", message), + + } } Err(e) => { eprintln!("Error reading from router: {}", e); + client_stream.close().await?; return Err(anyhow::anyhow!(format!("Error reading from router: {}", e))); - } } } + } + data = tun_device.recv(&mut tun_buf) => { match data { Ok(n) => { let packet = etherparse::Ipv4HeaderSlice::from_slice(&tun_buf[..n])?; - let src = packet.source_addr(); - match ip_match_network(src, &config.local_routes).await { - Some(net) => println!("Source IP {} matches local route {}", src, net), + println!("Read packet from TUN interface: {} -> {}, size: {}", packet.source_addr(), packet.destination_addr(), n); + let dst = packet.destination_addr(); + match ip_match_network(dst, &route_map).await { + Some(uuid) => { + println!("Packet destination {} matches route for client {}, sending to router", dst, uuid); + let msg = VpnPacket{ + dst_uuid: uuid, + src_uuid: self_uuid, + data_len: n, + }; + let data = tun_buf[..n].to_vec(); + // tx.write(&RouterMessages::Data(msg), data.as_vec()).await?; + client_stream.send(RouterMessages::Data(msg), Some(data)).await?; + }, None => {}, } + } Err(e) => { eprintln!("Error reading from TUN interface: {}", e); + client_stream.close().await?; return Err(anyhow::anyhow!(format!("Error reading from TUN interface: {}", e))); } } @@ -109,33 +155,24 @@ pub async fn start(config: ClientCfg) -> Result<()> { // Ok(()) } -pub async fn register_client( - rx: &mut OwnedReadHalf, - tx: &mut OwnedWriteHalf, - config: &ClientCfg, - buf: &mut [u8], -) -> Result<()> { +pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, buf: &mut [u8]) -> Result { let register_msg = RouterMessages::CliReg(CliRegMessages::Reg(config.clone())); + let mut client_registration_timeout = tokio::time::interval_at(Instant::now() + CLIENT_REGISTER_TIMEOUT, CLIENT_REGISTER_TIMEOUT); - tx.write_all(®ister_msg.to_bytes()).await?; + + client_stream.send(register_msg, None).await?; loop { tokio::select! { - msg = rx.read(buf) => { + msg = client_stream.receive(buf) => { match msg { - Ok(0) => { - let msg = "Connection closed by router while waiting for registration confirmation."; - eprintln!("{}", msg); - return Err(anyhow::anyhow!(msg)); - } - Ok(n) => { - let response = RouterMessages::from_slice(&buf[..n]); - println!("Received registration response from router: {:?}", response); - match response { - RouterMessages::CliReg(CliRegMessages::RegOk(uuid)) => { - println!("Client registration successful with UUID: {}", uuid); - return Ok(()); + Ok((message, _data)) => { + match message { + RouterMessages::CliReg(CliRegMessages::RegOk(uuid)) => { + println!("Received registration response from router: {:?}", message); + println!("Client registration successful with UUID: {}", uuid); + return Ok(uuid); } RouterMessages::CliReg(CliRegMessages::RegFailed(err_msg)) => { eprintln!("Client registration failed: {}", err_msg); @@ -144,28 +181,26 @@ pub async fn register_client( _ => { let msg = "Unexpected message type received during client registration."; eprintln!("{}", msg); + client_stream.close().await?; return Err(anyhow::anyhow!(msg)); } } } Err(e) => { eprintln!("Error reading from router during client registration: {}", e); + client_stream.close().await?; return Err(anyhow::anyhow!(format!("Error reading from router: {}", e))); } } - - } _= client_registration_timeout.tick() => { let msg = "Client registration timed out waiting for confirmation from router."; eprintln!("{}", msg); eprintln!("Closing connection with Server"); - tx.shutdown().await?; + client_stream.close().await?; return Err(anyhow::anyhow!(msg)); } - - } } // Ok(()) diff --git a/src/main.rs b/src/main.rs index fc0751b..314e2e2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -84,7 +84,7 @@ async fn main() -> anyhow::Result<()> { match commandline.mode { OpModes::Client(client) => client::start(client).await?, OpModes::Router { bind_address } => { - router::start(bind_address).await; + router::start(bind_address).await?; } } diff --git a/src/network.rs b/src/network.rs index 5554f0f..7bd292b 100644 --- a/src/network.rs +++ b/src/network.rs @@ -1,10 +1,12 @@ -use ipnet::Ipv4Net; use std::net::Ipv4Addr; +use uuid::Uuid; -pub async fn ip_match_network(ip: Ipv4Addr, networks: &[Ipv4Net]) -> Option { - for net in networks { +use crate::router::RoutesMap; + +pub async fn ip_match_network(ip: Ipv4Addr, networks: &RoutesMap) -> Option { + for (net, uuid) in networks { if net.contains(&ip) { - return Some(*net); + return Some(*uuid); } } None diff --git a/src/router.rs b/src/router.rs index e0924e1..6fbb9c8 100644 --- a/src/router.rs +++ b/src/router.rs @@ -1,16 +1,18 @@ -use anyhow::Result; +use anyhow::{Context, Result, bail}; use chrono::Utc; +use etherparse::err::packet; use ipnet::Ipv4Net; use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration}; +use serde_bytes::Bytes; +use std::{collections::HashMap, mem, net::SocketAddr, sync::Arc, time::Duration}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{ TcpStream, tcp::{OwnedReadHalf, OwnedWriteHalf}, }, - sync::{Mutex, RwLock}, + sync::{Mutex, Notify, RwLock}, time::Instant, }; use uuid::Uuid; @@ -19,36 +21,42 @@ use crate::client::ClientCfg; pub static KEEP_ALIVE_INTERVAL: Duration = tokio::time::Duration::from_secs(30); pub static CLIENT_REGISTER_TIMEOUT: Duration = tokio::time::Duration::from_millis(100); +pub static TCP_NODELAY: bool = true; pub const SERVER_PACKET_SIZE: usize = 1024 * 9; pub trait ReceiverTrait {} +pub type RoutesMap = HashMap; + #[derive(Debug, Clone, Serialize, Deserialize)] pub enum RouterMessages { CliReg(CliRegMessages), KeepAlive(i64), Data(VpnPacket), Quit(String), + RouteUpdate(RoutesMap), Unknown(String), } +impl RouterMessages { + pub async fn to_bytes(&self) -> Vec { + serde_json::to_vec(self).expect("Unable to serialize RouterMessages") + } + + pub async fn from_slice(slice: &[u8]) -> Self { + serde_json::from_slice(slice).unwrap_or(RouterMessages::Unknown(String::from_utf8_lossy(slice).to_string())) + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct VpnPacket { pub src_uuid: Uuid, pub dst_uuid: Uuid, - pub payload: Vec, -} -impl RouterMessages { - pub fn to_bytes(&self) -> Vec { - serde_json::to_vec(self).expect("Unable to serialize RouteMessages") - } - pub fn from_slice(slice: &[u8]) -> Self { - serde_json::from_slice(slice).unwrap_or(RouterMessages::Unknown( - String::from_utf8(slice.to_vec()).unwrap_or_else(|b| format!("Invalid UTF-8: {:?}", b.as_bytes())), - )) - } + pub data_len: usize, } +impl RouterMessages {} + #[derive(Debug, Clone, Serialize, Deserialize)] pub enum CliRegMessages { Reg(ClientCfg), @@ -82,8 +90,12 @@ impl VPNClient { pub fn id(&self) -> Uuid { self.id } - pub async fn send(&self, msg: RouterMessages) -> Result<()> { - self.stream.send(msg).await + pub async fn send(&self, msg: RouterMessages, data: Option>) -> Result { + self.stream.send(msg, data).await + } + + pub async fn receive(&self, buf: &mut [u8]) -> Result<(RouterMessages, Option>)> { + self.stream.receive(buf).await } pub async fn close(&self) -> Result<()> { @@ -94,20 +106,101 @@ impl VPNClient { #[derive(Debug, Clone)] pub struct ClientStream { tx: Arc>, + rx: Arc>, } impl ClientStream { - pub fn new(tx: OwnedWriteHalf) -> Self { + pub fn new(tx: TcpStream) -> Self { + let (rx, tx) = tx.into_split(); Self { - // write can be shared tx: Arc::new(Mutex::new(tx)), - // read is done only by one task + rx: Arc::new(Mutex::new(rx)), } } - pub async fn send(&self, msg: RouterMessages) -> Result<()> { - let bytes = msg.to_bytes(); + pub async fn send(&self, msg: RouterMessages, payload: Option>) -> Result { + // [u16 total_len][u16 msg_len][u16 payload_len][msg bytes][payload bytes] + + let msg_bytes = serde_json::to_vec(&msg)?; + let msg_len: u16 = msg_bytes.len().try_into().map_err(|e| { + anyhow::anyhow!(format!( + "Failed to convert message length to u16: {} (message size: {} bytes)", + e, + msg_bytes.len() + )) + })?; + + let payload_len: u16 = payload.as_ref().map_or(0, |p| p.len()).try_into().map_err(|e| { + anyhow::anyhow!(format!( + "Failed to convert payload length to u16: {} (payload size: {} bytes)", + e, + payload.as_ref().map_or(0, |p| p.len()) + )) + })?; + + let total_len_usize = 6usize + msg_bytes.len() + payload_len as usize; + + if total_len_usize > u16::MAX as usize { + bail!("Total packet size {} exceeds maximum of {}", total_len_usize, u16::MAX); + } + + let total_len = total_len_usize as u16; + + let mut packet = Vec::with_capacity(total_len_usize); + packet.extend_from_slice(&total_len.to_be_bytes()); + packet.extend_from_slice(&msg_len.to_be_bytes()); + packet.extend_from_slice(&payload_len.to_be_bytes()); + packet.extend_from_slice(&msg_bytes); + + if payload.is_some() { + // unwrap is safe here because we already checked payload is some + packet.extend_from_slice(&payload.unwrap()); + } + println!( + "Sending packet: total_len={}, msg_len={}, payload_len={}, message={:?}", + total_len, msg_len, payload_len, msg + ); let mut tx = self.tx.lock().await; - tx.write_all(&bytes).await?; - Ok(()) + tx.write_all(&packet).await?; + Ok(packet.len()) + } + + pub async fn receive(&self, buf: &mut [u8]) -> Result<(RouterMessages, Option>)> { + // [u16 total_len][u16 msg_len][u16 payload_len][msg bytes][payload bytes] + + let mut rx = self.rx.lock().await; + let router_message: RouterMessages; + let payload: Option>; + + match rx.read_exact(&mut buf[..6]).await { + Ok(_) => { + let total_len = u16::from_be_bytes(buf[0..2].try_into().unwrap()) as usize; + let msg_len = u16::from_be_bytes(buf[2..4].try_into().unwrap()) as usize; + let payload_len = u16::from_be_bytes(buf[4..6].try_into().unwrap()) as usize; + + if total_len != 6 + msg_len + payload_len { + bail!( + "Invalid packet length: total_len={}, but expected {}", + total_len, + 6 + msg_len + payload_len + ); + } + + rx.read_exact(&mut buf[..msg_len]).await?; + router_message = RouterMessages::from_slice(&buf[..msg_len]).await; + + payload = if payload_len > 0 { + rx.read_exact(&mut buf[..payload_len]).await?; + Some(buf[..payload_len].to_vec()) + } else { + None + } + } + Err(e) => { + eprintln!("Error reading from client stream: {}", e); + return Err(anyhow::anyhow!(format!("Error reading from client stream: {}", e))); + } + }; + + Ok((router_message, payload)) } pub async fn close(&self) -> Result<()> { @@ -121,37 +214,69 @@ impl ClientStream { pub struct Router { clients: Arc>>, routing_table: Arc>>, + notify: Arc, } impl Router { - pub async fn register_client(&self, routing_table: &[Ipv4Net], vpn_client: VPNClient) -> Result<()> { - let id = Uuid::new_v4(); - for net in routing_table { - self.routing_table.write().await.insert(*net, id); + pub async fn register_client(&self, client: ClientCfg, vpn_client: VPNClient) -> Result<()> { + let id = vpn_client.id(); + let mut client_local_route = client.local_routes; + client_local_route.push(client.interface_ip); + self.routing_table.write().await.insert(client.interface_ip, id); + for net in client_local_route { + self.routing_table.write().await.insert(net, id); } self.clients.write().await.insert(id, vpn_client); - + self.notify.notify_waiters(); Ok(()) } + + pub async fn remove_client(&self, client_id: Uuid) -> Result<()> { + self.clients.write().await.remove(&client_id); + self.routing_table.write().await.retain(|_, &mut id| id != client_id); + self.notify.notify_waiters(); + Ok(()) + } + + pub async fn get_routing_table(&self) -> HashMap { + self.routing_table.read().await.clone() + } + + pub async fn get_client(&self, uuid: Uuid) -> Option { + self.clients.read().await.get(&uuid).cloned() + } + + pub async fn notify_changes(&self) { + self.notify.notified().await; + for client in self.clients.read().await.values() { + println!("Notifying client {} of routing table change", client.id()); + if let Err(e) = client + .send(RouterMessages::RouteUpdate(self.get_routing_table().await), None) + .await + { + eprintln!("Error notifying client {}: {}", client.id(), e); + } + } + } } -pub async fn start(bind_address: SocketAddr) { +pub async fn start(bind_address: SocketAddr) -> Result<()> { println!("Starting router on {}...", bind_address); let router = Router::default(); - let socket = tokio::net::TcpListener::bind(bind_address).await.unwrap(); + let socket = tokio::net::TcpListener::bind(bind_address).await?; println!("Router is listening on {}...", bind_address); loop { + let router = router.clone(); match socket.accept().await { Ok((tcp_stream, addr)) => { println!("Accepted connection from {}", addr); //Clone the router for the new task - let router = router.clone(); tokio::spawn(async move { println!("Handling connection from {}", addr); - match handle_client(router.clone(), tcp_stream).await { + match handle_client(router, tcp_stream).await { Ok(_) => println!("Finished handling connection from {}", addr), Err(e) => eprintln!("Error handling connection from {}: {}", addr, e), } @@ -162,54 +287,72 @@ pub async fn start(bind_address: SocketAddr) { } } } + // Ok(()) } pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> { - let (mut rx, tx) = stream.into_split(); - let vpn_client = VPNClient::new(Uuid::new_v4(), ClientStream::new(tx)); + stream.set_nodelay(TCP_NODELAY)?; + + let vpn_client = VPNClient::new(Uuid::new_v4(), ClientStream::new(stream)); let mut keep_alive_tick = tokio::time::interval_at(Instant::now() + KEEP_ALIVE_INTERVAL, KEEP_ALIVE_INTERVAL); let mut buf = vec![0u8; SERVER_PACKET_SIZE]; - match client_init(&mut rx, &vpn_client, &mut buf).await { + match client_init(&vpn_client, &mut buf).await { Ok(client) => { println!( - "Client {} registered with routing table: {:?}", + "Client {} registered with routing table: {:?} and local endpoint {:?}", vpn_client.id(), - client.local_routes + client.local_routes, + client.interface_ip ); println!("Registering client {} with router...", vpn_client.id()); - router.register_client(&client.local_routes, vpn_client.clone()).await?; + router.register_client(client, vpn_client.clone()).await?; } Err(e) => { eprintln!("Failed to initialize client {}: {}", vpn_client.id(), e); return Err(e); } } + loop { tokio::select! { - msg = rx.read(&mut buf) => { + msg = vpn_client.receive(&mut buf) => { match msg { - Ok(0) => { - println!("Client {} closed the connection", vpn_client.id()); - return Ok(()); - } - Ok(n) => { - let msg = RouterMessages::from_slice(&buf[..n]); - println!("Received message from client {}: {:?}", vpn_client.id(), msg); - // Here you would implement the logic to handle messages from the client, such as routing data to other clients based on the routing table + Ok((message, data)) => { + println!("Received message from client {}: {:?}", vpn_client.id(), message); + match message { + RouterMessages::Data(packet_info) => { + if let Some(dst_client) = router.get_client(packet_info.dst_uuid).await { + println!("Forwarding packet from client {} to client {}", vpn_client.id(), dst_client.id()); + if let Err(e) = dst_client.send(RouterMessages::Data(packet_info), data).await { + eprintln!("Error forwarding packet to client {}: {}", dst_client.id(), e); + } + } else { + eprintln!("Destination client {} not found for packet from client {}", packet_info.dst_uuid, vpn_client.id()); + } + } + _ => println!("Received message from client {}: {:?}", vpn_client.id(), message) + } } Err(e) => { eprintln!("Error reading from client {}: {}", vpn_client.id(), e); + println!("Removing client {} from router...", vpn_client.id()); + router.remove_client(vpn_client.id()).await?; + vpn_client.close().await?; return Err(anyhow::anyhow!(format!("Error reading from client: {}", e))); } } } + _= router.notify_changes() => { + println!("Routing table updated. Current routing table:"); + } + _= keep_alive_tick.tick() => { // Send keep-alive message to the client - vpn_client.send(RouterMessages::KeepAlive(Utc::now().timestamp_micros())).await?; + vpn_client.send(RouterMessages::KeepAlive(Utc::now().timestamp_micros()), None).await?; } } } @@ -217,25 +360,19 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> { //Ok(()) } -pub async fn client_init(rx: &mut OwnedReadHalf, vpn_client: &VPNClient, buf: &mut [u8]) -> Result { +pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result { let mut client_registration_timeout = tokio::time::interval_at(Instant::now() + CLIENT_REGISTER_TIMEOUT, CLIENT_REGISTER_TIMEOUT); loop { tokio::select! { - msg = rx.read(buf) => { + msg = vpn_client.receive(buf) => { match msg { - Ok(0) => { - println!("Client {} closed the connection during registration", vpn_client.id()); - return Err(anyhow::anyhow!("Client closed the connection during registration")); - } - Ok(n) => { - let msg = RouterMessages::from_slice(&buf[..n]); - match msg { + Ok((router_msg, _)) => { + match router_msg { RouterMessages::CliReg(CliRegMessages::Reg(client))=> { println!("Received client registration with routing table: {:?}", client.local_routes); - let uuid = Uuid::new_v4(); - vpn_client.send(RouterMessages::CliReg(CliRegMessages::RegOk(uuid))).await?; + vpn_client.send(RouterMessages::CliReg(CliRegMessages::RegOk(vpn_client.id())),None).await?; return Ok(client); } router_msg => { @@ -250,17 +387,17 @@ pub async fn client_init(rx: &mut OwnedReadHalf, vpn_client: &VPNClient, buf: &m } Err(e) => { eprintln!("Error reading from client {} during registration: {}", vpn_client.id(), e); + vpn_client.close().await?; return Err(anyhow::anyhow!(format!("Error reading from client during registration: {}", e))); + } } - - } _ = client_registration_timeout.tick() => { let msg = format!("Client registration timed out after {}ms", (CLIENT_REGISTER_TIMEOUT.as_millis())); - vpn_client.send(RouterMessages::Quit(msg.clone())).await?; + vpn_client.send(RouterMessages::Quit(msg.clone()),None).await?; vpn_client.close().await?; return Err(anyhow::anyhow!(msg)); diff --git a/src/tun.rs b/src/tun.rs index 378382c..9b0076e 100644 --- a/src/tun.rs +++ b/src/tun.rs @@ -1,4 +1,6 @@ -use anyhow::Result; +use anyhow::{Context, Result}; +use ipnet::Ipv4Net; +use tokio::process::Command; use tun_rs::{AsyncDevice, DeviceBuilder}; use crate::client::ClientCfg; @@ -27,3 +29,45 @@ pub async fn inti_tun_interface(config: &ClientCfg) -> Result { Ok(device) } + +pub async fn add_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> { + let dev = tun_device.name().context("failed to get tun device name")?; + + println!("Adding route {} dev {}", route, dev); + + let output = Command::new("ip") + .args(["route", "replace", &route.to_string(), "dev", &dev]) + .output() + .await + .context("failed to execute ip route")?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + let msg = format!("Failed to add route {} via device {}: {:#}", route, dev, stderr); + eprintln!("{}", msg); + return Err(anyhow::anyhow!(msg)); + } + + Ok(()) +} + +pub async fn del_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> { + let dev = tun_device.name().context("failed to get tun device name")?; + + println!("Deleting route {} dev {}", route, dev); + + let output = Command::new("ip") + .args(["route", "del", &route.to_string(), "dev", &dev]) + .output() + .await + .context("failed to execute ip route")?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + let msg = format!("Failed to delete route {} via device {}: {:#}", route, dev, stderr); + eprintln!("{}", msg); + return Err(anyhow::anyhow!(msg)); + } + + Ok(()) +}