nshtrainer 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_directory.py +11 -28
- nshtrainer/callbacks/__init__.py +6 -0
- nshtrainer/callbacks/base.py +22 -3
- nshtrainer/callbacks/directory_setup.py +15 -8
- nshtrainer/callbacks/distributed_prediction_writer.py +166 -0
- nshtrainer/configs/__init__.py +28 -0
- nshtrainer/configs/callbacks/__init__.py +6 -0
- nshtrainer/configs/callbacks/distributed_prediction_writer/__init__.py +19 -0
- nshtrainer/configs/optimizer/__init__.py +24 -0
- nshtrainer/configs/trainer/__init__.py +4 -0
- nshtrainer/configs/trainer/_config/__init__.py +4 -0
- nshtrainer/model/base.py +60 -2
- nshtrainer/optimizer.py +559 -1
- nshtrainer/trainer/_config.py +11 -4
- nshtrainer/trainer/trainer.py +21 -2
- {nshtrainer-1.1.2.dist-info → nshtrainer-1.2.1.dist-info}/METADATA +1 -1
- {nshtrainer-1.1.2.dist-info → nshtrainer-1.2.1.dist-info}/RECORD +18 -16
- {nshtrainer-1.1.2.dist-info → nshtrainer-1.2.1.dist-info}/WHEEL +1 -1
nshtrainer/optimizer.py
CHANGED
@@ -2,10 +2,11 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from collections.abc import Iterable
|
5
|
-
from typing import Annotated, Any, Literal
|
5
|
+
from typing import Annotated, Any, Literal, Tuple, Union
|
6
6
|
|
7
7
|
import nshconfig as C
|
8
8
|
import torch.nn as nn
|
9
|
+
from torch import Tensor
|
9
10
|
from torch.optim import Optimizer
|
10
11
|
from typing_extensions import TypeAliasType, final, override
|
11
12
|
|
@@ -45,6 +46,18 @@ class AdamWConfig(OptimizerConfigBase):
|
|
45
46
|
amsgrad: bool = False
|
46
47
|
"""Whether to use the AMSGrad variant of this algorithm."""
|
47
48
|
|
49
|
+
maximize: bool = False
|
50
|
+
"""Maximize the objective with respect to the params, instead of minimizing."""
|
51
|
+
|
52
|
+
foreach: bool | None = None
|
53
|
+
"""Whether foreach implementation of optimizer is used."""
|
54
|
+
|
55
|
+
capturable: bool = False
|
56
|
+
"""Whether this instance is safe to capture in a CUDA graph."""
|
57
|
+
|
58
|
+
differentiable: bool = False
|
59
|
+
"""Whether autograd should occur through the optimizer step in training."""
|
60
|
+
|
48
61
|
@override
|
49
62
|
def create_optimizer(
|
50
63
|
self,
|
@@ -59,6 +72,551 @@ class AdamWConfig(OptimizerConfigBase):
|
|
59
72
|
betas=self.betas,
|
60
73
|
eps=self.eps,
|
61
74
|
amsgrad=self.amsgrad,
|
75
|
+
maximize=self.maximize,
|
76
|
+
foreach=self.foreach,
|
77
|
+
capturable=self.capturable,
|
78
|
+
differentiable=self.differentiable,
|
79
|
+
)
|
80
|
+
|
81
|
+
|
82
|
+
@final
|
83
|
+
@optimizer_registry.register
|
84
|
+
class AdafactorConfig(OptimizerConfigBase):
|
85
|
+
name: Literal["adafactor"] = "adafactor"
|
86
|
+
lr: float
|
87
|
+
"""Learning rate for the optimizer. If None, uses relative step size."""
|
88
|
+
|
89
|
+
eps1: float | None = None
|
90
|
+
"""Term added to the denominator to improve numerical stability (default: None)."""
|
91
|
+
|
92
|
+
eps2: float = 1e-3
|
93
|
+
"""Term added to the denominator to improve numerical stability (default: 1e-3)."""
|
94
|
+
|
95
|
+
beta2_decay: float = -0.8
|
96
|
+
"""Coefficient used for computing running averages of square gradient (default: -0.8)."""
|
97
|
+
|
98
|
+
weight_decay: float = 0.0
|
99
|
+
"""Weight decay (L2 penalty) (default: 0.0)."""
|
100
|
+
|
101
|
+
maximize: bool = False
|
102
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
103
|
+
|
104
|
+
@override
|
105
|
+
def create_optimizer(
|
106
|
+
self,
|
107
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
108
|
+
):
|
109
|
+
from torch.optim import Adafactor
|
110
|
+
|
111
|
+
return Adafactor(
|
112
|
+
parameters,
|
113
|
+
lr=self.lr,
|
114
|
+
eps=(self.eps1, self.eps2),
|
115
|
+
beta2_decay=self.beta2_decay,
|
116
|
+
weight_decay=self.weight_decay,
|
117
|
+
maximize=self.maximize,
|
118
|
+
)
|
119
|
+
|
120
|
+
|
121
|
+
@final
|
122
|
+
@optimizer_registry.register
|
123
|
+
class AdadeltaConfig(OptimizerConfigBase):
|
124
|
+
name: Literal["adadelta"] = "adadelta"
|
125
|
+
|
126
|
+
lr: float
|
127
|
+
"""Learning rate for the optimizer."""
|
128
|
+
|
129
|
+
rho: float = 0.9
|
130
|
+
"""Coefficient used for computing a running average of squared gradients."""
|
131
|
+
|
132
|
+
eps: float = 1e-6
|
133
|
+
"""Term added to the denominator to improve numerical stability."""
|
134
|
+
|
135
|
+
weight_decay: float = 0.0
|
136
|
+
"""Weight decay (L2 penalty) for the optimizer."""
|
137
|
+
|
138
|
+
maximize: bool = False
|
139
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
140
|
+
|
141
|
+
foreach: bool | None = None
|
142
|
+
"""Whether foreach implementation of optimizer is used."""
|
143
|
+
|
144
|
+
capturable: bool = False
|
145
|
+
"""Whether this instance is safe to capture in a CUDA graph."""
|
146
|
+
|
147
|
+
differentiable: bool = False
|
148
|
+
"""Whether autograd should occur through the optimizer step in training."""
|
149
|
+
|
150
|
+
@override
|
151
|
+
def create_optimizer(
|
152
|
+
self,
|
153
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
154
|
+
):
|
155
|
+
from torch.optim import Adadelta
|
156
|
+
|
157
|
+
return Adadelta(
|
158
|
+
parameters,
|
159
|
+
lr=self.lr,
|
160
|
+
rho=self.rho,
|
161
|
+
eps=self.eps,
|
162
|
+
weight_decay=self.weight_decay,
|
163
|
+
maximize=self.maximize,
|
164
|
+
foreach=self.foreach,
|
165
|
+
capturable=self.capturable,
|
166
|
+
differentiable=self.differentiable,
|
167
|
+
)
|
168
|
+
|
169
|
+
|
170
|
+
@final
|
171
|
+
@optimizer_registry.register
|
172
|
+
class AdagradConfig(OptimizerConfigBase):
|
173
|
+
name: Literal["adagrad"] = "adagrad"
|
174
|
+
|
175
|
+
lr: float
|
176
|
+
"""Learning rate for the optimizer."""
|
177
|
+
|
178
|
+
lr_decay: float = 0.0
|
179
|
+
"""Learning rate decay."""
|
180
|
+
|
181
|
+
weight_decay: float = 0.0
|
182
|
+
"""Weight decay (L2 penalty) for the optimizer."""
|
183
|
+
|
184
|
+
initial_accumulator_value: float = 0.0
|
185
|
+
"""Initial value for the accumulator."""
|
186
|
+
|
187
|
+
eps: float = 1e-10
|
188
|
+
"""Term added to the denominator to improve numerical stability."""
|
189
|
+
|
190
|
+
maximize: bool = False
|
191
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
192
|
+
|
193
|
+
foreach: bool | None = None
|
194
|
+
"""Whether foreach implementation of optimizer is used."""
|
195
|
+
|
196
|
+
differentiable: bool = False
|
197
|
+
"""Whether autograd should occur through the optimizer step in training."""
|
198
|
+
|
199
|
+
fused: bool | None = None
|
200
|
+
"""Whether the fused implementation is used."""
|
201
|
+
|
202
|
+
@override
|
203
|
+
def create_optimizer(
|
204
|
+
self,
|
205
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
206
|
+
):
|
207
|
+
from torch.optim import Adagrad
|
208
|
+
|
209
|
+
return Adagrad(
|
210
|
+
parameters,
|
211
|
+
lr=self.lr,
|
212
|
+
lr_decay=self.lr_decay,
|
213
|
+
weight_decay=self.weight_decay,
|
214
|
+
initial_accumulator_value=self.initial_accumulator_value,
|
215
|
+
eps=self.eps,
|
216
|
+
maximize=self.maximize,
|
217
|
+
foreach=self.foreach,
|
218
|
+
differentiable=self.differentiable,
|
219
|
+
fused=self.fused,
|
220
|
+
)
|
221
|
+
|
222
|
+
|
223
|
+
@final
|
224
|
+
@optimizer_registry.register
|
225
|
+
class AdamConfig(OptimizerConfigBase):
|
226
|
+
name: Literal["adam"] = "adam"
|
227
|
+
|
228
|
+
lr: float
|
229
|
+
"""Learning rate for the optimizer."""
|
230
|
+
|
231
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
232
|
+
"""Coefficients used for computing running averages of gradient and its square."""
|
233
|
+
|
234
|
+
eps: float = 1e-8
|
235
|
+
"""Term added to the denominator to improve numerical stability."""
|
236
|
+
|
237
|
+
weight_decay: float = 0.0
|
238
|
+
"""Weight decay (L2 penalty) for the optimizer."""
|
239
|
+
|
240
|
+
amsgrad: bool = False
|
241
|
+
"""Whether to use the AMSGrad variant of this algorithm."""
|
242
|
+
|
243
|
+
maximize: bool = False
|
244
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
245
|
+
|
246
|
+
foreach: bool | None = None
|
247
|
+
"""Whether foreach implementation of optimizer is used."""
|
248
|
+
|
249
|
+
capturable: bool = False
|
250
|
+
"""Whether this instance is safe to capture in a CUDA graph."""
|
251
|
+
|
252
|
+
differentiable: bool = False
|
253
|
+
"""Whether autograd should occur through the optimizer step in training."""
|
254
|
+
|
255
|
+
fused: bool | None = None
|
256
|
+
"""Whether the fused implementation is used."""
|
257
|
+
|
258
|
+
@override
|
259
|
+
def create_optimizer(
|
260
|
+
self,
|
261
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
262
|
+
):
|
263
|
+
from torch.optim import Adam
|
264
|
+
|
265
|
+
return Adam(
|
266
|
+
parameters,
|
267
|
+
lr=self.lr,
|
268
|
+
betas=self.betas,
|
269
|
+
eps=self.eps,
|
270
|
+
weight_decay=self.weight_decay,
|
271
|
+
amsgrad=self.amsgrad,
|
272
|
+
maximize=self.maximize,
|
273
|
+
foreach=self.foreach,
|
274
|
+
capturable=self.capturable,
|
275
|
+
differentiable=self.differentiable,
|
276
|
+
fused=self.fused,
|
277
|
+
)
|
278
|
+
|
279
|
+
|
280
|
+
@final
|
281
|
+
@optimizer_registry.register
|
282
|
+
class AdamaxConfig(OptimizerConfigBase):
|
283
|
+
name: Literal["adamax"] = "adamax"
|
284
|
+
|
285
|
+
lr: float
|
286
|
+
"""Learning rate for the optimizer."""
|
287
|
+
|
288
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
289
|
+
"""Coefficients used for computing running averages of gradient and its square."""
|
290
|
+
|
291
|
+
eps: float = 1e-8
|
292
|
+
"""Term added to the denominator to improve numerical stability."""
|
293
|
+
|
294
|
+
weight_decay: float = 0.0
|
295
|
+
"""Weight decay (L2 penalty) for the optimizer."""
|
296
|
+
|
297
|
+
maximize: bool = False
|
298
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
299
|
+
|
300
|
+
foreach: bool | None = None
|
301
|
+
"""Whether foreach implementation of optimizer is used."""
|
302
|
+
|
303
|
+
capturable: bool = False
|
304
|
+
"""Whether this instance is safe to capture in a CUDA graph."""
|
305
|
+
|
306
|
+
differentiable: bool = False
|
307
|
+
"""Whether autograd should occur through the optimizer step in training."""
|
308
|
+
|
309
|
+
@override
|
310
|
+
def create_optimizer(
|
311
|
+
self,
|
312
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
313
|
+
):
|
314
|
+
from torch.optim import Adamax
|
315
|
+
|
316
|
+
return Adamax(
|
317
|
+
parameters,
|
318
|
+
lr=self.lr,
|
319
|
+
betas=self.betas,
|
320
|
+
eps=self.eps,
|
321
|
+
weight_decay=self.weight_decay,
|
322
|
+
maximize=self.maximize,
|
323
|
+
foreach=self.foreach,
|
324
|
+
capturable=self.capturable,
|
325
|
+
differentiable=self.differentiable,
|
326
|
+
)
|
327
|
+
|
328
|
+
|
329
|
+
@final
|
330
|
+
@optimizer_registry.register
|
331
|
+
class ASGDConfig(OptimizerConfigBase):
|
332
|
+
name: Literal["asgd"] = "asgd"
|
333
|
+
|
334
|
+
lr: float
|
335
|
+
"""Learning rate for the optimizer."""
|
336
|
+
|
337
|
+
lambd: float = 1e-4
|
338
|
+
"""Decay term."""
|
339
|
+
|
340
|
+
alpha: float = 0.75
|
341
|
+
"""Power for eta update."""
|
342
|
+
|
343
|
+
t0: float = 1e6
|
344
|
+
"""Point at which to start averaging."""
|
345
|
+
|
346
|
+
weight_decay: float = 0.0
|
347
|
+
"""Weight decay (L2 penalty) for the optimizer."""
|
348
|
+
|
349
|
+
maximize: bool = False
|
350
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
351
|
+
|
352
|
+
@override
|
353
|
+
def create_optimizer(
|
354
|
+
self,
|
355
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
356
|
+
):
|
357
|
+
from torch.optim import ASGD
|
358
|
+
|
359
|
+
return ASGD(
|
360
|
+
parameters,
|
361
|
+
lr=self.lr,
|
362
|
+
lambd=self.lambd,
|
363
|
+
alpha=self.alpha,
|
364
|
+
t0=self.t0,
|
365
|
+
weight_decay=self.weight_decay,
|
366
|
+
maximize=self.maximize,
|
367
|
+
)
|
368
|
+
|
369
|
+
|
370
|
+
@final
|
371
|
+
@optimizer_registry.register
|
372
|
+
class NAdamConfig(OptimizerConfigBase):
|
373
|
+
name: Literal["nadam"] = "nadam"
|
374
|
+
|
375
|
+
lr: float
|
376
|
+
"""Learning rate for the optimizer."""
|
377
|
+
|
378
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
379
|
+
"""Coefficients used for computing running averages of gradient and its square."""
|
380
|
+
|
381
|
+
eps: float = 1e-8
|
382
|
+
"""Term added to the denominator to improve numerical stability."""
|
383
|
+
|
384
|
+
weight_decay: float = 0.0
|
385
|
+
"""Weight decay (L2 penalty) for the optimizer."""
|
386
|
+
|
387
|
+
momentum_decay: float = 4e-3
|
388
|
+
"""Momentum decay."""
|
389
|
+
|
390
|
+
decoupled_weight_decay: bool = False
|
391
|
+
"""Whether to use decoupled weight decay."""
|
392
|
+
|
393
|
+
maximize: bool = False
|
394
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
395
|
+
|
396
|
+
foreach: bool | None = None
|
397
|
+
"""Whether foreach implementation of optimizer is used."""
|
398
|
+
|
399
|
+
capturable: bool = False
|
400
|
+
"""Whether this instance is safe to capture in a CUDA graph."""
|
401
|
+
|
402
|
+
differentiable: bool = False
|
403
|
+
"""Whether autograd should occur through the optimizer step in training."""
|
404
|
+
|
405
|
+
@override
|
406
|
+
def create_optimizer(
|
407
|
+
self,
|
408
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
409
|
+
):
|
410
|
+
from torch.optim import NAdam
|
411
|
+
|
412
|
+
return NAdam(
|
413
|
+
parameters,
|
414
|
+
lr=self.lr,
|
415
|
+
betas=self.betas,
|
416
|
+
eps=self.eps,
|
417
|
+
weight_decay=self.weight_decay,
|
418
|
+
momentum_decay=self.momentum_decay,
|
419
|
+
decoupled_weight_decay=self.decoupled_weight_decay,
|
420
|
+
maximize=self.maximize,
|
421
|
+
foreach=self.foreach,
|
422
|
+
capturable=self.capturable,
|
423
|
+
differentiable=self.differentiable,
|
424
|
+
)
|
425
|
+
|
426
|
+
|
427
|
+
@final
|
428
|
+
@optimizer_registry.register
|
429
|
+
class RAdamConfig(OptimizerConfigBase):
|
430
|
+
name: Literal["radam"] = "radam"
|
431
|
+
|
432
|
+
lr: float
|
433
|
+
"""Learning rate for the optimizer."""
|
434
|
+
|
435
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
436
|
+
"""Coefficients used for computing running averages of gradient and its square."""
|
437
|
+
|
438
|
+
eps: float = 1e-8
|
439
|
+
"""Term added to the denominator to improve numerical stability."""
|
440
|
+
|
441
|
+
weight_decay: float = 0.0
|
442
|
+
"""Weight decay (L2 penalty) for the optimizer."""
|
443
|
+
|
444
|
+
decoupled_weight_decay: bool = False
|
445
|
+
"""Whether to use decoupled weight decay."""
|
446
|
+
|
447
|
+
maximize: bool = False
|
448
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
449
|
+
|
450
|
+
foreach: bool | None = None
|
451
|
+
"""Whether foreach implementation of optimizer is used."""
|
452
|
+
|
453
|
+
capturable: bool = False
|
454
|
+
"""Whether this instance is safe to capture in a CUDA graph."""
|
455
|
+
|
456
|
+
differentiable: bool = False
|
457
|
+
"""Whether autograd should occur through the optimizer step in training."""
|
458
|
+
|
459
|
+
@override
|
460
|
+
def create_optimizer(
|
461
|
+
self,
|
462
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
463
|
+
):
|
464
|
+
from torch.optim import RAdam
|
465
|
+
|
466
|
+
return RAdam(
|
467
|
+
parameters,
|
468
|
+
lr=self.lr,
|
469
|
+
betas=self.betas,
|
470
|
+
eps=self.eps,
|
471
|
+
weight_decay=self.weight_decay,
|
472
|
+
decoupled_weight_decay=self.decoupled_weight_decay,
|
473
|
+
maximize=self.maximize,
|
474
|
+
foreach=self.foreach,
|
475
|
+
capturable=self.capturable,
|
476
|
+
differentiable=self.differentiable,
|
477
|
+
)
|
478
|
+
|
479
|
+
|
480
|
+
@final
|
481
|
+
@optimizer_registry.register
|
482
|
+
class RMSpropConfig(OptimizerConfigBase):
|
483
|
+
name: Literal["rmsprop"] = "rmsprop"
|
484
|
+
|
485
|
+
lr: float
|
486
|
+
"""Learning rate for the optimizer."""
|
487
|
+
|
488
|
+
alpha: float = 0.99
|
489
|
+
"""Smoothing constant."""
|
490
|
+
|
491
|
+
eps: float = 1e-8
|
492
|
+
"""Term added to the denominator to improve numerical stability."""
|
493
|
+
|
494
|
+
weight_decay: float = 0.0
|
495
|
+
"""Weight decay (L2 penalty) for the optimizer."""
|
496
|
+
|
497
|
+
momentum: float = 0.0
|
498
|
+
"""Momentum factor."""
|
499
|
+
|
500
|
+
centered: bool = False
|
501
|
+
"""If True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance."""
|
502
|
+
|
503
|
+
maximize: bool = False
|
504
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
505
|
+
|
506
|
+
foreach: bool | None = None
|
507
|
+
"""Whether foreach implementation of optimizer is used."""
|
508
|
+
|
509
|
+
capturable: bool = False
|
510
|
+
"""Whether this instance is safe to capture in a CUDA graph."""
|
511
|
+
|
512
|
+
differentiable: bool = False
|
513
|
+
"""Whether autograd should occur through the optimizer step in training."""
|
514
|
+
|
515
|
+
@override
|
516
|
+
def create_optimizer(
|
517
|
+
self,
|
518
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
519
|
+
):
|
520
|
+
from torch.optim import RMSprop
|
521
|
+
|
522
|
+
return RMSprop(
|
523
|
+
parameters,
|
524
|
+
lr=self.lr,
|
525
|
+
alpha=self.alpha,
|
526
|
+
eps=self.eps,
|
527
|
+
weight_decay=self.weight_decay,
|
528
|
+
momentum=self.momentum,
|
529
|
+
centered=self.centered,
|
530
|
+
maximize=self.maximize,
|
531
|
+
foreach=self.foreach,
|
532
|
+
capturable=self.capturable,
|
533
|
+
differentiable=self.differentiable,
|
534
|
+
)
|
535
|
+
|
536
|
+
|
537
|
+
@final
|
538
|
+
@optimizer_registry.register
|
539
|
+
class RpropConfig(OptimizerConfigBase):
|
540
|
+
name: Literal["rprop"] = "rprop"
|
541
|
+
|
542
|
+
lr: float
|
543
|
+
"""Learning rate for the optimizer."""
|
544
|
+
|
545
|
+
etas: tuple[float, float] = (0.5, 1.2)
|
546
|
+
"""Pair of (etaminus, etaplus), multiplicative increase and decrease factors."""
|
547
|
+
|
548
|
+
step_sizes: tuple[float, float] = (1e-6, 50.0)
|
549
|
+
"""Pair of minimal and maximal allowed step sizes."""
|
550
|
+
|
551
|
+
maximize: bool = False
|
552
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
553
|
+
|
554
|
+
@override
|
555
|
+
def create_optimizer(
|
556
|
+
self,
|
557
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
558
|
+
):
|
559
|
+
from torch.optim import Rprop
|
560
|
+
|
561
|
+
return Rprop(
|
562
|
+
parameters,
|
563
|
+
lr=self.lr,
|
564
|
+
etas=self.etas,
|
565
|
+
step_sizes=self.step_sizes,
|
566
|
+
maximize=self.maximize,
|
567
|
+
)
|
568
|
+
|
569
|
+
|
570
|
+
@final
|
571
|
+
@optimizer_registry.register
|
572
|
+
class SGDConfig(OptimizerConfigBase):
|
573
|
+
name: Literal["sgd"] = "sgd"
|
574
|
+
|
575
|
+
lr: float
|
576
|
+
"""Learning rate for the optimizer."""
|
577
|
+
|
578
|
+
momentum: float = 0.0
|
579
|
+
"""Momentum factor."""
|
580
|
+
|
581
|
+
dampening: float = 0.0
|
582
|
+
"""Dampening for momentum."""
|
583
|
+
|
584
|
+
weight_decay: float = 0.0
|
585
|
+
"""Weight decay (L2 penalty) for the optimizer."""
|
586
|
+
|
587
|
+
nesterov: bool = False
|
588
|
+
"""Enables Nesterov momentum."""
|
589
|
+
|
590
|
+
maximize: bool = False
|
591
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
592
|
+
|
593
|
+
foreach: bool | None = None
|
594
|
+
"""Whether foreach implementation of optimizer is used."""
|
595
|
+
|
596
|
+
differentiable: bool = False
|
597
|
+
"""Whether autograd should occur through the optimizer step in training."""
|
598
|
+
|
599
|
+
fused: bool | None = None
|
600
|
+
"""Whether the fused implementation is used."""
|
601
|
+
|
602
|
+
@override
|
603
|
+
def create_optimizer(
|
604
|
+
self,
|
605
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
606
|
+
):
|
607
|
+
from torch.optim import SGD
|
608
|
+
|
609
|
+
return SGD(
|
610
|
+
parameters,
|
611
|
+
lr=self.lr,
|
612
|
+
momentum=self.momentum,
|
613
|
+
dampening=self.dampening,
|
614
|
+
weight_decay=self.weight_decay,
|
615
|
+
nesterov=self.nesterov,
|
616
|
+
maximize=self.maximize,
|
617
|
+
foreach=self.foreach,
|
618
|
+
differentiable=self.differentiable,
|
619
|
+
fused=self.fused,
|
62
620
|
)
|
63
621
|
|
64
622
|
|
nshtrainer/trainer/_config.py
CHANGED
@@ -31,6 +31,7 @@ from .._hf_hub import HuggingFaceHubConfig
|
|
31
31
|
from ..callbacks import (
|
32
32
|
BestCheckpointCallbackConfig,
|
33
33
|
CallbackConfig,
|
34
|
+
DistributedPredictionWriterConfig,
|
34
35
|
EarlyStoppingCallbackConfig,
|
35
36
|
LastCheckpointCallbackConfig,
|
36
37
|
NormLoggingCallbackConfig,
|
@@ -701,6 +702,14 @@ class TrainerConfig(C.Config):
|
|
701
702
|
auto_validate_metrics: MetricValidationCallbackConfig | None = None
|
702
703
|
"""If enabled, will automatically validate the metrics before starting the training routine."""
|
703
704
|
|
705
|
+
distributed_predict: DistributedPredictionWriterConfig | None = (
|
706
|
+
DistributedPredictionWriterConfig()
|
707
|
+
)
|
708
|
+
"""If enabled, will use a custom BasePredictionWriter callback to automatically
|
709
|
+
handle distributed prediction. This is useful for running prediction on multiple GPUs
|
710
|
+
seamlessly.
|
711
|
+
"""
|
712
|
+
|
704
713
|
lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
|
705
714
|
"""
|
706
715
|
Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
|
@@ -752,10 +761,7 @@ class TrainerConfig(C.Config):
|
|
752
761
|
)
|
753
762
|
|
754
763
|
def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
|
755
|
-
|
756
|
-
if self.barebones:
|
757
|
-
return
|
758
|
-
|
764
|
+
yield self.directory.setup_callback
|
759
765
|
yield self.early_stopping
|
760
766
|
yield self.checkpoint_saving
|
761
767
|
yield self.lr_monitor
|
@@ -772,6 +778,7 @@ class TrainerConfig(C.Config):
|
|
772
778
|
yield self.reduce_lr_on_plateau_sanity_checking
|
773
779
|
yield self.auto_set_debug_flag
|
774
780
|
yield self.auto_validate_metrics
|
781
|
+
yield self.distributed_predict
|
775
782
|
yield from self.callbacks
|
776
783
|
|
777
784
|
def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
|
nshtrainer/trainer/trainer.py
CHANGED
@@ -10,12 +10,16 @@ import torch
|
|
10
10
|
from lightning.fabric.plugins.environments.lsf import LSFEnvironment
|
11
11
|
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
|
12
12
|
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
|
13
|
-
from lightning.pytorch import LightningModule
|
13
|
+
from lightning.pytorch import LightningDataModule, LightningModule
|
14
14
|
from lightning.pytorch import Trainer as LightningTrainer
|
15
15
|
from lightning.pytorch.callbacks import Callback
|
16
16
|
from lightning.pytorch.profilers import Profiler
|
17
17
|
from lightning.pytorch.trainer.states import TrainerFn
|
18
|
-
from lightning.pytorch.utilities.types import
|
18
|
+
from lightning.pytorch.utilities.types import (
|
19
|
+
_EVALUATE_OUTPUT,
|
20
|
+
_PREDICT_OUTPUT,
|
21
|
+
EVAL_DATALOADERS,
|
22
|
+
)
|
19
23
|
from typing_extensions import Never, Unpack, assert_never, deprecated, override
|
20
24
|
|
21
25
|
from .._checkpoint.metadata import write_checkpoint_metadata
|
@@ -532,3 +536,18 @@ class Trainer(LightningTrainer):
|
|
532
536
|
update_hparams_dict=update_hparams_dict,
|
533
537
|
)
|
534
538
|
return cls(hparams)
|
539
|
+
|
540
|
+
def distributed_predict(
|
541
|
+
self,
|
542
|
+
model: LightningModule | None = None,
|
543
|
+
dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None,
|
544
|
+
datamodule: LightningDataModule | None = None,
|
545
|
+
ckpt_path: str | Path | None = None,
|
546
|
+
):
|
547
|
+
self.predict(
|
548
|
+
model,
|
549
|
+
dataloaders,
|
550
|
+
datamodule,
|
551
|
+
return_predictions=False,
|
552
|
+
ckpt_path=ckpt_path,
|
553
|
+
)
|