explodethosebits 0.3.0__cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. etb/__init__.py +351 -0
  2. etb/__init__.pyi +976 -0
  3. etb/_etb.cpython-39-x86_64-linux-gnu.so +0 -0
  4. etb/_version.py +34 -0
  5. etb/py.typed +2 -0
  6. explodethosebits-0.3.0.dist-info/METADATA +405 -0
  7. explodethosebits-0.3.0.dist-info/RECORD +88 -0
  8. explodethosebits-0.3.0.dist-info/WHEEL +6 -0
  9. explodethosebits-0.3.0.dist-info/licenses/LICENSE +21 -0
  10. explodethosebits-0.3.0.dist-info/sboms/auditwheel.cdx.json +1 -0
  11. explodethosebits.libs/libcudart-c3a75b33.so.12.8.90 +0 -0
  12. include/etb/bit_coordinate.hpp +45 -0
  13. include/etb/bit_extraction.hpp +79 -0
  14. include/etb/bit_pruning.hpp +122 -0
  15. include/etb/config.hpp +284 -0
  16. include/etb/cuda/arch_optimizations.cuh +358 -0
  17. include/etb/cuda/blackwell_optimizations.cuh +300 -0
  18. include/etb/cuda/cuda_common.cuh +265 -0
  19. include/etb/cuda/etb_cuda.cuh +200 -0
  20. include/etb/cuda/gpu_memory.cuh +406 -0
  21. include/etb/cuda/heuristics_kernel.cuh +315 -0
  22. include/etb/cuda/path_generator_kernel.cuh +272 -0
  23. include/etb/cuda/prefix_pruner_kernel.cuh +370 -0
  24. include/etb/cuda/signature_kernel.cuh +328 -0
  25. include/etb/early_stopping.hpp +246 -0
  26. include/etb/etb.hpp +20 -0
  27. include/etb/heuristics.hpp +165 -0
  28. include/etb/memoization.hpp +285 -0
  29. include/etb/path.hpp +86 -0
  30. include/etb/path_count.hpp +87 -0
  31. include/etb/path_generator.hpp +175 -0
  32. include/etb/prefix_trie.hpp +339 -0
  33. include/etb/reporting.hpp +437 -0
  34. include/etb/scoring.hpp +269 -0
  35. include/etb/signature.hpp +190 -0
  36. include/gmock/gmock-actions.h +2297 -0
  37. include/gmock/gmock-cardinalities.h +159 -0
  38. include/gmock/gmock-function-mocker.h +518 -0
  39. include/gmock/gmock-matchers.h +5623 -0
  40. include/gmock/gmock-more-actions.h +658 -0
  41. include/gmock/gmock-more-matchers.h +120 -0
  42. include/gmock/gmock-nice-strict.h +277 -0
  43. include/gmock/gmock-spec-builders.h +2148 -0
  44. include/gmock/gmock.h +96 -0
  45. include/gmock/internal/custom/README.md +18 -0
  46. include/gmock/internal/custom/gmock-generated-actions.h +7 -0
  47. include/gmock/internal/custom/gmock-matchers.h +37 -0
  48. include/gmock/internal/custom/gmock-port.h +40 -0
  49. include/gmock/internal/gmock-internal-utils.h +487 -0
  50. include/gmock/internal/gmock-port.h +139 -0
  51. include/gmock/internal/gmock-pp.h +279 -0
  52. include/gtest/gtest-assertion-result.h +237 -0
  53. include/gtest/gtest-death-test.h +345 -0
  54. include/gtest/gtest-matchers.h +923 -0
  55. include/gtest/gtest-message.h +252 -0
  56. include/gtest/gtest-param-test.h +546 -0
  57. include/gtest/gtest-printers.h +1161 -0
  58. include/gtest/gtest-spi.h +250 -0
  59. include/gtest/gtest-test-part.h +192 -0
  60. include/gtest/gtest-typed-test.h +331 -0
  61. include/gtest/gtest.h +2321 -0
  62. include/gtest/gtest_pred_impl.h +279 -0
  63. include/gtest/gtest_prod.h +60 -0
  64. include/gtest/internal/custom/README.md +44 -0
  65. include/gtest/internal/custom/gtest-port.h +37 -0
  66. include/gtest/internal/custom/gtest-printers.h +42 -0
  67. include/gtest/internal/custom/gtest.h +37 -0
  68. include/gtest/internal/gtest-death-test-internal.h +307 -0
  69. include/gtest/internal/gtest-filepath.h +227 -0
  70. include/gtest/internal/gtest-internal.h +1560 -0
  71. include/gtest/internal/gtest-param-util.h +1026 -0
  72. include/gtest/internal/gtest-port-arch.h +122 -0
  73. include/gtest/internal/gtest-port.h +2481 -0
  74. include/gtest/internal/gtest-string.h +178 -0
  75. include/gtest/internal/gtest-type-util.h +220 -0
  76. lib/libetb_core.a +0 -0
  77. lib64/cmake/GTest/GTestConfig.cmake +33 -0
  78. lib64/cmake/GTest/GTestConfigVersion.cmake +43 -0
  79. lib64/cmake/GTest/GTestTargets-release.cmake +49 -0
  80. lib64/cmake/GTest/GTestTargets.cmake +139 -0
  81. lib64/libgmock.a +0 -0
  82. lib64/libgmock_main.a +0 -0
  83. lib64/libgtest.a +0 -0
  84. lib64/libgtest_main.a +0 -0
  85. lib64/pkgconfig/gmock.pc +10 -0
  86. lib64/pkgconfig/gmock_main.pc +10 -0
  87. lib64/pkgconfig/gtest.pc +9 -0
  88. lib64/pkgconfig/gtest_main.pc +10 -0
@@ -0,0 +1,272 @@
1
+ #ifndef ETB_PATH_GENERATOR_KERNEL_CUH
2
+ #define ETB_PATH_GENERATOR_KERNEL_CUH
3
+
4
+ #include "cuda_common.cuh"
5
+ #include "gpu_memory.cuh"
6
+
7
+ namespace etb {
8
+ namespace cuda {
9
+
10
+ /**
11
+ * Work item for path generation.
12
+ * Represents a partial path that needs to be explored.
13
+ */
14
+ struct PathWorkItem {
15
+ uint32_t start_byte; // Starting byte index for this work item
16
+ uint32_t current_depth; // Current depth in the path
17
+ uint8_t prefix_bytes[16]; // Reconstructed bytes so far (max 16 for early stopping)
18
+ uint8_t prefix_length; // Number of bytes in prefix
19
+ uint8_t bit_selections[16]; // Bit selections made so far
20
+ float current_score; // Current heuristic score
21
+
22
+ __host__ __device__ PathWorkItem()
23
+ : start_byte(0), current_depth(0), prefix_length(0), current_score(0.0f) {
24
+ for (int i = 0; i < 16; ++i) {
25
+ prefix_bytes[i] = 0;
26
+ bit_selections[i] = 0;
27
+ }
28
+ }
29
+ };
30
+
31
+ /**
32
+ * Configuration for path generator kernel.
33
+ */
34
+ struct PathGeneratorConfig {
35
+ uint32_t input_length; // Length of input data
36
+ uint32_t max_depth; // Maximum path depth
37
+ uint32_t batch_size; // Number of paths to generate per kernel launch
38
+ DeviceBitPruningConfig bit_pruning;
39
+ DeviceEarlyStoppingConfig early_stopping;
40
+ DeviceHeuristicWeights heuristic_weights;
41
+ DeviceScoringWeights scoring_weights;
42
+
43
+ __host__ __device__ PathGeneratorConfig()
44
+ : input_length(0), max_depth(16), batch_size(65536) {}
45
+ };
46
+
47
+ /**
48
+ * Shared memory structure for path generation.
49
+ * Used for cooperative path exploration within a thread block.
50
+ */
51
+ struct PathGeneratorSharedMem {
52
+ // Prefix state shared across warp
53
+ uint8_t shared_prefix[32];
54
+ uint32_t shared_prefix_length;
55
+
56
+ // Work stealing queue (per-block)
57
+ uint32_t local_work_head;
58
+ uint32_t local_work_tail;
59
+ PathWorkItem local_work_items[32]; // Small local queue
60
+
61
+ // Reduction scratch space
62
+ float warp_scores[32];
63
+ uint32_t warp_votes[32];
64
+ };
65
+
66
+ /**
67
+ * OPTIMIZED Path generator CUDA kernel.
68
+ *
69
+ * Generates paths using work-stealing across thread blocks with
70
+ * warp-level cooperative path exploration.
71
+ *
72
+ * Optimizations applied:
73
+ * - Lock-free work stealing with proper CAS semantics (no increment-then-check)
74
+ * - Warp-cooperative early stopping checks using ballot_sync
75
+ * - Coalesced bit extraction (all lanes read same byte, extract different bits)
76
+ * - Block-level work coordination to reduce global queue contention
77
+ *
78
+ * Requirements: 9.3
79
+ *
80
+ * @param input_data Input byte array
81
+ * @param config Kernel configuration
82
+ * @param work_queue Global work queue
83
+ * @param work_queue_head Head pointer for work queue
84
+ * @param work_queue_tail Tail pointer for work queue
85
+ * @param prefix_trie Prefix trie for pruning
86
+ * @param candidates Output candidate queue
87
+ * @param candidate_count Number of candidates found
88
+ * @param min_score Minimum score threshold
89
+ */
90
+ __global__ void path_generator_kernel(
91
+ const uint8_t* input_data,
92
+ PathGeneratorConfig config,
93
+ PathWorkItem* work_queue,
94
+ uint32_t* work_queue_head,
95
+ uint32_t* work_queue_tail,
96
+ DevicePrefixTrieNode* prefix_trie,
97
+ DeviceCandidate* candidates,
98
+ uint32_t* candidate_count,
99
+ float* min_score
100
+ );
101
+
102
+ /**
103
+ * Initialize work queue with starting positions.
104
+ *
105
+ * @param work_queue Work queue to initialize
106
+ * @param work_queue_tail Tail pointer
107
+ * @param input_length Length of input data
108
+ * @param bit_mask Bit mask for allowed positions
109
+ */
110
+ __global__ void init_work_queue_kernel(
111
+ PathWorkItem* work_queue,
112
+ uint32_t* work_queue_tail,
113
+ uint32_t input_length,
114
+ uint8_t bit_mask
115
+ );
116
+
117
+ /**
118
+ * Host-side launcher for path generator kernel.
119
+ */
120
+ class PathGeneratorKernel {
121
+ public:
122
+ PathGeneratorKernel();
123
+ ~PathGeneratorKernel();
124
+
125
+ /**
126
+ * Configure the kernel for a specific device.
127
+ * @param device_id CUDA device ID
128
+ */
129
+ void configure(int device_id);
130
+
131
+ /**
132
+ * Launch the path generator kernel.
133
+ * @param mem GPU memory manager
134
+ * @param config Kernel configuration
135
+ * @param stream CUDA stream (nullptr for default)
136
+ */
137
+ void launch(GPUMemoryManager& mem, const PathGeneratorConfig& config,
138
+ cudaStream_t stream = nullptr);
139
+
140
+ /**
141
+ * Initialize work queue with starting positions.
142
+ * @param mem GPU memory manager
143
+ * @param input_length Length of input data
144
+ * @param bit_mask Bit mask for allowed positions
145
+ * @param stream CUDA stream
146
+ */
147
+ void init_work_queue(GPUMemoryManager& mem, uint32_t input_length,
148
+ uint8_t bit_mask, cudaStream_t stream = nullptr);
149
+
150
+ /**
151
+ * Get the kernel configuration.
152
+ */
153
+ const KernelConfig& get_config() const { return kernel_config_; }
154
+
155
+ private:
156
+ KernelConfig kernel_config_;
157
+ bool configured_;
158
+ };
159
+
160
+ // Device functions for path generation
161
+
162
+ /**
163
+ * Extract a bit from input data at the given coordinate.
164
+ */
165
+ __device__ inline uint8_t extract_bit(const uint8_t* data, uint32_t byte_idx, uint8_t bit_pos) {
166
+ return (data[byte_idx] >> bit_pos) & 1;
167
+ }
168
+
169
+ /**
170
+ * Reconstruct a byte from 8 bit selections.
171
+ */
172
+ __device__ inline uint8_t reconstruct_byte(const uint8_t* bits) {
173
+ uint8_t result = 0;
174
+ for (int i = 0; i < 8; ++i) {
175
+ result |= (bits[i] & 1) << i;
176
+ }
177
+ return result;
178
+ }
179
+
180
+ /**
181
+ * Check if a bit position is allowed by the mask.
182
+ */
183
+ __device__ inline bool is_bit_allowed(uint8_t bit_pos, uint8_t mask) {
184
+ return (mask >> bit_pos) & 1;
185
+ }
186
+
187
+ /**
188
+ * Count allowed bits in mask.
189
+ */
190
+ __device__ inline int count_allowed_bits(uint8_t mask) {
191
+ return __popc(static_cast<unsigned int>(mask));
192
+ }
193
+
194
+ /**
195
+ * Get the nth allowed bit position.
196
+ */
197
+ __device__ inline uint8_t get_nth_allowed_bit(uint8_t mask, int n) {
198
+ int count = 0;
199
+ for (uint8_t i = 0; i < 8; ++i) {
200
+ if ((mask >> i) & 1) {
201
+ if (count == n) return i;
202
+ ++count;
203
+ }
204
+ }
205
+ return 0;
206
+ }
207
+
208
+ /**
209
+ * Warp-level vote for early termination.
210
+ * Returns true if majority of warp votes to terminate.
211
+ */
212
+ __device__ inline bool warp_vote_terminate(bool should_terminate) {
213
+ unsigned int vote = __ballot_sync(0xFFFFFFFF, should_terminate);
214
+ return __popc(vote) > 16; // More than half the warp
215
+ }
216
+
217
+ /**
218
+ * Warp-level reduction for finding best score.
219
+ */
220
+ __device__ inline float warp_reduce_max(float val) {
221
+ for (int offset = 16; offset > 0; offset /= 2) {
222
+ float other = __shfl_down_sync(0xFFFFFFFF, val, offset);
223
+ val = fmaxf(val, other);
224
+ }
225
+ return __shfl_sync(0xFFFFFFFF, val, 0);
226
+ }
227
+
228
+ /**
229
+ * Atomic work stealing from global queue.
230
+ * Returns true if work was successfully stolen.
231
+ */
232
+ __device__ inline bool steal_work(PathWorkItem* work_queue,
233
+ uint32_t* head, uint32_t* tail,
234
+ uint32_t queue_capacity,
235
+ PathWorkItem& item) {
236
+ uint32_t old_head = atomicAdd(head, 1);
237
+ uint32_t current_tail = *tail;
238
+
239
+ if (old_head < current_tail) {
240
+ item = work_queue[old_head % queue_capacity];
241
+ return true;
242
+ }
243
+
244
+ // No work available, restore head
245
+ atomicSub(head, 1);
246
+ return false;
247
+ }
248
+
249
+ /**
250
+ * Push work item to global queue.
251
+ * Returns true if successfully pushed.
252
+ */
253
+ __device__ inline bool push_work(PathWorkItem* work_queue,
254
+ uint32_t* tail,
255
+ uint32_t queue_capacity,
256
+ const PathWorkItem& item) {
257
+ uint32_t old_tail = atomicAdd(tail, 1);
258
+
259
+ if (old_tail < queue_capacity) {
260
+ work_queue[old_tail] = item;
261
+ return true;
262
+ }
263
+
264
+ // Queue full, restore tail
265
+ atomicSub(tail, 1);
266
+ return false;
267
+ }
268
+
269
+ } // namespace cuda
270
+ } // namespace etb
271
+
272
+ #endif // ETB_PATH_GENERATOR_KERNEL_CUH
@@ -0,0 +1,370 @@
1
+ #ifndef ETB_PREFIX_PRUNER_KERNEL_CUH
2
+ #define ETB_PREFIX_PRUNER_KERNEL_CUH
3
+
4
+ #include "cuda_common.cuh"
5
+ #include "gpu_memory.cuh"
6
+
7
+ namespace etb {
8
+ namespace cuda {
9
+
10
+ /**
11
+ * Shared memory structure for prefix pruning operations.
12
+ */
13
+ struct PrefixPrunerSharedMem {
14
+ // Warp voting results
15
+ uint32_t warp_votes[8]; // 8 warps per block max
16
+
17
+ // Prefix being evaluated
18
+ uint8_t current_prefix[16];
19
+ uint32_t prefix_length;
20
+
21
+ // Trie navigation state
22
+ uint32_t current_node_idx;
23
+ DevicePrefixStatus current_status;
24
+ float current_score;
25
+ };
26
+
27
+ /**
28
+ * Prefix pruner CUDA kernel.
29
+ *
30
+ * OPTIMIZED: Now takes pre-computed prefix offsets to avoid O(n²) offset calculation.
31
+ * Host should compute prefix_offsets using exclusive_scan(prefix_lengths).
32
+ *
33
+ * Implements warp-level voting for termination decisions and
34
+ * atomic trie updates for prefix status.
35
+ *
36
+ * Requirements: 9.8
37
+ *
38
+ * @param prefix_trie Prefix trie nodes
39
+ * @param trie_size Number of nodes in trie
40
+ * @param prefixes Array of prefixes to evaluate
41
+ * @param prefix_lengths Array of prefix lengths
42
+ * @param prefix_offsets Pre-computed exclusive prefix sum of lengths (NEW)
43
+ * @param scores Array of heuristic scores for each prefix
44
+ * @param num_prefixes Number of prefixes to evaluate
45
+ * @param prune_threshold Score threshold for pruning
46
+ * @param prune_results Output array indicating if each prefix was pruned
47
+ */
48
+ __global__ void prefix_pruner_kernel(
49
+ DevicePrefixTrieNode* prefix_trie,
50
+ uint32_t trie_size,
51
+ const uint8_t* prefixes,
52
+ const uint32_t* prefix_lengths,
53
+ const uint32_t* prefix_offsets,
54
+ const float* scores,
55
+ uint32_t num_prefixes,
56
+ float prune_threshold,
57
+ bool* prune_results
58
+ );
59
+
60
+ /**
61
+ * Trie lookup kernel.
62
+ * OPTIMIZED: Uses pre-computed prefix offsets.
63
+ *
64
+ * @param prefix_trie Prefix trie nodes
65
+ * @param trie_size Number of nodes in trie
66
+ * @param prefixes Array of prefixes to look up
67
+ * @param prefix_lengths Array of prefix lengths
68
+ * @param prefix_offsets Pre-computed exclusive prefix sum of lengths (NEW)
69
+ * @param num_prefixes Number of prefixes
70
+ * @param statuses Output array of prefix statuses
71
+ * @param scores Output array of best scores
72
+ */
73
+ __global__ void trie_lookup_kernel(
74
+ const DevicePrefixTrieNode* prefix_trie,
75
+ uint32_t trie_size,
76
+ const uint8_t* prefixes,
77
+ const uint32_t* prefix_lengths,
78
+ const uint32_t* prefix_offsets,
79
+ uint32_t num_prefixes,
80
+ DevicePrefixStatus* statuses,
81
+ float* scores
82
+ );
83
+
84
+ /**
85
+ * Trie insert/update kernel.
86
+ * OPTIMIZED: Uses pre-computed prefix offsets and warp-cooperative allocation.
87
+ *
88
+ * @param prefix_trie Prefix trie nodes
89
+ * @param trie_size Current number of nodes
90
+ * @param max_trie_size Maximum trie capacity
91
+ * @param prefixes Array of prefixes to insert
92
+ * @param prefix_lengths Array of prefix lengths
93
+ * @param prefix_offsets Pre-computed exclusive prefix sum of lengths (NEW)
94
+ * @param statuses Array of statuses to set
95
+ * @param scores Array of scores to set
96
+ * @param num_prefixes Number of prefixes
97
+ * @param new_trie_size Output: new trie size after insertions
98
+ */
99
+ __global__ void trie_insert_kernel(
100
+ DevicePrefixTrieNode* prefix_trie,
101
+ uint32_t trie_size,
102
+ uint32_t max_trie_size,
103
+ const uint8_t* prefixes,
104
+ const uint32_t* prefix_lengths,
105
+ const uint32_t* prefix_offsets,
106
+ const DevicePrefixStatus* statuses,
107
+ const float* scores,
108
+ uint32_t num_prefixes,
109
+ uint32_t* new_trie_size
110
+ );
111
+
112
+ /**
113
+ * Batch prune check kernel.
114
+ * OPTIMIZED: Uses pre-computed prefix offsets.
115
+ *
116
+ * @param prefix_trie Prefix trie nodes
117
+ * @param trie_size Number of nodes
118
+ * @param prefixes Array of prefixes to check
119
+ * @param prefix_lengths Array of prefix lengths
120
+ * @param prefix_offsets Pre-computed exclusive prefix sum of lengths (NEW)
121
+ * @param num_prefixes Number of prefixes
122
+ * @param should_skip Output array indicating if prefix should be skipped
123
+ */
124
+ __global__ void batch_prune_check_kernel(
125
+ const DevicePrefixTrieNode* prefix_trie,
126
+ uint32_t trie_size,
127
+ const uint8_t* prefixes,
128
+ const uint32_t* prefix_lengths,
129
+ const uint32_t* prefix_offsets,
130
+ uint32_t num_prefixes,
131
+ bool* should_skip
132
+ );
133
+
134
+ /**
135
+ * Host-side launcher for prefix pruner kernels.
136
+ */
137
+ class PrefixPrunerKernel {
138
+ public:
139
+ PrefixPrunerKernel();
140
+ ~PrefixPrunerKernel();
141
+
142
+ /**
143
+ * Configure the kernel for a specific device.
144
+ * @param device_id CUDA device ID
145
+ */
146
+ void configure(int device_id);
147
+
148
+ /**
149
+ * Evaluate and prune prefixes.
150
+ * @param trie Device prefix trie
151
+ * @param trie_size Number of trie nodes
152
+ * @param prefixes Device array of prefixes (flattened)
153
+ * @param prefix_lengths Device array of prefix lengths
154
+ * @param scores Device array of heuristic scores
155
+ * @param num_prefixes Number of prefixes
156
+ * @param prune_threshold Score threshold for pruning
157
+ * @param prune_results Device array for results
158
+ * @param stream CUDA stream
159
+ */
160
+ void evaluate_and_prune(DevicePrefixTrieNode* trie, uint32_t trie_size,
161
+ const uint8_t* prefixes, const uint32_t* prefix_lengths,
162
+ const float* scores, uint32_t num_prefixes,
163
+ float prune_threshold, bool* prune_results,
164
+ cudaStream_t stream = nullptr);
165
+
166
+ /**
167
+ * Look up prefix statuses.
168
+ * @param trie Device prefix trie
169
+ * @param trie_size Number of trie nodes
170
+ * @param prefixes Device array of prefixes
171
+ * @param prefix_lengths Device array of prefix lengths
172
+ * @param num_prefixes Number of prefixes
173
+ * @param statuses Device array for output statuses
174
+ * @param scores Device array for output scores
175
+ * @param stream CUDA stream
176
+ */
177
+ void lookup(const DevicePrefixTrieNode* trie, uint32_t trie_size,
178
+ const uint8_t* prefixes, const uint32_t* prefix_lengths,
179
+ uint32_t num_prefixes, DevicePrefixStatus* statuses,
180
+ float* scores, cudaStream_t stream = nullptr);
181
+
182
+ /**
183
+ * Check if prefixes should be skipped due to pruned ancestors.
184
+ * @param trie Device prefix trie
185
+ * @param trie_size Number of trie nodes
186
+ * @param prefixes Device array of prefixes
187
+ * @param prefix_lengths Device array of prefix lengths
188
+ * @param num_prefixes Number of prefixes
189
+ * @param should_skip Device array for output flags
190
+ * @param stream CUDA stream
191
+ */
192
+ void check_pruned(const DevicePrefixTrieNode* trie, uint32_t trie_size,
193
+ const uint8_t* prefixes, const uint32_t* prefix_lengths,
194
+ uint32_t num_prefixes, bool* should_skip,
195
+ cudaStream_t stream = nullptr);
196
+
197
+ /**
198
+ * Get the kernel configuration.
199
+ */
200
+ const KernelConfig& get_config() const { return kernel_config_; }
201
+
202
+ private:
203
+ KernelConfig kernel_config_;
204
+ bool configured_;
205
+ };
206
+
207
+ // ============================================================================
208
+ // Device Functions
209
+ // ============================================================================
210
+
211
+ /**
212
+ * Navigate trie to find node for a prefix.
213
+ * Returns node index or UINT32_MAX if not found.
214
+ */
215
+ __device__ inline uint32_t find_prefix_node(
216
+ const DevicePrefixTrieNode* trie,
217
+ uint32_t trie_size,
218
+ const uint8_t* prefix,
219
+ uint32_t prefix_length
220
+ ) {
221
+ if (trie_size == 0 || prefix_length == 0) {
222
+ return 0; // Root node
223
+ }
224
+
225
+ uint32_t current = 0; // Start at root
226
+
227
+ for (uint32_t i = 0; i < prefix_length; ++i) {
228
+ uint8_t byte_val = prefix[i];
229
+ uint32_t children_offset = trie[current].children_offset;
230
+
231
+ if (children_offset == 0) {
232
+ return UINT32_MAX; // No children, prefix not found
233
+ }
234
+
235
+ // Look for child with matching byte value
236
+ bool found = false;
237
+ for (uint32_t c = 0; c < 256 && children_offset + c < trie_size; ++c) {
238
+ uint32_t child_idx = children_offset + c;
239
+ if (child_idx < trie_size && trie[child_idx].reconstructed_byte == byte_val) {
240
+ current = child_idx;
241
+ found = true;
242
+ break;
243
+ }
244
+ }
245
+
246
+ if (!found) {
247
+ return UINT32_MAX;
248
+ }
249
+ }
250
+
251
+ return current;
252
+ }
253
+
254
+ /**
255
+ * Check if any ancestor of a prefix is pruned.
256
+ */
257
+ __device__ inline bool is_ancestor_pruned(
258
+ const DevicePrefixTrieNode* trie,
259
+ uint32_t trie_size,
260
+ const uint8_t* prefix,
261
+ uint32_t prefix_length
262
+ ) {
263
+ if (trie_size == 0) return false;
264
+
265
+ uint32_t current = 0;
266
+
267
+ for (uint32_t i = 0; i < prefix_length; ++i) {
268
+ // Check current node status
269
+ if (trie[current].status == DevicePrefixStatus::PRUNED) {
270
+ return true;
271
+ }
272
+
273
+ uint8_t byte_val = prefix[i];
274
+ uint32_t children_offset = trie[current].children_offset;
275
+
276
+ if (children_offset == 0) {
277
+ return false; // No more nodes to check
278
+ }
279
+
280
+ // Find child
281
+ bool found = false;
282
+ for (uint32_t c = 0; c < 256 && children_offset + c < trie_size; ++c) {
283
+ uint32_t child_idx = children_offset + c;
284
+ if (child_idx < trie_size && trie[child_idx].reconstructed_byte == byte_val) {
285
+ current = child_idx;
286
+ found = true;
287
+ break;
288
+ }
289
+ }
290
+
291
+ if (!found) {
292
+ return false;
293
+ }
294
+ }
295
+
296
+ // Check final node
297
+ return trie[current].status == DevicePrefixStatus::PRUNED;
298
+ }
299
+
300
+ /**
301
+ * Atomically update node status.
302
+ */
303
+ __device__ inline void atomic_update_status(
304
+ DevicePrefixTrieNode* node,
305
+ DevicePrefixStatus new_status
306
+ ) {
307
+ // Use atomicCAS on the status byte
308
+ uint8_t* status_ptr = reinterpret_cast<uint8_t*>(&node->status);
309
+ uint8_t old_val = *status_ptr;
310
+ uint8_t new_val = static_cast<uint8_t>(new_status);
311
+
312
+ // Only update if transitioning to a "more final" state
313
+ // UNKNOWN -> VALID or PRUNED is allowed
314
+ // VALID -> PRUNED is allowed
315
+ // PRUNED is final
316
+ if (old_val == static_cast<uint8_t>(DevicePrefixStatus::PRUNED)) {
317
+ return; // Already pruned, don't change
318
+ }
319
+
320
+ atomicCAS(reinterpret_cast<unsigned int*>(status_ptr),
321
+ static_cast<unsigned int>(old_val),
322
+ static_cast<unsigned int>(new_val));
323
+ }
324
+
325
+ /**
326
+ * Atomically update best score (only if new score is higher).
327
+ */
328
+ __device__ inline void atomic_update_score(
329
+ DevicePrefixTrieNode* node,
330
+ float new_score
331
+ ) {
332
+ // Use atomicMax on the score
333
+ // Since atomicMax doesn't work directly on floats, we use a CAS loop
334
+ float* score_ptr = &node->best_score;
335
+ float old_score = *score_ptr;
336
+
337
+ while (new_score > old_score) {
338
+ float assumed = old_score;
339
+ old_score = __int_as_float(atomicCAS(
340
+ reinterpret_cast<int*>(score_ptr),
341
+ __float_as_int(assumed),
342
+ __float_as_int(new_score)
343
+ ));
344
+
345
+ if (old_score == assumed) {
346
+ break; // Successfully updated
347
+ }
348
+ }
349
+ }
350
+
351
+ /**
352
+ * Warp-level vote for pruning decision.
353
+ * Returns true if majority of warp votes to prune.
354
+ */
355
+ __device__ inline bool warp_vote_prune(bool should_prune) {
356
+ unsigned int vote = __ballot_sync(0xFFFFFFFF, should_prune);
357
+ return __popc(vote) > 16; // More than half
358
+ }
359
+
360
+ /**
361
+ * Increment visit count atomically.
362
+ */
363
+ __device__ inline void atomic_increment_visit(DevicePrefixTrieNode* node) {
364
+ atomicAdd(&node->visit_count, 1);
365
+ }
366
+
367
+ } // namespace cuda
368
+ } // namespace etb
369
+
370
+ #endif // ETB_PREFIX_PRUNER_KERNEL_CUH