nshtrainer 0.27.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/_hf_hub.py CHANGED
@@ -359,19 +359,3 @@ class HFHubCallback(NTCallbackBase):
359
359
  # NOTE: This file is fairly small, so we can just upload it directly.
360
360
  # No need to copy.
361
361
  self._save_file(metadata_path)
362
-
363
- @override
364
- def state_dict(self):
365
- return {
366
- "repo_id": self._repo_id,
367
- "checksum_to_path_in_repo": {
368
- k: str(v) for k, v in self._checksum_to_path_in_repo.items()
369
- },
370
- }
371
-
372
- @override
373
- def load_state_dict(self, state_dict):
374
- self._repo_id = state_dict["repo_id"]
375
- self._checksum_to_path_in_repo = {
376
- k: Path(v) for k, v in state_dict["checksum_to_path_in_repo"].items()
377
- }
@@ -152,9 +152,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
152
152
 
153
153
  # Save the new checkpoint
154
154
  filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
155
- trainer._nshtrainer_save_checkpoint(
156
- filepath, self.config.save_weights_only, use_checkpoint_cache=None
157
- )
155
+ trainer.save_checkpoint(filepath, self.config.save_weights_only)
158
156
 
159
157
  if trainer.is_global_zero:
160
158
  # Create the latest symlink
@@ -1012,8 +1012,6 @@ class TrainerConfig(C.Config):
1012
1012
  """If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`"""
1013
1013
  save_checkpoint_metadata: bool = True
1014
1014
  """If enabled, will save additional metadata whenever a checkpoint is saved."""
1015
- use_checkpoint_cache: bool = False
1016
- """If enabled, will optimize the saving of duplicate checkpoints by creating symlinks instead of copying the file."""
1017
1015
 
1018
1016
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
1019
1017
  """
@@ -1,7 +1,5 @@
1
1
  import logging
2
2
  import os
3
- import shutil
4
- from collections import defaultdict
5
3
  from collections.abc import Sequence
6
4
  from pathlib import Path
7
5
  from typing import TYPE_CHECKING, Any, cast
@@ -280,12 +278,6 @@ class Trainer(LightningTrainer):
280
278
  if TYPE_CHECKING:
281
279
  callbacks: list[Callback]
282
280
 
283
- def _nshtrainer_checkpoint_cache_get(self, key: tuple[int, int]):
284
- return next(
285
- (ckpt for ckpt in self._nshtrainer_checkpoint_cache[key] if ckpt.exists()),
286
- None,
287
- )
288
-
289
281
  @override
290
282
  def __init__(
291
283
  self,
@@ -293,10 +285,6 @@ class Trainer(LightningTrainer):
293
285
  /,
294
286
  **kwargs: Unpack[LightningTrainerKwargs],
295
287
  ):
296
- self._nshtrainer_checkpoint_cache = defaultdict[tuple[int, int], list[Path]](
297
- lambda: []
298
- )
299
-
300
288
  self._pre_init(config)
301
289
 
302
290
  kwargs = self._update_kwargs(config, kwargs)
@@ -419,50 +407,24 @@ class Trainer(LightningTrainer):
419
407
 
420
408
  return super()._run(model, ckpt_path)
421
409
 
422
- def _nshtrainer_save_checkpoint(
410
+ @override
411
+ def save_checkpoint(
423
412
  self,
424
413
  filepath: str | Path,
425
414
  weights_only: bool = False,
426
415
  storage_options: Any | None = None,
427
- use_checkpoint_cache: bool | None = None,
428
416
  ):
429
- lm = self._base_module
430
- root_config = cast(BaseConfig, lm.hparams)
431
- if use_checkpoint_cache is None:
432
- use_checkpoint_cache = root_config.trainer.use_checkpoint_cache
433
-
434
417
  filepath = Path(filepath)
435
418
 
436
419
  # List of files that we should upload to HF
437
420
  written_files: list[Path] = [filepath]
438
421
 
439
- cached_path = None
440
- if (
441
- use_checkpoint_cache
442
- and (
443
- cached_path := self._nshtrainer_checkpoint_cache_get(
444
- (self.current_epoch, self.global_step)
445
- )
446
- )
447
- is not None
448
- ):
449
- # If we have a cached path, then we symlink it to the new path.
450
- log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
451
- if self.is_global_zero:
452
- shutil.copy(cached_path, filepath)
453
- self.strategy.barrier("Trainer.save_checkpoint")
454
- else:
455
- super().save_checkpoint(filepath, weights_only, storage_options)
456
-
457
- # If we are using the cache but we don't have a cached path, then we save the checkpoint to the cache.
458
- if use_checkpoint_cache and cached_path is None:
459
- self._nshtrainer_checkpoint_cache[
460
- (self.current_epoch, self.global_step)
461
- ].append(filepath)
462
- log.debug(f"Checkpoint saved to cache: {filepath}")
422
+ super().save_checkpoint(filepath, weights_only, storage_options)
463
423
 
464
424
  # Save the checkpoint metadata
465
425
  metadata_path = None
426
+ lm = self._base_module
427
+ root_config = cast(BaseConfig, lm.hparams)
466
428
  if root_config.trainer.save_checkpoint_metadata and self.is_global_zero:
467
429
  # Generate the metadata and write to disk
468
430
  if (
@@ -474,17 +436,3 @@ class Trainer(LightningTrainer):
474
436
  from .. import _callback
475
437
 
476
438
  _callback._call_on_checkpoint_saved(self, filepath, metadata_path)
477
-
478
- @override
479
- def save_checkpoint(
480
- self,
481
- filepath: str | Path,
482
- weights_only: bool = False,
483
- storage_options: Any | None = None,
484
- ):
485
- return self._nshtrainer_save_checkpoint(
486
- filepath=filepath,
487
- weights_only=weights_only,
488
- storage_options=storage_options,
489
- use_checkpoint_cache=False,
490
- )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.27.0
3
+ Version: 0.28.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -4,13 +4,13 @@ nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uP
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=hxZwwsUKVbBtt4wjqcKZbObx0PuO-qCdF3BTdnyqaQo,4711
5
5
  nshtrainer/_checkpoint/saver.py,sha256=1loCDYDy_Cay37uKs_wvxnkwvr41WMmga85qefct80Q,1271
6
6
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
7
- nshtrainer/_hf_hub.py,sha256=v1DV3Vn6Pdwr2KYI7yL_Xv1dyJLaG7yZhKFWZFy3QFk,13087
7
+ nshtrainer/_hf_hub.py,sha256=42cR6viOiUInbBW4R7n9-AMnAn_ovN6YoLU3jD14PtI,12608
8
8
  nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
9
9
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
10
10
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
11
11
  nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
12
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
13
- nshtrainer/callbacks/checkpoint/_base.py,sha256=iDtQ8urdSfl0uUJMvsSvvYo-fsL3ZVXCv79FGS67Jx4,6666
13
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=dXCfyFyf2TjH9Pnc4FCqFUGrmy_n25MSCn8jdf3L-_8,6605
14
14
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=8BHgLAd3Tuzf5sup0guEAKF1jJiAwYsjdKBFYZw98ac,2171
15
15
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
16
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
@@ -58,7 +58,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
58
58
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
59
59
  nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
60
60
  nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
61
- nshtrainer/model/config.py,sha256=9fPNrMkpOUIWwAM1UTQQRTVnuV7m3CDTtivya2GCyQY,43487
61
+ nshtrainer/model/config.py,sha256=zcCLcqvg4u7Zg6SLtCnqdIfiW8I0eART47lf1LCYl-A,43326
62
62
  nshtrainer/model/modules/callback.py,sha256=1z6gUDBd35KG3phGzRekgZM6SIk-wj5Uo6APN4YhRR0,8549
63
63
  nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
64
64
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
@@ -78,7 +78,7 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
78
78
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
79
79
  nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
80
80
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
81
- nshtrainer/trainer/trainer.py,sha256=Bm-kHOF_9W905SpOQcfEwHSUhF6KW8aAFdWNyIotyeM,19450
81
+ nshtrainer/trainer/trainer.py,sha256=L4nYXq6Gts2sS9CQGenwEcvMET4L5vO5c60KM5Hm8Do,17544
82
82
  nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
83
83
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
84
84
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
@@ -87,6 +87,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
87
87
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
88
88
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
89
89
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
90
- nshtrainer-0.27.0.dist-info/METADATA,sha256=fn2L1iuf7uaxT8UYzZd1lhKSKGEldmA0_026HG5f6ek,916
91
- nshtrainer-0.27.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
- nshtrainer-0.27.0.dist-info/RECORD,,
90
+ nshtrainer-0.28.0.dist-info/METADATA,sha256=-4iVbXySZ9NGbwhqDWfBRdZmMNOEUx5ir1FKGrpdASg,916
91
+ nshtrainer-0.28.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
+ nshtrainer-0.28.0.dist-info/RECORD,,