cactus-react-native 1.4.0 → 1.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Cactus.podspec +1 -1
- package/README.md +465 -174
- package/android/CMakeLists.txt +24 -5
- package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libcurl.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libmbedcrypto.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libmbedtls.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libmbedx509.a +0 -0
- package/cpp/HybridCactus.cpp +157 -6
- package/cpp/HybridCactus.hpp +20 -3
- package/cpp/cactus_ffi.h +65 -30
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus.h +0 -1
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +65 -30
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_utils.h +357 -122
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +184 -63
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/gemma_tools.h +549 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +153 -27
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +90 -178
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +276 -151
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus.h +0 -1
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +65 -30
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_utils.h +357 -122
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +184 -63
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/gemma_tools.h +549 -0
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +153 -27
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +90 -178
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +276 -151
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
- package/lib/module/classes/CactusLM.js +43 -58
- package/lib/module/classes/CactusLM.js.map +1 -1
- package/lib/module/classes/CactusSTT.js +64 -38
- package/lib/module/classes/CactusSTT.js.map +1 -1
- package/lib/module/classes/CactusVAD.js +95 -0
- package/lib/module/classes/CactusVAD.js.map +1 -0
- package/lib/module/hooks/useCactusLM.js +23 -15
- package/lib/module/hooks/useCactusLM.js.map +1 -1
- package/lib/module/hooks/useCactusSTT.js +85 -28
- package/lib/module/hooks/useCactusSTT.js.map +1 -1
- package/lib/module/hooks/useCactusVAD.js +171 -0
- package/lib/module/hooks/useCactusVAD.js.map +1 -0
- package/lib/module/index.js +2 -3
- package/lib/module/index.js.map +1 -1
- package/lib/module/modelRegistry.js +52 -0
- package/lib/module/modelRegistry.js.map +1 -0
- package/lib/module/native/Cactus.js +107 -8
- package/lib/module/native/Cactus.js.map +1 -1
- package/lib/module/native/CactusIndex.js.map +1 -1
- package/lib/module/native/index.js +0 -3
- package/lib/module/native/index.js.map +1 -1
- package/lib/module/types/CactusLM.js +2 -0
- package/lib/module/types/CactusSTT.js +2 -0
- package/lib/module/types/CactusVAD.js +4 -0
- package/lib/module/types/{CactusModel.js.map → CactusVAD.js.map} +1 -1
- package/lib/module/types/common.js +2 -0
- package/lib/module/types/{CactusSTTModel.js.map → common.js.map} +1 -1
- package/lib/typescript/src/classes/CactusLM.d.ts +8 -6
- package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
- package/lib/typescript/src/classes/CactusSTT.d.ts +11 -6
- package/lib/typescript/src/classes/CactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/classes/CactusVAD.d.ts +20 -0
- package/lib/typescript/src/classes/CactusVAD.d.ts.map +1 -0
- package/lib/typescript/src/hooks/useCactusLM.d.ts +3 -3
- package/lib/typescript/src/hooks/useCactusLM.d.ts.map +1 -1
- package/lib/typescript/src/hooks/useCactusSTT.d.ts +11 -5
- package/lib/typescript/src/hooks/useCactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/hooks/useCactusVAD.d.ts +15 -0
- package/lib/typescript/src/hooks/useCactusVAD.d.ts.map +1 -0
- package/lib/typescript/src/index.d.ts +7 -6
- package/lib/typescript/src/index.d.ts.map +1 -1
- package/lib/typescript/src/modelRegistry.d.ts +5 -0
- package/lib/typescript/src/modelRegistry.d.ts.map +1 -0
- package/lib/typescript/src/native/Cactus.d.ts +12 -6
- package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
- package/lib/typescript/src/native/CactusIndex.d.ts +2 -2
- package/lib/typescript/src/native/CactusIndex.d.ts.map +1 -1
- package/lib/typescript/src/native/index.d.ts +0 -3
- package/lib/typescript/src/native/index.d.ts.map +1 -1
- package/lib/typescript/src/specs/Cactus.nitro.d.ts +6 -1
- package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusIndex.d.ts +2 -2
- package/lib/typescript/src/types/CactusIndex.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusLM.d.ts +19 -9
- package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusSTT.d.ts +45 -4
- package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusVAD.d.ts +34 -0
- package/lib/typescript/src/types/CactusVAD.d.ts.map +1 -0
- package/lib/typescript/src/types/common.d.ts +23 -0
- package/lib/typescript/src/types/common.d.ts.map +1 -0
- package/nitro.json +0 -11
- package/nitrogen/generated/android/cactus+autolinking.cmake +0 -5
- package/nitrogen/generated/android/cactusOnLoad.cpp +0 -30
- package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.cpp +0 -50
- package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.hpp +9 -147
- package/nitrogen/generated/ios/Cactus-Swift-Cxx-Umbrella.hpp +0 -13
- package/nitrogen/generated/ios/CactusAutolinking.mm +0 -26
- package/nitrogen/generated/ios/CactusAutolinking.swift +0 -30
- package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +5 -0
- package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +6 -1
- package/package.json +3 -3
- package/src/classes/CactusLM.ts +59 -74
- package/src/classes/CactusSTT.ts +92 -49
- package/src/classes/CactusVAD.ts +129 -0
- package/src/hooks/useCactusLM.ts +26 -9
- package/src/hooks/useCactusSTT.ts +105 -44
- package/src/hooks/useCactusVAD.ts +215 -0
- package/src/index.tsx +20 -10
- package/src/modelRegistry.ts +65 -0
- package/src/native/Cactus.ts +130 -14
- package/src/native/CactusIndex.ts +2 -2
- package/src/native/index.ts +0 -3
- package/src/specs/Cactus.nitro.ts +11 -2
- package/src/types/CactusIndex.ts +2 -2
- package/src/types/CactusLM.ts +20 -9
- package/src/types/CactusSTT.ts +50 -4
- package/src/types/CactusVAD.ts +39 -0
- package/src/types/common.ts +23 -0
- package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusCrypto.kt +0 -46
- package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusDeviceInfo.kt +0 -27
- package/android/src/main/jniLibs/arm64-v8a/libcactus_util.a +0 -0
- package/cpp/HybridCactusUtil.cpp +0 -47
- package/cpp/HybridCactusUtil.hpp +0 -27
- package/cpp/cactus_util.h +0 -25
- package/ios/HybridCactusCrypto.swift +0 -37
- package/ios/HybridCactusDeviceInfo.swift +0 -32
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_telemetry.h +0 -656
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_telemetry.h +0 -656
- package/ios/cactus_util.xcframework/Info.plist +0 -39
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/cactus_util.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/database.h +0 -27
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/ios_utils.h +0 -10
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/logging.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Info.plist +0 -0
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/cactus_util +0 -0
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/cactus_util.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/database.h +0 -27
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/ios_utils.h +0 -10
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/logging.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Info.plist +0 -0
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/_CodeSignature/CodeResources +0 -135
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/cactus_util +0 -0
- package/lib/module/api/Database.js +0 -137
- package/lib/module/api/Database.js.map +0 -1
- package/lib/module/api/RemoteLM.js +0 -201
- package/lib/module/api/RemoteLM.js.map +0 -1
- package/lib/module/config/CactusConfig.js +0 -12
- package/lib/module/config/CactusConfig.js.map +0 -1
- package/lib/module/native/CactusCrypto.js +0 -10
- package/lib/module/native/CactusCrypto.js.map +0 -1
- package/lib/module/native/CactusDeviceInfo.js +0 -13
- package/lib/module/native/CactusDeviceInfo.js.map +0 -1
- package/lib/module/native/CactusUtil.js +0 -36
- package/lib/module/native/CactusUtil.js.map +0 -1
- package/lib/module/specs/CactusCrypto.nitro.js +0 -4
- package/lib/module/specs/CactusCrypto.nitro.js.map +0 -1
- package/lib/module/specs/CactusDeviceInfo.nitro.js +0 -4
- package/lib/module/specs/CactusDeviceInfo.nitro.js.map +0 -1
- package/lib/module/specs/CactusUtil.nitro.js +0 -4
- package/lib/module/specs/CactusUtil.nitro.js.map +0 -1
- package/lib/module/telemetry/Telemetry.js +0 -154
- package/lib/module/telemetry/Telemetry.js.map +0 -1
- package/lib/module/types/CactusModel.js +0 -2
- package/lib/module/types/CactusSTTModel.js +0 -2
- package/lib/typescript/src/api/Database.d.ts +0 -18
- package/lib/typescript/src/api/Database.d.ts.map +0 -1
- package/lib/typescript/src/api/RemoteLM.d.ts +0 -14
- package/lib/typescript/src/api/RemoteLM.d.ts.map +0 -1
- package/lib/typescript/src/config/CactusConfig.d.ts +0 -7
- package/lib/typescript/src/config/CactusConfig.d.ts.map +0 -1
- package/lib/typescript/src/native/CactusCrypto.d.ts +0 -5
- package/lib/typescript/src/native/CactusCrypto.d.ts.map +0 -1
- package/lib/typescript/src/native/CactusDeviceInfo.d.ts +0 -7
- package/lib/typescript/src/native/CactusDeviceInfo.d.ts.map +0 -1
- package/lib/typescript/src/native/CactusUtil.d.ts +0 -6
- package/lib/typescript/src/native/CactusUtil.d.ts.map +0 -1
- package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts +0 -8
- package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts.map +0 -1
- package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts +0 -16
- package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts.map +0 -1
- package/lib/typescript/src/specs/CactusUtil.nitro.d.ts +0 -10
- package/lib/typescript/src/specs/CactusUtil.nitro.d.ts.map +0 -1
- package/lib/typescript/src/telemetry/Telemetry.d.ts +0 -34
- package/lib/typescript/src/telemetry/Telemetry.d.ts.map +0 -1
- package/lib/typescript/src/types/CactusModel.d.ts +0 -13
- package/lib/typescript/src/types/CactusModel.d.ts.map +0 -1
- package/lib/typescript/src/types/CactusSTTModel.d.ts +0 -8
- package/lib/typescript/src/types/CactusSTTModel.d.ts.map +0 -1
- package/nitrogen/generated/android/c++/JDeviceInfo.hpp +0 -74
- package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.cpp +0 -65
- package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.hpp +0 -65
- package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.cpp +0 -85
- package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.hpp +0 -66
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/DeviceInfo.kt +0 -50
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusCryptoSpec.kt +0 -58
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusDeviceInfoSpec.kt +0 -62
- package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.cpp +0 -11
- package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.hpp +0 -77
- package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.cpp +0 -11
- package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.hpp +0 -88
- package/nitrogen/generated/ios/swift/DeviceInfo.swift +0 -98
- package/nitrogen/generated/ios/swift/Func_void_DeviceInfo.swift +0 -47
- package/nitrogen/generated/ios/swift/Func_void_std__optional_std__string_.swift +0 -54
- package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec.swift +0 -57
- package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec_cxx.swift +0 -139
- package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec.swift +0 -58
- package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec_cxx.swift +0 -164
- package/nitrogen/generated/shared/c++/DeviceInfo.hpp +0 -92
- package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.cpp +0 -21
- package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.hpp +0 -63
- package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.cpp +0 -22
- package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.hpp +0 -67
- package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.cpp +0 -23
- package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.hpp +0 -66
- package/src/api/Database.ts +0 -188
- package/src/api/RemoteLM.ts +0 -273
- package/src/config/CactusConfig.ts +0 -11
- package/src/native/CactusCrypto.ts +0 -11
- package/src/native/CactusDeviceInfo.ts +0 -18
- package/src/native/CactusUtil.ts +0 -43
- package/src/specs/CactusCrypto.nitro.ts +0 -6
- package/src/specs/CactusDeviceInfo.nitro.ts +0 -15
- package/src/specs/CactusUtil.nitro.ts +0 -8
- package/src/telemetry/Telemetry.ts +0 -236
- package/src/types/CactusModel.ts +0 -15
- package/src/types/CactusSTTModel.ts +0 -10
|
@@ -2,6 +2,14 @@
|
|
|
2
2
|
#define KERNEL_UTILS_H
|
|
3
3
|
|
|
4
4
|
#include <arm_neon.h>
|
|
5
|
+
#if defined(__APPLE__)
|
|
6
|
+
#include <TargetConditionals.h>
|
|
7
|
+
#include <sys/sysctl.h>
|
|
8
|
+
#endif
|
|
9
|
+
#if defined(__ANDROID__)
|
|
10
|
+
#include <sys/auxv.h>
|
|
11
|
+
#include <asm/hwcap.h>
|
|
12
|
+
#endif
|
|
5
13
|
#include <algorithm>
|
|
6
14
|
#include <cmath>
|
|
7
15
|
#include <thread>
|
|
@@ -19,166 +27,304 @@
|
|
|
19
27
|
#include <cstdio>
|
|
20
28
|
|
|
21
29
|
constexpr size_t NEON_VECTOR_SIZE = 16;
|
|
30
|
+
constexpr size_t STREAMING_STORE_THRESHOLD = 32768;
|
|
22
31
|
|
|
23
|
-
inline
|
|
24
|
-
|
|
25
|
-
|
|
32
|
+
inline void stream_store_f16x8(__fp16* dst, float16x8_t val) {
|
|
33
|
+
#if defined(__aarch64__)
|
|
34
|
+
float16x4_t lo = vget_low_f16(val);
|
|
35
|
+
float16x4_t hi = vget_high_f16(val);
|
|
36
|
+
__asm__ __volatile__(
|
|
37
|
+
"stnp %d0, %d1, [%2]"
|
|
38
|
+
:
|
|
39
|
+
: "w"(lo), "w"(hi), "r"(dst)
|
|
40
|
+
: "memory"
|
|
41
|
+
);
|
|
42
|
+
#else
|
|
43
|
+
vst1q_f16(dst, val);
|
|
44
|
+
#endif
|
|
26
45
|
}
|
|
27
46
|
|
|
28
|
-
inline int8_t clamp_to_int8(int32_t value) {
|
|
29
|
-
return static_cast<int8_t>(std::max(-128, std::min(127, value)));
|
|
30
|
-
}
|
|
31
47
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
48
|
+
inline float32x4_t fast_exp_f32x4(float32x4_t x) {
|
|
49
|
+
const float32x4_t log2e = vdupq_n_f32(1.4426950408889634f);
|
|
50
|
+
const float32x4_t ln2 = vdupq_n_f32(0.6931471805599453f);
|
|
51
|
+
|
|
52
|
+
const float32x4_t c0 = vdupq_n_f32(1.0f);
|
|
53
|
+
const float32x4_t c1 = vdupq_n_f32(0.6931471805599453f);
|
|
54
|
+
const float32x4_t c2 = vdupq_n_f32(0.2402265069591007f);
|
|
55
|
+
const float32x4_t c3 = vdupq_n_f32(0.05550410866482158f);
|
|
56
|
+
const float32x4_t c4 = vdupq_n_f32(0.009618129842071803f);
|
|
57
|
+
|
|
58
|
+
x = vmaxq_f32(x, vdupq_n_f32(-87.0f));
|
|
59
|
+
x = vminq_f32(x, vdupq_n_f32(87.0f));
|
|
60
|
+
|
|
61
|
+
float32x4_t z = vmulq_f32(x, log2e);
|
|
62
|
+
|
|
63
|
+
int32x4_t zi = vcvtq_s32_f32(z);
|
|
64
|
+
float32x4_t zf = vsubq_f32(z, vcvtq_f32_s32(zi));
|
|
65
|
+
|
|
66
|
+
uint32x4_t neg_mask = vcltq_f32(zf, vdupq_n_f32(0.0f));
|
|
67
|
+
zi = vsubq_s32(zi, vandq_s32(vreinterpretq_s32_u32(neg_mask), vdupq_n_s32(1)));
|
|
68
|
+
zf = vaddq_f32(zf, vreinterpretq_f32_u32(vandq_u32(neg_mask, vreinterpretq_u32_f32(vdupq_n_f32(1.0f)))));
|
|
69
|
+
|
|
70
|
+
float32x4_t zf_ln2 = vmulq_f32(zf, ln2);
|
|
71
|
+
float32x4_t p = c4;
|
|
72
|
+
p = vfmaq_f32(c3, p, zf_ln2);
|
|
73
|
+
p = vfmaq_f32(c2, p, zf_ln2);
|
|
74
|
+
p = vfmaq_f32(c1, p, zf_ln2);
|
|
75
|
+
p = vfmaq_f32(c0, p, zf_ln2);
|
|
76
|
+
|
|
77
|
+
int32x4_t exp_bits = vshlq_n_s32(vaddq_s32(zi, vdupq_n_s32(127)), 23);
|
|
78
|
+
float32x4_t scale = vreinterpretq_f32_s32(exp_bits);
|
|
43
79
|
|
|
44
|
-
|
|
45
|
-
float16x8_t b_low, float16x8_t b_high) {
|
|
46
|
-
acc = vfmaq_f16(acc, a_low, b_low);
|
|
47
|
-
return vfmaq_f16(acc, a_high, b_high);
|
|
80
|
+
return vmulq_f32(p, scale);
|
|
48
81
|
}
|
|
49
82
|
|
|
50
|
-
inline float32x4_t
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
83
|
+
inline float32x4_t fast_tanh_f32x4(float32x4_t x) {
|
|
84
|
+
const float32x4_t one = vdupq_n_f32(1.0f);
|
|
85
|
+
const float32x4_t neg_one = vdupq_n_f32(-1.0f);
|
|
86
|
+
|
|
87
|
+
uint32x4_t pos_sat = vcgtq_f32(x, vdupq_n_f32(4.5f));
|
|
88
|
+
uint32x4_t neg_sat = vcltq_f32(x, vdupq_n_f32(-4.5f));
|
|
89
|
+
|
|
90
|
+
const float32x4_t c27 = vdupq_n_f32(27.0f);
|
|
91
|
+
const float32x4_t c9 = vdupq_n_f32(9.0f);
|
|
92
|
+
|
|
93
|
+
float32x4_t x2 = vmulq_f32(x, x);
|
|
94
|
+
float32x4_t num = vaddq_f32(c27, x2);
|
|
95
|
+
float32x4_t den = vfmaq_f32(c27, c9, x2);
|
|
96
|
+
|
|
97
|
+
float32x4_t result = vmulq_f32(x, vdivq_f32(num, den));
|
|
98
|
+
|
|
99
|
+
result = vbslq_f32(pos_sat, one, result);
|
|
100
|
+
result = vbslq_f32(neg_sat, neg_one, result);
|
|
101
|
+
|
|
102
|
+
return result;
|
|
54
103
|
}
|
|
55
104
|
|
|
56
105
|
namespace CactusThreading {
|
|
57
|
-
|
|
106
|
+
|
|
58
107
|
class ThreadPool {
|
|
59
108
|
private:
|
|
109
|
+
static constexpr size_t MAX_WORKERS = 16;
|
|
110
|
+
|
|
60
111
|
std::vector<std::thread> workers;
|
|
61
|
-
std::
|
|
62
|
-
|
|
63
|
-
std::
|
|
64
|
-
std::
|
|
65
|
-
std::
|
|
66
|
-
|
|
67
|
-
|
|
112
|
+
std::deque<std::function<void()>> tasks;
|
|
113
|
+
|
|
114
|
+
std::mutex mutex;
|
|
115
|
+
std::condition_variable work_available;
|
|
116
|
+
std::condition_variable work_done;
|
|
117
|
+
|
|
118
|
+
bool stop{false};
|
|
119
|
+
std::atomic<size_t> pending_tasks{0};
|
|
120
|
+
size_t num_workers_;
|
|
121
|
+
|
|
68
122
|
void worker_thread() {
|
|
69
123
|
while (true) {
|
|
70
124
|
std::function<void()> task;
|
|
71
125
|
{
|
|
72
|
-
std::unique_lock<std::mutex> lock(
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
126
|
+
std::unique_lock<std::mutex> lock(mutex);
|
|
127
|
+
work_available.wait(lock, [this] {
|
|
128
|
+
return stop || !tasks.empty();
|
|
129
|
+
});
|
|
130
|
+
|
|
131
|
+
if (stop && tasks.empty()) {
|
|
132
|
+
return;
|
|
133
|
+
}
|
|
134
|
+
|
|
77
135
|
task = std::move(tasks.front());
|
|
78
|
-
tasks.
|
|
79
|
-
active_workers++;
|
|
136
|
+
tasks.pop_front();
|
|
80
137
|
}
|
|
81
|
-
|
|
138
|
+
|
|
82
139
|
task();
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
140
|
+
|
|
141
|
+
if (pending_tasks.fetch_sub(1, std::memory_order_acq_rel) == 1) {
|
|
142
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
143
|
+
work_done.notify_one();
|
|
144
|
+
}
|
|
86
145
|
}
|
|
87
146
|
}
|
|
88
|
-
|
|
147
|
+
|
|
89
148
|
public:
|
|
90
|
-
explicit ThreadPool(size_t num_threads = std::thread::hardware_concurrency())
|
|
91
|
-
|
|
92
|
-
|
|
149
|
+
explicit ThreadPool(size_t num_threads = std::thread::hardware_concurrency())
|
|
150
|
+
: stop(false), pending_tasks(0) {
|
|
151
|
+
num_workers_ = std::min(num_threads, MAX_WORKERS);
|
|
152
|
+
if (num_workers_ == 0) num_workers_ = 1;
|
|
153
|
+
workers.reserve(num_workers_);
|
|
154
|
+
for (size_t i = 0; i < num_workers_; ++i) {
|
|
93
155
|
workers.emplace_back(&ThreadPool::worker_thread, this);
|
|
94
156
|
}
|
|
95
157
|
}
|
|
96
|
-
|
|
158
|
+
|
|
97
159
|
~ThreadPool() {
|
|
98
160
|
{
|
|
99
|
-
std::
|
|
161
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
100
162
|
stop = true;
|
|
101
163
|
}
|
|
102
|
-
|
|
164
|
+
work_available.notify_all();
|
|
103
165
|
for (auto& worker : workers) {
|
|
104
|
-
worker.
|
|
166
|
+
if (worker.joinable()) {
|
|
167
|
+
worker.join();
|
|
168
|
+
}
|
|
105
169
|
}
|
|
106
170
|
}
|
|
107
|
-
|
|
171
|
+
|
|
108
172
|
template<typename F>
|
|
109
173
|
auto enqueue(F&& f) -> std::future<decltype(f())> {
|
|
110
174
|
using return_type = decltype(f());
|
|
111
|
-
|
|
175
|
+
|
|
112
176
|
auto task = std::make_shared<std::packaged_task<return_type()>>(
|
|
113
177
|
std::forward<F>(f)
|
|
114
178
|
);
|
|
115
|
-
|
|
179
|
+
|
|
116
180
|
std::future<return_type> res = task->get_future();
|
|
181
|
+
|
|
117
182
|
{
|
|
118
|
-
std::
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
tasks.emplace([task](){ (*task)(); });
|
|
183
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
184
|
+
pending_tasks.fetch_add(1, std::memory_order_relaxed);
|
|
185
|
+
tasks.emplace_back([task](){ (*task)(); });
|
|
122
186
|
}
|
|
123
|
-
|
|
187
|
+
work_available.notify_one();
|
|
188
|
+
|
|
124
189
|
return res;
|
|
125
190
|
}
|
|
126
|
-
|
|
191
|
+
|
|
192
|
+
template<typename F>
|
|
193
|
+
void enqueue_batch(size_t total_work, F task_func) {
|
|
194
|
+
if (total_work == 0) return;
|
|
195
|
+
|
|
196
|
+
const size_t num_tasks = std::min(num_workers_, total_work);
|
|
197
|
+
const size_t per_worker = total_work / num_tasks;
|
|
198
|
+
const size_t remainder = total_work % num_tasks;
|
|
199
|
+
|
|
200
|
+
{
|
|
201
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
202
|
+
pending_tasks.fetch_add(num_tasks, std::memory_order_relaxed);
|
|
203
|
+
|
|
204
|
+
for (size_t w = 0; w < num_tasks; ++w) {
|
|
205
|
+
size_t start = w * per_worker + std::min(w, remainder);
|
|
206
|
+
size_t end = start + per_worker + (w < remainder ? 1 : 0);
|
|
207
|
+
tasks.emplace_back([=]() { task_func(start, end); });
|
|
208
|
+
}
|
|
209
|
+
}
|
|
210
|
+
work_available.notify_all();
|
|
211
|
+
}
|
|
212
|
+
|
|
127
213
|
void wait_all() {
|
|
128
|
-
std::unique_lock<std::mutex> lock(
|
|
129
|
-
|
|
130
|
-
return
|
|
214
|
+
std::unique_lock<std::mutex> lock(mutex);
|
|
215
|
+
work_done.wait(lock, [this] {
|
|
216
|
+
return pending_tasks.load(std::memory_order_acquire) == 0;
|
|
131
217
|
});
|
|
132
218
|
}
|
|
133
|
-
|
|
134
|
-
|
|
219
|
+
|
|
220
|
+
template<typename F>
|
|
221
|
+
void enqueue_n_threads(size_t total_work, size_t num_threads, F task_func) {
|
|
222
|
+
if (total_work == 0 || num_threads == 0) return;
|
|
223
|
+
|
|
224
|
+
num_threads = std::min(num_threads, std::min(num_workers_, total_work));
|
|
225
|
+
const size_t per_thread = total_work / num_threads;
|
|
226
|
+
const size_t remainder = total_work % num_threads;
|
|
227
|
+
|
|
228
|
+
{
|
|
229
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
230
|
+
pending_tasks.fetch_add(num_threads, std::memory_order_relaxed);
|
|
231
|
+
|
|
232
|
+
for (size_t t = 0; t < num_threads; ++t) {
|
|
233
|
+
size_t start = t * per_thread + std::min(t, remainder);
|
|
234
|
+
size_t end = start + per_thread + (t < remainder ? 1 : 0);
|
|
235
|
+
tasks.emplace_back([=]() { task_func(start, end); });
|
|
236
|
+
}
|
|
237
|
+
}
|
|
238
|
+
work_available.notify_all();
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
size_t num_workers() const { return num_workers_; }
|
|
135
242
|
};
|
|
136
|
-
|
|
243
|
+
|
|
137
244
|
inline ThreadPool& get_thread_pool() {
|
|
138
245
|
static ThreadPool pool;
|
|
139
246
|
return pool;
|
|
140
247
|
}
|
|
141
248
|
|
|
142
|
-
|
|
143
|
-
|
|
249
|
+
struct ParallelConfig {
|
|
250
|
+
size_t min_work_gate;
|
|
251
|
+
size_t work_per_thread;
|
|
252
|
+
|
|
253
|
+
constexpr ParallelConfig(size_t gate, size_t per_thread)
|
|
254
|
+
: min_work_gate(gate), work_per_thread(per_thread) {}
|
|
255
|
+
};
|
|
256
|
+
|
|
257
|
+
inline size_t get_optimal_thread_count(size_t total_work, ParallelConfig config) {
|
|
258
|
+
if (total_work < config.min_work_gate) return 1;
|
|
259
|
+
|
|
144
260
|
size_t pool_size = get_thread_pool().num_workers();
|
|
145
|
-
|
|
146
|
-
|
|
261
|
+
size_t num_threads = (total_work + config.work_per_thread - 1) / config.work_per_thread;
|
|
262
|
+
return std::min(pool_size, std::max(static_cast<size_t>(1), num_threads));
|
|
147
263
|
}
|
|
148
|
-
|
|
264
|
+
|
|
149
265
|
struct Thresholds {
|
|
266
|
+
#if defined(__ANDROID__)
|
|
267
|
+
static constexpr ParallelConfig ATTENTION{64, 32};
|
|
268
|
+
static constexpr ParallelConfig ELEMENT_WISE{5000, 2500};
|
|
269
|
+
static constexpr ParallelConfig AXIS_REDUCE{1000, 500};
|
|
270
|
+
static constexpr ParallelConfig ALL_REDUCE{10000, 5000};
|
|
271
|
+
static constexpr ParallelConfig SCALAR_BASIC{30000, 15000};
|
|
272
|
+
static constexpr ParallelConfig SCALAR_EXPENSIVE{10000, 5000};
|
|
273
|
+
#else // Apple
|
|
274
|
+
static constexpr ParallelConfig ATTENTION{32, 16};
|
|
275
|
+
static constexpr ParallelConfig ELEMENT_WISE{5000, 2500};
|
|
276
|
+
static constexpr ParallelConfig AXIS_REDUCE{1000, 500};
|
|
277
|
+
static constexpr ParallelConfig ALL_REDUCE{10000, 5000};
|
|
278
|
+
static constexpr ParallelConfig SCALAR_BASIC{5000, 2500};
|
|
279
|
+
static constexpr ParallelConfig SCALAR_EXPENSIVE{2500, 1250};
|
|
280
|
+
#endif
|
|
281
|
+
};
|
|
150
282
|
|
|
283
|
+
struct GemmThreading {
|
|
151
284
|
#if defined(__ANDROID__)
|
|
152
|
-
static
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
static
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
static constexpr size_t
|
|
161
|
-
static
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
static constexpr size_t
|
|
171
|
-
static
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
static
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
285
|
+
static size_t get_num_threads(size_t M, size_t pool_size) {
|
|
286
|
+
if (M <= 1) return 1;
|
|
287
|
+
return pool_size;
|
|
288
|
+
}
|
|
289
|
+
static size_t get_gemv_threads(size_t /*N_blocks*/, size_t /*pool_size*/) {
|
|
290
|
+
return 1;
|
|
291
|
+
}
|
|
292
|
+
#elif defined(__APPLE__) && TARGET_OS_IPHONE
|
|
293
|
+
static constexpr size_t GEMV_MIN_N_BLOCKS = 512;
|
|
294
|
+
static size_t get_num_threads(size_t M, size_t pool_size) {
|
|
295
|
+
if (M <= 1) return std::min(pool_size, static_cast<size_t>(2));
|
|
296
|
+
return pool_size;
|
|
297
|
+
}
|
|
298
|
+
static size_t get_gemv_threads(size_t N_blocks, size_t pool_size) {
|
|
299
|
+
if (N_blocks < GEMV_MIN_N_BLOCKS) return 1;
|
|
300
|
+
return std::min(pool_size, static_cast<size_t>(2));
|
|
301
|
+
}
|
|
302
|
+
#else
|
|
303
|
+
static constexpr size_t GEMV_MIN_N_BLOCKS = 256;
|
|
304
|
+
static size_t get_num_threads(size_t M, size_t pool_size) {
|
|
305
|
+
if (M <= 1) return std::min(pool_size, static_cast<size_t>(4));
|
|
306
|
+
return pool_size;
|
|
307
|
+
}
|
|
308
|
+
static size_t get_gemv_threads(size_t N_blocks, size_t pool_size) {
|
|
309
|
+
if (N_blocks < GEMV_MIN_N_BLOCKS) return 1;
|
|
310
|
+
if (N_blocks < 512) return std::min(pool_size, static_cast<size_t>(2));
|
|
311
|
+
return std::min(pool_size, static_cast<size_t>(4));
|
|
312
|
+
}
|
|
179
313
|
#endif
|
|
180
|
-
static constexpr size_t L2_CACHE_SIZE = 256 * 1024;
|
|
181
314
|
};
|
|
315
|
+
|
|
316
|
+
inline size_t& get_gemm_thread_override() {
|
|
317
|
+
static size_t override_threads = 0;
|
|
318
|
+
return override_threads;
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
inline void set_gemm_threads(size_t num_threads) {
|
|
322
|
+
get_gemm_thread_override() = num_threads;
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
inline void reset_gemm_threads() {
|
|
326
|
+
get_gemm_thread_override() = 0;
|
|
327
|
+
}
|
|
182
328
|
|
|
183
329
|
class TaskHandle {
|
|
184
330
|
private:
|
|
@@ -225,10 +371,10 @@ namespace CactusThreading {
|
|
|
225
371
|
};
|
|
226
372
|
|
|
227
373
|
template<typename WorkFunc>
|
|
228
|
-
TaskHandle parallel_for(size_t total_work,
|
|
229
|
-
const size_t num_threads = get_optimal_thread_count(total_work,
|
|
230
|
-
TaskHandle handle(!wait);
|
|
231
|
-
|
|
374
|
+
TaskHandle parallel_for(size_t total_work, ParallelConfig config, WorkFunc work_func, bool wait = true) {
|
|
375
|
+
const size_t num_threads = get_optimal_thread_count(total_work, config);
|
|
376
|
+
TaskHandle handle(!wait);
|
|
377
|
+
|
|
232
378
|
if (num_threads == 1) {
|
|
233
379
|
if (wait) {
|
|
234
380
|
work_func(0, total_work);
|
|
@@ -240,10 +386,10 @@ namespace CactusThreading {
|
|
|
240
386
|
}));
|
|
241
387
|
return handle;
|
|
242
388
|
}
|
|
243
|
-
|
|
389
|
+
|
|
244
390
|
auto& pool = get_thread_pool();
|
|
245
391
|
const size_t work_per_thread = total_work / num_threads;
|
|
246
|
-
|
|
392
|
+
|
|
247
393
|
for (size_t t = 0; t < num_threads; ++t) {
|
|
248
394
|
handle.add_future(pool.enqueue([work_func, t, num_threads, work_per_thread, total_work]() {
|
|
249
395
|
const size_t start_idx = t * work_per_thread;
|
|
@@ -251,17 +397,17 @@ namespace CactusThreading {
|
|
|
251
397
|
work_func(start_idx, end_idx);
|
|
252
398
|
}));
|
|
253
399
|
}
|
|
254
|
-
|
|
400
|
+
|
|
255
401
|
if (wait) {
|
|
256
402
|
handle.wait();
|
|
257
403
|
}
|
|
258
404
|
return handle;
|
|
259
405
|
}
|
|
260
|
-
|
|
406
|
+
|
|
261
407
|
template<typename WorkFunc>
|
|
262
|
-
void parallel_for_2d(size_t outer_size, size_t inner_size,
|
|
408
|
+
void parallel_for_2d(size_t outer_size, size_t inner_size, ParallelConfig config, WorkFunc work_func) {
|
|
263
409
|
const size_t total_work = outer_size * inner_size;
|
|
264
|
-
parallel_for(total_work,
|
|
410
|
+
parallel_for(total_work, config, [&](size_t start_idx, size_t end_idx) {
|
|
265
411
|
for (size_t work_idx = start_idx; work_idx < end_idx; ++work_idx) {
|
|
266
412
|
const size_t outer = work_idx / inner_size;
|
|
267
413
|
const size_t inner = work_idx % inner_size;
|
|
@@ -269,11 +415,11 @@ namespace CactusThreading {
|
|
|
269
415
|
}
|
|
270
416
|
});
|
|
271
417
|
}
|
|
272
|
-
|
|
418
|
+
|
|
273
419
|
template<typename WorkFunc, typename ResultType, typename CombineFunc>
|
|
274
|
-
ResultType parallel_reduce(size_t total_work,
|
|
420
|
+
ResultType parallel_reduce(size_t total_work, ParallelConfig config,
|
|
275
421
|
WorkFunc work_func, ResultType init_value, CombineFunc combine_func) {
|
|
276
|
-
const size_t num_threads = get_optimal_thread_count(total_work,
|
|
422
|
+
const size_t num_threads = get_optimal_thread_count(total_work, config);
|
|
277
423
|
|
|
278
424
|
if (num_threads == 1) {
|
|
279
425
|
return work_func(0, total_work);
|
|
@@ -298,46 +444,25 @@ namespace CactusThreading {
|
|
|
298
444
|
}
|
|
299
445
|
return result;
|
|
300
446
|
}
|
|
301
|
-
|
|
302
|
-
inline size_t compute_gemm_parallelism(size_t M, size_t K, size_t N, size_t element_size) {
|
|
303
|
-
size_t total_ops = M * K * N;
|
|
304
|
-
|
|
305
|
-
if (total_ops < Thresholds::GEMM_SMALL) return 1;
|
|
306
|
-
|
|
307
|
-
if (total_ops < Thresholds::GEMM_MEDIUM) {
|
|
308
|
-
return std::min(static_cast<size_t>(2), get_thread_pool().num_workers());
|
|
309
|
-
}
|
|
310
|
-
|
|
311
|
-
size_t bytes_accessed = (M * K + K * N + M * N) * element_size;
|
|
312
|
-
size_t cache_tiles = (bytes_accessed + Thresholds::L2_CACHE_SIZE - 1) / Thresholds::L2_CACHE_SIZE;
|
|
313
|
-
|
|
314
|
-
size_t compute_threads = std::sqrt(static_cast<double>(total_ops) / Thresholds::GEMM_SMALL);
|
|
315
|
-
size_t memory_threads = cache_tiles;
|
|
316
|
-
|
|
317
|
-
size_t optimal = std::min(compute_threads, memory_threads);
|
|
318
|
-
return std::min(optimal, get_thread_pool().num_workers());
|
|
319
|
-
}
|
|
320
|
-
|
|
447
|
+
|
|
321
448
|
template<typename WorkFunc>
|
|
322
|
-
void
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
size_t
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
work_func(row_start, row_end, col_start, col_end);
|
|
338
|
-
}
|
|
339
|
-
});
|
|
449
|
+
void parallel_gemm_tiles(size_t M, size_t total_tiles, WorkFunc work_func) {
|
|
450
|
+
auto& pool = get_thread_pool();
|
|
451
|
+
|
|
452
|
+
size_t override = get_gemm_thread_override();
|
|
453
|
+
size_t num_threads = (override > 0) ? override : GemmThreading::get_num_threads(M, pool.num_workers());
|
|
454
|
+
num_threads = std::min(num_threads, total_tiles);
|
|
455
|
+
|
|
456
|
+
if (num_threads <= 1) {
|
|
457
|
+
work_func(0, total_tiles);
|
|
458
|
+
return;
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
pool.enqueue_n_threads(total_tiles, num_threads, work_func);
|
|
462
|
+
pool.wait_all();
|
|
340
463
|
}
|
|
464
|
+
|
|
341
465
|
}
|
|
342
466
|
|
|
467
|
+
|
|
343
468
|
#endif // KERNEL_UTILS_H
|
|
Binary file
|