@seanhogg/builderforce-memory-engine 2026.6.18
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +393 -0
- package/dist/index.d.ts +32 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +40 -0
- package/dist/index.js.map +1 -0
- package/dist/kernels/activations.d.ts +5 -0
- package/dist/kernels/activations.d.ts.map +1 -0
- package/dist/kernels/activations.js +171 -0
- package/dist/kernels/activations.js.map +1 -0
- package/dist/kernels/attention.d.ts +19 -0
- package/dist/kernels/attention.d.ts.map +1 -0
- package/dist/kernels/attention.js +263 -0
- package/dist/kernels/attention.js.map +1 -0
- package/dist/kernels/complex_ssd.d.ts +33 -0
- package/dist/kernels/complex_ssd.d.ts.map +1 -0
- package/dist/kernels/complex_ssd.js +305 -0
- package/dist/kernels/complex_ssd.js.map +1 -0
- package/dist/kernels/conv1d.d.ts +3 -0
- package/dist/kernels/conv1d.d.ts.map +1 -0
- package/dist/kernels/conv1d.js +158 -0
- package/dist/kernels/conv1d.js.map +1 -0
- package/dist/kernels/linear_projection.d.ts +3 -0
- package/dist/kernels/linear_projection.d.ts.map +1 -0
- package/dist/kernels/linear_projection.js +219 -0
- package/dist/kernels/linear_projection.js.map +1 -0
- package/dist/kernels/selective_scan.d.ts +3 -0
- package/dist/kernels/selective_scan.d.ts.map +1 -0
- package/dist/kernels/selective_scan.js +348 -0
- package/dist/kernels/selective_scan.js.map +1 -0
- package/dist/kernels/ssd.d.ts +29 -0
- package/dist/kernels/ssd.d.ts.map +1 -0
- package/dist/kernels/ssd.js +276 -0
- package/dist/kernels/ssd.js.map +1 -0
- package/dist/kernels/weight_update.d.ts +3 -0
- package/dist/kernels/weight_update.d.ts.map +1 -0
- package/dist/kernels/weight_update.js +119 -0
- package/dist/kernels/weight_update.js.map +1 -0
- package/dist/model/attention_block.d.ts +48 -0
- package/dist/model/attention_block.d.ts.map +1 -0
- package/dist/model/attention_block.js +262 -0
- package/dist/model/attention_block.js.map +1 -0
- package/dist/model/mamba1_block.d.ts +70 -0
- package/dist/model/mamba1_block.d.ts.map +1 -0
- package/dist/model/mamba1_block.js +333 -0
- package/dist/model/mamba1_block.js.map +1 -0
- package/dist/model/mamba2_block.d.ts +44 -0
- package/dist/model/mamba2_block.d.ts.map +1 -0
- package/dist/model/mamba2_block.js +252 -0
- package/dist/model/mamba2_block.js.map +1 -0
- package/dist/model/mamba3_block.d.ts +51 -0
- package/dist/model/mamba3_block.d.ts.map +1 -0
- package/dist/model/mamba3_block.js +270 -0
- package/dist/model/mamba3_block.js.map +1 -0
- package/dist/model/mamba_block.d.ts +64 -0
- package/dist/model/mamba_block.d.ts.map +1 -0
- package/dist/model/mamba_block.js +303 -0
- package/dist/model/mamba_block.js.map +1 -0
- package/dist/model/mamba_model.d.ts +140 -0
- package/dist/model/mamba_model.d.ts.map +1 -0
- package/dist/model/mamba_model.js +527 -0
- package/dist/model/mamba_model.js.map +1 -0
- package/dist/model/sequence_layer.d.ts +25 -0
- package/dist/model/sequence_layer.d.ts.map +1 -0
- package/dist/model/sequence_layer.js +8 -0
- package/dist/model/sequence_layer.js.map +1 -0
- package/dist/tokenizer/bpe.d.ts +29 -0
- package/dist/tokenizer/bpe.d.ts.map +1 -0
- package/dist/tokenizer/bpe.js +164 -0
- package/dist/tokenizer/bpe.js.map +1 -0
- package/dist/training/autograd.d.ts +27 -0
- package/dist/training/autograd.d.ts.map +1 -0
- package/dist/training/autograd.js +120 -0
- package/dist/training/autograd.js.map +1 -0
- package/dist/training/trainer.d.ts +36 -0
- package/dist/training/trainer.d.ts.map +1 -0
- package/dist/training/trainer.js +183 -0
- package/dist/training/trainer.js.map +1 -0
- package/dist/utils/gpu_utils.d.ts +21 -0
- package/dist/utils/gpu_utils.d.ts.map +1 -0
- package/dist/utils/gpu_utils.js +111 -0
- package/dist/utils/gpu_utils.js.map +1 -0
- package/dist/utils/quantization.d.ts +26 -0
- package/dist/utils/quantization.d.ts.map +1 -0
- package/dist/utils/quantization.js +116 -0
- package/dist/utils/quantization.js.map +1 -0
- package/dist/utils/rng.d.ts +36 -0
- package/dist/utils/rng.d.ts.map +1 -0
- package/dist/utils/rng.js +61 -0
- package/dist/utils/rng.js.map +1 -0
- package/package.json +99 -0
- package/src/index.ts +114 -0
- package/src/kernels/activations.ts +174 -0
- package/src/kernels/attention.ts +268 -0
- package/src/kernels/complex_ssd.ts +307 -0
- package/src/kernels/conv1d.ts +159 -0
- package/src/kernels/linear_projection.ts +220 -0
- package/src/kernels/selective_scan.ts +350 -0
- package/src/kernels/ssd.ts +278 -0
- package/src/kernels/weight_update.ts +120 -0
- package/src/model/attention_block.ts +344 -0
- package/src/model/mamba1_block.ts +437 -0
- package/src/model/mamba2_block.ts +319 -0
- package/src/model/mamba3_block.ts +335 -0
- package/src/model/mamba_block.ts +401 -0
- package/src/model/mamba_model.ts +678 -0
- package/src/model/sequence_layer.ts +29 -0
- package/src/tokenizer/bpe.ts +186 -0
- package/src/training/autograd.ts +135 -0
- package/src/training/trainer.ts +309 -0
- package/src/utils/gpu_utils.ts +147 -0
- package/src/utils/quantization.ts +154 -0
- package/src/utils/rng.ts +65 -0
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* bpe.ts – Browser-side Byte Pair Encoding (BPE) tokenizer.
|
|
3
|
+
*/
|
|
4
|
+
function buildByteEncoder() {
|
|
5
|
+
const enc = new Map();
|
|
6
|
+
const ranges = [
|
|
7
|
+
[0x21, 0x7E],
|
|
8
|
+
[0xA1, 0xAC],
|
|
9
|
+
[0xAE, 0xFF],
|
|
10
|
+
];
|
|
11
|
+
let n = 0;
|
|
12
|
+
for (const [lo, hi] of ranges) {
|
|
13
|
+
for (let b = lo; b <= hi; b++) {
|
|
14
|
+
enc.set(b, String.fromCodePoint(b));
|
|
15
|
+
}
|
|
16
|
+
}
|
|
17
|
+
for (let b = 0; b < 256; b++) {
|
|
18
|
+
if (!enc.has(b)) {
|
|
19
|
+
enc.set(b, String.fromCodePoint(256 + n));
|
|
20
|
+
n++;
|
|
21
|
+
}
|
|
22
|
+
}
|
|
23
|
+
return enc;
|
|
24
|
+
}
|
|
25
|
+
const BYTE_ENCODER = buildByteEncoder();
|
|
26
|
+
const BYTE_DECODER = new Map([...BYTE_ENCODER].map(([k, v]) => [v, k]));
|
|
27
|
+
const PRE_TOKENIZE_RE = /(?:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+/gu;
|
|
28
|
+
export class BPETokenizer {
|
|
29
|
+
vocab;
|
|
30
|
+
idToToken;
|
|
31
|
+
merges;
|
|
32
|
+
bosToken;
|
|
33
|
+
eosToken;
|
|
34
|
+
padToken;
|
|
35
|
+
unkToken;
|
|
36
|
+
bosId;
|
|
37
|
+
eosId;
|
|
38
|
+
padId;
|
|
39
|
+
constructor() {
|
|
40
|
+
this.vocab = new Map();
|
|
41
|
+
this.idToToken = new Map();
|
|
42
|
+
this.merges = new Map();
|
|
43
|
+
this.bosToken = '<|im_start|>';
|
|
44
|
+
this.eosToken = '<|im_end|>';
|
|
45
|
+
this.padToken = '<|endoftext|>';
|
|
46
|
+
this.unkToken = '<unk>';
|
|
47
|
+
this.bosId = null;
|
|
48
|
+
this.eosId = null;
|
|
49
|
+
this.padId = null;
|
|
50
|
+
}
|
|
51
|
+
async load(vocab, merges) {
|
|
52
|
+
let vocabObj;
|
|
53
|
+
if (typeof vocab === 'string') {
|
|
54
|
+
const res = await fetch(vocab);
|
|
55
|
+
vocabObj = await res.json();
|
|
56
|
+
}
|
|
57
|
+
else {
|
|
58
|
+
vocabObj = vocab;
|
|
59
|
+
}
|
|
60
|
+
this.vocab = new Map(Object.entries(vocabObj).map(([k, v]) => [k, Number(v)]));
|
|
61
|
+
this.idToToken = new Map([...this.vocab].map(([k, v]) => [v, k]));
|
|
62
|
+
let mergeLines;
|
|
63
|
+
if (typeof merges === 'string') {
|
|
64
|
+
const res = await fetch(merges);
|
|
65
|
+
const txt = await res.text();
|
|
66
|
+
mergeLines = txt.split('\n').filter(l => l && !l.startsWith('#'));
|
|
67
|
+
}
|
|
68
|
+
else {
|
|
69
|
+
mergeLines = merges;
|
|
70
|
+
}
|
|
71
|
+
this.merges = new Map();
|
|
72
|
+
mergeLines.forEach((line, rank) => {
|
|
73
|
+
this.merges.set(line.trim(), rank);
|
|
74
|
+
});
|
|
75
|
+
this.bosId = this.vocab.get(this.bosToken) ?? null;
|
|
76
|
+
this.eosId = this.vocab.get(this.eosToken) ?? null;
|
|
77
|
+
this.padId = this.vocab.get(this.padToken) ?? null;
|
|
78
|
+
}
|
|
79
|
+
loadFromObjects(vocabObj, mergeArr) {
|
|
80
|
+
this.vocab = new Map(Object.entries(vocabObj).map(([k, v]) => [k, Number(v)]));
|
|
81
|
+
this.idToToken = new Map([...this.vocab].map(([k, v]) => [v, k]));
|
|
82
|
+
this.merges = new Map(mergeArr.map((m, i) => [m, i]));
|
|
83
|
+
this.bosId = this.vocab.get(this.bosToken) ?? null;
|
|
84
|
+
this.eosId = this.vocab.get(this.eosToken) ?? null;
|
|
85
|
+
this.padId = this.vocab.get(this.padToken) ?? null;
|
|
86
|
+
}
|
|
87
|
+
encode(text, opts = {}) {
|
|
88
|
+
const words = text.match(PRE_TOKENIZE_RE) ?? [];
|
|
89
|
+
const ids = [];
|
|
90
|
+
if (opts.addBos && this.bosId !== null)
|
|
91
|
+
ids.push(this.bosId);
|
|
92
|
+
for (const word of words) {
|
|
93
|
+
const bytes = new TextEncoder().encode(word);
|
|
94
|
+
const byteStr = Array.from(bytes).map(b => BYTE_ENCODER.get(b) ?? '?').join('');
|
|
95
|
+
const bpeTokens = this._bpe(byteStr);
|
|
96
|
+
for (const tok of bpeTokens) {
|
|
97
|
+
const id = this.vocab.get(tok);
|
|
98
|
+
if (id !== undefined) {
|
|
99
|
+
ids.push(id);
|
|
100
|
+
}
|
|
101
|
+
else {
|
|
102
|
+
for (const ch of tok) {
|
|
103
|
+
const cid = this.vocab.get(ch);
|
|
104
|
+
if (cid !== undefined)
|
|
105
|
+
ids.push(cid);
|
|
106
|
+
}
|
|
107
|
+
}
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
if (opts.addEos && this.eosId !== null)
|
|
111
|
+
ids.push(this.eosId);
|
|
112
|
+
return ids;
|
|
113
|
+
}
|
|
114
|
+
decode(ids) {
|
|
115
|
+
let byteStr = '';
|
|
116
|
+
for (const id of ids) {
|
|
117
|
+
const tok = this.idToToken.get(id);
|
|
118
|
+
if (tok !== undefined)
|
|
119
|
+
byteStr += tok;
|
|
120
|
+
}
|
|
121
|
+
const bytes = new Uint8Array([...byteStr].map(ch => BYTE_DECODER.get(ch) ?? ch.codePointAt(0) ?? 0));
|
|
122
|
+
try {
|
|
123
|
+
return new TextDecoder('utf-8').decode(bytes);
|
|
124
|
+
}
|
|
125
|
+
catch {
|
|
126
|
+
return byteStr;
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
_bpe(word) {
|
|
130
|
+
if (this.vocab.has(word))
|
|
131
|
+
return [word];
|
|
132
|
+
let symbols = [...word];
|
|
133
|
+
while (symbols.length > 1) {
|
|
134
|
+
let bestRank = Infinity;
|
|
135
|
+
let bestIdx = -1;
|
|
136
|
+
for (let i = 0; i < symbols.length - 1; i++) {
|
|
137
|
+
const pair = symbols[i] + ' ' + symbols[i + 1];
|
|
138
|
+
const rank = this.merges.get(pair);
|
|
139
|
+
if (rank !== undefined && rank < bestRank) {
|
|
140
|
+
bestRank = rank;
|
|
141
|
+
bestIdx = i;
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
if (bestIdx === -1)
|
|
145
|
+
break;
|
|
146
|
+
const merged = symbols[bestIdx] + symbols[bestIdx + 1];
|
|
147
|
+
symbols = [
|
|
148
|
+
...symbols.slice(0, bestIdx),
|
|
149
|
+
merged,
|
|
150
|
+
...symbols.slice(bestIdx + 2),
|
|
151
|
+
];
|
|
152
|
+
}
|
|
153
|
+
return symbols;
|
|
154
|
+
}
|
|
155
|
+
padOrTruncate(ids, maxLen, side = 'right') {
|
|
156
|
+
if (ids.length >= maxLen)
|
|
157
|
+
return ids.slice(0, maxLen);
|
|
158
|
+
const padId = this.padId ?? 0;
|
|
159
|
+
const pad = new Array(maxLen - ids.length).fill(padId);
|
|
160
|
+
return side === 'right' ? [...ids, ...pad] : [...pad, ...ids];
|
|
161
|
+
}
|
|
162
|
+
get vocabSize() { return this.vocab.size; }
|
|
163
|
+
}
|
|
164
|
+
//# sourceMappingURL=bpe.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"bpe.js","sourceRoot":"","sources":["../../src/tokenizer/bpe.ts"],"names":[],"mappings":"AAAA;;GAEG;AASH,SAAS,gBAAgB;IACrB,MAAM,GAAG,GAAG,IAAI,GAAG,EAAkB,CAAC;IACtC,MAAM,MAAM,GAAuB;QAC/B,CAAC,IAAI,EAAE,IAAI,CAAC;QACZ,CAAC,IAAI,EAAE,IAAI,CAAC;QACZ,CAAC,IAAI,EAAE,IAAI,CAAC;KACf,CAAC;IACF,IAAI,CAAC,GAAG,CAAC,CAAC;IACV,KAAK,MAAM,CAAC,EAAE,EAAE,EAAE,CAAC,IAAI,MAAM,EAAE,CAAC;QAC5B,KAAK,IAAI,CAAC,GAAG,EAAE,EAAE,CAAC,IAAI,EAAE,EAAE,CAAC,EAAE,EAAE,CAAC;YAC5B,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,MAAM,CAAC,aAAa,CAAC,CAAC,CAAC,CAAC,CAAC;QACxC,CAAC;IACL,CAAC;IACD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC;QAC3B,IAAI,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC;YACd,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,MAAM,CAAC,aAAa,CAAC,GAAG,GAAG,CAAC,CAAC,CAAC,CAAC;YAC1C,CAAC,EAAE,CAAC;QACR,CAAC;IACL,CAAC;IACD,OAAO,GAAG,CAAC;AACf,CAAC;AAED,MAAM,YAAY,GAAG,gBAAgB,EAAE,CAAC;AACxC,MAAM,YAAY,GAAG,IAAI,GAAG,CAAC,CAAC,GAAG,YAAY,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;AAExE,MAAM,eAAe,GACjB,sHAAsH,CAAC;AAE3H,MAAM,OAAO,YAAY;IACrB,KAAK,CAAsB;IAC3B,SAAS,CAAsB;IAC/B,MAAM,CAAsB;IAC5B,QAAQ,CAAS;IACjB,QAAQ,CAAS;IACjB,QAAQ,CAAS;IACjB,QAAQ,CAAS;IACjB,KAAK,CAAgB;IACrB,KAAK,CAAgB;IACrB,KAAK,CAAgB;IAErB;QACI,IAAI,CAAC,KAAK,GAAQ,IAAI,GAAG,EAAE,CAAC;QAC5B,IAAI,CAAC,SAAS,GAAI,IAAI,GAAG,EAAE,CAAC;QAC5B,IAAI,CAAC,MAAM,GAAO,IAAI,GAAG,EAAE,CAAC;QAC5B,IAAI,CAAC,QAAQ,GAAK,cAAc,CAAC;QACjC,IAAI,CAAC,QAAQ,GAAK,YAAY,CAAC;QAC/B,IAAI,CAAC,QAAQ,GAAK,eAAe,CAAC;QAClC,IAAI,CAAC,QAAQ,GAAK,OAAO,CAAC;QAC1B,IAAI,CAAC,KAAK,GAAQ,IAAI,CAAC;QACvB,IAAI,CAAC,KAAK,GAAQ,IAAI,CAAC;QACvB,IAAI,CAAC,KAAK,GAAQ,IAAI,CAAC;IAC3B,CAAC;IAED,KAAK,CAAC,IAAI,CAAC,KAAsC,EAAE,MAAyB;QACxE,IAAI,QAAgC,CAAC;QACrC,IAAI,OAAO,KAAK,KAAK,QAAQ,EAAE,CAAC;YAC5B,MAAM,GAAG,GAAG,MAAM,KAAK,CAAC,KAAK,CAAC,CAAC;YAC/B,QAAQ,GAAG,MAAM,GAAG,CAAC,IAAI,EAA4B,CAAC;QAC1D,CAAC;aAAM,CAAC;YACJ,QAAQ,GAAG,KAAK,CAAC;QACrB,CAAC;QACD,IAAI,CAAC,KAAK,GAAO,IAAI,GAAG,CAAC,MAAM,CAAC,OAAO,CAAC,QAAQ,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QACnF,IAAI,CAAC,SAAS,GAAG,IAAI,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QAElE,IAAI,UAAoB,CAAC;QACzB,IAAI,OAAO,MAAM,KAAK,QAAQ,EAAE,CAAC;YAC7B,MAAM,GAAG,GAAG,MAAM,KAAK,CAAC,MAAM,CAAC,CAAC;YAChC,MAAM,GAAG,GAAG,MAAM,GAAG,CAAC,IAAI,EAAE,CAAC;YAC7B,UAAU,GAAG,GAAG,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,UAAU,CAAC,GAAG,CAAC,CAAC,CAAC;QACtE,CAAC;aAAM,CAAC;YACJ,UAAU,GAAG,MAAM,CAAC;QACxB,CAAC;QACD,IAAI,CAAC,MAAM,GAAG,IAAI,GAAG,EAAE,CAAC;QACxB,UAAU,CAAC,OAAO,CAAC,CAAC,IAAI,EAAE,IAAI,EAAE,EAAE;YAC9B,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,IAAI,CAAC,IAAI,EAAE,EAAE,IAAI,CAAC,CAAC;QACvC,CAAC,CAAC,CAAC;QAEH,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC,IAAI,IAAI,CAAC;QACnD,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC,IAAI,IAAI,CAAC;QACnD,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC,IAAI,IAAI,CAAC;IACvD,CAAC;IAED,eAAe,CAAC,QAAgC,EAAE,QAAkB;QAChE,IAAI,CAAC,KAAK,GAAO,IAAI,GAAG,CAAC,MAAM,CAAC,OAAO,CAAC,QAAQ,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QACnF,IAAI,CAAC,SAAS,GAAG,IAAI,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QAClE,IAAI,CAAC,MAAM,GAAM,IAAI,GAAG,CAAC,QAAQ,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QACzD,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC,IAAI,IAAI,CAAC;QACnD,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC,IAAI,IAAI,CAAC;QACnD,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC,IAAI,IAAI,CAAC;IACvD,CAAC;IAED,MAAM,CAAC,IAAY,EAAE,OAAyB,EAAE;QAC5C,MAAM,KAAK,GAAG,IAAI,CAAC,KAAK,CAAC,eAAe,CAAC,IAAI,EAAE,CAAC;QAChD,MAAM,GAAG,GAAe,EAAE,CAAC;QAE3B,IAAI,IAAI,CAAC,MAAM,IAAI,IAAI,CAAC,KAAK,KAAK,IAAI;YAAE,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;QAE7D,KAAK,MAAM,IAAI,IAAI,KAAK,EAAE,CAAC;YACvB,MAAM,KAAK,GAAM,IAAI,WAAW,EAAE,CAAC,MAAM,CAAC,IAAI,CAAC,CAAC;YAChD,MAAM,OAAO,GAAI,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,YAAY,CAAC,GAAG,CAAC,CAAC,CAAC,IAAI,GAAG,CAAC,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;YACjF,MAAM,SAAS,GAAG,IAAI,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;YAErC,KAAK,MAAM,GAAG,IAAI,SAAS,EAAE,CAAC;gBAC1B,MAAM,EAAE,GAAG,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC;gBAC/B,IAAI,EAAE,KAAK,SAAS,EAAE,CAAC;oBACnB,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;gBACjB,CAAC;qBAAM,CAAC;oBACJ,KAAK,MAAM,EAAE,IAAI,GAAG,EAAE,CAAC;wBACnB,MAAM,GAAG,GAAG,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;wBAC/B,IAAI,GAAG,KAAK,SAAS;4BAAE,GAAG,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;oBACzC,CAAC;gBACL,CAAC;YACL,CAAC;QACL,CAAC;QAED,IAAI,IAAI,CAAC,MAAM,IAAI,IAAI,CAAC,KAAK,KAAK,IAAI;YAAE,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;QAC7D,OAAO,GAAG,CAAC;IACf,CAAC;IAED,MAAM,CAAC,GAAa;QAChB,IAAI,OAAO,GAAG,EAAE,CAAC;QACjB,KAAK,MAAM,EAAE,IAAI,GAAG,EAAE,CAAC;YACnB,MAAM,GAAG,GAAG,IAAI,CAAC,SAAS,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;YACnC,IAAI,GAAG,KAAK,SAAS;gBAAE,OAAO,IAAI,GAAG,CAAC;QAC1C,CAAC;QACD,MAAM,KAAK,GAAG,IAAI,UAAU,CACxB,CAAC,GAAG,OAAO,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,YAAY,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,EAAE,CAAC,WAAW,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,CACzE,CAAC;QACF,IAAI,CAAC;YACD,OAAO,IAAI,WAAW,CAAC,OAAO,CAAC,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;QAClD,CAAC;QAAC,MAAM,CAAC;YACL,OAAO,OAAO,CAAC;QACnB,CAAC;IACL,CAAC;IAED,IAAI,CAAC,IAAY;QACb,IAAI,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,IAAI,CAAC;YAAE,OAAO,CAAC,IAAI,CAAC,CAAC;QAExC,IAAI,OAAO,GAAG,CAAC,GAAG,IAAI,CAAC,CAAC;QAExB,OAAO,OAAO,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC;YACxB,IAAI,QAAQ,GAAG,QAAQ,CAAC;YACxB,IAAI,OAAO,GAAI,CAAC,CAAC,CAAC;YAElB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,OAAO,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;gBAC1C,MAAM,IAAI,GAAG,OAAO,CAAC,CAAC,CAAC,GAAG,GAAG,GAAG,OAAO,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;gBAC/C,MAAM,IAAI,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;gBACnC,IAAI,IAAI,KAAK,SAAS,IAAI,IAAI,GAAG,QAAQ,EAAE,CAAC;oBACxC,QAAQ,GAAG,IAAI,CAAC;oBAChB,OAAO,GAAI,CAAC,CAAC;gBACjB,CAAC;YACL,CAAC;YAED,IAAI,OAAO,KAAK,CAAC,CAAC;gBAAE,MAAM;YAE1B,MAAM,MAAM,GAAG,OAAO,CAAC,OAAO,CAAE,GAAG,OAAO,CAAC,OAAO,GAAG,CAAC,CAAE,CAAC;YACzD,OAAO,GAAG;gBACN,GAAG,OAAO,CAAC,KAAK,CAAC,CAAC,EAAE,OAAO,CAAC;gBAC5B,MAAM;gBACN,GAAG,OAAO,CAAC,KAAK,CAAC,OAAO,GAAG,CAAC,CAAC;aAChC,CAAC;QACN,CAAC;QAED,OAAO,OAAO,CAAC;IACnB,CAAC;IAED,aAAa,CAAC,GAAa,EAAE,MAAc,EAAE,OAAgB,OAAO;QAChE,IAAI,GAAG,CAAC,MAAM,IAAI,MAAM;YAAE,OAAO,GAAG,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;QACtD,MAAM,KAAK,GAAG,IAAI,CAAC,KAAK,IAAI,CAAC,CAAC;QAC9B,MAAM,GAAG,GAAK,IAAI,KAAK,CAAS,MAAM,GAAG,GAAG,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;QACjE,OAAO,IAAI,KAAK,OAAO,CAAC,CAAC,CAAC,CAAC,GAAG,GAAG,EAAE,GAAG,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,GAAG,EAAE,GAAG,GAAG,CAAC,CAAC;IAClE,CAAC;IAED,IAAI,SAAS,KAAa,OAAO,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,CAAC,CAAC;CACtD"}
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* autograd.ts – Lightweight tape-based automatic differentiation engine.
|
|
3
|
+
*/
|
|
4
|
+
export declare class Tensor {
|
|
5
|
+
data: GPUBuffer | null;
|
|
6
|
+
shape: number[];
|
|
7
|
+
numel: number;
|
|
8
|
+
requiresGrad: boolean;
|
|
9
|
+
grad: GPUBuffer | null;
|
|
10
|
+
_gradFn: number | null;
|
|
11
|
+
constructor(data: GPUBuffer | null, shape: number[], requiresGrad?: boolean);
|
|
12
|
+
get byteSize(): number;
|
|
13
|
+
zeroGrad(device: GPUDevice): void;
|
|
14
|
+
destroy(): void;
|
|
15
|
+
}
|
|
16
|
+
export declare function enableGrad(): void;
|
|
17
|
+
export declare function noGrad(): void;
|
|
18
|
+
export declare function clearTape(): void;
|
|
19
|
+
export declare function recordOperation(backwardFn: () => void | Promise<void>): number;
|
|
20
|
+
export declare function backward(): Promise<void>;
|
|
21
|
+
export declare function ensureGradBuffer(device: GPUDevice, tensor: Tensor): void;
|
|
22
|
+
export declare function allocateGradients(device: GPUDevice, tensors: Tensor[]): void;
|
|
23
|
+
export declare function zeroGradients(device: GPUDevice, tensors: Tensor[]): void;
|
|
24
|
+
export declare function onesLikeScalar(device: GPUDevice): GPUBuffer;
|
|
25
|
+
export declare function crossEntropyLoss(logits: Float32Array, targetId: number): number;
|
|
26
|
+
export declare function crossEntropyGrad(logits: Float32Array, targetId: number): Float32Array;
|
|
27
|
+
//# sourceMappingURL=autograd.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"autograd.d.ts","sourceRoot":"","sources":["../../src/training/autograd.ts"],"names":[],"mappings":"AAAA;;GAEG;AAYH,qBAAa,MAAM;IACf,IAAI,EAAE,SAAS,GAAG,IAAI,CAAC;IACvB,KAAK,EAAE,MAAM,EAAE,CAAC;IAChB,KAAK,EAAE,MAAM,CAAC;IACd,YAAY,EAAE,OAAO,CAAC;IACtB,IAAI,EAAE,SAAS,GAAG,IAAI,CAAC;IACvB,OAAO,EAAE,MAAM,GAAG,IAAI,CAAC;gBAEX,IAAI,EAAE,SAAS,GAAG,IAAI,EAAE,KAAK,EAAE,MAAM,EAAE,EAAE,YAAY,UAAQ;IASzE,IAAI,QAAQ,IAAI,MAAM,CAA2B;IAEjD,QAAQ,CAAC,MAAM,EAAE,SAAS,GAAG,IAAI;IAMjC,OAAO,IAAI,IAAI;CAMlB;AAED,wBAAgB,UAAU,IAAI,IAAI,CAA2B;AAC7D,wBAAgB,MAAM,IAAI,IAAI,CAA+B;AAC7D,wBAAgB,SAAS,IAAI,IAAI,CAAkB;AAEnD,wBAAgB,eAAe,CAAC,UAAU,EAAE,MAAM,IAAI,GAAG,OAAO,CAAC,IAAI,CAAC,GAAG,MAAM,CAI9E;AAED,wBAAsB,QAAQ,IAAI,OAAO,CAAC,IAAI,CAAC,CAK9C;AAED,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,SAAS,EAAE,MAAM,EAAE,MAAM,GAAG,IAAI,CAWxE;AAED,wBAAgB,iBAAiB,CAAC,MAAM,EAAE,SAAS,EAAE,OAAO,EAAE,MAAM,EAAE,GAAG,IAAI,CAI5E;AAED,wBAAgB,aAAa,CAAC,MAAM,EAAE,SAAS,EAAE,OAAO,EAAE,MAAM,EAAE,GAAG,IAAI,CAMxE;AAED,wBAAgB,cAAc,CAAC,MAAM,EAAE,SAAS,GAAG,SAAS,CAW3D;AAED,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,YAAY,EAAE,QAAQ,EAAE,MAAM,GAAG,MAAM,CAW/E;AAED,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,YAAY,EAAE,QAAQ,EAAE,MAAM,GAAG,YAAY,CAiBrF"}
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* autograd.ts – Lightweight tape-based automatic differentiation engine.
|
|
3
|
+
*/
|
|
4
|
+
/* eslint-disable @typescript-eslint/no-explicit-any */
|
|
5
|
+
const _gpu = globalThis;
|
|
6
|
+
let _tape = [];
|
|
7
|
+
let _gradEnabled = true;
|
|
8
|
+
export class Tensor {
|
|
9
|
+
data;
|
|
10
|
+
shape;
|
|
11
|
+
numel;
|
|
12
|
+
requiresGrad;
|
|
13
|
+
grad;
|
|
14
|
+
_gradFn;
|
|
15
|
+
constructor(data, shape, requiresGrad = false) {
|
|
16
|
+
this.data = data;
|
|
17
|
+
this.shape = shape;
|
|
18
|
+
this.numel = shape.reduce((a, b) => a * b, 1);
|
|
19
|
+
this.requiresGrad = requiresGrad;
|
|
20
|
+
this.grad = null;
|
|
21
|
+
this._gradFn = null;
|
|
22
|
+
}
|
|
23
|
+
get byteSize() { return this.numel * 4; }
|
|
24
|
+
zeroGrad(device) {
|
|
25
|
+
if (this.grad) {
|
|
26
|
+
device.queue.writeBuffer(this.grad, 0, new Float32Array(this.numel));
|
|
27
|
+
}
|
|
28
|
+
}
|
|
29
|
+
destroy() {
|
|
30
|
+
this.data?.destroy();
|
|
31
|
+
this.grad?.destroy();
|
|
32
|
+
this.data = null;
|
|
33
|
+
this.grad = null;
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
export function enableGrad() { _gradEnabled = true; }
|
|
37
|
+
export function noGrad() { _gradEnabled = false; }
|
|
38
|
+
export function clearTape() { _tape = []; }
|
|
39
|
+
export function recordOperation(backwardFn) {
|
|
40
|
+
if (!_gradEnabled)
|
|
41
|
+
return -1;
|
|
42
|
+
_tape.push({ backward: backwardFn });
|
|
43
|
+
return _tape.length - 1;
|
|
44
|
+
}
|
|
45
|
+
export async function backward() {
|
|
46
|
+
for (let i = _tape.length - 1; i >= 0; i--) {
|
|
47
|
+
await _tape[i].backward();
|
|
48
|
+
}
|
|
49
|
+
clearTape();
|
|
50
|
+
}
|
|
51
|
+
export function ensureGradBuffer(device, tensor) {
|
|
52
|
+
if (!tensor.grad) {
|
|
53
|
+
const STORAGE_USAGE = (_gpu.GPUBufferUsage?.STORAGE ?? 0x80) |
|
|
54
|
+
(_gpu.GPUBufferUsage?.COPY_DST ?? 0x08) |
|
|
55
|
+
(_gpu.GPUBufferUsage?.COPY_SRC ?? 0x04);
|
|
56
|
+
tensor.grad = device.createBuffer({
|
|
57
|
+
size: tensor.byteSize,
|
|
58
|
+
usage: STORAGE_USAGE,
|
|
59
|
+
});
|
|
60
|
+
device.queue.writeBuffer(tensor.grad, 0, new Float32Array(tensor.numel));
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
export function allocateGradients(device, tensors) {
|
|
64
|
+
for (const t of tensors) {
|
|
65
|
+
if (t.requiresGrad)
|
|
66
|
+
ensureGradBuffer(device, t);
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
export function zeroGradients(device, tensors) {
|
|
70
|
+
for (const t of tensors) {
|
|
71
|
+
if (t.grad) {
|
|
72
|
+
device.queue.writeBuffer(t.grad, 0, new Float32Array(t.numel));
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
export function onesLikeScalar(device) {
|
|
77
|
+
const USAGE = (_gpu.GPUBufferUsage?.STORAGE ?? 0x80) |
|
|
78
|
+
(_gpu.GPUBufferUsage?.COPY_DST ?? 0x08);
|
|
79
|
+
const buf = device.createBuffer({
|
|
80
|
+
size: 4,
|
|
81
|
+
usage: USAGE,
|
|
82
|
+
mappedAtCreation: true,
|
|
83
|
+
});
|
|
84
|
+
new Float32Array(buf.getMappedRange()).set([1.0]);
|
|
85
|
+
buf.unmap();
|
|
86
|
+
return buf;
|
|
87
|
+
}
|
|
88
|
+
export function crossEntropyLoss(logits, targetId) {
|
|
89
|
+
let maxLogit = -Infinity;
|
|
90
|
+
for (let i = 0; i < logits.length; i++) {
|
|
91
|
+
if (logits[i] > maxLogit)
|
|
92
|
+
maxLogit = logits[i];
|
|
93
|
+
}
|
|
94
|
+
let sumExp = 0;
|
|
95
|
+
for (let i = 0; i < logits.length; i++) {
|
|
96
|
+
sumExp += Math.exp(logits[i] - maxLogit);
|
|
97
|
+
}
|
|
98
|
+
const logSumExp = Math.log(sumExp) + maxLogit;
|
|
99
|
+
return logSumExp - logits[targetId];
|
|
100
|
+
}
|
|
101
|
+
export function crossEntropyGrad(logits, targetId) {
|
|
102
|
+
let maxLogit = -Infinity;
|
|
103
|
+
for (let i = 0; i < logits.length; i++) {
|
|
104
|
+
if (logits[i] > maxLogit)
|
|
105
|
+
maxLogit = logits[i];
|
|
106
|
+
}
|
|
107
|
+
let sumExp = 0;
|
|
108
|
+
const exp_shifted = new Float32Array(logits.length);
|
|
109
|
+
for (let i = 0; i < logits.length; i++) {
|
|
110
|
+
exp_shifted[i] = Math.exp(logits[i] - maxLogit);
|
|
111
|
+
sumExp += exp_shifted[i];
|
|
112
|
+
}
|
|
113
|
+
const probs = new Float32Array(logits.length);
|
|
114
|
+
for (let i = 0; i < logits.length; i++) {
|
|
115
|
+
probs[i] = exp_shifted[i] / sumExp;
|
|
116
|
+
}
|
|
117
|
+
probs[targetId] = (probs[targetId] ?? 0) - 1.0;
|
|
118
|
+
return probs;
|
|
119
|
+
}
|
|
120
|
+
//# sourceMappingURL=autograd.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"autograd.js","sourceRoot":"","sources":["../../src/training/autograd.ts"],"names":[],"mappings":"AAAA;;GAEG;AAEH,uDAAuD;AACvD,MAAM,IAAI,GAAG,UAAiB,CAAC;AAM/B,IAAI,KAAK,GAAgB,EAAE,CAAC;AAC5B,IAAI,YAAY,GAAG,IAAI,CAAC;AAExB,MAAM,OAAO,MAAM;IACf,IAAI,CAAmB;IACvB,KAAK,CAAW;IAChB,KAAK,CAAS;IACd,YAAY,CAAU;IACtB,IAAI,CAAmB;IACvB,OAAO,CAAgB;IAEvB,YAAY,IAAsB,EAAE,KAAe,EAAE,YAAY,GAAG,KAAK;QACrE,IAAI,CAAC,IAAI,GAAW,IAAI,CAAC;QACzB,IAAI,CAAC,KAAK,GAAU,KAAK,CAAC;QAC1B,IAAI,CAAC,KAAK,GAAU,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC,CAAC;QACrD,IAAI,CAAC,YAAY,GAAG,YAAY,CAAC;QACjC,IAAI,CAAC,IAAI,GAAW,IAAI,CAAC;QACzB,IAAI,CAAC,OAAO,GAAQ,IAAI,CAAC;IAC7B,CAAC;IAED,IAAI,QAAQ,KAAa,OAAO,IAAI,CAAC,KAAK,GAAG,CAAC,CAAC,CAAC,CAAC;IAEjD,QAAQ,CAAC,MAAiB;QACtB,IAAI,IAAI,CAAC,IAAI,EAAE,CAAC;YACZ,MAAM,CAAC,KAAK,CAAC,WAAW,CAAC,IAAI,CAAC,IAAI,EAAE,CAAC,EAAE,IAAI,YAAY,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC;QACzE,CAAC;IACL,CAAC;IAED,OAAO;QACH,IAAI,CAAC,IAAI,EAAE,OAAO,EAAE,CAAC;QACrB,IAAI,CAAC,IAAI,EAAE,OAAO,EAAE,CAAC;QACrB,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC;QACjB,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC;IACrB,CAAC;CACJ;AAED,MAAM,UAAU,UAAU,KAAY,YAAY,GAAG,IAAI,CAAC,CAAE,CAAC;AAC7D,MAAM,UAAU,MAAM,KAAgB,YAAY,GAAG,KAAK,CAAC,CAAC,CAAC;AAC7D,MAAM,UAAU,SAAS,KAAa,KAAK,GAAG,EAAE,CAAC,CAAC,CAAC;AAEnD,MAAM,UAAU,eAAe,CAAC,UAAsC;IAClE,IAAI,CAAC,YAAY;QAAE,OAAO,CAAC,CAAC,CAAC;IAC7B,KAAK,CAAC,IAAI,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,CAAC,CAAC;IACrC,OAAO,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC;AAC5B,CAAC;AAED,MAAM,CAAC,KAAK,UAAU,QAAQ;IAC1B,KAAK,IAAI,CAAC,GAAG,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;QACzC,MAAM,KAAK,CAAC,CAAC,CAAE,CAAC,QAAQ,EAAE,CAAC;IAC/B,CAAC;IACD,SAAS,EAAE,CAAC;AAChB,CAAC;AAED,MAAM,UAAU,gBAAgB,CAAC,MAAiB,EAAE,MAAc;IAC9D,IAAI,CAAC,MAAM,CAAC,IAAI,EAAE,CAAC;QACf,MAAM,aAAa,GAAW,CAAC,IAAI,CAAC,cAAc,EAAE,OAAO,IAAI,IAAI,CAAC;YACtC,CAAC,IAAI,CAAC,cAAc,EAAE,QAAQ,IAAI,IAAI,CAAC;YACvC,CAAC,IAAI,CAAC,cAAc,EAAE,QAAQ,IAAI,IAAI,CAAC,CAAC;QACtE,MAAM,CAAC,IAAI,GAAG,MAAM,CAAC,YAAY,CAAC;YAC9B,IAAI,EAAI,MAAM,CAAC,QAAQ;YACvB,KAAK,EAAG,aAAa;SACxB,CAAC,CAAC;QACH,MAAM,CAAC,KAAK,CAAC,WAAW,CAAC,MAAM,CAAC,IAAI,EAAE,CAAC,EAAE,IAAI,YAAY,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC;IAC7E,CAAC;AACL,CAAC;AAED,MAAM,UAAU,iBAAiB,CAAC,MAAiB,EAAE,OAAiB;IAClE,KAAK,MAAM,CAAC,IAAI,OAAO,EAAE,CAAC;QACtB,IAAI,CAAC,CAAC,YAAY;YAAE,gBAAgB,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC;IACpD,CAAC;AACL,CAAC;AAED,MAAM,UAAU,aAAa,CAAC,MAAiB,EAAE,OAAiB;IAC9D,KAAK,MAAM,CAAC,IAAI,OAAO,EAAE,CAAC;QACtB,IAAI,CAAC,CAAC,IAAI,EAAE,CAAC;YACT,MAAM,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,IAAI,EAAE,CAAC,EAAE,IAAI,YAAY,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC;QACnE,CAAC;IACL,CAAC;AACL,CAAC;AAED,MAAM,UAAU,cAAc,CAAC,MAAiB;IAC5C,MAAM,KAAK,GAAW,CAAC,IAAI,CAAC,cAAc,EAAE,OAAO,IAAI,IAAI,CAAC;QACtC,CAAC,IAAI,CAAC,cAAc,EAAE,QAAQ,IAAI,IAAI,CAAC,CAAC;IAC9D,MAAM,GAAG,GAAG,MAAM,CAAC,YAAY,CAAC;QAC5B,IAAI,EAAI,CAAC;QACT,KAAK,EAAG,KAAK;QACb,gBAAgB,EAAE,IAAI;KACzB,CAAC,CAAC;IACH,IAAI,YAAY,CAAC,GAAG,CAAC,cAAc,EAAE,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;IAClD,GAAG,CAAC,KAAK,EAAE,CAAC;IACZ,OAAO,GAAG,CAAC;AACf,CAAC;AAED,MAAM,UAAU,gBAAgB,CAAC,MAAoB,EAAE,QAAgB;IACnE,IAAI,QAAQ,GAAG,CAAC,QAAQ,CAAC;IACzB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QACrC,IAAI,MAAM,CAAC,CAAC,CAAE,GAAG,QAAQ;YAAE,QAAQ,GAAG,MAAM,CAAC,CAAC,CAAE,CAAC;IACrD,CAAC;IACD,IAAI,MAAM,GAAG,CAAC,CAAC;IACf,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QACrC,MAAM,IAAI,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,CAAE,GAAG,QAAQ,CAAC,CAAC;IAC9C,CAAC;IACD,MAAM,SAAS,GAAG,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,GAAG,QAAQ,CAAC;IAC9C,OAAO,SAAS,GAAG,MAAM,CAAC,QAAQ,CAAE,CAAC;AACzC,CAAC;AAED,MAAM,UAAU,gBAAgB,CAAC,MAAoB,EAAE,QAAgB;IACnE,IAAI,QAAQ,GAAG,CAAC,QAAQ,CAAC;IACzB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QACrC,IAAI,MAAM,CAAC,CAAC,CAAE,GAAG,QAAQ;YAAE,QAAQ,GAAG,MAAM,CAAC,CAAC,CAAE,CAAC;IACrD,CAAC;IACD,IAAI,MAAM,GAAG,CAAC,CAAC;IACf,MAAM,WAAW,GAAG,IAAI,YAAY,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC;IACpD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QACrC,WAAW,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,CAAE,GAAG,QAAQ,CAAC,CAAC;QACjD,MAAM,IAAI,WAAW,CAAC,CAAC,CAAE,CAAC;IAC9B,CAAC;IACD,MAAM,KAAK,GAAG,IAAI,YAAY,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC;IAC9C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QACrC,KAAK,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAE,GAAG,MAAM,CAAC;IACxC,CAAC;IACD,KAAK,CAAC,QAAQ,CAAC,GAAG,CAAC,KAAK,CAAC,QAAQ,CAAC,IAAI,CAAC,CAAC,GAAG,GAAG,CAAC;IAC/C,OAAO,KAAK,CAAC;AACjB,CAAC"}
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* trainer.ts – MambaTrainer class
|
|
3
|
+
*/
|
|
4
|
+
import { HybridMambaModel, MambaModel } from '../model/mamba_model.js';
|
|
5
|
+
import { BPETokenizer } from '../tokenizer/bpe.js';
|
|
6
|
+
export interface TrainOptions {
|
|
7
|
+
learningRate?: number;
|
|
8
|
+
epochs?: number;
|
|
9
|
+
batchSize?: number;
|
|
10
|
+
seqLen?: number;
|
|
11
|
+
maxGradNorm?: number;
|
|
12
|
+
weightDecay?: number;
|
|
13
|
+
beta1?: number;
|
|
14
|
+
beta2?: number;
|
|
15
|
+
eps?: number;
|
|
16
|
+
wsla?: boolean;
|
|
17
|
+
onEpochEnd?: ((epoch: number, loss: number) => void) | null;
|
|
18
|
+
}
|
|
19
|
+
export declare class MambaTrainer {
|
|
20
|
+
model: HybridMambaModel;
|
|
21
|
+
tokenizer: BPETokenizer | null;
|
|
22
|
+
device: GPUDevice;
|
|
23
|
+
private _moments;
|
|
24
|
+
private _step;
|
|
25
|
+
private _adamwPipeline;
|
|
26
|
+
private _clipReducePipeline;
|
|
27
|
+
private _clipScalePipeline;
|
|
28
|
+
constructor(model: HybridMambaModel | MambaModel, tokenizer?: BPETokenizer | null);
|
|
29
|
+
private _initMoments;
|
|
30
|
+
train(input: string | number[], opts?: TrainOptions): Promise<number[]>;
|
|
31
|
+
private _trainStep;
|
|
32
|
+
private _adamwStep;
|
|
33
|
+
private _clipGradients;
|
|
34
|
+
evaluate(input: string | number[]): Promise<number>;
|
|
35
|
+
}
|
|
36
|
+
//# sourceMappingURL=trainer.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"trainer.d.ts","sourceRoot":"","sources":["../../src/training/trainer.ts"],"names":[],"mappings":"AAAA;;GAEG;AAcH,OAAO,EAAE,gBAAgB,EAAE,UAAU,EAAE,MAAM,yBAAyB,CAAC;AACvE,OAAO,EAAE,YAAY,EAAE,MAAM,qBAAqB,CAAC;AAGnD,MAAM,WAAW,YAAY;IAC3B,YAAY,CAAC,EAAE,MAAM,CAAC;IACtB,MAAM,CAAC,EAAE,MAAM,CAAC;IAChB,SAAS,CAAC,EAAE,MAAM,CAAC;IACnB,MAAM,CAAC,EAAE,MAAM,CAAC;IAChB,WAAW,CAAC,EAAE,MAAM,CAAC;IACrB,WAAW,CAAC,EAAE,MAAM,CAAC;IACrB,KAAK,CAAC,EAAE,MAAM,CAAC;IACf,KAAK,CAAC,EAAE,MAAM,CAAC;IACf,GAAG,CAAC,EAAE,MAAM,CAAC;IACb,IAAI,CAAC,EAAE,OAAO,CAAC;IACf,UAAU,CAAC,EAAE,CAAC,CAAC,KAAK,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,KAAK,IAAI,CAAC,GAAG,IAAI,CAAC;CAC7D;AAiBD,qBAAa,YAAY;IACrB,KAAK,EAAE,gBAAgB,CAAC;IACxB,SAAS,EAAE,YAAY,GAAG,IAAI,CAAC;IAC/B,MAAM,EAAE,SAAS,CAAC;IAClB,OAAO,CAAC,QAAQ,CAAuB;IACvC,OAAO,CAAC,KAAK,CAAS;IACtB,OAAO,CAAC,cAAc,CAAqB;IAC3C,OAAO,CAAC,mBAAmB,CAAqB;IAChD,OAAO,CAAC,kBAAkB,CAAqB;gBAEnC,KAAK,EAAE,gBAAgB,GAAG,UAAU,EAAE,SAAS,GAAE,YAAY,GAAG,IAAW;IAavF,OAAO,CAAC,YAAY;IAQd,KAAK,CAAC,KAAK,EAAE,MAAM,GAAG,MAAM,EAAE,EAAE,IAAI,GAAE,YAAiB,GAAG,OAAO,CAAC,MAAM,EAAE,CAAC;YAkEnE,UAAU;YAkDV,UAAU;YAgCV,cAAc;IAuBtB,QAAQ,CAAC,KAAK,EAAE,MAAM,GAAG,MAAM,EAAE,GAAG,OAAO,CAAC,MAAM,CAAC;CA4B5D"}
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* trainer.ts – MambaTrainer class
|
|
3
|
+
*/
|
|
4
|
+
import { createUniformBuffer, createStorageBuffer, createEmptyStorageBuffer, createComputePipeline, createBindGroup, dispatchKernel, cdiv, } from '../utils/gpu_utils.js';
|
|
5
|
+
import { crossEntropyLoss, crossEntropyGrad } from './autograd.js';
|
|
6
|
+
import { WEIGHT_UPDATE_WGSL, GRAD_CLIP_WGSL } from '../kernels/weight_update.js';
|
|
7
|
+
export class MambaTrainer {
|
|
8
|
+
model;
|
|
9
|
+
tokenizer;
|
|
10
|
+
device;
|
|
11
|
+
_moments;
|
|
12
|
+
_step;
|
|
13
|
+
_adamwPipeline;
|
|
14
|
+
_clipReducePipeline;
|
|
15
|
+
_clipScalePipeline;
|
|
16
|
+
constructor(model, tokenizer = null) {
|
|
17
|
+
this.model = model;
|
|
18
|
+
this.tokenizer = tokenizer;
|
|
19
|
+
this.device = model.device;
|
|
20
|
+
this._moments = null;
|
|
21
|
+
this._step = 0;
|
|
22
|
+
this._adamwPipeline = createComputePipeline(this.device, WEIGHT_UPDATE_WGSL, 'adamw_update');
|
|
23
|
+
this._clipReducePipeline = createComputePipeline(this.device, GRAD_CLIP_WGSL, 'grad_norm_reduce');
|
|
24
|
+
this._clipScalePipeline = createComputePipeline(this.device, GRAD_CLIP_WGSL, 'grad_clip_scale');
|
|
25
|
+
}
|
|
26
|
+
_initMoments() {
|
|
27
|
+
if (this._moments)
|
|
28
|
+
return;
|
|
29
|
+
this._moments = this.model.parameters().map(p => ({
|
|
30
|
+
m: createEmptyStorageBuffer(this.device, p.numel * 4, false),
|
|
31
|
+
v: createEmptyStorageBuffer(this.device, p.numel * 4, false),
|
|
32
|
+
}));
|
|
33
|
+
}
|
|
34
|
+
async train(input, opts = {}) {
|
|
35
|
+
const { learningRate = 1e-4, epochs = 5, batchSize = 1, seqLen = 512, maxGradNorm = 1.0, weightDecay = 0.01, beta1 = 0.9, beta2 = 0.999, eps = 1e-8, wsla = false, onEpochEnd = null, } = opts;
|
|
36
|
+
if (wsla)
|
|
37
|
+
this.model.setWSLAMode(true);
|
|
38
|
+
let tokenIds;
|
|
39
|
+
if (typeof input === 'string') {
|
|
40
|
+
if (!this.tokenizer) {
|
|
41
|
+
throw new Error('MambaTrainer requires a tokenizer when input is a string. ' +
|
|
42
|
+
'Pass a BPETokenizer instance as the second constructor argument.');
|
|
43
|
+
}
|
|
44
|
+
tokenIds = this.tokenizer.encode(input);
|
|
45
|
+
}
|
|
46
|
+
else {
|
|
47
|
+
tokenIds = Array.from(input);
|
|
48
|
+
}
|
|
49
|
+
if (tokenIds.length < 2) {
|
|
50
|
+
throw new Error('Input must contain at least 2 tokens to form a training pair.');
|
|
51
|
+
}
|
|
52
|
+
const chunks = buildChunks(tokenIds, seqLen);
|
|
53
|
+
if (chunks.length === 0) {
|
|
54
|
+
throw new Error('Input is too short to form any training chunk.');
|
|
55
|
+
}
|
|
56
|
+
this._initMoments();
|
|
57
|
+
const epochLosses = [];
|
|
58
|
+
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
59
|
+
let epochLoss = 0;
|
|
60
|
+
let numSteps = 0;
|
|
61
|
+
for (const { inputs, targets } of chunks) {
|
|
62
|
+
const loss = await this._trainStep(inputs, targets, batchSize, { learningRate, maxGradNorm, weightDecay, beta1, beta2, eps, wsla });
|
|
63
|
+
epochLoss += loss;
|
|
64
|
+
numSteps++;
|
|
65
|
+
}
|
|
66
|
+
const avgLoss = epochLoss / numSteps;
|
|
67
|
+
epochLosses.push(avgLoss);
|
|
68
|
+
if (onEpochEnd)
|
|
69
|
+
onEpochEnd(epoch + 1, avgLoss);
|
|
70
|
+
}
|
|
71
|
+
if (wsla)
|
|
72
|
+
this.model.setWSLAMode(false);
|
|
73
|
+
return epochLosses;
|
|
74
|
+
}
|
|
75
|
+
async _trainStep(inputs, targets, batch, hyperparams) {
|
|
76
|
+
const { learningRate, maxGradNorm, weightDecay, beta1, beta2, eps } = hyperparams;
|
|
77
|
+
this._step++;
|
|
78
|
+
const seqLen = inputs.length;
|
|
79
|
+
const vocabSize = this.model.config.vocabSize;
|
|
80
|
+
const { logits, gpuLogits } = await this.model.forward(new Uint32Array(inputs), batch, seqLen);
|
|
81
|
+
let totalLoss = 0;
|
|
82
|
+
const dLogits = new Float32Array(batch * seqLen * vocabSize);
|
|
83
|
+
for (let i = 0; i < seqLen; i++) {
|
|
84
|
+
const offset = i * vocabSize;
|
|
85
|
+
const logitSlice = logits.slice(offset, offset + vocabSize);
|
|
86
|
+
const target = targets[i];
|
|
87
|
+
totalLoss += crossEntropyLoss(logitSlice, target);
|
|
88
|
+
const grad = crossEntropyGrad(logitSlice, target);
|
|
89
|
+
for (let v = 0; v < vocabSize; v++) {
|
|
90
|
+
dLogits[offset + v] = grad[v] / seqLen;
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
const loss = totalLoss / seqLen;
|
|
94
|
+
const dLogitsBuf = createStorageBuffer(this.device, dLogits, false);
|
|
95
|
+
await this._clipGradients(dLogitsBuf, dLogits.length, maxGradNorm);
|
|
96
|
+
const params = this.model.parameters();
|
|
97
|
+
const beta1_t = Math.pow(beta1, this._step);
|
|
98
|
+
const beta2_t = Math.pow(beta2, this._step);
|
|
99
|
+
await this._adamwStep(params, [dLogitsBuf], { learningRate, weightDecay, beta1, beta2, eps, beta1_t, beta2_t });
|
|
100
|
+
dLogitsBuf.destroy();
|
|
101
|
+
gpuLogits.destroy();
|
|
102
|
+
return loss;
|
|
103
|
+
}
|
|
104
|
+
async _adamwStep(params, gradBufs, hp) {
|
|
105
|
+
const { learningRate, weightDecay, beta1, beta2, eps, beta1_t, beta2_t } = hp;
|
|
106
|
+
for (let i = 0; i < params.length; i++) {
|
|
107
|
+
const p = params[i];
|
|
108
|
+
const gradBuf = gradBufs[Math.min(i, gradBufs.length - 1)];
|
|
109
|
+
if (!gradBuf || gradBuf.size < p.numel * 4)
|
|
110
|
+
continue;
|
|
111
|
+
const paramsBuf = createUniformBuffer(this.device, packAdamParams(p.numel, learningRate, beta1, beta2, eps, weightDecay, beta1_t, beta2_t));
|
|
112
|
+
const bg = createBindGroup(this.device, this._adamwPipeline, [
|
|
113
|
+
paramsBuf,
|
|
114
|
+
p.buf,
|
|
115
|
+
gradBuf,
|
|
116
|
+
this._moments[i].m,
|
|
117
|
+
this._moments[i].v,
|
|
118
|
+
]);
|
|
119
|
+
dispatchKernel(this.device, this._adamwPipeline, bg, [cdiv(p.numel, 256), 1, 1]);
|
|
120
|
+
paramsBuf.destroy();
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
async _clipGradients(gradBuf, numel, maxNorm) {
|
|
124
|
+
const normSqBuf = createEmptyStorageBuffer(this.device, 4, true);
|
|
125
|
+
this.device.queue.writeBuffer(normSqBuf, 0, new Float32Array([0.0]));
|
|
126
|
+
const clipParams = new ArrayBuffer(8);
|
|
127
|
+
new Uint32Array(clipParams, 0, 1).set([numel]);
|
|
128
|
+
new Float32Array(clipParams, 4, 1).set([maxNorm * maxNorm]);
|
|
129
|
+
const pBuf = createUniformBuffer(this.device, clipParams);
|
|
130
|
+
const bg1 = createBindGroup(this.device, this._clipReducePipeline, [pBuf, gradBuf, normSqBuf]);
|
|
131
|
+
dispatchKernel(this.device, this._clipReducePipeline, bg1, [cdiv(numel, 256), 1, 1]);
|
|
132
|
+
const bg2 = createBindGroup(this.device, this._clipScalePipeline, [pBuf, gradBuf, normSqBuf]);
|
|
133
|
+
dispatchKernel(this.device, this._clipScalePipeline, bg2, [cdiv(numel, 256), 1, 1]);
|
|
134
|
+
pBuf.destroy();
|
|
135
|
+
normSqBuf.destroy();
|
|
136
|
+
}
|
|
137
|
+
async evaluate(input) {
|
|
138
|
+
let tokenIds;
|
|
139
|
+
if (typeof input === 'string') {
|
|
140
|
+
if (!this.tokenizer)
|
|
141
|
+
throw new Error('Tokenizer required for string input.');
|
|
142
|
+
tokenIds = this.tokenizer.encode(input);
|
|
143
|
+
}
|
|
144
|
+
else {
|
|
145
|
+
tokenIds = Array.from(input);
|
|
146
|
+
}
|
|
147
|
+
const seqLen = tokenIds.length;
|
|
148
|
+
const vocabSize = this.model.config.vocabSize;
|
|
149
|
+
const { logits } = await this.model.forward(new Uint32Array(tokenIds.slice(0, -1)), 1, seqLen - 1);
|
|
150
|
+
let totalLoss = 0;
|
|
151
|
+
for (let i = 0; i < seqLen - 1; i++) {
|
|
152
|
+
const offset = i * vocabSize;
|
|
153
|
+
totalLoss += crossEntropyLoss(logits.slice(offset, offset + vocabSize), tokenIds[i + 1]);
|
|
154
|
+
}
|
|
155
|
+
const avgLoss = totalLoss / (seqLen - 1);
|
|
156
|
+
return Math.exp(avgLoss);
|
|
157
|
+
}
|
|
158
|
+
}
|
|
159
|
+
function buildChunks(ids, seqLen) {
|
|
160
|
+
const chunks = [];
|
|
161
|
+
for (let start = 0; start + seqLen < ids.length; start += seqLen) {
|
|
162
|
+
chunks.push({
|
|
163
|
+
inputs: ids.slice(start, start + seqLen),
|
|
164
|
+
targets: ids.slice(start + 1, start + seqLen + 1),
|
|
165
|
+
});
|
|
166
|
+
}
|
|
167
|
+
const rem = ids.length % seqLen;
|
|
168
|
+
if (rem > 1) {
|
|
169
|
+
const start = ids.length - rem;
|
|
170
|
+
chunks.push({
|
|
171
|
+
inputs: ids.slice(start, -1),
|
|
172
|
+
targets: ids.slice(start + 1),
|
|
173
|
+
});
|
|
174
|
+
}
|
|
175
|
+
return chunks;
|
|
176
|
+
}
|
|
177
|
+
function packAdamParams(numElements, lr, beta1, beta2, eps, weightDecay, beta1_t, beta2_t) {
|
|
178
|
+
const buf = new ArrayBuffer(32);
|
|
179
|
+
new Uint32Array(buf, 0, 1).set([numElements]);
|
|
180
|
+
new Float32Array(buf, 4, 7).set([lr, beta1, beta2, eps, weightDecay, beta1_t, beta2_t]);
|
|
181
|
+
return buf;
|
|
182
|
+
}
|
|
183
|
+
//# sourceMappingURL=trainer.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"trainer.js","sourceRoot":"","sources":["../../src/training/trainer.ts"],"names":[],"mappings":"AAAA;;GAEG;AAEH,OAAO,EACH,mBAAmB,EACnB,mBAAmB,EACnB,wBAAwB,EACxB,qBAAqB,EACrB,eAAe,EACf,cAAc,EACd,IAAI,GACP,MAAM,uBAAuB,CAAC;AAE/B,OAAO,EAAE,gBAAgB,EAAE,gBAAgB,EAAE,MAAM,eAAe,CAAC;AACnE,OAAO,EAAE,kBAAkB,EAAE,cAAc,EAAE,MAAM,6BAA6B,CAAC;AAkCjF,MAAM,OAAO,YAAY;IACrB,KAAK,CAAmB;IACxB,SAAS,CAAsB;IAC/B,MAAM,CAAY;IACV,QAAQ,CAAuB;IAC/B,KAAK,CAAS;IACd,cAAc,CAAqB;IACnC,mBAAmB,CAAqB;IACxC,kBAAkB,CAAqB;IAE/C,YAAY,KAAoC,EAAE,YAAiC,IAAI;QACnF,IAAI,CAAC,KAAK,GAAO,KAAK,CAAC;QACvB,IAAI,CAAC,SAAS,GAAG,SAAS,CAAC;QAC3B,IAAI,CAAC,MAAM,GAAM,KAAK,CAAC,MAAM,CAAC;QAE9B,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC;QACrB,IAAI,CAAC,KAAK,GAAG,CAAC,CAAC;QAEf,IAAI,CAAC,cAAc,GAAK,qBAAqB,CAAC,IAAI,CAAC,MAAM,EAAE,kBAAkB,EAAE,cAAc,CAAC,CAAC;QAC/F,IAAI,CAAC,mBAAmB,GAAG,qBAAqB,CAAC,IAAI,CAAC,MAAM,EAAE,cAAc,EAAE,kBAAkB,CAAC,CAAC;QAClG,IAAI,CAAC,kBAAkB,GAAI,qBAAqB,CAAC,IAAI,CAAC,MAAM,EAAE,cAAc,EAAE,iBAAiB,CAAC,CAAC;IACrG,CAAC;IAEO,YAAY;QAChB,IAAI,IAAI,CAAC,QAAQ;YAAE,OAAO;QAC1B,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,KAAK,CAAC,UAAU,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC;YAC9C,CAAC,EAAE,wBAAwB,CAAC,IAAI,CAAC,MAAM,EAAE,CAAC,CAAC,KAAK,GAAG,CAAC,EAAE,KAAK,CAAC;YAC5D,CAAC,EAAE,wBAAwB,CAAC,IAAI,CAAC,MAAM,EAAE,CAAC,CAAC,KAAK,GAAG,CAAC,EAAE,KAAK,CAAC;SAC/D,CAAC,CAAC,CAAC;IACR,CAAC;IAED,KAAK,CAAC,KAAK,CAAC,KAAwB,EAAE,OAAqB,EAAE;QACzD,MAAM,EACF,YAAY,GAAG,IAAI,EACnB,MAAM,GAAS,CAAC,EAChB,SAAS,GAAM,CAAC,EAChB,MAAM,GAAS,GAAG,EAClB,WAAW,GAAI,GAAG,EAClB,WAAW,GAAI,IAAI,EACnB,KAAK,GAAU,GAAG,EAClB,KAAK,GAAU,KAAK,EACpB,GAAG,GAAY,IAAI,EACnB,IAAI,GAAW,KAAK,EACpB,UAAU,GAAK,IAAI,GACtB,GAAG,IAAI,CAAC;QAET,IAAI,IAAI;YAAE,IAAI,CAAC,KAAK,CAAC,WAAW,CAAC,IAAI,CAAC,CAAC;QAEvC,IAAI,QAAkB,CAAC;QACvB,IAAI,OAAO,KAAK,KAAK,QAAQ,EAAE,CAAC;YAC5B,IAAI,CAAC,IAAI,CAAC,SAAS,EAAE,CAAC;gBAClB,MAAM,IAAI,KAAK,CACX,4DAA4D;oBAC5D,kEAAkE,CACrE,CAAC;YACN,CAAC;YACD,QAAQ,GAAG,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;QAC5C,CAAC;aAAM,CAAC;YACJ,QAAQ,GAAG,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;QACjC,CAAC;QAED,IAAI,QAAQ,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC;YACtB,MAAM,IAAI,KAAK,CAAC,+DAA+D,CAAC,CAAC;QACrF,CAAC;QAED,MAAM,MAAM,GAAG,WAAW,CAAC,QAAQ,EAAE,MAAM,CAAC,CAAC;QAC7C,IAAI,MAAM,CAAC,MAAM,KAAK,CAAC,EAAE,CAAC;YACtB,MAAM,IAAI,KAAK,CAAC,gDAAgD,CAAC,CAAC;QACtE,CAAC;QAED,IAAI,CAAC,YAAY,EAAE,CAAC;QAEpB,MAAM,WAAW,GAAa,EAAE,CAAC;QAEjC,KAAK,IAAI,KAAK,GAAG,CAAC,EAAE,KAAK,GAAG,MAAM,EAAE,KAAK,EAAE,EAAE,CAAC;YAC1C,IAAI,SAAS,GAAG,CAAC,CAAC;YAClB,IAAI,QAAQ,GAAI,CAAC,CAAC;YAElB,KAAK,MAAM,EAAE,MAAM,EAAE,OAAO,EAAE,IAAI,MAAM,EAAE,CAAC;gBACvC,MAAM,IAAI,GAAG,MAAM,IAAI,CAAC,UAAU,CAC9B,MAAM,EAAE,OAAO,EAAE,SAAS,EAC1B,EAAE,YAAY,EAAE,WAAW,EAAE,WAAW,EAAE,KAAK,EAAE,KAAK,EAAE,GAAG,EAAE,IAAI,EAAE,CACtE,CAAC;gBACF,SAAS,IAAI,IAAI,CAAC;gBAClB,QAAQ,EAAE,CAAC;YACf,CAAC;YAED,MAAM,OAAO,GAAG,SAAS,GAAG,QAAQ,CAAC;YACrC,WAAW,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;YAE1B,IAAI,UAAU;gBAAE,UAAU,CAAC,KAAK,GAAG,CAAC,EAAE,OAAO,CAAC,CAAC;QACnD,CAAC;QAED,IAAI,IAAI;YAAE,IAAI,CAAC,KAAK,CAAC,WAAW,CAAC,KAAK,CAAC,CAAC;QACxC,OAAO,WAAW,CAAC;IACvB,CAAC;IAEO,KAAK,CAAC,UAAU,CACpB,MAAgB,EAChB,OAAiB,EACjB,KAAa,EACb,WAAyI;QAEzI,MAAM,EAAE,YAAY,EAAE,WAAW,EAAE,WAAW,EAAE,KAAK,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,WAAW,CAAC;QAElF,IAAI,CAAC,KAAK,EAAE,CAAC;QACb,MAAM,MAAM,GAAM,MAAM,CAAC,MAAM,CAAC;QAChC,MAAM,SAAS,GAAG,IAAI,CAAC,KAAK,CAAC,MAAM,CAAC,SAAS,CAAC;QAE9C,MAAM,EAAE,MAAM,EAAE,SAAS,EAAE,GAAG,MAAM,IAAI,CAAC,KAAK,CAAC,OAAO,CAClD,IAAI,WAAW,CAAC,MAAM,CAAC,EAAE,KAAK,EAAE,MAAM,CACzC,CAAC;QAEF,IAAI,SAAS,GAAG,CAAC,CAAC;QAClB,MAAM,OAAO,GAAG,IAAI,YAAY,CAAC,KAAK,GAAG,MAAM,GAAG,SAAS,CAAC,CAAC;QAE7D,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YAC9B,MAAM,MAAM,GAAG,CAAC,GAAG,SAAS,CAAC;YAC7B,MAAM,UAAU,GAAG,MAAM,CAAC,KAAK,CAAC,MAAM,EAAE,MAAM,GAAG,SAAS,CAAC,CAAC;YAC5D,MAAM,MAAM,GAAG,OAAO,CAAC,CAAC,CAAE,CAAC;YAC3B,SAAS,IAAI,gBAAgB,CAAC,UAAU,EAAE,MAAM,CAAC,CAAC;YAClD,MAAM,IAAI,GAAI,gBAAgB,CAAC,UAAU,EAAE,MAAM,CAAC,CAAC;YACnD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,EAAE,CAAC,EAAE,EAAE,CAAC;gBACjC,OAAO,CAAC,MAAM,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,CAAC,CAAE,GAAG,MAAM,CAAC;YAC5C,CAAC;QACL,CAAC;QACD,MAAM,IAAI,GAAG,SAAS,GAAG,MAAM,CAAC;QAEhC,MAAM,UAAU,GAAG,mBAAmB,CAAC,IAAI,CAAC,MAAM,EAAE,OAAO,EAAE,KAAK,CAAC,CAAC;QAEpE,MAAM,IAAI,CAAC,cAAc,CAAC,UAAU,EAAE,OAAO,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnE,MAAM,MAAM,GAAI,IAAI,CAAC,KAAK,CAAC,UAAU,EAAE,CAAC;QACxC,MAAM,OAAO,GAAG,IAAI,CAAC,GAAG,CAAC,KAAK,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC;QAC5C,MAAM,OAAO,GAAG,IAAI,CAAC,GAAG,CAAC,KAAK,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC;QAE5C,MAAM,IAAI,CAAC,UAAU,CACjB,MAAM,EAAE,CAAC,UAAU,CAAC,EACpB,EAAE,YAAY,EAAE,WAAW,EAAE,KAAK,EAAE,KAAK,EAAE,GAAG,EAAE,OAAO,EAAE,OAAO,EAAE,CACrE,CAAC;QAEF,UAAU,CAAC,OAAO,EAAE,CAAC;QACrB,SAAS,CAAC,OAAO,EAAE,CAAC;QAEpB,OAAO,IAAI,CAAC;IAChB,CAAC;IAEO,KAAK,CAAC,UAAU,CACpB,MAAoB,EACpB,QAAqB,EACrB,EAAmB;QAEnB,MAAM,EAAE,YAAY,EAAE,WAAW,EAAE,KAAK,EAAE,KAAK,EAAE,GAAG,EAAE,OAAO,EAAE,OAAO,EAAE,GAAG,EAAE,CAAC;QAE9E,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACrC,MAAM,CAAC,GAAS,MAAM,CAAC,CAAC,CAAE,CAAC;YAC3B,MAAM,OAAO,GAAG,QAAQ,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,QAAQ,CAAC,MAAM,GAAG,CAAC,CAAC,CAAE,CAAC;YAE5D,IAAI,CAAC,OAAO,IAAI,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC,KAAK,GAAG,CAAC;gBAAE,SAAS;YAErD,MAAM,SAAS,GAAG,mBAAmB,CAAC,IAAI,CAAC,MAAM,EAAE,cAAc,CAC7D,CAAC,CAAC,KAAK,EAAE,YAAY,EAAE,KAAK,EAAE,KAAK,EAAE,GAAG,EAAE,WAAW,EAAE,OAAO,EAAE,OAAO,CAC1E,CAAC,CAAC;YAEH,MAAM,EAAE,GAAG,eAAe,CAAC,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,cAAc,EAAE;gBACzD,SAAS;gBACT,CAAC,CAAC,GAAG;gBACL,OAAO;gBACP,IAAI,CAAC,QAAS,CAAC,CAAC,CAAE,CAAC,CAAC;gBACpB,IAAI,CAAC,QAAS,CAAC,CAAC,CAAE,CAAC,CAAC;aACvB,CAAC,CAAC;YAEH,cAAc,CAAC,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,cAAc,EAAE,EAAE,EAC/C,CAAC,IAAI,CAAC,CAAC,CAAC,KAAK,EAAE,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;YAEhC,SAAS,CAAC,OAAO,EAAE,CAAC;QACxB,CAAC;IACL,CAAC;IAEO,KAAK,CAAC,cAAc,CAAC,OAAkB,EAAE,KAAa,EAAE,OAAe;QAC3E,MAAM,SAAS,GAAG,wBAAwB,CAAC,IAAI,CAAC,MAAM,EAAE,CAAC,EAAE,IAAI,CAAC,CAAC;QACjE,IAAI,CAAC,MAAM,CAAC,KAAK,CAAC,WAAW,CAAC,SAAS,EAAE,CAAC,EAAE,IAAI,YAAY,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;QAErE,MAAM,UAAU,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,WAAW,CAAC,UAAU,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC;QAC/C,IAAI,YAAY,CAAC,UAAU,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,OAAO,GAAG,OAAO,CAAC,CAAC,CAAC;QAC5D,MAAM,IAAI,GAAG,mBAAmB,CAAC,IAAI,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAE1D,MAAM,GAAG,GAAG,eAAe,CAAC,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,mBAAmB,EAC7D,CAAC,IAAI,EAAE,OAAO,EAAE,SAAS,CAAC,CAAC,CAAC;QAChC,cAAc,CAAC,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,mBAAmB,EAAE,GAAG,EACrD,CAAC,IAAI,CAAC,KAAK,EAAE,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAE9B,MAAM,GAAG,GAAG,eAAe,CAAC,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,kBAAkB,EAC5D,CAAC,IAAI,EAAE,OAAO,EAAE,SAAS,CAAC,CAAC,CAAC;QAChC,cAAc,CAAC,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,kBAAkB,EAAE,GAAG,EACpD,CAAC,IAAI,CAAC,KAAK,EAAE,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAE9B,IAAI,CAAC,OAAO,EAAE,CAAC;QACf,SAAS,CAAC,OAAO,EAAE,CAAC;IACxB,CAAC;IAED,KAAK,CAAC,QAAQ,CAAC,KAAwB;QACnC,IAAI,QAAkB,CAAC;QACvB,IAAI,OAAO,KAAK,KAAK,QAAQ,EAAE,CAAC;YAC5B,IAAI,CAAC,IAAI,CAAC,SAAS;gBAAE,MAAM,IAAI,KAAK,CAAC,sCAAsC,CAAC,CAAC;YAC7E,QAAQ,GAAG,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;QAC5C,CAAC;aAAM,CAAC;YACJ,QAAQ,GAAG,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;QACjC,CAAC;QAED,MAAM,MAAM,GAAM,QAAQ,CAAC,MAAM,CAAC;QAClC,MAAM,SAAS,GAAG,IAAI,CAAC,KAAK,CAAC,MAAM,CAAC,SAAS,CAAC;QAE9C,MAAM,EAAE,MAAM,EAAE,GAAG,MAAM,IAAI,CAAC,KAAK,CAAC,OAAO,CACvC,IAAI,WAAW,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,GAAG,CAAC,CACxD,CAAC;QAEF,IAAI,SAAS,GAAG,CAAC,CAAC;QAClB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;YAClC,MAAM,MAAM,GAAG,CAAC,GAAG,SAAS,CAAC;YAC7B,SAAS,IAAI,gBAAgB,CACzB,MAAM,CAAC,KAAK,CAAC,MAAM,EAAE,MAAM,GAAG,SAAS,CAAC,EACxC,QAAQ,CAAC,CAAC,GAAG,CAAC,CAAE,CACnB,CAAC;QACN,CAAC;QAED,MAAM,OAAO,GAAG,SAAS,GAAG,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;QACzC,OAAO,IAAI,CAAC,GAAG,CAAC,OAAO,CAAC,CAAC;IAC7B,CAAC;CACJ;AAED,SAAS,WAAW,CAAC,GAAa,EAAE,MAAc;IAC9C,MAAM,MAAM,GAAiD,EAAE,CAAC;IAChE,KAAK,IAAI,KAAK,GAAG,CAAC,EAAE,KAAK,GAAG,MAAM,GAAG,GAAG,CAAC,MAAM,EAAE,KAAK,IAAI,MAAM,EAAE,CAAC;QAC/D,MAAM,CAAC,IAAI,CAAC;YACR,MAAM,EAAG,GAAG,CAAC,KAAK,CAAC,KAAK,EAAE,KAAK,GAAG,MAAM,CAAC;YACzC,OAAO,EAAE,GAAG,CAAC,KAAK,CAAC,KAAK,GAAG,CAAC,EAAE,KAAK,GAAG,MAAM,GAAG,CAAC,CAAC;SACpD,CAAC,CAAC;IACP,CAAC;IACD,MAAM,GAAG,GAAG,GAAG,CAAC,MAAM,GAAG,MAAM,CAAC;IAChC,IAAI,GAAG,GAAG,CAAC,EAAE,CAAC;QACV,MAAM,KAAK,GAAG,GAAG,CAAC,MAAM,GAAG,GAAG,CAAC;QAC/B,MAAM,CAAC,IAAI,CAAC;YACR,MAAM,EAAG,GAAG,CAAC,KAAK,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC;YAC7B,OAAO,EAAE,GAAG,CAAC,KAAK,CAAC,KAAK,GAAG,CAAC,CAAC;SAChC,CAAC,CAAC;IACP,CAAC;IACD,OAAO,MAAM,CAAC;AAClB,CAAC;AAED,SAAS,cAAc,CACnB,WAAmB,EAAE,EAAU,EAAE,KAAa,EAAE,KAAa,EAC7D,GAAW,EAAE,WAAmB,EAAE,OAAe,EAAE,OAAe;IAElE,MAAM,GAAG,GAAG,IAAI,WAAW,CAAC,EAAE,CAAC,CAAC;IAChC,IAAI,WAAW,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,WAAW,CAAC,CAAC,CAAC;IAC9C,IAAI,YAAY,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,EAAE,EAAE,KAAK,EAAE,KAAK,EAAE,GAAG,EAAE,WAAW,EAAE,OAAO,EAAE,OAAO,CAAC,CAAC,CAAC;IACxF,OAAO,GAAG,CAAC;AACf,CAAC"}
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* gpu_utils.ts – WebGPU device management and buffer helpers.
|
|
3
|
+
*/
|
|
4
|
+
export interface InitWebGPUOptions {
|
|
5
|
+
powerPreference?: 'high-performance' | 'low-power';
|
|
6
|
+
}
|
|
7
|
+
export interface InitWebGPUResult {
|
|
8
|
+
device: GPUDevice;
|
|
9
|
+
adapter: GPUAdapter;
|
|
10
|
+
}
|
|
11
|
+
export declare function initWebGPU(opts?: InitWebGPUOptions): Promise<InitWebGPUResult>;
|
|
12
|
+
export declare function createStorageBuffer(device: GPUDevice, data: Float32Array | Uint32Array | number[], readable?: boolean): GPUBuffer;
|
|
13
|
+
export declare function createEmptyStorageBuffer(device: GPUDevice, byteSize: number, readable?: boolean): GPUBuffer;
|
|
14
|
+
export declare function createUniformBuffer(device: GPUDevice, data: ArrayBuffer | ArrayBufferView): GPUBuffer;
|
|
15
|
+
export declare function readBuffer(device: GPUDevice, srcBuffer: GPUBuffer, byteSize: number): Promise<Float32Array>;
|
|
16
|
+
export declare function uploadBuffer(device: GPUDevice, buffer: GPUBuffer, data: Float32Array, byteOffset?: number): void;
|
|
17
|
+
export declare function createComputePipeline(device: GPUDevice, wgslSource: string, entryPoint: string): GPUComputePipeline;
|
|
18
|
+
export declare function createBindGroup(device: GPUDevice, pipeline: GPUComputePipeline, buffers: GPUBuffer[], groupIndex?: number): GPUBindGroup;
|
|
19
|
+
export declare function dispatchKernel(device: GPUDevice, pipeline: GPUComputePipeline, bindGroup: GPUBindGroup, workgroups: [number, number, number]): void;
|
|
20
|
+
export declare function cdiv(a: number, b: number): number;
|
|
21
|
+
//# sourceMappingURL=gpu_utils.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"gpu_utils.d.ts","sourceRoot":"","sources":["../../src/utils/gpu_utils.ts"],"names":[],"mappings":"AAAA;;GAEG;AAUH,MAAM,WAAW,iBAAiB;IAChC,eAAe,CAAC,EAAE,kBAAkB,GAAG,WAAW,CAAC;CACpD;AAED,MAAM,WAAW,gBAAgB;IAC/B,MAAM,EAAE,SAAS,CAAC;IAClB,OAAO,EAAE,UAAU,CAAC;CACrB;AAED,wBAAsB,UAAU,CAAC,IAAI,GAAE,iBAAsB,GAAG,OAAO,CAAC,gBAAgB,CAAC,CAwCxF;AAED,wBAAgB,mBAAmB,CAAC,MAAM,EAAE,SAAS,EAAE,IAAI,EAAE,YAAY,GAAG,WAAW,GAAG,MAAM,EAAE,EAAE,QAAQ,UAAQ,GAAG,SAAS,CAW/H;AAED,wBAAgB,wBAAwB,CAAC,MAAM,EAAE,SAAS,EAAE,QAAQ,EAAE,MAAM,EAAE,QAAQ,UAAQ,GAAG,SAAS,CAGzG;AAED,wBAAgB,mBAAmB,CAAC,MAAM,EAAE,SAAS,EAAE,IAAI,EAAE,WAAW,GAAG,eAAe,GAAG,SAAS,CAUrG;AAED,wBAAsB,UAAU,CAAC,MAAM,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,MAAM,GAAG,OAAO,CAAC,YAAY,CAAC,CAgBjH;AAED,wBAAgB,YAAY,CAAC,MAAM,EAAE,SAAS,EAAE,MAAM,EAAE,SAAS,EAAE,IAAI,EAAE,YAAY,EAAE,UAAU,SAAI,GAAG,IAAI,CAE3G;AAED,wBAAgB,qBAAqB,CAAC,MAAM,EAAE,SAAS,EAAE,UAAU,EAAE,MAAM,EAAE,UAAU,EAAE,MAAM,GAAG,kBAAkB,CAMnH;AAED,wBAAgB,eAAe,CAAC,MAAM,EAAE,SAAS,EAAE,QAAQ,EAAE,kBAAkB,EAAE,OAAO,EAAE,SAAS,EAAE,EAAE,UAAU,SAAI,GAAG,YAAY,CASnI;AAED,wBAAgB,cAAc,CAAC,MAAM,EAAE,SAAS,EAAE,QAAQ,EAAE,kBAAkB,EAAE,SAAS,EAAE,YAAY,EAAE,UAAU,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,CAAC,GAAG,IAAI,CAQnJ;AAED,wBAAgB,IAAI,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAEjD"}
|