Refactor client and router code for improved readability, fix function naming, and enhance error handling; add initial test cases and configuration file
This commit is contained in:
6
client_test01/.config/vpn_config.toml
Normal file
6
client_test01/.config/vpn_config.toml
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
[mode.Client]
|
||||||
|
server = "127.0.0.1:443"
|
||||||
|
interface_ip = "5.5.5.5/32"
|
||||||
|
interface_name = "xvpn1"
|
||||||
|
local_routes = ["6.6.6.0/24"]
|
||||||
|
mtu = 1400
|
||||||
1
client_test01/xvpn
Symbolic link
1
client_test01/xvpn
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../target/release/xvpn
|
||||||
@@ -13,7 +13,7 @@ use crate::{
|
|||||||
CLIENT_REGISTER_TIMEOUT, CliRegMessages, ClientStream, RouterMessages, RoutesMap, SERVER_PACKET_SIZE,
|
CLIENT_REGISTER_TIMEOUT, CliRegMessages, ClientStream, RouterMessages, RoutesMap, SERVER_PACKET_SIZE,
|
||||||
TCP_NODELAY, VpnPacket,
|
TCP_NODELAY, VpnPacket,
|
||||||
},
|
},
|
||||||
tun::{add_route, del_route, inti_tun_interface},
|
tun::{add_route, del_route, init_tun_if},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct ClientStaTistic {
|
pub struct ClientStaTistic {
|
||||||
@@ -50,7 +50,7 @@ pub struct ClientCfg {
|
|||||||
pub mtu: u16,
|
pub mtu: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn start(config: ClientCfg) -> Result<()> {
|
pub async fn start(mut config: ClientCfg) -> Result<()> {
|
||||||
println!("Starting client with config: {:?}", config);
|
println!("Starting client with config: {:?}", config);
|
||||||
|
|
||||||
let stream = tokio::net::TcpStream::connect(&config.server).await?;
|
let stream = tokio::net::TcpStream::connect(&config.server).await?;
|
||||||
@@ -61,9 +61,9 @@ pub async fn start(config: ClientCfg) -> Result<()> {
|
|||||||
let mut vpn_buf = vec![0u8; SERVER_PACKET_SIZE];
|
let mut vpn_buf = vec![0u8; SERVER_PACKET_SIZE];
|
||||||
let mut tun_buf = vec![0u8; config.mtu as usize];
|
let mut tun_buf = vec![0u8; config.mtu as usize];
|
||||||
|
|
||||||
let self_uuid = register_client(&client_stream, &config, &mut vpn_buf).await?;
|
let self_uuid = register_client(&client_stream, &mut config, &mut vpn_buf).await?;
|
||||||
|
|
||||||
let tun_device = inti_tun_interface(&config).await?;
|
let tun_device = init_tun_if(&config).await?;
|
||||||
|
|
||||||
let mut route_map = RoutesMap::default();
|
let mut route_map = RoutesMap::default();
|
||||||
|
|
||||||
@@ -88,17 +88,14 @@ pub async fn start(config: ClientCfg) -> Result<()> {
|
|||||||
|
|
||||||
RouterMessages::RouteUpdate(mut updated_routes) => {
|
RouterMessages::RouteUpdate(mut updated_routes) => {
|
||||||
updated_routes.retain(|_ , u| *u != self_uuid);
|
updated_routes.retain(|_ , u| *u != self_uuid);
|
||||||
println!("Received route update from router: {:?}", updated_routes);
|
|
||||||
|
|
||||||
let removed_routes = route_map.iter().filter(|(n,_)| !updated_routes.contains_key(n));
|
let removed_routes = route_map.iter().filter(|(n,_)| !updated_routes.contains_key(n));
|
||||||
println!("Removed routes: {:?}", removed_routes);
|
|
||||||
for removed_route in removed_routes {
|
for removed_route in removed_routes {
|
||||||
println!("Route {:?} removed by router", removed_route);
|
println!("Route {:?} removed by router", removed_route);
|
||||||
del_route(&tun_device, *removed_route.0).await?;
|
del_route(&tun_device, *removed_route.0).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let new_routes = updated_routes.iter().filter(|(n,_)| !route_map.contains_key(n));
|
let new_routes = updated_routes.iter().filter(|(n,_)| !route_map.contains_key(n));
|
||||||
println!("New routes: {:?}", new_routes);
|
|
||||||
for new_route in new_routes {
|
for new_route in new_routes {
|
||||||
println!("Route {} added by router", new_route.0);
|
println!("Route {} added by router", new_route.0);
|
||||||
add_route(&tun_device, *new_route.0).await?;
|
add_route(&tun_device, *new_route.0).await?;
|
||||||
@@ -108,21 +105,23 @@ pub async fn start(config: ClientCfg) -> Result<()> {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RouterMessages::Data(packet_info) => {
|
RouterMessages::Data(_packet_info) => {
|
||||||
if let Some(data) = data {
|
if let Some(data) = data {
|
||||||
println!("Received data packet from router for client {}: size {} bytes", packet_info.src_uuid, packet_info.data_len);
|
|
||||||
tun_device.send(&data).await?;
|
tun_device.send(&data).await?;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
_ => println!("Received message from router: {:?}", message),
|
_ => {
|
||||||
|
let msg = format!("Unexpected message type received from router: {:?}.", message);
|
||||||
|
client_stream.close().await?;
|
||||||
|
bail!(msg);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Error reading from router: {}", e);
|
|
||||||
client_stream.close().await?;
|
client_stream.close().await?;
|
||||||
bail!(format!("Error reading from router: {}", e));
|
bail!("Error reading from router: {}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -131,11 +130,9 @@ pub async fn start(config: ClientCfg) -> Result<()> {
|
|||||||
match data {
|
match data {
|
||||||
Ok(n) => {
|
Ok(n) => {
|
||||||
let packet = etherparse::Ipv4HeaderSlice::from_slice(&tun_buf[..n])?;
|
let packet = etherparse::Ipv4HeaderSlice::from_slice(&tun_buf[..n])?;
|
||||||
println!("Read packet from TUN interface: {} -> {}, size: {}", packet.source_addr(), packet.destination_addr(), n);
|
|
||||||
let dst = packet.destination_addr();
|
let dst = packet.destination_addr();
|
||||||
match ip_match_network(dst, &route_map).await {
|
match ip_match_network(dst, &route_map).await {
|
||||||
Some(uuid) => {
|
Some(uuid) => {
|
||||||
println!("Packet destination {} matches route for client {}, sending to router", dst, uuid);
|
|
||||||
let msg = VpnPacket{
|
let msg = VpnPacket{
|
||||||
dst_uuid: uuid,
|
dst_uuid: uuid,
|
||||||
src_uuid: self_uuid,
|
src_uuid: self_uuid,
|
||||||
@@ -150,9 +147,8 @@ pub async fn start(config: ClientCfg) -> Result<()> {
|
|||||||
|
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Error reading from TUN interface: {}", e);
|
|
||||||
client_stream.close().await?;
|
client_stream.close().await?;
|
||||||
bail!(format!("Error reading from TUN interface: {}", e));
|
bail!("Error reading from TUN interface: {}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -171,8 +167,11 @@ pub async fn start(config: ClientCfg) -> Result<()> {
|
|||||||
// Ok(())
|
// Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, buf: &mut [u8]) -> Result<Uuid> {
|
pub async fn register_client(client_stream: &ClientStream, client_cfg: &mut ClientCfg, buf: &mut [u8]) -> Result<Uuid> {
|
||||||
let register_msg = RouterMessages::CliReg(CliRegMessages::Reg(config.clone()));
|
let local_peer_address = client_cfg.interface_ip;
|
||||||
|
client_cfg.local_routes.push(local_peer_address);
|
||||||
|
|
||||||
|
let register_msg = RouterMessages::CliReg(CliRegMessages::Reg(client_cfg.clone()));
|
||||||
|
|
||||||
let mut client_registration_timeout =
|
let mut client_registration_timeout =
|
||||||
tokio::time::interval_at(Instant::now() + CLIENT_REGISTER_TIMEOUT, CLIENT_REGISTER_TIMEOUT);
|
tokio::time::interval_at(Instant::now() + CLIENT_REGISTER_TIMEOUT, CLIENT_REGISTER_TIMEOUT);
|
||||||
@@ -192,20 +191,17 @@ pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, b
|
|||||||
}
|
}
|
||||||
|
|
||||||
RouterMessages::CliReg(CliRegMessages::RegFailed(err_msg)) => {
|
RouterMessages::CliReg(CliRegMessages::RegFailed(err_msg)) => {
|
||||||
eprintln!("Client registration failed: {}", err_msg);
|
|
||||||
bail!("Client registration failed: {}", err_msg);
|
bail!("Client registration failed: {}", err_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
_ => {
|
unexpected_msg => {
|
||||||
let msg = "Unexpected message type received during client registration.";
|
let msg = format!("Unexpected message type received during client registration: {:?}.", unexpected_msg);
|
||||||
eprintln!("{}", msg);
|
|
||||||
client_stream.close().await?;
|
client_stream.close().await?;
|
||||||
bail!(msg);
|
bail!(msg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Error reading from router during client registration: {}", e);
|
|
||||||
client_stream.close().await?;
|
client_stream.close().await?;
|
||||||
bail!("Error reading from router: {}", e);
|
bail!("Error reading from router: {}", e);
|
||||||
}
|
}
|
||||||
@@ -213,8 +209,6 @@ pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, b
|
|||||||
}
|
}
|
||||||
_= client_registration_timeout.tick() => {
|
_= client_registration_timeout.tick() => {
|
||||||
let msg = "Client registration timed out waiting for confirmation from router.";
|
let msg = "Client registration timed out waiting for confirmation from router.";
|
||||||
eprintln!("{}", msg);
|
|
||||||
eprintln!("Closing connection with Server");
|
|
||||||
client_stream.close().await?;
|
client_stream.close().await?;
|
||||||
bail!(msg);
|
bail!(msg);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ pub mod config;
|
|||||||
pub mod network;
|
pub mod network;
|
||||||
pub mod router;
|
pub mod router;
|
||||||
pub mod tun;
|
pub mod tun;
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests;
|
||||||
|
|
||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
|
|
||||||
|
|||||||
189
src/router.rs
189
src/router.rs
@@ -108,8 +108,8 @@ pub struct ClientStream {
|
|||||||
rx: Arc<Mutex<OwnedReadHalf>>,
|
rx: Arc<Mutex<OwnedReadHalf>>,
|
||||||
}
|
}
|
||||||
impl ClientStream {
|
impl ClientStream {
|
||||||
pub fn new(tx: TcpStream) -> Self {
|
pub fn new(tcp_stream: TcpStream) -> Self {
|
||||||
let (rx, tx) = tx.into_split();
|
let (rx, tx) = tcp_stream.into_split();
|
||||||
Self {
|
Self {
|
||||||
tx: Arc::new(Mutex::new(tx)),
|
tx: Arc::new(Mutex::new(tx)),
|
||||||
rx: Arc::new(Mutex::new(rx)),
|
rx: Arc::new(Mutex::new(rx)),
|
||||||
@@ -153,10 +153,7 @@ impl ClientStream {
|
|||||||
// unwrap is safe here because we already checked payload is some
|
// unwrap is safe here because we already checked payload is some
|
||||||
packet.extend_from_slice(&payload.unwrap());
|
packet.extend_from_slice(&payload.unwrap());
|
||||||
}
|
}
|
||||||
println!(
|
|
||||||
"Sending packet: total_len={}, msg_len={}, payload_len={}, message={:?}",
|
|
||||||
total_len, msg_len, payload_len, msg
|
|
||||||
);
|
|
||||||
if total_len as usize > SERVER_PACKET_SIZE {
|
if total_len as usize > SERVER_PACKET_SIZE {
|
||||||
bail!(
|
bail!(
|
||||||
"Packet too large to send, receive buffer too small: total_len={}, buf_len={}",
|
"Packet too large to send, receive buffer too small: total_len={}, buf_len={}",
|
||||||
@@ -225,28 +222,61 @@ impl ClientStream {
|
|||||||
tx.shutdown().await?;
|
tx.shutdown().await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn peer_addr(&self) -> Result<SocketAddr> {
|
||||||
|
let tx = self.tx.lock().await;
|
||||||
|
let peer_addr = tx.peer_addr()?;
|
||||||
|
Ok(peer_addr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
pub struct Router {
|
pub struct Router {
|
||||||
clients: Arc<RwLock<HashMap<Uuid, VPNClient>>>,
|
clients: Arc<RwLock<HashMap<Uuid, VPNClient>>>,
|
||||||
routing_table: Arc<RwLock<HashMap<Ipv4Net, Uuid>>>,
|
routing_table: Arc<RwLock<HashMap<Ipv4Net, Uuid>>>,
|
||||||
|
registration_lock: Arc<Mutex<()>>,
|
||||||
}
|
}
|
||||||
impl Router {
|
impl Router {
|
||||||
pub async fn register_client(&self, client: ClientCfg, vpn_client: VPNClient) -> Result<()> {
|
pub async fn register_client(
|
||||||
let id = vpn_client.id();
|
&self,
|
||||||
|
client_stream: ClientStream,
|
||||||
|
client_cfg: &ClientCfg,
|
||||||
|
id: Uuid,
|
||||||
|
) -> Result<VPNClient> {
|
||||||
|
// Ensure only one registration process can run at a time to
|
||||||
|
// Drops lock immediately after checking for route conflicts and updating routing table
|
||||||
|
// to minimize contention and allow concurrent client registrations
|
||||||
|
// as much as possible while still preventing
|
||||||
|
let _reg_lock = self.registration_lock.lock().await;
|
||||||
|
|
||||||
// Append local interface IP to client's local routes for routing table construction
|
println!(
|
||||||
let mut client_local_route = client.local_routes;
|
"Registering client {} with local routes: {:?}",
|
||||||
client_local_route.push(client.interface_ip);
|
id, client_cfg.local_routes
|
||||||
|
);
|
||||||
|
|
||||||
// Add Client local routes to global routing table
|
// Build a map of the client's routes to their UUID for efficient registration and routing table updates
|
||||||
for net in client_local_route {
|
let reg_routes: HashMap<Ipv4Net, Uuid> =
|
||||||
self.routing_table.write().await.insert(net, id);
|
HashMap::from_iter(client_cfg.local_routes.iter().map(|route| (*route, id)));
|
||||||
|
|
||||||
|
for route in reg_routes.keys() {
|
||||||
|
if let Some(existing_client_id) = self.check_if_route_exists(&route).await {
|
||||||
|
let router_msg = format!(
|
||||||
|
"Route {} overlaps with existing route for client {}, cannot register client {}",
|
||||||
|
route, existing_client_id, id
|
||||||
|
);
|
||||||
|
let cli_reg_msg = format!("Failed to register, you have overlapping routes {} ", route);
|
||||||
|
client_stream
|
||||||
|
.send(RouterMessages::CliReg(CliRegMessages::RegFailed(cli_reg_msg)), None)
|
||||||
|
.await?;
|
||||||
|
bail!(router_msg);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
self.clients.write().await.insert(id, vpn_client);
|
let vpn_client = VPNClient::new(id, client_stream);
|
||||||
self.update_all_clients().await
|
self.clients.write().await.insert(id, vpn_client.clone());
|
||||||
|
|
||||||
|
self.routing_table.write().await.extend(reg_routes);
|
||||||
|
return Ok(vpn_client);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn remove_client(&self, client_id: Uuid) -> Result<()> {
|
pub async fn remove_client(&self, client_id: Uuid) -> Result<()> {
|
||||||
@@ -263,6 +293,15 @@ impl Router {
|
|||||||
self.clients.read().await.get(&uuid).cloned()
|
self.clients.read().await.get(&uuid).cloned()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn check_if_route_exists(&self, net: &Ipv4Net) -> Option<Uuid> {
|
||||||
|
for (existing_net, existing_client_id) in self.routing_table.read().await.iter() {
|
||||||
|
if existing_net.contains(net) || net.contains(existing_net) {
|
||||||
|
return Some(*existing_client_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn update_all_clients(&self) -> Result<()> {
|
pub async fn update_all_clients(&self) -> Result<()> {
|
||||||
// Lock only the routing table and clients list once to get the current state,
|
// Lock only the routing table and clients list once to get the current state,
|
||||||
// then release locks before sending updates to avoid blocking other operations
|
// then release locks before sending updates to avoid blocking other operations
|
||||||
@@ -325,9 +364,8 @@ pub async fn start(bind_address: SocketAddr) -> Result<()> {
|
|||||||
_ = handle_router_connections(router.clone(), socket) => {}
|
_ = handle_router_connections(router.clone(), socket) => {}
|
||||||
|
|
||||||
_ = tokio::signal::ctrl_c() => {
|
_ = tokio::signal::ctrl_c() => {
|
||||||
println!();
|
|
||||||
let msg = "Received shutdown signal, shutting down router...";
|
let msg = "Received shutdown signal, shutting down router...";
|
||||||
println!("{}", msg);
|
println!("\n{}", msg);
|
||||||
router.close_all_client(msg).await?
|
router.close_all_client(msg).await?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -343,11 +381,19 @@ async fn handle_router_connections(router: Router, socket: tokio::net::TcpListen
|
|||||||
println!("Accepted connection from {}", addr);
|
println!("Accepted connection from {}", addr);
|
||||||
//Clone the router for the new task
|
//Clone the router for the new task
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
println!("Handling connection from {}", addr);
|
let id = Uuid::new_v4();
|
||||||
match handle_client(router, tcp_stream).await {
|
println!("Handling connection from {} :{}", addr, id);
|
||||||
Ok(_) => println!("Finished handling connection from {}", addr),
|
match handle_client(router.clone(), tcp_stream, id).await {
|
||||||
|
Ok(_) => {
|
||||||
|
println!("Finished handling connection from {}", addr);
|
||||||
|
}
|
||||||
Err(e) => eprintln!("Error handling connection from {}: {}", addr, e),
|
Err(e) => eprintln!("Error handling connection from {}: {}", addr, e),
|
||||||
}
|
}
|
||||||
|
println!("Cleaning up client {} from connection {}", id, addr);
|
||||||
|
router
|
||||||
|
.remove_client(id)
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|e| eprintln!("Error removing client {}: {}", id, e));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -357,48 +403,73 @@ async fn handle_router_connections(router: Router, socket: tokio::net::TcpListen
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
|
pub async fn handle_client(router: Router, tcp_stream: TcpStream, id: Uuid) -> Result<Uuid> {
|
||||||
stream.set_nodelay(TCP_NODELAY)?;
|
tcp_stream.set_nodelay(TCP_NODELAY)?;
|
||||||
|
|
||||||
let vpn_client = VPNClient::new(Uuid::new_v4(), ClientStream::new(stream));
|
|
||||||
|
|
||||||
let mut keep_alive_tick = tokio::time::interval_at(Instant::now() + KEEP_ALIVE_INTERVAL, KEEP_ALIVE_INTERVAL);
|
let mut keep_alive_tick = tokio::time::interval_at(Instant::now() + KEEP_ALIVE_INTERVAL, KEEP_ALIVE_INTERVAL);
|
||||||
let mut buf = vec![0u8; SERVER_PACKET_SIZE];
|
let mut buf = vec![0u8; SERVER_PACKET_SIZE];
|
||||||
|
|
||||||
match client_init(&vpn_client, &mut buf).await {
|
let peer_address = tcp_stream.peer_addr()?;
|
||||||
Ok(client) => {
|
|
||||||
|
let (client_stream, client_config) = match client_init(tcp_stream, &mut buf, id).await {
|
||||||
|
Ok((client_stream, client_cfg)) => {
|
||||||
println!(
|
println!(
|
||||||
"Client {} registered with routing table: {:?} and local endpoint {:?}",
|
"Client {} registered with routing table: {:?} and local endpoint {:?}",
|
||||||
vpn_client.id(),
|
id, client_cfg.local_routes, peer_address
|
||||||
client.local_routes,
|
|
||||||
client.interface_ip
|
|
||||||
);
|
);
|
||||||
println!("Registering client {} with router...", vpn_client.id());
|
(client_stream, client_cfg)
|
||||||
router.register_client(client, vpn_client.clone()).await?;
|
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Failed to initialize client {}: {}", vpn_client.id(), e);
|
let msg = format!(
|
||||||
return Err(e);
|
"Failed to initialize client connection form peer {:?}:{} {}",
|
||||||
|
peer_address, id, e
|
||||||
|
);
|
||||||
|
bail!(msg);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
let vpn_client = match router.register_client(client_stream, &client_config, id).await {
|
||||||
|
Ok(vpn_client) => {
|
||||||
|
println!(
|
||||||
|
"Successfully registered client {} with peer address {:?} and routing table: {:?} ",
|
||||||
|
vpn_client.id(),
|
||||||
|
peer_address,
|
||||||
|
client_config.local_routes
|
||||||
|
);
|
||||||
|
vpn_client
|
||||||
|
.send(RouterMessages::CliReg(CliRegMessages::RegOk(id)), None)
|
||||||
|
.await?;
|
||||||
|
vpn_client
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let msg = format!(
|
||||||
|
"Failed to register client {} with peer address {:?}: {}",
|
||||||
|
id, peer_address, e
|
||||||
|
);
|
||||||
|
bail!(msg);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"Finished client initialization for client {} with peer address {:?}, entering main loop...",
|
||||||
|
vpn_client.id(),
|
||||||
|
peer_address
|
||||||
|
);
|
||||||
|
|
||||||
|
router.update_all_clients().await?;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
msg = vpn_client.receive(&mut buf) => {
|
msg = vpn_client.receive(&mut buf) => {
|
||||||
match msg {
|
match msg {
|
||||||
Ok((message, data)) => {
|
Ok((message, data)) => {
|
||||||
println!("Received message from client {}: {:?}", vpn_client.id(), message);
|
|
||||||
match message {
|
match message {
|
||||||
RouterMessages::Quit(msg) =>{
|
RouterMessages::Quit(msg) =>{
|
||||||
println!("Received quit message from client {}: {}", vpn_client.id(), msg);
|
println!("Received quit message from client {}: {}", vpn_client.id(), msg);
|
||||||
println!("Removing client {} from router...", vpn_client.id());
|
println!("Removing client {} from router...", vpn_client.id());
|
||||||
router.remove_client(vpn_client.id()).await?;
|
|
||||||
vpn_client.close().await?;
|
vpn_client.close().await?;
|
||||||
return Ok(());
|
return Ok(id);
|
||||||
}
|
}
|
||||||
RouterMessages::Data(packet_info) => {
|
RouterMessages::Data(packet_info) => {
|
||||||
if let Some(dst_client) = router.get_client(packet_info.dst_uuid).await {
|
if let Some(dst_client) = router.get_client(packet_info.dst_uuid).await {
|
||||||
println!("Forwarding packet from client {} to client {}", vpn_client.id(), dst_client.id());
|
|
||||||
if let Err(e) = dst_client.send(RouterMessages::Data(packet_info), data).await {
|
if let Err(e) = dst_client.send(RouterMessages::Data(packet_info), data).await {
|
||||||
eprintln!("Error forwarding packet to client {}: {}", dst_client.id(), e);
|
eprintln!("Error forwarding packet to client {}: {}", dst_client.id(), e);
|
||||||
}
|
}
|
||||||
@@ -406,13 +477,15 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
|
|||||||
eprintln!("Destination client {} not found for packet from client {}", packet_info.dst_uuid, vpn_client.id());
|
eprintln!("Destination client {} not found for packet from client {}", packet_info.dst_uuid, vpn_client.id());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => println!("Received message from client {}: {:?}", vpn_client.id(), message)
|
_ => {
|
||||||
|
let msg = format!("Received unexpected message from client {}: {:?}", vpn_client.id(), message);
|
||||||
|
vpn_client.close().await?;
|
||||||
|
bail!(msg);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Error reading from client {}: {}", vpn_client.id(), e);
|
eprintln!("Error removing client {} from router...", vpn_client.id());
|
||||||
println!("Removing client {} from router...", vpn_client.id());
|
|
||||||
router.remove_client(vpn_client.id()).await?;
|
|
||||||
vpn_client.close().await?;
|
vpn_client.close().await?;
|
||||||
bail!(format!("Error reading from client: {}", e));
|
bail!(format!("Error reading from client: {}", e));
|
||||||
}
|
}
|
||||||
@@ -420,9 +493,7 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// _= router.notify_changes(vpn_client.id()) => {
|
|
||||||
// println!("Routing table updated. Current routing table:");
|
|
||||||
// }
|
|
||||||
|
|
||||||
_= keep_alive_tick.tick() => {
|
_= keep_alive_tick.tick() => {
|
||||||
// Send keep-alive message to the client
|
// Send keep-alive message to the client
|
||||||
@@ -435,35 +506,33 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
|
|||||||
//Ok(())
|
//Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result<ClientCfg> {
|
pub async fn client_init(tcp_stream: TcpStream, buf: &mut [u8], id: Uuid) -> Result<(ClientStream, ClientCfg)> {
|
||||||
|
let client_stream = ClientStream::new(tcp_stream);
|
||||||
let mut client_registration_timeout =
|
let mut client_registration_timeout =
|
||||||
tokio::time::interval_at(Instant::now() + CLIENT_REGISTER_TIMEOUT, CLIENT_REGISTER_TIMEOUT);
|
tokio::time::interval_at(Instant::now() + CLIENT_REGISTER_TIMEOUT, CLIENT_REGISTER_TIMEOUT);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
msg = vpn_client.receive(buf) => {
|
msg = client_stream.receive(buf) => {
|
||||||
match msg {
|
match msg {
|
||||||
Ok((router_msg, _)) => {
|
Ok((router_msg, _)) => {
|
||||||
match router_msg {
|
match router_msg {
|
||||||
RouterMessages::CliReg(CliRegMessages::Reg(client))=> {
|
RouterMessages::CliReg(CliRegMessages::Reg(client_cfg))=> {
|
||||||
println!("Received client registration with routing table: {:?}", client.local_routes);
|
println!("Received client registration with routing table: {:?}", client_cfg.local_routes);
|
||||||
vpn_client.send(RouterMessages::CliReg(CliRegMessages::RegOk(vpn_client.id())),None).await?;
|
return Ok((client_stream,client_cfg));
|
||||||
return Ok(client);
|
|
||||||
}
|
}
|
||||||
router_msg => {
|
router_msg => {
|
||||||
let msg = format!("Expected client registration message, but received: {:?}", router_msg);
|
let msg = format!("Expected client registration message, but received: {:?}", router_msg);
|
||||||
eprintln!("{}", msg);
|
client_stream.close().await?;
|
||||||
eprintln!("Closing connection with client {}", vpn_client.id());
|
|
||||||
vpn_client.close().await?;
|
|
||||||
bail!(msg);
|
bail!(msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Error reading from client {} during registration: {}", vpn_client.id(), e);
|
let msg = format!("Error reading from client {} during registration: {:?}", id, e);
|
||||||
vpn_client.close().await?;
|
client_stream.close().await?;
|
||||||
bail!(format!("Error reading from client during registration: {}", e));
|
bail!(msg);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -472,8 +541,8 @@ pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result<Clien
|
|||||||
|
|
||||||
_ = client_registration_timeout.tick() => {
|
_ = client_registration_timeout.tick() => {
|
||||||
let msg = format!("Client registration timed out after {}ms", (CLIENT_REGISTER_TIMEOUT.as_millis()));
|
let msg = format!("Client registration timed out after {}ms", (CLIENT_REGISTER_TIMEOUT.as_millis()));
|
||||||
vpn_client.send(RouterMessages::Quit(msg.clone()),None).await?;
|
client_stream.send(RouterMessages::Quit(msg.clone()),None).await?;
|
||||||
vpn_client.close().await?;
|
client_stream.close().await?;
|
||||||
bail!(msg);
|
bail!(msg);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
111
src/tests.rs
Normal file
111
src/tests.rs
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
use std::net::Ipv4Addr;
|
||||||
|
|
||||||
|
use ipnet::Ipv4Net;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
client::ClientCfg,
|
||||||
|
network::ip_match_network,
|
||||||
|
router::{CliRegMessages, RouterMessages, VpnPacket},
|
||||||
|
};
|
||||||
|
|
||||||
|
fn sample_cfg() -> ClientCfg {
|
||||||
|
ClientCfg {
|
||||||
|
server: "127.0.0.1:443".to_string(),
|
||||||
|
interface_ip: "10.8.0.2/32".parse().expect("cidr parse"),
|
||||||
|
interface_name: "xvpn0".to_string(),
|
||||||
|
local_routes: vec!["10.10.0.0/24".parse().expect("cidr parse")],
|
||||||
|
mtu: 1400,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn router_message_keep_alive_roundtrip_json() {
|
||||||
|
let msg = RouterMessages::KeepAlive(123_456);
|
||||||
|
let bytes = msg.to_bytes().await;
|
||||||
|
let decoded = RouterMessages::from_slice(&bytes).await;
|
||||||
|
|
||||||
|
match decoded {
|
||||||
|
RouterMessages::KeepAlive(ts) => assert_eq!(ts, 123_456),
|
||||||
|
other => panic!("unexpected decoded message: {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn router_message_data_roundtrip_json() {
|
||||||
|
let src = Uuid::new_v4();
|
||||||
|
let dst = Uuid::new_v4();
|
||||||
|
let msg = RouterMessages::Data(VpnPacket {
|
||||||
|
src_uuid: src,
|
||||||
|
dst_uuid: dst,
|
||||||
|
data_len: 512,
|
||||||
|
});
|
||||||
|
|
||||||
|
let bytes = msg.to_bytes().await;
|
||||||
|
let decoded = RouterMessages::from_slice(&bytes).await;
|
||||||
|
|
||||||
|
match decoded {
|
||||||
|
RouterMessages::Data(pkt) => {
|
||||||
|
assert_eq!(pkt.src_uuid, src);
|
||||||
|
assert_eq!(pkt.dst_uuid, dst);
|
||||||
|
assert_eq!(pkt.data_len, 512);
|
||||||
|
}
|
||||||
|
other => panic!("unexpected decoded message: {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn router_message_from_invalid_slice_is_unknown() {
|
||||||
|
let decoded = RouterMessages::from_slice(b"{ not valid json").await;
|
||||||
|
match decoded {
|
||||||
|
RouterMessages::Unknown(s) => assert!(!s.is_empty(), "unknown payload should not be empty"),
|
||||||
|
other => panic!("expected Unknown, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cli_reg_message_reg_roundtrip_json() {
|
||||||
|
let cfg = sample_cfg();
|
||||||
|
let msg = CliRegMessages::Reg(cfg.clone());
|
||||||
|
let bytes = msg.to_bytes();
|
||||||
|
let decoded = CliRegMessages::from_slice(&bytes);
|
||||||
|
|
||||||
|
match decoded {
|
||||||
|
CliRegMessages::Reg(decoded_cfg) => {
|
||||||
|
assert_eq!(decoded_cfg.server, cfg.server);
|
||||||
|
assert_eq!(decoded_cfg.interface_ip, cfg.interface_ip);
|
||||||
|
assert_eq!(decoded_cfg.interface_name, cfg.interface_name);
|
||||||
|
assert_eq!(decoded_cfg.local_routes, cfg.local_routes);
|
||||||
|
assert_eq!(decoded_cfg.mtu, cfg.mtu);
|
||||||
|
}
|
||||||
|
other => panic!("unexpected decoded message: {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cli_reg_message_invalid_utf8_is_unknown() {
|
||||||
|
let decoded = CliRegMessages::from_slice(&[0xff, 0xfe, 0xfd]);
|
||||||
|
match decoded {
|
||||||
|
CliRegMessages::Unknown(s) => assert!(s.contains("Invalid UTF-8"), "unexpected content: {s}"),
|
||||||
|
other => panic!("expected Unknown, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn ip_match_network_returns_expected_uuid() {
|
||||||
|
let mut routes = std::collections::HashMap::new();
|
||||||
|
let id = Uuid::new_v4();
|
||||||
|
routes.insert("192.168.1.0/24".parse::<Ipv4Net>().expect("cidr parse"), id);
|
||||||
|
|
||||||
|
let match_id = ip_match_network(Ipv4Addr::new(192, 168, 1, 44), &routes).await;
|
||||||
|
assert_eq!(match_id, Some(id));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn ip_match_network_returns_none_when_no_route_matches() {
|
||||||
|
let mut routes = std::collections::HashMap::new();
|
||||||
|
routes.insert("192.168.1.0/24".parse::<Ipv4Net>().expect("cidr parse"), Uuid::new_v4());
|
||||||
|
|
||||||
|
let no_match = ip_match_network(Ipv4Addr::new(10, 10, 10, 10), &routes).await;
|
||||||
|
assert!(no_match.is_none(), "non-matching ip should return none");
|
||||||
|
}
|
||||||
@@ -5,7 +5,7 @@ use tun_rs::{AsyncDevice, DeviceBuilder};
|
|||||||
|
|
||||||
use crate::client::ClientCfg;
|
use crate::client::ClientCfg;
|
||||||
|
|
||||||
pub async fn inti_tun_interface(config: &ClientCfg) -> Result<AsyncDevice> {
|
pub async fn init_tun_if(config: &ClientCfg) -> Result<AsyncDevice> {
|
||||||
println!(
|
println!(
|
||||||
"Initializing TUN interface with name: {}, IP: {}/{}, MTU: {}",
|
"Initializing TUN interface with name: {}, IP: {}/{}, MTU: {}",
|
||||||
config.interface_name,
|
config.interface_name,
|
||||||
@@ -22,7 +22,6 @@ pub async fn inti_tun_interface(config: &ClientCfg) -> Result<AsyncDevice> {
|
|||||||
Ok(dev) => dev,
|
Ok(dev) => dev,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let msg = format!("Failed to create TUN interface: {:#}", e);
|
let msg = format!("Failed to create TUN interface: {:#}", e);
|
||||||
eprintln!("{}", msg);
|
|
||||||
bail!(msg);
|
bail!(msg);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -44,7 +43,6 @@ pub async fn add_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> {
|
|||||||
if !output.status.success() {
|
if !output.status.success() {
|
||||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
let msg = format!("Failed to add route {} via device {}: {:#}", route, dev, stderr);
|
let msg = format!("Failed to add route {} via device {}: {:#}", route, dev, stderr);
|
||||||
eprintln!("{}", msg);
|
|
||||||
bail!(msg);
|
bail!(msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,7 +63,6 @@ pub async fn del_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> {
|
|||||||
if !output.status.success() {
|
if !output.status.success() {
|
||||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
let msg = format!("Failed to delete route {} via device {}: {:#}", route, dev, stderr);
|
let msg = format!("Failed to delete route {} via device {}: {:#}", route, dev, stderr);
|
||||||
eprintln!("{}", msg);
|
|
||||||
bail!(msg);
|
bail!(msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user