deepbox 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +344 -0
  3. package/dist/CSRMatrix-CwGwQRea.d.cts +219 -0
  4. package/dist/CSRMatrix-KzNt6QpS.d.ts +219 -0
  5. package/dist/Tensor-BQLk1ltW.d.cts +147 -0
  6. package/dist/Tensor-g8mUClel.d.ts +147 -0
  7. package/dist/chunk-4S73VUBD.js +677 -0
  8. package/dist/chunk-4S73VUBD.js.map +1 -0
  9. package/dist/chunk-5R4S63PF.js +2925 -0
  10. package/dist/chunk-5R4S63PF.js.map +1 -0
  11. package/dist/chunk-6AE5FKKQ.cjs +9264 -0
  12. package/dist/chunk-6AE5FKKQ.cjs.map +1 -0
  13. package/dist/chunk-AD436M45.js +3854 -0
  14. package/dist/chunk-AD436M45.js.map +1 -0
  15. package/dist/chunk-ALS7ETWZ.cjs +4263 -0
  16. package/dist/chunk-ALS7ETWZ.cjs.map +1 -0
  17. package/dist/chunk-AU7XHGKJ.js +2092 -0
  18. package/dist/chunk-AU7XHGKJ.js.map +1 -0
  19. package/dist/chunk-B5TNKUEY.js +1481 -0
  20. package/dist/chunk-B5TNKUEY.js.map +1 -0
  21. package/dist/chunk-BCR7G3A6.js +9136 -0
  22. package/dist/chunk-BCR7G3A6.js.map +1 -0
  23. package/dist/chunk-C4PKXY74.cjs +1917 -0
  24. package/dist/chunk-C4PKXY74.cjs.map +1 -0
  25. package/dist/chunk-DWZY6PIP.cjs +6400 -0
  26. package/dist/chunk-DWZY6PIP.cjs.map +1 -0
  27. package/dist/chunk-E3EU5FZO.cjs +2113 -0
  28. package/dist/chunk-E3EU5FZO.cjs.map +1 -0
  29. package/dist/chunk-F3JWBINJ.js +1054 -0
  30. package/dist/chunk-F3JWBINJ.js.map +1 -0
  31. package/dist/chunk-FJYLIGJX.js +1940 -0
  32. package/dist/chunk-FJYLIGJX.js.map +1 -0
  33. package/dist/chunk-JSCDE774.cjs +729 -0
  34. package/dist/chunk-JSCDE774.cjs.map +1 -0
  35. package/dist/chunk-LWECRCW2.cjs +2412 -0
  36. package/dist/chunk-LWECRCW2.cjs.map +1 -0
  37. package/dist/chunk-MLBMYKCG.js +6379 -0
  38. package/dist/chunk-MLBMYKCG.js.map +1 -0
  39. package/dist/chunk-OX6QXFMV.cjs +3874 -0
  40. package/dist/chunk-OX6QXFMV.cjs.map +1 -0
  41. package/dist/chunk-PHV2DKRS.cjs +1072 -0
  42. package/dist/chunk-PHV2DKRS.cjs.map +1 -0
  43. package/dist/chunk-PL7TAYKI.js +4056 -0
  44. package/dist/chunk-PL7TAYKI.js.map +1 -0
  45. package/dist/chunk-PR647I7R.js +1898 -0
  46. package/dist/chunk-PR647I7R.js.map +1 -0
  47. package/dist/chunk-QERHVCHC.cjs +2960 -0
  48. package/dist/chunk-QERHVCHC.cjs.map +1 -0
  49. package/dist/chunk-XEG44RF6.cjs +1514 -0
  50. package/dist/chunk-XEG44RF6.cjs.map +1 -0
  51. package/dist/chunk-XMWVME2W.js +2377 -0
  52. package/dist/chunk-XMWVME2W.js.map +1 -0
  53. package/dist/chunk-ZB75FESB.cjs +1979 -0
  54. package/dist/chunk-ZB75FESB.cjs.map +1 -0
  55. package/dist/chunk-ZLW62TJG.cjs +4061 -0
  56. package/dist/chunk-ZLW62TJG.cjs.map +1 -0
  57. package/dist/chunk-ZXKBDFP3.js +4235 -0
  58. package/dist/chunk-ZXKBDFP3.js.map +1 -0
  59. package/dist/core/index.cjs +204 -0
  60. package/dist/core/index.cjs.map +1 -0
  61. package/dist/core/index.d.cts +2 -0
  62. package/dist/core/index.d.ts +2 -0
  63. package/dist/core/index.js +3 -0
  64. package/dist/core/index.js.map +1 -0
  65. package/dist/dataframe/index.cjs +22 -0
  66. package/dist/dataframe/index.cjs.map +1 -0
  67. package/dist/dataframe/index.d.cts +3 -0
  68. package/dist/dataframe/index.d.ts +3 -0
  69. package/dist/dataframe/index.js +5 -0
  70. package/dist/dataframe/index.js.map +1 -0
  71. package/dist/datasets/index.cjs +134 -0
  72. package/dist/datasets/index.cjs.map +1 -0
  73. package/dist/datasets/index.d.cts +3 -0
  74. package/dist/datasets/index.d.ts +3 -0
  75. package/dist/datasets/index.js +5 -0
  76. package/dist/datasets/index.js.map +1 -0
  77. package/dist/index-74AB8Cyh.d.cts +1126 -0
  78. package/dist/index-9oQx1HgV.d.cts +1180 -0
  79. package/dist/index-BJY2SI4i.d.ts +483 -0
  80. package/dist/index-BWGhrDlr.d.ts +733 -0
  81. package/dist/index-B_DK4FKY.d.cts +242 -0
  82. package/dist/index-BbA2Gxfl.d.ts +456 -0
  83. package/dist/index-BgHYAoSS.d.cts +837 -0
  84. package/dist/index-BndMbqsM.d.ts +1439 -0
  85. package/dist/index-C1mfVYoo.d.ts +2517 -0
  86. package/dist/index-CCvlwAmL.d.cts +809 -0
  87. package/dist/index-CDw5CnOU.d.ts +785 -0
  88. package/dist/index-Cn3SdB0O.d.ts +1126 -0
  89. package/dist/index-CrqLlS-a.d.ts +776 -0
  90. package/dist/index-D61yaSMY.d.cts +483 -0
  91. package/dist/index-D9Loo1_A.d.cts +2517 -0
  92. package/dist/index-DIT_OO9C.d.cts +785 -0
  93. package/dist/index-DIp_RrRt.d.ts +242 -0
  94. package/dist/index-DbultU6X.d.cts +1427 -0
  95. package/dist/index-DmEg_LCm.d.cts +776 -0
  96. package/dist/index-DoPWVxPo.d.cts +1439 -0
  97. package/dist/index-DuCxd-8d.d.ts +837 -0
  98. package/dist/index-Dx42TZaY.d.ts +809 -0
  99. package/dist/index-DyZ4QQf5.d.cts +456 -0
  100. package/dist/index-GFAVyOWO.d.ts +1427 -0
  101. package/dist/index-WHQLn0e8.d.cts +733 -0
  102. package/dist/index-ZtI1Iy4L.d.ts +1180 -0
  103. package/dist/index-eJgeni9c.d.cts +1911 -0
  104. package/dist/index-tk4lSYod.d.ts +1911 -0
  105. package/dist/index.cjs +72 -0
  106. package/dist/index.cjs.map +1 -0
  107. package/dist/index.d.cts +17 -0
  108. package/dist/index.d.ts +17 -0
  109. package/dist/index.js +15 -0
  110. package/dist/index.js.map +1 -0
  111. package/dist/linalg/index.cjs +86 -0
  112. package/dist/linalg/index.cjs.map +1 -0
  113. package/dist/linalg/index.d.cts +3 -0
  114. package/dist/linalg/index.d.ts +3 -0
  115. package/dist/linalg/index.js +5 -0
  116. package/dist/linalg/index.js.map +1 -0
  117. package/dist/metrics/index.cjs +158 -0
  118. package/dist/metrics/index.cjs.map +1 -0
  119. package/dist/metrics/index.d.cts +3 -0
  120. package/dist/metrics/index.d.ts +3 -0
  121. package/dist/metrics/index.js +5 -0
  122. package/dist/metrics/index.js.map +1 -0
  123. package/dist/ml/index.cjs +87 -0
  124. package/dist/ml/index.cjs.map +1 -0
  125. package/dist/ml/index.d.cts +3 -0
  126. package/dist/ml/index.d.ts +3 -0
  127. package/dist/ml/index.js +6 -0
  128. package/dist/ml/index.js.map +1 -0
  129. package/dist/ndarray/index.cjs +501 -0
  130. package/dist/ndarray/index.cjs.map +1 -0
  131. package/dist/ndarray/index.d.cts +5 -0
  132. package/dist/ndarray/index.d.ts +5 -0
  133. package/dist/ndarray/index.js +4 -0
  134. package/dist/ndarray/index.js.map +1 -0
  135. package/dist/nn/index.cjs +142 -0
  136. package/dist/nn/index.cjs.map +1 -0
  137. package/dist/nn/index.d.cts +6 -0
  138. package/dist/nn/index.d.ts +6 -0
  139. package/dist/nn/index.js +5 -0
  140. package/dist/nn/index.js.map +1 -0
  141. package/dist/optim/index.cjs +77 -0
  142. package/dist/optim/index.cjs.map +1 -0
  143. package/dist/optim/index.d.cts +4 -0
  144. package/dist/optim/index.d.ts +4 -0
  145. package/dist/optim/index.js +4 -0
  146. package/dist/optim/index.js.map +1 -0
  147. package/dist/plot/index.cjs +114 -0
  148. package/dist/plot/index.cjs.map +1 -0
  149. package/dist/plot/index.d.cts +6 -0
  150. package/dist/plot/index.d.ts +6 -0
  151. package/dist/plot/index.js +5 -0
  152. package/dist/plot/index.js.map +1 -0
  153. package/dist/preprocess/index.cjs +82 -0
  154. package/dist/preprocess/index.cjs.map +1 -0
  155. package/dist/preprocess/index.d.cts +4 -0
  156. package/dist/preprocess/index.d.ts +4 -0
  157. package/dist/preprocess/index.js +5 -0
  158. package/dist/preprocess/index.js.map +1 -0
  159. package/dist/random/index.cjs +74 -0
  160. package/dist/random/index.cjs.map +1 -0
  161. package/dist/random/index.d.cts +3 -0
  162. package/dist/random/index.d.ts +3 -0
  163. package/dist/random/index.js +5 -0
  164. package/dist/random/index.js.map +1 -0
  165. package/dist/stats/index.cjs +142 -0
  166. package/dist/stats/index.cjs.map +1 -0
  167. package/dist/stats/index.d.cts +3 -0
  168. package/dist/stats/index.d.ts +3 -0
  169. package/dist/stats/index.js +5 -0
  170. package/dist/stats/index.js.map +1 -0
  171. package/dist/tensor-B96jjJLQ.d.cts +205 -0
  172. package/dist/tensor-B96jjJLQ.d.ts +205 -0
  173. package/package.json +226 -0
@@ -0,0 +1,1898 @@
1
+ import { __export, DataValidationError, InvalidParameterError, DeepboxError, NotFittedError, DTypeError, ShapeError, IndexError } from './chunk-4S73VUBD.js';
2
+
3
+ // src/optim/index.ts
4
+ var optim_exports = {};
5
+ __export(optim_exports, {
6
+ AdaDelta: () => AdaDelta,
7
+ Adagrad: () => Adagrad,
8
+ Adam: () => Adam,
9
+ AdamW: () => AdamW,
10
+ CosineAnnealingLR: () => CosineAnnealingLR,
11
+ ExponentialLR: () => ExponentialLR,
12
+ LRScheduler: () => LRScheduler,
13
+ LinearLR: () => LinearLR,
14
+ MultiStepLR: () => MultiStepLR,
15
+ Nadam: () => Nadam,
16
+ OneCycleLR: () => OneCycleLR,
17
+ Optimizer: () => Optimizer,
18
+ RMSprop: () => RMSprop,
19
+ ReduceLROnPlateau: () => ReduceLROnPlateau,
20
+ SGD: () => SGD,
21
+ StepLR: () => StepLR,
22
+ WarmupLR: () => WarmupLR
23
+ });
24
+
25
+ // src/optim/Optimizer.ts
26
+ function isRecord(value) {
27
+ return typeof value === "object" && value !== null;
28
+ }
29
+ function ensureRecord(value, context) {
30
+ if (!isRecord(value)) {
31
+ throw new DataValidationError(`${context} must be an object`);
32
+ }
33
+ return value;
34
+ }
35
+ function isStateRecord(value) {
36
+ return isRecord(value);
37
+ }
38
+ function ensureIntegerArray(value, context) {
39
+ if (!Array.isArray(value)) {
40
+ throw new DataValidationError(`${context} must be an array of integers`);
41
+ }
42
+ const output = [];
43
+ for (const entry of value) {
44
+ if (!Number.isInteger(entry)) {
45
+ throw new DataValidationError(`${context} must contain integers only`);
46
+ }
47
+ output.push(entry);
48
+ }
49
+ return output;
50
+ }
51
+ function isParamGroupArray(params) {
52
+ if (!Array.isArray(params)) return false;
53
+ if (params.length === 0) return true;
54
+ const first = params[0];
55
+ if (!first || typeof first !== "object") return false;
56
+ return "params" in first;
57
+ }
58
+ var Optimizer = class {
59
+ /**
60
+ * Create a new optimizer.
61
+ *
62
+ * Initializes the optimizer with either a simple list of parameters or
63
+ * multiple parameter groups with per-group hyperparameters.
64
+ *
65
+ * @param params - Either an iterable of parameters or array of parameter groups
66
+ * @param defaults - Default hyperparameters applied to all groups
67
+ */
68
+ constructor(params, defaults) {
69
+ this.defaults = defaults;
70
+ this.paramGroups = [];
71
+ if (!isParamGroupArray(params)) {
72
+ this.paramGroups.push({
73
+ params: Array.from(params),
74
+ // Convert iterable to array for efficient access
75
+ options: { ...defaults }
76
+ // Clone defaults to avoid mutation
77
+ });
78
+ } else {
79
+ for (const group of params) {
80
+ const { params: groupParams, ...groupOptions } = group;
81
+ this.paramGroups.push({
82
+ params: Array.from(groupParams),
83
+ // Convert to array
84
+ options: { ...defaults, ...groupOptions }
85
+ // Merge defaults with group options
86
+ });
87
+ }
88
+ }
89
+ }
90
+ /**
91
+ * Groups of parameters with their associated hyperparameters.
92
+ * Each group can have different options (e.g., learning rates).
93
+ * Exposed publicly to enable scheduler integrations.
94
+ */
95
+ paramGroups;
96
+ /**
97
+ * Per-parameter state storage.
98
+ * Maps each parameter to its optimizer-specific state (momentum, adaptive rates, etc.).
99
+ */
100
+ state = /* @__PURE__ */ new Map();
101
+ /**
102
+ * Zero out the gradients of all optimized parameters.
103
+ *
104
+ * This method should be called at the beginning of each training iteration,
105
+ * before computing new gradients. Without this call, gradients would accumulate
106
+ * across iterations, leading to incorrect updates.
107
+ *
108
+ * **Implementation Note:**
109
+ * For parameters wrapped in GradTensor, this calls zeroGrad() on each parameter,
110
+ * which either sets the gradient to zero or initializes it if not yet created.
111
+ *
112
+ * @example
113
+ * ```ts
114
+ * // Typical training loop
115
+ * optimizer.zeroGrad(); // Clear previous gradients
116
+ * const output = model.forward(input);
117
+ * const loss = criterion(output, target);
118
+ * loss.backward(); // Compute new gradients
119
+ * optimizer.step(); // Update parameters
120
+ * ```
121
+ */
122
+ zeroGrad() {
123
+ for (const group of this.paramGroups) {
124
+ for (const param of group.params) {
125
+ param.zeroGrad();
126
+ }
127
+ }
128
+ }
129
+ /**
130
+ * Add a parameter group to the optimizer.
131
+ *
132
+ * This method allows adding new parameters to optimize after the optimizer
133
+ * has been created. This is particularly useful for:
134
+ * - Fine-tuning: adding pre-trained layers with different learning rates
135
+ * - Progressive training: gradually unfreezing layers
136
+ * - Dynamic architectures: adding parameters while the model grows
137
+ *
138
+ * @param paramGroup - Parameter group to add with optional per-group options
139
+ *
140
+ * @example
141
+ * ```ts
142
+ * const optimizer = new SGD(model.backbone.parameters(), { lr: 0.001 });
143
+ * // Later, add classifier with higher learning rate
144
+ * optimizer.addParamGroup({
145
+ * params: model.classifier.parameters(),
146
+ * lr: 0.01
147
+ * });
148
+ * ```
149
+ */
150
+ addParamGroup(paramGroup) {
151
+ const { params, ...options } = paramGroup;
152
+ this.paramGroups.push({
153
+ params: Array.from(params),
154
+ // Convert iterable to array
155
+ options: { ...this.defaults, ...options }
156
+ // Merge with defaults
157
+ });
158
+ }
159
+ /**
160
+ * Get the current state of the optimizer.
161
+ *
162
+ * Returns a dictionary containing all optimizer state that needs to be
163
+ * saved for checkpointing. This includes per-parameter state (momentum buffers,
164
+ * adaptive learning rates, etc.) and parameter group configurations.
165
+ *
166
+ * **Note:** In a production implementation, parameters would be identified by
167
+ * unique IDs rather than object references for proper serialization.
168
+ *
169
+ * @returns Optimizer state dictionary containing state and parameter groups
170
+ *
171
+ * @example
172
+ * ```ts
173
+ * // Save checkpoint
174
+ * const checkpoint = {
175
+ * model: model.stateDict(),
176
+ * optimizer: optimizer.stateDict(),
177
+ * epoch: currentEpoch
178
+ * };
179
+ * ```
180
+ */
181
+ stateDict() {
182
+ const paramIdMap = /* @__PURE__ */ new Map();
183
+ const orderedParams = [];
184
+ const getParamId = (param) => {
185
+ const existing = paramIdMap.get(param);
186
+ if (existing !== void 0) return existing;
187
+ const id = orderedParams.length;
188
+ orderedParams.push(param);
189
+ paramIdMap.set(param, id);
190
+ return id;
191
+ };
192
+ return {
193
+ // Serialize per-parameter state
194
+ state: Array.from(this.state.entries()).map(([param, state]) => ({
195
+ paramId: getParamId(param),
196
+ param,
197
+ // Backward-compatible references
198
+ state
199
+ // Optimizer-specific state (momentum, etc.)
200
+ })),
201
+ // Serialize parameter groups and their options
202
+ paramGroups: this.paramGroups.map((group) => ({
203
+ params: group.params,
204
+ // Backward-compatible references
205
+ paramIds: group.params.map((param) => getParamId(param)),
206
+ options: group.options
207
+ // Hyperparameters for this group
208
+ }))
209
+ };
210
+ }
211
+ /**
212
+ * Load optimizer state from a state dictionary.
213
+ *
214
+ * Restores the optimizer to a previously saved state, including all
215
+ * per-parameter state and parameter group configurations. This is essential
216
+ * for resuming training from checkpoints.
217
+ *
218
+ * **Important:** The loaded state must be compatible with the current
219
+ * optimizer configuration (same parameters, same optimizer type).
220
+ *
221
+ * @param stateDict - State dictionary previously returned by stateDict()
222
+ *
223
+ * @example
224
+ * ```ts
225
+ * // Resume from checkpoint
226
+ * const checkpoint = loadCheckpoint('checkpoint.json');
227
+ * model.loadStateDict(checkpoint.model);
228
+ * optimizer.loadStateDict(checkpoint.optimizer);
229
+ * ```
230
+ */
231
+ loadStateDict(stateDict) {
232
+ const currentParams = this.paramGroups.flatMap((group) => group.params);
233
+ const currentParamCount = currentParams.length;
234
+ const paramLookup = /* @__PURE__ */ new Map();
235
+ for (let i = 0; i < currentParams.length; i++) {
236
+ paramLookup.set(currentParams[i], i);
237
+ }
238
+ if (Object.hasOwn(stateDict, "paramGroups")) {
239
+ const rawGroups = stateDict["paramGroups"];
240
+ if (!Array.isArray(rawGroups)) {
241
+ throw new DataValidationError("paramGroups must be an array");
242
+ }
243
+ const groupsArray = rawGroups;
244
+ if (groupsArray.length === 0) {
245
+ if (this.paramGroups.length !== 0) {
246
+ throw new DataValidationError("paramGroups cannot be empty");
247
+ }
248
+ this.paramGroups = [];
249
+ } else {
250
+ if (groupsArray.length !== this.paramGroups.length) {
251
+ throw new DataValidationError("paramGroups count mismatch");
252
+ }
253
+ const seenParamIds = /* @__PURE__ */ new Set();
254
+ let totalParamCount = 0;
255
+ let sawParamIds = false;
256
+ let sawNoParamIds = false;
257
+ const nextGroups = [];
258
+ groupsArray.forEach((rawGroup, index) => {
259
+ const groupRecord = ensureRecord(rawGroup, `paramGroups[${index}]`);
260
+ const optionsRaw = ensureRecord(groupRecord["options"], `paramGroups[${index}].options`);
261
+ const options = { ...this.defaults };
262
+ const optionsRecord = options;
263
+ const defaultsRecord = { ...this.defaults };
264
+ for (const [key, value] of Object.entries(optionsRaw)) {
265
+ if (Object.hasOwn(defaultsRecord, key)) {
266
+ const defaultVal = defaultsRecord[key];
267
+ const expectedType = typeof defaultVal;
268
+ const actualType = typeof value;
269
+ if (actualType !== expectedType) {
270
+ throw new DataValidationError(
271
+ `Type mismatch for option '${key}' in paramGroups[${index}]: expected ${expectedType}, got ${actualType}`
272
+ );
273
+ }
274
+ optionsRecord[key] = value;
275
+ }
276
+ }
277
+ const paramIdsRaw = groupRecord["paramIds"];
278
+ const paramsRaw = groupRecord["params"];
279
+ let paramIds;
280
+ if (paramIdsRaw !== void 0) {
281
+ paramIds = ensureIntegerArray(paramIdsRaw, `paramGroups[${index}].paramIds`);
282
+ sawParamIds = true;
283
+ } else {
284
+ sawNoParamIds = true;
285
+ }
286
+ let resolvedParams;
287
+ if (paramIds) {
288
+ for (const id of paramIds) {
289
+ if (id < 0 || id >= currentParamCount) {
290
+ throw new DataValidationError(`Invalid paramId ${id} in paramGroups`);
291
+ }
292
+ if (seenParamIds.has(id)) {
293
+ throw new DataValidationError(`Duplicate paramId ${id} in paramGroups`);
294
+ }
295
+ seenParamIds.add(id);
296
+ }
297
+ totalParamCount += paramIds.length;
298
+ resolvedParams = paramIds.map((id) => {
299
+ const param = currentParams[id];
300
+ if (!param) {
301
+ throw new DataValidationError(`Invalid paramId ${id} in paramGroups`);
302
+ }
303
+ return param;
304
+ });
305
+ }
306
+ if (paramsRaw !== void 0) {
307
+ if (!Array.isArray(paramsRaw)) {
308
+ throw new DataValidationError(`paramGroups[${index}].params must be an array`);
309
+ }
310
+ const resolvedFromParams = [];
311
+ let hasUnknown = false;
312
+ for (const paramRef of paramsRaw) {
313
+ const paramIndex = paramLookup.get(paramRef);
314
+ if (paramIndex === void 0) {
315
+ hasUnknown = true;
316
+ continue;
317
+ }
318
+ const param = currentParams[paramIndex];
319
+ if (!param) {
320
+ hasUnknown = true;
321
+ continue;
322
+ }
323
+ resolvedFromParams.push(param);
324
+ }
325
+ if (!hasUnknown) {
326
+ if (paramIds && paramIds.length !== resolvedFromParams.length) {
327
+ throw new DataValidationError("paramIds length does not match params length");
328
+ }
329
+ if (!resolvedParams) {
330
+ resolvedParams = resolvedFromParams;
331
+ }
332
+ }
333
+ }
334
+ if (!resolvedParams) {
335
+ throw new DataValidationError(`paramGroups[${index}] must include params or paramIds`);
336
+ }
337
+ if (paramIds === void 0) {
338
+ nextGroups.push({ params: resolvedParams, options });
339
+ } else {
340
+ nextGroups.push({ params: resolvedParams, options, paramIds });
341
+ }
342
+ });
343
+ if (sawParamIds && sawNoParamIds) {
344
+ throw new DataValidationError("paramIds must be provided for all parameter groups");
345
+ }
346
+ if (sawParamIds && totalParamCount !== currentParamCount) {
347
+ throw new DataValidationError(
348
+ `Parameter count mismatch: expected ${currentParamCount}, got ${totalParamCount}`
349
+ );
350
+ }
351
+ this.paramGroups = nextGroups.map((group) => ({
352
+ params: group.params,
353
+ options: group.options
354
+ }));
355
+ }
356
+ }
357
+ if (Object.hasOwn(stateDict, "state")) {
358
+ const rawState = stateDict["state"];
359
+ if (!Array.isArray(rawState)) {
360
+ throw new DataValidationError("state must be an array");
361
+ }
362
+ const stateArray = rawState;
363
+ this.state.clear();
364
+ stateArray.forEach((rawEntry, index) => {
365
+ const entryRecord = ensureRecord(rawEntry, `state[${index}]`);
366
+ if (!Object.hasOwn(entryRecord, "state")) {
367
+ throw new DataValidationError(`state[${index}].state is required`);
368
+ }
369
+ const entryStateValue = ensureRecord(entryRecord["state"], `state[${index}].state`);
370
+ if (!isStateRecord(entryStateValue)) {
371
+ throw new DataValidationError(`state[${index}].state must be an object`);
372
+ }
373
+ const paramIdRaw = entryRecord["paramId"];
374
+ const paramRaw = entryRecord["param"];
375
+ let resolvedParam;
376
+ if (paramIdRaw !== void 0) {
377
+ if (paramIdRaw === null || typeof paramIdRaw !== "number" || !Number.isInteger(paramIdRaw)) {
378
+ throw new DataValidationError(`Invalid paramId ${String(paramIdRaw)} in state`);
379
+ }
380
+ if (paramIdRaw < 0 || paramIdRaw >= currentParamCount) {
381
+ throw new DataValidationError(`Invalid paramId ${paramIdRaw} in state`);
382
+ }
383
+ const param = currentParams[paramIdRaw];
384
+ if (!param) {
385
+ throw new DataValidationError(`Invalid paramId ${paramIdRaw} in state`);
386
+ }
387
+ if (paramRaw !== void 0) {
388
+ const paramIndex = paramLookup.get(paramRaw);
389
+ if (paramIndex === void 0 || paramIndex !== paramIdRaw) {
390
+ throw new DataValidationError(`paramId ${paramIdRaw} does not match provided param`);
391
+ }
392
+ }
393
+ resolvedParam = param;
394
+ } else {
395
+ if (paramRaw === void 0) {
396
+ throw new DataValidationError("Missing param reference in state entry");
397
+ }
398
+ const paramIndex = paramLookup.get(paramRaw);
399
+ if (paramIndex === void 0) {
400
+ throw new DataValidationError("Unknown param reference in state entry");
401
+ }
402
+ const param = currentParams[paramIndex];
403
+ if (!param) {
404
+ throw new DataValidationError("Unknown param reference in state entry");
405
+ }
406
+ resolvedParam = param;
407
+ }
408
+ if (!resolvedParam) {
409
+ throw new DataValidationError(`Unable to resolve parameter for state[${index}]`);
410
+ }
411
+ if (!this.isState(entryStateValue)) {
412
+ throw new DataValidationError(`state[${index}].state has invalid structure`);
413
+ }
414
+ this.state.set(resolvedParam, entryStateValue);
415
+ });
416
+ }
417
+ }
418
+ };
419
+
420
+ // src/optim/_internal.ts
421
+ function isFloatTypedArray(value) {
422
+ return value instanceof Float32Array || value instanceof Float64Array;
423
+ }
424
+ function safeArrayAccess(array, index, context) {
425
+ if (index < 0 || index >= array.length) {
426
+ throw new IndexError(`Index ${index} out of bounds [0, ${array.length}) in ${context}`, {
427
+ index,
428
+ validRange: [0, array.length - 1]
429
+ });
430
+ }
431
+ const value = array[index];
432
+ if (value === void 0) {
433
+ throw new DeepboxError(`Unexpected undefined at index ${index} in ${context}`);
434
+ }
435
+ return value;
436
+ }
437
+ function assertFiniteNonNegative(name, value) {
438
+ if (!Number.isFinite(value) || value < 0) {
439
+ throw new InvalidParameterError(`Invalid ${name}: ${value}`, name, value);
440
+ }
441
+ }
442
+ function assertFinitePositive(name, value) {
443
+ if (!Number.isFinite(value) || value <= 0) {
444
+ throw new InvalidParameterError(`Invalid ${name}: ${value} (must be > 0)`, name, value);
445
+ }
446
+ }
447
+ function assertFinite(name, value) {
448
+ if (!Number.isFinite(value)) {
449
+ throw new InvalidParameterError(`Invalid ${name}: ${value}`, name, value);
450
+ }
451
+ }
452
+ function assertInRange(name, value, min, max) {
453
+ if (!Number.isFinite(value) || value < min || value >= max) {
454
+ throw new InvalidParameterError(
455
+ `Invalid ${name}: ${value} (must be in range [${min}, ${max}))`,
456
+ name,
457
+ value
458
+ );
459
+ }
460
+ }
461
+ function assertHasGradFloat(param, optimizerName) {
462
+ if (!param.requiresGrad) {
463
+ throw new InvalidParameterError(
464
+ "Cannot optimize a parameter with requiresGrad=false",
465
+ "requiresGrad",
466
+ false
467
+ );
468
+ }
469
+ const g = param.grad;
470
+ if (!g) {
471
+ throw new NotFittedError(
472
+ "Cannot optimize a parameter without a gradient. Did you forget backward()?"
473
+ );
474
+ }
475
+ const paramData = param.tensor.data;
476
+ const gradData = g.data;
477
+ if (!isFloatTypedArray(paramData) || !isFloatTypedArray(gradData)) {
478
+ throw new DTypeError(
479
+ `${optimizerName} optimizer supports float32 and float64 parameters and gradients only`
480
+ );
481
+ }
482
+ if (paramData.constructor !== gradData.constructor) {
483
+ throw new DTypeError(
484
+ `${optimizerName} optimizer requires parameter and gradient dtypes to match`
485
+ );
486
+ }
487
+ if (param.tensor.size !== g.size) {
488
+ throw new ShapeError(
489
+ `Gradient shape must match parameter shape (param: ${param.tensor.size}, grad: ${g.size})`
490
+ );
491
+ }
492
+ return {
493
+ grad: gradData,
494
+ gradOffset: g.offset,
495
+ param: paramData,
496
+ paramOffset: param.tensor.offset
497
+ };
498
+ }
499
+ function assertBufferSize(buffer, expectedSize, bufferName) {
500
+ if (buffer.length !== expectedSize) {
501
+ throw new DeepboxError(
502
+ `State buffer size mismatch for ${bufferName}: expected ${expectedSize}, got ${buffer.length}`
503
+ );
504
+ }
505
+ }
506
+
507
+ // src/optim/optimizers/adadelta.ts
508
+ var AdaDelta = class extends Optimizer {
509
+ _stepCount = 0;
510
+ get stepCount() {
511
+ return this._stepCount;
512
+ }
513
+ constructor(params, options = {}) {
514
+ const defaults = {
515
+ lr: options.lr ?? 1,
516
+ rho: options.rho ?? 0.9,
517
+ eps: options.eps ?? 1e-6,
518
+ weightDecay: options.weightDecay ?? 0
519
+ };
520
+ super(params, defaults);
521
+ assertFiniteNonNegative("learning rate", defaults.lr);
522
+ assertInRange("rho", defaults.rho, 0, 1);
523
+ assertFinitePositive("epsilon", defaults.eps);
524
+ assertFiniteNonNegative("weight_decay value", defaults.weightDecay);
525
+ }
526
+ /**
527
+ * Get the current learning rate.
528
+ *
529
+ * @param groupIdx - Parameter group index (default: 0)
530
+ * @returns Current learning rate
531
+ */
532
+ getLearningRate(groupIdx = 0) {
533
+ const group = this.paramGroups[groupIdx];
534
+ if (!group) {
535
+ throw new InvalidParameterError(
536
+ `Invalid group index: ${groupIdx} (valid range: [0, ${this.paramGroups.length}))`,
537
+ "groupIdx",
538
+ groupIdx
539
+ );
540
+ }
541
+ return group.options.lr;
542
+ }
543
+ /**
544
+ * Set the learning rate for all parameter groups.
545
+ *
546
+ * @param lr - New learning rate
547
+ */
548
+ setLearningRate(lr) {
549
+ assertFiniteNonNegative("learning rate", lr);
550
+ for (const group of this.paramGroups) {
551
+ group.options.lr = lr;
552
+ }
553
+ }
554
+ isState(state) {
555
+ return state["squareAvg"] instanceof Float64Array && state["accDelta"] instanceof Float64Array;
556
+ }
557
+ step(closure) {
558
+ let loss;
559
+ if (closure) {
560
+ loss = closure();
561
+ }
562
+ this._stepCount++;
563
+ for (const group of this.paramGroups) {
564
+ const { lr, rho, eps, weightDecay } = group.options;
565
+ assertFiniteNonNegative("learning rate", lr);
566
+ assertInRange("rho", rho, 0, 1);
567
+ assertFinitePositive("epsilon", eps);
568
+ assertFiniteNonNegative("weight_decay value", weightDecay);
569
+ for (const param of group.params) {
570
+ const {
571
+ grad: gradData,
572
+ gradOffset: gOff,
573
+ param: pData,
574
+ paramOffset: pOff
575
+ } = assertHasGradFloat(param, "AdaDelta");
576
+ const size = param.tensor.size;
577
+ let state = this.state.get(param);
578
+ if (!state) {
579
+ state = {
580
+ squareAvg: new Float64Array(size),
581
+ accDelta: new Float64Array(size)
582
+ };
583
+ this.state.set(param, state);
584
+ }
585
+ assertBufferSize(state.squareAvg, size, "AdaDelta squareAvg");
586
+ assertBufferSize(state.accDelta, size, "AdaDelta accDelta");
587
+ for (let i = 0; i < size; i++) {
588
+ const gi0 = safeArrayAccess(gradData, gOff + i, "AdaDelta gradient");
589
+ const pi = safeArrayAccess(pData, pOff + i, "AdaDelta parameter");
590
+ assertFinite("gradient", gi0);
591
+ assertFinite("parameter", pi);
592
+ const gi = weightDecay !== 0 ? gi0 + weightDecay * pi : gi0;
593
+ const sq = safeArrayAccess(state.squareAvg, i, "AdaDelta squareAvg");
594
+ const sqNew = rho * sq + (1 - rho) * gi * gi;
595
+ state.squareAvg[i] = sqNew;
596
+ const std = Math.sqrt(sqNew + eps);
597
+ const accD = safeArrayAccess(state.accDelta, i, "AdaDelta accDelta");
598
+ const rmsUpdate = Math.sqrt(accD + eps);
599
+ const delta = rmsUpdate / std * gi;
600
+ state.accDelta[i] = rho * accD + (1 - rho) * delta * delta;
601
+ pData[pOff + i] = pi - lr * delta;
602
+ }
603
+ }
604
+ }
605
+ return loss;
606
+ }
607
+ };
608
+
609
+ // src/optim/optimizers/adagrad.ts
610
+ var Adagrad = class extends Optimizer {
611
+ _stepCount = 0;
612
+ get stepCount() {
613
+ return this._stepCount;
614
+ }
615
+ constructor(params, options = {}) {
616
+ const defaults = {
617
+ lr: options.lr ?? 0.01,
618
+ eps: options.eps ?? 1e-10,
619
+ weightDecay: options.weightDecay ?? 0,
620
+ lrDecay: options.lrDecay ?? 0
621
+ };
622
+ super(params, defaults);
623
+ assertFiniteNonNegative("learning rate", defaults.lr);
624
+ assertFinitePositive("epsilon", defaults.eps);
625
+ assertFiniteNonNegative("weight_decay value", defaults.weightDecay);
626
+ assertFiniteNonNegative("lr_decay", defaults.lrDecay);
627
+ }
628
+ /**
629
+ * Get the current learning rate.
630
+ *
631
+ * @param groupIdx - Parameter group index (default: 0)
632
+ * @returns Current learning rate
633
+ */
634
+ getLearningRate(groupIdx = 0) {
635
+ const group = this.paramGroups[groupIdx];
636
+ if (!group) {
637
+ throw new InvalidParameterError(
638
+ `Invalid group index: ${groupIdx} (valid range: [0, ${this.paramGroups.length}))`,
639
+ "groupIdx",
640
+ groupIdx
641
+ );
642
+ }
643
+ return group.options.lr;
644
+ }
645
+ /**
646
+ * Set the learning rate for all parameter groups.
647
+ *
648
+ * @param lr - New learning rate
649
+ */
650
+ setLearningRate(lr) {
651
+ assertFiniteNonNegative("learning rate", lr);
652
+ for (const group of this.paramGroups) {
653
+ group.options.lr = lr;
654
+ }
655
+ }
656
+ isState(state) {
657
+ return typeof state["step"] === "number" && state["sum"] instanceof Float64Array;
658
+ }
659
+ step(closure) {
660
+ let loss;
661
+ if (closure) {
662
+ loss = closure();
663
+ }
664
+ this._stepCount++;
665
+ for (const group of this.paramGroups) {
666
+ const { lr, eps, weightDecay, lrDecay } = group.options;
667
+ assertFiniteNonNegative("learning rate", lr);
668
+ assertFinitePositive("epsilon", eps);
669
+ assertFiniteNonNegative("weight_decay value", weightDecay);
670
+ assertFiniteNonNegative("lr_decay", lrDecay);
671
+ for (const param of group.params) {
672
+ const {
673
+ grad: gradData,
674
+ gradOffset: gOff,
675
+ param: pData,
676
+ paramOffset: pOff
677
+ } = assertHasGradFloat(param, "Adagrad");
678
+ const size = param.tensor.size;
679
+ const existing = this.state.get(param);
680
+ const state = existing ?? (() => {
681
+ const next = {
682
+ step: 0,
683
+ sum: new Float64Array(size)
684
+ };
685
+ this.state.set(param, next);
686
+ return next;
687
+ })();
688
+ assertBufferSize(state.sum, size, "Adagrad sum");
689
+ state.step += 1;
690
+ const clr = lr / (1 + (state.step - 1) * lrDecay);
691
+ for (let i = 0; i < size; i++) {
692
+ const gi0 = safeArrayAccess(gradData, gOff + i, "Adagrad gradient");
693
+ const pi = safeArrayAccess(pData, pOff + i, "Adagrad parameter");
694
+ assertFinite("gradient", gi0);
695
+ assertFinite("parameter", pi);
696
+ const gi = weightDecay !== 0 ? gi0 + weightDecay * pi : gi0;
697
+ const sumVal = safeArrayAccess(state.sum, i, "Adagrad sum");
698
+ const sumNew = sumVal + gi * gi;
699
+ state.sum[i] = sumNew;
700
+ const std = Math.sqrt(sumNew) + eps;
701
+ pData[pOff + i] = pi - clr * (gi / std);
702
+ }
703
+ }
704
+ }
705
+ return loss;
706
+ }
707
+ };
708
+
709
+ // src/optim/optimizers/adam.ts
710
+ var Adam = class extends Optimizer {
711
+ _stepCount = 0;
712
+ get stepCount() {
713
+ return this._stepCount;
714
+ }
715
+ constructor(params, options = {}) {
716
+ const defaults = {
717
+ lr: options.lr ?? 1e-3,
718
+ beta1: options.beta1 ?? 0.9,
719
+ beta2: options.beta2 ?? 0.999,
720
+ eps: options.eps ?? 1e-8,
721
+ weightDecay: options.weightDecay ?? 0,
722
+ amsgrad: options.amsgrad ?? false
723
+ };
724
+ super(params, defaults);
725
+ assertFiniteNonNegative("learning rate", defaults.lr);
726
+ assertInRange("beta1", defaults.beta1, 0, 1);
727
+ assertInRange("beta2", defaults.beta2, 0, 1);
728
+ assertFinitePositive("epsilon", defaults.eps);
729
+ assertFiniteNonNegative("weight_decay value", defaults.weightDecay);
730
+ }
731
+ /**
732
+ * Get the current learning rate.
733
+ *
734
+ * @param groupIdx - Parameter group index (default: 0)
735
+ * @returns Current learning rate
736
+ */
737
+ getLearningRate(groupIdx = 0) {
738
+ const group = this.paramGroups[groupIdx];
739
+ if (!group) {
740
+ throw new InvalidParameterError(
741
+ `Invalid group index: ${groupIdx} (valid range: [0, ${this.paramGroups.length}))`,
742
+ "groupIdx",
743
+ groupIdx
744
+ );
745
+ }
746
+ return group.options.lr;
747
+ }
748
+ /**
749
+ * Set the learning rate for all parameter groups.
750
+ *
751
+ * @param lr - New learning rate
752
+ */
753
+ setLearningRate(lr) {
754
+ assertFiniteNonNegative("learning rate", lr);
755
+ for (const group of this.paramGroups) {
756
+ group.options.lr = lr;
757
+ }
758
+ }
759
+ isState(state) {
760
+ const hasRequired = typeof state["step"] === "number" && state["expAvg"] instanceof Float64Array && state["expAvgSq"] instanceof Float64Array;
761
+ if (!hasRequired) return false;
762
+ if (state["maxExpAvgSq"] !== void 0 && !(state["maxExpAvgSq"] instanceof Float64Array)) {
763
+ return false;
764
+ }
765
+ return true;
766
+ }
767
+ step(closure) {
768
+ let loss;
769
+ if (closure) {
770
+ loss = closure();
771
+ }
772
+ this._stepCount++;
773
+ for (const group of this.paramGroups) {
774
+ const { lr, beta1, beta2, eps, weightDecay, amsgrad } = group.options;
775
+ assertFiniteNonNegative("learning rate", lr);
776
+ assertInRange("beta1", beta1, 0, 1);
777
+ assertInRange("beta2", beta2, 0, 1);
778
+ assertFinitePositive("epsilon", eps);
779
+ assertFiniteNonNegative("weight_decay value", weightDecay);
780
+ for (const param of group.params) {
781
+ const {
782
+ grad: gradData,
783
+ gradOffset,
784
+ param: paramData,
785
+ paramOffset
786
+ } = assertHasGradFloat(param, "Adam");
787
+ const size = param.tensor.size;
788
+ const existing = this.state.get(param);
789
+ const state = existing ?? (() => {
790
+ const next = {
791
+ step: 0,
792
+ expAvg: new Float64Array(size),
793
+ expAvgSq: new Float64Array(size),
794
+ ...amsgrad ? { maxExpAvgSq: new Float64Array(size) } : {}
795
+ };
796
+ this.state.set(param, next);
797
+ return next;
798
+ })();
799
+ assertBufferSize(state.expAvg, size, "Adam expAvg");
800
+ assertBufferSize(state.expAvgSq, size, "Adam expAvgSq");
801
+ if (amsgrad && state.maxExpAvgSq) {
802
+ assertBufferSize(state.maxExpAvgSq, size, "Adam maxExpAvgSq");
803
+ }
804
+ state.step += 1;
805
+ const biasCorrection1 = 1 - beta1 ** state.step;
806
+ const biasCorrection2 = 1 - beta2 ** state.step;
807
+ const stepSize = lr / biasCorrection1;
808
+ for (let i = 0; i < size; i++) {
809
+ const gi0 = safeArrayAccess(gradData, gradOffset + i, "Adam gradient");
810
+ const pi = safeArrayAccess(paramData, paramOffset + i, "Adam parameter");
811
+ assertFinite("gradient", gi0);
812
+ assertFinite("parameter", pi);
813
+ const gi = weightDecay !== 0 ? gi0 + weightDecay * pi : gi0;
814
+ const m = safeArrayAccess(state.expAvg, i, "Adam expAvg");
815
+ const v = safeArrayAccess(state.expAvgSq, i, "Adam expAvgSq");
816
+ const mNew = beta1 * m + (1 - beta1) * gi;
817
+ const vNew = beta2 * v + (1 - beta2) * gi * gi;
818
+ state.expAvg[i] = mNew;
819
+ state.expAvgSq[i] = vNew;
820
+ let denomSq = vNew;
821
+ if (amsgrad) {
822
+ const maxBuf = state.maxExpAvgSq;
823
+ if (!maxBuf) {
824
+ throw new DeepboxError("Internal error: AMSGrad enabled but maxExpAvgSq is missing");
825
+ }
826
+ const maxV = Math.max(safeArrayAccess(maxBuf, i, "Adam maxExpAvgSq"), vNew);
827
+ maxBuf[i] = maxV;
828
+ denomSq = maxV;
829
+ }
830
+ const denom = Math.sqrt(denomSq / biasCorrection2) + eps;
831
+ paramData[paramOffset + i] = pi - stepSize * (mNew / denom);
832
+ }
833
+ }
834
+ }
835
+ return loss;
836
+ }
837
+ };
838
+
839
+ // src/optim/optimizers/adamw.ts
840
+ var AdamW = class extends Optimizer {
841
+ /** Internal counter tracking total number of optimization steps */
842
+ _stepCount = 0;
843
+ /**
844
+ * Get the total number of optimization steps performed.
845
+ *
846
+ * @returns Number of steps taken
847
+ */
848
+ get stepCount() {
849
+ return this._stepCount;
850
+ }
851
+ /**
852
+ * Create a new AdamW optimizer.
853
+ *
854
+ * @param params - Iterable of parameters or parameter groups to optimize
855
+ * @param options - Optimization options
856
+ * @param options.lr - Learning rate (default: 0.001)
857
+ * @param options.beta1 - First moment decay rate (default: 0.9)
858
+ * @param options.beta2 - Second moment decay rate (default: 0.999)
859
+ * @param options.eps - Numerical stability constant (default: 1e-8)
860
+ * @param options.weightDecay - Weight decay coefficient (default: 0.01)
861
+ * @param options.amsgrad - Enable AMSGrad variant (default: false)
862
+ * @throws {InvalidParameterError} If a parameter is invalid
863
+ */
864
+ constructor(params, options = {}) {
865
+ const defaults = {
866
+ lr: options.lr ?? 1e-3,
867
+ beta1: options.beta1 ?? 0.9,
868
+ beta2: options.beta2 ?? 0.999,
869
+ eps: options.eps ?? 1e-8,
870
+ weightDecay: options.weightDecay ?? 0.01,
871
+ // Higher default than Adam
872
+ amsgrad: options.amsgrad ?? false
873
+ };
874
+ super(params, defaults);
875
+ assertFiniteNonNegative("learning rate", defaults.lr);
876
+ assertInRange("beta1", defaults.beta1, 0, 1);
877
+ assertInRange("beta2", defaults.beta2, 0, 1);
878
+ assertFinitePositive("epsilon", defaults.eps);
879
+ assertFiniteNonNegative("weight_decay value", defaults.weightDecay);
880
+ }
881
+ /**
882
+ * Get the current learning rate.
883
+ *
884
+ * @param groupIdx - Parameter group index (default: 0)
885
+ * @returns Current learning rate
886
+ */
887
+ getLearningRate(groupIdx = 0) {
888
+ const group = this.paramGroups[groupIdx];
889
+ if (!group) {
890
+ throw new InvalidParameterError(
891
+ `Invalid group index: ${groupIdx} (valid range: [0, ${this.paramGroups.length}))`,
892
+ "groupIdx",
893
+ groupIdx
894
+ );
895
+ }
896
+ return group.options.lr;
897
+ }
898
+ /**
899
+ * Set the learning rate for all parameter groups.
900
+ *
901
+ * @param lr - New learning rate
902
+ */
903
+ setLearningRate(lr) {
904
+ assertFiniteNonNegative("learning rate", lr);
905
+ for (const group of this.paramGroups) {
906
+ group.options.lr = lr;
907
+ }
908
+ }
909
+ /**
910
+ * Perform a single optimization step (parameter update).
911
+ *
912
+ * Implements the AdamW update rule with decoupled weight decay.
913
+ *
914
+ * @param closure - Optional closure that reevaluates the model and returns the loss
915
+ * @returns Loss value if closure is provided, undefined otherwise
916
+ */
917
+ isState(state) {
918
+ const hasRequired = typeof state["step"] === "number" && state["expAvg"] instanceof Float64Array && state["expAvgSq"] instanceof Float64Array;
919
+ if (!hasRequired) return false;
920
+ if (state["maxExpAvgSq"] !== void 0 && !(state["maxExpAvgSq"] instanceof Float64Array)) {
921
+ return false;
922
+ }
923
+ return true;
924
+ }
925
+ step(closure) {
926
+ let loss;
927
+ if (closure) {
928
+ loss = closure();
929
+ }
930
+ this._stepCount++;
931
+ for (const group of this.paramGroups) {
932
+ const { lr, beta1, beta2, eps, weightDecay, amsgrad } = group.options;
933
+ assertFiniteNonNegative("learning rate", lr);
934
+ assertInRange("beta1", beta1, 0, 1);
935
+ assertInRange("beta2", beta2, 0, 1);
936
+ assertFinitePositive("epsilon", eps);
937
+ assertFiniteNonNegative("weight_decay value", weightDecay);
938
+ for (const param of group.params) {
939
+ const {
940
+ grad,
941
+ gradOffset,
942
+ param: pData,
943
+ paramOffset: pOff
944
+ } = assertHasGradFloat(param, "AdamW");
945
+ const size = param.tensor.size;
946
+ const existing = this.state.get(param);
947
+ const state = existing ?? (() => {
948
+ const next = {
949
+ step: 0,
950
+ expAvg: new Float64Array(size),
951
+ // First moment
952
+ expAvgSq: new Float64Array(size),
953
+ // Second moment
954
+ ...amsgrad ? { maxExpAvgSq: new Float64Array(size) } : {}
955
+ // AMSGrad buffer
956
+ };
957
+ this.state.set(param, next);
958
+ return next;
959
+ })();
960
+ assertBufferSize(state.expAvg, size, "AdamW expAvg");
961
+ assertBufferSize(state.expAvgSq, size, "AdamW expAvgSq");
962
+ if (amsgrad && state.maxExpAvgSq) {
963
+ assertBufferSize(state.maxExpAvgSq, size, "AdamW maxExpAvgSq");
964
+ }
965
+ state.step += 1;
966
+ const biasCorrection1 = 1 - beta1 ** state.step;
967
+ const biasCorrection2 = 1 - beta2 ** state.step;
968
+ const stepSize = lr / biasCorrection1;
969
+ for (let i = 0; i < size; i++) {
970
+ const gi = safeArrayAccess(grad, gradOffset + i, "AdamW gradient");
971
+ const pi = safeArrayAccess(pData, pOff + i, "AdamW parameter");
972
+ assertFinite("gradient", gi);
973
+ assertFinite("parameter", pi);
974
+ const m = safeArrayAccess(state.expAvg, i, "AdamW expAvg");
975
+ const v = safeArrayAccess(state.expAvgSq, i, "AdamW expAvgSq");
976
+ const mNew = beta1 * m + (1 - beta1) * gi;
977
+ const vNew = beta2 * v + (1 - beta2) * gi * gi;
978
+ state.expAvg[i] = mNew;
979
+ state.expAvgSq[i] = vNew;
980
+ let denomSq = vNew;
981
+ if (amsgrad) {
982
+ const maxBuf = state.maxExpAvgSq;
983
+ if (!maxBuf) {
984
+ throw new DeepboxError("Internal error: AMSGrad enabled but maxExpAvgSq is missing");
985
+ }
986
+ const maxV = Math.max(safeArrayAccess(maxBuf, i, "AdamW maxExpAvgSq"), vNew);
987
+ maxBuf[i] = maxV;
988
+ denomSq = maxV;
989
+ }
990
+ const denom = Math.sqrt(denomSq / biasCorrection2) + eps;
991
+ pData[pOff + i] = pi - stepSize * (mNew / denom) - lr * weightDecay * pi;
992
+ }
993
+ }
994
+ }
995
+ return loss;
996
+ }
997
+ };
998
+
999
+ // src/optim/optimizers/nadam.ts
1000
+ var Nadam = class extends Optimizer {
1001
+ _stepCount = 0;
1002
+ get stepCount() {
1003
+ return this._stepCount;
1004
+ }
1005
+ constructor(params, options = {}) {
1006
+ const defaults = {
1007
+ lr: options.lr ?? 2e-3,
1008
+ beta1: options.beta1 ?? 0.9,
1009
+ beta2: options.beta2 ?? 0.999,
1010
+ eps: options.eps ?? 1e-8,
1011
+ weightDecay: options.weightDecay ?? 0,
1012
+ momentumDecay: options.momentumDecay ?? 4e-3
1013
+ };
1014
+ super(params, defaults);
1015
+ assertFiniteNonNegative("learning rate", defaults.lr);
1016
+ assertInRange("beta1", defaults.beta1, 0, 1);
1017
+ assertInRange("beta2", defaults.beta2, 0, 1);
1018
+ assertFinitePositive("epsilon", defaults.eps);
1019
+ assertFiniteNonNegative("weight_decay value", defaults.weightDecay);
1020
+ assertFiniteNonNegative("momentum_decay", defaults.momentumDecay);
1021
+ }
1022
+ /**
1023
+ * Get the current learning rate.
1024
+ *
1025
+ * @param groupIdx - Parameter group index (default: 0)
1026
+ * @returns Current learning rate
1027
+ */
1028
+ getLearningRate(groupIdx = 0) {
1029
+ const group = this.paramGroups[groupIdx];
1030
+ if (!group) {
1031
+ throw new InvalidParameterError(
1032
+ `Invalid group index: ${groupIdx} (valid range: [0, ${this.paramGroups.length}))`,
1033
+ "groupIdx",
1034
+ groupIdx
1035
+ );
1036
+ }
1037
+ return group.options.lr;
1038
+ }
1039
+ /**
1040
+ * Set the learning rate for all parameter groups.
1041
+ *
1042
+ * @param lr - New learning rate
1043
+ */
1044
+ setLearningRate(lr) {
1045
+ assertFiniteNonNegative("learning rate", lr);
1046
+ for (const group of this.paramGroups) {
1047
+ group.options.lr = lr;
1048
+ }
1049
+ }
1050
+ isState(state) {
1051
+ return typeof state["step"] === "number" && state["expAvg"] instanceof Float64Array && state["expAvgSq"] instanceof Float64Array && typeof state["muProduct"] === "number";
1052
+ }
1053
+ step(closure) {
1054
+ let loss;
1055
+ if (closure) {
1056
+ loss = closure();
1057
+ }
1058
+ this._stepCount++;
1059
+ for (const group of this.paramGroups) {
1060
+ const { lr, beta1, beta2, eps, weightDecay, momentumDecay } = group.options;
1061
+ assertFiniteNonNegative("learning rate", lr);
1062
+ assertInRange("beta1", beta1, 0, 1);
1063
+ assertInRange("beta2", beta2, 0, 1);
1064
+ assertFinitePositive("epsilon", eps);
1065
+ assertFiniteNonNegative("weight_decay value", weightDecay);
1066
+ assertFiniteNonNegative("momentum_decay", momentumDecay);
1067
+ for (const param of group.params) {
1068
+ const {
1069
+ grad: gradData,
1070
+ gradOffset: gOff,
1071
+ param: pData,
1072
+ paramOffset: pOff
1073
+ } = assertHasGradFloat(param, "Nadam");
1074
+ const size = param.tensor.size;
1075
+ let state = this.state.get(param);
1076
+ if (!state) {
1077
+ state = {
1078
+ step: 0,
1079
+ expAvg: new Float64Array(size),
1080
+ expAvgSq: new Float64Array(size),
1081
+ muProduct: 1
1082
+ };
1083
+ this.state.set(param, state);
1084
+ }
1085
+ assertBufferSize(state.expAvg, size, "Nadam expAvg");
1086
+ assertBufferSize(state.expAvgSq, size, "Nadam expAvgSq");
1087
+ state.step++;
1088
+ const t = state.step;
1089
+ const biasCorrection2 = 1 - beta2 ** t;
1090
+ const mu = beta1 * (1 - 0.5 * 0.96 ** (t * momentumDecay));
1091
+ const muNext = beta1 * (1 - 0.5 * 0.96 ** ((t + 1) * momentumDecay));
1092
+ const muProduct = state.muProduct * mu;
1093
+ const muProductNext = muProduct * muNext;
1094
+ state.muProduct = muProduct;
1095
+ for (let i = 0; i < size; i++) {
1096
+ const gi0 = safeArrayAccess(gradData, gOff + i, "Nadam gradient");
1097
+ const pi = safeArrayAccess(pData, pOff + i, "Nadam parameter");
1098
+ assertFinite("gradient", gi0);
1099
+ assertFinite("parameter", pi);
1100
+ const gi = weightDecay !== 0 ? gi0 + weightDecay * pi : gi0;
1101
+ const m = safeArrayAccess(state.expAvg, i, "Nadam expAvg");
1102
+ const mNew = beta1 * m + (1 - beta1) * gi;
1103
+ state.expAvg[i] = mNew;
1104
+ const v = safeArrayAccess(state.expAvgSq, i, "Nadam expAvgSq");
1105
+ const vNew = beta2 * v + (1 - beta2) * gi * gi;
1106
+ state.expAvgSq[i] = vNew;
1107
+ const denom = Math.sqrt(vNew / biasCorrection2) + eps;
1108
+ const mHatNext = mNew / (1 - muProductNext);
1109
+ const gHat = gi / (1 - muProduct);
1110
+ const mNesterov = muNext * mHatNext + (1 - mu) * gHat;
1111
+ pData[pOff + i] = pi - lr * mNesterov / denom;
1112
+ }
1113
+ }
1114
+ }
1115
+ return loss;
1116
+ }
1117
+ };
1118
+
1119
+ // src/optim/optimizers/rmsprop.ts
1120
+ var RMSprop = class extends Optimizer {
1121
+ /** Internal counter tracking total number of optimization steps */
1122
+ _stepCount = 0;
1123
+ /**
1124
+ * Get the total number of optimization steps performed.
1125
+ *
1126
+ * @returns Number of steps taken
1127
+ */
1128
+ get stepCount() {
1129
+ return this._stepCount;
1130
+ }
1131
+ /**
1132
+ * Create a new RMSprop optimizer.
1133
+ *
1134
+ * @param params - Iterable of parameters or parameter groups to optimize
1135
+ * @param options - Optimization options
1136
+ * @param options.lr - Learning rate (default: 0.01)
1137
+ * @param options.alpha - Smoothing constant (default: 0.99)
1138
+ * @param options.eps - Numerical stability constant (default: 1e-8)
1139
+ * @param options.weightDecay - Weight decay coefficient (default: 0)
1140
+ * @param options.momentum - Momentum factor (default: 0)
1141
+ * @param options.centered - Use centered variant (default: false)
1142
+ * @throws {InvalidParameterError} If a parameter is invalid
1143
+ */
1144
+ constructor(params, options = {}) {
1145
+ const defaults = {
1146
+ lr: options.lr ?? 0.01,
1147
+ alpha: options.alpha ?? 0.99,
1148
+ eps: options.eps ?? 1e-8,
1149
+ weightDecay: options.weightDecay ?? 0,
1150
+ momentum: options.momentum ?? 0,
1151
+ centered: options.centered ?? false
1152
+ };
1153
+ super(params, defaults);
1154
+ assertFiniteNonNegative("learning rate", defaults.lr);
1155
+ if (!Number.isFinite(defaults.alpha) || defaults.alpha < 0 || defaults.alpha > 1) {
1156
+ throw new InvalidParameterError(
1157
+ `Invalid alpha: ${defaults.alpha} (must be in range [0, 1])`,
1158
+ "alpha",
1159
+ defaults.alpha
1160
+ );
1161
+ }
1162
+ assertFinitePositive("epsilon", defaults.eps);
1163
+ assertFiniteNonNegative("weight_decay value", defaults.weightDecay);
1164
+ assertFiniteNonNegative("momentum value", defaults.momentum);
1165
+ }
1166
+ /**
1167
+ * Get the current learning rate.
1168
+ *
1169
+ * @param groupIdx - Parameter group index (default: 0)
1170
+ * @returns Current learning rate
1171
+ */
1172
+ getLearningRate(groupIdx = 0) {
1173
+ const group = this.paramGroups[groupIdx];
1174
+ if (!group) {
1175
+ throw new InvalidParameterError(
1176
+ `Invalid group index: ${groupIdx} (valid range: [0, ${this.paramGroups.length}))`,
1177
+ "groupIdx",
1178
+ groupIdx
1179
+ );
1180
+ }
1181
+ return group.options.lr;
1182
+ }
1183
+ /**
1184
+ * Set the learning rate for all parameter groups.
1185
+ *
1186
+ * @param lr - New learning rate
1187
+ */
1188
+ setLearningRate(lr) {
1189
+ assertFiniteNonNegative("learning rate", lr);
1190
+ for (const group of this.paramGroups) {
1191
+ group.options.lr = lr;
1192
+ }
1193
+ }
1194
+ isState(state) {
1195
+ if (!(state["squareAvg"] instanceof Float64Array)) return false;
1196
+ if (state["momentumBuffer"] !== void 0 && !(state["momentumBuffer"] instanceof Float64Array)) {
1197
+ return false;
1198
+ }
1199
+ if (state["gradAvg"] !== void 0 && !(state["gradAvg"] instanceof Float64Array)) {
1200
+ return false;
1201
+ }
1202
+ return true;
1203
+ }
1204
+ step(closure) {
1205
+ let loss;
1206
+ if (closure) {
1207
+ loss = closure();
1208
+ }
1209
+ this._stepCount++;
1210
+ for (const group of this.paramGroups) {
1211
+ const { lr, alpha, eps, weightDecay, momentum, centered } = group.options;
1212
+ assertFiniteNonNegative("learning rate", lr);
1213
+ if (!Number.isFinite(alpha) || alpha < 0 || alpha > 1) {
1214
+ throw new InvalidParameterError(
1215
+ `Invalid alpha: ${alpha} (must be in range [0, 1])`,
1216
+ "alpha",
1217
+ alpha
1218
+ );
1219
+ }
1220
+ assertFinitePositive("epsilon", eps);
1221
+ assertFiniteNonNegative("weight_decay value", weightDecay);
1222
+ assertFiniteNonNegative("momentum value", momentum);
1223
+ for (const param of group.params) {
1224
+ const {
1225
+ grad: gradData,
1226
+ gradOffset: gOff,
1227
+ param: pData,
1228
+ paramOffset: pOff
1229
+ } = assertHasGradFloat(param, "RMSprop");
1230
+ const size = param.tensor.size;
1231
+ let state = this.state.get(param);
1232
+ if (!state) {
1233
+ state = {
1234
+ squareAvg: new Float64Array(size)
1235
+ };
1236
+ this.state.set(param, state);
1237
+ }
1238
+ if (momentum > 0 && !state.momentumBuffer) {
1239
+ state.momentumBuffer = new Float64Array(size);
1240
+ }
1241
+ if (centered && !state.gradAvg) {
1242
+ state.gradAvg = new Float64Array(size);
1243
+ }
1244
+ assertBufferSize(state.squareAvg, size, "RMSprop squareAvg");
1245
+ if (momentum > 0 && state.momentumBuffer) {
1246
+ assertBufferSize(state.momentumBuffer, size, "RMSprop momentumBuffer");
1247
+ }
1248
+ if (centered && state.gradAvg) {
1249
+ assertBufferSize(state.gradAvg, size, "RMSprop gradAvg");
1250
+ }
1251
+ for (let i = 0; i < size; i++) {
1252
+ const gi = safeArrayAccess(gradData, gOff + i, "RMSprop gradient");
1253
+ const pi = safeArrayAccess(pData, pOff + i, "RMSprop parameter");
1254
+ assertFinite("gradient", gi);
1255
+ assertFinite("parameter", pi);
1256
+ let grad = gi;
1257
+ if (weightDecay !== 0) {
1258
+ grad = grad + weightDecay * pi;
1259
+ }
1260
+ const sqAvg = safeArrayAccess(state.squareAvg, i, "RMSprop squareAvg");
1261
+ const sqAvgNew = alpha * sqAvg + (1 - alpha) * grad * grad;
1262
+ state.squareAvg[i] = sqAvgNew;
1263
+ let avg = sqAvgNew;
1264
+ if (centered) {
1265
+ const gAvg = state.gradAvg ? safeArrayAccess(state.gradAvg, i, "RMSprop gradAvg") : 0;
1266
+ const gAvgNew = alpha * gAvg + (1 - alpha) * grad;
1267
+ if (state.gradAvg) state.gradAvg[i] = gAvgNew;
1268
+ avg = sqAvgNew - gAvgNew * gAvgNew;
1269
+ }
1270
+ const denom = centered ? Math.sqrt(Math.max(avg, 0) + eps) : Math.sqrt(avg) + eps;
1271
+ const normalizedGrad = grad / denom;
1272
+ if (momentum > 0) {
1273
+ const buf = state.momentumBuffer ? safeArrayAccess(state.momentumBuffer, i, "RMSprop momentumBuffer") : 0;
1274
+ const bufNew = momentum * buf + normalizedGrad;
1275
+ if (state.momentumBuffer) state.momentumBuffer[i] = bufNew;
1276
+ pData[pOff + i] = pi - lr * bufNew;
1277
+ } else {
1278
+ pData[pOff + i] = pi - lr * normalizedGrad;
1279
+ }
1280
+ }
1281
+ }
1282
+ }
1283
+ return loss;
1284
+ }
1285
+ };
1286
+
1287
+ // src/optim/optimizers/sgd.ts
1288
+ var SGD = class extends Optimizer {
1289
+ /** Internal counter tracking total number of optimization steps */
1290
+ _stepCount = 0;
1291
+ get stepCount() {
1292
+ return this._stepCount;
1293
+ }
1294
+ /**
1295
+ * Create a new SGD optimizer.
1296
+ *
1297
+ * @param params - Iterable of parameters or parameter groups to optimize
1298
+ * @param options - Optimization options
1299
+ * @param options.lr - Learning rate (default: 0.01)
1300
+ * @param options.momentum - Momentum factor (default: 0)
1301
+ * @param options.dampening - Dampening for momentum (default: 0)
1302
+ * @param options.weightDecay - Weight decay (L2 penalty) (default: 0)
1303
+ * @param options.nesterov - Enable Nesterov momentum (default: false)
1304
+ */
1305
+ constructor(params, options = {}) {
1306
+ const defaults = {
1307
+ lr: options.lr ?? 0.01,
1308
+ momentum: options.momentum ?? 0,
1309
+ dampening: options.dampening ?? 0,
1310
+ weightDecay: options.weightDecay ?? 0,
1311
+ nesterov: options.nesterov ?? false
1312
+ };
1313
+ super(params, defaults);
1314
+ assertFiniteNonNegative("learning rate", defaults.lr);
1315
+ assertFiniteNonNegative("momentum value", defaults.momentum);
1316
+ assertFiniteNonNegative("dampening", defaults.dampening);
1317
+ assertFiniteNonNegative("weight_decay value", defaults.weightDecay);
1318
+ if (defaults.nesterov && (defaults.momentum <= 0 || defaults.dampening !== 0)) {
1319
+ throw new InvalidParameterError(
1320
+ "Nesterov momentum requires a momentum and zero dampening",
1321
+ "nesterov",
1322
+ {
1323
+ momentum: defaults.momentum,
1324
+ dampening: defaults.dampening,
1325
+ nesterov: defaults.nesterov
1326
+ }
1327
+ );
1328
+ }
1329
+ }
1330
+ /**
1331
+ * Perform a single optimization step.
1332
+ *
1333
+ * Implements the SGD update rule with optional momentum and weight decay.
1334
+ *
1335
+ * @param closure - Optional closure that reevaluates the model and returns the loss
1336
+ * @returns Loss value if closure is provided
1337
+ */
1338
+ isState(state) {
1339
+ if (state["momentumBuffer"] !== void 0 && !(state["momentumBuffer"] instanceof Float64Array)) {
1340
+ return false;
1341
+ }
1342
+ return true;
1343
+ }
1344
+ step(closure) {
1345
+ let loss;
1346
+ if (closure) {
1347
+ loss = closure();
1348
+ }
1349
+ this._stepCount++;
1350
+ for (const group of this.paramGroups) {
1351
+ const { lr, momentum, dampening, weightDecay, nesterov } = group.options;
1352
+ assertFiniteNonNegative("learning rate", lr);
1353
+ assertFiniteNonNegative("momentum value", momentum);
1354
+ assertFiniteNonNegative("dampening", dampening);
1355
+ assertFiniteNonNegative("weight_decay value", weightDecay);
1356
+ if (nesterov && (momentum <= 0 || dampening !== 0)) {
1357
+ throw new InvalidParameterError(
1358
+ "Nesterov momentum requires a momentum and zero dampening",
1359
+ "nesterov",
1360
+ { momentum, dampening, nesterov }
1361
+ );
1362
+ }
1363
+ for (const param of group.params) {
1364
+ const {
1365
+ grad: gradData,
1366
+ gradOffset,
1367
+ param: paramData,
1368
+ paramOffset
1369
+ } = assertHasGradFloat(param, "SGD");
1370
+ const size = param.tensor.size;
1371
+ let state = this.state.get(param);
1372
+ if (!state) {
1373
+ state = {};
1374
+ this.state.set(param, state);
1375
+ }
1376
+ let momentumBuffer;
1377
+ if (momentum !== 0) {
1378
+ if (!state.momentumBuffer) {
1379
+ state.momentumBuffer = new Float64Array(size);
1380
+ }
1381
+ momentumBuffer = state.momentumBuffer;
1382
+ }
1383
+ for (let i = 0; i < size; i++) {
1384
+ const gi = safeArrayAccess(gradData, gradOffset + i, "SGD gradient");
1385
+ const pi = safeArrayAccess(paramData, paramOffset + i, "SGD parameter");
1386
+ assertFinite("gradient", gi);
1387
+ assertFinite("parameter", pi);
1388
+ let d = gi;
1389
+ if (weightDecay !== 0) {
1390
+ d = d + weightDecay * pi;
1391
+ }
1392
+ if (momentumBuffer) {
1393
+ const bPrev = safeArrayAccess(momentumBuffer, i, "SGD momentum buffer");
1394
+ const bNew = momentum * bPrev + (1 - dampening) * d;
1395
+ momentumBuffer[i] = bNew;
1396
+ d = nesterov ? d + momentum * bNew : bNew;
1397
+ }
1398
+ paramData[paramOffset + i] = pi - lr * d;
1399
+ }
1400
+ }
1401
+ }
1402
+ return loss;
1403
+ }
1404
+ /**
1405
+ * Get the current learning rate.
1406
+ *
1407
+ * @param groupIdx - Parameter group index (default: 0)
1408
+ * @returns Current learning rate
1409
+ */
1410
+ getLearningRate(groupIdx = 0) {
1411
+ const group = this.paramGroups[groupIdx];
1412
+ if (!group) {
1413
+ throw new InvalidParameterError(
1414
+ `Invalid group index: ${groupIdx} (valid range: [0, ${this.paramGroups.length}))`,
1415
+ "groupIdx",
1416
+ groupIdx
1417
+ );
1418
+ }
1419
+ return group.options.lr;
1420
+ }
1421
+ /**
1422
+ * Set the learning rate for all parameter groups.
1423
+ *
1424
+ * @param lr - New learning rate
1425
+ */
1426
+ setLearningRate(lr) {
1427
+ assertFiniteNonNegative("learning rate", lr);
1428
+ for (const group of this.paramGroups) {
1429
+ group.options.lr = lr;
1430
+ }
1431
+ }
1432
+ };
1433
+
1434
+ // src/optim/schedulers.ts
1435
+ function isRecord2(value) {
1436
+ return typeof value === "object" && value !== null;
1437
+ }
1438
+ function resolveGroupLr(group, index) {
1439
+ const options = isRecord2(group.options) ? group.options : void 0;
1440
+ const lrValue = group.lr ?? options?.["lr"];
1441
+ if (typeof lrValue !== "number" || !Number.isFinite(lrValue) || lrValue < 0) {
1442
+ throw new InvalidParameterError(
1443
+ `optimizer.paramGroups[${index}].lr must be finite and >= 0`,
1444
+ `optimizer.paramGroups[${index}].lr`,
1445
+ lrValue
1446
+ );
1447
+ }
1448
+ return lrValue;
1449
+ }
1450
+ function setGroupLr(group, lr) {
1451
+ if (isRecord2(group.options)) {
1452
+ group.options["lr"] = lr;
1453
+ }
1454
+ if ("lr" in group) {
1455
+ group.lr = lr;
1456
+ }
1457
+ if (!("lr" in group) && !isRecord2(group.options)) {
1458
+ group.lr = lr;
1459
+ }
1460
+ }
1461
+ function validateLastEpoch(value) {
1462
+ if (!Number.isInteger(value) || value < -1) {
1463
+ throw new InvalidParameterError("lastEpoch must be an integer >= -1", "lastEpoch", value);
1464
+ }
1465
+ return value;
1466
+ }
1467
+ function validateFiniteNumber(value, name) {
1468
+ if (!Number.isFinite(value)) {
1469
+ throw new InvalidParameterError(`${name} must be finite`, name, value);
1470
+ }
1471
+ return value;
1472
+ }
1473
+ function validatePositiveNumber(value, name) {
1474
+ if (!Number.isFinite(value) || value <= 0) {
1475
+ throw new InvalidParameterError(`${name} must be > 0`, name, value);
1476
+ }
1477
+ return value;
1478
+ }
1479
+ function validatePositiveInteger(value, name) {
1480
+ if (!Number.isInteger(value) || value <= 0) {
1481
+ throw new InvalidParameterError(`${name} must be a positive integer`, name, value);
1482
+ }
1483
+ return value;
1484
+ }
1485
+ function validateNonNegativeNumber(value, name) {
1486
+ if (!Number.isFinite(value) || value < 0) {
1487
+ throw new InvalidParameterError(`${name} must be >= 0`, name, value);
1488
+ }
1489
+ return value;
1490
+ }
1491
+ function validateNonNegativeInteger(value, name) {
1492
+ if (!Number.isInteger(value) || value < 0) {
1493
+ throw new InvalidParameterError(`${name} must be a non-negative integer`, name, value);
1494
+ }
1495
+ return value;
1496
+ }
1497
+ function validateOptimizer(optimizer) {
1498
+ if (!optimizer || typeof optimizer !== "object" || !Array.isArray(optimizer.paramGroups)) {
1499
+ throw new InvalidParameterError(
1500
+ "optimizer must expose paramGroups array",
1501
+ "optimizer",
1502
+ optimizer
1503
+ );
1504
+ }
1505
+ if (optimizer.paramGroups.length === 0) {
1506
+ throw new InvalidParameterError(
1507
+ "optimizer.paramGroups must contain at least one group",
1508
+ "optimizer.paramGroups",
1509
+ optimizer.paramGroups
1510
+ );
1511
+ }
1512
+ for (let i = 0; i < optimizer.paramGroups.length; i++) {
1513
+ const group = optimizer.paramGroups[i];
1514
+ if (!group || typeof group !== "object") {
1515
+ throw new InvalidParameterError(
1516
+ `optimizer.paramGroups[${i}] must be an object`,
1517
+ "optimizer.paramGroups",
1518
+ group
1519
+ );
1520
+ }
1521
+ if (!Array.isArray(group.params)) {
1522
+ throw new InvalidParameterError(
1523
+ `optimizer.paramGroups[${i}].params must be an array`,
1524
+ `optimizer.paramGroups[${i}].params`,
1525
+ group.params
1526
+ );
1527
+ }
1528
+ resolveGroupLr(group, i);
1529
+ }
1530
+ }
1531
+ function validateMilestones(milestones) {
1532
+ if (!Array.isArray(milestones) || milestones.length === 0) {
1533
+ throw new InvalidParameterError(
1534
+ "milestones must be a non-empty array of non-negative integers",
1535
+ "milestones",
1536
+ milestones
1537
+ );
1538
+ }
1539
+ const sorted = [...milestones].sort((a, b) => a - b);
1540
+ for (let i = 0; i < sorted.length; i++) {
1541
+ const value = sorted[i];
1542
+ if (value === void 0 || !Number.isInteger(value) || value < 0) {
1543
+ throw new InvalidParameterError(
1544
+ "milestones must contain non-negative integers only",
1545
+ "milestones",
1546
+ milestones
1547
+ );
1548
+ }
1549
+ if (i > 0) {
1550
+ const prev = sorted[i - 1];
1551
+ if (prev !== void 0 && value <= prev) {
1552
+ throw new InvalidParameterError(
1553
+ "milestones must be strictly increasing",
1554
+ "milestones",
1555
+ milestones
1556
+ );
1557
+ }
1558
+ }
1559
+ }
1560
+ return sorted;
1561
+ }
1562
+ var LRScheduler = class {
1563
+ optimizer;
1564
+ lastEpoch;
1565
+ baseLrs;
1566
+ constructor(optimizer, lastEpoch = -1) {
1567
+ validateOptimizer(optimizer);
1568
+ this.lastEpoch = validateLastEpoch(lastEpoch);
1569
+ this.optimizer = optimizer;
1570
+ this.baseLrs = optimizer.paramGroups.map((group, index) => resolveGroupLr(group, index));
1571
+ }
1572
+ initializeFromLastEpoch(lastEpoch) {
1573
+ const validated = validateLastEpoch(lastEpoch);
1574
+ if (validated < 0) {
1575
+ return;
1576
+ }
1577
+ this.lastEpoch = -1;
1578
+ for (let i = 0; i <= validated; i++) {
1579
+ this.step();
1580
+ }
1581
+ }
1582
+ /**
1583
+ * Perform a scheduler step, updating learning rates.
1584
+ *
1585
+ * Should be called once per epoch after the optimizer step.
1586
+ */
1587
+ step() {
1588
+ this.lastEpoch++;
1589
+ const newLrs = this.getLr();
1590
+ for (let i = 0; i < this.optimizer.paramGroups.length; i++) {
1591
+ const group = this.optimizer.paramGroups[i];
1592
+ if (group) {
1593
+ const next = newLrs[i];
1594
+ if (next !== void 0) {
1595
+ setGroupLr(group, next);
1596
+ }
1597
+ }
1598
+ }
1599
+ }
1600
+ /**
1601
+ * Get the current learning rates for all parameter groups.
1602
+ */
1603
+ getLastLr() {
1604
+ return this.optimizer.paramGroups.map((group, index) => resolveGroupLr(group, index));
1605
+ }
1606
+ /**
1607
+ * Get current epoch number.
1608
+ */
1609
+ get epoch() {
1610
+ return this.lastEpoch;
1611
+ }
1612
+ };
1613
+ var StepLR = class extends LRScheduler {
1614
+ stepSize;
1615
+ gamma;
1616
+ constructor(optimizer, options) {
1617
+ const stepSize = validatePositiveInteger(options.stepSize, "stepSize");
1618
+ const gamma = validatePositiveNumber(options.gamma ?? 0.1, "gamma");
1619
+ const lastEpoch = validateLastEpoch(options.lastEpoch ?? -1);
1620
+ super(optimizer, -1);
1621
+ this.stepSize = stepSize;
1622
+ this.gamma = gamma;
1623
+ this.initializeFromLastEpoch(lastEpoch);
1624
+ }
1625
+ getLr() {
1626
+ const factor = this.gamma ** Math.floor(this.lastEpoch / this.stepSize);
1627
+ return this.baseLrs.map((lr) => lr * factor);
1628
+ }
1629
+ };
1630
+ var ExponentialLR = class extends LRScheduler {
1631
+ gamma;
1632
+ constructor(optimizer, options) {
1633
+ const gamma = validatePositiveNumber(options.gamma, "gamma");
1634
+ const lastEpoch = validateLastEpoch(options.lastEpoch ?? -1);
1635
+ super(optimizer, -1);
1636
+ this.gamma = gamma;
1637
+ this.initializeFromLastEpoch(lastEpoch);
1638
+ }
1639
+ getLr() {
1640
+ return this.baseLrs.map((lr) => lr * this.gamma ** this.lastEpoch);
1641
+ }
1642
+ };
1643
+ var CosineAnnealingLR = class extends LRScheduler {
1644
+ T_max;
1645
+ etaMin;
1646
+ constructor(optimizer, options) {
1647
+ const tMax = validatePositiveInteger(options.T_max, "T_max");
1648
+ const etaMin = validateNonNegativeNumber(options.etaMin ?? 0, "etaMin");
1649
+ const lastEpoch = validateLastEpoch(options.lastEpoch ?? -1);
1650
+ super(optimizer, -1);
1651
+ this.T_max = tMax;
1652
+ this.etaMin = etaMin;
1653
+ this.initializeFromLastEpoch(lastEpoch);
1654
+ }
1655
+ getLr() {
1656
+ return this.baseLrs.map((baseLr) => {
1657
+ return this.etaMin + (baseLr - this.etaMin) * (1 + Math.cos(Math.PI * this.lastEpoch / this.T_max)) / 2;
1658
+ });
1659
+ }
1660
+ };
1661
+ var MultiStepLR = class extends LRScheduler {
1662
+ sortedMilestones;
1663
+ gamma;
1664
+ constructor(optimizer, options) {
1665
+ const milestones = validateMilestones(options.milestones);
1666
+ const gamma = validatePositiveNumber(options.gamma ?? 0.1, "gamma");
1667
+ const lastEpoch = validateLastEpoch(options.lastEpoch ?? -1);
1668
+ super(optimizer, -1);
1669
+ this.sortedMilestones = milestones;
1670
+ this.gamma = gamma;
1671
+ this.initializeFromLastEpoch(lastEpoch);
1672
+ }
1673
+ getLr() {
1674
+ let numDecays = 0;
1675
+ for (const milestone of this.sortedMilestones) {
1676
+ if (this.lastEpoch >= milestone) {
1677
+ numDecays++;
1678
+ }
1679
+ }
1680
+ const factor = this.gamma ** numDecays;
1681
+ return this.baseLrs.map((lr) => lr * factor);
1682
+ }
1683
+ };
1684
+ var LinearLR = class extends LRScheduler {
1685
+ startFactor;
1686
+ endFactor;
1687
+ totalIters;
1688
+ constructor(optimizer, options) {
1689
+ const startFactor = validatePositiveNumber(options.startFactor ?? 1 / 3, "startFactor");
1690
+ const endFactor = validatePositiveNumber(options.endFactor ?? 1, "endFactor");
1691
+ const totalIters = validatePositiveInteger(options.totalIters, "totalIters");
1692
+ const lastEpoch = validateLastEpoch(options.lastEpoch ?? -1);
1693
+ super(optimizer, -1);
1694
+ this.startFactor = startFactor;
1695
+ this.endFactor = endFactor;
1696
+ this.totalIters = totalIters;
1697
+ this.initializeFromLastEpoch(lastEpoch);
1698
+ }
1699
+ getLr() {
1700
+ if (this.lastEpoch >= this.totalIters) {
1701
+ return this.baseLrs.map((lr) => lr * this.endFactor);
1702
+ }
1703
+ const factor = this.startFactor + (this.endFactor - this.startFactor) * (this.lastEpoch / this.totalIters);
1704
+ return this.baseLrs.map((lr) => lr * factor);
1705
+ }
1706
+ };
1707
+ var ReduceLROnPlateau = class {
1708
+ optimizer;
1709
+ mode;
1710
+ factor;
1711
+ patience;
1712
+ threshold;
1713
+ cooldown;
1714
+ minLr;
1715
+ best;
1716
+ numBadEpochs;
1717
+ cooldownCounter;
1718
+ constructor(optimizer, options = {}) {
1719
+ this.optimizer = optimizer;
1720
+ validateOptimizer(optimizer);
1721
+ this.mode = options.mode ?? "min";
1722
+ if (this.mode !== "min" && this.mode !== "max") {
1723
+ throw new InvalidParameterError("mode must be 'min' or 'max'", "mode", options.mode);
1724
+ }
1725
+ this.factor = validateFiniteNumber(options.factor ?? 0.1, "factor");
1726
+ if (this.factor <= 0 || this.factor >= 1) {
1727
+ throw new InvalidParameterError(
1728
+ "factor must be in the interval (0, 1)",
1729
+ "factor",
1730
+ this.factor
1731
+ );
1732
+ }
1733
+ this.patience = validateNonNegativeInteger(options.patience ?? 10, "patience");
1734
+ this.threshold = validateNonNegativeNumber(options.threshold ?? 1e-4, "threshold");
1735
+ this.cooldown = validateNonNegativeInteger(options.cooldown ?? 0, "cooldown");
1736
+ this.minLr = validateNonNegativeNumber(options.minLr ?? 0, "minLr");
1737
+ this.best = this.mode === "min" ? Infinity : -Infinity;
1738
+ this.numBadEpochs = 0;
1739
+ this.cooldownCounter = 0;
1740
+ }
1741
+ /**
1742
+ * Check if metric improved.
1743
+ */
1744
+ isBetter(current) {
1745
+ if (this.mode === "min") {
1746
+ return current < this.best - this.threshold;
1747
+ }
1748
+ return current > this.best + this.threshold;
1749
+ }
1750
+ /**
1751
+ * Perform a scheduler step based on the metric value.
1752
+ *
1753
+ * @param metric - Current value of the metric being monitored
1754
+ */
1755
+ step(metric) {
1756
+ if (!Number.isFinite(metric)) {
1757
+ throw new InvalidParameterError("metric must be finite", "metric", metric);
1758
+ }
1759
+ if (this.cooldownCounter > 0) {
1760
+ this.cooldownCounter--;
1761
+ this.numBadEpochs = 0;
1762
+ }
1763
+ if (this.isBetter(metric)) {
1764
+ this.best = metric;
1765
+ this.numBadEpochs = 0;
1766
+ } else if (this.cooldownCounter === 0) {
1767
+ this.numBadEpochs++;
1768
+ }
1769
+ if (this.numBadEpochs > this.patience) {
1770
+ this.reduceLr();
1771
+ this.cooldownCounter = this.cooldown;
1772
+ this.numBadEpochs = 0;
1773
+ }
1774
+ }
1775
+ /**
1776
+ * Reduce learning rate for all parameter groups.
1777
+ */
1778
+ reduceLr() {
1779
+ for (let i = 0; i < this.optimizer.paramGroups.length; i++) {
1780
+ const group = this.optimizer.paramGroups[i];
1781
+ if (!group) {
1782
+ throw new InvalidParameterError(
1783
+ `optimizer.paramGroups[${i}] is missing`,
1784
+ "optimizer.paramGroups",
1785
+ group
1786
+ );
1787
+ }
1788
+ const currentLr = resolveGroupLr(group, i);
1789
+ const newLr = Math.max(currentLr * this.factor, this.minLr);
1790
+ setGroupLr(group, newLr);
1791
+ }
1792
+ }
1793
+ /**
1794
+ * Get the current learning rates for all parameter groups.
1795
+ */
1796
+ getLastLr() {
1797
+ return this.optimizer.paramGroups.map((group, index) => resolveGroupLr(group, index));
1798
+ }
1799
+ };
1800
+ var WarmupLR = class extends LRScheduler {
1801
+ warmupEpochs;
1802
+ afterScheduler;
1803
+ constructor(optimizer, afterScheduler, options) {
1804
+ const warmupEpochs = validatePositiveInteger(options.warmupEpochs, "warmupEpochs");
1805
+ const lastEpoch = validateLastEpoch(options.lastEpoch ?? -1);
1806
+ super(optimizer, -1);
1807
+ this.warmupEpochs = warmupEpochs;
1808
+ this.afterScheduler = afterScheduler;
1809
+ this.initializeFromLastEpoch(lastEpoch);
1810
+ }
1811
+ getLr() {
1812
+ if (this.lastEpoch < this.warmupEpochs) {
1813
+ const factor = (this.lastEpoch + 1) / this.warmupEpochs;
1814
+ return this.baseLrs.map((lr) => lr * factor);
1815
+ }
1816
+ if (this.afterScheduler) {
1817
+ return this.afterScheduler.getLr();
1818
+ }
1819
+ return this.baseLrs;
1820
+ }
1821
+ step() {
1822
+ super.step();
1823
+ if (this.lastEpoch >= this.warmupEpochs && this.afterScheduler) {
1824
+ this.afterScheduler.step();
1825
+ }
1826
+ }
1827
+ };
1828
+ var OneCycleLR = class extends LRScheduler {
1829
+ maxLr;
1830
+ totalSteps;
1831
+ pctStart;
1832
+ divFactor;
1833
+ finalDivFactor;
1834
+ annealStrategy;
1835
+ constructor(optimizer, options) {
1836
+ const maxLr = validatePositiveNumber(options.maxLr, "maxLr");
1837
+ const totalSteps = validatePositiveInteger(options.totalSteps, "totalSteps");
1838
+ const pctStart = validateFiniteNumber(options.pctStart ?? 0.3, "pctStart");
1839
+ if (pctStart <= 0 || pctStart >= 1) {
1840
+ throw new InvalidParameterError(
1841
+ "pctStart must be in the interval (0, 1)",
1842
+ "pctStart",
1843
+ pctStart
1844
+ );
1845
+ }
1846
+ const divFactor = validatePositiveNumber(options.divFactor ?? 25, "divFactor");
1847
+ const finalDivFactor = validatePositiveNumber(options.finalDivFactor ?? 1e4, "finalDivFactor");
1848
+ const annealStrategy = options.annealStrategy ?? "cos";
1849
+ if (annealStrategy !== "cos" && annealStrategy !== "linear") {
1850
+ throw new InvalidParameterError(
1851
+ "annealStrategy must be 'cos' or 'linear'",
1852
+ "annealStrategy",
1853
+ annealStrategy
1854
+ );
1855
+ }
1856
+ const lastEpoch = validateLastEpoch(options.lastEpoch ?? -1);
1857
+ super(optimizer, -1);
1858
+ this.maxLr = maxLr;
1859
+ this.totalSteps = totalSteps;
1860
+ this.pctStart = pctStart;
1861
+ this.divFactor = divFactor;
1862
+ this.finalDivFactor = finalDivFactor;
1863
+ this.annealStrategy = annealStrategy;
1864
+ this.initializeFromLastEpoch(lastEpoch);
1865
+ }
1866
+ getLr() {
1867
+ const stepNum = this.lastEpoch;
1868
+ const upSteps = Math.max(1, Math.floor(this.totalSteps * this.pctStart));
1869
+ const downSteps = Math.max(1, this.totalSteps - upSteps);
1870
+ const initialLr = this.maxLr / this.divFactor;
1871
+ const minLr = this.maxLr / this.finalDivFactor;
1872
+ let lr;
1873
+ if (stepNum >= this.totalSteps) {
1874
+ lr = minLr;
1875
+ } else if (stepNum < upSteps) {
1876
+ const pct = stepNum / upSteps;
1877
+ lr = initialLr + (this.maxLr - initialLr) * pct;
1878
+ } else {
1879
+ const pct = (stepNum - upSteps) / downSteps;
1880
+ if (this.annealStrategy === "cos") {
1881
+ lr = minLr + (this.maxLr - minLr) * (1 + Math.cos(Math.PI * pct)) / 2;
1882
+ } else {
1883
+ lr = this.maxLr - (this.maxLr - minLr) * pct;
1884
+ }
1885
+ }
1886
+ const baseRef = this.baseLrs[0] ?? 0;
1887
+ return this.baseLrs.map((baseLr) => {
1888
+ if (baseRef === 0) {
1889
+ return baseLr === 0 ? 0 : lr;
1890
+ }
1891
+ return lr * (baseLr / baseRef);
1892
+ });
1893
+ }
1894
+ };
1895
+
1896
+ export { AdaDelta, Adagrad, Adam, AdamW, CosineAnnealingLR, ExponentialLR, LRScheduler, LinearLR, MultiStepLR, Nadam, OneCycleLR, Optimizer, RMSprop, ReduceLROnPlateau, SGD, StepLR, WarmupLR, optim_exports };
1897
+ //# sourceMappingURL=chunk-PR647I7R.js.map
1898
+ //# sourceMappingURL=chunk-PR647I7R.js.map