nshtrainer 0.12.1__tar.gz → 0.13.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/PKG-INFO +1 -1
  2. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/pyproject.toml +1 -1
  3. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/early_stopping.py +23 -62
  4. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/README.md +0 -0
  5. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/__init__.py +0 -0
  6. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/_checkpoint/loader.py +0 -0
  7. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  8. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/_checkpoint/saver.py +0 -0
  9. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/_experimental/__init__.py +0 -0
  10. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
  11. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
  12. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
  13. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/__init__.py +0 -0
  14. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  15. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/actsave.py +0 -0
  16. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/base.py +0 -0
  17. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  18. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  19. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  20. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  21. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  22. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/ema.py +0 -0
  23. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  24. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  25. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/interval.py +0 -0
  26. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  27. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  28. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/print_table.py +0 -0
  29. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  30. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/timer.py +0 -0
  31. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  32. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/data/__init__.py +0 -0
  33. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  34. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/data/transform.py +0 -0
  35. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/ll/__init__.py +0 -0
  36. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/ll/_experimental.py +0 -0
  37. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/ll/actsave.py +0 -0
  38. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/ll/callbacks.py +0 -0
  39. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/ll/config.py +0 -0
  40. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/ll/data.py +0 -0
  41. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/ll/log.py +0 -0
  42. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  43. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/ll/model.py +0 -0
  44. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/ll/nn.py +0 -0
  45. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/ll/optimizer.py +0 -0
  46. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/ll/runner.py +0 -0
  47. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/ll/snapshot.py +0 -0
  48. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/ll/snoop.py +0 -0
  49. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/ll/trainer.py +0 -0
  50. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/ll/typecheck.py +0 -0
  51. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/ll/util.py +0 -0
  52. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  53. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  54. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  55. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  56. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/metrics/__init__.py +0 -0
  57. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/metrics/_config.py +0 -0
  58. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/model/__init__.py +0 -0
  59. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/model/base.py +0 -0
  60. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/model/config.py +0 -0
  61. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/model/modules/callback.py +0 -0
  62. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/model/modules/debug.py +0 -0
  63. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/model/modules/distributed.py +0 -0
  64. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/model/modules/logger.py +0 -0
  65. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/model/modules/profiler.py +0 -0
  66. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  67. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  68. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/nn/__init__.py +0 -0
  69. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/nn/mlp.py +0 -0
  70. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/nn/module_dict.py +0 -0
  71. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/nn/module_list.py +0 -0
  72. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/nn/nonlinearity.py +0 -0
  73. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/optimizer.py +0 -0
  74. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/runner.py +0 -0
  75. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/scripts/find_packages.py +0 -0
  76. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/trainer/__init__.py +0 -0
  77. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  78. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  79. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/trainer/signal_connector.py +0 -0
  80. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/trainer/trainer.py +0 -0
  81. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/util/_environment_info.py +0 -0
  82. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/util/_useful_types.py +0 -0
  83. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/util/environment.py +0 -0
  84. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/util/seed.py +0 -0
  85. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/util/slurm.py +0 -0
  86. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/util/typed.py +0 -0
  87. {nshtrainer-0.12.1 → nshtrainer-0.13.1}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.12.1
3
+ Version: 0.13.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.12.1"
3
+ version = "0.13.1"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -8,6 +8,7 @@ from lightning.pytorch.callbacks import EarlyStopping as _EarlyStopping
8
8
  from lightning.pytorch.utilities.rank_zero import rank_prefixed_message
9
9
  from typing_extensions import override
10
10
 
11
+ from ..metrics._config import MetricConfig
11
12
  from .base import CallbackConfigBase
12
13
 
13
14
  log = logging.getLogger(__name__)
@@ -16,18 +17,12 @@ log = logging.getLogger(__name__)
16
17
  class EarlyStoppingConfig(CallbackConfigBase):
17
18
  name: Literal["early_stopping"] = "early_stopping"
18
19
 
19
- monitor: str | None = None
20
+ metric: MetricConfig | None = None
20
21
  """
21
22
  The metric to monitor for early stopping.
22
23
  If None, the primary metric will be used.
23
24
  """
24
25
 
25
- mode: Literal["min", "max"] | None = None
26
- """
27
- The mode for the metric to monitor for early stopping.
28
- If None, the primary metric mode will be used.
29
- """
30
-
31
26
  patience: int
32
27
  """
33
28
  Number of epochs with no improvement after which training will be stopped.
@@ -52,64 +47,30 @@ class EarlyStoppingConfig(CallbackConfigBase):
52
47
 
53
48
  @override
54
49
  def create_callbacks(self, root_config):
55
- monitor = self.monitor
56
- mode = self.mode
57
- if monitor is None:
58
- assert mode is None, "If `monitor` is not provided, `mode` must be None."
59
-
60
- primary_metric = root_config.primary_metric
61
- if primary_metric is None:
62
- raise ValueError(
63
- "No primary metric is set, so `monitor` must be provided in `early_stopping`."
64
- )
65
- monitor = primary_metric.validation_monitor
66
- mode = primary_metric.mode
67
-
68
- if mode is None:
69
- mode = "min"
70
-
71
- yield EarlyStopping(
72
- monitor=monitor,
73
- mode=mode,
74
- patience=self.patience,
75
- min_delta=self.min_delta,
76
- min_lr=self.min_lr,
77
- strict=self.strict,
78
- )
50
+ if (metric := self.metric) is None and (
51
+ metric := root_config.primary_metric
52
+ ) is None:
53
+ raise ValueError(
54
+ "Either `metric` or `root_config.primary_metric` must be set."
55
+ )
56
+
57
+ yield EarlyStopping(self, metric)
79
58
 
80
59
 
81
60
  class EarlyStopping(_EarlyStopping):
82
- def __init__(
83
- self,
84
- monitor: str,
85
- min_delta: float = 0,
86
- min_lr: float | None = None,
87
- patience: int = 3,
88
- verbose: bool = True,
89
- mode: str = "min",
90
- strict: bool = True,
91
- check_finite: bool = True,
92
- stopping_threshold: float | None = None,
93
- divergence_threshold: float | None = None,
94
- check_on_train_epoch_end: bool | None = None,
95
- log_rank_zero_only: bool = False,
96
- ):
61
+ def __init__(self, config: EarlyStoppingConfig, metric: MetricConfig):
62
+ self.config = config
63
+ self.metric = metric
64
+ del config, metric
65
+
97
66
  super().__init__(
98
- monitor,
99
- min_delta,
100
- patience,
101
- verbose,
102
- mode,
103
- strict,
104
- check_finite,
105
- stopping_threshold,
106
- divergence_threshold,
107
- check_on_train_epoch_end,
108
- log_rank_zero_only,
67
+ monitor=self.metric.validation_monitor,
68
+ mode=self.metric.mode,
69
+ patience=self.config.patience,
70
+ min_delta=self.config.min_delta,
71
+ strict=self.config.strict,
109
72
  )
110
73
 
111
- self.min_lr = min_lr
112
-
113
74
  @override
114
75
  @staticmethod
115
76
  def _log_info(
@@ -152,7 +113,7 @@ class EarlyStopping(_EarlyStopping):
152
113
  def _evaluate_stopping_criteria_min_lr(
153
114
  self, trainer: Trainer
154
115
  ) -> tuple[bool, str | None]:
155
- if self.min_lr is None:
116
+ if self.config.min_lr is None:
156
117
  return False, None
157
118
 
158
119
  # Get the maximum LR across all param groups in all optimizers
@@ -167,13 +128,13 @@ class EarlyStopping(_EarlyStopping):
167
128
  return False, None
168
129
 
169
130
  # If the maximum LR is less than the minimum LR, stop training
170
- if model_max_lr >= self.min_lr:
131
+ if model_max_lr >= self.config.min_lr:
171
132
  return False, None
172
133
 
173
134
  return True, (
174
135
  "Stopping threshold reached: "
175
136
  f"The maximum LR of the model across all param groups is {model_max_lr:.2e} "
176
- f"which is less than the minimum LR {self.min_lr:.2e}"
137
+ f"which is less than the minimum LR {self.config.min_lr:.2e}"
177
138
  )
178
139
 
179
140
  def on_early_stopping(self, trainer: Trainer):
File without changes