nshtrainer 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/_callback.py CHANGED
@@ -8,38 +8,46 @@ from lightning.pytorch import LightningModule
8
8
  from lightning.pytorch.callbacks import Callback as _LightningCallback
9
9
  from lightning.pytorch.utilities.types import STEP_OUTPUT
10
10
  from torch.optim import Optimizer
11
+ from typing_extensions import override
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from .trainer import Trainer
14
15
 
15
16
 
16
17
  class NTCallbackBase(_LightningCallback):
18
+ @override
17
19
  def setup( # pyright: ignore[reportIncompatibleMethodOverride]
18
20
  self, trainer: Trainer, pl_module: LightningModule, stage: str
19
21
  ) -> None:
20
22
  """Called when fit, validate, test, predict, or tune begins."""
21
23
 
24
+ @override
22
25
  def teardown( # pyright: ignore[reportIncompatibleMethodOverride]
23
26
  self, trainer: Trainer, pl_module: LightningModule, stage: str
24
27
  ) -> None:
25
28
  """Called when fit, validate, test, predict, or tune ends."""
26
29
 
30
+ @override
27
31
  def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
28
32
  """Called when fit begins."""
29
33
 
34
+ @override
30
35
  def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
31
36
  """Called when fit ends."""
32
37
 
38
+ @override
33
39
  def on_sanity_check_start( # pyright: ignore[reportIncompatibleMethodOverride]
34
40
  self, trainer: Trainer, pl_module: LightningModule
35
41
  ) -> None:
36
42
  """Called when the validation sanity check starts."""
37
43
 
44
+ @override
38
45
  def on_sanity_check_end( # pyright: ignore[reportIncompatibleMethodOverride]
39
46
  self, trainer: Trainer, pl_module: LightningModule
40
47
  ) -> None:
41
48
  """Called when the validation sanity check ends."""
42
49
 
50
+ @override
43
51
  def on_train_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
44
52
  self,
45
53
  trainer: Trainer,
@@ -49,6 +57,7 @@ class NTCallbackBase(_LightningCallback):
49
57
  ) -> None:
50
58
  """Called when the train batch begins."""
51
59
 
60
+ @override
52
61
  def on_train_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
53
62
  self,
54
63
  trainer: Trainer,
@@ -65,11 +74,13 @@ class NTCallbackBase(_LightningCallback):
65
74
 
66
75
  """
67
76
 
77
+ @override
68
78
  def on_train_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
69
79
  self, trainer: Trainer, pl_module: LightningModule
70
80
  ) -> None:
71
81
  """Called when the train epoch begins."""
72
82
 
83
+ @override
73
84
  def on_train_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
74
85
  self, trainer: Trainer, pl_module: LightningModule
75
86
  ) -> None:
@@ -81,10 +92,12 @@ class NTCallbackBase(_LightningCallback):
81
92
  .. code-block:: python
82
93
 
83
94
  class MyLightningModule(L.LightningModule):
95
+ @override
84
96
  def __init__(self):
85
97
  super().__init__() # pyright: ignore[reportIncompatibleMethodOverride]
86
98
  self.training_step_outputs = []
87
99
 
100
+ @override
88
101
  def training_step(self):
89
102
  loss = ... # pyright: ignore[reportIncompatibleMethodOverride]
90
103
  self.training_step_outputs.append(loss)
@@ -92,6 +105,7 @@ class NTCallbackBase(_LightningCallback):
92
105
 
93
106
 
94
107
  class MyCallback(L.Callback):
108
+ @override
95
109
  def on_train_epoch_end(self, trainer, pl_module):
96
110
  # do something with all training_step outputs, for example: # pyright: ignore[reportIncompatibleMethodOverride]
97
111
  epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
@@ -101,36 +115,43 @@ class NTCallbackBase(_LightningCallback):
101
115
 
102
116
  """
103
117
 
118
+ @override
104
119
  def on_validation_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
105
120
  self, trainer: Trainer, pl_module: LightningModule
106
121
  ) -> None:
107
122
  """Called when the val epoch begins."""
108
123
 
124
+ @override
109
125
  def on_validation_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
110
126
  self, trainer: Trainer, pl_module: LightningModule
111
127
  ) -> None:
112
128
  """Called when the val epoch ends."""
113
129
 
130
+ @override
114
131
  def on_test_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
115
132
  self, trainer: Trainer, pl_module: LightningModule
116
133
  ) -> None:
117
134
  """Called when the test epoch begins."""
118
135
 
136
+ @override
119
137
  def on_test_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
120
138
  self, trainer: Trainer, pl_module: LightningModule
121
139
  ) -> None:
122
140
  """Called when the test epoch ends."""
123
141
 
142
+ @override
124
143
  def on_predict_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
125
144
  self, trainer: Trainer, pl_module: LightningModule
126
145
  ) -> None:
127
146
  """Called when the predict epoch begins."""
128
147
 
148
+ @override
129
149
  def on_predict_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
130
150
  self, trainer: Trainer, pl_module: LightningModule
131
151
  ) -> None:
132
152
  """Called when the predict epoch ends."""
133
153
 
154
+ @override
134
155
  def on_validation_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
135
156
  self,
136
157
  trainer: Trainer,
@@ -141,6 +162,7 @@ class NTCallbackBase(_LightningCallback):
141
162
  ) -> None:
142
163
  """Called when the validation batch begins."""
143
164
 
165
+ @override
144
166
  def on_validation_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
145
167
  self,
146
168
  trainer: Trainer,
@@ -152,6 +174,7 @@ class NTCallbackBase(_LightningCallback):
152
174
  ) -> None:
153
175
  """Called when the validation batch ends."""
154
176
 
177
+ @override
155
178
  def on_test_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
156
179
  self,
157
180
  trainer: Trainer,
@@ -162,6 +185,7 @@ class NTCallbackBase(_LightningCallback):
162
185
  ) -> None:
163
186
  """Called when the test batch begins."""
164
187
 
188
+ @override
165
189
  def on_test_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
166
190
  self,
167
191
  trainer: Trainer,
@@ -173,6 +197,7 @@ class NTCallbackBase(_LightningCallback):
173
197
  ) -> None:
174
198
  """Called when the test batch ends."""
175
199
 
200
+ @override
176
201
  def on_predict_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
177
202
  self,
178
203
  trainer: Trainer,
@@ -183,6 +208,7 @@ class NTCallbackBase(_LightningCallback):
183
208
  ) -> None:
184
209
  """Called when the predict batch begins."""
185
210
 
211
+ @override
186
212
  def on_predict_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
187
213
  self,
188
214
  trainer: Trainer,
@@ -194,36 +220,45 @@ class NTCallbackBase(_LightningCallback):
194
220
  ) -> None:
195
221
  """Called when the predict batch ends."""
196
222
 
223
+ @override
197
224
  def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
198
225
  """Called when the train begins."""
199
226
 
227
+ @override
200
228
  def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
201
229
  """Called when the train ends."""
202
230
 
231
+ @override
203
232
  def on_validation_start( # pyright: ignore[reportIncompatibleMethodOverride]
204
233
  self, trainer: Trainer, pl_module: LightningModule
205
234
  ) -> None:
206
235
  """Called when the validation loop begins."""
207
236
 
237
+ @override
208
238
  def on_validation_end( # pyright: ignore[reportIncompatibleMethodOverride]
209
239
  self, trainer: Trainer, pl_module: LightningModule
210
240
  ) -> None:
211
241
  """Called when the validation loop ends."""
212
242
 
243
+ @override
213
244
  def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
214
245
  """Called when the test begins."""
215
246
 
247
+ @override
216
248
  def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
217
249
  """Called when the test ends."""
218
250
 
251
+ @override
219
252
  def on_predict_start( # pyright: ignore[reportIncompatibleMethodOverride]
220
253
  self, trainer: Trainer, pl_module: LightningModule
221
254
  ) -> None:
222
255
  """Called when the predict begins."""
223
256
 
257
+ @override
224
258
  def on_predict_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
225
259
  """Called when predict ends."""
226
260
 
261
+ @override
227
262
  def on_exception( # pyright: ignore[reportIncompatibleMethodOverride]
228
263
  self,
229
264
  trainer: Trainer,
@@ -232,7 +267,8 @@ class NTCallbackBase(_LightningCallback):
232
267
  ) -> None:
233
268
  """Called when any trainer execution is interrupted by an exception."""
234
269
 
235
- def state_dict(self) -> dict[str, Any]: # pyright: ignore[reportIncompatibleMethodOverride]
270
+ @override
271
+ def state_dict(self) -> dict[str, Any]:
236
272
  """Called when saving a checkpoint, implement to generate callback's ``state_dict``.
237
273
 
238
274
  Returns:
@@ -241,7 +277,8 @@ class NTCallbackBase(_LightningCallback):
241
277
  """
242
278
  return {}
243
279
 
244
- def load_state_dict(self, state_dict: dict[str, Any]) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
280
+ @override
281
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
245
282
  """Called when loading a checkpoint, implement to reload callback state given callback's ``state_dict``.
246
283
 
247
284
  Args:
@@ -250,6 +287,7 @@ class NTCallbackBase(_LightningCallback):
250
287
  """
251
288
  pass
252
289
 
290
+ @override
253
291
  def on_save_checkpoint( # pyright: ignore[reportIncompatibleMethodOverride]
254
292
  self,
255
293
  trainer: Trainer,
@@ -265,6 +303,7 @@ class NTCallbackBase(_LightningCallback):
265
303
 
266
304
  """
267
305
 
306
+ @override
268
307
  def on_load_checkpoint( # pyright: ignore[reportIncompatibleMethodOverride]
269
308
  self,
270
309
  trainer: Trainer,
@@ -280,16 +319,19 @@ class NTCallbackBase(_LightningCallback):
280
319
 
281
320
  """
282
321
 
322
+ @override
283
323
  def on_before_backward( # pyright: ignore[reportIncompatibleMethodOverride]
284
324
  self, trainer: Trainer, pl_module: LightningModule, loss: torch.Tensor
285
325
  ) -> None:
286
326
  """Called before ``loss.backward()``."""
287
327
 
328
+ @override
288
329
  def on_after_backward( # pyright: ignore[reportIncompatibleMethodOverride]
289
330
  self, trainer: Trainer, pl_module: LightningModule
290
331
  ) -> None:
291
332
  """Called after ``loss.backward()`` and before optimizers are stepped."""
292
333
 
334
+ @override
293
335
  def on_before_optimizer_step( # pyright: ignore[reportIncompatibleMethodOverride]
294
336
  self,
295
337
  trainer: Trainer,
@@ -298,6 +340,7 @@ class NTCallbackBase(_LightningCallback):
298
340
  ) -> None:
299
341
  """Called before ``optimizer.step()``."""
300
342
 
343
+ @override
301
344
  def on_before_zero_grad( # pyright: ignore[reportIncompatibleMethodOverride]
302
345
  self,
303
346
  trainer: Trainer,
@@ -306,7 +349,10 @@ class NTCallbackBase(_LightningCallback):
306
349
  ) -> None:
307
350
  """Called before ``optimizer.zero_grad()``."""
308
351
 
309
- def on_checkpoint_saved( # pyright: ignore[reportIncompatibleMethodOverride]
352
+ # =================================================================
353
+ # Our own new callbacks
354
+ # =================================================================
355
+ def on_checkpoint_saved(
310
356
  self,
311
357
  ckpt_path: Path,
312
358
  metadata_path: Path | None,
@@ -317,6 +363,7 @@ class NTCallbackBase(_LightningCallback):
317
363
  pass
318
364
 
319
365
 
366
+ @override
320
367
  def _call_on_checkpoint_saved(
321
368
  trainer: Trainer,
322
369
  ckpt_path: str | Path,
@@ -75,5 +75,5 @@ from .wandb_watch import WandbWatchCallbackConfig as WandbWatchCallbackConfig
75
75
 
76
76
  CallbackConfig = TypeAliasType(
77
77
  "CallbackConfig",
78
- Annotated[CallbackConfigBase, callback_registry.DynamicResolution()],
78
+ Annotated[CallbackConfigBase, callback_registry],
79
79
  )
@@ -5,13 +5,13 @@ import string
5
5
  from abc import ABC, abstractmethod
6
6
  from collections.abc import Callable
7
7
  from pathlib import Path
8
- from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
8
+ from typing import TYPE_CHECKING, Any, Generic, Literal
9
9
 
10
10
  import numpy as np
11
11
  import torch
12
12
  from lightning.pytorch import Trainer
13
13
  from lightning.pytorch.callbacks import Checkpoint
14
- from typing_extensions import override
14
+ from typing_extensions import TypeVar, override
15
15
 
16
16
  from ..._checkpoint.metadata import CheckpointMetadata, _generate_checkpoint_metadata
17
17
  from ..._checkpoint.saver import link_checkpoint, remove_checkpoint
@@ -67,14 +67,14 @@ class PrintTableMetricsCallback(Callback):
67
67
  }
68
68
  self.metrics.append(metrics_dict)
69
69
 
70
- from rich.console import Console # type: ignore[reportMissingImports] # noqa
70
+ from rich.console import Console # pyright: ignore[reportMissingImports] # noqa
71
71
 
72
72
  console = Console()
73
73
  table = self.create_metrics_table()
74
74
  console.print(table)
75
75
 
76
76
  def create_metrics_table(self):
77
- from rich.table import Table # type: ignore[reportMissingImports] # noqa
77
+ from rich.table import Table # pyright: ignore[reportMissingImports] # noqa
78
78
 
79
79
  table = Table(show_header=True, header_style="bold magenta")
80
80
 
@@ -38,6 +38,7 @@ class RLPSanityChecksCallbackConfig(CallbackConfigBase):
38
38
  def __bool__(self):
39
39
  return self.enabled
40
40
 
41
+ @override
41
42
  def create_callbacks(self, trainer_config):
42
43
  if not self:
43
44
  return
@@ -111,7 +111,6 @@ from nshtrainer.optimizer import RAdamConfig as RAdamConfig
111
111
  from nshtrainer.optimizer import RMSpropConfig as RMSpropConfig
112
112
  from nshtrainer.optimizer import RpropConfig as RpropConfig
113
113
  from nshtrainer.optimizer import SGDConfig as SGDConfig
114
- from nshtrainer.optimizer import Union as Union
115
114
  from nshtrainer.optimizer import optimizer_registry as optimizer_registry
116
115
  from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
117
116
  from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
@@ -355,7 +354,6 @@ __all__ = [
355
354
  "TorchSyncBatchNormPlugin",
356
355
  "TrainerConfig",
357
356
  "TransformerEnginePluginConfig",
358
- "Union",
359
357
  "WandbLoggerConfig",
360
358
  "WandbUploadCodeCallbackConfig",
361
359
  "WandbWatchCallbackConfig",
@@ -16,7 +16,6 @@ from nshtrainer.optimizer import RAdamConfig as RAdamConfig
16
16
  from nshtrainer.optimizer import RMSpropConfig as RMSpropConfig
17
17
  from nshtrainer.optimizer import RpropConfig as RpropConfig
18
18
  from nshtrainer.optimizer import SGDConfig as SGDConfig
19
- from nshtrainer.optimizer import Union as Union
20
19
  from nshtrainer.optimizer import optimizer_registry as optimizer_registry
21
20
 
22
21
  __all__ = [
@@ -34,6 +33,5 @@ __all__ = [
34
33
  "RMSpropConfig",
35
34
  "RpropConfig",
36
35
  "SGDConfig",
37
- "Union",
38
36
  "optimizer_registry",
39
37
  ]
@@ -12,6 +12,5 @@ from .tensorboard import TensorboardLoggerConfig as TensorboardLoggerConfig
12
12
  from .wandb import WandbLoggerConfig as WandbLoggerConfig
13
13
 
14
14
  LoggerConfig = TypeAliasType(
15
- "LoggerConfig",
16
- Annotated[LoggerConfigBase, logger_registry.DynamicResolution()],
15
+ "LoggerConfig", Annotated[LoggerConfigBase, logger_registry]
17
16
  )
@@ -5,7 +5,7 @@ from typing import Any, Literal
5
5
 
6
6
  import numpy as np
7
7
  from lightning.pytorch.loggers import Logger
8
- from typing_extensions import final
8
+ from typing_extensions import final, override
9
9
 
10
10
  from .base import LoggerConfigBase, logger_registry
11
11
 
@@ -15,6 +15,7 @@ from .base import LoggerConfigBase, logger_registry
15
15
  class ActSaveLoggerConfig(LoggerConfigBase):
16
16
  name: Literal["actsave"] = "actsave"
17
17
 
18
+ @override
18
19
  def create_logger(self, trainer_config):
19
20
  if not self.enabled:
20
21
  return None
@@ -24,10 +25,12 @@ class ActSaveLoggerConfig(LoggerConfigBase):
24
25
 
25
26
  class ActSaveLogger(Logger):
26
27
  @property
28
+ @override
27
29
  def name(self):
28
30
  return None
29
31
 
30
32
  @property
33
+ @override
31
34
  def version(self):
32
35
  from nshutils import ActSave
33
36
 
@@ -37,6 +40,7 @@ class ActSaveLogger(Logger):
37
40
  return ActSave._saver._id
38
41
 
39
42
  @property
43
+ @override
40
44
  def save_dir(self):
41
45
  from nshutils import ActSave
42
46
 
@@ -45,6 +49,7 @@ class ActSaveLogger(Logger):
45
49
 
46
50
  return str(ActSave._saver._save_dir)
47
51
 
52
+ @override
48
53
  def log_hyperparams(
49
54
  self,
50
55
  params: dict[str, Any] | Namespace,
@@ -56,6 +61,7 @@ class ActSaveLogger(Logger):
56
61
  # Wrap the hparams as a object-dtype np array
57
62
  return ActSave.save({"hyperparameters": np.array(params, dtype=object)})
58
63
 
64
+ @override
59
65
  def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> None:
60
66
  from nshutils import ActSave
61
67
 
@@ -63,7 +63,7 @@ class FinishWandbOnTeardownCallback(Callback):
63
63
  stage: str,
64
64
  ):
65
65
  try:
66
- import wandb # type: ignore
66
+ import wandb
67
67
  except ImportError:
68
68
  return
69
69
 
@@ -139,7 +139,7 @@ class WandbLoggerConfig(CallbackConfigBase, LoggerConfigBase):
139
139
  # If `wandb-core` is enabled, we should use the new backend.
140
140
  if self.use_wandb_core:
141
141
  try:
142
- import wandb # type: ignore
142
+ import wandb
143
143
 
144
144
  # The minimum version that supports the new backend is 0.17.5
145
145
  wandb_version = version.parse(importlib.metadata.version("wandb"))
@@ -151,7 +151,7 @@ class WandbLoggerConfig(CallbackConfigBase, LoggerConfigBase):
151
151
  )
152
152
  # W&B versions 0.18.0 use wandb-core by default
153
153
  elif wandb_version < version.parse("0.18.0"):
154
- wandb.require("core") # type: ignore
154
+ wandb.require("core")
155
155
  log.critical("Using the `wandb-core` backend for WandB.")
156
156
  except ImportError:
157
157
  pass
@@ -166,9 +166,9 @@ class WandbLoggerConfig(CallbackConfigBase, LoggerConfigBase):
166
166
  "If you want to use the new `wandb-core` backend, set `use_wandb_core=True`."
167
167
  )
168
168
  try:
169
- import wandb # type: ignore
169
+ import wandb
170
170
 
171
- wandb.require("legacy-service") # type: ignore
171
+ wandb.require("legacy-service")
172
172
  except ImportError:
173
173
  pass
174
174
 
@@ -81,7 +81,7 @@ class LRSchedulerConfigBase(C.Config, ABC):
81
81
  scheduler["monitor"] = metadata["monitor"]
82
82
  # - `strict`
83
83
  if scheduler.get("strict") is None and "strict" in metadata:
84
- scheduler["strict"] = metadata["strict"] # type: ignore
84
+ scheduler["strict"] = metadata["strict"]
85
85
 
86
86
  return scheduler
87
87
 
@@ -41,23 +41,6 @@ class CallbackModuleMixin(
41
41
  CallbackRegistrarModuleMixin,
42
42
  mixin_base_type(LightningModule),
43
43
  ):
44
- @property
45
- def _nshtrainer_callbacks(self) -> list[CallbackFn]:
46
- if not hasattr(self, "_private_nshtrainer_callbacks_list"):
47
- self._private_nshtrainer_callbacks_list = []
48
- return self._private_nshtrainer_callbacks_list
49
-
50
- def register_callback(
51
- self,
52
- callback: _Callback | Iterable[_Callback] | CallbackFn | None = None,
53
- ):
54
- if not callable(callback):
55
- callback_ = cast(CallbackFn, lambda: callback)
56
- else:
57
- callback_ = callback
58
-
59
- self._nshtrainer_callbacks.append(callback_)
60
-
61
44
  def _gather_all_callbacks(self):
62
45
  modules: list[Any] = []
63
46
  if isinstance(self, CallbackRegistrarModuleMixin):
@@ -203,6 +203,7 @@ class LoggerLightningModuleMixin(mixin_base_type(LightningModule)):
203
203
  name = f"{prefix}{name}"
204
204
  return super().log(name, value, metric_attribute=metric_attribute, **fn_kwargs)
205
205
 
206
+ @override
206
207
  def log_dict(
207
208
  self,
208
209
  dictionary: Mapping[str, _METRIC] | torchmetrics.MetricCollection,
@@ -28,9 +28,9 @@ class TypedModuleDict(nn.Module, Generic[TModule]):
28
28
  return f"{self.key_prefix}{key}"
29
29
 
30
30
  def _remove_prefix(self, key: str) -> str:
31
- assert key.startswith(
32
- self.key_prefix
33
- ), f"{key} does not start with {self.key_prefix}"
31
+ assert key.startswith(self.key_prefix), (
32
+ f"{key} does not start with {self.key_prefix}"
33
+ )
34
34
  return key[len(self.key_prefix) :]
35
35
 
36
36
  def __setitem__(self, key: str, module: TModule) -> None:
@@ -39,7 +39,7 @@ class TypedModuleDict(nn.Module, Generic[TModule]):
39
39
 
40
40
  def __getitem__(self, key: str) -> TModule:
41
41
  key = self._with_prefix(key)
42
- return self._module_dict.__getitem__(key) # type: ignore
42
+ return cast(TModule, self._module_dict.__getitem__(key))
43
43
 
44
44
  def update(self, modules: Mapping[str, TModule]) -> None:
45
45
  return self._module_dict.update(
@@ -1,12 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from collections.abc import Iterable, Iterator
4
- from typing import Generic, TypeVar, overload
4
+ from typing import Generic, cast, overload
5
5
 
6
6
  import torch.nn as nn
7
- from typing_extensions import override
7
+ from typing_extensions import TypeVar, override
8
8
 
9
- TModule = TypeVar("TModule", bound=nn.Module)
9
+ TModule = TypeVar("TModule", bound=nn.Module, infer_variance=True)
10
10
 
11
11
 
12
12
  class TypedModuleList(nn.ModuleList, Generic[TModule]):
@@ -14,39 +14,39 @@ class TypedModuleList(nn.ModuleList, Generic[TModule]):
14
14
  super().__init__(modules)
15
15
 
16
16
  @overload
17
- def __getitem__(self, idx: slice) -> "TypedModuleList[TModule]": ...
17
+ def __getitem__(self, idx: slice) -> TypedModuleList[TModule]: ...
18
18
 
19
19
  @overload
20
20
  def __getitem__(self, idx: int) -> TModule: ...
21
21
 
22
22
  @override
23
- def __getitem__(self, idx: int | slice) -> TModule | "TypedModuleList[TModule]":
24
- return super().__getitem__(idx) # type: ignore
23
+ def __getitem__(self, idx: int | slice) -> TModule | TypedModuleList[TModule]:
24
+ return cast(TModule | TypedModuleList[TModule], super().__getitem__(idx))
25
25
 
26
26
  @override
27
- def __setitem__(self, idx: int, module: TModule) -> None: # type: ignore
27
+ def __setitem__(self, idx: int, module: TModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
28
28
  return super().__setitem__(idx, module)
29
29
 
30
30
  @override
31
31
  def __iter__(self) -> Iterator[TModule]:
32
- return super().__iter__() # type: ignore
32
+ return cast(Iterator[TModule], super().__iter__())
33
33
 
34
34
  @override
35
- def __iadd__(self, modules: Iterable[TModule]) -> "TypedModuleList[TModule]": # type: ignore
36
- return super().__iadd__(modules) # type: ignore
35
+ def __iadd__(self, modules: Iterable[TModule]) -> TypedModuleList[TModule]: # pyright: ignore[reportIncompatibleMethodOverride]
36
+ return cast(TypedModuleList[TModule], super().__iadd__(modules))
37
37
 
38
38
  @override
39
- def __add__(self, modules: Iterable[TModule]) -> "TypedModuleList[TModule]": # type: ignore
40
- return super().__add__(modules) # type: ignore
39
+ def __add__(self, modules: Iterable[TModule]) -> TypedModuleList[TModule]: # pyright: ignore[reportIncompatibleMethodOverride]
40
+ return cast(TypedModuleList[TModule], super().__add__(modules))
41
41
 
42
42
  @override
43
- def insert(self, idx: int, module: TModule) -> None: # type: ignore
43
+ def insert(self, idx: int, module: TModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
44
44
  return super().insert(idx, module)
45
45
 
46
46
  @override
47
- def append(self, module: TModule) -> "TypedModuleList[TModule]": # type: ignore
48
- return super().append(module) # type: ignore
47
+ def append(self, module: TModule) -> TypedModuleList[TModule]: # pyright: ignore[reportIncompatibleMethodOverride]
48
+ return cast(TypedModuleList[TModule], super().append(module))
49
49
 
50
50
  @override
51
- def extend(self, modules: Iterable[TModule]) -> "TypedModuleList[TModule]": # type: ignore
52
- return super().extend(modules) # type: ignore
51
+ def extend(self, modules: Iterable[TModule]) -> TypedModuleList[TModule]: # pyright: ignore[reportIncompatibleMethodOverride]
52
+ return cast(TypedModuleList[TModule], super().extend(modules))
@@ -30,6 +30,7 @@ class ReLUNonlinearityConfig(NonlinearityConfigBase):
30
30
  def create_module(self) -> nn.Module:
31
31
  return nn.ReLU()
32
32
 
33
+ @override
33
34
  def __call__(self, x: torch.Tensor) -> torch.Tensor:
34
35
  return F.relu(x)
35
36
 
@@ -43,6 +44,7 @@ class SigmoidNonlinearityConfig(NonlinearityConfigBase):
43
44
  def create_module(self) -> nn.Module:
44
45
  return nn.Sigmoid()
45
46
 
47
+ @override
46
48
  def __call__(self, x: torch.Tensor) -> torch.Tensor:
47
49
  return torch.sigmoid(x)
48
50
 
@@ -56,6 +58,7 @@ class TanhNonlinearityConfig(NonlinearityConfigBase):
56
58
  def create_module(self) -> nn.Module:
57
59
  return nn.Tanh()
58
60
 
61
+ @override
59
62
  def __call__(self, x: torch.Tensor) -> torch.Tensor:
60
63
  return torch.tanh(x)
61
64
 
@@ -72,6 +75,7 @@ class SoftmaxNonlinearityConfig(NonlinearityConfigBase):
72
75
  def create_module(self) -> nn.Module:
73
76
  return nn.Softmax(dim=self.dim)
74
77
 
78
+ @override
75
79
  def __call__(self, x: torch.Tensor) -> torch.Tensor:
76
80
  return torch.softmax(x, dim=self.dim)
77
81
 
@@ -91,6 +95,7 @@ class SoftplusNonlinearityConfig(NonlinearityConfigBase):
91
95
  def create_module(self) -> nn.Module:
92
96
  return nn.Softplus(beta=self.beta, threshold=self.threshold)
93
97
 
98
+ @override
94
99
  def __call__(self, x: torch.Tensor) -> torch.Tensor:
95
100
  return F.softplus(x, beta=self.beta, threshold=self.threshold)
96
101
 
@@ -104,6 +109,7 @@ class SoftsignNonlinearityConfig(NonlinearityConfigBase):
104
109
  def create_module(self) -> nn.Module:
105
110
  return nn.Softsign()
106
111
 
112
+ @override
107
113
  def __call__(self, x: torch.Tensor) -> torch.Tensor:
108
114
  return F.softsign(x)
109
115
 
@@ -120,6 +126,7 @@ class ELUNonlinearityConfig(NonlinearityConfigBase):
120
126
  def create_module(self) -> nn.Module:
121
127
  return nn.ELU()
122
128
 
129
+ @override
123
130
  def __call__(self, x: torch.Tensor) -> torch.Tensor:
124
131
  return F.elu(x, alpha=self.alpha)
125
132
 
@@ -136,6 +143,7 @@ class LeakyReLUNonlinearityConfig(NonlinearityConfigBase):
136
143
  def create_module(self) -> nn.Module:
137
144
  return nn.LeakyReLU(negative_slope=self.negative_slope)
138
145
 
146
+ @override
139
147
  def __call__(self, x: torch.Tensor) -> torch.Tensor:
140
148
  return F.leaky_relu(x, negative_slope=self.negative_slope)
141
149
 
@@ -157,6 +165,7 @@ class PReLUConfig(NonlinearityConfigBase):
157
165
  def create_module(self) -> nn.Module:
158
166
  return nn.PReLU(num_parameters=self.num_parameters, init=self.init)
159
167
 
168
+ @override
160
169
  def __call__(self, x: torch.Tensor) -> torch.Tensor:
161
170
  raise NotImplementedError(
162
171
  "PReLU requires learnable parameters and cannot be called directly."
@@ -175,6 +184,7 @@ class GELUNonlinearityConfig(NonlinearityConfigBase):
175
184
  def create_module(self) -> nn.Module:
176
185
  return nn.GELU(approximate=self.approximate)
177
186
 
187
+ @override
178
188
  def __call__(self, x: torch.Tensor) -> torch.Tensor:
179
189
  return F.gelu(x, approximate=self.approximate)
180
190
 
@@ -188,6 +198,7 @@ class SwishNonlinearityConfig(NonlinearityConfigBase):
188
198
  def create_module(self) -> nn.Module:
189
199
  return nn.SiLU()
190
200
 
201
+ @override
191
202
  def __call__(self, x: torch.Tensor) -> torch.Tensor:
192
203
  return F.silu(x)
193
204
 
@@ -201,6 +212,7 @@ class SiLUNonlinearityConfig(NonlinearityConfigBase):
201
212
  def create_module(self) -> nn.Module:
202
213
  return nn.SiLU()
203
214
 
215
+ @override
204
216
  def __call__(self, x: torch.Tensor) -> torch.Tensor:
205
217
  return F.silu(x)
206
218
 
@@ -214,6 +226,7 @@ class MishNonlinearityConfig(NonlinearityConfigBase):
214
226
  def create_module(self) -> nn.Module:
215
227
  return nn.Mish()
216
228
 
229
+ @override
217
230
  def __call__(self, x: torch.Tensor) -> torch.Tensor:
218
231
  return F.mish(x)
219
232
 
@@ -234,12 +247,12 @@ class SwiGLUNonlinearityConfig(NonlinearityConfigBase):
234
247
  def create_module(self) -> nn.Module:
235
248
  return SwiGLU()
236
249
 
250
+ @override
237
251
  def __call__(self, x: torch.Tensor) -> torch.Tensor:
238
252
  input, gate = x.chunk(2, dim=-1)
239
253
  return input * F.silu(gate)
240
254
 
241
255
 
242
256
  NonlinearityConfig = TypeAliasType(
243
- "NonlinearityConfig",
244
- Annotated[NonlinearityConfigBase, nonlinearity_registry.DynamicResolution()],
257
+ "NonlinearityConfig", Annotated[NonlinearityConfigBase, nonlinearity_registry]
245
258
  )
nshtrainer/optimizer.py CHANGED
@@ -2,11 +2,10 @@ from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from collections.abc import Iterable
5
- from typing import Annotated, Any, Literal, Tuple, Union
5
+ from typing import Annotated, Any, Literal
6
6
 
7
7
  import nshconfig as C
8
8
  import torch.nn as nn
9
- from torch import Tensor
10
9
  from torch.optim import Optimizer
11
10
  from typing_extensions import TypeAliasType, final, override
12
11
 
@@ -621,6 +620,5 @@ class SGDConfig(OptimizerConfigBase):
621
620
 
622
621
 
623
622
  OptimizerConfig = TypeAliasType(
624
- "OptimizerConfig",
625
- Annotated[OptimizerConfigBase, optimizer_registry.DynamicResolution()],
623
+ "OptimizerConfig", Annotated[OptimizerConfigBase, optimizer_registry]
626
624
  )
@@ -23,8 +23,7 @@ class AcceleratorConfigBase(C.Config, ABC):
23
23
  accelerator_registry = C.Registry(AcceleratorConfigBase, discriminator="name")
24
24
 
25
25
  AcceleratorConfig = TypeAliasType(
26
- "AcceleratorConfig",
27
- Annotated[AcceleratorConfigBase, accelerator_registry.DynamicResolution()],
26
+ "AcceleratorConfig", Annotated[AcceleratorConfigBase, accelerator_registry]
28
27
  )
29
28
 
30
29
 
@@ -13,6 +13,5 @@ from .base import PluginConfigBase as PluginConfigBase
13
13
  from .base import plugin_registry as plugin_registry
14
14
 
15
15
  PluginConfig = TypeAliasType(
16
- "PluginConfig",
17
- Annotated[PluginConfigBase, plugin_registry.DynamicResolution()],
16
+ "PluginConfig", Annotated[PluginConfigBase, plugin_registry]
18
17
  )
@@ -17,7 +17,7 @@ def get_code_dir() -> Path | None:
17
17
  # New versions of nshrunner will have the code_dir attribute
18
18
  # in the session object. We should use that. Otherwise, use snapshot_dir.
19
19
  try:
20
- code_dir = session.code_dir # type: ignore
20
+ code_dir = session.code_dir
21
21
  except AttributeError:
22
22
  code_dir = session.snapshot_dir
23
23
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.4.1
3
+ Version: 1.5.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,15 +1,15 @@
1
1
  nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
2
2
  nshtrainer/__init__.py,sha256=RI_2B_IUWa10B6H5TAuWtE5FWX1X4ue-J4dTDaF2-lQ,1035
3
- nshtrainer/_callback.py,sha256=ZDppiJ4d65tRXTEWYPZLH_F1xFizdz1pkWJe_sQ5uII,12564
3
+ nshtrainer/_callback.py,sha256=aBg9Za6hjteHcGjb8bIGzaN57A03cXrPv4rMWqaNsLU,13253
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=El9Ip8jGA7mAN5rAMpVfg1dfUe2dGoOOfvF1JfYJGHM,5676
5
5
  nshtrainer/_checkpoint/saver.py,sha256=utcrYKSosd04N9m2GIylufO5DO05D90qVU3mvadfApU,1658
6
6
  nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
7
7
  nshtrainer/_hf_hub.py,sha256=OB4252GJ6AbKNCRmHVvEglvjYVMUN822BFYECABxfZU,14037
8
- nshtrainer/callbacks/__init__.py,sha256=m6eJuprZfBELuKpngKXre33B9yPXkG7jlKVmI-0yXRQ,4000
8
+ nshtrainer/callbacks/__init__.py,sha256=6l2vrFywWftzKTlZMEkF-WgE5uLjLgX89BoUMq8_x-0,3980
9
9
  nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
10
10
  nshtrainer/callbacks/base.py,sha256=K9aom1WVVRYxl-tHWgtmDUQZ1o63NgznvLsjauTKcCc,4225
11
11
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
12
- nshtrainer/callbacks/checkpoint/_base.py,sha256=BjgfCXsf4Ihf1MNKkHBUwjHMLwc04PZO-2Bx-LdAazg,11010
12
+ nshtrainer/callbacks/checkpoint/_base.py,sha256=AsNt1bZ-yloPHqenRy4KAJK5DDmhBY1RprR2_xbvomc,11010
13
13
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=aCs3E1eucfDlUeW2Iq_Ke7hb96BxHanmvn7PCCbqq0E,2648
14
14
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM6kUSHYNwHJeYG82j2dMcc,3554
15
15
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
@@ -25,14 +25,14 @@ nshtrainer/callbacks/log_epoch.py,sha256=C2yUww8lAuCX-dy06tsw95yCBOfFd2mfGs0VhrE
25
25
  nshtrainer/callbacks/lr_monitor.py,sha256=v45ehnwNO987087HfiOY5aIrVRbwdKMgPYRFHs1fyEE,1444
26
26
  nshtrainer/callbacks/metric_validation.py,sha256=4RDr1FuNKfro-6QEtmcFqT4iNf2twmJVNk9y-8nq9bg,2882
27
27
  nshtrainer/callbacks/norm_logging.py,sha256=nVIDWe-ASl5zN830-ODR8QMCqI1ma-QPCIwoy0Wb-Nk,6390
28
- nshtrainer/callbacks/print_table.py,sha256=VaS4JgI963do79laXK4lUkFQx8v6aRSy22W0zyal_LA,3035
29
- nshtrainer/callbacks/rlp_sanity_checks.py,sha256=Df9Prq2QKXnaeMBIvMQBhDhJTDeru5UbiuXJOJR16Gk,10050
28
+ nshtrainer/callbacks/print_table.py,sha256=xdDvogpLFHdaHM4yDGENvJUX4Gz4hDq-QpsPcv-Oqi8,3041
29
+ nshtrainer/callbacks/rlp_sanity_checks.py,sha256=PRtcj9K9fa2Oh6nbKQJR6w2__on0Jln969qZXlnkv1Q,10064
30
30
  nshtrainer/callbacks/shared_parameters.py,sha256=s94jJTAIbDGukYJu6l247QonVOCudGClU4t5kLt8XrY,3076
31
31
  nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU,4731
32
32
  nshtrainer/callbacks/wandb_upload_code.py,sha256=4X-mpiX5ghj9vnEreK2i8Xyvimqt0K-PNWA2HtT-B6I,1940
33
33
  nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
34
34
  nshtrainer/configs/.gitattributes,sha256=VeZmarvNEqiRBOHGcllpKm90nL6C8u4tBu7SEm7fj-E,26
35
- nshtrainer/configs/__init__.py,sha256=-yJ5Uk9VkANqfk-QnX2aynL0jSf7cJQuQNzT1GAE1x8,15684
35
+ nshtrainer/configs/__init__.py,sha256=ZHV_1zCZKUYBKzWiLPrF8eFKsb-gepAF4G7AmsCxkkA,15623
36
36
  nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
37
37
  nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
38
38
  nshtrainer/configs/_hf_hub/__init__.py,sha256=ciFLbV-JV8SVzqo2SyythEuDMnk7gGfdIacB18QYnkY,511
@@ -77,7 +77,7 @@ nshtrainer/configs/nn/__init__.py,sha256=Ms2gIqbRxNVm6GHKCddCJTTqMwUPifjjHD_fCfJ
77
77
  nshtrainer/configs/nn/mlp/__init__.py,sha256=O6kQ6utZNJPG9Fax5pRdZcHa3J-XFKKdXcc_PQg0jk0,347
78
78
  nshtrainer/configs/nn/nonlinearity/__init__.py,sha256=LCTbTyelCMABVw505CGQ4UpEGlAnIhflSLFqwAQXLQA,2155
79
79
  nshtrainer/configs/nn/rng/__init__.py,sha256=4iC6vwxbfNeXyvpwZ1Z5Kcy-he4cu7mg3UpLD-RLrHc,141
80
- nshtrainer/configs/optimizer/__init__.py,sha256=8ztp5UD-edfzwF-qdJTeZwlv-YWJ5Sn230b9aWxJyQQ,1398
80
+ nshtrainer/configs/optimizer/__init__.py,sha256=Kq6ACztSQhwgE_tP4F1RI7nQMBgC1ebQmY3HaBYKbeg,1337
81
81
  nshtrainer/configs/profiler/__init__.py,sha256=2ssaIpfVnvcbfNvZ-JeKp1Cx4NO1LknkVqTm1hu7Lvw,768
82
82
  nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRDm1CKqjwUOQNbQjD4,176
83
83
  nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
@@ -103,30 +103,30 @@ nshtrainer/data/__init__.py,sha256=K4i3Tw4g9EOK2zlMMbidi99y0SyI4w8P7_XUf1n42Ts,2
103
103
  nshtrainer/data/balanced_batch_sampler.py,sha256=r1cBKRXKHD8E1Ax6tj-FUbE-z1qpbO58mQ9VrK9uLnc,5481
104
104
  nshtrainer/data/datamodule.py,sha256=Rb4-mA8iXtjRlNUHcIqVPEvxA_VkiJXwN1EvHIsydJ0,4095
105
105
  nshtrainer/data/transform.py,sha256=qd0lIocO59Fk_m90xyOHgFezbymd1mRwly8nbYIfHGc,2263
106
- nshtrainer/loggers/__init__.py,sha256=fI0OHEltHP4tZI-KFB3npdzoxm_M2QsEYKxY3um05_s,592
107
- nshtrainer/loggers/actsave.py,sha256=wgNrpBB6wQM7qff8iLDb_sQnbiAcYHRmH56pcEJPB3o,1409
106
+ nshtrainer/loggers/__init__.py,sha256=0fnclaEIgAUrRlYuSKfzni11dlJ6edllrs06NVmbtYc,567
107
+ nshtrainer/loggers/actsave.py,sha256=Xd21jaBVUmkxITKYfycWKZEgcHu1-dmi1H5EYEjvPDw,1503
108
108
  nshtrainer/loggers/base.py,sha256=ON92XbwTSgadQOSyw5PiRRFzyH6uJ-xLtE0nB3cbgPc,1205
109
109
  nshtrainer/loggers/csv.py,sha256=xJ8mSmw4vJwinIfqhF6t2HWmh_1dXEYyLfGuXwL7WHo,1160
110
110
  nshtrainer/loggers/tensorboard.py,sha256=E7iO_fDt9bfH02hBL430bXPLljOo5iGgq2QyPqmx2gQ,2324
111
- nshtrainer/loggers/wandb.py,sha256=KZXAUWrrmdX_L8rqej77oUHaM0JxZRM8y9z6JP9PISw,6856
111
+ nshtrainer/loggers/wandb.py,sha256=EK2rvJwmV-LxXIm21ZORNBI9nz-AXqo2mIN7xyjs8bc,6776
112
112
  nshtrainer/lr_scheduler/__init__.py,sha256=daMMK3erUcNXGGd_nZB8AWu3ZTYqfS1RSWeK4FV2udw,851
113
- nshtrainer/lr_scheduler/base.py,sha256=LE53JRBTuAlA1fqbMgCZ7m39D1z0rGj2TizhJ62CPvE,3756
113
+ nshtrainer/lr_scheduler/base.py,sha256=24I3PNlj2iYPoaHeD2_InMAptclCtMnZoD8nXWqxLYw,3740
114
114
  nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=MsoXgCcWTKsrkNZiGnKS6yC-slRuleuwFxeM_lmG_pQ,5560
115
115
  nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=irPyDjfUX843ze4bJM9sW8WSeEcU643QJ30JN2hz9Rc,3206
116
116
  nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
117
117
  nshtrainer/metrics/_config.py,sha256=ox_ScK6V0J9nzIMhEB0qpToNKpt83VVgOVSRFCV-wBc,595
118
118
  nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
119
119
  nshtrainer/model/base.py,sha256=PvTmupfGahEZME0BWqbeErDPP1VOm2Nm9JxJkO8afcc,10815
120
- nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
120
+ nshtrainer/model/mixins/callback.py,sha256=-walDV3fxH4K-ezugvL__Tml9OP1WIlaaTT8j6mWxLI,2580
121
121
  nshtrainer/model/mixins/debug.py,sha256=ydLuAAaa7M5bX0gougZ5gWuZnvn4Ra9assal3IZ9hq8,2086
122
- nshtrainer/model/mixins/logger.py,sha256=7u9fQig-SVFA9RFIB4U0gqJAzruh49mgmXXvZ6VkDUk,11694
122
+ nshtrainer/model/mixins/logger.py,sha256=-D4YwSg0eTDtXj3N288FEo6rqsZ518u1aMBE4Dv4tmg,11708
123
123
  nshtrainer/nn/__init__.py,sha256=Vd246v2N9tBQ8XxmTquWzj5lAmeSnngrjpYOfp4LTXM,1499
124
124
  nshtrainer/nn/mlp.py,sha256=nYUgAISzuhC8sav6PloAdyz0PdEoikwppiXIuToEVdE,7550
125
- nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,2365
126
- nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
127
- nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
125
+ nshtrainer/nn/module_dict.py,sha256=FJrxUgQkY6O6tmA_7I_kRoPvxLtPU3cYZY-42InVG3A,2366
126
+ nshtrainer/nn/module_list.py,sha256=xvoF8F-pG-z3MnYc91anG9vUQFVes6niy8J8J0qVAlg,2091
127
+ nshtrainer/nn/nonlinearity.py,sha256=UhAsc8o_6AIsos6sktUWC9xLFCFgHJn5WiurKN1sf5U,6493
128
128
  nshtrainer/nn/rng.py,sha256=IJGvX9v8qBkfgBrMlNU2aj-MbYTPoncFyJzvPkzCQpM,512
129
- nshtrainer/optimizer.py,sha256=8pjOny7NxIt04PXxn3zOyJ2soL7nmj8yBVV82r_tNsc,17522
129
+ nshtrainer/optimizer.py,sha256=hvw_UNovYgLHhDvMr9BbUz3EPOIrGZDz9ir8lvCgiw0,17458
130
130
  nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
131
131
  nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
132
132
  nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5NOQ,1160
@@ -137,8 +137,8 @@ nshtrainer/trainer/_config.py,sha256=GL8DtuH-6x2aHcRlEcmzyhEBMRRldiSazNAeNmPw7gM
137
137
  nshtrainer/trainer/_distributed_prediction_result.py,sha256=bQw8Z6PT694UUf-zQPkech6CxyUSy8bAIexfSfPej0U,2507
138
138
  nshtrainer/trainer/_log_hparams.py,sha256=XH2lZ4U_3AZBhOt91ocsEhdL_NRz35oWvqLCUFDohUs,2389
139
139
  nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
140
- nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
141
- nshtrainer/trainer/plugin/__init__.py,sha256=q_q98MYNaZ2VE_tqGqYlQjQnlaF4NE1FUqVVbj0EK7k,517
140
+ nshtrainer/trainer/accelerator.py,sha256=rWfSJ-pQsLREaRPF_rRXsqxvaQQf6XGT6zpNt829jk0,2390
141
+ nshtrainer/trainer/plugin/__init__.py,sha256=LSxEK0vnoN9WkU8MDIetrVrDPLCowGLoc9cvh6RG6gg,492
142
142
  nshtrainer/trainer/plugin/base.py,sha256=76ct2TTHLpPr5MO8B9CIkoCOo-dFImzqAll8cIdC0cg,736
143
143
  nshtrainer/trainer/plugin/environment.py,sha256=SSXRWHjyFUA6oFx3duD_ZwhM59pWUjR1_UzHz02NI2c,5440
144
144
  nshtrainer/trainer/plugin/io.py,sha256=OmFSKLloMypletjaUr_Ptg6LS0ljqTVIp2o4Hm3eZoE,1926
@@ -149,7 +149,7 @@ nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTu
149
149
  nshtrainer/trainer/trainer.py,sha256=G_tHqzZCHJazhROcoKeOI5rZ5A8F8XlghiIWkdMbPR0,24387
150
150
  nshtrainer/util/_environment_info.py,sha256=j-wyEHKirsu3rIXTtqC2kLmIIkRe6obWjxPVWaqg2ow,24887
151
151
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
152
- nshtrainer/util/code_upload.py,sha256=CpbZEBbA8EcBElUVoCPbP5zdwtNzJhS20RLaOB-q-2k,1257
152
+ nshtrainer/util/code_upload.py,sha256=o0GKWROL5EUvJ2F-eOr9ag6R588ZbgG8HX37fvEMfgY,1241
153
153
  nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
154
154
  nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
155
155
  nshtrainer/util/config/duration.py,sha256=mM-UfU_HvhXwW33TYEDg0x58n80tnle2e6VaWtxZTjk,764
@@ -159,6 +159,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
159
159
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
160
160
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
161
161
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
162
- nshtrainer-1.4.1.dist-info/METADATA,sha256=QL69Trcmw3NF3UOovpqVJbzBTtHJtnDDxAzxyj9EX24,980
163
- nshtrainer-1.4.1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
164
- nshtrainer-1.4.1.dist-info/RECORD,,
162
+ nshtrainer-1.5.0.dist-info/METADATA,sha256=fbtia7kDnNxHx_8VE0I-zFtmlF-HMAxH5raSiPjtl7w,980
163
+ nshtrainer-1.5.0.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
164
+ nshtrainer-1.5.0.dist-info/RECORD,,