tensorgrad 0.0.12 → 0.0.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. package/dist/buffers.js +1 -6
  2. package/dist/buffers.js.map +1 -1
  3. package/dist/codegen.js +30 -28
  4. package/dist/codegen.js.map +1 -1
  5. package/dist/compile.js +39 -68
  6. package/dist/compile.js.map +1 -1
  7. package/dist/grad.js +1 -14
  8. package/dist/grad.js.map +1 -1
  9. package/dist/index.d.ts +740 -14
  10. package/dist/runtime.js +6 -9
  11. package/dist/runtime.js.map +1 -1
  12. package/dist/trace.js +8 -13
  13. package/dist/trace.js.map +1 -1
  14. package/package.json +9 -3
  15. package/src/buffers.ts +1 -6
  16. package/src/codegen.ts +31 -28
  17. package/src/compile.ts +312 -358
  18. package/src/grad.ts +1 -11
  19. package/src/runtime.ts +6 -9
  20. package/src/trace.ts +12 -9
  21. package/dist/adam.d.ts +0 -65
  22. package/dist/adam.d.ts.map +0 -1
  23. package/dist/buffers.d.ts +0 -57
  24. package/dist/buffers.d.ts.map +0 -1
  25. package/dist/capture.d.ts +0 -3
  26. package/dist/capture.d.ts.map +0 -1
  27. package/dist/codegen.d.ts +0 -23
  28. package/dist/codegen.d.ts.map +0 -1
  29. package/dist/compile.d.ts +0 -130
  30. package/dist/compile.d.ts.map +0 -1
  31. package/dist/grad.d.ts +0 -8
  32. package/dist/grad.d.ts.map +0 -1
  33. package/dist/index.d.ts.map +0 -1
  34. package/dist/ir.d.ts +0 -207
  35. package/dist/ir.d.ts.map +0 -1
  36. package/dist/module.d.ts +0 -55
  37. package/dist/module.d.ts.map +0 -1
  38. package/dist/nn.d.ts +0 -42
  39. package/dist/nn.d.ts.map +0 -1
  40. package/dist/ops.d.ts +0 -48
  41. package/dist/ops.d.ts.map +0 -1
  42. package/dist/runtime.d.ts +0 -115
  43. package/dist/runtime.d.ts.map +0 -1
  44. package/dist/shape.d.ts +0 -24
  45. package/dist/shape.d.ts.map +0 -1
  46. package/dist/trace.d.ts +0 -9
  47. package/dist/trace.d.ts.map +0 -1
package/dist/module.d.ts DELETED
@@ -1,55 +0,0 @@
1
- import type { Tensor, Shape, Dtype } from './ir.js';
2
- /** How a parameter's initial values are produced.
3
- * - `'randn'` — Gaussian, with `scale` (default 0.02). The common case for
4
- * weight matrices and embeddings.
5
- * - `'zeros'` — fill with 0. Common for biases and LayerNorm beta.
6
- * - `'ones'` — fill with 1. Common for LayerNorm gain.
7
- * - Custom function — receives total element count and shape, returns the
8
- * Float32Array. Use for fan-in scaling or any non-standard scheme.
9
- */
10
- export type InitSpec = 'randn' | 'zeros' | 'ones' | ((size: number, shape: readonly number[]) => Float32Array);
11
- export interface ParamOptions {
12
- dtype?: Dtype;
13
- /** Init kind. Default: `'randn'`. */
14
- init?: InitSpec;
15
- /** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
16
- scale?: number;
17
- /** Whether AdamW (when `weightDecay > 0`) should apply decoupled weight
18
- * decay to this param. Default: `true` for `'randn'` init (weight matrices,
19
- * embeddings), `false` for `'zeros'` / `'ones'` (biases, LN gains). Override
20
- * to force or skip. Replaces `adam.decayFilter` for the common case. */
21
- decay?: boolean;
22
- }
23
- type InitFn = (size: number, shape: readonly number[]) => Float32Array;
24
- export declare abstract class Module {
25
- /**
26
- * Declare a learnable parameter at this module. Must be called from inside
27
- * the constructor (typically as a field assignment). Returns a placeholder
28
- * that gets replaced with a real Tensor at compile time.
29
- *
30
- * The parameter's name is auto-derived from its property path in the model
31
- * tree (e.g. `layers.0.attn.W_q`). Init metadata travels with the param;
32
- * call `compiled.uploadInitialParams()` to apply it after compile.
33
- */
34
- protected param(shape: Shape, opts?: ParamOptions): Tensor;
35
- }
36
- export interface MaterializedParams {
37
- /** Map from auto-derived path (e.g. `layers.0.attn.W_q`) to its Tensor. */
38
- tensors: Record<string, Tensor>;
39
- /** Init function per param path. Used by `uploadInitialParams`. */
40
- initFns: Record<string, InitFn>;
41
- /** Whether this param should receive AdamW weight decay. Resolved at
42
- * `param()` time from `ParamOptions.decay` (with init-based default). */
43
- decayFlags: Record<string, boolean>;
44
- }
45
- /**
46
- * Walk the module tree and replace every ParamSentinel with a real Tensor
47
- * created via `paramInput(autoName, ...)`. Must be called inside an active
48
- * trace context (paramInput appends to the current graph).
49
- *
50
- * Returns the param tensors keyed by path, plus init functions for use by
51
- * `uploadInitialParams`.
52
- */
53
- export declare function materializeParams(root: Module): MaterializedParams;
54
- export {};
55
- //# sourceMappingURL=module.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"module.d.ts","sourceRoot":"","sources":["../src/module.ts"],"names":[],"mappings":"AA2BA,OAAO,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,SAAS,CAAA;AAOnD;;;;;;;GAOG;AACH,MAAM,MAAM,QAAQ,GAChB,OAAO,GACP,OAAO,GACP,MAAM,GACN,CAAC,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,SAAS,MAAM,EAAE,KAAK,YAAY,CAAC,CAAA;AAE9D,MAAM,WAAW,YAAY;IAC3B,KAAK,CAAC,EAAE,KAAK,CAAA;IACb,qCAAqC;IACrC,IAAI,CAAC,EAAE,QAAQ,CAAA;IACf,uEAAuE;IACvE,KAAK,CAAC,EAAE,MAAM,CAAA;IACd;;;6EAGyE;IACzE,KAAK,CAAC,EAAE,OAAO,CAAA;CAChB;AAED,KAAK,MAAM,GAAG,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,SAAS,MAAM,EAAE,KAAK,YAAY,CAAA;AAuDtE,8BAAsB,MAAM;IAC1B;;;;;;;;OAQG;IACH,SAAS,CAAC,KAAK,CAAC,KAAK,EAAE,KAAK,EAAE,IAAI,CAAC,EAAE,YAAY,GAAG,MAAM;CAK3D;AAMD,MAAM,WAAW,kBAAkB;IACjC,2EAA2E;IAC3E,OAAO,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IAC/B,mEAAmE;IACnE,OAAO,EAAE,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,CAAA;IAC/B;8EAC0E;IAC1E,UAAU,EAAE,MAAM,CAAC,MAAM,EAAE,OAAO,CAAC,CAAA;CACpC;AAED;;;;;;;GAOG;AACH,wBAAgB,iBAAiB,CAAC,IAAI,EAAE,MAAM,GAAG,kBAAkB,CAclE"}
package/dist/nn.d.ts DELETED
@@ -1,42 +0,0 @@
1
- import { Module } from './module.js';
2
- import type { Tensor } from './ir.js';
3
- import type { Captures } from './runtime.js';
4
- export interface LinearOptions {
5
- /** Include a bias term (default true). */
6
- bias?: boolean;
7
- }
8
- export declare class Linear extends Module {
9
- readonly inDim: number;
10
- readonly outDim: number;
11
- W: Tensor;
12
- b: Tensor | null;
13
- constructor(inDim: number, outDim: number, opts?: LinearOptions);
14
- fwd(x: Tensor): Tensor;
15
- }
16
- export declare class LayerNorm extends Module {
17
- readonly d: number;
18
- readonly eps: number;
19
- g: Tensor;
20
- b: Tensor;
21
- constructor(d: number, eps?: number);
22
- fwd(x: Tensor): Tensor;
23
- }
24
- /** [..., T, D] → [..., H, T, D/H]. Folds the standard
25
- * `transpose(reshape(x, [..., T, H, d]), [..., H, T, d])` pattern into one
26
- * call. Last dim of `x` must divide evenly by `nHeads`. */
27
- export declare function splitHeads(x: Tensor, nHeads: number): Tensor;
28
- /** Inverse of `splitHeads`: [..., H, T, d] → [..., T, H*d]. */
29
- export declare function mergeHeads(x: Tensor): Tensor;
30
- /** Slice a captured tensor named `name` into one Float32Array per head, using
31
- * the static shape registered at compile time. The leading axis is treated as
32
- * heads (matching `splitHeads` layout at B=1); a leading singleton batch is
33
- * stripped if present so callers can pass capture names directly. Throws if
34
- * the capture isn't registered or wasn't read back this call. */
35
- export declare function unsplitHeads(captures: Captures, name: string): Float32Array[];
36
- /** Per-position cross-entropy along the last (vocab) axis: returns
37
- * `-log p(target)` at each position. `logits` is `[..., V]`; `targets` is
38
- * `[...]` of i32; result is `[...]` (one rank less than logits). The user
39
- * applies their own masking + reduction downstream — useful when only some
40
- * positions contribute (e.g. result-digit masking) or for label smoothing. */
41
- export declare function crossEntropyLast(logits: Tensor, targets: Tensor): Tensor;
42
- //# sourceMappingURL=nn.d.ts.map
package/dist/nn.d.ts.map DELETED
@@ -1 +0,0 @@
1
- {"version":3,"file":"nn.d.ts","sourceRoot":"","sources":["../src/nn.ts"],"names":[],"mappings":"AAaA,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAA;AACpC,OAAO,KAAK,EAAE,MAAM,EAAE,MAAM,SAAS,CAAA;AAIrC,OAAO,KAAK,EAAE,QAAQ,EAAE,MAAM,cAAc,CAAA;AAM5C,MAAM,WAAW,aAAa;IAC5B,0CAA0C;IAC1C,IAAI,CAAC,EAAE,OAAO,CAAA;CACf;AAED,qBAAa,MAAO,SAAQ,MAAM;aAGJ,KAAK,EAAE,MAAM;aAAkB,MAAM,EAAE,MAAM;IAFzE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,GAAG,IAAI,CAAA;gBACY,KAAK,EAAE,MAAM,EAAkB,MAAM,EAAE,MAAM,EAAE,IAAI,GAAE,aAAkB;IAKnG,GAAG,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM;CAIvB;AAMD,qBAAa,SAAU,SAAQ,MAAM;aAGP,CAAC,EAAE,MAAM;aAAkB,GAAG,EAAE,MAAM;IAFlE,CAAC,EAAE,MAAM,CAAA;IACT,CAAC,EAAE,MAAM,CAAA;gBACmB,CAAC,EAAE,MAAM,EAAkB,GAAG,GAAE,MAAa;IAKzE,GAAG,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM;CAOvB;AAOD;;4DAE4D;AAC5D,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAa5D;AAED,+DAA+D;AAC/D,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAW5C;AAED;;;;kEAIkE;AAClE,wBAAgB,YAAY,CAAC,QAAQ,EAAE,QAAQ,EAAE,IAAI,EAAE,MAAM,GAAG,YAAY,EAAE,CAiB7E;AAMD;;;;+EAI+E;AAC/E,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM,GAAG,MAAM,CASxE"}
package/dist/ops.d.ts DELETED
@@ -1,48 +0,0 @@
1
- import type { Tensor, Shape, Dtype } from './ir.js';
2
- export declare function add(a: Tensor, b: Tensor | number): Tensor;
3
- export declare function sub(a: Tensor, b: Tensor | number): Tensor;
4
- export declare function mul(a: Tensor, b: Tensor | number): Tensor;
5
- export declare function div(a: Tensor, b: Tensor | number): Tensor;
6
- export declare function mulScalar(a: Tensor, scalar: number): Tensor;
7
- export declare function addScalar(a: Tensor, scalar: number): Tensor;
8
- export declare const sqrt: (a: Tensor) => Tensor;
9
- export declare const rsqrt: (a: Tensor) => Tensor;
10
- export declare const log: (a: Tensor) => Tensor;
11
- export declare const exp: (a: Tensor) => Tensor;
12
- export declare const relu: (a: Tensor) => Tensor;
13
- export declare function meanLast(a: Tensor): Tensor;
14
- export declare function sumLast(a: Tensor): Tensor;
15
- /** Reduce all elements to a 0-d scalar. Composes `reshape` + `sumLast`. */
16
- export declare function sumAll(a: Tensor): Tensor;
17
- export declare function reshape(a: Tensor, newShape: Shape): Tensor;
18
- export declare function transpose(a: Tensor, perm: readonly number[]): Tensor;
19
- /** Swap two axes of a tensor. Negative indices count from the end (so
20
- * `swapAxes(x, -1, -2)` swaps the last two — the common attention pattern).
21
- * All other axes keep their position. Implemented as `transpose` with the
22
- * permutation `[0, 1, ..., axis2, ..., axis1, ..., n-1]`. */
23
- export declare function swapAxes(a: Tensor, axis1: number, axis2: number): Tensor;
24
- export declare function matmul(a: Tensor, b: Tensor): Tensor;
25
- export declare function matmulBatched(a: Tensor, b: Tensor): Tensor;
26
- export declare function oneHot(indices: Tensor, depth: number, dtype?: Dtype): Tensor;
27
- /** Embedding lookup: pull rows from `table` indexed by `indices`. Decomposes
28
- * to `oneHot(indices, vocab) @ table` so autograd works without a dedicated
29
- * scatter-with-atomic-add backward — the matmul transpose rule handles it.
30
- * `table` is `[vocab, dim]`; `indices` is any shape `[...]` of i32; result
31
- * is `[..., dim]`. The vocab size is taken from `table.shape[0]`. */
32
- export declare function embedding(table: Tensor, indices: Tensor): Tensor;
33
- export declare function arange(n: number, dtype?: Dtype): Tensor;
34
- export declare function softmaxCausalLast(a: Tensor): Tensor;
35
- export declare function logSoftmaxLast(a: Tensor): Tensor;
36
- export declare function whereCausal(a: Tensor, fillValue: number): Tensor;
37
- export declare function sliceLastRange(a: Tensor, start: number, end: number): Tensor;
38
- export declare function broadcastTo(a: Tensor, targetShape: Shape): Tensor;
39
- export declare function sumToShape(a: Tensor, targetShape: Shape): Tensor;
40
- export declare function constScalar(value: number, dtype?: Dtype): Tensor;
41
- export declare const less: (a: Tensor, b: Tensor) => Tensor;
42
- export declare const greater: (a: Tensor, b: Tensor) => Tensor;
43
- export declare function where(cond: Tensor, a: Tensor, b: Tensor): Tensor;
44
- export declare function reluGrad(x: Tensor, dy: Tensor): Tensor;
45
- export declare function adamUpdateM(m: Tensor, g: Tensor, b1: number): Tensor;
46
- export declare function adamUpdateV(v: Tensor, g: Tensor, b2: number): Tensor;
47
- export declare function adamUpdateP(p: Tensor, mNew: Tensor, vNew: Tensor, lrt: Tensor, eps: number, decayShrink?: number | Tensor): Tensor;
48
- //# sourceMappingURL=ops.d.ts.map
package/dist/ops.d.ts.map DELETED
@@ -1 +0,0 @@
1
- {"version":3,"file":"ops.d.ts","sourceRoot":"","sources":["../src/ops.ts"],"names":[],"mappings":"AAWA,OAAO,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAU,MAAM,SAAS,CAAA;AAmC3D,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAEzD;AACD,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,GAAG,MAAM,CAMzD;AAQD,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAG3D;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,GAAG,MAAM,CAG3D;AAYD,eAAO,MAAM,IAAI,GAAK,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,KAAK,GAAI,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,GAAG,GAAM,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,GAAG,GAAM,GAAG,MAAM,KAAG,MAA2B,CAAA;AAC7D,eAAO,MAAM,IAAI,GAAK,GAAG,MAAM,KAAG,MAA2B,CAAA;AAO7D,wBAAgB,QAAQ,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAK1C;AAED,wBAAgB,OAAO,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAKzC;AAED,2EAA2E;AAC3E,wBAAgB,MAAM,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAExC;AAMD,wBAAgB,OAAO,CAAC,CAAC,EAAE,MAAM,EAAE,QAAQ,EAAE,KAAK,GAAG,MAAM,CAI1D;AAED,wBAAgB,SAAS,CAAC,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,SAAS,MAAM,EAAE,GAAG,MAAM,CAIpE;AAED;;;8DAG8D;AAC9D,wBAAgB,QAAQ,CAAC,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,GAAG,MAAM,CAcxE;AAMD,wBAAgB,MAAM,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAOnD;AAED,wBAAgB,aAAa,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAO1D;AAMD,wBAAgB,MAAM,CAAC,OAAO,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAOnF;AAED;;;;sEAIsE;AACtE,wBAAgB,SAAS,CAAC,KAAK,EAAE,MAAM,EAAE,OAAO,EAAE,MAAM,GAAG,MAAM,CAShE;AAGD,wBAAgB,MAAM,CAAC,CAAC,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAM9D;AASD,wBAAgB,iBAAiB,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAKnD;AAGD,wBAAgB,cAAc,CAAC,CAAC,EAAE,MAAM,GAAG,MAAM,CAIhD;AAMD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,SAAS,EAAE,MAAM,GAAG,MAAM,CAKhE;AAQD,wBAAgB,cAAc,CAAC,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,GAAG,EAAE,MAAM,GAAG,MAAM,CAI5E;AAOD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,WAAW,EAAE,KAAK,GAAG,MAAM,CAIjE;AAED,wBAAgB,UAAU,CAAC,CAAC,EAAE,MAAM,EAAE,WAAW,EAAE,KAAK,GAAG,MAAM,CAIhE;AAOD,wBAAgB,WAAW,CAAC,KAAK,EAAE,MAAM,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAGvE;AAWD,eAAO,MAAM,IAAI,GAAO,GAAG,MAAM,EAAE,GAAG,MAAM,KAAG,MAAqD,CAAA;AACpG,eAAO,MAAM,OAAO,GAAI,GAAG,MAAM,EAAE,GAAG,MAAM,KAAG,MAAqD,CAAA;AAGpG,wBAAgB,KAAK,CAAC,IAAI,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,GAAG,MAAM,CAMhE;AAID,wBAAgB,QAAQ,CAAC,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOtD;AAMD,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOpE;AAED,wBAAgB,WAAW,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,EAAE,EAAE,EAAE,MAAM,GAAG,MAAM,CAOpE;AAED,wBAAgB,WAAW,CACzB,CAAC,EAAE,MAAM,EACT,IAAI,EAAE,MAAM,EACZ,IAAI,EAAE,MAAM,EACZ,GAAG,EAAE,MAAM,EACX,GAAG,EAAE,MAAM,EACX,WAAW,GAAE,MAAM,GAAG,MAAU,GAC/B,MAAM,CA2BR"}
package/dist/runtime.d.ts DELETED
@@ -1,115 +0,0 @@
1
- import type { BufferPlan } from './buffers.js';
2
- import type { KernelSpec } from './codegen.js';
3
- export interface UploadParamsOptions {
4
- /** Skip the "missing param" check, allowing the caller to update only some
5
- * params and leave the rest at their current GPU values. Extra (unknown)
6
- * keys are still rejected — that's always a typo. Default: false. */
7
- partial?: boolean;
8
- }
9
- /**
10
- * Activation readbacks for one `step()`/`run()` call. Keyed by the names
11
- * passed to `capture(name, t)` during the trace. `get(name)` throws if the
12
- * name isn't registered or wasn't read back this call (i.e., the call was
13
- * made without `{ withCaptures: true }`); use `has(name)` if you need to
14
- * branch. `shapeOf(name)` returns the static-after-compile shape and works
15
- * regardless of whether captures were read back.
16
- */
17
- export declare class Captures {
18
- private readonly shapes;
19
- private readonly data;
20
- constructor(shapes: Record<string, readonly number[]>, data: Map<string, Float32Array>);
21
- get(name: string): Float32Array;
22
- shapeOf(name: string): readonly number[];
23
- has(name: string): boolean;
24
- names(): string[];
25
- }
26
- export interface RunResult {
27
- output: Float32Array;
28
- captures: Captures;
29
- }
30
- export interface StepResult {
31
- loss: number;
32
- captures: Captures;
33
- }
34
- export interface RunOptions {
35
- /** Read back tensors registered via `capture(name, t)` during the trace.
36
- * Default false. When false, the returned `captures` is empty (calling
37
- * `.get` throws); when true, captures are read back and accessible. */
38
- withCaptures?: boolean;
39
- }
40
- /** Common surface for both training and forward-only compiled runtimes. */
41
- export interface CompiledBase {
42
- /** The GPUDevice this runtime is bound to. Pass to sibling compiles to
43
- * share the device, or use directly for other GPU work. */
44
- device: GPUDevice;
45
- /** Param name -> the underlying GPUBuffer. Pass to a sibling compile via
46
- * `sharedParams` to share without copies. */
47
- params: Map<string, GPUBuffer>;
48
- /** Shape of the graph's output (loss scalar `[]` for training; the user's
49
- * returned tensor for forward-only compiles). */
50
- outputShape: number[];
51
- /** Upload parameter Float32Arrays to their GPU buffers. By default, requires
52
- * *all* params to be present; throws on any unknown or missing key. Pass
53
- * `{ partial: true }` to skip the missing-key check. */
54
- uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void;
55
- /** Read all parameters back as Float32Arrays — used for UI panels. */
56
- downloadParams(): Promise<Record<string, Float32Array>>;
57
- /** Free GPU resources. */
58
- destroy(): void;
59
- }
60
- /** Run a dispatch and read back the full output tensor. Default returns the
61
- * output as a `Float32Array`; with `{ withCaptures: true }` returns
62
- * `{ output, captures }`. Same shape as `step()`'s overloads. */
63
- export interface RunFn {
64
- (inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>;
65
- (inputs: Record<string, Int32Array | Float32Array>, opts: {
66
- withCaptures: true;
67
- }): Promise<RunResult>;
68
- (inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunResult>;
69
- }
70
- export interface CompiledRuntime extends CompiledBase {
71
- /** Read all parameter gradients back. Mostly for verification / debugging. */
72
- downloadParamGrads(): Promise<Record<string, Float32Array>>;
73
- /**
74
- * One full forward+backward step.
75
- * 1. Uploads `inputs` (tokens, targets, masks) to input buffers.
76
- * 2. Dispatches every kernel in order.
77
- * 3. Reads back the loss scalar (and any registered captures, if requested).
78
- * Default returns the loss as a JS number; with `{ withCaptures: true }`
79
- * returns `{ loss, captures }`.
80
- */
81
- step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>;
82
- step(inputs: Record<string, Int32Array | Float32Array>, opts: {
83
- withCaptures: true;
84
- }): Promise<StepResult>;
85
- step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>;
86
- /** Same dispatch as step() but returns the full output Float32Array — for
87
- * training graphs the output is a scalar loss, so step() is usually more
88
- * convenient. Provided for parity with `compileForward`. */
89
- run: RunFn;
90
- /** Re-zero all optimizer state buffers (Adam's m/v) in place. Pair with
91
- * `uploadInitialParams()` for a full training reset without recompile. */
92
- resetOptimizerState(): void;
93
- }
94
- /** Forward-only compiled runtime — produced by `compileForward`. No optimizer,
95
- * no backward. Returns the output tensor (not just a scalar) per `run()` call. */
96
- export interface CompiledForward extends CompiledBase {
97
- run: RunFn;
98
- }
99
- export interface RuntimeOpts {
100
- /** Pre-acquired GPUDevice. If omitted, runtime requests its own. */
101
- device?: GPUDevice;
102
- /** External param buffers to bind in place of allocating fresh ones, keyed
103
- * by param name. Used to share params between a training compile and a
104
- * sibling forward-only compile (e.g., a B=1 inference graph). When a name
105
- * is in this map, the runtime reuses the provided GPUBuffer; otherwise it
106
- * allocates as usual. */
107
- sharedParams?: Map<string, GPUBuffer>;
108
- }
109
- export declare function createRuntime(plan: BufferPlan, kernels: KernelSpec[], lossBufferId: number, opts?: RuntimeOpts): Promise<CompiledRuntime>;
110
- /** Same machinery as `createRuntime`, narrower public type: a forward-only
111
- * graph exposes `run()` instead of `step()` (no optimizer state, no scalar-
112
- * loss readback). The full runtime object is built once and projected by
113
- * `compileForward` to the public shape. */
114
- export declare function createForwardRuntime(plan: BufferPlan, kernels: KernelSpec[], outputBufferId: number, opts?: RuntimeOpts): Promise<CompiledForward>;
115
- //# sourceMappingURL=runtime.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"runtime.d.ts","sourceRoot":"","sources":["../src/runtime.ts"],"names":[],"mappings":"AAMA,OAAO,KAAK,EAAE,UAAU,EAAE,MAAM,cAAc,CAAA;AAC9C,OAAO,KAAK,EAAE,UAAU,EAAE,MAAM,cAAc,CAAA;AAM9C,MAAM,WAAW,mBAAmB;IAClC;;0EAEsE;IACtE,OAAO,CAAC,EAAE,OAAO,CAAA;CAClB;AAED;;;;;;;GAOG;AACH,qBAAa,QAAQ;IAEjB,OAAO,CAAC,QAAQ,CAAC,MAAM;IACvB,OAAO,CAAC,QAAQ,CAAC,IAAI;gBADJ,MAAM,EAAE,MAAM,CAAC,MAAM,EAAE,SAAS,MAAM,EAAE,CAAC,EACzC,IAAI,EAAE,GAAG,CAAC,MAAM,EAAE,YAAY,CAAC;IAElD,GAAG,CAAC,IAAI,EAAE,MAAM,GAAG,YAAY;IAS/B,OAAO,CAAC,IAAI,EAAE,MAAM,GAAG,SAAS,MAAM,EAAE;IAQxC,GAAG,CAAC,IAAI,EAAE,MAAM,GAAG,OAAO;IAC1B,KAAK,IAAI,MAAM,EAAE;CAClB;AAED,MAAM,WAAW,SAAS;IACxB,MAAM,EAAE,YAAY,CAAA;IACpB,QAAQ,EAAE,QAAQ,CAAA;CACnB;AAED,MAAM,WAAW,UAAU;IACzB,IAAI,EAAE,MAAM,CAAA;IACZ,QAAQ,EAAE,QAAQ,CAAA;CACnB;AAED,MAAM,WAAW,UAAU;IACzB;;4EAEwE;IACxE,YAAY,CAAC,EAAE,OAAO,CAAA;CACvB;AAED,2EAA2E;AAC3E,MAAM,WAAW,YAAY;IAC3B;gEAC4D;IAC5D,MAAM,EAAE,SAAS,CAAA;IACjB;kDAC8C;IAC9C,MAAM,EAAE,GAAG,CAAC,MAAM,EAAE,SAAS,CAAC,CAAA;IAC9B;sDACkD;IAClD,WAAW,EAAE,MAAM,EAAE,CAAA;IACrB;;6DAEyD;IACzD,YAAY,CAAC,MAAM,EAAE,MAAM,CAAC,MAAM,EAAE,YAAY,CAAC,EAAE,IAAI,CAAC,EAAE,mBAAmB,GAAG,IAAI,CAAA;IACpF,sEAAsE;IACtE,cAAc,IAAI,OAAO,CAAC,MAAM,CAAC,MAAM,EAAE,YAAY,CAAC,CAAC,CAAA;IACvD,0BAA0B;IAC1B,OAAO,IAAI,IAAI,CAAA;CAChB;AAED;;kEAEkE;AAClE,MAAM,WAAW,KAAK;IACpB,CAAC,MAAM,EAAE,MAAM,CAAC,MAAM,EAAE,UAAU,GAAG,YAAY,CAAC,GAAG,OAAO,CAAC,YAAY,CAAC,CAAA;IAC1E,CAAC,MAAM,EAAE,MAAM,CAAC,MAAM,EAAE,UAAU,GAAG,YAAY,CAAC,EAAE,IAAI,EAAE;QAAE,YAAY,EAAE,IAAI,CAAA;KAAE,GAAG,OAAO,CAAC,SAAS,CAAC,CAAA;IACrG,CAAC,MAAM,EAAE,MAAM,CAAC,MAAM,EAAE,UAAU,GAAG,YAAY,CAAC,EAAE,IAAI,EAAE,UAAU,GAAG,OAAO,CAAC,YAAY,GAAG,SAAS,CAAC,CAAA;CACzG;AAED,MAAM,WAAW,eAAgB,SAAQ,YAAY;IACnD,8EAA8E;IAC9E,kBAAkB,IAAI,OAAO,CAAC,MAAM,CAAC,MAAM,EAAE,YAAY,CAAC,CAAC,CAAA;IAC3D;;;;;;;OAOG;IACH,IAAI,CAAC,MAAM,EAAE,MAAM,CAAC,MAAM,EAAE,UAAU,GAAG,YAAY,CAAC,GAAG,OAAO,CAAC,MAAM,CAAC,CAAA;IACxE,IAAI,CAAC,MAAM,EAAE,MAAM,CAAC,MAAM,EAAE,UAAU,GAAG,YAAY,CAAC,EAAE,IAAI,EAAE;QAAE,YAAY,EAAE,IAAI,CAAA;KAAE,GAAG,OAAO,CAAC,UAAU,CAAC,CAAA;IAC1G,IAAI,CAAC,MAAM,EAAE,MAAM,CAAC,MAAM,EAAE,UAAU,GAAG,YAAY,CAAC,EAAE,IAAI,EAAE,UAAU,GAAG,OAAO,CAAC,MAAM,GAAG,UAAU,CAAC,CAAA;IACvG;;iEAE6D;IAC7D,GAAG,EAAE,KAAK,CAAA;IACV;+EAC2E;IAC3E,mBAAmB,IAAI,IAAI,CAAA;CAC5B;AAED;mFACmF;AACnF,MAAM,WAAW,eAAgB,SAAQ,YAAY;IACnD,GAAG,EAAE,KAAK,CAAA;CACX;AAED,MAAM,WAAW,WAAW;IAC1B,oEAAoE;IACpE,MAAM,CAAC,EAAE,SAAS,CAAA;IAClB;;;;8BAI0B;IAC1B,YAAY,CAAC,EAAE,GAAG,CAAC,MAAM,EAAE,SAAS,CAAC,CAAA;CACtC;AAQD,wBAAsB,aAAa,CACjC,IAAI,EAAE,UAAU,EAChB,OAAO,EAAE,UAAU,EAAE,EACrB,YAAY,EAAE,MAAM,EACpB,IAAI,GAAE,WAAgB,GACrB,OAAO,CAAC,eAAe,CAAC,CA4V1B;AAED;;;4CAG4C;AAC5C,wBAAsB,oBAAoB,CACxC,IAAI,EAAE,UAAU,EAChB,OAAO,EAAE,UAAU,EAAE,EACrB,cAAc,EAAE,MAAM,EACtB,IAAI,GAAE,WAAgB,GACrB,OAAO,CAAC,eAAe,CAAC,CAE1B"}
package/dist/shape.d.ts DELETED
@@ -1,24 +0,0 @@
1
- import type { Shape, CallSite } from './ir.js';
2
- export declare class ShapeError extends Error {
3
- constructor(message: string, site: CallSite | null);
4
- }
5
- export declare function shapesEqual(a: Shape, b: Shape): boolean;
6
- export declare function shapeSize(shape: Shape): number;
7
- export declare function showShape(shape: Shape): string;
8
- export declare function broadcastTrailing(a: Shape, b: Shape): Shape | null;
9
- export declare function inferElementwiseBinop(opName: string, aShape: Shape, bShape: Shape, site: CallSite | null): Shape;
10
- export declare function inferUnary(_opName: string, aShape: Shape, _site: CallSite | null): Shape;
11
- export declare function inferMeanLast(opName: string, aShape: Shape, site: CallSite | null): Shape;
12
- export declare function inferSumLast(opName: string, aShape: Shape, site: CallSite | null): Shape;
13
- export declare function inferReshape(opName: string, aShape: Shape, newShape: Shape, site: CallSite | null): Shape;
14
- export declare function inferTranspose(opName: string, aShape: Shape, perm: readonly number[], site: CallSite | null): Shape;
15
- export declare function inferMatmul(opName: string, aShape: Shape, bShape: Shape, site: CallSite | null): Shape;
16
- export declare function inferMatmulBatched(opName: string, aShape: Shape, bShape: Shape, site: CallSite | null): Shape;
17
- export declare function inferOneHot(opName: string, indicesShape: Shape, depth: number, site: CallSite | null): Shape;
18
- export declare function inferWhereCausal(opName: string, aShape: Shape, site: CallSite | null): Shape;
19
- export declare function inferSliceLastRange(opName: string, aShape: Shape, start: number, end: number, site: CallSite | null): Shape;
20
- export declare function inferBroadcastTo(opName: string, aShape: Shape, targetShape: Shape, site: CallSite | null): Shape;
21
- export declare function inferSumToShape(opName: string, aShape: Shape, targetShape: Shape, site: CallSite | null): Shape;
22
- export declare function inferWhere(opName: string, condShape: Shape, aShape: Shape, bShape: Shape, site: CallSite | null): Shape;
23
- export declare function inferReluGrad(opName: string, xShape: Shape, dyShape: Shape, site: CallSite | null): Shape;
24
- //# sourceMappingURL=shape.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"shape.d.ts","sourceRoot":"","sources":["../src/shape.ts"],"names":[],"mappings":"AAkBA,OAAO,KAAK,EAAE,KAAK,EAAE,QAAQ,EAAE,MAAM,SAAS,CAAA;AAO9C,qBAAa,UAAW,SAAQ,KAAK;gBACvB,OAAO,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI;CAKnD;AAUD,wBAAgB,WAAW,CAAC,CAAC,EAAE,KAAK,EAAE,CAAC,EAAE,KAAK,GAAG,OAAO,CAIvD;AAED,wBAAgB,SAAS,CAAC,KAAK,EAAE,KAAK,GAAG,MAAM,CAI9C;AAED,wBAAgB,SAAS,CAAC,KAAK,EAAE,KAAK,GAAG,MAAM,CAE9C;AAKD,wBAAgB,iBAAiB,CAAC,CAAC,EAAE,KAAK,EAAE,CAAC,EAAE,KAAK,GAAG,KAAK,GAAG,IAAI,CAclE;AASD,wBAAgB,qBAAqB,CACnC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAClE,KAAK,CAWP;AAED,wBAAgB,UAAU,CAAC,OAAO,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAExF;AAED,wBAAgB,aAAa,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAIzF;AAED,wBAAgB,YAAY,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAIxF;AAED,wBAAgB,YAAY,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,QAAQ,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CA0BzG;AAED,wBAAgB,cAAc,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,SAAS,MAAM,EAAE,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAWnH;AAGD,wBAAgB,WAAW,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAStG;AAGD,wBAAgB,kBAAkB,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAoB7G;AAED,wBAAgB,WAAW,CAAC,MAAM,EAAE,MAAM,EAAE,YAAY,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAG5G;AAGD,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAM5F;AAED,wBAAgB,mBAAmB,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,GAAG,EAAE,MAAM,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAO3H;AAID,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,WAAW,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAahH;AAID,wBAAgB,eAAe,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,WAAW,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAa/G;AAID,wBAAgB,UAAU,CAAC,MAAM,EAAE,MAAM,EAAE,SAAS,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAMvH;AAED,wBAAgB,aAAa,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,KAAK,EAAE,OAAO,EAAE,KAAK,EAAE,IAAI,EAAE,QAAQ,GAAG,IAAI,GAAG,KAAK,CAKzG"}
package/dist/trace.d.ts DELETED
@@ -1,9 +0,0 @@
1
- import type { Graph, Tensor, Shape, Dtype } from './ir.js';
2
- export declare function currentGraph(): Graph;
3
- export declare function isCaptureEnabled(): boolean;
4
- export declare function trace(fn: () => Tensor | Tensor[]): Graph;
5
- export declare function traceInto<T>(g: Graph, fn: () => T): T;
6
- export declare function paramInput(name: string, shape: Shape, dtype?: Dtype): Tensor;
7
- export declare function tensorInput(name: string, shape: Shape, dtype?: Dtype): Tensor;
8
- export declare function stateInput(name: string, shape: Shape, dtype?: Dtype, initValue?: number): Tensor;
9
- //# sourceMappingURL=trace.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"trace.d.ts","sourceRoot":"","sources":["../src/trace.ts"],"names":[],"mappings":"AAgBA,OAAO,KAAK,EAAE,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,MAAM,SAAS,CAAA;AAU1D,wBAAgB,YAAY,IAAI,KAAK,CAQpC;AAED,wBAAgB,gBAAgB,IAAI,OAAO,CAE1C;AAID,wBAAgB,KAAK,CAAC,EAAE,EAAE,MAAM,MAAM,GAAG,MAAM,EAAE,GAAG,KAAK,CAkBxD;AAQD,wBAAgB,SAAS,CAAC,CAAC,EAAE,CAAC,EAAE,KAAK,EAAE,EAAE,EAAE,MAAM,CAAC,GAAG,CAAC,CAWrD;AAOD,wBAAgB,UAAU,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAOnF;AAED,wBAAgB,WAAW,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,GAAE,KAAa,GAAG,MAAM,CAOpF;AAID,wBAAgB,UAAU,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAE,KAAK,GAAE,KAAa,EAAE,SAAS,SAAI,GAAG,MAAM,CAOlG"}