cactus-react-native 1.4.0 → 1.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. package/Cactus.podspec +1 -1
  2. package/README.md +465 -174
  3. package/android/CMakeLists.txt +24 -5
  4. package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libcurl.a +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libmbedcrypto.a +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libmbedtls.a +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libmbedx509.a +0 -0
  9. package/cpp/HybridCactus.cpp +157 -6
  10. package/cpp/HybridCactus.hpp +20 -3
  11. package/cpp/cactus_ffi.h +65 -30
  12. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus.h +0 -1
  13. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +65 -30
  14. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_utils.h +357 -122
  15. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +184 -63
  16. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/gemma_tools.h +549 -0
  17. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +153 -27
  18. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +90 -178
  19. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +276 -151
  20. package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
  21. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus.h +0 -1
  22. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +65 -30
  23. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_utils.h +357 -122
  24. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +184 -63
  25. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/gemma_tools.h +549 -0
  26. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +153 -27
  27. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +90 -178
  28. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +276 -151
  29. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
  30. package/lib/module/classes/CactusLM.js +43 -58
  31. package/lib/module/classes/CactusLM.js.map +1 -1
  32. package/lib/module/classes/CactusSTT.js +64 -38
  33. package/lib/module/classes/CactusSTT.js.map +1 -1
  34. package/lib/module/classes/CactusVAD.js +95 -0
  35. package/lib/module/classes/CactusVAD.js.map +1 -0
  36. package/lib/module/hooks/useCactusLM.js +23 -15
  37. package/lib/module/hooks/useCactusLM.js.map +1 -1
  38. package/lib/module/hooks/useCactusSTT.js +85 -28
  39. package/lib/module/hooks/useCactusSTT.js.map +1 -1
  40. package/lib/module/hooks/useCactusVAD.js +171 -0
  41. package/lib/module/hooks/useCactusVAD.js.map +1 -0
  42. package/lib/module/index.js +2 -3
  43. package/lib/module/index.js.map +1 -1
  44. package/lib/module/modelRegistry.js +52 -0
  45. package/lib/module/modelRegistry.js.map +1 -0
  46. package/lib/module/native/Cactus.js +107 -8
  47. package/lib/module/native/Cactus.js.map +1 -1
  48. package/lib/module/native/CactusIndex.js.map +1 -1
  49. package/lib/module/native/index.js +0 -3
  50. package/lib/module/native/index.js.map +1 -1
  51. package/lib/module/types/CactusLM.js +2 -0
  52. package/lib/module/types/CactusSTT.js +2 -0
  53. package/lib/module/types/CactusVAD.js +4 -0
  54. package/lib/module/types/{CactusModel.js.map → CactusVAD.js.map} +1 -1
  55. package/lib/module/types/common.js +2 -0
  56. package/lib/module/types/{CactusSTTModel.js.map → common.js.map} +1 -1
  57. package/lib/typescript/src/classes/CactusLM.d.ts +8 -6
  58. package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
  59. package/lib/typescript/src/classes/CactusSTT.d.ts +11 -6
  60. package/lib/typescript/src/classes/CactusSTT.d.ts.map +1 -1
  61. package/lib/typescript/src/classes/CactusVAD.d.ts +20 -0
  62. package/lib/typescript/src/classes/CactusVAD.d.ts.map +1 -0
  63. package/lib/typescript/src/hooks/useCactusLM.d.ts +3 -3
  64. package/lib/typescript/src/hooks/useCactusLM.d.ts.map +1 -1
  65. package/lib/typescript/src/hooks/useCactusSTT.d.ts +11 -5
  66. package/lib/typescript/src/hooks/useCactusSTT.d.ts.map +1 -1
  67. package/lib/typescript/src/hooks/useCactusVAD.d.ts +15 -0
  68. package/lib/typescript/src/hooks/useCactusVAD.d.ts.map +1 -0
  69. package/lib/typescript/src/index.d.ts +7 -6
  70. package/lib/typescript/src/index.d.ts.map +1 -1
  71. package/lib/typescript/src/modelRegistry.d.ts +5 -0
  72. package/lib/typescript/src/modelRegistry.d.ts.map +1 -0
  73. package/lib/typescript/src/native/Cactus.d.ts +12 -6
  74. package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
  75. package/lib/typescript/src/native/CactusIndex.d.ts +2 -2
  76. package/lib/typescript/src/native/CactusIndex.d.ts.map +1 -1
  77. package/lib/typescript/src/native/index.d.ts +0 -3
  78. package/lib/typescript/src/native/index.d.ts.map +1 -1
  79. package/lib/typescript/src/specs/Cactus.nitro.d.ts +6 -1
  80. package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
  81. package/lib/typescript/src/types/CactusIndex.d.ts +2 -2
  82. package/lib/typescript/src/types/CactusIndex.d.ts.map +1 -1
  83. package/lib/typescript/src/types/CactusLM.d.ts +19 -9
  84. package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
  85. package/lib/typescript/src/types/CactusSTT.d.ts +45 -4
  86. package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
  87. package/lib/typescript/src/types/CactusVAD.d.ts +34 -0
  88. package/lib/typescript/src/types/CactusVAD.d.ts.map +1 -0
  89. package/lib/typescript/src/types/common.d.ts +23 -0
  90. package/lib/typescript/src/types/common.d.ts.map +1 -0
  91. package/nitro.json +0 -11
  92. package/nitrogen/generated/android/cactus+autolinking.cmake +0 -5
  93. package/nitrogen/generated/android/cactusOnLoad.cpp +0 -30
  94. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.cpp +0 -50
  95. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.hpp +9 -147
  96. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Umbrella.hpp +0 -13
  97. package/nitrogen/generated/ios/CactusAutolinking.mm +0 -26
  98. package/nitrogen/generated/ios/CactusAutolinking.swift +0 -30
  99. package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +5 -0
  100. package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +6 -1
  101. package/package.json +3 -3
  102. package/src/classes/CactusLM.ts +59 -74
  103. package/src/classes/CactusSTT.ts +92 -49
  104. package/src/classes/CactusVAD.ts +129 -0
  105. package/src/hooks/useCactusLM.ts +26 -9
  106. package/src/hooks/useCactusSTT.ts +105 -44
  107. package/src/hooks/useCactusVAD.ts +215 -0
  108. package/src/index.tsx +20 -10
  109. package/src/modelRegistry.ts +65 -0
  110. package/src/native/Cactus.ts +130 -14
  111. package/src/native/CactusIndex.ts +2 -2
  112. package/src/native/index.ts +0 -3
  113. package/src/specs/Cactus.nitro.ts +11 -2
  114. package/src/types/CactusIndex.ts +2 -2
  115. package/src/types/CactusLM.ts +20 -9
  116. package/src/types/CactusSTT.ts +50 -4
  117. package/src/types/CactusVAD.ts +39 -0
  118. package/src/types/common.ts +23 -0
  119. package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusCrypto.kt +0 -46
  120. package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusDeviceInfo.kt +0 -27
  121. package/android/src/main/jniLibs/arm64-v8a/libcactus_util.a +0 -0
  122. package/cpp/HybridCactusUtil.cpp +0 -47
  123. package/cpp/HybridCactusUtil.hpp +0 -27
  124. package/cpp/cactus_util.h +0 -25
  125. package/ios/HybridCactusCrypto.swift +0 -37
  126. package/ios/HybridCactusDeviceInfo.swift +0 -32
  127. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_telemetry.h +0 -656
  128. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_telemetry.h +0 -656
  129. package/ios/cactus_util.xcframework/Info.plist +0 -39
  130. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/cactus_util.h +0 -25
  131. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/database.h +0 -27
  132. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/ios_utils.h +0 -10
  133. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/logging.h +0 -25
  134. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Info.plist +0 -0
  135. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/cactus_util +0 -0
  136. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/cactus_util.h +0 -25
  137. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/database.h +0 -27
  138. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/ios_utils.h +0 -10
  139. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/logging.h +0 -25
  140. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Info.plist +0 -0
  141. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/_CodeSignature/CodeResources +0 -135
  142. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/cactus_util +0 -0
  143. package/lib/module/api/Database.js +0 -137
  144. package/lib/module/api/Database.js.map +0 -1
  145. package/lib/module/api/RemoteLM.js +0 -201
  146. package/lib/module/api/RemoteLM.js.map +0 -1
  147. package/lib/module/config/CactusConfig.js +0 -12
  148. package/lib/module/config/CactusConfig.js.map +0 -1
  149. package/lib/module/native/CactusCrypto.js +0 -10
  150. package/lib/module/native/CactusCrypto.js.map +0 -1
  151. package/lib/module/native/CactusDeviceInfo.js +0 -13
  152. package/lib/module/native/CactusDeviceInfo.js.map +0 -1
  153. package/lib/module/native/CactusUtil.js +0 -36
  154. package/lib/module/native/CactusUtil.js.map +0 -1
  155. package/lib/module/specs/CactusCrypto.nitro.js +0 -4
  156. package/lib/module/specs/CactusCrypto.nitro.js.map +0 -1
  157. package/lib/module/specs/CactusDeviceInfo.nitro.js +0 -4
  158. package/lib/module/specs/CactusDeviceInfo.nitro.js.map +0 -1
  159. package/lib/module/specs/CactusUtil.nitro.js +0 -4
  160. package/lib/module/specs/CactusUtil.nitro.js.map +0 -1
  161. package/lib/module/telemetry/Telemetry.js +0 -154
  162. package/lib/module/telemetry/Telemetry.js.map +0 -1
  163. package/lib/module/types/CactusModel.js +0 -2
  164. package/lib/module/types/CactusSTTModel.js +0 -2
  165. package/lib/typescript/src/api/Database.d.ts +0 -18
  166. package/lib/typescript/src/api/Database.d.ts.map +0 -1
  167. package/lib/typescript/src/api/RemoteLM.d.ts +0 -14
  168. package/lib/typescript/src/api/RemoteLM.d.ts.map +0 -1
  169. package/lib/typescript/src/config/CactusConfig.d.ts +0 -7
  170. package/lib/typescript/src/config/CactusConfig.d.ts.map +0 -1
  171. package/lib/typescript/src/native/CactusCrypto.d.ts +0 -5
  172. package/lib/typescript/src/native/CactusCrypto.d.ts.map +0 -1
  173. package/lib/typescript/src/native/CactusDeviceInfo.d.ts +0 -7
  174. package/lib/typescript/src/native/CactusDeviceInfo.d.ts.map +0 -1
  175. package/lib/typescript/src/native/CactusUtil.d.ts +0 -6
  176. package/lib/typescript/src/native/CactusUtil.d.ts.map +0 -1
  177. package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts +0 -8
  178. package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts.map +0 -1
  179. package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts +0 -16
  180. package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts.map +0 -1
  181. package/lib/typescript/src/specs/CactusUtil.nitro.d.ts +0 -10
  182. package/lib/typescript/src/specs/CactusUtil.nitro.d.ts.map +0 -1
  183. package/lib/typescript/src/telemetry/Telemetry.d.ts +0 -34
  184. package/lib/typescript/src/telemetry/Telemetry.d.ts.map +0 -1
  185. package/lib/typescript/src/types/CactusModel.d.ts +0 -13
  186. package/lib/typescript/src/types/CactusModel.d.ts.map +0 -1
  187. package/lib/typescript/src/types/CactusSTTModel.d.ts +0 -8
  188. package/lib/typescript/src/types/CactusSTTModel.d.ts.map +0 -1
  189. package/nitrogen/generated/android/c++/JDeviceInfo.hpp +0 -74
  190. package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.cpp +0 -65
  191. package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.hpp +0 -65
  192. package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.cpp +0 -85
  193. package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.hpp +0 -66
  194. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/DeviceInfo.kt +0 -50
  195. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusCryptoSpec.kt +0 -58
  196. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusDeviceInfoSpec.kt +0 -62
  197. package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.cpp +0 -11
  198. package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.hpp +0 -77
  199. package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.cpp +0 -11
  200. package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.hpp +0 -88
  201. package/nitrogen/generated/ios/swift/DeviceInfo.swift +0 -98
  202. package/nitrogen/generated/ios/swift/Func_void_DeviceInfo.swift +0 -47
  203. package/nitrogen/generated/ios/swift/Func_void_std__optional_std__string_.swift +0 -54
  204. package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec.swift +0 -57
  205. package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec_cxx.swift +0 -139
  206. package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec.swift +0 -58
  207. package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec_cxx.swift +0 -164
  208. package/nitrogen/generated/shared/c++/DeviceInfo.hpp +0 -92
  209. package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.cpp +0 -21
  210. package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.hpp +0 -63
  211. package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.cpp +0 -22
  212. package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.hpp +0 -67
  213. package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.cpp +0 -23
  214. package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.hpp +0 -66
  215. package/src/api/Database.ts +0 -188
  216. package/src/api/RemoteLM.ts +0 -273
  217. package/src/config/CactusConfig.ts +0 -11
  218. package/src/native/CactusCrypto.ts +0 -11
  219. package/src/native/CactusDeviceInfo.ts +0 -18
  220. package/src/native/CactusUtil.ts +0 -43
  221. package/src/specs/CactusCrypto.nitro.ts +0 -6
  222. package/src/specs/CactusDeviceInfo.nitro.ts +0 -15
  223. package/src/specs/CactusUtil.nitro.ts +0 -8
  224. package/src/telemetry/Telemetry.ts +0 -236
  225. package/src/types/CactusModel.ts +0 -15
  226. package/src/types/CactusSTTModel.ts +0 -10
@@ -4,6 +4,7 @@
4
4
  #include <vector>
5
5
  #include <memory>
6
6
  #include <unordered_map>
7
+ #include <unordered_set>
7
8
  #include <functional>
8
9
  #include <cstring>
9
10
  #include <stdexcept>
@@ -11,6 +12,7 @@
11
12
  #include <mutex>
12
13
  #include <sstream>
13
14
  #include <iostream>
15
+ #include <arm_neon.h>
14
16
 
15
17
  namespace cactus {
16
18
 
@@ -96,9 +98,10 @@ namespace GraphFile {
96
98
  }
97
99
 
98
100
  enum class Precision {
99
- INT8,
101
+ INT8,
100
102
  FP16,
101
- FP32
103
+ FP32,
104
+ INT4
102
105
  };
103
106
 
104
107
  enum class ComputeBackend {
@@ -112,13 +115,17 @@ enum class OpType {
112
115
  MATMUL, TRANSPOSE, RESHAPE, SLICE, GATHER, EMBEDDING,
113
116
  BILINEAR_INTERPOLATION,
114
117
  SUM, MEAN, VARIANCE, MIN, MAX,
115
- RMS_NORM, ROPE, SOFTMAX, ATTENTION, CONV1D_CAUSAL, CONV1D_K3,
118
+ RMS_NORM, ROPE, ROPE_GPTJ, SOFTMAX, ATTENTION, ATTENTION_INT8_HYBRID, CONV1D_CAUSAL, CONV1D_K3, CONV1D_K7S3, CONV1D,
116
119
  SCALAR_ADD, SCALAR_SUBTRACT, SCALAR_MULTIPLY, SCALAR_DIVIDE, SCALAR_EXP, SCALAR_SQRT, SCALAR_COS, SCALAR_SIN,
117
- SILU, GELU, GELU_ERF,
120
+ RELU, SILU, GELU, GELU_ERF, SIGMOID, TANH,
118
121
  SAMPLE, CONCAT,
119
122
  SCATTER_TOPK,
120
- TOPK, LAYERNORM,
123
+ TOPK, LAYERNORM, GROUPNORM,
121
124
  INDEX,
125
+ PERSISTENT,
126
+ QUANTIZE_ACTIVATIONS,
127
+ LSTM_CELL,
128
+ STFT_MAGNITUDE
122
129
  };
123
130
 
124
131
  struct PrecisionTraits {
@@ -127,22 +134,32 @@ struct PrecisionTraits {
127
134
  case Precision::INT8: return 1;
128
135
  case Precision::FP16: return 2;
129
136
  case Precision::FP32: return 4;
137
+ case Precision::INT4: return 1;
130
138
  }
131
139
  return 1;
132
140
  }
133
-
141
+
142
+ static constexpr size_t packed_size_of(Precision prec, size_t count) {
143
+ switch (prec) {
144
+ case Precision::INT4: return (count + 1) / 2;
145
+ default: return count * size_of(prec);
146
+ }
147
+ }
148
+
134
149
  static constexpr bool is_integer(Precision prec) {
135
150
  switch (prec) {
136
151
  case Precision::INT8: return true;
152
+ case Precision::INT4: return true;
137
153
  case Precision::FP16: return false;
138
154
  case Precision::FP32: return false;
139
155
  }
140
156
  return true;
141
157
  }
142
-
158
+
143
159
  static constexpr bool is_floating_point(Precision prec) {
144
160
  switch (prec) {
145
161
  case Precision::INT8: return false;
162
+ case Precision::INT4: return false;
146
163
  case Precision::FP16: return true;
147
164
  case Precision::FP32: return true;
148
165
  }
@@ -153,8 +170,6 @@ struct PrecisionTraits {
153
170
  namespace Quantization {
154
171
  void int8_to_fp32(const int8_t* src, float* dst, size_t count, float scale = 1.0f);
155
172
  void fp32_to_int8(const float* src, int8_t* dst, size_t count, float scale = 1.0f);
156
- void dynamic_quantize_fp32_to_int8(const float* src, int8_t* dst, size_t count,
157
- float* computed_scale);
158
173
  void fp16_to_fp32(const __fp16* src, float* dst, size_t count);
159
174
  void fp32_to_fp16(const float* src, __fp16* dst, size_t count);
160
175
  void int8_to_fp16(const int8_t* src, __fp16* dst, size_t count, float scale = 1.0f);
@@ -188,10 +203,21 @@ struct BufferDesc {
188
203
  void* external_data;
189
204
  char* pooled_data;
190
205
  Precision precision;
191
- float quantization_scale;
206
+
207
+ size_t group_size = 0;
208
+ size_t num_groups = 0;
209
+ void* scales_data = nullptr;
210
+ std::unique_ptr<char[]> owned_scales;
211
+
212
+ bool is_interleaved = false;
213
+ size_t original_N = 0;
214
+
215
+ void* activation_scales_data = nullptr;
216
+ std::unique_ptr<char[]> owned_activation_scales;
217
+ size_t num_rows_for_activation_scales = 0;
192
218
 
193
219
  BufferDesc();
194
- BufferDesc(const std::vector<size_t>& s, Precision prec = Precision::INT8, float scale = 1.0f);
220
+ BufferDesc(const std::vector<size_t>& s, Precision prec = Precision::INT8);
195
221
  ~BufferDesc();
196
222
 
197
223
  BufferDesc(BufferDesc&& other) noexcept;
@@ -209,6 +235,44 @@ struct BufferDesc {
209
235
  template<typename T>
210
236
  const T* data_as() const { return static_cast<const T*>(get_data()); }
211
237
 
238
+ const __fp16* scales_as_fp16() const {
239
+ return reinterpret_cast<const __fp16*>(scales_data);
240
+ }
241
+
242
+ bool is_grouped_int8() const {
243
+ return precision == Precision::INT8 && group_size > 0;
244
+ }
245
+
246
+ void set_grouped_scales(size_t gs, size_t ng, void* scales_ptr) {
247
+ group_size = gs;
248
+ num_groups = ng;
249
+ scales_data = scales_ptr;
250
+ }
251
+
252
+ void set_interleaved(bool interleaved, size_t orig_n) {
253
+ is_interleaved = interleaved;
254
+ original_N = orig_n;
255
+ }
256
+
257
+ bool has_activation_scales() const {
258
+ return activation_scales_data != nullptr && num_rows_for_activation_scales > 0;
259
+ }
260
+ const float* activation_scales_as_float() const {
261
+ return reinterpret_cast<const float*>(activation_scales_data);
262
+ }
263
+ float* activation_scales_as_float() {
264
+ return reinterpret_cast<float*>(activation_scales_data);
265
+ }
266
+ void allocate_activation_scales(size_t num_rows) {
267
+ num_rows_for_activation_scales = num_rows;
268
+ owned_activation_scales = std::make_unique<char[]>(num_rows * sizeof(float));
269
+ activation_scales_data = owned_activation_scales.get();
270
+ }
271
+ void set_activation_scales(void* scales_ptr, size_t num_rows) {
272
+ activation_scales_data = scales_ptr;
273
+ num_rows_for_activation_scales = num_rows;
274
+ }
275
+
212
276
  void allocate();
213
277
  void allocate_from_pool(BufferPool& pool);
214
278
  void release_to_pool(BufferPool& pool);
@@ -242,11 +306,21 @@ struct OpParams {
242
306
 
243
307
  size_t index_value = 0;
244
308
  size_t num_classes = 0;
309
+ size_t num_groups = 0;
245
310
  size_t dst_height = 0;
246
311
  size_t dst_width = 0;
247
312
 
248
313
  std::vector<float> bias_values;
249
314
  std::vector<uint32_t> bias_indices;
315
+
316
+ const int8_t* cached_keys_int8 = nullptr;
317
+ const int8_t* cached_values_int8 = nullptr;
318
+ const float* cached_k_scales = nullptr;
319
+ const float* cached_v_scales = nullptr;
320
+ size_t cache_seq_len = 0;
321
+ size_t num_kv_heads = 0;
322
+ size_t head_dim = 0;
323
+ size_t num_fft_bins = 0;
250
324
  };
251
325
 
252
326
  struct GraphNode {
@@ -276,7 +350,10 @@ void compute_sample_node(GraphNode& node, const std::vector<std::unique_ptr<Grap
276
350
  void compute_scatter_topk_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
277
351
  void compute_topk_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
278
352
  void compute_layernorm_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
353
+ void compute_groupnorm_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
354
+ void compute_persistent_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
279
355
  void compute_index_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
356
+ void compute_lstm_cell_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
280
357
 
281
358
  void shrink_thread_local_buffers();
282
359
 
@@ -324,9 +401,10 @@ public:
324
401
 
325
402
  size_t input(const std::vector<size_t>& shape, Precision precision = Precision::INT8);
326
403
  size_t precision_cast(size_t input, Precision target_precision);
404
+ size_t quantize_activations(size_t input);
327
405
 
328
406
  size_t add(size_t input1, size_t input2);
329
- size_t add_clipped(size_t input1, size_t input2); // For FP16 residual connections (Gemma)
407
+ size_t add_clipped(size_t input1, size_t input2);
330
408
  size_t subtract(size_t input1, size_t input2);
331
409
  size_t multiply(size_t input1, size_t input2);
332
410
  size_t divide(size_t input1, size_t input2);
@@ -341,9 +419,12 @@ public:
341
419
  size_t scalar_cos(size_t input);
342
420
  size_t scalar_sin(size_t input);
343
421
 
422
+ size_t relu(size_t input);
344
423
  size_t silu(size_t input);
345
424
  size_t gelu(size_t input);
346
425
  size_t gelu_erf(size_t input);
426
+ size_t sigmoid(size_t input);
427
+ size_t tanh(size_t input);
347
428
 
348
429
  size_t matmul(size_t input1, size_t input2, bool pretransposed_rhs = false, ComputeBackend backend = ComputeBackend::CPU);
349
430
  size_t transpose(size_t input, ComputeBackend backend = ComputeBackend::CPU);
@@ -361,24 +442,42 @@ public:
361
442
  size_t gather(size_t embeddings, size_t indices);
362
443
  size_t mmap_embeddings(const std::string& filename);
363
444
  size_t mmap_weights(const std::string& filename);
364
- size_t load_weights(const std::string& filename);
365
- void set_quantization_scale(size_t node_id, float scale);
445
+ void set_grouped_scales(size_t node_id, size_t group_size, size_t num_groups, void* scales_ptr);
446
+ void set_interleaved(size_t node_id, bool interleaved, size_t original_N);
447
+
448
+ void release_weight_pages(size_t node_id);
449
+ void prefetch_weight_pages(size_t node_id);
450
+ void release_all_weight_pages();
366
451
  size_t embedding(const std::string& filename, size_t indices);
367
452
  size_t embedding(size_t embedding_tensor, size_t indices);
368
453
  size_t bilinear_interpolation(size_t pos_embeds, size_t dst_height, size_t dst_width);
369
454
 
370
455
  size_t layernorm(size_t input, size_t weight, size_t bias, float epsilon = 1e-5f);
456
+ size_t layernorm(size_t input, size_t weight, float epsilon = 1e-5f); // No bias version
457
+ size_t groupnorm(size_t input, size_t weight, size_t bias, size_t num_groups = 32, float epsilon = 1e-5f);
371
458
  size_t topk(size_t input, size_t k);
372
459
  size_t rms_norm(size_t input, size_t weight, float epsilon = 1e-5f);
373
460
  size_t rope(size_t input, float theta, size_t position_offset = 0, ComputeBackend backend = ComputeBackend::CPU);
461
+ size_t rope_gptj(size_t input, float theta, size_t position_offset = 0, size_t rot_dim = 0, ComputeBackend backend = ComputeBackend::CPU);
374
462
  size_t softmax(size_t input, int axis = -1);
375
463
  size_t attention(size_t query, size_t key, size_t value, float scale, bool is_causal = true, ComputeBackend backend = ComputeBackend::CPU);
376
464
  size_t attention(size_t query, size_t key, size_t value, float scale, size_t position_offset, ComputeBackend backend = ComputeBackend::CPU);
377
465
  size_t attention(size_t query, size_t key, size_t value, float scale, size_t position_offset, size_t window_size, ComputeBackend backend = ComputeBackend::CPU);
378
466
 
467
+ size_t attention_int8_hybrid(size_t query, size_t key_new, size_t value_new, float scale, size_t position_offset,
468
+ const int8_t* cached_keys, const int8_t* cached_values,
469
+ const float* k_scales, const float* v_scales,
470
+ size_t cache_len, size_t num_kv_heads, size_t head_dim, size_t window_size = 0);
471
+
379
472
  size_t conv1d_causal(size_t input, size_t weight, size_t kernel_size, size_t dilation = 1);
380
473
  size_t conv1d_k3(size_t input, size_t weight, size_t stride);
381
-
474
+ size_t conv1d_k7s3(size_t input, size_t weight, size_t bias);
475
+ size_t conv1d(size_t input, size_t weight, size_t stride);
476
+ size_t conv1d(size_t input, size_t weight, size_t bias, size_t stride);
477
+
478
+ size_t lstm_cell(size_t input, size_t h_prev, size_t c_prev, size_t weight_ih, size_t weight_hh, size_t bias_ih, size_t bias_hh);
479
+ size_t stft_magnitude(size_t input, size_t weight, size_t stride, size_t num_fft_bins);
480
+
382
481
  size_t sample(size_t logits, float temperature = 0.6f, float top_p = 0.95f, size_t top_k = 20,
383
482
  const std::unordered_map<uint32_t, float>& logit_bias = {});
384
483
 
@@ -392,6 +491,8 @@ public:
392
491
  void execute(const std::string& profile_file = "");
393
492
  void hard_reset();
394
493
  void soft_reset();
494
+ void soft_reset_keep_pool();
495
+ void set_prefill_mode(bool enabled) { prefill_mode_ = enabled; }
395
496
 
396
497
  void register_debug_node(uint32_t layer_idx, const std::string& name, size_t node_id);
397
498
  void capture_debug_node(uint32_t layer_idx, const std::string& name, size_t node_id);
@@ -403,6 +504,10 @@ public:
403
504
  void allocate_buffers();
404
505
  size_t get_node_count() const;
405
506
 
507
+ size_t persistent(size_t source_node);
508
+ bool is_populated(size_t persistent_node_id) const;
509
+ void invalidate_persistent(size_t persistent_node_id);
510
+
406
511
  std::vector<std::unique_ptr<GraphNode>> nodes_;
407
512
  std::unordered_map<size_t, size_t> node_index_map_;
408
513
 
@@ -410,8 +515,13 @@ private:
410
515
  size_t next_node_id_;
411
516
  std::vector<std::unique_ptr<GraphFile::MappedFile>> mapped_files_;
412
517
  std::unordered_map<std::string, size_t> weight_cache_;
518
+ std::unordered_map<size_t, size_t> node_to_mapped_file_;
413
519
  std::vector<DebugNodeEntry> debug_nodes_;
414
520
  BufferPool buffer_pool_;
521
+ bool prefill_mode_ = false;
522
+
523
+ std::unordered_set<size_t> persistent_node_ids_;
524
+ std::unordered_set<size_t> populated_node_ids_;
415
525
  };
416
526
 
417
527
 
@@ -424,31 +534,37 @@ namespace GraphFile {
424
534
  };
425
535
 
426
536
  void save_node(CactusGraph& graph, size_t node_id, const std::string& filename);
427
- LoadedNode load_into_graph(CactusGraph& graph, const std::string& filename);
428
537
 
429
538
  class MappedFile {
430
539
  public:
431
540
  MappedFile(const std::string& filename);
432
541
  ~MappedFile();
433
-
542
+
434
543
  MappedFile(const MappedFile&) = delete;
435
544
  MappedFile& operator=(const MappedFile&) = delete;
436
545
  MappedFile(MappedFile&& other) noexcept;
437
546
  MappedFile& operator=(MappedFile&& other) noexcept;
438
-
547
+
439
548
  const std::vector<size_t>& shape() const;
440
549
  Precision precision() const;
441
550
  size_t byte_size() const;
442
- float quantization_scale() const;
443
-
551
+
552
+ size_t group_size() const { return group_size_; }
553
+ size_t num_groups() const { return num_groups_; }
554
+ const void* scales_data() const;
555
+
556
+ bool is_interleaved() const { return is_interleaved_; }
557
+ size_t original_N() const { return original_N_; }
558
+
444
559
  void* data();
445
560
  const void* data() const;
446
-
561
+
447
562
  template<typename T>
448
563
  const T* typed_data() const;
449
-
450
- LoadedNode load_into_graph(CactusGraph& graph) const;
451
-
564
+
565
+ void release_pages();
566
+ void prefetch_pages();
567
+
452
568
  private:
453
569
  int fd_;
454
570
  void* mapped_data_;
@@ -456,11 +572,21 @@ namespace GraphFile {
456
572
  std::vector<size_t> shape_;
457
573
  Precision precision_;
458
574
  size_t byte_size_;
459
- float quantization_scale_;
575
+ size_t group_size_ = 0;
576
+ size_t num_groups_ = 0;
577
+ size_t scales_offset_ = 0;
578
+ size_t scales_bytes_ = 0;
579
+ uint32_t alignment_ = 32;
580
+
581
+ bool is_interleaved_ = false;
582
+ size_t original_N_ = 0;
583
+
584
+ std::unique_ptr<int8_t[]> unpacked_data_;
585
+
460
586
  void parse_header();
587
+ void apply_madvise_hints();
588
+ void unpack_int4_data();
461
589
  };
462
-
463
- MappedFile mmap_load(const std::string& filename);
464
590
  }
465
591
 
466
592
  #endif