@seanhogg/builderforce-memory-engine 2026.6.18
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +393 -0
- package/dist/index.d.ts +32 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +40 -0
- package/dist/index.js.map +1 -0
- package/dist/kernels/activations.d.ts +5 -0
- package/dist/kernels/activations.d.ts.map +1 -0
- package/dist/kernels/activations.js +171 -0
- package/dist/kernels/activations.js.map +1 -0
- package/dist/kernels/attention.d.ts +19 -0
- package/dist/kernels/attention.d.ts.map +1 -0
- package/dist/kernels/attention.js +263 -0
- package/dist/kernels/attention.js.map +1 -0
- package/dist/kernels/complex_ssd.d.ts +33 -0
- package/dist/kernels/complex_ssd.d.ts.map +1 -0
- package/dist/kernels/complex_ssd.js +305 -0
- package/dist/kernels/complex_ssd.js.map +1 -0
- package/dist/kernels/conv1d.d.ts +3 -0
- package/dist/kernels/conv1d.d.ts.map +1 -0
- package/dist/kernels/conv1d.js +158 -0
- package/dist/kernels/conv1d.js.map +1 -0
- package/dist/kernels/linear_projection.d.ts +3 -0
- package/dist/kernels/linear_projection.d.ts.map +1 -0
- package/dist/kernels/linear_projection.js +219 -0
- package/dist/kernels/linear_projection.js.map +1 -0
- package/dist/kernels/selective_scan.d.ts +3 -0
- package/dist/kernels/selective_scan.d.ts.map +1 -0
- package/dist/kernels/selective_scan.js +348 -0
- package/dist/kernels/selective_scan.js.map +1 -0
- package/dist/kernels/ssd.d.ts +29 -0
- package/dist/kernels/ssd.d.ts.map +1 -0
- package/dist/kernels/ssd.js +276 -0
- package/dist/kernels/ssd.js.map +1 -0
- package/dist/kernels/weight_update.d.ts +3 -0
- package/dist/kernels/weight_update.d.ts.map +1 -0
- package/dist/kernels/weight_update.js +119 -0
- package/dist/kernels/weight_update.js.map +1 -0
- package/dist/model/attention_block.d.ts +48 -0
- package/dist/model/attention_block.d.ts.map +1 -0
- package/dist/model/attention_block.js +262 -0
- package/dist/model/attention_block.js.map +1 -0
- package/dist/model/mamba1_block.d.ts +70 -0
- package/dist/model/mamba1_block.d.ts.map +1 -0
- package/dist/model/mamba1_block.js +333 -0
- package/dist/model/mamba1_block.js.map +1 -0
- package/dist/model/mamba2_block.d.ts +44 -0
- package/dist/model/mamba2_block.d.ts.map +1 -0
- package/dist/model/mamba2_block.js +252 -0
- package/dist/model/mamba2_block.js.map +1 -0
- package/dist/model/mamba3_block.d.ts +51 -0
- package/dist/model/mamba3_block.d.ts.map +1 -0
- package/dist/model/mamba3_block.js +270 -0
- package/dist/model/mamba3_block.js.map +1 -0
- package/dist/model/mamba_block.d.ts +64 -0
- package/dist/model/mamba_block.d.ts.map +1 -0
- package/dist/model/mamba_block.js +303 -0
- package/dist/model/mamba_block.js.map +1 -0
- package/dist/model/mamba_model.d.ts +140 -0
- package/dist/model/mamba_model.d.ts.map +1 -0
- package/dist/model/mamba_model.js +527 -0
- package/dist/model/mamba_model.js.map +1 -0
- package/dist/model/sequence_layer.d.ts +25 -0
- package/dist/model/sequence_layer.d.ts.map +1 -0
- package/dist/model/sequence_layer.js +8 -0
- package/dist/model/sequence_layer.js.map +1 -0
- package/dist/tokenizer/bpe.d.ts +29 -0
- package/dist/tokenizer/bpe.d.ts.map +1 -0
- package/dist/tokenizer/bpe.js +164 -0
- package/dist/tokenizer/bpe.js.map +1 -0
- package/dist/training/autograd.d.ts +27 -0
- package/dist/training/autograd.d.ts.map +1 -0
- package/dist/training/autograd.js +120 -0
- package/dist/training/autograd.js.map +1 -0
- package/dist/training/trainer.d.ts +36 -0
- package/dist/training/trainer.d.ts.map +1 -0
- package/dist/training/trainer.js +183 -0
- package/dist/training/trainer.js.map +1 -0
- package/dist/utils/gpu_utils.d.ts +21 -0
- package/dist/utils/gpu_utils.d.ts.map +1 -0
- package/dist/utils/gpu_utils.js +111 -0
- package/dist/utils/gpu_utils.js.map +1 -0
- package/dist/utils/quantization.d.ts +26 -0
- package/dist/utils/quantization.d.ts.map +1 -0
- package/dist/utils/quantization.js +116 -0
- package/dist/utils/quantization.js.map +1 -0
- package/dist/utils/rng.d.ts +36 -0
- package/dist/utils/rng.d.ts.map +1 -0
- package/dist/utils/rng.js +61 -0
- package/dist/utils/rng.js.map +1 -0
- package/package.json +99 -0
- package/src/index.ts +114 -0
- package/src/kernels/activations.ts +174 -0
- package/src/kernels/attention.ts +268 -0
- package/src/kernels/complex_ssd.ts +307 -0
- package/src/kernels/conv1d.ts +159 -0
- package/src/kernels/linear_projection.ts +220 -0
- package/src/kernels/selective_scan.ts +350 -0
- package/src/kernels/ssd.ts +278 -0
- package/src/kernels/weight_update.ts +120 -0
- package/src/model/attention_block.ts +344 -0
- package/src/model/mamba1_block.ts +437 -0
- package/src/model/mamba2_block.ts +319 -0
- package/src/model/mamba3_block.ts +335 -0
- package/src/model/mamba_block.ts +401 -0
- package/src/model/mamba_model.ts +678 -0
- package/src/model/sequence_layer.ts +29 -0
- package/src/tokenizer/bpe.ts +186 -0
- package/src/training/autograd.ts +135 -0
- package/src/training/trainer.ts +309 -0
- package/src/utils/gpu_utils.ts +147 -0
- package/src/utils/quantization.ts +154 -0
- package/src/utils/rng.ts +65 -0
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* mamba1_block.ts – Mamba-1 Mixer Block (S6 selective scan).
|
|
3
|
+
*
|
|
4
|
+
* Renamed from mamba_block.ts; MambaBlock is kept as a deprecated alias.
|
|
5
|
+
* Implements SequenceLayer so HybridMambaModel can iterate blocks generically.
|
|
6
|
+
*/
|
|
7
|
+
import { createComputePipeline, createBindGroup, createStorageBuffer, createEmptyStorageBuffer, createUniformBuffer, dispatchKernel, cdiv, } from '../utils/gpu_utils.js';
|
|
8
|
+
import { SELECTIVE_SCAN_FORWARD_WGSL } from '../kernels/selective_scan.js';
|
|
9
|
+
import { gaussianArray } from '../utils/rng.js';
|
|
10
|
+
import { CONV1D_FORWARD_WGSL } from '../kernels/conv1d.js';
|
|
11
|
+
import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection.js';
|
|
12
|
+
import { ACTIVATIONS_WGSL } from '../kernels/activations.js';
|
|
13
|
+
// ── Element-wise helper shaders (compiled once per pipeline) ─────────────────
|
|
14
|
+
const MUL_SHADER = /* wgsl */ `
|
|
15
|
+
@group(0) @binding(0) var<storage, read> a : array<f32>;
|
|
16
|
+
@group(0) @binding(1) var<storage, read> b : array<f32>;
|
|
17
|
+
@group(0) @binding(2) var<storage, read_write> c : array<f32>;
|
|
18
|
+
@group(0) @binding(3) var<uniform> n : u32;
|
|
19
|
+
@compute @workgroup_size(256)
|
|
20
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
21
|
+
let i = gid.x;
|
|
22
|
+
if (i < n) { c[i] = a[i] * b[i]; }
|
|
23
|
+
}
|
|
24
|
+
`;
|
|
25
|
+
const ADD_SHADER = /* wgsl */ `
|
|
26
|
+
@group(0) @binding(0) var<storage, read> a : array<f32>;
|
|
27
|
+
@group(0) @binding(1) var<storage, read> b : array<f32>;
|
|
28
|
+
@group(0) @binding(2) var<storage, read_write> c : array<f32>;
|
|
29
|
+
@group(0) @binding(3) var<uniform> n : u32;
|
|
30
|
+
@compute @workgroup_size(256)
|
|
31
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
32
|
+
let i = gid.x;
|
|
33
|
+
if (i < n) { c[i] = a[i] + b[i]; }
|
|
34
|
+
}
|
|
35
|
+
`;
|
|
36
|
+
// ── Mamba1Block ───────────────────────────────────────────────────────────────
|
|
37
|
+
export class Mamba1Block {
|
|
38
|
+
layerType = 'mamba1';
|
|
39
|
+
device;
|
|
40
|
+
config;
|
|
41
|
+
dInner;
|
|
42
|
+
dtRank;
|
|
43
|
+
wInProj;
|
|
44
|
+
bInProj;
|
|
45
|
+
wConv;
|
|
46
|
+
bConv;
|
|
47
|
+
wXProj;
|
|
48
|
+
bXProj;
|
|
49
|
+
wDtProj;
|
|
50
|
+
bDtProj;
|
|
51
|
+
A_log;
|
|
52
|
+
D_vec;
|
|
53
|
+
wOutProj;
|
|
54
|
+
bOutProj;
|
|
55
|
+
normWeight;
|
|
56
|
+
gpuWeights;
|
|
57
|
+
pipelines;
|
|
58
|
+
_wslaMode = false;
|
|
59
|
+
constructor(device, config) {
|
|
60
|
+
this.device = device;
|
|
61
|
+
this.config = {
|
|
62
|
+
dState: 16,
|
|
63
|
+
dConv: 4,
|
|
64
|
+
expand: 2,
|
|
65
|
+
biasConv: true,
|
|
66
|
+
dtRank: Math.ceil(config.dModel / 16),
|
|
67
|
+
...config,
|
|
68
|
+
};
|
|
69
|
+
const { dModel, expand } = this.config;
|
|
70
|
+
this.dInner = expand * dModel;
|
|
71
|
+
this.dtRank = config.dtRank ?? Math.ceil(dModel / 16);
|
|
72
|
+
this.wInProj = new Float32Array(0);
|
|
73
|
+
this.bInProj = new Float32Array(0);
|
|
74
|
+
this.wConv = new Float32Array(0);
|
|
75
|
+
this.bConv = new Float32Array(0);
|
|
76
|
+
this.wXProj = new Float32Array(0);
|
|
77
|
+
this.bXProj = new Float32Array(0);
|
|
78
|
+
this.wDtProj = new Float32Array(0);
|
|
79
|
+
this.bDtProj = new Float32Array(0);
|
|
80
|
+
this.A_log = new Float32Array(0);
|
|
81
|
+
this.D_vec = new Float32Array(0);
|
|
82
|
+
this.wOutProj = new Float32Array(0);
|
|
83
|
+
this.bOutProj = new Float32Array(0);
|
|
84
|
+
this.normWeight = new Float32Array(0);
|
|
85
|
+
this.gpuWeights = {};
|
|
86
|
+
this.pipelines = {};
|
|
87
|
+
this._initWeights();
|
|
88
|
+
this._buildPipelines();
|
|
89
|
+
}
|
|
90
|
+
_initWeights() {
|
|
91
|
+
const { dModel, dState, dConv } = this.config;
|
|
92
|
+
const D = this.dInner;
|
|
93
|
+
const N = dState;
|
|
94
|
+
const K = dConv;
|
|
95
|
+
const R = this.dtRank;
|
|
96
|
+
const randn = (n, std = 0.02) => gaussianArray(n, std);
|
|
97
|
+
const zeros = (n) => new Float32Array(n);
|
|
98
|
+
const ones = (n) => new Float32Array(n).fill(1.0);
|
|
99
|
+
this.wInProj = randn(2 * D * dModel);
|
|
100
|
+
this.bInProj = zeros(2 * D);
|
|
101
|
+
this.wConv = randn(D * K, 0.01);
|
|
102
|
+
this.bConv = zeros(D);
|
|
103
|
+
this.wXProj = randn((R + 2 * N) * D, 0.01);
|
|
104
|
+
this.bXProj = zeros(R + 2 * N);
|
|
105
|
+
this.wDtProj = randn(D * R, 0.02);
|
|
106
|
+
this.bDtProj = zeros(D);
|
|
107
|
+
this.A_log = new Float32Array(D * N);
|
|
108
|
+
for (let d = 0; d < D; d++) {
|
|
109
|
+
for (let n = 0; n < N; n++) {
|
|
110
|
+
this.A_log[d * N + n] = Math.log(n + 1);
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
this.D_vec = ones(D);
|
|
114
|
+
this.wOutProj = randn(dModel * D, 0.02);
|
|
115
|
+
this.bOutProj = zeros(dModel);
|
|
116
|
+
this.normWeight = ones(dModel);
|
|
117
|
+
this._uploadWeightsToGPU();
|
|
118
|
+
}
|
|
119
|
+
_uploadWeightsToGPU() {
|
|
120
|
+
const d = this.device;
|
|
121
|
+
const mk = (arr) => createStorageBuffer(d, arr, true);
|
|
122
|
+
this.gpuWeights = {
|
|
123
|
+
wInProj: mk(this.wInProj),
|
|
124
|
+
bInProj: mk(this.bInProj),
|
|
125
|
+
wConv: mk(this.wConv),
|
|
126
|
+
bConv: mk(this.bConv),
|
|
127
|
+
wXProj: mk(this.wXProj),
|
|
128
|
+
bXProj: mk(this.bXProj),
|
|
129
|
+
wDtProj: mk(this.wDtProj),
|
|
130
|
+
bDtProj: mk(this.bDtProj),
|
|
131
|
+
A_log: mk(this.A_log),
|
|
132
|
+
D_vec: mk(this.D_vec),
|
|
133
|
+
wOutProj: mk(this.wOutProj),
|
|
134
|
+
bOutProj: mk(this.bOutProj),
|
|
135
|
+
normWeight: mk(this.normWeight),
|
|
136
|
+
};
|
|
137
|
+
}
|
|
138
|
+
_buildPipelines() {
|
|
139
|
+
const d = this.device;
|
|
140
|
+
this.pipelines = {
|
|
141
|
+
linear: createComputePipeline(d, LINEAR_FORWARD_WGSL, 'linear_forward'),
|
|
142
|
+
conv1d: createComputePipeline(d, CONV1D_FORWARD_WGSL, 'conv1d_forward'),
|
|
143
|
+
silu: createComputePipeline(d, ACTIVATIONS_WGSL, 'silu_forward'),
|
|
144
|
+
rmsnorm: createComputePipeline(d, ACTIVATIONS_WGSL, 'rmsnorm_forward'),
|
|
145
|
+
scan_fwd: createComputePipeline(d, SELECTIVE_SCAN_FORWARD_WGSL, 'forward_scan'),
|
|
146
|
+
scan_reduce: createComputePipeline(d, SELECTIVE_SCAN_FORWARD_WGSL, 'forward_reduce'),
|
|
147
|
+
elMul: createComputePipeline(d, MUL_SHADER, 'main'),
|
|
148
|
+
elAdd: createComputePipeline(d, ADD_SHADER, 'main'),
|
|
149
|
+
};
|
|
150
|
+
}
|
|
151
|
+
forward(xBuf, batch, seqLen) {
|
|
152
|
+
const d = this.device;
|
|
153
|
+
const { dModel, dState, dConv } = this.config;
|
|
154
|
+
const D = this.dInner;
|
|
155
|
+
const N = dState;
|
|
156
|
+
const B = batch;
|
|
157
|
+
const L = seqLen;
|
|
158
|
+
const M = B * L;
|
|
159
|
+
const R = this.dtRank;
|
|
160
|
+
const cache = {};
|
|
161
|
+
// 1. Pre-block RMSNorm
|
|
162
|
+
const normOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
163
|
+
const normInv = createEmptyStorageBuffer(d, M * 4, true);
|
|
164
|
+
cache.normInv = normInv;
|
|
165
|
+
cache.normIn = xBuf;
|
|
166
|
+
{
|
|
167
|
+
const params = new ArrayBuffer(16);
|
|
168
|
+
new Uint32Array(params, 0, 2).set([M, dModel]);
|
|
169
|
+
new Float32Array(params, 8, 1).set([1e-6]);
|
|
170
|
+
const pBuf = createUniformBuffer(d, params);
|
|
171
|
+
const bg = createBindGroup(d, this.pipelines['rmsnorm'], [pBuf, xBuf, this.gpuWeights['normWeight'], normOut, normInv]);
|
|
172
|
+
dispatchKernel(d, this.pipelines['rmsnorm'], bg, [cdiv(M, 64), 1, 1]);
|
|
173
|
+
}
|
|
174
|
+
// 2. Input projection → x and z
|
|
175
|
+
const inProjOut = createEmptyStorageBuffer(d, M * 2 * D * 4, true);
|
|
176
|
+
cache.normOut = normOut;
|
|
177
|
+
{
|
|
178
|
+
const params = new Uint32Array([M, dModel, 2 * D]).buffer;
|
|
179
|
+
const pBuf = createUniformBuffer(d, params);
|
|
180
|
+
const bg = createBindGroup(d, this.pipelines['linear'], [pBuf, normOut, this.gpuWeights['wInProj'], this.gpuWeights['bInProj'], inProjOut]);
|
|
181
|
+
dispatchKernel(d, this.pipelines['linear'], bg, [cdiv(M, 16), cdiv(2 * D, 16), 1]);
|
|
182
|
+
}
|
|
183
|
+
// 3. Split into x and z
|
|
184
|
+
const xConvIn = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
185
|
+
const zBuf = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
186
|
+
{
|
|
187
|
+
const enc = d.createCommandEncoder();
|
|
188
|
+
enc.copyBufferToBuffer(inProjOut, 0, xConvIn, 0, M * D * 4);
|
|
189
|
+
enc.copyBufferToBuffer(inProjOut, M * D * 4, zBuf, 0, M * D * 4);
|
|
190
|
+
d.queue.submit([enc.finish()]);
|
|
191
|
+
}
|
|
192
|
+
inProjOut.destroy();
|
|
193
|
+
cache.zBuf = zBuf;
|
|
194
|
+
cache.xConvIn = xConvIn;
|
|
195
|
+
// 4. Causal conv1d on x
|
|
196
|
+
const convOut = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
197
|
+
cache.convOut = convOut;
|
|
198
|
+
{
|
|
199
|
+
const params = new Uint32Array([L, D, dConv, B]).buffer;
|
|
200
|
+
const pBuf = createUniformBuffer(d, params);
|
|
201
|
+
const bg = createBindGroup(d, this.pipelines['conv1d'], [pBuf, xConvIn, this.gpuWeights['wConv'], this.gpuWeights['bConv'], convOut]);
|
|
202
|
+
dispatchKernel(d, this.pipelines['conv1d'], bg, [cdiv(L, 16), cdiv(D, 16), B]);
|
|
203
|
+
}
|
|
204
|
+
// 5. SiLU activation
|
|
205
|
+
const siluOut = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
206
|
+
cache.siluOut = siluOut;
|
|
207
|
+
{
|
|
208
|
+
const params = new Uint32Array([M * D]).buffer;
|
|
209
|
+
const pBuf = createUniformBuffer(d, params);
|
|
210
|
+
const bg = createBindGroup(d, this.pipelines['silu'], [pBuf, convOut, siluOut]);
|
|
211
|
+
dispatchKernel(d, this.pipelines['silu'], bg, [cdiv(M * D, 256), 1, 1]);
|
|
212
|
+
}
|
|
213
|
+
// 6. x_proj → Δ (dtRaw), B, C
|
|
214
|
+
const xProjOut = createEmptyStorageBuffer(d, M * (R + 2 * N) * 4, true);
|
|
215
|
+
{
|
|
216
|
+
const params = new Uint32Array([M, D, R + 2 * N]).buffer;
|
|
217
|
+
const pBuf = createUniformBuffer(d, params);
|
|
218
|
+
const bg = createBindGroup(d, this.pipelines['linear'], [pBuf, siluOut, this.gpuWeights['wXProj'], this.gpuWeights['bXProj'], xProjOut]);
|
|
219
|
+
dispatchKernel(d, this.pipelines['linear'], bg, [cdiv(M, 16), cdiv(R + 2 * N, 16), 1]);
|
|
220
|
+
}
|
|
221
|
+
const dtRaw = createEmptyStorageBuffer(d, M * R * 4, true);
|
|
222
|
+
const B_raw = createEmptyStorageBuffer(d, B * L * N * 4, true);
|
|
223
|
+
const C_raw = createEmptyStorageBuffer(d, B * L * N * 4, true);
|
|
224
|
+
{
|
|
225
|
+
const enc = d.createCommandEncoder();
|
|
226
|
+
enc.copyBufferToBuffer(xProjOut, 0, dtRaw, 0, M * R * 4);
|
|
227
|
+
enc.copyBufferToBuffer(xProjOut, M * R * 4, B_raw, 0, B * L * N * 4);
|
|
228
|
+
enc.copyBufferToBuffer(xProjOut, M * (R + N) * 4, C_raw, 0, B * L * N * 4);
|
|
229
|
+
d.queue.submit([enc.finish()]);
|
|
230
|
+
}
|
|
231
|
+
xProjOut.destroy();
|
|
232
|
+
cache.B_raw = B_raw;
|
|
233
|
+
cache.C_raw = C_raw;
|
|
234
|
+
// 7. dt_proj: expand Δ to full dim
|
|
235
|
+
const deltaFull = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
236
|
+
cache.deltaFull = deltaFull;
|
|
237
|
+
{
|
|
238
|
+
const params = new Uint32Array([M, R, D]).buffer;
|
|
239
|
+
const pBuf = createUniformBuffer(d, params);
|
|
240
|
+
const bg = createBindGroup(d, this.pipelines['linear'], [pBuf, dtRaw, this.gpuWeights['wDtProj'], this.gpuWeights['bDtProj'], deltaFull]);
|
|
241
|
+
dispatchKernel(d, this.pipelines['linear'], bg, [cdiv(M, 16), cdiv(D, 16), 1]);
|
|
242
|
+
}
|
|
243
|
+
dtRaw.destroy();
|
|
244
|
+
// 8. Selective scan (S6)
|
|
245
|
+
const scanY = createEmptyStorageBuffer(d, B * L * D * 4, true);
|
|
246
|
+
const hCache = createEmptyStorageBuffer(d, 2 * B * L * D * N * 4, true);
|
|
247
|
+
cache.hCache = hCache;
|
|
248
|
+
{
|
|
249
|
+
const params = new Uint32Array([L, N, D, B]).buffer;
|
|
250
|
+
const pBuf = createUniformBuffer(d, params);
|
|
251
|
+
const bg1 = createBindGroup(d, this.pipelines['scan_fwd'], [pBuf, siluOut, deltaFull, this.gpuWeights['A_log'], B_raw, C_raw,
|
|
252
|
+
this.gpuWeights['D_vec'], scanY, hCache]);
|
|
253
|
+
dispatchKernel(d, this.pipelines['scan_fwd'], bg1, [cdiv(D, 8), cdiv(N, 8), B]);
|
|
254
|
+
const bg2 = createBindGroup(d, this.pipelines['scan_reduce'], [pBuf, siluOut, deltaFull, this.gpuWeights['A_log'], B_raw, C_raw,
|
|
255
|
+
this.gpuWeights['D_vec'], scanY, hCache]);
|
|
256
|
+
dispatchKernel(d, this.pipelines['scan_reduce'], bg2, [cdiv(L, 64), D, B]);
|
|
257
|
+
}
|
|
258
|
+
// 9. Gate: y ⊗ SiLU(z)
|
|
259
|
+
const siluZ = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
260
|
+
const gatedOut = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
261
|
+
{
|
|
262
|
+
const nBuf = createUniformBuffer(d, new Uint32Array([M * D]).buffer);
|
|
263
|
+
const bgZ = createBindGroup(d, this.pipelines['silu'], [nBuf, zBuf, siluZ]);
|
|
264
|
+
dispatchKernel(d, this.pipelines['silu'], bgZ, [cdiv(M * D, 256), 1, 1]);
|
|
265
|
+
const nBuf2 = createUniformBuffer(d, new Uint32Array([M * D]).buffer);
|
|
266
|
+
const bgMul = createBindGroup(d, this.pipelines['elMul'], [scanY, siluZ, gatedOut, nBuf2]);
|
|
267
|
+
dispatchKernel(d, this.pipelines['elMul'], bgMul, [cdiv(M * D, 256), 1, 1]);
|
|
268
|
+
}
|
|
269
|
+
siluZ.destroy();
|
|
270
|
+
scanY.destroy();
|
|
271
|
+
// 10. Output projection
|
|
272
|
+
const outProjOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
273
|
+
{
|
|
274
|
+
const params = new Uint32Array([M, D, dModel]).buffer;
|
|
275
|
+
const pBuf = createUniformBuffer(d, params);
|
|
276
|
+
const bg = createBindGroup(d, this.pipelines['linear'], [pBuf, gatedOut, this.gpuWeights['wOutProj'], this.gpuWeights['bOutProj'], outProjOut]);
|
|
277
|
+
dispatchKernel(d, this.pipelines['linear'], bg, [cdiv(M, 16), cdiv(dModel, 16), 1]);
|
|
278
|
+
}
|
|
279
|
+
gatedOut.destroy();
|
|
280
|
+
// 11. Residual add
|
|
281
|
+
const output = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
282
|
+
{
|
|
283
|
+
const nBuf = createUniformBuffer(d, new Uint32Array([M * dModel]).buffer);
|
|
284
|
+
const bg = createBindGroup(d, this.pipelines['elAdd'], [outProjOut, xBuf, output, nBuf]);
|
|
285
|
+
dispatchKernel(d, this.pipelines['elAdd'], bg, [cdiv(M * dModel, 256), 1, 1]);
|
|
286
|
+
}
|
|
287
|
+
outProjOut.destroy();
|
|
288
|
+
return { output, cache };
|
|
289
|
+
}
|
|
290
|
+
parameters() {
|
|
291
|
+
const { dModel, dState, dConv } = this.config;
|
|
292
|
+
const D = this.dInner;
|
|
293
|
+
const N = dState;
|
|
294
|
+
const K = dConv;
|
|
295
|
+
const R = this.dtRank;
|
|
296
|
+
return [
|
|
297
|
+
{ buf: this.gpuWeights['wInProj'], numel: 2 * D * dModel, name: 'wInProj' },
|
|
298
|
+
{ buf: this.gpuWeights['bInProj'], numel: 2 * D, name: 'bInProj' },
|
|
299
|
+
{ buf: this.gpuWeights['wConv'], numel: D * K, name: 'wConv' },
|
|
300
|
+
{ buf: this.gpuWeights['bConv'], numel: D, name: 'bConv' },
|
|
301
|
+
{ buf: this.gpuWeights['wXProj'], numel: (R + 2 * N) * D, name: 'wXProj' },
|
|
302
|
+
{ buf: this.gpuWeights['bXProj'], numel: R + 2 * N, name: 'bXProj' },
|
|
303
|
+
{ buf: this.gpuWeights['wDtProj'], numel: D * R, name: 'wDtProj' },
|
|
304
|
+
{ buf: this.gpuWeights['bDtProj'], numel: D, name: 'bDtProj' },
|
|
305
|
+
{ buf: this.gpuWeights['A_log'], numel: D * N, name: 'A_log' },
|
|
306
|
+
{ buf: this.gpuWeights['D_vec'], numel: D, name: 'D_vec' },
|
|
307
|
+
{ buf: this.gpuWeights['wOutProj'], numel: dModel * D, name: 'wOutProj' },
|
|
308
|
+
{ buf: this.gpuWeights['bOutProj'], numel: dModel, name: 'bOutProj' },
|
|
309
|
+
{ buf: this.gpuWeights['normWeight'], numel: dModel, name: 'normWeight' },
|
|
310
|
+
];
|
|
311
|
+
}
|
|
312
|
+
getTrainableParams() {
|
|
313
|
+
if (this._wslaMode) {
|
|
314
|
+
return [
|
|
315
|
+
{ buf: this.gpuWeights['wXProj'], numel: this.wXProj.length, name: 'wXProj' },
|
|
316
|
+
{ buf: this.gpuWeights['bXProj'], numel: this.bXProj.length, name: 'bXProj' },
|
|
317
|
+
];
|
|
318
|
+
}
|
|
319
|
+
return this.parameters();
|
|
320
|
+
}
|
|
321
|
+
setWSLAMode(enabled) {
|
|
322
|
+
this._wslaMode = enabled;
|
|
323
|
+
}
|
|
324
|
+
destroy() {
|
|
325
|
+
for (const buf of Object.values(this.gpuWeights)) {
|
|
326
|
+
buf.destroy();
|
|
327
|
+
}
|
|
328
|
+
this.gpuWeights = {};
|
|
329
|
+
}
|
|
330
|
+
}
|
|
331
|
+
// Deprecated alias — kept until mambacode.js 3.0.0
|
|
332
|
+
export { Mamba1Block as MambaBlock };
|
|
333
|
+
//# sourceMappingURL=mamba1_block.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"mamba1_block.js","sourceRoot":"","sources":["../../src/model/mamba1_block.ts"],"names":[],"mappings":"AAAA;;;;;GAKG;AAEH,OAAO,EACH,qBAAqB,EACrB,eAAe,EACf,mBAAmB,EACnB,wBAAwB,EACxB,mBAAmB,EACnB,cAAc,EACd,IAAI,GACP,MAAM,uBAAuB,CAAC;AAE/B,OAAO,EAAE,2BAA2B,EAAE,MAAO,8BAA8B,CAAC;AAC5E,OAAO,EAAE,aAAa,EAAE,MAAM,iBAAiB,CAAC;AAChD,OAAO,EAAE,mBAAmB,EAAE,MAAe,sBAAsB,CAAC;AACpE,OAAO,EAAE,mBAAmB,EAAE,MAAe,iCAAiC,CAAC;AAC/E,OAAO,EAAE,gBAAgB,EAAE,MAAkB,2BAA2B,CAAC;AAmCzE,gFAAgF;AAEhF,MAAM,UAAU,GAAG,UAAU,CAAA;;;;;;;;;;CAU5B,CAAC;AAEF,MAAM,UAAU,GAAG,UAAU,CAAA;;;;;;;;;;CAU5B,CAAC;AAEF,iFAAiF;AAEjF,MAAM,OAAO,WAAW;IACX,SAAS,GAAG,QAAiB,CAAC;IAEvC,MAAM,CAAc;IACpB,MAAM,CAAgC;IACtC,MAAM,CAAW;IACjB,MAAM,CAAW;IAEjB,OAAO,CAAkB;IACzB,OAAO,CAAkB;IACzB,KAAK,CAAoB;IACzB,KAAK,CAAoB;IACzB,MAAM,CAAmB;IACzB,MAAM,CAAmB;IACzB,OAAO,CAAkB;IACzB,OAAO,CAAkB;IACzB,KAAK,CAAoB;IACzB,KAAK,CAAoB;IACzB,QAAQ,CAAiB;IACzB,QAAQ,CAAiB;IACzB,UAAU,CAAe;IAEzB,UAAU,CAA6B;IACvC,SAAS,CAAuC;IAExC,SAAS,GAAG,KAAK,CAAC;IAE1B,YAAY,MAAiB,EAAE,MAAyB;QACpD,IAAI,CAAC,MAAM,GAAG,MAAM,CAAC;QACrB,IAAI,CAAC,MAAM,GAAG;YACV,MAAM,EAAI,EAAE;YACZ,KAAK,EAAK,CAAC;YACX,MAAM,EAAI,CAAC;YACX,QAAQ,EAAE,IAAI;YACd,MAAM,EAAI,IAAI,CAAC,IAAI,CAAC,MAAM,CAAC,MAAM,GAAG,EAAE,CAAC;YACvC,GAAG,MAAM;SACmB,CAAC;QAEjC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC;QACvC,IAAI,CAAC,MAAM,GAAG,MAAM,GAAG,MAAM,CAAC;QAC9B,IAAI,CAAC,MAAM,GAAG,MAAM,CAAC,MAAM,IAAI,IAAI,CAAC,IAAI,CAAC,MAAM,GAAG,EAAE,CAAC,CAAC;QAEtD,IAAI,CAAC,OAAO,GAAM,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,OAAO,GAAM,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,KAAK,GAAQ,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,KAAK,GAAQ,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,MAAM,GAAO,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,MAAM,GAAO,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,OAAO,GAAM,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,OAAO,GAAM,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,KAAK,GAAQ,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,KAAK,GAAQ,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,QAAQ,GAAK,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,QAAQ,GAAK,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,UAAU,GAAG,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,UAAU,GAAG,EAAE,CAAC;QACrB,IAAI,CAAC,SAAS,GAAI,EAAE,CAAC;QAErB,IAAI,CAAC,YAAY,EAAE,CAAC;QACpB,IAAI,CAAC,eAAe,EAAE,CAAC;IAC3B,CAAC;IAEO,YAAY;QAChB,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC;QAC9C,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC;QACtB,MAAM,CAAC,GAAG,MAAM,CAAC;QACjB,MAAM,CAAC,GAAG,KAAK,CAAC;QAChB,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC;QAEtB,MAAM,KAAK,GAAG,CAAC,CAAS,EAAE,GAAG,GAAG,IAAI,EAAgB,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QAE7E,MAAM,KAAK,GAAG,CAAC,CAAS,EAAgB,EAAE,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC;QAC/D,MAAM,IAAI,GAAI,CAAC,CAAS,EAAgB,EAAE,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;QAEzE,IAAI,CAAC,OAAO,GAAI,KAAK,CAAC,CAAC,GAAG,CAAC,GAAG,MAAM,CAAC,CAAC;QACtC,IAAI,CAAC,OAAO,GAAI,KAAK,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;QAC7B,IAAI,CAAC,KAAK,GAAM,KAAK,CAAC,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACnC,IAAI,CAAC,KAAK,GAAM,KAAK,CAAC,CAAC,CAAC,CAAC;QACzB,IAAI,CAAC,MAAM,GAAK,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC7C,IAAI,CAAC,MAAM,GAAK,KAAK,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;QACjC,IAAI,CAAC,OAAO,GAAI,KAAK,CAAC,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACnC,IAAI,CAAC,OAAO,GAAI,KAAK,CAAC,CAAC,CAAC,CAAC;QAEzB,IAAI,CAAC,KAAK,GAAG,IAAI,YAAY,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;QACrC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;YACzB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;gBACzB,IAAI,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;YAC5C,CAAC;QACL,CAAC;QAED,IAAI,CAAC,KAAK,GAAO,IAAI,CAAC,CAAC,CAAC,CAAC;QACzB,IAAI,CAAC,QAAQ,GAAI,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACzC,IAAI,CAAC,QAAQ,GAAI,KAAK,CAAC,MAAM,CAAC,CAAC;QAC/B,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,MAAM,CAAC,CAAC;QAE/B,IAAI,CAAC,mBAAmB,EAAE,CAAC;IAC/B,CAAC;IAEO,mBAAmB;QACvB,MAAM,CAAC,GAAI,IAAI,CAAC,MAAM,CAAC;QACvB,MAAM,EAAE,GAAG,CAAC,GAAiB,EAAa,EAAE,CAAC,mBAAmB,CAAC,CAAC,EAAE,GAAG,EAAE,IAAI,CAAC,CAAC;QAE/E,IAAI,CAAC,UAAU,GAAG;YACd,OAAO,EAAK,EAAE,CAAC,IAAI,CAAC,OAAO,CAAC;YAC5B,OAAO,EAAK,EAAE,CAAC,IAAI,CAAC,OAAO,CAAC;YAC5B,KAAK,EAAO,EAAE,CAAC,IAAI,CAAC,KAAK,CAAC;YAC1B,KAAK,EAAO,EAAE,CAAC,IAAI,CAAC,KAAK,CAAC;YAC1B,MAAM,EAAM,EAAE,CAAC,IAAI,CAAC,MAAM,CAAC;YAC3B,MAAM,EAAM,EAAE,CAAC,IAAI,CAAC,MAAM,CAAC;YAC3B,OAAO,EAAK,EAAE,CAAC,IAAI,CAAC,OAAO,CAAC;YAC5B,OAAO,EAAK,EAAE,CAAC,IAAI,CAAC,OAAO,CAAC;YAC5B,KAAK,EAAO,EAAE,CAAC,IAAI,CAAC,KAAK,CAAC;YAC1B,KAAK,EAAO,EAAE,CAAC,IAAI,CAAC,KAAK,CAAC;YAC1B,QAAQ,EAAI,EAAE,CAAC,IAAI,CAAC,QAAQ,CAAC;YAC7B,QAAQ,EAAI,EAAE,CAAC,IAAI,CAAC,QAAQ,CAAC;YAC7B,UAAU,EAAE,EAAE,CAAC,IAAI,CAAC,UAAU,CAAC;SAClC,CAAC;IACN,CAAC;IAEO,eAAe;QACnB,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC;QACtB,IAAI,CAAC,SAAS,GAAG;YACb,MAAM,EAAQ,qBAAqB,CAAC,CAAC,EAAE,mBAAmB,EAAW,gBAAgB,CAAC;YACtF,MAAM,EAAQ,qBAAqB,CAAC,CAAC,EAAE,mBAAmB,EAAW,gBAAgB,CAAC;YACtF,IAAI,EAAU,qBAAqB,CAAC,CAAC,EAAE,gBAAgB,EAAc,cAAc,CAAC;YACpF,OAAO,EAAO,qBAAqB,CAAC,CAAC,EAAE,gBAAgB,EAAc,iBAAiB,CAAC;YACvF,QAAQ,EAAM,qBAAqB,CAAC,CAAC,EAAE,2BAA2B,EAAG,cAAc,CAAC;YACpF,WAAW,EAAG,qBAAqB,CAAC,CAAC,EAAE,2BAA2B,EAAG,gBAAgB,CAAC;YACtF,KAAK,EAAS,qBAAqB,CAAC,CAAC,EAAE,UAAU,EAAE,MAAM,CAAC;YAC1D,KAAK,EAAS,qBAAqB,CAAC,CAAC,EAAE,UAAU,EAAE,MAAM,CAAC;SAC7D,CAAC;IACN,CAAC;IAED,OAAO,CAAC,IAAe,EAAE,KAAa,EAAE,MAAc;QAClD,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC;QACtB,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC;QAC9C,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC;QACtB,MAAM,CAAC,GAAG,MAAM,CAAC;QACjB,MAAM,CAAC,GAAG,KAAK,CAAC;QAChB,MAAM,CAAC,GAAG,MAAM,CAAC;QACjB,MAAM,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC;QAChB,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC;QAEtB,MAAM,KAAK,GAAG,EAAgB,CAAC;QAE/B,uBAAuB;QACvB,MAAM,OAAO,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAClE,MAAM,OAAO,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACzD,KAAK,CAAC,OAAO,GAAG,OAAO,CAAC;QACxB,KAAK,CAAC,MAAM,GAAI,IAAI,CAAC;QACrB,CAAC;YACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,EAAE,CAAC,CAAC;YACnC,IAAI,WAAW,CAAC,MAAM,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC;YAC/C,IAAI,YAAY,CAAC,MAAM,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC;YAC3C,MAAM,IAAI,GAAG,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YAC5C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,SAAS,CAAE,EACpD,CAAC,IAAI,EAAE,IAAI,EAAE,IAAI,CAAC,UAAU,CAAC,YAAY,CAAE,EAAE,OAAO,EAAE,OAAO,CAAC,CAAC,CAAC;YACpE,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,SAAS,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC3E,CAAC;QAED,gCAAgC;QAChC,MAAM,SAAS,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACnE,KAAK,CAAC,OAAO,GAAG,OAAO,CAAC;QACxB,CAAC;YACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;YAC1D,MAAM,IAAI,GAAK,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YAC9C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EACnD,CAAC,IAAI,EAAE,OAAO,EAAE,IAAI,CAAC,UAAU,CAAC,SAAS,CAAE,EAAE,IAAI,CAAC,UAAU,CAAC,SAAS,CAAE,EAAE,SAAS,CAAC,CAAC,CAAC;YAC1F,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,IAAI,CAAC,CAAC,GAAG,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACxF,CAAC;QAED,wBAAwB;QACxB,MAAM,OAAO,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC7D,MAAM,IAAI,GAAM,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC7D,CAAC;YACG,MAAM,GAAG,GAAG,CAAC,CAAC,oBAAoB,EAAE,CAAC;YACrC,GAAG,CAAC,kBAAkB,CAAC,SAAS,EAAE,CAAC,EAAU,OAAO,EAAE,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;YACpE,GAAG,CAAC,kBAAkB,CAAC,SAAS,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,EAAK,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;YACpE,CAAC,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,GAAG,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC;QACnC,CAAC;QACD,SAAS,CAAC,OAAO,EAAE,CAAC;QACpB,KAAK,CAAC,IAAI,GAAM,IAAI,CAAC;QACrB,KAAK,CAAC,OAAO,GAAG,OAAO,CAAC;QAExB,wBAAwB;QACxB,MAAM,OAAO,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC7D,KAAK,CAAC,OAAO,GAAG,OAAO,CAAC;QACxB,CAAC;YACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,KAAK,EAAE,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;YACxD,MAAM,IAAI,GAAK,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YAC9C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EACnD,CAAC,IAAI,EAAE,OAAO,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,OAAO,CAAC,CAAC,CAAC;YACpF,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpF,CAAC;QAED,qBAAqB;QACrB,MAAM,OAAO,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC7D,KAAK,CAAC,OAAO,GAAG,OAAO,CAAC;QACxB,CAAC;YACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;YAC/C,MAAM,IAAI,GAAK,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YAC9C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAE,EACjD,CAAC,IAAI,EAAE,OAAO,EAAE,OAAO,CAAC,CAAC,CAAC;YAC9B,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,GAAG,CAAC,EAAE,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC7E,CAAC;QAED,8BAA8B;QAC9B,MAAM,QAAQ,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACxE,CAAC;YACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;YACzD,MAAM,IAAI,GAAK,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YAC9C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EACnD,CAAC,IAAI,EAAE,OAAO,EAAE,IAAI,CAAC,UAAU,CAAC,QAAQ,CAAE,EAAE,IAAI,CAAC,UAAU,CAAC,QAAQ,CAAE,EAAE,QAAQ,CAAC,CAAC,CAAC;YACvF,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,IAAI,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC5F,CAAC;QAED,MAAM,KAAK,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC3D,MAAM,KAAK,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC/D,MAAM,KAAK,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC/D,CAAC;YACG,MAAM,GAAG,GAAG,CAAC,CAAC,oBAAoB,EAAE,CAAC;YACrC,GAAG,CAAC,kBAAkB,CAAC,QAAQ,EAAE,CAAC,EAAgB,KAAK,EAAE,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;YACvE,GAAG,CAAC,kBAAkB,CAAC,QAAQ,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAQ,KAAK,EAAE,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;YAC3E,GAAG,CAAC,kBAAkB,CAAC,QAAQ,EAAE,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,EAAE,KAAK,EAAE,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;YAC3E,CAAC,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,GAAG,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC;QACnC,CAAC;QACD,QAAQ,CAAC,OAAO,EAAE,CAAC;QACnB,KAAK,CAAC,KAAK,GAAG,KAAK,CAAC;QACpB,KAAK,CAAC,KAAK,GAAG,KAAK,CAAC;QAEpB,mCAAmC;QACnC,MAAM,SAAS,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC/D,KAAK,CAAC,SAAS,GAAG,SAAS,CAAC;QAC5B,CAAC;YACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;YACjD,MAAM,IAAI,GAAK,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YAC9C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EACnD,CAAC,IAAI,EAAE,KAAK,EAAE,IAAI,CAAC,UAAU,CAAC,SAAS,CAAE,EAAE,IAAI,CAAC,UAAU,CAAC,SAAS,CAAE,EAAE,SAAS,CAAC,CAAC,CAAC;YACxF,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpF,CAAC;QACD,KAAK,CAAC,OAAO,EAAE,CAAC;QAEhB,yBAAyB;QACzB,MAAM,KAAK,GAAI,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAChE,MAAM,MAAM,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACxE,KAAK,CAAC,MAAM,GAAG,MAAM,CAAC;QACtB,CAAC;YACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;YACpD,MAAM,IAAI,GAAK,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YAE9C,MAAM,GAAG,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,UAAU,CAAE,EACtD,CAAC,IAAI,EAAE,OAAO,EAAE,SAAS,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,KAAK,EAAE,KAAK;gBACjE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,KAAK,EAAE,MAAM,CAAC,CAAC,CAAC;YAChD,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,UAAU,CAAE,EAAE,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;YAEjF,MAAM,GAAG,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,aAAa,CAAE,EACzD,CAAC,IAAI,EAAE,OAAO,EAAE,SAAS,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,KAAK,EAAE,KAAK;gBACjE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAE,KAAK,EAAE,MAAM,CAAC,CAAC,CAAC;YAChD,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,aAAa,CAAE,EAAE,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAChF,CAAC;QAED,uBAAuB;QACvB,MAAM,KAAK,GAAM,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC9D,MAAM,QAAQ,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QAC9D,CAAC;YACG,MAAM,IAAI,GAAG,mBAAmB,CAAC,CAAC,EAAE,IAAI,WAAW,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;YACrE,MAAM,GAAG,GAAI,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAE,EACnD,CAAC,IAAI,EAAE,IAAI,EAAE,KAAK,CAAC,CAAC,CAAC;YACzB,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAE,EAAE,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,GAAG,CAAC,EAAE,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;YAE1E,MAAM,KAAK,GAAG,mBAAmB,CAAC,CAAC,EAAE,IAAI,WAAW,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;YACtE,MAAM,KAAK,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,OAAO,CAAE,EACrD,CAAC,KAAK,EAAE,KAAK,EAAE,QAAQ,EAAE,KAAK,CAAC,CAAC,CAAC;YACrC,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,OAAO,CAAE,EAAE,KAAK,EAAE,CAAC,IAAI,CAAC,CAAC,GAAG,CAAC,EAAE,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACjF,CAAC;QACD,KAAK,CAAC,OAAO,EAAE,CAAC;QAChB,KAAK,CAAC,OAAO,EAAE,CAAC;QAEhB,wBAAwB;QACxB,MAAM,UAAU,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACrE,CAAC;YACG,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC;YACtD,MAAM,IAAI,GAAK,mBAAmB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;YAC9C,MAAM,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EACnD,CAAC,IAAI,EAAE,QAAQ,EAAE,IAAI,CAAC,UAAU,CAAC,UAAU,CAAE,EAAE,IAAI,CAAC,UAAU,CAAC,UAAU,CAAE,EAAE,UAAU,CAAC,CAAC,CAAC;YAC9F,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACzF,CAAC;QACD,QAAQ,CAAC,OAAO,EAAE,CAAC;QAEnB,mBAAmB;QACnB,MAAM,MAAM,GAAG,wBAAwB,CAAC,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;QACjE,CAAC;YACG,MAAM,IAAI,GAAG,mBAAmB,CAAC,CAAC,EAAE,IAAI,WAAW,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;YAC1E,MAAM,EAAE,GAAK,eAAe,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,OAAO,CAAE,EACpD,CAAC,UAAU,EAAE,IAAI,EAAE,MAAM,EAAE,IAAI,CAAC,CAAC,CAAC;YACtC,cAAc,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,OAAO,CAAE,EAAE,EAAE,EAAE,CAAC,IAAI,CAAC,CAAC,GAAG,MAAM,EAAE,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACnF,CAAC;QACD,UAAU,CAAC,OAAO,EAAE,CAAC;QAErB,OAAO,EAAE,MAAM,EAAE,KAAK,EAAE,CAAC;IAC7B,CAAC;IAED,UAAU;QACN,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,GAAG,IAAI,CAAC,MAAM,CAAC;QAC9C,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC;QACtB,MAAM,CAAC,GAAG,MAAM,CAAC;QACjB,MAAM,CAAC,GAAG,KAAK,CAAC;QAChB,MAAM,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC;QAEtB,OAAO;YACH,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,SAAS,CAAE,EAAK,KAAK,EAAE,CAAC,GAAG,CAAC,GAAG,MAAM,EAAI,IAAI,EAAE,SAAS,EAAK;YACpF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,SAAS,CAAE,EAAK,KAAK,EAAE,CAAC,GAAG,CAAC,EAAa,IAAI,EAAE,SAAS,EAAK;YACpF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAO,KAAK,EAAE,CAAC,GAAG,CAAC,EAAa,IAAI,EAAE,OAAO,EAAO;YACpF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAO,KAAK,EAAE,CAAC,EAAiB,IAAI,EAAE,OAAO,EAAO;YACpF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,QAAQ,CAAE,EAAM,KAAK,EAAE,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,EAAE,IAAI,EAAE,QAAQ,EAAK;YAClF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,QAAQ,CAAE,EAAM,KAAK,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,EAAQ,IAAI,EAAE,QAAQ,EAAK;YAClF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,SAAS,CAAE,EAAK,KAAK,EAAE,CAAC,GAAG,CAAC,EAAa,IAAI,EAAE,SAAS,EAAI;YACnF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,SAAS,CAAE,EAAK,KAAK,EAAE,CAAC,EAAiB,IAAI,EAAE,SAAS,EAAI;YACnF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAO,KAAK,EAAE,CAAC,GAAG,CAAC,EAAa,IAAI,EAAE,OAAO,EAAM;YACnF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,OAAO,CAAE,EAAO,KAAK,EAAE,CAAC,EAAiB,IAAI,EAAE,OAAO,EAAM;YACnF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,UAAU,CAAE,EAAI,KAAK,EAAE,MAAM,GAAG,CAAC,EAAQ,IAAI,EAAE,UAAU,EAAG;YACnF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,UAAU,CAAE,EAAI,KAAK,EAAE,MAAM,EAAY,IAAI,EAAE,UAAU,EAAG;YACnF,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,YAAY,CAAE,EAAE,KAAK,EAAE,MAAM,EAAY,IAAI,EAAE,YAAY,EAAC;SACtF,CAAC;IACN,CAAC;IAED,kBAAkB;QACd,IAAI,IAAI,CAAC,SAAS,EAAE,CAAC;YACjB,OAAO;gBACH,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,QAAQ,CAAE,EAAE,KAAK,EAAE,IAAI,CAAC,MAAM,CAAC,MAAM,EAAE,IAAI,EAAE,QAAQ,EAAE;gBAC9E,EAAE,GAAG,EAAE,IAAI,CAAC,UAAU,CAAC,QAAQ,CAAE,EAAE,KAAK,EAAE,IAAI,CAAC,MAAM,CAAC,MAAM,EAAE,IAAI,EAAE,QAAQ,EAAE;aACjF,CAAC;QACN,CAAC;QACD,OAAO,IAAI,CAAC,UAAU,EAAE,CAAC;IAC7B,CAAC;IAED,WAAW,CAAC,OAAgB;QACxB,IAAI,CAAC,SAAS,GAAG,OAAO,CAAC;IAC7B,CAAC;IAED,OAAO;QACH,KAAK,MAAM,GAAG,IAAI,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,UAAU,CAAC,EAAE,CAAC;YAC/C,GAAG,CAAC,OAAO,EAAE,CAAC;QAClB,CAAC;QACD,IAAI,CAAC,UAAU,GAAG,EAAE,CAAC;IACzB,CAAC;CACJ;AAED,mDAAmD;AACnD,OAAO,EAAE,WAAW,IAAI,UAAU,EAAE,CAAC"}
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* mamba2_block.ts – Mamba-2 Mixer Block (Structured State Space Duality).
|
|
3
|
+
*
|
|
4
|
+
* Key differences from Mamba-1:
|
|
5
|
+
* - Multi-head SSM with scalar A per head
|
|
6
|
+
* - Single fused in_proj (no separate dt_proj expansion)
|
|
7
|
+
* - SSD (chunked) scan replaces S6 selective scan
|
|
8
|
+
* - Inner RMSNorm on scan output instead of SiLU gate
|
|
9
|
+
* - No separate z gate
|
|
10
|
+
*
|
|
11
|
+
* Implements SequenceLayer.
|
|
12
|
+
*/
|
|
13
|
+
import type { SequenceLayer, LayerForwardResult, LayerParam } from './sequence_layer.js';
|
|
14
|
+
export interface Mamba2BlockConfig {
|
|
15
|
+
dModel: number;
|
|
16
|
+
dState: number;
|
|
17
|
+
dConv: number;
|
|
18
|
+
expand: number;
|
|
19
|
+
nHeads: number;
|
|
20
|
+
nGroups: number;
|
|
21
|
+
chunkLen: number;
|
|
22
|
+
}
|
|
23
|
+
export interface Mamba2Cache {
|
|
24
|
+
stateCarry: GPUBuffer;
|
|
25
|
+
}
|
|
26
|
+
export declare class Mamba2Block implements SequenceLayer {
|
|
27
|
+
readonly layerType: "mamba2";
|
|
28
|
+
device: GPUDevice;
|
|
29
|
+
config: Required<Mamba2BlockConfig>;
|
|
30
|
+
dInner: number;
|
|
31
|
+
dHead: number;
|
|
32
|
+
gpuWeights: Record<string, GPUBuffer>;
|
|
33
|
+
pipelines: Record<string, GPUComputePipeline>;
|
|
34
|
+
private _wslaMode;
|
|
35
|
+
constructor(device: GPUDevice, config: Mamba2BlockConfig);
|
|
36
|
+
private _initWeights;
|
|
37
|
+
private _buildPipelines;
|
|
38
|
+
forward(xBuf: GPUBuffer, batch: number, seqLen: number): LayerForwardResult;
|
|
39
|
+
parameters(): LayerParam[];
|
|
40
|
+
getTrainableParams(): LayerParam[];
|
|
41
|
+
setWSLAMode(enabled: boolean): void;
|
|
42
|
+
destroy(): void;
|
|
43
|
+
}
|
|
44
|
+
//# sourceMappingURL=mamba2_block.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"mamba2_block.d.ts","sourceRoot":"","sources":["../../src/model/mamba2_block.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;GAWG;AAkBH,OAAO,KAAK,EAAE,aAAa,EAAE,kBAAkB,EAAE,UAAU,EAAE,MAAM,qBAAqB,CAAC;AAEzF,MAAM,WAAW,iBAAiB;IAC9B,MAAM,EAAK,MAAM,CAAC;IAClB,MAAM,EAAK,MAAM,CAAC;IAClB,KAAK,EAAM,MAAM,CAAC;IAClB,MAAM,EAAK,MAAM,CAAC;IAClB,MAAM,EAAK,MAAM,CAAC;IAClB,OAAO,EAAI,MAAM,CAAC;IAClB,QAAQ,EAAG,MAAM,CAAC;CACrB;AAED,MAAM,WAAW,WAAW;IACxB,UAAU,EAAG,SAAS,CAAC;CAC1B;AAcD,qBAAa,WAAY,YAAW,aAAa;IAC7C,QAAQ,CAAC,SAAS,EAAG,QAAQ,CAAU;IAEvC,MAAM,EAAG,SAAS,CAAC;IACnB,MAAM,EAAG,QAAQ,CAAC,iBAAiB,CAAC,CAAC;IACrC,MAAM,EAAG,MAAM,CAAC;IAChB,KAAK,EAAI,MAAM,CAAC;IAEhB,UAAU,EAAG,MAAM,CAAC,MAAM,EAAE,SAAS,CAAC,CAAC;IACvC,SAAS,EAAI,MAAM,CAAC,MAAM,EAAE,kBAAkB,CAAC,CAAC;IAEhD,OAAO,CAAC,SAAS,CAAS;gBAEd,MAAM,EAAE,SAAS,EAAE,MAAM,EAAE,iBAAiB;IAwBxD,OAAO,CAAC,YAAY;IA8BpB,OAAO,CAAC,eAAe;IAWvB,OAAO,CAAC,IAAI,EAAE,SAAS,EAAE,KAAK,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,kBAAkB;IA6I3E,UAAU,IAAI,UAAU,EAAE;IAsB1B,kBAAkB,IAAI,UAAU,EAAE;IAYlC,WAAW,CAAC,OAAO,EAAE,OAAO,GAAG,IAAI;IAInC,OAAO,IAAI,IAAI;CAIlB"}
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* mamba2_block.ts – Mamba-2 Mixer Block (Structured State Space Duality).
|
|
3
|
+
*
|
|
4
|
+
* Key differences from Mamba-1:
|
|
5
|
+
* - Multi-head SSM with scalar A per head
|
|
6
|
+
* - Single fused in_proj (no separate dt_proj expansion)
|
|
7
|
+
* - SSD (chunked) scan replaces S6 selective scan
|
|
8
|
+
* - Inner RMSNorm on scan output instead of SiLU gate
|
|
9
|
+
* - No separate z gate
|
|
10
|
+
*
|
|
11
|
+
* Implements SequenceLayer.
|
|
12
|
+
*/
|
|
13
|
+
import { createComputePipeline, createBindGroup, createStorageBuffer, createEmptyStorageBuffer, createUniformBuffer, dispatchKernel, cdiv, } from '../utils/gpu_utils.js';
|
|
14
|
+
import { SSD_FORWARD_WGSL } from '../kernels/ssd.js';
|
|
15
|
+
import { gaussianArray } from '../utils/rng.js';
|
|
16
|
+
import { CONV1D_FORWARD_WGSL } from '../kernels/conv1d.js';
|
|
17
|
+
import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection.js';
|
|
18
|
+
import { ACTIVATIONS_WGSL } from '../kernels/activations.js';
|
|
19
|
+
const ADD_SHADER = /* wgsl */ `
|
|
20
|
+
@group(0) @binding(0) var<storage, read> a : array<f32>;
|
|
21
|
+
@group(0) @binding(1) var<storage, read> b : array<f32>;
|
|
22
|
+
@group(0) @binding(2) var<storage, read_write> c : array<f32>;
|
|
23
|
+
@group(0) @binding(3) var<uniform> n : u32;
|
|
24
|
+
@compute @workgroup_size(256)
|
|
25
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
26
|
+
let i = gid.x;
|
|
27
|
+
if (i < n) { c[i] = a[i] + b[i]; }
|
|
28
|
+
}
|
|
29
|
+
`;
|
|
30
|
+
export class Mamba2Block {
|
|
31
|
+
layerType = 'mamba2';
|
|
32
|
+
device;
|
|
33
|
+
config;
|
|
34
|
+
dInner;
|
|
35
|
+
dHead;
|
|
36
|
+
gpuWeights;
|
|
37
|
+
pipelines;
|
|
38
|
+
_wslaMode = false;
|
|
39
|
+
constructor(device, config) {
|
|
40
|
+
this.device = device;
|
|
41
|
+
this.config = {
|
|
42
|
+
...{ dState: 16, dConv: 4, expand: 2, nGroups: 1, chunkLen: 256 },
|
|
43
|
+
...config,
|
|
44
|
+
};
|
|
45
|
+
const { dModel, expand, nHeads } = this.config;
|
|
46
|
+
this.dInner = expand * dModel;
|
|
47
|
+
this.dHead = this.dInner / nHeads;
|
|
48
|
+
if (this.dInner % nHeads !== 0) {
|
|
49
|
+
throw new Error(`Mamba2Block: dInner (${this.dInner}) must be divisible by nHeads (${nHeads}).`);
|
|
50
|
+
}
|
|
51
|
+
this.gpuWeights = {};
|
|
52
|
+
this.pipelines = {};
|
|
53
|
+
this._initWeights();
|
|
54
|
+
this._buildPipelines();
|
|
55
|
+
}
|
|
56
|
+
_initWeights() {
|
|
57
|
+
const { dModel, dState, dConv, nHeads, nGroups } = this.config;
|
|
58
|
+
const D = this.dInner;
|
|
59
|
+
const N = dState;
|
|
60
|
+
const K = dConv;
|
|
61
|
+
const H = nHeads;
|
|
62
|
+
const G = nGroups;
|
|
63
|
+
const randn = (n, std = 0.02) => gaussianArray(n, std);
|
|
64
|
+
const zeros = (n) => new Float32Array(n);
|
|
65
|
+
const ones = (n) => new Float32Array(n).fill(1.0);
|
|
66
|
+
// wInProj: (D_inner + 2*n_groups*N + H, D_model) — no bias per Mamba-2 spec
|
|
67
|
+
const inProjRows = D + 2 * G * N + H;
|
|
68
|
+
const mk = (arr) => createStorageBuffer(this.device, arr, true);
|
|
69
|
+
this.gpuWeights = {
|
|
70
|
+
wInProj: mk(randn(inProjRows * dModel)),
|
|
71
|
+
wConv: mk(randn((D + 2 * G * N) * K, 0.01)),
|
|
72
|
+
bConv: mk(zeros(D + 2 * G * N)),
|
|
73
|
+
A_log: mk(new Float32Array(H).fill(Math.log(1.0))),
|
|
74
|
+
dt_bias: mk(zeros(H)),
|
|
75
|
+
D_vec: mk(ones(H)),
|
|
76
|
+
wOutProj: mk(randn(dModel * D, 0.02)),
|
|
77
|
+
normWeight: mk(ones(D)), // inner RMSNorm
|
|
78
|
+
preNormWeight: mk(ones(dModel)), // pre-block RMSNorm
|
|
79
|
+
};
|
|
80
|
+
}
|
|
81
|
+
_buildPipelines() {
|
|
82
|
+
const d = this.device;
|
|
83
|
+
this.pipelines = {
|
|
84
|
+
linear: createComputePipeline(d, LINEAR_FORWARD_WGSL, 'linear_forward'),
|
|
85
|
+
conv1d: createComputePipeline(d, CONV1D_FORWARD_WGSL, 'conv1d_forward'),
|
|
86
|
+
rmsnorm: createComputePipeline(d, ACTIVATIONS_WGSL, 'rmsnorm_forward'),
|
|
87
|
+
ssd_fwd: createComputePipeline(d, SSD_FORWARD_WGSL, 'ssd_chunk_forward'),
|
|
88
|
+
elAdd: createComputePipeline(d, ADD_SHADER, 'main'),
|
|
89
|
+
};
|
|
90
|
+
}
|
|
91
|
+
forward(xBuf, batch, seqLen) {
|
|
92
|
+
const d = this.device;
|
|
93
|
+
const { dModel, dState, dConv, nHeads, nGroups, chunkLen } = this.config;
|
|
94
|
+
const D = this.dInner;
|
|
95
|
+
const N = dState;
|
|
96
|
+
const K = dConv;
|
|
97
|
+
const H = nHeads;
|
|
98
|
+
const G = nGroups;
|
|
99
|
+
const dh = this.dHead;
|
|
100
|
+
const B = batch;
|
|
101
|
+
const L = seqLen;
|
|
102
|
+
const M = B * L;
|
|
103
|
+
const convD = D + 2 * G * N; // channels for conv (x, B_proj, C_proj)
|
|
104
|
+
const numChunks = Math.ceil(L / chunkLen);
|
|
105
|
+
// 1. Pre-block RMSNorm
|
|
106
|
+
const normOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
107
|
+
const normInv = createEmptyStorageBuffer(d, M * 4, true);
|
|
108
|
+
{
|
|
109
|
+
const params = new ArrayBuffer(16);
|
|
110
|
+
new Uint32Array(params, 0, 2).set([M, dModel]);
|
|
111
|
+
new Float32Array(params, 8, 1).set([1e-6]);
|
|
112
|
+
const pBuf = createUniformBuffer(d, params);
|
|
113
|
+
const bg = createBindGroup(d, this.pipelines['rmsnorm'], [pBuf, xBuf, this.gpuWeights['preNormWeight'], normOut, normInv]);
|
|
114
|
+
dispatchKernel(d, this.pipelines['rmsnorm'], bg, [cdiv(M, 64), 1, 1]);
|
|
115
|
+
}
|
|
116
|
+
normInv.destroy();
|
|
117
|
+
// 2. Fused in_proj → [x (D), B_proj (G*N), C_proj (G*N), dt (H)]
|
|
118
|
+
const inProjRows = D + 2 * G * N + H;
|
|
119
|
+
const inProjOut = createEmptyStorageBuffer(d, M * inProjRows * 4, true);
|
|
120
|
+
{
|
|
121
|
+
const params = new Uint32Array([M, dModel, inProjRows]).buffer;
|
|
122
|
+
const pBuf = createUniformBuffer(d, params);
|
|
123
|
+
// wInProj has no bias — pass a zero-filled buffer
|
|
124
|
+
const zeroBias = createStorageBuffer(d, new Float32Array(inProjRows), true);
|
|
125
|
+
const bg = createBindGroup(d, this.pipelines['linear'], [pBuf, normOut, this.gpuWeights['wInProj'], zeroBias, inProjOut]);
|
|
126
|
+
dispatchKernel(d, this.pipelines['linear'], bg, [cdiv(M, 16), cdiv(inProjRows, 16), 1]);
|
|
127
|
+
zeroBias.destroy();
|
|
128
|
+
}
|
|
129
|
+
normOut.destroy();
|
|
130
|
+
// Split: xConv [D+2GN], dt [H]
|
|
131
|
+
const xConvBuf = createEmptyStorageBuffer(d, M * convD * 4, true);
|
|
132
|
+
const dtBuf = createEmptyStorageBuffer(d, M * H * 4, true);
|
|
133
|
+
{
|
|
134
|
+
const enc = d.createCommandEncoder();
|
|
135
|
+
enc.copyBufferToBuffer(inProjOut, 0, xConvBuf, 0, M * convD * 4);
|
|
136
|
+
enc.copyBufferToBuffer(inProjOut, M * convD * 4, dtBuf, 0, M * H * 4);
|
|
137
|
+
d.queue.submit([enc.finish()]);
|
|
138
|
+
}
|
|
139
|
+
inProjOut.destroy();
|
|
140
|
+
// 3. Causal conv1d over x + B_proj + C_proj (fused, convD channels)
|
|
141
|
+
const convOut = createEmptyStorageBuffer(d, M * convD * 4, true);
|
|
142
|
+
{
|
|
143
|
+
const params = new Uint32Array([L, convD, K, B, 1]).buffer;
|
|
144
|
+
const pBuf = createUniformBuffer(d, params);
|
|
145
|
+
const bg = createBindGroup(d, this.pipelines['conv1d'], [pBuf, xConvBuf, this.gpuWeights['wConv'], this.gpuWeights['bConv'], convOut]);
|
|
146
|
+
dispatchKernel(d, this.pipelines['conv1d'], bg, [cdiv(L, 16), cdiv(convD, 16), B]);
|
|
147
|
+
}
|
|
148
|
+
xConvBuf.destroy();
|
|
149
|
+
// Split conv output: x [D], B_proj [G*N], C_proj [G*N]
|
|
150
|
+
const xSsdBuf = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
151
|
+
const bProjBuf = createEmptyStorageBuffer(d, M * G * N * 4, true);
|
|
152
|
+
const cProjBuf = createEmptyStorageBuffer(d, M * G * N * 4, true);
|
|
153
|
+
{
|
|
154
|
+
const enc = d.createCommandEncoder();
|
|
155
|
+
enc.copyBufferToBuffer(convOut, 0, xSsdBuf, 0, M * D * 4);
|
|
156
|
+
enc.copyBufferToBuffer(convOut, M * D * 4, bProjBuf, 0, M * G * N * 4);
|
|
157
|
+
enc.copyBufferToBuffer(convOut, M * (D + G * N) * 4, cProjBuf, 0, M * G * N * 4);
|
|
158
|
+
d.queue.submit([enc.finish()]);
|
|
159
|
+
}
|
|
160
|
+
convOut.destroy();
|
|
161
|
+
// 4. SSD scan
|
|
162
|
+
// state_carry: [numChunks+1, B, H, N, dHead]
|
|
163
|
+
const stateCarry = createEmptyStorageBuffer(d, (numChunks + 1) * B * H * N * dh * 4, true);
|
|
164
|
+
const ssdOut = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
165
|
+
{
|
|
166
|
+
const ssdParams = new Uint32Array([L, D, H, dh, G, N, chunkLen, numChunks, B]).buffer;
|
|
167
|
+
const pBuf = createUniformBuffer(d, ssdParams);
|
|
168
|
+
const bg = createBindGroup(d, this.pipelines['ssd_fwd'], [pBuf, xSsdBuf, bProjBuf, cProjBuf, dtBuf,
|
|
169
|
+
this.gpuWeights['A_log'], this.gpuWeights['dt_bias'],
|
|
170
|
+
this.gpuWeights['D_vec'], ssdOut, stateCarry]);
|
|
171
|
+
dispatchKernel(d, this.pipelines['ssd_fwd'], bg, [numChunks, H, B]);
|
|
172
|
+
}
|
|
173
|
+
xSsdBuf.destroy();
|
|
174
|
+
bProjBuf.destroy();
|
|
175
|
+
cProjBuf.destroy();
|
|
176
|
+
dtBuf.destroy();
|
|
177
|
+
// 5. Inner RMSNorm on scan output
|
|
178
|
+
const innerNormOut = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
179
|
+
const innerNormInv = createEmptyStorageBuffer(d, M * 4, true);
|
|
180
|
+
{
|
|
181
|
+
const params = new ArrayBuffer(16);
|
|
182
|
+
new Uint32Array(params, 0, 2).set([M, D]);
|
|
183
|
+
new Float32Array(params, 8, 1).set([1e-6]);
|
|
184
|
+
const pBuf = createUniformBuffer(d, params);
|
|
185
|
+
const bg = createBindGroup(d, this.pipelines['rmsnorm'], [pBuf, ssdOut, this.gpuWeights['normWeight'], innerNormOut, innerNormInv]);
|
|
186
|
+
dispatchKernel(d, this.pipelines['rmsnorm'], bg, [cdiv(M, 64), 1, 1]);
|
|
187
|
+
}
|
|
188
|
+
ssdOut.destroy();
|
|
189
|
+
innerNormInv.destroy();
|
|
190
|
+
// 6. Output projection
|
|
191
|
+
const outProjOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
192
|
+
{
|
|
193
|
+
const params = new Uint32Array([M, D, dModel]).buffer;
|
|
194
|
+
const pBuf = createUniformBuffer(d, params);
|
|
195
|
+
const zeroBias = createStorageBuffer(d, new Float32Array(dModel), true);
|
|
196
|
+
const bg = createBindGroup(d, this.pipelines['linear'], [pBuf, innerNormOut, this.gpuWeights['wOutProj'], zeroBias, outProjOut]);
|
|
197
|
+
dispatchKernel(d, this.pipelines['linear'], bg, [cdiv(M, 16), cdiv(dModel, 16), 1]);
|
|
198
|
+
zeroBias.destroy();
|
|
199
|
+
}
|
|
200
|
+
innerNormOut.destroy();
|
|
201
|
+
// 7. Residual add
|
|
202
|
+
const output = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
203
|
+
{
|
|
204
|
+
const nBuf = createUniformBuffer(d, new Uint32Array([M * dModel]).buffer);
|
|
205
|
+
const bg = createBindGroup(d, this.pipelines['elAdd'], [outProjOut, xBuf, output, nBuf]);
|
|
206
|
+
dispatchKernel(d, this.pipelines['elAdd'], bg, [cdiv(M * dModel, 256), 1, 1]);
|
|
207
|
+
}
|
|
208
|
+
outProjOut.destroy();
|
|
209
|
+
const cache = { stateCarry };
|
|
210
|
+
return { output, cache };
|
|
211
|
+
}
|
|
212
|
+
parameters() {
|
|
213
|
+
const { dModel, dState, dConv, nHeads, nGroups } = this.config;
|
|
214
|
+
const D = this.dInner;
|
|
215
|
+
const N = dState;
|
|
216
|
+
const K = dConv;
|
|
217
|
+
const H = nHeads;
|
|
218
|
+
const G = nGroups;
|
|
219
|
+
const convD = D + 2 * G * N;
|
|
220
|
+
return [
|
|
221
|
+
{ buf: this.gpuWeights['wInProj'], numel: (D + 2 * G * N + H) * dModel, name: 'wInProj' },
|
|
222
|
+
{ buf: this.gpuWeights['wConv'], numel: convD * K, name: 'wConv' },
|
|
223
|
+
{ buf: this.gpuWeights['bConv'], numel: convD, name: 'bConv' },
|
|
224
|
+
{ buf: this.gpuWeights['A_log'], numel: H, name: 'A_log' },
|
|
225
|
+
{ buf: this.gpuWeights['dt_bias'], numel: H, name: 'dt_bias' },
|
|
226
|
+
{ buf: this.gpuWeights['D_vec'], numel: H, name: 'D_vec' },
|
|
227
|
+
{ buf: this.gpuWeights['wOutProj'], numel: dModel * D, name: 'wOutProj' },
|
|
228
|
+
{ buf: this.gpuWeights['normWeight'], numel: D, name: 'normWeight' },
|
|
229
|
+
{ buf: this.gpuWeights['preNormWeight'], numel: dModel, name: 'preNormWeight' },
|
|
230
|
+
];
|
|
231
|
+
}
|
|
232
|
+
getTrainableParams() {
|
|
233
|
+
if (this._wslaMode) {
|
|
234
|
+
// WSLA: train only B/C rows of wInProj (the selective projection part)
|
|
235
|
+
return [
|
|
236
|
+
{ buf: this.gpuWeights['wInProj'],
|
|
237
|
+
numel: (this.config.nGroups * this.config.dState * 2) * this.config.dModel,
|
|
238
|
+
name: 'wInProj_BC' },
|
|
239
|
+
];
|
|
240
|
+
}
|
|
241
|
+
return this.parameters();
|
|
242
|
+
}
|
|
243
|
+
setWSLAMode(enabled) {
|
|
244
|
+
this._wslaMode = enabled;
|
|
245
|
+
}
|
|
246
|
+
destroy() {
|
|
247
|
+
for (const buf of Object.values(this.gpuWeights))
|
|
248
|
+
buf.destroy();
|
|
249
|
+
this.gpuWeights = {};
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
//# sourceMappingURL=mamba2_block.js.map
|