Refactor client and router code for improved readability, fix function naming, and enhance error handling; add initial test cases and configuration file

This commit is contained in:
2026-03-05 14:36:08 +01:00
parent 479270834d
commit dad49936be
7 changed files with 269 additions and 89 deletions

View File

@@ -0,0 +1,6 @@
[mode.Client]
server = "127.0.0.1:443"
interface_ip = "5.5.5.5/32"
interface_name = "xvpn1"
local_routes = ["6.6.6.0/24"]
mtu = 1400

1
client_test01/xvpn Symbolic link
View File

@@ -0,0 +1 @@
../target/release/xvpn

View File

@@ -13,7 +13,7 @@ use crate::{
CLIENT_REGISTER_TIMEOUT, CliRegMessages, ClientStream, RouterMessages, RoutesMap, SERVER_PACKET_SIZE, CLIENT_REGISTER_TIMEOUT, CliRegMessages, ClientStream, RouterMessages, RoutesMap, SERVER_PACKET_SIZE,
TCP_NODELAY, VpnPacket, TCP_NODELAY, VpnPacket,
}, },
tun::{add_route, del_route, inti_tun_interface}, tun::{add_route, del_route, init_tun_if},
}; };
pub struct ClientStaTistic { pub struct ClientStaTistic {
@@ -50,7 +50,7 @@ pub struct ClientCfg {
pub mtu: u16, pub mtu: u16,
} }
pub async fn start(config: ClientCfg) -> Result<()> { pub async fn start(mut config: ClientCfg) -> Result<()> {
println!("Starting client with config: {:?}", config); println!("Starting client with config: {:?}", config);
let stream = tokio::net::TcpStream::connect(&config.server).await?; let stream = tokio::net::TcpStream::connect(&config.server).await?;
@@ -61,9 +61,9 @@ pub async fn start(config: ClientCfg) -> Result<()> {
let mut vpn_buf = vec![0u8; SERVER_PACKET_SIZE]; let mut vpn_buf = vec![0u8; SERVER_PACKET_SIZE];
let mut tun_buf = vec![0u8; config.mtu as usize]; let mut tun_buf = vec![0u8; config.mtu as usize];
let self_uuid = register_client(&client_stream, &config, &mut vpn_buf).await?; let self_uuid = register_client(&client_stream, &mut config, &mut vpn_buf).await?;
let tun_device = inti_tun_interface(&config).await?; let tun_device = init_tun_if(&config).await?;
let mut route_map = RoutesMap::default(); let mut route_map = RoutesMap::default();
@@ -88,17 +88,14 @@ pub async fn start(config: ClientCfg) -> Result<()> {
RouterMessages::RouteUpdate(mut updated_routes) => { RouterMessages::RouteUpdate(mut updated_routes) => {
updated_routes.retain(|_ , u| *u != self_uuid); updated_routes.retain(|_ , u| *u != self_uuid);
println!("Received route update from router: {:?}", updated_routes);
let removed_routes = route_map.iter().filter(|(n,_)| !updated_routes.contains_key(n)); let removed_routes = route_map.iter().filter(|(n,_)| !updated_routes.contains_key(n));
println!("Removed routes: {:?}", removed_routes);
for removed_route in removed_routes { for removed_route in removed_routes {
println!("Route {:?} removed by router", removed_route); println!("Route {:?} removed by router", removed_route);
del_route(&tun_device, *removed_route.0).await?; del_route(&tun_device, *removed_route.0).await?;
} }
let new_routes = updated_routes.iter().filter(|(n,_)| !route_map.contains_key(n)); let new_routes = updated_routes.iter().filter(|(n,_)| !route_map.contains_key(n));
println!("New routes: {:?}", new_routes);
for new_route in new_routes { for new_route in new_routes {
println!("Route {} added by router", new_route.0); println!("Route {} added by router", new_route.0);
add_route(&tun_device, *new_route.0).await?; add_route(&tun_device, *new_route.0).await?;
@@ -108,21 +105,23 @@ pub async fn start(config: ClientCfg) -> Result<()> {
} }
RouterMessages::Data(packet_info) => { RouterMessages::Data(_packet_info) => {
if let Some(data) = data { if let Some(data) = data {
println!("Received data packet from router for client {}: size {} bytes", packet_info.src_uuid, packet_info.data_len);
tun_device.send(&data).await?; tun_device.send(&data).await?;
} }
}, },
_ => println!("Received message from router: {:?}", message), _ => {
let msg = format!("Unexpected message type received from router: {:?}.", message);
client_stream.close().await?;
bail!(msg);
}
} }
} }
Err(e) => { Err(e) => {
eprintln!("Error reading from router: {}", e);
client_stream.close().await?; client_stream.close().await?;
bail!(format!("Error reading from router: {}", e)); bail!("Error reading from router: {}", e);
} }
} }
} }
@@ -131,11 +130,9 @@ pub async fn start(config: ClientCfg) -> Result<()> {
match data { match data {
Ok(n) => { Ok(n) => {
let packet = etherparse::Ipv4HeaderSlice::from_slice(&tun_buf[..n])?; let packet = etherparse::Ipv4HeaderSlice::from_slice(&tun_buf[..n])?;
println!("Read packet from TUN interface: {} -> {}, size: {}", packet.source_addr(), packet.destination_addr(), n);
let dst = packet.destination_addr(); let dst = packet.destination_addr();
match ip_match_network(dst, &route_map).await { match ip_match_network(dst, &route_map).await {
Some(uuid) => { Some(uuid) => {
println!("Packet destination {} matches route for client {}, sending to router", dst, uuid);
let msg = VpnPacket{ let msg = VpnPacket{
dst_uuid: uuid, dst_uuid: uuid,
src_uuid: self_uuid, src_uuid: self_uuid,
@@ -150,9 +147,8 @@ pub async fn start(config: ClientCfg) -> Result<()> {
} }
Err(e) => { Err(e) => {
eprintln!("Error reading from TUN interface: {}", e);
client_stream.close().await?; client_stream.close().await?;
bail!(format!("Error reading from TUN interface: {}", e)); bail!("Error reading from TUN interface: {}", e);
} }
} }
} }
@@ -171,8 +167,11 @@ pub async fn start(config: ClientCfg) -> Result<()> {
// Ok(()) // Ok(())
} }
pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, buf: &mut [u8]) -> Result<Uuid> { pub async fn register_client(client_stream: &ClientStream, client_cfg: &mut ClientCfg, buf: &mut [u8]) -> Result<Uuid> {
let register_msg = RouterMessages::CliReg(CliRegMessages::Reg(config.clone())); let local_peer_address = client_cfg.interface_ip;
client_cfg.local_routes.push(local_peer_address);
let register_msg = RouterMessages::CliReg(CliRegMessages::Reg(client_cfg.clone()));
let mut client_registration_timeout = let mut client_registration_timeout =
tokio::time::interval_at(Instant::now() + CLIENT_REGISTER_TIMEOUT, CLIENT_REGISTER_TIMEOUT); tokio::time::interval_at(Instant::now() + CLIENT_REGISTER_TIMEOUT, CLIENT_REGISTER_TIMEOUT);
@@ -192,20 +191,17 @@ pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, b
} }
RouterMessages::CliReg(CliRegMessages::RegFailed(err_msg)) => { RouterMessages::CliReg(CliRegMessages::RegFailed(err_msg)) => {
eprintln!("Client registration failed: {}", err_msg);
bail!("Client registration failed: {}", err_msg); bail!("Client registration failed: {}", err_msg);
} }
_ => { unexpected_msg => {
let msg = "Unexpected message type received during client registration."; let msg = format!("Unexpected message type received during client registration: {:?}.", unexpected_msg);
eprintln!("{}", msg);
client_stream.close().await?; client_stream.close().await?;
bail!(msg); bail!(msg);
} }
} }
} }
Err(e) => { Err(e) => {
eprintln!("Error reading from router during client registration: {}", e);
client_stream.close().await?; client_stream.close().await?;
bail!("Error reading from router: {}", e); bail!("Error reading from router: {}", e);
} }
@@ -213,8 +209,6 @@ pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, b
} }
_= client_registration_timeout.tick() => { _= client_registration_timeout.tick() => {
let msg = "Client registration timed out waiting for confirmation from router."; let msg = "Client registration timed out waiting for confirmation from router.";
eprintln!("{}", msg);
eprintln!("Closing connection with Server");
client_stream.close().await?; client_stream.close().await?;
bail!(msg); bail!(msg);
} }

View File

@@ -3,6 +3,8 @@ pub mod config;
pub mod network; pub mod network;
pub mod router; pub mod router;
pub mod tun; pub mod tun;
#[cfg(test)]
mod tests;
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};

View File

@@ -108,8 +108,8 @@ pub struct ClientStream {
rx: Arc<Mutex<OwnedReadHalf>>, rx: Arc<Mutex<OwnedReadHalf>>,
} }
impl ClientStream { impl ClientStream {
pub fn new(tx: TcpStream) -> Self { pub fn new(tcp_stream: TcpStream) -> Self {
let (rx, tx) = tx.into_split(); let (rx, tx) = tcp_stream.into_split();
Self { Self {
tx: Arc::new(Mutex::new(tx)), tx: Arc::new(Mutex::new(tx)),
rx: Arc::new(Mutex::new(rx)), rx: Arc::new(Mutex::new(rx)),
@@ -153,10 +153,7 @@ impl ClientStream {
// unwrap is safe here because we already checked payload is some // unwrap is safe here because we already checked payload is some
packet.extend_from_slice(&payload.unwrap()); packet.extend_from_slice(&payload.unwrap());
} }
println!(
"Sending packet: total_len={}, msg_len={}, payload_len={}, message={:?}",
total_len, msg_len, payload_len, msg
);
if total_len as usize > SERVER_PACKET_SIZE { if total_len as usize > SERVER_PACKET_SIZE {
bail!( bail!(
"Packet too large to send, receive buffer too small: total_len={}, buf_len={}", "Packet too large to send, receive buffer too small: total_len={}, buf_len={}",
@@ -225,28 +222,61 @@ impl ClientStream {
tx.shutdown().await?; tx.shutdown().await?;
Ok(()) Ok(())
} }
pub async fn peer_addr(&self) -> Result<SocketAddr> {
let tx = self.tx.lock().await;
let peer_addr = tx.peer_addr()?;
Ok(peer_addr)
}
} }
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct Router { pub struct Router {
clients: Arc<RwLock<HashMap<Uuid, VPNClient>>>, clients: Arc<RwLock<HashMap<Uuid, VPNClient>>>,
routing_table: Arc<RwLock<HashMap<Ipv4Net, Uuid>>>, routing_table: Arc<RwLock<HashMap<Ipv4Net, Uuid>>>,
registration_lock: Arc<Mutex<()>>,
} }
impl Router { impl Router {
pub async fn register_client(&self, client: ClientCfg, vpn_client: VPNClient) -> Result<()> { pub async fn register_client(
let id = vpn_client.id(); &self,
client_stream: ClientStream,
client_cfg: &ClientCfg,
id: Uuid,
) -> Result<VPNClient> {
// Ensure only one registration process can run at a time to
// Drops lock immediately after checking for route conflicts and updating routing table
// to minimize contention and allow concurrent client registrations
// as much as possible while still preventing
let _reg_lock = self.registration_lock.lock().await;
// Append local interface IP to client's local routes for routing table construction println!(
let mut client_local_route = client.local_routes; "Registering client {} with local routes: {:?}",
client_local_route.push(client.interface_ip); id, client_cfg.local_routes
);
// Add Client local routes to global routing table // Build a map of the client's routes to their UUID for efficient registration and routing table updates
for net in client_local_route { let reg_routes: HashMap<Ipv4Net, Uuid> =
self.routing_table.write().await.insert(net, id); HashMap::from_iter(client_cfg.local_routes.iter().map(|route| (*route, id)));
for route in reg_routes.keys() {
if let Some(existing_client_id) = self.check_if_route_exists(&route).await {
let router_msg = format!(
"Route {} overlaps with existing route for client {}, cannot register client {}",
route, existing_client_id, id
);
let cli_reg_msg = format!("Failed to register, you have overlapping routes {} ", route);
client_stream
.send(RouterMessages::CliReg(CliRegMessages::RegFailed(cli_reg_msg)), None)
.await?;
bail!(router_msg);
}
} }
self.clients.write().await.insert(id, vpn_client); let vpn_client = VPNClient::new(id, client_stream);
self.update_all_clients().await self.clients.write().await.insert(id, vpn_client.clone());
self.routing_table.write().await.extend(reg_routes);
return Ok(vpn_client);
} }
pub async fn remove_client(&self, client_id: Uuid) -> Result<()> { pub async fn remove_client(&self, client_id: Uuid) -> Result<()> {
@@ -263,6 +293,15 @@ impl Router {
self.clients.read().await.get(&uuid).cloned() self.clients.read().await.get(&uuid).cloned()
} }
pub async fn check_if_route_exists(&self, net: &Ipv4Net) -> Option<Uuid> {
for (existing_net, existing_client_id) in self.routing_table.read().await.iter() {
if existing_net.contains(net) || net.contains(existing_net) {
return Some(*existing_client_id);
}
}
None
}
pub async fn update_all_clients(&self) -> Result<()> { pub async fn update_all_clients(&self) -> Result<()> {
// Lock only the routing table and clients list once to get the current state, // Lock only the routing table and clients list once to get the current state,
// then release locks before sending updates to avoid blocking other operations // then release locks before sending updates to avoid blocking other operations
@@ -325,9 +364,8 @@ pub async fn start(bind_address: SocketAddr) -> Result<()> {
_ = handle_router_connections(router.clone(), socket) => {} _ = handle_router_connections(router.clone(), socket) => {}
_ = tokio::signal::ctrl_c() => { _ = tokio::signal::ctrl_c() => {
println!();
let msg = "Received shutdown signal, shutting down router..."; let msg = "Received shutdown signal, shutting down router...";
println!("{}", msg); println!("\n{}", msg);
router.close_all_client(msg).await? router.close_all_client(msg).await?
} }
} }
@@ -343,11 +381,19 @@ async fn handle_router_connections(router: Router, socket: tokio::net::TcpListen
println!("Accepted connection from {}", addr); println!("Accepted connection from {}", addr);
//Clone the router for the new task //Clone the router for the new task
tokio::spawn(async move { tokio::spawn(async move {
println!("Handling connection from {}", addr); let id = Uuid::new_v4();
match handle_client(router, tcp_stream).await { println!("Handling connection from {} :{}", addr, id);
Ok(_) => println!("Finished handling connection from {}", addr), match handle_client(router.clone(), tcp_stream, id).await {
Ok(_) => {
println!("Finished handling connection from {}", addr);
}
Err(e) => eprintln!("Error handling connection from {}: {}", addr, e), Err(e) => eprintln!("Error handling connection from {}: {}", addr, e),
} }
println!("Cleaning up client {} from connection {}", id, addr);
router
.remove_client(id)
.await
.unwrap_or_else(|e| eprintln!("Error removing client {}: {}", id, e));
}); });
} }
Err(e) => { Err(e) => {
@@ -357,48 +403,73 @@ async fn handle_router_connections(router: Router, socket: tokio::net::TcpListen
} }
} }
pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> { pub async fn handle_client(router: Router, tcp_stream: TcpStream, id: Uuid) -> Result<Uuid> {
stream.set_nodelay(TCP_NODELAY)?; tcp_stream.set_nodelay(TCP_NODELAY)?;
let vpn_client = VPNClient::new(Uuid::new_v4(), ClientStream::new(stream));
let mut keep_alive_tick = tokio::time::interval_at(Instant::now() + KEEP_ALIVE_INTERVAL, KEEP_ALIVE_INTERVAL); let mut keep_alive_tick = tokio::time::interval_at(Instant::now() + KEEP_ALIVE_INTERVAL, KEEP_ALIVE_INTERVAL);
let mut buf = vec![0u8; SERVER_PACKET_SIZE]; let mut buf = vec![0u8; SERVER_PACKET_SIZE];
match client_init(&vpn_client, &mut buf).await { let peer_address = tcp_stream.peer_addr()?;
Ok(client) => {
let (client_stream, client_config) = match client_init(tcp_stream, &mut buf, id).await {
Ok((client_stream, client_cfg)) => {
println!( println!(
"Client {} registered with routing table: {:?} and local endpoint {:?}", "Client {} registered with routing table: {:?} and local endpoint {:?}",
vpn_client.id(), id, client_cfg.local_routes, peer_address
client.local_routes,
client.interface_ip
); );
println!("Registering client {} with router...", vpn_client.id()); (client_stream, client_cfg)
router.register_client(client, vpn_client.clone()).await?;
} }
Err(e) => { Err(e) => {
eprintln!("Failed to initialize client {}: {}", vpn_client.id(), e); let msg = format!(
return Err(e); "Failed to initialize client connection form peer {:?}:{} {}",
peer_address, id, e
);
bail!(msg);
} }
} };
let vpn_client = match router.register_client(client_stream, &client_config, id).await {
Ok(vpn_client) => {
println!(
"Successfully registered client {} with peer address {:?} and routing table: {:?} ",
vpn_client.id(),
peer_address,
client_config.local_routes
);
vpn_client
.send(RouterMessages::CliReg(CliRegMessages::RegOk(id)), None)
.await?;
vpn_client
}
Err(e) => {
let msg = format!(
"Failed to register client {} with peer address {:?}: {}",
id, peer_address, e
);
bail!(msg);
}
};
println!(
"Finished client initialization for client {} with peer address {:?}, entering main loop...",
vpn_client.id(),
peer_address
);
router.update_all_clients().await?;
loop { loop {
tokio::select! { tokio::select! {
msg = vpn_client.receive(&mut buf) => { msg = vpn_client.receive(&mut buf) => {
match msg { match msg {
Ok((message, data)) => { Ok((message, data)) => {
println!("Received message from client {}: {:?}", vpn_client.id(), message);
match message { match message {
RouterMessages::Quit(msg) =>{ RouterMessages::Quit(msg) =>{
println!("Received quit message from client {}: {}", vpn_client.id(), msg); println!("Received quit message from client {}: {}", vpn_client.id(), msg);
println!("Removing client {} from router...", vpn_client.id()); println!("Removing client {} from router...", vpn_client.id());
router.remove_client(vpn_client.id()).await?;
vpn_client.close().await?; vpn_client.close().await?;
return Ok(()); return Ok(id);
} }
RouterMessages::Data(packet_info) => { RouterMessages::Data(packet_info) => {
if let Some(dst_client) = router.get_client(packet_info.dst_uuid).await { if let Some(dst_client) = router.get_client(packet_info.dst_uuid).await {
println!("Forwarding packet from client {} to client {}", vpn_client.id(), dst_client.id());
if let Err(e) = dst_client.send(RouterMessages::Data(packet_info), data).await { if let Err(e) = dst_client.send(RouterMessages::Data(packet_info), data).await {
eprintln!("Error forwarding packet to client {}: {}", dst_client.id(), e); eprintln!("Error forwarding packet to client {}: {}", dst_client.id(), e);
} }
@@ -406,13 +477,15 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
eprintln!("Destination client {} not found for packet from client {}", packet_info.dst_uuid, vpn_client.id()); eprintln!("Destination client {} not found for packet from client {}", packet_info.dst_uuid, vpn_client.id());
} }
} }
_ => println!("Received message from client {}: {:?}", vpn_client.id(), message) _ => {
let msg = format!("Received unexpected message from client {}: {:?}", vpn_client.id(), message);
vpn_client.close().await?;
bail!(msg);
}
} }
} }
Err(e) => { Err(e) => {
eprintln!("Error reading from client {}: {}", vpn_client.id(), e); eprintln!("Error removing client {} from router...", vpn_client.id());
println!("Removing client {} from router...", vpn_client.id());
router.remove_client(vpn_client.id()).await?;
vpn_client.close().await?; vpn_client.close().await?;
bail!(format!("Error reading from client: {}", e)); bail!(format!("Error reading from client: {}", e));
} }
@@ -420,9 +493,7 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
} }
// _= router.notify_changes(vpn_client.id()) => {
// println!("Routing table updated. Current routing table:");
// }
_= keep_alive_tick.tick() => { _= keep_alive_tick.tick() => {
// Send keep-alive message to the client // Send keep-alive message to the client
@@ -435,35 +506,33 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
//Ok(()) //Ok(())
} }
pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result<ClientCfg> { pub async fn client_init(tcp_stream: TcpStream, buf: &mut [u8], id: Uuid) -> Result<(ClientStream, ClientCfg)> {
let client_stream = ClientStream::new(tcp_stream);
let mut client_registration_timeout = let mut client_registration_timeout =
tokio::time::interval_at(Instant::now() + CLIENT_REGISTER_TIMEOUT, CLIENT_REGISTER_TIMEOUT); tokio::time::interval_at(Instant::now() + CLIENT_REGISTER_TIMEOUT, CLIENT_REGISTER_TIMEOUT);
loop { loop {
tokio::select! { tokio::select! {
msg = vpn_client.receive(buf) => { msg = client_stream.receive(buf) => {
match msg { match msg {
Ok((router_msg, _)) => { Ok((router_msg, _)) => {
match router_msg { match router_msg {
RouterMessages::CliReg(CliRegMessages::Reg(client))=> { RouterMessages::CliReg(CliRegMessages::Reg(client_cfg))=> {
println!("Received client registration with routing table: {:?}", client.local_routes); println!("Received client registration with routing table: {:?}", client_cfg.local_routes);
vpn_client.send(RouterMessages::CliReg(CliRegMessages::RegOk(vpn_client.id())),None).await?; return Ok((client_stream,client_cfg));
return Ok(client);
} }
router_msg => { router_msg => {
let msg = format!("Expected client registration message, but received: {:?}", router_msg); let msg = format!("Expected client registration message, but received: {:?}", router_msg);
eprintln!("{}", msg); client_stream.close().await?;
eprintln!("Closing connection with client {}", vpn_client.id());
vpn_client.close().await?;
bail!(msg); bail!(msg);
} }
} }
} }
Err(e) => { Err(e) => {
eprintln!("Error reading from client {} during registration: {}", vpn_client.id(), e); let msg = format!("Error reading from client {} during registration: {:?}", id, e);
vpn_client.close().await?; client_stream.close().await?;
bail!(format!("Error reading from client during registration: {}", e)); bail!(msg);
} }
} }
@@ -472,8 +541,8 @@ pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result<Clien
_ = client_registration_timeout.tick() => { _ = client_registration_timeout.tick() => {
let msg = format!("Client registration timed out after {}ms", (CLIENT_REGISTER_TIMEOUT.as_millis())); let msg = format!("Client registration timed out after {}ms", (CLIENT_REGISTER_TIMEOUT.as_millis()));
vpn_client.send(RouterMessages::Quit(msg.clone()),None).await?; client_stream.send(RouterMessages::Quit(msg.clone()),None).await?;
vpn_client.close().await?; client_stream.close().await?;
bail!(msg); bail!(msg);
} }

111
src/tests.rs Normal file
View File

@@ -0,0 +1,111 @@
use std::net::Ipv4Addr;
use ipnet::Ipv4Net;
use uuid::Uuid;
use crate::{
client::ClientCfg,
network::ip_match_network,
router::{CliRegMessages, RouterMessages, VpnPacket},
};
fn sample_cfg() -> ClientCfg {
ClientCfg {
server: "127.0.0.1:443".to_string(),
interface_ip: "10.8.0.2/32".parse().expect("cidr parse"),
interface_name: "xvpn0".to_string(),
local_routes: vec!["10.10.0.0/24".parse().expect("cidr parse")],
mtu: 1400,
}
}
#[tokio::test]
async fn router_message_keep_alive_roundtrip_json() {
let msg = RouterMessages::KeepAlive(123_456);
let bytes = msg.to_bytes().await;
let decoded = RouterMessages::from_slice(&bytes).await;
match decoded {
RouterMessages::KeepAlive(ts) => assert_eq!(ts, 123_456),
other => panic!("unexpected decoded message: {other:?}"),
}
}
#[tokio::test]
async fn router_message_data_roundtrip_json() {
let src = Uuid::new_v4();
let dst = Uuid::new_v4();
let msg = RouterMessages::Data(VpnPacket {
src_uuid: src,
dst_uuid: dst,
data_len: 512,
});
let bytes = msg.to_bytes().await;
let decoded = RouterMessages::from_slice(&bytes).await;
match decoded {
RouterMessages::Data(pkt) => {
assert_eq!(pkt.src_uuid, src);
assert_eq!(pkt.dst_uuid, dst);
assert_eq!(pkt.data_len, 512);
}
other => panic!("unexpected decoded message: {other:?}"),
}
}
#[tokio::test]
async fn router_message_from_invalid_slice_is_unknown() {
let decoded = RouterMessages::from_slice(b"{ not valid json").await;
match decoded {
RouterMessages::Unknown(s) => assert!(!s.is_empty(), "unknown payload should not be empty"),
other => panic!("expected Unknown, got {other:?}"),
}
}
#[test]
fn cli_reg_message_reg_roundtrip_json() {
let cfg = sample_cfg();
let msg = CliRegMessages::Reg(cfg.clone());
let bytes = msg.to_bytes();
let decoded = CliRegMessages::from_slice(&bytes);
match decoded {
CliRegMessages::Reg(decoded_cfg) => {
assert_eq!(decoded_cfg.server, cfg.server);
assert_eq!(decoded_cfg.interface_ip, cfg.interface_ip);
assert_eq!(decoded_cfg.interface_name, cfg.interface_name);
assert_eq!(decoded_cfg.local_routes, cfg.local_routes);
assert_eq!(decoded_cfg.mtu, cfg.mtu);
}
other => panic!("unexpected decoded message: {other:?}"),
}
}
#[test]
fn cli_reg_message_invalid_utf8_is_unknown() {
let decoded = CliRegMessages::from_slice(&[0xff, 0xfe, 0xfd]);
match decoded {
CliRegMessages::Unknown(s) => assert!(s.contains("Invalid UTF-8"), "unexpected content: {s}"),
other => panic!("expected Unknown, got {other:?}"),
}
}
#[tokio::test]
async fn ip_match_network_returns_expected_uuid() {
let mut routes = std::collections::HashMap::new();
let id = Uuid::new_v4();
routes.insert("192.168.1.0/24".parse::<Ipv4Net>().expect("cidr parse"), id);
let match_id = ip_match_network(Ipv4Addr::new(192, 168, 1, 44), &routes).await;
assert_eq!(match_id, Some(id));
}
#[tokio::test]
async fn ip_match_network_returns_none_when_no_route_matches() {
let mut routes = std::collections::HashMap::new();
routes.insert("192.168.1.0/24".parse::<Ipv4Net>().expect("cidr parse"), Uuid::new_v4());
let no_match = ip_match_network(Ipv4Addr::new(10, 10, 10, 10), &routes).await;
assert!(no_match.is_none(), "non-matching ip should return none");
}

View File

@@ -5,7 +5,7 @@ use tun_rs::{AsyncDevice, DeviceBuilder};
use crate::client::ClientCfg; use crate::client::ClientCfg;
pub async fn inti_tun_interface(config: &ClientCfg) -> Result<AsyncDevice> { pub async fn init_tun_if(config: &ClientCfg) -> Result<AsyncDevice> {
println!( println!(
"Initializing TUN interface with name: {}, IP: {}/{}, MTU: {}", "Initializing TUN interface with name: {}, IP: {}/{}, MTU: {}",
config.interface_name, config.interface_name,
@@ -22,7 +22,6 @@ pub async fn inti_tun_interface(config: &ClientCfg) -> Result<AsyncDevice> {
Ok(dev) => dev, Ok(dev) => dev,
Err(e) => { Err(e) => {
let msg = format!("Failed to create TUN interface: {:#}", e); let msg = format!("Failed to create TUN interface: {:#}", e);
eprintln!("{}", msg);
bail!(msg); bail!(msg);
} }
}; };
@@ -44,7 +43,6 @@ pub async fn add_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> {
if !output.status.success() { if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr); let stderr = String::from_utf8_lossy(&output.stderr);
let msg = format!("Failed to add route {} via device {}: {:#}", route, dev, stderr); let msg = format!("Failed to add route {} via device {}: {:#}", route, dev, stderr);
eprintln!("{}", msg);
bail!(msg); bail!(msg);
} }
@@ -65,7 +63,6 @@ pub async fn del_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> {
if !output.status.success() { if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr); let stderr = String::from_utf8_lossy(&output.stderr);
let msg = format!("Failed to delete route {} via device {}: {:#}", route, dev, stderr); let msg = format!("Failed to delete route {} via device {}: {:#}", route, dev, stderr);
eprintln!("{}", msg);
bail!(msg); bail!(msg);
} }