Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/tuic-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ quinn-congestions = { workspace = true }

# Tokio/Async
crossbeam-utils = { version = "0.8", default-features = false, features = ["std"] }
tokio = { version = "1", default-features = false, features = ["io-util", "macros", "net", "parking_lot", "rt-multi-thread", "time"] }
tokio-util = { version = "0.7", default-features = false, features = ["compat"] }
tokio = { version = "1", default-features = false, features = ["io-util", "macros", "net", "parking_lot", "rt-multi-thread", "signal", "time"] }
tokio-util = { version = "0.7", default-features = false, features = ["compat", "rt"] }

# TLS
rustls = { version = "0.23", default-features = false }
Expand Down
49 changes: 37 additions & 12 deletions crates/tuic-client/src/forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ use std::{
use bytes::Bytes;
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
use tokio::net::{TcpListener, UdpSocket};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, info, warn};
use wind_core::{
AbstractOutbound,
AbstractOutbound, AppContext,
types::TargetAddr,
udp::{UdpPacket, UdpStream},
};
Expand All @@ -36,16 +37,19 @@ fn next_assoc_id() -> u16 {
// hot path. The local `sessions: HashMap<SocketAddr, UdpForwardSession>`
// in `run_udp_forwarder` is the only routing table in use.

pub async fn start(tcp: Vec<TcpForward>, udp: Vec<UdpForward>) {
/// Spawn the configured TCP/UDP forwarders into `ctx.tasks`, each driven by a
/// child of `ctx.token` so shutdown stops the accept/recv loops and aborts
/// in-flight per-connection tasks.
pub async fn start(tcp: Vec<TcpForward>, udp: Vec<UdpForward>, ctx: &Arc<AppContext>) {
for entry in tcp {
tokio::spawn(run_tcp_forwarder(entry));
ctx.tasks.spawn(run_tcp_forwarder(entry, ctx.token.child_token()));
}
for entry in udp {
tokio::spawn(run_udp_forwarder(entry));
ctx.tasks.spawn(run_udp_forwarder(entry, ctx.token.child_token()));
}
}

async fn run_tcp_forwarder(entry: TcpForward) {
async fn run_tcp_forwarder(entry: TcpForward, cancel: CancellationToken) {
let listener = match create_tcp_listener(entry.listen) {
Ok(l) => l,
Err(err) => {
Expand All @@ -60,13 +64,28 @@ async fn run_tcp_forwarder(entry: TcpForward) {
remote = entry.remote
);
loop {
match listener.accept().await {
Ok((inbound, peer)) => {
let remote = entry.remote.clone();
let span = tracing::info_span!("forward_tcp", peer = %peer);
tokio::spawn(handle_tcp_conn(inbound, remote).instrument(span));
tokio::select! {
_ = cancel.cancelled() => {
info!("[forward-tcp] cancellation received, shutting down");
break;
}
res = listener.accept() => match res {
Ok((inbound, peer)) => {
let remote = entry.remote.clone();
let span = tracing::info_span!("forward_tcp", peer = %peer);
let conn_cancel = cancel.child_token();
tokio::spawn(
async move {
tokio::select! {
_ = conn_cancel.cancelled() => {}
_ = handle_tcp_conn(inbound, remote) => {}
}
}
.instrument(span),
);
}
Err(err) => warn!("[forward-tcp] accept error: {err}"),
}
Err(err) => warn!("[forward-tcp] accept error: {err}"),
}
}
}
Expand Down Expand Up @@ -121,7 +140,7 @@ struct UdpForwardSession {
last_seen: std::time::Instant,
}

async fn run_udp_forwarder(entry: UdpForward) {
async fn run_udp_forwarder(entry: UdpForward, cancel: CancellationToken) {
let socket = match UdpSocket::bind(entry.listen).await {
Ok(s) => s,
Err(err) => {
Expand Down Expand Up @@ -153,6 +172,12 @@ async fn run_udp_forwarder(entry: UdpForward) {

loop {
tokio::select! {
_ = cancel.cancelled() => {
info!("[forward-udp] cancellation received, shutting down");
// Dropping `sessions` drops every `tx_to_out`, which closes the
// per-session relay tasks' inbound channels so they exit cleanly.
break;
}
recv = socket.recv_from(&mut buf) => match recv {
Ok((n, src_addr)) => {
let pkt = Bytes::copy_from_slice(&buf[..n]);
Expand Down
40 changes: 33 additions & 7 deletions crates/tuic-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

use std::sync::Arc;

use tokio_util::sync::CancellationToken;
use wind_core::AppContext;

pub mod config;
Expand All @@ -14,16 +15,35 @@ pub mod wind_adapter;

pub use config::Config;

/// Run the TUIC client with the given configuration (using wind-tuic)
/// Run the TUIC client with the given configuration (using wind-tuic).
///
/// Constructs its own [`CancellationToken`] internally; callers that want to
/// drive a graceful shutdown from outside should use [`run_with_cancel`].
pub async fn run(cfg: Config) -> eyre::Result<()> {
// Initialize wind-tuic connection
let ctx = Arc::new(AppContext::default());
wind_adapter::create_connection(ctx, cfg.relay).await?;
run_with_cancel(cfg, CancellationToken::new()).await
}

/// Run the TUIC client with a caller-owned cancel token.
///
/// Cancelling `cancel` stops the SOCKS5 accept loop and the TCP/UDP
/// forwarders, closes the TUIC connection (so the server sees the client go
/// away immediately instead of waiting out its idle timeout), and waits for
/// tracked background tasks to drain. Pair with `tokio::select!` on `ctrl_c()`
/// so signal-triggered shutdown is graceful instead of relying on runtime drop.
pub async fn run_with_cancel(cfg: Config, cancel: CancellationToken) -> eyre::Result<()> {
// The context token is the caller's token, so the outbound's heartbeat poll
// task (which closes the QUIC connection on cancellation) and every UDP
// session task wind down from the same `cancel()`.
let ctx = Arc::new(AppContext {
tasks: tokio_util::task::TaskTracker::new(),
token: cancel.clone(),
});
wind_adapter::create_connection(ctx.clone(), cfg.relay).await?;

tracing::info!("TUIC client initialized with wind-tuic backend");

// Start forwarders
forward::start(cfg.local.tcp_forward.clone(), cfg.local.udp_forward.clone()).await;
// Start forwarders (tracked in ctx.tasks, cancelled via ctx.token).
forward::start(cfg.local.tcp_forward.clone(), cfg.local.udp_forward.clone(), &ctx).await;

// Start SOCKS5 server
match socks5::Server::set_config(cfg.local) {
Expand All @@ -33,6 +53,12 @@ pub async fn run(cfg: Config) -> eyre::Result<()> {
}
}

socks5::Server::start().await;
socks5::Server::start(cancel.clone()).await;

// `start` only returns once cancelled; drain the tracked background tasks
// (heartbeat poll, forwarder loops, UDP sessions) before returning so the
// QUIC close frames flush while the runtime is still alive.
ctx.tasks.close();
ctx.tasks.wait().await;
Ok(())
}
44 changes: 42 additions & 2 deletions crates/tuic-client/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::{process, str::FromStr};
use std::{process, str::FromStr, time::Duration};

use chrono::{Offset, TimeZone};
use clap::Parser;
#[cfg(feature = "jemallocator")]
use tikv_jemallocator::Jemalloc;
use tokio_util::sync::CancellationToken;
use tracing::level_filters::LevelFilter;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use tuic_client::config::{Cli, Config, EnvState};
Expand Down Expand Up @@ -51,5 +52,44 @@ async fn main() -> eyre::Result<()> {
)),
)
.try_init()?;
tuic_client::run(cfg).await
// Own the cancel token here so the ctrl-c branch can trigger a graceful
// shutdown: stop the SOCKS5/forwarder accept loops, close the TUIC
// connection (the server learns we left instead of waiting out its idle
// timeout), and drain background tasks — same structure as tuic-server.
let cancel = CancellationToken::new();
let mut client = tokio::spawn(tuic_client::run_with_cancel(cfg, cancel.clone()));

tokio::select! {
res = &mut client => {
match res {
Ok(Ok(())) => {}
Ok(Err(err)) => {
tracing::error!("Client exited with error: {err}");
return Err(err);
}
Err(join_err) => {
tracing::error!("Client task panicked or was cancelled: {join_err}");
return Err(eyre::eyre!("Client task panicked or was cancelled: {join_err}"));
}
}
}
res = tokio::signal::ctrl_c() => {
if let Err(err) = res {
tracing::error!("Failed to listen for Ctrl-C: {err}");
return Err(eyre::eyre!("Failed to listen for Ctrl-C: {err}"));
}
tracing::info!("Received Ctrl-C, shutting down.");
cancel.cancel();

// Give in-flight sessions up to 10 seconds to drain before dropping
// out of main and letting runtime teardown abort the rest.
match tokio::time::timeout(Duration::from_secs(10), client).await {
Ok(Ok(Ok(()))) => {}
Ok(Ok(Err(err))) => tracing::warn!("Client drained with error: {err}"),
Ok(Err(join_err)) => tracing::warn!("Client task drain join error: {join_err}"),
Err(_) => tracing::warn!("Client did not drain within 10s of Ctrl-C; aborting outstanding tasks"),
}
}
}
Ok(())
}
36 changes: 30 additions & 6 deletions crates/tuic-client/src/socks5/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use socks5_server::{
auth::{NoAuth, Password},
};
use tokio::{net::TcpListener, sync::RwLock as AsyncRwLock};
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use tracing::{Instrument, debug, info, warn};

use crate::{config::Local, error::Error};
Expand Down Expand Up @@ -111,20 +112,43 @@ impl Server {
})
}

pub async fn start() {
/// Accept SOCKS5 connections until `cancel` fires, then wait for in-flight
/// session tasks to wind down (each gets a child token, so cancellation
/// aborts handshakes and relays promptly).
pub async fn start(cancel: CancellationToken) {
let server = SERVER.get().unwrap();

warn!("[socks5] server started, listening on {}", server.inner.local_addr().unwrap());

let conn_tasks = TaskTracker::new();
loop {
match server.inner.accept().await {
Ok((conn, addr)) => {
let span = tracing::info_span!("socks5", peer = %addr);
tokio::spawn(Self::handle_socks5_conn(server, conn).instrument(span));
tokio::select! {
_ = cancel.cancelled() => {
info!("[socks5] cancellation received, shutting down");
break;
}
res = server.inner.accept() => match res {
Ok((conn, addr)) => {
let span = tracing::info_span!("socks5", peer = %addr);
let conn_cancel = cancel.child_token();
conn_tasks.spawn(
async move {
tokio::select! {
_ = conn_cancel.cancelled() => {
debug!("session aborted by shutdown");
}
_ = Self::handle_socks5_conn(server, conn) => {}
}
}
.instrument(span),
);
}
Err(err) => warn!("[socks5] failed to establish connection: {err}"),
}
Err(err) => warn!("[socks5] failed to establish connection: {err}"),
}
}
conn_tasks.close();
conn_tasks.wait().await;
}

async fn handle_socks5_conn(server: &Server, conn: socks5_server::IncomingConnection) {
Expand Down
36 changes: 26 additions & 10 deletions crates/tuic-server/src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use rustls::{
};
use sha2::{Digest, Sha256};
use tokio::fs;
use tokio_util::sync::CancellationToken;
use tracing::warn;

#[derive(Debug)]
Expand All @@ -24,7 +25,7 @@ pub struct CertResolver {
hash: ArcSwap<[u8; 32]>,
}
impl CertResolver {
pub async fn new(cert_path: &Path, key_path: &Path, interval: Duration) -> Result<Arc<Self>> {
pub async fn new(cert_path: &Path, key_path: &Path, interval: Duration, cancel: CancellationToken) -> Result<Arc<Self>> {
let cert_key = load_cert_key(cert_path, key_path).await?;
let hash = Self::calc_hash(cert_path, key_path).await?;
let resolver = Arc::new(Self {
Expand All @@ -33,20 +34,24 @@ impl CertResolver {
cert_key: ArcSwap::new(cert_key),
hash: ArcSwap::new(Arc::new(hash)),
});
// Start file watcher in background
// Start file watcher in background; exits on `cancel` so the task does
// not outlive the server when used as a library.
let resolver_clone = resolver.clone();
tokio::spawn(async move {
if let Err(e) = resolver_clone.start_watch(interval).await {
if let Err(e) = resolver_clone.start_watch(interval, cancel).await {
warn!("Certificate watcher exited with error: {e}");
}
});
Ok(resolver)
}

async fn start_watch(&self, interval: Duration) -> Result<()> {
async fn start_watch(&self, interval: Duration, cancel: CancellationToken) -> Result<()> {
let mut interval = tokio::time::interval(interval);
loop {
interval.tick().await;
tokio::select! {
_ = cancel.cancelled() => return Ok(()),
_ = interval.tick() => {}
}

// Treat I/O errors here as transient (the cert file may be in the
// middle of an ACME-driven `rename`, the directory may be missing
Expand Down Expand Up @@ -322,9 +327,14 @@ mod tests {
let (cert_der, key_der) = generate_test_cert_der()?;
let (cert_file, key_file) = create_temp_cert_file(&cert_der, &key_der).await;

let resolver = CertResolver::new(cert_file.path(), key_file.path(), Duration::from_secs(10))
.await
.unwrap();
let resolver = CertResolver::new(
cert_file.path(),
key_file.path(),
Duration::from_secs(10),
CancellationToken::new(),
)
.await
.unwrap();

let certified_key = resolver.cert_key.load_full();
assert!(!certified_key.cert.is_empty());
Expand All @@ -341,7 +351,7 @@ mod tests {
tokio::fs::write(&cert_path, &cert_pem.as_bytes()).await.unwrap();
tokio::fs::write(&key_path, &key_pem.as_bytes()).await.unwrap();

let resolver = CertResolver::new(&cert_path, &key_path, Duration::from_micros(100))
let resolver = CertResolver::new(&cert_path, &key_path, Duration::from_micros(100), CancellationToken::new())
.await
.unwrap();

Expand Down Expand Up @@ -372,7 +382,13 @@ mod tests {
let load_result = load_cert_key(cert_file.path(), key_file.path()).await;
assert!(load_result.is_err());

let resolver_result = CertResolver::new(cert_file.path(), key_file.path(), Duration::from_secs(10)).await;
let resolver_result = CertResolver::new(
cert_file.path(),
key_file.path(),
Duration::from_secs(10),
CancellationToken::new(),
)
.await;
assert!(resolver_result.is_err());
}

Expand Down
6 changes: 5 additions & 1 deletion crates/tuic-server/src/wind_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,15 @@ async fn create_quiche_inbound(ctx: &Arc<TuicAppContext>) -> eyre::Result<Tuiche
enable_0rtt: quiche.zero_rtt,
};

// Wire the binary's cancel token into the inbound so ctrl-c actually stops
// the accept loop and closes live connections (mirrors the quinn backend,
// which derives its token from `ctx.cancel` via `wind_ctx`).
let mut builder = TuicheInboundBuilder::new()
.listen_addr(cfg.server)
.connection_opts(opts)
.certificate_path(cert.clone())
.private_key_path(key.clone());
.private_key_path(key.clone())
.cancel_token(ctx.cancel.child_token());
for (uuid, pwd) in &cfg.users {
builder = builder.user(*uuid, pwd.clone());
}
Expand Down
Loading
Loading