cactus-react-native 1.4.0 → 1.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. package/Cactus.podspec +1 -1
  2. package/README.md +465 -174
  3. package/android/CMakeLists.txt +24 -5
  4. package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libcurl.a +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libmbedcrypto.a +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libmbedtls.a +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libmbedx509.a +0 -0
  9. package/cpp/HybridCactus.cpp +157 -6
  10. package/cpp/HybridCactus.hpp +20 -3
  11. package/cpp/cactus_ffi.h +65 -30
  12. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus.h +0 -1
  13. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +65 -30
  14. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_utils.h +357 -122
  15. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +184 -63
  16. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/gemma_tools.h +549 -0
  17. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +153 -27
  18. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +90 -178
  19. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +276 -151
  20. package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
  21. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus.h +0 -1
  22. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +65 -30
  23. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_utils.h +357 -122
  24. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +184 -63
  25. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/gemma_tools.h +549 -0
  26. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +153 -27
  27. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +90 -178
  28. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +276 -151
  29. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
  30. package/lib/module/classes/CactusLM.js +43 -58
  31. package/lib/module/classes/CactusLM.js.map +1 -1
  32. package/lib/module/classes/CactusSTT.js +64 -38
  33. package/lib/module/classes/CactusSTT.js.map +1 -1
  34. package/lib/module/classes/CactusVAD.js +95 -0
  35. package/lib/module/classes/CactusVAD.js.map +1 -0
  36. package/lib/module/hooks/useCactusLM.js +23 -15
  37. package/lib/module/hooks/useCactusLM.js.map +1 -1
  38. package/lib/module/hooks/useCactusSTT.js +85 -28
  39. package/lib/module/hooks/useCactusSTT.js.map +1 -1
  40. package/lib/module/hooks/useCactusVAD.js +171 -0
  41. package/lib/module/hooks/useCactusVAD.js.map +1 -0
  42. package/lib/module/index.js +2 -3
  43. package/lib/module/index.js.map +1 -1
  44. package/lib/module/modelRegistry.js +52 -0
  45. package/lib/module/modelRegistry.js.map +1 -0
  46. package/lib/module/native/Cactus.js +107 -8
  47. package/lib/module/native/Cactus.js.map +1 -1
  48. package/lib/module/native/CactusIndex.js.map +1 -1
  49. package/lib/module/native/index.js +0 -3
  50. package/lib/module/native/index.js.map +1 -1
  51. package/lib/module/types/CactusLM.js +2 -0
  52. package/lib/module/types/CactusSTT.js +2 -0
  53. package/lib/module/types/CactusVAD.js +4 -0
  54. package/lib/module/types/{CactusModel.js.map → CactusVAD.js.map} +1 -1
  55. package/lib/module/types/common.js +2 -0
  56. package/lib/module/types/{CactusSTTModel.js.map → common.js.map} +1 -1
  57. package/lib/typescript/src/classes/CactusLM.d.ts +8 -6
  58. package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
  59. package/lib/typescript/src/classes/CactusSTT.d.ts +11 -6
  60. package/lib/typescript/src/classes/CactusSTT.d.ts.map +1 -1
  61. package/lib/typescript/src/classes/CactusVAD.d.ts +20 -0
  62. package/lib/typescript/src/classes/CactusVAD.d.ts.map +1 -0
  63. package/lib/typescript/src/hooks/useCactusLM.d.ts +3 -3
  64. package/lib/typescript/src/hooks/useCactusLM.d.ts.map +1 -1
  65. package/lib/typescript/src/hooks/useCactusSTT.d.ts +11 -5
  66. package/lib/typescript/src/hooks/useCactusSTT.d.ts.map +1 -1
  67. package/lib/typescript/src/hooks/useCactusVAD.d.ts +15 -0
  68. package/lib/typescript/src/hooks/useCactusVAD.d.ts.map +1 -0
  69. package/lib/typescript/src/index.d.ts +7 -6
  70. package/lib/typescript/src/index.d.ts.map +1 -1
  71. package/lib/typescript/src/modelRegistry.d.ts +5 -0
  72. package/lib/typescript/src/modelRegistry.d.ts.map +1 -0
  73. package/lib/typescript/src/native/Cactus.d.ts +12 -6
  74. package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
  75. package/lib/typescript/src/native/CactusIndex.d.ts +2 -2
  76. package/lib/typescript/src/native/CactusIndex.d.ts.map +1 -1
  77. package/lib/typescript/src/native/index.d.ts +0 -3
  78. package/lib/typescript/src/native/index.d.ts.map +1 -1
  79. package/lib/typescript/src/specs/Cactus.nitro.d.ts +6 -1
  80. package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
  81. package/lib/typescript/src/types/CactusIndex.d.ts +2 -2
  82. package/lib/typescript/src/types/CactusIndex.d.ts.map +1 -1
  83. package/lib/typescript/src/types/CactusLM.d.ts +19 -9
  84. package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
  85. package/lib/typescript/src/types/CactusSTT.d.ts +45 -4
  86. package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
  87. package/lib/typescript/src/types/CactusVAD.d.ts +34 -0
  88. package/lib/typescript/src/types/CactusVAD.d.ts.map +1 -0
  89. package/lib/typescript/src/types/common.d.ts +23 -0
  90. package/lib/typescript/src/types/common.d.ts.map +1 -0
  91. package/nitro.json +0 -11
  92. package/nitrogen/generated/android/cactus+autolinking.cmake +0 -5
  93. package/nitrogen/generated/android/cactusOnLoad.cpp +0 -30
  94. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.cpp +0 -50
  95. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.hpp +9 -147
  96. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Umbrella.hpp +0 -13
  97. package/nitrogen/generated/ios/CactusAutolinking.mm +0 -26
  98. package/nitrogen/generated/ios/CactusAutolinking.swift +0 -30
  99. package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +5 -0
  100. package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +6 -1
  101. package/package.json +3 -3
  102. package/src/classes/CactusLM.ts +59 -74
  103. package/src/classes/CactusSTT.ts +92 -49
  104. package/src/classes/CactusVAD.ts +129 -0
  105. package/src/hooks/useCactusLM.ts +26 -9
  106. package/src/hooks/useCactusSTT.ts +105 -44
  107. package/src/hooks/useCactusVAD.ts +215 -0
  108. package/src/index.tsx +20 -10
  109. package/src/modelRegistry.ts +65 -0
  110. package/src/native/Cactus.ts +130 -14
  111. package/src/native/CactusIndex.ts +2 -2
  112. package/src/native/index.ts +0 -3
  113. package/src/specs/Cactus.nitro.ts +11 -2
  114. package/src/types/CactusIndex.ts +2 -2
  115. package/src/types/CactusLM.ts +20 -9
  116. package/src/types/CactusSTT.ts +50 -4
  117. package/src/types/CactusVAD.ts +39 -0
  118. package/src/types/common.ts +23 -0
  119. package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusCrypto.kt +0 -46
  120. package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusDeviceInfo.kt +0 -27
  121. package/android/src/main/jniLibs/arm64-v8a/libcactus_util.a +0 -0
  122. package/cpp/HybridCactusUtil.cpp +0 -47
  123. package/cpp/HybridCactusUtil.hpp +0 -27
  124. package/cpp/cactus_util.h +0 -25
  125. package/ios/HybridCactusCrypto.swift +0 -37
  126. package/ios/HybridCactusDeviceInfo.swift +0 -32
  127. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_telemetry.h +0 -656
  128. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_telemetry.h +0 -656
  129. package/ios/cactus_util.xcframework/Info.plist +0 -39
  130. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/cactus_util.h +0 -25
  131. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/database.h +0 -27
  132. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/ios_utils.h +0 -10
  133. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/logging.h +0 -25
  134. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Info.plist +0 -0
  135. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/cactus_util +0 -0
  136. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/cactus_util.h +0 -25
  137. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/database.h +0 -27
  138. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/ios_utils.h +0 -10
  139. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/logging.h +0 -25
  140. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Info.plist +0 -0
  141. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/_CodeSignature/CodeResources +0 -135
  142. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/cactus_util +0 -0
  143. package/lib/module/api/Database.js +0 -137
  144. package/lib/module/api/Database.js.map +0 -1
  145. package/lib/module/api/RemoteLM.js +0 -201
  146. package/lib/module/api/RemoteLM.js.map +0 -1
  147. package/lib/module/config/CactusConfig.js +0 -12
  148. package/lib/module/config/CactusConfig.js.map +0 -1
  149. package/lib/module/native/CactusCrypto.js +0 -10
  150. package/lib/module/native/CactusCrypto.js.map +0 -1
  151. package/lib/module/native/CactusDeviceInfo.js +0 -13
  152. package/lib/module/native/CactusDeviceInfo.js.map +0 -1
  153. package/lib/module/native/CactusUtil.js +0 -36
  154. package/lib/module/native/CactusUtil.js.map +0 -1
  155. package/lib/module/specs/CactusCrypto.nitro.js +0 -4
  156. package/lib/module/specs/CactusCrypto.nitro.js.map +0 -1
  157. package/lib/module/specs/CactusDeviceInfo.nitro.js +0 -4
  158. package/lib/module/specs/CactusDeviceInfo.nitro.js.map +0 -1
  159. package/lib/module/specs/CactusUtil.nitro.js +0 -4
  160. package/lib/module/specs/CactusUtil.nitro.js.map +0 -1
  161. package/lib/module/telemetry/Telemetry.js +0 -154
  162. package/lib/module/telemetry/Telemetry.js.map +0 -1
  163. package/lib/module/types/CactusModel.js +0 -2
  164. package/lib/module/types/CactusSTTModel.js +0 -2
  165. package/lib/typescript/src/api/Database.d.ts +0 -18
  166. package/lib/typescript/src/api/Database.d.ts.map +0 -1
  167. package/lib/typescript/src/api/RemoteLM.d.ts +0 -14
  168. package/lib/typescript/src/api/RemoteLM.d.ts.map +0 -1
  169. package/lib/typescript/src/config/CactusConfig.d.ts +0 -7
  170. package/lib/typescript/src/config/CactusConfig.d.ts.map +0 -1
  171. package/lib/typescript/src/native/CactusCrypto.d.ts +0 -5
  172. package/lib/typescript/src/native/CactusCrypto.d.ts.map +0 -1
  173. package/lib/typescript/src/native/CactusDeviceInfo.d.ts +0 -7
  174. package/lib/typescript/src/native/CactusDeviceInfo.d.ts.map +0 -1
  175. package/lib/typescript/src/native/CactusUtil.d.ts +0 -6
  176. package/lib/typescript/src/native/CactusUtil.d.ts.map +0 -1
  177. package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts +0 -8
  178. package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts.map +0 -1
  179. package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts +0 -16
  180. package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts.map +0 -1
  181. package/lib/typescript/src/specs/CactusUtil.nitro.d.ts +0 -10
  182. package/lib/typescript/src/specs/CactusUtil.nitro.d.ts.map +0 -1
  183. package/lib/typescript/src/telemetry/Telemetry.d.ts +0 -34
  184. package/lib/typescript/src/telemetry/Telemetry.d.ts.map +0 -1
  185. package/lib/typescript/src/types/CactusModel.d.ts +0 -13
  186. package/lib/typescript/src/types/CactusModel.d.ts.map +0 -1
  187. package/lib/typescript/src/types/CactusSTTModel.d.ts +0 -8
  188. package/lib/typescript/src/types/CactusSTTModel.d.ts.map +0 -1
  189. package/nitrogen/generated/android/c++/JDeviceInfo.hpp +0 -74
  190. package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.cpp +0 -65
  191. package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.hpp +0 -65
  192. package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.cpp +0 -85
  193. package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.hpp +0 -66
  194. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/DeviceInfo.kt +0 -50
  195. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusCryptoSpec.kt +0 -58
  196. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusDeviceInfoSpec.kt +0 -62
  197. package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.cpp +0 -11
  198. package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.hpp +0 -77
  199. package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.cpp +0 -11
  200. package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.hpp +0 -88
  201. package/nitrogen/generated/ios/swift/DeviceInfo.swift +0 -98
  202. package/nitrogen/generated/ios/swift/Func_void_DeviceInfo.swift +0 -47
  203. package/nitrogen/generated/ios/swift/Func_void_std__optional_std__string_.swift +0 -54
  204. package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec.swift +0 -57
  205. package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec_cxx.swift +0 -139
  206. package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec.swift +0 -58
  207. package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec_cxx.swift +0 -164
  208. package/nitrogen/generated/shared/c++/DeviceInfo.hpp +0 -92
  209. package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.cpp +0 -21
  210. package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.hpp +0 -63
  211. package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.cpp +0 -22
  212. package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.hpp +0 -67
  213. package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.cpp +0 -23
  214. package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.hpp +0 -66
  215. package/src/api/Database.ts +0 -188
  216. package/src/api/RemoteLM.ts +0 -273
  217. package/src/config/CactusConfig.ts +0 -11
  218. package/src/native/CactusCrypto.ts +0 -11
  219. package/src/native/CactusDeviceInfo.ts +0 -18
  220. package/src/native/CactusUtil.ts +0 -43
  221. package/src/specs/CactusCrypto.nitro.ts +0 -6
  222. package/src/specs/CactusDeviceInfo.nitro.ts +0 -15
  223. package/src/specs/CactusUtil.nitro.ts +0 -8
  224. package/src/telemetry/Telemetry.ts +0 -236
  225. package/src/types/CactusModel.ts +0 -15
  226. package/src/types/CactusSTTModel.ts +0 -10
@@ -15,12 +15,7 @@ enum class ScalarOpType {
15
15
  SIN
16
16
  };
17
17
 
18
-
19
- void cactus_add_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
20
- void cactus_subtract_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
21
- void cactus_multiply_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
22
- void cactus_divide_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
23
-
18
+ constexpr size_t KV_QUANT_GROUP_SIZE = 32;
24
19
 
25
20
  void cactus_add_f16(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
26
21
  void cactus_add_f16_clipped(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
@@ -28,27 +23,6 @@ void cactus_subtract_f16(const __fp16* a, const __fp16* b, __fp16* output, size_
28
23
  void cactus_multiply_f16(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
29
24
  void cactus_divide_f16(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
30
25
 
31
-
32
- void cactus_add_f32(const float* a, const float* b, float* output, size_t num_elements);
33
- void cactus_subtract_f32(const float* a, const float* b, float* output, size_t num_elements);
34
- void cactus_multiply_f32(const float* a, const float* b, float* output, size_t num_elements);
35
- void cactus_divide_f32(const float* a, const float* b, float* output, size_t num_elements);
36
-
37
-
38
- void cactus_add_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
39
- const size_t* a_strides, const size_t* b_strides,
40
- const size_t* output_shape, size_t ndim);
41
- void cactus_subtract_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
42
- const size_t* a_strides, const size_t* b_strides,
43
- const size_t* output_shape, size_t ndim);
44
- void cactus_multiply_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
45
- const size_t* a_strides, const size_t* b_strides,
46
- const size_t* output_shape, size_t ndim);
47
- void cactus_divide_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
48
- const size_t* a_strides, const size_t* b_strides,
49
- const size_t* output_shape, size_t ndim);
50
-
51
-
52
26
  void cactus_add_broadcast_f16(const __fp16* a, const __fp16* b, __fp16* output,
53
27
  const size_t* a_strides, const size_t* b_strides,
54
28
  const size_t* output_shape, size_t ndim);
@@ -62,159 +36,85 @@ void cactus_divide_broadcast_f16(const __fp16* a, const __fp16* b, __fp16* outpu
62
36
  const size_t* a_strides, const size_t* b_strides,
63
37
  const size_t* output_shape, size_t ndim);
64
38
 
65
-
66
- void cactus_add_broadcast_f32(const float* a, const float* b, float* output,
67
- const size_t* a_strides, const size_t* b_strides,
68
- const size_t* output_shape, size_t ndim);
69
- void cactus_subtract_broadcast_f32(const float* a, const float* b, float* output,
70
- const size_t* a_strides, const size_t* b_strides,
71
- const size_t* output_shape, size_t ndim);
72
- void cactus_multiply_broadcast_f32(const float* a, const float* b, float* output,
73
- const size_t* a_strides, const size_t* b_strides,
74
- const size_t* output_shape, size_t ndim);
75
- void cactus_divide_broadcast_f32(const float* a, const float* b, float* output,
76
- const size_t* a_strides, const size_t* b_strides,
77
- const size_t* output_shape, size_t ndim);
78
-
79
-
80
- void cactus_scalar_op_int8(const int8_t* input, int8_t* output, size_t num_elements, float scalar_value, ScalarOpType op_type);
81
39
  void cactus_scalar_op_f16(const __fp16* input, __fp16* output, size_t num_elements, float scalar_value, ScalarOpType op_type);
82
- void cactus_scalar_op_f32(const float* input, float* output, size_t num_elements, float scalar_value, ScalarOpType op_type);
83
40
 
41
+ void cactus_gemv_int8(const int8_t* A, float A_scale,
42
+ const int8_t* B, const __fp16* B_scales,
43
+ __fp16* C, size_t K, size_t N, size_t group_size);
84
44
 
85
- void cactus_matmul_int8(const int8_t* a, const int8_t* b_transposed, int8_t* c,
86
- size_t M, size_t K, size_t N,
87
- float a_scale, float b_scale, float c_scale);
45
+ void cactus_gemm_int8(const int8_t* A, const float* A_scales,
46
+ const int8_t* B, const __fp16* B_scales,
47
+ __fp16* C, size_t M, size_t K, size_t N, size_t group_size);
88
48
 
89
- #if defined(__ARM_FEATURE_MATMUL_INT8)
90
- void cactus_matmul_int8_to_int32_i8mm(const int8_t* a, const int8_t* b_transposed, int32_t* c,
91
- size_t M, size_t K, size_t N);
92
- #define cactus_matmul_int8_to_int32 cactus_matmul_int8_to_int32_i8mm
93
- #else
94
- void cactus_matmul_int8_to_int32(const int8_t* a, const int8_t* b_transposed, int32_t* c,
95
- size_t M, size_t K, size_t N);
96
- #endif
49
+ void cactus_matmul_int8(const int8_t* A, const float* A_scales,
50
+ const int8_t* B, const __fp16* B_scales,
51
+ __fp16* C, size_t M, size_t K, size_t N, size_t group_size);
97
52
 
98
53
  void cactus_matmul_f16(const __fp16* a, const __fp16* b_transposed, __fp16* c,
99
54
  size_t M, size_t K, size_t N);
100
55
 
101
- void cactus_matmul_f32(const float* a, const float* b_transposed, float* c,
102
- size_t M, size_t K, size_t N);
103
-
104
-
105
- void cactus_transpose_2d_int8(const int8_t* source, int8_t* destination,
106
- size_t num_rows, size_t num_cols, size_t start_row, size_t end_row);
107
56
  void cactus_transpose_2d_f16(const __fp16* source, __fp16* destination,
108
57
  size_t num_rows, size_t num_cols, size_t start_row, size_t end_row);
109
- void cactus_transpose_2d_f32(const float* source, float* destination,
110
- size_t num_rows, size_t num_cols, size_t start_row, size_t end_row);
111
-
112
- void cactus_transpose_int8(const int8_t* source, int8_t* destination, const size_t* shape,
113
- const size_t* permutation, size_t ndim, size_t start_idx, size_t end_idx);
114
58
  void cactus_transpose_f16(const __fp16* source, __fp16* destination, const size_t* shape,
115
59
  const size_t* permutation, size_t ndim, size_t start_idx, size_t end_idx);
116
- void cactus_transpose_f32(const float* source, float* destination, const size_t* shape,
117
- const size_t* permutation, size_t ndim, size_t start_idx, size_t end_idx);
118
60
 
119
- int64_t cactus_sum_all_int8(const int8_t* data, size_t num_elements);
120
- void cactus_sum_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, size_t axis_size, size_t inner_size);
121
61
  double cactus_sum_all_f16(const __fp16* data, size_t num_elements);
122
- double cactus_sum_all_f32(const float* data, size_t num_elements);
123
- void cactus_sum_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
62
+ void cactus_sum_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
124
63
 
125
- double cactus_mean_all_int8(const int8_t* data, size_t num_elements);
126
- void cactus_mean_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, size_t axis_size, size_t inner_size);
127
64
  double cactus_mean_all_f16(const __fp16* data, size_t num_elements);
128
65
  void cactus_mean_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
129
- double cactus_mean_all_f32(const float* data, size_t num_elements);
130
- void cactus_mean_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
131
66
 
132
- double cactus_variance_all_int8(const int8_t* data, size_t num_elements);
133
- void cactus_variance_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, size_t axis_size, size_t inner_size);
134
- double cactus_variance_all_f32(const float* data, size_t num_elements);
135
- void cactus_variance_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
67
+ double cactus_variance_all_f16(const __fp16* data, size_t num_elements);
68
+ void cactus_variance_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
136
69
 
137
- int64_t cactus_min_all_int8(const int8_t* data, size_t num_elements);
138
- void cactus_min_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, size_t axis_size, size_t inner_size);
139
- float cactus_min_all_f32(const float* data, size_t num_elements);
140
- void cactus_min_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
70
+ __fp16 cactus_min_all_f16(const __fp16* data, size_t num_elements);
71
+ void cactus_min_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
141
72
 
142
- int64_t cactus_max_all_int8(const int8_t* data, size_t num_elements);
143
- void cactus_max_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, size_t axis_size, size_t inner_size);
144
- float cactus_max_all_f32(const float* data, size_t num_elements);
145
- void cactus_max_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
73
+ __fp16 cactus_max_all_f16(const __fp16* data, size_t num_elements);
74
+ void cactus_max_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
146
75
 
147
76
  void cactus_rms_norm_f16(const __fp16* input, const __fp16* weight, __fp16* output,
148
77
  size_t batch_size, size_t dims, float eps);
149
-
150
- void cactus_rms_norm_f32(const float* input, const float* weight, float* output,
151
- size_t batch_size, size_t dims, float eps);
152
-
153
- void cactus_rms_norm_i8_f32(const int8_t* input, const float* weight, float* output,
154
- size_t batch_size, size_t dims, float eps, float input_scale);
155
78
 
156
79
  void cactus_rope_f16(const __fp16* input, __fp16* output, size_t batch_size, size_t seq_len,
157
80
  size_t num_heads, size_t head_dim, size_t start_pos, float theta);
158
81
 
159
- void cactus_rope_f32(const float* input, float* output, size_t batch_size, size_t seq_len,
160
- size_t num_heads, size_t head_dim, size_t start_pos, float theta);
82
+ void cactus_gpt_j_rope_f16(const __fp16* input, __fp16* output, size_t batch_size, size_t seq_len,
83
+ size_t num_heads, size_t head_dim, size_t rot_dim, size_t start_pos, float theta);
161
84
 
162
- void cactus_rope_i8_f32_i8(const int8_t* input, int8_t* output, size_t batch_size, size_t seq_len,
163
- size_t num_heads, size_t head_dim, size_t start_pos, float theta,
164
- float input_scale, float output_scale);
165
-
166
- void cactus_softmax_f16(const __fp16* input, __fp16* output, size_t batch_size,
85
+ void cactus_softmax_f16(const __fp16* input, __fp16* output, size_t batch_size,
167
86
  size_t seq_len, size_t vocab_size);
168
87
 
169
- void cactus_softmax_f32(const float* input, float* output, size_t batch_size,
170
- size_t seq_len, size_t vocab_size);
88
+ void cactus_relu_f16(const __fp16* input, __fp16* output, size_t num_elements);
171
89
 
172
- void cactus_silu_f32(const float* input, float* output, size_t num_elements);
173
90
  void cactus_silu_f16(const __fp16* input, __fp16* output, size_t num_elements);
174
- void cactus_silu_int8(const int8_t* input, int8_t* output, size_t num_elements,
175
- float input_scale, float output_scale);
176
91
 
177
- void cactus_gelu_f32(const float* input, float* output, size_t num_elements);
178
92
  void cactus_gelu_f16(const __fp16* input, __fp16* output, size_t num_elements);
179
- void cactus_gelu_int8(const int8_t* input, int8_t* output, size_t num_elements,
180
- float input_scale, float output_scale);
181
93
 
182
- void cactus_gelu_f32_erf(const float* input, float* output, size_t num_elements);
183
94
  void cactus_gelu_f16_erf(const __fp16* input, __fp16* output, size_t num_elements);
184
- void cactus_gelu_int8_erf(
185
- const int8_t* input,
186
- int8_t* output,
187
- size_t num_elements,
188
- float scale_in,
189
- float scale_out);
190
-
191
-
192
- void cactus_attention_int8(const int8_t* queries, const int8_t* keys, const int8_t* values, int8_t* output,
193
- size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
194
- size_t head_dim, float scale, const int8_t* mask,
195
- float q_scale, float k_scale, float v_scale, float output_scale, size_t position_offset = 0, size_t window_size = 0,
196
- bool is_causal = true);
95
+
96
+ void cactus_sigmoid_f16(const __fp16* input, __fp16* output, size_t num_elements);
97
+
98
+ void cactus_tanh_f16(const __fp16* input, __fp16* output, size_t num_elements);
197
99
 
198
100
  void cactus_attention_f16(const __fp16* queries, const __fp16* keys, const __fp16* values, __fp16* output,
199
101
  size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
200
102
  size_t head_dim, float scale, const __fp16* mask, size_t position_offset = 0, size_t window_size = 0,
201
103
  bool is_causal = true);
202
104
 
203
- void cactus_attention_f32(const float* queries, const float* keys, const float* values, float* output,
204
- size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
205
- size_t head_dim, float scale, const float* mask, size_t position_offset = 0, size_t window_size = 0,
206
- bool is_causal = true);
207
-
208
-
209
- void cactus_conv1d_causal_depthwise_f32(
210
- const float* input,
211
- const float* weight,
212
- float* output,
213
- size_t N,
214
- size_t L,
215
- size_t C,
216
- size_t K,
217
- size_t dilation);
105
+ void cactus_attention_hybrid_int8_fp16(
106
+ const __fp16* queries,
107
+ const int8_t* keys_cached,
108
+ const int8_t* values_cached,
109
+ const float* k_scales,
110
+ const float* v_scales,
111
+ const __fp16* keys_new,
112
+ const __fp16* values_new,
113
+ __fp16* output,
114
+ size_t batch_size, size_t seq_len, size_t cache_len, size_t new_len,
115
+ size_t num_q_heads, size_t num_kv_heads, size_t head_dim,
116
+ float scale, size_t position_offset = 0, bool is_causal = true, size_t window_size = 0,
117
+ size_t group_size = KV_QUANT_GROUP_SIZE);
218
118
 
219
119
  void cactus_conv1d_causal_depthwise_f16(
220
120
  const __fp16* input,
@@ -226,23 +126,10 @@ void cactus_conv1d_causal_depthwise_f16(
226
126
  size_t K,
227
127
  size_t dilation);
228
128
 
229
- void cactus_conv1d_causal_depthwise_int8(
230
- const int8_t* input,
231
- const int8_t* weight,
232
- int8_t* output,
233
- size_t N,
234
- size_t L,
235
- size_t C,
236
- size_t K,
237
- size_t dilation,
238
- float input_scale,
239
- float weight_scale,
240
- float output_scale);
241
-
242
- void cactus_conv1d_f32_k3(
243
- const float* input,
244
- const float* weight,
245
- float* output,
129
+ void cactus_conv1d_f16_k3(
130
+ const __fp16* input,
131
+ const __fp16* weight,
132
+ __fp16* output,
246
133
  size_t N,
247
134
  size_t L,
248
135
  size_t C_in,
@@ -250,37 +137,42 @@ void cactus_conv1d_f32_k3(
250
137
  size_t stride
251
138
  );
252
139
 
253
- void cactus_conv1d_f16_k3(
140
+ void cactus_conv1d_f16(
254
141
  const __fp16* input,
255
142
  const __fp16* weight,
143
+ const __fp16* bias,
256
144
  __fp16* output,
257
145
  size_t N,
258
146
  size_t L,
259
147
  size_t C_in,
260
148
  size_t C_out,
149
+ size_t K,
261
150
  size_t stride
262
151
  );
263
152
 
264
- void cactus_conv1d_f32_k3(
265
- const float* input,
266
- const float* weight,
267
- float* output,
153
+ void cactus_stft_magnitude_f16(
154
+ const __fp16* input,
155
+ const __fp16* weight,
156
+ __fp16* output,
268
157
  size_t N, size_t L,
269
158
  size_t C_in, size_t C_out,
270
- size_t stride
159
+ size_t K, size_t stride,
160
+ size_t num_fft_bins
271
161
  );
272
162
 
273
- void cactus_conv1d_f16_k3(
163
+ void cactus_conv1d_f16_k7s3_oc8(
274
164
  const __fp16* input,
275
- const __fp16* weight,
165
+ const __fp16* Wpack,
166
+ const __fp16* bias,
276
167
  __fp16* output,
277
- size_t N, size_t L,
278
- size_t C_in, size_t C_out,
279
- size_t stride
168
+ size_t N,
169
+ size_t L,
170
+ size_t C_in,
171
+ size_t C_out
280
172
  );
281
173
 
282
- void cactus_bilinear_interpolation_fp32(const float* input, float* output, size_t src_height, size_t src_width, size_t embed_dim,
283
- size_t dst_height, size_t dst_width);
174
+ void cactus_bilinear_interpolation_f16(const __fp16* input, __fp16* output, size_t src_height, size_t src_width, size_t embed_dim,
175
+ size_t dst_height, size_t dst_width);
284
176
 
285
177
  void cactus_sample_f32(const float* logits, uint32_t* output, size_t vocab_size,
286
178
  float temperature, float top_p, size_t top_k, size_t random_seed,
@@ -291,25 +183,45 @@ void cactus_sample_f16(const __fp16* logits, uint32_t* output, size_t vocab_size
291
183
  const float* bias_values = nullptr, const uint32_t* bias_indices = nullptr,
292
184
  size_t bias_count = 0);
293
185
 
294
-
295
- void cactus_concat_f32(const float* input1, const float* input2, float* output,
296
- const size_t* shape1, const size_t* shape2, const size_t* output_shape,
297
- size_t ndims, int axis);
298
186
  void cactus_concat_f16(const __fp16* input1, const __fp16* input2, __fp16* output,
299
187
  const size_t* shape1, const size_t* shape2, const size_t* output_shape,
300
188
  size_t ndims, int axis);
301
- void cactus_concat_int8(const int8_t* input1, const int8_t* input2, int8_t* output,
302
- const size_t* shape1, const size_t* shape2, const size_t* output_shape,
303
- size_t ndims, int axis);
304
189
 
305
190
  void cactus_int8_to_fp32(const int8_t* src, float* dst, size_t count, float scale = 1.0f);
306
191
  void cactus_fp32_to_int8(const float* src, int8_t* dst, size_t count, float scale = 1.0f);
307
- void cactus_dynamic_quantize_fp32_to_int8(const float* src, int8_t* dst, size_t count, float* computed_scale);
308
192
  void cactus_fp16_to_fp32(const __fp16* src, float* dst, size_t count);
309
193
  void cactus_fp32_to_fp16(const float* src, __fp16* dst, size_t count);
310
194
  void cactus_int8_to_fp16(const int8_t* src, __fp16* dst, size_t count, float scale = 1.0f);
311
195
  void cactus_fp16_to_int8(const __fp16* src, int8_t* dst, size_t count, float scale = 1.0f);
312
196
  float cactus_fp16_max_abs(const __fp16* src, size_t count);
313
- void cactus_int32_to_fp16_scaled(const int32_t* src, __fp16* dst, size_t count, float scale);
314
197
 
315
- #endif
198
+ void cactus_quantize_kv_fp16_to_int8(
199
+ const __fp16* src,
200
+ int8_t* dst,
201
+ float* scales,
202
+ size_t seq_len, size_t kv_heads, size_t head_dim,
203
+ size_t group_size = KV_QUANT_GROUP_SIZE);
204
+
205
+ inline size_t kv_scales_count(size_t seq_len, size_t kv_heads, size_t head_dim, size_t group_size = KV_QUANT_GROUP_SIZE) {
206
+ size_t num_groups = (head_dim + group_size - 1) / group_size;
207
+ return seq_len * kv_heads * num_groups;
208
+ }
209
+
210
+ void cactus_unpack_int4_to_int8(const uint8_t* packed, int8_t* unpacked, size_t unpacked_count);
211
+
212
+ void cactus_lstm_cell_f16(
213
+ const __fp16* x_input,
214
+ const __fp16* h_prev,
215
+ const __fp16* c_prev,
216
+ const __fp16* weight_ih,
217
+ const __fp16* weight_hh,
218
+ const __fp16* bias_ih,
219
+ const __fp16* bias_hh,
220
+ __fp16* h_new,
221
+ __fp16* c_new,
222
+ size_t batch_size,
223
+ size_t input_size,
224
+ size_t hidden_size
225
+ );
226
+
227
+ #endif