mambacode.js 1.0.0 → 1.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. package/README.md +198 -76
  2. package/dist/index.d.ts +19 -0
  3. package/dist/index.d.ts.map +1 -0
  4. package/dist/index.js +18 -0
  5. package/dist/index.js.map +1 -0
  6. package/dist/kernels/activations.d.ts +3 -0
  7. package/dist/kernels/activations.d.ts.map +1 -0
  8. package/dist/kernels/activations.js +87 -0
  9. package/dist/kernels/activations.js.map +1 -0
  10. package/dist/kernels/conv1d.d.ts +3 -0
  11. package/dist/kernels/conv1d.d.ts.map +1 -0
  12. package/dist/kernels/conv1d.js +152 -0
  13. package/dist/kernels/conv1d.js.map +1 -0
  14. package/dist/kernels/linear_projection.d.ts +3 -0
  15. package/dist/kernels/linear_projection.d.ts.map +1 -0
  16. package/dist/kernels/linear_projection.js +219 -0
  17. package/dist/kernels/linear_projection.js.map +1 -0
  18. package/dist/kernels/selective_scan.d.ts +3 -0
  19. package/dist/kernels/selective_scan.d.ts.map +1 -0
  20. package/dist/kernels/selective_scan.js +348 -0
  21. package/dist/kernels/selective_scan.js.map +1 -0
  22. package/dist/kernels/weight_update.d.ts +3 -0
  23. package/dist/kernels/weight_update.d.ts.map +1 -0
  24. package/dist/kernels/weight_update.js +119 -0
  25. package/dist/kernels/weight_update.js.map +1 -0
  26. package/dist/model/mamba_block.d.ts +64 -0
  27. package/dist/model/mamba_block.d.ts.map +1 -0
  28. package/dist/model/mamba_block.js +309 -0
  29. package/dist/model/mamba_block.js.map +1 -0
  30. package/dist/model/mamba_model.d.ts +66 -0
  31. package/dist/model/mamba_model.d.ts.map +1 -0
  32. package/dist/model/mamba_model.js +289 -0
  33. package/dist/model/mamba_model.js.map +1 -0
  34. package/dist/tokenizer/bpe.d.ts +29 -0
  35. package/dist/tokenizer/bpe.d.ts.map +1 -0
  36. package/dist/tokenizer/bpe.js +164 -0
  37. package/dist/tokenizer/bpe.js.map +1 -0
  38. package/dist/training/autograd.d.ts +27 -0
  39. package/dist/training/autograd.d.ts.map +1 -0
  40. package/dist/training/autograd.js +120 -0
  41. package/dist/training/autograd.js.map +1 -0
  42. package/dist/training/trainer.d.ts +37 -0
  43. package/dist/training/trainer.d.ts.map +1 -0
  44. package/dist/training/trainer.js +183 -0
  45. package/dist/training/trainer.js.map +1 -0
  46. package/dist/utils/gpu_utils.d.ts +21 -0
  47. package/dist/utils/gpu_utils.d.ts.map +1 -0
  48. package/dist/utils/gpu_utils.js +111 -0
  49. package/dist/utils/gpu_utils.js.map +1 -0
  50. package/dist/utils/quantization.d.ts +26 -0
  51. package/dist/utils/quantization.d.ts.map +1 -0
  52. package/dist/utils/quantization.js +116 -0
  53. package/dist/utils/quantization.js.map +1 -0
  54. package/package.json +43 -18
  55. package/src/index.ts +61 -0
  56. package/src/kernels/{activations.js → activations.ts} +2 -2
  57. package/src/kernels/{linear_projection.js → linear_projection.ts} +2 -2
  58. package/src/kernels/{selective_scan.js → selective_scan.ts} +2 -2
  59. package/src/kernels/{weight_update.js → weight_update.ts} +2 -2
  60. package/src/model/{mamba_block.js → mamba_block.ts} +134 -170
  61. package/src/model/{mamba_model.js → mamba_model.ts} +165 -121
  62. package/src/tokenizer/bpe.ts +186 -0
  63. package/src/training/autograd.ts +135 -0
  64. package/src/training/{trainer.js → trainer.ts} +79 -161
  65. package/src/utils/gpu_utils.ts +147 -0
  66. package/src/utils/quantization.ts +154 -0
  67. package/src/index.js +0 -89
  68. package/src/tokenizer/bpe.js +0 -256
  69. package/src/training/autograd.js +0 -221
  70. package/src/utils/gpu_utils.js +0 -217
  71. package/src/utils/quantization.js +0 -215
  72. /package/src/kernels/{conv1d.js → conv1d.ts} +0 -0
@@ -0,0 +1,348 @@
1
+ // Parallel Selective Scan WGSL Kernel
2
+ // Implements the S6 (Selective Scan) core of the Mamba architecture.
3
+ // Uses a Kogge-Stone parallel prefix-sum approach for O(log N) time on GPU.
4
+ //
5
+ // Forward pass recurrence:
6
+ // h_t = A_t * h_{t-1} + B_t * x_t
7
+ // y_t = C_t * h_t + D * x_t
8
+ //
9
+ // where A_t, B_t, C_t are input-dependent (selective) gate matrices.
10
+ export const SELECTIVE_SCAN_FORWARD_WGSL = /* wgsl */ `
11
+
12
+ // ---- Binding layout ----
13
+ // group 0: sequence data
14
+ // group 1: SSM parameters
15
+
16
+ struct ScanParams {
17
+ seq_len : u32, // L – sequence length
18
+ d_state : u32, // N – state dimension
19
+ d_inner : u32, // D – inner (expanded) channel dimension
20
+ batch : u32, // B – batch size
21
+ };
22
+
23
+ @group(0) @binding(0) var<uniform> params : ScanParams;
24
+ // u (B, L, D) – projected input after conv
25
+ @group(0) @binding(1) var<storage, read> u : array<f32>;
26
+ // delta (B, L, D) – time-step (Δ) after softplus
27
+ @group(0) @binding(2) var<storage, read> delta : array<f32>;
28
+ // A (D, N) – log-space diagonal state matrix (fixed, learned)
29
+ @group(0) @binding(3) var<storage, read> A : array<f32>;
30
+ // B (B, L, N) – input projection (selective)
31
+ @group(0) @binding(4) var<storage, read> B : array<f32>;
32
+ // C (B, L, N) – output projection (selective)
33
+ @group(0) @binding(5) var<storage, read> C : array<f32>;
34
+ // D (D,) – skip-connection scale
35
+ @group(0) @binding(6) var<storage, read> D_vec : array<f32>;
36
+ // y (B, L, D) – output (written by this kernel)
37
+ @group(0) @binding(7) var<storage, read_write> y : array<f32>;
38
+ // h_cache (B, L, D*N) – hidden states cache (for backward pass)
39
+ @group(0) @binding(8) var<storage, read_write> h_cache : array<f32>;
40
+
41
+ // ---- Workgroup shared memory ----
42
+ // Each workgroup processes one (batch, channel) slice across all time steps.
43
+ // We store the associative pair (a_bar, bu_bar) per time step so we can run
44
+ // a Kogge-Stone scan across the workgroup tile.
45
+ var<workgroup> wg_a : array<f32, 256>; // discretised A values
46
+ var<workgroup> wg_bu : array<f32, 256>; // B*u values
47
+
48
+ // ---- Helpers ----
49
+
50
+ // Softplus: numerically stable log(1 + exp(x))
51
+ fn softplus(x: f32) -> f32 {
52
+ return log(1.0 + exp(x));
53
+ }
54
+
55
+ // ZerO-Order Hold discretisation of continuous A, Δ:
56
+ // A_bar = exp(Δ * A)
57
+ // B_bar = (A_bar - 1) / A * B ≈ Δ * B (first-order for simplicity)
58
+ fn discretise_A(delta_val: f32, a_log: f32) -> f32 {
59
+ // A is stored as -exp(a_log) to ensure A_bar < 1 (stable)
60
+ let a_cont = -exp(a_log);
61
+ return exp(delta_val * a_cont);
62
+ }
63
+
64
+ fn discretise_B(delta_val: f32, a_log: f32, b_val: f32) -> f32 {
65
+ let a_cont = -exp(a_log);
66
+ let a_bar = exp(delta_val * a_cont);
67
+ // (A_bar - 1) / A_cont * B
68
+ let b_bar = (a_bar - 1.0) / a_cont * b_val;
69
+ return b_bar;
70
+ }
71
+
72
+ // ---- Main kernel ----
73
+ // Dispatch: (ceil(D/8), ceil(N/8), B)
74
+ // Each invocation is responsible for one (d, n, batch) triplet and scans
75
+ // the entire sequence using a two-pass Kogge-Stone scan within workgroup tiles.
76
+
77
+ @compute @workgroup_size(64, 1, 1)
78
+ fn forward_scan(
79
+ @builtin(global_invocation_id) gid : vec3<u32>,
80
+ @builtin(local_invocation_index) lid : u32,
81
+ @builtin(workgroup_id) wgid : vec3<u32>,
82
+ ) {
83
+ let L = params.seq_len;
84
+ let N = params.d_state;
85
+ let D = params.d_inner;
86
+ let B = params.batch;
87
+
88
+ // Each workgroup handles one (batch b, channel d, state n) combination.
89
+ // We pack d and n into the x dimension: global d = wgid.x, global n = wgid.y
90
+ let d = wgid.x;
91
+ let n = wgid.y;
92
+ let b = gid.z;
93
+
94
+ if (d >= D || n >= N || b >= B) { return; }
95
+
96
+ // Tile size equals workgroup size (64). We process TILE_SIZE steps at once.
97
+ let TILE: u32 = 64u;
98
+
99
+ // Running state h for this (b, d, n)
100
+ var h: f32 = 0.0;
101
+
102
+ var tile_start: u32 = 0u;
103
+ loop {
104
+ if (tile_start >= L) { break; }
105
+
106
+ let t = tile_start + lid; // absolute time step handled by this lane
107
+ var a_bar: f32 = 1.0;
108
+ var bu: f32 = 0.0;
109
+
110
+ if (t < L) {
111
+ // Indices
112
+ let delta_idx = b * L * D + t * D + d;
113
+ let u_idx = b * L * D + t * D + d;
114
+ let A_idx = d * N + n;
115
+ let B_idx = b * L * N + t * N + n;
116
+
117
+ let dv = softplus(delta[delta_idx]);
118
+ a_bar = discretise_A(dv, A[A_idx]);
119
+ bu = discretise_B(dv, A[A_idx], B[B_idx]) * u[u_idx];
120
+ }
121
+
122
+ wg_a[lid] = a_bar;
123
+ wg_bu[lid] = bu;
124
+ workgroupBarrier();
125
+
126
+ // ---- Kogge-Stone inclusive prefix scan within tile ----
127
+ // Associative operator: (a1, b1) ∘ (a2, b2) = (a1*a2, a1*b2 + b1)
128
+ // This computes cumulative state recurrence in log2(TILE) steps.
129
+ var stride: u32 = 1u;
130
+ loop {
131
+ if (stride >= TILE) { break; }
132
+ if (lid >= stride) {
133
+ let prev_a = wg_a[lid - stride];
134
+ let prev_bu = wg_bu[lid - stride];
135
+ // Combine: new_state = prev_a * cur_a (product of A_bars)
136
+ // new_bu = prev_a * cur_bu + prev_bu
137
+ let new_a = prev_a * wg_a[lid];
138
+ let new_bu = prev_a * wg_bu[lid] + prev_bu;
139
+ workgroupBarrier();
140
+ wg_a[lid] = new_a;
141
+ wg_bu[lid] = new_bu;
142
+ }
143
+ workgroupBarrier();
144
+ stride = stride << 1u;
145
+ }
146
+
147
+ // Incorporate the carry-in state from the previous tile.
148
+ // After the scan wg_bu[lid] holds the intra-tile inclusive sum.
149
+ // The actual h at position t = h_carry * wg_a[lid] + wg_bu[lid]
150
+ let h_t = h * wg_a[lid] + wg_bu[lid];
151
+
152
+ if (t < L) {
153
+ // Cache hidden state for backward pass
154
+ let h_idx = b * L * D * N + t * D * N + d * N + n;
155
+ h_cache[h_idx] = h_t;
156
+
157
+ // Accumulate y contribution: y_t += C_t[n] * h_t (over all n)
158
+ // We use an atomic-style accumulation: each (d, n) lane adds its
159
+ // contribution to the same y[b, t, d]. This races without atomics,
160
+ // so we instead write to a full h_cache and reduce in a second pass.
161
+ // Here we perform direct accumulation using atomicAdd approximation:
162
+ // (safe because each lane writes a unique n, which is stride 1 in mem)
163
+ let C_idx = b * L * N + t * N + n;
164
+ let y_idx = b * L * D + t * D + d;
165
+
166
+ // Direct write for n == 0 (first state dim), add for the rest.
167
+ // Since all workgroups for the same (b,d) run concurrently we must
168
+ // accumulate safely: we write each partial into h_cache and reduce
169
+ // in a subsequent lightweight kernel (forward_reduce).
170
+ // (For simplicity and correctness here we directly atomically add via
171
+ // f32 emulation – real deployment uses atomicAdd on f32 with spirv ext.)
172
+ // We store C*h contribution separately so forward_reduce can sum them.
173
+ // Layout: y_partial (B, L, D, N) – one slot per state dim
174
+ // y reused as y_partial in this kernel; forward_reduce collapses N dim.
175
+ let y_partial_idx = b * L * D * N + t * D * N + d * N + n;
176
+ // Reuse h_cache second half as y_partial (offset by B*L*D*N)
177
+ let offset = B * L * D * N;
178
+ h_cache[offset + y_partial_idx] = C[C_idx] * h_t;
179
+ }
180
+
181
+ // Update carry: last lane's h_t is the tile's final state
182
+ let last = min(TILE, L - tile_start) - 1u;
183
+ h = wg_a[last] * h + wg_bu[last]; // recombine carry
184
+
185
+ workgroupBarrier();
186
+ tile_start = tile_start + TILE;
187
+ }
188
+ }
189
+
190
+ // ---- Reduction kernel ----
191
+ // Collapses the N (d_state) dimension of y_partial into y.
192
+ // Adds the D (skip connection) term: y_t[d] += D_vec[d] * u_t[d]
193
+ // Dispatch: (ceil(L/64), D, B)
194
+
195
+ @compute @workgroup_size(64, 1, 1)
196
+ fn forward_reduce(
197
+ @builtin(global_invocation_id) gid : vec3<u32>,
198
+ ) {
199
+ let L = params.seq_len;
200
+ let N = params.d_state;
201
+ let D = params.d_inner;
202
+ let B = params.batch;
203
+
204
+ let t = gid.x;
205
+ let d = gid.y;
206
+ let b = gid.z;
207
+
208
+ if (t >= L || d >= D || b >= B) { return; }
209
+
210
+ let offset = B * L * D * N;
211
+ var sum: f32 = 0.0;
212
+ for (var n: u32 = 0u; n < N; n = n + 1u) {
213
+ let idx = offset + b * L * D * N + t * D * N + d * N + n;
214
+ sum = sum + h_cache[idx];
215
+ }
216
+
217
+ // Add skip connection
218
+ let u_idx = b * L * D + t * D + d;
219
+ sum = sum + D_vec[d] * u[u_idx];
220
+
221
+ let y_idx = b * L * D + t * D + d;
222
+ y[y_idx] = sum;
223
+ }
224
+ `;
225
+ // ---- Backward scan kernel (for autograd) ----
226
+ // Computes gradients w.r.t. Δ, A, B, C using the cached hidden states.
227
+ export const SELECTIVE_SCAN_BACKWARD_WGSL = /* wgsl */ `
228
+
229
+ struct ScanParams {
230
+ seq_len : u32,
231
+ d_state : u32,
232
+ d_inner : u32,
233
+ batch : u32,
234
+ };
235
+
236
+ @group(0) @binding(0) var<uniform> params : ScanParams;
237
+ @group(0) @binding(1) var<storage, read> u : array<f32>;
238
+ @group(0) @binding(2) var<storage, read> delta : array<f32>;
239
+ @group(0) @binding(3) var<storage, read> A : array<f32>;
240
+ @group(0) @binding(4) var<storage, read> B : array<f32>;
241
+ @group(0) @binding(5) var<storage, read> C : array<f32>;
242
+ @group(0) @binding(6) var<storage, read> h_cache : array<f32>;
243
+ @group(0) @binding(7) var<storage, read> dy : array<f32>; // upstream gradient
244
+ @group(0) @binding(8) var<storage, read_write> dA : array<f32>;
245
+ @group(0) @binding(9) var<storage, read_write> dB : array<f32>;
246
+ @group(0) @binding(10) var<storage, read_write> dC : array<f32>;
247
+ @group(0) @binding(11) var<storage, read_write> dDelta : array<f32>;
248
+ @group(0) @binding(12) var<storage, read_write> du : array<f32>;
249
+
250
+ fn softplus(x: f32) -> f32 {
251
+ return log(1.0 + exp(x));
252
+ }
253
+
254
+ fn softplus_grad(x: f32) -> f32 {
255
+ // d/dx softplus(x) = sigmoid(x)
256
+ return 1.0 / (1.0 + exp(-x));
257
+ }
258
+
259
+ fn discretise_A(delta_val: f32, a_log: f32) -> f32 {
260
+ let a_cont = -exp(a_log);
261
+ return exp(delta_val * a_cont);
262
+ }
263
+
264
+ // Reverse scan (backward pass) – processes time from T-1 down to 0.
265
+ // Dispatch: (D, N, B)
266
+ @compute @workgroup_size(1, 1, 1)
267
+ fn backward_scan(
268
+ @builtin(global_invocation_id) gid : vec3<u32>,
269
+ ) {
270
+ let L = params.seq_len;
271
+ let N = params.d_state;
272
+ let D = params.d_inner;
273
+ let B = params.batch;
274
+
275
+ let d = gid.x;
276
+ let n = gid.y;
277
+ let b = gid.z;
278
+
279
+ if (d >= D || n >= N || b >= B) { return; }
280
+
281
+ var dh: f32 = 0.0; // gradient of loss w.r.t. h_t, accumulated backwards
282
+
283
+ var t: u32 = L;
284
+ loop {
285
+ if (t == 0u) { break; }
286
+ t = t - 1u;
287
+
288
+ let delta_raw_idx = b * L * D + t * D + d;
289
+ let A_idx = d * N + n;
290
+ let B_idx = b * L * N + t * N + n;
291
+ let C_idx = b * L * N + t * N + n;
292
+ let u_idx = b * L * D + t * D + d;
293
+ let h_idx = b * L * D * N + t * D * N + d * N + n;
294
+
295
+ let delta_raw = delta[delta_raw_idx];
296
+ let dv = softplus(delta_raw);
297
+ let a_log = A[A_idx];
298
+ let a_cont = -exp(a_log);
299
+ let a_bar = exp(dv * a_cont);
300
+ let b_val = B[B_idx];
301
+ let c_val = C[C_idx];
302
+ let u_val = u[u_idx];
303
+ let h_t = h_cache[h_idx];
304
+
305
+ // dy_t contribution to dh (from C * h_t in the output)
306
+ // y_t[d] = sum_n C[n] * h_t[n] + D * u => dh_t[n] += C[n] * dy_t[d]
307
+ let dy_val = dy[b * L * D + t * D + d];
308
+ dh = dh + c_val * dy_val;
309
+
310
+ // dC[b, t, n] += dy_t[d] * h_t
311
+ dC[C_idx] = dC[C_idx] + dy_val * h_t;
312
+
313
+ // h_t = a_bar * h_{t-1} + b_bar * u_t
314
+ // b_bar = (a_bar - 1) / a_cont * b_val
315
+ let b_bar = (a_bar - 1.0) / a_cont * b_val;
316
+ let h_prev = (t > 0u) ? h_cache[b * L * D * N + (t - 1u) * D * N + d * N + n] : 0.0;
317
+
318
+ // dh_{t-1} += a_bar * dh_t
319
+ // (accumulated in next iteration; here dh already contains upstream)
320
+ let dh_cur = dh;
321
+
322
+ // dA[d,n] += dh_t * (d a_bar/d a_cont) * (d a_cont/d a_log) * h_{t-1}
323
+ // + dh_t * (d b_bar/d a_cont) * ... * b_val * u_val
324
+ // d(a_bar)/d(a_log) = a_bar * (-exp(a_log)) * dv = a_bar * a_cont * dv
325
+ let da_bar_da_log = a_bar * a_cont * dv;
326
+ dA[A_idx] = dA[A_idx] + dh_cur * (da_bar_da_log * h_prev);
327
+
328
+ // dB[b,t,n] += dh_t * b_bar / b_val * u_val (since b_bar is linear in b)
329
+ dB[B_idx] = dB[B_idx] + dh_cur * ((a_bar - 1.0) / a_cont) * u_val;
330
+
331
+ // du[b,t,d] += dh_t * b_bar (accumulate over n in separate kernel)
332
+ du[u_idx] = du[u_idx] + dh_cur * b_bar;
333
+
334
+ // dDelta[b,t,d]: chain rule through softplus and discretisation
335
+ // d(b_bar)/d(dv) = d/d(dv)[(a_bar-1)/a_cont * b] = a_bar * b / (a_cont ... )
336
+ // actually: d(a_bar)/d(dv) = a_bar * a_cont, d(b_bar)/d(dv) = a_bar * b_val
337
+ let da_bar_ddv = a_bar * a_cont;
338
+ let db_bar_ddv = a_bar * b_val;
339
+ let dLoss_ddv = dh_cur * (da_bar_ddv * h_prev + db_bar_ddv * u_val);
340
+ let ddv_ddelta = softplus_grad(delta_raw);
341
+ dDelta[delta_raw_idx] = dDelta[delta_raw_idx] + dLoss_ddv * ddv_ddelta;
342
+
343
+ // Propagate dh to previous timestep
344
+ dh = a_bar * dh_cur;
345
+ }
346
+ }
347
+ `;
348
+ //# sourceMappingURL=selective_scan.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"selective_scan.js","sourceRoot":"","sources":["../../src/kernels/selective_scan.ts"],"names":[],"mappings":"AAAA,sCAAsC;AACtC,qEAAqE;AACrE,4EAA4E;AAC5E,EAAE;AACF,2BAA2B;AAC3B,oCAAoC;AACpC,8BAA8B;AAC9B,EAAE;AACF,qEAAqE;AAErE,MAAM,CAAC,MAAM,2BAA2B,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAsN5D,CAAC;AAEF,gDAAgD;AAChD,uEAAuE;AAEvE,MAAM,CAAC,MAAM,4BAA4B,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAwH7D,CAAC"}
@@ -0,0 +1,3 @@
1
+ export declare const WEIGHT_UPDATE_WGSL: string;
2
+ export declare const GRAD_CLIP_WGSL: string;
3
+ //# sourceMappingURL=weight_update.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"weight_update.d.ts","sourceRoot":"","sources":["../../src/kernels/weight_update.ts"],"names":[],"mappings":"AAUA,eAAO,MAAM,kBAAkB,EAAE,MAgDhC,CAAC;AAIF,eAAO,MAAM,cAAc,EAAE,MAyD5B,CAAC"}
@@ -0,0 +1,119 @@
1
+ // Weight Update WGSL Kernel (AdamW Optimizer)
2
+ // Implements fused AdamW parameter update on the GPU.
3
+ //
4
+ // AdamW update rule:
5
+ // m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
6
+ // v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
7
+ // m_hat = m_t / (1 - beta1^t)
8
+ // v_hat = v_t / (1 - beta2^t)
9
+ // theta_t = theta_{t-1} * (1 - lr * weight_decay) - lr * m_hat / (sqrt(v_hat) + eps)
10
+ export const WEIGHT_UPDATE_WGSL = /* wgsl */ `
11
+
12
+ struct AdamParams {
13
+ num_elements : u32,
14
+ lr : f32, // learning rate
15
+ beta1 : f32, // default 0.9
16
+ beta2 : f32, // default 0.999
17
+ eps : f32, // default 1e-8
18
+ weight_decay : f32, // default 0.01
19
+ beta1_t : f32, // beta1^t (precomputed bias correction term)
20
+ beta2_t : f32, // beta2^t
21
+ };
22
+
23
+ @group(0) @binding(0) var<uniform> adam : AdamParams;
24
+ // param (N,) – weight tensor (read-write: updated in-place)
25
+ @group(0) @binding(1) var<storage, read_write> param : array<f32>;
26
+ // grad (N,) – gradient
27
+ @group(0) @binding(2) var<storage, read> grad : array<f32>;
28
+ // m (N,) – first moment
29
+ @group(0) @binding(3) var<storage, read_write> m_state : array<f32>;
30
+ // v (N,) – second moment
31
+ @group(0) @binding(4) var<storage, read_write> v_state : array<f32>;
32
+
33
+ // Dispatch: (ceil(N / 256), 1, 1)
34
+ @compute @workgroup_size(256, 1, 1)
35
+ fn adamw_update(
36
+ @builtin(global_invocation_id) gid : vec3<u32>,
37
+ ) {
38
+ let i = gid.x;
39
+ if (i >= adam.num_elements) { return; }
40
+
41
+ let g = grad[i];
42
+ let p = param[i];
43
+
44
+ // Moment updates
45
+ let m_new = adam.beta1 * m_state[i] + (1.0 - adam.beta1) * g;
46
+ let v_new = adam.beta2 * v_state[i] + (1.0 - adam.beta2) * g * g;
47
+ m_state[i] = m_new;
48
+ v_state[i] = v_new;
49
+
50
+ // Bias-corrected estimates
51
+ let m_hat = m_new / (1.0 - adam.beta1_t);
52
+ let v_hat = v_new / (1.0 - adam.beta2_t);
53
+
54
+ // Weight decay (decoupled) + gradient step
55
+ param[i] = p * (1.0 - adam.lr * adam.weight_decay) -
56
+ adam.lr * m_hat / (sqrt(v_hat) + adam.eps);
57
+ }
58
+ `;
59
+ // Gradient clipping kernel – clips global gradient norm to max_norm.
60
+ // Run before weight updates. Two-pass: first compute squared norm, then scale.
61
+ export const GRAD_CLIP_WGSL = /* wgsl */ `
62
+
63
+ struct ClipParams {
64
+ num_elements : u32,
65
+ max_norm_sq : f32, // max_norm^2
66
+ };
67
+
68
+ @group(0) @binding(0) var<uniform> clip_p : ClipParams;
69
+ @group(0) @binding(1) var<storage, read_write> grad : array<f32>;
70
+ @group(0) @binding(2) var<storage, read_write> norm_sq : array<f32>; // size 1, atomic accumulator
71
+
72
+ var<workgroup> local_sq : array<f32, 256>;
73
+
74
+ // Pass 1: reduce sum of squares into norm_sq[0]
75
+ @compute @workgroup_size(256, 1, 1)
76
+ fn grad_norm_reduce(
77
+ @builtin(global_invocation_id) gid : vec3<u32>,
78
+ @builtin(local_invocation_index) lid : u32,
79
+ ) {
80
+ let i = gid.x;
81
+ local_sq[lid] = 0.0;
82
+ if (i < clip_p.num_elements) {
83
+ local_sq[lid] = grad[i] * grad[i];
84
+ }
85
+ workgroupBarrier();
86
+
87
+ // Parallel reduction within workgroup
88
+ var s: u32 = 128u;
89
+ loop {
90
+ if (s == 0u) { break; }
91
+ if (lid < s) {
92
+ local_sq[lid] = local_sq[lid] + local_sq[lid + s];
93
+ }
94
+ workgroupBarrier();
95
+ s = s >> 1u;
96
+ }
97
+
98
+ if (lid == 0u) {
99
+ // Non-atomic accumulation (single workgroup assumption for small models)
100
+ norm_sq[0] = norm_sq[0] + local_sq[0];
101
+ }
102
+ }
103
+
104
+ // Pass 2: scale gradients if norm exceeds max_norm
105
+ @compute @workgroup_size(256, 1, 1)
106
+ fn grad_clip_scale(
107
+ @builtin(global_invocation_id) gid : vec3<u32>,
108
+ ) {
109
+ let i = gid.x;
110
+ if (i >= clip_p.num_elements) { return; }
111
+
112
+ let ns = norm_sq[0];
113
+ if (ns > clip_p.max_norm_sq) {
114
+ let scale = sqrt(clip_p.max_norm_sq / ns);
115
+ grad[i] = grad[i] * scale;
116
+ }
117
+ }
118
+ `;
119
+ //# sourceMappingURL=weight_update.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"weight_update.js","sourceRoot":"","sources":["../../src/kernels/weight_update.ts"],"names":[],"mappings":"AAAA,8CAA8C;AAC9C,sDAAsD;AACtD,EAAE;AACF,qBAAqB;AACrB,8CAA8C;AAC9C,gDAAgD;AAChD,gCAAgC;AAChC,gCAAgC;AAChC,uFAAuF;AAEvF,MAAM,CAAC,MAAM,kBAAkB,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAgDnD,CAAC;AAEF,qEAAqE;AACrE,gFAAgF;AAChF,MAAM,CAAC,MAAM,cAAc,GAAW,UAAU,CAAA;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAyD/C,CAAC"}
@@ -0,0 +1,64 @@
1
+ /**
2
+ * mamba_block.ts – Mamba Mixer Block
3
+ */
4
+ export interface MambaBlockConfig {
5
+ dModel: number;
6
+ dState?: number;
7
+ dConv?: number;
8
+ expand?: number;
9
+ dtRank?: number;
10
+ biasConv?: boolean;
11
+ }
12
+ export interface BlockParam {
13
+ buf: GPUBuffer;
14
+ numel: number;
15
+ name: string;
16
+ }
17
+ export interface BlockCache {
18
+ normInv: GPUBuffer;
19
+ normIn: GPUBuffer;
20
+ normOut: GPUBuffer;
21
+ zBuf: GPUBuffer;
22
+ xConvIn: GPUBuffer;
23
+ convOut: GPUBuffer;
24
+ siluOut: GPUBuffer;
25
+ deltaFull: GPUBuffer;
26
+ B_raw: GPUBuffer;
27
+ C_raw: GPUBuffer;
28
+ hCache: GPUBuffer;
29
+ }
30
+ export interface BlockForwardResult {
31
+ output: GPUBuffer;
32
+ cache: BlockCache;
33
+ }
34
+ export declare class MambaBlock {
35
+ device: GPUDevice;
36
+ config: Required<MambaBlockConfig>;
37
+ dInner: number;
38
+ dtRank: number;
39
+ wInProj: Float32Array;
40
+ bInProj: Float32Array;
41
+ wConv: Float32Array;
42
+ bConv: Float32Array;
43
+ wXProj: Float32Array;
44
+ bXProj: Float32Array;
45
+ wDtProj: Float32Array;
46
+ bDtProj: Float32Array;
47
+ A_log: Float32Array;
48
+ D_vec: Float32Array;
49
+ wOutProj: Float32Array;
50
+ bOutProj: Float32Array;
51
+ normWeight: Float32Array;
52
+ gpuWeights: Record<string, GPUBuffer>;
53
+ pipelines: Record<string, GPUComputePipeline>;
54
+ private _wslaMode;
55
+ constructor(device: GPUDevice, config: MambaBlockConfig);
56
+ private _initWeights;
57
+ private _uploadWeightsToGPU;
58
+ private _buildPipelines;
59
+ forward(xBuf: GPUBuffer, batch: number, seqLen: number): BlockForwardResult;
60
+ parameters(): BlockParam[];
61
+ setWSLAMode(enabled: boolean): void;
62
+ getTrainableParams(): BlockParam[];
63
+ }
64
+ //# sourceMappingURL=mamba_block.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"mamba_block.d.ts","sourceRoot":"","sources":["../../src/model/mamba_block.ts"],"names":[],"mappings":"AAAA;;GAEG;AAiBH,MAAM,WAAW,gBAAgB;IAC/B,MAAM,EAAE,MAAM,CAAC;IACf,MAAM,CAAC,EAAE,MAAM,CAAC;IAChB,KAAK,CAAC,EAAE,MAAM,CAAC;IACf,MAAM,CAAC,EAAE,MAAM,CAAC;IAChB,MAAM,CAAC,EAAE,MAAM,CAAC;IAChB,QAAQ,CAAC,EAAE,OAAO,CAAC;CACpB;AAED,MAAM,WAAW,UAAU;IACzB,GAAG,EAAE,SAAS,CAAC;IACf,KAAK,EAAE,MAAM,CAAC;IACd,IAAI,EAAE,MAAM,CAAC;CACd;AAED,MAAM,WAAW,UAAU;IACzB,OAAO,EAAE,SAAS,CAAC;IACnB,MAAM,EAAE,SAAS,CAAC;IAClB,OAAO,EAAE,SAAS,CAAC;IACnB,IAAI,EAAE,SAAS,CAAC;IAChB,OAAO,EAAE,SAAS,CAAC;IACnB,OAAO,EAAE,SAAS,CAAC;IACnB,OAAO,EAAE,SAAS,CAAC;IACnB,SAAS,EAAE,SAAS,CAAC;IACrB,KAAK,EAAE,SAAS,CAAC;IACjB,KAAK,EAAE,SAAS,CAAC;IACjB,MAAM,EAAE,SAAS,CAAC;CACnB;AAED,MAAM,WAAW,kBAAkB;IACjC,MAAM,EAAE,SAAS,CAAC;IAClB,KAAK,EAAE,UAAU,CAAC;CACnB;AAED,qBAAa,UAAU;IACnB,MAAM,EAAE,SAAS,CAAC;IAClB,MAAM,EAAE,QAAQ,CAAC,gBAAgB,CAAC,CAAC;IACnC,MAAM,EAAE,MAAM,CAAC;IACf,MAAM,EAAE,MAAM,CAAC;IACf,OAAO,EAAE,YAAY,CAAC;IACtB,OAAO,EAAE,YAAY,CAAC;IACtB,KAAK,EAAE,YAAY,CAAC;IACpB,KAAK,EAAE,YAAY,CAAC;IACpB,MAAM,EAAE,YAAY,CAAC;IACrB,MAAM,EAAE,YAAY,CAAC;IACrB,OAAO,EAAE,YAAY,CAAC;IACtB,OAAO,EAAE,YAAY,CAAC;IACtB,KAAK,EAAE,YAAY,CAAC;IACpB,KAAK,EAAE,YAAY,CAAC;IACpB,QAAQ,EAAE,YAAY,CAAC;IACvB,QAAQ,EAAE,YAAY,CAAC;IACvB,UAAU,EAAE,YAAY,CAAC;IACzB,UAAU,EAAE,MAAM,CAAC,MAAM,EAAE,SAAS,CAAC,CAAC;IACtC,SAAS,EAAE,MAAM,CAAC,MAAM,EAAE,kBAAkB,CAAC,CAAC;IAC9C,OAAO,CAAC,SAAS,CAAS;gBAEd,MAAM,EAAE,SAAS,EAAE,MAAM,EAAE,gBAAgB;IAoCvD,OAAO,CAAC,YAAY;IA2CpB,OAAO,CAAC,mBAAmB;IAqB3B,OAAO,CAAC,eAAe;IAavB,OAAO,CAAC,IAAI,EAAE,SAAS,EAAE,KAAK,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,kBAAkB;IAqL3E,UAAU,IAAI,UAAU,EAAE;IAwB1B,WAAW,CAAC,OAAO,EAAE,OAAO,GAAG,IAAI;IAInC,kBAAkB,IAAI,UAAU,EAAE;CASrC"}