Skip to content
54 changes: 39 additions & 15 deletions src/rs/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -697,15 +697,15 @@ macro_rules! define_quic_handle_ctx_fn {
// This is used by Connection and Stream, but not Listener.
#[allow(dead_code)]
fn consume_callback_ctx(&self) {
let res = unsafe { self.get_callback_ctx() };
if res.is_some() {
unsafe { self.set_context(std::ptr::null_mut()) };
if let Some(ctx) = unsafe { self.take_callback_ctx() } {
std::mem::drop(ctx);
}
}

/// # Safety
/// Caller is responsible for clearing the context if needed.
/// This does not clear the ctx.
#[allow(dead_code)]
unsafe fn get_callback_ctx(&self) -> Option<Box<Box<$callback_type>>> {
let ctx = self.get_context();
if !ctx.is_null() {
Expand All @@ -714,6 +714,18 @@ macro_rules! define_quic_handle_ctx_fn {
None
}
}

/// # Safety
/// Removes the callback context from the handle and returns it.
unsafe fn take_callback_ctx(&self) -> Option<Box<Box<$callback_type>>> {
let ctx = self.get_context();
if ctx.is_null() {
None
} else {
unsafe { self.set_context(std::ptr::null_mut()) };
Some(unsafe { Box::from_raw(ctx as *mut Box<$callback_type>) })
}
}
}
};
}
Expand Down Expand Up @@ -861,17 +873,29 @@ extern "C" fn raw_conn_callback(
context: *mut c_void,
event: *mut ffi::QUIC_CONNECTION_EVENT,
) -> QUIC_STATUS {
let conn = unsafe { ConnectionRef::from_raw(connection) };
let f = unsafe {
(context as *mut Box<ConnectionCallback>)
.as_mut() // allow mutation
.expect("cannot get ConnectionCallback from ctx")
let event_ref = unsafe { event.as_ref().expect("cannot get connection event") };
let cleanup_ctx =
event_ref.Type == ffi::QUIC_CONNECTION_EVENT_TYPE_QUIC_CONNECTION_EVENT_SHUTDOWN_COMPLETE;

let status = match unsafe { (context as *mut Box<ConnectionCallback>).as_mut() } {
Some(f) => {
let event = ConnectionEvent::from(event_ref);
let conn = unsafe { ConnectionRef::from_raw(connection) };
match f(conn, event) {
Ok(_) => StatusCode::QUIC_STATUS_SUCCESS.into(),
Err(e) => e.0,
}
}
// Context already cleaned (e.g. after ShutdownComplete). Nothing to do.
None => StatusCode::QUIC_STATUS_SUCCESS.into(),
Copy link
Collaborator

@guhetier guhetier Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this scenario is possible assuming the logic is sound. ShutdownComplete is the last event delivered to an application by MsQuic (see QUIC_CONNECTION_EVENT.md), and the rust wrapper own the context so it should never be freed before ShutdownComplete.

We can expect it to be present, I think.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! I found a case where the context can be None:
When Connection::drop() is called from outside the callback (e.g., after transferring ownership via a channel), close_inner() does:

  1. take_callback_ctx() → context becomes null
  2. ConnectionClose() → MsQuic triggers ShutdownComplete synchronously before returning
  3. The callback is invoked with a null context

This is demonstrated in the test where we drop(server_conn) from the main thread. Should we:

  1. Keep handling None to support both patterns (close from callback vs close from outside)?
  2. Or change close_inner() to drop the context after the handle is fully closed?

I lean towards option 2, dropping the context after ConnectionClose() returns would ensure the callback always has a valid context, and expect() would be safe. This also aligns better with your expectation that ShutdownComplete always has a context. What do you think?

};
let event = ConnectionEvent::from(unsafe { event.as_ref().unwrap() });
match f(conn, event) {
Ok(_) => StatusCode::QUIC_STATUS_SUCCESS.into(),
Err(e) => e.0,

if cleanup_ctx {
let conn = unsafe { ConnectionRef::from_raw(connection) };
conn.consume_callback_ctx();
}

status
}

impl Connection {
Expand Down Expand Up @@ -920,7 +944,7 @@ impl Connection {
fn close_inner(&self) {
if !self.handle.is_null() {
// get the context and drop it after handle close.
let ctx = unsafe { self.get_callback_ctx() };
let ctx = unsafe { self.take_callback_ctx() };
unsafe {
Api::ffi_ref().ConnectionClose.unwrap()(self.handle);
}
Expand Down Expand Up @@ -1102,7 +1126,7 @@ impl Listener {
fn close_inner(&self) {
if !self.handle.is_null() {
// consume the context and drop it after handle close.
let ctx = unsafe { self.get_callback_ctx() };
let ctx = unsafe { self.take_callback_ctx() };
unsafe {
Api::ffi_ref().ListenerClose.unwrap()(self.handle);
}
Expand Down Expand Up @@ -1179,7 +1203,7 @@ impl Stream {
pub fn close_inner(&self) {
if !self.handle.is_null() {
// consume the context and drop it after handle close.
let ctx = unsafe { self.get_callback_ctx() };
let ctx = unsafe { self.take_callback_ctx() };
unsafe {
Api::ffi_ref().StreamClose.unwrap()(self.handle);
}
Expand Down
129 changes: 126 additions & 3 deletions src/rs/server_client_test.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
use std::{
ffi::c_void,
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::Arc,
sync::{
atomic::{AtomicUsize, Ordering},
mpsc, Arc,
},
time::Duration,
};

use crate::{
config::{Credential, CredentialFlags},
Addr, BufferRef, Configuration, Connection, ConnectionEvent, ConnectionRef, CredentialConfig,
Listener, Registration, RegistrationConfig, Settings, Status, Stream, StreamEvent, StreamRef,
Addr, BufferRef, Configuration, Connection, ConnectionEvent, ConnectionRef,
ConnectionShutdownFlags, CredentialConfig, Listener, Registration, RegistrationConfig,
Settings, Status, Stream, StreamEvent, StreamRef,
};

fn buffers_to_string(buffers: &[BufferRef]) -> String {
Expand Down Expand Up @@ -286,3 +291,121 @@ fn test_server_client() {
}
l.stop();
}

#[test]
fn connection_ref_callback_cleanup() {
struct DropGuard {
counter: Arc<AtomicUsize>,
}

impl Drop for DropGuard {
fn drop(&mut self) {
self.counter.fetch_add(1, Ordering::SeqCst);
}
}

let cred = get_test_cred();

let reg = Registration::new(&RegistrationConfig::default()).unwrap();
let alpn = [BufferRef::from("qcleanup")];
let settings = Settings::new()
.set_ServerResumptionLevel(crate::ServerResumptionLevel::ResumeAndZerortt)
.set_PeerBidiStreamCount(1);

let server_config = Configuration::open(&reg, &alpn, Some(&settings)).unwrap();

let cred_config = CredentialConfig::new()
.set_credential_flags(CredentialFlags::NO_CERTIFICATE_VALIDATION)
.set_credential(cred);
server_config.load_credential(&cred_config).unwrap();
let server_config = Arc::new(server_config);

let drop_counter = Arc::new(AtomicUsize::new(0));
let (server_handle_tx, server_handle_rx) = mpsc::channel::<crate::ffi::HQUIC>();

let listener = Listener::open(&reg, {
let server_config = server_config.clone();
let drop_counter = drop_counter.clone();
let server_handle_tx = server_handle_tx.clone();
move |_, ev| {
if let crate::ListenerEvent::NewConnection { connection, .. } = ev {
let callback_guard = DropGuard {
counter: drop_counter.clone(),
};
let server_handle_tx = server_handle_tx.clone();
connection.set_callback_handler(move |conn: ConnectionRef, ev: ConnectionEvent| {
let _guard_ref = &callback_guard;
match ev {
ConnectionEvent::Connected { .. } => {}
ConnectionEvent::ShutdownComplete { .. } => {
let _ = server_handle_tx.send(unsafe { conn.as_raw() });
}
_ => {}
};
Ok(())
});
connection.set_configuration(&server_config)?;
}
Ok(())
}
})
.unwrap();

let local_address = Addr::from(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0));
listener.start(&alpn, Some(&local_address)).unwrap();
let port = listener
.get_local_addr()
.unwrap()
.as_socket()
.unwrap()
.port();

let client_settings = Settings::new().set_IdleTimeoutMs(500);
let client_config = Configuration::open(&reg, &alpn, Some(&client_settings)).unwrap();
let cred_config = CredentialConfig::new_client()
.set_credential_flags(CredentialFlags::NO_CERTIFICATE_VALIDATION);
client_config.load_credential(&cred_config).unwrap();

let (client_done_tx, client_done_rx) = mpsc::channel();
let client_conn = Connection::open(&reg, {
let client_done_tx = client_done_tx.clone();
move |conn: ConnectionRef, ev: ConnectionEvent| {
match ev {
ConnectionEvent::Connected { .. } => {
conn.shutdown(ConnectionShutdownFlags::NONE, 0);
}
ConnectionEvent::ShutdownComplete { .. } => {
let _ = client_done_tx.send(());
}
_ => {}
}
Ok(())
}
})
.unwrap();

client_conn
.start(&client_config, "127.0.0.1", port)
.unwrap();

let raw_conn = server_handle_rx
.recv_timeout(Duration::from_secs(5))
.expect("Server did not receive shutdown event");
client_done_rx
.recv_timeout(Duration::from_secs(5))
.expect("Client did not complete shutdown");

let mut retries = 50;
while drop_counter.load(Ordering::SeqCst) == 0 && retries > 0 {
std::thread::sleep(Duration::from_millis(10));
retries -= 1;
}
assert_eq!(
drop_counter.load(Ordering::SeqCst),
1,
"ConnectionRef callback context was not cleaned up"
);

unsafe { Connection::from_raw(raw_conn) };
listener.stop();
}
Loading