nshtrainer 0.27.0__py3-none-any.whl → 0.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -117,7 +117,7 @@ def _write_checkpoint_metadata(
117
117
  try:
118
118
  metadata_path.write_text(metadata.model_dump_json(indent=4), encoding="utf-8")
119
119
  except Exception:
120
- log.exception(f"Failed to write metadata to {metadata_path}")
120
+ log.warning(f"Failed to write metadata to {metadata_path}", exc_info=True)
121
121
  return None
122
122
 
123
123
  log.debug(f"Checkpoint metadata written to {metadata_path}")
@@ -129,7 +129,7 @@ def _remove_checkpoint_metadata(checkpoint_path: Path):
129
129
  try:
130
130
  path.unlink(missing_ok=True)
131
131
  except Exception:
132
- log.exception(f"Failed to remove {path}")
132
+ log.warning(f"Failed to remove {path}", exc_info=True)
133
133
  else:
134
134
  log.debug(f"Removed {path}")
135
135
 
@@ -25,11 +25,11 @@ def _link_checkpoint(
25
25
  try:
26
26
  if linkpath.exists():
27
27
  if linkpath.is_dir():
28
- shutil.rmtree(linkpath, ignore_errors=True)
28
+ shutil.rmtree(linkpath)
29
29
  else:
30
30
  linkpath.unlink(missing_ok=True)
31
31
  except Exception:
32
- log.exception(f"Failed to remove {linkpath}")
32
+ log.warning(f"Failed to remove {linkpath}", exc_info=True)
33
33
 
34
34
  if metadata:
35
35
  _remove_checkpoint_metadata(linkpath)
nshtrainer/_hf_hub.py CHANGED
@@ -179,7 +179,9 @@ class HFHubCallback(NTCallbackBase):
179
179
  try:
180
180
  yield
181
181
  except Exception:
182
- log.exception(f"Failed to {opeartion}, repo_id={self._repo_id}")
182
+ log.warning(
183
+ f"Failed to {opeartion}, repo_id={self._repo_id}", exc_info=True
184
+ )
183
185
  else:
184
186
  log.debug(f"Successfully {opeartion}, repo_id={self._repo_id}")
185
187
 
@@ -261,9 +263,13 @@ class HFHubCallback(NTCallbackBase):
261
263
  )
262
264
  log.info(f"Created new repository '{self.repo_id}'.")
263
265
  except Exception:
264
- log.exception(f"Failed to create repository '{self.repo_id}'")
266
+ log.warning(
267
+ f"Failed to create repository '{self.repo_id}'", exc_info=True
268
+ )
265
269
  except Exception:
266
- log.exception(f"Error checking repository '{self.repo_id}'")
270
+ log.warning(
271
+ f"Error checking repository '{self.repo_id}'", exc_info=True
272
+ )
267
273
 
268
274
  def _save_config(self, root_config: "BaseConfig"):
269
275
  with self._with_error_handling("upload config"):
@@ -300,9 +306,15 @@ class HFHubCallback(NTCallbackBase):
300
306
 
301
307
  def _save_file(self, p: _Upload):
302
308
  with self._with_error_handling("save file"):
309
+ # First, read the file into memory.
310
+ # We do this to avoid issues with
311
+ # the file being moved or deleted.
312
+ with p.local_path.open("rb") as f:
313
+ data = f.read()
314
+
303
315
  # Upload the checkpoint files to the repository
304
316
  self.api.upload_file(
305
- path_or_fileobj=p.local_path,
317
+ path_or_fileobj=data,
306
318
  path_in_repo=str(p.path_in_repo),
307
319
  repo_id=self.repo_id,
308
320
  repo_type="model",
@@ -359,19 +371,3 @@ class HFHubCallback(NTCallbackBase):
359
371
  # NOTE: This file is fairly small, so we can just upload it directly.
360
372
  # No need to copy.
361
373
  self._save_file(metadata_path)
362
-
363
- @override
364
- def state_dict(self):
365
- return {
366
- "repo_id": self._repo_id,
367
- "checksum_to_path_in_repo": {
368
- k: str(v) for k, v in self._checksum_to_path_in_repo.items()
369
- },
370
- }
371
-
372
- @override
373
- def load_state_dict(self, state_dict):
374
- self._repo_id = state_dict["repo_id"]
375
- self._checksum_to_path_in_repo = {
376
- k: Path(v) for k, v in state_dict["checksum_to_path_in_repo"].items()
377
- }
@@ -9,7 +9,7 @@ from lightning.pytorch import Trainer
9
9
  from lightning.pytorch.callbacks import Checkpoint
10
10
  from typing_extensions import TypeVar, override
11
11
 
12
- from ..._checkpoint.metadata import CheckpointMetadata, _metadata_path
12
+ from ..._checkpoint.metadata import CheckpointMetadata
13
13
  from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
14
14
  from ..base import CallbackConfigBase
15
15
 
@@ -152,9 +152,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
152
152
 
153
153
  # Save the new checkpoint
154
154
  filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
155
- trainer._nshtrainer_save_checkpoint(
156
- filepath, self.config.save_weights_only, use_checkpoint_cache=None
157
- )
155
+ trainer.save_checkpoint(filepath, self.config.save_weights_only)
158
156
 
159
157
  if trainer.is_global_zero:
160
158
  # Create the latest symlink
@@ -170,16 +168,6 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
170
168
  # deleted the old checkpoints, and created the symlink before continuing
171
169
  trainer.strategy.barrier()
172
170
 
173
- # Call the on save checkpoint callback for the symlink (if it exists)
174
- if (symlink_filename := self.symlink_path()) is not None:
175
- from ... import _callback
176
-
177
- symlink_path = self.dirpath / symlink_filename
178
- symlink_metadata_path = _metadata_path(symlink_path)
179
- _callback._call_on_checkpoint_saved(
180
- trainer, symlink_path, symlink_metadata_path
181
- )
182
-
183
171
  def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
184
172
  from lightning.pytorch.trainer.states import TrainerFn
185
173
 
@@ -1012,8 +1012,6 @@ class TrainerConfig(C.Config):
1012
1012
  """If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`"""
1013
1013
  save_checkpoint_metadata: bool = True
1014
1014
  """If enabled, will save additional metadata whenever a checkpoint is saved."""
1015
- use_checkpoint_cache: bool = False
1016
- """If enabled, will optimize the saving of duplicate checkpoints by creating symlinks instead of copying the file."""
1017
1015
 
1018
1016
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
1019
1017
  """
@@ -1,7 +1,5 @@
1
1
  import logging
2
2
  import os
3
- import shutil
4
- from collections import defaultdict
5
3
  from collections.abc import Sequence
6
4
  from pathlib import Path
7
5
  from typing import TYPE_CHECKING, Any, cast
@@ -280,12 +278,6 @@ class Trainer(LightningTrainer):
280
278
  if TYPE_CHECKING:
281
279
  callbacks: list[Callback]
282
280
 
283
- def _nshtrainer_checkpoint_cache_get(self, key: tuple[int, int]):
284
- return next(
285
- (ckpt for ckpt in self._nshtrainer_checkpoint_cache[key] if ckpt.exists()),
286
- None,
287
- )
288
-
289
281
  @override
290
282
  def __init__(
291
283
  self,
@@ -293,10 +285,6 @@ class Trainer(LightningTrainer):
293
285
  /,
294
286
  **kwargs: Unpack[LightningTrainerKwargs],
295
287
  ):
296
- self._nshtrainer_checkpoint_cache = defaultdict[tuple[int, int], list[Path]](
297
- lambda: []
298
- )
299
-
300
288
  self._pre_init(config)
301
289
 
302
290
  kwargs = self._update_kwargs(config, kwargs)
@@ -419,50 +407,24 @@ class Trainer(LightningTrainer):
419
407
 
420
408
  return super()._run(model, ckpt_path)
421
409
 
422
- def _nshtrainer_save_checkpoint(
410
+ @override
411
+ def save_checkpoint(
423
412
  self,
424
413
  filepath: str | Path,
425
414
  weights_only: bool = False,
426
415
  storage_options: Any | None = None,
427
- use_checkpoint_cache: bool | None = None,
428
416
  ):
429
- lm = self._base_module
430
- root_config = cast(BaseConfig, lm.hparams)
431
- if use_checkpoint_cache is None:
432
- use_checkpoint_cache = root_config.trainer.use_checkpoint_cache
433
-
434
417
  filepath = Path(filepath)
435
418
 
436
419
  # List of files that we should upload to HF
437
420
  written_files: list[Path] = [filepath]
438
421
 
439
- cached_path = None
440
- if (
441
- use_checkpoint_cache
442
- and (
443
- cached_path := self._nshtrainer_checkpoint_cache_get(
444
- (self.current_epoch, self.global_step)
445
- )
446
- )
447
- is not None
448
- ):
449
- # If we have a cached path, then we symlink it to the new path.
450
- log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
451
- if self.is_global_zero:
452
- shutil.copy(cached_path, filepath)
453
- self.strategy.barrier("Trainer.save_checkpoint")
454
- else:
455
- super().save_checkpoint(filepath, weights_only, storage_options)
456
-
457
- # If we are using the cache but we don't have a cached path, then we save the checkpoint to the cache.
458
- if use_checkpoint_cache and cached_path is None:
459
- self._nshtrainer_checkpoint_cache[
460
- (self.current_epoch, self.global_step)
461
- ].append(filepath)
462
- log.debug(f"Checkpoint saved to cache: {filepath}")
422
+ super().save_checkpoint(filepath, weights_only, storage_options)
463
423
 
464
424
  # Save the checkpoint metadata
465
425
  metadata_path = None
426
+ lm = self._base_module
427
+ root_config = cast(BaseConfig, lm.hparams)
466
428
  if root_config.trainer.save_checkpoint_metadata and self.is_global_zero:
467
429
  # Generate the metadata and write to disk
468
430
  if (
@@ -474,17 +436,3 @@ class Trainer(LightningTrainer):
474
436
  from .. import _callback
475
437
 
476
438
  _callback._call_on_checkpoint_saved(self, filepath, metadata_path)
477
-
478
- @override
479
- def save_checkpoint(
480
- self,
481
- filepath: str | Path,
482
- weights_only: bool = False,
483
- storage_options: Any | None = None,
484
- ):
485
- return self._nshtrainer_save_checkpoint(
486
- filepath=filepath,
487
- weights_only=weights_only,
488
- storage_options=storage_options,
489
- use_checkpoint_cache=False,
490
- )
@@ -434,7 +434,7 @@ class EnvironmentPackageConfig(C.Config):
434
434
  requires=requires,
435
435
  )
436
436
  except Exception:
437
- log.exception(f"Error processing package {dist.name}")
437
+ log.warning(f"Error processing package {dist.name}", exc_info=True)
438
438
 
439
439
  except ImportError:
440
440
  log.warning(
@@ -673,7 +673,7 @@ class GitRepositoryConfig(C.Config):
673
673
  except git.InvalidGitRepositoryError:
674
674
  draft.is_git_repo = False
675
675
  except Exception:
676
- log.exception("Failed to get Git repository information")
676
+ log.warning("Failed to get Git repository information", exc_info=True)
677
677
  draft.is_git_repo = None
678
678
 
679
679
  return draft.finalize()
nshtrainer/util/path.py CHANGED
@@ -97,7 +97,10 @@ def try_symlink_or_copy(
97
97
  symlink_target, target_is_directory=target_is_directory
98
98
  )
99
99
  except Exception:
100
- log.exception(f"Failed to create symlink or copy {file_path} to {link_path}")
100
+ log.warning(
101
+ f"Failed to create symlink or copy {file_path} to {link_path}",
102
+ exc_info=True,
103
+ )
101
104
  return False
102
105
  else:
103
106
  log.debug(f"Created symlink or copied {file_path} to {link_path}")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.27.0
3
+ Version: 0.29.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,16 +1,16 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
2
  nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
3
3
  nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
4
- nshtrainer/_checkpoint/metadata.py,sha256=hxZwwsUKVbBtt4wjqcKZbObx0PuO-qCdF3BTdnyqaQo,4711
5
- nshtrainer/_checkpoint/saver.py,sha256=1loCDYDy_Cay37uKs_wvxnkwvr41WMmga85qefct80Q,1271
4
+ nshtrainer/_checkpoint/metadata.py,sha256=5D4PgKodzhLsmQvuF3xxkH49epKaegxi4wh_ImDTtns,4737
5
+ nshtrainer/_checkpoint/saver.py,sha256=MbX_WjkDtHHAf9Ms-KXDlknkjiPXVoGIe2ciO28AdZ0,1264
6
6
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
7
- nshtrainer/_hf_hub.py,sha256=v1DV3Vn6Pdwr2KYI7yL_Xv1dyJLaG7yZhKFWZFy3QFk,13087
7
+ nshtrainer/_hf_hub.py,sha256=0bkXkqhve5D1onMW-fCfuvVKlTn0i6jv_6uMNgZ7OHQ,12974
8
8
  nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
9
9
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
10
10
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
11
11
  nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
12
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
13
- nshtrainer/callbacks/checkpoint/_base.py,sha256=iDtQ8urdSfl0uUJMvsSvvYo-fsL3ZVXCv79FGS67Jx4,6666
13
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=MzMF7JtvR3A_7DAM2r4NGQSBDisA7krv6WlVk5rKABQ,6157
14
14
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=8BHgLAd3Tuzf5sup0guEAKF1jJiAwYsjdKBFYZw98ac,2171
15
15
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
16
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
@@ -58,7 +58,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
58
58
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
59
59
  nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
60
60
  nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
61
- nshtrainer/model/config.py,sha256=9fPNrMkpOUIWwAM1UTQQRTVnuV7m3CDTtivya2GCyQY,43487
61
+ nshtrainer/model/config.py,sha256=zcCLcqvg4u7Zg6SLtCnqdIfiW8I0eART47lf1LCYl-A,43326
62
62
  nshtrainer/model/modules/callback.py,sha256=1z6gUDBd35KG3phGzRekgZM6SIk-wj5Uo6APN4YhRR0,8549
63
63
  nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
64
64
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
@@ -78,15 +78,15 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
78
78
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
79
79
  nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
80
80
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
81
- nshtrainer/trainer/trainer.py,sha256=Bm-kHOF_9W905SpOQcfEwHSUhF6KW8aAFdWNyIotyeM,19450
82
- nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
81
+ nshtrainer/trainer/trainer.py,sha256=L4nYXq6Gts2sS9CQGenwEcvMET4L5vO5c60KM5Hm8Do,17544
82
+ nshtrainer/util/_environment_info.py,sha256=CFUUZYjXhBLWGc0jtPNOaZgYMueUDEHpEaWFA1f3GoY,24213
83
83
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
84
84
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
85
- nshtrainer/util/path.py,sha256=jAEjF1qp8Aii32L5lWG4UFgVyQAFkHOMYEc_TC2hDx8,2947
85
+ nshtrainer/util/path.py,sha256=VkpuhR4GaZtSFBVqbGAvfjcrU-PR8xwiGzzwFNOWP9c,2995
86
86
  nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
87
87
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
88
88
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
89
89
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
90
- nshtrainer-0.27.0.dist-info/METADATA,sha256=fn2L1iuf7uaxT8UYzZd1lhKSKGEldmA0_026HG5f6ek,916
91
- nshtrainer-0.27.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
- nshtrainer-0.27.0.dist-info/RECORD,,
90
+ nshtrainer-0.29.0.dist-info/METADATA,sha256=EP3cdORGt4w_H0pX-whQJ5ULsO5HQXo3VlHp5bkfqfk,916
91
+ nshtrainer-0.29.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
+ nshtrainer-0.29.0.dist-info/RECORD,,