@lgrammel/ds4-provider 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. package/README.md +96 -0
  2. package/binding.gyp +75 -0
  3. package/dist/ds4-language-model.d.ts +71 -0
  4. package/dist/ds4-language-model.d.ts.map +1 -0
  5. package/dist/ds4-language-model.js +888 -0
  6. package/dist/ds4-language-model.js.map +1 -0
  7. package/dist/ds4-provider.d.ts +13 -0
  8. package/dist/ds4-provider.d.ts.map +1 -0
  9. package/dist/ds4-provider.js +20 -0
  10. package/dist/ds4-provider.js.map +1 -0
  11. package/dist/index.d.ts +4 -0
  12. package/dist/index.d.ts.map +1 -0
  13. package/dist/index.js +4 -0
  14. package/dist/index.js.map +1 -0
  15. package/dist/native-binding.d.ts +42 -0
  16. package/dist/native-binding.d.ts.map +1 -0
  17. package/dist/native-binding.js +157 -0
  18. package/dist/native-binding.js.map +1 -0
  19. package/ds4/LICENSE +22 -0
  20. package/ds4/ds4.c +18268 -0
  21. package/ds4/ds4.h +196 -0
  22. package/ds4/ds4_gpu.h +804 -0
  23. package/ds4/ds4_metal.m +14657 -0
  24. package/ds4/metal/argsort.metal +266 -0
  25. package/ds4/metal/bin.metal +192 -0
  26. package/ds4/metal/concat.metal +62 -0
  27. package/ds4/metal/cpy.metal +57 -0
  28. package/ds4/metal/dense.metal +1121 -0
  29. package/ds4/metal/dsv4_hc.metal +861 -0
  30. package/ds4/metal/dsv4_kv.metal +227 -0
  31. package/ds4/metal/dsv4_misc.metal +1088 -0
  32. package/ds4/metal/dsv4_rope.metal +155 -0
  33. package/ds4/metal/flash_attn.metal +1426 -0
  34. package/ds4/metal/get_rows.metal +54 -0
  35. package/ds4/metal/glu.metal +36 -0
  36. package/ds4/metal/moe.metal +1737 -0
  37. package/ds4/metal/norm.metal +153 -0
  38. package/ds4/metal/repeat.metal +52 -0
  39. package/ds4/metal/set_rows.metal +55 -0
  40. package/ds4/metal/softmax.metal +241 -0
  41. package/ds4/metal/sum_rows.metal +102 -0
  42. package/ds4/metal/unary.metal +312 -0
  43. package/native/binding.cpp +621 -0
  44. package/package.json +66 -0
  45. package/scripts/postinstall.cjs +13 -0
  46. package/scripts/vendor-ds4.cjs +67 -0
@@ -0,0 +1,153 @@
1
+ struct ds4_metal_args_norm {
2
+ int32_t ne00;
3
+ int32_t ne00_t;
4
+ uint64_t nb1;
5
+ uint64_t nb2;
6
+ uint64_t nb3;
7
+ float eps;
8
+ int32_t nef1[3];
9
+ int32_t nef2[3];
10
+ int32_t nef3[3];
11
+ uint64_t nbf1[3];
12
+ uint64_t nbf2[3];
13
+ uint64_t nbf3[3];
14
+ };
15
+
16
+ // RMSNorm over one activation row, optionally fusing the learned weight
17
+ // multiply. DS4 calls this before attention, before the FFN, and for plain
18
+ // diagnostics that need normalized but unweighted rows.
19
+ template <typename T, short F>
20
+ kernel void kernel_rms_norm_fuse_impl(
21
+ constant ds4_metal_args_norm & args,
22
+ device const char * src0,
23
+ device const char * src1_0,
24
+ device const char * src1_1,
25
+ device char * dst,
26
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
27
+ uint3 tgpig[[threadgroup_position_in_grid]],
28
+ ushort3 tpitg[[thread_position_in_threadgroup]],
29
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
30
+ ushort tiisg[[thread_index_in_simdgroup]],
31
+ ushort3 ntg[[threads_per_threadgroup]]) {
32
+ if (sgitg == 0) {
33
+ shmem_f32[tiisg] = 0.0f;
34
+ }
35
+
36
+ const int i01 = tgpig.x;
37
+ const int i02 = tgpig.y;
38
+ const int i03 = tgpig.z;
39
+
40
+ device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
41
+
42
+ device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
43
+ device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
44
+
45
+ float sumf = 0.0f;
46
+
47
+ // parallel sum
48
+ for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
49
+ sumf += dot(x[i00], x[i00]);
50
+ }
51
+ sumf = simd_sum(sumf);
52
+
53
+ threadgroup_barrier(mem_flags::mem_threadgroup);
54
+
55
+ if (tiisg == 0) {
56
+ shmem_f32[sgitg] = sumf;
57
+ }
58
+
59
+ threadgroup_barrier(mem_flags::mem_threadgroup);
60
+
61
+ sumf = shmem_f32[tiisg];
62
+ sumf = simd_sum(sumf);
63
+
64
+ const float mean = sumf/args.ne00;
65
+ const float scale = 1.0f/sqrt(mean + args.eps);
66
+
67
+ device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
68
+ for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
69
+ if (F == 1) {
70
+ y[i00] = (x[i00]*scale);
71
+ }
72
+ if (F == 2) {
73
+ y[i00] = (x[i00]*scale)*f0[i00];
74
+ }
75
+ if (F == 3) {
76
+ y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
77
+ }
78
+ }
79
+ }
80
+
81
+ typedef decltype(kernel_rms_norm_fuse_impl<float4, 1>) kernel_rms_norm_fuse_t;
82
+
83
+ // Host-visible RMSNorm variants: plain norm and norm multiplied by weight.
84
+ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 1>;
85
+ template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
86
+
87
+ struct ds4_metal_args_qkv_rms_norm {
88
+ int32_t q_n;
89
+ int32_t q_n4;
90
+ int32_t kv_n;
91
+ int32_t kv_n4;
92
+ uint64_t q_row_stride;
93
+ uint64_t kv_row_stride;
94
+ float eps;
95
+ };
96
+
97
+ // Normalizes DS4's q-lora row and KV row in one dispatch. The two reductions
98
+ // deliberately mirror kernel_rms_norm_mul_f32_4: Q uses the full 256-thread
99
+ // row shape for 1024 floats, while KV only has work in the first 128 lanes for
100
+ // its 512 floats. This keeps the q/kv normalization math aligned with the
101
+ // standalone kernels while removing one tiny launch from the attention setup.
102
+ kernel void kernel_dsv4_qkv_rms_norm_f32_4(
103
+ constant ds4_metal_args_qkv_rms_norm & args,
104
+ device const float4 * q_src,
105
+ device const float4 * q_weight,
106
+ device float4 * q_dst,
107
+ device const float4 * kv_src,
108
+ device const float4 * kv_weight,
109
+ device float4 * kv_dst,
110
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
111
+ uint3 tgpig[[threadgroup_position_in_grid]],
112
+ ushort3 tpitg[[thread_position_in_threadgroup]],
113
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
114
+ ushort tiisg[[thread_index_in_simdgroup]],
115
+ ushort3 ntg[[threads_per_threadgroup]]) {
116
+ if (sgitg == 0) {
117
+ shmem_f32[tiisg] = 0.0f;
118
+ }
119
+
120
+ const uint row = tgpig.x;
121
+ const bool kv_task = tgpig.y != 0;
122
+ const int n = kv_task ? args.kv_n : args.q_n;
123
+ const int n4 = kv_task ? args.kv_n4 : args.q_n4;
124
+ const uint64_t row_stride4 = (kv_task ? args.kv_row_stride : args.q_row_stride) / sizeof(float4);
125
+
126
+ device const float4 * x = kv_task ? kv_src + row * row_stride4 : q_src + row * row_stride4;
127
+ device const float4 * w = kv_task ? kv_weight : q_weight;
128
+ device float4 * y = kv_task ? kv_dst + row * row_stride4 : q_dst + row * row_stride4;
129
+
130
+ float sumf = 0.0f;
131
+ for (int i = tpitg.x; i < n4; i += ntg.x) {
132
+ const float4 v = x[i];
133
+ sumf += dot(v, v);
134
+ }
135
+ sumf = simd_sum(sumf);
136
+
137
+ threadgroup_barrier(mem_flags::mem_threadgroup);
138
+
139
+ if (tiisg == 0) {
140
+ shmem_f32[sgitg] = sumf;
141
+ }
142
+
143
+ threadgroup_barrier(mem_flags::mem_threadgroup);
144
+
145
+ sumf = shmem_f32[tiisg];
146
+ sumf = simd_sum(sumf);
147
+
148
+ const float scale = rsqrt(sumf / float(n) + args.eps);
149
+
150
+ for (int i = tpitg.x; i < n4; i += ntg.x) {
151
+ y[i] = (x[i] * scale) * w[i];
152
+ }
153
+ }
@@ -0,0 +1,52 @@
1
+ // DS4 Metal repeat kernel used for HC embedding expansion.
2
+
3
+ struct ds4_metal_args_repeat {
4
+ int32_t ne00;
5
+ int32_t ne01;
6
+ int32_t ne02;
7
+ int32_t ne03;
8
+ uint64_t nb00;
9
+ uint64_t nb01;
10
+ uint64_t nb02;
11
+ uint64_t nb03;
12
+ int32_t ne0;
13
+ int32_t ne1;
14
+ int32_t ne2;
15
+ int32_t ne3;
16
+ uint64_t nb0;
17
+ uint64_t nb1;
18
+ uint64_t nb2;
19
+ uint64_t nb3;
20
+ };
21
+
22
+ // Repeats a source row into the HC channel dimension. DS4 uses this when the
23
+ // token embedding has to become an HC activation block before layer 0.
24
+ template<typename T>
25
+ kernel void kernel_repeat(
26
+ constant ds4_metal_args_repeat & args,
27
+ device const char * src0,
28
+ device char * dst,
29
+ uint3 tgpig[[threadgroup_position_in_grid]],
30
+ ushort3 tpitg[[thread_position_in_threadgroup]],
31
+ ushort3 ntg[[threads_per_threadgroup]]) {
32
+ const int i3 = tgpig.z;
33
+ const int i2 = tgpig.y;
34
+ const int i1 = tgpig.x;
35
+
36
+ const int i03 = i3%args.ne03;
37
+ const int i02 = i2%args.ne02;
38
+ const int i01 = i1%args.ne01;
39
+
40
+ device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
41
+ device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
42
+
43
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
44
+ const int i00 = i0%args.ne00;
45
+ *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
46
+ }
47
+ }
48
+
49
+ typedef decltype(kernel_repeat<float>) kernel_repeat_t;
50
+
51
+ // Host-visible F32 repeat used for HC expansion of embeddings.
52
+ template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
@@ -0,0 +1,55 @@
1
+ // DS4 Metal set-rows kernel used for KV writes.
2
+
3
+ struct ds4_metal_args_set_rows {
4
+ int32_t nk0;
5
+ int32_t ne01;
6
+ uint64_t nb01;
7
+ uint64_t nb02;
8
+ uint64_t nb03;
9
+ int32_t ne11;
10
+ int32_t ne12;
11
+ uint64_t nb10;
12
+ uint64_t nb11;
13
+ uint64_t nb12;
14
+ uint64_t nb1;
15
+ uint64_t nb2;
16
+ uint64_t nb3;
17
+ };
18
+
19
+ // Scatters rows into the KV cache by token position. DS4 uses this after Q/K/V
20
+ // preparation so decode and later prefill chunks can attend to previous tokens.
21
+ template<typename T, typename TI>
22
+ kernel void kernel_set_rows_f(
23
+ constant ds4_metal_args_set_rows & args,
24
+ device const char * src0,
25
+ device const char * src1,
26
+ device float * dst,
27
+ uint3 tgpig[[threadgroup_position_in_grid]],
28
+ uint tiitg[[thread_index_in_threadgroup]],
29
+ uint3 tptg [[threads_per_threadgroup]]) {
30
+ const int32_t i03 = tgpig.z;
31
+ const int32_t i02 = tgpig.y;
32
+
33
+ const int32_t i12 = i03%args.ne12;
34
+ const int32_t i11 = i02%args.ne11;
35
+
36
+ const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
37
+ if (i01 >= args.ne01) {
38
+ return;
39
+ }
40
+
41
+ const int32_t i10 = i01;
42
+ const TI i1 = ((const device TI *) (src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
43
+
44
+ device T * dst_row = ( device T *) ((device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
45
+ const device float * src_row = (const device float *) ( src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
46
+
47
+ for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
48
+ dst_row[ind] = (T) src_row[ind];
49
+ }
50
+ }
51
+
52
+ typedef decltype(kernel_set_rows_f<float, int64_t>) set_rows_f_t;
53
+
54
+ // Host-visible F32/I32 scatter variant used by KV-cache writes.
55
+ template [[host_name("kernel_set_rows_f32_i32")]] kernel set_rows_f_t kernel_set_rows_f<float, int32_t>;
@@ -0,0 +1,241 @@
1
+ // DS4 Metal softmax kernel used by the compressor pooling compatibility path.
2
+ // The single-compressed-row path is intentionally left as soft_max -> mul ->
3
+ // sum_rows instead of using the fused dsv4_softmax_pool kernel.
4
+
5
+ struct ds4_metal_args_soft_max {
6
+ int32_t ne00;
7
+ int32_t ne01;
8
+ int32_t ne02;
9
+ uint64_t nb01;
10
+ uint64_t nb02;
11
+ uint64_t nb03;
12
+ int32_t ne11;
13
+ int32_t ne12;
14
+ int32_t ne13;
15
+ uint64_t nb11;
16
+ uint64_t nb12;
17
+ uint64_t nb13;
18
+ uint64_t nb1;
19
+ uint64_t nb2;
20
+ uint64_t nb3;
21
+ float scale;
22
+ float max_bias;
23
+ float m0;
24
+ float m1;
25
+ int32_t n_head_log2;
26
+ };
27
+
28
+ // Row softmax for score matrices. DS4 uses it in the literal one-compressor-row
29
+ // path where preserving the original graph operation boundary avoids drift.
30
+ template<typename T>
31
+ kernel void kernel_soft_max(
32
+ constant ds4_metal_args_soft_max & args,
33
+ device const char * src0,
34
+ device const char * src1,
35
+ device const char * src2,
36
+ device char * dst,
37
+ threadgroup float * buf [[threadgroup(0)]],
38
+ uint3 tgpig[[threadgroup_position_in_grid]],
39
+ uint3 tpitg[[thread_position_in_threadgroup]],
40
+ uint sgitg[[simdgroup_index_in_threadgroup]],
41
+ uint tiisg[[thread_index_in_simdgroup]],
42
+ uint3 tptg[[threads_per_threadgroup]]) {
43
+ const int32_t i03 = tgpig.z;
44
+ const int32_t i02 = tgpig.y;
45
+ const int32_t i01 = tgpig.x;
46
+
47
+ const int32_t i13 = i03%args.ne13;
48
+ const int32_t i12 = i02%args.ne12;
49
+ const int32_t i11 = i01;
50
+
51
+ device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
52
+ device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
53
+ device const float * psrc2 = src2 != src0 ? (device const float *) (src2) : nullptr;
54
+ device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
55
+
56
+ float slope = 1.0f;
57
+
58
+ if (args.max_bias > 0.0f) {
59
+ const int32_t h = i02;
60
+
61
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1;
62
+ const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
63
+
64
+ slope = pow(base, exp);
65
+ }
66
+
67
+ float lmax = psrc2 ? psrc2[i02] : -INFINITY;
68
+
69
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
70
+ lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
71
+ }
72
+
73
+ float max_val = simd_max(lmax);
74
+ if (tptg.x > N_SIMDWIDTH) {
75
+ if (sgitg == 0) {
76
+ buf[tiisg] = -INFINITY;
77
+ }
78
+
79
+ threadgroup_barrier(mem_flags::mem_threadgroup);
80
+
81
+ if (tiisg == 0) {
82
+ buf[sgitg] = max_val;
83
+ }
84
+
85
+ threadgroup_barrier(mem_flags::mem_threadgroup);
86
+
87
+ max_val = buf[tiisg];
88
+ max_val = simd_max(max_val);
89
+ }
90
+
91
+ float lsum = 0.0f;
92
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
93
+ const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
94
+ lsum += exp_psrc0;
95
+ pdst[i00] = exp_psrc0;
96
+ }
97
+
98
+ threadgroup_barrier(mem_flags::mem_none);
99
+
100
+ float sum = simd_sum(lsum);
101
+
102
+ if (tptg.x > N_SIMDWIDTH) {
103
+ if (sgitg == 0) {
104
+ buf[tiisg] = 0.0f;
105
+ }
106
+
107
+ threadgroup_barrier(mem_flags::mem_threadgroup);
108
+
109
+ if (tiisg == 0) {
110
+ buf[sgitg] = sum;
111
+ }
112
+
113
+ threadgroup_barrier(mem_flags::mem_threadgroup);
114
+
115
+ sum = buf[tiisg];
116
+ sum = simd_sum(sum);
117
+ }
118
+
119
+ if (psrc2) {
120
+ sum += exp(psrc2[i02] - max_val);
121
+ }
122
+
123
+ const float inv_sum = 1.0f/sum;
124
+
125
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
126
+ pdst[i00] *= inv_sum;
127
+ }
128
+ }
129
+
130
+ // Vectorized float4 row softmax for contiguous score rows whose length is a
131
+ // multiple of four; used by the same DS4 compressor/indexer graph path.
132
+ template<typename T>
133
+ kernel void kernel_soft_max_4(
134
+ constant ds4_metal_args_soft_max & args,
135
+ device const char * src0,
136
+ device const char * src1,
137
+ device const char * src2,
138
+ device char * dst,
139
+ threadgroup float * buf [[threadgroup(0)]],
140
+ uint3 tgpig[[threadgroup_position_in_grid]],
141
+ uint3 tpitg[[thread_position_in_threadgroup]],
142
+ uint sgitg[[simdgroup_index_in_threadgroup]],
143
+ uint tiisg[[thread_index_in_simdgroup]],
144
+ uint3 tptg[[threads_per_threadgroup]]) {
145
+ const int32_t i03 = tgpig.z;
146
+ const int32_t i02 = tgpig.y;
147
+ const int32_t i01 = tgpig.x;
148
+
149
+ const int32_t i13 = i03%args.ne13;
150
+ const int32_t i12 = i02%args.ne12;
151
+ const int32_t i11 = i01;
152
+
153
+ device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
154
+ device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
155
+ device const float * psrc2 = src2 != src0 ? (device const float * ) (src2) : nullptr;
156
+ device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
157
+
158
+ float slope = 1.0f;
159
+
160
+ if (args.max_bias > 0.0f) {
161
+ const int32_t h = i02;
162
+
163
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1;
164
+ const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
165
+
166
+ slope = pow(base, exp);
167
+ }
168
+
169
+ float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
170
+
171
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
172
+ lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
173
+ }
174
+
175
+ const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
176
+
177
+ float max_val = simd_max(lmax);
178
+ if (tptg.x > N_SIMDWIDTH) {
179
+ if (sgitg == 0) {
180
+ buf[tiisg] = -INFINITY;
181
+ }
182
+
183
+ threadgroup_barrier(mem_flags::mem_threadgroup);
184
+
185
+ if (tiisg == 0) {
186
+ buf[sgitg] = max_val;
187
+ }
188
+
189
+ threadgroup_barrier(mem_flags::mem_threadgroup);
190
+
191
+ max_val = buf[tiisg];
192
+ max_val = simd_max(max_val);
193
+ }
194
+
195
+ float4 lsum4 = 0.0f;
196
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
197
+ const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
198
+ lsum4 += exp_psrc4;
199
+ pdst4[i00] = exp_psrc4;
200
+ }
201
+
202
+ const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
203
+
204
+ threadgroup_barrier(mem_flags::mem_none);
205
+
206
+ float sum = simd_sum(lsum);
207
+
208
+ if (tptg.x > N_SIMDWIDTH) {
209
+ if (sgitg == 0) {
210
+ buf[tiisg] = 0.0f;
211
+ }
212
+
213
+ threadgroup_barrier(mem_flags::mem_threadgroup);
214
+
215
+ if (tiisg == 0) {
216
+ buf[sgitg] = sum;
217
+ }
218
+
219
+ threadgroup_barrier(mem_flags::mem_threadgroup);
220
+
221
+ sum = buf[tiisg];
222
+ sum = simd_sum(sum);
223
+ }
224
+
225
+ if (psrc2) {
226
+ sum += exp(psrc2[i02] - max_val);
227
+ }
228
+
229
+ const float inv_sum = 1.0f/sum;
230
+
231
+ for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
232
+ pdst4[i00] *= inv_sum;
233
+ }
234
+ }
235
+
236
+ typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
237
+ typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
238
+
239
+ // Host-visible F32 softmax variants used by compressor pooling.
240
+ template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
241
+ template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
@@ -0,0 +1,102 @@
1
+ // DS4 Metal row-sum kernel.
2
+
3
+ #define FC_SUM_ROWS 1400
4
+
5
+ #define OP_SUM_ROWS_NUM_SUM_ROWS 10
6
+ #define OP_SUM_ROWS_NUM_MEAN 11
7
+
8
+ struct ds4_metal_args_sum_rows {
9
+ int64_t ne00;
10
+ int64_t ne01;
11
+ int64_t ne02;
12
+ int64_t ne03;
13
+ uint64_t nb00;
14
+ uint64_t nb01;
15
+ uint64_t nb02;
16
+ uint64_t nb03;
17
+ int64_t ne0;
18
+ int64_t ne1;
19
+ int64_t ne2;
20
+ int64_t ne3;
21
+ uint64_t nb0;
22
+ uint64_t nb1;
23
+ uint64_t nb2;
24
+ uint64_t nb3;
25
+ };
26
+
27
+ static inline float sum(float x) {
28
+ return x;
29
+ }
30
+
31
+ static inline float sum(float4 x) {
32
+ return x[0] + x[1] + x[2] + x[3];
33
+ }
34
+
35
+ constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
36
+
37
+ // Reduces each row to a sum or mean. DS4 mainly uses the sum form to preserve
38
+ // the compressor-pooling graph boundary in the single-compressor-row case.
39
+ template <typename T0, typename T>
40
+ kernel void kernel_sum_rows_impl(
41
+ constant ds4_metal_args_sum_rows & args,
42
+ device const char * src0,
43
+ device char * dst,
44
+ threadgroup char * shmem [[threadgroup(0)]],
45
+ uint3 tgpig[[threadgroup_position_in_grid]],
46
+ ushort3 tpitg[[thread_position_in_threadgroup]],
47
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
48
+ ushort tiisg[[thread_index_in_simdgroup]],
49
+ ushort3 ntg[[threads_per_threadgroup]]) {
50
+ #define FC_OP FC_sum_rows_op
51
+
52
+ const int i3 = tgpig.z;
53
+ const int i2 = tgpig.y;
54
+ const int i1 = tgpig.x;
55
+
56
+ threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
57
+
58
+ if (sgitg == 0) {
59
+ shmem_t[tiisg] = 0.0f;
60
+ }
61
+
62
+ device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
63
+ device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
64
+
65
+ T0 sumf = T0(0.0f);
66
+
67
+ for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
68
+ sumf += src_row[i0];
69
+ }
70
+
71
+ sumf = simd_sum(sumf);
72
+
73
+ threadgroup_barrier(mem_flags::mem_threadgroup);
74
+
75
+ if (tiisg == 0) {
76
+ shmem_t[sgitg] = sumf;
77
+ }
78
+
79
+ threadgroup_barrier(mem_flags::mem_threadgroup);
80
+
81
+ sumf = shmem_t[tiisg];
82
+ sumf = simd_sum(sumf);
83
+
84
+ if (tpitg.x == 0) {
85
+ if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
86
+ if (is_same<float4, T0>::value) {
87
+ dst_row[0] = sum(sumf) / (4*args.ne00);
88
+ } else {
89
+ dst_row[0] = sum(sumf) / args.ne00;
90
+ }
91
+ } else {
92
+ dst_row[0] = sum(sumf);
93
+ }
94
+ }
95
+
96
+ #undef FC_OP
97
+ }
98
+
99
+ typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
100
+
101
+ // Host-visible F32 row reduction used by compressor pooling.
102
+ template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;