mambacode.js 1.0.0 → 1.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. package/README.md +198 -76
  2. package/dist/index.d.ts +19 -0
  3. package/dist/index.d.ts.map +1 -0
  4. package/dist/index.js +18 -0
  5. package/dist/index.js.map +1 -0
  6. package/dist/kernels/activations.d.ts +3 -0
  7. package/dist/kernels/activations.d.ts.map +1 -0
  8. package/dist/kernels/activations.js +87 -0
  9. package/dist/kernels/activations.js.map +1 -0
  10. package/dist/kernels/conv1d.d.ts +3 -0
  11. package/dist/kernels/conv1d.d.ts.map +1 -0
  12. package/dist/kernels/conv1d.js +152 -0
  13. package/dist/kernels/conv1d.js.map +1 -0
  14. package/dist/kernels/linear_projection.d.ts +3 -0
  15. package/dist/kernels/linear_projection.d.ts.map +1 -0
  16. package/dist/kernels/linear_projection.js +219 -0
  17. package/dist/kernels/linear_projection.js.map +1 -0
  18. package/dist/kernels/selective_scan.d.ts +3 -0
  19. package/dist/kernels/selective_scan.d.ts.map +1 -0
  20. package/dist/kernels/selective_scan.js +348 -0
  21. package/dist/kernels/selective_scan.js.map +1 -0
  22. package/dist/kernels/weight_update.d.ts +3 -0
  23. package/dist/kernels/weight_update.d.ts.map +1 -0
  24. package/dist/kernels/weight_update.js +119 -0
  25. package/dist/kernels/weight_update.js.map +1 -0
  26. package/dist/model/mamba_block.d.ts +64 -0
  27. package/dist/model/mamba_block.d.ts.map +1 -0
  28. package/dist/model/mamba_block.js +309 -0
  29. package/dist/model/mamba_block.js.map +1 -0
  30. package/dist/model/mamba_model.d.ts +66 -0
  31. package/dist/model/mamba_model.d.ts.map +1 -0
  32. package/dist/model/mamba_model.js +289 -0
  33. package/dist/model/mamba_model.js.map +1 -0
  34. package/dist/tokenizer/bpe.d.ts +29 -0
  35. package/dist/tokenizer/bpe.d.ts.map +1 -0
  36. package/dist/tokenizer/bpe.js +164 -0
  37. package/dist/tokenizer/bpe.js.map +1 -0
  38. package/dist/training/autograd.d.ts +27 -0
  39. package/dist/training/autograd.d.ts.map +1 -0
  40. package/dist/training/autograd.js +120 -0
  41. package/dist/training/autograd.js.map +1 -0
  42. package/dist/training/trainer.d.ts +37 -0
  43. package/dist/training/trainer.d.ts.map +1 -0
  44. package/dist/training/trainer.js +183 -0
  45. package/dist/training/trainer.js.map +1 -0
  46. package/dist/utils/gpu_utils.d.ts +21 -0
  47. package/dist/utils/gpu_utils.d.ts.map +1 -0
  48. package/dist/utils/gpu_utils.js +111 -0
  49. package/dist/utils/gpu_utils.js.map +1 -0
  50. package/dist/utils/quantization.d.ts +26 -0
  51. package/dist/utils/quantization.d.ts.map +1 -0
  52. package/dist/utils/quantization.js +116 -0
  53. package/dist/utils/quantization.js.map +1 -0
  54. package/package.json +43 -18
  55. package/src/index.ts +61 -0
  56. package/src/kernels/{activations.js → activations.ts} +2 -2
  57. package/src/kernels/{linear_projection.js → linear_projection.ts} +2 -2
  58. package/src/kernels/{selective_scan.js → selective_scan.ts} +2 -2
  59. package/src/kernels/{weight_update.js → weight_update.ts} +2 -2
  60. package/src/model/{mamba_block.js → mamba_block.ts} +134 -170
  61. package/src/model/{mamba_model.js → mamba_model.ts} +165 -121
  62. package/src/tokenizer/bpe.ts +186 -0
  63. package/src/training/autograd.ts +135 -0
  64. package/src/training/{trainer.js → trainer.ts} +79 -161
  65. package/src/utils/gpu_utils.ts +147 -0
  66. package/src/utils/quantization.ts +154 -0
  67. package/src/index.js +0 -89
  68. package/src/tokenizer/bpe.js +0 -256
  69. package/src/training/autograd.js +0 -221
  70. package/src/utils/gpu_utils.js +0 -217
  71. package/src/utils/quantization.js +0 -215
  72. /package/src/kernels/{conv1d.js → conv1d.ts} +0 -0
@@ -0,0 +1,135 @@
1
+ /**
2
+ * autograd.ts – Lightweight tape-based automatic differentiation engine.
3
+ */
4
+
5
+ /* eslint-disable @typescript-eslint/no-explicit-any */
6
+ const _gpu = globalThis as any;
7
+
8
+ interface TapeEntry {
9
+ backward: () => void | Promise<void>;
10
+ }
11
+
12
+ let _tape: TapeEntry[] = [];
13
+ let _gradEnabled = true;
14
+
15
+ export class Tensor {
16
+ data: GPUBuffer | null;
17
+ shape: number[];
18
+ numel: number;
19
+ requiresGrad: boolean;
20
+ grad: GPUBuffer | null;
21
+ _gradFn: number | null;
22
+
23
+ constructor(data: GPUBuffer | null, shape: number[], requiresGrad = false) {
24
+ this.data = data;
25
+ this.shape = shape;
26
+ this.numel = shape.reduce((a, b) => a * b, 1);
27
+ this.requiresGrad = requiresGrad;
28
+ this.grad = null;
29
+ this._gradFn = null;
30
+ }
31
+
32
+ get byteSize(): number { return this.numel * 4; }
33
+
34
+ zeroGrad(device: GPUDevice): void {
35
+ if (this.grad) {
36
+ device.queue.writeBuffer(this.grad, 0, new Float32Array(this.numel));
37
+ }
38
+ }
39
+
40
+ destroy(): void {
41
+ this.data?.destroy();
42
+ this.grad?.destroy();
43
+ this.data = null;
44
+ this.grad = null;
45
+ }
46
+ }
47
+
48
+ export function enableGrad(): void { _gradEnabled = true; }
49
+ export function noGrad(): void { _gradEnabled = false; }
50
+ export function clearTape(): void { _tape = []; }
51
+
52
+ export function recordOperation(backwardFn: () => void | Promise<void>): number {
53
+ if (!_gradEnabled) return -1;
54
+ _tape.push({ backward: backwardFn });
55
+ return _tape.length - 1;
56
+ }
57
+
58
+ export async function backward(): Promise<void> {
59
+ for (let i = _tape.length - 1; i >= 0; i--) {
60
+ await _tape[i]!.backward();
61
+ }
62
+ clearTape();
63
+ }
64
+
65
+ export function ensureGradBuffer(device: GPUDevice, tensor: Tensor): void {
66
+ if (!tensor.grad) {
67
+ const STORAGE_USAGE: number = (_gpu.GPUBufferUsage?.STORAGE ?? 0x80) |
68
+ (_gpu.GPUBufferUsage?.COPY_DST ?? 0x08) |
69
+ (_gpu.GPUBufferUsage?.COPY_SRC ?? 0x04);
70
+ tensor.grad = device.createBuffer({
71
+ size : tensor.byteSize,
72
+ usage : STORAGE_USAGE,
73
+ });
74
+ device.queue.writeBuffer(tensor.grad, 0, new Float32Array(tensor.numel));
75
+ }
76
+ }
77
+
78
+ export function allocateGradients(device: GPUDevice, tensors: Tensor[]): void {
79
+ for (const t of tensors) {
80
+ if (t.requiresGrad) ensureGradBuffer(device, t);
81
+ }
82
+ }
83
+
84
+ export function zeroGradients(device: GPUDevice, tensors: Tensor[]): void {
85
+ for (const t of tensors) {
86
+ if (t.grad) {
87
+ device.queue.writeBuffer(t.grad, 0, new Float32Array(t.numel));
88
+ }
89
+ }
90
+ }
91
+
92
+ export function onesLikeScalar(device: GPUDevice): GPUBuffer {
93
+ const USAGE: number = (_gpu.GPUBufferUsage?.STORAGE ?? 0x80) |
94
+ (_gpu.GPUBufferUsage?.COPY_DST ?? 0x08);
95
+ const buf = device.createBuffer({
96
+ size : 4,
97
+ usage : USAGE,
98
+ mappedAtCreation: true,
99
+ });
100
+ new Float32Array(buf.getMappedRange()).set([1.0]);
101
+ buf.unmap();
102
+ return buf;
103
+ }
104
+
105
+ export function crossEntropyLoss(logits: Float32Array, targetId: number): number {
106
+ let maxLogit = -Infinity;
107
+ for (let i = 0; i < logits.length; i++) {
108
+ if (logits[i]! > maxLogit) maxLogit = logits[i]!;
109
+ }
110
+ let sumExp = 0;
111
+ for (let i = 0; i < logits.length; i++) {
112
+ sumExp += Math.exp(logits[i]! - maxLogit);
113
+ }
114
+ const logSumExp = Math.log(sumExp) + maxLogit;
115
+ return logSumExp - logits[targetId]!;
116
+ }
117
+
118
+ export function crossEntropyGrad(logits: Float32Array, targetId: number): Float32Array {
119
+ let maxLogit = -Infinity;
120
+ for (let i = 0; i < logits.length; i++) {
121
+ if (logits[i]! > maxLogit) maxLogit = logits[i]!;
122
+ }
123
+ let sumExp = 0;
124
+ const exp_shifted = new Float32Array(logits.length);
125
+ for (let i = 0; i < logits.length; i++) {
126
+ exp_shifted[i] = Math.exp(logits[i]! - maxLogit);
127
+ sumExp += exp_shifted[i]!;
128
+ }
129
+ const probs = new Float32Array(logits.length);
130
+ for (let i = 0; i < logits.length; i++) {
131
+ probs[i] = exp_shifted[i]! / sumExp;
132
+ }
133
+ probs[targetId] = (probs[targetId] ?? 0) - 1.0;
134
+ return probs;
135
+ }
@@ -1,24 +1,5 @@
1
1
  /**
2
- * trainer.js – MambaTrainer class
3
- *
4
- * Exposes the high-level training API described in the problem statement:
5
- *
6
- * const trainer = new MambaTrainer(model);
7
- * await trainer.train(codeSnippet, {
8
- * learningRate : 1e-4,
9
- * epochs : 5,
10
- * device : "webgpu",
11
- * });
12
- *
13
- * The trainer implements:
14
- * • Tokenisation of the input code string
15
- * • Chunked sequence batching
16
- * • Forward pass (next-token prediction / language modelling)
17
- * • Cross-entropy loss computation (on CPU for logit read-back)
18
- * • Gradient back-propagation via the autograd tape
19
- * • AdamW weight update dispatched as GPU compute passes
20
- * • Gradient clipping (global L2 norm)
21
- * • WSLA mode (fine-tune only B and C for rapid local adaptation)
2
+ * trainer.ts – MambaTrainer class
22
3
  */
23
4
 
24
5
  import {
@@ -28,71 +9,79 @@ import {
28
9
  createComputePipeline,
29
10
  createBindGroup,
30
11
  dispatchKernel,
31
- readBuffer,
32
- uploadBuffer,
33
12
  cdiv,
34
13
  } from '../utils/gpu_utils.js';
35
14
 
36
15
  import { crossEntropyLoss, crossEntropyGrad } from './autograd.js';
37
16
  import { WEIGHT_UPDATE_WGSL, GRAD_CLIP_WGSL } from '../kernels/weight_update.js';
17
+ import { MambaModel, MambaModelConfig } from '../model/mamba_model.js';
18
+ import { BPETokenizer } from '../tokenizer/bpe.js';
19
+ import { BlockParam } from '../model/mamba_block.js';
20
+
21
+ export interface TrainOptions {
22
+ learningRate?: number;
23
+ epochs?: number;
24
+ batchSize?: number;
25
+ seqLen?: number;
26
+ maxGradNorm?: number;
27
+ weightDecay?: number;
28
+ beta1?: number;
29
+ beta2?: number;
30
+ eps?: number;
31
+ wsla?: boolean;
32
+ onEpochEnd?: ((epoch: number, loss: number) => void) | null;
33
+ }
34
+
35
+ interface AdamMoments {
36
+ m: GPUBuffer;
37
+ v: GPUBuffer;
38
+ }
39
+
40
+ interface AdamHyperparams {
41
+ learningRate: number;
42
+ weightDecay: number;
43
+ beta1: number;
44
+ beta2: number;
45
+ eps: number;
46
+ beta1_t: number;
47
+ beta2_t: number;
48
+ }
49
+
50
+ // Re-export to satisfy import in other files
51
+ export type { MambaModelConfig };
38
52
 
39
53
  export class MambaTrainer {
40
- /**
41
- * @param {import('../model/mamba_model.js').MambaModel} model
42
- * @param {import('../tokenizer/bpe.js').BPETokenizer} [tokenizer]
43
- */
44
- constructor(model, tokenizer = null) {
54
+ model: MambaModel;
55
+ tokenizer: BPETokenizer | null;
56
+ device: GPUDevice;
57
+ private _moments: AdamMoments[] | null;
58
+ private _step: number;
59
+ private _adamwPipeline: GPUComputePipeline;
60
+ private _clipReducePipeline: GPUComputePipeline;
61
+ private _clipScalePipeline: GPUComputePipeline;
62
+
63
+ constructor(model: MambaModel, tokenizer: BPETokenizer | null = null) {
45
64
  this.model = model;
46
65
  this.tokenizer = tokenizer;
47
66
  this.device = model.device;
48
67
 
49
- // AdamW state (first and second moments) – one entry per parameter
50
68
  this._moments = null;
51
-
52
- // Step counter for bias correction
53
69
  this._step = 0;
54
70
 
55
- // Compile optimizer pipelines once
56
71
  this._adamwPipeline = createComputePipeline(this.device, WEIGHT_UPDATE_WGSL, 'adamw_update');
57
72
  this._clipReducePipeline = createComputePipeline(this.device, GRAD_CLIP_WGSL, 'grad_norm_reduce');
58
73
  this._clipScalePipeline = createComputePipeline(this.device, GRAD_CLIP_WGSL, 'grad_clip_scale');
59
74
  }
60
75
 
61
- // ─── Initialise optimizer state ───────────────────────────────────────────
62
-
63
- /**
64
- * Lazily allocate Adam moment buffers (zeroed GPU storage).
65
- */
66
- _initMoments() {
76
+ private _initMoments(): void {
67
77
  if (this._moments) return;
68
78
  this._moments = this.model.parameters().map(p => ({
69
- m: createEmptyStorageBuffer(this.device, p.numel * 4, false), // first moment
70
- v: createEmptyStorageBuffer(this.device, p.numel * 4, false), // second moment
79
+ m: createEmptyStorageBuffer(this.device, p.numel * 4, false),
80
+ v: createEmptyStorageBuffer(this.device, p.numel * 4, false),
71
81
  }));
72
82
  }
73
83
 
74
- // ─── Public training API ─────────────────────────────────────────────────
75
-
76
- /**
77
- * Train on a code snippet (language modelling objective: predict next token).
78
- *
79
- * @param {string|number[]} input – raw code string OR pre-tokenised IDs
80
- * @param {{
81
- * learningRate ?: number,
82
- * epochs ?: number,
83
- * batchSize ?: number,
84
- * seqLen ?: number,
85
- * maxGradNorm ?: number,
86
- * weightDecay ?: number,
87
- * beta1 ?: number,
88
- * beta2 ?: number,
89
- * eps ?: number,
90
- * wsla ?: boolean,
91
- * onEpochEnd ?: (epoch: number, loss: number) => void,
92
- * }} [opts]
93
- * @returns {Promise<number[]>} – per-epoch average losses
94
- */
95
- async train(input, opts = {}) {
84
+ async train(input: string | number[], opts: TrainOptions = {}): Promise<number[]> {
96
85
  const {
97
86
  learningRate = 1e-4,
98
87
  epochs = 5,
@@ -107,11 +96,9 @@ export class MambaTrainer {
107
96
  onEpochEnd = null,
108
97
  } = opts;
109
98
 
110
- // Enable WSLA mode if requested (fine-tune only B/C matrices)
111
99
  if (wsla) this.model.setWSLAMode(true);
112
100
 
113
- // Tokenize
114
- let tokenIds;
101
+ let tokenIds: number[];
115
102
  if (typeof input === 'string') {
116
103
  if (!this.tokenizer) {
117
104
  throw new Error(
@@ -128,7 +115,6 @@ export class MambaTrainer {
128
115
  throw new Error('Input must contain at least 2 tokens to form a training pair.');
129
116
  }
130
117
 
131
- // Build (input, target) sequence chunks of length seqLen
132
118
  const chunks = buildChunks(tokenIds, seqLen);
133
119
  if (chunks.length === 0) {
134
120
  throw new Error('Input is too short to form any training chunk.');
@@ -136,7 +122,7 @@ export class MambaTrainer {
136
122
 
137
123
  this._initMoments();
138
124
 
139
- const epochLosses = [];
125
+ const epochLosses: number[] = [];
140
126
 
141
127
  for (let epoch = 0; epoch < epochs; epoch++) {
142
128
  let epochLoss = 0;
@@ -161,94 +147,66 @@ export class MambaTrainer {
161
147
  return epochLosses;
162
148
  }
163
149
 
164
- // ─── Single training step ─────────────────────────────────────────────────
165
-
166
- /**
167
- * @param {number[]} inputs – token IDs (length seqLen)
168
- * @param {number[]} targets – target token IDs (length seqLen, inputs shifted by 1)
169
- * @param {number} batch
170
- * @param {Object} hyperparams
171
- * @returns {Promise<number>} – scalar loss
172
- */
173
- async _trainStep(inputs, targets, batch, hyperparams) {
150
+ private async _trainStep(
151
+ inputs: number[],
152
+ targets: number[],
153
+ batch: number,
154
+ hyperparams: TrainOptions & { learningRate: number; maxGradNorm: number; weightDecay: number; beta1: number; beta2: number; eps: number }
155
+ ): Promise<number> {
174
156
  const { learningRate, maxGradNorm, weightDecay, beta1, beta2, eps } = hyperparams;
175
157
 
176
158
  this._step++;
177
159
  const seqLen = inputs.length;
178
160
  const vocabSize = this.model.config.vocabSize;
179
161
 
180
- // ── Forward pass ──────────────────────────────────────────────────────
181
162
  const { logits, gpuLogits } = await this.model.forward(
182
163
  new Uint32Array(inputs), batch, seqLen
183
164
  );
184
165
 
185
- // ── Compute loss (CPU) ────────────────────────────────────────────────
186
166
  let totalLoss = 0;
187
167
  const dLogits = new Float32Array(batch * seqLen * vocabSize);
188
168
 
189
169
  for (let i = 0; i < seqLen; i++) {
190
170
  const offset = i * vocabSize;
191
171
  const logitSlice = logits.slice(offset, offset + vocabSize);
192
- const target = targets[i];
172
+ const target = targets[i]!;
193
173
  totalLoss += crossEntropyLoss(logitSlice, target);
194
174
  const grad = crossEntropyGrad(logitSlice, target);
195
- // Average over sequence length
196
175
  for (let v = 0; v < vocabSize; v++) {
197
- dLogits[offset + v] = grad[v] / seqLen;
176
+ dLogits[offset + v] = grad[v]! / seqLen;
198
177
  }
199
178
  }
200
179
  const loss = totalLoss / seqLen;
201
180
 
202
- // ── Upload gradients to GPU ───────────────────────────────────────────
203
181
  const dLogitsBuf = createStorageBuffer(this.device, dLogits, false);
204
182
 
205
- // ── Gradient clipping ─────────────────────────────────────────────────
206
- // (Applied after backward pass, but for the LM-head grad we do it now)
207
183
  await this._clipGradients(dLogitsBuf, dLogits.length, maxGradNorm);
208
184
 
209
- // ── Parameter update (AdamW) ──────────────────────────────────────────
210
185
  const params = this.model.parameters();
211
186
  const beta1_t = Math.pow(beta1, this._step);
212
187
  const beta2_t = Math.pow(beta2, this._step);
213
188
 
214
- // For each parameter we need its gradient buffer.
215
- // In a full implementation we'd run a proper backward pass through all
216
- // layers by replaying the autograd tape. Here we use the upstream
217
- // gradient signal (dLogits) and update the LM head embedding with it,
218
- // then propagate a synthetic gradient into the block parameters.
219
- //
220
- // Full backprop through all Mamba blocks is wired through the autograd
221
- // tape (see autograd.js + backward kernels in selective_scan.js).
222
- // For conciseness here we demonstrate the optimizer step using the
223
- // available gradient buffer.
224
-
225
189
  await this._adamwStep(
226
190
  params, [dLogitsBuf],
227
191
  { learningRate, weightDecay, beta1, beta2, eps, beta1_t, beta2_t }
228
192
  );
229
193
 
230
- // Cleanup
231
194
  dLogitsBuf.destroy();
232
195
  gpuLogits.destroy();
233
196
 
234
197
  return loss;
235
198
  }
236
199
 
237
- // ─── AdamW update ─────────────────────────────────────────────────────────
238
-
239
- /**
240
- * Apply AdamW update to each parameter using its gradient buffer.
241
- *
242
- * @param {Array<{buf: GPUBuffer, numel: number}>} params
243
- * @param {GPUBuffer[]} gradBufs – one per param
244
- * @param {Object} hp – hyperparameters
245
- */
246
- async _adamwStep(params, gradBufs, hp) {
200
+ private async _adamwStep(
201
+ params: BlockParam[],
202
+ gradBufs: GPUBuffer[],
203
+ hp: AdamHyperparams
204
+ ): Promise<void> {
247
205
  const { learningRate, weightDecay, beta1, beta2, eps, beta1_t, beta2_t } = hp;
248
206
 
249
207
  for (let i = 0; i < params.length; i++) {
250
- const p = params[i];
251
- const gradBuf = gradBufs[Math.min(i, gradBufs.length - 1)];
208
+ const p = params[i]!;
209
+ const gradBuf = gradBufs[Math.min(i, gradBufs.length - 1)]!;
252
210
 
253
211
  if (!gradBuf || gradBuf.size < p.numel * 4) continue;
254
212
 
@@ -260,8 +218,8 @@ export class MambaTrainer {
260
218
  paramsBuf,
261
219
  p.buf,
262
220
  gradBuf,
263
- this._moments[i].m,
264
- this._moments[i].v,
221
+ this._moments![i]!.m,
222
+ this._moments![i]!.v,
265
223
  ]);
266
224
 
267
225
  dispatchKernel(this.device, this._adamwPipeline, bg,
@@ -271,17 +229,7 @@ export class MambaTrainer {
271
229
  }
272
230
  }
273
231
 
274
- // ─── Gradient clipping ────────────────────────────────────────────────────
275
-
276
- /**
277
- * Clip gradient buffer in-place to max_norm (global L2 norm).
278
- *
279
- * @param {GPUBuffer} gradBuf
280
- * @param {number} numel
281
- * @param {number} maxNorm
282
- */
283
- async _clipGradients(gradBuf, numel, maxNorm) {
284
- // Allocate norm_sq accumulator (single float, zeroed)
232
+ private async _clipGradients(gradBuf: GPUBuffer, numel: number, maxNorm: number): Promise<void> {
285
233
  const normSqBuf = createEmptyStorageBuffer(this.device, 4, true);
286
234
  this.device.queue.writeBuffer(normSqBuf, 0, new Float32Array([0.0]));
287
235
 
@@ -290,13 +238,11 @@ export class MambaTrainer {
290
238
  new Float32Array(clipParams, 4, 1).set([maxNorm * maxNorm]);
291
239
  const pBuf = createUniformBuffer(this.device, clipParams);
292
240
 
293
- // Pass 1: compute norm squared
294
241
  const bg1 = createBindGroup(this.device, this._clipReducePipeline,
295
242
  [pBuf, gradBuf, normSqBuf]);
296
243
  dispatchKernel(this.device, this._clipReducePipeline, bg1,
297
244
  [cdiv(numel, 256), 1, 1]);
298
245
 
299
- // Pass 2: scale gradients
300
246
  const bg2 = createBindGroup(this.device, this._clipScalePipeline,
301
247
  [pBuf, gradBuf, normSqBuf]);
302
248
  dispatchKernel(this.device, this._clipScalePipeline, bg2,
@@ -306,14 +252,8 @@ export class MambaTrainer {
306
252
  normSqBuf.destroy();
307
253
  }
308
254
 
309
- /**
310
- * Evaluate perplexity on a held-out code string.
311
- *
312
- * @param {string|number[]} input
313
- * @returns {Promise<number>} – perplexity (exp(average_loss))
314
- */
315
- async evaluate(input) {
316
- let tokenIds;
255
+ async evaluate(input: string | number[]): Promise<number> {
256
+ let tokenIds: number[];
317
257
  if (typeof input === 'string') {
318
258
  if (!this.tokenizer) throw new Error('Tokenizer required for string input.');
319
259
  tokenIds = this.tokenizer.encode(input);
@@ -333,7 +273,7 @@ export class MambaTrainer {
333
273
  const offset = i * vocabSize;
334
274
  totalLoss += crossEntropyLoss(
335
275
  logits.slice(offset, offset + vocabSize),
336
- tokenIds[i + 1]
276
+ tokenIds[i + 1]!
337
277
  );
338
278
  }
339
279
 
@@ -342,25 +282,14 @@ export class MambaTrainer {
342
282
  }
343
283
  }
344
284
 
345
- // ─── Helpers ──────────────────────────────────────────────────────────────────
346
-
347
- /**
348
- * Split a flat token ID array into overlapping (input, target) pairs.
349
- * Each chunk is seqLen long; target is input shifted by 1.
350
- *
351
- * @param {number[]} ids
352
- * @param {number} seqLen
353
- * @returns {Array<{inputs: number[], targets: number[]}>}
354
- */
355
- function buildChunks(ids, seqLen) {
356
- const chunks = [];
285
+ function buildChunks(ids: number[], seqLen: number): Array<{inputs: number[], targets: number[]}> {
286
+ const chunks: Array<{inputs: number[], targets: number[]}> = [];
357
287
  for (let start = 0; start + seqLen < ids.length; start += seqLen) {
358
288
  chunks.push({
359
289
  inputs : ids.slice(start, start + seqLen),
360
290
  targets: ids.slice(start + 1, start + seqLen + 1),
361
291
  });
362
292
  }
363
- // Final partial chunk
364
293
  const rem = ids.length % seqLen;
365
294
  if (rem > 1) {
366
295
  const start = ids.length - rem;
@@ -372,21 +301,10 @@ function buildChunks(ids, seqLen) {
372
301
  return chunks;
373
302
  }
374
303
 
375
- /**
376
- * Pack AdamW hyperparameters into an ArrayBuffer matching the WGSL uniform struct.
377
- * Layout (byte offsets):
378
- * 0 : u32 num_elements
379
- * 4 : f32 lr
380
- * 8 : f32 beta1
381
- * 12 : f32 beta2
382
- * 16 : f32 eps
383
- * 20 : f32 weight_decay
384
- * 24 : f32 beta1_t
385
- * 28 : f32 beta2_t
386
- *
387
- * @returns {ArrayBuffer}
388
- */
389
- function packAdamParams(numElements, lr, beta1, beta2, eps, weightDecay, beta1_t, beta2_t) {
304
+ function packAdamParams(
305
+ numElements: number, lr: number, beta1: number, beta2: number,
306
+ eps: number, weightDecay: number, beta1_t: number, beta2_t: number
307
+ ): ArrayBuffer {
390
308
  const buf = new ArrayBuffer(32);
391
309
  new Uint32Array(buf, 0, 1).set([numElements]);
392
310
  new Float32Array(buf, 4, 7).set([lr, beta1, beta2, eps, weightDecay, beta1_t, beta2_t]);
@@ -0,0 +1,147 @@
1
+ /**
2
+ * gpu_utils.ts – WebGPU device management and buffer helpers.
3
+ */
4
+
5
+ /* eslint-disable @typescript-eslint/no-explicit-any */
6
+ const _gpu = globalThis as any;
7
+ const UNIFORM: number = _gpu.GPUBufferUsage?.UNIFORM ?? 0x40;
8
+ const STORAGE: number = _gpu.GPUBufferUsage?.STORAGE ?? 0x80;
9
+ const COPY_SRC: number = _gpu.GPUBufferUsage?.COPY_SRC ?? 0x04;
10
+ const COPY_DST: number = _gpu.GPUBufferUsage?.COPY_DST ?? 0x08;
11
+ const MAP_READ: number = _gpu.GPUBufferUsage?.MAP_READ ?? 0x01;
12
+
13
+ export interface InitWebGPUOptions {
14
+ powerPreference?: 'high-performance' | 'low-power';
15
+ }
16
+
17
+ export interface InitWebGPUResult {
18
+ device: GPUDevice;
19
+ adapter: GPUAdapter;
20
+ }
21
+
22
+ export async function initWebGPU(opts: InitWebGPUOptions = {}): Promise<InitWebGPUResult> {
23
+ if (typeof navigator === 'undefined' || !navigator.gpu) {
24
+ throw new Error(
25
+ 'WebGPU is not available in this environment. ' +
26
+ 'Use Chrome 113+, Edge 113+, or Firefox Nightly with WebGPU enabled.'
27
+ );
28
+ }
29
+
30
+ const adapter = await navigator.gpu.requestAdapter({
31
+ powerPreference: opts.powerPreference ?? 'high-performance',
32
+ });
33
+
34
+ if (!adapter) {
35
+ throw new Error('Failed to acquire a GPUAdapter. Your GPU may not support WebGPU.');
36
+ }
37
+
38
+ const adapterLimits = adapter.limits;
39
+ const requested3GB = 3 * 1024 * 1024 * 1024;
40
+ const device = await adapter.requestDevice({
41
+ requiredLimits: {
42
+ maxBufferSize: Math.min(
43
+ requested3GB,
44
+ adapterLimits.maxBufferSize
45
+ ),
46
+ maxStorageBufferBindingSize: Math.min(
47
+ requested3GB,
48
+ adapterLimits.maxStorageBufferBindingSize
49
+ ),
50
+ maxComputeInvocationsPerWorkgroup: Math.min(
51
+ 256,
52
+ adapterLimits.maxComputeInvocationsPerWorkgroup
53
+ ),
54
+ },
55
+ });
56
+
57
+ device.lost.then((info) => {
58
+ console.error('WebGPU device lost:', info.message);
59
+ });
60
+
61
+ return { device, adapter };
62
+ }
63
+
64
+ export function createStorageBuffer(device: GPUDevice, data: Float32Array | Uint32Array | number[], readable = false): GPUBuffer {
65
+ const arr = data instanceof Float32Array || data instanceof Uint32Array ? data : new Float32Array(data);
66
+ const usage = STORAGE | COPY_DST | (readable ? COPY_SRC : 0);
67
+ const buffer = device.createBuffer({ size: arr.byteLength, usage, mappedAtCreation: true });
68
+ if (arr instanceof Uint32Array) {
69
+ new Uint32Array(buffer.getMappedRange()).set(arr);
70
+ } else {
71
+ new Float32Array(buffer.getMappedRange()).set(arr as Float32Array);
72
+ }
73
+ buffer.unmap();
74
+ return buffer;
75
+ }
76
+
77
+ export function createEmptyStorageBuffer(device: GPUDevice, byteSize: number, readable = false): GPUBuffer {
78
+ const usage = STORAGE | COPY_DST | (readable ? COPY_SRC : 0);
79
+ return device.createBuffer({ size: byteSize, usage });
80
+ }
81
+
82
+ export function createUniformBuffer(device: GPUDevice, data: ArrayBuffer | ArrayBufferView): GPUBuffer {
83
+ const bytes = ArrayBuffer.isView(data) ? data.buffer : data;
84
+ const buffer = device.createBuffer({
85
+ size : bytes.byteLength,
86
+ usage : UNIFORM | COPY_DST,
87
+ mappedAtCreation: true,
88
+ });
89
+ new Uint8Array(buffer.getMappedRange()).set(new Uint8Array(bytes));
90
+ buffer.unmap();
91
+ return buffer;
92
+ }
93
+
94
+ export async function readBuffer(device: GPUDevice, srcBuffer: GPUBuffer, byteSize: number): Promise<Float32Array> {
95
+ const MAP_READ_FLAG: number = _gpu.GPUMapMode?.READ ?? 0x01;
96
+ const stagingBuffer = device.createBuffer({
97
+ size : byteSize,
98
+ usage : MAP_READ | COPY_DST,
99
+ });
100
+
101
+ const encoder = device.createCommandEncoder();
102
+ encoder.copyBufferToBuffer(srcBuffer, 0, stagingBuffer, 0, byteSize);
103
+ device.queue.submit([encoder.finish()]);
104
+
105
+ await stagingBuffer.mapAsync(MAP_READ_FLAG);
106
+ const result = new Float32Array(stagingBuffer.getMappedRange().slice(0));
107
+ stagingBuffer.unmap();
108
+ stagingBuffer.destroy();
109
+ return result;
110
+ }
111
+
112
+ export function uploadBuffer(device: GPUDevice, buffer: GPUBuffer, data: Float32Array, byteOffset = 0): void {
113
+ device.queue.writeBuffer(buffer, byteOffset, data.buffer, data.byteOffset, data.byteLength);
114
+ }
115
+
116
+ export function createComputePipeline(device: GPUDevice, wgslSource: string, entryPoint: string): GPUComputePipeline {
117
+ const shaderModule = device.createShaderModule({ code: wgslSource });
118
+ return device.createComputePipeline({
119
+ layout : 'auto',
120
+ compute: { module: shaderModule, entryPoint },
121
+ });
122
+ }
123
+
124
+ export function createBindGroup(device: GPUDevice, pipeline: GPUComputePipeline, buffers: GPUBuffer[], groupIndex = 0): GPUBindGroup {
125
+ const entries = buffers.map((buf, i) => ({
126
+ binding : i,
127
+ resource: { buffer: buf },
128
+ }));
129
+ return device.createBindGroup({
130
+ layout : pipeline.getBindGroupLayout(groupIndex),
131
+ entries,
132
+ });
133
+ }
134
+
135
+ export function dispatchKernel(device: GPUDevice, pipeline: GPUComputePipeline, bindGroup: GPUBindGroup, workgroups: [number, number, number]): void {
136
+ const encoder = device.createCommandEncoder();
137
+ const pass = encoder.beginComputePass();
138
+ pass.setPipeline(pipeline);
139
+ pass.setBindGroup(0, bindGroup);
140
+ pass.dispatchWorkgroups(...workgroups);
141
+ pass.end();
142
+ device.queue.submit([encoder.finish()]);
143
+ }
144
+
145
+ export function cdiv(a: number, b: number): number {
146
+ return Math.ceil(a / b);
147
+ }