nshtrainer 0.44.0__py3-none-any.whl → 1.0.0b9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. nshtrainer/__init__.py +6 -3
  2. nshtrainer/_callback.py +297 -2
  3. nshtrainer/_checkpoint/loader.py +23 -30
  4. nshtrainer/_checkpoint/metadata.py +22 -18
  5. nshtrainer/_experimental/__init__.py +0 -2
  6. nshtrainer/_hf_hub.py +25 -26
  7. nshtrainer/callbacks/__init__.py +1 -3
  8. nshtrainer/callbacks/actsave.py +22 -20
  9. nshtrainer/callbacks/base.py +7 -7
  10. nshtrainer/callbacks/checkpoint/__init__.py +1 -1
  11. nshtrainer/callbacks/checkpoint/_base.py +8 -5
  12. nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
  13. nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
  14. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
  15. nshtrainer/callbacks/debug_flag.py +14 -19
  16. nshtrainer/callbacks/directory_setup.py +6 -11
  17. nshtrainer/callbacks/early_stopping.py +3 -3
  18. nshtrainer/callbacks/ema.py +1 -1
  19. nshtrainer/callbacks/finite_checks.py +1 -1
  20. nshtrainer/callbacks/gradient_skipping.py +1 -1
  21. nshtrainer/callbacks/log_epoch.py +1 -1
  22. nshtrainer/callbacks/norm_logging.py +1 -1
  23. nshtrainer/callbacks/print_table.py +1 -1
  24. nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
  25. nshtrainer/callbacks/shared_parameters.py +1 -1
  26. nshtrainer/callbacks/timer.py +1 -1
  27. nshtrainer/callbacks/wandb_upload_code.py +1 -1
  28. nshtrainer/callbacks/wandb_watch.py +1 -1
  29. nshtrainer/config/__init__.py +189 -189
  30. nshtrainer/config/_checkpoint/__init__.py +70 -0
  31. nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
  32. nshtrainer/config/_directory/__init__.py +2 -2
  33. nshtrainer/config/_hf_hub/__init__.py +2 -2
  34. nshtrainer/config/callbacks/__init__.py +44 -44
  35. nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
  36. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
  37. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
  38. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
  39. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
  40. nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
  41. nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
  42. nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
  43. nshtrainer/config/callbacks/ema/__init__.py +2 -2
  44. nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
  45. nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
  46. nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
  47. nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
  48. nshtrainer/config/callbacks/print_table/__init__.py +4 -4
  49. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
  50. nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
  51. nshtrainer/config/callbacks/timer/__init__.py +4 -4
  52. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
  53. nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
  54. nshtrainer/config/loggers/__init__.py +10 -6
  55. nshtrainer/config/loggers/actsave/__init__.py +29 -0
  56. nshtrainer/config/loggers/csv/__init__.py +2 -2
  57. nshtrainer/config/loggers/wandb/__init__.py +6 -6
  58. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
  59. nshtrainer/config/nn/__init__.py +18 -18
  60. nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
  61. nshtrainer/config/optimizer/__init__.py +2 -2
  62. nshtrainer/config/profiler/__init__.py +2 -2
  63. nshtrainer/config/profiler/pytorch/__init__.py +4 -4
  64. nshtrainer/config/profiler/simple/__init__.py +4 -4
  65. nshtrainer/config/trainer/__init__.py +180 -0
  66. nshtrainer/config/trainer/_config/__init__.py +59 -36
  67. nshtrainer/config/trainer/trainer/__init__.py +27 -0
  68. nshtrainer/config/util/__init__.py +109 -0
  69. nshtrainer/config/util/_environment_info/__init__.py +20 -20
  70. nshtrainer/config/util/config/__init__.py +2 -2
  71. nshtrainer/data/datamodule.py +51 -2
  72. nshtrainer/loggers/__init__.py +2 -1
  73. nshtrainer/loggers/_base.py +5 -2
  74. nshtrainer/loggers/actsave.py +59 -0
  75. nshtrainer/loggers/csv.py +5 -5
  76. nshtrainer/loggers/tensorboard.py +5 -5
  77. nshtrainer/loggers/wandb.py +17 -16
  78. nshtrainer/lr_scheduler/_base.py +2 -1
  79. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
  80. nshtrainer/model/__init__.py +0 -4
  81. nshtrainer/model/base.py +64 -347
  82. nshtrainer/model/mixins/callback.py +24 -5
  83. nshtrainer/model/mixins/debug.py +86 -0
  84. nshtrainer/model/mixins/logger.py +142 -145
  85. nshtrainer/profiler/_base.py +2 -2
  86. nshtrainer/profiler/advanced.py +4 -4
  87. nshtrainer/profiler/pytorch.py +4 -4
  88. nshtrainer/profiler/simple.py +4 -4
  89. nshtrainer/trainer/__init__.py +1 -0
  90. nshtrainer/trainer/_config.py +164 -17
  91. nshtrainer/trainer/checkpoint_connector.py +23 -8
  92. nshtrainer/trainer/trainer.py +194 -76
  93. nshtrainer/util/_environment_info.py +21 -13
  94. nshtrainer/util/config/dtype.py +4 -4
  95. nshtrainer/util/typing_utils.py +1 -1
  96. {nshtrainer-0.44.0.dist-info → nshtrainer-1.0.0b9.dist-info}/METADATA +2 -2
  97. nshtrainer-1.0.0b9.dist-info/RECORD +143 -0
  98. nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
  99. nshtrainer/callbacks/throughput_monitor.py +0 -58
  100. nshtrainer/config/model/__init__.py +0 -41
  101. nshtrainer/config/model/base/__init__.py +0 -25
  102. nshtrainer/config/model/config/__init__.py +0 -37
  103. nshtrainer/config/model/mixins/logger/__init__.py +0 -22
  104. nshtrainer/config/runner/__init__.py +0 -22
  105. nshtrainer/ll/__init__.py +0 -59
  106. nshtrainer/ll/_experimental.py +0 -3
  107. nshtrainer/ll/actsave.py +0 -6
  108. nshtrainer/ll/callbacks.py +0 -3
  109. nshtrainer/ll/config.py +0 -6
  110. nshtrainer/ll/data.py +0 -3
  111. nshtrainer/ll/log.py +0 -5
  112. nshtrainer/ll/lr_scheduler.py +0 -3
  113. nshtrainer/ll/model.py +0 -21
  114. nshtrainer/ll/nn.py +0 -3
  115. nshtrainer/ll/optimizer.py +0 -3
  116. nshtrainer/ll/runner.py +0 -5
  117. nshtrainer/ll/snapshot.py +0 -3
  118. nshtrainer/ll/snoop.py +0 -3
  119. nshtrainer/ll/trainer.py +0 -3
  120. nshtrainer/ll/typecheck.py +0 -3
  121. nshtrainer/ll/util.py +0 -3
  122. nshtrainer/model/config.py +0 -218
  123. nshtrainer/runner.py +0 -101
  124. nshtrainer-0.44.0.dist-info/RECORD +0 -162
  125. {nshtrainer-0.44.0.dist-info → nshtrainer-1.0.0b9.dist-info}/WHEEL +0 -0
@@ -2,90 +2,120 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import os
5
- from collections.abc import Sequence
5
+ from collections.abc import Mapping, Sequence
6
6
  from pathlib import Path
7
- from typing import TYPE_CHECKING, Any, cast
7
+ from typing import IO, TYPE_CHECKING, Any, cast
8
8
 
9
9
  import torch
10
10
  from lightning.fabric.plugins.environments.lsf import LSFEnvironment
11
11
  from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
12
12
  from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
13
+ from lightning.fabric.utilities.cloud_io import _load as pl_load
14
+ from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
13
15
  from lightning.pytorch import LightningModule
14
16
  from lightning.pytorch import Trainer as LightningTrainer
15
17
  from lightning.pytorch.callbacks import Callback
18
+ from lightning.pytorch.core.saving import (
19
+ _default_map_location,
20
+ load_hparams_from_tags_csv,
21
+ load_hparams_from_yaml,
22
+ )
16
23
  from lightning.pytorch.profilers import Profiler
17
24
  from lightning.pytorch.trainer.states import TrainerFn
25
+ from lightning.pytorch.utilities.migration import pl_legacy_patch
26
+ from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint
18
27
  from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
19
- from typing_extensions import Unpack, assert_never, override
28
+ from typing_extensions import Never, Unpack, assert_never, deprecated, override
20
29
 
21
30
  from .._checkpoint.metadata import _write_checkpoint_metadata
22
31
  from ..callbacks.base import resolve_all_callbacks
32
+ from ..util._environment_info import EnvironmentConfig
23
33
  from ..util.bf16 import is_bf16_supported_no_emulation
24
34
  from ._config import (
25
35
  AcceleratorConfigProtocol,
26
36
  LightningTrainerKwargs,
27
37
  StrategyConfigProtocol,
38
+ TrainerConfig,
28
39
  )
29
40
  from ._runtime_callback import RuntimeTrackerCallback, Stage
30
41
  from .checkpoint_connector import _CheckpointConnector
31
42
  from .signal_connector import _SignalConnector
32
43
 
33
- if TYPE_CHECKING:
34
- from ..model.config import BaseConfig
35
-
36
44
  log = logging.getLogger(__name__)
37
45
 
38
46
 
39
47
  class Trainer(LightningTrainer):
48
+ CHECKPOINT_HYPER_PARAMS_KEY = "trainer_hyper_parameters"
49
+
50
+ @property
51
+ def hparams(self) -> TrainerConfig:
52
+ """The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user. For
53
+ the frozen set of initial hyperparameters, use :attr:`hparams_initial`.
54
+
55
+ Returns:
56
+ Mutable hyperparameters dictionary
57
+
58
+ """
59
+ return self._hparams
60
+
61
+ @property
62
+ @deprecated("Use `hparams` instead")
63
+ def config(self):
64
+ return cast(Never, self.hparams)
65
+
66
+ @classmethod
67
+ def hparams_cls(cls):
68
+ return TrainerConfig
69
+
40
70
  @classmethod
41
- def _pre_init(cls, config: "BaseConfig"):
42
- if (precision := config.trainer.set_float32_matmul_precision) is not None:
71
+ def _pre_init(cls, hparams: TrainerConfig):
72
+ if (precision := hparams.set_float32_matmul_precision) is not None:
43
73
  torch.set_float32_matmul_precision(precision)
44
74
 
45
75
  @classmethod
46
76
  def _update_kwargs(
47
77
  cls,
48
- config: "BaseConfig",
78
+ hparams: TrainerConfig,
49
79
  kwargs_ctor: LightningTrainerKwargs,
50
80
  ):
51
81
  kwargs: LightningTrainerKwargs = {
52
- "deterministic": config.trainer.reproducibility.deterministic,
53
- "fast_dev_run": config.trainer.fast_dev_run,
54
- "max_epochs": config.trainer.max_epochs,
55
- "min_epochs": config.trainer.min_epochs,
56
- "max_steps": config.trainer.max_steps,
57
- "min_steps": config.trainer.min_steps,
58
- "max_time": config.trainer.max_time,
59
- "limit_train_batches": config.trainer.limit_train_batches,
60
- "limit_val_batches": config.trainer.limit_val_batches,
61
- "limit_test_batches": config.trainer.limit_test_batches,
62
- "limit_predict_batches": config.trainer.limit_predict_batches,
63
- "overfit_batches": config.trainer.overfit_batches,
64
- "val_check_interval": config.trainer.val_check_interval,
65
- "num_sanity_val_steps": config.trainer.num_sanity_val_steps,
66
- "log_every_n_steps": config.trainer.log_every_n_steps,
67
- "inference_mode": config.trainer.inference_mode,
82
+ "deterministic": hparams.reproducibility.deterministic,
83
+ "fast_dev_run": hparams.fast_dev_run,
84
+ "max_epochs": hparams.max_epochs,
85
+ "min_epochs": hparams.min_epochs,
86
+ "max_steps": hparams.max_steps,
87
+ "min_steps": hparams.min_steps,
88
+ "max_time": hparams.max_time,
89
+ "limit_train_batches": hparams.limit_train_batches,
90
+ "limit_val_batches": hparams.limit_val_batches,
91
+ "limit_test_batches": hparams.limit_test_batches,
92
+ "limit_predict_batches": hparams.limit_predict_batches,
93
+ "overfit_batches": hparams.overfit_batches,
94
+ "val_check_interval": hparams.val_check_interval,
95
+ "num_sanity_val_steps": hparams.num_sanity_val_steps,
96
+ "log_every_n_steps": hparams.log_every_n_steps,
97
+ "inference_mode": hparams.inference_mode,
68
98
  "callbacks": [],
69
99
  "plugins": [],
70
100
  "logger": [],
71
101
  # Moved to `lightning_kwargs`:
72
- # "enable_checkpointing": config.trainer.enable_checkpointing,
73
- # "accelerator": config.trainer.accelerator,
74
- # "strategy": config.trainer.strategy,
75
- # "num_nodes": config.trainer.num_nodes,
76
- # "precision": config.trainer.precision,
77
- # "logger": config.trainer.logging.enabled,
78
- # "log_every_n_steps": config.trainer.log_every_n_steps,
79
- # "enable_progress_bar": config.trainer.enable_progress_bar,
80
- # "enable_model_summary": config.trainer.enable_model_summary,
81
- # "accumulate_grad_batches": config.trainer.accumulate_grad_batches,
82
- # "benchmark": config.trainer.benchmark,
83
- # "use_distributed_sampler": config.trainer.use_distributed_sampler,
84
- # "detect_anomaly": config.trainer.detect_anomaly,
85
- # "barebones": config.trainer.barebones,
86
- # "plugins": config.trainer.plugins,
87
- # "sync_batchnorm": config.trainer.sync_batchnorm,
88
- # "reload_dataloaders_every_n_epochs": config.trainer.reload_dataloaders_every_n_epochs,
102
+ # "enable_checkpointing": hparams.enable_checkpointing,
103
+ # "accelerator": hparams.accelerator,
104
+ # "strategy": hparams.strategy,
105
+ # "num_nodes": hparams.num_nodes,
106
+ # "precision": hparams.precision,
107
+ # "logger": hparams.logging.enabled,
108
+ # "log_every_n_steps": hparams.log_every_n_steps,
109
+ # "enable_progress_bar": hparams.enable_progress_bar,
110
+ # "enable_model_summary": hparams.enable_model_summary,
111
+ # "accumulate_grad_batches": hparams.accumulate_grad_batches,
112
+ # "benchmark": hparams.benchmark,
113
+ # "use_distributed_sampler": hparams.use_distributed_sampler,
114
+ # "detect_anomaly": hparams.detect_anomaly,
115
+ # "barebones": hparams.barebones,
116
+ # "plugins": hparams.plugins,
117
+ # "sync_batchnorm": hparams.sync_batchnorm,
118
+ # "reload_dataloaders_every_n_epochs": hparams.reload_dataloaders_every_n_epochs,
89
119
  }
90
120
 
91
121
  def _update_key(key: str, new_value: Any):
@@ -115,20 +145,22 @@ class Trainer(LightningTrainer):
115
145
  _update_key(key, value)
116
146
 
117
147
  # Set `default_root_dir` if `auto_set_default_root_dir` is enabled.
118
- if config.trainer.auto_set_default_root_dir:
148
+ if hparams.auto_set_default_root_dir:
119
149
  if kwargs.get("default_root_dir"):
120
150
  raise ValueError(
121
- "You have set `config.trainer.default_root_dir`. "
151
+ "You have set `hparams.default_root_dir`. "
122
152
  "But we are trying to set it automatically. "
123
- "Please use `config.directory.base` rather than `config.trainer.default_root_dir`. "
124
- "If you want to set it manually, please set `config.trainer.auto_set_default_root_dir=False`."
153
+ "Please use `hparams.directory.base` rather than `hparams.default_root_dir`. "
154
+ "If you want to set it manually, please set `hparams.auto_set_default_root_dir=False`."
125
155
  )
126
156
 
127
157
  _update_kwargs(
128
- default_root_dir=config.directory.resolve_run_root_directory(config.id)
158
+ default_root_dir=hparams.directory.resolve_run_root_directory(
159
+ hparams.id
160
+ )
129
161
  )
130
162
 
131
- if (devices_input := config.trainer.devices) is not None:
163
+ if (devices_input := hparams.devices) is not None:
132
164
  match devices_input:
133
165
  case "all":
134
166
  devices = -1
@@ -141,22 +173,20 @@ class Trainer(LightningTrainer):
141
173
 
142
174
  _update_kwargs(devices=devices)
143
175
 
144
- if (
145
- use_distributed_sampler := config.trainer.use_distributed_sampler
146
- ) is not None:
176
+ if (use_distributed_sampler := hparams.use_distributed_sampler) is not None:
147
177
  _update_kwargs(use_distributed_sampler=use_distributed_sampler)
148
178
 
149
- if (accelerator := config.trainer.accelerator) is not None:
179
+ if (accelerator := hparams.accelerator) is not None:
150
180
  if isinstance(accelerator, AcceleratorConfigProtocol):
151
181
  accelerator = accelerator.create_accelerator()
152
182
  _update_kwargs(accelerator=accelerator)
153
183
 
154
- if (strategy := config.trainer.strategy) is not None:
184
+ if (strategy := hparams.strategy) is not None:
155
185
  if isinstance(strategy, StrategyConfigProtocol):
156
186
  strategy = strategy.create_strategy()
157
187
  _update_kwargs(strategy=strategy)
158
188
 
159
- if (precision := config.trainer.precision) is not None:
189
+ if (precision := hparams.precision) is not None:
160
190
  resolved_precision: _PRECISION_INPUT
161
191
  match precision:
162
192
  case "64-true" | "32-true" | "bf16-mixed":
@@ -184,11 +214,11 @@ class Trainer(LightningTrainer):
184
214
 
185
215
  _update_kwargs(precision=resolved_precision)
186
216
 
187
- if (detect_anomaly := config.trainer.detect_anomaly) is not None:
217
+ if (detect_anomaly := hparams.detect_anomaly) is not None:
188
218
  _update_kwargs(detect_anomaly=detect_anomaly)
189
219
 
190
220
  if (
191
- grad_clip_config := config.trainer.optimizer.gradient_clipping
221
+ grad_clip_config := hparams.optimizer.gradient_clipping
192
222
  ) is not None and grad_clip_config.enabled:
193
223
  # kwargs["gradient_clip_algorithm"] = grad_clip_config.algorithm
194
224
  # kwargs["gradient_clip_val"] = grad_clip_config.value
@@ -197,9 +227,9 @@ class Trainer(LightningTrainer):
197
227
  gradient_clip_val=grad_clip_config.value,
198
228
  )
199
229
 
200
- if profiler_config := config.trainer.profiler:
201
- if (profiler := profiler_config.create_profiler(config)) is None:
202
- log.warning(f"Profiler config {profiler_config=} returned None.")
230
+ if profiler_config := hparams.profiler:
231
+ if (profiler := profiler_config.create_profiler(hparams)) is None:
232
+ log.warning(f"Profiler hparams {profiler_config=} returned None.")
203
233
  # Make sure that the profiler is an instance of `Profiler`.
204
234
  elif not isinstance(profiler, Profiler):
205
235
  raise ValueError(f"{profiler=} is not an instance of `{Profiler}`.")
@@ -208,23 +238,29 @@ class Trainer(LightningTrainer):
208
238
  else:
209
239
  _update_kwargs(profiler=profiler)
210
240
 
211
- if callbacks := resolve_all_callbacks(config):
241
+ if callbacks := resolve_all_callbacks(hparams):
212
242
  _update_kwargs(callbacks=callbacks)
213
243
 
214
- if plugin_configs := config.trainer.plugins:
244
+ if plugin_configs := hparams.plugins:
215
245
  _update_kwargs(
216
246
  plugins=[
217
247
  plugin_config.create_plugin() for plugin_config in plugin_configs
218
248
  ]
219
249
  )
220
250
 
221
- if not config.trainer.logging.enabled:
222
- log.critical(f"Disabling logger because {config.trainer.logging.enabled=}.")
251
+ if not hparams.logging.enabled:
252
+ log.critical(f"Disabling logger because {hparams.logging.enabled=}.")
223
253
  kwargs["logger"] = False
224
254
  else:
225
- _update_kwargs(logger=list(config.trainer.logging.create_loggers(config)))
255
+ _update_kwargs(
256
+ logger=[
257
+ logger
258
+ for logger in hparams.logging.create_loggers(hparams)
259
+ if logger is not None
260
+ ]
261
+ )
226
262
 
227
- if config.trainer.auto_determine_num_nodes:
263
+ if hparams.auto_determine_num_nodes:
228
264
  # When num_nodes is auto, we need to detect the number of nodes.
229
265
  if SLURMEnvironment.detect():
230
266
  if (num_nodes := os.environ.get("SLURM_NNODES")) is not None:
@@ -243,12 +279,12 @@ class Trainer(LightningTrainer):
243
279
  _update_kwargs(num_nodes=num_nodes)
244
280
  else:
245
281
  log.info(
246
- "config.trainer.auto_determine_num_nodes ignored because no SLURM or LSF detected."
282
+ "hparams.auto_determine_num_nodes ignored because no SLURM or LSF detected."
247
283
  )
248
284
 
249
285
  # Update the kwargs with the additional trainer kwargs
250
- _update_kwargs(**cast(Any, config.trainer.additional_lightning_kwargs))
251
- _update_kwargs(**config.trainer.lightning_kwargs)
286
+ _update_kwargs(**cast(Any, hparams.additional_lightning_kwargs))
287
+ _update_kwargs(**hparams.lightning_kwargs)
252
288
  _update_kwargs(**kwargs_ctor)
253
289
 
254
290
  return kwargs
@@ -259,15 +295,29 @@ class Trainer(LightningTrainer):
259
295
  @override
260
296
  def __init__(
261
297
  self,
262
- config: "BaseConfig",
298
+ hparams: TrainerConfig | Mapping[str, Any],
263
299
  /,
264
300
  **kwargs: Unpack[LightningTrainerKwargs],
265
301
  ):
266
- self._pre_init(config)
302
+ # Validate the hparams.
303
+ hparams_cls = Trainer.hparams_cls()
304
+ if isinstance(hparams, Mapping):
305
+ hparams = hparams_cls.model_validate(hparams)
306
+ elif not isinstance(hparams, hparams_cls):
307
+ raise ValueError(
308
+ f"Trainer hparams must either be an instance of {hparams_cls} or a mapping. "
309
+ f"Got {type(hparams)=} instead."
310
+ )
311
+ hparams = hparams.model_deep_validate()
312
+
313
+ self._pre_init(hparams)
267
314
 
268
- kwargs = self._update_kwargs(config, kwargs)
315
+ kwargs = self._update_kwargs(hparams, kwargs)
269
316
  log.critical(f"LightningTrainer.__init__ with {kwargs=}.")
270
317
 
318
+ self._hparams = hparams
319
+ self.debug = self.hparams.debug
320
+
271
321
  super().__init__(**kwargs)
272
322
 
273
323
  # Add our own start time callback to measure the start time.
@@ -285,7 +335,7 @@ class Trainer(LightningTrainer):
285
335
  log.critical(f"LightningTrainer log directory: {self.log_dir}.")
286
336
 
287
337
  # Set the checkpoint
288
- if (ckpt_path := config.trainer.ckpt_path) is not None:
338
+ if (ckpt_path := hparams.ckpt_path) is not None:
289
339
  self.ckpt_path = str(Path(ckpt_path).resolve().absolute())
290
340
 
291
341
  def __runtime_tracker(self):
@@ -372,7 +422,16 @@ class Trainer(LightningTrainer):
372
422
  We patch the `Trainer._run` method to throw if gradient clipping is enabled
373
423
  and `model.automatic_optimization` is False.
374
424
  """
425
+ # Save the current environment information
426
+ datamodule = getattr(self, "datamodule", None)
427
+ self.hparams.environment = EnvironmentConfig.from_current_environment(
428
+ self.hparams, model, datamodule
429
+ )
375
430
 
431
+ # If gradient clipping is enabled, then we need to make sure that
432
+ # `model.automatic_optimization` is enabled. Otherwise, gradient clipping
433
+ # is not actually going to do anything, as we expect the user to manually
434
+ # call `optimizer.step()` and `optimizer.zero_grad()`.
376
435
  if not model.automatic_optimization and (
377
436
  self.gradient_clip_val is not None
378
437
  or self.gradient_clip_algorithm is not None
@@ -401,12 +460,10 @@ class Trainer(LightningTrainer):
401
460
 
402
461
  # Save the checkpoint metadata
403
462
  metadata_path = None
404
- lm = self._base_module
405
- root_config = cast("BaseConfig", lm.hparams)
406
- if root_config.trainer.save_checkpoint_metadata and self.is_global_zero:
463
+ if self.hparams.save_checkpoint_metadata and self.is_global_zero:
407
464
  # Generate the metadata and write to disk
408
465
  if (
409
- metadata_path := _write_checkpoint_metadata(self, lm, filepath)
466
+ metadata_path := _write_checkpoint_metadata(self, filepath)
410
467
  ) is not None:
411
468
  written_files.append(metadata_path)
412
469
 
@@ -414,3 +471,64 @@ class Trainer(LightningTrainer):
414
471
  from .. import _callback
415
472
 
416
473
  _callback._call_on_checkpoint_saved(self, filepath, metadata_path)
474
+
475
+ @classmethod
476
+ def load_from_checkpoint(
477
+ cls,
478
+ checkpoint_path: _PATH | IO,
479
+ map_location: _MAP_LOCATION_TYPE = None,
480
+ hparams_file: _PATH | None = None,
481
+ **kwargs: Any,
482
+ ):
483
+ loaded = _load_from_checkpoint(
484
+ checkpoint_path,
485
+ map_location=map_location,
486
+ hparams_file=hparams_file,
487
+ **kwargs,
488
+ )
489
+ return loaded
490
+
491
+
492
+ def _load_from_checkpoint(
493
+ checkpoint_path: _PATH | IO,
494
+ map_location: _MAP_LOCATION_TYPE = None,
495
+ hparams_file: _PATH | None = None,
496
+ **kwargs: Any,
497
+ ):
498
+ map_location = map_location or _default_map_location
499
+ with pl_legacy_patch():
500
+ checkpoint = pl_load(checkpoint_path, map_location=map_location)
501
+
502
+ # convert legacy checkpoints to the new format
503
+ checkpoint = _pl_migrate_checkpoint(
504
+ checkpoint,
505
+ checkpoint_path=(
506
+ checkpoint_path if isinstance(checkpoint_path, (str, Path)) else None
507
+ ),
508
+ )
509
+
510
+ if hparams_file is not None:
511
+ extension = str(hparams_file).split(".")[-1]
512
+ if extension.lower() == "csv":
513
+ hparams = load_hparams_from_tags_csv(hparams_file)
514
+ elif extension.lower() in ("yml", "yaml"):
515
+ hparams = load_hparams_from_yaml(hparams_file)
516
+ else:
517
+ raise ValueError(".csv, .yml or .yaml is required for `hparams_file`")
518
+
519
+ # overwrite hparams by the given file
520
+ checkpoint[Trainer.CHECKPOINT_HYPER_PARAMS_KEY] = hparams
521
+
522
+ # for past checkpoint need to add the new key
523
+ checkpoint.setdefault(Trainer.CHECKPOINT_HYPER_PARAMS_KEY, {})
524
+ # override the hparams with values that were passed in
525
+ checkpoint[Trainer.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
526
+
527
+ # load the hparams
528
+ hparams = Trainer.hparams_cls().model_validate(
529
+ checkpoint[Trainer.CHECKPOINT_HYPER_PARAMS_KEY]
530
+ )
531
+
532
+ # create the trainer
533
+ trainer = Trainer(hparams)
534
+ return trainer
@@ -15,14 +15,14 @@ from typing import TYPE_CHECKING, Any, cast
15
15
  import nshconfig as C
16
16
  import psutil
17
17
  import torch
18
+ from lightning.pytorch import LightningDataModule, LightningModule
18
19
  from packaging import version
19
20
  from typing_extensions import Self
20
21
 
21
22
  from .slurm import parse_slurm_node_list
22
23
 
23
24
  if TYPE_CHECKING:
24
- from ..model.base import LightningModuleBase
25
- from ..model.config import BaseConfig
25
+ from ..trainer._config import TrainerConfig
26
26
 
27
27
 
28
28
  log = logging.getLogger(__name__)
@@ -708,6 +708,9 @@ class EnvironmentConfig(C.Config):
708
708
  model: EnvironmentClassInformationConfig | None = None
709
709
  """The Lightning module class information."""
710
710
 
711
+ datamodule: EnvironmentClassInformationConfig | None = None
712
+ """The Lightning data module class information."""
713
+
711
714
  linux: EnvironmentLinuxEnvironmentConfig | None = None
712
715
  """The Linux environment information."""
713
716
 
@@ -768,8 +771,9 @@ class EnvironmentConfig(C.Config):
768
771
  @classmethod
769
772
  def from_current_environment(
770
773
  cls,
771
- root_config: "BaseConfig",
772
- model: "LightningModuleBase",
774
+ trainer_config: TrainerConfig,
775
+ model: LightningModule,
776
+ datamodule: LightningDataModule | None = None,
773
777
  ):
774
778
  draft = cls.draft()
775
779
  draft.cwd = Path(os.getcwd())
@@ -777,23 +781,27 @@ class EnvironmentConfig(C.Config):
777
781
  draft.python_path = [Path(path) for path in sys.path]
778
782
  draft.python_version = sys.version
779
783
  draft.python_packages = EnvironmentPackageConfig.from_current_environment()
780
- draft.config = EnvironmentClassInformationConfig.from_instance(root_config)
784
+ draft.config = EnvironmentClassInformationConfig.from_instance(trainer_config)
781
785
  draft.model = EnvironmentClassInformationConfig.from_instance(model)
786
+ if datamodule is not None:
787
+ draft.datamodule = EnvironmentClassInformationConfig.from_instance(
788
+ datamodule
789
+ )
782
790
  draft.linux = EnvironmentLinuxEnvironmentConfig.from_current_environment()
783
791
  draft.hardware = EnvironmentHardwareConfig.from_current_environment()
784
792
  draft.slurm = EnvironmentSLURMInformationConfig.from_current_environment()
785
793
  draft.lsf = EnvironmentLSFInformationConfig.from_current_environment()
786
- draft.base_dir = root_config.directory.resolve_run_root_directory(
787
- root_config.id
794
+ draft.base_dir = trainer_config.directory.resolve_run_root_directory(
795
+ trainer_config.id
788
796
  )
789
- draft.log_dir = root_config.directory.resolve_subdirectory(
790
- root_config.id, "log"
797
+ draft.log_dir = trainer_config.directory.resolve_subdirectory(
798
+ trainer_config.id, "log"
791
799
  )
792
- draft.checkpoint_dir = root_config.directory.resolve_subdirectory(
793
- root_config.id, "checkpoint"
800
+ draft.checkpoint_dir = trainer_config.directory.resolve_subdirectory(
801
+ trainer_config.id, "checkpoint"
794
802
  )
795
- draft.stdio_dir = root_config.directory.resolve_subdirectory(
796
- root_config.id, "stdio"
803
+ draft.stdio_dir = trainer_config.directory.resolve_subdirectory(
804
+ trainer_config.id, "stdio"
797
805
  )
798
806
  draft.seed = (
799
807
  int(seed_str) if (seed_str := os.environ.get("PL_GLOBAL_SEED")) else None
@@ -9,7 +9,7 @@ from typing_extensions import assert_never
9
9
  from ..bf16 import is_bf16_supported_no_emulation
10
10
 
11
11
  if TYPE_CHECKING:
12
- from ...model.base import BaseConfig
12
+ from ...trainer._config import TrainerConfig
13
13
 
14
14
  DTypeName: TypeAlias = Literal[
15
15
  "float32",
@@ -59,8 +59,8 @@ class DTypeConfig(C.Config):
59
59
  """The name of the dtype."""
60
60
 
61
61
  @classmethod
62
- def from_base_config(cls, config: "BaseConfig"):
63
- if (precision := config.trainer.precision) is None:
62
+ def from_trainer_config(cls, trainer_config: TrainerConfig):
63
+ if (precision := trainer_config.precision) is None:
64
64
  precision = "32-true"
65
65
 
66
66
  match precision:
@@ -79,7 +79,7 @@ class DTypeConfig(C.Config):
79
79
  case "64-true":
80
80
  return cls(name="float64")
81
81
  case _:
82
- assert_never(config.trainer.precision)
82
+ assert_never(trainer_config.precision)
83
83
 
84
84
  @property
85
85
  def torch_dtype(self):
@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING
4
4
 
5
5
  from typing_extensions import TypeVar
6
6
 
7
- TBase = TypeVar("TBase")
7
+ TBase = TypeVar("TBase", infer_variance=True)
8
8
 
9
9
 
10
10
  def mixin_base_type(base_class: type[TBase]) -> type[TBase]:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.44.0
3
+ Version: 1.0.0b9
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -15,7 +15,7 @@ Requires-Dist: huggingface-hub ; extra == "extra"
15
15
  Requires-Dist: lightning
16
16
  Requires-Dist: nshconfig
17
17
  Requires-Dist: nshrunner
18
- Requires-Dist: nshutils
18
+ Requires-Dist: nshutils ; extra == "extra"
19
19
  Requires-Dist: numpy
20
20
  Requires-Dist: packaging
21
21
  Requires-Dist: psutil