tensorgrad 0.0.12 → 0.0.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. package/dist/buffers.js +1 -6
  2. package/dist/buffers.js.map +1 -1
  3. package/dist/codegen.js +30 -28
  4. package/dist/codegen.js.map +1 -1
  5. package/dist/compile.js +39 -68
  6. package/dist/compile.js.map +1 -1
  7. package/dist/grad.js +1 -14
  8. package/dist/grad.js.map +1 -1
  9. package/dist/index.d.ts +740 -14
  10. package/dist/runtime.js +6 -9
  11. package/dist/runtime.js.map +1 -1
  12. package/dist/trace.js +8 -13
  13. package/dist/trace.js.map +1 -1
  14. package/package.json +9 -3
  15. package/src/buffers.ts +1 -6
  16. package/src/codegen.ts +31 -28
  17. package/src/compile.ts +312 -358
  18. package/src/grad.ts +1 -11
  19. package/src/runtime.ts +6 -9
  20. package/src/trace.ts +12 -9
  21. package/dist/adam.d.ts +0 -65
  22. package/dist/adam.d.ts.map +0 -1
  23. package/dist/buffers.d.ts +0 -57
  24. package/dist/buffers.d.ts.map +0 -1
  25. package/dist/capture.d.ts +0 -3
  26. package/dist/capture.d.ts.map +0 -1
  27. package/dist/codegen.d.ts +0 -23
  28. package/dist/codegen.d.ts.map +0 -1
  29. package/dist/compile.d.ts +0 -130
  30. package/dist/compile.d.ts.map +0 -1
  31. package/dist/grad.d.ts +0 -8
  32. package/dist/grad.d.ts.map +0 -1
  33. package/dist/index.d.ts.map +0 -1
  34. package/dist/ir.d.ts +0 -207
  35. package/dist/ir.d.ts.map +0 -1
  36. package/dist/module.d.ts +0 -55
  37. package/dist/module.d.ts.map +0 -1
  38. package/dist/nn.d.ts +0 -42
  39. package/dist/nn.d.ts.map +0 -1
  40. package/dist/ops.d.ts +0 -48
  41. package/dist/ops.d.ts.map +0 -1
  42. package/dist/runtime.d.ts +0 -115
  43. package/dist/runtime.d.ts.map +0 -1
  44. package/dist/shape.d.ts +0 -24
  45. package/dist/shape.d.ts.map +0 -1
  46. package/dist/trace.d.ts +0 -9
  47. package/dist/trace.d.ts.map +0 -1
package/dist/index.d.ts CHANGED
@@ -1,14 +1,740 @@
1
- export type { Tensor, Shape, Dtype, OpNode, Graph, CallSite } from './ir.js';
2
- export { ShapeError } from './shape.js';
3
- export { trace, traceInto, paramInput, tensorInput, stateInput } from './trace.js';
4
- export { capture } from './capture.js';
5
- export { add, sub, mul, div, sqrt, rsqrt, log, exp, relu, less, greater, where, meanLast, sumLast, sumAll, reshape, transpose, swapAxes, matmul, matmulBatched, oneHot, arange, embedding, softmaxCausalLast, logSoftmaxLast, whereCausal, sliceLastRange, } from './ops.js';
6
- export { appendGrad, type GradResult } from './grad.js';
7
- export { appendAdam, type AdamConfig, type AdamResult } from './adam.js';
8
- export { planBuffers, type BufferPlan, type BufferSpec, type Writeback, type WritebackDecl } from './buffers.js';
9
- export { emitKernels, type KernelSpec } from './codegen.js';
10
- export { createRuntime, createForwardRuntime, Captures, type CompiledRuntime, type CompiledForward, type RuntimeOpts, type RunOptions, type StepResult, type RunResult } from './runtime.js';
11
- export { compile, compileToIR, compileModule, compileForward, type CompiledIR, type CompileModuleOptions, type CompileForwardOptions, type CompileForwardMethodOptions, type CompiledModule, type CompiledForwardModule, type InputDecl, type InputDecls, type InputsTensors, type ForwardFn, } from './compile.js';
12
- export { Module, materializeParams, type InitSpec, type ParamOptions, type MaterializedParams } from './module.js';
13
- export * as nn from './nn.js';
14
- //# sourceMappingURL=index.d.ts.map
1
+ type Dtype = 'f32' | 'i32' | 'bool';
2
+ type Shape = readonly number[];
3
+ interface Tensor {
4
+ readonly id: number;
5
+ readonly shape: Shape;
6
+ readonly dtype: Dtype;
7
+ readonly source: number | null;
8
+ readonly site: CallSite | null;
9
+ }
10
+ interface CallSite {
11
+ readonly opName: string;
12
+ readonly stack: string;
13
+ }
14
+ type OpNode = {
15
+ kind: 'param_input';
16
+ out: number;
17
+ name: string;
18
+ } | {
19
+ kind: 'tensor_input';
20
+ out: number;
21
+ name: string;
22
+ } | {
23
+ kind: 'state_input';
24
+ out: number;
25
+ name: string;
26
+ initValue: number;
27
+ } | {
28
+ kind: 'add';
29
+ out: number;
30
+ a: number;
31
+ b: number;
32
+ } | {
33
+ kind: 'sub';
34
+ out: number;
35
+ a: number;
36
+ b: number;
37
+ } | {
38
+ kind: 'mul';
39
+ out: number;
40
+ a: number;
41
+ b: number;
42
+ } | {
43
+ kind: 'div';
44
+ out: number;
45
+ a: number;
46
+ b: number;
47
+ } | {
48
+ kind: 'mul_scalar';
49
+ out: number;
50
+ a: number;
51
+ scalar: number;
52
+ } | {
53
+ kind: 'add_scalar';
54
+ out: number;
55
+ a: number;
56
+ scalar: number;
57
+ } | {
58
+ kind: 'sqrt';
59
+ out: number;
60
+ a: number;
61
+ } | {
62
+ kind: 'rsqrt';
63
+ out: number;
64
+ a: number;
65
+ } | {
66
+ kind: 'log';
67
+ out: number;
68
+ a: number;
69
+ } | {
70
+ kind: 'exp';
71
+ out: number;
72
+ a: number;
73
+ } | {
74
+ kind: 'relu';
75
+ out: number;
76
+ a: number;
77
+ } | {
78
+ kind: 'mean_last';
79
+ out: number;
80
+ a: number;
81
+ } | {
82
+ kind: 'sum_last';
83
+ out: number;
84
+ a: number;
85
+ } | {
86
+ kind: 'reshape';
87
+ out: number;
88
+ a: number;
89
+ newShape: Shape;
90
+ } | {
91
+ kind: 'transpose';
92
+ out: number;
93
+ a: number;
94
+ perm: readonly number[];
95
+ } | {
96
+ kind: 'matmul';
97
+ out: number;
98
+ a: number;
99
+ b: number;
100
+ } | {
101
+ kind: 'matmul_batched';
102
+ out: number;
103
+ a: number;
104
+ b: number;
105
+ } | {
106
+ kind: 'one_hot';
107
+ out: number;
108
+ indices: number;
109
+ depth: number;
110
+ dtype: Dtype;
111
+ } | {
112
+ kind: 'arange';
113
+ out: number;
114
+ n: number;
115
+ dtype: Dtype;
116
+ } | {
117
+ kind: 'softmax_causal_last';
118
+ out: number;
119
+ a: number;
120
+ } | {
121
+ kind: 'log_softmax_last';
122
+ out: number;
123
+ a: number;
124
+ } | {
125
+ kind: 'where_causal';
126
+ out: number;
127
+ a: number;
128
+ fillValue: number;
129
+ } | {
130
+ kind: 'less';
131
+ out: number;
132
+ a: number;
133
+ b: number;
134
+ } | {
135
+ kind: 'greater';
136
+ out: number;
137
+ a: number;
138
+ b: number;
139
+ } | {
140
+ kind: 'where';
141
+ out: number;
142
+ cond: number;
143
+ a: number;
144
+ b: number;
145
+ } | {
146
+ kind: 'adam_update_m';
147
+ out: number;
148
+ m: number;
149
+ g: number;
150
+ b1: number;
151
+ } | {
152
+ kind: 'adam_update_v';
153
+ out: number;
154
+ v: number;
155
+ g: number;
156
+ b2: number;
157
+ } | {
158
+ kind: 'adam_update_p';
159
+ out: number;
160
+ p: number;
161
+ mNew: number;
162
+ vNew: number;
163
+ lrt: number;
164
+ eps: number;
165
+ decayShrink: number;
166
+ decayShrinkTensor: number | null;
167
+ } | {
168
+ kind: 'slice_last_range';
169
+ out: number;
170
+ a: number;
171
+ start: number;
172
+ end: number;
173
+ } | {
174
+ kind: 'broadcast_to';
175
+ out: number;
176
+ a: number;
177
+ targetShape: Shape;
178
+ } | {
179
+ kind: 'sum_to_shape';
180
+ out: number;
181
+ a: number;
182
+ targetShape: Shape;
183
+ } | {
184
+ kind: 'const_scalar';
185
+ out: number;
186
+ value: number;
187
+ dtype: Dtype;
188
+ } | {
189
+ kind: 'relu_grad';
190
+ out: number;
191
+ x: number;
192
+ dy: number;
193
+ };
194
+ interface Graph {
195
+ readonly ops: OpNode[];
196
+ readonly tensors: Tensor[];
197
+ readonly outputs: number[];
198
+ readonly captures: Map<string, number>;
199
+ }
200
+
201
+ declare class ShapeError extends Error {
202
+ constructor(message: string, site: CallSite | null);
203
+ }
204
+
205
+ declare function trace(fn: () => Tensor | Tensor[]): Graph;
206
+ declare function traceInto<T>(g: Graph, fn: () => T): T;
207
+ declare function paramInput(name: string, shape: Shape, dtype?: Dtype): Tensor;
208
+ declare function tensorInput(name: string, shape: Shape, dtype?: Dtype): Tensor;
209
+ declare function stateInput(name: string, shape: Shape, dtype?: Dtype, initValue?: number): Tensor;
210
+
211
+ declare function capture<T extends Tensor>(name: string, t: T): T;
212
+
213
+ declare function add(a: Tensor, b: Tensor | number): Tensor;
214
+ declare function sub(a: Tensor, b: Tensor | number): Tensor;
215
+ declare function mul(a: Tensor, b: Tensor | number): Tensor;
216
+ declare function div(a: Tensor, b: Tensor | number): Tensor;
217
+ declare const sqrt: (a: Tensor) => Tensor;
218
+ declare const rsqrt: (a: Tensor) => Tensor;
219
+ declare const log: (a: Tensor) => Tensor;
220
+ declare const exp: (a: Tensor) => Tensor;
221
+ declare const relu: (a: Tensor) => Tensor;
222
+ declare function meanLast(a: Tensor): Tensor;
223
+ declare function sumLast(a: Tensor): Tensor;
224
+ /** Reduce all elements to a 0-d scalar. Composes `reshape` + `sumLast`. */
225
+ declare function sumAll(a: Tensor): Tensor;
226
+ declare function reshape(a: Tensor, newShape: Shape): Tensor;
227
+ declare function transpose(a: Tensor, perm: readonly number[]): Tensor;
228
+ /** Swap two axes of a tensor. Negative indices count from the end (so
229
+ * `swapAxes(x, -1, -2)` swaps the last two — the common attention pattern).
230
+ * All other axes keep their position. Implemented as `transpose` with the
231
+ * permutation `[0, 1, ..., axis2, ..., axis1, ..., n-1]`. */
232
+ declare function swapAxes(a: Tensor, axis1: number, axis2: number): Tensor;
233
+ declare function matmul(a: Tensor, b: Tensor): Tensor;
234
+ declare function matmulBatched(a: Tensor, b: Tensor): Tensor;
235
+ declare function oneHot(indices: Tensor, depth: number, dtype?: Dtype): Tensor;
236
+ /** Embedding lookup: pull rows from `table` indexed by `indices`. Decomposes
237
+ * to `oneHot(indices, vocab) @ table` so autograd works without a dedicated
238
+ * scatter-with-atomic-add backward — the matmul transpose rule handles it.
239
+ * `table` is `[vocab, dim]`; `indices` is any shape `[...]` of i32; result
240
+ * is `[..., dim]`. The vocab size is taken from `table.shape[0]`. */
241
+ declare function embedding(table: Tensor, indices: Tensor): Tensor;
242
+ declare function arange(n: number, dtype?: Dtype): Tensor;
243
+ declare function softmaxCausalLast(a: Tensor): Tensor;
244
+ declare function logSoftmaxLast(a: Tensor): Tensor;
245
+ declare function whereCausal(a: Tensor, fillValue: number): Tensor;
246
+ declare function sliceLastRange(a: Tensor, start: number, end: number): Tensor;
247
+ declare const less: (a: Tensor, b: Tensor) => Tensor;
248
+ declare const greater: (a: Tensor, b: Tensor) => Tensor;
249
+ declare function where(cond: Tensor, a: Tensor, b: Tensor): Tensor;
250
+
251
+ interface GradResult {
252
+ readonly graph: Graph;
253
+ readonly paramGrads: Record<string, Tensor>;
254
+ readonly loss: Tensor;
255
+ }
256
+ declare function appendGrad(graph: Graph): GradResult;
257
+
258
+ interface BufferSpec {
259
+ /** Matches tensor.id. */
260
+ id: number;
261
+ byteSize: number;
262
+ dtype: Dtype;
263
+ shape: Shape;
264
+ kind: 'param' | 'param_grad' | 'tensor_input' | 'state' | 'intermediate' | 'output';
265
+ /** External name for param/param_grad/tensor_input/state bindings. null otherwise. */
266
+ name: string | null;
267
+ /** For state buffers: the value to fill on initial allocation. 0 by default. */
268
+ initValue?: number;
269
+ }
270
+ /**
271
+ * After step(), copy `source`'s buffer into `dest`'s buffer.
272
+ * Used to write back updated optimizer state and updated parameters into
273
+ * their persistent home buffers.
274
+ */
275
+ interface Writeback {
276
+ source: number;
277
+ dest: number;
278
+ bytes: number;
279
+ }
280
+ interface BufferPlan {
281
+ buffers: BufferSpec[];
282
+ /** Tensor id -> buffer id (currently 1:1 but kept opaque for future pooling). */
283
+ tensorToBuffer: Map<number, number>;
284
+ /** Easy lookup tables for the runtime. */
285
+ paramsByName: Map<string, number>;
286
+ inputsByName: Map<string, number>;
287
+ paramGradsByName: Map<string, number>;
288
+ statesByName: Map<string, number>;
289
+ capturesByName: Map<string, number>;
290
+ outputBufferIds: number[];
291
+ /** End-of-step writebacks (Adam updates for params, m, v, etc.) */
292
+ writebacks: Writeback[];
293
+ }
294
+ /**
295
+ * Caller-supplied writeback declarations: "after each step, copy this Tensor's
296
+ * buffer into the persistent home of this param/state."
297
+ */
298
+ interface WritebackDecl {
299
+ /** The Tensor (output of some op) holding the new value to write back. */
300
+ source: Tensor;
301
+ /** Either a param name (writes to that param's home buffer) or a state name. */
302
+ destName: string;
303
+ destKind: 'param' | 'state';
304
+ }
305
+ /**
306
+ * Build a BufferPlan from a graph + the param-grad map produced by appendGrad.
307
+ * @param graph the full graph (forward + backward + any optimizer ops)
308
+ * @param paramGrads map from param name -> the Tensor that holds its gradient
309
+ * @param writebackDecls list of end-of-step writebacks (e.g. from appendAdam).
310
+ * Empty when there's no optimizer in the graph.
311
+ */
312
+ declare function planBuffers(graph: Graph, paramGrads: Record<string, Tensor>, writebackDecls?: WritebackDecl[]): BufferPlan;
313
+
314
+ interface AdamConfig {
315
+ /** Constant scalar (e.g., `0.005`) or a per-step schedule function
316
+ * `(step) => lr`. Schedule fn lets the user implement linear/cosine decay
317
+ * or warmup; first call passes `step=1`. Decay-shrink (AdamW) updates
318
+ * per-step automatically when this is a function. */
319
+ lr: number | ((step: number) => number);
320
+ b1?: number;
321
+ b2?: number;
322
+ eps?: number;
323
+ /** AdamW: decoupled weight decay coefficient. Default 0 (plain Adam).
324
+ * When non-zero, every step shrinks each decayed param by a factor of
325
+ * `1 - lr * weightDecay` before the gradient update. */
326
+ weightDecay?: number;
327
+ /** Filter deciding which params get weight decay. Only consulted when
328
+ * weightDecay > 0. Default: decay every param. Override for the standard
329
+ * transformer convention (decay weights/embeddings, skip biases + LN gains).
330
+ * Example: `(name) => name.includes('.W') || name.endsWith('_emb')`. */
331
+ decayFilter?: (paramName: string) => boolean;
332
+ }
333
+ /** Resolved hyperparameters: lr is the schedule fn (constants are wrapped). */
334
+ interface AdamResolvedConfig {
335
+ lr: (step: number) => number;
336
+ b1: number;
337
+ b2: number;
338
+ eps: number;
339
+ weightDecay: number;
340
+ decayFilter: (name: string) => boolean;
341
+ /** True iff the user supplied an lr function (vs a constant). When false,
342
+ * decayShrink is baked at compile time and never updated. */
343
+ lrIsScheduled: boolean;
344
+ }
345
+ interface AdamResult {
346
+ /** Writebacks the buffer planner should wire into the runtime. */
347
+ writebacks: WritebackDecl[];
348
+ /** Name of the per-step scalar tensor_input. The runtime fills this each call
349
+ * with `lr * sqrt(1-b2^t)/(1-b1^t)` (Adam's bias-corrected effective LR). */
350
+ lrtInputName: string;
351
+ /** Name of the per-step decayShrink scalar tensor_input, or null when lr is
352
+ * static (decayShrink baked into the kernel) or no params are decayed. */
353
+ decayShrinkInputName: string | null;
354
+ /** Hyperparameters as captured (so the runtime can compute lrt and decayShrink). */
355
+ config: AdamResolvedConfig;
356
+ }
357
+ /**
358
+ * Append Adam update ops to `graph`. Must be called inside an active trace
359
+ * context (or after a trace, since traceInto re-enters the graph).
360
+ *
361
+ * @param graph the graph (already containing forward + backward)
362
+ * @param paramGrads param name -> gradient tensor (output of `appendGrad`)
363
+ * @param paramTensors param name -> the param's leaf Tensor (the param_input).
364
+ * Needed because the param_input lives in the graph but we
365
+ * don't have a direct map by name in `Graph` — caller passes it.
366
+ * @param config Adam hyperparameters. Set `weightDecay > 0` for AdamW; an
367
+ * optional `decayFilter` selects which params receive decay.
368
+ */
369
+ declare function appendAdam(graph: Graph, paramGrads: Record<string, Tensor>, paramTensors: Record<string, Tensor>, config: AdamConfig,
370
+ /** Per-param decay flags from `materializeParams`. When supplied, overrides
371
+ * `config.decayFilter` for any name in the map; falls back to `decayFilter`
372
+ * for names not present (e.g., for low-level callers using `compile()`
373
+ * directly without a Module). */
374
+ decayFlags?: Record<string, boolean>): AdamResult;
375
+
376
+ interface KernelSpec {
377
+ /** Index into graph.ops. */
378
+ opIndex: number;
379
+ /** Op kind (for debugging / pipeline cache key). */
380
+ opKind: OpNode['kind'];
381
+ /** Generated WGSL source. Empty string for "logical" ops with no kernel. */
382
+ wgsl: string;
383
+ /**
384
+ * Buffer ids in binding-index order. The runtime creates a bind group with
385
+ * these in @binding(0..N) on @group(0). Inputs come first (read), output last
386
+ * (read_write).
387
+ */
388
+ bindings: number[];
389
+ /** Number of threads to dispatch (1-D). 0 means "skip" (e.g. reshape no-op). */
390
+ threads: number;
391
+ /** Workgroup size; usually WG_SIZE. */
392
+ workgroupSize: number;
393
+ }
394
+ /** Generate a KernelSpec per compute op in graph.ops (in dispatch order). */
395
+ declare function emitKernels(graph: Graph, plan: BufferPlan): KernelSpec[];
396
+
397
+ interface UploadParamsOptions {
398
+ /** Skip the "missing param" check, allowing the caller to update only some
399
+ * params and leave the rest at their current GPU values. Extra (unknown)
400
+ * keys are still rejected — that's always a typo. Default: false. */
401
+ partial?: boolean;
402
+ }
403
+ /**
404
+ * Activation readbacks for one `step()`/`run()` call. Keyed by the names
405
+ * passed to `capture(name, t)` during the trace. `get(name)` throws if the
406
+ * name isn't registered or wasn't read back this call (i.e., the call was
407
+ * made without `{ withCaptures: true }`); use `has(name)` if you need to
408
+ * branch. `shapeOf(name)` returns the static-after-compile shape and works
409
+ * regardless of whether captures were read back.
410
+ */
411
+ declare class Captures {
412
+ private readonly shapes;
413
+ private readonly data;
414
+ constructor(shapes: Record<string, readonly number[]>, data: Map<string, Float32Array>);
415
+ get(name: string): Float32Array;
416
+ shapeOf(name: string): readonly number[];
417
+ has(name: string): boolean;
418
+ names(): string[];
419
+ }
420
+ interface RunResult {
421
+ output: Float32Array;
422
+ captures: Captures;
423
+ }
424
+ interface StepResult {
425
+ loss: number;
426
+ captures: Captures;
427
+ }
428
+ interface RunOptions {
429
+ /** Read back tensors registered via `capture(name, t)` during the trace.
430
+ * Default false. When false, the returned `captures` is empty (calling
431
+ * `.get` throws); when true, captures are read back and accessible. */
432
+ withCaptures?: boolean;
433
+ }
434
+ /** Common surface for both training and forward-only compiled runtimes. */
435
+ interface CompiledBase {
436
+ /** The GPUDevice this runtime is bound to. Pass to sibling compiles to
437
+ * share the device, or use directly for other GPU work. */
438
+ device: GPUDevice;
439
+ /** Param name -> the underlying GPUBuffer. Pass to a sibling compile via
440
+ * `sharedParams` to share without copies. */
441
+ params: Map<string, GPUBuffer>;
442
+ /** Shape of the graph's output (loss scalar `[]` for training; the user's
443
+ * returned tensor for forward-only compiles). */
444
+ outputShape: number[];
445
+ /** Upload parameter Float32Arrays to their GPU buffers. By default, requires
446
+ * *all* params to be present; throws on any unknown or missing key. Pass
447
+ * `{ partial: true }` to skip the missing-key check. */
448
+ uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void;
449
+ /** Read all parameters back as Float32Arrays — used for UI panels. */
450
+ downloadParams(): Promise<Record<string, Float32Array>>;
451
+ /** Free GPU resources. */
452
+ destroy(): void;
453
+ }
454
+ /** Run a dispatch and read back the full output tensor. Default returns the
455
+ * output as a `Float32Array`; with `{ withCaptures: true }` returns
456
+ * `{ output, captures }`. Same shape as `step()`'s overloads. */
457
+ interface RunFn {
458
+ (inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>;
459
+ (inputs: Record<string, Int32Array | Float32Array>, opts: {
460
+ withCaptures: true;
461
+ }): Promise<RunResult>;
462
+ (inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunResult>;
463
+ }
464
+ interface CompiledRuntime extends CompiledBase {
465
+ /** Read all parameter gradients back. Mostly for verification / debugging. */
466
+ downloadParamGrads(): Promise<Record<string, Float32Array>>;
467
+ /**
468
+ * One full forward+backward step.
469
+ * 1. Uploads `inputs` (tokens, targets, masks) to input buffers.
470
+ * 2. Dispatches every kernel in order.
471
+ * 3. Reads back the loss scalar (and any registered captures, if requested).
472
+ * Default returns the loss as a JS number; with `{ withCaptures: true }`
473
+ * returns `{ loss, captures }`.
474
+ */
475
+ step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>;
476
+ step(inputs: Record<string, Int32Array | Float32Array>, opts: {
477
+ withCaptures: true;
478
+ }): Promise<StepResult>;
479
+ step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>;
480
+ /** Same dispatch as step() but returns the full output Float32Array — for
481
+ * training graphs the output is a scalar loss, so step() is usually more
482
+ * convenient. Provided for parity with `compileForward`. */
483
+ run: RunFn;
484
+ /** Re-zero all optimizer state buffers (Adam's m/v) in place. Pair with
485
+ * `uploadInitialParams()` for a full training reset without recompile. */
486
+ resetOptimizerState(): void;
487
+ }
488
+ /** Forward-only compiled runtime — produced by `compileForward`. No optimizer,
489
+ * no backward. Returns the output tensor (not just a scalar) per `run()` call. */
490
+ interface CompiledForward extends CompiledBase {
491
+ run: RunFn;
492
+ }
493
+ interface RuntimeOpts {
494
+ /** Pre-acquired GPUDevice. If omitted, runtime requests its own. */
495
+ device?: GPUDevice;
496
+ /** External param buffers to bind in place of allocating fresh ones, keyed
497
+ * by param name. Used to share params between a training compile and a
498
+ * sibling forward-only compile (e.g., a B=1 inference graph). When a name
499
+ * is in this map, the runtime reuses the provided GPUBuffer; otherwise it
500
+ * allocates as usual. */
501
+ sharedParams?: Map<string, GPUBuffer>;
502
+ }
503
+ declare function createRuntime(plan: BufferPlan, kernels: KernelSpec[], lossBufferId: number, opts?: RuntimeOpts): Promise<CompiledRuntime>;
504
+ /** Same machinery as `createRuntime`, narrower public type: a forward-only
505
+ * graph exposes `run()` instead of `step()` (no optimizer state, no scalar-
506
+ * loss readback). The full runtime object is built once and projected by
507
+ * `compileForward` to the public shape. */
508
+ declare function createForwardRuntime(plan: BufferPlan, kernels: KernelSpec[], outputBufferId: number, opts?: RuntimeOpts): Promise<CompiledForward>;
509
+
510
+ /** How a parameter's initial values are produced.
511
+ * - `'randn'` — Gaussian, with `scale` (default 0.02). The common case for
512
+ * weight matrices and embeddings.
513
+ * - `'zeros'` — fill with 0. Common for biases and LayerNorm beta.
514
+ * - `'ones'` — fill with 1. Common for LayerNorm gain.
515
+ * - Custom function — receives total element count and shape, returns the
516
+ * Float32Array. Use for fan-in scaling or any non-standard scheme.
517
+ */
518
+ type InitSpec = 'randn' | 'zeros' | 'ones' | ((size: number, shape: readonly number[]) => Float32Array);
519
+ interface ParamOptions {
520
+ dtype?: Dtype;
521
+ /** Init kind. Default: `'randn'`. */
522
+ init?: InitSpec;
523
+ /** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
524
+ scale?: number;
525
+ /** Whether AdamW (when `weightDecay > 0`) should apply decoupled weight
526
+ * decay to this param. Default: `true` for `'randn'` init (weight matrices,
527
+ * embeddings), `false` for `'zeros'` / `'ones'` (biases, LN gains). Override
528
+ * to force or skip. Replaces `adam.decayFilter` for the common case. */
529
+ decay?: boolean;
530
+ }
531
+ type InitFn = (size: number, shape: readonly number[]) => Float32Array;
532
+ declare abstract class Module {
533
+ /**
534
+ * Declare a learnable parameter at this module. Must be called from inside
535
+ * the constructor (typically as a field assignment). Returns a placeholder
536
+ * that gets replaced with a real Tensor at compile time.
537
+ *
538
+ * The parameter's name is auto-derived from its property path in the model
539
+ * tree (e.g. `layers.0.attn.W_q`). Init metadata travels with the param;
540
+ * call `compiled.uploadInitialParams()` to apply it after compile.
541
+ */
542
+ protected param(shape: Shape, opts?: ParamOptions): Tensor;
543
+ }
544
+ interface MaterializedParams {
545
+ /** Map from auto-derived path (e.g. `layers.0.attn.W_q`) to its Tensor. */
546
+ tensors: Record<string, Tensor>;
547
+ /** Init function per param path. Used by `uploadInitialParams`. */
548
+ initFns: Record<string, InitFn>;
549
+ /** Whether this param should receive AdamW weight decay. Resolved at
550
+ * `param()` time from `ParamOptions.decay` (with init-based default). */
551
+ decayFlags: Record<string, boolean>;
552
+ }
553
+ /**
554
+ * Walk the module tree and replace every ParamSentinel with a real Tensor
555
+ * created via `paramInput(autoName, ...)`. Must be called inside an active
556
+ * trace context (paramInput appends to the current graph).
557
+ *
558
+ * Returns the param tensors keyed by path, plus init functions for use by
559
+ * `uploadInitialParams`.
560
+ */
561
+ declare function materializeParams(root: Module): MaterializedParams;
562
+
563
+ /** Declares one input tensor of the model's forward function. The name is the
564
+ * key in the `inputs:` Record at compile time and the key on the `step()`/
565
+ * `run()` data object at runtime. */
566
+ interface InputDecl {
567
+ shape: Shape;
568
+ dtype?: Dtype;
569
+ }
570
+ /** Inputs declaration: a Record from input name to its shape/dtype. The name
571
+ * doubles as the key the forward fn destructures and the key the runtime
572
+ * expects in `step({...})` / `run({...})`. */
573
+ type InputDecls = Record<string, InputDecl>;
574
+ /** Maps an `InputDecls` Record to its forward-time tensor counterpart —
575
+ * same keys, each value is a Tensor. Used to type the forward function's
576
+ * `inputs` argument from the declared shape Record. */
577
+ type InputsTensors<I extends InputDecls> = {
578
+ [K in keyof I]: Tensor;
579
+ };
580
+ /** Forward function shape: takes the materialized model and a Record of
581
+ * named input tensors (matching the declared `inputs:` keys), returns the
582
+ * output tensor (loss for compileModule; logits/etc. for compileForward).
583
+ * The second generic flows from the inputs declaration so destructuring
584
+ * the input record stays typed. */
585
+ type ForwardFn<M extends Module, I extends InputDecls = InputDecls> = (m: M, inputs: InputsTensors<I>) => Tensor;
586
+ interface CompiledIR {
587
+ graph: GradResult['graph'];
588
+ paramGrads: GradResult['paramGrads'];
589
+ loss: Tensor;
590
+ plan: BufferPlan;
591
+ kernels: KernelSpec[];
592
+ }
593
+ /** Trace + autograd + buffer-plan + codegen, without touching WebGPU. */
594
+ declare function compileToIR(traceFn: () => Tensor): CompiledIR;
595
+ /** Full compile pipeline. Browser-only because it creates a GPUDevice. */
596
+ declare function compile(traceFn: () => Tensor, opts?: RuntimeOpts): Promise<CompiledRuntime & {
597
+ ir: CompiledIR;
598
+ }>;
599
+ interface CompileModuleOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
600
+ /** Per-step data inputs to the forward function, keyed by name. The forward
601
+ * fn destructures these out of its second argument; runtime calls to
602
+ * `step()` / `run()` pass typed arrays under the same keys. */
603
+ inputs?: I;
604
+ /** Adam hyperparameters. If omitted, no optimizer is appended (forward-only). */
605
+ adam?: AdamConfig;
606
+ }
607
+ interface CompileForwardOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
608
+ /** Per-step data inputs to the forward function, keyed by name. */
609
+ inputs?: I;
610
+ }
611
+ /** Forward-only compile options as taken by the `compileForward` *method* on
612
+ * a training runtime — no `device` (inherited) and no `sharedParams`
613
+ * (auto-supplied from the train graph's params). */
614
+ interface CompileForwardMethodOptions<I extends InputDecls = InputDecls> {
615
+ inputs?: I;
616
+ }
617
+ /** Returned by `compileModule`. Adds training-graph extras (auto-init, reset,
618
+ * sibling-graph compile) on top of the base runtime. */
619
+ interface CompiledModule<M extends Module> extends CompiledRuntime {
620
+ ir: CompiledIR;
621
+ /** Number of dispatchable kernels (excludes leaf no-ops). */
622
+ kernelCount: number;
623
+ /** Re-initialize all params from their declared init specs and zero the
624
+ * optimizer state. Use to start training over without recompiling. */
625
+ reset(): void;
626
+ /** Compile a sibling forward-only graph (e.g., a B=1 inference graph or a
627
+ * B=N held-out eval graph) that shares this runtime's device and param
628
+ * buffers. Pass the forward fn (typically distinct from your loss fn —
629
+ * it returns logits, not a scalar) and any shape changes via `inputs`.
630
+ * Auto-initialization is a no-op since params are shared. */
631
+ compileForward<I extends InputDecls>(forward: ForwardFn<M, I>, opts?: CompileForwardMethodOptions<I>): Promise<CompiledForwardModule>;
632
+ }
633
+ /** Returned by `compileForward` (and by the `compileForward` method). */
634
+ interface CompiledForwardModule extends CompiledForward {
635
+ ir: CompiledIR;
636
+ /** Number of dispatchable kernels (excludes leaf no-ops). */
637
+ kernelCount: number;
638
+ }
639
+ /**
640
+ * Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
641
+ * model instance itself: compilation mutates the tree (every `ParamSentinel`
642
+ * field becomes a real `Tensor`), so the instance is consumed and shouldn't be
643
+ * referenced afterwards. Re-call the factory if you need a fresh tree.
644
+ *
645
+ * The forward function takes the materialized model and a Record of named
646
+ * input tensors, returns the loss tensor. Inputs are matched by name with the
647
+ * `inputs:` declaration:
648
+ *
649
+ * inputs: {
650
+ * tokens: { shape: [B, T], dtype: 'i32' },
651
+ * targets: { shape: [B, T], dtype: 'i32' },
652
+ * }
653
+ * forward: (m, { tokens, targets }) => …
654
+ *
655
+ * Walks the module tree to materialize params with auto-derived names, then
656
+ * runs trace → grad → adam → buffer plan → codegen → runtime. Initial
657
+ * parameter values are uploaded automatically before this function returns;
658
+ * call `reset()` later to re-randomize.
659
+ *
660
+ * If `opts.adam` is set, the runtime's `step()` automatically tracks an
661
+ * internal step count and injects the bias-corrected `lrt` scalar each call;
662
+ * users don't need to provide it themselves.
663
+ */
664
+ declare function compileModule<M extends Module, I extends InputDecls = InputDecls>(modelFactory: () => M, forward: ForwardFn<M, I>, opts?: CompileModuleOptions<I>): Promise<CompiledModule<M>>;
665
+ /**
666
+ * Compile a Module-based model in forward-only mode (no autograd, no Adam).
667
+ * The forward function returns the output tensor (e.g., logits) instead of a
668
+ * scalar loss; runtime exposes `run(inputs)` returning the full output as a
669
+ * `Float32Array`.
670
+ *
671
+ * **Prefer the `compileForward` method on a training runtime** when both
672
+ * graphs use the same Module class — it auto-supplies `device` and
673
+ * `sharedParams`. This standalone form is for forward-only models with no
674
+ * training graph at all, or for sharing params across a different model.
675
+ *
676
+ * **Sharing params with a training compile.** Pass `opts.sharedParams =
677
+ * trainCompiled.params` to bind this graph's param buffers to an existing
678
+ * training runtime's GPU buffers — every train step is then immediately
679
+ * visible to `run()` calls here, no copies.
680
+ *
681
+ * Initial param values are uploaded automatically for params *not* covered
682
+ * by `sharedParams` (those are owned by the sibling compile).
683
+ */
684
+ declare function compileForward<M extends Module, I extends InputDecls = InputDecls>(modelFactory: () => M, forward: ForwardFn<M, I>, opts?: CompileForwardOptions<I>): Promise<CompiledForwardModule>;
685
+
686
+ interface LinearOptions {
687
+ /** Include a bias term (default true). */
688
+ bias?: boolean;
689
+ }
690
+ declare class Linear extends Module {
691
+ readonly inDim: number;
692
+ readonly outDim: number;
693
+ W: Tensor;
694
+ b: Tensor | null;
695
+ constructor(inDim: number, outDim: number, opts?: LinearOptions);
696
+ fwd(x: Tensor): Tensor;
697
+ }
698
+ declare class LayerNorm extends Module {
699
+ readonly d: number;
700
+ readonly eps: number;
701
+ g: Tensor;
702
+ b: Tensor;
703
+ constructor(d: number, eps?: number);
704
+ fwd(x: Tensor): Tensor;
705
+ }
706
+ /** [..., T, D] → [..., H, T, D/H]. Folds the standard
707
+ * `transpose(reshape(x, [..., T, H, d]), [..., H, T, d])` pattern into one
708
+ * call. Last dim of `x` must divide evenly by `nHeads`. */
709
+ declare function splitHeads(x: Tensor, nHeads: number): Tensor;
710
+ /** Inverse of `splitHeads`: [..., H, T, d] → [..., T, H*d]. */
711
+ declare function mergeHeads(x: Tensor): Tensor;
712
+ /** Slice a captured tensor named `name` into one Float32Array per head, using
713
+ * the static shape registered at compile time. The leading axis is treated as
714
+ * heads (matching `splitHeads` layout at B=1); a leading singleton batch is
715
+ * stripped if present so callers can pass capture names directly. Throws if
716
+ * the capture isn't registered or wasn't read back this call. */
717
+ declare function unsplitHeads(captures: Captures, name: string): Float32Array[];
718
+ /** Per-position cross-entropy along the last (vocab) axis: returns
719
+ * `-log p(target)` at each position. `logits` is `[..., V]`; `targets` is
720
+ * `[...]` of i32; result is `[...]` (one rank less than logits). The user
721
+ * applies their own masking + reduction downstream — useful when only some
722
+ * positions contribute (e.g. result-digit masking) or for label smoothing. */
723
+ declare function crossEntropyLast(logits: Tensor, targets: Tensor): Tensor;
724
+
725
+ type nn_d_LayerNorm = LayerNorm;
726
+ declare const nn_d_LayerNorm: typeof LayerNorm;
727
+ type nn_d_Linear = Linear;
728
+ declare const nn_d_Linear: typeof Linear;
729
+ type nn_d_LinearOptions = LinearOptions;
730
+ declare const nn_d_crossEntropyLast: typeof crossEntropyLast;
731
+ declare const nn_d_mergeHeads: typeof mergeHeads;
732
+ declare const nn_d_splitHeads: typeof splitHeads;
733
+ declare const nn_d_unsplitHeads: typeof unsplitHeads;
734
+ declare namespace nn_d {
735
+ export { nn_d_LayerNorm as LayerNorm, nn_d_Linear as Linear, nn_d_crossEntropyLast as crossEntropyLast, nn_d_mergeHeads as mergeHeads, nn_d_splitHeads as splitHeads, nn_d_unsplitHeads as unsplitHeads };
736
+ export type { nn_d_LinearOptions as LinearOptions };
737
+ }
738
+
739
+ export { Captures, Module, ShapeError, add, appendAdam, appendGrad, arange, capture, compile, compileForward, compileModule, compileToIR, createForwardRuntime, createRuntime, div, embedding, emitKernels, exp, greater, less, log, logSoftmaxLast, materializeParams, matmul, matmulBatched, meanLast, mul, nn_d as nn, oneHot, paramInput, planBuffers, relu, reshape, rsqrt, sliceLastRange, softmaxCausalLast, sqrt, stateInput, sub, sumAll, sumLast, swapAxes, tensorInput, trace, traceInto, transpose, where, whereCausal };
740
+ export type { AdamConfig, AdamResult, BufferPlan, BufferSpec, CallSite, CompileForwardMethodOptions, CompileForwardOptions, CompileModuleOptions, CompiledForward, CompiledForwardModule, CompiledIR, CompiledModule, CompiledRuntime, Dtype, ForwardFn, GradResult, Graph, InitSpec, InputDecl, InputDecls, InputsTensors, KernelSpec, MaterializedParams, OpNode, ParamOptions, RunOptions, RunResult, RuntimeOpts, Shape, StepResult, Tensor, Writeback, WritebackDecl };