cactus-react-native 1.4.0 → 1.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. package/README.md +212 -27
  2. package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
  3. package/cpp/HybridCactus.cpp +119 -0
  4. package/cpp/HybridCactus.hpp +13 -0
  5. package/cpp/cactus_ffi.h +24 -0
  6. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +24 -0
  7. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_utils.h +41 -1
  8. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +66 -48
  9. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/gemma_tools.h +549 -0
  10. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +102 -21
  11. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +45 -195
  12. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +399 -140
  13. package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
  14. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +24 -0
  15. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_utils.h +41 -1
  16. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +66 -48
  17. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/gemma_tools.h +549 -0
  18. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +102 -21
  19. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +45 -195
  20. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +399 -140
  21. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
  22. package/lib/module/api/Database.js +0 -92
  23. package/lib/module/api/Database.js.map +1 -1
  24. package/lib/module/classes/CactusLM.js +33 -15
  25. package/lib/module/classes/CactusLM.js.map +1 -1
  26. package/lib/module/classes/CactusSTT.js +90 -15
  27. package/lib/module/classes/CactusSTT.js.map +1 -1
  28. package/lib/module/hooks/useCactusLM.js +14 -5
  29. package/lib/module/hooks/useCactusLM.js.map +1 -1
  30. package/lib/module/hooks/useCactusSTT.js +100 -4
  31. package/lib/module/hooks/useCactusSTT.js.map +1 -1
  32. package/lib/module/index.js.map +1 -1
  33. package/lib/module/models.js +336 -0
  34. package/lib/module/models.js.map +1 -0
  35. package/lib/module/native/Cactus.js +37 -0
  36. package/lib/module/native/Cactus.js.map +1 -1
  37. package/lib/module/types/CactusLM.js +2 -0
  38. package/lib/module/types/CactusSTT.js +2 -0
  39. package/lib/module/types/common.js +2 -0
  40. package/lib/module/types/{CactusModel.js.map → common.js.map} +1 -1
  41. package/lib/typescript/src/api/Database.d.ts +0 -6
  42. package/lib/typescript/src/api/Database.d.ts.map +1 -1
  43. package/lib/typescript/src/classes/CactusLM.d.ts +7 -3
  44. package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
  45. package/lib/typescript/src/classes/CactusSTT.d.ts +13 -4
  46. package/lib/typescript/src/classes/CactusSTT.d.ts.map +1 -1
  47. package/lib/typescript/src/hooks/useCactusLM.d.ts +2 -2
  48. package/lib/typescript/src/hooks/useCactusLM.d.ts.map +1 -1
  49. package/lib/typescript/src/hooks/useCactusSTT.d.ts +12 -4
  50. package/lib/typescript/src/hooks/useCactusSTT.d.ts.map +1 -1
  51. package/lib/typescript/src/index.d.ts +2 -3
  52. package/lib/typescript/src/index.d.ts.map +1 -1
  53. package/lib/typescript/src/models.d.ts +6 -0
  54. package/lib/typescript/src/models.d.ts.map +1 -0
  55. package/lib/typescript/src/native/Cactus.d.ts +6 -1
  56. package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
  57. package/lib/typescript/src/specs/Cactus.nitro.d.ts +5 -0
  58. package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
  59. package/lib/typescript/src/types/CactusLM.d.ts +2 -0
  60. package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
  61. package/lib/typescript/src/types/CactusSTT.d.ts +20 -0
  62. package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
  63. package/lib/typescript/src/types/common.d.ts +28 -0
  64. package/lib/typescript/src/types/common.d.ts.map +1 -0
  65. package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +5 -0
  66. package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +5 -0
  67. package/package.json +1 -1
  68. package/src/api/Database.ts +0 -133
  69. package/src/classes/CactusLM.ts +49 -17
  70. package/src/classes/CactusSTT.ts +118 -17
  71. package/src/hooks/useCactusLM.ts +25 -5
  72. package/src/hooks/useCactusSTT.ts +117 -5
  73. package/src/index.tsx +6 -2
  74. package/src/models.ts +344 -0
  75. package/src/native/Cactus.ts +55 -0
  76. package/src/specs/Cactus.nitro.ts +5 -0
  77. package/src/types/CactusLM.ts +3 -0
  78. package/src/types/CactusSTT.ts +26 -0
  79. package/src/types/common.ts +28 -0
  80. package/lib/module/types/CactusModel.js +0 -2
  81. package/lib/module/types/CactusSTTModel.js +0 -2
  82. package/lib/module/types/CactusSTTModel.js.map +0 -1
  83. package/lib/typescript/src/types/CactusModel.d.ts +0 -13
  84. package/lib/typescript/src/types/CactusModel.d.ts.map +0 -1
  85. package/lib/typescript/src/types/CactusSTTModel.d.ts +0 -8
  86. package/lib/typescript/src/types/CactusSTTModel.d.ts.map +0 -1
  87. package/src/types/CactusModel.ts +0 -15
  88. package/src/types/CactusSTTModel.ts +0 -10
@@ -11,6 +11,7 @@
11
11
  #include <mutex>
12
12
  #include <sstream>
13
13
  #include <iostream>
14
+ #include <arm_neon.h>
14
15
 
15
16
  namespace cactus {
16
17
 
@@ -96,9 +97,10 @@ namespace GraphFile {
96
97
  }
97
98
 
98
99
  enum class Precision {
99
- INT8,
100
+ INT8,
100
101
  FP16,
101
- FP32
102
+ FP32,
103
+ INT4
102
104
  };
103
105
 
104
106
  enum class ComputeBackend {
@@ -112,7 +114,7 @@ enum class OpType {
112
114
  MATMUL, TRANSPOSE, RESHAPE, SLICE, GATHER, EMBEDDING,
113
115
  BILINEAR_INTERPOLATION,
114
116
  SUM, MEAN, VARIANCE, MIN, MAX,
115
- RMS_NORM, ROPE, SOFTMAX, ATTENTION, CONV1D_CAUSAL, CONV1D_K3,
117
+ RMS_NORM, ROPE, SOFTMAX, ATTENTION, ATTENTION_INT8_HYBRID, CONV1D_CAUSAL, CONV1D_K3,
116
118
  SCALAR_ADD, SCALAR_SUBTRACT, SCALAR_MULTIPLY, SCALAR_DIVIDE, SCALAR_EXP, SCALAR_SQRT, SCALAR_COS, SCALAR_SIN,
117
119
  SILU, GELU, GELU_ERF,
118
120
  SAMPLE, CONCAT,
@@ -122,27 +124,38 @@ enum class OpType {
122
124
  };
123
125
 
124
126
  struct PrecisionTraits {
127
+ // Returns in-memory element size (INT4 unpacks to INT8, so returns 1)
125
128
  static constexpr size_t size_of(Precision prec) {
126
129
  switch (prec) {
127
130
  case Precision::INT8: return 1;
128
131
  case Precision::FP16: return 2;
129
132
  case Precision::FP32: return 4;
133
+ case Precision::INT4: return 1;
130
134
  }
131
135
  return 1;
132
136
  }
133
-
137
+
138
+ static constexpr size_t packed_size_of(Precision prec, size_t count) {
139
+ switch (prec) {
140
+ case Precision::INT4: return (count + 1) / 2;
141
+ default: return count * size_of(prec);
142
+ }
143
+ }
144
+
134
145
  static constexpr bool is_integer(Precision prec) {
135
146
  switch (prec) {
136
147
  case Precision::INT8: return true;
148
+ case Precision::INT4: return true;
137
149
  case Precision::FP16: return false;
138
150
  case Precision::FP32: return false;
139
151
  }
140
152
  return true;
141
153
  }
142
-
154
+
143
155
  static constexpr bool is_floating_point(Precision prec) {
144
156
  switch (prec) {
145
157
  case Precision::INT8: return false;
158
+ case Precision::INT4: return false;
146
159
  case Precision::FP16: return true;
147
160
  case Precision::FP32: return true;
148
161
  }
@@ -153,8 +166,6 @@ struct PrecisionTraits {
153
166
  namespace Quantization {
154
167
  void int8_to_fp32(const int8_t* src, float* dst, size_t count, float scale = 1.0f);
155
168
  void fp32_to_int8(const float* src, int8_t* dst, size_t count, float scale = 1.0f);
156
- void dynamic_quantize_fp32_to_int8(const float* src, int8_t* dst, size_t count,
157
- float* computed_scale);
158
169
  void fp16_to_fp32(const __fp16* src, float* dst, size_t count);
159
170
  void fp32_to_fp16(const float* src, __fp16* dst, size_t count);
160
171
  void int8_to_fp16(const int8_t* src, __fp16* dst, size_t count, float scale = 1.0f);
@@ -188,10 +199,17 @@ struct BufferDesc {
188
199
  void* external_data;
189
200
  char* pooled_data;
190
201
  Precision precision;
191
- float quantization_scale;
202
+
203
+ size_t group_size = 0;
204
+ size_t num_groups = 0;
205
+ void* scales_data = nullptr;
206
+ std::unique_ptr<char[]> owned_scales;
207
+
208
+ const void* packed_int4_data = nullptr;
209
+ size_t packed_int4_size = 0;
192
210
 
193
211
  BufferDesc();
194
- BufferDesc(const std::vector<size_t>& s, Precision prec = Precision::INT8, float scale = 1.0f);
212
+ BufferDesc(const std::vector<size_t>& s, Precision prec = Precision::INT8);
195
213
  ~BufferDesc();
196
214
 
197
215
  BufferDesc(BufferDesc&& other) noexcept;
@@ -209,6 +227,28 @@ struct BufferDesc {
209
227
  template<typename T>
210
228
  const T* data_as() const { return static_cast<const T*>(get_data()); }
211
229
 
230
+ const __fp16* scales_as_fp16() const {
231
+ return reinterpret_cast<const __fp16*>(scales_data);
232
+ }
233
+ bool is_grouped_int8() const {
234
+ return precision == Precision::INT8 && group_size > 0;
235
+ }
236
+ bool is_packed_int4() const {
237
+ return packed_int4_data != nullptr && packed_int4_size > 0;
238
+ }
239
+ const uint8_t* packed_int4_as_uint8() const {
240
+ return reinterpret_cast<const uint8_t*>(packed_int4_data);
241
+ }
242
+ void set_grouped_scales(size_t gs, size_t ng, void* scales_ptr) {
243
+ group_size = gs;
244
+ num_groups = ng;
245
+ scales_data = scales_ptr;
246
+ }
247
+ void set_packed_int4(const void* packed_data, size_t packed_size) {
248
+ packed_int4_data = packed_data;
249
+ packed_int4_size = packed_size;
250
+ }
251
+
212
252
  void allocate();
213
253
  void allocate_from_pool(BufferPool& pool);
214
254
  void release_to_pool(BufferPool& pool);
@@ -247,6 +287,14 @@ struct OpParams {
247
287
 
248
288
  std::vector<float> bias_values;
249
289
  std::vector<uint32_t> bias_indices;
290
+
291
+ const int8_t* cached_keys_int8 = nullptr;
292
+ const int8_t* cached_values_int8 = nullptr;
293
+ const float* cached_k_scales = nullptr;
294
+ const float* cached_v_scales = nullptr;
295
+ size_t cache_seq_len = 0;
296
+ size_t num_kv_heads = 0;
297
+ size_t head_dim = 0;
250
298
  };
251
299
 
252
300
  struct GraphNode {
@@ -326,7 +374,7 @@ public:
326
374
  size_t precision_cast(size_t input, Precision target_precision);
327
375
 
328
376
  size_t add(size_t input1, size_t input2);
329
- size_t add_clipped(size_t input1, size_t input2); // For FP16 residual connections (Gemma)
377
+ size_t add_clipped(size_t input1, size_t input2);
330
378
  size_t subtract(size_t input1, size_t input2);
331
379
  size_t multiply(size_t input1, size_t input2);
332
380
  size_t divide(size_t input1, size_t input2);
@@ -361,8 +409,12 @@ public:
361
409
  size_t gather(size_t embeddings, size_t indices);
362
410
  size_t mmap_embeddings(const std::string& filename);
363
411
  size_t mmap_weights(const std::string& filename);
364
- size_t load_weights(const std::string& filename);
365
- void set_quantization_scale(size_t node_id, float scale);
412
+ size_t load_weights(const std::string& filename);
413
+ void set_grouped_scales(size_t node_id, size_t group_size, size_t num_groups, void* scales_ptr);
414
+
415
+ void release_weight_pages(size_t node_id);
416
+ void prefetch_weight_pages(size_t node_id);
417
+ void release_all_weight_pages();
366
418
  size_t embedding(const std::string& filename, size_t indices);
367
419
  size_t embedding(size_t embedding_tensor, size_t indices);
368
420
  size_t bilinear_interpolation(size_t pos_embeds, size_t dst_height, size_t dst_width);
@@ -376,6 +428,11 @@ public:
376
428
  size_t attention(size_t query, size_t key, size_t value, float scale, size_t position_offset, ComputeBackend backend = ComputeBackend::CPU);
377
429
  size_t attention(size_t query, size_t key, size_t value, float scale, size_t position_offset, size_t window_size, ComputeBackend backend = ComputeBackend::CPU);
378
430
 
431
+ size_t attention_int8_hybrid(size_t query, size_t key_new, size_t value_new, float scale, size_t position_offset,
432
+ const int8_t* cached_keys, const int8_t* cached_values,
433
+ const float* k_scales, const float* v_scales,
434
+ size_t cache_len, size_t num_kv_heads, size_t head_dim);
435
+
379
436
  size_t conv1d_causal(size_t input, size_t weight, size_t kernel_size, size_t dilation = 1);
380
437
  size_t conv1d_k3(size_t input, size_t weight, size_t stride);
381
438
 
@@ -392,6 +449,8 @@ public:
392
449
  void execute(const std::string& profile_file = "");
393
450
  void hard_reset();
394
451
  void soft_reset();
452
+ void soft_reset_keep_pool();
453
+ void set_prefill_mode(bool enabled) { prefill_mode_ = enabled; }
395
454
 
396
455
  void register_debug_node(uint32_t layer_idx, const std::string& name, size_t node_id);
397
456
  void capture_debug_node(uint32_t layer_idx, const std::string& name, size_t node_id);
@@ -410,8 +469,10 @@ private:
410
469
  size_t next_node_id_;
411
470
  std::vector<std::unique_ptr<GraphFile::MappedFile>> mapped_files_;
412
471
  std::unordered_map<std::string, size_t> weight_cache_;
472
+ std::unordered_map<size_t, size_t> node_to_mapped_file_;
413
473
  std::vector<DebugNodeEntry> debug_nodes_;
414
474
  BufferPool buffer_pool_;
475
+ bool prefill_mode_ = false;
415
476
  };
416
477
 
417
478
 
@@ -430,25 +491,36 @@ namespace GraphFile {
430
491
  public:
431
492
  MappedFile(const std::string& filename);
432
493
  ~MappedFile();
433
-
494
+
434
495
  MappedFile(const MappedFile&) = delete;
435
496
  MappedFile& operator=(const MappedFile&) = delete;
436
497
  MappedFile(MappedFile&& other) noexcept;
437
498
  MappedFile& operator=(MappedFile&& other) noexcept;
438
-
499
+
439
500
  const std::vector<size_t>& shape() const;
440
501
  Precision precision() const;
502
+ Precision effective_precision() const {
503
+ return is_int4_ ? Precision::INT8 : precision_;
504
+ }
441
505
  size_t byte_size() const;
442
- float quantization_scale() const;
443
-
506
+
507
+ size_t group_size() const { return group_size_; }
508
+ size_t num_groups() const { return num_groups_; }
509
+ const void* scales_data() const;
510
+ const void* raw_packed_data() const; // Get raw mmap'd data without unpacking (for INT4)
511
+ bool is_int4() const { return is_int4_; }
512
+
444
513
  void* data();
445
514
  const void* data() const;
446
-
515
+
447
516
  template<typename T>
448
517
  const T* typed_data() const;
449
-
518
+
450
519
  LoadedNode load_into_graph(CactusGraph& graph) const;
451
-
520
+
521
+ void release_pages();
522
+ void prefetch_pages();
523
+
452
524
  private:
453
525
  int fd_;
454
526
  void* mapped_data_;
@@ -456,10 +528,19 @@ namespace GraphFile {
456
528
  std::vector<size_t> shape_;
457
529
  Precision precision_;
458
530
  size_t byte_size_;
459
- float quantization_scale_;
531
+ size_t group_size_ = 0;
532
+ size_t num_groups_ = 0;
533
+ size_t scales_offset_ = 0;
534
+ size_t scales_bytes_ = 0;
535
+ uint32_t version_ = 1;
536
+ uint32_t alignment_ = 32;
537
+ bool is_int4_ = false;
538
+ mutable std::unique_ptr<int8_t[]> unpacked_int4_data_;
460
539
  void parse_header();
540
+ void apply_madvise_hints();
541
+ void unpack_int4_if_needed() const;
461
542
  };
462
-
543
+
463
544
  MappedFile mmap_load(const std::string& filename);
464
545
  }
465
546
 
@@ -15,12 +15,7 @@ enum class ScalarOpType {
15
15
  SIN
16
16
  };
17
17
 
18
-
19
- void cactus_add_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
20
- void cactus_subtract_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
21
- void cactus_multiply_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
22
- void cactus_divide_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
23
-
18
+ constexpr size_t KV_QUANT_GROUP_SIZE = 128;
24
19
 
25
20
  void cactus_add_f16(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
26
21
  void cactus_add_f16_clipped(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
@@ -28,27 +23,6 @@ void cactus_subtract_f16(const __fp16* a, const __fp16* b, __fp16* output, size_
28
23
  void cactus_multiply_f16(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
29
24
  void cactus_divide_f16(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
30
25
 
31
-
32
- void cactus_add_f32(const float* a, const float* b, float* output, size_t num_elements);
33
- void cactus_subtract_f32(const float* a, const float* b, float* output, size_t num_elements);
34
- void cactus_multiply_f32(const float* a, const float* b, float* output, size_t num_elements);
35
- void cactus_divide_f32(const float* a, const float* b, float* output, size_t num_elements);
36
-
37
-
38
- void cactus_add_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
39
- const size_t* a_strides, const size_t* b_strides,
40
- const size_t* output_shape, size_t ndim);
41
- void cactus_subtract_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
42
- const size_t* a_strides, const size_t* b_strides,
43
- const size_t* output_shape, size_t ndim);
44
- void cactus_multiply_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
45
- const size_t* a_strides, const size_t* b_strides,
46
- const size_t* output_shape, size_t ndim);
47
- void cactus_divide_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
48
- const size_t* a_strides, const size_t* b_strides,
49
- const size_t* output_shape, size_t ndim);
50
-
51
-
52
26
  void cactus_add_broadcast_f16(const __fp16* a, const __fp16* b, __fp16* output,
53
27
  const size_t* a_strides, const size_t* b_strides,
54
28
  const size_t* output_shape, size_t ndim);
@@ -62,159 +36,72 @@ void cactus_divide_broadcast_f16(const __fp16* a, const __fp16* b, __fp16* outpu
62
36
  const size_t* a_strides, const size_t* b_strides,
63
37
  const size_t* output_shape, size_t ndim);
64
38
 
65
-
66
- void cactus_add_broadcast_f32(const float* a, const float* b, float* output,
67
- const size_t* a_strides, const size_t* b_strides,
68
- const size_t* output_shape, size_t ndim);
69
- void cactus_subtract_broadcast_f32(const float* a, const float* b, float* output,
70
- const size_t* a_strides, const size_t* b_strides,
71
- const size_t* output_shape, size_t ndim);
72
- void cactus_multiply_broadcast_f32(const float* a, const float* b, float* output,
73
- const size_t* a_strides, const size_t* b_strides,
74
- const size_t* output_shape, size_t ndim);
75
- void cactus_divide_broadcast_f32(const float* a, const float* b, float* output,
76
- const size_t* a_strides, const size_t* b_strides,
77
- const size_t* output_shape, size_t ndim);
78
-
79
-
80
- void cactus_scalar_op_int8(const int8_t* input, int8_t* output, size_t num_elements, float scalar_value, ScalarOpType op_type);
81
39
  void cactus_scalar_op_f16(const __fp16* input, __fp16* output, size_t num_elements, float scalar_value, ScalarOpType op_type);
82
- void cactus_scalar_op_f32(const float* input, float* output, size_t num_elements, float scalar_value, ScalarOpType op_type);
83
-
84
40
 
85
- void cactus_matmul_int8(const int8_t* a, const int8_t* b_transposed, int8_t* c,
86
- size_t M, size_t K, size_t N,
87
- float a_scale, float b_scale, float c_scale);
41
+ void cactus_matmul_int8(const int8_t* A, const float* A_scales,
42
+ const int8_t* B, const __fp16* B_scales,
43
+ __fp16* C, size_t M, size_t K, size_t N, size_t group_size);
88
44
 
89
- #if defined(__ARM_FEATURE_MATMUL_INT8)
90
- void cactus_matmul_int8_to_int32_i8mm(const int8_t* a, const int8_t* b_transposed, int32_t* c,
91
- size_t M, size_t K, size_t N);
92
- #define cactus_matmul_int8_to_int32 cactus_matmul_int8_to_int32_i8mm
93
- #else
94
- void cactus_matmul_int8_to_int32(const int8_t* a, const int8_t* b_transposed, int32_t* c,
95
- size_t M, size_t K, size_t N);
96
- #endif
45
+ void cactus_matmul_int4(const int8_t* A, const float* A_scales,
46
+ const uint8_t* B_packed, const __fp16* B_scales,
47
+ __fp16* C, size_t M, size_t K, size_t N, size_t group_size);
97
48
 
98
49
  void cactus_matmul_f16(const __fp16* a, const __fp16* b_transposed, __fp16* c,
99
50
  size_t M, size_t K, size_t N);
100
51
 
101
- void cactus_matmul_f32(const float* a, const float* b_transposed, float* c,
102
- size_t M, size_t K, size_t N);
103
-
104
-
105
- void cactus_transpose_2d_int8(const int8_t* source, int8_t* destination,
106
- size_t num_rows, size_t num_cols, size_t start_row, size_t end_row);
107
52
  void cactus_transpose_2d_f16(const __fp16* source, __fp16* destination,
108
53
  size_t num_rows, size_t num_cols, size_t start_row, size_t end_row);
109
- void cactus_transpose_2d_f32(const float* source, float* destination,
110
- size_t num_rows, size_t num_cols, size_t start_row, size_t end_row);
111
-
112
- void cactus_transpose_int8(const int8_t* source, int8_t* destination, const size_t* shape,
113
- const size_t* permutation, size_t ndim, size_t start_idx, size_t end_idx);
114
54
  void cactus_transpose_f16(const __fp16* source, __fp16* destination, const size_t* shape,
115
55
  const size_t* permutation, size_t ndim, size_t start_idx, size_t end_idx);
116
- void cactus_transpose_f32(const float* source, float* destination, const size_t* shape,
117
- const size_t* permutation, size_t ndim, size_t start_idx, size_t end_idx);
118
56
 
119
- int64_t cactus_sum_all_int8(const int8_t* data, size_t num_elements);
120
- void cactus_sum_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, size_t axis_size, size_t inner_size);
121
57
  double cactus_sum_all_f16(const __fp16* data, size_t num_elements);
122
- double cactus_sum_all_f32(const float* data, size_t num_elements);
123
- void cactus_sum_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
58
+ void cactus_sum_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
124
59
 
125
- double cactus_mean_all_int8(const int8_t* data, size_t num_elements);
126
- void cactus_mean_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, size_t axis_size, size_t inner_size);
127
60
  double cactus_mean_all_f16(const __fp16* data, size_t num_elements);
128
61
  void cactus_mean_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
129
- double cactus_mean_all_f32(const float* data, size_t num_elements);
130
- void cactus_mean_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
131
62
 
132
- double cactus_variance_all_int8(const int8_t* data, size_t num_elements);
133
- void cactus_variance_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, size_t axis_size, size_t inner_size);
134
- double cactus_variance_all_f32(const float* data, size_t num_elements);
135
- void cactus_variance_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
63
+ double cactus_variance_all_f16(const __fp16* data, size_t num_elements);
64
+ void cactus_variance_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
136
65
 
137
- int64_t cactus_min_all_int8(const int8_t* data, size_t num_elements);
138
- void cactus_min_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, size_t axis_size, size_t inner_size);
139
- float cactus_min_all_f32(const float* data, size_t num_elements);
140
- void cactus_min_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
66
+ __fp16 cactus_min_all_f16(const __fp16* data, size_t num_elements);
67
+ void cactus_min_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
141
68
 
142
- int64_t cactus_max_all_int8(const int8_t* data, size_t num_elements);
143
- void cactus_max_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, size_t axis_size, size_t inner_size);
144
- float cactus_max_all_f32(const float* data, size_t num_elements);
145
- void cactus_max_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
69
+ __fp16 cactus_max_all_f16(const __fp16* data, size_t num_elements);
70
+ void cactus_max_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
146
71
 
147
72
  void cactus_rms_norm_f16(const __fp16* input, const __fp16* weight, __fp16* output,
148
73
  size_t batch_size, size_t dims, float eps);
149
-
150
- void cactus_rms_norm_f32(const float* input, const float* weight, float* output,
151
- size_t batch_size, size_t dims, float eps);
152
-
153
- void cactus_rms_norm_i8_f32(const int8_t* input, const float* weight, float* output,
154
- size_t batch_size, size_t dims, float eps, float input_scale);
155
74
 
156
75
  void cactus_rope_f16(const __fp16* input, __fp16* output, size_t batch_size, size_t seq_len,
157
76
  size_t num_heads, size_t head_dim, size_t start_pos, float theta);
158
77
 
159
- void cactus_rope_f32(const float* input, float* output, size_t batch_size, size_t seq_len,
160
- size_t num_heads, size_t head_dim, size_t start_pos, float theta);
161
-
162
- void cactus_rope_i8_f32_i8(const int8_t* input, int8_t* output, size_t batch_size, size_t seq_len,
163
- size_t num_heads, size_t head_dim, size_t start_pos, float theta,
164
- float input_scale, float output_scale);
165
-
166
- void cactus_softmax_f16(const __fp16* input, __fp16* output, size_t batch_size,
78
+ void cactus_softmax_f16(const __fp16* input, __fp16* output, size_t batch_size,
167
79
  size_t seq_len, size_t vocab_size);
168
80
 
169
- void cactus_softmax_f32(const float* input, float* output, size_t batch_size,
170
- size_t seq_len, size_t vocab_size);
171
-
172
- void cactus_silu_f32(const float* input, float* output, size_t num_elements);
173
81
  void cactus_silu_f16(const __fp16* input, __fp16* output, size_t num_elements);
174
- void cactus_silu_int8(const int8_t* input, int8_t* output, size_t num_elements,
175
- float input_scale, float output_scale);
176
82
 
177
- void cactus_gelu_f32(const float* input, float* output, size_t num_elements);
178
83
  void cactus_gelu_f16(const __fp16* input, __fp16* output, size_t num_elements);
179
- void cactus_gelu_int8(const int8_t* input, int8_t* output, size_t num_elements,
180
- float input_scale, float output_scale);
181
84
 
182
- void cactus_gelu_f32_erf(const float* input, float* output, size_t num_elements);
183
85
  void cactus_gelu_f16_erf(const __fp16* input, __fp16* output, size_t num_elements);
184
- void cactus_gelu_int8_erf(
185
- const int8_t* input,
186
- int8_t* output,
187
- size_t num_elements,
188
- float scale_in,
189
- float scale_out);
190
-
191
-
192
- void cactus_attention_int8(const int8_t* queries, const int8_t* keys, const int8_t* values, int8_t* output,
193
- size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
194
- size_t head_dim, float scale, const int8_t* mask,
195
- float q_scale, float k_scale, float v_scale, float output_scale, size_t position_offset = 0, size_t window_size = 0,
196
- bool is_causal = true);
197
86
 
198
87
  void cactus_attention_f16(const __fp16* queries, const __fp16* keys, const __fp16* values, __fp16* output,
199
88
  size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
200
89
  size_t head_dim, float scale, const __fp16* mask, size_t position_offset = 0, size_t window_size = 0,
201
90
  bool is_causal = true);
202
91
 
203
- void cactus_attention_f32(const float* queries, const float* keys, const float* values, float* output,
204
- size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
205
- size_t head_dim, float scale, const float* mask, size_t position_offset = 0, size_t window_size = 0,
206
- bool is_causal = true);
207
-
208
-
209
- void cactus_conv1d_causal_depthwise_f32(
210
- const float* input,
211
- const float* weight,
212
- float* output,
213
- size_t N,
214
- size_t L,
215
- size_t C,
216
- size_t K,
217
- size_t dilation);
92
+ void cactus_attention_hybrid_int8_fp16(
93
+ const __fp16* queries,
94
+ const int8_t* keys_cached,
95
+ const int8_t* values_cached,
96
+ const float* k_scales,
97
+ const float* v_scales,
98
+ const __fp16* keys_new,
99
+ const __fp16* values_new,
100
+ __fp16* output,
101
+ size_t batch_size, size_t seq_len, size_t cache_len, size_t new_len,
102
+ size_t num_q_heads, size_t num_kv_heads, size_t head_dim,
103
+ float scale, size_t position_offset = 0, bool is_causal = true,
104
+ size_t group_size = KV_QUANT_GROUP_SIZE);
218
105
 
219
106
  void cactus_conv1d_causal_depthwise_f16(
220
107
  const __fp16* input,
@@ -226,30 +113,6 @@ void cactus_conv1d_causal_depthwise_f16(
226
113
  size_t K,
227
114
  size_t dilation);
228
115
 
229
- void cactus_conv1d_causal_depthwise_int8(
230
- const int8_t* input,
231
- const int8_t* weight,
232
- int8_t* output,
233
- size_t N,
234
- size_t L,
235
- size_t C,
236
- size_t K,
237
- size_t dilation,
238
- float input_scale,
239
- float weight_scale,
240
- float output_scale);
241
-
242
- void cactus_conv1d_f32_k3(
243
- const float* input,
244
- const float* weight,
245
- float* output,
246
- size_t N,
247
- size_t L,
248
- size_t C_in,
249
- size_t C_out,
250
- size_t stride
251
- );
252
-
253
116
  void cactus_conv1d_f16_k3(
254
117
  const __fp16* input,
255
118
  const __fp16* weight,
@@ -261,26 +124,8 @@ void cactus_conv1d_f16_k3(
261
124
  size_t stride
262
125
  );
263
126
 
264
- void cactus_conv1d_f32_k3(
265
- const float* input,
266
- const float* weight,
267
- float* output,
268
- size_t N, size_t L,
269
- size_t C_in, size_t C_out,
270
- size_t stride
271
- );
272
-
273
- void cactus_conv1d_f16_k3(
274
- const __fp16* input,
275
- const __fp16* weight,
276
- __fp16* output,
277
- size_t N, size_t L,
278
- size_t C_in, size_t C_out,
279
- size_t stride
280
- );
281
-
282
- void cactus_bilinear_interpolation_fp32(const float* input, float* output, size_t src_height, size_t src_width, size_t embed_dim,
283
- size_t dst_height, size_t dst_width);
127
+ void cactus_bilinear_interpolation_f16(const __fp16* input, __fp16* output, size_t src_height, size_t src_width, size_t embed_dim,
128
+ size_t dst_height, size_t dst_width);
284
129
 
285
130
  void cactus_sample_f32(const float* logits, uint32_t* output, size_t vocab_size,
286
131
  float temperature, float top_p, size_t top_k, size_t random_seed,
@@ -291,25 +136,30 @@ void cactus_sample_f16(const __fp16* logits, uint32_t* output, size_t vocab_size
291
136
  const float* bias_values = nullptr, const uint32_t* bias_indices = nullptr,
292
137
  size_t bias_count = 0);
293
138
 
294
-
295
- void cactus_concat_f32(const float* input1, const float* input2, float* output,
296
- const size_t* shape1, const size_t* shape2, const size_t* output_shape,
297
- size_t ndims, int axis);
298
139
  void cactus_concat_f16(const __fp16* input1, const __fp16* input2, __fp16* output,
299
140
  const size_t* shape1, const size_t* shape2, const size_t* output_shape,
300
141
  size_t ndims, int axis);
301
- void cactus_concat_int8(const int8_t* input1, const int8_t* input2, int8_t* output,
302
- const size_t* shape1, const size_t* shape2, const size_t* output_shape,
303
- size_t ndims, int axis);
304
142
 
305
143
  void cactus_int8_to_fp32(const int8_t* src, float* dst, size_t count, float scale = 1.0f);
306
144
  void cactus_fp32_to_int8(const float* src, int8_t* dst, size_t count, float scale = 1.0f);
307
- void cactus_dynamic_quantize_fp32_to_int8(const float* src, int8_t* dst, size_t count, float* computed_scale);
308
145
  void cactus_fp16_to_fp32(const __fp16* src, float* dst, size_t count);
309
146
  void cactus_fp32_to_fp16(const float* src, __fp16* dst, size_t count);
310
147
  void cactus_int8_to_fp16(const int8_t* src, __fp16* dst, size_t count, float scale = 1.0f);
311
148
  void cactus_fp16_to_int8(const __fp16* src, int8_t* dst, size_t count, float scale = 1.0f);
312
149
  float cactus_fp16_max_abs(const __fp16* src, size_t count);
313
- void cactus_int32_to_fp16_scaled(const int32_t* src, __fp16* dst, size_t count, float scale);
314
150
 
315
- #endif
151
+ void cactus_quantize_kv_fp16_to_int8(
152
+ const __fp16* src,
153
+ int8_t* dst,
154
+ float* scales,
155
+ size_t seq_len, size_t kv_heads, size_t head_dim,
156
+ size_t group_size = KV_QUANT_GROUP_SIZE);
157
+
158
+ inline size_t kv_scales_count(size_t seq_len, size_t kv_heads, size_t head_dim, size_t group_size = KV_QUANT_GROUP_SIZE) {
159
+ size_t num_groups = (head_dim + group_size - 1) / group_size;
160
+ return seq_len * kv_heads * num_groups;
161
+ }
162
+
163
+ void cactus_unpack_int4_to_int8(const uint8_t* packed, int8_t* unpacked, size_t unpacked_count);
164
+
165
+ #endif