diff options
Diffstat (limited to 'rust/kernel/rbtree.rs')
| -rw-r--r-- | rust/kernel/rbtree.rs | 244 |
1 files changed, 197 insertions, 47 deletions
diff --git a/rust/kernel/rbtree.rs b/rust/kernel/rbtree.rs index b8fe6be6fcc4..4729eb56827a 100644 --- a/rust/kernel/rbtree.rs +++ b/rust/kernel/rbtree.rs @@ -243,34 +243,64 @@ impl<K, V> RBTree<K, V> { } /// Returns a cursor over the tree nodes, starting with the smallest key. - pub fn cursor_front(&mut self) -> Option<Cursor<'_, K, V>> { + pub fn cursor_front_mut(&mut self) -> Option<CursorMut<'_, K, V>> { let root = addr_of_mut!(self.root); - // SAFETY: `self.root` is always a valid root node + // SAFETY: `self.root` is always a valid root node. let current = unsafe { bindings::rb_first(root) }; NonNull::new(current).map(|current| { // INVARIANT: // - `current` is a valid node in the [`RBTree`] pointed to by `self`. - Cursor { + CursorMut { current, tree: self, } }) } + /// Returns an immutable cursor over the tree nodes, starting with the smallest key. + pub fn cursor_front(&self) -> Option<Cursor<'_, K, V>> { + let root = &raw const self.root; + // SAFETY: `self.root` is always a valid root node. + let current = unsafe { bindings::rb_first(root) }; + NonNull::new(current).map(|current| { + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self`. + Cursor { + current, + _tree: PhantomData, + } + }) + } + /// Returns a cursor over the tree nodes, starting with the largest key. - pub fn cursor_back(&mut self) -> Option<Cursor<'_, K, V>> { + pub fn cursor_back_mut(&mut self) -> Option<CursorMut<'_, K, V>> { let root = addr_of_mut!(self.root); - // SAFETY: `self.root` is always a valid root node + // SAFETY: `self.root` is always a valid root node. let current = unsafe { bindings::rb_last(root) }; NonNull::new(current).map(|current| { // INVARIANT: // - `current` is a valid node in the [`RBTree`] pointed to by `self`. - Cursor { + CursorMut { current, tree: self, } }) } + + /// Returns a cursor over the tree nodes, starting with the largest key. + pub fn cursor_back(&self) -> Option<Cursor<'_, K, V>> { + let root = &raw const self.root; + // SAFETY: `self.root` is always a valid root node. + let current = unsafe { bindings::rb_last(root) }; + NonNull::new(current).map(|current| { + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self`. + Cursor { + current, + _tree: PhantomData, + } + }) + } } impl<K, V> RBTree<K, V> @@ -421,12 +451,47 @@ where /// If the given key exists, the cursor starts there. /// Otherwise it starts with the first larger key in sort order. /// If there is no larger key, it returns [`None`]. - pub fn cursor_lower_bound(&mut self, key: &K) -> Option<Cursor<'_, K, V>> + pub fn cursor_lower_bound_mut(&mut self, key: &K) -> Option<CursorMut<'_, K, V>> + where + K: Ord, + { + let best = self.find_best_match(key)?; + + NonNull::new(best.as_ptr()).map(|current| { + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self`. + CursorMut { + current, + tree: self, + } + }) + } + + /// Returns a cursor over the tree nodes based on the given key. + /// + /// If the given key exists, the cursor starts there. + /// Otherwise it starts with the first larger key in sort order. + /// If there is no larger key, it returns [`None`]. + pub fn cursor_lower_bound(&self, key: &K) -> Option<Cursor<'_, K, V>> where K: Ord, { + let best = self.find_best_match(key)?; + + NonNull::new(best.as_ptr()).map(|current| { + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self`. + Cursor { + current, + _tree: PhantomData, + } + }) + } + + fn find_best_match(&self, key: &K) -> Option<NonNull<bindings::rb_node>> { let mut node = self.root.rb_node; - let mut best_match: Option<NonNull<Node<K, V>>> = None; + let mut best_key: Option<&K> = None; + let mut best_links: Option<NonNull<bindings::rb_node>> = None; while !node.is_null() { // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` // point to the links field of `Node<K, V>` objects. @@ -439,42 +504,28 @@ where let right_child = unsafe { (*node).rb_right }; match key.cmp(this_key) { Ordering::Equal => { - best_match = NonNull::new(this); + // SAFETY: `this` is a non-null node so it is valid by the type invariants. + best_links = Some(unsafe { NonNull::new_unchecked(&mut (*this).links) }); break; } Ordering::Greater => { node = right_child; } Ordering::Less => { - let is_better_match = match best_match { + let is_better_match = match best_key { None => true, - Some(best) => { - // SAFETY: `best` is a non-null node so it is valid by the type invariants. - let best_key = unsafe { &(*best.as_ptr()).key }; - best_key > this_key - } + Some(best) => best > this_key, }; if is_better_match { - best_match = NonNull::new(this); + best_key = Some(this_key); + // SAFETY: `this` is a non-null node so it is valid by the type invariants. + best_links = Some(unsafe { NonNull::new_unchecked(&mut (*this).links) }); } node = left_child; } }; } - - let best = best_match?; - - // SAFETY: `best` is a non-null node so it is valid by the type invariants. - let links = unsafe { addr_of_mut!((*best.as_ptr()).links) }; - - NonNull::new(links).map(|current| { - // INVARIANT: - // - `current` is a valid node in the [`RBTree`] pointed to by `self`. - Cursor { - current, - tree: self, - } - }) + best_links } } @@ -507,7 +558,7 @@ impl<K, V> Drop for RBTree<K, V> { } } -/// A bidirectional cursor over the tree nodes, sorted by key. +/// A bidirectional mutable cursor over the tree nodes, sorted by key. /// /// # Examples /// @@ -526,7 +577,7 @@ impl<K, V> Drop for RBTree<K, V> { /// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; /// /// // Get a cursor to the first element. -/// let mut cursor = tree.cursor_front().unwrap(); +/// let mut cursor = tree.cursor_front_mut().unwrap(); /// let mut current = cursor.current(); /// assert_eq!(current, (&10, &100)); /// @@ -564,7 +615,7 @@ impl<K, V> Drop for RBTree<K, V> { /// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; /// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; /// -/// let mut cursor = tree.cursor_back().unwrap(); +/// let mut cursor = tree.cursor_back_mut().unwrap(); /// let current = cursor.current(); /// assert_eq!(current, (&30, &300)); /// @@ -577,7 +628,7 @@ impl<K, V> Drop for RBTree<K, V> { /// use kernel::rbtree::RBTree; /// /// let mut tree: RBTree<u16, u16> = RBTree::new(); -/// assert!(tree.cursor_front().is_none()); +/// assert!(tree.cursor_front_mut().is_none()); /// /// # Ok::<(), Error>(()) /// ``` @@ -628,7 +679,7 @@ impl<K, V> Drop for RBTree<K, V> { /// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; /// /// // Retrieve a cursor. -/// let mut cursor = tree.cursor_front().unwrap(); +/// let mut cursor = tree.cursor_front_mut().unwrap(); /// /// // Get a mutable reference to the current value. /// let (k, v) = cursor.current_mut(); @@ -655,7 +706,7 @@ impl<K, V> Drop for RBTree<K, V> { /// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; /// /// // Remove the first element. -/// let mut cursor = tree.cursor_front().unwrap(); +/// let mut cursor = tree.cursor_front_mut().unwrap(); /// let mut current = cursor.current(); /// assert_eq!(current, (&10, &100)); /// cursor = cursor.remove_current().0.unwrap(); @@ -665,7 +716,7 @@ impl<K, V> Drop for RBTree<K, V> { /// assert_eq!(current, (&20, &200)); /// /// // Get a cursor to the last element, and remove it. -/// cursor = tree.cursor_back().unwrap(); +/// cursor = tree.cursor_back_mut().unwrap(); /// current = cursor.current(); /// assert_eq!(current, (&30, &300)); /// @@ -694,7 +745,7 @@ impl<K, V> Drop for RBTree<K, V> { /// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; /// /// // Get a cursor to the first element. -/// let mut cursor = tree.cursor_front().unwrap(); +/// let mut cursor = tree.cursor_front_mut().unwrap(); /// let mut current = cursor.current(); /// assert_eq!(current, (&10, &100)); /// @@ -702,7 +753,7 @@ impl<K, V> Drop for RBTree<K, V> { /// assert!(cursor.remove_prev().is_none()); /// /// // Get a cursor to the last element. -/// cursor = tree.cursor_back().unwrap(); +/// cursor = tree.cursor_back_mut().unwrap(); /// current = cursor.current(); /// assert_eq!(current, (&30, &300)); /// @@ -726,18 +777,48 @@ impl<K, V> Drop for RBTree<K, V> { /// /// # Invariants /// - `current` points to a node that is in the same [`RBTree`] as `tree`. -pub struct Cursor<'a, K, V> { +pub struct CursorMut<'a, K, V> { tree: &'a mut RBTree<K, V>, current: NonNull<bindings::rb_node>, } -// SAFETY: The [`Cursor`] has exclusive access to both `K` and `V`, so it is sufficient to require them to be `Send`. -// The cursor only gives out immutable references to the keys, but since it has excusive access to those same -// keys, `Send` is sufficient. `Sync` would be okay, but it is more restrictive to the user. -unsafe impl<'a, K: Send, V: Send> Send for Cursor<'a, K, V> {} +/// A bidirectional immutable cursor over the tree nodes, sorted by key. This is a simpler +/// variant of [`CursorMut`] that is basically providing read only access. +/// +/// # Examples +/// +/// In the following example, we obtain a cursor to the first element in the tree. +/// The cursor allows us to iterate bidirectionally over key/value pairs in the tree. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::RBTree}; +/// +/// // Create a new tree. +/// let mut tree = RBTree::new(); +/// +/// // Insert three elements. +/// tree.try_create_and_insert(10, 100, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; +/// +/// // Get a cursor to the first element. +/// let cursor = tree.cursor_front().unwrap(); +/// let current = cursor.current(); +/// assert_eq!(current, (&10, &100)); +/// +/// # Ok::<(), Error>(()) +/// ``` +pub struct Cursor<'a, K, V> { + _tree: PhantomData<&'a RBTree<K, V>>, + current: NonNull<bindings::rb_node>, +} -// SAFETY: The [`Cursor`] gives out immutable references to K and mutable references to V, -// so it has the same thread safety requirements as mutable references. +// SAFETY: The immutable cursor gives out shared access to `K` and `V` so if `K` and `V` can be +// shared across threads, then it's safe to share the cursor. +unsafe impl<'a, K: Sync, V: Sync> Send for Cursor<'a, K, V> {} + +// SAFETY: The immutable cursor gives out shared access to `K` and `V` so if `K` and `V` can be +// shared across threads, then it's safe to share the cursor. unsafe impl<'a, K: Sync, V: Sync> Sync for Cursor<'a, K, V> {} impl<'a, K, V> Cursor<'a, K, V> { @@ -749,6 +830,75 @@ impl<'a, K, V> Cursor<'a, K, V> { unsafe { Self::to_key_value(self.current) } } + /// # Safety + /// + /// - `node` must be a valid pointer to a node in an [`RBTree`]. + /// - The caller has immutable access to `node` for the duration of `'b`. + unsafe fn to_key_value<'b>(node: NonNull<bindings::rb_node>) -> (&'b K, &'b V) { + // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` + // point to the links field of `Node<K, V>` objects. + let this = unsafe { container_of!(node.as_ptr(), Node<K, V>, links) }; + // SAFETY: The passed `node` is the current node or a non-null neighbor, + // thus `this` is valid by the type invariants. + let k = unsafe { &(*this).key }; + // SAFETY: The passed `node` is the current node or a non-null neighbor, + // thus `this` is valid by the type invariants. + let v = unsafe { &(*this).value }; + (k, v) + } + + /// Access the previous node without moving the cursor. + pub fn peek_prev(&self) -> Option<(&K, &V)> { + self.peek(Direction::Prev) + } + + /// Access the next node without moving the cursor. + pub fn peek_next(&self) -> Option<(&K, &V)> { + self.peek(Direction::Next) + } + + fn peek(&self, direction: Direction) -> Option<(&K, &V)> { + self.get_neighbor_raw(direction).map(|neighbor| { + // SAFETY: + // - `neighbor` is a valid tree node. + // - By the function signature, we have an immutable reference to `self`. + unsafe { Self::to_key_value(neighbor) } + }) + } + + fn get_neighbor_raw(&self, direction: Direction) -> Option<NonNull<bindings::rb_node>> { + // SAFETY: `self.current` is valid by the type invariants. + let neighbor = unsafe { + match direction { + Direction::Prev => bindings::rb_prev(self.current.as_ptr()), + Direction::Next => bindings::rb_next(self.current.as_ptr()), + } + }; + + NonNull::new(neighbor) + } +} + +// SAFETY: The [`CursorMut`] has exclusive access to both `K` and `V`, so it is sufficient to +// require them to be `Send`. +// The cursor only gives out immutable references to the keys, but since it has exclusive access to +// those same keys, `Send` is sufficient. `Sync` would be okay, but it is more restrictive to the +// user. +unsafe impl<'a, K: Send, V: Send> Send for CursorMut<'a, K, V> {} + +// SAFETY: The [`CursorMut`] gives out immutable references to `K` and mutable references to `V`, +// so it has the same thread safety requirements as mutable references. +unsafe impl<'a, K: Sync, V: Sync> Sync for CursorMut<'a, K, V> {} + +impl<'a, K, V> CursorMut<'a, K, V> { + /// The current node. + pub fn current(&self) -> (&K, &V) { + // SAFETY: + // - `self.current` is a valid node by the type invariants. + // - We have an immutable reference by the function signature. + unsafe { Self::to_key_value(self.current) } + } + /// The current node, with a mutable value pub fn current_mut(&mut self) -> (&K, &mut V) { // SAFETY: @@ -920,7 +1070,7 @@ impl<'a, K, V> Cursor<'a, K, V> { } } -/// Direction for [`Cursor`] operations. +/// Direction for [`Cursor`] and [`CursorMut`] operations. enum Direction { /// the node immediately before, in sort order Prev, |
