cactus-react-native 1.10.3 → 1.12.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. package/README.md +199 -40
  2. package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
  3. package/cpp/HybridCactus.cpp +131 -2
  4. package/cpp/HybridCactus.hpp +15 -0
  5. package/cpp/cactus_ffi.h +240 -2
  6. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +240 -2
  7. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_utils.h +940 -109
  8. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +175 -25
  9. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/gemma_tools.h +48 -21
  10. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +79 -7
  11. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +122 -9
  12. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +191 -2
  13. package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
  14. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +240 -2
  15. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_utils.h +940 -109
  16. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +175 -25
  17. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/gemma_tools.h +48 -21
  18. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +79 -7
  19. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +122 -9
  20. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +191 -2
  21. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
  22. package/lib/module/classes/{CactusVAD.js → CactusAudio.js} +19 -6
  23. package/lib/module/classes/CactusAudio.js.map +1 -0
  24. package/lib/module/classes/CactusLM.js +25 -0
  25. package/lib/module/classes/CactusLM.js.map +1 -1
  26. package/lib/module/hooks/{useCactusVAD.js → useCactusAudio.js} +50 -20
  27. package/lib/module/hooks/useCactusAudio.js.map +1 -0
  28. package/lib/module/index.js +2 -2
  29. package/lib/module/index.js.map +1 -1
  30. package/lib/module/modelRegistry.js +5 -3
  31. package/lib/module/modelRegistry.js.map +1 -1
  32. package/lib/module/native/Cactus.js +81 -2
  33. package/lib/module/native/Cactus.js.map +1 -1
  34. package/lib/module/types/CactusAudio.js +4 -0
  35. package/lib/module/types/{CactusVAD.js.map → CactusAudio.js.map} +1 -1
  36. package/lib/typescript/src/classes/CactusAudio.d.ts +22 -0
  37. package/lib/typescript/src/classes/CactusAudio.d.ts.map +1 -0
  38. package/lib/typescript/src/classes/CactusLM.d.ts +2 -1
  39. package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
  40. package/lib/typescript/src/hooks/useCactusAudio.d.ts +17 -0
  41. package/lib/typescript/src/hooks/useCactusAudio.d.ts.map +1 -0
  42. package/lib/typescript/src/index.d.ts +4 -4
  43. package/lib/typescript/src/index.d.ts.map +1 -1
  44. package/lib/typescript/src/native/Cactus.d.ts +9 -3
  45. package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
  46. package/lib/typescript/src/specs/Cactus.nitro.d.ts +3 -0
  47. package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
  48. package/lib/typescript/src/types/CactusAudio.d.ts +63 -0
  49. package/lib/typescript/src/types/CactusAudio.d.ts.map +1 -0
  50. package/lib/typescript/src/types/CactusLM.d.ts +15 -0
  51. package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
  52. package/lib/typescript/src/types/CactusSTT.d.ts +1 -0
  53. package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
  54. package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +3 -0
  55. package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +3 -0
  56. package/package.json +1 -1
  57. package/src/classes/{CactusVAD.ts → CactusAudio.ts} +32 -13
  58. package/src/classes/CactusLM.ts +36 -0
  59. package/src/hooks/{useCactusVAD.ts → useCactusAudio.ts} +65 -28
  60. package/src/index.tsx +16 -9
  61. package/src/modelRegistry.ts +20 -6
  62. package/src/native/Cactus.ts +118 -3
  63. package/src/specs/Cactus.nitro.ts +16 -0
  64. package/src/types/CactusAudio.ts +73 -0
  65. package/src/types/CactusLM.ts +17 -0
  66. package/src/types/CactusSTT.ts +1 -0
  67. package/lib/module/classes/CactusVAD.js.map +0 -1
  68. package/lib/module/hooks/useCactusVAD.js.map +0 -1
  69. package/lib/module/types/CactusVAD.js +0 -4
  70. package/lib/typescript/src/classes/CactusVAD.d.ts +0 -20
  71. package/lib/typescript/src/classes/CactusVAD.d.ts.map +0 -1
  72. package/lib/typescript/src/hooks/useCactusVAD.d.ts +0 -15
  73. package/lib/typescript/src/hooks/useCactusVAD.d.ts.map +0 -1
  74. package/lib/typescript/src/types/CactusVAD.d.ts +0 -34
  75. package/lib/typescript/src/types/CactusVAD.d.ts.map +0 -1
  76. package/src/types/CactusVAD.ts +0 -39
@@ -75,6 +75,7 @@ struct Config {
75
75
  bool use_pixel_shuffle = false;
76
76
  uint32_t pixel_shuffle_factor = 1;
77
77
  bool use_image_tokens = false;
78
+ uint32_t image_token_id = 0;
78
79
  bool use_layout_tags = false;
79
80
  uint32_t image_seq_len = 64;
80
81
 
@@ -107,6 +108,26 @@ struct Config {
107
108
  uint32_t subsampling_factor = 0;
108
109
  uint32_t num_mel_bins = 80;
109
110
  std::string encoder_hidden_act = "silu";
111
+ uint32_t linear_num_key_heads = 0;
112
+ uint32_t linear_key_head_dim = 0;
113
+ uint32_t linear_num_value_heads = 0;
114
+ uint32_t linear_value_head_dim = 0;
115
+ uint32_t linear_q_proj_dim = 0;
116
+ uint32_t linear_k_proj_dim = 0;
117
+ uint32_t linear_v_proj_dim = 0;
118
+
119
+ uint32_t kv_lora_rank = 0;
120
+ uint32_t q_lora_rank = 0;
121
+ uint32_t qk_head_dim = 0;
122
+ uint32_t qk_nope_head_dim = 0;
123
+ uint32_t qk_rope_head_dim = 0;
124
+ uint32_t v_head_dim = 0;
125
+ uint32_t rope_interleave = 0;
126
+ bool attention_bias = false;
127
+ float rope_scaling_factor = 1.0f;
128
+ float rope_mscale_all_dim = 0.0f;
129
+
130
+ enum class ModelType {QWEN = 0, GEMMA = 1, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6, WHISPER = 7, MOONSHINE = 8, SILERO_VAD = 9, PARAKEET = 10, QWEN3P5 = 11, PARAKEET_TDT = 12, GEMMA3N = 13, YOUTU = 14, GEMMA4 = 15, PYANNOTE = 16, WESPEAKER = 17};
110
131
  uint32_t predictor_hidden_dim = 0;
111
132
  uint32_t predictor_num_layers = 0;
112
133
  uint32_t tdt_joint_dim = 0;
@@ -114,7 +135,6 @@ struct Config {
114
135
  uint32_t tdt_blank_id = 0;
115
136
  std::vector<uint32_t> tdt_durations;
116
137
 
117
- enum class ModelType {QWEN = 0, GEMMA = 1, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6, WHISPER = 7, MOONSHINE = 8, SILERO_VAD = 9, PARAKEET = 10, PARAKEET_TDT = 11};
118
138
  ModelType model_type = ModelType::QWEN;
119
139
 
120
140
  enum class ModelVariant {DEFAULT = 0, VLM = 1, EXTRACT = 2, RAG = 3};
@@ -138,6 +158,58 @@ struct Config {
138
158
  std::vector<std::string> layer_types;
139
159
  size_t conv_L_cache = 0;
140
160
 
161
+ uint32_t altup_num_inputs = 4;
162
+ uint32_t laurel_rank = 64;
163
+ uint32_t hidden_size_per_layer_input = 256;
164
+ uint32_t num_kv_shared_layers = 0;
165
+ uint32_t sliding_window = 512;
166
+ float rope_local_base_freq = 10000.0f;
167
+ float final_logit_softcapping = 0.0f;
168
+ float global_partial_rotary_factor = 1.0f;
169
+ uint32_t expert_intermediate_size = 0;
170
+ uint32_t global_head_dim = 0;
171
+ uint32_t num_global_kv_heads = 0;
172
+ bool attention_k_eq_v = false;
173
+ bool enable_moe_block = false;
174
+ std::vector<float> activation_sparsity_ppf;
175
+
176
+ uint32_t vision_head_dim = 64;
177
+ uint32_t vision_kv_heads = 12;
178
+ uint32_t vision_intermediate_size = 3072;
179
+ uint32_t vision_position_embedding_size = 10240;
180
+ uint32_t vision_pooling_kernel_size = 3;
181
+ uint32_t vision_default_output_length = 280;
182
+ float vision_rope_theta = 100.0f;
183
+
184
+ uint32_t audio_hidden_dim = 0;
185
+ uint32_t audio_num_layers = 0;
186
+ uint32_t audio_num_heads = 0;
187
+ uint32_t audio_head_dim = 0;
188
+ uint32_t audio_input_feat_size = 128;
189
+ uint32_t audio_conf_conv_kernel_size = 5;
190
+ uint32_t audio_chunk_size = 12;
191
+ uint32_t audio_context_left = 13;
192
+ uint32_t audio_context_right = 0;
193
+ float audio_logit_cap = 50.0f;
194
+ float audio_residual_weight = 0.5f;
195
+ uint32_t audio_output_proj_dims = 0;
196
+ uint32_t audio_vocab_size = 128;
197
+ uint32_t audio_vocab_offset = 0;
198
+ uint32_t audio_soft_tokens = 188;
199
+ uint32_t audio_sscp_conv0_channels = 128;
200
+ uint32_t audio_sscp_conv1_channels = 32;
201
+ float audio_sscp_conv_eps = 1e-3f;
202
+ float audio_rms_norm_eps = 1e-6f;
203
+ uint32_t audio_fft_length = 1024;
204
+ uint32_t audio_token_id = 0;
205
+ bool audio_fft_overdrive = false;
206
+ uint32_t channel_open_token_id = 100;
207
+ uint32_t channel_close_token_id = 101;
208
+
209
+ static bool is_gemma_family(ModelType t) {
210
+ return t == ModelType::GEMMA || t == ModelType::GEMMA3N || t == ModelType::GEMMA4;
211
+ }
212
+
141
213
  bool from_json(const std::string& json_path);
142
214
  std::string to_json() const;
143
215
  };
@@ -155,14 +227,38 @@ struct MergeRule {
155
227
  };
156
228
 
157
229
 
230
+ struct ToolCallInfo {
231
+ std::string name;
232
+ std::string arguments;
233
+ };
234
+
158
235
  struct ChatMessage {
159
236
  std::string role;
160
237
  std::string content;
161
238
  std::string name;
162
239
  std::vector<std::string> images;
240
+ std::vector<std::string> audio;
241
+ size_t audio_soft_token_count = 0;
242
+ std::vector<ToolCallInfo> tool_calls;
163
243
  };
164
244
 
245
+ struct TokenizerRuntimeConfig {
246
+ enum class TokenizerType { UNKNOWN, BPE, SENTENCEPIECE };
247
+ enum class VocabFormat { UNKNOWN, ID_TAB_TOKEN, LINE_TOKEN };
248
+ enum class Normalizer { NONE, METASPACE, BYTE_LEVEL };
249
+ enum class Decoder { NONE, REPLACE_METASPACE, BYTE_LEVEL };
250
+
251
+ TokenizerType tokenizer_type = TokenizerType::UNKNOWN;
252
+ VocabFormat vocab_format = VocabFormat::UNKNOWN;
253
+ Normalizer normalizer = Normalizer::NONE;
254
+ Decoder decoder = Decoder::NONE;
255
+ bool byte_fallback = false;
256
+ bool has_chat_template = false;
257
+ };
165
258
 
259
+ TokenizerRuntimeConfig load_tokenizer_runtime_config(const std::string& config_file);
260
+ void load_special_tokens_map(const std::string& config_file, std::unordered_map<std::string, uint32_t>& special_tokens);
261
+ std::vector<std::string> split_with_special_tokens(const std::string& text, const std::unordered_map<std::string, uint32_t>& special_tokens);
166
262
 
167
263
  class Tokenizer {
168
264
  public:
@@ -172,7 +268,7 @@ public:
172
268
  virtual std::string decode(const std::vector<uint32_t>& tokens) const = 0;
173
269
 
174
270
  virtual std::vector<uint32_t> apply_chat_template(const std::vector<ChatMessage>& messages, bool add_generation_prompt = true) const;
175
- virtual std::string format_chat_prompt(const std::vector<ChatMessage>& messages, bool add_generation_prompt = true, const std::string& tools_json = "") const;
271
+ virtual std::string format_chat_prompt(const std::vector<ChatMessage>& messages, bool add_generation_prompt = true, const std::string& tools_json = "", bool enable_thinking_if_supported = true) const;
176
272
 
177
273
  virtual uint32_t get_vocab_size() const = 0;
178
274
  virtual uint32_t get_unk_token() const = 0;
@@ -188,7 +284,7 @@ public:
188
284
  uint32_t get_global_img_token_id() const { return global_img_token_id_; }
189
285
 
190
286
  protected:
191
- enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2, BERT, WHISPER, PARAKEET};
287
+ enum class ModelType { UNKNOWN, QWEN, QWEN3P5, GEMMA, GEMMA4, LFM2, BERT, WHISPER, PARAKEET, YOUTU};
192
288
  ModelType model_type_ = ModelType::UNKNOWN;
193
289
  enum class ModelVariant { DEFAULT, VLM, EXTRACT, RAG};
194
290
  ModelVariant model_variant_ = ModelVariant::DEFAULT;
@@ -199,11 +295,21 @@ protected:
199
295
  uint32_t fake_token_id_ = 49189;
200
296
  uint32_t global_img_token_id_ = 49152;
201
297
 
298
+
299
+ uint32_t vision_patch_size_ = 16;
300
+ uint32_t vision_pooling_kernel_size_ = 3;
301
+ uint32_t vision_default_output_length_ = 280;
302
+ uint32_t vision_image_size_ = 768;
303
+ TokenizerRuntimeConfig runtime_config_;
304
+
202
305
  void detect_model_type(const std::string& config_path);
203
- std::string format_qwen_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
306
+ void load_chat_template(const std::string& template_file);
307
+ std::string format_qwen_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json, bool enable_thinking_if_supported = true) const;
204
308
  std::string format_gemma_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
309
+ std::string format_gemma4_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json, bool enable_thinking_if_supported = true) const;
205
310
  std::string format_lfm2_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
206
311
  std::string format_lfm2_vl_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
312
+ std::string format_youtu_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
207
313
  };
208
314
 
209
315
  class BPETokenizer : public Tokenizer {
@@ -245,6 +351,7 @@ private:
245
351
  std::string bytes_to_unicode(const std::string& text) const;
246
352
  std::string unicode_to_bytes(const std::string& text) const;
247
353
  std::vector<std::string> byte_level_split(const std::string& text) const;
354
+ std::vector<std::string> utf8_split(const std::string& text) const;
248
355
 
249
356
  void cleanup_mmap();
250
357
 
@@ -256,12 +363,6 @@ private:
256
363
  std::unordered_map<std::string, uint32_t> special_tokens_;
257
364
  std::vector<std::string> split_with_special_tokens(const std::string& text) const;
258
365
  void load_special_tokens(const std::string& config_file);
259
-
260
- void load_chat_template(const std::string& template_file);
261
-
262
- std::unordered_map<std::string, uint32_t> tool_tokens_;
263
- bool has_tool_support_;
264
- void load_tokenizer_config(const std::string& config_file);
265
366
  };
266
367
 
267
368
  class SPTokenizer : public Tokenizer {
@@ -311,8 +412,6 @@ private:
311
412
  std::unordered_map<std::string, uint32_t> special_tokens_;
312
413
  std::vector<std::string> split_with_special_tokens(const std::string& text) const;
313
414
  void load_special_tokens(const std::string& config_file);
314
-
315
- void load_chat_template(const std::string& template_file);
316
415
  };
317
416
 
318
417
  class ConvCache {
@@ -355,8 +454,10 @@ struct KVCache {
355
454
  struct LayerCache {
356
455
  std::vector<uint8_t> keys;
357
456
  std::vector<uint8_t> values;
358
- std::vector<float> key_scales;
359
- std::vector<float> value_scales;
457
+ std::vector<float> key_scales;
458
+ std::vector<float> value_scales;
459
+ size_t head_dim = 0;
460
+ size_t kv_heads = 0;
360
461
  };
361
462
 
362
463
  std::vector<LayerCache> layer_caches;
@@ -366,8 +467,6 @@ struct KVCache {
366
467
  size_t current_seq_len = 0;
367
468
  size_t total_seq_len = 0;
368
469
  size_t max_seq_len = 2048;
369
- size_t num_kv_heads = 0;
370
- size_t head_dim = 0;
371
470
  size_t num_layers = 0;
372
471
  Precision precision;
373
472
  size_t element_size = 4;
@@ -375,12 +474,14 @@ struct KVCache {
375
474
  void set_window_size(size_t window, size_t sink = DEFAULT_SINK_SIZE);
376
475
  size_t get_effective_seq_len() const { return current_seq_len; }
377
476
  size_t get_total_seq_len() const { return total_seq_len; }
477
+ size_t get_layer_head_dim(size_t layer_idx) const { return layer_caches[layer_idx].head_dim; }
478
+ size_t get_layer_kv_heads(size_t layer_idx) const { return layer_caches[layer_idx].kv_heads; }
378
479
 
379
- void init(size_t num_layers, size_t max_seq, size_t num_kv_heads, size_t head_dim, Precision model_precision);
480
+ void init(size_t num_layers, size_t max_seq, const std::vector<size_t>& layer_dims, const std::vector<size_t>& layer_kv_heads, Precision model_precision);
380
481
  void reset();
381
482
  void update_from_graph(CactusGraph* gb, const std::vector<size_t>& k_nodes,
382
483
  const std::vector<size_t>& v_nodes, size_t seq_len,
383
- size_t num_layers, size_t kv_heads, size_t head_dim);
484
+ size_t num_layers);
384
485
 
385
486
  void update_from_npu(size_t layer_idx, const __fp16* k_data, const __fp16* v_data,
386
487
  size_t num_tokens, size_t kv_heads, size_t head_dim);
@@ -404,6 +505,9 @@ struct KVCache {
404
505
  const int8_t* get_values_int8(size_t layer) const;
405
506
  const float* get_key_scales(size_t layer) const;
406
507
  const float* get_value_scales(size_t layer) const;
508
+
509
+ void remove_token_range(size_t start, size_t count);
510
+ void compact_to_windows(const std::vector<size_t>& target_windows);
407
511
  };
408
512
 
409
513
  class ToolCallConstrainer {
@@ -421,7 +525,7 @@ public:
421
525
  QWEN_EXPECT_ARGS_COLON,
422
526
  QWEN_IN_ARGUMENTS,
423
527
  QWEN_EXPECT_CLOSE_BRACE,
424
- QWEN_EXPECT_END,
528
+ QWEN_EXPECT_END,
425
529
 
426
530
  LFM_START,
427
531
  LFM_EXPECT_BRACKET,
@@ -457,12 +561,17 @@ private:
457
561
  Config::ModelType model_type_ = Config::ModelType::QWEN;
458
562
  Tokenizer* tokenizer_ = nullptr;
459
563
 
564
+ bool is_gemma_family() const { return Config::is_gemma_family(model_type_); }
565
+
460
566
  std::vector<std::string> function_names_;
461
567
  std::string generated_text_;
462
- int brace_depth_ = 0;
568
+ int brace_depth_ = 0;
569
+
570
+ std::string call_start_tag_;
571
+ std::string call_end_tag_;
463
572
 
464
- std::unordered_set<uint32_t> qwen_tool_call_start_tokens_;
465
- std::unordered_set<uint32_t> qwen_tool_call_end_tokens_;
573
+ std::unordered_set<uint32_t> qwen_tool_call_start_tokens_;
574
+ std::unordered_set<uint32_t> qwen_tool_call_end_tokens_;
466
575
  std::unordered_set<uint32_t> open_brace_tokens_;
467
576
  std::unordered_set<uint32_t> close_brace_tokens_;
468
577
  std::unordered_set<uint32_t> colon_tokens_;
@@ -472,7 +581,7 @@ private:
472
581
  std::unordered_set<uint32_t> quote_tokens_;
473
582
  std::unordered_set<uint32_t> backtick_tokens_;
474
583
  std::unordered_set<uint32_t> all_func_name_tokens_;
475
- std::unordered_map<std::string, std::vector<uint32_t>> func_name_sequences_;
584
+ std::unordered_map<std::string, std::vector<uint32_t>> func_name_sequences_;
476
585
 
477
586
  std::unordered_set<uint32_t> tool_start_tokens_;
478
587
  std::unordered_set<uint32_t> tool_end_tokens_;
@@ -523,12 +632,16 @@ public:
523
632
 
524
633
  virtual void prefill(const std::vector<uint32_t>& tokens, size_t chunk_size = 256, const std::string& profile_file = "");
525
634
 
635
+ virtual void prefill_with_images(const std::vector<uint32_t>& tokens, const std::vector<std::string>& image_paths,
636
+ const std::string& profile_file = "");
637
+
526
638
  virtual uint32_t decode_with_images(const std::vector<uint32_t>& tokens, const std::vector<std::string>& image_paths,
527
639
  float temperature = -1.0f, float top_p = -1.0f,
528
640
  size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr);
529
641
 
530
642
  virtual uint32_t decode_with_audio(const std::vector<uint32_t>& tokens, const std::vector<float>& audio_features, float temperature = 0.0f, float top_p = 0.0f,
531
- size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr);
643
+ size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr,
644
+ float* out_token_time_start = nullptr, float* out_token_time_end = nullptr);
532
645
 
533
646
  std::vector<float> get_embeddings(const std::vector<uint32_t>& tokens, bool pooled = true, bool normalize = false, const std::string& profile_file = "");
534
647
 
@@ -548,13 +661,37 @@ public:
548
661
  bool has_npu_prefill() const;
549
662
  size_t get_prefill_chunk_size() const;
550
663
 
664
+ virtual void remove_thinking_tokens(const std::vector<std::pair<size_t, size_t>>& ranges);
665
+ virtual void compact_kv_cache() {}
666
+
551
667
  void set_tool_constraints(const std::vector<std::string>& function_names);
552
668
  void clear_tool_constraints();
553
669
  void update_tool_constraints(uint32_t token_id);
554
670
 
555
671
  void* graph_handle_;
556
672
 
673
+ void set_vocab_bias(const std::unordered_map<uint32_t, float>& bias) {
674
+ vocab_bias_ = bias;
675
+ }
676
+
677
+ void clear_vocab_bias() {
678
+ vocab_bias_.clear();
679
+ }
680
+
681
+ bool has_vocab_bias() const {
682
+ return !vocab_bias_.empty();
683
+ }
684
+
685
+ const std::unordered_map<uint32_t, float>& get_vocab_bias() const {
686
+ return vocab_bias_;
687
+ }
688
+
557
689
  protected:
690
+ size_t sample_token(CactusGraph* gb, size_t logits_node_id, float temperature, float top_p, size_t top_k,
691
+ const std::unordered_map<uint32_t, float>* extra_bias = nullptr) const;
692
+
693
+ static void compute_entropy(CactusGraph* gb, size_t logits_node_id, float* out_entropy);
694
+
558
695
  virtual size_t forward(const std::vector<uint32_t>& tokens, bool use_cache = false) = 0;
559
696
 
560
697
  virtual size_t forward(const std::vector<float>& audio_features, const std::vector<uint32_t>& tokens, bool use_cache = false);
@@ -569,6 +706,12 @@ protected:
569
706
  virtual size_t build_transformer_block(CactusGraph* gb, size_t hidden, uint32_t layer_idx,
570
707
  ComputeBackend backend, bool use_cache = false, size_t position_offset = 0) = 0;
571
708
  void update_kv_cache(CactusGraph* gb, size_t seq_len);
709
+ virtual std::vector<size_t> get_kv_layer_dims() const {
710
+ return std::vector<size_t>(config_.num_layers, config_.attention_head_dim);
711
+ }
712
+ virtual std::vector<size_t> get_kv_layer_heads() const {
713
+ return std::vector<size_t>(config_.num_layers, config_.attention_kv_heads);
714
+ }
572
715
  virtual void post_init() {}
573
716
  virtual void post_execute_updates(CactusGraph*, size_t) {}
574
717
  Config config_;
@@ -601,6 +744,9 @@ protected:
601
744
  virtual std::vector<__fp16> get_token_embeddings(const std::vector<uint32_t>& tokens);
602
745
 
603
746
  ToolCallConstrainer tool_constrainer_;
747
+
748
+ private:
749
+ std::unordered_map<uint32_t, float> vocab_bias_;
604
750
  };
605
751
 
606
752
  std::unique_ptr<Model> create_model(const std::string& model_folder);
@@ -705,13 +851,17 @@ public:
705
851
  bool remove_dc_offset = false;
706
852
  float preemphasis = 0.0f;
707
853
  bool hann_periodic = true;
854
+ float window_a0 = 0.5f;
855
+ size_t fft_override = 0;
856
+ bool mel_floor_additive = false;
708
857
  };
709
858
 
710
859
  AudioProcessor();
711
860
  ~AudioProcessor();
712
861
 
713
862
  void init_mel_filters(size_t num_frequency_bins, size_t num_mel_filters,
714
- float min_freq, float max_freq, size_t sampling_rate);
863
+ float min_freq, float max_freq, size_t sampling_rate,
864
+ const char* norm = "slaney", const char* mel_scale = "slaney");
715
865
 
716
866
  std::vector<float> compute_spectrogram(
717
867
  const std::vector<float>& waveform,
@@ -53,6 +53,7 @@ inline std::string format_argument(const std::string& json, size_t& pos, bool es
53
53
  char c = json[pos];
54
54
 
55
55
  if (c == '"') {
56
+ pos++;
56
57
  std::string value = extract_json_string(json, pos);
57
58
  return escape(value);
58
59
  } else if (c == '{') {
@@ -240,7 +241,7 @@ inline std::string format_parameters(const std::string& properties_json, const s
240
241
  result += ",properties:{" + format_parameters(prop_obj["properties"], nested_required) + "}";
241
242
  }
242
243
  if (prop_obj.count("required")) {
243
- result += ",required:[";
244
+ std::string req_items;
244
245
  size_t req_pos = 0;
245
246
  skip_whitespace(prop_obj["required"], req_pos);
246
247
  if (req_pos < prop_obj["required"].length() && prop_obj["required"][req_pos] == '[') {
@@ -253,13 +254,15 @@ inline std::string format_parameters(const std::string& properties_json, const s
253
254
  if (prop_obj["required"][req_pos] == '"') {
254
255
  req_pos++;
255
256
  std::string req_item = extract_json_string(prop_obj["required"], req_pos);
256
- if (!req_first) result += ",";
257
+ if (!req_first) req_items += ",";
257
258
  req_first = false;
258
- result += escape(req_item);
259
+ req_items += escape(req_item);
259
260
  }
260
261
  }
261
262
  }
262
- result += "]";
263
+ if (!req_items.empty()) {
264
+ result += ",required:[" + req_items + "]";
265
+ }
263
266
  }
264
267
  } else if (to_upper(type_val) == "ARRAY") {
265
268
  if (prop_obj.count("items")) {
@@ -342,7 +345,7 @@ inline std::string format_function_declaration(const std::string& name,
342
345
  }
343
346
 
344
347
  if (params.count("required")) {
345
- result += ",required:[";
348
+ std::string req_items;
346
349
  size_t req_pos = 0;
347
350
  skip_whitespace(params["required"], req_pos);
348
351
  if (req_pos < params["required"].length() && params["required"][req_pos] == '[') {
@@ -355,13 +358,15 @@ inline std::string format_function_declaration(const std::string& name,
355
358
  if (params["required"][req_pos] == '"') {
356
359
  req_pos++;
357
360
  std::string item = extract_json_string(params["required"], req_pos);
358
- if (!first) result += ",";
361
+ if (!first) req_items += ",";
359
362
  first = false;
360
- result += escape(item);
363
+ req_items += escape(item);
361
364
  }
362
365
  }
363
366
  }
364
- result += "]";
367
+ if (!req_items.empty()) {
368
+ result += ",required:[" + req_items + "]";
369
+ }
365
370
  }
366
371
 
367
372
  if (params.count("type")) {
@@ -377,12 +382,15 @@ inline std::string format_function_declaration(const std::string& name,
377
382
  }
378
383
 
379
384
  template<typename ToolFunction>
380
- inline std::string format_tools(const std::vector<ToolFunction>& tools) {
385
+ inline std::string format_tools(const std::vector<ToolFunction>& tools, bool use_pipe_tags = false) {
381
386
  if (tools.empty()) return "";
382
387
 
388
+ const char* decl_start = use_pipe_tags ? "<|tool>" : "<start_function_declaration>";
389
+ const char* decl_end = use_pipe_tags ? "<tool|>" : "<end_function_declaration>";
390
+
383
391
  std::string result;
384
392
  for (const auto& tool : tools) {
385
- result += "<start_function_declaration>";
393
+ result += decl_start;
386
394
  std::string params_json;
387
395
  auto it = tool.parameters.find("schema");
388
396
  if (it != tool.parameters.end()) {
@@ -390,12 +398,26 @@ inline std::string format_tools(const std::vector<ToolFunction>& tools) {
390
398
  }
391
399
 
392
400
  result += format_function_declaration(tool.name, tool.description, params_json);
393
- result += "<end_function_declaration>";
401
+ result += decl_end;
394
402
  }
395
403
  return result;
396
404
  }
397
405
 
398
406
 
407
+ inline size_t match_quote_tag(const std::string& s, size_t pos) {
408
+ if (s.compare(pos, 8, "<escape>") == 0) return 8;
409
+ if (s.compare(pos, 5, "<|\"|>") == 0) return 5;
410
+ return 0;
411
+ }
412
+
413
+ inline size_t find_quote_tag(const std::string& s, size_t pos) {
414
+ size_t e = s.find("<escape>", pos);
415
+ size_t t = s.find("<|\"|>", pos);
416
+ if (e == std::string::npos) return t;
417
+ if (t == std::string::npos) return e;
418
+ return std::min(e, t);
419
+ }
420
+
399
421
  inline std::string unescape(const std::string& s) {
400
422
  const std::string ESCAPE_TAG = "<escape>";
401
423
  std::string result = s;
@@ -427,12 +449,13 @@ inline std::string args_to_json(const std::string& args_content) {
427
449
  while (pos < args_content.length() && std::isspace(args_content[pos])) pos++;
428
450
 
429
451
  if (pos < args_content.length()) {
430
- if (args_content.compare(pos, 8, "<escape>") == 0) {
431
- pos += 8;
432
- size_t val_end = args_content.find("<escape>", pos);
452
+ size_t qtag_len = match_quote_tag(args_content, pos);
453
+ if (qtag_len > 0) {
454
+ pos += qtag_len;
455
+ size_t val_end = find_quote_tag(args_content, pos);
433
456
  if (val_end != std::string::npos) {
434
457
  value = "\"" + args_content.substr(pos, val_end - pos) + "\"";
435
- pos = val_end + 8;
458
+ pos = val_end + match_quote_tag(args_content, val_end);
436
459
  }
437
460
  } else if (args_content[pos] == '{') {
438
461
  int depth = 1;
@@ -464,12 +487,13 @@ inline std::string args_to_json(const std::string& args_content) {
464
487
  if (!first_item) value += ",";
465
488
  first_item = false;
466
489
 
467
- if (arr_content.compare(arr_pos, 8, "<escape>") == 0) {
468
- arr_pos += 8;
469
- size_t end = arr_content.find("<escape>", arr_pos);
490
+ size_t aq_len = match_quote_tag(arr_content, arr_pos);
491
+ if (aq_len > 0) {
492
+ arr_pos += aq_len;
493
+ size_t end = find_quote_tag(arr_content, arr_pos);
470
494
  if (end != std::string::npos) {
471
495
  value += "\"" + arr_content.substr(arr_pos, end - arr_pos) + "\"";
472
- arr_pos = end + 8;
496
+ arr_pos = end + match_quote_tag(arr_content, end);
473
497
  }
474
498
  } else {
475
499
  size_t end = arr_content.find_first_of(",]", arr_pos);
@@ -499,8 +523,11 @@ inline std::string args_to_json(const std::string& args_content) {
499
523
  }
500
524
 
501
525
  inline void parse_function_calls(std::string& response, std::vector<std::string>& function_calls) {
502
- const std::string CALL_START = "<start_function_call>";
503
- const std::string CALL_END = "<end_function_call>";
526
+
527
+ const std::string CALL_START = (response.find("<|tool_call>") != std::string::npos)
528
+ ? "<|tool_call>" : "<start_function_call>";
529
+ const std::string CALL_END = (CALL_START == "<|tool_call>")
530
+ ? "<tool_call|>" : "<end_function_call>";
504
531
  size_t pos = 0;
505
532
 
506
533
  while ((pos = response.find(CALL_START, pos)) != std::string::npos) {