@seanhogg/builderforce-memory-engine 2026.6.18

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +393 -0
  3. package/dist/index.d.ts +32 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +40 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/kernels/activations.d.ts +5 -0
  8. package/dist/kernels/activations.d.ts.map +1 -0
  9. package/dist/kernels/activations.js +171 -0
  10. package/dist/kernels/activations.js.map +1 -0
  11. package/dist/kernels/attention.d.ts +19 -0
  12. package/dist/kernels/attention.d.ts.map +1 -0
  13. package/dist/kernels/attention.js +263 -0
  14. package/dist/kernels/attention.js.map +1 -0
  15. package/dist/kernels/complex_ssd.d.ts +33 -0
  16. package/dist/kernels/complex_ssd.d.ts.map +1 -0
  17. package/dist/kernels/complex_ssd.js +305 -0
  18. package/dist/kernels/complex_ssd.js.map +1 -0
  19. package/dist/kernels/conv1d.d.ts +3 -0
  20. package/dist/kernels/conv1d.d.ts.map +1 -0
  21. package/dist/kernels/conv1d.js +158 -0
  22. package/dist/kernels/conv1d.js.map +1 -0
  23. package/dist/kernels/linear_projection.d.ts +3 -0
  24. package/dist/kernels/linear_projection.d.ts.map +1 -0
  25. package/dist/kernels/linear_projection.js +219 -0
  26. package/dist/kernels/linear_projection.js.map +1 -0
  27. package/dist/kernels/selective_scan.d.ts +3 -0
  28. package/dist/kernels/selective_scan.d.ts.map +1 -0
  29. package/dist/kernels/selective_scan.js +348 -0
  30. package/dist/kernels/selective_scan.js.map +1 -0
  31. package/dist/kernels/ssd.d.ts +29 -0
  32. package/dist/kernels/ssd.d.ts.map +1 -0
  33. package/dist/kernels/ssd.js +276 -0
  34. package/dist/kernels/ssd.js.map +1 -0
  35. package/dist/kernels/weight_update.d.ts +3 -0
  36. package/dist/kernels/weight_update.d.ts.map +1 -0
  37. package/dist/kernels/weight_update.js +119 -0
  38. package/dist/kernels/weight_update.js.map +1 -0
  39. package/dist/model/attention_block.d.ts +48 -0
  40. package/dist/model/attention_block.d.ts.map +1 -0
  41. package/dist/model/attention_block.js +262 -0
  42. package/dist/model/attention_block.js.map +1 -0
  43. package/dist/model/mamba1_block.d.ts +70 -0
  44. package/dist/model/mamba1_block.d.ts.map +1 -0
  45. package/dist/model/mamba1_block.js +333 -0
  46. package/dist/model/mamba1_block.js.map +1 -0
  47. package/dist/model/mamba2_block.d.ts +44 -0
  48. package/dist/model/mamba2_block.d.ts.map +1 -0
  49. package/dist/model/mamba2_block.js +252 -0
  50. package/dist/model/mamba2_block.js.map +1 -0
  51. package/dist/model/mamba3_block.d.ts +51 -0
  52. package/dist/model/mamba3_block.d.ts.map +1 -0
  53. package/dist/model/mamba3_block.js +270 -0
  54. package/dist/model/mamba3_block.js.map +1 -0
  55. package/dist/model/mamba_block.d.ts +64 -0
  56. package/dist/model/mamba_block.d.ts.map +1 -0
  57. package/dist/model/mamba_block.js +303 -0
  58. package/dist/model/mamba_block.js.map +1 -0
  59. package/dist/model/mamba_model.d.ts +140 -0
  60. package/dist/model/mamba_model.d.ts.map +1 -0
  61. package/dist/model/mamba_model.js +527 -0
  62. package/dist/model/mamba_model.js.map +1 -0
  63. package/dist/model/sequence_layer.d.ts +25 -0
  64. package/dist/model/sequence_layer.d.ts.map +1 -0
  65. package/dist/model/sequence_layer.js +8 -0
  66. package/dist/model/sequence_layer.js.map +1 -0
  67. package/dist/tokenizer/bpe.d.ts +29 -0
  68. package/dist/tokenizer/bpe.d.ts.map +1 -0
  69. package/dist/tokenizer/bpe.js +164 -0
  70. package/dist/tokenizer/bpe.js.map +1 -0
  71. package/dist/training/autograd.d.ts +27 -0
  72. package/dist/training/autograd.d.ts.map +1 -0
  73. package/dist/training/autograd.js +120 -0
  74. package/dist/training/autograd.js.map +1 -0
  75. package/dist/training/trainer.d.ts +36 -0
  76. package/dist/training/trainer.d.ts.map +1 -0
  77. package/dist/training/trainer.js +183 -0
  78. package/dist/training/trainer.js.map +1 -0
  79. package/dist/utils/gpu_utils.d.ts +21 -0
  80. package/dist/utils/gpu_utils.d.ts.map +1 -0
  81. package/dist/utils/gpu_utils.js +111 -0
  82. package/dist/utils/gpu_utils.js.map +1 -0
  83. package/dist/utils/quantization.d.ts +26 -0
  84. package/dist/utils/quantization.d.ts.map +1 -0
  85. package/dist/utils/quantization.js +116 -0
  86. package/dist/utils/quantization.js.map +1 -0
  87. package/dist/utils/rng.d.ts +36 -0
  88. package/dist/utils/rng.d.ts.map +1 -0
  89. package/dist/utils/rng.js +61 -0
  90. package/dist/utils/rng.js.map +1 -0
  91. package/package.json +99 -0
  92. package/src/index.ts +114 -0
  93. package/src/kernels/activations.ts +174 -0
  94. package/src/kernels/attention.ts +268 -0
  95. package/src/kernels/complex_ssd.ts +307 -0
  96. package/src/kernels/conv1d.ts +159 -0
  97. package/src/kernels/linear_projection.ts +220 -0
  98. package/src/kernels/selective_scan.ts +350 -0
  99. package/src/kernels/ssd.ts +278 -0
  100. package/src/kernels/weight_update.ts +120 -0
  101. package/src/model/attention_block.ts +344 -0
  102. package/src/model/mamba1_block.ts +437 -0
  103. package/src/model/mamba2_block.ts +319 -0
  104. package/src/model/mamba3_block.ts +335 -0
  105. package/src/model/mamba_block.ts +401 -0
  106. package/src/model/mamba_model.ts +678 -0
  107. package/src/model/sequence_layer.ts +29 -0
  108. package/src/tokenizer/bpe.ts +186 -0
  109. package/src/training/autograd.ts +135 -0
  110. package/src/training/trainer.ts +309 -0
  111. package/src/utils/gpu_utils.ts +147 -0
  112. package/src/utils/quantization.ts +154 -0
  113. package/src/utils/rng.ts +65 -0
@@ -0,0 +1,333 @@
1
+ /**
2
+ * mamba1_block.ts – Mamba-1 Mixer Block (S6 selective scan).
3
+ *
4
+ * Renamed from mamba_block.ts; MambaBlock is kept as a deprecated alias.
5
+ * Implements SequenceLayer so HybridMambaModel can iterate blocks generically.
6
+ */
7
+ import { createComputePipeline, createBindGroup, createStorageBuffer, createEmptyStorageBuffer, createUniformBuffer, dispatchKernel, cdiv, } from '../utils/gpu_utils.js';
8
+ import { SELECTIVE_SCAN_FORWARD_WGSL } from '../kernels/selective_scan.js';
9
+ import { gaussianArray } from '../utils/rng.js';
10
+ import { CONV1D_FORWARD_WGSL } from '../kernels/conv1d.js';
11
+ import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection.js';
12
+ import { ACTIVATIONS_WGSL } from '../kernels/activations.js';
13
+ // ── Element-wise helper shaders (compiled once per pipeline) ─────────────────
14
+ const MUL_SHADER = /* wgsl */ `
15
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
16
+ @group(0) @binding(1) var<storage, read> b : array<f32>;
17
+ @group(0) @binding(2) var<storage, read_write> c : array<f32>;
18
+ @group(0) @binding(3) var<uniform> n : u32;
19
+ @compute @workgroup_size(256)
20
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
21
+ let i = gid.x;
22
+ if (i < n) { c[i] = a[i] * b[i]; }
23
+ }
24
+ `;
25
+ const ADD_SHADER = /* wgsl */ `
26
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
27
+ @group(0) @binding(1) var<storage, read> b : array<f32>;
28
+ @group(0) @binding(2) var<storage, read_write> c : array<f32>;
29
+ @group(0) @binding(3) var<uniform> n : u32;
30
+ @compute @workgroup_size(256)
31
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
32
+ let i = gid.x;
33
+ if (i < n) { c[i] = a[i] + b[i]; }
34
+ }
35
+ `;
36
+ // ── Mamba1Block ───────────────────────────────────────────────────────────────
37
+ export class Mamba1Block {
38
+ layerType = 'mamba1';
39
+ device;
40
+ config;
41
+ dInner;
42
+ dtRank;
43
+ wInProj;
44
+ bInProj;
45
+ wConv;
46
+ bConv;
47
+ wXProj;
48
+ bXProj;
49
+ wDtProj;
50
+ bDtProj;
51
+ A_log;
52
+ D_vec;
53
+ wOutProj;
54
+ bOutProj;
55
+ normWeight;
56
+ gpuWeights;
57
+ pipelines;
58
+ _wslaMode = false;
59
+ constructor(device, config) {
60
+ this.device = device;
61
+ this.config = {
62
+ dState: 16,
63
+ dConv: 4,
64
+ expand: 2,
65
+ biasConv: true,
66
+ dtRank: Math.ceil(config.dModel / 16),
67
+ ...config,
68
+ };
69
+ const { dModel, expand } = this.config;
70
+ this.dInner = expand * dModel;
71
+ this.dtRank = config.dtRank ?? Math.ceil(dModel / 16);
72
+ this.wInProj = new Float32Array(0);
73
+ this.bInProj = new Float32Array(0);
74
+ this.wConv = new Float32Array(0);
75
+ this.bConv = new Float32Array(0);
76
+ this.wXProj = new Float32Array(0);
77
+ this.bXProj = new Float32Array(0);
78
+ this.wDtProj = new Float32Array(0);
79
+ this.bDtProj = new Float32Array(0);
80
+ this.A_log = new Float32Array(0);
81
+ this.D_vec = new Float32Array(0);
82
+ this.wOutProj = new Float32Array(0);
83
+ this.bOutProj = new Float32Array(0);
84
+ this.normWeight = new Float32Array(0);
85
+ this.gpuWeights = {};
86
+ this.pipelines = {};
87
+ this._initWeights();
88
+ this._buildPipelines();
89
+ }
90
+ _initWeights() {
91
+ const { dModel, dState, dConv } = this.config;
92
+ const D = this.dInner;
93
+ const N = dState;
94
+ const K = dConv;
95
+ const R = this.dtRank;
96
+ const randn = (n, std = 0.02) => gaussianArray(n, std);
97
+ const zeros = (n) => new Float32Array(n);
98
+ const ones = (n) => new Float32Array(n).fill(1.0);
99
+ this.wInProj = randn(2 * D * dModel);
100
+ this.bInProj = zeros(2 * D);
101
+ this.wConv = randn(D * K, 0.01);
102
+ this.bConv = zeros(D);
103
+ this.wXProj = randn((R + 2 * N) * D, 0.01);
104
+ this.bXProj = zeros(R + 2 * N);
105
+ this.wDtProj = randn(D * R, 0.02);
106
+ this.bDtProj = zeros(D);
107
+ this.A_log = new Float32Array(D * N);
108
+ for (let d = 0; d < D; d++) {
109
+ for (let n = 0; n < N; n++) {
110
+ this.A_log[d * N + n] = Math.log(n + 1);
111
+ }
112
+ }
113
+ this.D_vec = ones(D);
114
+ this.wOutProj = randn(dModel * D, 0.02);
115
+ this.bOutProj = zeros(dModel);
116
+ this.normWeight = ones(dModel);
117
+ this._uploadWeightsToGPU();
118
+ }
119
+ _uploadWeightsToGPU() {
120
+ const d = this.device;
121
+ const mk = (arr) => createStorageBuffer(d, arr, true);
122
+ this.gpuWeights = {
123
+ wInProj: mk(this.wInProj),
124
+ bInProj: mk(this.bInProj),
125
+ wConv: mk(this.wConv),
126
+ bConv: mk(this.bConv),
127
+ wXProj: mk(this.wXProj),
128
+ bXProj: mk(this.bXProj),
129
+ wDtProj: mk(this.wDtProj),
130
+ bDtProj: mk(this.bDtProj),
131
+ A_log: mk(this.A_log),
132
+ D_vec: mk(this.D_vec),
133
+ wOutProj: mk(this.wOutProj),
134
+ bOutProj: mk(this.bOutProj),
135
+ normWeight: mk(this.normWeight),
136
+ };
137
+ }
138
+ _buildPipelines() {
139
+ const d = this.device;
140
+ this.pipelines = {
141
+ linear: createComputePipeline(d, LINEAR_FORWARD_WGSL, 'linear_forward'),
142
+ conv1d: createComputePipeline(d, CONV1D_FORWARD_WGSL, 'conv1d_forward'),
143
+ silu: createComputePipeline(d, ACTIVATIONS_WGSL, 'silu_forward'),
144
+ rmsnorm: createComputePipeline(d, ACTIVATIONS_WGSL, 'rmsnorm_forward'),
145
+ scan_fwd: createComputePipeline(d, SELECTIVE_SCAN_FORWARD_WGSL, 'forward_scan'),
146
+ scan_reduce: createComputePipeline(d, SELECTIVE_SCAN_FORWARD_WGSL, 'forward_reduce'),
147
+ elMul: createComputePipeline(d, MUL_SHADER, 'main'),
148
+ elAdd: createComputePipeline(d, ADD_SHADER, 'main'),
149
+ };
150
+ }
151
+ forward(xBuf, batch, seqLen) {
152
+ const d = this.device;
153
+ const { dModel, dState, dConv } = this.config;
154
+ const D = this.dInner;
155
+ const N = dState;
156
+ const B = batch;
157
+ const L = seqLen;
158
+ const M = B * L;
159
+ const R = this.dtRank;
160
+ const cache = {};
161
+ // 1. Pre-block RMSNorm
162
+ const normOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
163
+ const normInv = createEmptyStorageBuffer(d, M * 4, true);
164
+ cache.normInv = normInv;
165
+ cache.normIn = xBuf;
166
+ {
167
+ const params = new ArrayBuffer(16);
168
+ new Uint32Array(params, 0, 2).set([M, dModel]);
169
+ new Float32Array(params, 8, 1).set([1e-6]);
170
+ const pBuf = createUniformBuffer(d, params);
171
+ const bg = createBindGroup(d, this.pipelines['rmsnorm'], [pBuf, xBuf, this.gpuWeights['normWeight'], normOut, normInv]);
172
+ dispatchKernel(d, this.pipelines['rmsnorm'], bg, [cdiv(M, 64), 1, 1]);
173
+ }
174
+ // 2. Input projection → x and z
175
+ const inProjOut = createEmptyStorageBuffer(d, M * 2 * D * 4, true);
176
+ cache.normOut = normOut;
177
+ {
178
+ const params = new Uint32Array([M, dModel, 2 * D]).buffer;
179
+ const pBuf = createUniformBuffer(d, params);
180
+ const bg = createBindGroup(d, this.pipelines['linear'], [pBuf, normOut, this.gpuWeights['wInProj'], this.gpuWeights['bInProj'], inProjOut]);
181
+ dispatchKernel(d, this.pipelines['linear'], bg, [cdiv(M, 16), cdiv(2 * D, 16), 1]);
182
+ }
183
+ // 3. Split into x and z
184
+ const xConvIn = createEmptyStorageBuffer(d, M * D * 4, true);
185
+ const zBuf = createEmptyStorageBuffer(d, M * D * 4, true);
186
+ {
187
+ const enc = d.createCommandEncoder();
188
+ enc.copyBufferToBuffer(inProjOut, 0, xConvIn, 0, M * D * 4);
189
+ enc.copyBufferToBuffer(inProjOut, M * D * 4, zBuf, 0, M * D * 4);
190
+ d.queue.submit([enc.finish()]);
191
+ }
192
+ inProjOut.destroy();
193
+ cache.zBuf = zBuf;
194
+ cache.xConvIn = xConvIn;
195
+ // 4. Causal conv1d on x
196
+ const convOut = createEmptyStorageBuffer(d, M * D * 4, true);
197
+ cache.convOut = convOut;
198
+ {
199
+ const params = new Uint32Array([L, D, dConv, B]).buffer;
200
+ const pBuf = createUniformBuffer(d, params);
201
+ const bg = createBindGroup(d, this.pipelines['conv1d'], [pBuf, xConvIn, this.gpuWeights['wConv'], this.gpuWeights['bConv'], convOut]);
202
+ dispatchKernel(d, this.pipelines['conv1d'], bg, [cdiv(L, 16), cdiv(D, 16), B]);
203
+ }
204
+ // 5. SiLU activation
205
+ const siluOut = createEmptyStorageBuffer(d, M * D * 4, true);
206
+ cache.siluOut = siluOut;
207
+ {
208
+ const params = new Uint32Array([M * D]).buffer;
209
+ const pBuf = createUniformBuffer(d, params);
210
+ const bg = createBindGroup(d, this.pipelines['silu'], [pBuf, convOut, siluOut]);
211
+ dispatchKernel(d, this.pipelines['silu'], bg, [cdiv(M * D, 256), 1, 1]);
212
+ }
213
+ // 6. x_proj → Δ (dtRaw), B, C
214
+ const xProjOut = createEmptyStorageBuffer(d, M * (R + 2 * N) * 4, true);
215
+ {
216
+ const params = new Uint32Array([M, D, R + 2 * N]).buffer;
217
+ const pBuf = createUniformBuffer(d, params);
218
+ const bg = createBindGroup(d, this.pipelines['linear'], [pBuf, siluOut, this.gpuWeights['wXProj'], this.gpuWeights['bXProj'], xProjOut]);
219
+ dispatchKernel(d, this.pipelines['linear'], bg, [cdiv(M, 16), cdiv(R + 2 * N, 16), 1]);
220
+ }
221
+ const dtRaw = createEmptyStorageBuffer(d, M * R * 4, true);
222
+ const B_raw = createEmptyStorageBuffer(d, B * L * N * 4, true);
223
+ const C_raw = createEmptyStorageBuffer(d, B * L * N * 4, true);
224
+ {
225
+ const enc = d.createCommandEncoder();
226
+ enc.copyBufferToBuffer(xProjOut, 0, dtRaw, 0, M * R * 4);
227
+ enc.copyBufferToBuffer(xProjOut, M * R * 4, B_raw, 0, B * L * N * 4);
228
+ enc.copyBufferToBuffer(xProjOut, M * (R + N) * 4, C_raw, 0, B * L * N * 4);
229
+ d.queue.submit([enc.finish()]);
230
+ }
231
+ xProjOut.destroy();
232
+ cache.B_raw = B_raw;
233
+ cache.C_raw = C_raw;
234
+ // 7. dt_proj: expand Δ to full dim
235
+ const deltaFull = createEmptyStorageBuffer(d, M * D * 4, true);
236
+ cache.deltaFull = deltaFull;
237
+ {
238
+ const params = new Uint32Array([M, R, D]).buffer;
239
+ const pBuf = createUniformBuffer(d, params);
240
+ const bg = createBindGroup(d, this.pipelines['linear'], [pBuf, dtRaw, this.gpuWeights['wDtProj'], this.gpuWeights['bDtProj'], deltaFull]);
241
+ dispatchKernel(d, this.pipelines['linear'], bg, [cdiv(M, 16), cdiv(D, 16), 1]);
242
+ }
243
+ dtRaw.destroy();
244
+ // 8. Selective scan (S6)
245
+ const scanY = createEmptyStorageBuffer(d, B * L * D * 4, true);
246
+ const hCache = createEmptyStorageBuffer(d, 2 * B * L * D * N * 4, true);
247
+ cache.hCache = hCache;
248
+ {
249
+ const params = new Uint32Array([L, N, D, B]).buffer;
250
+ const pBuf = createUniformBuffer(d, params);
251
+ const bg1 = createBindGroup(d, this.pipelines['scan_fwd'], [pBuf, siluOut, deltaFull, this.gpuWeights['A_log'], B_raw, C_raw,
252
+ this.gpuWeights['D_vec'], scanY, hCache]);
253
+ dispatchKernel(d, this.pipelines['scan_fwd'], bg1, [cdiv(D, 8), cdiv(N, 8), B]);
254
+ const bg2 = createBindGroup(d, this.pipelines['scan_reduce'], [pBuf, siluOut, deltaFull, this.gpuWeights['A_log'], B_raw, C_raw,
255
+ this.gpuWeights['D_vec'], scanY, hCache]);
256
+ dispatchKernel(d, this.pipelines['scan_reduce'], bg2, [cdiv(L, 64), D, B]);
257
+ }
258
+ // 9. Gate: y ⊗ SiLU(z)
259
+ const siluZ = createEmptyStorageBuffer(d, M * D * 4, true);
260
+ const gatedOut = createEmptyStorageBuffer(d, M * D * 4, true);
261
+ {
262
+ const nBuf = createUniformBuffer(d, new Uint32Array([M * D]).buffer);
263
+ const bgZ = createBindGroup(d, this.pipelines['silu'], [nBuf, zBuf, siluZ]);
264
+ dispatchKernel(d, this.pipelines['silu'], bgZ, [cdiv(M * D, 256), 1, 1]);
265
+ const nBuf2 = createUniformBuffer(d, new Uint32Array([M * D]).buffer);
266
+ const bgMul = createBindGroup(d, this.pipelines['elMul'], [scanY, siluZ, gatedOut, nBuf2]);
267
+ dispatchKernel(d, this.pipelines['elMul'], bgMul, [cdiv(M * D, 256), 1, 1]);
268
+ }
269
+ siluZ.destroy();
270
+ scanY.destroy();
271
+ // 10. Output projection
272
+ const outProjOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
273
+ {
274
+ const params = new Uint32Array([M, D, dModel]).buffer;
275
+ const pBuf = createUniformBuffer(d, params);
276
+ const bg = createBindGroup(d, this.pipelines['linear'], [pBuf, gatedOut, this.gpuWeights['wOutProj'], this.gpuWeights['bOutProj'], outProjOut]);
277
+ dispatchKernel(d, this.pipelines['linear'], bg, [cdiv(M, 16), cdiv(dModel, 16), 1]);
278
+ }
279
+ gatedOut.destroy();
280
+ // 11. Residual add
281
+ const output = createEmptyStorageBuffer(d, M * dModel * 4, true);
282
+ {
283
+ const nBuf = createUniformBuffer(d, new Uint32Array([M * dModel]).buffer);
284
+ const bg = createBindGroup(d, this.pipelines['elAdd'], [outProjOut, xBuf, output, nBuf]);
285
+ dispatchKernel(d, this.pipelines['elAdd'], bg, [cdiv(M * dModel, 256), 1, 1]);
286
+ }
287
+ outProjOut.destroy();
288
+ return { output, cache };
289
+ }
290
+ parameters() {
291
+ const { dModel, dState, dConv } = this.config;
292
+ const D = this.dInner;
293
+ const N = dState;
294
+ const K = dConv;
295
+ const R = this.dtRank;
296
+ return [
297
+ { buf: this.gpuWeights['wInProj'], numel: 2 * D * dModel, name: 'wInProj' },
298
+ { buf: this.gpuWeights['bInProj'], numel: 2 * D, name: 'bInProj' },
299
+ { buf: this.gpuWeights['wConv'], numel: D * K, name: 'wConv' },
300
+ { buf: this.gpuWeights['bConv'], numel: D, name: 'bConv' },
301
+ { buf: this.gpuWeights['wXProj'], numel: (R + 2 * N) * D, name: 'wXProj' },
302
+ { buf: this.gpuWeights['bXProj'], numel: R + 2 * N, name: 'bXProj' },
303
+ { buf: this.gpuWeights['wDtProj'], numel: D * R, name: 'wDtProj' },
304
+ { buf: this.gpuWeights['bDtProj'], numel: D, name: 'bDtProj' },
305
+ { buf: this.gpuWeights['A_log'], numel: D * N, name: 'A_log' },
306
+ { buf: this.gpuWeights['D_vec'], numel: D, name: 'D_vec' },
307
+ { buf: this.gpuWeights['wOutProj'], numel: dModel * D, name: 'wOutProj' },
308
+ { buf: this.gpuWeights['bOutProj'], numel: dModel, name: 'bOutProj' },
309
+ { buf: this.gpuWeights['normWeight'], numel: dModel, name: 'normWeight' },
310
+ ];
311
+ }
312
+ getTrainableParams() {
313
+ if (this._wslaMode) {
314
+ return [
315
+ { buf: this.gpuWeights['wXProj'], numel: this.wXProj.length, name: 'wXProj' },
316
+ { buf: this.gpuWeights['bXProj'], numel: this.bXProj.length, name: 'bXProj' },
317
+ ];
318
+ }
319
+ return this.parameters();
320
+ }
321
+ setWSLAMode(enabled) {
322
+ this._wslaMode = enabled;
323
+ }
324
+ destroy() {
325
+ for (const buf of Object.values(this.gpuWeights)) {
326
+ buf.destroy();
327
+ }
328
+ this.gpuWeights = {};
329
+ }
330
+ }
331
+ // Deprecated alias — kept until mambacode.js 3.0.0
332
+ export { Mamba1Block as MambaBlock };
333
+ //# sourceMappingURL=mamba1_block.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"mamba1_block.js","sourceRoot":"","sources":["../../src/model/mamba1_block.ts"],"names":[],"mappings":"AAAA;;;;;GAKG;AAEH,OAAO,EACH,qBAAqB,EACrB,eAAe,EACf,mBAAmB,EACnB,wBAAwB,EACxB,mBAAmB,EACnB,cAAc,EACd,IAAI,GACP,MAAM,uBAAuB,CAAC;AAE/B,OAAO,EAAE,2BAA2B,EAAE,MAAO,8BAA8B,CAAC;AAC5E,OAAO,EAAE,aAAa,EAAE,MAAM,iBAAiB,CAAC;AAChD,OAAO,EAAE,mBAAmB,EAAE,MAAe,sBAAsB,CAAC;AACpE,OAAO,EAAE,mBAAmB,EAAE,MAAe,iCAAiC,CAAC;AAC/E,OAAO,EAAE,gBAAgB,EAAE,MAAkB,2BAA2B,CAAC;AAmCzE,gFAAgF;AAEhF,MAAM,UAAU,GAAG,UAAU,CAAA;;;;;;;;;;CAU5B,CAAC;AAEF,MAAM,UAAU,GAAG,UAAU,CAAA;;;;;;;;;;CAU5B,CAAC;AAEF,iFAAiF;AAEjF,MAAM,OAAO,WAAW;IACX,SAAS,GAAG,QAAiB,CAAC;IAEvC,MAAM,CAAc;IACpB,MAAM,CAAgC;IACtC,MAAM,CAAW;IACjB,MAAM,CAAW;IAEjB,OAAO,CAAkB;IACzB,OAAO,CAAkB;IACzB,KAAK,CAAoB;IACzB,KAAK,CAAoB;IACzB,MAAM,CAAmB;IACzB,MAAM,CAAmB;IACzB,OAAO,CAAkB;IACzB,OAAO,CAAkB;IACzB,KAAK,CAAoB;IACzB,KAAK,CAAoB;IACzB,QAAQ,CAAiB;IACzB,QAAQ,CAAiB;IACzB,UAAU,CAAe;IAEzB,UAAU,CAA6B;IACvC,SAAS,CAAuC;IAExC,SAAS,GAAG,KAAK,CAAC;IAE1B,YAAY,MAAiB,EAAE,MAAyB;QACpD,IAAI,CAAC,MAAM,GAAG,MAAM,CAAC;QACrB,IAAI,CAAC,MAAM,GAAG;YACV,MAAM,EAAI,EAAE;YACZ,KAAK,EAAK,CAAC;YACX,MAAM,EAAI,CAAC;YACX,QAAQ,EAAE,IAAI;YACd,MAAM,EAAI,IAAI,CAAC,IAAI,CAAC,MAAM,CAAC,MAAM,GAAG,EAAE,CAAC;YACvC,GAAG,MAAM;SACmB,CAAC;QAEjC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC;QACvC,IAAI,CAAC,MAAM,GAAG,MAAM,GAAG,MAAM,CAAC;QAC9B,IAAI,CAAC,MAAM,GAAG,MAAM,CAAC,MAAM,IAAI,IAAI,CAAC,IAAI,CAAC,MAAM,GAAG,EAAE,CAAC,CAAC;QAEtD,IAAI,CAAC,OAAO,GAAM,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,OAAO,GAAM,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,KAAK,GAAQ,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,KAAK,GAAQ,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,MAAM,GAAO,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,MAAM,GAAO,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,OAAO,GAAM,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,OAAO,GAAM,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,KAAK,GAAQ,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,KAAK,GAAQ,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,QAAQ,GAAK,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,QAAQ,GAAK,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,UAAU,GAAG,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,UAAU,GAAG,EAAE,CAAC;QACrB,IAAI,CAAC,SAAS,GAAI,EAAE,CAAC;QAErB,IAAI,CAAC,YAAY,EAAE,CAAC;QACpB,IAAI,CAAC,eAAe,EAAE,CAAC;IAC3B,CAAC;IAEO,YAAY;QAChB,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC;QAC9C,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC;QACtB,MAAM,CAAC,GAAG,MAAM,CAAC;QACjB,MAAM,CAAC,GAAG,KAAK,CAAC;QAChB,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC;QAEtB,MAAM,KAAK,GAAG,CAAC,CAAS,EAAE,GAAG,GAAG,IAAI,EAAgB,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QAE7E,MAAM,KAAK,GAAG,CAAC,CAAS,EAAgB,EAAE,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QAC/D,MAAM,IAAI,GAAI,CAAC,CAAS,EAAgB,EAAE,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;QAEzE,IAAI,CAAC,OAAO,GAAI,KAAK,CAAC,CAAC,GAAG,CAAC,GAAG,MAAM,CAAC,CAAC;QACtC,IAAI,CAAC,OAAO,GAAI,KAAK,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;QAC7B,IAAI,CAAC,KAAK,GAAM,KAAK,CAAC,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACnC,IAAI,CAAC,KAAK,GAAM,KAAK,CAAC,CAAC,CAAC,CAAC;QACzB,IAAI,CAAC,MAAM,GAAK,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC7C,IAAI,CAAC,MAAM,GAAK,KAAK,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;QACjC,IAAI,CAAC,OAAO,GAAI,KAAK,CAAC,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACnC,IAAI,CAAC,OAAO,GAAI,KAAK,CAAC,CAAC,CAAC,CAAC;QAEzB,IAAI,CAAC,KAAK,GAAG,IAAI,YAAY,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;QACrC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;YACzB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;gBACzB,IAAI,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;YAC5C,CAAC;QACL,CAAC;QAED,IAAI,CAAC,KAAK,GAAO,IAAI,CAAC,CAAC,CAAC,CAAC;QACzB,IAAI,CAAC,QAAQ,GAAI,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACzC,IAAI,CAAC,QAAQ,GAAI,KAAK,CAAC,MAAM,CAAC,CAAC;QAC/B,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,MAAM,CAAC,CAAC;QAE/B,IAAI,CAAC,mBAAmB,EAAE,CAAC;IAC/B,CAAC;IAEO,mBAAmB;QACvB,MAAM,CAAC,GAAI,IAAI,CAAC,MAAM,CAAC;QACvB,MAAM,EAAE,GAAG,CAAC,GAAiB,EAAa,EAAE,CAAC,mBAAmB,CAAC,CAAC,EAAE,GAAG,EAAE,IAAI,CAAC,CAAC;QAE/E,IAAI,CAAC,UAAU,GAAG;YACd,OAAO,EAAK,EAAE,CAAC,IAAI,CAAC,OAAO,CAAC;YAC5B,OAAO,EAAK,EAAE,CAAC,IAAI,CAAC,OAAO,CAAC;YAC5B,KAAK,EAAO,EAAE,CAAC,IAAI,CAAC,KAAK,CAAC;YAC1B,KAAK,EAAO,EAAE,CAAC,IAAI,CAAC,KAAK,CAAC;YAC1B,MAAM,EAAM,EAAE,CAAC,IAAI,CAAC,MAAM,CAAC;YAC3B,MAAM,EAAM,EAAE,CAAC,IAAI,CAAC,MAAM,CAAC;YAC3B,OAAO,EAAK,EAAE,CAAC,IAAI,CAAC,OAAO,CAAC;YAC5B,OAAO,EAAK,EAAE,CAAC,IAAI,CAAC,OAAO,CAAC;YAC5B,KAAK,EAAO,EAAE,CAAC,IAAI,CAAC,KAAK,CAAC;YAC1B,KAAK,EAAO,EAAE,CAAC,IAAI,CAAC,KAAK,CAAC;YAC1B,QAAQ,EAAI,EAAE,CAAC,IAAI,CAAC,QAAQ,CAAC;YAC7B,QAAQ,EAAI,EAAE,CAAC,IAAI,CAAC,QAAQ,CAAC;YAC7B,UAAU,EAAE,EAAE,CAAC,IAAI,CAAC,UAAU,CAAC;SAClC,CAAC;IACN,CAAC;IAEO,eAAe;QACnB,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC;QACtB,IAAI,CAAC,SAAS,GAAG;YACb,MAAM,EAAQ,qBAAqB,CAAC,CAAC,EAAE,mBAAmB,EAAW,gBAAgB,CAAC;YACtF,MAAM,EAAQ,qBAAqB,CAAC,CAAC,EAAE,mBAAmB,EAAW,gBAAgB,CAAC;YACtF,IAAI,EAAU,qBAAqB,CAAC,CAAC,EAAE,gBAAgB,EAAc,cAAc,CAAC;YACpF,OAAO,EAAO,qBAAqB,CAAC,CAAC,EAAE,gBAAgB,EAAc,iBAAiB,CAAC;YACvF,QAAQ,EAAM,qBAAqB,CAAC,CAAC,EAAE,2BAA2B,EAAG,cAAc,CAAC;YACpF,WAAW,EAAG,qBAAqB,CAAC,CAAC,EAAE,2BAA2B,EAAG,gBAAgB,CAAC;YACtF,KAAK,EAAS,qBAAqB,CAAC,CAAC,EAAE,UAAU,EAAE,MAAM,CAAC;YAC1D,KAAK,EAAS,qBAAqB,CAAC,CAAC,EAAE,UAAU,EAAE,MAAM,CAAC;SAC7D,CAAC;IACN,CAAC;IAED,OAAO,CAAC,IAAe,EAAE,KAAa,EAAE,MAAc;QAClD,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC;QACtB,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC;QAC9C,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC;QACtB,MAAM,CAAC,GAAG,MAAM,CAAC;QACjB,MAAM,CAAC,GAAG,KAAK,CAAC;QAChB,MAAM,CAAC,GAAG,MAAM,CAAC;QACjB,MAAM,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC;QAChB,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC;QAEtB,MAAM,KAAK,GAAG,EAAgB,CAAC;QAE/B,uBAAuB;QACvB,MAAM,OAAO,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAClE,MAAM,OAAO,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACzD,KAAK,CAAC,OAAO,GAAG,OAAO,CAAC;QACxB,KAAK,CAAC,MAAM,GAAI,IAAI,CAAC;QACrB,CAAC;YACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,EAAE,CAAC,CAAC;YACnC,IAAI,WAAW,CAAC,MAAM,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC;YAC/C,IAAI,YAAY,CAAC,MAAM,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC;YAC3C,MAAM,IAAI,GAAG,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YAC5C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,SAAS,CAAE,EACpD,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,CAAC,UAAU,CAAC,YAAY,CAAE,EAAE,OAAO,EAAE,OAAO,CAAC,CAAC,CAAC;YACpE,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,SAAS,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC3E,CAAC;QAED,gCAAgC;QAChC,MAAM,SAAS,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACnE,KAAK,CAAC,OAAO,GAAG,OAAO,CAAC;QACxB,CAAC;YACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;YAC1D,MAAM,IAAI,GAAK,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YAC9C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EACnD,CAAC,IAAI,EAAE,OAAO,EAAE,IAAI,CAAC,UAAU,CAAC,SAAS,CAAE,EAAE,IAAI,CAAC,UAAU,CAAC,SAAS,CAAE,EAAE,SAAS,CAAC,CAAC,CAAC;YAC1F,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,IAAI,CAAC,CAAC,GAAG,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACxF,CAAC;QAED,wBAAwB;QACxB,MAAM,OAAO,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC7D,MAAM,IAAI,GAAM,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC7D,CAAC;YACG,MAAM,GAAG,GAAG,CAAC,CAAC,oBAAoB,EAAE,CAAC;YACrC,GAAG,CAAC,kBAAkB,CAAC,SAAS,EAAE,CAAC,EAAU,OAAO,EAAE,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;YACpE,GAAG,CAAC,kBAAkB,CAAC,SAAS,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,EAAK,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;YACpE,CAAC,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,GAAG,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC;QACnC,CAAC;QACD,SAAS,CAAC,OAAO,EAAE,CAAC;QACpB,KAAK,CAAC,IAAI,GAAM,IAAI,CAAC;QACrB,KAAK,CAAC,OAAO,GAAG,OAAO,CAAC;QAExB,wBAAwB;QACxB,MAAM,OAAO,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC7D,KAAK,CAAC,OAAO,GAAG,OAAO,CAAC;QACxB,CAAC;YACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,KAAK,EAAE,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;YACxD,MAAM,IAAI,GAAK,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YAC9C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EACnD,CAAC,IAAI,EAAE,OAAO,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,OAAO,CAAC,CAAC,CAAC;YACpF,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpF,CAAC;QAED,qBAAqB;QACrB,MAAM,OAAO,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC7D,KAAK,CAAC,OAAO,GAAG,OAAO,CAAC;QACxB,CAAC;YACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;YAC/C,MAAM,IAAI,GAAK,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YAC9C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAE,EACjD,CAAC,IAAI,EAAE,OAAO,EAAE,OAAO,CAAC,CAAC,CAAC;YAC9B,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,GAAG,CAAC,EAAE,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC7E,CAAC;QAED,8BAA8B;QAC9B,MAAM,QAAQ,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACxE,CAAC;YACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;YACzD,MAAM,IAAI,GAAK,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YAC9C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EACnD,CAAC,IAAI,EAAE,OAAO,EAAE,IAAI,CAAC,UAAU,CAAC,QAAQ,CAAE,EAAE,IAAI,CAAC,UAAU,CAAC,QAAQ,CAAE,EAAE,QAAQ,CAAC,CAAC,CAAC;YACvF,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,IAAI,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC5F,CAAC;QAED,MAAM,KAAK,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC3D,MAAM,KAAK,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC/D,MAAM,KAAK,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC/D,CAAC;YACG,MAAM,GAAG,GAAG,CAAC,CAAC,oBAAoB,EAAE,CAAC;YACrC,GAAG,CAAC,kBAAkB,CAAC,QAAQ,EAAE,CAAC,EAAgB,KAAK,EAAE,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;YACvE,GAAG,CAAC,kBAAkB,CAAC,QAAQ,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAQ,KAAK,EAAE,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;YAC3E,GAAG,CAAC,kBAAkB,CAAC,QAAQ,EAAE,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,EAAE,KAAK,EAAE,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;YAC3E,CAAC,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,GAAG,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC;QACnC,CAAC;QACD,QAAQ,CAAC,OAAO,EAAE,CAAC;QACnB,KAAK,CAAC,KAAK,GAAG,KAAK,CAAC;QACpB,KAAK,CAAC,KAAK,GAAG,KAAK,CAAC;QAEpB,mCAAmC;QACnC,MAAM,SAAS,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC/D,KAAK,CAAC,SAAS,GAAG,SAAS,CAAC;QAC5B,CAAC;YACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;YACjD,MAAM,IAAI,GAAK,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YAC9C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EACnD,CAAC,IAAI,EAAE,KAAK,EAAE,IAAI,CAAC,UAAU,CAAC,SAAS,CAAE,EAAE,IAAI,CAAC,UAAU,CAAC,SAAS,CAAE,EAAE,SAAS,CAAC,CAAC,CAAC;YACxF,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpF,CAAC;QACD,KAAK,CAAC,OAAO,EAAE,CAAC;QAEhB,yBAAyB;QACzB,MAAM,KAAK,GAAI,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAChE,MAAM,MAAM,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACxE,KAAK,CAAC,MAAM,GAAG,MAAM,CAAC;QACtB,CAAC;YACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;YACpD,MAAM,IAAI,GAAK,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YAE9C,MAAM,GAAG,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,UAAU,CAAE,EACtD,CAAC,IAAI,EAAE,OAAO,EAAE,SAAS,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,KAAK,EAAE,KAAK;gBACjE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,KAAK,EAAE,MAAM,CAAC,CAAC,CAAC;YAChD,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,UAAU,CAAE,EAAE,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;YAEjF,MAAM,GAAG,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,aAAa,CAAE,EACzD,CAAC,IAAI,EAAE,OAAO,EAAE,SAAS,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,KAAK,EAAE,KAAK;gBACjE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,KAAK,EAAE,MAAM,CAAC,CAAC,CAAC;YAChD,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,aAAa,CAAE,EAAE,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAChF,CAAC;QAED,uBAAuB;QACvB,MAAM,KAAK,GAAM,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC9D,MAAM,QAAQ,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC9D,CAAC;YACG,MAAM,IAAI,GAAG,mBAAmB,CAAC,CAAC,EAAE,IAAI,WAAW,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;YACrE,MAAM,GAAG,GAAI,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAE,EACnD,CAAC,IAAI,EAAE,IAAI,EAAE,KAAK,CAAC,CAAC,CAAC;YACzB,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAE,EAAE,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,GAAG,CAAC,EAAE,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;YAE1E,MAAM,KAAK,GAAG,mBAAmB,CAAC,CAAC,EAAE,IAAI,WAAW,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;YACtE,MAAM,KAAK,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,OAAO,CAAE,EACrD,CAAC,KAAK,EAAE,KAAK,EAAE,QAAQ,EAAE,KAAK,CAAC,CAAC,CAAC;YACrC,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,OAAO,CAAE,EAAE,KAAK,EAAE,CAAC,IAAI,CAAC,CAAC,GAAG,CAAC,EAAE,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACjF,CAAC;QACD,KAAK,CAAC,OAAO,EAAE,CAAC;QAChB,KAAK,CAAC,OAAO,EAAE,CAAC;QAEhB,wBAAwB;QACxB,MAAM,UAAU,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACrE,CAAC;YACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC;YACtD,MAAM,IAAI,GAAK,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YAC9C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EACnD,CAAC,IAAI,EAAE,QAAQ,EAAE,IAAI,CAAC,UAAU,CAAC,UAAU,CAAE,EAAE,IAAI,CAAC,UAAU,CAAC,UAAU,CAAE,EAAE,UAAU,CAAC,CAAC,CAAC;YAC9F,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACzF,CAAC;QACD,QAAQ,CAAC,OAAO,EAAE,CAAC;QAEnB,mBAAmB;QACnB,MAAM,MAAM,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACjE,CAAC;YACG,MAAM,IAAI,GAAG,mBAAmB,CAAC,CAAC,EAAE,IAAI,WAAW,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;YAC1E,MAAM,EAAE,GAAK,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,OAAO,CAAE,EACpD,CAAC,UAAU,EAAE,IAAI,EAAE,MAAM,EAAE,IAAI,CAAC,CAAC,CAAC;YACtC,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,OAAO,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,GAAG,MAAM,EAAE,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACnF,CAAC;QACD,UAAU,CAAC,OAAO,EAAE,CAAC;QAErB,OAAO,EAAE,MAAM,EAAE,KAAK,EAAE,CAAC;IAC7B,CAAC;IAED,UAAU;QACN,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC;QAC9C,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC;QACtB,MAAM,CAAC,GAAG,MAAM,CAAC;QACjB,MAAM,CAAC,GAAG,KAAK,CAAC;QAChB,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC;QAEtB,OAAO;YACH,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,SAAS,CAAE,EAAK,KAAK,EAAE,CAAC,GAAG,CAAC,GAAG,MAAM,EAAI,IAAI,EAAE,SAAS,EAAK;YACpF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,SAAS,CAAE,EAAK,KAAK,EAAE,CAAC,GAAG,CAAC,EAAa,IAAI,EAAE,SAAS,EAAK;YACpF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAO,KAAK,EAAE,CAAC,GAAG,CAAC,EAAa,IAAI,EAAE,OAAO,EAAO;YACpF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAO,KAAK,EAAE,CAAC,EAAiB,IAAI,EAAE,OAAO,EAAO;YACpF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,QAAQ,CAAE,EAAM,KAAK,EAAE,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,EAAE,IAAI,EAAE,QAAQ,EAAK;YAClF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,QAAQ,CAAE,EAAM,KAAK,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAQ,IAAI,EAAE,QAAQ,EAAK;YAClF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,SAAS,CAAE,EAAK,KAAK,EAAE,CAAC,GAAG,CAAC,EAAa,IAAI,EAAE,SAAS,EAAI;YACnF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,SAAS,CAAE,EAAK,KAAK,EAAE,CAAC,EAAiB,IAAI,EAAE,SAAS,EAAI;YACnF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAO,KAAK,EAAE,CAAC,GAAG,CAAC,EAAa,IAAI,EAAE,OAAO,EAAM;YACnF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAO,KAAK,EAAE,CAAC,EAAiB,IAAI,EAAE,OAAO,EAAM;YACnF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,UAAU,CAAE,EAAI,KAAK,EAAE,MAAM,GAAG,CAAC,EAAQ,IAAI,EAAE,UAAU,EAAG;YACnF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,UAAU,CAAE,EAAI,KAAK,EAAE,MAAM,EAAY,IAAI,EAAE,UAAU,EAAG;YACnF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,YAAY,CAAE,EAAE,KAAK,EAAE,MAAM,EAAY,IAAI,EAAE,YAAY,EAAC;SACtF,CAAC;IACN,CAAC;IAED,kBAAkB;QACd,IAAI,IAAI,CAAC,SAAS,EAAE,CAAC;YACjB,OAAO;gBACH,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,QAAQ,CAAE,EAAE,KAAK,EAAE,IAAI,CAAC,MAAM,CAAC,MAAM,EAAE,IAAI,EAAE,QAAQ,EAAE;gBAC9E,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,QAAQ,CAAE,EAAE,KAAK,EAAE,IAAI,CAAC,MAAM,CAAC,MAAM,EAAE,IAAI,EAAE,QAAQ,EAAE;aACjF,CAAC;QACN,CAAC;QACD,OAAO,IAAI,CAAC,UAAU,EAAE,CAAC;IAC7B,CAAC;IAED,WAAW,CAAC,OAAgB;QACxB,IAAI,CAAC,SAAS,GAAG,OAAO,CAAC;IAC7B,CAAC;IAED,OAAO;QACH,KAAK,MAAM,GAAG,IAAI,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,UAAU,CAAC,EAAE,CAAC;YAC/C,GAAG,CAAC,OAAO,EAAE,CAAC;QAClB,CAAC;QACD,IAAI,CAAC,UAAU,GAAG,EAAE,CAAC;IACzB,CAAC;CACJ;AAED,mDAAmD;AACnD,OAAO,EAAE,WAAW,IAAI,UAAU,EAAE,CAAC"}
@@ -0,0 +1,44 @@
1
+ /**
2
+ * mamba2_block.ts – Mamba-2 Mixer Block (Structured State Space Duality).
3
+ *
4
+ * Key differences from Mamba-1:
5
+ * - Multi-head SSM with scalar A per head
6
+ * - Single fused in_proj (no separate dt_proj expansion)
7
+ * - SSD (chunked) scan replaces S6 selective scan
8
+ * - Inner RMSNorm on scan output instead of SiLU gate
9
+ * - No separate z gate
10
+ *
11
+ * Implements SequenceLayer.
12
+ */
13
+ import type { SequenceLayer, LayerForwardResult, LayerParam } from './sequence_layer.js';
14
+ export interface Mamba2BlockConfig {
15
+ dModel: number;
16
+ dState: number;
17
+ dConv: number;
18
+ expand: number;
19
+ nHeads: number;
20
+ nGroups: number;
21
+ chunkLen: number;
22
+ }
23
+ export interface Mamba2Cache {
24
+ stateCarry: GPUBuffer;
25
+ }
26
+ export declare class Mamba2Block implements SequenceLayer {
27
+ readonly layerType: "mamba2";
28
+ device: GPUDevice;
29
+ config: Required<Mamba2BlockConfig>;
30
+ dInner: number;
31
+ dHead: number;
32
+ gpuWeights: Record<string, GPUBuffer>;
33
+ pipelines: Record<string, GPUComputePipeline>;
34
+ private _wslaMode;
35
+ constructor(device: GPUDevice, config: Mamba2BlockConfig);
36
+ private _initWeights;
37
+ private _buildPipelines;
38
+ forward(xBuf: GPUBuffer, batch: number, seqLen: number): LayerForwardResult;
39
+ parameters(): LayerParam[];
40
+ getTrainableParams(): LayerParam[];
41
+ setWSLAMode(enabled: boolean): void;
42
+ destroy(): void;
43
+ }
44
+ //# sourceMappingURL=mamba2_block.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"mamba2_block.d.ts","sourceRoot":"","sources":["../../src/model/mamba2_block.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;GAWG;AAkBH,OAAO,KAAK,EAAE,aAAa,EAAE,kBAAkB,EAAE,UAAU,EAAE,MAAM,qBAAqB,CAAC;AAEzF,MAAM,WAAW,iBAAiB;IAC9B,MAAM,EAAK,MAAM,CAAC;IAClB,MAAM,EAAK,MAAM,CAAC;IAClB,KAAK,EAAM,MAAM,CAAC;IAClB,MAAM,EAAK,MAAM,CAAC;IAClB,MAAM,EAAK,MAAM,CAAC;IAClB,OAAO,EAAI,MAAM,CAAC;IAClB,QAAQ,EAAG,MAAM,CAAC;CACrB;AAED,MAAM,WAAW,WAAW;IACxB,UAAU,EAAG,SAAS,CAAC;CAC1B;AAcD,qBAAa,WAAY,YAAW,aAAa;IAC7C,QAAQ,CAAC,SAAS,EAAG,QAAQ,CAAU;IAEvC,MAAM,EAAG,SAAS,CAAC;IACnB,MAAM,EAAG,QAAQ,CAAC,iBAAiB,CAAC,CAAC;IACrC,MAAM,EAAG,MAAM,CAAC;IAChB,KAAK,EAAI,MAAM,CAAC;IAEhB,UAAU,EAAG,MAAM,CAAC,MAAM,EAAE,SAAS,CAAC,CAAC;IACvC,SAAS,EAAI,MAAM,CAAC,MAAM,EAAE,kBAAkB,CAAC,CAAC;IAEhD,OAAO,CAAC,SAAS,CAAS;gBAEd,MAAM,EAAE,SAAS,EAAE,MAAM,EAAE,iBAAiB;IAwBxD,OAAO,CAAC,YAAY;IA8BpB,OAAO,CAAC,eAAe;IAWvB,OAAO,CAAC,IAAI,EAAE,SAAS,EAAE,KAAK,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,kBAAkB;IA6I3E,UAAU,IAAI,UAAU,EAAE;IAsB1B,kBAAkB,IAAI,UAAU,EAAE;IAYlC,WAAW,CAAC,OAAO,EAAE,OAAO,GAAG,IAAI;IAInC,OAAO,IAAI,IAAI;CAIlB"}
@@ -0,0 +1,252 @@
1
+ /**
2
+ * mamba2_block.ts – Mamba-2 Mixer Block (Structured State Space Duality).
3
+ *
4
+ * Key differences from Mamba-1:
5
+ * - Multi-head SSM with scalar A per head
6
+ * - Single fused in_proj (no separate dt_proj expansion)
7
+ * - SSD (chunked) scan replaces S6 selective scan
8
+ * - Inner RMSNorm on scan output instead of SiLU gate
9
+ * - No separate z gate
10
+ *
11
+ * Implements SequenceLayer.
12
+ */
13
+ import { createComputePipeline, createBindGroup, createStorageBuffer, createEmptyStorageBuffer, createUniformBuffer, dispatchKernel, cdiv, } from '../utils/gpu_utils.js';
14
+ import { SSD_FORWARD_WGSL } from '../kernels/ssd.js';
15
+ import { gaussianArray } from '../utils/rng.js';
16
+ import { CONV1D_FORWARD_WGSL } from '../kernels/conv1d.js';
17
+ import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection.js';
18
+ import { ACTIVATIONS_WGSL } from '../kernels/activations.js';
19
+ const ADD_SHADER = /* wgsl */ `
20
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
21
+ @group(0) @binding(1) var<storage, read> b : array<f32>;
22
+ @group(0) @binding(2) var<storage, read_write> c : array<f32>;
23
+ @group(0) @binding(3) var<uniform> n : u32;
24
+ @compute @workgroup_size(256)
25
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
26
+ let i = gid.x;
27
+ if (i < n) { c[i] = a[i] + b[i]; }
28
+ }
29
+ `;
30
+ export class Mamba2Block {
31
+ layerType = 'mamba2';
32
+ device;
33
+ config;
34
+ dInner;
35
+ dHead;
36
+ gpuWeights;
37
+ pipelines;
38
+ _wslaMode = false;
39
+ constructor(device, config) {
40
+ this.device = device;
41
+ this.config = {
42
+ ...{ dState: 16, dConv: 4, expand: 2, nGroups: 1, chunkLen: 256 },
43
+ ...config,
44
+ };
45
+ const { dModel, expand, nHeads } = this.config;
46
+ this.dInner = expand * dModel;
47
+ this.dHead = this.dInner / nHeads;
48
+ if (this.dInner % nHeads !== 0) {
49
+ throw new Error(`Mamba2Block: dInner (${this.dInner}) must be divisible by nHeads (${nHeads}).`);
50
+ }
51
+ this.gpuWeights = {};
52
+ this.pipelines = {};
53
+ this._initWeights();
54
+ this._buildPipelines();
55
+ }
56
+ _initWeights() {
57
+ const { dModel, dState, dConv, nHeads, nGroups } = this.config;
58
+ const D = this.dInner;
59
+ const N = dState;
60
+ const K = dConv;
61
+ const H = nHeads;
62
+ const G = nGroups;
63
+ const randn = (n, std = 0.02) => gaussianArray(n, std);
64
+ const zeros = (n) => new Float32Array(n);
65
+ const ones = (n) => new Float32Array(n).fill(1.0);
66
+ // wInProj: (D_inner + 2*n_groups*N + H, D_model) — no bias per Mamba-2 spec
67
+ const inProjRows = D + 2 * G * N + H;
68
+ const mk = (arr) => createStorageBuffer(this.device, arr, true);
69
+ this.gpuWeights = {
70
+ wInProj: mk(randn(inProjRows * dModel)),
71
+ wConv: mk(randn((D + 2 * G * N) * K, 0.01)),
72
+ bConv: mk(zeros(D + 2 * G * N)),
73
+ A_log: mk(new Float32Array(H).fill(Math.log(1.0))),
74
+ dt_bias: mk(zeros(H)),
75
+ D_vec: mk(ones(H)),
76
+ wOutProj: mk(randn(dModel * D, 0.02)),
77
+ normWeight: mk(ones(D)), // inner RMSNorm
78
+ preNormWeight: mk(ones(dModel)), // pre-block RMSNorm
79
+ };
80
+ }
81
+ _buildPipelines() {
82
+ const d = this.device;
83
+ this.pipelines = {
84
+ linear: createComputePipeline(d, LINEAR_FORWARD_WGSL, 'linear_forward'),
85
+ conv1d: createComputePipeline(d, CONV1D_FORWARD_WGSL, 'conv1d_forward'),
86
+ rmsnorm: createComputePipeline(d, ACTIVATIONS_WGSL, 'rmsnorm_forward'),
87
+ ssd_fwd: createComputePipeline(d, SSD_FORWARD_WGSL, 'ssd_chunk_forward'),
88
+ elAdd: createComputePipeline(d, ADD_SHADER, 'main'),
89
+ };
90
+ }
91
+ forward(xBuf, batch, seqLen) {
92
+ const d = this.device;
93
+ const { dModel, dState, dConv, nHeads, nGroups, chunkLen } = this.config;
94
+ const D = this.dInner;
95
+ const N = dState;
96
+ const K = dConv;
97
+ const H = nHeads;
98
+ const G = nGroups;
99
+ const dh = this.dHead;
100
+ const B = batch;
101
+ const L = seqLen;
102
+ const M = B * L;
103
+ const convD = D + 2 * G * N; // channels for conv (x, B_proj, C_proj)
104
+ const numChunks = Math.ceil(L / chunkLen);
105
+ // 1. Pre-block RMSNorm
106
+ const normOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
107
+ const normInv = createEmptyStorageBuffer(d, M * 4, true);
108
+ {
109
+ const params = new ArrayBuffer(16);
110
+ new Uint32Array(params, 0, 2).set([M, dModel]);
111
+ new Float32Array(params, 8, 1).set([1e-6]);
112
+ const pBuf = createUniformBuffer(d, params);
113
+ const bg = createBindGroup(d, this.pipelines['rmsnorm'], [pBuf, xBuf, this.gpuWeights['preNormWeight'], normOut, normInv]);
114
+ dispatchKernel(d, this.pipelines['rmsnorm'], bg, [cdiv(M, 64), 1, 1]);
115
+ }
116
+ normInv.destroy();
117
+ // 2. Fused in_proj → [x (D), B_proj (G*N), C_proj (G*N), dt (H)]
118
+ const inProjRows = D + 2 * G * N + H;
119
+ const inProjOut = createEmptyStorageBuffer(d, M * inProjRows * 4, true);
120
+ {
121
+ const params = new Uint32Array([M, dModel, inProjRows]).buffer;
122
+ const pBuf = createUniformBuffer(d, params);
123
+ // wInProj has no bias — pass a zero-filled buffer
124
+ const zeroBias = createStorageBuffer(d, new Float32Array(inProjRows), true);
125
+ const bg = createBindGroup(d, this.pipelines['linear'], [pBuf, normOut, this.gpuWeights['wInProj'], zeroBias, inProjOut]);
126
+ dispatchKernel(d, this.pipelines['linear'], bg, [cdiv(M, 16), cdiv(inProjRows, 16), 1]);
127
+ zeroBias.destroy();
128
+ }
129
+ normOut.destroy();
130
+ // Split: xConv [D+2GN], dt [H]
131
+ const xConvBuf = createEmptyStorageBuffer(d, M * convD * 4, true);
132
+ const dtBuf = createEmptyStorageBuffer(d, M * H * 4, true);
133
+ {
134
+ const enc = d.createCommandEncoder();
135
+ enc.copyBufferToBuffer(inProjOut, 0, xConvBuf, 0, M * convD * 4);
136
+ enc.copyBufferToBuffer(inProjOut, M * convD * 4, dtBuf, 0, M * H * 4);
137
+ d.queue.submit([enc.finish()]);
138
+ }
139
+ inProjOut.destroy();
140
+ // 3. Causal conv1d over x + B_proj + C_proj (fused, convD channels)
141
+ const convOut = createEmptyStorageBuffer(d, M * convD * 4, true);
142
+ {
143
+ const params = new Uint32Array([L, convD, K, B, 1]).buffer;
144
+ const pBuf = createUniformBuffer(d, params);
145
+ const bg = createBindGroup(d, this.pipelines['conv1d'], [pBuf, xConvBuf, this.gpuWeights['wConv'], this.gpuWeights['bConv'], convOut]);
146
+ dispatchKernel(d, this.pipelines['conv1d'], bg, [cdiv(L, 16), cdiv(convD, 16), B]);
147
+ }
148
+ xConvBuf.destroy();
149
+ // Split conv output: x [D], B_proj [G*N], C_proj [G*N]
150
+ const xSsdBuf = createEmptyStorageBuffer(d, M * D * 4, true);
151
+ const bProjBuf = createEmptyStorageBuffer(d, M * G * N * 4, true);
152
+ const cProjBuf = createEmptyStorageBuffer(d, M * G * N * 4, true);
153
+ {
154
+ const enc = d.createCommandEncoder();
155
+ enc.copyBufferToBuffer(convOut, 0, xSsdBuf, 0, M * D * 4);
156
+ enc.copyBufferToBuffer(convOut, M * D * 4, bProjBuf, 0, M * G * N * 4);
157
+ enc.copyBufferToBuffer(convOut, M * (D + G * N) * 4, cProjBuf, 0, M * G * N * 4);
158
+ d.queue.submit([enc.finish()]);
159
+ }
160
+ convOut.destroy();
161
+ // 4. SSD scan
162
+ // state_carry: [numChunks+1, B, H, N, dHead]
163
+ const stateCarry = createEmptyStorageBuffer(d, (numChunks + 1) * B * H * N * dh * 4, true);
164
+ const ssdOut = createEmptyStorageBuffer(d, M * D * 4, true);
165
+ {
166
+ const ssdParams = new Uint32Array([L, D, H, dh, G, N, chunkLen, numChunks, B]).buffer;
167
+ const pBuf = createUniformBuffer(d, ssdParams);
168
+ const bg = createBindGroup(d, this.pipelines['ssd_fwd'], [pBuf, xSsdBuf, bProjBuf, cProjBuf, dtBuf,
169
+ this.gpuWeights['A_log'], this.gpuWeights['dt_bias'],
170
+ this.gpuWeights['D_vec'], ssdOut, stateCarry]);
171
+ dispatchKernel(d, this.pipelines['ssd_fwd'], bg, [numChunks, H, B]);
172
+ }
173
+ xSsdBuf.destroy();
174
+ bProjBuf.destroy();
175
+ cProjBuf.destroy();
176
+ dtBuf.destroy();
177
+ // 5. Inner RMSNorm on scan output
178
+ const innerNormOut = createEmptyStorageBuffer(d, M * D * 4, true);
179
+ const innerNormInv = createEmptyStorageBuffer(d, M * 4, true);
180
+ {
181
+ const params = new ArrayBuffer(16);
182
+ new Uint32Array(params, 0, 2).set([M, D]);
183
+ new Float32Array(params, 8, 1).set([1e-6]);
184
+ const pBuf = createUniformBuffer(d, params);
185
+ const bg = createBindGroup(d, this.pipelines['rmsnorm'], [pBuf, ssdOut, this.gpuWeights['normWeight'], innerNormOut, innerNormInv]);
186
+ dispatchKernel(d, this.pipelines['rmsnorm'], bg, [cdiv(M, 64), 1, 1]);
187
+ }
188
+ ssdOut.destroy();
189
+ innerNormInv.destroy();
190
+ // 6. Output projection
191
+ const outProjOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
192
+ {
193
+ const params = new Uint32Array([M, D, dModel]).buffer;
194
+ const pBuf = createUniformBuffer(d, params);
195
+ const zeroBias = createStorageBuffer(d, new Float32Array(dModel), true);
196
+ const bg = createBindGroup(d, this.pipelines['linear'], [pBuf, innerNormOut, this.gpuWeights['wOutProj'], zeroBias, outProjOut]);
197
+ dispatchKernel(d, this.pipelines['linear'], bg, [cdiv(M, 16), cdiv(dModel, 16), 1]);
198
+ zeroBias.destroy();
199
+ }
200
+ innerNormOut.destroy();
201
+ // 7. Residual add
202
+ const output = createEmptyStorageBuffer(d, M * dModel * 4, true);
203
+ {
204
+ const nBuf = createUniformBuffer(d, new Uint32Array([M * dModel]).buffer);
205
+ const bg = createBindGroup(d, this.pipelines['elAdd'], [outProjOut, xBuf, output, nBuf]);
206
+ dispatchKernel(d, this.pipelines['elAdd'], bg, [cdiv(M * dModel, 256), 1, 1]);
207
+ }
208
+ outProjOut.destroy();
209
+ const cache = { stateCarry };
210
+ return { output, cache };
211
+ }
212
+ parameters() {
213
+ const { dModel, dState, dConv, nHeads, nGroups } = this.config;
214
+ const D = this.dInner;
215
+ const N = dState;
216
+ const K = dConv;
217
+ const H = nHeads;
218
+ const G = nGroups;
219
+ const convD = D + 2 * G * N;
220
+ return [
221
+ { buf: this.gpuWeights['wInProj'], numel: (D + 2 * G * N + H) * dModel, name: 'wInProj' },
222
+ { buf: this.gpuWeights['wConv'], numel: convD * K, name: 'wConv' },
223
+ { buf: this.gpuWeights['bConv'], numel: convD, name: 'bConv' },
224
+ { buf: this.gpuWeights['A_log'], numel: H, name: 'A_log' },
225
+ { buf: this.gpuWeights['dt_bias'], numel: H, name: 'dt_bias' },
226
+ { buf: this.gpuWeights['D_vec'], numel: H, name: 'D_vec' },
227
+ { buf: this.gpuWeights['wOutProj'], numel: dModel * D, name: 'wOutProj' },
228
+ { buf: this.gpuWeights['normWeight'], numel: D, name: 'normWeight' },
229
+ { buf: this.gpuWeights['preNormWeight'], numel: dModel, name: 'preNormWeight' },
230
+ ];
231
+ }
232
+ getTrainableParams() {
233
+ if (this._wslaMode) {
234
+ // WSLA: train only B/C rows of wInProj (the selective projection part)
235
+ return [
236
+ { buf: this.gpuWeights['wInProj'],
237
+ numel: (this.config.nGroups * this.config.dState * 2) * this.config.dModel,
238
+ name: 'wInProj_BC' },
239
+ ];
240
+ }
241
+ return this.parameters();
242
+ }
243
+ setWSLAMode(enabled) {
244
+ this._wslaMode = enabled;
245
+ }
246
+ destroy() {
247
+ for (const buf of Object.values(this.gpuWeights))
248
+ buf.destroy();
249
+ this.gpuWeights = {};
250
+ }
251
+ }
252
+ //# sourceMappingURL=mamba2_block.js.map