nshtrainer 0.37.0__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -41,7 +41,7 @@ class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
41
41
  self,
42
42
  root_config: "BaseConfig",
43
43
  dirpath: Path,
44
- ) -> "CheckpointBase": ...
44
+ ) -> "CheckpointBase | None": ...
45
45
 
46
46
  @override
47
47
  def create_callbacks(self, root_config):
@@ -50,7 +50,8 @@ class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
50
50
  or root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
51
51
  )
52
52
 
53
- yield self.create_checkpoint(root_config, dirpath)
53
+ if (callback := self.create_checkpoint(root_config, dirpath)) is not None:
54
+ yield callback
54
55
 
55
56
 
56
57
  TConfig = TypeVar("TConfig", bound=BaseCheckpointCallbackConfig, infer_variance=True)
@@ -20,15 +20,26 @@ class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
20
20
  metric: MetricConfig | None = None
21
21
  """Metric to monitor, or `None` to use the default metric."""
22
22
 
23
+ throw_on_no_metric: bool = True
24
+ """
25
+ Whether to throw an error if no metric is provided and no primary metric is found in the root config.
26
+ """
27
+
23
28
  @override
24
29
  def create_checkpoint(self, root_config, dirpath):
25
30
  # Resolve metric
26
31
  if (metric := self.metric) is None and (
27
32
  metric := root_config.primary_metric
28
33
  ) is None:
29
- raise ValueError(
30
- "No metric provided and no primary metric found in the root config"
34
+ error_msg = (
35
+ "No metric provided and no primary metric found in the root config. "
36
+ "Cannot create BestCheckpointCallback."
31
37
  )
38
+ if self.throw_on_no_metric:
39
+ raise ValueError(error_msg)
40
+ else:
41
+ log.warning(error_msg)
42
+ return None
32
43
 
33
44
  return BestCheckpoint(self, dirpath, metric)
34
45
 
@@ -51,7 +51,7 @@ class EarlyStoppingConfig(CallbackConfigBase):
51
51
  metric := root_config.primary_metric
52
52
  ) is None:
53
53
  raise ValueError(
54
- "Either `metric` or `root_config.primary_metric` must be set."
54
+ "Either `metric` or `root_config.primary_metric` must be set to use EarlyStopping."
55
55
  )
56
56
 
57
57
  yield EarlyStopping(self, metric)
nshtrainer/config.py CHANGED
@@ -65,13 +65,13 @@ from nshtrainer.callbacks.wandb_upload_code import (
65
65
  WandbUploadCodeConfig as WandbUploadCodeConfig,
66
66
  )
67
67
  from nshtrainer.callbacks.wandb_watch import WandbWatchConfig as WandbWatchConfig
68
- from nshtrainer.config import LRSchedulerConfig as LRSchedulerConfig
69
68
  from nshtrainer.loggers._base import BaseLoggerConfig as BaseLoggerConfig
70
69
  from nshtrainer.loggers.csv import CSVLoggerConfig as CSVLoggerConfig
71
70
  from nshtrainer.loggers.tensorboard import (
72
71
  TensorboardLoggerConfig as TensorboardLoggerConfig,
73
72
  )
74
73
  from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
74
+ from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
75
75
  from nshtrainer.lr_scheduler._base import LRSchedulerConfigBase as LRSchedulerConfigBase
76
76
  from nshtrainer.lr_scheduler.linear_warmup_cosine import (
77
77
  DurationConfig as DurationConfig,
@@ -263,7 +263,7 @@ class CheckpointSavingConfig(CallbackConfigBase):
263
263
  """Enable checkpoint saving."""
264
264
 
265
265
  checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
266
- BestCheckpointCallbackConfig(),
266
+ BestCheckpointCallbackConfig(throw_on_no_metric=False),
267
267
  LastCheckpointCallbackConfig(),
268
268
  OnExceptionCheckpointCallbackConfig(),
269
269
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.37.0
3
+ Version: 0.39.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -11,13 +11,13 @@ nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHm
11
11
  nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
12
12
  nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
13
13
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
14
- nshtrainer/callbacks/checkpoint/_base.py,sha256=vvlwuD-20NozYVIolGGShmUdkkNYeuwN6xCoFnK4GiU,6157
15
- nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=8BHgLAd3Tuzf5sup0guEAKF1jJiAwYsjdKBFYZw98ac,2171
14
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=X6844ceZ-EMz9UCVLJnA1d-ej2AkkRH90v2hy2zPDkc,6215
15
+ nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=MOC5jLlQnnc61yuiVc2-O4NeYclnkXroOPDV9-2zh5w,2553
16
16
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
17
17
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
18
18
  nshtrainer/callbacks/debug_flag.py,sha256=T7rkY9hYQ_-PsPo2XiQ4eVZ9bBsTd2knpZWctCbjxXc,2011
19
19
  nshtrainer/callbacks/directory_setup.py,sha256=c0uY0oTqLcQ3egInHO7G6BeQQgk_xvOLoHH8FR-9U0U,2629
20
- nshtrainer/callbacks/early_stopping.py,sha256=VWuJz0oN87b6SwBeVc32YNpeJr1wts8K45k8JJJmG9I,4617
20
+ nshtrainer/callbacks/early_stopping.py,sha256=upCNtjkXmwGPwLJmDZacXdBzySG2i_sGIK9QjnaE6tU,4638
21
21
  nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
22
22
  nshtrainer/callbacks/finite_checks.py,sha256=gJC_RUr3ais3FJI0uB6wUZnDdE3WRwCix3ppA3PwQXA,2077
23
23
  nshtrainer/callbacks/gradient_skipping.py,sha256=EBNkANDnD3BTszWjnG-jwY8FEj-iRqhE3e1x5LQF6M8,3393
@@ -31,7 +31,7 @@ nshtrainer/callbacks/throughput_monitor.py,sha256=H_ocXErZxUO3dxFk8Tx_VQdpI9E_Zt
31
31
  nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
32
32
  nshtrainer/callbacks/wandb_upload_code.py,sha256=OWG4UkL2SfW6oj6AGRXeBJsZmgsqeHLW2Fj8Jm4ga3I,2298
33
33
  nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
34
- nshtrainer/config.py,sha256=6jApGtO9DVFoXKr9_Z7-MFG5R4WXjbpzZ6jkNI3yD-Y,8306
34
+ nshtrainer/config.py,sha256=pZyRZOkBRR7eBFRiHpHjQFNEFjaX9tYZIAqZtvKi6cA,8312
35
35
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
36
36
  nshtrainer/data/balanced_batch_sampler.py,sha256=ybMJF-CguaZ17fLEweZ5suaGOiHOMEm3Bn8rQfGTzGQ,5445
37
37
  nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
@@ -82,7 +82,7 @@ nshtrainer/profiler/simple.py,sha256=MbMfsJvligd0mtGiltxJ0T8MQVDP9T9BzQZFwswl66Y
82
82
  nshtrainer/runner.py,sha256=USAjrExHkN5oVNVunsoPnLxfQrEHSaa54S3RipOe544,3605
83
83
  nshtrainer/scripts/find_packages.py,sha256=ixYivZobumyyGsf2B9oYMLyLTRcBzY_vUv-u3bNW-hs,1424
84
84
  nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
85
- nshtrainer/trainer/_config.py,sha256=ZIodM5Ek1lpkWFhQ_VfmKR7q1mZFFwtjfx8FH72H8WM,29174
85
+ nshtrainer/trainer/_config.py,sha256=YqpGb4RodkUg87TVE5WBSc4CQkUF0z3qDRdil1HRxoM,29198
86
86
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
87
87
  nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
88
88
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
@@ -99,6 +99,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
99
99
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
100
100
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
101
101
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
102
- nshtrainer-0.37.0.dist-info/METADATA,sha256=Rd7HeNaz5lBQa0AspjWyRgFLctqpSfna2R9VnMEUURU,916
103
- nshtrainer-0.37.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
104
- nshtrainer-0.37.0.dist-info/RECORD,,
102
+ nshtrainer-0.39.0.dist-info/METADATA,sha256=bIWwvsGuePEZeT3Q8dYS8IK5Y6ZM4yqe9v41Ybs9OGM,916
103
+ nshtrainer-0.39.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
104
+ nshtrainer-0.39.0.dist-info/RECORD,,