nshtrainer 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/optimizer.py CHANGED
@@ -2,10 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from collections.abc import Iterable
5
- from typing import Annotated, Any, Literal
5
+ from typing import Annotated, Any, Literal, Tuple, Union
6
6
 
7
7
  import nshconfig as C
8
8
  import torch.nn as nn
9
+ from torch import Tensor
9
10
  from torch.optim import Optimizer
10
11
  from typing_extensions import TypeAliasType, final, override
11
12
 
@@ -45,6 +46,18 @@ class AdamWConfig(OptimizerConfigBase):
45
46
  amsgrad: bool = False
46
47
  """Whether to use the AMSGrad variant of this algorithm."""
47
48
 
49
+ maximize: bool = False
50
+ """Maximize the objective with respect to the params, instead of minimizing."""
51
+
52
+ foreach: bool | None = None
53
+ """Whether foreach implementation of optimizer is used."""
54
+
55
+ capturable: bool = False
56
+ """Whether this instance is safe to capture in a CUDA graph."""
57
+
58
+ differentiable: bool = False
59
+ """Whether autograd should occur through the optimizer step in training."""
60
+
48
61
  @override
49
62
  def create_optimizer(
50
63
  self,
@@ -59,6 +72,551 @@ class AdamWConfig(OptimizerConfigBase):
59
72
  betas=self.betas,
60
73
  eps=self.eps,
61
74
  amsgrad=self.amsgrad,
75
+ maximize=self.maximize,
76
+ foreach=self.foreach,
77
+ capturable=self.capturable,
78
+ differentiable=self.differentiable,
79
+ )
80
+
81
+
82
+ @final
83
+ @optimizer_registry.register
84
+ class AdafactorConfig(OptimizerConfigBase):
85
+ name: Literal["adafactor"] = "adafactor"
86
+ lr: float
87
+ """Learning rate for the optimizer. If None, uses relative step size."""
88
+
89
+ eps1: float | None = None
90
+ """Term added to the denominator to improve numerical stability (default: None)."""
91
+
92
+ eps2: float = 1e-3
93
+ """Term added to the denominator to improve numerical stability (default: 1e-3)."""
94
+
95
+ beta2_decay: float = -0.8
96
+ """Coefficient used for computing running averages of square gradient (default: -0.8)."""
97
+
98
+ weight_decay: float = 0.0
99
+ """Weight decay (L2 penalty) (default: 0.0)."""
100
+
101
+ maximize: bool = False
102
+ """Maximize the params based on the objective, instead of minimizing."""
103
+
104
+ @override
105
+ def create_optimizer(
106
+ self,
107
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
108
+ ):
109
+ from torch.optim import Adafactor
110
+
111
+ return Adafactor(
112
+ parameters,
113
+ lr=self.lr,
114
+ eps=(self.eps1, self.eps2),
115
+ beta2_decay=self.beta2_decay,
116
+ weight_decay=self.weight_decay,
117
+ maximize=self.maximize,
118
+ )
119
+
120
+
121
+ @final
122
+ @optimizer_registry.register
123
+ class AdadeltaConfig(OptimizerConfigBase):
124
+ name: Literal["adadelta"] = "adadelta"
125
+
126
+ lr: float
127
+ """Learning rate for the optimizer."""
128
+
129
+ rho: float = 0.9
130
+ """Coefficient used for computing a running average of squared gradients."""
131
+
132
+ eps: float = 1e-6
133
+ """Term added to the denominator to improve numerical stability."""
134
+
135
+ weight_decay: float = 0.0
136
+ """Weight decay (L2 penalty) for the optimizer."""
137
+
138
+ maximize: bool = False
139
+ """Maximize the params based on the objective, instead of minimizing."""
140
+
141
+ foreach: bool | None = None
142
+ """Whether foreach implementation of optimizer is used."""
143
+
144
+ capturable: bool = False
145
+ """Whether this instance is safe to capture in a CUDA graph."""
146
+
147
+ differentiable: bool = False
148
+ """Whether autograd should occur through the optimizer step in training."""
149
+
150
+ @override
151
+ def create_optimizer(
152
+ self,
153
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
154
+ ):
155
+ from torch.optim import Adadelta
156
+
157
+ return Adadelta(
158
+ parameters,
159
+ lr=self.lr,
160
+ rho=self.rho,
161
+ eps=self.eps,
162
+ weight_decay=self.weight_decay,
163
+ maximize=self.maximize,
164
+ foreach=self.foreach,
165
+ capturable=self.capturable,
166
+ differentiable=self.differentiable,
167
+ )
168
+
169
+
170
+ @final
171
+ @optimizer_registry.register
172
+ class AdagradConfig(OptimizerConfigBase):
173
+ name: Literal["adagrad"] = "adagrad"
174
+
175
+ lr: float
176
+ """Learning rate for the optimizer."""
177
+
178
+ lr_decay: float = 0.0
179
+ """Learning rate decay."""
180
+
181
+ weight_decay: float = 0.0
182
+ """Weight decay (L2 penalty) for the optimizer."""
183
+
184
+ initial_accumulator_value: float = 0.0
185
+ """Initial value for the accumulator."""
186
+
187
+ eps: float = 1e-10
188
+ """Term added to the denominator to improve numerical stability."""
189
+
190
+ maximize: bool = False
191
+ """Maximize the params based on the objective, instead of minimizing."""
192
+
193
+ foreach: bool | None = None
194
+ """Whether foreach implementation of optimizer is used."""
195
+
196
+ differentiable: bool = False
197
+ """Whether autograd should occur through the optimizer step in training."""
198
+
199
+ fused: bool | None = None
200
+ """Whether the fused implementation is used."""
201
+
202
+ @override
203
+ def create_optimizer(
204
+ self,
205
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
206
+ ):
207
+ from torch.optim import Adagrad
208
+
209
+ return Adagrad(
210
+ parameters,
211
+ lr=self.lr,
212
+ lr_decay=self.lr_decay,
213
+ weight_decay=self.weight_decay,
214
+ initial_accumulator_value=self.initial_accumulator_value,
215
+ eps=self.eps,
216
+ maximize=self.maximize,
217
+ foreach=self.foreach,
218
+ differentiable=self.differentiable,
219
+ fused=self.fused,
220
+ )
221
+
222
+
223
+ @final
224
+ @optimizer_registry.register
225
+ class AdamConfig(OptimizerConfigBase):
226
+ name: Literal["adam"] = "adam"
227
+
228
+ lr: float
229
+ """Learning rate for the optimizer."""
230
+
231
+ betas: tuple[float, float] = (0.9, 0.999)
232
+ """Coefficients used for computing running averages of gradient and its square."""
233
+
234
+ eps: float = 1e-8
235
+ """Term added to the denominator to improve numerical stability."""
236
+
237
+ weight_decay: float = 0.0
238
+ """Weight decay (L2 penalty) for the optimizer."""
239
+
240
+ amsgrad: bool = False
241
+ """Whether to use the AMSGrad variant of this algorithm."""
242
+
243
+ maximize: bool = False
244
+ """Maximize the params based on the objective, instead of minimizing."""
245
+
246
+ foreach: bool | None = None
247
+ """Whether foreach implementation of optimizer is used."""
248
+
249
+ capturable: bool = False
250
+ """Whether this instance is safe to capture in a CUDA graph."""
251
+
252
+ differentiable: bool = False
253
+ """Whether autograd should occur through the optimizer step in training."""
254
+
255
+ fused: bool | None = None
256
+ """Whether the fused implementation is used."""
257
+
258
+ @override
259
+ def create_optimizer(
260
+ self,
261
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
262
+ ):
263
+ from torch.optim import Adam
264
+
265
+ return Adam(
266
+ parameters,
267
+ lr=self.lr,
268
+ betas=self.betas,
269
+ eps=self.eps,
270
+ weight_decay=self.weight_decay,
271
+ amsgrad=self.amsgrad,
272
+ maximize=self.maximize,
273
+ foreach=self.foreach,
274
+ capturable=self.capturable,
275
+ differentiable=self.differentiable,
276
+ fused=self.fused,
277
+ )
278
+
279
+
280
+ @final
281
+ @optimizer_registry.register
282
+ class AdamaxConfig(OptimizerConfigBase):
283
+ name: Literal["adamax"] = "adamax"
284
+
285
+ lr: float
286
+ """Learning rate for the optimizer."""
287
+
288
+ betas: tuple[float, float] = (0.9, 0.999)
289
+ """Coefficients used for computing running averages of gradient and its square."""
290
+
291
+ eps: float = 1e-8
292
+ """Term added to the denominator to improve numerical stability."""
293
+
294
+ weight_decay: float = 0.0
295
+ """Weight decay (L2 penalty) for the optimizer."""
296
+
297
+ maximize: bool = False
298
+ """Maximize the params based on the objective, instead of minimizing."""
299
+
300
+ foreach: bool | None = None
301
+ """Whether foreach implementation of optimizer is used."""
302
+
303
+ capturable: bool = False
304
+ """Whether this instance is safe to capture in a CUDA graph."""
305
+
306
+ differentiable: bool = False
307
+ """Whether autograd should occur through the optimizer step in training."""
308
+
309
+ @override
310
+ def create_optimizer(
311
+ self,
312
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
313
+ ):
314
+ from torch.optim import Adamax
315
+
316
+ return Adamax(
317
+ parameters,
318
+ lr=self.lr,
319
+ betas=self.betas,
320
+ eps=self.eps,
321
+ weight_decay=self.weight_decay,
322
+ maximize=self.maximize,
323
+ foreach=self.foreach,
324
+ capturable=self.capturable,
325
+ differentiable=self.differentiable,
326
+ )
327
+
328
+
329
+ @final
330
+ @optimizer_registry.register
331
+ class ASGDConfig(OptimizerConfigBase):
332
+ name: Literal["asgd"] = "asgd"
333
+
334
+ lr: float
335
+ """Learning rate for the optimizer."""
336
+
337
+ lambd: float = 1e-4
338
+ """Decay term."""
339
+
340
+ alpha: float = 0.75
341
+ """Power for eta update."""
342
+
343
+ t0: float = 1e6
344
+ """Point at which to start averaging."""
345
+
346
+ weight_decay: float = 0.0
347
+ """Weight decay (L2 penalty) for the optimizer."""
348
+
349
+ maximize: bool = False
350
+ """Maximize the params based on the objective, instead of minimizing."""
351
+
352
+ @override
353
+ def create_optimizer(
354
+ self,
355
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
356
+ ):
357
+ from torch.optim import ASGD
358
+
359
+ return ASGD(
360
+ parameters,
361
+ lr=self.lr,
362
+ lambd=self.lambd,
363
+ alpha=self.alpha,
364
+ t0=self.t0,
365
+ weight_decay=self.weight_decay,
366
+ maximize=self.maximize,
367
+ )
368
+
369
+
370
+ @final
371
+ @optimizer_registry.register
372
+ class NAdamConfig(OptimizerConfigBase):
373
+ name: Literal["nadam"] = "nadam"
374
+
375
+ lr: float
376
+ """Learning rate for the optimizer."""
377
+
378
+ betas: tuple[float, float] = (0.9, 0.999)
379
+ """Coefficients used for computing running averages of gradient and its square."""
380
+
381
+ eps: float = 1e-8
382
+ """Term added to the denominator to improve numerical stability."""
383
+
384
+ weight_decay: float = 0.0
385
+ """Weight decay (L2 penalty) for the optimizer."""
386
+
387
+ momentum_decay: float = 4e-3
388
+ """Momentum decay."""
389
+
390
+ decoupled_weight_decay: bool = False
391
+ """Whether to use decoupled weight decay."""
392
+
393
+ maximize: bool = False
394
+ """Maximize the params based on the objective, instead of minimizing."""
395
+
396
+ foreach: bool | None = None
397
+ """Whether foreach implementation of optimizer is used."""
398
+
399
+ capturable: bool = False
400
+ """Whether this instance is safe to capture in a CUDA graph."""
401
+
402
+ differentiable: bool = False
403
+ """Whether autograd should occur through the optimizer step in training."""
404
+
405
+ @override
406
+ def create_optimizer(
407
+ self,
408
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
409
+ ):
410
+ from torch.optim import NAdam
411
+
412
+ return NAdam(
413
+ parameters,
414
+ lr=self.lr,
415
+ betas=self.betas,
416
+ eps=self.eps,
417
+ weight_decay=self.weight_decay,
418
+ momentum_decay=self.momentum_decay,
419
+ decoupled_weight_decay=self.decoupled_weight_decay,
420
+ maximize=self.maximize,
421
+ foreach=self.foreach,
422
+ capturable=self.capturable,
423
+ differentiable=self.differentiable,
424
+ )
425
+
426
+
427
+ @final
428
+ @optimizer_registry.register
429
+ class RAdamConfig(OptimizerConfigBase):
430
+ name: Literal["radam"] = "radam"
431
+
432
+ lr: float
433
+ """Learning rate for the optimizer."""
434
+
435
+ betas: tuple[float, float] = (0.9, 0.999)
436
+ """Coefficients used for computing running averages of gradient and its square."""
437
+
438
+ eps: float = 1e-8
439
+ """Term added to the denominator to improve numerical stability."""
440
+
441
+ weight_decay: float = 0.0
442
+ """Weight decay (L2 penalty) for the optimizer."""
443
+
444
+ decoupled_weight_decay: bool = False
445
+ """Whether to use decoupled weight decay."""
446
+
447
+ maximize: bool = False
448
+ """Maximize the params based on the objective, instead of minimizing."""
449
+
450
+ foreach: bool | None = None
451
+ """Whether foreach implementation of optimizer is used."""
452
+
453
+ capturable: bool = False
454
+ """Whether this instance is safe to capture in a CUDA graph."""
455
+
456
+ differentiable: bool = False
457
+ """Whether autograd should occur through the optimizer step in training."""
458
+
459
+ @override
460
+ def create_optimizer(
461
+ self,
462
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
463
+ ):
464
+ from torch.optim import RAdam
465
+
466
+ return RAdam(
467
+ parameters,
468
+ lr=self.lr,
469
+ betas=self.betas,
470
+ eps=self.eps,
471
+ weight_decay=self.weight_decay,
472
+ decoupled_weight_decay=self.decoupled_weight_decay,
473
+ maximize=self.maximize,
474
+ foreach=self.foreach,
475
+ capturable=self.capturable,
476
+ differentiable=self.differentiable,
477
+ )
478
+
479
+
480
+ @final
481
+ @optimizer_registry.register
482
+ class RMSpropConfig(OptimizerConfigBase):
483
+ name: Literal["rmsprop"] = "rmsprop"
484
+
485
+ lr: float
486
+ """Learning rate for the optimizer."""
487
+
488
+ alpha: float = 0.99
489
+ """Smoothing constant."""
490
+
491
+ eps: float = 1e-8
492
+ """Term added to the denominator to improve numerical stability."""
493
+
494
+ weight_decay: float = 0.0
495
+ """Weight decay (L2 penalty) for the optimizer."""
496
+
497
+ momentum: float = 0.0
498
+ """Momentum factor."""
499
+
500
+ centered: bool = False
501
+ """If True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance."""
502
+
503
+ maximize: bool = False
504
+ """Maximize the params based on the objective, instead of minimizing."""
505
+
506
+ foreach: bool | None = None
507
+ """Whether foreach implementation of optimizer is used."""
508
+
509
+ capturable: bool = False
510
+ """Whether this instance is safe to capture in a CUDA graph."""
511
+
512
+ differentiable: bool = False
513
+ """Whether autograd should occur through the optimizer step in training."""
514
+
515
+ @override
516
+ def create_optimizer(
517
+ self,
518
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
519
+ ):
520
+ from torch.optim import RMSprop
521
+
522
+ return RMSprop(
523
+ parameters,
524
+ lr=self.lr,
525
+ alpha=self.alpha,
526
+ eps=self.eps,
527
+ weight_decay=self.weight_decay,
528
+ momentum=self.momentum,
529
+ centered=self.centered,
530
+ maximize=self.maximize,
531
+ foreach=self.foreach,
532
+ capturable=self.capturable,
533
+ differentiable=self.differentiable,
534
+ )
535
+
536
+
537
+ @final
538
+ @optimizer_registry.register
539
+ class RpropConfig(OptimizerConfigBase):
540
+ name: Literal["rprop"] = "rprop"
541
+
542
+ lr: float
543
+ """Learning rate for the optimizer."""
544
+
545
+ etas: tuple[float, float] = (0.5, 1.2)
546
+ """Pair of (etaminus, etaplus), multiplicative increase and decrease factors."""
547
+
548
+ step_sizes: tuple[float, float] = (1e-6, 50.0)
549
+ """Pair of minimal and maximal allowed step sizes."""
550
+
551
+ maximize: bool = False
552
+ """Maximize the params based on the objective, instead of minimizing."""
553
+
554
+ @override
555
+ def create_optimizer(
556
+ self,
557
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
558
+ ):
559
+ from torch.optim import Rprop
560
+
561
+ return Rprop(
562
+ parameters,
563
+ lr=self.lr,
564
+ etas=self.etas,
565
+ step_sizes=self.step_sizes,
566
+ maximize=self.maximize,
567
+ )
568
+
569
+
570
+ @final
571
+ @optimizer_registry.register
572
+ class SGDConfig(OptimizerConfigBase):
573
+ name: Literal["sgd"] = "sgd"
574
+
575
+ lr: float
576
+ """Learning rate for the optimizer."""
577
+
578
+ momentum: float = 0.0
579
+ """Momentum factor."""
580
+
581
+ dampening: float = 0.0
582
+ """Dampening for momentum."""
583
+
584
+ weight_decay: float = 0.0
585
+ """Weight decay (L2 penalty) for the optimizer."""
586
+
587
+ nesterov: bool = False
588
+ """Enables Nesterov momentum."""
589
+
590
+ maximize: bool = False
591
+ """Maximize the params based on the objective, instead of minimizing."""
592
+
593
+ foreach: bool | None = None
594
+ """Whether foreach implementation of optimizer is used."""
595
+
596
+ differentiable: bool = False
597
+ """Whether autograd should occur through the optimizer step in training."""
598
+
599
+ fused: bool | None = None
600
+ """Whether the fused implementation is used."""
601
+
602
+ @override
603
+ def create_optimizer(
604
+ self,
605
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
606
+ ):
607
+ from torch.optim import SGD
608
+
609
+ return SGD(
610
+ parameters,
611
+ lr=self.lr,
612
+ momentum=self.momentum,
613
+ dampening=self.dampening,
614
+ weight_decay=self.weight_decay,
615
+ nesterov=self.nesterov,
616
+ maximize=self.maximize,
617
+ foreach=self.foreach,
618
+ differentiable=self.differentiable,
619
+ fused=self.fused,
62
620
  )
63
621
 
64
622
 
@@ -31,6 +31,7 @@ from .._hf_hub import HuggingFaceHubConfig
31
31
  from ..callbacks import (
32
32
  BestCheckpointCallbackConfig,
33
33
  CallbackConfig,
34
+ DistributedPredictionWriterConfig,
34
35
  EarlyStoppingCallbackConfig,
35
36
  LastCheckpointCallbackConfig,
36
37
  NormLoggingCallbackConfig,
@@ -701,6 +702,14 @@ class TrainerConfig(C.Config):
701
702
  auto_validate_metrics: MetricValidationCallbackConfig | None = None
702
703
  """If enabled, will automatically validate the metrics before starting the training routine."""
703
704
 
705
+ distributed_predict: DistributedPredictionWriterConfig | None = (
706
+ DistributedPredictionWriterConfig()
707
+ )
708
+ """If enabled, will use a custom BasePredictionWriter callback to automatically
709
+ handle distributed prediction. This is useful for running prediction on multiple GPUs
710
+ seamlessly.
711
+ """
712
+
704
713
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
705
714
  """
706
715
  Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
@@ -752,10 +761,7 @@ class TrainerConfig(C.Config):
752
761
  )
753
762
 
754
763
  def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
755
- # Disable all callbacks if barebones mode is enabled
756
- if self.barebones:
757
- return
758
-
764
+ yield self.directory.setup_callback
759
765
  yield self.early_stopping
760
766
  yield self.checkpoint_saving
761
767
  yield self.lr_monitor
@@ -772,6 +778,7 @@ class TrainerConfig(C.Config):
772
778
  yield self.reduce_lr_on_plateau_sanity_checking
773
779
  yield self.auto_set_debug_flag
774
780
  yield self.auto_validate_metrics
781
+ yield self.distributed_predict
775
782
  yield from self.callbacks
776
783
 
777
784
  def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
@@ -10,12 +10,16 @@ import torch
10
10
  from lightning.fabric.plugins.environments.lsf import LSFEnvironment
11
11
  from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
12
12
  from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
13
- from lightning.pytorch import LightningModule
13
+ from lightning.pytorch import LightningDataModule, LightningModule
14
14
  from lightning.pytorch import Trainer as LightningTrainer
15
15
  from lightning.pytorch.callbacks import Callback
16
16
  from lightning.pytorch.profilers import Profiler
17
17
  from lightning.pytorch.trainer.states import TrainerFn
18
- from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
18
+ from lightning.pytorch.utilities.types import (
19
+ _EVALUATE_OUTPUT,
20
+ _PREDICT_OUTPUT,
21
+ EVAL_DATALOADERS,
22
+ )
19
23
  from typing_extensions import Never, Unpack, assert_never, deprecated, override
20
24
 
21
25
  from .._checkpoint.metadata import write_checkpoint_metadata
@@ -532,3 +536,18 @@ class Trainer(LightningTrainer):
532
536
  update_hparams_dict=update_hparams_dict,
533
537
  )
534
538
  return cls(hparams)
539
+
540
+ def distributed_predict(
541
+ self,
542
+ model: LightningModule | None = None,
543
+ dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None,
544
+ datamodule: LightningDataModule | None = None,
545
+ ckpt_path: str | Path | None = None,
546
+ ):
547
+ self.predict(
548
+ model,
549
+ dataloaders,
550
+ datamodule,
551
+ return_predictions=False,
552
+ ckpt_path=ckpt_path,
553
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.1.2
3
+ Version: 1.2.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com