tensorgrad 0.0.1 → 0.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +7 -9
- package/dist/adam.d.ts +14 -2
- package/dist/adam.d.ts.map +1 -1
- package/dist/adam.js +19 -8
- package/dist/adam.js.map +1 -1
- package/dist/buffers.d.ts +1 -0
- package/dist/buffers.d.ts.map +1 -1
- package/dist/buffers.js +12 -1
- package/dist/buffers.js.map +1 -1
- package/dist/capture.d.ts +3 -0
- package/dist/capture.d.ts.map +1 -0
- package/dist/capture.js +33 -0
- package/dist/capture.js.map +1 -0
- package/dist/codegen.js +4 -2
- package/dist/codegen.js.map +1 -1
- package/dist/compile.d.ts +33 -5
- package/dist/compile.d.ts.map +1 -1
- package/dist/compile.js +96 -11
- package/dist/compile.js.map +1 -1
- package/dist/index.d.ts +5 -3
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +4 -2
- package/dist/index.js.map +1 -1
- package/dist/ir.d.ts +2 -0
- package/dist/ir.d.ts.map +1 -1
- package/dist/ir.js +1 -1
- package/dist/ir.js.map +1 -1
- package/dist/module.d.ts +30 -4
- package/dist/module.d.ts.map +1 -1
- package/dist/module.js +39 -13
- package/dist/module.js.map +1 -1
- package/dist/nn.d.ts +19 -0
- package/dist/nn.d.ts.map +1 -0
- package/dist/nn.js +60 -0
- package/dist/nn.js.map +1 -0
- package/dist/ops.d.ts +1 -1
- package/dist/ops.d.ts.map +1 -1
- package/dist/ops.js +2 -2
- package/dist/ops.js.map +1 -1
- package/dist/runtime.d.ts +79 -4
- package/dist/runtime.d.ts.map +1 -1
- package/dist/runtime.js +153 -19
- package/dist/runtime.js.map +1 -1
- package/dist/trace.d.ts +1 -0
- package/dist/trace.d.ts.map +1 -1
- package/dist/trace.js +12 -0
- package/dist/trace.js.map +1 -1
- package/package.json +1 -2
- package/src/adam.ts +31 -10
- package/src/buffers.ts +14 -1
- package/src/capture.ts +36 -0
- package/src/codegen.ts +4 -2
- package/src/compile.ts +112 -13
- package/src/index.ts +5 -3
- package/src/ir.ts +10 -4
- package/src/module.ts +75 -11
- package/src/nn.ts +59 -0
- package/src/ops.ts +2 -2
- package/src/runtime.ts +260 -22
- package/src/trace.ts +13 -0
- package/SPEC.md +0 -293
package/dist/compile.js
CHANGED
|
@@ -12,7 +12,7 @@ import { appendGrad } from './grad.js';
|
|
|
12
12
|
import { appendAdam } from './adam.js';
|
|
13
13
|
import { planBuffers } from './buffers.js';
|
|
14
14
|
import { emitKernels } from './codegen.js';
|
|
15
|
-
import { createRuntime } from './runtime.js';
|
|
15
|
+
import { createRuntime, createForwardRuntime } from './runtime.js';
|
|
16
16
|
import { Module, materializeParams } from './module.js';
|
|
17
17
|
/** Trace + autograd + buffer-plan + codegen, without touching WebGPU. */
|
|
18
18
|
export function compileToIR(traceFn) {
|
|
@@ -30,9 +30,13 @@ export async function compile(traceFn, opts = {}) {
|
|
|
30
30
|
return Object.assign(runtime, { ir });
|
|
31
31
|
}
|
|
32
32
|
/**
|
|
33
|
-
* Compile a Module-based model.
|
|
34
|
-
* model
|
|
35
|
-
*
|
|
33
|
+
* Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
|
|
34
|
+
* model instance itself: compilation mutates the tree (every `ParamSentinel`
|
|
35
|
+
* field becomes a real `Tensor`), so the instance is consumed and shouldn't be
|
|
36
|
+
* referenced afterwards. Re-call the factory if you need a fresh tree.
|
|
37
|
+
*
|
|
38
|
+
* The forward function takes the materialized model and returns the loss
|
|
39
|
+
* tensor.
|
|
36
40
|
*
|
|
37
41
|
* Walks the module tree to materialize params with auto-derived names, then
|
|
38
42
|
* runs trace → grad → adam → buffer plan → codegen → runtime.
|
|
@@ -41,36 +45,117 @@ export async function compile(traceFn, opts = {}) {
|
|
|
41
45
|
* internal step count and injects the bias-corrected `lrt` scalar each call;
|
|
42
46
|
* users don't need to provide it themselves.
|
|
43
47
|
*/
|
|
44
|
-
export async function compileModule(
|
|
48
|
+
export async function compileModule(modelFactory, forward, opts = {}) {
|
|
45
49
|
const inputDecls = opts.inputs ?? [];
|
|
46
|
-
|
|
50
|
+
const model = modelFactory();
|
|
51
|
+
let materialized = { tensors: {}, initFns: {} };
|
|
47
52
|
const graph = trace(() => {
|
|
48
|
-
|
|
53
|
+
materialized = materializeParams(model);
|
|
49
54
|
const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'));
|
|
50
55
|
return forward(model, ...inputTensors);
|
|
51
56
|
});
|
|
52
57
|
const { paramGrads, loss } = appendGrad(graph);
|
|
53
58
|
let adamResult;
|
|
54
59
|
if (opts.adam) {
|
|
55
|
-
adamResult = appendAdam(graph, paramGrads,
|
|
60
|
+
adamResult = appendAdam(graph, paramGrads, materialized.tensors, opts.adam);
|
|
56
61
|
}
|
|
57
62
|
const plan = planBuffers(graph, paramGrads, adamResult?.writebacks ?? []);
|
|
58
63
|
const kernels = emitKernels(graph, plan);
|
|
59
64
|
const lossBufferId = plan.tensorToBuffer.get(loss.id);
|
|
60
65
|
const runtime = await createRuntime(plan, kernels, lossBufferId, opts);
|
|
61
66
|
// If Adam is enabled, wrap step() to track the step count and supply lrt.
|
|
67
|
+
// Wrap resetOptimizerState() too, so a reset zeros m/v *and* the bias-correction
|
|
68
|
+
// counter — otherwise the next step would skip Adam's warmup phase.
|
|
62
69
|
if (adamResult) {
|
|
63
70
|
const { lrtInputName, config } = adamResult;
|
|
64
71
|
let t = 0;
|
|
65
72
|
const lrtBuf = new Float32Array(1);
|
|
66
73
|
const innerStep = runtime.step.bind(runtime);
|
|
67
|
-
|
|
74
|
+
const innerReset = runtime.resetOptimizerState.bind(runtime);
|
|
75
|
+
const wrappedStep = (inputs, opts) => {
|
|
68
76
|
t++;
|
|
69
77
|
lrtBuf[0] = config.lr * Math.sqrt(1 - Math.pow(config.b2, t)) / (1 - Math.pow(config.b1, t));
|
|
70
|
-
|
|
78
|
+
const merged = { ...inputs, [lrtInputName]: lrtBuf };
|
|
79
|
+
return opts?.withCaptures ? innerStep(merged, { withCaptures: true }) : innerStep(merged);
|
|
80
|
+
};
|
|
81
|
+
runtime.step = wrappedStep;
|
|
82
|
+
runtime.resetOptimizerState = () => {
|
|
83
|
+
t = 0;
|
|
84
|
+
innerReset();
|
|
71
85
|
};
|
|
72
86
|
}
|
|
87
|
+
const { initFns } = materialized;
|
|
88
|
+
const uploadInitialParams = () => {
|
|
89
|
+
const out = {};
|
|
90
|
+
for (const [name, bufId] of plan.paramsByName) {
|
|
91
|
+
const shape = plan.buffers[bufId].shape;
|
|
92
|
+
const size = shape.reduce((a, b) => a * b, 1);
|
|
93
|
+
const initFn = initFns[name];
|
|
94
|
+
if (!initFn)
|
|
95
|
+
throw new Error(`uploadInitialParams: no init for param '${name}'`);
|
|
96
|
+
out[name] = initFn(size, shape);
|
|
97
|
+
}
|
|
98
|
+
runtime.uploadParams(out);
|
|
99
|
+
};
|
|
73
100
|
const ir = { graph, paramGrads, loss, plan, kernels };
|
|
74
|
-
return Object.assign(runtime, { ir });
|
|
101
|
+
return Object.assign(runtime, { ir, uploadInitialParams });
|
|
102
|
+
}
|
|
103
|
+
// ============================================================================
|
|
104
|
+
// Forward-only compile
|
|
105
|
+
// ============================================================================
|
|
106
|
+
/**
|
|
107
|
+
* Compile a Module-based model in forward-only mode (no autograd, no Adam).
|
|
108
|
+
* The forward function returns the output tensor (e.g., logits) instead of a
|
|
109
|
+
* scalar loss; runtime exposes `run(inputs)` returning the full output as a
|
|
110
|
+
* `Float32Array`.
|
|
111
|
+
*
|
|
112
|
+
* **Sharing params with a training compile.** Pass `opts.sharedParams =
|
|
113
|
+
* trainCompiled.params` to bind this graph's param buffers to an existing
|
|
114
|
+
* training runtime's GPU buffers — every train step is then immediately
|
|
115
|
+
* visible to `run()` calls here, no copies. The forward graph's
|
|
116
|
+
* `uploadInitialParams()` skips any param covered by `sharedParams`.
|
|
117
|
+
*
|
|
118
|
+
* Typical use: a B=1 inference graph alongside a B=512 training graph,
|
|
119
|
+
* built from the same `Module` factory.
|
|
120
|
+
*/
|
|
121
|
+
export async function compileForward(modelFactory, forward, opts = {}) {
|
|
122
|
+
const inputDecls = opts.inputs ?? [];
|
|
123
|
+
const model = modelFactory();
|
|
124
|
+
let materialized = { tensors: {}, initFns: {} };
|
|
125
|
+
const graph = trace(() => {
|
|
126
|
+
materialized = materializeParams(model);
|
|
127
|
+
const inputTensors = inputDecls.map(d => tensorInput(d.name, d.shape, d.dtype ?? 'f32'));
|
|
128
|
+
return forward(model, ...inputTensors);
|
|
129
|
+
});
|
|
130
|
+
const plan = planBuffers(graph, /* paramGrads */ {});
|
|
131
|
+
const kernels = emitKernels(graph, plan);
|
|
132
|
+
const outputTensor = graph.tensors[graph.outputs[0]];
|
|
133
|
+
const outputBufferId = plan.tensorToBuffer.get(outputTensor.id);
|
|
134
|
+
const runtime = await createForwardRuntime(plan, kernels, outputBufferId, opts);
|
|
135
|
+
const sharedParams = opts.sharedParams;
|
|
136
|
+
const { initFns } = materialized;
|
|
137
|
+
const uploadInitialParams = () => {
|
|
138
|
+
const out = {};
|
|
139
|
+
let needsUpload = false;
|
|
140
|
+
for (const [name, bufId] of plan.paramsByName) {
|
|
141
|
+
// Skip params covered by sharedParams — those are owned by the providing
|
|
142
|
+
// compile and already initialized there.
|
|
143
|
+
if (sharedParams?.has(name))
|
|
144
|
+
continue;
|
|
145
|
+
const shape = plan.buffers[bufId].shape;
|
|
146
|
+
const size = shape.reduce((a, b) => a * b, 1);
|
|
147
|
+
const initFn = initFns[name];
|
|
148
|
+
if (!initFn)
|
|
149
|
+
throw new Error(`uploadInitialParams: no init for param '${name}'`);
|
|
150
|
+
out[name] = initFn(size, shape);
|
|
151
|
+
needsUpload = true;
|
|
152
|
+
}
|
|
153
|
+
if (needsUpload)
|
|
154
|
+
runtime.uploadParams(out, { partial: !!sharedParams });
|
|
155
|
+
};
|
|
156
|
+
// CompiledIR.loss is the field name; for forward-only, it carries the user's
|
|
157
|
+
// returned tensor (e.g., logits). Same shape conceptually; just no autograd.
|
|
158
|
+
const ir = { graph, paramGrads: {}, loss: outputTensor, plan, kernels };
|
|
159
|
+
return Object.assign(runtime, { ir, uploadInitialParams });
|
|
75
160
|
}
|
|
76
161
|
//# sourceMappingURL=compile.js.map
|
package/dist/compile.js.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"compile.js","sourceRoot":"","sources":["../src/compile.ts"],"names":[],"mappings":"AAAA,2EAA2E;AAC3E,EAAE;AACF,oBAAoB;AACpB,sEAAsE;AACtE,iEAAiE;AACjE,0EAA0E;AAC1E,0EAA0E;AAC1E,2EAA2E;AAC3E,mEAAmE;AAGnE,OAAO,EAAE,KAAK,EAAE,WAAW,EAAE,MAAM,YAAY,CAAA;AAC/C,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,
|
|
1
|
+
{"version":3,"file":"compile.js","sourceRoot":"","sources":["../src/compile.ts"],"names":[],"mappings":"AAAA,2EAA2E;AAC3E,EAAE;AACF,oBAAoB;AACpB,sEAAsE;AACtE,iEAAiE;AACjE,0EAA0E;AAC1E,0EAA0E;AAC1E,2EAA2E;AAC3E,mEAAmE;AAGnE,OAAO,EAAE,KAAK,EAAE,WAAW,EAAE,MAAM,YAAY,CAAA;AAC/C,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,oBAAoB,EAAgE,MAAM,cAAc,CAAA;AAChI,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAAE,MAAM,aAAa,CAAA;AAmBvD,yEAAyE;AACzE,MAAM,UAAU,WAAW,CAAC,OAAqB;IAC/C,MAAM,KAAK,GAAG,KAAK,CAAC,OAAO,CAAC,CAAA;IAC5B,MAAM,EAAE,UAAU,EAAE,IAAI,EAAE,GAAG,UAAU,CAAC,KAAK,CAAC,CAAA;IAC9C,MAAM,IAAI,GAAG,WAAW,CAAC,KAAK,EAAE,UAAU,CAAC,CAAA;IAC3C,MAAM,OAAO,GAAG,WAAW,CAAC,KAAK,EAAE,IAAI,CAAC,CAAA;IACxC,OAAO,EAAE,KAAK,EAAE,UAAU,EAAE,IAAI,EAAE,IAAI,EAAE,OAAO,EAAE,CAAA;AACnD,CAAC;AAED,0EAA0E;AAC1E,MAAM,CAAC,KAAK,UAAU,OAAO,CAAC,OAAqB,EAAE,OAAoB,EAAE;IACzE,MAAM,EAAE,GAAG,WAAW,CAAC,OAAO,CAAC,CAAA;IAC/B,MAAM,YAAY,GAAG,EAAE,CAAC,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAE,CAAA;IAC5D,MAAM,OAAO,GAAG,MAAM,aAAa,CAAC,EAAE,CAAC,IAAI,EAAE,EAAE,CAAC,OAAO,EAAE,YAAY,EAAE,IAAI,CAAC,CAAA;IAC5E,OAAO,MAAM,CAAC,MAAM,CAAC,OAAO,EAAE,EAAE,EAAE,EAAE,CAAC,CAAA;AACvC,CAAC;AAqBD;;;;;;;;;;;;;;;GAeG;AACH,MAAM,CAAC,KAAK,UAAU,aAAa,CACjC,YAAqB,EACrB,OAA8C,EAC9C,OAA6B,EAAE;IAE/B,MAAM,UAAU,GAAG,IAAI,CAAC,MAAM,IAAI,EAAE,CAAA;IACpC,MAAM,KAAK,GAAG,YAAY,EAAE,CAAA;IAC5B,IAAI,YAAY,GAAyC,EAAE,OAAO,EAAE,EAAE,EAAE,OAAO,EAAE,EAAE,EAAE,CAAA;IACrF,MAAM,KAAK,GAAG,KAAK,CAAC,GAAG,EAAE;QACvB,YAAY,GAAG,iBAAiB,CAAC,KAAK,CAAC,CAAA;QACvC,MAAM,YAAY,GAAG,UAAU,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,WAAW,CAAC,CAAC,CAAC,IAAI,EAAE,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,IAAI,KAAK,CAAC,CAAC,CAAA;QACxF,OAAO,OAAO,CAAC,KAAK,EAAE,GAAG,YAAY,CAAC,CAAA;IACxC,CAAC,CAAC,CAAA;IAEF,MAAM,EAAE,UAAU,EAAE,IAAI,EAAE,GAAG,UAAU,CAAC,KAAK,CAAC,CAAA;IAE9C,IAAI,UAAqD,CAAA;IACzD,IAAI,IAAI,CAAC,IAAI,EAAE,CAAC;QACd,UAAU,GAAG,UAAU,CAAC,KAAK,EAAE,UAAU,EAAE,YAAY,CAAC,OAAO,EAAE,IAAI,CAAC,IAAI,CAAC,CAAA;IAC7E,CAAC;IAED,MAAM,IAAI,GAAG,WAAW,CAAC,KAAK,EAAE,UAAU,EAAE,UAAU,EAAE,UAAU,IAAI,EAAE,CAAC,CAAA;IACzE,MAAM,OAAO,GAAG,WAAW,CAAC,KAAK,EAAE,IAAI,CAAC,CAAA;IACxC,MAAM,YAAY,GAAG,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAE,CAAA;IACtD,MAAM,OAAO,GAAG,MAAM,aAAa,CAAC,IAAI,EAAE,OAAO,EAAE,YAAY,EAAE,IAAI,CAAC,CAAA;IAEtE,0EAA0E;IAC1E,iFAAiF;IACjF,oEAAoE;IACpE,IAAI,UAAU,EAAE,CAAC;QACf,MAAM,EAAE,YAAY,EAAE,MAAM,EAAE,GAAG,UAAU,CAAA;QAC3C,IAAI,CAAC,GAAG,CAAC,CAAA;QACT,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,CAAC,CAAC,CAAA;QAClC,MAAM,SAAS,GAAG,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,OAAO,CAA4B,CAAA;QACvE,MAAM,UAAU,GAAG,OAAO,CAAC,mBAAmB,CAAC,IAAI,CAAC,OAAO,CAAC,CAAA;QAC5D,MAAM,WAAW,GAAG,CAClB,MAAiD,EACjD,IAAiC,EAC2C,EAAE;YAC9E,CAAC,EAAE,CAAA;YACH,MAAM,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC,EAAE,GAAG,IAAI,CAAC,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAA;YAC5F,MAAM,MAAM,GAAG,EAAE,GAAG,MAAM,EAAE,CAAC,YAAY,CAAC,EAAE,MAAM,EAAE,CAAA;YACpD,OAAO,IAAI,EAAE,YAAY,CAAC,CAAC,CAAC,SAAS,CAAC,MAAM,EAAE,EAAE,YAAY,EAAE,IAAI,EAAE,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,MAAM,CAAC,CAAA;QAC3F,CAAC,CAAA;QACD,OAAO,CAAC,IAAI,GAAG,WAAsC,CAAA;QACrD,OAAO,CAAC,mBAAmB,GAAG,GAAG,EAAE;YACjC,CAAC,GAAG,CAAC,CAAA;YACL,UAAU,EAAE,CAAA;QACd,CAAC,CAAA;IACH,CAAC;IAED,MAAM,EAAE,OAAO,EAAE,GAAG,YAAY,CAAA;IAChC,MAAM,mBAAmB,GAAG,GAAG,EAAE;QAC/B,MAAM,GAAG,GAAiC,EAAE,CAAA;QAC5C,KAAK,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,IAAI,IAAI,CAAC,YAAY,EAAE,CAAC;YAC9C,MAAM,KAAK,GAAG,IAAI,CAAC,OAAO,CAAC,KAAK,CAAE,CAAC,KAAK,CAAA;YACxC,MAAM,IAAI,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC,CAAA;YAC7C,MAAM,MAAM,GAAG,OAAO,CAAC,IAAI,CAAC,CAAA;YAC5B,IAAI,CAAC,MAAM;gBAAE,MAAM,IAAI,KAAK,CAAC,2CAA2C,IAAI,GAAG,CAAC,CAAA;YAChF,GAAG,CAAC,IAAI,CAAC,GAAG,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,CAAA;QACjC,CAAC;QACD,OAAO,CAAC,YAAY,CAAC,GAAG,CAAC,CAAA;IAC3B,CAAC,CAAA;IAED,MAAM,EAAE,GAAe,EAAE,KAAK,EAAE,UAAU,EAAE,IAAI,EAAE,IAAI,EAAE,OAAO,EAAE,CAAA;IACjE,OAAO,MAAM,CAAC,MAAM,CAAC,OAAO,EAAE,EAAE,EAAE,EAAE,mBAAmB,EAAE,CAAC,CAAA;AAC5D,CAAC;AAED,+EAA+E;AAC/E,uBAAuB;AACvB,+EAA+E;AAE/E;;;;;;;;;;;;;;GAcG;AACH,MAAM,CAAC,KAAK,UAAU,cAAc,CAClC,YAAqB,EACrB,OAA8C,EAC9C,OAA8B,EAAE;IAEhC,MAAM,UAAU,GAAG,IAAI,CAAC,MAAM,IAAI,EAAE,CAAA;IACpC,MAAM,KAAK,GAAG,YAAY,EAAE,CAAA;IAC5B,IAAI,YAAY,GAAyC,EAAE,OAAO,EAAE,EAAE,EAAE,OAAO,EAAE,EAAE,EAAE,CAAA;IACrF,MAAM,KAAK,GAAG,KAAK,CAAC,GAAG,EAAE;QACvB,YAAY,GAAG,iBAAiB,CAAC,KAAK,CAAC,CAAA;QACvC,MAAM,YAAY,GAAG,UAAU,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,WAAW,CAAC,CAAC,CAAC,IAAI,EAAE,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,IAAI,KAAK,CAAC,CAAC,CAAA;QACxF,OAAO,OAAO,CAAC,KAAK,EAAE,GAAG,YAAY,CAAC,CAAA;IACxC,CAAC,CAAC,CAAA;IAEF,MAAM,IAAI,GAAG,WAAW,CAAC,KAAK,EAAE,gBAAgB,CAAC,EAAE,CAAC,CAAA;IACpD,MAAM,OAAO,GAAG,WAAW,CAAC,KAAK,EAAE,IAAI,CAAC,CAAA;IACxC,MAAM,YAAY,GAAG,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC,CAAE,CAAE,CAAA;IACtD,MAAM,cAAc,GAAG,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,YAAY,CAAC,EAAE,CAAE,CAAA;IAChE,MAAM,OAAO,GAAG,MAAM,oBAAoB,CAAC,IAAI,EAAE,OAAO,EAAE,cAAc,EAAE,IAAI,CAAC,CAAA;IAE/E,MAAM,YAAY,GAAG,IAAI,CAAC,YAAY,CAAA;IACtC,MAAM,EAAE,OAAO,EAAE,GAAG,YAAY,CAAA;IAChC,MAAM,mBAAmB,GAAG,GAAG,EAAE;QAC/B,MAAM,GAAG,GAAiC,EAAE,CAAA;QAC5C,IAAI,WAAW,GAAG,KAAK,CAAA;QACvB,KAAK,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,IAAI,IAAI,CAAC,YAAY,EAAE,CAAC;YAC9C,yEAAyE;YACzE,yCAAyC;YACzC,IAAI,YAAY,EAAE,GAAG,CAAC,IAAI,CAAC;gBAAE,SAAQ;YACrC,MAAM,KAAK,GAAG,IAAI,CAAC,OAAO,CAAC,KAAK,CAAE,CAAC,KAAK,CAAA;YACxC,MAAM,IAAI,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC,CAAA;YAC7C,MAAM,MAAM,GAAG,OAAO,CAAC,IAAI,CAAC,CAAA;YAC5B,IAAI,CAAC,MAAM;gBAAE,MAAM,IAAI,KAAK,CAAC,2CAA2C,IAAI,GAAG,CAAC,CAAA;YAChF,GAAG,CAAC,IAAI,CAAC,GAAG,MAAM,CAAC,IAAI,EAAE,KAAK,CAAC,CAAA;YAC/B,WAAW,GAAG,IAAI,CAAA;QACpB,CAAC;QACD,IAAI,WAAW;YAAE,OAAO,CAAC,YAAY,CAAC,GAAG,EAAE,EAAE,OAAO,EAAE,CAAC,CAAC,YAAY,EAAE,CAAC,CAAA;IACzE,CAAC,CAAA;IAED,6EAA6E;IAC7E,6EAA6E;IAC7E,MAAM,EAAE,GAAe,EAAE,KAAK,EAAE,UAAU,EAAE,EAAE,EAAE,IAAI,EAAE,YAAY,EAAE,IAAI,EAAE,OAAO,EAAE,CAAA;IACnF,OAAO,MAAM,CAAC,MAAM,CAAC,OAAO,EAAE,EAAE,EAAE,EAAE,mBAAmB,EAAE,CAAC,CAAA;AAC5D,CAAC"}
|
package/dist/index.d.ts
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
export type { Tensor, Shape, Dtype, OpNode, Graph, CallSite } from './ir.js';
|
|
2
2
|
export { ShapeError } from './shape.js';
|
|
3
3
|
export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js';
|
|
4
|
+
export { capture } from './capture.js';
|
|
4
5
|
export { add, sub, mul, div, sqrt, rsqrt, log, exp, relu, less, greater, where, meanLast, sumLast, reshape, transpose, matmul, matmulBatched, oneHot, arange, softmaxCausalLast, logSoftmaxLast, whereCausal, sliceLastRange, } from './ops.js';
|
|
5
6
|
export { appendGrad, type GradResult } from './grad.js';
|
|
6
7
|
export { appendAdam, type AdamConfig, type AdamResult } from './adam.js';
|
|
7
8
|
export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js';
|
|
8
9
|
export { emitKernels, type KernelSpec } from './codegen.js';
|
|
9
|
-
export { createRuntime, type CompiledRuntime, type RuntimeOpts } from './runtime.js';
|
|
10
|
-
export { compile, compileToIR, compileModule, type CompiledIR, type CompileModuleOptions, type InputDecl } from './compile.js';
|
|
11
|
-
export { Module, materializeParams } from './module.js';
|
|
10
|
+
export { createRuntime, createForwardRuntime, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type StepOptions, type StepWithCaptures, type RunOptions, type RunWithCaptures } from './runtime.js';
|
|
11
|
+
export { compile, compileToIR, compileModule, compileForward, type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type InputDecl } from './compile.js';
|
|
12
|
+
export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js';
|
|
13
|
+
export * as nn from './nn.js';
|
|
12
14
|
//# sourceMappingURL=index.d.ts.map
|
package/dist/index.d.ts.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAKA,YAAY,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,QAAQ,EAAE,MAAM,SAAS,CAAA;AAC5E,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAClF,OAAO,EAEL,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAElB,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI,EAE3B,IAAI,EAAE,OAAO,EAAE,KAAK,EAEpB,QAAQ,EAAE,OAAO,EAEjB,OAAO,EAAE,SAAS,EAElB,MAAM,EAAE,aAAa,EAErB,MAAM,EAAE,MAAM,EAEd,iBAAiB,EAAE,cAAc,EAAE,WAAW,EAE9C,cAAc,GACf,MAAM,UAAU,CAAA;AAMjB,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACxE,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,SAAS,EAAE,KAAK,aAAa,EAAE,MAAM,cAAc,CAAA;AAChH,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,KAAK,eAAe,EAAE,KAAK,WAAW,EAAE,MAAM,cAAc,CAAA;
|
|
1
|
+
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAKA,YAAY,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,QAAQ,EAAE,MAAM,SAAS,CAAA;AAC5E,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAClF,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAA;AACtC,OAAO,EAEL,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAElB,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI,EAE3B,IAAI,EAAE,OAAO,EAAE,KAAK,EAEpB,QAAQ,EAAE,OAAO,EAEjB,OAAO,EAAE,SAAS,EAElB,MAAM,EAAE,aAAa,EAErB,MAAM,EAAE,MAAM,EAEd,iBAAiB,EAAE,cAAc,EAAE,WAAW,EAE9C,cAAc,GACf,MAAM,UAAU,CAAA;AAMjB,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACxE,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,SAAS,EAAE,KAAK,aAAa,EAAE,MAAM,cAAc,CAAA;AAChH,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,oBAAoB,EAAE,KAAK,eAAe,EAAE,KAAK,eAAe,EAAE,KAAK,WAAW,EAAE,KAAK,WAAW,EAAE,KAAK,gBAAgB,EAAE,KAAK,UAAU,EAAE,KAAK,eAAe,EAAE,MAAM,cAAc,CAAA;AAChN,OAAO,EAAE,OAAO,EAAE,WAAW,EAAE,aAAa,EAAE,cAAc,EAAE,KAAK,UAAU,EAAE,KAAK,oBAAoB,EAAE,KAAK,qBAAqB,EAAE,KAAK,SAAS,EAAE,MAAM,cAAc,CAAA;AAC1K,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAAE,KAAK,QAAQ,EAAE,KAAK,YAAY,EAAE,KAAK,kBAAkB,EAAE,MAAM,aAAa,CAAA;AAClH,OAAO,KAAK,EAAE,MAAM,SAAS,CAAA"}
|
package/dist/index.js
CHANGED
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
// codegen / compile() (Phase 3+) come later.
|
|
5
5
|
export { ShapeError } from './shape.js';
|
|
6
6
|
export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js';
|
|
7
|
+
export { capture } from './capture.js';
|
|
7
8
|
export {
|
|
8
9
|
// Element-wise arithmetic. The binops accept Tensor or JS-number for the second arg.
|
|
9
10
|
add, sub, mul, div,
|
|
@@ -31,7 +32,8 @@ export { appendGrad } from './grad.js';
|
|
|
31
32
|
export { appendAdam } from './adam.js';
|
|
32
33
|
export { planBuffers } from './buffers.js';
|
|
33
34
|
export { emitKernels } from './codegen.js';
|
|
34
|
-
export { createRuntime } from './runtime.js';
|
|
35
|
-
export { compile, compileToIR, compileModule } from './compile.js';
|
|
35
|
+
export { createRuntime, createForwardRuntime } from './runtime.js';
|
|
36
|
+
export { compile, compileToIR, compileModule, compileForward } from './compile.js';
|
|
36
37
|
export { Module, materializeParams } from './module.js';
|
|
38
|
+
export * as nn from './nn.js';
|
|
37
39
|
//# sourceMappingURL=index.js.map
|
package/dist/index.js.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"index.js","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,+CAA+C;AAC/C,EAAE;AACF,8EAA8E;AAC9E,6CAA6C;AAG7C,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAClF,OAAO;AACL,qFAAqF;AACrF,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG;AAClB,qBAAqB;AACrB,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI;AAC3B,uBAAuB;AACvB,IAAI,EAAE,OAAO,EAAE,KAAK;AACpB,yEAAyE;AACzE,QAAQ,EAAE,OAAO;AACjB,YAAY;AACZ,OAAO,EAAE,SAAS;AAClB,iBAAiB;AACjB,MAAM,EAAE,aAAa;AACrB,qBAAqB;AACrB,MAAM,EAAE,MAAM;AACd,4CAA4C;AAC5C,iBAAiB,EAAE,cAAc,EAAE,WAAW;AAC9C,UAAU;AACV,cAAc,GACf,MAAM,UAAU,CAAA;AAEjB,sFAAsF;AACtF,8EAA8E;AAC9E,2EAA2E;AAC3E,qDAAqD;AACrD,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAoC,MAAM,WAAW,CAAA;AACxE,OAAO,EAAE,WAAW,EAAwE,MAAM,cAAc,CAAA;AAChH,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,
|
|
1
|
+
{"version":3,"file":"index.js","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,+CAA+C;AAC/C,EAAE;AACF,8EAA8E;AAC9E,6CAA6C;AAG7C,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAClF,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAA;AACtC,OAAO;AACL,qFAAqF;AACrF,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG;AAClB,qBAAqB;AACrB,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI;AAC3B,uBAAuB;AACvB,IAAI,EAAE,OAAO,EAAE,KAAK;AACpB,yEAAyE;AACzE,QAAQ,EAAE,OAAO;AACjB,YAAY;AACZ,OAAO,EAAE,SAAS;AAClB,iBAAiB;AACjB,MAAM,EAAE,aAAa;AACrB,qBAAqB;AACrB,MAAM,EAAE,MAAM;AACd,4CAA4C;AAC5C,iBAAiB,EAAE,cAAc,EAAE,WAAW;AAC9C,UAAU;AACV,cAAc,GACf,MAAM,UAAU,CAAA;AAEjB,sFAAsF;AACtF,8EAA8E;AAC9E,2EAA2E;AAC3E,qDAAqD;AACrD,OAAO,EAAE,UAAU,EAAmB,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAoC,MAAM,WAAW,CAAA;AACxE,OAAO,EAAE,WAAW,EAAwE,MAAM,cAAc,CAAA;AAChH,OAAO,EAAE,WAAW,EAAmB,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,oBAAoB,EAAgJ,MAAM,cAAc,CAAA;AAChN,OAAO,EAAE,OAAO,EAAE,WAAW,EAAE,aAAa,EAAE,cAAc,EAA0F,MAAM,cAAc,CAAA;AAC1K,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAA6D,MAAM,aAAa,CAAA;AAClH,OAAO,KAAK,EAAE,MAAM,SAAS,CAAA"}
|
package/dist/ir.d.ts
CHANGED
|
@@ -162,6 +162,7 @@ export type OpNode = {
|
|
|
162
162
|
vNew: number;
|
|
163
163
|
lrt: number;
|
|
164
164
|
eps: number;
|
|
165
|
+
decayShrink: number;
|
|
165
166
|
} | {
|
|
166
167
|
kind: 'slice_last_range';
|
|
167
168
|
out: number;
|
|
@@ -193,6 +194,7 @@ export interface Graph {
|
|
|
193
194
|
readonly ops: OpNode[];
|
|
194
195
|
readonly tensors: Tensor[];
|
|
195
196
|
readonly outputs: number[];
|
|
197
|
+
readonly captures: Map<string, number>;
|
|
196
198
|
}
|
|
197
199
|
export declare function makeGraph(): Graph;
|
|
198
200
|
export declare function addTensor(g: Graph, shape: Shape, dtype: Dtype, source: number | null, site: CallSite | null): Tensor;
|
package/dist/ir.d.ts.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"ir.d.ts","sourceRoot":"","sources":["../src/ir.ts"],"names":[],"mappings":"AAcA,MAAM,MAAM,KAAK,GAAG,KAAK,GAAG,KAAK,GAAG,MAAM,CAAA;AAC1C,MAAM,MAAM,KAAK,GAAG,SAAS,MAAM,EAAE,CAAA;AAIrC,MAAM,WAAW,MAAM;IACrB,QAAQ,CAAC,EAAE,EAAE,MAAM,CAAA;IACnB,QAAQ,CAAC,KAAK,EAAE,KAAK,CAAA;IACrB,QAAQ,CAAC,KAAK,EAAE,KAAK,CAAA;IAErB,QAAQ,CAAC,MAAM,EAAE,MAAM,GAAG,IAAI,CAAA;IAG9B,QAAQ,CAAC,IAAI,EAAE,QAAQ,GAAG,IAAI,CAAA;CAC/B;AAED,MAAM,WAAW,QAAQ;IACvB,QAAQ,CAAC,MAAM,EAAE,MAAM,CAAA;IAEvB,QAAQ,CAAC,KAAK,EAAE,MAAM,CAAA;CACvB;AAQD,MAAM,MAAM,MAAM,GAGd;IAAE,IAAI,EAAE,aAAa,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAA;CAAE,GAElD;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAA;CAAE,GAInD;IAAE,IAAI,EAAE,aAAa,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,SAAS,EAAE,MAAM,CAAA;CAAE,GAGrE;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,YAAY,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,MAAM,EAAE,MAAM,CAAA;CAAE,GAC9D;IAAE,IAAI,EAAE,YAAY,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,MAAM,EAAE,MAAM,CAAA;CAAE,GAG9D;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACxC;IAAE,IAAI,EAAE,OAAO,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACzC;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACvC;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACvC;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAGxC;IAAE,IAAI,EAAE,WAAW,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAC7C;IAAE,IAAI,EAAE,UAAU,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAG5C;IAAE,IAAI,EAAE,SAAS,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,QAAQ,EAAE,KAAK,CAAA;CAAE,GAC5D;IAAE,IAAI,EAAE,WAAW,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,SAAS,MAAM,EAAE,CAAA;CAAE,GAMtE;IAAE,IAAI,EAAE,QAAQ,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAErD;IAAE,IAAI,EAAE,gBAAgB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAG7D;IAAE,IAAI,EAAE,SAAS,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,OAAO,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,KAAK,CAAA;CAAE,GAC9E;IAAE,IAAI,EAAE,QAAQ,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,KAAK,CAAA;CAAE,GAGxD;IAAE,IAAI,EAAE,qBAAqB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACvD;IAAE,IAAI,EAAE,kBAAkB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAIpD;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,SAAS,EAAE,MAAM,CAAA;CAAE,GAKnE;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACnD;IAAE,IAAI,EAAE,SAAS,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAGtD;IAAE,IAAI,EAAE,OAAO,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAMlE;IAAE,IAAI,EAAE,eAAe,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,EAAE,EAAE,MAAM,CAAA;CAAE,GACxE;IAAE,IAAI,EAAE,eAAe,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,EAAE,EAAE,MAAM,CAAA;CAAE,
|
|
1
|
+
{"version":3,"file":"ir.d.ts","sourceRoot":"","sources":["../src/ir.ts"],"names":[],"mappings":"AAcA,MAAM,MAAM,KAAK,GAAG,KAAK,GAAG,KAAK,GAAG,MAAM,CAAA;AAC1C,MAAM,MAAM,KAAK,GAAG,SAAS,MAAM,EAAE,CAAA;AAIrC,MAAM,WAAW,MAAM;IACrB,QAAQ,CAAC,EAAE,EAAE,MAAM,CAAA;IACnB,QAAQ,CAAC,KAAK,EAAE,KAAK,CAAA;IACrB,QAAQ,CAAC,KAAK,EAAE,KAAK,CAAA;IAErB,QAAQ,CAAC,MAAM,EAAE,MAAM,GAAG,IAAI,CAAA;IAG9B,QAAQ,CAAC,IAAI,EAAE,QAAQ,GAAG,IAAI,CAAA;CAC/B;AAED,MAAM,WAAW,QAAQ;IACvB,QAAQ,CAAC,MAAM,EAAE,MAAM,CAAA;IAEvB,QAAQ,CAAC,KAAK,EAAE,MAAM,CAAA;CACvB;AAQD,MAAM,MAAM,MAAM,GAGd;IAAE,IAAI,EAAE,aAAa,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAA;CAAE,GAElD;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAA;CAAE,GAInD;IAAE,IAAI,EAAE,aAAa,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,SAAS,EAAE,MAAM,CAAA;CAAE,GAGrE;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,YAAY,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,MAAM,EAAE,MAAM,CAAA;CAAE,GAC9D;IAAE,IAAI,EAAE,YAAY,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,MAAM,EAAE,MAAM,CAAA;CAAE,GAG9D;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACxC;IAAE,IAAI,EAAE,OAAO,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACzC;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACvC;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACvC;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAGxC;IAAE,IAAI,EAAE,WAAW,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAC7C;IAAE,IAAI,EAAE,UAAU,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAG5C;IAAE,IAAI,EAAE,SAAS,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,QAAQ,EAAE,KAAK,CAAA;CAAE,GAC5D;IAAE,IAAI,EAAE,WAAW,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,SAAS,MAAM,EAAE,CAAA;CAAE,GAMtE;IAAE,IAAI,EAAE,QAAQ,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAErD;IAAE,IAAI,EAAE,gBAAgB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAG7D;IAAE,IAAI,EAAE,SAAS,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,OAAO,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,KAAK,CAAA;CAAE,GAC9E;IAAE,IAAI,EAAE,QAAQ,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,KAAK,CAAA;CAAE,GAGxD;IAAE,IAAI,EAAE,qBAAqB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACvD;IAAE,IAAI,EAAE,kBAAkB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAIpD;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,SAAS,EAAE,MAAM,CAAA;CAAE,GAKnE;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACnD;IAAE,IAAI,EAAE,SAAS,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAGtD;IAAE,IAAI,EAAE,OAAO,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAMlE;IAAE,IAAI,EAAE,eAAe,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,EAAE,EAAE,MAAM,CAAA;CAAE,GACxE;IAAE,IAAI,EAAE,eAAe,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,EAAE,EAAE,MAAM,CAAA;CAAE,GAOxE;IAAE,IAAI,EAAE,eAAe,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,WAAW,EAAE,MAAM,CAAA;CAAE,GAM5H;IAAE,IAAI,EAAE,kBAAkB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAA;CAAE,GAGhF;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,WAAW,EAAE,KAAK,CAAA;CAAE,GAGpE;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,WAAW,EAAE,KAAK,CAAA;CAAE,GAEpE;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,KAAK,CAAA;CAAE,GAElE;IAAE,IAAI,EAAE,WAAW,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,EAAE,EAAE,MAAM,CAAA;CAAE,CAAA;AAI7D,MAAM,WAAW,KAAK;IACpB,QAAQ,CAAC,GAAG,EAAE,MAAM,EAAE,CAAA;IACtB,QAAQ,CAAC,OAAO,EAAE,MAAM,EAAE,CAAA;IAG1B,QAAQ,CAAC,OAAO,EAAE,MAAM,EAAE,CAAA;IAI1B,QAAQ,CAAC,QAAQ,EAAE,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;CACvC;AAED,wBAAgB,SAAS,IAAI,KAAK,CAEjC;AAGD,wBAAgB,SAAS,CAAC,CAAC,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,MAAM,GAAG,IAAI,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,MAAM,CAKpH;AAMD,wBAAgB,KAAK,CAAC,CAAC,SAAS,MAAM,CAAC,MAAM,CAAC,EAC5C,CAAC,EAAE,KAAK,EACR,IAAI,EAAE,CAAC,EACP,KAAK,EAAE,KAAK,EACZ,KAAK,EAAE,KAAK,EACZ,IAAI,EAAE,QAAQ,GAAG,IAAI,EACrB,MAAM,EAAE,IAAI,CAAC,OAAO,CAAC,MAAM,EAAE;IAAE,IAAI,EAAE,CAAC,CAAA;CAAE,CAAC,EAAE,MAAM,GAAG,KAAK,CAAC,GACzD,MAAM,CAMR;AAID,wBAAgB,WAAW,CAAC,MAAM,EAAE,MAAM,GAAG,QAAQ,CAIpD;AAID,wBAAgB,UAAU,CAAC,IAAI,EAAE,QAAQ,GAAG,MAAM,CAYjD"}
|
package/dist/ir.js
CHANGED
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
// Design intent: keep this file boring. No tracing logic, no shape inference,
|
|
13
13
|
// no codegen — those live in their own modules and consume `Graph` / `OpNode`.
|
|
14
14
|
export function makeGraph() {
|
|
15
|
-
return { ops: [], tensors: [], outputs: [] };
|
|
15
|
+
return { ops: [], tensors: [], outputs: [], captures: new Map() };
|
|
16
16
|
}
|
|
17
17
|
// Internal: register a fresh tensor in the graph and return its id.
|
|
18
18
|
export function addTensor(g, shape, dtype, source, site) {
|
package/dist/ir.js.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"ir.js","sourceRoot":"","sources":["../src/ir.ts"],"names":[],"mappings":"AAAA,uDAAuD;AACvD,EAAE;AACF,gFAAgF;AAChF,+EAA+E;AAC/E,8EAA8E;AAC9E,EAAE;AACF,0DAA0D;AAC1D,uCAAuC;AACvC,8EAA8E;AAC9E,wFAAwF;AACxF,EAAE;AACF,8EAA8E;AAC9E,+EAA+E;
|
|
1
|
+
{"version":3,"file":"ir.js","sourceRoot":"","sources":["../src/ir.ts"],"names":[],"mappings":"AAAA,uDAAuD;AACvD,EAAE;AACF,gFAAgF;AAChF,+EAA+E;AAC/E,8EAA8E;AAC9E,EAAE;AACF,0DAA0D;AAC1D,uCAAuC;AACvC,8EAA8E;AAC9E,wFAAwF;AACxF,EAAE;AACF,8EAA8E;AAC9E,+EAA+E;AAyI/E,MAAM,UAAU,SAAS;IACvB,OAAO,EAAE,GAAG,EAAE,EAAE,EAAE,OAAO,EAAE,EAAE,EAAE,OAAO,EAAE,EAAE,EAAE,QAAQ,EAAE,IAAI,GAAG,EAAE,EAAE,CAAA;AACnE,CAAC;AAED,oEAAoE;AACpE,MAAM,UAAU,SAAS,CAAC,CAAQ,EAAE,KAAY,EAAE,KAAY,EAAE,MAAqB,EAAE,IAAqB;IAC1G,MAAM,EAAE,GAAG,CAAC,CAAC,OAAO,CAAC,MAAM,CAAA;IAC3B,MAAM,CAAC,GAAW,EAAE,EAAE,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,IAAI,EAAE,CAAA;IACpD,CAAC,CAAC,OAAO,CAAC,IAAI,CAAC,CAAC,CAAC,CAAA;IACjB,OAAO,CAAC,CAAA;AACV,CAAC;AAED,kFAAkF;AAClF,0EAA0E;AAC1E,+EAA+E;AAC/E,kFAAkF;AAClF,MAAM,UAAU,KAAK,CACnB,CAAQ,EACR,IAAO,EACP,KAAY,EACZ,KAAY,EACZ,IAAqB,EACrB,MAA0D;IAE1D,MAAM,OAAO,GAAG,CAAC,CAAC,GAAG,CAAC,MAAM,CAAA;IAC5B,MAAM,GAAG,GAAG,SAAS,CAAC,CAAC,EAAE,KAAK,EAAE,KAAK,EAAE,OAAO,EAAE,IAAI,CAAC,CAAA;IACrD,MAAM,IAAI,GAAG,EAAE,IAAI,EAAE,GAAG,EAAE,GAAG,CAAC,EAAE,EAAE,GAAG,MAAM,EAAkC,CAAA;IAC7E,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,CAAA;IAChB,OAAO,GAAG,CAAA;AACZ,CAAC;AAED,0EAA0E;AAC1E,iFAAiF;AACjF,MAAM,UAAU,WAAW,CAAC,MAAc;IACxC,+EAA+E;IAC/E,MAAM,KAAK,GAAG,CAAC,IAAI,KAAK,EAAE,CAAC,CAAC,KAAK,IAAI,EAAE,CAAA;IACvC,OAAO,EAAE,MAAM,EAAE,KAAK,EAAE,CAAA;AAC1B,CAAC;AAED,8EAA8E;AAC9E,2DAA2D;AAC3D,MAAM,UAAU,UAAU,CAAC,IAAc;IACvC,MAAM,KAAK,GAAG,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,IAAI,CAAC,CAAA;IACpC,2EAA2E;IAC3E,iEAAiE;IACjE,MAAM,UAAU,GAAa,EAAE,CAAA;IAC/B,KAAK,MAAM,IAAI,IAAI,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC;QAClC,IAAI,IAAI,CAAC,QAAQ,CAAC,kBAAkB,CAAC,IAAI,IAAI,CAAC,QAAQ,CAAC,qBAAqB,CAAC;YAAE,SAAQ;QACvF,UAAU,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,EAAE,CAAC,CAAA;QAC5B,IAAI,UAAU,CAAC,MAAM,IAAI,CAAC;YAAE,MAAK;IACnC,CAAC;IACD,IAAI,UAAU,CAAC,MAAM,KAAK,CAAC;QAAE,OAAO,IAAI,IAAI,CAAC,MAAM,yBAAyB,CAAA;IAC5E,OAAO,IAAI,IAAI,CAAC,MAAM,QAAQ,UAAU,CAAC,IAAI,CAAC,MAAM,CAAC,EAAE,CAAA;AACzD,CAAC"}
|
package/dist/module.d.ts
CHANGED
|
@@ -1,4 +1,21 @@
|
|
|
1
1
|
import type { Tensor, Shape, Dtype } from './ir.js';
|
|
2
|
+
/** How a parameter's initial values are produced.
|
|
3
|
+
* - `'randn'` — Gaussian, with `scale` (default 0.02). The common case for
|
|
4
|
+
* weight matrices and embeddings.
|
|
5
|
+
* - `'zeros'` — fill with 0. Common for biases and LayerNorm beta.
|
|
6
|
+
* - `'ones'` — fill with 1. Common for LayerNorm gain.
|
|
7
|
+
* - Custom function — receives total element count and shape, returns the
|
|
8
|
+
* Float32Array. Use for fan-in scaling or any non-standard scheme.
|
|
9
|
+
*/
|
|
10
|
+
export type InitSpec = 'randn' | 'zeros' | 'ones' | ((size: number, shape: readonly number[]) => Float32Array);
|
|
11
|
+
export interface ParamOptions {
|
|
12
|
+
dtype?: Dtype;
|
|
13
|
+
/** Init kind. Default: `'randn'`. */
|
|
14
|
+
init?: InitSpec;
|
|
15
|
+
/** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
|
|
16
|
+
scale?: number;
|
|
17
|
+
}
|
|
18
|
+
type InitFn = (size: number, shape: readonly number[]) => Float32Array;
|
|
2
19
|
export declare abstract class Module {
|
|
3
20
|
/**
|
|
4
21
|
* Declare a learnable parameter at this module. Must be called from inside
|
|
@@ -6,16 +23,25 @@ export declare abstract class Module {
|
|
|
6
23
|
* that gets replaced with a real Tensor at compile time.
|
|
7
24
|
*
|
|
8
25
|
* The parameter's name is auto-derived from its property path in the model
|
|
9
|
-
* tree (e.g. `layers.0.attn.W_q`).
|
|
26
|
+
* tree (e.g. `layers.0.attn.W_q`). Init metadata travels with the param;
|
|
27
|
+
* call `compiled.uploadInitialParams()` to apply it after compile.
|
|
10
28
|
*/
|
|
11
|
-
protected param(shape: Shape,
|
|
29
|
+
protected param(shape: Shape, opts?: ParamOptions): Tensor;
|
|
30
|
+
}
|
|
31
|
+
export interface MaterializedParams {
|
|
32
|
+
/** Map from auto-derived path (e.g. `layers.0.attn.W_q`) to its Tensor. */
|
|
33
|
+
tensors: Record<string, Tensor>;
|
|
34
|
+
/** Init function per param path. Used by `uploadInitialParams`. */
|
|
35
|
+
initFns: Record<string, InitFn>;
|
|
12
36
|
}
|
|
13
37
|
/**
|
|
14
38
|
* Walk the module tree and replace every ParamSentinel with a real Tensor
|
|
15
39
|
* created via `paramInput(autoName, ...)`. Must be called inside an active
|
|
16
40
|
* trace context (paramInput appends to the current graph).
|
|
17
41
|
*
|
|
18
|
-
* Returns
|
|
42
|
+
* Returns the param tensors keyed by path, plus init functions for use by
|
|
43
|
+
* `uploadInitialParams`.
|
|
19
44
|
*/
|
|
20
|
-
export declare function materializeParams(root: Module):
|
|
45
|
+
export declare function materializeParams(root: Module): MaterializedParams;
|
|
46
|
+
export {};
|
|
21
47
|
//# sourceMappingURL=module.d.ts.map
|
package/dist/module.d.ts.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"module.d.ts","sourceRoot":"","sources":["../src/module.ts"],"names":[],"mappings":"AA2BA,OAAO,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,SAAS,CAAA;
|
|
1
|
+
{"version":3,"file":"module.d.ts","sourceRoot":"","sources":["../src/module.ts"],"names":[],"mappings":"AA2BA,OAAO,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,SAAS,CAAA;AAOnD;;;;;;;GAOG;AACH,MAAM,MAAM,QAAQ,GAChB,OAAO,GACP,OAAO,GACP,MAAM,GACN,CAAC,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,SAAS,MAAM,EAAE,KAAK,YAAY,CAAC,CAAA;AAE9D,MAAM,WAAW,YAAY;IAC3B,KAAK,CAAC,EAAE,KAAK,CAAA;IACb,qCAAqC;IACrC,IAAI,CAAC,EAAE,QAAQ,CAAA;IACf,uEAAuE;IACvE,KAAK,CAAC,EAAE,MAAM,CAAA;CACf;AAED,KAAK,MAAM,GAAG,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,SAAS,MAAM,EAAE,KAAK,YAAY,CAAA;AA2CtE,8BAAsB,MAAM;IAC1B;;;;;;;;OAQG;IACH,SAAS,CAAC,KAAK,CAAC,KAAK,EAAE,KAAK,EAAE,IAAI,CAAC,EAAE,YAAY,GAAG,MAAM;CAK3D;AAMD,MAAM,WAAW,kBAAkB;IACjC,2EAA2E;IAC3E,OAAO,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IAC/B,mEAAmE;IACnE,OAAO,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;CAChC;AAED;;;;;;;GAOG;AACH,wBAAgB,iBAAiB,CAAC,IAAI,EAAE,MAAM,GAAG,kBAAkB,CAYlE"}
|
package/dist/module.js
CHANGED
|
@@ -6,8 +6,8 @@
|
|
|
6
6
|
// W: Tensor; b: Tensor
|
|
7
7
|
// constructor(inDim: number, outDim: number) {
|
|
8
8
|
// super()
|
|
9
|
-
// this.W = this.param([inDim, outDim])
|
|
10
|
-
// this.b = this.param([outDim])
|
|
9
|
+
// this.W = this.param([inDim, outDim]) // randn, scale 0.02
|
|
10
|
+
// this.b = this.param([outDim], { init: 'zeros' })
|
|
11
11
|
// }
|
|
12
12
|
// }
|
|
13
13
|
// class Block extends Module {
|
|
@@ -25,6 +25,28 @@
|
|
|
25
25
|
// and writeback wiring. Forward functions are pure and stateless — they
|
|
26
26
|
// take the materialized model and inputs, return a Tensor.
|
|
27
27
|
import { paramInput } from './trace.js';
|
|
28
|
+
function boxMuller() {
|
|
29
|
+
return Math.sqrt(-2 * Math.log(Math.max(1e-10, Math.random()))) * Math.cos(2 * Math.PI * Math.random());
|
|
30
|
+
}
|
|
31
|
+
function resolveInit(opts) {
|
|
32
|
+
const init = opts?.init ?? 'randn';
|
|
33
|
+
if (init === 'randn') {
|
|
34
|
+
const scale = opts?.scale ?? 0.02;
|
|
35
|
+
return (size) => {
|
|
36
|
+
const arr = new Float32Array(size);
|
|
37
|
+
for (let i = 0; i < size; i++)
|
|
38
|
+
arr[i] = boxMuller() * scale;
|
|
39
|
+
return arr;
|
|
40
|
+
};
|
|
41
|
+
}
|
|
42
|
+
if (init === 'zeros')
|
|
43
|
+
return (size) => new Float32Array(size);
|
|
44
|
+
if (init === 'ones')
|
|
45
|
+
return (size) => { const a = new Float32Array(size); a.fill(1); return a; };
|
|
46
|
+
if (typeof init === 'function')
|
|
47
|
+
return init;
|
|
48
|
+
throw new Error(`Unknown init: ${String(init)}`);
|
|
49
|
+
}
|
|
28
50
|
// ============================================================================
|
|
29
51
|
// Internals: param sentinel
|
|
30
52
|
// ============================================================================
|
|
@@ -36,9 +58,11 @@ import { paramInput } from './trace.js';
|
|
|
36
58
|
class ParamSentinel {
|
|
37
59
|
shape;
|
|
38
60
|
dtype;
|
|
39
|
-
|
|
61
|
+
initFn;
|
|
62
|
+
constructor(shape, dtype, initFn) {
|
|
40
63
|
this.shape = shape;
|
|
41
64
|
this.dtype = dtype;
|
|
65
|
+
this.initFn = initFn;
|
|
42
66
|
}
|
|
43
67
|
}
|
|
44
68
|
// ============================================================================
|
|
@@ -51,33 +75,35 @@ export class Module {
|
|
|
51
75
|
* that gets replaced with a real Tensor at compile time.
|
|
52
76
|
*
|
|
53
77
|
* The parameter's name is auto-derived from its property path in the model
|
|
54
|
-
* tree (e.g. `layers.0.attn.W_q`).
|
|
78
|
+
* tree (e.g. `layers.0.attn.W_q`). Init metadata travels with the param;
|
|
79
|
+
* call `compiled.uploadInitialParams()` to apply it after compile.
|
|
55
80
|
*/
|
|
56
|
-
param(shape,
|
|
81
|
+
param(shape, opts) {
|
|
82
|
+
const dtype = opts?.dtype ?? 'f32';
|
|
57
83
|
// Lie to TypeScript: the sentinel becomes a Tensor at materialize time.
|
|
58
|
-
return new ParamSentinel(shape, dtype);
|
|
84
|
+
return new ParamSentinel(shape, dtype, resolveInit(opts));
|
|
59
85
|
}
|
|
60
86
|
}
|
|
61
|
-
// ============================================================================
|
|
62
|
-
// Tree walking
|
|
63
|
-
// ============================================================================
|
|
64
87
|
/**
|
|
65
88
|
* Walk the module tree and replace every ParamSentinel with a real Tensor
|
|
66
89
|
* created via `paramInput(autoName, ...)`. Must be called inside an active
|
|
67
90
|
* trace context (paramInput appends to the current graph).
|
|
68
91
|
*
|
|
69
|
-
* Returns
|
|
92
|
+
* Returns the param tensors keyed by path, plus init functions for use by
|
|
93
|
+
* `uploadInitialParams`.
|
|
70
94
|
*/
|
|
71
95
|
export function materializeParams(root) {
|
|
72
|
-
const
|
|
96
|
+
const tensors = {};
|
|
97
|
+
const initFns = {};
|
|
73
98
|
visit(root, '', (path, val, owner, key) => {
|
|
74
99
|
if (val instanceof ParamSentinel) {
|
|
75
100
|
const t = paramInput(path, val.shape, val.dtype);
|
|
76
101
|
owner[key] = t;
|
|
77
|
-
|
|
102
|
+
tensors[path] = t;
|
|
103
|
+
initFns[path] = val.initFn;
|
|
78
104
|
}
|
|
79
105
|
});
|
|
80
|
-
return
|
|
106
|
+
return { tensors, initFns };
|
|
81
107
|
}
|
|
82
108
|
function visit(node, path, visitor) {
|
|
83
109
|
if (node === null || node === undefined)
|
package/dist/module.js.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"module.js","sourceRoot":"","sources":["../src/module.ts"],"names":[],"mappings":"AAAA,6EAA6E;AAC7E,EAAE;AACF,+CAA+C;AAC/C,EAAE;AACF,kCAAkC;AAClC,2BAA2B;AAC3B,mDAAmD;AACnD,gBAAgB;AAChB,
|
|
1
|
+
{"version":3,"file":"module.js","sourceRoot":"","sources":["../src/module.ts"],"names":[],"mappings":"AAAA,6EAA6E;AAC7E,EAAE;AACF,+CAA+C;AAC/C,EAAE;AACF,kCAAkC;AAClC,2BAA2B;AAC3B,mDAAmD;AACnD,gBAAgB;AAChB,gFAAgF;AAChF,yDAAyD;AACzD,QAAQ;AACR,MAAM;AACN,iCAAiC;AACjC,8BAA8B;AAC9B,+BAA+B;AAC/B,MAAM;AACN,iCAAiC;AACjC,mCAAmC;AACnC,+CAA+C;AAC/C,MAAM;AACN,EAAE;AACF,wEAAwE;AACxE,0EAA0E;AAC1E,0EAA0E;AAC1E,wEAAwE;AACxE,2DAA2D;AAG3D,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AA8BvC,SAAS,SAAS;IAChB,OAAO,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,GAAG,CAAC,KAAK,EAAE,IAAI,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,EAAE,GAAG,IAAI,CAAC,MAAM,EAAE,CAAC,CAAA;AACzG,CAAC;AAED,SAAS,WAAW,CAAC,IAA8B;IACjD,MAAM,IAAI,GAAG,IAAI,EAAE,IAAI,IAAI,OAAO,CAAA;IAClC,IAAI,IAAI,KAAK,OAAO,EAAE,CAAC;QACrB,MAAM,KAAK,GAAG,IAAI,EAAE,KAAK,IAAI,IAAI,CAAA;QACjC,OAAO,CAAC,IAAI,EAAE,EAAE;YACd,MAAM,GAAG,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,CAAA;YAClC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE;gBAAE,GAAG,CAAC,CAAC,CAAC,GAAG,SAAS,EAAE,GAAG,KAAK,CAAA;YAC3D,OAAO,GAAG,CAAA;QACZ,CAAC,CAAA;IACH,CAAC;IACD,IAAI,IAAI,KAAK,OAAO;QAAE,OAAO,CAAC,IAAI,EAAE,EAAE,CAAC,IAAI,YAAY,CAAC,IAAI,CAAC,CAAA;IAC7D,IAAI,IAAI,KAAK,MAAM;QAAE,OAAO,CAAC,IAAI,EAAE,EAAE,GAAG,MAAM,CAAC,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAA,CAAC,CAAC,CAAA;IAC/F,IAAI,OAAO,IAAI,KAAK,UAAU;QAAE,OAAO,IAAI,CAAA;IAC3C,MAAM,IAAI,KAAK,CAAC,iBAAiB,MAAM,CAAC,IAAI,CAAC,EAAE,CAAC,CAAA;AAClD,CAAC;AAED,+EAA+E;AAC/E,4BAA4B;AAC5B,+EAA+E;AAC/E,EAAE;AACF,6EAA6E;AAC7E,4EAA4E;AAC5E,2EAA2E;AAC3E,yEAAyE;AAEzE,MAAM,aAAa;IAEC;IACA;IACA;IAHlB,YACkB,KAAY,EACZ,KAAY,EACZ,MAAc;QAFd,UAAK,GAAL,KAAK,CAAO;QACZ,UAAK,GAAL,KAAK,CAAO;QACZ,WAAM,GAAN,MAAM,CAAQ;IAC7B,CAAC;CACL;AAED,+EAA+E;AAC/E,oBAAoB;AACpB,+EAA+E;AAE/E,MAAM,OAAgB,MAAM;IAC1B;;;;;;;;OAQG;IACO,KAAK,CAAC,KAAY,EAAE,IAAmB;QAC/C,MAAM,KAAK,GAAG,IAAI,EAAE,KAAK,IAAI,KAAK,CAAA;QAClC,wEAAwE;QACxE,OAAO,IAAI,aAAa,CAAC,KAAK,EAAE,KAAK,EAAE,WAAW,CAAC,IAAI,CAAC,CAAsB,CAAA;IAChF,CAAC;CACF;AAaD;;;;;;;GAOG;AACH,MAAM,UAAU,iBAAiB,CAAC,IAAY;IAC5C,MAAM,OAAO,GAA2B,EAAE,CAAA;IAC1C,MAAM,OAAO,GAA2B,EAAE,CAAA;IAC1C,KAAK,CAAC,IAAI,EAAE,EAAE,EAAE,CAAC,IAAI,EAAE,GAAG,EAAE,KAAK,EAAE,GAAG,EAAE,EAAE;QACxC,IAAI,GAAG,YAAY,aAAa,EAAE,CAAC;YACjC,MAAM,CAAC,GAAG,UAAU,CAAC,IAAI,EAAE,GAAG,CAAC,KAAK,EAAE,GAAG,CAAC,KAAK,CAAC,CAC/C;YAAC,KAAa,CAAC,GAAG,CAAC,GAAG,CAAC,CAAA;YACxB,OAAO,CAAC,IAAI,CAAC,GAAG,CAAC,CAAA;YACjB,OAAO,CAAC,IAAI,CAAC,GAAG,GAAG,CAAC,MAAM,CAAA;QAC5B,CAAC;IACH,CAAC,CAAC,CAAA;IACF,OAAO,EAAE,OAAO,EAAE,OAAO,EAAE,CAAA;AAC7B,CAAC;AAaD,SAAS,KAAK,CAAC,IAAa,EAAE,IAAY,EAAE,OAAgB;IAC1D,IAAI,IAAI,KAAK,IAAI,IAAI,IAAI,KAAK,SAAS;QAAE,OAAM;IAC/C,IAAI,OAAO,IAAI,KAAK,QAAQ;QAAE,OAAM;IAEpC,IAAI,IAAI,YAAY,MAAM,EAAE,CAAC;QAC3B,KAAK,MAAM,GAAG,IAAI,MAAM,CAAC,IAAI,CAAC,IAAc,CAAC,EAAE,CAAC;YAC9C,MAAM,KAAK,GAAI,IAAY,CAAC,GAAG,CAAC,CAAA;YAChC,MAAM,SAAS,GAAG,IAAI,CAAC,CAAC,CAAC,GAAG,IAAI,IAAI,GAAG,EAAE,CAAC,CAAC,CAAC,GAAG,CAAA;YAC/C,UAAU,CAAC,KAAK,EAAE,SAAS,EAAE,IAAI,EAAE,GAAG,EAAE,OAAO,CAAC,CAAA;QAClD,CAAC;QACD,OAAM;IACR,CAAC;IACD,IAAI,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,EAAE,CAAC;QACxB,IAAI,CAAC,OAAO,CAAC,CAAC,IAAI,EAAE,CAAC,EAAE,EAAE;YACvB,MAAM,SAAS,GAAG,IAAI,CAAC,CAAC,CAAC,GAAG,IAAI,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAA;YACnD,UAAU,CAAC,IAAI,EAAE,SAAS,EAAE,IAAyB,EAAE,CAAC,EAAE,OAAO,CAAC,CAAA;QACpE,CAAC,CAAC,CAAA;QACF,OAAM;IACR,CAAC;IACD,2EAA2E;IAC3E,uBAAuB;AACzB,CAAC;AAED,SAAS,UAAU,CAAC,KAAc,EAAE,IAAY,EAAE,KAAa,EAAE,GAAoB,EAAE,OAAgB;IACrG,IAAI,KAAK,YAAY,MAAM,IAAI,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE,CAAC;QACpD,KAAK,CAAC,KAAK,EAAE,IAAI,EAAE,OAAO,CAAC,CAAA;IAC7B,CAAC;SAAM,CAAC;QACN,OAAO,CAAC,IAAI,EAAE,KAAK,EAAE,KAAK,EAAE,GAAG,CAAC,CAAA;IAClC,CAAC;AACH,CAAC"}
|
package/dist/nn.d.ts
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import { Module } from './module.js';
|
|
2
|
+
import type { Tensor } from './ir.js';
|
|
3
|
+
export declare class Linear extends Module {
|
|
4
|
+
readonly inDim: number;
|
|
5
|
+
readonly outDim: number;
|
|
6
|
+
W: Tensor;
|
|
7
|
+
b: Tensor | null;
|
|
8
|
+
constructor(inDim: number, outDim: number, withBias?: boolean);
|
|
9
|
+
}
|
|
10
|
+
export declare function linearFwd(p: Linear, x: Tensor): Tensor;
|
|
11
|
+
export declare class LayerNorm extends Module {
|
|
12
|
+
readonly d: number;
|
|
13
|
+
readonly eps: number;
|
|
14
|
+
g: Tensor;
|
|
15
|
+
b: Tensor;
|
|
16
|
+
constructor(d: number, eps?: number);
|
|
17
|
+
}
|
|
18
|
+
export declare function layerNormFwd(p: LayerNorm, x: Tensor): Tensor;
|
|
19
|
+
//# sourceMappingURL=nn.d.ts.map
|
package/dist/nn.d.ts.map
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"nn.d.ts","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"AAeA,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAA;AACpC,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,SAAS,CAAA;AAOrC,qBAAa,MAAO,SAAQ,MAAM;aAGJ,KAAK,EAAE,MAAM;aAAkB,MAAM,EAAE,MAAM;IAFzE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,GAAG,IAAI,CAAA;gBACY,KAAK,EAAE,MAAM,EAAkB,MAAM,EAAE,MAAM,EAAE,QAAQ,UAAO;CAK3F;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAGtD;AAMD,qBAAa,SAAU,SAAQ,MAAM;aAGP,CAAC,EAAE,MAAM;aAAkB,GAAG,EAAE,MAAM;IAFlE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,CAAA;gBACmB,CAAC,EAAE,MAAM,EAAkB,GAAG,GAAE,MAAa;CAK1E;AAED,wBAAgB,YAAY,CAAC,CAAC,EAAE,SAAS,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAM5D"}
|
package/dist/nn.js
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
// Standard "batteries-included" Module subclasses for the most common layers.
|
|
2
|
+
//
|
|
3
|
+
// JAX-style: each class declares its params (and their init); the forward is a
|
|
4
|
+
// plain function the user calls with `(module, x)`. No subclassing, no method
|
|
5
|
+
// dispatch — keeps the autograd-traced computation visible at the call site.
|
|
6
|
+
//
|
|
7
|
+
// Import as a namespace:
|
|
8
|
+
//
|
|
9
|
+
// import { nn } from 'tensorgrad'
|
|
10
|
+
// class Block extends Module {
|
|
11
|
+
// ln = new nn.LayerNorm(D)
|
|
12
|
+
// ffn = new nn.Linear(D, 4 * D)
|
|
13
|
+
// }
|
|
14
|
+
// const y = nn.linearFwd(p.ffn, nn.layerNormFwd(p.ln, x))
|
|
15
|
+
import { Module } from './module.js';
|
|
16
|
+
import { add, matmul, sub, mul, div, sqrt, meanLast } from './ops.js';
|
|
17
|
+
// ----------------------------------------------------------------------------
|
|
18
|
+
// Linear: y = x @ W (+ b)
|
|
19
|
+
// ----------------------------------------------------------------------------
|
|
20
|
+
export class Linear extends Module {
|
|
21
|
+
inDim;
|
|
22
|
+
outDim;
|
|
23
|
+
W;
|
|
24
|
+
b;
|
|
25
|
+
constructor(inDim, outDim, withBias = true) {
|
|
26
|
+
super();
|
|
27
|
+
this.inDim = inDim;
|
|
28
|
+
this.outDim = outDim;
|
|
29
|
+
this.W = this.param([inDim, outDim]); // randn, scale 0.02
|
|
30
|
+
this.b = withBias ? this.param([outDim], { init: 'zeros' }) : null;
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
export function linearFwd(p, x) {
|
|
34
|
+
const out = matmul(x, p.W);
|
|
35
|
+
return p.b ? add(out, p.b) : out;
|
|
36
|
+
}
|
|
37
|
+
// ----------------------------------------------------------------------------
|
|
38
|
+
// LayerNorm — normalizes over the last axis. eps defaults to 1e-5.
|
|
39
|
+
// ----------------------------------------------------------------------------
|
|
40
|
+
export class LayerNorm extends Module {
|
|
41
|
+
d;
|
|
42
|
+
eps;
|
|
43
|
+
g;
|
|
44
|
+
b;
|
|
45
|
+
constructor(d, eps = 1e-5) {
|
|
46
|
+
super();
|
|
47
|
+
this.d = d;
|
|
48
|
+
this.eps = eps;
|
|
49
|
+
this.g = this.param([d], { init: 'ones' });
|
|
50
|
+
this.b = this.param([d], { init: 'zeros' });
|
|
51
|
+
}
|
|
52
|
+
}
|
|
53
|
+
export function layerNormFwd(p, x) {
|
|
54
|
+
const m = meanLast(x);
|
|
55
|
+
const c = sub(x, m);
|
|
56
|
+
const v = meanLast(mul(c, c));
|
|
57
|
+
const stdev = sqrt(add(v, p.eps));
|
|
58
|
+
return add(mul(div(c, stdev), p.g), p.b);
|
|
59
|
+
}
|
|
60
|
+
//# sourceMappingURL=nn.js.map
|
package/dist/nn.js.map
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"nn.js","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"AAAA,8EAA8E;AAC9E,EAAE;AACF,+EAA+E;AAC/E,8EAA8E;AAC9E,6EAA6E;AAC7E,EAAE;AACF,yBAAyB;AACzB,EAAE;AACF,oCAAoC;AACpC,iCAAiC;AACjC,gCAAgC;AAChC,oCAAoC;AACpC,MAAM;AACN,4DAA4D;AAE5D,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAA;AAEpC,OAAO,EAAE,GAAG,EAAE,MAAM,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI,EAAE,QAAQ,EAAE,MAAM,UAAU,CAAA;AAErE,+EAA+E;AAC/E,0BAA0B;AAC1B,+EAA+E;AAE/E,MAAM,OAAO,MAAO,SAAQ,MAAM;IAGJ;IAA+B;IAF3D,CAAC,CAAQ;IACT,CAAC,CAAe;IAChB,YAA4B,KAAa,EAAkB,MAAc,EAAE,QAAQ,GAAG,IAAI;QACxF,KAAK,EAAE,CAAA;QADmB,UAAK,GAAL,KAAK,CAAQ;QAAkB,WAAM,GAAN,MAAM,CAAQ;QAEvE,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,MAAM,CAAC,CAAC,CAAA,CAAsB,oBAAoB;QAC9E,IAAI,CAAC,CAAC,GAAG,QAAQ,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,MAAM,CAAC,EAAE,EAAE,IAAI,EAAE,OAAO,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI,CAAA;IACpE,CAAC;CACF;AAED,MAAM,UAAU,SAAS,CAAC,CAAS,EAAE,CAAS;IAC5C,MAAM,GAAG,GAAG,MAAM,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAA;IAC1B,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAA;AAClC,CAAC;AAED,+EAA+E;AAC/E,mEAAmE;AACnE,+EAA+E;AAE/E,MAAM,OAAO,SAAU,SAAQ,MAAM;IAGP;IAA2B;IAFvD,CAAC,CAAQ;IACT,CAAC,CAAQ;IACT,YAA4B,CAAS,EAAkB,MAAc,IAAI;QACvE,KAAK,EAAE,CAAA;QADmB,MAAC,GAAD,CAAC,CAAQ;QAAkB,QAAG,GAAH,GAAG,CAAe;QAEvE,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,IAAI,EAAE,MAAM,EAAE,CAAC,CAAA;QAC1C,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,IAAI,EAAE,OAAO,EAAE,CAAC,CAAA;IAC7C,CAAC;CACF;AAED,MAAM,UAAU,YAAY,CAAC,CAAY,EAAE,CAAS;IAClD,MAAM,CAAC,GAAG,QAAQ,CAAC,CAAC,CAAC,CAAA;IACrB,MAAM,CAAC,GAAG,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAA;IACnB,MAAM,CAAC,GAAG,QAAQ,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAA;IAC7B,MAAM,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,GAAG,CAAC,CAAC,CAAA;IACjC,OAAO,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAA;AAC1C,CAAC"}
|
package/dist/ops.d.ts
CHANGED
|
@@ -31,5 +31,5 @@ export declare function where(cond: Tensor, a: Tensor, b: Tensor): Tensor;
|
|
|
31
31
|
export declare function reluGrad(x: Tensor, dy: Tensor): Tensor;
|
|
32
32
|
export declare function adamUpdateM(m: Tensor, g: Tensor, b1: number): Tensor;
|
|
33
33
|
export declare function adamUpdateV(v: Tensor, g: Tensor, b2: number): Tensor;
|
|
34
|
-
export declare function adamUpdateP(p: Tensor, mNew: Tensor, vNew: Tensor, lrt: Tensor, eps: number): Tensor;
|
|
34
|
+
export declare function adamUpdateP(p: Tensor, mNew: Tensor, vNew: Tensor, lrt: Tensor, eps: number, decayShrink?: number): Tensor;
|
|
35
35
|
//# sourceMappingURL=ops.d.ts.map
|
package/dist/ops.d.ts.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"ops.d.ts","sourceRoot":"","sources":["../src/ops.ts"],"names":[],"mappings":"AAWA,OAAO,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAU,MAAM,SAAS,CAAA;AAmC3D,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAMzD;AAQD,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAG3D;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAG3D;AAYD,eAAO,MAAM,IAAI,GAAK,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,KAAK,GAAI,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,GAAG,GAAM,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,GAAG,GAAM,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,IAAI,GAAK,GAAG,MAAM,KAAG,MAA2B,CAAA;AAO7D,wBAAgB,QAAQ,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAK1C;AAED,wBAAgB,OAAO,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAKzC;AAMD,wBAAgB,OAAO,CAAC,CAAC,EAAE,MAAM,EAAE,QAAQ,EAAE,KAAK,GAAG,MAAM,CAI1D;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,SAAS,MAAM,EAAE,GAAG,MAAM,CAIpE;AAMD,wBAAgB,MAAM,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAOnD;AAED,wBAAgB,aAAa,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAO1D;AAMD,wBAAgB,MAAM,CAAC,OAAO,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAOnF;AAGD,wBAAgB,MAAM,CAAC,CAAC,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAM9D;AASD,wBAAgB,iBAAiB,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAKnD;AAGD,wBAAgB,cAAc,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAIhD;AAMD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,SAAS,EAAE,MAAM,GAAG,MAAM,CAKhE;AAQD,wBAAgB,cAAc,CAAC,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,GAAG,EAAE,MAAM,GAAG,MAAM,CAI5E;AAOD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,WAAW,EAAE,KAAK,GAAG,MAAM,CAIjE;AAED,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,EAAE,WAAW,EAAE,KAAK,GAAG,MAAM,CAIhE;AAOD,wBAAgB,WAAW,CAAC,KAAK,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAGvE;AAWD,eAAO,MAAM,IAAI,GAAO,GAAG,MAAM,EAAE,GAAG,MAAM,KAAG,MAAqD,CAAA;AACpG,eAAO,MAAM,OAAO,GAAI,GAAG,MAAM,EAAE,GAAG,MAAM,KAAG,MAAqD,CAAA;AAGpG,wBAAgB,KAAK,CAAC,IAAI,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAMhE;AAID,wBAAgB,QAAQ,CAAC,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOtD;AAMD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOpE;AAED,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOpE;AAED,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,GAAG,EAAE,MAAM,EAAE,GAAG,EAAE,MAAM,GAAG,MAAM,
|
|
1
|
+
{"version":3,"file":"ops.d.ts","sourceRoot":"","sources":["../src/ops.ts"],"names":[],"mappings":"AAWA,OAAO,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAU,MAAM,SAAS,CAAA;AAmC3D,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAMzD;AAQD,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAG3D;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAG3D;AAYD,eAAO,MAAM,IAAI,GAAK,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,KAAK,GAAI,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,GAAG,GAAM,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,GAAG,GAAM,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,IAAI,GAAK,GAAG,MAAM,KAAG,MAA2B,CAAA;AAO7D,wBAAgB,QAAQ,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAK1C;AAED,wBAAgB,OAAO,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAKzC;AAMD,wBAAgB,OAAO,CAAC,CAAC,EAAE,MAAM,EAAE,QAAQ,EAAE,KAAK,GAAG,MAAM,CAI1D;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,SAAS,MAAM,EAAE,GAAG,MAAM,CAIpE;AAMD,wBAAgB,MAAM,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAOnD;AAED,wBAAgB,aAAa,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAO1D;AAMD,wBAAgB,MAAM,CAAC,OAAO,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAOnF;AAGD,wBAAgB,MAAM,CAAC,CAAC,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAM9D;AASD,wBAAgB,iBAAiB,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAKnD;AAGD,wBAAgB,cAAc,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAIhD;AAMD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,SAAS,EAAE,MAAM,GAAG,MAAM,CAKhE;AAQD,wBAAgB,cAAc,CAAC,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,GAAG,EAAE,MAAM,GAAG,MAAM,CAI5E;AAOD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,WAAW,EAAE,KAAK,GAAG,MAAM,CAIjE;AAED,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,EAAE,WAAW,EAAE,KAAK,GAAG,MAAM,CAIhE;AAOD,wBAAgB,WAAW,CAAC,KAAK,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAGvE;AAWD,eAAO,MAAM,IAAI,GAAO,GAAG,MAAM,EAAE,GAAG,MAAM,KAAG,MAAqD,CAAA;AACpG,eAAO,MAAM,OAAO,GAAI,GAAG,MAAM,EAAE,GAAG,MAAM,KAAG,MAAqD,CAAA;AAGpG,wBAAgB,KAAK,CAAC,IAAI,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAMhE;AAID,wBAAgB,QAAQ,CAAC,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOtD;AAMD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOpE;AAED,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOpE;AAED,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,GAAG,EAAE,MAAM,EAAE,GAAG,EAAE,MAAM,EAAE,WAAW,GAAE,MAAU,GAAG,MAAM,CAW5H"}
|
package/dist/ops.js
CHANGED
|
@@ -255,7 +255,7 @@ export function adamUpdateV(v, g, b2) {
|
|
|
255
255
|
}
|
|
256
256
|
return addOp(currentGraph(), 'adam_update_v', v.shape, 'f32', site, { v: v.id, g: g.id, b2 });
|
|
257
257
|
}
|
|
258
|
-
export function adamUpdateP(p, mNew, vNew, lrt, eps) {
|
|
258
|
+
export function adamUpdateP(p, mNew, vNew, lrt, eps, decayShrink = 1) {
|
|
259
259
|
const site = captureSite('adamUpdateP');
|
|
260
260
|
if (p.dtype !== 'f32')
|
|
261
261
|
throw new ShapeError(`adamUpdateP: requires f32`, site);
|
|
@@ -265,6 +265,6 @@ export function adamUpdateP(p, mNew, vNew, lrt, eps) {
|
|
|
265
265
|
if (p.shape.length !== mNew.shape.length || p.shape.some((d, i) => d !== mNew.shape[i])) {
|
|
266
266
|
throw new ShapeError(`adamUpdateP: p/mNew shape mismatch`, site);
|
|
267
267
|
}
|
|
268
|
-
return addOp(currentGraph(), 'adam_update_p', p.shape, 'f32', site, { p: p.id, mNew: mNew.id, vNew: vNew.id, lrt: lrt.id, eps });
|
|
268
|
+
return addOp(currentGraph(), 'adam_update_p', p.shape, 'f32', site, { p: p.id, mNew: mNew.id, vNew: vNew.id, lrt: lrt.id, eps, decayShrink });
|
|
269
269
|
}
|
|
270
270
|
//# sourceMappingURL=ops.js.map
|