tensorgrad 0.0.15 → 0.0.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.d.ts +154 -193
- package/dist/index.js +2208 -39
- package/dist/index.js.map +7 -1
- package/dist/worker.debug.js +553 -0
- package/package.json +60 -58
- package/src/adam.ts +69 -15
- package/src/compile.ts +334 -156
- package/src/index.ts +8 -4
- package/src/module.ts +72 -34
- package/src/worker-protocol.ts +183 -0
- package/src/worker-proxy.ts +76 -0
- package/src/worker.ts +281 -0
- package/dist/adam.js +0 -111
- package/dist/adam.js.map +0 -1
- package/dist/buffers.js +0 -120
- package/dist/buffers.js.map +0 -1
- package/dist/capture.js +0 -33
- package/dist/capture.js.map +0 -1
- package/dist/codegen.js +0 -724
- package/dist/codegen.js.map +0 -1
- package/dist/compile.js +0 -184
- package/dist/compile.js.map +0 -1
- package/dist/grad.js +0 -380
- package/dist/grad.js.map +0 -1
- package/dist/ir.js +0 -60
- package/dist/ir.js.map +0 -1
- package/dist/module.js +0 -155
- package/dist/module.js.map +0 -1
- package/dist/nn.js +0 -135
- package/dist/nn.js.map +0 -1
- package/dist/ops.js +0 -326
- package/dist/ops.js.map +0 -1
- package/dist/runtime.js +0 -402
- package/dist/runtime.js.map +0 -1
- package/dist/shape.js +0 -259
- package/dist/shape.js.map +0 -1
- package/dist/trace.js +0 -100
- package/dist/trace.js.map +0 -1
package/src/adam.ts
CHANGED
|
@@ -32,12 +32,68 @@ import type { WritebackDecl } from './buffers.js'
|
|
|
32
32
|
import { traceInto, stateInput, tensorInput } from './trace.js'
|
|
33
33
|
import { adamUpdateM, adamUpdateV, adamUpdateP } from './ops.js'
|
|
34
34
|
|
|
35
|
+
/** Per-step learning-rate schedule. Either a fixed number or one of the
|
|
36
|
+
* serializable shape forms below. Functions/closures are not supported —
|
|
37
|
+
* the schedule needs to cross thread boundaries and survive serialization
|
|
38
|
+
* for the worker-internal runtime, and every realistic LR pattern (constant,
|
|
39
|
+
* linear decay, cosine, warmup-then-decay) maps to a finite set of shapes.
|
|
40
|
+
* Use the `lr` helper namespace to construct shapes ergonomically. */
|
|
41
|
+
export type LRSchedule =
|
|
42
|
+
| number
|
|
43
|
+
| { readonly kind: 'constant'; readonly value: number }
|
|
44
|
+
| { readonly kind: 'linearDecay'; readonly peak: number; readonly final: number; readonly steps: number }
|
|
45
|
+
| { readonly kind: 'cosineDecay'; readonly peak: number; readonly final: number; readonly steps: number }
|
|
46
|
+
| { readonly kind: 'warmup'; readonly peakLr: number; readonly warmupSteps: number; readonly after: LRSchedule }
|
|
47
|
+
|
|
48
|
+
/** Ergonomic constructors for LRSchedule shapes. */
|
|
49
|
+
export const lr = {
|
|
50
|
+
constant: (value: number): LRSchedule => ({ kind: 'constant', value }),
|
|
51
|
+
/** Linearly interpolate from `peak` at step 1 to `final` at step `steps`,
|
|
52
|
+
* then hold at `final`. Matches `peak + (final - peak) * min(step/steps, 1)`. */
|
|
53
|
+
linearDecay: (opts: { peak: number; final: number; steps: number }): LRSchedule =>
|
|
54
|
+
({ kind: 'linearDecay', ...opts }),
|
|
55
|
+
/** Half-cosine from `peak` at step 1 down to `final` at step `steps`,
|
|
56
|
+
* then hold at `final`. */
|
|
57
|
+
cosineDecay: (opts: { peak: number; final: number; steps: number }): LRSchedule =>
|
|
58
|
+
({ kind: 'cosineDecay', ...opts }),
|
|
59
|
+
/** Linear ramp from 0 to `peakLr` over `warmupSteps` steps, then hand off
|
|
60
|
+
* to `after` (offset so step 1 of `after` = first post-warmup step). */
|
|
61
|
+
warmup: (opts: { peakLr: number; warmupSteps: number; after: LRSchedule }): LRSchedule =>
|
|
62
|
+
({ kind: 'warmup', ...opts }),
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
/** Resolve a schedule to its scalar value at a given 1-based step. */
|
|
66
|
+
export function resolveLR(schedule: LRSchedule, step: number): number {
|
|
67
|
+
if (typeof schedule === 'number') return schedule
|
|
68
|
+
switch (schedule.kind) {
|
|
69
|
+
case 'constant': return schedule.value
|
|
70
|
+
case 'linearDecay': {
|
|
71
|
+
const f = Math.min(step / schedule.steps, 1)
|
|
72
|
+
return schedule.peak + (schedule.final - schedule.peak) * f
|
|
73
|
+
}
|
|
74
|
+
case 'cosineDecay': {
|
|
75
|
+
const f = Math.min(step / schedule.steps, 1)
|
|
76
|
+
return schedule.final + 0.5 * (schedule.peak - schedule.final) * (1 + Math.cos(Math.PI * f))
|
|
77
|
+
}
|
|
78
|
+
case 'warmup': {
|
|
79
|
+
if (step <= schedule.warmupSteps) return schedule.peakLr * (step / schedule.warmupSteps)
|
|
80
|
+
return resolveLR(schedule.after, step - schedule.warmupSteps)
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
/** True for shapes that produce different values at different steps (so the
|
|
86
|
+
* AdamW decayShrink scalar must be a per-step input rather than baked).
|
|
87
|
+
* Numbers and `{kind:'constant'}` are static; everything else varies. */
|
|
88
|
+
export function isLRDynamic(schedule: LRSchedule): boolean {
|
|
89
|
+
if (typeof schedule === 'number') return false
|
|
90
|
+
return schedule.kind !== 'constant'
|
|
91
|
+
}
|
|
92
|
+
|
|
35
93
|
export interface AdamConfig {
|
|
36
|
-
/**
|
|
37
|
-
* `(
|
|
38
|
-
|
|
39
|
-
* per-step automatically when this is a function. */
|
|
40
|
-
lr: number | ((step: number) => number)
|
|
94
|
+
/** Learning rate schedule. Pass a number for fixed lr, or a shape from
|
|
95
|
+
* the `lr` helpers (e.g., `lr.linearDecay({ peak: 0.005, final: 0.0005, steps: 1500 })`). */
|
|
96
|
+
lr: LRSchedule
|
|
41
97
|
b1?: number // default 0.9
|
|
42
98
|
b2?: number // default 0.999
|
|
43
99
|
eps?: number // default 1e-8
|
|
@@ -52,16 +108,17 @@ export interface AdamConfig {
|
|
|
52
108
|
decayFilter?: (paramName: string) => boolean
|
|
53
109
|
}
|
|
54
110
|
|
|
55
|
-
/** Resolved hyperparameters
|
|
111
|
+
/** Resolved hyperparameters with all fields populated. `lr` stays as the
|
|
112
|
+
* shape (not pre-resolved) so the runtime can compute per-step values. */
|
|
56
113
|
export interface AdamResolvedConfig {
|
|
57
|
-
lr:
|
|
114
|
+
lr: LRSchedule
|
|
58
115
|
b1: number
|
|
59
116
|
b2: number
|
|
60
117
|
eps: number
|
|
61
118
|
weightDecay: number
|
|
62
119
|
decayFilter: (name: string) => boolean
|
|
63
|
-
/** True iff the
|
|
64
|
-
* decayShrink is baked at compile time
|
|
120
|
+
/** True iff the lr shape varies with step (linearDecay, cosineDecay,
|
|
121
|
+
* warmup). When false, decayShrink is baked at compile time. */
|
|
65
122
|
lrIsScheduled: boolean
|
|
66
123
|
}
|
|
67
124
|
|
|
@@ -101,13 +158,10 @@ export function appendAdam(
|
|
|
101
158
|
* directly without a Module). */
|
|
102
159
|
decayFlags?: Record<string, boolean>,
|
|
103
160
|
): AdamResult {
|
|
104
|
-
const lrIsScheduled =
|
|
105
|
-
const
|
|
106
|
-
? config.lr as (step: number) => number
|
|
107
|
-
: (() => config.lr as number)
|
|
108
|
-
const initialLr = lrFn(1)
|
|
161
|
+
const lrIsScheduled = isLRDynamic(config.lr)
|
|
162
|
+
const initialLr = resolveLR(config.lr, 1)
|
|
109
163
|
const fullConfig: AdamResolvedConfig = {
|
|
110
|
-
lr:
|
|
164
|
+
lr: config.lr,
|
|
111
165
|
b1: config.b1 ?? 0.9,
|
|
112
166
|
b2: config.b2 ?? 0.999,
|
|
113
167
|
eps: config.eps ?? 1e-8,
|