@seanhogg/builderforce-memory-engine 2026.6.18
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +393 -0
- package/dist/index.d.ts +32 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +40 -0
- package/dist/index.js.map +1 -0
- package/dist/kernels/activations.d.ts +5 -0
- package/dist/kernels/activations.d.ts.map +1 -0
- package/dist/kernels/activations.js +171 -0
- package/dist/kernels/activations.js.map +1 -0
- package/dist/kernels/attention.d.ts +19 -0
- package/dist/kernels/attention.d.ts.map +1 -0
- package/dist/kernels/attention.js +263 -0
- package/dist/kernels/attention.js.map +1 -0
- package/dist/kernels/complex_ssd.d.ts +33 -0
- package/dist/kernels/complex_ssd.d.ts.map +1 -0
- package/dist/kernels/complex_ssd.js +305 -0
- package/dist/kernels/complex_ssd.js.map +1 -0
- package/dist/kernels/conv1d.d.ts +3 -0
- package/dist/kernels/conv1d.d.ts.map +1 -0
- package/dist/kernels/conv1d.js +158 -0
- package/dist/kernels/conv1d.js.map +1 -0
- package/dist/kernels/linear_projection.d.ts +3 -0
- package/dist/kernels/linear_projection.d.ts.map +1 -0
- package/dist/kernels/linear_projection.js +219 -0
- package/dist/kernels/linear_projection.js.map +1 -0
- package/dist/kernels/selective_scan.d.ts +3 -0
- package/dist/kernels/selective_scan.d.ts.map +1 -0
- package/dist/kernels/selective_scan.js +348 -0
- package/dist/kernels/selective_scan.js.map +1 -0
- package/dist/kernels/ssd.d.ts +29 -0
- package/dist/kernels/ssd.d.ts.map +1 -0
- package/dist/kernels/ssd.js +276 -0
- package/dist/kernels/ssd.js.map +1 -0
- package/dist/kernels/weight_update.d.ts +3 -0
- package/dist/kernels/weight_update.d.ts.map +1 -0
- package/dist/kernels/weight_update.js +119 -0
- package/dist/kernels/weight_update.js.map +1 -0
- package/dist/model/attention_block.d.ts +48 -0
- package/dist/model/attention_block.d.ts.map +1 -0
- package/dist/model/attention_block.js +262 -0
- package/dist/model/attention_block.js.map +1 -0
- package/dist/model/mamba1_block.d.ts +70 -0
- package/dist/model/mamba1_block.d.ts.map +1 -0
- package/dist/model/mamba1_block.js +333 -0
- package/dist/model/mamba1_block.js.map +1 -0
- package/dist/model/mamba2_block.d.ts +44 -0
- package/dist/model/mamba2_block.d.ts.map +1 -0
- package/dist/model/mamba2_block.js +252 -0
- package/dist/model/mamba2_block.js.map +1 -0
- package/dist/model/mamba3_block.d.ts +51 -0
- package/dist/model/mamba3_block.d.ts.map +1 -0
- package/dist/model/mamba3_block.js +270 -0
- package/dist/model/mamba3_block.js.map +1 -0
- package/dist/model/mamba_block.d.ts +64 -0
- package/dist/model/mamba_block.d.ts.map +1 -0
- package/dist/model/mamba_block.js +303 -0
- package/dist/model/mamba_block.js.map +1 -0
- package/dist/model/mamba_model.d.ts +140 -0
- package/dist/model/mamba_model.d.ts.map +1 -0
- package/dist/model/mamba_model.js +527 -0
- package/dist/model/mamba_model.js.map +1 -0
- package/dist/model/sequence_layer.d.ts +25 -0
- package/dist/model/sequence_layer.d.ts.map +1 -0
- package/dist/model/sequence_layer.js +8 -0
- package/dist/model/sequence_layer.js.map +1 -0
- package/dist/tokenizer/bpe.d.ts +29 -0
- package/dist/tokenizer/bpe.d.ts.map +1 -0
- package/dist/tokenizer/bpe.js +164 -0
- package/dist/tokenizer/bpe.js.map +1 -0
- package/dist/training/autograd.d.ts +27 -0
- package/dist/training/autograd.d.ts.map +1 -0
- package/dist/training/autograd.js +120 -0
- package/dist/training/autograd.js.map +1 -0
- package/dist/training/trainer.d.ts +36 -0
- package/dist/training/trainer.d.ts.map +1 -0
- package/dist/training/trainer.js +183 -0
- package/dist/training/trainer.js.map +1 -0
- package/dist/utils/gpu_utils.d.ts +21 -0
- package/dist/utils/gpu_utils.d.ts.map +1 -0
- package/dist/utils/gpu_utils.js +111 -0
- package/dist/utils/gpu_utils.js.map +1 -0
- package/dist/utils/quantization.d.ts +26 -0
- package/dist/utils/quantization.d.ts.map +1 -0
- package/dist/utils/quantization.js +116 -0
- package/dist/utils/quantization.js.map +1 -0
- package/dist/utils/rng.d.ts +36 -0
- package/dist/utils/rng.d.ts.map +1 -0
- package/dist/utils/rng.js +61 -0
- package/dist/utils/rng.js.map +1 -0
- package/package.json +99 -0
- package/src/index.ts +114 -0
- package/src/kernels/activations.ts +174 -0
- package/src/kernels/attention.ts +268 -0
- package/src/kernels/complex_ssd.ts +307 -0
- package/src/kernels/conv1d.ts +159 -0
- package/src/kernels/linear_projection.ts +220 -0
- package/src/kernels/selective_scan.ts +350 -0
- package/src/kernels/ssd.ts +278 -0
- package/src/kernels/weight_update.ts +120 -0
- package/src/model/attention_block.ts +344 -0
- package/src/model/mamba1_block.ts +437 -0
- package/src/model/mamba2_block.ts +319 -0
- package/src/model/mamba3_block.ts +335 -0
- package/src/model/mamba_block.ts +401 -0
- package/src/model/mamba_model.ts +678 -0
- package/src/model/sequence_layer.ts +29 -0
- package/src/tokenizer/bpe.ts +186 -0
- package/src/training/autograd.ts +135 -0
- package/src/training/trainer.ts +309 -0
- package/src/utils/gpu_utils.ts +147 -0
- package/src/utils/quantization.ts +154 -0
- package/src/utils/rng.ts +65 -0
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* mamba2_block.ts – Mamba-2 Mixer Block (Structured State Space Duality).
|
|
3
|
+
*
|
|
4
|
+
* Key differences from Mamba-1:
|
|
5
|
+
* - Multi-head SSM with scalar A per head
|
|
6
|
+
* - Single fused in_proj (no separate dt_proj expansion)
|
|
7
|
+
* - SSD (chunked) scan replaces S6 selective scan
|
|
8
|
+
* - Inner RMSNorm on scan output instead of SiLU gate
|
|
9
|
+
* - No separate z gate
|
|
10
|
+
*
|
|
11
|
+
* Implements SequenceLayer.
|
|
12
|
+
*/
|
|
13
|
+
|
|
14
|
+
import {
|
|
15
|
+
createComputePipeline,
|
|
16
|
+
createBindGroup,
|
|
17
|
+
createStorageBuffer,
|
|
18
|
+
createEmptyStorageBuffer,
|
|
19
|
+
createUniformBuffer,
|
|
20
|
+
dispatchKernel,
|
|
21
|
+
cdiv,
|
|
22
|
+
} from '../utils/gpu_utils.js';
|
|
23
|
+
|
|
24
|
+
import { SSD_FORWARD_WGSL } from '../kernels/ssd.js';
|
|
25
|
+
import { gaussianArray } from '../utils/rng.js';
|
|
26
|
+
import { CONV1D_FORWARD_WGSL } from '../kernels/conv1d.js';
|
|
27
|
+
import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection.js';
|
|
28
|
+
import { ACTIVATIONS_WGSL } from '../kernels/activations.js';
|
|
29
|
+
|
|
30
|
+
import type { SequenceLayer, LayerForwardResult, LayerParam } from './sequence_layer.js';
|
|
31
|
+
|
|
32
|
+
export interface Mamba2BlockConfig {
|
|
33
|
+
dModel : number;
|
|
34
|
+
dState : number; // N — state dim per group
|
|
35
|
+
dConv : number; // K — conv kernel width
|
|
36
|
+
expand : number; // dInner = expand * dModel
|
|
37
|
+
nHeads : number; // H — number of SSM heads
|
|
38
|
+
nGroups : number; // number of B/C groups (default 1)
|
|
39
|
+
chunkLen : number; // SSD chunk length (default 256)
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
export interface Mamba2Cache {
|
|
43
|
+
stateCarry : GPUBuffer; // inter-chunk states
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
const ADD_SHADER = /* wgsl */`
|
|
47
|
+
@group(0) @binding(0) var<storage, read> a : array<f32>;
|
|
48
|
+
@group(0) @binding(1) var<storage, read> b : array<f32>;
|
|
49
|
+
@group(0) @binding(2) var<storage, read_write> c : array<f32>;
|
|
50
|
+
@group(0) @binding(3) var<uniform> n : u32;
|
|
51
|
+
@compute @workgroup_size(256)
|
|
52
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
53
|
+
let i = gid.x;
|
|
54
|
+
if (i < n) { c[i] = a[i] + b[i]; }
|
|
55
|
+
}
|
|
56
|
+
`;
|
|
57
|
+
|
|
58
|
+
export class Mamba2Block implements SequenceLayer {
|
|
59
|
+
readonly layerType = 'mamba2' as const;
|
|
60
|
+
|
|
61
|
+
device : GPUDevice;
|
|
62
|
+
config : Required<Mamba2BlockConfig>;
|
|
63
|
+
dInner : number;
|
|
64
|
+
dHead : number;
|
|
65
|
+
|
|
66
|
+
gpuWeights : Record<string, GPUBuffer>;
|
|
67
|
+
pipelines : Record<string, GPUComputePipeline>;
|
|
68
|
+
|
|
69
|
+
private _wslaMode = false;
|
|
70
|
+
|
|
71
|
+
constructor(device: GPUDevice, config: Mamba2BlockConfig) {
|
|
72
|
+
this.device = device;
|
|
73
|
+
this.config = {
|
|
74
|
+
...{ dState: 16, dConv: 4, expand: 2, nGroups: 1, chunkLen: 256 },
|
|
75
|
+
...config,
|
|
76
|
+
} as Required<Mamba2BlockConfig>;
|
|
77
|
+
|
|
78
|
+
const { dModel, expand, nHeads } = this.config;
|
|
79
|
+
this.dInner = expand * dModel;
|
|
80
|
+
this.dHead = this.dInner / nHeads;
|
|
81
|
+
|
|
82
|
+
if (this.dInner % nHeads !== 0) {
|
|
83
|
+
throw new Error(
|
|
84
|
+
`Mamba2Block: dInner (${this.dInner}) must be divisible by nHeads (${nHeads}).`
|
|
85
|
+
);
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
this.gpuWeights = {};
|
|
89
|
+
this.pipelines = {};
|
|
90
|
+
|
|
91
|
+
this._initWeights();
|
|
92
|
+
this._buildPipelines();
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
private _initWeights(): void {
|
|
96
|
+
const { dModel, dState, dConv, nHeads, nGroups } = this.config;
|
|
97
|
+
const D = this.dInner;
|
|
98
|
+
const N = dState;
|
|
99
|
+
const K = dConv;
|
|
100
|
+
const H = nHeads;
|
|
101
|
+
const G = nGroups;
|
|
102
|
+
|
|
103
|
+
const randn = (n: number, std = 0.02): Float32Array => gaussianArray(n, std);
|
|
104
|
+
|
|
105
|
+
const zeros = (n: number) => new Float32Array(n);
|
|
106
|
+
const ones = (n: number) => new Float32Array(n).fill(1.0);
|
|
107
|
+
|
|
108
|
+
// wInProj: (D_inner + 2*n_groups*N + H, D_model) — no bias per Mamba-2 spec
|
|
109
|
+
const inProjRows = D + 2 * G * N + H;
|
|
110
|
+
const mk = (arr: Float32Array) => createStorageBuffer(this.device, arr, true);
|
|
111
|
+
|
|
112
|
+
this.gpuWeights = {
|
|
113
|
+
wInProj : mk(randn(inProjRows * dModel)),
|
|
114
|
+
wConv : mk(randn((D + 2 * G * N) * K, 0.01)),
|
|
115
|
+
bConv : mk(zeros(D + 2 * G * N)),
|
|
116
|
+
A_log : mk(new Float32Array(H).fill(Math.log(1.0))),
|
|
117
|
+
dt_bias : mk(zeros(H)),
|
|
118
|
+
D_vec : mk(ones(H)),
|
|
119
|
+
wOutProj : mk(randn(dModel * D, 0.02)),
|
|
120
|
+
normWeight : mk(ones(D)), // inner RMSNorm
|
|
121
|
+
preNormWeight: mk(ones(dModel)), // pre-block RMSNorm
|
|
122
|
+
};
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
private _buildPipelines(): void {
|
|
126
|
+
const d = this.device;
|
|
127
|
+
this.pipelines = {
|
|
128
|
+
linear : createComputePipeline(d, LINEAR_FORWARD_WGSL, 'linear_forward'),
|
|
129
|
+
conv1d : createComputePipeline(d, CONV1D_FORWARD_WGSL, 'conv1d_forward'),
|
|
130
|
+
rmsnorm : createComputePipeline(d, ACTIVATIONS_WGSL, 'rmsnorm_forward'),
|
|
131
|
+
ssd_fwd : createComputePipeline(d, SSD_FORWARD_WGSL, 'ssd_chunk_forward'),
|
|
132
|
+
elAdd : createComputePipeline(d, ADD_SHADER, 'main'),
|
|
133
|
+
};
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
forward(xBuf: GPUBuffer, batch: number, seqLen: number): LayerForwardResult {
|
|
137
|
+
const d = this.device;
|
|
138
|
+
const { dModel, dState, dConv, nHeads, nGroups, chunkLen } = this.config;
|
|
139
|
+
const D = this.dInner;
|
|
140
|
+
const N = dState;
|
|
141
|
+
const K = dConv;
|
|
142
|
+
const H = nHeads;
|
|
143
|
+
const G = nGroups;
|
|
144
|
+
const dh = this.dHead;
|
|
145
|
+
const B = batch;
|
|
146
|
+
const L = seqLen;
|
|
147
|
+
const M = B * L;
|
|
148
|
+
const convD = D + 2 * G * N; // channels for conv (x, B_proj, C_proj)
|
|
149
|
+
const numChunks = Math.ceil(L / chunkLen);
|
|
150
|
+
|
|
151
|
+
// 1. Pre-block RMSNorm
|
|
152
|
+
const normOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
153
|
+
const normInv = createEmptyStorageBuffer(d, M * 4, true);
|
|
154
|
+
{
|
|
155
|
+
const params = new ArrayBuffer(16);
|
|
156
|
+
new Uint32Array(params, 0, 2).set([M, dModel]);
|
|
157
|
+
new Float32Array(params, 8, 1).set([1e-6]);
|
|
158
|
+
const pBuf = createUniformBuffer(d, params);
|
|
159
|
+
const bg = createBindGroup(d, this.pipelines['rmsnorm']!,
|
|
160
|
+
[pBuf, xBuf, this.gpuWeights['preNormWeight']!, normOut, normInv]);
|
|
161
|
+
dispatchKernel(d, this.pipelines['rmsnorm']!, bg, [cdiv(M, 64), 1, 1]);
|
|
162
|
+
}
|
|
163
|
+
normInv.destroy();
|
|
164
|
+
|
|
165
|
+
// 2. Fused in_proj → [x (D), B_proj (G*N), C_proj (G*N), dt (H)]
|
|
166
|
+
const inProjRows = D + 2 * G * N + H;
|
|
167
|
+
const inProjOut = createEmptyStorageBuffer(d, M * inProjRows * 4, true);
|
|
168
|
+
{
|
|
169
|
+
const params = new Uint32Array([M, dModel, inProjRows]).buffer;
|
|
170
|
+
const pBuf = createUniformBuffer(d, params);
|
|
171
|
+
// wInProj has no bias — pass a zero-filled buffer
|
|
172
|
+
const zeroBias = createStorageBuffer(d, new Float32Array(inProjRows), true);
|
|
173
|
+
const bg = createBindGroup(d, this.pipelines['linear']!,
|
|
174
|
+
[pBuf, normOut, this.gpuWeights['wInProj']!, zeroBias, inProjOut]);
|
|
175
|
+
dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(inProjRows, 16), 1]);
|
|
176
|
+
zeroBias.destroy();
|
|
177
|
+
}
|
|
178
|
+
normOut.destroy();
|
|
179
|
+
|
|
180
|
+
// Split: xConv [D+2GN], dt [H]
|
|
181
|
+
const xConvBuf = createEmptyStorageBuffer(d, M * convD * 4, true);
|
|
182
|
+
const dtBuf = createEmptyStorageBuffer(d, M * H * 4, true);
|
|
183
|
+
{
|
|
184
|
+
const enc = d.createCommandEncoder();
|
|
185
|
+
enc.copyBufferToBuffer(inProjOut, 0, xConvBuf, 0, M * convD * 4);
|
|
186
|
+
enc.copyBufferToBuffer(inProjOut, M * convD * 4, dtBuf, 0, M * H * 4);
|
|
187
|
+
d.queue.submit([enc.finish()]);
|
|
188
|
+
}
|
|
189
|
+
inProjOut.destroy();
|
|
190
|
+
|
|
191
|
+
// 3. Causal conv1d over x + B_proj + C_proj (fused, convD channels)
|
|
192
|
+
const convOut = createEmptyStorageBuffer(d, M * convD * 4, true);
|
|
193
|
+
{
|
|
194
|
+
const params = new Uint32Array([L, convD, K, B, 1]).buffer;
|
|
195
|
+
const pBuf = createUniformBuffer(d, params);
|
|
196
|
+
const bg = createBindGroup(d, this.pipelines['conv1d']!,
|
|
197
|
+
[pBuf, xConvBuf, this.gpuWeights['wConv']!, this.gpuWeights['bConv']!, convOut]);
|
|
198
|
+
dispatchKernel(d, this.pipelines['conv1d']!, bg, [cdiv(L, 16), cdiv(convD, 16), B]);
|
|
199
|
+
}
|
|
200
|
+
xConvBuf.destroy();
|
|
201
|
+
|
|
202
|
+
// Split conv output: x [D], B_proj [G*N], C_proj [G*N]
|
|
203
|
+
const xSsdBuf = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
204
|
+
const bProjBuf = createEmptyStorageBuffer(d, M * G * N * 4, true);
|
|
205
|
+
const cProjBuf = createEmptyStorageBuffer(d, M * G * N * 4, true);
|
|
206
|
+
{
|
|
207
|
+
const enc = d.createCommandEncoder();
|
|
208
|
+
enc.copyBufferToBuffer(convOut, 0, xSsdBuf, 0, M * D * 4);
|
|
209
|
+
enc.copyBufferToBuffer(convOut, M * D * 4, bProjBuf, 0, M * G * N * 4);
|
|
210
|
+
enc.copyBufferToBuffer(convOut, M * (D + G * N) * 4, cProjBuf, 0, M * G * N * 4);
|
|
211
|
+
d.queue.submit([enc.finish()]);
|
|
212
|
+
}
|
|
213
|
+
convOut.destroy();
|
|
214
|
+
|
|
215
|
+
// 4. SSD scan
|
|
216
|
+
// state_carry: [numChunks+1, B, H, N, dHead]
|
|
217
|
+
const stateCarry = createEmptyStorageBuffer(
|
|
218
|
+
d, (numChunks + 1) * B * H * N * dh * 4, true);
|
|
219
|
+
const ssdOut = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
220
|
+
|
|
221
|
+
{
|
|
222
|
+
const ssdParams = new Uint32Array([L, D, H, dh, G, N, chunkLen, numChunks, B]).buffer;
|
|
223
|
+
const pBuf = createUniformBuffer(d, ssdParams);
|
|
224
|
+
const bg = createBindGroup(d, this.pipelines['ssd_fwd']!,
|
|
225
|
+
[pBuf, xSsdBuf, bProjBuf, cProjBuf, dtBuf,
|
|
226
|
+
this.gpuWeights['A_log']!, this.gpuWeights['dt_bias']!,
|
|
227
|
+
this.gpuWeights['D_vec']!, ssdOut, stateCarry]);
|
|
228
|
+
dispatchKernel(d, this.pipelines['ssd_fwd']!, bg, [numChunks, H, B]);
|
|
229
|
+
}
|
|
230
|
+
xSsdBuf.destroy();
|
|
231
|
+
bProjBuf.destroy();
|
|
232
|
+
cProjBuf.destroy();
|
|
233
|
+
dtBuf.destroy();
|
|
234
|
+
|
|
235
|
+
// 5. Inner RMSNorm on scan output
|
|
236
|
+
const innerNormOut = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
237
|
+
const innerNormInv = createEmptyStorageBuffer(d, M * 4, true);
|
|
238
|
+
{
|
|
239
|
+
const params = new ArrayBuffer(16);
|
|
240
|
+
new Uint32Array(params, 0, 2).set([M, D]);
|
|
241
|
+
new Float32Array(params, 8, 1).set([1e-6]);
|
|
242
|
+
const pBuf = createUniformBuffer(d, params);
|
|
243
|
+
const bg = createBindGroup(d, this.pipelines['rmsnorm']!,
|
|
244
|
+
[pBuf, ssdOut, this.gpuWeights['normWeight']!, innerNormOut, innerNormInv]);
|
|
245
|
+
dispatchKernel(d, this.pipelines['rmsnorm']!, bg, [cdiv(M, 64), 1, 1]);
|
|
246
|
+
}
|
|
247
|
+
ssdOut.destroy();
|
|
248
|
+
innerNormInv.destroy();
|
|
249
|
+
|
|
250
|
+
// 6. Output projection
|
|
251
|
+
const outProjOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
252
|
+
{
|
|
253
|
+
const params = new Uint32Array([M, D, dModel]).buffer;
|
|
254
|
+
const pBuf = createUniformBuffer(d, params);
|
|
255
|
+
const zeroBias = createStorageBuffer(d, new Float32Array(dModel), true);
|
|
256
|
+
const bg = createBindGroup(d, this.pipelines['linear']!,
|
|
257
|
+
[pBuf, innerNormOut, this.gpuWeights['wOutProj']!, zeroBias, outProjOut]);
|
|
258
|
+
dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(dModel, 16), 1]);
|
|
259
|
+
zeroBias.destroy();
|
|
260
|
+
}
|
|
261
|
+
innerNormOut.destroy();
|
|
262
|
+
|
|
263
|
+
// 7. Residual add
|
|
264
|
+
const output = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
265
|
+
{
|
|
266
|
+
const nBuf = createUniformBuffer(d, new Uint32Array([M * dModel]).buffer);
|
|
267
|
+
const bg = createBindGroup(d, this.pipelines['elAdd']!,
|
|
268
|
+
[outProjOut, xBuf, output, nBuf]);
|
|
269
|
+
dispatchKernel(d, this.pipelines['elAdd']!, bg, [cdiv(M * dModel, 256), 1, 1]);
|
|
270
|
+
}
|
|
271
|
+
outProjOut.destroy();
|
|
272
|
+
|
|
273
|
+
const cache: Mamba2Cache = { stateCarry };
|
|
274
|
+
return { output, cache };
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
parameters(): LayerParam[] {
|
|
278
|
+
const { dModel, dState, dConv, nHeads, nGroups } = this.config;
|
|
279
|
+
const D = this.dInner;
|
|
280
|
+
const N = dState;
|
|
281
|
+
const K = dConv;
|
|
282
|
+
const H = nHeads;
|
|
283
|
+
const G = nGroups;
|
|
284
|
+
const convD = D + 2 * G * N;
|
|
285
|
+
|
|
286
|
+
return [
|
|
287
|
+
{ buf: this.gpuWeights['wInProj']!, numel: (D + 2 * G * N + H) * dModel, name: 'wInProj' },
|
|
288
|
+
{ buf: this.gpuWeights['wConv']!, numel: convD * K, name: 'wConv' },
|
|
289
|
+
{ buf: this.gpuWeights['bConv']!, numel: convD, name: 'bConv' },
|
|
290
|
+
{ buf: this.gpuWeights['A_log']!, numel: H, name: 'A_log' },
|
|
291
|
+
{ buf: this.gpuWeights['dt_bias']!, numel: H, name: 'dt_bias' },
|
|
292
|
+
{ buf: this.gpuWeights['D_vec']!, numel: H, name: 'D_vec' },
|
|
293
|
+
{ buf: this.gpuWeights['wOutProj']!, numel: dModel * D, name: 'wOutProj' },
|
|
294
|
+
{ buf: this.gpuWeights['normWeight']!, numel: D, name: 'normWeight' },
|
|
295
|
+
{ buf: this.gpuWeights['preNormWeight']!, numel: dModel, name: 'preNormWeight'},
|
|
296
|
+
];
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
getTrainableParams(): LayerParam[] {
|
|
300
|
+
if (this._wslaMode) {
|
|
301
|
+
// WSLA: train only B/C rows of wInProj (the selective projection part)
|
|
302
|
+
return [
|
|
303
|
+
{ buf: this.gpuWeights['wInProj']!,
|
|
304
|
+
numel: (this.config.nGroups * this.config.dState * 2) * this.config.dModel,
|
|
305
|
+
name: 'wInProj_BC' },
|
|
306
|
+
];
|
|
307
|
+
}
|
|
308
|
+
return this.parameters();
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
setWSLAMode(enabled: boolean): void {
|
|
312
|
+
this._wslaMode = enabled;
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
destroy(): void {
|
|
316
|
+
for (const buf of Object.values(this.gpuWeights)) buf.destroy();
|
|
317
|
+
this.gpuWeights = {};
|
|
318
|
+
}
|
|
319
|
+
}
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* mamba3_block.ts – Mamba-3 Mixer Block (Complex-valued MIMO SSM, inference-first).
|
|
3
|
+
*
|
|
4
|
+
* Three improvements over Mamba-2:
|
|
5
|
+
* 1. Complex-valued states — h ∈ ℂ^(N/2), stored as interleaved f32 pairs
|
|
6
|
+
* 2. MIMO recurrence — G×G block recurrence per head (default G=1 = SISO)
|
|
7
|
+
* 3. ET discretisation — B_bar = (A_bar − 1)·A⁻¹·B (exact, not approx)
|
|
8
|
+
*
|
|
9
|
+
* Weight shapes vs Mamba-2 (same 9 tensors, different A_log shape):
|
|
10
|
+
* wInProj : (D + 2*G*N_c*2 + H, dModel) where N_c = dState (complex count)
|
|
11
|
+
* wConv : (D + 2*G*N_c*2, K)
|
|
12
|
+
* bConv : (D + 2*G*N_c*2,)
|
|
13
|
+
* A_log : (H, 2) ← [log|A|, arg(A)] per head
|
|
14
|
+
* dt_bias : (H,)
|
|
15
|
+
* D_vec : (H,)
|
|
16
|
+
* wOutProj : (dModel, D)
|
|
17
|
+
* normWeight : (D,)
|
|
18
|
+
* preNormWeight: (dModel,)
|
|
19
|
+
*
|
|
20
|
+
* Implements SequenceLayer.
|
|
21
|
+
*/
|
|
22
|
+
|
|
23
|
+
import {
|
|
24
|
+
createComputePipeline,
|
|
25
|
+
createBindGroup,
|
|
26
|
+
createStorageBuffer,
|
|
27
|
+
createEmptyStorageBuffer,
|
|
28
|
+
createUniformBuffer,
|
|
29
|
+
dispatchKernel,
|
|
30
|
+
cdiv,
|
|
31
|
+
} from '../utils/gpu_utils.js';
|
|
32
|
+
|
|
33
|
+
import { COMPLEX_SSD_FORWARD_WGSL } from '../kernels/complex_ssd.js';
|
|
34
|
+
import { gaussianArray } from '../utils/rng.js';
|
|
35
|
+
import { CONV1D_FORWARD_WGSL } from '../kernels/conv1d.js';
|
|
36
|
+
import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection.js';
|
|
37
|
+
import { ACTIVATIONS_WGSL } from '../kernels/activations.js';
|
|
38
|
+
|
|
39
|
+
import type { Mamba2BlockConfig } from './mamba2_block.js';
|
|
40
|
+
import type { SequenceLayer, LayerForwardResult, LayerParam } from './sequence_layer.js';
|
|
41
|
+
|
|
42
|
+
export interface Mamba3BlockConfig extends Mamba2BlockConfig {
|
|
43
|
+
/** MIMO group size G. Default 1 = SISO (same as Mamba-2). */
|
|
44
|
+
mimoGroup?: number;
|
|
45
|
+
// dState here is the complex state count N_c (real state count = 2*N_c)
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
export interface Mamba3Cache {
|
|
49
|
+
stateCarry: GPUBuffer;
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
const ADD_SHADER = /* wgsl */`
|
|
53
|
+
@group(0) @binding(0) var<storage, read> a : array<f32>;
|
|
54
|
+
@group(0) @binding(1) var<storage, read> b : array<f32>;
|
|
55
|
+
@group(0) @binding(2) var<storage, read_write> c : array<f32>;
|
|
56
|
+
@group(0) @binding(3) var<uniform> n : u32;
|
|
57
|
+
@compute @workgroup_size(256)
|
|
58
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
59
|
+
let i = gid.x;
|
|
60
|
+
if (i < n) { c[i] = a[i] + b[i]; }
|
|
61
|
+
}
|
|
62
|
+
`;
|
|
63
|
+
|
|
64
|
+
export class Mamba3Block implements SequenceLayer {
|
|
65
|
+
readonly layerType = 'mamba3' as const;
|
|
66
|
+
|
|
67
|
+
device : GPUDevice;
|
|
68
|
+
config : Required<Mamba3BlockConfig>;
|
|
69
|
+
dInner : number;
|
|
70
|
+
dHead : number;
|
|
71
|
+
/** Complex state count per head (N_c = dState in config). */
|
|
72
|
+
nComplex : number;
|
|
73
|
+
|
|
74
|
+
gpuWeights: Record<string, GPUBuffer>;
|
|
75
|
+
pipelines : Record<string, GPUComputePipeline>;
|
|
76
|
+
|
|
77
|
+
private _wslaMode = false;
|
|
78
|
+
|
|
79
|
+
constructor(device: GPUDevice, config: Mamba3BlockConfig) {
|
|
80
|
+
this.device = device;
|
|
81
|
+
this.config = {
|
|
82
|
+
...{ dState: 16, dConv: 4, expand: 2, nGroups: 1, chunkLen: 256, mimoGroup: 1 },
|
|
83
|
+
...config,
|
|
84
|
+
} as Required<Mamba3BlockConfig>;
|
|
85
|
+
|
|
86
|
+
const { dModel, expand, nHeads } = this.config;
|
|
87
|
+
this.dInner = expand * dModel;
|
|
88
|
+
this.dHead = this.dInner / nHeads;
|
|
89
|
+
this.nComplex = this.config.dState; // N_c
|
|
90
|
+
|
|
91
|
+
if (this.dInner % nHeads !== 0) {
|
|
92
|
+
throw new Error(
|
|
93
|
+
`Mamba3Block: dInner (${this.dInner}) must be divisible by nHeads (${nHeads}).`
|
|
94
|
+
);
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
this.gpuWeights = {};
|
|
98
|
+
this.pipelines = {};
|
|
99
|
+
|
|
100
|
+
this._initWeights();
|
|
101
|
+
this._buildPipelines();
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
private _initWeights(): void {
|
|
105
|
+
const { dModel, dConv, nHeads, nGroups } = this.config;
|
|
106
|
+
const D = this.dInner;
|
|
107
|
+
const Nc = this.nComplex;
|
|
108
|
+
const K = dConv;
|
|
109
|
+
const H = nHeads;
|
|
110
|
+
const G = nGroups;
|
|
111
|
+
// Each complex state = 2 f32 values
|
|
112
|
+
const convD = D + 2 * G * Nc * 2; // x-channels + complex B/C
|
|
113
|
+
|
|
114
|
+
const randn = (n: number, std = 0.02): Float32Array => gaussianArray(n, std);
|
|
115
|
+
|
|
116
|
+
const zeros = (n: number) => new Float32Array(n);
|
|
117
|
+
const ones = (n: number) => new Float32Array(n).fill(1.0);
|
|
118
|
+
|
|
119
|
+
// A_log: (H, 2) = [log|A|, arg(A)] per head
|
|
120
|
+
// Initialise to unit magnitude (|A|=1, phase=0) → purely oscillatory
|
|
121
|
+
const A_log = new Float32Array(H * 2);
|
|
122
|
+
for (let h = 0; h < H; h++) {
|
|
123
|
+
A_log[h * 2 + 0] = 0.0; // log|A| = 0 → |A| = 1
|
|
124
|
+
A_log[h * 2 + 1] = (2 * Math.PI * h) / H; // evenly spaced phases
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
const mk = (arr: Float32Array) => createStorageBuffer(this.device, arr, true);
|
|
128
|
+
|
|
129
|
+
const inProjRows = D + 2 * G * Nc * 2 + H;
|
|
130
|
+
this.gpuWeights = {
|
|
131
|
+
wInProj : mk(randn(inProjRows * dModel)),
|
|
132
|
+
wConv : mk(randn(convD * K, 0.01)),
|
|
133
|
+
bConv : mk(zeros(convD)),
|
|
134
|
+
A_log : mk(A_log),
|
|
135
|
+
dt_bias : mk(zeros(H)),
|
|
136
|
+
D_vec : mk(ones(H)),
|
|
137
|
+
wOutProj : mk(randn(dModel * D, 0.02)),
|
|
138
|
+
normWeight : mk(ones(D)),
|
|
139
|
+
preNormWeight: mk(ones(dModel)),
|
|
140
|
+
};
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
private _buildPipelines(): void {
|
|
144
|
+
const d = this.device;
|
|
145
|
+
this.pipelines = {
|
|
146
|
+
linear : createComputePipeline(d, LINEAR_FORWARD_WGSL, 'linear_forward'),
|
|
147
|
+
conv1d : createComputePipeline(d, CONV1D_FORWARD_WGSL, 'conv1d_forward'),
|
|
148
|
+
rmsnorm : createComputePipeline(d, ACTIVATIONS_WGSL, 'rmsnorm_forward'),
|
|
149
|
+
cssd_fwd : createComputePipeline(d, COMPLEX_SSD_FORWARD_WGSL, 'complex_ssd_forward'),
|
|
150
|
+
elAdd : createComputePipeline(d, ADD_SHADER, 'main'),
|
|
151
|
+
};
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
forward(xBuf: GPUBuffer, batch: number, seqLen: number): LayerForwardResult {
|
|
155
|
+
const d = this.device;
|
|
156
|
+
const { dModel, dConv, nHeads, nGroups, chunkLen } = this.config;
|
|
157
|
+
const D = this.dInner;
|
|
158
|
+
const Nc = this.nComplex;
|
|
159
|
+
const K = dConv;
|
|
160
|
+
const H = nHeads;
|
|
161
|
+
const G = nGroups;
|
|
162
|
+
const dh = this.dHead;
|
|
163
|
+
const B = batch;
|
|
164
|
+
const L = seqLen;
|
|
165
|
+
const M = B * L;
|
|
166
|
+
const convD = D + 2 * G * Nc * 2;
|
|
167
|
+
const numChunks = Math.ceil(L / chunkLen);
|
|
168
|
+
|
|
169
|
+
// 1. Pre-block RMSNorm
|
|
170
|
+
const normOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
171
|
+
const normInv = createEmptyStorageBuffer(d, M * 4, true);
|
|
172
|
+
{
|
|
173
|
+
const params = new ArrayBuffer(16);
|
|
174
|
+
new Uint32Array(params, 0, 2).set([M, dModel]);
|
|
175
|
+
new Float32Array(params, 8, 1).set([1e-6]);
|
|
176
|
+
const pBuf = createUniformBuffer(d, params);
|
|
177
|
+
const bg = createBindGroup(d, this.pipelines['rmsnorm']!,
|
|
178
|
+
[pBuf, xBuf, this.gpuWeights['preNormWeight']!, normOut, normInv]);
|
|
179
|
+
dispatchKernel(d, this.pipelines['rmsnorm']!, bg, [cdiv(M, 64), 1, 1]);
|
|
180
|
+
}
|
|
181
|
+
normInv.destroy();
|
|
182
|
+
|
|
183
|
+
// 2. Fused in_proj
|
|
184
|
+
const inProjRows = D + 2 * G * Nc * 2 + H;
|
|
185
|
+
const inProjOut = createEmptyStorageBuffer(d, M * inProjRows * 4, true);
|
|
186
|
+
{
|
|
187
|
+
const params = new Uint32Array([M, dModel, inProjRows]).buffer;
|
|
188
|
+
const pBuf = createUniformBuffer(d, params);
|
|
189
|
+
const zeroBias = createStorageBuffer(d, new Float32Array(inProjRows), true);
|
|
190
|
+
const bg = createBindGroup(d, this.pipelines['linear']!,
|
|
191
|
+
[pBuf, normOut, this.gpuWeights['wInProj']!, zeroBias, inProjOut]);
|
|
192
|
+
dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(inProjRows, 16), 1]);
|
|
193
|
+
zeroBias.destroy();
|
|
194
|
+
}
|
|
195
|
+
normOut.destroy();
|
|
196
|
+
|
|
197
|
+
// Split: xConv [convD], dt [H]
|
|
198
|
+
const xConvBuf = createEmptyStorageBuffer(d, M * convD * 4, true);
|
|
199
|
+
const dtBuf = createEmptyStorageBuffer(d, M * H * 4, true);
|
|
200
|
+
{
|
|
201
|
+
const enc = d.createCommandEncoder();
|
|
202
|
+
enc.copyBufferToBuffer(inProjOut, 0, xConvBuf, 0, M * convD * 4);
|
|
203
|
+
enc.copyBufferToBuffer(inProjOut, M * convD * 4, dtBuf, 0, M * H * 4);
|
|
204
|
+
d.queue.submit([enc.finish()]);
|
|
205
|
+
}
|
|
206
|
+
inProjOut.destroy();
|
|
207
|
+
|
|
208
|
+
// 3. Causal conv1d (fused convD channels)
|
|
209
|
+
const convOut = createEmptyStorageBuffer(d, M * convD * 4, true);
|
|
210
|
+
{
|
|
211
|
+
const params = new Uint32Array([L, convD, K, B, 1]).buffer;
|
|
212
|
+
const pBuf = createUniformBuffer(d, params);
|
|
213
|
+
const bg = createBindGroup(d, this.pipelines['conv1d']!,
|
|
214
|
+
[pBuf, xConvBuf, this.gpuWeights['wConv']!, this.gpuWeights['bConv']!, convOut]);
|
|
215
|
+
dispatchKernel(d, this.pipelines['conv1d']!, bg, [cdiv(L, 16), cdiv(convD, 16), B]);
|
|
216
|
+
}
|
|
217
|
+
xConvBuf.destroy();
|
|
218
|
+
|
|
219
|
+
// Split: xSsd [D], B_proj_complex [G*Nc*2], C_proj_complex [G*Nc*2]
|
|
220
|
+
const xSsdBuf = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
221
|
+
const bProjBuf = createEmptyStorageBuffer(d, M * G * Nc * 2 * 4, true);
|
|
222
|
+
const cProjBuf = createEmptyStorageBuffer(d, M * G * Nc * 2 * 4, true);
|
|
223
|
+
{
|
|
224
|
+
const enc = d.createCommandEncoder();
|
|
225
|
+
enc.copyBufferToBuffer(convOut, 0, xSsdBuf, 0, M * D * 4);
|
|
226
|
+
enc.copyBufferToBuffer(convOut, M * D * 4, bProjBuf, 0, M * G * Nc * 2 * 4);
|
|
227
|
+
enc.copyBufferToBuffer(convOut, M * (D + G * Nc * 2) * 4, cProjBuf, 0, M * G * Nc * 2 * 4);
|
|
228
|
+
d.queue.submit([enc.finish()]);
|
|
229
|
+
}
|
|
230
|
+
convOut.destroy();
|
|
231
|
+
|
|
232
|
+
// 4. Complex SSD scan
|
|
233
|
+
// state_carry: [numChunks+1, B, H, Nc*2, dHead]
|
|
234
|
+
const stateCarry = createEmptyStorageBuffer(
|
|
235
|
+
d, (numChunks + 1) * B * H * Nc * 2 * dh * 4, true);
|
|
236
|
+
const cssdOut = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
237
|
+
|
|
238
|
+
{
|
|
239
|
+
const params = new Uint32Array([L, D, H, dh, G, Nc, chunkLen, numChunks, B]).buffer;
|
|
240
|
+
const pBuf = createUniformBuffer(d, params);
|
|
241
|
+
const bg = createBindGroup(d, this.pipelines['cssd_fwd']!,
|
|
242
|
+
[pBuf, xSsdBuf, bProjBuf, cProjBuf, dtBuf,
|
|
243
|
+
this.gpuWeights['A_log']!, this.gpuWeights['dt_bias']!,
|
|
244
|
+
this.gpuWeights['D_vec']!, cssdOut, stateCarry]);
|
|
245
|
+
dispatchKernel(d, this.pipelines['cssd_fwd']!, bg, [numChunks, H, B]);
|
|
246
|
+
}
|
|
247
|
+
xSsdBuf.destroy();
|
|
248
|
+
bProjBuf.destroy();
|
|
249
|
+
cProjBuf.destroy();
|
|
250
|
+
dtBuf.destroy();
|
|
251
|
+
|
|
252
|
+
// 5. Inner RMSNorm
|
|
253
|
+
const innerNormOut = createEmptyStorageBuffer(d, M * D * 4, true);
|
|
254
|
+
const innerNormInv = createEmptyStorageBuffer(d, M * 4, true);
|
|
255
|
+
{
|
|
256
|
+
const params = new ArrayBuffer(16);
|
|
257
|
+
new Uint32Array(params, 0, 2).set([M, D]);
|
|
258
|
+
new Float32Array(params, 8, 1).set([1e-6]);
|
|
259
|
+
const pBuf = createUniformBuffer(d, params);
|
|
260
|
+
const bg = createBindGroup(d, this.pipelines['rmsnorm']!,
|
|
261
|
+
[pBuf, cssdOut, this.gpuWeights['normWeight']!, innerNormOut, innerNormInv]);
|
|
262
|
+
dispatchKernel(d, this.pipelines['rmsnorm']!, bg, [cdiv(M, 64), 1, 1]);
|
|
263
|
+
}
|
|
264
|
+
cssdOut.destroy();
|
|
265
|
+
innerNormInv.destroy();
|
|
266
|
+
|
|
267
|
+
// 6. Output projection
|
|
268
|
+
const outProjOut = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
269
|
+
{
|
|
270
|
+
const params = new Uint32Array([M, D, dModel]).buffer;
|
|
271
|
+
const pBuf = createUniformBuffer(d, params);
|
|
272
|
+
const zeroBias = createStorageBuffer(d, new Float32Array(dModel), true);
|
|
273
|
+
const bg = createBindGroup(d, this.pipelines['linear']!,
|
|
274
|
+
[pBuf, innerNormOut, this.gpuWeights['wOutProj']!, zeroBias, outProjOut]);
|
|
275
|
+
dispatchKernel(d, this.pipelines['linear']!, bg, [cdiv(M, 16), cdiv(dModel, 16), 1]);
|
|
276
|
+
zeroBias.destroy();
|
|
277
|
+
}
|
|
278
|
+
innerNormOut.destroy();
|
|
279
|
+
|
|
280
|
+
// 7. Residual add
|
|
281
|
+
const output = createEmptyStorageBuffer(d, M * dModel * 4, true);
|
|
282
|
+
{
|
|
283
|
+
const nBuf = createUniformBuffer(d, new Uint32Array([M * dModel]).buffer);
|
|
284
|
+
const bg = createBindGroup(d, this.pipelines['elAdd']!,
|
|
285
|
+
[outProjOut, xBuf, output, nBuf]);
|
|
286
|
+
dispatchKernel(d, this.pipelines['elAdd']!, bg, [cdiv(M * dModel, 256), 1, 1]);
|
|
287
|
+
}
|
|
288
|
+
outProjOut.destroy();
|
|
289
|
+
|
|
290
|
+
const cache: Mamba3Cache = { stateCarry };
|
|
291
|
+
return { output, cache };
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
parameters(): LayerParam[] {
|
|
295
|
+
const { dModel, dConv, nHeads, nGroups } = this.config;
|
|
296
|
+
const D = this.dInner;
|
|
297
|
+
const Nc = this.nComplex;
|
|
298
|
+
const K = dConv;
|
|
299
|
+
const H = nHeads;
|
|
300
|
+
const G = nGroups;
|
|
301
|
+
const convD = D + 2 * G * Nc * 2;
|
|
302
|
+
|
|
303
|
+
return [
|
|
304
|
+
{ buf: this.gpuWeights['wInProj']!, numel: (D + 2 * G * Nc * 2 + H) * dModel, name: 'wInProj' },
|
|
305
|
+
{ buf: this.gpuWeights['wConv']!, numel: convD * K, name: 'wConv' },
|
|
306
|
+
{ buf: this.gpuWeights['bConv']!, numel: convD, name: 'bConv' },
|
|
307
|
+
{ buf: this.gpuWeights['A_log']!, numel: H * 2, name: 'A_log' },
|
|
308
|
+
{ buf: this.gpuWeights['dt_bias']!, numel: H, name: 'dt_bias' },
|
|
309
|
+
{ buf: this.gpuWeights['D_vec']!, numel: H, name: 'D_vec' },
|
|
310
|
+
{ buf: this.gpuWeights['wOutProj']!, numel: dModel * D, name: 'wOutProj' },
|
|
311
|
+
{ buf: this.gpuWeights['normWeight']!, numel: D, name: 'normWeight' },
|
|
312
|
+
{ buf: this.gpuWeights['preNormWeight']!, numel: dModel, name: 'preNormWeight'},
|
|
313
|
+
];
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
getTrainableParams(): LayerParam[] {
|
|
317
|
+
if (this._wslaMode) {
|
|
318
|
+
return [
|
|
319
|
+
{ buf: this.gpuWeights['wInProj']!,
|
|
320
|
+
numel: (this.config.nGroups * this.nComplex * 2 * 2) * this.config.dModel,
|
|
321
|
+
name: 'wInProj_BC' },
|
|
322
|
+
];
|
|
323
|
+
}
|
|
324
|
+
return this.parameters();
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
setWSLAMode(enabled: boolean): void {
|
|
328
|
+
this._wslaMode = enabled;
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
destroy(): void {
|
|
332
|
+
for (const buf of Object.values(this.gpuWeights)) buf.destroy();
|
|
333
|
+
this.gpuWeights = {};
|
|
334
|
+
}
|
|
335
|
+
}
|