cactus-react-native 1.4.0 → 1.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. package/Cactus.podspec +1 -1
  2. package/README.md +465 -174
  3. package/android/CMakeLists.txt +24 -5
  4. package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libcurl.a +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libmbedcrypto.a +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libmbedtls.a +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libmbedx509.a +0 -0
  9. package/cpp/HybridCactus.cpp +157 -6
  10. package/cpp/HybridCactus.hpp +20 -3
  11. package/cpp/cactus_ffi.h +65 -30
  12. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus.h +0 -1
  13. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +65 -30
  14. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_utils.h +357 -122
  15. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +184 -63
  16. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/gemma_tools.h +549 -0
  17. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +153 -27
  18. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +90 -178
  19. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +276 -151
  20. package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
  21. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus.h +0 -1
  22. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +65 -30
  23. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_utils.h +357 -122
  24. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +184 -63
  25. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/gemma_tools.h +549 -0
  26. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +153 -27
  27. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +90 -178
  28. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +276 -151
  29. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
  30. package/lib/module/classes/CactusLM.js +43 -58
  31. package/lib/module/classes/CactusLM.js.map +1 -1
  32. package/lib/module/classes/CactusSTT.js +64 -38
  33. package/lib/module/classes/CactusSTT.js.map +1 -1
  34. package/lib/module/classes/CactusVAD.js +95 -0
  35. package/lib/module/classes/CactusVAD.js.map +1 -0
  36. package/lib/module/hooks/useCactusLM.js +23 -15
  37. package/lib/module/hooks/useCactusLM.js.map +1 -1
  38. package/lib/module/hooks/useCactusSTT.js +85 -28
  39. package/lib/module/hooks/useCactusSTT.js.map +1 -1
  40. package/lib/module/hooks/useCactusVAD.js +171 -0
  41. package/lib/module/hooks/useCactusVAD.js.map +1 -0
  42. package/lib/module/index.js +2 -3
  43. package/lib/module/index.js.map +1 -1
  44. package/lib/module/modelRegistry.js +52 -0
  45. package/lib/module/modelRegistry.js.map +1 -0
  46. package/lib/module/native/Cactus.js +107 -8
  47. package/lib/module/native/Cactus.js.map +1 -1
  48. package/lib/module/native/CactusIndex.js.map +1 -1
  49. package/lib/module/native/index.js +0 -3
  50. package/lib/module/native/index.js.map +1 -1
  51. package/lib/module/types/CactusLM.js +2 -0
  52. package/lib/module/types/CactusSTT.js +2 -0
  53. package/lib/module/types/CactusVAD.js +4 -0
  54. package/lib/module/types/{CactusModel.js.map → CactusVAD.js.map} +1 -1
  55. package/lib/module/types/common.js +2 -0
  56. package/lib/module/types/{CactusSTTModel.js.map → common.js.map} +1 -1
  57. package/lib/typescript/src/classes/CactusLM.d.ts +8 -6
  58. package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
  59. package/lib/typescript/src/classes/CactusSTT.d.ts +11 -6
  60. package/lib/typescript/src/classes/CactusSTT.d.ts.map +1 -1
  61. package/lib/typescript/src/classes/CactusVAD.d.ts +20 -0
  62. package/lib/typescript/src/classes/CactusVAD.d.ts.map +1 -0
  63. package/lib/typescript/src/hooks/useCactusLM.d.ts +3 -3
  64. package/lib/typescript/src/hooks/useCactusLM.d.ts.map +1 -1
  65. package/lib/typescript/src/hooks/useCactusSTT.d.ts +11 -5
  66. package/lib/typescript/src/hooks/useCactusSTT.d.ts.map +1 -1
  67. package/lib/typescript/src/hooks/useCactusVAD.d.ts +15 -0
  68. package/lib/typescript/src/hooks/useCactusVAD.d.ts.map +1 -0
  69. package/lib/typescript/src/index.d.ts +7 -6
  70. package/lib/typescript/src/index.d.ts.map +1 -1
  71. package/lib/typescript/src/modelRegistry.d.ts +5 -0
  72. package/lib/typescript/src/modelRegistry.d.ts.map +1 -0
  73. package/lib/typescript/src/native/Cactus.d.ts +12 -6
  74. package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
  75. package/lib/typescript/src/native/CactusIndex.d.ts +2 -2
  76. package/lib/typescript/src/native/CactusIndex.d.ts.map +1 -1
  77. package/lib/typescript/src/native/index.d.ts +0 -3
  78. package/lib/typescript/src/native/index.d.ts.map +1 -1
  79. package/lib/typescript/src/specs/Cactus.nitro.d.ts +6 -1
  80. package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
  81. package/lib/typescript/src/types/CactusIndex.d.ts +2 -2
  82. package/lib/typescript/src/types/CactusIndex.d.ts.map +1 -1
  83. package/lib/typescript/src/types/CactusLM.d.ts +19 -9
  84. package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
  85. package/lib/typescript/src/types/CactusSTT.d.ts +45 -4
  86. package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
  87. package/lib/typescript/src/types/CactusVAD.d.ts +34 -0
  88. package/lib/typescript/src/types/CactusVAD.d.ts.map +1 -0
  89. package/lib/typescript/src/types/common.d.ts +23 -0
  90. package/lib/typescript/src/types/common.d.ts.map +1 -0
  91. package/nitro.json +0 -11
  92. package/nitrogen/generated/android/cactus+autolinking.cmake +0 -5
  93. package/nitrogen/generated/android/cactusOnLoad.cpp +0 -30
  94. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.cpp +0 -50
  95. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.hpp +9 -147
  96. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Umbrella.hpp +0 -13
  97. package/nitrogen/generated/ios/CactusAutolinking.mm +0 -26
  98. package/nitrogen/generated/ios/CactusAutolinking.swift +0 -30
  99. package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +5 -0
  100. package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +6 -1
  101. package/package.json +3 -3
  102. package/src/classes/CactusLM.ts +59 -74
  103. package/src/classes/CactusSTT.ts +92 -49
  104. package/src/classes/CactusVAD.ts +129 -0
  105. package/src/hooks/useCactusLM.ts +26 -9
  106. package/src/hooks/useCactusSTT.ts +105 -44
  107. package/src/hooks/useCactusVAD.ts +215 -0
  108. package/src/index.tsx +20 -10
  109. package/src/modelRegistry.ts +65 -0
  110. package/src/native/Cactus.ts +130 -14
  111. package/src/native/CactusIndex.ts +2 -2
  112. package/src/native/index.ts +0 -3
  113. package/src/specs/Cactus.nitro.ts +11 -2
  114. package/src/types/CactusIndex.ts +2 -2
  115. package/src/types/CactusLM.ts +20 -9
  116. package/src/types/CactusSTT.ts +50 -4
  117. package/src/types/CactusVAD.ts +39 -0
  118. package/src/types/common.ts +23 -0
  119. package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusCrypto.kt +0 -46
  120. package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusDeviceInfo.kt +0 -27
  121. package/android/src/main/jniLibs/arm64-v8a/libcactus_util.a +0 -0
  122. package/cpp/HybridCactusUtil.cpp +0 -47
  123. package/cpp/HybridCactusUtil.hpp +0 -27
  124. package/cpp/cactus_util.h +0 -25
  125. package/ios/HybridCactusCrypto.swift +0 -37
  126. package/ios/HybridCactusDeviceInfo.swift +0 -32
  127. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_telemetry.h +0 -656
  128. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_telemetry.h +0 -656
  129. package/ios/cactus_util.xcframework/Info.plist +0 -39
  130. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/cactus_util.h +0 -25
  131. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/database.h +0 -27
  132. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/ios_utils.h +0 -10
  133. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/logging.h +0 -25
  134. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Info.plist +0 -0
  135. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/cactus_util +0 -0
  136. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/cactus_util.h +0 -25
  137. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/database.h +0 -27
  138. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/ios_utils.h +0 -10
  139. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/logging.h +0 -25
  140. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Info.plist +0 -0
  141. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/_CodeSignature/CodeResources +0 -135
  142. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/cactus_util +0 -0
  143. package/lib/module/api/Database.js +0 -137
  144. package/lib/module/api/Database.js.map +0 -1
  145. package/lib/module/api/RemoteLM.js +0 -201
  146. package/lib/module/api/RemoteLM.js.map +0 -1
  147. package/lib/module/config/CactusConfig.js +0 -12
  148. package/lib/module/config/CactusConfig.js.map +0 -1
  149. package/lib/module/native/CactusCrypto.js +0 -10
  150. package/lib/module/native/CactusCrypto.js.map +0 -1
  151. package/lib/module/native/CactusDeviceInfo.js +0 -13
  152. package/lib/module/native/CactusDeviceInfo.js.map +0 -1
  153. package/lib/module/native/CactusUtil.js +0 -36
  154. package/lib/module/native/CactusUtil.js.map +0 -1
  155. package/lib/module/specs/CactusCrypto.nitro.js +0 -4
  156. package/lib/module/specs/CactusCrypto.nitro.js.map +0 -1
  157. package/lib/module/specs/CactusDeviceInfo.nitro.js +0 -4
  158. package/lib/module/specs/CactusDeviceInfo.nitro.js.map +0 -1
  159. package/lib/module/specs/CactusUtil.nitro.js +0 -4
  160. package/lib/module/specs/CactusUtil.nitro.js.map +0 -1
  161. package/lib/module/telemetry/Telemetry.js +0 -154
  162. package/lib/module/telemetry/Telemetry.js.map +0 -1
  163. package/lib/module/types/CactusModel.js +0 -2
  164. package/lib/module/types/CactusSTTModel.js +0 -2
  165. package/lib/typescript/src/api/Database.d.ts +0 -18
  166. package/lib/typescript/src/api/Database.d.ts.map +0 -1
  167. package/lib/typescript/src/api/RemoteLM.d.ts +0 -14
  168. package/lib/typescript/src/api/RemoteLM.d.ts.map +0 -1
  169. package/lib/typescript/src/config/CactusConfig.d.ts +0 -7
  170. package/lib/typescript/src/config/CactusConfig.d.ts.map +0 -1
  171. package/lib/typescript/src/native/CactusCrypto.d.ts +0 -5
  172. package/lib/typescript/src/native/CactusCrypto.d.ts.map +0 -1
  173. package/lib/typescript/src/native/CactusDeviceInfo.d.ts +0 -7
  174. package/lib/typescript/src/native/CactusDeviceInfo.d.ts.map +0 -1
  175. package/lib/typescript/src/native/CactusUtil.d.ts +0 -6
  176. package/lib/typescript/src/native/CactusUtil.d.ts.map +0 -1
  177. package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts +0 -8
  178. package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts.map +0 -1
  179. package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts +0 -16
  180. package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts.map +0 -1
  181. package/lib/typescript/src/specs/CactusUtil.nitro.d.ts +0 -10
  182. package/lib/typescript/src/specs/CactusUtil.nitro.d.ts.map +0 -1
  183. package/lib/typescript/src/telemetry/Telemetry.d.ts +0 -34
  184. package/lib/typescript/src/telemetry/Telemetry.d.ts.map +0 -1
  185. package/lib/typescript/src/types/CactusModel.d.ts +0 -13
  186. package/lib/typescript/src/types/CactusModel.d.ts.map +0 -1
  187. package/lib/typescript/src/types/CactusSTTModel.d.ts +0 -8
  188. package/lib/typescript/src/types/CactusSTTModel.d.ts.map +0 -1
  189. package/nitrogen/generated/android/c++/JDeviceInfo.hpp +0 -74
  190. package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.cpp +0 -65
  191. package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.hpp +0 -65
  192. package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.cpp +0 -85
  193. package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.hpp +0 -66
  194. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/DeviceInfo.kt +0 -50
  195. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusCryptoSpec.kt +0 -58
  196. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusDeviceInfoSpec.kt +0 -62
  197. package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.cpp +0 -11
  198. package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.hpp +0 -77
  199. package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.cpp +0 -11
  200. package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.hpp +0 -88
  201. package/nitrogen/generated/ios/swift/DeviceInfo.swift +0 -98
  202. package/nitrogen/generated/ios/swift/Func_void_DeviceInfo.swift +0 -47
  203. package/nitrogen/generated/ios/swift/Func_void_std__optional_std__string_.swift +0 -54
  204. package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec.swift +0 -57
  205. package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec_cxx.swift +0 -139
  206. package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec.swift +0 -58
  207. package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec_cxx.swift +0 -164
  208. package/nitrogen/generated/shared/c++/DeviceInfo.hpp +0 -92
  209. package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.cpp +0 -21
  210. package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.hpp +0 -63
  211. package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.cpp +0 -22
  212. package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.hpp +0 -67
  213. package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.cpp +0 -23
  214. package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.hpp +0 -66
  215. package/src/api/Database.ts +0 -188
  216. package/src/api/RemoteLM.ts +0 -273
  217. package/src/config/CactusConfig.ts +0 -11
  218. package/src/native/CactusCrypto.ts +0 -11
  219. package/src/native/CactusDeviceInfo.ts +0 -18
  220. package/src/native/CactusUtil.ts +0 -43
  221. package/src/specs/CactusCrypto.nitro.ts +0 -6
  222. package/src/specs/CactusDeviceInfo.nitro.ts +0 -15
  223. package/src/specs/CactusUtil.nitro.ts +0 -8
  224. package/src/telemetry/Telemetry.ts +0 -236
  225. package/src/types/CactusModel.ts +0 -15
  226. package/src/types/CactusSTTModel.ts +0 -10
@@ -84,12 +84,17 @@ struct Config {
84
84
  bool use_thumbnail = true;
85
85
  uint32_t min_image_tokens = 64;
86
86
  uint32_t max_image_tokens = 256;
87
- uint32_t max_num_patches = 1024;
87
+ uint32_t max_num_patches = 1024;
88
88
  uint32_t tile_size = 512;
89
89
  float max_pixels_tolerance = 2.0f;
90
90
  bool do_image_splitting = true;
91
+ bool encoder_act_gelu = false;
92
+ bool decoder_act_gelu = false;
93
+ uint32_t num_encoder_layers = 0;
94
+ uint32_t num_decoder_layers = 0;
95
+ float partial_rotary_factor = 0.0f;
91
96
 
92
- enum class ModelType {QWEN = 0, GEMMA = 1, SMOL = 2, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6, WHISPER = 7};
97
+ enum class ModelType {QWEN = 0, GEMMA = 1, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6, WHISPER = 7, MOONSHINE = 8, SILERO_VAD = 9};
93
98
  ModelType model_type = ModelType::QWEN;
94
99
 
95
100
  enum class ModelVariant {DEFAULT = 0, VLM = 1, EXTRACT = 2, RAG = 3};
@@ -107,6 +112,8 @@ struct Config {
107
112
  float default_temperature = 0.6f;
108
113
  float default_top_p = 0.95f;
109
114
  size_t default_top_k = 20;
115
+ float default_max_tps = -1.0f;
116
+ float default_cloud_handoff_threshold = 0.0f;
110
117
 
111
118
  std::vector<std::string> layer_types;
112
119
  size_t conv_L_cache = 0;
@@ -131,9 +138,12 @@ struct MergeRule {
131
138
  struct ChatMessage {
132
139
  std::string role;
133
140
  std::string content;
141
+ std::string name;
134
142
  std::vector<std::string> images;
135
143
  };
136
144
 
145
+
146
+
137
147
  class Tokenizer {
138
148
  public:
139
149
  virtual ~Tokenizer() = default;
@@ -149,6 +159,7 @@ public:
149
159
  virtual uint32_t get_bos_token() const = 0;
150
160
  virtual uint32_t get_eos_token() const = 0;
151
161
  virtual bool has_chat_template() const { return has_chat_template_; }
162
+ std::string get_default_stop_sequence() const;
152
163
 
153
164
  virtual bool load_vocabulary_with_config(const std::string& vocab_file, const std::string& merges_file, const std::string& config_file) = 0;
154
165
 
@@ -156,11 +167,8 @@ public:
156
167
  uint32_t get_fake_token_id() const { return fake_token_id_; }
157
168
  uint32_t get_global_img_token_id() const { return global_img_token_id_; }
158
169
 
159
-
160
- void set_corpus_dir(const std::string& dir) { corpus_dir_ = dir; }
161
-
162
170
  protected:
163
- enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2, SMOL, BERT, WHISPER};
171
+ enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2, BERT, WHISPER};
164
172
  ModelType model_type_ = ModelType::UNKNOWN;
165
173
  enum class ModelVariant { DEFAULT, VLM, EXTRACT, RAG};
166
174
  ModelVariant model_variant_ = ModelVariant::DEFAULT;
@@ -170,14 +178,12 @@ protected:
170
178
  uint32_t image_token_id_ = 396;
171
179
  uint32_t fake_token_id_ = 49189;
172
180
  uint32_t global_img_token_id_ = 49152;
173
- std::string corpus_dir_;
174
181
 
175
182
  void detect_model_type(const std::string& config_path);
176
183
  std::string format_qwen_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
177
184
  std::string format_gemma_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
178
185
  std::string format_lfm2_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
179
186
  std::string format_lfm2_vl_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
180
- std::string format_smol_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
181
187
  };
182
188
 
183
189
  class BPETokenizer : public Tokenizer {
@@ -329,6 +335,8 @@ struct KVCache {
329
335
  struct LayerCache {
330
336
  std::vector<uint8_t> keys;
331
337
  std::vector<uint8_t> values;
338
+ std::vector<float> key_scales;
339
+ std::vector<float> value_scales;
332
340
  };
333
341
 
334
342
  std::vector<LayerCache> layer_caches;
@@ -354,13 +362,11 @@ struct KVCache {
354
362
  const std::vector<size_t>& v_nodes, size_t seq_len,
355
363
  size_t num_layers, size_t kv_heads, size_t head_dim);
356
364
 
357
- // Update KV cache from NPU prefill outputs
358
- // NPU outputs are in shape [num_tokens, num_kv_heads, head_dim]
359
- // This handles transposition to cache format and sliding window
360
365
  void update_from_npu(size_t layer_idx, const __fp16* k_data, const __fp16* v_data,
361
366
  size_t num_tokens, size_t kv_heads, size_t head_dim);
362
367
 
363
368
  bool is_empty() const { return current_seq_len == 0; }
369
+ bool is_int8() const { return precision == Precision::INT8; }
364
370
  void* get_key_ptr(size_t layer);
365
371
  void* get_value_ptr(size_t layer);
366
372
 
@@ -374,33 +380,44 @@ struct KVCache {
374
380
 
375
381
  CircularView get_key_view(size_t layer);
376
382
  CircularView get_value_view(size_t layer);
383
+
384
+ const int8_t* get_keys_int8(size_t layer) const;
385
+ const int8_t* get_values_int8(size_t layer) const;
386
+ const float* get_key_scales(size_t layer) const;
387
+ const float* get_value_scales(size_t layer) const;
377
388
  };
378
389
 
379
390
  class ToolCallConstrainer {
380
391
  public:
381
392
  enum class State {
382
- START, // -> expect {
383
- EXPECT_FC_KEY, // -> expect "function_call"
384
- EXPECT_FC_COLON, // -> expect :
385
- EXPECT_FC_OPEN_BRACE, // -> expect {
386
- EXPECT_NAME_KEY, // -> expect "name"
387
- EXPECT_NAME_COLON, // -> expect :
388
- EXPECT_NAME_VALUE, // -> expect "<function_name>"
389
- EXPECT_COMMA, // -> expect ,
390
- EXPECT_ARGS_KEY, // -> expect "arguments"
391
- EXPECT_ARGS_COLON, // -> expect :
392
- IN_ARGUMENTS, // -> free JSON, track brace depth
393
- EXPECT_INNER_CLOSE, // -> expect } to close inner object
394
- EXPECT_OUTER_CLOSE, // -> expect } to close outer object
395
- DONE, // complete
396
-
397
- LFM_START, // -> expect <|tool_call_start|>
398
- LFM_EXPECT_BRACKET, // -> expect [
399
- LFM_IN_FUNC_NAME, // -> expect function name
400
- LFM_EXPECT_PAREN, // -> expect (
401
- LFM_IN_ARGUMENTS, // -> arguments until )
402
- LFM_EXPECT_BRACKET_CLOSE, // -> expect ]
403
- LFM_EXPECT_END // -> expect <|tool_call_end|>
393
+ DONE,
394
+
395
+ QWEN_START,
396
+ QWEN_EXPECT_OPEN_BRACE,
397
+ QWEN_EXPECT_NAME_KEY,
398
+ QWEN_EXPECT_NAME_COLON,
399
+ QWEN_EXPECT_NAME_VALUE,
400
+ QWEN_EXPECT_COMMA,
401
+ QWEN_EXPECT_ARGS_KEY,
402
+ QWEN_EXPECT_ARGS_COLON,
403
+ QWEN_IN_ARGUMENTS,
404
+ QWEN_EXPECT_CLOSE_BRACE,
405
+ QWEN_EXPECT_END,
406
+
407
+ LFM_START,
408
+ LFM_EXPECT_BRACKET,
409
+ LFM_IN_FUNC_NAME,
410
+ LFM_EXPECT_PAREN,
411
+ LFM_IN_ARGUMENTS,
412
+ LFM_EXPECT_BRACKET_CLOSE,
413
+ LFM_EXPECT_END,
414
+
415
+ GEMMA_START,
416
+ GEMMA_EXPECT_CALL,
417
+ GEMMA_IN_FUNC_NAME,
418
+ GEMMA_EXPECT_BRACE,
419
+ GEMMA_IN_ARGUMENTS,
420
+ GEMMA_EXPECT_END
404
421
  };
405
422
 
406
423
  void init(Config::ModelType model_type,
@@ -417,42 +434,48 @@ public:
417
434
 
418
435
  private:
419
436
  bool active_ = false;
420
- State state_ = State::START;
437
+ State state_ = State::QWEN_START;
421
438
  Config::ModelType model_type_ = Config::ModelType::QWEN;
422
439
  Tokenizer* tokenizer_ = nullptr;
423
440
 
424
441
  std::vector<std::string> function_names_;
425
442
  std::string generated_text_;
426
- int brace_depth_ = 0; // Track nested braces in arguments
427
-
428
- // Pre-tokenized token sets for each grammar element
429
- std::unordered_set<uint32_t> open_brace_tokens_; // {
430
- std::unordered_set<uint32_t> close_brace_tokens_; // }
431
- std::unordered_set<uint32_t> colon_tokens_; // :
432
- std::unordered_set<uint32_t> comma_tokens_; // ,
433
- std::unordered_set<uint32_t> fc_key_tokens_; // "function_call"
434
- std::unordered_set<uint32_t> name_key_tokens_; // "name"
435
- std::unordered_set<uint32_t> args_key_tokens_; // "arguments"
436
- std::unordered_set<uint32_t> quote_tokens_; // "
437
- std::unordered_set<uint32_t> backtick_tokens_; // ` (to block markdown code fences)
438
- std::unordered_set<uint32_t> response_starter_tokens_; // Common response starters to block (I, I'm, Sorry, etc.)
439
- std::unordered_set<uint32_t> all_func_name_tokens_; // All function name tokens combined
440
- std::unordered_map<std::string, std::vector<uint32_t>> func_name_sequences_; // Full token sequence per function
441
-
442
- // LFM2-specific tokens
443
+ int brace_depth_ = 0;
444
+
445
+ std::unordered_set<uint32_t> qwen_tool_call_start_tokens_;
446
+ std::unordered_set<uint32_t> qwen_tool_call_end_tokens_;
447
+ std::unordered_set<uint32_t> open_brace_tokens_;
448
+ std::unordered_set<uint32_t> close_brace_tokens_;
449
+ std::unordered_set<uint32_t> colon_tokens_;
450
+ std::unordered_set<uint32_t> comma_tokens_;
451
+ std::unordered_set<uint32_t> name_key_tokens_;
452
+ std::unordered_set<uint32_t> args_key_tokens_;
453
+ std::unordered_set<uint32_t> quote_tokens_;
454
+ std::unordered_set<uint32_t> backtick_tokens_;
455
+ std::unordered_set<uint32_t> all_func_name_tokens_;
456
+ std::unordered_map<std::string, std::vector<uint32_t>> func_name_sequences_;
457
+
443
458
  std::unordered_set<uint32_t> tool_start_tokens_;
444
459
  std::unordered_set<uint32_t> tool_end_tokens_;
445
- std::unordered_set<uint32_t> bracket_open_tokens_; // [
446
- std::unordered_set<uint32_t> bracket_close_tokens_; // ]
447
- std::unordered_set<uint32_t> paren_open_tokens_; // (
448
- std::unordered_set<uint32_t> paren_close_tokens_; // )
449
- std::unordered_set<uint32_t> equals_tokens_; // =
460
+ std::unordered_set<uint32_t> bracket_open_tokens_;
461
+ std::unordered_set<uint32_t> bracket_close_tokens_;
462
+ std::unordered_set<uint32_t> paren_open_tokens_;
463
+ std::unordered_set<uint32_t> paren_close_tokens_;
464
+ std::unordered_set<uint32_t> equals_tokens_;
465
+
466
+ std::unordered_set<uint32_t> gemma_call_start_tokens_;
467
+ std::unordered_set<uint32_t> gemma_call_end_tokens_;
468
+ std::unordered_set<uint32_t> gemma_response_start_tokens_;
469
+ std::unordered_set<uint32_t> gemma_call_prefix_tokens_;
470
+ std::unordered_set<uint32_t> escape_tokens_;
450
471
 
451
472
  std::unordered_map<uint32_t, float> current_bias_;
452
473
 
453
474
  void compute_bias();
454
475
  void tokenize_grammar_elements();
455
476
  void add_tokens_for_string(const std::string& str, std::unordered_set<uint32_t>& token_set);
477
+ void tokenize_function_names(bool quote_names);
478
+ void init_common_tokens();
456
479
  };
457
480
 
458
481
  class Model {
@@ -477,22 +500,22 @@ public:
477
500
  const std::string& system_prompt = "", bool do_warmup = true);
478
501
 
479
502
  virtual uint32_t decode(const std::vector<uint32_t>& tokens, float temperature = -1.0f, float top_p = -1.0f,
480
- size_t top_k = 0, const std::string& profile_file = "");
503
+ size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr);
481
504
 
482
505
  virtual void prefill(const std::vector<uint32_t>& tokens, size_t chunk_size = 256, const std::string& profile_file = "");
483
506
 
484
507
  virtual uint32_t decode_with_images(const std::vector<uint32_t>& tokens, const std::vector<std::string>& image_paths,
485
508
  float temperature = -1.0f, float top_p = -1.0f,
486
- size_t top_k = 0, const std::string& profile_file = "");
509
+ size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr);
487
510
 
488
- virtual uint32_t decode_with_audio(const std::vector<uint32_t>& tokens, const std::vector<float>& mel_bins, float temperature = 0.0f, float top_p = 0.0f,
489
- size_t top_k = 0, const std::string& profile_file = "");
511
+ virtual uint32_t decode_with_audio(const std::vector<uint32_t>& tokens, const std::vector<float>& audio_features, float temperature = 0.0f, float top_p = 0.0f,
512
+ size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr);
490
513
 
491
514
  std::vector<float> get_embeddings(const std::vector<uint32_t>& tokens, bool pooled = true, bool normalize = false, const std::string& profile_file = "");
492
515
 
493
516
  virtual std::vector<float> get_image_embeddings(const std::string& image_path);
494
517
 
495
- virtual std::vector<float> get_audio_embeddings(const std::vector<float>& mel_bins);
518
+ virtual std::vector<float> get_audio_embeddings(const std::vector<float>& audio_features);
496
519
 
497
520
  virtual void reset_cache() { kv_cache_.reset(); }
498
521
 
@@ -515,7 +538,7 @@ public:
515
538
  protected:
516
539
  virtual size_t forward(const std::vector<uint32_t>& tokens, bool use_cache = false) = 0;
517
540
 
518
- virtual size_t forward(const std::vector<float>& mel_bins, const std::vector<uint32_t>& tokens, bool use_cache = false);
541
+ virtual size_t forward(const std::vector<float>& audio_features, const std::vector<uint32_t>& tokens, bool use_cache = false);
519
542
 
520
543
  virtual void load_weights_to_graph(CactusGraph* gb) = 0;
521
544
 
@@ -627,6 +650,7 @@ public:
627
650
  private:
628
651
  Config config_;
629
652
 
653
+ std::pair<int64_t, int64_t> compute_pixel_limits() const;
630
654
  std::vector<unsigned char> convert_to_rgb(const unsigned char* img_data, int width, int height, int channels);
631
655
  std::pair<int, int> smart_resize(int height, int width);
632
656
  bool is_image_too_large(int height, int width);
@@ -683,5 +707,102 @@ private:
683
707
  size_t num_mel_filters_;
684
708
  };
685
709
 
710
+ namespace index {
711
+ constexpr uint32_t MAGIC = 0x43414354;
712
+ constexpr uint32_t VERSION = 1;
713
+
714
+ struct Document {
715
+ int id;
716
+ std::vector<float> embedding;
717
+ std::string content;
718
+ std::string metadata;
719
+ };
720
+
721
+ struct QueryResult {
722
+ int doc_id;
723
+ float score;
724
+ };
725
+
726
+ struct QueryOptions {
727
+ size_t top_k = 10;
728
+ float score_threshold = -1.0f;
729
+ };
730
+
731
+ class Index {
732
+ public:
733
+ Index(const std::string& index_path, const std::string& data_path, size_t embedding_dim);
734
+ ~Index();
735
+
736
+ Index(const Index&) = delete;
737
+ Index& operator=(const Index&) = delete;
738
+ Index(Index&&) = delete;
739
+ Index& operator=(Index&&) = delete;
740
+
741
+ void add_documents(const std::vector<Document>& documents);
742
+ void delete_documents(const std::vector<int>& doc_ids);
743
+ std::vector<Document> get_documents(const std::vector<int>& doc_ids);
744
+ std::vector<std::vector<QueryResult>> query(const std::vector<std::vector<float>>& embeddings, const QueryOptions& options);
745
+ void compact();
746
+
747
+ private:
748
+ struct IndexHeader {
749
+ uint32_t magic;
750
+ uint32_t version;
751
+ uint32_t embedding_dim;
752
+ uint32_t num_documents;
753
+ };
754
+
755
+ struct IndexEntry {
756
+ int32_t doc_id;
757
+ uint64_t data_offset;
758
+ uint8_t flags; // bit 0: tombstone
759
+
760
+ const __fp16* embedding() const {
761
+ return reinterpret_cast<const __fp16*>(this + 1);
762
+ }
763
+
764
+ static size_t size(size_t embedding_dim) {
765
+ return sizeof(IndexEntry) + embedding_dim * sizeof(__fp16);
766
+ }
767
+ };
768
+
769
+ struct DataHeader {
770
+ uint32_t magic;
771
+ uint32_t version;
772
+ };
773
+
774
+ struct DataEntry {
775
+ uint16_t content_len;
776
+ uint16_t metadata_len;
777
+
778
+ const char* content() const {
779
+ return reinterpret_cast<const char*>(this + 1);
780
+ }
781
+
782
+ const char* metadata() const {
783
+ return content() + content_len;
784
+ }
785
+ };
786
+
787
+ void parse_index_header();
788
+ void parse_data_header();
789
+ void build_doc_id_map();
790
+ void validate_documents(const std::vector<Document>& documents);
791
+ void validate_doc_ids(const std::vector<int>& doc_ids);
792
+ ssize_t write_full(int fd, const void* buf, size_t count);
793
+
794
+ std::unordered_map<int, uint32_t> doc_id_map_;
795
+
796
+ std::string index_path_, data_path_;
797
+ size_t embedding_dim_;
798
+ size_t index_entry_size_;
799
+ uint32_t num_documents_;
800
+
801
+ int index_fd_, data_fd_;
802
+ void *mapped_index_, *mapped_data_;
803
+ size_t index_file_size_, data_file_size_;
804
+ };
805
+ } // namespace index
806
+
807
+ }
686
808
  }
687
- }