@seanhogg/builderforce-memory-engine 2026.6.18
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +393 -0
- package/dist/index.d.ts +32 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +40 -0
- package/dist/index.js.map +1 -0
- package/dist/kernels/activations.d.ts +5 -0
- package/dist/kernels/activations.d.ts.map +1 -0
- package/dist/kernels/activations.js +171 -0
- package/dist/kernels/activations.js.map +1 -0
- package/dist/kernels/attention.d.ts +19 -0
- package/dist/kernels/attention.d.ts.map +1 -0
- package/dist/kernels/attention.js +263 -0
- package/dist/kernels/attention.js.map +1 -0
- package/dist/kernels/complex_ssd.d.ts +33 -0
- package/dist/kernels/complex_ssd.d.ts.map +1 -0
- package/dist/kernels/complex_ssd.js +305 -0
- package/dist/kernels/complex_ssd.js.map +1 -0
- package/dist/kernels/conv1d.d.ts +3 -0
- package/dist/kernels/conv1d.d.ts.map +1 -0
- package/dist/kernels/conv1d.js +158 -0
- package/dist/kernels/conv1d.js.map +1 -0
- package/dist/kernels/linear_projection.d.ts +3 -0
- package/dist/kernels/linear_projection.d.ts.map +1 -0
- package/dist/kernels/linear_projection.js +219 -0
- package/dist/kernels/linear_projection.js.map +1 -0
- package/dist/kernels/selective_scan.d.ts +3 -0
- package/dist/kernels/selective_scan.d.ts.map +1 -0
- package/dist/kernels/selective_scan.js +348 -0
- package/dist/kernels/selective_scan.js.map +1 -0
- package/dist/kernels/ssd.d.ts +29 -0
- package/dist/kernels/ssd.d.ts.map +1 -0
- package/dist/kernels/ssd.js +276 -0
- package/dist/kernels/ssd.js.map +1 -0
- package/dist/kernels/weight_update.d.ts +3 -0
- package/dist/kernels/weight_update.d.ts.map +1 -0
- package/dist/kernels/weight_update.js +119 -0
- package/dist/kernels/weight_update.js.map +1 -0
- package/dist/model/attention_block.d.ts +48 -0
- package/dist/model/attention_block.d.ts.map +1 -0
- package/dist/model/attention_block.js +262 -0
- package/dist/model/attention_block.js.map +1 -0
- package/dist/model/mamba1_block.d.ts +70 -0
- package/dist/model/mamba1_block.d.ts.map +1 -0
- package/dist/model/mamba1_block.js +333 -0
- package/dist/model/mamba1_block.js.map +1 -0
- package/dist/model/mamba2_block.d.ts +44 -0
- package/dist/model/mamba2_block.d.ts.map +1 -0
- package/dist/model/mamba2_block.js +252 -0
- package/dist/model/mamba2_block.js.map +1 -0
- package/dist/model/mamba3_block.d.ts +51 -0
- package/dist/model/mamba3_block.d.ts.map +1 -0
- package/dist/model/mamba3_block.js +270 -0
- package/dist/model/mamba3_block.js.map +1 -0
- package/dist/model/mamba_block.d.ts +64 -0
- package/dist/model/mamba_block.d.ts.map +1 -0
- package/dist/model/mamba_block.js +303 -0
- package/dist/model/mamba_block.js.map +1 -0
- package/dist/model/mamba_model.d.ts +140 -0
- package/dist/model/mamba_model.d.ts.map +1 -0
- package/dist/model/mamba_model.js +527 -0
- package/dist/model/mamba_model.js.map +1 -0
- package/dist/model/sequence_layer.d.ts +25 -0
- package/dist/model/sequence_layer.d.ts.map +1 -0
- package/dist/model/sequence_layer.js +8 -0
- package/dist/model/sequence_layer.js.map +1 -0
- package/dist/tokenizer/bpe.d.ts +29 -0
- package/dist/tokenizer/bpe.d.ts.map +1 -0
- package/dist/tokenizer/bpe.js +164 -0
- package/dist/tokenizer/bpe.js.map +1 -0
- package/dist/training/autograd.d.ts +27 -0
- package/dist/training/autograd.d.ts.map +1 -0
- package/dist/training/autograd.js +120 -0
- package/dist/training/autograd.js.map +1 -0
- package/dist/training/trainer.d.ts +36 -0
- package/dist/training/trainer.d.ts.map +1 -0
- package/dist/training/trainer.js +183 -0
- package/dist/training/trainer.js.map +1 -0
- package/dist/utils/gpu_utils.d.ts +21 -0
- package/dist/utils/gpu_utils.d.ts.map +1 -0
- package/dist/utils/gpu_utils.js +111 -0
- package/dist/utils/gpu_utils.js.map +1 -0
- package/dist/utils/quantization.d.ts +26 -0
- package/dist/utils/quantization.d.ts.map +1 -0
- package/dist/utils/quantization.js +116 -0
- package/dist/utils/quantization.js.map +1 -0
- package/dist/utils/rng.d.ts +36 -0
- package/dist/utils/rng.d.ts.map +1 -0
- package/dist/utils/rng.js +61 -0
- package/dist/utils/rng.js.map +1 -0
- package/package.json +99 -0
- package/src/index.ts +114 -0
- package/src/kernels/activations.ts +174 -0
- package/src/kernels/attention.ts +268 -0
- package/src/kernels/complex_ssd.ts +307 -0
- package/src/kernels/conv1d.ts +159 -0
- package/src/kernels/linear_projection.ts +220 -0
- package/src/kernels/selective_scan.ts +350 -0
- package/src/kernels/ssd.ts +278 -0
- package/src/kernels/weight_update.ts +120 -0
- package/src/model/attention_block.ts +344 -0
- package/src/model/mamba1_block.ts +437 -0
- package/src/model/mamba2_block.ts +319 -0
- package/src/model/mamba3_block.ts +335 -0
- package/src/model/mamba_block.ts +401 -0
- package/src/model/mamba_model.ts +678 -0
- package/src/model/sequence_layer.ts +29 -0
- package/src/tokenizer/bpe.ts +186 -0
- package/src/training/autograd.ts +135 -0
- package/src/training/trainer.ts +309 -0
- package/src/utils/gpu_utils.ts +147 -0
- package/src/utils/quantization.ts +154 -0
- package/src/utils/rng.ts +65 -0
|
@@ -0,0 +1,401 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* mamba_block.ts – Mamba Mixer Block
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
import {
|
|
6
|
+
createComputePipeline,
|
|
7
|
+
createBindGroup,
|
|
8
|
+
createStorageBuffer,
|
|
9
|
+
createEmptyStorageBuffer,
|
|
10
|
+
createUniformBuffer,
|
|
11
|
+
dispatchKernel,
|
|
12
|
+
cdiv,
|
|
13
|
+
} from '../utils/gpu_utils.js';
|
|
14
|
+
|
|
15
|
+
import { SELECTIVE_SCAN_FORWARD_WGSL } from '../kernels/selective_scan.js';
|
|
16
|
+
import { gaussianArray } from '../utils/rng.js';
|
|
17
|
+
import { CONV1D_FORWARD_WGSL } from '../kernels/conv1d.js';
|
|
18
|
+
import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection.js';
|
|
19
|
+
import { ACTIVATIONS_WGSL } from '../kernels/activations.js';
|
|
20
|
+
|
|
21
|
+
export interface MambaBlockConfig {
|
|
22
|
+
dModel: number;
|
|
23
|
+
dState?: number;
|
|
24
|
+
dConv?: number;
|
|
25
|
+
expand?: number;
|
|
26
|
+
dtRank?: number;
|
|
27
|
+
biasConv?: boolean;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
export interface BlockParam {
|
|
31
|
+
buf: GPUBuffer;
|
|
32
|
+
numel: number;
|
|
33
|
+
name: string;
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
export interface BlockCache {
|
|
37
|
+
normInv: GPUBuffer;
|
|
38
|
+
normIn: GPUBuffer;
|
|
39
|
+
normOut: GPUBuffer;
|
|
40
|
+
zBuf: GPUBuffer;
|
|
41
|
+
xConvIn: GPUBuffer;
|
|
42
|
+
convOut: GPUBuffer;
|
|
43
|
+
siluOut: GPUBuffer;
|
|
44
|
+
deltaFull: GPUBuffer;
|
|
45
|
+
B_raw: GPUBuffer;
|
|
46
|
+
C_raw: GPUBuffer;
|
|
47
|
+
hCache: GPUBuffer;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
export interface BlockForwardResult {
|
|
51
|
+
output: GPUBuffer;
|
|
52
|
+
cache: BlockCache;
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
export class MambaBlock {
|
|
56
|
+
device: GPUDevice;
|
|
57
|
+
config: Required<MambaBlockConfig>;
|
|
58
|
+
dInner: number;
|
|
59
|
+
dtRank: number;
|
|
60
|
+
wInProj: Float32Array;
|
|
61
|
+
bInProj: Float32Array;
|
|
62
|
+
wConv: Float32Array;
|
|
63
|
+
bConv: Float32Array;
|
|
64
|
+
wXProj: Float32Array;
|
|
65
|
+
bXProj: Float32Array;
|
|
66
|
+
wDtProj: Float32Array;
|
|
67
|
+
bDtProj: Float32Array;
|
|
68
|
+
A_log: Float32Array;
|
|
69
|
+
D_vec: Float32Array;
|
|
70
|
+
wOutProj: Float32Array;
|
|
71
|
+
bOutProj: Float32Array;
|
|
72
|
+
normWeight: Float32Array;
|
|
73
|
+
gpuWeights: Record<string, GPUBuffer>;
|
|
74
|
+
pipelines: Record<string, GPUComputePipeline>;
|
|
75
|
+
private _wslaMode = false;
|
|
76
|
+
|
|
77
|
+
constructor(device: GPUDevice, config: MambaBlockConfig) {
|
|
78
|
+
this.device = device;
|
|
79
|
+
this.config = {
|
|
80
|
+
dState : 16,
|
|
81
|
+
dConv : 4,
|
|
82
|
+
expand : 2,
|
|
83
|
+
biasConv: true,
|
|
84
|
+
dtRank : Math.ceil(config.dModel / 16),
|
|
85
|
+
...config,
|
|
86
|
+
} as Required<MambaBlockConfig>;
|
|
87
|
+
|
|
88
|
+
const { dModel, expand } = this.config;
|
|
89
|
+
this.dInner = expand * dModel;
|
|
90
|
+
this.dtRank = config.dtRank ?? Math.ceil(dModel / 16);
|
|
91
|
+
|
|
92
|
+
// Initialize these before _initWeights so TypeScript is happy
|
|
93
|
+
this.wInProj = new Float32Array(0);
|
|
94
|
+
this.bInProj = new Float32Array(0);
|
|
95
|
+
this.wConv = new Float32Array(0);
|
|
96
|
+
this.bConv = new Float32Array(0);
|
|
97
|
+
this.wXProj = new Float32Array(0);
|
|
98
|
+
this.bXProj = new Float32Array(0);
|
|
99
|
+
this.wDtProj = new Float32Array(0);
|
|
100
|
+
this.bDtProj = new Float32Array(0);
|
|
101
|
+
this.A_log = new Float32Array(0);
|
|
102
|
+
this.D_vec = new Float32Array(0);
|
|
103
|
+
this.wOutProj = new Float32Array(0);
|
|
104
|
+
this.bOutProj = new Float32Array(0);
|
|
105
|
+
this.normWeight = new Float32Array(0);
|
|
106
|
+
this.gpuWeights = {};
|
|
107
|
+
this.pipelines = {};
|
|
108
|
+
|
|
109
|
+
this._initWeights();
|
|
110
|
+
this._buildPipelines();
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
private _initWeights(): void {
|
|
114
|
+
const { dModel, dState, dConv } = this.config;
|
|
115
|
+
const D = this.dInner;
|
|
116
|
+
const N = dState;
|
|
117
|
+
const K = dConv;
|
|
118
|
+
const R = this.dtRank;
|
|
119
|
+
|
|
120
|
+
const randn = (n: number, std = 0.02): Float32Array => gaussianArray(n, std);
|
|
121
|
+
|
|
122
|
+
const zeros = (n: number): Float32Array => new Float32Array(n);
|
|
123
|
+
const ones = (n: number): Float32Array => new Float32Array(n).fill(1.0);
|
|
124
|
+
|
|
125
|
+
this.wInProj = randn(2 * D * dModel);
|
|
126
|
+
this.bInProj = zeros(2 * D);
|
|
127
|
+
this.wConv = randn(D * K, 0.01);
|
|
128
|
+
this.bConv = zeros(D);
|
|
129
|
+
this.wXProj = randn((R + 2 * N) * D, 0.01);
|
|
130
|
+
this.bXProj = zeros(R + 2 * N);
|
|
131
|
+
this.wDtProj = randn(D * R, 0.02);
|
|
132
|
+
this.bDtProj = zeros(D);
|
|
133
|
+
|
|
134
|
+
this.A_log = new Float32Array(D * N);
|
|
135
|
+
for (let d = 0; d < D; d++) {
|
|
136
|
+
for (let n = 0; n < N; n++) {
|
|
137
|
+
this.A_log[d * N + n] = Math.log(n + 1);
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
this.D_vec = ones(D);
|
|
142
|
+
this.wOutProj = randn(dModel * D, 0.02);
|
|
143
|
+
this.bOutProj = zeros(dModel);
|
|
144
|
+
this.normWeight = ones(dModel);
|
|
145
|
+
|
|
146
|
+
this._uploadWeightsToGPU();
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
private _uploadWeightsToGPU(): void {
|
|
150
|
+
const d = this.device;
|
|
151
|
+
const mk = (arr: Float32Array, readable = true): GPUBuffer => createStorageBuffer(d, arr, readable);
|
|
152
|
+
|
|
153
|
+
this.gpuWeights = {
|
|
154
|
+
wInProj : mk(this.wInProj),
|
|
155
|
+
bInProj : mk(this.bInProj),
|
|
156
|
+
wConv : mk(this.wConv),
|
|
157
|
+
bConv : mk(this.bConv),
|
|
158
|
+
wXProj : mk(this.wXProj),
|
|
159
|
+
bXProj : mk(this.bXProj),
|
|
160
|
+
wDtProj : mk(this.wDtProj),
|
|
161
|
+
bDtProj : mk(this.bDtProj),
|
|
162
|
+
A_log : mk(this.A_log),
|
|
163
|
+
D_vec : mk(this.D_vec),
|
|
164
|
+
wOutProj : mk(this.wOutProj),
|
|
165
|
+
bOutProj : mk(this.bOutProj),
|
|
166
|
+
normWeight: mk(this.normWeight),
|
|
167
|
+
};
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
private _buildPipelines(): void {
|
|
171
|
+
const d = this.device;
|
|
172
|
+
|
|
173
|
+
this.pipelines = {
|
|
174
|
+
linear : createComputePipeline(d, LINEAR_FORWARD_WGSL, 'linear_forward'),
|
|
175
|
+
conv1d : createComputePipeline(d, CONV1D_FORWARD_WGSL, 'conv1d_forward'),
|
|
176
|
+
silu : createComputePipeline(d, ACTIVATIONS_WGSL, 'silu_forward'),
|
|
177
|
+
rmsnorm : createComputePipeline(d, ACTIVATIONS_WGSL, 'rmsnorm_forward'),
|
|
178
|
+
scan_fwd : createComputePipeline(d, SELECTIVE_SCAN_FORWARD_WGSL, 'forward_scan'),
|
|
179
|
+
scan_reduce: createComputePipeline(d, SELECTIVE_SCAN_FORWARD_WGSL, 'forward_reduce'),
|
|
180
|
+
};
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
forward(xBuf: GPUBuffer, batch: number, seqLen: number): BlockForwardResult {
|
|
184
|
+
const d = this.device;
|
|
185
|
+
const { dModel, dState, dConv } = this.config;
|
|
186
|
+
const D = this.dInner;
|
|
187
|
+
const N = dState;
|
|
188
|
+
const B = batch;
|
|
189
|
+
const L = seqLen;
|
|
190
|
+
const M = B * L;
|
|
191
|
+
const R = this.dtRank;
|
|
192
|
+
|
|
193
|
+
const cache = {} as BlockCache;
|
|
194
|
+
|
|
195
|
+
const normOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
196
|
+
const normInv = createEmptyStorageBuffer(d, M * 4, true);
|
|
197
|
+
cache.normInv = normInv;
|
|
198
|
+
cache.normIn = xBuf;
|
|
199
|
+
|
|
200
|
+
{
|
|
201
|
+
const params = new ArrayBuffer(16);
|
|
202
|
+
new Uint32Array(params, 0, 2).set([M, dModel]);
|
|
203
|
+
new Float32Array(params, 8, 1).set([1e-6]);
|
|
204
|
+
const pBuf = createUniformBuffer(d, params);
|
|
205
|
+
|
|
206
|
+
const bg = createBindGroup(d, this.pipelines['rmsnorm']!,
|
|
207
|
+
[pBuf, xBuf, this.gpuWeights['normWeight']!, normOut, normInv]);
|
|
208
|
+
dispatchKernel(d, this.pipelines['rmsnorm']!, bg, [cdiv(M, 64), 1, 1]);
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
const inProjOut = createEmptyStorageBuffer(d, M * 2 * D * 4, true);
|
|
212
|
+
cache.normOut = normOut;
|
|
213
|
+
{
|
|
214
|
+
const params = new Uint32Array([M, dModel, 2 * D]).buffer;
|
|
215
|
+
const pBuf = createUniformBuffer(d, params);
|
|
216
|
+
const bg = createBindGroup(d, this.pipelines['linear']!,
|
|
217
|
+
[pBuf, normOut, this.gpuWeights['wInProj']!, this.gpuWeights['bInProj']!, inProjOut]);
|
|
218
|
+
dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(2 * D, 16), 1]);
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
const xConvIn = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
222
|
+
const zBuf = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
223
|
+
{
|
|
224
|
+
const enc = d.createCommandEncoder();
|
|
225
|
+
enc.copyBufferToBuffer(inProjOut, 0, xConvIn, 0, M * D * 4);
|
|
226
|
+
enc.copyBufferToBuffer(inProjOut, M * D * 4, zBuf, 0, M * D * 4);
|
|
227
|
+
d.queue.submit([enc.finish()]);
|
|
228
|
+
}
|
|
229
|
+
cache.zBuf = zBuf;
|
|
230
|
+
|
|
231
|
+
const convOut = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
232
|
+
cache.xConvIn = xConvIn;
|
|
233
|
+
{
|
|
234
|
+
const params = new Uint32Array([L, D, dConv, B]).buffer;
|
|
235
|
+
const pBuf = createUniformBuffer(d, params);
|
|
236
|
+
const bg = createBindGroup(d, this.pipelines['conv1d']!,
|
|
237
|
+
[pBuf, xConvIn, this.gpuWeights['wConv']!, this.gpuWeights['bConv']!, convOut]);
|
|
238
|
+
dispatchKernel(d, this.pipelines['conv1d']!, bg, [cdiv(L, 16), cdiv(D, 16), B]);
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
const siluOut = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
242
|
+
cache.convOut = convOut;
|
|
243
|
+
{
|
|
244
|
+
const params = new Uint32Array([M * D]).buffer;
|
|
245
|
+
const pBuf = createUniformBuffer(d, params);
|
|
246
|
+
const bg = createBindGroup(d, this.pipelines['silu']!,
|
|
247
|
+
[pBuf, convOut, siluOut]);
|
|
248
|
+
dispatchKernel(d, this.pipelines['silu']!, bg, [cdiv(M * D, 256), 1, 1]);
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
const xProjOut = createEmptyStorageBuffer(d, M * (R + 2 * N) * 4, true);
|
|
252
|
+
{
|
|
253
|
+
const params = new Uint32Array([M, D, R + 2 * N]).buffer;
|
|
254
|
+
const pBuf = createUniformBuffer(d, params);
|
|
255
|
+
const bg = createBindGroup(d, this.pipelines['linear']!,
|
|
256
|
+
[pBuf, siluOut, this.gpuWeights['wXProj']!, this.gpuWeights['bXProj']!, xProjOut]);
|
|
257
|
+
dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(R + 2 * N, 16), 1]);
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
const dtRaw = createEmptyStorageBuffer(d, M * R * 4, true);
|
|
261
|
+
const B_raw = createEmptyStorageBuffer(d, B * L * N * 4, true);
|
|
262
|
+
const C_raw = createEmptyStorageBuffer(d, B * L * N * 4, true);
|
|
263
|
+
{
|
|
264
|
+
const enc = d.createCommandEncoder();
|
|
265
|
+
enc.copyBufferToBuffer(xProjOut, 0, dtRaw, 0, M * R * 4);
|
|
266
|
+
enc.copyBufferToBuffer(xProjOut, M * R * 4, B_raw, 0, B * L * N * 4);
|
|
267
|
+
enc.copyBufferToBuffer(xProjOut, M * (R + N) * 4, C_raw, 0, B * L * N * 4);
|
|
268
|
+
d.queue.submit([enc.finish()]);
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
const deltaFull = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
272
|
+
{
|
|
273
|
+
const params = new Uint32Array([M, R, D]).buffer;
|
|
274
|
+
const pBuf = createUniformBuffer(d, params);
|
|
275
|
+
const bg = createBindGroup(d, this.pipelines['linear']!,
|
|
276
|
+
[pBuf, dtRaw, this.gpuWeights['wDtProj']!, this.gpuWeights['bDtProj']!, deltaFull]);
|
|
277
|
+
dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(D, 16), 1]);
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
const scanY = createEmptyStorageBuffer(d, B * L * D * 4, true);
|
|
281
|
+
const hCache = createEmptyStorageBuffer(d, 2 * B * L * D * N * 4, true);
|
|
282
|
+
cache.siluOut = siluOut;
|
|
283
|
+
cache.deltaFull = deltaFull;
|
|
284
|
+
cache.B_raw = B_raw;
|
|
285
|
+
cache.C_raw = C_raw;
|
|
286
|
+
cache.hCache = hCache;
|
|
287
|
+
|
|
288
|
+
{
|
|
289
|
+
const params = new Uint32Array([L, N, D, B]).buffer;
|
|
290
|
+
const pBuf = createUniformBuffer(d, params);
|
|
291
|
+
|
|
292
|
+
const bg = createBindGroup(d, this.pipelines['scan_fwd']!,
|
|
293
|
+
[pBuf, siluOut, deltaFull, this.gpuWeights['A_log']!, B_raw, C_raw,
|
|
294
|
+
this.gpuWeights['D_vec']!, scanY, hCache]);
|
|
295
|
+
dispatchKernel(d, this.pipelines['scan_fwd']!, bg,
|
|
296
|
+
[cdiv(D, 8), cdiv(N, 8), B]);
|
|
297
|
+
|
|
298
|
+
const bg2 = createBindGroup(d, this.pipelines['scan_reduce']!,
|
|
299
|
+
[pBuf, siluOut, deltaFull, this.gpuWeights['A_log']!, B_raw, C_raw,
|
|
300
|
+
this.gpuWeights['D_vec']!, scanY, hCache]);
|
|
301
|
+
dispatchKernel(d, this.pipelines['scan_reduce']!, bg2,
|
|
302
|
+
[cdiv(L, 64), D, B]);
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
const siluZ = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
306
|
+
const gatedOut = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
307
|
+
{
|
|
308
|
+
const params = new Uint32Array([M * D]).buffer;
|
|
309
|
+
const pBuf = createUniformBuffer(d, params);
|
|
310
|
+
const bg = createBindGroup(d, this.pipelines['silu']!,
|
|
311
|
+
[pBuf, zBuf, siluZ]);
|
|
312
|
+
dispatchKernel(d, this.pipelines['silu']!, bg, [cdiv(M * D, 256), 1, 1]);
|
|
313
|
+
|
|
314
|
+
const mulShader = /* wgsl */`
|
|
315
|
+
@group(0) @binding(0) var<storage, read> a : array<f32>;
|
|
316
|
+
@group(0) @binding(1) var<storage, read> b : array<f32>;
|
|
317
|
+
@group(0) @binding(2) var<storage, read_write> c : array<f32>;
|
|
318
|
+
@group(0) @binding(3) var<uniform> n : u32;
|
|
319
|
+
@compute @workgroup_size(256)
|
|
320
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
321
|
+
let i = gid.x;
|
|
322
|
+
if (i < n) { c[i] = a[i] * b[i]; }
|
|
323
|
+
}
|
|
324
|
+
`;
|
|
325
|
+
const mulPipeline = createComputePipeline(d, mulShader, 'main');
|
|
326
|
+
const nBuf = createUniformBuffer(d, new Uint32Array([M * D]).buffer);
|
|
327
|
+
const bgMul = createBindGroup(d, mulPipeline,
|
|
328
|
+
[scanY, siluZ, gatedOut, nBuf]);
|
|
329
|
+
dispatchKernel(d, mulPipeline, bgMul, [cdiv(M * D, 256), 1, 1]);
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
const outProjOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
333
|
+
{
|
|
334
|
+
const params = new Uint32Array([M, D, dModel]).buffer;
|
|
335
|
+
const pBuf = createUniformBuffer(d, params);
|
|
336
|
+
const bg = createBindGroup(d, this.pipelines['linear']!,
|
|
337
|
+
[pBuf, gatedOut, this.gpuWeights['wOutProj']!, this.gpuWeights['bOutProj']!, outProjOut]);
|
|
338
|
+
dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(dModel, 16), 1]);
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
const output = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
342
|
+
{
|
|
343
|
+
const addShader = /* wgsl */`
|
|
344
|
+
@group(0) @binding(0) var<storage, read> a : array<f32>;
|
|
345
|
+
@group(0) @binding(1) var<storage, read> b : array<f32>;
|
|
346
|
+
@group(0) @binding(2) var<storage, read_write> c : array<f32>;
|
|
347
|
+
@group(0) @binding(3) var<uniform> n : u32;
|
|
348
|
+
@compute @workgroup_size(256)
|
|
349
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
350
|
+
let i = gid.x;
|
|
351
|
+
if (i < n) { c[i] = a[i] + b[i]; }
|
|
352
|
+
}
|
|
353
|
+
`;
|
|
354
|
+
const addPipeline = createComputePipeline(d, addShader, 'main');
|
|
355
|
+
const nBuf = createUniformBuffer(d, new Uint32Array([M * dModel]).buffer);
|
|
356
|
+
const bgAdd = createBindGroup(d, addPipeline,
|
|
357
|
+
[outProjOut, xBuf, output, nBuf]);
|
|
358
|
+
dispatchKernel(d, addPipeline, bgAdd, [cdiv(M * dModel, 256), 1, 1]);
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
return { output, cache };
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
parameters(): BlockParam[] {
|
|
365
|
+
const { dModel, dState, dConv } = this.config;
|
|
366
|
+
const D = this.dInner;
|
|
367
|
+
const N = dState;
|
|
368
|
+
const K = dConv;
|
|
369
|
+
const R = this.dtRank;
|
|
370
|
+
|
|
371
|
+
return [
|
|
372
|
+
{ buf: this.gpuWeights['wInProj']!, numel: 2 * D * dModel, name: 'wInProj' },
|
|
373
|
+
{ buf: this.gpuWeights['bInProj']!, numel: 2 * D, name: 'bInProj' },
|
|
374
|
+
{ buf: this.gpuWeights['wConv']!, numel: D * K, name: 'wConv' },
|
|
375
|
+
{ buf: this.gpuWeights['bConv']!, numel: D, name: 'bConv' },
|
|
376
|
+
{ buf: this.gpuWeights['wXProj']!, numel: (R + 2*N) * D, name: 'wXProj' },
|
|
377
|
+
{ buf: this.gpuWeights['bXProj']!, numel: R + 2 * N, name: 'bXProj' },
|
|
378
|
+
{ buf: this.gpuWeights['wDtProj']!, numel: D * R, name: 'wDtProj' },
|
|
379
|
+
{ buf: this.gpuWeights['bDtProj']!, numel: D, name: 'bDtProj' },
|
|
380
|
+
{ buf: this.gpuWeights['A_log']!, numel: D * N, name: 'A_log' },
|
|
381
|
+
{ buf: this.gpuWeights['D_vec']!, numel: D, name: 'D_vec' },
|
|
382
|
+
{ buf: this.gpuWeights['wOutProj']!, numel: dModel * D, name: 'wOutProj' },
|
|
383
|
+
{ buf: this.gpuWeights['bOutProj']!, numel: dModel, name: 'bOutProj' },
|
|
384
|
+
{ buf: this.gpuWeights['normWeight']!, numel: dModel, name: 'normWeight'},
|
|
385
|
+
];
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
setWSLAMode(enabled: boolean): void {
|
|
389
|
+
this._wslaMode = enabled;
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
getTrainableParams(): BlockParam[] {
|
|
393
|
+
if (this._wslaMode) {
|
|
394
|
+
return [
|
|
395
|
+
{ buf: this.gpuWeights['wXProj']!, numel: this.wXProj.length, name: 'wXProj' },
|
|
396
|
+
{ buf: this.gpuWeights['bXProj']!, numel: this.bXProj.length, name: 'bXProj' },
|
|
397
|
+
];
|
|
398
|
+
}
|
|
399
|
+
return this.parameters();
|
|
400
|
+
}
|
|
401
|
+
}
|