deepbox 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +344 -0
- package/dist/CSRMatrix-CwGwQRea.d.cts +219 -0
- package/dist/CSRMatrix-KzNt6QpS.d.ts +219 -0
- package/dist/Tensor-BQLk1ltW.d.cts +147 -0
- package/dist/Tensor-g8mUClel.d.ts +147 -0
- package/dist/chunk-4S73VUBD.js +677 -0
- package/dist/chunk-4S73VUBD.js.map +1 -0
- package/dist/chunk-5R4S63PF.js +2925 -0
- package/dist/chunk-5R4S63PF.js.map +1 -0
- package/dist/chunk-6AE5FKKQ.cjs +9264 -0
- package/dist/chunk-6AE5FKKQ.cjs.map +1 -0
- package/dist/chunk-AD436M45.js +3854 -0
- package/dist/chunk-AD436M45.js.map +1 -0
- package/dist/chunk-ALS7ETWZ.cjs +4263 -0
- package/dist/chunk-ALS7ETWZ.cjs.map +1 -0
- package/dist/chunk-AU7XHGKJ.js +2092 -0
- package/dist/chunk-AU7XHGKJ.js.map +1 -0
- package/dist/chunk-B5TNKUEY.js +1481 -0
- package/dist/chunk-B5TNKUEY.js.map +1 -0
- package/dist/chunk-BCR7G3A6.js +9136 -0
- package/dist/chunk-BCR7G3A6.js.map +1 -0
- package/dist/chunk-C4PKXY74.cjs +1917 -0
- package/dist/chunk-C4PKXY74.cjs.map +1 -0
- package/dist/chunk-DWZY6PIP.cjs +6400 -0
- package/dist/chunk-DWZY6PIP.cjs.map +1 -0
- package/dist/chunk-E3EU5FZO.cjs +2113 -0
- package/dist/chunk-E3EU5FZO.cjs.map +1 -0
- package/dist/chunk-F3JWBINJ.js +1054 -0
- package/dist/chunk-F3JWBINJ.js.map +1 -0
- package/dist/chunk-FJYLIGJX.js +1940 -0
- package/dist/chunk-FJYLIGJX.js.map +1 -0
- package/dist/chunk-JSCDE774.cjs +729 -0
- package/dist/chunk-JSCDE774.cjs.map +1 -0
- package/dist/chunk-LWECRCW2.cjs +2412 -0
- package/dist/chunk-LWECRCW2.cjs.map +1 -0
- package/dist/chunk-MLBMYKCG.js +6379 -0
- package/dist/chunk-MLBMYKCG.js.map +1 -0
- package/dist/chunk-OX6QXFMV.cjs +3874 -0
- package/dist/chunk-OX6QXFMV.cjs.map +1 -0
- package/dist/chunk-PHV2DKRS.cjs +1072 -0
- package/dist/chunk-PHV2DKRS.cjs.map +1 -0
- package/dist/chunk-PL7TAYKI.js +4056 -0
- package/dist/chunk-PL7TAYKI.js.map +1 -0
- package/dist/chunk-PR647I7R.js +1898 -0
- package/dist/chunk-PR647I7R.js.map +1 -0
- package/dist/chunk-QERHVCHC.cjs +2960 -0
- package/dist/chunk-QERHVCHC.cjs.map +1 -0
- package/dist/chunk-XEG44RF6.cjs +1514 -0
- package/dist/chunk-XEG44RF6.cjs.map +1 -0
- package/dist/chunk-XMWVME2W.js +2377 -0
- package/dist/chunk-XMWVME2W.js.map +1 -0
- package/dist/chunk-ZB75FESB.cjs +1979 -0
- package/dist/chunk-ZB75FESB.cjs.map +1 -0
- package/dist/chunk-ZLW62TJG.cjs +4061 -0
- package/dist/chunk-ZLW62TJG.cjs.map +1 -0
- package/dist/chunk-ZXKBDFP3.js +4235 -0
- package/dist/chunk-ZXKBDFP3.js.map +1 -0
- package/dist/core/index.cjs +204 -0
- package/dist/core/index.cjs.map +1 -0
- package/dist/core/index.d.cts +2 -0
- package/dist/core/index.d.ts +2 -0
- package/dist/core/index.js +3 -0
- package/dist/core/index.js.map +1 -0
- package/dist/dataframe/index.cjs +22 -0
- package/dist/dataframe/index.cjs.map +1 -0
- package/dist/dataframe/index.d.cts +3 -0
- package/dist/dataframe/index.d.ts +3 -0
- package/dist/dataframe/index.js +5 -0
- package/dist/dataframe/index.js.map +1 -0
- package/dist/datasets/index.cjs +134 -0
- package/dist/datasets/index.cjs.map +1 -0
- package/dist/datasets/index.d.cts +3 -0
- package/dist/datasets/index.d.ts +3 -0
- package/dist/datasets/index.js +5 -0
- package/dist/datasets/index.js.map +1 -0
- package/dist/index-74AB8Cyh.d.cts +1126 -0
- package/dist/index-9oQx1HgV.d.cts +1180 -0
- package/dist/index-BJY2SI4i.d.ts +483 -0
- package/dist/index-BWGhrDlr.d.ts +733 -0
- package/dist/index-B_DK4FKY.d.cts +242 -0
- package/dist/index-BbA2Gxfl.d.ts +456 -0
- package/dist/index-BgHYAoSS.d.cts +837 -0
- package/dist/index-BndMbqsM.d.ts +1439 -0
- package/dist/index-C1mfVYoo.d.ts +2517 -0
- package/dist/index-CCvlwAmL.d.cts +809 -0
- package/dist/index-CDw5CnOU.d.ts +785 -0
- package/dist/index-Cn3SdB0O.d.ts +1126 -0
- package/dist/index-CrqLlS-a.d.ts +776 -0
- package/dist/index-D61yaSMY.d.cts +483 -0
- package/dist/index-D9Loo1_A.d.cts +2517 -0
- package/dist/index-DIT_OO9C.d.cts +785 -0
- package/dist/index-DIp_RrRt.d.ts +242 -0
- package/dist/index-DbultU6X.d.cts +1427 -0
- package/dist/index-DmEg_LCm.d.cts +776 -0
- package/dist/index-DoPWVxPo.d.cts +1439 -0
- package/dist/index-DuCxd-8d.d.ts +837 -0
- package/dist/index-Dx42TZaY.d.ts +809 -0
- package/dist/index-DyZ4QQf5.d.cts +456 -0
- package/dist/index-GFAVyOWO.d.ts +1427 -0
- package/dist/index-WHQLn0e8.d.cts +733 -0
- package/dist/index-ZtI1Iy4L.d.ts +1180 -0
- package/dist/index-eJgeni9c.d.cts +1911 -0
- package/dist/index-tk4lSYod.d.ts +1911 -0
- package/dist/index.cjs +72 -0
- package/dist/index.cjs.map +1 -0
- package/dist/index.d.cts +17 -0
- package/dist/index.d.ts +17 -0
- package/dist/index.js +15 -0
- package/dist/index.js.map +1 -0
- package/dist/linalg/index.cjs +86 -0
- package/dist/linalg/index.cjs.map +1 -0
- package/dist/linalg/index.d.cts +3 -0
- package/dist/linalg/index.d.ts +3 -0
- package/dist/linalg/index.js +5 -0
- package/dist/linalg/index.js.map +1 -0
- package/dist/metrics/index.cjs +158 -0
- package/dist/metrics/index.cjs.map +1 -0
- package/dist/metrics/index.d.cts +3 -0
- package/dist/metrics/index.d.ts +3 -0
- package/dist/metrics/index.js +5 -0
- package/dist/metrics/index.js.map +1 -0
- package/dist/ml/index.cjs +87 -0
- package/dist/ml/index.cjs.map +1 -0
- package/dist/ml/index.d.cts +3 -0
- package/dist/ml/index.d.ts +3 -0
- package/dist/ml/index.js +6 -0
- package/dist/ml/index.js.map +1 -0
- package/dist/ndarray/index.cjs +501 -0
- package/dist/ndarray/index.cjs.map +1 -0
- package/dist/ndarray/index.d.cts +5 -0
- package/dist/ndarray/index.d.ts +5 -0
- package/dist/ndarray/index.js +4 -0
- package/dist/ndarray/index.js.map +1 -0
- package/dist/nn/index.cjs +142 -0
- package/dist/nn/index.cjs.map +1 -0
- package/dist/nn/index.d.cts +6 -0
- package/dist/nn/index.d.ts +6 -0
- package/dist/nn/index.js +5 -0
- package/dist/nn/index.js.map +1 -0
- package/dist/optim/index.cjs +77 -0
- package/dist/optim/index.cjs.map +1 -0
- package/dist/optim/index.d.cts +4 -0
- package/dist/optim/index.d.ts +4 -0
- package/dist/optim/index.js +4 -0
- package/dist/optim/index.js.map +1 -0
- package/dist/plot/index.cjs +114 -0
- package/dist/plot/index.cjs.map +1 -0
- package/dist/plot/index.d.cts +6 -0
- package/dist/plot/index.d.ts +6 -0
- package/dist/plot/index.js +5 -0
- package/dist/plot/index.js.map +1 -0
- package/dist/preprocess/index.cjs +82 -0
- package/dist/preprocess/index.cjs.map +1 -0
- package/dist/preprocess/index.d.cts +4 -0
- package/dist/preprocess/index.d.ts +4 -0
- package/dist/preprocess/index.js +5 -0
- package/dist/preprocess/index.js.map +1 -0
- package/dist/random/index.cjs +74 -0
- package/dist/random/index.cjs.map +1 -0
- package/dist/random/index.d.cts +3 -0
- package/dist/random/index.d.ts +3 -0
- package/dist/random/index.js +5 -0
- package/dist/random/index.js.map +1 -0
- package/dist/stats/index.cjs +142 -0
- package/dist/stats/index.cjs.map +1 -0
- package/dist/stats/index.d.cts +3 -0
- package/dist/stats/index.d.ts +3 -0
- package/dist/stats/index.js +5 -0
- package/dist/stats/index.js.map +1 -0
- package/dist/tensor-B96jjJLQ.d.cts +205 -0
- package/dist/tensor-B96jjJLQ.d.ts +205 -0
- package/package.json +226 -0
|
@@ -0,0 +1,1898 @@
|
|
|
1
|
+
import { __export, DataValidationError, InvalidParameterError, DeepboxError, NotFittedError, DTypeError, ShapeError, IndexError } from './chunk-4S73VUBD.js';
|
|
2
|
+
|
|
3
|
+
// src/optim/index.ts
|
|
4
|
+
var optim_exports = {};
|
|
5
|
+
__export(optim_exports, {
|
|
6
|
+
AdaDelta: () => AdaDelta,
|
|
7
|
+
Adagrad: () => Adagrad,
|
|
8
|
+
Adam: () => Adam,
|
|
9
|
+
AdamW: () => AdamW,
|
|
10
|
+
CosineAnnealingLR: () => CosineAnnealingLR,
|
|
11
|
+
ExponentialLR: () => ExponentialLR,
|
|
12
|
+
LRScheduler: () => LRScheduler,
|
|
13
|
+
LinearLR: () => LinearLR,
|
|
14
|
+
MultiStepLR: () => MultiStepLR,
|
|
15
|
+
Nadam: () => Nadam,
|
|
16
|
+
OneCycleLR: () => OneCycleLR,
|
|
17
|
+
Optimizer: () => Optimizer,
|
|
18
|
+
RMSprop: () => RMSprop,
|
|
19
|
+
ReduceLROnPlateau: () => ReduceLROnPlateau,
|
|
20
|
+
SGD: () => SGD,
|
|
21
|
+
StepLR: () => StepLR,
|
|
22
|
+
WarmupLR: () => WarmupLR
|
|
23
|
+
});
|
|
24
|
+
|
|
25
|
+
// src/optim/Optimizer.ts
|
|
26
|
+
function isRecord(value) {
|
|
27
|
+
return typeof value === "object" && value !== null;
|
|
28
|
+
}
|
|
29
|
+
function ensureRecord(value, context) {
|
|
30
|
+
if (!isRecord(value)) {
|
|
31
|
+
throw new DataValidationError(`${context} must be an object`);
|
|
32
|
+
}
|
|
33
|
+
return value;
|
|
34
|
+
}
|
|
35
|
+
function isStateRecord(value) {
|
|
36
|
+
return isRecord(value);
|
|
37
|
+
}
|
|
38
|
+
function ensureIntegerArray(value, context) {
|
|
39
|
+
if (!Array.isArray(value)) {
|
|
40
|
+
throw new DataValidationError(`${context} must be an array of integers`);
|
|
41
|
+
}
|
|
42
|
+
const output = [];
|
|
43
|
+
for (const entry of value) {
|
|
44
|
+
if (!Number.isInteger(entry)) {
|
|
45
|
+
throw new DataValidationError(`${context} must contain integers only`);
|
|
46
|
+
}
|
|
47
|
+
output.push(entry);
|
|
48
|
+
}
|
|
49
|
+
return output;
|
|
50
|
+
}
|
|
51
|
+
function isParamGroupArray(params) {
|
|
52
|
+
if (!Array.isArray(params)) return false;
|
|
53
|
+
if (params.length === 0) return true;
|
|
54
|
+
const first = params[0];
|
|
55
|
+
if (!first || typeof first !== "object") return false;
|
|
56
|
+
return "params" in first;
|
|
57
|
+
}
|
|
58
|
+
var Optimizer = class {
|
|
59
|
+
/**
|
|
60
|
+
* Create a new optimizer.
|
|
61
|
+
*
|
|
62
|
+
* Initializes the optimizer with either a simple list of parameters or
|
|
63
|
+
* multiple parameter groups with per-group hyperparameters.
|
|
64
|
+
*
|
|
65
|
+
* @param params - Either an iterable of parameters or array of parameter groups
|
|
66
|
+
* @param defaults - Default hyperparameters applied to all groups
|
|
67
|
+
*/
|
|
68
|
+
constructor(params, defaults) {
|
|
69
|
+
this.defaults = defaults;
|
|
70
|
+
this.paramGroups = [];
|
|
71
|
+
if (!isParamGroupArray(params)) {
|
|
72
|
+
this.paramGroups.push({
|
|
73
|
+
params: Array.from(params),
|
|
74
|
+
// Convert iterable to array for efficient access
|
|
75
|
+
options: { ...defaults }
|
|
76
|
+
// Clone defaults to avoid mutation
|
|
77
|
+
});
|
|
78
|
+
} else {
|
|
79
|
+
for (const group of params) {
|
|
80
|
+
const { params: groupParams, ...groupOptions } = group;
|
|
81
|
+
this.paramGroups.push({
|
|
82
|
+
params: Array.from(groupParams),
|
|
83
|
+
// Convert to array
|
|
84
|
+
options: { ...defaults, ...groupOptions }
|
|
85
|
+
// Merge defaults with group options
|
|
86
|
+
});
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
/**
|
|
91
|
+
* Groups of parameters with their associated hyperparameters.
|
|
92
|
+
* Each group can have different options (e.g., learning rates).
|
|
93
|
+
* Exposed publicly to enable scheduler integrations.
|
|
94
|
+
*/
|
|
95
|
+
paramGroups;
|
|
96
|
+
/**
|
|
97
|
+
* Per-parameter state storage.
|
|
98
|
+
* Maps each parameter to its optimizer-specific state (momentum, adaptive rates, etc.).
|
|
99
|
+
*/
|
|
100
|
+
state = /* @__PURE__ */ new Map();
|
|
101
|
+
/**
|
|
102
|
+
* Zero out the gradients of all optimized parameters.
|
|
103
|
+
*
|
|
104
|
+
* This method should be called at the beginning of each training iteration,
|
|
105
|
+
* before computing new gradients. Without this call, gradients would accumulate
|
|
106
|
+
* across iterations, leading to incorrect updates.
|
|
107
|
+
*
|
|
108
|
+
* **Implementation Note:**
|
|
109
|
+
* For parameters wrapped in GradTensor, this calls zeroGrad() on each parameter,
|
|
110
|
+
* which either sets the gradient to zero or initializes it if not yet created.
|
|
111
|
+
*
|
|
112
|
+
* @example
|
|
113
|
+
* ```ts
|
|
114
|
+
* // Typical training loop
|
|
115
|
+
* optimizer.zeroGrad(); // Clear previous gradients
|
|
116
|
+
* const output = model.forward(input);
|
|
117
|
+
* const loss = criterion(output, target);
|
|
118
|
+
* loss.backward(); // Compute new gradients
|
|
119
|
+
* optimizer.step(); // Update parameters
|
|
120
|
+
* ```
|
|
121
|
+
*/
|
|
122
|
+
zeroGrad() {
|
|
123
|
+
for (const group of this.paramGroups) {
|
|
124
|
+
for (const param of group.params) {
|
|
125
|
+
param.zeroGrad();
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
/**
|
|
130
|
+
* Add a parameter group to the optimizer.
|
|
131
|
+
*
|
|
132
|
+
* This method allows adding new parameters to optimize after the optimizer
|
|
133
|
+
* has been created. This is particularly useful for:
|
|
134
|
+
* - Fine-tuning: adding pre-trained layers with different learning rates
|
|
135
|
+
* - Progressive training: gradually unfreezing layers
|
|
136
|
+
* - Dynamic architectures: adding parameters while the model grows
|
|
137
|
+
*
|
|
138
|
+
* @param paramGroup - Parameter group to add with optional per-group options
|
|
139
|
+
*
|
|
140
|
+
* @example
|
|
141
|
+
* ```ts
|
|
142
|
+
* const optimizer = new SGD(model.backbone.parameters(), { lr: 0.001 });
|
|
143
|
+
* // Later, add classifier with higher learning rate
|
|
144
|
+
* optimizer.addParamGroup({
|
|
145
|
+
* params: model.classifier.parameters(),
|
|
146
|
+
* lr: 0.01
|
|
147
|
+
* });
|
|
148
|
+
* ```
|
|
149
|
+
*/
|
|
150
|
+
addParamGroup(paramGroup) {
|
|
151
|
+
const { params, ...options } = paramGroup;
|
|
152
|
+
this.paramGroups.push({
|
|
153
|
+
params: Array.from(params),
|
|
154
|
+
// Convert iterable to array
|
|
155
|
+
options: { ...this.defaults, ...options }
|
|
156
|
+
// Merge with defaults
|
|
157
|
+
});
|
|
158
|
+
}
|
|
159
|
+
/**
|
|
160
|
+
* Get the current state of the optimizer.
|
|
161
|
+
*
|
|
162
|
+
* Returns a dictionary containing all optimizer state that needs to be
|
|
163
|
+
* saved for checkpointing. This includes per-parameter state (momentum buffers,
|
|
164
|
+
* adaptive learning rates, etc.) and parameter group configurations.
|
|
165
|
+
*
|
|
166
|
+
* **Note:** In a production implementation, parameters would be identified by
|
|
167
|
+
* unique IDs rather than object references for proper serialization.
|
|
168
|
+
*
|
|
169
|
+
* @returns Optimizer state dictionary containing state and parameter groups
|
|
170
|
+
*
|
|
171
|
+
* @example
|
|
172
|
+
* ```ts
|
|
173
|
+
* // Save checkpoint
|
|
174
|
+
* const checkpoint = {
|
|
175
|
+
* model: model.stateDict(),
|
|
176
|
+
* optimizer: optimizer.stateDict(),
|
|
177
|
+
* epoch: currentEpoch
|
|
178
|
+
* };
|
|
179
|
+
* ```
|
|
180
|
+
*/
|
|
181
|
+
stateDict() {
|
|
182
|
+
const paramIdMap = /* @__PURE__ */ new Map();
|
|
183
|
+
const orderedParams = [];
|
|
184
|
+
const getParamId = (param) => {
|
|
185
|
+
const existing = paramIdMap.get(param);
|
|
186
|
+
if (existing !== void 0) return existing;
|
|
187
|
+
const id = orderedParams.length;
|
|
188
|
+
orderedParams.push(param);
|
|
189
|
+
paramIdMap.set(param, id);
|
|
190
|
+
return id;
|
|
191
|
+
};
|
|
192
|
+
return {
|
|
193
|
+
// Serialize per-parameter state
|
|
194
|
+
state: Array.from(this.state.entries()).map(([param, state]) => ({
|
|
195
|
+
paramId: getParamId(param),
|
|
196
|
+
param,
|
|
197
|
+
// Backward-compatible references
|
|
198
|
+
state
|
|
199
|
+
// Optimizer-specific state (momentum, etc.)
|
|
200
|
+
})),
|
|
201
|
+
// Serialize parameter groups and their options
|
|
202
|
+
paramGroups: this.paramGroups.map((group) => ({
|
|
203
|
+
params: group.params,
|
|
204
|
+
// Backward-compatible references
|
|
205
|
+
paramIds: group.params.map((param) => getParamId(param)),
|
|
206
|
+
options: group.options
|
|
207
|
+
// Hyperparameters for this group
|
|
208
|
+
}))
|
|
209
|
+
};
|
|
210
|
+
}
|
|
211
|
+
/**
|
|
212
|
+
* Load optimizer state from a state dictionary.
|
|
213
|
+
*
|
|
214
|
+
* Restores the optimizer to a previously saved state, including all
|
|
215
|
+
* per-parameter state and parameter group configurations. This is essential
|
|
216
|
+
* for resuming training from checkpoints.
|
|
217
|
+
*
|
|
218
|
+
* **Important:** The loaded state must be compatible with the current
|
|
219
|
+
* optimizer configuration (same parameters, same optimizer type).
|
|
220
|
+
*
|
|
221
|
+
* @param stateDict - State dictionary previously returned by stateDict()
|
|
222
|
+
*
|
|
223
|
+
* @example
|
|
224
|
+
* ```ts
|
|
225
|
+
* // Resume from checkpoint
|
|
226
|
+
* const checkpoint = loadCheckpoint('checkpoint.json');
|
|
227
|
+
* model.loadStateDict(checkpoint.model);
|
|
228
|
+
* optimizer.loadStateDict(checkpoint.optimizer);
|
|
229
|
+
* ```
|
|
230
|
+
*/
|
|
231
|
+
loadStateDict(stateDict) {
|
|
232
|
+
const currentParams = this.paramGroups.flatMap((group) => group.params);
|
|
233
|
+
const currentParamCount = currentParams.length;
|
|
234
|
+
const paramLookup = /* @__PURE__ */ new Map();
|
|
235
|
+
for (let i = 0; i < currentParams.length; i++) {
|
|
236
|
+
paramLookup.set(currentParams[i], i);
|
|
237
|
+
}
|
|
238
|
+
if (Object.hasOwn(stateDict, "paramGroups")) {
|
|
239
|
+
const rawGroups = stateDict["paramGroups"];
|
|
240
|
+
if (!Array.isArray(rawGroups)) {
|
|
241
|
+
throw new DataValidationError("paramGroups must be an array");
|
|
242
|
+
}
|
|
243
|
+
const groupsArray = rawGroups;
|
|
244
|
+
if (groupsArray.length === 0) {
|
|
245
|
+
if (this.paramGroups.length !== 0) {
|
|
246
|
+
throw new DataValidationError("paramGroups cannot be empty");
|
|
247
|
+
}
|
|
248
|
+
this.paramGroups = [];
|
|
249
|
+
} else {
|
|
250
|
+
if (groupsArray.length !== this.paramGroups.length) {
|
|
251
|
+
throw new DataValidationError("paramGroups count mismatch");
|
|
252
|
+
}
|
|
253
|
+
const seenParamIds = /* @__PURE__ */ new Set();
|
|
254
|
+
let totalParamCount = 0;
|
|
255
|
+
let sawParamIds = false;
|
|
256
|
+
let sawNoParamIds = false;
|
|
257
|
+
const nextGroups = [];
|
|
258
|
+
groupsArray.forEach((rawGroup, index) => {
|
|
259
|
+
const groupRecord = ensureRecord(rawGroup, `paramGroups[${index}]`);
|
|
260
|
+
const optionsRaw = ensureRecord(groupRecord["options"], `paramGroups[${index}].options`);
|
|
261
|
+
const options = { ...this.defaults };
|
|
262
|
+
const optionsRecord = options;
|
|
263
|
+
const defaultsRecord = { ...this.defaults };
|
|
264
|
+
for (const [key, value] of Object.entries(optionsRaw)) {
|
|
265
|
+
if (Object.hasOwn(defaultsRecord, key)) {
|
|
266
|
+
const defaultVal = defaultsRecord[key];
|
|
267
|
+
const expectedType = typeof defaultVal;
|
|
268
|
+
const actualType = typeof value;
|
|
269
|
+
if (actualType !== expectedType) {
|
|
270
|
+
throw new DataValidationError(
|
|
271
|
+
`Type mismatch for option '${key}' in paramGroups[${index}]: expected ${expectedType}, got ${actualType}`
|
|
272
|
+
);
|
|
273
|
+
}
|
|
274
|
+
optionsRecord[key] = value;
|
|
275
|
+
}
|
|
276
|
+
}
|
|
277
|
+
const paramIdsRaw = groupRecord["paramIds"];
|
|
278
|
+
const paramsRaw = groupRecord["params"];
|
|
279
|
+
let paramIds;
|
|
280
|
+
if (paramIdsRaw !== void 0) {
|
|
281
|
+
paramIds = ensureIntegerArray(paramIdsRaw, `paramGroups[${index}].paramIds`);
|
|
282
|
+
sawParamIds = true;
|
|
283
|
+
} else {
|
|
284
|
+
sawNoParamIds = true;
|
|
285
|
+
}
|
|
286
|
+
let resolvedParams;
|
|
287
|
+
if (paramIds) {
|
|
288
|
+
for (const id of paramIds) {
|
|
289
|
+
if (id < 0 || id >= currentParamCount) {
|
|
290
|
+
throw new DataValidationError(`Invalid paramId ${id} in paramGroups`);
|
|
291
|
+
}
|
|
292
|
+
if (seenParamIds.has(id)) {
|
|
293
|
+
throw new DataValidationError(`Duplicate paramId ${id} in paramGroups`);
|
|
294
|
+
}
|
|
295
|
+
seenParamIds.add(id);
|
|
296
|
+
}
|
|
297
|
+
totalParamCount += paramIds.length;
|
|
298
|
+
resolvedParams = paramIds.map((id) => {
|
|
299
|
+
const param = currentParams[id];
|
|
300
|
+
if (!param) {
|
|
301
|
+
throw new DataValidationError(`Invalid paramId ${id} in paramGroups`);
|
|
302
|
+
}
|
|
303
|
+
return param;
|
|
304
|
+
});
|
|
305
|
+
}
|
|
306
|
+
if (paramsRaw !== void 0) {
|
|
307
|
+
if (!Array.isArray(paramsRaw)) {
|
|
308
|
+
throw new DataValidationError(`paramGroups[${index}].params must be an array`);
|
|
309
|
+
}
|
|
310
|
+
const resolvedFromParams = [];
|
|
311
|
+
let hasUnknown = false;
|
|
312
|
+
for (const paramRef of paramsRaw) {
|
|
313
|
+
const paramIndex = paramLookup.get(paramRef);
|
|
314
|
+
if (paramIndex === void 0) {
|
|
315
|
+
hasUnknown = true;
|
|
316
|
+
continue;
|
|
317
|
+
}
|
|
318
|
+
const param = currentParams[paramIndex];
|
|
319
|
+
if (!param) {
|
|
320
|
+
hasUnknown = true;
|
|
321
|
+
continue;
|
|
322
|
+
}
|
|
323
|
+
resolvedFromParams.push(param);
|
|
324
|
+
}
|
|
325
|
+
if (!hasUnknown) {
|
|
326
|
+
if (paramIds && paramIds.length !== resolvedFromParams.length) {
|
|
327
|
+
throw new DataValidationError("paramIds length does not match params length");
|
|
328
|
+
}
|
|
329
|
+
if (!resolvedParams) {
|
|
330
|
+
resolvedParams = resolvedFromParams;
|
|
331
|
+
}
|
|
332
|
+
}
|
|
333
|
+
}
|
|
334
|
+
if (!resolvedParams) {
|
|
335
|
+
throw new DataValidationError(`paramGroups[${index}] must include params or paramIds`);
|
|
336
|
+
}
|
|
337
|
+
if (paramIds === void 0) {
|
|
338
|
+
nextGroups.push({ params: resolvedParams, options });
|
|
339
|
+
} else {
|
|
340
|
+
nextGroups.push({ params: resolvedParams, options, paramIds });
|
|
341
|
+
}
|
|
342
|
+
});
|
|
343
|
+
if (sawParamIds && sawNoParamIds) {
|
|
344
|
+
throw new DataValidationError("paramIds must be provided for all parameter groups");
|
|
345
|
+
}
|
|
346
|
+
if (sawParamIds && totalParamCount !== currentParamCount) {
|
|
347
|
+
throw new DataValidationError(
|
|
348
|
+
`Parameter count mismatch: expected ${currentParamCount}, got ${totalParamCount}`
|
|
349
|
+
);
|
|
350
|
+
}
|
|
351
|
+
this.paramGroups = nextGroups.map((group) => ({
|
|
352
|
+
params: group.params,
|
|
353
|
+
options: group.options
|
|
354
|
+
}));
|
|
355
|
+
}
|
|
356
|
+
}
|
|
357
|
+
if (Object.hasOwn(stateDict, "state")) {
|
|
358
|
+
const rawState = stateDict["state"];
|
|
359
|
+
if (!Array.isArray(rawState)) {
|
|
360
|
+
throw new DataValidationError("state must be an array");
|
|
361
|
+
}
|
|
362
|
+
const stateArray = rawState;
|
|
363
|
+
this.state.clear();
|
|
364
|
+
stateArray.forEach((rawEntry, index) => {
|
|
365
|
+
const entryRecord = ensureRecord(rawEntry, `state[${index}]`);
|
|
366
|
+
if (!Object.hasOwn(entryRecord, "state")) {
|
|
367
|
+
throw new DataValidationError(`state[${index}].state is required`);
|
|
368
|
+
}
|
|
369
|
+
const entryStateValue = ensureRecord(entryRecord["state"], `state[${index}].state`);
|
|
370
|
+
if (!isStateRecord(entryStateValue)) {
|
|
371
|
+
throw new DataValidationError(`state[${index}].state must be an object`);
|
|
372
|
+
}
|
|
373
|
+
const paramIdRaw = entryRecord["paramId"];
|
|
374
|
+
const paramRaw = entryRecord["param"];
|
|
375
|
+
let resolvedParam;
|
|
376
|
+
if (paramIdRaw !== void 0) {
|
|
377
|
+
if (paramIdRaw === null || typeof paramIdRaw !== "number" || !Number.isInteger(paramIdRaw)) {
|
|
378
|
+
throw new DataValidationError(`Invalid paramId ${String(paramIdRaw)} in state`);
|
|
379
|
+
}
|
|
380
|
+
if (paramIdRaw < 0 || paramIdRaw >= currentParamCount) {
|
|
381
|
+
throw new DataValidationError(`Invalid paramId ${paramIdRaw} in state`);
|
|
382
|
+
}
|
|
383
|
+
const param = currentParams[paramIdRaw];
|
|
384
|
+
if (!param) {
|
|
385
|
+
throw new DataValidationError(`Invalid paramId ${paramIdRaw} in state`);
|
|
386
|
+
}
|
|
387
|
+
if (paramRaw !== void 0) {
|
|
388
|
+
const paramIndex = paramLookup.get(paramRaw);
|
|
389
|
+
if (paramIndex === void 0 || paramIndex !== paramIdRaw) {
|
|
390
|
+
throw new DataValidationError(`paramId ${paramIdRaw} does not match provided param`);
|
|
391
|
+
}
|
|
392
|
+
}
|
|
393
|
+
resolvedParam = param;
|
|
394
|
+
} else {
|
|
395
|
+
if (paramRaw === void 0) {
|
|
396
|
+
throw new DataValidationError("Missing param reference in state entry");
|
|
397
|
+
}
|
|
398
|
+
const paramIndex = paramLookup.get(paramRaw);
|
|
399
|
+
if (paramIndex === void 0) {
|
|
400
|
+
throw new DataValidationError("Unknown param reference in state entry");
|
|
401
|
+
}
|
|
402
|
+
const param = currentParams[paramIndex];
|
|
403
|
+
if (!param) {
|
|
404
|
+
throw new DataValidationError("Unknown param reference in state entry");
|
|
405
|
+
}
|
|
406
|
+
resolvedParam = param;
|
|
407
|
+
}
|
|
408
|
+
if (!resolvedParam) {
|
|
409
|
+
throw new DataValidationError(`Unable to resolve parameter for state[${index}]`);
|
|
410
|
+
}
|
|
411
|
+
if (!this.isState(entryStateValue)) {
|
|
412
|
+
throw new DataValidationError(`state[${index}].state has invalid structure`);
|
|
413
|
+
}
|
|
414
|
+
this.state.set(resolvedParam, entryStateValue);
|
|
415
|
+
});
|
|
416
|
+
}
|
|
417
|
+
}
|
|
418
|
+
};
|
|
419
|
+
|
|
420
|
+
// src/optim/_internal.ts
|
|
421
|
+
function isFloatTypedArray(value) {
|
|
422
|
+
return value instanceof Float32Array || value instanceof Float64Array;
|
|
423
|
+
}
|
|
424
|
+
function safeArrayAccess(array, index, context) {
|
|
425
|
+
if (index < 0 || index >= array.length) {
|
|
426
|
+
throw new IndexError(`Index ${index} out of bounds [0, ${array.length}) in ${context}`, {
|
|
427
|
+
index,
|
|
428
|
+
validRange: [0, array.length - 1]
|
|
429
|
+
});
|
|
430
|
+
}
|
|
431
|
+
const value = array[index];
|
|
432
|
+
if (value === void 0) {
|
|
433
|
+
throw new DeepboxError(`Unexpected undefined at index ${index} in ${context}`);
|
|
434
|
+
}
|
|
435
|
+
return value;
|
|
436
|
+
}
|
|
437
|
+
function assertFiniteNonNegative(name, value) {
|
|
438
|
+
if (!Number.isFinite(value) || value < 0) {
|
|
439
|
+
throw new InvalidParameterError(`Invalid ${name}: ${value}`, name, value);
|
|
440
|
+
}
|
|
441
|
+
}
|
|
442
|
+
function assertFinitePositive(name, value) {
|
|
443
|
+
if (!Number.isFinite(value) || value <= 0) {
|
|
444
|
+
throw new InvalidParameterError(`Invalid ${name}: ${value} (must be > 0)`, name, value);
|
|
445
|
+
}
|
|
446
|
+
}
|
|
447
|
+
function assertFinite(name, value) {
|
|
448
|
+
if (!Number.isFinite(value)) {
|
|
449
|
+
throw new InvalidParameterError(`Invalid ${name}: ${value}`, name, value);
|
|
450
|
+
}
|
|
451
|
+
}
|
|
452
|
+
function assertInRange(name, value, min, max) {
|
|
453
|
+
if (!Number.isFinite(value) || value < min || value >= max) {
|
|
454
|
+
throw new InvalidParameterError(
|
|
455
|
+
`Invalid ${name}: ${value} (must be in range [${min}, ${max}))`,
|
|
456
|
+
name,
|
|
457
|
+
value
|
|
458
|
+
);
|
|
459
|
+
}
|
|
460
|
+
}
|
|
461
|
+
function assertHasGradFloat(param, optimizerName) {
|
|
462
|
+
if (!param.requiresGrad) {
|
|
463
|
+
throw new InvalidParameterError(
|
|
464
|
+
"Cannot optimize a parameter with requiresGrad=false",
|
|
465
|
+
"requiresGrad",
|
|
466
|
+
false
|
|
467
|
+
);
|
|
468
|
+
}
|
|
469
|
+
const g = param.grad;
|
|
470
|
+
if (!g) {
|
|
471
|
+
throw new NotFittedError(
|
|
472
|
+
"Cannot optimize a parameter without a gradient. Did you forget backward()?"
|
|
473
|
+
);
|
|
474
|
+
}
|
|
475
|
+
const paramData = param.tensor.data;
|
|
476
|
+
const gradData = g.data;
|
|
477
|
+
if (!isFloatTypedArray(paramData) || !isFloatTypedArray(gradData)) {
|
|
478
|
+
throw new DTypeError(
|
|
479
|
+
`${optimizerName} optimizer supports float32 and float64 parameters and gradients only`
|
|
480
|
+
);
|
|
481
|
+
}
|
|
482
|
+
if (paramData.constructor !== gradData.constructor) {
|
|
483
|
+
throw new DTypeError(
|
|
484
|
+
`${optimizerName} optimizer requires parameter and gradient dtypes to match`
|
|
485
|
+
);
|
|
486
|
+
}
|
|
487
|
+
if (param.tensor.size !== g.size) {
|
|
488
|
+
throw new ShapeError(
|
|
489
|
+
`Gradient shape must match parameter shape (param: ${param.tensor.size}, grad: ${g.size})`
|
|
490
|
+
);
|
|
491
|
+
}
|
|
492
|
+
return {
|
|
493
|
+
grad: gradData,
|
|
494
|
+
gradOffset: g.offset,
|
|
495
|
+
param: paramData,
|
|
496
|
+
paramOffset: param.tensor.offset
|
|
497
|
+
};
|
|
498
|
+
}
|
|
499
|
+
function assertBufferSize(buffer, expectedSize, bufferName) {
|
|
500
|
+
if (buffer.length !== expectedSize) {
|
|
501
|
+
throw new DeepboxError(
|
|
502
|
+
`State buffer size mismatch for ${bufferName}: expected ${expectedSize}, got ${buffer.length}`
|
|
503
|
+
);
|
|
504
|
+
}
|
|
505
|
+
}
|
|
506
|
+
|
|
507
|
+
// src/optim/optimizers/adadelta.ts
|
|
508
|
+
var AdaDelta = class extends Optimizer {
|
|
509
|
+
_stepCount = 0;
|
|
510
|
+
get stepCount() {
|
|
511
|
+
return this._stepCount;
|
|
512
|
+
}
|
|
513
|
+
constructor(params, options = {}) {
|
|
514
|
+
const defaults = {
|
|
515
|
+
lr: options.lr ?? 1,
|
|
516
|
+
rho: options.rho ?? 0.9,
|
|
517
|
+
eps: options.eps ?? 1e-6,
|
|
518
|
+
weightDecay: options.weightDecay ?? 0
|
|
519
|
+
};
|
|
520
|
+
super(params, defaults);
|
|
521
|
+
assertFiniteNonNegative("learning rate", defaults.lr);
|
|
522
|
+
assertInRange("rho", defaults.rho, 0, 1);
|
|
523
|
+
assertFinitePositive("epsilon", defaults.eps);
|
|
524
|
+
assertFiniteNonNegative("weight_decay value", defaults.weightDecay);
|
|
525
|
+
}
|
|
526
|
+
/**
|
|
527
|
+
* Get the current learning rate.
|
|
528
|
+
*
|
|
529
|
+
* @param groupIdx - Parameter group index (default: 0)
|
|
530
|
+
* @returns Current learning rate
|
|
531
|
+
*/
|
|
532
|
+
getLearningRate(groupIdx = 0) {
|
|
533
|
+
const group = this.paramGroups[groupIdx];
|
|
534
|
+
if (!group) {
|
|
535
|
+
throw new InvalidParameterError(
|
|
536
|
+
`Invalid group index: ${groupIdx} (valid range: [0, ${this.paramGroups.length}))`,
|
|
537
|
+
"groupIdx",
|
|
538
|
+
groupIdx
|
|
539
|
+
);
|
|
540
|
+
}
|
|
541
|
+
return group.options.lr;
|
|
542
|
+
}
|
|
543
|
+
/**
|
|
544
|
+
* Set the learning rate for all parameter groups.
|
|
545
|
+
*
|
|
546
|
+
* @param lr - New learning rate
|
|
547
|
+
*/
|
|
548
|
+
setLearningRate(lr) {
|
|
549
|
+
assertFiniteNonNegative("learning rate", lr);
|
|
550
|
+
for (const group of this.paramGroups) {
|
|
551
|
+
group.options.lr = lr;
|
|
552
|
+
}
|
|
553
|
+
}
|
|
554
|
+
isState(state) {
|
|
555
|
+
return state["squareAvg"] instanceof Float64Array && state["accDelta"] instanceof Float64Array;
|
|
556
|
+
}
|
|
557
|
+
step(closure) {
|
|
558
|
+
let loss;
|
|
559
|
+
if (closure) {
|
|
560
|
+
loss = closure();
|
|
561
|
+
}
|
|
562
|
+
this._stepCount++;
|
|
563
|
+
for (const group of this.paramGroups) {
|
|
564
|
+
const { lr, rho, eps, weightDecay } = group.options;
|
|
565
|
+
assertFiniteNonNegative("learning rate", lr);
|
|
566
|
+
assertInRange("rho", rho, 0, 1);
|
|
567
|
+
assertFinitePositive("epsilon", eps);
|
|
568
|
+
assertFiniteNonNegative("weight_decay value", weightDecay);
|
|
569
|
+
for (const param of group.params) {
|
|
570
|
+
const {
|
|
571
|
+
grad: gradData,
|
|
572
|
+
gradOffset: gOff,
|
|
573
|
+
param: pData,
|
|
574
|
+
paramOffset: pOff
|
|
575
|
+
} = assertHasGradFloat(param, "AdaDelta");
|
|
576
|
+
const size = param.tensor.size;
|
|
577
|
+
let state = this.state.get(param);
|
|
578
|
+
if (!state) {
|
|
579
|
+
state = {
|
|
580
|
+
squareAvg: new Float64Array(size),
|
|
581
|
+
accDelta: new Float64Array(size)
|
|
582
|
+
};
|
|
583
|
+
this.state.set(param, state);
|
|
584
|
+
}
|
|
585
|
+
assertBufferSize(state.squareAvg, size, "AdaDelta squareAvg");
|
|
586
|
+
assertBufferSize(state.accDelta, size, "AdaDelta accDelta");
|
|
587
|
+
for (let i = 0; i < size; i++) {
|
|
588
|
+
const gi0 = safeArrayAccess(gradData, gOff + i, "AdaDelta gradient");
|
|
589
|
+
const pi = safeArrayAccess(pData, pOff + i, "AdaDelta parameter");
|
|
590
|
+
assertFinite("gradient", gi0);
|
|
591
|
+
assertFinite("parameter", pi);
|
|
592
|
+
const gi = weightDecay !== 0 ? gi0 + weightDecay * pi : gi0;
|
|
593
|
+
const sq = safeArrayAccess(state.squareAvg, i, "AdaDelta squareAvg");
|
|
594
|
+
const sqNew = rho * sq + (1 - rho) * gi * gi;
|
|
595
|
+
state.squareAvg[i] = sqNew;
|
|
596
|
+
const std = Math.sqrt(sqNew + eps);
|
|
597
|
+
const accD = safeArrayAccess(state.accDelta, i, "AdaDelta accDelta");
|
|
598
|
+
const rmsUpdate = Math.sqrt(accD + eps);
|
|
599
|
+
const delta = rmsUpdate / std * gi;
|
|
600
|
+
state.accDelta[i] = rho * accD + (1 - rho) * delta * delta;
|
|
601
|
+
pData[pOff + i] = pi - lr * delta;
|
|
602
|
+
}
|
|
603
|
+
}
|
|
604
|
+
}
|
|
605
|
+
return loss;
|
|
606
|
+
}
|
|
607
|
+
};
|
|
608
|
+
|
|
609
|
+
// src/optim/optimizers/adagrad.ts
|
|
610
|
+
var Adagrad = class extends Optimizer {
|
|
611
|
+
_stepCount = 0;
|
|
612
|
+
get stepCount() {
|
|
613
|
+
return this._stepCount;
|
|
614
|
+
}
|
|
615
|
+
constructor(params, options = {}) {
|
|
616
|
+
const defaults = {
|
|
617
|
+
lr: options.lr ?? 0.01,
|
|
618
|
+
eps: options.eps ?? 1e-10,
|
|
619
|
+
weightDecay: options.weightDecay ?? 0,
|
|
620
|
+
lrDecay: options.lrDecay ?? 0
|
|
621
|
+
};
|
|
622
|
+
super(params, defaults);
|
|
623
|
+
assertFiniteNonNegative("learning rate", defaults.lr);
|
|
624
|
+
assertFinitePositive("epsilon", defaults.eps);
|
|
625
|
+
assertFiniteNonNegative("weight_decay value", defaults.weightDecay);
|
|
626
|
+
assertFiniteNonNegative("lr_decay", defaults.lrDecay);
|
|
627
|
+
}
|
|
628
|
+
/**
|
|
629
|
+
* Get the current learning rate.
|
|
630
|
+
*
|
|
631
|
+
* @param groupIdx - Parameter group index (default: 0)
|
|
632
|
+
* @returns Current learning rate
|
|
633
|
+
*/
|
|
634
|
+
getLearningRate(groupIdx = 0) {
|
|
635
|
+
const group = this.paramGroups[groupIdx];
|
|
636
|
+
if (!group) {
|
|
637
|
+
throw new InvalidParameterError(
|
|
638
|
+
`Invalid group index: ${groupIdx} (valid range: [0, ${this.paramGroups.length}))`,
|
|
639
|
+
"groupIdx",
|
|
640
|
+
groupIdx
|
|
641
|
+
);
|
|
642
|
+
}
|
|
643
|
+
return group.options.lr;
|
|
644
|
+
}
|
|
645
|
+
/**
|
|
646
|
+
* Set the learning rate for all parameter groups.
|
|
647
|
+
*
|
|
648
|
+
* @param lr - New learning rate
|
|
649
|
+
*/
|
|
650
|
+
setLearningRate(lr) {
|
|
651
|
+
assertFiniteNonNegative("learning rate", lr);
|
|
652
|
+
for (const group of this.paramGroups) {
|
|
653
|
+
group.options.lr = lr;
|
|
654
|
+
}
|
|
655
|
+
}
|
|
656
|
+
isState(state) {
|
|
657
|
+
return typeof state["step"] === "number" && state["sum"] instanceof Float64Array;
|
|
658
|
+
}
|
|
659
|
+
step(closure) {
|
|
660
|
+
let loss;
|
|
661
|
+
if (closure) {
|
|
662
|
+
loss = closure();
|
|
663
|
+
}
|
|
664
|
+
this._stepCount++;
|
|
665
|
+
for (const group of this.paramGroups) {
|
|
666
|
+
const { lr, eps, weightDecay, lrDecay } = group.options;
|
|
667
|
+
assertFiniteNonNegative("learning rate", lr);
|
|
668
|
+
assertFinitePositive("epsilon", eps);
|
|
669
|
+
assertFiniteNonNegative("weight_decay value", weightDecay);
|
|
670
|
+
assertFiniteNonNegative("lr_decay", lrDecay);
|
|
671
|
+
for (const param of group.params) {
|
|
672
|
+
const {
|
|
673
|
+
grad: gradData,
|
|
674
|
+
gradOffset: gOff,
|
|
675
|
+
param: pData,
|
|
676
|
+
paramOffset: pOff
|
|
677
|
+
} = assertHasGradFloat(param, "Adagrad");
|
|
678
|
+
const size = param.tensor.size;
|
|
679
|
+
const existing = this.state.get(param);
|
|
680
|
+
const state = existing ?? (() => {
|
|
681
|
+
const next = {
|
|
682
|
+
step: 0,
|
|
683
|
+
sum: new Float64Array(size)
|
|
684
|
+
};
|
|
685
|
+
this.state.set(param, next);
|
|
686
|
+
return next;
|
|
687
|
+
})();
|
|
688
|
+
assertBufferSize(state.sum, size, "Adagrad sum");
|
|
689
|
+
state.step += 1;
|
|
690
|
+
const clr = lr / (1 + (state.step - 1) * lrDecay);
|
|
691
|
+
for (let i = 0; i < size; i++) {
|
|
692
|
+
const gi0 = safeArrayAccess(gradData, gOff + i, "Adagrad gradient");
|
|
693
|
+
const pi = safeArrayAccess(pData, pOff + i, "Adagrad parameter");
|
|
694
|
+
assertFinite("gradient", gi0);
|
|
695
|
+
assertFinite("parameter", pi);
|
|
696
|
+
const gi = weightDecay !== 0 ? gi0 + weightDecay * pi : gi0;
|
|
697
|
+
const sumVal = safeArrayAccess(state.sum, i, "Adagrad sum");
|
|
698
|
+
const sumNew = sumVal + gi * gi;
|
|
699
|
+
state.sum[i] = sumNew;
|
|
700
|
+
const std = Math.sqrt(sumNew) + eps;
|
|
701
|
+
pData[pOff + i] = pi - clr * (gi / std);
|
|
702
|
+
}
|
|
703
|
+
}
|
|
704
|
+
}
|
|
705
|
+
return loss;
|
|
706
|
+
}
|
|
707
|
+
};
|
|
708
|
+
|
|
709
|
+
// src/optim/optimizers/adam.ts
|
|
710
|
+
var Adam = class extends Optimizer {
|
|
711
|
+
_stepCount = 0;
|
|
712
|
+
get stepCount() {
|
|
713
|
+
return this._stepCount;
|
|
714
|
+
}
|
|
715
|
+
constructor(params, options = {}) {
|
|
716
|
+
const defaults = {
|
|
717
|
+
lr: options.lr ?? 1e-3,
|
|
718
|
+
beta1: options.beta1 ?? 0.9,
|
|
719
|
+
beta2: options.beta2 ?? 0.999,
|
|
720
|
+
eps: options.eps ?? 1e-8,
|
|
721
|
+
weightDecay: options.weightDecay ?? 0,
|
|
722
|
+
amsgrad: options.amsgrad ?? false
|
|
723
|
+
};
|
|
724
|
+
super(params, defaults);
|
|
725
|
+
assertFiniteNonNegative("learning rate", defaults.lr);
|
|
726
|
+
assertInRange("beta1", defaults.beta1, 0, 1);
|
|
727
|
+
assertInRange("beta2", defaults.beta2, 0, 1);
|
|
728
|
+
assertFinitePositive("epsilon", defaults.eps);
|
|
729
|
+
assertFiniteNonNegative("weight_decay value", defaults.weightDecay);
|
|
730
|
+
}
|
|
731
|
+
/**
|
|
732
|
+
* Get the current learning rate.
|
|
733
|
+
*
|
|
734
|
+
* @param groupIdx - Parameter group index (default: 0)
|
|
735
|
+
* @returns Current learning rate
|
|
736
|
+
*/
|
|
737
|
+
getLearningRate(groupIdx = 0) {
|
|
738
|
+
const group = this.paramGroups[groupIdx];
|
|
739
|
+
if (!group) {
|
|
740
|
+
throw new InvalidParameterError(
|
|
741
|
+
`Invalid group index: ${groupIdx} (valid range: [0, ${this.paramGroups.length}))`,
|
|
742
|
+
"groupIdx",
|
|
743
|
+
groupIdx
|
|
744
|
+
);
|
|
745
|
+
}
|
|
746
|
+
return group.options.lr;
|
|
747
|
+
}
|
|
748
|
+
/**
|
|
749
|
+
* Set the learning rate for all parameter groups.
|
|
750
|
+
*
|
|
751
|
+
* @param lr - New learning rate
|
|
752
|
+
*/
|
|
753
|
+
setLearningRate(lr) {
|
|
754
|
+
assertFiniteNonNegative("learning rate", lr);
|
|
755
|
+
for (const group of this.paramGroups) {
|
|
756
|
+
group.options.lr = lr;
|
|
757
|
+
}
|
|
758
|
+
}
|
|
759
|
+
isState(state) {
|
|
760
|
+
const hasRequired = typeof state["step"] === "number" && state["expAvg"] instanceof Float64Array && state["expAvgSq"] instanceof Float64Array;
|
|
761
|
+
if (!hasRequired) return false;
|
|
762
|
+
if (state["maxExpAvgSq"] !== void 0 && !(state["maxExpAvgSq"] instanceof Float64Array)) {
|
|
763
|
+
return false;
|
|
764
|
+
}
|
|
765
|
+
return true;
|
|
766
|
+
}
|
|
767
|
+
step(closure) {
|
|
768
|
+
let loss;
|
|
769
|
+
if (closure) {
|
|
770
|
+
loss = closure();
|
|
771
|
+
}
|
|
772
|
+
this._stepCount++;
|
|
773
|
+
for (const group of this.paramGroups) {
|
|
774
|
+
const { lr, beta1, beta2, eps, weightDecay, amsgrad } = group.options;
|
|
775
|
+
assertFiniteNonNegative("learning rate", lr);
|
|
776
|
+
assertInRange("beta1", beta1, 0, 1);
|
|
777
|
+
assertInRange("beta2", beta2, 0, 1);
|
|
778
|
+
assertFinitePositive("epsilon", eps);
|
|
779
|
+
assertFiniteNonNegative("weight_decay value", weightDecay);
|
|
780
|
+
for (const param of group.params) {
|
|
781
|
+
const {
|
|
782
|
+
grad: gradData,
|
|
783
|
+
gradOffset,
|
|
784
|
+
param: paramData,
|
|
785
|
+
paramOffset
|
|
786
|
+
} = assertHasGradFloat(param, "Adam");
|
|
787
|
+
const size = param.tensor.size;
|
|
788
|
+
const existing = this.state.get(param);
|
|
789
|
+
const state = existing ?? (() => {
|
|
790
|
+
const next = {
|
|
791
|
+
step: 0,
|
|
792
|
+
expAvg: new Float64Array(size),
|
|
793
|
+
expAvgSq: new Float64Array(size),
|
|
794
|
+
...amsgrad ? { maxExpAvgSq: new Float64Array(size) } : {}
|
|
795
|
+
};
|
|
796
|
+
this.state.set(param, next);
|
|
797
|
+
return next;
|
|
798
|
+
})();
|
|
799
|
+
assertBufferSize(state.expAvg, size, "Adam expAvg");
|
|
800
|
+
assertBufferSize(state.expAvgSq, size, "Adam expAvgSq");
|
|
801
|
+
if (amsgrad && state.maxExpAvgSq) {
|
|
802
|
+
assertBufferSize(state.maxExpAvgSq, size, "Adam maxExpAvgSq");
|
|
803
|
+
}
|
|
804
|
+
state.step += 1;
|
|
805
|
+
const biasCorrection1 = 1 - beta1 ** state.step;
|
|
806
|
+
const biasCorrection2 = 1 - beta2 ** state.step;
|
|
807
|
+
const stepSize = lr / biasCorrection1;
|
|
808
|
+
for (let i = 0; i < size; i++) {
|
|
809
|
+
const gi0 = safeArrayAccess(gradData, gradOffset + i, "Adam gradient");
|
|
810
|
+
const pi = safeArrayAccess(paramData, paramOffset + i, "Adam parameter");
|
|
811
|
+
assertFinite("gradient", gi0);
|
|
812
|
+
assertFinite("parameter", pi);
|
|
813
|
+
const gi = weightDecay !== 0 ? gi0 + weightDecay * pi : gi0;
|
|
814
|
+
const m = safeArrayAccess(state.expAvg, i, "Adam expAvg");
|
|
815
|
+
const v = safeArrayAccess(state.expAvgSq, i, "Adam expAvgSq");
|
|
816
|
+
const mNew = beta1 * m + (1 - beta1) * gi;
|
|
817
|
+
const vNew = beta2 * v + (1 - beta2) * gi * gi;
|
|
818
|
+
state.expAvg[i] = mNew;
|
|
819
|
+
state.expAvgSq[i] = vNew;
|
|
820
|
+
let denomSq = vNew;
|
|
821
|
+
if (amsgrad) {
|
|
822
|
+
const maxBuf = state.maxExpAvgSq;
|
|
823
|
+
if (!maxBuf) {
|
|
824
|
+
throw new DeepboxError("Internal error: AMSGrad enabled but maxExpAvgSq is missing");
|
|
825
|
+
}
|
|
826
|
+
const maxV = Math.max(safeArrayAccess(maxBuf, i, "Adam maxExpAvgSq"), vNew);
|
|
827
|
+
maxBuf[i] = maxV;
|
|
828
|
+
denomSq = maxV;
|
|
829
|
+
}
|
|
830
|
+
const denom = Math.sqrt(denomSq / biasCorrection2) + eps;
|
|
831
|
+
paramData[paramOffset + i] = pi - stepSize * (mNew / denom);
|
|
832
|
+
}
|
|
833
|
+
}
|
|
834
|
+
}
|
|
835
|
+
return loss;
|
|
836
|
+
}
|
|
837
|
+
};
|
|
838
|
+
|
|
839
|
+
// src/optim/optimizers/adamw.ts
|
|
840
|
+
var AdamW = class extends Optimizer {
|
|
841
|
+
/** Internal counter tracking total number of optimization steps */
|
|
842
|
+
_stepCount = 0;
|
|
843
|
+
/**
|
|
844
|
+
* Get the total number of optimization steps performed.
|
|
845
|
+
*
|
|
846
|
+
* @returns Number of steps taken
|
|
847
|
+
*/
|
|
848
|
+
get stepCount() {
|
|
849
|
+
return this._stepCount;
|
|
850
|
+
}
|
|
851
|
+
/**
|
|
852
|
+
* Create a new AdamW optimizer.
|
|
853
|
+
*
|
|
854
|
+
* @param params - Iterable of parameters or parameter groups to optimize
|
|
855
|
+
* @param options - Optimization options
|
|
856
|
+
* @param options.lr - Learning rate (default: 0.001)
|
|
857
|
+
* @param options.beta1 - First moment decay rate (default: 0.9)
|
|
858
|
+
* @param options.beta2 - Second moment decay rate (default: 0.999)
|
|
859
|
+
* @param options.eps - Numerical stability constant (default: 1e-8)
|
|
860
|
+
* @param options.weightDecay - Weight decay coefficient (default: 0.01)
|
|
861
|
+
* @param options.amsgrad - Enable AMSGrad variant (default: false)
|
|
862
|
+
* @throws {InvalidParameterError} If a parameter is invalid
|
|
863
|
+
*/
|
|
864
|
+
constructor(params, options = {}) {
|
|
865
|
+
const defaults = {
|
|
866
|
+
lr: options.lr ?? 1e-3,
|
|
867
|
+
beta1: options.beta1 ?? 0.9,
|
|
868
|
+
beta2: options.beta2 ?? 0.999,
|
|
869
|
+
eps: options.eps ?? 1e-8,
|
|
870
|
+
weightDecay: options.weightDecay ?? 0.01,
|
|
871
|
+
// Higher default than Adam
|
|
872
|
+
amsgrad: options.amsgrad ?? false
|
|
873
|
+
};
|
|
874
|
+
super(params, defaults);
|
|
875
|
+
assertFiniteNonNegative("learning rate", defaults.lr);
|
|
876
|
+
assertInRange("beta1", defaults.beta1, 0, 1);
|
|
877
|
+
assertInRange("beta2", defaults.beta2, 0, 1);
|
|
878
|
+
assertFinitePositive("epsilon", defaults.eps);
|
|
879
|
+
assertFiniteNonNegative("weight_decay value", defaults.weightDecay);
|
|
880
|
+
}
|
|
881
|
+
/**
|
|
882
|
+
* Get the current learning rate.
|
|
883
|
+
*
|
|
884
|
+
* @param groupIdx - Parameter group index (default: 0)
|
|
885
|
+
* @returns Current learning rate
|
|
886
|
+
*/
|
|
887
|
+
getLearningRate(groupIdx = 0) {
|
|
888
|
+
const group = this.paramGroups[groupIdx];
|
|
889
|
+
if (!group) {
|
|
890
|
+
throw new InvalidParameterError(
|
|
891
|
+
`Invalid group index: ${groupIdx} (valid range: [0, ${this.paramGroups.length}))`,
|
|
892
|
+
"groupIdx",
|
|
893
|
+
groupIdx
|
|
894
|
+
);
|
|
895
|
+
}
|
|
896
|
+
return group.options.lr;
|
|
897
|
+
}
|
|
898
|
+
/**
|
|
899
|
+
* Set the learning rate for all parameter groups.
|
|
900
|
+
*
|
|
901
|
+
* @param lr - New learning rate
|
|
902
|
+
*/
|
|
903
|
+
setLearningRate(lr) {
|
|
904
|
+
assertFiniteNonNegative("learning rate", lr);
|
|
905
|
+
for (const group of this.paramGroups) {
|
|
906
|
+
group.options.lr = lr;
|
|
907
|
+
}
|
|
908
|
+
}
|
|
909
|
+
/**
|
|
910
|
+
* Perform a single optimization step (parameter update).
|
|
911
|
+
*
|
|
912
|
+
* Implements the AdamW update rule with decoupled weight decay.
|
|
913
|
+
*
|
|
914
|
+
* @param closure - Optional closure that reevaluates the model and returns the loss
|
|
915
|
+
* @returns Loss value if closure is provided, undefined otherwise
|
|
916
|
+
*/
|
|
917
|
+
isState(state) {
|
|
918
|
+
const hasRequired = typeof state["step"] === "number" && state["expAvg"] instanceof Float64Array && state["expAvgSq"] instanceof Float64Array;
|
|
919
|
+
if (!hasRequired) return false;
|
|
920
|
+
if (state["maxExpAvgSq"] !== void 0 && !(state["maxExpAvgSq"] instanceof Float64Array)) {
|
|
921
|
+
return false;
|
|
922
|
+
}
|
|
923
|
+
return true;
|
|
924
|
+
}
|
|
925
|
+
step(closure) {
|
|
926
|
+
let loss;
|
|
927
|
+
if (closure) {
|
|
928
|
+
loss = closure();
|
|
929
|
+
}
|
|
930
|
+
this._stepCount++;
|
|
931
|
+
for (const group of this.paramGroups) {
|
|
932
|
+
const { lr, beta1, beta2, eps, weightDecay, amsgrad } = group.options;
|
|
933
|
+
assertFiniteNonNegative("learning rate", lr);
|
|
934
|
+
assertInRange("beta1", beta1, 0, 1);
|
|
935
|
+
assertInRange("beta2", beta2, 0, 1);
|
|
936
|
+
assertFinitePositive("epsilon", eps);
|
|
937
|
+
assertFiniteNonNegative("weight_decay value", weightDecay);
|
|
938
|
+
for (const param of group.params) {
|
|
939
|
+
const {
|
|
940
|
+
grad,
|
|
941
|
+
gradOffset,
|
|
942
|
+
param: pData,
|
|
943
|
+
paramOffset: pOff
|
|
944
|
+
} = assertHasGradFloat(param, "AdamW");
|
|
945
|
+
const size = param.tensor.size;
|
|
946
|
+
const existing = this.state.get(param);
|
|
947
|
+
const state = existing ?? (() => {
|
|
948
|
+
const next = {
|
|
949
|
+
step: 0,
|
|
950
|
+
expAvg: new Float64Array(size),
|
|
951
|
+
// First moment
|
|
952
|
+
expAvgSq: new Float64Array(size),
|
|
953
|
+
// Second moment
|
|
954
|
+
...amsgrad ? { maxExpAvgSq: new Float64Array(size) } : {}
|
|
955
|
+
// AMSGrad buffer
|
|
956
|
+
};
|
|
957
|
+
this.state.set(param, next);
|
|
958
|
+
return next;
|
|
959
|
+
})();
|
|
960
|
+
assertBufferSize(state.expAvg, size, "AdamW expAvg");
|
|
961
|
+
assertBufferSize(state.expAvgSq, size, "AdamW expAvgSq");
|
|
962
|
+
if (amsgrad && state.maxExpAvgSq) {
|
|
963
|
+
assertBufferSize(state.maxExpAvgSq, size, "AdamW maxExpAvgSq");
|
|
964
|
+
}
|
|
965
|
+
state.step += 1;
|
|
966
|
+
const biasCorrection1 = 1 - beta1 ** state.step;
|
|
967
|
+
const biasCorrection2 = 1 - beta2 ** state.step;
|
|
968
|
+
const stepSize = lr / biasCorrection1;
|
|
969
|
+
for (let i = 0; i < size; i++) {
|
|
970
|
+
const gi = safeArrayAccess(grad, gradOffset + i, "AdamW gradient");
|
|
971
|
+
const pi = safeArrayAccess(pData, pOff + i, "AdamW parameter");
|
|
972
|
+
assertFinite("gradient", gi);
|
|
973
|
+
assertFinite("parameter", pi);
|
|
974
|
+
const m = safeArrayAccess(state.expAvg, i, "AdamW expAvg");
|
|
975
|
+
const v = safeArrayAccess(state.expAvgSq, i, "AdamW expAvgSq");
|
|
976
|
+
const mNew = beta1 * m + (1 - beta1) * gi;
|
|
977
|
+
const vNew = beta2 * v + (1 - beta2) * gi * gi;
|
|
978
|
+
state.expAvg[i] = mNew;
|
|
979
|
+
state.expAvgSq[i] = vNew;
|
|
980
|
+
let denomSq = vNew;
|
|
981
|
+
if (amsgrad) {
|
|
982
|
+
const maxBuf = state.maxExpAvgSq;
|
|
983
|
+
if (!maxBuf) {
|
|
984
|
+
throw new DeepboxError("Internal error: AMSGrad enabled but maxExpAvgSq is missing");
|
|
985
|
+
}
|
|
986
|
+
const maxV = Math.max(safeArrayAccess(maxBuf, i, "AdamW maxExpAvgSq"), vNew);
|
|
987
|
+
maxBuf[i] = maxV;
|
|
988
|
+
denomSq = maxV;
|
|
989
|
+
}
|
|
990
|
+
const denom = Math.sqrt(denomSq / biasCorrection2) + eps;
|
|
991
|
+
pData[pOff + i] = pi - stepSize * (mNew / denom) - lr * weightDecay * pi;
|
|
992
|
+
}
|
|
993
|
+
}
|
|
994
|
+
}
|
|
995
|
+
return loss;
|
|
996
|
+
}
|
|
997
|
+
};
|
|
998
|
+
|
|
999
|
+
// src/optim/optimizers/nadam.ts
|
|
1000
|
+
var Nadam = class extends Optimizer {
|
|
1001
|
+
_stepCount = 0;
|
|
1002
|
+
get stepCount() {
|
|
1003
|
+
return this._stepCount;
|
|
1004
|
+
}
|
|
1005
|
+
constructor(params, options = {}) {
|
|
1006
|
+
const defaults = {
|
|
1007
|
+
lr: options.lr ?? 2e-3,
|
|
1008
|
+
beta1: options.beta1 ?? 0.9,
|
|
1009
|
+
beta2: options.beta2 ?? 0.999,
|
|
1010
|
+
eps: options.eps ?? 1e-8,
|
|
1011
|
+
weightDecay: options.weightDecay ?? 0,
|
|
1012
|
+
momentumDecay: options.momentumDecay ?? 4e-3
|
|
1013
|
+
};
|
|
1014
|
+
super(params, defaults);
|
|
1015
|
+
assertFiniteNonNegative("learning rate", defaults.lr);
|
|
1016
|
+
assertInRange("beta1", defaults.beta1, 0, 1);
|
|
1017
|
+
assertInRange("beta2", defaults.beta2, 0, 1);
|
|
1018
|
+
assertFinitePositive("epsilon", defaults.eps);
|
|
1019
|
+
assertFiniteNonNegative("weight_decay value", defaults.weightDecay);
|
|
1020
|
+
assertFiniteNonNegative("momentum_decay", defaults.momentumDecay);
|
|
1021
|
+
}
|
|
1022
|
+
/**
|
|
1023
|
+
* Get the current learning rate.
|
|
1024
|
+
*
|
|
1025
|
+
* @param groupIdx - Parameter group index (default: 0)
|
|
1026
|
+
* @returns Current learning rate
|
|
1027
|
+
*/
|
|
1028
|
+
getLearningRate(groupIdx = 0) {
|
|
1029
|
+
const group = this.paramGroups[groupIdx];
|
|
1030
|
+
if (!group) {
|
|
1031
|
+
throw new InvalidParameterError(
|
|
1032
|
+
`Invalid group index: ${groupIdx} (valid range: [0, ${this.paramGroups.length}))`,
|
|
1033
|
+
"groupIdx",
|
|
1034
|
+
groupIdx
|
|
1035
|
+
);
|
|
1036
|
+
}
|
|
1037
|
+
return group.options.lr;
|
|
1038
|
+
}
|
|
1039
|
+
/**
|
|
1040
|
+
* Set the learning rate for all parameter groups.
|
|
1041
|
+
*
|
|
1042
|
+
* @param lr - New learning rate
|
|
1043
|
+
*/
|
|
1044
|
+
setLearningRate(lr) {
|
|
1045
|
+
assertFiniteNonNegative("learning rate", lr);
|
|
1046
|
+
for (const group of this.paramGroups) {
|
|
1047
|
+
group.options.lr = lr;
|
|
1048
|
+
}
|
|
1049
|
+
}
|
|
1050
|
+
isState(state) {
|
|
1051
|
+
return typeof state["step"] === "number" && state["expAvg"] instanceof Float64Array && state["expAvgSq"] instanceof Float64Array && typeof state["muProduct"] === "number";
|
|
1052
|
+
}
|
|
1053
|
+
step(closure) {
|
|
1054
|
+
let loss;
|
|
1055
|
+
if (closure) {
|
|
1056
|
+
loss = closure();
|
|
1057
|
+
}
|
|
1058
|
+
this._stepCount++;
|
|
1059
|
+
for (const group of this.paramGroups) {
|
|
1060
|
+
const { lr, beta1, beta2, eps, weightDecay, momentumDecay } = group.options;
|
|
1061
|
+
assertFiniteNonNegative("learning rate", lr);
|
|
1062
|
+
assertInRange("beta1", beta1, 0, 1);
|
|
1063
|
+
assertInRange("beta2", beta2, 0, 1);
|
|
1064
|
+
assertFinitePositive("epsilon", eps);
|
|
1065
|
+
assertFiniteNonNegative("weight_decay value", weightDecay);
|
|
1066
|
+
assertFiniteNonNegative("momentum_decay", momentumDecay);
|
|
1067
|
+
for (const param of group.params) {
|
|
1068
|
+
const {
|
|
1069
|
+
grad: gradData,
|
|
1070
|
+
gradOffset: gOff,
|
|
1071
|
+
param: pData,
|
|
1072
|
+
paramOffset: pOff
|
|
1073
|
+
} = assertHasGradFloat(param, "Nadam");
|
|
1074
|
+
const size = param.tensor.size;
|
|
1075
|
+
let state = this.state.get(param);
|
|
1076
|
+
if (!state) {
|
|
1077
|
+
state = {
|
|
1078
|
+
step: 0,
|
|
1079
|
+
expAvg: new Float64Array(size),
|
|
1080
|
+
expAvgSq: new Float64Array(size),
|
|
1081
|
+
muProduct: 1
|
|
1082
|
+
};
|
|
1083
|
+
this.state.set(param, state);
|
|
1084
|
+
}
|
|
1085
|
+
assertBufferSize(state.expAvg, size, "Nadam expAvg");
|
|
1086
|
+
assertBufferSize(state.expAvgSq, size, "Nadam expAvgSq");
|
|
1087
|
+
state.step++;
|
|
1088
|
+
const t = state.step;
|
|
1089
|
+
const biasCorrection2 = 1 - beta2 ** t;
|
|
1090
|
+
const mu = beta1 * (1 - 0.5 * 0.96 ** (t * momentumDecay));
|
|
1091
|
+
const muNext = beta1 * (1 - 0.5 * 0.96 ** ((t + 1) * momentumDecay));
|
|
1092
|
+
const muProduct = state.muProduct * mu;
|
|
1093
|
+
const muProductNext = muProduct * muNext;
|
|
1094
|
+
state.muProduct = muProduct;
|
|
1095
|
+
for (let i = 0; i < size; i++) {
|
|
1096
|
+
const gi0 = safeArrayAccess(gradData, gOff + i, "Nadam gradient");
|
|
1097
|
+
const pi = safeArrayAccess(pData, pOff + i, "Nadam parameter");
|
|
1098
|
+
assertFinite("gradient", gi0);
|
|
1099
|
+
assertFinite("parameter", pi);
|
|
1100
|
+
const gi = weightDecay !== 0 ? gi0 + weightDecay * pi : gi0;
|
|
1101
|
+
const m = safeArrayAccess(state.expAvg, i, "Nadam expAvg");
|
|
1102
|
+
const mNew = beta1 * m + (1 - beta1) * gi;
|
|
1103
|
+
state.expAvg[i] = mNew;
|
|
1104
|
+
const v = safeArrayAccess(state.expAvgSq, i, "Nadam expAvgSq");
|
|
1105
|
+
const vNew = beta2 * v + (1 - beta2) * gi * gi;
|
|
1106
|
+
state.expAvgSq[i] = vNew;
|
|
1107
|
+
const denom = Math.sqrt(vNew / biasCorrection2) + eps;
|
|
1108
|
+
const mHatNext = mNew / (1 - muProductNext);
|
|
1109
|
+
const gHat = gi / (1 - muProduct);
|
|
1110
|
+
const mNesterov = muNext * mHatNext + (1 - mu) * gHat;
|
|
1111
|
+
pData[pOff + i] = pi - lr * mNesterov / denom;
|
|
1112
|
+
}
|
|
1113
|
+
}
|
|
1114
|
+
}
|
|
1115
|
+
return loss;
|
|
1116
|
+
}
|
|
1117
|
+
};
|
|
1118
|
+
|
|
1119
|
+
// src/optim/optimizers/rmsprop.ts
|
|
1120
|
+
var RMSprop = class extends Optimizer {
|
|
1121
|
+
/** Internal counter tracking total number of optimization steps */
|
|
1122
|
+
_stepCount = 0;
|
|
1123
|
+
/**
|
|
1124
|
+
* Get the total number of optimization steps performed.
|
|
1125
|
+
*
|
|
1126
|
+
* @returns Number of steps taken
|
|
1127
|
+
*/
|
|
1128
|
+
get stepCount() {
|
|
1129
|
+
return this._stepCount;
|
|
1130
|
+
}
|
|
1131
|
+
/**
|
|
1132
|
+
* Create a new RMSprop optimizer.
|
|
1133
|
+
*
|
|
1134
|
+
* @param params - Iterable of parameters or parameter groups to optimize
|
|
1135
|
+
* @param options - Optimization options
|
|
1136
|
+
* @param options.lr - Learning rate (default: 0.01)
|
|
1137
|
+
* @param options.alpha - Smoothing constant (default: 0.99)
|
|
1138
|
+
* @param options.eps - Numerical stability constant (default: 1e-8)
|
|
1139
|
+
* @param options.weightDecay - Weight decay coefficient (default: 0)
|
|
1140
|
+
* @param options.momentum - Momentum factor (default: 0)
|
|
1141
|
+
* @param options.centered - Use centered variant (default: false)
|
|
1142
|
+
* @throws {InvalidParameterError} If a parameter is invalid
|
|
1143
|
+
*/
|
|
1144
|
+
constructor(params, options = {}) {
|
|
1145
|
+
const defaults = {
|
|
1146
|
+
lr: options.lr ?? 0.01,
|
|
1147
|
+
alpha: options.alpha ?? 0.99,
|
|
1148
|
+
eps: options.eps ?? 1e-8,
|
|
1149
|
+
weightDecay: options.weightDecay ?? 0,
|
|
1150
|
+
momentum: options.momentum ?? 0,
|
|
1151
|
+
centered: options.centered ?? false
|
|
1152
|
+
};
|
|
1153
|
+
super(params, defaults);
|
|
1154
|
+
assertFiniteNonNegative("learning rate", defaults.lr);
|
|
1155
|
+
if (!Number.isFinite(defaults.alpha) || defaults.alpha < 0 || defaults.alpha > 1) {
|
|
1156
|
+
throw new InvalidParameterError(
|
|
1157
|
+
`Invalid alpha: ${defaults.alpha} (must be in range [0, 1])`,
|
|
1158
|
+
"alpha",
|
|
1159
|
+
defaults.alpha
|
|
1160
|
+
);
|
|
1161
|
+
}
|
|
1162
|
+
assertFinitePositive("epsilon", defaults.eps);
|
|
1163
|
+
assertFiniteNonNegative("weight_decay value", defaults.weightDecay);
|
|
1164
|
+
assertFiniteNonNegative("momentum value", defaults.momentum);
|
|
1165
|
+
}
|
|
1166
|
+
/**
|
|
1167
|
+
* Get the current learning rate.
|
|
1168
|
+
*
|
|
1169
|
+
* @param groupIdx - Parameter group index (default: 0)
|
|
1170
|
+
* @returns Current learning rate
|
|
1171
|
+
*/
|
|
1172
|
+
getLearningRate(groupIdx = 0) {
|
|
1173
|
+
const group = this.paramGroups[groupIdx];
|
|
1174
|
+
if (!group) {
|
|
1175
|
+
throw new InvalidParameterError(
|
|
1176
|
+
`Invalid group index: ${groupIdx} (valid range: [0, ${this.paramGroups.length}))`,
|
|
1177
|
+
"groupIdx",
|
|
1178
|
+
groupIdx
|
|
1179
|
+
);
|
|
1180
|
+
}
|
|
1181
|
+
return group.options.lr;
|
|
1182
|
+
}
|
|
1183
|
+
/**
|
|
1184
|
+
* Set the learning rate for all parameter groups.
|
|
1185
|
+
*
|
|
1186
|
+
* @param lr - New learning rate
|
|
1187
|
+
*/
|
|
1188
|
+
setLearningRate(lr) {
|
|
1189
|
+
assertFiniteNonNegative("learning rate", lr);
|
|
1190
|
+
for (const group of this.paramGroups) {
|
|
1191
|
+
group.options.lr = lr;
|
|
1192
|
+
}
|
|
1193
|
+
}
|
|
1194
|
+
isState(state) {
|
|
1195
|
+
if (!(state["squareAvg"] instanceof Float64Array)) return false;
|
|
1196
|
+
if (state["momentumBuffer"] !== void 0 && !(state["momentumBuffer"] instanceof Float64Array)) {
|
|
1197
|
+
return false;
|
|
1198
|
+
}
|
|
1199
|
+
if (state["gradAvg"] !== void 0 && !(state["gradAvg"] instanceof Float64Array)) {
|
|
1200
|
+
return false;
|
|
1201
|
+
}
|
|
1202
|
+
return true;
|
|
1203
|
+
}
|
|
1204
|
+
step(closure) {
|
|
1205
|
+
let loss;
|
|
1206
|
+
if (closure) {
|
|
1207
|
+
loss = closure();
|
|
1208
|
+
}
|
|
1209
|
+
this._stepCount++;
|
|
1210
|
+
for (const group of this.paramGroups) {
|
|
1211
|
+
const { lr, alpha, eps, weightDecay, momentum, centered } = group.options;
|
|
1212
|
+
assertFiniteNonNegative("learning rate", lr);
|
|
1213
|
+
if (!Number.isFinite(alpha) || alpha < 0 || alpha > 1) {
|
|
1214
|
+
throw new InvalidParameterError(
|
|
1215
|
+
`Invalid alpha: ${alpha} (must be in range [0, 1])`,
|
|
1216
|
+
"alpha",
|
|
1217
|
+
alpha
|
|
1218
|
+
);
|
|
1219
|
+
}
|
|
1220
|
+
assertFinitePositive("epsilon", eps);
|
|
1221
|
+
assertFiniteNonNegative("weight_decay value", weightDecay);
|
|
1222
|
+
assertFiniteNonNegative("momentum value", momentum);
|
|
1223
|
+
for (const param of group.params) {
|
|
1224
|
+
const {
|
|
1225
|
+
grad: gradData,
|
|
1226
|
+
gradOffset: gOff,
|
|
1227
|
+
param: pData,
|
|
1228
|
+
paramOffset: pOff
|
|
1229
|
+
} = assertHasGradFloat(param, "RMSprop");
|
|
1230
|
+
const size = param.tensor.size;
|
|
1231
|
+
let state = this.state.get(param);
|
|
1232
|
+
if (!state) {
|
|
1233
|
+
state = {
|
|
1234
|
+
squareAvg: new Float64Array(size)
|
|
1235
|
+
};
|
|
1236
|
+
this.state.set(param, state);
|
|
1237
|
+
}
|
|
1238
|
+
if (momentum > 0 && !state.momentumBuffer) {
|
|
1239
|
+
state.momentumBuffer = new Float64Array(size);
|
|
1240
|
+
}
|
|
1241
|
+
if (centered && !state.gradAvg) {
|
|
1242
|
+
state.gradAvg = new Float64Array(size);
|
|
1243
|
+
}
|
|
1244
|
+
assertBufferSize(state.squareAvg, size, "RMSprop squareAvg");
|
|
1245
|
+
if (momentum > 0 && state.momentumBuffer) {
|
|
1246
|
+
assertBufferSize(state.momentumBuffer, size, "RMSprop momentumBuffer");
|
|
1247
|
+
}
|
|
1248
|
+
if (centered && state.gradAvg) {
|
|
1249
|
+
assertBufferSize(state.gradAvg, size, "RMSprop gradAvg");
|
|
1250
|
+
}
|
|
1251
|
+
for (let i = 0; i < size; i++) {
|
|
1252
|
+
const gi = safeArrayAccess(gradData, gOff + i, "RMSprop gradient");
|
|
1253
|
+
const pi = safeArrayAccess(pData, pOff + i, "RMSprop parameter");
|
|
1254
|
+
assertFinite("gradient", gi);
|
|
1255
|
+
assertFinite("parameter", pi);
|
|
1256
|
+
let grad = gi;
|
|
1257
|
+
if (weightDecay !== 0) {
|
|
1258
|
+
grad = grad + weightDecay * pi;
|
|
1259
|
+
}
|
|
1260
|
+
const sqAvg = safeArrayAccess(state.squareAvg, i, "RMSprop squareAvg");
|
|
1261
|
+
const sqAvgNew = alpha * sqAvg + (1 - alpha) * grad * grad;
|
|
1262
|
+
state.squareAvg[i] = sqAvgNew;
|
|
1263
|
+
let avg = sqAvgNew;
|
|
1264
|
+
if (centered) {
|
|
1265
|
+
const gAvg = state.gradAvg ? safeArrayAccess(state.gradAvg, i, "RMSprop gradAvg") : 0;
|
|
1266
|
+
const gAvgNew = alpha * gAvg + (1 - alpha) * grad;
|
|
1267
|
+
if (state.gradAvg) state.gradAvg[i] = gAvgNew;
|
|
1268
|
+
avg = sqAvgNew - gAvgNew * gAvgNew;
|
|
1269
|
+
}
|
|
1270
|
+
const denom = centered ? Math.sqrt(Math.max(avg, 0) + eps) : Math.sqrt(avg) + eps;
|
|
1271
|
+
const normalizedGrad = grad / denom;
|
|
1272
|
+
if (momentum > 0) {
|
|
1273
|
+
const buf = state.momentumBuffer ? safeArrayAccess(state.momentumBuffer, i, "RMSprop momentumBuffer") : 0;
|
|
1274
|
+
const bufNew = momentum * buf + normalizedGrad;
|
|
1275
|
+
if (state.momentumBuffer) state.momentumBuffer[i] = bufNew;
|
|
1276
|
+
pData[pOff + i] = pi - lr * bufNew;
|
|
1277
|
+
} else {
|
|
1278
|
+
pData[pOff + i] = pi - lr * normalizedGrad;
|
|
1279
|
+
}
|
|
1280
|
+
}
|
|
1281
|
+
}
|
|
1282
|
+
}
|
|
1283
|
+
return loss;
|
|
1284
|
+
}
|
|
1285
|
+
};
|
|
1286
|
+
|
|
1287
|
+
// src/optim/optimizers/sgd.ts
|
|
1288
|
+
var SGD = class extends Optimizer {
|
|
1289
|
+
/** Internal counter tracking total number of optimization steps */
|
|
1290
|
+
_stepCount = 0;
|
|
1291
|
+
get stepCount() {
|
|
1292
|
+
return this._stepCount;
|
|
1293
|
+
}
|
|
1294
|
+
/**
|
|
1295
|
+
* Create a new SGD optimizer.
|
|
1296
|
+
*
|
|
1297
|
+
* @param params - Iterable of parameters or parameter groups to optimize
|
|
1298
|
+
* @param options - Optimization options
|
|
1299
|
+
* @param options.lr - Learning rate (default: 0.01)
|
|
1300
|
+
* @param options.momentum - Momentum factor (default: 0)
|
|
1301
|
+
* @param options.dampening - Dampening for momentum (default: 0)
|
|
1302
|
+
* @param options.weightDecay - Weight decay (L2 penalty) (default: 0)
|
|
1303
|
+
* @param options.nesterov - Enable Nesterov momentum (default: false)
|
|
1304
|
+
*/
|
|
1305
|
+
constructor(params, options = {}) {
|
|
1306
|
+
const defaults = {
|
|
1307
|
+
lr: options.lr ?? 0.01,
|
|
1308
|
+
momentum: options.momentum ?? 0,
|
|
1309
|
+
dampening: options.dampening ?? 0,
|
|
1310
|
+
weightDecay: options.weightDecay ?? 0,
|
|
1311
|
+
nesterov: options.nesterov ?? false
|
|
1312
|
+
};
|
|
1313
|
+
super(params, defaults);
|
|
1314
|
+
assertFiniteNonNegative("learning rate", defaults.lr);
|
|
1315
|
+
assertFiniteNonNegative("momentum value", defaults.momentum);
|
|
1316
|
+
assertFiniteNonNegative("dampening", defaults.dampening);
|
|
1317
|
+
assertFiniteNonNegative("weight_decay value", defaults.weightDecay);
|
|
1318
|
+
if (defaults.nesterov && (defaults.momentum <= 0 || defaults.dampening !== 0)) {
|
|
1319
|
+
throw new InvalidParameterError(
|
|
1320
|
+
"Nesterov momentum requires a momentum and zero dampening",
|
|
1321
|
+
"nesterov",
|
|
1322
|
+
{
|
|
1323
|
+
momentum: defaults.momentum,
|
|
1324
|
+
dampening: defaults.dampening,
|
|
1325
|
+
nesterov: defaults.nesterov
|
|
1326
|
+
}
|
|
1327
|
+
);
|
|
1328
|
+
}
|
|
1329
|
+
}
|
|
1330
|
+
/**
|
|
1331
|
+
* Perform a single optimization step.
|
|
1332
|
+
*
|
|
1333
|
+
* Implements the SGD update rule with optional momentum and weight decay.
|
|
1334
|
+
*
|
|
1335
|
+
* @param closure - Optional closure that reevaluates the model and returns the loss
|
|
1336
|
+
* @returns Loss value if closure is provided
|
|
1337
|
+
*/
|
|
1338
|
+
isState(state) {
|
|
1339
|
+
if (state["momentumBuffer"] !== void 0 && !(state["momentumBuffer"] instanceof Float64Array)) {
|
|
1340
|
+
return false;
|
|
1341
|
+
}
|
|
1342
|
+
return true;
|
|
1343
|
+
}
|
|
1344
|
+
step(closure) {
|
|
1345
|
+
let loss;
|
|
1346
|
+
if (closure) {
|
|
1347
|
+
loss = closure();
|
|
1348
|
+
}
|
|
1349
|
+
this._stepCount++;
|
|
1350
|
+
for (const group of this.paramGroups) {
|
|
1351
|
+
const { lr, momentum, dampening, weightDecay, nesterov } = group.options;
|
|
1352
|
+
assertFiniteNonNegative("learning rate", lr);
|
|
1353
|
+
assertFiniteNonNegative("momentum value", momentum);
|
|
1354
|
+
assertFiniteNonNegative("dampening", dampening);
|
|
1355
|
+
assertFiniteNonNegative("weight_decay value", weightDecay);
|
|
1356
|
+
if (nesterov && (momentum <= 0 || dampening !== 0)) {
|
|
1357
|
+
throw new InvalidParameterError(
|
|
1358
|
+
"Nesterov momentum requires a momentum and zero dampening",
|
|
1359
|
+
"nesterov",
|
|
1360
|
+
{ momentum, dampening, nesterov }
|
|
1361
|
+
);
|
|
1362
|
+
}
|
|
1363
|
+
for (const param of group.params) {
|
|
1364
|
+
const {
|
|
1365
|
+
grad: gradData,
|
|
1366
|
+
gradOffset,
|
|
1367
|
+
param: paramData,
|
|
1368
|
+
paramOffset
|
|
1369
|
+
} = assertHasGradFloat(param, "SGD");
|
|
1370
|
+
const size = param.tensor.size;
|
|
1371
|
+
let state = this.state.get(param);
|
|
1372
|
+
if (!state) {
|
|
1373
|
+
state = {};
|
|
1374
|
+
this.state.set(param, state);
|
|
1375
|
+
}
|
|
1376
|
+
let momentumBuffer;
|
|
1377
|
+
if (momentum !== 0) {
|
|
1378
|
+
if (!state.momentumBuffer) {
|
|
1379
|
+
state.momentumBuffer = new Float64Array(size);
|
|
1380
|
+
}
|
|
1381
|
+
momentumBuffer = state.momentumBuffer;
|
|
1382
|
+
}
|
|
1383
|
+
for (let i = 0; i < size; i++) {
|
|
1384
|
+
const gi = safeArrayAccess(gradData, gradOffset + i, "SGD gradient");
|
|
1385
|
+
const pi = safeArrayAccess(paramData, paramOffset + i, "SGD parameter");
|
|
1386
|
+
assertFinite("gradient", gi);
|
|
1387
|
+
assertFinite("parameter", pi);
|
|
1388
|
+
let d = gi;
|
|
1389
|
+
if (weightDecay !== 0) {
|
|
1390
|
+
d = d + weightDecay * pi;
|
|
1391
|
+
}
|
|
1392
|
+
if (momentumBuffer) {
|
|
1393
|
+
const bPrev = safeArrayAccess(momentumBuffer, i, "SGD momentum buffer");
|
|
1394
|
+
const bNew = momentum * bPrev + (1 - dampening) * d;
|
|
1395
|
+
momentumBuffer[i] = bNew;
|
|
1396
|
+
d = nesterov ? d + momentum * bNew : bNew;
|
|
1397
|
+
}
|
|
1398
|
+
paramData[paramOffset + i] = pi - lr * d;
|
|
1399
|
+
}
|
|
1400
|
+
}
|
|
1401
|
+
}
|
|
1402
|
+
return loss;
|
|
1403
|
+
}
|
|
1404
|
+
/**
|
|
1405
|
+
* Get the current learning rate.
|
|
1406
|
+
*
|
|
1407
|
+
* @param groupIdx - Parameter group index (default: 0)
|
|
1408
|
+
* @returns Current learning rate
|
|
1409
|
+
*/
|
|
1410
|
+
getLearningRate(groupIdx = 0) {
|
|
1411
|
+
const group = this.paramGroups[groupIdx];
|
|
1412
|
+
if (!group) {
|
|
1413
|
+
throw new InvalidParameterError(
|
|
1414
|
+
`Invalid group index: ${groupIdx} (valid range: [0, ${this.paramGroups.length}))`,
|
|
1415
|
+
"groupIdx",
|
|
1416
|
+
groupIdx
|
|
1417
|
+
);
|
|
1418
|
+
}
|
|
1419
|
+
return group.options.lr;
|
|
1420
|
+
}
|
|
1421
|
+
/**
|
|
1422
|
+
* Set the learning rate for all parameter groups.
|
|
1423
|
+
*
|
|
1424
|
+
* @param lr - New learning rate
|
|
1425
|
+
*/
|
|
1426
|
+
setLearningRate(lr) {
|
|
1427
|
+
assertFiniteNonNegative("learning rate", lr);
|
|
1428
|
+
for (const group of this.paramGroups) {
|
|
1429
|
+
group.options.lr = lr;
|
|
1430
|
+
}
|
|
1431
|
+
}
|
|
1432
|
+
};
|
|
1433
|
+
|
|
1434
|
+
// src/optim/schedulers.ts
|
|
1435
|
+
function isRecord2(value) {
|
|
1436
|
+
return typeof value === "object" && value !== null;
|
|
1437
|
+
}
|
|
1438
|
+
function resolveGroupLr(group, index) {
|
|
1439
|
+
const options = isRecord2(group.options) ? group.options : void 0;
|
|
1440
|
+
const lrValue = group.lr ?? options?.["lr"];
|
|
1441
|
+
if (typeof lrValue !== "number" || !Number.isFinite(lrValue) || lrValue < 0) {
|
|
1442
|
+
throw new InvalidParameterError(
|
|
1443
|
+
`optimizer.paramGroups[${index}].lr must be finite and >= 0`,
|
|
1444
|
+
`optimizer.paramGroups[${index}].lr`,
|
|
1445
|
+
lrValue
|
|
1446
|
+
);
|
|
1447
|
+
}
|
|
1448
|
+
return lrValue;
|
|
1449
|
+
}
|
|
1450
|
+
function setGroupLr(group, lr) {
|
|
1451
|
+
if (isRecord2(group.options)) {
|
|
1452
|
+
group.options["lr"] = lr;
|
|
1453
|
+
}
|
|
1454
|
+
if ("lr" in group) {
|
|
1455
|
+
group.lr = lr;
|
|
1456
|
+
}
|
|
1457
|
+
if (!("lr" in group) && !isRecord2(group.options)) {
|
|
1458
|
+
group.lr = lr;
|
|
1459
|
+
}
|
|
1460
|
+
}
|
|
1461
|
+
function validateLastEpoch(value) {
|
|
1462
|
+
if (!Number.isInteger(value) || value < -1) {
|
|
1463
|
+
throw new InvalidParameterError("lastEpoch must be an integer >= -1", "lastEpoch", value);
|
|
1464
|
+
}
|
|
1465
|
+
return value;
|
|
1466
|
+
}
|
|
1467
|
+
function validateFiniteNumber(value, name) {
|
|
1468
|
+
if (!Number.isFinite(value)) {
|
|
1469
|
+
throw new InvalidParameterError(`${name} must be finite`, name, value);
|
|
1470
|
+
}
|
|
1471
|
+
return value;
|
|
1472
|
+
}
|
|
1473
|
+
function validatePositiveNumber(value, name) {
|
|
1474
|
+
if (!Number.isFinite(value) || value <= 0) {
|
|
1475
|
+
throw new InvalidParameterError(`${name} must be > 0`, name, value);
|
|
1476
|
+
}
|
|
1477
|
+
return value;
|
|
1478
|
+
}
|
|
1479
|
+
function validatePositiveInteger(value, name) {
|
|
1480
|
+
if (!Number.isInteger(value) || value <= 0) {
|
|
1481
|
+
throw new InvalidParameterError(`${name} must be a positive integer`, name, value);
|
|
1482
|
+
}
|
|
1483
|
+
return value;
|
|
1484
|
+
}
|
|
1485
|
+
function validateNonNegativeNumber(value, name) {
|
|
1486
|
+
if (!Number.isFinite(value) || value < 0) {
|
|
1487
|
+
throw new InvalidParameterError(`${name} must be >= 0`, name, value);
|
|
1488
|
+
}
|
|
1489
|
+
return value;
|
|
1490
|
+
}
|
|
1491
|
+
function validateNonNegativeInteger(value, name) {
|
|
1492
|
+
if (!Number.isInteger(value) || value < 0) {
|
|
1493
|
+
throw new InvalidParameterError(`${name} must be a non-negative integer`, name, value);
|
|
1494
|
+
}
|
|
1495
|
+
return value;
|
|
1496
|
+
}
|
|
1497
|
+
function validateOptimizer(optimizer) {
|
|
1498
|
+
if (!optimizer || typeof optimizer !== "object" || !Array.isArray(optimizer.paramGroups)) {
|
|
1499
|
+
throw new InvalidParameterError(
|
|
1500
|
+
"optimizer must expose paramGroups array",
|
|
1501
|
+
"optimizer",
|
|
1502
|
+
optimizer
|
|
1503
|
+
);
|
|
1504
|
+
}
|
|
1505
|
+
if (optimizer.paramGroups.length === 0) {
|
|
1506
|
+
throw new InvalidParameterError(
|
|
1507
|
+
"optimizer.paramGroups must contain at least one group",
|
|
1508
|
+
"optimizer.paramGroups",
|
|
1509
|
+
optimizer.paramGroups
|
|
1510
|
+
);
|
|
1511
|
+
}
|
|
1512
|
+
for (let i = 0; i < optimizer.paramGroups.length; i++) {
|
|
1513
|
+
const group = optimizer.paramGroups[i];
|
|
1514
|
+
if (!group || typeof group !== "object") {
|
|
1515
|
+
throw new InvalidParameterError(
|
|
1516
|
+
`optimizer.paramGroups[${i}] must be an object`,
|
|
1517
|
+
"optimizer.paramGroups",
|
|
1518
|
+
group
|
|
1519
|
+
);
|
|
1520
|
+
}
|
|
1521
|
+
if (!Array.isArray(group.params)) {
|
|
1522
|
+
throw new InvalidParameterError(
|
|
1523
|
+
`optimizer.paramGroups[${i}].params must be an array`,
|
|
1524
|
+
`optimizer.paramGroups[${i}].params`,
|
|
1525
|
+
group.params
|
|
1526
|
+
);
|
|
1527
|
+
}
|
|
1528
|
+
resolveGroupLr(group, i);
|
|
1529
|
+
}
|
|
1530
|
+
}
|
|
1531
|
+
function validateMilestones(milestones) {
|
|
1532
|
+
if (!Array.isArray(milestones) || milestones.length === 0) {
|
|
1533
|
+
throw new InvalidParameterError(
|
|
1534
|
+
"milestones must be a non-empty array of non-negative integers",
|
|
1535
|
+
"milestones",
|
|
1536
|
+
milestones
|
|
1537
|
+
);
|
|
1538
|
+
}
|
|
1539
|
+
const sorted = [...milestones].sort((a, b) => a - b);
|
|
1540
|
+
for (let i = 0; i < sorted.length; i++) {
|
|
1541
|
+
const value = sorted[i];
|
|
1542
|
+
if (value === void 0 || !Number.isInteger(value) || value < 0) {
|
|
1543
|
+
throw new InvalidParameterError(
|
|
1544
|
+
"milestones must contain non-negative integers only",
|
|
1545
|
+
"milestones",
|
|
1546
|
+
milestones
|
|
1547
|
+
);
|
|
1548
|
+
}
|
|
1549
|
+
if (i > 0) {
|
|
1550
|
+
const prev = sorted[i - 1];
|
|
1551
|
+
if (prev !== void 0 && value <= prev) {
|
|
1552
|
+
throw new InvalidParameterError(
|
|
1553
|
+
"milestones must be strictly increasing",
|
|
1554
|
+
"milestones",
|
|
1555
|
+
milestones
|
|
1556
|
+
);
|
|
1557
|
+
}
|
|
1558
|
+
}
|
|
1559
|
+
}
|
|
1560
|
+
return sorted;
|
|
1561
|
+
}
|
|
1562
|
+
var LRScheduler = class {
|
|
1563
|
+
optimizer;
|
|
1564
|
+
lastEpoch;
|
|
1565
|
+
baseLrs;
|
|
1566
|
+
constructor(optimizer, lastEpoch = -1) {
|
|
1567
|
+
validateOptimizer(optimizer);
|
|
1568
|
+
this.lastEpoch = validateLastEpoch(lastEpoch);
|
|
1569
|
+
this.optimizer = optimizer;
|
|
1570
|
+
this.baseLrs = optimizer.paramGroups.map((group, index) => resolveGroupLr(group, index));
|
|
1571
|
+
}
|
|
1572
|
+
initializeFromLastEpoch(lastEpoch) {
|
|
1573
|
+
const validated = validateLastEpoch(lastEpoch);
|
|
1574
|
+
if (validated < 0) {
|
|
1575
|
+
return;
|
|
1576
|
+
}
|
|
1577
|
+
this.lastEpoch = -1;
|
|
1578
|
+
for (let i = 0; i <= validated; i++) {
|
|
1579
|
+
this.step();
|
|
1580
|
+
}
|
|
1581
|
+
}
|
|
1582
|
+
/**
|
|
1583
|
+
* Perform a scheduler step, updating learning rates.
|
|
1584
|
+
*
|
|
1585
|
+
* Should be called once per epoch after the optimizer step.
|
|
1586
|
+
*/
|
|
1587
|
+
step() {
|
|
1588
|
+
this.lastEpoch++;
|
|
1589
|
+
const newLrs = this.getLr();
|
|
1590
|
+
for (let i = 0; i < this.optimizer.paramGroups.length; i++) {
|
|
1591
|
+
const group = this.optimizer.paramGroups[i];
|
|
1592
|
+
if (group) {
|
|
1593
|
+
const next = newLrs[i];
|
|
1594
|
+
if (next !== void 0) {
|
|
1595
|
+
setGroupLr(group, next);
|
|
1596
|
+
}
|
|
1597
|
+
}
|
|
1598
|
+
}
|
|
1599
|
+
}
|
|
1600
|
+
/**
|
|
1601
|
+
* Get the current learning rates for all parameter groups.
|
|
1602
|
+
*/
|
|
1603
|
+
getLastLr() {
|
|
1604
|
+
return this.optimizer.paramGroups.map((group, index) => resolveGroupLr(group, index));
|
|
1605
|
+
}
|
|
1606
|
+
/**
|
|
1607
|
+
* Get current epoch number.
|
|
1608
|
+
*/
|
|
1609
|
+
get epoch() {
|
|
1610
|
+
return this.lastEpoch;
|
|
1611
|
+
}
|
|
1612
|
+
};
|
|
1613
|
+
var StepLR = class extends LRScheduler {
|
|
1614
|
+
stepSize;
|
|
1615
|
+
gamma;
|
|
1616
|
+
constructor(optimizer, options) {
|
|
1617
|
+
const stepSize = validatePositiveInteger(options.stepSize, "stepSize");
|
|
1618
|
+
const gamma = validatePositiveNumber(options.gamma ?? 0.1, "gamma");
|
|
1619
|
+
const lastEpoch = validateLastEpoch(options.lastEpoch ?? -1);
|
|
1620
|
+
super(optimizer, -1);
|
|
1621
|
+
this.stepSize = stepSize;
|
|
1622
|
+
this.gamma = gamma;
|
|
1623
|
+
this.initializeFromLastEpoch(lastEpoch);
|
|
1624
|
+
}
|
|
1625
|
+
getLr() {
|
|
1626
|
+
const factor = this.gamma ** Math.floor(this.lastEpoch / this.stepSize);
|
|
1627
|
+
return this.baseLrs.map((lr) => lr * factor);
|
|
1628
|
+
}
|
|
1629
|
+
};
|
|
1630
|
+
var ExponentialLR = class extends LRScheduler {
|
|
1631
|
+
gamma;
|
|
1632
|
+
constructor(optimizer, options) {
|
|
1633
|
+
const gamma = validatePositiveNumber(options.gamma, "gamma");
|
|
1634
|
+
const lastEpoch = validateLastEpoch(options.lastEpoch ?? -1);
|
|
1635
|
+
super(optimizer, -1);
|
|
1636
|
+
this.gamma = gamma;
|
|
1637
|
+
this.initializeFromLastEpoch(lastEpoch);
|
|
1638
|
+
}
|
|
1639
|
+
getLr() {
|
|
1640
|
+
return this.baseLrs.map((lr) => lr * this.gamma ** this.lastEpoch);
|
|
1641
|
+
}
|
|
1642
|
+
};
|
|
1643
|
+
var CosineAnnealingLR = class extends LRScheduler {
|
|
1644
|
+
T_max;
|
|
1645
|
+
etaMin;
|
|
1646
|
+
constructor(optimizer, options) {
|
|
1647
|
+
const tMax = validatePositiveInteger(options.T_max, "T_max");
|
|
1648
|
+
const etaMin = validateNonNegativeNumber(options.etaMin ?? 0, "etaMin");
|
|
1649
|
+
const lastEpoch = validateLastEpoch(options.lastEpoch ?? -1);
|
|
1650
|
+
super(optimizer, -1);
|
|
1651
|
+
this.T_max = tMax;
|
|
1652
|
+
this.etaMin = etaMin;
|
|
1653
|
+
this.initializeFromLastEpoch(lastEpoch);
|
|
1654
|
+
}
|
|
1655
|
+
getLr() {
|
|
1656
|
+
return this.baseLrs.map((baseLr) => {
|
|
1657
|
+
return this.etaMin + (baseLr - this.etaMin) * (1 + Math.cos(Math.PI * this.lastEpoch / this.T_max)) / 2;
|
|
1658
|
+
});
|
|
1659
|
+
}
|
|
1660
|
+
};
|
|
1661
|
+
var MultiStepLR = class extends LRScheduler {
|
|
1662
|
+
sortedMilestones;
|
|
1663
|
+
gamma;
|
|
1664
|
+
constructor(optimizer, options) {
|
|
1665
|
+
const milestones = validateMilestones(options.milestones);
|
|
1666
|
+
const gamma = validatePositiveNumber(options.gamma ?? 0.1, "gamma");
|
|
1667
|
+
const lastEpoch = validateLastEpoch(options.lastEpoch ?? -1);
|
|
1668
|
+
super(optimizer, -1);
|
|
1669
|
+
this.sortedMilestones = milestones;
|
|
1670
|
+
this.gamma = gamma;
|
|
1671
|
+
this.initializeFromLastEpoch(lastEpoch);
|
|
1672
|
+
}
|
|
1673
|
+
getLr() {
|
|
1674
|
+
let numDecays = 0;
|
|
1675
|
+
for (const milestone of this.sortedMilestones) {
|
|
1676
|
+
if (this.lastEpoch >= milestone) {
|
|
1677
|
+
numDecays++;
|
|
1678
|
+
}
|
|
1679
|
+
}
|
|
1680
|
+
const factor = this.gamma ** numDecays;
|
|
1681
|
+
return this.baseLrs.map((lr) => lr * factor);
|
|
1682
|
+
}
|
|
1683
|
+
};
|
|
1684
|
+
var LinearLR = class extends LRScheduler {
|
|
1685
|
+
startFactor;
|
|
1686
|
+
endFactor;
|
|
1687
|
+
totalIters;
|
|
1688
|
+
constructor(optimizer, options) {
|
|
1689
|
+
const startFactor = validatePositiveNumber(options.startFactor ?? 1 / 3, "startFactor");
|
|
1690
|
+
const endFactor = validatePositiveNumber(options.endFactor ?? 1, "endFactor");
|
|
1691
|
+
const totalIters = validatePositiveInteger(options.totalIters, "totalIters");
|
|
1692
|
+
const lastEpoch = validateLastEpoch(options.lastEpoch ?? -1);
|
|
1693
|
+
super(optimizer, -1);
|
|
1694
|
+
this.startFactor = startFactor;
|
|
1695
|
+
this.endFactor = endFactor;
|
|
1696
|
+
this.totalIters = totalIters;
|
|
1697
|
+
this.initializeFromLastEpoch(lastEpoch);
|
|
1698
|
+
}
|
|
1699
|
+
getLr() {
|
|
1700
|
+
if (this.lastEpoch >= this.totalIters) {
|
|
1701
|
+
return this.baseLrs.map((lr) => lr * this.endFactor);
|
|
1702
|
+
}
|
|
1703
|
+
const factor = this.startFactor + (this.endFactor - this.startFactor) * (this.lastEpoch / this.totalIters);
|
|
1704
|
+
return this.baseLrs.map((lr) => lr * factor);
|
|
1705
|
+
}
|
|
1706
|
+
};
|
|
1707
|
+
var ReduceLROnPlateau = class {
|
|
1708
|
+
optimizer;
|
|
1709
|
+
mode;
|
|
1710
|
+
factor;
|
|
1711
|
+
patience;
|
|
1712
|
+
threshold;
|
|
1713
|
+
cooldown;
|
|
1714
|
+
minLr;
|
|
1715
|
+
best;
|
|
1716
|
+
numBadEpochs;
|
|
1717
|
+
cooldownCounter;
|
|
1718
|
+
constructor(optimizer, options = {}) {
|
|
1719
|
+
this.optimizer = optimizer;
|
|
1720
|
+
validateOptimizer(optimizer);
|
|
1721
|
+
this.mode = options.mode ?? "min";
|
|
1722
|
+
if (this.mode !== "min" && this.mode !== "max") {
|
|
1723
|
+
throw new InvalidParameterError("mode must be 'min' or 'max'", "mode", options.mode);
|
|
1724
|
+
}
|
|
1725
|
+
this.factor = validateFiniteNumber(options.factor ?? 0.1, "factor");
|
|
1726
|
+
if (this.factor <= 0 || this.factor >= 1) {
|
|
1727
|
+
throw new InvalidParameterError(
|
|
1728
|
+
"factor must be in the interval (0, 1)",
|
|
1729
|
+
"factor",
|
|
1730
|
+
this.factor
|
|
1731
|
+
);
|
|
1732
|
+
}
|
|
1733
|
+
this.patience = validateNonNegativeInteger(options.patience ?? 10, "patience");
|
|
1734
|
+
this.threshold = validateNonNegativeNumber(options.threshold ?? 1e-4, "threshold");
|
|
1735
|
+
this.cooldown = validateNonNegativeInteger(options.cooldown ?? 0, "cooldown");
|
|
1736
|
+
this.minLr = validateNonNegativeNumber(options.minLr ?? 0, "minLr");
|
|
1737
|
+
this.best = this.mode === "min" ? Infinity : -Infinity;
|
|
1738
|
+
this.numBadEpochs = 0;
|
|
1739
|
+
this.cooldownCounter = 0;
|
|
1740
|
+
}
|
|
1741
|
+
/**
|
|
1742
|
+
* Check if metric improved.
|
|
1743
|
+
*/
|
|
1744
|
+
isBetter(current) {
|
|
1745
|
+
if (this.mode === "min") {
|
|
1746
|
+
return current < this.best - this.threshold;
|
|
1747
|
+
}
|
|
1748
|
+
return current > this.best + this.threshold;
|
|
1749
|
+
}
|
|
1750
|
+
/**
|
|
1751
|
+
* Perform a scheduler step based on the metric value.
|
|
1752
|
+
*
|
|
1753
|
+
* @param metric - Current value of the metric being monitored
|
|
1754
|
+
*/
|
|
1755
|
+
step(metric) {
|
|
1756
|
+
if (!Number.isFinite(metric)) {
|
|
1757
|
+
throw new InvalidParameterError("metric must be finite", "metric", metric);
|
|
1758
|
+
}
|
|
1759
|
+
if (this.cooldownCounter > 0) {
|
|
1760
|
+
this.cooldownCounter--;
|
|
1761
|
+
this.numBadEpochs = 0;
|
|
1762
|
+
}
|
|
1763
|
+
if (this.isBetter(metric)) {
|
|
1764
|
+
this.best = metric;
|
|
1765
|
+
this.numBadEpochs = 0;
|
|
1766
|
+
} else if (this.cooldownCounter === 0) {
|
|
1767
|
+
this.numBadEpochs++;
|
|
1768
|
+
}
|
|
1769
|
+
if (this.numBadEpochs > this.patience) {
|
|
1770
|
+
this.reduceLr();
|
|
1771
|
+
this.cooldownCounter = this.cooldown;
|
|
1772
|
+
this.numBadEpochs = 0;
|
|
1773
|
+
}
|
|
1774
|
+
}
|
|
1775
|
+
/**
|
|
1776
|
+
* Reduce learning rate for all parameter groups.
|
|
1777
|
+
*/
|
|
1778
|
+
reduceLr() {
|
|
1779
|
+
for (let i = 0; i < this.optimizer.paramGroups.length; i++) {
|
|
1780
|
+
const group = this.optimizer.paramGroups[i];
|
|
1781
|
+
if (!group) {
|
|
1782
|
+
throw new InvalidParameterError(
|
|
1783
|
+
`optimizer.paramGroups[${i}] is missing`,
|
|
1784
|
+
"optimizer.paramGroups",
|
|
1785
|
+
group
|
|
1786
|
+
);
|
|
1787
|
+
}
|
|
1788
|
+
const currentLr = resolveGroupLr(group, i);
|
|
1789
|
+
const newLr = Math.max(currentLr * this.factor, this.minLr);
|
|
1790
|
+
setGroupLr(group, newLr);
|
|
1791
|
+
}
|
|
1792
|
+
}
|
|
1793
|
+
/**
|
|
1794
|
+
* Get the current learning rates for all parameter groups.
|
|
1795
|
+
*/
|
|
1796
|
+
getLastLr() {
|
|
1797
|
+
return this.optimizer.paramGroups.map((group, index) => resolveGroupLr(group, index));
|
|
1798
|
+
}
|
|
1799
|
+
};
|
|
1800
|
+
var WarmupLR = class extends LRScheduler {
|
|
1801
|
+
warmupEpochs;
|
|
1802
|
+
afterScheduler;
|
|
1803
|
+
constructor(optimizer, afterScheduler, options) {
|
|
1804
|
+
const warmupEpochs = validatePositiveInteger(options.warmupEpochs, "warmupEpochs");
|
|
1805
|
+
const lastEpoch = validateLastEpoch(options.lastEpoch ?? -1);
|
|
1806
|
+
super(optimizer, -1);
|
|
1807
|
+
this.warmupEpochs = warmupEpochs;
|
|
1808
|
+
this.afterScheduler = afterScheduler;
|
|
1809
|
+
this.initializeFromLastEpoch(lastEpoch);
|
|
1810
|
+
}
|
|
1811
|
+
getLr() {
|
|
1812
|
+
if (this.lastEpoch < this.warmupEpochs) {
|
|
1813
|
+
const factor = (this.lastEpoch + 1) / this.warmupEpochs;
|
|
1814
|
+
return this.baseLrs.map((lr) => lr * factor);
|
|
1815
|
+
}
|
|
1816
|
+
if (this.afterScheduler) {
|
|
1817
|
+
return this.afterScheduler.getLr();
|
|
1818
|
+
}
|
|
1819
|
+
return this.baseLrs;
|
|
1820
|
+
}
|
|
1821
|
+
step() {
|
|
1822
|
+
super.step();
|
|
1823
|
+
if (this.lastEpoch >= this.warmupEpochs && this.afterScheduler) {
|
|
1824
|
+
this.afterScheduler.step();
|
|
1825
|
+
}
|
|
1826
|
+
}
|
|
1827
|
+
};
|
|
1828
|
+
var OneCycleLR = class extends LRScheduler {
|
|
1829
|
+
maxLr;
|
|
1830
|
+
totalSteps;
|
|
1831
|
+
pctStart;
|
|
1832
|
+
divFactor;
|
|
1833
|
+
finalDivFactor;
|
|
1834
|
+
annealStrategy;
|
|
1835
|
+
constructor(optimizer, options) {
|
|
1836
|
+
const maxLr = validatePositiveNumber(options.maxLr, "maxLr");
|
|
1837
|
+
const totalSteps = validatePositiveInteger(options.totalSteps, "totalSteps");
|
|
1838
|
+
const pctStart = validateFiniteNumber(options.pctStart ?? 0.3, "pctStart");
|
|
1839
|
+
if (pctStart <= 0 || pctStart >= 1) {
|
|
1840
|
+
throw new InvalidParameterError(
|
|
1841
|
+
"pctStart must be in the interval (0, 1)",
|
|
1842
|
+
"pctStart",
|
|
1843
|
+
pctStart
|
|
1844
|
+
);
|
|
1845
|
+
}
|
|
1846
|
+
const divFactor = validatePositiveNumber(options.divFactor ?? 25, "divFactor");
|
|
1847
|
+
const finalDivFactor = validatePositiveNumber(options.finalDivFactor ?? 1e4, "finalDivFactor");
|
|
1848
|
+
const annealStrategy = options.annealStrategy ?? "cos";
|
|
1849
|
+
if (annealStrategy !== "cos" && annealStrategy !== "linear") {
|
|
1850
|
+
throw new InvalidParameterError(
|
|
1851
|
+
"annealStrategy must be 'cos' or 'linear'",
|
|
1852
|
+
"annealStrategy",
|
|
1853
|
+
annealStrategy
|
|
1854
|
+
);
|
|
1855
|
+
}
|
|
1856
|
+
const lastEpoch = validateLastEpoch(options.lastEpoch ?? -1);
|
|
1857
|
+
super(optimizer, -1);
|
|
1858
|
+
this.maxLr = maxLr;
|
|
1859
|
+
this.totalSteps = totalSteps;
|
|
1860
|
+
this.pctStart = pctStart;
|
|
1861
|
+
this.divFactor = divFactor;
|
|
1862
|
+
this.finalDivFactor = finalDivFactor;
|
|
1863
|
+
this.annealStrategy = annealStrategy;
|
|
1864
|
+
this.initializeFromLastEpoch(lastEpoch);
|
|
1865
|
+
}
|
|
1866
|
+
getLr() {
|
|
1867
|
+
const stepNum = this.lastEpoch;
|
|
1868
|
+
const upSteps = Math.max(1, Math.floor(this.totalSteps * this.pctStart));
|
|
1869
|
+
const downSteps = Math.max(1, this.totalSteps - upSteps);
|
|
1870
|
+
const initialLr = this.maxLr / this.divFactor;
|
|
1871
|
+
const minLr = this.maxLr / this.finalDivFactor;
|
|
1872
|
+
let lr;
|
|
1873
|
+
if (stepNum >= this.totalSteps) {
|
|
1874
|
+
lr = minLr;
|
|
1875
|
+
} else if (stepNum < upSteps) {
|
|
1876
|
+
const pct = stepNum / upSteps;
|
|
1877
|
+
lr = initialLr + (this.maxLr - initialLr) * pct;
|
|
1878
|
+
} else {
|
|
1879
|
+
const pct = (stepNum - upSteps) / downSteps;
|
|
1880
|
+
if (this.annealStrategy === "cos") {
|
|
1881
|
+
lr = minLr + (this.maxLr - minLr) * (1 + Math.cos(Math.PI * pct)) / 2;
|
|
1882
|
+
} else {
|
|
1883
|
+
lr = this.maxLr - (this.maxLr - minLr) * pct;
|
|
1884
|
+
}
|
|
1885
|
+
}
|
|
1886
|
+
const baseRef = this.baseLrs[0] ?? 0;
|
|
1887
|
+
return this.baseLrs.map((baseLr) => {
|
|
1888
|
+
if (baseRef === 0) {
|
|
1889
|
+
return baseLr === 0 ? 0 : lr;
|
|
1890
|
+
}
|
|
1891
|
+
return lr * (baseLr / baseRef);
|
|
1892
|
+
});
|
|
1893
|
+
}
|
|
1894
|
+
};
|
|
1895
|
+
|
|
1896
|
+
export { AdaDelta, Adagrad, Adam, AdamW, CosineAnnealingLR, ExponentialLR, LRScheduler, LinearLR, MultiStepLR, Nadam, OneCycleLR, Optimizer, RMSprop, ReduceLROnPlateau, SGD, StepLR, WarmupLR, optim_exports };
|
|
1897
|
+
//# sourceMappingURL=chunk-PR647I7R.js.map
|
|
1898
|
+
//# sourceMappingURL=chunk-PR647I7R.js.map
|