cactus-react-native 1.10.3 → 1.12.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. package/README.md +199 -40
  2. package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
  3. package/cpp/HybridCactus.cpp +131 -2
  4. package/cpp/HybridCactus.hpp +15 -0
  5. package/cpp/cactus_ffi.h +240 -2
  6. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +240 -2
  7. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_utils.h +940 -109
  8. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +175 -25
  9. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/gemma_tools.h +48 -21
  10. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +79 -7
  11. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +122 -9
  12. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +191 -2
  13. package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
  14. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +240 -2
  15. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_utils.h +940 -109
  16. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +175 -25
  17. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/gemma_tools.h +48 -21
  18. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +79 -7
  19. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +122 -9
  20. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +191 -2
  21. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
  22. package/lib/module/classes/{CactusVAD.js → CactusAudio.js} +19 -6
  23. package/lib/module/classes/CactusAudio.js.map +1 -0
  24. package/lib/module/classes/CactusLM.js +25 -0
  25. package/lib/module/classes/CactusLM.js.map +1 -1
  26. package/lib/module/hooks/{useCactusVAD.js → useCactusAudio.js} +50 -20
  27. package/lib/module/hooks/useCactusAudio.js.map +1 -0
  28. package/lib/module/index.js +2 -2
  29. package/lib/module/index.js.map +1 -1
  30. package/lib/module/modelRegistry.js +5 -3
  31. package/lib/module/modelRegistry.js.map +1 -1
  32. package/lib/module/native/Cactus.js +81 -2
  33. package/lib/module/native/Cactus.js.map +1 -1
  34. package/lib/module/types/CactusAudio.js +4 -0
  35. package/lib/module/types/{CactusVAD.js.map → CactusAudio.js.map} +1 -1
  36. package/lib/typescript/src/classes/CactusAudio.d.ts +22 -0
  37. package/lib/typescript/src/classes/CactusAudio.d.ts.map +1 -0
  38. package/lib/typescript/src/classes/CactusLM.d.ts +2 -1
  39. package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
  40. package/lib/typescript/src/hooks/useCactusAudio.d.ts +17 -0
  41. package/lib/typescript/src/hooks/useCactusAudio.d.ts.map +1 -0
  42. package/lib/typescript/src/index.d.ts +4 -4
  43. package/lib/typescript/src/index.d.ts.map +1 -1
  44. package/lib/typescript/src/native/Cactus.d.ts +9 -3
  45. package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
  46. package/lib/typescript/src/specs/Cactus.nitro.d.ts +3 -0
  47. package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
  48. package/lib/typescript/src/types/CactusAudio.d.ts +63 -0
  49. package/lib/typescript/src/types/CactusAudio.d.ts.map +1 -0
  50. package/lib/typescript/src/types/CactusLM.d.ts +15 -0
  51. package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
  52. package/lib/typescript/src/types/CactusSTT.d.ts +1 -0
  53. package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
  54. package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +3 -0
  55. package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +3 -0
  56. package/package.json +1 -1
  57. package/src/classes/{CactusVAD.ts → CactusAudio.ts} +32 -13
  58. package/src/classes/CactusLM.ts +36 -0
  59. package/src/hooks/{useCactusVAD.ts → useCactusAudio.ts} +65 -28
  60. package/src/index.tsx +16 -9
  61. package/src/modelRegistry.ts +20 -6
  62. package/src/native/Cactus.ts +118 -3
  63. package/src/specs/Cactus.nitro.ts +16 -0
  64. package/src/types/CactusAudio.ts +73 -0
  65. package/src/types/CactusLM.ts +17 -0
  66. package/src/types/CactusSTT.ts +1 -0
  67. package/lib/module/classes/CactusVAD.js.map +0 -1
  68. package/lib/module/hooks/useCactusVAD.js.map +0 -1
  69. package/lib/module/types/CactusVAD.js +0 -4
  70. package/lib/typescript/src/classes/CactusVAD.d.ts +0 -20
  71. package/lib/typescript/src/classes/CactusVAD.d.ts.map +0 -1
  72. package/lib/typescript/src/hooks/useCactusVAD.d.ts +0 -15
  73. package/lib/typescript/src/hooks/useCactusVAD.d.ts.map +0 -1
  74. package/lib/typescript/src/types/CactusVAD.d.ts +0 -34
  75. package/lib/typescript/src/types/CactusVAD.d.ts.map +0 -1
  76. package/src/types/CactusVAD.ts +0 -39
@@ -122,13 +122,14 @@ enum class Activation {
122
122
  enum class OpType {
123
123
  INPUT, PRECISION_CAST,
124
124
  ADD, ADD_CLIPPED, SUBTRACT, MULTIPLY, DIVIDE,
125
+ ABS, POW, FLATTEN, VIEW,
125
126
  MATMUL, TRANSPOSE, RESHAPE, SLICE, GATHER, EMBEDDING,
126
127
  BILINEAR_INTERPOLATION,
127
128
  SUM, MEAN, VARIANCE, MIN, MAX,
128
129
  RMS_NORM, ROPE, ROPE_GPTJ, SOFTMAX, ATTENTION, ATTENTION_INT8_HYBRID, REL_POS_BIAS, CONV1D_CAUSAL, CONV1D_K3, CONV1D_K7S3, CONV1D, CONV1D_SAME_DEPTHWISE_K9, CONV1D_POINTWISE, CONV2D_K3S2P1, CONV2D_DEPTHWISE_K3S2P1, CONV2D_POINTWISE_1X1, GLU, BATCHNORM,
129
130
  SCALAR_ADD, SCALAR_SUBTRACT, SCALAR_MULTIPLY, SCALAR_DIVIDE, SCALAR_EXP, SCALAR_SQRT, SCALAR_COS, SCALAR_SIN, SCALAR_LOG,
130
131
  RELU, SILU, GELU, GELU_ERF, SIGMOID, TANH,
131
- SAMPLE, CONCAT,
132
+ SAMPLE, CONCAT, CAT,
132
133
  SCATTER_TOPK,
133
134
  TOPK, LAYERNORM, GROUPNORM,
134
135
  MOE_LAYER,
@@ -136,7 +137,17 @@ enum class OpType {
136
137
  PERSISTENT,
137
138
  QUANTIZE_ACTIVATIONS,
138
139
  LSTM_CELL,
139
- STFT
140
+ GATED_DELTANET_DECODE,
141
+ GATED_DELTANET_PREFILL,
142
+ STFT,
143
+ ALTUP_PREDICT,
144
+ ALTUP_CORRECT,
145
+ GAUSSIAN_TOPK,
146
+ MAXPOOL1D,
147
+ BILSTM_SEQUENCE,
148
+ LEAKY_RELU,
149
+ CONV2D_K3S1P1,
150
+ STATS_POOL
140
151
  };
141
152
 
142
153
  struct PrecisionTraits {
@@ -315,6 +326,7 @@ struct OpParams {
315
326
  size_t window_size = 0;
316
327
  bool is_causal = true;
317
328
  bool attention_mask_is_additive = false;
329
+ float logit_cap = 0.0f;
318
330
  std::vector<size_t> new_shape;
319
331
  std::vector<size_t> permutation;
320
332
  Precision output_precision = Precision::INT8;
@@ -350,6 +362,10 @@ struct OpParams {
350
362
  size_t num_kv_heads = 0;
351
363
  size_t head_dim = 0;
352
364
  size_t num_fft_bins = 0;
365
+ size_t chunk_size = 0;
366
+ size_t num_altup_inputs = 0;
367
+ size_t v_head_dim = 0;
368
+ size_t kernel_size = 0;
353
369
  };
354
370
 
355
371
  struct GraphNode {
@@ -362,6 +378,28 @@ struct GraphNode {
362
378
  GraphNode(size_t node_id, OpType type);
363
379
  };
364
380
 
381
+ using nodes_vector = std::vector<std::unique_ptr<GraphNode>>;
382
+ using node_index_map_t = std::unordered_map<size_t, size_t>;
383
+
384
+ inline const BufferDesc& get_input(const GraphNode& node, size_t idx,
385
+ const nodes_vector& nodes,
386
+ const node_index_map_t& node_index_map) {
387
+ return nodes[node_index_map.at(node.input_ids[idx])]->output_buffer;
388
+ }
389
+
390
+ struct AxisDims {
391
+ size_t outer, axis_size, inner;
392
+ static AxisDims from_shape(const std::vector<size_t>& shape, size_t axis) {
393
+ AxisDims d;
394
+ d.outer = 1;
395
+ for (size_t i = 0; i < axis; i++) d.outer *= shape[i];
396
+ d.axis_size = shape[axis];
397
+ d.inner = 1;
398
+ for (size_t i = axis + 1; i < shape.size(); i++) d.inner *= shape[i];
399
+ return d;
400
+ }
401
+ };
402
+
365
403
  template<typename T>
366
404
  void dispatch_binary_op(OpType op, const T* lhs, const T* rhs, T* output, size_t count);
367
405
 
@@ -383,6 +421,14 @@ void compute_groupnorm_node(GraphNode& node, const std::vector<std::unique_ptr<G
383
421
  void compute_persistent_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
384
422
  void compute_index_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
385
423
  void compute_lstm_cell_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
424
+ void compute_gated_deltanet_decode_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
425
+ void compute_gated_deltanet_prefill_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
426
+ void compute_altup_predict_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
427
+ void compute_altup_correct_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
428
+ void compute_maxpool1d_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
429
+ void compute_bilstm_sequence_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
430
+ void compute_conv2d_k3s1p1_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
431
+ void compute_stats_pool_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
386
432
 
387
433
  void shrink_thread_local_buffers();
388
434
  class BufferPool {
@@ -437,7 +483,6 @@ public:
437
483
  size_t multiply(size_t input1, size_t input2);
438
484
  size_t divide(size_t input1, size_t input2);
439
485
 
440
-
441
486
  size_t scalar_add(size_t input, float value);
442
487
  size_t scalar_subtract(size_t input, float value);
443
488
  size_t scalar_multiply(size_t input, float value);
@@ -455,6 +500,11 @@ public:
455
500
  size_t sigmoid(size_t input);
456
501
  size_t tanh(size_t input);
457
502
  size_t glu(size_t input, int axis = -1);
503
+
504
+ size_t abs(size_t input);
505
+ size_t pow(size_t input, float exponent);
506
+ size_t view(size_t input, const std::vector<size_t>& new_shape);
507
+ size_t flatten(size_t input, int start_dim = 0, int end_dim = -1);
458
508
 
459
509
  size_t matmul(size_t input1, size_t input2, bool pretransposed_rhs = false, ComputeBackend backend = ComputeBackend::CPU);
460
510
  size_t transpose(size_t input, ComputeBackend backend = ComputeBackend::CPU);
@@ -497,7 +547,9 @@ public:
497
547
  size_t num_experts_per_tok,
498
548
  bool normalize_routing,
499
549
  float epsilon,
500
- float routed_scaling_factor);
550
+ float routed_scaling_factor,
551
+ Activation activation = Activation::SILU,
552
+ size_t per_expert_scale = 0);
501
553
  size_t moe_layer(size_t hidden,
502
554
  size_t routing_probs,
503
555
  size_t topk_indices,
@@ -518,13 +570,15 @@ public:
518
570
  size_t attention(size_t query, size_t key, size_t value, float scale, size_t position_offset, size_t window_size, ComputeBackend backend = ComputeBackend::CPU);
519
571
  size_t attention_masked(size_t query, size_t key, size_t value, size_t mask, float scale,
520
572
  bool is_causal = true, ComputeBackend backend = ComputeBackend::CPU,
521
- bool additive_mask = false, size_t position_offset = 0, size_t window_size = 0);
573
+ bool additive_mask = false, size_t position_offset = 0, size_t window_size = 0,
574
+ float logit_cap = 0.0f);
522
575
  size_t rel_pos_bias(size_t query, size_t relative_key, float scale);
523
576
 
524
577
  size_t attention_int8_hybrid(size_t query, size_t key_new, size_t value_new, float scale, size_t position_offset,
525
578
  const int8_t* cached_keys, const int8_t* cached_values,
526
579
  const float* k_scales, const float* v_scales,
527
- size_t cache_len, size_t num_kv_heads, size_t head_dim, size_t window_size = 0);
580
+ size_t cache_len, size_t num_kv_heads, size_t head_dim,
581
+ size_t window_size = 0, size_t v_head_dim = 0);
528
582
 
529
583
  size_t conv1d_causal(size_t input, size_t weight, size_t kernel_size, size_t dilation = 1);
530
584
  size_t conv1d_k3(size_t input, size_t weight, size_t stride);
@@ -543,12 +597,30 @@ public:
543
597
  size_t conv2d_pointwise_1x1(size_t input, size_t weight, size_t bias);
544
598
 
545
599
  size_t lstm_cell(size_t input, size_t h_prev, size_t c_prev, size_t weight_ih, size_t weight_hh, size_t bias_ih, size_t bias_hh);
600
+ size_t gated_deltanet_decode(size_t query, size_t key, size_t value, size_t gate_log, size_t beta,
601
+ size_t initial_state, float scale = 0.0f);
602
+ size_t gated_deltanet_prefill(size_t query, size_t key, size_t value, size_t gate_log, size_t beta,
603
+ size_t initial_state, size_t chunk_size = 64, float scale = 0.0f);
546
604
  size_t stft(size_t input, size_t weight, size_t stride, size_t num_fft_bins);
547
605
 
606
+ size_t altup_predict(size_t coefs, const size_t* streams, size_t num_streams);
607
+ size_t altup_correct(size_t coefs, size_t innovation, const size_t* predictions, size_t num_predictions);
608
+
609
+ size_t gaussian_topk(size_t input, float ppf);
610
+
611
+ size_t maxpool1d(size_t input, size_t kernel_size, size_t stride);
612
+ size_t leaky_relu(size_t input, float negative_slope = 0.01f);
613
+ size_t bilstm_sequence(size_t input, size_t w_ih_fwd, size_t w_hh_fwd, size_t b_ih_fwd, size_t b_hh_fwd,
614
+ size_t w_ih_bwd, size_t w_hh_bwd, size_t b_ih_bwd, size_t b_hh_bwd);
615
+ size_t conv2d_k3s1p1(size_t input, size_t weight);
616
+ size_t conv2d_k3s1p1(size_t input, size_t weight, size_t bias);
617
+ size_t stats_pool(size_t input);
618
+
548
619
  size_t sample(size_t logits, float temperature = 0.6f, float top_p = 0.95f, size_t top_k = 20,
549
620
  const std::unordered_map<uint32_t, float>& logit_bias = {});
550
621
 
551
622
  size_t concat(size_t input1, size_t input2, int axis = 0);
623
+ size_t cat(const std::vector<size_t>& inputs, int axis);
552
624
  size_t scatter_topk(size_t indices, size_t values, size_t num_classes);
553
625
 
554
626
  void set_input(size_t node_id, const void* data, Precision precision);
@@ -653,4 +725,4 @@ namespace GraphFile {
653
725
  };
654
726
  }
655
727
 
656
- #endif
728
+ #endif
@@ -11,7 +11,9 @@ enum class ScalarOpType {
11
11
  SUBTRACT,
12
12
  MULTIPLY,
13
13
  DIVIDE,
14
+ ABS,
14
15
  EXP,
16
+ POW,
15
17
  SQRT,
16
18
  COS,
17
19
  SIN,
@@ -54,6 +56,14 @@ void cactus_matmul_int8(const int8_t* A, const float* A_scales,
54
56
  const int8_t* B, const __fp16* B_scales,
55
57
  __fp16* C, size_t M, size_t K, size_t N, size_t group_size);
56
58
 
59
+ void cactus_gemv_int8_i8mm(const int8_t* A, float A_scale,
60
+ const int8_t* B, const __fp16* B_scales,
61
+ __fp16* C, size_t K, size_t N, size_t group_size);
62
+
63
+ void cactus_gemm_int8_i8mm(const int8_t* A, const float* A_scales,
64
+ const int8_t* B, const __fp16* B_scales,
65
+ __fp16* C, size_t M, size_t K, size_t N, size_t group_size);
66
+
57
67
  void cactus_gemv_int4(const int8_t* A, float A_scale,
58
68
  const int8_t* B_packed, const __fp16* B_scales,
59
69
  __fp16* C, size_t K, size_t N, size_t group_size);
@@ -97,6 +107,9 @@ void cactus_max_axis_f16(const __fp16* input, __fp16* output, size_t outer_size,
97
107
  void cactus_rms_norm_f16(const __fp16* input, const __fp16* weight, __fp16* output,
98
108
  size_t batch_size, size_t dims, float eps);
99
109
 
110
+ void cactus_layer_norm_f16(const __fp16* input, const __fp16* weight, const __fp16* bias,
111
+ __fp16* output, size_t batch_size, size_t dims, float eps);
112
+
100
113
  void cactus_rope_f16(const __fp16* input, __fp16* output, size_t batch_size, size_t seq_len,
101
114
  size_t num_heads, size_t head_dim, size_t start_pos, float theta);
102
115
 
@@ -108,6 +121,8 @@ void cactus_softmax_f16(const __fp16* input, __fp16* output, size_t batch_size,
108
121
 
109
122
  void cactus_relu_f16(const __fp16* input, __fp16* output, size_t num_elements);
110
123
 
124
+ void cactus_leaky_relu_f16(const __fp16* input, __fp16* output, size_t num_elements, float negative_slope);
125
+
111
126
  void cactus_silu_f16(const __fp16* input, __fp16* output, size_t num_elements);
112
127
 
113
128
  void cactus_gelu_f16(const __fp16* input, __fp16* output, size_t num_elements);
@@ -163,21 +178,54 @@ void cactus_batchnorm_f32(
163
178
  void cactus_attention_f16(const __fp16* queries, const __fp16* keys, const __fp16* values, __fp16* output,
164
179
  size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
165
180
  size_t head_dim, float scale, const __fp16* mask, size_t position_offset = 0, size_t window_size = 0,
166
- bool is_causal = true, bool mask_is_additive = false, bool mask_per_head = false);
181
+ bool is_causal = true, bool mask_is_additive = false, bool mask_per_head = false,
182
+ size_t v_head_dim = 0, float logit_cap = 0.0f);
167
183
 
168
184
  void cactus_attention_hybrid_int8_fp16(
169
- const __fp16* queries,
170
- const int8_t* keys_cached,
171
- const int8_t* values_cached,
185
+ const __fp16* queries,
186
+ const int8_t* keys_cached,
187
+ const int8_t* values_cached,
172
188
  const float* k_scales,
173
- const float* v_scales,
174
- const __fp16* keys_new,
175
- const __fp16* values_new,
189
+ const float* v_scales,
190
+ const __fp16* keys_new,
191
+ const __fp16* values_new,
176
192
  __fp16* output,
177
193
  size_t batch_size, size_t seq_len, size_t cache_len, size_t new_len,
178
194
  size_t num_q_heads, size_t num_kv_heads, size_t head_dim,
179
195
  float scale, size_t position_offset = 0, bool is_causal = true, size_t window_size = 0,
180
- size_t group_size = KV_QUANT_GROUP_SIZE);
196
+ size_t group_size = KV_QUANT_GROUP_SIZE, size_t v_head_dim = 0);
197
+
198
+ void cactus_gated_deltanet_decode_f16(
199
+ const __fp16* q_data,
200
+ const __fp16* k_data,
201
+ const __fp16* v_data,
202
+ const __fp16* g_data,
203
+ const __fp16* b_data,
204
+ const __fp16* s_data,
205
+ __fp16* out,
206
+ size_t B,
207
+ size_t Hq,
208
+ size_t Hv,
209
+ size_t K,
210
+ size_t V,
211
+ float scale);
212
+
213
+ void cactus_gated_deltanet_prefill_f16(
214
+ const __fp16* q_data,
215
+ const __fp16* k_data,
216
+ const __fp16* v_data,
217
+ const __fp16* g_data,
218
+ const __fp16* b_data,
219
+ const __fp16* s_data,
220
+ __fp16* out,
221
+ size_t B,
222
+ size_t T,
223
+ size_t Hq,
224
+ size_t Hv,
225
+ size_t K,
226
+ size_t V,
227
+ size_t requested_chunk_size,
228
+ float scale);
181
229
 
182
230
  void cactus_conv1d_causal_depthwise_f16(
183
231
  const __fp16* input,
@@ -244,6 +292,18 @@ void cactus_conv1d_same_depthwise_f16_k9(
244
292
  size_t C
245
293
  );
246
294
 
295
+ void cactus_conv2d_f16_k3s1p1_nchw(
296
+ const __fp16* input,
297
+ const __fp16* weight,
298
+ const __fp16* bias,
299
+ __fp16* output,
300
+ size_t N,
301
+ size_t C_in,
302
+ size_t H,
303
+ size_t W,
304
+ size_t C_out
305
+ );
306
+
247
307
  void cactus_conv2d_f16_k3s2p1_nchw(
248
308
  const __fp16* input,
249
309
  const __fp16* weight,
@@ -305,6 +365,8 @@ void cactus_sample_f16(const __fp16* logits, uint32_t* output, size_t vocab_size
305
365
  void cactus_concat_f16(const __fp16* input1, const __fp16* input2, __fp16* output,
306
366
  const size_t* shape1, const size_t* shape2, const size_t* output_shape,
307
367
  size_t ndims, int axis);
368
+ void cactus_cat_f16(const __fp16** inputs, __fp16* output, const size_t** input_shapes,
369
+ const size_t* output_shape, size_t num_inputs, size_t rank, int axis);
308
370
 
309
371
  void cactus_int8_to_fp32(const int8_t* src, float* dst, size_t count, float scale = 1.0f);
310
372
  void cactus_fp32_to_int8(const float* src, int8_t* dst, size_t count, float scale = 1.0f);
@@ -328,6 +390,30 @@ inline size_t kv_scales_count(size_t seq_len, size_t kv_heads, size_t head_dim,
328
390
 
329
391
  void cactus_unpack_int4_to_int8(const uint8_t* packed, int8_t* unpacked, size_t unpacked_count);
330
392
 
393
+ void cactus_gaussian_topk_f16(
394
+ const __fp16* input,
395
+ __fp16* output,
396
+ size_t rows,
397
+ size_t cols,
398
+ float ppf);
399
+
400
+ void cactus_altup_predict_f16(
401
+ const __fp16* coefs,
402
+ const __fp16* const* streams,
403
+ __fp16* output,
404
+ size_t n,
405
+ size_t seq_len,
406
+ size_t hidden_dim);
407
+
408
+ void cactus_altup_correct_f16(
409
+ const __fp16* coefs,
410
+ const __fp16* innovation,
411
+ const __fp16* const* predictions,
412
+ __fp16* output,
413
+ size_t n,
414
+ size_t seq_len,
415
+ size_t hidden_dim);
416
+
331
417
  void cactus_lstm_cell_f16(
332
418
  const __fp16* x_input,
333
419
  const __fp16* h_prev,
@@ -343,4 +429,31 @@ void cactus_lstm_cell_f16(
343
429
  size_t hidden_size
344
430
  );
345
431
 
346
- #endif
432
+ void cactus_bilstm_sequence_f16(
433
+ const __fp16* input,
434
+ const __fp16* weight_ih_fwd,
435
+ const __fp16* weight_hh_fwd,
436
+ const __fp16* bias_ih_fwd,
437
+ const __fp16* bias_hh_fwd,
438
+ const __fp16* weight_ih_bwd,
439
+ const __fp16* weight_hh_bwd,
440
+ const __fp16* bias_ih_bwd,
441
+ const __fp16* bias_hh_bwd,
442
+ __fp16* output,
443
+ size_t batch_size,
444
+ size_t seq_len,
445
+ size_t input_size,
446
+ size_t hidden_size
447
+ );
448
+
449
+ void cactus_maxpool1d_f16(
450
+ const __fp16* input,
451
+ __fp16* output,
452
+ size_t batch_size,
453
+ size_t channels,
454
+ size_t input_length,
455
+ size_t kernel_size,
456
+ size_t stride
457
+ );
458
+
459
+ #endif
@@ -9,6 +9,8 @@
9
9
  #if defined(__ANDROID__)
10
10
  #include <sys/auxv.h>
11
11
  #include <asm/hwcap.h>
12
+ #include <sched.h>
13
+ #include <fstream>
12
14
  #endif
13
15
  #include <algorithm>
14
16
  #include <cmath>
@@ -44,6 +46,29 @@ inline void stream_store_f16x8(__fp16* dst, float16x8_t val) {
44
46
  #endif
45
47
  }
46
48
 
49
+ inline bool cpu_has_i8mm() {
50
+ #if defined(__aarch64__)
51
+ static std::once_flag once;
52
+ static bool has = false;
53
+
54
+ std::call_once(once, []() {
55
+ #if defined(__APPLE__)
56
+ has = true;
57
+ #elif defined(__ANDROID__)
58
+ unsigned long hwcap2 = getauxval(AT_HWCAP2);
59
+ #ifndef HWCAP2_I8MM
60
+ #define HWCAP2_I8MM (1 << 13)
61
+ #endif
62
+ has = (hwcap2 & HWCAP2_I8MM) != 0;
63
+ #endif
64
+ });
65
+
66
+ return has;
67
+ #else
68
+ return false;
69
+ #endif
70
+ }
71
+
47
72
  inline bool cpu_has_sme2() {
48
73
  #if defined(__aarch64__)
49
74
  static std::once_flag once;
@@ -130,6 +155,33 @@ inline float32x4_t fast_tanh_f32x4(float32x4_t x) {
130
155
  return result;
131
156
  }
132
157
 
158
+ constexpr size_t SIMD_F16_WIDTH = 8;
159
+
160
+ inline size_t simd_align(size_t count, size_t width = SIMD_F16_WIDTH) {
161
+ return (count / width) * width;
162
+ }
163
+
164
+ inline void f16x8_split_f32(float16x8_t v, float32x4_t& lo, float32x4_t& hi) {
165
+ lo = vcvt_f32_f16(vget_low_f16(v));
166
+ hi = vcvt_f32_f16(vget_high_f16(v));
167
+ }
168
+
169
+ inline float16x8_t f32_merge_f16(float32x4_t lo, float32x4_t hi) {
170
+ return vcombine_f16(vcvt_f16_f32(lo), vcvt_f16_f32(hi));
171
+ }
172
+
173
+ inline float32x4_t fast_sigmoid_f32x4(float32x4_t x) {
174
+ const float32x4_t one = vdupq_n_f32(1.0f);
175
+ return vdivq_f32(one, vaddq_f32(one, fast_exp_f32x4(vnegq_f32(x))));
176
+ }
177
+
178
+ template<typename F32x4Op>
179
+ inline float16x8_t apply_f32_op_on_f16x8(float16x8_t v, F32x4Op op) {
180
+ float32x4_t lo, hi;
181
+ f16x8_split_f32(v, lo, hi);
182
+ return f32_merge_f16(op(lo), op(hi));
183
+ }
184
+
133
185
  inline void unpack_int4_as_int8x16x2(const uint8_t* ptr, int8x16_t& high_decoded, int8x16_t& low_decoded) {
134
186
  int8x16_t packed = vreinterpretq_s8_u8(vld1q_u8(ptr));
135
187
  high_decoded = vshrq_n_s8(packed, 4);
@@ -138,6 +190,80 @@ inline void unpack_int4_as_int8x16x2(const uint8_t* ptr, int8x16_t& high_decoded
138
190
 
139
191
  namespace CactusThreading {
140
192
 
193
+ #if defined(__ANDROID__)
194
+ struct CoreTopology {
195
+ std::vector<int> performance_cores;
196
+ std::vector<int> all_cores;
197
+
198
+ static CoreTopology& get() {
199
+ static CoreTopology topo = detect();
200
+ return topo;
201
+ }
202
+
203
+ private:
204
+ static int read_sysfs_int(const char* path) {
205
+ std::ifstream f(path);
206
+ if (!f.is_open()) return -1;
207
+ int val = -1;
208
+ f >> val;
209
+ return val;
210
+ }
211
+
212
+ static CoreTopology detect() {
213
+ CoreTopology topo;
214
+ constexpr int MAX_CPUS = 16;
215
+ std::vector<std::pair<int, int>> core_caps;
216
+
217
+ for (int i = 0; i < MAX_CPUS; ++i) {
218
+ char path[128];
219
+
220
+ snprintf(path, sizeof(path),
221
+ "/sys/devices/system/cpu/cpu%d/cpu_capacity", i);
222
+ int cap = read_sysfs_int(path);
223
+ if (cap > 0) {
224
+ core_caps.push_back({i, cap});
225
+ topo.all_cores.push_back(i);
226
+ continue;
227
+ }
228
+
229
+ snprintf(path, sizeof(path),
230
+ "/sys/devices/system/cpu/cpu%d/cpufreq/cpuinfo_max_freq", i);
231
+ int freq = read_sysfs_int(path);
232
+ if (freq > 0) {
233
+ core_caps.push_back({i, freq});
234
+ topo.all_cores.push_back(i);
235
+ }
236
+ }
237
+
238
+ if (core_caps.empty()) return topo;
239
+
240
+ int max_cap = 0;
241
+ for (auto& [id, cap] : core_caps) {
242
+ max_cap = std::max(max_cap, cap);
243
+ }
244
+
245
+ int threshold = static_cast<int>(max_cap * 0.70);
246
+ for (auto& [id, cap] : core_caps) {
247
+ if (cap >= threshold) {
248
+ topo.performance_cores.push_back(id);
249
+ }
250
+ }
251
+
252
+ return topo;
253
+ }
254
+ };
255
+
256
+ inline bool pin_current_thread_to_cores(const std::vector<int>& cores) {
257
+ if (cores.empty()) return false;
258
+ cpu_set_t mask;
259
+ CPU_ZERO(&mask);
260
+ for (int core : cores) {
261
+ CPU_SET(core, &mask);
262
+ }
263
+ return sched_setaffinity(0, sizeof(mask), &mask) == 0;
264
+ }
265
+ #endif
266
+
141
267
  class ThreadPool {
142
268
  private:
143
269
  static constexpr size_t MAX_WORKERS = 16;
@@ -184,9 +310,25 @@ namespace CactusThreading {
184
310
  : stop(false), pending_tasks(0) {
185
311
  num_workers_ = std::min(num_threads, MAX_WORKERS);
186
312
  if (num_workers_ == 0) num_workers_ = 1;
313
+
314
+ #if defined(__ANDROID__)
315
+ auto& topo = CoreTopology::get();
316
+ if (!topo.performance_cores.empty()) {
317
+ num_workers_ = std::min(num_workers_, topo.performance_cores.size());
318
+ }
319
+ #endif
320
+
187
321
  workers.reserve(num_workers_);
188
322
  for (size_t i = 0; i < num_workers_; ++i) {
189
- workers.emplace_back(&ThreadPool::worker_thread, this);
323
+ workers.emplace_back([this]() {
324
+ #if defined(__ANDROID__)
325
+ auto& perf = CoreTopology::get().performance_cores;
326
+ if (!perf.empty()) {
327
+ pin_current_thread_to_cores(perf);
328
+ }
329
+ #endif
330
+ worker_thread();
331
+ });
190
332
  }
191
333
  }
192
334
 
@@ -498,5 +640,52 @@ namespace CactusThreading {
498
640
 
499
641
  }
500
642
 
643
+ template<typename SimdOp, typename ScalarOp>
644
+ void elementwise_op_f16(const __fp16* input, __fp16* output, size_t num_elements,
645
+ bool use_streaming, CactusThreading::ParallelConfig config,
646
+ SimdOp simd_op, ScalarOp scalar_op, size_t unroll = 4) {
647
+ CactusThreading::parallel_for(num_elements, config,
648
+ [&](size_t start, size_t end) {
649
+ const size_t n = end - start;
650
+ const size_t vec_end = start + simd_align(n);
651
+
652
+ if (use_streaming && unroll >= 4) {
653
+ const size_t unrolled_end = start + simd_align(n, SIMD_F16_WIDTH * 4);
654
+ for (size_t i = start; i < unrolled_end; i += SIMD_F16_WIDTH * 4) {
655
+ __builtin_prefetch(&input[i + 256], 0, 0);
656
+ float16x8_t v0 = simd_op(vld1q_f16(&input[i]));
657
+ float16x8_t v1 = simd_op(vld1q_f16(&input[i + 8]));
658
+ float16x8_t v2 = simd_op(vld1q_f16(&input[i + 16]));
659
+ float16x8_t v3 = simd_op(vld1q_f16(&input[i + 24]));
660
+ stream_store_f16x8(&output[i], v0);
661
+ stream_store_f16x8(&output[i + 8], v1);
662
+ stream_store_f16x8(&output[i + 16], v2);
663
+ stream_store_f16x8(&output[i + 24], v3);
664
+ }
665
+ for (size_t i = unrolled_end; i < vec_end; i += SIMD_F16_WIDTH) {
666
+ stream_store_f16x8(&output[i], simd_op(vld1q_f16(&input[i])));
667
+ }
668
+ } else if (use_streaming && unroll >= 2) {
669
+ const size_t unrolled_end = start + simd_align(n, SIMD_F16_WIDTH * 2);
670
+ for (size_t i = start; i < unrolled_end; i += SIMD_F16_WIDTH * 2) {
671
+ __builtin_prefetch(&input[i + 128], 0, 0);
672
+ float16x8_t v0 = simd_op(vld1q_f16(&input[i]));
673
+ float16x8_t v1 = simd_op(vld1q_f16(&input[i + 8]));
674
+ stream_store_f16x8(&output[i], v0);
675
+ stream_store_f16x8(&output[i + 8], v1);
676
+ }
677
+ for (size_t i = unrolled_end; i < vec_end; i += SIMD_F16_WIDTH) {
678
+ stream_store_f16x8(&output[i], simd_op(vld1q_f16(&input[i])));
679
+ }
680
+ } else {
681
+ for (size_t i = start; i < vec_end; i += SIMD_F16_WIDTH) {
682
+ vst1q_f16(&output[i], simd_op(vld1q_f16(&input[i])));
683
+ }
684
+ }
685
+ for (size_t i = vec_end; i < end; ++i) {
686
+ output[i] = scalar_op(input[i]);
687
+ }
688
+ });
689
+ }
501
690
 
502
- #endif // KERNEL_UTILS_H
691
+ #endif // KERNEL_UTILS_H