nshtrainer 0.9.1__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. nshtrainer/__init__.py +2 -1
  2. nshtrainer/_checkpoint/loader.py +319 -0
  3. nshtrainer/_checkpoint/metadata.py +102 -0
  4. nshtrainer/callbacks/__init__.py +17 -1
  5. nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
  6. nshtrainer/callbacks/base.py +7 -5
  7. nshtrainer/callbacks/ema.py +1 -1
  8. nshtrainer/callbacks/finite_checks.py +1 -1
  9. nshtrainer/callbacks/gradient_skipping.py +1 -1
  10. nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
  11. nshtrainer/callbacks/model_checkpoint.py +187 -0
  12. nshtrainer/callbacks/norm_logging.py +1 -1
  13. nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
  14. nshtrainer/callbacks/print_table.py +1 -1
  15. nshtrainer/callbacks/throughput_monitor.py +1 -1
  16. nshtrainer/callbacks/timer.py +1 -1
  17. nshtrainer/callbacks/wandb_watch.py +1 -1
  18. nshtrainer/ll/__init__.py +0 -1
  19. nshtrainer/ll/actsave.py +2 -1
  20. nshtrainer/metrics/__init__.py +1 -0
  21. nshtrainer/metrics/_config.py +37 -0
  22. nshtrainer/model/__init__.py +11 -11
  23. nshtrainer/model/_environment.py +777 -0
  24. nshtrainer/model/base.py +5 -114
  25. nshtrainer/model/config.py +49 -501
  26. nshtrainer/model/modules/logger.py +11 -6
  27. nshtrainer/runner.py +3 -6
  28. nshtrainer/trainer/_runtime_callback.py +120 -0
  29. nshtrainer/trainer/checkpoint_connector.py +63 -0
  30. nshtrainer/trainer/signal_connector.py +12 -9
  31. nshtrainer/trainer/trainer.py +111 -31
  32. {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.1.dist-info}/METADATA +3 -1
  33. {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.1.dist-info}/RECORD +34 -27
  34. nshtrainer/actsave/__init__.py +0 -3
  35. {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.1.dist-info}/WHEEL +0 -0
@@ -11,10 +11,14 @@ from lightning.pytorch.utilities.types import _METRIC
11
11
  from lightning_utilities.core.rank_zero import rank_zero_warn
12
12
  from typing_extensions import override
13
13
 
14
- from ...actsave import ActSave
15
14
  from ...util.typing_utils import mixin_base_type
16
15
  from ..config import BaseConfig
17
16
 
17
+ try:
18
+ from nshutils import ActSave # type: ignore
19
+ except ImportError:
20
+ ActSave = None
21
+
18
22
 
19
23
  @dataclass(frozen=True, kw_only=True)
20
24
  class _LogContext:
@@ -155,14 +159,15 @@ class LoggerLightningModuleMixin(LoggerModuleMixin, mixin_base_type(LightningMod
155
159
 
156
160
  def _logger_actsave(self, name: str, value: _METRIC) -> None:
157
161
  hparams = cast(BaseConfig, self.hparams)
158
- if (
159
- not hparams.trainer.actsave
160
- or not hparams.trainer.actsave.auto_save_logged_metrics
161
- ):
162
+ if not hparams.trainer.logging.actsave_logged_metrics:
163
+ return
164
+
165
+ if ActSave is None:
166
+ rank_zero_warn("ActSave is not available, skipping logging of metrics")
162
167
  return
163
168
 
164
169
  ActSave.save(
165
- {
170
+ lambda: {
166
171
  f"logger.{name}": lambda: value.compute()
167
172
  if isinstance(value, torchmetrics.Metric)
168
173
  else value
nshtrainer/runner.py CHANGED
@@ -5,6 +5,7 @@ from typing import Generic
5
5
 
6
6
  from nshrunner import RunInfo
7
7
  from nshrunner import Runner as _Runner
8
+ from nshrunner._submit import screen
8
9
  from nshrunner.snapshot import SnapshotArgType
9
10
  from typing_extensions import TypeVar, TypeVarTuple, Unpack, override
10
11
 
@@ -89,6 +90,7 @@ class Runner(
89
90
  def fast_dev_run_session(
90
91
  self,
91
92
  runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
93
+ options: screen.ScreenJobKwargs = {},
92
94
  n_batches: int = 1,
93
95
  *,
94
96
  snapshot: SnapshotArgType,
@@ -99,10 +101,7 @@ class Runner(
99
101
  ]
100
102
  | None = None,
101
103
  activate_venv: bool = True,
102
- session_name: str = "nshrunner",
103
- attach: bool = True,
104
104
  print_command: bool = True,
105
- pause_before_exit: bool = False,
106
105
  ):
107
106
  transforms = transforms or []
108
107
  transforms.append(
@@ -110,13 +109,11 @@ class Runner(
110
109
  )
111
110
  return self.session(
112
111
  runs,
112
+ options,
113
113
  snapshot=snapshot,
114
114
  setup_commands=setup_commands,
115
115
  env=env,
116
116
  transforms=transforms,
117
117
  activate_venv=activate_venv,
118
- session_name=session_name,
119
- attach=attach,
120
118
  print_command=print_command,
121
- pause_before_exit=pause_before_exit,
122
119
  )
@@ -0,0 +1,120 @@
1
+ import datetime
2
+ import logging
3
+ import time
4
+ from dataclasses import dataclass
5
+ from typing import Any, Literal, TypeAlias
6
+
7
+ from lightning.pytorch.callbacks.callback import Callback
8
+ from typing_extensions import override
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+ Stage: TypeAlias = Literal["train", "validate", "test", "predict"]
13
+ ALL_STAGES = ("train", "validate", "test", "predict")
14
+
15
+
16
+ @dataclass
17
+ class TimeInfo:
18
+ datetime: datetime.datetime
19
+ monotonic: float
20
+
21
+
22
+ class RuntimeTrackerCallback(Callback):
23
+ def __init__(self):
24
+ super().__init__()
25
+ self._start_time: dict[Stage, TimeInfo] = {}
26
+ self._end_time: dict[Stage, TimeInfo] = {}
27
+ self._offsets = {stage: datetime.timedelta() for stage in ALL_STAGES}
28
+
29
+ def start_time(self, stage: Stage) -> TimeInfo | None:
30
+ """Return the start time of a particular stage"""
31
+ return self._start_time.get(stage)
32
+
33
+ def end_time(self, stage: Stage) -> TimeInfo | None:
34
+ """Return the end time of a particular stage"""
35
+ return self._end_time.get(stage)
36
+
37
+ def time_elapsed(self, stage: Stage) -> datetime.timedelta:
38
+ """Return the time elapsed for a particular stage"""
39
+ start = self.start_time(stage)
40
+ end = self.end_time(stage)
41
+ offset = self._offsets[stage]
42
+ if start is None:
43
+ return offset
44
+ if end is None:
45
+ current = TimeInfo(datetime.datetime.now(), time.monotonic())
46
+ return (
47
+ datetime.timedelta(seconds=current.monotonic - start.monotonic) + offset
48
+ )
49
+ return datetime.timedelta(seconds=end.monotonic - start.monotonic) + offset
50
+
51
+ def _record_time(self, stage: Stage, time_dict: dict[Stage, TimeInfo]):
52
+ time_dict[stage] = TimeInfo(datetime.datetime.now(), time.monotonic())
53
+
54
+ @override
55
+ def on_train_start(self, trainer, pl_module):
56
+ self._record_time("train", self._start_time)
57
+
58
+ @override
59
+ def on_train_end(self, trainer, pl_module):
60
+ self._record_time("train", self._end_time)
61
+
62
+ @override
63
+ def on_validation_start(self, trainer, pl_module):
64
+ self._record_time("validate", self._start_time)
65
+
66
+ @override
67
+ def on_validation_end(self, trainer, pl_module):
68
+ self._record_time("validate", self._end_time)
69
+
70
+ @override
71
+ def on_test_start(self, trainer, pl_module):
72
+ self._record_time("test", self._start_time)
73
+
74
+ @override
75
+ def on_test_end(self, trainer, pl_module):
76
+ self._record_time("test", self._end_time)
77
+
78
+ @override
79
+ def on_predict_start(self, trainer, pl_module):
80
+ self._record_time("predict", self._start_time)
81
+
82
+ @override
83
+ def on_predict_end(self, trainer, pl_module):
84
+ self._record_time("predict", self._end_time)
85
+
86
+ @override
87
+ def state_dict(self) -> dict[str, Any]:
88
+ return {
89
+ "time_elapsed": {
90
+ stage: self.time_elapsed(stage).total_seconds() for stage in ALL_STAGES
91
+ },
92
+ "start_times": {
93
+ stage: (info.datetime.isoformat(), info.monotonic)
94
+ for stage, info in self._start_time.items()
95
+ },
96
+ "end_times": {
97
+ stage: (info.datetime.isoformat(), info.monotonic)
98
+ for stage, info in self._end_time.items()
99
+ },
100
+ }
101
+
102
+ @override
103
+ def load_state_dict(self, state_dict: dict[str, Any]):
104
+ time_elapsed: dict[Stage, float] = state_dict.get("time_elapsed", {})
105
+ for stage in ALL_STAGES:
106
+ self._offsets[stage] = datetime.timedelta(
107
+ seconds=time_elapsed.get(stage, 0)
108
+ )
109
+
110
+ start_times: dict[Stage, tuple[str, float]] = state_dict.get("start_times", {})
111
+ for stage, (dt_str, monotonic) in start_times.items():
112
+ self._start_time[stage] = TimeInfo(
113
+ datetime.datetime.fromisoformat(dt_str), monotonic
114
+ )
115
+
116
+ end_times: dict[Stage, tuple[str, float]] = state_dict.get("end_times", {})
117
+ for stage, (dt_str, monotonic) in end_times.items():
118
+ self._end_time[stage] = TimeInfo(
119
+ datetime.datetime.fromisoformat(dt_str), monotonic
120
+ )
@@ -0,0 +1,63 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import TYPE_CHECKING, cast
4
+
5
+ from lightning.pytorch.trainer.connectors.checkpoint_connector import (
6
+ _CheckpointConnector,
7
+ )
8
+ from lightning.pytorch.trainer.states import TrainerFn
9
+ from typing_extensions import override
10
+
11
+ from .._checkpoint.loader import CheckpointLoadingConfig, _resolve_checkpoint
12
+
13
+ if TYPE_CHECKING:
14
+ from ..model.config import BaseConfig
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ class CheckpointConnector(_CheckpointConnector):
19
+ def __resolve_auto_ckpt_path(
20
+ self,
21
+ ckpt_path: str | Path | None,
22
+ state_fn: TrainerFn,
23
+ ):
24
+ from .trainer import Trainer
25
+
26
+ # If this isn't an `nshtrainer` trainer (which I don't know why it wouldn't be),
27
+ # then we just default to the parent class's implementation of `_parse_ckpt_path`.
28
+ trainer = self.trainer
29
+ if not isinstance(trainer, Trainer):
30
+ return None
31
+
32
+ # Now, resolve the checkpoint loader config.
33
+ root_config = cast("BaseConfig", trainer._base_module.config)
34
+ if (ckpt_loader_config := root_config.trainer.checkpoint_loading) == "auto":
35
+ ckpt_loader_config = CheckpointLoadingConfig.auto(ckpt_path, state_fn)
36
+ log.debug(f"Checkpoint loader config: {ckpt_loader_config}")
37
+
38
+ # Use the config to resolve the checkpoint.
39
+ if (
40
+ ckpt_path := _resolve_checkpoint(ckpt_loader_config, root_config, trainer)
41
+ ) is None:
42
+ log.info(
43
+ "No checkpoint found for the current trainer state. "
44
+ "Training will start from scratch."
45
+ )
46
+
47
+ log.info(f"Loading checkpoint from: {ckpt_path}")
48
+ return ckpt_path
49
+
50
+ @override
51
+ def _parse_ckpt_path(
52
+ self,
53
+ state_fn: TrainerFn,
54
+ ckpt_path: str | Path | None,
55
+ model_provided: bool,
56
+ model_connected: bool,
57
+ ):
58
+ if (p := self.__resolve_auto_ckpt_path(ckpt_path, state_fn)) is not None:
59
+ return p
60
+
61
+ return super()._parse_ckpt_path(
62
+ state_fn, ckpt_path, model_provided, model_connected
63
+ )
@@ -11,6 +11,7 @@ from pathlib import Path
11
11
  from types import FrameType
12
12
  from typing import Any, TypeAlias
13
13
 
14
+ import nshrunner as nr
14
15
  import torch.utils.data
15
16
  from lightning.fabric.plugins.environments.lsf import LSFEnvironment
16
17
  from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
@@ -27,20 +28,22 @@ _HANDLER: TypeAlias = Callable[[_SIGNUM, FrameType], Any] | int | signal.Handler
27
28
 
28
29
 
29
30
  def _resolve_requeue_signals():
30
- signals: list[signal.Signals] = []
31
-
32
- if timeout_signal_name := os.environ.get("NSHRUNNER_TIMEOUT_SIGNAL"):
33
- signals.append(signal.Signals[timeout_signal_name])
34
-
35
- if preempt_signal_name := os.environ.get("NSHRUNNER_PREEMPT_SIGNAL"):
36
- signals.append(signal.Signals[preempt_signal_name])
31
+ if (session := nr.Session.from_current_session()) is None:
32
+ return None
37
33
 
34
+ signals: list[signal.Signals] = []
35
+ if session.submit_timeout_signal:
36
+ signals.append(session.submit_timeout_signal)
37
+ if session.submit_preempt_signal:
38
+ signals.append(session.submit_preempt_signal)
38
39
  return signals
39
40
 
40
41
 
41
42
  class _SignalConnector(_LightningSignalConnector):
42
- def _auto_requeue_signals(self) -> list[signal.Signals]:
43
- signals = _resolve_requeue_signals()
43
+ def _auto_requeue_signals(self) -> list[signal.Signals] | None:
44
+ if not (signals := _resolve_requeue_signals()):
45
+ return None
46
+
44
47
  signals_set = set(signals)
45
48
  valid_signals: set[signal.Signals] = signal.valid_signals()
46
49
  assert signals_set.issubset(
@@ -3,7 +3,7 @@ import logging
3
3
  import os
4
4
  from collections.abc import Sequence
5
5
  from pathlib import Path
6
- from typing import Any, cast
6
+ from typing import TYPE_CHECKING, Any, cast
7
7
 
8
8
  import torch
9
9
  from lightning.fabric.plugins.environments.lsf import LSFEnvironment
@@ -11,11 +11,13 @@ from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
11
11
  from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
12
12
  from lightning.pytorch import LightningModule
13
13
  from lightning.pytorch import Trainer as LightningTrainer
14
+ from lightning.pytorch.callbacks import Callback
14
15
  from lightning.pytorch.profilers import Profiler
16
+ from lightning.pytorch.trainer.states import TrainerFn
15
17
  from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
16
18
  from typing_extensions import Unpack, assert_never, override
17
19
 
18
- from ..actsave import ActSave
20
+ from .._checkpoint.metadata import _write_checkpoint_metadata
19
21
  from ..callbacks.base import resolve_all_callbacks
20
22
  from ..model.config import (
21
23
  AcceleratorConfigProtocol,
@@ -24,6 +26,7 @@ from ..model.config import (
24
26
  LightningTrainerKwargs,
25
27
  StrategyConfigProtocol,
26
28
  )
29
+ from ._runtime_callback import RuntimeTrackerCallback, Stage
27
30
  from .signal_connector import _SignalConnector
28
31
 
29
32
  log = logging.getLogger(__name__)
@@ -168,12 +171,12 @@ class Trainer(LightningTrainer):
168
171
 
169
172
  if (accelerator := config.trainer.accelerator) is not None:
170
173
  if isinstance(accelerator, AcceleratorConfigProtocol):
171
- accelerator = accelerator.construct_accelerator()
174
+ accelerator = accelerator.create_accelerator()
172
175
  _update_kwargs(accelerator=accelerator)
173
176
 
174
177
  if (strategy := config.trainer.strategy) is not None:
175
178
  if isinstance(strategy, StrategyConfigProtocol):
176
- strategy = strategy.construct_strategy()
179
+ strategy = strategy.create_strategy()
177
180
  _update_kwargs(strategy=strategy)
178
181
 
179
182
  if (precision := config.trainer.precision) is not None:
@@ -220,7 +223,7 @@ class Trainer(LightningTrainer):
220
223
  if profiler := config.trainer.profiler:
221
224
  # If the profiler is an ProfilerConfig instance, then we instantiate it.
222
225
  if isinstance(profiler, BaseProfilerConfig):
223
- profiler = profiler.construct_profiler(config)
226
+ profiler = profiler.create_profiler(config)
224
227
  # Make sure that the profiler is an instance of `Profiler`.
225
228
  if not isinstance(profiler, Profiler):
226
229
  raise ValueError(f"{profiler=} is not an instance of `{Profiler}`.")
@@ -236,7 +239,7 @@ class Trainer(LightningTrainer):
236
239
  if plugin_configs := config.trainer.plugins:
237
240
  _update_kwargs(
238
241
  plugins=[
239
- plugin_config.construct_plugin() for plugin_config in plugin_configs
242
+ plugin_config.create_plugin() for plugin_config in plugin_configs
240
243
  ]
241
244
  )
242
245
 
@@ -244,7 +247,7 @@ class Trainer(LightningTrainer):
244
247
  log.critical(f"Disabling logger because {config.trainer.logging.enabled=}.")
245
248
  kwargs["logger"] = False
246
249
  else:
247
- _update_kwargs(logger=config.trainer.logging.construct_loggers(config))
250
+ _update_kwargs(logger=config.trainer.logging.create_loggers(config))
248
251
 
249
252
  if config.trainer.auto_determine_num_nodes:
250
253
  # When num_nodes is auto, we need to detect the number of nodes.
@@ -275,6 +278,9 @@ class Trainer(LightningTrainer):
275
278
 
276
279
  return kwargs
277
280
 
281
+ if TYPE_CHECKING:
282
+ callbacks: list[Callback]
283
+
278
284
  @override
279
285
  def __init__(
280
286
  self,
@@ -282,12 +288,14 @@ class Trainer(LightningTrainer):
282
288
  /,
283
289
  **kwargs: Unpack[LightningTrainerKwargs],
284
290
  ):
285
- self._ll_config = config
286
291
  kwargs = self._update_kwargs(config, kwargs)
287
292
  log.critical(f"LightningTrainer.__init__ with {kwargs=}.")
288
293
 
289
294
  super().__init__(**kwargs)
290
295
 
296
+ # Add our own start time callback to measure the start time.
297
+ self.callbacks.append(RuntimeTrackerCallback())
298
+
291
299
  # Replace the signal connector with our own.
292
300
  self._signal_connector = _SignalConnector(self)
293
301
 
@@ -296,34 +304,89 @@ class Trainer(LightningTrainer):
296
304
  log_dir = str(Path(log_dir).resolve())
297
305
  log.critical(f"LightningTrainer log directory: {self.log_dir}.")
298
306
 
299
- # Checkpoint loading
300
- if (
301
- ckpt_loading := self._ll_config.trainer.checkpoint_loading
302
- ) and ckpt_loading.path:
303
- self.ckpt_path = ckpt_loading.path
307
+ def __runtime_tracker(self):
308
+ return next(
309
+ (
310
+ callback
311
+ for callback in self.callbacks
312
+ if isinstance(callback, RuntimeTrackerCallback)
313
+ ),
314
+ None,
315
+ )
316
+
317
+ def __current_stage(self) -> Stage:
318
+ match self.state.fn:
319
+ case None:
320
+ raise ValueError(
321
+ "Trainer state function is not set. "
322
+ "You must call `fit`, `validate`, `test`, or `predict`, "
323
+ "or explicitly provide a stage."
324
+ )
325
+ case TrainerFn.FITTING:
326
+ return "train"
327
+ case TrainerFn.VALIDATING:
328
+ return "validate"
329
+ case TrainerFn.TESTING:
330
+ return "test"
331
+ case TrainerFn.PREDICTING:
332
+ return "predict"
333
+ case _:
334
+ assert_never(self.state.fn)
335
+
336
+ def start_time(self, stage: Stage | None = None):
337
+ """Return the start time of the run"""
338
+ if (tracker := self.__runtime_tracker()) is None:
339
+ raise ValueError(
340
+ "RuntimeTrackerCallback is not set. Cannot get start time."
341
+ )
342
+ if stage is None:
343
+ stage = self.__current_stage()
304
344
 
305
- @contextlib.contextmanager
306
- def _actsave_context(self, model: LightningModule):
307
- hparams = cast(BaseConfig, model.hparams)
308
- if not (actsave_config := hparams.trainer.actsave):
309
- yield
310
- return
345
+ return tracker.start_time(stage)
346
+
347
+ def end_time(self, stage: Stage | None = None):
348
+ """Return the end time of the run"""
349
+ if (tracker := self.__runtime_tracker()) is None:
350
+ raise ValueError(
351
+ "RuntimeTrackerCallback is not set. Cannot get start time."
352
+ )
353
+ if stage is None:
354
+ stage = self.__current_stage()
355
+
356
+ return tracker.end_time(stage)
357
+
358
+ def time_elapsed(self, stage: Stage | None = None):
359
+ """Return the time elapsed for the run"""
360
+ if (tracker := self.__runtime_tracker()) is None:
361
+ raise ValueError(
362
+ "RuntimeTrackerCallback is not set. Cannot get start time."
363
+ )
364
+ if stage is None:
365
+ stage = self.__current_stage()
311
366
 
312
- # Enter actsave context
313
- with ActSave.enabled(actsave_config.resolve_save_dir(hparams)):
314
- yield
367
+ return tracker.time_elapsed(stage)
368
+
369
+ @property
370
+ def _base_module(self):
371
+ if self.lightning_module is None:
372
+ raise ValueError("LightningModule is not set.")
373
+
374
+ from ..model.base import LightningModuleBase
375
+
376
+ if not isinstance(self.lightning_module, LightningModuleBase):
377
+ raise ValueError(
378
+ f"LightningModule is not an instance of {LightningModuleBase}."
379
+ )
380
+
381
+ return self.lightning_module
315
382
 
316
383
  @override
317
384
  def _run(
318
385
  self, model: LightningModule, ckpt_path: str | Path | None = None
319
386
  ) -> _EVALUATE_OUTPUT | _PREDICT_OUTPUT | None:
320
- """
321
- Two things done here:
322
- 1. Lightning doesn't support gradient clipping with manual optimization.
323
- We patch the `Trainer._run` method to throw if gradient clipping is enabled
324
- and `model.automatic_optimization` is False.
325
-
326
- 2. We actually set up actsave here.
387
+ """Lightning doesn't support gradient clipping with manual optimization.
388
+ We patch the `Trainer._run` method to throw if gradient clipping is enabled
389
+ and `model.automatic_optimization` is False.
327
390
  """
328
391
 
329
392
  if not model.automatic_optimization and (
@@ -336,5 +399,22 @@ class Trainer(LightningTrainer):
336
399
  "or disable automatic gradient clipping. "
337
400
  )
338
401
 
339
- with self._actsave_context(model):
340
- return super()._run(model, ckpt_path)
402
+ return super()._run(model, ckpt_path)
403
+
404
+ @override
405
+ def save_checkpoint(
406
+ self,
407
+ filepath: str | Path,
408
+ weights_only: bool = False,
409
+ storage_options: Any | None = None,
410
+ ):
411
+ filepath = Path(filepath)
412
+ ret_val = super().save_checkpoint(filepath, weights_only, storage_options)
413
+
414
+ # Save the checkpoint metadata
415
+ lm = self._base_module
416
+ if lm.config.trainer.save_checkpoint_metadata and self.is_global_zero:
417
+ # Generate the metadata and write to disk
418
+ _write_checkpoint_metadata(self, lm, filepath)
419
+
420
+ return ret_val
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.9.1
3
+ Version: 0.10.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -9,11 +9,13 @@ Classifier: Programming Language :: Python :: 3
9
9
  Classifier: Programming Language :: Python :: 3.10
10
10
  Classifier: Programming Language :: Python :: 3.11
11
11
  Classifier: Programming Language :: Python :: 3.12
12
+ Requires-Dist: GitPython
12
13
  Requires-Dist: lightning
13
14
  Requires-Dist: nshconfig
14
15
  Requires-Dist: nshrunner
15
16
  Requires-Dist: nshutils
16
17
  Requires-Dist: numpy
18
+ Requires-Dist: psutil
17
19
  Requires-Dist: pytorch-lightning
18
20
  Requires-Dist: torch
19
21
  Requires-Dist: torchmetrics
@@ -1,32 +1,34 @@
1
- nshtrainer/__init__.py,sha256=nbZHdfTk0oWqsJgrSzdgk2DSf4CGhdZn79esoJGauO8,548
1
+ nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
+ nshtrainer/_checkpoint/loader.py,sha256=48flPr1XgQHOgIPaCrRqOEvRuG0SZuV3cQ1vgHLqFqI,11025
3
+ nshtrainer/_checkpoint/metadata.py,sha256=C7je_soYyEbZjiq7p2_pSVFkgcXnz2J2H5sMy8oskx0,3051
2
4
  nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
3
5
  nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
4
6
  nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
5
7
  nshtrainer/_experimental/flops/module_tracker.py,sha256=bUL-IRTd0aF_DwmXkZjHZAA31p4ZEhyqhc26XWKQUUY,4922
6
- nshtrainer/actsave/__init__.py,sha256=_ZuwgRtF1-ekouXNvtZCAS1g_IDYGB4NX8BFSGNGBT8,119
7
- nshtrainer/actsave/_callback.py,sha256=mnHOtuG9vtHEzz9q4vCvDNC6VvjZsgb4MSSuOoUDh3M,2778
8
- nshtrainer/callbacks/__init__.py,sha256=I6W33ityL9Ko8jjqHh3WH_8miV59SAe9LxInhoqX5XE,1665
8
+ nshtrainer/callbacks/__init__.py,sha256=ifXQRwtccznl4lMKwKLSuuAQC4bKFBgfzQ4rx9gOqjE,2345
9
9
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
10
- nshtrainer/callbacks/base.py,sha256=LrcRUV02bZEKXRIRvhHT9qsvw_kwoWiAdQkVMyKc5NU,3542
10
+ nshtrainer/callbacks/actsave.py,sha256=aY6T_NAzaFAVU8WMHOXnWL5wd2bi8eVxeU2S0iAs70c,4446
11
+ nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
11
12
  nshtrainer/callbacks/early_stopping.py,sha256=jriSU761wf_qTJ9Bos0D3h5aDvZHYpRqK62Ne8aWp5I,3768
12
- nshtrainer/callbacks/ema.py,sha256=zKCtvzZFo0ORlwNZHjaMk-sJoxrlTtFWOzR-yGy95W0,12134
13
- nshtrainer/callbacks/finite_checks.py,sha256=kX3TIJsxyqx0GuLJfYsqVgKU27zwjG9Z8324lyCFtwM,2087
14
- nshtrainer/callbacks/gradient_skipping.py,sha256=ModaIXpb69LbA8TpEXKRLdr4Sq7-l0CWnN6fvpaV188,3477
13
+ nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
14
+ nshtrainer/callbacks/finite_checks.py,sha256=AO5fa51uANAjAkeJfTquOjK6W_4RSU5Kky3f5jmAPlQ,2084
15
+ nshtrainer/callbacks/gradient_skipping.py,sha256=fSJpjgHbztFKz7w3qFuCHZpmbEt9BCLAy-sU0B4xJQI,3474
15
16
  nshtrainer/callbacks/interval.py,sha256=smz5Zl8cN6X6yHKVsMRS2e3SEkzRCP3LvwE1ONvLfaw,8080
16
- nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=ZT0bn7X0BZbQXbk6fos47NsbbhD4Z9c9YmFqdcUEqus,1503
17
+ nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=p0zeDK3PLWWl485e9o08ywEEARCfuZ5it47tNCtR4ec,2838
17
18
  nshtrainer/callbacks/log_epoch.py,sha256=fTa_K_Y8A7g09630cG4YkDE6AzSMPkjb9bpPm4gtqos,1120
18
- nshtrainer/callbacks/norm_logging.py,sha256=IMrK0WiVSDFyspwyPpwELMK4mmd5Jpx4enAW_GsWbi4,6284
19
- nshtrainer/callbacks/on_exception_checkpoint.py,sha256=eDyB7qkpPdAaKjAY2uFMMY8Nht6TGeuDnsgHuKtp8eA,1615
20
- nshtrainer/callbacks/print_table.py,sha256=FcA-CBWwMf9c1NNRinvYpZC400RNQxuP28bJfgniT3Q,2840
21
- nshtrainer/callbacks/throughput_monitor.py,sha256=YQLdpX3LGybIiD814yT9yCCVSEXRWf8WwsvVaN5aDBE,1848
22
- nshtrainer/callbacks/timer.py,sha256=sDXPPcdDKu5xnuK_bjr8plIq9MBuluNJ42Mt9LvPZzc,4610
23
- nshtrainer/callbacks/wandb_watch.py,sha256=pUpMsNxd03ex1rzOmFw2HzGOXjnQGaH84m8cc2dXo4g,2937
19
+ nshtrainer/callbacks/model_checkpoint.py,sha256=4zYycpXHGRyL4svWLP6GmG3WJs5m3B5PRCOzXC3m_qg,5955
20
+ nshtrainer/callbacks/norm_logging.py,sha256=EWyrfkp8iHjQi9iAAXHxb0xStw2RwkdpKG2_gLarQRA,6281
21
+ nshtrainer/callbacks/on_exception_checkpoint.py,sha256=zna_QF_x4HwD7Es5XxrHLDED43NU1GpcDNoL139HEOs,3355
22
+ nshtrainer/callbacks/print_table.py,sha256=_FdAHhqylWGk4Z0c2FrLFeiMA4jhfA_beZRK_BHpzmE,2837
23
+ nshtrainer/callbacks/throughput_monitor.py,sha256=4EF3b79HdHiRgBGIFDyD4O1oywb5h1tV8nml7NuuDjU,1845
24
+ nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
25
+ nshtrainer/callbacks/wandb_watch.py,sha256=bicXS3nZfPGoN7Owu1XIBS-1bw7yeIJdYJTnRN0dp2E,2934
24
26
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
25
27
  nshtrainer/data/balanced_batch_sampler.py,sha256=bcJBcQjh1hB1yKF_xSlT9AtEWv0BJjYc1CuH2BF-ea8,4392
26
28
  nshtrainer/data/transform.py,sha256=JeGxvytQly8hougrsdMmKG8gJ6qvFPDglJCO4Tp6STk,1795
27
- nshtrainer/ll/__init__.py,sha256=nxYPtoFOFAvzkD6O3EIuwCiRi_LedYa_EH-RIfDG91s,2685
29
+ nshtrainer/ll/__init__.py,sha256=dD0ISxHJ2lg1HLSM0b3db7TBlsPpQCtChnuYO-c2oqI,2635
28
30
  nshtrainer/ll/_experimental.py,sha256=oBQCKOEVYoxuUU9eLb-Fg2B2mzZD7SA0zfAO6lmWZ88,53
29
- nshtrainer/ll/actsave.py,sha256=QJ7yJIqvabpZzumX7PLPzkh6dfqY-zxiEdzv48VtZEY,123
31
+ nshtrainer/ll/actsave.py,sha256=2lbiseSrjcwFT6AiyLNWarTWl1bnzliVWlu1iOfnP30,209
30
32
  nshtrainer/ll/callbacks.py,sha256=AxyUmc8aGRSjx6WwwgXYCmdJ73rwLuEAEH0AGRosojQ,49
31
33
  nshtrainer/ll/config.py,sha256=fKumJf42HY2FITX1QUM1OTXkYD6U2np2ciyd4PFRPZ8,145
32
34
  nshtrainer/ll/data.py,sha256=zRG0FRje-jtSHximVzkHIHzpwsyQxpHCoACFihNKLPM,44
@@ -45,13 +47,16 @@ nshtrainer/lr_scheduler/__init__.py,sha256=uEvgaFAs-4s_bAEMaildy0GT6OvgpgOEKTuzq
45
47
  nshtrainer/lr_scheduler/_base.py,sha256=7xOIuxQ86YHbFWG5a3gX46emQj1WN_LaY4-i0Q1TDBg,3659
46
48
  nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=mn6cyizyI_stkXtg6zxIEGF9btIxMRWigUHUTlUYCSw,5221
47
49
  nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-WKArwxxf7ZQizhmbCw,2782
48
- nshtrainer/model/__init__.py,sha256=y32Hla-5whpzLL2BtCJpBakSp8o-1nQbpO0j_-xq_Po,1864
49
- nshtrainer/model/base.py,sha256=YtqnjiMf0cLVjFEQuOLm5WwCkVnZftiHlIdCrxdax3s,21297
50
- nshtrainer/model/config.py,sha256=-I_HLTTwqWimnnoKJ64oBEq3x31CZj9rwrg9MnFzs38,68215
50
+ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
51
+ nshtrainer/metrics/_config.py,sha256=hWWS4IXENRyH3RmJ7z1Wx1n3Lt1sNMlGOrcU6PW15o0,1104
52
+ nshtrainer/model/__init__.py,sha256=TbexTxiE20WHYg5q3L88Hysk4LlHeKk_isv33aSBREA,1918
53
+ nshtrainer/model/_environment.py,sha256=s3JFnigbssFRJTwH33K7DcAYVhLOFCC1OZgFNXJgjuw,22317
54
+ nshtrainer/model/base.py,sha256=Bmw-t70TydDbE9P0ee-lTibGoUhrCx5Qke-upa7FGVM,17512
55
+ nshtrainer/model/config.py,sha256=B1XkKYbhpAm6RmF4n4eR66hMh-kCXwIQB2pQuhR9TZE,53177
51
56
  nshtrainer/model/modules/callback.py,sha256=JF59U9-CjJsAIspEhTJbVaGN0wGctZG7UquE3IS7R8A,6408
52
57
  nshtrainer/model/modules/debug.py,sha256=DTVty8cKnzj1GCULRyGx_sWTTsq9NLi30dzqjRTnuCU,1127
53
58
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
54
- nshtrainer/model/modules/logger.py,sha256=XEeo3QrplTNKZqfl6iWZf3fze3R4YOeOvs-RKVHFoQs,5527
59
+ nshtrainer/model/modules/logger.py,sha256=YYhehQysqTjuVFcd_EREYDh57CIlezidFBS2Ohp_xKo,5661
55
60
  nshtrainer/model/modules/profiler.py,sha256=rQ_jRMcM1Z2AIROZlRnBRHM5rkTpq67afZPD6CIRfXs,825
56
61
  nshtrainer/model/modules/rlp_sanity_checks.py,sha256=o6gUceFwsuDHmL8eLOYuT3JGXFzq_qc4awl2RWaBygU,8900
57
62
  nshtrainer/model/modules/shared_parameters.py,sha256=mD5wrlBE3c025vzVdTpnSyC8yxzuI-aUWMmPhqPT0a0,2694
@@ -61,17 +66,19 @@ nshtrainer/nn/module_dict.py,sha256=NOY0B6WDTnktyWH4GthsprMQo0bpehC-hCq9SfD8paE,
61
66
  nshtrainer/nn/module_list.py,sha256=fb2u5Rqdjff8Pekyr9hkCPkBorQ-fldzzFAjsgWAm30,1719
62
67
  nshtrainer/nn/nonlinearity.py,sha256=owtU4kh4G98psD0axOJWVfBhm-OtJVgFM-TXSHmbNPU,3625
63
68
  nshtrainer/optimizer.py,sha256=kuJEA1pvB3y1FcsfhAoOJujVqEZqFHlmYO8GW6JeA1g,1527
64
- nshtrainer/runner.py,sha256=7EumpnBkdNWjSNT9Gm-pkxAJ3W6-iMC-yae-WNeZcLw,3771
69
+ nshtrainer/runner.py,sha256=6qfE5FBONzD79kVHuWYKEvK0J_Qi5dMBbHQhRMmnIhE,3649
65
70
  nshtrainer/scripts/check_env.py,sha256=IMl6dSqsLYppI0XuCsVq8lK4bYqXwY9KHJkzsShz4Kg,806
66
71
  nshtrainer/scripts/find_packages.py,sha256=FbdlfmAefttFSMfaT0A46a-oHLP_ioaQKihwBfBeWeA,1467
67
72
  nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
68
- nshtrainer/trainer/signal_connector.py,sha256=JSP8W2PSdzwO3iWX1WOL1l8dufh2dKgUWeJ2gEWCppg,10626
69
- nshtrainer/trainer/trainer.py,sha256=eYEYfY9v70MuorHcSf8nqM7f2CkmUHhpPcjCk4FJD7k,14034
73
+ nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
74
+ nshtrainer/trainer/checkpoint_connector.py,sha256=xoqI2dcPnlNFPPLVIU6dBOvRPC9PtfX5qu__xV1lx0Y,2124
75
+ nshtrainer/trainer/signal_connector.py,sha256=llwc8pdKAWxREFpjdi14Bpy8rGVMEJsmJx_s2p4gI8E,10689
76
+ nshtrainer/trainer/trainer.py,sha256=n3T9Iz3eaDostxEdjapWImAsVMxyU9WBdhlPl0THX-g,16785
70
77
  nshtrainer/util/environment.py,sha256=_SEtiQ_s5bL5pllUlf96AOUv15kNvCPvocVC13S7mIk,4166
71
78
  nshtrainer/util/seed.py,sha256=HEXgVs-wldByahOysKwq7506OHxdYTEgmP-tDQVAEkQ,287
72
79
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
73
80
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
74
81
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
75
- nshtrainer-0.9.1.dist-info/METADATA,sha256=3s9luSztUNVhu3t_sSmOw3HhwuVVUoiLhQwlxBiaaSg,647
76
- nshtrainer-0.9.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
77
- nshtrainer-0.9.1.dist-info/RECORD,,
82
+ nshtrainer-0.10.1.dist-info/METADATA,sha256=O8wMPb0ksoZajyes8dsq4IIjsfP_jQaxGYpW3rYE9Ro,695
83
+ nshtrainer-0.10.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
84
+ nshtrainer-0.10.1.dist-info/RECORD,,
@@ -1,3 +0,0 @@
1
- from nshutils.actsave import * # type: ignore # noqa: F403
2
-
3
- from ._callback import ActSaveCallback as ActSaveCallback