nshtrainer 0.27.0__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_hf_hub.py +0 -16
- nshtrainer/callbacks/checkpoint/_base.py +1 -3
- nshtrainer/model/config.py +0 -2
- nshtrainer/trainer/trainer.py +5 -57
- {nshtrainer-0.27.0.dist-info → nshtrainer-0.28.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.27.0.dist-info → nshtrainer-0.28.0.dist-info}/RECORD +7 -7
- {nshtrainer-0.27.0.dist-info → nshtrainer-0.28.0.dist-info}/WHEEL +0 -0
nshtrainer/_hf_hub.py
CHANGED
|
@@ -359,19 +359,3 @@ class HFHubCallback(NTCallbackBase):
|
|
|
359
359
|
# NOTE: This file is fairly small, so we can just upload it directly.
|
|
360
360
|
# No need to copy.
|
|
361
361
|
self._save_file(metadata_path)
|
|
362
|
-
|
|
363
|
-
@override
|
|
364
|
-
def state_dict(self):
|
|
365
|
-
return {
|
|
366
|
-
"repo_id": self._repo_id,
|
|
367
|
-
"checksum_to_path_in_repo": {
|
|
368
|
-
k: str(v) for k, v in self._checksum_to_path_in_repo.items()
|
|
369
|
-
},
|
|
370
|
-
}
|
|
371
|
-
|
|
372
|
-
@override
|
|
373
|
-
def load_state_dict(self, state_dict):
|
|
374
|
-
self._repo_id = state_dict["repo_id"]
|
|
375
|
-
self._checksum_to_path_in_repo = {
|
|
376
|
-
k: Path(v) for k, v in state_dict["checksum_to_path_in_repo"].items()
|
|
377
|
-
}
|
|
@@ -152,9 +152,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
152
152
|
|
|
153
153
|
# Save the new checkpoint
|
|
154
154
|
filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
|
|
155
|
-
trainer.
|
|
156
|
-
filepath, self.config.save_weights_only, use_checkpoint_cache=None
|
|
157
|
-
)
|
|
155
|
+
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
158
156
|
|
|
159
157
|
if trainer.is_global_zero:
|
|
160
158
|
# Create the latest symlink
|
nshtrainer/model/config.py
CHANGED
|
@@ -1012,8 +1012,6 @@ class TrainerConfig(C.Config):
|
|
|
1012
1012
|
"""If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`"""
|
|
1013
1013
|
save_checkpoint_metadata: bool = True
|
|
1014
1014
|
"""If enabled, will save additional metadata whenever a checkpoint is saved."""
|
|
1015
|
-
use_checkpoint_cache: bool = False
|
|
1016
|
-
"""If enabled, will optimize the saving of duplicate checkpoints by creating symlinks instead of copying the file."""
|
|
1017
1015
|
|
|
1018
1016
|
lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
|
|
1019
1017
|
"""
|
nshtrainer/trainer/trainer.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
-
import shutil
|
|
4
|
-
from collections import defaultdict
|
|
5
3
|
from collections.abc import Sequence
|
|
6
4
|
from pathlib import Path
|
|
7
5
|
from typing import TYPE_CHECKING, Any, cast
|
|
@@ -280,12 +278,6 @@ class Trainer(LightningTrainer):
|
|
|
280
278
|
if TYPE_CHECKING:
|
|
281
279
|
callbacks: list[Callback]
|
|
282
280
|
|
|
283
|
-
def _nshtrainer_checkpoint_cache_get(self, key: tuple[int, int]):
|
|
284
|
-
return next(
|
|
285
|
-
(ckpt for ckpt in self._nshtrainer_checkpoint_cache[key] if ckpt.exists()),
|
|
286
|
-
None,
|
|
287
|
-
)
|
|
288
|
-
|
|
289
281
|
@override
|
|
290
282
|
def __init__(
|
|
291
283
|
self,
|
|
@@ -293,10 +285,6 @@ class Trainer(LightningTrainer):
|
|
|
293
285
|
/,
|
|
294
286
|
**kwargs: Unpack[LightningTrainerKwargs],
|
|
295
287
|
):
|
|
296
|
-
self._nshtrainer_checkpoint_cache = defaultdict[tuple[int, int], list[Path]](
|
|
297
|
-
lambda: []
|
|
298
|
-
)
|
|
299
|
-
|
|
300
288
|
self._pre_init(config)
|
|
301
289
|
|
|
302
290
|
kwargs = self._update_kwargs(config, kwargs)
|
|
@@ -419,50 +407,24 @@ class Trainer(LightningTrainer):
|
|
|
419
407
|
|
|
420
408
|
return super()._run(model, ckpt_path)
|
|
421
409
|
|
|
422
|
-
|
|
410
|
+
@override
|
|
411
|
+
def save_checkpoint(
|
|
423
412
|
self,
|
|
424
413
|
filepath: str | Path,
|
|
425
414
|
weights_only: bool = False,
|
|
426
415
|
storage_options: Any | None = None,
|
|
427
|
-
use_checkpoint_cache: bool | None = None,
|
|
428
416
|
):
|
|
429
|
-
lm = self._base_module
|
|
430
|
-
root_config = cast(BaseConfig, lm.hparams)
|
|
431
|
-
if use_checkpoint_cache is None:
|
|
432
|
-
use_checkpoint_cache = root_config.trainer.use_checkpoint_cache
|
|
433
|
-
|
|
434
417
|
filepath = Path(filepath)
|
|
435
418
|
|
|
436
419
|
# List of files that we should upload to HF
|
|
437
420
|
written_files: list[Path] = [filepath]
|
|
438
421
|
|
|
439
|
-
|
|
440
|
-
if (
|
|
441
|
-
use_checkpoint_cache
|
|
442
|
-
and (
|
|
443
|
-
cached_path := self._nshtrainer_checkpoint_cache_get(
|
|
444
|
-
(self.current_epoch, self.global_step)
|
|
445
|
-
)
|
|
446
|
-
)
|
|
447
|
-
is not None
|
|
448
|
-
):
|
|
449
|
-
# If we have a cached path, then we symlink it to the new path.
|
|
450
|
-
log.info(f"Re-using cached checkpoint {cached_path} for {filepath}.")
|
|
451
|
-
if self.is_global_zero:
|
|
452
|
-
shutil.copy(cached_path, filepath)
|
|
453
|
-
self.strategy.barrier("Trainer.save_checkpoint")
|
|
454
|
-
else:
|
|
455
|
-
super().save_checkpoint(filepath, weights_only, storage_options)
|
|
456
|
-
|
|
457
|
-
# If we are using the cache but we don't have a cached path, then we save the checkpoint to the cache.
|
|
458
|
-
if use_checkpoint_cache and cached_path is None:
|
|
459
|
-
self._nshtrainer_checkpoint_cache[
|
|
460
|
-
(self.current_epoch, self.global_step)
|
|
461
|
-
].append(filepath)
|
|
462
|
-
log.debug(f"Checkpoint saved to cache: {filepath}")
|
|
422
|
+
super().save_checkpoint(filepath, weights_only, storage_options)
|
|
463
423
|
|
|
464
424
|
# Save the checkpoint metadata
|
|
465
425
|
metadata_path = None
|
|
426
|
+
lm = self._base_module
|
|
427
|
+
root_config = cast(BaseConfig, lm.hparams)
|
|
466
428
|
if root_config.trainer.save_checkpoint_metadata and self.is_global_zero:
|
|
467
429
|
# Generate the metadata and write to disk
|
|
468
430
|
if (
|
|
@@ -474,17 +436,3 @@ class Trainer(LightningTrainer):
|
|
|
474
436
|
from .. import _callback
|
|
475
437
|
|
|
476
438
|
_callback._call_on_checkpoint_saved(self, filepath, metadata_path)
|
|
477
|
-
|
|
478
|
-
@override
|
|
479
|
-
def save_checkpoint(
|
|
480
|
-
self,
|
|
481
|
-
filepath: str | Path,
|
|
482
|
-
weights_only: bool = False,
|
|
483
|
-
storage_options: Any | None = None,
|
|
484
|
-
):
|
|
485
|
-
return self._nshtrainer_save_checkpoint(
|
|
486
|
-
filepath=filepath,
|
|
487
|
-
weights_only=weights_only,
|
|
488
|
-
storage_options=storage_options,
|
|
489
|
-
use_checkpoint_cache=False,
|
|
490
|
-
)
|
|
@@ -4,13 +4,13 @@ nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uP
|
|
|
4
4
|
nshtrainer/_checkpoint/metadata.py,sha256=hxZwwsUKVbBtt4wjqcKZbObx0PuO-qCdF3BTdnyqaQo,4711
|
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=1loCDYDy_Cay37uKs_wvxnkwvr41WMmga85qefct80Q,1271
|
|
6
6
|
nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
|
|
7
|
-
nshtrainer/_hf_hub.py,sha256=
|
|
7
|
+
nshtrainer/_hf_hub.py,sha256=42cR6viOiUInbBW4R7n9-AMnAn_ovN6YoLU3jD14PtI,12608
|
|
8
8
|
nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
|
|
9
9
|
nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
|
|
10
10
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
11
11
|
nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
|
|
12
12
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
|
|
13
|
-
nshtrainer/callbacks/checkpoint/_base.py,sha256=
|
|
13
|
+
nshtrainer/callbacks/checkpoint/_base.py,sha256=dXCfyFyf2TjH9Pnc4FCqFUGrmy_n25MSCn8jdf3L-_8,6605
|
|
14
14
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=8BHgLAd3Tuzf5sup0guEAKF1jJiAwYsjdKBFYZw98ac,2171
|
|
15
15
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
|
|
16
16
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
|
|
@@ -58,7 +58,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
|
|
|
58
58
|
nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
|
|
59
59
|
nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
|
|
60
60
|
nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
|
|
61
|
-
nshtrainer/model/config.py,sha256=
|
|
61
|
+
nshtrainer/model/config.py,sha256=zcCLcqvg4u7Zg6SLtCnqdIfiW8I0eART47lf1LCYl-A,43326
|
|
62
62
|
nshtrainer/model/modules/callback.py,sha256=1z6gUDBd35KG3phGzRekgZM6SIk-wj5Uo6APN4YhRR0,8549
|
|
63
63
|
nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
|
|
64
64
|
nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
|
|
@@ -78,7 +78,7 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
|
|
|
78
78
|
nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
|
|
79
79
|
nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
|
|
80
80
|
nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
|
|
81
|
-
nshtrainer/trainer/trainer.py,sha256=
|
|
81
|
+
nshtrainer/trainer/trainer.py,sha256=L4nYXq6Gts2sS9CQGenwEcvMET4L5vO5c60KM5Hm8Do,17544
|
|
82
82
|
nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
|
|
83
83
|
nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
|
|
84
84
|
nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
|
|
@@ -87,6 +87,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
87
87
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
88
88
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
89
89
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
90
|
-
nshtrainer-0.
|
|
91
|
-
nshtrainer-0.
|
|
92
|
-
nshtrainer-0.
|
|
90
|
+
nshtrainer-0.28.0.dist-info/METADATA,sha256=-4iVbXySZ9NGbwhqDWfBRdZmMNOEUx5ir1FKGrpdASg,916
|
|
91
|
+
nshtrainer-0.28.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
92
|
+
nshtrainer-0.28.0.dist-info/RECORD,,
|
|
File without changes
|