nshtrainer 1.0.0b43__py3-none-any.whl → 1.0.0b44__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -40,6 +40,10 @@ from .log_epoch import LogEpochCallback as LogEpochCallback
40
40
  from .log_epoch import LogEpochCallbackConfig as LogEpochCallbackConfig
41
41
  from .lr_monitor import LearningRateMonitor as LearningRateMonitor
42
42
  from .lr_monitor import LearningRateMonitorConfig as LearningRateMonitorConfig
43
+ from .metric_validation import MetricValidationCallback as MetricValidationCallback
44
+ from .metric_validation import (
45
+ MetricValidationCallbackConfig as MetricValidationCallbackConfig,
46
+ )
43
47
  from .norm_logging import NormLoggingCallback as NormLoggingCallback
44
48
  from .norm_logging import NormLoggingCallbackConfig as NormLoggingCallbackConfig
45
49
  from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
@@ -0,0 +1,75 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Literal
5
+
6
+ from lightning.pytorch.utilities.exceptions import MisconfigurationException
7
+ from typing_extensions import final, override, assert_never
8
+
9
+ from .._callback import NTCallbackBase
10
+ from ..metrics import MetricConfig
11
+ from .base import CallbackConfigBase, callback_registry
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ @final
17
+ @callback_registry.register
18
+ class MetricValidationCallbackConfig(CallbackConfigBase):
19
+ name: Literal["metric_validation"] = "metric_validation"
20
+
21
+ error_behavior: Literal["raise", "warn"] = "raise"
22
+ """
23
+ Behavior when an error occurs during validation:
24
+ - "raise": Raise an error and stop the training.
25
+ - "warn": Log a warning and continue the training.
26
+ """
27
+
28
+ validate_default_metric: bool = True
29
+ """Whether to validate the default metric from the root config."""
30
+
31
+ metrics: list[MetricConfig] = []
32
+ """List of metrics to validate."""
33
+
34
+ @override
35
+ def create_callbacks(self, trainer_config):
36
+ metrics = self.metrics.copy()
37
+ if (
38
+ self.validate_default_metric
39
+ and (default_metric := trainer_config.primary_metric) is not None
40
+ ):
41
+ metrics.append(default_metric)
42
+
43
+ yield MetricValidationCallback(self, metrics)
44
+
45
+
46
+ class MetricValidationCallback(NTCallbackBase):
47
+ def __init__(
48
+ self, config: MetricValidationCallbackConfig, metrics: list[MetricConfig]
49
+ ):
50
+ super().__init__()
51
+
52
+ self.config = config
53
+ self.metrics = metrics
54
+
55
+ @override
56
+ def on_sanity_check_end(self, trainer, pl_module):
57
+ super().on_sanity_check_end(trainer, pl_module)
58
+
59
+ log.debug("Validating metrics...")
60
+ logged_metrics = set(trainer.logged_metrics.keys())
61
+ for metric in self.metrics:
62
+ if metric.validation_monitor in logged_metrics:
63
+ continue
64
+
65
+ match self.config.error_behavior:
66
+ case "raise":
67
+ raise MisconfigurationException(
68
+ f"Metric '{metric.validation_monitor}' not found in logged metrics."
69
+ )
70
+ case "warn":
71
+ log.warning(
72
+ f"Metric '{metric.validation_monitor}' not found in logged metrics."
73
+ )
74
+ case _:
75
+ assert_never(self.config.error_behavior)
@@ -39,6 +39,9 @@ from nshtrainer.callbacks import (
39
39
  )
40
40
  from nshtrainer.callbacks import LearningRateMonitorConfig as LearningRateMonitorConfig
41
41
  from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
42
+ from nshtrainer.callbacks import (
43
+ MetricValidationCallbackConfig as MetricValidationCallbackConfig,
44
+ )
42
45
  from nshtrainer.callbacks import NormLoggingCallbackConfig as NormLoggingCallbackConfig
43
46
  from nshtrainer.callbacks import (
44
47
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
@@ -287,6 +290,7 @@ __all__ = [
287
290
  "MPIEnvironmentPlugin",
288
291
  "MPSAcceleratorConfig",
289
292
  "MetricConfig",
293
+ "MetricValidationCallbackConfig",
290
294
  "MishNonlinearityConfig",
291
295
  "MixedPrecisionPluginConfig",
292
296
  "NonlinearityConfig",
@@ -28,6 +28,9 @@ from nshtrainer.callbacks import (
28
28
  )
29
29
  from nshtrainer.callbacks import LearningRateMonitorConfig as LearningRateMonitorConfig
30
30
  from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
31
+ from nshtrainer.callbacks import (
32
+ MetricValidationCallbackConfig as MetricValidationCallbackConfig,
33
+ )
31
34
  from nshtrainer.callbacks import NormLoggingCallbackConfig as NormLoggingCallbackConfig
32
35
  from nshtrainer.callbacks import (
33
36
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
@@ -65,6 +68,7 @@ from . import finite_checks as finite_checks
65
68
  from . import gradient_skipping as gradient_skipping
66
69
  from . import log_epoch as log_epoch
67
70
  from . import lr_monitor as lr_monitor
71
+ from . import metric_validation as metric_validation
68
72
  from . import norm_logging as norm_logging
69
73
  from . import print_table as print_table
70
74
  from . import rlp_sanity_checks as rlp_sanity_checks
@@ -91,6 +95,7 @@ __all__ = [
91
95
  "LearningRateMonitorConfig",
92
96
  "LogEpochCallbackConfig",
93
97
  "MetricConfig",
98
+ "MetricValidationCallbackConfig",
94
99
  "NormLoggingCallbackConfig",
95
100
  "OnExceptionCheckpointCallbackConfig",
96
101
  "PrintTableMetricsCallbackConfig",
@@ -110,6 +115,7 @@ __all__ = [
110
115
  "gradient_skipping",
111
116
  "log_epoch",
112
117
  "lr_monitor",
118
+ "metric_validation",
113
119
  "norm_logging",
114
120
  "print_table",
115
121
  "rlp_sanity_checks",
@@ -0,0 +1,21 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.callbacks.metric_validation import (
6
+ CallbackConfigBase as CallbackConfigBase,
7
+ )
8
+ from nshtrainer.callbacks.metric_validation import MetricConfig as MetricConfig
9
+ from nshtrainer.callbacks.metric_validation import (
10
+ MetricValidationCallbackConfig as MetricValidationCallbackConfig,
11
+ )
12
+ from nshtrainer.callbacks.metric_validation import (
13
+ callback_registry as callback_registry,
14
+ )
15
+
16
+ __all__ = [
17
+ "CallbackConfigBase",
18
+ "MetricConfig",
19
+ "MetricValidationCallbackConfig",
20
+ "callback_registry",
21
+ ]
@@ -38,6 +38,9 @@ from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbac
38
38
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
39
39
  from nshtrainer.trainer._config import LoggerConfigBase as LoggerConfigBase
40
40
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
41
+ from nshtrainer.trainer._config import (
42
+ MetricValidationCallbackConfig as MetricValidationCallbackConfig,
43
+ )
41
44
  from nshtrainer.trainer._config import (
42
45
  NormLoggingCallbackConfig as NormLoggingCallbackConfig,
43
46
  )
@@ -164,6 +167,7 @@ __all__ = [
164
167
  "MPIEnvironmentPlugin",
165
168
  "MPSAcceleratorConfig",
166
169
  "MetricConfig",
170
+ "MetricValidationCallbackConfig",
167
171
  "MixedPrecisionPluginConfig",
168
172
  "NormLoggingCallbackConfig",
169
173
  "OnExceptionCheckpointCallbackConfig",
@@ -34,6 +34,9 @@ from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbac
34
34
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
35
35
  from nshtrainer.trainer._config import LoggerConfigBase as LoggerConfigBase
36
36
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
37
+ from nshtrainer.trainer._config import (
38
+ MetricValidationCallbackConfig as MetricValidationCallbackConfig,
39
+ )
37
40
  from nshtrainer.trainer._config import (
38
41
  NormLoggingCallbackConfig as NormLoggingCallbackConfig,
39
42
  )
@@ -77,6 +80,7 @@ __all__ = [
77
80
  "LoggerConfig",
78
81
  "LoggerConfigBase",
79
82
  "MetricConfig",
83
+ "MetricValidationCallbackConfig",
80
84
  "NormLoggingCallbackConfig",
81
85
  "OnExceptionCheckpointCallbackConfig",
82
86
  "PluginConfig",
@@ -40,6 +40,7 @@ from ..callbacks.base import CallbackConfigBase
40
40
  from ..callbacks.debug_flag import DebugFlagCallbackConfig
41
41
  from ..callbacks.log_epoch import LogEpochCallbackConfig
42
42
  from ..callbacks.lr_monitor import LearningRateMonitorConfig
43
+ from ..callbacks.metric_validation import MetricValidationCallbackConfig
43
44
  from ..callbacks.rlp_sanity_checks import RLPSanityChecksCallbackConfig
44
45
  from ..callbacks.shared_parameters import SharedParametersCallbackConfig
45
46
  from ..loggers import (
@@ -697,6 +698,10 @@ class TrainerConfig(C.Config):
697
698
  - The trainer is running in fast_dev_run mode.
698
699
  - The trainer is running a sanity check (which happens before starting the training routine).
699
700
  """
701
+ auto_validate_metrics: MetricValidationCallbackConfig | None = (
702
+ MetricValidationCallbackConfig()
703
+ )
704
+ """If enabled, will automatically validate the metrics before starting the training routine."""
700
705
 
701
706
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
702
707
  """
@@ -768,6 +773,7 @@ class TrainerConfig(C.Config):
768
773
  yield self.shared_parameters
769
774
  yield self.reduce_lr_on_plateau_sanity_checking
770
775
  yield self.auto_set_debug_flag
776
+ yield self.auto_validate_metrics
771
777
  yield from self.callbacks
772
778
 
773
779
  def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.0.0b43
3
+ Version: 1.0.0b44
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -6,7 +6,7 @@ nshtrainer/_checkpoint/saver.py,sha256=rWl4d2lCTMU4_wt8yZFL2pFQaP9hj5sPgqHMPQ4zu
6
6
  nshtrainer/_directory.py,sha256=TJR9ccyuzRlAVfVjGyeQ3E2AFAcz-XbBCxWfiXo2SlY,3191
7
7
  nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
8
8
  nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
9
- nshtrainer/callbacks/__init__.py,sha256=4giOYT8A709UOLRtQEt16QbOAFUHCjJ_aLB7ITTwXJI,3577
9
+ nshtrainer/callbacks/__init__.py,sha256=w80d6PGNu3wjUj9NiRGMqCX9NnXD5ZlvbY-DIK4zjPE,3766
10
10
  nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
11
11
  nshtrainer/callbacks/base.py,sha256=Alaou1IHAIlMEM7g58d_02ozY2xWlshBN7fsw5Ee21s,3683
12
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
@@ -23,6 +23,7 @@ nshtrainer/callbacks/gradient_skipping.py,sha256=8g7oC7PF0LTAEzwiNoaS5tWOnkjk_EB
23
23
  nshtrainer/callbacks/interval.py,sha256=UCzUzt3XCFVyQyCWL9lOrStkkxesvduNOYk8yMrGTTk,8116
24
24
  nshtrainer/callbacks/log_epoch.py,sha256=B5Dm8XVZwCzKUhUWfT_5PDdDac993191OsbcxxuSVJE,1457
25
25
  nshtrainer/callbacks/lr_monitor.py,sha256=qy_C0R40J0hBAukzBwng5FI2jJUpWuXOi5N6FU6ym3I,1210
26
+ nshtrainer/callbacks/metric_validation.py,sha256=4bMMHVQ7rBbveDiowZS7Wwr77rE8HrerIbo3n9OddPA,2406
26
27
  nshtrainer/callbacks/norm_logging.py,sha256=nVIDWe-ASl5zN830-ODR8QMCqI1ma-QPCIwoy0Wb-Nk,6390
27
28
  nshtrainer/callbacks/print_table.py,sha256=VaS4JgI963do79laXK4lUkFQx8v6aRSy22W0zyal_LA,3035
28
29
  nshtrainer/callbacks/rlp_sanity_checks.py,sha256=74BZvV2HLO__ucQXsLXb8eJLUZgRFUNJZ6TL9efMp74,10051
@@ -31,12 +32,12 @@ nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU
31
32
  nshtrainer/callbacks/wandb_upload_code.py,sha256=shV7UtnXgY2bUlXdVrXiaDs0PNLlIt7TzNJkJPkzvzI,2414
32
33
  nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
33
34
  nshtrainer/configs/.gitattributes,sha256=VeZmarvNEqiRBOHGcllpKm90nL6C8u4tBu7SEm7fj-E,26
34
- nshtrainer/configs/__init__.py,sha256=MZfcSKhnjtVObBvVv9lu8L2cFTLINP5zcTQvWnz8jdk,14505
35
+ nshtrainer/configs/__init__.py,sha256=0BzCgE1iEJ0Ywmy__mqJZipLQtwZVdz6XK-gHbkA7GY,14650
35
36
  nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
36
37
  nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
37
38
  nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
38
39
  nshtrainer/configs/_hf_hub/__init__.py,sha256=ciFLbV-JV8SVzqo2SyythEuDMnk7gGfdIacB18QYnkY,511
39
- nshtrainer/configs/callbacks/__init__.py,sha256=jSWkbsdiu9vdGWTzqkDf-Bo9dXr9RengeNZLzWUhi7Y,4283
40
+ nshtrainer/configs/callbacks/__init__.py,sha256=PB3Jg-8_vMhp-mCFw2_Tqt05drKwHK6Ovl9mb8NNiXs,4506
40
41
  nshtrainer/configs/callbacks/actsave/__init__.py,sha256=JvjSZtEoA28FC4u-QT3skQzBDVbN9eq07rn4u2ydW-E,377
41
42
  nshtrainer/configs/callbacks/base/__init__.py,sha256=wT3RhXttLyf6RFWCIvsoiXcPdfGx5W309WBI18AI5os,278
42
43
  nshtrainer/configs/callbacks/checkpoint/__init__.py,sha256=aGJ7vX14YamkMdwYAdPv6XrRnP0aZd5uZ5X0nSLc6IU,1475
@@ -52,6 +53,7 @@ nshtrainer/configs/callbacks/finite_checks/__init__.py,sha256=e-vx9Kn-noqw4wPvZw
52
53
  nshtrainer/configs/callbacks/gradient_skipping/__init__.py,sha256=T3eVxxJfnYBrO9WfLiycn4TyWP4vaqJ57yp7Epkg7B4,485
53
54
  nshtrainer/configs/callbacks/log_epoch/__init__.py,sha256=IQ5owYYvyk7fiQP1QXYtncRRJrESuq3rRFhab-II2uE,419
54
55
  nshtrainer/configs/callbacks/lr_monitor/__init__.py,sha256=qejy1AnXNDHmsFuXRAXQQ5B0TcbKzvpaw-I4dv2AXIs,431
56
+ nshtrainer/configs/callbacks/metric_validation/__init__.py,sha256=_YV0EbISkforE_GDlTTVA6Nn2_l13zX3m1ggcbhnAvs,585
55
57
  nshtrainer/configs/callbacks/norm_logging/__init__.py,sha256=j2LrnYEbDGLwJR2lk-jmh-4J_iLEs2HNEoepvJSFLAg,437
56
58
  nshtrainer/configs/callbacks/print_table/__init__.py,sha256=t6fA_dBkUCszUXDJKEdnlBH4oEpfAQqcmAlatTFYIyQ,452
57
59
  nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py,sha256=dlP14Wh-w8zG_B4EtNmCIFzVMhf6bXCJ1O9cJWmEFnA,482
@@ -80,8 +82,8 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRD
80
82
  nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
81
83
  nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
82
84
  nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
83
- nshtrainer/configs/trainer/__init__.py,sha256=jYCp4Q9uvutA6NYqfthbREMg09-obD3gHtzEI2Ta-hU,7729
84
- nshtrainer/configs/trainer/_config/__init__.py,sha256=uof_oJfhwjB1pft7KsRdk_RvNj-tE8wcDBEM7X5qtNc,3666
85
+ nshtrainer/configs/trainer/__init__.py,sha256=a8pzGVid52abAVARPbgjaN566H1ZM44FH_x95bsBaGE,7880
86
+ nshtrainer/configs/trainer/_config/__init__.py,sha256=6DXdtP-uH11TopQ7kzId9fco-wVkD7ZfevbBqDpN6TE,3817
85
87
  nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
86
88
  nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
87
89
  nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=GuubcKrbXt4VjJRT8VpNUQqBtyuutre_CfkS0EWZ5_E,368
@@ -129,7 +131,7 @@ nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5N
129
131
  nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
130
132
  nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
131
133
  nshtrainer/trainer/__init__.py,sha256=fQ7gQRlGWX-90TYT0rttkQyvXDCzo7DAvJgr-jX1zsY,316
132
- nshtrainer/trainer/_config.py,sha256=QDy6sINVDGEqfHfPTWXSN-06EoEuMSVscHn8fCRTvr0,32981
134
+ nshtrainer/trainer/_config.py,sha256=pCBRtqIC_BzNPqthsDhd7L5_7DG5y8_uVq19lj1mtOM,33311
133
135
  nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
134
136
  nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
135
137
  nshtrainer/trainer/plugin/__init__.py,sha256=UM8f70Ml3RGpsXeQr1Yh1yBcxxjFNGFGgJbniCn_rws,366
@@ -152,6 +154,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
152
154
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
153
155
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
154
156
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
155
- nshtrainer-1.0.0b43.dist-info/METADATA,sha256=ZE3l6CN34ptFgx3SDPfKIgjdV2s3J8qdP729eb58vzo,988
156
- nshtrainer-1.0.0b43.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
157
- nshtrainer-1.0.0b43.dist-info/RECORD,,
157
+ nshtrainer-1.0.0b44.dist-info/METADATA,sha256=u_dApZgfGst9vUiKBgnFQhGB0pBeULPOeGlaQ5-CPnI,988
158
+ nshtrainer-1.0.0b44.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
159
+ nshtrainer-1.0.0b44.dist-info/RECORD,,