nshtrainer 1.1.1b1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/_directory.py CHANGED
@@ -65,9 +65,9 @@ class DirectoryConfig(C.Config):
65
65
  ) -> Path:
66
66
  # The subdir will be $CWD/nshtrainer/{id}/{log, stdio, checkpoint, activation}/
67
67
  if (subdir := getattr(self, subdirectory, None)) is not None:
68
- assert isinstance(
69
- subdir, Path
70
- ), f"Expected a Path for {subdirectory}, got {type(subdir)}"
68
+ assert isinstance(subdir, Path), (
69
+ f"Expected a Path for {subdirectory}, got {type(subdir)}"
70
+ )
71
71
  return subdir
72
72
 
73
73
  dir = self.resolve_run_root_directory(run_id)
@@ -23,6 +23,12 @@ from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
23
23
  from .directory_setup import (
24
24
  DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
25
25
  )
26
+ from .distributed_prediction_writer import (
27
+ DistributedPredictionWriter as DistributedPredictionWriter,
28
+ )
29
+ from .distributed_prediction_writer import (
30
+ DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
31
+ )
26
32
  from .early_stopping import EarlyStoppingCallback as EarlyStoppingCallback
27
33
  from .early_stopping import EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig
28
34
  from .ema import EMACallback as EMACallback
@@ -23,6 +23,10 @@ class CallbackMetadataConfig(TypedDict, total=False):
23
23
  """Priority of the callback. Callbacks with higher priority will be loaded first.
24
24
  Default is `0`."""
25
25
 
26
+ enabled_for_barebones: bool
27
+ """Whether this callback is enabled for barebones mode.
28
+ Default is `False`."""
29
+
26
30
 
27
31
  @dataclass(frozen=True)
28
32
  class CallbackWithMetadata:
@@ -91,10 +95,20 @@ def _filter_ignore_if_exists(callbacks: list[CallbackWithMetadata]):
91
95
 
92
96
 
93
97
  def _process_and_filter_callbacks(
98
+ trainer_config: TrainerConfig,
94
99
  callbacks: Iterable[CallbackWithMetadata],
95
100
  ) -> list[Callback]:
96
101
  callbacks = list(callbacks)
97
102
 
103
+ # If we're in barebones mode, used the callback metadata
104
+ # to decide to keep/remove the callback.
105
+ if trainer_config.barebones:
106
+ callbacks = [
107
+ callback
108
+ for callback in callbacks
109
+ if callback.metadata.get("enabled_for_barebones", False)
110
+ ]
111
+
98
112
  # Sort by priority (higher priority first)
99
113
  callbacks.sort(
100
114
  key=lambda callback: callback.metadata.get("priority", 0),
@@ -114,9 +128,14 @@ def resolve_all_callbacks(trainer_config: TrainerConfig):
114
128
  if config is not None
115
129
  ]
116
130
  callbacks = _process_and_filter_callbacks(
117
- callback
118
- for callback_config in callback_configs
119
- for callback in _create_callbacks_with_metadata(callback_config, trainer_config)
131
+ trainer_config,
132
+ (
133
+ callback
134
+ for callback_config in callback_configs
135
+ for callback in _create_callbacks_with_metadata(
136
+ callback_config, trainer_config
137
+ )
138
+ ),
120
139
  )
121
140
  return callbacks
122
141
 
@@ -0,0 +1,166 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import logging
5
+ from collections.abc import Iterator, Sequence
6
+ from pathlib import Path
7
+ from typing import Any, ClassVar, Literal, overload
8
+
9
+ import torch
10
+ from lightning.fabric.utilities.apply_func import move_data_to_device
11
+ from lightning.pytorch.callbacks import BasePredictionWriter
12
+ from typing_extensions import final, override
13
+
14
+ from .base import CallbackConfigBase, CallbackMetadataConfig, callback_registry
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ @final
20
+ @callback_registry.register
21
+ class DistributedPredictionWriterConfig(CallbackConfigBase):
22
+ metadata: ClassVar[CallbackMetadataConfig] = CallbackMetadataConfig(
23
+ enabled_for_barebones=True
24
+ )
25
+ """Metadata for the callback."""
26
+
27
+ name: Literal["distributed_prediction_writer"] = "distributed_prediction_writer"
28
+
29
+ dirpath: Path | None = None
30
+ """Directory to save the predictions to. If None, will use the default directory."""
31
+
32
+ move_to_cpu_on_save: bool = True
33
+ """Whether to move the predictions to CPU before saving. Default is True."""
34
+
35
+ save_raw: bool = True
36
+ """Whether to save the raw predictions."""
37
+
38
+ save_processed: bool = True
39
+ """Whether to process and save the predictions.
40
+
41
+ "Processing" means that the model's batched predictions are split into individual predictions
42
+ and saved as a list of tensors.
43
+ """
44
+
45
+ @override
46
+ def create_callbacks(self, trainer_config):
47
+ if (dirpath := self.dirpath) is None:
48
+ dirpath = trainer_config.directory.resolve_subdirectory(
49
+ trainer_config.id, "predictions"
50
+ )
51
+
52
+ yield DistributedPredictionWriter(self, dirpath)
53
+
54
+
55
+ def _move_and_save(data, path: Path, move_to_cpu: bool):
56
+ if move_to_cpu:
57
+ data = move_data_to_device(data, "cpu")
58
+
59
+ # Save the data to the specified path
60
+ torch.save(data, path)
61
+
62
+
63
+ class DistributedPredictionWriter(BasePredictionWriter):
64
+ def __init__(
65
+ self,
66
+ config: DistributedPredictionWriterConfig,
67
+ output_dir: Path,
68
+ ):
69
+ self.config = config
70
+
71
+ super().__init__(write_interval="batch")
72
+
73
+ self.output_dir = output_dir
74
+
75
+ @override
76
+ def write_on_batch_end(
77
+ self,
78
+ trainer,
79
+ pl_module,
80
+ prediction,
81
+ batch_indices,
82
+ batch,
83
+ batch_idx,
84
+ dataloader_idx,
85
+ ):
86
+ save = functools.partial(
87
+ _move_and_save,
88
+ move_to_cpu=self.config.move_to_cpu_on_save,
89
+ )
90
+
91
+ # Regular, unstructured writing.
92
+ if self.config.save_raw:
93
+ output_dir = (
94
+ self.output_dir
95
+ / "raw"
96
+ / f"dataloader_{dataloader_idx}"
97
+ / f"rank_{trainer.global_rank}"
98
+ / f"batch_{batch_idx}"
99
+ )
100
+ output_dir.mkdir(parents=True, exist_ok=True)
101
+ save(prediction, output_dir / "predictions.pt")
102
+ save(batch, output_dir / "batch.pt")
103
+ save(batch_indices, output_dir / "batch_indices.pt")
104
+
105
+ if self.config.save_processed:
106
+ # Processed writing.
107
+ from ..model.base import LightningModuleBase
108
+
109
+ if not isinstance(pl_module, LightningModuleBase):
110
+ raise ValueError(
111
+ "The model must be a subclass of LightningModuleBase to use the distributed prediction writer."
112
+ )
113
+
114
+ output_dir = self.output_dir / "processed" / f"dataloader_{dataloader_idx}"
115
+ output_dir.mkdir(parents=True, exist_ok=True)
116
+
117
+ # Split into individual predictions
118
+ assert batch_indices is not None, (
119
+ "Batch indices must be provided for processed writing."
120
+ )
121
+ for sample in pl_module.split_batched_predictions(
122
+ batch, prediction, batch_indices
123
+ ):
124
+ sample = {
125
+ **sample,
126
+ "global_rank": trainer.global_rank,
127
+ "world_size": trainer.world_size,
128
+ "is_global_zero": trainer.is_global_zero,
129
+ }
130
+ save(sample, output_dir / f"{sample['index']}.pt")
131
+
132
+
133
+ class DistributedPredictionReader(Sequence[tuple[Any, Any]]):
134
+ def __init__(self, output_dir: Path):
135
+ self.output_dir = output_dir
136
+
137
+ @override
138
+ def __len__(self) -> int:
139
+ return len(list(self.output_dir.glob("*.pt")))
140
+
141
+ @overload
142
+ def __getitem__(self, index: int) -> tuple[Any, Any]: ...
143
+
144
+ @overload
145
+ def __getitem__(self, index: slice) -> list[tuple[Any, Any]]: ...
146
+
147
+ @override
148
+ def __getitem__(
149
+ self, index: int | slice
150
+ ) -> tuple[Any, Any] | list[tuple[Any, Any]]:
151
+ if isinstance(index, slice):
152
+ # Handle slice indexing
153
+ indices = range(*index.indices(len(self)))
154
+ return [self.__getitem__(i) for i in indices]
155
+
156
+ # Handle integer indexing
157
+ path = self.output_dir / f"{index}.pt"
158
+ if not path.exists():
159
+ raise FileNotFoundError(f"File {path} does not exist.")
160
+ sample = torch.load(path)
161
+ return sample["batch"], sample["prediction"]
162
+
163
+ @override
164
+ def __iter__(self) -> Iterator[tuple[Any, Any]]:
165
+ for i in range(len(self)):
166
+ yield self[i]
@@ -21,6 +21,9 @@ from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackCon
21
21
  from nshtrainer.callbacks import (
22
22
  DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
23
23
  )
24
+ from nshtrainer.callbacks import (
25
+ DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
26
+ )
24
27
  from nshtrainer.callbacks import (
25
28
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
26
29
  )
@@ -95,9 +98,21 @@ from nshtrainer.nn.nonlinearity import (
95
98
  SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
96
99
  )
97
100
  from nshtrainer.nn.nonlinearity import nonlinearity_registry as nonlinearity_registry
101
+ from nshtrainer.optimizer import AdadeltaConfig as AdadeltaConfig
102
+ from nshtrainer.optimizer import AdafactorConfig as AdafactorConfig
103
+ from nshtrainer.optimizer import AdagradConfig as AdagradConfig
104
+ from nshtrainer.optimizer import AdamaxConfig as AdamaxConfig
105
+ from nshtrainer.optimizer import AdamConfig as AdamConfig
98
106
  from nshtrainer.optimizer import AdamWConfig as AdamWConfig
107
+ from nshtrainer.optimizer import ASGDConfig as ASGDConfig
108
+ from nshtrainer.optimizer import NAdamConfig as NAdamConfig
99
109
  from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
100
110
  from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
111
+ from nshtrainer.optimizer import RAdamConfig as RAdamConfig
112
+ from nshtrainer.optimizer import RMSpropConfig as RMSpropConfig
113
+ from nshtrainer.optimizer import RpropConfig as RpropConfig
114
+ from nshtrainer.optimizer import SGDConfig as SGDConfig
115
+ from nshtrainer.optimizer import Union as Union
101
116
  from nshtrainer.optimizer import optimizer_registry as optimizer_registry
102
117
  from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
103
118
  from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
@@ -225,11 +240,17 @@ from . import trainer as trainer
225
240
  from . import util as util
226
241
 
227
242
  __all__ = [
243
+ "ASGDConfig",
228
244
  "AcceleratorConfig",
229
245
  "AcceleratorConfigBase",
230
246
  "ActSaveConfig",
231
247
  "ActSaveLoggerConfig",
248
+ "AdadeltaConfig",
249
+ "AdafactorConfig",
250
+ "AdagradConfig",
251
+ "AdamConfig",
232
252
  "AdamWConfig",
253
+ "AdamaxConfig",
233
254
  "AdvancedProfilerConfig",
234
255
  "AsyncCheckpointIOPlugin",
235
256
  "BaseCheckpointCallbackConfig",
@@ -249,6 +270,7 @@ __all__ = [
249
270
  "DeepSpeedPluginConfig",
250
271
  "DirectoryConfig",
251
272
  "DirectorySetupCallbackConfig",
273
+ "DistributedPredictionWriterConfig",
252
274
  "DoublePrecisionPluginConfig",
253
275
  "DurationConfig",
254
276
  "ELUNonlinearityConfig",
@@ -294,6 +316,7 @@ __all__ = [
294
316
  "MetricValidationCallbackConfig",
295
317
  "MishNonlinearityConfig",
296
318
  "MixedPrecisionPluginConfig",
319
+ "NAdamConfig",
297
320
  "NonlinearityConfig",
298
321
  "NonlinearityConfigBase",
299
322
  "NormLoggingCallbackConfig",
@@ -306,10 +329,14 @@ __all__ = [
306
329
  "PrintTableMetricsCallbackConfig",
307
330
  "ProfilerConfig",
308
331
  "PyTorchProfilerConfig",
332
+ "RAdamConfig",
309
333
  "RLPSanityChecksCallbackConfig",
334
+ "RMSpropConfig",
310
335
  "RNGConfig",
311
336
  "ReLUNonlinearityConfig",
312
337
  "ReduceLROnPlateauConfig",
338
+ "RpropConfig",
339
+ "SGDConfig",
313
340
  "SLURMEnvironmentPlugin",
314
341
  "SanityCheckingConfig",
315
342
  "SharedParametersCallbackConfig",
@@ -331,6 +358,7 @@ __all__ = [
331
358
  "TorchSyncBatchNormPlugin",
332
359
  "TrainerConfig",
333
360
  "TransformerEnginePluginConfig",
361
+ "Union",
334
362
  "WandbLoggerConfig",
335
363
  "WandbUploadCodeCallbackConfig",
336
364
  "WandbWatchCallbackConfig",
@@ -12,6 +12,9 @@ from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackCon
12
12
  from nshtrainer.callbacks import (
13
13
  DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
14
14
  )
15
+ from nshtrainer.callbacks import (
16
+ DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
17
+ )
15
18
  from nshtrainer.callbacks import (
16
19
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
17
20
  )
@@ -62,6 +65,7 @@ from . import base as base
62
65
  from . import checkpoint as checkpoint
63
66
  from . import debug_flag as debug_flag
64
67
  from . import directory_setup as directory_setup
68
+ from . import distributed_prediction_writer as distributed_prediction_writer
65
69
  from . import early_stopping as early_stopping
66
70
  from . import ema as ema
67
71
  from . import finite_checks as finite_checks
@@ -86,6 +90,7 @@ __all__ = [
86
90
  "CheckpointMetadata",
87
91
  "DebugFlagCallbackConfig",
88
92
  "DirectorySetupCallbackConfig",
93
+ "DistributedPredictionWriterConfig",
89
94
  "EMACallbackConfig",
90
95
  "EarlyStoppingCallbackConfig",
91
96
  "EpochTimerCallbackConfig",
@@ -109,6 +114,7 @@ __all__ = [
109
114
  "checkpoint",
110
115
  "debug_flag",
111
116
  "directory_setup",
117
+ "distributed_prediction_writer",
112
118
  "early_stopping",
113
119
  "ema",
114
120
  "finite_checks",
@@ -0,0 +1,19 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.callbacks.distributed_prediction_writer import (
6
+ CallbackConfigBase as CallbackConfigBase,
7
+ )
8
+ from nshtrainer.callbacks.distributed_prediction_writer import (
9
+ DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
10
+ )
11
+ from nshtrainer.callbacks.distributed_prediction_writer import (
12
+ callback_registry as callback_registry,
13
+ )
14
+
15
+ __all__ = [
16
+ "CallbackConfigBase",
17
+ "DistributedPredictionWriterConfig",
18
+ "callback_registry",
19
+ ]
@@ -2,14 +2,38 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
+ from nshtrainer.optimizer import AdadeltaConfig as AdadeltaConfig
6
+ from nshtrainer.optimizer import AdafactorConfig as AdafactorConfig
7
+ from nshtrainer.optimizer import AdagradConfig as AdagradConfig
8
+ from nshtrainer.optimizer import AdamaxConfig as AdamaxConfig
9
+ from nshtrainer.optimizer import AdamConfig as AdamConfig
5
10
  from nshtrainer.optimizer import AdamWConfig as AdamWConfig
11
+ from nshtrainer.optimizer import ASGDConfig as ASGDConfig
12
+ from nshtrainer.optimizer import NAdamConfig as NAdamConfig
6
13
  from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
7
14
  from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
15
+ from nshtrainer.optimizer import RAdamConfig as RAdamConfig
16
+ from nshtrainer.optimizer import RMSpropConfig as RMSpropConfig
17
+ from nshtrainer.optimizer import RpropConfig as RpropConfig
18
+ from nshtrainer.optimizer import SGDConfig as SGDConfig
19
+ from nshtrainer.optimizer import Union as Union
8
20
  from nshtrainer.optimizer import optimizer_registry as optimizer_registry
9
21
 
10
22
  __all__ = [
23
+ "ASGDConfig",
24
+ "AdadeltaConfig",
25
+ "AdafactorConfig",
26
+ "AdagradConfig",
27
+ "AdamConfig",
11
28
  "AdamWConfig",
29
+ "AdamaxConfig",
30
+ "NAdamConfig",
12
31
  "OptimizerConfig",
13
32
  "OptimizerConfigBase",
33
+ "RAdamConfig",
34
+ "RMSpropConfig",
35
+ "RpropConfig",
36
+ "SGDConfig",
37
+ "Union",
14
38
  "optimizer_registry",
15
39
  ]
@@ -22,6 +22,9 @@ from nshtrainer.trainer._config import (
22
22
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
23
23
  )
24
24
  from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
25
+ from nshtrainer.trainer._config import (
26
+ DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
27
+ )
25
28
  from nshtrainer.trainer._config import (
26
29
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
27
30
  )
@@ -149,6 +152,7 @@ __all__ = [
149
152
  "DebugFlagCallbackConfig",
150
153
  "DeepSpeedPluginConfig",
151
154
  "DirectoryConfig",
155
+ "DistributedPredictionWriterConfig",
152
156
  "DoublePrecisionPluginConfig",
153
157
  "EarlyStoppingCallbackConfig",
154
158
  "EnvironmentConfig",
@@ -18,6 +18,9 @@ from nshtrainer.trainer._config import (
18
18
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
19
19
  )
20
20
  from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
21
+ from nshtrainer.trainer._config import (
22
+ DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
23
+ )
21
24
  from nshtrainer.trainer._config import (
22
25
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
23
26
  )
@@ -70,6 +73,7 @@ __all__ = [
70
73
  "CheckpointSavingConfig",
71
74
  "DebugFlagCallbackConfig",
72
75
  "DirectoryConfig",
76
+ "DistributedPredictionWriterConfig",
73
77
  "EarlyStoppingCallbackConfig",
74
78
  "EnvironmentConfig",
75
79
  "GradientClippingConfig",
nshtrainer/model/base.py CHANGED
@@ -2,9 +2,9 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  from abc import ABC, abstractmethod
5
- from collections.abc import Callable, Mapping
5
+ from collections.abc import Callable, Iterable, Mapping, Sequence
6
6
  from pathlib import Path
7
- from typing import Any, Generic, Literal, cast
7
+ from typing import Any, Generic, Literal, TypedDict, cast
8
8
 
9
9
  import nshconfig as C
10
10
  import torch
@@ -53,6 +53,47 @@ VALID_REDUCE_OPS = (
53
53
  )
54
54
 
55
55
 
56
+ class IndividualSample(TypedDict):
57
+ """
58
+ A dictionary that contains the individual sample.
59
+ This is used to split the batched predictions into individual predictions.
60
+ """
61
+
62
+ index: int
63
+ """The index of the sample in the batch."""
64
+
65
+ batch: Any
66
+ """The batch to split."""
67
+
68
+ prediction: Any
69
+ """The batched prediction to split."""
70
+
71
+
72
+ def default_split_batched_predictions(
73
+ batch: Any,
74
+ prediction: Any,
75
+ batch_indices: Sequence[Any],
76
+ ) -> Iterable[IndividualSample]:
77
+ """
78
+ Splits the batched predictions into a list of individual predictions.
79
+ Args:
80
+ batch: The batch to split.
81
+ prediction: The batched prediction to split.
82
+ batch_indices: The indices of the batches.
83
+ Returns:
84
+ A tuple of two sequences: the corresponding batches and the individual predictions.
85
+ """
86
+ import torch.utils._pytree as tree
87
+
88
+ for sample_idx, batch_idx in enumerate(batch_indices):
89
+ # Create a dictionary for each sample
90
+ yield IndividualSample(
91
+ index=batch_idx,
92
+ batch=tree.tree_map(lambda x: x[sample_idx], batch),
93
+ prediction=tree.tree_map(lambda x: x[sample_idx], prediction),
94
+ )
95
+
96
+
56
97
  class LightningModuleBase(
57
98
  DebugModuleMixin,
58
99
  RLPSanityCheckModuleMixin,
@@ -171,6 +212,23 @@ class LightningModuleBase(
171
212
  loss = cast(torch.Tensor, loss)
172
213
  return loss
173
214
 
215
+ def split_batched_predictions(
216
+ self,
217
+ batch: Any,
218
+ prediction: Any,
219
+ batch_indices: Sequence[Any],
220
+ ) -> Iterable[IndividualSample]:
221
+ """
222
+ Splits the batched predictions into a list of individual predictions.
223
+ Args:
224
+ batch: The batch to split.
225
+ prediction: The batched prediction to split.
226
+ batch_indices: The indices of the batches.
227
+ Returns:
228
+ A tuple of two sequences: the corresponding batches and the individual predictions.
229
+ """
230
+ return default_split_batched_predictions(batch, prediction, batch_indices)
231
+
174
232
  @override
175
233
  @classmethod
176
234
  def load_from_checkpoint(cls, *args, **kwargs) -> Never:
nshtrainer/optimizer.py CHANGED
@@ -2,10 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from collections.abc import Iterable
5
- from typing import Annotated, Any, Literal
5
+ from typing import Annotated, Any, Literal, Tuple, Union
6
6
 
7
7
  import nshconfig as C
8
8
  import torch.nn as nn
9
+ from torch import Tensor
9
10
  from torch.optim import Optimizer
10
11
  from typing_extensions import TypeAliasType, final, override
11
12
 
@@ -45,6 +46,18 @@ class AdamWConfig(OptimizerConfigBase):
45
46
  amsgrad: bool = False
46
47
  """Whether to use the AMSGrad variant of this algorithm."""
47
48
 
49
+ maximize: bool = False
50
+ """Maximize the objective with respect to the params, instead of minimizing."""
51
+
52
+ foreach: bool | None = None
53
+ """Whether foreach implementation of optimizer is used."""
54
+
55
+ capturable: bool = False
56
+ """Whether this instance is safe to capture in a CUDA graph."""
57
+
58
+ differentiable: bool = False
59
+ """Whether autograd should occur through the optimizer step in training."""
60
+
48
61
  @override
49
62
  def create_optimizer(
50
63
  self,
@@ -59,6 +72,551 @@ class AdamWConfig(OptimizerConfigBase):
59
72
  betas=self.betas,
60
73
  eps=self.eps,
61
74
  amsgrad=self.amsgrad,
75
+ maximize=self.maximize,
76
+ foreach=self.foreach,
77
+ capturable=self.capturable,
78
+ differentiable=self.differentiable,
79
+ )
80
+
81
+
82
+ @final
83
+ @optimizer_registry.register
84
+ class AdafactorConfig(OptimizerConfigBase):
85
+ name: Literal["adafactor"] = "adafactor"
86
+ lr: float
87
+ """Learning rate for the optimizer. If None, uses relative step size."""
88
+
89
+ eps1: float | None = None
90
+ """Term added to the denominator to improve numerical stability (default: None)."""
91
+
92
+ eps2: float = 1e-3
93
+ """Term added to the denominator to improve numerical stability (default: 1e-3)."""
94
+
95
+ beta2_decay: float = -0.8
96
+ """Coefficient used for computing running averages of square gradient (default: -0.8)."""
97
+
98
+ weight_decay: float = 0.0
99
+ """Weight decay (L2 penalty) (default: 0.0)."""
100
+
101
+ maximize: bool = False
102
+ """Maximize the params based on the objective, instead of minimizing."""
103
+
104
+ @override
105
+ def create_optimizer(
106
+ self,
107
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
108
+ ):
109
+ from torch.optim import Adafactor
110
+
111
+ return Adafactor(
112
+ parameters,
113
+ lr=self.lr,
114
+ eps=(self.eps1, self.eps2),
115
+ beta2_decay=self.beta2_decay,
116
+ weight_decay=self.weight_decay,
117
+ maximize=self.maximize,
118
+ )
119
+
120
+
121
+ @final
122
+ @optimizer_registry.register
123
+ class AdadeltaConfig(OptimizerConfigBase):
124
+ name: Literal["adadelta"] = "adadelta"
125
+
126
+ lr: float
127
+ """Learning rate for the optimizer."""
128
+
129
+ rho: float = 0.9
130
+ """Coefficient used for computing a running average of squared gradients."""
131
+
132
+ eps: float = 1e-6
133
+ """Term added to the denominator to improve numerical stability."""
134
+
135
+ weight_decay: float = 0.0
136
+ """Weight decay (L2 penalty) for the optimizer."""
137
+
138
+ maximize: bool = False
139
+ """Maximize the params based on the objective, instead of minimizing."""
140
+
141
+ foreach: bool | None = None
142
+ """Whether foreach implementation of optimizer is used."""
143
+
144
+ capturable: bool = False
145
+ """Whether this instance is safe to capture in a CUDA graph."""
146
+
147
+ differentiable: bool = False
148
+ """Whether autograd should occur through the optimizer step in training."""
149
+
150
+ @override
151
+ def create_optimizer(
152
+ self,
153
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
154
+ ):
155
+ from torch.optim import Adadelta
156
+
157
+ return Adadelta(
158
+ parameters,
159
+ lr=self.lr,
160
+ rho=self.rho,
161
+ eps=self.eps,
162
+ weight_decay=self.weight_decay,
163
+ maximize=self.maximize,
164
+ foreach=self.foreach,
165
+ capturable=self.capturable,
166
+ differentiable=self.differentiable,
167
+ )
168
+
169
+
170
+ @final
171
+ @optimizer_registry.register
172
+ class AdagradConfig(OptimizerConfigBase):
173
+ name: Literal["adagrad"] = "adagrad"
174
+
175
+ lr: float
176
+ """Learning rate for the optimizer."""
177
+
178
+ lr_decay: float = 0.0
179
+ """Learning rate decay."""
180
+
181
+ weight_decay: float = 0.0
182
+ """Weight decay (L2 penalty) for the optimizer."""
183
+
184
+ initial_accumulator_value: float = 0.0
185
+ """Initial value for the accumulator."""
186
+
187
+ eps: float = 1e-10
188
+ """Term added to the denominator to improve numerical stability."""
189
+
190
+ maximize: bool = False
191
+ """Maximize the params based on the objective, instead of minimizing."""
192
+
193
+ foreach: bool | None = None
194
+ """Whether foreach implementation of optimizer is used."""
195
+
196
+ differentiable: bool = False
197
+ """Whether autograd should occur through the optimizer step in training."""
198
+
199
+ fused: bool | None = None
200
+ """Whether the fused implementation is used."""
201
+
202
+ @override
203
+ def create_optimizer(
204
+ self,
205
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
206
+ ):
207
+ from torch.optim import Adagrad
208
+
209
+ return Adagrad(
210
+ parameters,
211
+ lr=self.lr,
212
+ lr_decay=self.lr_decay,
213
+ weight_decay=self.weight_decay,
214
+ initial_accumulator_value=self.initial_accumulator_value,
215
+ eps=self.eps,
216
+ maximize=self.maximize,
217
+ foreach=self.foreach,
218
+ differentiable=self.differentiable,
219
+ fused=self.fused,
220
+ )
221
+
222
+
223
+ @final
224
+ @optimizer_registry.register
225
+ class AdamConfig(OptimizerConfigBase):
226
+ name: Literal["adam"] = "adam"
227
+
228
+ lr: float
229
+ """Learning rate for the optimizer."""
230
+
231
+ betas: tuple[float, float] = (0.9, 0.999)
232
+ """Coefficients used for computing running averages of gradient and its square."""
233
+
234
+ eps: float = 1e-8
235
+ """Term added to the denominator to improve numerical stability."""
236
+
237
+ weight_decay: float = 0.0
238
+ """Weight decay (L2 penalty) for the optimizer."""
239
+
240
+ amsgrad: bool = False
241
+ """Whether to use the AMSGrad variant of this algorithm."""
242
+
243
+ maximize: bool = False
244
+ """Maximize the params based on the objective, instead of minimizing."""
245
+
246
+ foreach: bool | None = None
247
+ """Whether foreach implementation of optimizer is used."""
248
+
249
+ capturable: bool = False
250
+ """Whether this instance is safe to capture in a CUDA graph."""
251
+
252
+ differentiable: bool = False
253
+ """Whether autograd should occur through the optimizer step in training."""
254
+
255
+ fused: bool | None = None
256
+ """Whether the fused implementation is used."""
257
+
258
+ @override
259
+ def create_optimizer(
260
+ self,
261
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
262
+ ):
263
+ from torch.optim import Adam
264
+
265
+ return Adam(
266
+ parameters,
267
+ lr=self.lr,
268
+ betas=self.betas,
269
+ eps=self.eps,
270
+ weight_decay=self.weight_decay,
271
+ amsgrad=self.amsgrad,
272
+ maximize=self.maximize,
273
+ foreach=self.foreach,
274
+ capturable=self.capturable,
275
+ differentiable=self.differentiable,
276
+ fused=self.fused,
277
+ )
278
+
279
+
280
+ @final
281
+ @optimizer_registry.register
282
+ class AdamaxConfig(OptimizerConfigBase):
283
+ name: Literal["adamax"] = "adamax"
284
+
285
+ lr: float
286
+ """Learning rate for the optimizer."""
287
+
288
+ betas: tuple[float, float] = (0.9, 0.999)
289
+ """Coefficients used for computing running averages of gradient and its square."""
290
+
291
+ eps: float = 1e-8
292
+ """Term added to the denominator to improve numerical stability."""
293
+
294
+ weight_decay: float = 0.0
295
+ """Weight decay (L2 penalty) for the optimizer."""
296
+
297
+ maximize: bool = False
298
+ """Maximize the params based on the objective, instead of minimizing."""
299
+
300
+ foreach: bool | None = None
301
+ """Whether foreach implementation of optimizer is used."""
302
+
303
+ capturable: bool = False
304
+ """Whether this instance is safe to capture in a CUDA graph."""
305
+
306
+ differentiable: bool = False
307
+ """Whether autograd should occur through the optimizer step in training."""
308
+
309
+ @override
310
+ def create_optimizer(
311
+ self,
312
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
313
+ ):
314
+ from torch.optim import Adamax
315
+
316
+ return Adamax(
317
+ parameters,
318
+ lr=self.lr,
319
+ betas=self.betas,
320
+ eps=self.eps,
321
+ weight_decay=self.weight_decay,
322
+ maximize=self.maximize,
323
+ foreach=self.foreach,
324
+ capturable=self.capturable,
325
+ differentiable=self.differentiable,
326
+ )
327
+
328
+
329
+ @final
330
+ @optimizer_registry.register
331
+ class ASGDConfig(OptimizerConfigBase):
332
+ name: Literal["asgd"] = "asgd"
333
+
334
+ lr: float
335
+ """Learning rate for the optimizer."""
336
+
337
+ lambd: float = 1e-4
338
+ """Decay term."""
339
+
340
+ alpha: float = 0.75
341
+ """Power for eta update."""
342
+
343
+ t0: float = 1e6
344
+ """Point at which to start averaging."""
345
+
346
+ weight_decay: float = 0.0
347
+ """Weight decay (L2 penalty) for the optimizer."""
348
+
349
+ maximize: bool = False
350
+ """Maximize the params based on the objective, instead of minimizing."""
351
+
352
+ @override
353
+ def create_optimizer(
354
+ self,
355
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
356
+ ):
357
+ from torch.optim import ASGD
358
+
359
+ return ASGD(
360
+ parameters,
361
+ lr=self.lr,
362
+ lambd=self.lambd,
363
+ alpha=self.alpha,
364
+ t0=self.t0,
365
+ weight_decay=self.weight_decay,
366
+ maximize=self.maximize,
367
+ )
368
+
369
+
370
+ @final
371
+ @optimizer_registry.register
372
+ class NAdamConfig(OptimizerConfigBase):
373
+ name: Literal["nadam"] = "nadam"
374
+
375
+ lr: float
376
+ """Learning rate for the optimizer."""
377
+
378
+ betas: tuple[float, float] = (0.9, 0.999)
379
+ """Coefficients used for computing running averages of gradient and its square."""
380
+
381
+ eps: float = 1e-8
382
+ """Term added to the denominator to improve numerical stability."""
383
+
384
+ weight_decay: float = 0.0
385
+ """Weight decay (L2 penalty) for the optimizer."""
386
+
387
+ momentum_decay: float = 4e-3
388
+ """Momentum decay."""
389
+
390
+ decoupled_weight_decay: bool = False
391
+ """Whether to use decoupled weight decay."""
392
+
393
+ maximize: bool = False
394
+ """Maximize the params based on the objective, instead of minimizing."""
395
+
396
+ foreach: bool | None = None
397
+ """Whether foreach implementation of optimizer is used."""
398
+
399
+ capturable: bool = False
400
+ """Whether this instance is safe to capture in a CUDA graph."""
401
+
402
+ differentiable: bool = False
403
+ """Whether autograd should occur through the optimizer step in training."""
404
+
405
+ @override
406
+ def create_optimizer(
407
+ self,
408
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
409
+ ):
410
+ from torch.optim import NAdam
411
+
412
+ return NAdam(
413
+ parameters,
414
+ lr=self.lr,
415
+ betas=self.betas,
416
+ eps=self.eps,
417
+ weight_decay=self.weight_decay,
418
+ momentum_decay=self.momentum_decay,
419
+ decoupled_weight_decay=self.decoupled_weight_decay,
420
+ maximize=self.maximize,
421
+ foreach=self.foreach,
422
+ capturable=self.capturable,
423
+ differentiable=self.differentiable,
424
+ )
425
+
426
+
427
+ @final
428
+ @optimizer_registry.register
429
+ class RAdamConfig(OptimizerConfigBase):
430
+ name: Literal["radam"] = "radam"
431
+
432
+ lr: float
433
+ """Learning rate for the optimizer."""
434
+
435
+ betas: tuple[float, float] = (0.9, 0.999)
436
+ """Coefficients used for computing running averages of gradient and its square."""
437
+
438
+ eps: float = 1e-8
439
+ """Term added to the denominator to improve numerical stability."""
440
+
441
+ weight_decay: float = 0.0
442
+ """Weight decay (L2 penalty) for the optimizer."""
443
+
444
+ decoupled_weight_decay: bool = False
445
+ """Whether to use decoupled weight decay."""
446
+
447
+ maximize: bool = False
448
+ """Maximize the params based on the objective, instead of minimizing."""
449
+
450
+ foreach: bool | None = None
451
+ """Whether foreach implementation of optimizer is used."""
452
+
453
+ capturable: bool = False
454
+ """Whether this instance is safe to capture in a CUDA graph."""
455
+
456
+ differentiable: bool = False
457
+ """Whether autograd should occur through the optimizer step in training."""
458
+
459
+ @override
460
+ def create_optimizer(
461
+ self,
462
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
463
+ ):
464
+ from torch.optim import RAdam
465
+
466
+ return RAdam(
467
+ parameters,
468
+ lr=self.lr,
469
+ betas=self.betas,
470
+ eps=self.eps,
471
+ weight_decay=self.weight_decay,
472
+ decoupled_weight_decay=self.decoupled_weight_decay,
473
+ maximize=self.maximize,
474
+ foreach=self.foreach,
475
+ capturable=self.capturable,
476
+ differentiable=self.differentiable,
477
+ )
478
+
479
+
480
+ @final
481
+ @optimizer_registry.register
482
+ class RMSpropConfig(OptimizerConfigBase):
483
+ name: Literal["rmsprop"] = "rmsprop"
484
+
485
+ lr: float
486
+ """Learning rate for the optimizer."""
487
+
488
+ alpha: float = 0.99
489
+ """Smoothing constant."""
490
+
491
+ eps: float = 1e-8
492
+ """Term added to the denominator to improve numerical stability."""
493
+
494
+ weight_decay: float = 0.0
495
+ """Weight decay (L2 penalty) for the optimizer."""
496
+
497
+ momentum: float = 0.0
498
+ """Momentum factor."""
499
+
500
+ centered: bool = False
501
+ """If True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance."""
502
+
503
+ maximize: bool = False
504
+ """Maximize the params based on the objective, instead of minimizing."""
505
+
506
+ foreach: bool | None = None
507
+ """Whether foreach implementation of optimizer is used."""
508
+
509
+ capturable: bool = False
510
+ """Whether this instance is safe to capture in a CUDA graph."""
511
+
512
+ differentiable: bool = False
513
+ """Whether autograd should occur through the optimizer step in training."""
514
+
515
+ @override
516
+ def create_optimizer(
517
+ self,
518
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
519
+ ):
520
+ from torch.optim import RMSprop
521
+
522
+ return RMSprop(
523
+ parameters,
524
+ lr=self.lr,
525
+ alpha=self.alpha,
526
+ eps=self.eps,
527
+ weight_decay=self.weight_decay,
528
+ momentum=self.momentum,
529
+ centered=self.centered,
530
+ maximize=self.maximize,
531
+ foreach=self.foreach,
532
+ capturable=self.capturable,
533
+ differentiable=self.differentiable,
534
+ )
535
+
536
+
537
+ @final
538
+ @optimizer_registry.register
539
+ class RpropConfig(OptimizerConfigBase):
540
+ name: Literal["rprop"] = "rprop"
541
+
542
+ lr: float
543
+ """Learning rate for the optimizer."""
544
+
545
+ etas: tuple[float, float] = (0.5, 1.2)
546
+ """Pair of (etaminus, etaplus), multiplicative increase and decrease factors."""
547
+
548
+ step_sizes: tuple[float, float] = (1e-6, 50.0)
549
+ """Pair of minimal and maximal allowed step sizes."""
550
+
551
+ maximize: bool = False
552
+ """Maximize the params based on the objective, instead of minimizing."""
553
+
554
+ @override
555
+ def create_optimizer(
556
+ self,
557
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
558
+ ):
559
+ from torch.optim import Rprop
560
+
561
+ return Rprop(
562
+ parameters,
563
+ lr=self.lr,
564
+ etas=self.etas,
565
+ step_sizes=self.step_sizes,
566
+ maximize=self.maximize,
567
+ )
568
+
569
+
570
+ @final
571
+ @optimizer_registry.register
572
+ class SGDConfig(OptimizerConfigBase):
573
+ name: Literal["sgd"] = "sgd"
574
+
575
+ lr: float
576
+ """Learning rate for the optimizer."""
577
+
578
+ momentum: float = 0.0
579
+ """Momentum factor."""
580
+
581
+ dampening: float = 0.0
582
+ """Dampening for momentum."""
583
+
584
+ weight_decay: float = 0.0
585
+ """Weight decay (L2 penalty) for the optimizer."""
586
+
587
+ nesterov: bool = False
588
+ """Enables Nesterov momentum."""
589
+
590
+ maximize: bool = False
591
+ """Maximize the params based on the objective, instead of minimizing."""
592
+
593
+ foreach: bool | None = None
594
+ """Whether foreach implementation of optimizer is used."""
595
+
596
+ differentiable: bool = False
597
+ """Whether autograd should occur through the optimizer step in training."""
598
+
599
+ fused: bool | None = None
600
+ """Whether the fused implementation is used."""
601
+
602
+ @override
603
+ def create_optimizer(
604
+ self,
605
+ parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
606
+ ):
607
+ from torch.optim import SGD
608
+
609
+ return SGD(
610
+ parameters,
611
+ lr=self.lr,
612
+ momentum=self.momentum,
613
+ dampening=self.dampening,
614
+ weight_decay=self.weight_decay,
615
+ nesterov=self.nesterov,
616
+ maximize=self.maximize,
617
+ foreach=self.foreach,
618
+ differentiable=self.differentiable,
619
+ fused=self.fused,
62
620
  )
63
621
 
64
622
 
@@ -31,6 +31,7 @@ from .._hf_hub import HuggingFaceHubConfig
31
31
  from ..callbacks import (
32
32
  BestCheckpointCallbackConfig,
33
33
  CallbackConfig,
34
+ DistributedPredictionWriterConfig,
34
35
  EarlyStoppingCallbackConfig,
35
36
  LastCheckpointCallbackConfig,
36
37
  NormLoggingCallbackConfig,
@@ -701,6 +702,14 @@ class TrainerConfig(C.Config):
701
702
  auto_validate_metrics: MetricValidationCallbackConfig | None = None
702
703
  """If enabled, will automatically validate the metrics before starting the training routine."""
703
704
 
705
+ distributed_predict: DistributedPredictionWriterConfig | None = (
706
+ DistributedPredictionWriterConfig()
707
+ )
708
+ """If enabled, will use a custom BasePredictionWriter callback to automatically
709
+ handle distributed prediction. This is useful for running prediction on multiple GPUs
710
+ seamlessly.
711
+ """
712
+
704
713
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
705
714
  """
706
715
  Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
@@ -752,10 +761,6 @@ class TrainerConfig(C.Config):
752
761
  )
753
762
 
754
763
  def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
755
- # Disable all callbacks if barebones mode is enabled
756
- if self.barebones:
757
- return
758
-
759
764
  yield self.early_stopping
760
765
  yield self.checkpoint_saving
761
766
  yield self.lr_monitor
@@ -772,6 +777,7 @@ class TrainerConfig(C.Config):
772
777
  yield self.reduce_lr_on_plateau_sanity_checking
773
778
  yield self.auto_set_debug_flag
774
779
  yield self.auto_validate_metrics
780
+ yield self.distributed_predict
775
781
  yield from self.callbacks
776
782
 
777
783
  def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
@@ -10,12 +10,16 @@ import torch
10
10
  from lightning.fabric.plugins.environments.lsf import LSFEnvironment
11
11
  from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
12
12
  from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
13
- from lightning.pytorch import LightningModule
13
+ from lightning.pytorch import LightningDataModule, LightningModule
14
14
  from lightning.pytorch import Trainer as LightningTrainer
15
15
  from lightning.pytorch.callbacks import Callback
16
16
  from lightning.pytorch.profilers import Profiler
17
17
  from lightning.pytorch.trainer.states import TrainerFn
18
- from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
18
+ from lightning.pytorch.utilities.types import (
19
+ _EVALUATE_OUTPUT,
20
+ _PREDICT_OUTPUT,
21
+ EVAL_DATALOADERS,
22
+ )
19
23
  from typing_extensions import Never, Unpack, assert_never, deprecated, override
20
24
 
21
25
  from .._checkpoint.metadata import write_checkpoint_metadata
@@ -532,3 +536,18 @@ class Trainer(LightningTrainer):
532
536
  update_hparams_dict=update_hparams_dict,
533
537
  )
534
538
  return cls(hparams)
539
+
540
+ def distributed_predict(
541
+ self,
542
+ model: LightningModule | None = None,
543
+ dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None,
544
+ datamodule: LightningDataModule | None = None,
545
+ ckpt_path: str | Path | None = None,
546
+ ):
547
+ self.predict(
548
+ model,
549
+ dataloaders,
550
+ datamodule,
551
+ return_predictions=False,
552
+ ckpt_path=ckpt_path,
553
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.1.1b1
3
+ Version: 1.2.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -3,12 +3,12 @@ nshtrainer/__init__.py,sha256=VcqBfL8RgCcZDaY645nxeDmOspqerx4x46wggCMnS0E,692
3
3
  nshtrainer/_callback.py,sha256=ZDppiJ4d65tRXTEWYPZLH_F1xFizdz1pkWJe_sQ5uII,12564
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=Hh5a7OkdknUEbkEwX6vS88-XLEeuVDoR6a3en2uLzQE,5597
5
5
  nshtrainer/_checkpoint/saver.py,sha256=utcrYKSosd04N9m2GIylufO5DO05D90qVU3mvadfApU,1658
6
- nshtrainer/_directory.py,sha256=TJR9ccyuzRlAVfVjGyeQ3E2AFAcz-XbBCxWfiXo2SlY,3191
6
+ nshtrainer/_directory.py,sha256=SuXJe9xJXZkDXWWfeOS9rEDz6vZUA6mpnEdkAW0ZQnY,3193
7
7
  nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
8
8
  nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
9
- nshtrainer/callbacks/__init__.py,sha256=w80d6PGNu3wjUj9NiRGMqCX9NnXD5ZlvbY-DIK4zjPE,3766
9
+ nshtrainer/callbacks/__init__.py,sha256=m6eJuprZfBELuKpngKXre33B9yPXkG7jlKVmI-0yXRQ,4000
10
10
  nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
11
- nshtrainer/callbacks/base.py,sha256=Alaou1IHAIlMEM7g58d_02ozY2xWlshBN7fsw5Ee21s,3683
11
+ nshtrainer/callbacks/base.py,sha256=K9aom1WVVRYxl-tHWgtmDUQZ1o63NgznvLsjauTKcCc,4225
12
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
13
13
  nshtrainer/callbacks/checkpoint/_base.py,sha256=f7lpk8W4xqxk3PolBEU3AWt9VTIpoLW7wMUhC5DNm3c,6345
14
14
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=aCs3E1eucfDlUeW2Iq_Ke7hb96BxHanmvn7PCCbqq0E,2648
@@ -16,6 +16,7 @@ nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM
16
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
17
17
  nshtrainer/callbacks/debug_flag.py,sha256=96fuP0C7C6dSs1GiMeUYzzs0X3Q4Pjt9JVWg3b75fU4,1748
18
18
  nshtrainer/callbacks/directory_setup.py,sha256=wPas_Ren8ANejogmIdKhqqgj4ulxz9AS_8xVIAfRXa0,2565
19
+ nshtrainer/callbacks/distributed_prediction_writer.py,sha256=OSh2C6XF7Nki4eFByNVhwlt69izkxnlmfPx54w4rvBo,5274
19
20
  nshtrainer/callbacks/early_stopping.py,sha256=rC_qYKCQWjRQJFo0ky46uG0aDJdYP8vsSlKunk0bUVI,4765
20
21
  nshtrainer/callbacks/ema.py,sha256=dBFiUXG0xmyCw8-ayuSzJMKqSbepl6Ii5VIbhFlT5ug,12255
21
22
  nshtrainer/callbacks/finite_checks.py,sha256=3lZ3kEIjmYQfqTF0DcrgZ9_98ZLQhQj8usH7SgWst3o,2185
@@ -32,12 +33,12 @@ nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU
32
33
  nshtrainer/callbacks/wandb_upload_code.py,sha256=shV7UtnXgY2bUlXdVrXiaDs0PNLlIt7TzNJkJPkzvzI,2414
33
34
  nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
34
35
  nshtrainer/configs/.gitattributes,sha256=VeZmarvNEqiRBOHGcllpKm90nL6C8u4tBu7SEm7fj-E,26
35
- nshtrainer/configs/__init__.py,sha256=4WNs4Zv4PtHWD0KKH4X7j_zFt-COrEB0KhNIljsA6Rc,14740
36
+ nshtrainer/configs/__init__.py,sha256=KD3uClMwnA4LfQ7rY5phDdUbp3j8NoZfaGbGPbpaJVs,15848
36
37
  nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
37
38
  nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
38
39
  nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
39
40
  nshtrainer/configs/_hf_hub/__init__.py,sha256=ciFLbV-JV8SVzqo2SyythEuDMnk7gGfdIacB18QYnkY,511
40
- nshtrainer/configs/callbacks/__init__.py,sha256=PB3Jg-8_vMhp-mCFw2_Tqt05drKwHK6Ovl9mb8NNiXs,4506
41
+ nshtrainer/configs/callbacks/__init__.py,sha256=tP9urR73NIanyxpbi4EERsxOnGNiptbQpmsj-v53a38,4774
41
42
  nshtrainer/configs/callbacks/actsave/__init__.py,sha256=JvjSZtEoA28FC4u-QT3skQzBDVbN9eq07rn4u2ydW-E,377
42
43
  nshtrainer/configs/callbacks/base/__init__.py,sha256=wT3RhXttLyf6RFWCIvsoiXcPdfGx5W309WBI18AI5os,278
43
44
  nshtrainer/configs/callbacks/checkpoint/__init__.py,sha256=aGJ7vX14YamkMdwYAdPv6XrRnP0aZd5uZ5X0nSLc6IU,1475
@@ -47,6 +48,7 @@ nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py,sha256=SIRfz
47
48
  nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py,sha256=VSkO0TYCAYy_9mQuOBoAND7D3Cg6w6nMCpqivQZLPcE,551
48
49
  nshtrainer/configs/callbacks/debug_flag/__init__.py,sha256=s_ifB-DbZjar0w11pr2oVAlcMTWWMnK_tCNilfswL04,425
49
50
  nshtrainer/configs/callbacks/directory_setup/__init__.py,sha256=e8GCRy2Alds3AXLwp4ieSGtn8S0YjmKJ5khOaQ0zKGs,464
51
+ nshtrainer/configs/callbacks/distributed_prediction_writer/__init__.py,sha256=npO97m5inRgAnGtGBwz_MNJz44B2cG4j9LZFCllQcrk,530
50
52
  nshtrainer/configs/callbacks/early_stopping/__init__.py,sha256=m8N6H11PjqcWqXP5ZxWC8L4PHMUI6avYyN5rUNprjuQ,546
51
53
  nshtrainer/configs/callbacks/ema/__init__.py,sha256=DUJrbDD8wWX_s0_4dwKpT_IWKSVpBmhe4-1aELq7G6w,377
52
54
  nshtrainer/configs/callbacks/finite_checks/__init__.py,sha256=e-vx9Kn-noqw4wPvZw7fDMfb9Tsa6Duk0TIa8ZIgIIE,443
@@ -77,14 +79,14 @@ nshtrainer/configs/nn/__init__.py,sha256=Ms2gIqbRxNVm6GHKCddCJTTqMwUPifjjHD_fCfJ
77
79
  nshtrainer/configs/nn/mlp/__init__.py,sha256=O6kQ6utZNJPG9Fax5pRdZcHa3J-XFKKdXcc_PQg0jk0,347
78
80
  nshtrainer/configs/nn/nonlinearity/__init__.py,sha256=LCTbTyelCMABVw505CGQ4UpEGlAnIhflSLFqwAQXLQA,2155
79
81
  nshtrainer/configs/nn/rng/__init__.py,sha256=4iC6vwxbfNeXyvpwZ1Z5Kcy-he4cu7mg3UpLD-RLrHc,141
80
- nshtrainer/configs/optimizer/__init__.py,sha256=itIDIHQvGm50eZ7JLyNElahnNUMPJ__4PMmTjc0RQ6o,444
82
+ nshtrainer/configs/optimizer/__init__.py,sha256=8ztp5UD-edfzwF-qdJTeZwlv-YWJ5Sn230b9aWxJyQQ,1398
81
83
  nshtrainer/configs/profiler/__init__.py,sha256=2ssaIpfVnvcbfNvZ-JeKp1Cx4NO1LknkVqTm1hu7Lvw,768
82
84
  nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRDm1CKqjwUOQNbQjD4,176
83
85
  nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
84
86
  nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
85
87
  nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
86
- nshtrainer/configs/trainer/__init__.py,sha256=a8pzGVid52abAVARPbgjaN566H1ZM44FH_x95bsBaGE,7880
87
- nshtrainer/configs/trainer/_config/__init__.py,sha256=6DXdtP-uH11TopQ7kzId9fco-wVkD7ZfevbBqDpN6TE,3817
88
+ nshtrainer/configs/trainer/__init__.py,sha256=PF9rYuVpk0IuhjcxS_hmBTT6A0oq7AWZDcx0Gfqi7MM,8040
89
+ nshtrainer/configs/trainer/_config/__init__.py,sha256=5B8pjyNHfyFJ6p8dD5VSHD1tw2CcZ87Eq2C_Req3t60,3977
88
90
  nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
89
91
  nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
90
92
  nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=slW5z1FZw2qICXO9l9DnLIDB1Yl7KOcxPEZkyYIHrp4,276
@@ -116,7 +118,7 @@ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=irPyDjfUX843ze4bJM9sW8WSe
116
118
  nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
117
119
  nshtrainer/metrics/_config.py,sha256=ox_ScK6V0J9nzIMhEB0qpToNKpt83VVgOVSRFCV-wBc,595
118
120
  nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
119
- nshtrainer/model/base.py,sha256=LsOK5mMhYG5J0eSFKZKdd1fTvr38sgi8LLVSqoW6OCU,8386
121
+ nshtrainer/model/base.py,sha256=Pv3M3QStWQp-DnfGFsLPAmp87HHrX1NrkAa4JcyBoDk,10255
120
122
  nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
121
123
  nshtrainer/model/mixins/debug.py,sha256=ydLuAAaa7M5bX0gougZ5gWuZnvn4Ra9assal3IZ9hq8,2086
122
124
  nshtrainer/model/mixins/logger.py,sha256=7u9fQig-SVFA9RFIB4U0gqJAzruh49mgmXXvZ6VkDUk,11694
@@ -126,14 +128,14 @@ nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,
126
128
  nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
127
129
  nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
128
130
  nshtrainer/nn/rng.py,sha256=IJGvX9v8qBkfgBrMlNU2aj-MbYTPoncFyJzvPkzCQpM,512
129
- nshtrainer/optimizer.py,sha256=u968GRNPUNn3f_9BEY2RBNuJq5O3wJWams3NG0dkrOA,1738
131
+ nshtrainer/optimizer.py,sha256=8pjOny7NxIt04PXxn3zOyJ2soL7nmj8yBVV82r_tNsc,17522
130
132
  nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
131
133
  nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
132
134
  nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5NOQ,1160
133
135
  nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
134
136
  nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
135
137
  nshtrainer/trainer/__init__.py,sha256=fQ7gQRlGWX-90TYT0rttkQyvXDCzo7DAvJgr-jX1zsY,316
136
- nshtrainer/trainer/_config.py,sha256=s-_XoLc9mbNAdroRJyOKd3dLTyrFLQkPyGJkKDmBYf8,33267
138
+ nshtrainer/trainer/_config.py,sha256=tdWAYh-KGXBpgdY8fwvOejjRZN-AS2Ze0f_9s2VEuZ0,33556
137
139
  nshtrainer/trainer/_log_hparams.py,sha256=XH2lZ4U_3AZBhOt91ocsEhdL_NRz35oWvqLCUFDohUs,2389
138
140
  nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
139
141
  nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
@@ -145,7 +147,7 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=-BbEyWZ063O7tZme7Gdu1lVxK6p1NeuLc
145
147
  nshtrainer/trainer/plugin/precision.py,sha256=7lf7KZd_yFyPmhLApjEIv0pkoDB5zdxi-7in0wRj3z8,5436
146
148
  nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
147
149
  nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
148
- nshtrainer/trainer/trainer.py,sha256=BKRicDlLI7KstzuP0SmzJzp0U4GK5lhZcKHS1IuL5sA,21197
150
+ nshtrainer/trainer/trainer.py,sha256=smoN61iixWYDWGFvxrt8VwryZVy_NzqqjUcgOid0gRA,21696
149
151
  nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
150
152
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
151
153
  nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
@@ -157,6 +159,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
157
159
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
158
160
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
159
161
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
160
- nshtrainer-1.1.1b1.dist-info/METADATA,sha256=wdOIQ91eUgWrIHfPLP06FD4uMkyyIfToR3VhBY-BXsE,962
161
- nshtrainer-1.1.1b1.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
162
- nshtrainer-1.1.1b1.dist-info/RECORD,,
162
+ nshtrainer-1.2.0.dist-info/METADATA,sha256=HkNLruaJJuf3ijnGe7NqNd9emBR6QHMRh2-taC5wTrU,960
163
+ nshtrainer-1.2.0.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
164
+ nshtrainer-1.2.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 2.1.1
2
+ Generator: poetry-core 2.1.2
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any