deepbox 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +344 -0
  3. package/dist/CSRMatrix-CwGwQRea.d.cts +219 -0
  4. package/dist/CSRMatrix-KzNt6QpS.d.ts +219 -0
  5. package/dist/Tensor-BQLk1ltW.d.cts +147 -0
  6. package/dist/Tensor-g8mUClel.d.ts +147 -0
  7. package/dist/chunk-4S73VUBD.js +677 -0
  8. package/dist/chunk-4S73VUBD.js.map +1 -0
  9. package/dist/chunk-5R4S63PF.js +2925 -0
  10. package/dist/chunk-5R4S63PF.js.map +1 -0
  11. package/dist/chunk-6AE5FKKQ.cjs +9264 -0
  12. package/dist/chunk-6AE5FKKQ.cjs.map +1 -0
  13. package/dist/chunk-AD436M45.js +3854 -0
  14. package/dist/chunk-AD436M45.js.map +1 -0
  15. package/dist/chunk-ALS7ETWZ.cjs +4263 -0
  16. package/dist/chunk-ALS7ETWZ.cjs.map +1 -0
  17. package/dist/chunk-AU7XHGKJ.js +2092 -0
  18. package/dist/chunk-AU7XHGKJ.js.map +1 -0
  19. package/dist/chunk-B5TNKUEY.js +1481 -0
  20. package/dist/chunk-B5TNKUEY.js.map +1 -0
  21. package/dist/chunk-BCR7G3A6.js +9136 -0
  22. package/dist/chunk-BCR7G3A6.js.map +1 -0
  23. package/dist/chunk-C4PKXY74.cjs +1917 -0
  24. package/dist/chunk-C4PKXY74.cjs.map +1 -0
  25. package/dist/chunk-DWZY6PIP.cjs +6400 -0
  26. package/dist/chunk-DWZY6PIP.cjs.map +1 -0
  27. package/dist/chunk-E3EU5FZO.cjs +2113 -0
  28. package/dist/chunk-E3EU5FZO.cjs.map +1 -0
  29. package/dist/chunk-F3JWBINJ.js +1054 -0
  30. package/dist/chunk-F3JWBINJ.js.map +1 -0
  31. package/dist/chunk-FJYLIGJX.js +1940 -0
  32. package/dist/chunk-FJYLIGJX.js.map +1 -0
  33. package/dist/chunk-JSCDE774.cjs +729 -0
  34. package/dist/chunk-JSCDE774.cjs.map +1 -0
  35. package/dist/chunk-LWECRCW2.cjs +2412 -0
  36. package/dist/chunk-LWECRCW2.cjs.map +1 -0
  37. package/dist/chunk-MLBMYKCG.js +6379 -0
  38. package/dist/chunk-MLBMYKCG.js.map +1 -0
  39. package/dist/chunk-OX6QXFMV.cjs +3874 -0
  40. package/dist/chunk-OX6QXFMV.cjs.map +1 -0
  41. package/dist/chunk-PHV2DKRS.cjs +1072 -0
  42. package/dist/chunk-PHV2DKRS.cjs.map +1 -0
  43. package/dist/chunk-PL7TAYKI.js +4056 -0
  44. package/dist/chunk-PL7TAYKI.js.map +1 -0
  45. package/dist/chunk-PR647I7R.js +1898 -0
  46. package/dist/chunk-PR647I7R.js.map +1 -0
  47. package/dist/chunk-QERHVCHC.cjs +2960 -0
  48. package/dist/chunk-QERHVCHC.cjs.map +1 -0
  49. package/dist/chunk-XEG44RF6.cjs +1514 -0
  50. package/dist/chunk-XEG44RF6.cjs.map +1 -0
  51. package/dist/chunk-XMWVME2W.js +2377 -0
  52. package/dist/chunk-XMWVME2W.js.map +1 -0
  53. package/dist/chunk-ZB75FESB.cjs +1979 -0
  54. package/dist/chunk-ZB75FESB.cjs.map +1 -0
  55. package/dist/chunk-ZLW62TJG.cjs +4061 -0
  56. package/dist/chunk-ZLW62TJG.cjs.map +1 -0
  57. package/dist/chunk-ZXKBDFP3.js +4235 -0
  58. package/dist/chunk-ZXKBDFP3.js.map +1 -0
  59. package/dist/core/index.cjs +204 -0
  60. package/dist/core/index.cjs.map +1 -0
  61. package/dist/core/index.d.cts +2 -0
  62. package/dist/core/index.d.ts +2 -0
  63. package/dist/core/index.js +3 -0
  64. package/dist/core/index.js.map +1 -0
  65. package/dist/dataframe/index.cjs +22 -0
  66. package/dist/dataframe/index.cjs.map +1 -0
  67. package/dist/dataframe/index.d.cts +3 -0
  68. package/dist/dataframe/index.d.ts +3 -0
  69. package/dist/dataframe/index.js +5 -0
  70. package/dist/dataframe/index.js.map +1 -0
  71. package/dist/datasets/index.cjs +134 -0
  72. package/dist/datasets/index.cjs.map +1 -0
  73. package/dist/datasets/index.d.cts +3 -0
  74. package/dist/datasets/index.d.ts +3 -0
  75. package/dist/datasets/index.js +5 -0
  76. package/dist/datasets/index.js.map +1 -0
  77. package/dist/index-74AB8Cyh.d.cts +1126 -0
  78. package/dist/index-9oQx1HgV.d.cts +1180 -0
  79. package/dist/index-BJY2SI4i.d.ts +483 -0
  80. package/dist/index-BWGhrDlr.d.ts +733 -0
  81. package/dist/index-B_DK4FKY.d.cts +242 -0
  82. package/dist/index-BbA2Gxfl.d.ts +456 -0
  83. package/dist/index-BgHYAoSS.d.cts +837 -0
  84. package/dist/index-BndMbqsM.d.ts +1439 -0
  85. package/dist/index-C1mfVYoo.d.ts +2517 -0
  86. package/dist/index-CCvlwAmL.d.cts +809 -0
  87. package/dist/index-CDw5CnOU.d.ts +785 -0
  88. package/dist/index-Cn3SdB0O.d.ts +1126 -0
  89. package/dist/index-CrqLlS-a.d.ts +776 -0
  90. package/dist/index-D61yaSMY.d.cts +483 -0
  91. package/dist/index-D9Loo1_A.d.cts +2517 -0
  92. package/dist/index-DIT_OO9C.d.cts +785 -0
  93. package/dist/index-DIp_RrRt.d.ts +242 -0
  94. package/dist/index-DbultU6X.d.cts +1427 -0
  95. package/dist/index-DmEg_LCm.d.cts +776 -0
  96. package/dist/index-DoPWVxPo.d.cts +1439 -0
  97. package/dist/index-DuCxd-8d.d.ts +837 -0
  98. package/dist/index-Dx42TZaY.d.ts +809 -0
  99. package/dist/index-DyZ4QQf5.d.cts +456 -0
  100. package/dist/index-GFAVyOWO.d.ts +1427 -0
  101. package/dist/index-WHQLn0e8.d.cts +733 -0
  102. package/dist/index-ZtI1Iy4L.d.ts +1180 -0
  103. package/dist/index-eJgeni9c.d.cts +1911 -0
  104. package/dist/index-tk4lSYod.d.ts +1911 -0
  105. package/dist/index.cjs +72 -0
  106. package/dist/index.cjs.map +1 -0
  107. package/dist/index.d.cts +17 -0
  108. package/dist/index.d.ts +17 -0
  109. package/dist/index.js +15 -0
  110. package/dist/index.js.map +1 -0
  111. package/dist/linalg/index.cjs +86 -0
  112. package/dist/linalg/index.cjs.map +1 -0
  113. package/dist/linalg/index.d.cts +3 -0
  114. package/dist/linalg/index.d.ts +3 -0
  115. package/dist/linalg/index.js +5 -0
  116. package/dist/linalg/index.js.map +1 -0
  117. package/dist/metrics/index.cjs +158 -0
  118. package/dist/metrics/index.cjs.map +1 -0
  119. package/dist/metrics/index.d.cts +3 -0
  120. package/dist/metrics/index.d.ts +3 -0
  121. package/dist/metrics/index.js +5 -0
  122. package/dist/metrics/index.js.map +1 -0
  123. package/dist/ml/index.cjs +87 -0
  124. package/dist/ml/index.cjs.map +1 -0
  125. package/dist/ml/index.d.cts +3 -0
  126. package/dist/ml/index.d.ts +3 -0
  127. package/dist/ml/index.js +6 -0
  128. package/dist/ml/index.js.map +1 -0
  129. package/dist/ndarray/index.cjs +501 -0
  130. package/dist/ndarray/index.cjs.map +1 -0
  131. package/dist/ndarray/index.d.cts +5 -0
  132. package/dist/ndarray/index.d.ts +5 -0
  133. package/dist/ndarray/index.js +4 -0
  134. package/dist/ndarray/index.js.map +1 -0
  135. package/dist/nn/index.cjs +142 -0
  136. package/dist/nn/index.cjs.map +1 -0
  137. package/dist/nn/index.d.cts +6 -0
  138. package/dist/nn/index.d.ts +6 -0
  139. package/dist/nn/index.js +5 -0
  140. package/dist/nn/index.js.map +1 -0
  141. package/dist/optim/index.cjs +77 -0
  142. package/dist/optim/index.cjs.map +1 -0
  143. package/dist/optim/index.d.cts +4 -0
  144. package/dist/optim/index.d.ts +4 -0
  145. package/dist/optim/index.js +4 -0
  146. package/dist/optim/index.js.map +1 -0
  147. package/dist/plot/index.cjs +114 -0
  148. package/dist/plot/index.cjs.map +1 -0
  149. package/dist/plot/index.d.cts +6 -0
  150. package/dist/plot/index.d.ts +6 -0
  151. package/dist/plot/index.js +5 -0
  152. package/dist/plot/index.js.map +1 -0
  153. package/dist/preprocess/index.cjs +82 -0
  154. package/dist/preprocess/index.cjs.map +1 -0
  155. package/dist/preprocess/index.d.cts +4 -0
  156. package/dist/preprocess/index.d.ts +4 -0
  157. package/dist/preprocess/index.js +5 -0
  158. package/dist/preprocess/index.js.map +1 -0
  159. package/dist/random/index.cjs +74 -0
  160. package/dist/random/index.cjs.map +1 -0
  161. package/dist/random/index.d.cts +3 -0
  162. package/dist/random/index.d.ts +3 -0
  163. package/dist/random/index.js +5 -0
  164. package/dist/random/index.js.map +1 -0
  165. package/dist/stats/index.cjs +142 -0
  166. package/dist/stats/index.cjs.map +1 -0
  167. package/dist/stats/index.d.cts +3 -0
  168. package/dist/stats/index.d.ts +3 -0
  169. package/dist/stats/index.js +5 -0
  170. package/dist/stats/index.js.map +1 -0
  171. package/dist/tensor-B96jjJLQ.d.cts +205 -0
  172. package/dist/tensor-B96jjJLQ.d.ts +205 -0
  173. package/package.json +226 -0
@@ -0,0 +1,1126 @@
1
+ import { G as GradTensor } from './index-B_DK4FKY.cjs';
2
+
3
+ /**
4
+ * Base class for all optimizers.
5
+ *
6
+ * This abstract class provides the foundation for implementing optimization algorithms
7
+ * used in training machine learning models. All concrete optimizers (SGD, Adam, etc.)
8
+ * must extend this class and implement the abstract `step()` method.
9
+ *
10
+ * **Key Features:**
11
+ * - Parameter groups with per-group hyperparameters
12
+ * - State management for stateful optimizers (momentum, adaptive learning rates)
13
+ * - Gradient zeroing utilities
14
+ * - State serialization for checkpointing
15
+ *
16
+ * **Design Pattern:**
17
+ * The optimizer maintains a list of parameter groups, where each group can have
18
+ * different hyperparameters (e.g., different learning rates for different layers).
19
+ * This enables fine-grained control over the optimization process.
20
+ *
21
+ * @example
22
+ * ```ts
23
+ * import { SGD } from 'deepbox/optim';
24
+ *
25
+ * const optimizer = new SGD(model.parameters(), { lr: 0.01 });
26
+ *
27
+ * // Training loop
28
+ * for (let epoch = 0; epoch < 100; epoch++) {
29
+ * optimizer.zeroGrad();
30
+ * const loss = computeLoss();
31
+ * loss.backward();
32
+ * optimizer.step();
33
+ * }
34
+ * ```
35
+ *
36
+ * @example
37
+ * ```ts
38
+ * // Using parameter groups with different learning rates
39
+ * const optimizer = new SGD([
40
+ * { params: model.layer1.parameters(), lr: 0.01 },
41
+ * { params: model.layer2.parameters(), lr: 0.001 }
42
+ * ], { lr: 0.01 });
43
+ * ```
44
+ *
45
+ * References:
46
+ * - PyTorch Optimizer: https://pytorch.org/docs/stable/optim.html
47
+ *
48
+ * @category Optimization
49
+ */
50
+ /**
51
+ * Represents a group of parameters with optional per-group hyperparameters.
52
+ *
53
+ * @template Options - Type of optimizer-specific options
54
+ * @property params - Iterable of parameters to optimize in this group
55
+ */
56
+ type ParamGroup<Options extends Record<string, unknown>> = {
57
+ readonly params: Iterable<GradTensor>;
58
+ } & Partial<Options>;
59
+ /**
60
+ * Abstract base class for all optimization algorithms.
61
+ *
62
+ * @template Options - Type defining optimizer-specific hyperparameters
63
+ * @template State - Type defining per-parameter state (e.g., momentum buffers)
64
+ */
65
+ declare abstract class Optimizer<Options extends Record<string, unknown>, State extends Record<string, unknown>> {
66
+ protected readonly defaults: Readonly<Options>;
67
+ /**
68
+ * Groups of parameters with their associated hyperparameters.
69
+ * Each group can have different options (e.g., learning rates).
70
+ * Exposed publicly to enable scheduler integrations.
71
+ */
72
+ paramGroups: Array<{
73
+ params: GradTensor[];
74
+ options: Options;
75
+ }>;
76
+ /**
77
+ * Per-parameter state storage.
78
+ * Maps each parameter to its optimizer-specific state (momentum, adaptive rates, etc.).
79
+ */
80
+ protected state: Map<GradTensor, State>;
81
+ /**
82
+ * Create a new optimizer.
83
+ *
84
+ * Initializes the optimizer with either a simple list of parameters or
85
+ * multiple parameter groups with per-group hyperparameters.
86
+ *
87
+ * @param params - Either an iterable of parameters or array of parameter groups
88
+ * @param defaults - Default hyperparameters applied to all groups
89
+ */
90
+ constructor(params: Iterable<GradTensor> | ReadonlyArray<ParamGroup<Options>>, defaults: Readonly<Options>);
91
+ /**
92
+ * Perform a single optimization step (parameter update).
93
+ *
94
+ * This abstract method must be implemented by all optimizer subclasses.
95
+ * It applies the optimization algorithm to update all parameters based on
96
+ * their gradients.
97
+ *
98
+ * @param closure - Optional closure that reevaluates the model and returns the loss.
99
+ * Used by some optimizers (e.g., LBFGS) that require multiple
100
+ * function evaluations per step.
101
+ * @returns Loss value if closure is provided, undefined otherwise
102
+ */
103
+ abstract step(closure?: () => number): number | undefined;
104
+ /**
105
+ * Zero out the gradients of all optimized parameters.
106
+ *
107
+ * This method should be called at the beginning of each training iteration,
108
+ * before computing new gradients. Without this call, gradients would accumulate
109
+ * across iterations, leading to incorrect updates.
110
+ *
111
+ * **Implementation Note:**
112
+ * For parameters wrapped in GradTensor, this calls zeroGrad() on each parameter,
113
+ * which either sets the gradient to zero or initializes it if not yet created.
114
+ *
115
+ * @example
116
+ * ```ts
117
+ * // Typical training loop
118
+ * optimizer.zeroGrad(); // Clear previous gradients
119
+ * const output = model.forward(input);
120
+ * const loss = criterion(output, target);
121
+ * loss.backward(); // Compute new gradients
122
+ * optimizer.step(); // Update parameters
123
+ * ```
124
+ */
125
+ zeroGrad(): void;
126
+ /**
127
+ * Add a parameter group to the optimizer.
128
+ *
129
+ * This method allows adding new parameters to optimize after the optimizer
130
+ * has been created. This is particularly useful for:
131
+ * - Fine-tuning: adding pre-trained layers with different learning rates
132
+ * - Progressive training: gradually unfreezing layers
133
+ * - Dynamic architectures: adding parameters while the model grows
134
+ *
135
+ * @param paramGroup - Parameter group to add with optional per-group options
136
+ *
137
+ * @example
138
+ * ```ts
139
+ * const optimizer = new SGD(model.backbone.parameters(), { lr: 0.001 });
140
+ * // Later, add classifier with higher learning rate
141
+ * optimizer.addParamGroup({
142
+ * params: model.classifier.parameters(),
143
+ * lr: 0.01
144
+ * });
145
+ * ```
146
+ */
147
+ addParamGroup(paramGroup: ParamGroup<Options>): void;
148
+ /**
149
+ * Validate that a given state object matches the optimizer's state type.
150
+ *
151
+ * @param state - The state object to validate
152
+ * @returns True if the state object is valid, false otherwise
153
+ */
154
+ protected abstract isState(state: Record<string, unknown>): state is State;
155
+ /**
156
+ * Get the current state of the optimizer.
157
+ *
158
+ * Returns a dictionary containing all optimizer state that needs to be
159
+ * saved for checkpointing. This includes per-parameter state (momentum buffers,
160
+ * adaptive learning rates, etc.) and parameter group configurations.
161
+ *
162
+ * **Note:** In a production implementation, parameters would be identified by
163
+ * unique IDs rather than object references for proper serialization.
164
+ *
165
+ * @returns Optimizer state dictionary containing state and parameter groups
166
+ *
167
+ * @example
168
+ * ```ts
169
+ * // Save checkpoint
170
+ * const checkpoint = {
171
+ * model: model.stateDict(),
172
+ * optimizer: optimizer.stateDict(),
173
+ * epoch: currentEpoch
174
+ * };
175
+ * ```
176
+ */
177
+ stateDict(): {
178
+ state: {
179
+ paramId: number;
180
+ param: GradTensor;
181
+ state: State;
182
+ }[];
183
+ paramGroups: {
184
+ params: GradTensor[];
185
+ paramIds: number[];
186
+ options: Options;
187
+ }[];
188
+ };
189
+ /**
190
+ * Load optimizer state from a state dictionary.
191
+ *
192
+ * Restores the optimizer to a previously saved state, including all
193
+ * per-parameter state and parameter group configurations. This is essential
194
+ * for resuming training from checkpoints.
195
+ *
196
+ * **Important:** The loaded state must be compatible with the current
197
+ * optimizer configuration (same parameters, same optimizer type).
198
+ *
199
+ * @param stateDict - State dictionary previously returned by stateDict()
200
+ *
201
+ * @example
202
+ * ```ts
203
+ * // Resume from checkpoint
204
+ * const checkpoint = loadCheckpoint('checkpoint.json');
205
+ * model.loadStateDict(checkpoint.model);
206
+ * optimizer.loadStateDict(checkpoint.optimizer);
207
+ * ```
208
+ */
209
+ loadStateDict(stateDict: Record<string, unknown>): void;
210
+ }
211
+
212
+ type AdaDeltaOptions = {
213
+ lr: number;
214
+ readonly rho: number;
215
+ readonly eps: number;
216
+ readonly weightDecay: number;
217
+ };
218
+ type AdaDeltaState = {
219
+ squareAvg: Float64Array;
220
+ accDelta: Float64Array;
221
+ };
222
+ /**
223
+ * AdaDelta optimizer.
224
+ *
225
+ * Implements AdaDelta algorithm - an extension of Adagrad that seeks to reduce
226
+ * its aggressive, monotonically decreasing learning rate. AdaDelta adapts learning
227
+ * rates based on a moving window of gradient updates, rather than accumulating all
228
+ * past gradients.
229
+ *
230
+ * @example
231
+ * ```ts
232
+ * import { AdaDelta } from 'deepbox/optim';
233
+ *
234
+ * const optimizer = new AdaDelta(model.parameters(), {
235
+ * lr: 1.0,
236
+ * rho: 0.9,
237
+ * eps: 1e-6
238
+ * });
239
+ *
240
+ * // Training loop
241
+ * for (let epoch = 0; epoch < numEpochs; epoch++) {
242
+ * optimizer.zeroGrad();
243
+ * // ...
244
+ * optimizer.step();
245
+ * }
246
+ * ```
247
+ *
248
+ * @category Optimizers
249
+ */
250
+ declare class AdaDelta extends Optimizer<AdaDeltaOptions, AdaDeltaState> {
251
+ private _stepCount;
252
+ get stepCount(): number;
253
+ constructor(params: Iterable<GradTensor> | ReadonlyArray<ParamGroup<AdaDeltaOptions>>, options?: {
254
+ readonly lr?: number;
255
+ readonly rho?: number;
256
+ readonly eps?: number;
257
+ readonly weightDecay?: number;
258
+ });
259
+ /**
260
+ * Get the current learning rate.
261
+ *
262
+ * @param groupIdx - Parameter group index (default: 0)
263
+ * @returns Current learning rate
264
+ */
265
+ getLearningRate(groupIdx?: number): number;
266
+ /**
267
+ * Set the learning rate for all parameter groups.
268
+ *
269
+ * @param lr - New learning rate
270
+ */
271
+ setLearningRate(lr: number): void;
272
+ protected isState(state: Record<string, unknown>): state is AdaDeltaState;
273
+ step(closure?: () => number): number | undefined;
274
+ }
275
+
276
+ type AdagradOptions = {
277
+ lr: number;
278
+ eps: number;
279
+ weightDecay: number;
280
+ lrDecay: number;
281
+ };
282
+ type AdagradState = {
283
+ step: number;
284
+ sum: Float64Array;
285
+ };
286
+ /**
287
+ * Adagrad (Adaptive Gradient Algorithm) optimizer.
288
+ *
289
+ * Adagrad adapts the learning rate for each parameter based on the historical
290
+ * sum of squared gradients. Parameters with larger gradients receive smaller
291
+ * effective learning rates, while parameters with smaller gradients receive
292
+ * larger effective learning rates.
293
+ *
294
+ * @example
295
+ * ```ts
296
+ * import { Adagrad } from 'deepbox/optim';
297
+ *
298
+ * const optimizer = new Adagrad(model.parameters(), {
299
+ * lr: 0.01,
300
+ * eps: 1e-10
301
+ * });
302
+ *
303
+ * // Training loop
304
+ * for (let epoch = 0; epoch < numEpochs; epoch++) {
305
+ * optimizer.zeroGrad();
306
+ * // ...
307
+ * optimizer.step();
308
+ * }
309
+ * ```
310
+ *
311
+ * @category Optimizers
312
+ */
313
+ declare class Adagrad extends Optimizer<AdagradOptions, AdagradState> {
314
+ private _stepCount;
315
+ get stepCount(): number;
316
+ constructor(params: Iterable<GradTensor> | ReadonlyArray<ParamGroup<AdagradOptions>>, options?: {
317
+ readonly lr?: number;
318
+ readonly eps?: number;
319
+ readonly weightDecay?: number;
320
+ readonly lrDecay?: number;
321
+ });
322
+ /**
323
+ * Get the current learning rate.
324
+ *
325
+ * @param groupIdx - Parameter group index (default: 0)
326
+ * @returns Current learning rate
327
+ */
328
+ getLearningRate(groupIdx?: number): number;
329
+ /**
330
+ * Set the learning rate for all parameter groups.
331
+ *
332
+ * @param lr - New learning rate
333
+ */
334
+ setLearningRate(lr: number): void;
335
+ protected isState(state: Record<string, unknown>): state is AdagradState;
336
+ step(closure?: () => number): number | undefined;
337
+ }
338
+
339
+ type AdamOptions = {
340
+ lr: number;
341
+ beta1: number;
342
+ beta2: number;
343
+ eps: number;
344
+ weightDecay: number;
345
+ amsgrad: boolean;
346
+ };
347
+ type AdamState = {
348
+ step: number;
349
+ expAvg: Float64Array;
350
+ expAvgSq: Float64Array;
351
+ maxExpAvgSq?: Float64Array;
352
+ };
353
+ /**
354
+ * Adam (Adaptive Moment Estimation) optimizer.
355
+ *
356
+ * Computes adaptive learning rates for each parameter by maintaining
357
+ * running averages of both the gradients and their squared values.
358
+ *
359
+ * @example
360
+ * ```ts
361
+ * import { Adam } from 'deepbox/optim';
362
+ *
363
+ * const optimizer = new Adam(model.parameters(), {
364
+ * lr: 0.001,
365
+ * beta1: 0.9,
366
+ * beta2: 0.999
367
+ * });
368
+ * ```
369
+ *
370
+ * @category Optimizers
371
+ */
372
+ declare class Adam extends Optimizer<AdamOptions, AdamState> {
373
+ private _stepCount;
374
+ get stepCount(): number;
375
+ constructor(params: Iterable<GradTensor> | ReadonlyArray<ParamGroup<AdamOptions>>, options?: {
376
+ readonly lr?: number;
377
+ readonly beta1?: number;
378
+ readonly beta2?: number;
379
+ readonly eps?: number;
380
+ readonly weightDecay?: number;
381
+ readonly amsgrad?: boolean;
382
+ });
383
+ /**
384
+ * Get the current learning rate.
385
+ *
386
+ * @param groupIdx - Parameter group index (default: 0)
387
+ * @returns Current learning rate
388
+ */
389
+ getLearningRate(groupIdx?: number): number;
390
+ /**
391
+ * Set the learning rate for all parameter groups.
392
+ *
393
+ * @param lr - New learning rate
394
+ */
395
+ setLearningRate(lr: number): void;
396
+ protected isState(state: Record<string, unknown>): state is AdamState;
397
+ step(closure?: () => number): number | undefined;
398
+ }
399
+
400
+ /**
401
+ * Options for the AdamW optimizer.
402
+ *
403
+ * @property lr - Learning rate (step size)
404
+ * @property beta1 - Exponential decay rate for first moment estimates
405
+ * @property beta2 - Exponential decay rate for second moment estimates
406
+ * @property eps - Small constant for numerical stability
407
+ * @property weightDecay - Weight decay coefficient (L2 penalty)
408
+ * @property amsgrad - Whether to use the AMSGrad variant
409
+ */
410
+ type AdamWOptions = {
411
+ lr: number;
412
+ beta1: number;
413
+ beta2: number;
414
+ eps: number;
415
+ weightDecay: number;
416
+ amsgrad: boolean;
417
+ };
418
+ /**
419
+ * State maintained per parameter by AdamW.
420
+ *
421
+ * @property step - Number of optimization steps taken
422
+ * @property expAvg - Exponentially weighted average of gradients (first moment)
423
+ * @property expAvgSq - Exponentially weighted average of squared gradients (second moment)
424
+ * @property maxExpAvgSq - Maximum of exponentially weighted average of squared gradients (AMSGrad only)
425
+ */
426
+ type AdamWState = {
427
+ step: number;
428
+ expAvg: Float64Array;
429
+ expAvgSq: Float64Array;
430
+ maxExpAvgSq?: Float64Array;
431
+ };
432
+ /**
433
+ * AdamW (Adam with decoupled Weight decay) optimizer.
434
+ *
435
+ * AdamW fixes the weight decay implementation in Adam by decoupling it from the
436
+ * gradient-based update. This leads to better generalization and is the recommended
437
+ * variant for most applications.
438
+ *
439
+ * @example
440
+ * ```ts
441
+ * import { AdamW } from 'deepbox/optim';
442
+ *
443
+ * const optimizer = new AdamW(model.parameters(), {
444
+ * lr: 0.001,
445
+ * weightDecay: 0.01, // Typical value for AdamW
446
+ * beta1: 0.9,
447
+ * beta2: 0.999
448
+ * });
449
+ *
450
+ * // Training loop
451
+ * for (let epoch = 0; epoch < numEpochs; epoch++) {
452
+ * optimizer.zeroGrad();
453
+ * // ...
454
+ * optimizer.step();
455
+ * }
456
+ * ```
457
+ *
458
+ * @category Optimizers
459
+ */
460
+ declare class AdamW extends Optimizer<AdamWOptions, AdamWState> {
461
+ /** Internal counter tracking total number of optimization steps */
462
+ private _stepCount;
463
+ /**
464
+ * Get the total number of optimization steps performed.
465
+ *
466
+ * @returns Number of steps taken
467
+ */
468
+ get stepCount(): number;
469
+ /**
470
+ * Create a new AdamW optimizer.
471
+ *
472
+ * @param params - Iterable of parameters or parameter groups to optimize
473
+ * @param options - Optimization options
474
+ * @param options.lr - Learning rate (default: 0.001)
475
+ * @param options.beta1 - First moment decay rate (default: 0.9)
476
+ * @param options.beta2 - Second moment decay rate (default: 0.999)
477
+ * @param options.eps - Numerical stability constant (default: 1e-8)
478
+ * @param options.weightDecay - Weight decay coefficient (default: 0.01)
479
+ * @param options.amsgrad - Enable AMSGrad variant (default: false)
480
+ * @throws {InvalidParameterError} If a parameter is invalid
481
+ */
482
+ constructor(params: Iterable<GradTensor> | ReadonlyArray<ParamGroup<AdamWOptions>>, options?: {
483
+ readonly lr?: number;
484
+ readonly beta1?: number;
485
+ readonly beta2?: number;
486
+ readonly eps?: number;
487
+ readonly weightDecay?: number;
488
+ readonly amsgrad?: boolean;
489
+ });
490
+ /**
491
+ * Get the current learning rate.
492
+ *
493
+ * @param groupIdx - Parameter group index (default: 0)
494
+ * @returns Current learning rate
495
+ */
496
+ getLearningRate(groupIdx?: number): number;
497
+ /**
498
+ * Set the learning rate for all parameter groups.
499
+ *
500
+ * @param lr - New learning rate
501
+ */
502
+ setLearningRate(lr: number): void;
503
+ /**
504
+ * Perform a single optimization step (parameter update).
505
+ *
506
+ * Implements the AdamW update rule with decoupled weight decay.
507
+ *
508
+ * @param closure - Optional closure that reevaluates the model and returns the loss
509
+ * @returns Loss value if closure is provided, undefined otherwise
510
+ */
511
+ protected isState(state: Record<string, unknown>): state is AdamWState;
512
+ step(closure?: () => number): number | undefined;
513
+ }
514
+
515
+ type NadamOptions = {
516
+ lr: number;
517
+ readonly beta1: number;
518
+ readonly beta2: number;
519
+ readonly eps: number;
520
+ readonly weightDecay: number;
521
+ readonly momentumDecay: number;
522
+ };
523
+ type NadamState = {
524
+ step: number;
525
+ expAvg: Float64Array;
526
+ expAvgSq: Float64Array;
527
+ muProduct: number;
528
+ };
529
+ /**
530
+ * Nadam (Nesterov-accelerated Adam) optimizer.
531
+ *
532
+ * Implements Nadam algorithm - combines Adam's adaptive learning rates with
533
+ * Nesterov momentum for potentially faster convergence. Nadam applies Nesterov
534
+ * acceleration to the momentum term, providing a "look-ahead" gradient.
535
+ *
536
+ * @example
537
+ * ```ts
538
+ * import { Nadam } from 'deepbox/optim';
539
+ *
540
+ * const optimizer = new Nadam(model.parameters(), {
541
+ * lr: 0.002,
542
+ * beta1: 0.9,
543
+ * beta2: 0.999
544
+ * });
545
+ *
546
+ * // Training loop
547
+ * for (let epoch = 0; epoch < numEpochs; epoch++) {
548
+ * optimizer.zeroGrad();
549
+ * // ...
550
+ * optimizer.step();
551
+ * }
552
+ * ```
553
+ *
554
+ * @category Optimizers
555
+ */
556
+ declare class Nadam extends Optimizer<NadamOptions, NadamState> {
557
+ private _stepCount;
558
+ get stepCount(): number;
559
+ constructor(params: Iterable<GradTensor> | ReadonlyArray<ParamGroup<NadamOptions>>, options?: {
560
+ readonly lr?: number;
561
+ readonly beta1?: number;
562
+ readonly beta2?: number;
563
+ readonly eps?: number;
564
+ readonly weightDecay?: number;
565
+ readonly momentumDecay?: number;
566
+ });
567
+ /**
568
+ * Get the current learning rate.
569
+ *
570
+ * @param groupIdx - Parameter group index (default: 0)
571
+ * @returns Current learning rate
572
+ */
573
+ getLearningRate(groupIdx?: number): number;
574
+ /**
575
+ * Set the learning rate for all parameter groups.
576
+ *
577
+ * @param lr - New learning rate
578
+ */
579
+ setLearningRate(lr: number): void;
580
+ protected isState(state: Record<string, unknown>): state is NadamState;
581
+ step(closure?: () => number): number | undefined;
582
+ }
583
+
584
+ /**
585
+ * Options for the RMSprop optimizer.
586
+ *
587
+ * @property lr - Learning rate (step size)
588
+ * @property alpha - Smoothing constant for moving average of squared gradients
589
+ * @property eps - Small constant for numerical stability
590
+ * @property weightDecay - Weight decay coefficient (L2 penalty)
591
+ * @property momentum - Momentum factor
592
+ * @property centered - Whether to use centered RMSprop variant
593
+ */
594
+ type RMSpropOptions = {
595
+ lr: number;
596
+ alpha: number;
597
+ eps: number;
598
+ weightDecay: number;
599
+ momentum: number;
600
+ centered: boolean;
601
+ };
602
+ /**
603
+ * State maintained per parameter by RMSprop.
604
+ *
605
+ * @property squareAvg - Exponentially weighted average of squared gradients
606
+ * @property momentumBuffer - Momentum buffer (if momentum > 0)
607
+ * @property gradAvg - Exponentially weighted average of gradients (centered variant only)
608
+ */
609
+ type RMSpropState = {
610
+ squareAvg: Float64Array;
611
+ momentumBuffer?: Float64Array;
612
+ gradAvg?: Float64Array;
613
+ };
614
+ /**
615
+ * RMSprop (Root Mean Square Propagation) optimizer.
616
+ *
617
+ * RMSprop adapts the learning rate for each parameter by dividing by a running
618
+ * average of recent gradient magnitudes. This helps with non-stationary objectives
619
+ * and is particularly effective for RNNs.
620
+ *
621
+ * @example
622
+ * ```ts
623
+ * import { RMSprop } from 'deepbox/optim';
624
+ *
625
+ * const optimizer = new RMSprop(model.parameters(), {
626
+ * lr: 0.01,
627
+ * alpha: 0.99,
628
+ * momentum: 0.9,
629
+ * centered: true
630
+ * });
631
+ *
632
+ * // Training loop
633
+ * for (let epoch = 0; epoch < numEpochs; epoch++) {
634
+ * optimizer.zeroGrad();
635
+ * // ...
636
+ * optimizer.step();
637
+ * }
638
+ * ```
639
+ *
640
+ * @category Optimizers
641
+ */
642
+ declare class RMSprop extends Optimizer<RMSpropOptions, RMSpropState> {
643
+ /** Internal counter tracking total number of optimization steps */
644
+ private _stepCount;
645
+ /**
646
+ * Get the total number of optimization steps performed.
647
+ *
648
+ * @returns Number of steps taken
649
+ */
650
+ get stepCount(): number;
651
+ /**
652
+ * Create a new RMSprop optimizer.
653
+ *
654
+ * @param params - Iterable of parameters or parameter groups to optimize
655
+ * @param options - Optimization options
656
+ * @param options.lr - Learning rate (default: 0.01)
657
+ * @param options.alpha - Smoothing constant (default: 0.99)
658
+ * @param options.eps - Numerical stability constant (default: 1e-8)
659
+ * @param options.weightDecay - Weight decay coefficient (default: 0)
660
+ * @param options.momentum - Momentum factor (default: 0)
661
+ * @param options.centered - Use centered variant (default: false)
662
+ * @throws {InvalidParameterError} If a parameter is invalid
663
+ */
664
+ constructor(params: Iterable<GradTensor> | ReadonlyArray<ParamGroup<RMSpropOptions>>, options?: {
665
+ readonly lr?: number;
666
+ readonly alpha?: number;
667
+ readonly eps?: number;
668
+ readonly weightDecay?: number;
669
+ readonly momentum?: number;
670
+ readonly centered?: boolean;
671
+ });
672
+ /**
673
+ * Get the current learning rate.
674
+ *
675
+ * @param groupIdx - Parameter group index (default: 0)
676
+ * @returns Current learning rate
677
+ */
678
+ getLearningRate(groupIdx?: number): number;
679
+ /**
680
+ * Set the learning rate for all parameter groups.
681
+ *
682
+ * @param lr - New learning rate
683
+ */
684
+ setLearningRate(lr: number): void;
685
+ protected isState(state: Record<string, unknown>): state is RMSpropState;
686
+ step(closure?: () => number): number | undefined;
687
+ }
688
+
689
+ type SGDOptions = {
690
+ lr: number;
691
+ momentum: number;
692
+ dampening: number;
693
+ weightDecay: number;
694
+ nesterov: boolean;
695
+ };
696
+ type SGDState = {
697
+ momentumBuffer?: Float64Array;
698
+ };
699
+ /**
700
+ * Stochastic Gradient Descent (SGD) optimizer.
701
+ *
702
+ * Implements vanilla SGD with optional momentum, weight decay, and Nesterov acceleration.
703
+ *
704
+ * @example
705
+ * ```ts
706
+ * import { SGD } from 'deepbox/optim';
707
+ * import { Module } from 'deepbox/nn';
708
+ *
709
+ * const model: Module = ...;
710
+ * const optimizer = new SGD(model.parameters(), {
711
+ * lr: 0.01,
712
+ * momentum: 0.9,
713
+ * weightDecay: 5e-4,
714
+ * nesterov: true
715
+ * });
716
+ *
717
+ * // Training loop
718
+ * for (let epoch = 0; epoch < numEpochs; epoch++) {
719
+ * for (const [inputs, targets] of dataLoader) {
720
+ * optimizer.zeroGrad();
721
+ * const outputs = model.forward(inputs);
722
+ * const loss = criterion(outputs, targets);
723
+ * loss.backward();
724
+ * optimizer.step();
725
+ * }
726
+ * }
727
+ * ```
728
+ *
729
+ * @category Optimizers
730
+ */
731
+ declare class SGD extends Optimizer<SGDOptions, SGDState> {
732
+ /** Internal counter tracking total number of optimization steps */
733
+ private _stepCount;
734
+ get stepCount(): number;
735
+ /**
736
+ * Create a new SGD optimizer.
737
+ *
738
+ * @param params - Iterable of parameters or parameter groups to optimize
739
+ * @param options - Optimization options
740
+ * @param options.lr - Learning rate (default: 0.01)
741
+ * @param options.momentum - Momentum factor (default: 0)
742
+ * @param options.dampening - Dampening for momentum (default: 0)
743
+ * @param options.weightDecay - Weight decay (L2 penalty) (default: 0)
744
+ * @param options.nesterov - Enable Nesterov momentum (default: false)
745
+ */
746
+ constructor(params: Iterable<GradTensor> | ReadonlyArray<ParamGroup<SGDOptions>>, options?: {
747
+ readonly lr?: number;
748
+ readonly momentum?: number;
749
+ readonly dampening?: number;
750
+ readonly weightDecay?: number;
751
+ readonly nesterov?: boolean;
752
+ });
753
+ /**
754
+ * Perform a single optimization step.
755
+ *
756
+ * Implements the SGD update rule with optional momentum and weight decay.
757
+ *
758
+ * @param closure - Optional closure that reevaluates the model and returns the loss
759
+ * @returns Loss value if closure is provided
760
+ */
761
+ protected isState(state: Record<string, unknown>): state is SGDState;
762
+ step(closure?: () => number): number | undefined;
763
+ /**
764
+ * Get the current learning rate.
765
+ *
766
+ * @param groupIdx - Parameter group index (default: 0)
767
+ * @returns Current learning rate
768
+ */
769
+ getLearningRate(groupIdx?: number): number;
770
+ /**
771
+ * Set the learning rate for all parameter groups.
772
+ *
773
+ * @param lr - New learning rate
774
+ */
775
+ setLearningRate(lr: number): void;
776
+ }
777
+
778
+ /**
779
+ * Interface for optimizer-like objects that schedulers can work with.
780
+ * This allows schedulers to work with different optimizer implementations.
781
+ * Parameter groups may expose `lr` directly or via `options.lr`.
782
+ */
783
+ interface SchedulerOptimizer {
784
+ paramGroups: SchedulerParamGroup[];
785
+ }
786
+ type SchedulerParamGroup = {
787
+ params: unknown[];
788
+ lr?: number;
789
+ options?: Record<string, unknown>;
790
+ };
791
+ /**
792
+ * Base class for learning rate schedulers.
793
+ *
794
+ * Learning rate schedulers adjust the learning rate during training according
795
+ * to a predefined schedule. This can help improve convergence and prevent
796
+ * overshooting optimal solutions.
797
+ *
798
+ * @example
799
+ * ```ts
800
+ * import { SGD, StepLR } from 'deepbox/optim';
801
+ *
802
+ * const optimizer = new SGD(model.parameters(), { lr: 0.1 });
803
+ * const scheduler = new StepLR(optimizer, { stepSize: 10, gamma: 0.1 });
804
+ *
805
+ * for (let epoch = 0; epoch < 100; epoch++) {
806
+ * train();
807
+ * scheduler.step();
808
+ * }
809
+ * ```
810
+ *
811
+ * @category Optimization
812
+ */
813
+ declare abstract class LRScheduler {
814
+ protected optimizer: SchedulerOptimizer;
815
+ protected lastEpoch: number;
816
+ protected baseLrs: number[];
817
+ constructor(optimizer: SchedulerOptimizer, lastEpoch?: number);
818
+ protected initializeFromLastEpoch(lastEpoch: number): void;
819
+ /**
820
+ * Compute the learning rate for the current epoch.
821
+ * Must be implemented by subclasses.
822
+ *
823
+ * @returns Array of learning rates for each parameter group
824
+ */
825
+ abstract getLr(): number[];
826
+ /**
827
+ * Perform a scheduler step, updating learning rates.
828
+ *
829
+ * Should be called once per epoch after the optimizer step.
830
+ */
831
+ step(): void;
832
+ /**
833
+ * Get the current learning rates for all parameter groups.
834
+ */
835
+ getLastLr(): number[];
836
+ /**
837
+ * Get current epoch number.
838
+ */
839
+ get epoch(): number;
840
+ }
841
+ /**
842
+ * Step learning rate scheduler.
843
+ *
844
+ * Decays the learning rate by gamma every stepSize epochs.
845
+ * lr = baseLr * gamma^(epoch // stepSize)
846
+ *
847
+ * @example
848
+ * ```ts
849
+ * const scheduler = new StepLR(optimizer, { stepSize: 30, gamma: 0.1 });
850
+ * // lr = 0.1 for epochs 0-29
851
+ * // lr = 0.01 for epochs 30-59
852
+ * // lr = 0.001 for epochs 60-89
853
+ * ```
854
+ *
855
+ * @see {@link https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html | PyTorch StepLR}
856
+ */
857
+ declare class StepLR extends LRScheduler {
858
+ private stepSize;
859
+ private gamma;
860
+ constructor(optimizer: SchedulerOptimizer, options: {
861
+ stepSize: number;
862
+ gamma?: number;
863
+ lastEpoch?: number;
864
+ });
865
+ getLr(): number[];
866
+ }
867
+ /**
868
+ * Exponential learning rate scheduler.
869
+ *
870
+ * Decays the learning rate exponentially every epoch.
871
+ * lr = baseLr * gamma^epoch
872
+ *
873
+ * @example
874
+ * ```ts
875
+ * const scheduler = new ExponentialLR(optimizer, { gamma: 0.95 });
876
+ * // lr *= 0.95 each epoch
877
+ * ```
878
+ *
879
+ * @see {@link https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ExponentialLR.html | PyTorch ExponentialLR}
880
+ */
881
+ declare class ExponentialLR extends LRScheduler {
882
+ private gamma;
883
+ constructor(optimizer: SchedulerOptimizer, options: {
884
+ gamma: number;
885
+ lastEpoch?: number;
886
+ });
887
+ getLr(): number[];
888
+ }
889
+ /**
890
+ * Cosine annealing learning rate scheduler.
891
+ *
892
+ * Sets the learning rate using a cosine annealing schedule.
893
+ * lr = etaMin + (baseLr - etaMin) * (1 + cos(π * epoch / T_max)) / 2
894
+ *
895
+ * @example
896
+ * ```ts
897
+ * const scheduler = new CosineAnnealingLR(optimizer, { T_max: 100, etaMin: 0.001 });
898
+ * ```
899
+ *
900
+ * @see {@link https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html | PyTorch CosineAnnealingLR}
901
+ */
902
+ declare class CosineAnnealingLR extends LRScheduler {
903
+ private T_max;
904
+ private etaMin;
905
+ constructor(optimizer: SchedulerOptimizer, options: {
906
+ T_max: number;
907
+ etaMin?: number;
908
+ lastEpoch?: number;
909
+ });
910
+ getLr(): number[];
911
+ }
912
+ /**
913
+ * Multi-step learning rate scheduler.
914
+ *
915
+ * Decays the learning rate by gamma once the epoch reaches one of the milestones.
916
+ *
917
+ * @example
918
+ * ```ts
919
+ * const scheduler = new MultiStepLR(optimizer, { milestones: [30, 80], gamma: 0.1 });
920
+ * // lr = 0.1 for epochs 0-29
921
+ * // lr = 0.01 for epochs 30-79
922
+ * // lr = 0.001 for epochs 80+
923
+ * ```
924
+ *
925
+ * @see {@link https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.MultiStepLR.html | PyTorch MultiStepLR}
926
+ */
927
+ declare class MultiStepLR extends LRScheduler {
928
+ private sortedMilestones;
929
+ private gamma;
930
+ constructor(optimizer: SchedulerOptimizer, options: {
931
+ milestones: number[];
932
+ gamma?: number;
933
+ lastEpoch?: number;
934
+ });
935
+ getLr(): number[];
936
+ }
937
+ /**
938
+ * Linear learning rate scheduler.
939
+ *
940
+ * Linearly interpolates the learning rate multiplicative factor from startFactor
941
+ * to endFactor over totalIters epochs. After totalIters, the factor remains at endFactor.
942
+ *
943
+ * lr = baseLr * (startFactor + (endFactor - startFactor) * epoch / totalIters)
944
+ *
945
+ * @example
946
+ * ```ts
947
+ * const scheduler = new LinearLR(optimizer, {
948
+ * startFactor: 0.1,
949
+ * endFactor: 0.01,
950
+ * totalIters: 100
951
+ * });
952
+ * ```
953
+ *
954
+ * @see {@link https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.LinearLR.html | PyTorch LinearLR}
955
+ */
956
+ declare class LinearLR extends LRScheduler {
957
+ private startFactor;
958
+ private endFactor;
959
+ private totalIters;
960
+ constructor(optimizer: SchedulerOptimizer, options: {
961
+ startFactor?: number;
962
+ endFactor?: number;
963
+ totalIters: number;
964
+ lastEpoch?: number;
965
+ });
966
+ getLr(): number[];
967
+ }
968
+ /**
969
+ * Reduce learning rate on plateau.
970
+ *
971
+ * Reduces learning rate when a metric has stopped improving.
972
+ * This scheduler reads a metric value and if no improvement is seen
973
+ * for 'patience' epochs, the learning rate is reduced.
974
+ *
975
+ * @example
976
+ * ```ts
977
+ * const scheduler = new ReduceLROnPlateau(optimizer, {
978
+ * mode: 'min',
979
+ * factor: 0.1,
980
+ * patience: 10
981
+ * });
982
+ *
983
+ * for (let epoch = 0; epoch < 100; epoch++) {
984
+ * const valLoss = validate();
985
+ * scheduler.step(valLoss);
986
+ * }
987
+ * ```
988
+ *
989
+ * @see {@link https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html | PyTorch ReduceLROnPlateau}
990
+ */
991
+ declare class ReduceLROnPlateau {
992
+ private optimizer;
993
+ private mode;
994
+ private factor;
995
+ private patience;
996
+ private threshold;
997
+ private cooldown;
998
+ private minLr;
999
+ private best;
1000
+ private numBadEpochs;
1001
+ private cooldownCounter;
1002
+ constructor(optimizer: SchedulerOptimizer, options?: {
1003
+ mode?: "min" | "max";
1004
+ factor?: number;
1005
+ patience?: number;
1006
+ threshold?: number;
1007
+ cooldown?: number;
1008
+ minLr?: number;
1009
+ });
1010
+ /**
1011
+ * Check if metric improved.
1012
+ */
1013
+ private isBetter;
1014
+ /**
1015
+ * Perform a scheduler step based on the metric value.
1016
+ *
1017
+ * @param metric - Current value of the metric being monitored
1018
+ */
1019
+ step(metric: number): void;
1020
+ /**
1021
+ * Reduce learning rate for all parameter groups.
1022
+ */
1023
+ private reduceLr;
1024
+ /**
1025
+ * Get the current learning rates for all parameter groups.
1026
+ */
1027
+ getLastLr(): number[];
1028
+ }
1029
+ /**
1030
+ * Warmup scheduler that wraps another scheduler.
1031
+ *
1032
+ * Linearly increases the learning rate from 0 to the base lr over warmupEpochs,
1033
+ * then delegates to the wrapped scheduler.
1034
+ *
1035
+ * @example
1036
+ * ```ts
1037
+ * const baseScheduler = new CosineAnnealingLR(optimizer, { T_max: 100 });
1038
+ * const scheduler = new WarmupLR(optimizer, baseScheduler, { warmupEpochs: 5 });
1039
+ * ```
1040
+ */
1041
+ declare class WarmupLR extends LRScheduler {
1042
+ private warmupEpochs;
1043
+ private afterScheduler;
1044
+ constructor(optimizer: SchedulerOptimizer, afterScheduler: LRScheduler | null, options: {
1045
+ warmupEpochs: number;
1046
+ lastEpoch?: number;
1047
+ });
1048
+ getLr(): number[];
1049
+ step(): void;
1050
+ }
1051
+ /**
1052
+ * One-cycle learning rate scheduler.
1053
+ *
1054
+ * Implements the 1cycle policy: lr starts at maxLr/divFactor, increases to maxLr
1055
+ * over pctStart of the training, then decreases to maxLr/finalDivFactor.
1056
+ *
1057
+ * @example
1058
+ * ```ts
1059
+ * const scheduler = new OneCycleLR(optimizer, {
1060
+ * maxLr: 0.1,
1061
+ * totalSteps: 1000,
1062
+ * pctStart: 0.3
1063
+ * });
1064
+ * ```
1065
+ *
1066
+ * @see {@link https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html | PyTorch OneCycleLR}
1067
+ */
1068
+ declare class OneCycleLR extends LRScheduler {
1069
+ private maxLr;
1070
+ private totalSteps;
1071
+ private pctStart;
1072
+ private divFactor;
1073
+ private finalDivFactor;
1074
+ private annealStrategy;
1075
+ constructor(optimizer: SchedulerOptimizer, options: {
1076
+ maxLr: number;
1077
+ totalSteps: number;
1078
+ pctStart?: number;
1079
+ divFactor?: number;
1080
+ finalDivFactor?: number;
1081
+ annealStrategy?: "cos" | "linear";
1082
+ lastEpoch?: number;
1083
+ });
1084
+ getLr(): number[];
1085
+ }
1086
+
1087
+ type index_AdaDelta = AdaDelta;
1088
+ declare const index_AdaDelta: typeof AdaDelta;
1089
+ type index_Adagrad = Adagrad;
1090
+ declare const index_Adagrad: typeof Adagrad;
1091
+ type index_Adam = Adam;
1092
+ declare const index_Adam: typeof Adam;
1093
+ type index_AdamW = AdamW;
1094
+ declare const index_AdamW: typeof AdamW;
1095
+ type index_CosineAnnealingLR = CosineAnnealingLR;
1096
+ declare const index_CosineAnnealingLR: typeof CosineAnnealingLR;
1097
+ type index_ExponentialLR = ExponentialLR;
1098
+ declare const index_ExponentialLR: typeof ExponentialLR;
1099
+ type index_LRScheduler = LRScheduler;
1100
+ declare const index_LRScheduler: typeof LRScheduler;
1101
+ type index_LinearLR = LinearLR;
1102
+ declare const index_LinearLR: typeof LinearLR;
1103
+ type index_MultiStepLR = MultiStepLR;
1104
+ declare const index_MultiStepLR: typeof MultiStepLR;
1105
+ type index_Nadam = Nadam;
1106
+ declare const index_Nadam: typeof Nadam;
1107
+ type index_OneCycleLR = OneCycleLR;
1108
+ declare const index_OneCycleLR: typeof OneCycleLR;
1109
+ type index_Optimizer<Options extends Record<string, unknown>, State extends Record<string, unknown>> = Optimizer<Options, State>;
1110
+ declare const index_Optimizer: typeof Optimizer;
1111
+ type index_ParamGroup<Options extends Record<string, unknown>> = ParamGroup<Options>;
1112
+ type index_RMSprop = RMSprop;
1113
+ declare const index_RMSprop: typeof RMSprop;
1114
+ type index_ReduceLROnPlateau = ReduceLROnPlateau;
1115
+ declare const index_ReduceLROnPlateau: typeof ReduceLROnPlateau;
1116
+ type index_SGD = SGD;
1117
+ declare const index_SGD: typeof SGD;
1118
+ type index_StepLR = StepLR;
1119
+ declare const index_StepLR: typeof StepLR;
1120
+ type index_WarmupLR = WarmupLR;
1121
+ declare const index_WarmupLR: typeof WarmupLR;
1122
+ declare namespace index {
1123
+ export { index_AdaDelta as AdaDelta, index_Adagrad as Adagrad, index_Adam as Adam, index_AdamW as AdamW, index_CosineAnnealingLR as CosineAnnealingLR, index_ExponentialLR as ExponentialLR, index_LRScheduler as LRScheduler, index_LinearLR as LinearLR, index_MultiStepLR as MultiStepLR, index_Nadam as Nadam, index_OneCycleLR as OneCycleLR, index_Optimizer as Optimizer, type index_ParamGroup as ParamGroup, index_RMSprop as RMSprop, index_ReduceLROnPlateau as ReduceLROnPlateau, index_SGD as SGD, index_StepLR as StepLR, index_WarmupLR as WarmupLR };
1124
+ }
1125
+
1126
+ export { AdaDelta as A, CosineAnnealingLR as C, ExponentialLR as E, LinearLR as L, MultiStepLR as M, Nadam as N, Optimizer as O, type ParamGroup as P, RMSprop as R, SGD as S, WarmupLR as W, Adagrad as a, Adam as b, AdamW as c, LRScheduler as d, OneCycleLR as e, ReduceLROnPlateau as f, StepLR as g, index as i };