Refactor client and router code for improved readability, fix function naming, and enhance error handling; add initial test cases and configuration file

This commit is contained in:
2026-03-05 14:36:08 +01:00
parent 479270834d
commit dad49936be
7 changed files with 269 additions and 89 deletions

View File

@@ -0,0 +1,6 @@
[mode.Client]
server = "127.0.0.1:443"
interface_ip = "5.5.5.5/32"
interface_name = "xvpn1"
local_routes = ["6.6.6.0/24"]
mtu = 1400

1
client_test01/xvpn Symbolic link
View File

@@ -0,0 +1 @@
../target/release/xvpn

View File

@@ -13,7 +13,7 @@ use crate::{
CLIENT_REGISTER_TIMEOUT, CliRegMessages, ClientStream, RouterMessages, RoutesMap, SERVER_PACKET_SIZE,
TCP_NODELAY, VpnPacket,
},
tun::{add_route, del_route, inti_tun_interface},
tun::{add_route, del_route, init_tun_if},
};
pub struct ClientStaTistic {
@@ -50,7 +50,7 @@ pub struct ClientCfg {
pub mtu: u16,
}
pub async fn start(config: ClientCfg) -> Result<()> {
pub async fn start(mut config: ClientCfg) -> Result<()> {
println!("Starting client with config: {:?}", config);
let stream = tokio::net::TcpStream::connect(&config.server).await?;
@@ -61,9 +61,9 @@ pub async fn start(config: ClientCfg) -> Result<()> {
let mut vpn_buf = vec![0u8; SERVER_PACKET_SIZE];
let mut tun_buf = vec![0u8; config.mtu as usize];
let self_uuid = register_client(&client_stream, &config, &mut vpn_buf).await?;
let self_uuid = register_client(&client_stream, &mut config, &mut vpn_buf).await?;
let tun_device = inti_tun_interface(&config).await?;
let tun_device = init_tun_if(&config).await?;
let mut route_map = RoutesMap::default();
@@ -88,17 +88,14 @@ pub async fn start(config: ClientCfg) -> Result<()> {
RouterMessages::RouteUpdate(mut updated_routes) => {
updated_routes.retain(|_ , u| *u != self_uuid);
println!("Received route update from router: {:?}", updated_routes);
let removed_routes = route_map.iter().filter(|(n,_)| !updated_routes.contains_key(n));
println!("Removed routes: {:?}", removed_routes);
for removed_route in removed_routes {
println!("Route {:?} removed by router", removed_route);
del_route(&tun_device, *removed_route.0).await?;
}
let new_routes = updated_routes.iter().filter(|(n,_)| !route_map.contains_key(n));
println!("New routes: {:?}", new_routes);
for new_route in new_routes {
println!("Route {} added by router", new_route.0);
add_route(&tun_device, *new_route.0).await?;
@@ -108,21 +105,23 @@ pub async fn start(config: ClientCfg) -> Result<()> {
}
RouterMessages::Data(packet_info) => {
RouterMessages::Data(_packet_info) => {
if let Some(data) = data {
println!("Received data packet from router for client {}: size {} bytes", packet_info.src_uuid, packet_info.data_len);
tun_device.send(&data).await?;
}
},
_ => println!("Received message from router: {:?}", message),
_ => {
let msg = format!("Unexpected message type received from router: {:?}.", message);
client_stream.close().await?;
bail!(msg);
}
}
}
Err(e) => {
eprintln!("Error reading from router: {}", e);
client_stream.close().await?;
bail!(format!("Error reading from router: {}", e));
bail!("Error reading from router: {}", e);
}
}
}
@@ -131,11 +130,9 @@ pub async fn start(config: ClientCfg) -> Result<()> {
match data {
Ok(n) => {
let packet = etherparse::Ipv4HeaderSlice::from_slice(&tun_buf[..n])?;
println!("Read packet from TUN interface: {} -> {}, size: {}", packet.source_addr(), packet.destination_addr(), n);
let dst = packet.destination_addr();
match ip_match_network(dst, &route_map).await {
Some(uuid) => {
println!("Packet destination {} matches route for client {}, sending to router", dst, uuid);
let msg = VpnPacket{
dst_uuid: uuid,
src_uuid: self_uuid,
@@ -150,9 +147,8 @@ pub async fn start(config: ClientCfg) -> Result<()> {
}
Err(e) => {
eprintln!("Error reading from TUN interface: {}", e);
client_stream.close().await?;
bail!(format!("Error reading from TUN interface: {}", e));
bail!("Error reading from TUN interface: {}", e);
}
}
}
@@ -171,8 +167,11 @@ pub async fn start(config: ClientCfg) -> Result<()> {
// Ok(())
}
pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, buf: &mut [u8]) -> Result<Uuid> {
let register_msg = RouterMessages::CliReg(CliRegMessages::Reg(config.clone()));
pub async fn register_client(client_stream: &ClientStream, client_cfg: &mut ClientCfg, buf: &mut [u8]) -> Result<Uuid> {
let local_peer_address = client_cfg.interface_ip;
client_cfg.local_routes.push(local_peer_address);
let register_msg = RouterMessages::CliReg(CliRegMessages::Reg(client_cfg.clone()));
let mut client_registration_timeout =
tokio::time::interval_at(Instant::now() + CLIENT_REGISTER_TIMEOUT, CLIENT_REGISTER_TIMEOUT);
@@ -192,20 +191,17 @@ pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, b
}
RouterMessages::CliReg(CliRegMessages::RegFailed(err_msg)) => {
eprintln!("Client registration failed: {}", err_msg);
bail!("Client registration failed: {}", err_msg);
}
_ => {
let msg = "Unexpected message type received during client registration.";
eprintln!("{}", msg);
unexpected_msg => {
let msg = format!("Unexpected message type received during client registration: {:?}.", unexpected_msg);
client_stream.close().await?;
bail!(msg);
}
}
}
Err(e) => {
eprintln!("Error reading from router during client registration: {}", e);
client_stream.close().await?;
bail!("Error reading from router: {}", e);
}
@@ -213,8 +209,6 @@ pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, b
}
_= client_registration_timeout.tick() => {
let msg = "Client registration timed out waiting for confirmation from router.";
eprintln!("{}", msg);
eprintln!("Closing connection with Server");
client_stream.close().await?;
bail!(msg);
}

View File

@@ -3,6 +3,8 @@ pub mod config;
pub mod network;
pub mod router;
pub mod tun;
#[cfg(test)]
mod tests;
use clap::{Parser, Subcommand};

View File

@@ -108,8 +108,8 @@ pub struct ClientStream {
rx: Arc<Mutex<OwnedReadHalf>>,
}
impl ClientStream {
pub fn new(tx: TcpStream) -> Self {
let (rx, tx) = tx.into_split();
pub fn new(tcp_stream: TcpStream) -> Self {
let (rx, tx) = tcp_stream.into_split();
Self {
tx: Arc::new(Mutex::new(tx)),
rx: Arc::new(Mutex::new(rx)),
@@ -153,10 +153,7 @@ impl ClientStream {
// unwrap is safe here because we already checked payload is some
packet.extend_from_slice(&payload.unwrap());
}
println!(
"Sending packet: total_len={}, msg_len={}, payload_len={}, message={:?}",
total_len, msg_len, payload_len, msg
);
if total_len as usize > SERVER_PACKET_SIZE {
bail!(
"Packet too large to send, receive buffer too small: total_len={}, buf_len={}",
@@ -225,28 +222,61 @@ impl ClientStream {
tx.shutdown().await?;
Ok(())
}
pub async fn peer_addr(&self) -> Result<SocketAddr> {
let tx = self.tx.lock().await;
let peer_addr = tx.peer_addr()?;
Ok(peer_addr)
}
}
#[derive(Debug, Clone, Default)]
pub struct Router {
clients: Arc<RwLock<HashMap<Uuid, VPNClient>>>,
routing_table: Arc<RwLock<HashMap<Ipv4Net, Uuid>>>,
registration_lock: Arc<Mutex<()>>,
}
impl Router {
pub async fn register_client(&self, client: ClientCfg, vpn_client: VPNClient) -> Result<()> {
let id = vpn_client.id();
pub async fn register_client(
&self,
client_stream: ClientStream,
client_cfg: &ClientCfg,
id: Uuid,
) -> Result<VPNClient> {
// Ensure only one registration process can run at a time to
// Drops lock immediately after checking for route conflicts and updating routing table
// to minimize contention and allow concurrent client registrations
// as much as possible while still preventing
let _reg_lock = self.registration_lock.lock().await;
// Append local interface IP to client's local routes for routing table construction
let mut client_local_route = client.local_routes;
client_local_route.push(client.interface_ip);
println!(
"Registering client {} with local routes: {:?}",
id, client_cfg.local_routes
);
// Add Client local routes to global routing table
for net in client_local_route {
self.routing_table.write().await.insert(net, id);
// Build a map of the client's routes to their UUID for efficient registration and routing table updates
let reg_routes: HashMap<Ipv4Net, Uuid> =
HashMap::from_iter(client_cfg.local_routes.iter().map(|route| (*route, id)));
for route in reg_routes.keys() {
if let Some(existing_client_id) = self.check_if_route_exists(&route).await {
let router_msg = format!(
"Route {} overlaps with existing route for client {}, cannot register client {}",
route, existing_client_id, id
);
let cli_reg_msg = format!("Failed to register, you have overlapping routes {} ", route);
client_stream
.send(RouterMessages::CliReg(CliRegMessages::RegFailed(cli_reg_msg)), None)
.await?;
bail!(router_msg);
}
}
self.clients.write().await.insert(id, vpn_client);
self.update_all_clients().await
let vpn_client = VPNClient::new(id, client_stream);
self.clients.write().await.insert(id, vpn_client.clone());
self.routing_table.write().await.extend(reg_routes);
return Ok(vpn_client);
}
pub async fn remove_client(&self, client_id: Uuid) -> Result<()> {
@@ -263,6 +293,15 @@ impl Router {
self.clients.read().await.get(&uuid).cloned()
}
pub async fn check_if_route_exists(&self, net: &Ipv4Net) -> Option<Uuid> {
for (existing_net, existing_client_id) in self.routing_table.read().await.iter() {
if existing_net.contains(net) || net.contains(existing_net) {
return Some(*existing_client_id);
}
}
None
}
pub async fn update_all_clients(&self) -> Result<()> {
// Lock only the routing table and clients list once to get the current state,
// then release locks before sending updates to avoid blocking other operations
@@ -325,9 +364,8 @@ pub async fn start(bind_address: SocketAddr) -> Result<()> {
_ = handle_router_connections(router.clone(), socket) => {}
_ = tokio::signal::ctrl_c() => {
println!();
let msg = "Received shutdown signal, shutting down router...";
println!("{}", msg);
println!("\n{}", msg);
router.close_all_client(msg).await?
}
}
@@ -343,11 +381,19 @@ async fn handle_router_connections(router: Router, socket: tokio::net::TcpListen
println!("Accepted connection from {}", addr);
//Clone the router for the new task
tokio::spawn(async move {
println!("Handling connection from {}", addr);
match handle_client(router, tcp_stream).await {
Ok(_) => println!("Finished handling connection from {}", addr),
let id = Uuid::new_v4();
println!("Handling connection from {} :{}", addr, id);
match handle_client(router.clone(), tcp_stream, id).await {
Ok(_) => {
println!("Finished handling connection from {}", addr);
}
Err(e) => eprintln!("Error handling connection from {}: {}", addr, e),
}
println!("Cleaning up client {} from connection {}", id, addr);
router
.remove_client(id)
.await
.unwrap_or_else(|e| eprintln!("Error removing client {}: {}", id, e));
});
}
Err(e) => {
@@ -357,48 +403,73 @@ async fn handle_router_connections(router: Router, socket: tokio::net::TcpListen
}
}
pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
stream.set_nodelay(TCP_NODELAY)?;
let vpn_client = VPNClient::new(Uuid::new_v4(), ClientStream::new(stream));
pub async fn handle_client(router: Router, tcp_stream: TcpStream, id: Uuid) -> Result<Uuid> {
tcp_stream.set_nodelay(TCP_NODELAY)?;
let mut keep_alive_tick = tokio::time::interval_at(Instant::now() + KEEP_ALIVE_INTERVAL, KEEP_ALIVE_INTERVAL);
let mut buf = vec![0u8; SERVER_PACKET_SIZE];
match client_init(&vpn_client, &mut buf).await {
Ok(client) => {
let peer_address = tcp_stream.peer_addr()?;
let (client_stream, client_config) = match client_init(tcp_stream, &mut buf, id).await {
Ok((client_stream, client_cfg)) => {
println!(
"Client {} registered with routing table: {:?} and local endpoint {:?}",
vpn_client.id(),
client.local_routes,
client.interface_ip
id, client_cfg.local_routes, peer_address
);
println!("Registering client {} with router...", vpn_client.id());
router.register_client(client, vpn_client.clone()).await?;
(client_stream, client_cfg)
}
Err(e) => {
eprintln!("Failed to initialize client {}: {}", vpn_client.id(), e);
return Err(e);
let msg = format!(
"Failed to initialize client connection form peer {:?}:{} {}",
peer_address, id, e
);
bail!(msg);
}
};
let vpn_client = match router.register_client(client_stream, &client_config, id).await {
Ok(vpn_client) => {
println!(
"Successfully registered client {} with peer address {:?} and routing table: {:?} ",
vpn_client.id(),
peer_address,
client_config.local_routes
);
vpn_client
.send(RouterMessages::CliReg(CliRegMessages::RegOk(id)), None)
.await?;
vpn_client
}
Err(e) => {
let msg = format!(
"Failed to register client {} with peer address {:?}: {}",
id, peer_address, e
);
bail!(msg);
}
};
println!(
"Finished client initialization for client {} with peer address {:?}, entering main loop...",
vpn_client.id(),
peer_address
);
router.update_all_clients().await?;
loop {
tokio::select! {
msg = vpn_client.receive(&mut buf) => {
match msg {
Ok((message, data)) => {
println!("Received message from client {}: {:?}", vpn_client.id(), message);
match message {
RouterMessages::Quit(msg) =>{
println!("Received quit message from client {}: {}", vpn_client.id(), msg);
println!("Removing client {} from router...", vpn_client.id());
router.remove_client(vpn_client.id()).await?;
vpn_client.close().await?;
return Ok(());
return Ok(id);
}
RouterMessages::Data(packet_info) => {
if let Some(dst_client) = router.get_client(packet_info.dst_uuid).await {
println!("Forwarding packet from client {} to client {}", vpn_client.id(), dst_client.id());
if let Err(e) = dst_client.send(RouterMessages::Data(packet_info), data).await {
eprintln!("Error forwarding packet to client {}: {}", dst_client.id(), e);
}
@@ -406,13 +477,15 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
eprintln!("Destination client {} not found for packet from client {}", packet_info.dst_uuid, vpn_client.id());
}
}
_ => println!("Received message from client {}: {:?}", vpn_client.id(), message)
_ => {
let msg = format!("Received unexpected message from client {}: {:?}", vpn_client.id(), message);
vpn_client.close().await?;
bail!(msg);
}
}
}
Err(e) => {
eprintln!("Error reading from client {}: {}", vpn_client.id(), e);
println!("Removing client {} from router...", vpn_client.id());
router.remove_client(vpn_client.id()).await?;
eprintln!("Error removing client {} from router...", vpn_client.id());
vpn_client.close().await?;
bail!(format!("Error reading from client: {}", e));
}
@@ -420,9 +493,7 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
}
// _= router.notify_changes(vpn_client.id()) => {
// println!("Routing table updated. Current routing table:");
// }
_= keep_alive_tick.tick() => {
// Send keep-alive message to the client
@@ -435,35 +506,33 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
//Ok(())
}
pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result<ClientCfg> {
pub async fn client_init(tcp_stream: TcpStream, buf: &mut [u8], id: Uuid) -> Result<(ClientStream, ClientCfg)> {
let client_stream = ClientStream::new(tcp_stream);
let mut client_registration_timeout =
tokio::time::interval_at(Instant::now() + CLIENT_REGISTER_TIMEOUT, CLIENT_REGISTER_TIMEOUT);
loop {
tokio::select! {
msg = vpn_client.receive(buf) => {
msg = client_stream.receive(buf) => {
match msg {
Ok((router_msg, _)) => {
match router_msg {
RouterMessages::CliReg(CliRegMessages::Reg(client))=> {
println!("Received client registration with routing table: {:?}", client.local_routes);
vpn_client.send(RouterMessages::CliReg(CliRegMessages::RegOk(vpn_client.id())),None).await?;
return Ok(client);
RouterMessages::CliReg(CliRegMessages::Reg(client_cfg))=> {
println!("Received client registration with routing table: {:?}", client_cfg.local_routes);
return Ok((client_stream,client_cfg));
}
router_msg => {
let msg = format!("Expected client registration message, but received: {:?}", router_msg);
eprintln!("{}", msg);
eprintln!("Closing connection with client {}", vpn_client.id());
vpn_client.close().await?;
client_stream.close().await?;
bail!(msg);
}
}
}
Err(e) => {
eprintln!("Error reading from client {} during registration: {}", vpn_client.id(), e);
vpn_client.close().await?;
bail!(format!("Error reading from client during registration: {}", e));
let msg = format!("Error reading from client {} during registration: {:?}", id, e);
client_stream.close().await?;
bail!(msg);
}
}
@@ -472,8 +541,8 @@ pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result<Clien
_ = client_registration_timeout.tick() => {
let msg = format!("Client registration timed out after {}ms", (CLIENT_REGISTER_TIMEOUT.as_millis()));
vpn_client.send(RouterMessages::Quit(msg.clone()),None).await?;
vpn_client.close().await?;
client_stream.send(RouterMessages::Quit(msg.clone()),None).await?;
client_stream.close().await?;
bail!(msg);
}

111
src/tests.rs Normal file
View File

@@ -0,0 +1,111 @@
use std::net::Ipv4Addr;
use ipnet::Ipv4Net;
use uuid::Uuid;
use crate::{
client::ClientCfg,
network::ip_match_network,
router::{CliRegMessages, RouterMessages, VpnPacket},
};
fn sample_cfg() -> ClientCfg {
ClientCfg {
server: "127.0.0.1:443".to_string(),
interface_ip: "10.8.0.2/32".parse().expect("cidr parse"),
interface_name: "xvpn0".to_string(),
local_routes: vec!["10.10.0.0/24".parse().expect("cidr parse")],
mtu: 1400,
}
}
#[tokio::test]
async fn router_message_keep_alive_roundtrip_json() {
let msg = RouterMessages::KeepAlive(123_456);
let bytes = msg.to_bytes().await;
let decoded = RouterMessages::from_slice(&bytes).await;
match decoded {
RouterMessages::KeepAlive(ts) => assert_eq!(ts, 123_456),
other => panic!("unexpected decoded message: {other:?}"),
}
}
#[tokio::test]
async fn router_message_data_roundtrip_json() {
let src = Uuid::new_v4();
let dst = Uuid::new_v4();
let msg = RouterMessages::Data(VpnPacket {
src_uuid: src,
dst_uuid: dst,
data_len: 512,
});
let bytes = msg.to_bytes().await;
let decoded = RouterMessages::from_slice(&bytes).await;
match decoded {
RouterMessages::Data(pkt) => {
assert_eq!(pkt.src_uuid, src);
assert_eq!(pkt.dst_uuid, dst);
assert_eq!(pkt.data_len, 512);
}
other => panic!("unexpected decoded message: {other:?}"),
}
}
#[tokio::test]
async fn router_message_from_invalid_slice_is_unknown() {
let decoded = RouterMessages::from_slice(b"{ not valid json").await;
match decoded {
RouterMessages::Unknown(s) => assert!(!s.is_empty(), "unknown payload should not be empty"),
other => panic!("expected Unknown, got {other:?}"),
}
}
#[test]
fn cli_reg_message_reg_roundtrip_json() {
let cfg = sample_cfg();
let msg = CliRegMessages::Reg(cfg.clone());
let bytes = msg.to_bytes();
let decoded = CliRegMessages::from_slice(&bytes);
match decoded {
CliRegMessages::Reg(decoded_cfg) => {
assert_eq!(decoded_cfg.server, cfg.server);
assert_eq!(decoded_cfg.interface_ip, cfg.interface_ip);
assert_eq!(decoded_cfg.interface_name, cfg.interface_name);
assert_eq!(decoded_cfg.local_routes, cfg.local_routes);
assert_eq!(decoded_cfg.mtu, cfg.mtu);
}
other => panic!("unexpected decoded message: {other:?}"),
}
}
#[test]
fn cli_reg_message_invalid_utf8_is_unknown() {
let decoded = CliRegMessages::from_slice(&[0xff, 0xfe, 0xfd]);
match decoded {
CliRegMessages::Unknown(s) => assert!(s.contains("Invalid UTF-8"), "unexpected content: {s}"),
other => panic!("expected Unknown, got {other:?}"),
}
}
#[tokio::test]
async fn ip_match_network_returns_expected_uuid() {
let mut routes = std::collections::HashMap::new();
let id = Uuid::new_v4();
routes.insert("192.168.1.0/24".parse::<Ipv4Net>().expect("cidr parse"), id);
let match_id = ip_match_network(Ipv4Addr::new(192, 168, 1, 44), &routes).await;
assert_eq!(match_id, Some(id));
}
#[tokio::test]
async fn ip_match_network_returns_none_when_no_route_matches() {
let mut routes = std::collections::HashMap::new();
routes.insert("192.168.1.0/24".parse::<Ipv4Net>().expect("cidr parse"), Uuid::new_v4());
let no_match = ip_match_network(Ipv4Addr::new(10, 10, 10, 10), &routes).await;
assert!(no_match.is_none(), "non-matching ip should return none");
}

View File

@@ -5,7 +5,7 @@ use tun_rs::{AsyncDevice, DeviceBuilder};
use crate::client::ClientCfg;
pub async fn inti_tun_interface(config: &ClientCfg) -> Result<AsyncDevice> {
pub async fn init_tun_if(config: &ClientCfg) -> Result<AsyncDevice> {
println!(
"Initializing TUN interface with name: {}, IP: {}/{}, MTU: {}",
config.interface_name,
@@ -22,7 +22,6 @@ pub async fn inti_tun_interface(config: &ClientCfg) -> Result<AsyncDevice> {
Ok(dev) => dev,
Err(e) => {
let msg = format!("Failed to create TUN interface: {:#}", e);
eprintln!("{}", msg);
bail!(msg);
}
};
@@ -44,7 +43,6 @@ pub async fn add_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> {
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let msg = format!("Failed to add route {} via device {}: {:#}", route, dev, stderr);
eprintln!("{}", msg);
bail!(msg);
}
@@ -65,7 +63,6 @@ pub async fn del_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> {
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let msg = format!("Failed to delete route {} via device {}: {:#}", route, dev, stderr);
eprintln!("{}", msg);
bail!(msg);
}