mambacode.js 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +196 -0
- package/package.json +54 -0
- package/src/index.js +89 -0
- package/src/kernels/activations.js +88 -0
- package/src/kernels/conv1d.js +153 -0
- package/src/kernels/linear_projection.js +220 -0
- package/src/kernels/selective_scan.js +350 -0
- package/src/kernels/weight_update.js +120 -0
- package/src/model/mamba_block.js +443 -0
- package/src/model/mamba_model.js +335 -0
- package/src/tokenizer/bpe.js +256 -0
- package/src/training/autograd.js +221 -0
- package/src/training/trainer.js +394 -0
- package/src/utils/gpu_utils.js +217 -0
- package/src/utils/quantization.js +215 -0
|
@@ -0,0 +1,443 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* mamba_block.js – Mamba Mixer Block
|
|
3
|
+
*
|
|
4
|
+
* Implements one complete Mamba residual layer:
|
|
5
|
+
*
|
|
6
|
+
* x ──► Norm ──► Linear up (×2, for z-gate) ──► Conv1D ──► SiLU ──► Scan ──► × z ──► Linear down ──► + x
|
|
7
|
+
*
|
|
8
|
+
* Components (all dispatched as WebGPU compute passes):
|
|
9
|
+
* 1. RMSNorm
|
|
10
|
+
* 2. Linear up-projection: (D_model → 2 × D_inner)
|
|
11
|
+
* 3. 1D Causal Convolution (depthwise, kernel_size=4)
|
|
12
|
+
* 4. SiLU activation
|
|
13
|
+
* 5. Selective Scan (S6 core)
|
|
14
|
+
* 6. Gated multiplication: y * SiLU(z)
|
|
15
|
+
* 7. Linear down-projection: (D_inner → D_model)
|
|
16
|
+
* 8. Residual add
|
|
17
|
+
*/
|
|
18
|
+
|
|
19
|
+
import {
|
|
20
|
+
createComputePipeline,
|
|
21
|
+
createBindGroup,
|
|
22
|
+
createStorageBuffer,
|
|
23
|
+
createEmptyStorageBuffer,
|
|
24
|
+
createUniformBuffer,
|
|
25
|
+
dispatchKernel,
|
|
26
|
+
cdiv,
|
|
27
|
+
} from '../utils/gpu_utils.js';
|
|
28
|
+
|
|
29
|
+
import { SELECTIVE_SCAN_FORWARD_WGSL } from '../kernels/selective_scan.js';
|
|
30
|
+
import { CONV1D_FORWARD_WGSL } from '../kernels/conv1d.js';
|
|
31
|
+
import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection.js';
|
|
32
|
+
import { ACTIVATIONS_WGSL } from '../kernels/activations.js';
|
|
33
|
+
|
|
34
|
+
/**
|
|
35
|
+
* @typedef {Object} MambaBlockConfig
|
|
36
|
+
* @property {number} dModel – model dimension (embedding size)
|
|
37
|
+
* @property {number} dState – SSM state dimension (N, default 16)
|
|
38
|
+
* @property {number} dConv – 1D conv kernel size (default 4)
|
|
39
|
+
* @property {number} expand – expansion factor (default 2) → dInner = expand * dModel
|
|
40
|
+
* @property {number} dtRank – rank of Δ projection (default ceil(dModel/16))
|
|
41
|
+
* @property {boolean} [biasConv] – use bias in conv (default true)
|
|
42
|
+
*/
|
|
43
|
+
|
|
44
|
+
export class MambaBlock {
|
|
45
|
+
/**
|
|
46
|
+
* @param {GPUDevice} device
|
|
47
|
+
* @param {MambaBlockConfig} config
|
|
48
|
+
*/
|
|
49
|
+
constructor(device, config) {
|
|
50
|
+
this.device = device;
|
|
51
|
+
this.config = {
|
|
52
|
+
dState : 16,
|
|
53
|
+
dConv : 4,
|
|
54
|
+
expand : 2,
|
|
55
|
+
biasConv: true,
|
|
56
|
+
...config,
|
|
57
|
+
};
|
|
58
|
+
|
|
59
|
+
const { dModel, dState, dConv, expand } = this.config;
|
|
60
|
+
this.dInner = expand * dModel;
|
|
61
|
+
this.dtRank = this.config.dtRank ?? Math.ceil(dModel / 16);
|
|
62
|
+
|
|
63
|
+
// ---- Initialise learnable parameters (CPU → GPU) ----
|
|
64
|
+
this._initWeights();
|
|
65
|
+
|
|
66
|
+
// ---- Compile GPU pipelines (once) ----
|
|
67
|
+
this._buildPipelines();
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
// ─── Weight initialisation ────────────────────────────────────────────────
|
|
71
|
+
|
|
72
|
+
_initWeights() {
|
|
73
|
+
const { dModel, dState, dConv } = this.config;
|
|
74
|
+
const D = this.dInner;
|
|
75
|
+
const N = dState;
|
|
76
|
+
const K = dConv;
|
|
77
|
+
const R = this.dtRank;
|
|
78
|
+
|
|
79
|
+
const randn = (n, std = 0.02) => {
|
|
80
|
+
const a = new Float32Array(n);
|
|
81
|
+
for (let i = 0; i < n; i++) {
|
|
82
|
+
// Box-Muller
|
|
83
|
+
const u1 = Math.random(), u2 = Math.random();
|
|
84
|
+
a[i] = std * Math.sqrt(-2 * Math.log(u1 + 1e-12)) * Math.cos(2 * Math.PI * u2);
|
|
85
|
+
}
|
|
86
|
+
return a;
|
|
87
|
+
};
|
|
88
|
+
|
|
89
|
+
const zeros = (n) => new Float32Array(n);
|
|
90
|
+
const ones = (n) => new Float32Array(n).fill(1.0);
|
|
91
|
+
const linspace = (n) => {
|
|
92
|
+
const a = new Float32Array(n);
|
|
93
|
+
for (let i = 0; i < n; i++) a[i] = i;
|
|
94
|
+
return a;
|
|
95
|
+
};
|
|
96
|
+
|
|
97
|
+
// in_proj: (2*D_inner, D_model) – up-projection (and z gate)
|
|
98
|
+
this.wInProj = randn(2 * D * dModel);
|
|
99
|
+
this.bInProj = zeros(2 * D);
|
|
100
|
+
|
|
101
|
+
// conv1d: weight (D_inner, K), bias (D_inner,)
|
|
102
|
+
this.wConv = randn(D * K, 0.01);
|
|
103
|
+
this.bConv = zeros(D);
|
|
104
|
+
|
|
105
|
+
// x_proj: (dt_rank + 2*N, D_inner) – projects x to Δ, B, C
|
|
106
|
+
this.wXProj = randn((R + 2 * N) * D, 0.01);
|
|
107
|
+
this.bXProj = zeros(R + 2 * N);
|
|
108
|
+
|
|
109
|
+
// dt_proj: (D_inner, dt_rank) – projects Δ to full D_inner width
|
|
110
|
+
this.wDtProj = randn(D * R, 0.02);
|
|
111
|
+
this.bDtProj = zeros(D);
|
|
112
|
+
|
|
113
|
+
// A: (D_inner, N) – log-space negative eigenvalues
|
|
114
|
+
// Initialised to log(range(1, N+1)) per HiPPO theory
|
|
115
|
+
this.A_log = new Float32Array(D * N);
|
|
116
|
+
for (let d = 0; d < D; d++) {
|
|
117
|
+
for (let n = 0; n < N; n++) {
|
|
118
|
+
this.A_log[d * N + n] = Math.log(n + 1);
|
|
119
|
+
}
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
// D: (D_inner,) – skip connection scale (initialised to 1)
|
|
123
|
+
this.D_vec = ones(D);
|
|
124
|
+
|
|
125
|
+
// out_proj: (D_model, D_inner) – down-projection
|
|
126
|
+
this.wOutProj = randn(dModel * D, 0.02);
|
|
127
|
+
this.bOutProj = zeros(dModel);
|
|
128
|
+
|
|
129
|
+
// RMSNorm scale: (D_model,)
|
|
130
|
+
this.normWeight = ones(dModel);
|
|
131
|
+
|
|
132
|
+
// Upload all to GPU
|
|
133
|
+
this._uploadWeightsToGPU();
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
_uploadWeightsToGPU() {
|
|
137
|
+
const d = this.device;
|
|
138
|
+
const mk = (arr, readable = true) => createStorageBuffer(d, arr, readable);
|
|
139
|
+
|
|
140
|
+
this.gpuWeights = {
|
|
141
|
+
wInProj : mk(this.wInProj),
|
|
142
|
+
bInProj : mk(this.bInProj),
|
|
143
|
+
wConv : mk(this.wConv),
|
|
144
|
+
bConv : mk(this.bConv),
|
|
145
|
+
wXProj : mk(this.wXProj),
|
|
146
|
+
bXProj : mk(this.bXProj),
|
|
147
|
+
wDtProj : mk(this.wDtProj),
|
|
148
|
+
bDtProj : mk(this.bDtProj),
|
|
149
|
+
A_log : mk(this.A_log),
|
|
150
|
+
D_vec : mk(this.D_vec),
|
|
151
|
+
wOutProj : mk(this.wOutProj),
|
|
152
|
+
bOutProj : mk(this.bOutProj),
|
|
153
|
+
normWeight: mk(this.normWeight),
|
|
154
|
+
};
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
// ─── Pipeline compilation ─────────────────────────────────────────────────
|
|
158
|
+
|
|
159
|
+
_buildPipelines() {
|
|
160
|
+
const d = this.device;
|
|
161
|
+
|
|
162
|
+
this.pipelines = {
|
|
163
|
+
linear : createComputePipeline(d, LINEAR_FORWARD_WGSL, 'linear_forward'),
|
|
164
|
+
conv1d : createComputePipeline(d, CONV1D_FORWARD_WGSL, 'conv1d_forward'),
|
|
165
|
+
silu : createComputePipeline(d, ACTIVATIONS_WGSL, 'silu_forward'),
|
|
166
|
+
rmsnorm : createComputePipeline(d, ACTIVATIONS_WGSL, 'rmsnorm_forward'),
|
|
167
|
+
scan_fwd : createComputePipeline(d, SELECTIVE_SCAN_FORWARD_WGSL, 'forward_scan'),
|
|
168
|
+
scan_reduce: createComputePipeline(d, SELECTIVE_SCAN_FORWARD_WGSL, 'forward_reduce'),
|
|
169
|
+
};
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
// ─── Forward pass ─────────────────────────────────────────────────────────
|
|
173
|
+
|
|
174
|
+
/**
|
|
175
|
+
* Run the Mamba block forward pass on GPU.
|
|
176
|
+
*
|
|
177
|
+
* @param {GPUBuffer} xBuf – input (batch * seqLen, dModel)
|
|
178
|
+
* @param {number} batch
|
|
179
|
+
* @param {number} seqLen
|
|
180
|
+
* @returns {{ output: GPUBuffer, cache: Object }}
|
|
181
|
+
* output – (batch * seqLen, dModel)
|
|
182
|
+
* cache – intermediate buffers needed for backward pass
|
|
183
|
+
*/
|
|
184
|
+
forward(xBuf, batch, seqLen) {
|
|
185
|
+
const d = this.device;
|
|
186
|
+
const { dModel, dState, dConv } = this.config;
|
|
187
|
+
const D = this.dInner;
|
|
188
|
+
const N = dState;
|
|
189
|
+
const B = batch;
|
|
190
|
+
const L = seqLen;
|
|
191
|
+
const M = B * L;
|
|
192
|
+
const R = this.dtRank;
|
|
193
|
+
|
|
194
|
+
// Intermediate buffers (will be freed after backward or cached)
|
|
195
|
+
const cache = {};
|
|
196
|
+
|
|
197
|
+
// 1. RMSNorm: (M, dModel)
|
|
198
|
+
const normOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
199
|
+
const normInv = createEmptyStorageBuffer(d, M * 4, true);
|
|
200
|
+
cache.normInv = normInv;
|
|
201
|
+
cache.normIn = xBuf;
|
|
202
|
+
|
|
203
|
+
{
|
|
204
|
+
// Pack params as Uint32 (num_rows, dim) + f32 (eps) ← 12 bytes padded to 16
|
|
205
|
+
const params = new ArrayBuffer(16);
|
|
206
|
+
new Uint32Array(params, 0, 2).set([M, dModel]);
|
|
207
|
+
new Float32Array(params, 8, 1).set([1e-6]);
|
|
208
|
+
const pBuf = createUniformBuffer(d, params);
|
|
209
|
+
|
|
210
|
+
const bg = createBindGroup(d, this.pipelines.rmsnorm,
|
|
211
|
+
[pBuf, xBuf, this.gpuWeights.normWeight, normOut, normInv]);
|
|
212
|
+
dispatchKernel(d, this.pipelines.rmsnorm, bg, [cdiv(M, 64), 1, 1]);
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
// 2. in_proj: (M, 2*D) = normOut @ wInProj^T + bInProj
|
|
216
|
+
const inProjOut = createEmptyStorageBuffer(d, M * 2 * D * 4, true);
|
|
217
|
+
cache.normOut = normOut;
|
|
218
|
+
{
|
|
219
|
+
const params = new Uint32Array([M, dModel, 2 * D]).buffer;
|
|
220
|
+
const pBuf = createUniformBuffer(d, params);
|
|
221
|
+
const bg = createBindGroup(d, this.pipelines.linear,
|
|
222
|
+
[pBuf, normOut, this.gpuWeights.wInProj, this.gpuWeights.bInProj, inProjOut]);
|
|
223
|
+
dispatchKernel(d, this.pipelines.linear, bg, [cdiv(M, 16), cdiv(2 * D, 16), 1]);
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
// Split inProjOut into x (M, D) and z (M, D) – the z-gate
|
|
227
|
+
// We reuse the same buffer with offsets since WGSL bindings can be offset.
|
|
228
|
+
// For simplicity, allocate two separate buffers and copy.
|
|
229
|
+
const xConvIn = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
230
|
+
const zBuf = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
231
|
+
{
|
|
232
|
+
// Copy first D columns into xConvIn, last D columns into zBuf
|
|
233
|
+
const enc = d.createCommandEncoder();
|
|
234
|
+
enc.copyBufferToBuffer(inProjOut, 0, xConvIn, 0, M * D * 4);
|
|
235
|
+
enc.copyBufferToBuffer(inProjOut, M * D * 4, zBuf, 0, M * D * 4);
|
|
236
|
+
d.queue.submit([enc.finish()]);
|
|
237
|
+
}
|
|
238
|
+
cache.zBuf = zBuf;
|
|
239
|
+
|
|
240
|
+
// 3. Conv1D on xConvIn: (B, L, D) – depthwise causal conv
|
|
241
|
+
const convOut = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
242
|
+
cache.xConvIn = xConvIn;
|
|
243
|
+
{
|
|
244
|
+
const params = new Uint32Array([L, D, dConv, B]).buffer;
|
|
245
|
+
const pBuf = createUniformBuffer(d, params);
|
|
246
|
+
const bg = createBindGroup(d, this.pipelines.conv1d,
|
|
247
|
+
[pBuf, xConvIn, this.gpuWeights.wConv, this.gpuWeights.bConv, convOut]);
|
|
248
|
+
dispatchKernel(d, this.pipelines.conv1d, bg, [cdiv(L, 16), cdiv(D, 16), B]);
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
// 4. SiLU(convOut) in-place
|
|
252
|
+
const siluOut = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
253
|
+
cache.convOut = convOut;
|
|
254
|
+
{
|
|
255
|
+
const params = new Uint32Array([M * D]).buffer;
|
|
256
|
+
const pBuf = createUniformBuffer(d, params);
|
|
257
|
+
const bg = createBindGroup(d, this.pipelines.silu,
|
|
258
|
+
[pBuf, convOut, siluOut]);
|
|
259
|
+
dispatchKernel(d, this.pipelines.silu, bg, [cdiv(M * D, 256), 1, 1]);
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
// 5. x_proj: (M, R+2N) = siluOut @ wXProj^T + bXProj
|
|
263
|
+
const xProjOut = createEmptyStorageBuffer(d, M * (R + 2 * N) * 4, true);
|
|
264
|
+
{
|
|
265
|
+
const params = new Uint32Array([M, D, R + 2 * N]).buffer;
|
|
266
|
+
const pBuf = createUniformBuffer(d, params);
|
|
267
|
+
const bg = createBindGroup(d, this.pipelines.linear,
|
|
268
|
+
[pBuf, siluOut, this.gpuWeights.wXProj, this.gpuWeights.bXProj, xProjOut]);
|
|
269
|
+
dispatchKernel(d, this.pipelines.linear, bg, [cdiv(M, 16), cdiv(R + 2 * N, 16), 1]);
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
// Split xProjOut → dtRaw (M, R), B_raw (M*N flattened) = (B, L, N), C_raw (B, L, N)
|
|
273
|
+
const dtRaw = createEmptyStorageBuffer(d, M * R * 4, true);
|
|
274
|
+
const B_raw = createEmptyStorageBuffer(d, B * L * N * 4, true);
|
|
275
|
+
const C_raw = createEmptyStorageBuffer(d, B * L * N * 4, true);
|
|
276
|
+
{
|
|
277
|
+
const enc = d.createCommandEncoder();
|
|
278
|
+
enc.copyBufferToBuffer(xProjOut, 0, dtRaw, 0, M * R * 4);
|
|
279
|
+
enc.copyBufferToBuffer(xProjOut, M * R * 4, B_raw, 0, B * L * N * 4);
|
|
280
|
+
enc.copyBufferToBuffer(xProjOut, M * (R + N) * 4, C_raw, 0, B * L * N * 4);
|
|
281
|
+
d.queue.submit([enc.finish()]);
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
// 6. dt_proj: (M, D) = dtRaw @ wDtProj^T + bDtProj
|
|
285
|
+
const deltaFull = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
286
|
+
{
|
|
287
|
+
const params = new Uint32Array([M, R, D]).buffer;
|
|
288
|
+
const pBuf = createUniformBuffer(d, params);
|
|
289
|
+
const bg = createBindGroup(d, this.pipelines.linear,
|
|
290
|
+
[pBuf, dtRaw, this.gpuWeights.wDtProj, this.gpuWeights.bDtProj, deltaFull]);
|
|
291
|
+
dispatchKernel(d, this.pipelines.linear, bg, [cdiv(M, 16), cdiv(D, 16), 1]);
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
// 7. Selective Scan
|
|
295
|
+
// Allocate y (B, L, D) and h_cache (2 * B*L*D*N) – first half for h, second for y_partial
|
|
296
|
+
const scanY = createEmptyStorageBuffer(d, B * L * D * 4, true);
|
|
297
|
+
const hCache = createEmptyStorageBuffer(d, 2 * B * L * D * N * 4, true);
|
|
298
|
+
cache.siluOut = siluOut;
|
|
299
|
+
cache.deltaFull = deltaFull;
|
|
300
|
+
cache.B_raw = B_raw;
|
|
301
|
+
cache.C_raw = C_raw;
|
|
302
|
+
cache.hCache = hCache;
|
|
303
|
+
|
|
304
|
+
{
|
|
305
|
+
const params = new Uint32Array([L, N, D, B]).buffer;
|
|
306
|
+
const pBuf = createUniformBuffer(d, params);
|
|
307
|
+
|
|
308
|
+
// forward_scan pass
|
|
309
|
+
const bg = createBindGroup(d, this.pipelines.scan_fwd,
|
|
310
|
+
[pBuf, siluOut, deltaFull, this.gpuWeights.A_log, B_raw, C_raw,
|
|
311
|
+
this.gpuWeights.D_vec, scanY, hCache]);
|
|
312
|
+
dispatchKernel(d, this.pipelines.scan_fwd, bg,
|
|
313
|
+
[cdiv(D, 8), cdiv(N, 8), B]);
|
|
314
|
+
|
|
315
|
+
// forward_reduce pass (collapses N dim → y)
|
|
316
|
+
const bg2 = createBindGroup(d, this.pipelines.scan_reduce,
|
|
317
|
+
[pBuf, siluOut, deltaFull, this.gpuWeights.A_log, B_raw, C_raw,
|
|
318
|
+
this.gpuWeights.D_vec, scanY, hCache]);
|
|
319
|
+
dispatchKernel(d, this.pipelines.scan_reduce, bg2,
|
|
320
|
+
[cdiv(L, 64), D, B]);
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
// 8. Gate: scanY *= SiLU(zBuf) – element-wise product
|
|
324
|
+
const siluZ = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
325
|
+
const gatedOut = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
326
|
+
{
|
|
327
|
+
// SiLU(z)
|
|
328
|
+
const params = new Uint32Array([M * D]).buffer;
|
|
329
|
+
const pBuf = createUniformBuffer(d, params);
|
|
330
|
+
const bg = createBindGroup(d, this.pipelines.silu,
|
|
331
|
+
[pBuf, zBuf, siluZ]);
|
|
332
|
+
dispatchKernel(d, this.pipelines.silu, bg, [cdiv(M * D, 256), 1, 1]);
|
|
333
|
+
|
|
334
|
+
// Element-wise multiply scanY * siluZ → gatedOut
|
|
335
|
+
// We encode this as a trivial compute pass using a small inline shader.
|
|
336
|
+
const mulShader = /* wgsl */`
|
|
337
|
+
@group(0) @binding(0) var<storage, read> a : array<f32>;
|
|
338
|
+
@group(0) @binding(1) var<storage, read> b : array<f32>;
|
|
339
|
+
@group(0) @binding(2) var<storage, read_write> c : array<f32>;
|
|
340
|
+
@group(0) @binding(3) var<uniform> n : u32;
|
|
341
|
+
@compute @workgroup_size(256)
|
|
342
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
343
|
+
let i = gid.x;
|
|
344
|
+
if (i < n) { c[i] = a[i] * b[i]; }
|
|
345
|
+
}
|
|
346
|
+
`;
|
|
347
|
+
const mulPipeline = createComputePipeline(d, mulShader, 'main');
|
|
348
|
+
const nBuf = createUniformBuffer(d, new Uint32Array([M * D]).buffer);
|
|
349
|
+
const bgMul = createBindGroup(d, mulPipeline,
|
|
350
|
+
[scanY, siluZ, gatedOut, nBuf]);
|
|
351
|
+
dispatchKernel(d, mulPipeline, bgMul, [cdiv(M * D, 256), 1, 1]);
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
// 9. out_proj: (M, dModel) = gatedOut @ wOutProj^T + bOutProj
|
|
355
|
+
const outProjOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
356
|
+
{
|
|
357
|
+
const params = new Uint32Array([M, D, dModel]).buffer;
|
|
358
|
+
const pBuf = createUniformBuffer(d, params);
|
|
359
|
+
const bg = createBindGroup(d, this.pipelines.linear,
|
|
360
|
+
[pBuf, gatedOut, this.gpuWeights.wOutProj, this.gpuWeights.bOutProj, outProjOut]);
|
|
361
|
+
dispatchKernel(d, this.pipelines.linear, bg, [cdiv(M, 16), cdiv(dModel, 16), 1]);
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
// 10. Residual add: output = outProjOut + x
|
|
365
|
+
const output = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
366
|
+
{
|
|
367
|
+
const addShader = /* wgsl */`
|
|
368
|
+
@group(0) @binding(0) var<storage, read> a : array<f32>;
|
|
369
|
+
@group(0) @binding(1) var<storage, read> b : array<f32>;
|
|
370
|
+
@group(0) @binding(2) var<storage, read_write> c : array<f32>;
|
|
371
|
+
@group(0) @binding(3) var<uniform> n : u32;
|
|
372
|
+
@compute @workgroup_size(256)
|
|
373
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
374
|
+
let i = gid.x;
|
|
375
|
+
if (i < n) { c[i] = a[i] + b[i]; }
|
|
376
|
+
}
|
|
377
|
+
`;
|
|
378
|
+
const addPipeline = createComputePipeline(d, addShader, 'main');
|
|
379
|
+
const nBuf = createUniformBuffer(d, new Uint32Array([M * dModel]).buffer);
|
|
380
|
+
const bgAdd = createBindGroup(d, addPipeline,
|
|
381
|
+
[outProjOut, xBuf, output, nBuf]);
|
|
382
|
+
dispatchKernel(d, addPipeline, bgAdd, [cdiv(M * dModel, 256), 1, 1]);
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
return { output, cache };
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
/**
|
|
389
|
+
* Return a list of all parameter GPU buffers (for the optimizer).
|
|
390
|
+
* @returns {Array<{buf: GPUBuffer, numel: number, name: string}>}
|
|
391
|
+
*/
|
|
392
|
+
parameters() {
|
|
393
|
+
const { dModel, dState, dConv } = this.config;
|
|
394
|
+
const D = this.dInner;
|
|
395
|
+
const N = dState;
|
|
396
|
+
const K = dConv;
|
|
397
|
+
const R = this.dtRank;
|
|
398
|
+
|
|
399
|
+
return [
|
|
400
|
+
{ buf: this.gpuWeights.wInProj, numel: 2 * D * dModel, name: 'wInProj' },
|
|
401
|
+
{ buf: this.gpuWeights.bInProj, numel: 2 * D, name: 'bInProj' },
|
|
402
|
+
{ buf: this.gpuWeights.wConv, numel: D * K, name: 'wConv' },
|
|
403
|
+
{ buf: this.gpuWeights.bConv, numel: D, name: 'bConv' },
|
|
404
|
+
{ buf: this.gpuWeights.wXProj, numel: (R + 2*N) * D, name: 'wXProj' },
|
|
405
|
+
{ buf: this.gpuWeights.bXProj, numel: R + 2 * N, name: 'bXProj' },
|
|
406
|
+
{ buf: this.gpuWeights.wDtProj, numel: D * R, name: 'wDtProj' },
|
|
407
|
+
{ buf: this.gpuWeights.bDtProj, numel: D, name: 'bDtProj' },
|
|
408
|
+
{ buf: this.gpuWeights.A_log, numel: D * N, name: 'A_log' },
|
|
409
|
+
{ buf: this.gpuWeights.D_vec, numel: D, name: 'D_vec' },
|
|
410
|
+
{ buf: this.gpuWeights.wOutProj, numel: dModel * D, name: 'wOutProj' },
|
|
411
|
+
{ buf: this.gpuWeights.bOutProj, numel: dModel, name: 'bOutProj' },
|
|
412
|
+
{ buf: this.gpuWeights.normWeight, numel: dModel, name: 'normWeight'},
|
|
413
|
+
];
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
/**
|
|
417
|
+
* WSLA (Weight-Selective Local Adaptation) mode.
|
|
418
|
+
* Freezes all parameters except the B and C matrices (wXProj slice).
|
|
419
|
+
* This allows rapid local adaptation with minimal compute.
|
|
420
|
+
*
|
|
421
|
+
* @param {boolean} enabled
|
|
422
|
+
*/
|
|
423
|
+
setWSLAMode(enabled) {
|
|
424
|
+
this._wslaMode = enabled;
|
|
425
|
+
// Mark which parameters receive gradients
|
|
426
|
+
// (The trainer checks this.getTrainableParams() during backward)
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
/**
|
|
430
|
+
* Returns only the trainable parameters under WSLA mode.
|
|
431
|
+
* @returns {Array<{buf: GPUBuffer, numel: number, name: string}>}
|
|
432
|
+
*/
|
|
433
|
+
getTrainableParams() {
|
|
434
|
+
if (this._wslaMode) {
|
|
435
|
+
// Only B and C portions of wXProj
|
|
436
|
+
return [
|
|
437
|
+
{ buf: this.gpuWeights.wXProj, numel: this.wXProj.length, name: 'wXProj' },
|
|
438
|
+
{ buf: this.gpuWeights.bXProj, numel: this.bXProj.length, name: 'bXProj' },
|
|
439
|
+
];
|
|
440
|
+
}
|
|
441
|
+
return this.parameters();
|
|
442
|
+
}
|
|
443
|
+
}
|