diff options
Diffstat (limited to 'rtic-sync/src/arbiter.rs')
| -rw-r--r-- | rtic-sync/src/arbiter.rs | 194 |
1 files changed, 194 insertions, 0 deletions
diff --git a/rtic-sync/src/arbiter.rs b/rtic-sync/src/arbiter.rs new file mode 100644 index 0000000..d50b1ea --- /dev/null +++ b/rtic-sync/src/arbiter.rs @@ -0,0 +1,194 @@ +//! Crate + +use core::cell::UnsafeCell; +use core::future::poll_fn; +use core::ops::{Deref, DerefMut}; +use core::pin::Pin; +use core::sync::atomic::{fence, AtomicBool, Ordering}; +use core::task::{Poll, Waker}; + +use rtic_common::dropper::OnDrop; +use rtic_common::wait_queue::{Link, WaitQueue}; + +/// This is needed to make the async closure in `send` accept that we "share" +/// the link possible between threads. +#[derive(Clone)] +struct LinkPtr(*mut Option<Link<Waker>>); + +impl LinkPtr { + /// This will dereference the pointer stored within and give out an `&mut`. + unsafe fn get(&mut self) -> &mut Option<Link<Waker>> { + &mut *self.0 + } +} + +unsafe impl Send for LinkPtr {} +unsafe impl Sync for LinkPtr {} + +/// An FIFO waitqueue for use in shared bus usecases. +pub struct Arbiter<T> { + wait_queue: WaitQueue, + inner: UnsafeCell<T>, + taken: AtomicBool, +} + +unsafe impl<T> Send for Arbiter<T> {} +unsafe impl<T> Sync for Arbiter<T> {} + +impl<T> Arbiter<T> { + /// Create a new arbiter. + pub const fn new(inner: T) -> Self { + Self { + wait_queue: WaitQueue::new(), + inner: UnsafeCell::new(inner), + taken: AtomicBool::new(false), + } + } + + /// Get access to the inner value in the `Arbiter`. This will wait until access is granted, + /// for non-blocking access use `try_access`. + pub async fn access(&self) -> ExclusiveAccess<'_, T> { + let mut link_ptr: Option<Link<Waker>> = None; + + // Make this future `Drop`-safe. + // SAFETY(link_ptr): Shadow the original definition of `link_ptr` so we can't abuse it. + let mut link_ptr = LinkPtr(&mut link_ptr as *mut Option<Link<Waker>>); + + let mut link_ptr2 = link_ptr.clone(); + let dropper = OnDrop::new(|| { + // SAFETY: We only run this closure and dereference the pointer if we have + // exited the `poll_fn` below in the `drop(dropper)` call. The other dereference + // of this pointer is in the `poll_fn`. + if let Some(link) = unsafe { link_ptr2.get() } { + link.remove_from_list(&self.wait_queue); + } + }); + + poll_fn(|cx| { + critical_section::with(|_| { + fence(Ordering::SeqCst); + + // The queue is empty and noone has taken the value. + if self.wait_queue.is_empty() && !self.taken.load(Ordering::Relaxed) { + self.taken.store(true, Ordering::Relaxed); + + return Poll::Ready(()); + } + + // SAFETY: This pointer is only dereferenced here and on drop of the future + // which happens outside this `poll_fn`'s stack frame. + let link = unsafe { link_ptr.get() }; + if let Some(link) = link { + if link.is_popped() { + return Poll::Ready(()); + } + } else { + // Place the link in the wait queue on first run. + let link_ref = link.insert(Link::new(cx.waker().clone())); + + // SAFETY(new_unchecked): The address to the link is stable as it is defined + // outside this stack frame. + // SAFETY(push): `link_ref` lifetime comes from `link_ptr` that is shadowed, + // and we make sure in `dropper` that the link is removed from the queue + // before dropping `link_ptr` AND `dropper` makes sure that the shadowed + // `link_ptr` lives until the end of the stack frame. + unsafe { self.wait_queue.push(Pin::new_unchecked(link_ref)) }; + } + + Poll::Pending + }) + }) + .await; + + // Make sure the link is removed from the queue. + drop(dropper); + + // SAFETY: One only gets here if there is exlusive access. + ExclusiveAccess { + arbiter: self, + inner: unsafe { &mut *self.inner.get() }, + } + } + + /// Non-blockingly tries to access the underlying value. + /// If someone is in queue to get it, this will return `None`. + pub fn try_access(&self) -> Option<ExclusiveAccess<'_, T>> { + critical_section::with(|_| { + fence(Ordering::SeqCst); + + // The queue is empty and noone has taken the value. + if self.wait_queue.is_empty() && !self.taken.load(Ordering::Relaxed) { + self.taken.store(true, Ordering::Relaxed); + + // SAFETY: One only gets here if there is exlusive access. + Some(ExclusiveAccess { + arbiter: self, + inner: unsafe { &mut *self.inner.get() }, + }) + } else { + None + } + }) + } +} + +/// This token represents exclusive access to the value protected by the `Arbiter`. +pub struct ExclusiveAccess<'a, T> { + arbiter: &'a Arbiter<T>, + inner: &'a mut T, +} + +impl<'a, T> Drop for ExclusiveAccess<'a, T> { + fn drop(&mut self) { + critical_section::with(|_| { + fence(Ordering::SeqCst); + + if self.arbiter.wait_queue.is_empty() { + // If noone is in queue and we release exclusive access, reset `taken`. + self.arbiter.taken.store(false, Ordering::Relaxed); + } else if let Some(next) = self.arbiter.wait_queue.pop() { + // Wake the next one in queue. + next.wake(); + } + }) + } +} + +impl<'a, T> Deref for ExclusiveAccess<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.inner + } +} + +impl<'a, T> DerefMut for ExclusiveAccess<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.inner + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn stress_channel() { + const NUM_RUNS: usize = 100_000; + + static ARB: Arbiter<usize> = Arbiter::new(0); + let mut v = std::vec::Vec::new(); + + for _ in 0..NUM_RUNS { + v.push(tokio::spawn(async move { + *ARB.access().await += 1; + })); + } + + for v in v { + v.await.unwrap(); + } + + assert_eq!(*ARB.access().await, NUM_RUNS) + } +} |
