aboutsummaryrefslogtreecommitdiff
path: root/rtic-channel
diff options
context:
space:
mode:
authorEmil Fresk <emil.fresk@gmail.com>2023-01-28 20:47:21 +0100
committerHenrik Tjäder <henrik@tjaders.com>2023-03-01 00:33:37 +0100
commit94b00df2c7e82efb9003945bb90675e73ed0f5e0 (patch)
tree5cf235406625936ec54282266d3234aa43b1c899 /rtic-channel
parent48ac310036bb0053d36d9cfce191351028808651 (diff)
rtic-channel: Add testing, fix bugs
Diffstat (limited to 'rtic-channel')
-rw-r--r--rtic-channel/Cargo.toml3
-rw-r--r--rtic-channel/src/lib.rs170
-rw-r--r--rtic-channel/src/wait_queue.rs14
3 files changed, 172 insertions, 15 deletions
diff --git a/rtic-channel/Cargo.toml b/rtic-channel/Cargo.toml
index 8962352..5d4cbd0 100644
--- a/rtic-channel/Cargo.toml
+++ b/rtic-channel/Cargo.toml
@@ -9,6 +9,9 @@ edition = "2021"
heapless = "0.7"
critical-section = "1"
+[dev-dependencies]
+tokio = { version = "1", features = ["rt", "macros", "time"] }
+
[features]
default = []
diff --git a/rtic-channel/src/lib.rs b/rtic-channel/src/lib.rs
index b6a317f..1077b5a 100644
--- a/rtic-channel/src/lib.rs
+++ b/rtic-channel/src/lib.rs
@@ -10,6 +10,7 @@ use core::{
mem::MaybeUninit,
pin::Pin,
ptr,
+ sync::atomic::{fence, Ordering},
task::{Poll, Waker},
};
use heapless::Deque;
@@ -40,6 +41,9 @@ pub struct Channel<T, const N: usize> {
num_senders: UnsafeCell<usize>,
}
+unsafe impl<T, const N: usize> Send for Channel<T, N> {}
+unsafe impl<T, const N: usize> Sync for Channel<T, N> {}
+
struct UnsafeAccess<'a, const N: usize> {
freeq: &'a mut Deque<u8, N>,
readyq: &'a mut Deque<u8, N>,
@@ -129,6 +133,21 @@ pub struct Sender<'a, T, const N: usize>(&'a Channel<T, N>);
unsafe impl<'a, T, const N: usize> Send for Sender<'a, T, N> {}
+/// This is needed to make the async closure in `send` accept that we "share"
+/// the link possible between threads.
+#[derive(Clone)]
+struct LinkPtr(*mut Option<wait_queue::Link<Waker>>);
+
+impl LinkPtr {
+ /// This will dereference the pointer stored within and give out an `&mut`.
+ unsafe fn get(&self) -> &mut Option<wait_queue::Link<Waker>> {
+ &mut *self.0
+ }
+}
+
+unsafe impl Send for LinkPtr {}
+unsafe impl Sync for LinkPtr {}
+
impl<'a, T, const N: usize> core::fmt::Debug for Sender<'a, T, N> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Sender")
@@ -147,7 +166,12 @@ impl<'a, T, const N: usize> Sender<'a, T, N> {
}
// Write the value into the ready queue.
- critical_section::with(|cs| unsafe { self.0.access(cs).readyq.push_back_unchecked(idx) });
+ critical_section::with(|cs| {
+ debug_assert!(!self.0.access(cs).readyq.is_full());
+ unsafe { self.0.access(cs).readyq.push_back_unchecked(idx) }
+ });
+
+ fence(Ordering::SeqCst);
// If there is a receiver waker, wake it.
self.0.receiver_waker.wake();
@@ -176,18 +200,17 @@ impl<'a, T, const N: usize> Sender<'a, T, N> {
/// Send a value. If there is no place left in the queue this will wait until there is.
/// If the receiver does not exist this will return an error.
pub async fn send(&mut self, val: T) -> Result<(), NoReceiver<T>> {
- if self.is_closed() {}
-
let mut link_ptr: Option<wait_queue::Link<Waker>> = None;
// Make this future `Drop`-safe, also shadow the original definition so we can't abuse it.
- let link_ptr = &mut link_ptr as *mut Option<wait_queue::Link<Waker>>;
+ let link_ptr = LinkPtr(&mut link_ptr as *mut Option<wait_queue::Link<Waker>>);
+ let link_ptr2 = link_ptr.clone();
let dropper = OnDrop::new(|| {
// SAFETY: We only run this closure and dereference the pointer if we have
// exited the `poll_fn` below in the `drop(dropper)` call. The other dereference
// of this pointer is in the `poll_fn`.
- if let Some(link) = unsafe { &mut *link_ptr } {
+ if let Some(link) = unsafe { link_ptr2.get() } {
link.remove_from_list(&self.0.wait_queue);
}
});
@@ -199,11 +222,19 @@ impl<'a, T, const N: usize> Sender<'a, T, N> {
// Do all this in one critical section, else there can be race conditions
let queue_idx = critical_section::with(|cs| {
- if !self.0.wait_queue.is_empty() || self.0.access(cs).freeq.is_empty() {
+ let wq_empty = self.0.wait_queue.is_empty();
+ let fq_empty = self.0.access(cs).freeq.is_empty();
+ if !wq_empty || fq_empty {
// SAFETY: This pointer is only dereferenced here and on drop of the future
// which happens outside this `poll_fn`'s stack frame.
- let link = unsafe { &mut *link_ptr };
- if link.is_none() {
+ let link = unsafe { link_ptr.get() };
+ if let Some(link) = link {
+ if !link.is_poped() {
+ return None;
+ } else {
+ // Fall through to dequeue
+ }
+ } else {
// Place the link in the wait queue on first run.
let link_ref = link.insert(wait_queue::Link::new(cx.waker().clone()));
@@ -212,11 +243,12 @@ impl<'a, T, const N: usize> Sender<'a, T, N> {
self.0
.wait_queue
.push(unsafe { Pin::new_unchecked(link_ref) });
- }
- return None;
+ return None;
+ }
}
+ debug_assert!(!self.0.access(cs).freeq.is_empty());
// Get index as the queue is guaranteed not empty and the wait queue is empty
let idx = unsafe { self.0.access(cs).freeq.pop_front_unchecked() };
@@ -319,7 +351,12 @@ impl<'a, T, const N: usize> Receiver<'a, T, N> {
let r = unsafe { ptr::read(self.0.slots.get_unchecked(rs as usize).get() as *const T) };
// Return the index to the free queue after we've read the value.
- critical_section::with(|cs| unsafe { self.0.access(cs).freeq.push_back_unchecked(rs) });
+ critical_section::with(|cs| {
+ debug_assert!(!self.0.access(cs).freeq.is_full());
+ unsafe { self.0.access(cs).freeq.push_back_unchecked(rs) }
+ });
+
+ fence(Ordering::SeqCst);
// If someone is waiting in the WaiterQueue, wake the first one up.
if let Some(wait_head) = self.0.wait_queue.pop() {
@@ -363,7 +400,7 @@ impl<'a, T, const N: usize> Receiver<'a, T, N> {
/// Is the queue full.
pub fn is_full(&self) -> bool {
- critical_section::with(|cs| self.0.access(cs).readyq.is_empty())
+ critical_section::with(|cs| self.0.access(cs).readyq.is_full())
}
/// Is the queue empty.
@@ -412,6 +449,113 @@ extern crate std;
#[cfg(test)]
mod tests {
+ use super::*;
+
+ #[test]
+ fn empty() {
+ let (mut s, mut r) = make_channel!(u32, 10);
+
+ assert!(s.is_empty());
+ assert!(r.is_empty());
+
+ s.try_send(1).unwrap();
+
+ assert!(!s.is_empty());
+ assert!(!r.is_empty());
+
+ r.try_recv().unwrap();
+
+ assert!(s.is_empty());
+ assert!(r.is_empty());
+ }
+
+ #[test]
+ fn full() {
+ let (mut s, mut r) = make_channel!(u32, 3);
+
+ for _ in 0..3 {
+ assert!(!s.is_full());
+ assert!(!r.is_full());
+
+ s.try_send(1).unwrap();
+ }
+
+ assert!(s.is_full());
+ assert!(r.is_full());
+
+ for _ in 0..3 {
+ r.try_recv().unwrap();
+
+ assert!(!s.is_full());
+ assert!(!r.is_full());
+ }
+ }
+
+ #[test]
+ fn send_recieve() {
+ let (mut s, mut r) = make_channel!(u32, 10);
+
+ for i in 0..10 {
+ s.try_send(i).unwrap();
+ }
+
+ assert_eq!(s.try_send(11), Err(11));
+
+ for i in 0..10 {
+ assert_eq!(r.try_recv().unwrap(), i);
+ }
+
+ assert_eq!(r.try_recv(), None);
+ }
+
+ #[test]
+ fn closed_recv() {
+ let (s, mut r) = make_channel!(u32, 10);
+
+ drop(s);
+
+ assert!(r.is_closed());
+
+ assert_eq!(r.try_recv(), None);
+ }
+
#[test]
- fn channel() {}
+ fn closed_sender() {
+ let (mut s, r) = make_channel!(u32, 10);
+
+ drop(r);
+
+ assert!(s.is_closed());
+
+ assert_eq!(s.try_send(11), Ok(()));
+ }
+
+ #[tokio::test]
+ async fn stress_channel() {
+ const NUM_RUNS: usize = 1_000;
+ const QUEUE_SIZE: usize = 10;
+
+ let (s, mut r) = make_channel!(u32, QUEUE_SIZE);
+ let mut v = std::vec::Vec::new();
+
+ for i in 0..NUM_RUNS {
+ let mut s = s.clone();
+
+ v.push(tokio::spawn(async move {
+ s.send(i as _).await.unwrap();
+ }));
+ }
+
+ let mut map = std::collections::BTreeSet::new();
+
+ for _ in 0..NUM_RUNS {
+ map.insert(r.recv().await.unwrap());
+ }
+
+ assert_eq!(map.len(), NUM_RUNS);
+
+ for v in v {
+ v.await.unwrap();
+ }
+ }
}
diff --git a/rtic-channel/src/wait_queue.rs b/rtic-channel/src/wait_queue.rs
index ba05e6b..e6d5a8b 100644
--- a/rtic-channel/src/wait_queue.rs
+++ b/rtic-channel/src/wait_queue.rs
@@ -3,7 +3,7 @@
use core::marker::PhantomPinned;
use core::pin::Pin;
use core::ptr::null_mut;
-use core::sync::atomic::{AtomicPtr, Ordering};
+use core::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
use core::task::Waker;
use critical_section as cs;
@@ -57,6 +57,7 @@ impl<T: Clone> LinkedList<T> {
// Clear the pointers in the node.
head_ref.next.store(null_mut(), Self::R);
head_ref.prev.store(null_mut(), Self::R);
+ head_ref.is_poped.store(true, Self::R);
return Some(head_val);
}
@@ -100,9 +101,12 @@ pub struct Link<T> {
pub(crate) val: T,
next: AtomicPtr<Link<T>>,
prev: AtomicPtr<Link<T>>,
+ is_poped: AtomicBool,
_up: PhantomPinned,
}
+unsafe impl<T> Send for Link<T> {}
+
impl<T: Clone> Link<T> {
const R: Ordering = Ordering::Relaxed;
@@ -112,10 +116,15 @@ impl<T: Clone> Link<T> {
val,
next: AtomicPtr::new(null_mut()),
prev: AtomicPtr::new(null_mut()),
+ is_poped: AtomicBool::new(false),
_up: PhantomPinned,
}
}
+ pub fn is_poped(&self) -> bool {
+ self.is_poped.load(Self::R)
+ }
+
pub fn remove_from_list(&mut self, list: &LinkedList<T>) {
cs::with(|_| {
// Make sure all previous writes are visible
@@ -123,6 +132,7 @@ impl<T: Clone> Link<T> {
let prev = self.prev.load(Self::R);
let next = self.next.load(Self::R);
+ self.is_poped.store(true, Self::R);
match unsafe { (prev.as_ref(), next.as_ref()) } {
(None, None) => {
@@ -217,7 +227,7 @@ mod tests {
#[test]
fn linked_list() {
- let mut wq = LinkedList::<u32>::new();
+ let wq = LinkedList::<u32>::new();
let mut i1 = Link::new(10);
let mut i2 = Link::new(11);