nshtrainer 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_checkpoint/loader.py +7 -5
- nshtrainer/_checkpoint/metadata.py +7 -11
- nshtrainer/callbacks/checkpoint/_base.py +1 -0
- nshtrainer/ll/snapshot.py +1 -1
- nshtrainer/trainer/trainer.py +2 -1
- {nshtrainer-0.15.1.dist-info → nshtrainer-0.16.1.dist-info}/METADATA +1 -1
- {nshtrainer-0.15.1.dist-info → nshtrainer-0.16.1.dist-info}/RECORD +8 -8
- {nshtrainer-0.15.1.dist-info → nshtrainer-0.16.1.dist-info}/WHEEL +0 -0
nshtrainer/_checkpoint/loader.py
CHANGED
|
@@ -10,7 +10,7 @@ from lightning.pytorch.trainer.states import TrainerFn
|
|
|
10
10
|
from typing_extensions import assert_never
|
|
11
11
|
|
|
12
12
|
from ..metrics._config import MetricConfig
|
|
13
|
-
from .metadata import
|
|
13
|
+
from .metadata import CheckpointMetadata
|
|
14
14
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
16
16
|
from ..model.config import BaseConfig
|
|
@@ -263,13 +263,13 @@ def _checkpoint_candidates(
|
|
|
263
263
|
|
|
264
264
|
# Load all checkpoints in the directory.
|
|
265
265
|
# We can do this by looking for metadata files.
|
|
266
|
-
for path in ckpt_dir.glob(f"*{
|
|
266
|
+
for path in ckpt_dir.glob(f"*{CheckpointMetadata.PATH_SUFFIX}"):
|
|
267
267
|
if (meta := _load_ckpt_meta(path, root_config)) is not None:
|
|
268
268
|
yield meta
|
|
269
269
|
|
|
270
270
|
# If we have a pre-empted checkpoint, load it
|
|
271
271
|
if include_hpc and (hpc_path := trainer._checkpoint_connector._hpc_resume_path):
|
|
272
|
-
hpc_meta_path = Path(hpc_path).with_suffix(
|
|
272
|
+
hpc_meta_path = Path(hpc_path).with_suffix(CheckpointMetadata.PATH_SUFFIX)
|
|
273
273
|
if (meta := _load_ckpt_meta(hpc_meta_path, root_config)) is not None:
|
|
274
274
|
yield meta
|
|
275
275
|
|
|
@@ -279,7 +279,9 @@ def _additional_candidates(
|
|
|
279
279
|
):
|
|
280
280
|
for path in additional_candidates:
|
|
281
281
|
if (
|
|
282
|
-
meta := _load_ckpt_meta(
|
|
282
|
+
meta := _load_ckpt_meta(
|
|
283
|
+
path.with_suffix(CheckpointMetadata.PATH_SUFFIX), root_config
|
|
284
|
+
)
|
|
283
285
|
) is None:
|
|
284
286
|
continue
|
|
285
287
|
yield meta
|
|
@@ -310,7 +312,7 @@ def _resolve_checkpoint(
|
|
|
310
312
|
match strategy:
|
|
311
313
|
case UserProvidedPathCheckpointStrategyConfig():
|
|
312
314
|
meta = _load_ckpt_meta(
|
|
313
|
-
strategy.path.with_suffix(
|
|
315
|
+
strategy.path.with_suffix(CheckpointMetadata.PATH_SUFFIX),
|
|
314
316
|
root_config,
|
|
315
317
|
on_error=strategy.on_error,
|
|
316
318
|
)
|
|
@@ -10,8 +10,6 @@ import nshconfig as C
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
import torch
|
|
12
12
|
|
|
13
|
-
from ..util._environment_info import EnvironmentConfig
|
|
14
|
-
|
|
15
13
|
if TYPE_CHECKING:
|
|
16
14
|
from ..model import BaseConfig, LightningModuleBase
|
|
17
15
|
from ..trainer.trainer import Trainer
|
|
@@ -38,7 +36,7 @@ class CheckpointMetadata(C.Config):
|
|
|
38
36
|
global_step: int
|
|
39
37
|
training_time: datetime.timedelta
|
|
40
38
|
metrics: dict[str, Any]
|
|
41
|
-
environment:
|
|
39
|
+
environment: dict[str, Any]
|
|
42
40
|
|
|
43
41
|
hparams: dict[str, Any] | None
|
|
44
42
|
|
|
@@ -48,9 +46,7 @@ class CheckpointMetadata(C.Config):
|
|
|
48
46
|
|
|
49
47
|
@classmethod
|
|
50
48
|
def from_ckpt_path(cls, checkpoint_path: Path):
|
|
51
|
-
if not (
|
|
52
|
-
metadata_path := checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
53
|
-
).exists():
|
|
49
|
+
if not (metadata_path := checkpoint_path.with_suffix(cls.PATH_SUFFIX)).exists():
|
|
54
50
|
raise FileNotFoundError(
|
|
55
51
|
f"Metadata file not found for checkpoint: {checkpoint_path}"
|
|
56
52
|
)
|
|
@@ -93,7 +89,7 @@ def _generate_checkpoint_metadata(
|
|
|
93
89
|
global_step=trainer.global_step,
|
|
94
90
|
training_time=training_time,
|
|
95
91
|
metrics=metrics,
|
|
96
|
-
environment=config.environment,
|
|
92
|
+
environment=config.environment.model_dump(mode="json"),
|
|
97
93
|
hparams=config.model_dump(mode="json"),
|
|
98
94
|
)
|
|
99
95
|
|
|
@@ -104,7 +100,7 @@ def _write_checkpoint_metadata(
|
|
|
104
100
|
checkpoint_path: Path,
|
|
105
101
|
):
|
|
106
102
|
config = cast("BaseConfig", model.config)
|
|
107
|
-
metadata_path = checkpoint_path.with_suffix(
|
|
103
|
+
metadata_path = checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
|
|
108
104
|
metadata = _generate_checkpoint_metadata(
|
|
109
105
|
config, trainer, checkpoint_path, metadata_path
|
|
110
106
|
)
|
|
@@ -119,7 +115,7 @@ def _write_checkpoint_metadata(
|
|
|
119
115
|
|
|
120
116
|
|
|
121
117
|
def _remove_checkpoint_metadata(checkpoint_path: Path):
|
|
122
|
-
path = checkpoint_path.with_suffix(
|
|
118
|
+
path = checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
|
|
123
119
|
try:
|
|
124
120
|
path.unlink(missing_ok=True)
|
|
125
121
|
except Exception as e:
|
|
@@ -133,8 +129,8 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
|
|
|
133
129
|
_remove_checkpoint_metadata(linked_checkpoint_path)
|
|
134
130
|
|
|
135
131
|
# Link the metadata files to the new checkpoint
|
|
136
|
-
path = checkpoint_path.with_suffix(
|
|
137
|
-
linked_path = linked_checkpoint_path.with_suffix(
|
|
132
|
+
path = checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
|
|
133
|
+
linked_path = linked_checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
|
|
138
134
|
try:
|
|
139
135
|
try:
|
|
140
136
|
# linked_path.symlink_to(path)
|
|
@@ -102,6 +102,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
102
102
|
metas = [
|
|
103
103
|
CheckpointMetadata.from_file(p)
|
|
104
104
|
for p in self.dirpath.glob(f"*{CheckpointMetadata.PATH_SUFFIX}")
|
|
105
|
+
if p.is_file() and not p.is_symlink()
|
|
105
106
|
]
|
|
106
107
|
|
|
107
108
|
# Sort by the topk sort key
|
nshtrainer/ll/snapshot.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
from
|
|
1
|
+
from nshsnap import * # pyright: ignore[reportWildcardImportFromLibrary] # noqa: F403
|
nshtrainer/trainer/trainer.py
CHANGED
|
@@ -419,7 +419,8 @@ class Trainer(LightningTrainer):
|
|
|
419
419
|
|
|
420
420
|
# Save the checkpoint metadata
|
|
421
421
|
lm = self._base_module
|
|
422
|
-
|
|
422
|
+
hparams = cast(BaseConfig, lm.hparams)
|
|
423
|
+
if hparams.trainer.save_checkpoint_metadata and self.is_global_zero:
|
|
423
424
|
# Generate the metadata and write to disk
|
|
424
425
|
_write_checkpoint_metadata(self, lm, filepath)
|
|
425
426
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
|
|
2
|
-
nshtrainer/_checkpoint/loader.py,sha256=
|
|
3
|
-
nshtrainer/_checkpoint/metadata.py,sha256=
|
|
2
|
+
nshtrainer/_checkpoint/loader.py,sha256=myFObRsPdb8jBncMK73vjr5FDJIfKhF86Ec_kSjXtwg,13837
|
|
3
|
+
nshtrainer/_checkpoint/metadata.py,sha256=z3gi_J_YDNNtBw1OZ08LBECkcoc9rIydnTKvFOOoG4c,5131
|
|
4
4
|
nshtrainer/_checkpoint/saver.py,sha256=DkbCH0YeOJ71m32vAARiQdGBf0hvwwdoAV8LOFGy-0Y,1428
|
|
5
5
|
nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
|
|
6
6
|
nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
|
|
@@ -8,7 +8,7 @@ nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHm
|
|
|
8
8
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
9
9
|
nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
|
|
10
10
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
|
|
11
|
-
nshtrainer/callbacks/checkpoint/_base.py,sha256=
|
|
11
|
+
nshtrainer/callbacks/checkpoint/_base.py,sha256=YT_V-oihO9iB4ETl46CGYTCQjIYl-CpV7TMViTn07Lk,6144
|
|
12
12
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=DJiLo7NDzd-lp-O3v7Cv8WejyjXPV_6_RmfltKO9fvE,2165
|
|
13
13
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CqB_8Xv32rtpLCaEEPi6DbRZm4ph5TWS-LfqIHXUIUA,1097
|
|
14
14
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
|
|
@@ -38,7 +38,7 @@ nshtrainer/ll/model.py,sha256=cxFQfFc-2mAYBGwDpP8m5tjQBs7M47cZ6JoPXksPaoI,473
|
|
|
38
38
|
nshtrainer/ll/nn.py,sha256=8qiRDFwojIxkB7-LtNWk4mLL2tJbaskHYofDsOIHiNg,42
|
|
39
39
|
nshtrainer/ll/optimizer.py,sha256=3T-VZtT73jVvwCNJGDjgGEbzs-1LFTzMQH-SB_58mSo,49
|
|
40
40
|
nshtrainer/ll/runner.py,sha256=B0m5VEhNKIjF1aFmqPkonkQxDoRL2jeHZGsV3zwhSVE,117
|
|
41
|
-
nshtrainer/ll/snapshot.py,sha256=
|
|
41
|
+
nshtrainer/ll/snapshot.py,sha256=EdzvbE6IyHcNpIvkF1lWj3pBigfRF9jQXgOlbgBavYs,87
|
|
42
42
|
nshtrainer/ll/snoop.py,sha256=hG9VCdm8mIZytHLZgUKRSoWqg55rVvUpBAH-OSiwgvI,36
|
|
43
43
|
nshtrainer/ll/trainer.py,sha256=hkn2xPtrSPQ7LqQhbyAKuMfNyHdhqB9bDPvgRCK8oJM,47
|
|
44
44
|
nshtrainer/ll/typecheck.py,sha256=ryV1Tzcf7hJ4I19H1oQVkikU9spmRk8jyIKQZ5UF7pQ,62
|
|
@@ -76,7 +76,7 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
|
|
|
76
76
|
nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
|
|
77
77
|
nshtrainer/trainer/checkpoint_connector.py,sha256=F2tkHogbMAa5U7335sm77sZBkjEDa5v46XbJCH9Mg6c,2167
|
|
78
78
|
nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
|
|
79
|
-
nshtrainer/trainer/trainer.py,sha256=
|
|
79
|
+
nshtrainer/trainer/trainer.py,sha256=jIqiNrq1I0f5pQP7lHshtgjCAYfpoWPoqwS74LHU9iM,17148
|
|
80
80
|
nshtrainer/util/_environment_info.py,sha256=Nmhls6u5rtMWbeDLGCjEk58efUutc_ONSUg3fs59TSI,24210
|
|
81
81
|
nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
|
|
82
82
|
nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
|
|
@@ -84,6 +84,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
84
84
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
85
85
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
86
86
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
87
|
-
nshtrainer-0.
|
|
88
|
-
nshtrainer-0.
|
|
89
|
-
nshtrainer-0.
|
|
87
|
+
nshtrainer-0.16.1.dist-info/METADATA,sha256=SQ2uzBMs4f-pYgabVgI9EErqge_LtQ9-RTz2_mTskDU,885
|
|
88
|
+
nshtrainer-0.16.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
89
|
+
nshtrainer-0.16.1.dist-info/RECORD,,
|
|
File without changes
|