cactus-react-native 1.4.0 → 1.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +212 -27
- package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
- package/cpp/HybridCactus.cpp +119 -0
- package/cpp/HybridCactus.hpp +13 -0
- package/cpp/cactus_ffi.h +24 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +24 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_utils.h +41 -1
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +66 -48
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/gemma_tools.h +549 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +102 -21
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +45 -195
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +399 -140
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +24 -0
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_utils.h +41 -1
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +66 -48
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/gemma_tools.h +549 -0
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +102 -21
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +45 -195
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +399 -140
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
- package/lib/module/api/Database.js +0 -92
- package/lib/module/api/Database.js.map +1 -1
- package/lib/module/classes/CactusLM.js +33 -15
- package/lib/module/classes/CactusLM.js.map +1 -1
- package/lib/module/classes/CactusSTT.js +90 -15
- package/lib/module/classes/CactusSTT.js.map +1 -1
- package/lib/module/hooks/useCactusLM.js +14 -5
- package/lib/module/hooks/useCactusLM.js.map +1 -1
- package/lib/module/hooks/useCactusSTT.js +100 -4
- package/lib/module/hooks/useCactusSTT.js.map +1 -1
- package/lib/module/index.js.map +1 -1
- package/lib/module/models.js +336 -0
- package/lib/module/models.js.map +1 -0
- package/lib/module/native/Cactus.js +37 -0
- package/lib/module/native/Cactus.js.map +1 -1
- package/lib/module/types/CactusLM.js +2 -0
- package/lib/module/types/CactusSTT.js +2 -0
- package/lib/module/types/common.js +2 -0
- package/lib/module/types/{CactusModel.js.map → common.js.map} +1 -1
- package/lib/typescript/src/api/Database.d.ts +0 -6
- package/lib/typescript/src/api/Database.d.ts.map +1 -1
- package/lib/typescript/src/classes/CactusLM.d.ts +7 -3
- package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
- package/lib/typescript/src/classes/CactusSTT.d.ts +13 -4
- package/lib/typescript/src/classes/CactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/hooks/useCactusLM.d.ts +2 -2
- package/lib/typescript/src/hooks/useCactusLM.d.ts.map +1 -1
- package/lib/typescript/src/hooks/useCactusSTT.d.ts +12 -4
- package/lib/typescript/src/hooks/useCactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/index.d.ts +2 -3
- package/lib/typescript/src/index.d.ts.map +1 -1
- package/lib/typescript/src/models.d.ts +6 -0
- package/lib/typescript/src/models.d.ts.map +1 -0
- package/lib/typescript/src/native/Cactus.d.ts +6 -1
- package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
- package/lib/typescript/src/specs/Cactus.nitro.d.ts +5 -0
- package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusLM.d.ts +2 -0
- package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusSTT.d.ts +20 -0
- package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/types/common.d.ts +28 -0
- package/lib/typescript/src/types/common.d.ts.map +1 -0
- package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +5 -0
- package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +5 -0
- package/package.json +1 -1
- package/src/api/Database.ts +0 -133
- package/src/classes/CactusLM.ts +49 -17
- package/src/classes/CactusSTT.ts +118 -17
- package/src/hooks/useCactusLM.ts +25 -5
- package/src/hooks/useCactusSTT.ts +117 -5
- package/src/index.tsx +6 -2
- package/src/models.ts +344 -0
- package/src/native/Cactus.ts +55 -0
- package/src/specs/Cactus.nitro.ts +5 -0
- package/src/types/CactusLM.ts +3 -0
- package/src/types/CactusSTT.ts +26 -0
- package/src/types/common.ts +28 -0
- package/lib/module/types/CactusModel.js +0 -2
- package/lib/module/types/CactusSTTModel.js +0 -2
- package/lib/module/types/CactusSTTModel.js.map +0 -1
- package/lib/typescript/src/types/CactusModel.d.ts +0 -13
- package/lib/typescript/src/types/CactusModel.d.ts.map +0 -1
- package/lib/typescript/src/types/CactusSTTModel.d.ts +0 -8
- package/lib/typescript/src/types/CactusSTTModel.d.ts.map +0 -1
- package/src/types/CactusModel.ts +0 -15
- package/src/types/CactusSTTModel.ts +0 -10
|
@@ -2,6 +2,13 @@
|
|
|
2
2
|
#define KERNEL_UTILS_H
|
|
3
3
|
|
|
4
4
|
#include <arm_neon.h>
|
|
5
|
+
#if defined(__APPLE__)
|
|
6
|
+
#include <TargetConditionals.h>
|
|
7
|
+
#endif
|
|
8
|
+
#if defined(__ANDROID__)
|
|
9
|
+
#include <sys/auxv.h>
|
|
10
|
+
#include <asm/hwcap.h>
|
|
11
|
+
#endif
|
|
5
12
|
#include <algorithm>
|
|
6
13
|
#include <cmath>
|
|
7
14
|
#include <thread>
|
|
@@ -19,166 +26,439 @@
|
|
|
19
26
|
#include <cstdio>
|
|
20
27
|
|
|
21
28
|
constexpr size_t NEON_VECTOR_SIZE = 16;
|
|
29
|
+
constexpr size_t STREAMING_STORE_THRESHOLD = 32768;
|
|
22
30
|
|
|
23
|
-
inline
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
31
|
+
inline void stream_store_f16x8(__fp16* dst, float16x8_t val) {
|
|
32
|
+
#if defined(__aarch64__)
|
|
33
|
+
float16x4_t lo = vget_low_f16(val);
|
|
34
|
+
float16x4_t hi = vget_high_f16(val);
|
|
35
|
+
__asm__ __volatile__(
|
|
36
|
+
"stnp %d0, %d1, [%2]"
|
|
37
|
+
:
|
|
38
|
+
: "w"(lo), "w"(hi), "r"(dst)
|
|
39
|
+
: "memory"
|
|
40
|
+
);
|
|
41
|
+
#else
|
|
42
|
+
vst1q_f16(dst, val);
|
|
43
|
+
#endif
|
|
30
44
|
}
|
|
31
45
|
|
|
32
46
|
#if defined(__ARM_FEATURE_DOTPROD)
|
|
33
|
-
inline int32x4_t
|
|
47
|
+
inline int32x4_t accum_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
|
34
48
|
return vdotq_s32(acc, a, b);
|
|
35
49
|
}
|
|
36
50
|
#else
|
|
37
|
-
inline int32x4_t
|
|
51
|
+
inline int32x4_t accum_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
|
38
52
|
int16x8_t prod_low = vmull_s8(vget_low_s8(a), vget_low_s8(b));
|
|
39
53
|
int32x4_t acc_high = vpaddlq_s16(vmull_s8(vget_high_s8(a), vget_high_s8(b)));
|
|
40
54
|
return vaddq_s32(vaddq_s32(acc, vpaddlq_s16(prod_low)), acc_high);
|
|
41
55
|
}
|
|
42
56
|
#endif
|
|
43
57
|
|
|
44
|
-
|
|
58
|
+
// I8MM support: runtime detection on Android, compile-time on Apple
|
|
59
|
+
#if defined(__ANDROID__) && defined(__aarch64__)
|
|
60
|
+
|
|
61
|
+
inline bool cactus_has_i8mm() {
|
|
62
|
+
static int8_t supported = -1;
|
|
63
|
+
if (supported == -1) {
|
|
64
|
+
unsigned long hwcaps = getauxval(AT_HWCAP2);
|
|
65
|
+
supported = (hwcaps & HWCAP2_I8MM) ? 1 : 0;
|
|
66
|
+
}
|
|
67
|
+
return supported;
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
__attribute__((target("arch=armv8.2-a+i8mm")))
|
|
71
|
+
inline int32x4_t accum_matmul(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
|
72
|
+
return vmmlaq_s32(acc, a, b);
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
#elif defined(__APPLE__) && defined(__aarch64__)
|
|
76
|
+
|
|
77
|
+
inline bool cactus_has_i8mm() {
|
|
78
|
+
return true;
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
__attribute__((target("i8mm")))
|
|
82
|
+
inline int32x4_t accum_matmul(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
|
83
|
+
return vmmlaq_s32(acc, a, b);
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
#else
|
|
87
|
+
|
|
88
|
+
inline bool cactus_has_i8mm() {
|
|
89
|
+
return false;
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
#endif
|
|
93
|
+
|
|
94
|
+
inline float16x8_t accum_f16_dot(float16x8_t acc, float16x8_t a_low, float16x8_t a_high,
|
|
45
95
|
float16x8_t b_low, float16x8_t b_high) {
|
|
46
96
|
acc = vfmaq_f16(acc, a_low, b_low);
|
|
47
97
|
return vfmaq_f16(acc, a_high, b_high);
|
|
48
98
|
}
|
|
49
99
|
|
|
50
|
-
inline float32x4_t
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
100
|
+
inline float32x4_t fast_exp_f32x4(float32x4_t x) {
|
|
101
|
+
const float32x4_t log2e = vdupq_n_f32(1.4426950408889634f);
|
|
102
|
+
const float32x4_t ln2 = vdupq_n_f32(0.6931471805599453f);
|
|
103
|
+
|
|
104
|
+
const float32x4_t c0 = vdupq_n_f32(1.0f);
|
|
105
|
+
const float32x4_t c1 = vdupq_n_f32(0.6931471805599453f);
|
|
106
|
+
const float32x4_t c2 = vdupq_n_f32(0.2402265069591007f);
|
|
107
|
+
const float32x4_t c3 = vdupq_n_f32(0.05550410866482158f);
|
|
108
|
+
const float32x4_t c4 = vdupq_n_f32(0.009618129842071803f);
|
|
109
|
+
|
|
110
|
+
x = vmaxq_f32(x, vdupq_n_f32(-87.0f));
|
|
111
|
+
x = vminq_f32(x, vdupq_n_f32(87.0f));
|
|
112
|
+
|
|
113
|
+
float32x4_t z = vmulq_f32(x, log2e);
|
|
114
|
+
|
|
115
|
+
int32x4_t zi = vcvtq_s32_f32(z);
|
|
116
|
+
float32x4_t zf = vsubq_f32(z, vcvtq_f32_s32(zi));
|
|
117
|
+
|
|
118
|
+
uint32x4_t neg_mask = vcltq_f32(zf, vdupq_n_f32(0.0f));
|
|
119
|
+
zi = vsubq_s32(zi, vandq_s32(vreinterpretq_s32_u32(neg_mask), vdupq_n_s32(1)));
|
|
120
|
+
zf = vaddq_f32(zf, vreinterpretq_f32_u32(vandq_u32(neg_mask, vreinterpretq_u32_f32(vdupq_n_f32(1.0f)))));
|
|
121
|
+
|
|
122
|
+
float32x4_t zf_ln2 = vmulq_f32(zf, ln2);
|
|
123
|
+
float32x4_t p = c4;
|
|
124
|
+
p = vfmaq_f32(c3, p, zf_ln2);
|
|
125
|
+
p = vfmaq_f32(c2, p, zf_ln2);
|
|
126
|
+
p = vfmaq_f32(c1, p, zf_ln2);
|
|
127
|
+
p = vfmaq_f32(c0, p, zf_ln2);
|
|
128
|
+
|
|
129
|
+
int32x4_t exp_bits = vshlq_n_s32(vaddq_s32(zi, vdupq_n_s32(127)), 23);
|
|
130
|
+
float32x4_t scale = vreinterpretq_f32_s32(exp_bits);
|
|
131
|
+
|
|
132
|
+
return vmulq_f32(p, scale);
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
inline float32x4_t fast_tanh_f32x4(float32x4_t x) {
|
|
136
|
+
const float32x4_t one = vdupq_n_f32(1.0f);
|
|
137
|
+
const float32x4_t neg_one = vdupq_n_f32(-1.0f);
|
|
138
|
+
|
|
139
|
+
uint32x4_t pos_sat = vcgtq_f32(x, vdupq_n_f32(4.5f));
|
|
140
|
+
uint32x4_t neg_sat = vcltq_f32(x, vdupq_n_f32(-4.5f));
|
|
141
|
+
|
|
142
|
+
const float32x4_t c27 = vdupq_n_f32(27.0f);
|
|
143
|
+
const float32x4_t c9 = vdupq_n_f32(9.0f);
|
|
144
|
+
|
|
145
|
+
float32x4_t x2 = vmulq_f32(x, x);
|
|
146
|
+
float32x4_t num = vaddq_f32(c27, x2);
|
|
147
|
+
float32x4_t den = vfmaq_f32(c27, c9, x2);
|
|
148
|
+
|
|
149
|
+
float32x4_t result = vmulq_f32(x, vdivq_f32(num, den));
|
|
150
|
+
|
|
151
|
+
result = vbslq_f32(pos_sat, one, result);
|
|
152
|
+
result = vbslq_f32(neg_sat, neg_one, result);
|
|
153
|
+
|
|
154
|
+
return result;
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
inline int8x16_t unpack_int4_lo(uint8x16_t packed) {
|
|
158
|
+
uint8x16_t lo = vandq_u8(packed, vdupq_n_u8(0x0F));
|
|
159
|
+
uint8x16_t sign_mask = vcgtq_u8(lo, vdupq_n_u8(7));
|
|
160
|
+
uint8x16_t correction = vandq_u8(sign_mask, vdupq_n_u8(16));
|
|
161
|
+
return vreinterpretq_s8_u8(vsubq_u8(lo, correction));
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
inline int8x16_t unpack_int4_hi(uint8x16_t packed) {
|
|
165
|
+
uint8x16_t hi = vshrq_n_u8(packed, 4);
|
|
166
|
+
uint8x16_t sign_mask = vcgtq_u8(hi, vdupq_n_u8(7));
|
|
167
|
+
uint8x16_t correction = vandq_u8(sign_mask, vdupq_n_u8(16));
|
|
168
|
+
return vreinterpretq_s8_u8(vsubq_u8(hi, correction));
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
inline void unpack_int4_to_int8x32(uint8x16_t packed, int8x16_t& out_lo, int8x16_t& out_hi) {
|
|
172
|
+
int8x16_t lo_nibbles = unpack_int4_lo(packed);
|
|
173
|
+
int8x16_t hi_nibbles = unpack_int4_hi(packed);
|
|
174
|
+
int8x16x2_t interleaved = vzipq_s8(lo_nibbles, hi_nibbles);
|
|
175
|
+
out_lo = interleaved.val[0];
|
|
176
|
+
out_hi = interleaved.val[1];
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
inline int32x4_t int4_dot_asm(int32x4_t acc, uint8x16_t packed, int8x16_t a_lo, int8x16_t a_hi) {
|
|
180
|
+
#if defined(__aarch64__)
|
|
181
|
+
int8x16_t b_lo, b_hi;
|
|
182
|
+
|
|
183
|
+
__asm__ __volatile__ (
|
|
184
|
+
"movi v16.16b, #0x0F \n" // low nibble mask
|
|
185
|
+
"movi v17.16b, #7 \n" // sign threshold
|
|
186
|
+
"movi v18.16b, #16 \n" // sign correction
|
|
187
|
+
|
|
188
|
+
"and %[b_lo].16b, %[packed].16b, v16.16b \n"
|
|
189
|
+
|
|
190
|
+
"ushr %[b_hi].16b, %[packed].16b, #4 \n"
|
|
191
|
+
|
|
192
|
+
"cmgt v19.16b, %[b_lo].16b, v17.16b \n"
|
|
193
|
+
"and v19.16b, v19.16b, v18.16b \n"
|
|
194
|
+
"sub %[b_lo].16b, %[b_lo].16b, v19.16b \n"
|
|
195
|
+
|
|
196
|
+
"cmgt v20.16b, %[b_hi].16b, v17.16b \n"
|
|
197
|
+
"and v20.16b, v20.16b, v18.16b \n"
|
|
198
|
+
"sub %[b_hi].16b, %[b_hi].16b, v20.16b \n"
|
|
199
|
+
|
|
200
|
+
"zip1 v21.16b, %[b_lo].16b, %[b_hi].16b \n"
|
|
201
|
+
"zip2 v22.16b, %[b_lo].16b, %[b_hi].16b \n"
|
|
202
|
+
|
|
203
|
+
".arch armv8.2-a+dotprod \n"
|
|
204
|
+
"sdot %[acc].4s, %[a_lo].16b, v21.16b \n"
|
|
205
|
+
"sdot %[acc].4s, %[a_hi].16b, v22.16b \n"
|
|
206
|
+
|
|
207
|
+
: [acc] "+w"(acc), [b_lo] "=w"(b_lo), [b_hi] "=w"(b_hi)
|
|
208
|
+
: [packed] "w"(packed), [a_lo] "w"(a_lo), [a_hi] "w"(a_hi)
|
|
209
|
+
: "v16", "v17", "v18", "v19", "v20", "v21", "v22"
|
|
210
|
+
);
|
|
211
|
+
|
|
212
|
+
return acc;
|
|
213
|
+
#else
|
|
214
|
+
int8x16_t b_lo, b_hi;
|
|
215
|
+
unpack_int4_to_int8x32(packed, b_lo, b_hi);
|
|
216
|
+
acc = accum_dot(acc, a_lo, b_lo);
|
|
217
|
+
acc = accum_dot(acc, a_hi, b_hi);
|
|
218
|
+
return acc;
|
|
219
|
+
#endif
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
inline int32_t int4_dot_m1_asm(const int8_t* a_ptr, const uint8_t* b_packed, size_t group_size) {
|
|
223
|
+
#if defined(__aarch64__)
|
|
224
|
+
int32x4_t acc = vdupq_n_s32(0);
|
|
225
|
+
|
|
226
|
+
for (size_t k = 0; k < group_size; k += 64) {
|
|
227
|
+
uint8x16_t p0 = vld1q_u8(b_packed + k/2);
|
|
228
|
+
uint8x16_t p1 = vld1q_u8(b_packed + k/2 + 16);
|
|
229
|
+
|
|
230
|
+
int8x16_t a0 = vld1q_s8(a_ptr + k);
|
|
231
|
+
int8x16_t a1 = vld1q_s8(a_ptr + k + 16);
|
|
232
|
+
int8x16_t a2 = vld1q_s8(a_ptr + k + 32);
|
|
233
|
+
int8x16_t a3 = vld1q_s8(a_ptr + k + 48);
|
|
234
|
+
|
|
235
|
+
acc = int4_dot_asm(acc, p0, a0, a1);
|
|
236
|
+
acc = int4_dot_asm(acc, p1, a2, a3);
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
return vaddvq_s32(acc);
|
|
240
|
+
#else
|
|
241
|
+
int32x4_t acc = vdupq_n_s32(0);
|
|
242
|
+
for (size_t k = 0; k < group_size; k += 32) {
|
|
243
|
+
uint8x16_t packed = vld1q_u8(b_packed + k/2);
|
|
244
|
+
int8x16_t b_lo, b_hi;
|
|
245
|
+
unpack_int4_to_int8x32(packed, b_lo, b_hi);
|
|
246
|
+
acc = accum_dot(acc, vld1q_s8(a_ptr + k), b_lo);
|
|
247
|
+
acc = accum_dot(acc, vld1q_s8(a_ptr + k + 16), b_hi);
|
|
248
|
+
}
|
|
249
|
+
return vaddvq_s32(acc);
|
|
250
|
+
#endif
|
|
54
251
|
}
|
|
55
252
|
|
|
56
253
|
namespace CactusThreading {
|
|
57
|
-
|
|
254
|
+
|
|
58
255
|
class ThreadPool {
|
|
59
256
|
private:
|
|
257
|
+
static constexpr size_t MAX_WORKERS = 16;
|
|
258
|
+
|
|
60
259
|
std::vector<std::thread> workers;
|
|
61
|
-
std::
|
|
62
|
-
|
|
63
|
-
std::
|
|
64
|
-
std::
|
|
65
|
-
std::
|
|
66
|
-
|
|
67
|
-
|
|
260
|
+
std::deque<std::function<void()>> tasks;
|
|
261
|
+
|
|
262
|
+
std::mutex mutex;
|
|
263
|
+
std::condition_variable work_available;
|
|
264
|
+
std::condition_variable work_done;
|
|
265
|
+
|
|
266
|
+
bool stop{false};
|
|
267
|
+
std::atomic<size_t> pending_tasks{0};
|
|
268
|
+
size_t num_workers_;
|
|
269
|
+
|
|
68
270
|
void worker_thread() {
|
|
69
271
|
while (true) {
|
|
70
272
|
std::function<void()> task;
|
|
71
273
|
{
|
|
72
|
-
std::unique_lock<std::mutex> lock(
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
274
|
+
std::unique_lock<std::mutex> lock(mutex);
|
|
275
|
+
work_available.wait(lock, [this] {
|
|
276
|
+
return stop || !tasks.empty();
|
|
277
|
+
});
|
|
278
|
+
|
|
279
|
+
if (stop && tasks.empty()) {
|
|
280
|
+
return;
|
|
281
|
+
}
|
|
282
|
+
|
|
77
283
|
task = std::move(tasks.front());
|
|
78
|
-
tasks.
|
|
79
|
-
active_workers++;
|
|
284
|
+
tasks.pop_front();
|
|
80
285
|
}
|
|
81
|
-
|
|
286
|
+
|
|
82
287
|
task();
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
288
|
+
|
|
289
|
+
if (pending_tasks.fetch_sub(1, std::memory_order_acq_rel) == 1) {
|
|
290
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
291
|
+
work_done.notify_one();
|
|
292
|
+
}
|
|
86
293
|
}
|
|
87
294
|
}
|
|
88
|
-
|
|
295
|
+
|
|
89
296
|
public:
|
|
90
|
-
explicit ThreadPool(size_t num_threads = std::thread::hardware_concurrency())
|
|
91
|
-
|
|
92
|
-
|
|
297
|
+
explicit ThreadPool(size_t num_threads = std::thread::hardware_concurrency())
|
|
298
|
+
: stop(false), pending_tasks(0) {
|
|
299
|
+
num_workers_ = std::min(num_threads, MAX_WORKERS);
|
|
300
|
+
if (num_workers_ == 0) num_workers_ = 1;
|
|
301
|
+
workers.reserve(num_workers_);
|
|
302
|
+
for (size_t i = 0; i < num_workers_; ++i) {
|
|
93
303
|
workers.emplace_back(&ThreadPool::worker_thread, this);
|
|
94
304
|
}
|
|
95
305
|
}
|
|
96
|
-
|
|
306
|
+
|
|
97
307
|
~ThreadPool() {
|
|
98
308
|
{
|
|
99
|
-
std::
|
|
309
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
100
310
|
stop = true;
|
|
101
311
|
}
|
|
102
|
-
|
|
312
|
+
work_available.notify_all();
|
|
103
313
|
for (auto& worker : workers) {
|
|
104
|
-
worker.
|
|
314
|
+
if (worker.joinable()) {
|
|
315
|
+
worker.join();
|
|
316
|
+
}
|
|
105
317
|
}
|
|
106
318
|
}
|
|
107
|
-
|
|
319
|
+
|
|
108
320
|
template<typename F>
|
|
109
321
|
auto enqueue(F&& f) -> std::future<decltype(f())> {
|
|
110
322
|
using return_type = decltype(f());
|
|
111
|
-
|
|
323
|
+
|
|
112
324
|
auto task = std::make_shared<std::packaged_task<return_type()>>(
|
|
113
325
|
std::forward<F>(f)
|
|
114
326
|
);
|
|
115
|
-
|
|
327
|
+
|
|
116
328
|
std::future<return_type> res = task->get_future();
|
|
329
|
+
|
|
117
330
|
{
|
|
118
|
-
std::
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
tasks.emplace([task](){ (*task)(); });
|
|
331
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
332
|
+
pending_tasks.fetch_add(1, std::memory_order_relaxed);
|
|
333
|
+
tasks.emplace_back([task](){ (*task)(); });
|
|
122
334
|
}
|
|
123
|
-
|
|
335
|
+
work_available.notify_one();
|
|
336
|
+
|
|
124
337
|
return res;
|
|
125
338
|
}
|
|
126
|
-
|
|
339
|
+
|
|
340
|
+
template<typename F>
|
|
341
|
+
void enqueue_batch(size_t total_work, F task_func) {
|
|
342
|
+
if (total_work == 0) return;
|
|
343
|
+
|
|
344
|
+
const size_t num_tasks = std::min(num_workers_, total_work);
|
|
345
|
+
const size_t per_worker = total_work / num_tasks;
|
|
346
|
+
const size_t remainder = total_work % num_tasks;
|
|
347
|
+
|
|
348
|
+
{
|
|
349
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
350
|
+
pending_tasks.fetch_add(num_tasks, std::memory_order_relaxed);
|
|
351
|
+
|
|
352
|
+
for (size_t w = 0; w < num_tasks; ++w) {
|
|
353
|
+
size_t start = w * per_worker + std::min(w, remainder);
|
|
354
|
+
size_t end = start + per_worker + (w < remainder ? 1 : 0);
|
|
355
|
+
tasks.emplace_back([=]() { task_func(start, end); });
|
|
356
|
+
}
|
|
357
|
+
}
|
|
358
|
+
work_available.notify_all();
|
|
359
|
+
}
|
|
360
|
+
|
|
127
361
|
void wait_all() {
|
|
128
|
-
std::unique_lock<std::mutex> lock(
|
|
129
|
-
|
|
130
|
-
return
|
|
362
|
+
std::unique_lock<std::mutex> lock(mutex);
|
|
363
|
+
work_done.wait(lock, [this] {
|
|
364
|
+
return pending_tasks.load(std::memory_order_acquire) == 0;
|
|
131
365
|
});
|
|
132
366
|
}
|
|
133
|
-
|
|
134
|
-
|
|
367
|
+
|
|
368
|
+
template<typename F>
|
|
369
|
+
void enqueue_n_threads(size_t total_work, size_t num_threads, F task_func) {
|
|
370
|
+
if (total_work == 0 || num_threads == 0) return;
|
|
371
|
+
|
|
372
|
+
num_threads = std::min(num_threads, std::min(num_workers_, total_work));
|
|
373
|
+
const size_t per_thread = total_work / num_threads;
|
|
374
|
+
const size_t remainder = total_work % num_threads;
|
|
375
|
+
|
|
376
|
+
{
|
|
377
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
378
|
+
pending_tasks.fetch_add(num_threads, std::memory_order_relaxed);
|
|
379
|
+
|
|
380
|
+
for (size_t t = 0; t < num_threads; ++t) {
|
|
381
|
+
size_t start = t * per_thread + std::min(t, remainder);
|
|
382
|
+
size_t end = start + per_thread + (t < remainder ? 1 : 0);
|
|
383
|
+
tasks.emplace_back([=]() { task_func(start, end); });
|
|
384
|
+
}
|
|
385
|
+
}
|
|
386
|
+
work_available.notify_all();
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
size_t num_workers() const { return num_workers_; }
|
|
135
390
|
};
|
|
136
|
-
|
|
391
|
+
|
|
137
392
|
inline ThreadPool& get_thread_pool() {
|
|
138
393
|
static ThreadPool pool;
|
|
139
394
|
return pool;
|
|
140
395
|
}
|
|
141
396
|
|
|
142
|
-
|
|
143
|
-
|
|
397
|
+
struct ParallelConfig {
|
|
398
|
+
size_t min_work_gate;
|
|
399
|
+
size_t work_per_thread;
|
|
400
|
+
|
|
401
|
+
constexpr ParallelConfig(size_t gate, size_t per_thread)
|
|
402
|
+
: min_work_gate(gate), work_per_thread(per_thread) {}
|
|
403
|
+
};
|
|
404
|
+
|
|
405
|
+
inline size_t get_optimal_thread_count(size_t total_work, ParallelConfig config) {
|
|
406
|
+
if (total_work < config.min_work_gate) return 1;
|
|
407
|
+
|
|
144
408
|
size_t pool_size = get_thread_pool().num_workers();
|
|
145
|
-
|
|
146
|
-
|
|
409
|
+
size_t num_threads = (total_work + config.work_per_thread - 1) / config.work_per_thread;
|
|
410
|
+
return std::min(pool_size, std::max(static_cast<size_t>(1), num_threads));
|
|
147
411
|
}
|
|
148
|
-
|
|
412
|
+
|
|
149
413
|
struct Thresholds {
|
|
414
|
+
#if defined(__ANDROID__)
|
|
415
|
+
static constexpr ParallelConfig ATTENTION{64, 32};
|
|
416
|
+
static constexpr ParallelConfig ELEMENT_WISE{5000, 2500};
|
|
417
|
+
static constexpr ParallelConfig AXIS_REDUCE{1000, 500};
|
|
418
|
+
static constexpr ParallelConfig ALL_REDUCE{10000, 5000};
|
|
419
|
+
static constexpr ParallelConfig SCALAR_BASIC{30000, 15000};
|
|
420
|
+
static constexpr ParallelConfig SCALAR_EXPENSIVE{10000, 5000};
|
|
421
|
+
#else // Apple
|
|
422
|
+
static constexpr ParallelConfig ATTENTION{32, 16};
|
|
423
|
+
static constexpr ParallelConfig ELEMENT_WISE{5000, 2500};
|
|
424
|
+
static constexpr ParallelConfig AXIS_REDUCE{1000, 500};
|
|
425
|
+
static constexpr ParallelConfig ALL_REDUCE{10000, 5000};
|
|
426
|
+
static constexpr ParallelConfig SCALAR_BASIC{5000, 2500};
|
|
427
|
+
static constexpr ParallelConfig SCALAR_EXPENSIVE{2500, 1250};
|
|
428
|
+
#endif
|
|
429
|
+
};
|
|
150
430
|
|
|
431
|
+
struct GemmThreading {
|
|
151
432
|
#if defined(__ANDROID__)
|
|
152
|
-
static
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
static
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
static
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
static constexpr size_t ELEMENT_WISE = 5000;
|
|
167
|
-
static constexpr size_t AXIS_REDUCE = 1000;
|
|
168
|
-
static constexpr size_t ALL_REDUCE = 10000;
|
|
169
|
-
static constexpr size_t SCALAR_BASIC = 5000;
|
|
170
|
-
static constexpr size_t SCALAR_EXPENSIVE = 2500;
|
|
171
|
-
static constexpr size_t ATTENTION = 4;
|
|
172
|
-
static constexpr size_t GEMM_TILED = 4;
|
|
173
|
-
static constexpr size_t GEMM_SMALL = 64 * 64 * 64;
|
|
174
|
-
static constexpr size_t GEMM_MEDIUM = 256 * 256 * 256;
|
|
175
|
-
static constexpr size_t GEMM_TILE_M = 64;
|
|
176
|
-
static constexpr size_t GEMM_TILE_N = 64;
|
|
177
|
-
static constexpr size_t GEMM_TILE_M_SMALL = 32;
|
|
178
|
-
static constexpr size_t GEMM_TILE_N_SMALL = 32;
|
|
433
|
+
static size_t get_num_threads(size_t M, size_t pool_size) {
|
|
434
|
+
if (M <= 1) return 1;
|
|
435
|
+
return pool_size;
|
|
436
|
+
}
|
|
437
|
+
#elif defined(__APPLE__) && TARGET_OS_IPHONE
|
|
438
|
+
static size_t get_num_threads(size_t M, size_t pool_size) {
|
|
439
|
+
if (M <= 1) return std::min(pool_size, static_cast<size_t>(2));
|
|
440
|
+
return pool_size;
|
|
441
|
+
}
|
|
442
|
+
#else // Mac
|
|
443
|
+
static size_t get_num_threads(size_t M, size_t pool_size) {
|
|
444
|
+
if (M <= 1) return std::min(pool_size, static_cast<size_t>(4));
|
|
445
|
+
return pool_size;
|
|
446
|
+
}
|
|
179
447
|
#endif
|
|
180
|
-
static constexpr size_t L2_CACHE_SIZE = 256 * 1024;
|
|
181
448
|
};
|
|
449
|
+
|
|
450
|
+
inline size_t& get_gemm_thread_override() {
|
|
451
|
+
static size_t override_threads = 0;
|
|
452
|
+
return override_threads;
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
inline void set_gemm_threads(size_t num_threads) {
|
|
456
|
+
get_gemm_thread_override() = num_threads;
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
inline void reset_gemm_threads() {
|
|
460
|
+
get_gemm_thread_override() = 0;
|
|
461
|
+
}
|
|
182
462
|
|
|
183
463
|
class TaskHandle {
|
|
184
464
|
private:
|
|
@@ -225,10 +505,10 @@ namespace CactusThreading {
|
|
|
225
505
|
};
|
|
226
506
|
|
|
227
507
|
template<typename WorkFunc>
|
|
228
|
-
TaskHandle parallel_for(size_t total_work,
|
|
229
|
-
const size_t num_threads = get_optimal_thread_count(total_work,
|
|
230
|
-
TaskHandle handle(!wait);
|
|
231
|
-
|
|
508
|
+
TaskHandle parallel_for(size_t total_work, ParallelConfig config, WorkFunc work_func, bool wait = true) {
|
|
509
|
+
const size_t num_threads = get_optimal_thread_count(total_work, config);
|
|
510
|
+
TaskHandle handle(!wait);
|
|
511
|
+
|
|
232
512
|
if (num_threads == 1) {
|
|
233
513
|
if (wait) {
|
|
234
514
|
work_func(0, total_work);
|
|
@@ -240,10 +520,10 @@ namespace CactusThreading {
|
|
|
240
520
|
}));
|
|
241
521
|
return handle;
|
|
242
522
|
}
|
|
243
|
-
|
|
523
|
+
|
|
244
524
|
auto& pool = get_thread_pool();
|
|
245
525
|
const size_t work_per_thread = total_work / num_threads;
|
|
246
|
-
|
|
526
|
+
|
|
247
527
|
for (size_t t = 0; t < num_threads; ++t) {
|
|
248
528
|
handle.add_future(pool.enqueue([work_func, t, num_threads, work_per_thread, total_work]() {
|
|
249
529
|
const size_t start_idx = t * work_per_thread;
|
|
@@ -251,17 +531,17 @@ namespace CactusThreading {
|
|
|
251
531
|
work_func(start_idx, end_idx);
|
|
252
532
|
}));
|
|
253
533
|
}
|
|
254
|
-
|
|
534
|
+
|
|
255
535
|
if (wait) {
|
|
256
536
|
handle.wait();
|
|
257
537
|
}
|
|
258
538
|
return handle;
|
|
259
539
|
}
|
|
260
|
-
|
|
540
|
+
|
|
261
541
|
template<typename WorkFunc>
|
|
262
|
-
void parallel_for_2d(size_t outer_size, size_t inner_size,
|
|
542
|
+
void parallel_for_2d(size_t outer_size, size_t inner_size, ParallelConfig config, WorkFunc work_func) {
|
|
263
543
|
const size_t total_work = outer_size * inner_size;
|
|
264
|
-
parallel_for(total_work,
|
|
544
|
+
parallel_for(total_work, config, [&](size_t start_idx, size_t end_idx) {
|
|
265
545
|
for (size_t work_idx = start_idx; work_idx < end_idx; ++work_idx) {
|
|
266
546
|
const size_t outer = work_idx / inner_size;
|
|
267
547
|
const size_t inner = work_idx % inner_size;
|
|
@@ -269,11 +549,11 @@ namespace CactusThreading {
|
|
|
269
549
|
}
|
|
270
550
|
});
|
|
271
551
|
}
|
|
272
|
-
|
|
552
|
+
|
|
273
553
|
template<typename WorkFunc, typename ResultType, typename CombineFunc>
|
|
274
|
-
ResultType parallel_reduce(size_t total_work,
|
|
554
|
+
ResultType parallel_reduce(size_t total_work, ParallelConfig config,
|
|
275
555
|
WorkFunc work_func, ResultType init_value, CombineFunc combine_func) {
|
|
276
|
-
const size_t num_threads = get_optimal_thread_count(total_work,
|
|
556
|
+
const size_t num_threads = get_optimal_thread_count(total_work, config);
|
|
277
557
|
|
|
278
558
|
if (num_threads == 1) {
|
|
279
559
|
return work_func(0, total_work);
|
|
@@ -298,46 +578,25 @@ namespace CactusThreading {
|
|
|
298
578
|
}
|
|
299
579
|
return result;
|
|
300
580
|
}
|
|
301
|
-
|
|
302
|
-
inline size_t compute_gemm_parallelism(size_t M, size_t K, size_t N, size_t element_size) {
|
|
303
|
-
size_t total_ops = M * K * N;
|
|
304
|
-
|
|
305
|
-
if (total_ops < Thresholds::GEMM_SMALL) return 1;
|
|
306
|
-
|
|
307
|
-
if (total_ops < Thresholds::GEMM_MEDIUM) {
|
|
308
|
-
return std::min(static_cast<size_t>(2), get_thread_pool().num_workers());
|
|
309
|
-
}
|
|
310
|
-
|
|
311
|
-
size_t bytes_accessed = (M * K + K * N + M * N) * element_size;
|
|
312
|
-
size_t cache_tiles = (bytes_accessed + Thresholds::L2_CACHE_SIZE - 1) / Thresholds::L2_CACHE_SIZE;
|
|
313
|
-
|
|
314
|
-
size_t compute_threads = std::sqrt(static_cast<double>(total_ops) / Thresholds::GEMM_SMALL);
|
|
315
|
-
size_t memory_threads = cache_tiles;
|
|
316
|
-
|
|
317
|
-
size_t optimal = std::min(compute_threads, memory_threads);
|
|
318
|
-
return std::min(optimal, get_thread_pool().num_workers());
|
|
319
|
-
}
|
|
320
|
-
|
|
581
|
+
|
|
321
582
|
template<typename WorkFunc>
|
|
322
|
-
void
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
size_t
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
work_func(row_start, row_end, col_start, col_end);
|
|
338
|
-
}
|
|
339
|
-
});
|
|
583
|
+
void parallel_gemm_tiles(size_t M, size_t total_tiles, WorkFunc work_func) {
|
|
584
|
+
auto& pool = get_thread_pool();
|
|
585
|
+
|
|
586
|
+
size_t override = get_gemm_thread_override();
|
|
587
|
+
size_t num_threads = (override > 0) ? override : GemmThreading::get_num_threads(M, pool.num_workers());
|
|
588
|
+
num_threads = std::min(num_threads, total_tiles);
|
|
589
|
+
|
|
590
|
+
if (num_threads <= 1) {
|
|
591
|
+
work_func(0, total_tiles);
|
|
592
|
+
return;
|
|
593
|
+
}
|
|
594
|
+
|
|
595
|
+
pool.enqueue_n_threads(total_tiles, num_threads, work_func);
|
|
596
|
+
pool.wait_all();
|
|
340
597
|
}
|
|
598
|
+
|
|
341
599
|
}
|
|
342
600
|
|
|
601
|
+
|
|
343
602
|
#endif // KERNEL_UTILS_H
|
|
Binary file
|