tensorgrad 0.0.11 → 0.0.13
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +119 -119
- package/dist/buffers.js +1 -6
- package/dist/buffers.js.map +1 -1
- package/dist/codegen.js +30 -28
- package/dist/codegen.js.map +1 -1
- package/dist/compile.js +39 -68
- package/dist/compile.js.map +1 -1
- package/dist/grad.js +1 -14
- package/dist/grad.js.map +1 -1
- package/dist/index.d.ts +740 -14
- package/dist/runtime.js +9 -11
- package/dist/runtime.js.map +1 -1
- package/dist/trace.js +8 -13
- package/dist/trace.js.map +1 -1
- package/package.json +67 -61
- package/src/buffers.ts +1 -6
- package/src/codegen.ts +31 -28
- package/src/compile.ts +45 -91
- package/src/grad.ts +1 -11
- package/src/index.ts +47 -47
- package/src/runtime.ts +520 -515
- package/src/trace.ts +12 -9
- package/dist/adam.d.ts +0 -65
- package/dist/adam.d.ts.map +0 -1
- package/dist/buffers.d.ts +0 -57
- package/dist/buffers.d.ts.map +0 -1
- package/dist/capture.d.ts +0 -3
- package/dist/capture.d.ts.map +0 -1
- package/dist/codegen.d.ts +0 -23
- package/dist/codegen.d.ts.map +0 -1
- package/dist/compile.d.ts +0 -130
- package/dist/compile.d.ts.map +0 -1
- package/dist/grad.d.ts +0 -8
- package/dist/grad.d.ts.map +0 -1
- package/dist/index.d.ts.map +0 -1
- package/dist/ir.d.ts +0 -207
- package/dist/ir.d.ts.map +0 -1
- package/dist/module.d.ts +0 -55
- package/dist/module.d.ts.map +0 -1
- package/dist/nn.d.ts +0 -42
- package/dist/nn.d.ts.map +0 -1
- package/dist/ops.d.ts +0 -48
- package/dist/ops.d.ts.map +0 -1
- package/dist/runtime.d.ts +0 -108
- package/dist/runtime.d.ts.map +0 -1
- package/dist/shape.d.ts +0 -24
- package/dist/shape.d.ts.map +0 -1
- package/dist/trace.d.ts +0 -9
- package/dist/trace.d.ts.map +0 -1
package/src/trace.ts
CHANGED
|
@@ -84,20 +84,25 @@ export function traceInto<T>(g: Graph, fn: () => T): T {
|
|
|
84
84
|
// Their .source on the Tensor points at that node so codegen knows where to
|
|
85
85
|
// bind external data.
|
|
86
86
|
|
|
87
|
+
// Param/tensor inputs share a namespace (a step() call passes both as keys in
|
|
88
|
+
// the same dispatch object); state inputs have their own namespace.
|
|
89
|
+
type NamedInputKind = 'param_input' | 'tensor_input' | 'state_input'
|
|
90
|
+
function assertNameUnused(g: Graph, name: string, kinds: NamedInputKind[], label: string): void {
|
|
91
|
+
if (g.ops.some(op => kinds.includes(op.kind as NamedInputKind) && (op as { name?: string }).name === name)) {
|
|
92
|
+
throw new Error(`tensorgrad: ${label} name '${name}' already used in this trace`)
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
|
|
87
96
|
export function paramInput(name: string, shape: Shape, dtype: Dtype = 'f32'): Tensor {
|
|
88
97
|
const g = currentGraph()
|
|
89
|
-
|
|
90
|
-
throw new Error(`tensorgrad: input name '${name}' already used in this trace`)
|
|
91
|
-
}
|
|
98
|
+
assertNameUnused(g, name, ['param_input', 'tensor_input'], 'input')
|
|
92
99
|
const site = captureSite('paramInput')
|
|
93
100
|
return addOp(g, 'param_input', shape, dtype, site, { name } as any)
|
|
94
101
|
}
|
|
95
102
|
|
|
96
103
|
export function tensorInput(name: string, shape: Shape, dtype: Dtype = 'f32'): Tensor {
|
|
97
104
|
const g = currentGraph()
|
|
98
|
-
|
|
99
|
-
throw new Error(`tensorgrad: input name '${name}' already used in this trace`)
|
|
100
|
-
}
|
|
105
|
+
assertNameUnused(g, name, ['param_input', 'tensor_input'], 'input')
|
|
101
106
|
const site = captureSite('tensorInput')
|
|
102
107
|
return addOp(g, 'tensor_input', shape, dtype, site, { name } as any)
|
|
103
108
|
}
|
|
@@ -106,9 +111,7 @@ export function tensorInput(name: string, shape: Shape, dtype: Dtype = 'f32'): T
|
|
|
106
111
|
// and updated across step() calls via writebacks declared by the optimizer helper.
|
|
107
112
|
export function stateInput(name: string, shape: Shape, dtype: Dtype = 'f32', initValue = 0): Tensor {
|
|
108
113
|
const g = currentGraph()
|
|
109
|
-
|
|
110
|
-
throw new Error(`tensorgrad: state name '${name}' already used in this trace`)
|
|
111
|
-
}
|
|
114
|
+
assertNameUnused(g, name, ['state_input'], 'state')
|
|
112
115
|
const site = captureSite('stateInput')
|
|
113
116
|
return addOp(g, 'state_input', shape, dtype, site, { name, initValue } as any)
|
|
114
117
|
}
|
package/dist/adam.d.ts
DELETED
|
@@ -1,65 +0,0 @@
|
|
|
1
|
-
import type { Tensor } from './ir.js';
|
|
2
|
-
import type { Graph } from './ir.js';
|
|
3
|
-
import type { WritebackDecl } from './buffers.js';
|
|
4
|
-
export interface AdamConfig {
|
|
5
|
-
/** Constant scalar (e.g., `0.005`) or a per-step schedule function
|
|
6
|
-
* `(step) => lr`. Schedule fn lets the user implement linear/cosine decay
|
|
7
|
-
* or warmup; first call passes `step=1`. Decay-shrink (AdamW) updates
|
|
8
|
-
* per-step automatically when this is a function. */
|
|
9
|
-
lr: number | ((step: number) => number);
|
|
10
|
-
b1?: number;
|
|
11
|
-
b2?: number;
|
|
12
|
-
eps?: number;
|
|
13
|
-
/** AdamW: decoupled weight decay coefficient. Default 0 (plain Adam).
|
|
14
|
-
* When non-zero, every step shrinks each decayed param by a factor of
|
|
15
|
-
* `1 - lr * weightDecay` before the gradient update. */
|
|
16
|
-
weightDecay?: number;
|
|
17
|
-
/** Filter deciding which params get weight decay. Only consulted when
|
|
18
|
-
* weightDecay > 0. Default: decay every param. Override for the standard
|
|
19
|
-
* transformer convention (decay weights/embeddings, skip biases + LN gains).
|
|
20
|
-
* Example: `(name) => name.includes('.W') || name.endsWith('_emb')`. */
|
|
21
|
-
decayFilter?: (paramName: string) => boolean;
|
|
22
|
-
}
|
|
23
|
-
/** Resolved hyperparameters: lr is the schedule fn (constants are wrapped). */
|
|
24
|
-
export interface AdamResolvedConfig {
|
|
25
|
-
lr: (step: number) => number;
|
|
26
|
-
b1: number;
|
|
27
|
-
b2: number;
|
|
28
|
-
eps: number;
|
|
29
|
-
weightDecay: number;
|
|
30
|
-
decayFilter: (name: string) => boolean;
|
|
31
|
-
/** True iff the user supplied an lr function (vs a constant). When false,
|
|
32
|
-
* decayShrink is baked at compile time and never updated. */
|
|
33
|
-
lrIsScheduled: boolean;
|
|
34
|
-
}
|
|
35
|
-
export interface AdamResult {
|
|
36
|
-
/** Writebacks the buffer planner should wire into the runtime. */
|
|
37
|
-
writebacks: WritebackDecl[];
|
|
38
|
-
/** Name of the per-step scalar tensor_input. The runtime fills this each call
|
|
39
|
-
* with `lr * sqrt(1-b2^t)/(1-b1^t)` (Adam's bias-corrected effective LR). */
|
|
40
|
-
lrtInputName: string;
|
|
41
|
-
/** Name of the per-step decayShrink scalar tensor_input, or null when lr is
|
|
42
|
-
* static (decayShrink baked into the kernel) or no params are decayed. */
|
|
43
|
-
decayShrinkInputName: string | null;
|
|
44
|
-
/** Hyperparameters as captured (so the runtime can compute lrt and decayShrink). */
|
|
45
|
-
config: AdamResolvedConfig;
|
|
46
|
-
}
|
|
47
|
-
/**
|
|
48
|
-
* Append Adam update ops to `graph`. Must be called inside an active trace
|
|
49
|
-
* context (or after a trace, since traceInto re-enters the graph).
|
|
50
|
-
*
|
|
51
|
-
* @param graph the graph (already containing forward + backward)
|
|
52
|
-
* @param paramGrads param name -> gradient tensor (output of `appendGrad`)
|
|
53
|
-
* @param paramTensors param name -> the param's leaf Tensor (the param_input).
|
|
54
|
-
* Needed because the param_input lives in the graph but we
|
|
55
|
-
* don't have a direct map by name in `Graph` — caller passes it.
|
|
56
|
-
* @param config Adam hyperparameters. Set `weightDecay > 0` for AdamW; an
|
|
57
|
-
* optional `decayFilter` selects which params receive decay.
|
|
58
|
-
*/
|
|
59
|
-
export declare function appendAdam(graph: Graph, paramGrads: Record<string, Tensor>, paramTensors: Record<string, Tensor>, config: AdamConfig,
|
|
60
|
-
/** Per-param decay flags from `materializeParams`. When supplied, overrides
|
|
61
|
-
* `config.decayFilter` for any name in the map; falls back to `decayFilter`
|
|
62
|
-
* for names not present (e.g., for low-level callers using `compile()`
|
|
63
|
-
* directly without a Module). */
|
|
64
|
-
decayFlags?: Record<string, boolean>): AdamResult;
|
|
65
|
-
//# sourceMappingURL=adam.d.ts.map
|
package/dist/adam.d.ts.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"adam.d.ts","sourceRoot":"","sources":["../src/adam.ts"],"names":[],"mappings":"AA4BA,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,SAAS,CAAA;AACrC,OAAO,KAAK,EAAE,KAAK,EAAE,MAAM,SAAS,CAAA;AACpC,OAAO,KAAK,EAAE,aAAa,EAAE,MAAM,cAAc,CAAA;AAIjD,MAAM,WAAW,UAAU;IACzB;;;0DAGsD;IACtD,EAAE,EAAE,MAAM,GAAG,CAAC,CAAC,IAAI,EAAE,MAAM,KAAK,MAAM,CAAC,CAAA;IACvC,EAAE,CAAC,EAAE,MAAM,CAAA;IACX,EAAE,CAAC,EAAE,MAAM,CAAA;IACX,GAAG,CAAC,EAAE,MAAM,CAAA;IACZ;;6DAEyD;IACzD,WAAW,CAAC,EAAE,MAAM,CAAA;IACpB;;;6EAGyE;IACzE,WAAW,CAAC,EAAE,CAAC,SAAS,EAAE,MAAM,KAAK,OAAO,CAAA;CAC7C;AAED,+EAA+E;AAC/E,MAAM,WAAW,kBAAkB;IACjC,EAAE,EAAE,CAAC,IAAI,EAAE,MAAM,KAAK,MAAM,CAAA;IAC5B,EAAE,EAAE,MAAM,CAAA;IACV,EAAE,EAAE,MAAM,CAAA;IACV,GAAG,EAAE,MAAM,CAAA;IACX,WAAW,EAAE,MAAM,CAAA;IACnB,WAAW,EAAE,CAAC,IAAI,EAAE,MAAM,KAAK,OAAO,CAAA;IACtC;kEAC8D;IAC9D,aAAa,EAAE,OAAO,CAAA;CACvB;AAED,MAAM,WAAW,UAAU;IACzB,kEAAkE;IAClE,UAAU,EAAE,aAAa,EAAE,CAAA;IAC3B;iFAC6E;IAC7E,YAAY,EAAE,MAAM,CAAA;IACpB;+EAC2E;IAC3E,oBAAoB,EAAE,MAAM,GAAG,IAAI,CAAA;IACnC,oFAAoF;IACpF,MAAM,EAAE,kBAAkB,CAAA;CAC3B;AAED;;;;;;;;;;;GAWG;AACH,wBAAgB,UAAU,CACxB,KAAK,EAAE,KAAK,EACZ,UAAU,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,EAClC,YAAY,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,EACpC,MAAM,EAAE,UAAU;AAClB;;;kCAGkC;AAClC,UAAU,CAAC,EAAE,MAAM,CAAC,MAAM,EAAE,OAAO,CAAC,GACnC,UAAU,CAyEZ"}
|
package/dist/buffers.d.ts
DELETED
|
@@ -1,57 +0,0 @@
|
|
|
1
|
-
import type { Graph, Tensor, Dtype, Shape } from './ir.js';
|
|
2
|
-
export interface BufferSpec {
|
|
3
|
-
/** Matches tensor.id. */
|
|
4
|
-
id: number;
|
|
5
|
-
byteSize: number;
|
|
6
|
-
dtype: Dtype;
|
|
7
|
-
shape: Shape;
|
|
8
|
-
kind: 'param' | 'param_grad' | 'tensor_input' | 'state' | 'intermediate' | 'output';
|
|
9
|
-
/** External name for param/param_grad/tensor_input/state bindings. null otherwise. */
|
|
10
|
-
name: string | null;
|
|
11
|
-
/** For state buffers: the value to fill on initial allocation. 0 by default. */
|
|
12
|
-
initValue?: number;
|
|
13
|
-
}
|
|
14
|
-
/**
|
|
15
|
-
* After step(), copy `source`'s buffer into `dest`'s buffer.
|
|
16
|
-
* Used to write back updated optimizer state and updated parameters into
|
|
17
|
-
* their persistent home buffers.
|
|
18
|
-
*/
|
|
19
|
-
export interface Writeback {
|
|
20
|
-
source: number;
|
|
21
|
-
dest: number;
|
|
22
|
-
bytes: number;
|
|
23
|
-
}
|
|
24
|
-
export interface BufferPlan {
|
|
25
|
-
buffers: BufferSpec[];
|
|
26
|
-
/** Tensor id -> buffer id (currently 1:1 but kept opaque for future pooling). */
|
|
27
|
-
tensorToBuffer: Map<number, number>;
|
|
28
|
-
/** Easy lookup tables for the runtime. */
|
|
29
|
-
paramsByName: Map<string, number>;
|
|
30
|
-
inputsByName: Map<string, number>;
|
|
31
|
-
paramGradsByName: Map<string, number>;
|
|
32
|
-
statesByName: Map<string, number>;
|
|
33
|
-
capturesByName: Map<string, number>;
|
|
34
|
-
outputBufferIds: number[];
|
|
35
|
-
/** End-of-step writebacks (Adam updates for params, m, v, etc.) */
|
|
36
|
-
writebacks: Writeback[];
|
|
37
|
-
}
|
|
38
|
-
/**
|
|
39
|
-
* Caller-supplied writeback declarations: "after each step, copy this Tensor's
|
|
40
|
-
* buffer into the persistent home of this param/state."
|
|
41
|
-
*/
|
|
42
|
-
export interface WritebackDecl {
|
|
43
|
-
/** The Tensor (output of some op) holding the new value to write back. */
|
|
44
|
-
source: Tensor;
|
|
45
|
-
/** Either a param name (writes to that param's home buffer) or a state name. */
|
|
46
|
-
destName: string;
|
|
47
|
-
destKind: 'param' | 'state';
|
|
48
|
-
}
|
|
49
|
-
/**
|
|
50
|
-
* Build a BufferPlan from a graph + the param-grad map produced by appendGrad.
|
|
51
|
-
* @param graph the full graph (forward + backward + any optimizer ops)
|
|
52
|
-
* @param paramGrads map from param name -> the Tensor that holds its gradient
|
|
53
|
-
* @param writebackDecls list of end-of-step writebacks (e.g. from appendAdam).
|
|
54
|
-
* Empty when there's no optimizer in the graph.
|
|
55
|
-
*/
|
|
56
|
-
export declare function planBuffers(graph: Graph, paramGrads: Record<string, Tensor>, writebackDecls?: WritebackDecl[]): BufferPlan;
|
|
57
|
-
//# sourceMappingURL=buffers.d.ts.map
|
package/dist/buffers.d.ts.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"buffers.d.ts","sourceRoot":"","sources":["../src/buffers.ts"],"names":[],"mappings":"AAcA,OAAO,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAU,MAAM,SAAS,CAAA;AAElE,MAAM,WAAW,UAAU;IACzB,yBAAyB;IACzB,EAAE,EAAE,MAAM,CAAA;IACV,QAAQ,EAAE,MAAM,CAAA;IAChB,KAAK,EAAE,KAAK,CAAA;IACZ,KAAK,EAAE,KAAK,CAAA;IACZ,IAAI,EAAE,OAAO,GAAG,YAAY,GAAG,cAAc,GAAG,OAAO,GAAG,cAAc,GAAG,QAAQ,CAAA;IACnF,sFAAsF;IACtF,IAAI,EAAE,MAAM,GAAG,IAAI,CAAA;IACnB,gFAAgF;IAChF,SAAS,CAAC,EAAE,MAAM,CAAA;CACnB;AAED;;;;GAIG;AACH,MAAM,WAAW,SAAS;IACxB,MAAM,EAAE,MAAM,CAAA;IACd,IAAI,EAAE,MAAM,CAAA;IACZ,KAAK,EAAE,MAAM,CAAA;CACd;AAED,MAAM,WAAW,UAAU;IACzB,OAAO,EAAE,UAAU,EAAE,CAAA;IACrB,iFAAiF;IACjF,cAAc,EAAE,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IACnC,0CAA0C;IAC1C,YAAY,EAAE,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IACjC,YAAY,EAAE,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IACjC,gBAAgB,EAAE,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IACrC,YAAY,EAAE,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IACjC,cAAc,EAAE,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IACnC,eAAe,EAAE,MAAM,EAAE,CAAA;IACzB,mEAAmE;IACnE,UAAU,EAAE,SAAS,EAAE,CAAA;CACxB;AAUD;;;GAGG;AACH,MAAM,WAAW,aAAa;IAC5B,0EAA0E;IAC1E,MAAM,EAAE,MAAM,CAAA;IACd,gFAAgF;IAChF,QAAQ,EAAE,MAAM,CAAA;IAChB,QAAQ,EAAE,OAAO,GAAG,OAAO,CAAA;CAC5B;AAED;;;;;;GAMG;AACH,wBAAgB,WAAW,CACzB,KAAK,EAAE,KAAK,EACZ,UAAU,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,EAClC,cAAc,GAAE,aAAa,EAAO,GACnC,UAAU,CAmGZ"}
|
package/dist/capture.d.ts
DELETED
package/dist/capture.d.ts.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"capture.d.ts","sourceRoot":"","sources":["../src/capture.ts"],"names":[],"mappings":"AAqBA,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,SAAS,CAAA;AAGrC,wBAAgB,OAAO,CAAC,CAAC,SAAS,MAAM,EAAE,IAAI,EAAE,MAAM,EAAE,CAAC,EAAE,CAAC,GAAG,CAAC,CAW/D"}
|
package/dist/codegen.d.ts
DELETED
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
import type { Graph, OpNode } from './ir.js';
|
|
2
|
-
import type { BufferPlan } from './buffers.js';
|
|
3
|
-
export interface KernelSpec {
|
|
4
|
-
/** Index into graph.ops. */
|
|
5
|
-
opIndex: number;
|
|
6
|
-
/** Op kind (for debugging / pipeline cache key). */
|
|
7
|
-
opKind: OpNode['kind'];
|
|
8
|
-
/** Generated WGSL source. Empty string for "logical" ops with no kernel. */
|
|
9
|
-
wgsl: string;
|
|
10
|
-
/**
|
|
11
|
-
* Buffer ids in binding-index order. The runtime creates a bind group with
|
|
12
|
-
* these in @binding(0..N) on @group(0). Inputs come first (read), output last
|
|
13
|
-
* (read_write).
|
|
14
|
-
*/
|
|
15
|
-
bindings: number[];
|
|
16
|
-
/** Number of threads to dispatch (1-D). 0 means "skip" (e.g. reshape no-op). */
|
|
17
|
-
threads: number;
|
|
18
|
-
/** Workgroup size; usually WG_SIZE. */
|
|
19
|
-
workgroupSize: number;
|
|
20
|
-
}
|
|
21
|
-
/** Generate a KernelSpec per compute op in graph.ops (in dispatch order). */
|
|
22
|
-
export declare function emitKernels(graph: Graph, plan: BufferPlan): KernelSpec[];
|
|
23
|
-
//# sourceMappingURL=codegen.d.ts.map
|
package/dist/codegen.d.ts.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"codegen.d.ts","sourceRoot":"","sources":["../src/codegen.ts"],"names":[],"mappings":"AAYA,OAAO,KAAK,EAAE,KAAK,EAAE,MAAM,EAAiB,MAAM,SAAS,CAAA;AAC3D,OAAO,KAAK,EAAE,UAAU,EAAE,MAAM,cAAc,CAAA;AAO9C,MAAM,WAAW,UAAU;IACzB,4BAA4B;IAC5B,OAAO,EAAE,MAAM,CAAA;IACf,oDAAoD;IACpD,MAAM,EAAE,MAAM,CAAC,MAAM,CAAC,CAAA;IACtB,4EAA4E;IAC5E,IAAI,EAAE,MAAM,CAAA;IACZ;;;;OAIG;IACH,QAAQ,EAAE,MAAM,EAAE,CAAA;IAClB,gFAAgF;IAChF,OAAO,EAAE,MAAM,CAAA;IACf,uCAAuC;IACvC,aAAa,EAAE,MAAM,CAAA;CACtB;AAMD,6EAA6E;AAC7E,wBAAgB,WAAW,CAAC,KAAK,EAAE,KAAK,EAAE,IAAI,EAAE,UAAU,GAAG,UAAU,EAAE,CAQxE"}
|
package/dist/compile.d.ts
DELETED
|
@@ -1,130 +0,0 @@
|
|
|
1
|
-
import type { Tensor, Shape, Dtype } from './ir.js';
|
|
2
|
-
import { type GradResult } from './grad.js';
|
|
3
|
-
import { type AdamConfig } from './adam.js';
|
|
4
|
-
import { type BufferPlan } from './buffers.js';
|
|
5
|
-
import { type KernelSpec } from './codegen.js';
|
|
6
|
-
import { type CompiledRuntime, type CompiledForward, type RuntimeOpts } from './runtime.js';
|
|
7
|
-
import { Module } from './module.js';
|
|
8
|
-
/** Declares one input tensor of the model's forward function. The name is the
|
|
9
|
-
* key in the `inputs:` Record at compile time and the key on the `step()`/
|
|
10
|
-
* `run()` data object at runtime. */
|
|
11
|
-
export interface InputDecl {
|
|
12
|
-
shape: Shape;
|
|
13
|
-
dtype?: Dtype;
|
|
14
|
-
}
|
|
15
|
-
/** Inputs declaration: a Record from input name to its shape/dtype. The name
|
|
16
|
-
* doubles as the key the forward fn destructures and the key the runtime
|
|
17
|
-
* expects in `step({...})` / `run({...})`. */
|
|
18
|
-
export type InputDecls = Record<string, InputDecl>;
|
|
19
|
-
/** Maps an `InputDecls` Record to its forward-time tensor counterpart —
|
|
20
|
-
* same keys, each value is a Tensor. Used to type the forward function's
|
|
21
|
-
* `inputs` argument from the declared shape Record. */
|
|
22
|
-
export type InputsTensors<I extends InputDecls> = {
|
|
23
|
-
[K in keyof I]: Tensor;
|
|
24
|
-
};
|
|
25
|
-
/** Forward function shape: takes the materialized model and a Record of
|
|
26
|
-
* named input tensors (matching the declared `inputs:` keys), returns the
|
|
27
|
-
* output tensor (loss for compileModule; logits/etc. for compileForward).
|
|
28
|
-
* The second generic flows from the inputs declaration so destructuring
|
|
29
|
-
* the input record stays typed. */
|
|
30
|
-
export type ForwardFn<M extends Module, I extends InputDecls = InputDecls> = (m: M, inputs: InputsTensors<I>) => Tensor;
|
|
31
|
-
export interface CompiledIR {
|
|
32
|
-
graph: GradResult['graph'];
|
|
33
|
-
paramGrads: GradResult['paramGrads'];
|
|
34
|
-
loss: Tensor;
|
|
35
|
-
plan: BufferPlan;
|
|
36
|
-
kernels: KernelSpec[];
|
|
37
|
-
}
|
|
38
|
-
/** Trace + autograd + buffer-plan + codegen, without touching WebGPU. */
|
|
39
|
-
export declare function compileToIR(traceFn: () => Tensor): CompiledIR;
|
|
40
|
-
/** Full compile pipeline. Browser-only because it creates a GPUDevice. */
|
|
41
|
-
export declare function compile(traceFn: () => Tensor, opts?: RuntimeOpts): Promise<CompiledRuntime & {
|
|
42
|
-
ir: CompiledIR;
|
|
43
|
-
}>;
|
|
44
|
-
export interface CompileModuleOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
|
|
45
|
-
/** Per-step data inputs to the forward function, keyed by name. The forward
|
|
46
|
-
* fn destructures these out of its second argument; runtime calls to
|
|
47
|
-
* `step()` / `run()` pass typed arrays under the same keys. */
|
|
48
|
-
inputs?: I;
|
|
49
|
-
/** Adam hyperparameters. If omitted, no optimizer is appended (forward-only). */
|
|
50
|
-
adam?: AdamConfig;
|
|
51
|
-
}
|
|
52
|
-
export interface CompileForwardOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
|
|
53
|
-
/** Per-step data inputs to the forward function, keyed by name. */
|
|
54
|
-
inputs?: I;
|
|
55
|
-
}
|
|
56
|
-
/** Forward-only compile options as taken by the `compileForward` *method* on
|
|
57
|
-
* a training runtime — no `device` (inherited) and no `sharedParams`
|
|
58
|
-
* (auto-supplied from the train graph's params). */
|
|
59
|
-
export interface CompileForwardMethodOptions<I extends InputDecls = InputDecls> {
|
|
60
|
-
inputs?: I;
|
|
61
|
-
}
|
|
62
|
-
/** Returned by `compileModule`. Adds training-graph extras (auto-init, reset,
|
|
63
|
-
* sibling-graph compile) on top of the base runtime. */
|
|
64
|
-
export interface CompiledModule<M extends Module> extends CompiledRuntime {
|
|
65
|
-
ir: CompiledIR;
|
|
66
|
-
/** Number of dispatchable kernels (excludes leaf no-ops). */
|
|
67
|
-
kernelCount: number;
|
|
68
|
-
/** Re-initialize all params from their declared init specs and zero the
|
|
69
|
-
* optimizer state. Use to start training over without recompiling. */
|
|
70
|
-
reset(): void;
|
|
71
|
-
/** Compile a sibling forward-only graph (e.g., a B=1 inference graph or a
|
|
72
|
-
* B=N held-out eval graph) that shares this runtime's device and param
|
|
73
|
-
* buffers. Pass the forward fn (typically distinct from your loss fn —
|
|
74
|
-
* it returns logits, not a scalar) and any shape changes via `inputs`.
|
|
75
|
-
* Auto-initialization is a no-op since params are shared. */
|
|
76
|
-
compileForward<I extends InputDecls>(forward: ForwardFn<M, I>, opts?: CompileForwardMethodOptions<I>): Promise<CompiledForwardModule>;
|
|
77
|
-
}
|
|
78
|
-
/** Returned by `compileForward` (and by the `compileForward` method). */
|
|
79
|
-
export interface CompiledForwardModule extends CompiledForward {
|
|
80
|
-
ir: CompiledIR;
|
|
81
|
-
/** Number of dispatchable kernels (excludes leaf no-ops). */
|
|
82
|
-
kernelCount: number;
|
|
83
|
-
}
|
|
84
|
-
/**
|
|
85
|
-
* Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
|
|
86
|
-
* model instance itself: compilation mutates the tree (every `ParamSentinel`
|
|
87
|
-
* field becomes a real `Tensor`), so the instance is consumed and shouldn't be
|
|
88
|
-
* referenced afterwards. Re-call the factory if you need a fresh tree.
|
|
89
|
-
*
|
|
90
|
-
* The forward function takes the materialized model and a Record of named
|
|
91
|
-
* input tensors, returns the loss tensor. Inputs are matched by name with the
|
|
92
|
-
* `inputs:` declaration:
|
|
93
|
-
*
|
|
94
|
-
* inputs: {
|
|
95
|
-
* tokens: { shape: [B, T], dtype: 'i32' },
|
|
96
|
-
* targets: { shape: [B, T], dtype: 'i32' },
|
|
97
|
-
* }
|
|
98
|
-
* forward: (m, { tokens, targets }) => …
|
|
99
|
-
*
|
|
100
|
-
* Walks the module tree to materialize params with auto-derived names, then
|
|
101
|
-
* runs trace → grad → adam → buffer plan → codegen → runtime. Initial
|
|
102
|
-
* parameter values are uploaded automatically before this function returns;
|
|
103
|
-
* call `reset()` later to re-randomize.
|
|
104
|
-
*
|
|
105
|
-
* If `opts.adam` is set, the runtime's `step()` automatically tracks an
|
|
106
|
-
* internal step count and injects the bias-corrected `lrt` scalar each call;
|
|
107
|
-
* users don't need to provide it themselves.
|
|
108
|
-
*/
|
|
109
|
-
export declare function compileModule<M extends Module, I extends InputDecls = InputDecls>(modelFactory: () => M, forward: ForwardFn<M, I>, opts?: CompileModuleOptions<I>): Promise<CompiledModule<M>>;
|
|
110
|
-
/**
|
|
111
|
-
* Compile a Module-based model in forward-only mode (no autograd, no Adam).
|
|
112
|
-
* The forward function returns the output tensor (e.g., logits) instead of a
|
|
113
|
-
* scalar loss; runtime exposes `run(inputs)` returning the full output as a
|
|
114
|
-
* `Float32Array`.
|
|
115
|
-
*
|
|
116
|
-
* **Prefer the `compileForward` method on a training runtime** when both
|
|
117
|
-
* graphs use the same Module class — it auto-supplies `device` and
|
|
118
|
-
* `sharedParams`. This standalone form is for forward-only models with no
|
|
119
|
-
* training graph at all, or for sharing params across a different model.
|
|
120
|
-
*
|
|
121
|
-
* **Sharing params with a training compile.** Pass `opts.sharedParams =
|
|
122
|
-
* trainCompiled.params` to bind this graph's param buffers to an existing
|
|
123
|
-
* training runtime's GPU buffers — every train step is then immediately
|
|
124
|
-
* visible to `run()` calls here, no copies.
|
|
125
|
-
*
|
|
126
|
-
* Initial param values are uploaded automatically for params *not* covered
|
|
127
|
-
* by `sharedParams` (those are owned by the sibling compile).
|
|
128
|
-
*/
|
|
129
|
-
export declare function compileForward<M extends Module, I extends InputDecls = InputDecls>(modelFactory: () => M, forward: ForwardFn<M, I>, opts?: CompileForwardOptions<I>): Promise<CompiledForwardModule>;
|
|
130
|
-
//# sourceMappingURL=compile.d.ts.map
|
package/dist/compile.d.ts.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"compile.d.ts","sourceRoot":"","sources":["../src/compile.ts"],"names":[],"mappings":"AAUA,OAAO,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,SAAS,CAAA;AAEnD,OAAO,EAAc,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACvD,OAAO,EAAc,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACvD,OAAO,EAAe,KAAK,UAAU,EAAE,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAe,KAAK,UAAU,EAAE,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAuC,KAAK,eAAe,EAAE,KAAK,eAAe,EAAE,KAAK,WAAW,EAAE,MAAM,cAAc,CAAA;AAChI,OAAO,EAAE,MAAM,EAAqB,MAAM,aAAa,CAAA;AAEvD;;sCAEsC;AACtC,MAAM,WAAW,SAAS;IACxB,KAAK,EAAE,KAAK,CAAA;IACZ,KAAK,CAAC,EAAE,KAAK,CAAA;CACd;AAED;;+CAE+C;AAC/C,MAAM,MAAM,UAAU,GAAG,MAAM,CAAC,MAAM,EAAE,SAAS,CAAC,CAAA;AAElD;;wDAEwD;AACxD,MAAM,MAAM,aAAa,CAAC,CAAC,SAAS,UAAU,IAAI;KAAG,CAAC,IAAI,MAAM,CAAC,GAAG,MAAM;CAAE,CAAA;AAE5E;;;;oCAIoC;AACpC,MAAM,MAAM,SAAS,CAAC,CAAC,SAAS,MAAM,EAAE,CAAC,SAAS,UAAU,GAAG,UAAU,IACvE,CAAC,CAAC,EAAE,CAAC,EAAE,MAAM,EAAE,aAAa,CAAC,CAAC,CAAC,KAAK,MAAM,CAAA;AAE5C,MAAM,WAAW,UAAU;IACzB,KAAK,EAAE,UAAU,CAAC,OAAO,CAAC,CAAA;IAC1B,UAAU,EAAE,UAAU,CAAC,YAAY,CAAC,CAAA;IACpC,IAAI,EAAE,MAAM,CAAA;IACZ,IAAI,EAAE,UAAU,CAAA;IAChB,OAAO,EAAE,UAAU,EAAE,CAAA;CACtB;AAED,yEAAyE;AACzE,wBAAgB,WAAW,CAAC,OAAO,EAAE,MAAM,MAAM,GAAG,UAAU,CAM7D;AAED,0EAA0E;AAC1E,wBAAsB,OAAO,CAAC,OAAO,EAAE,MAAM,MAAM,EAAE,IAAI,GAAE,WAAgB,GAAG,OAAO,CAAC,eAAe,GAAG;IAAE,EAAE,EAAE,UAAU,CAAA;CAAE,CAAC,CAK1H;AAMD,MAAM,WAAW,oBAAoB,CAAC,CAAC,SAAS,UAAU,GAAG,UAAU,CAAE,SAAQ,WAAW;IAC1F;;oEAEgE;IAChE,MAAM,CAAC,EAAE,CAAC,CAAA;IACV,iFAAiF;IACjF,IAAI,CAAC,EAAE,UAAU,CAAA;CAClB;AAED,MAAM,WAAW,qBAAqB,CAAC,CAAC,SAAS,UAAU,GAAG,UAAU,CAAE,SAAQ,WAAW;IAC3F,mEAAmE;IACnE,MAAM,CAAC,EAAE,CAAC,CAAA;CACX;AAED;;qDAEqD;AACrD,MAAM,WAAW,2BAA2B,CAAC,CAAC,SAAS,UAAU,GAAG,UAAU;IAC5E,MAAM,CAAC,EAAE,CAAC,CAAA;CACX;AAED;yDACyD;AACzD,MAAM,WAAW,cAAc,CAAC,CAAC,SAAS,MAAM,CAAE,SAAQ,eAAe;IACvE,EAAE,EAAE,UAAU,CAAA;IACd,6DAA6D;IAC7D,WAAW,EAAE,MAAM,CAAA;IACnB;2EACuE;IACvE,KAAK,IAAI,IAAI,CAAA;IACb;;;;kEAI8D;IAC9D,cAAc,CAAC,CAAC,SAAS,UAAU,EACjC,OAAO,EAAE,SAAS,CAAC,CAAC,EAAE,CAAC,CAAC,EACxB,IAAI,CAAC,EAAE,2BAA2B,CAAC,CAAC,CAAC,GACpC,OAAO,CAAC,qBAAqB,CAAC,CAAA;CAClC;AAED,yEAAyE;AACzE,MAAM,WAAW,qBAAsB,SAAQ,eAAe;IAC5D,EAAE,EAAE,UAAU,CAAA;IACd,6DAA6D;IAC7D,WAAW,EAAE,MAAM,CAAA;CACpB;AAED;;;;;;;;;;;;;;;;;;;;;;;;GAwBG;AACH,wBAAsB,aAAa,CAAC,CAAC,SAAS,MAAM,EAAE,CAAC,SAAS,UAAU,GAAG,UAAU,EACrF,YAAY,EAAE,MAAM,CAAC,EACrB,OAAO,EAAE,SAAS,CAAC,CAAC,EAAE,CAAC,CAAC,EACxB,IAAI,GAAE,oBAAoB,CAAC,CAAC,CAAM,GACjC,OAAO,CAAC,cAAc,CAAC,CAAC,CAAC,CAAC,CAyC5B;AAMD;;;;;;;;;;;;;;;;;;GAkBG;AACH,wBAAsB,cAAc,CAAC,CAAC,SAAS,MAAM,EAAE,CAAC,SAAS,UAAU,GAAG,UAAU,EACtF,YAAY,EAAE,MAAM,CAAC,EACrB,OAAO,EAAE,SAAS,CAAC,CAAC,EAAE,CAAC,CAAC,EACxB,IAAI,GAAE,qBAAqB,CAAC,CAAC,CAAM,GAClC,OAAO,CAAC,qBAAqB,CAAC,CAYhC"}
|
package/dist/grad.d.ts
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
import type { Graph, Tensor } from './ir.js';
|
|
2
|
-
export interface GradResult {
|
|
3
|
-
readonly graph: Graph;
|
|
4
|
-
readonly paramGrads: Record<string, Tensor>;
|
|
5
|
-
readonly loss: Tensor;
|
|
6
|
-
}
|
|
7
|
-
export declare function appendGrad(graph: Graph): GradResult;
|
|
8
|
-
//# sourceMappingURL=grad.d.ts.map
|
package/dist/grad.d.ts.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"grad.d.ts","sourceRoot":"","sources":["../src/grad.ts"],"names":[],"mappings":"AAiBA,OAAO,KAAK,EAAE,KAAK,EAAU,MAAM,EAAS,MAAM,SAAS,CAAA;AAe3D,MAAM,WAAW,UAAU;IAEzB,QAAQ,CAAC,KAAK,EAAE,KAAK,CAAA;IAErB,QAAQ,CAAC,UAAU,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IAE3C,QAAQ,CAAC,IAAI,EAAE,MAAM,CAAA;CACtB;AASD,wBAAgB,UAAU,CAAC,KAAK,EAAE,KAAK,GAAG,UAAU,CAmDnD"}
|
package/dist/index.d.ts.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAKA,YAAY,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,QAAQ,EAAE,MAAM,SAAS,CAAA;AAC5E,OAAO,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AACvC,OAAO,EAAE,KAAK,EAAE,SAAS,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,MAAM,YAAY,CAAA;AAClF,OAAO,EAAE,OAAO,EAAE,MAAM,cAAc,CAAA;AACtC,OAAO,EAEL,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAElB,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,GAAG,EAAE,IAAI,EAE3B,IAAI,EAAE,OAAO,EAAE,KAAK,EAEpB,QAAQ,EAAE,OAAO,EAAE,MAAM,EAEzB,OAAO,EAAE,SAAS,EAAE,QAAQ,EAE5B,MAAM,EAAE,aAAa,EAErB,MAAM,EAAE,MAAM,EAAE,SAAS,EAEzB,iBAAiB,EAAE,cAAc,EAAE,WAAW,EAE9C,cAAc,GACf,MAAM,UAAU,CAAA;AAMjB,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACvD,OAAO,EAAE,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,MAAM,WAAW,CAAA;AACxE,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,SAAS,EAAE,KAAK,aAAa,EAAE,MAAM,cAAc,CAAA;AAChH,OAAO,EAAE,WAAW,EAAE,KAAK,UAAU,EAAE,MAAM,cAAc,CAAA;AAC3D,OAAO,EAAE,aAAa,EAAE,oBAAoB,EAAE,QAAQ,EAAE,KAAK,eAAe,EAAE,KAAK,eAAe,EAAE,KAAK,WAAW,EAAE,KAAK,UAAU,EAAE,KAAK,UAAU,EAAE,KAAK,SAAS,EAAE,MAAM,cAAc,CAAA;AAC5L,OAAO,EACL,OAAO,EAAE,WAAW,EAAE,aAAa,EAAE,cAAc,EACnD,KAAK,UAAU,EAAE,KAAK,oBAAoB,EAAE,KAAK,qBAAqB,EAAE,KAAK,2BAA2B,EACxG,KAAK,cAAc,EAAE,KAAK,qBAAqB,EAC/C,KAAK,SAAS,EAAE,KAAK,UAAU,EAAE,KAAK,aAAa,EAAE,KAAK,SAAS,GACpE,MAAM,cAAc,CAAA;AACrB,OAAO,EAAE,MAAM,EAAE,iBAAiB,EAAE,KAAK,QAAQ,EAAE,KAAK,YAAY,EAAE,KAAK,kBAAkB,EAAE,MAAM,aAAa,CAAA;AAClH,OAAO,KAAK,EAAE,MAAM,SAAS,CAAA"}
|
package/dist/ir.d.ts
DELETED
|
@@ -1,207 +0,0 @@
|
|
|
1
|
-
export type Dtype = 'f32' | 'i32' | 'bool';
|
|
2
|
-
export type Shape = readonly number[];
|
|
3
|
-
export interface Tensor {
|
|
4
|
-
readonly id: number;
|
|
5
|
-
readonly shape: Shape;
|
|
6
|
-
readonly dtype: Dtype;
|
|
7
|
-
readonly source: number | null;
|
|
8
|
-
readonly site: CallSite | null;
|
|
9
|
-
}
|
|
10
|
-
export interface CallSite {
|
|
11
|
-
readonly opName: string;
|
|
12
|
-
readonly stack: string;
|
|
13
|
-
}
|
|
14
|
-
export type OpNode = {
|
|
15
|
-
kind: 'param_input';
|
|
16
|
-
out: number;
|
|
17
|
-
name: string;
|
|
18
|
-
} | {
|
|
19
|
-
kind: 'tensor_input';
|
|
20
|
-
out: number;
|
|
21
|
-
name: string;
|
|
22
|
-
} | {
|
|
23
|
-
kind: 'state_input';
|
|
24
|
-
out: number;
|
|
25
|
-
name: string;
|
|
26
|
-
initValue: number;
|
|
27
|
-
} | {
|
|
28
|
-
kind: 'add';
|
|
29
|
-
out: number;
|
|
30
|
-
a: number;
|
|
31
|
-
b: number;
|
|
32
|
-
} | {
|
|
33
|
-
kind: 'sub';
|
|
34
|
-
out: number;
|
|
35
|
-
a: number;
|
|
36
|
-
b: number;
|
|
37
|
-
} | {
|
|
38
|
-
kind: 'mul';
|
|
39
|
-
out: number;
|
|
40
|
-
a: number;
|
|
41
|
-
b: number;
|
|
42
|
-
} | {
|
|
43
|
-
kind: 'div';
|
|
44
|
-
out: number;
|
|
45
|
-
a: number;
|
|
46
|
-
b: number;
|
|
47
|
-
} | {
|
|
48
|
-
kind: 'mul_scalar';
|
|
49
|
-
out: number;
|
|
50
|
-
a: number;
|
|
51
|
-
scalar: number;
|
|
52
|
-
} | {
|
|
53
|
-
kind: 'add_scalar';
|
|
54
|
-
out: number;
|
|
55
|
-
a: number;
|
|
56
|
-
scalar: number;
|
|
57
|
-
} | {
|
|
58
|
-
kind: 'sqrt';
|
|
59
|
-
out: number;
|
|
60
|
-
a: number;
|
|
61
|
-
} | {
|
|
62
|
-
kind: 'rsqrt';
|
|
63
|
-
out: number;
|
|
64
|
-
a: number;
|
|
65
|
-
} | {
|
|
66
|
-
kind: 'log';
|
|
67
|
-
out: number;
|
|
68
|
-
a: number;
|
|
69
|
-
} | {
|
|
70
|
-
kind: 'exp';
|
|
71
|
-
out: number;
|
|
72
|
-
a: number;
|
|
73
|
-
} | {
|
|
74
|
-
kind: 'relu';
|
|
75
|
-
out: number;
|
|
76
|
-
a: number;
|
|
77
|
-
} | {
|
|
78
|
-
kind: 'mean_last';
|
|
79
|
-
out: number;
|
|
80
|
-
a: number;
|
|
81
|
-
} | {
|
|
82
|
-
kind: 'sum_last';
|
|
83
|
-
out: number;
|
|
84
|
-
a: number;
|
|
85
|
-
} | {
|
|
86
|
-
kind: 'reshape';
|
|
87
|
-
out: number;
|
|
88
|
-
a: number;
|
|
89
|
-
newShape: Shape;
|
|
90
|
-
} | {
|
|
91
|
-
kind: 'transpose';
|
|
92
|
-
out: number;
|
|
93
|
-
a: number;
|
|
94
|
-
perm: readonly number[];
|
|
95
|
-
} | {
|
|
96
|
-
kind: 'matmul';
|
|
97
|
-
out: number;
|
|
98
|
-
a: number;
|
|
99
|
-
b: number;
|
|
100
|
-
} | {
|
|
101
|
-
kind: 'matmul_batched';
|
|
102
|
-
out: number;
|
|
103
|
-
a: number;
|
|
104
|
-
b: number;
|
|
105
|
-
} | {
|
|
106
|
-
kind: 'one_hot';
|
|
107
|
-
out: number;
|
|
108
|
-
indices: number;
|
|
109
|
-
depth: number;
|
|
110
|
-
dtype: Dtype;
|
|
111
|
-
} | {
|
|
112
|
-
kind: 'arange';
|
|
113
|
-
out: number;
|
|
114
|
-
n: number;
|
|
115
|
-
dtype: Dtype;
|
|
116
|
-
} | {
|
|
117
|
-
kind: 'softmax_causal_last';
|
|
118
|
-
out: number;
|
|
119
|
-
a: number;
|
|
120
|
-
} | {
|
|
121
|
-
kind: 'log_softmax_last';
|
|
122
|
-
out: number;
|
|
123
|
-
a: number;
|
|
124
|
-
} | {
|
|
125
|
-
kind: 'where_causal';
|
|
126
|
-
out: number;
|
|
127
|
-
a: number;
|
|
128
|
-
fillValue: number;
|
|
129
|
-
} | {
|
|
130
|
-
kind: 'less';
|
|
131
|
-
out: number;
|
|
132
|
-
a: number;
|
|
133
|
-
b: number;
|
|
134
|
-
} | {
|
|
135
|
-
kind: 'greater';
|
|
136
|
-
out: number;
|
|
137
|
-
a: number;
|
|
138
|
-
b: number;
|
|
139
|
-
} | {
|
|
140
|
-
kind: 'where';
|
|
141
|
-
out: number;
|
|
142
|
-
cond: number;
|
|
143
|
-
a: number;
|
|
144
|
-
b: number;
|
|
145
|
-
} | {
|
|
146
|
-
kind: 'adam_update_m';
|
|
147
|
-
out: number;
|
|
148
|
-
m: number;
|
|
149
|
-
g: number;
|
|
150
|
-
b1: number;
|
|
151
|
-
} | {
|
|
152
|
-
kind: 'adam_update_v';
|
|
153
|
-
out: number;
|
|
154
|
-
v: number;
|
|
155
|
-
g: number;
|
|
156
|
-
b2: number;
|
|
157
|
-
} | {
|
|
158
|
-
kind: 'adam_update_p';
|
|
159
|
-
out: number;
|
|
160
|
-
p: number;
|
|
161
|
-
mNew: number;
|
|
162
|
-
vNew: number;
|
|
163
|
-
lrt: number;
|
|
164
|
-
eps: number;
|
|
165
|
-
decayShrink: number;
|
|
166
|
-
decayShrinkTensor: number | null;
|
|
167
|
-
} | {
|
|
168
|
-
kind: 'slice_last_range';
|
|
169
|
-
out: number;
|
|
170
|
-
a: number;
|
|
171
|
-
start: number;
|
|
172
|
-
end: number;
|
|
173
|
-
} | {
|
|
174
|
-
kind: 'broadcast_to';
|
|
175
|
-
out: number;
|
|
176
|
-
a: number;
|
|
177
|
-
targetShape: Shape;
|
|
178
|
-
} | {
|
|
179
|
-
kind: 'sum_to_shape';
|
|
180
|
-
out: number;
|
|
181
|
-
a: number;
|
|
182
|
-
targetShape: Shape;
|
|
183
|
-
} | {
|
|
184
|
-
kind: 'const_scalar';
|
|
185
|
-
out: number;
|
|
186
|
-
value: number;
|
|
187
|
-
dtype: Dtype;
|
|
188
|
-
} | {
|
|
189
|
-
kind: 'relu_grad';
|
|
190
|
-
out: number;
|
|
191
|
-
x: number;
|
|
192
|
-
dy: number;
|
|
193
|
-
};
|
|
194
|
-
export interface Graph {
|
|
195
|
-
readonly ops: OpNode[];
|
|
196
|
-
readonly tensors: Tensor[];
|
|
197
|
-
readonly outputs: number[];
|
|
198
|
-
readonly captures: Map<string, number>;
|
|
199
|
-
}
|
|
200
|
-
export declare function makeGraph(): Graph;
|
|
201
|
-
export declare function addTensor(g: Graph, shape: Shape, dtype: Dtype, source: number | null, site: CallSite | null): Tensor;
|
|
202
|
-
export declare function addOp<K extends OpNode['kind']>(g: Graph, kind: K, shape: Shape, dtype: Dtype, site: CallSite | null, fields: Omit<Extract<OpNode, {
|
|
203
|
-
kind: K;
|
|
204
|
-
}>, 'kind' | 'out'>): Tensor;
|
|
205
|
-
export declare function captureSite(opName: string): CallSite;
|
|
206
|
-
export declare function formatSite(site: CallSite): string;
|
|
207
|
-
//# sourceMappingURL=ir.d.ts.map
|
package/dist/ir.d.ts.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"ir.d.ts","sourceRoot":"","sources":["../src/ir.ts"],"names":[],"mappings":"AAcA,MAAM,MAAM,KAAK,GAAG,KAAK,GAAG,KAAK,GAAG,MAAM,CAAA;AAC1C,MAAM,MAAM,KAAK,GAAG,SAAS,MAAM,EAAE,CAAA;AAIrC,MAAM,WAAW,MAAM;IACrB,QAAQ,CAAC,EAAE,EAAE,MAAM,CAAA;IACnB,QAAQ,CAAC,KAAK,EAAE,KAAK,CAAA;IACrB,QAAQ,CAAC,KAAK,EAAE,KAAK,CAAA;IAErB,QAAQ,CAAC,MAAM,EAAE,MAAM,GAAG,IAAI,CAAA;IAG9B,QAAQ,CAAC,IAAI,EAAE,QAAQ,GAAG,IAAI,CAAA;CAC/B;AAED,MAAM,WAAW,QAAQ;IACvB,QAAQ,CAAC,MAAM,EAAE,MAAM,CAAA;IAEvB,QAAQ,CAAC,KAAK,EAAE,MAAM,CAAA;CACvB;AAQD,MAAM,MAAM,MAAM,GAGd;IAAE,IAAI,EAAE,aAAa,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAA;CAAE,GAElD;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAA;CAAE,GAInD;IAAE,IAAI,EAAE,aAAa,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,SAAS,EAAE,MAAM,CAAA;CAAE,GAGrE;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAClD;IAAE,IAAI,EAAE,YAAY,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,MAAM,EAAE,MAAM,CAAA;CAAE,GAC9D;IAAE,IAAI,EAAE,YAAY,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,MAAM,EAAE,MAAM,CAAA;CAAE,GAG9D;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACxC;IAAE,IAAI,EAAE,OAAO,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACzC;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACvC;IAAE,IAAI,EAAE,KAAK,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACvC;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAGxC;IAAE,IAAI,EAAE,WAAW,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAC7C;IAAE,IAAI,EAAE,UAAU,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAG5C;IAAE,IAAI,EAAE,SAAS,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,QAAQ,EAAE,KAAK,CAAA;CAAE,GAC5D;IAAE,IAAI,EAAE,WAAW,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,SAAS,MAAM,EAAE,CAAA;CAAE,GAMtE;IAAE,IAAI,EAAE,QAAQ,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAErD;IAAE,IAAI,EAAE,gBAAgB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAG7D;IAAE,IAAI,EAAE,SAAS,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,OAAO,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,KAAK,CAAA;CAAE,GAC9E;IAAE,IAAI,EAAE,QAAQ,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,KAAK,CAAA;CAAE,GAGxD;IAAE,IAAI,EAAE,qBAAqB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACvD;IAAE,IAAI,EAAE,kBAAkB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAIpD;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,SAAS,EAAE,MAAM,CAAA;CAAE,GAKnE;IAAE,IAAI,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GACnD;IAAE,IAAI,EAAE,SAAS,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAGtD;IAAE,IAAI,EAAE,OAAO,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,IAAI,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAA;CAAE,GAMlE;IAAE,IAAI,EAAE,eAAe,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,EAAE,EAAE,MAAM,CAAA;CAAE,GACxE;IAAE,IAAI,EAAE,eAAe,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,EAAE,EAAE,MAAM,CAAA;CAAE,GASxE;IACE,IAAI,EAAE,eAAe,CAAA;IACrB,GAAG,EAAE,MAAM,CAAA;IACX,CAAC,EAAE,MAAM,CAAA;IACT,IAAI,EAAE,MAAM,CAAA;IACZ,IAAI,EAAE,MAAM,CAAA;IACZ,GAAG,EAAE,MAAM,CAAA;IACX,GAAG,EAAE,MAAM,CAAA;IACX,WAAW,EAAE,MAAM,CAAA;IACnB,iBAAiB,EAAE,MAAM,GAAG,IAAI,CAAA;CACjC,GAMD;IAAE,IAAI,EAAE,kBAAkB,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,MAAM,CAAC;IAAC,GAAG,EAAE,MAAM,CAAA;CAAE,GAGhF;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,WAAW,EAAE,KAAK,CAAA;CAAE,GAGpE;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,WAAW,EAAE,KAAK,CAAA;CAAE,GAEpE;IAAE,IAAI,EAAE,cAAc,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,MAAM,CAAC;IAAC,KAAK,EAAE,KAAK,CAAA;CAAE,GAElE;IAAE,IAAI,EAAE,WAAW,CAAC;IAAC,GAAG,EAAE,MAAM,CAAC;IAAC,CAAC,EAAE,MAAM,CAAC;IAAC,EAAE,EAAE,MAAM,CAAA;CAAE,CAAA;AAI7D,MAAM,WAAW,KAAK;IACpB,QAAQ,CAAC,GAAG,EAAE,MAAM,EAAE,CAAA;IACtB,QAAQ,CAAC,OAAO,EAAE,MAAM,EAAE,CAAA;IAG1B,QAAQ,CAAC,OAAO,EAAE,MAAM,EAAE,CAAA;IAI1B,QAAQ,CAAC,QAAQ,EAAE,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;CACvC;AAED,wBAAgB,SAAS,IAAI,KAAK,CAEjC;AAGD,wBAAgB,SAAS,CAAC,CAAC,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,MAAM,GAAG,IAAI,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,MAAM,CAKpH;AAMD,wBAAgB,KAAK,CAAC,CAAC,SAAS,MAAM,CAAC,MAAM,CAAC,EAC5C,CAAC,EAAE,KAAK,EACR,IAAI,EAAE,CAAC,EACP,KAAK,EAAE,KAAK,EACZ,KAAK,EAAE,KAAK,EACZ,IAAI,EAAE,QAAQ,GAAG,IAAI,EACrB,MAAM,EAAE,IAAI,CAAC,OAAO,CAAC,MAAM,EAAE;IAAE,IAAI,EAAE,CAAC,CAAA;CAAE,CAAC,EAAE,MAAM,GAAG,KAAK,CAAC,GACzD,MAAM,CAMR;AAID,wBAAgB,WAAW,CAAC,MAAM,EAAE,MAAM,GAAG,QAAQ,CAIpD;AAID,wBAAgB,UAAU,CAAC,IAAI,EAAE,QAAQ,GAAG,MAAM,CAYjD"}
|
package/dist/module.d.ts
DELETED
|
@@ -1,55 +0,0 @@
|
|
|
1
|
-
import type { Tensor, Shape, Dtype } from './ir.js';
|
|
2
|
-
/** How a parameter's initial values are produced.
|
|
3
|
-
* - `'randn'` — Gaussian, with `scale` (default 0.02). The common case for
|
|
4
|
-
* weight matrices and embeddings.
|
|
5
|
-
* - `'zeros'` — fill with 0. Common for biases and LayerNorm beta.
|
|
6
|
-
* - `'ones'` — fill with 1. Common for LayerNorm gain.
|
|
7
|
-
* - Custom function — receives total element count and shape, returns the
|
|
8
|
-
* Float32Array. Use for fan-in scaling or any non-standard scheme.
|
|
9
|
-
*/
|
|
10
|
-
export type InitSpec = 'randn' | 'zeros' | 'ones' | ((size: number, shape: readonly number[]) => Float32Array);
|
|
11
|
-
export interface ParamOptions {
|
|
12
|
-
dtype?: Dtype;
|
|
13
|
-
/** Init kind. Default: `'randn'`. */
|
|
14
|
-
init?: InitSpec;
|
|
15
|
-
/** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
|
|
16
|
-
scale?: number;
|
|
17
|
-
/** Whether AdamW (when `weightDecay > 0`) should apply decoupled weight
|
|
18
|
-
* decay to this param. Default: `true` for `'randn'` init (weight matrices,
|
|
19
|
-
* embeddings), `false` for `'zeros'` / `'ones'` (biases, LN gains). Override
|
|
20
|
-
* to force or skip. Replaces `adam.decayFilter` for the common case. */
|
|
21
|
-
decay?: boolean;
|
|
22
|
-
}
|
|
23
|
-
type InitFn = (size: number, shape: readonly number[]) => Float32Array;
|
|
24
|
-
export declare abstract class Module {
|
|
25
|
-
/**
|
|
26
|
-
* Declare a learnable parameter at this module. Must be called from inside
|
|
27
|
-
* the constructor (typically as a field assignment). Returns a placeholder
|
|
28
|
-
* that gets replaced with a real Tensor at compile time.
|
|
29
|
-
*
|
|
30
|
-
* The parameter's name is auto-derived from its property path in the model
|
|
31
|
-
* tree (e.g. `layers.0.attn.W_q`). Init metadata travels with the param;
|
|
32
|
-
* call `compiled.uploadInitialParams()` to apply it after compile.
|
|
33
|
-
*/
|
|
34
|
-
protected param(shape: Shape, opts?: ParamOptions): Tensor;
|
|
35
|
-
}
|
|
36
|
-
export interface MaterializedParams {
|
|
37
|
-
/** Map from auto-derived path (e.g. `layers.0.attn.W_q`) to its Tensor. */
|
|
38
|
-
tensors: Record<string, Tensor>;
|
|
39
|
-
/** Init function per param path. Used by `uploadInitialParams`. */
|
|
40
|
-
initFns: Record<string, InitFn>;
|
|
41
|
-
/** Whether this param should receive AdamW weight decay. Resolved at
|
|
42
|
-
* `param()` time from `ParamOptions.decay` (with init-based default). */
|
|
43
|
-
decayFlags: Record<string, boolean>;
|
|
44
|
-
}
|
|
45
|
-
/**
|
|
46
|
-
* Walk the module tree and replace every ParamSentinel with a real Tensor
|
|
47
|
-
* created via `paramInput(autoName, ...)`. Must be called inside an active
|
|
48
|
-
* trace context (paramInput appends to the current graph).
|
|
49
|
-
*
|
|
50
|
-
* Returns the param tensors keyed by path, plus init functions for use by
|
|
51
|
-
* `uploadInitialParams`.
|
|
52
|
-
*/
|
|
53
|
-
export declare function materializeParams(root: Module): MaterializedParams;
|
|
54
|
-
export {};
|
|
55
|
-
//# sourceMappingURL=module.d.ts.map
|
package/dist/module.d.ts.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"module.d.ts","sourceRoot":"","sources":["../src/module.ts"],"names":[],"mappings":"AA2BA,OAAO,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,SAAS,CAAA;AAOnD;;;;;;;GAOG;AACH,MAAM,MAAM,QAAQ,GAChB,OAAO,GACP,OAAO,GACP,MAAM,GACN,CAAC,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,SAAS,MAAM,EAAE,KAAK,YAAY,CAAC,CAAA;AAE9D,MAAM,WAAW,YAAY;IAC3B,KAAK,CAAC,EAAE,KAAK,CAAA;IACb,qCAAqC;IACrC,IAAI,CAAC,EAAE,QAAQ,CAAA;IACf,uEAAuE;IACvE,KAAK,CAAC,EAAE,MAAM,CAAA;IACd;;;6EAGyE;IACzE,KAAK,CAAC,EAAE,OAAO,CAAA;CAChB;AAED,KAAK,MAAM,GAAG,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,SAAS,MAAM,EAAE,KAAK,YAAY,CAAA;AAuDtE,8BAAsB,MAAM;IAC1B;;;;;;;;OAQG;IACH,SAAS,CAAC,KAAK,CAAC,KAAK,EAAE,KAAK,EAAE,IAAI,CAAC,EAAE,YAAY,GAAG,MAAM;CAK3D;AAMD,MAAM,WAAW,kBAAkB;IACjC,2EAA2E;IAC3E,OAAO,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IAC/B,mEAAmE;IACnE,OAAO,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IAC/B;8EAC0E;IAC1E,UAAU,EAAE,MAAM,CAAC,MAAM,EAAE,OAAO,CAAC,CAAA;CACpC;AAED;;;;;;;GAOG;AACH,wBAAgB,iBAAiB,CAAC,IAAI,EAAE,MAAM,GAAG,kBAAkB,CAclE"}
|