nshtrainer 1.0.0b48__py3-none-any.whl → 1.0.0b50__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,6 +7,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau
7
7
  from typing_extensions import final, override
8
8
 
9
9
  from ..metrics._config import MetricConfig
10
+ from ..util.config import EpochsConfig
10
11
  from .base import LRSchedulerConfigBase, LRSchedulerMetadata, lr_scheduler_registry
11
12
 
12
13
 
@@ -21,13 +22,13 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
21
22
  """Metric to monitor.
22
23
  If not provided, the primary metric of the runner will be used."""
23
24
 
24
- patience: int
25
+ patience: int | EpochsConfig
25
26
  r"""Number of epochs with no improvement after which learning rate will be reduced."""
26
27
 
27
28
  factor: float
28
29
  r"""Factor by which the learning rate will be reduced. new_lr = lr * factor."""
29
30
 
30
- cooldown: int = 0
31
+ cooldown: int | EpochsConfig = 0
31
32
  r"""Number of epochs to wait before resuming normal operation after lr has been reduced."""
32
33
 
33
34
  min_lr: float | list[float] = 0.0
@@ -57,14 +58,20 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
57
58
  "Primary metric must be provided if metric is not specified."
58
59
  )
59
60
 
61
+ if isinstance(patience := self.patience, EpochsConfig):
62
+ patience = int(patience.value)
63
+
64
+ if isinstance(cooldown := self.cooldown, EpochsConfig):
65
+ cooldown = int(cooldown.value)
66
+
60
67
  lr_scheduler = ReduceLROnPlateau(
61
68
  optimizer,
62
69
  mode=metric.mode,
63
70
  factor=self.factor,
64
- patience=self.patience,
71
+ patience=patience,
65
72
  threshold=self.threshold,
66
73
  threshold_mode=self.threshold_mode,
67
- cooldown=self.cooldown,
74
+ cooldown=cooldown,
68
75
  min_lr=self.min_lr,
69
76
  eps=self.eps,
70
77
  )
@@ -1,16 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
- import copy
4
3
  import dataclasses
5
4
  from collections import deque
6
- from collections.abc import Callable, Generator
5
+ from collections.abc import Callable, Generator, Mapping
7
6
  from contextlib import contextmanager
8
7
  from typing import Any, ClassVar
9
8
 
9
+ import torchmetrics
10
10
  from lightning.pytorch import LightningModule
11
11
  from lightning.pytorch.utilities.types import _METRIC
12
12
  from lightning_utilities.core.rank_zero import rank_zero_warn
13
- from typing_extensions import Self, override
13
+ from typing_extensions import override
14
14
 
15
15
  from ...util.typing_utils import mixin_base_type
16
16
 
@@ -33,23 +33,6 @@ class _LogContextKwargs:
33
33
  batch_size: int | None = None
34
34
  rank_zero_only: bool | None = None
35
35
 
36
- def copy_from(self, other: Self):
37
- kwargs = copy.deepcopy(self)
38
-
39
- # Copy over all the not-None values from the other object
40
- updates = {}
41
- for field in dataclasses.fields(self):
42
- # Ignore disabled fields
43
- if field.name in self.__ignore_fields__:
44
- continue
45
-
46
- if (value := getattr(other, field.name, None)) is None:
47
- continue
48
- # setattr(kwargs, field.name, value)
49
- updates[field.name] = value
50
-
51
- return dataclasses.replace(kwargs, **updates)
52
-
53
36
  def to_dict(self):
54
37
  d = dataclasses.asdict(self)
55
38
  for field in self.__ignore_fields__:
@@ -135,6 +118,16 @@ class LoggerLightningModuleMixin(mixin_base_type(LightningModule)):
135
118
  finally:
136
119
  _ = self._logger_prefix_stack.pop()
137
120
 
121
+ def _make_prefix_and_kwargs_dict(self, kwargs: _LogContextKwargs):
122
+ prefix = "".join(c.prefix for c in self._logger_prefix_stack if c.prefix)
123
+
124
+ fn_kwargs: dict[str, Any] = {}
125
+ for c in self._logger_prefix_stack:
126
+ fn_kwargs.update(c.to_dict())
127
+
128
+ fn_kwargs.update(kwargs.to_dict())
129
+ return prefix, fn_kwargs
130
+
138
131
  @override
139
132
  def log(
140
133
  self,
@@ -153,18 +146,117 @@ class LoggerLightningModuleMixin(mixin_base_type(LightningModule)):
153
146
  metric_attribute: str | None = None,
154
147
  rank_zero_only: bool | None = None,
155
148
  ) -> None:
149
+ """Log a key, value pair.
150
+
151
+ Example::
152
+
153
+ self.log('train_loss', loss)
154
+
155
+ The default behavior per hook is documented here: :ref:`extensions/logging:Automatic Logging`.
156
+
157
+ Args:
158
+ name: key to log. Must be identical across all processes if using DDP or any other distributed strategy.
159
+ value: value to log. Can be a ``float``, ``Tensor``, or a ``Metric``.
160
+ prog_bar: if ``True`` logs to the progress bar.
161
+ logger: if ``True`` logs to the logger.
162
+ on_step: if ``True`` logs at this step. The default value is determined by the hook.
163
+ See :ref:`extensions/logging:Automatic Logging` for details.
164
+ on_epoch: if ``True`` logs epoch accumulated metrics. The default value is determined by the hook.
165
+ See :ref:`extensions/logging:Automatic Logging` for details.
166
+ reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default.
167
+ enable_graph: if ``True``, will not auto detach the graph.
168
+ sync_dist: if ``True``, reduces the metric across devices. Use with care as this may lead to a significant
169
+ communication overhead.
170
+ sync_dist_group: the DDP group to sync across.
171
+ add_dataloader_idx: if ``True``, appends the index of the current dataloader to
172
+ the name (when using multiple dataloaders). If False, user needs to give unique names for
173
+ each dataloader to not mix the values.
174
+ batch_size: Current batch_size. This will be directly inferred from the loaded batch,
175
+ but for some data structures you might need to explicitly provide it.
176
+ metric_attribute: To restore the metric state, Lightning requires the reference of the
177
+ :class:`torchmetrics.Metric` in your model. This is found automatically if it is a model attribute.
178
+ rank_zero_only: Tells Lightning if you are calling ``self.log`` from every process (default) or only from
179
+ rank 0. If ``True``, you won't be able to use this metric as a monitor in callbacks
180
+ (e.g., early stopping). Warning: Improper use can lead to deadlocks! See
181
+ :ref:`Advanced Logging <visualize/logging_advanced:rank_zero_only>` for more details.
182
+
183
+ """
156
184
  # If logging is disabled, then do nothing.
157
185
  if not self.logging_enabled:
158
186
  return
159
187
 
160
- # join all prefixes
161
- prefix = "".join(c.prefix for c in self._logger_prefix_stack if c.prefix)
188
+ prefix, fn_kwargs = self._make_prefix_and_kwargs_dict(
189
+ _LogContextKwargs(
190
+ prog_bar=prog_bar,
191
+ logger=logger,
192
+ on_step=on_step,
193
+ on_epoch=on_epoch,
194
+ reduce_fx=reduce_fx,
195
+ enable_graph=enable_graph,
196
+ sync_dist=sync_dist,
197
+ sync_dist_group=sync_dist_group,
198
+ add_dataloader_idx=add_dataloader_idx,
199
+ batch_size=batch_size,
200
+ rank_zero_only=rank_zero_only,
201
+ )
202
+ )
162
203
  name = f"{prefix}{name}"
204
+ return super().log(name, value, metric_attribute=metric_attribute, **fn_kwargs)
163
205
 
164
- fn_kwargs = _LogContextKwargs()
165
- for c in self._logger_prefix_stack:
166
- fn_kwargs = fn_kwargs.copy_from(c)
167
- fn_kwargs = fn_kwargs.copy_from(
206
+ def log_dict(
207
+ self,
208
+ dictionary: Mapping[str, _METRIC] | torchmetrics.MetricCollection,
209
+ prog_bar: bool | None = None,
210
+ logger: bool | None = None,
211
+ on_step: bool | None = None,
212
+ on_epoch: bool | None = None,
213
+ reduce_fx: str | Callable | None = None,
214
+ enable_graph: bool | None = None,
215
+ sync_dist: bool | None = None,
216
+ sync_dist_group: Any | None = None,
217
+ add_dataloader_idx: bool | None = None,
218
+ batch_size: int | None = None,
219
+ rank_zero_only: bool | None = None,
220
+ ) -> None:
221
+ """Log a dictionary of values at once.
222
+
223
+ Example::
224
+
225
+ values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n}
226
+ self.log_dict(values)
227
+
228
+ Args:
229
+ dictionary: key value pairs.
230
+ Keys must be identical across all processes if using DDP or any other distributed strategy.
231
+ The values can be a ``float``, ``Tensor``, ``Metric``, or ``MetricCollection``.
232
+ prog_bar: if ``True`` logs to the progress base.
233
+ logger: if ``True`` logs to the logger.
234
+ on_step: if ``True`` logs at this step.
235
+ ``None`` auto-logs for training_step but not validation/test_step.
236
+ The default value is determined by the hook.
237
+ See :ref:`extensions/logging:Automatic Logging` for details.
238
+ on_epoch: if ``True`` logs epoch accumulated metrics.
239
+ ``None`` auto-logs for val/test step but not ``training_step``.
240
+ The default value is determined by the hook.
241
+ See :ref:`extensions/logging:Automatic Logging` for details.
242
+ reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default.
243
+ enable_graph: if ``True``, will not auto-detach the graph
244
+ sync_dist: if ``True``, reduces the metric across GPUs/TPUs. Use with care as this may lead to a significant
245
+ communication overhead.
246
+ sync_dist_group: the ddp group to sync across.
247
+ add_dataloader_idx: if ``True``, appends the index of the current dataloader to
248
+ the name (when using multiple). If ``False``, user needs to give unique names for
249
+ each dataloader to not mix values.
250
+ batch_size: Current batch size. This will be directly inferred from the loaded batch,
251
+ but some data structures might need to explicitly provide it.
252
+ rank_zero_only: Tells Lightning if you are calling ``self.log`` from every process (default) or only from
253
+ rank 0. If ``True``, you won't be able to use this metric as a monitor in callbacks
254
+ (e.g., early stopping). Warning: Improper use can lead to deadlocks! See
255
+ :ref:`Advanced Logging <visualize/logging_advanced:rank_zero_only>` for more details.
256
+
257
+ """
258
+
259
+ _, fn_kwargs = self._make_prefix_and_kwargs_dict(
168
260
  _LogContextKwargs(
169
261
  prog_bar=prog_bar,
170
262
  logger=logger,
@@ -179,9 +271,5 @@ class LoggerLightningModuleMixin(mixin_base_type(LightningModule)):
179
271
  rank_zero_only=rank_zero_only,
180
272
  )
181
273
  )
182
- return super().log(
183
- name,
184
- value,
185
- metric_attribute=metric_attribute,
186
- **fn_kwargs.to_dict(),
187
- )
274
+ # NOTE: Prefix will be handled by the individual log calls.
275
+ return super().log_dict(dictionary, **fn_kwargs)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.0.0b48
3
+ Version: 1.0.0b50
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -111,14 +111,14 @@ nshtrainer/loggers/wandb.py,sha256=KZXAUWrrmdX_L8rqej77oUHaM0JxZRM8y9z6JP9PISw,6
111
111
  nshtrainer/lr_scheduler/__init__.py,sha256=daMMK3erUcNXGGd_nZB8AWu3ZTYqfS1RSWeK4FV2udw,851
112
112
  nshtrainer/lr_scheduler/base.py,sha256=LE53JRBTuAlA1fqbMgCZ7m39D1z0rGj2TizhJ62CPvE,3756
113
113
  nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=MsoXgCcWTKsrkNZiGnKS6yC-slRuleuwFxeM_lmG_pQ,5560
114
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=zKO_4Cl28m3TopoNFmc5H6GSUuVUGYUoAlXpMh_EJIk,2931
114
+ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=irPyDjfUX843ze4bJM9sW8WSeEcU643QJ30JN2hz9Rc,3206
115
115
  nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
116
116
  nshtrainer/metrics/_config.py,sha256=ox_ScK6V0J9nzIMhEB0qpToNKpt83VVgOVSRFCV-wBc,595
117
117
  nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
118
118
  nshtrainer/model/base.py,sha256=bZMNap0rkxRbAbu2BOHV_6YS2iZZnvy6wVSMOXGa_ZM,8680
119
119
  nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
120
120
  nshtrainer/model/mixins/debug.py,sha256=ydLuAAaa7M5bX0gougZ5gWuZnvn4Ra9assal3IZ9hq8,2086
121
- nshtrainer/model/mixins/logger.py,sha256=IYfyyW_1VAD_HiTsfX28P-XNgz_SMb07t5lwb5rjlZ0,6221
121
+ nshtrainer/model/mixins/logger.py,sha256=7u9fQig-SVFA9RFIB4U0gqJAzruh49mgmXXvZ6VkDUk,11694
122
122
  nshtrainer/nn/__init__.py,sha256=5Gg3nieGSC5_dXaI9KUVUUbM13hHexH9831m4hcf6no,1475
123
123
  nshtrainer/nn/mlp.py,sha256=nYUgAISzuhC8sav6PloAdyz0PdEoikwppiXIuToEVdE,7550
124
124
  nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
@@ -154,6 +154,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
154
154
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
155
155
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
156
156
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
157
- nshtrainer-1.0.0b48.dist-info/METADATA,sha256=b26a0GYVQcEszYiodjGF34N7gvEKONBVuB1bXTv35U4,988
158
- nshtrainer-1.0.0b48.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
159
- nshtrainer-1.0.0b48.dist-info/RECORD,,
157
+ nshtrainer-1.0.0b50.dist-info/METADATA,sha256=KgNg6AHzL9uCAc1tzfM0gbQl5Bu9QhQFFtecE75KIn0,988
158
+ nshtrainer-1.0.0b50.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
159
+ nshtrainer-1.0.0b50.dist-info/RECORD,,