cactus-react-native 1.4.0 → 1.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +212 -27
- package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
- package/cpp/HybridCactus.cpp +119 -0
- package/cpp/HybridCactus.hpp +13 -0
- package/cpp/cactus_ffi.h +24 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +24 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_utils.h +41 -1
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +66 -48
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/gemma_tools.h +549 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +102 -21
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +45 -195
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +399 -140
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +24 -0
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_utils.h +41 -1
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +66 -48
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/gemma_tools.h +549 -0
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +102 -21
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +45 -195
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +399 -140
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
- package/lib/module/api/Database.js +0 -92
- package/lib/module/api/Database.js.map +1 -1
- package/lib/module/classes/CactusLM.js +33 -15
- package/lib/module/classes/CactusLM.js.map +1 -1
- package/lib/module/classes/CactusSTT.js +90 -15
- package/lib/module/classes/CactusSTT.js.map +1 -1
- package/lib/module/hooks/useCactusLM.js +14 -5
- package/lib/module/hooks/useCactusLM.js.map +1 -1
- package/lib/module/hooks/useCactusSTT.js +100 -4
- package/lib/module/hooks/useCactusSTT.js.map +1 -1
- package/lib/module/index.js.map +1 -1
- package/lib/module/models.js +336 -0
- package/lib/module/models.js.map +1 -0
- package/lib/module/native/Cactus.js +37 -0
- package/lib/module/native/Cactus.js.map +1 -1
- package/lib/module/types/CactusLM.js +2 -0
- package/lib/module/types/CactusSTT.js +2 -0
- package/lib/module/types/common.js +2 -0
- package/lib/module/types/{CactusModel.js.map → common.js.map} +1 -1
- package/lib/typescript/src/api/Database.d.ts +0 -6
- package/lib/typescript/src/api/Database.d.ts.map +1 -1
- package/lib/typescript/src/classes/CactusLM.d.ts +7 -3
- package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
- package/lib/typescript/src/classes/CactusSTT.d.ts +13 -4
- package/lib/typescript/src/classes/CactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/hooks/useCactusLM.d.ts +2 -2
- package/lib/typescript/src/hooks/useCactusLM.d.ts.map +1 -1
- package/lib/typescript/src/hooks/useCactusSTT.d.ts +12 -4
- package/lib/typescript/src/hooks/useCactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/index.d.ts +2 -3
- package/lib/typescript/src/index.d.ts.map +1 -1
- package/lib/typescript/src/models.d.ts +6 -0
- package/lib/typescript/src/models.d.ts.map +1 -0
- package/lib/typescript/src/native/Cactus.d.ts +6 -1
- package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
- package/lib/typescript/src/specs/Cactus.nitro.d.ts +5 -0
- package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusLM.d.ts +2 -0
- package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusSTT.d.ts +20 -0
- package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/types/common.d.ts +28 -0
- package/lib/typescript/src/types/common.d.ts.map +1 -0
- package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +5 -0
- package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +5 -0
- package/package.json +1 -1
- package/src/api/Database.ts +0 -133
- package/src/classes/CactusLM.ts +49 -17
- package/src/classes/CactusSTT.ts +118 -17
- package/src/hooks/useCactusLM.ts +25 -5
- package/src/hooks/useCactusSTT.ts +117 -5
- package/src/index.tsx +6 -2
- package/src/models.ts +344 -0
- package/src/native/Cactus.ts +55 -0
- package/src/specs/Cactus.nitro.ts +5 -0
- package/src/types/CactusLM.ts +3 -0
- package/src/types/CactusSTT.ts +26 -0
- package/src/types/common.ts +28 -0
- package/lib/module/types/CactusModel.js +0 -2
- package/lib/module/types/CactusSTTModel.js +0 -2
- package/lib/module/types/CactusSTTModel.js.map +0 -1
- package/lib/typescript/src/types/CactusModel.d.ts +0 -13
- package/lib/typescript/src/types/CactusModel.d.ts.map +0 -1
- package/lib/typescript/src/types/CactusSTTModel.d.ts +0 -8
- package/lib/typescript/src/types/CactusSTTModel.d.ts.map +0 -1
- package/src/types/CactusModel.ts +0 -15
- package/src/types/CactusSTTModel.ts +0 -10
|
@@ -11,6 +11,7 @@
|
|
|
11
11
|
#include <mutex>
|
|
12
12
|
#include <sstream>
|
|
13
13
|
#include <iostream>
|
|
14
|
+
#include <arm_neon.h>
|
|
14
15
|
|
|
15
16
|
namespace cactus {
|
|
16
17
|
|
|
@@ -96,9 +97,10 @@ namespace GraphFile {
|
|
|
96
97
|
}
|
|
97
98
|
|
|
98
99
|
enum class Precision {
|
|
99
|
-
INT8,
|
|
100
|
+
INT8,
|
|
100
101
|
FP16,
|
|
101
|
-
FP32
|
|
102
|
+
FP32,
|
|
103
|
+
INT4
|
|
102
104
|
};
|
|
103
105
|
|
|
104
106
|
enum class ComputeBackend {
|
|
@@ -112,7 +114,7 @@ enum class OpType {
|
|
|
112
114
|
MATMUL, TRANSPOSE, RESHAPE, SLICE, GATHER, EMBEDDING,
|
|
113
115
|
BILINEAR_INTERPOLATION,
|
|
114
116
|
SUM, MEAN, VARIANCE, MIN, MAX,
|
|
115
|
-
RMS_NORM, ROPE, SOFTMAX, ATTENTION, CONV1D_CAUSAL, CONV1D_K3,
|
|
117
|
+
RMS_NORM, ROPE, SOFTMAX, ATTENTION, ATTENTION_INT8_HYBRID, CONV1D_CAUSAL, CONV1D_K3,
|
|
116
118
|
SCALAR_ADD, SCALAR_SUBTRACT, SCALAR_MULTIPLY, SCALAR_DIVIDE, SCALAR_EXP, SCALAR_SQRT, SCALAR_COS, SCALAR_SIN,
|
|
117
119
|
SILU, GELU, GELU_ERF,
|
|
118
120
|
SAMPLE, CONCAT,
|
|
@@ -122,27 +124,38 @@ enum class OpType {
|
|
|
122
124
|
};
|
|
123
125
|
|
|
124
126
|
struct PrecisionTraits {
|
|
127
|
+
// Returns in-memory element size (INT4 unpacks to INT8, so returns 1)
|
|
125
128
|
static constexpr size_t size_of(Precision prec) {
|
|
126
129
|
switch (prec) {
|
|
127
130
|
case Precision::INT8: return 1;
|
|
128
131
|
case Precision::FP16: return 2;
|
|
129
132
|
case Precision::FP32: return 4;
|
|
133
|
+
case Precision::INT4: return 1;
|
|
130
134
|
}
|
|
131
135
|
return 1;
|
|
132
136
|
}
|
|
133
|
-
|
|
137
|
+
|
|
138
|
+
static constexpr size_t packed_size_of(Precision prec, size_t count) {
|
|
139
|
+
switch (prec) {
|
|
140
|
+
case Precision::INT4: return (count + 1) / 2;
|
|
141
|
+
default: return count * size_of(prec);
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
|
|
134
145
|
static constexpr bool is_integer(Precision prec) {
|
|
135
146
|
switch (prec) {
|
|
136
147
|
case Precision::INT8: return true;
|
|
148
|
+
case Precision::INT4: return true;
|
|
137
149
|
case Precision::FP16: return false;
|
|
138
150
|
case Precision::FP32: return false;
|
|
139
151
|
}
|
|
140
152
|
return true;
|
|
141
153
|
}
|
|
142
|
-
|
|
154
|
+
|
|
143
155
|
static constexpr bool is_floating_point(Precision prec) {
|
|
144
156
|
switch (prec) {
|
|
145
157
|
case Precision::INT8: return false;
|
|
158
|
+
case Precision::INT4: return false;
|
|
146
159
|
case Precision::FP16: return true;
|
|
147
160
|
case Precision::FP32: return true;
|
|
148
161
|
}
|
|
@@ -153,8 +166,6 @@ struct PrecisionTraits {
|
|
|
153
166
|
namespace Quantization {
|
|
154
167
|
void int8_to_fp32(const int8_t* src, float* dst, size_t count, float scale = 1.0f);
|
|
155
168
|
void fp32_to_int8(const float* src, int8_t* dst, size_t count, float scale = 1.0f);
|
|
156
|
-
void dynamic_quantize_fp32_to_int8(const float* src, int8_t* dst, size_t count,
|
|
157
|
-
float* computed_scale);
|
|
158
169
|
void fp16_to_fp32(const __fp16* src, float* dst, size_t count);
|
|
159
170
|
void fp32_to_fp16(const float* src, __fp16* dst, size_t count);
|
|
160
171
|
void int8_to_fp16(const int8_t* src, __fp16* dst, size_t count, float scale = 1.0f);
|
|
@@ -188,10 +199,17 @@ struct BufferDesc {
|
|
|
188
199
|
void* external_data;
|
|
189
200
|
char* pooled_data;
|
|
190
201
|
Precision precision;
|
|
191
|
-
|
|
202
|
+
|
|
203
|
+
size_t group_size = 0;
|
|
204
|
+
size_t num_groups = 0;
|
|
205
|
+
void* scales_data = nullptr;
|
|
206
|
+
std::unique_ptr<char[]> owned_scales;
|
|
207
|
+
|
|
208
|
+
const void* packed_int4_data = nullptr;
|
|
209
|
+
size_t packed_int4_size = 0;
|
|
192
210
|
|
|
193
211
|
BufferDesc();
|
|
194
|
-
BufferDesc(const std::vector<size_t>& s, Precision prec = Precision::INT8
|
|
212
|
+
BufferDesc(const std::vector<size_t>& s, Precision prec = Precision::INT8);
|
|
195
213
|
~BufferDesc();
|
|
196
214
|
|
|
197
215
|
BufferDesc(BufferDesc&& other) noexcept;
|
|
@@ -209,6 +227,28 @@ struct BufferDesc {
|
|
|
209
227
|
template<typename T>
|
|
210
228
|
const T* data_as() const { return static_cast<const T*>(get_data()); }
|
|
211
229
|
|
|
230
|
+
const __fp16* scales_as_fp16() const {
|
|
231
|
+
return reinterpret_cast<const __fp16*>(scales_data);
|
|
232
|
+
}
|
|
233
|
+
bool is_grouped_int8() const {
|
|
234
|
+
return precision == Precision::INT8 && group_size > 0;
|
|
235
|
+
}
|
|
236
|
+
bool is_packed_int4() const {
|
|
237
|
+
return packed_int4_data != nullptr && packed_int4_size > 0;
|
|
238
|
+
}
|
|
239
|
+
const uint8_t* packed_int4_as_uint8() const {
|
|
240
|
+
return reinterpret_cast<const uint8_t*>(packed_int4_data);
|
|
241
|
+
}
|
|
242
|
+
void set_grouped_scales(size_t gs, size_t ng, void* scales_ptr) {
|
|
243
|
+
group_size = gs;
|
|
244
|
+
num_groups = ng;
|
|
245
|
+
scales_data = scales_ptr;
|
|
246
|
+
}
|
|
247
|
+
void set_packed_int4(const void* packed_data, size_t packed_size) {
|
|
248
|
+
packed_int4_data = packed_data;
|
|
249
|
+
packed_int4_size = packed_size;
|
|
250
|
+
}
|
|
251
|
+
|
|
212
252
|
void allocate();
|
|
213
253
|
void allocate_from_pool(BufferPool& pool);
|
|
214
254
|
void release_to_pool(BufferPool& pool);
|
|
@@ -247,6 +287,14 @@ struct OpParams {
|
|
|
247
287
|
|
|
248
288
|
std::vector<float> bias_values;
|
|
249
289
|
std::vector<uint32_t> bias_indices;
|
|
290
|
+
|
|
291
|
+
const int8_t* cached_keys_int8 = nullptr;
|
|
292
|
+
const int8_t* cached_values_int8 = nullptr;
|
|
293
|
+
const float* cached_k_scales = nullptr;
|
|
294
|
+
const float* cached_v_scales = nullptr;
|
|
295
|
+
size_t cache_seq_len = 0;
|
|
296
|
+
size_t num_kv_heads = 0;
|
|
297
|
+
size_t head_dim = 0;
|
|
250
298
|
};
|
|
251
299
|
|
|
252
300
|
struct GraphNode {
|
|
@@ -326,7 +374,7 @@ public:
|
|
|
326
374
|
size_t precision_cast(size_t input, Precision target_precision);
|
|
327
375
|
|
|
328
376
|
size_t add(size_t input1, size_t input2);
|
|
329
|
-
size_t add_clipped(size_t input1, size_t input2);
|
|
377
|
+
size_t add_clipped(size_t input1, size_t input2);
|
|
330
378
|
size_t subtract(size_t input1, size_t input2);
|
|
331
379
|
size_t multiply(size_t input1, size_t input2);
|
|
332
380
|
size_t divide(size_t input1, size_t input2);
|
|
@@ -361,8 +409,12 @@ public:
|
|
|
361
409
|
size_t gather(size_t embeddings, size_t indices);
|
|
362
410
|
size_t mmap_embeddings(const std::string& filename);
|
|
363
411
|
size_t mmap_weights(const std::string& filename);
|
|
364
|
-
size_t load_weights(const std::string& filename);
|
|
365
|
-
void
|
|
412
|
+
size_t load_weights(const std::string& filename);
|
|
413
|
+
void set_grouped_scales(size_t node_id, size_t group_size, size_t num_groups, void* scales_ptr);
|
|
414
|
+
|
|
415
|
+
void release_weight_pages(size_t node_id);
|
|
416
|
+
void prefetch_weight_pages(size_t node_id);
|
|
417
|
+
void release_all_weight_pages();
|
|
366
418
|
size_t embedding(const std::string& filename, size_t indices);
|
|
367
419
|
size_t embedding(size_t embedding_tensor, size_t indices);
|
|
368
420
|
size_t bilinear_interpolation(size_t pos_embeds, size_t dst_height, size_t dst_width);
|
|
@@ -376,6 +428,11 @@ public:
|
|
|
376
428
|
size_t attention(size_t query, size_t key, size_t value, float scale, size_t position_offset, ComputeBackend backend = ComputeBackend::CPU);
|
|
377
429
|
size_t attention(size_t query, size_t key, size_t value, float scale, size_t position_offset, size_t window_size, ComputeBackend backend = ComputeBackend::CPU);
|
|
378
430
|
|
|
431
|
+
size_t attention_int8_hybrid(size_t query, size_t key_new, size_t value_new, float scale, size_t position_offset,
|
|
432
|
+
const int8_t* cached_keys, const int8_t* cached_values,
|
|
433
|
+
const float* k_scales, const float* v_scales,
|
|
434
|
+
size_t cache_len, size_t num_kv_heads, size_t head_dim);
|
|
435
|
+
|
|
379
436
|
size_t conv1d_causal(size_t input, size_t weight, size_t kernel_size, size_t dilation = 1);
|
|
380
437
|
size_t conv1d_k3(size_t input, size_t weight, size_t stride);
|
|
381
438
|
|
|
@@ -392,6 +449,8 @@ public:
|
|
|
392
449
|
void execute(const std::string& profile_file = "");
|
|
393
450
|
void hard_reset();
|
|
394
451
|
void soft_reset();
|
|
452
|
+
void soft_reset_keep_pool();
|
|
453
|
+
void set_prefill_mode(bool enabled) { prefill_mode_ = enabled; }
|
|
395
454
|
|
|
396
455
|
void register_debug_node(uint32_t layer_idx, const std::string& name, size_t node_id);
|
|
397
456
|
void capture_debug_node(uint32_t layer_idx, const std::string& name, size_t node_id);
|
|
@@ -410,8 +469,10 @@ private:
|
|
|
410
469
|
size_t next_node_id_;
|
|
411
470
|
std::vector<std::unique_ptr<GraphFile::MappedFile>> mapped_files_;
|
|
412
471
|
std::unordered_map<std::string, size_t> weight_cache_;
|
|
472
|
+
std::unordered_map<size_t, size_t> node_to_mapped_file_;
|
|
413
473
|
std::vector<DebugNodeEntry> debug_nodes_;
|
|
414
474
|
BufferPool buffer_pool_;
|
|
475
|
+
bool prefill_mode_ = false;
|
|
415
476
|
};
|
|
416
477
|
|
|
417
478
|
|
|
@@ -430,25 +491,36 @@ namespace GraphFile {
|
|
|
430
491
|
public:
|
|
431
492
|
MappedFile(const std::string& filename);
|
|
432
493
|
~MappedFile();
|
|
433
|
-
|
|
494
|
+
|
|
434
495
|
MappedFile(const MappedFile&) = delete;
|
|
435
496
|
MappedFile& operator=(const MappedFile&) = delete;
|
|
436
497
|
MappedFile(MappedFile&& other) noexcept;
|
|
437
498
|
MappedFile& operator=(MappedFile&& other) noexcept;
|
|
438
|
-
|
|
499
|
+
|
|
439
500
|
const std::vector<size_t>& shape() const;
|
|
440
501
|
Precision precision() const;
|
|
502
|
+
Precision effective_precision() const {
|
|
503
|
+
return is_int4_ ? Precision::INT8 : precision_;
|
|
504
|
+
}
|
|
441
505
|
size_t byte_size() const;
|
|
442
|
-
|
|
443
|
-
|
|
506
|
+
|
|
507
|
+
size_t group_size() const { return group_size_; }
|
|
508
|
+
size_t num_groups() const { return num_groups_; }
|
|
509
|
+
const void* scales_data() const;
|
|
510
|
+
const void* raw_packed_data() const; // Get raw mmap'd data without unpacking (for INT4)
|
|
511
|
+
bool is_int4() const { return is_int4_; }
|
|
512
|
+
|
|
444
513
|
void* data();
|
|
445
514
|
const void* data() const;
|
|
446
|
-
|
|
515
|
+
|
|
447
516
|
template<typename T>
|
|
448
517
|
const T* typed_data() const;
|
|
449
|
-
|
|
518
|
+
|
|
450
519
|
LoadedNode load_into_graph(CactusGraph& graph) const;
|
|
451
|
-
|
|
520
|
+
|
|
521
|
+
void release_pages();
|
|
522
|
+
void prefetch_pages();
|
|
523
|
+
|
|
452
524
|
private:
|
|
453
525
|
int fd_;
|
|
454
526
|
void* mapped_data_;
|
|
@@ -456,10 +528,19 @@ namespace GraphFile {
|
|
|
456
528
|
std::vector<size_t> shape_;
|
|
457
529
|
Precision precision_;
|
|
458
530
|
size_t byte_size_;
|
|
459
|
-
|
|
531
|
+
size_t group_size_ = 0;
|
|
532
|
+
size_t num_groups_ = 0;
|
|
533
|
+
size_t scales_offset_ = 0;
|
|
534
|
+
size_t scales_bytes_ = 0;
|
|
535
|
+
uint32_t version_ = 1;
|
|
536
|
+
uint32_t alignment_ = 32;
|
|
537
|
+
bool is_int4_ = false;
|
|
538
|
+
mutable std::unique_ptr<int8_t[]> unpacked_int4_data_;
|
|
460
539
|
void parse_header();
|
|
540
|
+
void apply_madvise_hints();
|
|
541
|
+
void unpack_int4_if_needed() const;
|
|
461
542
|
};
|
|
462
|
-
|
|
543
|
+
|
|
463
544
|
MappedFile mmap_load(const std::string& filename);
|
|
464
545
|
}
|
|
465
546
|
|
|
@@ -15,12 +15,7 @@ enum class ScalarOpType {
|
|
|
15
15
|
SIN
|
|
16
16
|
};
|
|
17
17
|
|
|
18
|
-
|
|
19
|
-
void cactus_add_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
|
|
20
|
-
void cactus_subtract_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
|
|
21
|
-
void cactus_multiply_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
|
|
22
|
-
void cactus_divide_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
|
|
23
|
-
|
|
18
|
+
constexpr size_t KV_QUANT_GROUP_SIZE = 128;
|
|
24
19
|
|
|
25
20
|
void cactus_add_f16(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
|
|
26
21
|
void cactus_add_f16_clipped(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
|
|
@@ -28,27 +23,6 @@ void cactus_subtract_f16(const __fp16* a, const __fp16* b, __fp16* output, size_
|
|
|
28
23
|
void cactus_multiply_f16(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
|
|
29
24
|
void cactus_divide_f16(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
|
|
30
25
|
|
|
31
|
-
|
|
32
|
-
void cactus_add_f32(const float* a, const float* b, float* output, size_t num_elements);
|
|
33
|
-
void cactus_subtract_f32(const float* a, const float* b, float* output, size_t num_elements);
|
|
34
|
-
void cactus_multiply_f32(const float* a, const float* b, float* output, size_t num_elements);
|
|
35
|
-
void cactus_divide_f32(const float* a, const float* b, float* output, size_t num_elements);
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
void cactus_add_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
|
|
39
|
-
const size_t* a_strides, const size_t* b_strides,
|
|
40
|
-
const size_t* output_shape, size_t ndim);
|
|
41
|
-
void cactus_subtract_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
|
|
42
|
-
const size_t* a_strides, const size_t* b_strides,
|
|
43
|
-
const size_t* output_shape, size_t ndim);
|
|
44
|
-
void cactus_multiply_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
|
|
45
|
-
const size_t* a_strides, const size_t* b_strides,
|
|
46
|
-
const size_t* output_shape, size_t ndim);
|
|
47
|
-
void cactus_divide_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
|
|
48
|
-
const size_t* a_strides, const size_t* b_strides,
|
|
49
|
-
const size_t* output_shape, size_t ndim);
|
|
50
|
-
|
|
51
|
-
|
|
52
26
|
void cactus_add_broadcast_f16(const __fp16* a, const __fp16* b, __fp16* output,
|
|
53
27
|
const size_t* a_strides, const size_t* b_strides,
|
|
54
28
|
const size_t* output_shape, size_t ndim);
|
|
@@ -62,159 +36,72 @@ void cactus_divide_broadcast_f16(const __fp16* a, const __fp16* b, __fp16* outpu
|
|
|
62
36
|
const size_t* a_strides, const size_t* b_strides,
|
|
63
37
|
const size_t* output_shape, size_t ndim);
|
|
64
38
|
|
|
65
|
-
|
|
66
|
-
void cactus_add_broadcast_f32(const float* a, const float* b, float* output,
|
|
67
|
-
const size_t* a_strides, const size_t* b_strides,
|
|
68
|
-
const size_t* output_shape, size_t ndim);
|
|
69
|
-
void cactus_subtract_broadcast_f32(const float* a, const float* b, float* output,
|
|
70
|
-
const size_t* a_strides, const size_t* b_strides,
|
|
71
|
-
const size_t* output_shape, size_t ndim);
|
|
72
|
-
void cactus_multiply_broadcast_f32(const float* a, const float* b, float* output,
|
|
73
|
-
const size_t* a_strides, const size_t* b_strides,
|
|
74
|
-
const size_t* output_shape, size_t ndim);
|
|
75
|
-
void cactus_divide_broadcast_f32(const float* a, const float* b, float* output,
|
|
76
|
-
const size_t* a_strides, const size_t* b_strides,
|
|
77
|
-
const size_t* output_shape, size_t ndim);
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
void cactus_scalar_op_int8(const int8_t* input, int8_t* output, size_t num_elements, float scalar_value, ScalarOpType op_type);
|
|
81
39
|
void cactus_scalar_op_f16(const __fp16* input, __fp16* output, size_t num_elements, float scalar_value, ScalarOpType op_type);
|
|
82
|
-
void cactus_scalar_op_f32(const float* input, float* output, size_t num_elements, float scalar_value, ScalarOpType op_type);
|
|
83
|
-
|
|
84
40
|
|
|
85
|
-
void cactus_matmul_int8(const int8_t*
|
|
86
|
-
|
|
87
|
-
|
|
41
|
+
void cactus_matmul_int8(const int8_t* A, const float* A_scales,
|
|
42
|
+
const int8_t* B, const __fp16* B_scales,
|
|
43
|
+
__fp16* C, size_t M, size_t K, size_t N, size_t group_size);
|
|
88
44
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
#define cactus_matmul_int8_to_int32 cactus_matmul_int8_to_int32_i8mm
|
|
93
|
-
#else
|
|
94
|
-
void cactus_matmul_int8_to_int32(const int8_t* a, const int8_t* b_transposed, int32_t* c,
|
|
95
|
-
size_t M, size_t K, size_t N);
|
|
96
|
-
#endif
|
|
45
|
+
void cactus_matmul_int4(const int8_t* A, const float* A_scales,
|
|
46
|
+
const uint8_t* B_packed, const __fp16* B_scales,
|
|
47
|
+
__fp16* C, size_t M, size_t K, size_t N, size_t group_size);
|
|
97
48
|
|
|
98
49
|
void cactus_matmul_f16(const __fp16* a, const __fp16* b_transposed, __fp16* c,
|
|
99
50
|
size_t M, size_t K, size_t N);
|
|
100
51
|
|
|
101
|
-
void cactus_matmul_f32(const float* a, const float* b_transposed, float* c,
|
|
102
|
-
size_t M, size_t K, size_t N);
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
void cactus_transpose_2d_int8(const int8_t* source, int8_t* destination,
|
|
106
|
-
size_t num_rows, size_t num_cols, size_t start_row, size_t end_row);
|
|
107
52
|
void cactus_transpose_2d_f16(const __fp16* source, __fp16* destination,
|
|
108
53
|
size_t num_rows, size_t num_cols, size_t start_row, size_t end_row);
|
|
109
|
-
void cactus_transpose_2d_f32(const float* source, float* destination,
|
|
110
|
-
size_t num_rows, size_t num_cols, size_t start_row, size_t end_row);
|
|
111
|
-
|
|
112
|
-
void cactus_transpose_int8(const int8_t* source, int8_t* destination, const size_t* shape,
|
|
113
|
-
const size_t* permutation, size_t ndim, size_t start_idx, size_t end_idx);
|
|
114
54
|
void cactus_transpose_f16(const __fp16* source, __fp16* destination, const size_t* shape,
|
|
115
55
|
const size_t* permutation, size_t ndim, size_t start_idx, size_t end_idx);
|
|
116
|
-
void cactus_transpose_f32(const float* source, float* destination, const size_t* shape,
|
|
117
|
-
const size_t* permutation, size_t ndim, size_t start_idx, size_t end_idx);
|
|
118
56
|
|
|
119
|
-
int64_t cactus_sum_all_int8(const int8_t* data, size_t num_elements);
|
|
120
|
-
void cactus_sum_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
121
57
|
double cactus_sum_all_f16(const __fp16* data, size_t num_elements);
|
|
122
|
-
|
|
123
|
-
void cactus_sum_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
58
|
+
void cactus_sum_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
124
59
|
|
|
125
|
-
double cactus_mean_all_int8(const int8_t* data, size_t num_elements);
|
|
126
|
-
void cactus_mean_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
127
60
|
double cactus_mean_all_f16(const __fp16* data, size_t num_elements);
|
|
128
61
|
void cactus_mean_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
129
|
-
double cactus_mean_all_f32(const float* data, size_t num_elements);
|
|
130
|
-
void cactus_mean_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
131
62
|
|
|
132
|
-
double
|
|
133
|
-
void
|
|
134
|
-
double cactus_variance_all_f32(const float* data, size_t num_elements);
|
|
135
|
-
void cactus_variance_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
63
|
+
double cactus_variance_all_f16(const __fp16* data, size_t num_elements);
|
|
64
|
+
void cactus_variance_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
136
65
|
|
|
137
|
-
|
|
138
|
-
void
|
|
139
|
-
float cactus_min_all_f32(const float* data, size_t num_elements);
|
|
140
|
-
void cactus_min_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
66
|
+
__fp16 cactus_min_all_f16(const __fp16* data, size_t num_elements);
|
|
67
|
+
void cactus_min_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
141
68
|
|
|
142
|
-
|
|
143
|
-
void
|
|
144
|
-
float cactus_max_all_f32(const float* data, size_t num_elements);
|
|
145
|
-
void cactus_max_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
69
|
+
__fp16 cactus_max_all_f16(const __fp16* data, size_t num_elements);
|
|
70
|
+
void cactus_max_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
146
71
|
|
|
147
72
|
void cactus_rms_norm_f16(const __fp16* input, const __fp16* weight, __fp16* output,
|
|
148
73
|
size_t batch_size, size_t dims, float eps);
|
|
149
|
-
|
|
150
|
-
void cactus_rms_norm_f32(const float* input, const float* weight, float* output,
|
|
151
|
-
size_t batch_size, size_t dims, float eps);
|
|
152
|
-
|
|
153
|
-
void cactus_rms_norm_i8_f32(const int8_t* input, const float* weight, float* output,
|
|
154
|
-
size_t batch_size, size_t dims, float eps, float input_scale);
|
|
155
74
|
|
|
156
75
|
void cactus_rope_f16(const __fp16* input, __fp16* output, size_t batch_size, size_t seq_len,
|
|
157
76
|
size_t num_heads, size_t head_dim, size_t start_pos, float theta);
|
|
158
77
|
|
|
159
|
-
void
|
|
160
|
-
size_t num_heads, size_t head_dim, size_t start_pos, float theta);
|
|
161
|
-
|
|
162
|
-
void cactus_rope_i8_f32_i8(const int8_t* input, int8_t* output, size_t batch_size, size_t seq_len,
|
|
163
|
-
size_t num_heads, size_t head_dim, size_t start_pos, float theta,
|
|
164
|
-
float input_scale, float output_scale);
|
|
165
|
-
|
|
166
|
-
void cactus_softmax_f16(const __fp16* input, __fp16* output, size_t batch_size,
|
|
78
|
+
void cactus_softmax_f16(const __fp16* input, __fp16* output, size_t batch_size,
|
|
167
79
|
size_t seq_len, size_t vocab_size);
|
|
168
80
|
|
|
169
|
-
void cactus_softmax_f32(const float* input, float* output, size_t batch_size,
|
|
170
|
-
size_t seq_len, size_t vocab_size);
|
|
171
|
-
|
|
172
|
-
void cactus_silu_f32(const float* input, float* output, size_t num_elements);
|
|
173
81
|
void cactus_silu_f16(const __fp16* input, __fp16* output, size_t num_elements);
|
|
174
|
-
void cactus_silu_int8(const int8_t* input, int8_t* output, size_t num_elements,
|
|
175
|
-
float input_scale, float output_scale);
|
|
176
82
|
|
|
177
|
-
void cactus_gelu_f32(const float* input, float* output, size_t num_elements);
|
|
178
83
|
void cactus_gelu_f16(const __fp16* input, __fp16* output, size_t num_elements);
|
|
179
|
-
void cactus_gelu_int8(const int8_t* input, int8_t* output, size_t num_elements,
|
|
180
|
-
float input_scale, float output_scale);
|
|
181
84
|
|
|
182
|
-
void cactus_gelu_f32_erf(const float* input, float* output, size_t num_elements);
|
|
183
85
|
void cactus_gelu_f16_erf(const __fp16* input, __fp16* output, size_t num_elements);
|
|
184
|
-
void cactus_gelu_int8_erf(
|
|
185
|
-
const int8_t* input,
|
|
186
|
-
int8_t* output,
|
|
187
|
-
size_t num_elements,
|
|
188
|
-
float scale_in,
|
|
189
|
-
float scale_out);
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
void cactus_attention_int8(const int8_t* queries, const int8_t* keys, const int8_t* values, int8_t* output,
|
|
193
|
-
size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
|
|
194
|
-
size_t head_dim, float scale, const int8_t* mask,
|
|
195
|
-
float q_scale, float k_scale, float v_scale, float output_scale, size_t position_offset = 0, size_t window_size = 0,
|
|
196
|
-
bool is_causal = true);
|
|
197
86
|
|
|
198
87
|
void cactus_attention_f16(const __fp16* queries, const __fp16* keys, const __fp16* values, __fp16* output,
|
|
199
88
|
size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
|
|
200
89
|
size_t head_dim, float scale, const __fp16* mask, size_t position_offset = 0, size_t window_size = 0,
|
|
201
90
|
bool is_causal = true);
|
|
202
91
|
|
|
203
|
-
void
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
const
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
size_t
|
|
214
|
-
size_t
|
|
215
|
-
size_t
|
|
216
|
-
size_t K,
|
|
217
|
-
size_t dilation);
|
|
92
|
+
void cactus_attention_hybrid_int8_fp16(
|
|
93
|
+
const __fp16* queries,
|
|
94
|
+
const int8_t* keys_cached,
|
|
95
|
+
const int8_t* values_cached,
|
|
96
|
+
const float* k_scales,
|
|
97
|
+
const float* v_scales,
|
|
98
|
+
const __fp16* keys_new,
|
|
99
|
+
const __fp16* values_new,
|
|
100
|
+
__fp16* output,
|
|
101
|
+
size_t batch_size, size_t seq_len, size_t cache_len, size_t new_len,
|
|
102
|
+
size_t num_q_heads, size_t num_kv_heads, size_t head_dim,
|
|
103
|
+
float scale, size_t position_offset = 0, bool is_causal = true,
|
|
104
|
+
size_t group_size = KV_QUANT_GROUP_SIZE);
|
|
218
105
|
|
|
219
106
|
void cactus_conv1d_causal_depthwise_f16(
|
|
220
107
|
const __fp16* input,
|
|
@@ -226,30 +113,6 @@ void cactus_conv1d_causal_depthwise_f16(
|
|
|
226
113
|
size_t K,
|
|
227
114
|
size_t dilation);
|
|
228
115
|
|
|
229
|
-
void cactus_conv1d_causal_depthwise_int8(
|
|
230
|
-
const int8_t* input,
|
|
231
|
-
const int8_t* weight,
|
|
232
|
-
int8_t* output,
|
|
233
|
-
size_t N,
|
|
234
|
-
size_t L,
|
|
235
|
-
size_t C,
|
|
236
|
-
size_t K,
|
|
237
|
-
size_t dilation,
|
|
238
|
-
float input_scale,
|
|
239
|
-
float weight_scale,
|
|
240
|
-
float output_scale);
|
|
241
|
-
|
|
242
|
-
void cactus_conv1d_f32_k3(
|
|
243
|
-
const float* input,
|
|
244
|
-
const float* weight,
|
|
245
|
-
float* output,
|
|
246
|
-
size_t N,
|
|
247
|
-
size_t L,
|
|
248
|
-
size_t C_in,
|
|
249
|
-
size_t C_out,
|
|
250
|
-
size_t stride
|
|
251
|
-
);
|
|
252
|
-
|
|
253
116
|
void cactus_conv1d_f16_k3(
|
|
254
117
|
const __fp16* input,
|
|
255
118
|
const __fp16* weight,
|
|
@@ -261,26 +124,8 @@ void cactus_conv1d_f16_k3(
|
|
|
261
124
|
size_t stride
|
|
262
125
|
);
|
|
263
126
|
|
|
264
|
-
void
|
|
265
|
-
|
|
266
|
-
const float* weight,
|
|
267
|
-
float* output,
|
|
268
|
-
size_t N, size_t L,
|
|
269
|
-
size_t C_in, size_t C_out,
|
|
270
|
-
size_t stride
|
|
271
|
-
);
|
|
272
|
-
|
|
273
|
-
void cactus_conv1d_f16_k3(
|
|
274
|
-
const __fp16* input,
|
|
275
|
-
const __fp16* weight,
|
|
276
|
-
__fp16* output,
|
|
277
|
-
size_t N, size_t L,
|
|
278
|
-
size_t C_in, size_t C_out,
|
|
279
|
-
size_t stride
|
|
280
|
-
);
|
|
281
|
-
|
|
282
|
-
void cactus_bilinear_interpolation_fp32(const float* input, float* output, size_t src_height, size_t src_width, size_t embed_dim,
|
|
283
|
-
size_t dst_height, size_t dst_width);
|
|
127
|
+
void cactus_bilinear_interpolation_f16(const __fp16* input, __fp16* output, size_t src_height, size_t src_width, size_t embed_dim,
|
|
128
|
+
size_t dst_height, size_t dst_width);
|
|
284
129
|
|
|
285
130
|
void cactus_sample_f32(const float* logits, uint32_t* output, size_t vocab_size,
|
|
286
131
|
float temperature, float top_p, size_t top_k, size_t random_seed,
|
|
@@ -291,25 +136,30 @@ void cactus_sample_f16(const __fp16* logits, uint32_t* output, size_t vocab_size
|
|
|
291
136
|
const float* bias_values = nullptr, const uint32_t* bias_indices = nullptr,
|
|
292
137
|
size_t bias_count = 0);
|
|
293
138
|
|
|
294
|
-
|
|
295
|
-
void cactus_concat_f32(const float* input1, const float* input2, float* output,
|
|
296
|
-
const size_t* shape1, const size_t* shape2, const size_t* output_shape,
|
|
297
|
-
size_t ndims, int axis);
|
|
298
139
|
void cactus_concat_f16(const __fp16* input1, const __fp16* input2, __fp16* output,
|
|
299
140
|
const size_t* shape1, const size_t* shape2, const size_t* output_shape,
|
|
300
141
|
size_t ndims, int axis);
|
|
301
|
-
void cactus_concat_int8(const int8_t* input1, const int8_t* input2, int8_t* output,
|
|
302
|
-
const size_t* shape1, const size_t* shape2, const size_t* output_shape,
|
|
303
|
-
size_t ndims, int axis);
|
|
304
142
|
|
|
305
143
|
void cactus_int8_to_fp32(const int8_t* src, float* dst, size_t count, float scale = 1.0f);
|
|
306
144
|
void cactus_fp32_to_int8(const float* src, int8_t* dst, size_t count, float scale = 1.0f);
|
|
307
|
-
void cactus_dynamic_quantize_fp32_to_int8(const float* src, int8_t* dst, size_t count, float* computed_scale);
|
|
308
145
|
void cactus_fp16_to_fp32(const __fp16* src, float* dst, size_t count);
|
|
309
146
|
void cactus_fp32_to_fp16(const float* src, __fp16* dst, size_t count);
|
|
310
147
|
void cactus_int8_to_fp16(const int8_t* src, __fp16* dst, size_t count, float scale = 1.0f);
|
|
311
148
|
void cactus_fp16_to_int8(const __fp16* src, int8_t* dst, size_t count, float scale = 1.0f);
|
|
312
149
|
float cactus_fp16_max_abs(const __fp16* src, size_t count);
|
|
313
|
-
void cactus_int32_to_fp16_scaled(const int32_t* src, __fp16* dst, size_t count, float scale);
|
|
314
150
|
|
|
315
|
-
|
|
151
|
+
void cactus_quantize_kv_fp16_to_int8(
|
|
152
|
+
const __fp16* src,
|
|
153
|
+
int8_t* dst,
|
|
154
|
+
float* scales,
|
|
155
|
+
size_t seq_len, size_t kv_heads, size_t head_dim,
|
|
156
|
+
size_t group_size = KV_QUANT_GROUP_SIZE);
|
|
157
|
+
|
|
158
|
+
inline size_t kv_scales_count(size_t seq_len, size_t kv_heads, size_t head_dim, size_t group_size = KV_QUANT_GROUP_SIZE) {
|
|
159
|
+
size_t num_groups = (head_dim + group_size - 1) / group_size;
|
|
160
|
+
return seq_len * kv_heads * num_groups;
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
void cactus_unpack_int4_to_int8(const uint8_t* packed, int8_t* unpacked, size_t unpacked_count);
|
|
164
|
+
|
|
165
|
+
#endif
|