mambacode.js 1.0.0 → 1.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +198 -76
- package/dist/index.d.ts +19 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +18 -0
- package/dist/index.js.map +1 -0
- package/dist/kernels/activations.d.ts +3 -0
- package/dist/kernels/activations.d.ts.map +1 -0
- package/dist/kernels/activations.js +87 -0
- package/dist/kernels/activations.js.map +1 -0
- package/dist/kernels/conv1d.d.ts +3 -0
- package/dist/kernels/conv1d.d.ts.map +1 -0
- package/dist/kernels/conv1d.js +152 -0
- package/dist/kernels/conv1d.js.map +1 -0
- package/dist/kernels/linear_projection.d.ts +3 -0
- package/dist/kernels/linear_projection.d.ts.map +1 -0
- package/dist/kernels/linear_projection.js +219 -0
- package/dist/kernels/linear_projection.js.map +1 -0
- package/dist/kernels/selective_scan.d.ts +3 -0
- package/dist/kernels/selective_scan.d.ts.map +1 -0
- package/dist/kernels/selective_scan.js +348 -0
- package/dist/kernels/selective_scan.js.map +1 -0
- package/dist/kernels/weight_update.d.ts +3 -0
- package/dist/kernels/weight_update.d.ts.map +1 -0
- package/dist/kernels/weight_update.js +119 -0
- package/dist/kernels/weight_update.js.map +1 -0
- package/dist/model/mamba_block.d.ts +64 -0
- package/dist/model/mamba_block.d.ts.map +1 -0
- package/dist/model/mamba_block.js +309 -0
- package/dist/model/mamba_block.js.map +1 -0
- package/dist/model/mamba_model.d.ts +66 -0
- package/dist/model/mamba_model.d.ts.map +1 -0
- package/dist/model/mamba_model.js +289 -0
- package/dist/model/mamba_model.js.map +1 -0
- package/dist/tokenizer/bpe.d.ts +29 -0
- package/dist/tokenizer/bpe.d.ts.map +1 -0
- package/dist/tokenizer/bpe.js +164 -0
- package/dist/tokenizer/bpe.js.map +1 -0
- package/dist/training/autograd.d.ts +27 -0
- package/dist/training/autograd.d.ts.map +1 -0
- package/dist/training/autograd.js +120 -0
- package/dist/training/autograd.js.map +1 -0
- package/dist/training/trainer.d.ts +37 -0
- package/dist/training/trainer.d.ts.map +1 -0
- package/dist/training/trainer.js +183 -0
- package/dist/training/trainer.js.map +1 -0
- package/dist/utils/gpu_utils.d.ts +21 -0
- package/dist/utils/gpu_utils.d.ts.map +1 -0
- package/dist/utils/gpu_utils.js +111 -0
- package/dist/utils/gpu_utils.js.map +1 -0
- package/dist/utils/quantization.d.ts +26 -0
- package/dist/utils/quantization.d.ts.map +1 -0
- package/dist/utils/quantization.js +116 -0
- package/dist/utils/quantization.js.map +1 -0
- package/package.json +43 -18
- package/src/index.ts +61 -0
- package/src/kernels/{activations.js → activations.ts} +2 -2
- package/src/kernels/{linear_projection.js → linear_projection.ts} +2 -2
- package/src/kernels/{selective_scan.js → selective_scan.ts} +2 -2
- package/src/kernels/{weight_update.js → weight_update.ts} +2 -2
- package/src/model/{mamba_block.js → mamba_block.ts} +134 -170
- package/src/model/{mamba_model.js → mamba_model.ts} +165 -121
- package/src/tokenizer/bpe.ts +186 -0
- package/src/training/autograd.ts +135 -0
- package/src/training/{trainer.js → trainer.ts} +79 -161
- package/src/utils/gpu_utils.ts +147 -0
- package/src/utils/quantization.ts +154 -0
- package/src/index.js +0 -89
- package/src/tokenizer/bpe.js +0 -256
- package/src/training/autograd.js +0 -221
- package/src/utils/gpu_utils.js +0 -217
- package/src/utils/quantization.js +0 -215
- /package/src/kernels/{conv1d.js → conv1d.ts} +0 -0
|
@@ -1,19 +1,5 @@
|
|
|
1
1
|
/**
|
|
2
|
-
* mamba_block.
|
|
3
|
-
*
|
|
4
|
-
* Implements one complete Mamba residual layer:
|
|
5
|
-
*
|
|
6
|
-
* x ──► Norm ──► Linear up (×2, for z-gate) ──► Conv1D ──► SiLU ──► Scan ──► × z ──► Linear down ──► + x
|
|
7
|
-
*
|
|
8
|
-
* Components (all dispatched as WebGPU compute passes):
|
|
9
|
-
* 1. RMSNorm
|
|
10
|
-
* 2. Linear up-projection: (D_model → 2 × D_inner)
|
|
11
|
-
* 3. 1D Causal Convolution (depthwise, kernel_size=4)
|
|
12
|
-
* 4. SiLU activation
|
|
13
|
-
* 5. Selective Scan (S6 core)
|
|
14
|
-
* 6. Gated multiplication: y * SiLU(z)
|
|
15
|
-
* 7. Linear down-projection: (D_inner → D_model)
|
|
16
|
-
* 8. Residual add
|
|
2
|
+
* mamba_block.ts – Mamba Mixer Block
|
|
17
3
|
*/
|
|
18
4
|
|
|
19
5
|
import {
|
|
@@ -31,87 +17,126 @@ import { CONV1D_FORWARD_WGSL } from '../kernels/conv1d.js';
|
|
|
31
17
|
import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection.js';
|
|
32
18
|
import { ACTIVATIONS_WGSL } from '../kernels/activations.js';
|
|
33
19
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
20
|
+
export interface MambaBlockConfig {
|
|
21
|
+
dModel: number;
|
|
22
|
+
dState?: number;
|
|
23
|
+
dConv?: number;
|
|
24
|
+
expand?: number;
|
|
25
|
+
dtRank?: number;
|
|
26
|
+
biasConv?: boolean;
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
export interface BlockParam {
|
|
30
|
+
buf: GPUBuffer;
|
|
31
|
+
numel: number;
|
|
32
|
+
name: string;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
export interface BlockCache {
|
|
36
|
+
normInv: GPUBuffer;
|
|
37
|
+
normIn: GPUBuffer;
|
|
38
|
+
normOut: GPUBuffer;
|
|
39
|
+
zBuf: GPUBuffer;
|
|
40
|
+
xConvIn: GPUBuffer;
|
|
41
|
+
convOut: GPUBuffer;
|
|
42
|
+
siluOut: GPUBuffer;
|
|
43
|
+
deltaFull: GPUBuffer;
|
|
44
|
+
B_raw: GPUBuffer;
|
|
45
|
+
C_raw: GPUBuffer;
|
|
46
|
+
hCache: GPUBuffer;
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
export interface BlockForwardResult {
|
|
50
|
+
output: GPUBuffer;
|
|
51
|
+
cache: BlockCache;
|
|
52
|
+
}
|
|
43
53
|
|
|
44
54
|
export class MambaBlock {
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
55
|
+
device: GPUDevice;
|
|
56
|
+
config: Required<MambaBlockConfig>;
|
|
57
|
+
dInner: number;
|
|
58
|
+
dtRank: number;
|
|
59
|
+
wInProj: Float32Array;
|
|
60
|
+
bInProj: Float32Array;
|
|
61
|
+
wConv: Float32Array;
|
|
62
|
+
bConv: Float32Array;
|
|
63
|
+
wXProj: Float32Array;
|
|
64
|
+
bXProj: Float32Array;
|
|
65
|
+
wDtProj: Float32Array;
|
|
66
|
+
bDtProj: Float32Array;
|
|
67
|
+
A_log: Float32Array;
|
|
68
|
+
D_vec: Float32Array;
|
|
69
|
+
wOutProj: Float32Array;
|
|
70
|
+
bOutProj: Float32Array;
|
|
71
|
+
normWeight: Float32Array;
|
|
72
|
+
gpuWeights: Record<string, GPUBuffer>;
|
|
73
|
+
pipelines: Record<string, GPUComputePipeline>;
|
|
74
|
+
private _wslaMode = false;
|
|
75
|
+
|
|
76
|
+
constructor(device: GPUDevice, config: MambaBlockConfig) {
|
|
50
77
|
this.device = device;
|
|
51
78
|
this.config = {
|
|
52
79
|
dState : 16,
|
|
53
80
|
dConv : 4,
|
|
54
81
|
expand : 2,
|
|
55
82
|
biasConv: true,
|
|
83
|
+
dtRank : Math.ceil(config.dModel / 16),
|
|
56
84
|
...config,
|
|
57
|
-
}
|
|
85
|
+
} as Required<MambaBlockConfig>;
|
|
58
86
|
|
|
59
|
-
const { dModel,
|
|
87
|
+
const { dModel, expand } = this.config;
|
|
60
88
|
this.dInner = expand * dModel;
|
|
61
|
-
this.dtRank =
|
|
89
|
+
this.dtRank = config.dtRank ?? Math.ceil(dModel / 16);
|
|
90
|
+
|
|
91
|
+
// Initialize these before _initWeights so TypeScript is happy
|
|
92
|
+
this.wInProj = new Float32Array(0);
|
|
93
|
+
this.bInProj = new Float32Array(0);
|
|
94
|
+
this.wConv = new Float32Array(0);
|
|
95
|
+
this.bConv = new Float32Array(0);
|
|
96
|
+
this.wXProj = new Float32Array(0);
|
|
97
|
+
this.bXProj = new Float32Array(0);
|
|
98
|
+
this.wDtProj = new Float32Array(0);
|
|
99
|
+
this.bDtProj = new Float32Array(0);
|
|
100
|
+
this.A_log = new Float32Array(0);
|
|
101
|
+
this.D_vec = new Float32Array(0);
|
|
102
|
+
this.wOutProj = new Float32Array(0);
|
|
103
|
+
this.bOutProj = new Float32Array(0);
|
|
104
|
+
this.normWeight = new Float32Array(0);
|
|
105
|
+
this.gpuWeights = {};
|
|
106
|
+
this.pipelines = {};
|
|
62
107
|
|
|
63
|
-
// ---- Initialise learnable parameters (CPU → GPU) ----
|
|
64
108
|
this._initWeights();
|
|
65
|
-
|
|
66
|
-
// ---- Compile GPU pipelines (once) ----
|
|
67
109
|
this._buildPipelines();
|
|
68
110
|
}
|
|
69
111
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
_initWeights() {
|
|
112
|
+
private _initWeights(): void {
|
|
73
113
|
const { dModel, dState, dConv } = this.config;
|
|
74
114
|
const D = this.dInner;
|
|
75
115
|
const N = dState;
|
|
76
116
|
const K = dConv;
|
|
77
117
|
const R = this.dtRank;
|
|
78
118
|
|
|
79
|
-
const randn = (n, std = 0.02) => {
|
|
119
|
+
const randn = (n: number, std = 0.02): Float32Array => {
|
|
80
120
|
const a = new Float32Array(n);
|
|
81
121
|
for (let i = 0; i < n; i++) {
|
|
82
|
-
// Box-Muller
|
|
83
122
|
const u1 = Math.random(), u2 = Math.random();
|
|
84
123
|
a[i] = std * Math.sqrt(-2 * Math.log(u1 + 1e-12)) * Math.cos(2 * Math.PI * u2);
|
|
85
124
|
}
|
|
86
125
|
return a;
|
|
87
126
|
};
|
|
88
127
|
|
|
89
|
-
const zeros = (n) => new Float32Array(n);
|
|
90
|
-
const ones = (n) => new Float32Array(n).fill(1.0);
|
|
91
|
-
const linspace = (n) => {
|
|
92
|
-
const a = new Float32Array(n);
|
|
93
|
-
for (let i = 0; i < n; i++) a[i] = i;
|
|
94
|
-
return a;
|
|
95
|
-
};
|
|
128
|
+
const zeros = (n: number): Float32Array => new Float32Array(n);
|
|
129
|
+
const ones = (n: number): Float32Array => new Float32Array(n).fill(1.0);
|
|
96
130
|
|
|
97
|
-
// in_proj: (2*D_inner, D_model) – up-projection (and z gate)
|
|
98
131
|
this.wInProj = randn(2 * D * dModel);
|
|
99
132
|
this.bInProj = zeros(2 * D);
|
|
100
|
-
|
|
101
|
-
// conv1d: weight (D_inner, K), bias (D_inner,)
|
|
102
133
|
this.wConv = randn(D * K, 0.01);
|
|
103
134
|
this.bConv = zeros(D);
|
|
104
|
-
|
|
105
|
-
// x_proj: (dt_rank + 2*N, D_inner) – projects x to Δ, B, C
|
|
106
135
|
this.wXProj = randn((R + 2 * N) * D, 0.01);
|
|
107
136
|
this.bXProj = zeros(R + 2 * N);
|
|
108
|
-
|
|
109
|
-
// dt_proj: (D_inner, dt_rank) – projects Δ to full D_inner width
|
|
110
137
|
this.wDtProj = randn(D * R, 0.02);
|
|
111
138
|
this.bDtProj = zeros(D);
|
|
112
139
|
|
|
113
|
-
// A: (D_inner, N) – log-space negative eigenvalues
|
|
114
|
-
// Initialised to log(range(1, N+1)) per HiPPO theory
|
|
115
140
|
this.A_log = new Float32Array(D * N);
|
|
116
141
|
for (let d = 0; d < D; d++) {
|
|
117
142
|
for (let n = 0; n < N; n++) {
|
|
@@ -119,23 +144,17 @@ export class MambaBlock {
|
|
|
119
144
|
}
|
|
120
145
|
}
|
|
121
146
|
|
|
122
|
-
// D: (D_inner,) – skip connection scale (initialised to 1)
|
|
123
147
|
this.D_vec = ones(D);
|
|
124
|
-
|
|
125
|
-
// out_proj: (D_model, D_inner) – down-projection
|
|
126
148
|
this.wOutProj = randn(dModel * D, 0.02);
|
|
127
149
|
this.bOutProj = zeros(dModel);
|
|
128
|
-
|
|
129
|
-
// RMSNorm scale: (D_model,)
|
|
130
150
|
this.normWeight = ones(dModel);
|
|
131
151
|
|
|
132
|
-
// Upload all to GPU
|
|
133
152
|
this._uploadWeightsToGPU();
|
|
134
153
|
}
|
|
135
154
|
|
|
136
|
-
_uploadWeightsToGPU() {
|
|
155
|
+
private _uploadWeightsToGPU(): void {
|
|
137
156
|
const d = this.device;
|
|
138
|
-
const mk = (arr, readable = true) => createStorageBuffer(d, arr, readable);
|
|
157
|
+
const mk = (arr: Float32Array, readable = true): GPUBuffer => createStorageBuffer(d, arr, readable);
|
|
139
158
|
|
|
140
159
|
this.gpuWeights = {
|
|
141
160
|
wInProj : mk(this.wInProj),
|
|
@@ -154,9 +173,7 @@ export class MambaBlock {
|
|
|
154
173
|
};
|
|
155
174
|
}
|
|
156
175
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
_buildPipelines() {
|
|
176
|
+
private _buildPipelines(): void {
|
|
160
177
|
const d = this.device;
|
|
161
178
|
|
|
162
179
|
this.pipelines = {
|
|
@@ -169,19 +186,7 @@ export class MambaBlock {
|
|
|
169
186
|
};
|
|
170
187
|
}
|
|
171
188
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
/**
|
|
175
|
-
* Run the Mamba block forward pass on GPU.
|
|
176
|
-
*
|
|
177
|
-
* @param {GPUBuffer} xBuf – input (batch * seqLen, dModel)
|
|
178
|
-
* @param {number} batch
|
|
179
|
-
* @param {number} seqLen
|
|
180
|
-
* @returns {{ output: GPUBuffer, cache: Object }}
|
|
181
|
-
* output – (batch * seqLen, dModel)
|
|
182
|
-
* cache – intermediate buffers needed for backward pass
|
|
183
|
-
*/
|
|
184
|
-
forward(xBuf, batch, seqLen) {
|
|
189
|
+
forward(xBuf: GPUBuffer, batch: number, seqLen: number): BlockForwardResult {
|
|
185
190
|
const d = this.device;
|
|
186
191
|
const { dModel, dState, dConv } = this.config;
|
|
187
192
|
const D = this.dInner;
|
|
@@ -191,45 +196,37 @@ export class MambaBlock {
|
|
|
191
196
|
const M = B * L;
|
|
192
197
|
const R = this.dtRank;
|
|
193
198
|
|
|
194
|
-
|
|
195
|
-
const cache = {};
|
|
199
|
+
const cache = {} as BlockCache;
|
|
196
200
|
|
|
197
|
-
// 1. RMSNorm: (M, dModel)
|
|
198
201
|
const normOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
199
202
|
const normInv = createEmptyStorageBuffer(d, M * 4, true);
|
|
200
203
|
cache.normInv = normInv;
|
|
201
204
|
cache.normIn = xBuf;
|
|
202
205
|
|
|
203
206
|
{
|
|
204
|
-
// Pack params as Uint32 (num_rows, dim) + f32 (eps) ← 12 bytes padded to 16
|
|
205
207
|
const params = new ArrayBuffer(16);
|
|
206
208
|
new Uint32Array(params, 0, 2).set([M, dModel]);
|
|
207
209
|
new Float32Array(params, 8, 1).set([1e-6]);
|
|
208
210
|
const pBuf = createUniformBuffer(d, params);
|
|
209
211
|
|
|
210
|
-
const bg = createBindGroup(d, this.pipelines
|
|
211
|
-
[pBuf, xBuf, this.gpuWeights
|
|
212
|
-
dispatchKernel(d, this.pipelines
|
|
212
|
+
const bg = createBindGroup(d, this.pipelines['rmsnorm']!,
|
|
213
|
+
[pBuf, xBuf, this.gpuWeights['normWeight']!, normOut, normInv]);
|
|
214
|
+
dispatchKernel(d, this.pipelines['rmsnorm']!, bg, [cdiv(M, 64), 1, 1]);
|
|
213
215
|
}
|
|
214
216
|
|
|
215
|
-
// 2. in_proj: (M, 2*D) = normOut @ wInProj^T + bInProj
|
|
216
217
|
const inProjOut = createEmptyStorageBuffer(d, M * 2 * D * 4, true);
|
|
217
218
|
cache.normOut = normOut;
|
|
218
219
|
{
|
|
219
220
|
const params = new Uint32Array([M, dModel, 2 * D]).buffer;
|
|
220
221
|
const pBuf = createUniformBuffer(d, params);
|
|
221
|
-
const bg = createBindGroup(d, this.pipelines
|
|
222
|
-
[pBuf, normOut, this.gpuWeights
|
|
223
|
-
dispatchKernel(d, this.pipelines
|
|
222
|
+
const bg = createBindGroup(d, this.pipelines['linear']!,
|
|
223
|
+
[pBuf, normOut, this.gpuWeights['wInProj']!, this.gpuWeights['bInProj']!, inProjOut]);
|
|
224
|
+
dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(2 * D, 16), 1]);
|
|
224
225
|
}
|
|
225
226
|
|
|
226
|
-
// Split inProjOut into x (M, D) and z (M, D) – the z-gate
|
|
227
|
-
// We reuse the same buffer with offsets since WGSL bindings can be offset.
|
|
228
|
-
// For simplicity, allocate two separate buffers and copy.
|
|
229
227
|
const xConvIn = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
230
228
|
const zBuf = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
231
229
|
{
|
|
232
|
-
// Copy first D columns into xConvIn, last D columns into zBuf
|
|
233
230
|
const enc = d.createCommandEncoder();
|
|
234
231
|
enc.copyBufferToBuffer(inProjOut, 0, xConvIn, 0, M * D * 4);
|
|
235
232
|
enc.copyBufferToBuffer(inProjOut, M * D * 4, zBuf, 0, M * D * 4);
|
|
@@ -237,39 +234,35 @@ export class MambaBlock {
|
|
|
237
234
|
}
|
|
238
235
|
cache.zBuf = zBuf;
|
|
239
236
|
|
|
240
|
-
// 3. Conv1D on xConvIn: (B, L, D) – depthwise causal conv
|
|
241
237
|
const convOut = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
242
238
|
cache.xConvIn = xConvIn;
|
|
243
239
|
{
|
|
244
240
|
const params = new Uint32Array([L, D, dConv, B]).buffer;
|
|
245
241
|
const pBuf = createUniformBuffer(d, params);
|
|
246
|
-
const bg = createBindGroup(d, this.pipelines
|
|
247
|
-
[pBuf, xConvIn, this.gpuWeights
|
|
248
|
-
dispatchKernel(d, this.pipelines
|
|
242
|
+
const bg = createBindGroup(d, this.pipelines['conv1d']!,
|
|
243
|
+
[pBuf, xConvIn, this.gpuWeights['wConv']!, this.gpuWeights['bConv']!, convOut]);
|
|
244
|
+
dispatchKernel(d, this.pipelines['conv1d']!, bg, [cdiv(L, 16), cdiv(D, 16), B]);
|
|
249
245
|
}
|
|
250
246
|
|
|
251
|
-
// 4. SiLU(convOut) in-place
|
|
252
247
|
const siluOut = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
253
248
|
cache.convOut = convOut;
|
|
254
249
|
{
|
|
255
250
|
const params = new Uint32Array([M * D]).buffer;
|
|
256
251
|
const pBuf = createUniformBuffer(d, params);
|
|
257
|
-
const bg = createBindGroup(d, this.pipelines
|
|
252
|
+
const bg = createBindGroup(d, this.pipelines['silu']!,
|
|
258
253
|
[pBuf, convOut, siluOut]);
|
|
259
|
-
dispatchKernel(d, this.pipelines
|
|
254
|
+
dispatchKernel(d, this.pipelines['silu']!, bg, [cdiv(M * D, 256), 1, 1]);
|
|
260
255
|
}
|
|
261
256
|
|
|
262
|
-
// 5. x_proj: (M, R+2N) = siluOut @ wXProj^T + bXProj
|
|
263
257
|
const xProjOut = createEmptyStorageBuffer(d, M * (R + 2 * N) * 4, true);
|
|
264
258
|
{
|
|
265
259
|
const params = new Uint32Array([M, D, R + 2 * N]).buffer;
|
|
266
260
|
const pBuf = createUniformBuffer(d, params);
|
|
267
|
-
const bg = createBindGroup(d, this.pipelines
|
|
268
|
-
[pBuf, siluOut, this.gpuWeights
|
|
269
|
-
dispatchKernel(d, this.pipelines
|
|
261
|
+
const bg = createBindGroup(d, this.pipelines['linear']!,
|
|
262
|
+
[pBuf, siluOut, this.gpuWeights['wXProj']!, this.gpuWeights['bXProj']!, xProjOut]);
|
|
263
|
+
dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(R + 2 * N, 16), 1]);
|
|
270
264
|
}
|
|
271
265
|
|
|
272
|
-
// Split xProjOut → dtRaw (M, R), B_raw (M*N flattened) = (B, L, N), C_raw (B, L, N)
|
|
273
266
|
const dtRaw = createEmptyStorageBuffer(d, M * R * 4, true);
|
|
274
267
|
const B_raw = createEmptyStorageBuffer(d, B * L * N * 4, true);
|
|
275
268
|
const C_raw = createEmptyStorageBuffer(d, B * L * N * 4, true);
|
|
@@ -281,18 +274,15 @@ export class MambaBlock {
|
|
|
281
274
|
d.queue.submit([enc.finish()]);
|
|
282
275
|
}
|
|
283
276
|
|
|
284
|
-
// 6. dt_proj: (M, D) = dtRaw @ wDtProj^T + bDtProj
|
|
285
277
|
const deltaFull = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
286
278
|
{
|
|
287
279
|
const params = new Uint32Array([M, R, D]).buffer;
|
|
288
280
|
const pBuf = createUniformBuffer(d, params);
|
|
289
|
-
const bg = createBindGroup(d, this.pipelines
|
|
290
|
-
[pBuf, dtRaw, this.gpuWeights
|
|
291
|
-
dispatchKernel(d, this.pipelines
|
|
281
|
+
const bg = createBindGroup(d, this.pipelines['linear']!,
|
|
282
|
+
[pBuf, dtRaw, this.gpuWeights['wDtProj']!, this.gpuWeights['bDtProj']!, deltaFull]);
|
|
283
|
+
dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(D, 16), 1]);
|
|
292
284
|
}
|
|
293
285
|
|
|
294
|
-
// 7. Selective Scan
|
|
295
|
-
// Allocate y (B, L, D) and h_cache (2 * B*L*D*N) – first half for h, second for y_partial
|
|
296
286
|
const scanY = createEmptyStorageBuffer(d, B * L * D * 4, true);
|
|
297
287
|
const hCache = createEmptyStorageBuffer(d, 2 * B * L * D * N * 4, true);
|
|
298
288
|
cache.siluOut = siluOut;
|
|
@@ -305,34 +295,28 @@ export class MambaBlock {
|
|
|
305
295
|
const params = new Uint32Array([L, N, D, B]).buffer;
|
|
306
296
|
const pBuf = createUniformBuffer(d, params);
|
|
307
297
|
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
dispatchKernel(d, this.pipelines.scan_fwd, bg,
|
|
298
|
+
const bg = createBindGroup(d, this.pipelines['scan_fwd']!,
|
|
299
|
+
[pBuf, siluOut, deltaFull, this.gpuWeights['A_log']!, B_raw, C_raw,
|
|
300
|
+
this.gpuWeights['D_vec']!, scanY, hCache]);
|
|
301
|
+
dispatchKernel(d, this.pipelines['scan_fwd']!, bg,
|
|
313
302
|
[cdiv(D, 8), cdiv(N, 8), B]);
|
|
314
303
|
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
dispatchKernel(d, this.pipelines.scan_reduce, bg2,
|
|
304
|
+
const bg2 = createBindGroup(d, this.pipelines['scan_reduce']!,
|
|
305
|
+
[pBuf, siluOut, deltaFull, this.gpuWeights['A_log']!, B_raw, C_raw,
|
|
306
|
+
this.gpuWeights['D_vec']!, scanY, hCache]);
|
|
307
|
+
dispatchKernel(d, this.pipelines['scan_reduce']!, bg2,
|
|
320
308
|
[cdiv(L, 64), D, B]);
|
|
321
309
|
}
|
|
322
310
|
|
|
323
|
-
// 8. Gate: scanY *= SiLU(zBuf) – element-wise product
|
|
324
311
|
const siluZ = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
325
312
|
const gatedOut = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
326
313
|
{
|
|
327
|
-
// SiLU(z)
|
|
328
314
|
const params = new Uint32Array([M * D]).buffer;
|
|
329
315
|
const pBuf = createUniformBuffer(d, params);
|
|
330
|
-
const bg = createBindGroup(d, this.pipelines
|
|
316
|
+
const bg = createBindGroup(d, this.pipelines['silu']!,
|
|
331
317
|
[pBuf, zBuf, siluZ]);
|
|
332
|
-
dispatchKernel(d, this.pipelines
|
|
318
|
+
dispatchKernel(d, this.pipelines['silu']!, bg, [cdiv(M * D, 256), 1, 1]);
|
|
333
319
|
|
|
334
|
-
// Element-wise multiply scanY * siluZ → gatedOut
|
|
335
|
-
// We encode this as a trivial compute pass using a small inline shader.
|
|
336
320
|
const mulShader = /* wgsl */`
|
|
337
321
|
@group(0) @binding(0) var<storage, read> a : array<f32>;
|
|
338
322
|
@group(0) @binding(1) var<storage, read> b : array<f32>;
|
|
@@ -351,17 +335,15 @@ export class MambaBlock {
|
|
|
351
335
|
dispatchKernel(d, mulPipeline, bgMul, [cdiv(M * D, 256), 1, 1]);
|
|
352
336
|
}
|
|
353
337
|
|
|
354
|
-
// 9. out_proj: (M, dModel) = gatedOut @ wOutProj^T + bOutProj
|
|
355
338
|
const outProjOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
356
339
|
{
|
|
357
340
|
const params = new Uint32Array([M, D, dModel]).buffer;
|
|
358
341
|
const pBuf = createUniformBuffer(d, params);
|
|
359
|
-
const bg = createBindGroup(d, this.pipelines
|
|
360
|
-
[pBuf, gatedOut, this.gpuWeights
|
|
361
|
-
dispatchKernel(d, this.pipelines
|
|
342
|
+
const bg = createBindGroup(d, this.pipelines['linear']!,
|
|
343
|
+
[pBuf, gatedOut, this.gpuWeights['wOutProj']!, this.gpuWeights['bOutProj']!, outProjOut]);
|
|
344
|
+
dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(dModel, 16), 1]);
|
|
362
345
|
}
|
|
363
346
|
|
|
364
|
-
// 10. Residual add: output = outProjOut + x
|
|
365
347
|
const output = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
366
348
|
{
|
|
367
349
|
const addShader = /* wgsl */`
|
|
@@ -385,11 +367,7 @@ export class MambaBlock {
|
|
|
385
367
|
return { output, cache };
|
|
386
368
|
}
|
|
387
369
|
|
|
388
|
-
|
|
389
|
-
* Return a list of all parameter GPU buffers (for the optimizer).
|
|
390
|
-
* @returns {Array<{buf: GPUBuffer, numel: number, name: string}>}
|
|
391
|
-
*/
|
|
392
|
-
parameters() {
|
|
370
|
+
parameters(): BlockParam[] {
|
|
393
371
|
const { dModel, dState, dConv } = this.config;
|
|
394
372
|
const D = this.dInner;
|
|
395
373
|
const N = dState;
|
|
@@ -397,45 +375,31 @@ export class MambaBlock {
|
|
|
397
375
|
const R = this.dtRank;
|
|
398
376
|
|
|
399
377
|
return [
|
|
400
|
-
{ buf: this.gpuWeights
|
|
401
|
-
{ buf: this.gpuWeights
|
|
402
|
-
{ buf: this.gpuWeights
|
|
403
|
-
{ buf: this.gpuWeights
|
|
404
|
-
{ buf: this.gpuWeights
|
|
405
|
-
{ buf: this.gpuWeights
|
|
406
|
-
{ buf: this.gpuWeights
|
|
407
|
-
{ buf: this.gpuWeights
|
|
408
|
-
{ buf: this.gpuWeights
|
|
409
|
-
{ buf: this.gpuWeights
|
|
410
|
-
{ buf: this.gpuWeights
|
|
411
|
-
{ buf: this.gpuWeights
|
|
412
|
-
{ buf: this.gpuWeights
|
|
378
|
+
{ buf: this.gpuWeights['wInProj']!, numel: 2 * D * dModel, name: 'wInProj' },
|
|
379
|
+
{ buf: this.gpuWeights['bInProj']!, numel: 2 * D, name: 'bInProj' },
|
|
380
|
+
{ buf: this.gpuWeights['wConv']!, numel: D * K, name: 'wConv' },
|
|
381
|
+
{ buf: this.gpuWeights['bConv']!, numel: D, name: 'bConv' },
|
|
382
|
+
{ buf: this.gpuWeights['wXProj']!, numel: (R + 2*N) * D, name: 'wXProj' },
|
|
383
|
+
{ buf: this.gpuWeights['bXProj']!, numel: R + 2 * N, name: 'bXProj' },
|
|
384
|
+
{ buf: this.gpuWeights['wDtProj']!, numel: D * R, name: 'wDtProj' },
|
|
385
|
+
{ buf: this.gpuWeights['bDtProj']!, numel: D, name: 'bDtProj' },
|
|
386
|
+
{ buf: this.gpuWeights['A_log']!, numel: D * N, name: 'A_log' },
|
|
387
|
+
{ buf: this.gpuWeights['D_vec']!, numel: D, name: 'D_vec' },
|
|
388
|
+
{ buf: this.gpuWeights['wOutProj']!, numel: dModel * D, name: 'wOutProj' },
|
|
389
|
+
{ buf: this.gpuWeights['bOutProj']!, numel: dModel, name: 'bOutProj' },
|
|
390
|
+
{ buf: this.gpuWeights['normWeight']!, numel: dModel, name: 'normWeight'},
|
|
413
391
|
];
|
|
414
392
|
}
|
|
415
393
|
|
|
416
|
-
|
|
417
|
-
* WSLA (Weight-Selective Local Adaptation) mode.
|
|
418
|
-
* Freezes all parameters except the B and C matrices (wXProj slice).
|
|
419
|
-
* This allows rapid local adaptation with minimal compute.
|
|
420
|
-
*
|
|
421
|
-
* @param {boolean} enabled
|
|
422
|
-
*/
|
|
423
|
-
setWSLAMode(enabled) {
|
|
394
|
+
setWSLAMode(enabled: boolean): void {
|
|
424
395
|
this._wslaMode = enabled;
|
|
425
|
-
// Mark which parameters receive gradients
|
|
426
|
-
// (The trainer checks this.getTrainableParams() during backward)
|
|
427
396
|
}
|
|
428
397
|
|
|
429
|
-
|
|
430
|
-
* Returns only the trainable parameters under WSLA mode.
|
|
431
|
-
* @returns {Array<{buf: GPUBuffer, numel: number, name: string}>}
|
|
432
|
-
*/
|
|
433
|
-
getTrainableParams() {
|
|
398
|
+
getTrainableParams(): BlockParam[] {
|
|
434
399
|
if (this._wslaMode) {
|
|
435
|
-
// Only B and C portions of wXProj
|
|
436
400
|
return [
|
|
437
|
-
{ buf: this.gpuWeights
|
|
438
|
-
{ buf: this.gpuWeights
|
|
401
|
+
{ buf: this.gpuWeights['wXProj']!, numel: this.wXProj.length, name: 'wXProj' },
|
|
402
|
+
{ buf: this.gpuWeights['bXProj']!, numel: this.bXProj.length, name: 'bXProj' },
|
|
439
403
|
];
|
|
440
404
|
}
|
|
441
405
|
return this.parameters();
|