cactus-react-native 1.5.0 → 1.10.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. package/Cactus.podspec +1 -1
  2. package/README.md +347 -241
  3. package/android/CMakeLists.txt +24 -5
  4. package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libcurl.a +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libmbedcrypto.a +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libmbedtls.a +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libmbedx509.a +0 -0
  9. package/cpp/HybridCactus.cpp +197 -117
  10. package/cpp/HybridCactus.hpp +18 -9
  11. package/cpp/cactus_ffi.h +66 -42
  12. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus.h +0 -1
  13. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_cloud.h +48 -0
  14. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +66 -42
  15. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_utils.h +568 -135
  16. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +148 -17
  17. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +145 -36
  18. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +187 -6
  19. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +49 -149
  20. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Info.plist +0 -0
  21. package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
  22. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus.h +0 -1
  23. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_cloud.h +48 -0
  24. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +66 -42
  25. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_utils.h +568 -135
  26. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +148 -17
  27. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +145 -36
  28. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +187 -6
  29. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +49 -149
  30. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Info.plist +0 -0
  31. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/_CodeSignature/CodeResources +1 -1
  32. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
  33. package/lib/module/classes/CactusLM.js +16 -49
  34. package/lib/module/classes/CactusLM.js.map +1 -1
  35. package/lib/module/classes/CactusSTT.js +41 -75
  36. package/lib/module/classes/CactusSTT.js.map +1 -1
  37. package/lib/module/classes/CactusVAD.js +95 -0
  38. package/lib/module/classes/CactusVAD.js.map +1 -0
  39. package/lib/module/hooks/useCactusLM.js +10 -11
  40. package/lib/module/hooks/useCactusLM.js.map +1 -1
  41. package/lib/module/hooks/useCactusSTT.js +23 -62
  42. package/lib/module/hooks/useCactusSTT.js.map +1 -1
  43. package/lib/module/hooks/useCactusVAD.js +171 -0
  44. package/lib/module/hooks/useCactusVAD.js.map +1 -0
  45. package/lib/module/index.js +2 -3
  46. package/lib/module/index.js.map +1 -1
  47. package/lib/module/modelRegistry.js +52 -0
  48. package/lib/module/modelRegistry.js.map +1 -0
  49. package/lib/module/native/Cactus.js +103 -23
  50. package/lib/module/native/Cactus.js.map +1 -1
  51. package/lib/module/native/CactusIndex.js.map +1 -1
  52. package/lib/module/native/index.js +0 -3
  53. package/lib/module/native/index.js.map +1 -1
  54. package/lib/module/types/CactusVAD.js +4 -0
  55. package/lib/module/{specs/CactusUtil.nitro.js.map → types/CactusVAD.js.map} +1 -1
  56. package/lib/typescript/src/classes/CactusLM.d.ts +5 -7
  57. package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
  58. package/lib/typescript/src/classes/CactusSTT.d.ts +9 -12
  59. package/lib/typescript/src/classes/CactusSTT.d.ts.map +1 -1
  60. package/lib/typescript/src/classes/CactusVAD.d.ts +20 -0
  61. package/lib/typescript/src/classes/CactusVAD.d.ts.map +1 -0
  62. package/lib/typescript/src/hooks/useCactusLM.d.ts +2 -2
  63. package/lib/typescript/src/hooks/useCactusLM.d.ts.map +1 -1
  64. package/lib/typescript/src/hooks/useCactusSTT.d.ts +6 -8
  65. package/lib/typescript/src/hooks/useCactusSTT.d.ts.map +1 -1
  66. package/lib/typescript/src/hooks/useCactusVAD.d.ts +15 -0
  67. package/lib/typescript/src/hooks/useCactusVAD.d.ts.map +1 -0
  68. package/lib/typescript/src/index.d.ts +7 -5
  69. package/lib/typescript/src/index.d.ts.map +1 -1
  70. package/lib/typescript/src/modelRegistry.d.ts +5 -0
  71. package/lib/typescript/src/modelRegistry.d.ts.map +1 -0
  72. package/lib/typescript/src/native/Cactus.d.ts +13 -11
  73. package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
  74. package/lib/typescript/src/native/CactusIndex.d.ts +2 -2
  75. package/lib/typescript/src/native/CactusIndex.d.ts.map +1 -1
  76. package/lib/typescript/src/native/index.d.ts +0 -3
  77. package/lib/typescript/src/native/index.d.ts.map +1 -1
  78. package/lib/typescript/src/specs/Cactus.nitro.d.ts +7 -6
  79. package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
  80. package/lib/typescript/src/types/CactusIndex.d.ts +2 -2
  81. package/lib/typescript/src/types/CactusIndex.d.ts.map +1 -1
  82. package/lib/typescript/src/types/CactusLM.d.ts +19 -11
  83. package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
  84. package/lib/typescript/src/types/CactusSTT.d.ts +44 -12
  85. package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
  86. package/lib/typescript/src/types/CactusVAD.d.ts +34 -0
  87. package/lib/typescript/src/types/CactusVAD.d.ts.map +1 -0
  88. package/lib/typescript/src/types/common.d.ts +1 -6
  89. package/lib/typescript/src/types/common.d.ts.map +1 -1
  90. package/nitro.json +0 -11
  91. package/nitrogen/generated/android/cactus+autolinking.cmake +0 -5
  92. package/nitrogen/generated/android/cactusOnLoad.cpp +0 -30
  93. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.cpp +0 -50
  94. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.hpp +9 -147
  95. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Umbrella.hpp +0 -13
  96. package/nitrogen/generated/ios/CactusAutolinking.mm +0 -26
  97. package/nitrogen/generated/ios/CactusAutolinking.swift +0 -30
  98. package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +5 -4
  99. package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +7 -6
  100. package/package.json +3 -3
  101. package/src/classes/CactusLM.ts +18 -65
  102. package/src/classes/CactusSTT.ts +52 -90
  103. package/src/classes/CactusVAD.ts +129 -0
  104. package/src/hooks/useCactusLM.ts +14 -17
  105. package/src/hooks/useCactusSTT.ts +47 -98
  106. package/src/hooks/useCactusVAD.ts +215 -0
  107. package/src/index.tsx +21 -12
  108. package/src/modelRegistry.ts +65 -0
  109. package/src/native/Cactus.ts +131 -38
  110. package/src/native/CactusIndex.ts +2 -2
  111. package/src/native/index.ts +0 -3
  112. package/src/specs/Cactus.nitro.ts +16 -7
  113. package/src/types/CactusIndex.ts +2 -2
  114. package/src/types/CactusLM.ts +19 -11
  115. package/src/types/CactusSTT.ts +47 -13
  116. package/src/types/CactusVAD.ts +39 -0
  117. package/src/types/common.ts +1 -6
  118. package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusCrypto.kt +0 -46
  119. package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusDeviceInfo.kt +0 -27
  120. package/android/src/main/jniLibs/arm64-v8a/libcactus_util.a +0 -0
  121. package/cpp/HybridCactusUtil.cpp +0 -47
  122. package/cpp/HybridCactusUtil.hpp +0 -27
  123. package/cpp/cactus_util.h +0 -25
  124. package/ios/HybridCactusCrypto.swift +0 -37
  125. package/ios/HybridCactusDeviceInfo.swift +0 -32
  126. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_telemetry.h +0 -656
  127. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_telemetry.h +0 -656
  128. package/ios/cactus_util.xcframework/Info.plist +0 -39
  129. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/cactus_util.h +0 -25
  130. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/database.h +0 -27
  131. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/ios_utils.h +0 -10
  132. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/logging.h +0 -25
  133. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Info.plist +0 -0
  134. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/cactus_util +0 -0
  135. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/cactus_util.h +0 -25
  136. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/database.h +0 -27
  137. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/ios_utils.h +0 -10
  138. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/logging.h +0 -25
  139. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Info.plist +0 -0
  140. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/_CodeSignature/CodeResources +0 -135
  141. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/cactus_util +0 -0
  142. package/lib/module/api/Database.js +0 -45
  143. package/lib/module/api/Database.js.map +0 -1
  144. package/lib/module/api/RemoteLM.js +0 -201
  145. package/lib/module/api/RemoteLM.js.map +0 -1
  146. package/lib/module/config/CactusConfig.js +0 -12
  147. package/lib/module/config/CactusConfig.js.map +0 -1
  148. package/lib/module/models.js +0 -336
  149. package/lib/module/models.js.map +0 -1
  150. package/lib/module/native/CactusCrypto.js +0 -10
  151. package/lib/module/native/CactusCrypto.js.map +0 -1
  152. package/lib/module/native/CactusDeviceInfo.js +0 -13
  153. package/lib/module/native/CactusDeviceInfo.js.map +0 -1
  154. package/lib/module/native/CactusUtil.js +0 -36
  155. package/lib/module/native/CactusUtil.js.map +0 -1
  156. package/lib/module/specs/CactusCrypto.nitro.js +0 -4
  157. package/lib/module/specs/CactusCrypto.nitro.js.map +0 -1
  158. package/lib/module/specs/CactusDeviceInfo.nitro.js +0 -4
  159. package/lib/module/specs/CactusDeviceInfo.nitro.js.map +0 -1
  160. package/lib/module/specs/CactusUtil.nitro.js +0 -4
  161. package/lib/module/telemetry/Telemetry.js +0 -154
  162. package/lib/module/telemetry/Telemetry.js.map +0 -1
  163. package/lib/typescript/src/api/Database.d.ts +0 -12
  164. package/lib/typescript/src/api/Database.d.ts.map +0 -1
  165. package/lib/typescript/src/api/RemoteLM.d.ts +0 -14
  166. package/lib/typescript/src/api/RemoteLM.d.ts.map +0 -1
  167. package/lib/typescript/src/config/CactusConfig.d.ts +0 -7
  168. package/lib/typescript/src/config/CactusConfig.d.ts.map +0 -1
  169. package/lib/typescript/src/models.d.ts +0 -6
  170. package/lib/typescript/src/models.d.ts.map +0 -1
  171. package/lib/typescript/src/native/CactusCrypto.d.ts +0 -5
  172. package/lib/typescript/src/native/CactusCrypto.d.ts.map +0 -1
  173. package/lib/typescript/src/native/CactusDeviceInfo.d.ts +0 -7
  174. package/lib/typescript/src/native/CactusDeviceInfo.d.ts.map +0 -1
  175. package/lib/typescript/src/native/CactusUtil.d.ts +0 -6
  176. package/lib/typescript/src/native/CactusUtil.d.ts.map +0 -1
  177. package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts +0 -8
  178. package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts.map +0 -1
  179. package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts +0 -16
  180. package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts.map +0 -1
  181. package/lib/typescript/src/specs/CactusUtil.nitro.d.ts +0 -10
  182. package/lib/typescript/src/specs/CactusUtil.nitro.d.ts.map +0 -1
  183. package/lib/typescript/src/telemetry/Telemetry.d.ts +0 -34
  184. package/lib/typescript/src/telemetry/Telemetry.d.ts.map +0 -1
  185. package/nitrogen/generated/android/c++/JDeviceInfo.hpp +0 -74
  186. package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.cpp +0 -65
  187. package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.hpp +0 -65
  188. package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.cpp +0 -85
  189. package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.hpp +0 -66
  190. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/DeviceInfo.kt +0 -50
  191. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusCryptoSpec.kt +0 -58
  192. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusDeviceInfoSpec.kt +0 -62
  193. package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.cpp +0 -11
  194. package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.hpp +0 -77
  195. package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.cpp +0 -11
  196. package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.hpp +0 -88
  197. package/nitrogen/generated/ios/swift/DeviceInfo.swift +0 -98
  198. package/nitrogen/generated/ios/swift/Func_void_DeviceInfo.swift +0 -47
  199. package/nitrogen/generated/ios/swift/Func_void_std__optional_std__string_.swift +0 -54
  200. package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec.swift +0 -57
  201. package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec_cxx.swift +0 -139
  202. package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec.swift +0 -58
  203. package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec_cxx.swift +0 -164
  204. package/nitrogen/generated/shared/c++/DeviceInfo.hpp +0 -92
  205. package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.cpp +0 -21
  206. package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.hpp +0 -63
  207. package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.cpp +0 -22
  208. package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.hpp +0 -67
  209. package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.cpp +0 -23
  210. package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.hpp +0 -66
  211. package/src/api/Database.ts +0 -55
  212. package/src/api/RemoteLM.ts +0 -273
  213. package/src/config/CactusConfig.ts +0 -11
  214. package/src/models.ts +0 -344
  215. package/src/native/CactusCrypto.ts +0 -11
  216. package/src/native/CactusDeviceInfo.ts +0 -18
  217. package/src/native/CactusUtil.ts +0 -43
  218. package/src/specs/CactusCrypto.nitro.ts +0 -6
  219. package/src/specs/CactusDeviceInfo.nitro.ts +0 -15
  220. package/src/specs/CactusUtil.nitro.ts +0 -8
  221. package/src/telemetry/Telemetry.ts +0 -236
@@ -56,6 +56,12 @@ struct Config {
56
56
  uint32_t num_shared_experts = 0;
57
57
  uint32_t num_top_experts = 0;
58
58
  uint32_t moe_every_n_layers = 0;
59
+ uint32_t moe_intermediate_dim = 0;
60
+ uint32_t num_dense_layers = 0;
61
+ uint32_t num_experts_per_tok = 0;
62
+ bool norm_topk_prob = false;
63
+ bool use_expert_bias = false;
64
+ float routed_scaling_factor = 1.0f;
59
65
  bool tie_word_embeddings = true;
60
66
 
61
67
  uint32_t vision_hidden_dim = 0;
@@ -84,12 +90,31 @@ struct Config {
84
90
  bool use_thumbnail = true;
85
91
  uint32_t min_image_tokens = 64;
86
92
  uint32_t max_image_tokens = 256;
87
- uint32_t max_num_patches = 1024;
93
+ uint32_t max_num_patches = 1024;
88
94
  uint32_t tile_size = 512;
89
95
  float max_pixels_tolerance = 2.0f;
90
96
  bool do_image_splitting = true;
91
-
92
- enum class ModelType {QWEN = 0, GEMMA = 1, SMOL = 2, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6, WHISPER = 7};
97
+ bool encoder_act_gelu = false;
98
+ bool decoder_act_gelu = false;
99
+ uint32_t num_encoder_layers = 0;
100
+ uint32_t num_decoder_layers = 0;
101
+ float partial_rotary_factor = 0.0f;
102
+ uint32_t pad_token_id = 0;
103
+ uint32_t conv_kernel_size = 0;
104
+ uint32_t subsampling_conv_kernel_size = 0;
105
+ uint32_t subsampling_conv_stride = 0;
106
+ uint32_t subsampling_conv_channels = 0;
107
+ uint32_t subsampling_factor = 0;
108
+ uint32_t num_mel_bins = 80;
109
+ std::string encoder_hidden_act = "silu";
110
+ uint32_t predictor_hidden_dim = 0;
111
+ uint32_t predictor_num_layers = 0;
112
+ uint32_t tdt_joint_dim = 0;
113
+ uint32_t tdt_num_durations = 0;
114
+ uint32_t tdt_blank_id = 0;
115
+ std::vector<uint32_t> tdt_durations;
116
+
117
+ enum class ModelType {QWEN = 0, GEMMA = 1, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6, WHISPER = 7, MOONSHINE = 8, SILERO_VAD = 9, PARAKEET = 10, PARAKEET_TDT = 11};
93
118
  ModelType model_type = ModelType::QWEN;
94
119
 
95
120
  enum class ModelVariant {DEFAULT = 0, VLM = 1, EXTRACT = 2, RAG = 3};
@@ -107,6 +132,8 @@ struct Config {
107
132
  float default_temperature = 0.6f;
108
133
  float default_top_p = 0.95f;
109
134
  size_t default_top_k = 20;
135
+ float default_max_tps = -1.0f;
136
+ float default_cloud_handoff_threshold = 0.0f;
110
137
 
111
138
  std::vector<std::string> layer_types;
112
139
  size_t conv_L_cache = 0;
@@ -152,6 +179,7 @@ public:
152
179
  virtual uint32_t get_bos_token() const = 0;
153
180
  virtual uint32_t get_eos_token() const = 0;
154
181
  virtual bool has_chat_template() const { return has_chat_template_; }
182
+ std::string get_default_stop_sequence() const;
155
183
 
156
184
  virtual bool load_vocabulary_with_config(const std::string& vocab_file, const std::string& merges_file, const std::string& config_file) = 0;
157
185
 
@@ -159,11 +187,8 @@ public:
159
187
  uint32_t get_fake_token_id() const { return fake_token_id_; }
160
188
  uint32_t get_global_img_token_id() const { return global_img_token_id_; }
161
189
 
162
-
163
- void set_corpus_dir(const std::string& dir) { corpus_dir_ = dir; }
164
-
165
190
  protected:
166
- enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2, SMOL, BERT, WHISPER};
191
+ enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2, BERT, WHISPER, PARAKEET};
167
192
  ModelType model_type_ = ModelType::UNKNOWN;
168
193
  enum class ModelVariant { DEFAULT, VLM, EXTRACT, RAG};
169
194
  ModelVariant model_variant_ = ModelVariant::DEFAULT;
@@ -173,14 +198,12 @@ protected:
173
198
  uint32_t image_token_id_ = 396;
174
199
  uint32_t fake_token_id_ = 49189;
175
200
  uint32_t global_img_token_id_ = 49152;
176
- std::string corpus_dir_;
177
201
 
178
202
  void detect_model_type(const std::string& config_path);
179
203
  std::string format_qwen_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
180
204
  std::string format_gemma_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
181
205
  std::string format_lfm2_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
182
206
  std::string format_lfm2_vl_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
183
- std::string format_smol_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
184
207
  };
185
208
 
186
209
  class BPETokenizer : public Tokenizer {
@@ -363,7 +386,6 @@ struct KVCache {
363
386
  size_t num_tokens, size_t kv_heads, size_t head_dim);
364
387
 
365
388
  bool is_empty() const { return current_seq_len == 0; }
366
- bool is_int8() const { return precision == Precision::INT8; }
367
389
  void* get_key_ptr(size_t layer);
368
390
  void* get_value_ptr(size_t layer);
369
391
 
@@ -471,6 +493,8 @@ private:
471
493
  void compute_bias();
472
494
  void tokenize_grammar_elements();
473
495
  void add_tokens_for_string(const std::string& str, std::unordered_set<uint32_t>& token_set);
496
+ void tokenize_function_names(bool quote_names);
497
+ void init_common_tokens();
474
498
  };
475
499
 
476
500
  class Model {
@@ -495,22 +519,22 @@ public:
495
519
  const std::string& system_prompt = "", bool do_warmup = true);
496
520
 
497
521
  virtual uint32_t decode(const std::vector<uint32_t>& tokens, float temperature = -1.0f, float top_p = -1.0f,
498
- size_t top_k = 0, const std::string& profile_file = "");
522
+ size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr);
499
523
 
500
524
  virtual void prefill(const std::vector<uint32_t>& tokens, size_t chunk_size = 256, const std::string& profile_file = "");
501
525
 
502
526
  virtual uint32_t decode_with_images(const std::vector<uint32_t>& tokens, const std::vector<std::string>& image_paths,
503
527
  float temperature = -1.0f, float top_p = -1.0f,
504
- size_t top_k = 0, const std::string& profile_file = "");
528
+ size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr);
505
529
 
506
- virtual uint32_t decode_with_audio(const std::vector<uint32_t>& tokens, const std::vector<float>& mel_bins, float temperature = 0.0f, float top_p = 0.0f,
507
- size_t top_k = 0, const std::string& profile_file = "");
530
+ virtual uint32_t decode_with_audio(const std::vector<uint32_t>& tokens, const std::vector<float>& audio_features, float temperature = 0.0f, float top_p = 0.0f,
531
+ size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr);
508
532
 
509
533
  std::vector<float> get_embeddings(const std::vector<uint32_t>& tokens, bool pooled = true, bool normalize = false, const std::string& profile_file = "");
510
534
 
511
535
  virtual std::vector<float> get_image_embeddings(const std::string& image_path);
512
536
 
513
- virtual std::vector<float> get_audio_embeddings(const std::vector<float>& mel_bins);
537
+ virtual std::vector<float> get_audio_embeddings(const std::vector<float>& audio_features);
514
538
 
515
539
  virtual void reset_cache() { kv_cache_.reset(); }
516
540
 
@@ -533,7 +557,7 @@ public:
533
557
  protected:
534
558
  virtual size_t forward(const std::vector<uint32_t>& tokens, bool use_cache = false) = 0;
535
559
 
536
- virtual size_t forward(const std::vector<float>& mel_bins, const std::vector<uint32_t>& tokens, bool use_cache = false);
560
+ virtual size_t forward(const std::vector<float>& audio_features, const std::vector<uint32_t>& tokens, bool use_cache = false);
537
561
 
538
562
  virtual void load_weights_to_graph(CactusGraph* gb) = 0;
539
563
 
@@ -645,6 +669,7 @@ public:
645
669
  private:
646
670
  Config config_;
647
671
 
672
+ std::pair<int64_t, int64_t> compute_pixel_limits() const;
648
673
  std::vector<unsigned char> convert_to_rgb(const unsigned char* img_data, int width, int height, int channels);
649
674
  std::pair<int, int> smart_resize(int height, int width);
650
675
  bool is_image_too_large(int height, int width);
@@ -678,6 +703,8 @@ public:
678
703
  float reference = 1.0f;
679
704
  float min_value = 1e-10f;
680
705
  bool remove_dc_offset = false;
706
+ float preemphasis = 0.0f;
707
+ bool hann_periodic = true;
681
708
  };
682
709
 
683
710
  AudioProcessor();
@@ -690,6 +717,11 @@ public:
690
717
  const std::vector<float>& waveform,
691
718
  const SpectrogramConfig& config);
692
719
 
720
+ static std::vector<float> compute_irfft(
721
+ const std::vector<float>& complex_input,
722
+ size_t n,
723
+ const char* norm = "backward");
724
+
693
725
  const std::vector<float>& get_mel_filters() const { return mel_filters_; }
694
726
 
695
727
  size_t get_num_mel_filters() const { return num_mel_filters_; }
@@ -701,5 +733,104 @@ private:
701
733
  size_t num_mel_filters_;
702
734
  };
703
735
 
736
+ namespace index {
737
+ constexpr uint32_t MAGIC = 0x43414354;
738
+ constexpr uint32_t VERSION = 1;
739
+
740
+ struct Document {
741
+ int id;
742
+ std::vector<float> embedding;
743
+ std::string content;
744
+ std::string metadata;
745
+ };
746
+
747
+ struct QueryResult {
748
+ int doc_id;
749
+ float score;
750
+
751
+ QueryResult(int doc_id, float score) : doc_id(doc_id), score(score) {}
752
+ };
753
+
754
+ struct QueryOptions {
755
+ size_t top_k = 10;
756
+ float score_threshold = -1.0f;
757
+ };
758
+
759
+ class Index {
760
+ public:
761
+ Index(const std::string& index_path, const std::string& data_path, size_t embedding_dim);
762
+ ~Index();
763
+
764
+ Index(const Index&) = delete;
765
+ Index& operator=(const Index&) = delete;
766
+ Index(Index&&) = delete;
767
+ Index& operator=(Index&&) = delete;
768
+
769
+ void add_documents(const std::vector<Document>& documents);
770
+ void delete_documents(const std::vector<int>& doc_ids);
771
+ std::vector<Document> get_documents(const std::vector<int>& doc_ids);
772
+ std::vector<std::vector<QueryResult>> query(const std::vector<std::vector<float>>& embeddings, const QueryOptions& options);
773
+ void compact();
774
+
775
+ private:
776
+ struct IndexHeader {
777
+ uint32_t magic;
778
+ uint32_t version;
779
+ uint32_t embedding_dim;
780
+ uint32_t num_documents;
781
+ };
782
+
783
+ struct IndexEntry {
784
+ int32_t doc_id;
785
+ uint64_t data_offset;
786
+ uint8_t flags; // bit 0: tombstone
787
+
788
+ const __fp16* embedding() const {
789
+ return reinterpret_cast<const __fp16*>(this + 1);
790
+ }
791
+
792
+ static size_t size(size_t embedding_dim) {
793
+ return sizeof(IndexEntry) + embedding_dim * sizeof(__fp16);
794
+ }
795
+ };
796
+
797
+ struct DataHeader {
798
+ uint32_t magic;
799
+ uint32_t version;
800
+ };
801
+
802
+ struct DataEntry {
803
+ uint16_t content_len;
804
+ uint16_t metadata_len;
805
+
806
+ const char* content() const {
807
+ return reinterpret_cast<const char*>(this + 1);
808
+ }
809
+
810
+ const char* metadata() const {
811
+ return content() + content_len;
812
+ }
813
+ };
814
+
815
+ void parse_index_header();
816
+ void parse_data_header();
817
+ void build_doc_id_map();
818
+ void validate_documents(const std::vector<Document>& documents);
819
+ void validate_doc_ids(const std::vector<int>& doc_ids);
820
+ ssize_t write_full(int fd, const void* buf, size_t count);
821
+
822
+ std::unordered_map<int, uint32_t> doc_id_map_;
823
+
824
+ std::string index_path_, data_path_;
825
+ size_t embedding_dim_;
826
+ size_t index_entry_size_;
827
+ uint32_t num_documents_;
828
+
829
+ int index_fd_, data_fd_;
830
+ void *mapped_index_, *mapped_data_;
831
+ size_t index_file_size_, data_file_size_;
832
+ };
833
+ } // namespace index
834
+
835
+ }
704
836
  }
705
- }
@@ -4,7 +4,9 @@
4
4
  #include <vector>
5
5
  #include <memory>
6
6
  #include <unordered_map>
7
+ #include <unordered_set>
7
8
  #include <functional>
9
+ #include <cassert>
8
10
  #include <cstring>
9
11
  #include <stdexcept>
10
12
  #include <string>
@@ -108,23 +110,36 @@ enum class ComputeBackend {
108
110
  NPU
109
111
  };
110
112
 
113
+ enum class Activation {
114
+ SILU,
115
+ GELU,
116
+ GELU_ERF,
117
+ RELU,
118
+ SIGMOID,
119
+ TANH
120
+ };
121
+
111
122
  enum class OpType {
112
123
  INPUT, PRECISION_CAST,
113
124
  ADD, ADD_CLIPPED, SUBTRACT, MULTIPLY, DIVIDE,
114
125
  MATMUL, TRANSPOSE, RESHAPE, SLICE, GATHER, EMBEDDING,
115
126
  BILINEAR_INTERPOLATION,
116
127
  SUM, MEAN, VARIANCE, MIN, MAX,
117
- RMS_NORM, ROPE, SOFTMAX, ATTENTION, ATTENTION_INT8_HYBRID, CONV1D_CAUSAL, CONV1D_K3,
118
- SCALAR_ADD, SCALAR_SUBTRACT, SCALAR_MULTIPLY, SCALAR_DIVIDE, SCALAR_EXP, SCALAR_SQRT, SCALAR_COS, SCALAR_SIN,
119
- SILU, GELU, GELU_ERF,
128
+ RMS_NORM, ROPE, ROPE_GPTJ, SOFTMAX, ATTENTION, ATTENTION_INT8_HYBRID, REL_POS_BIAS, CONV1D_CAUSAL, CONV1D_K3, CONV1D_K7S3, CONV1D, CONV1D_SAME_DEPTHWISE_K9, CONV1D_POINTWISE, CONV2D_K3S2P1, CONV2D_DEPTHWISE_K3S2P1, CONV2D_POINTWISE_1X1, GLU, BATCHNORM,
129
+ SCALAR_ADD, SCALAR_SUBTRACT, SCALAR_MULTIPLY, SCALAR_DIVIDE, SCALAR_EXP, SCALAR_SQRT, SCALAR_COS, SCALAR_SIN, SCALAR_LOG,
130
+ RELU, SILU, GELU, GELU_ERF, SIGMOID, TANH,
120
131
  SAMPLE, CONCAT,
121
132
  SCATTER_TOPK,
122
- TOPK, LAYERNORM,
133
+ TOPK, LAYERNORM, GROUPNORM,
134
+ MOE_LAYER,
123
135
  INDEX,
136
+ PERSISTENT,
137
+ QUANTIZE_ACTIVATIONS,
138
+ LSTM_CELL,
139
+ STFT
124
140
  };
125
141
 
126
142
  struct PrecisionTraits {
127
- // Returns in-memory element size (INT4 unpacks to INT8, so returns 1)
128
143
  static constexpr size_t size_of(Precision prec) {
129
144
  switch (prec) {
130
145
  case Precision::INT8: return 1;
@@ -137,11 +152,20 @@ struct PrecisionTraits {
137
152
 
138
153
  static constexpr size_t packed_size_of(Precision prec, size_t count) {
139
154
  switch (prec) {
140
- case Precision::INT4: return (count + 1) / 2;
155
+ case Precision::INT4: return (count + 1) / 2;
141
156
  default: return count * size_of(prec);
142
157
  }
143
158
  }
144
159
 
160
+ static size_t byte_offset_of(Precision prec, size_t element_offset) {
161
+ switch (prec) {
162
+ case Precision::INT4:
163
+ assert(element_offset % 32 == 0 && "INT4 byte offset must be group-aligned (multiple of 32)");
164
+ return element_offset / 2;
165
+ default: return element_offset * size_of(prec);
166
+ }
167
+ }
168
+
145
169
  static constexpr bool is_integer(Precision prec) {
146
170
  switch (prec) {
147
171
  case Precision::INT8: return true;
@@ -177,7 +201,6 @@ struct TensorConfig {
177
201
  Precision compute_precision = Precision::INT8;
178
202
  Precision output_precision = Precision::INT8;
179
203
  bool auto_mixed_precision = false;
180
- bool enable_int4_packing = true;
181
204
 
182
205
  static TensorConfig& global();
183
206
  };
@@ -205,8 +228,12 @@ struct BufferDesc {
205
228
  void* scales_data = nullptr;
206
229
  std::unique_ptr<char[]> owned_scales;
207
230
 
208
- const void* packed_int4_data = nullptr;
209
- size_t packed_int4_size = 0;
231
+ bool is_interleaved = false;
232
+ size_t original_N = 0;
233
+
234
+ void* activation_scales_data = nullptr;
235
+ std::unique_ptr<char[]> owned_activation_scales;
236
+ size_t num_rows_for_activation_scales = 0;
210
237
 
211
238
  BufferDesc();
212
239
  BufferDesc(const std::vector<size_t>& s, Precision prec = Precision::INT8);
@@ -230,23 +257,43 @@ struct BufferDesc {
230
257
  const __fp16* scales_as_fp16() const {
231
258
  return reinterpret_cast<const __fp16*>(scales_data);
232
259
  }
260
+
233
261
  bool is_grouped_int8() const {
234
262
  return precision == Precision::INT8 && group_size > 0;
235
263
  }
236
- bool is_packed_int4() const {
237
- return packed_int4_data != nullptr && packed_int4_size > 0;
238
- }
239
- const uint8_t* packed_int4_as_uint8() const {
240
- return reinterpret_cast<const uint8_t*>(packed_int4_data);
264
+
265
+ bool is_grouped_int4() const {
266
+ return precision == Precision::INT4 && group_size > 0;
241
267
  }
268
+
242
269
  void set_grouped_scales(size_t gs, size_t ng, void* scales_ptr) {
243
270
  group_size = gs;
244
271
  num_groups = ng;
245
272
  scales_data = scales_ptr;
246
273
  }
247
- void set_packed_int4(const void* packed_data, size_t packed_size) {
248
- packed_int4_data = packed_data;
249
- packed_int4_size = packed_size;
274
+
275
+ void set_interleaved(bool interleaved, size_t orig_n) {
276
+ is_interleaved = interleaved;
277
+ original_N = orig_n;
278
+ }
279
+
280
+ bool has_activation_scales() const {
281
+ return activation_scales_data != nullptr && num_rows_for_activation_scales > 0;
282
+ }
283
+ const float* activation_scales_as_float() const {
284
+ return reinterpret_cast<const float*>(activation_scales_data);
285
+ }
286
+ float* activation_scales_as_float() {
287
+ return reinterpret_cast<float*>(activation_scales_data);
288
+ }
289
+ void allocate_activation_scales(size_t num_rows) {
290
+ num_rows_for_activation_scales = num_rows;
291
+ owned_activation_scales = std::make_unique<char[]>(num_rows * sizeof(float));
292
+ activation_scales_data = owned_activation_scales.get();
293
+ }
294
+ void set_activation_scales(void* scales_ptr, size_t num_rows) {
295
+ activation_scales_data = scales_ptr;
296
+ num_rows_for_activation_scales = num_rows;
250
297
  }
251
298
 
252
299
  void allocate();
@@ -267,6 +314,7 @@ struct OpParams {
267
314
  size_t slice_length = 0;
268
315
  size_t window_size = 0;
269
316
  bool is_causal = true;
317
+ bool attention_mask_is_additive = false;
270
318
  std::vector<size_t> new_shape;
271
319
  std::vector<size_t> permutation;
272
320
  Precision output_precision = Precision::INT8;
@@ -282,8 +330,14 @@ struct OpParams {
282
330
 
283
331
  size_t index_value = 0;
284
332
  size_t num_classes = 0;
333
+ size_t num_groups = 0;
285
334
  size_t dst_height = 0;
286
335
  size_t dst_width = 0;
336
+ bool normalize_routing = false;
337
+ size_t num_experts = 0;
338
+ size_t num_experts_per_tok = 0;
339
+ bool moe_gated = true;
340
+ Activation activation = Activation::SILU;
287
341
 
288
342
  std::vector<float> bias_values;
289
343
  std::vector<uint32_t> bias_indices;
@@ -295,6 +349,7 @@ struct OpParams {
295
349
  size_t cache_seq_len = 0;
296
350
  size_t num_kv_heads = 0;
297
351
  size_t head_dim = 0;
352
+ size_t num_fft_bins = 0;
298
353
  };
299
354
 
300
355
  struct GraphNode {
@@ -324,10 +379,12 @@ void compute_sample_node(GraphNode& node, const std::vector<std::unique_ptr<Grap
324
379
  void compute_scatter_topk_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
325
380
  void compute_topk_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
326
381
  void compute_layernorm_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
382
+ void compute_groupnorm_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
383
+ void compute_persistent_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
327
384
  void compute_index_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
385
+ void compute_lstm_cell_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
328
386
 
329
387
  void shrink_thread_local_buffers();
330
-
331
388
  class BufferPool {
332
389
  public:
333
390
  BufferPool() = default;
@@ -372,6 +429,7 @@ public:
372
429
 
373
430
  size_t input(const std::vector<size_t>& shape, Precision precision = Precision::INT8);
374
431
  size_t precision_cast(size_t input, Precision target_precision);
432
+ size_t quantize_activations(size_t input);
375
433
 
376
434
  size_t add(size_t input1, size_t input2);
377
435
  size_t add_clipped(size_t input1, size_t input2);
@@ -388,10 +446,15 @@ public:
388
446
  size_t scalar_sqrt(size_t input);
389
447
  size_t scalar_cos(size_t input);
390
448
  size_t scalar_sin(size_t input);
449
+ size_t scalar_log(size_t input);
391
450
 
451
+ size_t relu(size_t input);
392
452
  size_t silu(size_t input);
393
453
  size_t gelu(size_t input);
394
454
  size_t gelu_erf(size_t input);
455
+ size_t sigmoid(size_t input);
456
+ size_t tanh(size_t input);
457
+ size_t glu(size_t input, int axis = -1);
395
458
 
396
459
  size_t matmul(size_t input1, size_t input2, bool pretransposed_rhs = false, ComputeBackend backend = ComputeBackend::CPU);
397
460
  size_t transpose(size_t input, ComputeBackend backend = ComputeBackend::CPU);
@@ -409,8 +472,8 @@ public:
409
472
  size_t gather(size_t embeddings, size_t indices);
410
473
  size_t mmap_embeddings(const std::string& filename);
411
474
  size_t mmap_weights(const std::string& filename);
412
- size_t load_weights(const std::string& filename);
413
475
  void set_grouped_scales(size_t node_id, size_t group_size, size_t num_groups, void* scales_ptr);
476
+ void set_interleaved(size_t node_id, bool interleaved, size_t original_N);
414
477
 
415
478
  void release_weight_pages(size_t node_id);
416
479
  void prefetch_weight_pages(size_t node_id);
@@ -420,22 +483,68 @@ public:
420
483
  size_t bilinear_interpolation(size_t pos_embeds, size_t dst_height, size_t dst_width);
421
484
 
422
485
  size_t layernorm(size_t input, size_t weight, size_t bias, float epsilon = 1e-5f);
486
+ size_t layernorm(size_t input, size_t weight, float epsilon = 1e-5f); // No bias version
487
+ size_t groupnorm(size_t input, size_t weight, size_t bias, size_t num_groups = 32, float epsilon = 1e-5f);
488
+ size_t batchnorm(size_t input, size_t weight, size_t bias, size_t running_mean, size_t running_var, int axis = 1, float epsilon = 1e-5f);
423
489
  size_t topk(size_t input, size_t k);
490
+ size_t moe_layer(size_t hidden,
491
+ size_t routing_probs,
492
+ size_t topk_indices,
493
+ const std::vector<size_t>& w1_weights,
494
+ const std::vector<size_t>& w3_weights,
495
+ const std::vector<size_t>& w2_weights,
496
+ size_t num_experts,
497
+ size_t num_experts_per_tok,
498
+ bool normalize_routing,
499
+ float epsilon,
500
+ float routed_scaling_factor);
501
+ size_t moe_layer(size_t hidden,
502
+ size_t routing_probs,
503
+ size_t topk_indices,
504
+ const std::vector<size_t>& w1_weights,
505
+ const std::vector<size_t>& w2_weights,
506
+ size_t num_experts,
507
+ size_t num_experts_per_tok,
508
+ bool normalize_routing,
509
+ float epsilon,
510
+ float routed_scaling_factor,
511
+ Activation activation);
424
512
  size_t rms_norm(size_t input, size_t weight, float epsilon = 1e-5f);
425
513
  size_t rope(size_t input, float theta, size_t position_offset = 0, ComputeBackend backend = ComputeBackend::CPU);
514
+ size_t rope_gptj(size_t input, float theta, size_t position_offset = 0, size_t rot_dim = 0, ComputeBackend backend = ComputeBackend::CPU);
426
515
  size_t softmax(size_t input, int axis = -1);
427
516
  size_t attention(size_t query, size_t key, size_t value, float scale, bool is_causal = true, ComputeBackend backend = ComputeBackend::CPU);
428
517
  size_t attention(size_t query, size_t key, size_t value, float scale, size_t position_offset, ComputeBackend backend = ComputeBackend::CPU);
429
518
  size_t attention(size_t query, size_t key, size_t value, float scale, size_t position_offset, size_t window_size, ComputeBackend backend = ComputeBackend::CPU);
519
+ size_t attention_masked(size_t query, size_t key, size_t value, size_t mask, float scale,
520
+ bool is_causal = true, ComputeBackend backend = ComputeBackend::CPU,
521
+ bool additive_mask = false, size_t position_offset = 0, size_t window_size = 0);
522
+ size_t rel_pos_bias(size_t query, size_t relative_key, float scale);
430
523
 
431
524
  size_t attention_int8_hybrid(size_t query, size_t key_new, size_t value_new, float scale, size_t position_offset,
432
525
  const int8_t* cached_keys, const int8_t* cached_values,
433
526
  const float* k_scales, const float* v_scales,
434
- size_t cache_len, size_t num_kv_heads, size_t head_dim);
527
+ size_t cache_len, size_t num_kv_heads, size_t head_dim, size_t window_size = 0);
435
528
 
436
529
  size_t conv1d_causal(size_t input, size_t weight, size_t kernel_size, size_t dilation = 1);
437
530
  size_t conv1d_k3(size_t input, size_t weight, size_t stride);
438
-
531
+ size_t conv1d_k7s3(size_t input, size_t weight, size_t bias);
532
+ size_t conv1d(size_t input, size_t weight, size_t stride);
533
+ size_t conv1d(size_t input, size_t weight, size_t bias, size_t stride);
534
+ size_t conv1d_same_depthwise_k9(size_t input, size_t weight);
535
+ size_t conv1d_same_depthwise_k9(size_t input, size_t weight, size_t bias);
536
+ size_t conv1d_pointwise(size_t input, size_t weight);
537
+ size_t conv1d_pointwise(size_t input, size_t weight, size_t bias);
538
+ size_t conv2d_k3s2p1(size_t input, size_t weight);
539
+ size_t conv2d_k3s2p1(size_t input, size_t weight, size_t bias);
540
+ size_t conv2d_depthwise_k3s2p1(size_t input, size_t weight);
541
+ size_t conv2d_depthwise_k3s2p1(size_t input, size_t weight, size_t bias);
542
+ size_t conv2d_pointwise_1x1(size_t input, size_t weight);
543
+ size_t conv2d_pointwise_1x1(size_t input, size_t weight, size_t bias);
544
+
545
+ size_t lstm_cell(size_t input, size_t h_prev, size_t c_prev, size_t weight_ih, size_t weight_hh, size_t bias_ih, size_t bias_hh);
546
+ size_t stft(size_t input, size_t weight, size_t stride, size_t num_fft_bins);
547
+
439
548
  size_t sample(size_t logits, float temperature = 0.6f, float top_p = 0.95f, size_t top_k = 20,
440
549
  const std::unordered_map<uint32_t, float>& logit_bias = {});
441
550
 
@@ -462,6 +571,10 @@ public:
462
571
  void allocate_buffers();
463
572
  size_t get_node_count() const;
464
573
 
574
+ size_t persistent(size_t source_node);
575
+ bool is_populated(size_t persistent_node_id) const;
576
+ void invalidate_persistent(size_t persistent_node_id);
577
+
465
578
  std::vector<std::unique_ptr<GraphNode>> nodes_;
466
579
  std::unordered_map<size_t, size_t> node_index_map_;
467
580
 
@@ -473,6 +586,9 @@ private:
473
586
  std::vector<DebugNodeEntry> debug_nodes_;
474
587
  BufferPool buffer_pool_;
475
588
  bool prefill_mode_ = false;
589
+
590
+ std::unordered_set<size_t> persistent_node_ids_;
591
+ std::unordered_set<size_t> populated_node_ids_;
476
592
  };
477
593
 
478
594
 
@@ -485,7 +601,6 @@ namespace GraphFile {
485
601
  };
486
602
 
487
603
  void save_node(CactusGraph& graph, size_t node_id, const std::string& filename);
488
- LoadedNode load_into_graph(CactusGraph& graph, const std::string& filename);
489
604
 
490
605
  class MappedFile {
491
606
  public:
@@ -499,16 +614,14 @@ namespace GraphFile {
499
614
 
500
615
  const std::vector<size_t>& shape() const;
501
616
  Precision precision() const;
502
- Precision effective_precision() const {
503
- return is_int4_ ? Precision::INT8 : precision_;
504
- }
505
617
  size_t byte_size() const;
506
618
 
507
619
  size_t group_size() const { return group_size_; }
508
620
  size_t num_groups() const { return num_groups_; }
509
621
  const void* scales_data() const;
510
- const void* raw_packed_data() const; // Get raw mmap'd data without unpacking (for INT4)
511
- bool is_int4() const { return is_int4_; }
622
+
623
+ bool is_interleaved() const { return is_interleaved_; }
624
+ size_t original_N() const { return original_N_; }
512
625
 
513
626
  void* data();
514
627
  const void* data() const;
@@ -516,8 +629,6 @@ namespace GraphFile {
516
629
  template<typename T>
517
630
  const T* typed_data() const;
518
631
 
519
- LoadedNode load_into_graph(CactusGraph& graph) const;
520
-
521
632
  void release_pages();
522
633
  void prefetch_pages();
523
634
 
@@ -532,16 +643,14 @@ namespace GraphFile {
532
643
  size_t num_groups_ = 0;
533
644
  size_t scales_offset_ = 0;
534
645
  size_t scales_bytes_ = 0;
535
- uint32_t version_ = 1;
536
646
  uint32_t alignment_ = 32;
537
- bool is_int4_ = false;
538
- mutable std::unique_ptr<int8_t[]> unpacked_int4_data_;
647
+
648
+ bool is_interleaved_ = false;
649
+ size_t original_N_ = 0;
650
+
539
651
  void parse_header();
540
652
  void apply_madvise_hints();
541
- void unpack_int4_if_needed() const;
542
653
  };
543
-
544
- MappedFile mmap_load(const std::string& filename);
545
654
  }
546
655
 
547
- #endif
656
+ #endif