@lgrammel/ds4-provider 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. package/README.md +96 -0
  2. package/binding.gyp +75 -0
  3. package/dist/ds4-language-model.d.ts +71 -0
  4. package/dist/ds4-language-model.d.ts.map +1 -0
  5. package/dist/ds4-language-model.js +888 -0
  6. package/dist/ds4-language-model.js.map +1 -0
  7. package/dist/ds4-provider.d.ts +13 -0
  8. package/dist/ds4-provider.d.ts.map +1 -0
  9. package/dist/ds4-provider.js +20 -0
  10. package/dist/ds4-provider.js.map +1 -0
  11. package/dist/index.d.ts +4 -0
  12. package/dist/index.d.ts.map +1 -0
  13. package/dist/index.js +4 -0
  14. package/dist/index.js.map +1 -0
  15. package/dist/native-binding.d.ts +42 -0
  16. package/dist/native-binding.d.ts.map +1 -0
  17. package/dist/native-binding.js +157 -0
  18. package/dist/native-binding.js.map +1 -0
  19. package/ds4/LICENSE +22 -0
  20. package/ds4/ds4.c +18268 -0
  21. package/ds4/ds4.h +196 -0
  22. package/ds4/ds4_gpu.h +804 -0
  23. package/ds4/ds4_metal.m +14657 -0
  24. package/ds4/metal/argsort.metal +266 -0
  25. package/ds4/metal/bin.metal +192 -0
  26. package/ds4/metal/concat.metal +62 -0
  27. package/ds4/metal/cpy.metal +57 -0
  28. package/ds4/metal/dense.metal +1121 -0
  29. package/ds4/metal/dsv4_hc.metal +861 -0
  30. package/ds4/metal/dsv4_kv.metal +227 -0
  31. package/ds4/metal/dsv4_misc.metal +1088 -0
  32. package/ds4/metal/dsv4_rope.metal +155 -0
  33. package/ds4/metal/flash_attn.metal +1426 -0
  34. package/ds4/metal/get_rows.metal +54 -0
  35. package/ds4/metal/glu.metal +36 -0
  36. package/ds4/metal/moe.metal +1737 -0
  37. package/ds4/metal/norm.metal +153 -0
  38. package/ds4/metal/repeat.metal +52 -0
  39. package/ds4/metal/set_rows.metal +55 -0
  40. package/ds4/metal/softmax.metal +241 -0
  41. package/ds4/metal/sum_rows.metal +102 -0
  42. package/ds4/metal/unary.metal +312 -0
  43. package/native/binding.cpp +621 -0
  44. package/package.json +66 -0
  45. package/scripts/postinstall.cjs +13 -0
  46. package/scripts/vendor-ds4.cjs +67 -0
@@ -0,0 +1,227 @@
1
+ constant float dsv4_e4m3fn_exp_scale[16] = {
2
+ 0.0f, 0.015625f, 0.03125f, 0.0625f,
3
+ 0.125f, 0.25f, 0.5f, 1.0f,
4
+ 2.0f, 4.0f, 8.0f, 16.0f,
5
+ 32.0f, 64.0f, 128.0f, 256.0f,
6
+ };
7
+
8
+ struct ds4_metal_args_dsv4_fp8_kv_quantize {
9
+ int64_t ne00;
10
+ int64_t ne01;
11
+ int64_t ne02;
12
+ int64_t ne03;
13
+ ulong nb00;
14
+ ulong nb01;
15
+ ulong nb02;
16
+ ulong nb03;
17
+ ulong nb0;
18
+ ulong nb1;
19
+ ulong nb2;
20
+ ulong nb3;
21
+ int n_rot;
22
+ };
23
+
24
+ struct ds4_metal_args_dsv4_kv_fp8_store {
25
+ int32_t head_dim;
26
+ int32_t n_rot;
27
+ int32_t raw_row;
28
+ };
29
+
30
+ struct ds4_metal_args_dsv4_ratio4_shift {
31
+ uint32_t width;
32
+ };
33
+
34
+ struct ds4_metal_args_dsv4_compressor_store_one {
35
+ uint32_t width;
36
+ uint32_t ratio;
37
+ uint32_t pos;
38
+ uint32_t ape_type;
39
+ };
40
+
41
+ static inline float dsv4_e4m3fn_value(int i) {
42
+ const int exp = (i >> 3) & 0x0f;
43
+ const int mant = i & 0x07;
44
+ return exp == 0
45
+ ? float(mant) * 0.001953125f
46
+ : (1.0f + float(mant) * 0.125f) * dsv4_e4m3fn_exp_scale[exp];
47
+ }
48
+
49
+ static inline float dsv4_e4m3fn_dequant(float x) {
50
+ const float sign = x < 0.0f ? -1.0f : 1.0f;
51
+ const float ax = min(abs(x), 448.0f);
52
+
53
+ int lo = 0;
54
+ int hi = 126;
55
+ while (lo < hi) {
56
+ const int mid = (lo + hi + 1) >> 1;
57
+ if (dsv4_e4m3fn_value(mid) <= ax) {
58
+ lo = mid;
59
+ } else {
60
+ hi = mid - 1;
61
+ }
62
+ }
63
+
64
+ int best = lo;
65
+ if (best < 126) {
66
+ const float best_diff = abs(ax - dsv4_e4m3fn_value(best));
67
+ const float next_diff = abs(ax - dsv4_e4m3fn_value(best + 1));
68
+ if (next_diff < best_diff || (next_diff == best_diff && ((best + 1) & 1) == 0 && (best & 1) != 0)) {
69
+ best = best + 1;
70
+ }
71
+ }
72
+
73
+ return sign * dsv4_e4m3fn_value(best);
74
+ }
75
+
76
+ // Quantizes the non-RoPE part of a KV row through E4M3FN and writes the
77
+ // dequantized value back as float. DS4 uses this to match the FP8 KV-cache
78
+ // semantics while keeping the Metal graph's cache buffers float-addressable.
79
+ kernel void kernel_dsv4_fp8_kv_quantize_f32(
80
+ constant ds4_metal_args_dsv4_fp8_kv_quantize & args,
81
+ device const char * src0,
82
+ device char * dst,
83
+ threadgroup float * scratch [[threadgroup(0)]],
84
+ uint row [[threadgroup_position_in_grid]],
85
+ uint tid [[thread_position_in_threadgroup]]) {
86
+ const int64_t n_rows = args.ne01 * args.ne02 * args.ne03;
87
+ if ((int64_t) row >= n_rows) {
88
+ return;
89
+ }
90
+
91
+ const int64_t i1 = row % args.ne01;
92
+ const int64_t i2 = (row / args.ne01) % args.ne02;
93
+ const int64_t i3 = row / (args.ne01 * args.ne02);
94
+
95
+ device const char * src_base = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
96
+ device char * dst_base = dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3;
97
+
98
+ const int64_t n_nope = args.ne00 - args.n_rot;
99
+
100
+ for (int64_t off = 0; off < n_nope; off += 64) {
101
+ float v = 0.0f;
102
+ if (tid < 64) {
103
+ v = *((device const float *) (src_base + (off + tid)*args.nb00));
104
+ scratch[tid] = abs(v);
105
+ }
106
+ threadgroup_barrier(mem_flags::mem_threadgroup);
107
+
108
+ for (uint stride = 32; stride > 0; stride >>= 1) {
109
+ if (tid < stride) {
110
+ scratch[tid] = max(scratch[tid], scratch[tid + stride]);
111
+ }
112
+ threadgroup_barrier(mem_flags::mem_threadgroup);
113
+ }
114
+
115
+ const float amax = max(scratch[0], 1.0e-4f);
116
+ const float scale = exp2(ceil(log2(amax / 448.0f)));
117
+ if (tid < 64) {
118
+ const float q = dsv4_e4m3fn_dequant(clamp(v / scale, -448.0f, 448.0f)) * scale;
119
+ *((device float *) (dst_base + (off + tid)*args.nb0)) = q;
120
+ }
121
+ threadgroup_barrier(mem_flags::mem_threadgroup);
122
+ }
123
+
124
+ for (int64_t i = n_nope + tid; i < args.ne00; i += 64) {
125
+ *((device float *) (dst_base + i*args.nb0)) = *((device const float *) (src_base + i*args.nb00));
126
+ }
127
+ }
128
+
129
+ // Decode-side KV finalizer after RoPE. The normal RoPE kernel intentionally
130
+ // remains separate because tiny trigonometric codegen changes can flip later
131
+ // sampled tokens. This kernel only fuses the FP8 round-trip for the non-RoPE
132
+ // prefix with the F16-rounded raw-cache row used by FlashAttention.
133
+ kernel void kernel_dsv4_kv_fp8_store_f32(
134
+ constant ds4_metal_args_dsv4_kv_fp8_store & args,
135
+ device float * kv,
136
+ device float * raw_cache,
137
+ threadgroup float * scratch [[threadgroup(0)]],
138
+ uint tid [[thread_position_in_threadgroup]]) {
139
+ const int head_dim = args.head_dim;
140
+ const int n_rot = args.n_rot;
141
+ const int n_nope = head_dim - n_rot;
142
+ if (head_dim <= 0 || n_rot < 0 || n_nope < 0 || tid >= 64) {
143
+ return;
144
+ }
145
+
146
+ device float * raw = raw_cache + (int64_t)args.raw_row * head_dim;
147
+
148
+ for (int off = 0; off < n_nope; off += 64) {
149
+ float v = 0.0f;
150
+ if (off + (int)tid < n_nope) {
151
+ v = kv[off + tid];
152
+ scratch[tid] = abs(v);
153
+ } else {
154
+ scratch[tid] = 0.0f;
155
+ }
156
+ threadgroup_barrier(mem_flags::mem_threadgroup);
157
+
158
+ for (uint stride = 32; stride > 0; stride >>= 1) {
159
+ if (tid < stride) {
160
+ scratch[tid] = max(scratch[tid], scratch[tid + stride]);
161
+ }
162
+ threadgroup_barrier(mem_flags::mem_threadgroup);
163
+ }
164
+
165
+ const float amax = max(scratch[0], 1.0e-4f);
166
+ const float fp8_scale = exp2(ceil(log2(amax / 448.0f)));
167
+ if (off + (int)tid < n_nope) {
168
+ const float q = dsv4_e4m3fn_dequant(clamp(v / fp8_scale, -448.0f, 448.0f)) * fp8_scale;
169
+ kv[off + tid] = q;
170
+ raw[off + tid] = (float)((half)q);
171
+ }
172
+ threadgroup_barrier(mem_flags::mem_threadgroup);
173
+ }
174
+
175
+ for (int i = n_nope + tid; i < head_dim; i += 64) {
176
+ raw[i] = (float)((half)kv[i]);
177
+ }
178
+ }
179
+
180
+ // Ratio-4 compression keeps two 4-row halves of recurrent state. After an
181
+ // emitted compressed row, the second half becomes the next window's previous
182
+ // half. The old encoder expressed this as four generic copies; this DS4-specific
183
+ // kernel performs the KV and score copies together.
184
+ kernel void kernel_dsv4_ratio4_shift_f32(
185
+ constant ds4_metal_args_dsv4_ratio4_shift & args,
186
+ device float * state_kv,
187
+ device float * state_score,
188
+ uint gid [[thread_position_in_grid]]) {
189
+ const uint n = 4u * args.width;
190
+ if (gid >= n) return;
191
+
192
+ state_kv[gid] = state_kv[n + gid];
193
+ state_score[gid] = state_score[n + gid];
194
+ }
195
+
196
+ // One-token compressor frontier update. Decode appends exactly one projected KV
197
+ // row and one score row into a small recurrent state. The generic batch helper
198
+ // expresses this as APE copy, score add, and two set_rows operations; this
199
+ // kernel writes both state tensors directly while preserving the same
200
+ // score + APE arithmetic.
201
+ kernel void kernel_dsv4_compressor_store_one(
202
+ constant ds4_metal_args_dsv4_compressor_store_one & args,
203
+ device const float * kv,
204
+ device const float * score,
205
+ device const char * ape,
206
+ device float * state_kv,
207
+ device float * state_score,
208
+ uint gid [[thread_position_in_grid]]) {
209
+ if (gid >= args.width || args.width == 0 || args.ratio == 0) {
210
+ return;
211
+ }
212
+
213
+ const uint pos_mod = args.pos % args.ratio;
214
+ const uint dst_row = args.ratio == 4u ? args.ratio + pos_mod : pos_mod;
215
+ const uint dst = dst_row * args.width + gid;
216
+ const uint ape_i = pos_mod * args.width + gid;
217
+
218
+ float ape_v;
219
+ if (args.ape_type == 1u) {
220
+ ape_v = (float)(((device const half *)ape)[ape_i]);
221
+ } else {
222
+ ape_v = ((device const float *)ape)[ape_i];
223
+ }
224
+
225
+ state_kv[dst] = kv[gid];
226
+ state_score[dst] = score[gid] + ape_v;
227
+ }