explodethosebits 0.3.0__cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- etb/__init__.py +351 -0
- etb/__init__.pyi +976 -0
- etb/_etb.cpython-39-x86_64-linux-gnu.so +0 -0
- etb/_version.py +34 -0
- etb/py.typed +2 -0
- explodethosebits-0.3.0.dist-info/METADATA +405 -0
- explodethosebits-0.3.0.dist-info/RECORD +88 -0
- explodethosebits-0.3.0.dist-info/WHEEL +6 -0
- explodethosebits-0.3.0.dist-info/licenses/LICENSE +21 -0
- explodethosebits-0.3.0.dist-info/sboms/auditwheel.cdx.json +1 -0
- explodethosebits.libs/libcudart-c3a75b33.so.12.8.90 +0 -0
- include/etb/bit_coordinate.hpp +45 -0
- include/etb/bit_extraction.hpp +79 -0
- include/etb/bit_pruning.hpp +122 -0
- include/etb/config.hpp +284 -0
- include/etb/cuda/arch_optimizations.cuh +358 -0
- include/etb/cuda/blackwell_optimizations.cuh +300 -0
- include/etb/cuda/cuda_common.cuh +265 -0
- include/etb/cuda/etb_cuda.cuh +200 -0
- include/etb/cuda/gpu_memory.cuh +406 -0
- include/etb/cuda/heuristics_kernel.cuh +315 -0
- include/etb/cuda/path_generator_kernel.cuh +272 -0
- include/etb/cuda/prefix_pruner_kernel.cuh +370 -0
- include/etb/cuda/signature_kernel.cuh +328 -0
- include/etb/early_stopping.hpp +246 -0
- include/etb/etb.hpp +20 -0
- include/etb/heuristics.hpp +165 -0
- include/etb/memoization.hpp +285 -0
- include/etb/path.hpp +86 -0
- include/etb/path_count.hpp +87 -0
- include/etb/path_generator.hpp +175 -0
- include/etb/prefix_trie.hpp +339 -0
- include/etb/reporting.hpp +437 -0
- include/etb/scoring.hpp +269 -0
- include/etb/signature.hpp +190 -0
- include/gmock/gmock-actions.h +2297 -0
- include/gmock/gmock-cardinalities.h +159 -0
- include/gmock/gmock-function-mocker.h +518 -0
- include/gmock/gmock-matchers.h +5623 -0
- include/gmock/gmock-more-actions.h +658 -0
- include/gmock/gmock-more-matchers.h +120 -0
- include/gmock/gmock-nice-strict.h +277 -0
- include/gmock/gmock-spec-builders.h +2148 -0
- include/gmock/gmock.h +96 -0
- include/gmock/internal/custom/README.md +18 -0
- include/gmock/internal/custom/gmock-generated-actions.h +7 -0
- include/gmock/internal/custom/gmock-matchers.h +37 -0
- include/gmock/internal/custom/gmock-port.h +40 -0
- include/gmock/internal/gmock-internal-utils.h +487 -0
- include/gmock/internal/gmock-port.h +139 -0
- include/gmock/internal/gmock-pp.h +279 -0
- include/gtest/gtest-assertion-result.h +237 -0
- include/gtest/gtest-death-test.h +345 -0
- include/gtest/gtest-matchers.h +923 -0
- include/gtest/gtest-message.h +252 -0
- include/gtest/gtest-param-test.h +546 -0
- include/gtest/gtest-printers.h +1161 -0
- include/gtest/gtest-spi.h +250 -0
- include/gtest/gtest-test-part.h +192 -0
- include/gtest/gtest-typed-test.h +331 -0
- include/gtest/gtest.h +2321 -0
- include/gtest/gtest_pred_impl.h +279 -0
- include/gtest/gtest_prod.h +60 -0
- include/gtest/internal/custom/README.md +44 -0
- include/gtest/internal/custom/gtest-port.h +37 -0
- include/gtest/internal/custom/gtest-printers.h +42 -0
- include/gtest/internal/custom/gtest.h +37 -0
- include/gtest/internal/gtest-death-test-internal.h +307 -0
- include/gtest/internal/gtest-filepath.h +227 -0
- include/gtest/internal/gtest-internal.h +1560 -0
- include/gtest/internal/gtest-param-util.h +1026 -0
- include/gtest/internal/gtest-port-arch.h +122 -0
- include/gtest/internal/gtest-port.h +2481 -0
- include/gtest/internal/gtest-string.h +178 -0
- include/gtest/internal/gtest-type-util.h +220 -0
- lib/libetb_core.a +0 -0
- lib64/cmake/GTest/GTestConfig.cmake +33 -0
- lib64/cmake/GTest/GTestConfigVersion.cmake +43 -0
- lib64/cmake/GTest/GTestTargets-release.cmake +49 -0
- lib64/cmake/GTest/GTestTargets.cmake +139 -0
- lib64/libgmock.a +0 -0
- lib64/libgmock_main.a +0 -0
- lib64/libgtest.a +0 -0
- lib64/libgtest_main.a +0 -0
- lib64/pkgconfig/gmock.pc +10 -0
- lib64/pkgconfig/gmock_main.pc +10 -0
- lib64/pkgconfig/gtest.pc +9 -0
- lib64/pkgconfig/gtest_main.pc +10 -0
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
#ifndef ETB_PATH_GENERATOR_KERNEL_CUH
|
|
2
|
+
#define ETB_PATH_GENERATOR_KERNEL_CUH
|
|
3
|
+
|
|
4
|
+
#include "cuda_common.cuh"
|
|
5
|
+
#include "gpu_memory.cuh"
|
|
6
|
+
|
|
7
|
+
namespace etb {
|
|
8
|
+
namespace cuda {
|
|
9
|
+
|
|
10
|
+
/**
|
|
11
|
+
* Work item for path generation.
|
|
12
|
+
* Represents a partial path that needs to be explored.
|
|
13
|
+
*/
|
|
14
|
+
struct PathWorkItem {
|
|
15
|
+
uint32_t start_byte; // Starting byte index for this work item
|
|
16
|
+
uint32_t current_depth; // Current depth in the path
|
|
17
|
+
uint8_t prefix_bytes[16]; // Reconstructed bytes so far (max 16 for early stopping)
|
|
18
|
+
uint8_t prefix_length; // Number of bytes in prefix
|
|
19
|
+
uint8_t bit_selections[16]; // Bit selections made so far
|
|
20
|
+
float current_score; // Current heuristic score
|
|
21
|
+
|
|
22
|
+
__host__ __device__ PathWorkItem()
|
|
23
|
+
: start_byte(0), current_depth(0), prefix_length(0), current_score(0.0f) {
|
|
24
|
+
for (int i = 0; i < 16; ++i) {
|
|
25
|
+
prefix_bytes[i] = 0;
|
|
26
|
+
bit_selections[i] = 0;
|
|
27
|
+
}
|
|
28
|
+
}
|
|
29
|
+
};
|
|
30
|
+
|
|
31
|
+
/**
|
|
32
|
+
* Configuration for path generator kernel.
|
|
33
|
+
*/
|
|
34
|
+
struct PathGeneratorConfig {
|
|
35
|
+
uint32_t input_length; // Length of input data
|
|
36
|
+
uint32_t max_depth; // Maximum path depth
|
|
37
|
+
uint32_t batch_size; // Number of paths to generate per kernel launch
|
|
38
|
+
DeviceBitPruningConfig bit_pruning;
|
|
39
|
+
DeviceEarlyStoppingConfig early_stopping;
|
|
40
|
+
DeviceHeuristicWeights heuristic_weights;
|
|
41
|
+
DeviceScoringWeights scoring_weights;
|
|
42
|
+
|
|
43
|
+
__host__ __device__ PathGeneratorConfig()
|
|
44
|
+
: input_length(0), max_depth(16), batch_size(65536) {}
|
|
45
|
+
};
|
|
46
|
+
|
|
47
|
+
/**
|
|
48
|
+
* Shared memory structure for path generation.
|
|
49
|
+
* Used for cooperative path exploration within a thread block.
|
|
50
|
+
*/
|
|
51
|
+
struct PathGeneratorSharedMem {
|
|
52
|
+
// Prefix state shared across warp
|
|
53
|
+
uint8_t shared_prefix[32];
|
|
54
|
+
uint32_t shared_prefix_length;
|
|
55
|
+
|
|
56
|
+
// Work stealing queue (per-block)
|
|
57
|
+
uint32_t local_work_head;
|
|
58
|
+
uint32_t local_work_tail;
|
|
59
|
+
PathWorkItem local_work_items[32]; // Small local queue
|
|
60
|
+
|
|
61
|
+
// Reduction scratch space
|
|
62
|
+
float warp_scores[32];
|
|
63
|
+
uint32_t warp_votes[32];
|
|
64
|
+
};
|
|
65
|
+
|
|
66
|
+
/**
|
|
67
|
+
* OPTIMIZED Path generator CUDA kernel.
|
|
68
|
+
*
|
|
69
|
+
* Generates paths using work-stealing across thread blocks with
|
|
70
|
+
* warp-level cooperative path exploration.
|
|
71
|
+
*
|
|
72
|
+
* Optimizations applied:
|
|
73
|
+
* - Lock-free work stealing with proper CAS semantics (no increment-then-check)
|
|
74
|
+
* - Warp-cooperative early stopping checks using ballot_sync
|
|
75
|
+
* - Coalesced bit extraction (all lanes read same byte, extract different bits)
|
|
76
|
+
* - Block-level work coordination to reduce global queue contention
|
|
77
|
+
*
|
|
78
|
+
* Requirements: 9.3
|
|
79
|
+
*
|
|
80
|
+
* @param input_data Input byte array
|
|
81
|
+
* @param config Kernel configuration
|
|
82
|
+
* @param work_queue Global work queue
|
|
83
|
+
* @param work_queue_head Head pointer for work queue
|
|
84
|
+
* @param work_queue_tail Tail pointer for work queue
|
|
85
|
+
* @param prefix_trie Prefix trie for pruning
|
|
86
|
+
* @param candidates Output candidate queue
|
|
87
|
+
* @param candidate_count Number of candidates found
|
|
88
|
+
* @param min_score Minimum score threshold
|
|
89
|
+
*/
|
|
90
|
+
__global__ void path_generator_kernel(
|
|
91
|
+
const uint8_t* input_data,
|
|
92
|
+
PathGeneratorConfig config,
|
|
93
|
+
PathWorkItem* work_queue,
|
|
94
|
+
uint32_t* work_queue_head,
|
|
95
|
+
uint32_t* work_queue_tail,
|
|
96
|
+
DevicePrefixTrieNode* prefix_trie,
|
|
97
|
+
DeviceCandidate* candidates,
|
|
98
|
+
uint32_t* candidate_count,
|
|
99
|
+
float* min_score
|
|
100
|
+
);
|
|
101
|
+
|
|
102
|
+
/**
|
|
103
|
+
* Initialize work queue with starting positions.
|
|
104
|
+
*
|
|
105
|
+
* @param work_queue Work queue to initialize
|
|
106
|
+
* @param work_queue_tail Tail pointer
|
|
107
|
+
* @param input_length Length of input data
|
|
108
|
+
* @param bit_mask Bit mask for allowed positions
|
|
109
|
+
*/
|
|
110
|
+
__global__ void init_work_queue_kernel(
|
|
111
|
+
PathWorkItem* work_queue,
|
|
112
|
+
uint32_t* work_queue_tail,
|
|
113
|
+
uint32_t input_length,
|
|
114
|
+
uint8_t bit_mask
|
|
115
|
+
);
|
|
116
|
+
|
|
117
|
+
/**
|
|
118
|
+
* Host-side launcher for path generator kernel.
|
|
119
|
+
*/
|
|
120
|
+
class PathGeneratorKernel {
|
|
121
|
+
public:
|
|
122
|
+
PathGeneratorKernel();
|
|
123
|
+
~PathGeneratorKernel();
|
|
124
|
+
|
|
125
|
+
/**
|
|
126
|
+
* Configure the kernel for a specific device.
|
|
127
|
+
* @param device_id CUDA device ID
|
|
128
|
+
*/
|
|
129
|
+
void configure(int device_id);
|
|
130
|
+
|
|
131
|
+
/**
|
|
132
|
+
* Launch the path generator kernel.
|
|
133
|
+
* @param mem GPU memory manager
|
|
134
|
+
* @param config Kernel configuration
|
|
135
|
+
* @param stream CUDA stream (nullptr for default)
|
|
136
|
+
*/
|
|
137
|
+
void launch(GPUMemoryManager& mem, const PathGeneratorConfig& config,
|
|
138
|
+
cudaStream_t stream = nullptr);
|
|
139
|
+
|
|
140
|
+
/**
|
|
141
|
+
* Initialize work queue with starting positions.
|
|
142
|
+
* @param mem GPU memory manager
|
|
143
|
+
* @param input_length Length of input data
|
|
144
|
+
* @param bit_mask Bit mask for allowed positions
|
|
145
|
+
* @param stream CUDA stream
|
|
146
|
+
*/
|
|
147
|
+
void init_work_queue(GPUMemoryManager& mem, uint32_t input_length,
|
|
148
|
+
uint8_t bit_mask, cudaStream_t stream = nullptr);
|
|
149
|
+
|
|
150
|
+
/**
|
|
151
|
+
* Get the kernel configuration.
|
|
152
|
+
*/
|
|
153
|
+
const KernelConfig& get_config() const { return kernel_config_; }
|
|
154
|
+
|
|
155
|
+
private:
|
|
156
|
+
KernelConfig kernel_config_;
|
|
157
|
+
bool configured_;
|
|
158
|
+
};
|
|
159
|
+
|
|
160
|
+
// Device functions for path generation
|
|
161
|
+
|
|
162
|
+
/**
|
|
163
|
+
* Extract a bit from input data at the given coordinate.
|
|
164
|
+
*/
|
|
165
|
+
__device__ inline uint8_t extract_bit(const uint8_t* data, uint32_t byte_idx, uint8_t bit_pos) {
|
|
166
|
+
return (data[byte_idx] >> bit_pos) & 1;
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
/**
|
|
170
|
+
* Reconstruct a byte from 8 bit selections.
|
|
171
|
+
*/
|
|
172
|
+
__device__ inline uint8_t reconstruct_byte(const uint8_t* bits) {
|
|
173
|
+
uint8_t result = 0;
|
|
174
|
+
for (int i = 0; i < 8; ++i) {
|
|
175
|
+
result |= (bits[i] & 1) << i;
|
|
176
|
+
}
|
|
177
|
+
return result;
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
/**
|
|
181
|
+
* Check if a bit position is allowed by the mask.
|
|
182
|
+
*/
|
|
183
|
+
__device__ inline bool is_bit_allowed(uint8_t bit_pos, uint8_t mask) {
|
|
184
|
+
return (mask >> bit_pos) & 1;
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
/**
|
|
188
|
+
* Count allowed bits in mask.
|
|
189
|
+
*/
|
|
190
|
+
__device__ inline int count_allowed_bits(uint8_t mask) {
|
|
191
|
+
return __popc(static_cast<unsigned int>(mask));
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
/**
|
|
195
|
+
* Get the nth allowed bit position.
|
|
196
|
+
*/
|
|
197
|
+
__device__ inline uint8_t get_nth_allowed_bit(uint8_t mask, int n) {
|
|
198
|
+
int count = 0;
|
|
199
|
+
for (uint8_t i = 0; i < 8; ++i) {
|
|
200
|
+
if ((mask >> i) & 1) {
|
|
201
|
+
if (count == n) return i;
|
|
202
|
+
++count;
|
|
203
|
+
}
|
|
204
|
+
}
|
|
205
|
+
return 0;
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
/**
|
|
209
|
+
* Warp-level vote for early termination.
|
|
210
|
+
* Returns true if majority of warp votes to terminate.
|
|
211
|
+
*/
|
|
212
|
+
__device__ inline bool warp_vote_terminate(bool should_terminate) {
|
|
213
|
+
unsigned int vote = __ballot_sync(0xFFFFFFFF, should_terminate);
|
|
214
|
+
return __popc(vote) > 16; // More than half the warp
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
/**
|
|
218
|
+
* Warp-level reduction for finding best score.
|
|
219
|
+
*/
|
|
220
|
+
__device__ inline float warp_reduce_max(float val) {
|
|
221
|
+
for (int offset = 16; offset > 0; offset /= 2) {
|
|
222
|
+
float other = __shfl_down_sync(0xFFFFFFFF, val, offset);
|
|
223
|
+
val = fmaxf(val, other);
|
|
224
|
+
}
|
|
225
|
+
return __shfl_sync(0xFFFFFFFF, val, 0);
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
/**
|
|
229
|
+
* Atomic work stealing from global queue.
|
|
230
|
+
* Returns true if work was successfully stolen.
|
|
231
|
+
*/
|
|
232
|
+
__device__ inline bool steal_work(PathWorkItem* work_queue,
|
|
233
|
+
uint32_t* head, uint32_t* tail,
|
|
234
|
+
uint32_t queue_capacity,
|
|
235
|
+
PathWorkItem& item) {
|
|
236
|
+
uint32_t old_head = atomicAdd(head, 1);
|
|
237
|
+
uint32_t current_tail = *tail;
|
|
238
|
+
|
|
239
|
+
if (old_head < current_tail) {
|
|
240
|
+
item = work_queue[old_head % queue_capacity];
|
|
241
|
+
return true;
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
// No work available, restore head
|
|
245
|
+
atomicSub(head, 1);
|
|
246
|
+
return false;
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
/**
|
|
250
|
+
* Push work item to global queue.
|
|
251
|
+
* Returns true if successfully pushed.
|
|
252
|
+
*/
|
|
253
|
+
__device__ inline bool push_work(PathWorkItem* work_queue,
|
|
254
|
+
uint32_t* tail,
|
|
255
|
+
uint32_t queue_capacity,
|
|
256
|
+
const PathWorkItem& item) {
|
|
257
|
+
uint32_t old_tail = atomicAdd(tail, 1);
|
|
258
|
+
|
|
259
|
+
if (old_tail < queue_capacity) {
|
|
260
|
+
work_queue[old_tail] = item;
|
|
261
|
+
return true;
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
// Queue full, restore tail
|
|
265
|
+
atomicSub(tail, 1);
|
|
266
|
+
return false;
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
} // namespace cuda
|
|
270
|
+
} // namespace etb
|
|
271
|
+
|
|
272
|
+
#endif // ETB_PATH_GENERATOR_KERNEL_CUH
|
|
@@ -0,0 +1,370 @@
|
|
|
1
|
+
#ifndef ETB_PREFIX_PRUNER_KERNEL_CUH
|
|
2
|
+
#define ETB_PREFIX_PRUNER_KERNEL_CUH
|
|
3
|
+
|
|
4
|
+
#include "cuda_common.cuh"
|
|
5
|
+
#include "gpu_memory.cuh"
|
|
6
|
+
|
|
7
|
+
namespace etb {
|
|
8
|
+
namespace cuda {
|
|
9
|
+
|
|
10
|
+
/**
|
|
11
|
+
* Shared memory structure for prefix pruning operations.
|
|
12
|
+
*/
|
|
13
|
+
struct PrefixPrunerSharedMem {
|
|
14
|
+
// Warp voting results
|
|
15
|
+
uint32_t warp_votes[8]; // 8 warps per block max
|
|
16
|
+
|
|
17
|
+
// Prefix being evaluated
|
|
18
|
+
uint8_t current_prefix[16];
|
|
19
|
+
uint32_t prefix_length;
|
|
20
|
+
|
|
21
|
+
// Trie navigation state
|
|
22
|
+
uint32_t current_node_idx;
|
|
23
|
+
DevicePrefixStatus current_status;
|
|
24
|
+
float current_score;
|
|
25
|
+
};
|
|
26
|
+
|
|
27
|
+
/**
|
|
28
|
+
* Prefix pruner CUDA kernel.
|
|
29
|
+
*
|
|
30
|
+
* OPTIMIZED: Now takes pre-computed prefix offsets to avoid O(n²) offset calculation.
|
|
31
|
+
* Host should compute prefix_offsets using exclusive_scan(prefix_lengths).
|
|
32
|
+
*
|
|
33
|
+
* Implements warp-level voting for termination decisions and
|
|
34
|
+
* atomic trie updates for prefix status.
|
|
35
|
+
*
|
|
36
|
+
* Requirements: 9.8
|
|
37
|
+
*
|
|
38
|
+
* @param prefix_trie Prefix trie nodes
|
|
39
|
+
* @param trie_size Number of nodes in trie
|
|
40
|
+
* @param prefixes Array of prefixes to evaluate
|
|
41
|
+
* @param prefix_lengths Array of prefix lengths
|
|
42
|
+
* @param prefix_offsets Pre-computed exclusive prefix sum of lengths (NEW)
|
|
43
|
+
* @param scores Array of heuristic scores for each prefix
|
|
44
|
+
* @param num_prefixes Number of prefixes to evaluate
|
|
45
|
+
* @param prune_threshold Score threshold for pruning
|
|
46
|
+
* @param prune_results Output array indicating if each prefix was pruned
|
|
47
|
+
*/
|
|
48
|
+
__global__ void prefix_pruner_kernel(
|
|
49
|
+
DevicePrefixTrieNode* prefix_trie,
|
|
50
|
+
uint32_t trie_size,
|
|
51
|
+
const uint8_t* prefixes,
|
|
52
|
+
const uint32_t* prefix_lengths,
|
|
53
|
+
const uint32_t* prefix_offsets,
|
|
54
|
+
const float* scores,
|
|
55
|
+
uint32_t num_prefixes,
|
|
56
|
+
float prune_threshold,
|
|
57
|
+
bool* prune_results
|
|
58
|
+
);
|
|
59
|
+
|
|
60
|
+
/**
|
|
61
|
+
* Trie lookup kernel.
|
|
62
|
+
* OPTIMIZED: Uses pre-computed prefix offsets.
|
|
63
|
+
*
|
|
64
|
+
* @param prefix_trie Prefix trie nodes
|
|
65
|
+
* @param trie_size Number of nodes in trie
|
|
66
|
+
* @param prefixes Array of prefixes to look up
|
|
67
|
+
* @param prefix_lengths Array of prefix lengths
|
|
68
|
+
* @param prefix_offsets Pre-computed exclusive prefix sum of lengths (NEW)
|
|
69
|
+
* @param num_prefixes Number of prefixes
|
|
70
|
+
* @param statuses Output array of prefix statuses
|
|
71
|
+
* @param scores Output array of best scores
|
|
72
|
+
*/
|
|
73
|
+
__global__ void trie_lookup_kernel(
|
|
74
|
+
const DevicePrefixTrieNode* prefix_trie,
|
|
75
|
+
uint32_t trie_size,
|
|
76
|
+
const uint8_t* prefixes,
|
|
77
|
+
const uint32_t* prefix_lengths,
|
|
78
|
+
const uint32_t* prefix_offsets,
|
|
79
|
+
uint32_t num_prefixes,
|
|
80
|
+
DevicePrefixStatus* statuses,
|
|
81
|
+
float* scores
|
|
82
|
+
);
|
|
83
|
+
|
|
84
|
+
/**
|
|
85
|
+
* Trie insert/update kernel.
|
|
86
|
+
* OPTIMIZED: Uses pre-computed prefix offsets and warp-cooperative allocation.
|
|
87
|
+
*
|
|
88
|
+
* @param prefix_trie Prefix trie nodes
|
|
89
|
+
* @param trie_size Current number of nodes
|
|
90
|
+
* @param max_trie_size Maximum trie capacity
|
|
91
|
+
* @param prefixes Array of prefixes to insert
|
|
92
|
+
* @param prefix_lengths Array of prefix lengths
|
|
93
|
+
* @param prefix_offsets Pre-computed exclusive prefix sum of lengths (NEW)
|
|
94
|
+
* @param statuses Array of statuses to set
|
|
95
|
+
* @param scores Array of scores to set
|
|
96
|
+
* @param num_prefixes Number of prefixes
|
|
97
|
+
* @param new_trie_size Output: new trie size after insertions
|
|
98
|
+
*/
|
|
99
|
+
__global__ void trie_insert_kernel(
|
|
100
|
+
DevicePrefixTrieNode* prefix_trie,
|
|
101
|
+
uint32_t trie_size,
|
|
102
|
+
uint32_t max_trie_size,
|
|
103
|
+
const uint8_t* prefixes,
|
|
104
|
+
const uint32_t* prefix_lengths,
|
|
105
|
+
const uint32_t* prefix_offsets,
|
|
106
|
+
const DevicePrefixStatus* statuses,
|
|
107
|
+
const float* scores,
|
|
108
|
+
uint32_t num_prefixes,
|
|
109
|
+
uint32_t* new_trie_size
|
|
110
|
+
);
|
|
111
|
+
|
|
112
|
+
/**
|
|
113
|
+
* Batch prune check kernel.
|
|
114
|
+
* OPTIMIZED: Uses pre-computed prefix offsets.
|
|
115
|
+
*
|
|
116
|
+
* @param prefix_trie Prefix trie nodes
|
|
117
|
+
* @param trie_size Number of nodes
|
|
118
|
+
* @param prefixes Array of prefixes to check
|
|
119
|
+
* @param prefix_lengths Array of prefix lengths
|
|
120
|
+
* @param prefix_offsets Pre-computed exclusive prefix sum of lengths (NEW)
|
|
121
|
+
* @param num_prefixes Number of prefixes
|
|
122
|
+
* @param should_skip Output array indicating if prefix should be skipped
|
|
123
|
+
*/
|
|
124
|
+
__global__ void batch_prune_check_kernel(
|
|
125
|
+
const DevicePrefixTrieNode* prefix_trie,
|
|
126
|
+
uint32_t trie_size,
|
|
127
|
+
const uint8_t* prefixes,
|
|
128
|
+
const uint32_t* prefix_lengths,
|
|
129
|
+
const uint32_t* prefix_offsets,
|
|
130
|
+
uint32_t num_prefixes,
|
|
131
|
+
bool* should_skip
|
|
132
|
+
);
|
|
133
|
+
|
|
134
|
+
/**
|
|
135
|
+
* Host-side launcher for prefix pruner kernels.
|
|
136
|
+
*/
|
|
137
|
+
class PrefixPrunerKernel {
|
|
138
|
+
public:
|
|
139
|
+
PrefixPrunerKernel();
|
|
140
|
+
~PrefixPrunerKernel();
|
|
141
|
+
|
|
142
|
+
/**
|
|
143
|
+
* Configure the kernel for a specific device.
|
|
144
|
+
* @param device_id CUDA device ID
|
|
145
|
+
*/
|
|
146
|
+
void configure(int device_id);
|
|
147
|
+
|
|
148
|
+
/**
|
|
149
|
+
* Evaluate and prune prefixes.
|
|
150
|
+
* @param trie Device prefix trie
|
|
151
|
+
* @param trie_size Number of trie nodes
|
|
152
|
+
* @param prefixes Device array of prefixes (flattened)
|
|
153
|
+
* @param prefix_lengths Device array of prefix lengths
|
|
154
|
+
* @param scores Device array of heuristic scores
|
|
155
|
+
* @param num_prefixes Number of prefixes
|
|
156
|
+
* @param prune_threshold Score threshold for pruning
|
|
157
|
+
* @param prune_results Device array for results
|
|
158
|
+
* @param stream CUDA stream
|
|
159
|
+
*/
|
|
160
|
+
void evaluate_and_prune(DevicePrefixTrieNode* trie, uint32_t trie_size,
|
|
161
|
+
const uint8_t* prefixes, const uint32_t* prefix_lengths,
|
|
162
|
+
const float* scores, uint32_t num_prefixes,
|
|
163
|
+
float prune_threshold, bool* prune_results,
|
|
164
|
+
cudaStream_t stream = nullptr);
|
|
165
|
+
|
|
166
|
+
/**
|
|
167
|
+
* Look up prefix statuses.
|
|
168
|
+
* @param trie Device prefix trie
|
|
169
|
+
* @param trie_size Number of trie nodes
|
|
170
|
+
* @param prefixes Device array of prefixes
|
|
171
|
+
* @param prefix_lengths Device array of prefix lengths
|
|
172
|
+
* @param num_prefixes Number of prefixes
|
|
173
|
+
* @param statuses Device array for output statuses
|
|
174
|
+
* @param scores Device array for output scores
|
|
175
|
+
* @param stream CUDA stream
|
|
176
|
+
*/
|
|
177
|
+
void lookup(const DevicePrefixTrieNode* trie, uint32_t trie_size,
|
|
178
|
+
const uint8_t* prefixes, const uint32_t* prefix_lengths,
|
|
179
|
+
uint32_t num_prefixes, DevicePrefixStatus* statuses,
|
|
180
|
+
float* scores, cudaStream_t stream = nullptr);
|
|
181
|
+
|
|
182
|
+
/**
|
|
183
|
+
* Check if prefixes should be skipped due to pruned ancestors.
|
|
184
|
+
* @param trie Device prefix trie
|
|
185
|
+
* @param trie_size Number of trie nodes
|
|
186
|
+
* @param prefixes Device array of prefixes
|
|
187
|
+
* @param prefix_lengths Device array of prefix lengths
|
|
188
|
+
* @param num_prefixes Number of prefixes
|
|
189
|
+
* @param should_skip Device array for output flags
|
|
190
|
+
* @param stream CUDA stream
|
|
191
|
+
*/
|
|
192
|
+
void check_pruned(const DevicePrefixTrieNode* trie, uint32_t trie_size,
|
|
193
|
+
const uint8_t* prefixes, const uint32_t* prefix_lengths,
|
|
194
|
+
uint32_t num_prefixes, bool* should_skip,
|
|
195
|
+
cudaStream_t stream = nullptr);
|
|
196
|
+
|
|
197
|
+
/**
|
|
198
|
+
* Get the kernel configuration.
|
|
199
|
+
*/
|
|
200
|
+
const KernelConfig& get_config() const { return kernel_config_; }
|
|
201
|
+
|
|
202
|
+
private:
|
|
203
|
+
KernelConfig kernel_config_;
|
|
204
|
+
bool configured_;
|
|
205
|
+
};
|
|
206
|
+
|
|
207
|
+
// ============================================================================
|
|
208
|
+
// Device Functions
|
|
209
|
+
// ============================================================================
|
|
210
|
+
|
|
211
|
+
/**
|
|
212
|
+
* Navigate trie to find node for a prefix.
|
|
213
|
+
* Returns node index or UINT32_MAX if not found.
|
|
214
|
+
*/
|
|
215
|
+
__device__ inline uint32_t find_prefix_node(
|
|
216
|
+
const DevicePrefixTrieNode* trie,
|
|
217
|
+
uint32_t trie_size,
|
|
218
|
+
const uint8_t* prefix,
|
|
219
|
+
uint32_t prefix_length
|
|
220
|
+
) {
|
|
221
|
+
if (trie_size == 0 || prefix_length == 0) {
|
|
222
|
+
return 0; // Root node
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
uint32_t current = 0; // Start at root
|
|
226
|
+
|
|
227
|
+
for (uint32_t i = 0; i < prefix_length; ++i) {
|
|
228
|
+
uint8_t byte_val = prefix[i];
|
|
229
|
+
uint32_t children_offset = trie[current].children_offset;
|
|
230
|
+
|
|
231
|
+
if (children_offset == 0) {
|
|
232
|
+
return UINT32_MAX; // No children, prefix not found
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
// Look for child with matching byte value
|
|
236
|
+
bool found = false;
|
|
237
|
+
for (uint32_t c = 0; c < 256 && children_offset + c < trie_size; ++c) {
|
|
238
|
+
uint32_t child_idx = children_offset + c;
|
|
239
|
+
if (child_idx < trie_size && trie[child_idx].reconstructed_byte == byte_val) {
|
|
240
|
+
current = child_idx;
|
|
241
|
+
found = true;
|
|
242
|
+
break;
|
|
243
|
+
}
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
if (!found) {
|
|
247
|
+
return UINT32_MAX;
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
return current;
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
/**
|
|
255
|
+
* Check if any ancestor of a prefix is pruned.
|
|
256
|
+
*/
|
|
257
|
+
__device__ inline bool is_ancestor_pruned(
|
|
258
|
+
const DevicePrefixTrieNode* trie,
|
|
259
|
+
uint32_t trie_size,
|
|
260
|
+
const uint8_t* prefix,
|
|
261
|
+
uint32_t prefix_length
|
|
262
|
+
) {
|
|
263
|
+
if (trie_size == 0) return false;
|
|
264
|
+
|
|
265
|
+
uint32_t current = 0;
|
|
266
|
+
|
|
267
|
+
for (uint32_t i = 0; i < prefix_length; ++i) {
|
|
268
|
+
// Check current node status
|
|
269
|
+
if (trie[current].status == DevicePrefixStatus::PRUNED) {
|
|
270
|
+
return true;
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
uint8_t byte_val = prefix[i];
|
|
274
|
+
uint32_t children_offset = trie[current].children_offset;
|
|
275
|
+
|
|
276
|
+
if (children_offset == 0) {
|
|
277
|
+
return false; // No more nodes to check
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
// Find child
|
|
281
|
+
bool found = false;
|
|
282
|
+
for (uint32_t c = 0; c < 256 && children_offset + c < trie_size; ++c) {
|
|
283
|
+
uint32_t child_idx = children_offset + c;
|
|
284
|
+
if (child_idx < trie_size && trie[child_idx].reconstructed_byte == byte_val) {
|
|
285
|
+
current = child_idx;
|
|
286
|
+
found = true;
|
|
287
|
+
break;
|
|
288
|
+
}
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
if (!found) {
|
|
292
|
+
return false;
|
|
293
|
+
}
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
// Check final node
|
|
297
|
+
return trie[current].status == DevicePrefixStatus::PRUNED;
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
/**
|
|
301
|
+
* Atomically update node status.
|
|
302
|
+
*/
|
|
303
|
+
__device__ inline void atomic_update_status(
|
|
304
|
+
DevicePrefixTrieNode* node,
|
|
305
|
+
DevicePrefixStatus new_status
|
|
306
|
+
) {
|
|
307
|
+
// Use atomicCAS on the status byte
|
|
308
|
+
uint8_t* status_ptr = reinterpret_cast<uint8_t*>(&node->status);
|
|
309
|
+
uint8_t old_val = *status_ptr;
|
|
310
|
+
uint8_t new_val = static_cast<uint8_t>(new_status);
|
|
311
|
+
|
|
312
|
+
// Only update if transitioning to a "more final" state
|
|
313
|
+
// UNKNOWN -> VALID or PRUNED is allowed
|
|
314
|
+
// VALID -> PRUNED is allowed
|
|
315
|
+
// PRUNED is final
|
|
316
|
+
if (old_val == static_cast<uint8_t>(DevicePrefixStatus::PRUNED)) {
|
|
317
|
+
return; // Already pruned, don't change
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
atomicCAS(reinterpret_cast<unsigned int*>(status_ptr),
|
|
321
|
+
static_cast<unsigned int>(old_val),
|
|
322
|
+
static_cast<unsigned int>(new_val));
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
/**
|
|
326
|
+
* Atomically update best score (only if new score is higher).
|
|
327
|
+
*/
|
|
328
|
+
__device__ inline void atomic_update_score(
|
|
329
|
+
DevicePrefixTrieNode* node,
|
|
330
|
+
float new_score
|
|
331
|
+
) {
|
|
332
|
+
// Use atomicMax on the score
|
|
333
|
+
// Since atomicMax doesn't work directly on floats, we use a CAS loop
|
|
334
|
+
float* score_ptr = &node->best_score;
|
|
335
|
+
float old_score = *score_ptr;
|
|
336
|
+
|
|
337
|
+
while (new_score > old_score) {
|
|
338
|
+
float assumed = old_score;
|
|
339
|
+
old_score = __int_as_float(atomicCAS(
|
|
340
|
+
reinterpret_cast<int*>(score_ptr),
|
|
341
|
+
__float_as_int(assumed),
|
|
342
|
+
__float_as_int(new_score)
|
|
343
|
+
));
|
|
344
|
+
|
|
345
|
+
if (old_score == assumed) {
|
|
346
|
+
break; // Successfully updated
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
/**
|
|
352
|
+
* Warp-level vote for pruning decision.
|
|
353
|
+
* Returns true if majority of warp votes to prune.
|
|
354
|
+
*/
|
|
355
|
+
__device__ inline bool warp_vote_prune(bool should_prune) {
|
|
356
|
+
unsigned int vote = __ballot_sync(0xFFFFFFFF, should_prune);
|
|
357
|
+
return __popc(vote) > 16; // More than half
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
/**
|
|
361
|
+
* Increment visit count atomically.
|
|
362
|
+
*/
|
|
363
|
+
__device__ inline void atomic_increment_visit(DevicePrefixTrieNode* node) {
|
|
364
|
+
atomicAdd(&node->visit_count, 1);
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
} // namespace cuda
|
|
368
|
+
} // namespace etb
|
|
369
|
+
|
|
370
|
+
#endif // ETB_PREFIX_PRUNER_KERNEL_CUH
|