@seanhogg/builderforce-memory-engine 2026.6.18

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +393 -0
  3. package/dist/index.d.ts +32 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +40 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/kernels/activations.d.ts +5 -0
  8. package/dist/kernels/activations.d.ts.map +1 -0
  9. package/dist/kernels/activations.js +171 -0
  10. package/dist/kernels/activations.js.map +1 -0
  11. package/dist/kernels/attention.d.ts +19 -0
  12. package/dist/kernels/attention.d.ts.map +1 -0
  13. package/dist/kernels/attention.js +263 -0
  14. package/dist/kernels/attention.js.map +1 -0
  15. package/dist/kernels/complex_ssd.d.ts +33 -0
  16. package/dist/kernels/complex_ssd.d.ts.map +1 -0
  17. package/dist/kernels/complex_ssd.js +305 -0
  18. package/dist/kernels/complex_ssd.js.map +1 -0
  19. package/dist/kernels/conv1d.d.ts +3 -0
  20. package/dist/kernels/conv1d.d.ts.map +1 -0
  21. package/dist/kernels/conv1d.js +158 -0
  22. package/dist/kernels/conv1d.js.map +1 -0
  23. package/dist/kernels/linear_projection.d.ts +3 -0
  24. package/dist/kernels/linear_projection.d.ts.map +1 -0
  25. package/dist/kernels/linear_projection.js +219 -0
  26. package/dist/kernels/linear_projection.js.map +1 -0
  27. package/dist/kernels/selective_scan.d.ts +3 -0
  28. package/dist/kernels/selective_scan.d.ts.map +1 -0
  29. package/dist/kernels/selective_scan.js +348 -0
  30. package/dist/kernels/selective_scan.js.map +1 -0
  31. package/dist/kernels/ssd.d.ts +29 -0
  32. package/dist/kernels/ssd.d.ts.map +1 -0
  33. package/dist/kernels/ssd.js +276 -0
  34. package/dist/kernels/ssd.js.map +1 -0
  35. package/dist/kernels/weight_update.d.ts +3 -0
  36. package/dist/kernels/weight_update.d.ts.map +1 -0
  37. package/dist/kernels/weight_update.js +119 -0
  38. package/dist/kernels/weight_update.js.map +1 -0
  39. package/dist/model/attention_block.d.ts +48 -0
  40. package/dist/model/attention_block.d.ts.map +1 -0
  41. package/dist/model/attention_block.js +262 -0
  42. package/dist/model/attention_block.js.map +1 -0
  43. package/dist/model/mamba1_block.d.ts +70 -0
  44. package/dist/model/mamba1_block.d.ts.map +1 -0
  45. package/dist/model/mamba1_block.js +333 -0
  46. package/dist/model/mamba1_block.js.map +1 -0
  47. package/dist/model/mamba2_block.d.ts +44 -0
  48. package/dist/model/mamba2_block.d.ts.map +1 -0
  49. package/dist/model/mamba2_block.js +252 -0
  50. package/dist/model/mamba2_block.js.map +1 -0
  51. package/dist/model/mamba3_block.d.ts +51 -0
  52. package/dist/model/mamba3_block.d.ts.map +1 -0
  53. package/dist/model/mamba3_block.js +270 -0
  54. package/dist/model/mamba3_block.js.map +1 -0
  55. package/dist/model/mamba_block.d.ts +64 -0
  56. package/dist/model/mamba_block.d.ts.map +1 -0
  57. package/dist/model/mamba_block.js +303 -0
  58. package/dist/model/mamba_block.js.map +1 -0
  59. package/dist/model/mamba_model.d.ts +140 -0
  60. package/dist/model/mamba_model.d.ts.map +1 -0
  61. package/dist/model/mamba_model.js +527 -0
  62. package/dist/model/mamba_model.js.map +1 -0
  63. package/dist/model/sequence_layer.d.ts +25 -0
  64. package/dist/model/sequence_layer.d.ts.map +1 -0
  65. package/dist/model/sequence_layer.js +8 -0
  66. package/dist/model/sequence_layer.js.map +1 -0
  67. package/dist/tokenizer/bpe.d.ts +29 -0
  68. package/dist/tokenizer/bpe.d.ts.map +1 -0
  69. package/dist/tokenizer/bpe.js +164 -0
  70. package/dist/tokenizer/bpe.js.map +1 -0
  71. package/dist/training/autograd.d.ts +27 -0
  72. package/dist/training/autograd.d.ts.map +1 -0
  73. package/dist/training/autograd.js +120 -0
  74. package/dist/training/autograd.js.map +1 -0
  75. package/dist/training/trainer.d.ts +36 -0
  76. package/dist/training/trainer.d.ts.map +1 -0
  77. package/dist/training/trainer.js +183 -0
  78. package/dist/training/trainer.js.map +1 -0
  79. package/dist/utils/gpu_utils.d.ts +21 -0
  80. package/dist/utils/gpu_utils.d.ts.map +1 -0
  81. package/dist/utils/gpu_utils.js +111 -0
  82. package/dist/utils/gpu_utils.js.map +1 -0
  83. package/dist/utils/quantization.d.ts +26 -0
  84. package/dist/utils/quantization.d.ts.map +1 -0
  85. package/dist/utils/quantization.js +116 -0
  86. package/dist/utils/quantization.js.map +1 -0
  87. package/dist/utils/rng.d.ts +36 -0
  88. package/dist/utils/rng.d.ts.map +1 -0
  89. package/dist/utils/rng.js +61 -0
  90. package/dist/utils/rng.js.map +1 -0
  91. package/package.json +99 -0
  92. package/src/index.ts +114 -0
  93. package/src/kernels/activations.ts +174 -0
  94. package/src/kernels/attention.ts +268 -0
  95. package/src/kernels/complex_ssd.ts +307 -0
  96. package/src/kernels/conv1d.ts +159 -0
  97. package/src/kernels/linear_projection.ts +220 -0
  98. package/src/kernels/selective_scan.ts +350 -0
  99. package/src/kernels/ssd.ts +278 -0
  100. package/src/kernels/weight_update.ts +120 -0
  101. package/src/model/attention_block.ts +344 -0
  102. package/src/model/mamba1_block.ts +437 -0
  103. package/src/model/mamba2_block.ts +319 -0
  104. package/src/model/mamba3_block.ts +335 -0
  105. package/src/model/mamba_block.ts +401 -0
  106. package/src/model/mamba_model.ts +678 -0
  107. package/src/model/sequence_layer.ts +29 -0
  108. package/src/tokenizer/bpe.ts +186 -0
  109. package/src/training/autograd.ts +135 -0
  110. package/src/training/trainer.ts +309 -0
  111. package/src/utils/gpu_utils.ts +147 -0
  112. package/src/utils/quantization.ts +154 -0
  113. package/src/utils/rng.ts +65 -0
@@ -0,0 +1,678 @@
1
+ /**
2
+ * mamba_model.ts – HybridMambaModel: Mamba-1/2/3 and Attention layer scheduling.
3
+ *
4
+ * Replaces the fixed MambaBlock[] array with a SequenceLayer[] built from a
5
+ * per-layer type schedule. MambaModel is kept as a backward-compatible alias
6
+ * (all-mamba1 schedule).
7
+ *
8
+ * MBJS binary format:
9
+ * Version 1 (legacy): [magic][v=1][nParams][numel[]][ f32 data ]
10
+ * Version 2 (new): [magic][v=2][nLayers][layerType[]][padding][nParams][numel[]][ f32 data ]
11
+ * layerType: 0=mamba1, 1=mamba2, 2=mamba3, 3=attention
12
+ */
13
+
14
+ import { Mamba1Block } from './mamba1_block.js';
15
+ import { Mamba2Block } from './mamba2_block.js';
16
+ import { Mamba3Block } from './mamba3_block.js';
17
+ import { AttentionBlock } from './attention_block.js';
18
+ import type { SequenceLayer, LayerParam, LayerType } from './sequence_layer.js';
19
+ import type { Mamba1BlockConfig } from './mamba1_block.js';
20
+ import type { Mamba2BlockConfig } from './mamba2_block.js';
21
+ import type { Mamba3BlockConfig } from './mamba3_block.js';
22
+ import type { AttentionBlockConfig } from './attention_block.js';
23
+
24
+ import {
25
+ createStorageBuffer,
26
+ createEmptyStorageBuffer,
27
+ createUniformBuffer,
28
+ createComputePipeline,
29
+ createBindGroup,
30
+ dispatchKernel,
31
+ readBuffer,
32
+ uploadBuffer,
33
+ cdiv,
34
+ } from '../utils/gpu_utils.js';
35
+ import { LINEAR_FORWARD_WGSL } from '../kernels/linear_projection.js';
36
+ import { ACTIVATIONS_WGSL } from '../kernels/activations.js';
37
+ import { gaussianArray, setInitSeed } from '../utils/rng.js';
38
+ import { quantizeFp16, dequantizeFp16 } from '../utils/quantization.js';
39
+
40
+ // ── Public types ──────────────────────────────────────────────────────────────
41
+
42
+ export interface LayerSpec {
43
+ type : LayerType;
44
+ config? : Partial<Mamba1BlockConfig | Mamba2BlockConfig | Mamba3BlockConfig | AttentionBlockConfig>;
45
+ }
46
+
47
+ export interface HybridMambaModelConfig {
48
+ vocabSize : number;
49
+ dModel : number;
50
+ numLayers : number;
51
+
52
+ /**
53
+ * Per-layer type schedule. Length must equal numLayers.
54
+ * Defaults to all 'mamba1' (backward-compatible).
55
+ */
56
+ layers? : LayerSpec[];
57
+
58
+ // Shared defaults per variant type (individual LayerSpec.config overrides take precedence)
59
+ defaultMamba1? : Partial<Mamba1BlockConfig>;
60
+ defaultMamba2? : Partial<Mamba2BlockConfig>;
61
+ defaultMamba3? : Partial<Mamba3BlockConfig>;
62
+ defaultAttention?: Partial<AttentionBlockConfig>;
63
+
64
+ // Mamba-1 compatible shorthand fields (applied to all mamba1 layers)
65
+ dState? : number;
66
+ dConv? : number;
67
+ expand? : number;
68
+
69
+ // Mamba-2/3 defaults
70
+ nHeads? : number;
71
+ nGroups? : number;
72
+ chunkLen? : number;
73
+ mimoGroup? : number;
74
+
75
+ eosId? : number;
76
+
77
+ /**
78
+ * Optional deterministic seed for weight initialisation. When set, the
79
+ * embedding table and all block weights are initialised reproducibly — the
80
+ * same seed yields byte-identical initial weights on any machine. When
81
+ * omitted, weights use `Math.random` (non-reproducible) as before.
82
+ */
83
+ seed? : number;
84
+ }
85
+
86
+ /** Legacy Mamba-1-only config (fully backward-compatible). */
87
+ export interface MambaModelConfig {
88
+ vocabSize : number;
89
+ dModel : number;
90
+ numLayers : number;
91
+ dState? : number;
92
+ dConv? : number;
93
+ expand? : number;
94
+ eosId? : number;
95
+ }
96
+
97
+ export interface ModelForwardResult {
98
+ logits : Float32Array;
99
+ gpuLogits : GPUBuffer;
100
+ caches : unknown[];
101
+ }
102
+
103
+ export interface SamplingOptions {
104
+ temperature? : number;
105
+ topK? : number;
106
+ topP? : number;
107
+ }
108
+
109
+ // ── MBJS format constants ─────────────────────────────────────────────────────
110
+
111
+ const MBJS_MAGIC = 0x4D424A53; // 'MBJS'
112
+ const LAYER_TYPE_ID: Record<LayerType, number> = {
113
+ mamba1 : 0,
114
+ mamba2 : 1,
115
+ mamba3 : 2,
116
+ attention: 3,
117
+ };
118
+ const ID_TO_LAYER_TYPE: LayerType[] = ['mamba1', 'mamba2', 'mamba3', 'attention'];
119
+
120
+ // ── HybridMambaModel ──────────────────────────────────────────────────────────
121
+
122
+ export class HybridMambaModel {
123
+ device : GPUDevice;
124
+ config : Required<HybridMambaModelConfig>;
125
+ gpuEmbedding : GPUBuffer;
126
+ layers : SequenceLayer[];
127
+ layerSpecs : LayerSpec[];
128
+ gpuFinalNorm : GPUBuffer;
129
+ tiedEmbedding : boolean;
130
+ gpuLMHeadBias : GPUBuffer;
131
+
132
+ private _lmHeadPipeline : GPUComputePipeline;
133
+ private _rmsnormPipeline : GPUComputePipeline;
134
+ private _embedPipeline : GPUComputePipeline;
135
+ private _wslaMode = false;
136
+
137
+ constructor(device: GPUDevice, config: HybridMambaModelConfig) {
138
+ this.device = device;
139
+ this.config = {
140
+ dState : 16,
141
+ dConv : 4,
142
+ expand : 2,
143
+ nHeads : 4,
144
+ nGroups : 1,
145
+ chunkLen : 256,
146
+ mimoGroup : 1,
147
+ eosId : -1,
148
+ defaultMamba1 : {},
149
+ defaultMamba2 : {},
150
+ defaultMamba3 : {},
151
+ defaultAttention: {},
152
+ layers : undefined as unknown as LayerSpec[],
153
+ seed : undefined as unknown as number,
154
+ ...config,
155
+ } as Required<HybridMambaModelConfig>;
156
+
157
+ // Install the deterministic init seed (if any) for the duration of
158
+ // construction, so the embedding table and every block initialise
159
+ // reproducibly. Restored to Math.random once all weights are built.
160
+ setInitSeed(this.config.seed);
161
+
162
+ // Resolve layer schedule
163
+ const layerSchedule: LayerSpec[] = config.layers
164
+ ?? Array.from({ length: config.numLayers }, () => ({ type: 'mamba1' as LayerType }));
165
+
166
+ if (layerSchedule.length !== config.numLayers) {
167
+ throw new Error(
168
+ `HybridMambaModel: layers schedule length (${layerSchedule.length}) must equal numLayers (${config.numLayers}).`
169
+ );
170
+ }
171
+ this.layerSpecs = layerSchedule;
172
+
173
+ // Embedding table
174
+ const { vocabSize, dModel } = this.config;
175
+ const embedData = gaussianArray(vocabSize * dModel, 1.0 / Math.sqrt(dModel));
176
+ this.gpuEmbedding = createStorageBuffer(device, embedData, true);
177
+
178
+ // Build layers (block constructors also draw from the seeded source)
179
+ this.layers = layerSchedule.map(spec => this._buildLayer(spec));
180
+
181
+ // Restore the default Math.random source now that all weights are built.
182
+ setInitSeed(undefined);
183
+
184
+ // Final RMSNorm
185
+ this.gpuFinalNorm = createStorageBuffer(device, new Float32Array(dModel).fill(1.0), true);
186
+
187
+ this.tiedEmbedding = true;
188
+ this.gpuLMHeadBias = createStorageBuffer(device, new Float32Array(vocabSize), true);
189
+
190
+ this._lmHeadPipeline = createComputePipeline(device, LINEAR_FORWARD_WGSL, 'linear_forward');
191
+ this._rmsnormPipeline = createComputePipeline(device, ACTIVATIONS_WGSL, 'rmsnorm_forward');
192
+ this._embedPipeline = createComputePipeline(device, EMBED_LOOKUP_WGSL, 'embed_lookup');
193
+ }
194
+
195
+ private _buildLayer(spec: LayerSpec): SequenceLayer {
196
+ const c = this.config;
197
+ switch (spec.type) {
198
+ case 'mamba1': {
199
+ const base: Mamba1BlockConfig = {
200
+ dModel : c.dModel,
201
+ dState : c.dState,
202
+ dConv : c.dConv,
203
+ expand : c.expand,
204
+ ...c.defaultMamba1,
205
+ };
206
+ return new Mamba1Block(this.device, { ...base, ...(spec.config ?? {}) } as Mamba1BlockConfig);
207
+ }
208
+ case 'mamba2': {
209
+ const base: Mamba2BlockConfig = {
210
+ dModel : c.dModel,
211
+ dState : c.dState,
212
+ dConv : c.dConv,
213
+ expand : c.expand,
214
+ nHeads : c.nHeads,
215
+ nGroups : c.nGroups,
216
+ chunkLen: c.chunkLen,
217
+ ...c.defaultMamba2,
218
+ };
219
+ return new Mamba2Block(this.device, { ...base, ...(spec.config ?? {}) } as Mamba2BlockConfig);
220
+ }
221
+ case 'mamba3': {
222
+ const base: Mamba3BlockConfig = {
223
+ dModel : c.dModel,
224
+ dState : c.dState,
225
+ dConv : c.dConv,
226
+ expand : c.expand,
227
+ nHeads : c.nHeads,
228
+ nGroups : c.nGroups,
229
+ chunkLen : c.chunkLen,
230
+ mimoGroup: c.mimoGroup,
231
+ ...c.defaultMamba3,
232
+ };
233
+ return new Mamba3Block(this.device, { ...base, ...(spec.config ?? {}) } as Mamba3BlockConfig);
234
+ }
235
+ case 'attention': {
236
+ const base: AttentionBlockConfig = {
237
+ dModel : c.dModel,
238
+ nHeads : c.nHeads,
239
+ ...c.defaultAttention,
240
+ };
241
+ return new AttentionBlock(this.device, { ...base, ...(spec.config ?? {}) } as AttentionBlockConfig);
242
+ }
243
+ }
244
+ }
245
+
246
+ embedTokens(tokenIds: number[] | Uint32Array, batch: number, seqLen: number): GPUBuffer {
247
+ const { dModel } = this.config;
248
+ const M = batch * seqLen;
249
+
250
+ const idsBuf = createStorageBuffer(this.device,
251
+ tokenIds instanceof Uint32Array ? tokenIds : new Uint32Array(tokenIds), false);
252
+ const outBuf = createEmptyStorageBuffer(this.device, M * dModel * 4, true);
253
+
254
+ const pBuf = createUniformBuffer(this.device, new Uint32Array([M, dModel]).buffer);
255
+ const bg = createBindGroup(this.device, this._embedPipeline,
256
+ [pBuf, idsBuf, this.gpuEmbedding, outBuf]);
257
+ dispatchKernel(this.device, this._embedPipeline, bg, [cdiv(M, 64), 1, 1]);
258
+
259
+ idsBuf.destroy();
260
+ pBuf.destroy();
261
+ return outBuf;
262
+ }
263
+
264
+ async forward(tokenIds: number[] | Uint32Array, batch: number, seqLen: number): Promise<ModelForwardResult> {
265
+ const { dModel, vocabSize } = this.config;
266
+ const M = batch * seqLen;
267
+
268
+ let hidden = this.embedTokens(tokenIds, batch, seqLen);
269
+
270
+ const caches: unknown[] = [];
271
+ for (const layer of this.layers) {
272
+ const { output, cache } = layer.forward(hidden, batch, seqLen);
273
+ caches.push(cache);
274
+ hidden.destroy();
275
+ hidden = output;
276
+ }
277
+
278
+ // Final RMSNorm
279
+ const normOut = createEmptyStorageBuffer(this.device, M * dModel * 4, true);
280
+ const normInv = createEmptyStorageBuffer(this.device, M * 4, false);
281
+ {
282
+ const params = new ArrayBuffer(16);
283
+ new Uint32Array(params, 0, 2).set([M, dModel]);
284
+ new Float32Array(params, 8, 1).set([1e-6]);
285
+ const pBuf = createUniformBuffer(this.device, params);
286
+ const bg = createBindGroup(this.device, this._rmsnormPipeline,
287
+ [pBuf, hidden, this.gpuFinalNorm, normOut, normInv]);
288
+ dispatchKernel(this.device, this._rmsnormPipeline, bg, [cdiv(M, 64), 1, 1]);
289
+ }
290
+ hidden.destroy();
291
+
292
+ // LM head (tied embedding)
293
+ const gpuLogits = createEmptyStorageBuffer(this.device, M * vocabSize * 4, true);
294
+ {
295
+ const params = new Uint32Array([M, dModel, vocabSize]).buffer;
296
+ const pBuf = createUniformBuffer(this.device, params);
297
+ const bg = createBindGroup(this.device, this._lmHeadPipeline,
298
+ [pBuf, normOut, this.gpuEmbedding, this.gpuLMHeadBias, gpuLogits]);
299
+ dispatchKernel(this.device, this._lmHeadPipeline, bg,
300
+ [cdiv(M, 16), cdiv(vocabSize, 16), 1]);
301
+ }
302
+ normOut.destroy();
303
+ normInv.destroy();
304
+
305
+ const logits = await readBuffer(this.device, gpuLogits, M * vocabSize * 4);
306
+ return { logits, gpuLogits, caches };
307
+ }
308
+
309
+ /**
310
+ * Produces a single fixed-length embedding vector for a token sequence.
311
+ *
312
+ * Runs the full layer stack plus the final RMSNorm — i.e. the same hidden
313
+ * state the LM head consumes — then mean-pools across sequence positions and
314
+ * L2-normalises the result. The returned vector has length `dModel` and is
315
+ * suitable for cosine-similarity semantic search.
316
+ *
317
+ * Unlike `forward()`, this skips the (expensive) LM-head projection: it only
318
+ * needs the `dModel`-wide hidden state, not `vocabSize` logits.
319
+ *
320
+ * The embedding reflects whatever the model currently knows — an untrained
321
+ * model behaves like a random projection of the token embeddings (still
322
+ * lexically discriminative), and the representation sharpens automatically as
323
+ * the model is adapted/distilled.
324
+ */
325
+ async embed(tokenIds: number[] | Uint32Array): Promise<Float32Array> {
326
+ const { dModel } = this.config;
327
+ const seqLen = tokenIds.length;
328
+ const batch = 1;
329
+ const M = batch * seqLen;
330
+ if (M === 0) return new Float32Array(dModel);
331
+
332
+ let hidden = this.embedTokens(tokenIds, batch, seqLen);
333
+ for (const layer of this.layers) {
334
+ const { output } = layer.forward(hidden, batch, seqLen);
335
+ hidden.destroy();
336
+ hidden = output;
337
+ }
338
+
339
+ // Final RMSNorm — mirrors forward(), but we stop here (no LM head).
340
+ const normOut = createEmptyStorageBuffer(this.device, M * dModel * 4, true);
341
+ const normInv = createEmptyStorageBuffer(this.device, M * 4, false);
342
+ {
343
+ const params = new ArrayBuffer(16);
344
+ new Uint32Array(params, 0, 2).set([M, dModel]);
345
+ new Float32Array(params, 8, 1).set([1e-6]);
346
+ const pBuf = createUniformBuffer(this.device, params);
347
+ const bg = createBindGroup(this.device, this._rmsnormPipeline,
348
+ [pBuf, hidden, this.gpuFinalNorm, normOut, normInv]);
349
+ dispatchKernel(this.device, this._rmsnormPipeline, bg, [cdiv(M, 64), 1, 1]);
350
+ }
351
+ hidden.destroy();
352
+
353
+ const normed = await readBuffer(this.device, normOut, M * dModel * 4);
354
+ normOut.destroy();
355
+ normInv.destroy();
356
+
357
+ // Mean-pool across sequence positions → dModel vector.
358
+ const out = new Float32Array(dModel);
359
+ for (let t = 0; t < seqLen; t++) {
360
+ const base = t * dModel;
361
+ for (let d = 0; d < dModel; d++) out[d]! += normed[base + d]!;
362
+ }
363
+ for (let d = 0; d < dModel; d++) out[d]! /= seqLen;
364
+
365
+ // L2-normalise so cosine similarity reduces to a dot product.
366
+ let norm = 0;
367
+ for (let d = 0; d < dModel; d++) norm += out[d]! * out[d]!;
368
+ norm = Math.sqrt(norm) || 1;
369
+ for (let d = 0; d < dModel; d++) out[d]! /= norm;
370
+
371
+ return out;
372
+ }
373
+
374
+ async generate(promptIds: number[], maxNewTokens = 200, samplingOpts: SamplingOptions = {}): Promise<number[]> {
375
+ const { temperature = 1.0, topK = 50, topP = 0.9 } = samplingOpts;
376
+ const { vocabSize } = this.config;
377
+
378
+ let ids = [...promptIds];
379
+
380
+ for (let step = 0; step < maxNewTokens; step++) {
381
+ const { logits } = await this.forward(new Uint32Array(ids), 1, ids.length);
382
+ const lastLogits = logits.slice((ids.length - 1) * vocabSize, ids.length * vocabSize);
383
+ const nextId = sampleToken(lastLogits, { temperature, topK, topP });
384
+ ids.push(nextId);
385
+ if (nextId === this.config.eosId) break;
386
+ }
387
+
388
+ return ids;
389
+ }
390
+
391
+ parameters(): LayerParam[] {
392
+ const params: LayerParam[] = [];
393
+
394
+ params.push({
395
+ buf : this.gpuEmbedding,
396
+ numel: this.config.vocabSize * this.config.dModel,
397
+ name : 'embedding',
398
+ });
399
+
400
+ for (let i = 0; i < this.layers.length; i++) {
401
+ for (const p of this.layers[i]!.parameters()) {
402
+ params.push({ ...p, name: `layer${i}.${p.name}` });
403
+ }
404
+ }
405
+
406
+ params.push({
407
+ buf : this.gpuFinalNorm,
408
+ numel: this.config.dModel,
409
+ name : 'final_norm',
410
+ });
411
+
412
+ return params;
413
+ }
414
+
415
+ setWSLAMode(enabled: boolean): void {
416
+ for (const layer of this.layers) layer.setWSLAMode(enabled);
417
+ this._wslaMode = enabled;
418
+ }
419
+
420
+ // ── Serialisation (MBJS v2 / v3) ──────────────────────────────────────────
421
+
422
+ /**
423
+ * Export all parameters to an ArrayBuffer.
424
+ *
425
+ * MBJS v2/v3 format (identical header; only the data encoding differs):
426
+ * [0..3] magic : uint32 = 0x4D424A53
427
+ * [4..7] version : uint32 = 2 (fp32 data) | 3 (fp16 data)
428
+ * [8..11] nLayers : uint32
429
+ * [12 .. 12+nLayers-1] layerType[i]: uint8 (0=m1, 1=m2, 2=m3, 3=attn)
430
+ * aligned to 4 bytes: padding
431
+ * [next 4] nParams : uint32
432
+ * [next 4*nParams] numel[i]: uint32
433
+ * [data] float32 values (v2) | float16 values (v3, half the size)
434
+ *
435
+ * Pass `{ fp16: true }` to emit a v3 checkpoint — roughly half the bytes,
436
+ * with a small precision loss that is negligible for SSM weights.
437
+ */
438
+ async exportWeights(opts: { fp16?: boolean } = {}): Promise<ArrayBuffer> {
439
+ const fp16 = opts.fp16 ?? false;
440
+ const params = this.parameters();
441
+ const nParams = params.length;
442
+ const nLayers = this.layers.length;
443
+
444
+ const arrays: Float32Array[] = await Promise.all(
445
+ params.map(p => readBuffer(this.device, p.buf, p.numel * 4))
446
+ );
447
+
448
+ // Header: magic(4) + version(4) + nLayers(4) + layerTypes(nLayers, padded to 4) + nParams(4) + numels(4*nParams)
449
+ const layerTypeBytes = Math.ceil(nLayers / 4) * 4; // align to 4
450
+ const headerBytes = 4 + 4 + 4 + layerTypeBytes + 4 + nParams * 4;
451
+ const bytesPerEl = fp16 ? 2 : 4;
452
+ const totalEls = arrays.reduce((a, arr) => a + arr.length, 0);
453
+ const dataBytes = totalEls * bytesPerEl;
454
+ const out = new ArrayBuffer(headerBytes + dataBytes);
455
+ const view = new DataView(out);
456
+
457
+ let off = 0;
458
+ view.setUint32(off, MBJS_MAGIC, true); off += 4;
459
+ view.setUint32(off, fp16 ? 3 : 2, true); off += 4; // version 2 (fp32) | 3 (fp16)
460
+ view.setUint32(off, nLayers, true); off += 4;
461
+
462
+ for (let i = 0; i < nLayers; i++) {
463
+ const lt = this.layers[i]!.layerType;
464
+ view.setUint8(off + i, LAYER_TYPE_ID[lt]);
465
+ }
466
+ off += layerTypeBytes;
467
+
468
+ view.setUint32(off, nParams, true); off += 4;
469
+ for (const p of params) {
470
+ view.setUint32(off, p.numel, true);
471
+ off += 4;
472
+ }
473
+ // Header bytes are a multiple of 4, so both Float32Array and Uint16Array
474
+ // views below are correctly aligned at `off`.
475
+ if (fp16) {
476
+ for (const arr of arrays) {
477
+ const half = quantizeFp16(arr);
478
+ new Uint16Array(out, off, half.length).set(half);
479
+ off += half.length * 2;
480
+ }
481
+ } else {
482
+ for (const arr of arrays) {
483
+ new Float32Array(out, off, arr.length).set(arr);
484
+ off += arr.byteLength;
485
+ }
486
+ }
487
+
488
+ return out;
489
+ }
490
+
491
+ /**
492
+ * Load parameters from an MBJS v1, v2, or v3 ArrayBuffer.
493
+ *
494
+ * v1: assumes all layers are mamba1 (backward compatible).
495
+ * v2: reads layer type array and validates per-layer parameter counts (fp32 data).
496
+ * v3: identical layout to v2 but the data section is fp16 (dequantised on load).
497
+ */
498
+ async loadWeights(buffer: ArrayBuffer): Promise<void> {
499
+ const view = new DataView(buffer);
500
+ let off = 0;
501
+
502
+ const magic = view.getUint32(off, true); off += 4;
503
+ if (magic !== MBJS_MAGIC) {
504
+ throw new Error('Invalid weight file: bad magic number. Expected MBJS file.');
505
+ }
506
+
507
+ const version = view.getUint32(off, true); off += 4;
508
+
509
+ if (version === 1) {
510
+ // Legacy path: all-mamba1, no layer metadata
511
+ const nParams = view.getUint32(off, true); off += 4;
512
+ const params = this.parameters();
513
+
514
+ if (nParams !== params.length) {
515
+ throw new Error(
516
+ `Weight file has ${nParams} parameters but this model has ${params.length}.`
517
+ );
518
+ }
519
+
520
+ const numels: number[] = [];
521
+ for (let i = 0; i < nParams; i++) {
522
+ numels.push(view.getUint32(off, true));
523
+ off += 4;
524
+ }
525
+
526
+ for (let i = 0; i < nParams; i++) {
527
+ const p = params[i]!;
528
+ const numel = numels[i]!;
529
+ if (numel !== p.numel) {
530
+ throw new Error(`Parameter ${i} ("${p.name}") size mismatch: file=${numel}, model=${p.numel}.`);
531
+ }
532
+ uploadBuffer(this.device, p.buf, new Float32Array(buffer, off, p.numel));
533
+ off += p.numel * 4;
534
+ }
535
+ return;
536
+ }
537
+
538
+ if (version === 2 || version === 3) {
539
+ const fp16 = version === 3;
540
+ const nLayers = view.getUint32(off, true); off += 4;
541
+
542
+ if (nLayers !== this.layers.length) {
543
+ throw new Error(`Weight file has ${nLayers} layers but this model has ${this.layers.length}.`);
544
+ }
545
+
546
+ // Read layer types and validate
547
+ for (let i = 0; i < nLayers; i++) {
548
+ const typeId = view.getUint8(off + i);
549
+ const expectedType = this.layers[i]!.layerType;
550
+ const fileType = ID_TO_LAYER_TYPE[typeId] ?? 'mamba1';
551
+ if (fileType !== expectedType) {
552
+ throw new Error(
553
+ `Layer ${i} type mismatch: file="${fileType}", model="${expectedType}".`
554
+ );
555
+ }
556
+ }
557
+ const layerTypeBytes = Math.ceil(nLayers / 4) * 4;
558
+ off += layerTypeBytes;
559
+
560
+ const nParams = view.getUint32(off, true); off += 4;
561
+ const params = this.parameters();
562
+
563
+ if (nParams !== params.length) {
564
+ throw new Error(
565
+ `Weight file has ${nParams} parameters but this model has ${params.length}.`
566
+ );
567
+ }
568
+
569
+ const numels: number[] = [];
570
+ for (let i = 0; i < nParams; i++) {
571
+ numels.push(view.getUint32(off, true));
572
+ off += 4;
573
+ }
574
+
575
+ for (let i = 0; i < nParams; i++) {
576
+ const p = params[i]!;
577
+ const numel = numels[i]!;
578
+ if (numel !== p.numel) {
579
+ throw new Error(`Parameter ${i} ("${p.name}") size mismatch: file=${numel}, model=${p.numel}.`);
580
+ }
581
+ if (fp16) {
582
+ const half = new Uint16Array(buffer, off, numel);
583
+ uploadBuffer(this.device, p.buf, dequantizeFp16(half));
584
+ off += numel * 2;
585
+ } else {
586
+ uploadBuffer(this.device, p.buf, new Float32Array(buffer, off, p.numel));
587
+ off += numel * 4;
588
+ }
589
+ }
590
+ return;
591
+ }
592
+
593
+ throw new Error(`Unsupported MBJS version: ${version}. Expected 1, 2, or 3.`);
594
+ }
595
+
596
+ destroy(): void {
597
+ this.gpuEmbedding.destroy();
598
+ for (const layer of this.layers) layer.destroy();
599
+ this.gpuFinalNorm.destroy();
600
+ this.gpuLMHeadBias.destroy();
601
+ }
602
+ }
603
+
604
+ // ── MambaModel – backward-compatible alias ────────────────────────────────────
605
+
606
+ export class MambaModel extends HybridMambaModel {
607
+ constructor(device: GPUDevice, config: MambaModelConfig) {
608
+ super(device, {
609
+ ...config,
610
+ layers: Array.from({ length: config.numLayers }, () => ({ type: 'mamba1' as LayerType })),
611
+ });
612
+ }
613
+ }
614
+
615
+ // ── Embed lookup WGSL ─────────────────────────────────────────────────────────
616
+
617
+ const EMBED_LOOKUP_WGSL: string = /* wgsl */`
618
+ struct EmbedParams {
619
+ num_tokens : u32,
620
+ d_model : u32,
621
+ };
622
+
623
+ @group(0) @binding(0) var<uniform> params : EmbedParams;
624
+ @group(0) @binding(1) var<storage, read> ids : array<u32>;
625
+ @group(0) @binding(2) var<storage, read> table : array<f32>;
626
+ @group(0) @binding(3) var<storage, read_write> out : array<f32>;
627
+
628
+ @compute @workgroup_size(64, 1, 1)
629
+ fn embed_lookup(@builtin(global_invocation_id) gid: vec3<u32>) {
630
+ let token_idx = gid.x;
631
+ if (token_idx >= params.num_tokens) { return; }
632
+
633
+ let D = params.d_model;
634
+ let tok = ids[token_idx];
635
+ let src = tok * D;
636
+ let dst = token_idx * D;
637
+
638
+ for (var i: u32 = 0u; i < D; i = i + 1u) {
639
+ out[dst + i] = table[src + i];
640
+ }
641
+ }
642
+ `;
643
+
644
+ // ── Token sampling ────────────────────────────────────────────────────────────
645
+
646
+ function sampleToken(logits: Float32Array, { temperature = 1.0, topK = 50, topP = 0.9 } = {}): number {
647
+ const n = logits.length;
648
+
649
+ const scaled = new Float32Array(n);
650
+ for (let i = 0; i < n; i++) scaled[i] = logits[i]! / Math.max(temperature, 1e-7);
651
+
652
+ let maxL = -Infinity;
653
+ for (let i = 0; i < n; i++) if (scaled[i]! > maxL) maxL = scaled[i]!;
654
+ let sumE = 0;
655
+ const exps = new Float32Array(n);
656
+ for (let i = 0; i < n; i++) { exps[i] = Math.exp(scaled[i]! - maxL); sumE += exps[i]!; }
657
+
658
+ const indices = Array.from({ length: n }, (_, i) => i).sort((a, b) => exps[b]! - exps[a]!);
659
+ const topKIdx = indices.slice(0, topK);
660
+
661
+ let cumSum = 0;
662
+ const nucleus: number[] = [];
663
+ for (const idx of topKIdx) {
664
+ cumSum += exps[idx]! / sumE;
665
+ nucleus.push(idx);
666
+ if (cumSum >= topP) break;
667
+ }
668
+
669
+ let nucleusSum = 0;
670
+ for (const idx of nucleus) nucleusSum += exps[idx]!;
671
+ const threshold = Math.random() * nucleusSum;
672
+ let acc = 0;
673
+ for (const idx of nucleus) {
674
+ acc += exps[idx]!;
675
+ if (acc >= threshold) return idx;
676
+ }
677
+ return nucleus[nucleus.length - 1]!;
678
+ }