@seanhogg/builderforce-memory-engine 2026.6.18

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +393 -0
  3. package/dist/index.d.ts +32 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +40 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/kernels/activations.d.ts +5 -0
  8. package/dist/kernels/activations.d.ts.map +1 -0
  9. package/dist/kernels/activations.js +171 -0
  10. package/dist/kernels/activations.js.map +1 -0
  11. package/dist/kernels/attention.d.ts +19 -0
  12. package/dist/kernels/attention.d.ts.map +1 -0
  13. package/dist/kernels/attention.js +263 -0
  14. package/dist/kernels/attention.js.map +1 -0
  15. package/dist/kernels/complex_ssd.d.ts +33 -0
  16. package/dist/kernels/complex_ssd.d.ts.map +1 -0
  17. package/dist/kernels/complex_ssd.js +305 -0
  18. package/dist/kernels/complex_ssd.js.map +1 -0
  19. package/dist/kernels/conv1d.d.ts +3 -0
  20. package/dist/kernels/conv1d.d.ts.map +1 -0
  21. package/dist/kernels/conv1d.js +158 -0
  22. package/dist/kernels/conv1d.js.map +1 -0
  23. package/dist/kernels/linear_projection.d.ts +3 -0
  24. package/dist/kernels/linear_projection.d.ts.map +1 -0
  25. package/dist/kernels/linear_projection.js +219 -0
  26. package/dist/kernels/linear_projection.js.map +1 -0
  27. package/dist/kernels/selective_scan.d.ts +3 -0
  28. package/dist/kernels/selective_scan.d.ts.map +1 -0
  29. package/dist/kernels/selective_scan.js +348 -0
  30. package/dist/kernels/selective_scan.js.map +1 -0
  31. package/dist/kernels/ssd.d.ts +29 -0
  32. package/dist/kernels/ssd.d.ts.map +1 -0
  33. package/dist/kernels/ssd.js +276 -0
  34. package/dist/kernels/ssd.js.map +1 -0
  35. package/dist/kernels/weight_update.d.ts +3 -0
  36. package/dist/kernels/weight_update.d.ts.map +1 -0
  37. package/dist/kernels/weight_update.js +119 -0
  38. package/dist/kernels/weight_update.js.map +1 -0
  39. package/dist/model/attention_block.d.ts +48 -0
  40. package/dist/model/attention_block.d.ts.map +1 -0
  41. package/dist/model/attention_block.js +262 -0
  42. package/dist/model/attention_block.js.map +1 -0
  43. package/dist/model/mamba1_block.d.ts +70 -0
  44. package/dist/model/mamba1_block.d.ts.map +1 -0
  45. package/dist/model/mamba1_block.js +333 -0
  46. package/dist/model/mamba1_block.js.map +1 -0
  47. package/dist/model/mamba2_block.d.ts +44 -0
  48. package/dist/model/mamba2_block.d.ts.map +1 -0
  49. package/dist/model/mamba2_block.js +252 -0
  50. package/dist/model/mamba2_block.js.map +1 -0
  51. package/dist/model/mamba3_block.d.ts +51 -0
  52. package/dist/model/mamba3_block.d.ts.map +1 -0
  53. package/dist/model/mamba3_block.js +270 -0
  54. package/dist/model/mamba3_block.js.map +1 -0
  55. package/dist/model/mamba_block.d.ts +64 -0
  56. package/dist/model/mamba_block.d.ts.map +1 -0
  57. package/dist/model/mamba_block.js +303 -0
  58. package/dist/model/mamba_block.js.map +1 -0
  59. package/dist/model/mamba_model.d.ts +140 -0
  60. package/dist/model/mamba_model.d.ts.map +1 -0
  61. package/dist/model/mamba_model.js +527 -0
  62. package/dist/model/mamba_model.js.map +1 -0
  63. package/dist/model/sequence_layer.d.ts +25 -0
  64. package/dist/model/sequence_layer.d.ts.map +1 -0
  65. package/dist/model/sequence_layer.js +8 -0
  66. package/dist/model/sequence_layer.js.map +1 -0
  67. package/dist/tokenizer/bpe.d.ts +29 -0
  68. package/dist/tokenizer/bpe.d.ts.map +1 -0
  69. package/dist/tokenizer/bpe.js +164 -0
  70. package/dist/tokenizer/bpe.js.map +1 -0
  71. package/dist/training/autograd.d.ts +27 -0
  72. package/dist/training/autograd.d.ts.map +1 -0
  73. package/dist/training/autograd.js +120 -0
  74. package/dist/training/autograd.js.map +1 -0
  75. package/dist/training/trainer.d.ts +36 -0
  76. package/dist/training/trainer.d.ts.map +1 -0
  77. package/dist/training/trainer.js +183 -0
  78. package/dist/training/trainer.js.map +1 -0
  79. package/dist/utils/gpu_utils.d.ts +21 -0
  80. package/dist/utils/gpu_utils.d.ts.map +1 -0
  81. package/dist/utils/gpu_utils.js +111 -0
  82. package/dist/utils/gpu_utils.js.map +1 -0
  83. package/dist/utils/quantization.d.ts +26 -0
  84. package/dist/utils/quantization.d.ts.map +1 -0
  85. package/dist/utils/quantization.js +116 -0
  86. package/dist/utils/quantization.js.map +1 -0
  87. package/dist/utils/rng.d.ts +36 -0
  88. package/dist/utils/rng.d.ts.map +1 -0
  89. package/dist/utils/rng.js +61 -0
  90. package/dist/utils/rng.js.map +1 -0
  91. package/package.json +99 -0
  92. package/src/index.ts +114 -0
  93. package/src/kernels/activations.ts +174 -0
  94. package/src/kernels/attention.ts +268 -0
  95. package/src/kernels/complex_ssd.ts +307 -0
  96. package/src/kernels/conv1d.ts +159 -0
  97. package/src/kernels/linear_projection.ts +220 -0
  98. package/src/kernels/selective_scan.ts +350 -0
  99. package/src/kernels/ssd.ts +278 -0
  100. package/src/kernels/weight_update.ts +120 -0
  101. package/src/model/attention_block.ts +344 -0
  102. package/src/model/mamba1_block.ts +437 -0
  103. package/src/model/mamba2_block.ts +319 -0
  104. package/src/model/mamba3_block.ts +335 -0
  105. package/src/model/mamba_block.ts +401 -0
  106. package/src/model/mamba_model.ts +678 -0
  107. package/src/model/sequence_layer.ts +29 -0
  108. package/src/tokenizer/bpe.ts +186 -0
  109. package/src/training/autograd.ts +135 -0
  110. package/src/training/trainer.ts +309 -0
  111. package/src/utils/gpu_utils.ts +147 -0
  112. package/src/utils/quantization.ts +154 -0
  113. package/src/utils/rng.ts +65 -0
@@ -0,0 +1,348 @@
1
+ // Parallel Selective Scan WGSL Kernel
2
+ // Implements the S6 (Selective Scan) core of the Mamba architecture.
3
+ // Uses a Kogge-Stone parallel prefix-sum approach for O(log N) time on GPU.
4
+ //
5
+ // Forward pass recurrence:
6
+ // h_t = A_t * h_{t-1} + B_t * x_t
7
+ // y_t = C_t * h_t + D * x_t
8
+ //
9
+ // where A_t, B_t, C_t are input-dependent (selective) gate matrices.
10
+ export const SELECTIVE_SCAN_FORWARD_WGSL = /* wgsl */ `
11
+
12
+ // ---- Binding layout ----
13
+ // group 0: sequence data
14
+ // group 1: SSM parameters
15
+
16
+ struct ScanParams {
17
+ seq_len : u32, // L – sequence length
18
+ d_state : u32, // N – state dimension
19
+ d_inner : u32, // D – inner (expanded) channel dimension
20
+ batch : u32, // B – batch size
21
+ };
22
+
23
+ @group(0) @binding(0) var<uniform> params : ScanParams;
24
+ // u (B, L, D) – projected input after conv
25
+ @group(0) @binding(1) var<storage, read> u : array<f32>;
26
+ // delta (B, L, D) – time-step (Δ) after softplus
27
+ @group(0) @binding(2) var<storage, read> delta : array<f32>;
28
+ // A (D, N) – log-space diagonal state matrix (fixed, learned)
29
+ @group(0) @binding(3) var<storage, read> A : array<f32>;
30
+ // B (B, L, N) – input projection (selective)
31
+ @group(0) @binding(4) var<storage, read> B : array<f32>;
32
+ // C (B, L, N) – output projection (selective)
33
+ @group(0) @binding(5) var<storage, read> C : array<f32>;
34
+ // D (D,) – skip-connection scale
35
+ @group(0) @binding(6) var<storage, read> D_vec : array<f32>;
36
+ // y (B, L, D) – output (written by this kernel)
37
+ @group(0) @binding(7) var<storage, read_write> y : array<f32>;
38
+ // h_cache (B, L, D*N) – hidden states cache (for backward pass)
39
+ @group(0) @binding(8) var<storage, read_write> h_cache : array<f32>;
40
+
41
+ // ---- Workgroup shared memory ----
42
+ // Each workgroup processes one (batch, channel) slice across all time steps.
43
+ // We store the associative pair (a_bar, bu_bar) per time step so we can run
44
+ // a Kogge-Stone scan across the workgroup tile.
45
+ var<workgroup> wg_a : array<f32, 256>; // discretised A values
46
+ var<workgroup> wg_bu : array<f32, 256>; // B*u values
47
+
48
+ // ---- Helpers ----
49
+
50
+ // Softplus: numerically stable log(1 + exp(x))
51
+ fn softplus(x: f32) -> f32 {
52
+ return log(1.0 + exp(x));
53
+ }
54
+
55
+ // ZerO-Order Hold discretisation of continuous A, Δ:
56
+ // A_bar = exp(Δ * A)
57
+ // B_bar = (A_bar - 1) / A * B ≈ Δ * B (first-order for simplicity)
58
+ fn discretise_A(delta_val: f32, a_log: f32) -> f32 {
59
+ // A is stored as -exp(a_log) to ensure A_bar < 1 (stable)
60
+ let a_cont = -exp(a_log);
61
+ return exp(delta_val * a_cont);
62
+ }
63
+
64
+ fn discretise_B(delta_val: f32, a_log: f32, b_val: f32) -> f32 {
65
+ let a_cont = -exp(a_log);
66
+ let a_bar = exp(delta_val * a_cont);
67
+ // (A_bar - 1) / A_cont * B
68
+ let b_bar = (a_bar - 1.0) / a_cont * b_val;
69
+ return b_bar;
70
+ }
71
+
72
+ // ---- Main kernel ----
73
+ // Dispatch: (ceil(D/8), ceil(N/8), B)
74
+ // Each invocation is responsible for one (d, n, batch) triplet and scans
75
+ // the entire sequence using a two-pass Kogge-Stone scan within workgroup tiles.
76
+
77
+ @compute @workgroup_size(64, 1, 1)
78
+ fn forward_scan(
79
+ @builtin(global_invocation_id) gid : vec3<u32>,
80
+ @builtin(local_invocation_index) lid : u32,
81
+ @builtin(workgroup_id) wgid : vec3<u32>,
82
+ ) {
83
+ let L = params.seq_len;
84
+ let N = params.d_state;
85
+ let D = params.d_inner;
86
+ let B = params.batch;
87
+
88
+ // Each workgroup handles one (batch b, channel d, state n) combination.
89
+ // We pack d and n into the x dimension: global d = wgid.x, global n = wgid.y
90
+ let d = wgid.x;
91
+ let n = wgid.y;
92
+ let b = gid.z;
93
+
94
+ if (d >= D || n >= N || b >= B) { return; }
95
+
96
+ // Tile size equals workgroup size (64). We process TILE_SIZE steps at once.
97
+ let TILE: u32 = 64u;
98
+
99
+ // Running state h for this (b, d, n)
100
+ var h: f32 = 0.0;
101
+
102
+ var tile_start: u32 = 0u;
103
+ loop {
104
+ if (tile_start >= L) { break; }
105
+
106
+ let t = tile_start + lid; // absolute time step handled by this lane
107
+ var a_bar: f32 = 1.0;
108
+ var bu: f32 = 0.0;
109
+
110
+ if (t < L) {
111
+ // Indices
112
+ let delta_idx = b * L * D + t * D + d;
113
+ let u_idx = b * L * D + t * D + d;
114
+ let A_idx = d * N + n;
115
+ let B_idx = b * L * N + t * N + n;
116
+
117
+ let dv = softplus(delta[delta_idx]);
118
+ a_bar = discretise_A(dv, A[A_idx]);
119
+ bu = discretise_B(dv, A[A_idx], B[B_idx]) * u[u_idx];
120
+ }
121
+
122
+ wg_a[lid] = a_bar;
123
+ wg_bu[lid] = bu;
124
+ workgroupBarrier();
125
+
126
+ // ---- Kogge-Stone inclusive prefix scan within tile ----
127
+ // Associative operator: (a1, b1) ∘ (a2, b2) = (a1*a2, a1*b2 + b1)
128
+ // This computes cumulative state recurrence in log2(TILE) steps.
129
+ var stride: u32 = 1u;
130
+ loop {
131
+ if (stride >= TILE) { break; }
132
+ if (lid >= stride) {
133
+ let prev_a = wg_a[lid - stride];
134
+ let prev_bu = wg_bu[lid - stride];
135
+ // Combine: new_state = prev_a * cur_a (product of A_bars)
136
+ // new_bu = prev_a * cur_bu + prev_bu
137
+ let new_a = prev_a * wg_a[lid];
138
+ let new_bu = prev_a * wg_bu[lid] + prev_bu;
139
+ workgroupBarrier();
140
+ wg_a[lid] = new_a;
141
+ wg_bu[lid] = new_bu;
142
+ }
143
+ workgroupBarrier();
144
+ stride = stride << 1u;
145
+ }
146
+
147
+ // Incorporate the carry-in state from the previous tile.
148
+ // After the scan wg_bu[lid] holds the intra-tile inclusive sum.
149
+ // The actual h at position t = h_carry * wg_a[lid] + wg_bu[lid]
150
+ let h_t = h * wg_a[lid] + wg_bu[lid];
151
+
152
+ if (t < L) {
153
+ // Cache hidden state for backward pass
154
+ let h_idx = b * L * D * N + t * D * N + d * N + n;
155
+ h_cache[h_idx] = h_t;
156
+
157
+ // Accumulate y contribution: y_t += C_t[n] * h_t (over all n)
158
+ // We use an atomic-style accumulation: each (d, n) lane adds its
159
+ // contribution to the same y[b, t, d]. This races without atomics,
160
+ // so we instead write to a full h_cache and reduce in a second pass.
161
+ // Here we perform direct accumulation using atomicAdd approximation:
162
+ // (safe because each lane writes a unique n, which is stride 1 in mem)
163
+ let C_idx = b * L * N + t * N + n;
164
+ let y_idx = b * L * D + t * D + d;
165
+
166
+ // Direct write for n == 0 (first state dim), add for the rest.
167
+ // Since all workgroups for the same (b,d) run concurrently we must
168
+ // accumulate safely: we write each partial into h_cache and reduce
169
+ // in a subsequent lightweight kernel (forward_reduce).
170
+ // (For simplicity and correctness here we directly atomically add via
171
+ // f32 emulation – real deployment uses atomicAdd on f32 with spirv ext.)
172
+ // We store C*h contribution separately so forward_reduce can sum them.
173
+ // Layout: y_partial (B, L, D, N) – one slot per state dim
174
+ // y reused as y_partial in this kernel; forward_reduce collapses N dim.
175
+ let y_partial_idx = b * L * D * N + t * D * N + d * N + n;
176
+ // Reuse h_cache second half as y_partial (offset by B*L*D*N)
177
+ let offset = B * L * D * N;
178
+ h_cache[offset + y_partial_idx] = C[C_idx] * h_t;
179
+ }
180
+
181
+ // Update carry: last lane's h_t is the tile's final state
182
+ let last = min(TILE, L - tile_start) - 1u;
183
+ h = wg_a[last] * h + wg_bu[last]; // recombine carry
184
+
185
+ workgroupBarrier();
186
+ tile_start = tile_start + TILE;
187
+ }
188
+ }
189
+
190
+ // ---- Reduction kernel ----
191
+ // Collapses the N (d_state) dimension of y_partial into y.
192
+ // Adds the D (skip connection) term: y_t[d] += D_vec[d] * u_t[d]
193
+ // Dispatch: (ceil(L/64), D, B)
194
+
195
+ @compute @workgroup_size(64, 1, 1)
196
+ fn forward_reduce(
197
+ @builtin(global_invocation_id) gid : vec3<u32>,
198
+ ) {
199
+ let L = params.seq_len;
200
+ let N = params.d_state;
201
+ let D = params.d_inner;
202
+ let B = params.batch;
203
+
204
+ let t = gid.x;
205
+ let d = gid.y;
206
+ let b = gid.z;
207
+
208
+ if (t >= L || d >= D || b >= B) { return; }
209
+
210
+ let offset = B * L * D * N;
211
+ var sum: f32 = 0.0;
212
+ for (var n: u32 = 0u; n < N; n = n + 1u) {
213
+ let idx = offset + b * L * D * N + t * D * N + d * N + n;
214
+ sum = sum + h_cache[idx];
215
+ }
216
+
217
+ // Add skip connection
218
+ let u_idx = b * L * D + t * D + d;
219
+ sum = sum + D_vec[d] * u[u_idx];
220
+
221
+ let y_idx = b * L * D + t * D + d;
222
+ y[y_idx] = sum;
223
+ }
224
+ `;
225
+ // ---- Backward scan kernel (for autograd) ----
226
+ // Computes gradients w.r.t. Δ, A, B, C using the cached hidden states.
227
+ export const SELECTIVE_SCAN_BACKWARD_WGSL = /* wgsl */ `
228
+
229
+ struct ScanParams {
230
+ seq_len : u32,
231
+ d_state : u32,
232
+ d_inner : u32,
233
+ batch : u32,
234
+ };
235
+
236
+ @group(0) @binding(0) var<uniform> params : ScanParams;
237
+ @group(0) @binding(1) var<storage, read> u : array<f32>;
238
+ @group(0) @binding(2) var<storage, read> delta : array<f32>;
239
+ @group(0) @binding(3) var<storage, read> A : array<f32>;
240
+ @group(0) @binding(4) var<storage, read> B : array<f32>;
241
+ @group(0) @binding(5) var<storage, read> C : array<f32>;
242
+ @group(0) @binding(6) var<storage, read> h_cache : array<f32>;
243
+ @group(0) @binding(7) var<storage, read> dy : array<f32>; // upstream gradient
244
+ @group(0) @binding(8) var<storage, read_write> dA : array<f32>;
245
+ @group(0) @binding(9) var<storage, read_write> dB : array<f32>;
246
+ @group(0) @binding(10) var<storage, read_write> dC : array<f32>;
247
+ @group(0) @binding(11) var<storage, read_write> dDelta : array<f32>;
248
+ @group(0) @binding(12) var<storage, read_write> du : array<f32>;
249
+
250
+ fn softplus(x: f32) -> f32 {
251
+ return log(1.0 + exp(x));
252
+ }
253
+
254
+ fn softplus_grad(x: f32) -> f32 {
255
+ // d/dx softplus(x) = sigmoid(x)
256
+ return 1.0 / (1.0 + exp(-x));
257
+ }
258
+
259
+ fn discretise_A(delta_val: f32, a_log: f32) -> f32 {
260
+ let a_cont = -exp(a_log);
261
+ return exp(delta_val * a_cont);
262
+ }
263
+
264
+ // Reverse scan (backward pass) – processes time from T-1 down to 0.
265
+ // Dispatch: (D, N, B)
266
+ @compute @workgroup_size(1, 1, 1)
267
+ fn backward_scan(
268
+ @builtin(global_invocation_id) gid : vec3<u32>,
269
+ ) {
270
+ let L = params.seq_len;
271
+ let N = params.d_state;
272
+ let D = params.d_inner;
273
+ let B = params.batch;
274
+
275
+ let d = gid.x;
276
+ let n = gid.y;
277
+ let b = gid.z;
278
+
279
+ if (d >= D || n >= N || b >= B) { return; }
280
+
281
+ var dh: f32 = 0.0; // gradient of loss w.r.t. h_t, accumulated backwards
282
+
283
+ var t: u32 = L;
284
+ loop {
285
+ if (t == 0u) { break; }
286
+ t = t - 1u;
287
+
288
+ let delta_raw_idx = b * L * D + t * D + d;
289
+ let A_idx = d * N + n;
290
+ let B_idx = b * L * N + t * N + n;
291
+ let C_idx = b * L * N + t * N + n;
292
+ let u_idx = b * L * D + t * D + d;
293
+ let h_idx = b * L * D * N + t * D * N + d * N + n;
294
+
295
+ let delta_raw = delta[delta_raw_idx];
296
+ let dv = softplus(delta_raw);
297
+ let a_log = A[A_idx];
298
+ let a_cont = -exp(a_log);
299
+ let a_bar = exp(dv * a_cont);
300
+ let b_val = B[B_idx];
301
+ let c_val = C[C_idx];
302
+ let u_val = u[u_idx];
303
+ let h_t = h_cache[h_idx];
304
+
305
+ // dy_t contribution to dh (from C * h_t in the output)
306
+ // y_t[d] = sum_n C[n] * h_t[n] + D * u => dh_t[n] += C[n] * dy_t[d]
307
+ let dy_val = dy[b * L * D + t * D + d];
308
+ dh = dh + c_val * dy_val;
309
+
310
+ // dC[b, t, n] += dy_t[d] * h_t
311
+ dC[C_idx] = dC[C_idx] + dy_val * h_t;
312
+
313
+ // h_t = a_bar * h_{t-1} + b_bar * u_t
314
+ // b_bar = (a_bar - 1) / a_cont * b_val
315
+ let b_bar = (a_bar - 1.0) / a_cont * b_val;
316
+ let h_prev = (t > 0u) ? h_cache[b * L * D * N + (t - 1u) * D * N + d * N + n] : 0.0;
317
+
318
+ // dh_{t-1} += a_bar * dh_t
319
+ // (accumulated in next iteration; here dh already contains upstream)
320
+ let dh_cur = dh;
321
+
322
+ // dA[d,n] += dh_t * (d a_bar/d a_cont) * (d a_cont/d a_log) * h_{t-1}
323
+ // + dh_t * (d b_bar/d a_cont) * ... * b_val * u_val
324
+ // d(a_bar)/d(a_log) = a_bar * (-exp(a_log)) * dv = a_bar * a_cont * dv
325
+ let da_bar_da_log = a_bar * a_cont * dv;
326
+ dA[A_idx] = dA[A_idx] + dh_cur * (da_bar_da_log * h_prev);
327
+
328
+ // dB[b,t,n] += dh_t * b_bar / b_val * u_val (since b_bar is linear in b)
329
+ dB[B_idx] = dB[B_idx] + dh_cur * ((a_bar - 1.0) / a_cont) * u_val;
330
+
331
+ // du[b,t,d] += dh_t * b_bar (accumulate over n in separate kernel)
332
+ du[u_idx] = du[u_idx] + dh_cur * b_bar;
333
+
334
+ // dDelta[b,t,d]: chain rule through softplus and discretisation
335
+ // d(b_bar)/d(dv) = d/d(dv)[(a_bar-1)/a_cont * b] = a_bar * b / (a_cont ... )
336
+ // actually: d(a_bar)/d(dv) = a_bar * a_cont, d(b_bar)/d(dv) = a_bar * b_val
337
+ let da_bar_ddv = a_bar * a_cont;
338
+ let db_bar_ddv = a_bar * b_val;
339
+ let dLoss_ddv = dh_cur * (da_bar_ddv * h_prev + db_bar_ddv * u_val);
340
+ let ddv_ddelta = softplus_grad(delta_raw);
341
+ dDelta[delta_raw_idx] = dDelta[delta_raw_idx] + dLoss_ddv * ddv_ddelta;
342
+
343
+ // Propagate dh to previous timestep
344
+ dh = a_bar * dh_cur;
345
+ }
346
+ }
347
+ `;
348
+ //# sourceMappingURL=selective_scan.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"selective_scan.js","sourceRoot":"","sources":["../../src/kernels/selective_scan.ts"],"names":[],"mappings":"AAAA,sCAAsC;AACtC,qEAAqE;AACrE,4EAA4E;AAC5E,EAAE;AACF,2BAA2B;AAC3B,oCAAoC;AACpC,8BAA8B;AAC9B,EAAE;AACF,qEAAqE;AAErE,MAAM,CAAC,MAAM,2BAA2B,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAsN5D,CAAC;AAEF,gDAAgD;AAChD,uEAAuE;AAEvE,MAAM,CAAC,MAAM,4BAA4B,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAwH7D,CAAC"}
@@ -0,0 +1,29 @@
1
+ /**
2
+ * ssd.ts – Structured State Space Duality (SSD) kernels for Mamba-2.
3
+ *
4
+ * Implements a chunked SSD algorithm:
5
+ * A_bar_t = exp(-softplus(A_h) · softplus(dt_t + dt_bias_h)) [scalar per head]
6
+ * h_t = A_bar_t · h_{t-1} + B_t · x_t [MIMO per head]
7
+ * y_t = C_t · h_t
8
+ *
9
+ * The sequence is split into chunks of `chunk_len` time steps.
10
+ * Within each chunk the recurrence is run sequentially; the carry-over
11
+ * state `h` is passed forward between chunks via the state_carry buffer.
12
+ *
13
+ * Dispatch for ssd_chunk_forward: (num_chunks, H, B)
14
+ * Dispatch for ssd_chunk_backward: (num_chunks, H, B)
15
+ *
16
+ * Buffer layout (all f32, row-major):
17
+ * x : [B, L, D_inner] where D_inner = H * d_head
18
+ * B_proj : [B, L, n_groups, N]
19
+ * C_proj : [B, L, n_groups, N]
20
+ * dt : [B, L, H]
21
+ * A_log : [H] log(-A), positive scalar per head
22
+ * dt_bias : [H]
23
+ * D_vec : [H] skip connection per head
24
+ * out : [B, L, D_inner] scan output (written by kernel)
25
+ * state_carry : [num_chunks+1, B, H, N, d_head] inter-chunk states
26
+ */
27
+ export declare const SSD_FORWARD_WGSL: string;
28
+ export declare const SSD_BACKWARD_WGSL: string;
29
+ //# sourceMappingURL=ssd.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"ssd.d.ts","sourceRoot":"","sources":["../../src/kernels/ssd.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;;;;;;;;;;;GAyBG;AAEH,eAAO,MAAM,gBAAgB,EAAE,MA+H9B,CAAC;AAIF,eAAO,MAAM,iBAAiB,EAAE,MAuH/B,CAAC"}
@@ -0,0 +1,276 @@
1
+ /**
2
+ * ssd.ts – Structured State Space Duality (SSD) kernels for Mamba-2.
3
+ *
4
+ * Implements a chunked SSD algorithm:
5
+ * A_bar_t = exp(-softplus(A_h) · softplus(dt_t + dt_bias_h)) [scalar per head]
6
+ * h_t = A_bar_t · h_{t-1} + B_t · x_t [MIMO per head]
7
+ * y_t = C_t · h_t
8
+ *
9
+ * The sequence is split into chunks of `chunk_len` time steps.
10
+ * Within each chunk the recurrence is run sequentially; the carry-over
11
+ * state `h` is passed forward between chunks via the state_carry buffer.
12
+ *
13
+ * Dispatch for ssd_chunk_forward: (num_chunks, H, B)
14
+ * Dispatch for ssd_chunk_backward: (num_chunks, H, B)
15
+ *
16
+ * Buffer layout (all f32, row-major):
17
+ * x : [B, L, D_inner] where D_inner = H * d_head
18
+ * B_proj : [B, L, n_groups, N]
19
+ * C_proj : [B, L, n_groups, N]
20
+ * dt : [B, L, H]
21
+ * A_log : [H] log(-A), positive scalar per head
22
+ * dt_bias : [H]
23
+ * D_vec : [H] skip connection per head
24
+ * out : [B, L, D_inner] scan output (written by kernel)
25
+ * state_carry : [num_chunks+1, B, H, N, d_head] inter-chunk states
26
+ */
27
+ export const SSD_FORWARD_WGSL = /* wgsl */ `
28
+ struct SsdParams {
29
+ seq_len : u32,
30
+ d_inner : u32,
31
+ n_heads : u32,
32
+ d_head : u32, // d_inner / n_heads
33
+ n_groups : u32,
34
+ d_state : u32, // N
35
+ chunk_len : u32,
36
+ n_chunks : u32,
37
+ batch : u32,
38
+ };
39
+
40
+ @group(0) @binding(0) var<uniform> params : SsdParams;
41
+ @group(0) @binding(1) var<storage, read> x_in : array<f32>; // [B,L,D_inner]
42
+ @group(0) @binding(2) var<storage, read> B_proj : array<f32>; // [B,L,n_groups,N]
43
+ @group(0) @binding(3) var<storage, read> C_proj : array<f32>; // [B,L,n_groups,N]
44
+ @group(0) @binding(4) var<storage, read> dt_in : array<f32>; // [B,L,H]
45
+ @group(0) @binding(5) var<storage, read> A_log : array<f32>; // [H]
46
+ @group(0) @binding(6) var<storage, read> dt_bias : array<f32>; // [H]
47
+ @group(0) @binding(7) var<storage, read> D_vec : array<f32>; // [H]
48
+ @group(0) @binding(8) var<storage, read_write> out_buf : array<f32>; // [B,L,D_inner]
49
+ @group(0) @binding(9) var<storage, read_write> state_carry : array<f32>; // [n_chunks+1,B,H,N,d_head]
50
+
51
+ fn softplus(x: f32) -> f32 {
52
+ return log(1.0 + exp(x));
53
+ }
54
+
55
+ // Workgroup: one chunk × one head × one batch item
56
+ @compute @workgroup_size(1, 1, 1)
57
+ fn ssd_chunk_forward(@builtin(global_invocation_id) gid: vec3<u32>) {
58
+ let chunk_id = gid.x;
59
+ let head_id = gid.y;
60
+ let batch_id = gid.z;
61
+
62
+ let L = params.seq_len;
63
+ let D = params.d_inner;
64
+ let H = params.n_heads;
65
+ let dh = params.d_head;
66
+ let G = params.n_groups;
67
+ let N = params.d_state;
68
+ let CL = params.chunk_len;
69
+ let NC = params.n_chunks;
70
+ let B = params.batch;
71
+
72
+ let t_start = chunk_id * CL;
73
+ let t_end = min(t_start + CL, L);
74
+
75
+ // Group index: heads are partitioned across groups
76
+ let group_id = head_id * G / H;
77
+
78
+ // A scalar for this head
79
+ let neg_A = softplus(A_log[head_id]); // A_log stores log(-A) positive
80
+ let db = dt_bias[head_id];
81
+ let d_skip = D_vec[head_id];
82
+
83
+ // Load carry-in state: h[N, dh] (stored flat as N*dh floats)
84
+ // state_carry layout: [NC+1, B, H, N*dh]
85
+ let state_stride_chunk = B * H * N * dh;
86
+ let state_base_in = chunk_id * state_stride_chunk
87
+ + batch_id * H * N * dh
88
+ + head_id * N * dh;
89
+
90
+ // We maintain h as a local array (N * dh floats).
91
+ // WebGPU WGSL does not support variable-length arrays in function scope,
92
+ // so we use a fixed maximum. Max N*dh = 64*64 = 4096. Here we use dynamic
93
+ // indexing into state_carry which is shared storage.
94
+
95
+ // Write carry-in into temporary positions — use state_carry directly for
96
+ // the running state (overwrite in-place from carry-in slot).
97
+ // Copy carry-in to working slot (chunk_id+1 slot, updated each step).
98
+ let state_base_out = (chunk_id + 1u) * state_stride_chunk
99
+ + batch_id * H * N * dh
100
+ + head_id * N * dh;
101
+
102
+ // Initialise working state from carry-in
103
+ for (var s: u32 = 0u; s < N * dh; s = s + 1u) {
104
+ state_carry[state_base_out + s] = state_carry[state_base_in + s];
105
+ }
106
+
107
+ // Sequential scan over the chunk
108
+ for (var t: u32 = t_start; t < t_end; t = t + 1u) {
109
+ // dt scalar for this head at time t
110
+ let dt_idx = batch_id * L * H + t * H + head_id;
111
+ let dt_val = softplus(dt_in[dt_idx] + db);
112
+
113
+ // A_bar = exp(-neg_A * dt_val)
114
+ let a_bar = exp(-neg_A * dt_val);
115
+
116
+ // Head slice of x: x[batch, t, head*dh .. (head+1)*dh]
117
+ let x_base = batch_id * L * D + t * D + head_id * dh;
118
+
119
+ // B at this time step: B_proj[batch, t, group_id, *] shape [N]
120
+ let b_base = batch_id * L * G * N + t * G * N + group_id * N;
121
+
122
+ // C at this time step: C_proj[batch, t, group_id, *] shape [N]
123
+ let c_base = batch_id * L * G * N + t * G * N + group_id * N;
124
+
125
+ // y accumulator for this head at time t
126
+ var y_acc: f32 = 0.0;
127
+
128
+ for (var n: u32 = 0u; n < N; n = n + 1u) {
129
+ let b_val = B_proj[b_base + n];
130
+ let c_val = C_proj[c_base + n];
131
+
132
+ for (var i: u32 = 0u; i < dh; i = i + 1u) {
133
+ let s_idx = state_base_out + n * dh + i;
134
+ let x_val = x_in[x_base + i];
135
+
136
+ // h_t = A_bar * h_{t-1} + B * x
137
+ let h_new = a_bar * state_carry[s_idx] + b_val * x_val;
138
+ state_carry[s_idx] = h_new;
139
+
140
+ // y += C * h (summed over n dimension per output channel i)
141
+ y_acc = y_acc + c_val * h_new;
142
+ }
143
+ }
144
+
145
+ // Write y + skip (D * x, averaged over dh for the skip scalar)
146
+ // out[batch, t, head*dh .. (head+1)*dh]
147
+ for (var i: u32 = 0u; i < dh; i = i + 1u) {
148
+ let out_idx = batch_id * L * D + t * D + head_id * dh + i;
149
+ let x_val = x_in[x_base + i];
150
+ out_buf[out_idx] = y_acc + d_skip * x_val;
151
+ }
152
+ }
153
+ }
154
+ `;
155
+ // ── Backward ──────────────────────────────────────────────────────────────────
156
+ export const SSD_BACKWARD_WGSL = /* wgsl */ `
157
+ struct SsdParams {
158
+ seq_len : u32,
159
+ d_inner : u32,
160
+ n_heads : u32,
161
+ d_head : u32,
162
+ n_groups : u32,
163
+ d_state : u32,
164
+ chunk_len : u32,
165
+ n_chunks : u32,
166
+ batch : u32,
167
+ };
168
+
169
+ @group(0) @binding(0) var<uniform> params : SsdParams;
170
+ @group(0) @binding(1) var<storage, read> x_in : array<f32>;
171
+ @group(0) @binding(2) var<storage, read> B_proj : array<f32>;
172
+ @group(0) @binding(3) var<storage, read> C_proj : array<f32>;
173
+ @group(0) @binding(4) var<storage, read> dt_in : array<f32>;
174
+ @group(0) @binding(5) var<storage, read> A_log : array<f32>;
175
+ @group(0) @binding(6) var<storage, read> dt_bias : array<f32>;
176
+ @group(0) @binding(7) var<storage, read> state_carry : array<f32>; // forward states
177
+ @group(0) @binding(8) var<storage, read> dy : array<f32>; // upstream grad
178
+ @group(0) @binding(9) var<storage, read_write> dx : array<f32>;
179
+ @group(0) @binding(10) var<storage, read_write> dB : array<f32>;
180
+ @group(0) @binding(11) var<storage, read_write> dC : array<f32>;
181
+ @group(0) @binding(12) var<storage, read_write> ddt : array<f32>;
182
+ @group(0) @binding(13) var<storage, read_write> dA_log : array<f32>;
183
+ @group(0) @binding(14) var<storage, read_write> dD_vec : array<f32>;
184
+
185
+ fn softplus(x: f32) -> f32 {
186
+ return log(1.0 + exp(x));
187
+ }
188
+ fn d_softplus(x: f32) -> f32 {
189
+ return 1.0 / (1.0 + exp(-x));
190
+ }
191
+
192
+ @compute @workgroup_size(1, 1, 1)
193
+ fn ssd_chunk_backward(@builtin(global_invocation_id) gid: vec3<u32>) {
194
+ let chunk_id = gid.x;
195
+ let head_id = gid.y;
196
+ let batch_id = gid.z;
197
+
198
+ let L = params.seq_len;
199
+ let D = params.d_inner;
200
+ let H = params.n_heads;
201
+ let dh = params.d_head;
202
+ let G = params.n_groups;
203
+ let N = params.d_state;
204
+ let CL = params.chunk_len;
205
+ let NC = params.n_chunks;
206
+ let B = params.batch;
207
+
208
+ let t_start = chunk_id * CL;
209
+ let t_end = min(t_start + CL, L);
210
+ let group_id = head_id * G / H;
211
+
212
+ let neg_A = softplus(A_log[head_id]);
213
+ let db = dt_bias[head_id];
214
+
215
+ let state_stride = B * H * N * dh;
216
+ let state_base = chunk_id * state_stride
217
+ + batch_id * H * N * dh
218
+ + head_id * N * dh;
219
+
220
+ // Backward: iterate time steps in reverse within the chunk
221
+ // dh_next starts at zero (or propagated from future chunks — simplified here)
222
+ for (var t_rev: u32 = 0u; t_rev < t_end - t_start; t_rev = t_rev + 1u) {
223
+ let t = t_end - 1u - t_rev;
224
+
225
+ let dt_idx = batch_id * L * H + t * H + head_id;
226
+ let dt_raw = dt_in[dt_idx] + db;
227
+ let dt_val = softplus(dt_raw);
228
+ let a_bar = exp(-neg_A * dt_val);
229
+
230
+ let x_base = batch_id * L * D + t * D + head_id * dh;
231
+ let b_base = batch_id * L * G * N + t * G * N + group_id * N;
232
+ let c_base = b_base;
233
+
234
+ for (var i: u32 = 0u; i < dh; i = i + 1u) {
235
+ let dy_val = dy[batch_id * L * D + t * D + head_id * dh + i];
236
+ let x_val = x_in[x_base + i];
237
+
238
+ // dD_vec
239
+ dD_vec[head_id] = dD_vec[head_id] + dy_val * x_val;
240
+ // dx from skip
241
+ dx[x_base + i] = dx[x_base + i] + dy_val * /* D */ 1.0;
242
+
243
+ for (var n: u32 = 0u; n < N; n = n + 1u) {
244
+ let s_idx = state_base + n * dh + i;
245
+ let h_val = state_carry[(chunk_id + 1u) * state_stride
246
+ + batch_id * H * N * dh
247
+ + head_id * N * dh + n * dh + i];
248
+ let c_val = C_proj[c_base + n];
249
+ let b_val = B_proj[b_base + n];
250
+
251
+ // dC += dy * h
252
+ dC[b_base + n] = dC[b_base + n] + dy_val * h_val;
253
+
254
+ // dh = C * dy
255
+ let dh_val = c_val * dy_val;
256
+
257
+ // dB += dh * x
258
+ dB[b_base + n] = dB[b_base + n] + dh_val * x_val;
259
+
260
+ // dx += dh * B
261
+ dx[x_base + i] = dx[x_base + i] + dh_val * b_val;
262
+
263
+ // ddt += dh * h_prev * (-neg_A) * d_softplus(dt_raw)
264
+ let h_prev = state_carry[s_idx];
265
+ ddt[dt_idx] = ddt[dt_idx]
266
+ + dh_val * h_prev * (-neg_A) * d_softplus(dt_raw);
267
+
268
+ // dA_log += dh * h_prev * a_bar * (-dt_val) * d_softplus(A_log[head])
269
+ dA_log[head_id] = dA_log[head_id]
270
+ + dh_val * h_prev * a_bar * (-dt_val) * d_softplus(A_log[head_id]);
271
+ }
272
+ }
273
+ }
274
+ }
275
+ `;
276
+ //# sourceMappingURL=ssd.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"ssd.js","sourceRoot":"","sources":["../../src/kernels/ssd.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;;;;;;;;;;;GAyBG;AAEH,MAAM,CAAC,MAAM,gBAAgB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CA+HjD,CAAC;AAEF,iFAAiF;AAEjF,MAAM,CAAC,MAAM,iBAAiB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAuHlD,CAAC"}
@@ -0,0 +1,3 @@
1
+ export declare const WEIGHT_UPDATE_WGSL: string;
2
+ export declare const GRAD_CLIP_WGSL: string;
3
+ //# sourceMappingURL=weight_update.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"weight_update.d.ts","sourceRoot":"","sources":["../../src/kernels/weight_update.ts"],"names":[],"mappings":"AAUA,eAAO,MAAM,kBAAkB,EAAE,MAgDhC,CAAC;AAIF,eAAO,MAAM,cAAc,EAAE,MAyD5B,CAAC"}