nshtrainer 1.0.0b11__py3-none-any.whl → 1.0.0b13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+
5
+ from lightning.pytorch.callbacks import LearningRateMonitor
6
+
7
+ from .base import CallbackConfigBase
8
+
9
+
10
+ class LearningRateMonitorConfig(CallbackConfigBase):
11
+ logging_interval: Literal["step", "epoch"] | None = None
12
+ """
13
+ Set to 'epoch' or 'step' to log 'lr' of all optimizers at the same interval, set to None to log at individual interval according to the 'interval' key of each scheduler. Defaults to None.
14
+ """
15
+
16
+ log_momentum: bool = False
17
+ """
18
+ Option to also log the momentum values of the optimizer, if the optimizer has the 'momentum' or 'betas' attribute. Defaults to False.
19
+ """
20
+
21
+ log_weight_decay: bool = False
22
+ """
23
+ Option to also log the weight decay values of the optimizer. Defaults to False.
24
+ """
25
+
26
+ def create_callbacks(self, trainer_config):
27
+ yield LearningRateMonitor(
28
+ logging_interval=self.logging_interval,
29
+ log_momentum=self.log_momentum,
30
+ log_weight_decay=self.log_weight_decay,
31
+ )
@@ -132,10 +132,8 @@ if TYPE_CHECKING:
132
132
  from nshtrainer.trainer._config import (
133
133
  GradientClippingConfig as GradientClippingConfig,
134
134
  )
135
- from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
136
- from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
137
135
  from nshtrainer.trainer._config import (
138
- ReproducibilityConfig as ReproducibilityConfig,
136
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
139
137
  )
140
138
  from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
141
139
  from nshtrainer.util._environment_info import (
@@ -325,6 +323,10 @@ else:
325
323
  ).LastCheckpointStrategyConfig
326
324
  if name == "LeakyReLUNonlinearityConfig":
327
325
  return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
326
+ if name == "LearningRateMonitorConfig":
327
+ return importlib.import_module(
328
+ "nshtrainer.trainer._config"
329
+ ).LearningRateMonitorConfig
328
330
  if name == "LinearWarmupCosineDecayLRSchedulerConfig":
329
331
  return importlib.import_module(
330
332
  "nshtrainer.lr_scheduler"
@@ -333,8 +335,6 @@ else:
333
335
  return importlib.import_module(
334
336
  "nshtrainer.callbacks"
335
337
  ).LogEpochCallbackConfig
336
- if name == "LoggingConfig":
337
- return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
338
338
  if name == "MLPConfig":
339
339
  return importlib.import_module("nshtrainer.nn").MLPConfig
340
340
  if name == "MetricConfig":
@@ -349,10 +349,6 @@ else:
349
349
  return importlib.import_module(
350
350
  "nshtrainer.callbacks"
351
351
  ).OnExceptionCheckpointCallbackConfig
352
- if name == "OptimizationConfig":
353
- return importlib.import_module(
354
- "nshtrainer.trainer._config"
355
- ).OptimizationConfig
356
352
  if name == "OptimizerConfigBase":
357
353
  return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
358
354
  if name == "PReLUConfig":
@@ -373,10 +369,6 @@ else:
373
369
  return importlib.import_module(
374
370
  "nshtrainer.lr_scheduler"
375
371
  ).ReduceLROnPlateauConfig
376
- if name == "ReproducibilityConfig":
377
- return importlib.import_module(
378
- "nshtrainer.trainer._config"
379
- ).ReproducibilityConfig
380
372
  if name == "SanityCheckingConfig":
381
373
  return importlib.import_module(
382
374
  "nshtrainer.trainer._config"
@@ -62,6 +62,9 @@ if TYPE_CHECKING:
62
62
  CheckpointMetadata as CheckpointMetadata,
63
63
  )
64
64
  from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
65
+ from nshtrainer.callbacks.lr_monitor import (
66
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
67
+ )
65
68
  else:
66
69
 
67
70
  def __getattr__(name):
@@ -115,6 +118,10 @@ else:
115
118
  return importlib.import_module(
116
119
  "nshtrainer.callbacks"
117
120
  ).LastCheckpointCallbackConfig
121
+ if name == "LearningRateMonitorConfig":
122
+ return importlib.import_module(
123
+ "nshtrainer.callbacks.lr_monitor"
124
+ ).LearningRateMonitorConfig
118
125
  if name == "LogEpochCallbackConfig":
119
126
  return importlib.import_module(
120
127
  "nshtrainer.callbacks"
@@ -167,6 +174,7 @@ from . import ema as ema
167
174
  from . import finite_checks as finite_checks
168
175
  from . import gradient_skipping as gradient_skipping
169
176
  from . import log_epoch as log_epoch
177
+ from . import lr_monitor as lr_monitor
170
178
  from . import norm_logging as norm_logging
171
179
  from . import print_table as print_table
172
180
  from . import rlp_sanity_checks as rlp_sanity_checks
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.callbacks.lr_monitor import CallbackConfigBase as CallbackConfigBase
11
+ from nshtrainer.callbacks.lr_monitor import (
12
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
13
+ )
14
+ else:
15
+
16
+ def __getattr__(name):
17
+ import importlib
18
+
19
+ if name in globals():
20
+ return globals()[name]
21
+ if name == "CallbackConfigBase":
22
+ return importlib.import_module(
23
+ "nshtrainer.callbacks.lr_monitor"
24
+ ).CallbackConfigBase
25
+ if name == "LearningRateMonitorConfig":
26
+ return importlib.import_module(
27
+ "nshtrainer.callbacks.lr_monitor"
28
+ ).LearningRateMonitorConfig
29
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
+
31
+ # Submodule exports
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING
9
9
  if TYPE_CHECKING:
10
10
  from nshtrainer.trainer import TrainerConfig as TrainerConfig
11
11
  from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
12
+ from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
12
13
  from nshtrainer.trainer._config import (
13
14
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
14
15
  )
@@ -39,20 +40,21 @@ if TYPE_CHECKING:
39
40
  from nshtrainer.trainer._config import (
40
41
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
41
42
  )
43
+ from nshtrainer.trainer._config import (
44
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
45
+ )
42
46
  from nshtrainer.trainer._config import (
43
47
  LogEpochCallbackConfig as LogEpochCallbackConfig,
44
48
  )
45
49
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
46
- from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
47
50
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
48
51
  from nshtrainer.trainer._config import (
49
- OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
52
+ NormLoggingCallbackConfig as NormLoggingCallbackConfig,
50
53
  )
51
- from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
52
- from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
53
54
  from nshtrainer.trainer._config import (
54
- ReproducibilityConfig as ReproducibilityConfig,
55
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
55
56
  )
57
+ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
56
58
  from nshtrainer.trainer._config import (
57
59
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
58
60
  )
@@ -75,6 +77,10 @@ else:
75
77
  return importlib.import_module(
76
78
  "nshtrainer.trainer._config"
77
79
  ).ActSaveLoggerConfig
80
+ if name == "BaseLoggerConfig":
81
+ return importlib.import_module(
82
+ "nshtrainer.trainer._config"
83
+ ).BaseLoggerConfig
78
84
  if name == "BestCheckpointCallbackConfig":
79
85
  return importlib.import_module(
80
86
  "nshtrainer.trainer._config"
@@ -119,30 +125,28 @@ else:
119
125
  return importlib.import_module(
120
126
  "nshtrainer.trainer._config"
121
127
  ).LastCheckpointCallbackConfig
128
+ if name == "LearningRateMonitorConfig":
129
+ return importlib.import_module(
130
+ "nshtrainer.trainer._config"
131
+ ).LearningRateMonitorConfig
122
132
  if name == "LogEpochCallbackConfig":
123
133
  return importlib.import_module(
124
134
  "nshtrainer.trainer._config"
125
135
  ).LogEpochCallbackConfig
126
- if name == "LoggingConfig":
127
- return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
128
136
  if name == "MetricConfig":
129
137
  return importlib.import_module("nshtrainer.trainer._config").MetricConfig
130
- if name == "OnExceptionCheckpointCallbackConfig":
138
+ if name == "NormLoggingCallbackConfig":
131
139
  return importlib.import_module(
132
140
  "nshtrainer.trainer._config"
133
- ).OnExceptionCheckpointCallbackConfig
134
- if name == "OptimizationConfig":
141
+ ).NormLoggingCallbackConfig
142
+ if name == "OnExceptionCheckpointCallbackConfig":
135
143
  return importlib.import_module(
136
144
  "nshtrainer.trainer._config"
137
- ).OptimizationConfig
145
+ ).OnExceptionCheckpointCallbackConfig
138
146
  if name == "RLPSanityChecksCallbackConfig":
139
147
  return importlib.import_module(
140
148
  "nshtrainer.trainer._config"
141
149
  ).RLPSanityChecksCallbackConfig
142
- if name == "ReproducibilityConfig":
143
- return importlib.import_module(
144
- "nshtrainer.trainer._config"
145
- ).ReproducibilityConfig
146
150
  if name == "SanityCheckingConfig":
147
151
  return importlib.import_module(
148
152
  "nshtrainer.trainer._config"
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
11
+ from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
11
12
  from nshtrainer.trainer._config import (
12
13
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
13
14
  )
@@ -38,20 +39,21 @@ if TYPE_CHECKING:
38
39
  from nshtrainer.trainer._config import (
39
40
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
40
41
  )
42
+ from nshtrainer.trainer._config import (
43
+ LearningRateMonitorConfig as LearningRateMonitorConfig,
44
+ )
41
45
  from nshtrainer.trainer._config import (
42
46
  LogEpochCallbackConfig as LogEpochCallbackConfig,
43
47
  )
44
48
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
45
- from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
46
49
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
47
50
  from nshtrainer.trainer._config import (
48
- OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
51
+ NormLoggingCallbackConfig as NormLoggingCallbackConfig,
49
52
  )
50
- from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
51
- from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
52
53
  from nshtrainer.trainer._config import (
53
- ReproducibilityConfig as ReproducibilityConfig,
54
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
54
55
  )
56
+ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
55
57
  from nshtrainer.trainer._config import (
56
58
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
57
59
  )
@@ -75,6 +77,10 @@ else:
75
77
  return importlib.import_module(
76
78
  "nshtrainer.trainer._config"
77
79
  ).ActSaveLoggerConfig
80
+ if name == "BaseLoggerConfig":
81
+ return importlib.import_module(
82
+ "nshtrainer.trainer._config"
83
+ ).BaseLoggerConfig
78
84
  if name == "BestCheckpointCallbackConfig":
79
85
  return importlib.import_module(
80
86
  "nshtrainer.trainer._config"
@@ -119,30 +125,28 @@ else:
119
125
  return importlib.import_module(
120
126
  "nshtrainer.trainer._config"
121
127
  ).LastCheckpointCallbackConfig
128
+ if name == "LearningRateMonitorConfig":
129
+ return importlib.import_module(
130
+ "nshtrainer.trainer._config"
131
+ ).LearningRateMonitorConfig
122
132
  if name == "LogEpochCallbackConfig":
123
133
  return importlib.import_module(
124
134
  "nshtrainer.trainer._config"
125
135
  ).LogEpochCallbackConfig
126
- if name == "LoggingConfig":
127
- return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
128
136
  if name == "MetricConfig":
129
137
  return importlib.import_module("nshtrainer.trainer._config").MetricConfig
130
- if name == "OnExceptionCheckpointCallbackConfig":
138
+ if name == "NormLoggingCallbackConfig":
131
139
  return importlib.import_module(
132
140
  "nshtrainer.trainer._config"
133
- ).OnExceptionCheckpointCallbackConfig
134
- if name == "OptimizationConfig":
141
+ ).NormLoggingCallbackConfig
142
+ if name == "OnExceptionCheckpointCallbackConfig":
135
143
  return importlib.import_module(
136
144
  "nshtrainer.trainer._config"
137
- ).OptimizationConfig
145
+ ).OnExceptionCheckpointCallbackConfig
138
146
  if name == "RLPSanityChecksCallbackConfig":
139
147
  return importlib.import_module(
140
148
  "nshtrainer.trainer._config"
141
149
  ).RLPSanityChecksCallbackConfig
142
- if name == "ReproducibilityConfig":
143
- return importlib.import_module(
144
- "nshtrainer.trainer._config"
145
- ).ReproducibilityConfig
146
150
  if name == "SanityCheckingConfig":
147
151
  return importlib.import_module(
148
152
  "nshtrainer.trainer._config"
@@ -1,10 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from collections.abc import Mapping
4
+ from collections.abc import Callable, Mapping
5
+ from pathlib import Path
5
6
  from typing import Any, Generic, cast
6
7
 
7
8
  import nshconfig as C
9
+ import torch
8
10
  from lightning.pytorch import LightningDataModule
9
11
  from typing_extensions import Never, TypeVar, deprecated, override
10
12
 
@@ -55,3 +57,68 @@ class LightningDataModuleBase(
55
57
  )
56
58
  hparams = hparams.model_deep_validate()
57
59
  self.save_hyperparameters(hparams)
60
+
61
+ @override
62
+ @classmethod
63
+ def load_from_checkpoint(cls, *args, **kwargs) -> Never:
64
+ raise ValueError("This method is not supported. Use `from_checkpoint` instead.")
65
+
66
+ @classmethod
67
+ def hparams_from_checkpoint(
68
+ cls,
69
+ ckpt_or_path: dict[str, Any] | str | Path,
70
+ /,
71
+ strict: bool | None = None,
72
+ *,
73
+ update_hparams: Callable[[THparams], THparams] | None = None,
74
+ update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
75
+ ):
76
+ if isinstance(ckpt_or_path, dict):
77
+ ckpt = ckpt_or_path
78
+ else:
79
+ ckpt = torch.load(ckpt_or_path, map_location="cpu")
80
+
81
+ if (hparams := ckpt.get(cls.CHECKPOINT_HYPER_PARAMS_KEY)) is None:
82
+ raise ValueError(
83
+ f"The checkpoint does not contain hyperparameters. It must contain the key '{cls.CHECKPOINT_HYPER_PARAMS_KEY}'."
84
+ )
85
+ if update_hparams_dict is not None:
86
+ hparams = update_hparams_dict(hparams)
87
+
88
+ hparams = cls.hparams_cls().model_validate(hparams, strict=strict)
89
+ if update_hparams is not None:
90
+ hparams = update_hparams(hparams)
91
+
92
+ return hparams
93
+
94
+ @classmethod
95
+ def from_checkpoint(
96
+ cls,
97
+ ckpt_or_path: dict[str, Any] | str | Path,
98
+ /,
99
+ strict: bool | None = None,
100
+ map_location: torch.serialization.MAP_LOCATION = None,
101
+ *,
102
+ update_hparams: Callable[[THparams], THparams] | None = None,
103
+ update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
104
+ ):
105
+ # Load checkpoint
106
+ if isinstance(ckpt_or_path, Mapping):
107
+ ckpt = ckpt_or_path
108
+ else:
109
+ ckpt = torch.load(ckpt_or_path, map_location=map_location)
110
+
111
+ # Load hyperparameters from checkpoint
112
+ hparams = cls.hparams_from_checkpoint(
113
+ ckpt,
114
+ strict=strict,
115
+ update_hparams=update_hparams,
116
+ update_hparams_dict=update_hparams_dict,
117
+ )
118
+
119
+ # Load datamodule from checkpoint
120
+ datamodule = cls(hparams)
121
+ if datamodule.__class__.__qualname__ in ckpt:
122
+ datamodule.load_state_dict(ckpt[datamodule.__class__.__qualname__])
123
+
124
+ return datamodule
nshtrainer/model/base.py CHANGED
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  from abc import ABC, abstractmethod
5
- from collections.abc import Mapping
5
+ from collections.abc import Callable, Mapping
6
+ from pathlib import Path
6
7
  from typing import Any, Generic, Literal, cast
7
8
 
8
9
  import nshconfig as C
@@ -10,11 +11,13 @@ import torch
10
11
  import torch.distributed
11
12
  from lightning.pytorch import LightningModule
12
13
  from lightning.pytorch.profilers import PassThroughProfiler, Profiler
14
+ from lightning.pytorch.utilities.model_helpers import is_overridden
15
+ from lightning.pytorch.utilities.rank_zero import rank_zero_warn
13
16
  from typing_extensions import Never, TypeVar, deprecated, override
14
17
 
15
18
  from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
16
19
  from .mixins.callback import CallbackModuleMixin
17
- from .mixins.debug import _DebugModuleMixin, _trainer
20
+ from .mixins.debug import _DebugModuleMixin
18
21
  from .mixins.logger import LoggerLightningModuleMixin
19
22
 
20
23
  log = logging.getLogger(__name__)
@@ -241,3 +244,98 @@ class LightningModuleBase(
241
244
  loss = sum((0.0 * v).sum() for v in self.parameters() if v.requires_grad)
242
245
  loss = cast(torch.Tensor, loss)
243
246
  return loss
247
+
248
+ @override
249
+ @classmethod
250
+ def load_from_checkpoint(cls, *args, **kwargs) -> Never:
251
+ raise ValueError("This method is not supported. Use `from_checkpoint` instead.")
252
+
253
+ @classmethod
254
+ def hparams_from_checkpoint(
255
+ cls,
256
+ ckpt_or_path: dict[str, Any] | str | Path,
257
+ /,
258
+ strict: bool | None = None,
259
+ *,
260
+ update_hparams: Callable[[THparams], THparams] | None = None,
261
+ update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
262
+ ):
263
+ if isinstance(ckpt_or_path, dict):
264
+ ckpt = ckpt_or_path
265
+ else:
266
+ ckpt = torch.load(ckpt_or_path, map_location="cpu")
267
+
268
+ if (hparams := ckpt.get(cls.CHECKPOINT_HYPER_PARAMS_KEY)) is None:
269
+ raise ValueError(
270
+ f"The checkpoint does not contain hyperparameters. It must contain the key '{cls.CHECKPOINT_HYPER_PARAMS_KEY}'."
271
+ )
272
+ if update_hparams_dict is not None:
273
+ hparams = update_hparams_dict(hparams)
274
+
275
+ hparams = cls.hparams_cls().model_validate(hparams, strict=strict)
276
+ if update_hparams is not None:
277
+ hparams = update_hparams(hparams)
278
+
279
+ return hparams
280
+
281
+ @classmethod
282
+ def from_checkpoint(
283
+ cls,
284
+ ckpt_or_path: dict[str, Any] | str | Path,
285
+ /,
286
+ strict: bool | None = None,
287
+ map_location: torch.serialization.MAP_LOCATION = None,
288
+ *,
289
+ update_hparams: Callable[[THparams], THparams] | None = None,
290
+ update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
291
+ ):
292
+ # Load checkpoint
293
+ if isinstance(ckpt_or_path, Mapping):
294
+ ckpt = ckpt_or_path
295
+ else:
296
+ ckpt = torch.load(ckpt_or_path, map_location=map_location)
297
+
298
+ # Load hyperparameters from checkpoint
299
+ hparams = cls.hparams_from_checkpoint(
300
+ ckpt,
301
+ strict=strict,
302
+ update_hparams=update_hparams,
303
+ update_hparams_dict=update_hparams_dict,
304
+ )
305
+
306
+ # Load model from checkpoint
307
+ model = cls(hparams)
308
+
309
+ # Load model state from checkpoint
310
+ if (
311
+ model._strict_loading is not None
312
+ and strict is not None
313
+ and strict != model.strict_loading
314
+ ):
315
+ raise ValueError(
316
+ f"You set `.load_from_checkpoint(..., strict={strict!r})` which is in conflict with"
317
+ f" `{cls.__name__}.strict_loading={model.strict_loading!r}. Please set the same value for both of them."
318
+ )
319
+ strict = model.strict_loading if strict is None else strict
320
+
321
+ if is_overridden("configure_model", model):
322
+ model.configure_model()
323
+
324
+ # give model a chance to load something
325
+ model.on_load_checkpoint(ckpt)
326
+
327
+ # load the state_dict on the model automatically
328
+
329
+ keys = model.load_state_dict(ckpt["state_dict"], strict=strict)
330
+
331
+ if not strict:
332
+ if keys.missing_keys:
333
+ rank_zero_warn(
334
+ f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}"
335
+ )
336
+ if keys.unexpected_keys:
337
+ rank_zero_warn(
338
+ f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}"
339
+ )
340
+
341
+ return model
@@ -9,7 +9,6 @@ from collections.abc import Iterable, Sequence
9
9
  from datetime import timedelta
10
10
  from pathlib import Path
11
11
  from typing import (
12
- TYPE_CHECKING,
13
12
  Annotated,
14
13
  Any,
15
14
  ClassVar,
@@ -41,11 +40,13 @@ from ..callbacks import (
41
40
  CallbackConfig,
42
41
  EarlyStoppingCallbackConfig,
43
42
  LastCheckpointCallbackConfig,
43
+ NormLoggingCallbackConfig,
44
44
  OnExceptionCheckpointCallbackConfig,
45
45
  )
46
46
  from ..callbacks.base import CallbackConfigBase
47
47
  from ..callbacks.debug_flag import DebugFlagCallbackConfig
48
48
  from ..callbacks.log_epoch import LogEpochCallbackConfig
49
+ from ..callbacks.lr_monitor import LearningRateMonitorConfig
49
50
  from ..callbacks.rlp_sanity_checks import RLPSanityChecksCallbackConfig
50
51
  from ..callbacks.shared_parameters import SharedParametersCallbackConfig
51
52
  from ..loggers import (
@@ -54,6 +55,7 @@ from ..loggers import (
54
55
  TensorboardLoggerConfig,
55
56
  WandbLoggerConfig,
56
57
  )
58
+ from ..loggers._base import BaseLoggerConfig
57
59
  from ..loggers.actsave import ActSaveLoggerConfig
58
60
  from ..metrics._config import MetricConfig
59
61
  from ..profiler import ProfilerConfig
@@ -62,103 +64,6 @@ from ..util._environment_info import EnvironmentConfig
62
64
  log = logging.getLogger(__name__)
63
65
 
64
66
 
65
- class LoggingConfig(CallbackConfigBase):
66
- enabled: bool = True
67
- """Enable experiment tracking."""
68
-
69
- loggers: Sequence[LoggerConfig] = [
70
- WandbLoggerConfig(),
71
- CSVLoggerConfig(),
72
- TensorboardLoggerConfig(),
73
- ]
74
- """Loggers to use for experiment tracking."""
75
-
76
- log_lr: bool | Literal["step", "epoch"] = True
77
- """If enabled, will register a `LearningRateMonitor` callback to log the learning rate to the logger."""
78
- log_epoch: LogEpochCallbackConfig | None = LogEpochCallbackConfig()
79
- """If enabled, will log the fractional epoch number to the logger."""
80
-
81
- actsave_logger: ActSaveLoggerConfig | None = None
82
- """If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
83
-
84
- @property
85
- def wandb(self):
86
- return next(
87
- (
88
- logger
89
- for logger in self.loggers
90
- if isinstance(logger, WandbLoggerConfig)
91
- ),
92
- None,
93
- )
94
-
95
- @property
96
- def csv(self):
97
- return next(
98
- (logger for logger in self.loggers if isinstance(logger, CSVLoggerConfig)),
99
- None,
100
- )
101
-
102
- @property
103
- def tensorboard(self):
104
- return next(
105
- (
106
- logger
107
- for logger in self.loggers
108
- if isinstance(logger, TensorboardLoggerConfig)
109
- ),
110
- None,
111
- )
112
-
113
- def create_loggers(self, trainer_config: TrainerConfig):
114
- """
115
- Constructs and returns a list of loggers based on the provided root configuration.
116
-
117
- Args:
118
- trainer_config (TrainerConfig): The root configuration object.
119
-
120
- Returns:
121
- list[Logger]: A list of constructed loggers.
122
- """
123
- if not self.enabled:
124
- return
125
-
126
- for logger_config in sorted(
127
- self.loggers,
128
- key=lambda x: x.priority,
129
- reverse=True,
130
- ):
131
- if not logger_config.enabled:
132
- continue
133
- if (logger := logger_config.create_logger(trainer_config)) is None:
134
- continue
135
- yield logger
136
-
137
- # If the actsave_metrics is enabled, add the ActSave logger
138
- if self.actsave_logger:
139
- yield self.actsave_logger.create_logger(trainer_config)
140
-
141
- @override
142
- def create_callbacks(self, trainer_config):
143
- if self.log_lr:
144
- from lightning.pytorch.callbacks import LearningRateMonitor
145
-
146
- logging_interval: str | None = None
147
- if isinstance(self.log_lr, str):
148
- logging_interval = self.log_lr
149
-
150
- yield LearningRateMonitor(logging_interval=logging_interval)
151
-
152
- if self.log_epoch:
153
- yield from self.log_epoch.create_callbacks(trainer_config)
154
-
155
- for logger in self.loggers:
156
- if not logger or not isinstance(logger, CallbackConfigBase):
157
- continue
158
-
159
- yield from logger.create_callbacks(trainer_config)
160
-
161
-
162
67
  class GradientClippingConfig(C.Config):
163
68
  enabled: bool = True
164
69
  """Enable gradient clipping."""
@@ -168,32 +73,6 @@ class GradientClippingConfig(C.Config):
168
73
  """Norm type to use for gradient clipping."""
169
74
 
170
75
 
171
- class OptimizationConfig(CallbackConfigBase):
172
- log_grad_norm: bool | str | float = False
173
- """If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
174
- log_grad_norm_per_param: bool | str | float = False
175
- """If enabled, will log the gradient norm for each model parameter to the logger."""
176
-
177
- log_param_norm: bool | str | float = False
178
- """If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
179
- log_param_norm_per_param: bool | str | float = False
180
- """If enabled, will log the parameter norm for each model parameter to the logger."""
181
-
182
- gradient_clipping: GradientClippingConfig | None = None
183
- """Gradient clipping configuration, or None to disable gradient clipping."""
184
-
185
- @override
186
- def create_callbacks(self, trainer_config):
187
- from ..callbacks.norm_logging import NormLoggingCallbackConfig
188
-
189
- yield from NormLoggingCallbackConfig(
190
- log_grad_norm=self.log_grad_norm,
191
- log_grad_norm_per_param=self.log_grad_norm_per_param,
192
- log_param_norm=self.log_param_norm,
193
- log_param_norm_per_param=self.log_param_norm_per_param,
194
- ).create_callbacks(trainer_config)
195
-
196
-
197
76
  TPlugin = TypeVar(
198
77
  "TPlugin",
199
78
  Precision,
@@ -253,15 +132,6 @@ StrategyLiteral: TypeAlias = Literal[
253
132
  ]
254
133
 
255
134
 
256
- class ReproducibilityConfig(C.Config):
257
- deterministic: bool | Literal["warn"] | None = None
258
- """
259
- If ``True``, sets whether PyTorch operations must use deterministic algorithms.
260
- Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
261
- that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
262
- """
263
-
264
-
265
135
  CheckpointCallbackConfig: TypeAlias = Annotated[
266
136
  BestCheckpointCallbackConfig
267
137
  | LastCheckpointCallbackConfig
@@ -635,14 +505,34 @@ class TrainerConfig(C.Config):
635
505
  hf_hub: HuggingFaceHubConfig = HuggingFaceHubConfig()
636
506
  """Hugging Face Hub configuration options."""
637
507
 
638
- logging: LoggingConfig = LoggingConfig()
639
- """Logging/experiment tracking (e.g., WandB) configuration options."""
508
+ loggers: Sequence[LoggerConfig] = [
509
+ WandbLoggerConfig(),
510
+ CSVLoggerConfig(),
511
+ TensorboardLoggerConfig(),
512
+ ]
513
+ """Loggers to use for experiment tracking."""
640
514
 
641
- optimizer: OptimizationConfig = OptimizationConfig()
642
- """Optimization configuration options."""
515
+ actsave_logger: ActSaveLoggerConfig | None = None
516
+ """If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
643
517
 
644
- reproducibility: ReproducibilityConfig = ReproducibilityConfig()
645
- """Reproducibility configuration options."""
518
+ lr_monitor: LearningRateMonitorConfig | None = LearningRateMonitorConfig()
519
+ """Learning rate monitoring configuration options."""
520
+
521
+ log_epoch: LogEpochCallbackConfig | None = LogEpochCallbackConfig()
522
+ """If enabled, will log the fractional epoch number to the logger."""
523
+
524
+ gradient_clipping: GradientClippingConfig | None = None
525
+ """Gradient clipping configuration, or None to disable gradient clipping."""
526
+
527
+ log_norms: NormLoggingCallbackConfig | None = None
528
+ """Norm logging configuration options."""
529
+
530
+ deterministic: bool | Literal["warn"] | None = None
531
+ """
532
+ If ``True``, sets whether PyTorch operations must use deterministic algorithms.
533
+ Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
534
+ that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
535
+ """
646
536
 
647
537
  reduce_lr_on_plateau_sanity_checking: RLPSanityChecksCallbackConfig | None = (
648
538
  RLPSanityChecksCallbackConfig()
@@ -857,27 +747,87 @@ class TrainerConfig(C.Config):
857
747
  set_float32_matmul_precision: Literal["medium", "high", "highest"] | None = None
858
748
  """If enabled, will set the torch float32 matmul precision to the specified value. Useful for faster training on Ampere+ GPUs."""
859
749
 
750
+ @property
751
+ def wandb_logger(self):
752
+ return next(
753
+ (
754
+ logger
755
+ for logger in self.loggers
756
+ if isinstance(logger, WandbLoggerConfig)
757
+ ),
758
+ None,
759
+ )
760
+
761
+ @property
762
+ def csv_logger(self):
763
+ return next(
764
+ (logger for logger in self.loggers if isinstance(logger, CSVLoggerConfig)),
765
+ None,
766
+ )
767
+
768
+ @property
769
+ def tensorboard_logger(self):
770
+ return next(
771
+ (
772
+ logger
773
+ for logger in self.loggers
774
+ if isinstance(logger, TensorboardLoggerConfig)
775
+ ),
776
+ None,
777
+ )
778
+
860
779
  def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
861
780
  yield self.early_stopping
862
781
  yield self.checkpoint_saving
863
- yield self.logging
864
- yield self.optimizer
782
+ yield self.lr_monitor
783
+ yield from (
784
+ logger_config
785
+ for logger_config in self.loggers
786
+ if logger_config is not None
787
+ and isinstance(logger_config, CallbackConfigBase)
788
+ )
789
+ yield self.log_epoch
790
+ yield self.log_norms
865
791
  yield self.hf_hub
866
792
  yield self.shared_parameters
867
793
  yield self.reduce_lr_on_plateau_sanity_checking
868
794
  yield self.auto_set_debug_flag
869
795
  yield from self.callbacks
870
796
 
797
+ def _nshtrainer_all_logger_configs(self) -> Iterable[BaseLoggerConfig | None]:
798
+ yield from self.loggers
799
+ yield self.actsave_logger
800
+
871
801
  # region Helper Methods
802
+ def fast_dev_run_(self, value: int | bool = True, /):
803
+ """
804
+ Enables fast_dev_run mode for the trainer.
805
+ This will run the training loop for a specified number of batches,
806
+ if an integer is provided, or for a single batch if True is provided.
807
+ """
808
+ self.fast_dev_run = value
809
+ return self
810
+
872
811
  def with_fast_dev_run(self, value: int | bool = True, /):
873
812
  """
874
813
  Enables fast_dev_run mode for the trainer.
875
814
  This will run the training loop for a specified number of batches,
876
815
  if an integer is provided, or for a single batch if True is provided.
877
816
  """
878
- config = copy.deepcopy(self)
879
- config.fast_dev_run = value
880
- return config
817
+ return copy.deepcopy(self).fast_dev_run_(value)
818
+
819
+ def project_root_(self, project_root: str | Path | os.PathLike):
820
+ """
821
+ Set the project root directory for the trainer.
822
+
823
+ Args:
824
+ project_root (Path): The base directory to use.
825
+
826
+ Returns:
827
+ self: The current instance of the class.
828
+ """
829
+ self.directory.project_root = Path(project_root)
830
+ return self
881
831
 
882
832
  def with_project_root(self, project_root: str | Path | os.PathLike):
883
833
  """
@@ -889,9 +839,7 @@ class TrainerConfig(C.Config):
889
839
  Returns:
890
840
  self: The current instance of the class.
891
841
  """
892
- config = copy.deepcopy(self)
893
- config.directory.project_root = Path(project_root)
894
- return config
842
+ return copy.deepcopy(self).project_root_(project_root)
895
843
 
896
844
  def reset_run(
897
845
  self,
@@ -2,28 +2,19 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import os
5
- from collections.abc import Mapping, Sequence
5
+ from collections.abc import Callable, Mapping, Sequence
6
6
  from pathlib import Path
7
- from typing import IO, TYPE_CHECKING, Any, cast
7
+ from typing import TYPE_CHECKING, Any, cast
8
8
 
9
9
  import torch
10
10
  from lightning.fabric.plugins.environments.lsf import LSFEnvironment
11
11
  from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
12
12
  from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
13
- from lightning.fabric.utilities.cloud_io import _load as pl_load
14
- from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
15
13
  from lightning.pytorch import LightningModule
16
14
  from lightning.pytorch import Trainer as LightningTrainer
17
15
  from lightning.pytorch.callbacks import Callback
18
- from lightning.pytorch.core.saving import (
19
- _default_map_location,
20
- load_hparams_from_tags_csv,
21
- load_hparams_from_yaml,
22
- )
23
16
  from lightning.pytorch.profilers import Profiler
24
17
  from lightning.pytorch.trainer.states import TrainerFn
25
- from lightning.pytorch.utilities.migration import pl_legacy_patch
26
- from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint
27
18
  from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
28
19
  from typing_extensions import Never, Unpack, assert_never, deprecated, override
29
20
 
@@ -79,7 +70,7 @@ class Trainer(LightningTrainer):
79
70
  kwargs_ctor: LightningTrainerKwargs,
80
71
  ):
81
72
  kwargs: LightningTrainerKwargs = {
82
- "deterministic": hparams.reproducibility.deterministic,
73
+ "deterministic": hparams.deterministic,
83
74
  "fast_dev_run": hparams.fast_dev_run,
84
75
  "max_epochs": hparams.max_epochs,
85
76
  "min_epochs": hparams.min_epochs,
@@ -218,7 +209,7 @@ class Trainer(LightningTrainer):
218
209
  _update_kwargs(detect_anomaly=detect_anomaly)
219
210
 
220
211
  if (
221
- grad_clip_config := hparams.optimizer.gradient_clipping
212
+ grad_clip_config := hparams.gradient_clipping
222
213
  ) is not None and grad_clip_config.enabled:
223
214
  # kwargs["gradient_clip_algorithm"] = grad_clip_config.algorithm
224
215
  # kwargs["gradient_clip_val"] = grad_clip_config.value
@@ -248,17 +239,14 @@ class Trainer(LightningTrainer):
248
239
  ]
249
240
  )
250
241
 
251
- if not hparams.logging.enabled:
252
- log.critical(f"Disabling logger because {hparams.logging.enabled=}.")
253
- kwargs["logger"] = False
254
- else:
255
- _update_kwargs(
256
- logger=[
257
- logger
258
- for logger in hparams.logging.create_loggers(hparams)
259
- if logger is not None
260
- ]
261
- )
242
+ _update_kwargs(
243
+ logger=[
244
+ logger
245
+ for logger_config in hparams._nshtrainer_all_logger_configs()
246
+ if logger_config is not None
247
+ and (logger := logger_config.create_logger(hparams)) is not None
248
+ ]
249
+ )
262
250
 
263
251
  if hparams.auto_determine_num_nodes:
264
252
  # When num_nodes is auto, we need to detect the number of nodes.
@@ -473,62 +461,46 @@ class Trainer(LightningTrainer):
473
461
  _callback._call_on_checkpoint_saved(self, filepath, metadata_path)
474
462
 
475
463
  @classmethod
476
- def load_from_checkpoint(
464
+ def hparams_from_checkpoint(
477
465
  cls,
478
- checkpoint_path: _PATH | IO,
479
- map_location: _MAP_LOCATION_TYPE = None,
480
- hparams_file: _PATH | None = None,
481
- **kwargs: Any,
466
+ ckpt_or_path: dict[str, Any] | str | Path,
467
+ /,
468
+ strict: bool | None = None,
469
+ *,
470
+ update_hparams: Callable[[TrainerConfig], TrainerConfig] | None = None,
471
+ update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
482
472
  ):
483
- loaded = _load_from_checkpoint(
484
- checkpoint_path,
485
- map_location=map_location,
486
- hparams_file=hparams_file,
487
- **kwargs,
488
- )
489
- return loaded
490
-
491
-
492
- def _load_from_checkpoint(
493
- checkpoint_path: _PATH | IO,
494
- map_location: _MAP_LOCATION_TYPE = None,
495
- hparams_file: _PATH | None = None,
496
- **kwargs: Any,
497
- ):
498
- map_location = map_location or _default_map_location
499
- with pl_legacy_patch():
500
- checkpoint = pl_load(checkpoint_path, map_location=map_location)
501
-
502
- # convert legacy checkpoints to the new format
503
- checkpoint = _pl_migrate_checkpoint(
504
- checkpoint,
505
- checkpoint_path=(
506
- checkpoint_path if isinstance(checkpoint_path, (str, Path)) else None
507
- ),
508
- )
509
-
510
- if hparams_file is not None:
511
- extension = str(hparams_file).split(".")[-1]
512
- if extension.lower() == "csv":
513
- hparams = load_hparams_from_tags_csv(hparams_file)
514
- elif extension.lower() in ("yml", "yaml"):
515
- hparams = load_hparams_from_yaml(hparams_file)
473
+ if isinstance(ckpt_or_path, dict):
474
+ ckpt = ckpt_or_path
516
475
  else:
517
- raise ValueError(".csv, .yml or .yaml is required for `hparams_file`")
476
+ ckpt = torch.load(ckpt_or_path, map_location="cpu")
518
477
 
519
- # overwrite hparams by the given file
520
- checkpoint[Trainer.CHECKPOINT_HYPER_PARAMS_KEY] = hparams
478
+ if (hparams := ckpt.get(cls.CHECKPOINT_HYPER_PARAMS_KEY)) is None:
479
+ raise ValueError(
480
+ f"The checkpoint does not contain hyperparameters. It must contain the key '{cls.CHECKPOINT_HYPER_PARAMS_KEY}'."
481
+ )
482
+ if update_hparams_dict is not None:
483
+ hparams = update_hparams_dict(hparams)
521
484
 
522
- # for past checkpoint need to add the new key
523
- checkpoint.setdefault(Trainer.CHECKPOINT_HYPER_PARAMS_KEY, {})
524
- # override the hparams with values that were passed in
525
- checkpoint[Trainer.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
485
+ hparams = cls.hparams_cls().model_validate(hparams, strict=strict)
486
+ if update_hparams is not None:
487
+ hparams = update_hparams(hparams)
526
488
 
527
- # load the hparams
528
- hparams = Trainer.hparams_cls().model_validate(
529
- checkpoint[Trainer.CHECKPOINT_HYPER_PARAMS_KEY]
530
- )
489
+ return hparams
531
490
 
532
- # create the trainer
533
- trainer = Trainer(hparams)
534
- return trainer
491
+ @classmethod
492
+ def from_checkpoint(
493
+ cls,
494
+ path: str | Path,
495
+ strict: bool | None = None,
496
+ *,
497
+ update_hparams: Callable[[TrainerConfig], TrainerConfig] | None = None,
498
+ update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
499
+ ):
500
+ hparams = cls.hparams_from_checkpoint(
501
+ path,
502
+ strict=strict,
503
+ update_hparams=update_hparams,
504
+ update_hparams_dict=update_hparams_dict,
505
+ )
506
+ return cls(hparams)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 1.0.0b11
3
+ Version: 1.0.0b13
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -22,6 +22,7 @@ nshtrainer/callbacks/finite_checks.py,sha256=iCiKQ5i9RckkzcPeCHzC3hkg3AlW3ESuWtF
22
22
  nshtrainer/callbacks/gradient_skipping.py,sha256=k5qNaNeileZ_5YFad4ssfLplMxMKeKFhPcY8-QVmLek,3464
23
23
  nshtrainer/callbacks/interval.py,sha256=UCzUzt3XCFVyQyCWL9lOrStkkxesvduNOYk8yMrGTTk,8116
24
24
  nshtrainer/callbacks/log_epoch.py,sha256=Wr-Ksxsynsqu_zyB_zoiPLjnWv-ksC3xPekY6iyN-P8,1396
25
+ nshtrainer/callbacks/lr_monitor.py,sha256=IyFZoXaxJoTBSkdLu1iEZ1qI8_UFNJwafR_xTVPZXXU,1050
25
26
  nshtrainer/callbacks/norm_logging.py,sha256=C44Mvt73gqQEpCFd0j3qYg6NY7sL2jm3X1qJVY_XLfI,6329
26
27
  nshtrainer/callbacks/print_table.py,sha256=WIgfzVSfAfS3_8kUuX-nWJOGWBEmtNlejypuoJQViPY,2884
27
28
  nshtrainer/callbacks/rlp_sanity_checks.py,sha256=kWl2dYOXn2L8k6ub_012jNkqOxtyea1yr1qWRNG6UW4,9990
@@ -29,13 +30,13 @@ nshtrainer/callbacks/shared_parameters.py,sha256=33eRzifNj6reKbvmGuam1hUofo3sD4J
29
30
  nshtrainer/callbacks/timer.py,sha256=BB-M7tV4QNYOwY_Su6j9P7IILxVRae_upmDq4qsxiao,4670
30
31
  nshtrainer/callbacks/wandb_upload_code.py,sha256=PTqNE1QB5U8NR5zhbiQZrmQuugX2UV7B12UdMpo9aV0,2353
31
32
  nshtrainer/callbacks/wandb_watch.py,sha256=tTTcFzxd2Ia9xu8tCogQ5CLJZBq1ne5JlpGVE75vKYs,2976
32
- nshtrainer/configs/__init__.py,sha256=0TczWa5OFRKOGKHgabeB7VUMxPpD0RCgDR6AvdAD-tI,22721
33
+ nshtrainer/configs/__init__.py,sha256=Vyf_gn7u3s9ET4Yszf6SILtqvpIGiJ4X5RJfmW-FK6I,22293
33
34
  nshtrainer/configs/_checkpoint/__init__.py,sha256=vuiBbd4VzCo7lRyhyTUArEQeWwJkewvNPKDxBJiUHoY,2719
34
35
  nshtrainer/configs/_checkpoint/loader/__init__.py,sha256=hdLpypoEkES1MTaTHAdGFJnSoZzgx_8NzAKbK143SyI,2399
35
36
  nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=X9KxpcoHQbJp6-MTGvp4pct-MYHaHcl82s9yqZ5KiSk,867
36
37
  nshtrainer/configs/_directory/__init__.py,sha256=mTUoSz-DSsvI2M98cqu2Z2x215oM0sLyljh_5rVexvQ,1029
37
38
  nshtrainer/configs/_hf_hub/__init__.py,sha256=3HGCGhRb7NhOuLeskGqbYNuS9c81oOUbX6ibyF3XiCY,1063
38
- nshtrainer/configs/callbacks/__init__.py,sha256=-dHN8NZdCaNUy_isnlh779FZ1w9_WkOkv6VSN_-86jM,7316
39
+ nshtrainer/configs/callbacks/__init__.py,sha256=xgCa98EmqU8cHxlJa-64Cc4c_0fS0Cz2iVac4edL_yc,7657
39
40
  nshtrainer/configs/callbacks/actsave/__init__.py,sha256=AkVWS9vCcDJFpPUpyc7i9cjaFZU2kKxDyFDqakMZA-E,809
40
41
  nshtrainer/configs/callbacks/base/__init__.py,sha256=OdtHDMkYC_ioCEAkg7bSQi3o7e2t5WHPcFjavXdfdTA,602
41
42
  nshtrainer/configs/callbacks/checkpoint/__init__.py,sha256=DE-JTtY4NdaP3mgWduearFYMvy1tswWRBWMde06RzQc,2700
@@ -50,6 +51,7 @@ nshtrainer/configs/callbacks/ema/__init__.py,sha256=KlPGdJWjYTKLdpl-VnN4BYY2sA_L
50
51
  nshtrainer/configs/callbacks/finite_checks/__init__.py,sha256=V6Owp05XdIk3EO67AMVGdwbT4-D86QRuvqWM2gu5Xpw,949
51
52
  nshtrainer/configs/callbacks/gradient_skipping/__init__.py,sha256=RhIfJFq-x_sWYrWVGaVEBeT8uUFYjFgt0Ug8pPgpJSg,981
52
53
  nshtrainer/configs/callbacks/log_epoch/__init__.py,sha256=4jePzjE3bVxaI7hQrcWW5SrKT5MrFyplJZwK8bQHbGI,900
54
+ nshtrainer/configs/callbacks/lr_monitor/__init__.py,sha256=iC8U0oWC75JzPUMRoGWkC8WkMuLbF9-zuN_yQlByycY,916
53
55
  nshtrainer/configs/callbacks/norm_logging/__init__.py,sha256=8M9hcGpEuQadfgTR4-YL4TWeyxZjg0s84x420B03-aE,941
54
56
  nshtrainer/configs/callbacks/print_table/__init__.py,sha256=Ni47iS2mIzwGu8XuHfUY5BJKawUO_2TyJMZ62QBpEW0,961
55
57
  nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py,sha256=OgPywk8Z9y_dnq_liH2PPWuQSpUlQ_Q2-q99HDN9Leg,977
@@ -78,8 +80,8 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=smDYCplrI5B38XJcNZ462ZeTo9l
78
80
  nshtrainer/configs/profiler/advanced/__init__.py,sha256=VCCvbzhEeOcdZ0Unvk_anAcmbQuGogTQhK_bXs5RG9U,892
79
81
  nshtrainer/configs/profiler/pytorch/__init__.py,sha256=hJ90ym5ElI-BY_XS3VSLjcgQWfV0Pp1MdzTU6Qi8MFg,884
80
82
  nshtrainer/configs/profiler/simple/__init__.py,sha256=18V64kKYrJeSCrPmY3wYnshEISaf7xmrfw2Ny-6P3uE,859
81
- nshtrainer/configs/trainer/__init__.py,sha256=Gf5RizrVL84NWfHnagVCRHtXiD6x0UA4N7vhYypluTk,7916
82
- nshtrainer/configs/trainer/_config/__init__.py,sha256=BeQ5t_9d6rx6SbSu4ZqD9eitLCQRpTOVOnSxT0LCrlM,7806
83
+ nshtrainer/configs/trainer/__init__.py,sha256=QLCDVxVg1Ig-wgUW5r8I1FdPdbYz9-gse17s3R69Fw0,8019
84
+ nshtrainer/configs/trainer/_config/__init__.py,sha256=YEcliai8jLOoB53lxT5BZIR3NzfoLb4x3VGbMakFVo4,7909
83
85
  nshtrainer/configs/trainer/checkpoint_connector/__init__.py,sha256=pSu79zOFFWvqjI3SkHWl13H8ZNJFTc6a5J1r2KnfUKM,667
84
86
  nshtrainer/configs/trainer/trainer/__init__.py,sha256=P-Y2DOZZcJtEdPjGEKCxq5R3JSzKhUUoidkSvO_cfKI,797
85
87
  nshtrainer/configs/util/__init__.py,sha256=ZcmEqg2OWKKcPBqzDG1SnuaAMgR4_c0jog-Xg6QTUzc,4555
@@ -89,7 +91,7 @@ nshtrainer/configs/util/config/dtype/__init__.py,sha256=NCXMVO-EUz3JvPmlDci72O9Z
89
91
  nshtrainer/configs/util/config/duration/__init__.py,sha256=8llT1MCKQpsdNldN5h5Wo0GjUuRn28Sxw2FTXTNKBpM,1060
90
92
  nshtrainer/data/__init__.py,sha256=K4i3Tw4g9EOK2zlMMbidi99y0SyI4w8P7_XUf1n42Ts,260
91
93
  nshtrainer/data/balanced_batch_sampler.py,sha256=r1cBKRXKHD8E1Ax6tj-FUbE-z1qpbO58mQ9VrK9uLnc,5481
92
- nshtrainer/data/datamodule.py,sha256=NC0y7JAOvYCei-yPGUOeIB7MkHVcWUQE-5dWC0Hvpxo,1824
94
+ nshtrainer/data/datamodule.py,sha256=lSOgH32nysJWa6Y7ba1QyOdUV0DVVdO98qokP8wigjk,4138
93
95
  nshtrainer/data/transform.py,sha256=qd0lIocO59Fk_m90xyOHgFezbymd1mRwly8nbYIfHGc,2263
94
96
  nshtrainer/loggers/__init__.py,sha256=11X6D_lF0cHAkxkYsVZY3Q3r30Fq0FUi9heeb5RD870,570
95
97
  nshtrainer/loggers/_base.py,sha256=nw4AZzJP3Z-fljgQlgq7FkuMkPmYKTsXj7OfJJSmtXI,811
@@ -104,7 +106,7 @@ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=vXH5S26ESHO_LPPqW8aDC3S5N
104
106
  nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
105
107
  nshtrainer/metrics/_config.py,sha256=kR9OJLZXIaZpG8UpHG3EQwOL36HB8yPk6afYjoIM0XM,1324
106
108
  nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
107
- nshtrainer/model/base.py,sha256=qJXzu1w4mqF1eVwRBahEuJwJTIVSAZGZaV1vigwrg0Y,6817
109
+ nshtrainer/model/base.py,sha256=JL3AmH17GQjQIoMrZl3O0vUI7dj5ZsO5iEJgoLPyzHw,10356
108
110
  nshtrainer/model/mixins/callback.py,sha256=Ea_legORzs0N078j0N9RJivDVeWH5KtXDpdJS75IwIo,3098
109
111
  nshtrainer/model/mixins/debug.py,sha256=1LX9KzeFX9JDPs_a6YCdYDZXLhEk_5rBO2aCqlfBy7w,2087
110
112
  nshtrainer/model/mixins/logger.py,sha256=LQDJJbiv30PlWX6rTT_EhjNBNfUFfcvGz5sX4MnOCzI,5330
@@ -119,13 +121,12 @@ nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,
119
121
  nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5NOQ,1160
120
122
  nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
121
123
  nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
122
- nshtrainer/scripts/find_packages.py,sha256=-oNnSNPp3pujCVgManW_LFlJcnvhrtHgvUJ4W88-6-o,1460
123
124
  nshtrainer/trainer/__init__.py,sha256=MmoydVS6aYeav7zgDAUHxAQrV_PMQsbnZTCuPnLH9Wk,128
124
- nshtrainer/trainer/_config.py,sha256=5HC2Gs7YbJK9gMcKm4cQ78B3PxcOvwQ372Glhf4fzbo,34232
125
+ nshtrainer/trainer/_config.py,sha256=2S6Qhwn724n_jgGhWVI64Wi_pHKjU1ggoY4sxq-_SlA,32309
125
126
  nshtrainer/trainer/_runtime_callback.py,sha256=T3epaj1YeIN0R8CS2cg5HNJIB21TyaD_PVNNOPJ6nJs,4200
126
127
  nshtrainer/trainer/checkpoint_connector.py,sha256=pC1tTDcq0p6sAsoTmAbwINW49IfqupMMtnE9-AKdTUw,2824
127
128
  nshtrainer/trainer/signal_connector.py,sha256=YMJf6vTnW0JcnBkuYikm9x_9XscaokrCEzCn4THOGao,10776
128
- nshtrainer/trainer/trainer.py,sha256=JLcfdPqlOaAmKJAE3w9xIEai7Xyyl3JhDyMPsZ8fyv8,20692
129
+ nshtrainer/trainer/trainer.py,sha256=kIXh_25jDJSGcwEyLjbvqWN0P5B35VBJLXOwXqUGqF4,19759
129
130
  nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
130
131
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
131
132
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
@@ -138,6 +139,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
138
139
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
139
140
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
140
141
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
141
- nshtrainer-1.0.0b11.dist-info/METADATA,sha256=6IsguMlupAvNHaGUyo3m1-TzLvwycrpucABLwbzyj_Q,937
142
- nshtrainer-1.0.0b11.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
143
- nshtrainer-1.0.0b11.dist-info/RECORD,,
142
+ nshtrainer-1.0.0b13.dist-info/METADATA,sha256=9PQNipTw68KmSV_7Kt4fK_KtlYKSaKBcvvkBZrwWFtY,937
143
+ nshtrainer-1.0.0b13.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
144
+ nshtrainer-1.0.0b13.dist-info/RECORD,,
@@ -1,52 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import argparse
4
- import ast
5
- import glob
6
- import sys
7
- from pathlib import Path
8
-
9
-
10
- def get_imports(file_path: Path):
11
- with open(file_path, "r") as file:
12
- try:
13
- tree = ast.parse(file.read())
14
- except SyntaxError:
15
- print(f"Syntax error in file: {file_path}", file=sys.stderr)
16
- return set()
17
-
18
- imports = set()
19
- for node in ast.walk(tree):
20
- if isinstance(node, ast.Import):
21
- for alias in node.names:
22
- imports.add(alias.name.split(".")[0])
23
- elif isinstance(node, ast.ImportFrom):
24
- if node.level == 0 and node.module: # Absolute import
25
- imports.add(node.module.split(".")[0])
26
- return imports
27
-
28
-
29
- def main():
30
- parser = argparse.ArgumentParser(
31
- description="Find unique Python packages used in files."
32
- )
33
- parser.add_argument("glob_pattern", help="Glob pattern to match files")
34
- parser.add_argument(
35
- "--exclude-std", action="store_true", help="Exclude Python standard libraries"
36
- )
37
- args = parser.parse_args()
38
-
39
- all_imports = set()
40
- for file_path in glob.glob(args.glob_pattern, recursive=True):
41
- all_imports.update(get_imports(Path(file_path)))
42
-
43
- if args.exclude_std:
44
- std_libs = set(sys.stdlib_module_names)
45
- all_imports = all_imports - std_libs
46
-
47
- for package in sorted(all_imports):
48
- print(package)
49
-
50
-
51
- if __name__ == "__main__":
52
- main()