cactus-react-native 1.10.4 → 1.12.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +199 -40
- package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
- package/cpp/HybridCactus.cpp +131 -2
- package/cpp/HybridCactus.hpp +15 -0
- package/cpp/cactus_ffi.h +240 -2
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +240 -2
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_utils.h +940 -109
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +175 -25
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/gemma_tools.h +48 -21
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +79 -7
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +122 -9
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +191 -2
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +240 -2
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_utils.h +940 -109
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +175 -25
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/gemma_tools.h +48 -21
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +79 -7
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +122 -9
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +191 -2
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
- package/lib/module/classes/{CactusVAD.js → CactusAudio.js} +19 -6
- package/lib/module/classes/CactusAudio.js.map +1 -0
- package/lib/module/classes/CactusLM.js +25 -0
- package/lib/module/classes/CactusLM.js.map +1 -1
- package/lib/module/hooks/{useCactusVAD.js → useCactusAudio.js} +50 -20
- package/lib/module/hooks/useCactusAudio.js.map +1 -0
- package/lib/module/index.js +2 -2
- package/lib/module/index.js.map +1 -1
- package/lib/module/modelRegistry.js +1 -1
- package/lib/module/native/Cactus.js +81 -2
- package/lib/module/native/Cactus.js.map +1 -1
- package/lib/module/types/CactusAudio.js +4 -0
- package/lib/module/types/{CactusVAD.js.map → CactusAudio.js.map} +1 -1
- package/lib/typescript/src/classes/CactusAudio.d.ts +22 -0
- package/lib/typescript/src/classes/CactusAudio.d.ts.map +1 -0
- package/lib/typescript/src/classes/CactusLM.d.ts +2 -1
- package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
- package/lib/typescript/src/hooks/useCactusAudio.d.ts +17 -0
- package/lib/typescript/src/hooks/useCactusAudio.d.ts.map +1 -0
- package/lib/typescript/src/index.d.ts +4 -4
- package/lib/typescript/src/index.d.ts.map +1 -1
- package/lib/typescript/src/native/Cactus.d.ts +9 -3
- package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
- package/lib/typescript/src/specs/Cactus.nitro.d.ts +3 -0
- package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusAudio.d.ts +63 -0
- package/lib/typescript/src/types/CactusAudio.d.ts.map +1 -0
- package/lib/typescript/src/types/CactusLM.d.ts +15 -0
- package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusSTT.d.ts +1 -0
- package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
- package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +3 -0
- package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +3 -0
- package/package.json +1 -1
- package/src/classes/{CactusVAD.ts → CactusAudio.ts} +32 -13
- package/src/classes/CactusLM.ts +36 -0
- package/src/hooks/{useCactusVAD.ts → useCactusAudio.ts} +65 -28
- package/src/index.tsx +16 -9
- package/src/modelRegistry.ts +1 -1
- package/src/native/Cactus.ts +118 -3
- package/src/specs/Cactus.nitro.ts +16 -0
- package/src/types/CactusAudio.ts +73 -0
- package/src/types/CactusLM.ts +17 -0
- package/src/types/CactusSTT.ts +1 -0
- package/lib/module/classes/CactusVAD.js.map +0 -1
- package/lib/module/hooks/useCactusVAD.js.map +0 -1
- package/lib/module/types/CactusVAD.js +0 -4
- package/lib/typescript/src/classes/CactusVAD.d.ts +0 -20
- package/lib/typescript/src/classes/CactusVAD.d.ts.map +0 -1
- package/lib/typescript/src/hooks/useCactusVAD.d.ts +0 -15
- package/lib/typescript/src/hooks/useCactusVAD.d.ts.map +0 -1
- package/lib/typescript/src/types/CactusVAD.d.ts +0 -34
- package/lib/typescript/src/types/CactusVAD.d.ts.map +0 -1
- package/src/types/CactusVAD.ts +0 -39
|
@@ -75,6 +75,7 @@ struct Config {
|
|
|
75
75
|
bool use_pixel_shuffle = false;
|
|
76
76
|
uint32_t pixel_shuffle_factor = 1;
|
|
77
77
|
bool use_image_tokens = false;
|
|
78
|
+
uint32_t image_token_id = 0;
|
|
78
79
|
bool use_layout_tags = false;
|
|
79
80
|
uint32_t image_seq_len = 64;
|
|
80
81
|
|
|
@@ -107,6 +108,26 @@ struct Config {
|
|
|
107
108
|
uint32_t subsampling_factor = 0;
|
|
108
109
|
uint32_t num_mel_bins = 80;
|
|
109
110
|
std::string encoder_hidden_act = "silu";
|
|
111
|
+
uint32_t linear_num_key_heads = 0;
|
|
112
|
+
uint32_t linear_key_head_dim = 0;
|
|
113
|
+
uint32_t linear_num_value_heads = 0;
|
|
114
|
+
uint32_t linear_value_head_dim = 0;
|
|
115
|
+
uint32_t linear_q_proj_dim = 0;
|
|
116
|
+
uint32_t linear_k_proj_dim = 0;
|
|
117
|
+
uint32_t linear_v_proj_dim = 0;
|
|
118
|
+
|
|
119
|
+
uint32_t kv_lora_rank = 0;
|
|
120
|
+
uint32_t q_lora_rank = 0;
|
|
121
|
+
uint32_t qk_head_dim = 0;
|
|
122
|
+
uint32_t qk_nope_head_dim = 0;
|
|
123
|
+
uint32_t qk_rope_head_dim = 0;
|
|
124
|
+
uint32_t v_head_dim = 0;
|
|
125
|
+
uint32_t rope_interleave = 0;
|
|
126
|
+
bool attention_bias = false;
|
|
127
|
+
float rope_scaling_factor = 1.0f;
|
|
128
|
+
float rope_mscale_all_dim = 0.0f;
|
|
129
|
+
|
|
130
|
+
enum class ModelType {QWEN = 0, GEMMA = 1, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6, WHISPER = 7, MOONSHINE = 8, SILERO_VAD = 9, PARAKEET = 10, QWEN3P5 = 11, PARAKEET_TDT = 12, GEMMA3N = 13, YOUTU = 14, GEMMA4 = 15, PYANNOTE = 16, WESPEAKER = 17};
|
|
110
131
|
uint32_t predictor_hidden_dim = 0;
|
|
111
132
|
uint32_t predictor_num_layers = 0;
|
|
112
133
|
uint32_t tdt_joint_dim = 0;
|
|
@@ -114,7 +135,6 @@ struct Config {
|
|
|
114
135
|
uint32_t tdt_blank_id = 0;
|
|
115
136
|
std::vector<uint32_t> tdt_durations;
|
|
116
137
|
|
|
117
|
-
enum class ModelType {QWEN = 0, GEMMA = 1, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6, WHISPER = 7, MOONSHINE = 8, SILERO_VAD = 9, PARAKEET = 10, PARAKEET_TDT = 11};
|
|
118
138
|
ModelType model_type = ModelType::QWEN;
|
|
119
139
|
|
|
120
140
|
enum class ModelVariant {DEFAULT = 0, VLM = 1, EXTRACT = 2, RAG = 3};
|
|
@@ -138,6 +158,58 @@ struct Config {
|
|
|
138
158
|
std::vector<std::string> layer_types;
|
|
139
159
|
size_t conv_L_cache = 0;
|
|
140
160
|
|
|
161
|
+
uint32_t altup_num_inputs = 4;
|
|
162
|
+
uint32_t laurel_rank = 64;
|
|
163
|
+
uint32_t hidden_size_per_layer_input = 256;
|
|
164
|
+
uint32_t num_kv_shared_layers = 0;
|
|
165
|
+
uint32_t sliding_window = 512;
|
|
166
|
+
float rope_local_base_freq = 10000.0f;
|
|
167
|
+
float final_logit_softcapping = 0.0f;
|
|
168
|
+
float global_partial_rotary_factor = 1.0f;
|
|
169
|
+
uint32_t expert_intermediate_size = 0;
|
|
170
|
+
uint32_t global_head_dim = 0;
|
|
171
|
+
uint32_t num_global_kv_heads = 0;
|
|
172
|
+
bool attention_k_eq_v = false;
|
|
173
|
+
bool enable_moe_block = false;
|
|
174
|
+
std::vector<float> activation_sparsity_ppf;
|
|
175
|
+
|
|
176
|
+
uint32_t vision_head_dim = 64;
|
|
177
|
+
uint32_t vision_kv_heads = 12;
|
|
178
|
+
uint32_t vision_intermediate_size = 3072;
|
|
179
|
+
uint32_t vision_position_embedding_size = 10240;
|
|
180
|
+
uint32_t vision_pooling_kernel_size = 3;
|
|
181
|
+
uint32_t vision_default_output_length = 280;
|
|
182
|
+
float vision_rope_theta = 100.0f;
|
|
183
|
+
|
|
184
|
+
uint32_t audio_hidden_dim = 0;
|
|
185
|
+
uint32_t audio_num_layers = 0;
|
|
186
|
+
uint32_t audio_num_heads = 0;
|
|
187
|
+
uint32_t audio_head_dim = 0;
|
|
188
|
+
uint32_t audio_input_feat_size = 128;
|
|
189
|
+
uint32_t audio_conf_conv_kernel_size = 5;
|
|
190
|
+
uint32_t audio_chunk_size = 12;
|
|
191
|
+
uint32_t audio_context_left = 13;
|
|
192
|
+
uint32_t audio_context_right = 0;
|
|
193
|
+
float audio_logit_cap = 50.0f;
|
|
194
|
+
float audio_residual_weight = 0.5f;
|
|
195
|
+
uint32_t audio_output_proj_dims = 0;
|
|
196
|
+
uint32_t audio_vocab_size = 128;
|
|
197
|
+
uint32_t audio_vocab_offset = 0;
|
|
198
|
+
uint32_t audio_soft_tokens = 188;
|
|
199
|
+
uint32_t audio_sscp_conv0_channels = 128;
|
|
200
|
+
uint32_t audio_sscp_conv1_channels = 32;
|
|
201
|
+
float audio_sscp_conv_eps = 1e-3f;
|
|
202
|
+
float audio_rms_norm_eps = 1e-6f;
|
|
203
|
+
uint32_t audio_fft_length = 1024;
|
|
204
|
+
uint32_t audio_token_id = 0;
|
|
205
|
+
bool audio_fft_overdrive = false;
|
|
206
|
+
uint32_t channel_open_token_id = 100;
|
|
207
|
+
uint32_t channel_close_token_id = 101;
|
|
208
|
+
|
|
209
|
+
static bool is_gemma_family(ModelType t) {
|
|
210
|
+
return t == ModelType::GEMMA || t == ModelType::GEMMA3N || t == ModelType::GEMMA4;
|
|
211
|
+
}
|
|
212
|
+
|
|
141
213
|
bool from_json(const std::string& json_path);
|
|
142
214
|
std::string to_json() const;
|
|
143
215
|
};
|
|
@@ -155,14 +227,38 @@ struct MergeRule {
|
|
|
155
227
|
};
|
|
156
228
|
|
|
157
229
|
|
|
230
|
+
struct ToolCallInfo {
|
|
231
|
+
std::string name;
|
|
232
|
+
std::string arguments;
|
|
233
|
+
};
|
|
234
|
+
|
|
158
235
|
struct ChatMessage {
|
|
159
236
|
std::string role;
|
|
160
237
|
std::string content;
|
|
161
238
|
std::string name;
|
|
162
239
|
std::vector<std::string> images;
|
|
240
|
+
std::vector<std::string> audio;
|
|
241
|
+
size_t audio_soft_token_count = 0;
|
|
242
|
+
std::vector<ToolCallInfo> tool_calls;
|
|
163
243
|
};
|
|
164
244
|
|
|
245
|
+
struct TokenizerRuntimeConfig {
|
|
246
|
+
enum class TokenizerType { UNKNOWN, BPE, SENTENCEPIECE };
|
|
247
|
+
enum class VocabFormat { UNKNOWN, ID_TAB_TOKEN, LINE_TOKEN };
|
|
248
|
+
enum class Normalizer { NONE, METASPACE, BYTE_LEVEL };
|
|
249
|
+
enum class Decoder { NONE, REPLACE_METASPACE, BYTE_LEVEL };
|
|
250
|
+
|
|
251
|
+
TokenizerType tokenizer_type = TokenizerType::UNKNOWN;
|
|
252
|
+
VocabFormat vocab_format = VocabFormat::UNKNOWN;
|
|
253
|
+
Normalizer normalizer = Normalizer::NONE;
|
|
254
|
+
Decoder decoder = Decoder::NONE;
|
|
255
|
+
bool byte_fallback = false;
|
|
256
|
+
bool has_chat_template = false;
|
|
257
|
+
};
|
|
165
258
|
|
|
259
|
+
TokenizerRuntimeConfig load_tokenizer_runtime_config(const std::string& config_file);
|
|
260
|
+
void load_special_tokens_map(const std::string& config_file, std::unordered_map<std::string, uint32_t>& special_tokens);
|
|
261
|
+
std::vector<std::string> split_with_special_tokens(const std::string& text, const std::unordered_map<std::string, uint32_t>& special_tokens);
|
|
166
262
|
|
|
167
263
|
class Tokenizer {
|
|
168
264
|
public:
|
|
@@ -172,7 +268,7 @@ public:
|
|
|
172
268
|
virtual std::string decode(const std::vector<uint32_t>& tokens) const = 0;
|
|
173
269
|
|
|
174
270
|
virtual std::vector<uint32_t> apply_chat_template(const std::vector<ChatMessage>& messages, bool add_generation_prompt = true) const;
|
|
175
|
-
virtual std::string format_chat_prompt(const std::vector<ChatMessage>& messages, bool add_generation_prompt = true, const std::string& tools_json = "") const;
|
|
271
|
+
virtual std::string format_chat_prompt(const std::vector<ChatMessage>& messages, bool add_generation_prompt = true, const std::string& tools_json = "", bool enable_thinking_if_supported = true) const;
|
|
176
272
|
|
|
177
273
|
virtual uint32_t get_vocab_size() const = 0;
|
|
178
274
|
virtual uint32_t get_unk_token() const = 0;
|
|
@@ -188,7 +284,7 @@ public:
|
|
|
188
284
|
uint32_t get_global_img_token_id() const { return global_img_token_id_; }
|
|
189
285
|
|
|
190
286
|
protected:
|
|
191
|
-
enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2, BERT, WHISPER, PARAKEET};
|
|
287
|
+
enum class ModelType { UNKNOWN, QWEN, QWEN3P5, GEMMA, GEMMA4, LFM2, BERT, WHISPER, PARAKEET, YOUTU};
|
|
192
288
|
ModelType model_type_ = ModelType::UNKNOWN;
|
|
193
289
|
enum class ModelVariant { DEFAULT, VLM, EXTRACT, RAG};
|
|
194
290
|
ModelVariant model_variant_ = ModelVariant::DEFAULT;
|
|
@@ -199,11 +295,21 @@ protected:
|
|
|
199
295
|
uint32_t fake_token_id_ = 49189;
|
|
200
296
|
uint32_t global_img_token_id_ = 49152;
|
|
201
297
|
|
|
298
|
+
|
|
299
|
+
uint32_t vision_patch_size_ = 16;
|
|
300
|
+
uint32_t vision_pooling_kernel_size_ = 3;
|
|
301
|
+
uint32_t vision_default_output_length_ = 280;
|
|
302
|
+
uint32_t vision_image_size_ = 768;
|
|
303
|
+
TokenizerRuntimeConfig runtime_config_;
|
|
304
|
+
|
|
202
305
|
void detect_model_type(const std::string& config_path);
|
|
203
|
-
|
|
306
|
+
void load_chat_template(const std::string& template_file);
|
|
307
|
+
std::string format_qwen_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json, bool enable_thinking_if_supported = true) const;
|
|
204
308
|
std::string format_gemma_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
|
|
309
|
+
std::string format_gemma4_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json, bool enable_thinking_if_supported = true) const;
|
|
205
310
|
std::string format_lfm2_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
|
|
206
311
|
std::string format_lfm2_vl_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
|
|
312
|
+
std::string format_youtu_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
|
|
207
313
|
};
|
|
208
314
|
|
|
209
315
|
class BPETokenizer : public Tokenizer {
|
|
@@ -245,6 +351,7 @@ private:
|
|
|
245
351
|
std::string bytes_to_unicode(const std::string& text) const;
|
|
246
352
|
std::string unicode_to_bytes(const std::string& text) const;
|
|
247
353
|
std::vector<std::string> byte_level_split(const std::string& text) const;
|
|
354
|
+
std::vector<std::string> utf8_split(const std::string& text) const;
|
|
248
355
|
|
|
249
356
|
void cleanup_mmap();
|
|
250
357
|
|
|
@@ -256,12 +363,6 @@ private:
|
|
|
256
363
|
std::unordered_map<std::string, uint32_t> special_tokens_;
|
|
257
364
|
std::vector<std::string> split_with_special_tokens(const std::string& text) const;
|
|
258
365
|
void load_special_tokens(const std::string& config_file);
|
|
259
|
-
|
|
260
|
-
void load_chat_template(const std::string& template_file);
|
|
261
|
-
|
|
262
|
-
std::unordered_map<std::string, uint32_t> tool_tokens_;
|
|
263
|
-
bool has_tool_support_;
|
|
264
|
-
void load_tokenizer_config(const std::string& config_file);
|
|
265
366
|
};
|
|
266
367
|
|
|
267
368
|
class SPTokenizer : public Tokenizer {
|
|
@@ -311,8 +412,6 @@ private:
|
|
|
311
412
|
std::unordered_map<std::string, uint32_t> special_tokens_;
|
|
312
413
|
std::vector<std::string> split_with_special_tokens(const std::string& text) const;
|
|
313
414
|
void load_special_tokens(const std::string& config_file);
|
|
314
|
-
|
|
315
|
-
void load_chat_template(const std::string& template_file);
|
|
316
415
|
};
|
|
317
416
|
|
|
318
417
|
class ConvCache {
|
|
@@ -355,8 +454,10 @@ struct KVCache {
|
|
|
355
454
|
struct LayerCache {
|
|
356
455
|
std::vector<uint8_t> keys;
|
|
357
456
|
std::vector<uint8_t> values;
|
|
358
|
-
std::vector<float> key_scales;
|
|
359
|
-
std::vector<float> value_scales;
|
|
457
|
+
std::vector<float> key_scales;
|
|
458
|
+
std::vector<float> value_scales;
|
|
459
|
+
size_t head_dim = 0;
|
|
460
|
+
size_t kv_heads = 0;
|
|
360
461
|
};
|
|
361
462
|
|
|
362
463
|
std::vector<LayerCache> layer_caches;
|
|
@@ -366,8 +467,6 @@ struct KVCache {
|
|
|
366
467
|
size_t current_seq_len = 0;
|
|
367
468
|
size_t total_seq_len = 0;
|
|
368
469
|
size_t max_seq_len = 2048;
|
|
369
|
-
size_t num_kv_heads = 0;
|
|
370
|
-
size_t head_dim = 0;
|
|
371
470
|
size_t num_layers = 0;
|
|
372
471
|
Precision precision;
|
|
373
472
|
size_t element_size = 4;
|
|
@@ -375,12 +474,14 @@ struct KVCache {
|
|
|
375
474
|
void set_window_size(size_t window, size_t sink = DEFAULT_SINK_SIZE);
|
|
376
475
|
size_t get_effective_seq_len() const { return current_seq_len; }
|
|
377
476
|
size_t get_total_seq_len() const { return total_seq_len; }
|
|
477
|
+
size_t get_layer_head_dim(size_t layer_idx) const { return layer_caches[layer_idx].head_dim; }
|
|
478
|
+
size_t get_layer_kv_heads(size_t layer_idx) const { return layer_caches[layer_idx].kv_heads; }
|
|
378
479
|
|
|
379
|
-
void init(size_t num_layers, size_t max_seq, size_t
|
|
480
|
+
void init(size_t num_layers, size_t max_seq, const std::vector<size_t>& layer_dims, const std::vector<size_t>& layer_kv_heads, Precision model_precision);
|
|
380
481
|
void reset();
|
|
381
482
|
void update_from_graph(CactusGraph* gb, const std::vector<size_t>& k_nodes,
|
|
382
483
|
const std::vector<size_t>& v_nodes, size_t seq_len,
|
|
383
|
-
size_t num_layers
|
|
484
|
+
size_t num_layers);
|
|
384
485
|
|
|
385
486
|
void update_from_npu(size_t layer_idx, const __fp16* k_data, const __fp16* v_data,
|
|
386
487
|
size_t num_tokens, size_t kv_heads, size_t head_dim);
|
|
@@ -404,6 +505,9 @@ struct KVCache {
|
|
|
404
505
|
const int8_t* get_values_int8(size_t layer) const;
|
|
405
506
|
const float* get_key_scales(size_t layer) const;
|
|
406
507
|
const float* get_value_scales(size_t layer) const;
|
|
508
|
+
|
|
509
|
+
void remove_token_range(size_t start, size_t count);
|
|
510
|
+
void compact_to_windows(const std::vector<size_t>& target_windows);
|
|
407
511
|
};
|
|
408
512
|
|
|
409
513
|
class ToolCallConstrainer {
|
|
@@ -421,7 +525,7 @@ public:
|
|
|
421
525
|
QWEN_EXPECT_ARGS_COLON,
|
|
422
526
|
QWEN_IN_ARGUMENTS,
|
|
423
527
|
QWEN_EXPECT_CLOSE_BRACE,
|
|
424
|
-
QWEN_EXPECT_END,
|
|
528
|
+
QWEN_EXPECT_END,
|
|
425
529
|
|
|
426
530
|
LFM_START,
|
|
427
531
|
LFM_EXPECT_BRACKET,
|
|
@@ -457,12 +561,17 @@ private:
|
|
|
457
561
|
Config::ModelType model_type_ = Config::ModelType::QWEN;
|
|
458
562
|
Tokenizer* tokenizer_ = nullptr;
|
|
459
563
|
|
|
564
|
+
bool is_gemma_family() const { return Config::is_gemma_family(model_type_); }
|
|
565
|
+
|
|
460
566
|
std::vector<std::string> function_names_;
|
|
461
567
|
std::string generated_text_;
|
|
462
|
-
int brace_depth_ = 0;
|
|
568
|
+
int brace_depth_ = 0;
|
|
569
|
+
|
|
570
|
+
std::string call_start_tag_;
|
|
571
|
+
std::string call_end_tag_;
|
|
463
572
|
|
|
464
|
-
std::unordered_set<uint32_t> qwen_tool_call_start_tokens_;
|
|
465
|
-
std::unordered_set<uint32_t> qwen_tool_call_end_tokens_;
|
|
573
|
+
std::unordered_set<uint32_t> qwen_tool_call_start_tokens_;
|
|
574
|
+
std::unordered_set<uint32_t> qwen_tool_call_end_tokens_;
|
|
466
575
|
std::unordered_set<uint32_t> open_brace_tokens_;
|
|
467
576
|
std::unordered_set<uint32_t> close_brace_tokens_;
|
|
468
577
|
std::unordered_set<uint32_t> colon_tokens_;
|
|
@@ -472,7 +581,7 @@ private:
|
|
|
472
581
|
std::unordered_set<uint32_t> quote_tokens_;
|
|
473
582
|
std::unordered_set<uint32_t> backtick_tokens_;
|
|
474
583
|
std::unordered_set<uint32_t> all_func_name_tokens_;
|
|
475
|
-
std::unordered_map<std::string, std::vector<uint32_t>> func_name_sequences_;
|
|
584
|
+
std::unordered_map<std::string, std::vector<uint32_t>> func_name_sequences_;
|
|
476
585
|
|
|
477
586
|
std::unordered_set<uint32_t> tool_start_tokens_;
|
|
478
587
|
std::unordered_set<uint32_t> tool_end_tokens_;
|
|
@@ -523,12 +632,16 @@ public:
|
|
|
523
632
|
|
|
524
633
|
virtual void prefill(const std::vector<uint32_t>& tokens, size_t chunk_size = 256, const std::string& profile_file = "");
|
|
525
634
|
|
|
635
|
+
virtual void prefill_with_images(const std::vector<uint32_t>& tokens, const std::vector<std::string>& image_paths,
|
|
636
|
+
const std::string& profile_file = "");
|
|
637
|
+
|
|
526
638
|
virtual uint32_t decode_with_images(const std::vector<uint32_t>& tokens, const std::vector<std::string>& image_paths,
|
|
527
639
|
float temperature = -1.0f, float top_p = -1.0f,
|
|
528
640
|
size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr);
|
|
529
641
|
|
|
530
642
|
virtual uint32_t decode_with_audio(const std::vector<uint32_t>& tokens, const std::vector<float>& audio_features, float temperature = 0.0f, float top_p = 0.0f,
|
|
531
|
-
size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr
|
|
643
|
+
size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr,
|
|
644
|
+
float* out_token_time_start = nullptr, float* out_token_time_end = nullptr);
|
|
532
645
|
|
|
533
646
|
std::vector<float> get_embeddings(const std::vector<uint32_t>& tokens, bool pooled = true, bool normalize = false, const std::string& profile_file = "");
|
|
534
647
|
|
|
@@ -548,13 +661,37 @@ public:
|
|
|
548
661
|
bool has_npu_prefill() const;
|
|
549
662
|
size_t get_prefill_chunk_size() const;
|
|
550
663
|
|
|
664
|
+
virtual void remove_thinking_tokens(const std::vector<std::pair<size_t, size_t>>& ranges);
|
|
665
|
+
virtual void compact_kv_cache() {}
|
|
666
|
+
|
|
551
667
|
void set_tool_constraints(const std::vector<std::string>& function_names);
|
|
552
668
|
void clear_tool_constraints();
|
|
553
669
|
void update_tool_constraints(uint32_t token_id);
|
|
554
670
|
|
|
555
671
|
void* graph_handle_;
|
|
556
672
|
|
|
673
|
+
void set_vocab_bias(const std::unordered_map<uint32_t, float>& bias) {
|
|
674
|
+
vocab_bias_ = bias;
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
void clear_vocab_bias() {
|
|
678
|
+
vocab_bias_.clear();
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
bool has_vocab_bias() const {
|
|
682
|
+
return !vocab_bias_.empty();
|
|
683
|
+
}
|
|
684
|
+
|
|
685
|
+
const std::unordered_map<uint32_t, float>& get_vocab_bias() const {
|
|
686
|
+
return vocab_bias_;
|
|
687
|
+
}
|
|
688
|
+
|
|
557
689
|
protected:
|
|
690
|
+
size_t sample_token(CactusGraph* gb, size_t logits_node_id, float temperature, float top_p, size_t top_k,
|
|
691
|
+
const std::unordered_map<uint32_t, float>* extra_bias = nullptr) const;
|
|
692
|
+
|
|
693
|
+
static void compute_entropy(CactusGraph* gb, size_t logits_node_id, float* out_entropy);
|
|
694
|
+
|
|
558
695
|
virtual size_t forward(const std::vector<uint32_t>& tokens, bool use_cache = false) = 0;
|
|
559
696
|
|
|
560
697
|
virtual size_t forward(const std::vector<float>& audio_features, const std::vector<uint32_t>& tokens, bool use_cache = false);
|
|
@@ -569,6 +706,12 @@ protected:
|
|
|
569
706
|
virtual size_t build_transformer_block(CactusGraph* gb, size_t hidden, uint32_t layer_idx,
|
|
570
707
|
ComputeBackend backend, bool use_cache = false, size_t position_offset = 0) = 0;
|
|
571
708
|
void update_kv_cache(CactusGraph* gb, size_t seq_len);
|
|
709
|
+
virtual std::vector<size_t> get_kv_layer_dims() const {
|
|
710
|
+
return std::vector<size_t>(config_.num_layers, config_.attention_head_dim);
|
|
711
|
+
}
|
|
712
|
+
virtual std::vector<size_t> get_kv_layer_heads() const {
|
|
713
|
+
return std::vector<size_t>(config_.num_layers, config_.attention_kv_heads);
|
|
714
|
+
}
|
|
572
715
|
virtual void post_init() {}
|
|
573
716
|
virtual void post_execute_updates(CactusGraph*, size_t) {}
|
|
574
717
|
Config config_;
|
|
@@ -601,6 +744,9 @@ protected:
|
|
|
601
744
|
virtual std::vector<__fp16> get_token_embeddings(const std::vector<uint32_t>& tokens);
|
|
602
745
|
|
|
603
746
|
ToolCallConstrainer tool_constrainer_;
|
|
747
|
+
|
|
748
|
+
private:
|
|
749
|
+
std::unordered_map<uint32_t, float> vocab_bias_;
|
|
604
750
|
};
|
|
605
751
|
|
|
606
752
|
std::unique_ptr<Model> create_model(const std::string& model_folder);
|
|
@@ -705,13 +851,17 @@ public:
|
|
|
705
851
|
bool remove_dc_offset = false;
|
|
706
852
|
float preemphasis = 0.0f;
|
|
707
853
|
bool hann_periodic = true;
|
|
854
|
+
float window_a0 = 0.5f;
|
|
855
|
+
size_t fft_override = 0;
|
|
856
|
+
bool mel_floor_additive = false;
|
|
708
857
|
};
|
|
709
858
|
|
|
710
859
|
AudioProcessor();
|
|
711
860
|
~AudioProcessor();
|
|
712
861
|
|
|
713
862
|
void init_mel_filters(size_t num_frequency_bins, size_t num_mel_filters,
|
|
714
|
-
float min_freq, float max_freq, size_t sampling_rate
|
|
863
|
+
float min_freq, float max_freq, size_t sampling_rate,
|
|
864
|
+
const char* norm = "slaney", const char* mel_scale = "slaney");
|
|
715
865
|
|
|
716
866
|
std::vector<float> compute_spectrogram(
|
|
717
867
|
const std::vector<float>& waveform,
|
|
@@ -53,6 +53,7 @@ inline std::string format_argument(const std::string& json, size_t& pos, bool es
|
|
|
53
53
|
char c = json[pos];
|
|
54
54
|
|
|
55
55
|
if (c == '"') {
|
|
56
|
+
pos++;
|
|
56
57
|
std::string value = extract_json_string(json, pos);
|
|
57
58
|
return escape(value);
|
|
58
59
|
} else if (c == '{') {
|
|
@@ -240,7 +241,7 @@ inline std::string format_parameters(const std::string& properties_json, const s
|
|
|
240
241
|
result += ",properties:{" + format_parameters(prop_obj["properties"], nested_required) + "}";
|
|
241
242
|
}
|
|
242
243
|
if (prop_obj.count("required")) {
|
|
243
|
-
|
|
244
|
+
std::string req_items;
|
|
244
245
|
size_t req_pos = 0;
|
|
245
246
|
skip_whitespace(prop_obj["required"], req_pos);
|
|
246
247
|
if (req_pos < prop_obj["required"].length() && prop_obj["required"][req_pos] == '[') {
|
|
@@ -253,13 +254,15 @@ inline std::string format_parameters(const std::string& properties_json, const s
|
|
|
253
254
|
if (prop_obj["required"][req_pos] == '"') {
|
|
254
255
|
req_pos++;
|
|
255
256
|
std::string req_item = extract_json_string(prop_obj["required"], req_pos);
|
|
256
|
-
if (!req_first)
|
|
257
|
+
if (!req_first) req_items += ",";
|
|
257
258
|
req_first = false;
|
|
258
|
-
|
|
259
|
+
req_items += escape(req_item);
|
|
259
260
|
}
|
|
260
261
|
}
|
|
261
262
|
}
|
|
262
|
-
|
|
263
|
+
if (!req_items.empty()) {
|
|
264
|
+
result += ",required:[" + req_items + "]";
|
|
265
|
+
}
|
|
263
266
|
}
|
|
264
267
|
} else if (to_upper(type_val) == "ARRAY") {
|
|
265
268
|
if (prop_obj.count("items")) {
|
|
@@ -342,7 +345,7 @@ inline std::string format_function_declaration(const std::string& name,
|
|
|
342
345
|
}
|
|
343
346
|
|
|
344
347
|
if (params.count("required")) {
|
|
345
|
-
|
|
348
|
+
std::string req_items;
|
|
346
349
|
size_t req_pos = 0;
|
|
347
350
|
skip_whitespace(params["required"], req_pos);
|
|
348
351
|
if (req_pos < params["required"].length() && params["required"][req_pos] == '[') {
|
|
@@ -355,13 +358,15 @@ inline std::string format_function_declaration(const std::string& name,
|
|
|
355
358
|
if (params["required"][req_pos] == '"') {
|
|
356
359
|
req_pos++;
|
|
357
360
|
std::string item = extract_json_string(params["required"], req_pos);
|
|
358
|
-
if (!first)
|
|
361
|
+
if (!first) req_items += ",";
|
|
359
362
|
first = false;
|
|
360
|
-
|
|
363
|
+
req_items += escape(item);
|
|
361
364
|
}
|
|
362
365
|
}
|
|
363
366
|
}
|
|
364
|
-
|
|
367
|
+
if (!req_items.empty()) {
|
|
368
|
+
result += ",required:[" + req_items + "]";
|
|
369
|
+
}
|
|
365
370
|
}
|
|
366
371
|
|
|
367
372
|
if (params.count("type")) {
|
|
@@ -377,12 +382,15 @@ inline std::string format_function_declaration(const std::string& name,
|
|
|
377
382
|
}
|
|
378
383
|
|
|
379
384
|
template<typename ToolFunction>
|
|
380
|
-
inline std::string format_tools(const std::vector<ToolFunction>& tools) {
|
|
385
|
+
inline std::string format_tools(const std::vector<ToolFunction>& tools, bool use_pipe_tags = false) {
|
|
381
386
|
if (tools.empty()) return "";
|
|
382
387
|
|
|
388
|
+
const char* decl_start = use_pipe_tags ? "<|tool>" : "<start_function_declaration>";
|
|
389
|
+
const char* decl_end = use_pipe_tags ? "<tool|>" : "<end_function_declaration>";
|
|
390
|
+
|
|
383
391
|
std::string result;
|
|
384
392
|
for (const auto& tool : tools) {
|
|
385
|
-
result +=
|
|
393
|
+
result += decl_start;
|
|
386
394
|
std::string params_json;
|
|
387
395
|
auto it = tool.parameters.find("schema");
|
|
388
396
|
if (it != tool.parameters.end()) {
|
|
@@ -390,12 +398,26 @@ inline std::string format_tools(const std::vector<ToolFunction>& tools) {
|
|
|
390
398
|
}
|
|
391
399
|
|
|
392
400
|
result += format_function_declaration(tool.name, tool.description, params_json);
|
|
393
|
-
result +=
|
|
401
|
+
result += decl_end;
|
|
394
402
|
}
|
|
395
403
|
return result;
|
|
396
404
|
}
|
|
397
405
|
|
|
398
406
|
|
|
407
|
+
inline size_t match_quote_tag(const std::string& s, size_t pos) {
|
|
408
|
+
if (s.compare(pos, 8, "<escape>") == 0) return 8;
|
|
409
|
+
if (s.compare(pos, 5, "<|\"|>") == 0) return 5;
|
|
410
|
+
return 0;
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
inline size_t find_quote_tag(const std::string& s, size_t pos) {
|
|
414
|
+
size_t e = s.find("<escape>", pos);
|
|
415
|
+
size_t t = s.find("<|\"|>", pos);
|
|
416
|
+
if (e == std::string::npos) return t;
|
|
417
|
+
if (t == std::string::npos) return e;
|
|
418
|
+
return std::min(e, t);
|
|
419
|
+
}
|
|
420
|
+
|
|
399
421
|
inline std::string unescape(const std::string& s) {
|
|
400
422
|
const std::string ESCAPE_TAG = "<escape>";
|
|
401
423
|
std::string result = s;
|
|
@@ -427,12 +449,13 @@ inline std::string args_to_json(const std::string& args_content) {
|
|
|
427
449
|
while (pos < args_content.length() && std::isspace(args_content[pos])) pos++;
|
|
428
450
|
|
|
429
451
|
if (pos < args_content.length()) {
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
452
|
+
size_t qtag_len = match_quote_tag(args_content, pos);
|
|
453
|
+
if (qtag_len > 0) {
|
|
454
|
+
pos += qtag_len;
|
|
455
|
+
size_t val_end = find_quote_tag(args_content, pos);
|
|
433
456
|
if (val_end != std::string::npos) {
|
|
434
457
|
value = "\"" + args_content.substr(pos, val_end - pos) + "\"";
|
|
435
|
-
pos = val_end +
|
|
458
|
+
pos = val_end + match_quote_tag(args_content, val_end);
|
|
436
459
|
}
|
|
437
460
|
} else if (args_content[pos] == '{') {
|
|
438
461
|
int depth = 1;
|
|
@@ -464,12 +487,13 @@ inline std::string args_to_json(const std::string& args_content) {
|
|
|
464
487
|
if (!first_item) value += ",";
|
|
465
488
|
first_item = false;
|
|
466
489
|
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
490
|
+
size_t aq_len = match_quote_tag(arr_content, arr_pos);
|
|
491
|
+
if (aq_len > 0) {
|
|
492
|
+
arr_pos += aq_len;
|
|
493
|
+
size_t end = find_quote_tag(arr_content, arr_pos);
|
|
470
494
|
if (end != std::string::npos) {
|
|
471
495
|
value += "\"" + arr_content.substr(arr_pos, end - arr_pos) + "\"";
|
|
472
|
-
arr_pos = end +
|
|
496
|
+
arr_pos = end + match_quote_tag(arr_content, end);
|
|
473
497
|
}
|
|
474
498
|
} else {
|
|
475
499
|
size_t end = arr_content.find_first_of(",]", arr_pos);
|
|
@@ -499,8 +523,11 @@ inline std::string args_to_json(const std::string& args_content) {
|
|
|
499
523
|
}
|
|
500
524
|
|
|
501
525
|
inline void parse_function_calls(std::string& response, std::vector<std::string>& function_calls) {
|
|
502
|
-
|
|
503
|
-
const std::string
|
|
526
|
+
|
|
527
|
+
const std::string CALL_START = (response.find("<|tool_call>") != std::string::npos)
|
|
528
|
+
? "<|tool_call>" : "<start_function_call>";
|
|
529
|
+
const std::string CALL_END = (CALL_START == "<|tool_call>")
|
|
530
|
+
? "<tool_call|>" : "<end_function_call>";
|
|
504
531
|
size_t pos = 0;
|
|
505
532
|
|
|
506
533
|
while ((pos = response.find(CALL_START, pos)) != std::string::npos) {
|