tensorgrad 0.0.15 → 0.0.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/adam.ts CHANGED
@@ -32,12 +32,68 @@ import type { WritebackDecl } from './buffers.js'
32
32
  import { traceInto, stateInput, tensorInput } from './trace.js'
33
33
  import { adamUpdateM, adamUpdateV, adamUpdateP } from './ops.js'
34
34
 
35
+ /** Per-step learning-rate schedule. Either a fixed number or one of the
36
+ * serializable shape forms below. Functions/closures are not supported —
37
+ * the schedule needs to cross thread boundaries and survive serialization
38
+ * for the worker-internal runtime, and every realistic LR pattern (constant,
39
+ * linear decay, cosine, warmup-then-decay) maps to a finite set of shapes.
40
+ * Use the `lr` helper namespace to construct shapes ergonomically. */
41
+ export type LRSchedule =
42
+ | number
43
+ | { readonly kind: 'constant'; readonly value: number }
44
+ | { readonly kind: 'linearDecay'; readonly peak: number; readonly final: number; readonly steps: number }
45
+ | { readonly kind: 'cosineDecay'; readonly peak: number; readonly final: number; readonly steps: number }
46
+ | { readonly kind: 'warmup'; readonly peakLr: number; readonly warmupSteps: number; readonly after: LRSchedule }
47
+
48
+ /** Ergonomic constructors for LRSchedule shapes. */
49
+ export const lr = {
50
+ constant: (value: number): LRSchedule => ({ kind: 'constant', value }),
51
+ /** Linearly interpolate from `peak` at step 1 to `final` at step `steps`,
52
+ * then hold at `final`. Matches `peak + (final - peak) * min(step/steps, 1)`. */
53
+ linearDecay: (opts: { peak: number; final: number; steps: number }): LRSchedule =>
54
+ ({ kind: 'linearDecay', ...opts }),
55
+ /** Half-cosine from `peak` at step 1 down to `final` at step `steps`,
56
+ * then hold at `final`. */
57
+ cosineDecay: (opts: { peak: number; final: number; steps: number }): LRSchedule =>
58
+ ({ kind: 'cosineDecay', ...opts }),
59
+ /** Linear ramp from 0 to `peakLr` over `warmupSteps` steps, then hand off
60
+ * to `after` (offset so step 1 of `after` = first post-warmup step). */
61
+ warmup: (opts: { peakLr: number; warmupSteps: number; after: LRSchedule }): LRSchedule =>
62
+ ({ kind: 'warmup', ...opts }),
63
+ }
64
+
65
+ /** Resolve a schedule to its scalar value at a given 1-based step. */
66
+ export function resolveLR(schedule: LRSchedule, step: number): number {
67
+ if (typeof schedule === 'number') return schedule
68
+ switch (schedule.kind) {
69
+ case 'constant': return schedule.value
70
+ case 'linearDecay': {
71
+ const f = Math.min(step / schedule.steps, 1)
72
+ return schedule.peak + (schedule.final - schedule.peak) * f
73
+ }
74
+ case 'cosineDecay': {
75
+ const f = Math.min(step / schedule.steps, 1)
76
+ return schedule.final + 0.5 * (schedule.peak - schedule.final) * (1 + Math.cos(Math.PI * f))
77
+ }
78
+ case 'warmup': {
79
+ if (step <= schedule.warmupSteps) return schedule.peakLr * (step / schedule.warmupSteps)
80
+ return resolveLR(schedule.after, step - schedule.warmupSteps)
81
+ }
82
+ }
83
+ }
84
+
85
+ /** True for shapes that produce different values at different steps (so the
86
+ * AdamW decayShrink scalar must be a per-step input rather than baked).
87
+ * Numbers and `{kind:'constant'}` are static; everything else varies. */
88
+ export function isLRDynamic(schedule: LRSchedule): boolean {
89
+ if (typeof schedule === 'number') return false
90
+ return schedule.kind !== 'constant'
91
+ }
92
+
35
93
  export interface AdamConfig {
36
- /** Constant scalar (e.g., `0.005`) or a per-step schedule function
37
- * `(step) => lr`. Schedule fn lets the user implement linear/cosine decay
38
- * or warmup; first call passes `step=1`. Decay-shrink (AdamW) updates
39
- * per-step automatically when this is a function. */
40
- lr: number | ((step: number) => number)
94
+ /** Learning rate schedule. Pass a number for fixed lr, or a shape from
95
+ * the `lr` helpers (e.g., `lr.linearDecay({ peak: 0.005, final: 0.0005, steps: 1500 })`). */
96
+ lr: LRSchedule
41
97
  b1?: number // default 0.9
42
98
  b2?: number // default 0.999
43
99
  eps?: number // default 1e-8
@@ -52,16 +108,17 @@ export interface AdamConfig {
52
108
  decayFilter?: (paramName: string) => boolean
53
109
  }
54
110
 
55
- /** Resolved hyperparameters: lr is the schedule fn (constants are wrapped). */
111
+ /** Resolved hyperparameters with all fields populated. `lr` stays as the
112
+ * shape (not pre-resolved) so the runtime can compute per-step values. */
56
113
  export interface AdamResolvedConfig {
57
- lr: (step: number) => number
114
+ lr: LRSchedule
58
115
  b1: number
59
116
  b2: number
60
117
  eps: number
61
118
  weightDecay: number
62
119
  decayFilter: (name: string) => boolean
63
- /** True iff the user supplied an lr function (vs a constant). When false,
64
- * decayShrink is baked at compile time and never updated. */
120
+ /** True iff the lr shape varies with step (linearDecay, cosineDecay,
121
+ * warmup). When false, decayShrink is baked at compile time. */
65
122
  lrIsScheduled: boolean
66
123
  }
67
124
 
@@ -101,13 +158,10 @@ export function appendAdam(
101
158
  * directly without a Module). */
102
159
  decayFlags?: Record<string, boolean>,
103
160
  ): AdamResult {
104
- const lrIsScheduled = typeof config.lr === 'function'
105
- const lrFn = lrIsScheduled
106
- ? config.lr as (step: number) => number
107
- : (() => config.lr as number)
108
- const initialLr = lrFn(1)
161
+ const lrIsScheduled = isLRDynamic(config.lr)
162
+ const initialLr = resolveLR(config.lr, 1)
109
163
  const fullConfig: AdamResolvedConfig = {
110
- lr: lrFn,
164
+ lr: config.lr,
111
165
  b1: config.b1 ?? 0.9,
112
166
  b2: config.b2 ?? 0.999,
113
167
  eps: config.eps ?? 1e-8,