nshtrainer 0.8.7__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. nshtrainer/__init__.py +2 -1
  2. nshtrainer/callbacks/__init__.py +17 -1
  3. nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
  4. nshtrainer/callbacks/base.py +7 -5
  5. nshtrainer/callbacks/ema.py +1 -1
  6. nshtrainer/callbacks/finite_checks.py +1 -1
  7. nshtrainer/callbacks/gradient_skipping.py +1 -1
  8. nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
  9. nshtrainer/callbacks/model_checkpoint.py +187 -0
  10. nshtrainer/callbacks/norm_logging.py +1 -1
  11. nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
  12. nshtrainer/callbacks/print_table.py +1 -1
  13. nshtrainer/callbacks/throughput_monitor.py +1 -1
  14. nshtrainer/callbacks/timer.py +1 -1
  15. nshtrainer/callbacks/wandb_watch.py +1 -1
  16. nshtrainer/ll/__init__.py +0 -1
  17. nshtrainer/ll/actsave.py +2 -1
  18. nshtrainer/metrics/__init__.py +1 -0
  19. nshtrainer/metrics/_config.py +37 -0
  20. nshtrainer/model/__init__.py +11 -11
  21. nshtrainer/model/_environment.py +777 -0
  22. nshtrainer/model/base.py +5 -114
  23. nshtrainer/model/config.py +92 -507
  24. nshtrainer/model/modules/logger.py +11 -6
  25. nshtrainer/runner.py +3 -6
  26. nshtrainer/trainer/_checkpoint_metadata.py +102 -0
  27. nshtrainer/trainer/_checkpoint_resolver.py +319 -0
  28. nshtrainer/trainer/_runtime_callback.py +120 -0
  29. nshtrainer/trainer/checkpoint_connector.py +63 -0
  30. nshtrainer/trainer/signal_connector.py +12 -9
  31. nshtrainer/trainer/trainer.py +111 -31
  32. {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/METADATA +3 -1
  33. {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/RECORD +34 -27
  34. nshtrainer/actsave/__init__.py +0 -3
  35. {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/WHEEL +0 -0
nshtrainer/model/base.py CHANGED
@@ -1,30 +1,18 @@
1
- import getpass
2
1
  import inspect
3
- import os
4
- import platform
5
- import sys
6
2
  from abc import ABC, abstractmethod
7
- from collections.abc import Callable, MutableMapping
8
- from datetime import timedelta
3
+ from collections.abc import MutableMapping
9
4
  from logging import getLogger
10
- from pathlib import Path
11
5
  from typing import IO, TYPE_CHECKING, Any, Generic, cast
12
6
 
13
7
  import torch
14
8
  from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
15
- from lightning.pytorch import LightningDataModule, LightningModule, Trainer
9
+ from lightning.pytorch import LightningModule, Trainer
16
10
  from lightning.pytorch.callbacks import Callback
17
11
  from lightning.pytorch.utilities.types import STEP_OUTPUT
18
12
  from typing_extensions import Self, TypeVar, override
19
13
 
20
- from .config import (
21
- BaseConfig,
22
- EnvironmentClassInformationConfig,
23
- EnvironmentLinuxEnvironmentConfig,
24
- EnvironmentLSFInformationConfig,
25
- EnvironmentSLURMInformationConfig,
26
- EnvironmentSnapshotConfig,
27
- )
14
+ from ._environment import EnvironmentConfig
15
+ from .config import BaseConfig
28
16
  from .modules.callback import CallbackModuleMixin
29
17
  from .modules.debug import DebugModuleMixin
30
18
  from .modules.distributed import DistributedMixin
@@ -102,39 +90,6 @@ class DebugFlagCallback(Callback):
102
90
  hparams.debug = self._debug
103
91
 
104
92
 
105
- def _cls_info(cls: type):
106
- name = cls.__name__
107
- module = cls.__module__
108
- full_name = f"{cls.__module__}.{cls.__qualname__}"
109
-
110
- file_path = inspect.getfile(cls)
111
- source_file_path = inspect.getsourcefile(cls)
112
- return EnvironmentClassInformationConfig(
113
- name=name,
114
- module=module,
115
- full_name=full_name,
116
- file_path=Path(file_path),
117
- source_file_path=Path(source_file_path) if source_file_path else None,
118
- )
119
-
120
-
121
- T = TypeVar("T")
122
-
123
-
124
- def _psutil():
125
- import psutil
126
-
127
- return psutil
128
-
129
-
130
- def _try_get(fn: Callable[[], T | None]) -> T | None:
131
- try:
132
- return fn()
133
- except Exception as e:
134
- log.warning(f"Failed to get value: {e}")
135
- return None
136
-
137
-
138
93
  class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
139
94
  ProfilerMixin,
140
95
  RLPSanityCheckModuleMixin,
@@ -212,58 +167,6 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
212
167
  **kwargs,
213
168
  )
214
169
 
215
- @classmethod
216
- def _update_environment(cls, hparams: THparams):
217
- hparams.environment.cwd = Path(os.getcwd())
218
- hparams.environment.python_executable = Path(sys.executable)
219
- hparams.environment.python_path = [Path(path) for path in sys.path]
220
- hparams.environment.python_version = sys.version
221
- hparams.environment.config = _cls_info(cls.config_cls())
222
- hparams.environment.model = _cls_info(cls)
223
- hparams.environment.slurm = (
224
- EnvironmentSLURMInformationConfig.from_current_environment()
225
- )
226
- hparams.environment.lsf = (
227
- EnvironmentLSFInformationConfig.from_current_environment()
228
- )
229
- hparams.environment.base_dir = hparams.directory.resolve_run_root_directory(
230
- hparams.id
231
- )
232
- hparams.environment.log_dir = hparams.directory.resolve_subdirectory(
233
- hparams.id, "log"
234
- )
235
- hparams.environment.checkpoint_dir = hparams.directory.resolve_subdirectory(
236
- hparams.id, "checkpoint"
237
- )
238
- hparams.environment.stdio_dir = hparams.directory.resolve_subdirectory(
239
- hparams.id, "stdio"
240
- )
241
- hparams.environment.seed = (
242
- int(seed_str) if (seed_str := os.environ.get("PL_GLOBAL_SEED")) else None
243
- )
244
- hparams.environment.seed_workers = (
245
- bool(int(seed_everything))
246
- if (seed_everything := os.environ.get("PL_SEED_WORKERS"))
247
- else None
248
- )
249
- hparams.environment.linux = EnvironmentLinuxEnvironmentConfig(
250
- user=_try_get(lambda: getpass.getuser()),
251
- hostname=_try_get(lambda: platform.node()),
252
- system=_try_get(lambda: platform.system()),
253
- release=_try_get(lambda: platform.release()),
254
- version=_try_get(lambda: platform.version()),
255
- machine=_try_get(lambda: platform.machine()),
256
- processor=_try_get(lambda: platform.processor()),
257
- cpu_count=_try_get(lambda: os.cpu_count()),
258
- memory=_try_get(lambda: _psutil().virtual_memory().total),
259
- uptime=_try_get(lambda: timedelta(seconds=_psutil().boot_time())),
260
- boot_time=_try_get(lambda: _psutil().boot_time()),
261
- load_avg=_try_get(lambda: os.getloadavg()),
262
- )
263
- hparams.environment.snapshot = (
264
- EnvironmentSnapshotConfig.from_current_environment()
265
- )
266
-
267
170
  def pre_init_update_hparams_dict(self, hparams: MutableMapping[str, Any]):
268
171
  """
269
172
  Override this method to update the hparams dictionary before it is used to create the hparams object.
@@ -287,12 +190,11 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
287
190
 
288
191
  hparams = self.pre_init_update_hparams_dict(hparams)
289
192
  hparams = self.config_cls().model_validate(hparams)
290
- self._update_environment(hparams)
193
+ hparams.environment = EnvironmentConfig.from_current_environment(hparams, self)
291
194
  hparams = self.pre_init_update_hparams(hparams)
292
195
  super().__init__(hparams)
293
196
 
294
197
  self.save_hyperparameters(hparams)
295
-
296
198
  self.register_callback(lambda: DebugFlagCallback())
297
199
 
298
200
  def zero_loss(self):
@@ -304,17 +206,6 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
304
206
  loss = cast(torch.Tensor, loss)
305
207
  return loss
306
208
 
307
- @property
308
- def datamodule(self):
309
- datamodule = getattr(self.trainer, "datamodule", None)
310
- if (datamodule := getattr(self.trainer, "datamodule", None)) is None:
311
- return None
312
- if not isinstance(datamodule, LightningDataModule):
313
- raise TypeError(
314
- f"datamodule must be a LightningDataModule: {type(datamodule)}"
315
- )
316
- return datamodule
317
-
318
209
  if TYPE_CHECKING:
319
210
 
320
211
  @override