nshtrainer 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_checkpoint/loader.py +7 -5
- nshtrainer/_checkpoint/metadata.py +14 -16
- nshtrainer/callbacks/checkpoint/_base.py +1 -0
- nshtrainer/trainer/trainer.py +2 -1
- nshtrainer/util/_environment_info.py +6 -6
- {nshtrainer-0.16.0.dist-info → nshtrainer-0.17.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.16.0.dist-info → nshtrainer-0.17.0.dist-info}/RECORD +8 -8
- {nshtrainer-0.16.0.dist-info → nshtrainer-0.17.0.dist-info}/WHEEL +0 -0
nshtrainer/_checkpoint/loader.py
CHANGED
|
@@ -10,7 +10,7 @@ from lightning.pytorch.trainer.states import TrainerFn
|
|
|
10
10
|
from typing_extensions import assert_never
|
|
11
11
|
|
|
12
12
|
from ..metrics._config import MetricConfig
|
|
13
|
-
from .metadata import
|
|
13
|
+
from .metadata import CheckpointMetadata
|
|
14
14
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
16
16
|
from ..model.config import BaseConfig
|
|
@@ -263,13 +263,13 @@ def _checkpoint_candidates(
|
|
|
263
263
|
|
|
264
264
|
# Load all checkpoints in the directory.
|
|
265
265
|
# We can do this by looking for metadata files.
|
|
266
|
-
for path in ckpt_dir.glob(f"*{
|
|
266
|
+
for path in ckpt_dir.glob(f"*{CheckpointMetadata.PATH_SUFFIX}"):
|
|
267
267
|
if (meta := _load_ckpt_meta(path, root_config)) is not None:
|
|
268
268
|
yield meta
|
|
269
269
|
|
|
270
270
|
# If we have a pre-empted checkpoint, load it
|
|
271
271
|
if include_hpc and (hpc_path := trainer._checkpoint_connector._hpc_resume_path):
|
|
272
|
-
hpc_meta_path = Path(hpc_path).with_suffix(
|
|
272
|
+
hpc_meta_path = Path(hpc_path).with_suffix(CheckpointMetadata.PATH_SUFFIX)
|
|
273
273
|
if (meta := _load_ckpt_meta(hpc_meta_path, root_config)) is not None:
|
|
274
274
|
yield meta
|
|
275
275
|
|
|
@@ -279,7 +279,9 @@ def _additional_candidates(
|
|
|
279
279
|
):
|
|
280
280
|
for path in additional_candidates:
|
|
281
281
|
if (
|
|
282
|
-
meta := _load_ckpt_meta(
|
|
282
|
+
meta := _load_ckpt_meta(
|
|
283
|
+
path.with_suffix(CheckpointMetadata.PATH_SUFFIX), root_config
|
|
284
|
+
)
|
|
283
285
|
) is None:
|
|
284
286
|
continue
|
|
285
287
|
yield meta
|
|
@@ -310,7 +312,7 @@ def _resolve_checkpoint(
|
|
|
310
312
|
match strategy:
|
|
311
313
|
case UserProvidedPathCheckpointStrategyConfig():
|
|
312
314
|
meta = _load_ckpt_meta(
|
|
313
|
-
strategy.path.with_suffix(
|
|
315
|
+
strategy.path.with_suffix(CheckpointMetadata.PATH_SUFFIX),
|
|
314
316
|
root_config,
|
|
315
317
|
on_error=strategy.on_error,
|
|
316
318
|
)
|
|
@@ -40,7 +40,7 @@ class CheckpointMetadata(C.Config):
|
|
|
40
40
|
metrics: dict[str, Any]
|
|
41
41
|
environment: EnvironmentConfig
|
|
42
42
|
|
|
43
|
-
hparams:
|
|
43
|
+
hparams: Any
|
|
44
44
|
|
|
45
45
|
@classmethod
|
|
46
46
|
def from_file(cls, path: Path):
|
|
@@ -48,9 +48,7 @@ class CheckpointMetadata(C.Config):
|
|
|
48
48
|
|
|
49
49
|
@classmethod
|
|
50
50
|
def from_ckpt_path(cls, checkpoint_path: Path):
|
|
51
|
-
if not (
|
|
52
|
-
metadata_path := checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
53
|
-
).exists():
|
|
51
|
+
if not (metadata_path := checkpoint_path.with_suffix(cls.PATH_SUFFIX)).exists():
|
|
54
52
|
raise FileNotFoundError(
|
|
55
53
|
f"Metadata file not found for checkpoint: {checkpoint_path}"
|
|
56
54
|
)
|
|
@@ -94,7 +92,7 @@ def _generate_checkpoint_metadata(
|
|
|
94
92
|
training_time=training_time,
|
|
95
93
|
metrics=metrics,
|
|
96
94
|
environment=config.environment,
|
|
97
|
-
hparams=config.model_dump(
|
|
95
|
+
hparams=config.model_dump(),
|
|
98
96
|
)
|
|
99
97
|
|
|
100
98
|
|
|
@@ -104,26 +102,26 @@ def _write_checkpoint_metadata(
|
|
|
104
102
|
checkpoint_path: Path,
|
|
105
103
|
):
|
|
106
104
|
config = cast("BaseConfig", model.config)
|
|
107
|
-
metadata_path = checkpoint_path.with_suffix(
|
|
105
|
+
metadata_path = checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
|
|
108
106
|
metadata = _generate_checkpoint_metadata(
|
|
109
107
|
config, trainer, checkpoint_path, metadata_path
|
|
110
108
|
)
|
|
111
109
|
|
|
112
110
|
# Write the metadata to the checkpoint directory
|
|
113
111
|
try:
|
|
114
|
-
metadata_path.write_text(metadata.model_dump_json(indent=4))
|
|
115
|
-
except Exception
|
|
116
|
-
log.
|
|
112
|
+
metadata_path.write_text(metadata.model_dump_json(indent=4), encoding="utf-8")
|
|
113
|
+
except Exception:
|
|
114
|
+
log.exception(f"Failed to write metadata to {checkpoint_path}")
|
|
117
115
|
else:
|
|
118
116
|
log.debug(f"Checkpoint metadata written to {checkpoint_path}")
|
|
119
117
|
|
|
120
118
|
|
|
121
119
|
def _remove_checkpoint_metadata(checkpoint_path: Path):
|
|
122
|
-
path = checkpoint_path.with_suffix(
|
|
120
|
+
path = checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
|
|
123
121
|
try:
|
|
124
122
|
path.unlink(missing_ok=True)
|
|
125
|
-
except Exception
|
|
126
|
-
log.
|
|
123
|
+
except Exception:
|
|
124
|
+
log.exception(f"Failed to remove {path}")
|
|
127
125
|
else:
|
|
128
126
|
log.debug(f"Removed {path}")
|
|
129
127
|
|
|
@@ -133,8 +131,8 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
|
|
|
133
131
|
_remove_checkpoint_metadata(linked_checkpoint_path)
|
|
134
132
|
|
|
135
133
|
# Link the metadata files to the new checkpoint
|
|
136
|
-
path = checkpoint_path.with_suffix(
|
|
137
|
-
linked_path = linked_checkpoint_path.with_suffix(
|
|
134
|
+
path = checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
|
|
135
|
+
linked_path = linked_checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
|
|
138
136
|
try:
|
|
139
137
|
try:
|
|
140
138
|
# linked_path.symlink_to(path)
|
|
@@ -146,8 +144,8 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
|
|
|
146
144
|
# on Windows, special permissions are required to create symbolic links as a regular user
|
|
147
145
|
# fall back to copying the file
|
|
148
146
|
shutil.copy(path, linked_path)
|
|
149
|
-
except Exception
|
|
150
|
-
log.
|
|
147
|
+
except Exception:
|
|
148
|
+
log.exception(f"Failed to link {path} to {linked_path}")
|
|
151
149
|
else:
|
|
152
150
|
log.debug(f"Linked {path} to {linked_path}")
|
|
153
151
|
|
|
@@ -102,6 +102,7 @@ class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
|
102
102
|
metas = [
|
|
103
103
|
CheckpointMetadata.from_file(p)
|
|
104
104
|
for p in self.dirpath.glob(f"*{CheckpointMetadata.PATH_SUFFIX}")
|
|
105
|
+
if p.is_file() and not p.is_symlink()
|
|
105
106
|
]
|
|
106
107
|
|
|
107
108
|
# Sort by the topk sort key
|
nshtrainer/trainer/trainer.py
CHANGED
|
@@ -419,7 +419,8 @@ class Trainer(LightningTrainer):
|
|
|
419
419
|
|
|
420
420
|
# Save the checkpoint metadata
|
|
421
421
|
lm = self._base_module
|
|
422
|
-
|
|
422
|
+
hparams = cast(BaseConfig, lm.hparams)
|
|
423
|
+
if hparams.trainer.save_checkpoint_metadata and self.is_global_zero:
|
|
423
424
|
# Generate the metadata and write to disk
|
|
424
425
|
_write_checkpoint_metadata(self, lm, filepath)
|
|
425
426
|
|
|
@@ -429,12 +429,12 @@ class EnvironmentPackageConfig(C.Config):
|
|
|
429
429
|
version=clean_version,
|
|
430
430
|
path=Path(str(f)) if (f := dist.locate_file("")) else None,
|
|
431
431
|
summary=metadata["Summary"] if "Summary" in metadata else None,
|
|
432
|
-
author=metadata["Author"] if "
|
|
433
|
-
license=metadata["License"] if "
|
|
432
|
+
author=metadata["Author"] if "Author" in metadata else None,
|
|
433
|
+
license=metadata["License"] if "License" in metadata else None,
|
|
434
434
|
requires=requires,
|
|
435
435
|
)
|
|
436
|
-
except Exception
|
|
437
|
-
log.
|
|
436
|
+
except Exception:
|
|
437
|
+
log.exception(f"Error processing package {dist.name}")
|
|
438
438
|
|
|
439
439
|
except ImportError:
|
|
440
440
|
log.warning(
|
|
@@ -672,8 +672,8 @@ class GitRepositoryConfig(C.Config):
|
|
|
672
672
|
draft.is_dirty = repo.is_dirty()
|
|
673
673
|
except git.InvalidGitRepositoryError:
|
|
674
674
|
draft.is_git_repo = False
|
|
675
|
-
except Exception
|
|
676
|
-
log.
|
|
675
|
+
except Exception:
|
|
676
|
+
log.exception("Failed to get Git repository information")
|
|
677
677
|
draft.is_git_repo = None
|
|
678
678
|
|
|
679
679
|
return draft.finalize()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
|
|
2
|
-
nshtrainer/_checkpoint/loader.py,sha256=
|
|
3
|
-
nshtrainer/_checkpoint/metadata.py,sha256=
|
|
2
|
+
nshtrainer/_checkpoint/loader.py,sha256=myFObRsPdb8jBncMK73vjr5FDJIfKhF86Ec_kSjXtwg,13837
|
|
3
|
+
nshtrainer/_checkpoint/metadata.py,sha256=ilq3Wz9QmILxwpy5qYm7kphhtqPUYAOh8D-OXc5SqSc,5131
|
|
4
4
|
nshtrainer/_checkpoint/saver.py,sha256=DkbCH0YeOJ71m32vAARiQdGBf0hvwwdoAV8LOFGy-0Y,1428
|
|
5
5
|
nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
|
|
6
6
|
nshtrainer/callbacks/__init__.py,sha256=4qocBDzQbLLhhbIEfvbA3SQB_Dy9ZJH7keMwPay-ZS8,2359
|
|
@@ -8,7 +8,7 @@ nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHm
|
|
|
8
8
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
9
9
|
nshtrainer/callbacks/base.py,sha256=UnlYZAqSb8UwBJR-N5-XunxFx2yZjZ4lyGqUfhbCRlI,3555
|
|
10
10
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
|
|
11
|
-
nshtrainer/callbacks/checkpoint/_base.py,sha256=
|
|
11
|
+
nshtrainer/callbacks/checkpoint/_base.py,sha256=YT_V-oihO9iB4ETl46CGYTCQjIYl-CpV7TMViTn07Lk,6144
|
|
12
12
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=DJiLo7NDzd-lp-O3v7Cv8WejyjXPV_6_RmfltKO9fvE,2165
|
|
13
13
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CqB_8Xv32rtpLCaEEPi6DbRZm4ph5TWS-LfqIHXUIUA,1097
|
|
14
14
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
|
|
@@ -76,14 +76,14 @@ nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3
|
|
|
76
76
|
nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
|
|
77
77
|
nshtrainer/trainer/checkpoint_connector.py,sha256=F2tkHogbMAa5U7335sm77sZBkjEDa5v46XbJCH9Mg6c,2167
|
|
78
78
|
nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
|
|
79
|
-
nshtrainer/trainer/trainer.py,sha256=
|
|
80
|
-
nshtrainer/util/_environment_info.py,sha256=
|
|
79
|
+
nshtrainer/trainer/trainer.py,sha256=jIqiNrq1I0f5pQP7lHshtgjCAYfpoWPoqwS74LHU9iM,17148
|
|
80
|
+
nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
|
|
81
81
|
nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
|
|
82
82
|
nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
|
|
83
83
|
nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
84
84
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
85
85
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
86
86
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
87
|
-
nshtrainer-0.
|
|
88
|
-
nshtrainer-0.
|
|
89
|
-
nshtrainer-0.
|
|
87
|
+
nshtrainer-0.17.0.dist-info/METADATA,sha256=6qrwuGiaXKIaFBpYwvuqBjyDhkLX4bNjU1ojJeTLsQE,885
|
|
88
|
+
nshtrainer-0.17.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
89
|
+
nshtrainer-0.17.0.dist-info/RECORD,,
|
|
File without changes
|