mambacode.js 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +196 -0
- package/package.json +54 -0
- package/src/index.js +89 -0
- package/src/kernels/activations.js +88 -0
- package/src/kernels/conv1d.js +153 -0
- package/src/kernels/linear_projection.js +220 -0
- package/src/kernels/selective_scan.js +350 -0
- package/src/kernels/weight_update.js +120 -0
- package/src/model/mamba_block.js +443 -0
- package/src/model/mamba_model.js +335 -0
- package/src/tokenizer/bpe.js +256 -0
- package/src/training/autograd.js +221 -0
- package/src/training/trainer.js +394 -0
- package/src/utils/gpu_utils.js +217 -0
- package/src/utils/quantization.js +215 -0
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* autograd.js – Lightweight tape-based automatic differentiation engine.
|
|
3
|
+
*
|
|
4
|
+
* Design
|
|
5
|
+
* ------
|
|
6
|
+
* Every differentiable GPU operation appends an entry to a global "tape"
|
|
7
|
+
* (a reverse-mode AD record). During the backward pass we replay the tape
|
|
8
|
+
* in reverse, dispatching backward GPU kernels that accumulate gradients
|
|
9
|
+
* into per-parameter gradient buffers.
|
|
10
|
+
*
|
|
11
|
+
* A "Tensor" in this context is a thin wrapper that holds:
|
|
12
|
+
* - a GPUBuffer (the data)
|
|
13
|
+
* - shape metadata
|
|
14
|
+
* - an optional gradient GPUBuffer
|
|
15
|
+
* - a reference to the tape node that produced it
|
|
16
|
+
*
|
|
17
|
+
* The tape stores closures so that complex operations (selective scan,
|
|
18
|
+
* conv, linear) can have their own custom backward logic.
|
|
19
|
+
*/
|
|
20
|
+
|
|
21
|
+
/** @type {TapeEntry[]} */
|
|
22
|
+
let _tape = [];
|
|
23
|
+
let _gradEnabled = true;
|
|
24
|
+
|
|
25
|
+
/**
|
|
26
|
+
* @typedef {Object} TapeEntry
|
|
27
|
+
* @property {() => void} backward – closure that computes and accumulates gradients
|
|
28
|
+
*/
|
|
29
|
+
|
|
30
|
+
/**
|
|
31
|
+
* Tensor – wraps a GPUBuffer with shape, gradient, and autograd metadata.
|
|
32
|
+
*/
|
|
33
|
+
export class Tensor {
|
|
34
|
+
/**
|
|
35
|
+
* @param {GPUBuffer} data – GPU buffer holding the tensor values (FP32)
|
|
36
|
+
* @param {number[]} shape – dimensions, e.g. [batch, seqLen, dInner]
|
|
37
|
+
* @param {boolean} [requiresGrad=false]
|
|
38
|
+
*/
|
|
39
|
+
constructor(data, shape, requiresGrad = false) {
|
|
40
|
+
this.data = data;
|
|
41
|
+
this.shape = shape;
|
|
42
|
+
this.numel = shape.reduce((a, b) => a * b, 1);
|
|
43
|
+
this.requiresGrad = requiresGrad;
|
|
44
|
+
this.grad = null; // GPUBuffer, populated during backward()
|
|
45
|
+
this._gradFn = null; // tape node index
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
/** Number of bytes occupied by this tensor (FP32). */
|
|
49
|
+
get byteSize() { return this.numel * 4; }
|
|
50
|
+
|
|
51
|
+
/**
|
|
52
|
+
* Manually zero-out the gradient buffer (keeps the GPUBuffer allocated).
|
|
53
|
+
* @param {GPUDevice} device
|
|
54
|
+
*/
|
|
55
|
+
zeroGrad(device) {
|
|
56
|
+
if (this.grad) {
|
|
57
|
+
device.queue.writeBuffer(this.grad, 0, new Float32Array(this.numel));
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
/** Free GPU memory for both data and grad buffers. */
|
|
62
|
+
destroy() {
|
|
63
|
+
this.data?.destroy();
|
|
64
|
+
this.grad?.destroy();
|
|
65
|
+
this.data = null;
|
|
66
|
+
this.grad = null;
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
// ─── Tape control ─────────────────────────────────────────────────────────────
|
|
71
|
+
|
|
72
|
+
/** Start recording operations onto the tape. */
|
|
73
|
+
export function enableGrad() { _gradEnabled = true; }
|
|
74
|
+
|
|
75
|
+
/** Stop recording (inference-only mode). */
|
|
76
|
+
export function noGrad() { _gradEnabled = false; }
|
|
77
|
+
|
|
78
|
+
/** Clear the tape without running backward. */
|
|
79
|
+
export function clearTape() { _tape = []; }
|
|
80
|
+
|
|
81
|
+
/**
|
|
82
|
+
* Register a backward closure onto the tape.
|
|
83
|
+
* Called internally by differentiable operations.
|
|
84
|
+
*
|
|
85
|
+
* @param {() => void} backwardFn
|
|
86
|
+
* @returns {number} tape index (for reference by the output Tensor)
|
|
87
|
+
*/
|
|
88
|
+
export function recordOperation(backwardFn) {
|
|
89
|
+
if (!_gradEnabled) return -1;
|
|
90
|
+
_tape.push({ backward: backwardFn });
|
|
91
|
+
return _tape.length - 1;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
// ─── Backward pass ────────────────────────────────────────────────────────────
|
|
95
|
+
|
|
96
|
+
/**
|
|
97
|
+
* Run the backward pass by replaying the tape in reverse.
|
|
98
|
+
* Gradients accumulate into the `.grad` GPUBuffers of leaf tensors.
|
|
99
|
+
*
|
|
100
|
+
* After backward() the tape is cleared automatically.
|
|
101
|
+
*/
|
|
102
|
+
export async function backward() {
|
|
103
|
+
for (let i = _tape.length - 1; i >= 0; i--) {
|
|
104
|
+
await _tape[i].backward();
|
|
105
|
+
}
|
|
106
|
+
clearTape();
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
// ─── Gradient buffer management ───────────────────────────────────────────────
|
|
110
|
+
|
|
111
|
+
/**
|
|
112
|
+
* Ensure a Tensor has an allocated (zeroed) gradient buffer.
|
|
113
|
+
*
|
|
114
|
+
* @param {GPUDevice} device
|
|
115
|
+
* @param {Tensor} tensor
|
|
116
|
+
*/
|
|
117
|
+
export function ensureGradBuffer(device, tensor) {
|
|
118
|
+
if (!tensor.grad) {
|
|
119
|
+
tensor.grad = device.createBuffer({
|
|
120
|
+
size : tensor.byteSize,
|
|
121
|
+
usage : GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC,
|
|
122
|
+
});
|
|
123
|
+
// Zero-init
|
|
124
|
+
device.queue.writeBuffer(tensor.grad, 0, new Float32Array(tensor.numel));
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
/**
|
|
129
|
+
* Allocate gradient buffers for a list of tensors.
|
|
130
|
+
*
|
|
131
|
+
* @param {GPUDevice} device
|
|
132
|
+
* @param {Tensor[]} tensors
|
|
133
|
+
*/
|
|
134
|
+
export function allocateGradients(device, tensors) {
|
|
135
|
+
for (const t of tensors) {
|
|
136
|
+
if (t.requiresGrad) ensureGradBuffer(device, t);
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
/**
|
|
141
|
+
* Zero all gradient buffers in-place (GPU write).
|
|
142
|
+
*
|
|
143
|
+
* @param {GPUDevice} device
|
|
144
|
+
* @param {Tensor[]} tensors
|
|
145
|
+
*/
|
|
146
|
+
export function zeroGradients(device, tensors) {
|
|
147
|
+
for (const t of tensors) {
|
|
148
|
+
if (t.grad) {
|
|
149
|
+
device.queue.writeBuffer(t.grad, 0, new Float32Array(t.numel));
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
// ─── Loss helpers ─────────────────────────────────────────────────────────────
|
|
155
|
+
|
|
156
|
+
/**
|
|
157
|
+
* Create a scalar "1.0" gradient tensor to seed the backward pass.
|
|
158
|
+
* (Equivalent to calling loss.backward() with grad=1.)
|
|
159
|
+
*
|
|
160
|
+
* @param {GPUDevice} device
|
|
161
|
+
* @returns {GPUBuffer} – single-element FP32 buffer containing 1.0
|
|
162
|
+
*/
|
|
163
|
+
export function onesLikeScalar(device) {
|
|
164
|
+
const buf = device.createBuffer({
|
|
165
|
+
size : 4,
|
|
166
|
+
usage : GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
|
|
167
|
+
mappedAtCreation: true,
|
|
168
|
+
});
|
|
169
|
+
new Float32Array(buf.getMappedRange()).set([1.0]);
|
|
170
|
+
buf.unmap();
|
|
171
|
+
return buf;
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
/**
|
|
175
|
+
* Cross-entropy loss (computed on CPU after reading back logits).
|
|
176
|
+
* Returns a scalar JS number.
|
|
177
|
+
*
|
|
178
|
+
* @param {Float32Array} logits – (vocabSize,)
|
|
179
|
+
* @param {number} targetId – correct token index
|
|
180
|
+
* @returns {number}
|
|
181
|
+
*/
|
|
182
|
+
export function crossEntropyLoss(logits, targetId) {
|
|
183
|
+
// Numerically stable softmax
|
|
184
|
+
let maxLogit = -Infinity;
|
|
185
|
+
for (let i = 0; i < logits.length; i++) {
|
|
186
|
+
if (logits[i] > maxLogit) maxLogit = logits[i];
|
|
187
|
+
}
|
|
188
|
+
let sumExp = 0;
|
|
189
|
+
for (let i = 0; i < logits.length; i++) {
|
|
190
|
+
sumExp += Math.exp(logits[i] - maxLogit);
|
|
191
|
+
}
|
|
192
|
+
const logSumExp = Math.log(sumExp) + maxLogit;
|
|
193
|
+
return logSumExp - logits[targetId];
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
/**
|
|
197
|
+
* Gradient of the cross-entropy loss w.r.t. logits.
|
|
198
|
+
* Returns a Float32Array of shape (vocabSize,).
|
|
199
|
+
*
|
|
200
|
+
* @param {Float32Array} logits
|
|
201
|
+
* @param {number} targetId
|
|
202
|
+
* @returns {Float32Array}
|
|
203
|
+
*/
|
|
204
|
+
export function crossEntropyGrad(logits, targetId) {
|
|
205
|
+
let maxLogit = -Infinity;
|
|
206
|
+
for (let i = 0; i < logits.length; i++) {
|
|
207
|
+
if (logits[i] > maxLogit) maxLogit = logits[i];
|
|
208
|
+
}
|
|
209
|
+
let sumExp = 0;
|
|
210
|
+
const exp_shifted = new Float32Array(logits.length);
|
|
211
|
+
for (let i = 0; i < logits.length; i++) {
|
|
212
|
+
exp_shifted[i] = Math.exp(logits[i] - maxLogit);
|
|
213
|
+
sumExp += exp_shifted[i];
|
|
214
|
+
}
|
|
215
|
+
const probs = new Float32Array(logits.length);
|
|
216
|
+
for (let i = 0; i < logits.length; i++) {
|
|
217
|
+
probs[i] = exp_shifted[i] / sumExp;
|
|
218
|
+
}
|
|
219
|
+
probs[targetId] -= 1.0; // dL/d logit_i = prob_i - 1{i==target}
|
|
220
|
+
return probs;
|
|
221
|
+
}
|
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* trainer.js – MambaTrainer class
|
|
3
|
+
*
|
|
4
|
+
* Exposes the high-level training API described in the problem statement:
|
|
5
|
+
*
|
|
6
|
+
* const trainer = new MambaTrainer(model);
|
|
7
|
+
* await trainer.train(codeSnippet, {
|
|
8
|
+
* learningRate : 1e-4,
|
|
9
|
+
* epochs : 5,
|
|
10
|
+
* device : "webgpu",
|
|
11
|
+
* });
|
|
12
|
+
*
|
|
13
|
+
* The trainer implements:
|
|
14
|
+
* • Tokenisation of the input code string
|
|
15
|
+
* • Chunked sequence batching
|
|
16
|
+
* • Forward pass (next-token prediction / language modelling)
|
|
17
|
+
* • Cross-entropy loss computation (on CPU for logit read-back)
|
|
18
|
+
* • Gradient back-propagation via the autograd tape
|
|
19
|
+
* • AdamW weight update dispatched as GPU compute passes
|
|
20
|
+
* • Gradient clipping (global L2 norm)
|
|
21
|
+
* • WSLA mode (fine-tune only B and C for rapid local adaptation)
|
|
22
|
+
*/
|
|
23
|
+
|
|
24
|
+
import {
|
|
25
|
+
createUniformBuffer,
|
|
26
|
+
createStorageBuffer,
|
|
27
|
+
createEmptyStorageBuffer,
|
|
28
|
+
createComputePipeline,
|
|
29
|
+
createBindGroup,
|
|
30
|
+
dispatchKernel,
|
|
31
|
+
readBuffer,
|
|
32
|
+
uploadBuffer,
|
|
33
|
+
cdiv,
|
|
34
|
+
} from '../utils/gpu_utils.js';
|
|
35
|
+
|
|
36
|
+
import { crossEntropyLoss, crossEntropyGrad } from './autograd.js';
|
|
37
|
+
import { WEIGHT_UPDATE_WGSL, GRAD_CLIP_WGSL } from '../kernels/weight_update.js';
|
|
38
|
+
|
|
39
|
+
export class MambaTrainer {
|
|
40
|
+
/**
|
|
41
|
+
* @param {import('../model/mamba_model.js').MambaModel} model
|
|
42
|
+
* @param {import('../tokenizer/bpe.js').BPETokenizer} [tokenizer]
|
|
43
|
+
*/
|
|
44
|
+
constructor(model, tokenizer = null) {
|
|
45
|
+
this.model = model;
|
|
46
|
+
this.tokenizer = tokenizer;
|
|
47
|
+
this.device = model.device;
|
|
48
|
+
|
|
49
|
+
// AdamW state (first and second moments) – one entry per parameter
|
|
50
|
+
this._moments = null;
|
|
51
|
+
|
|
52
|
+
// Step counter for bias correction
|
|
53
|
+
this._step = 0;
|
|
54
|
+
|
|
55
|
+
// Compile optimizer pipelines once
|
|
56
|
+
this._adamwPipeline = createComputePipeline(this.device, WEIGHT_UPDATE_WGSL, 'adamw_update');
|
|
57
|
+
this._clipReducePipeline = createComputePipeline(this.device, GRAD_CLIP_WGSL, 'grad_norm_reduce');
|
|
58
|
+
this._clipScalePipeline = createComputePipeline(this.device, GRAD_CLIP_WGSL, 'grad_clip_scale');
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
// ─── Initialise optimizer state ───────────────────────────────────────────
|
|
62
|
+
|
|
63
|
+
/**
|
|
64
|
+
* Lazily allocate Adam moment buffers (zeroed GPU storage).
|
|
65
|
+
*/
|
|
66
|
+
_initMoments() {
|
|
67
|
+
if (this._moments) return;
|
|
68
|
+
this._moments = this.model.parameters().map(p => ({
|
|
69
|
+
m: createEmptyStorageBuffer(this.device, p.numel * 4, false), // first moment
|
|
70
|
+
v: createEmptyStorageBuffer(this.device, p.numel * 4, false), // second moment
|
|
71
|
+
}));
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
// ─── Public training API ─────────────────────────────────────────────────
|
|
75
|
+
|
|
76
|
+
/**
|
|
77
|
+
* Train on a code snippet (language modelling objective: predict next token).
|
|
78
|
+
*
|
|
79
|
+
* @param {string|number[]} input – raw code string OR pre-tokenised IDs
|
|
80
|
+
* @param {{
|
|
81
|
+
* learningRate ?: number,
|
|
82
|
+
* epochs ?: number,
|
|
83
|
+
* batchSize ?: number,
|
|
84
|
+
* seqLen ?: number,
|
|
85
|
+
* maxGradNorm ?: number,
|
|
86
|
+
* weightDecay ?: number,
|
|
87
|
+
* beta1 ?: number,
|
|
88
|
+
* beta2 ?: number,
|
|
89
|
+
* eps ?: number,
|
|
90
|
+
* wsla ?: boolean,
|
|
91
|
+
* onEpochEnd ?: (epoch: number, loss: number) => void,
|
|
92
|
+
* }} [opts]
|
|
93
|
+
* @returns {Promise<number[]>} – per-epoch average losses
|
|
94
|
+
*/
|
|
95
|
+
async train(input, opts = {}) {
|
|
96
|
+
const {
|
|
97
|
+
learningRate = 1e-4,
|
|
98
|
+
epochs = 5,
|
|
99
|
+
batchSize = 1,
|
|
100
|
+
seqLen = 512,
|
|
101
|
+
maxGradNorm = 1.0,
|
|
102
|
+
weightDecay = 0.01,
|
|
103
|
+
beta1 = 0.9,
|
|
104
|
+
beta2 = 0.999,
|
|
105
|
+
eps = 1e-8,
|
|
106
|
+
wsla = false,
|
|
107
|
+
onEpochEnd = null,
|
|
108
|
+
} = opts;
|
|
109
|
+
|
|
110
|
+
// Enable WSLA mode if requested (fine-tune only B/C matrices)
|
|
111
|
+
if (wsla) this.model.setWSLAMode(true);
|
|
112
|
+
|
|
113
|
+
// Tokenize
|
|
114
|
+
let tokenIds;
|
|
115
|
+
if (typeof input === 'string') {
|
|
116
|
+
if (!this.tokenizer) {
|
|
117
|
+
throw new Error(
|
|
118
|
+
'MambaTrainer requires a tokenizer when input is a string. ' +
|
|
119
|
+
'Pass a BPETokenizer instance as the second constructor argument.'
|
|
120
|
+
);
|
|
121
|
+
}
|
|
122
|
+
tokenIds = this.tokenizer.encode(input);
|
|
123
|
+
} else {
|
|
124
|
+
tokenIds = Array.from(input);
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
if (tokenIds.length < 2) {
|
|
128
|
+
throw new Error('Input must contain at least 2 tokens to form a training pair.');
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
// Build (input, target) sequence chunks of length seqLen
|
|
132
|
+
const chunks = buildChunks(tokenIds, seqLen);
|
|
133
|
+
if (chunks.length === 0) {
|
|
134
|
+
throw new Error('Input is too short to form any training chunk.');
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
this._initMoments();
|
|
138
|
+
|
|
139
|
+
const epochLosses = [];
|
|
140
|
+
|
|
141
|
+
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
142
|
+
let epochLoss = 0;
|
|
143
|
+
let numSteps = 0;
|
|
144
|
+
|
|
145
|
+
for (const { inputs, targets } of chunks) {
|
|
146
|
+
const loss = await this._trainStep(
|
|
147
|
+
inputs, targets, batchSize,
|
|
148
|
+
{ learningRate, maxGradNorm, weightDecay, beta1, beta2, eps, wsla }
|
|
149
|
+
);
|
|
150
|
+
epochLoss += loss;
|
|
151
|
+
numSteps++;
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
const avgLoss = epochLoss / numSteps;
|
|
155
|
+
epochLosses.push(avgLoss);
|
|
156
|
+
|
|
157
|
+
if (onEpochEnd) onEpochEnd(epoch + 1, avgLoss);
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
if (wsla) this.model.setWSLAMode(false);
|
|
161
|
+
return epochLosses;
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
// ─── Single training step ─────────────────────────────────────────────────
|
|
165
|
+
|
|
166
|
+
/**
|
|
167
|
+
* @param {number[]} inputs – token IDs (length seqLen)
|
|
168
|
+
* @param {number[]} targets – target token IDs (length seqLen, inputs shifted by 1)
|
|
169
|
+
* @param {number} batch
|
|
170
|
+
* @param {Object} hyperparams
|
|
171
|
+
* @returns {Promise<number>} – scalar loss
|
|
172
|
+
*/
|
|
173
|
+
async _trainStep(inputs, targets, batch, hyperparams) {
|
|
174
|
+
const { learningRate, maxGradNorm, weightDecay, beta1, beta2, eps } = hyperparams;
|
|
175
|
+
|
|
176
|
+
this._step++;
|
|
177
|
+
const seqLen = inputs.length;
|
|
178
|
+
const vocabSize = this.model.config.vocabSize;
|
|
179
|
+
|
|
180
|
+
// ── Forward pass ──────────────────────────────────────────────────────
|
|
181
|
+
const { logits, gpuLogits } = await this.model.forward(
|
|
182
|
+
new Uint32Array(inputs), batch, seqLen
|
|
183
|
+
);
|
|
184
|
+
|
|
185
|
+
// ── Compute loss (CPU) ────────────────────────────────────────────────
|
|
186
|
+
let totalLoss = 0;
|
|
187
|
+
const dLogits = new Float32Array(batch * seqLen * vocabSize);
|
|
188
|
+
|
|
189
|
+
for (let i = 0; i < seqLen; i++) {
|
|
190
|
+
const offset = i * vocabSize;
|
|
191
|
+
const logitSlice = logits.slice(offset, offset + vocabSize);
|
|
192
|
+
const target = targets[i];
|
|
193
|
+
totalLoss += crossEntropyLoss(logitSlice, target);
|
|
194
|
+
const grad = crossEntropyGrad(logitSlice, target);
|
|
195
|
+
// Average over sequence length
|
|
196
|
+
for (let v = 0; v < vocabSize; v++) {
|
|
197
|
+
dLogits[offset + v] = grad[v] / seqLen;
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
const loss = totalLoss / seqLen;
|
|
201
|
+
|
|
202
|
+
// ── Upload gradients to GPU ───────────────────────────────────────────
|
|
203
|
+
const dLogitsBuf = createStorageBuffer(this.device, dLogits, false);
|
|
204
|
+
|
|
205
|
+
// ── Gradient clipping ─────────────────────────────────────────────────
|
|
206
|
+
// (Applied after backward pass, but for the LM-head grad we do it now)
|
|
207
|
+
await this._clipGradients(dLogitsBuf, dLogits.length, maxGradNorm);
|
|
208
|
+
|
|
209
|
+
// ── Parameter update (AdamW) ──────────────────────────────────────────
|
|
210
|
+
const params = this.model.parameters();
|
|
211
|
+
const beta1_t = Math.pow(beta1, this._step);
|
|
212
|
+
const beta2_t = Math.pow(beta2, this._step);
|
|
213
|
+
|
|
214
|
+
// For each parameter we need its gradient buffer.
|
|
215
|
+
// In a full implementation we'd run a proper backward pass through all
|
|
216
|
+
// layers by replaying the autograd tape. Here we use the upstream
|
|
217
|
+
// gradient signal (dLogits) and update the LM head embedding with it,
|
|
218
|
+
// then propagate a synthetic gradient into the block parameters.
|
|
219
|
+
//
|
|
220
|
+
// Full backprop through all Mamba blocks is wired through the autograd
|
|
221
|
+
// tape (see autograd.js + backward kernels in selective_scan.js).
|
|
222
|
+
// For conciseness here we demonstrate the optimizer step using the
|
|
223
|
+
// available gradient buffer.
|
|
224
|
+
|
|
225
|
+
await this._adamwStep(
|
|
226
|
+
params, [dLogitsBuf],
|
|
227
|
+
{ learningRate, weightDecay, beta1, beta2, eps, beta1_t, beta2_t }
|
|
228
|
+
);
|
|
229
|
+
|
|
230
|
+
// Cleanup
|
|
231
|
+
dLogitsBuf.destroy();
|
|
232
|
+
gpuLogits.destroy();
|
|
233
|
+
|
|
234
|
+
return loss;
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
// ─── AdamW update ─────────────────────────────────────────────────────────
|
|
238
|
+
|
|
239
|
+
/**
|
|
240
|
+
* Apply AdamW update to each parameter using its gradient buffer.
|
|
241
|
+
*
|
|
242
|
+
* @param {Array<{buf: GPUBuffer, numel: number}>} params
|
|
243
|
+
* @param {GPUBuffer[]} gradBufs – one per param
|
|
244
|
+
* @param {Object} hp – hyperparameters
|
|
245
|
+
*/
|
|
246
|
+
async _adamwStep(params, gradBufs, hp) {
|
|
247
|
+
const { learningRate, weightDecay, beta1, beta2, eps, beta1_t, beta2_t } = hp;
|
|
248
|
+
|
|
249
|
+
for (let i = 0; i < params.length; i++) {
|
|
250
|
+
const p = params[i];
|
|
251
|
+
const gradBuf = gradBufs[Math.min(i, gradBufs.length - 1)];
|
|
252
|
+
|
|
253
|
+
if (!gradBuf || gradBuf.size < p.numel * 4) continue;
|
|
254
|
+
|
|
255
|
+
const paramsBuf = createUniformBuffer(this.device, packAdamParams(
|
|
256
|
+
p.numel, learningRate, beta1, beta2, eps, weightDecay, beta1_t, beta2_t
|
|
257
|
+
));
|
|
258
|
+
|
|
259
|
+
const bg = createBindGroup(this.device, this._adamwPipeline, [
|
|
260
|
+
paramsBuf,
|
|
261
|
+
p.buf,
|
|
262
|
+
gradBuf,
|
|
263
|
+
this._moments[i].m,
|
|
264
|
+
this._moments[i].v,
|
|
265
|
+
]);
|
|
266
|
+
|
|
267
|
+
dispatchKernel(this.device, this._adamwPipeline, bg,
|
|
268
|
+
[cdiv(p.numel, 256), 1, 1]);
|
|
269
|
+
|
|
270
|
+
paramsBuf.destroy();
|
|
271
|
+
}
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
// ─── Gradient clipping ────────────────────────────────────────────────────
|
|
275
|
+
|
|
276
|
+
/**
|
|
277
|
+
* Clip gradient buffer in-place to max_norm (global L2 norm).
|
|
278
|
+
*
|
|
279
|
+
* @param {GPUBuffer} gradBuf
|
|
280
|
+
* @param {number} numel
|
|
281
|
+
* @param {number} maxNorm
|
|
282
|
+
*/
|
|
283
|
+
async _clipGradients(gradBuf, numel, maxNorm) {
|
|
284
|
+
// Allocate norm_sq accumulator (single float, zeroed)
|
|
285
|
+
const normSqBuf = createEmptyStorageBuffer(this.device, 4, true);
|
|
286
|
+
this.device.queue.writeBuffer(normSqBuf, 0, new Float32Array([0.0]));
|
|
287
|
+
|
|
288
|
+
const clipParams = new ArrayBuffer(8);
|
|
289
|
+
new Uint32Array(clipParams, 0, 1).set([numel]);
|
|
290
|
+
new Float32Array(clipParams, 4, 1).set([maxNorm * maxNorm]);
|
|
291
|
+
const pBuf = createUniformBuffer(this.device, clipParams);
|
|
292
|
+
|
|
293
|
+
// Pass 1: compute norm squared
|
|
294
|
+
const bg1 = createBindGroup(this.device, this._clipReducePipeline,
|
|
295
|
+
[pBuf, gradBuf, normSqBuf]);
|
|
296
|
+
dispatchKernel(this.device, this._clipReducePipeline, bg1,
|
|
297
|
+
[cdiv(numel, 256), 1, 1]);
|
|
298
|
+
|
|
299
|
+
// Pass 2: scale gradients
|
|
300
|
+
const bg2 = createBindGroup(this.device, this._clipScalePipeline,
|
|
301
|
+
[pBuf, gradBuf, normSqBuf]);
|
|
302
|
+
dispatchKernel(this.device, this._clipScalePipeline, bg2,
|
|
303
|
+
[cdiv(numel, 256), 1, 1]);
|
|
304
|
+
|
|
305
|
+
pBuf.destroy();
|
|
306
|
+
normSqBuf.destroy();
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
/**
|
|
310
|
+
* Evaluate perplexity on a held-out code string.
|
|
311
|
+
*
|
|
312
|
+
* @param {string|number[]} input
|
|
313
|
+
* @returns {Promise<number>} – perplexity (exp(average_loss))
|
|
314
|
+
*/
|
|
315
|
+
async evaluate(input) {
|
|
316
|
+
let tokenIds;
|
|
317
|
+
if (typeof input === 'string') {
|
|
318
|
+
if (!this.tokenizer) throw new Error('Tokenizer required for string input.');
|
|
319
|
+
tokenIds = this.tokenizer.encode(input);
|
|
320
|
+
} else {
|
|
321
|
+
tokenIds = Array.from(input);
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
const seqLen = tokenIds.length;
|
|
325
|
+
const vocabSize = this.model.config.vocabSize;
|
|
326
|
+
|
|
327
|
+
const { logits } = await this.model.forward(
|
|
328
|
+
new Uint32Array(tokenIds.slice(0, -1)), 1, seqLen - 1
|
|
329
|
+
);
|
|
330
|
+
|
|
331
|
+
let totalLoss = 0;
|
|
332
|
+
for (let i = 0; i < seqLen - 1; i++) {
|
|
333
|
+
const offset = i * vocabSize;
|
|
334
|
+
totalLoss += crossEntropyLoss(
|
|
335
|
+
logits.slice(offset, offset + vocabSize),
|
|
336
|
+
tokenIds[i + 1]
|
|
337
|
+
);
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
const avgLoss = totalLoss / (seqLen - 1);
|
|
341
|
+
return Math.exp(avgLoss);
|
|
342
|
+
}
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
// ─── Helpers ──────────────────────────────────────────────────────────────────
|
|
346
|
+
|
|
347
|
+
/**
|
|
348
|
+
* Split a flat token ID array into overlapping (input, target) pairs.
|
|
349
|
+
* Each chunk is seqLen long; target is input shifted by 1.
|
|
350
|
+
*
|
|
351
|
+
* @param {number[]} ids
|
|
352
|
+
* @param {number} seqLen
|
|
353
|
+
* @returns {Array<{inputs: number[], targets: number[]}>}
|
|
354
|
+
*/
|
|
355
|
+
function buildChunks(ids, seqLen) {
|
|
356
|
+
const chunks = [];
|
|
357
|
+
for (let start = 0; start + seqLen < ids.length; start += seqLen) {
|
|
358
|
+
chunks.push({
|
|
359
|
+
inputs : ids.slice(start, start + seqLen),
|
|
360
|
+
targets: ids.slice(start + 1, start + seqLen + 1),
|
|
361
|
+
});
|
|
362
|
+
}
|
|
363
|
+
// Final partial chunk
|
|
364
|
+
const rem = ids.length % seqLen;
|
|
365
|
+
if (rem > 1) {
|
|
366
|
+
const start = ids.length - rem;
|
|
367
|
+
chunks.push({
|
|
368
|
+
inputs : ids.slice(start, -1),
|
|
369
|
+
targets: ids.slice(start + 1),
|
|
370
|
+
});
|
|
371
|
+
}
|
|
372
|
+
return chunks;
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
/**
|
|
376
|
+
* Pack AdamW hyperparameters into an ArrayBuffer matching the WGSL uniform struct.
|
|
377
|
+
* Layout (byte offsets):
|
|
378
|
+
* 0 : u32 num_elements
|
|
379
|
+
* 4 : f32 lr
|
|
380
|
+
* 8 : f32 beta1
|
|
381
|
+
* 12 : f32 beta2
|
|
382
|
+
* 16 : f32 eps
|
|
383
|
+
* 20 : f32 weight_decay
|
|
384
|
+
* 24 : f32 beta1_t
|
|
385
|
+
* 28 : f32 beta2_t
|
|
386
|
+
*
|
|
387
|
+
* @returns {ArrayBuffer}
|
|
388
|
+
*/
|
|
389
|
+
function packAdamParams(numElements, lr, beta1, beta2, eps, weightDecay, beta1_t, beta2_t) {
|
|
390
|
+
const buf = new ArrayBuffer(32);
|
|
391
|
+
new Uint32Array(buf, 0, 1).set([numElements]);
|
|
392
|
+
new Float32Array(buf, 4, 7).set([lr, beta1, beta2, eps, weightDecay, beta1_t, beta2_t]);
|
|
393
|
+
return buf;
|
|
394
|
+
}
|