nshtrainer 0.12.0__tar.gz → 0.13.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/PKG-INFO +1 -1
  2. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/pyproject.toml +1 -1
  3. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/__init__.py +3 -1
  4. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/early_stopping.py +60 -31
  5. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/model/config.py +1 -67
  6. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/README.md +0 -0
  7. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/__init__.py +0 -0
  8. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/_checkpoint/loader.py +0 -0
  9. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  10. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
  11. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/_experimental/__init__.py +0 -0
  12. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
  13. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
  14. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
  15. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  16. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/actsave.py +0 -0
  17. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/base.py +0 -0
  18. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  19. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  20. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  21. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  22. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  23. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/ema.py +0 -0
  24. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  25. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  26. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/interval.py +0 -0
  27. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  28. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  29. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/print_table.py +0 -0
  30. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  31. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/timer.py +0 -0
  32. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  33. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/data/__init__.py +0 -0
  34. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  35. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/data/transform.py +0 -0
  36. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/ll/__init__.py +0 -0
  37. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/ll/_experimental.py +0 -0
  38. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/ll/actsave.py +0 -0
  39. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/ll/callbacks.py +0 -0
  40. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/ll/config.py +0 -0
  41. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/ll/data.py +0 -0
  42. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/ll/log.py +0 -0
  43. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  44. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/ll/model.py +0 -0
  45. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/ll/nn.py +0 -0
  46. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/ll/optimizer.py +0 -0
  47. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/ll/runner.py +0 -0
  48. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/ll/snapshot.py +0 -0
  49. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/ll/snoop.py +0 -0
  50. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/ll/trainer.py +0 -0
  51. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/ll/typecheck.py +0 -0
  52. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/ll/util.py +0 -0
  53. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  54. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  55. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  56. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  57. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/metrics/__init__.py +0 -0
  58. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/metrics/_config.py +0 -0
  59. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/model/__init__.py +0 -0
  60. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/model/base.py +0 -0
  61. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/model/modules/callback.py +0 -0
  62. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/model/modules/debug.py +0 -0
  63. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/model/modules/distributed.py +0 -0
  64. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/model/modules/logger.py +0 -0
  65. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/model/modules/profiler.py +0 -0
  66. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  67. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  68. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/nn/__init__.py +0 -0
  69. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/nn/mlp.py +0 -0
  70. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/nn/module_dict.py +0 -0
  71. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/nn/module_list.py +0 -0
  72. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
  73. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/optimizer.py +0 -0
  74. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/runner.py +0 -0
  75. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/scripts/find_packages.py +0 -0
  76. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/trainer/__init__.py +0 -0
  77. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  78. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  79. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
  80. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/trainer/trainer.py +0 -0
  81. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/util/_environment_info.py +0 -0
  82. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/util/_useful_types.py +0 -0
  83. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/util/environment.py +0 -0
  84. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/util/seed.py +0 -0
  85. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/util/slurm.py +0 -0
  86. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/util/typed.py +0 -0
  87. {nshtrainer-0.12.0 → nshtrainer-0.13.0}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.12.0
3
+ Version: 0.13.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.12.0"
3
+ version = "0.13.0"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -13,6 +13,7 @@ from .checkpoint import (
13
13
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
14
14
  )
15
15
  from .early_stopping import EarlyStopping as EarlyStopping
16
+ from .early_stopping import EarlyStoppingConfig as EarlyStoppingConfig
16
17
  from .ema import EMA as EMA
17
18
  from .ema import EMAConfig as EMAConfig
18
19
  from .finite_checks import FiniteChecksCallback as FiniteChecksCallback
@@ -34,7 +35,8 @@ from .wandb_watch import WandbWatchCallback as WandbWatchCallback
34
35
  from .wandb_watch import WandbWatchConfig as WandbWatchConfig
35
36
 
36
37
  CallbackConfig = Annotated[
37
- ThroughputMonitorConfig
38
+ EarlyStoppingConfig
39
+ | ThroughputMonitorConfig
38
40
  | EpochTimerConfig
39
41
  | PrintTableMetricsConfig
40
42
  | FiniteChecksConfig
@@ -1,5 +1,6 @@
1
1
  import logging
2
2
  import math
3
+ from typing import Literal
3
4
 
4
5
  from lightning.fabric.utilities.rank_zero import _get_rank
5
6
  from lightning.pytorch import Trainer
@@ -7,41 +8,69 @@ from lightning.pytorch.callbacks import EarlyStopping as _EarlyStopping
7
8
  from lightning.pytorch.utilities.rank_zero import rank_prefixed_message
8
9
  from typing_extensions import override
9
10
 
11
+ from ..metrics._config import MetricConfig
12
+ from .base import CallbackConfigBase
13
+
10
14
  log = logging.getLogger(__name__)
11
15
 
12
16
 
17
+ class EarlyStoppingConfig(CallbackConfigBase):
18
+ name: Literal["early_stopping"] = "early_stopping"
19
+
20
+ metric: MetricConfig | None = None
21
+ """
22
+ The metric to monitor for early stopping.
23
+ If None, the primary metric will be used.
24
+ """
25
+
26
+ patience: int
27
+ """
28
+ Number of epochs with no improvement after which training will be stopped.
29
+ """
30
+
31
+ min_delta: float = 1.0e-8
32
+ """
33
+ Minimum change in the monitored quantity to qualify as an improvement.
34
+ """
35
+
36
+ min_lr: float | None = None
37
+ """
38
+ Minimum learning rate. If the learning rate of the model is less than this value,
39
+ the training will be stopped.
40
+ """
41
+
42
+ strict: bool = True
43
+ """
44
+ Whether to enforce that the monitored quantity must improve by at least `min_delta`
45
+ to qualify as an improvement.
46
+ """
47
+
48
+ @override
49
+ def create_callbacks(self, root_config):
50
+ if (metric := self.metric) is None and (
51
+ metric := root_config.primary_metric
52
+ ) is None:
53
+ raise ValueError(
54
+ "Either `metric` or `root_config.primary_metric` must be set."
55
+ )
56
+
57
+ yield EarlyStopping(self, metric)
58
+
59
+
13
60
  class EarlyStopping(_EarlyStopping):
14
- def __init__(
15
- self,
16
- monitor: str,
17
- min_delta: float = 0,
18
- min_lr: float | None = None,
19
- patience: int = 3,
20
- verbose: bool = True,
21
- mode: str = "min",
22
- strict: bool = True,
23
- check_finite: bool = True,
24
- stopping_threshold: float | None = None,
25
- divergence_threshold: float | None = None,
26
- check_on_train_epoch_end: bool | None = None,
27
- log_rank_zero_only: bool = False,
28
- ):
61
+ def __init__(self, config: EarlyStoppingConfig, metric: MetricConfig):
62
+ self.config = config
63
+ self.metric = metric
64
+ del config, metric
65
+
29
66
  super().__init__(
30
- monitor,
31
- min_delta,
32
- patience,
33
- verbose,
34
- mode,
35
- strict,
36
- check_finite,
37
- stopping_threshold,
38
- divergence_threshold,
39
- check_on_train_epoch_end,
40
- log_rank_zero_only,
67
+ monitor=self.metric.validation_monitor,
68
+ mode=self.metric.mode,
69
+ patience=self.patience,
70
+ min_delta=self.min_delta,
71
+ strict=self.strict,
41
72
  )
42
73
 
43
- self.min_lr = min_lr
44
-
45
74
  @override
46
75
  @staticmethod
47
76
  def _log_info(
@@ -84,7 +113,7 @@ class EarlyStopping(_EarlyStopping):
84
113
  def _evaluate_stopping_criteria_min_lr(
85
114
  self, trainer: Trainer
86
115
  ) -> tuple[bool, str | None]:
87
- if self.min_lr is None:
116
+ if self.config.min_lr is None:
88
117
  return False, None
89
118
 
90
119
  # Get the maximum LR across all param groups in all optimizers
@@ -99,13 +128,13 @@ class EarlyStopping(_EarlyStopping):
99
128
  return False, None
100
129
 
101
130
  # If the maximum LR is less than the minimum LR, stop training
102
- if model_max_lr >= self.min_lr:
131
+ if model_max_lr >= self.config.min_lr:
103
132
  return False, None
104
133
 
105
134
  return True, (
106
135
  "Stopping threshold reached: "
107
136
  f"The maximum LR of the model across all param groups is {model_max_lr:.2e} "
108
- f"which is less than the minimum LR {self.min_lr:.2e}"
137
+ f"which is less than the minimum LR {self.config.min_lr:.2e}"
109
138
  )
110
139
 
111
140
  def on_early_stopping(self, trainer: Trainer):
@@ -38,6 +38,7 @@ from .._checkpoint.loader import CheckpointLoadingConfig
38
38
  from ..callbacks import (
39
39
  BestCheckpointCallbackConfig,
40
40
  CallbackConfig,
41
+ EarlyStoppingConfig,
41
42
  LastCheckpointCallbackConfig,
42
43
  OnExceptionCheckpointCallbackConfig,
43
44
  WandbWatchConfig,
@@ -1043,73 +1044,6 @@ class LightningTrainerKwargs(TypedDict, total=False):
1043
1044
  """
1044
1045
 
1045
1046
 
1046
- class EarlyStoppingConfig(CallbackConfigBase):
1047
- monitor: str | None = None
1048
- """
1049
- The metric to monitor for early stopping.
1050
- If None, the primary metric will be used.
1051
- """
1052
-
1053
- mode: Literal["min", "max"] | None = None
1054
- """
1055
- The mode for the metric to monitor for early stopping.
1056
- If None, the primary metric mode will be used.
1057
- """
1058
-
1059
- patience: int
1060
- """
1061
- Number of epochs with no improvement after which training will be stopped.
1062
- """
1063
-
1064
- min_delta: float = 1.0e-8
1065
- """
1066
- Minimum change in the monitored quantity to qualify as an improvement.
1067
- """
1068
-
1069
- min_lr: float | None = None
1070
- """
1071
- Minimum learning rate. If the learning rate of the model is less than this value,
1072
- the training will be stopped.
1073
- """
1074
-
1075
- strict: bool = True
1076
- """
1077
- Whether to enforce that the monitored quantity must improve by at least `min_delta`
1078
- to qualify as an improvement.
1079
- """
1080
-
1081
- @override
1082
- def create_callbacks(self, root_config: "BaseConfig"):
1083
- from ..callbacks.early_stopping import EarlyStopping
1084
-
1085
- monitor = self.monitor
1086
- mode = self.mode
1087
- if monitor is None:
1088
- assert mode is None, "If `monitor` is not provided, `mode` must be None."
1089
-
1090
- primary_metric = root_config.primary_metric
1091
- if primary_metric is None:
1092
- raise ValueError(
1093
- "No primary metric is set, so `monitor` must be provided in `early_stopping`."
1094
- )
1095
- monitor = primary_metric.validation_monitor
1096
- mode = primary_metric.mode
1097
-
1098
- if mode is None:
1099
- mode = "min"
1100
-
1101
- return [
1102
- EarlyStopping(
1103
- monitor=monitor,
1104
- mode=mode,
1105
- patience=self.patience,
1106
- min_delta=self.min_delta,
1107
- min_lr=self.min_lr,
1108
- strict=self.strict,
1109
- )
1110
- ]
1111
-
1112
-
1113
1047
  class SanityCheckingConfig(C.Config):
1114
1048
  reduce_lr_on_plateau: Literal["disable", "error", "warn"] = "error"
1115
1049
  """
File without changes