mambacode.js 1.0.0 → 1.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +198 -76
- package/dist/index.d.ts +19 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +18 -0
- package/dist/index.js.map +1 -0
- package/dist/kernels/activations.d.ts +3 -0
- package/dist/kernels/activations.d.ts.map +1 -0
- package/dist/kernels/activations.js +87 -0
- package/dist/kernels/activations.js.map +1 -0
- package/dist/kernels/conv1d.d.ts +3 -0
- package/dist/kernels/conv1d.d.ts.map +1 -0
- package/dist/kernels/conv1d.js +152 -0
- package/dist/kernels/conv1d.js.map +1 -0
- package/dist/kernels/linear_projection.d.ts +3 -0
- package/dist/kernels/linear_projection.d.ts.map +1 -0
- package/dist/kernels/linear_projection.js +219 -0
- package/dist/kernels/linear_projection.js.map +1 -0
- package/dist/kernels/selective_scan.d.ts +3 -0
- package/dist/kernels/selective_scan.d.ts.map +1 -0
- package/dist/kernels/selective_scan.js +348 -0
- package/dist/kernels/selective_scan.js.map +1 -0
- package/dist/kernels/weight_update.d.ts +3 -0
- package/dist/kernels/weight_update.d.ts.map +1 -0
- package/dist/kernels/weight_update.js +119 -0
- package/dist/kernels/weight_update.js.map +1 -0
- package/dist/model/mamba_block.d.ts +64 -0
- package/dist/model/mamba_block.d.ts.map +1 -0
- package/dist/model/mamba_block.js +309 -0
- package/dist/model/mamba_block.js.map +1 -0
- package/dist/model/mamba_model.d.ts +66 -0
- package/dist/model/mamba_model.d.ts.map +1 -0
- package/dist/model/mamba_model.js +289 -0
- package/dist/model/mamba_model.js.map +1 -0
- package/dist/tokenizer/bpe.d.ts +29 -0
- package/dist/tokenizer/bpe.d.ts.map +1 -0
- package/dist/tokenizer/bpe.js +164 -0
- package/dist/tokenizer/bpe.js.map +1 -0
- package/dist/training/autograd.d.ts +27 -0
- package/dist/training/autograd.d.ts.map +1 -0
- package/dist/training/autograd.js +120 -0
- package/dist/training/autograd.js.map +1 -0
- package/dist/training/trainer.d.ts +37 -0
- package/dist/training/trainer.d.ts.map +1 -0
- package/dist/training/trainer.js +183 -0
- package/dist/training/trainer.js.map +1 -0
- package/dist/utils/gpu_utils.d.ts +21 -0
- package/dist/utils/gpu_utils.d.ts.map +1 -0
- package/dist/utils/gpu_utils.js +111 -0
- package/dist/utils/gpu_utils.js.map +1 -0
- package/dist/utils/quantization.d.ts +26 -0
- package/dist/utils/quantization.d.ts.map +1 -0
- package/dist/utils/quantization.js +116 -0
- package/dist/utils/quantization.js.map +1 -0
- package/package.json +43 -18
- package/src/index.ts +61 -0
- package/src/kernels/{activations.js → activations.ts} +2 -2
- package/src/kernels/{linear_projection.js → linear_projection.ts} +2 -2
- package/src/kernels/{selective_scan.js → selective_scan.ts} +2 -2
- package/src/kernels/{weight_update.js → weight_update.ts} +2 -2
- package/src/model/{mamba_block.js → mamba_block.ts} +134 -170
- package/src/model/{mamba_model.js → mamba_model.ts} +165 -121
- package/src/tokenizer/bpe.ts +186 -0
- package/src/training/autograd.ts +135 -0
- package/src/training/{trainer.js → trainer.ts} +79 -161
- package/src/utils/gpu_utils.ts +147 -0
- package/src/utils/quantization.ts +154 -0
- package/src/index.js +0 -89
- package/src/tokenizer/bpe.js +0 -256
- package/src/training/autograd.js +0 -221
- package/src/utils/gpu_utils.js +0 -217
- package/src/utils/quantization.js +0 -215
- /package/src/kernels/{conv1d.js → conv1d.ts} +0 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* autograd.ts – Lightweight tape-based automatic differentiation engine.
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
/* eslint-disable @typescript-eslint/no-explicit-any */
|
|
6
|
+
const _gpu = globalThis as any;
|
|
7
|
+
|
|
8
|
+
interface TapeEntry {
|
|
9
|
+
backward: () => void | Promise<void>;
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
let _tape: TapeEntry[] = [];
|
|
13
|
+
let _gradEnabled = true;
|
|
14
|
+
|
|
15
|
+
export class Tensor {
|
|
16
|
+
data: GPUBuffer | null;
|
|
17
|
+
shape: number[];
|
|
18
|
+
numel: number;
|
|
19
|
+
requiresGrad: boolean;
|
|
20
|
+
grad: GPUBuffer | null;
|
|
21
|
+
_gradFn: number | null;
|
|
22
|
+
|
|
23
|
+
constructor(data: GPUBuffer | null, shape: number[], requiresGrad = false) {
|
|
24
|
+
this.data = data;
|
|
25
|
+
this.shape = shape;
|
|
26
|
+
this.numel = shape.reduce((a, b) => a * b, 1);
|
|
27
|
+
this.requiresGrad = requiresGrad;
|
|
28
|
+
this.grad = null;
|
|
29
|
+
this._gradFn = null;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
get byteSize(): number { return this.numel * 4; }
|
|
33
|
+
|
|
34
|
+
zeroGrad(device: GPUDevice): void {
|
|
35
|
+
if (this.grad) {
|
|
36
|
+
device.queue.writeBuffer(this.grad, 0, new Float32Array(this.numel));
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
destroy(): void {
|
|
41
|
+
this.data?.destroy();
|
|
42
|
+
this.grad?.destroy();
|
|
43
|
+
this.data = null;
|
|
44
|
+
this.grad = null;
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
export function enableGrad(): void { _gradEnabled = true; }
|
|
49
|
+
export function noGrad(): void { _gradEnabled = false; }
|
|
50
|
+
export function clearTape(): void { _tape = []; }
|
|
51
|
+
|
|
52
|
+
export function recordOperation(backwardFn: () => void | Promise<void>): number {
|
|
53
|
+
if (!_gradEnabled) return -1;
|
|
54
|
+
_tape.push({ backward: backwardFn });
|
|
55
|
+
return _tape.length - 1;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
export async function backward(): Promise<void> {
|
|
59
|
+
for (let i = _tape.length - 1; i >= 0; i--) {
|
|
60
|
+
await _tape[i]!.backward();
|
|
61
|
+
}
|
|
62
|
+
clearTape();
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
export function ensureGradBuffer(device: GPUDevice, tensor: Tensor): void {
|
|
66
|
+
if (!tensor.grad) {
|
|
67
|
+
const STORAGE_USAGE: number = (_gpu.GPUBufferUsage?.STORAGE ?? 0x80) |
|
|
68
|
+
(_gpu.GPUBufferUsage?.COPY_DST ?? 0x08) |
|
|
69
|
+
(_gpu.GPUBufferUsage?.COPY_SRC ?? 0x04);
|
|
70
|
+
tensor.grad = device.createBuffer({
|
|
71
|
+
size : tensor.byteSize,
|
|
72
|
+
usage : STORAGE_USAGE,
|
|
73
|
+
});
|
|
74
|
+
device.queue.writeBuffer(tensor.grad, 0, new Float32Array(tensor.numel));
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
export function allocateGradients(device: GPUDevice, tensors: Tensor[]): void {
|
|
79
|
+
for (const t of tensors) {
|
|
80
|
+
if (t.requiresGrad) ensureGradBuffer(device, t);
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
export function zeroGradients(device: GPUDevice, tensors: Tensor[]): void {
|
|
85
|
+
for (const t of tensors) {
|
|
86
|
+
if (t.grad) {
|
|
87
|
+
device.queue.writeBuffer(t.grad, 0, new Float32Array(t.numel));
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
export function onesLikeScalar(device: GPUDevice): GPUBuffer {
|
|
93
|
+
const USAGE: number = (_gpu.GPUBufferUsage?.STORAGE ?? 0x80) |
|
|
94
|
+
(_gpu.GPUBufferUsage?.COPY_DST ?? 0x08);
|
|
95
|
+
const buf = device.createBuffer({
|
|
96
|
+
size : 4,
|
|
97
|
+
usage : USAGE,
|
|
98
|
+
mappedAtCreation: true,
|
|
99
|
+
});
|
|
100
|
+
new Float32Array(buf.getMappedRange()).set([1.0]);
|
|
101
|
+
buf.unmap();
|
|
102
|
+
return buf;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
export function crossEntropyLoss(logits: Float32Array, targetId: number): number {
|
|
106
|
+
let maxLogit = -Infinity;
|
|
107
|
+
for (let i = 0; i < logits.length; i++) {
|
|
108
|
+
if (logits[i]! > maxLogit) maxLogit = logits[i]!;
|
|
109
|
+
}
|
|
110
|
+
let sumExp = 0;
|
|
111
|
+
for (let i = 0; i < logits.length; i++) {
|
|
112
|
+
sumExp += Math.exp(logits[i]! - maxLogit);
|
|
113
|
+
}
|
|
114
|
+
const logSumExp = Math.log(sumExp) + maxLogit;
|
|
115
|
+
return logSumExp - logits[targetId]!;
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
export function crossEntropyGrad(logits: Float32Array, targetId: number): Float32Array {
|
|
119
|
+
let maxLogit = -Infinity;
|
|
120
|
+
for (let i = 0; i < logits.length; i++) {
|
|
121
|
+
if (logits[i]! > maxLogit) maxLogit = logits[i]!;
|
|
122
|
+
}
|
|
123
|
+
let sumExp = 0;
|
|
124
|
+
const exp_shifted = new Float32Array(logits.length);
|
|
125
|
+
for (let i = 0; i < logits.length; i++) {
|
|
126
|
+
exp_shifted[i] = Math.exp(logits[i]! - maxLogit);
|
|
127
|
+
sumExp += exp_shifted[i]!;
|
|
128
|
+
}
|
|
129
|
+
const probs = new Float32Array(logits.length);
|
|
130
|
+
for (let i = 0; i < logits.length; i++) {
|
|
131
|
+
probs[i] = exp_shifted[i]! / sumExp;
|
|
132
|
+
}
|
|
133
|
+
probs[targetId] = (probs[targetId] ?? 0) - 1.0;
|
|
134
|
+
return probs;
|
|
135
|
+
}
|
|
@@ -1,24 +1,5 @@
|
|
|
1
1
|
/**
|
|
2
|
-
* trainer.
|
|
3
|
-
*
|
|
4
|
-
* Exposes the high-level training API described in the problem statement:
|
|
5
|
-
*
|
|
6
|
-
* const trainer = new MambaTrainer(model);
|
|
7
|
-
* await trainer.train(codeSnippet, {
|
|
8
|
-
* learningRate : 1e-4,
|
|
9
|
-
* epochs : 5,
|
|
10
|
-
* device : "webgpu",
|
|
11
|
-
* });
|
|
12
|
-
*
|
|
13
|
-
* The trainer implements:
|
|
14
|
-
* • Tokenisation of the input code string
|
|
15
|
-
* • Chunked sequence batching
|
|
16
|
-
* • Forward pass (next-token prediction / language modelling)
|
|
17
|
-
* • Cross-entropy loss computation (on CPU for logit read-back)
|
|
18
|
-
* • Gradient back-propagation via the autograd tape
|
|
19
|
-
* • AdamW weight update dispatched as GPU compute passes
|
|
20
|
-
* • Gradient clipping (global L2 norm)
|
|
21
|
-
* • WSLA mode (fine-tune only B and C for rapid local adaptation)
|
|
2
|
+
* trainer.ts – MambaTrainer class
|
|
22
3
|
*/
|
|
23
4
|
|
|
24
5
|
import {
|
|
@@ -28,71 +9,79 @@ import {
|
|
|
28
9
|
createComputePipeline,
|
|
29
10
|
createBindGroup,
|
|
30
11
|
dispatchKernel,
|
|
31
|
-
readBuffer,
|
|
32
|
-
uploadBuffer,
|
|
33
12
|
cdiv,
|
|
34
13
|
} from '../utils/gpu_utils.js';
|
|
35
14
|
|
|
36
15
|
import { crossEntropyLoss, crossEntropyGrad } from './autograd.js';
|
|
37
16
|
import { WEIGHT_UPDATE_WGSL, GRAD_CLIP_WGSL } from '../kernels/weight_update.js';
|
|
17
|
+
import { MambaModel, MambaModelConfig } from '../model/mamba_model.js';
|
|
18
|
+
import { BPETokenizer } from '../tokenizer/bpe.js';
|
|
19
|
+
import { BlockParam } from '../model/mamba_block.js';
|
|
20
|
+
|
|
21
|
+
export interface TrainOptions {
|
|
22
|
+
learningRate?: number;
|
|
23
|
+
epochs?: number;
|
|
24
|
+
batchSize?: number;
|
|
25
|
+
seqLen?: number;
|
|
26
|
+
maxGradNorm?: number;
|
|
27
|
+
weightDecay?: number;
|
|
28
|
+
beta1?: number;
|
|
29
|
+
beta2?: number;
|
|
30
|
+
eps?: number;
|
|
31
|
+
wsla?: boolean;
|
|
32
|
+
onEpochEnd?: ((epoch: number, loss: number) => void) | null;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
interface AdamMoments {
|
|
36
|
+
m: GPUBuffer;
|
|
37
|
+
v: GPUBuffer;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
interface AdamHyperparams {
|
|
41
|
+
learningRate: number;
|
|
42
|
+
weightDecay: number;
|
|
43
|
+
beta1: number;
|
|
44
|
+
beta2: number;
|
|
45
|
+
eps: number;
|
|
46
|
+
beta1_t: number;
|
|
47
|
+
beta2_t: number;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
// Re-export to satisfy import in other files
|
|
51
|
+
export type { MambaModelConfig };
|
|
38
52
|
|
|
39
53
|
export class MambaTrainer {
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
54
|
+
model: MambaModel;
|
|
55
|
+
tokenizer: BPETokenizer | null;
|
|
56
|
+
device: GPUDevice;
|
|
57
|
+
private _moments: AdamMoments[] | null;
|
|
58
|
+
private _step: number;
|
|
59
|
+
private _adamwPipeline: GPUComputePipeline;
|
|
60
|
+
private _clipReducePipeline: GPUComputePipeline;
|
|
61
|
+
private _clipScalePipeline: GPUComputePipeline;
|
|
62
|
+
|
|
63
|
+
constructor(model: MambaModel, tokenizer: BPETokenizer | null = null) {
|
|
45
64
|
this.model = model;
|
|
46
65
|
this.tokenizer = tokenizer;
|
|
47
66
|
this.device = model.device;
|
|
48
67
|
|
|
49
|
-
// AdamW state (first and second moments) – one entry per parameter
|
|
50
68
|
this._moments = null;
|
|
51
|
-
|
|
52
|
-
// Step counter for bias correction
|
|
53
69
|
this._step = 0;
|
|
54
70
|
|
|
55
|
-
// Compile optimizer pipelines once
|
|
56
71
|
this._adamwPipeline = createComputePipeline(this.device, WEIGHT_UPDATE_WGSL, 'adamw_update');
|
|
57
72
|
this._clipReducePipeline = createComputePipeline(this.device, GRAD_CLIP_WGSL, 'grad_norm_reduce');
|
|
58
73
|
this._clipScalePipeline = createComputePipeline(this.device, GRAD_CLIP_WGSL, 'grad_clip_scale');
|
|
59
74
|
}
|
|
60
75
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
/**
|
|
64
|
-
* Lazily allocate Adam moment buffers (zeroed GPU storage).
|
|
65
|
-
*/
|
|
66
|
-
_initMoments() {
|
|
76
|
+
private _initMoments(): void {
|
|
67
77
|
if (this._moments) return;
|
|
68
78
|
this._moments = this.model.parameters().map(p => ({
|
|
69
|
-
m: createEmptyStorageBuffer(this.device, p.numel * 4, false),
|
|
70
|
-
v: createEmptyStorageBuffer(this.device, p.numel * 4, false),
|
|
79
|
+
m: createEmptyStorageBuffer(this.device, p.numel * 4, false),
|
|
80
|
+
v: createEmptyStorageBuffer(this.device, p.numel * 4, false),
|
|
71
81
|
}));
|
|
72
82
|
}
|
|
73
83
|
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
/**
|
|
77
|
-
* Train on a code snippet (language modelling objective: predict next token).
|
|
78
|
-
*
|
|
79
|
-
* @param {string|number[]} input – raw code string OR pre-tokenised IDs
|
|
80
|
-
* @param {{
|
|
81
|
-
* learningRate ?: number,
|
|
82
|
-
* epochs ?: number,
|
|
83
|
-
* batchSize ?: number,
|
|
84
|
-
* seqLen ?: number,
|
|
85
|
-
* maxGradNorm ?: number,
|
|
86
|
-
* weightDecay ?: number,
|
|
87
|
-
* beta1 ?: number,
|
|
88
|
-
* beta2 ?: number,
|
|
89
|
-
* eps ?: number,
|
|
90
|
-
* wsla ?: boolean,
|
|
91
|
-
* onEpochEnd ?: (epoch: number, loss: number) => void,
|
|
92
|
-
* }} [opts]
|
|
93
|
-
* @returns {Promise<number[]>} – per-epoch average losses
|
|
94
|
-
*/
|
|
95
|
-
async train(input, opts = {}) {
|
|
84
|
+
async train(input: string | number[], opts: TrainOptions = {}): Promise<number[]> {
|
|
96
85
|
const {
|
|
97
86
|
learningRate = 1e-4,
|
|
98
87
|
epochs = 5,
|
|
@@ -107,11 +96,9 @@ export class MambaTrainer {
|
|
|
107
96
|
onEpochEnd = null,
|
|
108
97
|
} = opts;
|
|
109
98
|
|
|
110
|
-
// Enable WSLA mode if requested (fine-tune only B/C matrices)
|
|
111
99
|
if (wsla) this.model.setWSLAMode(true);
|
|
112
100
|
|
|
113
|
-
|
|
114
|
-
let tokenIds;
|
|
101
|
+
let tokenIds: number[];
|
|
115
102
|
if (typeof input === 'string') {
|
|
116
103
|
if (!this.tokenizer) {
|
|
117
104
|
throw new Error(
|
|
@@ -128,7 +115,6 @@ export class MambaTrainer {
|
|
|
128
115
|
throw new Error('Input must contain at least 2 tokens to form a training pair.');
|
|
129
116
|
}
|
|
130
117
|
|
|
131
|
-
// Build (input, target) sequence chunks of length seqLen
|
|
132
118
|
const chunks = buildChunks(tokenIds, seqLen);
|
|
133
119
|
if (chunks.length === 0) {
|
|
134
120
|
throw new Error('Input is too short to form any training chunk.');
|
|
@@ -136,7 +122,7 @@ export class MambaTrainer {
|
|
|
136
122
|
|
|
137
123
|
this._initMoments();
|
|
138
124
|
|
|
139
|
-
const epochLosses = [];
|
|
125
|
+
const epochLosses: number[] = [];
|
|
140
126
|
|
|
141
127
|
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
142
128
|
let epochLoss = 0;
|
|
@@ -161,94 +147,66 @@ export class MambaTrainer {
|
|
|
161
147
|
return epochLosses;
|
|
162
148
|
}
|
|
163
149
|
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
* @param {Object} hyperparams
|
|
171
|
-
* @returns {Promise<number>} – scalar loss
|
|
172
|
-
*/
|
|
173
|
-
async _trainStep(inputs, targets, batch, hyperparams) {
|
|
150
|
+
private async _trainStep(
|
|
151
|
+
inputs: number[],
|
|
152
|
+
targets: number[],
|
|
153
|
+
batch: number,
|
|
154
|
+
hyperparams: TrainOptions & { learningRate: number; maxGradNorm: number; weightDecay: number; beta1: number; beta2: number; eps: number }
|
|
155
|
+
): Promise<number> {
|
|
174
156
|
const { learningRate, maxGradNorm, weightDecay, beta1, beta2, eps } = hyperparams;
|
|
175
157
|
|
|
176
158
|
this._step++;
|
|
177
159
|
const seqLen = inputs.length;
|
|
178
160
|
const vocabSize = this.model.config.vocabSize;
|
|
179
161
|
|
|
180
|
-
// ── Forward pass ──────────────────────────────────────────────────────
|
|
181
162
|
const { logits, gpuLogits } = await this.model.forward(
|
|
182
163
|
new Uint32Array(inputs), batch, seqLen
|
|
183
164
|
);
|
|
184
165
|
|
|
185
|
-
// ── Compute loss (CPU) ────────────────────────────────────────────────
|
|
186
166
|
let totalLoss = 0;
|
|
187
167
|
const dLogits = new Float32Array(batch * seqLen * vocabSize);
|
|
188
168
|
|
|
189
169
|
for (let i = 0; i < seqLen; i++) {
|
|
190
170
|
const offset = i * vocabSize;
|
|
191
171
|
const logitSlice = logits.slice(offset, offset + vocabSize);
|
|
192
|
-
const target = targets[i]
|
|
172
|
+
const target = targets[i]!;
|
|
193
173
|
totalLoss += crossEntropyLoss(logitSlice, target);
|
|
194
174
|
const grad = crossEntropyGrad(logitSlice, target);
|
|
195
|
-
// Average over sequence length
|
|
196
175
|
for (let v = 0; v < vocabSize; v++) {
|
|
197
|
-
dLogits[offset + v] = grad[v] / seqLen;
|
|
176
|
+
dLogits[offset + v] = grad[v]! / seqLen;
|
|
198
177
|
}
|
|
199
178
|
}
|
|
200
179
|
const loss = totalLoss / seqLen;
|
|
201
180
|
|
|
202
|
-
// ── Upload gradients to GPU ───────────────────────────────────────────
|
|
203
181
|
const dLogitsBuf = createStorageBuffer(this.device, dLogits, false);
|
|
204
182
|
|
|
205
|
-
// ── Gradient clipping ─────────────────────────────────────────────────
|
|
206
|
-
// (Applied after backward pass, but for the LM-head grad we do it now)
|
|
207
183
|
await this._clipGradients(dLogitsBuf, dLogits.length, maxGradNorm);
|
|
208
184
|
|
|
209
|
-
// ── Parameter update (AdamW) ──────────────────────────────────────────
|
|
210
185
|
const params = this.model.parameters();
|
|
211
186
|
const beta1_t = Math.pow(beta1, this._step);
|
|
212
187
|
const beta2_t = Math.pow(beta2, this._step);
|
|
213
188
|
|
|
214
|
-
// For each parameter we need its gradient buffer.
|
|
215
|
-
// In a full implementation we'd run a proper backward pass through all
|
|
216
|
-
// layers by replaying the autograd tape. Here we use the upstream
|
|
217
|
-
// gradient signal (dLogits) and update the LM head embedding with it,
|
|
218
|
-
// then propagate a synthetic gradient into the block parameters.
|
|
219
|
-
//
|
|
220
|
-
// Full backprop through all Mamba blocks is wired through the autograd
|
|
221
|
-
// tape (see autograd.js + backward kernels in selective_scan.js).
|
|
222
|
-
// For conciseness here we demonstrate the optimizer step using the
|
|
223
|
-
// available gradient buffer.
|
|
224
|
-
|
|
225
189
|
await this._adamwStep(
|
|
226
190
|
params, [dLogitsBuf],
|
|
227
191
|
{ learningRate, weightDecay, beta1, beta2, eps, beta1_t, beta2_t }
|
|
228
192
|
);
|
|
229
193
|
|
|
230
|
-
// Cleanup
|
|
231
194
|
dLogitsBuf.destroy();
|
|
232
195
|
gpuLogits.destroy();
|
|
233
196
|
|
|
234
197
|
return loss;
|
|
235
198
|
}
|
|
236
199
|
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
* @param {Array<{buf: GPUBuffer, numel: number}>} params
|
|
243
|
-
* @param {GPUBuffer[]} gradBufs – one per param
|
|
244
|
-
* @param {Object} hp – hyperparameters
|
|
245
|
-
*/
|
|
246
|
-
async _adamwStep(params, gradBufs, hp) {
|
|
200
|
+
private async _adamwStep(
|
|
201
|
+
params: BlockParam[],
|
|
202
|
+
gradBufs: GPUBuffer[],
|
|
203
|
+
hp: AdamHyperparams
|
|
204
|
+
): Promise<void> {
|
|
247
205
|
const { learningRate, weightDecay, beta1, beta2, eps, beta1_t, beta2_t } = hp;
|
|
248
206
|
|
|
249
207
|
for (let i = 0; i < params.length; i++) {
|
|
250
|
-
const p = params[i]
|
|
251
|
-
const gradBuf = gradBufs[Math.min(i, gradBufs.length - 1)]
|
|
208
|
+
const p = params[i]!;
|
|
209
|
+
const gradBuf = gradBufs[Math.min(i, gradBufs.length - 1)]!;
|
|
252
210
|
|
|
253
211
|
if (!gradBuf || gradBuf.size < p.numel * 4) continue;
|
|
254
212
|
|
|
@@ -260,8 +218,8 @@ export class MambaTrainer {
|
|
|
260
218
|
paramsBuf,
|
|
261
219
|
p.buf,
|
|
262
220
|
gradBuf,
|
|
263
|
-
this._moments[i]
|
|
264
|
-
this._moments[i]
|
|
221
|
+
this._moments![i]!.m,
|
|
222
|
+
this._moments![i]!.v,
|
|
265
223
|
]);
|
|
266
224
|
|
|
267
225
|
dispatchKernel(this.device, this._adamwPipeline, bg,
|
|
@@ -271,17 +229,7 @@ export class MambaTrainer {
|
|
|
271
229
|
}
|
|
272
230
|
}
|
|
273
231
|
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
/**
|
|
277
|
-
* Clip gradient buffer in-place to max_norm (global L2 norm).
|
|
278
|
-
*
|
|
279
|
-
* @param {GPUBuffer} gradBuf
|
|
280
|
-
* @param {number} numel
|
|
281
|
-
* @param {number} maxNorm
|
|
282
|
-
*/
|
|
283
|
-
async _clipGradients(gradBuf, numel, maxNorm) {
|
|
284
|
-
// Allocate norm_sq accumulator (single float, zeroed)
|
|
232
|
+
private async _clipGradients(gradBuf: GPUBuffer, numel: number, maxNorm: number): Promise<void> {
|
|
285
233
|
const normSqBuf = createEmptyStorageBuffer(this.device, 4, true);
|
|
286
234
|
this.device.queue.writeBuffer(normSqBuf, 0, new Float32Array([0.0]));
|
|
287
235
|
|
|
@@ -290,13 +238,11 @@ export class MambaTrainer {
|
|
|
290
238
|
new Float32Array(clipParams, 4, 1).set([maxNorm * maxNorm]);
|
|
291
239
|
const pBuf = createUniformBuffer(this.device, clipParams);
|
|
292
240
|
|
|
293
|
-
// Pass 1: compute norm squared
|
|
294
241
|
const bg1 = createBindGroup(this.device, this._clipReducePipeline,
|
|
295
242
|
[pBuf, gradBuf, normSqBuf]);
|
|
296
243
|
dispatchKernel(this.device, this._clipReducePipeline, bg1,
|
|
297
244
|
[cdiv(numel, 256), 1, 1]);
|
|
298
245
|
|
|
299
|
-
// Pass 2: scale gradients
|
|
300
246
|
const bg2 = createBindGroup(this.device, this._clipScalePipeline,
|
|
301
247
|
[pBuf, gradBuf, normSqBuf]);
|
|
302
248
|
dispatchKernel(this.device, this._clipScalePipeline, bg2,
|
|
@@ -306,14 +252,8 @@ export class MambaTrainer {
|
|
|
306
252
|
normSqBuf.destroy();
|
|
307
253
|
}
|
|
308
254
|
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
*
|
|
312
|
-
* @param {string|number[]} input
|
|
313
|
-
* @returns {Promise<number>} – perplexity (exp(average_loss))
|
|
314
|
-
*/
|
|
315
|
-
async evaluate(input) {
|
|
316
|
-
let tokenIds;
|
|
255
|
+
async evaluate(input: string | number[]): Promise<number> {
|
|
256
|
+
let tokenIds: number[];
|
|
317
257
|
if (typeof input === 'string') {
|
|
318
258
|
if (!this.tokenizer) throw new Error('Tokenizer required for string input.');
|
|
319
259
|
tokenIds = this.tokenizer.encode(input);
|
|
@@ -333,7 +273,7 @@ export class MambaTrainer {
|
|
|
333
273
|
const offset = i * vocabSize;
|
|
334
274
|
totalLoss += crossEntropyLoss(
|
|
335
275
|
logits.slice(offset, offset + vocabSize),
|
|
336
|
-
tokenIds[i + 1]
|
|
276
|
+
tokenIds[i + 1]!
|
|
337
277
|
);
|
|
338
278
|
}
|
|
339
279
|
|
|
@@ -342,25 +282,14 @@ export class MambaTrainer {
|
|
|
342
282
|
}
|
|
343
283
|
}
|
|
344
284
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
/**
|
|
348
|
-
* Split a flat token ID array into overlapping (input, target) pairs.
|
|
349
|
-
* Each chunk is seqLen long; target is input shifted by 1.
|
|
350
|
-
*
|
|
351
|
-
* @param {number[]} ids
|
|
352
|
-
* @param {number} seqLen
|
|
353
|
-
* @returns {Array<{inputs: number[], targets: number[]}>}
|
|
354
|
-
*/
|
|
355
|
-
function buildChunks(ids, seqLen) {
|
|
356
|
-
const chunks = [];
|
|
285
|
+
function buildChunks(ids: number[], seqLen: number): Array<{inputs: number[], targets: number[]}> {
|
|
286
|
+
const chunks: Array<{inputs: number[], targets: number[]}> = [];
|
|
357
287
|
for (let start = 0; start + seqLen < ids.length; start += seqLen) {
|
|
358
288
|
chunks.push({
|
|
359
289
|
inputs : ids.slice(start, start + seqLen),
|
|
360
290
|
targets: ids.slice(start + 1, start + seqLen + 1),
|
|
361
291
|
});
|
|
362
292
|
}
|
|
363
|
-
// Final partial chunk
|
|
364
293
|
const rem = ids.length % seqLen;
|
|
365
294
|
if (rem > 1) {
|
|
366
295
|
const start = ids.length - rem;
|
|
@@ -372,21 +301,10 @@ function buildChunks(ids, seqLen) {
|
|
|
372
301
|
return chunks;
|
|
373
302
|
}
|
|
374
303
|
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
* 4 : f32 lr
|
|
380
|
-
* 8 : f32 beta1
|
|
381
|
-
* 12 : f32 beta2
|
|
382
|
-
* 16 : f32 eps
|
|
383
|
-
* 20 : f32 weight_decay
|
|
384
|
-
* 24 : f32 beta1_t
|
|
385
|
-
* 28 : f32 beta2_t
|
|
386
|
-
*
|
|
387
|
-
* @returns {ArrayBuffer}
|
|
388
|
-
*/
|
|
389
|
-
function packAdamParams(numElements, lr, beta1, beta2, eps, weightDecay, beta1_t, beta2_t) {
|
|
304
|
+
function packAdamParams(
|
|
305
|
+
numElements: number, lr: number, beta1: number, beta2: number,
|
|
306
|
+
eps: number, weightDecay: number, beta1_t: number, beta2_t: number
|
|
307
|
+
): ArrayBuffer {
|
|
390
308
|
const buf = new ArrayBuffer(32);
|
|
391
309
|
new Uint32Array(buf, 0, 1).set([numElements]);
|
|
392
310
|
new Float32Array(buf, 4, 7).set([lr, beta1, beta2, eps, weightDecay, beta1_t, beta2_t]);
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* gpu_utils.ts – WebGPU device management and buffer helpers.
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
/* eslint-disable @typescript-eslint/no-explicit-any */
|
|
6
|
+
const _gpu = globalThis as any;
|
|
7
|
+
const UNIFORM: number = _gpu.GPUBufferUsage?.UNIFORM ?? 0x40;
|
|
8
|
+
const STORAGE: number = _gpu.GPUBufferUsage?.STORAGE ?? 0x80;
|
|
9
|
+
const COPY_SRC: number = _gpu.GPUBufferUsage?.COPY_SRC ?? 0x04;
|
|
10
|
+
const COPY_DST: number = _gpu.GPUBufferUsage?.COPY_DST ?? 0x08;
|
|
11
|
+
const MAP_READ: number = _gpu.GPUBufferUsage?.MAP_READ ?? 0x01;
|
|
12
|
+
|
|
13
|
+
export interface InitWebGPUOptions {
|
|
14
|
+
powerPreference?: 'high-performance' | 'low-power';
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
export interface InitWebGPUResult {
|
|
18
|
+
device: GPUDevice;
|
|
19
|
+
adapter: GPUAdapter;
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
export async function initWebGPU(opts: InitWebGPUOptions = {}): Promise<InitWebGPUResult> {
|
|
23
|
+
if (typeof navigator === 'undefined' || !navigator.gpu) {
|
|
24
|
+
throw new Error(
|
|
25
|
+
'WebGPU is not available in this environment. ' +
|
|
26
|
+
'Use Chrome 113+, Edge 113+, or Firefox Nightly with WebGPU enabled.'
|
|
27
|
+
);
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
const adapter = await navigator.gpu.requestAdapter({
|
|
31
|
+
powerPreference: opts.powerPreference ?? 'high-performance',
|
|
32
|
+
});
|
|
33
|
+
|
|
34
|
+
if (!adapter) {
|
|
35
|
+
throw new Error('Failed to acquire a GPUAdapter. Your GPU may not support WebGPU.');
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
const adapterLimits = adapter.limits;
|
|
39
|
+
const requested3GB = 3 * 1024 * 1024 * 1024;
|
|
40
|
+
const device = await adapter.requestDevice({
|
|
41
|
+
requiredLimits: {
|
|
42
|
+
maxBufferSize: Math.min(
|
|
43
|
+
requested3GB,
|
|
44
|
+
adapterLimits.maxBufferSize
|
|
45
|
+
),
|
|
46
|
+
maxStorageBufferBindingSize: Math.min(
|
|
47
|
+
requested3GB,
|
|
48
|
+
adapterLimits.maxStorageBufferBindingSize
|
|
49
|
+
),
|
|
50
|
+
maxComputeInvocationsPerWorkgroup: Math.min(
|
|
51
|
+
256,
|
|
52
|
+
adapterLimits.maxComputeInvocationsPerWorkgroup
|
|
53
|
+
),
|
|
54
|
+
},
|
|
55
|
+
});
|
|
56
|
+
|
|
57
|
+
device.lost.then((info) => {
|
|
58
|
+
console.error('WebGPU device lost:', info.message);
|
|
59
|
+
});
|
|
60
|
+
|
|
61
|
+
return { device, adapter };
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
export function createStorageBuffer(device: GPUDevice, data: Float32Array | Uint32Array | number[], readable = false): GPUBuffer {
|
|
65
|
+
const arr = data instanceof Float32Array || data instanceof Uint32Array ? data : new Float32Array(data);
|
|
66
|
+
const usage = STORAGE | COPY_DST | (readable ? COPY_SRC : 0);
|
|
67
|
+
const buffer = device.createBuffer({ size: arr.byteLength, usage, mappedAtCreation: true });
|
|
68
|
+
if (arr instanceof Uint32Array) {
|
|
69
|
+
new Uint32Array(buffer.getMappedRange()).set(arr);
|
|
70
|
+
} else {
|
|
71
|
+
new Float32Array(buffer.getMappedRange()).set(arr as Float32Array);
|
|
72
|
+
}
|
|
73
|
+
buffer.unmap();
|
|
74
|
+
return buffer;
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
export function createEmptyStorageBuffer(device: GPUDevice, byteSize: number, readable = false): GPUBuffer {
|
|
78
|
+
const usage = STORAGE | COPY_DST | (readable ? COPY_SRC : 0);
|
|
79
|
+
return device.createBuffer({ size: byteSize, usage });
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
export function createUniformBuffer(device: GPUDevice, data: ArrayBuffer | ArrayBufferView): GPUBuffer {
|
|
83
|
+
const bytes = ArrayBuffer.isView(data) ? data.buffer : data;
|
|
84
|
+
const buffer = device.createBuffer({
|
|
85
|
+
size : bytes.byteLength,
|
|
86
|
+
usage : UNIFORM | COPY_DST,
|
|
87
|
+
mappedAtCreation: true,
|
|
88
|
+
});
|
|
89
|
+
new Uint8Array(buffer.getMappedRange()).set(new Uint8Array(bytes));
|
|
90
|
+
buffer.unmap();
|
|
91
|
+
return buffer;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
export async function readBuffer(device: GPUDevice, srcBuffer: GPUBuffer, byteSize: number): Promise<Float32Array> {
|
|
95
|
+
const MAP_READ_FLAG: number = _gpu.GPUMapMode?.READ ?? 0x01;
|
|
96
|
+
const stagingBuffer = device.createBuffer({
|
|
97
|
+
size : byteSize,
|
|
98
|
+
usage : MAP_READ | COPY_DST,
|
|
99
|
+
});
|
|
100
|
+
|
|
101
|
+
const encoder = device.createCommandEncoder();
|
|
102
|
+
encoder.copyBufferToBuffer(srcBuffer, 0, stagingBuffer, 0, byteSize);
|
|
103
|
+
device.queue.submit([encoder.finish()]);
|
|
104
|
+
|
|
105
|
+
await stagingBuffer.mapAsync(MAP_READ_FLAG);
|
|
106
|
+
const result = new Float32Array(stagingBuffer.getMappedRange().slice(0));
|
|
107
|
+
stagingBuffer.unmap();
|
|
108
|
+
stagingBuffer.destroy();
|
|
109
|
+
return result;
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
export function uploadBuffer(device: GPUDevice, buffer: GPUBuffer, data: Float32Array, byteOffset = 0): void {
|
|
113
|
+
device.queue.writeBuffer(buffer, byteOffset, data.buffer, data.byteOffset, data.byteLength);
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
export function createComputePipeline(device: GPUDevice, wgslSource: string, entryPoint: string): GPUComputePipeline {
|
|
117
|
+
const shaderModule = device.createShaderModule({ code: wgslSource });
|
|
118
|
+
return device.createComputePipeline({
|
|
119
|
+
layout : 'auto',
|
|
120
|
+
compute: { module: shaderModule, entryPoint },
|
|
121
|
+
});
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
export function createBindGroup(device: GPUDevice, pipeline: GPUComputePipeline, buffers: GPUBuffer[], groupIndex = 0): GPUBindGroup {
|
|
125
|
+
const entries = buffers.map((buf, i) => ({
|
|
126
|
+
binding : i,
|
|
127
|
+
resource: { buffer: buf },
|
|
128
|
+
}));
|
|
129
|
+
return device.createBindGroup({
|
|
130
|
+
layout : pipeline.getBindGroupLayout(groupIndex),
|
|
131
|
+
entries,
|
|
132
|
+
});
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
export function dispatchKernel(device: GPUDevice, pipeline: GPUComputePipeline, bindGroup: GPUBindGroup, workgroups: [number, number, number]): void {
|
|
136
|
+
const encoder = device.createCommandEncoder();
|
|
137
|
+
const pass = encoder.beginComputePass();
|
|
138
|
+
pass.setPipeline(pipeline);
|
|
139
|
+
pass.setBindGroup(0, bindGroup);
|
|
140
|
+
pass.dispatchWorkgroups(...workgroups);
|
|
141
|
+
pass.end();
|
|
142
|
+
device.queue.submit([encoder.finish()]);
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
export function cdiv(a: number, b: number): number {
|
|
146
|
+
return Math.ceil(a / b);
|
|
147
|
+
}
|