diff --git a/src/next.rs b/src/next.rs index 09877e2..6f18644 100644 --- a/src/next.rs +++ b/src/next.rs @@ -2,11 +2,11 @@ use std::{ time::{Instant, Duration}, collections::VecDeque, marker::PhantomData, - sync::{Mutex, Condvar, Arc, atomic::AtomicUsize}, fmt, error, + sync::{Mutex, Condvar, Arc, MutexGuard}, + fmt, + error, }; -// TODO: shut down debouncer on drop - /// Creates a new debouncer for deduplicating groups of "raw" events which occur at a similar time. /// The debouncer is comprised of two halves; a [`DebouncerTx`](DebouncerTx) for sending raw events /// to the debouncer, and a [`DebouncerRx`](DebouncerRx) for receiving grouped (debounced) events @@ -52,14 +52,7 @@ pub fn debouncer(debounce_time: Duration, fold: F) where F: Fn(Option, R) -> D, { - let shared_state = Arc::new(Debouncer { - state: Mutex::new(DebouncerState::new()), - debounce_time, - queue_wait_cvar: Condvar::new(), - event_ready_wait_cvar: Condvar::new(), - tx_count: AtomicUsize::new(1), - rx_count: AtomicUsize::new(1), - }); + let shared_state = Arc::new(Debouncer::new(debounce_time, 1, 1)); let tx = DebouncerTx { debouncer: shared_state.clone(), @@ -89,12 +82,16 @@ where } } -// FIXME: increment some sort of tx count impl Clone for DebouncerTx where F: Clone, { fn clone(&self) -> Self { + self.debouncer + .lock_state() + .add_tx() + .expect("debouncer tx count should not overflow"); + Self { debouncer: self.debouncer.clone(), fold: self.fold.clone(), @@ -103,10 +100,20 @@ where } } -// FIXME: check tx count and shut down if zero impl Drop for DebouncerTx { fn drop(&mut self) { - todo!() + let remaining_tx = { + let mut state_guard = self.debouncer.lock_state(); + state_guard.remove_tx() + }; + + if remaining_tx == 0 { + // There may be some rx threads waiting on the condvars, so notify them to stop them + // from waiting. Upon waking up, the rx threads will see that the tx count is 0 and + // switch to their "shutdown" behaviour. + self.debouncer.event_ready_wait_cvar.notify_all(); + self.debouncer.queue_wait_cvar.notify_all(); + } } } @@ -124,19 +131,25 @@ impl DebouncerRx { } } -// FIXME: increment some sort of rx count impl Clone for DebouncerRx { fn clone(&self) -> Self { + self.debouncer + .lock_state() + .add_rx() + .expect("debouncer rx count should not overflow"); + Self { debouncer: self.debouncer.clone(), } } } -// FIXME: check rx count and shut down if zero impl Drop for DebouncerRx { fn drop(&mut self) { - todo!() + // Decrement the rx count. We don't need to notify any threads waiting on condvars if the + // count reaches 0, because only rx threads wait on the condvars and we know there are no + // more rx threads (because the count just reached 0!) + self.debouncer.lock_state().remove_rx(); } } @@ -172,11 +185,25 @@ struct Debouncer { debounce_time: Duration, queue_wait_cvar: Condvar, event_ready_wait_cvar: Condvar, - tx_count: AtomicUsize, - rx_count: AtomicUsize, } impl Debouncer { + fn new(debounce_time: Duration, tx_count: usize, rx_count: usize) -> Self { + Self { + state: Mutex::new(DebouncerState::new(tx_count, rx_count)), + debounce_time, + queue_wait_cvar: Condvar::new(), + event_ready_wait_cvar: Condvar::new(), + } + } + + fn lock_state(&self) -> MutexGuard> { + match self.state.lock() { + Ok(guard) => guard, + Err(err) => err.into_inner(), + } + } + fn push(&self, raw_event: R, fold: F) -> Result<(), SendError> where F: Fn(Option, R) -> T, @@ -184,11 +211,11 @@ impl Debouncer { let now = Instant::now(); let push_outcome = { - let mut state_guard = self.state.lock().unwrap(); + let mut state_guard = self.lock_state(); - // Return an error if the debouncer is closed. Include the raw event in the error so - // that it isn't lost. - if state_guard.shutdown { + // Return an error if there are no rxs left to send to. Include the raw event in the + // error so that it isn't lost. + if state_guard.has_no_rxs() { return Err(SendError(raw_event)); } @@ -212,10 +239,12 @@ impl Debouncer { Shutdown, } - let mut state_guard = self.state.lock().unwrap(); + let mut state_guard = self.lock_state(); let result = 'result: { - if state_guard.shutdown { + // Check that there are any txs left so we don't get stuck waiting forever if the queue + // is empty (as a tx is the only thing that can wake us up). + if state_guard.has_no_txs() { break 'result PopOutcome::Shutdown; } @@ -225,17 +254,23 @@ impl Debouncer { // wait for it to be ready. // 3. Before we can wake up from waiting, some other thread pops the now-ready `x`, and // the queue is now empty. - // 4. We wake up and must wait for the queue to become non-empty again. + // 4. We wake up and must wait for the queue to become non-empty again. + // + // Additionally, the loop also handles spurious returns from the condvar wait. 'pop_event_outer_loop: loop { // If there are no accumulators in the queue, wait for one to be pushed. if state_guard.acc_queue.is_empty() { // Park the thread. We will be woken up again either by a new accumulator being // pushed to the queue, or by the debouncer being shut down. + // + // This may return spuriously so the queue may still be empty when we wake up, + // but if this happens we will retry the wait when we attempt and fail to peek + // the queue later. state_guard = self.queue_wait_cvar.wait(state_guard).unwrap(); // We may have been unparked because someone wants to shut down the debouncer, // so check the shutdown flag. - if state_guard.shutdown { + if state_guard.has_no_txs() { break PopOutcome::Shutdown; } } @@ -253,7 +288,9 @@ impl Debouncer { // popped. let Some(peeked_acc) = state_guard.acc_queue.peek_oldest() else { // If there is no accumulator for us to pop from the queue, go back to - // waiting for the queue to be non-empty. + // waiting for the queue to be non-empty. This could happen because the + // previous condvar wait returned spuriously, or because another thread + // popped the accumulator before us. continue 'pop_event_outer_loop; }; @@ -270,13 +307,19 @@ impl Debouncer { // Wait the amount of time between now and the `ready_time` of the // accumulator. This is done using a condvar so the sleep can be // interrupted if someone wants to shut down the debouncer. + // + // This may return spuriously so we may wake up before the time has + // elapsed and without being notified by anyone. However, since we go + // back so the start of the loop after this, we will see that there is + // still time remaining before the accumulator becomes ready, and + // therefore we will resume waiting. (state_guard, _) = self.event_ready_wait_cvar .wait_timeout(state_guard, wait_time) .unwrap(); // Again, we may have been unparked because someone wants to shut down the // debouncer, so check the shutdown flag and return if it is set. - if state_guard.shutdown { + if state_guard.has_no_txs() { break PopOutcome::Shutdown; } @@ -317,11 +360,11 @@ impl Debouncer { } fn try_pop(&self) -> Result, ReceiveError> { - let mut state_guard = self.state.lock().unwrap(); + let mut state_guard = self.lock_state(); // If the debouncer has been shut down, return any remaining accumulators then start // returning errors once the accumulators have been depleted. - if state_guard.shutdown { + if state_guard.has_no_txs() { return state_guard .acc_queue .pop_oldest_acc_discard_none() @@ -356,16 +399,53 @@ impl Debouncer { struct DebouncerState { acc_queue: EventAccQueue, - shutdown: bool, + // These could alternatively be `AtomicUsize`s which live outside the mutex, but we opt to have + // them inside the mutex instead to avoid a confusing mixture of mutexes / condvars and + // atomics. + tx_count: usize, + rx_count: usize, } impl DebouncerState { - fn new() -> Self { + fn new(tx_count: usize, rx_count: usize) -> Self { Self { acc_queue: EventAccQueue::new(), - shutdown: false, + tx_count, + rx_count, } } + + fn has_no_txs(&self) -> bool { + self.tx_count == 0 + } + + fn add_tx(&mut self) -> Result<(), CountOverflowError> { + self.tx_count = self.tx_count.checked_add(1) + .ok_or(CountOverflowError)?; + + Ok(()) + } + + fn remove_tx(&mut self) -> usize { + self.tx_count = self.tx_count.saturating_sub(1); + self.tx_count + } + + fn has_no_rxs(&self) -> bool { + self.rx_count == 0 + } + + fn add_rx(&mut self) -> Result<(), CountOverflowError> { + self.rx_count = self.rx_count.checked_add(1) + .ok_or(CountOverflowError)?; + + Ok(()) + } + + fn remove_rx(&mut self) -> usize { + self.rx_count = self.rx_count.saturating_sub(1); + self.rx_count + } } struct EventAccQueue { @@ -468,3 +548,5 @@ impl EventAcc { } } +#[derive(Debug)] +struct CountOverflowError;