tensorgrad 0.0.11 → 0.0.13
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +119 -119
- package/dist/buffers.js +1 -6
- package/dist/buffers.js.map +1 -1
- package/dist/codegen.js +30 -28
- package/dist/codegen.js.map +1 -1
- package/dist/compile.js +39 -68
- package/dist/compile.js.map +1 -1
- package/dist/grad.js +1 -14
- package/dist/grad.js.map +1 -1
- package/dist/index.d.ts +740 -14
- package/dist/runtime.js +9 -11
- package/dist/runtime.js.map +1 -1
- package/dist/trace.js +8 -13
- package/dist/trace.js.map +1 -1
- package/package.json +67 -61
- package/src/buffers.ts +1 -6
- package/src/codegen.ts +31 -28
- package/src/compile.ts +45 -91
- package/src/grad.ts +1 -11
- package/src/index.ts +47 -47
- package/src/runtime.ts +520 -515
- package/src/trace.ts +12 -9
- package/dist/adam.d.ts +0 -65
- package/dist/adam.d.ts.map +0 -1
- package/dist/buffers.d.ts +0 -57
- package/dist/buffers.d.ts.map +0 -1
- package/dist/capture.d.ts +0 -3
- package/dist/capture.d.ts.map +0 -1
- package/dist/codegen.d.ts +0 -23
- package/dist/codegen.d.ts.map +0 -1
- package/dist/compile.d.ts +0 -130
- package/dist/compile.d.ts.map +0 -1
- package/dist/grad.d.ts +0 -8
- package/dist/grad.d.ts.map +0 -1
- package/dist/index.d.ts.map +0 -1
- package/dist/ir.d.ts +0 -207
- package/dist/ir.d.ts.map +0 -1
- package/dist/module.d.ts +0 -55
- package/dist/module.d.ts.map +0 -1
- package/dist/nn.d.ts +0 -42
- package/dist/nn.d.ts.map +0 -1
- package/dist/ops.d.ts +0 -48
- package/dist/ops.d.ts.map +0 -1
- package/dist/runtime.d.ts +0 -108
- package/dist/runtime.d.ts.map +0 -1
- package/dist/shape.d.ts +0 -24
- package/dist/shape.d.ts.map +0 -1
- package/dist/trace.d.ts +0 -9
- package/dist/trace.d.ts.map +0 -1
package/dist/nn.d.ts
DELETED
|
@@ -1,42 +0,0 @@
|
|
|
1
|
-
import { Module } from './module.js';
|
|
2
|
-
import type { Tensor } from './ir.js';
|
|
3
|
-
import type { Captures } from './runtime.js';
|
|
4
|
-
export interface LinearOptions {
|
|
5
|
-
/** Include a bias term (default true). */
|
|
6
|
-
bias?: boolean;
|
|
7
|
-
}
|
|
8
|
-
export declare class Linear extends Module {
|
|
9
|
-
readonly inDim: number;
|
|
10
|
-
readonly outDim: number;
|
|
11
|
-
W: Tensor;
|
|
12
|
-
b: Tensor | null;
|
|
13
|
-
constructor(inDim: number, outDim: number, opts?: LinearOptions);
|
|
14
|
-
fwd(x: Tensor): Tensor;
|
|
15
|
-
}
|
|
16
|
-
export declare class LayerNorm extends Module {
|
|
17
|
-
readonly d: number;
|
|
18
|
-
readonly eps: number;
|
|
19
|
-
g: Tensor;
|
|
20
|
-
b: Tensor;
|
|
21
|
-
constructor(d: number, eps?: number);
|
|
22
|
-
fwd(x: Tensor): Tensor;
|
|
23
|
-
}
|
|
24
|
-
/** [..., T, D] → [..., H, T, D/H]. Folds the standard
|
|
25
|
-
* `transpose(reshape(x, [..., T, H, d]), [..., H, T, d])` pattern into one
|
|
26
|
-
* call. Last dim of `x` must divide evenly by `nHeads`. */
|
|
27
|
-
export declare function splitHeads(x: Tensor, nHeads: number): Tensor;
|
|
28
|
-
/** Inverse of `splitHeads`: [..., H, T, d] → [..., T, H*d]. */
|
|
29
|
-
export declare function mergeHeads(x: Tensor): Tensor;
|
|
30
|
-
/** Slice a captured tensor named `name` into one Float32Array per head, using
|
|
31
|
-
* the static shape registered at compile time. The leading axis is treated as
|
|
32
|
-
* heads (matching `splitHeads` layout at B=1); a leading singleton batch is
|
|
33
|
-
* stripped if present so callers can pass capture names directly. Throws if
|
|
34
|
-
* the capture isn't registered or wasn't read back this call. */
|
|
35
|
-
export declare function unsplitHeads(captures: Captures, name: string): Float32Array[];
|
|
36
|
-
/** Per-position cross-entropy along the last (vocab) axis: returns
|
|
37
|
-
* `-log p(target)` at each position. `logits` is `[..., V]`; `targets` is
|
|
38
|
-
* `[...]` of i32; result is `[...]` (one rank less than logits). The user
|
|
39
|
-
* applies their own masking + reduction downstream — useful when only some
|
|
40
|
-
* positions contribute (e.g. result-digit masking) or for label smoothing. */
|
|
41
|
-
export declare function crossEntropyLast(logits: Tensor, targets: Tensor): Tensor;
|
|
42
|
-
//# sourceMappingURL=nn.d.ts.map
|
package/dist/nn.d.ts.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"nn.d.ts","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"AAaA,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAA;AACpC,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,SAAS,CAAA;AAIrC,OAAO,KAAK,EAAE,QAAQ,EAAE,MAAM,cAAc,CAAA;AAM5C,MAAM,WAAW,aAAa;IAC5B,0CAA0C;IAC1C,IAAI,CAAC,EAAE,OAAO,CAAA;CACf;AAED,qBAAa,MAAO,SAAQ,MAAM;aAGJ,KAAK,EAAE,MAAM;aAAkB,MAAM,EAAE,MAAM;IAFzE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,GAAG,IAAI,CAAA;gBACY,KAAK,EAAE,MAAM,EAAkB,MAAM,EAAE,MAAM,EAAE,IAAI,GAAE,aAAkB;IAKnG,GAAG,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM;CAIvB;AAMD,qBAAa,SAAU,SAAQ,MAAM;aAGP,CAAC,EAAE,MAAM;aAAkB,GAAG,EAAE,MAAM;IAFlE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,CAAA;gBACmB,CAAC,EAAE,MAAM,EAAkB,GAAG,GAAE,MAAa;IAKzE,GAAG,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM;CAOvB;AAOD;;4DAE4D;AAC5D,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAa5D;AAED,+DAA+D;AAC/D,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAW5C;AAED;;;;kEAIkE;AAClE,wBAAgB,YAAY,CAAC,QAAQ,EAAE,QAAQ,EAAE,IAAI,EAAE,MAAM,GAAG,YAAY,EAAE,CAiB7E;AAMD;;;;+EAI+E;AAC/E,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM,GAAG,MAAM,CASxE"}
|
package/dist/ops.d.ts
DELETED
|
@@ -1,48 +0,0 @@
|
|
|
1
|
-
import type { Tensor, Shape, Dtype } from './ir.js';
|
|
2
|
-
export declare function add(a: Tensor, b: Tensor | number): Tensor;
|
|
3
|
-
export declare function sub(a: Tensor, b: Tensor | number): Tensor;
|
|
4
|
-
export declare function mul(a: Tensor, b: Tensor | number): Tensor;
|
|
5
|
-
export declare function div(a: Tensor, b: Tensor | number): Tensor;
|
|
6
|
-
export declare function mulScalar(a: Tensor, scalar: number): Tensor;
|
|
7
|
-
export declare function addScalar(a: Tensor, scalar: number): Tensor;
|
|
8
|
-
export declare const sqrt: (a: Tensor) => Tensor;
|
|
9
|
-
export declare const rsqrt: (a: Tensor) => Tensor;
|
|
10
|
-
export declare const log: (a: Tensor) => Tensor;
|
|
11
|
-
export declare const exp: (a: Tensor) => Tensor;
|
|
12
|
-
export declare const relu: (a: Tensor) => Tensor;
|
|
13
|
-
export declare function meanLast(a: Tensor): Tensor;
|
|
14
|
-
export declare function sumLast(a: Tensor): Tensor;
|
|
15
|
-
/** Reduce all elements to a 0-d scalar. Composes `reshape` + `sumLast`. */
|
|
16
|
-
export declare function sumAll(a: Tensor): Tensor;
|
|
17
|
-
export declare function reshape(a: Tensor, newShape: Shape): Tensor;
|
|
18
|
-
export declare function transpose(a: Tensor, perm: readonly number[]): Tensor;
|
|
19
|
-
/** Swap two axes of a tensor. Negative indices count from the end (so
|
|
20
|
-
* `swapAxes(x, -1, -2)` swaps the last two — the common attention pattern).
|
|
21
|
-
* All other axes keep their position. Implemented as `transpose` with the
|
|
22
|
-
* permutation `[0, 1, ..., axis2, ..., axis1, ..., n-1]`. */
|
|
23
|
-
export declare function swapAxes(a: Tensor, axis1: number, axis2: number): Tensor;
|
|
24
|
-
export declare function matmul(a: Tensor, b: Tensor): Tensor;
|
|
25
|
-
export declare function matmulBatched(a: Tensor, b: Tensor): Tensor;
|
|
26
|
-
export declare function oneHot(indices: Tensor, depth: number, dtype?: Dtype): Tensor;
|
|
27
|
-
/** Embedding lookup: pull rows from `table` indexed by `indices`. Decomposes
|
|
28
|
-
* to `oneHot(indices, vocab) @ table` so autograd works without a dedicated
|
|
29
|
-
* scatter-with-atomic-add backward — the matmul transpose rule handles it.
|
|
30
|
-
* `table` is `[vocab, dim]`; `indices` is any shape `[...]` of i32; result
|
|
31
|
-
* is `[..., dim]`. The vocab size is taken from `table.shape[0]`. */
|
|
32
|
-
export declare function embedding(table: Tensor, indices: Tensor): Tensor;
|
|
33
|
-
export declare function arange(n: number, dtype?: Dtype): Tensor;
|
|
34
|
-
export declare function softmaxCausalLast(a: Tensor): Tensor;
|
|
35
|
-
export declare function logSoftmaxLast(a: Tensor): Tensor;
|
|
36
|
-
export declare function whereCausal(a: Tensor, fillValue: number): Tensor;
|
|
37
|
-
export declare function sliceLastRange(a: Tensor, start: number, end: number): Tensor;
|
|
38
|
-
export declare function broadcastTo(a: Tensor, targetShape: Shape): Tensor;
|
|
39
|
-
export declare function sumToShape(a: Tensor, targetShape: Shape): Tensor;
|
|
40
|
-
export declare function constScalar(value: number, dtype?: Dtype): Tensor;
|
|
41
|
-
export declare const less: (a: Tensor, b: Tensor) => Tensor;
|
|
42
|
-
export declare const greater: (a: Tensor, b: Tensor) => Tensor;
|
|
43
|
-
export declare function where(cond: Tensor, a: Tensor, b: Tensor): Tensor;
|
|
44
|
-
export declare function reluGrad(x: Tensor, dy: Tensor): Tensor;
|
|
45
|
-
export declare function adamUpdateM(m: Tensor, g: Tensor, b1: number): Tensor;
|
|
46
|
-
export declare function adamUpdateV(v: Tensor, g: Tensor, b2: number): Tensor;
|
|
47
|
-
export declare function adamUpdateP(p: Tensor, mNew: Tensor, vNew: Tensor, lrt: Tensor, eps: number, decayShrink?: number | Tensor): Tensor;
|
|
48
|
-
//# sourceMappingURL=ops.d.ts.map
|
package/dist/ops.d.ts.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"ops.d.ts","sourceRoot":"","sources":["../src/ops.ts"],"names":[],"mappings":"AAWA,OAAO,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAU,MAAM,SAAS,CAAA;AAmC3D,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAMzD;AAQD,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAG3D;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAG3D;AAYD,eAAO,MAAM,IAAI,GAAK,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,KAAK,GAAI,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,GAAG,GAAM,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,GAAG,GAAM,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,IAAI,GAAK,GAAG,MAAM,KAAG,MAA2B,CAAA;AAO7D,wBAAgB,QAAQ,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAK1C;AAED,wBAAgB,OAAO,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAKzC;AAED,2EAA2E;AAC3E,wBAAgB,MAAM,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAExC;AAMD,wBAAgB,OAAO,CAAC,CAAC,EAAE,MAAM,EAAE,QAAQ,EAAE,KAAK,GAAG,MAAM,CAI1D;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,SAAS,MAAM,EAAE,GAAG,MAAM,CAIpE;AAED;;;8DAG8D;AAC9D,wBAAgB,QAAQ,CAAC,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,GAAG,MAAM,CAcxE;AAMD,wBAAgB,MAAM,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAOnD;AAED,wBAAgB,aAAa,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAO1D;AAMD,wBAAgB,MAAM,CAAC,OAAO,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAOnF;AAED;;;;sEAIsE;AACtE,wBAAgB,SAAS,CAAC,KAAK,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM,GAAG,MAAM,CAShE;AAGD,wBAAgB,MAAM,CAAC,CAAC,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAM9D;AASD,wBAAgB,iBAAiB,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAKnD;AAGD,wBAAgB,cAAc,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAIhD;AAMD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,SAAS,EAAE,MAAM,GAAG,MAAM,CAKhE;AAQD,wBAAgB,cAAc,CAAC,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,GAAG,EAAE,MAAM,GAAG,MAAM,CAI5E;AAOD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,WAAW,EAAE,KAAK,GAAG,MAAM,CAIjE;AAED,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,EAAE,WAAW,EAAE,KAAK,GAAG,MAAM,CAIhE;AAOD,wBAAgB,WAAW,CAAC,KAAK,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAGvE;AAWD,eAAO,MAAM,IAAI,GAAO,GAAG,MAAM,EAAE,GAAG,MAAM,KAAG,MAAqD,CAAA;AACpG,eAAO,MAAM,OAAO,GAAI,GAAG,MAAM,EAAE,GAAG,MAAM,KAAG,MAAqD,CAAA;AAGpG,wBAAgB,KAAK,CAAC,IAAI,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAMhE;AAID,wBAAgB,QAAQ,CAAC,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOtD;AAMD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOpE;AAED,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOpE;AAED,wBAAgB,WAAW,CACzB,CAAC,EAAE,MAAM,EACT,IAAI,EAAE,MAAM,EACZ,IAAI,EAAE,MAAM,EACZ,GAAG,EAAE,MAAM,EACX,GAAG,EAAE,MAAM,EACX,WAAW,GAAE,MAAM,GAAG,MAAU,GAC/B,MAAM,CA2BR"}
|
package/dist/runtime.d.ts
DELETED
|
@@ -1,108 +0,0 @@
|
|
|
1
|
-
import type { BufferPlan } from './buffers.js';
|
|
2
|
-
import type { KernelSpec } from './codegen.js';
|
|
3
|
-
export interface UploadParamsOptions {
|
|
4
|
-
/** Skip the "missing param" check, allowing the caller to update only some
|
|
5
|
-
* params and leave the rest at their current GPU values. Extra (unknown)
|
|
6
|
-
* keys are still rejected — that's always a typo. Default: false. */
|
|
7
|
-
partial?: boolean;
|
|
8
|
-
}
|
|
9
|
-
/**
|
|
10
|
-
* Activation readbacks for one `step()`/`run()` call. Keyed by the names
|
|
11
|
-
* passed to `capture(name, t)` during the trace. `get(name)` throws if the
|
|
12
|
-
* name isn't registered or wasn't read back this call (i.e., the call was
|
|
13
|
-
* made without `{ withCaptures: true }`); use `has(name)` if you need to
|
|
14
|
-
* branch. `shapeOf(name)` returns the static-after-compile shape and works
|
|
15
|
-
* regardless of whether captures were read back.
|
|
16
|
-
*/
|
|
17
|
-
export declare class Captures {
|
|
18
|
-
private readonly shapes;
|
|
19
|
-
private readonly data;
|
|
20
|
-
constructor(shapes: Record<string, readonly number[]>, data: Map<string, Float32Array>);
|
|
21
|
-
get(name: string): Float32Array;
|
|
22
|
-
shapeOf(name: string): readonly number[];
|
|
23
|
-
has(name: string): boolean;
|
|
24
|
-
names(): string[];
|
|
25
|
-
}
|
|
26
|
-
export interface RunResult {
|
|
27
|
-
output: Float32Array;
|
|
28
|
-
captures: Captures;
|
|
29
|
-
}
|
|
30
|
-
export interface StepResult {
|
|
31
|
-
loss: number;
|
|
32
|
-
captures: Captures;
|
|
33
|
-
}
|
|
34
|
-
export interface RunOptions {
|
|
35
|
-
/** Read back tensors registered via `capture(name, t)` during the trace.
|
|
36
|
-
* Default false. When false, the returned `captures` is empty (calling
|
|
37
|
-
* `.get` throws); when true, captures are read back and accessible. */
|
|
38
|
-
withCaptures?: boolean;
|
|
39
|
-
}
|
|
40
|
-
/** Common surface for both training and forward-only compiled runtimes. */
|
|
41
|
-
export interface CompiledBase {
|
|
42
|
-
/** The GPUDevice this runtime is bound to. Pass to sibling compiles to
|
|
43
|
-
* share the device, or use directly for other GPU work. */
|
|
44
|
-
device: GPUDevice;
|
|
45
|
-
/** Param name -> the underlying GPUBuffer. Pass to a sibling compile via
|
|
46
|
-
* `sharedParams` to share without copies. */
|
|
47
|
-
params: Map<string, GPUBuffer>;
|
|
48
|
-
/** Shape of the graph's output (loss scalar `[]` for training; the user's
|
|
49
|
-
* returned tensor for forward-only compiles). */
|
|
50
|
-
outputShape: number[];
|
|
51
|
-
/** Upload parameter Float32Arrays to their GPU buffers. By default, requires
|
|
52
|
-
* *all* params to be present; throws on any unknown or missing key. Pass
|
|
53
|
-
* `{ partial: true }` to skip the missing-key check. */
|
|
54
|
-
uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void;
|
|
55
|
-
/** Read all parameters back as Float32Arrays — used for UI panels. */
|
|
56
|
-
downloadParams(): Promise<Record<string, Float32Array>>;
|
|
57
|
-
/** Free GPU resources. */
|
|
58
|
-
destroy(): void;
|
|
59
|
-
}
|
|
60
|
-
/** Run a dispatch and read back the full output tensor. Captures are always
|
|
61
|
-
* returned; their data is empty unless `{ withCaptures: true }` is passed. */
|
|
62
|
-
export type RunFn = (inputs: Record<string, Int32Array | Float32Array>, opts?: RunOptions) => Promise<RunResult>;
|
|
63
|
-
export interface CompiledRuntime extends CompiledBase {
|
|
64
|
-
/** Read all parameter gradients back. Mostly for verification / debugging. */
|
|
65
|
-
downloadParamGrads(): Promise<Record<string, Float32Array>>;
|
|
66
|
-
/**
|
|
67
|
-
* One full forward+backward step.
|
|
68
|
-
* 1. Uploads `inputs` (tokens, targets, masks) to input buffers.
|
|
69
|
-
* 2. Dispatches every kernel in order.
|
|
70
|
-
* 3. Reads back the loss scalar (and any registered captures, if requested).
|
|
71
|
-
* Default returns the loss as a JS number; with `{ withCaptures: true }`
|
|
72
|
-
* returns `{ loss, captures }`.
|
|
73
|
-
*/
|
|
74
|
-
step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>;
|
|
75
|
-
step(inputs: Record<string, Int32Array | Float32Array>, opts: {
|
|
76
|
-
withCaptures: true;
|
|
77
|
-
}): Promise<StepResult>;
|
|
78
|
-
step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>;
|
|
79
|
-
/** Same dispatch as step() but returns the full output Float32Array — for
|
|
80
|
-
* training graphs the output is a scalar loss, so step() is usually more
|
|
81
|
-
* convenient. Provided for parity with `compileForward`. */
|
|
82
|
-
run: RunFn;
|
|
83
|
-
/** Re-zero all optimizer state buffers (Adam's m/v) in place. Pair with
|
|
84
|
-
* `uploadInitialParams()` for a full training reset without recompile. */
|
|
85
|
-
resetOptimizerState(): void;
|
|
86
|
-
}
|
|
87
|
-
/** Forward-only compiled runtime — produced by `compileForward`. No optimizer,
|
|
88
|
-
* no backward. Returns the output tensor (not just a scalar) per `run()` call. */
|
|
89
|
-
export interface CompiledForward extends CompiledBase {
|
|
90
|
-
run: RunFn;
|
|
91
|
-
}
|
|
92
|
-
export interface RuntimeOpts {
|
|
93
|
-
/** Pre-acquired GPUDevice. If omitted, runtime requests its own. */
|
|
94
|
-
device?: GPUDevice;
|
|
95
|
-
/** External param buffers to bind in place of allocating fresh ones, keyed
|
|
96
|
-
* by param name. Used to share params between a training compile and a
|
|
97
|
-
* sibling forward-only compile (e.g., a B=1 inference graph). When a name
|
|
98
|
-
* is in this map, the runtime reuses the provided GPUBuffer; otherwise it
|
|
99
|
-
* allocates as usual. */
|
|
100
|
-
sharedParams?: Map<string, GPUBuffer>;
|
|
101
|
-
}
|
|
102
|
-
export declare function createRuntime(plan: BufferPlan, kernels: KernelSpec[], lossBufferId: number, opts?: RuntimeOpts): Promise<CompiledRuntime>;
|
|
103
|
-
/** Same machinery as `createRuntime`, narrower public type: a forward-only
|
|
104
|
-
* graph exposes `run()` instead of `step()` (no optimizer state, no scalar-
|
|
105
|
-
* loss readback). The full runtime object is built once and projected by
|
|
106
|
-
* `compileForward` to the public shape. */
|
|
107
|
-
export declare function createForwardRuntime(plan: BufferPlan, kernels: KernelSpec[], outputBufferId: number, opts?: RuntimeOpts): Promise<CompiledForward>;
|
|
108
|
-
//# sourceMappingURL=runtime.d.ts.map
|
package/dist/runtime.d.ts.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"runtime.d.ts","sourceRoot":"","sources":["../src/runtime.ts"],"names":[],"mappings":"AAMA,OAAO,KAAK,EAAE,UAAU,EAAE,MAAM,cAAc,CAAA;AAC9C,OAAO,KAAK,EAAE,UAAU,EAAE,MAAM,cAAc,CAAA;AAM9C,MAAM,WAAW,mBAAmB;IAClC;;0EAEsE;IACtE,OAAO,CAAC,EAAE,OAAO,CAAA;CAClB;AAED;;;;;;;GAOG;AACH,qBAAa,QAAQ;IAEjB,OAAO,CAAC,QAAQ,CAAC,MAAM;IACvB,OAAO,CAAC,QAAQ,CAAC,IAAI;gBADJ,MAAM,EAAE,MAAM,CAAC,MAAM,EAAE,SAAS,MAAM,EAAE,CAAC,EACzC,IAAI,EAAE,GAAG,CAAC,MAAM,EAAE,YAAY,CAAC;IAElD,GAAG,CAAC,IAAI,EAAE,MAAM,GAAG,YAAY;IAS/B,OAAO,CAAC,IAAI,EAAE,MAAM,GAAG,SAAS,MAAM,EAAE;IAQxC,GAAG,CAAC,IAAI,EAAE,MAAM,GAAG,OAAO;IAC1B,KAAK,IAAI,MAAM,EAAE;CAClB;AAED,MAAM,WAAW,SAAS;IACxB,MAAM,EAAE,YAAY,CAAA;IACpB,QAAQ,EAAE,QAAQ,CAAA;CACnB;AAED,MAAM,WAAW,UAAU;IACzB,IAAI,EAAE,MAAM,CAAA;IACZ,QAAQ,EAAE,QAAQ,CAAA;CACnB;AAED,MAAM,WAAW,UAAU;IACzB;;4EAEwE;IACxE,YAAY,CAAC,EAAE,OAAO,CAAA;CACvB;AAED,2EAA2E;AAC3E,MAAM,WAAW,YAAY;IAC3B;gEAC4D;IAC5D,MAAM,EAAE,SAAS,CAAA;IACjB;kDAC8C;IAC9C,MAAM,EAAE,GAAG,CAAC,MAAM,EAAE,SAAS,CAAC,CAAA;IAC9B;sDACkD;IAClD,WAAW,EAAE,MAAM,EAAE,CAAA;IACrB;;6DAEyD;IACzD,YAAY,CAAC,MAAM,EAAE,MAAM,CAAC,MAAM,EAAE,YAAY,CAAC,EAAE,IAAI,CAAC,EAAE,mBAAmB,GAAG,IAAI,CAAA;IACpF,sEAAsE;IACtE,cAAc,IAAI,OAAO,CAAC,MAAM,CAAC,MAAM,EAAE,YAAY,CAAC,CAAC,CAAA;IACvD,0BAA0B;IAC1B,OAAO,IAAI,IAAI,CAAA;CAChB;AAED;+EAC+E;AAC/E,MAAM,MAAM,KAAK,GAAG,CAClB,MAAM,EAAE,MAAM,CAAC,MAAM,EAAE,UAAU,GAAG,YAAY,CAAC,EACjD,IAAI,CAAC,EAAE,UAAU,KACd,OAAO,CAAC,SAAS,CAAC,CAAA;AAEvB,MAAM,WAAW,eAAgB,SAAQ,YAAY;IACnD,8EAA8E;IAC9E,kBAAkB,IAAI,OAAO,CAAC,MAAM,CAAC,MAAM,EAAE,YAAY,CAAC,CAAC,CAAA;IAC3D;;;;;;;OAOG;IACH,IAAI,CAAC,MAAM,EAAE,MAAM,CAAC,MAAM,EAAE,UAAU,GAAG,YAAY,CAAC,GAAG,OAAO,CAAC,MAAM,CAAC,CAAA;IACxE,IAAI,CAAC,MAAM,EAAE,MAAM,CAAC,MAAM,EAAE,UAAU,GAAG,YAAY,CAAC,EAAE,IAAI,EAAE;QAAE,YAAY,EAAE,IAAI,CAAA;KAAE,GAAG,OAAO,CAAC,UAAU,CAAC,CAAA;IAC1G,IAAI,CAAC,MAAM,EAAE,MAAM,CAAC,MAAM,EAAE,UAAU,GAAG,YAAY,CAAC,EAAE,IAAI,EAAE,UAAU,GAAG,OAAO,CAAC,MAAM,GAAG,UAAU,CAAC,CAAA;IACvG;;iEAE6D;IAC7D,GAAG,EAAE,KAAK,CAAA;IACV;+EAC2E;IAC3E,mBAAmB,IAAI,IAAI,CAAA;CAC5B;AAED;mFACmF;AACnF,MAAM,WAAW,eAAgB,SAAQ,YAAY;IACnD,GAAG,EAAE,KAAK,CAAA;CACX;AAED,MAAM,WAAW,WAAW;IAC1B,oEAAoE;IACpE,MAAM,CAAC,EAAE,SAAS,CAAA;IAClB;;;;8BAI0B;IAC1B,YAAY,CAAC,EAAE,GAAG,CAAC,MAAM,EAAE,SAAS,CAAC,CAAA;CACtC;AAQD,wBAAsB,aAAa,CACjC,IAAI,EAAE,UAAU,EAChB,OAAO,EAAE,UAAU,EAAE,EACrB,YAAY,EAAE,MAAM,EACpB,IAAI,GAAE,WAAgB,GACrB,OAAO,CAAC,eAAe,CAAC,CAsV1B;AAED;;;4CAG4C;AAC5C,wBAAsB,oBAAoB,CACxC,IAAI,EAAE,UAAU,EAChB,OAAO,EAAE,UAAU,EAAE,EACrB,cAAc,EAAE,MAAM,EACtB,IAAI,GAAE,WAAgB,GACrB,OAAO,CAAC,eAAe,CAAC,CAE1B"}
|
package/dist/shape.d.ts
DELETED
|
@@ -1,24 +0,0 @@
|
|
|
1
|
-
import type { Shape, CallSite } from './ir.js';
|
|
2
|
-
export declare class ShapeError extends Error {
|
|
3
|
-
constructor(message: string, site: CallSite | null);
|
|
4
|
-
}
|
|
5
|
-
export declare function shapesEqual(a: Shape, b: Shape): boolean;
|
|
6
|
-
export declare function shapeSize(shape: Shape): number;
|
|
7
|
-
export declare function showShape(shape: Shape): string;
|
|
8
|
-
export declare function broadcastTrailing(a: Shape, b: Shape): Shape | null;
|
|
9
|
-
export declare function inferElementwiseBinop(opName: string, aShape: Shape, bShape: Shape, site: CallSite | null): Shape;
|
|
10
|
-
export declare function inferUnary(_opName: string, aShape: Shape, _site: CallSite | null): Shape;
|
|
11
|
-
export declare function inferMeanLast(opName: string, aShape: Shape, site: CallSite | null): Shape;
|
|
12
|
-
export declare function inferSumLast(opName: string, aShape: Shape, site: CallSite | null): Shape;
|
|
13
|
-
export declare function inferReshape(opName: string, aShape: Shape, newShape: Shape, site: CallSite | null): Shape;
|
|
14
|
-
export declare function inferTranspose(opName: string, aShape: Shape, perm: readonly number[], site: CallSite | null): Shape;
|
|
15
|
-
export declare function inferMatmul(opName: string, aShape: Shape, bShape: Shape, site: CallSite | null): Shape;
|
|
16
|
-
export declare function inferMatmulBatched(opName: string, aShape: Shape, bShape: Shape, site: CallSite | null): Shape;
|
|
17
|
-
export declare function inferOneHot(opName: string, indicesShape: Shape, depth: number, site: CallSite | null): Shape;
|
|
18
|
-
export declare function inferWhereCausal(opName: string, aShape: Shape, site: CallSite | null): Shape;
|
|
19
|
-
export declare function inferSliceLastRange(opName: string, aShape: Shape, start: number, end: number, site: CallSite | null): Shape;
|
|
20
|
-
export declare function inferBroadcastTo(opName: string, aShape: Shape, targetShape: Shape, site: CallSite | null): Shape;
|
|
21
|
-
export declare function inferSumToShape(opName: string, aShape: Shape, targetShape: Shape, site: CallSite | null): Shape;
|
|
22
|
-
export declare function inferWhere(opName: string, condShape: Shape, aShape: Shape, bShape: Shape, site: CallSite | null): Shape;
|
|
23
|
-
export declare function inferReluGrad(opName: string, xShape: Shape, dyShape: Shape, site: CallSite | null): Shape;
|
|
24
|
-
//# sourceMappingURL=shape.d.ts.map
|
package/dist/shape.d.ts.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"shape.d.ts","sourceRoot":"","sources":["../src/shape.ts"],"names":[],"mappings":"AAkBA,OAAO,KAAK,EAAE,KAAK,EAAE,QAAQ,EAAE,MAAM,SAAS,CAAA;AAO9C,qBAAa,UAAW,SAAQ,KAAK;gBACvB,OAAO,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI;CAKnD;AAUD,wBAAgB,WAAW,CAAC,CAAC,EAAE,KAAK,EAAE,CAAC,EAAE,KAAK,GAAG,OAAO,CAIvD;AAED,wBAAgB,SAAS,CAAC,KAAK,EAAE,KAAK,GAAG,MAAM,CAI9C;AAED,wBAAgB,SAAS,CAAC,KAAK,EAAE,KAAK,GAAG,MAAM,CAE9C;AAKD,wBAAgB,iBAAiB,CAAC,CAAC,EAAE,KAAK,EAAE,CAAC,EAAE,KAAK,GAAG,KAAK,GAAG,IAAI,CAclE;AASD,wBAAgB,qBAAqB,CACnC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAClE,KAAK,CAWP;AAED,wBAAgB,UAAU,CAAC,OAAO,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAExF;AAED,wBAAgB,aAAa,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAIzF;AAED,wBAAgB,YAAY,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAIxF;AAED,wBAAgB,YAAY,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,QAAQ,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CA0BzG;AAED,wBAAgB,cAAc,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,SAAS,MAAM,EAAE,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAWnH;AAGD,wBAAgB,WAAW,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAStG;AAGD,wBAAgB,kBAAkB,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAoB7G;AAED,wBAAgB,WAAW,CAAC,MAAM,EAAE,MAAM,EAAE,YAAY,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAG5G;AAGD,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAM5F;AAED,wBAAgB,mBAAmB,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,GAAG,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAO3H;AAID,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,WAAW,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAahH;AAID,wBAAgB,eAAe,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,WAAW,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAa/G;AAID,wBAAgB,UAAU,CAAC,MAAM,EAAE,MAAM,EAAE,SAAS,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAMvH;AAED,wBAAgB,aAAa,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,OAAO,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAKzG"}
|
package/dist/trace.d.ts
DELETED
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
import type { Graph, Tensor, Shape, Dtype } from './ir.js';
|
|
2
|
-
export declare function currentGraph(): Graph;
|
|
3
|
-
export declare function isCaptureEnabled(): boolean;
|
|
4
|
-
export declare function trace(fn: () => Tensor | Tensor[]): Graph;
|
|
5
|
-
export declare function traceInto<T>(g: Graph, fn: () => T): T;
|
|
6
|
-
export declare function paramInput(name: string, shape: Shape, dtype?: Dtype): Tensor;
|
|
7
|
-
export declare function tensorInput(name: string, shape: Shape, dtype?: Dtype): Tensor;
|
|
8
|
-
export declare function stateInput(name: string, shape: Shape, dtype?: Dtype, initValue?: number): Tensor;
|
|
9
|
-
//# sourceMappingURL=trace.d.ts.map
|
package/dist/trace.d.ts.map
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
{"version":3,"file":"trace.d.ts","sourceRoot":"","sources":["../src/trace.ts"],"names":[],"mappings":"AAgBA,OAAO,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,SAAS,CAAA;AAU1D,wBAAgB,YAAY,IAAI,KAAK,CAQpC;AAED,wBAAgB,gBAAgB,IAAI,OAAO,CAE1C;AAID,wBAAgB,KAAK,CAAC,EAAE,EAAE,MAAM,MAAM,GAAG,MAAM,EAAE,GAAG,KAAK,CAkBxD;AAQD,wBAAgB,SAAS,CAAC,CAAC,EAAE,CAAC,EAAE,KAAK,EAAE,EAAE,EAAE,MAAM,CAAC,GAAG,CAAC,CAWrD;AAOD,wBAAgB,UAAU,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAOnF;AAED,wBAAgB,WAAW,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAOpF;AAID,wBAAgB,UAAU,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,GAAE,KAAa,EAAE,SAAS,SAAI,GAAG,MAAM,CAOlG"}
|