nshtrainer 0.19.2__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_checkpoint/loader.py +14 -5
- nshtrainer/model/config.py +7 -4
- nshtrainer/trainer/checkpoint_connector.py +8 -2
- {nshtrainer-0.19.2.dist-info → nshtrainer-0.20.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.19.2.dist-info → nshtrainer-0.20.0.dist-info}/RECORD +6 -6
- {nshtrainer-0.19.2.dist-info → nshtrainer-0.20.0.dist-info}/WHEEL +0 -0
nshtrainer/_checkpoint/loader.py
CHANGED
|
@@ -76,7 +76,11 @@ class CheckpointLoadingConfig(C.Config):
|
|
|
76
76
|
"""Whether to include checkpoints from HPC pre-emption."""
|
|
77
77
|
|
|
78
78
|
@classmethod
|
|
79
|
-
def
|
|
79
|
+
def none(cls, include_hpc: bool = False):
|
|
80
|
+
return cls(strategies=[], include_hpc=include_hpc)
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def _auto_train(cls, ckpt: Literal["best", "last", "none"] | str | Path | None):
|
|
80
84
|
if ckpt is None:
|
|
81
85
|
ckpt = "last"
|
|
82
86
|
match ckpt:
|
|
@@ -90,6 +94,8 @@ class CheckpointLoadingConfig(C.Config):
|
|
|
90
94
|
strategies=[LastCheckpointStrategyConfig()],
|
|
91
95
|
include_hpc=True,
|
|
92
96
|
)
|
|
97
|
+
case "none":
|
|
98
|
+
return cls.none()
|
|
93
99
|
case Path() | str():
|
|
94
100
|
ckpt = Path(ckpt)
|
|
95
101
|
return cls(
|
|
@@ -103,9 +109,10 @@ class CheckpointLoadingConfig(C.Config):
|
|
|
103
109
|
assert_never(ckpt)
|
|
104
110
|
|
|
105
111
|
@classmethod
|
|
106
|
-
def _auto_eval(cls, ckpt: Literal["best", "last"] | str | Path | None):
|
|
112
|
+
def _auto_eval(cls, ckpt: Literal["best", "last", "none"] | str | Path | None):
|
|
107
113
|
if ckpt is None:
|
|
108
|
-
|
|
114
|
+
log.warn("No checkpoint specified for evaluation. Defaulting to `last`.")
|
|
115
|
+
ckpt = "last"
|
|
109
116
|
|
|
110
117
|
match ckpt:
|
|
111
118
|
case "best":
|
|
@@ -118,6 +125,8 @@ class CheckpointLoadingConfig(C.Config):
|
|
|
118
125
|
strategies=[LastCheckpointStrategyConfig()],
|
|
119
126
|
include_hpc=False,
|
|
120
127
|
)
|
|
128
|
+
case "none":
|
|
129
|
+
return cls.none(include_hpc=False)
|
|
121
130
|
case Path() | str():
|
|
122
131
|
ckpt = Path(ckpt)
|
|
123
132
|
return cls(
|
|
@@ -130,7 +139,7 @@ class CheckpointLoadingConfig(C.Config):
|
|
|
130
139
|
@classmethod
|
|
131
140
|
def auto(
|
|
132
141
|
cls,
|
|
133
|
-
ckpt: Literal["best", "last"] | str | Path | None,
|
|
142
|
+
ckpt: Literal["best", "last", "none"] | str | Path | None,
|
|
134
143
|
trainer_mode: TrainerFn,
|
|
135
144
|
):
|
|
136
145
|
"""
|
|
@@ -141,7 +150,7 @@ class CheckpointLoadingConfig(C.Config):
|
|
|
141
150
|
|
|
142
151
|
Parameters:
|
|
143
152
|
-----------
|
|
144
|
-
ckpt : Literal["best", "last"] | str | Path | None
|
|
153
|
+
ckpt : Literal["best", "last", "none"] | str | Path | None
|
|
145
154
|
Specifies the checkpoint loading preference:
|
|
146
155
|
- "best": Use the best checkpoint based on the primary metric.
|
|
147
156
|
- "last": Use the most recent checkpoint.
|
nshtrainer/model/config.py
CHANGED
|
@@ -811,11 +811,14 @@ class SanityCheckingConfig(C.Config):
|
|
|
811
811
|
|
|
812
812
|
|
|
813
813
|
class TrainerConfig(C.Config):
|
|
814
|
-
ckpt_path: str | Path | None = None
|
|
815
|
-
"""Path to a checkpoint to load and resume training from."""
|
|
814
|
+
ckpt_path: Literal["none"] | str | Path | None = None
|
|
815
|
+
"""Path to a checkpoint to load and resume training from. If ``"none"``, will not load a checkpoint."""
|
|
816
816
|
|
|
817
|
-
checkpoint_loading: CheckpointLoadingConfig | Literal["auto"] = "auto"
|
|
818
|
-
"""Checkpoint loading configuration options.
|
|
817
|
+
checkpoint_loading: CheckpointLoadingConfig | Literal["auto", "none"] = "auto"
|
|
818
|
+
"""Checkpoint loading configuration options.
|
|
819
|
+
`"auto"` will automatically determine the best checkpoint loading strategy based on the provided.
|
|
820
|
+
`"none"` will disable checkpoint loading.
|
|
821
|
+
"""
|
|
819
822
|
|
|
820
823
|
checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
|
|
821
824
|
"""Checkpoint saving configuration options."""
|
|
@@ -31,8 +31,14 @@ class _CheckpointConnector(_LightningCheckpointConnector):
|
|
|
31
31
|
|
|
32
32
|
# Now, resolve the checkpoint loader config.
|
|
33
33
|
root_config = cast("BaseConfig", trainer._base_module.config)
|
|
34
|
-
|
|
35
|
-
|
|
34
|
+
ckpt_loader_config = root_config.trainer.checkpoint_loading
|
|
35
|
+
match ckpt_loader_config:
|
|
36
|
+
case "auto":
|
|
37
|
+
ckpt_loader_config = CheckpointLoadingConfig.auto(ckpt_path, state_fn)
|
|
38
|
+
case "none":
|
|
39
|
+
ckpt_loader_config = CheckpointLoadingConfig.none()
|
|
40
|
+
case _:
|
|
41
|
+
pass
|
|
36
42
|
log.debug(f"Checkpoint loader config: {ckpt_loader_config}")
|
|
37
43
|
|
|
38
44
|
# Use the config to resolve the checkpoint.
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
|
|
2
|
-
nshtrainer/_checkpoint/loader.py,sha256=
|
|
2
|
+
nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
|
|
3
3
|
nshtrainer/_checkpoint/metadata.py,sha256=p5e7dhVPpOGrXeuesq_7Y_RHi5lguzDAR_UXtMJXzWU,5175
|
|
4
4
|
nshtrainer/_checkpoint/saver.py,sha256=DkbCH0YeOJ71m32vAARiQdGBf0hvwwdoAV8LOFGy-0Y,1428
|
|
5
5
|
nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
|
|
@@ -57,7 +57,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
|
|
|
57
57
|
nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
|
|
58
58
|
nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
|
|
59
59
|
nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
|
|
60
|
-
nshtrainer/model/config.py,sha256=
|
|
60
|
+
nshtrainer/model/config.py,sha256=zcCLcqvg4u7Zg6SLtCnqdIfiW8I0eART47lf1LCYl-A,43326
|
|
61
61
|
nshtrainer/model/modules/callback.py,sha256=1z6gUDBd35KG3phGzRekgZM6SIk-wj5Uo6APN4YhRR0,8549
|
|
62
62
|
nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
|
|
63
63
|
nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
|
|
@@ -75,7 +75,7 @@ nshtrainer/runner.py,sha256=USAjrExHkN5oVNVunsoPnLxfQrEHSaa54S3RipOe544,3605
|
|
|
75
75
|
nshtrainer/scripts/find_packages.py,sha256=ixYivZobumyyGsf2B9oYMLyLTRcBzY_vUv-u3bNW-hs,1424
|
|
76
76
|
nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
|
|
77
77
|
nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
|
|
78
|
-
nshtrainer/trainer/checkpoint_connector.py,sha256=
|
|
78
|
+
nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
|
|
79
79
|
nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
|
|
80
80
|
nshtrainer/trainer/trainer.py,sha256=TTtVkgSB_ekgDlHg24d58Vzddtkpp6ZHOTVprXdXMH0,17503
|
|
81
81
|
nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
|
|
@@ -85,6 +85,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
85
85
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
86
86
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
87
87
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
88
|
-
nshtrainer-0.
|
|
89
|
-
nshtrainer-0.
|
|
90
|
-
nshtrainer-0.
|
|
88
|
+
nshtrainer-0.20.0.dist-info/METADATA,sha256=BCzgQYVMH8_7VHpAcAEuJqlQ0oJOERSbBop4bOebYZ4,935
|
|
89
|
+
nshtrainer-0.20.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
90
|
+
nshtrainer-0.20.0.dist-info/RECORD,,
|
|
File without changes
|