nshtrainer 0.19.2__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -76,7 +76,11 @@ class CheckpointLoadingConfig(C.Config):
76
76
  """Whether to include checkpoints from HPC pre-emption."""
77
77
 
78
78
  @classmethod
79
- def _auto_train(cls, ckpt: Literal["best", "last"] | str | Path | None):
79
+ def none(cls, include_hpc: bool = False):
80
+ return cls(strategies=[], include_hpc=include_hpc)
81
+
82
+ @classmethod
83
+ def _auto_train(cls, ckpt: Literal["best", "last", "none"] | str | Path | None):
80
84
  if ckpt is None:
81
85
  ckpt = "last"
82
86
  match ckpt:
@@ -90,6 +94,8 @@ class CheckpointLoadingConfig(C.Config):
90
94
  strategies=[LastCheckpointStrategyConfig()],
91
95
  include_hpc=True,
92
96
  )
97
+ case "none":
98
+ return cls.none()
93
99
  case Path() | str():
94
100
  ckpt = Path(ckpt)
95
101
  return cls(
@@ -103,9 +109,10 @@ class CheckpointLoadingConfig(C.Config):
103
109
  assert_never(ckpt)
104
110
 
105
111
  @classmethod
106
- def _auto_eval(cls, ckpt: Literal["best", "last"] | str | Path | None):
112
+ def _auto_eval(cls, ckpt: Literal["best", "last", "none"] | str | Path | None):
107
113
  if ckpt is None:
108
- raise ValueError("Checkpoint path must be provided for evaluation.")
114
+ log.warn("No checkpoint specified for evaluation. Defaulting to `last`.")
115
+ ckpt = "last"
109
116
 
110
117
  match ckpt:
111
118
  case "best":
@@ -118,6 +125,8 @@ class CheckpointLoadingConfig(C.Config):
118
125
  strategies=[LastCheckpointStrategyConfig()],
119
126
  include_hpc=False,
120
127
  )
128
+ case "none":
129
+ return cls.none(include_hpc=False)
121
130
  case Path() | str():
122
131
  ckpt = Path(ckpt)
123
132
  return cls(
@@ -130,7 +139,7 @@ class CheckpointLoadingConfig(C.Config):
130
139
  @classmethod
131
140
  def auto(
132
141
  cls,
133
- ckpt: Literal["best", "last"] | str | Path | None,
142
+ ckpt: Literal["best", "last", "none"] | str | Path | None,
134
143
  trainer_mode: TrainerFn,
135
144
  ):
136
145
  """
@@ -141,7 +150,7 @@ class CheckpointLoadingConfig(C.Config):
141
150
 
142
151
  Parameters:
143
152
  -----------
144
- ckpt : Literal["best", "last"] | str | Path | None
153
+ ckpt : Literal["best", "last", "none"] | str | Path | None
145
154
  Specifies the checkpoint loading preference:
146
155
  - "best": Use the best checkpoint based on the primary metric.
147
156
  - "last": Use the most recent checkpoint.
@@ -811,11 +811,14 @@ class SanityCheckingConfig(C.Config):
811
811
 
812
812
 
813
813
  class TrainerConfig(C.Config):
814
- ckpt_path: str | Path | None = None
815
- """Path to a checkpoint to load and resume training from."""
814
+ ckpt_path: Literal["none"] | str | Path | None = None
815
+ """Path to a checkpoint to load and resume training from. If ``"none"``, will not load a checkpoint."""
816
816
 
817
- checkpoint_loading: CheckpointLoadingConfig | Literal["auto"] = "auto"
818
- """Checkpoint loading configuration options."""
817
+ checkpoint_loading: CheckpointLoadingConfig | Literal["auto", "none"] = "auto"
818
+ """Checkpoint loading configuration options.
819
+ `"auto"` will automatically determine the best checkpoint loading strategy based on the provided.
820
+ `"none"` will disable checkpoint loading.
821
+ """
819
822
 
820
823
  checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
821
824
  """Checkpoint saving configuration options."""
@@ -31,8 +31,14 @@ class _CheckpointConnector(_LightningCheckpointConnector):
31
31
 
32
32
  # Now, resolve the checkpoint loader config.
33
33
  root_config = cast("BaseConfig", trainer._base_module.config)
34
- if (ckpt_loader_config := root_config.trainer.checkpoint_loading) == "auto":
35
- ckpt_loader_config = CheckpointLoadingConfig.auto(ckpt_path, state_fn)
34
+ ckpt_loader_config = root_config.trainer.checkpoint_loading
35
+ match ckpt_loader_config:
36
+ case "auto":
37
+ ckpt_loader_config = CheckpointLoadingConfig.auto(ckpt_path, state_fn)
38
+ case "none":
39
+ ckpt_loader_config = CheckpointLoadingConfig.none()
40
+ case _:
41
+ pass
36
42
  log.debug(f"Checkpoint loader config: {ckpt_loader_config}")
37
43
 
38
44
  # Use the config to resolve the checkpoint.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.19.2
3
+ Version: 0.20.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,5 +1,5 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
- nshtrainer/_checkpoint/loader.py,sha256=myFObRsPdb8jBncMK73vjr5FDJIfKhF86Ec_kSjXtwg,13837
2
+ nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
3
3
  nshtrainer/_checkpoint/metadata.py,sha256=p5e7dhVPpOGrXeuesq_7Y_RHi5lguzDAR_UXtMJXzWU,5175
4
4
  nshtrainer/_checkpoint/saver.py,sha256=DkbCH0YeOJ71m32vAARiQdGBf0hvwwdoAV8LOFGy-0Y,1428
5
5
  nshtrainer/_experimental/__init__.py,sha256=pEXPyI184UuDHvfh4p9Kg9nQZQZI41e4_HvNd4BK-yg,81
@@ -57,7 +57,7 @@ nshtrainer/metrics/__init__.py,sha256=ObLIELGguIEcUpRsUkqh1ltrvZii6vglTpJGrPvoy0
57
57
  nshtrainer/metrics/_config.py,sha256=jgRBfDAQLFTW7AiUY7CRtdfts6CR6keeuqm0FFMWCzQ,1288
58
58
  nshtrainer/model/__init__.py,sha256=VyRziPT3YilP6xjLi_StsSqtlvn7N4LOMzgukRsOnF8,1380
59
59
  nshtrainer/model/base.py,sha256=oQVolDk81acy4OlckwQEBHuX2gCaVSYiIA0JaDIfhQ4,17517
60
- nshtrainer/model/config.py,sha256=147uV7IukvuYE4G_ZuQNxVjnlog1BdCrAVbcj_sx9Vs,43104
60
+ nshtrainer/model/config.py,sha256=zcCLcqvg4u7Zg6SLtCnqdIfiW8I0eART47lf1LCYl-A,43326
61
61
  nshtrainer/model/modules/callback.py,sha256=1z6gUDBd35KG3phGzRekgZM6SIk-wj5Uo6APN4YhRR0,8549
62
62
  nshtrainer/model/modules/debug.py,sha256=Yy7XEdPou9BkCsD5hJchwJGmCVGrfUru5g9VjPM4uAw,1120
63
63
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
@@ -75,7 +75,7 @@ nshtrainer/runner.py,sha256=USAjrExHkN5oVNVunsoPnLxfQrEHSaa54S3RipOe544,3605
75
75
  nshtrainer/scripts/find_packages.py,sha256=ixYivZobumyyGsf2B9oYMLyLTRcBzY_vUv-u3bNW-hs,1424
76
76
  nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
77
77
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
78
- nshtrainer/trainer/checkpoint_connector.py,sha256=F2tkHogbMAa5U7335sm77sZBkjEDa5v46XbJCH9Mg6c,2167
78
+ nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
79
79
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
80
80
  nshtrainer/trainer/trainer.py,sha256=TTtVkgSB_ekgDlHg24d58Vzddtkpp6ZHOTVprXdXMH0,17503
81
81
  nshtrainer/util/_environment_info.py,sha256=gIdq9TJgzGCdcVzZxjHcwYasJ_HmEGVHbvE-KJVVtWs,24187
@@ -85,6 +85,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
85
85
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
86
86
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
87
87
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
88
- nshtrainer-0.19.2.dist-info/METADATA,sha256=InNVoRQEPpPRCFbBje-ekgQzFFycxC9VzQsmEqUJK1c,935
89
- nshtrainer-0.19.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
90
- nshtrainer-0.19.2.dist-info/RECORD,,
88
+ nshtrainer-0.20.0.dist-info/METADATA,sha256=BCzgQYVMH8_7VHpAcAEuJqlQ0oJOERSbBop4bOebYZ4,935
89
+ nshtrainer-0.20.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
90
+ nshtrainer-0.20.0.dist-info/RECORD,,