nshtrainer 0.19.3__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_checkpoint/loader.py +12 -4
- nshtrainer/model/config.py +7 -4
- nshtrainer/trainer/checkpoint_connector.py +8 -2
- {nshtrainer-0.19.3.dist-info → nshtrainer-0.20.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.19.3.dist-info → nshtrainer-0.20.0.dist-info}/RECORD +6 -6
- {nshtrainer-0.19.3.dist-info → nshtrainer-0.20.0.dist-info}/WHEEL +0 -0
nshtrainer/_checkpoint/loader.py
CHANGED
|
@@ -76,7 +76,11 @@ class CheckpointLoadingConfig(C.Config):
|
|
|
76
76
|
"""Whether to include checkpoints from HPC pre-emption."""
|
|
77
77
|
|
|
78
78
|
@classmethod
|
|
79
|
-
def
|
|
79
|
+
def none(cls, include_hpc: bool = False):
|
|
80
|
+
return cls(strategies=[], include_hpc=include_hpc)
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def _auto_train(cls, ckpt: Literal["best", "last", "none"] | str | Path | None):
|
|
80
84
|
if ckpt is None:
|
|
81
85
|
ckpt = "last"
|
|
82
86
|
match ckpt:
|
|
@@ -90,6 +94,8 @@ class CheckpointLoadingConfig(C.Config):
|
|
|
90
94
|
strategies=[LastCheckpointStrategyConfig()],
|
|
91
95
|
include_hpc=True,
|
|
92
96
|
)
|
|
97
|
+
case "none":
|
|
98
|
+
return cls.none()
|
|
93
99
|
case Path() | str():
|
|
94
100
|
ckpt = Path(ckpt)
|
|
95
101
|
return cls(
|
|
@@ -103,7 +109,7 @@ class CheckpointLoadingConfig(C.Config):
|
|
|
103
109
|
assert_never(ckpt)
|
|
104
110
|
|
|
105
111
|
@classmethod
|
|
106
|
-
def _auto_eval(cls, ckpt: Literal["best", "last"] | str | Path | None):
|
|
112
|
+
def _auto_eval(cls, ckpt: Literal["best", "last", "none"] | str | Path | None):
|
|
107
113
|
if ckpt is None:
|
|
108
114
|
log.warn("No checkpoint specified for evaluation. Defaulting to `last`.")
|
|
109
115
|
ckpt = "last"
|
|
@@ -119,6 +125,8 @@ class CheckpointLoadingConfig(C.Config):
|
|
|
119
125
|
strategies=[LastCheckpointStrategyConfig()],
|
|
120
126
|
include_hpc=False,
|
|
121
127
|
)
|
|
128
|
+
case "none":
|
|
129
|
+
return cls.none(include_hpc=False)
|
|
122
130
|
case Path() | str():
|
|
123
131
|
ckpt = Path(ckpt)
|
|
124
132
|
return cls(
|
|
@@ -131,7 +139,7 @@ class CheckpointLoadingConfig(C.Config):
|
|
|
131
139
|
@classmethod
|
|
132
140
|
def auto(
|
|
133
141
|
cls,
|
|
134
|
-
ckpt: Literal["best", "last"] | str | Path | None,
|
|
142
|
+
ckpt: Literal["best", "last", "none"] | str | Path | None,
|
|
135
143
|
trainer_mode: TrainerFn,
|
|
136
144
|
):
|
|
137
145
|
"""
|
|
@@ -142,7 +150,7 @@ class CheckpointLoadingConfig(C.Config):
|
|
|
142
150
|
|
|
143
151
|
Parameters:
|
|
144
152
|
-----------
|
|
145
|
-
ckpt : Literal["best", "last"] | str | Path | None
|
|
153
|
+
ckpt : Literal["best", "last", "none"] | str | Path | None
|
|
146
154
|
Specifies the checkpoint loading preference:
|
|
147
155
|
- "best": Use the best checkpoint based on the primary metric.
|
|
148
156
|
- "last": Use the most recent checkpoint.
|
nshtrainer/model/config.py
CHANGED
|
@@ -811,11 +811,14 @@ class SanityCheckingConfig(C.Config):
|
|
|
811
811
|
|
|
812
812
|
|
|
813
813
|
class TrainerConfig(C.Config):
|
|
814
|
-
ckpt_path: str | Path | None = None
|
|
815
|
-
"""Path to a checkpoint to load and resume training from."""
|
|
814
|
+
ckpt_path: Literal["none"] | str | Path | None = None
|
|
815
|
+
"""Path to a checkpoint to load and resume training from. If ``"none"``, will not load a checkpoint."""
|
|
816
816
|
|
|
817
|
-
checkpoint_loading: CheckpointLoadingConfig | Literal["auto"] = "auto"
|
|
818
|
-
"""Checkpoint loading configuration options.
|
|
817
|
+
checkpoint_loading: CheckpointLoadingConfig | Literal["auto", "none"] = "auto"
|
|
818
|
+
"""Checkpoint loading configuration options.
|
|
819
|
+
`"auto"` will automatically determine the best checkpoint loading strategy based on the provided.
|
|
820
|
+
`"none"` will disable checkpoint loading.
|
|
821
|
+
"""
|
|
819
822
|
|
|
820
823
|
checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
|
|
821
824
|
"""Checkpoint saving configuration options."""
|
|
@@ -31,8 +31,14 @@ class _CheckpointConnector(_LightningCheckpointConnector):
|
|
|
31
31
|
|
|
32
32
|
# Now, resolve the checkpoint loader config.
|
|
33
33
|
root_config = cast("BaseConfig", trainer._base_module.config)
|
|
34
|
-
|
|
35
|
-
|
|
34
|
+
ckpt_loader_config = root_config.trainer.checkpoint_loading
|
|
35
|
+
match ckpt_loader_config:
|
|
36
|
+
case "auto":
|
|
37
|
+
ckpt_loader_config = CheckpointLoadingConfig.auto(ckpt_path, state_fn)
|
|
38
|
+
case "none":
|
|
39
|
+
ckpt_loader_config = CheckpointLoadingConfig.none()
|
|
40
|
+
case _:
|
|
41
|
+
pass
|
|
36
42
|
log.debug(f"Checkpoint loader config: {ckpt_loader_config}")
|
|
37
43
|
|
|
38
44
|
# Use the config to resolve the checkpoint.
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
|
|
2
|
-
nshtrainer/_checkpoint/loader.py,sha256=
|
|
2
|
+
nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
|
|
3
3
|
nshtrainer/_checkpoint/metadata.py,sha256=p5e7dhVPpOGrXeuesq_7Y_RHi5lguzDAR_UXtMJXzWU,5175
|
|
4
4
|
nshtrainer/_checkpoint/saver.py,sha256=DkbCH0YeOJ71m32vAARiQdGBf0hvwwdoAV8LOFGy-0Y,1428
|
|
5
5
|
nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
|
|
@@ -57,7 +57,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
|
|
|
57
57
|
nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
|
|
58
58
|
nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
|
|
59
59
|
nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
|
|
60
|
-
nshtrainer/model/config.py,sha256=
|
|
60
|
+
nshtrainer/model/config.py,sha256=zcCLcqvg4u7Zg6SLtCnqdIfiW8I0eART47lf1LCYl-A,43326
|
|
61
61
|
nshtrainer/model/modules/callback.py,sha256=1z6gUDBd35KG3phGzRekgZM6SIk-wj5Uo6APN4YhRR0,8549
|
|
62
62
|
nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
|
|
63
63
|
nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
|
|
@@ -75,7 +75,7 @@ nshtrainer/runner.py,sha256=USAjrExHkN5oVNVunsoPnLxfQrEHSaa54S3RipOe544,3605
|
|
|
75
75
|
nshtrainer/scripts/find_packages.py,sha256=ixYivZobumyyGsf2B9oYMLyLTRcBzY_vUv-u3bNW-hs,1424
|
|
76
76
|
nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
|
|
77
77
|
nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
|
|
78
|
-
nshtrainer/trainer/checkpoint_connector.py,sha256=
|
|
78
|
+
nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
|
|
79
79
|
nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
|
|
80
80
|
nshtrainer/trainer/trainer.py,sha256=TTtVkgSB_ekgDlHg24d58Vzddtkpp6ZHOTVprXdXMH0,17503
|
|
81
81
|
nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
|
|
@@ -85,6 +85,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
85
85
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
86
86
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
87
87
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
88
|
-
nshtrainer-0.
|
|
89
|
-
nshtrainer-0.
|
|
90
|
-
nshtrainer-0.
|
|
88
|
+
nshtrainer-0.20.0.dist-info/METADATA,sha256=BCzgQYVMH8_7VHpAcAEuJqlQ0oJOERSbBop4bOebYZ4,935
|
|
89
|
+
nshtrainer-0.20.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
90
|
+
nshtrainer-0.20.0.dist-info/RECORD,,
|
|
File without changes
|