cactus-react-native 1.2.1 → 1.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. package/README.md +765 -33
  2. package/android/CMakeLists.txt +4 -3
  3. package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusFileSystem.kt +20 -1
  4. package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libcactus_util.a +0 -0
  6. package/cpp/HybridCactus.cpp +231 -19
  7. package/cpp/HybridCactus.hpp +25 -3
  8. package/cpp/HybridCactusIndex.cpp +325 -0
  9. package/cpp/HybridCactusIndex.hpp +43 -0
  10. package/cpp/HybridCactusUtil.cpp +3 -3
  11. package/cpp/HybridCactusUtil.hpp +2 -1
  12. package/cpp/cactus_ffi.h +107 -2
  13. package/cpp/cactus_util.h +1 -1
  14. package/ios/HybridCactusFileSystem.swift +23 -2
  15. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus.h +2 -0
  16. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +107 -2
  17. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_telemetry.h +656 -0
  18. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/{ffi_utils.h → cactus_utils.h} +145 -18
  19. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +135 -7
  20. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/gemma_tools.h +549 -0
  21. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +193 -26
  22. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +54 -195
  23. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +399 -140
  24. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Info.plist +0 -0
  25. package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
  26. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus.h +2 -0
  27. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +107 -2
  28. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_telemetry.h +656 -0
  29. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/{ffi_utils.h → cactus_utils.h} +145 -18
  30. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +135 -7
  31. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/gemma_tools.h +549 -0
  32. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +193 -26
  33. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +54 -195
  34. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +399 -140
  35. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Info.plist +0 -0
  36. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/_CodeSignature/CodeResources +1 -1
  37. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
  38. package/ios/cactus_util.xcframework/Info.plist +4 -4
  39. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/cactus_util.h +1 -1
  40. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/database.h +27 -0
  41. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Info.plist +0 -0
  42. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/cactus_util +0 -0
  43. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/cactus_util.h +1 -1
  44. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/database.h +27 -0
  45. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Info.plist +0 -0
  46. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/_CodeSignature/CodeResources +3 -3
  47. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/cactus_util +0 -0
  48. package/lib/module/api/Database.js +12 -95
  49. package/lib/module/api/Database.js.map +1 -1
  50. package/lib/module/classes/CactusIndex.js +45 -0
  51. package/lib/module/classes/CactusIndex.js.map +1 -0
  52. package/lib/module/classes/CactusLM.js +65 -17
  53. package/lib/module/classes/CactusLM.js.map +1 -1
  54. package/lib/module/classes/CactusSTT.js +104 -17
  55. package/lib/module/classes/CactusSTT.js.map +1 -1
  56. package/lib/module/config/CactusConfig.js +2 -0
  57. package/lib/module/config/CactusConfig.js.map +1 -1
  58. package/lib/module/constants/packageVersion.js +1 -1
  59. package/lib/module/hooks/useCactusIndex.js +175 -0
  60. package/lib/module/hooks/useCactusIndex.js.map +1 -0
  61. package/lib/module/hooks/useCactusLM.js +68 -7
  62. package/lib/module/hooks/useCactusLM.js.map +1 -1
  63. package/lib/module/hooks/useCactusSTT.js +102 -6
  64. package/lib/module/hooks/useCactusSTT.js.map +1 -1
  65. package/lib/module/index.js +2 -0
  66. package/lib/module/index.js.map +1 -1
  67. package/lib/module/models.js +336 -0
  68. package/lib/module/models.js.map +1 -0
  69. package/lib/module/native/Cactus.js +61 -13
  70. package/lib/module/native/Cactus.js.map +1 -1
  71. package/lib/module/native/CactusFileSystem.js +3 -0
  72. package/lib/module/native/CactusFileSystem.js.map +1 -1
  73. package/lib/module/native/CactusIndex.js +32 -0
  74. package/lib/module/native/CactusIndex.js.map +1 -0
  75. package/lib/module/native/CactusUtil.js +16 -3
  76. package/lib/module/native/CactusUtil.js.map +1 -1
  77. package/lib/module/native/index.js +1 -0
  78. package/lib/module/native/index.js.map +1 -1
  79. package/lib/module/specs/CactusIndex.nitro.js +4 -0
  80. package/lib/module/specs/CactusIndex.nitro.js.map +1 -0
  81. package/lib/module/telemetry/Telemetry.js +3 -1
  82. package/lib/module/telemetry/Telemetry.js.map +1 -1
  83. package/lib/module/types/CactusIndex.js +2 -0
  84. package/lib/module/types/{CactusModel.js.map → CactusIndex.js.map} +1 -1
  85. package/lib/module/types/CactusLM.js +2 -0
  86. package/lib/module/types/CactusSTT.js +2 -0
  87. package/lib/module/types/common.js +2 -0
  88. package/lib/module/types/{CactusSTTModel.js.map → common.js.map} +1 -1
  89. package/lib/typescript/src/api/Database.d.ts +4 -7
  90. package/lib/typescript/src/api/Database.d.ts.map +1 -1
  91. package/lib/typescript/src/classes/CactusIndex.d.ts +15 -0
  92. package/lib/typescript/src/classes/CactusIndex.d.ts.map +1 -0
  93. package/lib/typescript/src/classes/CactusLM.d.ts +12 -5
  94. package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
  95. package/lib/typescript/src/classes/CactusSTT.d.ts +15 -5
  96. package/lib/typescript/src/classes/CactusSTT.d.ts.map +1 -1
  97. package/lib/typescript/src/config/CactusConfig.d.ts +1 -0
  98. package/lib/typescript/src/config/CactusConfig.d.ts.map +1 -1
  99. package/lib/typescript/src/constants/packageVersion.d.ts +1 -1
  100. package/lib/typescript/src/hooks/useCactusIndex.d.ts +14 -0
  101. package/lib/typescript/src/hooks/useCactusIndex.d.ts.map +1 -0
  102. package/lib/typescript/src/hooks/useCactusLM.d.ts +6 -4
  103. package/lib/typescript/src/hooks/useCactusLM.d.ts.map +1 -1
  104. package/lib/typescript/src/hooks/useCactusSTT.d.ts +13 -5
  105. package/lib/typescript/src/hooks/useCactusSTT.d.ts.map +1 -1
  106. package/lib/typescript/src/index.d.ts +6 -4
  107. package/lib/typescript/src/index.d.ts.map +1 -1
  108. package/lib/typescript/src/models.d.ts +6 -0
  109. package/lib/typescript/src/models.d.ts.map +1 -0
  110. package/lib/typescript/src/native/Cactus.d.ts +10 -3
  111. package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
  112. package/lib/typescript/src/native/CactusFileSystem.d.ts +1 -0
  113. package/lib/typescript/src/native/CactusFileSystem.d.ts.map +1 -1
  114. package/lib/typescript/src/native/CactusIndex.d.ts +12 -0
  115. package/lib/typescript/src/native/CactusIndex.d.ts.map +1 -0
  116. package/lib/typescript/src/native/CactusUtil.d.ts.map +1 -1
  117. package/lib/typescript/src/native/index.d.ts +1 -0
  118. package/lib/typescript/src/native/index.d.ts.map +1 -1
  119. package/lib/typescript/src/specs/Cactus.nitro.d.ts +9 -2
  120. package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
  121. package/lib/typescript/src/specs/CactusFileSystem.nitro.d.ts +1 -0
  122. package/lib/typescript/src/specs/CactusFileSystem.nitro.d.ts.map +1 -1
  123. package/lib/typescript/src/specs/CactusIndex.nitro.d.ts +24 -0
  124. package/lib/typescript/src/specs/CactusIndex.nitro.d.ts.map +1 -0
  125. package/lib/typescript/src/specs/CactusUtil.nitro.d.ts +1 -1
  126. package/lib/typescript/src/specs/CactusUtil.nitro.d.ts.map +1 -1
  127. package/lib/typescript/src/types/CactusIndex.d.ts +34 -0
  128. package/lib/typescript/src/types/CactusIndex.d.ts.map +1 -0
  129. package/lib/typescript/src/types/CactusLM.d.ts +19 -0
  130. package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
  131. package/lib/typescript/src/types/CactusSTT.d.ts +21 -1
  132. package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
  133. package/lib/typescript/src/types/common.d.ts +28 -0
  134. package/lib/typescript/src/types/common.d.ts.map +1 -0
  135. package/nitro.json +3 -0
  136. package/nitrogen/generated/android/c++/JDeviceInfo.hpp +1 -1
  137. package/nitrogen/generated/android/c++/JFunc_void_double.hpp +1 -1
  138. package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.cpp +1 -1
  139. package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.hpp +1 -1
  140. package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.cpp +1 -1
  141. package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.hpp +1 -1
  142. package/nitrogen/generated/android/c++/JHybridCactusFileSystemSpec.cpp +17 -1
  143. package/nitrogen/generated/android/c++/JHybridCactusFileSystemSpec.hpp +2 -1
  144. package/nitrogen/generated/android/c++/JHybridCactusImageSpec.cpp +1 -1
  145. package/nitrogen/generated/android/c++/JHybridCactusImageSpec.hpp +1 -1
  146. package/nitrogen/generated/android/cactus+autolinking.cmake +2 -1
  147. package/nitrogen/generated/android/cactus+autolinking.gradle +1 -1
  148. package/nitrogen/generated/android/cactusOnLoad.cpp +11 -1
  149. package/nitrogen/generated/android/cactusOnLoad.hpp +1 -1
  150. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/DeviceInfo.kt +1 -1
  151. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/Func_void_double.kt +1 -1
  152. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusCryptoSpec.kt +1 -1
  153. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusDeviceInfoSpec.kt +1 -1
  154. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusFileSystemSpec.kt +5 -1
  155. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusImageSpec.kt +1 -1
  156. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/cactusOnLoad.kt +1 -1
  157. package/nitrogen/generated/ios/Cactus+autolinking.rb +1 -1
  158. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.cpp +1 -1
  159. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.hpp +1 -1
  160. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Umbrella.hpp +1 -1
  161. package/nitrogen/generated/ios/CactusAutolinking.mm +11 -1
  162. package/nitrogen/generated/ios/CactusAutolinking.swift +1 -1
  163. package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.cpp +1 -1
  164. package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.hpp +1 -1
  165. package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.cpp +1 -1
  166. package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.hpp +1 -1
  167. package/nitrogen/generated/ios/c++/HybridCactusFileSystemSpecSwift.cpp +1 -1
  168. package/nitrogen/generated/ios/c++/HybridCactusFileSystemSpecSwift.hpp +9 -1
  169. package/nitrogen/generated/ios/c++/HybridCactusImageSpecSwift.cpp +1 -1
  170. package/nitrogen/generated/ios/c++/HybridCactusImageSpecSwift.hpp +1 -1
  171. package/nitrogen/generated/ios/swift/DeviceInfo.swift +1 -1
  172. package/nitrogen/generated/ios/swift/Func_void.swift +1 -1
  173. package/nitrogen/generated/ios/swift/Func_void_DeviceInfo.swift +1 -1
  174. package/nitrogen/generated/ios/swift/Func_void_bool.swift +1 -1
  175. package/nitrogen/generated/ios/swift/Func_void_double.swift +1 -1
  176. package/nitrogen/generated/ios/swift/Func_void_std__exception_ptr.swift +1 -1
  177. package/nitrogen/generated/ios/swift/Func_void_std__optional_std__string_.swift +1 -1
  178. package/nitrogen/generated/ios/swift/Func_void_std__string.swift +1 -1
  179. package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec.swift +1 -1
  180. package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec_cxx.swift +1 -1
  181. package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec.swift +1 -1
  182. package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec_cxx.swift +1 -1
  183. package/nitrogen/generated/ios/swift/HybridCactusFileSystemSpec.swift +2 -1
  184. package/nitrogen/generated/ios/swift/HybridCactusFileSystemSpec_cxx.swift +20 -1
  185. package/nitrogen/generated/ios/swift/HybridCactusImageSpec.swift +1 -1
  186. package/nitrogen/generated/ios/swift/HybridCactusImageSpec_cxx.swift +1 -1
  187. package/nitrogen/generated/shared/c++/CactusIndexGetResult.hpp +84 -0
  188. package/nitrogen/generated/shared/c++/CactusIndexQueryResult.hpp +79 -0
  189. package/nitrogen/generated/shared/c++/DeviceInfo.hpp +1 -1
  190. package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.cpp +1 -1
  191. package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.hpp +1 -1
  192. package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.cpp +1 -1
  193. package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.hpp +1 -1
  194. package/nitrogen/generated/shared/c++/HybridCactusFileSystemSpec.cpp +2 -1
  195. package/nitrogen/generated/shared/c++/HybridCactusFileSystemSpec.hpp +2 -1
  196. package/nitrogen/generated/shared/c++/HybridCactusImageSpec.cpp +1 -1
  197. package/nitrogen/generated/shared/c++/HybridCactusImageSpec.hpp +1 -1
  198. package/nitrogen/generated/shared/c++/HybridCactusIndexSpec.cpp +27 -0
  199. package/nitrogen/generated/shared/c++/HybridCactusIndexSpec.hpp +76 -0
  200. package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +8 -1
  201. package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +11 -3
  202. package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.cpp +1 -1
  203. package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.hpp +2 -2
  204. package/package.json +2 -2
  205. package/src/api/Database.ts +14 -135
  206. package/src/classes/CactusIndex.ts +58 -0
  207. package/src/classes/CactusLM.ts +87 -19
  208. package/src/classes/CactusSTT.ts +134 -20
  209. package/src/config/CactusConfig.ts +3 -0
  210. package/src/constants/packageVersion.ts +1 -1
  211. package/src/hooks/useCactusIndex.ts +195 -0
  212. package/src/hooks/useCactusLM.ts +88 -8
  213. package/src/hooks/useCactusSTT.ts +119 -7
  214. package/src/index.tsx +22 -2
  215. package/src/models.ts +344 -0
  216. package/src/native/Cactus.ts +95 -13
  217. package/src/native/CactusFileSystem.ts +4 -0
  218. package/src/native/CactusIndex.ts +54 -0
  219. package/src/native/CactusUtil.ts +19 -3
  220. package/src/native/index.ts +1 -0
  221. package/src/specs/Cactus.nitro.ts +18 -2
  222. package/src/specs/CactusFileSystem.nitro.ts +2 -0
  223. package/src/specs/CactusIndex.nitro.ts +31 -0
  224. package/src/specs/CactusUtil.nitro.ts +1 -1
  225. package/src/telemetry/Telemetry.ts +1 -1
  226. package/src/types/CactusIndex.ts +40 -0
  227. package/src/types/CactusLM.ts +24 -0
  228. package/src/types/CactusSTT.ts +27 -1
  229. package/src/types/common.ts +28 -0
  230. package/android/src/main/jniLibs/arm64-v8a/libcactus_util.so +0 -0
  231. package/lib/module/types/CactusModel.js +0 -2
  232. package/lib/module/types/CactusSTTModel.js +0 -2
  233. package/lib/typescript/src/types/CactusModel.d.ts +0 -13
  234. package/lib/typescript/src/types/CactusModel.d.ts.map +0 -1
  235. package/lib/typescript/src/types/CactusSTTModel.d.ts +0 -8
  236. package/lib/typescript/src/types/CactusSTTModel.d.ts.map +0 -1
  237. package/src/types/CactusModel.ts +0 -15
  238. package/src/types/CactusSTTModel.ts +0 -10
@@ -15,12 +15,7 @@ enum class ScalarOpType {
15
15
  SIN
16
16
  };
17
17
 
18
-
19
- void cactus_add_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
20
- void cactus_subtract_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
21
- void cactus_multiply_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
22
- void cactus_divide_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
23
-
18
+ constexpr size_t KV_QUANT_GROUP_SIZE = 128;
24
19
 
25
20
  void cactus_add_f16(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
26
21
  void cactus_add_f16_clipped(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
@@ -28,27 +23,6 @@ void cactus_subtract_f16(const __fp16* a, const __fp16* b, __fp16* output, size_
28
23
  void cactus_multiply_f16(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
29
24
  void cactus_divide_f16(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
30
25
 
31
-
32
- void cactus_add_f32(const float* a, const float* b, float* output, size_t num_elements);
33
- void cactus_subtract_f32(const float* a, const float* b, float* output, size_t num_elements);
34
- void cactus_multiply_f32(const float* a, const float* b, float* output, size_t num_elements);
35
- void cactus_divide_f32(const float* a, const float* b, float* output, size_t num_elements);
36
-
37
-
38
- void cactus_add_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
39
- const size_t* a_strides, const size_t* b_strides,
40
- const size_t* output_shape, size_t ndim);
41
- void cactus_subtract_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
42
- const size_t* a_strides, const size_t* b_strides,
43
- const size_t* output_shape, size_t ndim);
44
- void cactus_multiply_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
45
- const size_t* a_strides, const size_t* b_strides,
46
- const size_t* output_shape, size_t ndim);
47
- void cactus_divide_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
48
- const size_t* a_strides, const size_t* b_strides,
49
- const size_t* output_shape, size_t ndim);
50
-
51
-
52
26
  void cactus_add_broadcast_f16(const __fp16* a, const __fp16* b, __fp16* output,
53
27
  const size_t* a_strides, const size_t* b_strides,
54
28
  const size_t* output_shape, size_t ndim);
@@ -62,154 +36,72 @@ void cactus_divide_broadcast_f16(const __fp16* a, const __fp16* b, __fp16* outpu
62
36
  const size_t* a_strides, const size_t* b_strides,
63
37
  const size_t* output_shape, size_t ndim);
64
38
 
65
-
66
- void cactus_add_broadcast_f32(const float* a, const float* b, float* output,
67
- const size_t* a_strides, const size_t* b_strides,
68
- const size_t* output_shape, size_t ndim);
69
- void cactus_subtract_broadcast_f32(const float* a, const float* b, float* output,
70
- const size_t* a_strides, const size_t* b_strides,
71
- const size_t* output_shape, size_t ndim);
72
- void cactus_multiply_broadcast_f32(const float* a, const float* b, float* output,
73
- const size_t* a_strides, const size_t* b_strides,
74
- const size_t* output_shape, size_t ndim);
75
- void cactus_divide_broadcast_f32(const float* a, const float* b, float* output,
76
- const size_t* a_strides, const size_t* b_strides,
77
- const size_t* output_shape, size_t ndim);
78
-
79
-
80
- void cactus_scalar_op_int8(const int8_t* input, int8_t* output, size_t num_elements, float scalar_value, ScalarOpType op_type);
81
39
  void cactus_scalar_op_f16(const __fp16* input, __fp16* output, size_t num_elements, float scalar_value, ScalarOpType op_type);
82
- void cactus_scalar_op_f32(const float* input, float* output, size_t num_elements, float scalar_value, ScalarOpType op_type);
83
40
 
41
+ void cactus_matmul_int8(const int8_t* A, const float* A_scales,
42
+ const int8_t* B, const __fp16* B_scales,
43
+ __fp16* C, size_t M, size_t K, size_t N, size_t group_size);
84
44
 
85
- void cactus_matmul_int8(const int8_t* a, const int8_t* b_transposed, int8_t* c,
86
- size_t M, size_t K, size_t N,
87
- float a_scale, float b_scale, float c_scale);
88
-
89
- #if defined(__ARM_FEATURE_MATMUL_INT8)
90
- void cactus_matmul_int8_to_int32_i8mm(const int8_t* a, const int8_t* b_transposed, int32_t* c,
91
- size_t M, size_t K, size_t N);
92
- #define cactus_matmul_int8_to_int32 cactus_matmul_int8_to_int32_i8mm
93
- #else
94
- void cactus_matmul_int8_to_int32(const int8_t* a, const int8_t* b_transposed, int32_t* c,
95
- size_t M, size_t K, size_t N);
96
- #endif
45
+ void cactus_matmul_int4(const int8_t* A, const float* A_scales,
46
+ const uint8_t* B_packed, const __fp16* B_scales,
47
+ __fp16* C, size_t M, size_t K, size_t N, size_t group_size);
97
48
 
98
49
  void cactus_matmul_f16(const __fp16* a, const __fp16* b_transposed, __fp16* c,
99
50
  size_t M, size_t K, size_t N);
100
51
 
101
- void cactus_matmul_f32(const float* a, const float* b_transposed, float* c,
102
- size_t M, size_t K, size_t N);
103
-
104
-
105
- void cactus_transpose_2d_int8(const int8_t* source, int8_t* destination,
106
- size_t num_rows, size_t num_cols, size_t start_row, size_t end_row);
107
- void cactus_transpose_2d_f32(const float* source, float* destination,
52
+ void cactus_transpose_2d_f16(const __fp16* source, __fp16* destination,
108
53
  size_t num_rows, size_t num_cols, size_t start_row, size_t end_row);
109
-
110
- void cactus_transpose_int8(const int8_t* source, int8_t* destination, const size_t* shape,
111
- const size_t* permutation, size_t ndim, size_t start_idx, size_t end_idx);
112
- void cactus_transpose_f32(const float* source, float* destination, const size_t* shape,
54
+ void cactus_transpose_f16(const __fp16* source, __fp16* destination, const size_t* shape,
113
55
  const size_t* permutation, size_t ndim, size_t start_idx, size_t end_idx);
114
56
 
115
- int64_t cactus_sum_all_int8(const int8_t* data, size_t num_elements);
116
- void cactus_sum_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, size_t axis_size, size_t inner_size);
117
- double cactus_sum_all_f32(const float* data, size_t num_elements);
118
- void cactus_sum_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
57
+ double cactus_sum_all_f16(const __fp16* data, size_t num_elements);
58
+ void cactus_sum_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
119
59
 
120
- double cactus_mean_all_int8(const int8_t* data, size_t num_elements);
121
- void cactus_mean_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, size_t axis_size, size_t inner_size);
122
60
  double cactus_mean_all_f16(const __fp16* data, size_t num_elements);
123
61
  void cactus_mean_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
124
- double cactus_mean_all_f32(const float* data, size_t num_elements);
125
- void cactus_mean_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
126
62
 
127
- double cactus_variance_all_int8(const int8_t* data, size_t num_elements);
128
- void cactus_variance_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, size_t axis_size, size_t inner_size);
129
- double cactus_variance_all_f32(const float* data, size_t num_elements);
130
- void cactus_variance_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
63
+ double cactus_variance_all_f16(const __fp16* data, size_t num_elements);
64
+ void cactus_variance_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
131
65
 
132
- int64_t cactus_min_all_int8(const int8_t* data, size_t num_elements);
133
- void cactus_min_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, size_t axis_size, size_t inner_size);
134
- float cactus_min_all_f32(const float* data, size_t num_elements);
135
- void cactus_min_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
66
+ __fp16 cactus_min_all_f16(const __fp16* data, size_t num_elements);
67
+ void cactus_min_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
136
68
 
137
- int64_t cactus_max_all_int8(const int8_t* data, size_t num_elements);
138
- void cactus_max_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, size_t axis_size, size_t inner_size);
139
- float cactus_max_all_f32(const float* data, size_t num_elements);
140
- void cactus_max_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
69
+ __fp16 cactus_max_all_f16(const __fp16* data, size_t num_elements);
70
+ void cactus_max_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
141
71
 
142
72
  void cactus_rms_norm_f16(const __fp16* input, const __fp16* weight, __fp16* output,
143
73
  size_t batch_size, size_t dims, float eps);
144
-
145
- void cactus_rms_norm_f32(const float* input, const float* weight, float* output,
146
- size_t batch_size, size_t dims, float eps);
147
-
148
- void cactus_rms_norm_i8_f32(const int8_t* input, const float* weight, float* output,
149
- size_t batch_size, size_t dims, float eps, float input_scale);
150
74
 
151
75
  void cactus_rope_f16(const __fp16* input, __fp16* output, size_t batch_size, size_t seq_len,
152
76
  size_t num_heads, size_t head_dim, size_t start_pos, float theta);
153
77
 
154
- void cactus_rope_f32(const float* input, float* output, size_t batch_size, size_t seq_len,
155
- size_t num_heads, size_t head_dim, size_t start_pos, float theta);
156
-
157
- void cactus_rope_i8_f32_i8(const int8_t* input, int8_t* output, size_t batch_size, size_t seq_len,
158
- size_t num_heads, size_t head_dim, size_t start_pos, float theta,
159
- float input_scale, float output_scale);
160
-
161
- void cactus_softmax_f16(const __fp16* input, __fp16* output, size_t batch_size,
78
+ void cactus_softmax_f16(const __fp16* input, __fp16* output, size_t batch_size,
162
79
  size_t seq_len, size_t vocab_size);
163
80
 
164
- void cactus_softmax_f32(const float* input, float* output, size_t batch_size,
165
- size_t seq_len, size_t vocab_size);
166
-
167
- void cactus_silu_f32(const float* input, float* output, size_t num_elements);
168
81
  void cactus_silu_f16(const __fp16* input, __fp16* output, size_t num_elements);
169
- void cactus_silu_int8(const int8_t* input, int8_t* output, size_t num_elements,
170
- float input_scale, float output_scale);
171
82
 
172
- void cactus_gelu_f32(const float* input, float* output, size_t num_elements);
173
83
  void cactus_gelu_f16(const __fp16* input, __fp16* output, size_t num_elements);
174
- void cactus_gelu_int8(const int8_t* input, int8_t* output, size_t num_elements,
175
- float input_scale, float output_scale);
176
84
 
177
- void cactus_gelu_f32_erf(const float* input, float* output, size_t num_elements);
178
85
  void cactus_gelu_f16_erf(const __fp16* input, __fp16* output, size_t num_elements);
179
- void cactus_gelu_int8_erf(
180
- const int8_t* input,
181
- int8_t* output,
182
- size_t num_elements,
183
- float scale_in,
184
- float scale_out);
185
-
186
-
187
- void cactus_attention_int8(const int8_t* queries, const int8_t* keys, const int8_t* values, int8_t* output,
188
- size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
189
- size_t head_dim, float scale, const int8_t* mask,
190
- float q_scale, float k_scale, float v_scale, float output_scale, size_t position_offset = 0, size_t window_size = 0,
191
- bool is_causal = true);
192
86
 
193
87
  void cactus_attention_f16(const __fp16* queries, const __fp16* keys, const __fp16* values, __fp16* output,
194
88
  size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
195
89
  size_t head_dim, float scale, const __fp16* mask, size_t position_offset = 0, size_t window_size = 0,
196
90
  bool is_causal = true);
197
91
 
198
- void cactus_attention_f32(const float* queries, const float* keys, const float* values, float* output,
199
- size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
200
- size_t head_dim, float scale, const float* mask, size_t position_offset = 0, size_t window_size = 0,
201
- bool is_causal = true);
202
-
203
-
204
- void cactus_conv1d_causal_depthwise_f32(
205
- const float* input,
206
- const float* weight,
207
- float* output,
208
- size_t N,
209
- size_t L,
210
- size_t C,
211
- size_t K,
212
- size_t dilation);
92
+ void cactus_attention_hybrid_int8_fp16(
93
+ const __fp16* queries,
94
+ const int8_t* keys_cached,
95
+ const int8_t* values_cached,
96
+ const float* k_scales,
97
+ const float* v_scales,
98
+ const __fp16* keys_new,
99
+ const __fp16* values_new,
100
+ __fp16* output,
101
+ size_t batch_size, size_t seq_len, size_t cache_len, size_t new_len,
102
+ size_t num_q_heads, size_t num_kv_heads, size_t head_dim,
103
+ float scale, size_t position_offset = 0, bool is_causal = true,
104
+ size_t group_size = KV_QUANT_GROUP_SIZE);
213
105
 
214
106
  void cactus_conv1d_causal_depthwise_f16(
215
107
  const __fp16* input,
@@ -221,30 +113,6 @@ void cactus_conv1d_causal_depthwise_f16(
221
113
  size_t K,
222
114
  size_t dilation);
223
115
 
224
- void cactus_conv1d_causal_depthwise_int8(
225
- const int8_t* input,
226
- const int8_t* weight,
227
- int8_t* output,
228
- size_t N,
229
- size_t L,
230
- size_t C,
231
- size_t K,
232
- size_t dilation,
233
- float input_scale,
234
- float weight_scale,
235
- float output_scale);
236
-
237
- void cactus_conv1d_f32_k3(
238
- const float* input,
239
- const float* weight,
240
- float* output,
241
- size_t N,
242
- size_t L,
243
- size_t C_in,
244
- size_t C_out,
245
- size_t stride
246
- );
247
-
248
116
  void cactus_conv1d_f16_k3(
249
117
  const __fp16* input,
250
118
  const __fp16* weight,
@@ -256,51 +124,42 @@ void cactus_conv1d_f16_k3(
256
124
  size_t stride
257
125
  );
258
126
 
259
- void cactus_conv1d_f32_k3(
260
- const float* input,
261
- const float* weight,
262
- float* output,
263
- size_t N, size_t L,
264
- size_t C_in, size_t C_out,
265
- size_t stride
266
- );
267
-
268
- void cactus_conv1d_f16_k3(
269
- const __fp16* input,
270
- const __fp16* weight,
271
- __fp16* output,
272
- size_t N, size_t L,
273
- size_t C_in, size_t C_out,
274
- size_t stride
275
- );
276
-
277
- void cactus_bilinear_interpolation_fp32(const float* input, float* output, size_t src_height, size_t src_width, size_t embed_dim,
278
- size_t dst_height, size_t dst_width);
127
+ void cactus_bilinear_interpolation_f16(const __fp16* input, __fp16* output, size_t src_height, size_t src_width, size_t embed_dim,
128
+ size_t dst_height, size_t dst_width);
279
129
 
280
130
  void cactus_sample_f32(const float* logits, uint32_t* output, size_t vocab_size,
281
- float temperature, float top_p, size_t top_k, size_t random_seed);
131
+ float temperature, float top_p, size_t top_k, size_t random_seed,
132
+ const float* bias_values = nullptr, const uint32_t* bias_indices = nullptr,
133
+ size_t bias_count = 0);
282
134
  void cactus_sample_f16(const __fp16* logits, uint32_t* output, size_t vocab_size,
283
- float temperature, float top_p, size_t top_k, size_t random_seed);
284
-
135
+ float temperature, float top_p, size_t top_k, size_t random_seed,
136
+ const float* bias_values = nullptr, const uint32_t* bias_indices = nullptr,
137
+ size_t bias_count = 0);
285
138
 
286
- void cactus_concat_f32(const float* input1, const float* input2, float* output,
287
- const size_t* shape1, const size_t* shape2, const size_t* output_shape,
288
- size_t ndims, int axis);
289
139
  void cactus_concat_f16(const __fp16* input1, const __fp16* input2, __fp16* output,
290
140
  const size_t* shape1, const size_t* shape2, const size_t* output_shape,
291
141
  size_t ndims, int axis);
292
- void cactus_concat_int8(const int8_t* input1, const int8_t* input2, int8_t* output,
293
- const size_t* shape1, const size_t* shape2, const size_t* output_shape,
294
- size_t ndims, int axis);
295
142
 
296
143
  void cactus_int8_to_fp32(const int8_t* src, float* dst, size_t count, float scale = 1.0f);
297
144
  void cactus_fp32_to_int8(const float* src, int8_t* dst, size_t count, float scale = 1.0f);
298
- void cactus_dynamic_quantize_fp32_to_int8(const float* src, int8_t* dst, size_t count, float* computed_scale);
299
145
  void cactus_fp16_to_fp32(const __fp16* src, float* dst, size_t count);
300
146
  void cactus_fp32_to_fp16(const float* src, __fp16* dst, size_t count);
301
147
  void cactus_int8_to_fp16(const int8_t* src, __fp16* dst, size_t count, float scale = 1.0f);
302
148
  void cactus_fp16_to_int8(const __fp16* src, int8_t* dst, size_t count, float scale = 1.0f);
303
149
  float cactus_fp16_max_abs(const __fp16* src, size_t count);
304
- void cactus_int32_to_fp16_scaled(const int32_t* src, __fp16* dst, size_t count, float scale);
305
150
 
306
- #endif
151
+ void cactus_quantize_kv_fp16_to_int8(
152
+ const __fp16* src,
153
+ int8_t* dst,
154
+ float* scales,
155
+ size_t seq_len, size_t kv_heads, size_t head_dim,
156
+ size_t group_size = KV_QUANT_GROUP_SIZE);
157
+
158
+ inline size_t kv_scales_count(size_t seq_len, size_t kv_heads, size_t head_dim, size_t group_size = KV_QUANT_GROUP_SIZE) {
159
+ size_t num_groups = (head_dim + group_size - 1) / group_size;
160
+ return seq_len * kv_heads * num_groups;
161
+ }
162
+
163
+ void cactus_unpack_int4_to_int8(const uint8_t* packed, int8_t* unpacked, size_t unpacked_count);
164
+
165
+ #endif