cactus-react-native 1.4.0 → 1.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. package/Cactus.podspec +1 -1
  2. package/README.md +465 -174
  3. package/android/CMakeLists.txt +24 -5
  4. package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libcurl.a +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libmbedcrypto.a +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libmbedtls.a +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libmbedx509.a +0 -0
  9. package/cpp/HybridCactus.cpp +157 -6
  10. package/cpp/HybridCactus.hpp +20 -3
  11. package/cpp/cactus_ffi.h +65 -30
  12. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus.h +0 -1
  13. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +65 -30
  14. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_utils.h +357 -122
  15. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +184 -63
  16. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/gemma_tools.h +549 -0
  17. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +153 -27
  18. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +90 -178
  19. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +276 -151
  20. package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
  21. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus.h +0 -1
  22. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +65 -30
  23. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_utils.h +357 -122
  24. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +184 -63
  25. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/gemma_tools.h +549 -0
  26. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +153 -27
  27. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +90 -178
  28. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +276 -151
  29. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
  30. package/lib/module/classes/CactusLM.js +43 -58
  31. package/lib/module/classes/CactusLM.js.map +1 -1
  32. package/lib/module/classes/CactusSTT.js +64 -38
  33. package/lib/module/classes/CactusSTT.js.map +1 -1
  34. package/lib/module/classes/CactusVAD.js +95 -0
  35. package/lib/module/classes/CactusVAD.js.map +1 -0
  36. package/lib/module/hooks/useCactusLM.js +23 -15
  37. package/lib/module/hooks/useCactusLM.js.map +1 -1
  38. package/lib/module/hooks/useCactusSTT.js +85 -28
  39. package/lib/module/hooks/useCactusSTT.js.map +1 -1
  40. package/lib/module/hooks/useCactusVAD.js +171 -0
  41. package/lib/module/hooks/useCactusVAD.js.map +1 -0
  42. package/lib/module/index.js +2 -3
  43. package/lib/module/index.js.map +1 -1
  44. package/lib/module/modelRegistry.js +52 -0
  45. package/lib/module/modelRegistry.js.map +1 -0
  46. package/lib/module/native/Cactus.js +107 -8
  47. package/lib/module/native/Cactus.js.map +1 -1
  48. package/lib/module/native/CactusIndex.js.map +1 -1
  49. package/lib/module/native/index.js +0 -3
  50. package/lib/module/native/index.js.map +1 -1
  51. package/lib/module/types/CactusLM.js +2 -0
  52. package/lib/module/types/CactusSTT.js +2 -0
  53. package/lib/module/types/CactusVAD.js +4 -0
  54. package/lib/module/types/{CactusModel.js.map → CactusVAD.js.map} +1 -1
  55. package/lib/module/types/common.js +2 -0
  56. package/lib/module/types/{CactusSTTModel.js.map → common.js.map} +1 -1
  57. package/lib/typescript/src/classes/CactusLM.d.ts +8 -6
  58. package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
  59. package/lib/typescript/src/classes/CactusSTT.d.ts +11 -6
  60. package/lib/typescript/src/classes/CactusSTT.d.ts.map +1 -1
  61. package/lib/typescript/src/classes/CactusVAD.d.ts +20 -0
  62. package/lib/typescript/src/classes/CactusVAD.d.ts.map +1 -0
  63. package/lib/typescript/src/hooks/useCactusLM.d.ts +3 -3
  64. package/lib/typescript/src/hooks/useCactusLM.d.ts.map +1 -1
  65. package/lib/typescript/src/hooks/useCactusSTT.d.ts +11 -5
  66. package/lib/typescript/src/hooks/useCactusSTT.d.ts.map +1 -1
  67. package/lib/typescript/src/hooks/useCactusVAD.d.ts +15 -0
  68. package/lib/typescript/src/hooks/useCactusVAD.d.ts.map +1 -0
  69. package/lib/typescript/src/index.d.ts +7 -6
  70. package/lib/typescript/src/index.d.ts.map +1 -1
  71. package/lib/typescript/src/modelRegistry.d.ts +5 -0
  72. package/lib/typescript/src/modelRegistry.d.ts.map +1 -0
  73. package/lib/typescript/src/native/Cactus.d.ts +12 -6
  74. package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
  75. package/lib/typescript/src/native/CactusIndex.d.ts +2 -2
  76. package/lib/typescript/src/native/CactusIndex.d.ts.map +1 -1
  77. package/lib/typescript/src/native/index.d.ts +0 -3
  78. package/lib/typescript/src/native/index.d.ts.map +1 -1
  79. package/lib/typescript/src/specs/Cactus.nitro.d.ts +6 -1
  80. package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
  81. package/lib/typescript/src/types/CactusIndex.d.ts +2 -2
  82. package/lib/typescript/src/types/CactusIndex.d.ts.map +1 -1
  83. package/lib/typescript/src/types/CactusLM.d.ts +19 -9
  84. package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
  85. package/lib/typescript/src/types/CactusSTT.d.ts +45 -4
  86. package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
  87. package/lib/typescript/src/types/CactusVAD.d.ts +34 -0
  88. package/lib/typescript/src/types/CactusVAD.d.ts.map +1 -0
  89. package/lib/typescript/src/types/common.d.ts +23 -0
  90. package/lib/typescript/src/types/common.d.ts.map +1 -0
  91. package/nitro.json +0 -11
  92. package/nitrogen/generated/android/cactus+autolinking.cmake +0 -5
  93. package/nitrogen/generated/android/cactusOnLoad.cpp +0 -30
  94. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.cpp +0 -50
  95. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.hpp +9 -147
  96. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Umbrella.hpp +0 -13
  97. package/nitrogen/generated/ios/CactusAutolinking.mm +0 -26
  98. package/nitrogen/generated/ios/CactusAutolinking.swift +0 -30
  99. package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +5 -0
  100. package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +6 -1
  101. package/package.json +3 -3
  102. package/src/classes/CactusLM.ts +59 -74
  103. package/src/classes/CactusSTT.ts +92 -49
  104. package/src/classes/CactusVAD.ts +129 -0
  105. package/src/hooks/useCactusLM.ts +26 -9
  106. package/src/hooks/useCactusSTT.ts +105 -44
  107. package/src/hooks/useCactusVAD.ts +215 -0
  108. package/src/index.tsx +20 -10
  109. package/src/modelRegistry.ts +65 -0
  110. package/src/native/Cactus.ts +130 -14
  111. package/src/native/CactusIndex.ts +2 -2
  112. package/src/native/index.ts +0 -3
  113. package/src/specs/Cactus.nitro.ts +11 -2
  114. package/src/types/CactusIndex.ts +2 -2
  115. package/src/types/CactusLM.ts +20 -9
  116. package/src/types/CactusSTT.ts +50 -4
  117. package/src/types/CactusVAD.ts +39 -0
  118. package/src/types/common.ts +23 -0
  119. package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusCrypto.kt +0 -46
  120. package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusDeviceInfo.kt +0 -27
  121. package/android/src/main/jniLibs/arm64-v8a/libcactus_util.a +0 -0
  122. package/cpp/HybridCactusUtil.cpp +0 -47
  123. package/cpp/HybridCactusUtil.hpp +0 -27
  124. package/cpp/cactus_util.h +0 -25
  125. package/ios/HybridCactusCrypto.swift +0 -37
  126. package/ios/HybridCactusDeviceInfo.swift +0 -32
  127. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_telemetry.h +0 -656
  128. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_telemetry.h +0 -656
  129. package/ios/cactus_util.xcframework/Info.plist +0 -39
  130. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/cactus_util.h +0 -25
  131. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/database.h +0 -27
  132. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/ios_utils.h +0 -10
  133. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/logging.h +0 -25
  134. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Info.plist +0 -0
  135. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/cactus_util +0 -0
  136. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/cactus_util.h +0 -25
  137. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/database.h +0 -27
  138. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/ios_utils.h +0 -10
  139. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/logging.h +0 -25
  140. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Info.plist +0 -0
  141. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/_CodeSignature/CodeResources +0 -135
  142. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/cactus_util +0 -0
  143. package/lib/module/api/Database.js +0 -137
  144. package/lib/module/api/Database.js.map +0 -1
  145. package/lib/module/api/RemoteLM.js +0 -201
  146. package/lib/module/api/RemoteLM.js.map +0 -1
  147. package/lib/module/config/CactusConfig.js +0 -12
  148. package/lib/module/config/CactusConfig.js.map +0 -1
  149. package/lib/module/native/CactusCrypto.js +0 -10
  150. package/lib/module/native/CactusCrypto.js.map +0 -1
  151. package/lib/module/native/CactusDeviceInfo.js +0 -13
  152. package/lib/module/native/CactusDeviceInfo.js.map +0 -1
  153. package/lib/module/native/CactusUtil.js +0 -36
  154. package/lib/module/native/CactusUtil.js.map +0 -1
  155. package/lib/module/specs/CactusCrypto.nitro.js +0 -4
  156. package/lib/module/specs/CactusCrypto.nitro.js.map +0 -1
  157. package/lib/module/specs/CactusDeviceInfo.nitro.js +0 -4
  158. package/lib/module/specs/CactusDeviceInfo.nitro.js.map +0 -1
  159. package/lib/module/specs/CactusUtil.nitro.js +0 -4
  160. package/lib/module/specs/CactusUtil.nitro.js.map +0 -1
  161. package/lib/module/telemetry/Telemetry.js +0 -154
  162. package/lib/module/telemetry/Telemetry.js.map +0 -1
  163. package/lib/module/types/CactusModel.js +0 -2
  164. package/lib/module/types/CactusSTTModel.js +0 -2
  165. package/lib/typescript/src/api/Database.d.ts +0 -18
  166. package/lib/typescript/src/api/Database.d.ts.map +0 -1
  167. package/lib/typescript/src/api/RemoteLM.d.ts +0 -14
  168. package/lib/typescript/src/api/RemoteLM.d.ts.map +0 -1
  169. package/lib/typescript/src/config/CactusConfig.d.ts +0 -7
  170. package/lib/typescript/src/config/CactusConfig.d.ts.map +0 -1
  171. package/lib/typescript/src/native/CactusCrypto.d.ts +0 -5
  172. package/lib/typescript/src/native/CactusCrypto.d.ts.map +0 -1
  173. package/lib/typescript/src/native/CactusDeviceInfo.d.ts +0 -7
  174. package/lib/typescript/src/native/CactusDeviceInfo.d.ts.map +0 -1
  175. package/lib/typescript/src/native/CactusUtil.d.ts +0 -6
  176. package/lib/typescript/src/native/CactusUtil.d.ts.map +0 -1
  177. package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts +0 -8
  178. package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts.map +0 -1
  179. package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts +0 -16
  180. package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts.map +0 -1
  181. package/lib/typescript/src/specs/CactusUtil.nitro.d.ts +0 -10
  182. package/lib/typescript/src/specs/CactusUtil.nitro.d.ts.map +0 -1
  183. package/lib/typescript/src/telemetry/Telemetry.d.ts +0 -34
  184. package/lib/typescript/src/telemetry/Telemetry.d.ts.map +0 -1
  185. package/lib/typescript/src/types/CactusModel.d.ts +0 -13
  186. package/lib/typescript/src/types/CactusModel.d.ts.map +0 -1
  187. package/lib/typescript/src/types/CactusSTTModel.d.ts +0 -8
  188. package/lib/typescript/src/types/CactusSTTModel.d.ts.map +0 -1
  189. package/nitrogen/generated/android/c++/JDeviceInfo.hpp +0 -74
  190. package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.cpp +0 -65
  191. package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.hpp +0 -65
  192. package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.cpp +0 -85
  193. package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.hpp +0 -66
  194. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/DeviceInfo.kt +0 -50
  195. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusCryptoSpec.kt +0 -58
  196. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusDeviceInfoSpec.kt +0 -62
  197. package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.cpp +0 -11
  198. package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.hpp +0 -77
  199. package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.cpp +0 -11
  200. package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.hpp +0 -88
  201. package/nitrogen/generated/ios/swift/DeviceInfo.swift +0 -98
  202. package/nitrogen/generated/ios/swift/Func_void_DeviceInfo.swift +0 -47
  203. package/nitrogen/generated/ios/swift/Func_void_std__optional_std__string_.swift +0 -54
  204. package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec.swift +0 -57
  205. package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec_cxx.swift +0 -139
  206. package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec.swift +0 -58
  207. package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec_cxx.swift +0 -164
  208. package/nitrogen/generated/shared/c++/DeviceInfo.hpp +0 -92
  209. package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.cpp +0 -21
  210. package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.hpp +0 -63
  211. package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.cpp +0 -22
  212. package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.hpp +0 -67
  213. package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.cpp +0 -23
  214. package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.hpp +0 -66
  215. package/src/api/Database.ts +0 -188
  216. package/src/api/RemoteLM.ts +0 -273
  217. package/src/config/CactusConfig.ts +0 -11
  218. package/src/native/CactusCrypto.ts +0 -11
  219. package/src/native/CactusDeviceInfo.ts +0 -18
  220. package/src/native/CactusUtil.ts +0 -43
  221. package/src/specs/CactusCrypto.nitro.ts +0 -6
  222. package/src/specs/CactusDeviceInfo.nitro.ts +0 -15
  223. package/src/specs/CactusUtil.nitro.ts +0 -8
  224. package/src/telemetry/Telemetry.ts +0 -236
  225. package/src/types/CactusModel.ts +0 -15
  226. package/src/types/CactusSTTModel.ts +0 -10
@@ -2,6 +2,14 @@
2
2
  #define KERNEL_UTILS_H
3
3
 
4
4
  #include <arm_neon.h>
5
+ #if defined(__APPLE__)
6
+ #include <TargetConditionals.h>
7
+ #include <sys/sysctl.h>
8
+ #endif
9
+ #if defined(__ANDROID__)
10
+ #include <sys/auxv.h>
11
+ #include <asm/hwcap.h>
12
+ #endif
5
13
  #include <algorithm>
6
14
  #include <cmath>
7
15
  #include <thread>
@@ -19,166 +27,304 @@
19
27
  #include <cstdio>
20
28
 
21
29
  constexpr size_t NEON_VECTOR_SIZE = 16;
30
+ constexpr size_t STREAMING_STORE_THRESHOLD = 32768;
22
31
 
23
- inline int8_t clamp_to_int8(float value) {
24
- int32_t clamped = static_cast<int32_t>(roundf(value));
25
- return static_cast<int8_t>(std::max(-128, std::min(127, clamped)));
32
+ inline void stream_store_f16x8(__fp16* dst, float16x8_t val) {
33
+ #if defined(__aarch64__)
34
+ float16x4_t lo = vget_low_f16(val);
35
+ float16x4_t hi = vget_high_f16(val);
36
+ __asm__ __volatile__(
37
+ "stnp %d0, %d1, [%2]"
38
+ :
39
+ : "w"(lo), "w"(hi), "r"(dst)
40
+ : "memory"
41
+ );
42
+ #else
43
+ vst1q_f16(dst, val);
44
+ #endif
26
45
  }
27
46
 
28
- inline int8_t clamp_to_int8(int32_t value) {
29
- return static_cast<int8_t>(std::max(-128, std::min(127, value)));
30
- }
31
47
 
32
- #if defined(__ARM_FEATURE_DOTPROD)
33
- inline int32x4_t accum_i8mm(int32x4_t acc, int8x16_t a, int8x16_t b) {
34
- return vdotq_s32(acc, a, b);
35
- }
36
- #else
37
- inline int32x4_t accum_i8mm(int32x4_t acc, int8x16_t a, int8x16_t b) {
38
- int16x8_t prod_low = vmull_s8(vget_low_s8(a), vget_low_s8(b));
39
- int32x4_t acc_high = vpaddlq_s16(vmull_s8(vget_high_s8(a), vget_high_s8(b)));
40
- return vaddq_s32(vaddq_s32(acc, vpaddlq_s16(prod_low)), acc_high);
41
- }
42
- #endif
48
+ inline float32x4_t fast_exp_f32x4(float32x4_t x) {
49
+ const float32x4_t log2e = vdupq_n_f32(1.4426950408889634f);
50
+ const float32x4_t ln2 = vdupq_n_f32(0.6931471805599453f);
51
+
52
+ const float32x4_t c0 = vdupq_n_f32(1.0f);
53
+ const float32x4_t c1 = vdupq_n_f32(0.6931471805599453f);
54
+ const float32x4_t c2 = vdupq_n_f32(0.2402265069591007f);
55
+ const float32x4_t c3 = vdupq_n_f32(0.05550410866482158f);
56
+ const float32x4_t c4 = vdupq_n_f32(0.009618129842071803f);
57
+
58
+ x = vmaxq_f32(x, vdupq_n_f32(-87.0f));
59
+ x = vminq_f32(x, vdupq_n_f32(87.0f));
60
+
61
+ float32x4_t z = vmulq_f32(x, log2e);
62
+
63
+ int32x4_t zi = vcvtq_s32_f32(z);
64
+ float32x4_t zf = vsubq_f32(z, vcvtq_f32_s32(zi));
65
+
66
+ uint32x4_t neg_mask = vcltq_f32(zf, vdupq_n_f32(0.0f));
67
+ zi = vsubq_s32(zi, vandq_s32(vreinterpretq_s32_u32(neg_mask), vdupq_n_s32(1)));
68
+ zf = vaddq_f32(zf, vreinterpretq_f32_u32(vandq_u32(neg_mask, vreinterpretq_u32_f32(vdupq_n_f32(1.0f)))));
69
+
70
+ float32x4_t zf_ln2 = vmulq_f32(zf, ln2);
71
+ float32x4_t p = c4;
72
+ p = vfmaq_f32(c3, p, zf_ln2);
73
+ p = vfmaq_f32(c2, p, zf_ln2);
74
+ p = vfmaq_f32(c1, p, zf_ln2);
75
+ p = vfmaq_f32(c0, p, zf_ln2);
76
+
77
+ int32x4_t exp_bits = vshlq_n_s32(vaddq_s32(zi, vdupq_n_s32(127)), 23);
78
+ float32x4_t scale = vreinterpretq_f32_s32(exp_bits);
43
79
 
44
- inline float16x8_t accum_f16_dot(float16x8_t acc, float16x8_t a_low, float16x8_t a_high,
45
- float16x8_t b_low, float16x8_t b_high) {
46
- acc = vfmaq_f16(acc, a_low, b_low);
47
- return vfmaq_f16(acc, a_high, b_high);
80
+ return vmulq_f32(p, scale);
48
81
  }
49
82
 
50
- inline float32x4_t accum_f32_dot(float32x4_t acc, float32x4_t a_low, float32x4_t a_high,
51
- float32x4_t b_low, float32x4_t b_high) {
52
- acc = vfmaq_f32(acc, a_low, b_low);
53
- return vfmaq_f32(acc, a_high, b_high);
83
+ inline float32x4_t fast_tanh_f32x4(float32x4_t x) {
84
+ const float32x4_t one = vdupq_n_f32(1.0f);
85
+ const float32x4_t neg_one = vdupq_n_f32(-1.0f);
86
+
87
+ uint32x4_t pos_sat = vcgtq_f32(x, vdupq_n_f32(4.5f));
88
+ uint32x4_t neg_sat = vcltq_f32(x, vdupq_n_f32(-4.5f));
89
+
90
+ const float32x4_t c27 = vdupq_n_f32(27.0f);
91
+ const float32x4_t c9 = vdupq_n_f32(9.0f);
92
+
93
+ float32x4_t x2 = vmulq_f32(x, x);
94
+ float32x4_t num = vaddq_f32(c27, x2);
95
+ float32x4_t den = vfmaq_f32(c27, c9, x2);
96
+
97
+ float32x4_t result = vmulq_f32(x, vdivq_f32(num, den));
98
+
99
+ result = vbslq_f32(pos_sat, one, result);
100
+ result = vbslq_f32(neg_sat, neg_one, result);
101
+
102
+ return result;
54
103
  }
55
104
 
56
105
  namespace CactusThreading {
57
-
106
+
58
107
  class ThreadPool {
59
108
  private:
109
+ static constexpr size_t MAX_WORKERS = 16;
110
+
60
111
  std::vector<std::thread> workers;
61
- std::queue<std::function<void()>> tasks;
62
- std::mutex queue_mutex;
63
- std::condition_variable condition;
64
- std::atomic<bool> stop{false};
65
- std::atomic<size_t> active_workers{0};
66
- std::condition_variable finish_condition;
67
-
112
+ std::deque<std::function<void()>> tasks;
113
+
114
+ std::mutex mutex;
115
+ std::condition_variable work_available;
116
+ std::condition_variable work_done;
117
+
118
+ bool stop{false};
119
+ std::atomic<size_t> pending_tasks{0};
120
+ size_t num_workers_;
121
+
68
122
  void worker_thread() {
69
123
  while (true) {
70
124
  std::function<void()> task;
71
125
  {
72
- std::unique_lock<std::mutex> lock(queue_mutex);
73
- condition.wait(lock, [this] { return stop || !tasks.empty(); });
74
-
75
- if (stop && tasks.empty()) return;
76
-
126
+ std::unique_lock<std::mutex> lock(mutex);
127
+ work_available.wait(lock, [this] {
128
+ return stop || !tasks.empty();
129
+ });
130
+
131
+ if (stop && tasks.empty()) {
132
+ return;
133
+ }
134
+
77
135
  task = std::move(tasks.front());
78
- tasks.pop();
79
- active_workers++;
136
+ tasks.pop_front();
80
137
  }
81
-
138
+
82
139
  task();
83
-
84
- active_workers--;
85
- finish_condition.notify_all();
140
+
141
+ if (pending_tasks.fetch_sub(1, std::memory_order_acq_rel) == 1) {
142
+ std::lock_guard<std::mutex> lock(mutex);
143
+ work_done.notify_one();
144
+ }
86
145
  }
87
146
  }
88
-
147
+
89
148
  public:
90
- explicit ThreadPool(size_t num_threads = std::thread::hardware_concurrency()) {
91
- workers.reserve(num_threads);
92
- for (size_t i = 0; i < num_threads; ++i) {
149
+ explicit ThreadPool(size_t num_threads = std::thread::hardware_concurrency())
150
+ : stop(false), pending_tasks(0) {
151
+ num_workers_ = std::min(num_threads, MAX_WORKERS);
152
+ if (num_workers_ == 0) num_workers_ = 1;
153
+ workers.reserve(num_workers_);
154
+ for (size_t i = 0; i < num_workers_; ++i) {
93
155
  workers.emplace_back(&ThreadPool::worker_thread, this);
94
156
  }
95
157
  }
96
-
158
+
97
159
  ~ThreadPool() {
98
160
  {
99
- std::unique_lock<std::mutex> lock(queue_mutex);
161
+ std::lock_guard<std::mutex> lock(mutex);
100
162
  stop = true;
101
163
  }
102
- condition.notify_all();
164
+ work_available.notify_all();
103
165
  for (auto& worker : workers) {
104
- worker.join();
166
+ if (worker.joinable()) {
167
+ worker.join();
168
+ }
105
169
  }
106
170
  }
107
-
171
+
108
172
  template<typename F>
109
173
  auto enqueue(F&& f) -> std::future<decltype(f())> {
110
174
  using return_type = decltype(f());
111
-
175
+
112
176
  auto task = std::make_shared<std::packaged_task<return_type()>>(
113
177
  std::forward<F>(f)
114
178
  );
115
-
179
+
116
180
  std::future<return_type> res = task->get_future();
181
+
117
182
  {
118
- std::unique_lock<std::mutex> lock(queue_mutex);
119
- if (stop) throw std::runtime_error("enqueue on stopped ThreadPool");
120
-
121
- tasks.emplace([task](){ (*task)(); });
183
+ std::lock_guard<std::mutex> lock(mutex);
184
+ pending_tasks.fetch_add(1, std::memory_order_relaxed);
185
+ tasks.emplace_back([task](){ (*task)(); });
122
186
  }
123
- condition.notify_one();
187
+ work_available.notify_one();
188
+
124
189
  return res;
125
190
  }
126
-
191
+
192
+ template<typename F>
193
+ void enqueue_batch(size_t total_work, F task_func) {
194
+ if (total_work == 0) return;
195
+
196
+ const size_t num_tasks = std::min(num_workers_, total_work);
197
+ const size_t per_worker = total_work / num_tasks;
198
+ const size_t remainder = total_work % num_tasks;
199
+
200
+ {
201
+ std::lock_guard<std::mutex> lock(mutex);
202
+ pending_tasks.fetch_add(num_tasks, std::memory_order_relaxed);
203
+
204
+ for (size_t w = 0; w < num_tasks; ++w) {
205
+ size_t start = w * per_worker + std::min(w, remainder);
206
+ size_t end = start + per_worker + (w < remainder ? 1 : 0);
207
+ tasks.emplace_back([=]() { task_func(start, end); });
208
+ }
209
+ }
210
+ work_available.notify_all();
211
+ }
212
+
127
213
  void wait_all() {
128
- std::unique_lock<std::mutex> lock(queue_mutex);
129
- finish_condition.wait(lock, [this] {
130
- return tasks.empty() && active_workers == 0;
214
+ std::unique_lock<std::mutex> lock(mutex);
215
+ work_done.wait(lock, [this] {
216
+ return pending_tasks.load(std::memory_order_acquire) == 0;
131
217
  });
132
218
  }
133
-
134
- size_t num_workers() const { return workers.size(); }
219
+
220
+ template<typename F>
221
+ void enqueue_n_threads(size_t total_work, size_t num_threads, F task_func) {
222
+ if (total_work == 0 || num_threads == 0) return;
223
+
224
+ num_threads = std::min(num_threads, std::min(num_workers_, total_work));
225
+ const size_t per_thread = total_work / num_threads;
226
+ const size_t remainder = total_work % num_threads;
227
+
228
+ {
229
+ std::lock_guard<std::mutex> lock(mutex);
230
+ pending_tasks.fetch_add(num_threads, std::memory_order_relaxed);
231
+
232
+ for (size_t t = 0; t < num_threads; ++t) {
233
+ size_t start = t * per_thread + std::min(t, remainder);
234
+ size_t end = start + per_thread + (t < remainder ? 1 : 0);
235
+ tasks.emplace_back([=]() { task_func(start, end); });
236
+ }
237
+ }
238
+ work_available.notify_all();
239
+ }
240
+
241
+ size_t num_workers() const { return num_workers_; }
135
242
  };
136
-
243
+
137
244
  inline ThreadPool& get_thread_pool() {
138
245
  static ThreadPool pool;
139
246
  return pool;
140
247
  }
141
248
 
142
- inline size_t get_optimal_thread_count(size_t total_work, size_t min_work_per_thread) {
143
- if (total_work < min_work_per_thread) return 1;
249
+ struct ParallelConfig {
250
+ size_t min_work_gate;
251
+ size_t work_per_thread;
252
+
253
+ constexpr ParallelConfig(size_t gate, size_t per_thread)
254
+ : min_work_gate(gate), work_per_thread(per_thread) {}
255
+ };
256
+
257
+ inline size_t get_optimal_thread_count(size_t total_work, ParallelConfig config) {
258
+ if (total_work < config.min_work_gate) return 1;
259
+
144
260
  size_t pool_size = get_thread_pool().num_workers();
145
- return std::min(pool_size,
146
- std::max(static_cast<size_t>(1), total_work / min_work_per_thread));
261
+ size_t num_threads = (total_work + config.work_per_thread - 1) / config.work_per_thread;
262
+ return std::min(pool_size, std::max(static_cast<size_t>(1), num_threads));
147
263
  }
148
-
264
+
149
265
  struct Thresholds {
266
+ #if defined(__ANDROID__)
267
+ static constexpr ParallelConfig ATTENTION{64, 32};
268
+ static constexpr ParallelConfig ELEMENT_WISE{5000, 2500};
269
+ static constexpr ParallelConfig AXIS_REDUCE{1000, 500};
270
+ static constexpr ParallelConfig ALL_REDUCE{10000, 5000};
271
+ static constexpr ParallelConfig SCALAR_BASIC{30000, 15000};
272
+ static constexpr ParallelConfig SCALAR_EXPENSIVE{10000, 5000};
273
+ #else // Apple
274
+ static constexpr ParallelConfig ATTENTION{32, 16};
275
+ static constexpr ParallelConfig ELEMENT_WISE{5000, 2500};
276
+ static constexpr ParallelConfig AXIS_REDUCE{1000, 500};
277
+ static constexpr ParallelConfig ALL_REDUCE{10000, 5000};
278
+ static constexpr ParallelConfig SCALAR_BASIC{5000, 2500};
279
+ static constexpr ParallelConfig SCALAR_EXPENSIVE{2500, 1250};
280
+ #endif
281
+ };
150
282
 
283
+ struct GemmThreading {
151
284
  #if defined(__ANDROID__)
152
- static constexpr size_t ELEMENT_WISE = 5000;
153
- static constexpr size_t AXIS_REDUCE = 1000;
154
- static constexpr size_t ALL_REDUCE = 10000;
155
- static constexpr size_t SCALAR_BASIC = 30000;
156
- static constexpr size_t SCALAR_EXPENSIVE = 10000;
157
- static constexpr size_t ATTENTION = 512;
158
- static constexpr size_t GEMM_TILED = 20000;
159
- static constexpr size_t GEMM_SMALL = 64 * 64 * 64;
160
- static constexpr size_t GEMM_MEDIUM = 256 * 256 * 256;
161
- static constexpr size_t GEMM_TILE_M = 64;
162
- static constexpr size_t GEMM_TILE_N = 64;
163
- static constexpr size_t GEMM_TILE_M_SMALL = 32;
164
- static constexpr size_t GEMM_TILE_N_SMALL = 32;
165
- #else // iOS
166
- static constexpr size_t ELEMENT_WISE = 5000;
167
- static constexpr size_t AXIS_REDUCE = 1000;
168
- static constexpr size_t ALL_REDUCE = 10000;
169
- static constexpr size_t SCALAR_BASIC = 5000;
170
- static constexpr size_t SCALAR_EXPENSIVE = 2500;
171
- static constexpr size_t ATTENTION = 4;
172
- static constexpr size_t GEMM_TILED = 4;
173
- static constexpr size_t GEMM_SMALL = 64 * 64 * 64;
174
- static constexpr size_t GEMM_MEDIUM = 256 * 256 * 256;
175
- static constexpr size_t GEMM_TILE_M = 64;
176
- static constexpr size_t GEMM_TILE_N = 64;
177
- static constexpr size_t GEMM_TILE_M_SMALL = 32;
178
- static constexpr size_t GEMM_TILE_N_SMALL = 32;
285
+ static size_t get_num_threads(size_t M, size_t pool_size) {
286
+ if (M <= 1) return 1;
287
+ return pool_size;
288
+ }
289
+ static size_t get_gemv_threads(size_t /*N_blocks*/, size_t /*pool_size*/) {
290
+ return 1;
291
+ }
292
+ #elif defined(__APPLE__) && TARGET_OS_IPHONE
293
+ static constexpr size_t GEMV_MIN_N_BLOCKS = 512;
294
+ static size_t get_num_threads(size_t M, size_t pool_size) {
295
+ if (M <= 1) return std::min(pool_size, static_cast<size_t>(2));
296
+ return pool_size;
297
+ }
298
+ static size_t get_gemv_threads(size_t N_blocks, size_t pool_size) {
299
+ if (N_blocks < GEMV_MIN_N_BLOCKS) return 1;
300
+ return std::min(pool_size, static_cast<size_t>(2));
301
+ }
302
+ #else
303
+ static constexpr size_t GEMV_MIN_N_BLOCKS = 256;
304
+ static size_t get_num_threads(size_t M, size_t pool_size) {
305
+ if (M <= 1) return std::min(pool_size, static_cast<size_t>(4));
306
+ return pool_size;
307
+ }
308
+ static size_t get_gemv_threads(size_t N_blocks, size_t pool_size) {
309
+ if (N_blocks < GEMV_MIN_N_BLOCKS) return 1;
310
+ if (N_blocks < 512) return std::min(pool_size, static_cast<size_t>(2));
311
+ return std::min(pool_size, static_cast<size_t>(4));
312
+ }
179
313
  #endif
180
- static constexpr size_t L2_CACHE_SIZE = 256 * 1024;
181
314
  };
315
+
316
+ inline size_t& get_gemm_thread_override() {
317
+ static size_t override_threads = 0;
318
+ return override_threads;
319
+ }
320
+
321
+ inline void set_gemm_threads(size_t num_threads) {
322
+ get_gemm_thread_override() = num_threads;
323
+ }
324
+
325
+ inline void reset_gemm_threads() {
326
+ get_gemm_thread_override() = 0;
327
+ }
182
328
 
183
329
  class TaskHandle {
184
330
  private:
@@ -225,10 +371,10 @@ namespace CactusThreading {
225
371
  };
226
372
 
227
373
  template<typename WorkFunc>
228
- TaskHandle parallel_for(size_t total_work, size_t threshold, WorkFunc work_func, bool wait = true) {
229
- const size_t num_threads = get_optimal_thread_count(total_work, threshold);
230
- TaskHandle handle(!wait);
231
-
374
+ TaskHandle parallel_for(size_t total_work, ParallelConfig config, WorkFunc work_func, bool wait = true) {
375
+ const size_t num_threads = get_optimal_thread_count(total_work, config);
376
+ TaskHandle handle(!wait);
377
+
232
378
  if (num_threads == 1) {
233
379
  if (wait) {
234
380
  work_func(0, total_work);
@@ -240,10 +386,10 @@ namespace CactusThreading {
240
386
  }));
241
387
  return handle;
242
388
  }
243
-
389
+
244
390
  auto& pool = get_thread_pool();
245
391
  const size_t work_per_thread = total_work / num_threads;
246
-
392
+
247
393
  for (size_t t = 0; t < num_threads; ++t) {
248
394
  handle.add_future(pool.enqueue([work_func, t, num_threads, work_per_thread, total_work]() {
249
395
  const size_t start_idx = t * work_per_thread;
@@ -251,17 +397,17 @@ namespace CactusThreading {
251
397
  work_func(start_idx, end_idx);
252
398
  }));
253
399
  }
254
-
400
+
255
401
  if (wait) {
256
402
  handle.wait();
257
403
  }
258
404
  return handle;
259
405
  }
260
-
406
+
261
407
  template<typename WorkFunc>
262
- void parallel_for_2d(size_t outer_size, size_t inner_size, size_t threshold, WorkFunc work_func) {
408
+ void parallel_for_2d(size_t outer_size, size_t inner_size, ParallelConfig config, WorkFunc work_func) {
263
409
  const size_t total_work = outer_size * inner_size;
264
- parallel_for(total_work, threshold, [&](size_t start_idx, size_t end_idx) {
410
+ parallel_for(total_work, config, [&](size_t start_idx, size_t end_idx) {
265
411
  for (size_t work_idx = start_idx; work_idx < end_idx; ++work_idx) {
266
412
  const size_t outer = work_idx / inner_size;
267
413
  const size_t inner = work_idx % inner_size;
@@ -269,11 +415,11 @@ namespace CactusThreading {
269
415
  }
270
416
  });
271
417
  }
272
-
418
+
273
419
  template<typename WorkFunc, typename ResultType, typename CombineFunc>
274
- ResultType parallel_reduce(size_t total_work, size_t threshold,
420
+ ResultType parallel_reduce(size_t total_work, ParallelConfig config,
275
421
  WorkFunc work_func, ResultType init_value, CombineFunc combine_func) {
276
- const size_t num_threads = get_optimal_thread_count(total_work, threshold);
422
+ const size_t num_threads = get_optimal_thread_count(total_work, config);
277
423
 
278
424
  if (num_threads == 1) {
279
425
  return work_func(0, total_work);
@@ -298,46 +444,25 @@ namespace CactusThreading {
298
444
  }
299
445
  return result;
300
446
  }
301
-
302
- inline size_t compute_gemm_parallelism(size_t M, size_t K, size_t N, size_t element_size) {
303
- size_t total_ops = M * K * N;
304
-
305
- if (total_ops < Thresholds::GEMM_SMALL) return 1;
306
-
307
- if (total_ops < Thresholds::GEMM_MEDIUM) {
308
- return std::min(static_cast<size_t>(2), get_thread_pool().num_workers());
309
- }
310
-
311
- size_t bytes_accessed = (M * K + K * N + M * N) * element_size;
312
- size_t cache_tiles = (bytes_accessed + Thresholds::L2_CACHE_SIZE - 1) / Thresholds::L2_CACHE_SIZE;
313
-
314
- size_t compute_threads = std::sqrt(static_cast<double>(total_ops) / Thresholds::GEMM_SMALL);
315
- size_t memory_threads = cache_tiles;
316
-
317
- size_t optimal = std::min(compute_threads, memory_threads);
318
- return std::min(optimal, get_thread_pool().num_workers());
319
- }
320
-
447
+
321
448
  template<typename WorkFunc>
322
- void parallel_for_2d_tiled(size_t rows, size_t cols, size_t tile_rows, size_t tile_cols, WorkFunc work_func) {
323
- size_t num_row_tiles = (rows + tile_rows - 1) / tile_rows;
324
- size_t num_col_tiles = (cols + tile_cols - 1) / tile_cols;
325
- size_t total_tiles = num_row_tiles * num_col_tiles;
326
-
327
- parallel_for(total_tiles, Thresholds::GEMM_TILED, [=](size_t start_tile, size_t end_tile) {
328
- for (size_t tile_idx = start_tile; tile_idx < end_tile; ++tile_idx) {
329
- size_t tile_row = tile_idx / num_col_tiles;
330
- size_t tile_col = tile_idx % num_col_tiles;
331
-
332
- size_t row_start = tile_row * tile_rows;
333
- size_t row_end = std::min(row_start + tile_rows, rows);
334
- size_t col_start = tile_col * tile_cols;
335
- size_t col_end = std::min(col_start + tile_cols, cols);
336
-
337
- work_func(row_start, row_end, col_start, col_end);
338
- }
339
- });
449
+ void parallel_gemm_tiles(size_t M, size_t total_tiles, WorkFunc work_func) {
450
+ auto& pool = get_thread_pool();
451
+
452
+ size_t override = get_gemm_thread_override();
453
+ size_t num_threads = (override > 0) ? override : GemmThreading::get_num_threads(M, pool.num_workers());
454
+ num_threads = std::min(num_threads, total_tiles);
455
+
456
+ if (num_threads <= 1) {
457
+ work_func(0, total_tiles);
458
+ return;
459
+ }
460
+
461
+ pool.enqueue_n_threads(total_tiles, num_threads, work_func);
462
+ pool.wait_all();
340
463
  }
464
+
341
465
  }
342
466
 
467
+
343
468
  #endif // KERNEL_UTILS_H
@@ -7,7 +7,6 @@
7
7
  #include "engine/engine.h"
8
8
  #include "models/model.h"
9
9
  #include "ffi/cactus_ffi.h"
10
- #include "ffi/cactus_telemetry.h"
11
10
  #include "npu/npu.h"
12
11
 
13
12
  #endif // CACTUS_H