// SPDX-License-Identifier: GPL-2.0 /* Copyright (c) 2025 Meta Platforms, Inc. and affiliates. */ #include #include #include "bpf_misc.h" #include "bpf_experimental.h" struct node_data { struct bpf_refcount ref; struct bpf_rb_node r0; struct bpf_rb_node r1; int key0; int key1; }; #define private(name) SEC(".data." #name) __hidden __attribute__((aligned(8))) private(A) struct bpf_spin_lock glock0; private(A) struct bpf_rb_root groot0 __contains(node_data, r0); private(B) struct bpf_spin_lock glock1; private(B) struct bpf_rb_root groot1 __contains(node_data, r1); #define rb_entry(ptr, type, member) container_of(ptr, type, member) #define NR_NODES 16 int zero = 0; static bool less0(struct bpf_rb_node *a, const struct bpf_rb_node *b) { struct node_data *node_a; struct node_data *node_b; node_a = rb_entry(a, struct node_data, r0); node_b = rb_entry(b, struct node_data, r0); return node_a->key0 < node_b->key0; } static bool less1(struct bpf_rb_node *a, const struct bpf_rb_node *b) { struct node_data *node_a; struct node_data *node_b; node_a = rb_entry(a, struct node_data, r1); node_b = rb_entry(b, struct node_data, r1); return node_a->key1 < node_b->key1; } SEC("syscall") __retval(0) long rbtree_search(void *ctx) { struct bpf_rb_node *rb_n, *rb_m, *gc_ns[NR_NODES]; long lookup_key = NR_NODES / 2; struct node_data *n, *m; int i, nr_gc = 0; for (i = zero; i < NR_NODES && can_loop; i++) { n = bpf_obj_new(typeof(*n)); if (!n) return __LINE__; m = bpf_refcount_acquire(n); n->key0 = i; m->key1 = i; bpf_spin_lock(&glock0); bpf_rbtree_add(&groot0, &n->r0, less0); bpf_spin_unlock(&glock0); bpf_spin_lock(&glock1); bpf_rbtree_add(&groot1, &m->r1, less1); bpf_spin_unlock(&glock1); } n = NULL; bpf_spin_lock(&glock0); rb_n = bpf_rbtree_root(&groot0); while (can_loop) { if (!rb_n) { bpf_spin_unlock(&glock0); return __LINE__; } n = rb_entry(rb_n, struct node_data, r0); if (lookup_key == n->key0) break; if (nr_gc < NR_NODES) gc_ns[nr_gc++] = rb_n; if (lookup_key < n->key0) rb_n = bpf_rbtree_left(&groot0, rb_n); else rb_n = bpf_rbtree_right(&groot0, rb_n); } if (!n || lookup_key != n->key0) { bpf_spin_unlock(&glock0); return __LINE__; } for (i = 0; i < nr_gc; i++) { rb_n = gc_ns[i]; gc_ns[i] = bpf_rbtree_remove(&groot0, rb_n); } m = bpf_refcount_acquire(n); bpf_spin_unlock(&glock0); for (i = 0; i < nr_gc; i++) { rb_n = gc_ns[i]; if (rb_n) { n = rb_entry(rb_n, struct node_data, r0); bpf_obj_drop(n); } } if (!m) return __LINE__; bpf_spin_lock(&glock1); rb_m = bpf_rbtree_remove(&groot1, &m->r1); bpf_spin_unlock(&glock1); bpf_obj_drop(m); if (!rb_m) return __LINE__; bpf_obj_drop(rb_entry(rb_m, struct node_data, r1)); return 0; } #define TEST_ROOT(dolock) \ SEC("syscall") \ __failure __msg(MSG) \ long test_root_spinlock_##dolock(void *ctx) \ { \ struct bpf_rb_node *rb_n; \ __u64 jiffies = 0; \ \ if (dolock) \ bpf_spin_lock(&glock0); \ rb_n = bpf_rbtree_root(&groot0); \ if (rb_n) \ jiffies = bpf_jiffies64(); \ if (dolock) \ bpf_spin_unlock(&glock0); \ \ return !!jiffies; \ } #define TEST_LR(op, dolock) \ SEC("syscall") \ __failure __msg(MSG) \ long test_##op##_spinlock_##dolock(void *ctx) \ { \ struct bpf_rb_node *rb_n; \ struct node_data *n; \ __u64 jiffies = 0; \ \ bpf_spin_lock(&glock0); \ rb_n = bpf_rbtree_root(&groot0); \ if (!rb_n) { \ bpf_spin_unlock(&glock0); \ return 1; \ } \ n = rb_entry(rb_n, struct node_data, r0); \ n = bpf_refcount_acquire(n); \ bpf_spin_unlock(&glock0); \ if (!n) \ return 1; \ \ if (dolock) \ bpf_spin_lock(&glock0); \ rb_n = bpf_rbtree_##op(&groot0, &n->r0); \ if (rb_n) \ jiffies = bpf_jiffies64(); \ if (dolock) \ bpf_spin_unlock(&glock0); \ \ return !!jiffies; \ } /* * Use a spearate MSG macro instead of passing to TEST_XXX(..., MSG) * to ensure the message itself is not in the bpf prog lineinfo * which the verifier includes in its log. * Otherwise, the test_loader will incorrectly match the prog lineinfo * instead of the log generated by the verifier. */ #define MSG "call bpf_rbtree_root{{.+}}; R0{{(_w)?}}=rcu_ptr_or_null_node_data(id={{[0-9]+}},non_own_ref" TEST_ROOT(true) #undef MSG #define MSG "call bpf_rbtree_{{(left|right).+}}; R0{{(_w)?}}=rcu_ptr_or_null_node_data(id={{[0-9]+}},non_own_ref" TEST_LR(left, true) TEST_LR(right, true) #undef MSG #define MSG "bpf_spin_lock at off=0 must be held for bpf_rb_root" TEST_ROOT(false) TEST_LR(left, false) TEST_LR(right, false) #undef MSG char _license[] SEC("license") = "GPL";