nshtrainer 0.26.2__py3-none-any.whl → 0.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -153,10 +153,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
153
153
  # Save the new checkpoint
154
154
  filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
155
155
  trainer._nshtrainer_save_checkpoint(
156
- filepath,
157
- self.config.save_weights_only,
158
- use_checkpoint_cache=None,
159
- ckpt_cache_use_symlink=False,
156
+ filepath, self.config.save_weights_only, use_checkpoint_cache=None
160
157
  )
161
158
 
162
159
  if trainer.is_global_zero:
@@ -1012,7 +1012,7 @@ class TrainerConfig(C.Config):
1012
1012
  """If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`"""
1013
1013
  save_checkpoint_metadata: bool = True
1014
1014
  """If enabled, will save additional metadata whenever a checkpoint is saved."""
1015
- use_checkpoint_cache: bool = True
1015
+ use_checkpoint_cache: bool = False
1016
1016
  """If enabled, will optimize the saving of duplicate checkpoints by creating symlinks instead of copying the file."""
1017
1017
 
1018
1018
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
@@ -1,6 +1,7 @@
1
1
  import logging
2
2
  import os
3
3
  import shutil
4
+ from collections import defaultdict
4
5
  from collections.abc import Sequence
5
6
  from pathlib import Path
6
7
  from typing import TYPE_CHECKING, Any, cast
@@ -18,7 +19,6 @@ from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
18
19
  from typing_extensions import Unpack, assert_never, override
19
20
 
20
21
  from .._checkpoint.metadata import _write_checkpoint_metadata
21
- from .._checkpoint.saver import _link_checkpoint
22
22
  from ..callbacks.base import resolve_all_callbacks
23
23
  from ..model.config import (
24
24
  AcceleratorConfigProtocol,
@@ -280,6 +280,12 @@ class Trainer(LightningTrainer):
280
280
  if TYPE_CHECKING:
281
281
  callbacks: list[Callback]
282
282
 
283
+ def _nshtrainer_checkpoint_cache_get(self, key: tuple[int, int]):
284
+ return next(
285
+ (ckpt for ckpt in self._nshtrainer_checkpoint_cache[key] if ckpt.exists()),
286
+ None,
287
+ )
288
+
283
289
  @override
284
290
  def __init__(
285
291
  self,
@@ -287,7 +293,9 @@ class Trainer(LightningTrainer):
287
293
  /,
288
294
  **kwargs: Unpack[LightningTrainerKwargs],
289
295
  ):
290
- self._nshtrainer_checkpoint_cache: dict[tuple[int, int], Path] = {}
296
+ self._nshtrainer_checkpoint_cache = defaultdict[tuple[int, int], list[Path]](
297
+ lambda: []
298
+ )
291
299
 
292
300
  self._pre_init(config)
293
301
 
@@ -417,7 +425,6 @@ class Trainer(LightningTrainer):
417
425
  weights_only: bool = False,
418
426
  storage_options: Any | None = None,
419
427
  use_checkpoint_cache: bool | None = None,
420
- ckpt_cache_use_symlink: bool = False,
421
428
  ):
422
429
  lm = self._base_module
423
430
  root_config = cast(BaseConfig, lm.hparams)
@@ -433,7 +440,7 @@ class Trainer(LightningTrainer):
433
440
  if (
434
441
  use_checkpoint_cache
435
442
  and (
436
- cached_path := self._nshtrainer_checkpoint_cache.get(
443
+ cached_path := self._nshtrainer_checkpoint_cache_get(
437
444
  (self.current_epoch, self.global_step)
438
445
  )
439
446
  )
@@ -442,10 +449,7 @@ class Trainer(LightningTrainer):
442
449
  # If we have a cached path, then we symlink it to the new path.
443
450
  log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
444
451
  if self.is_global_zero:
445
- if ckpt_cache_use_symlink:
446
- _link_checkpoint(cached_path, filepath, metadata=False)
447
- else:
448
- shutil.copy(cached_path, filepath)
452
+ shutil.copy(cached_path, filepath)
449
453
  self.strategy.barrier("Trainer.save_checkpoint")
450
454
  else:
451
455
  super().save_checkpoint(filepath, weights_only, storage_options)
@@ -454,7 +458,7 @@ class Trainer(LightningTrainer):
454
458
  if use_checkpoint_cache and cached_path is None:
455
459
  self._nshtrainer_checkpoint_cache[
456
460
  (self.current_epoch, self.global_step)
457
- ] = filepath
461
+ ].append(filepath)
458
462
  log.debug(f"Checkpoint saved to cache: {filepath}")
459
463
 
460
464
  # Save the checkpoint metadata
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.26.2
3
+ Version: 0.27.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -10,7 +10,7 @@ nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHm
10
10
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
11
11
  nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
12
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
13
- nshtrainer/callbacks/checkpoint/_base.py,sha256=pLhLj7tvfY5czGY_vT0xRfWHzGJYC4iOBRLokFVq0mE,6733
13
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=iDtQ8urdSfl0uUJMvsSvvYo-fsL3ZVXCv79FGS67Jx4,6666
14
14
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=8BHgLAd3Tuzf5sup0guEAKF1jJiAwYsjdKBFYZw98ac,2171
15
15
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
16
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
@@ -58,7 +58,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
58
58
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
59
59
  nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
60
60
  nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
61
- nshtrainer/model/config.py,sha256=22_xIcdEO2pJzXgrFaqGFtk3PQEiwKiMZY1cjhoyWaA,43486
61
+ nshtrainer/model/config.py,sha256=9fPNrMkpOUIWwAM1UTQQRTVnuV7m3CDTtivya2GCyQY,43487
62
62
  nshtrainer/model/modules/callback.py,sha256=1z6gUDBd35KG3phGzRekgZM6SIk-wj5Uo6APN4YhRR0,8549
63
63
  nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
64
64
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
@@ -78,7 +78,7 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
78
78
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
79
79
  nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
80
80
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
81
- nshtrainer/trainer/trainer.py,sha256=Zwdcqfmrr7yuonsp4VrNOget8wkaZY9lf-_yeJ94lkk,19397
81
+ nshtrainer/trainer/trainer.py,sha256=Bm-kHOF_9W905SpOQcfEwHSUhF6KW8aAFdWNyIotyeM,19450
82
82
  nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
83
83
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
84
84
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
@@ -87,6 +87,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
87
87
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
88
88
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
89
89
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
90
- nshtrainer-0.26.2.dist-info/METADATA,sha256=d3vWjdB9FT6fbWJPkyBI-4M18ekg2WcJmJiKasExchM,916
91
- nshtrainer-0.26.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
- nshtrainer-0.26.2.dist-info/RECORD,,
90
+ nshtrainer-0.27.0.dist-info/METADATA,sha256=fn2L1iuf7uaxT8UYzZd1lhKSKGEldmA0_026HG5f6ek,916
91
+ nshtrainer-0.27.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
+ nshtrainer-0.27.0.dist-info/RECORD,,