@lgrammel/ds4-provider 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. package/README.md +96 -0
  2. package/binding.gyp +75 -0
  3. package/dist/ds4-language-model.d.ts +71 -0
  4. package/dist/ds4-language-model.d.ts.map +1 -0
  5. package/dist/ds4-language-model.js +888 -0
  6. package/dist/ds4-language-model.js.map +1 -0
  7. package/dist/ds4-provider.d.ts +13 -0
  8. package/dist/ds4-provider.d.ts.map +1 -0
  9. package/dist/ds4-provider.js +20 -0
  10. package/dist/ds4-provider.js.map +1 -0
  11. package/dist/index.d.ts +4 -0
  12. package/dist/index.d.ts.map +1 -0
  13. package/dist/index.js +4 -0
  14. package/dist/index.js.map +1 -0
  15. package/dist/native-binding.d.ts +42 -0
  16. package/dist/native-binding.d.ts.map +1 -0
  17. package/dist/native-binding.js +157 -0
  18. package/dist/native-binding.js.map +1 -0
  19. package/ds4/LICENSE +22 -0
  20. package/ds4/ds4.c +18268 -0
  21. package/ds4/ds4.h +196 -0
  22. package/ds4/ds4_gpu.h +804 -0
  23. package/ds4/ds4_metal.m +14657 -0
  24. package/ds4/metal/argsort.metal +266 -0
  25. package/ds4/metal/bin.metal +192 -0
  26. package/ds4/metal/concat.metal +62 -0
  27. package/ds4/metal/cpy.metal +57 -0
  28. package/ds4/metal/dense.metal +1121 -0
  29. package/ds4/metal/dsv4_hc.metal +861 -0
  30. package/ds4/metal/dsv4_kv.metal +227 -0
  31. package/ds4/metal/dsv4_misc.metal +1088 -0
  32. package/ds4/metal/dsv4_rope.metal +155 -0
  33. package/ds4/metal/flash_attn.metal +1426 -0
  34. package/ds4/metal/get_rows.metal +54 -0
  35. package/ds4/metal/glu.metal +36 -0
  36. package/ds4/metal/moe.metal +1737 -0
  37. package/ds4/metal/norm.metal +153 -0
  38. package/ds4/metal/repeat.metal +52 -0
  39. package/ds4/metal/set_rows.metal +55 -0
  40. package/ds4/metal/softmax.metal +241 -0
  41. package/ds4/metal/sum_rows.metal +102 -0
  42. package/ds4/metal/unary.metal +312 -0
  43. package/native/binding.cpp +621 -0
  44. package/package.json +66 -0
  45. package/scripts/postinstall.cjs +13 -0
  46. package/scripts/vendor-ds4.cjs +67 -0
@@ -0,0 +1,312 @@
1
+ #define FC_UNARY 1200
2
+
3
+ #define OP_UNARY_NUM_SCALE 10
4
+ #define OP_UNARY_NUM_FILL 11
5
+ #define OP_UNARY_NUM_CLAMP 12
6
+ #define OP_UNARY_NUM_SQR 13
7
+ #define OP_UNARY_NUM_SQRT 14
8
+ #define OP_UNARY_NUM_SIN 15
9
+ #define OP_UNARY_NUM_COS 16
10
+ #define OP_UNARY_NUM_LOG 17
11
+ #define OP_UNARY_NUM_LEAKY_RELU 18
12
+
13
+ #define OP_UNARY_NUM_TANH 100
14
+ #define OP_UNARY_NUM_RELU 101
15
+ #define OP_UNARY_NUM_SIGMOID 102
16
+ #define OP_UNARY_NUM_GELU 103
17
+ #define OP_UNARY_NUM_GELU_ERF 104
18
+ #define OP_UNARY_NUM_GELU_QUICK 105
19
+ #define OP_UNARY_NUM_SILU 106
20
+ #define OP_UNARY_NUM_ELU 107
21
+ #define OP_UNARY_NUM_NEG 108
22
+ #define OP_UNARY_NUM_ABS 109
23
+ #define OP_UNARY_NUM_SGN 110
24
+ #define OP_UNARY_NUM_STEP 111
25
+ #define OP_UNARY_NUM_HARDSWISH 112
26
+ #define OP_UNARY_NUM_HARDSIGMOID 113
27
+ #define OP_UNARY_NUM_EXP 114
28
+ #define OP_UNARY_NUM_SOFTPLUS 115
29
+ #define OP_UNARY_NUM_EXPM1 116
30
+ #define OP_UNARY_NUM_FLOOR 117
31
+ #define OP_UNARY_NUM_CEIL 118
32
+ #define OP_UNARY_NUM_ROUND 119
33
+ #define OP_UNARY_NUM_TRUNC 120
34
+ #define OP_UNARY_NUM_XIELU 121
35
+
36
+ struct ds4_metal_args_unary {
37
+ int32_t ne00;
38
+ int32_t ne01;
39
+ int32_t ne02;
40
+ int32_t ne03;
41
+ uint64_t nb00;
42
+ uint64_t nb01;
43
+ uint64_t nb02;
44
+ uint64_t nb03;
45
+ int32_t ne0;
46
+ int32_t ne1;
47
+ int32_t ne2;
48
+ int32_t ne3;
49
+ uint64_t nb0;
50
+ uint64_t nb1;
51
+ uint64_t nb2;
52
+ uint64_t nb3;
53
+ float slope;
54
+ float scale;
55
+ float bias;
56
+ float val;
57
+ float min;
58
+ float max;
59
+ };
60
+
61
+ constant float GELU_COEF_A = 0.044715f;
62
+ constant float GELU_QUICK_COEF = -1.702f;
63
+ constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
64
+ constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
65
+
66
+ // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
67
+ // ref: https://www.johndcook.com/blog/python_erf/
68
+ constant float p_erf = 0.3275911f;
69
+ constant float a1_erf = 0.254829592f;
70
+ constant float a2_erf = -0.284496736f;
71
+ constant float a3_erf = 1.421413741f;
72
+ constant float a4_erf = -1.453152027f;
73
+ constant float a5_erf = 1.061405429f;
74
+
75
+ template<typename T>
76
+ inline T erf_approx(T x) {
77
+ T sign_x = sign(x);
78
+ x = fabs(x);
79
+ T t = 1.0f / (1.0f + p_erf * x);
80
+ T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
81
+ return sign_x * y;
82
+ }
83
+
84
+ template<typename T> T elu_approx(T x);
85
+
86
+ template<> inline float elu_approx<float>(float x) {
87
+ return (x > 0.f) ? x : (exp(x) - 1);
88
+ }
89
+
90
+ template<> inline float4 elu_approx<float4>(float4 x) {
91
+ float4 res;
92
+
93
+ res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
94
+ res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
95
+ res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
96
+ res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
97
+
98
+ return res;
99
+ }
100
+
101
+ constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
102
+ constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
103
+
104
+ // Generic unary elementwise op selected by function constant. DS4 only uses a
105
+ // small subset in inference, mainly sigmoid, SiLU, softplus, sqrt, clamp,
106
+ // scale, and fill.
107
+ template <typename T0, typename T, typename TC>
108
+ kernel void kernel_unary_impl(
109
+ constant ds4_metal_args_unary & args,
110
+ device const char * src0,
111
+ device char * dst,
112
+ uint3 tgpig[[threadgroup_position_in_grid]],
113
+ ushort3 tpitg[[thread_position_in_threadgroup]],
114
+ ushort3 ntg[[threads_per_threadgroup]]) {
115
+ #define FC_OP FC_unary_op
116
+ #define FC_CNT FC_unary_cnt
117
+
118
+ device const T0 * src0_ptr;
119
+ device T * dst_ptr;
120
+
121
+ int i0;
122
+
123
+ if (FC_CNT) {
124
+ i0 = tgpig.x;
125
+
126
+ src0_ptr = (device const T0 *) (src0);
127
+ dst_ptr = (device T *) (dst);
128
+ } else {
129
+ const int i03 = tgpig.z;
130
+ const int i02 = tgpig.y;
131
+ const int k0 = tgpig.x/args.ne01;
132
+ const int i01 = tgpig.x - k0*args.ne01;
133
+
134
+ i0 = k0*ntg.x + tpitg.x;
135
+
136
+ src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
137
+ dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 );
138
+ }
139
+
140
+ {
141
+ if (!FC_CNT) {
142
+ if (i0 >= args.ne0) {
143
+ return;
144
+ }
145
+ }
146
+
147
+ const TC x = (TC) src0_ptr[i0];
148
+
149
+ if (FC_OP == OP_UNARY_NUM_SCALE) {
150
+ dst_ptr[i0] = (T) (args.scale * x + args.bias);
151
+ }
152
+
153
+ if (FC_OP == OP_UNARY_NUM_FILL) {
154
+ dst_ptr[i0] = (T) args.val;
155
+ }
156
+
157
+ if (FC_OP == OP_UNARY_NUM_CLAMP) {
158
+ dst_ptr[i0] = (T) clamp(x, args.min, args.max);
159
+ }
160
+
161
+ if (FC_OP == OP_UNARY_NUM_SQR) {
162
+ dst_ptr[i0] = (T) (x * x);
163
+ }
164
+
165
+ if (FC_OP == OP_UNARY_NUM_SQRT) {
166
+ dst_ptr[i0] = (T) sqrt(x);
167
+ }
168
+
169
+ if (FC_OP == OP_UNARY_NUM_SIN) {
170
+ dst_ptr[i0] = (T) sin(x);
171
+ }
172
+
173
+ if (FC_OP == OP_UNARY_NUM_COS) {
174
+ dst_ptr[i0] = (T) cos(x);
175
+ }
176
+
177
+ if (FC_OP == OP_UNARY_NUM_LOG) {
178
+ dst_ptr[i0] = (T) log(x);
179
+ }
180
+
181
+ if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
182
+ dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
183
+ }
184
+
185
+ if (FC_OP == OP_UNARY_NUM_TANH) {
186
+ dst_ptr[i0] = (T) precise::tanh(x);
187
+ }
188
+
189
+ if (FC_OP == OP_UNARY_NUM_RELU) {
190
+ dst_ptr[i0] = (T) fmax(0, x);
191
+ }
192
+
193
+ if (FC_OP == OP_UNARY_NUM_SIGMOID) {
194
+ dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
195
+ }
196
+
197
+ if (FC_OP == OP_UNARY_NUM_GELU) {
198
+ dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
199
+ }
200
+
201
+ if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
202
+ dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
203
+ }
204
+
205
+ if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
206
+ dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
207
+ }
208
+
209
+ if (FC_OP == OP_UNARY_NUM_SILU) {
210
+ dst_ptr[i0] = (T) (x / (1 + exp(-x)));
211
+ }
212
+
213
+ if (FC_OP == OP_UNARY_NUM_ELU) {
214
+ dst_ptr[i0] = (T) elu_approx(x);
215
+ }
216
+
217
+ if (FC_OP == OP_UNARY_NUM_NEG) {
218
+ dst_ptr[i0] = (T) -x;
219
+ }
220
+
221
+ if (FC_OP == OP_UNARY_NUM_ABS) {
222
+ dst_ptr[i0] = (T) fabs(x);
223
+ }
224
+
225
+ if (FC_OP == OP_UNARY_NUM_SGN) {
226
+ dst_ptr[i0] = T(x > 0) - T(x < 0);
227
+ }
228
+
229
+ if (FC_OP == OP_UNARY_NUM_STEP) {
230
+ dst_ptr[i0] = T(x > 0);
231
+ }
232
+
233
+ if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
234
+ dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
235
+ }
236
+
237
+ if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
238
+ dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
239
+ }
240
+
241
+ if (FC_OP == OP_UNARY_NUM_EXP) {
242
+ dst_ptr[i0] = (T) exp(x);
243
+ }
244
+
245
+ if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
246
+ dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
247
+ }
248
+
249
+ if (FC_OP == OP_UNARY_NUM_EXPM1) {
250
+ // Metal target profiles used here do not all expose expm1(); this
251
+ // generic unary branch is not used by the DS4 inference graph.
252
+ dst_ptr[i0] = (T) (exp(x) - 1);
253
+ }
254
+
255
+ if (FC_OP == OP_UNARY_NUM_FLOOR) {
256
+ dst_ptr[i0] = (T) floor(x);
257
+ }
258
+
259
+ if (FC_OP == OP_UNARY_NUM_CEIL) {
260
+ dst_ptr[i0] = (T) ceil(x);
261
+ }
262
+
263
+ if (FC_OP == OP_UNARY_NUM_ROUND) {
264
+ dst_ptr[i0] = (T) round(x);
265
+ }
266
+
267
+ if (FC_OP == OP_UNARY_NUM_TRUNC) {
268
+ dst_ptr[i0] = (T) trunc(x);
269
+ }
270
+
271
+ if (FC_OP == OP_UNARY_NUM_XIELU) {
272
+ const TC xi = x;
273
+ const TC gate = TC(xi > TC(0.0f));
274
+ const TC clamped = fmin(xi, TC(args.val));
275
+ const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi;
276
+ const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi;
277
+ dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg);
278
+ }
279
+ }
280
+
281
+ #undef FC_OP
282
+ #undef FC_CNT
283
+ }
284
+
285
+ typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
286
+
287
+ // Decode router probability transform. The generic path applies softplus and
288
+ // sqrt as two elementwise kernels; DS4 decode always transforms one 256-wide
289
+ // expert-logit row, so this vectorized kernel does both in one pass.
290
+ kernel void kernel_dsv4_softplus_sqrt_f32_4(
291
+ constant ds4_metal_args_unary & args,
292
+ device const char *src,
293
+ device char *dst,
294
+ uint3 tgpig [[threadgroup_position_in_grid]],
295
+ ushort3 tpitg [[thread_position_in_threadgroup]],
296
+ ushort3 ntg [[threads_per_threadgroup]]) {
297
+ const int k0 = tgpig.x/args.ne01;
298
+ const int i01 = tgpig.x - k0*args.ne01;
299
+ const int i0 = k0*ntg.x + tpitg.x;
300
+ if (i0 >= args.ne0) return;
301
+
302
+ device const float4 *s = (device const float4 *)(src + i01*args.nb01);
303
+ device float4 *d = (device float4 *)(dst + i01*args.nb1);
304
+ const float4 x = s[i0];
305
+ const float4 sp = select(log(1.0f + exp(x)), x, x > 20.0f);
306
+ d[i0] = sqrt(sp);
307
+ }
308
+
309
+ // Host-visible unary variants. Function constants select the actual DS4 op.
310
+ template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>;
311
+ template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
312
+ template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>;