@seanhogg/builderforce-memory-engine 2026.6.18
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +393 -0
- package/dist/index.d.ts +32 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +40 -0
- package/dist/index.js.map +1 -0
- package/dist/kernels/activations.d.ts +5 -0
- package/dist/kernels/activations.d.ts.map +1 -0
- package/dist/kernels/activations.js +171 -0
- package/dist/kernels/activations.js.map +1 -0
- package/dist/kernels/attention.d.ts +19 -0
- package/dist/kernels/attention.d.ts.map +1 -0
- package/dist/kernels/attention.js +263 -0
- package/dist/kernels/attention.js.map +1 -0
- package/dist/kernels/complex_ssd.d.ts +33 -0
- package/dist/kernels/complex_ssd.d.ts.map +1 -0
- package/dist/kernels/complex_ssd.js +305 -0
- package/dist/kernels/complex_ssd.js.map +1 -0
- package/dist/kernels/conv1d.d.ts +3 -0
- package/dist/kernels/conv1d.d.ts.map +1 -0
- package/dist/kernels/conv1d.js +158 -0
- package/dist/kernels/conv1d.js.map +1 -0
- package/dist/kernels/linear_projection.d.ts +3 -0
- package/dist/kernels/linear_projection.d.ts.map +1 -0
- package/dist/kernels/linear_projection.js +219 -0
- package/dist/kernels/linear_projection.js.map +1 -0
- package/dist/kernels/selective_scan.d.ts +3 -0
- package/dist/kernels/selective_scan.d.ts.map +1 -0
- package/dist/kernels/selective_scan.js +348 -0
- package/dist/kernels/selective_scan.js.map +1 -0
- package/dist/kernels/ssd.d.ts +29 -0
- package/dist/kernels/ssd.d.ts.map +1 -0
- package/dist/kernels/ssd.js +276 -0
- package/dist/kernels/ssd.js.map +1 -0
- package/dist/kernels/weight_update.d.ts +3 -0
- package/dist/kernels/weight_update.d.ts.map +1 -0
- package/dist/kernels/weight_update.js +119 -0
- package/dist/kernels/weight_update.js.map +1 -0
- package/dist/model/attention_block.d.ts +48 -0
- package/dist/model/attention_block.d.ts.map +1 -0
- package/dist/model/attention_block.js +262 -0
- package/dist/model/attention_block.js.map +1 -0
- package/dist/model/mamba1_block.d.ts +70 -0
- package/dist/model/mamba1_block.d.ts.map +1 -0
- package/dist/model/mamba1_block.js +333 -0
- package/dist/model/mamba1_block.js.map +1 -0
- package/dist/model/mamba2_block.d.ts +44 -0
- package/dist/model/mamba2_block.d.ts.map +1 -0
- package/dist/model/mamba2_block.js +252 -0
- package/dist/model/mamba2_block.js.map +1 -0
- package/dist/model/mamba3_block.d.ts +51 -0
- package/dist/model/mamba3_block.d.ts.map +1 -0
- package/dist/model/mamba3_block.js +270 -0
- package/dist/model/mamba3_block.js.map +1 -0
- package/dist/model/mamba_block.d.ts +64 -0
- package/dist/model/mamba_block.d.ts.map +1 -0
- package/dist/model/mamba_block.js +303 -0
- package/dist/model/mamba_block.js.map +1 -0
- package/dist/model/mamba_model.d.ts +140 -0
- package/dist/model/mamba_model.d.ts.map +1 -0
- package/dist/model/mamba_model.js +527 -0
- package/dist/model/mamba_model.js.map +1 -0
- package/dist/model/sequence_layer.d.ts +25 -0
- package/dist/model/sequence_layer.d.ts.map +1 -0
- package/dist/model/sequence_layer.js +8 -0
- package/dist/model/sequence_layer.js.map +1 -0
- package/dist/tokenizer/bpe.d.ts +29 -0
- package/dist/tokenizer/bpe.d.ts.map +1 -0
- package/dist/tokenizer/bpe.js +164 -0
- package/dist/tokenizer/bpe.js.map +1 -0
- package/dist/training/autograd.d.ts +27 -0
- package/dist/training/autograd.d.ts.map +1 -0
- package/dist/training/autograd.js +120 -0
- package/dist/training/autograd.js.map +1 -0
- package/dist/training/trainer.d.ts +36 -0
- package/dist/training/trainer.d.ts.map +1 -0
- package/dist/training/trainer.js +183 -0
- package/dist/training/trainer.js.map +1 -0
- package/dist/utils/gpu_utils.d.ts +21 -0
- package/dist/utils/gpu_utils.d.ts.map +1 -0
- package/dist/utils/gpu_utils.js +111 -0
- package/dist/utils/gpu_utils.js.map +1 -0
- package/dist/utils/quantization.d.ts +26 -0
- package/dist/utils/quantization.d.ts.map +1 -0
- package/dist/utils/quantization.js +116 -0
- package/dist/utils/quantization.js.map +1 -0
- package/dist/utils/rng.d.ts +36 -0
- package/dist/utils/rng.d.ts.map +1 -0
- package/dist/utils/rng.js +61 -0
- package/dist/utils/rng.js.map +1 -0
- package/package.json +99 -0
- package/src/index.ts +114 -0
- package/src/kernels/activations.ts +174 -0
- package/src/kernels/attention.ts +268 -0
- package/src/kernels/complex_ssd.ts +307 -0
- package/src/kernels/conv1d.ts +159 -0
- package/src/kernels/linear_projection.ts +220 -0
- package/src/kernels/selective_scan.ts +350 -0
- package/src/kernels/ssd.ts +278 -0
- package/src/kernels/weight_update.ts +120 -0
- package/src/model/attention_block.ts +344 -0
- package/src/model/mamba1_block.ts +437 -0
- package/src/model/mamba2_block.ts +319 -0
- package/src/model/mamba3_block.ts +335 -0
- package/src/model/mamba_block.ts +401 -0
- package/src/model/mamba_model.ts +678 -0
- package/src/model/sequence_layer.ts +29 -0
- package/src/tokenizer/bpe.ts +186 -0
- package/src/training/autograd.ts +135 -0
- package/src/training/trainer.ts +309 -0
- package/src/utils/gpu_utils.ts +147 -0
- package/src/utils/quantization.ts +154 -0
- package/src/utils/rng.ts +65 -0
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* sequence_layer.ts – Common interface implemented by all block types.
|
|
3
|
+
*
|
|
4
|
+
* Mamba1Block, Mamba2Block, Mamba3Block, and AttentionBlock all implement
|
|
5
|
+
* SequenceLayer so that HybridMambaModel can iterate layers generically.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
export interface LayerForwardResult {
|
|
9
|
+
output : GPUBuffer;
|
|
10
|
+
cache : unknown; // type-specific per layer variant
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
export interface LayerParam {
|
|
14
|
+
buf : GPUBuffer;
|
|
15
|
+
numel : number;
|
|
16
|
+
name : string;
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
export type LayerType = 'mamba1' | 'mamba2' | 'mamba3' | 'attention';
|
|
20
|
+
|
|
21
|
+
export interface SequenceLayer {
|
|
22
|
+
readonly layerType: LayerType;
|
|
23
|
+
|
|
24
|
+
forward(xBuf: GPUBuffer, batch: number, seqLen: number): LayerForwardResult;
|
|
25
|
+
parameters(): LayerParam[];
|
|
26
|
+
getTrainableParams(): LayerParam[];
|
|
27
|
+
setWSLAMode(enabled: boolean): void;
|
|
28
|
+
destroy(): void;
|
|
29
|
+
}
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* bpe.ts – Browser-side Byte Pair Encoding (BPE) tokenizer.
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
export interface BPEEncodeOptions {
|
|
6
|
+
addBos?: boolean;
|
|
7
|
+
addEos?: boolean;
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
export type PadSide = 'right' | 'left';
|
|
11
|
+
|
|
12
|
+
function buildByteEncoder(): Map<number, string> {
|
|
13
|
+
const enc = new Map<number, string>();
|
|
14
|
+
const ranges: [number, number][] = [
|
|
15
|
+
[0x21, 0x7E],
|
|
16
|
+
[0xA1, 0xAC],
|
|
17
|
+
[0xAE, 0xFF],
|
|
18
|
+
];
|
|
19
|
+
let n = 0;
|
|
20
|
+
for (const [lo, hi] of ranges) {
|
|
21
|
+
for (let b = lo; b <= hi; b++) {
|
|
22
|
+
enc.set(b, String.fromCodePoint(b));
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
for (let b = 0; b < 256; b++) {
|
|
26
|
+
if (!enc.has(b)) {
|
|
27
|
+
enc.set(b, String.fromCodePoint(256 + n));
|
|
28
|
+
n++;
|
|
29
|
+
}
|
|
30
|
+
}
|
|
31
|
+
return enc;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
const BYTE_ENCODER = buildByteEncoder();
|
|
35
|
+
const BYTE_DECODER = new Map([...BYTE_ENCODER].map(([k, v]) => [v, k]));
|
|
36
|
+
|
|
37
|
+
const PRE_TOKENIZE_RE =
|
|
38
|
+
/(?:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+/gu;
|
|
39
|
+
|
|
40
|
+
export class BPETokenizer {
|
|
41
|
+
vocab: Map<string, number>;
|
|
42
|
+
idToToken: Map<number, string>;
|
|
43
|
+
merges: Map<string, number>;
|
|
44
|
+
bosToken: string;
|
|
45
|
+
eosToken: string;
|
|
46
|
+
padToken: string;
|
|
47
|
+
unkToken: string;
|
|
48
|
+
bosId: number | null;
|
|
49
|
+
eosId: number | null;
|
|
50
|
+
padId: number | null;
|
|
51
|
+
|
|
52
|
+
constructor() {
|
|
53
|
+
this.vocab = new Map();
|
|
54
|
+
this.idToToken = new Map();
|
|
55
|
+
this.merges = new Map();
|
|
56
|
+
this.bosToken = '<|im_start|>';
|
|
57
|
+
this.eosToken = '<|im_end|>';
|
|
58
|
+
this.padToken = '<|endoftext|>';
|
|
59
|
+
this.unkToken = '<unk>';
|
|
60
|
+
this.bosId = null;
|
|
61
|
+
this.eosId = null;
|
|
62
|
+
this.padId = null;
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
async load(vocab: string | Record<string, number>, merges: string | string[]): Promise<void> {
|
|
66
|
+
let vocabObj: Record<string, number>;
|
|
67
|
+
if (typeof vocab === 'string') {
|
|
68
|
+
const res = await fetch(vocab);
|
|
69
|
+
vocabObj = await res.json() as Record<string, number>;
|
|
70
|
+
} else {
|
|
71
|
+
vocabObj = vocab;
|
|
72
|
+
}
|
|
73
|
+
this.vocab = new Map(Object.entries(vocabObj).map(([k, v]) => [k, Number(v)]));
|
|
74
|
+
this.idToToken = new Map([...this.vocab].map(([k, v]) => [v, k]));
|
|
75
|
+
|
|
76
|
+
let mergeLines: string[];
|
|
77
|
+
if (typeof merges === 'string') {
|
|
78
|
+
const res = await fetch(merges);
|
|
79
|
+
const txt = await res.text();
|
|
80
|
+
mergeLines = txt.split('\n').filter(l => l && !l.startsWith('#'));
|
|
81
|
+
} else {
|
|
82
|
+
mergeLines = merges;
|
|
83
|
+
}
|
|
84
|
+
this.merges = new Map();
|
|
85
|
+
mergeLines.forEach((line, rank) => {
|
|
86
|
+
this.merges.set(line.trim(), rank);
|
|
87
|
+
});
|
|
88
|
+
|
|
89
|
+
this.bosId = this.vocab.get(this.bosToken) ?? null;
|
|
90
|
+
this.eosId = this.vocab.get(this.eosToken) ?? null;
|
|
91
|
+
this.padId = this.vocab.get(this.padToken) ?? null;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
loadFromObjects(vocabObj: Record<string, number>, mergeArr: string[]): void {
|
|
95
|
+
this.vocab = new Map(Object.entries(vocabObj).map(([k, v]) => [k, Number(v)]));
|
|
96
|
+
this.idToToken = new Map([...this.vocab].map(([k, v]) => [v, k]));
|
|
97
|
+
this.merges = new Map(mergeArr.map((m, i) => [m, i]));
|
|
98
|
+
this.bosId = this.vocab.get(this.bosToken) ?? null;
|
|
99
|
+
this.eosId = this.vocab.get(this.eosToken) ?? null;
|
|
100
|
+
this.padId = this.vocab.get(this.padToken) ?? null;
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
encode(text: string, opts: BPEEncodeOptions = {}): number[] {
|
|
104
|
+
const words = text.match(PRE_TOKENIZE_RE) ?? [];
|
|
105
|
+
const ids: number[] = [];
|
|
106
|
+
|
|
107
|
+
if (opts.addBos && this.bosId !== null) ids.push(this.bosId);
|
|
108
|
+
|
|
109
|
+
for (const word of words) {
|
|
110
|
+
const bytes = new TextEncoder().encode(word);
|
|
111
|
+
const byteStr = Array.from(bytes).map(b => BYTE_ENCODER.get(b) ?? '?').join('');
|
|
112
|
+
const bpeTokens = this._bpe(byteStr);
|
|
113
|
+
|
|
114
|
+
for (const tok of bpeTokens) {
|
|
115
|
+
const id = this.vocab.get(tok);
|
|
116
|
+
if (id !== undefined) {
|
|
117
|
+
ids.push(id);
|
|
118
|
+
} else {
|
|
119
|
+
for (const ch of tok) {
|
|
120
|
+
const cid = this.vocab.get(ch);
|
|
121
|
+
if (cid !== undefined) ids.push(cid);
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
if (opts.addEos && this.eosId !== null) ids.push(this.eosId);
|
|
128
|
+
return ids;
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
decode(ids: number[]): string {
|
|
132
|
+
let byteStr = '';
|
|
133
|
+
for (const id of ids) {
|
|
134
|
+
const tok = this.idToToken.get(id);
|
|
135
|
+
if (tok !== undefined) byteStr += tok;
|
|
136
|
+
}
|
|
137
|
+
const bytes = new Uint8Array(
|
|
138
|
+
[...byteStr].map(ch => BYTE_DECODER.get(ch) ?? ch.codePointAt(0) ?? 0)
|
|
139
|
+
);
|
|
140
|
+
try {
|
|
141
|
+
return new TextDecoder('utf-8').decode(bytes);
|
|
142
|
+
} catch {
|
|
143
|
+
return byteStr;
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
_bpe(word: string): string[] {
|
|
148
|
+
if (this.vocab.has(word)) return [word];
|
|
149
|
+
|
|
150
|
+
let symbols = [...word];
|
|
151
|
+
|
|
152
|
+
while (symbols.length > 1) {
|
|
153
|
+
let bestRank = Infinity;
|
|
154
|
+
let bestIdx = -1;
|
|
155
|
+
|
|
156
|
+
for (let i = 0; i < symbols.length - 1; i++) {
|
|
157
|
+
const pair = symbols[i] + ' ' + symbols[i + 1];
|
|
158
|
+
const rank = this.merges.get(pair);
|
|
159
|
+
if (rank !== undefined && rank < bestRank) {
|
|
160
|
+
bestRank = rank;
|
|
161
|
+
bestIdx = i;
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
if (bestIdx === -1) break;
|
|
166
|
+
|
|
167
|
+
const merged = symbols[bestIdx]! + symbols[bestIdx + 1]!;
|
|
168
|
+
symbols = [
|
|
169
|
+
...symbols.slice(0, bestIdx),
|
|
170
|
+
merged,
|
|
171
|
+
...symbols.slice(bestIdx + 2),
|
|
172
|
+
];
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
return symbols;
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
padOrTruncate(ids: number[], maxLen: number, side: PadSide = 'right'): number[] {
|
|
179
|
+
if (ids.length >= maxLen) return ids.slice(0, maxLen);
|
|
180
|
+
const padId = this.padId ?? 0;
|
|
181
|
+
const pad = new Array<number>(maxLen - ids.length).fill(padId);
|
|
182
|
+
return side === 'right' ? [...ids, ...pad] : [...pad, ...ids];
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
get vocabSize(): number { return this.vocab.size; }
|
|
186
|
+
}
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* autograd.ts – Lightweight tape-based automatic differentiation engine.
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
/* eslint-disable @typescript-eslint/no-explicit-any */
|
|
6
|
+
const _gpu = globalThis as any;
|
|
7
|
+
|
|
8
|
+
interface TapeEntry {
|
|
9
|
+
backward: () => void | Promise<void>;
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
let _tape: TapeEntry[] = [];
|
|
13
|
+
let _gradEnabled = true;
|
|
14
|
+
|
|
15
|
+
export class Tensor {
|
|
16
|
+
data: GPUBuffer | null;
|
|
17
|
+
shape: number[];
|
|
18
|
+
numel: number;
|
|
19
|
+
requiresGrad: boolean;
|
|
20
|
+
grad: GPUBuffer | null;
|
|
21
|
+
_gradFn: number | null;
|
|
22
|
+
|
|
23
|
+
constructor(data: GPUBuffer | null, shape: number[], requiresGrad = false) {
|
|
24
|
+
this.data = data;
|
|
25
|
+
this.shape = shape;
|
|
26
|
+
this.numel = shape.reduce((a, b) => a * b, 1);
|
|
27
|
+
this.requiresGrad = requiresGrad;
|
|
28
|
+
this.grad = null;
|
|
29
|
+
this._gradFn = null;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
get byteSize(): number { return this.numel * 4; }
|
|
33
|
+
|
|
34
|
+
zeroGrad(device: GPUDevice): void {
|
|
35
|
+
if (this.grad) {
|
|
36
|
+
device.queue.writeBuffer(this.grad, 0, new Float32Array(this.numel));
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
destroy(): void {
|
|
41
|
+
this.data?.destroy();
|
|
42
|
+
this.grad?.destroy();
|
|
43
|
+
this.data = null;
|
|
44
|
+
this.grad = null;
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
export function enableGrad(): void { _gradEnabled = true; }
|
|
49
|
+
export function noGrad(): void { _gradEnabled = false; }
|
|
50
|
+
export function clearTape(): void { _tape = []; }
|
|
51
|
+
|
|
52
|
+
export function recordOperation(backwardFn: () => void | Promise<void>): number {
|
|
53
|
+
if (!_gradEnabled) return -1;
|
|
54
|
+
_tape.push({ backward: backwardFn });
|
|
55
|
+
return _tape.length - 1;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
export async function backward(): Promise<void> {
|
|
59
|
+
for (let i = _tape.length - 1; i >= 0; i--) {
|
|
60
|
+
await _tape[i]!.backward();
|
|
61
|
+
}
|
|
62
|
+
clearTape();
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
export function ensureGradBuffer(device: GPUDevice, tensor: Tensor): void {
|
|
66
|
+
if (!tensor.grad) {
|
|
67
|
+
const STORAGE_USAGE: number = (_gpu.GPUBufferUsage?.STORAGE ?? 0x80) |
|
|
68
|
+
(_gpu.GPUBufferUsage?.COPY_DST ?? 0x08) |
|
|
69
|
+
(_gpu.GPUBufferUsage?.COPY_SRC ?? 0x04);
|
|
70
|
+
tensor.grad = device.createBuffer({
|
|
71
|
+
size : tensor.byteSize,
|
|
72
|
+
usage : STORAGE_USAGE,
|
|
73
|
+
});
|
|
74
|
+
device.queue.writeBuffer(tensor.grad, 0, new Float32Array(tensor.numel));
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
export function allocateGradients(device: GPUDevice, tensors: Tensor[]): void {
|
|
79
|
+
for (const t of tensors) {
|
|
80
|
+
if (t.requiresGrad) ensureGradBuffer(device, t);
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
export function zeroGradients(device: GPUDevice, tensors: Tensor[]): void {
|
|
85
|
+
for (const t of tensors) {
|
|
86
|
+
if (t.grad) {
|
|
87
|
+
device.queue.writeBuffer(t.grad, 0, new Float32Array(t.numel));
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
export function onesLikeScalar(device: GPUDevice): GPUBuffer {
|
|
93
|
+
const USAGE: number = (_gpu.GPUBufferUsage?.STORAGE ?? 0x80) |
|
|
94
|
+
(_gpu.GPUBufferUsage?.COPY_DST ?? 0x08);
|
|
95
|
+
const buf = device.createBuffer({
|
|
96
|
+
size : 4,
|
|
97
|
+
usage : USAGE,
|
|
98
|
+
mappedAtCreation: true,
|
|
99
|
+
});
|
|
100
|
+
new Float32Array(buf.getMappedRange()).set([1.0]);
|
|
101
|
+
buf.unmap();
|
|
102
|
+
return buf;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
export function crossEntropyLoss(logits: Float32Array, targetId: number): number {
|
|
106
|
+
let maxLogit = -Infinity;
|
|
107
|
+
for (let i = 0; i < logits.length; i++) {
|
|
108
|
+
if (logits[i]! > maxLogit) maxLogit = logits[i]!;
|
|
109
|
+
}
|
|
110
|
+
let sumExp = 0;
|
|
111
|
+
for (let i = 0; i < logits.length; i++) {
|
|
112
|
+
sumExp += Math.exp(logits[i]! - maxLogit);
|
|
113
|
+
}
|
|
114
|
+
const logSumExp = Math.log(sumExp) + maxLogit;
|
|
115
|
+
return logSumExp - logits[targetId]!;
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
export function crossEntropyGrad(logits: Float32Array, targetId: number): Float32Array {
|
|
119
|
+
let maxLogit = -Infinity;
|
|
120
|
+
for (let i = 0; i < logits.length; i++) {
|
|
121
|
+
if (logits[i]! > maxLogit) maxLogit = logits[i]!;
|
|
122
|
+
}
|
|
123
|
+
let sumExp = 0;
|
|
124
|
+
const exp_shifted = new Float32Array(logits.length);
|
|
125
|
+
for (let i = 0; i < logits.length; i++) {
|
|
126
|
+
exp_shifted[i] = Math.exp(logits[i]! - maxLogit);
|
|
127
|
+
sumExp += exp_shifted[i]!;
|
|
128
|
+
}
|
|
129
|
+
const probs = new Float32Array(logits.length);
|
|
130
|
+
for (let i = 0; i < logits.length; i++) {
|
|
131
|
+
probs[i] = exp_shifted[i]! / sumExp;
|
|
132
|
+
}
|
|
133
|
+
probs[targetId] = (probs[targetId] ?? 0) - 1.0;
|
|
134
|
+
return probs;
|
|
135
|
+
}
|
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* trainer.ts – MambaTrainer class
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
import {
|
|
6
|
+
createUniformBuffer,
|
|
7
|
+
createStorageBuffer,
|
|
8
|
+
createEmptyStorageBuffer,
|
|
9
|
+
createComputePipeline,
|
|
10
|
+
createBindGroup,
|
|
11
|
+
dispatchKernel,
|
|
12
|
+
cdiv,
|
|
13
|
+
} from '../utils/gpu_utils.js';
|
|
14
|
+
|
|
15
|
+
import { crossEntropyLoss, crossEntropyGrad } from './autograd.js';
|
|
16
|
+
import { WEIGHT_UPDATE_WGSL, GRAD_CLIP_WGSL } from '../kernels/weight_update.js';
|
|
17
|
+
import { HybridMambaModel, MambaModel } from '../model/mamba_model.js';
|
|
18
|
+
import { BPETokenizer } from '../tokenizer/bpe.js';
|
|
19
|
+
import type { LayerParam as BlockParam } from '../model/sequence_layer.js';
|
|
20
|
+
|
|
21
|
+
export interface TrainOptions {
|
|
22
|
+
learningRate?: number;
|
|
23
|
+
epochs?: number;
|
|
24
|
+
batchSize?: number;
|
|
25
|
+
seqLen?: number;
|
|
26
|
+
maxGradNorm?: number;
|
|
27
|
+
weightDecay?: number;
|
|
28
|
+
beta1?: number;
|
|
29
|
+
beta2?: number;
|
|
30
|
+
eps?: number;
|
|
31
|
+
wsla?: boolean;
|
|
32
|
+
onEpochEnd?: ((epoch: number, loss: number) => void) | null;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
interface AdamMoments {
|
|
36
|
+
m: GPUBuffer;
|
|
37
|
+
v: GPUBuffer;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
interface AdamHyperparams {
|
|
41
|
+
learningRate: number;
|
|
42
|
+
weightDecay: number;
|
|
43
|
+
beta1: number;
|
|
44
|
+
beta2: number;
|
|
45
|
+
eps: number;
|
|
46
|
+
beta1_t: number;
|
|
47
|
+
beta2_t: number;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
export class MambaTrainer {
|
|
51
|
+
model: HybridMambaModel;
|
|
52
|
+
tokenizer: BPETokenizer | null;
|
|
53
|
+
device: GPUDevice;
|
|
54
|
+
private _moments: AdamMoments[] | null;
|
|
55
|
+
private _step: number;
|
|
56
|
+
private _adamwPipeline: GPUComputePipeline;
|
|
57
|
+
private _clipReducePipeline: GPUComputePipeline;
|
|
58
|
+
private _clipScalePipeline: GPUComputePipeline;
|
|
59
|
+
|
|
60
|
+
constructor(model: HybridMambaModel | MambaModel, tokenizer: BPETokenizer | null = null) {
|
|
61
|
+
this.model = model;
|
|
62
|
+
this.tokenizer = tokenizer;
|
|
63
|
+
this.device = model.device;
|
|
64
|
+
|
|
65
|
+
this._moments = null;
|
|
66
|
+
this._step = 0;
|
|
67
|
+
|
|
68
|
+
this._adamwPipeline = createComputePipeline(this.device, WEIGHT_UPDATE_WGSL, 'adamw_update');
|
|
69
|
+
this._clipReducePipeline = createComputePipeline(this.device, GRAD_CLIP_WGSL, 'grad_norm_reduce');
|
|
70
|
+
this._clipScalePipeline = createComputePipeline(this.device, GRAD_CLIP_WGSL, 'grad_clip_scale');
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
private _initMoments(): void {
|
|
74
|
+
if (this._moments) return;
|
|
75
|
+
this._moments = this.model.parameters().map(p => ({
|
|
76
|
+
m: createEmptyStorageBuffer(this.device, p.numel * 4, false),
|
|
77
|
+
v: createEmptyStorageBuffer(this.device, p.numel * 4, false),
|
|
78
|
+
}));
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
async train(input: string | number[], opts: TrainOptions = {}): Promise<number[]> {
|
|
82
|
+
const {
|
|
83
|
+
learningRate = 1e-4,
|
|
84
|
+
epochs = 5,
|
|
85
|
+
batchSize = 1,
|
|
86
|
+
seqLen = 512,
|
|
87
|
+
maxGradNorm = 1.0,
|
|
88
|
+
weightDecay = 0.01,
|
|
89
|
+
beta1 = 0.9,
|
|
90
|
+
beta2 = 0.999,
|
|
91
|
+
eps = 1e-8,
|
|
92
|
+
wsla = false,
|
|
93
|
+
onEpochEnd = null,
|
|
94
|
+
} = opts;
|
|
95
|
+
|
|
96
|
+
if (wsla) this.model.setWSLAMode(true);
|
|
97
|
+
|
|
98
|
+
let tokenIds: number[];
|
|
99
|
+
if (typeof input === 'string') {
|
|
100
|
+
if (!this.tokenizer) {
|
|
101
|
+
throw new Error(
|
|
102
|
+
'MambaTrainer requires a tokenizer when input is a string. ' +
|
|
103
|
+
'Pass a BPETokenizer instance as the second constructor argument.'
|
|
104
|
+
);
|
|
105
|
+
}
|
|
106
|
+
tokenIds = this.tokenizer.encode(input);
|
|
107
|
+
} else {
|
|
108
|
+
tokenIds = Array.from(input);
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
if (tokenIds.length < 2) {
|
|
112
|
+
throw new Error('Input must contain at least 2 tokens to form a training pair.');
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
const chunks = buildChunks(tokenIds, seqLen);
|
|
116
|
+
if (chunks.length === 0) {
|
|
117
|
+
throw new Error('Input is too short to form any training chunk.');
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
this._initMoments();
|
|
121
|
+
|
|
122
|
+
const epochLosses: number[] = [];
|
|
123
|
+
|
|
124
|
+
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
125
|
+
let epochLoss = 0;
|
|
126
|
+
let numSteps = 0;
|
|
127
|
+
|
|
128
|
+
for (const { inputs, targets } of chunks) {
|
|
129
|
+
const loss = await this._trainStep(
|
|
130
|
+
inputs, targets, batchSize,
|
|
131
|
+
{ learningRate, maxGradNorm, weightDecay, beta1, beta2, eps, wsla }
|
|
132
|
+
);
|
|
133
|
+
epochLoss += loss;
|
|
134
|
+
numSteps++;
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
const avgLoss = epochLoss / numSteps;
|
|
138
|
+
epochLosses.push(avgLoss);
|
|
139
|
+
|
|
140
|
+
if (onEpochEnd) onEpochEnd(epoch + 1, avgLoss);
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
if (wsla) this.model.setWSLAMode(false);
|
|
144
|
+
return epochLosses;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
private async _trainStep(
|
|
148
|
+
inputs: number[],
|
|
149
|
+
targets: number[],
|
|
150
|
+
batch: number,
|
|
151
|
+
hyperparams: TrainOptions & { learningRate: number; maxGradNorm: number; weightDecay: number; beta1: number; beta2: number; eps: number }
|
|
152
|
+
): Promise<number> {
|
|
153
|
+
const { learningRate, maxGradNorm, weightDecay, beta1, beta2, eps } = hyperparams;
|
|
154
|
+
|
|
155
|
+
this._step++;
|
|
156
|
+
const seqLen = inputs.length;
|
|
157
|
+
const vocabSize = this.model.config.vocabSize;
|
|
158
|
+
|
|
159
|
+
const { logits, gpuLogits } = await this.model.forward(
|
|
160
|
+
new Uint32Array(inputs), batch, seqLen
|
|
161
|
+
);
|
|
162
|
+
|
|
163
|
+
let totalLoss = 0;
|
|
164
|
+
const dLogits = new Float32Array(batch * seqLen * vocabSize);
|
|
165
|
+
|
|
166
|
+
for (let i = 0; i < seqLen; i++) {
|
|
167
|
+
const offset = i * vocabSize;
|
|
168
|
+
const logitSlice = logits.slice(offset, offset + vocabSize);
|
|
169
|
+
const target = targets[i]!;
|
|
170
|
+
totalLoss += crossEntropyLoss(logitSlice, target);
|
|
171
|
+
const grad = crossEntropyGrad(logitSlice, target);
|
|
172
|
+
for (let v = 0; v < vocabSize; v++) {
|
|
173
|
+
dLogits[offset + v] = grad[v]! / seqLen;
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
const loss = totalLoss / seqLen;
|
|
177
|
+
|
|
178
|
+
const dLogitsBuf = createStorageBuffer(this.device, dLogits, false);
|
|
179
|
+
|
|
180
|
+
await this._clipGradients(dLogitsBuf, dLogits.length, maxGradNorm);
|
|
181
|
+
|
|
182
|
+
const params = this.model.parameters();
|
|
183
|
+
const beta1_t = Math.pow(beta1, this._step);
|
|
184
|
+
const beta2_t = Math.pow(beta2, this._step);
|
|
185
|
+
|
|
186
|
+
await this._adamwStep(
|
|
187
|
+
params, [dLogitsBuf],
|
|
188
|
+
{ learningRate, weightDecay, beta1, beta2, eps, beta1_t, beta2_t }
|
|
189
|
+
);
|
|
190
|
+
|
|
191
|
+
dLogitsBuf.destroy();
|
|
192
|
+
gpuLogits.destroy();
|
|
193
|
+
|
|
194
|
+
return loss;
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
private async _adamwStep(
|
|
198
|
+
params: BlockParam[],
|
|
199
|
+
gradBufs: GPUBuffer[],
|
|
200
|
+
hp: AdamHyperparams
|
|
201
|
+
): Promise<void> {
|
|
202
|
+
const { learningRate, weightDecay, beta1, beta2, eps, beta1_t, beta2_t } = hp;
|
|
203
|
+
|
|
204
|
+
for (let i = 0; i < params.length; i++) {
|
|
205
|
+
const p = params[i]!;
|
|
206
|
+
const gradBuf = gradBufs[Math.min(i, gradBufs.length - 1)]!;
|
|
207
|
+
|
|
208
|
+
if (!gradBuf || gradBuf.size < p.numel * 4) continue;
|
|
209
|
+
|
|
210
|
+
const paramsBuf = createUniformBuffer(this.device, packAdamParams(
|
|
211
|
+
p.numel, learningRate, beta1, beta2, eps, weightDecay, beta1_t, beta2_t
|
|
212
|
+
));
|
|
213
|
+
|
|
214
|
+
const bg = createBindGroup(this.device, this._adamwPipeline, [
|
|
215
|
+
paramsBuf,
|
|
216
|
+
p.buf,
|
|
217
|
+
gradBuf,
|
|
218
|
+
this._moments![i]!.m,
|
|
219
|
+
this._moments![i]!.v,
|
|
220
|
+
]);
|
|
221
|
+
|
|
222
|
+
dispatchKernel(this.device, this._adamwPipeline, bg,
|
|
223
|
+
[cdiv(p.numel, 256), 1, 1]);
|
|
224
|
+
|
|
225
|
+
paramsBuf.destroy();
|
|
226
|
+
}
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
private async _clipGradients(gradBuf: GPUBuffer, numel: number, maxNorm: number): Promise<void> {
|
|
230
|
+
const normSqBuf = createEmptyStorageBuffer(this.device, 4, true);
|
|
231
|
+
this.device.queue.writeBuffer(normSqBuf, 0, new Float32Array([0.0]));
|
|
232
|
+
|
|
233
|
+
const clipParams = new ArrayBuffer(8);
|
|
234
|
+
new Uint32Array(clipParams, 0, 1).set([numel]);
|
|
235
|
+
new Float32Array(clipParams, 4, 1).set([maxNorm * maxNorm]);
|
|
236
|
+
const pBuf = createUniformBuffer(this.device, clipParams);
|
|
237
|
+
|
|
238
|
+
const bg1 = createBindGroup(this.device, this._clipReducePipeline,
|
|
239
|
+
[pBuf, gradBuf, normSqBuf]);
|
|
240
|
+
dispatchKernel(this.device, this._clipReducePipeline, bg1,
|
|
241
|
+
[cdiv(numel, 256), 1, 1]);
|
|
242
|
+
|
|
243
|
+
const bg2 = createBindGroup(this.device, this._clipScalePipeline,
|
|
244
|
+
[pBuf, gradBuf, normSqBuf]);
|
|
245
|
+
dispatchKernel(this.device, this._clipScalePipeline, bg2,
|
|
246
|
+
[cdiv(numel, 256), 1, 1]);
|
|
247
|
+
|
|
248
|
+
pBuf.destroy();
|
|
249
|
+
normSqBuf.destroy();
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
async evaluate(input: string | number[]): Promise<number> {
|
|
253
|
+
let tokenIds: number[];
|
|
254
|
+
if (typeof input === 'string') {
|
|
255
|
+
if (!this.tokenizer) throw new Error('Tokenizer required for string input.');
|
|
256
|
+
tokenIds = this.tokenizer.encode(input);
|
|
257
|
+
} else {
|
|
258
|
+
tokenIds = Array.from(input);
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
const seqLen = tokenIds.length;
|
|
262
|
+
const vocabSize = this.model.config.vocabSize;
|
|
263
|
+
|
|
264
|
+
const { logits } = await this.model.forward(
|
|
265
|
+
new Uint32Array(tokenIds.slice(0, -1)), 1, seqLen - 1
|
|
266
|
+
);
|
|
267
|
+
|
|
268
|
+
let totalLoss = 0;
|
|
269
|
+
for (let i = 0; i < seqLen - 1; i++) {
|
|
270
|
+
const offset = i * vocabSize;
|
|
271
|
+
totalLoss += crossEntropyLoss(
|
|
272
|
+
logits.slice(offset, offset + vocabSize),
|
|
273
|
+
tokenIds[i + 1]!
|
|
274
|
+
);
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
const avgLoss = totalLoss / (seqLen - 1);
|
|
278
|
+
return Math.exp(avgLoss);
|
|
279
|
+
}
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
function buildChunks(ids: number[], seqLen: number): Array<{inputs: number[], targets: number[]}> {
|
|
283
|
+
const chunks: Array<{inputs: number[], targets: number[]}> = [];
|
|
284
|
+
for (let start = 0; start + seqLen < ids.length; start += seqLen) {
|
|
285
|
+
chunks.push({
|
|
286
|
+
inputs : ids.slice(start, start + seqLen),
|
|
287
|
+
targets: ids.slice(start + 1, start + seqLen + 1),
|
|
288
|
+
});
|
|
289
|
+
}
|
|
290
|
+
const rem = ids.length % seqLen;
|
|
291
|
+
if (rem > 1) {
|
|
292
|
+
const start = ids.length - rem;
|
|
293
|
+
chunks.push({
|
|
294
|
+
inputs : ids.slice(start, -1),
|
|
295
|
+
targets: ids.slice(start + 1),
|
|
296
|
+
});
|
|
297
|
+
}
|
|
298
|
+
return chunks;
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
function packAdamParams(
|
|
302
|
+
numElements: number, lr: number, beta1: number, beta2: number,
|
|
303
|
+
eps: number, weightDecay: number, beta1_t: number, beta2_t: number
|
|
304
|
+
): ArrayBuffer {
|
|
305
|
+
const buf = new ArrayBuffer(32);
|
|
306
|
+
new Uint32Array(buf, 0, 1).set([numElements]);
|
|
307
|
+
new Float32Array(buf, 4, 7).set([lr, beta1, beta2, eps, weightDecay, beta1_t, beta2_t]);
|
|
308
|
+
return buf;
|
|
309
|
+
}
|