@lgrammel/ds4-provider 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +96 -0
- package/binding.gyp +75 -0
- package/dist/ds4-language-model.d.ts +71 -0
- package/dist/ds4-language-model.d.ts.map +1 -0
- package/dist/ds4-language-model.js +888 -0
- package/dist/ds4-language-model.js.map +1 -0
- package/dist/ds4-provider.d.ts +13 -0
- package/dist/ds4-provider.d.ts.map +1 -0
- package/dist/ds4-provider.js +20 -0
- package/dist/ds4-provider.js.map +1 -0
- package/dist/index.d.ts +4 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +4 -0
- package/dist/index.js.map +1 -0
- package/dist/native-binding.d.ts +42 -0
- package/dist/native-binding.d.ts.map +1 -0
- package/dist/native-binding.js +157 -0
- package/dist/native-binding.js.map +1 -0
- package/ds4/LICENSE +22 -0
- package/ds4/ds4.c +18268 -0
- package/ds4/ds4.h +196 -0
- package/ds4/ds4_gpu.h +804 -0
- package/ds4/ds4_metal.m +14657 -0
- package/ds4/metal/argsort.metal +266 -0
- package/ds4/metal/bin.metal +192 -0
- package/ds4/metal/concat.metal +62 -0
- package/ds4/metal/cpy.metal +57 -0
- package/ds4/metal/dense.metal +1121 -0
- package/ds4/metal/dsv4_hc.metal +861 -0
- package/ds4/metal/dsv4_kv.metal +227 -0
- package/ds4/metal/dsv4_misc.metal +1088 -0
- package/ds4/metal/dsv4_rope.metal +155 -0
- package/ds4/metal/flash_attn.metal +1426 -0
- package/ds4/metal/get_rows.metal +54 -0
- package/ds4/metal/glu.metal +36 -0
- package/ds4/metal/moe.metal +1737 -0
- package/ds4/metal/norm.metal +153 -0
- package/ds4/metal/repeat.metal +52 -0
- package/ds4/metal/set_rows.metal +55 -0
- package/ds4/metal/softmax.metal +241 -0
- package/ds4/metal/sum_rows.metal +102 -0
- package/ds4/metal/unary.metal +312 -0
- package/native/binding.cpp +621 -0
- package/package.json +66 -0
- package/scripts/postinstall.cjs +13 -0
- package/scripts/vendor-ds4.cjs +67 -0
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
// DS4 Metal get-rows kernel.
|
|
2
|
+
|
|
3
|
+
struct ds4_metal_args_get_rows {
|
|
4
|
+
int32_t ne00t;
|
|
5
|
+
int32_t ne00;
|
|
6
|
+
uint64_t nb01;
|
|
7
|
+
uint64_t nb02;
|
|
8
|
+
uint64_t nb03;
|
|
9
|
+
int32_t ne10;
|
|
10
|
+
uint64_t nb10;
|
|
11
|
+
uint64_t nb11;
|
|
12
|
+
uint64_t nb12;
|
|
13
|
+
uint64_t nb1;
|
|
14
|
+
uint64_t nb2;
|
|
15
|
+
uint64_t nb3;
|
|
16
|
+
};
|
|
17
|
+
|
|
18
|
+
// Gathers embedding/table rows by integer ids. DS4 uses this for token
|
|
19
|
+
// embeddings and small indexed tables such as router/hash lookup outputs.
|
|
20
|
+
template<typename T0, typename T>
|
|
21
|
+
kernel void kernel_get_rows_f(
|
|
22
|
+
constant ds4_metal_args_get_rows & args,
|
|
23
|
+
device const char * src0,
|
|
24
|
+
device const char * src1,
|
|
25
|
+
device char * dst,
|
|
26
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
27
|
+
ushort tiitg[[thread_index_in_threadgroup]],
|
|
28
|
+
ushort3 ntg [[threads_per_threadgroup]]) {
|
|
29
|
+
const int32_t iw0 = tgpig.x/args.ne10;
|
|
30
|
+
const int32_t i10 = tgpig.x%args.ne10;
|
|
31
|
+
const int32_t i11 = tgpig.y;
|
|
32
|
+
const int32_t i12 = tgpig.z;
|
|
33
|
+
|
|
34
|
+
const int32_t r = ((const device int32_t *) (src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
|
|
35
|
+
|
|
36
|
+
const int32_t i02 = i11;
|
|
37
|
+
const int32_t i03 = i12;
|
|
38
|
+
|
|
39
|
+
auto psrc = (const device T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
|
|
40
|
+
auto pdst = ( device T *) (dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
|
|
41
|
+
|
|
42
|
+
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
|
|
43
|
+
pdst[ind] = psrc[ind];
|
|
44
|
+
|
|
45
|
+
break;
|
|
46
|
+
}
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;
|
|
50
|
+
|
|
51
|
+
// Host-visible gather variants for F32, F16, and I32 tables.
|
|
52
|
+
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float, float>;
|
|
53
|
+
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half, float>;
|
|
54
|
+
template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
struct ds4_metal_args_glu {
|
|
2
|
+
int32_t ne00;
|
|
3
|
+
uint64_t nb01;
|
|
4
|
+
int32_t ne10;
|
|
5
|
+
uint64_t nb11;
|
|
6
|
+
int32_t ne0;
|
|
7
|
+
uint64_t nb1;
|
|
8
|
+
int32_t i00;
|
|
9
|
+
int32_t i10;
|
|
10
|
+
float alpha;
|
|
11
|
+
float limit;
|
|
12
|
+
};
|
|
13
|
+
|
|
14
|
+
// SwiGLU activation for the FFN inner state: silu(gate) * up. The DS4 graph
|
|
15
|
+
// uses it between the gate/up expert matmuls and the down projection.
|
|
16
|
+
kernel void kernel_swiglu_f32(
|
|
17
|
+
constant ds4_metal_args_glu & args,
|
|
18
|
+
device const char * src0,
|
|
19
|
+
device const char * src1,
|
|
20
|
+
device char * dst,
|
|
21
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
22
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
23
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
24
|
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
25
|
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
26
|
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
|
27
|
+
|
|
28
|
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
29
|
+
const float x0 = src0_row[i0];
|
|
30
|
+
const float x1 = src1_row[i0];
|
|
31
|
+
|
|
32
|
+
const float silu = x0 / (1.0f + exp(-x0));
|
|
33
|
+
|
|
34
|
+
dst_row[i0] = silu*x1;
|
|
35
|
+
}
|
|
36
|
+
}
|