Refactor code structure for improved readability and maintainability
This commit is contained in:
11
Cargo.lock
generated
11
Cargo.lock
generated
@@ -832,6 +832,16 @@ dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_bytes"
|
||||
version = "0.11.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a5d440709e79d88e51ac01c4b72fc6cb7314017bb7da9eeff678aa94c10e3ea8"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_core"
|
||||
version = "1.0.228"
|
||||
@@ -1595,6 +1605,7 @@ dependencies = [
|
||||
"etherparse",
|
||||
"ipnet",
|
||||
"serde",
|
||||
"serde_bytes",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"toml",
|
||||
|
||||
12
Cargo.toml
12
Cargo.toml
@@ -8,16 +8,7 @@ clap = { version = "4.5.37", features = ["derive"] }
|
||||
serde = { version = "1.0.228", features = ["derive"] }
|
||||
serde_json = "1.0.149"
|
||||
toml = "1.0.3"
|
||||
tokio = { version = "1.49.0", features = [
|
||||
"macros",
|
||||
"rt-multi-thread",
|
||||
"time",
|
||||
"fs",
|
||||
"net",
|
||||
"io-util",
|
||||
"sync",
|
||||
"signal",
|
||||
] }
|
||||
tokio = { version = "1.49.0", features = ["macros", "rt-multi-thread", "time", "fs", "net", "io-util", "sync", "signal", "process"] }
|
||||
anyhow = "1.0.102"
|
||||
uuid = { version = "1.21.0", features = ["v4", "serde"] }
|
||||
ipnet = { version = "2.11.0", features = ["serde"] }
|
||||
@@ -25,3 +16,4 @@ base64 = "0.22.1"
|
||||
tun-rs = { version = "2.8.2", features = ["async"] }
|
||||
chrono = "0.4.44"
|
||||
etherparse = "0.19.0"
|
||||
serde_bytes = "0.11.19"
|
||||
|
||||
6
client_test/.config/vpn_config.toml
Normal file
6
client_test/.config/vpn_config.toml
Normal file
@@ -0,0 +1,6 @@
|
||||
[mode.Client]
|
||||
server = "127.0.0.1:443"
|
||||
interface_ip = "2.2.2.2/32"
|
||||
interface_name = "xvpn0"
|
||||
local_routes = ["4.4.4.0/24"]
|
||||
mtu = 1400
|
||||
1
client_test/xvpn
Symbolic link
1
client_test/xvpn
Symbolic link
@@ -0,0 +1 @@
|
||||
../target/release/xvpn
|
||||
127
src/client.rs
127
src/client.rs
@@ -4,16 +4,16 @@ use clap::Args;
|
||||
use ipnet::Ipv4Net;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net::tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||
time::Instant,
|
||||
};
|
||||
use tokio::time::Instant;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
network::ip_match_network,
|
||||
router::{CLIENT_REGISTER_TIMEOUT, CliRegMessages, RouterMessages, SERVER_PACKET_SIZE},
|
||||
tun::inti_tun_interface,
|
||||
router::{
|
||||
CLIENT_REGISTER_TIMEOUT, CliRegMessages, ClientStream, RouterMessages, RoutesMap, SERVER_PACKET_SIZE,
|
||||
TCP_NODELAY, VpnPacket,
|
||||
},
|
||||
tun::{add_route, del_route, inti_tun_interface},
|
||||
};
|
||||
|
||||
pub struct ClientStaTistic {
|
||||
@@ -54,51 +54,97 @@ pub async fn start(config: ClientCfg) -> Result<()> {
|
||||
println!("Starting client with config: {:?}", config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(&config.server).await?;
|
||||
//stream.set_nodelay(true)?;
|
||||
let (mut rx, mut tx) = stream.into_split();
|
||||
// let client_stream = ClientStream::new(tx);
|
||||
stream.set_nodelay(TCP_NODELAY)?;
|
||||
|
||||
let client_stream = ClientStream::new(stream);
|
||||
|
||||
let mut vpn_buf = vec![0u8; SERVER_PACKET_SIZE];
|
||||
let mut tun_buf = vec![0u8; config.mtu as usize];
|
||||
register_client(&mut rx, &mut tx, &config, &mut vpn_buf).await?;
|
||||
|
||||
let self_uuid = register_client(&client_stream, &config, &mut vpn_buf).await?;
|
||||
|
||||
let tun_device = inti_tun_interface(&config).await?;
|
||||
|
||||
let mut route_map = RoutesMap::default();
|
||||
|
||||
println!("Client registration successful. Entering main loop to receive messages from router...");
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = rx.read(&mut vpn_buf) => {
|
||||
msg = client_stream.receive(&mut vpn_buf) => {
|
||||
match msg {
|
||||
Ok(0) => {
|
||||
println!("Connection to router closed by peer.");
|
||||
return Ok(());
|
||||
}
|
||||
Ok(n) => {
|
||||
match RouterMessages::from_slice(&vpn_buf[..n]){
|
||||
Ok((message, data)) => {
|
||||
match message {
|
||||
RouterMessages::KeepAlive(timestamp) => {
|
||||
println!("Received keep-alive message from router with timestamp: {}, delta {} ms", timestamp, (Utc::now().timestamp_micros() - timestamp).abs() as f64 / 1000.0);
|
||||
}
|
||||
|
||||
_ => println!("Received message from router: {:?}", RouterMessages::from_slice(&vpn_buf[..n]))
|
||||
};
|
||||
|
||||
RouterMessages::RouteUpdate(mut updated_routes) => {
|
||||
updated_routes.retain(|_ , u| *u != self_uuid);
|
||||
println!("Received route update from router: {:?}", updated_routes);
|
||||
|
||||
let removed_routes = route_map.iter().filter(|(n,_)| !updated_routes.contains_key(n));
|
||||
println!("Removed routes: {:?}", removed_routes);
|
||||
for removed_route in removed_routes {
|
||||
println!("Route {:?} removed by router", removed_route);
|
||||
del_route(&tun_device, *removed_route.0).await?;
|
||||
}
|
||||
|
||||
let new_routes = updated_routes.iter().filter(|(n,_)| !route_map.contains_key(n));
|
||||
println!("New routes: {:?}", new_routes);
|
||||
for new_route in new_routes {
|
||||
println!("Route {} added by router", new_route.0);
|
||||
add_route(&tun_device, *new_route.0).await?;
|
||||
}
|
||||
|
||||
route_map = updated_routes;
|
||||
|
||||
}
|
||||
|
||||
RouterMessages::Data(packet_info) => {
|
||||
if let Some(data) = data {
|
||||
println!("Received data packet from router for client {}: size {} bytes", packet_info.src_uuid, packet_info.data_len);
|
||||
tun_device.send(&data).await?;
|
||||
}
|
||||
},
|
||||
|
||||
_ => println!("Received message from router: {:?}", message),
|
||||
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error reading from router: {}", e);
|
||||
client_stream.close().await?;
|
||||
return Err(anyhow::anyhow!(format!("Error reading from router: {}", e)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data = tun_device.recv(&mut tun_buf) => {
|
||||
match data {
|
||||
Ok(n) => {
|
||||
let packet = etherparse::Ipv4HeaderSlice::from_slice(&tun_buf[..n])?;
|
||||
let src = packet.source_addr();
|
||||
match ip_match_network(src, &config.local_routes).await {
|
||||
Some(net) => println!("Source IP {} matches local route {}", src, net),
|
||||
println!("Read packet from TUN interface: {} -> {}, size: {}", packet.source_addr(), packet.destination_addr(), n);
|
||||
let dst = packet.destination_addr();
|
||||
match ip_match_network(dst, &route_map).await {
|
||||
Some(uuid) => {
|
||||
println!("Packet destination {} matches route for client {}, sending to router", dst, uuid);
|
||||
let msg = VpnPacket{
|
||||
dst_uuid: uuid,
|
||||
src_uuid: self_uuid,
|
||||
data_len: n,
|
||||
};
|
||||
let data = tun_buf[..n].to_vec();
|
||||
// tx.write(&RouterMessages::Data(msg), data.as_vec()).await?;
|
||||
client_stream.send(RouterMessages::Data(msg), Some(data)).await?;
|
||||
},
|
||||
None => {},
|
||||
}
|
||||
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error reading from TUN interface: {}", e);
|
||||
client_stream.close().await?;
|
||||
return Err(anyhow::anyhow!(format!("Error reading from TUN interface: {}", e)));
|
||||
}
|
||||
}
|
||||
@@ -109,33 +155,24 @@ pub async fn start(config: ClientCfg) -> Result<()> {
|
||||
// Ok(())
|
||||
}
|
||||
|
||||
pub async fn register_client(
|
||||
rx: &mut OwnedReadHalf,
|
||||
tx: &mut OwnedWriteHalf,
|
||||
config: &ClientCfg,
|
||||
buf: &mut [u8],
|
||||
) -> Result<()> {
|
||||
pub async fn register_client(client_stream: &ClientStream, config: &ClientCfg, buf: &mut [u8]) -> Result<Uuid> {
|
||||
let register_msg = RouterMessages::CliReg(CliRegMessages::Reg(config.clone()));
|
||||
|
||||
let mut client_registration_timeout =
|
||||
tokio::time::interval_at(Instant::now() + CLIENT_REGISTER_TIMEOUT, CLIENT_REGISTER_TIMEOUT);
|
||||
tx.write_all(®ister_msg.to_bytes()).await?;
|
||||
|
||||
client_stream.send(register_msg, None).await?;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = rx.read(buf) => {
|
||||
msg = client_stream.receive(buf) => {
|
||||
match msg {
|
||||
Ok(0) => {
|
||||
let msg = "Connection closed by router while waiting for registration confirmation.";
|
||||
eprintln!("{}", msg);
|
||||
return Err(anyhow::anyhow!(msg));
|
||||
}
|
||||
Ok(n) => {
|
||||
let response = RouterMessages::from_slice(&buf[..n]);
|
||||
println!("Received registration response from router: {:?}", response);
|
||||
match response {
|
||||
Ok((message, _data)) => {
|
||||
match message {
|
||||
RouterMessages::CliReg(CliRegMessages::RegOk(uuid)) => {
|
||||
println!("Received registration response from router: {:?}", message);
|
||||
println!("Client registration successful with UUID: {}", uuid);
|
||||
return Ok(());
|
||||
return Ok(uuid);
|
||||
}
|
||||
RouterMessages::CliReg(CliRegMessages::RegFailed(err_msg)) => {
|
||||
eprintln!("Client registration failed: {}", err_msg);
|
||||
@@ -144,28 +181,26 @@ pub async fn register_client(
|
||||
_ => {
|
||||
let msg = "Unexpected message type received during client registration.";
|
||||
eprintln!("{}", msg);
|
||||
client_stream.close().await?;
|
||||
return Err(anyhow::anyhow!(msg));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error reading from router during client registration: {}", e);
|
||||
client_stream.close().await?;
|
||||
return Err(anyhow::anyhow!(format!("Error reading from router: {}", e)));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
_= client_registration_timeout.tick() => {
|
||||
let msg = "Client registration timed out waiting for confirmation from router.";
|
||||
eprintln!("{}", msg);
|
||||
eprintln!("Closing connection with Server");
|
||||
tx.shutdown().await?;
|
||||
client_stream.close().await?;
|
||||
return Err(anyhow::anyhow!(msg));
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
// Ok(())
|
||||
|
||||
@@ -84,7 +84,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
match commandline.mode {
|
||||
OpModes::Client(client) => client::start(client).await?,
|
||||
OpModes::Router { bind_address } => {
|
||||
router::start(bind_address).await;
|
||||
router::start(bind_address).await?;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
use ipnet::Ipv4Net;
|
||||
use std::net::Ipv4Addr;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn ip_match_network(ip: Ipv4Addr, networks: &[Ipv4Net]) -> Option<Ipv4Net> {
|
||||
for net in networks {
|
||||
use crate::router::RoutesMap;
|
||||
|
||||
pub async fn ip_match_network(ip: Ipv4Addr, networks: &RoutesMap) -> Option<Uuid> {
|
||||
for (net, uuid) in networks {
|
||||
if net.contains(&ip) {
|
||||
return Some(*net);
|
||||
return Some(*uuid);
|
||||
}
|
||||
}
|
||||
None
|
||||
|
||||
259
src/router.rs
259
src/router.rs
@@ -1,16 +1,18 @@
|
||||
use anyhow::Result;
|
||||
use anyhow::{Context, Result, bail};
|
||||
use chrono::Utc;
|
||||
|
||||
use etherparse::err::packet;
|
||||
use ipnet::Ipv4Net;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration};
|
||||
use serde_bytes::Bytes;
|
||||
use std::{collections::HashMap, mem, net::SocketAddr, sync::Arc, time::Duration};
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net::{
|
||||
TcpStream,
|
||||
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||
},
|
||||
sync::{Mutex, RwLock},
|
||||
sync::{Mutex, Notify, RwLock},
|
||||
time::Instant,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
@@ -19,36 +21,42 @@ use crate::client::ClientCfg;
|
||||
|
||||
pub static KEEP_ALIVE_INTERVAL: Duration = tokio::time::Duration::from_secs(30);
|
||||
pub static CLIENT_REGISTER_TIMEOUT: Duration = tokio::time::Duration::from_millis(100);
|
||||
pub static TCP_NODELAY: bool = true;
|
||||
pub const SERVER_PACKET_SIZE: usize = 1024 * 9;
|
||||
|
||||
pub trait ReceiverTrait {}
|
||||
|
||||
pub type RoutesMap = HashMap<Ipv4Net, Uuid>;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum RouterMessages {
|
||||
CliReg(CliRegMessages),
|
||||
KeepAlive(i64),
|
||||
Data(VpnPacket),
|
||||
Quit(String),
|
||||
RouteUpdate(RoutesMap),
|
||||
Unknown(String),
|
||||
}
|
||||
|
||||
impl RouterMessages {
|
||||
pub async fn to_bytes(&self) -> Vec<u8> {
|
||||
serde_json::to_vec(self).expect("Unable to serialize RouterMessages")
|
||||
}
|
||||
|
||||
pub async fn from_slice(slice: &[u8]) -> Self {
|
||||
serde_json::from_slice(slice).unwrap_or(RouterMessages::Unknown(String::from_utf8_lossy(slice).to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VpnPacket {
|
||||
pub src_uuid: Uuid,
|
||||
pub dst_uuid: Uuid,
|
||||
pub payload: Vec<u8>,
|
||||
}
|
||||
impl RouterMessages {
|
||||
pub fn to_bytes(&self) -> Vec<u8> {
|
||||
serde_json::to_vec(self).expect("Unable to serialize RouteMessages")
|
||||
}
|
||||
pub fn from_slice(slice: &[u8]) -> Self {
|
||||
serde_json::from_slice(slice).unwrap_or(RouterMessages::Unknown(
|
||||
String::from_utf8(slice.to_vec()).unwrap_or_else(|b| format!("Invalid UTF-8: {:?}", b.as_bytes())),
|
||||
))
|
||||
}
|
||||
pub data_len: usize,
|
||||
}
|
||||
|
||||
impl RouterMessages {}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum CliRegMessages {
|
||||
Reg(ClientCfg),
|
||||
@@ -82,8 +90,12 @@ impl VPNClient {
|
||||
pub fn id(&self) -> Uuid {
|
||||
self.id
|
||||
}
|
||||
pub async fn send(&self, msg: RouterMessages) -> Result<()> {
|
||||
self.stream.send(msg).await
|
||||
pub async fn send(&self, msg: RouterMessages, data: Option<Vec<u8>>) -> Result<usize> {
|
||||
self.stream.send(msg, data).await
|
||||
}
|
||||
|
||||
pub async fn receive(&self, buf: &mut [u8]) -> Result<(RouterMessages, Option<Vec<u8>>)> {
|
||||
self.stream.receive(buf).await
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<()> {
|
||||
@@ -94,20 +106,101 @@ impl VPNClient {
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClientStream {
|
||||
tx: Arc<Mutex<OwnedWriteHalf>>,
|
||||
rx: Arc<Mutex<OwnedReadHalf>>,
|
||||
}
|
||||
impl ClientStream {
|
||||
pub fn new(tx: OwnedWriteHalf) -> Self {
|
||||
pub fn new(tx: TcpStream) -> Self {
|
||||
let (rx, tx) = tx.into_split();
|
||||
Self {
|
||||
// write can be shared
|
||||
tx: Arc::new(Mutex::new(tx)),
|
||||
// read is done only by one task
|
||||
rx: Arc::new(Mutex::new(rx)),
|
||||
}
|
||||
}
|
||||
pub async fn send(&self, msg: RouterMessages) -> Result<()> {
|
||||
let bytes = msg.to_bytes();
|
||||
pub async fn send(&self, msg: RouterMessages, payload: Option<Vec<u8>>) -> Result<usize> {
|
||||
// [u16 total_len][u16 msg_len][u16 payload_len][msg bytes][payload bytes]
|
||||
|
||||
let msg_bytes = serde_json::to_vec(&msg)?;
|
||||
let msg_len: u16 = msg_bytes.len().try_into().map_err(|e| {
|
||||
anyhow::anyhow!(format!(
|
||||
"Failed to convert message length to u16: {} (message size: {} bytes)",
|
||||
e,
|
||||
msg_bytes.len()
|
||||
))
|
||||
})?;
|
||||
|
||||
let payload_len: u16 = payload.as_ref().map_or(0, |p| p.len()).try_into().map_err(|e| {
|
||||
anyhow::anyhow!(format!(
|
||||
"Failed to convert payload length to u16: {} (payload size: {} bytes)",
|
||||
e,
|
||||
payload.as_ref().map_or(0, |p| p.len())
|
||||
))
|
||||
})?;
|
||||
|
||||
let total_len_usize = 6usize + msg_bytes.len() + payload_len as usize;
|
||||
|
||||
if total_len_usize > u16::MAX as usize {
|
||||
bail!("Total packet size {} exceeds maximum of {}", total_len_usize, u16::MAX);
|
||||
}
|
||||
|
||||
let total_len = total_len_usize as u16;
|
||||
|
||||
let mut packet = Vec::with_capacity(total_len_usize);
|
||||
packet.extend_from_slice(&total_len.to_be_bytes());
|
||||
packet.extend_from_slice(&msg_len.to_be_bytes());
|
||||
packet.extend_from_slice(&payload_len.to_be_bytes());
|
||||
packet.extend_from_slice(&msg_bytes);
|
||||
|
||||
if payload.is_some() {
|
||||
// unwrap is safe here because we already checked payload is some
|
||||
packet.extend_from_slice(&payload.unwrap());
|
||||
}
|
||||
println!(
|
||||
"Sending packet: total_len={}, msg_len={}, payload_len={}, message={:?}",
|
||||
total_len, msg_len, payload_len, msg
|
||||
);
|
||||
let mut tx = self.tx.lock().await;
|
||||
tx.write_all(&bytes).await?;
|
||||
Ok(())
|
||||
tx.write_all(&packet).await?;
|
||||
Ok(packet.len())
|
||||
}
|
||||
|
||||
pub async fn receive(&self, buf: &mut [u8]) -> Result<(RouterMessages, Option<Vec<u8>>)> {
|
||||
// [u16 total_len][u16 msg_len][u16 payload_len][msg bytes][payload bytes]
|
||||
|
||||
let mut rx = self.rx.lock().await;
|
||||
let router_message: RouterMessages;
|
||||
let payload: Option<Vec<u8>>;
|
||||
|
||||
match rx.read_exact(&mut buf[..6]).await {
|
||||
Ok(_) => {
|
||||
let total_len = u16::from_be_bytes(buf[0..2].try_into().unwrap()) as usize;
|
||||
let msg_len = u16::from_be_bytes(buf[2..4].try_into().unwrap()) as usize;
|
||||
let payload_len = u16::from_be_bytes(buf[4..6].try_into().unwrap()) as usize;
|
||||
|
||||
if total_len != 6 + msg_len + payload_len {
|
||||
bail!(
|
||||
"Invalid packet length: total_len={}, but expected {}",
|
||||
total_len,
|
||||
6 + msg_len + payload_len
|
||||
);
|
||||
}
|
||||
|
||||
rx.read_exact(&mut buf[..msg_len]).await?;
|
||||
router_message = RouterMessages::from_slice(&buf[..msg_len]).await;
|
||||
|
||||
payload = if payload_len > 0 {
|
||||
rx.read_exact(&mut buf[..payload_len]).await?;
|
||||
Some(buf[..payload_len].to_vec())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error reading from client stream: {}", e);
|
||||
return Err(anyhow::anyhow!(format!("Error reading from client stream: {}", e)));
|
||||
}
|
||||
};
|
||||
|
||||
Ok((router_message, payload))
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<()> {
|
||||
@@ -121,37 +214,69 @@ impl ClientStream {
|
||||
pub struct Router {
|
||||
clients: Arc<RwLock<HashMap<Uuid, VPNClient>>>,
|
||||
routing_table: Arc<RwLock<HashMap<Ipv4Net, Uuid>>>,
|
||||
notify: Arc<Notify>,
|
||||
}
|
||||
impl Router {
|
||||
pub async fn register_client(&self, routing_table: &[Ipv4Net], vpn_client: VPNClient) -> Result<()> {
|
||||
let id = Uuid::new_v4();
|
||||
for net in routing_table {
|
||||
self.routing_table.write().await.insert(*net, id);
|
||||
pub async fn register_client(&self, client: ClientCfg, vpn_client: VPNClient) -> Result<()> {
|
||||
let id = vpn_client.id();
|
||||
let mut client_local_route = client.local_routes;
|
||||
client_local_route.push(client.interface_ip);
|
||||
self.routing_table.write().await.insert(client.interface_ip, id);
|
||||
for net in client_local_route {
|
||||
self.routing_table.write().await.insert(net, id);
|
||||
}
|
||||
|
||||
self.clients.write().await.insert(id, vpn_client);
|
||||
|
||||
self.notify.notify_waiters();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn remove_client(&self, client_id: Uuid) -> Result<()> {
|
||||
self.clients.write().await.remove(&client_id);
|
||||
self.routing_table.write().await.retain(|_, &mut id| id != client_id);
|
||||
self.notify.notify_waiters();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_routing_table(&self) -> HashMap<Ipv4Net, Uuid> {
|
||||
self.routing_table.read().await.clone()
|
||||
}
|
||||
|
||||
pub async fn get_client(&self, uuid: Uuid) -> Option<VPNClient> {
|
||||
self.clients.read().await.get(&uuid).cloned()
|
||||
}
|
||||
|
||||
pub async fn notify_changes(&self) {
|
||||
self.notify.notified().await;
|
||||
for client in self.clients.read().await.values() {
|
||||
println!("Notifying client {} of routing table change", client.id());
|
||||
if let Err(e) = client
|
||||
.send(RouterMessages::RouteUpdate(self.get_routing_table().await), None)
|
||||
.await
|
||||
{
|
||||
eprintln!("Error notifying client {}: {}", client.id(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start(bind_address: SocketAddr) {
|
||||
pub async fn start(bind_address: SocketAddr) -> Result<()> {
|
||||
println!("Starting router on {}...", bind_address);
|
||||
|
||||
let router = Router::default();
|
||||
|
||||
let socket = tokio::net::TcpListener::bind(bind_address).await.unwrap();
|
||||
let socket = tokio::net::TcpListener::bind(bind_address).await?;
|
||||
println!("Router is listening on {}...", bind_address);
|
||||
|
||||
loop {
|
||||
let router = router.clone();
|
||||
match socket.accept().await {
|
||||
Ok((tcp_stream, addr)) => {
|
||||
println!("Accepted connection from {}", addr);
|
||||
//Clone the router for the new task
|
||||
let router = router.clone();
|
||||
tokio::spawn(async move {
|
||||
println!("Handling connection from {}", addr);
|
||||
match handle_client(router.clone(), tcp_stream).await {
|
||||
match handle_client(router, tcp_stream).await {
|
||||
Ok(_) => println!("Finished handling connection from {}", addr),
|
||||
Err(e) => eprintln!("Error handling connection from {}: {}", addr, e),
|
||||
}
|
||||
@@ -162,54 +287,72 @@ pub async fn start(bind_address: SocketAddr) {
|
||||
}
|
||||
}
|
||||
}
|
||||
// Ok(())
|
||||
}
|
||||
|
||||
pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
|
||||
let (mut rx, tx) = stream.into_split();
|
||||
let vpn_client = VPNClient::new(Uuid::new_v4(), ClientStream::new(tx));
|
||||
stream.set_nodelay(TCP_NODELAY)?;
|
||||
|
||||
let vpn_client = VPNClient::new(Uuid::new_v4(), ClientStream::new(stream));
|
||||
|
||||
let mut keep_alive_tick = tokio::time::interval_at(Instant::now() + KEEP_ALIVE_INTERVAL, KEEP_ALIVE_INTERVAL);
|
||||
let mut buf = vec![0u8; SERVER_PACKET_SIZE];
|
||||
|
||||
match client_init(&mut rx, &vpn_client, &mut buf).await {
|
||||
match client_init(&vpn_client, &mut buf).await {
|
||||
Ok(client) => {
|
||||
println!(
|
||||
"Client {} registered with routing table: {:?}",
|
||||
"Client {} registered with routing table: {:?} and local endpoint {:?}",
|
||||
vpn_client.id(),
|
||||
client.local_routes
|
||||
client.local_routes,
|
||||
client.interface_ip
|
||||
);
|
||||
println!("Registering client {} with router...", vpn_client.id());
|
||||
router.register_client(&client.local_routes, vpn_client.clone()).await?;
|
||||
router.register_client(client, vpn_client.clone()).await?;
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Failed to initialize client {}: {}", vpn_client.id(), e);
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = rx.read(&mut buf) => {
|
||||
msg = vpn_client.receive(&mut buf) => {
|
||||
match msg {
|
||||
Ok(0) => {
|
||||
println!("Client {} closed the connection", vpn_client.id());
|
||||
return Ok(());
|
||||
Ok((message, data)) => {
|
||||
println!("Received message from client {}: {:?}", vpn_client.id(), message);
|
||||
match message {
|
||||
RouterMessages::Data(packet_info) => {
|
||||
if let Some(dst_client) = router.get_client(packet_info.dst_uuid).await {
|
||||
println!("Forwarding packet from client {} to client {}", vpn_client.id(), dst_client.id());
|
||||
if let Err(e) = dst_client.send(RouterMessages::Data(packet_info), data).await {
|
||||
eprintln!("Error forwarding packet to client {}: {}", dst_client.id(), e);
|
||||
}
|
||||
} else {
|
||||
eprintln!("Destination client {} not found for packet from client {}", packet_info.dst_uuid, vpn_client.id());
|
||||
}
|
||||
}
|
||||
_ => println!("Received message from client {}: {:?}", vpn_client.id(), message)
|
||||
}
|
||||
Ok(n) => {
|
||||
let msg = RouterMessages::from_slice(&buf[..n]);
|
||||
println!("Received message from client {}: {:?}", vpn_client.id(), msg);
|
||||
// Here you would implement the logic to handle messages from the client, such as routing data to other clients based on the routing table
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error reading from client {}: {}", vpn_client.id(), e);
|
||||
println!("Removing client {} from router...", vpn_client.id());
|
||||
router.remove_client(vpn_client.id()).await?;
|
||||
vpn_client.close().await?;
|
||||
return Err(anyhow::anyhow!(format!("Error reading from client: {}", e)));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
_= router.notify_changes() => {
|
||||
println!("Routing table updated. Current routing table:");
|
||||
}
|
||||
|
||||
_= keep_alive_tick.tick() => {
|
||||
// Send keep-alive message to the client
|
||||
vpn_client.send(RouterMessages::KeepAlive(Utc::now().timestamp_micros())).await?;
|
||||
vpn_client.send(RouterMessages::KeepAlive(Utc::now().timestamp_micros()), None).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -217,25 +360,19 @@ pub async fn handle_client(router: Router, stream: TcpStream) -> Result<()> {
|
||||
//Ok(())
|
||||
}
|
||||
|
||||
pub async fn client_init(rx: &mut OwnedReadHalf, vpn_client: &VPNClient, buf: &mut [u8]) -> Result<ClientCfg> {
|
||||
pub async fn client_init(vpn_client: &VPNClient, buf: &mut [u8]) -> Result<ClientCfg> {
|
||||
let mut client_registration_timeout =
|
||||
tokio::time::interval_at(Instant::now() + CLIENT_REGISTER_TIMEOUT, CLIENT_REGISTER_TIMEOUT);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = rx.read(buf) => {
|
||||
match msg {
|
||||
Ok(0) => {
|
||||
println!("Client {} closed the connection during registration", vpn_client.id());
|
||||
return Err(anyhow::anyhow!("Client closed the connection during registration"));
|
||||
}
|
||||
Ok(n) => {
|
||||
let msg = RouterMessages::from_slice(&buf[..n]);
|
||||
msg = vpn_client.receive(buf) => {
|
||||
match msg {
|
||||
Ok((router_msg, _)) => {
|
||||
match router_msg {
|
||||
RouterMessages::CliReg(CliRegMessages::Reg(client))=> {
|
||||
println!("Received client registration with routing table: {:?}", client.local_routes);
|
||||
let uuid = Uuid::new_v4();
|
||||
vpn_client.send(RouterMessages::CliReg(CliRegMessages::RegOk(uuid))).await?;
|
||||
vpn_client.send(RouterMessages::CliReg(CliRegMessages::RegOk(vpn_client.id())),None).await?;
|
||||
return Ok(client);
|
||||
}
|
||||
router_msg => {
|
||||
@@ -250,17 +387,17 @@ pub async fn client_init(rx: &mut OwnedReadHalf, vpn_client: &VPNClient, buf: &m
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error reading from client {} during registration: {}", vpn_client.id(), e);
|
||||
vpn_client.close().await?;
|
||||
return Err(anyhow::anyhow!(format!("Error reading from client during registration: {}", e)));
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
_ = client_registration_timeout.tick() => {
|
||||
let msg = format!("Client registration timed out after {}ms", (CLIENT_REGISTER_TIMEOUT.as_millis()));
|
||||
vpn_client.send(RouterMessages::Quit(msg.clone())).await?;
|
||||
vpn_client.send(RouterMessages::Quit(msg.clone()),None).await?;
|
||||
vpn_client.close().await?;
|
||||
return Err(anyhow::anyhow!(msg));
|
||||
|
||||
|
||||
46
src/tun.rs
46
src/tun.rs
@@ -1,4 +1,6 @@
|
||||
use anyhow::Result;
|
||||
use anyhow::{Context, Result};
|
||||
use ipnet::Ipv4Net;
|
||||
use tokio::process::Command;
|
||||
use tun_rs::{AsyncDevice, DeviceBuilder};
|
||||
|
||||
use crate::client::ClientCfg;
|
||||
@@ -27,3 +29,45 @@ pub async fn inti_tun_interface(config: &ClientCfg) -> Result<AsyncDevice> {
|
||||
|
||||
Ok(device)
|
||||
}
|
||||
|
||||
pub async fn add_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> {
|
||||
let dev = tun_device.name().context("failed to get tun device name")?;
|
||||
|
||||
println!("Adding route {} dev {}", route, dev);
|
||||
|
||||
let output = Command::new("ip")
|
||||
.args(["route", "replace", &route.to_string(), "dev", &dev])
|
||||
.output()
|
||||
.await
|
||||
.context("failed to execute ip route")?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let msg = format!("Failed to add route {} via device {}: {:#}", route, dev, stderr);
|
||||
eprintln!("{}", msg);
|
||||
return Err(anyhow::anyhow!(msg));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn del_route(tun_device: &AsyncDevice, route: Ipv4Net) -> Result<()> {
|
||||
let dev = tun_device.name().context("failed to get tun device name")?;
|
||||
|
||||
println!("Deleting route {} dev {}", route, dev);
|
||||
|
||||
let output = Command::new("ip")
|
||||
.args(["route", "del", &route.to_string(), "dev", &dev])
|
||||
.output()
|
||||
.await
|
||||
.context("failed to execute ip route")?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let msg = format!("Failed to delete route {} via device {}: {:#}", route, dev, stderr);
|
||||
eprintln!("{}", msg);
|
||||
return Err(anyhow::anyhow!(msg));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user