@seanhogg/builderforce-memory-engine 2026.6.18

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +393 -0
  3. package/dist/index.d.ts +32 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +40 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/kernels/activations.d.ts +5 -0
  8. package/dist/kernels/activations.d.ts.map +1 -0
  9. package/dist/kernels/activations.js +171 -0
  10. package/dist/kernels/activations.js.map +1 -0
  11. package/dist/kernels/attention.d.ts +19 -0
  12. package/dist/kernels/attention.d.ts.map +1 -0
  13. package/dist/kernels/attention.js +263 -0
  14. package/dist/kernels/attention.js.map +1 -0
  15. package/dist/kernels/complex_ssd.d.ts +33 -0
  16. package/dist/kernels/complex_ssd.d.ts.map +1 -0
  17. package/dist/kernels/complex_ssd.js +305 -0
  18. package/dist/kernels/complex_ssd.js.map +1 -0
  19. package/dist/kernels/conv1d.d.ts +3 -0
  20. package/dist/kernels/conv1d.d.ts.map +1 -0
  21. package/dist/kernels/conv1d.js +158 -0
  22. package/dist/kernels/conv1d.js.map +1 -0
  23. package/dist/kernels/linear_projection.d.ts +3 -0
  24. package/dist/kernels/linear_projection.d.ts.map +1 -0
  25. package/dist/kernels/linear_projection.js +219 -0
  26. package/dist/kernels/linear_projection.js.map +1 -0
  27. package/dist/kernels/selective_scan.d.ts +3 -0
  28. package/dist/kernels/selective_scan.d.ts.map +1 -0
  29. package/dist/kernels/selective_scan.js +348 -0
  30. package/dist/kernels/selective_scan.js.map +1 -0
  31. package/dist/kernels/ssd.d.ts +29 -0
  32. package/dist/kernels/ssd.d.ts.map +1 -0
  33. package/dist/kernels/ssd.js +276 -0
  34. package/dist/kernels/ssd.js.map +1 -0
  35. package/dist/kernels/weight_update.d.ts +3 -0
  36. package/dist/kernels/weight_update.d.ts.map +1 -0
  37. package/dist/kernels/weight_update.js +119 -0
  38. package/dist/kernels/weight_update.js.map +1 -0
  39. package/dist/model/attention_block.d.ts +48 -0
  40. package/dist/model/attention_block.d.ts.map +1 -0
  41. package/dist/model/attention_block.js +262 -0
  42. package/dist/model/attention_block.js.map +1 -0
  43. package/dist/model/mamba1_block.d.ts +70 -0
  44. package/dist/model/mamba1_block.d.ts.map +1 -0
  45. package/dist/model/mamba1_block.js +333 -0
  46. package/dist/model/mamba1_block.js.map +1 -0
  47. package/dist/model/mamba2_block.d.ts +44 -0
  48. package/dist/model/mamba2_block.d.ts.map +1 -0
  49. package/dist/model/mamba2_block.js +252 -0
  50. package/dist/model/mamba2_block.js.map +1 -0
  51. package/dist/model/mamba3_block.d.ts +51 -0
  52. package/dist/model/mamba3_block.d.ts.map +1 -0
  53. package/dist/model/mamba3_block.js +270 -0
  54. package/dist/model/mamba3_block.js.map +1 -0
  55. package/dist/model/mamba_block.d.ts +64 -0
  56. package/dist/model/mamba_block.d.ts.map +1 -0
  57. package/dist/model/mamba_block.js +303 -0
  58. package/dist/model/mamba_block.js.map +1 -0
  59. package/dist/model/mamba_model.d.ts +140 -0
  60. package/dist/model/mamba_model.d.ts.map +1 -0
  61. package/dist/model/mamba_model.js +527 -0
  62. package/dist/model/mamba_model.js.map +1 -0
  63. package/dist/model/sequence_layer.d.ts +25 -0
  64. package/dist/model/sequence_layer.d.ts.map +1 -0
  65. package/dist/model/sequence_layer.js +8 -0
  66. package/dist/model/sequence_layer.js.map +1 -0
  67. package/dist/tokenizer/bpe.d.ts +29 -0
  68. package/dist/tokenizer/bpe.d.ts.map +1 -0
  69. package/dist/tokenizer/bpe.js +164 -0
  70. package/dist/tokenizer/bpe.js.map +1 -0
  71. package/dist/training/autograd.d.ts +27 -0
  72. package/dist/training/autograd.d.ts.map +1 -0
  73. package/dist/training/autograd.js +120 -0
  74. package/dist/training/autograd.js.map +1 -0
  75. package/dist/training/trainer.d.ts +36 -0
  76. package/dist/training/trainer.d.ts.map +1 -0
  77. package/dist/training/trainer.js +183 -0
  78. package/dist/training/trainer.js.map +1 -0
  79. package/dist/utils/gpu_utils.d.ts +21 -0
  80. package/dist/utils/gpu_utils.d.ts.map +1 -0
  81. package/dist/utils/gpu_utils.js +111 -0
  82. package/dist/utils/gpu_utils.js.map +1 -0
  83. package/dist/utils/quantization.d.ts +26 -0
  84. package/dist/utils/quantization.d.ts.map +1 -0
  85. package/dist/utils/quantization.js +116 -0
  86. package/dist/utils/quantization.js.map +1 -0
  87. package/dist/utils/rng.d.ts +36 -0
  88. package/dist/utils/rng.d.ts.map +1 -0
  89. package/dist/utils/rng.js +61 -0
  90. package/dist/utils/rng.js.map +1 -0
  91. package/package.json +99 -0
  92. package/src/index.ts +114 -0
  93. package/src/kernels/activations.ts +174 -0
  94. package/src/kernels/attention.ts +268 -0
  95. package/src/kernels/complex_ssd.ts +307 -0
  96. package/src/kernels/conv1d.ts +159 -0
  97. package/src/kernels/linear_projection.ts +220 -0
  98. package/src/kernels/selective_scan.ts +350 -0
  99. package/src/kernels/ssd.ts +278 -0
  100. package/src/kernels/weight_update.ts +120 -0
  101. package/src/model/attention_block.ts +344 -0
  102. package/src/model/mamba1_block.ts +437 -0
  103. package/src/model/mamba2_block.ts +319 -0
  104. package/src/model/mamba3_block.ts +335 -0
  105. package/src/model/mamba_block.ts +401 -0
  106. package/src/model/mamba_model.ts +678 -0
  107. package/src/model/sequence_layer.ts +29 -0
  108. package/src/tokenizer/bpe.ts +186 -0
  109. package/src/training/autograd.ts +135 -0
  110. package/src/training/trainer.ts +309 -0
  111. package/src/utils/gpu_utils.ts +147 -0
  112. package/src/utils/quantization.ts +154 -0
  113. package/src/utils/rng.ts +65 -0
@@ -0,0 +1,401 @@
1
+ /**
2
+ * mamba_block.ts – Mamba Mixer Block
3
+ */
4
+
5
+ import {
6
+ createComputePipeline,
7
+ createBindGroup,
8
+ createStorageBuffer,
9
+ createEmptyStorageBuffer,
10
+ createUniformBuffer,
11
+ dispatchKernel,
12
+ cdiv,
13
+ } from '../utils/gpu_utils.js';
14
+
15
+ import { SELECTIVE_SCAN_FORWARD_WGSL } from '../kernels/selective_scan.js';
16
+ import { gaussianArray } from '../utils/rng.js';
17
+ import { CONV1D_FORWARD_WGSL } from '../kernels/conv1d.js';
18
+ import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection.js';
19
+ import { ACTIVATIONS_WGSL } from '../kernels/activations.js';
20
+
21
+ export interface MambaBlockConfig {
22
+ dModel: number;
23
+ dState?: number;
24
+ dConv?: number;
25
+ expand?: number;
26
+ dtRank?: number;
27
+ biasConv?: boolean;
28
+ }
29
+
30
+ export interface BlockParam {
31
+ buf: GPUBuffer;
32
+ numel: number;
33
+ name: string;
34
+ }
35
+
36
+ export interface BlockCache {
37
+ normInv: GPUBuffer;
38
+ normIn: GPUBuffer;
39
+ normOut: GPUBuffer;
40
+ zBuf: GPUBuffer;
41
+ xConvIn: GPUBuffer;
42
+ convOut: GPUBuffer;
43
+ siluOut: GPUBuffer;
44
+ deltaFull: GPUBuffer;
45
+ B_raw: GPUBuffer;
46
+ C_raw: GPUBuffer;
47
+ hCache: GPUBuffer;
48
+ }
49
+
50
+ export interface BlockForwardResult {
51
+ output: GPUBuffer;
52
+ cache: BlockCache;
53
+ }
54
+
55
+ export class MambaBlock {
56
+ device: GPUDevice;
57
+ config: Required<MambaBlockConfig>;
58
+ dInner: number;
59
+ dtRank: number;
60
+ wInProj: Float32Array;
61
+ bInProj: Float32Array;
62
+ wConv: Float32Array;
63
+ bConv: Float32Array;
64
+ wXProj: Float32Array;
65
+ bXProj: Float32Array;
66
+ wDtProj: Float32Array;
67
+ bDtProj: Float32Array;
68
+ A_log: Float32Array;
69
+ D_vec: Float32Array;
70
+ wOutProj: Float32Array;
71
+ bOutProj: Float32Array;
72
+ normWeight: Float32Array;
73
+ gpuWeights: Record<string, GPUBuffer>;
74
+ pipelines: Record<string, GPUComputePipeline>;
75
+ private _wslaMode = false;
76
+
77
+ constructor(device: GPUDevice, config: MambaBlockConfig) {
78
+ this.device = device;
79
+ this.config = {
80
+ dState : 16,
81
+ dConv : 4,
82
+ expand : 2,
83
+ biasConv: true,
84
+ dtRank : Math.ceil(config.dModel / 16),
85
+ ...config,
86
+ } as Required<MambaBlockConfig>;
87
+
88
+ const { dModel, expand } = this.config;
89
+ this.dInner = expand * dModel;
90
+ this.dtRank = config.dtRank ?? Math.ceil(dModel / 16);
91
+
92
+ // Initialize these before _initWeights so TypeScript is happy
93
+ this.wInProj = new Float32Array(0);
94
+ this.bInProj = new Float32Array(0);
95
+ this.wConv = new Float32Array(0);
96
+ this.bConv = new Float32Array(0);
97
+ this.wXProj = new Float32Array(0);
98
+ this.bXProj = new Float32Array(0);
99
+ this.wDtProj = new Float32Array(0);
100
+ this.bDtProj = new Float32Array(0);
101
+ this.A_log = new Float32Array(0);
102
+ this.D_vec = new Float32Array(0);
103
+ this.wOutProj = new Float32Array(0);
104
+ this.bOutProj = new Float32Array(0);
105
+ this.normWeight = new Float32Array(0);
106
+ this.gpuWeights = {};
107
+ this.pipelines = {};
108
+
109
+ this._initWeights();
110
+ this._buildPipelines();
111
+ }
112
+
113
+ private _initWeights(): void {
114
+ const { dModel, dState, dConv } = this.config;
115
+ const D = this.dInner;
116
+ const N = dState;
117
+ const K = dConv;
118
+ const R = this.dtRank;
119
+
120
+ const randn = (n: number, std = 0.02): Float32Array => gaussianArray(n, std);
121
+
122
+ const zeros = (n: number): Float32Array => new Float32Array(n);
123
+ const ones = (n: number): Float32Array => new Float32Array(n).fill(1.0);
124
+
125
+ this.wInProj = randn(2 * D * dModel);
126
+ this.bInProj = zeros(2 * D);
127
+ this.wConv = randn(D * K, 0.01);
128
+ this.bConv = zeros(D);
129
+ this.wXProj = randn((R + 2 * N) * D, 0.01);
130
+ this.bXProj = zeros(R + 2 * N);
131
+ this.wDtProj = randn(D * R, 0.02);
132
+ this.bDtProj = zeros(D);
133
+
134
+ this.A_log = new Float32Array(D * N);
135
+ for (let d = 0; d < D; d++) {
136
+ for (let n = 0; n < N; n++) {
137
+ this.A_log[d * N + n] = Math.log(n + 1);
138
+ }
139
+ }
140
+
141
+ this.D_vec = ones(D);
142
+ this.wOutProj = randn(dModel * D, 0.02);
143
+ this.bOutProj = zeros(dModel);
144
+ this.normWeight = ones(dModel);
145
+
146
+ this._uploadWeightsToGPU();
147
+ }
148
+
149
+ private _uploadWeightsToGPU(): void {
150
+ const d = this.device;
151
+ const mk = (arr: Float32Array, readable = true): GPUBuffer => createStorageBuffer(d, arr, readable);
152
+
153
+ this.gpuWeights = {
154
+ wInProj : mk(this.wInProj),
155
+ bInProj : mk(this.bInProj),
156
+ wConv : mk(this.wConv),
157
+ bConv : mk(this.bConv),
158
+ wXProj : mk(this.wXProj),
159
+ bXProj : mk(this.bXProj),
160
+ wDtProj : mk(this.wDtProj),
161
+ bDtProj : mk(this.bDtProj),
162
+ A_log : mk(this.A_log),
163
+ D_vec : mk(this.D_vec),
164
+ wOutProj : mk(this.wOutProj),
165
+ bOutProj : mk(this.bOutProj),
166
+ normWeight: mk(this.normWeight),
167
+ };
168
+ }
169
+
170
+ private _buildPipelines(): void {
171
+ const d = this.device;
172
+
173
+ this.pipelines = {
174
+ linear : createComputePipeline(d, LINEAR_FORWARD_WGSL, 'linear_forward'),
175
+ conv1d : createComputePipeline(d, CONV1D_FORWARD_WGSL, 'conv1d_forward'),
176
+ silu : createComputePipeline(d, ACTIVATIONS_WGSL, 'silu_forward'),
177
+ rmsnorm : createComputePipeline(d, ACTIVATIONS_WGSL, 'rmsnorm_forward'),
178
+ scan_fwd : createComputePipeline(d, SELECTIVE_SCAN_FORWARD_WGSL, 'forward_scan'),
179
+ scan_reduce: createComputePipeline(d, SELECTIVE_SCAN_FORWARD_WGSL, 'forward_reduce'),
180
+ };
181
+ }
182
+
183
+ forward(xBuf: GPUBuffer, batch: number, seqLen: number): BlockForwardResult {
184
+ const d = this.device;
185
+ const { dModel, dState, dConv } = this.config;
186
+ const D = this.dInner;
187
+ const N = dState;
188
+ const B = batch;
189
+ const L = seqLen;
190
+ const M = B * L;
191
+ const R = this.dtRank;
192
+
193
+ const cache = {} as BlockCache;
194
+
195
+ const normOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
196
+ const normInv = createEmptyStorageBuffer(d, M * 4, true);
197
+ cache.normInv = normInv;
198
+ cache.normIn = xBuf;
199
+
200
+ {
201
+ const params = new ArrayBuffer(16);
202
+ new Uint32Array(params, 0, 2).set([M, dModel]);
203
+ new Float32Array(params, 8, 1).set([1e-6]);
204
+ const pBuf = createUniformBuffer(d, params);
205
+
206
+ const bg = createBindGroup(d, this.pipelines['rmsnorm']!,
207
+ [pBuf, xBuf, this.gpuWeights['normWeight']!, normOut, normInv]);
208
+ dispatchKernel(d, this.pipelines['rmsnorm']!, bg, [cdiv(M, 64), 1, 1]);
209
+ }
210
+
211
+ const inProjOut = createEmptyStorageBuffer(d, M * 2 * D * 4, true);
212
+ cache.normOut = normOut;
213
+ {
214
+ const params = new Uint32Array([M, dModel, 2 * D]).buffer;
215
+ const pBuf = createUniformBuffer(d, params);
216
+ const bg = createBindGroup(d, this.pipelines['linear']!,
217
+ [pBuf, normOut, this.gpuWeights['wInProj']!, this.gpuWeights['bInProj']!, inProjOut]);
218
+ dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(2 * D, 16), 1]);
219
+ }
220
+
221
+ const xConvIn = createEmptyStorageBuffer(d, M * D * 4, true);
222
+ const zBuf = createEmptyStorageBuffer(d, M * D * 4, true);
223
+ {
224
+ const enc = d.createCommandEncoder();
225
+ enc.copyBufferToBuffer(inProjOut, 0, xConvIn, 0, M * D * 4);
226
+ enc.copyBufferToBuffer(inProjOut, M * D * 4, zBuf, 0, M * D * 4);
227
+ d.queue.submit([enc.finish()]);
228
+ }
229
+ cache.zBuf = zBuf;
230
+
231
+ const convOut = createEmptyStorageBuffer(d, M * D * 4, true);
232
+ cache.xConvIn = xConvIn;
233
+ {
234
+ const params = new Uint32Array([L, D, dConv, B]).buffer;
235
+ const pBuf = createUniformBuffer(d, params);
236
+ const bg = createBindGroup(d, this.pipelines['conv1d']!,
237
+ [pBuf, xConvIn, this.gpuWeights['wConv']!, this.gpuWeights['bConv']!, convOut]);
238
+ dispatchKernel(d, this.pipelines['conv1d']!, bg, [cdiv(L, 16), cdiv(D, 16), B]);
239
+ }
240
+
241
+ const siluOut = createEmptyStorageBuffer(d, M * D * 4, true);
242
+ cache.convOut = convOut;
243
+ {
244
+ const params = new Uint32Array([M * D]).buffer;
245
+ const pBuf = createUniformBuffer(d, params);
246
+ const bg = createBindGroup(d, this.pipelines['silu']!,
247
+ [pBuf, convOut, siluOut]);
248
+ dispatchKernel(d, this.pipelines['silu']!, bg, [cdiv(M * D, 256), 1, 1]);
249
+ }
250
+
251
+ const xProjOut = createEmptyStorageBuffer(d, M * (R + 2 * N) * 4, true);
252
+ {
253
+ const params = new Uint32Array([M, D, R + 2 * N]).buffer;
254
+ const pBuf = createUniformBuffer(d, params);
255
+ const bg = createBindGroup(d, this.pipelines['linear']!,
256
+ [pBuf, siluOut, this.gpuWeights['wXProj']!, this.gpuWeights['bXProj']!, xProjOut]);
257
+ dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(R + 2 * N, 16), 1]);
258
+ }
259
+
260
+ const dtRaw = createEmptyStorageBuffer(d, M * R * 4, true);
261
+ const B_raw = createEmptyStorageBuffer(d, B * L * N * 4, true);
262
+ const C_raw = createEmptyStorageBuffer(d, B * L * N * 4, true);
263
+ {
264
+ const enc = d.createCommandEncoder();
265
+ enc.copyBufferToBuffer(xProjOut, 0, dtRaw, 0, M * R * 4);
266
+ enc.copyBufferToBuffer(xProjOut, M * R * 4, B_raw, 0, B * L * N * 4);
267
+ enc.copyBufferToBuffer(xProjOut, M * (R + N) * 4, C_raw, 0, B * L * N * 4);
268
+ d.queue.submit([enc.finish()]);
269
+ }
270
+
271
+ const deltaFull = createEmptyStorageBuffer(d, M * D * 4, true);
272
+ {
273
+ const params = new Uint32Array([M, R, D]).buffer;
274
+ const pBuf = createUniformBuffer(d, params);
275
+ const bg = createBindGroup(d, this.pipelines['linear']!,
276
+ [pBuf, dtRaw, this.gpuWeights['wDtProj']!, this.gpuWeights['bDtProj']!, deltaFull]);
277
+ dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(D, 16), 1]);
278
+ }
279
+
280
+ const scanY = createEmptyStorageBuffer(d, B * L * D * 4, true);
281
+ const hCache = createEmptyStorageBuffer(d, 2 * B * L * D * N * 4, true);
282
+ cache.siluOut = siluOut;
283
+ cache.deltaFull = deltaFull;
284
+ cache.B_raw = B_raw;
285
+ cache.C_raw = C_raw;
286
+ cache.hCache = hCache;
287
+
288
+ {
289
+ const params = new Uint32Array([L, N, D, B]).buffer;
290
+ const pBuf = createUniformBuffer(d, params);
291
+
292
+ const bg = createBindGroup(d, this.pipelines['scan_fwd']!,
293
+ [pBuf, siluOut, deltaFull, this.gpuWeights['A_log']!, B_raw, C_raw,
294
+ this.gpuWeights['D_vec']!, scanY, hCache]);
295
+ dispatchKernel(d, this.pipelines['scan_fwd']!, bg,
296
+ [cdiv(D, 8), cdiv(N, 8), B]);
297
+
298
+ const bg2 = createBindGroup(d, this.pipelines['scan_reduce']!,
299
+ [pBuf, siluOut, deltaFull, this.gpuWeights['A_log']!, B_raw, C_raw,
300
+ this.gpuWeights['D_vec']!, scanY, hCache]);
301
+ dispatchKernel(d, this.pipelines['scan_reduce']!, bg2,
302
+ [cdiv(L, 64), D, B]);
303
+ }
304
+
305
+ const siluZ = createEmptyStorageBuffer(d, M * D * 4, true);
306
+ const gatedOut = createEmptyStorageBuffer(d, M * D * 4, true);
307
+ {
308
+ const params = new Uint32Array([M * D]).buffer;
309
+ const pBuf = createUniformBuffer(d, params);
310
+ const bg = createBindGroup(d, this.pipelines['silu']!,
311
+ [pBuf, zBuf, siluZ]);
312
+ dispatchKernel(d, this.pipelines['silu']!, bg, [cdiv(M * D, 256), 1, 1]);
313
+
314
+ const mulShader = /* wgsl */`
315
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
316
+ @group(0) @binding(1) var<storage, read> b : array<f32>;
317
+ @group(0) @binding(2) var<storage, read_write> c : array<f32>;
318
+ @group(0) @binding(3) var<uniform> n : u32;
319
+ @compute @workgroup_size(256)
320
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
321
+ let i = gid.x;
322
+ if (i < n) { c[i] = a[i] * b[i]; }
323
+ }
324
+ `;
325
+ const mulPipeline = createComputePipeline(d, mulShader, 'main');
326
+ const nBuf = createUniformBuffer(d, new Uint32Array([M * D]).buffer);
327
+ const bgMul = createBindGroup(d, mulPipeline,
328
+ [scanY, siluZ, gatedOut, nBuf]);
329
+ dispatchKernel(d, mulPipeline, bgMul, [cdiv(M * D, 256), 1, 1]);
330
+ }
331
+
332
+ const outProjOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
333
+ {
334
+ const params = new Uint32Array([M, D, dModel]).buffer;
335
+ const pBuf = createUniformBuffer(d, params);
336
+ const bg = createBindGroup(d, this.pipelines['linear']!,
337
+ [pBuf, gatedOut, this.gpuWeights['wOutProj']!, this.gpuWeights['bOutProj']!, outProjOut]);
338
+ dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(dModel, 16), 1]);
339
+ }
340
+
341
+ const output = createEmptyStorageBuffer(d, M * dModel * 4, true);
342
+ {
343
+ const addShader = /* wgsl */`
344
+ @group(0) @binding(0) var<storage, read> a : array<f32>;
345
+ @group(0) @binding(1) var<storage, read> b : array<f32>;
346
+ @group(0) @binding(2) var<storage, read_write> c : array<f32>;
347
+ @group(0) @binding(3) var<uniform> n : u32;
348
+ @compute @workgroup_size(256)
349
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
350
+ let i = gid.x;
351
+ if (i < n) { c[i] = a[i] + b[i]; }
352
+ }
353
+ `;
354
+ const addPipeline = createComputePipeline(d, addShader, 'main');
355
+ const nBuf = createUniformBuffer(d, new Uint32Array([M * dModel]).buffer);
356
+ const bgAdd = createBindGroup(d, addPipeline,
357
+ [outProjOut, xBuf, output, nBuf]);
358
+ dispatchKernel(d, addPipeline, bgAdd, [cdiv(M * dModel, 256), 1, 1]);
359
+ }
360
+
361
+ return { output, cache };
362
+ }
363
+
364
+ parameters(): BlockParam[] {
365
+ const { dModel, dState, dConv } = this.config;
366
+ const D = this.dInner;
367
+ const N = dState;
368
+ const K = dConv;
369
+ const R = this.dtRank;
370
+
371
+ return [
372
+ { buf: this.gpuWeights['wInProj']!, numel: 2 * D * dModel, name: 'wInProj' },
373
+ { buf: this.gpuWeights['bInProj']!, numel: 2 * D, name: 'bInProj' },
374
+ { buf: this.gpuWeights['wConv']!, numel: D * K, name: 'wConv' },
375
+ { buf: this.gpuWeights['bConv']!, numel: D, name: 'bConv' },
376
+ { buf: this.gpuWeights['wXProj']!, numel: (R + 2*N) * D, name: 'wXProj' },
377
+ { buf: this.gpuWeights['bXProj']!, numel: R + 2 * N, name: 'bXProj' },
378
+ { buf: this.gpuWeights['wDtProj']!, numel: D * R, name: 'wDtProj' },
379
+ { buf: this.gpuWeights['bDtProj']!, numel: D, name: 'bDtProj' },
380
+ { buf: this.gpuWeights['A_log']!, numel: D * N, name: 'A_log' },
381
+ { buf: this.gpuWeights['D_vec']!, numel: D, name: 'D_vec' },
382
+ { buf: this.gpuWeights['wOutProj']!, numel: dModel * D, name: 'wOutProj' },
383
+ { buf: this.gpuWeights['bOutProj']!, numel: dModel, name: 'bOutProj' },
384
+ { buf: this.gpuWeights['normWeight']!, numel: dModel, name: 'normWeight'},
385
+ ];
386
+ }
387
+
388
+ setWSLAMode(enabled: boolean): void {
389
+ this._wslaMode = enabled;
390
+ }
391
+
392
+ getTrainableParams(): BlockParam[] {
393
+ if (this._wslaMode) {
394
+ return [
395
+ { buf: this.gpuWeights['wXProj']!, numel: this.wXProj.length, name: 'wXProj' },
396
+ { buf: this.gpuWeights['bXProj']!, numel: this.bXProj.length, name: 'bXProj' },
397
+ ];
398
+ }
399
+ return this.parameters();
400
+ }
401
+ }