aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordatdenkikniet <jcdra1@gmail.com>2025-03-16 12:46:23 +0100
committerEmil Fresk <emil.fresk@gmail.com>2025-03-24 07:36:23 +0000
commitb5db43550185c2acd62d3f27bc89f2f24b4fbb22 (patch)
treea7fe58d0c5a553ef94f8959f02a3d22c20058d87
parentd76252d767cb0990b2362c5fb15ac3ee88675f3e (diff)
rtic-sync: introduce loom compat layer and apply it to `channel`
-rw-r--r--rtic-sync/Cargo.toml12
-rw-r--r--rtic-sync/src/arbiter.rs1
-rw-r--r--rtic-sync/src/channel.rs250
-rw-r--r--rtic-sync/src/lib.rs7
-rw-r--r--rtic-sync/src/loom_cs.rs69
-rw-r--r--rtic-sync/src/signal.rs4
-rw-r--r--rtic-sync/src/unsafecell.rs43
7 files changed, 299 insertions, 87 deletions
diff --git a/rtic-sync/Cargo.toml b/rtic-sync/Cargo.toml
index 60d8be2..cb54eef 100644
--- a/rtic-sync/Cargo.toml
+++ b/rtic-sync/Cargo.toml
@@ -25,15 +25,23 @@ portable-atomic = { version = "1", default-features = false }
embedded-hal = { version = "1.0.0" }
embedded-hal-async = { version = "1.0.0" }
embedded-hal-bus = { version = "0.2.0", features = ["async"] }
-
defmt-03 = { package = "defmt", version = "0.3", optional = true }
[dev-dependencies]
cassette = "0.3.0"
static_cell = "2.1.0"
-tokio = { version = "1", features = ["rt", "macros", "time"] }
+
+[target.'cfg(not(loom))'.dev-dependencies]
+tokio = { version = "1", features = ["rt", "macros", "time"], default-features = false }
[features]
default = []
testing = ["critical-section/std", "rtic-common/testing"]
defmt-03 = ["dep:defmt-03", "embedded-hal/defmt-03", "embedded-hal-async/defmt-03", "embedded-hal-bus/defmt-03"]
+
+[lints.rust]
+unexpected_cfgs = { level = "allow", check-cfg = ['cfg(loom)'] }
+
+[target.'cfg(loom)'.dependencies]
+loom = { version = "0.7.2", features = [ "futures" ] }
+critical-section = { version = "1", features = [ "restore-state-bool" ] }
diff --git a/rtic-sync/src/arbiter.rs b/rtic-sync/src/arbiter.rs
index 768e200..60559df 100644
--- a/rtic-sync/src/arbiter.rs
+++ b/rtic-sync/src/arbiter.rs
@@ -381,6 +381,7 @@ pub mod i2c {
}
}
+#[cfg(not(loom))]
#[cfg(test)]
mod tests {
use super::*;
diff --git a/rtic-sync/src/channel.rs b/rtic-sync/src/channel.rs
index 0bd2cd2..9c2111f 100644
--- a/rtic-sync/src/channel.rs
+++ b/rtic-sync/src/channel.rs
@@ -1,7 +1,7 @@
//! An async aware MPSC channel that can be used on no-alloc systems.
+use crate::unsafecell::UnsafeCell;
use core::{
- cell::UnsafeCell,
future::poll_fn,
mem::MaybeUninit,
pin::Pin,
@@ -48,11 +48,21 @@ unsafe impl<T, const N: usize> Send for Channel<T, N> {}
unsafe impl<T, const N: usize> Sync for Channel<T, N> {}
-struct UnsafeAccess<'a, const N: usize> {
- freeq: &'a mut Deque<u8, N>,
- readyq: &'a mut Deque<u8, N>,
- receiver_dropped: &'a mut bool,
- num_senders: &'a mut usize,
+macro_rules! cs_access {
+ ($name:ident, $type:ty) => {
+ /// Access the value mutably.
+ ///
+ /// SAFETY: this function must not be called recursively within `f`.
+ unsafe fn $name<F, R>(&self, _cs: critical_section::CriticalSection, f: F) -> R
+ where
+ F: FnOnce(&mut $type) -> R,
+ {
+ self.$name.with_mut(|v| {
+ let v = unsafe { &mut *v };
+ f(v)
+ })
+ }
+ };
}
impl<T, const N: usize> Default for Channel<T, N> {
@@ -65,6 +75,7 @@ impl<T, const N: usize> Channel<T, N> {
const _CHECK: () = assert!(N < 256, "This queue support a maximum of 255 entries");
/// Create a new channel.
+ #[cfg(not(loom))]
pub const fn new() -> Self {
Self {
freeq: UnsafeCell::new(Deque::new()),
@@ -77,37 +88,49 @@ impl<T, const N: usize> Channel<T, N> {
}
}
+ /// Create a new channel.
+ #[cfg(loom)]
+ pub fn new() -> Self {
+ Self {
+ freeq: UnsafeCell::new(Deque::new()),
+ readyq: UnsafeCell::new(Deque::new()),
+ receiver_waker: WakerRegistration::new(),
+ slots: core::array::from_fn(|_| UnsafeCell::new(MaybeUninit::uninit())),
+ wait_queue: WaitQueue::new(),
+ receiver_dropped: UnsafeCell::new(false),
+ num_senders: UnsafeCell::new(0),
+ }
+ }
+
/// Split the queue into a `Sender`/`Receiver` pair.
pub fn split(&mut self) -> (Sender<'_, T, N>, Receiver<'_, T, N>) {
+ // SAFETY: we have exclusive access to `self`.
+ let freeq = self.freeq.get_mut();
+ let freeq = unsafe { freeq.deref() };
+
// Fill free queue
for idx in 0..N as u8 {
- assert!(!self.freeq.get_mut().is_full());
+ assert!(!freeq.is_full());
// SAFETY: This safe as the loop goes from 0 to the capacity of the underlying queue.
unsafe {
- self.freeq.get_mut().push_back_unchecked(idx);
+ freeq.push_back_unchecked(idx);
}
}
- assert!(self.freeq.get_mut().is_full());
+ assert!(freeq.is_full());
// There is now 1 sender
- *self.num_senders.get_mut() = 1;
+ // SAFETY: we have exclusive access to `self`.
+ unsafe { *self.num_senders.get_mut().deref() = 1 };
(Sender(self), Receiver(self))
}
- fn access<'a>(&'a self, _cs: critical_section::CriticalSection) -> UnsafeAccess<'a, N> {
- // SAFETY: This is safe as are in a critical section.
- unsafe {
- UnsafeAccess {
- freeq: &mut *self.freeq.get(),
- readyq: &mut *self.readyq.get(),
- receiver_dropped: &mut *self.receiver_dropped.get(),
- num_senders: &mut *self.num_senders.get(),
- }
- }
- }
+ cs_access!(freeq, Deque<u8, N>);
+ cs_access!(readyq, Deque<u8, N>);
+ cs_access!(receiver_dropped, bool);
+ cs_access!(num_senders, usize);
/// Return free slot `slot` to the channel.
///
@@ -127,8 +150,14 @@ impl<T, const N: usize> Channel<T, N> {
unsafe { freeq_slot.replace(Some(slot), cs) };
wait_head.wake();
} else {
- assert!(!self.access(cs).freeq.is_full());
- unsafe { self.access(cs).freeq.push_back_unchecked(slot) }
+ // SAFETY: `self.freeq` is not called recursively.
+ unsafe {
+ self.freeq(cs, |freeq| {
+ assert!(!freeq.is_full());
+ // SAFETY: `freeq` is not full.
+ freeq.push_back_unchecked(slot);
+ });
+ }
}
})
}
@@ -136,6 +165,7 @@ impl<T, const N: usize> Channel<T, N> {
/// Creates a split channel with `'static` lifetime.
#[macro_export]
+#[cfg(not(loom))]
macro_rules! make_channel {
($type:ty, $size:expr) => {{
static mut CHANNEL: $crate::channel::Channel<$type, $size> =
@@ -285,16 +315,21 @@ impl<T, const N: usize> Sender<'_, T, N> {
fn send_footer(&mut self, idx: u8, val: T) {
// Write the value to the slots, note; this memcpy is not under a critical section.
unsafe {
- ptr::write(
- self.0.slots.get_unchecked(idx as usize).get() as *mut T,
- val,
- )
+ let first_element = self.0.slots.get_unchecked(idx as usize).get_mut();
+ let ptr = first_element.deref().as_mut_ptr();
+ ptr::write(ptr, val)
}
// Write the value into the ready queue.
critical_section::with(|cs| {
- assert!(!self.0.access(cs).readyq.is_full());
- unsafe { self.0.access(cs).readyq.push_back_unchecked(idx) }
+ // SAFETY: `self.0.readyq` is not called recursively.
+ unsafe {
+ self.0.readyq(cs, |readyq| {
+ assert!(!readyq.is_full());
+ // SAFETY: ready is not full.
+ readyq.push_back_unchecked(idx);
+ });
+ }
});
fence(Ordering::SeqCst);
@@ -315,12 +350,16 @@ impl<T, const N: usize> Sender<'_, T, N> {
return Err(TrySendError::NoReceiver(val));
}
- let idx =
- if let Some(idx) = critical_section::with(|cs| self.0.access(cs).freeq.pop_front()) {
- idx
- } else {
- return Err(TrySendError::Full(val));
- };
+ let free_slot = critical_section::with(|cs| unsafe {
+ // SAFETY: `self.0.freeq` is not called recursively.
+ self.0.freeq(cs, |q| q.pop_front())
+ });
+
+ let idx = if let Some(idx) = free_slot {
+ idx
+ } else {
+ return Err(TrySendError::Full(val));
+ };
self.send_footer(idx, val);
@@ -368,7 +407,8 @@ impl<T, const N: usize> Sender<'_, T, N> {
}
let wq_empty = self.0.wait_queue.is_empty();
- let freeq_empty = self.0.access(cs).freeq.is_empty();
+ // SAFETY: `self.0.freeq` is not called recursively.
+ let freeq_empty = unsafe { self.0.freeq(cs, |q| q.is_empty()) };
// SAFETY: This pointer is only dereferenced here and on drop of the future
// which happens outside this `poll_fn`'s stack frame.
@@ -416,9 +456,15 @@ impl<T, const N: usize> Sender<'_, T, N> {
}
// We are not in the wait queue, no one else is waiting, and there is a free slot available.
else {
- assert!(!self.0.access(cs).freeq.is_empty());
- let slot = unsafe { self.0.access(cs).freeq.pop_back_unchecked() };
- Poll::Ready(Ok(slot))
+ // SAFETY: `self.0.freeq` is not called recursively.
+ unsafe {
+ self.0.freeq(cs, |freeq| {
+ assert!(!freeq.is_empty());
+ // SAFETY: `freeq` is non-empty
+ let slot = freeq.pop_back_unchecked();
+ Poll::Ready(Ok(slot))
+ })
+ }
}
})
})
@@ -438,17 +484,26 @@ impl<T, const N: usize> Sender<'_, T, N> {
/// Returns true if there is no `Receiver`s.
pub fn is_closed(&self) -> bool {
- critical_section::with(|cs| *self.0.access(cs).receiver_dropped)
+ critical_section::with(|cs| unsafe {
+ // SAFETY: `self.0.receiver_dropped` is not called recursively.
+ self.0.receiver_dropped(cs, |v| *v)
+ })
}
/// Is the queue full.
pub fn is_full(&self) -> bool {
- critical_section::with(|cs| self.0.access(cs).freeq.is_empty())
+ critical_section::with(|cs| unsafe {
+ // SAFETY: `self.0.freeq` is not called recursively.
+ self.0.freeq(cs, |v| v.is_empty())
+ })
}
/// Is the queue empty.
pub fn is_empty(&self) -> bool {
- critical_section::with(|cs| self.0.access(cs).freeq.is_full())
+ critical_section::with(|cs| unsafe {
+ // SAFETY: `self.0.freeq` is not called recursively.
+ self.0.freeq(cs, |v| v.is_full())
+ })
}
}
@@ -456,9 +511,13 @@ impl<T, const N: usize> Drop for Sender<'_, T, N> {
fn drop(&mut self) {
// Count down the reference counter
let num_senders = critical_section::with(|cs| {
- *self.0.access(cs).num_senders -= 1;
-
- *self.0.access(cs).num_senders
+ unsafe {
+ // SAFETY: `self.0.num_senders` is not called recursively.
+ self.0.num_senders(cs, |s| {
+ *s -= 1;
+ *s
+ })
+ }
});
// If there are no senders, wake the receiver to do error handling.
@@ -471,7 +530,10 @@ impl<T, const N: usize> Drop for Sender<'_, T, N> {
impl<T, const N: usize> Clone for Sender<'_, T, N> {
fn clone(&self) -> Self {
// Count up the reference counter
- critical_section::with(|cs| *self.0.access(cs).num_senders += 1);
+ critical_section::with(|cs| unsafe {
+ // SAFETY: `self.0.num_senders` is not called recursively.
+ self.0.num_senders(cs, |v| *v += 1);
+ });
Self(self.0)
}
@@ -511,11 +573,18 @@ impl<T, const N: usize> Receiver<'_, T, N> {
/// Receives a value if there is one in the channel, non-blocking.
pub fn try_recv(&mut self) -> Result<T, ReceiveError> {
// Try to get a ready slot.
- let ready_slot = critical_section::with(|cs| self.0.access(cs).readyq.pop_front());
+ let ready_slot = critical_section::with(|cs| unsafe {
+ // SAFETY: `self.0.readyq` is not called recursively.
+ self.0.readyq(cs, |q| q.pop_front())
+ });
if let Some(rs) = ready_slot {
// Read the value from the slots, note; this memcpy is not under a critical section.
- let r = unsafe { ptr::read(self.0.slots.get_unchecked(rs as usize).get() as *const T) };
+ let r = unsafe {
+ let first_element = self.0.slots.get_unchecked(rs as usize).get_mut();
+ let ptr = first_element.deref().as_ptr();
+ ptr::read(ptr)
+ };
// Return the index to the free queue after we've read the value.
// SAFETY: `rs` comes directly from `readyq`.
@@ -556,24 +625,36 @@ impl<T, const N: usize> Receiver<'_, T, N> {
/// Returns true if there are no `Sender`s.
pub fn is_closed(&self) -> bool {
- critical_section::with(|cs| *self.0.access(cs).num_senders == 0)
+ critical_section::with(|cs| unsafe {
+ // SAFETY: `self.0.num_senders` is not called recursively.
+ self.0.num_senders(cs, |v| *v == 0)
+ })
}
/// Is the queue full.
pub fn is_full(&self) -> bool {
- critical_section::with(|cs| self.0.access(cs).readyq.is_full())
+ critical_section::with(|cs| unsafe {
+ // SAFETY: `self.0.readyq` is not called recursively.
+ self.0.readyq(cs, |v| v.is_full())
+ })
}
/// Is the queue empty.
pub fn is_empty(&self) -> bool {
- critical_section::with(|cs| self.0.access(cs).readyq.is_empty())
+ critical_section::with(|cs| unsafe {
+ // SAFETY: `self.0.readyq` is not called recursively.
+ self.0.readyq(cs, |v| v.is_empty())
+ })
}
}
impl<T, const N: usize> Drop for Receiver<'_, T, N> {
fn drop(&mut self) {
// Mark the receiver as dropped and wake all waiters
- critical_section::with(|cs| *self.0.access(cs).receiver_dropped = true);
+ critical_section::with(|cs| unsafe {
+ // SAFETY: `self.0.receiver_dropped` is not called recursively.
+ self.0.receiver_dropped(cs, |v| *v = true);
+ });
while let Some((waker, _)) = self.0.wait_queue.pop() {
waker.wake();
@@ -582,6 +663,7 @@ impl<T, const N: usize> Drop for Receiver<'_, T, N> {
}
#[cfg(test)]
+#[cfg(not(loom))]
mod tests {
use cassette::Cassette;
@@ -666,35 +748,6 @@ mod tests {
assert_eq!(s.try_send(11), Err(TrySendError::NoReceiver(11)));
}
- #[tokio::test]
- async fn stress_channel() {
- const NUM_RUNS: usize = 1_000;
- const QUEUE_SIZE: usize = 10;
-
- let (s, mut r) = make_channel!(u32, QUEUE_SIZE);
- let mut v = std::vec::Vec::new();
-
- for i in 0..NUM_RUNS {
- let mut s = s.clone();
-
- v.push(tokio::spawn(async move {
- s.send(i as _).await.unwrap();
- }));
- }
-
- let mut map = std::collections::BTreeSet::new();
-
- for _ in 0..NUM_RUNS {
- map.insert(r.recv().await.unwrap());
- }
-
- assert_eq!(map.len(), NUM_RUNS);
-
- for v in v {
- v.await.unwrap();
- }
- }
-
fn make() {
let _ = make_channel!(u32, 10);
}
@@ -715,7 +768,7 @@ mod tests {
where
F: FnOnce(&mut Deque<u8, N>) -> R,
{
- critical_section::with(|cs| f(channel.access(cs).freeq))
+ critical_section::with(|cs| unsafe { channel.freeq(cs, f) })
}
#[test]
@@ -750,3 +803,36 @@ mod tests {
drop((tx, rx));
}
}
+
+#[cfg(not(loom))]
+#[cfg(test)]
+mod tokio_tests {
+ #[tokio::test]
+ async fn stress_channel() {
+ const NUM_RUNS: usize = 1_000;
+ const QUEUE_SIZE: usize = 10;
+
+ let (s, mut r) = make_channel!(u32, QUEUE_SIZE);
+ let mut v = std::vec::Vec::new();
+
+ for i in 0..NUM_RUNS {
+ let mut s = s.clone();
+
+ v.push(tokio::spawn(async move {
+ s.send(i as _).await.unwrap();
+ }));
+ }
+
+ let mut map = std::collections::BTreeSet::new();
+
+ for _ in 0..NUM_RUNS {
+ map.insert(r.recv().await.unwrap());
+ }
+
+ assert_eq!(map.len(), NUM_RUNS);
+
+ for v in v {
+ v.await.unwrap();
+ }
+ }
+}
diff --git a/rtic-sync/src/lib.rs b/rtic-sync/src/lib.rs
index f884588..c2f323f 100644
--- a/rtic-sync/src/lib.rs
+++ b/rtic-sync/src/lib.rs
@@ -1,6 +1,6 @@
//! Synchronization primitives for asynchronous contexts.
-#![no_std]
+#![cfg_attr(not(loom), no_std)]
#![deny(missing_docs)]
#[cfg(feature = "defmt-03")]
@@ -11,6 +11,11 @@ pub mod channel;
pub use portable_atomic;
pub mod signal;
+mod unsafecell;
+
#[cfg(test)]
#[macro_use]
extern crate std;
+
+#[cfg(loom)]
+mod loom_cs;
diff --git a/rtic-sync/src/loom_cs.rs b/rtic-sync/src/loom_cs.rs
new file mode 100644
index 0000000..3291f52
--- /dev/null
+++ b/rtic-sync/src/loom_cs.rs
@@ -0,0 +1,69 @@
+//! A loom-based implementation of CriticalSection, effectively copied from the critical_section::std module.
+
+use core::cell::RefCell;
+use core::mem::MaybeUninit;
+
+use loom::cell::Cell;
+use loom::sync::{Mutex, MutexGuard};
+
+loom::lazy_static! {
+ static ref GLOBAL_MUTEX: Mutex<()> = Mutex::new(());
+ // This is initialized if a thread has acquired the CS, uninitialized otherwise.
+ static ref GLOBAL_GUARD: RefCell<MaybeUninit<MutexGuard<'static, ()>>> = RefCell::new(MaybeUninit::uninit());
+}
+
+loom::thread_local!(static IS_LOCKED: Cell<bool> = Cell::new(false));
+
+struct StdCriticalSection;
+critical_section::set_impl!(StdCriticalSection);
+
+unsafe impl critical_section::Impl for StdCriticalSection {
+ unsafe fn acquire() -> bool {
+ // Allow reentrancy by checking thread local state
+ IS_LOCKED.with(|l| {
+ if l.get() {
+ // CS already acquired in the current thread.
+ return true;
+ }
+
+ // Note: it is fine to set this flag *before* acquiring the mutex because it's thread local.
+ // No other thread can see its value, there's no potential for races.
+ // This way, we hold the mutex for slightly less time.
+ l.set(true);
+
+ // Not acquired in the current thread, acquire it.
+ let guard = match GLOBAL_MUTEX.lock() {
+ Ok(guard) => guard,
+ Err(err) => {
+ // Ignore poison on the global mutex in case a panic occurred
+ // while the mutex was held.
+ err.into_inner()
+ }
+ };
+ GLOBAL_GUARD.borrow_mut().write(guard);
+
+ false
+ })
+ }
+
+ unsafe fn release(nested_cs: bool) {
+ if !nested_cs {
+ // SAFETY: As per the acquire/release safety contract, release can only be called
+ // if the critical section is acquired in the current thread,
+ // in which case we know the GLOBAL_GUARD is initialized.
+ //
+ // We have to `assume_init_read` then drop instead of `assume_init_drop` because:
+ // - drop requires exclusive access (&mut) to the contents
+ // - mutex guard drop first unlocks the mutex, then returns. In between those, there's a brief
+ // moment where the mutex is unlocked but a `&mut` to the contents exists.
+ // - During this moment, another thread can go and use GLOBAL_GUARD, causing `&mut` aliasing.
+ #[allow(let_underscore_lock)]
+ let _ = GLOBAL_GUARD.borrow_mut().assume_init_read();
+
+ // Note: it is fine to clear this flag *after* releasing the mutex because it's thread local.
+ // No other thread can see its value, there's no potential for races.
+ // This way, we hold the mutex for slightly less time.
+ IS_LOCKED.with(|l| l.set(false));
+ }
+ }
+}
diff --git a/rtic-sync/src/signal.rs b/rtic-sync/src/signal.rs
index f3c8ceb..d43e9d5 100644
--- a/rtic-sync/src/signal.rs
+++ b/rtic-sync/src/signal.rs
@@ -168,10 +168,10 @@ macro_rules! make_signal {
}
#[cfg(test)]
+#[cfg(not(loom))]
mod tests {
- use static_cell::StaticCell;
-
use super::*;
+ use static_cell::StaticCell;
#[test]
fn empty() {
diff --git a/rtic-sync/src/unsafecell.rs b/rtic-sync/src/unsafecell.rs
new file mode 100644
index 0000000..e1774f8
--- /dev/null
+++ b/rtic-sync/src/unsafecell.rs
@@ -0,0 +1,43 @@
+//! Compat layer for [`core::cell::UnsafeCell`] and `loom::cell::UnsafeCell`.
+
+#[cfg(loom)]
+pub use loom::cell::UnsafeCell;
+
+#[cfg(not(loom))]
+pub use core::UnsafeCell;
+
+#[cfg(not(loom))]
+mod core {
+ /// An [`core::cell::UnsafeCell`] wrapper that provides compatibility with
+ /// loom's UnsafeCell.
+ #[derive(Debug)]
+ pub struct UnsafeCell<T>(core::cell::UnsafeCell<T>);
+
+ impl<T> UnsafeCell<T> {
+ /// Create a new `UnsafeCell`.
+ pub const fn new(data: T) -> UnsafeCell<T> {
+ UnsafeCell(core::cell::UnsafeCell::new(data))
+ }
+
+ /// Access the contents of the `UnsafeCell` through a mut pointer.
+ pub fn get_mut(&self) -> MutPtr<T> {
+ MutPtr(self.0.get())
+ }
+
+ pub unsafe fn with_mut<F, R>(&self, f: F) -> R
+ where
+ F: FnOnce(*mut T) -> R,
+ {
+ f(self.0.get())
+ }
+ }
+
+ pub struct MutPtr<T>(*mut T);
+
+ impl<T> MutPtr<T> {
+ #[allow(clippy::mut_from_ref)]
+ pub unsafe fn deref(&self) -> &mut T {
+ &mut *self.0
+ }
+ }
+}