nshtrainer 0.26.2__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/_hf_hub.py CHANGED
@@ -359,19 +359,3 @@ class HFHubCallback(NTCallbackBase):
359
359
  # NOTE: This file is fairly small, so we can just upload it directly.
360
360
  # No need to copy.
361
361
  self._save_file(metadata_path)
362
-
363
- @override
364
- def state_dict(self):
365
- return {
366
- "repo_id": self._repo_id,
367
- "checksum_to_path_in_repo": {
368
- k: str(v) for k, v in self._checksum_to_path_in_repo.items()
369
- },
370
- }
371
-
372
- @override
373
- def load_state_dict(self, state_dict):
374
- self._repo_id = state_dict["repo_id"]
375
- self._checksum_to_path_in_repo = {
376
- k: Path(v) for k, v in state_dict["checksum_to_path_in_repo"].items()
377
- }
@@ -152,12 +152,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
152
152
 
153
153
  # Save the new checkpoint
154
154
  filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
155
- trainer._nshtrainer_save_checkpoint(
156
- filepath,
157
- self.config.save_weights_only,
158
- use_checkpoint_cache=None,
159
- ckpt_cache_use_symlink=False,
160
- )
155
+ trainer.save_checkpoint(filepath, self.config.save_weights_only)
161
156
 
162
157
  if trainer.is_global_zero:
163
158
  # Create the latest symlink
@@ -1012,8 +1012,6 @@ class TrainerConfig(C.Config):
1012
1012
  """If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`"""
1013
1013
  save_checkpoint_metadata: bool = True
1014
1014
  """If enabled, will save additional metadata whenever a checkpoint is saved."""
1015
- use_checkpoint_cache: bool = True
1016
- """If enabled, will optimize the saving of duplicate checkpoints by creating symlinks instead of copying the file."""
1017
1015
 
1018
1016
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
1019
1017
  """
@@ -1,6 +1,5 @@
1
1
  import logging
2
2
  import os
3
- import shutil
4
3
  from collections.abc import Sequence
5
4
  from pathlib import Path
6
5
  from typing import TYPE_CHECKING, Any, cast
@@ -18,7 +17,6 @@ from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
18
17
  from typing_extensions import Unpack, assert_never, override
19
18
 
20
19
  from .._checkpoint.metadata import _write_checkpoint_metadata
21
- from .._checkpoint.saver import _link_checkpoint
22
20
  from ..callbacks.base import resolve_all_callbacks
23
21
  from ..model.config import (
24
22
  AcceleratorConfigProtocol,
@@ -287,8 +285,6 @@ class Trainer(LightningTrainer):
287
285
  /,
288
286
  **kwargs: Unpack[LightningTrainerKwargs],
289
287
  ):
290
- self._nshtrainer_checkpoint_cache: dict[tuple[int, int], Path] = {}
291
-
292
288
  self._pre_init(config)
293
289
 
294
290
  kwargs = self._update_kwargs(config, kwargs)
@@ -411,54 +407,24 @@ class Trainer(LightningTrainer):
411
407
 
412
408
  return super()._run(model, ckpt_path)
413
409
 
414
- def _nshtrainer_save_checkpoint(
410
+ @override
411
+ def save_checkpoint(
415
412
  self,
416
413
  filepath: str | Path,
417
414
  weights_only: bool = False,
418
415
  storage_options: Any | None = None,
419
- use_checkpoint_cache: bool | None = None,
420
- ckpt_cache_use_symlink: bool = False,
421
416
  ):
422
- lm = self._base_module
423
- root_config = cast(BaseConfig, lm.hparams)
424
- if use_checkpoint_cache is None:
425
- use_checkpoint_cache = root_config.trainer.use_checkpoint_cache
426
-
427
417
  filepath = Path(filepath)
428
418
 
429
419
  # List of files that we should upload to HF
430
420
  written_files: list[Path] = [filepath]
431
421
 
432
- cached_path = None
433
- if (
434
- use_checkpoint_cache
435
- and (
436
- cached_path := self._nshtrainer_checkpoint_cache.get(
437
- (self.current_epoch, self.global_step)
438
- )
439
- )
440
- is not None
441
- ):
442
- # If we have a cached path, then we symlink it to the new path.
443
- log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
444
- if self.is_global_zero:
445
- if ckpt_cache_use_symlink:
446
- _link_checkpoint(cached_path, filepath, metadata=False)
447
- else:
448
- shutil.copy(cached_path, filepath)
449
- self.strategy.barrier("Trainer.save_checkpoint")
450
- else:
451
- super().save_checkpoint(filepath, weights_only, storage_options)
452
-
453
- # If we are using the cache but we don't have a cached path, then we save the checkpoint to the cache.
454
- if use_checkpoint_cache and cached_path is None:
455
- self._nshtrainer_checkpoint_cache[
456
- (self.current_epoch, self.global_step)
457
- ] = filepath
458
- log.debug(f"Checkpoint saved to cache: {filepath}")
422
+ super().save_checkpoint(filepath, weights_only, storage_options)
459
423
 
460
424
  # Save the checkpoint metadata
461
425
  metadata_path = None
426
+ lm = self._base_module
427
+ root_config = cast(BaseConfig, lm.hparams)
462
428
  if root_config.trainer.save_checkpoint_metadata and self.is_global_zero:
463
429
  # Generate the metadata and write to disk
464
430
  if (
@@ -470,17 +436,3 @@ class Trainer(LightningTrainer):
470
436
  from .. import _callback
471
437
 
472
438
  _callback._call_on_checkpoint_saved(self, filepath, metadata_path)
473
-
474
- @override
475
- def save_checkpoint(
476
- self,
477
- filepath: str | Path,
478
- weights_only: bool = False,
479
- storage_options: Any | None = None,
480
- ):
481
- return self._nshtrainer_save_checkpoint(
482
- filepath=filepath,
483
- weights_only=weights_only,
484
- storage_options=storage_options,
485
- use_checkpoint_cache=False,
486
- )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.26.2
3
+ Version: 0.28.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -4,13 +4,13 @@ nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uP
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=hxZwwsUKVbBtt4wjqcKZbObx0PuO-qCdF3BTdnyqaQo,4711
5
5
  nshtrainer/_checkpoint/saver.py,sha256=1loCDYDy_Cay37uKs_wvxnkwvr41WMmga85qefct80Q,1271
6
6
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
7
- nshtrainer/_hf_hub.py,sha256=v1DV3Vn6Pdwr2KYI7yL_Xv1dyJLaG7yZhKFWZFy3QFk,13087
7
+ nshtrainer/_hf_hub.py,sha256=42cR6viOiUInbBW4R7n9-AMnAn_ovN6YoLU3jD14PtI,12608
8
8
  nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
9
9
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
10
10
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
11
11
  nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
12
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
13
- nshtrainer/callbacks/checkpoint/_base.py,sha256=pLhLj7tvfY5czGY_vT0xRfWHzGJYC4iOBRLokFVq0mE,6733
13
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=dXCfyFyf2TjH9Pnc4FCqFUGrmy_n25MSCn8jdf3L-_8,6605
14
14
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=8BHgLAd3Tuzf5sup0guEAKF1jJiAwYsjdKBFYZw98ac,2171
15
15
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
16
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
@@ -58,7 +58,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
58
58
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
59
59
  nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
60
60
  nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
61
- nshtrainer/model/config.py,sha256=22_xIcdEO2pJzXgrFaqGFtk3PQEiwKiMZY1cjhoyWaA,43486
61
+ nshtrainer/model/config.py,sha256=zcCLcqvg4u7Zg6SLtCnqdIfiW8I0eART47lf1LCYl-A,43326
62
62
  nshtrainer/model/modules/callback.py,sha256=1z6gUDBd35KG3phGzRekgZM6SIk-wj5Uo6APN4YhRR0,8549
63
63
  nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
64
64
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
@@ -78,7 +78,7 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
78
78
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
79
79
  nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
80
80
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
81
- nshtrainer/trainer/trainer.py,sha256=Zwdcqfmrr7yuonsp4VrNOget8wkaZY9lf-_yeJ94lkk,19397
81
+ nshtrainer/trainer/trainer.py,sha256=L4nYXq6Gts2sS9CQGenwEcvMET4L5vO5c60KM5Hm8Do,17544
82
82
  nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
83
83
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
84
84
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
@@ -87,6 +87,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
87
87
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
88
88
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
89
89
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
90
- nshtrainer-0.26.2.dist-info/METADATA,sha256=d3vWjdB9FT6fbWJPkyBI-4M18ekg2WcJmJiKasExchM,916
91
- nshtrainer-0.26.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
- nshtrainer-0.26.2.dist-info/RECORD,,
90
+ nshtrainer-0.28.0.dist-info/METADATA,sha256=-4iVbXySZ9NGbwhqDWfBRdZmMNOEUx5ir1FKGrpdASg,916
91
+ nshtrainer-0.28.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
+ nshtrainer-0.28.0.dist-info/RECORD,,