@lgrammel/ds4-provider 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +96 -0
- package/binding.gyp +75 -0
- package/dist/ds4-language-model.d.ts +71 -0
- package/dist/ds4-language-model.d.ts.map +1 -0
- package/dist/ds4-language-model.js +888 -0
- package/dist/ds4-language-model.js.map +1 -0
- package/dist/ds4-provider.d.ts +13 -0
- package/dist/ds4-provider.d.ts.map +1 -0
- package/dist/ds4-provider.js +20 -0
- package/dist/ds4-provider.js.map +1 -0
- package/dist/index.d.ts +4 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +4 -0
- package/dist/index.js.map +1 -0
- package/dist/native-binding.d.ts +42 -0
- package/dist/native-binding.d.ts.map +1 -0
- package/dist/native-binding.js +157 -0
- package/dist/native-binding.js.map +1 -0
- package/ds4/LICENSE +22 -0
- package/ds4/ds4.c +18268 -0
- package/ds4/ds4.h +196 -0
- package/ds4/ds4_gpu.h +804 -0
- package/ds4/ds4_metal.m +14657 -0
- package/ds4/metal/argsort.metal +266 -0
- package/ds4/metal/bin.metal +192 -0
- package/ds4/metal/concat.metal +62 -0
- package/ds4/metal/cpy.metal +57 -0
- package/ds4/metal/dense.metal +1121 -0
- package/ds4/metal/dsv4_hc.metal +861 -0
- package/ds4/metal/dsv4_kv.metal +227 -0
- package/ds4/metal/dsv4_misc.metal +1088 -0
- package/ds4/metal/dsv4_rope.metal +155 -0
- package/ds4/metal/flash_attn.metal +1426 -0
- package/ds4/metal/get_rows.metal +54 -0
- package/ds4/metal/glu.metal +36 -0
- package/ds4/metal/moe.metal +1737 -0
- package/ds4/metal/norm.metal +153 -0
- package/ds4/metal/repeat.metal +52 -0
- package/ds4/metal/set_rows.metal +55 -0
- package/ds4/metal/softmax.metal +241 -0
- package/ds4/metal/sum_rows.metal +102 -0
- package/ds4/metal/unary.metal +312 -0
- package/native/binding.cpp +621 -0
- package/package.json +66 -0
- package/scripts/postinstall.cjs +13 -0
- package/scripts/vendor-ds4.cjs +67 -0
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
struct ds4_metal_args_argsort {
|
|
2
|
+
int32_t ne00;
|
|
3
|
+
int32_t ne01;
|
|
4
|
+
int32_t ne02;
|
|
5
|
+
int32_t ne03;
|
|
6
|
+
uint64_t nb00;
|
|
7
|
+
uint64_t nb01;
|
|
8
|
+
uint64_t nb02;
|
|
9
|
+
uint64_t nb03;
|
|
10
|
+
int32_t ne0;
|
|
11
|
+
int32_t ne1;
|
|
12
|
+
int32_t ne2;
|
|
13
|
+
int32_t ne3;
|
|
14
|
+
int32_t top_k;
|
|
15
|
+
};
|
|
16
|
+
|
|
17
|
+
struct ds4_metal_args_argsort_merge {
|
|
18
|
+
int64_t ne00;
|
|
19
|
+
int64_t ne01;
|
|
20
|
+
int64_t ne02;
|
|
21
|
+
int64_t ne03;
|
|
22
|
+
uint64_t nb00;
|
|
23
|
+
uint64_t nb01;
|
|
24
|
+
uint64_t nb02;
|
|
25
|
+
uint64_t nb03;
|
|
26
|
+
int32_t ne0;
|
|
27
|
+
int32_t ne1;
|
|
28
|
+
int32_t ne2;
|
|
29
|
+
int32_t ne3;
|
|
30
|
+
int32_t top_k;
|
|
31
|
+
int32_t len;
|
|
32
|
+
};
|
|
33
|
+
|
|
34
|
+
typedef void (argsort_t)(
|
|
35
|
+
constant ds4_metal_args_argsort & args,
|
|
36
|
+
device const char * src0,
|
|
37
|
+
device int32_t * dst,
|
|
38
|
+
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
|
39
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
40
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
41
|
+
ushort3 ntg[[threads_per_threadgroup]]);
|
|
42
|
+
|
|
43
|
+
// Sort one float row into an index row. DS4 only exports the descending
|
|
44
|
+
// instance because router and indexer selection both need top-k order.
|
|
45
|
+
template<ds4_sort_order order>
|
|
46
|
+
kernel void kernel_argsort_f32_i32(
|
|
47
|
+
constant ds4_metal_args_argsort & args,
|
|
48
|
+
device const char * src0,
|
|
49
|
+
device int32_t * dst,
|
|
50
|
+
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
|
51
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
52
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
53
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
54
|
+
// bitonic sort
|
|
55
|
+
const int col = tpitg[0];
|
|
56
|
+
const int ib = tgpig[0] / args.ne01;
|
|
57
|
+
|
|
58
|
+
const int i00 = ib*ntg.x;
|
|
59
|
+
const int i01 = tgpig[0] % args.ne01;
|
|
60
|
+
const int i02 = tgpig[1];
|
|
61
|
+
const int i03 = tgpig[2];
|
|
62
|
+
|
|
63
|
+
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
|
|
64
|
+
|
|
65
|
+
// initialize indices
|
|
66
|
+
shmem_i32[col] = i00 + col;
|
|
67
|
+
|
|
68
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
69
|
+
|
|
70
|
+
for (int k = 2; k <= ntg.x; k *= 2) {
|
|
71
|
+
for (int j = k / 2; j > 0; j /= 2) {
|
|
72
|
+
int ixj = col ^ j;
|
|
73
|
+
if (ixj > col) {
|
|
74
|
+
if ((col & k) == 0) {
|
|
75
|
+
if (shmem_i32[col] >= args.ne00 ||
|
|
76
|
+
(shmem_i32[ixj] < args.ne00 && (order == DS4_SORT_ORDER_ASC ?
|
|
77
|
+
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
|
|
78
|
+
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
|
|
79
|
+
) {
|
|
80
|
+
SWAP(shmem_i32[col], shmem_i32[ixj]);
|
|
81
|
+
}
|
|
82
|
+
} else {
|
|
83
|
+
if (shmem_i32[ixj] >= args.ne00 ||
|
|
84
|
+
(shmem_i32[col] < args.ne00 && (order == DS4_SORT_ORDER_ASC ?
|
|
85
|
+
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
|
|
86
|
+
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
|
|
87
|
+
) {
|
|
88
|
+
SWAP(shmem_i32[col], shmem_i32[ixj]);
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
const int64_t i0 = ib*args.top_k;
|
|
98
|
+
|
|
99
|
+
// copy the result to dst without the padding
|
|
100
|
+
if (i0 + col < args.ne0 && col < args.top_k) {
|
|
101
|
+
dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
|
|
102
|
+
|
|
103
|
+
dst[col] = shmem_i32[col];
|
|
104
|
+
}
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
// Host-visible sort variant used by DS4 top-k selection.
|
|
108
|
+
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<DS4_SORT_ORDER_DESC>;
|
|
109
|
+
|
|
110
|
+
typedef void (argsort_merge_t)(
|
|
111
|
+
constant ds4_metal_args_argsort_merge & args,
|
|
112
|
+
device const char * src0,
|
|
113
|
+
device const int32_t * tmp,
|
|
114
|
+
device int32_t * dst,
|
|
115
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
116
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
117
|
+
ushort3 ntg[[threads_per_threadgroup]]);
|
|
118
|
+
|
|
119
|
+
// Merges sorted index runs produced by kernel_argsort_f32_i32. In the DS4 graph
|
|
120
|
+
// this finishes top-k over router or compressed-attention score rows.
|
|
121
|
+
template<ds4_sort_order order>
|
|
122
|
+
kernel void kernel_argsort_merge_f32_i32(
|
|
123
|
+
constant ds4_metal_args_argsort_merge & args,
|
|
124
|
+
device const char * src0,
|
|
125
|
+
device const int32_t * tmp,
|
|
126
|
+
device int32_t * dst,
|
|
127
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
128
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
129
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
130
|
+
|
|
131
|
+
const int im = tgpig[0] / args.ne01;
|
|
132
|
+
const int i01 = tgpig[0] % args.ne01;
|
|
133
|
+
const int i02 = tgpig[1];
|
|
134
|
+
const int i03 = tgpig[2];
|
|
135
|
+
|
|
136
|
+
const int start = im * (2 * args.len);
|
|
137
|
+
|
|
138
|
+
const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
|
|
139
|
+
const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
|
|
140
|
+
|
|
141
|
+
const int total = len0 + len1;
|
|
142
|
+
|
|
143
|
+
device const int32_t * tmp0 = tmp + start
|
|
144
|
+
+ i01*args.ne0
|
|
145
|
+
+ i02*args.ne0*args.ne01
|
|
146
|
+
+ i03*args.ne0*args.ne01*args.ne02;
|
|
147
|
+
|
|
148
|
+
device const int32_t * tmp1 = tmp0 + args.len;
|
|
149
|
+
|
|
150
|
+
dst += start
|
|
151
|
+
+ i01*args.top_k
|
|
152
|
+
+ i02*args.top_k*args.ne01
|
|
153
|
+
+ i03*args.top_k*args.ne01*args.ne02;
|
|
154
|
+
|
|
155
|
+
device const float * src0_row = (device const float *)(src0
|
|
156
|
+
+ args.nb01*i01
|
|
157
|
+
+ args.nb02*i02
|
|
158
|
+
+ args.nb03*i03);
|
|
159
|
+
|
|
160
|
+
if (total == 0) {
|
|
161
|
+
return;
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
const int chunk = (total + ntg.x - 1) / ntg.x;
|
|
165
|
+
|
|
166
|
+
const int k0 = tpitg.x * chunk;
|
|
167
|
+
const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
|
|
168
|
+
|
|
169
|
+
if (k0 >= args.top_k) {
|
|
170
|
+
return;
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
if (k0 >= total) {
|
|
174
|
+
return;
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
int low = k0 > len1 ? k0 - len1 : 0;
|
|
178
|
+
int high = MIN(k0, len0);
|
|
179
|
+
|
|
180
|
+
// binary-search partition (i, j) such that i + j = k
|
|
181
|
+
while (low < high) {
|
|
182
|
+
const int mid = (low + high) >> 1;
|
|
183
|
+
|
|
184
|
+
const int32_t idx0 = tmp0[mid];
|
|
185
|
+
const int32_t idx1 = tmp1[k0 - mid - 1];
|
|
186
|
+
|
|
187
|
+
const float val0 = src0_row[idx0];
|
|
188
|
+
const float val1 = src0_row[idx1];
|
|
189
|
+
|
|
190
|
+
bool take_left;
|
|
191
|
+
if (order == DS4_SORT_ORDER_ASC) {
|
|
192
|
+
take_left = (val0 <= val1);
|
|
193
|
+
} else {
|
|
194
|
+
take_left = (val0 >= val1);
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
if (take_left) {
|
|
198
|
+
low = mid + 1;
|
|
199
|
+
} else {
|
|
200
|
+
high = mid;
|
|
201
|
+
}
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
int i = low;
|
|
205
|
+
int j = k0 - i;
|
|
206
|
+
|
|
207
|
+
// keep the merge fronts into registers
|
|
208
|
+
int32_t idx0 = 0;
|
|
209
|
+
float val0 = 0.0f;
|
|
210
|
+
if (i < len0) {
|
|
211
|
+
idx0 = tmp0[i];
|
|
212
|
+
val0 = src0_row[idx0];
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
int32_t idx1 = 0;
|
|
216
|
+
float val1 = 0.0f;
|
|
217
|
+
if (j < len1) {
|
|
218
|
+
idx1 = tmp1[j];
|
|
219
|
+
val1 = src0_row[idx1];
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
for (int k = k0; k < k1; ++k) {
|
|
223
|
+
int32_t out_idx;
|
|
224
|
+
|
|
225
|
+
if (i >= len0) {
|
|
226
|
+
while (k < k1) {
|
|
227
|
+
dst[k++] = tmp1[j++];
|
|
228
|
+
}
|
|
229
|
+
break;
|
|
230
|
+
} else if (j >= len1) {
|
|
231
|
+
while (k < k1) {
|
|
232
|
+
dst[k++] = tmp0[i++];
|
|
233
|
+
}
|
|
234
|
+
break;
|
|
235
|
+
} else {
|
|
236
|
+
bool take_left;
|
|
237
|
+
|
|
238
|
+
if (order == DS4_SORT_ORDER_ASC) {
|
|
239
|
+
take_left = (val0 <= val1);
|
|
240
|
+
} else {
|
|
241
|
+
take_left = (val0 >= val1);
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
if (take_left) {
|
|
245
|
+
out_idx = idx0;
|
|
246
|
+
++i;
|
|
247
|
+
if (i < len0) {
|
|
248
|
+
idx0 = tmp0[i];
|
|
249
|
+
val0 = src0_row[idx0];
|
|
250
|
+
}
|
|
251
|
+
} else {
|
|
252
|
+
out_idx = idx1;
|
|
253
|
+
++j;
|
|
254
|
+
if (j < len1) {
|
|
255
|
+
idx1 = tmp1[j];
|
|
256
|
+
val1 = src0_row[idx1];
|
|
257
|
+
}
|
|
258
|
+
}
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
dst[k] = out_idx;
|
|
262
|
+
}
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
// Host-visible merge variant used by DS4 top-k selection.
|
|
266
|
+
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<DS4_SORT_ORDER_DESC>;
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
struct ds4_metal_args_bin {
|
|
2
|
+
int32_t ne00;
|
|
3
|
+
int32_t ne01;
|
|
4
|
+
int32_t ne02;
|
|
5
|
+
int32_t ne03;
|
|
6
|
+
uint64_t nb00;
|
|
7
|
+
uint64_t nb01;
|
|
8
|
+
uint64_t nb02;
|
|
9
|
+
uint64_t nb03;
|
|
10
|
+
int32_t ne10;
|
|
11
|
+
int32_t ne11;
|
|
12
|
+
int32_t ne12;
|
|
13
|
+
int32_t ne13;
|
|
14
|
+
uint64_t nb10;
|
|
15
|
+
uint64_t nb11;
|
|
16
|
+
uint64_t nb12;
|
|
17
|
+
uint64_t nb13;
|
|
18
|
+
int32_t ne0;
|
|
19
|
+
int32_t ne1;
|
|
20
|
+
int32_t ne2;
|
|
21
|
+
int32_t ne3;
|
|
22
|
+
uint64_t nb0;
|
|
23
|
+
uint64_t nb1;
|
|
24
|
+
uint64_t nb2;
|
|
25
|
+
uint64_t nb3;
|
|
26
|
+
uint64_t offs;
|
|
27
|
+
uint64_t o1[8];
|
|
28
|
+
};
|
|
29
|
+
|
|
30
|
+
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
|
|
31
|
+
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
|
|
32
|
+
constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
|
|
33
|
+
constant bool FC_bin_cb [[function_constant(FC_BIN + 3)]];
|
|
34
|
+
|
|
35
|
+
// Generic binary elementwise op with compile-time operation and broadcast
|
|
36
|
+
// modes. DS4 currently instantiates this as add, multiply, scalar multiply, and
|
|
37
|
+
// row division in the static graph.
|
|
38
|
+
template <typename T0, typename T1, typename T>
|
|
39
|
+
kernel void kernel_bin_fuse_impl(
|
|
40
|
+
constant ds4_metal_args_bin & args,
|
|
41
|
+
device const char * src0,
|
|
42
|
+
device const char * src1,
|
|
43
|
+
device char * dst,
|
|
44
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
45
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
46
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
47
|
+
#define FC_OP FC_bin_op
|
|
48
|
+
#define FC_F FC_bin_f
|
|
49
|
+
#define FC_RB FC_bin_rb
|
|
50
|
+
#define FC_CB FC_bin_cb
|
|
51
|
+
|
|
52
|
+
if (FC_RB) {
|
|
53
|
+
const uint i0 = tgpig.y*args.ne00 + tgpig.x;
|
|
54
|
+
const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x;
|
|
55
|
+
|
|
56
|
+
device const T0 * src0_row = (device const T0 *) (src0);
|
|
57
|
+
device T * dst_row = (device T *) (dst);
|
|
58
|
+
|
|
59
|
+
if (FC_F == 1) {
|
|
60
|
+
device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
|
|
61
|
+
|
|
62
|
+
if (FC_OP == 0) {
|
|
63
|
+
dst_row[i0] = src0_row[i0] + src1_row[i1];
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
if (FC_OP == 1) {
|
|
67
|
+
dst_row[i0] = src0_row[i0] - src1_row[i1];
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
if (FC_OP == 2) {
|
|
71
|
+
dst_row[i0] = src0_row[i0] * src1_row[i1];
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
if (FC_OP == 3) {
|
|
75
|
+
dst_row[i0] = src0_row[i0] / src1_row[i1];
|
|
76
|
+
}
|
|
77
|
+
} else {
|
|
78
|
+
T0 res = src0_row[i0];
|
|
79
|
+
|
|
80
|
+
if (FC_OP == 0) {
|
|
81
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
82
|
+
res += ((device const T1 *) (src1 + args.o1[j]))[i1];
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
if (FC_OP == 1) {
|
|
87
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
88
|
+
res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
if (FC_OP == 2) {
|
|
93
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
94
|
+
res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
if (FC_OP == 3) {
|
|
99
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
100
|
+
res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
dst_row[i0] = res;
|
|
105
|
+
}
|
|
106
|
+
} else {
|
|
107
|
+
const int i03 = tgpig.z;
|
|
108
|
+
const int i02 = tgpig.y;
|
|
109
|
+
const int i01 = tgpig.x;
|
|
110
|
+
|
|
111
|
+
if (i01 >= args.ne01) {
|
|
112
|
+
return;
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
const int i13 = i03%args.ne13;
|
|
116
|
+
const int i12 = i02%args.ne12;
|
|
117
|
+
const int i11 = i01%args.ne11;
|
|
118
|
+
|
|
119
|
+
device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
|
|
120
|
+
device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
|
|
121
|
+
|
|
122
|
+
if (FC_F == 1) {
|
|
123
|
+
device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
|
124
|
+
|
|
125
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
126
|
+
const int i10 = FC_CB ? i0%args.ne10 : i0;
|
|
127
|
+
|
|
128
|
+
if (FC_OP == 0) {
|
|
129
|
+
dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
if (FC_OP == 1) {
|
|
133
|
+
dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
if (FC_OP == 2) {
|
|
137
|
+
dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
if (FC_OP == 3) {
|
|
141
|
+
dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
} else {
|
|
145
|
+
device const T1 * src1_ptr[8];
|
|
146
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
147
|
+
src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
151
|
+
const int i10 = FC_CB ? i0%args.ne10 : i0;
|
|
152
|
+
|
|
153
|
+
T res = src0_ptr[i0];
|
|
154
|
+
|
|
155
|
+
if (FC_OP == 0) {
|
|
156
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
157
|
+
res += src1_ptr[j][i10];
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
if (FC_OP == 1) {
|
|
162
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
163
|
+
res -= src1_ptr[j][i10];
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
if (FC_OP == 2) {
|
|
168
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
169
|
+
res *= src1_ptr[j][i10];
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
if (FC_OP == 3) {
|
|
174
|
+
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
|
175
|
+
res /= src1_ptr[j][i10];
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
dst_ptr[i0] = res;
|
|
180
|
+
}
|
|
181
|
+
}
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
#undef FC_OP
|
|
185
|
+
#undef FC_F
|
|
186
|
+
#undef FC_RB
|
|
187
|
+
#undef FC_CB
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
|
|
191
|
+
// Host-visible F32 binary op; function constants specialize it per use site.
|
|
192
|
+
template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float, float, float>;
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
// DS4 Metal concat kernel used by the graph.
|
|
2
|
+
|
|
3
|
+
struct ds4_metal_args_concat {
|
|
4
|
+
int32_t ne00;
|
|
5
|
+
int32_t ne01;
|
|
6
|
+
int32_t ne02;
|
|
7
|
+
int32_t ne03;
|
|
8
|
+
uint64_t nb00;
|
|
9
|
+
uint64_t nb01;
|
|
10
|
+
uint64_t nb02;
|
|
11
|
+
uint64_t nb03;
|
|
12
|
+
int32_t ne10;
|
|
13
|
+
int32_t ne11;
|
|
14
|
+
int32_t ne12;
|
|
15
|
+
int32_t ne13;
|
|
16
|
+
uint64_t nb10;
|
|
17
|
+
uint64_t nb11;
|
|
18
|
+
uint64_t nb12;
|
|
19
|
+
uint64_t nb13;
|
|
20
|
+
int32_t ne0;
|
|
21
|
+
int32_t ne1;
|
|
22
|
+
int32_t ne2;
|
|
23
|
+
int32_t ne3;
|
|
24
|
+
uint64_t nb0;
|
|
25
|
+
uint64_t nb1;
|
|
26
|
+
uint64_t nb2;
|
|
27
|
+
uint64_t nb3;
|
|
28
|
+
int32_t dim;
|
|
29
|
+
};
|
|
30
|
+
|
|
31
|
+
// Concatenates two float tensors along one dimension. In DS4 this is a graph
|
|
32
|
+
// utility for assembling attention inputs with exactly the same tensor layout
|
|
33
|
+
// expected by the downstream kernels.
|
|
34
|
+
kernel void kernel_concat(
|
|
35
|
+
constant ds4_metal_args_concat & args,
|
|
36
|
+
device const char * src0,
|
|
37
|
+
device const char * src1,
|
|
38
|
+
device char * dst,
|
|
39
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
40
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
41
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
42
|
+
const int i3 = tgpig.z;
|
|
43
|
+
const int i2 = tgpig.y;
|
|
44
|
+
const int i1 = tgpig.x;
|
|
45
|
+
|
|
46
|
+
int o[4] = {0, 0, 0, 0};
|
|
47
|
+
o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
|
|
48
|
+
|
|
49
|
+
device const float * x;
|
|
50
|
+
|
|
51
|
+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
52
|
+
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
|
|
53
|
+
x = (device const float *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00);
|
|
54
|
+
} else {
|
|
55
|
+
x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
59
|
+
|
|
60
|
+
*y = *x;
|
|
61
|
+
}
|
|
62
|
+
}
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
struct ds4_metal_args_cpy {
|
|
2
|
+
int64_t nk0;
|
|
3
|
+
int64_t ne00;
|
|
4
|
+
int64_t ne01;
|
|
5
|
+
int64_t ne02;
|
|
6
|
+
int64_t ne03;
|
|
7
|
+
uint64_t nb00;
|
|
8
|
+
uint64_t nb01;
|
|
9
|
+
uint64_t nb02;
|
|
10
|
+
uint64_t nb03;
|
|
11
|
+
int64_t ne0;
|
|
12
|
+
int64_t ne1;
|
|
13
|
+
int64_t ne2;
|
|
14
|
+
int64_t ne3;
|
|
15
|
+
uint64_t nb0;
|
|
16
|
+
uint64_t nb1;
|
|
17
|
+
uint64_t nb2;
|
|
18
|
+
uint64_t nb3;
|
|
19
|
+
};
|
|
20
|
+
|
|
21
|
+
// Typed copy/conversion between graph tensors. DS4 uses this for layout
|
|
22
|
+
// materialization and F32/F16 conversions at graph boundaries such as KV/cache
|
|
23
|
+
// packing and compressor pooling.
|
|
24
|
+
template<typename T0, typename T1>
|
|
25
|
+
kernel void kernel_cpy_t_t(
|
|
26
|
+
constant ds4_metal_args_cpy & args,
|
|
27
|
+
device const char * src0,
|
|
28
|
+
device char * dst,
|
|
29
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
30
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
|
31
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
32
|
+
const int i03 = tgpig[2];
|
|
33
|
+
const int i02 = tgpig[1];
|
|
34
|
+
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
|
|
35
|
+
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
|
36
|
+
|
|
37
|
+
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
38
|
+
|
|
39
|
+
const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
|
40
|
+
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
|
41
|
+
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
|
42
|
+
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
|
43
|
+
|
|
44
|
+
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
45
|
+
|
|
46
|
+
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {
|
|
47
|
+
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
48
|
+
dst_data[i00] = (T1) src[0];
|
|
49
|
+
break;
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;
|
|
54
|
+
// Host-visible copy/conversion variants used by the DS4 graph.
|
|
55
|
+
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, float>;
|
|
56
|
+
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
|
|
57
|
+
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<half, float>;
|