tensorgrad 0.0.15 → 0.0.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.ts CHANGED
@@ -311,12 +311,62 @@ interface WritebackDecl {
311
311
  */
312
312
  declare function planBuffers(graph: Graph, paramGrads: Record<string, Tensor>, writebackDecls?: WritebackDecl[]): BufferPlan;
313
313
 
314
+ /** Per-step learning-rate schedule. Either a fixed number or one of the
315
+ * serializable shape forms below. Functions/closures are not supported —
316
+ * the schedule needs to cross thread boundaries and survive serialization
317
+ * for the worker-internal runtime, and every realistic LR pattern (constant,
318
+ * linear decay, cosine, warmup-then-decay) maps to a finite set of shapes.
319
+ * Use the `lr` helper namespace to construct shapes ergonomically. */
320
+ type LRSchedule = number | {
321
+ readonly kind: 'constant';
322
+ readonly value: number;
323
+ } | {
324
+ readonly kind: 'linearDecay';
325
+ readonly peak: number;
326
+ readonly final: number;
327
+ readonly steps: number;
328
+ } | {
329
+ readonly kind: 'cosineDecay';
330
+ readonly peak: number;
331
+ readonly final: number;
332
+ readonly steps: number;
333
+ } | {
334
+ readonly kind: 'warmup';
335
+ readonly peakLr: number;
336
+ readonly warmupSteps: number;
337
+ readonly after: LRSchedule;
338
+ };
339
+ /** Ergonomic constructors for LRSchedule shapes. */
340
+ declare const lr: {
341
+ constant: (value: number) => LRSchedule;
342
+ /** Linearly interpolate from `peak` at step 1 to `final` at step `steps`,
343
+ * then hold at `final`. Matches `peak + (final - peak) * min(step/steps, 1)`. */
344
+ linearDecay: (opts: {
345
+ peak: number;
346
+ final: number;
347
+ steps: number;
348
+ }) => LRSchedule;
349
+ /** Half-cosine from `peak` at step 1 down to `final` at step `steps`,
350
+ * then hold at `final`. */
351
+ cosineDecay: (opts: {
352
+ peak: number;
353
+ final: number;
354
+ steps: number;
355
+ }) => LRSchedule;
356
+ /** Linear ramp from 0 to `peakLr` over `warmupSteps` steps, then hand off
357
+ * to `after` (offset so step 1 of `after` = first post-warmup step). */
358
+ warmup: (opts: {
359
+ peakLr: number;
360
+ warmupSteps: number;
361
+ after: LRSchedule;
362
+ }) => LRSchedule;
363
+ };
364
+ /** Resolve a schedule to its scalar value at a given 1-based step. */
365
+ declare function resolveLR(schedule: LRSchedule, step: number): number;
314
366
  interface AdamConfig {
315
- /** Constant scalar (e.g., `0.005`) or a per-step schedule function
316
- * `(step) => lr`. Schedule fn lets the user implement linear/cosine decay
317
- * or warmup; first call passes `step=1`. Decay-shrink (AdamW) updates
318
- * per-step automatically when this is a function. */
319
- lr: number | ((step: number) => number);
367
+ /** Learning rate schedule. Pass a number for fixed lr, or a shape from
368
+ * the `lr` helpers (e.g., `lr.linearDecay({ peak: 0.005, final: 0.0005, steps: 1500 })`). */
369
+ lr: LRSchedule;
320
370
  b1?: number;
321
371
  b2?: number;
322
372
  eps?: number;
@@ -330,16 +380,17 @@ interface AdamConfig {
330
380
  * Example: `(name) => name.includes('.W') || name.endsWith('_emb')`. */
331
381
  decayFilter?: (paramName: string) => boolean;
332
382
  }
333
- /** Resolved hyperparameters: lr is the schedule fn (constants are wrapped). */
383
+ /** Resolved hyperparameters with all fields populated. `lr` stays as the
384
+ * shape (not pre-resolved) so the runtime can compute per-step values. */
334
385
  interface AdamResolvedConfig {
335
- lr: (step: number) => number;
386
+ lr: LRSchedule;
336
387
  b1: number;
337
388
  b2: number;
338
389
  eps: number;
339
390
  weightDecay: number;
340
391
  decayFilter: (name: string) => boolean;
341
- /** True iff the user supplied an lr function (vs a constant). When false,
342
- * decayShrink is baked at compile time and never updated. */
392
+ /** True iff the lr shape varies with step (linearDecay, cosineDecay,
393
+ * warmup). When false, decayShrink is baked at compile time. */
343
394
  lrIsScheduled: boolean;
344
395
  }
345
396
  interface AdamResult {
@@ -431,124 +482,52 @@ interface RunOptions {
431
482
  * `.get` throws); when true, captures are read back and accessible. */
432
483
  withCaptures?: boolean;
433
484
  }
434
- interface StepOptions extends RunOptions {
435
- /** If false, the training submit is queued but the JS thread does not
436
- * await `mapAsync` of the loss buffer. Returns `void` immediately.
437
- * Use `runtime.readLoss()` to read the latest loss explicitly when
438
- * you want it (e.g., every Nth step for UI display).
439
- *
440
- * Why: each `mapAsync` round-trip is ~1 ms on desktop but 10–30 ms on
441
- * Android Chrome. A training loop that awaits per step pays N × that
442
- * on the main thread, which on mobile starves the OS compositor and
443
- * causes visible UI sluggishness. With `readLoss: false` plus a
444
- * `requestAnimationFrame` yield between steps, the main thread stays
445
- * responsive while training runs at GPU speed.
446
- *
447
- * Implies `withCaptures: false`. Default: true. */
448
- readLoss?: boolean;
449
- }
450
- /** Common surface for both training and forward-only compiled runtimes. */
451
- interface CompiledBase {
452
- /** The GPUDevice this runtime is bound to. Pass to sibling compiles to
453
- * share the device, or use directly for other GPU work. */
454
- device: GPUDevice;
455
- /** Param name -> the underlying GPUBuffer. Pass to a sibling compile via
456
- * `sharedParams` to share without copies. */
457
- params: Map<string, GPUBuffer>;
458
- /** Shape of the graph's output (loss scalar `[]` for training; the user's
459
- * returned tensor for forward-only compiles). */
460
- outputShape: number[];
461
- /** Upload parameter Float32Arrays to their GPU buffers. By default, requires
462
- * *all* params to be present; throws on any unknown or missing key. Pass
463
- * `{ partial: true }` to skip the missing-key check. */
464
- uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void;
465
- /** Read all parameters back as Float32Arrays — used for UI panels. */
466
- downloadParams(): Promise<Record<string, Float32Array>>;
467
- /** Free GPU resources. */
468
- destroy(): void;
469
- }
470
- /** Run a dispatch and read back the full output tensor. Default returns the
471
- * output as a `Float32Array`; with `{ withCaptures: true }` returns
472
- * `{ output, captures }`. Same shape as `step()`'s overloads. */
473
- interface RunFn {
474
- (inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>;
475
- (inputs: Record<string, Int32Array | Float32Array>, opts: {
476
- withCaptures: true;
477
- }): Promise<RunResult>;
478
- (inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunResult>;
479
- }
480
- interface CompiledRuntime extends CompiledBase {
481
- /** Read all parameter gradients back. Mostly for verification / debugging. */
482
- downloadParamGrads(): Promise<Record<string, Float32Array>>;
483
- /**
484
- * One full forward+backward step.
485
- * 1. Uploads `inputs` (tokens, targets, masks) to input buffers.
486
- * 2. Dispatches every kernel in order.
487
- * 3. Reads back the loss scalar (and any registered captures, if requested).
488
- * Default returns the loss as a JS number; with `{ withCaptures: true }`
489
- * returns `{ loss, captures }`.
490
- */
491
- step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>;
492
- step(inputs: Record<string, Int32Array | Float32Array>, opts: {
493
- withCaptures: true;
494
- }): Promise<StepResult>;
495
- step(inputs: Record<string, Int32Array | Float32Array>, opts: {
496
- readLoss: false;
497
- }): Promise<void>;
498
- step(inputs: Record<string, Int32Array | Float32Array>, opts: StepOptions): Promise<number | StepResult | void>;
499
- /** Same dispatch as step() but returns the full output Float32Array — for
500
- * training graphs the output is a scalar loss, so step() is usually more
501
- * convenient. Provided for parity with `compileForward`. */
502
- run: RunFn;
503
- /** Read the latest loss value from the GPU. Pair with `step({ readLoss: false })`
504
- * fire-and-forget training: every Nth iteration, call `readLoss()` for the
505
- * UI, but most iterations don't pay the `mapAsync` cost. */
506
- readLoss(): Promise<number>;
507
- /** Re-zero all optimizer state buffers (Adam's m/v) in place. Pair with
508
- * `uploadInitialParams()` for a full training reset without recompile. */
509
- resetOptimizerState(): void;
510
- }
511
- /** Forward-only compiled runtime — produced by `compileForward`. No optimizer,
512
- * no backward. Returns the output tensor (not just a scalar) per `run()` call. */
513
- interface CompiledForward extends CompiledBase {
514
- run: RunFn;
515
- }
516
- interface RuntimeOpts {
517
- /** Pre-acquired GPUDevice. If omitted, runtime requests its own. */
518
- device?: GPUDevice;
519
- /** External param buffers to bind in place of allocating fresh ones, keyed
520
- * by param name. Used to share params between a training compile and a
521
- * sibling forward-only compile (e.g., a B=1 inference graph). When a name
522
- * is in this map, the runtime reuses the provided GPUBuffer; otherwise it
523
- * allocates as usual. */
524
- sharedParams?: Map<string, GPUBuffer>;
525
- }
526
- declare function createRuntime(plan: BufferPlan, kernels: KernelSpec[], lossBufferId: number, opts?: RuntimeOpts): Promise<CompiledRuntime>;
527
- /** Same machinery as `createRuntime`, narrower public type: a forward-only
528
- * graph exposes `run()` instead of `step()` (no optimizer state, no scalar-
529
- * loss readback). The full runtime object is built once and projected by
530
- * `compileForward` to the public shape. */
531
- declare function createForwardRuntime(plan: BufferPlan, kernels: KernelSpec[], outputBufferId: number, opts?: RuntimeOpts): Promise<CompiledForward>;
532
485
 
533
- /** How a parameter's initial values are produced.
534
- * - `'randn'` Gaussian, with `scale` (default 0.02). The common case for
535
- * weight matrices and embeddings.
536
- * - `'zeros'` — fill with 0. Common for biases and LayerNorm beta.
537
- * - `'ones'` — fill with 1. Common for LayerNorm gain.
538
- * - Custom function receives total element count and shape, returns the
539
- * Float32Array. Use for fan-in scaling or any non-standard scheme.
486
+ /** How a parameter's initial values are produced. Serializable shape — no
487
+ * closures, since the initial values cross the worker boundary at compile
488
+ * time. Use the `init` helpers for ergonomic construction.
489
+ *
490
+ * String shorthands:
491
+ * - `'randn'`Gaussian with std 0.02 (the common weight-matrix init).
492
+ * - `'zeros'` fill with 0 (biases, LayerNorm beta).
493
+ * - `'ones'` — fill with 1 (LayerNorm gain).
494
+ *
495
+ * Object shapes:
496
+ * - `{ kind: 'randn', scale }` — randn with explicit std.
497
+ * - `{ kind: 'kaiming', gain? }` — `std = gain / sqrt(fan_in)`. Default
498
+ * gain `sqrt(2)` (good for ReLU). `fan_in = shape[0]`.
499
+ * - `{ kind: 'literal', data }` — explicit Float32Array; length must
500
+ * match the parameter's element count.
540
501
  */
541
- type InitSpec = 'randn' | 'zeros' | 'ones' | ((size: number, shape: readonly number[]) => Float32Array);
502
+ type InitSpec = 'randn' | 'zeros' | 'ones' | {
503
+ readonly kind: 'randn';
504
+ readonly scale: number;
505
+ } | {
506
+ readonly kind: 'kaiming';
507
+ readonly gain?: number;
508
+ } | {
509
+ readonly kind: 'literal';
510
+ readonly data: Float32Array;
511
+ };
512
+ /** Ergonomic constructors for InitSpec object shapes. */
513
+ declare const init: {
514
+ randn: (opts?: {
515
+ scale?: number;
516
+ }) => InitSpec;
517
+ kaiming: (opts?: {
518
+ gain?: number;
519
+ }) => InitSpec;
520
+ literal: (data: Float32Array) => InitSpec;
521
+ };
542
522
  interface ParamOptions {
543
523
  dtype?: Dtype;
544
- /** Init kind. Default: `'randn'`. */
524
+ /** Init shape. Default: `'randn'` (std 0.02). */
545
525
  init?: InitSpec;
546
- /** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
547
- scale?: number;
548
526
  /** Whether AdamW (when `weightDecay > 0`) should apply decoupled weight
549
- * decay to this param. Default: `true` for `'randn'` init (weight matrices,
550
- * embeddings), `false` for `'zeros'` / `'ones'` (biases, LN gains). Override
551
- * to force or skip. Replaces `adam.decayFilter` for the common case. */
527
+ * decay to this param. Default: `true` for randn/kaiming/literal init
528
+ * (weight matrices, embeddings); `false` for zeros/ones (biases, LN
529
+ * gains). Override to force or skip. Replaces `adam.decayFilter` for
530
+ * the common case. */
552
531
  decay?: boolean;
553
532
  }
554
533
  type InitFn = (size: number, shape: readonly number[]) => Float32Array;
@@ -590,21 +569,14 @@ interface InputDecl {
590
569
  shape: Shape;
591
570
  dtype?: Dtype;
592
571
  }
593
- /** Inputs declaration: a Record from input name to its shape/dtype. The name
594
- * doubles as the key the forward fn destructures and the key the runtime
595
- * expects in `step({...})` / `run({...})`. */
572
+ /** Inputs declaration: a Record from input name to its shape/dtype. */
596
573
  type InputDecls = Record<string, InputDecl>;
597
574
  /** Maps an `InputDecls` Record to its forward-time tensor counterpart —
598
- * same keys, each value is a Tensor. Used to type the forward function's
599
- * `inputs` argument from the declared shape Record. */
575
+ * same keys, each value is a Tensor. */
600
576
  type InputsTensors<I extends InputDecls> = {
601
577
  [K in keyof I]: Tensor;
602
578
  };
603
- /** Forward function shape: takes the materialized model and a Record of
604
- * named input tensors (matching the declared `inputs:` keys), returns the
605
- * output tensor (loss for compileModule; logits/etc. for compileForward).
606
- * The second generic flows from the inputs declaration so destructuring
607
- * the input record stays typed. */
579
+ /** Forward function shape. */
608
580
  type ForwardFn<M extends Module, I extends InputDecls = InputDecls> = (m: M, inputs: InputsTensors<I>) => Tensor;
609
581
  interface CompiledIR {
610
582
  graph: GradResult['graph'];
@@ -615,59 +587,67 @@ interface CompiledIR {
615
587
  }
616
588
  /** Trace + autograd + buffer-plan + codegen, without touching WebGPU. */
617
589
  declare function compileToIR(traceFn: () => Tensor): CompiledIR;
618
- /** Full compile pipeline. Browser-only because it creates a GPUDevice. */
619
- declare function compile(traceFn: () => Tensor, opts?: RuntimeOpts): Promise<CompiledRuntime & {
620
- ir: CompiledIR;
621
- }>;
622
- interface CompileModuleOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
623
- /** Per-step data inputs to the forward function, keyed by name. The forward
624
- * fn destructures these out of its second argument; runtime calls to
625
- * `step()` / `run()` pass typed arrays under the same keys. */
590
+ interface CompileModuleOptions<I extends InputDecls = InputDecls> {
626
591
  inputs?: I;
627
- /** Adam hyperparameters. If omitted, no optimizer is appended (forward-only). */
628
592
  adam?: AdamConfig;
629
593
  }
630
- interface CompileForwardOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
631
- /** Per-step data inputs to the forward function, keyed by name. */
594
+ interface CompileForwardOptions<I extends InputDecls = InputDecls> {
632
595
  inputs?: I;
633
596
  }
634
- /** Forward-only compile options as taken by the `compileForward` *method* on
635
- * a training runtime — no `device` (inherited) and no `sharedParams`
636
- * (auto-supplied from the train graph's params). */
637
597
  interface CompileForwardMethodOptions<I extends InputDecls = InputDecls> {
638
598
  inputs?: I;
639
599
  }
640
- /** Returned by `compileModule`. Adds training-graph extras (auto-init, reset,
641
- * sibling-graph compile) on top of the base runtime. */
642
- interface CompiledModule<M extends Module> extends CompiledRuntime {
643
- ir: CompiledIR;
644
- /** Number of dispatchable kernels (excludes leaf no-ops). */
645
- kernelCount: number;
646
- /** Re-initialize all params from their declared init specs and zero the
647
- * optimizer state. Use to start training over without recompiling. */
648
- reset(): void;
649
- /** Compile a sibling forward-only graph (e.g., a B=1 inference graph or a
650
- * B=N held-out eval graph) that shares this runtime's device and param
651
- * buffers. Pass the forward fn (typically distinct from your loss fn —
652
- * it returns logits, not a scalar) and any shape changes via `inputs`.
653
- * Auto-initialization is a no-op since params are shared. */
600
+ /** Returned by `compileModule`. Proxies all GPU work to a worker held
601
+ * internally; user code awaits Promises and never sees the worker. */
602
+ interface CompiledModule<M extends Module> {
603
+ readonly ir: CompiledIR;
604
+ readonly kernelCount: number;
605
+ readonly outputShape: readonly number[];
606
+ /** Names of the model's parameters, in materialization order. The actual
607
+ * GPUBuffers live in the worker; use `downloadParams()` for values. */
608
+ readonly paramNames: readonly string[];
609
+ step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>;
610
+ step(inputs: Record<string, Int32Array | Float32Array>, opts: {
611
+ withCaptures: true;
612
+ }): Promise<StepResult>;
613
+ run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>;
614
+ run(inputs: Record<string, Int32Array | Float32Array>, opts: {
615
+ withCaptures: true;
616
+ }): Promise<RunResult>;
617
+ uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): Promise<void>;
618
+ downloadParams(): Promise<Record<string, Float32Array>>;
619
+ downloadParamGrads(): Promise<Record<string, Float32Array>>;
620
+ /** Re-initialize all params + zero optimizer state. */
621
+ reset(): Promise<void>;
622
+ resetOptimizerState(): Promise<void>;
623
+ /** Compile a sibling forward-only graph that shares this runtime's worker
624
+ * (and therefore its param GPUBuffers). */
654
625
  compileForward<I extends InputDecls>(forward: ForwardFn<M, I>, opts?: CompileForwardMethodOptions<I>): Promise<CompiledForwardModule>;
626
+ /** Free the runtime's GPU resources and terminate the worker. */
627
+ destroy(): void;
655
628
  }
656
629
  /** Returned by `compileForward` (and by the `compileForward` method). */
657
- interface CompiledForwardModule extends CompiledForward {
658
- ir: CompiledIR;
659
- /** Number of dispatchable kernels (excludes leaf no-ops). */
660
- kernelCount: number;
630
+ interface CompiledForwardModule {
631
+ readonly ir: CompiledIR;
632
+ readonly kernelCount: number;
633
+ readonly outputShape: readonly number[];
634
+ readonly paramNames: readonly string[];
635
+ run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>;
636
+ run(inputs: Record<string, Int32Array | Float32Array>, opts: {
637
+ withCaptures: true;
638
+ }): Promise<RunResult>;
639
+ uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): Promise<void>;
640
+ downloadParams(): Promise<Record<string, Float32Array>>;
641
+ destroy(): void;
661
642
  }
662
643
  /**
663
644
  * Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
664
645
  * model instance itself: compilation mutates the tree (every `ParamSentinel`
665
646
  * field becomes a real `Tensor`), so the instance is consumed and shouldn't be
666
- * referenced afterwards. Re-call the factory if you need a fresh tree.
647
+ * referenced afterwards.
667
648
  *
668
649
  * The forward function takes the materialized model and a Record of named
669
- * input tensors, returns the loss tensor. Inputs are matched by name with the
670
- * `inputs:` declaration:
650
+ * input tensors, returns the loss tensor:
671
651
  *
672
652
  * inputs: {
673
653
  * tokens: { shape: [B, T], dtype: 'i32' },
@@ -675,34 +655,15 @@ interface CompiledForwardModule extends CompiledForward {
675
655
  * }
676
656
  * forward: (m, { tokens, targets }) => …
677
657
  *
678
- * Walks the module tree to materialize params with auto-derived names, then
679
- * runs trace grad adam buffer plan codegen → runtime. Initial
680
- * parameter values are uploaded automatically before this function returns;
681
- * call `reset()` later to re-randomize.
682
- *
683
- * If `opts.adam` is set, the runtime's `step()` automatically tracks an
684
- * internal step count and injects the bias-corrected `lrt` scalar each call;
685
- * users don't need to provide it themselves.
658
+ * Returns a `CompiledModule` proxy. All GPU work (createRuntime, step, run,
659
+ * mapAsync) happens in an internal worker; calls return Promises that resolve
660
+ * when the worker replies.
686
661
  */
687
662
  declare function compileModule<M extends Module, I extends InputDecls = InputDecls>(modelFactory: () => M, forward: ForwardFn<M, I>, opts?: CompileModuleOptions<I>): Promise<CompiledModule<M>>;
688
663
  /**
689
- * Compile a Module-based model in forward-only mode (no autograd, no Adam).
690
- * The forward function returns the output tensor (e.g., logits) instead of a
691
- * scalar loss; runtime exposes `run(inputs)` returning the full output as a
692
- * `Float32Array`.
693
- *
694
- * **Prefer the `compileForward` method on a training runtime** when both
695
- * graphs use the same Module class — it auto-supplies `device` and
696
- * `sharedParams`. This standalone form is for forward-only models with no
697
- * training graph at all, or for sharing params across a different model.
698
- *
699
- * **Sharing params with a training compile.** Pass `opts.sharedParams =
700
- * trainCompiled.params` to bind this graph's param buffers to an existing
701
- * training runtime's GPU buffers — every train step is then immediately
702
- * visible to `run()` calls here, no copies.
703
- *
704
- * Initial param values are uploaded automatically for params *not* covered
705
- * by `sharedParams` (those are owned by the sibling compile).
664
+ * Forward-only compile. Spawns its own worker. For sibling graphs that share
665
+ * params with a training graph, prefer the `compileForward` method on the
666
+ * CompiledModule returned by `compileModule()`.
706
667
  */
707
668
  declare function compileForward<M extends Module, I extends InputDecls = InputDecls>(modelFactory: () => M, forward: ForwardFn<M, I>, opts?: CompileForwardOptions<I>): Promise<CompiledForwardModule>;
708
669
 
@@ -759,5 +720,5 @@ declare namespace nn_d {
759
720
  export type { nn_d_LinearOptions as LinearOptions };
760
721
  }
761
722
 
762
- export { Captures, Module, ShapeError, add, appendAdam, appendGrad, arange, capture, compile, compileForward, compileModule, compileToIR, createForwardRuntime, createRuntime, div, embedding, emitKernels, exp, greater, less, log, logSoftmaxLast, materializeParams, matmul, matmulBatched, meanLast, mul, nn_d as nn, oneHot, paramInput, planBuffers, relu, reshape, rsqrt, sliceLastRange, softmaxCausalLast, sqrt, stateInput, sub, sumAll, sumLast, swapAxes, tensorInput, trace, traceInto, transpose, where, whereCausal };
763
- export type { AdamConfig, AdamResult, BufferPlan, BufferSpec, CallSite, CompileForwardMethodOptions, CompileForwardOptions, CompileModuleOptions, CompiledForward, CompiledForwardModule, CompiledIR, CompiledModule, CompiledRuntime, Dtype, ForwardFn, GradResult, Graph, InitSpec, InputDecl, InputDecls, InputsTensors, KernelSpec, MaterializedParams, OpNode, ParamOptions, RunOptions, RunResult, RuntimeOpts, Shape, StepResult, Tensor, Writeback, WritebackDecl };
723
+ export { Captures, Module, ShapeError, add, appendAdam, appendGrad, arange, capture, compileForward, compileModule, compileToIR, div, embedding, emitKernels, exp, greater, init, less, log, logSoftmaxLast, lr, materializeParams, matmul, matmulBatched, meanLast, mul, nn_d as nn, oneHot, paramInput, planBuffers, relu, reshape, resolveLR, rsqrt, sliceLastRange, softmaxCausalLast, sqrt, stateInput, sub, sumAll, sumLast, swapAxes, tensorInput, trace, traceInto, transpose, where, whereCausal };
724
+ export type { AdamConfig, AdamResult, BufferPlan, BufferSpec, CallSite, CompileForwardMethodOptions, CompileForwardOptions, CompileModuleOptions, CompiledForwardModule, CompiledIR, CompiledModule, Dtype, ForwardFn, GradResult, Graph, InitSpec, InputDecl, InputDecls, InputsTensors, KernelSpec, LRSchedule, MaterializedParams, OpNode, ParamOptions, RunOptions, RunResult, Shape, StepResult, Tensor, UploadParamsOptions, Writeback, WritebackDecl };