cactus-react-native 1.5.0 → 1.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. package/Cactus.podspec +1 -1
  2. package/README.md +347 -241
  3. package/android/CMakeLists.txt +24 -5
  4. package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libcurl.a +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libmbedcrypto.a +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libmbedtls.a +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libmbedx509.a +0 -0
  9. package/cpp/HybridCactus.cpp +149 -117
  10. package/cpp/HybridCactus.hpp +14 -10
  11. package/cpp/cactus_ffi.h +54 -43
  12. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus.h +0 -1
  13. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +54 -43
  14. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_utils.h +318 -123
  15. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +118 -15
  16. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +77 -32
  17. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +68 -6
  18. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +21 -155
  19. package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
  20. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus.h +0 -1
  21. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +54 -43
  22. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_utils.h +318 -123
  23. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +118 -15
  24. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +77 -32
  25. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +68 -6
  26. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +21 -155
  27. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
  28. package/lib/module/classes/CactusLM.js +16 -49
  29. package/lib/module/classes/CactusLM.js.map +1 -1
  30. package/lib/module/classes/CactusSTT.js +30 -79
  31. package/lib/module/classes/CactusSTT.js.map +1 -1
  32. package/lib/module/classes/CactusVAD.js +95 -0
  33. package/lib/module/classes/CactusVAD.js.map +1 -0
  34. package/lib/module/hooks/useCactusLM.js +10 -11
  35. package/lib/module/hooks/useCactusLM.js.map +1 -1
  36. package/lib/module/hooks/useCactusSTT.js +23 -62
  37. package/lib/module/hooks/useCactusSTT.js.map +1 -1
  38. package/lib/module/hooks/useCactusVAD.js +171 -0
  39. package/lib/module/hooks/useCactusVAD.js.map +1 -0
  40. package/lib/module/index.js +2 -3
  41. package/lib/module/index.js.map +1 -1
  42. package/lib/module/modelRegistry.js +52 -0
  43. package/lib/module/modelRegistry.js.map +1 -0
  44. package/lib/module/native/Cactus.js +85 -23
  45. package/lib/module/native/Cactus.js.map +1 -1
  46. package/lib/module/native/CactusIndex.js.map +1 -1
  47. package/lib/module/native/index.js +0 -3
  48. package/lib/module/native/index.js.map +1 -1
  49. package/lib/module/types/CactusVAD.js +4 -0
  50. package/lib/module/{specs/CactusUtil.nitro.js.map → types/CactusVAD.js.map} +1 -1
  51. package/lib/typescript/src/classes/CactusLM.d.ts +5 -7
  52. package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
  53. package/lib/typescript/src/classes/CactusSTT.d.ts +8 -12
  54. package/lib/typescript/src/classes/CactusSTT.d.ts.map +1 -1
  55. package/lib/typescript/src/classes/CactusVAD.d.ts +20 -0
  56. package/lib/typescript/src/classes/CactusVAD.d.ts.map +1 -0
  57. package/lib/typescript/src/hooks/useCactusLM.d.ts +2 -2
  58. package/lib/typescript/src/hooks/useCactusLM.d.ts.map +1 -1
  59. package/lib/typescript/src/hooks/useCactusSTT.d.ts +6 -8
  60. package/lib/typescript/src/hooks/useCactusSTT.d.ts.map +1 -1
  61. package/lib/typescript/src/hooks/useCactusVAD.d.ts +15 -0
  62. package/lib/typescript/src/hooks/useCactusVAD.d.ts.map +1 -0
  63. package/lib/typescript/src/index.d.ts +7 -5
  64. package/lib/typescript/src/index.d.ts.map +1 -1
  65. package/lib/typescript/src/modelRegistry.d.ts +5 -0
  66. package/lib/typescript/src/modelRegistry.d.ts.map +1 -0
  67. package/lib/typescript/src/native/Cactus.d.ts +12 -11
  68. package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
  69. package/lib/typescript/src/native/CactusIndex.d.ts +2 -2
  70. package/lib/typescript/src/native/CactusIndex.d.ts.map +1 -1
  71. package/lib/typescript/src/native/index.d.ts +0 -3
  72. package/lib/typescript/src/native/index.d.ts.map +1 -1
  73. package/lib/typescript/src/specs/Cactus.nitro.d.ts +6 -6
  74. package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
  75. package/lib/typescript/src/types/CactusIndex.d.ts +2 -2
  76. package/lib/typescript/src/types/CactusIndex.d.ts.map +1 -1
  77. package/lib/typescript/src/types/CactusLM.d.ts +19 -11
  78. package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
  79. package/lib/typescript/src/types/CactusSTT.d.ts +33 -12
  80. package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
  81. package/lib/typescript/src/types/CactusVAD.d.ts +34 -0
  82. package/lib/typescript/src/types/CactusVAD.d.ts.map +1 -0
  83. package/lib/typescript/src/types/common.d.ts +1 -6
  84. package/lib/typescript/src/types/common.d.ts.map +1 -1
  85. package/nitro.json +0 -11
  86. package/nitrogen/generated/android/cactus+autolinking.cmake +0 -5
  87. package/nitrogen/generated/android/cactusOnLoad.cpp +0 -30
  88. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.cpp +0 -50
  89. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.hpp +9 -147
  90. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Umbrella.hpp +0 -13
  91. package/nitrogen/generated/ios/CactusAutolinking.mm +0 -26
  92. package/nitrogen/generated/ios/CactusAutolinking.swift +0 -30
  93. package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +4 -4
  94. package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +6 -6
  95. package/package.json +3 -3
  96. package/src/classes/CactusLM.ts +18 -65
  97. package/src/classes/CactusSTT.ts +39 -97
  98. package/src/classes/CactusVAD.ts +129 -0
  99. package/src/hooks/useCactusLM.ts +14 -17
  100. package/src/hooks/useCactusSTT.ts +47 -98
  101. package/src/hooks/useCactusVAD.ts +215 -0
  102. package/src/index.tsx +18 -12
  103. package/src/modelRegistry.ts +65 -0
  104. package/src/native/Cactus.ts +102 -41
  105. package/src/native/CactusIndex.ts +2 -2
  106. package/src/native/index.ts +0 -3
  107. package/src/specs/Cactus.nitro.ts +11 -7
  108. package/src/types/CactusIndex.ts +2 -2
  109. package/src/types/CactusLM.ts +19 -11
  110. package/src/types/CactusSTT.ts +33 -13
  111. package/src/types/CactusVAD.ts +39 -0
  112. package/src/types/common.ts +1 -6
  113. package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusCrypto.kt +0 -46
  114. package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusDeviceInfo.kt +0 -27
  115. package/android/src/main/jniLibs/arm64-v8a/libcactus_util.a +0 -0
  116. package/cpp/HybridCactusUtil.cpp +0 -47
  117. package/cpp/HybridCactusUtil.hpp +0 -27
  118. package/cpp/cactus_util.h +0 -25
  119. package/ios/HybridCactusCrypto.swift +0 -37
  120. package/ios/HybridCactusDeviceInfo.swift +0 -32
  121. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_telemetry.h +0 -656
  122. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_telemetry.h +0 -656
  123. package/ios/cactus_util.xcframework/Info.plist +0 -39
  124. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/cactus_util.h +0 -25
  125. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/database.h +0 -27
  126. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/ios_utils.h +0 -10
  127. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/logging.h +0 -25
  128. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Info.plist +0 -0
  129. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/cactus_util +0 -0
  130. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/cactus_util.h +0 -25
  131. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/database.h +0 -27
  132. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/ios_utils.h +0 -10
  133. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/logging.h +0 -25
  134. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Info.plist +0 -0
  135. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/_CodeSignature/CodeResources +0 -135
  136. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/cactus_util +0 -0
  137. package/lib/module/api/Database.js +0 -45
  138. package/lib/module/api/Database.js.map +0 -1
  139. package/lib/module/api/RemoteLM.js +0 -201
  140. package/lib/module/api/RemoteLM.js.map +0 -1
  141. package/lib/module/config/CactusConfig.js +0 -12
  142. package/lib/module/config/CactusConfig.js.map +0 -1
  143. package/lib/module/models.js +0 -336
  144. package/lib/module/models.js.map +0 -1
  145. package/lib/module/native/CactusCrypto.js +0 -10
  146. package/lib/module/native/CactusCrypto.js.map +0 -1
  147. package/lib/module/native/CactusDeviceInfo.js +0 -13
  148. package/lib/module/native/CactusDeviceInfo.js.map +0 -1
  149. package/lib/module/native/CactusUtil.js +0 -36
  150. package/lib/module/native/CactusUtil.js.map +0 -1
  151. package/lib/module/specs/CactusCrypto.nitro.js +0 -4
  152. package/lib/module/specs/CactusCrypto.nitro.js.map +0 -1
  153. package/lib/module/specs/CactusDeviceInfo.nitro.js +0 -4
  154. package/lib/module/specs/CactusDeviceInfo.nitro.js.map +0 -1
  155. package/lib/module/specs/CactusUtil.nitro.js +0 -4
  156. package/lib/module/telemetry/Telemetry.js +0 -154
  157. package/lib/module/telemetry/Telemetry.js.map +0 -1
  158. package/lib/typescript/src/api/Database.d.ts +0 -12
  159. package/lib/typescript/src/api/Database.d.ts.map +0 -1
  160. package/lib/typescript/src/api/RemoteLM.d.ts +0 -14
  161. package/lib/typescript/src/api/RemoteLM.d.ts.map +0 -1
  162. package/lib/typescript/src/config/CactusConfig.d.ts +0 -7
  163. package/lib/typescript/src/config/CactusConfig.d.ts.map +0 -1
  164. package/lib/typescript/src/models.d.ts +0 -6
  165. package/lib/typescript/src/models.d.ts.map +0 -1
  166. package/lib/typescript/src/native/CactusCrypto.d.ts +0 -5
  167. package/lib/typescript/src/native/CactusCrypto.d.ts.map +0 -1
  168. package/lib/typescript/src/native/CactusDeviceInfo.d.ts +0 -7
  169. package/lib/typescript/src/native/CactusDeviceInfo.d.ts.map +0 -1
  170. package/lib/typescript/src/native/CactusUtil.d.ts +0 -6
  171. package/lib/typescript/src/native/CactusUtil.d.ts.map +0 -1
  172. package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts +0 -8
  173. package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts.map +0 -1
  174. package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts +0 -16
  175. package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts.map +0 -1
  176. package/lib/typescript/src/specs/CactusUtil.nitro.d.ts +0 -10
  177. package/lib/typescript/src/specs/CactusUtil.nitro.d.ts.map +0 -1
  178. package/lib/typescript/src/telemetry/Telemetry.d.ts +0 -34
  179. package/lib/typescript/src/telemetry/Telemetry.d.ts.map +0 -1
  180. package/nitrogen/generated/android/c++/JDeviceInfo.hpp +0 -74
  181. package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.cpp +0 -65
  182. package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.hpp +0 -65
  183. package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.cpp +0 -85
  184. package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.hpp +0 -66
  185. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/DeviceInfo.kt +0 -50
  186. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusCryptoSpec.kt +0 -58
  187. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusDeviceInfoSpec.kt +0 -62
  188. package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.cpp +0 -11
  189. package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.hpp +0 -77
  190. package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.cpp +0 -11
  191. package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.hpp +0 -88
  192. package/nitrogen/generated/ios/swift/DeviceInfo.swift +0 -98
  193. package/nitrogen/generated/ios/swift/Func_void_DeviceInfo.swift +0 -47
  194. package/nitrogen/generated/ios/swift/Func_void_std__optional_std__string_.swift +0 -54
  195. package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec.swift +0 -57
  196. package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec_cxx.swift +0 -139
  197. package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec.swift +0 -58
  198. package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec_cxx.swift +0 -164
  199. package/nitrogen/generated/shared/c++/DeviceInfo.hpp +0 -92
  200. package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.cpp +0 -21
  201. package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.hpp +0 -63
  202. package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.cpp +0 -22
  203. package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.hpp +0 -67
  204. package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.cpp +0 -23
  205. package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.hpp +0 -66
  206. package/src/api/Database.ts +0 -55
  207. package/src/api/RemoteLM.ts +0 -273
  208. package/src/config/CactusConfig.ts +0 -11
  209. package/src/models.ts +0 -344
  210. package/src/native/CactusCrypto.ts +0 -11
  211. package/src/native/CactusDeviceInfo.ts +0 -18
  212. package/src/native/CactusUtil.ts +0 -43
  213. package/src/specs/CactusCrypto.nitro.ts +0 -6
  214. package/src/specs/CactusDeviceInfo.nitro.ts +0 -15
  215. package/src/specs/CactusUtil.nitro.ts +0 -8
  216. package/src/telemetry/Telemetry.ts +0 -236
@@ -84,12 +84,17 @@ struct Config {
84
84
  bool use_thumbnail = true;
85
85
  uint32_t min_image_tokens = 64;
86
86
  uint32_t max_image_tokens = 256;
87
- uint32_t max_num_patches = 1024;
87
+ uint32_t max_num_patches = 1024;
88
88
  uint32_t tile_size = 512;
89
89
  float max_pixels_tolerance = 2.0f;
90
90
  bool do_image_splitting = true;
91
+ bool encoder_act_gelu = false;
92
+ bool decoder_act_gelu = false;
93
+ uint32_t num_encoder_layers = 0;
94
+ uint32_t num_decoder_layers = 0;
95
+ float partial_rotary_factor = 0.0f;
91
96
 
92
- enum class ModelType {QWEN = 0, GEMMA = 1, SMOL = 2, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6, WHISPER = 7};
97
+ enum class ModelType {QWEN = 0, GEMMA = 1, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6, WHISPER = 7, MOONSHINE = 8, SILERO_VAD = 9};
93
98
  ModelType model_type = ModelType::QWEN;
94
99
 
95
100
  enum class ModelVariant {DEFAULT = 0, VLM = 1, EXTRACT = 2, RAG = 3};
@@ -107,6 +112,8 @@ struct Config {
107
112
  float default_temperature = 0.6f;
108
113
  float default_top_p = 0.95f;
109
114
  size_t default_top_k = 20;
115
+ float default_max_tps = -1.0f;
116
+ float default_cloud_handoff_threshold = 0.0f;
110
117
 
111
118
  std::vector<std::string> layer_types;
112
119
  size_t conv_L_cache = 0;
@@ -152,6 +159,7 @@ public:
152
159
  virtual uint32_t get_bos_token() const = 0;
153
160
  virtual uint32_t get_eos_token() const = 0;
154
161
  virtual bool has_chat_template() const { return has_chat_template_; }
162
+ std::string get_default_stop_sequence() const;
155
163
 
156
164
  virtual bool load_vocabulary_with_config(const std::string& vocab_file, const std::string& merges_file, const std::string& config_file) = 0;
157
165
 
@@ -159,11 +167,8 @@ public:
159
167
  uint32_t get_fake_token_id() const { return fake_token_id_; }
160
168
  uint32_t get_global_img_token_id() const { return global_img_token_id_; }
161
169
 
162
-
163
- void set_corpus_dir(const std::string& dir) { corpus_dir_ = dir; }
164
-
165
170
  protected:
166
- enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2, SMOL, BERT, WHISPER};
171
+ enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2, BERT, WHISPER};
167
172
  ModelType model_type_ = ModelType::UNKNOWN;
168
173
  enum class ModelVariant { DEFAULT, VLM, EXTRACT, RAG};
169
174
  ModelVariant model_variant_ = ModelVariant::DEFAULT;
@@ -173,14 +178,12 @@ protected:
173
178
  uint32_t image_token_id_ = 396;
174
179
  uint32_t fake_token_id_ = 49189;
175
180
  uint32_t global_img_token_id_ = 49152;
176
- std::string corpus_dir_;
177
181
 
178
182
  void detect_model_type(const std::string& config_path);
179
183
  std::string format_qwen_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
180
184
  std::string format_gemma_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
181
185
  std::string format_lfm2_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
182
186
  std::string format_lfm2_vl_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
183
- std::string format_smol_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
184
187
  };
185
188
 
186
189
  class BPETokenizer : public Tokenizer {
@@ -471,6 +474,8 @@ private:
471
474
  void compute_bias();
472
475
  void tokenize_grammar_elements();
473
476
  void add_tokens_for_string(const std::string& str, std::unordered_set<uint32_t>& token_set);
477
+ void tokenize_function_names(bool quote_names);
478
+ void init_common_tokens();
474
479
  };
475
480
 
476
481
  class Model {
@@ -495,22 +500,22 @@ public:
495
500
  const std::string& system_prompt = "", bool do_warmup = true);
496
501
 
497
502
  virtual uint32_t decode(const std::vector<uint32_t>& tokens, float temperature = -1.0f, float top_p = -1.0f,
498
- size_t top_k = 0, const std::string& profile_file = "");
503
+ size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr);
499
504
 
500
505
  virtual void prefill(const std::vector<uint32_t>& tokens, size_t chunk_size = 256, const std::string& profile_file = "");
501
506
 
502
507
  virtual uint32_t decode_with_images(const std::vector<uint32_t>& tokens, const std::vector<std::string>& image_paths,
503
508
  float temperature = -1.0f, float top_p = -1.0f,
504
- size_t top_k = 0, const std::string& profile_file = "");
509
+ size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr);
505
510
 
506
- virtual uint32_t decode_with_audio(const std::vector<uint32_t>& tokens, const std::vector<float>& mel_bins, float temperature = 0.0f, float top_p = 0.0f,
507
- size_t top_k = 0, const std::string& profile_file = "");
511
+ virtual uint32_t decode_with_audio(const std::vector<uint32_t>& tokens, const std::vector<float>& audio_features, float temperature = 0.0f, float top_p = 0.0f,
512
+ size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr);
508
513
 
509
514
  std::vector<float> get_embeddings(const std::vector<uint32_t>& tokens, bool pooled = true, bool normalize = false, const std::string& profile_file = "");
510
515
 
511
516
  virtual std::vector<float> get_image_embeddings(const std::string& image_path);
512
517
 
513
- virtual std::vector<float> get_audio_embeddings(const std::vector<float>& mel_bins);
518
+ virtual std::vector<float> get_audio_embeddings(const std::vector<float>& audio_features);
514
519
 
515
520
  virtual void reset_cache() { kv_cache_.reset(); }
516
521
 
@@ -533,7 +538,7 @@ public:
533
538
  protected:
534
539
  virtual size_t forward(const std::vector<uint32_t>& tokens, bool use_cache = false) = 0;
535
540
 
536
- virtual size_t forward(const std::vector<float>& mel_bins, const std::vector<uint32_t>& tokens, bool use_cache = false);
541
+ virtual size_t forward(const std::vector<float>& audio_features, const std::vector<uint32_t>& tokens, bool use_cache = false);
537
542
 
538
543
  virtual void load_weights_to_graph(CactusGraph* gb) = 0;
539
544
 
@@ -645,6 +650,7 @@ public:
645
650
  private:
646
651
  Config config_;
647
652
 
653
+ std::pair<int64_t, int64_t> compute_pixel_limits() const;
648
654
  std::vector<unsigned char> convert_to_rgb(const unsigned char* img_data, int width, int height, int channels);
649
655
  std::pair<int, int> smart_resize(int height, int width);
650
656
  bool is_image_too_large(int height, int width);
@@ -701,5 +707,102 @@ private:
701
707
  size_t num_mel_filters_;
702
708
  };
703
709
 
710
+ namespace index {
711
+ constexpr uint32_t MAGIC = 0x43414354;
712
+ constexpr uint32_t VERSION = 1;
713
+
714
+ struct Document {
715
+ int id;
716
+ std::vector<float> embedding;
717
+ std::string content;
718
+ std::string metadata;
719
+ };
720
+
721
+ struct QueryResult {
722
+ int doc_id;
723
+ float score;
724
+ };
725
+
726
+ struct QueryOptions {
727
+ size_t top_k = 10;
728
+ float score_threshold = -1.0f;
729
+ };
730
+
731
+ class Index {
732
+ public:
733
+ Index(const std::string& index_path, const std::string& data_path, size_t embedding_dim);
734
+ ~Index();
735
+
736
+ Index(const Index&) = delete;
737
+ Index& operator=(const Index&) = delete;
738
+ Index(Index&&) = delete;
739
+ Index& operator=(Index&&) = delete;
740
+
741
+ void add_documents(const std::vector<Document>& documents);
742
+ void delete_documents(const std::vector<int>& doc_ids);
743
+ std::vector<Document> get_documents(const std::vector<int>& doc_ids);
744
+ std::vector<std::vector<QueryResult>> query(const std::vector<std::vector<float>>& embeddings, const QueryOptions& options);
745
+ void compact();
746
+
747
+ private:
748
+ struct IndexHeader {
749
+ uint32_t magic;
750
+ uint32_t version;
751
+ uint32_t embedding_dim;
752
+ uint32_t num_documents;
753
+ };
754
+
755
+ struct IndexEntry {
756
+ int32_t doc_id;
757
+ uint64_t data_offset;
758
+ uint8_t flags; // bit 0: tombstone
759
+
760
+ const __fp16* embedding() const {
761
+ return reinterpret_cast<const __fp16*>(this + 1);
762
+ }
763
+
764
+ static size_t size(size_t embedding_dim) {
765
+ return sizeof(IndexEntry) + embedding_dim * sizeof(__fp16);
766
+ }
767
+ };
768
+
769
+ struct DataHeader {
770
+ uint32_t magic;
771
+ uint32_t version;
772
+ };
773
+
774
+ struct DataEntry {
775
+ uint16_t content_len;
776
+ uint16_t metadata_len;
777
+
778
+ const char* content() const {
779
+ return reinterpret_cast<const char*>(this + 1);
780
+ }
781
+
782
+ const char* metadata() const {
783
+ return content() + content_len;
784
+ }
785
+ };
786
+
787
+ void parse_index_header();
788
+ void parse_data_header();
789
+ void build_doc_id_map();
790
+ void validate_documents(const std::vector<Document>& documents);
791
+ void validate_doc_ids(const std::vector<int>& doc_ids);
792
+ ssize_t write_full(int fd, const void* buf, size_t count);
793
+
794
+ std::unordered_map<int, uint32_t> doc_id_map_;
795
+
796
+ std::string index_path_, data_path_;
797
+ size_t embedding_dim_;
798
+ size_t index_entry_size_;
799
+ uint32_t num_documents_;
800
+
801
+ int index_fd_, data_fd_;
802
+ void *mapped_index_, *mapped_data_;
803
+ size_t index_file_size_, data_file_size_;
804
+ };
805
+ } // namespace index
806
+
807
+ }
704
808
  }
705
- }
@@ -4,6 +4,7 @@
4
4
  #include <vector>
5
5
  #include <memory>
6
6
  #include <unordered_map>
7
+ #include <unordered_set>
7
8
  #include <functional>
8
9
  #include <cstring>
9
10
  #include <stdexcept>
@@ -114,17 +115,20 @@ enum class OpType {
114
115
  MATMUL, TRANSPOSE, RESHAPE, SLICE, GATHER, EMBEDDING,
115
116
  BILINEAR_INTERPOLATION,
116
117
  SUM, MEAN, VARIANCE, MIN, MAX,
117
- RMS_NORM, ROPE, SOFTMAX, ATTENTION, ATTENTION_INT8_HYBRID, CONV1D_CAUSAL, CONV1D_K3,
118
+ RMS_NORM, ROPE, ROPE_GPTJ, SOFTMAX, ATTENTION, ATTENTION_INT8_HYBRID, CONV1D_CAUSAL, CONV1D_K3, CONV1D_K7S3, CONV1D,
118
119
  SCALAR_ADD, SCALAR_SUBTRACT, SCALAR_MULTIPLY, SCALAR_DIVIDE, SCALAR_EXP, SCALAR_SQRT, SCALAR_COS, SCALAR_SIN,
119
- SILU, GELU, GELU_ERF,
120
+ RELU, SILU, GELU, GELU_ERF, SIGMOID, TANH,
120
121
  SAMPLE, CONCAT,
121
122
  SCATTER_TOPK,
122
- TOPK, LAYERNORM,
123
+ TOPK, LAYERNORM, GROUPNORM,
123
124
  INDEX,
125
+ PERSISTENT,
126
+ QUANTIZE_ACTIVATIONS,
127
+ LSTM_CELL,
128
+ STFT_MAGNITUDE
124
129
  };
125
130
 
126
131
  struct PrecisionTraits {
127
- // Returns in-memory element size (INT4 unpacks to INT8, so returns 1)
128
132
  static constexpr size_t size_of(Precision prec) {
129
133
  switch (prec) {
130
134
  case Precision::INT8: return 1;
@@ -205,8 +209,12 @@ struct BufferDesc {
205
209
  void* scales_data = nullptr;
206
210
  std::unique_ptr<char[]> owned_scales;
207
211
 
208
- const void* packed_int4_data = nullptr;
209
- size_t packed_int4_size = 0;
212
+ bool is_interleaved = false;
213
+ size_t original_N = 0;
214
+
215
+ void* activation_scales_data = nullptr;
216
+ std::unique_ptr<char[]> owned_activation_scales;
217
+ size_t num_rows_for_activation_scales = 0;
210
218
 
211
219
  BufferDesc();
212
220
  BufferDesc(const std::vector<size_t>& s, Precision prec = Precision::INT8);
@@ -230,23 +238,39 @@ struct BufferDesc {
230
238
  const __fp16* scales_as_fp16() const {
231
239
  return reinterpret_cast<const __fp16*>(scales_data);
232
240
  }
241
+
233
242
  bool is_grouped_int8() const {
234
243
  return precision == Precision::INT8 && group_size > 0;
235
244
  }
236
- bool is_packed_int4() const {
237
- return packed_int4_data != nullptr && packed_int4_size > 0;
238
- }
239
- const uint8_t* packed_int4_as_uint8() const {
240
- return reinterpret_cast<const uint8_t*>(packed_int4_data);
241
- }
245
+
242
246
  void set_grouped_scales(size_t gs, size_t ng, void* scales_ptr) {
243
247
  group_size = gs;
244
248
  num_groups = ng;
245
249
  scales_data = scales_ptr;
246
250
  }
247
- void set_packed_int4(const void* packed_data, size_t packed_size) {
248
- packed_int4_data = packed_data;
249
- packed_int4_size = packed_size;
251
+
252
+ void set_interleaved(bool interleaved, size_t orig_n) {
253
+ is_interleaved = interleaved;
254
+ original_N = orig_n;
255
+ }
256
+
257
+ bool has_activation_scales() const {
258
+ return activation_scales_data != nullptr && num_rows_for_activation_scales > 0;
259
+ }
260
+ const float* activation_scales_as_float() const {
261
+ return reinterpret_cast<const float*>(activation_scales_data);
262
+ }
263
+ float* activation_scales_as_float() {
264
+ return reinterpret_cast<float*>(activation_scales_data);
265
+ }
266
+ void allocate_activation_scales(size_t num_rows) {
267
+ num_rows_for_activation_scales = num_rows;
268
+ owned_activation_scales = std::make_unique<char[]>(num_rows * sizeof(float));
269
+ activation_scales_data = owned_activation_scales.get();
270
+ }
271
+ void set_activation_scales(void* scales_ptr, size_t num_rows) {
272
+ activation_scales_data = scales_ptr;
273
+ num_rows_for_activation_scales = num_rows;
250
274
  }
251
275
 
252
276
  void allocate();
@@ -282,6 +306,7 @@ struct OpParams {
282
306
 
283
307
  size_t index_value = 0;
284
308
  size_t num_classes = 0;
309
+ size_t num_groups = 0;
285
310
  size_t dst_height = 0;
286
311
  size_t dst_width = 0;
287
312
 
@@ -295,6 +320,7 @@ struct OpParams {
295
320
  size_t cache_seq_len = 0;
296
321
  size_t num_kv_heads = 0;
297
322
  size_t head_dim = 0;
323
+ size_t num_fft_bins = 0;
298
324
  };
299
325
 
300
326
  struct GraphNode {
@@ -324,7 +350,10 @@ void compute_sample_node(GraphNode& node, const std::vector<std::unique_ptr<Grap
324
350
  void compute_scatter_topk_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
325
351
  void compute_topk_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
326
352
  void compute_layernorm_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
353
+ void compute_groupnorm_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
354
+ void compute_persistent_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
327
355
  void compute_index_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
356
+ void compute_lstm_cell_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
328
357
 
329
358
  void shrink_thread_local_buffers();
330
359
 
@@ -372,6 +401,7 @@ public:
372
401
 
373
402
  size_t input(const std::vector<size_t>& shape, Precision precision = Precision::INT8);
374
403
  size_t precision_cast(size_t input, Precision target_precision);
404
+ size_t quantize_activations(size_t input);
375
405
 
376
406
  size_t add(size_t input1, size_t input2);
377
407
  size_t add_clipped(size_t input1, size_t input2);
@@ -389,9 +419,12 @@ public:
389
419
  size_t scalar_cos(size_t input);
390
420
  size_t scalar_sin(size_t input);
391
421
 
422
+ size_t relu(size_t input);
392
423
  size_t silu(size_t input);
393
424
  size_t gelu(size_t input);
394
425
  size_t gelu_erf(size_t input);
426
+ size_t sigmoid(size_t input);
427
+ size_t tanh(size_t input);
395
428
 
396
429
  size_t matmul(size_t input1, size_t input2, bool pretransposed_rhs = false, ComputeBackend backend = ComputeBackend::CPU);
397
430
  size_t transpose(size_t input, ComputeBackend backend = ComputeBackend::CPU);
@@ -409,8 +442,8 @@ public:
409
442
  size_t gather(size_t embeddings, size_t indices);
410
443
  size_t mmap_embeddings(const std::string& filename);
411
444
  size_t mmap_weights(const std::string& filename);
412
- size_t load_weights(const std::string& filename);
413
445
  void set_grouped_scales(size_t node_id, size_t group_size, size_t num_groups, void* scales_ptr);
446
+ void set_interleaved(size_t node_id, bool interleaved, size_t original_N);
414
447
 
415
448
  void release_weight_pages(size_t node_id);
416
449
  void prefetch_weight_pages(size_t node_id);
@@ -420,9 +453,12 @@ public:
420
453
  size_t bilinear_interpolation(size_t pos_embeds, size_t dst_height, size_t dst_width);
421
454
 
422
455
  size_t layernorm(size_t input, size_t weight, size_t bias, float epsilon = 1e-5f);
456
+ size_t layernorm(size_t input, size_t weight, float epsilon = 1e-5f); // No bias version
457
+ size_t groupnorm(size_t input, size_t weight, size_t bias, size_t num_groups = 32, float epsilon = 1e-5f);
423
458
  size_t topk(size_t input, size_t k);
424
459
  size_t rms_norm(size_t input, size_t weight, float epsilon = 1e-5f);
425
460
  size_t rope(size_t input, float theta, size_t position_offset = 0, ComputeBackend backend = ComputeBackend::CPU);
461
+ size_t rope_gptj(size_t input, float theta, size_t position_offset = 0, size_t rot_dim = 0, ComputeBackend backend = ComputeBackend::CPU);
426
462
  size_t softmax(size_t input, int axis = -1);
427
463
  size_t attention(size_t query, size_t key, size_t value, float scale, bool is_causal = true, ComputeBackend backend = ComputeBackend::CPU);
428
464
  size_t attention(size_t query, size_t key, size_t value, float scale, size_t position_offset, ComputeBackend backend = ComputeBackend::CPU);
@@ -431,11 +467,17 @@ public:
431
467
  size_t attention_int8_hybrid(size_t query, size_t key_new, size_t value_new, float scale, size_t position_offset,
432
468
  const int8_t* cached_keys, const int8_t* cached_values,
433
469
  const float* k_scales, const float* v_scales,
434
- size_t cache_len, size_t num_kv_heads, size_t head_dim);
470
+ size_t cache_len, size_t num_kv_heads, size_t head_dim, size_t window_size = 0);
435
471
 
436
472
  size_t conv1d_causal(size_t input, size_t weight, size_t kernel_size, size_t dilation = 1);
437
473
  size_t conv1d_k3(size_t input, size_t weight, size_t stride);
438
-
474
+ size_t conv1d_k7s3(size_t input, size_t weight, size_t bias);
475
+ size_t conv1d(size_t input, size_t weight, size_t stride);
476
+ size_t conv1d(size_t input, size_t weight, size_t bias, size_t stride);
477
+
478
+ size_t lstm_cell(size_t input, size_t h_prev, size_t c_prev, size_t weight_ih, size_t weight_hh, size_t bias_ih, size_t bias_hh);
479
+ size_t stft_magnitude(size_t input, size_t weight, size_t stride, size_t num_fft_bins);
480
+
439
481
  size_t sample(size_t logits, float temperature = 0.6f, float top_p = 0.95f, size_t top_k = 20,
440
482
  const std::unordered_map<uint32_t, float>& logit_bias = {});
441
483
 
@@ -462,6 +504,10 @@ public:
462
504
  void allocate_buffers();
463
505
  size_t get_node_count() const;
464
506
 
507
+ size_t persistent(size_t source_node);
508
+ bool is_populated(size_t persistent_node_id) const;
509
+ void invalidate_persistent(size_t persistent_node_id);
510
+
465
511
  std::vector<std::unique_ptr<GraphNode>> nodes_;
466
512
  std::unordered_map<size_t, size_t> node_index_map_;
467
513
 
@@ -473,6 +519,9 @@ private:
473
519
  std::vector<DebugNodeEntry> debug_nodes_;
474
520
  BufferPool buffer_pool_;
475
521
  bool prefill_mode_ = false;
522
+
523
+ std::unordered_set<size_t> persistent_node_ids_;
524
+ std::unordered_set<size_t> populated_node_ids_;
476
525
  };
477
526
 
478
527
 
@@ -485,7 +534,6 @@ namespace GraphFile {
485
534
  };
486
535
 
487
536
  void save_node(CactusGraph& graph, size_t node_id, const std::string& filename);
488
- LoadedNode load_into_graph(CactusGraph& graph, const std::string& filename);
489
537
 
490
538
  class MappedFile {
491
539
  public:
@@ -499,16 +547,14 @@ namespace GraphFile {
499
547
 
500
548
  const std::vector<size_t>& shape() const;
501
549
  Precision precision() const;
502
- Precision effective_precision() const {
503
- return is_int4_ ? Precision::INT8 : precision_;
504
- }
505
550
  size_t byte_size() const;
506
551
 
507
552
  size_t group_size() const { return group_size_; }
508
553
  size_t num_groups() const { return num_groups_; }
509
554
  const void* scales_data() const;
510
- const void* raw_packed_data() const; // Get raw mmap'd data without unpacking (for INT4)
511
- bool is_int4() const { return is_int4_; }
555
+
556
+ bool is_interleaved() const { return is_interleaved_; }
557
+ size_t original_N() const { return original_N_; }
512
558
 
513
559
  void* data();
514
560
  const void* data() const;
@@ -516,8 +562,6 @@ namespace GraphFile {
516
562
  template<typename T>
517
563
  const T* typed_data() const;
518
564
 
519
- LoadedNode load_into_graph(CactusGraph& graph) const;
520
-
521
565
  void release_pages();
522
566
  void prefetch_pages();
523
567
 
@@ -532,16 +576,17 @@ namespace GraphFile {
532
576
  size_t num_groups_ = 0;
533
577
  size_t scales_offset_ = 0;
534
578
  size_t scales_bytes_ = 0;
535
- uint32_t version_ = 1;
536
579
  uint32_t alignment_ = 32;
537
- bool is_int4_ = false;
538
- mutable std::unique_ptr<int8_t[]> unpacked_int4_data_;
580
+
581
+ bool is_interleaved_ = false;
582
+ size_t original_N_ = 0;
583
+
584
+ std::unique_ptr<int8_t[]> unpacked_data_;
585
+
539
586
  void parse_header();
540
587
  void apply_madvise_hints();
541
- void unpack_int4_if_needed() const;
588
+ void unpack_int4_data();
542
589
  };
543
-
544
- MappedFile mmap_load(const std::string& filename);
545
590
  }
546
591
 
547
592
  #endif
@@ -15,7 +15,7 @@ enum class ScalarOpType {
15
15
  SIN
16
16
  };
17
17
 
18
- constexpr size_t KV_QUANT_GROUP_SIZE = 128;
18
+ constexpr size_t KV_QUANT_GROUP_SIZE = 32;
19
19
 
20
20
  void cactus_add_f16(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
21
21
  void cactus_add_f16_clipped(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
@@ -38,14 +38,18 @@ void cactus_divide_broadcast_f16(const __fp16* a, const __fp16* b, __fp16* outpu
38
38
 
39
39
  void cactus_scalar_op_f16(const __fp16* input, __fp16* output, size_t num_elements, float scalar_value, ScalarOpType op_type);
40
40
 
41
+ void cactus_gemv_int8(const int8_t* A, float A_scale,
42
+ const int8_t* B, const __fp16* B_scales,
43
+ __fp16* C, size_t K, size_t N, size_t group_size);
44
+
45
+ void cactus_gemm_int8(const int8_t* A, const float* A_scales,
46
+ const int8_t* B, const __fp16* B_scales,
47
+ __fp16* C, size_t M, size_t K, size_t N, size_t group_size);
48
+
41
49
  void cactus_matmul_int8(const int8_t* A, const float* A_scales,
42
50
  const int8_t* B, const __fp16* B_scales,
43
51
  __fp16* C, size_t M, size_t K, size_t N, size_t group_size);
44
52
 
45
- void cactus_matmul_int4(const int8_t* A, const float* A_scales,
46
- const uint8_t* B_packed, const __fp16* B_scales,
47
- __fp16* C, size_t M, size_t K, size_t N, size_t group_size);
48
-
49
53
  void cactus_matmul_f16(const __fp16* a, const __fp16* b_transposed, __fp16* c,
50
54
  size_t M, size_t K, size_t N);
51
55
 
@@ -75,15 +79,24 @@ void cactus_rms_norm_f16(const __fp16* input, const __fp16* weight, __fp16* outp
75
79
  void cactus_rope_f16(const __fp16* input, __fp16* output, size_t batch_size, size_t seq_len,
76
80
  size_t num_heads, size_t head_dim, size_t start_pos, float theta);
77
81
 
82
+ void cactus_gpt_j_rope_f16(const __fp16* input, __fp16* output, size_t batch_size, size_t seq_len,
83
+ size_t num_heads, size_t head_dim, size_t rot_dim, size_t start_pos, float theta);
84
+
78
85
  void cactus_softmax_f16(const __fp16* input, __fp16* output, size_t batch_size,
79
86
  size_t seq_len, size_t vocab_size);
80
87
 
88
+ void cactus_relu_f16(const __fp16* input, __fp16* output, size_t num_elements);
89
+
81
90
  void cactus_silu_f16(const __fp16* input, __fp16* output, size_t num_elements);
82
91
 
83
92
  void cactus_gelu_f16(const __fp16* input, __fp16* output, size_t num_elements);
84
93
 
85
94
  void cactus_gelu_f16_erf(const __fp16* input, __fp16* output, size_t num_elements);
86
95
 
96
+ void cactus_sigmoid_f16(const __fp16* input, __fp16* output, size_t num_elements);
97
+
98
+ void cactus_tanh_f16(const __fp16* input, __fp16* output, size_t num_elements);
99
+
87
100
  void cactus_attention_f16(const __fp16* queries, const __fp16* keys, const __fp16* values, __fp16* output,
88
101
  size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
89
102
  size_t head_dim, float scale, const __fp16* mask, size_t position_offset = 0, size_t window_size = 0,
@@ -100,7 +113,7 @@ void cactus_attention_hybrid_int8_fp16(
100
113
  __fp16* output,
101
114
  size_t batch_size, size_t seq_len, size_t cache_len, size_t new_len,
102
115
  size_t num_q_heads, size_t num_kv_heads, size_t head_dim,
103
- float scale, size_t position_offset = 0, bool is_causal = true,
116
+ float scale, size_t position_offset = 0, bool is_causal = true, size_t window_size = 0,
104
117
  size_t group_size = KV_QUANT_GROUP_SIZE);
105
118
 
106
119
  void cactus_conv1d_causal_depthwise_f16(
@@ -124,6 +137,40 @@ void cactus_conv1d_f16_k3(
124
137
  size_t stride
125
138
  );
126
139
 
140
+ void cactus_conv1d_f16(
141
+ const __fp16* input,
142
+ const __fp16* weight,
143
+ const __fp16* bias,
144
+ __fp16* output,
145
+ size_t N,
146
+ size_t L,
147
+ size_t C_in,
148
+ size_t C_out,
149
+ size_t K,
150
+ size_t stride
151
+ );
152
+
153
+ void cactus_stft_magnitude_f16(
154
+ const __fp16* input,
155
+ const __fp16* weight,
156
+ __fp16* output,
157
+ size_t N, size_t L,
158
+ size_t C_in, size_t C_out,
159
+ size_t K, size_t stride,
160
+ size_t num_fft_bins
161
+ );
162
+
163
+ void cactus_conv1d_f16_k7s3_oc8(
164
+ const __fp16* input,
165
+ const __fp16* Wpack,
166
+ const __fp16* bias,
167
+ __fp16* output,
168
+ size_t N,
169
+ size_t L,
170
+ size_t C_in,
171
+ size_t C_out
172
+ );
173
+
127
174
  void cactus_bilinear_interpolation_f16(const __fp16* input, __fp16* output, size_t src_height, size_t src_width, size_t embed_dim,
128
175
  size_t dst_height, size_t dst_width);
129
176
 
@@ -162,4 +209,19 @@ inline size_t kv_scales_count(size_t seq_len, size_t kv_heads, size_t head_dim,
162
209
 
163
210
  void cactus_unpack_int4_to_int8(const uint8_t* packed, int8_t* unpacked, size_t unpacked_count);
164
211
 
212
+ void cactus_lstm_cell_f16(
213
+ const __fp16* x_input,
214
+ const __fp16* h_prev,
215
+ const __fp16* c_prev,
216
+ const __fp16* weight_ih,
217
+ const __fp16* weight_hh,
218
+ const __fp16* bias_ih,
219
+ const __fp16* bias_hh,
220
+ __fp16* h_new,
221
+ __fp16* c_new,
222
+ size_t batch_size,
223
+ size_t input_size,
224
+ size_t hidden_size
225
+ );
226
+
165
227
  #endif