nshtrainer 0.8.7__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. nshtrainer/__init__.py +2 -1
  2. nshtrainer/callbacks/__init__.py +17 -1
  3. nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
  4. nshtrainer/callbacks/base.py +7 -5
  5. nshtrainer/callbacks/ema.py +1 -1
  6. nshtrainer/callbacks/finite_checks.py +1 -1
  7. nshtrainer/callbacks/gradient_skipping.py +1 -1
  8. nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
  9. nshtrainer/callbacks/model_checkpoint.py +187 -0
  10. nshtrainer/callbacks/norm_logging.py +1 -1
  11. nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
  12. nshtrainer/callbacks/print_table.py +1 -1
  13. nshtrainer/callbacks/throughput_monitor.py +1 -1
  14. nshtrainer/callbacks/timer.py +1 -1
  15. nshtrainer/callbacks/wandb_watch.py +1 -1
  16. nshtrainer/ll/__init__.py +0 -1
  17. nshtrainer/ll/actsave.py +2 -1
  18. nshtrainer/metrics/__init__.py +1 -0
  19. nshtrainer/metrics/_config.py +37 -0
  20. nshtrainer/model/__init__.py +11 -11
  21. nshtrainer/model/_environment.py +777 -0
  22. nshtrainer/model/base.py +5 -114
  23. nshtrainer/model/config.py +92 -507
  24. nshtrainer/model/modules/logger.py +11 -6
  25. nshtrainer/runner.py +3 -6
  26. nshtrainer/trainer/_checkpoint_metadata.py +102 -0
  27. nshtrainer/trainer/_checkpoint_resolver.py +319 -0
  28. nshtrainer/trainer/_runtime_callback.py +120 -0
  29. nshtrainer/trainer/checkpoint_connector.py +63 -0
  30. nshtrainer/trainer/signal_connector.py +12 -9
  31. nshtrainer/trainer/trainer.py +111 -31
  32. {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/METADATA +3 -1
  33. {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/RECORD +34 -27
  34. nshtrainer/actsave/__init__.py +0 -3
  35. {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/WHEEL +0 -0
@@ -11,6 +11,7 @@ from pathlib import Path
11
11
  from types import FrameType
12
12
  from typing import Any, TypeAlias
13
13
 
14
+ import nshrunner as nr
14
15
  import torch.utils.data
15
16
  from lightning.fabric.plugins.environments.lsf import LSFEnvironment
16
17
  from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
@@ -27,20 +28,22 @@ _HANDLER: TypeAlias = Callable[[_SIGNUM, FrameType], Any] | int | signal.Handler
27
28
 
28
29
 
29
30
  def _resolve_requeue_signals():
30
- signals: list[signal.Signals] = []
31
-
32
- if timeout_signal_name := os.environ.get("NSHRUNNER_TIMEOUT_SIGNAL"):
33
- signals.append(signal.Signals[timeout_signal_name])
34
-
35
- if preempt_signal_name := os.environ.get("NSHRUNNER_PREEMPT_SIGNAL"):
36
- signals.append(signal.Signals[preempt_signal_name])
31
+ if (session := nr.Session.from_current_session()) is None:
32
+ return None
37
33
 
34
+ signals: list[signal.Signals] = []
35
+ if session.submit_timeout_signal:
36
+ signals.append(session.submit_timeout_signal)
37
+ if session.submit_preempt_signal:
38
+ signals.append(session.submit_preempt_signal)
38
39
  return signals
39
40
 
40
41
 
41
42
  class _SignalConnector(_LightningSignalConnector):
42
- def _auto_requeue_signals(self) -> list[signal.Signals]:
43
- signals = _resolve_requeue_signals()
43
+ def _auto_requeue_signals(self) -> list[signal.Signals] | None:
44
+ if not (signals := _resolve_requeue_signals()):
45
+ return None
46
+
44
47
  signals_set = set(signals)
45
48
  valid_signals: set[signal.Signals] = signal.valid_signals()
46
49
  assert signals_set.issubset(
@@ -3,7 +3,7 @@ import logging
3
3
  import os
4
4
  from collections.abc import Sequence
5
5
  from pathlib import Path
6
- from typing import Any, cast
6
+ from typing import TYPE_CHECKING, Any, cast
7
7
 
8
8
  import torch
9
9
  from lightning.fabric.plugins.environments.lsf import LSFEnvironment
@@ -11,11 +11,12 @@ from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
11
11
  from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
12
12
  from lightning.pytorch import LightningModule
13
13
  from lightning.pytorch import Trainer as LightningTrainer
14
+ from lightning.pytorch.callbacks import Callback
14
15
  from lightning.pytorch.profilers import Profiler
16
+ from lightning.pytorch.trainer.states import TrainerFn
15
17
  from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
16
18
  from typing_extensions import Unpack, assert_never, override
17
19
 
18
- from ..actsave import ActSave
19
20
  from ..callbacks.base import resolve_all_callbacks
20
21
  from ..model.config import (
21
22
  AcceleratorConfigProtocol,
@@ -24,6 +25,8 @@ from ..model.config import (
24
25
  LightningTrainerKwargs,
25
26
  StrategyConfigProtocol,
26
27
  )
28
+ from ._checkpoint_metadata import _write_checkpoint_metadata
29
+ from ._runtime_callback import RuntimeTrackerCallback, Stage
27
30
  from .signal_connector import _SignalConnector
28
31
 
29
32
  log = logging.getLogger(__name__)
@@ -168,12 +171,12 @@ class Trainer(LightningTrainer):
168
171
 
169
172
  if (accelerator := config.trainer.accelerator) is not None:
170
173
  if isinstance(accelerator, AcceleratorConfigProtocol):
171
- accelerator = accelerator.construct_accelerator()
174
+ accelerator = accelerator.create_accelerator()
172
175
  _update_kwargs(accelerator=accelerator)
173
176
 
174
177
  if (strategy := config.trainer.strategy) is not None:
175
178
  if isinstance(strategy, StrategyConfigProtocol):
176
- strategy = strategy.construct_strategy()
179
+ strategy = strategy.create_strategy()
177
180
  _update_kwargs(strategy=strategy)
178
181
 
179
182
  if (precision := config.trainer.precision) is not None:
@@ -220,7 +223,7 @@ class Trainer(LightningTrainer):
220
223
  if profiler := config.trainer.profiler:
221
224
  # If the profiler is an ProfilerConfig instance, then we instantiate it.
222
225
  if isinstance(profiler, BaseProfilerConfig):
223
- profiler = profiler.construct_profiler(config)
226
+ profiler = profiler.create_profiler(config)
224
227
  # Make sure that the profiler is an instance of `Profiler`.
225
228
  if not isinstance(profiler, Profiler):
226
229
  raise ValueError(f"{profiler=} is not an instance of `{Profiler}`.")
@@ -236,7 +239,7 @@ class Trainer(LightningTrainer):
236
239
  if plugin_configs := config.trainer.plugins:
237
240
  _update_kwargs(
238
241
  plugins=[
239
- plugin_config.construct_plugin() for plugin_config in plugin_configs
242
+ plugin_config.create_plugin() for plugin_config in plugin_configs
240
243
  ]
241
244
  )
242
245
 
@@ -244,7 +247,7 @@ class Trainer(LightningTrainer):
244
247
  log.critical(f"Disabling logger because {config.trainer.logging.enabled=}.")
245
248
  kwargs["logger"] = False
246
249
  else:
247
- _update_kwargs(logger=config.trainer.logging.construct_loggers(config))
250
+ _update_kwargs(logger=config.trainer.logging.create_loggers(config))
248
251
 
249
252
  if config.trainer.auto_determine_num_nodes:
250
253
  # When num_nodes is auto, we need to detect the number of nodes.
@@ -275,6 +278,9 @@ class Trainer(LightningTrainer):
275
278
 
276
279
  return kwargs
277
280
 
281
+ if TYPE_CHECKING:
282
+ callbacks: list[Callback]
283
+
278
284
  @override
279
285
  def __init__(
280
286
  self,
@@ -282,12 +288,14 @@ class Trainer(LightningTrainer):
282
288
  /,
283
289
  **kwargs: Unpack[LightningTrainerKwargs],
284
290
  ):
285
- self._ll_config = config
286
291
  kwargs = self._update_kwargs(config, kwargs)
287
292
  log.critical(f"LightningTrainer.__init__ with {kwargs=}.")
288
293
 
289
294
  super().__init__(**kwargs)
290
295
 
296
+ # Add our own start time callback to measure the start time.
297
+ self.callbacks.append(RuntimeTrackerCallback())
298
+
291
299
  # Replace the signal connector with our own.
292
300
  self._signal_connector = _SignalConnector(self)
293
301
 
@@ -296,34 +304,89 @@ class Trainer(LightningTrainer):
296
304
  log_dir = str(Path(log_dir).resolve())
297
305
  log.critical(f"LightningTrainer log directory: {self.log_dir}.")
298
306
 
299
- # Checkpoint loading
300
- if (
301
- ckpt_loading := self._ll_config.trainer.checkpoint_loading
302
- ) and ckpt_loading.path:
303
- self.ckpt_path = ckpt_loading.path
307
+ def __runtime_tracker(self):
308
+ return next(
309
+ (
310
+ callback
311
+ for callback in self.callbacks
312
+ if isinstance(callback, RuntimeTrackerCallback)
313
+ ),
314
+ None,
315
+ )
316
+
317
+ def __current_stage(self) -> Stage:
318
+ match self.state.fn:
319
+ case None:
320
+ raise ValueError(
321
+ "Trainer state function is not set. "
322
+ "You must call `fit`, `validate`, `test`, or `predict`, "
323
+ "or explicitly provide a stage."
324
+ )
325
+ case TrainerFn.FITTING:
326
+ return "train"
327
+ case TrainerFn.VALIDATING:
328
+ return "validate"
329
+ case TrainerFn.TESTING:
330
+ return "test"
331
+ case TrainerFn.PREDICTING:
332
+ return "predict"
333
+ case _:
334
+ assert_never(self.state.fn)
335
+
336
+ def start_time(self, stage: Stage | None = None):
337
+ """Return the start time of the run"""
338
+ if (tracker := self.__runtime_tracker()) is None:
339
+ raise ValueError(
340
+ "RuntimeTrackerCallback is not set. Cannot get start time."
341
+ )
342
+ if stage is None:
343
+ stage = self.__current_stage()
304
344
 
305
- @contextlib.contextmanager
306
- def _actsave_context(self, model: LightningModule):
307
- hparams = cast(BaseConfig, model.hparams)
308
- if not (actsave_config := hparams.trainer.actsave):
309
- yield
310
- return
345
+ return tracker.start_time(stage)
346
+
347
+ def end_time(self, stage: Stage | None = None):
348
+ """Return the end time of the run"""
349
+ if (tracker := self.__runtime_tracker()) is None:
350
+ raise ValueError(
351
+ "RuntimeTrackerCallback is not set. Cannot get start time."
352
+ )
353
+ if stage is None:
354
+ stage = self.__current_stage()
355
+
356
+ return tracker.end_time(stage)
357
+
358
+ def time_elapsed(self, stage: Stage | None = None):
359
+ """Return the time elapsed for the run"""
360
+ if (tracker := self.__runtime_tracker()) is None:
361
+ raise ValueError(
362
+ "RuntimeTrackerCallback is not set. Cannot get start time."
363
+ )
364
+ if stage is None:
365
+ stage = self.__current_stage()
311
366
 
312
- # Enter actsave context
313
- with ActSave.enabled(actsave_config.resolve_save_dir(hparams)):
314
- yield
367
+ return tracker.time_elapsed(stage)
368
+
369
+ @property
370
+ def _base_module(self):
371
+ if self.lightning_module is None:
372
+ raise ValueError("LightningModule is not set.")
373
+
374
+ from ..model.base import LightningModuleBase
375
+
376
+ if not isinstance(self.lightning_module, LightningModuleBase):
377
+ raise ValueError(
378
+ f"LightningModule is not an instance of {LightningModuleBase}."
379
+ )
380
+
381
+ return self.lightning_module
315
382
 
316
383
  @override
317
384
  def _run(
318
385
  self, model: LightningModule, ckpt_path: str | Path | None = None
319
386
  ) -> _EVALUATE_OUTPUT | _PREDICT_OUTPUT | None:
320
- """
321
- Two things done here:
322
- 1. Lightning doesn't support gradient clipping with manual optimization.
323
- We patch the `Trainer._run` method to throw if gradient clipping is enabled
324
- and `model.automatic_optimization` is False.
325
-
326
- 2. We actually set up actsave here.
387
+ """Lightning doesn't support gradient clipping with manual optimization.
388
+ We patch the `Trainer._run` method to throw if gradient clipping is enabled
389
+ and `model.automatic_optimization` is False.
327
390
  """
328
391
 
329
392
  if not model.automatic_optimization and (
@@ -336,5 +399,22 @@ class Trainer(LightningTrainer):
336
399
  "or disable automatic gradient clipping. "
337
400
  )
338
401
 
339
- with self._actsave_context(model):
340
- return super()._run(model, ckpt_path)
402
+ return super()._run(model, ckpt_path)
403
+
404
+ @override
405
+ def save_checkpoint(
406
+ self,
407
+ filepath: str | Path,
408
+ weights_only: bool = False,
409
+ storage_options: Any | None = None,
410
+ ):
411
+ filepath = Path(filepath)
412
+ ret_val = super().save_checkpoint(filepath, weights_only, storage_options)
413
+
414
+ # Save the checkpoint metadata
415
+ lm = self._base_module
416
+ if lm.config.trainer.save_checkpoint_metadata and self.is_global_zero:
417
+ # Generate the metadata and write to disk
418
+ _write_checkpoint_metadata(self, lm, filepath)
419
+
420
+ return ret_val
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.8.7
3
+ Version: 0.10.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -9,11 +9,13 @@ Classifier: Programming Language :: Python :: 3
9
9
  Classifier: Programming Language :: Python :: 3.10
10
10
  Classifier: Programming Language :: Python :: 3.11
11
11
  Classifier: Programming Language :: Python :: 3.12
12
+ Requires-Dist: GitPython
12
13
  Requires-Dist: lightning
13
14
  Requires-Dist: nshconfig
14
15
  Requires-Dist: nshrunner
15
16
  Requires-Dist: nshutils
16
17
  Requires-Dist: numpy
18
+ Requires-Dist: psutil
17
19
  Requires-Dist: pytorch-lightning
18
20
  Requires-Dist: torch
19
21
  Requires-Dist: torchmetrics
@@ -1,32 +1,32 @@
1
- nshtrainer/__init__.py,sha256=nbZHdfTk0oWqsJgrSzdgk2DSf4CGhdZn79esoJGauO8,548
1
+ nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
2
  nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
3
3
  nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
4
4
  nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
5
5
  nshtrainer/_experimental/flops/module_tracker.py,sha256=bUL-IRTd0aF_DwmXkZjHZAA31p4ZEhyqhc26XWKQUUY,4922
6
- nshtrainer/actsave/__init__.py,sha256=_ZuwgRtF1-ekouXNvtZCAS1g_IDYGB4NX8BFSGNGBT8,119
7
- nshtrainer/actsave/_callback.py,sha256=mnHOtuG9vtHEzz9q4vCvDNC6VvjZsgb4MSSuOoUDh3M,2778
8
- nshtrainer/callbacks/__init__.py,sha256=I6W33ityL9Ko8jjqHh3WH_8miV59SAe9LxInhoqX5XE,1665
6
+ nshtrainer/callbacks/__init__.py,sha256=ifXQRwtccznl4lMKwKLSuuAQC4bKFBgfzQ4rx9gOqjE,2345
9
7
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
10
- nshtrainer/callbacks/base.py,sha256=LrcRUV02bZEKXRIRvhHT9qsvw_kwoWiAdQkVMyKc5NU,3542
8
+ nshtrainer/callbacks/actsave.py,sha256=aY6T_NAzaFAVU8WMHOXnWL5wd2bi8eVxeU2S0iAs70c,4446
9
+ nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
11
10
  nshtrainer/callbacks/early_stopping.py,sha256=jriSU761wf_qTJ9Bos0D3h5aDvZHYpRqK62Ne8aWp5I,3768
12
- nshtrainer/callbacks/ema.py,sha256=zKCtvzZFo0ORlwNZHjaMk-sJoxrlTtFWOzR-yGy95W0,12134
13
- nshtrainer/callbacks/finite_checks.py,sha256=kX3TIJsxyqx0GuLJfYsqVgKU27zwjG9Z8324lyCFtwM,2087
14
- nshtrainer/callbacks/gradient_skipping.py,sha256=ModaIXpb69LbA8TpEXKRLdr4Sq7-l0CWnN6fvpaV188,3477
11
+ nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
12
+ nshtrainer/callbacks/finite_checks.py,sha256=AO5fa51uANAjAkeJfTquOjK6W_4RSU5Kky3f5jmAPlQ,2084
13
+ nshtrainer/callbacks/gradient_skipping.py,sha256=fSJpjgHbztFKz7w3qFuCHZpmbEt9BCLAy-sU0B4xJQI,3474
15
14
  nshtrainer/callbacks/interval.py,sha256=smz5Zl8cN6X6yHKVsMRS2e3SEkzRCP3LvwE1ONvLfaw,8080
16
- nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=ZT0bn7X0BZbQXbk6fos47NsbbhD4Z9c9YmFqdcUEqus,1503
15
+ nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=p0zeDK3PLWWl485e9o08ywEEARCfuZ5it47tNCtR4ec,2838
17
16
  nshtrainer/callbacks/log_epoch.py,sha256=fTa_K_Y8A7g09630cG4YkDE6AzSMPkjb9bpPm4gtqos,1120
18
- nshtrainer/callbacks/norm_logging.py,sha256=IMrK0WiVSDFyspwyPpwELMK4mmd5Jpx4enAW_GsWbi4,6284
19
- nshtrainer/callbacks/on_exception_checkpoint.py,sha256=eDyB7qkpPdAaKjAY2uFMMY8Nht6TGeuDnsgHuKtp8eA,1615
20
- nshtrainer/callbacks/print_table.py,sha256=FcA-CBWwMf9c1NNRinvYpZC400RNQxuP28bJfgniT3Q,2840
21
- nshtrainer/callbacks/throughput_monitor.py,sha256=YQLdpX3LGybIiD814yT9yCCVSEXRWf8WwsvVaN5aDBE,1848
22
- nshtrainer/callbacks/timer.py,sha256=sDXPPcdDKu5xnuK_bjr8plIq9MBuluNJ42Mt9LvPZzc,4610
23
- nshtrainer/callbacks/wandb_watch.py,sha256=pUpMsNxd03ex1rzOmFw2HzGOXjnQGaH84m8cc2dXo4g,2937
17
+ nshtrainer/callbacks/model_checkpoint.py,sha256=4zYycpXHGRyL4svWLP6GmG3WJs5m3B5PRCOzXC3m_qg,5955
18
+ nshtrainer/callbacks/norm_logging.py,sha256=EWyrfkp8iHjQi9iAAXHxb0xStw2RwkdpKG2_gLarQRA,6281
19
+ nshtrainer/callbacks/on_exception_checkpoint.py,sha256=zna_QF_x4HwD7Es5XxrHLDED43NU1GpcDNoL139HEOs,3355
20
+ nshtrainer/callbacks/print_table.py,sha256=_FdAHhqylWGk4Z0c2FrLFeiMA4jhfA_beZRK_BHpzmE,2837
21
+ nshtrainer/callbacks/throughput_monitor.py,sha256=4EF3b79HdHiRgBGIFDyD4O1oywb5h1tV8nml7NuuDjU,1845
22
+ nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
23
+ nshtrainer/callbacks/wandb_watch.py,sha256=bicXS3nZfPGoN7Owu1XIBS-1bw7yeIJdYJTnRN0dp2E,2934
24
24
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
25
25
  nshtrainer/data/balanced_batch_sampler.py,sha256=bcJBcQjh1hB1yKF_xSlT9AtEWv0BJjYc1CuH2BF-ea8,4392
26
26
  nshtrainer/data/transform.py,sha256=JeGxvytQly8hougrsdMmKG8gJ6qvFPDglJCO4Tp6STk,1795
27
- nshtrainer/ll/__init__.py,sha256=nxYPtoFOFAvzkD6O3EIuwCiRi_LedYa_EH-RIfDG91s,2685
27
+ nshtrainer/ll/__init__.py,sha256=dD0ISxHJ2lg1HLSM0b3db7TBlsPpQCtChnuYO-c2oqI,2635
28
28
  nshtrainer/ll/_experimental.py,sha256=oBQCKOEVYoxuUU9eLb-Fg2B2mzZD7SA0zfAO6lmWZ88,53
29
- nshtrainer/ll/actsave.py,sha256=QJ7yJIqvabpZzumX7PLPzkh6dfqY-zxiEdzv48VtZEY,123
29
+ nshtrainer/ll/actsave.py,sha256=2lbiseSrjcwFT6AiyLNWarTWl1bnzliVWlu1iOfnP30,209
30
30
  nshtrainer/ll/callbacks.py,sha256=AxyUmc8aGRSjx6WwwgXYCmdJ73rwLuEAEH0AGRosojQ,49
31
31
  nshtrainer/ll/config.py,sha256=fKumJf42HY2FITX1QUM1OTXkYD6U2np2ciyd4PFRPZ8,145
32
32
  nshtrainer/ll/data.py,sha256=zRG0FRje-jtSHximVzkHIHzpwsyQxpHCoACFihNKLPM,44
@@ -45,13 +45,16 @@ nshtrainer/lr_scheduler/__init__.py,sha256=uEvgaFAs-4s_bAEMaildy0GT6OvgpgOEKTuzq
45
45
  nshtrainer/lr_scheduler/_base.py,sha256=7xOIuxQ86YHbFWG5a3gX46emQj1WN_LaY4-i0Q1TDBg,3659
46
46
  nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=mn6cyizyI_stkXtg6zxIEGF9btIxMRWigUHUTlUYCSw,5221
47
47
  nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=h76oTHYpMxauV_l6lviya5DW-WKArwxxf7ZQizhmbCw,2782
48
- nshtrainer/model/__init__.py,sha256=y32Hla-5whpzLL2BtCJpBakSp8o-1nQbpO0j_-xq_Po,1864
49
- nshtrainer/model/base.py,sha256=YtqnjiMf0cLVjFEQuOLm5WwCkVnZftiHlIdCrxdax3s,21297
50
- nshtrainer/model/config.py,sha256=6lATW6-Z1SIDgQ1IWrGBVQKTr8DhL5b_rFbJHQz0d5o,66796
48
+ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy00,50
49
+ nshtrainer/metrics/_config.py,sha256=hWWS4IXENRyH3RmJ7z1Wx1n3Lt1sNMlGOrcU6PW15o0,1104
50
+ nshtrainer/model/__init__.py,sha256=TbexTxiE20WHYg5q3L88Hysk4LlHeKk_isv33aSBREA,1918
51
+ nshtrainer/model/_environment.py,sha256=s3JFnigbssFRJTwH33K7DcAYVhLOFCC1OZgFNXJgjuw,22317
52
+ nshtrainer/model/base.py,sha256=Bmw-t70TydDbE9P0ee-lTibGoUhrCx5Qke-upa7FGVM,17512
53
+ nshtrainer/model/config.py,sha256=f8gbTaIi02U8EyooC1vv2ElZfXPgMIAVtU0n-LnkNE4,53187
51
54
  nshtrainer/model/modules/callback.py,sha256=JF59U9-CjJsAIspEhTJbVaGN0wGctZG7UquE3IS7R8A,6408
52
55
  nshtrainer/model/modules/debug.py,sha256=DTVty8cKnzj1GCULRyGx_sWTTsq9NLi30dzqjRTnuCU,1127
53
56
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
54
- nshtrainer/model/modules/logger.py,sha256=XEeo3QrplTNKZqfl6iWZf3fze3R4YOeOvs-RKVHFoQs,5527
57
+ nshtrainer/model/modules/logger.py,sha256=YYhehQysqTjuVFcd_EREYDh57CIlezidFBS2Ohp_xKo,5661
55
58
  nshtrainer/model/modules/profiler.py,sha256=rQ_jRMcM1Z2AIROZlRnBRHM5rkTpq67afZPD6CIRfXs,825
56
59
  nshtrainer/model/modules/rlp_sanity_checks.py,sha256=o6gUceFwsuDHmL8eLOYuT3JGXFzq_qc4awl2RWaBygU,8900
57
60
  nshtrainer/model/modules/shared_parameters.py,sha256=mD5wrlBE3c025vzVdTpnSyC8yxzuI-aUWMmPhqPT0a0,2694
@@ -61,17 +64,21 @@ nshtrainer/nn/module_dict.py,sha256=NOY0B6WDTnktyWH4GthsprMQo0bpehC-hCq9SfD8paE,
61
64
  nshtrainer/nn/module_list.py,sha256=fb2u5Rqdjff8Pekyr9hkCPkBorQ-fldzzFAjsgWAm30,1719
62
65
  nshtrainer/nn/nonlinearity.py,sha256=owtU4kh4G98psD0axOJWVfBhm-OtJVgFM-TXSHmbNPU,3625
63
66
  nshtrainer/optimizer.py,sha256=kuJEA1pvB3y1FcsfhAoOJujVqEZqFHlmYO8GW6JeA1g,1527
64
- nshtrainer/runner.py,sha256=7EumpnBkdNWjSNT9Gm-pkxAJ3W6-iMC-yae-WNeZcLw,3771
67
+ nshtrainer/runner.py,sha256=6qfE5FBONzD79kVHuWYKEvK0J_Qi5dMBbHQhRMmnIhE,3649
65
68
  nshtrainer/scripts/check_env.py,sha256=IMl6dSqsLYppI0XuCsVq8lK4bYqXwY9KHJkzsShz4Kg,806
66
69
  nshtrainer/scripts/find_packages.py,sha256=FbdlfmAefttFSMfaT0A46a-oHLP_ioaQKihwBfBeWeA,1467
67
70
  nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
68
- nshtrainer/trainer/signal_connector.py,sha256=JSP8W2PSdzwO3iWX1WOL1l8dufh2dKgUWeJ2gEWCppg,10626
69
- nshtrainer/trainer/trainer.py,sha256=eYEYfY9v70MuorHcSf8nqM7f2CkmUHhpPcjCk4FJD7k,14034
71
+ nshtrainer/trainer/_checkpoint_metadata.py,sha256=dj3g0rUZLWfohIRFAhhLqB4qh1fJsquQ5-EZ0Zbl5ZE,3042
72
+ nshtrainer/trainer/_checkpoint_resolver.py,sha256=kfIccBLWAMwn-Bw1pbj3XTXXaCdO_taUEUp3RdwFuLY,11037
73
+ nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
74
+ nshtrainer/trainer/checkpoint_connector.py,sha256=9DrZliK95BIZfwVFxL06Uf7DbfbQ5UAWd0xckH-LU6U,2125
75
+ nshtrainer/trainer/signal_connector.py,sha256=llwc8pdKAWxREFpjdi14Bpy8rGVMEJsmJx_s2p4gI8E,10689
76
+ nshtrainer/trainer/trainer.py,sha256=qHwerfdQUCU21IkWf50d_qZAIHb2d8qOLfqTszBpzks,16784
70
77
  nshtrainer/util/environment.py,sha256=_SEtiQ_s5bL5pllUlf96AOUv15kNvCPvocVC13S7mIk,4166
71
78
  nshtrainer/util/seed.py,sha256=HEXgVs-wldByahOysKwq7506OHxdYTEgmP-tDQVAEkQ,287
72
79
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
73
80
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
74
81
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
75
- nshtrainer-0.8.7.dist-info/METADATA,sha256=O1kFYWXIuVK1EU0TpwbpbADX1lJQmPM1-9xLTuNaNB8,647
76
- nshtrainer-0.8.7.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
77
- nshtrainer-0.8.7.dist-info/RECORD,,
82
+ nshtrainer-0.10.0.dist-info/METADATA,sha256=GslAMAaEXDbMxDd4ijoqjQKYBjb0iAnEGkZ3pAF_sOQ,695
83
+ nshtrainer-0.10.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
84
+ nshtrainer-0.10.0.dist-info/RECORD,,
@@ -1,3 +0,0 @@
1
- from nshutils.actsave import * # type: ignore # noqa: F403
2
-
3
- from ._callback import ActSaveCallback as ActSaveCallback