cactus-react-native 1.4.0 → 1.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. package/README.md +212 -27
  2. package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
  3. package/cpp/HybridCactus.cpp +119 -0
  4. package/cpp/HybridCactus.hpp +13 -0
  5. package/cpp/cactus_ffi.h +24 -0
  6. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +24 -0
  7. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_utils.h +41 -1
  8. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +66 -48
  9. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/gemma_tools.h +549 -0
  10. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +102 -21
  11. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +45 -195
  12. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +399 -140
  13. package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
  14. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +24 -0
  15. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_utils.h +41 -1
  16. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +66 -48
  17. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/gemma_tools.h +549 -0
  18. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +102 -21
  19. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +45 -195
  20. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +399 -140
  21. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
  22. package/lib/module/api/Database.js +0 -92
  23. package/lib/module/api/Database.js.map +1 -1
  24. package/lib/module/classes/CactusLM.js +33 -15
  25. package/lib/module/classes/CactusLM.js.map +1 -1
  26. package/lib/module/classes/CactusSTT.js +90 -15
  27. package/lib/module/classes/CactusSTT.js.map +1 -1
  28. package/lib/module/hooks/useCactusLM.js +14 -5
  29. package/lib/module/hooks/useCactusLM.js.map +1 -1
  30. package/lib/module/hooks/useCactusSTT.js +100 -4
  31. package/lib/module/hooks/useCactusSTT.js.map +1 -1
  32. package/lib/module/index.js.map +1 -1
  33. package/lib/module/models.js +336 -0
  34. package/lib/module/models.js.map +1 -0
  35. package/lib/module/native/Cactus.js +37 -0
  36. package/lib/module/native/Cactus.js.map +1 -1
  37. package/lib/module/types/CactusLM.js +2 -0
  38. package/lib/module/types/CactusSTT.js +2 -0
  39. package/lib/module/types/common.js +2 -0
  40. package/lib/module/types/{CactusModel.js.map → common.js.map} +1 -1
  41. package/lib/typescript/src/api/Database.d.ts +0 -6
  42. package/lib/typescript/src/api/Database.d.ts.map +1 -1
  43. package/lib/typescript/src/classes/CactusLM.d.ts +7 -3
  44. package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
  45. package/lib/typescript/src/classes/CactusSTT.d.ts +13 -4
  46. package/lib/typescript/src/classes/CactusSTT.d.ts.map +1 -1
  47. package/lib/typescript/src/hooks/useCactusLM.d.ts +2 -2
  48. package/lib/typescript/src/hooks/useCactusLM.d.ts.map +1 -1
  49. package/lib/typescript/src/hooks/useCactusSTT.d.ts +12 -4
  50. package/lib/typescript/src/hooks/useCactusSTT.d.ts.map +1 -1
  51. package/lib/typescript/src/index.d.ts +2 -3
  52. package/lib/typescript/src/index.d.ts.map +1 -1
  53. package/lib/typescript/src/models.d.ts +6 -0
  54. package/lib/typescript/src/models.d.ts.map +1 -0
  55. package/lib/typescript/src/native/Cactus.d.ts +6 -1
  56. package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
  57. package/lib/typescript/src/specs/Cactus.nitro.d.ts +5 -0
  58. package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
  59. package/lib/typescript/src/types/CactusLM.d.ts +2 -0
  60. package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
  61. package/lib/typescript/src/types/CactusSTT.d.ts +20 -0
  62. package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
  63. package/lib/typescript/src/types/common.d.ts +28 -0
  64. package/lib/typescript/src/types/common.d.ts.map +1 -0
  65. package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +5 -0
  66. package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +5 -0
  67. package/package.json +1 -1
  68. package/src/api/Database.ts +0 -133
  69. package/src/classes/CactusLM.ts +49 -17
  70. package/src/classes/CactusSTT.ts +118 -17
  71. package/src/hooks/useCactusLM.ts +25 -5
  72. package/src/hooks/useCactusSTT.ts +117 -5
  73. package/src/index.tsx +6 -2
  74. package/src/models.ts +344 -0
  75. package/src/native/Cactus.ts +55 -0
  76. package/src/specs/Cactus.nitro.ts +5 -0
  77. package/src/types/CactusLM.ts +3 -0
  78. package/src/types/CactusSTT.ts +26 -0
  79. package/src/types/common.ts +28 -0
  80. package/lib/module/types/CactusModel.js +0 -2
  81. package/lib/module/types/CactusSTTModel.js +0 -2
  82. package/lib/module/types/CactusSTTModel.js.map +0 -1
  83. package/lib/typescript/src/types/CactusModel.d.ts +0 -13
  84. package/lib/typescript/src/types/CactusModel.d.ts.map +0 -1
  85. package/lib/typescript/src/types/CactusSTTModel.d.ts +0 -8
  86. package/lib/typescript/src/types/CactusSTTModel.d.ts.map +0 -1
  87. package/src/types/CactusModel.ts +0 -15
  88. package/src/types/CactusSTTModel.ts +0 -10
@@ -2,6 +2,13 @@
2
2
  #define KERNEL_UTILS_H
3
3
 
4
4
  #include <arm_neon.h>
5
+ #if defined(__APPLE__)
6
+ #include <TargetConditionals.h>
7
+ #endif
8
+ #if defined(__ANDROID__)
9
+ #include <sys/auxv.h>
10
+ #include <asm/hwcap.h>
11
+ #endif
5
12
  #include <algorithm>
6
13
  #include <cmath>
7
14
  #include <thread>
@@ -19,166 +26,439 @@
19
26
  #include <cstdio>
20
27
 
21
28
  constexpr size_t NEON_VECTOR_SIZE = 16;
29
+ constexpr size_t STREAMING_STORE_THRESHOLD = 32768;
22
30
 
23
- inline int8_t clamp_to_int8(float value) {
24
- int32_t clamped = static_cast<int32_t>(roundf(value));
25
- return static_cast<int8_t>(std::max(-128, std::min(127, clamped)));
26
- }
27
-
28
- inline int8_t clamp_to_int8(int32_t value) {
29
- return static_cast<int8_t>(std::max(-128, std::min(127, value)));
31
+ inline void stream_store_f16x8(__fp16* dst, float16x8_t val) {
32
+ #if defined(__aarch64__)
33
+ float16x4_t lo = vget_low_f16(val);
34
+ float16x4_t hi = vget_high_f16(val);
35
+ __asm__ __volatile__(
36
+ "stnp %d0, %d1, [%2]"
37
+ :
38
+ : "w"(lo), "w"(hi), "r"(dst)
39
+ : "memory"
40
+ );
41
+ #else
42
+ vst1q_f16(dst, val);
43
+ #endif
30
44
  }
31
45
 
32
46
  #if defined(__ARM_FEATURE_DOTPROD)
33
- inline int32x4_t accum_i8mm(int32x4_t acc, int8x16_t a, int8x16_t b) {
47
+ inline int32x4_t accum_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
34
48
  return vdotq_s32(acc, a, b);
35
49
  }
36
50
  #else
37
- inline int32x4_t accum_i8mm(int32x4_t acc, int8x16_t a, int8x16_t b) {
51
+ inline int32x4_t accum_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
38
52
  int16x8_t prod_low = vmull_s8(vget_low_s8(a), vget_low_s8(b));
39
53
  int32x4_t acc_high = vpaddlq_s16(vmull_s8(vget_high_s8(a), vget_high_s8(b)));
40
54
  return vaddq_s32(vaddq_s32(acc, vpaddlq_s16(prod_low)), acc_high);
41
55
  }
42
56
  #endif
43
57
 
44
- inline float16x8_t accum_f16_dot(float16x8_t acc, float16x8_t a_low, float16x8_t a_high,
58
+ // I8MM support: runtime detection on Android, compile-time on Apple
59
+ #if defined(__ANDROID__) && defined(__aarch64__)
60
+
61
+ inline bool cactus_has_i8mm() {
62
+ static int8_t supported = -1;
63
+ if (supported == -1) {
64
+ unsigned long hwcaps = getauxval(AT_HWCAP2);
65
+ supported = (hwcaps & HWCAP2_I8MM) ? 1 : 0;
66
+ }
67
+ return supported;
68
+ }
69
+
70
+ __attribute__((target("arch=armv8.2-a+i8mm")))
71
+ inline int32x4_t accum_matmul(int32x4_t acc, int8x16_t a, int8x16_t b) {
72
+ return vmmlaq_s32(acc, a, b);
73
+ }
74
+
75
+ #elif defined(__APPLE__) && defined(__aarch64__)
76
+
77
+ inline bool cactus_has_i8mm() {
78
+ return true;
79
+ }
80
+
81
+ __attribute__((target("i8mm")))
82
+ inline int32x4_t accum_matmul(int32x4_t acc, int8x16_t a, int8x16_t b) {
83
+ return vmmlaq_s32(acc, a, b);
84
+ }
85
+
86
+ #else
87
+
88
+ inline bool cactus_has_i8mm() {
89
+ return false;
90
+ }
91
+
92
+ #endif
93
+
94
+ inline float16x8_t accum_f16_dot(float16x8_t acc, float16x8_t a_low, float16x8_t a_high,
45
95
  float16x8_t b_low, float16x8_t b_high) {
46
96
  acc = vfmaq_f16(acc, a_low, b_low);
47
97
  return vfmaq_f16(acc, a_high, b_high);
48
98
  }
49
99
 
50
- inline float32x4_t accum_f32_dot(float32x4_t acc, float32x4_t a_low, float32x4_t a_high,
51
- float32x4_t b_low, float32x4_t b_high) {
52
- acc = vfmaq_f32(acc, a_low, b_low);
53
- return vfmaq_f32(acc, a_high, b_high);
100
+ inline float32x4_t fast_exp_f32x4(float32x4_t x) {
101
+ const float32x4_t log2e = vdupq_n_f32(1.4426950408889634f);
102
+ const float32x4_t ln2 = vdupq_n_f32(0.6931471805599453f);
103
+
104
+ const float32x4_t c0 = vdupq_n_f32(1.0f);
105
+ const float32x4_t c1 = vdupq_n_f32(0.6931471805599453f);
106
+ const float32x4_t c2 = vdupq_n_f32(0.2402265069591007f);
107
+ const float32x4_t c3 = vdupq_n_f32(0.05550410866482158f);
108
+ const float32x4_t c4 = vdupq_n_f32(0.009618129842071803f);
109
+
110
+ x = vmaxq_f32(x, vdupq_n_f32(-87.0f));
111
+ x = vminq_f32(x, vdupq_n_f32(87.0f));
112
+
113
+ float32x4_t z = vmulq_f32(x, log2e);
114
+
115
+ int32x4_t zi = vcvtq_s32_f32(z);
116
+ float32x4_t zf = vsubq_f32(z, vcvtq_f32_s32(zi));
117
+
118
+ uint32x4_t neg_mask = vcltq_f32(zf, vdupq_n_f32(0.0f));
119
+ zi = vsubq_s32(zi, vandq_s32(vreinterpretq_s32_u32(neg_mask), vdupq_n_s32(1)));
120
+ zf = vaddq_f32(zf, vreinterpretq_f32_u32(vandq_u32(neg_mask, vreinterpretq_u32_f32(vdupq_n_f32(1.0f)))));
121
+
122
+ float32x4_t zf_ln2 = vmulq_f32(zf, ln2);
123
+ float32x4_t p = c4;
124
+ p = vfmaq_f32(c3, p, zf_ln2);
125
+ p = vfmaq_f32(c2, p, zf_ln2);
126
+ p = vfmaq_f32(c1, p, zf_ln2);
127
+ p = vfmaq_f32(c0, p, zf_ln2);
128
+
129
+ int32x4_t exp_bits = vshlq_n_s32(vaddq_s32(zi, vdupq_n_s32(127)), 23);
130
+ float32x4_t scale = vreinterpretq_f32_s32(exp_bits);
131
+
132
+ return vmulq_f32(p, scale);
133
+ }
134
+
135
+ inline float32x4_t fast_tanh_f32x4(float32x4_t x) {
136
+ const float32x4_t one = vdupq_n_f32(1.0f);
137
+ const float32x4_t neg_one = vdupq_n_f32(-1.0f);
138
+
139
+ uint32x4_t pos_sat = vcgtq_f32(x, vdupq_n_f32(4.5f));
140
+ uint32x4_t neg_sat = vcltq_f32(x, vdupq_n_f32(-4.5f));
141
+
142
+ const float32x4_t c27 = vdupq_n_f32(27.0f);
143
+ const float32x4_t c9 = vdupq_n_f32(9.0f);
144
+
145
+ float32x4_t x2 = vmulq_f32(x, x);
146
+ float32x4_t num = vaddq_f32(c27, x2);
147
+ float32x4_t den = vfmaq_f32(c27, c9, x2);
148
+
149
+ float32x4_t result = vmulq_f32(x, vdivq_f32(num, den));
150
+
151
+ result = vbslq_f32(pos_sat, one, result);
152
+ result = vbslq_f32(neg_sat, neg_one, result);
153
+
154
+ return result;
155
+ }
156
+
157
+ inline int8x16_t unpack_int4_lo(uint8x16_t packed) {
158
+ uint8x16_t lo = vandq_u8(packed, vdupq_n_u8(0x0F));
159
+ uint8x16_t sign_mask = vcgtq_u8(lo, vdupq_n_u8(7));
160
+ uint8x16_t correction = vandq_u8(sign_mask, vdupq_n_u8(16));
161
+ return vreinterpretq_s8_u8(vsubq_u8(lo, correction));
162
+ }
163
+
164
+ inline int8x16_t unpack_int4_hi(uint8x16_t packed) {
165
+ uint8x16_t hi = vshrq_n_u8(packed, 4);
166
+ uint8x16_t sign_mask = vcgtq_u8(hi, vdupq_n_u8(7));
167
+ uint8x16_t correction = vandq_u8(sign_mask, vdupq_n_u8(16));
168
+ return vreinterpretq_s8_u8(vsubq_u8(hi, correction));
169
+ }
170
+
171
+ inline void unpack_int4_to_int8x32(uint8x16_t packed, int8x16_t& out_lo, int8x16_t& out_hi) {
172
+ int8x16_t lo_nibbles = unpack_int4_lo(packed);
173
+ int8x16_t hi_nibbles = unpack_int4_hi(packed);
174
+ int8x16x2_t interleaved = vzipq_s8(lo_nibbles, hi_nibbles);
175
+ out_lo = interleaved.val[0];
176
+ out_hi = interleaved.val[1];
177
+ }
178
+
179
+ inline int32x4_t int4_dot_asm(int32x4_t acc, uint8x16_t packed, int8x16_t a_lo, int8x16_t a_hi) {
180
+ #if defined(__aarch64__)
181
+ int8x16_t b_lo, b_hi;
182
+
183
+ __asm__ __volatile__ (
184
+ "movi v16.16b, #0x0F \n" // low nibble mask
185
+ "movi v17.16b, #7 \n" // sign threshold
186
+ "movi v18.16b, #16 \n" // sign correction
187
+
188
+ "and %[b_lo].16b, %[packed].16b, v16.16b \n"
189
+
190
+ "ushr %[b_hi].16b, %[packed].16b, #4 \n"
191
+
192
+ "cmgt v19.16b, %[b_lo].16b, v17.16b \n"
193
+ "and v19.16b, v19.16b, v18.16b \n"
194
+ "sub %[b_lo].16b, %[b_lo].16b, v19.16b \n"
195
+
196
+ "cmgt v20.16b, %[b_hi].16b, v17.16b \n"
197
+ "and v20.16b, v20.16b, v18.16b \n"
198
+ "sub %[b_hi].16b, %[b_hi].16b, v20.16b \n"
199
+
200
+ "zip1 v21.16b, %[b_lo].16b, %[b_hi].16b \n"
201
+ "zip2 v22.16b, %[b_lo].16b, %[b_hi].16b \n"
202
+
203
+ ".arch armv8.2-a+dotprod \n"
204
+ "sdot %[acc].4s, %[a_lo].16b, v21.16b \n"
205
+ "sdot %[acc].4s, %[a_hi].16b, v22.16b \n"
206
+
207
+ : [acc] "+w"(acc), [b_lo] "=w"(b_lo), [b_hi] "=w"(b_hi)
208
+ : [packed] "w"(packed), [a_lo] "w"(a_lo), [a_hi] "w"(a_hi)
209
+ : "v16", "v17", "v18", "v19", "v20", "v21", "v22"
210
+ );
211
+
212
+ return acc;
213
+ #else
214
+ int8x16_t b_lo, b_hi;
215
+ unpack_int4_to_int8x32(packed, b_lo, b_hi);
216
+ acc = accum_dot(acc, a_lo, b_lo);
217
+ acc = accum_dot(acc, a_hi, b_hi);
218
+ return acc;
219
+ #endif
220
+ }
221
+
222
+ inline int32_t int4_dot_m1_asm(const int8_t* a_ptr, const uint8_t* b_packed, size_t group_size) {
223
+ #if defined(__aarch64__)
224
+ int32x4_t acc = vdupq_n_s32(0);
225
+
226
+ for (size_t k = 0; k < group_size; k += 64) {
227
+ uint8x16_t p0 = vld1q_u8(b_packed + k/2);
228
+ uint8x16_t p1 = vld1q_u8(b_packed + k/2 + 16);
229
+
230
+ int8x16_t a0 = vld1q_s8(a_ptr + k);
231
+ int8x16_t a1 = vld1q_s8(a_ptr + k + 16);
232
+ int8x16_t a2 = vld1q_s8(a_ptr + k + 32);
233
+ int8x16_t a3 = vld1q_s8(a_ptr + k + 48);
234
+
235
+ acc = int4_dot_asm(acc, p0, a0, a1);
236
+ acc = int4_dot_asm(acc, p1, a2, a3);
237
+ }
238
+
239
+ return vaddvq_s32(acc);
240
+ #else
241
+ int32x4_t acc = vdupq_n_s32(0);
242
+ for (size_t k = 0; k < group_size; k += 32) {
243
+ uint8x16_t packed = vld1q_u8(b_packed + k/2);
244
+ int8x16_t b_lo, b_hi;
245
+ unpack_int4_to_int8x32(packed, b_lo, b_hi);
246
+ acc = accum_dot(acc, vld1q_s8(a_ptr + k), b_lo);
247
+ acc = accum_dot(acc, vld1q_s8(a_ptr + k + 16), b_hi);
248
+ }
249
+ return vaddvq_s32(acc);
250
+ #endif
54
251
  }
55
252
 
56
253
  namespace CactusThreading {
57
-
254
+
58
255
  class ThreadPool {
59
256
  private:
257
+ static constexpr size_t MAX_WORKERS = 16;
258
+
60
259
  std::vector<std::thread> workers;
61
- std::queue<std::function<void()>> tasks;
62
- std::mutex queue_mutex;
63
- std::condition_variable condition;
64
- std::atomic<bool> stop{false};
65
- std::atomic<size_t> active_workers{0};
66
- std::condition_variable finish_condition;
67
-
260
+ std::deque<std::function<void()>> tasks;
261
+
262
+ std::mutex mutex;
263
+ std::condition_variable work_available;
264
+ std::condition_variable work_done;
265
+
266
+ bool stop{false};
267
+ std::atomic<size_t> pending_tasks{0};
268
+ size_t num_workers_;
269
+
68
270
  void worker_thread() {
69
271
  while (true) {
70
272
  std::function<void()> task;
71
273
  {
72
- std::unique_lock<std::mutex> lock(queue_mutex);
73
- condition.wait(lock, [this] { return stop || !tasks.empty(); });
74
-
75
- if (stop && tasks.empty()) return;
76
-
274
+ std::unique_lock<std::mutex> lock(mutex);
275
+ work_available.wait(lock, [this] {
276
+ return stop || !tasks.empty();
277
+ });
278
+
279
+ if (stop && tasks.empty()) {
280
+ return;
281
+ }
282
+
77
283
  task = std::move(tasks.front());
78
- tasks.pop();
79
- active_workers++;
284
+ tasks.pop_front();
80
285
  }
81
-
286
+
82
287
  task();
83
-
84
- active_workers--;
85
- finish_condition.notify_all();
288
+
289
+ if (pending_tasks.fetch_sub(1, std::memory_order_acq_rel) == 1) {
290
+ std::lock_guard<std::mutex> lock(mutex);
291
+ work_done.notify_one();
292
+ }
86
293
  }
87
294
  }
88
-
295
+
89
296
  public:
90
- explicit ThreadPool(size_t num_threads = std::thread::hardware_concurrency()) {
91
- workers.reserve(num_threads);
92
- for (size_t i = 0; i < num_threads; ++i) {
297
+ explicit ThreadPool(size_t num_threads = std::thread::hardware_concurrency())
298
+ : stop(false), pending_tasks(0) {
299
+ num_workers_ = std::min(num_threads, MAX_WORKERS);
300
+ if (num_workers_ == 0) num_workers_ = 1;
301
+ workers.reserve(num_workers_);
302
+ for (size_t i = 0; i < num_workers_; ++i) {
93
303
  workers.emplace_back(&ThreadPool::worker_thread, this);
94
304
  }
95
305
  }
96
-
306
+
97
307
  ~ThreadPool() {
98
308
  {
99
- std::unique_lock<std::mutex> lock(queue_mutex);
309
+ std::lock_guard<std::mutex> lock(mutex);
100
310
  stop = true;
101
311
  }
102
- condition.notify_all();
312
+ work_available.notify_all();
103
313
  for (auto& worker : workers) {
104
- worker.join();
314
+ if (worker.joinable()) {
315
+ worker.join();
316
+ }
105
317
  }
106
318
  }
107
-
319
+
108
320
  template<typename F>
109
321
  auto enqueue(F&& f) -> std::future<decltype(f())> {
110
322
  using return_type = decltype(f());
111
-
323
+
112
324
  auto task = std::make_shared<std::packaged_task<return_type()>>(
113
325
  std::forward<F>(f)
114
326
  );
115
-
327
+
116
328
  std::future<return_type> res = task->get_future();
329
+
117
330
  {
118
- std::unique_lock<std::mutex> lock(queue_mutex);
119
- if (stop) throw std::runtime_error("enqueue on stopped ThreadPool");
120
-
121
- tasks.emplace([task](){ (*task)(); });
331
+ std::lock_guard<std::mutex> lock(mutex);
332
+ pending_tasks.fetch_add(1, std::memory_order_relaxed);
333
+ tasks.emplace_back([task](){ (*task)(); });
122
334
  }
123
- condition.notify_one();
335
+ work_available.notify_one();
336
+
124
337
  return res;
125
338
  }
126
-
339
+
340
+ template<typename F>
341
+ void enqueue_batch(size_t total_work, F task_func) {
342
+ if (total_work == 0) return;
343
+
344
+ const size_t num_tasks = std::min(num_workers_, total_work);
345
+ const size_t per_worker = total_work / num_tasks;
346
+ const size_t remainder = total_work % num_tasks;
347
+
348
+ {
349
+ std::lock_guard<std::mutex> lock(mutex);
350
+ pending_tasks.fetch_add(num_tasks, std::memory_order_relaxed);
351
+
352
+ for (size_t w = 0; w < num_tasks; ++w) {
353
+ size_t start = w * per_worker + std::min(w, remainder);
354
+ size_t end = start + per_worker + (w < remainder ? 1 : 0);
355
+ tasks.emplace_back([=]() { task_func(start, end); });
356
+ }
357
+ }
358
+ work_available.notify_all();
359
+ }
360
+
127
361
  void wait_all() {
128
- std::unique_lock<std::mutex> lock(queue_mutex);
129
- finish_condition.wait(lock, [this] {
130
- return tasks.empty() && active_workers == 0;
362
+ std::unique_lock<std::mutex> lock(mutex);
363
+ work_done.wait(lock, [this] {
364
+ return pending_tasks.load(std::memory_order_acquire) == 0;
131
365
  });
132
366
  }
133
-
134
- size_t num_workers() const { return workers.size(); }
367
+
368
+ template<typename F>
369
+ void enqueue_n_threads(size_t total_work, size_t num_threads, F task_func) {
370
+ if (total_work == 0 || num_threads == 0) return;
371
+
372
+ num_threads = std::min(num_threads, std::min(num_workers_, total_work));
373
+ const size_t per_thread = total_work / num_threads;
374
+ const size_t remainder = total_work % num_threads;
375
+
376
+ {
377
+ std::lock_guard<std::mutex> lock(mutex);
378
+ pending_tasks.fetch_add(num_threads, std::memory_order_relaxed);
379
+
380
+ for (size_t t = 0; t < num_threads; ++t) {
381
+ size_t start = t * per_thread + std::min(t, remainder);
382
+ size_t end = start + per_thread + (t < remainder ? 1 : 0);
383
+ tasks.emplace_back([=]() { task_func(start, end); });
384
+ }
385
+ }
386
+ work_available.notify_all();
387
+ }
388
+
389
+ size_t num_workers() const { return num_workers_; }
135
390
  };
136
-
391
+
137
392
  inline ThreadPool& get_thread_pool() {
138
393
  static ThreadPool pool;
139
394
  return pool;
140
395
  }
141
396
 
142
- inline size_t get_optimal_thread_count(size_t total_work, size_t min_work_per_thread) {
143
- if (total_work < min_work_per_thread) return 1;
397
+ struct ParallelConfig {
398
+ size_t min_work_gate;
399
+ size_t work_per_thread;
400
+
401
+ constexpr ParallelConfig(size_t gate, size_t per_thread)
402
+ : min_work_gate(gate), work_per_thread(per_thread) {}
403
+ };
404
+
405
+ inline size_t get_optimal_thread_count(size_t total_work, ParallelConfig config) {
406
+ if (total_work < config.min_work_gate) return 1;
407
+
144
408
  size_t pool_size = get_thread_pool().num_workers();
145
- return std::min(pool_size,
146
- std::max(static_cast<size_t>(1), total_work / min_work_per_thread));
409
+ size_t num_threads = (total_work + config.work_per_thread - 1) / config.work_per_thread;
410
+ return std::min(pool_size, std::max(static_cast<size_t>(1), num_threads));
147
411
  }
148
-
412
+
149
413
  struct Thresholds {
414
+ #if defined(__ANDROID__)
415
+ static constexpr ParallelConfig ATTENTION{64, 32};
416
+ static constexpr ParallelConfig ELEMENT_WISE{5000, 2500};
417
+ static constexpr ParallelConfig AXIS_REDUCE{1000, 500};
418
+ static constexpr ParallelConfig ALL_REDUCE{10000, 5000};
419
+ static constexpr ParallelConfig SCALAR_BASIC{30000, 15000};
420
+ static constexpr ParallelConfig SCALAR_EXPENSIVE{10000, 5000};
421
+ #else // Apple
422
+ static constexpr ParallelConfig ATTENTION{32, 16};
423
+ static constexpr ParallelConfig ELEMENT_WISE{5000, 2500};
424
+ static constexpr ParallelConfig AXIS_REDUCE{1000, 500};
425
+ static constexpr ParallelConfig ALL_REDUCE{10000, 5000};
426
+ static constexpr ParallelConfig SCALAR_BASIC{5000, 2500};
427
+ static constexpr ParallelConfig SCALAR_EXPENSIVE{2500, 1250};
428
+ #endif
429
+ };
150
430
 
431
+ struct GemmThreading {
151
432
  #if defined(__ANDROID__)
152
- static constexpr size_t ELEMENT_WISE = 5000;
153
- static constexpr size_t AXIS_REDUCE = 1000;
154
- static constexpr size_t ALL_REDUCE = 10000;
155
- static constexpr size_t SCALAR_BASIC = 30000;
156
- static constexpr size_t SCALAR_EXPENSIVE = 10000;
157
- static constexpr size_t ATTENTION = 512;
158
- static constexpr size_t GEMM_TILED = 20000;
159
- static constexpr size_t GEMM_SMALL = 64 * 64 * 64;
160
- static constexpr size_t GEMM_MEDIUM = 256 * 256 * 256;
161
- static constexpr size_t GEMM_TILE_M = 64;
162
- static constexpr size_t GEMM_TILE_N = 64;
163
- static constexpr size_t GEMM_TILE_M_SMALL = 32;
164
- static constexpr size_t GEMM_TILE_N_SMALL = 32;
165
- #else // iOS
166
- static constexpr size_t ELEMENT_WISE = 5000;
167
- static constexpr size_t AXIS_REDUCE = 1000;
168
- static constexpr size_t ALL_REDUCE = 10000;
169
- static constexpr size_t SCALAR_BASIC = 5000;
170
- static constexpr size_t SCALAR_EXPENSIVE = 2500;
171
- static constexpr size_t ATTENTION = 4;
172
- static constexpr size_t GEMM_TILED = 4;
173
- static constexpr size_t GEMM_SMALL = 64 * 64 * 64;
174
- static constexpr size_t GEMM_MEDIUM = 256 * 256 * 256;
175
- static constexpr size_t GEMM_TILE_M = 64;
176
- static constexpr size_t GEMM_TILE_N = 64;
177
- static constexpr size_t GEMM_TILE_M_SMALL = 32;
178
- static constexpr size_t GEMM_TILE_N_SMALL = 32;
433
+ static size_t get_num_threads(size_t M, size_t pool_size) {
434
+ if (M <= 1) return 1;
435
+ return pool_size;
436
+ }
437
+ #elif defined(__APPLE__) && TARGET_OS_IPHONE
438
+ static size_t get_num_threads(size_t M, size_t pool_size) {
439
+ if (M <= 1) return std::min(pool_size, static_cast<size_t>(2));
440
+ return pool_size;
441
+ }
442
+ #else // Mac
443
+ static size_t get_num_threads(size_t M, size_t pool_size) {
444
+ if (M <= 1) return std::min(pool_size, static_cast<size_t>(4));
445
+ return pool_size;
446
+ }
179
447
  #endif
180
- static constexpr size_t L2_CACHE_SIZE = 256 * 1024;
181
448
  };
449
+
450
+ inline size_t& get_gemm_thread_override() {
451
+ static size_t override_threads = 0;
452
+ return override_threads;
453
+ }
454
+
455
+ inline void set_gemm_threads(size_t num_threads) {
456
+ get_gemm_thread_override() = num_threads;
457
+ }
458
+
459
+ inline void reset_gemm_threads() {
460
+ get_gemm_thread_override() = 0;
461
+ }
182
462
 
183
463
  class TaskHandle {
184
464
  private:
@@ -225,10 +505,10 @@ namespace CactusThreading {
225
505
  };
226
506
 
227
507
  template<typename WorkFunc>
228
- TaskHandle parallel_for(size_t total_work, size_t threshold, WorkFunc work_func, bool wait = true) {
229
- const size_t num_threads = get_optimal_thread_count(total_work, threshold);
230
- TaskHandle handle(!wait);
231
-
508
+ TaskHandle parallel_for(size_t total_work, ParallelConfig config, WorkFunc work_func, bool wait = true) {
509
+ const size_t num_threads = get_optimal_thread_count(total_work, config);
510
+ TaskHandle handle(!wait);
511
+
232
512
  if (num_threads == 1) {
233
513
  if (wait) {
234
514
  work_func(0, total_work);
@@ -240,10 +520,10 @@ namespace CactusThreading {
240
520
  }));
241
521
  return handle;
242
522
  }
243
-
523
+
244
524
  auto& pool = get_thread_pool();
245
525
  const size_t work_per_thread = total_work / num_threads;
246
-
526
+
247
527
  for (size_t t = 0; t < num_threads; ++t) {
248
528
  handle.add_future(pool.enqueue([work_func, t, num_threads, work_per_thread, total_work]() {
249
529
  const size_t start_idx = t * work_per_thread;
@@ -251,17 +531,17 @@ namespace CactusThreading {
251
531
  work_func(start_idx, end_idx);
252
532
  }));
253
533
  }
254
-
534
+
255
535
  if (wait) {
256
536
  handle.wait();
257
537
  }
258
538
  return handle;
259
539
  }
260
-
540
+
261
541
  template<typename WorkFunc>
262
- void parallel_for_2d(size_t outer_size, size_t inner_size, size_t threshold, WorkFunc work_func) {
542
+ void parallel_for_2d(size_t outer_size, size_t inner_size, ParallelConfig config, WorkFunc work_func) {
263
543
  const size_t total_work = outer_size * inner_size;
264
- parallel_for(total_work, threshold, [&](size_t start_idx, size_t end_idx) {
544
+ parallel_for(total_work, config, [&](size_t start_idx, size_t end_idx) {
265
545
  for (size_t work_idx = start_idx; work_idx < end_idx; ++work_idx) {
266
546
  const size_t outer = work_idx / inner_size;
267
547
  const size_t inner = work_idx % inner_size;
@@ -269,11 +549,11 @@ namespace CactusThreading {
269
549
  }
270
550
  });
271
551
  }
272
-
552
+
273
553
  template<typename WorkFunc, typename ResultType, typename CombineFunc>
274
- ResultType parallel_reduce(size_t total_work, size_t threshold,
554
+ ResultType parallel_reduce(size_t total_work, ParallelConfig config,
275
555
  WorkFunc work_func, ResultType init_value, CombineFunc combine_func) {
276
- const size_t num_threads = get_optimal_thread_count(total_work, threshold);
556
+ const size_t num_threads = get_optimal_thread_count(total_work, config);
277
557
 
278
558
  if (num_threads == 1) {
279
559
  return work_func(0, total_work);
@@ -298,46 +578,25 @@ namespace CactusThreading {
298
578
  }
299
579
  return result;
300
580
  }
301
-
302
- inline size_t compute_gemm_parallelism(size_t M, size_t K, size_t N, size_t element_size) {
303
- size_t total_ops = M * K * N;
304
-
305
- if (total_ops < Thresholds::GEMM_SMALL) return 1;
306
-
307
- if (total_ops < Thresholds::GEMM_MEDIUM) {
308
- return std::min(static_cast<size_t>(2), get_thread_pool().num_workers());
309
- }
310
-
311
- size_t bytes_accessed = (M * K + K * N + M * N) * element_size;
312
- size_t cache_tiles = (bytes_accessed + Thresholds::L2_CACHE_SIZE - 1) / Thresholds::L2_CACHE_SIZE;
313
-
314
- size_t compute_threads = std::sqrt(static_cast<double>(total_ops) / Thresholds::GEMM_SMALL);
315
- size_t memory_threads = cache_tiles;
316
-
317
- size_t optimal = std::min(compute_threads, memory_threads);
318
- return std::min(optimal, get_thread_pool().num_workers());
319
- }
320
-
581
+
321
582
  template<typename WorkFunc>
322
- void parallel_for_2d_tiled(size_t rows, size_t cols, size_t tile_rows, size_t tile_cols, WorkFunc work_func) {
323
- size_t num_row_tiles = (rows + tile_rows - 1) / tile_rows;
324
- size_t num_col_tiles = (cols + tile_cols - 1) / tile_cols;
325
- size_t total_tiles = num_row_tiles * num_col_tiles;
326
-
327
- parallel_for(total_tiles, Thresholds::GEMM_TILED, [=](size_t start_tile, size_t end_tile) {
328
- for (size_t tile_idx = start_tile; tile_idx < end_tile; ++tile_idx) {
329
- size_t tile_row = tile_idx / num_col_tiles;
330
- size_t tile_col = tile_idx % num_col_tiles;
331
-
332
- size_t row_start = tile_row * tile_rows;
333
- size_t row_end = std::min(row_start + tile_rows, rows);
334
- size_t col_start = tile_col * tile_cols;
335
- size_t col_end = std::min(col_start + tile_cols, cols);
336
-
337
- work_func(row_start, row_end, col_start, col_end);
338
- }
339
- });
583
+ void parallel_gemm_tiles(size_t M, size_t total_tiles, WorkFunc work_func) {
584
+ auto& pool = get_thread_pool();
585
+
586
+ size_t override = get_gemm_thread_override();
587
+ size_t num_threads = (override > 0) ? override : GemmThreading::get_num_threads(M, pool.num_workers());
588
+ num_threads = std::min(num_threads, total_tiles);
589
+
590
+ if (num_threads <= 1) {
591
+ work_func(0, total_tiles);
592
+ return;
593
+ }
594
+
595
+ pool.enqueue_n_threads(total_tiles, num_threads, work_func);
596
+ pool.wait_all();
340
597
  }
598
+
341
599
  }
342
600
 
601
+
343
602
  #endif // KERNEL_UTILS_H