nshtrainer 0.37.0__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/callbacks/checkpoint/_base.py +3 -2
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +13 -2
- nshtrainer/callbacks/early_stopping.py +1 -1
- nshtrainer/config.py +1 -1
- nshtrainer/trainer/_config.py +1 -1
- {nshtrainer-0.37.0.dist-info → nshtrainer-0.39.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.37.0.dist-info → nshtrainer-0.39.0.dist-info}/RECORD +8 -8
- {nshtrainer-0.37.0.dist-info → nshtrainer-0.39.0.dist-info}/WHEEL +0 -0
|
@@ -41,7 +41,7 @@ class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
|
|
|
41
41
|
self,
|
|
42
42
|
root_config: "BaseConfig",
|
|
43
43
|
dirpath: Path,
|
|
44
|
-
) -> "CheckpointBase": ...
|
|
44
|
+
) -> "CheckpointBase | None": ...
|
|
45
45
|
|
|
46
46
|
@override
|
|
47
47
|
def create_callbacks(self, root_config):
|
|
@@ -50,7 +50,8 @@ class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
|
|
|
50
50
|
or root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
|
|
51
51
|
)
|
|
52
52
|
|
|
53
|
-
|
|
53
|
+
if (callback := self.create_checkpoint(root_config, dirpath)) is not None:
|
|
54
|
+
yield callback
|
|
54
55
|
|
|
55
56
|
|
|
56
57
|
TConfig = TypeVar("TConfig", bound=BaseCheckpointCallbackConfig, infer_variance=True)
|
|
@@ -20,15 +20,26 @@ class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
|
|
20
20
|
metric: MetricConfig | None = None
|
|
21
21
|
"""Metric to monitor, or `None` to use the default metric."""
|
|
22
22
|
|
|
23
|
+
throw_on_no_metric: bool = True
|
|
24
|
+
"""
|
|
25
|
+
Whether to throw an error if no metric is provided and no primary metric is found in the root config.
|
|
26
|
+
"""
|
|
27
|
+
|
|
23
28
|
@override
|
|
24
29
|
def create_checkpoint(self, root_config, dirpath):
|
|
25
30
|
# Resolve metric
|
|
26
31
|
if (metric := self.metric) is None and (
|
|
27
32
|
metric := root_config.primary_metric
|
|
28
33
|
) is None:
|
|
29
|
-
|
|
30
|
-
"No metric provided and no primary metric found in the root config"
|
|
34
|
+
error_msg = (
|
|
35
|
+
"No metric provided and no primary metric found in the root config. "
|
|
36
|
+
"Cannot create BestCheckpointCallback."
|
|
31
37
|
)
|
|
38
|
+
if self.throw_on_no_metric:
|
|
39
|
+
raise ValueError(error_msg)
|
|
40
|
+
else:
|
|
41
|
+
log.warning(error_msg)
|
|
42
|
+
return None
|
|
32
43
|
|
|
33
44
|
return BestCheckpoint(self, dirpath, metric)
|
|
34
45
|
|
|
@@ -51,7 +51,7 @@ class EarlyStoppingConfig(CallbackConfigBase):
|
|
|
51
51
|
metric := root_config.primary_metric
|
|
52
52
|
) is None:
|
|
53
53
|
raise ValueError(
|
|
54
|
-
"Either `metric` or `root_config.primary_metric` must be set."
|
|
54
|
+
"Either `metric` or `root_config.primary_metric` must be set to use EarlyStopping."
|
|
55
55
|
)
|
|
56
56
|
|
|
57
57
|
yield EarlyStopping(self, metric)
|
nshtrainer/config.py
CHANGED
|
@@ -65,13 +65,13 @@ from nshtrainer.callbacks.wandb_upload_code import (
|
|
|
65
65
|
WandbUploadCodeConfig as WandbUploadCodeConfig,
|
|
66
66
|
)
|
|
67
67
|
from nshtrainer.callbacks.wandb_watch import WandbWatchConfig as WandbWatchConfig
|
|
68
|
-
from nshtrainer.config import LRSchedulerConfig as LRSchedulerConfig
|
|
69
68
|
from nshtrainer.loggers._base import BaseLoggerConfig as BaseLoggerConfig
|
|
70
69
|
from nshtrainer.loggers.csv import CSVLoggerConfig as CSVLoggerConfig
|
|
71
70
|
from nshtrainer.loggers.tensorboard import (
|
|
72
71
|
TensorboardLoggerConfig as TensorboardLoggerConfig,
|
|
73
72
|
)
|
|
74
73
|
from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
|
|
74
|
+
from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
|
|
75
75
|
from nshtrainer.lr_scheduler._base import LRSchedulerConfigBase as LRSchedulerConfigBase
|
|
76
76
|
from nshtrainer.lr_scheduler.linear_warmup_cosine import (
|
|
77
77
|
DurationConfig as DurationConfig,
|
nshtrainer/trainer/_config.py
CHANGED
|
@@ -263,7 +263,7 @@ class CheckpointSavingConfig(CallbackConfigBase):
|
|
|
263
263
|
"""Enable checkpoint saving."""
|
|
264
264
|
|
|
265
265
|
checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
|
|
266
|
-
BestCheckpointCallbackConfig(),
|
|
266
|
+
BestCheckpointCallbackConfig(throw_on_no_metric=False),
|
|
267
267
|
LastCheckpointCallbackConfig(),
|
|
268
268
|
OnExceptionCheckpointCallbackConfig(),
|
|
269
269
|
]
|
|
@@ -11,13 +11,13 @@ nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHm
|
|
|
11
11
|
nshtrainer/callbacks/actsave.py,sha256=qbnaKts4_dvjPeAaPtv7Ds12_vEWzaHUfg_--49NB9I,4041
|
|
12
12
|
nshtrainer/callbacks/base.py,sha256=NpjeKmonJ1Kaz5_39XSn3LlDwvbGjk6WV8BpHSNCvI4,3508
|
|
13
13
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=g-3zIthupERKqWZQw-A_busQPaPRkto6iHBV-M7nK1Y,527
|
|
14
|
-
nshtrainer/callbacks/checkpoint/_base.py,sha256=
|
|
15
|
-
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=
|
|
14
|
+
nshtrainer/callbacks/checkpoint/_base.py,sha256=X6844ceZ-EMz9UCVLJnA1d-ej2AkkRH90v2hy2zPDkc,6215
|
|
15
|
+
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=MOC5jLlQnnc61yuiVc2-O4NeYclnkXroOPDV9-2zh5w,2553
|
|
16
16
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=CWWv0cSwQ1VAX26N7hAyMxbNCk26Keh39oQguBEK5To,1102
|
|
17
17
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=ctT88EGT22_t_6tr5r7Sfo43cuve6XeroBnBYRMPOus,3372
|
|
18
18
|
nshtrainer/callbacks/debug_flag.py,sha256=T7rkY9hYQ_-PsPo2XiQ4eVZ9bBsTd2knpZWctCbjxXc,2011
|
|
19
19
|
nshtrainer/callbacks/directory_setup.py,sha256=c0uY0oTqLcQ3egInHO7G6BeQQgk_xvOLoHH8FR-9U0U,2629
|
|
20
|
-
nshtrainer/callbacks/early_stopping.py,sha256=
|
|
20
|
+
nshtrainer/callbacks/early_stopping.py,sha256=upCNtjkXmwGPwLJmDZacXdBzySG2i_sGIK9QjnaE6tU,4638
|
|
21
21
|
nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,12131
|
|
22
22
|
nshtrainer/callbacks/finite_checks.py,sha256=gJC_RUr3ais3FJI0uB6wUZnDdE3WRwCix3ppA3PwQXA,2077
|
|
23
23
|
nshtrainer/callbacks/gradient_skipping.py,sha256=EBNkANDnD3BTszWjnG-jwY8FEj-iRqhE3e1x5LQF6M8,3393
|
|
@@ -31,7 +31,7 @@ nshtrainer/callbacks/throughput_monitor.py,sha256=H_ocXErZxUO3dxFk8Tx_VQdpI9E_Zt
|
|
|
31
31
|
nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
|
|
32
32
|
nshtrainer/callbacks/wandb_upload_code.py,sha256=OWG4UkL2SfW6oj6AGRXeBJsZmgsqeHLW2Fj8Jm4ga3I,2298
|
|
33
33
|
nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
|
|
34
|
-
nshtrainer/config.py,sha256=
|
|
34
|
+
nshtrainer/config.py,sha256=pZyRZOkBRR7eBFRiHpHjQFNEFjaX9tYZIAqZtvKi6cA,8312
|
|
35
35
|
nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
|
|
36
36
|
nshtrainer/data/balanced_batch_sampler.py,sha256=ybMJF-CguaZ17fLEweZ5suaGOiHOMEm3Bn8rQfGTzGQ,5445
|
|
37
37
|
nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
|
|
@@ -82,7 +82,7 @@ nshtrainer/profiler/simple.py,sha256=MbMfsJvligd0mtGiltxJ0T8MQVDP9T9BzQZFwswl66Y
|
|
|
82
82
|
nshtrainer/runner.py,sha256=USAjrExHkN5oVNVunsoPnLxfQrEHSaa54S3RipOe544,3605
|
|
83
83
|
nshtrainer/scripts/find_packages.py,sha256=ixYivZobumyyGsf2B9oYMLyLTRcBzY_vUv-u3bNW-hs,1424
|
|
84
84
|
nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
|
|
85
|
-
nshtrainer/trainer/_config.py,sha256=
|
|
85
|
+
nshtrainer/trainer/_config.py,sha256=YqpGb4RodkUg87TVE5WBSc4CQkUF0z3qDRdil1HRxoM,29198
|
|
86
86
|
nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
|
|
87
87
|
nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
|
|
88
88
|
nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
|
|
@@ -99,6 +99,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
|
|
|
99
99
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
100
100
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
101
101
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
102
|
-
nshtrainer-0.
|
|
103
|
-
nshtrainer-0.
|
|
104
|
-
nshtrainer-0.
|
|
102
|
+
nshtrainer-0.39.0.dist-info/METADATA,sha256=bIWwvsGuePEZeT3Q8dYS8IK5Y6ZM4yqe9v41Ybs9OGM,916
|
|
103
|
+
nshtrainer-0.39.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
104
|
+
nshtrainer-0.39.0.dist-info/RECORD,,
|
|
File without changes
|