nshtrainer 1.0.0b12__py3-none-any.whl → 1.0.0b13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+
5
+ from lightning.pytorch.callbacks import LearningRateMonitor
6
+
7
+ from .base import CallbackConfigBase
8
+
9
+
10
+ class LearningRateMonitorConfig(CallbackConfigBase):
11
+ logging_interval: Literal["step", "epoch"] | None = None
12
+ """
13
+ Set to 'epoch' or 'step' to log 'lr' of all optimizers at the same interval, set to None to log at individual interval according to the 'interval' key of each scheduler. Defaults to None.
14
+ """
15
+
16
+ log_momentum: bool = False
17
+ """
18
+ Option to also log the momentum values of the optimizer, if the optimizer has the 'momentum' or 'betas' attribute. Defaults to False.
19
+ """
20
+
21
+ log_weight_decay: bool = False
22
+ """
23
+ Option to also log the weight decay values of the optimizer. Defaults to False.
24
+ """
25
+
26
+ def create_callbacks(self, trainer_config):
27
+ yield LearningRateMonitor(
28
+ logging_interval=self.logging_interval,
29
+ log_momentum=self.log_momentum,
30
+ log_weight_decay=self.log_weight_decay,
31
+ )
@@ -132,10 +132,8 @@ if TYPE_CHECKING:
132
132
  from nshtrainer.trainer._config import (
133
133
  GradientClippingConfig as GradientClippingConfig,
134
134
  )
135
- from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
136
- from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
137
135
  from nshtrainer.trainer._config import (
138
- ReproducibilityConfig as ReproducibilityConfig,
136
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
139
137
  )
140
138
  from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
141
139
  from nshtrainer.util._environment_info import (
@@ -325,6 +323,10 @@ else:
325
323
  ).LastCheckpointStrategyConfig
326
324
  if name == "LeakyReLUNonlinearityConfig":
327
325
  return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
326
+ if name == "LearningRateMonitorConfig":
327
+ return importlib.import_module(
328
+ "nshtrainer.trainer._config"
329
+ ).LearningRateMonitorConfig
328
330
  if name == "LinearWarmupCosineDecayLRSchedulerConfig":
329
331
  return importlib.import_module(
330
332
  "nshtrainer.lr_scheduler"
@@ -333,8 +335,6 @@ else:
333
335
  return importlib.import_module(
334
336
  "nshtrainer.callbacks"
335
337
  ).LogEpochCallbackConfig
336
- if name == "LoggingConfig":
337
- return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
338
338
  if name == "MLPConfig":
339
339
  return importlib.import_module("nshtrainer.nn").MLPConfig
340
340
  if name == "MetricConfig":
@@ -349,10 +349,6 @@ else:
349
349
  return importlib.import_module(
350
350
  "nshtrainer.callbacks"
351
351
  ).OnExceptionCheckpointCallbackConfig
352
- if name == "OptimizationConfig":
353
- return importlib.import_module(
354
- "nshtrainer.trainer._config"
355
- ).OptimizationConfig
356
352
  if name == "OptimizerConfigBase":
357
353
  return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
358
354
  if name == "PReLUConfig":
@@ -373,10 +369,6 @@ else:
373
369
  return importlib.import_module(
374
370
  "nshtrainer.lr_scheduler"
375
371
  ).ReduceLROnPlateauConfig
376
- if name == "ReproducibilityConfig":
377
- return importlib.import_module(
378
- "nshtrainer.trainer._config"
379
- ).ReproducibilityConfig
380
372
  if name == "SanityCheckingConfig":
381
373
  return importlib.import_module(
382
374
  "nshtrainer.trainer._config"
@@ -62,6 +62,9 @@ if TYPE_CHECKING:
62
62
  CheckpointMetadata as CheckpointMetadata,
63
63
  )
64
64
  from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
65
+ from nshtrainer.callbacks.lr_monitor import (
66
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
67
+ )
65
68
  else:
66
69
 
67
70
  def __getattr__(name):
@@ -115,6 +118,10 @@ else:
115
118
  return importlib.import_module(
116
119
  "nshtrainer.callbacks"
117
120
  ).LastCheckpointCallbackConfig
121
+ if name == "LearningRateMonitorConfig":
122
+ return importlib.import_module(
123
+ "nshtrainer.callbacks.lr_monitor"
124
+ ).LearningRateMonitorConfig
118
125
  if name == "LogEpochCallbackConfig":
119
126
  return importlib.import_module(
120
127
  "nshtrainer.callbacks"
@@ -167,6 +174,7 @@ from . import ema as ema
167
174
  from . import finite_checks as finite_checks
168
175
  from . import gradient_skipping as gradient_skipping
169
176
  from . import log_epoch as log_epoch
177
+ from . import lr_monitor as lr_monitor
170
178
  from . import norm_logging as norm_logging
171
179
  from . import print_table as print_table
172
180
  from . import rlp_sanity_checks as rlp_sanity_checks
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.callbacks.lr_monitor import CallbackConfigBase as CallbackConfigBase
11
+ from nshtrainer.callbacks.lr_monitor import (
12
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
13
+ )
14
+ else:
15
+
16
+ def __getattr__(name):
17
+ import importlib
18
+
19
+ if name in globals():
20
+ return globals()[name]
21
+ if name == "CallbackConfigBase":
22
+ return importlib.import_module(
23
+ "nshtrainer.callbacks.lr_monitor"
24
+ ).CallbackConfigBase
25
+ if name == "LearningRateMonitorConfig":
26
+ return importlib.import_module(
27
+ "nshtrainer.callbacks.lr_monitor"
28
+ ).LearningRateMonitorConfig
29
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
+
31
+ # Submodule exports
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING
9
9
  if TYPE_CHECKING:
10
10
  from nshtrainer.trainer import TrainerConfig as TrainerConfig
11
11
  from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
12
+ from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
12
13
  from nshtrainer.trainer._config import (
13
14
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
14
15
  )
@@ -39,20 +40,21 @@ if TYPE_CHECKING:
39
40
  from nshtrainer.trainer._config import (
40
41
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
41
42
  )
43
+ from nshtrainer.trainer._config import (
44
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
45
+ )
42
46
  from nshtrainer.trainer._config import (
43
47
  LogEpochCallbackConfig as LogEpochCallbackConfig,
44
48
  )
45
49
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
46
- from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
47
50
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
48
51
  from nshtrainer.trainer._config import (
49
- OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
52
+ NormLoggingCallbackConfig as NormLoggingCallbackConfig,
50
53
  )
51
- from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
52
- from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
53
54
  from nshtrainer.trainer._config import (
54
- ReproducibilityConfig as ReproducibilityConfig,
55
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
55
56
  )
57
+ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
56
58
  from nshtrainer.trainer._config import (
57
59
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
58
60
  )
@@ -75,6 +77,10 @@ else:
75
77
  return importlib.import_module(
76
78
  "nshtrainer.trainer._config"
77
79
  ).ActSaveLoggerConfig
80
+ if name == "BaseLoggerConfig":
81
+ return importlib.import_module(
82
+ "nshtrainer.trainer._config"
83
+ ).BaseLoggerConfig
78
84
  if name == "BestCheckpointCallbackConfig":
79
85
  return importlib.import_module(
80
86
  "nshtrainer.trainer._config"
@@ -119,30 +125,28 @@ else:
119
125
  return importlib.import_module(
120
126
  "nshtrainer.trainer._config"
121
127
  ).LastCheckpointCallbackConfig
128
+ if name == "LearningRateMonitorConfig":
129
+ return importlib.import_module(
130
+ "nshtrainer.trainer._config"
131
+ ).LearningRateMonitorConfig
122
132
  if name == "LogEpochCallbackConfig":
123
133
  return importlib.import_module(
124
134
  "nshtrainer.trainer._config"
125
135
  ).LogEpochCallbackConfig
126
- if name == "LoggingConfig":
127
- return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
128
136
  if name == "MetricConfig":
129
137
  return importlib.import_module("nshtrainer.trainer._config").MetricConfig
130
- if name == "OnExceptionCheckpointCallbackConfig":
138
+ if name == "NormLoggingCallbackConfig":
131
139
  return importlib.import_module(
132
140
  "nshtrainer.trainer._config"
133
- ).OnExceptionCheckpointCallbackConfig
134
- if name == "OptimizationConfig":
141
+ ).NormLoggingCallbackConfig
142
+ if name == "OnExceptionCheckpointCallbackConfig":
135
143
  return importlib.import_module(
136
144
  "nshtrainer.trainer._config"
137
- ).OptimizationConfig
145
+ ).OnExceptionCheckpointCallbackConfig
138
146
  if name == "RLPSanityChecksCallbackConfig":
139
147
  return importlib.import_module(
140
148
  "nshtrainer.trainer._config"
141
149
  ).RLPSanityChecksCallbackConfig
142
- if name == "ReproducibilityConfig":
143
- return importlib.import_module(
144
- "nshtrainer.trainer._config"
145
- ).ReproducibilityConfig
146
150
  if name == "SanityCheckingConfig":
147
151
  return importlib.import_module(
148
152
  "nshtrainer.trainer._config"
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
11
+ from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
11
12
  from nshtrainer.trainer._config import (
12
13
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
13
14
  )
@@ -38,20 +39,21 @@ if TYPE_CHECKING:
38
39
  from nshtrainer.trainer._config import (
39
40
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
40
41
  )
42
+ from nshtrainer.trainer._config import (
43
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
44
+ )
41
45
  from nshtrainer.trainer._config import (
42
46
  LogEpochCallbackConfig as LogEpochCallbackConfig,
43
47
  )
44
48
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
45
- from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
46
49
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
47
50
  from nshtrainer.trainer._config import (
48
- OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
51
+ NormLoggingCallbackConfig as NormLoggingCallbackConfig,
49
52
  )
50
- from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
51
- from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
52
53
  from nshtrainer.trainer._config import (
53
- ReproducibilityConfig as ReproducibilityConfig,
54
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
54
55
  )
56
+ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
55
57
  from nshtrainer.trainer._config import (
56
58
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
57
59
  )
@@ -75,6 +77,10 @@ else:
75
77
  return importlib.import_module(
76
78
  "nshtrainer.trainer._config"
77
79
  ).ActSaveLoggerConfig
80
+ if name == "BaseLoggerConfig":
81
+ return importlib.import_module(
82
+ "nshtrainer.trainer._config"
83
+ ).BaseLoggerConfig
78
84
  if name == "BestCheckpointCallbackConfig":
79
85
  return importlib.import_module(
80
86
  "nshtrainer.trainer._config"
@@ -119,30 +125,28 @@ else:
119
125
  return importlib.import_module(
120
126
  "nshtrainer.trainer._config"
121
127
  ).LastCheckpointCallbackConfig
128
+ if name == "LearningRateMonitorConfig":
129
+ return importlib.import_module(
130
+ "nshtrainer.trainer._config"
131
+ ).LearningRateMonitorConfig
122
132
  if name == "LogEpochCallbackConfig":
123
133
  return importlib.import_module(
124
134
  "nshtrainer.trainer._config"
125
135
  ).LogEpochCallbackConfig
126
- if name == "LoggingConfig":
127
- return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
128
136
  if name == "MetricConfig":
129
137
  return importlib.import_module("nshtrainer.trainer._config").MetricConfig
130
- if name == "OnExceptionCheckpointCallbackConfig":
138
+ if name == "NormLoggingCallbackConfig":
131
139
  return importlib.import_module(
132
140
  "nshtrainer.trainer._config"
133
- ).OnExceptionCheckpointCallbackConfig
134
- if name == "OptimizationConfig":
141
+ ).NormLoggingCallbackConfig
142
+ if name == "OnExceptionCheckpointCallbackConfig":
135
143
  return importlib.import_module(
136
144
  "nshtrainer.trainer._config"
137
- ).OptimizationConfig
145
+ ).OnExceptionCheckpointCallbackConfig
138
146
  if name == "RLPSanityChecksCallbackConfig":
139
147
  return importlib.import_module(
140
148
  "nshtrainer.trainer._config"
141
149
  ).RLPSanityChecksCallbackConfig
142
- if name == "ReproducibilityConfig":
143
- return importlib.import_module(
144
- "nshtrainer.trainer._config"
145
- ).ReproducibilityConfig
146
150
  if name == "SanityCheckingConfig":
147
151
  return importlib.import_module(
148
152
  "nshtrainer.trainer._config"
@@ -8,8 +8,6 @@ from typing import Any, Generic, cast
8
8
  import nshconfig as C
9
9
  import torch
10
10
  from lightning.pytorch import LightningDataModule
11
- from lightning.pytorch.utilities.model_helpers import is_overridden
12
- from lightning.pytorch.utilities.rank_zero import rank_zero_warn
13
11
  from typing_extensions import Never, TypeVar, deprecated, override
14
12
 
15
13
  from ..model.mixins.callback import CallbackRegistrarModuleMixin
@@ -40,11 +40,13 @@ from ..callbacks import (
40
40
  CallbackConfig,
41
41
  EarlyStoppingCallbackConfig,
42
42
  LastCheckpointCallbackConfig,
43
+ NormLoggingCallbackConfig,
43
44
  OnExceptionCheckpointCallbackConfig,
44
45
  )
45
46
  from ..callbacks.base import CallbackConfigBase
46
47
  from ..callbacks.debug_flag import DebugFlagCallbackConfig
47
48
  from ..callbacks.log_epoch import LogEpochCallbackConfig
49
+ from ..callbacks.lr_monitor import LearningRateMonitorConfig
48
50
  from ..callbacks.rlp_sanity_checks import RLPSanityChecksCallbackConfig
49
51
  from ..callbacks.shared_parameters import SharedParametersCallbackConfig
50
52
  from ..loggers import (
@@ -53,6 +55,7 @@ from ..loggers import (
53
55
  TensorboardLoggerConfig,
54
56
  WandbLoggerConfig,
55
57
  )
58
+ from ..loggers._base import BaseLoggerConfig
56
59
  from ..loggers.actsave import ActSaveLoggerConfig
57
60
  from ..metrics._config import MetricConfig
58
61
  from ..profiler import ProfilerConfig
@@ -61,103 +64,6 @@ from ..util._environment_info import EnvironmentConfig
61
64
  log = logging.getLogger(__name__)
62
65
 
63
66
 
64
- class LoggingConfig(CallbackConfigBase):
65
- enabled: bool = True
66
- """Enable experiment tracking."""
67
-
68
- loggers: Sequence[LoggerConfig] = [
69
- WandbLoggerConfig(),
70
- CSVLoggerConfig(),
71
- TensorboardLoggerConfig(),
72
- ]
73
- """Loggers to use for experiment tracking."""
74
-
75
- log_lr: bool | Literal["step", "epoch"] = True
76
- """If enabled, will register a `LearningRateMonitor` callback to log the learning rate to the logger."""
77
- log_epoch: LogEpochCallbackConfig | None = LogEpochCallbackConfig()
78
- """If enabled, will log the fractional epoch number to the logger."""
79
-
80
- actsave_logger: ActSaveLoggerConfig | None = None
81
- """If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
82
-
83
- @property
84
- def wandb(self):
85
- return next(
86
- (
87
- logger
88
- for logger in self.loggers
89
- if isinstance(logger, WandbLoggerConfig)
90
- ),
91
- None,
92
- )
93
-
94
- @property
95
- def csv(self):
96
- return next(
97
- (logger for logger in self.loggers if isinstance(logger, CSVLoggerConfig)),
98
- None,
99
- )
100
-
101
- @property
102
- def tensorboard(self):
103
- return next(
104
- (
105
- logger
106
- for logger in self.loggers
107
- if isinstance(logger, TensorboardLoggerConfig)
108
- ),
109
- None,
110
- )
111
-
112
- def create_loggers(self, trainer_config: TrainerConfig):
113
- """
114
- Constructs and returns a list of loggers based on the provided root configuration.
115
-
116
- Args:
117
- trainer_config (TrainerConfig): The root configuration object.
118
-
119
- Returns:
120
- list[Logger]: A list of constructed loggers.
121
- """
122
- if not self.enabled:
123
- return
124
-
125
- for logger_config in sorted(
126
- self.loggers,
127
- key=lambda x: x.priority,
128
- reverse=True,
129
- ):
130
- if not logger_config.enabled:
131
- continue
132
- if (logger := logger_config.create_logger(trainer_config)) is None:
133
- continue
134
- yield logger
135
-
136
- # If the actsave_metrics is enabled, add the ActSave logger
137
- if self.actsave_logger:
138
- yield self.actsave_logger.create_logger(trainer_config)
139
-
140
- @override
141
- def create_callbacks(self, trainer_config):
142
- if self.log_lr:
143
- from lightning.pytorch.callbacks import LearningRateMonitor
144
-
145
- logging_interval: str | None = None
146
- if isinstance(self.log_lr, str):
147
- logging_interval = self.log_lr
148
-
149
- yield LearningRateMonitor(logging_interval=logging_interval)
150
-
151
- if self.log_epoch:
152
- yield from self.log_epoch.create_callbacks(trainer_config)
153
-
154
- for logger in self.loggers:
155
- if not logger or not isinstance(logger, CallbackConfigBase):
156
- continue
157
-
158
- yield from logger.create_callbacks(trainer_config)
159
-
160
-
161
67
  class GradientClippingConfig(C.Config):
162
68
  enabled: bool = True
163
69
  """Enable gradient clipping."""
@@ -167,32 +73,6 @@ class GradientClippingConfig(C.Config):
167
73
  """Norm type to use for gradient clipping."""
168
74
 
169
75
 
170
- class OptimizationConfig(CallbackConfigBase):
171
- log_grad_norm: bool | str | float = False
172
- """If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
173
- log_grad_norm_per_param: bool | str | float = False
174
- """If enabled, will log the gradient norm for each model parameter to the logger."""
175
-
176
- log_param_norm: bool | str | float = False
177
- """If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
178
- log_param_norm_per_param: bool | str | float = False
179
- """If enabled, will log the parameter norm for each model parameter to the logger."""
180
-
181
- gradient_clipping: GradientClippingConfig | None = None
182
- """Gradient clipping configuration, or None to disable gradient clipping."""
183
-
184
- @override
185
- def create_callbacks(self, trainer_config):
186
- from ..callbacks.norm_logging import NormLoggingCallbackConfig
187
-
188
- yield from NormLoggingCallbackConfig(
189
- log_grad_norm=self.log_grad_norm,
190
- log_grad_norm_per_param=self.log_grad_norm_per_param,
191
- log_param_norm=self.log_param_norm,
192
- log_param_norm_per_param=self.log_param_norm_per_param,
193
- ).create_callbacks(trainer_config)
194
-
195
-
196
76
  TPlugin = TypeVar(
197
77
  "TPlugin",
198
78
  Precision,
@@ -252,15 +132,6 @@ StrategyLiteral: TypeAlias = Literal[
252
132
  ]
253
133
 
254
134
 
255
- class ReproducibilityConfig(C.Config):
256
- deterministic: bool | Literal["warn"] | None = None
257
- """
258
- If ``True``, sets whether PyTorch operations must use deterministic algorithms.
259
- Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
260
- that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
261
- """
262
-
263
-
264
135
  CheckpointCallbackConfig: TypeAlias = Annotated[
265
136
  BestCheckpointCallbackConfig
266
137
  | LastCheckpointCallbackConfig
@@ -634,14 +505,34 @@ class TrainerConfig(C.Config):
634
505
  hf_hub: HuggingFaceHubConfig = HuggingFaceHubConfig()
635
506
  """Hugging Face Hub configuration options."""
636
507
 
637
- logging: LoggingConfig = LoggingConfig()
638
- """Logging/experiment tracking (e.g., WandB) configuration options."""
508
+ loggers: Sequence[LoggerConfig] = [
509
+ WandbLoggerConfig(),
510
+ CSVLoggerConfig(),
511
+ TensorboardLoggerConfig(),
512
+ ]
513
+ """Loggers to use for experiment tracking."""
639
514
 
640
- optimizer: OptimizationConfig = OptimizationConfig()
641
- """Optimization configuration options."""
515
+ actsave_logger: ActSaveLoggerConfig | None = None
516
+ """If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
642
517
 
643
- reproducibility: ReproducibilityConfig = ReproducibilityConfig()
644
- """Reproducibility configuration options."""
518
+ lr_monitor: LearningRateMonitorConfig | None = LearningRateMonitorConfig()
519
+ """Learning rate monitoring configuration options."""
520
+
521
+ log_epoch: LogEpochCallbackConfig | None = LogEpochCallbackConfig()
522
+ """If enabled, will log the fractional epoch number to the logger."""
523
+
524
+ gradient_clipping: GradientClippingConfig | None = None
525
+ """Gradient clipping configuration, or None to disable gradient clipping."""
526
+
527
+ log_norms: NormLoggingCallbackConfig | None = None
528
+ """Norm logging configuration options."""
529
+
530
+ deterministic: bool | Literal["warn"] | None = None
531
+ """
532
+ If ``True``, sets whether PyTorch operations must use deterministic algorithms.
533
+ Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
534
+ that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
535
+ """
645
536
 
646
537
  reduce_lr_on_plateau_sanity_checking: RLPSanityChecksCallbackConfig | None = (
647
538
  RLPSanityChecksCallbackConfig()
@@ -856,27 +747,87 @@ class TrainerConfig(C.Config):
856
747
  set_float32_matmul_precision: Literal["medium", "high", "highest"] | None = None
857
748
  """If enabled, will set the torch float32 matmul precision to the specified value. Useful for faster training on Ampere+ GPUs."""
858
749
 
750
+ @property
751
+ def wandb_logger(self):
752
+ return next(
753
+ (
754
+ logger
755
+ for logger in self.loggers
756
+ if isinstance(logger, WandbLoggerConfig)
757
+ ),
758
+ None,
759
+ )
760
+
761
+ @property
762
+ def csv_logger(self):
763
+ return next(
764
+ (logger for logger in self.loggers if isinstance(logger, CSVLoggerConfig)),
765
+ None,
766
+ )
767
+
768
+ @property
769
+ def tensorboard_logger(self):
770
+ return next(
771
+ (
772
+ logger
773
+ for logger in self.loggers
774
+ if isinstance(logger, TensorboardLoggerConfig)
775
+ ),
776
+ None,
777
+ )
778
+
859
779
  def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
860
780
  yield self.early_stopping
861
781
  yield self.checkpoint_saving
862
- yield self.logging
863
- yield self.optimizer
782
+ yield self.lr_monitor
783
+ yield from (
784
+ logger_config
785
+ for logger_config in self.loggers
786
+ if logger_config is not None
787
+ and isinstance(logger_config, CallbackConfigBase)
788
+ )
789
+ yield self.log_epoch
790
+ yield self.log_norms
864
791
  yield self.hf_hub
865
792
  yield self.shared_parameters
866
793
  yield self.reduce_lr_on_plateau_sanity_checking
867
794
  yield self.auto_set_debug_flag
868
795
  yield from self.callbacks
869
796
 
797
+ def _nshtrainer_all_logger_configs(self) -> Iterable[BaseLoggerConfig | None]:
798
+ yield from self.loggers
799
+ yield self.actsave_logger
800
+
870
801
  # region Helper Methods
802
+ def fast_dev_run_(self, value: int | bool = True, /):
803
+ """
804
+ Enables fast_dev_run mode for the trainer.
805
+ This will run the training loop for a specified number of batches,
806
+ if an integer is provided, or for a single batch if True is provided.
807
+ """
808
+ self.fast_dev_run = value
809
+ return self
810
+
871
811
  def with_fast_dev_run(self, value: int | bool = True, /):
872
812
  """
873
813
  Enables fast_dev_run mode for the trainer.
874
814
  This will run the training loop for a specified number of batches,
875
815
  if an integer is provided, or for a single batch if True is provided.
876
816
  """
877
- config = copy.deepcopy(self)
878
- config.fast_dev_run = value
879
- return config
817
+ return copy.deepcopy(self).fast_dev_run_(value)
818
+
819
+ def project_root_(self, project_root: str | Path | os.PathLike):
820
+ """
821
+ Set the project root directory for the trainer.
822
+
823
+ Args:
824
+ project_root (Path): The base directory to use.
825
+
826
+ Returns:
827
+ self: The current instance of the class.
828
+ """
829
+ self.directory.project_root = Path(project_root)
830
+ return self
880
831
 
881
832
  def with_project_root(self, project_root: str | Path | os.PathLike):
882
833
  """
@@ -888,9 +839,7 @@ class TrainerConfig(C.Config):
888
839
  Returns:
889
840
  self: The current instance of the class.
890
841
  """
891
- config = copy.deepcopy(self)
892
- config.directory.project_root = Path(project_root)
893
- return config
842
+ return copy.deepcopy(self).project_root_(project_root)
894
843
 
895
844
  def reset_run(
896
845
  self,
@@ -70,7 +70,7 @@ class Trainer(LightningTrainer):
70
70
  kwargs_ctor: LightningTrainerKwargs,
71
71
  ):
72
72
  kwargs: LightningTrainerKwargs = {
73
- "deterministic": hparams.reproducibility.deterministic,
73
+ "deterministic": hparams.deterministic,
74
74
  "fast_dev_run": hparams.fast_dev_run,
75
75
  "max_epochs": hparams.max_epochs,
76
76
  "min_epochs": hparams.min_epochs,
@@ -209,7 +209,7 @@ class Trainer(LightningTrainer):
209
209
  _update_kwargs(detect_anomaly=detect_anomaly)
210
210
 
211
211
  if (
212
- grad_clip_config := hparams.optimizer.gradient_clipping
212
+ grad_clip_config := hparams.gradient_clipping
213
213
  ) is not None and grad_clip_config.enabled:
214
214
  # kwargs["gradient_clip_algorithm"] = grad_clip_config.algorithm
215
215
  # kwargs["gradient_clip_val"] = grad_clip_config.value
@@ -239,17 +239,14 @@ class Trainer(LightningTrainer):
239
239
  ]
240
240
  )
241
241
 
242
- if not hparams.logging.enabled:
243
- log.critical(f"Disabling logger because {hparams.logging.enabled=}.")
244
- kwargs["logger"] = False
245
- else:
246
- _update_kwargs(
247
- logger=[
248
- logger
249
- for logger in hparams.logging.create_loggers(hparams)
250
- if logger is not None
251
- ]
252
- )
242
+ _update_kwargs(
243
+ logger=[
244
+ logger
245
+ for logger_config in hparams._nshtrainer_all_logger_configs()
246
+ if logger_config is not None
247
+ and (logger := logger_config.create_logger(hparams)) is not None
248
+ ]
249
+ )
253
250
 
254
251
  if hparams.auto_determine_num_nodes:
255
252
  # When num_nodes is auto, we need to detect the number of nodes.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 1.0.0b12
3
+ Version: 1.0.0b13
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -22,6 +22,7 @@ nshtrainer/callbacks/finite_checks.py,sha256=iCiKQ5i9RckkzcPeCHzC3hkg3AlW3ESuWtF
22
22
  nshtrainer/callbacks/gradient_skipping.py,sha256=k5qNaNeileZ_5YFad4ssfLplMxMKeKFhPcY8-QVmLek,3464
23
23
  nshtrainer/callbacks/interval.py,sha256=UCzUzt3XCFVyQyCWL9lOrStkkxesvduNOYk8yMrGTTk,8116
24
24
  nshtrainer/callbacks/log_epoch.py,sha256=Wr-Ksxsynsqu_zyB_zoiPLjnWv-ksC3xPekY6iyN-P8,1396
25
+ nshtrainer/callbacks/lr_monitor.py,sha256=IyFZoXaxJoTBSkdLu1iEZ1qI8_UFNJwafR_xTVPZXXU,1050
25
26
  nshtrainer/callbacks/norm_logging.py,sha256=C44Mvt73gqQEpCFd0j3qYg6NY7sL2jm3X1qJVY_XLfI,6329
26
27
  nshtrainer/callbacks/print_table.py,sha256=WIgfzVSfAfS3_8kUuX-nWJOGWBEmtNlejypuoJQViPY,2884
27
28
  nshtrainer/callbacks/rlp_sanity_checks.py,sha256=kWl2dYOXn2L8k6ub_012jNkqOxtyea1yr1qWRNG6UW4,9990
@@ -29,13 +30,13 @@ nshtrainer/callbacks/shared_parameters.py,sha256=33eRzifNj6reKbvmGuam1hUofo3sD4J
29
30
  nshtrainer/callbacks/timer.py,sha256=BB-M7tV4QNYOwY_Su6j9P7IILxVRae_upmDq4qsxiao,4670
30
31
  nshtrainer/callbacks/wandb_upload_code.py,sha256=PTqNE1QB5U8NR5zhbiQZrmQuugX2UV7B12UdMpo9aV0,2353
31
32
  nshtrainer/callbacks/wandb_watch.py,sha256=tTTcFzxd2Ia9xu8tCogQ5CLJZBq1ne5JlpGVE75vKYs,2976
32
- nshtrainer/configs/__init__.py,sha256=0TczWa5OFRKOGKHgabeB7VUMxPpD0RCgDR6AvdAD-tI,22721
33
+ nshtrainer/configs/__init__.py,sha256=Vyf_gn7u3s9ET4Yszf6SILtqvpIGiJ4X5RJfmW-FK6I,22293
33
34
  nshtrainer/configs/_checkpoint/__init__.py,sha256=vuiBbd4VzCo7lRyhyTUArEQeWwJkewvNPKDxBJiUHoY,2719
34
35
  nshtrainer/configs/_checkpoint/loader/__init__.py,sha256=hdLpypoEkES1MTaTHAdGFJnSoZzgx_8NzAKbK143SyI,2399
35
36
  nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=X9KxpcoHQbJp6-MTGvp4pct-MYHaHcl82s9yqZ5KiSk,867
36
37
  nshtrainer/configs/_directory/__init__.py,sha256=mTUoSz-DSsvI2M98cqu2Z2x215oM0sLyljh_5rVexvQ,1029
37
38
  nshtrainer/configs/_hf_hub/__init__.py,sha256=3HGCGhRb7NhOuLeskGqbYNuS9c81oOUbX6ibyF3XiCY,1063
38
- nshtrainer/configs/callbacks/__init__.py,sha256=-dHN8NZdCaNUy_isnlh779FZ1w9_WkOkv6VSN_-86jM,7316
39
+ nshtrainer/configs/callbacks/__init__.py,sha256=xgCa98EmqU8cHxlJa-64Cc4c_0fS0Cz2iVac4edL_yc,7657
39
40
  nshtrainer/configs/callbacks/actsave/__init__.py,sha256=AkVWS9vCcDJFpPUpyc7i9cjaFZU2kKxDyFDqakMZA-E,809
40
41
  nshtrainer/configs/callbacks/base/__init__.py,sha256=OdtHDMkYC_ioCEAkg7bSQi3o7e2t5WHPcFjavXdfdTA,602
41
42
  nshtrainer/configs/callbacks/checkpoint/__init__.py,sha256=DE-JTtY4NdaP3mgWduearFYMvy1tswWRBWMde06RzQc,2700
@@ -50,6 +51,7 @@ nshtrainer/configs/callbacks/ema/__init__.py,sha256=KlPGdJWjYTKLdpl-VnN4BYY2sA_L
50
51
  nshtrainer/configs/callbacks/finite_checks/__init__.py,sha256=V6Owp05XdIk3EO67AMVGdwbT4-D86QRuvqWM2gu5Xpw,949
51
52
  nshtrainer/configs/callbacks/gradient_skipping/__init__.py,sha256=RhIfJFq-x_sWYrWVGaVEBeT8uUFYjFgt0Ug8pPgpJSg,981
52
53
  nshtrainer/configs/callbacks/log_epoch/__init__.py,sha256=4jePzjE3bVxaI7hQrcWW5SrKT5MrFyplJZwK8bQHbGI,900
54
+ nshtrainer/configs/callbacks/lr_monitor/__init__.py,sha256=iC8U0oWC75JzPUMRoGWkC8WkMuLbF9-zuN_yQlByycY,916
53
55
  nshtrainer/configs/callbacks/norm_logging/__init__.py,sha256=8M9hcGpEuQadfgTR4-YL4TWeyxZjg0s84x420B03-aE,941
54
56
  nshtrainer/configs/callbacks/print_table/__init__.py,sha256=Ni47iS2mIzwGu8XuHfUY5BJKawUO_2TyJMZ62QBpEW0,961
55
57
  nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py,sha256=OgPywk8Z9y_dnq_liH2PPWuQSpUlQ_Q2-q99HDN9Leg,977
@@ -78,8 +80,8 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=smDYCplrI5B38XJcNZ462ZeTo9l
78
80
  nshtrainer/configs/profiler/advanced/__init__.py,sha256=VCCvbzhEeOcdZ0Unvk_anAcmbQuGogTQhK_bXs5RG9U,892
79
81
  nshtrainer/configs/profiler/pytorch/__init__.py,sha256=hJ90ym5ElI-BY_XS3VSLjcgQWfV0Pp1MdzTU6Qi8MFg,884
80
82
  nshtrainer/configs/profiler/simple/__init__.py,sha256=18V64kKYrJeSCrPmY3wYnshEISaf7xmrfw2Ny-6P3uE,859
81
- nshtrainer/configs/trainer/__init__.py,sha256=Gf5RizrVL84NWfHnagVCRHtXiD6x0UA4N7vhYypluTk,7916
82
- nshtrainer/configs/trainer/_config/__init__.py,sha256=BeQ5t_9d6rx6SbSu4ZqD9eitLCQRpTOVOnSxT0LCrlM,7806
83
+ nshtrainer/configs/trainer/__init__.py,sha256=QLCDVxVg1Ig-wgUW5r8I1FdPdbYz9-gse17s3R69Fw0,8019
84
+ nshtrainer/configs/trainer/_config/__init__.py,sha256=YEcliai8jLOoB53lxT5BZIR3NzfoLb4x3VGbMakFVo4,7909
83
85
  nshtrainer/configs/trainer/checkpoint_connector/__init__.py,sha256=pSu79zOFFWvqjI3SkHWl13H8ZNJFTc6a5J1r2KnfUKM,667
84
86
  nshtrainer/configs/trainer/trainer/__init__.py,sha256=P-Y2DOZZcJtEdPjGEKCxq5R3JSzKhUUoidkSvO_cfKI,797
85
87
  nshtrainer/configs/util/__init__.py,sha256=ZcmEqg2OWKKcPBqzDG1SnuaAMgR4_c0jog-Xg6QTUzc,4555
@@ -89,7 +91,7 @@ nshtrainer/configs/util/config/dtype/__init__.py,sha256=NCXMVO-EUz3JvPmlDci72O9Z
89
91
  nshtrainer/configs/util/config/duration/__init__.py,sha256=8llT1MCKQpsdNldN5h5Wo0GjUuRn28Sxw2FTXTNKBpM,1060
90
92
  nshtrainer/data/__init__.py,sha256=K4i3Tw4g9EOK2zlMMbidi99y0SyI4w8P7_XUf1n42Ts,260
91
93
  nshtrainer/data/balanced_batch_sampler.py,sha256=r1cBKRXKHD8E1Ax6tj-FUbE-z1qpbO58mQ9VrK9uLnc,5481
92
- nshtrainer/data/datamodule.py,sha256=5nh6LwjO70M5m0WgYegB6gtLnu61HOOn17SFJHEsMcE,4271
94
+ nshtrainer/data/datamodule.py,sha256=lSOgH32nysJWa6Y7ba1QyOdUV0DVVdO98qokP8wigjk,4138
93
95
  nshtrainer/data/transform.py,sha256=qd0lIocO59Fk_m90xyOHgFezbymd1mRwly8nbYIfHGc,2263
94
96
  nshtrainer/loggers/__init__.py,sha256=11X6D_lF0cHAkxkYsVZY3Q3r30Fq0FUi9heeb5RD870,570
95
97
  nshtrainer/loggers/_base.py,sha256=nw4AZzJP3Z-fljgQlgq7FkuMkPmYKTsXj7OfJJSmtXI,811
@@ -120,11 +122,11 @@ nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5N
120
122
  nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
121
123
  nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
122
124
  nshtrainer/trainer/__init__.py,sha256=MmoydVS6aYeav7zgDAUHxAQrV_PMQsbnZTCuPnLH9Wk,128
123
- nshtrainer/trainer/_config.py,sha256=WnI-08MIwm-uo4ak-44427asEMU3Kx8rIqNiu3z0Eck,34213
125
+ nshtrainer/trainer/_config.py,sha256=2S6Qhwn724n_jgGhWVI64Wi_pHKjU1ggoY4sxq-_SlA,32309
124
126
  nshtrainer/trainer/_runtime_callback.py,sha256=T3epaj1YeIN0R8CS2cg5HNJIB21TyaD_PVNNOPJ6nJs,4200
125
127
  nshtrainer/trainer/checkpoint_connector.py,sha256=pC1tTDcq0p6sAsoTmAbwINW49IfqupMMtnE9-AKdTUw,2824
126
128
  nshtrainer/trainer/signal_connector.py,sha256=YMJf6vTnW0JcnBkuYikm9x_9XscaokrCEzCn4THOGao,10776
127
- nshtrainer/trainer/trainer.py,sha256=Yp2RSdY7_4Sw_pDG5YRhgPwGMFwMuukUYWuYopVAJ3s,19890
129
+ nshtrainer/trainer/trainer.py,sha256=kIXh_25jDJSGcwEyLjbvqWN0P5B35VBJLXOwXqUGqF4,19759
128
130
  nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
129
131
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
130
132
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
@@ -137,6 +139,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
137
139
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
138
140
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
139
141
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
140
- nshtrainer-1.0.0b12.dist-info/METADATA,sha256=h_Td_f7pRokAYJnVvJiTpWpdDwwoBDpHvYCHDy0A3bc,937
141
- nshtrainer-1.0.0b12.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
142
- nshtrainer-1.0.0b12.dist-info/RECORD,,
142
+ nshtrainer-1.0.0b13.dist-info/METADATA,sha256=9PQNipTw68KmSV_7Kt4fK_KtlYKSaKBcvvkBZrwWFtY,937
143
+ nshtrainer-1.0.0b13.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
144
+ nshtrainer-1.0.0b13.dist-info/RECORD,,