tensorgrad 0.0.14 → 0.0.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.d.ts +154 -170
- package/dist/index.js +2208 -39
- package/dist/index.js.map +7 -1
- package/dist/worker.debug.js +553 -0
- package/package.json +60 -58
- package/src/adam.ts +69 -15
- package/src/compile.ts +334 -154
- package/src/index.ts +8 -4
- package/src/module.ts +72 -34
- package/src/runtime.ts +64 -11
- package/src/worker-protocol.ts +183 -0
- package/src/worker-proxy.ts +76 -0
- package/src/worker.ts +281 -0
- package/dist/adam.js +0 -111
- package/dist/adam.js.map +0 -1
- package/dist/buffers.js +0 -120
- package/dist/buffers.js.map +0 -1
- package/dist/capture.js +0 -33
- package/dist/capture.js.map +0 -1
- package/dist/codegen.js +0 -724
- package/dist/codegen.js.map +0 -1
- package/dist/compile.js +0 -180
- package/dist/compile.js.map +0 -1
- package/dist/grad.js +0 -380
- package/dist/grad.js.map +0 -1
- package/dist/ir.js +0 -60
- package/dist/ir.js.map +0 -1
- package/dist/module.js +0 -155
- package/dist/module.js.map +0 -1
- package/dist/nn.js +0 -135
- package/dist/nn.js.map +0 -1
- package/dist/ops.js +0 -326
- package/dist/ops.js.map +0 -1
- package/dist/runtime.js +0 -375
- package/dist/runtime.js.map +0 -1
- package/dist/shape.js +0 -259
- package/dist/shape.js.map +0 -1
- package/dist/trace.js +0 -100
- package/dist/trace.js.map +0 -1
package/dist/index.d.ts
CHANGED
|
@@ -311,12 +311,62 @@ interface WritebackDecl {
|
|
|
311
311
|
*/
|
|
312
312
|
declare function planBuffers(graph: Graph, paramGrads: Record<string, Tensor>, writebackDecls?: WritebackDecl[]): BufferPlan;
|
|
313
313
|
|
|
314
|
+
/** Per-step learning-rate schedule. Either a fixed number or one of the
|
|
315
|
+
* serializable shape forms below. Functions/closures are not supported —
|
|
316
|
+
* the schedule needs to cross thread boundaries and survive serialization
|
|
317
|
+
* for the worker-internal runtime, and every realistic LR pattern (constant,
|
|
318
|
+
* linear decay, cosine, warmup-then-decay) maps to a finite set of shapes.
|
|
319
|
+
* Use the `lr` helper namespace to construct shapes ergonomically. */
|
|
320
|
+
type LRSchedule = number | {
|
|
321
|
+
readonly kind: 'constant';
|
|
322
|
+
readonly value: number;
|
|
323
|
+
} | {
|
|
324
|
+
readonly kind: 'linearDecay';
|
|
325
|
+
readonly peak: number;
|
|
326
|
+
readonly final: number;
|
|
327
|
+
readonly steps: number;
|
|
328
|
+
} | {
|
|
329
|
+
readonly kind: 'cosineDecay';
|
|
330
|
+
readonly peak: number;
|
|
331
|
+
readonly final: number;
|
|
332
|
+
readonly steps: number;
|
|
333
|
+
} | {
|
|
334
|
+
readonly kind: 'warmup';
|
|
335
|
+
readonly peakLr: number;
|
|
336
|
+
readonly warmupSteps: number;
|
|
337
|
+
readonly after: LRSchedule;
|
|
338
|
+
};
|
|
339
|
+
/** Ergonomic constructors for LRSchedule shapes. */
|
|
340
|
+
declare const lr: {
|
|
341
|
+
constant: (value: number) => LRSchedule;
|
|
342
|
+
/** Linearly interpolate from `peak` at step 1 to `final` at step `steps`,
|
|
343
|
+
* then hold at `final`. Matches `peak + (final - peak) * min(step/steps, 1)`. */
|
|
344
|
+
linearDecay: (opts: {
|
|
345
|
+
peak: number;
|
|
346
|
+
final: number;
|
|
347
|
+
steps: number;
|
|
348
|
+
}) => LRSchedule;
|
|
349
|
+
/** Half-cosine from `peak` at step 1 down to `final` at step `steps`,
|
|
350
|
+
* then hold at `final`. */
|
|
351
|
+
cosineDecay: (opts: {
|
|
352
|
+
peak: number;
|
|
353
|
+
final: number;
|
|
354
|
+
steps: number;
|
|
355
|
+
}) => LRSchedule;
|
|
356
|
+
/** Linear ramp from 0 to `peakLr` over `warmupSteps` steps, then hand off
|
|
357
|
+
* to `after` (offset so step 1 of `after` = first post-warmup step). */
|
|
358
|
+
warmup: (opts: {
|
|
359
|
+
peakLr: number;
|
|
360
|
+
warmupSteps: number;
|
|
361
|
+
after: LRSchedule;
|
|
362
|
+
}) => LRSchedule;
|
|
363
|
+
};
|
|
364
|
+
/** Resolve a schedule to its scalar value at a given 1-based step. */
|
|
365
|
+
declare function resolveLR(schedule: LRSchedule, step: number): number;
|
|
314
366
|
interface AdamConfig {
|
|
315
|
-
/**
|
|
316
|
-
* `(
|
|
317
|
-
|
|
318
|
-
* per-step automatically when this is a function. */
|
|
319
|
-
lr: number | ((step: number) => number);
|
|
367
|
+
/** Learning rate schedule. Pass a number for fixed lr, or a shape from
|
|
368
|
+
* the `lr` helpers (e.g., `lr.linearDecay({ peak: 0.005, final: 0.0005, steps: 1500 })`). */
|
|
369
|
+
lr: LRSchedule;
|
|
320
370
|
b1?: number;
|
|
321
371
|
b2?: number;
|
|
322
372
|
eps?: number;
|
|
@@ -330,16 +380,17 @@ interface AdamConfig {
|
|
|
330
380
|
* Example: `(name) => name.includes('.W') || name.endsWith('_emb')`. */
|
|
331
381
|
decayFilter?: (paramName: string) => boolean;
|
|
332
382
|
}
|
|
333
|
-
/** Resolved hyperparameters
|
|
383
|
+
/** Resolved hyperparameters with all fields populated. `lr` stays as the
|
|
384
|
+
* shape (not pre-resolved) so the runtime can compute per-step values. */
|
|
334
385
|
interface AdamResolvedConfig {
|
|
335
|
-
lr:
|
|
386
|
+
lr: LRSchedule;
|
|
336
387
|
b1: number;
|
|
337
388
|
b2: number;
|
|
338
389
|
eps: number;
|
|
339
390
|
weightDecay: number;
|
|
340
391
|
decayFilter: (name: string) => boolean;
|
|
341
|
-
/** True iff the
|
|
342
|
-
* decayShrink is baked at compile time
|
|
392
|
+
/** True iff the lr shape varies with step (linearDecay, cosineDecay,
|
|
393
|
+
* warmup). When false, decayShrink is baked at compile time. */
|
|
343
394
|
lrIsScheduled: boolean;
|
|
344
395
|
}
|
|
345
396
|
interface AdamResult {
|
|
@@ -431,101 +482,52 @@ interface RunOptions {
|
|
|
431
482
|
* `.get` throws); when true, captures are read back and accessible. */
|
|
432
483
|
withCaptures?: boolean;
|
|
433
484
|
}
|
|
434
|
-
/** Common surface for both training and forward-only compiled runtimes. */
|
|
435
|
-
interface CompiledBase {
|
|
436
|
-
/** The GPUDevice this runtime is bound to. Pass to sibling compiles to
|
|
437
|
-
* share the device, or use directly for other GPU work. */
|
|
438
|
-
device: GPUDevice;
|
|
439
|
-
/** Param name -> the underlying GPUBuffer. Pass to a sibling compile via
|
|
440
|
-
* `sharedParams` to share without copies. */
|
|
441
|
-
params: Map<string, GPUBuffer>;
|
|
442
|
-
/** Shape of the graph's output (loss scalar `[]` for training; the user's
|
|
443
|
-
* returned tensor for forward-only compiles). */
|
|
444
|
-
outputShape: number[];
|
|
445
|
-
/** Upload parameter Float32Arrays to their GPU buffers. By default, requires
|
|
446
|
-
* *all* params to be present; throws on any unknown or missing key. Pass
|
|
447
|
-
* `{ partial: true }` to skip the missing-key check. */
|
|
448
|
-
uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): void;
|
|
449
|
-
/** Read all parameters back as Float32Arrays — used for UI panels. */
|
|
450
|
-
downloadParams(): Promise<Record<string, Float32Array>>;
|
|
451
|
-
/** Free GPU resources. */
|
|
452
|
-
destroy(): void;
|
|
453
|
-
}
|
|
454
|
-
/** Run a dispatch and read back the full output tensor. Default returns the
|
|
455
|
-
* output as a `Float32Array`; with `{ withCaptures: true }` returns
|
|
456
|
-
* `{ output, captures }`. Same shape as `step()`'s overloads. */
|
|
457
|
-
interface RunFn {
|
|
458
|
-
(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>;
|
|
459
|
-
(inputs: Record<string, Int32Array | Float32Array>, opts: {
|
|
460
|
-
withCaptures: true;
|
|
461
|
-
}): Promise<RunResult>;
|
|
462
|
-
(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<Float32Array | RunResult>;
|
|
463
|
-
}
|
|
464
|
-
interface CompiledRuntime extends CompiledBase {
|
|
465
|
-
/** Read all parameter gradients back. Mostly for verification / debugging. */
|
|
466
|
-
downloadParamGrads(): Promise<Record<string, Float32Array>>;
|
|
467
|
-
/**
|
|
468
|
-
* One full forward+backward step.
|
|
469
|
-
* 1. Uploads `inputs` (tokens, targets, masks) to input buffers.
|
|
470
|
-
* 2. Dispatches every kernel in order.
|
|
471
|
-
* 3. Reads back the loss scalar (and any registered captures, if requested).
|
|
472
|
-
* Default returns the loss as a JS number; with `{ withCaptures: true }`
|
|
473
|
-
* returns `{ loss, captures }`.
|
|
474
|
-
*/
|
|
475
|
-
step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>;
|
|
476
|
-
step(inputs: Record<string, Int32Array | Float32Array>, opts: {
|
|
477
|
-
withCaptures: true;
|
|
478
|
-
}): Promise<StepResult>;
|
|
479
|
-
step(inputs: Record<string, Int32Array | Float32Array>, opts: RunOptions): Promise<number | StepResult>;
|
|
480
|
-
/** Same dispatch as step() but returns the full output Float32Array — for
|
|
481
|
-
* training graphs the output is a scalar loss, so step() is usually more
|
|
482
|
-
* convenient. Provided for parity with `compileForward`. */
|
|
483
|
-
run: RunFn;
|
|
484
|
-
/** Re-zero all optimizer state buffers (Adam's m/v) in place. Pair with
|
|
485
|
-
* `uploadInitialParams()` for a full training reset without recompile. */
|
|
486
|
-
resetOptimizerState(): void;
|
|
487
|
-
}
|
|
488
|
-
/** Forward-only compiled runtime — produced by `compileForward`. No optimizer,
|
|
489
|
-
* no backward. Returns the output tensor (not just a scalar) per `run()` call. */
|
|
490
|
-
interface CompiledForward extends CompiledBase {
|
|
491
|
-
run: RunFn;
|
|
492
|
-
}
|
|
493
|
-
interface RuntimeOpts {
|
|
494
|
-
/** Pre-acquired GPUDevice. If omitted, runtime requests its own. */
|
|
495
|
-
device?: GPUDevice;
|
|
496
|
-
/** External param buffers to bind in place of allocating fresh ones, keyed
|
|
497
|
-
* by param name. Used to share params between a training compile and a
|
|
498
|
-
* sibling forward-only compile (e.g., a B=1 inference graph). When a name
|
|
499
|
-
* is in this map, the runtime reuses the provided GPUBuffer; otherwise it
|
|
500
|
-
* allocates as usual. */
|
|
501
|
-
sharedParams?: Map<string, GPUBuffer>;
|
|
502
|
-
}
|
|
503
|
-
declare function createRuntime(plan: BufferPlan, kernels: KernelSpec[], lossBufferId: number, opts?: RuntimeOpts): Promise<CompiledRuntime>;
|
|
504
|
-
/** Same machinery as `createRuntime`, narrower public type: a forward-only
|
|
505
|
-
* graph exposes `run()` instead of `step()` (no optimizer state, no scalar-
|
|
506
|
-
* loss readback). The full runtime object is built once and projected by
|
|
507
|
-
* `compileForward` to the public shape. */
|
|
508
|
-
declare function createForwardRuntime(plan: BufferPlan, kernels: KernelSpec[], outputBufferId: number, opts?: RuntimeOpts): Promise<CompiledForward>;
|
|
509
485
|
|
|
510
|
-
/** How a parameter's initial values are produced.
|
|
511
|
-
*
|
|
512
|
-
*
|
|
513
|
-
*
|
|
514
|
-
*
|
|
515
|
-
* -
|
|
516
|
-
*
|
|
486
|
+
/** How a parameter's initial values are produced. Serializable shape — no
|
|
487
|
+
* closures, since the initial values cross the worker boundary at compile
|
|
488
|
+
* time. Use the `init` helpers for ergonomic construction.
|
|
489
|
+
*
|
|
490
|
+
* String shorthands:
|
|
491
|
+
* - `'randn'` — Gaussian with std 0.02 (the common weight-matrix init).
|
|
492
|
+
* - `'zeros'` — fill with 0 (biases, LayerNorm beta).
|
|
493
|
+
* - `'ones'` — fill with 1 (LayerNorm gain).
|
|
494
|
+
*
|
|
495
|
+
* Object shapes:
|
|
496
|
+
* - `{ kind: 'randn', scale }` — randn with explicit std.
|
|
497
|
+
* - `{ kind: 'kaiming', gain? }` — `std = gain / sqrt(fan_in)`. Default
|
|
498
|
+
* gain `sqrt(2)` (good for ReLU). `fan_in = shape[0]`.
|
|
499
|
+
* - `{ kind: 'literal', data }` — explicit Float32Array; length must
|
|
500
|
+
* match the parameter's element count.
|
|
517
501
|
*/
|
|
518
|
-
type InitSpec = 'randn' | 'zeros' | 'ones' |
|
|
502
|
+
type InitSpec = 'randn' | 'zeros' | 'ones' | {
|
|
503
|
+
readonly kind: 'randn';
|
|
504
|
+
readonly scale: number;
|
|
505
|
+
} | {
|
|
506
|
+
readonly kind: 'kaiming';
|
|
507
|
+
readonly gain?: number;
|
|
508
|
+
} | {
|
|
509
|
+
readonly kind: 'literal';
|
|
510
|
+
readonly data: Float32Array;
|
|
511
|
+
};
|
|
512
|
+
/** Ergonomic constructors for InitSpec object shapes. */
|
|
513
|
+
declare const init: {
|
|
514
|
+
randn: (opts?: {
|
|
515
|
+
scale?: number;
|
|
516
|
+
}) => InitSpec;
|
|
517
|
+
kaiming: (opts?: {
|
|
518
|
+
gain?: number;
|
|
519
|
+
}) => InitSpec;
|
|
520
|
+
literal: (data: Float32Array) => InitSpec;
|
|
521
|
+
};
|
|
519
522
|
interface ParamOptions {
|
|
520
523
|
dtype?: Dtype;
|
|
521
|
-
/** Init
|
|
524
|
+
/** Init shape. Default: `'randn'` (std 0.02). */
|
|
522
525
|
init?: InitSpec;
|
|
523
|
-
/** Std dev for `'randn'`. Default 0.02. Ignored for non-randn init. */
|
|
524
|
-
scale?: number;
|
|
525
526
|
/** Whether AdamW (when `weightDecay > 0`) should apply decoupled weight
|
|
526
|
-
* decay to this param. Default: `true` for
|
|
527
|
-
* embeddings)
|
|
528
|
-
* to force or skip. Replaces `adam.decayFilter` for
|
|
527
|
+
* decay to this param. Default: `true` for randn/kaiming/literal init
|
|
528
|
+
* (weight matrices, embeddings); `false` for zeros/ones (biases, LN
|
|
529
|
+
* gains). Override to force or skip. Replaces `adam.decayFilter` for
|
|
530
|
+
* the common case. */
|
|
529
531
|
decay?: boolean;
|
|
530
532
|
}
|
|
531
533
|
type InitFn = (size: number, shape: readonly number[]) => Float32Array;
|
|
@@ -567,21 +569,14 @@ interface InputDecl {
|
|
|
567
569
|
shape: Shape;
|
|
568
570
|
dtype?: Dtype;
|
|
569
571
|
}
|
|
570
|
-
/** Inputs declaration: a Record from input name to its shape/dtype.
|
|
571
|
-
* doubles as the key the forward fn destructures and the key the runtime
|
|
572
|
-
* expects in `step({...})` / `run({...})`. */
|
|
572
|
+
/** Inputs declaration: a Record from input name to its shape/dtype. */
|
|
573
573
|
type InputDecls = Record<string, InputDecl>;
|
|
574
574
|
/** Maps an `InputDecls` Record to its forward-time tensor counterpart —
|
|
575
|
-
* same keys, each value is a Tensor.
|
|
576
|
-
* `inputs` argument from the declared shape Record. */
|
|
575
|
+
* same keys, each value is a Tensor. */
|
|
577
576
|
type InputsTensors<I extends InputDecls> = {
|
|
578
577
|
[K in keyof I]: Tensor;
|
|
579
578
|
};
|
|
580
|
-
/** Forward function shape
|
|
581
|
-
* named input tensors (matching the declared `inputs:` keys), returns the
|
|
582
|
-
* output tensor (loss for compileModule; logits/etc. for compileForward).
|
|
583
|
-
* The second generic flows from the inputs declaration so destructuring
|
|
584
|
-
* the input record stays typed. */
|
|
579
|
+
/** Forward function shape. */
|
|
585
580
|
type ForwardFn<M extends Module, I extends InputDecls = InputDecls> = (m: M, inputs: InputsTensors<I>) => Tensor;
|
|
586
581
|
interface CompiledIR {
|
|
587
582
|
graph: GradResult['graph'];
|
|
@@ -592,59 +587,67 @@ interface CompiledIR {
|
|
|
592
587
|
}
|
|
593
588
|
/** Trace + autograd + buffer-plan + codegen, without touching WebGPU. */
|
|
594
589
|
declare function compileToIR(traceFn: () => Tensor): CompiledIR;
|
|
595
|
-
|
|
596
|
-
declare function compile(traceFn: () => Tensor, opts?: RuntimeOpts): Promise<CompiledRuntime & {
|
|
597
|
-
ir: CompiledIR;
|
|
598
|
-
}>;
|
|
599
|
-
interface CompileModuleOptions<I extends InputDecls = InputDecls> extends RuntimeOpts {
|
|
600
|
-
/** Per-step data inputs to the forward function, keyed by name. The forward
|
|
601
|
-
* fn destructures these out of its second argument; runtime calls to
|
|
602
|
-
* `step()` / `run()` pass typed arrays under the same keys. */
|
|
590
|
+
interface CompileModuleOptions<I extends InputDecls = InputDecls> {
|
|
603
591
|
inputs?: I;
|
|
604
|
-
/** Adam hyperparameters. If omitted, no optimizer is appended (forward-only). */
|
|
605
592
|
adam?: AdamConfig;
|
|
606
593
|
}
|
|
607
|
-
interface CompileForwardOptions<I extends InputDecls = InputDecls>
|
|
608
|
-
/** Per-step data inputs to the forward function, keyed by name. */
|
|
594
|
+
interface CompileForwardOptions<I extends InputDecls = InputDecls> {
|
|
609
595
|
inputs?: I;
|
|
610
596
|
}
|
|
611
|
-
/** Forward-only compile options as taken by the `compileForward` *method* on
|
|
612
|
-
* a training runtime — no `device` (inherited) and no `sharedParams`
|
|
613
|
-
* (auto-supplied from the train graph's params). */
|
|
614
597
|
interface CompileForwardMethodOptions<I extends InputDecls = InputDecls> {
|
|
615
598
|
inputs?: I;
|
|
616
599
|
}
|
|
617
|
-
/** Returned by `compileModule`.
|
|
618
|
-
*
|
|
619
|
-
interface CompiledModule<M extends Module>
|
|
620
|
-
ir: CompiledIR;
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
/**
|
|
624
|
-
*
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
600
|
+
/** Returned by `compileModule`. Proxies all GPU work to a worker held
|
|
601
|
+
* internally; user code awaits Promises and never sees the worker. */
|
|
602
|
+
interface CompiledModule<M extends Module> {
|
|
603
|
+
readonly ir: CompiledIR;
|
|
604
|
+
readonly kernelCount: number;
|
|
605
|
+
readonly outputShape: readonly number[];
|
|
606
|
+
/** Names of the model's parameters, in materialization order. The actual
|
|
607
|
+
* GPUBuffers live in the worker; use `downloadParams()` for values. */
|
|
608
|
+
readonly paramNames: readonly string[];
|
|
609
|
+
step(inputs: Record<string, Int32Array | Float32Array>): Promise<number>;
|
|
610
|
+
step(inputs: Record<string, Int32Array | Float32Array>, opts: {
|
|
611
|
+
withCaptures: true;
|
|
612
|
+
}): Promise<StepResult>;
|
|
613
|
+
run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>;
|
|
614
|
+
run(inputs: Record<string, Int32Array | Float32Array>, opts: {
|
|
615
|
+
withCaptures: true;
|
|
616
|
+
}): Promise<RunResult>;
|
|
617
|
+
uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): Promise<void>;
|
|
618
|
+
downloadParams(): Promise<Record<string, Float32Array>>;
|
|
619
|
+
downloadParamGrads(): Promise<Record<string, Float32Array>>;
|
|
620
|
+
/** Re-initialize all params + zero optimizer state. */
|
|
621
|
+
reset(): Promise<void>;
|
|
622
|
+
resetOptimizerState(): Promise<void>;
|
|
623
|
+
/** Compile a sibling forward-only graph that shares this runtime's worker
|
|
624
|
+
* (and therefore its param GPUBuffers). */
|
|
631
625
|
compileForward<I extends InputDecls>(forward: ForwardFn<M, I>, opts?: CompileForwardMethodOptions<I>): Promise<CompiledForwardModule>;
|
|
626
|
+
/** Free the runtime's GPU resources and terminate the worker. */
|
|
627
|
+
destroy(): void;
|
|
632
628
|
}
|
|
633
629
|
/** Returned by `compileForward` (and by the `compileForward` method). */
|
|
634
|
-
interface CompiledForwardModule
|
|
635
|
-
ir: CompiledIR;
|
|
636
|
-
|
|
637
|
-
|
|
630
|
+
interface CompiledForwardModule {
|
|
631
|
+
readonly ir: CompiledIR;
|
|
632
|
+
readonly kernelCount: number;
|
|
633
|
+
readonly outputShape: readonly number[];
|
|
634
|
+
readonly paramNames: readonly string[];
|
|
635
|
+
run(inputs: Record<string, Int32Array | Float32Array>): Promise<Float32Array>;
|
|
636
|
+
run(inputs: Record<string, Int32Array | Float32Array>, opts: {
|
|
637
|
+
withCaptures: true;
|
|
638
|
+
}): Promise<RunResult>;
|
|
639
|
+
uploadParams(params: Record<string, Float32Array>, opts?: UploadParamsOptions): Promise<void>;
|
|
640
|
+
downloadParams(): Promise<Record<string, Float32Array>>;
|
|
641
|
+
destroy(): void;
|
|
638
642
|
}
|
|
639
643
|
/**
|
|
640
644
|
* Compile a Module-based model. Pass a *factory* `() => new Model()`, not the
|
|
641
645
|
* model instance itself: compilation mutates the tree (every `ParamSentinel`
|
|
642
646
|
* field becomes a real `Tensor`), so the instance is consumed and shouldn't be
|
|
643
|
-
* referenced afterwards.
|
|
647
|
+
* referenced afterwards.
|
|
644
648
|
*
|
|
645
649
|
* The forward function takes the materialized model and a Record of named
|
|
646
|
-
* input tensors, returns the loss tensor
|
|
647
|
-
* `inputs:` declaration:
|
|
650
|
+
* input tensors, returns the loss tensor:
|
|
648
651
|
*
|
|
649
652
|
* inputs: {
|
|
650
653
|
* tokens: { shape: [B, T], dtype: 'i32' },
|
|
@@ -652,34 +655,15 @@ interface CompiledForwardModule extends CompiledForward {
|
|
|
652
655
|
* }
|
|
653
656
|
* forward: (m, { tokens, targets }) => …
|
|
654
657
|
*
|
|
655
|
-
*
|
|
656
|
-
*
|
|
657
|
-
*
|
|
658
|
-
* call `reset()` later to re-randomize.
|
|
659
|
-
*
|
|
660
|
-
* If `opts.adam` is set, the runtime's `step()` automatically tracks an
|
|
661
|
-
* internal step count and injects the bias-corrected `lrt` scalar each call;
|
|
662
|
-
* users don't need to provide it themselves.
|
|
658
|
+
* Returns a `CompiledModule` proxy. All GPU work (createRuntime, step, run,
|
|
659
|
+
* mapAsync) happens in an internal worker; calls return Promises that resolve
|
|
660
|
+
* when the worker replies.
|
|
663
661
|
*/
|
|
664
662
|
declare function compileModule<M extends Module, I extends InputDecls = InputDecls>(modelFactory: () => M, forward: ForwardFn<M, I>, opts?: CompileModuleOptions<I>): Promise<CompiledModule<M>>;
|
|
665
663
|
/**
|
|
666
|
-
*
|
|
667
|
-
*
|
|
668
|
-
*
|
|
669
|
-
* `Float32Array`.
|
|
670
|
-
*
|
|
671
|
-
* **Prefer the `compileForward` method on a training runtime** when both
|
|
672
|
-
* graphs use the same Module class — it auto-supplies `device` and
|
|
673
|
-
* `sharedParams`. This standalone form is for forward-only models with no
|
|
674
|
-
* training graph at all, or for sharing params across a different model.
|
|
675
|
-
*
|
|
676
|
-
* **Sharing params with a training compile.** Pass `opts.sharedParams =
|
|
677
|
-
* trainCompiled.params` to bind this graph's param buffers to an existing
|
|
678
|
-
* training runtime's GPU buffers — every train step is then immediately
|
|
679
|
-
* visible to `run()` calls here, no copies.
|
|
680
|
-
*
|
|
681
|
-
* Initial param values are uploaded automatically for params *not* covered
|
|
682
|
-
* by `sharedParams` (those are owned by the sibling compile).
|
|
664
|
+
* Forward-only compile. Spawns its own worker. For sibling graphs that share
|
|
665
|
+
* params with a training graph, prefer the `compileForward` method on the
|
|
666
|
+
* CompiledModule returned by `compileModule()`.
|
|
683
667
|
*/
|
|
684
668
|
declare function compileForward<M extends Module, I extends InputDecls = InputDecls>(modelFactory: () => M, forward: ForwardFn<M, I>, opts?: CompileForwardOptions<I>): Promise<CompiledForwardModule>;
|
|
685
669
|
|
|
@@ -736,5 +720,5 @@ declare namespace nn_d {
|
|
|
736
720
|
export type { nn_d_LinearOptions as LinearOptions };
|
|
737
721
|
}
|
|
738
722
|
|
|
739
|
-
export { Captures, Module, ShapeError, add, appendAdam, appendGrad, arange, capture,
|
|
740
|
-
export type { AdamConfig, AdamResult, BufferPlan, BufferSpec, CallSite, CompileForwardMethodOptions, CompileForwardOptions, CompileModuleOptions,
|
|
723
|
+
export { Captures, Module, ShapeError, add, appendAdam, appendGrad, arange, capture, compileForward, compileModule, compileToIR, div, embedding, emitKernels, exp, greater, init, less, log, logSoftmaxLast, lr, materializeParams, matmul, matmulBatched, meanLast, mul, nn_d as nn, oneHot, paramInput, planBuffers, relu, reshape, resolveLR, rsqrt, sliceLastRange, softmaxCausalLast, sqrt, stateInput, sub, sumAll, sumLast, swapAxes, tensorInput, trace, traceInto, transpose, where, whereCausal };
|
|
724
|
+
export type { AdamConfig, AdamResult, BufferPlan, BufferSpec, CallSite, CompileForwardMethodOptions, CompileForwardOptions, CompileModuleOptions, CompiledForwardModule, CompiledIR, CompiledModule, Dtype, ForwardFn, GradResult, Graph, InitSpec, InputDecl, InputDecls, InputsTensors, KernelSpec, LRSchedule, MaterializedParams, OpNode, ParamOptions, RunOptions, RunResult, Shape, StepResult, Tensor, UploadParamsOptions, Writeback, WritebackDecl };
|