mambacode.js 1.0.0 → 1.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. package/README.md +198 -76
  2. package/dist/index.d.ts +19 -0
  3. package/dist/index.d.ts.map +1 -0
  4. package/dist/index.js +18 -0
  5. package/dist/index.js.map +1 -0
  6. package/dist/kernels/activations.d.ts +3 -0
  7. package/dist/kernels/activations.d.ts.map +1 -0
  8. package/dist/kernels/activations.js +87 -0
  9. package/dist/kernels/activations.js.map +1 -0
  10. package/dist/kernels/conv1d.d.ts +3 -0
  11. package/dist/kernels/conv1d.d.ts.map +1 -0
  12. package/dist/kernels/conv1d.js +152 -0
  13. package/dist/kernels/conv1d.js.map +1 -0
  14. package/dist/kernels/linear_projection.d.ts +3 -0
  15. package/dist/kernels/linear_projection.d.ts.map +1 -0
  16. package/dist/kernels/linear_projection.js +219 -0
  17. package/dist/kernels/linear_projection.js.map +1 -0
  18. package/dist/kernels/selective_scan.d.ts +3 -0
  19. package/dist/kernels/selective_scan.d.ts.map +1 -0
  20. package/dist/kernels/selective_scan.js +348 -0
  21. package/dist/kernels/selective_scan.js.map +1 -0
  22. package/dist/kernels/weight_update.d.ts +3 -0
  23. package/dist/kernels/weight_update.d.ts.map +1 -0
  24. package/dist/kernels/weight_update.js +119 -0
  25. package/dist/kernels/weight_update.js.map +1 -0
  26. package/dist/model/mamba_block.d.ts +64 -0
  27. package/dist/model/mamba_block.d.ts.map +1 -0
  28. package/dist/model/mamba_block.js +309 -0
  29. package/dist/model/mamba_block.js.map +1 -0
  30. package/dist/model/mamba_model.d.ts +66 -0
  31. package/dist/model/mamba_model.d.ts.map +1 -0
  32. package/dist/model/mamba_model.js +289 -0
  33. package/dist/model/mamba_model.js.map +1 -0
  34. package/dist/tokenizer/bpe.d.ts +29 -0
  35. package/dist/tokenizer/bpe.d.ts.map +1 -0
  36. package/dist/tokenizer/bpe.js +164 -0
  37. package/dist/tokenizer/bpe.js.map +1 -0
  38. package/dist/training/autograd.d.ts +27 -0
  39. package/dist/training/autograd.d.ts.map +1 -0
  40. package/dist/training/autograd.js +120 -0
  41. package/dist/training/autograd.js.map +1 -0
  42. package/dist/training/trainer.d.ts +37 -0
  43. package/dist/training/trainer.d.ts.map +1 -0
  44. package/dist/training/trainer.js +183 -0
  45. package/dist/training/trainer.js.map +1 -0
  46. package/dist/utils/gpu_utils.d.ts +21 -0
  47. package/dist/utils/gpu_utils.d.ts.map +1 -0
  48. package/dist/utils/gpu_utils.js +111 -0
  49. package/dist/utils/gpu_utils.js.map +1 -0
  50. package/dist/utils/quantization.d.ts +26 -0
  51. package/dist/utils/quantization.d.ts.map +1 -0
  52. package/dist/utils/quantization.js +116 -0
  53. package/dist/utils/quantization.js.map +1 -0
  54. package/package.json +43 -18
  55. package/src/index.ts +61 -0
  56. package/src/kernels/{activations.js → activations.ts} +2 -2
  57. package/src/kernels/{linear_projection.js → linear_projection.ts} +2 -2
  58. package/src/kernels/{selective_scan.js → selective_scan.ts} +2 -2
  59. package/src/kernels/{weight_update.js → weight_update.ts} +2 -2
  60. package/src/model/{mamba_block.js → mamba_block.ts} +134 -170
  61. package/src/model/{mamba_model.js → mamba_model.ts} +165 -121
  62. package/src/tokenizer/bpe.ts +186 -0
  63. package/src/training/autograd.ts +135 -0
  64. package/src/training/{trainer.js → trainer.ts} +79 -161
  65. package/src/utils/gpu_utils.ts +147 -0
  66. package/src/utils/quantization.ts +154 -0
  67. package/src/index.js +0 -89
  68. package/src/tokenizer/bpe.js +0 -256
  69. package/src/training/autograd.js +0 -221
  70. package/src/utils/gpu_utils.js +0 -217
  71. package/src/utils/quantization.js +0 -215
  72. /package/src/kernels/{conv1d.js → conv1d.ts} +0 -0
@@ -1,19 +1,5 @@
1
1
  /**
2
- * mamba_block.js – Mamba Mixer Block
3
- *
4
- * Implements one complete Mamba residual layer:
5
- *
6
- * x ──► Norm ──► Linear up (×2, for z-gate) ──► Conv1D ──► SiLU ──► Scan ──► × z ──► Linear down ──► + x
7
- *
8
- * Components (all dispatched as WebGPU compute passes):
9
- * 1. RMSNorm
10
- * 2. Linear up-projection: (D_model → 2 × D_inner)
11
- * 3. 1D Causal Convolution (depthwise, kernel_size=4)
12
- * 4. SiLU activation
13
- * 5. Selective Scan (S6 core)
14
- * 6. Gated multiplication: y * SiLU(z)
15
- * 7. Linear down-projection: (D_inner → D_model)
16
- * 8. Residual add
2
+ * mamba_block.ts – Mamba Mixer Block
17
3
  */
18
4
 
19
5
  import {
@@ -31,87 +17,126 @@ import { CONV1D_FORWARD_WGSL } from '../kernels/conv1d.js';
31
17
  import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection.js';
32
18
  import { ACTIVATIONS_WGSL } from '../kernels/activations.js';
33
19
 
34
- /**
35
- * @typedef {Object} MambaBlockConfig
36
- * @property {number} dModel – model dimension (embedding size)
37
- * @property {number} dState – SSM state dimension (N, default 16)
38
- * @property {number} dConv – 1D conv kernel size (default 4)
39
- * @property {number} expand – expansion factor (default 2) → dInner = expand * dModel
40
- * @property {number} dtRank – rank of Δ projection (default ceil(dModel/16))
41
- * @property {boolean} [biasConv] – use bias in conv (default true)
42
- */
20
+ export interface MambaBlockConfig {
21
+ dModel: number;
22
+ dState?: number;
23
+ dConv?: number;
24
+ expand?: number;
25
+ dtRank?: number;
26
+ biasConv?: boolean;
27
+ }
28
+
29
+ export interface BlockParam {
30
+ buf: GPUBuffer;
31
+ numel: number;
32
+ name: string;
33
+ }
34
+
35
+ export interface BlockCache {
36
+ normInv: GPUBuffer;
37
+ normIn: GPUBuffer;
38
+ normOut: GPUBuffer;
39
+ zBuf: GPUBuffer;
40
+ xConvIn: GPUBuffer;
41
+ convOut: GPUBuffer;
42
+ siluOut: GPUBuffer;
43
+ deltaFull: GPUBuffer;
44
+ B_raw: GPUBuffer;
45
+ C_raw: GPUBuffer;
46
+ hCache: GPUBuffer;
47
+ }
48
+
49
+ export interface BlockForwardResult {
50
+ output: GPUBuffer;
51
+ cache: BlockCache;
52
+ }
43
53
 
44
54
  export class MambaBlock {
45
- /**
46
- * @param {GPUDevice} device
47
- * @param {MambaBlockConfig} config
48
- */
49
- constructor(device, config) {
55
+ device: GPUDevice;
56
+ config: Required<MambaBlockConfig>;
57
+ dInner: number;
58
+ dtRank: number;
59
+ wInProj: Float32Array;
60
+ bInProj: Float32Array;
61
+ wConv: Float32Array;
62
+ bConv: Float32Array;
63
+ wXProj: Float32Array;
64
+ bXProj: Float32Array;
65
+ wDtProj: Float32Array;
66
+ bDtProj: Float32Array;
67
+ A_log: Float32Array;
68
+ D_vec: Float32Array;
69
+ wOutProj: Float32Array;
70
+ bOutProj: Float32Array;
71
+ normWeight: Float32Array;
72
+ gpuWeights: Record<string, GPUBuffer>;
73
+ pipelines: Record<string, GPUComputePipeline>;
74
+ private _wslaMode = false;
75
+
76
+ constructor(device: GPUDevice, config: MambaBlockConfig) {
50
77
  this.device = device;
51
78
  this.config = {
52
79
  dState : 16,
53
80
  dConv : 4,
54
81
  expand : 2,
55
82
  biasConv: true,
83
+ dtRank : Math.ceil(config.dModel / 16),
56
84
  ...config,
57
- };
85
+ } as Required<MambaBlockConfig>;
58
86
 
59
- const { dModel, dState, dConv, expand } = this.config;
87
+ const { dModel, expand } = this.config;
60
88
  this.dInner = expand * dModel;
61
- this.dtRank = this.config.dtRank ?? Math.ceil(dModel / 16);
89
+ this.dtRank = config.dtRank ?? Math.ceil(dModel / 16);
90
+
91
+ // Initialize these before _initWeights so TypeScript is happy
92
+ this.wInProj = new Float32Array(0);
93
+ this.bInProj = new Float32Array(0);
94
+ this.wConv = new Float32Array(0);
95
+ this.bConv = new Float32Array(0);
96
+ this.wXProj = new Float32Array(0);
97
+ this.bXProj = new Float32Array(0);
98
+ this.wDtProj = new Float32Array(0);
99
+ this.bDtProj = new Float32Array(0);
100
+ this.A_log = new Float32Array(0);
101
+ this.D_vec = new Float32Array(0);
102
+ this.wOutProj = new Float32Array(0);
103
+ this.bOutProj = new Float32Array(0);
104
+ this.normWeight = new Float32Array(0);
105
+ this.gpuWeights = {};
106
+ this.pipelines = {};
62
107
 
63
- // ---- Initialise learnable parameters (CPU → GPU) ----
64
108
  this._initWeights();
65
-
66
- // ---- Compile GPU pipelines (once) ----
67
109
  this._buildPipelines();
68
110
  }
69
111
 
70
- // ─── Weight initialisation ────────────────────────────────────────────────
71
-
72
- _initWeights() {
112
+ private _initWeights(): void {
73
113
  const { dModel, dState, dConv } = this.config;
74
114
  const D = this.dInner;
75
115
  const N = dState;
76
116
  const K = dConv;
77
117
  const R = this.dtRank;
78
118
 
79
- const randn = (n, std = 0.02) => {
119
+ const randn = (n: number, std = 0.02): Float32Array => {
80
120
  const a = new Float32Array(n);
81
121
  for (let i = 0; i < n; i++) {
82
- // Box-Muller
83
122
  const u1 = Math.random(), u2 = Math.random();
84
123
  a[i] = std * Math.sqrt(-2 * Math.log(u1 + 1e-12)) * Math.cos(2 * Math.PI * u2);
85
124
  }
86
125
  return a;
87
126
  };
88
127
 
89
- const zeros = (n) => new Float32Array(n);
90
- const ones = (n) => new Float32Array(n).fill(1.0);
91
- const linspace = (n) => {
92
- const a = new Float32Array(n);
93
- for (let i = 0; i < n; i++) a[i] = i;
94
- return a;
95
- };
128
+ const zeros = (n: number): Float32Array => new Float32Array(n);
129
+ const ones = (n: number): Float32Array => new Float32Array(n).fill(1.0);
96
130
 
97
- // in_proj: (2*D_inner, D_model) – up-projection (and z gate)
98
131
  this.wInProj = randn(2 * D * dModel);
99
132
  this.bInProj = zeros(2 * D);
100
-
101
- // conv1d: weight (D_inner, K), bias (D_inner,)
102
133
  this.wConv = randn(D * K, 0.01);
103
134
  this.bConv = zeros(D);
104
-
105
- // x_proj: (dt_rank + 2*N, D_inner) – projects x to Δ, B, C
106
135
  this.wXProj = randn((R + 2 * N) * D, 0.01);
107
136
  this.bXProj = zeros(R + 2 * N);
108
-
109
- // dt_proj: (D_inner, dt_rank) – projects Δ to full D_inner width
110
137
  this.wDtProj = randn(D * R, 0.02);
111
138
  this.bDtProj = zeros(D);
112
139
 
113
- // A: (D_inner, N) – log-space negative eigenvalues
114
- // Initialised to log(range(1, N+1)) per HiPPO theory
115
140
  this.A_log = new Float32Array(D * N);
116
141
  for (let d = 0; d < D; d++) {
117
142
  for (let n = 0; n < N; n++) {
@@ -119,23 +144,17 @@ export class MambaBlock {
119
144
  }
120
145
  }
121
146
 
122
- // D: (D_inner,) – skip connection scale (initialised to 1)
123
147
  this.D_vec = ones(D);
124
-
125
- // out_proj: (D_model, D_inner) – down-projection
126
148
  this.wOutProj = randn(dModel * D, 0.02);
127
149
  this.bOutProj = zeros(dModel);
128
-
129
- // RMSNorm scale: (D_model,)
130
150
  this.normWeight = ones(dModel);
131
151
 
132
- // Upload all to GPU
133
152
  this._uploadWeightsToGPU();
134
153
  }
135
154
 
136
- _uploadWeightsToGPU() {
155
+ private _uploadWeightsToGPU(): void {
137
156
  const d = this.device;
138
- const mk = (arr, readable = true) => createStorageBuffer(d, arr, readable);
157
+ const mk = (arr: Float32Array, readable = true): GPUBuffer => createStorageBuffer(d, arr, readable);
139
158
 
140
159
  this.gpuWeights = {
141
160
  wInProj : mk(this.wInProj),
@@ -154,9 +173,7 @@ export class MambaBlock {
154
173
  };
155
174
  }
156
175
 
157
- // ─── Pipeline compilation ─────────────────────────────────────────────────
158
-
159
- _buildPipelines() {
176
+ private _buildPipelines(): void {
160
177
  const d = this.device;
161
178
 
162
179
  this.pipelines = {
@@ -169,19 +186,7 @@ export class MambaBlock {
169
186
  };
170
187
  }
171
188
 
172
- // ─── Forward pass ─────────────────────────────────────────────────────────
173
-
174
- /**
175
- * Run the Mamba block forward pass on GPU.
176
- *
177
- * @param {GPUBuffer} xBuf – input (batch * seqLen, dModel)
178
- * @param {number} batch
179
- * @param {number} seqLen
180
- * @returns {{ output: GPUBuffer, cache: Object }}
181
- * output – (batch * seqLen, dModel)
182
- * cache – intermediate buffers needed for backward pass
183
- */
184
- forward(xBuf, batch, seqLen) {
189
+ forward(xBuf: GPUBuffer, batch: number, seqLen: number): BlockForwardResult {
185
190
  const d = this.device;
186
191
  const { dModel, dState, dConv } = this.config;
187
192
  const D = this.dInner;
@@ -191,45 +196,37 @@ export class MambaBlock {
191
196
  const M = B * L;
192
197
  const R = this.dtRank;
193
198
 
194
- // Intermediate buffers (will be freed after backward or cached)
195
- const cache = {};
199
+ const cache = {} as BlockCache;
196
200
 
197
- // 1. RMSNorm: (M, dModel)
198
201
  const normOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
199
202
  const normInv = createEmptyStorageBuffer(d, M * 4, true);
200
203
  cache.normInv = normInv;
201
204
  cache.normIn = xBuf;
202
205
 
203
206
  {
204
- // Pack params as Uint32 (num_rows, dim) + f32 (eps) ← 12 bytes padded to 16
205
207
  const params = new ArrayBuffer(16);
206
208
  new Uint32Array(params, 0, 2).set([M, dModel]);
207
209
  new Float32Array(params, 8, 1).set([1e-6]);
208
210
  const pBuf = createUniformBuffer(d, params);
209
211
 
210
- const bg = createBindGroup(d, this.pipelines.rmsnorm,
211
- [pBuf, xBuf, this.gpuWeights.normWeight, normOut, normInv]);
212
- dispatchKernel(d, this.pipelines.rmsnorm, bg, [cdiv(M, 64), 1, 1]);
212
+ const bg = createBindGroup(d, this.pipelines['rmsnorm']!,
213
+ [pBuf, xBuf, this.gpuWeights['normWeight']!, normOut, normInv]);
214
+ dispatchKernel(d, this.pipelines['rmsnorm']!, bg, [cdiv(M, 64), 1, 1]);
213
215
  }
214
216
 
215
- // 2. in_proj: (M, 2*D) = normOut @ wInProj^T + bInProj
216
217
  const inProjOut = createEmptyStorageBuffer(d, M * 2 * D * 4, true);
217
218
  cache.normOut = normOut;
218
219
  {
219
220
  const params = new Uint32Array([M, dModel, 2 * D]).buffer;
220
221
  const pBuf = createUniformBuffer(d, params);
221
- const bg = createBindGroup(d, this.pipelines.linear,
222
- [pBuf, normOut, this.gpuWeights.wInProj, this.gpuWeights.bInProj, inProjOut]);
223
- dispatchKernel(d, this.pipelines.linear, bg, [cdiv(M, 16), cdiv(2 * D, 16), 1]);
222
+ const bg = createBindGroup(d, this.pipelines['linear']!,
223
+ [pBuf, normOut, this.gpuWeights['wInProj']!, this.gpuWeights['bInProj']!, inProjOut]);
224
+ dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(2 * D, 16), 1]);
224
225
  }
225
226
 
226
- // Split inProjOut into x (M, D) and z (M, D) – the z-gate
227
- // We reuse the same buffer with offsets since WGSL bindings can be offset.
228
- // For simplicity, allocate two separate buffers and copy.
229
227
  const xConvIn = createEmptyStorageBuffer(d, M * D * 4, true);
230
228
  const zBuf = createEmptyStorageBuffer(d, M * D * 4, true);
231
229
  {
232
- // Copy first D columns into xConvIn, last D columns into zBuf
233
230
  const enc = d.createCommandEncoder();
234
231
  enc.copyBufferToBuffer(inProjOut, 0, xConvIn, 0, M * D * 4);
235
232
  enc.copyBufferToBuffer(inProjOut, M * D * 4, zBuf, 0, M * D * 4);
@@ -237,39 +234,35 @@ export class MambaBlock {
237
234
  }
238
235
  cache.zBuf = zBuf;
239
236
 
240
- // 3. Conv1D on xConvIn: (B, L, D) – depthwise causal conv
241
237
  const convOut = createEmptyStorageBuffer(d, M * D * 4, true);
242
238
  cache.xConvIn = xConvIn;
243
239
  {
244
240
  const params = new Uint32Array([L, D, dConv, B]).buffer;
245
241
  const pBuf = createUniformBuffer(d, params);
246
- const bg = createBindGroup(d, this.pipelines.conv1d,
247
- [pBuf, xConvIn, this.gpuWeights.wConv, this.gpuWeights.bConv, convOut]);
248
- dispatchKernel(d, this.pipelines.conv1d, bg, [cdiv(L, 16), cdiv(D, 16), B]);
242
+ const bg = createBindGroup(d, this.pipelines['conv1d']!,
243
+ [pBuf, xConvIn, this.gpuWeights['wConv']!, this.gpuWeights['bConv']!, convOut]);
244
+ dispatchKernel(d, this.pipelines['conv1d']!, bg, [cdiv(L, 16), cdiv(D, 16), B]);
249
245
  }
250
246
 
251
- // 4. SiLU(convOut) in-place
252
247
  const siluOut = createEmptyStorageBuffer(d, M * D * 4, true);
253
248
  cache.convOut = convOut;
254
249
  {
255
250
  const params = new Uint32Array([M * D]).buffer;
256
251
  const pBuf = createUniformBuffer(d, params);
257
- const bg = createBindGroup(d, this.pipelines.silu,
252
+ const bg = createBindGroup(d, this.pipelines['silu']!,
258
253
  [pBuf, convOut, siluOut]);
259
- dispatchKernel(d, this.pipelines.silu, bg, [cdiv(M * D, 256), 1, 1]);
254
+ dispatchKernel(d, this.pipelines['silu']!, bg, [cdiv(M * D, 256), 1, 1]);
260
255
  }
261
256
 
262
- // 5. x_proj: (M, R+2N) = siluOut @ wXProj^T + bXProj
263
257
  const xProjOut = createEmptyStorageBuffer(d, M * (R + 2 * N) * 4, true);
264
258
  {
265
259
  const params = new Uint32Array([M, D, R + 2 * N]).buffer;
266
260
  const pBuf = createUniformBuffer(d, params);
267
- const bg = createBindGroup(d, this.pipelines.linear,
268
- [pBuf, siluOut, this.gpuWeights.wXProj, this.gpuWeights.bXProj, xProjOut]);
269
- dispatchKernel(d, this.pipelines.linear, bg, [cdiv(M, 16), cdiv(R + 2 * N, 16), 1]);
261
+ const bg = createBindGroup(d, this.pipelines['linear']!,
262
+ [pBuf, siluOut, this.gpuWeights['wXProj']!, this.gpuWeights['bXProj']!, xProjOut]);
263
+ dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(R + 2 * N, 16), 1]);
270
264
  }
271
265
 
272
- // Split xProjOut → dtRaw (M, R), B_raw (M*N flattened) = (B, L, N), C_raw (B, L, N)
273
266
  const dtRaw = createEmptyStorageBuffer(d, M * R * 4, true);
274
267
  const B_raw = createEmptyStorageBuffer(d, B * L * N * 4, true);
275
268
  const C_raw = createEmptyStorageBuffer(d, B * L * N * 4, true);
@@ -281,18 +274,15 @@ export class MambaBlock {
281
274
  d.queue.submit([enc.finish()]);
282
275
  }
283
276
 
284
- // 6. dt_proj: (M, D) = dtRaw @ wDtProj^T + bDtProj
285
277
  const deltaFull = createEmptyStorageBuffer(d, M * D * 4, true);
286
278
  {
287
279
  const params = new Uint32Array([M, R, D]).buffer;
288
280
  const pBuf = createUniformBuffer(d, params);
289
- const bg = createBindGroup(d, this.pipelines.linear,
290
- [pBuf, dtRaw, this.gpuWeights.wDtProj, this.gpuWeights.bDtProj, deltaFull]);
291
- dispatchKernel(d, this.pipelines.linear, bg, [cdiv(M, 16), cdiv(D, 16), 1]);
281
+ const bg = createBindGroup(d, this.pipelines['linear']!,
282
+ [pBuf, dtRaw, this.gpuWeights['wDtProj']!, this.gpuWeights['bDtProj']!, deltaFull]);
283
+ dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(D, 16), 1]);
292
284
  }
293
285
 
294
- // 7. Selective Scan
295
- // Allocate y (B, L, D) and h_cache (2 * B*L*D*N) – first half for h, second for y_partial
296
286
  const scanY = createEmptyStorageBuffer(d, B * L * D * 4, true);
297
287
  const hCache = createEmptyStorageBuffer(d, 2 * B * L * D * N * 4, true);
298
288
  cache.siluOut = siluOut;
@@ -305,34 +295,28 @@ export class MambaBlock {
305
295
  const params = new Uint32Array([L, N, D, B]).buffer;
306
296
  const pBuf = createUniformBuffer(d, params);
307
297
 
308
- // forward_scan pass
309
- const bg = createBindGroup(d, this.pipelines.scan_fwd,
310
- [pBuf, siluOut, deltaFull, this.gpuWeights.A_log, B_raw, C_raw,
311
- this.gpuWeights.D_vec, scanY, hCache]);
312
- dispatchKernel(d, this.pipelines.scan_fwd, bg,
298
+ const bg = createBindGroup(d, this.pipelines['scan_fwd']!,
299
+ [pBuf, siluOut, deltaFull, this.gpuWeights['A_log']!, B_raw, C_raw,
300
+ this.gpuWeights['D_vec']!, scanY, hCache]);
301
+ dispatchKernel(d, this.pipelines['scan_fwd']!, bg,
313
302
  [cdiv(D, 8), cdiv(N, 8), B]);
314
303
 
315
- // forward_reduce pass (collapses N dim → y)
316
- const bg2 = createBindGroup(d, this.pipelines.scan_reduce,
317
- [pBuf, siluOut, deltaFull, this.gpuWeights.A_log, B_raw, C_raw,
318
- this.gpuWeights.D_vec, scanY, hCache]);
319
- dispatchKernel(d, this.pipelines.scan_reduce, bg2,
304
+ const bg2 = createBindGroup(d, this.pipelines['scan_reduce']!,
305
+ [pBuf, siluOut, deltaFull, this.gpuWeights['A_log']!, B_raw, C_raw,
306
+ this.gpuWeights['D_vec']!, scanY, hCache]);
307
+ dispatchKernel(d, this.pipelines['scan_reduce']!, bg2,
320
308
  [cdiv(L, 64), D, B]);
321
309
  }
322
310
 
323
- // 8. Gate: scanY *= SiLU(zBuf) – element-wise product
324
311
  const siluZ = createEmptyStorageBuffer(d, M * D * 4, true);
325
312
  const gatedOut = createEmptyStorageBuffer(d, M * D * 4, true);
326
313
  {
327
- // SiLU(z)
328
314
  const params = new Uint32Array([M * D]).buffer;
329
315
  const pBuf = createUniformBuffer(d, params);
330
- const bg = createBindGroup(d, this.pipelines.silu,
316
+ const bg = createBindGroup(d, this.pipelines['silu']!,
331
317
  [pBuf, zBuf, siluZ]);
332
- dispatchKernel(d, this.pipelines.silu, bg, [cdiv(M * D, 256), 1, 1]);
318
+ dispatchKernel(d, this.pipelines['silu']!, bg, [cdiv(M * D, 256), 1, 1]);
333
319
 
334
- // Element-wise multiply scanY * siluZ → gatedOut
335
- // We encode this as a trivial compute pass using a small inline shader.
336
320
  const mulShader = /* wgsl */`
337
321
  @group(0) @binding(0) var<storage, read> a : array<f32>;
338
322
  @group(0) @binding(1) var<storage, read> b : array<f32>;
@@ -351,17 +335,15 @@ export class MambaBlock {
351
335
  dispatchKernel(d, mulPipeline, bgMul, [cdiv(M * D, 256), 1, 1]);
352
336
  }
353
337
 
354
- // 9. out_proj: (M, dModel) = gatedOut @ wOutProj^T + bOutProj
355
338
  const outProjOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
356
339
  {
357
340
  const params = new Uint32Array([M, D, dModel]).buffer;
358
341
  const pBuf = createUniformBuffer(d, params);
359
- const bg = createBindGroup(d, this.pipelines.linear,
360
- [pBuf, gatedOut, this.gpuWeights.wOutProj, this.gpuWeights.bOutProj, outProjOut]);
361
- dispatchKernel(d, this.pipelines.linear, bg, [cdiv(M, 16), cdiv(dModel, 16), 1]);
342
+ const bg = createBindGroup(d, this.pipelines['linear']!,
343
+ [pBuf, gatedOut, this.gpuWeights['wOutProj']!, this.gpuWeights['bOutProj']!, outProjOut]);
344
+ dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(dModel, 16), 1]);
362
345
  }
363
346
 
364
- // 10. Residual add: output = outProjOut + x
365
347
  const output = createEmptyStorageBuffer(d, M * dModel * 4, true);
366
348
  {
367
349
  const addShader = /* wgsl */`
@@ -385,11 +367,7 @@ export class MambaBlock {
385
367
  return { output, cache };
386
368
  }
387
369
 
388
- /**
389
- * Return a list of all parameter GPU buffers (for the optimizer).
390
- * @returns {Array<{buf: GPUBuffer, numel: number, name: string}>}
391
- */
392
- parameters() {
370
+ parameters(): BlockParam[] {
393
371
  const { dModel, dState, dConv } = this.config;
394
372
  const D = this.dInner;
395
373
  const N = dState;
@@ -397,45 +375,31 @@ export class MambaBlock {
397
375
  const R = this.dtRank;
398
376
 
399
377
  return [
400
- { buf: this.gpuWeights.wInProj, numel: 2 * D * dModel, name: 'wInProj' },
401
- { buf: this.gpuWeights.bInProj, numel: 2 * D, name: 'bInProj' },
402
- { buf: this.gpuWeights.wConv, numel: D * K, name: 'wConv' },
403
- { buf: this.gpuWeights.bConv, numel: D, name: 'bConv' },
404
- { buf: this.gpuWeights.wXProj, numel: (R + 2*N) * D, name: 'wXProj' },
405
- { buf: this.gpuWeights.bXProj, numel: R + 2 * N, name: 'bXProj' },
406
- { buf: this.gpuWeights.wDtProj, numel: D * R, name: 'wDtProj' },
407
- { buf: this.gpuWeights.bDtProj, numel: D, name: 'bDtProj' },
408
- { buf: this.gpuWeights.A_log, numel: D * N, name: 'A_log' },
409
- { buf: this.gpuWeights.D_vec, numel: D, name: 'D_vec' },
410
- { buf: this.gpuWeights.wOutProj, numel: dModel * D, name: 'wOutProj' },
411
- { buf: this.gpuWeights.bOutProj, numel: dModel, name: 'bOutProj' },
412
- { buf: this.gpuWeights.normWeight, numel: dModel, name: 'normWeight'},
378
+ { buf: this.gpuWeights['wInProj']!, numel: 2 * D * dModel, name: 'wInProj' },
379
+ { buf: this.gpuWeights['bInProj']!, numel: 2 * D, name: 'bInProj' },
380
+ { buf: this.gpuWeights['wConv']!, numel: D * K, name: 'wConv' },
381
+ { buf: this.gpuWeights['bConv']!, numel: D, name: 'bConv' },
382
+ { buf: this.gpuWeights['wXProj']!, numel: (R + 2*N) * D, name: 'wXProj' },
383
+ { buf: this.gpuWeights['bXProj']!, numel: R + 2 * N, name: 'bXProj' },
384
+ { buf: this.gpuWeights['wDtProj']!, numel: D * R, name: 'wDtProj' },
385
+ { buf: this.gpuWeights['bDtProj']!, numel: D, name: 'bDtProj' },
386
+ { buf: this.gpuWeights['A_log']!, numel: D * N, name: 'A_log' },
387
+ { buf: this.gpuWeights['D_vec']!, numel: D, name: 'D_vec' },
388
+ { buf: this.gpuWeights['wOutProj']!, numel: dModel * D, name: 'wOutProj' },
389
+ { buf: this.gpuWeights['bOutProj']!, numel: dModel, name: 'bOutProj' },
390
+ { buf: this.gpuWeights['normWeight']!, numel: dModel, name: 'normWeight'},
413
391
  ];
414
392
  }
415
393
 
416
- /**
417
- * WSLA (Weight-Selective Local Adaptation) mode.
418
- * Freezes all parameters except the B and C matrices (wXProj slice).
419
- * This allows rapid local adaptation with minimal compute.
420
- *
421
- * @param {boolean} enabled
422
- */
423
- setWSLAMode(enabled) {
394
+ setWSLAMode(enabled: boolean): void {
424
395
  this._wslaMode = enabled;
425
- // Mark which parameters receive gradients
426
- // (The trainer checks this.getTrainableParams() during backward)
427
396
  }
428
397
 
429
- /**
430
- * Returns only the trainable parameters under WSLA mode.
431
- * @returns {Array<{buf: GPUBuffer, numel: number, name: string}>}
432
- */
433
- getTrainableParams() {
398
+ getTrainableParams(): BlockParam[] {
434
399
  if (this._wslaMode) {
435
- // Only B and C portions of wXProj
436
400
  return [
437
- { buf: this.gpuWeights.wXProj, numel: this.wXProj.length, name: 'wXProj' },
438
- { buf: this.gpuWeights.bXProj, numel: this.bXProj.length, name: 'bXProj' },
401
+ { buf: this.gpuWeights['wXProj']!, numel: this.wXProj.length, name: 'wXProj' },
402
+ { buf: this.gpuWeights['bXProj']!, numel: this.bXProj.length, name: 'bXProj' },
439
403
  ];
440
404
  }
441
405
  return this.parameters();