nshtrainer 0.26.1__py3-none-any.whl → 0.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/_hf_hub.py CHANGED
@@ -188,27 +188,19 @@ class HFHubCallback(NTCallbackBase):
188
188
 
189
189
  self.config = config
190
190
 
191
- self._repo_id = None
191
+ self._repo_id: str | None = None
192
192
  self._checksum_to_path_in_repo: dict[str, Path] = {}
193
193
 
194
194
  @override
195
195
  def setup(self, trainer, pl_module, stage):
196
- from .trainer.trainer import Trainer
197
-
198
- if not isinstance(trainer, Trainer):
199
- raise ValueError(
200
- f"HFHubCallback requires a `nshtrainer.Trainer` instance, got {type(trainer)}."
201
- )
202
-
203
196
  root_config = cast("BaseConfig", pl_module.hparams)
197
+ self._repo_id = _repo_name(self.api, root_config)
198
+
199
+ if not self.config or not trainer.is_global_zero:
200
+ return
204
201
 
205
202
  # Create the repository, if it doesn't exist
206
- self._repo_id = self.api.create_repo(
207
- repo_id=_repo_name(self.api, root_config),
208
- repo_type="model",
209
- private=self.config.auto_create.private,
210
- exist_ok=True,
211
- )
203
+ self._create_repo_if_not_exists()
212
204
 
213
205
  # Upload the config and code
214
206
  self._save_config(root_config)
@@ -216,17 +208,22 @@ class HFHubCallback(NTCallbackBase):
216
208
 
217
209
  @override
218
210
  def on_checkpoint_saved(self, ckpt_path, metadata_path, trainer, pl_module):
219
- root_config = cast("BaseConfig", pl_module.hparams)
220
-
221
211
  # If HF Hub is enabled, then we upload
222
- if self.config and trainer.is_global_zero:
223
- with self._with_error_handling("save checkpoints"):
224
- self._save_checkpoint(
225
- _Upload.from_local_path(ckpt_path, root_config),
226
- _Upload.from_local_path(metadata_path, root_config)
227
- if metadata_path is not None
228
- else None,
229
- )
212
+ if (
213
+ not self.config
214
+ or not self.config.save_checkpoints
215
+ or not trainer.is_global_zero
216
+ ):
217
+ return
218
+
219
+ with self._with_error_handling("save checkpoints"):
220
+ root_config = cast("BaseConfig", pl_module.hparams)
221
+ self._save_checkpoint(
222
+ _Upload.from_local_path(ckpt_path, root_config),
223
+ _Upload.from_local_path(metadata_path, root_config)
224
+ if metadata_path is not None
225
+ else None,
226
+ )
230
227
 
231
228
  @cached_property
232
229
  def api(self):
@@ -241,6 +238,33 @@ class HFHubCallback(NTCallbackBase):
241
238
  raise ValueError("Repository id has not been initialized.")
242
239
  return self._repo_id
243
240
 
241
+ def _create_repo_if_not_exists(self):
242
+ if not self.config or not self.config.auto_create:
243
+ return
244
+
245
+ # Create the repository, if it doesn't exist
246
+ with self._with_error_handling("create repository"):
247
+ from huggingface_hub.utils import RepositoryNotFoundError
248
+
249
+ try:
250
+ # Check if the repository exists
251
+ self.api.repo_info(repo_id=self.repo_id, repo_type="model")
252
+ log.info(f"Repository '{self.repo_id}' already exists.")
253
+ except RepositoryNotFoundError:
254
+ # Repository doesn't exist, so create it
255
+ try:
256
+ self.api.create_repo(
257
+ repo_id=self.repo_id,
258
+ repo_type="model",
259
+ private=self.config.auto_create.private,
260
+ exist_ok=True,
261
+ )
262
+ log.info(f"Created new repository '{self.repo_id}'.")
263
+ except Exception:
264
+ log.exception(f"Failed to create repository '{self.repo_id}'")
265
+ except Exception:
266
+ log.exception(f"Error checking repository '{self.repo_id}'")
267
+
244
268
  def _save_config(self, root_config: "BaseConfig"):
245
269
  with self._with_error_handling("upload config"):
246
270
  self.api.upload_file(
@@ -153,10 +153,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
153
153
  # Save the new checkpoint
154
154
  filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
155
155
  trainer._nshtrainer_save_checkpoint(
156
- filepath,
157
- self.config.save_weights_only,
158
- use_checkpoint_cache=None,
159
- ckpt_cache_use_symlink=False,
156
+ filepath, self.config.save_weights_only, use_checkpoint_cache=None
160
157
  )
161
158
 
162
159
  if trainer.is_global_zero:
@@ -1012,7 +1012,7 @@ class TrainerConfig(C.Config):
1012
1012
  """If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`"""
1013
1013
  save_checkpoint_metadata: bool = True
1014
1014
  """If enabled, will save additional metadata whenever a checkpoint is saved."""
1015
- use_checkpoint_cache: bool = True
1015
+ use_checkpoint_cache: bool = False
1016
1016
  """If enabled, will optimize the saving of duplicate checkpoints by creating symlinks instead of copying the file."""
1017
1017
 
1018
1018
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
@@ -1,6 +1,7 @@
1
1
  import logging
2
2
  import os
3
3
  import shutil
4
+ from collections import defaultdict
4
5
  from collections.abc import Sequence
5
6
  from pathlib import Path
6
7
  from typing import TYPE_CHECKING, Any, cast
@@ -18,7 +19,6 @@ from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
18
19
  from typing_extensions import Unpack, assert_never, override
19
20
 
20
21
  from .._checkpoint.metadata import _write_checkpoint_metadata
21
- from .._checkpoint.saver import _link_checkpoint
22
22
  from ..callbacks.base import resolve_all_callbacks
23
23
  from ..model.config import (
24
24
  AcceleratorConfigProtocol,
@@ -280,6 +280,12 @@ class Trainer(LightningTrainer):
280
280
  if TYPE_CHECKING:
281
281
  callbacks: list[Callback]
282
282
 
283
+ def _nshtrainer_checkpoint_cache_get(self, key: tuple[int, int]):
284
+ return next(
285
+ (ckpt for ckpt in self._nshtrainer_checkpoint_cache[key] if ckpt.exists()),
286
+ None,
287
+ )
288
+
283
289
  @override
284
290
  def __init__(
285
291
  self,
@@ -287,7 +293,9 @@ class Trainer(LightningTrainer):
287
293
  /,
288
294
  **kwargs: Unpack[LightningTrainerKwargs],
289
295
  ):
290
- self._nshtrainer_checkpoint_cache: dict[tuple[int, int], Path] = {}
296
+ self._nshtrainer_checkpoint_cache = defaultdict[tuple[int, int], list[Path]](
297
+ lambda: []
298
+ )
291
299
 
292
300
  self._pre_init(config)
293
301
 
@@ -417,7 +425,6 @@ class Trainer(LightningTrainer):
417
425
  weights_only: bool = False,
418
426
  storage_options: Any | None = None,
419
427
  use_checkpoint_cache: bool | None = None,
420
- ckpt_cache_use_symlink: bool = False,
421
428
  ):
422
429
  lm = self._base_module
423
430
  root_config = cast(BaseConfig, lm.hparams)
@@ -433,7 +440,7 @@ class Trainer(LightningTrainer):
433
440
  if (
434
441
  use_checkpoint_cache
435
442
  and (
436
- cached_path := self._nshtrainer_checkpoint_cache.get(
443
+ cached_path := self._nshtrainer_checkpoint_cache_get(
437
444
  (self.current_epoch, self.global_step)
438
445
  )
439
446
  )
@@ -442,10 +449,7 @@ class Trainer(LightningTrainer):
442
449
  # If we have a cached path, then we symlink it to the new path.
443
450
  log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
444
451
  if self.is_global_zero:
445
- if ckpt_cache_use_symlink:
446
- _link_checkpoint(cached_path, filepath, metadata=False)
447
- else:
448
- shutil.copy(cached_path, filepath)
452
+ shutil.copy(cached_path, filepath)
449
453
  self.strategy.barrier("Trainer.save_checkpoint")
450
454
  else:
451
455
  super().save_checkpoint(filepath, weights_only, storage_options)
@@ -454,7 +458,7 @@ class Trainer(LightningTrainer):
454
458
  if use_checkpoint_cache and cached_path is None:
455
459
  self._nshtrainer_checkpoint_cache[
456
460
  (self.current_epoch, self.global_step)
457
- ] = filepath
461
+ ].append(filepath)
458
462
  log.debug(f"Checkpoint saved to cache: {filepath}")
459
463
 
460
464
  # Save the checkpoint metadata
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.26.1
3
+ Version: 0.27.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -4,13 +4,13 @@ nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uP
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=hxZwwsUKVbBtt4wjqcKZbObx0PuO-qCdF3BTdnyqaQo,4711
5
5
  nshtrainer/_checkpoint/saver.py,sha256=1loCDYDy_Cay37uKs_wvxnkwvr41WMmga85qefct80Q,1271
6
6
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
7
- nshtrainer/_hf_hub.py,sha256=0K3uWa8hd2KyGuUYM7OXARcA7vuUiWWGSlP2USysY7o,12066
7
+ nshtrainer/_hf_hub.py,sha256=v1DV3Vn6Pdwr2KYI7yL_Xv1dyJLaG7yZhKFWZFy3QFk,13087
8
8
  nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
9
9
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
10
10
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
11
11
  nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
12
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
13
- nshtrainer/callbacks/checkpoint/_base.py,sha256=pLhLj7tvfY5czGY_vT0xRfWHzGJYC4iOBRLokFVq0mE,6733
13
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=iDtQ8urdSfl0uUJMvsSvvYo-fsL3ZVXCv79FGS67Jx4,6666
14
14
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=8BHgLAd3Tuzf5sup0guEAKF1jJiAwYsjdKBFYZw98ac,2171
15
15
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
16
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
@@ -58,7 +58,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
58
58
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
59
59
  nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
60
60
  nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
61
- nshtrainer/model/config.py,sha256=22_xIcdEO2pJzXgrFaqGFtk3PQEiwKiMZY1cjhoyWaA,43486
61
+ nshtrainer/model/config.py,sha256=9fPNrMkpOUIWwAM1UTQQRTVnuV7m3CDTtivya2GCyQY,43487
62
62
  nshtrainer/model/modules/callback.py,sha256=1z6gUDBd35KG3phGzRekgZM6SIk-wj5Uo6APN4YhRR0,8549
63
63
  nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
64
64
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
@@ -78,7 +78,7 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
78
78
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
79
79
  nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
80
80
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
81
- nshtrainer/trainer/trainer.py,sha256=Zwdcqfmrr7yuonsp4VrNOget8wkaZY9lf-_yeJ94lkk,19397
81
+ nshtrainer/trainer/trainer.py,sha256=Bm-kHOF_9W905SpOQcfEwHSUhF6KW8aAFdWNyIotyeM,19450
82
82
  nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
83
83
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
84
84
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
@@ -87,6 +87,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
87
87
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
88
88
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
89
89
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
90
- nshtrainer-0.26.1.dist-info/METADATA,sha256=tMMpyg1BTKec5d69ziW6XBxDXaI0gSK5tDMPCmj7VCQ,916
91
- nshtrainer-0.26.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
- nshtrainer-0.26.1.dist-info/RECORD,,
90
+ nshtrainer-0.27.0.dist-info/METADATA,sha256=fn2L1iuf7uaxT8UYzZd1lhKSKGEldmA0_026HG5f6ek,916
91
+ nshtrainer-0.27.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
92
+ nshtrainer-0.27.0.dist-info/RECORD,,