nshtrainer 0.19.3__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -76,7 +76,11 @@ class CheckpointLoadingConfig(C.Config):
76
76
  """Whether to include checkpoints from HPC pre-emption."""
77
77
 
78
78
  @classmethod
79
- def _auto_train(cls, ckpt: Literal["best", "last"] | str | Path | None):
79
+ def none(cls, include_hpc: bool = False):
80
+ return cls(strategies=[], include_hpc=include_hpc)
81
+
82
+ @classmethod
83
+ def _auto_train(cls, ckpt: Literal["best", "last", "none"] | str | Path | None):
80
84
  if ckpt is None:
81
85
  ckpt = "last"
82
86
  match ckpt:
@@ -90,6 +94,8 @@ class CheckpointLoadingConfig(C.Config):
90
94
  strategies=[LastCheckpointStrategyConfig()],
91
95
  include_hpc=True,
92
96
  )
97
+ case "none":
98
+ return cls.none()
93
99
  case Path() | str():
94
100
  ckpt = Path(ckpt)
95
101
  return cls(
@@ -103,7 +109,7 @@ class CheckpointLoadingConfig(C.Config):
103
109
  assert_never(ckpt)
104
110
 
105
111
  @classmethod
106
- def _auto_eval(cls, ckpt: Literal["best", "last"] | str | Path | None):
112
+ def _auto_eval(cls, ckpt: Literal["best", "last", "none"] | str | Path | None):
107
113
  if ckpt is None:
108
114
  log.warn("No checkpoint specified for evaluation. Defaulting to `last`.")
109
115
  ckpt = "last"
@@ -119,6 +125,8 @@ class CheckpointLoadingConfig(C.Config):
119
125
  strategies=[LastCheckpointStrategyConfig()],
120
126
  include_hpc=False,
121
127
  )
128
+ case "none":
129
+ return cls.none(include_hpc=False)
122
130
  case Path() | str():
123
131
  ckpt = Path(ckpt)
124
132
  return cls(
@@ -131,7 +139,7 @@ class CheckpointLoadingConfig(C.Config):
131
139
  @classmethod
132
140
  def auto(
133
141
  cls,
134
- ckpt: Literal["best", "last"] | str | Path | None,
142
+ ckpt: Literal["best", "last", "none"] | str | Path | None,
135
143
  trainer_mode: TrainerFn,
136
144
  ):
137
145
  """
@@ -142,7 +150,7 @@ class CheckpointLoadingConfig(C.Config):
142
150
 
143
151
  Parameters:
144
152
  -----------
145
- ckpt : Literal["best", "last"] | str | Path | None
153
+ ckpt : Literal["best", "last", "none"] | str | Path | None
146
154
  Specifies the checkpoint loading preference:
147
155
  - "best": Use the best checkpoint based on the primary metric.
148
156
  - "last": Use the most recent checkpoint.
@@ -811,11 +811,14 @@ class SanityCheckingConfig(C.Config):
811
811
 
812
812
 
813
813
  class TrainerConfig(C.Config):
814
- ckpt_path: str | Path | None = None
815
- """Path to a checkpoint to load and resume training from."""
814
+ ckpt_path: Literal["none"] | str | Path | None = None
815
+ """Path to a checkpoint to load and resume training from. If ``"none"``, will not load a checkpoint."""
816
816
 
817
- checkpoint_loading: CheckpointLoadingConfig | Literal["auto"] = "auto"
818
- """Checkpoint loading configuration options."""
817
+ checkpoint_loading: CheckpointLoadingConfig | Literal["auto", "none"] = "auto"
818
+ """Checkpoint loading configuration options.
819
+ `"auto"` will automatically determine the best checkpoint loading strategy based on the provided.
820
+ `"none"` will disable checkpoint loading.
821
+ """
819
822
 
820
823
  checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
821
824
  """Checkpoint saving configuration options."""
@@ -31,8 +31,14 @@ class _CheckpointConnector(_LightningCheckpointConnector):
31
31
 
32
32
  # Now, resolve the checkpoint loader config.
33
33
  root_config = cast("BaseConfig", trainer._base_module.config)
34
- if (ckpt_loader_config := root_config.trainer.checkpoint_loading) == "auto":
35
- ckpt_loader_config = CheckpointLoadingConfig.auto(ckpt_path, state_fn)
34
+ ckpt_loader_config = root_config.trainer.checkpoint_loading
35
+ match ckpt_loader_config:
36
+ case "auto":
37
+ ckpt_loader_config = CheckpointLoadingConfig.auto(ckpt_path, state_fn)
38
+ case "none":
39
+ ckpt_loader_config = CheckpointLoadingConfig.none()
40
+ case _:
41
+ pass
36
42
  log.debug(f"Checkpoint loader config: {ckpt_loader_config}")
37
43
 
38
44
  # Use the config to resolve the checkpoint.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.19.3
3
+ Version: 0.20.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,5 +1,5 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
- nshtrainer/_checkpoint/loader.py,sha256=vc9f22qEDVw-y-Clpy71jVeI2EPxWNqRy-cAslMTb8c,13868
2
+ nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
3
3
  nshtrainer/_checkpoint/metadata.py,sha256=p5e7dhVPpOGrXeuesq_7Y_RHi5lguzDAR_UXtMJXzWU,5175
4
4
  nshtrainer/_checkpoint/saver.py,sha256=DkbCH0YeOJ71m32vAARiQdGBf0hvwwdoAV8LOFGy-0Y,1428
5
5
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
@@ -57,7 +57,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
57
57
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
58
58
  nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
59
59
  nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
60
- nshtrainer/model/config.py,sha256=147uV7IukvuYE4G_ZuQNxVjnlog1BdCrAVbcj_sx9Vs,43104
60
+ nshtrainer/model/config.py,sha256=zcCLcqvg4u7Zg6SLtCnqdIfiW8I0eART47lf1LCYl-A,43326
61
61
  nshtrainer/model/modules/callback.py,sha256=1z6gUDBd35KG3phGzRekgZM6SIk-wj5Uo6APN4YhRR0,8549
62
62
  nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
63
63
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
@@ -75,7 +75,7 @@ nshtrainer/runner.py,sha256=USAjrExHkN5oVNVunsoPnLxfQrEHSaa54S3RipOe544,3605
75
75
  nshtrainer/scripts/find_packages.py,sha256=ixYivZobumyyGsf2B9oYMLyLTRcBzY_vUv-u3bNW-hs,1424
76
76
  nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
77
77
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
78
- nshtrainer/trainer/checkpoint_connector.py,sha256=F2tkHogbMAa5U7335sm77sZBkjEDa5v46XbJCH9Mg6c,2167
78
+ nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
79
79
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
80
80
  nshtrainer/trainer/trainer.py,sha256=TTtVkgSB_ekgDlHg24d58Vzddtkpp6ZHOTVprXdXMH0,17503
81
81
  nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
@@ -85,6 +85,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
85
85
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
86
86
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
87
87
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
88
- nshtrainer-0.19.3.dist-info/METADATA,sha256=0LXA6hNdn7QjIJEGK-tvPQHiTEuxwsDzifvEyBMCYmo,935
89
- nshtrainer-0.19.3.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
90
- nshtrainer-0.19.3.dist-info/RECORD,,
88
+ nshtrainer-0.20.0.dist-info/METADATA,sha256=BCzgQYVMH8_7VHpAcAEuJqlQ0oJOERSbBop4bOebYZ4,935
89
+ nshtrainer-0.20.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
90
+ nshtrainer-0.20.0.dist-info/RECORD,,