tensorgrad 0.0.15 → 0.0.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.d.ts +154 -193
- package/dist/index.js +2208 -39
- package/dist/index.js.map +7 -1
- package/dist/worker.debug.js +553 -0
- package/package.json +60 -58
- package/src/adam.ts +69 -15
- package/src/compile.ts +334 -156
- package/src/index.ts +8 -4
- package/src/module.ts +72 -34
- package/src/worker-protocol.ts +183 -0
- package/src/worker-proxy.ts +76 -0
- package/src/worker.ts +281 -0
- package/dist/adam.js +0 -111
- package/dist/adam.js.map +0 -1
- package/dist/buffers.js +0 -120
- package/dist/buffers.js.map +0 -1
- package/dist/capture.js +0 -33
- package/dist/capture.js.map +0 -1
- package/dist/codegen.js +0 -724
- package/dist/codegen.js.map +0 -1
- package/dist/compile.js +0 -184
- package/dist/compile.js.map +0 -1
- package/dist/grad.js +0 -380
- package/dist/grad.js.map +0 -1
- package/dist/ir.js +0 -60
- package/dist/ir.js.map +0 -1
- package/dist/module.js +0 -155
- package/dist/module.js.map +0 -1
- package/dist/nn.js +0 -135
- package/dist/nn.js.map +0 -1
- package/dist/ops.js +0 -326
- package/dist/ops.js.map +0 -1
- package/dist/runtime.js +0 -402
- package/dist/runtime.js.map +0 -1
- package/dist/shape.js +0 -259
- package/dist/shape.js.map +0 -1
- package/dist/trace.js +0 -100
- package/dist/trace.js.map +0 -1
package/dist/index.d.ts
CHANGED
|
@@ -311,12 +311,62 @@ interface WritebackDecl {
|
|
|
311
311
|
*/
|
|
312
312
|
declare function planBuffers(graph: Graph, paramGrads: Record<string, Tensor>, writebackDecls?: WritebackDecl[]): BufferPlan;
|
|
313
313
|
|
|
314
|
+
/** Per-step learning-rate schedule. Either a fixed number or one of the
|
|
315
|
+
* serializable shape forms below. Functions/closures are not supported —
|
|
316
|
+
* the schedule needs to cross thread boundaries and survive serialization
|
|
317
|
+
* for the worker-internal runtime, and every realistic LR pattern (constant,
|
|
318
|
+
* linear decay, cosine, warmup-then-decay) maps to a finite set of shapes.
|
|
319
|
+
* Use the `lr` helper namespace to construct shapes ergonomically. */
|
|
320
|
+
type LRSchedule = number | {
|
|
321
|
+
readonly kind: 'constant';
|
|
322
|
+
readonly value: number;
|
|
323
|
+
} | {
|
|
324
|
+
readonly kind: 'linearDecay';
|
|
325
|
+
readonly peak: number;
|
|
326
|
+
readonly final: number;
|
|
327
|
+
readonly steps: number;
|
|
328
|
+
} | {
|
|
329
|
+
readonly kind: 'cosineDecay';
|
|
330
|
+
readonly peak: number;
|
|
331
|
+
readonly final: number;
|
|
332
|
+
readonly steps: number;
|
|
333
|
+
} | {
|
|
334
|
+
readonly kind: 'warmup';
|
|
335
|
+
readonly peakLr: number;
|
|
336
|
+
readonly warmupSteps: number;
|
|
337
|
+
readonly after: LRSchedule;
|
|
338
|
+
};
|
|
339
|
+
/** Ergonomic constructors for LRSchedule shapes. */
|
|
340
|
+
declare const lr: {
|
|
341
|
+
constant: (value: number) => LRSchedule;
|
|
342
|
+
/** Linearly interpolate from `peak` at step 1 to `final` at step `steps`,
|
|
343
|
+
* then hold at `final`. Matches `peak + (final - peak) * min(step/steps, 1)`. */
|
|
344
|
+
linearDecay: (opts: {
|
|
345
|
+
peak: number;
|
|
346
|
+
final: number;
|
|
347
|
+
steps: number;
|
|
348
|
+
}) => LRSchedule;
|
|
349
|
+
/** Half-cosine from `peak` at step 1 down to `final` at step `steps`,
|
|
350
|
+
* then hold at `final`. */
|
|
351
|
+
cosineDecay: (opts: {
|
|
352
|
+
peak: number;
|
|
353
|
+
final: number;
|
|
354
|
+
steps: number;
|
|
355
|
+
}) => LRSchedule;
|
|
356
|
+
/** Linear ramp from 0 to `peakLr` over `warmupSteps` steps, then hand off
|
|
357
|
+
* to `after` (offset so step 1 of `after` = first post-warmup step). */
|
|
358
|
+
warmup: (opts: {
|
|
359
|
+
peakLr: number;
|
|
360
|
+
warmupSteps: number;
|
|
361
|
+
after: LRSchedule;
|
|
362
|
+
}) => LRSchedule;
|
|
363
|
+
};
|
|
364
|
+
/** Resolve a schedule to its scalar value at a given 1-based step. */
|
|
365
|
+
declare function resolveLR(schedule: LRSchedule, step: number): number;
|
|
314
366
|
interface AdamConfig {
|
|
315
|
-
/**
|
|
316
|
-
* `(
|
|
317
|
-
|
|
318
|
-
* per-step automatically when this is a function. */
|
|
319
|
-
lr: number | ((step: number) => number);
|
|
367
|
+
/** Learning rate schedule. Pass a number for fixed lr, or a shape from
|
|
368
|
+
* the `lr` helpers (e.g., `lr.linearDecay({ peak: 0.005, final: 0.0005, steps: 1500 })`). */
|
|
369
|
+
lr: LRSchedule;
|
|
320
370
|
b1?: number;
|
|
321
371
|
b2?: number;
|
|
322
372
|
eps?: number;
|
|
@@ -330,16 +380,17 @@ interface AdamConfig {
|
|
|
330
380
|
* Example: `(name) => name.includes('.W') || name.endsWith('_emb')`. */
|
|
331
381
|
decayFilter?: (paramName: string) => boolean;
|
|
332
382
|
}
|
|
333
|
-
/** Resolved hyperparameters
|
|
383
|
+
/** Resolved hyperparameters with all fields populated. `lr` stays as the
|
|
384
|
+
* shape (not pre-resolved) so the runtime can compute per-step values. */
|
|
334
385
|
interface AdamResolvedConfig {
|
|
335
|
-
lr:
|
|
386
|
+
lr: LRSchedule;
|
|
336
387
|
b1: number;
|
|
337
388
|
b2: number;
|
|
338
389
|
eps: number;
|
|
339
390
|
weightDecay: number;
|
|
340
391
|
decayFilter: (name: string) => boolean;
|
|
341
|
-
/** True iff the
|
|
342
|
-
* decayShrink is baked at compile time
|
|
392
|
+
/** True iff the lr shape varies with step (linearDecay, cosineDecay,
|
|
393
|
+
* warmup). When false, decayShrink is baked at compile time. */
|
|
343
394
|
lrIsScheduled: boolean;
|
|
344
395
|
}
|
|
345
396
|
interface AdamResult {
|
|
@@ -431,124 +482,52 @@ interface RunOptions {
|
|
|
431
482
|
* `.get` throws); when true, captures are read back and accessible. */
|
|
432
483
|
withCaptures?: boolean;
|
|
433
484
|
}
|
|
434
|
-
interface StepOptions extends RunOptions {
|
|
435
|
-
/** If false, the training submit is queued but the JS thread does not
|
|
436
|
-
* await `mapAsync` of the loss buffer. Returns `void` immediately.
|
|
437
|
-
* Use `runtime.readLoss()` to read the latest loss explicitly when
|
|
438
|
-
* you want it (e.g., every Nth step for UI display).
|
|
439
|
-
*
|
|
440
|
-
* Why: each `mapAsync` round-trip is ~1 ms on desktop but 10–30 ms on
|
|
441
|
-
* Android Chrome. A training loop that awaits per step pays N × that
|
|
442
|
-
* on the main thread, which on mobile starves the OS compositor and
|
|
443
|
-
* causes visible UI sluggishness. With `readLoss: false` plus a
|
|
444
|
-
* `requestAnimationFrame` yield between steps, the main thread stays
|
|
445
|
-
* responsive while training runs at GPU speed.
|
|
446
|
-
*
|
|
447
|
-
* Implies `withCaptures: false`. Default: true. */
|
|
448
|
-
readLoss?: boolean;
|
|
449
|
-
}
|
|
450
|
-
/** Common surface for both training and forward-only compiled runtimes. */
|
|
451
|
-
interface CompiledBase {
|
|
452
|
-
/** The GPUDevice this runtime is bound to. Pass to sibling compiles to
|
|
453
|
-
* share the device, or use directly for other GPU work. */
|
|
454
|
-
device: GPUDevice;
|
|
455
|
-
/** Param name -> the underlying GPUBuffer. Pass to a sibling compile via
|
|
456
|
-
* `sharedParams` to share without copies. */
|
|
457
|
-
params: Map<string, GPUBuffer>;
|
|
458
|
-
/** Shape of the graph's output (loss scalar `[]` for training; the user's
|
|
459
|
-
* returned tensor for forward-only compiles). */
|
|
460
|
-
outputShape: number[];
|
|
461
|
-
/** Upload parameter Float32Arrays to their GPU buffers. By default, requires
|
|
462
|
-
* *all* params to be present; throws on any unknown or missing key. Pass
|
|
463
|
-
* `{ partial: true }` to skip the missing-key check. */
|
|
464
|
-
uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void;
|
|
465
|
-
/** Read all parameters back as Float32Arrays — used for UI panels. */
|
|
466
|
-
downloadParams(): Promise<Record<string, Float32Array>>;
|
|
467
|
-
/** Free GPU resources. */
|
|
468
|
-
destroy(): void;
|
|
469
|
-
}
|
|
470
|
-
/** Run a dispatch and read back the full output tensor. Default returns the
|
|
471
|
-
* output as a `Float32Array`; with `{ withCaptures: true }` returns
|
|
472
|
-
* `{ output, captures }`. Same shape as `step()`'s overloads. */
|
|
473
|
-
interface RunFn {
|
|
474
|
-
(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>;
|
|
475
|
-
(inputs: Record<string, Int32Array | Float32Array>, opts: {
|
|
476
|
-
withCaptures: true;
|
|
477
|
-
}): Promise<RunResult>;
|
|
478
|
-
(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunResult>;
|
|
479
|
-
}
|
|
480
|
-
interface CompiledRuntime extends CompiledBase {
|
|
481
|
-
/** Read all parameter gradients back. Mostly for verification / debugging. */
|
|
482
|
-
downloadParamGrads(): Promise<Record<string, Float32Array>>;
|
|
483
|
-
/**
|
|
484
|
-
* One full forward+backward step.
|
|
485
|
-
* 1. Uploads `inputs` (tokens, targets, masks) to input buffers.
|
|
486
|
-
* 2. Dispatches every kernel in order.
|
|
487
|
-
* 3. Reads back the loss scalar (and any registered captures, if requested).
|
|
488
|
-
* Default returns the loss as a JS number; with `{ withCaptures: true }`
|
|
489
|
-
* returns `{ loss, captures }`.
|
|
490
|
-
*/
|
|
491
|
-
step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>;
|
|
492
|
-
step(inputs: Record<string, Int32Array | Float32Array>, opts: {
|
|
493
|
-
withCaptures: true;
|
|
494
|
-
}): Promise<StepResult>;
|
|
495
|
-
step(inputs: Record<string, Int32Array | Float32Array>, opts: {
|
|
496
|
-
readLoss: false;
|
|
497
|
-
}): Promise<void>;
|
|
498
|
-
step(inputs: Record<string, Int32Array | Float32Array>, opts: StepOptions): Promise<number | StepResult | void>;
|
|
499
|
-
/** Same dispatch as step() but returns the full output Float32Array — for
|
|
500
|
-
* training graphs the output is a scalar loss, so step() is usually more
|
|
501
|
-
* convenient. Provided for parity with `compileForward`. */
|
|
502
|
-
run: RunFn;
|
|
503
|
-
/** Read the latest loss value from the GPU. Pair with `step({ readLoss: false })`
|
|
504
|
-
* fire-and-forget training: every Nth iteration, call `readLoss()` for the
|
|
505
|
-
* UI, but most iterations don't pay the `mapAsync` cost. */
|
|
506
|
-
readLoss(): Promise<number>;
|
|
507
|
-
/** Re-zero all optimizer state buffers (Adam's m/v) in place. Pair with
|
|
508
|
-
* `uploadInitialParams()` for a full training reset without recompile. */
|
|
509
|
-
resetOptimizerState(): void;
|
|
510
|
-
}
|
|
511
|
-
/** Forward-only compiled runtime — produced by `compileForward`. No optimizer,
|
|
512
|
-
* no backward. Returns the output tensor (not just a scalar) per `run()` call. */
|
|
513
|
-
interface CompiledForward extends CompiledBase {
|
|
514
|
-
run: RunFn;
|
|
515
|
-
}
|
|
516
|
-
interface RuntimeOpts {
|
|
517
|
-
/** Pre-acquired GPUDevice. If omitted, runtime requests its own. */
|
|
518
|
-
device?: GPUDevice;
|
|
519
|
-
/** External param buffers to bind in place of allocating fresh ones, keyed
|
|
520
|
-
* by param name. Used to share params between a training compile and a
|
|
521
|
-
* sibling forward-only compile (e.g., a B=1 inference graph). When a name
|
|
522
|
-
* is in this map, the runtime reuses the provided GPUBuffer; otherwise it
|
|
523
|
-
* allocates as usual. */
|
|
524
|
-
sharedParams?: Map<string, GPUBuffer>;
|
|
525
|
-
}
|
|
526
|
-
declare function createRuntime(plan: BufferPlan, kernels: KernelSpec[], lossBufferId: number, opts?: RuntimeOpts): Promise<CompiledRuntime>;
|
|
527
|
-
/** Same machinery as `createRuntime`, narrower public type: a forward-only
|
|
528
|
-
* graph exposes `run()` instead of `step()` (no optimizer state, no scalar-
|
|
529
|
-
* loss readback). The full runtime object is built once and projected by
|
|
530
|
-
* `compileForward` to the public shape. */
|
|
531
|
-
declare function createForwardRuntime(plan: BufferPlan, kernels: KernelSpec[], outputBufferId: number, opts?: RuntimeOpts): Promise<CompiledForward>;
|
|
532
485
|
|
|
533
|
-
/** How a parameter's initial values are produced.
|
|
534
|
-
*
|
|
535
|
-
*
|
|
536
|
-
*
|
|
537
|
-
*
|
|
538
|
-
* -
|
|
539
|
-
*
|
|
486
|
+
/** How a parameter's initial values are produced. Serializable shape — no
|
|
487
|
+
* closures, since the initial values cross the worker boundary at compile
|
|
488
|
+
* time. Use the `init` helpers for ergonomic construction.
|
|
489
|
+
*
|
|
490
|
+
* String shorthands:
|
|
491
|
+
* - `'randn'` — Gaussian with std 0.02 (the common weight-matrix init).
|
|
492
|
+
* - `'zeros'` — fill with 0 (biases, LayerNorm beta).
|
|
493
|
+
* - `'ones'` — fill with 1 (LayerNorm gain).
|
|
494
|
+
*
|
|
495
|
+
* Object shapes:
|
|
496
|
+
* - `{ kind: 'randn', scale }` — randn with explicit std.
|
|
497
|
+
* - `{ kind: 'kaiming', gain? }` — `std = gain / sqrt(fan_in)`. Default
|
|
498
|
+
* gain `sqrt(2)` (good for ReLU). `fan_in = shape[0]`.
|
|
499
|
+
* - `{ kind: 'literal', data }` — explicit Float32Array; length must
|
|
500
|
+
* match the parameter's element count.
|
|
540
501
|
*/
|
|
541
|
-
type InitSpec = 'randn' | 'zeros' | 'ones' |
|
|
502
|
+
type InitSpec = 'randn' | 'zeros' | 'ones' | {
|
|
503
|
+
readonly kind: 'randn';
|
|
504
|
+
readonly scale: number;
|
|
505
|
+
} | {
|
|
506
|
+
readonly kind: 'kaiming';
|
|
507
|
+
readonly gain?: number;
|
|
508
|
+
} | {
|
|
509
|
+
readonly kind: 'literal';
|
|
510
|
+
readonly data: Float32Array;
|
|
511
|
+
};
|
|
512
|
+
/** Ergonomic constructors for InitSpec object shapes. */
|
|
513
|
+
declare const init: {
|
|
514
|
+
randn: (opts?: {
|
|
515
|
+
scale?: number;
|
|
516
|
+
}) => InitSpec;
|
|
517
|
+
kaiming: (opts?: {
|
|
518
|
+
gain?: number;
|
|
519
|
+
}) => InitSpec;
|
|
520
|
+
literal: (data: Float32Array) => InitSpec;
|
|
521
|
+
};
|
|
542
522
|
interface ParamOptions {
|
|
543
523
|
dtype?: Dtype;
|
|
544
|
-
/** Init
|
|
524
|
+
/** Init shape. Default: `'randn'` (std 0.02). */
|
|
545
525
|
init?: InitSpec;
|
|
546
|
-
/** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
|
|
547
|
-
scale?: number;
|
|
548
526
|
/** Whether AdamW (when `weightDecay > 0`) should apply decoupled weight
|
|
549
|
-
* decay to this param. Default: `true` for
|
|
550
|
-
* embeddings)
|
|
551
|
-
* to force or skip. Replaces `adam.decayFilter` for
|
|
527
|
+
* decay to this param. Default: `true` for randn/kaiming/literal init
|
|
528
|
+
* (weight matrices, embeddings); `false` for zeros/ones (biases, LN
|
|
529
|
+
* gains). Override to force or skip. Replaces `adam.decayFilter` for
|
|
530
|
+
* the common case. */
|
|
552
531
|
decay?: boolean;
|
|
553
532
|
}
|
|
554
533
|
type InitFn = (size: number, shape: readonly number[]) => Float32Array;
|
|
@@ -590,21 +569,14 @@ interface InputDecl {
|
|
|
590
569
|
shape: Shape;
|
|
591
570
|
dtype?: Dtype;
|
|
592
571
|
}
|
|
593
|
-
/** Inputs declaration: a Record from input name to its shape/dtype.
|
|
594
|
-
* doubles as the key the forward fn destructures and the key the runtime
|
|
595
|
-
* expects in `step({...})` / `run({...})`. */
|
|
572
|
+
/** Inputs declaration: a Record from input name to its shape/dtype. */
|
|
596
573
|
type InputDecls = Record<string, InputDecl>;
|
|
597
574
|
/** Maps an `InputDecls` Record to its forward-time tensor counterpart —
|
|
598
|
-
* same keys, each value is a Tensor.
|
|
599
|
-
* `inputs` argument from the declared shape Record. */
|
|
575
|
+
* same keys, each value is a Tensor. */
|
|
600
576
|
type InputsTensors<I extends InputDecls> = {
|
|
601
577
|
[K in keyof I]: Tensor;
|
|
602
578
|
};
|
|
603
|
-
/** Forward function shape
|
|
604
|
-
* named input tensors (matching the declared `inputs:` keys), returns the
|
|
605
|
-
* output tensor (loss for compileModule; logits/etc. for compileForward).
|
|
606
|
-
* The second generic flows from the inputs declaration so destructuring
|
|
607
|
-
* the input record stays typed. */
|
|
579
|
+
/** Forward function shape. */
|
|
608
580
|
type ForwardFn<M extends Module, I extends InputDecls = InputDecls> = (m: M, inputs: InputsTensors<I>) => Tensor;
|
|
609
581
|
interface CompiledIR {
|
|
610
582
|
graph: GradResult['graph'];
|
|
@@ -615,59 +587,67 @@ interface CompiledIR {
|
|
|
615
587
|
}
|
|
616
588
|
/** Trace + autograd + buffer-plan + codegen, without touching WebGPU. */
|
|
617
589
|
declare function compileToIR(traceFn: () => Tensor): CompiledIR;
|
|
618
|
-
|
|
619
|
-
declare function compile(traceFn: () => Tensor, opts?: RuntimeOpts): Promise<CompiledRuntime & {
|
|
620
|
-
ir: CompiledIR;
|
|
621
|
-
}>;
|
|
622
|
-
interface CompileModuleOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
|
|
623
|
-
/** Per-step data inputs to the forward function, keyed by name. The forward
|
|
624
|
-
* fn destructures these out of its second argument; runtime calls to
|
|
625
|
-
* `step()` / `run()` pass typed arrays under the same keys. */
|
|
590
|
+
interface CompileModuleOptions<I extends InputDecls = InputDecls> {
|
|
626
591
|
inputs?: I;
|
|
627
|
-
/** Adam hyperparameters. If omitted, no optimizer is appended (forward-only). */
|
|
628
592
|
adam?: AdamConfig;
|
|
629
593
|
}
|
|
630
|
-
interface CompileForwardOptions<I extends InputDecls = InputDecls>
|
|
631
|
-
/** Per-step data inputs to the forward function, keyed by name. */
|
|
594
|
+
interface CompileForwardOptions<I extends InputDecls = InputDecls> {
|
|
632
595
|
inputs?: I;
|
|
633
596
|
}
|
|
634
|
-
/** Forward-only compile options as taken by the `compileForward` *method* on
|
|
635
|
-
* a training runtime — no `device` (inherited) and no `sharedParams`
|
|
636
|
-
* (auto-supplied from the train graph's params). */
|
|
637
597
|
interface CompileForwardMethodOptions<I extends InputDecls = InputDecls> {
|
|
638
598
|
inputs?: I;
|
|
639
599
|
}
|
|
640
|
-
/** Returned by `compileModule`.
|
|
641
|
-
*
|
|
642
|
-
interface CompiledModule<M extends Module>
|
|
643
|
-
ir: CompiledIR;
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
/**
|
|
647
|
-
*
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
600
|
+
/** Returned by `compileModule`. Proxies all GPU work to a worker held
|
|
601
|
+
* internally; user code awaits Promises and never sees the worker. */
|
|
602
|
+
interface CompiledModule<M extends Module> {
|
|
603
|
+
readonly ir: CompiledIR;
|
|
604
|
+
readonly kernelCount: number;
|
|
605
|
+
readonly outputShape: readonly number[];
|
|
606
|
+
/** Names of the model's parameters, in materialization order. The actual
|
|
607
|
+
* GPUBuffers live in the worker; use `downloadParams()` for values. */
|
|
608
|
+
readonly paramNames: readonly string[];
|
|
609
|
+
step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>;
|
|
610
|
+
step(inputs: Record<string, Int32Array | Float32Array>, opts: {
|
|
611
|
+
withCaptures: true;
|
|
612
|
+
}): Promise<StepResult>;
|
|
613
|
+
run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>;
|
|
614
|
+
run(inputs: Record<string, Int32Array | Float32Array>, opts: {
|
|
615
|
+
withCaptures: true;
|
|
616
|
+
}): Promise<RunResult>;
|
|
617
|
+
uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): Promise<void>;
|
|
618
|
+
downloadParams(): Promise<Record<string, Float32Array>>;
|
|
619
|
+
downloadParamGrads(): Promise<Record<string, Float32Array>>;
|
|
620
|
+
/** Re-initialize all params + zero optimizer state. */
|
|
621
|
+
reset(): Promise<void>;
|
|
622
|
+
resetOptimizerState(): Promise<void>;
|
|
623
|
+
/** Compile a sibling forward-only graph that shares this runtime's worker
|
|
624
|
+
* (and therefore its param GPUBuffers). */
|
|
654
625
|
compileForward<I extends InputDecls>(forward: ForwardFn<M, I>, opts?: CompileForwardMethodOptions<I>): Promise<CompiledForwardModule>;
|
|
626
|
+
/** Free the runtime's GPU resources and terminate the worker. */
|
|
627
|
+
destroy(): void;
|
|
655
628
|
}
|
|
656
629
|
/** Returned by `compileForward` (and by the `compileForward` method). */
|
|
657
|
-
interface CompiledForwardModule
|
|
658
|
-
ir: CompiledIR;
|
|
659
|
-
|
|
660
|
-
|
|
630
|
+
interface CompiledForwardModule {
|
|
631
|
+
readonly ir: CompiledIR;
|
|
632
|
+
readonly kernelCount: number;
|
|
633
|
+
readonly outputShape: readonly number[];
|
|
634
|
+
readonly paramNames: readonly string[];
|
|
635
|
+
run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>;
|
|
636
|
+
run(inputs: Record<string, Int32Array | Float32Array>, opts: {
|
|
637
|
+
withCaptures: true;
|
|
638
|
+
}): Promise<RunResult>;
|
|
639
|
+
uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): Promise<void>;
|
|
640
|
+
downloadParams(): Promise<Record<string, Float32Array>>;
|
|
641
|
+
destroy(): void;
|
|
661
642
|
}
|
|
662
643
|
/**
|
|
663
644
|
* Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
|
|
664
645
|
* model instance itself: compilation mutates the tree (every `ParamSentinel`
|
|
665
646
|
* field becomes a real `Tensor`), so the instance is consumed and shouldn't be
|
|
666
|
-
* referenced afterwards.
|
|
647
|
+
* referenced afterwards.
|
|
667
648
|
*
|
|
668
649
|
* The forward function takes the materialized model and a Record of named
|
|
669
|
-
* input tensors, returns the loss tensor
|
|
670
|
-
* `inputs:` declaration:
|
|
650
|
+
* input tensors, returns the loss tensor:
|
|
671
651
|
*
|
|
672
652
|
* inputs: {
|
|
673
653
|
* tokens: { shape: [B, T], dtype: 'i32' },
|
|
@@ -675,34 +655,15 @@ interface CompiledForwardModule extends CompiledForward {
|
|
|
675
655
|
* }
|
|
676
656
|
* forward: (m, { tokens, targets }) => …
|
|
677
657
|
*
|
|
678
|
-
*
|
|
679
|
-
*
|
|
680
|
-
*
|
|
681
|
-
* call `reset()` later to re-randomize.
|
|
682
|
-
*
|
|
683
|
-
* If `opts.adam` is set, the runtime's `step()` automatically tracks an
|
|
684
|
-
* internal step count and injects the bias-corrected `lrt` scalar each call;
|
|
685
|
-
* users don't need to provide it themselves.
|
|
658
|
+
* Returns a `CompiledModule` proxy. All GPU work (createRuntime, step, run,
|
|
659
|
+
* mapAsync) happens in an internal worker; calls return Promises that resolve
|
|
660
|
+
* when the worker replies.
|
|
686
661
|
*/
|
|
687
662
|
declare function compileModule<M extends Module, I extends InputDecls = InputDecls>(modelFactory: () => M, forward: ForwardFn<M, I>, opts?: CompileModuleOptions<I>): Promise<CompiledModule<M>>;
|
|
688
663
|
/**
|
|
689
|
-
*
|
|
690
|
-
*
|
|
691
|
-
*
|
|
692
|
-
* `Float32Array`.
|
|
693
|
-
*
|
|
694
|
-
* **Prefer the `compileForward` method on a training runtime** when both
|
|
695
|
-
* graphs use the same Module class — it auto-supplies `device` and
|
|
696
|
-
* `sharedParams`. This standalone form is for forward-only models with no
|
|
697
|
-
* training graph at all, or for sharing params across a different model.
|
|
698
|
-
*
|
|
699
|
-
* **Sharing params with a training compile.** Pass `opts.sharedParams =
|
|
700
|
-
* trainCompiled.params` to bind this graph's param buffers to an existing
|
|
701
|
-
* training runtime's GPU buffers — every train step is then immediately
|
|
702
|
-
* visible to `run()` calls here, no copies.
|
|
703
|
-
*
|
|
704
|
-
* Initial param values are uploaded automatically for params *not* covered
|
|
705
|
-
* by `sharedParams` (those are owned by the sibling compile).
|
|
664
|
+
* Forward-only compile. Spawns its own worker. For sibling graphs that share
|
|
665
|
+
* params with a training graph, prefer the `compileForward` method on the
|
|
666
|
+
* CompiledModule returned by `compileModule()`.
|
|
706
667
|
*/
|
|
707
668
|
declare function compileForward<M extends Module, I extends InputDecls = InputDecls>(modelFactory: () => M, forward: ForwardFn<M, I>, opts?: CompileForwardOptions<I>): Promise<CompiledForwardModule>;
|
|
708
669
|
|
|
@@ -759,5 +720,5 @@ declare namespace nn_d {
|
|
|
759
720
|
export type { nn_d_LinearOptions as LinearOptions };
|
|
760
721
|
}
|
|
761
722
|
|
|
762
|
-
export { Captures, Module, ShapeError, add, appendAdam, appendGrad, arange, capture,
|
|
763
|
-
export type { AdamConfig, AdamResult, BufferPlan, BufferSpec, CallSite, CompileForwardMethodOptions, CompileForwardOptions, CompileModuleOptions,
|
|
723
|
+
export { Captures, Module, ShapeError, add, appendAdam, appendGrad, arange, capture, compileForward, compileModule, compileToIR, div, embedding, emitKernels, exp, greater, init, less, log, logSoftmaxLast, lr, materializeParams, matmul, matmulBatched, meanLast, mul, nn_d as nn, oneHot, paramInput, planBuffers, relu, reshape, resolveLR, rsqrt, sliceLastRange, softmaxCausalLast, sqrt, stateInput, sub, sumAll, sumLast, swapAxes, tensorInput, trace, traceInto, transpose, where, whereCausal };
|
|
724
|
+
export type { AdamConfig, AdamResult, BufferPlan, BufferSpec, CallSite, CompileForwardMethodOptions, CompileForwardOptions, CompileModuleOptions, CompiledForwardModule, CompiledIR, CompiledModule, Dtype, ForwardFn, GradResult, Graph, InitSpec, InputDecl, InputDecls, InputsTensors, KernelSpec, LRSchedule, MaterializedParams, OpNode, ParamOptions, RunOptions, RunResult, Shape, StepResult, Tensor, UploadParamsOptions, Writeback, WritebackDecl };
|