mambacode.js 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,443 @@
1
+ /**
2
+ * mamba_block.js – Mamba Mixer Block
3
+ *
4
+ * Implements one complete Mamba residual layer:
5
+ *
6
+ * x ──► Norm ──► Linear up (×2, for z-gate) ──► Conv1D ──► SiLU ──► Scan ──► × z ──► Linear down ──► + x
7
+ *
8
+ * Components (all dispatched as WebGPU compute passes):
9
+ * 1. RMSNorm
10
+ * 2. Linear up-projection: (D_model → 2 × D_inner)
11
+ * 3. 1D Causal Convolution (depthwise, kernel_size=4)
12
+ * 4. SiLU activation
13
+ * 5. Selective Scan (S6 core)
14
+ * 6. Gated multiplication: y * SiLU(z)
15
+ * 7. Linear down-projection: (D_inner → D_model)
16
+ * 8. Residual add
17
+ */
18
+
19
+ import {
20
+ createComputePipeline,
21
+ createBindGroup,
22
+ createStorageBuffer,
23
+ createEmptyStorageBuffer,
24
+ createUniformBuffer,
25
+ dispatchKernel,
26
+ cdiv,
27
+ } from '../utils/gpu_utils.js';
28
+
29
+ import { SELECTIVE_SCAN_FORWARD_WGSL } from '../kernels/selective_scan.js';
30
+ import { CONV1D_FORWARD_WGSL } from '../kernels/conv1d.js';
31
+ import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection.js';
32
+ import { ACTIVATIONS_WGSL } from '../kernels/activations.js';
33
+
34
+ /**
35
+ * @typedef {Object} MambaBlockConfig
36
+ * @property {number} dModel – model dimension (embedding size)
37
+ * @property {number} dState – SSM state dimension (N, default 16)
38
+ * @property {number} dConv – 1D conv kernel size (default 4)
39
+ * @property {number} expand – expansion factor (default 2) → dInner = expand * dModel
40
+ * @property {number} dtRank – rank of Δ projection (default ceil(dModel/16))
41
+ * @property {boolean} [biasConv] – use bias in conv (default true)
42
+ */
43
+
44
+ export class MambaBlock {
45
+ /**
46
+ * @param {GPUDevice} device
47
+ * @param {MambaBlockConfig} config
48
+ */
49
+ constructor(device, config) {
50
+ this.device = device;
51
+ this.config = {
52
+ dState : 16,
53
+ dConv : 4,
54
+ expand : 2,
55
+ biasConv: true,
56
+ ...config,
57
+ };
58
+
59
+ const { dModel, dState, dConv, expand } = this.config;
60
+ this.dInner = expand * dModel;
61
+ this.dtRank = this.config.dtRank ?? Math.ceil(dModel / 16);
62
+
63
+ // ---- Initialise learnable parameters (CPU → GPU) ----
64
+ this._initWeights();
65
+
66
+ // ---- Compile GPU pipelines (once) ----
67
+ this._buildPipelines();
68
+ }
69
+
70
+ // ─── Weight initialisation ────────────────────────────────────────────────
71
+
72
+ _initWeights() {
73
+ const { dModel, dState, dConv } = this.config;
74
+ const D = this.dInner;
75
+ const N = dState;
76
+ const K = dConv;
77
+ const R = this.dtRank;
78
+
79
+ const randn = (n, std = 0.02) => {
80
+ const a = new Float32Array(n);
81
+ for (let i = 0; i < n; i++) {
82
+ // Box-Muller
83
+ const u1 = Math.random(), u2 = Math.random();
84
+ a[i] = std * Math.sqrt(-2 * Math.log(u1 + 1e-12)) * Math.cos(2 * Math.PI * u2);
85
+ }
86
+ return a;
87
+ };
88
+
89
+ const zeros = (n) => new Float32Array(n);
90
+ const ones = (n) => new Float32Array(n).fill(1.0);
91
+ const linspace = (n) => {
92
+ const a = new Float32Array(n);
93
+ for (let i = 0; i < n; i++) a[i] = i;
94
+ return a;
95
+ };
96
+
97
+ // in_proj: (2*D_inner, D_model) – up-projection (and z gate)
98
+ this.wInProj = randn(2 * D * dModel);
99
+ this.bInProj = zeros(2 * D);
100
+
101
+ // conv1d: weight (D_inner, K), bias (D_inner,)
102
+ this.wConv = randn(D * K, 0.01);
103
+ this.bConv = zeros(D);
104
+
105
+ // x_proj: (dt_rank + 2*N, D_inner) – projects x to Δ, B, C
106
+ this.wXProj = randn((R + 2 * N) * D, 0.01);
107
+ this.bXProj = zeros(R + 2 * N);
108
+
109
+ // dt_proj: (D_inner, dt_rank) – projects Δ to full D_inner width
110
+ this.wDtProj = randn(D * R, 0.02);
111
+ this.bDtProj = zeros(D);
112
+
113
+ // A: (D_inner, N) – log-space negative eigenvalues
114
+ // Initialised to log(range(1, N+1)) per HiPPO theory
115
+ this.A_log = new Float32Array(D * N);
116
+ for (let d = 0; d < D; d++) {
117
+ for (let n = 0; n < N; n++) {
118
+ this.A_log[d * N + n] = Math.log(n + 1);
119
+ }
120
+ }
121
+
122
+ // D: (D_inner,) – skip connection scale (initialised to 1)
123
+ this.D_vec = ones(D);
124
+
125
+ // out_proj: (D_model, D_inner) – down-projection
126
+ this.wOutProj = randn(dModel * D, 0.02);
127
+ this.bOutProj = zeros(dModel);
128
+
129
+ // RMSNorm scale: (D_model,)
130
+ this.normWeight = ones(dModel);
131
+
132
+ // Upload all to GPU
133
+ this._uploadWeightsToGPU();
134
+ }
135
+
136
+ _uploadWeightsToGPU() {
137
+ const d = this.device;
138
+ const mk = (arr, readable = true) => createStorageBuffer(d, arr, readable);
139
+
140
+ this.gpuWeights = {
141
+ wInProj : mk(this.wInProj),
142
+ bInProj : mk(this.bInProj),
143
+ wConv : mk(this.wConv),
144
+ bConv : mk(this.bConv),
145
+ wXProj : mk(this.wXProj),
146
+ bXProj : mk(this.bXProj),
147
+ wDtProj : mk(this.wDtProj),
148
+ bDtProj : mk(this.bDtProj),
149
+ A_log : mk(this.A_log),
150
+ D_vec : mk(this.D_vec),
151
+ wOutProj : mk(this.wOutProj),
152
+ bOutProj : mk(this.bOutProj),
153
+ normWeight: mk(this.normWeight),
154
+ };
155
+ }
156
+
157
+ // ─── Pipeline compilation ─────────────────────────────────────────────────
158
+
159
+ _buildPipelines() {
160
+ const d = this.device;
161
+
162
+ this.pipelines = {
163
+ linear : createComputePipeline(d, LINEAR_FORWARD_WGSL, 'linear_forward'),
164
+ conv1d : createComputePipeline(d, CONV1D_FORWARD_WGSL, 'conv1d_forward'),
165
+ silu : createComputePipeline(d, ACTIVATIONS_WGSL, 'silu_forward'),
166
+ rmsnorm : createComputePipeline(d, ACTIVATIONS_WGSL, 'rmsnorm_forward'),
167
+ scan_fwd : createComputePipeline(d, SELECTIVE_SCAN_FORWARD_WGSL, 'forward_scan'),
168
+ scan_reduce: createComputePipeline(d, SELECTIVE_SCAN_FORWARD_WGSL, 'forward_reduce'),
169
+ };
170
+ }
171
+
172
+ // ─── Forward pass ─────────────────────────────────────────────────────────
173
+
174
+ /**
175
+ * Run the Mamba block forward pass on GPU.
176
+ *
177
+ * @param {GPUBuffer} xBuf – input (batch * seqLen, dModel)
178
+ * @param {number} batch
179
+ * @param {number} seqLen
180
+ * @returns {{ output: GPUBuffer, cache: Object }}
181
+ * output – (batch * seqLen, dModel)
182
+ * cache – intermediate buffers needed for backward pass
183
+ */
184
+ forward(xBuf, batch, seqLen) {
185
+ const d = this.device;
186
+ const { dModel, dState, dConv } = this.config;
187
+ const D = this.dInner;
188
+ const N = dState;
189
+ const B = batch;
190
+ const L = seqLen;
191
+ const M = B * L;
192
+ const R = this.dtRank;
193
+
194
+ // Intermediate buffers (will be freed after backward or cached)
195
+ const cache = {};
196
+
197
+ // 1. RMSNorm: (M, dModel)
198
+ const normOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
199
+ const normInv = createEmptyStorageBuffer(d, M * 4, true);
200
+ cache.normInv = normInv;
201
+ cache.normIn = xBuf;
202
+
203
+ {
204
+ // Pack params as Uint32 (num_rows, dim) + f32 (eps) ← 12 bytes padded to 16
205
+ const params = new ArrayBuffer(16);
206
+ new Uint32Array(params, 0, 2).set([M, dModel]);
207
+ new Float32Array(params, 8, 1).set([1e-6]);
208
+ const pBuf = createUniformBuffer(d, params);
209
+
210
+ const bg = createBindGroup(d, this.pipelines.rmsnorm,
211
+ [pBuf, xBuf, this.gpuWeights.normWeight, normOut, normInv]);
212
+ dispatchKernel(d, this.pipelines.rmsnorm, bg, [cdiv(M, 64), 1, 1]);
213
+ }
214
+
215
+ // 2. in_proj: (M, 2*D) = normOut @ wInProj^T + bInProj
216
+ const inProjOut = createEmptyStorageBuffer(d, M * 2 * D * 4, true);
217
+ cache.normOut = normOut;
218
+ {
219
+ const params = new Uint32Array([M, dModel, 2 * D]).buffer;
220
+ const pBuf = createUniformBuffer(d, params);
221
+ const bg = createBindGroup(d, this.pipelines.linear,
222
+ [pBuf, normOut, this.gpuWeights.wInProj, this.gpuWeights.bInProj, inProjOut]);
223
+ dispatchKernel(d, this.pipelines.linear, bg, [cdiv(M, 16), cdiv(2 * D, 16), 1]);
224
+ }
225
+
226
+ // Split inProjOut into x (M, D) and z (M, D) – the z-gate
227
+ // We reuse the same buffer with offsets since WGSL bindings can be offset.
228
+ // For simplicity, allocate two separate buffers and copy.
229
+ const xConvIn = createEmptyStorageBuffer(d, M * D * 4, true);
230
+ const zBuf = createEmptyStorageBuffer(d, M * D * 4, true);
231
+ {
232
+ // Copy first D columns into xConvIn, last D columns into zBuf
233
+ const enc = d.createCommandEncoder();
234
+ enc.copyBufferToBuffer(inProjOut, 0, xConvIn, 0, M * D * 4);
235
+ enc.copyBufferToBuffer(inProjOut, M * D * 4, zBuf, 0, M * D * 4);
236
+ d.queue.submit([enc.finish()]);
237
+ }
238
+ cache.zBuf = zBuf;
239
+
240
+ // 3. Conv1D on xConvIn: (B, L, D) – depthwise causal conv
241
+ const convOut = createEmptyStorageBuffer(d, M * D * 4, true);
242
+ cache.xConvIn = xConvIn;
243
+ {
244
+ const params = new Uint32Array([L, D, dConv, B]).buffer;
245
+ const pBuf = createUniformBuffer(d, params);
246
+ const bg = createBindGroup(d, this.pipelines.conv1d,
247
+ [pBuf, xConvIn, this.gpuWeights.wConv, this.gpuWeights.bConv, convOut]);
248
+ dispatchKernel(d, this.pipelines.conv1d, bg, [cdiv(L, 16), cdiv(D, 16), B]);
249
+ }
250
+
251
+ // 4. SiLU(convOut) in-place
252
+ const siluOut = createEmptyStorageBuffer(d, M * D * 4, true);
253
+ cache.convOut = convOut;
254
+ {
255
+ const params = new Uint32Array([M * D]).buffer;
256
+ const pBuf = createUniformBuffer(d, params);
257
+ const bg = createBindGroup(d, this.pipelines.silu,
258
+ [pBuf, convOut, siluOut]);
259
+ dispatchKernel(d, this.pipelines.silu, bg, [cdiv(M * D, 256), 1, 1]);
260
+ }
261
+
262
+ // 5. x_proj: (M, R+2N) = siluOut @ wXProj^T + bXProj
263
+ const xProjOut = createEmptyStorageBuffer(d, M * (R + 2 * N) * 4, true);
264
+ {
265
+ const params = new Uint32Array([M, D, R + 2 * N]).buffer;
266
+ const pBuf = createUniformBuffer(d, params);
267
+ const bg = createBindGroup(d, this.pipelines.linear,
268
+ [pBuf, siluOut, this.gpuWeights.wXProj, this.gpuWeights.bXProj, xProjOut]);
269
+ dispatchKernel(d, this.pipelines.linear, bg, [cdiv(M, 16), cdiv(R + 2 * N, 16), 1]);
270
+ }
271
+
272
+ // Split xProjOut → dtRaw (M, R), B_raw (M*N flattened) = (B, L, N), C_raw (B, L, N)
273
+ const dtRaw = createEmptyStorageBuffer(d, M * R * 4, true);
274
+ const B_raw = createEmptyStorageBuffer(d, B * L * N * 4, true);
275
+ const C_raw = createEmptyStorageBuffer(d, B * L * N * 4, true);
276
+ {
277
+ const enc = d.createCommandEncoder();
278
+ enc.copyBufferToBuffer(xProjOut, 0, dtRaw, 0, M * R * 4);
279
+ enc.copyBufferToBuffer(xProjOut, M * R * 4, B_raw, 0, B * L * N * 4);
280
+ enc.copyBufferToBuffer(xProjOut, M * (R + N) * 4, C_raw, 0, B * L * N * 4);
281
+ d.queue.submit([enc.finish()]);
282
+ }
283
+
284
+ // 6. dt_proj: (M, D) = dtRaw @ wDtProj^T + bDtProj
285
+ const deltaFull = createEmptyStorageBuffer(d, M * D * 4, true);
286
+ {
287
+ const params = new Uint32Array([M, R, D]).buffer;
288
+ const pBuf = createUniformBuffer(d, params);
289
+ const bg = createBindGroup(d, this.pipelines.linear,
290
+ [pBuf, dtRaw, this.gpuWeights.wDtProj, this.gpuWeights.bDtProj, deltaFull]);
291
+ dispatchKernel(d, this.pipelines.linear, bg, [cdiv(M, 16), cdiv(D, 16), 1]);
292
+ }
293
+
294
+ // 7. Selective Scan
295
+ // Allocate y (B, L, D) and h_cache (2 * B*L*D*N) – first half for h, second for y_partial
296
+ const scanY = createEmptyStorageBuffer(d, B * L * D * 4, true);
297
+ const hCache = createEmptyStorageBuffer(d, 2 * B * L * D * N * 4, true);
298
+ cache.siluOut = siluOut;
299
+ cache.deltaFull = deltaFull;
300
+ cache.B_raw = B_raw;
301
+ cache.C_raw = C_raw;
302
+ cache.hCache = hCache;
303
+
304
+ {
305
+ const params = new Uint32Array([L, N, D, B]).buffer;
306
+ const pBuf = createUniformBuffer(d, params);
307
+
308
+ // forward_scan pass
309
+ const bg = createBindGroup(d, this.pipelines.scan_fwd,
310
+ [pBuf, siluOut, deltaFull, this.gpuWeights.A_log, B_raw, C_raw,
311
+ this.gpuWeights.D_vec, scanY, hCache]);
312
+ dispatchKernel(d, this.pipelines.scan_fwd, bg,
313
+ [cdiv(D, 8), cdiv(N, 8), B]);
314
+
315
+ // forward_reduce pass (collapses N dim → y)
316
+ const bg2 = createBindGroup(d, this.pipelines.scan_reduce,
317
+ [pBuf, siluOut, deltaFull, this.gpuWeights.A_log, B_raw, C_raw,
318
+ this.gpuWeights.D_vec, scanY, hCache]);
319
+ dispatchKernel(d, this.pipelines.scan_reduce, bg2,
320
+ [cdiv(L, 64), D, B]);
321
+ }
322
+
323
+ // 8. Gate: scanY *= SiLU(zBuf) – element-wise product
324
+ const siluZ = createEmptyStorageBuffer(d, M * D * 4, true);
325
+ const gatedOut = createEmptyStorageBuffer(d, M * D * 4, true);
326
+ {
327
+ // SiLU(z)
328
+ const params = new Uint32Array([M * D]).buffer;
329
+ const pBuf = createUniformBuffer(d, params);
330
+ const bg = createBindGroup(d, this.pipelines.silu,
331
+ [pBuf, zBuf, siluZ]);
332
+ dispatchKernel(d, this.pipelines.silu, bg, [cdiv(M * D, 256), 1, 1]);
333
+
334
+ // Element-wise multiply scanY * siluZ → gatedOut
335
+ // We encode this as a trivial compute pass using a small inline shader.
336
+ const mulShader = /* wgsl */`
337
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
338
+ @group(0) @binding(1) var<storage, read> b : array<f32>;
339
+ @group(0) @binding(2) var<storage, read_write> c : array<f32>;
340
+ @group(0) @binding(3) var<uniform> n : u32;
341
+ @compute @workgroup_size(256)
342
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
343
+ let i = gid.x;
344
+ if (i < n) { c[i] = a[i] * b[i]; }
345
+ }
346
+ `;
347
+ const mulPipeline = createComputePipeline(d, mulShader, 'main');
348
+ const nBuf = createUniformBuffer(d, new Uint32Array([M * D]).buffer);
349
+ const bgMul = createBindGroup(d, mulPipeline,
350
+ [scanY, siluZ, gatedOut, nBuf]);
351
+ dispatchKernel(d, mulPipeline, bgMul, [cdiv(M * D, 256), 1, 1]);
352
+ }
353
+
354
+ // 9. out_proj: (M, dModel) = gatedOut @ wOutProj^T + bOutProj
355
+ const outProjOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
356
+ {
357
+ const params = new Uint32Array([M, D, dModel]).buffer;
358
+ const pBuf = createUniformBuffer(d, params);
359
+ const bg = createBindGroup(d, this.pipelines.linear,
360
+ [pBuf, gatedOut, this.gpuWeights.wOutProj, this.gpuWeights.bOutProj, outProjOut]);
361
+ dispatchKernel(d, this.pipelines.linear, bg, [cdiv(M, 16), cdiv(dModel, 16), 1]);
362
+ }
363
+
364
+ // 10. Residual add: output = outProjOut + x
365
+ const output = createEmptyStorageBuffer(d, M * dModel * 4, true);
366
+ {
367
+ const addShader = /* wgsl */`
368
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
369
+ @group(0) @binding(1) var<storage, read> b : array<f32>;
370
+ @group(0) @binding(2) var<storage, read_write> c : array<f32>;
371
+ @group(0) @binding(3) var<uniform> n : u32;
372
+ @compute @workgroup_size(256)
373
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
374
+ let i = gid.x;
375
+ if (i < n) { c[i] = a[i] + b[i]; }
376
+ }
377
+ `;
378
+ const addPipeline = createComputePipeline(d, addShader, 'main');
379
+ const nBuf = createUniformBuffer(d, new Uint32Array([M * dModel]).buffer);
380
+ const bgAdd = createBindGroup(d, addPipeline,
381
+ [outProjOut, xBuf, output, nBuf]);
382
+ dispatchKernel(d, addPipeline, bgAdd, [cdiv(M * dModel, 256), 1, 1]);
383
+ }
384
+
385
+ return { output, cache };
386
+ }
387
+
388
+ /**
389
+ * Return a list of all parameter GPU buffers (for the optimizer).
390
+ * @returns {Array<{buf: GPUBuffer, numel: number, name: string}>}
391
+ */
392
+ parameters() {
393
+ const { dModel, dState, dConv } = this.config;
394
+ const D = this.dInner;
395
+ const N = dState;
396
+ const K = dConv;
397
+ const R = this.dtRank;
398
+
399
+ return [
400
+ { buf: this.gpuWeights.wInProj, numel: 2 * D * dModel, name: 'wInProj' },
401
+ { buf: this.gpuWeights.bInProj, numel: 2 * D, name: 'bInProj' },
402
+ { buf: this.gpuWeights.wConv, numel: D * K, name: 'wConv' },
403
+ { buf: this.gpuWeights.bConv, numel: D, name: 'bConv' },
404
+ { buf: this.gpuWeights.wXProj, numel: (R + 2*N) * D, name: 'wXProj' },
405
+ { buf: this.gpuWeights.bXProj, numel: R + 2 * N, name: 'bXProj' },
406
+ { buf: this.gpuWeights.wDtProj, numel: D * R, name: 'wDtProj' },
407
+ { buf: this.gpuWeights.bDtProj, numel: D, name: 'bDtProj' },
408
+ { buf: this.gpuWeights.A_log, numel: D * N, name: 'A_log' },
409
+ { buf: this.gpuWeights.D_vec, numel: D, name: 'D_vec' },
410
+ { buf: this.gpuWeights.wOutProj, numel: dModel * D, name: 'wOutProj' },
411
+ { buf: this.gpuWeights.bOutProj, numel: dModel, name: 'bOutProj' },
412
+ { buf: this.gpuWeights.normWeight, numel: dModel, name: 'normWeight'},
413
+ ];
414
+ }
415
+
416
+ /**
417
+ * WSLA (Weight-Selective Local Adaptation) mode.
418
+ * Freezes all parameters except the B and C matrices (wXProj slice).
419
+ * This allows rapid local adaptation with minimal compute.
420
+ *
421
+ * @param {boolean} enabled
422
+ */
423
+ setWSLAMode(enabled) {
424
+ this._wslaMode = enabled;
425
+ // Mark which parameters receive gradients
426
+ // (The trainer checks this.getTrainableParams() during backward)
427
+ }
428
+
429
+ /**
430
+ * Returns only the trainable parameters under WSLA mode.
431
+ * @returns {Array<{buf: GPUBuffer, numel: number, name: string}>}
432
+ */
433
+ getTrainableParams() {
434
+ if (this._wslaMode) {
435
+ // Only B and C portions of wXProj
436
+ return [
437
+ { buf: this.gpuWeights.wXProj, numel: this.wXProj.length, name: 'wXProj' },
438
+ { buf: this.gpuWeights.bXProj, numel: this.bXProj.length, name: 'bXProj' },
439
+ ];
440
+ }
441
+ return this.parameters();
442
+ }
443
+ }