nshtrainer 1.0.0b46__py3-none-any.whl → 1.0.0b48__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/_callback.py CHANGED
@@ -4,46 +4,46 @@ from pathlib import Path
4
4
  from typing import TYPE_CHECKING, Any
5
5
 
6
6
  import torch
7
+ from lightning.pytorch import LightningModule
7
8
  from lightning.pytorch.callbacks import Callback as _LightningCallback
8
9
  from lightning.pytorch.utilities.types import STEP_OUTPUT
9
10
  from torch.optim import Optimizer
10
11
 
11
12
  if TYPE_CHECKING:
12
- from .model import LightningModuleBase
13
13
  from .trainer import Trainer
14
14
 
15
15
 
16
16
  class NTCallbackBase(_LightningCallback):
17
17
  def setup( # pyright: ignore[reportIncompatibleMethodOverride]
18
- self, trainer: Trainer, pl_module: LightningModuleBase, stage: str
18
+ self, trainer: Trainer, pl_module: LightningModule, stage: str
19
19
  ) -> None:
20
20
  """Called when fit, validate, test, predict, or tune begins."""
21
21
 
22
22
  def teardown( # pyright: ignore[reportIncompatibleMethodOverride]
23
- self, trainer: Trainer, pl_module: LightningModuleBase, stage: str
23
+ self, trainer: Trainer, pl_module: LightningModule, stage: str
24
24
  ) -> None:
25
25
  """Called when fit, validate, test, predict, or tune ends."""
26
26
 
27
- def on_fit_start(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
27
+ def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
28
28
  """Called when fit begins."""
29
29
 
30
- def on_fit_end(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
30
+ def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
31
31
  """Called when fit ends."""
32
32
 
33
33
  def on_sanity_check_start( # pyright: ignore[reportIncompatibleMethodOverride]
34
- self, trainer: Trainer, pl_module: LightningModuleBase
34
+ self, trainer: Trainer, pl_module: LightningModule
35
35
  ) -> None:
36
36
  """Called when the validation sanity check starts."""
37
37
 
38
38
  def on_sanity_check_end( # pyright: ignore[reportIncompatibleMethodOverride]
39
- self, trainer: Trainer, pl_module: LightningModuleBase
39
+ self, trainer: Trainer, pl_module: LightningModule
40
40
  ) -> None:
41
41
  """Called when the validation sanity check ends."""
42
42
 
43
43
  def on_train_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
44
44
  self,
45
45
  trainer: Trainer,
46
- pl_module: LightningModuleBase,
46
+ pl_module: LightningModule,
47
47
  batch: Any,
48
48
  batch_idx: int,
49
49
  ) -> None:
@@ -52,7 +52,7 @@ class NTCallbackBase(_LightningCallback):
52
52
  def on_train_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
53
53
  self,
54
54
  trainer: Trainer,
55
- pl_module: LightningModuleBase,
55
+ pl_module: LightningModule,
56
56
  outputs: STEP_OUTPUT,
57
57
  batch: Any,
58
58
  batch_idx: int,
@@ -66,12 +66,12 @@ class NTCallbackBase(_LightningCallback):
66
66
  """
67
67
 
68
68
  def on_train_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
69
- self, trainer: Trainer, pl_module: LightningModuleBase
69
+ self, trainer: Trainer, pl_module: LightningModule
70
70
  ) -> None:
71
71
  """Called when the train epoch begins."""
72
72
 
73
73
  def on_train_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
74
- self, trainer: Trainer, pl_module: LightningModuleBase
74
+ self, trainer: Trainer, pl_module: LightningModule
75
75
  ) -> None:
76
76
  """Called when the train epoch ends.
77
77
 
@@ -102,39 +102,39 @@ class NTCallbackBase(_LightningCallback):
102
102
  """
103
103
 
104
104
  def on_validation_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
105
- self, trainer: Trainer, pl_module: LightningModuleBase
105
+ self, trainer: Trainer, pl_module: LightningModule
106
106
  ) -> None:
107
107
  """Called when the val epoch begins."""
108
108
 
109
109
  def on_validation_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
110
- self, trainer: Trainer, pl_module: LightningModuleBase
110
+ self, trainer: Trainer, pl_module: LightningModule
111
111
  ) -> None:
112
112
  """Called when the val epoch ends."""
113
113
 
114
114
  def on_test_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
115
- self, trainer: Trainer, pl_module: LightningModuleBase
115
+ self, trainer: Trainer, pl_module: LightningModule
116
116
  ) -> None:
117
117
  """Called when the test epoch begins."""
118
118
 
119
119
  def on_test_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
120
- self, trainer: Trainer, pl_module: LightningModuleBase
120
+ self, trainer: Trainer, pl_module: LightningModule
121
121
  ) -> None:
122
122
  """Called when the test epoch ends."""
123
123
 
124
124
  def on_predict_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
125
- self, trainer: Trainer, pl_module: LightningModuleBase
125
+ self, trainer: Trainer, pl_module: LightningModule
126
126
  ) -> None:
127
127
  """Called when the predict epoch begins."""
128
128
 
129
129
  def on_predict_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
130
- self, trainer: Trainer, pl_module: LightningModuleBase
130
+ self, trainer: Trainer, pl_module: LightningModule
131
131
  ) -> None:
132
132
  """Called when the predict epoch ends."""
133
133
 
134
134
  def on_validation_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
135
135
  self,
136
136
  trainer: Trainer,
137
- pl_module: LightningModuleBase,
137
+ pl_module: LightningModule,
138
138
  batch: Any,
139
139
  batch_idx: int,
140
140
  dataloader_idx: int = 0,
@@ -144,7 +144,7 @@ class NTCallbackBase(_LightningCallback):
144
144
  def on_validation_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
145
145
  self,
146
146
  trainer: Trainer,
147
- pl_module: LightningModuleBase,
147
+ pl_module: LightningModule,
148
148
  outputs: STEP_OUTPUT,
149
149
  batch: Any,
150
150
  batch_idx: int,
@@ -155,7 +155,7 @@ class NTCallbackBase(_LightningCallback):
155
155
  def on_test_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
156
156
  self,
157
157
  trainer: Trainer,
158
- pl_module: LightningModuleBase,
158
+ pl_module: LightningModule,
159
159
  batch: Any,
160
160
  batch_idx: int,
161
161
  dataloader_idx: int = 0,
@@ -165,7 +165,7 @@ class NTCallbackBase(_LightningCallback):
165
165
  def on_test_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
166
166
  self,
167
167
  trainer: Trainer,
168
- pl_module: LightningModuleBase,
168
+ pl_module: LightningModule,
169
169
  outputs: STEP_OUTPUT,
170
170
  batch: Any,
171
171
  batch_idx: int,
@@ -176,7 +176,7 @@ class NTCallbackBase(_LightningCallback):
176
176
  def on_predict_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
177
177
  self,
178
178
  trainer: Trainer,
179
- pl_module: LightningModuleBase,
179
+ pl_module: LightningModule,
180
180
  batch: Any,
181
181
  batch_idx: int,
182
182
  dataloader_idx: int = 0,
@@ -186,7 +186,7 @@ class NTCallbackBase(_LightningCallback):
186
186
  def on_predict_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
187
187
  self,
188
188
  trainer: Trainer,
189
- pl_module: LightningModuleBase,
189
+ pl_module: LightningModule,
190
190
  outputs: Any,
191
191
  batch: Any,
192
192
  batch_idx: int,
@@ -194,40 +194,40 @@ class NTCallbackBase(_LightningCallback):
194
194
  ) -> None:
195
195
  """Called when the predict batch ends."""
196
196
 
197
- def on_train_start(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
197
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
198
198
  """Called when the train begins."""
199
199
 
200
- def on_train_end(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
200
+ def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
201
201
  """Called when the train ends."""
202
202
 
203
203
  def on_validation_start( # pyright: ignore[reportIncompatibleMethodOverride]
204
- self, trainer: Trainer, pl_module: LightningModuleBase
204
+ self, trainer: Trainer, pl_module: LightningModule
205
205
  ) -> None:
206
206
  """Called when the validation loop begins."""
207
207
 
208
208
  def on_validation_end( # pyright: ignore[reportIncompatibleMethodOverride]
209
- self, trainer: Trainer, pl_module: LightningModuleBase
209
+ self, trainer: Trainer, pl_module: LightningModule
210
210
  ) -> None:
211
211
  """Called when the validation loop ends."""
212
212
 
213
- def on_test_start(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
213
+ def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
214
214
  """Called when the test begins."""
215
215
 
216
- def on_test_end(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
216
+ def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
217
217
  """Called when the test ends."""
218
218
 
219
219
  def on_predict_start( # pyright: ignore[reportIncompatibleMethodOverride]
220
- self, trainer: Trainer, pl_module: LightningModuleBase
220
+ self, trainer: Trainer, pl_module: LightningModule
221
221
  ) -> None:
222
222
  """Called when the predict begins."""
223
223
 
224
- def on_predict_end(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
224
+ def on_predict_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
225
225
  """Called when predict ends."""
226
226
 
227
227
  def on_exception( # pyright: ignore[reportIncompatibleMethodOverride]
228
228
  self,
229
229
  trainer: Trainer,
230
- pl_module: LightningModuleBase,
230
+ pl_module: LightningModule,
231
231
  exception: BaseException,
232
232
  ) -> None:
233
233
  """Called when any trainer execution is interrupted by an exception."""
@@ -253,7 +253,7 @@ class NTCallbackBase(_LightningCallback):
253
253
  def on_save_checkpoint( # pyright: ignore[reportIncompatibleMethodOverride]
254
254
  self,
255
255
  trainer: Trainer,
256
- pl_module: LightningModuleBase,
256
+ pl_module: LightningModule,
257
257
  checkpoint: dict[str, Any],
258
258
  ) -> None:
259
259
  r"""Called when saving a checkpoint to give you a chance to store anything else you might want to save.
@@ -268,7 +268,7 @@ class NTCallbackBase(_LightningCallback):
268
268
  def on_load_checkpoint( # pyright: ignore[reportIncompatibleMethodOverride]
269
269
  self,
270
270
  trainer: Trainer,
271
- pl_module: LightningModuleBase,
271
+ pl_module: LightningModule,
272
272
  checkpoint: dict[str, Any],
273
273
  ) -> None:
274
274
  r"""Called when loading a model checkpoint, use to reload state.
@@ -281,19 +281,19 @@ class NTCallbackBase(_LightningCallback):
281
281
  """
282
282
 
283
283
  def on_before_backward( # pyright: ignore[reportIncompatibleMethodOverride]
284
- self, trainer: Trainer, pl_module: LightningModuleBase, loss: torch.Tensor
284
+ self, trainer: Trainer, pl_module: LightningModule, loss: torch.Tensor
285
285
  ) -> None:
286
286
  """Called before ``loss.backward()``."""
287
287
 
288
288
  def on_after_backward( # pyright: ignore[reportIncompatibleMethodOverride]
289
- self, trainer: Trainer, pl_module: LightningModuleBase
289
+ self, trainer: Trainer, pl_module: LightningModule
290
290
  ) -> None:
291
291
  """Called after ``loss.backward()`` and before optimizers are stepped."""
292
292
 
293
293
  def on_before_optimizer_step( # pyright: ignore[reportIncompatibleMethodOverride]
294
294
  self,
295
295
  trainer: Trainer,
296
- pl_module: LightningModuleBase,
296
+ pl_module: LightningModule,
297
297
  optimizer: Optimizer,
298
298
  ) -> None:
299
299
  """Called before ``optimizer.step()``."""
@@ -301,7 +301,7 @@ class NTCallbackBase(_LightningCallback):
301
301
  def on_before_zero_grad( # pyright: ignore[reportIncompatibleMethodOverride]
302
302
  self,
303
303
  trainer: Trainer,
304
- pl_module: LightningModuleBase,
304
+ pl_module: LightningModule,
305
305
  optimizer: Optimizer,
306
306
  ) -> None:
307
307
  """Called before ``optimizer.zero_grad()``."""
@@ -310,15 +310,15 @@ class NTCallbackBase(_LightningCallback):
310
310
  self,
311
311
  ckpt_path: Path,
312
312
  metadata_path: Path | None,
313
- trainer: "Trainer",
314
- pl_module: "LightningModuleBase",
313
+ trainer: Trainer,
314
+ pl_module: LightningModule,
315
315
  ) -> None:
316
316
  """Called after a checkpoint is saved."""
317
317
  pass
318
318
 
319
319
 
320
320
  def _call_on_checkpoint_saved(
321
- trainer: "Trainer",
321
+ trainer: Trainer,
322
322
  ckpt_path: str | Path,
323
323
  metadata_path: str | Path | None,
324
324
  ):
@@ -333,5 +333,5 @@ def _call_on_checkpoint_saved(
333
333
  ckpt_path,
334
334
  metadata_path,
335
335
  trainer,
336
- trainer._base_module,
336
+ trainer.lightning_module,
337
337
  )
@@ -16,6 +16,18 @@ from ..util.path import compute_file_checksum, try_symlink_or_copy
16
16
  if TYPE_CHECKING:
17
17
  from ..trainer.trainer import Trainer
18
18
 
19
+ try:
20
+ from pydantic import BaseModel
21
+
22
+ _HAS_PYDANTIC = True
23
+ except ImportError:
24
+ if not TYPE_CHECKING:
25
+ BaseModel = object
26
+ else:
27
+ from pydantic import BaseModel
28
+ _HAS_PYDANTIC = False
29
+
30
+
19
31
  log = logging.getLogger(__name__)
20
32
 
21
33
 
@@ -27,10 +39,10 @@ def _full_hparams_dict(trainer: Trainer):
27
39
  hparams["trainer"] = trainer.hparams.model_dump(mode="json")
28
40
 
29
41
  if trainer.lightning_module is not None:
30
- from ..model import LightningModuleBase
31
-
32
- if isinstance(trainer.lightning_module, LightningModuleBase):
33
- hparams["model"] = trainer.lightning_module.hparams.model_dump(mode="json")
42
+ model_hparams = trainer.lightning_module.hparams
43
+ if _HAS_PYDANTIC and isinstance(model_hparams, BaseModel):
44
+ model_hparams = model_hparams.model_dump(mode="json")
45
+ hparams["model"] = dict(model_hparams)
34
46
 
35
47
  return hparams
36
48
 
@@ -51,7 +51,7 @@ class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
51
51
  class BestCheckpointCallback(CheckpointBase[BestCheckpointCallbackConfig]):
52
52
  @property
53
53
  def _metric_name_normalized(self):
54
- return self.metric.name.replace("/", "_").replace(" ", "_").replace(".", "_")
54
+ return self.metric.monitor.replace("/", "_").replace(" ", "_").replace(".", "_")
55
55
 
56
56
  @override
57
57
  def __init__(
@@ -69,12 +69,12 @@ class BestCheckpointCallback(CheckpointBase[BestCheckpointCallbackConfig]):
69
69
 
70
70
  @override
71
71
  def default_filename(self):
72
- return f"epoch{{epoch}}-step{{step}}-{self._metric_name_normalized}{{{self.metric.validation_monitor}}}"
72
+ return f"epoch{{epoch}}-step{{step}}-{self._metric_name_normalized}{{{self.metric.monitor}}}"
73
73
 
74
74
  @override
75
75
  def topk_sort_key(self, metadata: CheckpointMetadata):
76
76
  return metadata.metrics.get(
77
- self.metric.validation_monitor,
77
+ self.metric.monitor,
78
78
  float("-inf" if self.metric.mode == "max" else "inf"),
79
79
  )
80
80
 
@@ -68,7 +68,7 @@ class EarlyStoppingCallback(_EarlyStopping):
68
68
  del config, metric
69
69
 
70
70
  super().__init__(
71
- monitor=self.metric.validation_monitor,
71
+ monitor=self.metric.monitor,
72
72
  mode=self.metric.mode,
73
73
  patience=self.config.patience,
74
74
  min_delta=self.config.min_delta,
@@ -55,14 +55,14 @@ class MetricValidationCallback(Callback):
55
55
  self.metrics = metrics
56
56
 
57
57
  def _check_metrics(self, trainer: Trainer):
58
- metric_names = ", ".join(metric.validation_monitor for metric in self.metrics)
58
+ metric_names = ", ".join(metric.monitor for metric in self.metrics)
59
59
  log.info(f"Validating metrics: {metric_names}...")
60
60
  logged_metrics = set(trainer.logged_metrics.keys())
61
61
 
62
62
  invalid_metrics: list[str] = []
63
63
  for metric in self.metrics:
64
- if metric.validation_monitor not in logged_metrics:
65
- invalid_metrics.append(metric.validation_monitor)
64
+ if metric.monitor not in logged_metrics:
65
+ invalid_metrics.append(metric.monitor)
66
66
 
67
67
  if invalid_metrics:
68
68
  msg = (
@@ -171,7 +171,7 @@ class CustomRLPImplementation(Protocol):
171
171
  __reduce_lr_on_plateau__: bool
172
172
 
173
173
 
174
- class _RLPSanityCheckModuleMixin(LightningModule):
174
+ class RLPSanityCheckModuleMixin(LightningModule):
175
175
  def reduce_lr_on_plateau_config(
176
176
  self,
177
177
  lr_scheduler: LRSchedulerTypeUnion | LRSchedulerConfigType,
@@ -11,13 +11,13 @@ from lightning.pytorch import LightningDataModule
11
11
  from typing_extensions import Never, TypeVar, deprecated, override
12
12
 
13
13
  from ..model.mixins.callback import CallbackRegistrarModuleMixin
14
- from ..model.mixins.debug import _DebugModuleMixin
14
+ from ..model.mixins.debug import DebugModuleMixin
15
15
 
16
16
  THparams = TypeVar("THparams", bound=C.Config, infer_variance=True)
17
17
 
18
18
 
19
19
  class LightningDataModuleBase(
20
- _DebugModuleMixin,
20
+ DebugModuleMixin,
21
21
  CallbackRegistrarModuleMixin,
22
22
  LightningDataModule,
23
23
  ABC,
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Annotated
4
4
 
5
- import nshconfig as C
6
5
  from typing_extensions import TypeAliasType
7
6
 
8
7
  from .actsave import ActSaveLoggerConfig as ActSaveLoggerConfig
@@ -3,9 +3,10 @@ from __future__ import annotations
3
3
  import math
4
4
  from abc import ABC, abstractmethod
5
5
  from collections.abc import Mapping
6
- from typing import TYPE_CHECKING, Literal
6
+ from typing import Literal
7
7
 
8
8
  import nshconfig as C
9
+ from lightning.pytorch import LightningModule
9
10
  from lightning.pytorch.utilities.types import (
10
11
  LRSchedulerConfigType,
11
12
  LRSchedulerTypeUnion,
@@ -13,9 +14,6 @@ from lightning.pytorch.utilities.types import (
13
14
  from torch.optim import Optimizer
14
15
  from typing_extensions import Never, NotRequired, TypedDict
15
16
 
16
- if TYPE_CHECKING:
17
- from ..model.base import LightningModuleBase
18
-
19
17
 
20
18
  class LRSchedulerMetadata(TypedDict):
21
19
  interval: Literal["epoch", "step"]
@@ -44,13 +42,13 @@ class LRSchedulerConfigBase(C.Config, ABC):
44
42
 
45
43
  @abstractmethod
46
44
  def create_scheduler_impl(
47
- self, optimizer: Optimizer, lightning_module: LightningModuleBase
45
+ self, optimizer: Optimizer, lightning_module: LightningModule
48
46
  ) -> LRSchedulerTypeUnion | LRSchedulerConfigType: ...
49
47
 
50
48
  def create_scheduler(
51
49
  self,
52
50
  optimizer: Optimizer,
53
- lightning_module: LightningModuleBase,
51
+ lightning_module: LightningModule,
54
52
  lr: Never
55
53
  | None = None, # Backward compatibility, should be removed in the future
56
54
  ) -> LRSchedulerConfigType:
@@ -87,7 +85,7 @@ class LRSchedulerConfigBase(C.Config, ABC):
87
85
 
88
86
  return scheduler
89
87
 
90
- def compute_num_steps_per_epoch(self, lightning_module: LightningModuleBase) -> int:
88
+ def compute_num_steps_per_epoch(self, lightning_module: LightningModule) -> int:
91
89
  trainer = lightning_module.trainer
92
90
  # Use the Lightning trainer to convert the epoch-based values to step-based values
93
91
  _ = trainer.estimated_stepping_batches
@@ -49,13 +49,13 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
49
49
  if (metric := self.metric) is None:
50
50
  from ..trainer import Trainer
51
51
 
52
- assert isinstance(
53
- trainer := lightning_module.trainer, Trainer
54
- ), "The trainer must be a `nshtrainer.Trainer` instance."
52
+ assert isinstance(trainer := lightning_module.trainer, Trainer), (
53
+ "The trainer must be a `nshtrainer.Trainer` instance."
54
+ )
55
55
 
56
- assert (
57
- metric := trainer.hparams.primary_metric
58
- ) is not None, "Primary metric must be provided if metric is not specified."
56
+ assert (metric := trainer.hparams.primary_metric) is not None, (
57
+ "Primary metric must be provided if metric is not specified."
58
+ )
59
59
 
60
60
  lr_scheduler = ReduceLROnPlateau(
61
61
  optimizer,
@@ -70,7 +70,7 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
70
70
  )
71
71
  return {
72
72
  "scheduler": lr_scheduler,
73
- "monitor": metric.validation_monitor,
73
+ "monitor": metric.monitor,
74
74
  }
75
75
 
76
76
  @override
@@ -7,8 +7,8 @@ import nshconfig as C
7
7
 
8
8
 
9
9
  class MetricConfig(C.Config):
10
- name: str
11
- """The name of the primary metric."""
10
+ monitor: str
11
+ """The name of the metric to monitor."""
12
12
 
13
13
  mode: Literal["min", "max"]
14
14
  """
@@ -17,23 +17,6 @@ class MetricConfig(C.Config):
17
17
  - "max" for metrics that should be maximized (e.g., accuracy)
18
18
  """
19
19
 
20
- @property
21
- def validation_monitor(self) -> str:
22
- return f"val/{self.name}"
23
-
24
- def __post_init__(self):
25
- for split in ("train", "val", "test", "predict"):
26
- if self.name.startswith(f"{split}/"):
27
- raise ValueError(
28
- f"Primary metric name should not start with '{split}/'. "
29
- f"Just use '{self.name[len(split) + 1:]}' instead. "
30
- "The split name is automatically added depending on the context."
31
- )
32
-
33
- @classmethod
34
- def loss(cls, mode: Literal["min", "max"] = "min"):
35
- return cls(name="loss", mode=mode)
36
-
37
20
  @property
38
21
  def best(self):
39
22
  return builtins.min if self.mode == "min" else builtins.max
nshtrainer/model/base.py CHANGED
@@ -15,9 +15,9 @@ from lightning.pytorch.utilities.model_helpers import is_overridden
15
15
  from lightning.pytorch.utilities.rank_zero import rank_zero_warn
16
16
  from typing_extensions import Never, TypeVar, deprecated, override
17
17
 
18
- from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
18
+ from ..callbacks.rlp_sanity_checks import RLPSanityCheckModuleMixin
19
19
  from .mixins.callback import CallbackModuleMixin
20
- from .mixins.debug import _DebugModuleMixin
20
+ from .mixins.debug import DebugModuleMixin
21
21
  from .mixins.logger import LoggerLightningModuleMixin
22
22
 
23
23
  log = logging.getLogger(__name__)
@@ -54,81 +54,14 @@ VALID_REDUCE_OPS = (
54
54
 
55
55
 
56
56
  class LightningModuleBase(
57
- _DebugModuleMixin,
58
- _RLPSanityCheckModuleMixin,
57
+ DebugModuleMixin,
58
+ RLPSanityCheckModuleMixin,
59
59
  LoggerLightningModuleMixin,
60
60
  CallbackModuleMixin,
61
61
  LightningModule,
62
62
  ABC,
63
63
  Generic[THparams],
64
64
  ):
65
- # region Debug
66
- @property
67
- def debug(self) -> bool:
68
- if torch.jit.is_scripting():
69
- return False
70
-
71
- if (trainer := self._trainer) is None:
72
- return False
73
-
74
- from ..trainer import Trainer
75
-
76
- if not isinstance(trainer, Trainer):
77
- return False
78
-
79
- return trainer.debug
80
-
81
- @debug.setter
82
- def debug(self, value: bool):
83
- if torch.jit.is_scripting():
84
- return
85
-
86
- if (trainer := self._trainer) is None:
87
- return
88
-
89
- from ..trainer import Trainer
90
-
91
- if not isinstance(trainer, Trainer):
92
- return
93
-
94
- trainer.debug = value
95
-
96
- @torch.jit.unused
97
- def breakpoint(self, rank_zero_only: bool = True):
98
- if (
99
- not rank_zero_only
100
- or not torch.distributed.is_initialized()
101
- or torch.distributed.get_rank() == 0
102
- ):
103
- breakpoint()
104
-
105
- if rank_zero_only and torch.distributed.is_initialized():
106
- _ = torch.distributed.barrier()
107
-
108
- @torch.jit.unused
109
- def ensure_finite(
110
- self,
111
- tensor: torch.Tensor,
112
- name: str | None = None,
113
- throw: bool = False,
114
- ):
115
- name_parts: list[str] = ["Tensor"]
116
- if name is not None:
117
- name_parts.append(name)
118
- name = " ".join(name_parts)
119
-
120
- not_finite = ~torch.isfinite(tensor)
121
- if not_finite.any():
122
- msg = f"{name} has {not_finite.sum().item()}/{not_finite.numel()} non-finite values."
123
- if throw:
124
- raise RuntimeError(msg)
125
- else:
126
- log.warning(msg)
127
- return False
128
- return True
129
-
130
- # endregion
131
-
132
65
  # region Profiler
133
66
  @property
134
67
  def profiler(self) -> Profiler:
@@ -28,7 +28,7 @@ def _trainer(module: Any):
28
28
  return trainer
29
29
 
30
30
 
31
- class _DebugModuleMixin:
31
+ class DebugModuleMixin:
32
32
  @property
33
33
  def nshtrainer_or_none(self):
34
34
  return _trainer(self)
@@ -54,6 +54,12 @@ class _LogContextKwargs:
54
54
  d = dataclasses.asdict(self)
55
55
  for field in self.__ignore_fields__:
56
56
  d.pop(field, None)
57
+
58
+ # Pop all None values
59
+ for k in list(d.keys()):
60
+ if d[k] is None:
61
+ d.pop(k)
62
+
57
63
  return d
58
64
 
59
65
 
@@ -134,18 +140,18 @@ class LoggerLightningModuleMixin(mixin_base_type(LightningModule)):
134
140
  self,
135
141
  name: str,
136
142
  value: _METRIC,
137
- prog_bar: bool = False,
143
+ prog_bar: bool | None = None,
138
144
  logger: bool | None = None,
139
145
  on_step: bool | None = None,
140
146
  on_epoch: bool | None = None,
141
- reduce_fx: str | Callable = "mean",
142
- enable_graph: bool = False,
143
- sync_dist: bool = False,
147
+ reduce_fx: str | Callable | None = None,
148
+ enable_graph: bool | None = None,
149
+ sync_dist: bool | None = None,
144
150
  sync_dist_group: Any | None = None,
145
- add_dataloader_idx: bool = True,
151
+ add_dataloader_idx: bool | None = None,
146
152
  batch_size: int | None = None,
147
153
  metric_attribute: str | None = None,
148
- rank_zero_only: bool = False,
154
+ rank_zero_only: bool | None = None,
149
155
  ) -> None:
150
156
  # If logging is disabled, then do nothing.
151
157
  if not self.logging_enabled:
@@ -418,20 +418,6 @@ class Trainer(LightningTrainer):
418
418
 
419
419
  return tracker.time_elapsed(stage)
420
420
 
421
- @property
422
- def _base_module(self):
423
- if self.lightning_module is None:
424
- raise ValueError("LightningModule is not set.")
425
-
426
- from ..model.base import LightningModuleBase
427
-
428
- if not isinstance(self.lightning_module, LightningModuleBase):
429
- raise ValueError(
430
- f"LightningModule is not an instance of {LightningModuleBase}."
431
- )
432
-
433
- return self.lightning_module
434
-
435
421
  @override
436
422
  def _run(
437
423
  self, model: LightningModule, ckpt_path: str | Path | None = None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.0.0b46
3
+ Version: 1.0.0b48
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,7 +1,7 @@
1
1
  nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
2
2
  nshtrainer/__init__.py,sha256=g_moPnfQxSxFZX5NB9ILQQOJrt4RTRuiFt9N0STIpxM,874
3
- nshtrainer/_callback.py,sha256=tXQCDzS6CvMTuTY5lQSH5qZs1pXUi-gt9bQdpXMVdEs,12715
4
- nshtrainer/_checkpoint/metadata.py,sha256=XoKqY3eR95CYuc_Kk9ck-p4iM2Q1OXU3vSXNrzohHz0,5332
3
+ nshtrainer/_callback.py,sha256=ZDppiJ4d65tRXTEWYPZLH_F1xFizdz1pkWJe_sQ5uII,12564
4
+ nshtrainer/_checkpoint/metadata.py,sha256=ojSEmq0udFwdzIC5vkbF0yEdhMaJ2iBrZCSFNDkeeGY,5578
5
5
  nshtrainer/_checkpoint/saver.py,sha256=65UDrz3KuhkgVfco-RkWuoa1wzTZoXxunlC769yJaMc,1639
6
6
  nshtrainer/_directory.py,sha256=TJR9ccyuzRlAVfVjGyeQ3E2AFAcz-XbBCxWfiXo2SlY,3191
7
7
  nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
@@ -11,22 +11,22 @@ nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpK
11
11
  nshtrainer/callbacks/base.py,sha256=Alaou1IHAIlMEM7g58d_02ozY2xWlshBN7fsw5Ee21s,3683
12
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
13
13
  nshtrainer/callbacks/checkpoint/_base.py,sha256=f7lpk8W4xqxk3PolBEU3AWt9VTIpoLW7wMUhC5DNm3c,6345
14
- nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=2CQuhPJ3Fi7lDw7z-J8kXXXuDU8-4HcU48oZxR49apk,2667
14
+ nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=aCs3E1eucfDlUeW2Iq_Ke7hb96BxHanmvn7PCCbqq0E,2648
15
15
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM6kUSHYNwHJeYG82j2dMcc,3554
16
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
17
17
  nshtrainer/callbacks/debug_flag.py,sha256=96fuP0C7C6dSs1GiMeUYzzs0X3Q4Pjt9JVWg3b75fU4,1748
18
18
  nshtrainer/callbacks/directory_setup.py,sha256=wPas_Ren8ANejogmIdKhqqgj4ulxz9AS_8xVIAfRXa0,2565
19
- nshtrainer/callbacks/early_stopping.py,sha256=EjzN-gD_Xd4YHZLkXsbi00g_4ti3RTMJEdHJ8GMeaFM,4776
19
+ nshtrainer/callbacks/early_stopping.py,sha256=rC_qYKCQWjRQJFo0ky46uG0aDJdYP8vsSlKunk0bUVI,4765
20
20
  nshtrainer/callbacks/ema.py,sha256=dBFiUXG0xmyCw8-ayuSzJMKqSbepl6Ii5VIbhFlT5ug,12255
21
21
  nshtrainer/callbacks/finite_checks.py,sha256=3lZ3kEIjmYQfqTF0DcrgZ9_98ZLQhQj8usH7SgWst3o,2185
22
22
  nshtrainer/callbacks/gradient_skipping.py,sha256=8g7oC7PF0LTAEzwiNoaS5tWOnkjk_EB0QG3JdHkQ8ek,3523
23
23
  nshtrainer/callbacks/interval.py,sha256=UCzUzt3XCFVyQyCWL9lOrStkkxesvduNOYk8yMrGTTk,8116
24
24
  nshtrainer/callbacks/log_epoch.py,sha256=B5Dm8XVZwCzKUhUWfT_5PDdDac993191OsbcxxuSVJE,1457
25
25
  nshtrainer/callbacks/lr_monitor.py,sha256=qy_C0R40J0hBAukzBwng5FI2jJUpWuXOi5N6FU6ym3I,1210
26
- nshtrainer/callbacks/metric_validation.py,sha256=tqUVS2n9QRT3v1_8jAGlYBFhLpA6Bm9pxOsfWhD3yZQ,2915
26
+ nshtrainer/callbacks/metric_validation.py,sha256=4RDr1FuNKfro-6QEtmcFqT4iNf2twmJVNk9y-8nq9bg,2882
27
27
  nshtrainer/callbacks/norm_logging.py,sha256=nVIDWe-ASl5zN830-ODR8QMCqI1ma-QPCIwoy0Wb-Nk,6390
28
28
  nshtrainer/callbacks/print_table.py,sha256=VaS4JgI963do79laXK4lUkFQx8v6aRSy22W0zyal_LA,3035
29
- nshtrainer/callbacks/rlp_sanity_checks.py,sha256=74BZvV2HLO__ucQXsLXb8eJLUZgRFUNJZ6TL9efMp74,10051
29
+ nshtrainer/callbacks/rlp_sanity_checks.py,sha256=Df9Prq2QKXnaeMBIvMQBhDhJTDeru5UbiuXJOJR16Gk,10050
30
30
  nshtrainer/callbacks/shared_parameters.py,sha256=s94jJTAIbDGukYJu6l247QonVOCudGClU4t5kLt8XrY,3076
31
31
  nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU,4731
32
32
  nshtrainer/callbacks/wandb_upload_code.py,sha256=shV7UtnXgY2bUlXdVrXiaDs0PNLlIt7TzNJkJPkzvzI,2414
@@ -100,25 +100,25 @@ nshtrainer/configs/util/config/dtype/__init__.py,sha256=PmGF-O4r6SXqEaagVsQ5YxEq
100
100
  nshtrainer/configs/util/config/duration/__init__.py,sha256=44lS2irOIPVfgshMTfnZM2jC6l0Pjst9w2M_lJoS_MU,353
101
101
  nshtrainer/data/__init__.py,sha256=K4i3Tw4g9EOK2zlMMbidi99y0SyI4w8P7_XUf1n42Ts,260
102
102
  nshtrainer/data/balanced_batch_sampler.py,sha256=r1cBKRXKHD8E1Ax6tj-FUbE-z1qpbO58mQ9VrK9uLnc,5481
103
- nshtrainer/data/datamodule.py,sha256=lSOgH32nysJWa6Y7ba1QyOdUV0DVVdO98qokP8wigjk,4138
103
+ nshtrainer/data/datamodule.py,sha256=0M-HjGZQkLG77HXn4ZgLSypnbSjkjTq6GEJwGWe_gbM,4136
104
104
  nshtrainer/data/transform.py,sha256=qd0lIocO59Fk_m90xyOHgFezbymd1mRwly8nbYIfHGc,2263
105
- nshtrainer/loggers/__init__.py,sha256=Ddd3JJXVzew_ZpwHA9kGnGmvq4OwhItwghDL5PzNhDc,614
105
+ nshtrainer/loggers/__init__.py,sha256=fI0OHEltHP4tZI-KFB3npdzoxm_M2QsEYKxY3um05_s,592
106
106
  nshtrainer/loggers/actsave.py,sha256=wgNrpBB6wQM7qff8iLDb_sQnbiAcYHRmH56pcEJPB3o,1409
107
107
  nshtrainer/loggers/base.py,sha256=ON92XbwTSgadQOSyw5PiRRFzyH6uJ-xLtE0nB3cbgPc,1205
108
108
  nshtrainer/loggers/csv.py,sha256=xJ8mSmw4vJwinIfqhF6t2HWmh_1dXEYyLfGuXwL7WHo,1160
109
109
  nshtrainer/loggers/tensorboard.py,sha256=E7iO_fDt9bfH02hBL430bXPLljOo5iGgq2QyPqmx2gQ,2324
110
110
  nshtrainer/loggers/wandb.py,sha256=KZXAUWrrmdX_L8rqej77oUHaM0JxZRM8y9z6JP9PISw,6856
111
111
  nshtrainer/lr_scheduler/__init__.py,sha256=daMMK3erUcNXGGd_nZB8AWu3ZTYqfS1RSWeK4FV2udw,851
112
- nshtrainer/lr_scheduler/base.py,sha256=062fGcH5sYeEKwoY55RydCTvfPwTnyZHCi049a3nMbM,3805
112
+ nshtrainer/lr_scheduler/base.py,sha256=LE53JRBTuAlA1fqbMgCZ7m39D1z0rGj2TizhJ62CPvE,3756
113
113
  nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=MsoXgCcWTKsrkNZiGnKS6yC-slRuleuwFxeM_lmG_pQ,5560
114
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=v9T0GpvOoHV30atFB0MwExHgHcTpMCYxbMRoPjPBjt8,2938
114
+ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=zKO_4Cl28m3TopoNFmc5H6GSUuVUGYUoAlXpMh_EJIk,2931
115
115
  nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
116
- nshtrainer/metrics/_config.py,sha256=XIRokFM8PHrhBa3w2R6BM6a4es3ncsoBqE_LqXQFsFE,1223
116
+ nshtrainer/metrics/_config.py,sha256=ox_ScK6V0J9nzIMhEB0qpToNKpt83VVgOVSRFCV-wBc,595
117
117
  nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
118
- nshtrainer/model/base.py,sha256=JL3AmH17GQjQIoMrZl3O0vUI7dj5ZsO5iEJgoLPyzHw,10356
118
+ nshtrainer/model/base.py,sha256=bZMNap0rkxRbAbu2BOHV_6YS2iZZnvy6wVSMOXGa_ZM,8680
119
119
  nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
120
- nshtrainer/model/mixins/debug.py,sha256=1LX9KzeFX9JDPs_a6YCdYDZXLhEk_5rBO2aCqlfBy7w,2087
121
- nshtrainer/model/mixins/logger.py,sha256=27H99FuLaxc6_dDLG2pid4E_5E0-eLGnc2Ifpt0HYIM,6066
120
+ nshtrainer/model/mixins/debug.py,sha256=ydLuAAaa7M5bX0gougZ5gWuZnvn4Ra9assal3IZ9hq8,2086
121
+ nshtrainer/model/mixins/logger.py,sha256=IYfyyW_1VAD_HiTsfX28P-XNgz_SMb07t5lwb5rjlZ0,6221
122
122
  nshtrainer/nn/__init__.py,sha256=5Gg3nieGSC5_dXaI9KUVUUbM13hHexH9831m4hcf6no,1475
123
123
  nshtrainer/nn/mlp.py,sha256=nYUgAISzuhC8sav6PloAdyz0PdEoikwppiXIuToEVdE,7550
124
124
  nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
@@ -142,7 +142,7 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=h-ydZwXepnsw5-paLgiDatqPyQ_8C0QEv
142
142
  nshtrainer/trainer/plugin/precision.py,sha256=I0QsB1bVxmsFmBOkgrAfGONsuYae_lD9Bz0PfJEQvH4,5598
143
143
  nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
144
144
  nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
145
- nshtrainer/trainer/trainer.py,sha256=8wMe0qArbDfStS4UdmuKSC2aiAImR3mhj14_kCJiNSM,20797
145
+ nshtrainer/trainer/trainer.py,sha256=QEK-0bcw1y5Cconi99PYFXr0MElUGgGYMZ_SlcJUQ1k,20364
146
146
  nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
147
147
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
148
148
  nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
@@ -154,6 +154,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
154
154
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
155
155
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
156
156
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
157
- nshtrainer-1.0.0b46.dist-info/METADATA,sha256=L6-5RyLlIcoFyURkoCuHsAgItT0gSl6Ip0l4iDKvs4o,988
158
- nshtrainer-1.0.0b46.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
159
- nshtrainer-1.0.0b46.dist-info/RECORD,,
157
+ nshtrainer-1.0.0b48.dist-info/METADATA,sha256=b26a0GYVQcEszYiodjGF34N7gvEKONBVuB1bXTv35U4,988
158
+ nshtrainer-1.0.0b48.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
159
+ nshtrainer-1.0.0b48.dist-info/RECORD,,