mambacode.js 1.0.0 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +198 -76
- package/dist/index.d.ts +18 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +18 -0
- package/dist/index.js.map +1 -0
- package/dist/kernels/activations.d.ts +3 -0
- package/dist/kernels/activations.d.ts.map +1 -0
- package/dist/kernels/activations.js +87 -0
- package/dist/kernels/activations.js.map +1 -0
- package/dist/kernels/conv1d.d.ts +3 -0
- package/dist/kernels/conv1d.d.ts.map +1 -0
- package/dist/kernels/conv1d.js +152 -0
- package/dist/kernels/conv1d.js.map +1 -0
- package/dist/kernels/linear_projection.d.ts +3 -0
- package/dist/kernels/linear_projection.d.ts.map +1 -0
- package/dist/kernels/linear_projection.js +219 -0
- package/dist/kernels/linear_projection.js.map +1 -0
- package/dist/kernels/selective_scan.d.ts +3 -0
- package/dist/kernels/selective_scan.d.ts.map +1 -0
- package/dist/kernels/selective_scan.js +348 -0
- package/dist/kernels/selective_scan.js.map +1 -0
- package/dist/kernels/weight_update.d.ts +3 -0
- package/dist/kernels/weight_update.d.ts.map +1 -0
- package/dist/kernels/weight_update.js +119 -0
- package/dist/kernels/weight_update.js.map +1 -0
- package/dist/model/mamba_block.d.ts +64 -0
- package/dist/model/mamba_block.d.ts.map +1 -0
- package/dist/model/mamba_block.js +309 -0
- package/dist/model/mamba_block.js.map +1 -0
- package/dist/model/mamba_model.d.ts +66 -0
- package/dist/model/mamba_model.d.ts.map +1 -0
- package/dist/model/mamba_model.js +289 -0
- package/dist/model/mamba_model.js.map +1 -0
- package/dist/tokenizer/bpe.d.ts +29 -0
- package/dist/tokenizer/bpe.d.ts.map +1 -0
- package/dist/tokenizer/bpe.js +164 -0
- package/dist/tokenizer/bpe.js.map +1 -0
- package/dist/training/autograd.d.ts +27 -0
- package/dist/training/autograd.d.ts.map +1 -0
- package/dist/training/autograd.js +120 -0
- package/dist/training/autograd.js.map +1 -0
- package/dist/training/trainer.d.ts +37 -0
- package/dist/training/trainer.d.ts.map +1 -0
- package/dist/training/trainer.js +183 -0
- package/dist/training/trainer.js.map +1 -0
- package/dist/utils/gpu_utils.d.ts +21 -0
- package/dist/utils/gpu_utils.d.ts.map +1 -0
- package/dist/utils/gpu_utils.js +111 -0
- package/dist/utils/gpu_utils.js.map +1 -0
- package/dist/utils/quantization.d.ts +26 -0
- package/dist/utils/quantization.d.ts.map +1 -0
- package/dist/utils/quantization.js +116 -0
- package/dist/utils/quantization.js.map +1 -0
- package/package.json +43 -18
- package/src/index.ts +59 -0
- package/src/kernels/{activations.js → activations.ts} +2 -2
- package/src/kernels/{linear_projection.js → linear_projection.ts} +2 -2
- package/src/kernels/{selective_scan.js → selective_scan.ts} +2 -2
- package/src/kernels/{weight_update.js → weight_update.ts} +2 -2
- package/src/model/{mamba_block.js → mamba_block.ts} +139 -175
- package/src/model/{mamba_model.js → mamba_model.ts} +168 -124
- package/src/tokenizer/bpe.ts +186 -0
- package/src/training/autograd.ts +135 -0
- package/src/training/trainer.ts +312 -0
- package/src/utils/gpu_utils.ts +147 -0
- package/src/utils/quantization.ts +154 -0
- package/src/index.js +0 -89
- package/src/tokenizer/bpe.js +0 -256
- package/src/training/autograd.js +0 -221
- package/src/training/trainer.js +0 -394
- package/src/utils/gpu_utils.js +0 -217
- package/src/utils/quantization.js +0 -215
- /package/src/kernels/{conv1d.js → conv1d.ts} +0 -0
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* autograd.ts – Lightweight tape-based automatic differentiation engine.
|
|
3
|
+
*/
|
|
4
|
+
/* eslint-disable @typescript-eslint/no-explicit-any */
|
|
5
|
+
const _gpu = globalThis;
|
|
6
|
+
let _tape = [];
|
|
7
|
+
let _gradEnabled = true;
|
|
8
|
+
export class Tensor {
|
|
9
|
+
data;
|
|
10
|
+
shape;
|
|
11
|
+
numel;
|
|
12
|
+
requiresGrad;
|
|
13
|
+
grad;
|
|
14
|
+
_gradFn;
|
|
15
|
+
constructor(data, shape, requiresGrad = false) {
|
|
16
|
+
this.data = data;
|
|
17
|
+
this.shape = shape;
|
|
18
|
+
this.numel = shape.reduce((a, b) => a * b, 1);
|
|
19
|
+
this.requiresGrad = requiresGrad;
|
|
20
|
+
this.grad = null;
|
|
21
|
+
this._gradFn = null;
|
|
22
|
+
}
|
|
23
|
+
get byteSize() { return this.numel * 4; }
|
|
24
|
+
zeroGrad(device) {
|
|
25
|
+
if (this.grad) {
|
|
26
|
+
device.queue.writeBuffer(this.grad, 0, new Float32Array(this.numel));
|
|
27
|
+
}
|
|
28
|
+
}
|
|
29
|
+
destroy() {
|
|
30
|
+
this.data?.destroy();
|
|
31
|
+
this.grad?.destroy();
|
|
32
|
+
this.data = null;
|
|
33
|
+
this.grad = null;
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
export function enableGrad() { _gradEnabled = true; }
|
|
37
|
+
export function noGrad() { _gradEnabled = false; }
|
|
38
|
+
export function clearTape() { _tape = []; }
|
|
39
|
+
export function recordOperation(backwardFn) {
|
|
40
|
+
if (!_gradEnabled)
|
|
41
|
+
return -1;
|
|
42
|
+
_tape.push({ backward: backwardFn });
|
|
43
|
+
return _tape.length - 1;
|
|
44
|
+
}
|
|
45
|
+
export async function backward() {
|
|
46
|
+
for (let i = _tape.length - 1; i >= 0; i--) {
|
|
47
|
+
await _tape[i].backward();
|
|
48
|
+
}
|
|
49
|
+
clearTape();
|
|
50
|
+
}
|
|
51
|
+
export function ensureGradBuffer(device, tensor) {
|
|
52
|
+
if (!tensor.grad) {
|
|
53
|
+
const STORAGE_USAGE = (_gpu.GPUBufferUsage?.STORAGE ?? 0x80) |
|
|
54
|
+
(_gpu.GPUBufferUsage?.COPY_DST ?? 0x08) |
|
|
55
|
+
(_gpu.GPUBufferUsage?.COPY_SRC ?? 0x04);
|
|
56
|
+
tensor.grad = device.createBuffer({
|
|
57
|
+
size: tensor.byteSize,
|
|
58
|
+
usage: STORAGE_USAGE,
|
|
59
|
+
});
|
|
60
|
+
device.queue.writeBuffer(tensor.grad, 0, new Float32Array(tensor.numel));
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
export function allocateGradients(device, tensors) {
|
|
64
|
+
for (const t of tensors) {
|
|
65
|
+
if (t.requiresGrad)
|
|
66
|
+
ensureGradBuffer(device, t);
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
export function zeroGradients(device, tensors) {
|
|
70
|
+
for (const t of tensors) {
|
|
71
|
+
if (t.grad) {
|
|
72
|
+
device.queue.writeBuffer(t.grad, 0, new Float32Array(t.numel));
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
export function onesLikeScalar(device) {
|
|
77
|
+
const USAGE = (_gpu.GPUBufferUsage?.STORAGE ?? 0x80) |
|
|
78
|
+
(_gpu.GPUBufferUsage?.COPY_DST ?? 0x08);
|
|
79
|
+
const buf = device.createBuffer({
|
|
80
|
+
size: 4,
|
|
81
|
+
usage: USAGE,
|
|
82
|
+
mappedAtCreation: true,
|
|
83
|
+
});
|
|
84
|
+
new Float32Array(buf.getMappedRange()).set([1.0]);
|
|
85
|
+
buf.unmap();
|
|
86
|
+
return buf;
|
|
87
|
+
}
|
|
88
|
+
export function crossEntropyLoss(logits, targetId) {
|
|
89
|
+
let maxLogit = -Infinity;
|
|
90
|
+
for (let i = 0; i < logits.length; i++) {
|
|
91
|
+
if (logits[i] > maxLogit)
|
|
92
|
+
maxLogit = logits[i];
|
|
93
|
+
}
|
|
94
|
+
let sumExp = 0;
|
|
95
|
+
for (let i = 0; i < logits.length; i++) {
|
|
96
|
+
sumExp += Math.exp(logits[i] - maxLogit);
|
|
97
|
+
}
|
|
98
|
+
const logSumExp = Math.log(sumExp) + maxLogit;
|
|
99
|
+
return logSumExp - logits[targetId];
|
|
100
|
+
}
|
|
101
|
+
export function crossEntropyGrad(logits, targetId) {
|
|
102
|
+
let maxLogit = -Infinity;
|
|
103
|
+
for (let i = 0; i < logits.length; i++) {
|
|
104
|
+
if (logits[i] > maxLogit)
|
|
105
|
+
maxLogit = logits[i];
|
|
106
|
+
}
|
|
107
|
+
let sumExp = 0;
|
|
108
|
+
const exp_shifted = new Float32Array(logits.length);
|
|
109
|
+
for (let i = 0; i < logits.length; i++) {
|
|
110
|
+
exp_shifted[i] = Math.exp(logits[i] - maxLogit);
|
|
111
|
+
sumExp += exp_shifted[i];
|
|
112
|
+
}
|
|
113
|
+
const probs = new Float32Array(logits.length);
|
|
114
|
+
for (let i = 0; i < logits.length; i++) {
|
|
115
|
+
probs[i] = exp_shifted[i] / sumExp;
|
|
116
|
+
}
|
|
117
|
+
probs[targetId] = (probs[targetId] ?? 0) - 1.0;
|
|
118
|
+
return probs;
|
|
119
|
+
}
|
|
120
|
+
//# sourceMappingURL=autograd.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"autograd.js","sourceRoot":"","sources":["../../src/training/autograd.ts"],"names":[],"mappings":"AAAA;;GAEG;AAEH,uDAAuD;AACvD,MAAM,IAAI,GAAG,UAAiB,CAAC;AAM/B,IAAI,KAAK,GAAgB,EAAE,CAAC;AAC5B,IAAI,YAAY,GAAG,IAAI,CAAC;AAExB,MAAM,OAAO,MAAM;IACf,IAAI,CAAmB;IACvB,KAAK,CAAW;IAChB,KAAK,CAAS;IACd,YAAY,CAAU;IACtB,IAAI,CAAmB;IACvB,OAAO,CAAgB;IAEvB,YAAY,IAAsB,EAAE,KAAe,EAAE,YAAY,GAAG,KAAK;QACrE,IAAI,CAAC,IAAI,GAAW,IAAI,CAAC;QACzB,IAAI,CAAC,KAAK,GAAU,KAAK,CAAC;QAC1B,IAAI,CAAC,KAAK,GAAU,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC,CAAC;QACrD,IAAI,CAAC,YAAY,GAAG,YAAY,CAAC;QACjC,IAAI,CAAC,IAAI,GAAW,IAAI,CAAC;QACzB,IAAI,CAAC,OAAO,GAAQ,IAAI,CAAC;IAC7B,CAAC;IAED,IAAI,QAAQ,KAAa,OAAO,IAAI,CAAC,KAAK,GAAG,CAAC,CAAC,CAAC,CAAC;IAEjD,QAAQ,CAAC,MAAiB;QACtB,IAAI,IAAI,CAAC,IAAI,EAAE,CAAC;YACZ,MAAM,CAAC,KAAK,CAAC,WAAW,CAAC,IAAI,CAAC,IAAI,EAAE,CAAC,EAAE,IAAI,YAAY,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC;QACzE,CAAC;IACL,CAAC;IAED,OAAO;QACH,IAAI,CAAC,IAAI,EAAE,OAAO,EAAE,CAAC;QACrB,IAAI,CAAC,IAAI,EAAE,OAAO,EAAE,CAAC;QACrB,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC;QACjB,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC;IACrB,CAAC;CACJ;AAED,MAAM,UAAU,UAAU,KAAY,YAAY,GAAG,IAAI,CAAC,CAAE,CAAC;AAC7D,MAAM,UAAU,MAAM,KAAgB,YAAY,GAAG,KAAK,CAAC,CAAC,CAAC;AAC7D,MAAM,UAAU,SAAS,KAAa,KAAK,GAAG,EAAE,CAAC,CAAC,CAAC;AAEnD,MAAM,UAAU,eAAe,CAAC,UAAsC;IAClE,IAAI,CAAC,YAAY;QAAE,OAAO,CAAC,CAAC,CAAC;IAC7B,KAAK,CAAC,IAAI,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,CAAC,CAAC;IACrC,OAAO,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC;AAC5B,CAAC;AAED,MAAM,CAAC,KAAK,UAAU,QAAQ;IAC1B,KAAK,IAAI,CAAC,GAAG,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;QACzC,MAAM,KAAK,CAAC,CAAC,CAAE,CAAC,QAAQ,EAAE,CAAC;IAC/B,CAAC;IACD,SAAS,EAAE,CAAC;AAChB,CAAC;AAED,MAAM,UAAU,gBAAgB,CAAC,MAAiB,EAAE,MAAc;IAC9D,IAAI,CAAC,MAAM,CAAC,IAAI,EAAE,CAAC;QACf,MAAM,aAAa,GAAW,CAAC,IAAI,CAAC,cAAc,EAAE,OAAO,IAAI,IAAI,CAAC;YACtC,CAAC,IAAI,CAAC,cAAc,EAAE,QAAQ,IAAI,IAAI,CAAC;YACvC,CAAC,IAAI,CAAC,cAAc,EAAE,QAAQ,IAAI,IAAI,CAAC,CAAC;QACtE,MAAM,CAAC,IAAI,GAAG,MAAM,CAAC,YAAY,CAAC;YAC9B,IAAI,EAAI,MAAM,CAAC,QAAQ;YACvB,KAAK,EAAG,aAAa;SACxB,CAAC,CAAC;QACH,MAAM,CAAC,KAAK,CAAC,WAAW,CAAC,MAAM,CAAC,IAAI,EAAE,CAAC,EAAE,IAAI,YAAY,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC;IAC7E,CAAC;AACL,CAAC;AAED,MAAM,UAAU,iBAAiB,CAAC,MAAiB,EAAE,OAAiB;IAClE,KAAK,MAAM,CAAC,IAAI,OAAO,EAAE,CAAC;QACtB,IAAI,CAAC,CAAC,YAAY;YAAE,gBAAgB,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC;IACpD,CAAC;AACL,CAAC;AAED,MAAM,UAAU,aAAa,CAAC,MAAiB,EAAE,OAAiB;IAC9D,KAAK,MAAM,CAAC,IAAI,OAAO,EAAE,CAAC;QACtB,IAAI,CAAC,CAAC,IAAI,EAAE,CAAC;YACT,MAAM,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,IAAI,EAAE,CAAC,EAAE,IAAI,YAAY,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC;QACnE,CAAC;IACL,CAAC;AACL,CAAC;AAED,MAAM,UAAU,cAAc,CAAC,MAAiB;IAC5C,MAAM,KAAK,GAAW,CAAC,IAAI,CAAC,cAAc,EAAE,OAAO,IAAI,IAAI,CAAC;QACtC,CAAC,IAAI,CAAC,cAAc,EAAE,QAAQ,IAAI,IAAI,CAAC,CAAC;IAC9D,MAAM,GAAG,GAAG,MAAM,CAAC,YAAY,CAAC;QAC5B,IAAI,EAAI,CAAC;QACT,KAAK,EAAG,KAAK;QACb,gBAAgB,EAAE,IAAI;KACzB,CAAC,CAAC;IACH,IAAI,YAAY,CAAC,GAAG,CAAC,cAAc,EAAE,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;IAClD,GAAG,CAAC,KAAK,EAAE,CAAC;IACZ,OAAO,GAAG,CAAC;AACf,CAAC;AAED,MAAM,UAAU,gBAAgB,CAAC,MAAoB,EAAE,QAAgB;IACnE,IAAI,QAAQ,GAAG,CAAC,QAAQ,CAAC;IACzB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QACrC,IAAI,MAAM,CAAC,CAAC,CAAE,GAAG,QAAQ;YAAE,QAAQ,GAAG,MAAM,CAAC,CAAC,CAAE,CAAC;IACrD,CAAC;IACD,IAAI,MAAM,GAAG,CAAC,CAAC;IACf,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QACrC,MAAM,IAAI,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,CAAE,GAAG,QAAQ,CAAC,CAAC;IAC9C,CAAC;IACD,MAAM,SAAS,GAAG,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,GAAG,QAAQ,CAAC;IAC9C,OAAO,SAAS,GAAG,MAAM,CAAC,QAAQ,CAAE,CAAC;AACzC,CAAC;AAED,MAAM,UAAU,gBAAgB,CAAC,MAAoB,EAAE,QAAgB;IACnE,IAAI,QAAQ,GAAG,CAAC,QAAQ,CAAC;IACzB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QACrC,IAAI,MAAM,CAAC,CAAC,CAAE,GAAG,QAAQ;YAAE,QAAQ,GAAG,MAAM,CAAC,CAAC,CAAE,CAAC;IACrD,CAAC;IACD,IAAI,MAAM,GAAG,CAAC,CAAC;IACf,MAAM,WAAW,GAAG,IAAI,YAAY,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC;IACpD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QACrC,WAAW,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,CAAE,GAAG,QAAQ,CAAC,CAAC;QACjD,MAAM,IAAI,WAAW,CAAC,CAAC,CAAE,CAAC;IAC9B,CAAC;IACD,MAAM,KAAK,GAAG,IAAI,YAAY,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC;IAC9C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QACrC,KAAK,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAE,GAAG,MAAM,CAAC;IACxC,CAAC;IACD,KAAK,CAAC,QAAQ,CAAC,GAAG,CAAC,KAAK,CAAC,QAAQ,CAAC,IAAI,CAAC,CAAC,GAAG,GAAG,CAAC;IAC/C,OAAO,KAAK,CAAC;AACjB,CAAC"}
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* trainer.ts – MambaTrainer class
|
|
3
|
+
*/
|
|
4
|
+
import { MambaModel, MambaModelConfig } from '../model/mamba_model';
|
|
5
|
+
import { BPETokenizer } from '../tokenizer/bpe';
|
|
6
|
+
export interface TrainOptions {
|
|
7
|
+
learningRate?: number;
|
|
8
|
+
epochs?: number;
|
|
9
|
+
batchSize?: number;
|
|
10
|
+
seqLen?: number;
|
|
11
|
+
maxGradNorm?: number;
|
|
12
|
+
weightDecay?: number;
|
|
13
|
+
beta1?: number;
|
|
14
|
+
beta2?: number;
|
|
15
|
+
eps?: number;
|
|
16
|
+
wsla?: boolean;
|
|
17
|
+
onEpochEnd?: ((epoch: number, loss: number) => void) | null;
|
|
18
|
+
}
|
|
19
|
+
export type { MambaModelConfig };
|
|
20
|
+
export declare class MambaTrainer {
|
|
21
|
+
model: MambaModel;
|
|
22
|
+
tokenizer: BPETokenizer | null;
|
|
23
|
+
device: GPUDevice;
|
|
24
|
+
private _moments;
|
|
25
|
+
private _step;
|
|
26
|
+
private _adamwPipeline;
|
|
27
|
+
private _clipReducePipeline;
|
|
28
|
+
private _clipScalePipeline;
|
|
29
|
+
constructor(model: MambaModel, tokenizer?: BPETokenizer | null);
|
|
30
|
+
private _initMoments;
|
|
31
|
+
train(input: string | number[], opts?: TrainOptions): Promise<number[]>;
|
|
32
|
+
private _trainStep;
|
|
33
|
+
private _adamwStep;
|
|
34
|
+
private _clipGradients;
|
|
35
|
+
evaluate(input: string | number[]): Promise<number>;
|
|
36
|
+
}
|
|
37
|
+
//# sourceMappingURL=trainer.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"trainer.d.ts","sourceRoot":"","sources":["../../src/training/trainer.ts"],"names":[],"mappings":"AAAA;;GAEG;AAcH,OAAO,EAAE,UAAU,EAAE,gBAAgB,EAAE,MAAM,sBAAsB,CAAC;AACpE,OAAO,EAAE,YAAY,EAAE,MAAM,kBAAkB,CAAC;AAGhD,MAAM,WAAW,YAAY;IAC3B,YAAY,CAAC,EAAE,MAAM,CAAC;IACtB,MAAM,CAAC,EAAE,MAAM,CAAC;IAChB,SAAS,CAAC,EAAE,MAAM,CAAC;IACnB,MAAM,CAAC,EAAE,MAAM,CAAC;IAChB,WAAW,CAAC,EAAE,MAAM,CAAC;IACrB,WAAW,CAAC,EAAE,MAAM,CAAC;IACrB,KAAK,CAAC,EAAE,MAAM,CAAC;IACf,KAAK,CAAC,EAAE,MAAM,CAAC;IACf,GAAG,CAAC,EAAE,MAAM,CAAC;IACb,IAAI,CAAC,EAAE,OAAO,CAAC;IACf,UAAU,CAAC,EAAE,CAAC,CAAC,KAAK,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,KAAK,IAAI,CAAC,GAAG,IAAI,CAAC;CAC7D;AAkBD,YAAY,EAAE,gBAAgB,EAAE,CAAC;AAEjC,qBAAa,YAAY;IACrB,KAAK,EAAE,UAAU,CAAC;IAClB,SAAS,EAAE,YAAY,GAAG,IAAI,CAAC;IAC/B,MAAM,EAAE,SAAS,CAAC;IAClB,OAAO,CAAC,QAAQ,CAAuB;IACvC,OAAO,CAAC,KAAK,CAAS;IACtB,OAAO,CAAC,cAAc,CAAqB;IAC3C,OAAO,CAAC,mBAAmB,CAAqB;IAChD,OAAO,CAAC,kBAAkB,CAAqB;gBAEnC,KAAK,EAAE,UAAU,EAAE,SAAS,GAAE,YAAY,GAAG,IAAW;IAapE,OAAO,CAAC,YAAY;IAQd,KAAK,CAAC,KAAK,EAAE,MAAM,GAAG,MAAM,EAAE,EAAE,IAAI,GAAE,YAAiB,GAAG,OAAO,CAAC,MAAM,EAAE,CAAC;YAkEnE,UAAU;YAkDV,UAAU;YAgCV,cAAc;IAuBtB,QAAQ,CAAC,KAAK,EAAE,MAAM,GAAG,MAAM,EAAE,GAAG,OAAO,CAAC,MAAM,CAAC;CA4B5D"}
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* trainer.ts – MambaTrainer class
|
|
3
|
+
*/
|
|
4
|
+
import { createUniformBuffer, createStorageBuffer, createEmptyStorageBuffer, createComputePipeline, createBindGroup, dispatchKernel, cdiv, } from '../utils/gpu_utils';
|
|
5
|
+
import { crossEntropyLoss, crossEntropyGrad } from './autograd';
|
|
6
|
+
import { WEIGHT_UPDATE_WGSL, GRAD_CLIP_WGSL } from '../kernels/weight_update';
|
|
7
|
+
export class MambaTrainer {
|
|
8
|
+
model;
|
|
9
|
+
tokenizer;
|
|
10
|
+
device;
|
|
11
|
+
_moments;
|
|
12
|
+
_step;
|
|
13
|
+
_adamwPipeline;
|
|
14
|
+
_clipReducePipeline;
|
|
15
|
+
_clipScalePipeline;
|
|
16
|
+
constructor(model, tokenizer = null) {
|
|
17
|
+
this.model = model;
|
|
18
|
+
this.tokenizer = tokenizer;
|
|
19
|
+
this.device = model.device;
|
|
20
|
+
this._moments = null;
|
|
21
|
+
this._step = 0;
|
|
22
|
+
this._adamwPipeline = createComputePipeline(this.device, WEIGHT_UPDATE_WGSL, 'adamw_update');
|
|
23
|
+
this._clipReducePipeline = createComputePipeline(this.device, GRAD_CLIP_WGSL, 'grad_norm_reduce');
|
|
24
|
+
this._clipScalePipeline = createComputePipeline(this.device, GRAD_CLIP_WGSL, 'grad_clip_scale');
|
|
25
|
+
}
|
|
26
|
+
_initMoments() {
|
|
27
|
+
if (this._moments)
|
|
28
|
+
return;
|
|
29
|
+
this._moments = this.model.parameters().map(p => ({
|
|
30
|
+
m: createEmptyStorageBuffer(this.device, p.numel * 4, false),
|
|
31
|
+
v: createEmptyStorageBuffer(this.device, p.numel * 4, false),
|
|
32
|
+
}));
|
|
33
|
+
}
|
|
34
|
+
async train(input, opts = {}) {
|
|
35
|
+
const { learningRate = 1e-4, epochs = 5, batchSize = 1, seqLen = 512, maxGradNorm = 1.0, weightDecay = 0.01, beta1 = 0.9, beta2 = 0.999, eps = 1e-8, wsla = false, onEpochEnd = null, } = opts;
|
|
36
|
+
if (wsla)
|
|
37
|
+
this.model.setWSLAMode(true);
|
|
38
|
+
let tokenIds;
|
|
39
|
+
if (typeof input === 'string') {
|
|
40
|
+
if (!this.tokenizer) {
|
|
41
|
+
throw new Error('MambaTrainer requires a tokenizer when input is a string. ' +
|
|
42
|
+
'Pass a BPETokenizer instance as the second constructor argument.');
|
|
43
|
+
}
|
|
44
|
+
tokenIds = this.tokenizer.encode(input);
|
|
45
|
+
}
|
|
46
|
+
else {
|
|
47
|
+
tokenIds = Array.from(input);
|
|
48
|
+
}
|
|
49
|
+
if (tokenIds.length < 2) {
|
|
50
|
+
throw new Error('Input must contain at least 2 tokens to form a training pair.');
|
|
51
|
+
}
|
|
52
|
+
const chunks = buildChunks(tokenIds, seqLen);
|
|
53
|
+
if (chunks.length === 0) {
|
|
54
|
+
throw new Error('Input is too short to form any training chunk.');
|
|
55
|
+
}
|
|
56
|
+
this._initMoments();
|
|
57
|
+
const epochLosses = [];
|
|
58
|
+
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
59
|
+
let epochLoss = 0;
|
|
60
|
+
let numSteps = 0;
|
|
61
|
+
for (const { inputs, targets } of chunks) {
|
|
62
|
+
const loss = await this._trainStep(inputs, targets, batchSize, { learningRate, maxGradNorm, weightDecay, beta1, beta2, eps, wsla });
|
|
63
|
+
epochLoss += loss;
|
|
64
|
+
numSteps++;
|
|
65
|
+
}
|
|
66
|
+
const avgLoss = epochLoss / numSteps;
|
|
67
|
+
epochLosses.push(avgLoss);
|
|
68
|
+
if (onEpochEnd)
|
|
69
|
+
onEpochEnd(epoch + 1, avgLoss);
|
|
70
|
+
}
|
|
71
|
+
if (wsla)
|
|
72
|
+
this.model.setWSLAMode(false);
|
|
73
|
+
return epochLosses;
|
|
74
|
+
}
|
|
75
|
+
async _trainStep(inputs, targets, batch, hyperparams) {
|
|
76
|
+
const { learningRate, maxGradNorm, weightDecay, beta1, beta2, eps } = hyperparams;
|
|
77
|
+
this._step++;
|
|
78
|
+
const seqLen = inputs.length;
|
|
79
|
+
const vocabSize = this.model.config.vocabSize;
|
|
80
|
+
const { logits, gpuLogits } = await this.model.forward(new Uint32Array(inputs), batch, seqLen);
|
|
81
|
+
let totalLoss = 0;
|
|
82
|
+
const dLogits = new Float32Array(batch * seqLen * vocabSize);
|
|
83
|
+
for (let i = 0; i < seqLen; i++) {
|
|
84
|
+
const offset = i * vocabSize;
|
|
85
|
+
const logitSlice = logits.slice(offset, offset + vocabSize);
|
|
86
|
+
const target = targets[i];
|
|
87
|
+
totalLoss += crossEntropyLoss(logitSlice, target);
|
|
88
|
+
const grad = crossEntropyGrad(logitSlice, target);
|
|
89
|
+
for (let v = 0; v < vocabSize; v++) {
|
|
90
|
+
dLogits[offset + v] = grad[v] / seqLen;
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
const loss = totalLoss / seqLen;
|
|
94
|
+
const dLogitsBuf = createStorageBuffer(this.device, dLogits, false);
|
|
95
|
+
await this._clipGradients(dLogitsBuf, dLogits.length, maxGradNorm);
|
|
96
|
+
const params = this.model.parameters();
|
|
97
|
+
const beta1_t = Math.pow(beta1, this._step);
|
|
98
|
+
const beta2_t = Math.pow(beta2, this._step);
|
|
99
|
+
await this._adamwStep(params, [dLogitsBuf], { learningRate, weightDecay, beta1, beta2, eps, beta1_t, beta2_t });
|
|
100
|
+
dLogitsBuf.destroy();
|
|
101
|
+
gpuLogits.destroy();
|
|
102
|
+
return loss;
|
|
103
|
+
}
|
|
104
|
+
async _adamwStep(params, gradBufs, hp) {
|
|
105
|
+
const { learningRate, weightDecay, beta1, beta2, eps, beta1_t, beta2_t } = hp;
|
|
106
|
+
for (let i = 0; i < params.length; i++) {
|
|
107
|
+
const p = params[i];
|
|
108
|
+
const gradBuf = gradBufs[Math.min(i, gradBufs.length - 1)];
|
|
109
|
+
if (!gradBuf || gradBuf.size < p.numel * 4)
|
|
110
|
+
continue;
|
|
111
|
+
const paramsBuf = createUniformBuffer(this.device, packAdamParams(p.numel, learningRate, beta1, beta2, eps, weightDecay, beta1_t, beta2_t));
|
|
112
|
+
const bg = createBindGroup(this.device, this._adamwPipeline, [
|
|
113
|
+
paramsBuf,
|
|
114
|
+
p.buf,
|
|
115
|
+
gradBuf,
|
|
116
|
+
this._moments[i].m,
|
|
117
|
+
this._moments[i].v,
|
|
118
|
+
]);
|
|
119
|
+
dispatchKernel(this.device, this._adamwPipeline, bg, [cdiv(p.numel, 256), 1, 1]);
|
|
120
|
+
paramsBuf.destroy();
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
async _clipGradients(gradBuf, numel, maxNorm) {
|
|
124
|
+
const normSqBuf = createEmptyStorageBuffer(this.device, 4, true);
|
|
125
|
+
this.device.queue.writeBuffer(normSqBuf, 0, new Float32Array([0.0]));
|
|
126
|
+
const clipParams = new ArrayBuffer(8);
|
|
127
|
+
new Uint32Array(clipParams, 0, 1).set([numel]);
|
|
128
|
+
new Float32Array(clipParams, 4, 1).set([maxNorm * maxNorm]);
|
|
129
|
+
const pBuf = createUniformBuffer(this.device, clipParams);
|
|
130
|
+
const bg1 = createBindGroup(this.device, this._clipReducePipeline, [pBuf, gradBuf, normSqBuf]);
|
|
131
|
+
dispatchKernel(this.device, this._clipReducePipeline, bg1, [cdiv(numel, 256), 1, 1]);
|
|
132
|
+
const bg2 = createBindGroup(this.device, this._clipScalePipeline, [pBuf, gradBuf, normSqBuf]);
|
|
133
|
+
dispatchKernel(this.device, this._clipScalePipeline, bg2, [cdiv(numel, 256), 1, 1]);
|
|
134
|
+
pBuf.destroy();
|
|
135
|
+
normSqBuf.destroy();
|
|
136
|
+
}
|
|
137
|
+
async evaluate(input) {
|
|
138
|
+
let tokenIds;
|
|
139
|
+
if (typeof input === 'string') {
|
|
140
|
+
if (!this.tokenizer)
|
|
141
|
+
throw new Error('Tokenizer required for string input.');
|
|
142
|
+
tokenIds = this.tokenizer.encode(input);
|
|
143
|
+
}
|
|
144
|
+
else {
|
|
145
|
+
tokenIds = Array.from(input);
|
|
146
|
+
}
|
|
147
|
+
const seqLen = tokenIds.length;
|
|
148
|
+
const vocabSize = this.model.config.vocabSize;
|
|
149
|
+
const { logits } = await this.model.forward(new Uint32Array(tokenIds.slice(0, -1)), 1, seqLen - 1);
|
|
150
|
+
let totalLoss = 0;
|
|
151
|
+
for (let i = 0; i < seqLen - 1; i++) {
|
|
152
|
+
const offset = i * vocabSize;
|
|
153
|
+
totalLoss += crossEntropyLoss(logits.slice(offset, offset + vocabSize), tokenIds[i + 1]);
|
|
154
|
+
}
|
|
155
|
+
const avgLoss = totalLoss / (seqLen - 1);
|
|
156
|
+
return Math.exp(avgLoss);
|
|
157
|
+
}
|
|
158
|
+
}
|
|
159
|
+
function buildChunks(ids, seqLen) {
|
|
160
|
+
const chunks = [];
|
|
161
|
+
for (let start = 0; start + seqLen < ids.length; start += seqLen) {
|
|
162
|
+
chunks.push({
|
|
163
|
+
inputs: ids.slice(start, start + seqLen),
|
|
164
|
+
targets: ids.slice(start + 1, start + seqLen + 1),
|
|
165
|
+
});
|
|
166
|
+
}
|
|
167
|
+
const rem = ids.length % seqLen;
|
|
168
|
+
if (rem > 1) {
|
|
169
|
+
const start = ids.length - rem;
|
|
170
|
+
chunks.push({
|
|
171
|
+
inputs: ids.slice(start, -1),
|
|
172
|
+
targets: ids.slice(start + 1),
|
|
173
|
+
});
|
|
174
|
+
}
|
|
175
|
+
return chunks;
|
|
176
|
+
}
|
|
177
|
+
function packAdamParams(numElements, lr, beta1, beta2, eps, weightDecay, beta1_t, beta2_t) {
|
|
178
|
+
const buf = new ArrayBuffer(32);
|
|
179
|
+
new Uint32Array(buf, 0, 1).set([numElements]);
|
|
180
|
+
new Float32Array(buf, 4, 7).set([lr, beta1, beta2, eps, weightDecay, beta1_t, beta2_t]);
|
|
181
|
+
return buf;
|
|
182
|
+
}
|
|
183
|
+
//# sourceMappingURL=trainer.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"trainer.js","sourceRoot":"","sources":["../../src/training/trainer.ts"],"names":[],"mappings":"AAAA;;GAEG;AAEH,OAAO,EACH,mBAAmB,EACnB,mBAAmB,EACnB,wBAAwB,EACxB,qBAAqB,EACrB,eAAe,EACf,cAAc,EACd,IAAI,GACP,MAAM,oBAAoB,CAAC;AAE5B,OAAO,EAAE,gBAAgB,EAAE,gBAAgB,EAAE,MAAM,YAAY,CAAC;AAChE,OAAO,EAAE,kBAAkB,EAAE,cAAc,EAAE,MAAM,0BAA0B,CAAC;AAqC9E,MAAM,OAAO,YAAY;IACrB,KAAK,CAAa;IAClB,SAAS,CAAsB;IAC/B,MAAM,CAAY;IACV,QAAQ,CAAuB;IAC/B,KAAK,CAAS;IACd,cAAc,CAAqB;IACnC,mBAAmB,CAAqB;IACxC,kBAAkB,CAAqB;IAE/C,YAAY,KAAiB,EAAE,YAAiC,IAAI;QAChE,IAAI,CAAC,KAAK,GAAO,KAAK,CAAC;QACvB,IAAI,CAAC,SAAS,GAAG,SAAS,CAAC;QAC3B,IAAI,CAAC,MAAM,GAAM,KAAK,CAAC,MAAM,CAAC;QAE9B,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC;QACrB,IAAI,CAAC,KAAK,GAAG,CAAC,CAAC;QAEf,IAAI,CAAC,cAAc,GAAK,qBAAqB,CAAC,IAAI,CAAC,MAAM,EAAE,kBAAkB,EAAE,cAAc,CAAC,CAAC;QAC/F,IAAI,CAAC,mBAAmB,GAAG,qBAAqB,CAAC,IAAI,CAAC,MAAM,EAAE,cAAc,EAAE,kBAAkB,CAAC,CAAC;QAClG,IAAI,CAAC,kBAAkB,GAAI,qBAAqB,CAAC,IAAI,CAAC,MAAM,EAAE,cAAc,EAAE,iBAAiB,CAAC,CAAC;IACrG,CAAC;IAEO,YAAY;QAChB,IAAI,IAAI,CAAC,QAAQ;YAAE,OAAO;QAC1B,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,KAAK,CAAC,UAAU,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC;YAC9C,CAAC,EAAE,wBAAwB,CAAC,IAAI,CAAC,MAAM,EAAE,CAAC,CAAC,KAAK,GAAG,CAAC,EAAE,KAAK,CAAC;YAC5D,CAAC,EAAE,wBAAwB,CAAC,IAAI,CAAC,MAAM,EAAE,CAAC,CAAC,KAAK,GAAG,CAAC,EAAE,KAAK,CAAC;SAC/D,CAAC,CAAC,CAAC;IACR,CAAC;IAED,KAAK,CAAC,KAAK,CAAC,KAAwB,EAAE,OAAqB,EAAE;QACzD,MAAM,EACF,YAAY,GAAG,IAAI,EACnB,MAAM,GAAS,CAAC,EAChB,SAAS,GAAM,CAAC,EAChB,MAAM,GAAS,GAAG,EAClB,WAAW,GAAI,GAAG,EAClB,WAAW,GAAI,IAAI,EACnB,KAAK,GAAU,GAAG,EAClB,KAAK,GAAU,KAAK,EACpB,GAAG,GAAY,IAAI,EACnB,IAAI,GAAW,KAAK,EACpB,UAAU,GAAK,IAAI,GACtB,GAAG,IAAI,CAAC;QAET,IAAI,IAAI;YAAE,IAAI,CAAC,KAAK,CAAC,WAAW,CAAC,IAAI,CAAC,CAAC;QAEvC,IAAI,QAAkB,CAAC;QACvB,IAAI,OAAO,KAAK,KAAK,QAAQ,EAAE,CAAC;YAC5B,IAAI,CAAC,IAAI,CAAC,SAAS,EAAE,CAAC;gBAClB,MAAM,IAAI,KAAK,CACX,4DAA4D;oBAC5D,kEAAkE,CACrE,CAAC;YACN,CAAC;YACD,QAAQ,GAAG,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;QAC5C,CAAC;aAAM,CAAC;YACJ,QAAQ,GAAG,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;QACjC,CAAC;QAED,IAAI,QAAQ,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC;YACtB,MAAM,IAAI,KAAK,CAAC,+DAA+D,CAAC,CAAC;QACrF,CAAC;QAED,MAAM,MAAM,GAAG,WAAW,CAAC,QAAQ,EAAE,MAAM,CAAC,CAAC;QAC7C,IAAI,MAAM,CAAC,MAAM,KAAK,CAAC,EAAE,CAAC;YACtB,MAAM,IAAI,KAAK,CAAC,gDAAgD,CAAC,CAAC;QACtE,CAAC;QAED,IAAI,CAAC,YAAY,EAAE,CAAC;QAEpB,MAAM,WAAW,GAAa,EAAE,CAAC;QAEjC,KAAK,IAAI,KAAK,GAAG,CAAC,EAAE,KAAK,GAAG,MAAM,EAAE,KAAK,EAAE,EAAE,CAAC;YAC1C,IAAI,SAAS,GAAG,CAAC,CAAC;YAClB,IAAI,QAAQ,GAAI,CAAC,CAAC;YAElB,KAAK,MAAM,EAAE,MAAM,EAAE,OAAO,EAAE,IAAI,MAAM,EAAE,CAAC;gBACvC,MAAM,IAAI,GAAG,MAAM,IAAI,CAAC,UAAU,CAC9B,MAAM,EAAE,OAAO,EAAE,SAAS,EAC1B,EAAE,YAAY,EAAE,WAAW,EAAE,WAAW,EAAE,KAAK,EAAE,KAAK,EAAE,GAAG,EAAE,IAAI,EAAE,CACtE,CAAC;gBACF,SAAS,IAAI,IAAI,CAAC;gBAClB,QAAQ,EAAE,CAAC;YACf,CAAC;YAED,MAAM,OAAO,GAAG,SAAS,GAAG,QAAQ,CAAC;YACrC,WAAW,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;YAE1B,IAAI,UAAU;gBAAE,UAAU,CAAC,KAAK,GAAG,CAAC,EAAE,OAAO,CAAC,CAAC;QACnD,CAAC;QAED,IAAI,IAAI;YAAE,IAAI,CAAC,KAAK,CAAC,WAAW,CAAC,KAAK,CAAC,CAAC;QACxC,OAAO,WAAW,CAAC;IACvB,CAAC;IAEO,KAAK,CAAC,UAAU,CACpB,MAAgB,EAChB,OAAiB,EACjB,KAAa,EACb,WAAyI;QAEzI,MAAM,EAAE,YAAY,EAAE,WAAW,EAAE,WAAW,EAAE,KAAK,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,WAAW,CAAC;QAElF,IAAI,CAAC,KAAK,EAAE,CAAC;QACb,MAAM,MAAM,GAAM,MAAM,CAAC,MAAM,CAAC;QAChC,MAAM,SAAS,GAAG,IAAI,CAAC,KAAK,CAAC,MAAM,CAAC,SAAS,CAAC;QAE9C,MAAM,EAAE,MAAM,EAAE,SAAS,EAAE,GAAG,MAAM,IAAI,CAAC,KAAK,CAAC,OAAO,CAClD,IAAI,WAAW,CAAC,MAAM,CAAC,EAAE,KAAK,EAAE,MAAM,CACzC,CAAC;QAEF,IAAI,SAAS,GAAG,CAAC,CAAC;QAClB,MAAM,OAAO,GAAG,IAAI,YAAY,CAAC,KAAK,GAAG,MAAM,GAAG,SAAS,CAAC,CAAC;QAE7D,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YAC9B,MAAM,MAAM,GAAG,CAAC,GAAG,SAAS,CAAC;YAC7B,MAAM,UAAU,GAAG,MAAM,CAAC,KAAK,CAAC,MAAM,EAAE,MAAM,GAAG,SAAS,CAAC,CAAC;YAC5D,MAAM,MAAM,GAAG,OAAO,CAAC,CAAC,CAAE,CAAC;YAC3B,SAAS,IAAI,gBAAgB,CAAC,UAAU,EAAE,MAAM,CAAC,CAAC;YAClD,MAAM,IAAI,GAAI,gBAAgB,CAAC,UAAU,EAAE,MAAM,CAAC,CAAC;YACnD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,EAAE,CAAC,EAAE,EAAE,CAAC;gBACjC,OAAO,CAAC,MAAM,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,CAAC,CAAE,GAAG,MAAM,CAAC;YAC5C,CAAC;QACL,CAAC;QACD,MAAM,IAAI,GAAG,SAAS,GAAG,MAAM,CAAC;QAEhC,MAAM,UAAU,GAAG,mBAAmB,CAAC,IAAI,CAAC,MAAM,EAAE,OAAO,EAAE,KAAK,CAAC,CAAC;QAEpE,MAAM,IAAI,CAAC,cAAc,CAAC,UAAU,EAAE,OAAO,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnE,MAAM,MAAM,GAAI,IAAI,CAAC,KAAK,CAAC,UAAU,EAAE,CAAC;QACxC,MAAM,OAAO,GAAG,IAAI,CAAC,GAAG,CAAC,KAAK,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC;QAC5C,MAAM,OAAO,GAAG,IAAI,CAAC,GAAG,CAAC,KAAK,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC;QAE5C,MAAM,IAAI,CAAC,UAAU,CACjB,MAAM,EAAE,CAAC,UAAU,CAAC,EACpB,EAAE,YAAY,EAAE,WAAW,EAAE,KAAK,EAAE,KAAK,EAAE,GAAG,EAAE,OAAO,EAAE,OAAO,EAAE,CACrE,CAAC;QAEF,UAAU,CAAC,OAAO,EAAE,CAAC;QACrB,SAAS,CAAC,OAAO,EAAE,CAAC;QAEpB,OAAO,IAAI,CAAC;IAChB,CAAC;IAEO,KAAK,CAAC,UAAU,CACpB,MAAoB,EACpB,QAAqB,EACrB,EAAmB;QAEnB,MAAM,EAAE,YAAY,EAAE,WAAW,EAAE,KAAK,EAAE,KAAK,EAAE,GAAG,EAAE,OAAO,EAAE,OAAO,EAAE,GAAG,EAAE,CAAC;QAE9E,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACrC,MAAM,CAAC,GAAS,MAAM,CAAC,CAAC,CAAE,CAAC;YAC3B,MAAM,OAAO,GAAG,QAAQ,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,QAAQ,CAAC,MAAM,GAAG,CAAC,CAAC,CAAE,CAAC;YAE5D,IAAI,CAAC,OAAO,IAAI,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC,KAAK,GAAG,CAAC;gBAAE,SAAS;YAErD,MAAM,SAAS,GAAG,mBAAmB,CAAC,IAAI,CAAC,MAAM,EAAE,cAAc,CAC7D,CAAC,CAAC,KAAK,EAAE,YAAY,EAAE,KAAK,EAAE,KAAK,EAAE,GAAG,EAAE,WAAW,EAAE,OAAO,EAAE,OAAO,CAC1E,CAAC,CAAC;YAEH,MAAM,EAAE,GAAG,eAAe,CAAC,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,cAAc,EAAE;gBACzD,SAAS;gBACT,CAAC,CAAC,GAAG;gBACL,OAAO;gBACP,IAAI,CAAC,QAAS,CAAC,CAAC,CAAE,CAAC,CAAC;gBACpB,IAAI,CAAC,QAAS,CAAC,CAAC,CAAE,CAAC,CAAC;aACvB,CAAC,CAAC;YAEH,cAAc,CAAC,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,cAAc,EAAE,EAAE,EAC/C,CAAC,IAAI,CAAC,CAAC,CAAC,KAAK,EAAE,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;YAEhC,SAAS,CAAC,OAAO,EAAE,CAAC;QACxB,CAAC;IACL,CAAC;IAEO,KAAK,CAAC,cAAc,CAAC,OAAkB,EAAE,KAAa,EAAE,OAAe;QAC3E,MAAM,SAAS,GAAG,wBAAwB,CAAC,IAAI,CAAC,MAAM,EAAE,CAAC,EAAE,IAAI,CAAC,CAAC;QACjE,IAAI,CAAC,MAAM,CAAC,KAAK,CAAC,WAAW,CAAC,SAAS,EAAE,CAAC,EAAE,IAAI,YAAY,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;QAErE,MAAM,UAAU,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,WAAW,CAAC,UAAU,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC;QAC/C,IAAI,YAAY,CAAC,UAAU,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,OAAO,GAAG,OAAO,CAAC,CAAC,CAAC;QAC5D,MAAM,IAAI,GAAG,mBAAmB,CAAC,IAAI,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAE1D,MAAM,GAAG,GAAG,eAAe,CAAC,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,mBAAmB,EAC7D,CAAC,IAAI,EAAE,OAAO,EAAE,SAAS,CAAC,CAAC,CAAC;QAChC,cAAc,CAAC,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,mBAAmB,EAAE,GAAG,EACrD,CAAC,IAAI,CAAC,KAAK,EAAE,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAE9B,MAAM,GAAG,GAAG,eAAe,CAAC,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,kBAAkB,EAC5D,CAAC,IAAI,EAAE,OAAO,EAAE,SAAS,CAAC,CAAC,CAAC;QAChC,cAAc,CAAC,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,kBAAkB,EAAE,GAAG,EACpD,CAAC,IAAI,CAAC,KAAK,EAAE,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAE9B,IAAI,CAAC,OAAO,EAAE,CAAC;QACf,SAAS,CAAC,OAAO,EAAE,CAAC;IACxB,CAAC;IAED,KAAK,CAAC,QAAQ,CAAC,KAAwB;QACnC,IAAI,QAAkB,CAAC;QACvB,IAAI,OAAO,KAAK,KAAK,QAAQ,EAAE,CAAC;YAC5B,IAAI,CAAC,IAAI,CAAC,SAAS;gBAAE,MAAM,IAAI,KAAK,CAAC,sCAAsC,CAAC,CAAC;YAC7E,QAAQ,GAAG,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;QAC5C,CAAC;aAAM,CAAC;YACJ,QAAQ,GAAG,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;QACjC,CAAC;QAED,MAAM,MAAM,GAAM,QAAQ,CAAC,MAAM,CAAC;QAClC,MAAM,SAAS,GAAG,IAAI,CAAC,KAAK,CAAC,MAAM,CAAC,SAAS,CAAC;QAE9C,MAAM,EAAE,MAAM,EAAE,GAAG,MAAM,IAAI,CAAC,KAAK,CAAC,OAAO,CACvC,IAAI,WAAW,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,GAAG,CAAC,CACxD,CAAC;QAEF,IAAI,SAAS,GAAG,CAAC,CAAC;QAClB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;YAClC,MAAM,MAAM,GAAG,CAAC,GAAG,SAAS,CAAC;YAC7B,SAAS,IAAI,gBAAgB,CACzB,MAAM,CAAC,KAAK,CAAC,MAAM,EAAE,MAAM,GAAG,SAAS,CAAC,EACxC,QAAQ,CAAC,CAAC,GAAG,CAAC,CAAE,CACnB,CAAC;QACN,CAAC;QAED,MAAM,OAAO,GAAG,SAAS,GAAG,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;QACzC,OAAO,IAAI,CAAC,GAAG,CAAC,OAAO,CAAC,CAAC;IAC7B,CAAC;CACJ;AAED,SAAS,WAAW,CAAC,GAAa,EAAE,MAAc;IAC9C,MAAM,MAAM,GAAiD,EAAE,CAAC;IAChE,KAAK,IAAI,KAAK,GAAG,CAAC,EAAE,KAAK,GAAG,MAAM,GAAG,GAAG,CAAC,MAAM,EAAE,KAAK,IAAI,MAAM,EAAE,CAAC;QAC/D,MAAM,CAAC,IAAI,CAAC;YACR,MAAM,EAAG,GAAG,CAAC,KAAK,CAAC,KAAK,EAAE,KAAK,GAAG,MAAM,CAAC;YACzC,OAAO,EAAE,GAAG,CAAC,KAAK,CAAC,KAAK,GAAG,CAAC,EAAE,KAAK,GAAG,MAAM,GAAG,CAAC,CAAC;SACpD,CAAC,CAAC;IACP,CAAC;IACD,MAAM,GAAG,GAAG,GAAG,CAAC,MAAM,GAAG,MAAM,CAAC;IAChC,IAAI,GAAG,GAAG,CAAC,EAAE,CAAC;QACV,MAAM,KAAK,GAAG,GAAG,CAAC,MAAM,GAAG,GAAG,CAAC;QAC/B,MAAM,CAAC,IAAI,CAAC;YACR,MAAM,EAAG,GAAG,CAAC,KAAK,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC;YAC7B,OAAO,EAAE,GAAG,CAAC,KAAK,CAAC,KAAK,GAAG,CAAC,CAAC;SAChC,CAAC,CAAC;IACP,CAAC;IACD,OAAO,MAAM,CAAC;AAClB,CAAC;AAED,SAAS,cAAc,CACnB,WAAmB,EAAE,EAAU,EAAE,KAAa,EAAE,KAAa,EAC7D,GAAW,EAAE,WAAmB,EAAE,OAAe,EAAE,OAAe;IAElE,MAAM,GAAG,GAAG,IAAI,WAAW,CAAC,EAAE,CAAC,CAAC;IAChC,IAAI,WAAW,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,WAAW,CAAC,CAAC,CAAC;IAC9C,IAAI,YAAY,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,EAAE,EAAE,KAAK,EAAE,KAAK,EAAE,GAAG,EAAE,WAAW,EAAE,OAAO,EAAE,OAAO,CAAC,CAAC,CAAC;IACxF,OAAO,GAAG,CAAC;AACf,CAAC"}
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* gpu_utils.ts – WebGPU device management and buffer helpers.
|
|
3
|
+
*/
|
|
4
|
+
export interface InitWebGPUOptions {
|
|
5
|
+
powerPreference?: 'high-performance' | 'low-power';
|
|
6
|
+
}
|
|
7
|
+
export interface InitWebGPUResult {
|
|
8
|
+
device: GPUDevice;
|
|
9
|
+
adapter: GPUAdapter;
|
|
10
|
+
}
|
|
11
|
+
export declare function initWebGPU(opts?: InitWebGPUOptions): Promise<InitWebGPUResult>;
|
|
12
|
+
export declare function createStorageBuffer(device: GPUDevice, data: Float32Array | Uint32Array | number[], readable?: boolean): GPUBuffer;
|
|
13
|
+
export declare function createEmptyStorageBuffer(device: GPUDevice, byteSize: number, readable?: boolean): GPUBuffer;
|
|
14
|
+
export declare function createUniformBuffer(device: GPUDevice, data: ArrayBuffer | ArrayBufferView): GPUBuffer;
|
|
15
|
+
export declare function readBuffer(device: GPUDevice, srcBuffer: GPUBuffer, byteSize: number): Promise<Float32Array>;
|
|
16
|
+
export declare function uploadBuffer(device: GPUDevice, buffer: GPUBuffer, data: Float32Array, byteOffset?: number): void;
|
|
17
|
+
export declare function createComputePipeline(device: GPUDevice, wgslSource: string, entryPoint: string): GPUComputePipeline;
|
|
18
|
+
export declare function createBindGroup(device: GPUDevice, pipeline: GPUComputePipeline, buffers: GPUBuffer[], groupIndex?: number): GPUBindGroup;
|
|
19
|
+
export declare function dispatchKernel(device: GPUDevice, pipeline: GPUComputePipeline, bindGroup: GPUBindGroup, workgroups: [number, number, number]): void;
|
|
20
|
+
export declare function cdiv(a: number, b: number): number;
|
|
21
|
+
//# sourceMappingURL=gpu_utils.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"gpu_utils.d.ts","sourceRoot":"","sources":["../../src/utils/gpu_utils.ts"],"names":[],"mappings":"AAAA;;GAEG;AAUH,MAAM,WAAW,iBAAiB;IAChC,eAAe,CAAC,EAAE,kBAAkB,GAAG,WAAW,CAAC;CACpD;AAED,MAAM,WAAW,gBAAgB;IAC/B,MAAM,EAAE,SAAS,CAAC;IAClB,OAAO,EAAE,UAAU,CAAC;CACrB;AAED,wBAAsB,UAAU,CAAC,IAAI,GAAE,iBAAsB,GAAG,OAAO,CAAC,gBAAgB,CAAC,CAwCxF;AAED,wBAAgB,mBAAmB,CAAC,MAAM,EAAE,SAAS,EAAE,IAAI,EAAE,YAAY,GAAG,WAAW,GAAG,MAAM,EAAE,EAAE,QAAQ,UAAQ,GAAG,SAAS,CAW/H;AAED,wBAAgB,wBAAwB,CAAC,MAAM,EAAE,SAAS,EAAE,QAAQ,EAAE,MAAM,EAAE,QAAQ,UAAQ,GAAG,SAAS,CAGzG;AAED,wBAAgB,mBAAmB,CAAC,MAAM,EAAE,SAAS,EAAE,IAAI,EAAE,WAAW,GAAG,eAAe,GAAG,SAAS,CAUrG;AAED,wBAAsB,UAAU,CAAC,MAAM,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,MAAM,GAAG,OAAO,CAAC,YAAY,CAAC,CAgBjH;AAED,wBAAgB,YAAY,CAAC,MAAM,EAAE,SAAS,EAAE,MAAM,EAAE,SAAS,EAAE,IAAI,EAAE,YAAY,EAAE,UAAU,SAAI,GAAG,IAAI,CAE3G;AAED,wBAAgB,qBAAqB,CAAC,MAAM,EAAE,SAAS,EAAE,UAAU,EAAE,MAAM,EAAE,UAAU,EAAE,MAAM,GAAG,kBAAkB,CAMnH;AAED,wBAAgB,eAAe,CAAC,MAAM,EAAE,SAAS,EAAE,QAAQ,EAAE,kBAAkB,EAAE,OAAO,EAAE,SAAS,EAAE,EAAE,UAAU,SAAI,GAAG,YAAY,CASnI;AAED,wBAAgB,cAAc,CAAC,MAAM,EAAE,SAAS,EAAE,QAAQ,EAAE,kBAAkB,EAAE,SAAS,EAAE,YAAY,EAAE,UAAU,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,CAAC,GAAG,IAAI,CAQnJ;AAED,wBAAgB,IAAI,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAEjD"}
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* gpu_utils.ts – WebGPU device management and buffer helpers.
|
|
3
|
+
*/
|
|
4
|
+
/* eslint-disable @typescript-eslint/no-explicit-any */
|
|
5
|
+
const _gpu = globalThis;
|
|
6
|
+
const UNIFORM = _gpu.GPUBufferUsage?.UNIFORM ?? 0x40;
|
|
7
|
+
const STORAGE = _gpu.GPUBufferUsage?.STORAGE ?? 0x80;
|
|
8
|
+
const COPY_SRC = _gpu.GPUBufferUsage?.COPY_SRC ?? 0x04;
|
|
9
|
+
const COPY_DST = _gpu.GPUBufferUsage?.COPY_DST ?? 0x08;
|
|
10
|
+
const MAP_READ = _gpu.GPUBufferUsage?.MAP_READ ?? 0x01;
|
|
11
|
+
export async function initWebGPU(opts = {}) {
|
|
12
|
+
if (typeof navigator === 'undefined' || !navigator.gpu) {
|
|
13
|
+
throw new Error('WebGPU is not available in this environment. ' +
|
|
14
|
+
'Use Chrome 113+, Edge 113+, or Firefox Nightly with WebGPU enabled.');
|
|
15
|
+
}
|
|
16
|
+
const adapter = await navigator.gpu.requestAdapter({
|
|
17
|
+
powerPreference: opts.powerPreference ?? 'high-performance',
|
|
18
|
+
});
|
|
19
|
+
if (!adapter) {
|
|
20
|
+
throw new Error('Failed to acquire a GPUAdapter. Your GPU may not support WebGPU.');
|
|
21
|
+
}
|
|
22
|
+
const adapterLimits = adapter.limits;
|
|
23
|
+
const requested3GB = 3 * 1024 * 1024 * 1024;
|
|
24
|
+
const device = await adapter.requestDevice({
|
|
25
|
+
requiredLimits: {
|
|
26
|
+
maxBufferSize: Math.min(requested3GB, adapterLimits.maxBufferSize),
|
|
27
|
+
maxStorageBufferBindingSize: Math.min(requested3GB, adapterLimits.maxStorageBufferBindingSize),
|
|
28
|
+
maxComputeInvocationsPerWorkgroup: Math.min(256, adapterLimits.maxComputeInvocationsPerWorkgroup),
|
|
29
|
+
},
|
|
30
|
+
});
|
|
31
|
+
device.lost.then((info) => {
|
|
32
|
+
console.error('WebGPU device lost:', info.message);
|
|
33
|
+
});
|
|
34
|
+
return { device, adapter };
|
|
35
|
+
}
|
|
36
|
+
export function createStorageBuffer(device, data, readable = false) {
|
|
37
|
+
const arr = data instanceof Float32Array || data instanceof Uint32Array ? data : new Float32Array(data);
|
|
38
|
+
const usage = STORAGE | COPY_DST | (readable ? COPY_SRC : 0);
|
|
39
|
+
const buffer = device.createBuffer({ size: arr.byteLength, usage, mappedAtCreation: true });
|
|
40
|
+
if (arr instanceof Uint32Array) {
|
|
41
|
+
new Uint32Array(buffer.getMappedRange()).set(arr);
|
|
42
|
+
}
|
|
43
|
+
else {
|
|
44
|
+
new Float32Array(buffer.getMappedRange()).set(arr);
|
|
45
|
+
}
|
|
46
|
+
buffer.unmap();
|
|
47
|
+
return buffer;
|
|
48
|
+
}
|
|
49
|
+
export function createEmptyStorageBuffer(device, byteSize, readable = false) {
|
|
50
|
+
const usage = STORAGE | COPY_DST | (readable ? COPY_SRC : 0);
|
|
51
|
+
return device.createBuffer({ size: byteSize, usage });
|
|
52
|
+
}
|
|
53
|
+
export function createUniformBuffer(device, data) {
|
|
54
|
+
const bytes = ArrayBuffer.isView(data) ? data.buffer : data;
|
|
55
|
+
const buffer = device.createBuffer({
|
|
56
|
+
size: bytes.byteLength,
|
|
57
|
+
usage: UNIFORM | COPY_DST,
|
|
58
|
+
mappedAtCreation: true,
|
|
59
|
+
});
|
|
60
|
+
new Uint8Array(buffer.getMappedRange()).set(new Uint8Array(bytes));
|
|
61
|
+
buffer.unmap();
|
|
62
|
+
return buffer;
|
|
63
|
+
}
|
|
64
|
+
export async function readBuffer(device, srcBuffer, byteSize) {
|
|
65
|
+
const MAP_READ_FLAG = _gpu.GPUMapMode?.READ ?? 0x01;
|
|
66
|
+
const stagingBuffer = device.createBuffer({
|
|
67
|
+
size: byteSize,
|
|
68
|
+
usage: MAP_READ | COPY_DST,
|
|
69
|
+
});
|
|
70
|
+
const encoder = device.createCommandEncoder();
|
|
71
|
+
encoder.copyBufferToBuffer(srcBuffer, 0, stagingBuffer, 0, byteSize);
|
|
72
|
+
device.queue.submit([encoder.finish()]);
|
|
73
|
+
await stagingBuffer.mapAsync(MAP_READ_FLAG);
|
|
74
|
+
const result = new Float32Array(stagingBuffer.getMappedRange().slice(0));
|
|
75
|
+
stagingBuffer.unmap();
|
|
76
|
+
stagingBuffer.destroy();
|
|
77
|
+
return result;
|
|
78
|
+
}
|
|
79
|
+
export function uploadBuffer(device, buffer, data, byteOffset = 0) {
|
|
80
|
+
device.queue.writeBuffer(buffer, byteOffset, data.buffer, data.byteOffset, data.byteLength);
|
|
81
|
+
}
|
|
82
|
+
export function createComputePipeline(device, wgslSource, entryPoint) {
|
|
83
|
+
const shaderModule = device.createShaderModule({ code: wgslSource });
|
|
84
|
+
return device.createComputePipeline({
|
|
85
|
+
layout: 'auto',
|
|
86
|
+
compute: { module: shaderModule, entryPoint },
|
|
87
|
+
});
|
|
88
|
+
}
|
|
89
|
+
export function createBindGroup(device, pipeline, buffers, groupIndex = 0) {
|
|
90
|
+
const entries = buffers.map((buf, i) => ({
|
|
91
|
+
binding: i,
|
|
92
|
+
resource: { buffer: buf },
|
|
93
|
+
}));
|
|
94
|
+
return device.createBindGroup({
|
|
95
|
+
layout: pipeline.getBindGroupLayout(groupIndex),
|
|
96
|
+
entries,
|
|
97
|
+
});
|
|
98
|
+
}
|
|
99
|
+
export function dispatchKernel(device, pipeline, bindGroup, workgroups) {
|
|
100
|
+
const encoder = device.createCommandEncoder();
|
|
101
|
+
const pass = encoder.beginComputePass();
|
|
102
|
+
pass.setPipeline(pipeline);
|
|
103
|
+
pass.setBindGroup(0, bindGroup);
|
|
104
|
+
pass.dispatchWorkgroups(...workgroups);
|
|
105
|
+
pass.end();
|
|
106
|
+
device.queue.submit([encoder.finish()]);
|
|
107
|
+
}
|
|
108
|
+
export function cdiv(a, b) {
|
|
109
|
+
return Math.ceil(a / b);
|
|
110
|
+
}
|
|
111
|
+
//# sourceMappingURL=gpu_utils.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"gpu_utils.js","sourceRoot":"","sources":["../../src/utils/gpu_utils.ts"],"names":[],"mappings":"AAAA;;GAEG;AAEH,uDAAuD;AACvD,MAAM,IAAI,GAAG,UAAiB,CAAC;AAC/B,MAAM,OAAO,GAAY,IAAI,CAAC,cAAc,EAAE,OAAO,IAAK,IAAI,CAAC;AAC/D,MAAM,OAAO,GAAY,IAAI,CAAC,cAAc,EAAE,OAAO,IAAK,IAAI,CAAC;AAC/D,MAAM,QAAQ,GAAW,IAAI,CAAC,cAAc,EAAE,QAAQ,IAAI,IAAI,CAAC;AAC/D,MAAM,QAAQ,GAAW,IAAI,CAAC,cAAc,EAAE,QAAQ,IAAI,IAAI,CAAC;AAC/D,MAAM,QAAQ,GAAW,IAAI,CAAC,cAAc,EAAE,QAAQ,IAAI,IAAI,CAAC;AAW/D,MAAM,CAAC,KAAK,UAAU,UAAU,CAAC,OAA0B,EAAE;IACzD,IAAI,OAAO,SAAS,KAAK,WAAW,IAAI,CAAC,SAAS,CAAC,GAAG,EAAE,CAAC;QACrD,MAAM,IAAI,KAAK,CACX,+CAA+C;YAC/C,qEAAqE,CACxE,CAAC;IACN,CAAC;IAED,MAAM,OAAO,GAAG,MAAM,SAAS,CAAC,GAAG,CAAC,cAAc,CAAC;QAC/C,eAAe,EAAE,IAAI,CAAC,eAAe,IAAI,kBAAkB;KAC9D,CAAC,CAAC;IAEH,IAAI,CAAC,OAAO,EAAE,CAAC;QACX,MAAM,IAAI,KAAK,CAAC,kEAAkE,CAAC,CAAC;IACxF,CAAC;IAED,MAAM,aAAa,GAAG,OAAO,CAAC,MAAM,CAAC;IACrC,MAAM,YAAY,GAAI,CAAC,GAAG,IAAI,GAAG,IAAI,GAAG,IAAI,CAAC;IAC7C,MAAM,MAAM,GAAG,MAAM,OAAO,CAAC,aAAa,CAAC;QACvC,cAAc,EAAE;YACZ,aAAa,EAAE,IAAI,CAAC,GAAG,CACnB,YAAY,EACZ,aAAa,CAAC,aAAa,CAC9B;YACD,2BAA2B,EAAE,IAAI,CAAC,GAAG,CACjC,YAAY,EACZ,aAAa,CAAC,2BAA2B,CAC5C;YACD,iCAAiC,EAAE,IAAI,CAAC,GAAG,CACvC,GAAG,EACH,aAAa,CAAC,iCAAiC,CAClD;SACJ;KACJ,CAAC,CAAC;IAEH,MAAM,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC,IAAI,EAAE,EAAE;QACtB,OAAO,CAAC,KAAK,CAAC,qBAAqB,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC;IACvD,CAAC,CAAC,CAAC;IAEH,OAAO,EAAE,MAAM,EAAE,OAAO,EAAE,CAAC;AAC/B,CAAC;AAED,MAAM,UAAU,mBAAmB,CAAC,MAAiB,EAAE,IAA2C,EAAE,QAAQ,GAAG,KAAK;IAChH,MAAM,GAAG,GAAM,IAAI,YAAY,YAAY,IAAI,IAAI,YAAY,WAAW,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,IAAI,YAAY,CAAC,IAAI,CAAC,CAAC;IAC3G,MAAM,KAAK,GAAI,OAAO,GAAG,QAAQ,GAAG,CAAC,QAAQ,CAAC,CAAC,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAC9D,MAAM,MAAM,GAAG,MAAM,CAAC,YAAY,CAAC,EAAE,IAAI,EAAE,GAAG,CAAC,UAAU,EAAE,KAAK,EAAE,gBAAgB,EAAE,IAAI,EAAE,CAAC,CAAC;IAC5F,IAAI,GAAG,YAAY,WAAW,EAAE,CAAC;QAC7B,IAAI,WAAW,CAAC,MAAM,CAAC,cAAc,EAAE,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC;IACtD,CAAC;SAAM,CAAC;QACJ,IAAI,YAAY,CAAC,MAAM,CAAC,cAAc,EAAE,CAAC,CAAC,GAAG,CAAC,GAAmB,CAAC,CAAC;IACvE,CAAC;IACD,MAAM,CAAC,KAAK,EAAE,CAAC;IACf,OAAO,MAAM,CAAC;AAClB,CAAC;AAED,MAAM,UAAU,wBAAwB,CAAC,MAAiB,EAAE,QAAgB,EAAE,QAAQ,GAAG,KAAK;IAC1F,MAAM,KAAK,GAAG,OAAO,GAAG,QAAQ,GAAG,CAAC,QAAQ,CAAC,CAAC,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAC7D,OAAO,MAAM,CAAC,YAAY,CAAC,EAAE,IAAI,EAAE,QAAQ,EAAE,KAAK,EAAE,CAAC,CAAC;AAC1D,CAAC;AAED,MAAM,UAAU,mBAAmB,CAAC,MAAiB,EAAE,IAAmC;IACtF,MAAM,KAAK,GAAI,WAAW,CAAC,MAAM,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC,CAAC,IAAI,CAAC;IAC7D,MAAM,MAAM,GAAG,MAAM,CAAC,YAAY,CAAC;QAC/B,IAAI,EAAI,KAAK,CAAC,UAAU;QACxB,KAAK,EAAG,OAAO,GAAG,QAAQ;QAC1B,gBAAgB,EAAE,IAAI;KACzB,CAAC,CAAC;IACH,IAAI,UAAU,CAAC,MAAM,CAAC,cAAc,EAAE,CAAC,CAAC,GAAG,CAAC,IAAI,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC;IACnE,MAAM,CAAC,KAAK,EAAE,CAAC;IACf,OAAO,MAAM,CAAC;AAClB,CAAC;AAED,MAAM,CAAC,KAAK,UAAU,UAAU,CAAC,MAAiB,EAAE,SAAoB,EAAE,QAAgB;IACtF,MAAM,aAAa,GAAW,IAAI,CAAC,UAAU,EAAE,IAAI,IAAI,IAAI,CAAC;IAC5D,MAAM,aAAa,GAAG,MAAM,CAAC,YAAY,CAAC;QACtC,IAAI,EAAI,QAAQ;QAChB,KAAK,EAAG,QAAQ,GAAG,QAAQ;KAC9B,CAAC,CAAC;IAEH,MAAM,OAAO,GAAG,MAAM,CAAC,oBAAoB,EAAE,CAAC;IAC9C,OAAO,CAAC,kBAAkB,CAAC,SAAS,EAAE,CAAC,EAAE,aAAa,EAAE,CAAC,EAAE,QAAQ,CAAC,CAAC;IACrE,MAAM,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC;IAExC,MAAM,aAAa,CAAC,QAAQ,CAAC,aAAa,CAAC,CAAC;IAC5C,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,aAAa,CAAC,cAAc,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;IACzE,aAAa,CAAC,KAAK,EAAE,CAAC;IACtB,aAAa,CAAC,OAAO,EAAE,CAAC;IACxB,OAAO,MAAM,CAAC;AAClB,CAAC;AAED,MAAM,UAAU,YAAY,CAAC,MAAiB,EAAE,MAAiB,EAAE,IAAkB,EAAE,UAAU,GAAG,CAAC;IACjG,MAAM,CAAC,KAAK,CAAC,WAAW,CAAC,MAAM,EAAE,UAAU,EAAE,IAAI,CAAC,MAAM,EAAE,IAAI,CAAC,UAAU,EAAE,IAAI,CAAC,UAAU,CAAC,CAAC;AAChG,CAAC;AAED,MAAM,UAAU,qBAAqB,CAAC,MAAiB,EAAE,UAAkB,EAAE,UAAkB;IAC3F,MAAM,YAAY,GAAG,MAAM,CAAC,kBAAkB,CAAC,EAAE,IAAI,EAAE,UAAU,EAAE,CAAC,CAAC;IACrE,OAAO,MAAM,CAAC,qBAAqB,CAAC;QAChC,MAAM,EAAG,MAAM;QACf,OAAO,EAAE,EAAE,MAAM,EAAE,YAAY,EAAE,UAAU,EAAE;KAChD,CAAC,CAAC;AACP,CAAC;AAED,MAAM,UAAU,eAAe,CAAC,MAAiB,EAAE,QAA4B,EAAE,OAAoB,EAAE,UAAU,GAAG,CAAC;IACjH,MAAM,OAAO,GAAG,OAAO,CAAC,GAAG,CAAC,CAAC,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC;QACrC,OAAO,EAAG,CAAC;QACX,QAAQ,EAAE,EAAE,MAAM,EAAE,GAAG,EAAE;KAC5B,CAAC,CAAC,CAAC;IACJ,OAAO,MAAM,CAAC,eAAe,CAAC;QAC1B,MAAM,EAAG,QAAQ,CAAC,kBAAkB,CAAC,UAAU,CAAC;QAChD,OAAO;KACV,CAAC,CAAC;AACP,CAAC;AAED,MAAM,UAAU,cAAc,CAAC,MAAiB,EAAE,QAA4B,EAAE,SAAuB,EAAE,UAAoC;IACzI,MAAM,OAAO,GAAG,MAAM,CAAC,oBAAoB,EAAE,CAAC;IAC9C,MAAM,IAAI,GAAM,OAAO,CAAC,gBAAgB,EAAE,CAAC;IAC3C,IAAI,CAAC,WAAW,CAAC,QAAQ,CAAC,CAAC;IAC3B,IAAI,CAAC,YAAY,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC;IAChC,IAAI,CAAC,kBAAkB,CAAC,GAAG,UAAU,CAAC,CAAC;IACvC,IAAI,CAAC,GAAG,EAAE,CAAC;IACX,MAAM,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC;AAC5C,CAAC;AAED,MAAM,UAAU,IAAI,CAAC,CAAS,EAAE,CAAS;IACrC,OAAO,IAAI,CAAC,IAAI,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;AAC5B,CAAC"}
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* quantization.ts – FP16 and Int8 quantization utilities.
|
|
3
|
+
*/
|
|
4
|
+
export interface QuantizeInt8Result {
|
|
5
|
+
data: Int8Array;
|
|
6
|
+
scale: number;
|
|
7
|
+
}
|
|
8
|
+
export interface QuantizeInt8PerChannelResult {
|
|
9
|
+
data: Int8Array;
|
|
10
|
+
scales: Float32Array;
|
|
11
|
+
}
|
|
12
|
+
export interface MemoryEstimate {
|
|
13
|
+
fp32: number;
|
|
14
|
+
fp16: number;
|
|
15
|
+
int8: number;
|
|
16
|
+
}
|
|
17
|
+
export declare function floatToFp16(val: number): number;
|
|
18
|
+
export declare function fp16ToFloat(val: number): number;
|
|
19
|
+
export declare function quantizeFp16(f32: Float32Array): Uint16Array;
|
|
20
|
+
export declare function dequantizeFp16(fp16: Uint16Array): Float32Array;
|
|
21
|
+
export declare function quantizeInt8(f32: Float32Array): QuantizeInt8Result;
|
|
22
|
+
export declare function dequantizeInt8(int8: Int8Array, scale: number): Float32Array;
|
|
23
|
+
export declare function quantizeInt8PerChannel(f32: Float32Array, numChannels: number): QuantizeInt8PerChannelResult;
|
|
24
|
+
export declare function dequantizeInt8PerChannel(int8: Int8Array, scales: Float32Array, numChannels: number): Float32Array;
|
|
25
|
+
export declare function estimateMemory(numElements: number): MemoryEstimate;
|
|
26
|
+
//# sourceMappingURL=quantization.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"quantization.d.ts","sourceRoot":"","sources":["../../src/utils/quantization.ts"],"names":[],"mappings":"AAAA;;GAEG;AAEH,MAAM,WAAW,kBAAkB;IACjC,IAAI,EAAE,SAAS,CAAC;IAChB,KAAK,EAAE,MAAM,CAAC;CACf;AAED,MAAM,WAAW,4BAA4B;IAC3C,IAAI,EAAE,SAAS,CAAC;IAChB,MAAM,EAAE,YAAY,CAAC;CACtB;AAED,MAAM,WAAW,cAAc;IAC7B,IAAI,EAAE,MAAM,CAAC;IACb,IAAI,EAAE,MAAM,CAAC;IACb,IAAI,EAAE,MAAM,CAAC;CACd;AAED,wBAAgB,WAAW,CAAC,GAAG,EAAE,MAAM,GAAG,MAAM,CA4B/C;AAED,wBAAgB,WAAW,CAAC,GAAG,EAAE,MAAM,GAAG,MAAM,CAiB/C;AAED,wBAAgB,YAAY,CAAC,GAAG,EAAE,YAAY,GAAG,WAAW,CAM3D;AAED,wBAAgB,cAAc,CAAC,IAAI,EAAE,WAAW,GAAG,YAAY,CAM9D;AAED,wBAAgB,YAAY,CAAC,GAAG,EAAE,YAAY,GAAG,kBAAkB,CAelE;AAED,wBAAgB,cAAc,CAAC,IAAI,EAAE,SAAS,EAAE,KAAK,EAAE,MAAM,GAAG,YAAY,CAM3E;AAED,wBAAgB,sBAAsB,CAAC,GAAG,EAAE,YAAY,EAAE,WAAW,EAAE,MAAM,GAAG,4BAA4B,CAqB3G;AAED,wBAAgB,wBAAwB,CAAC,IAAI,EAAE,SAAS,EAAE,MAAM,EAAE,YAAY,EAAE,WAAW,EAAE,MAAM,GAAG,YAAY,CAYjH;AAED,wBAAgB,cAAc,CAAC,WAAW,EAAE,MAAM,GAAG,cAAc,CAMlE"}
|