nshtrainer 1.1.1b1__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_directory.py +3 -3
- nshtrainer/callbacks/__init__.py +6 -0
- nshtrainer/callbacks/base.py +22 -3
- nshtrainer/callbacks/distributed_prediction_writer.py +166 -0
- nshtrainer/configs/__init__.py +28 -0
- nshtrainer/configs/callbacks/__init__.py +6 -0
- nshtrainer/configs/callbacks/distributed_prediction_writer/__init__.py +19 -0
- nshtrainer/configs/optimizer/__init__.py +24 -0
- nshtrainer/configs/trainer/__init__.py +4 -0
- nshtrainer/configs/trainer/_config/__init__.py +4 -0
- nshtrainer/model/base.py +60 -2
- nshtrainer/optimizer.py +559 -1
- nshtrainer/trainer/_config.py +10 -4
- nshtrainer/trainer/trainer.py +21 -2
- {nshtrainer-1.1.1b1.dist-info → nshtrainer-1.2.0.dist-info}/METADATA +1 -1
- {nshtrainer-1.1.1b1.dist-info → nshtrainer-1.2.0.dist-info}/RECORD +17 -15
- {nshtrainer-1.1.1b1.dist-info → nshtrainer-1.2.0.dist-info}/WHEEL +1 -1
nshtrainer/_directory.py
CHANGED
@@ -65,9 +65,9 @@ class DirectoryConfig(C.Config):
|
|
65
65
|
) -> Path:
|
66
66
|
# The subdir will be $CWD/nshtrainer/{id}/{log, stdio, checkpoint, activation}/
|
67
67
|
if (subdir := getattr(self, subdirectory, None)) is not None:
|
68
|
-
assert isinstance(
|
69
|
-
|
70
|
-
)
|
68
|
+
assert isinstance(subdir, Path), (
|
69
|
+
f"Expected a Path for {subdirectory}, got {type(subdir)}"
|
70
|
+
)
|
71
71
|
return subdir
|
72
72
|
|
73
73
|
dir = self.resolve_run_root_directory(run_id)
|
nshtrainer/callbacks/__init__.py
CHANGED
@@ -23,6 +23,12 @@ from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
|
|
23
23
|
from .directory_setup import (
|
24
24
|
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
25
25
|
)
|
26
|
+
from .distributed_prediction_writer import (
|
27
|
+
DistributedPredictionWriter as DistributedPredictionWriter,
|
28
|
+
)
|
29
|
+
from .distributed_prediction_writer import (
|
30
|
+
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
31
|
+
)
|
26
32
|
from .early_stopping import EarlyStoppingCallback as EarlyStoppingCallback
|
27
33
|
from .early_stopping import EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig
|
28
34
|
from .ema import EMACallback as EMACallback
|
nshtrainer/callbacks/base.py
CHANGED
@@ -23,6 +23,10 @@ class CallbackMetadataConfig(TypedDict, total=False):
|
|
23
23
|
"""Priority of the callback. Callbacks with higher priority will be loaded first.
|
24
24
|
Default is `0`."""
|
25
25
|
|
26
|
+
enabled_for_barebones: bool
|
27
|
+
"""Whether this callback is enabled for barebones mode.
|
28
|
+
Default is `False`."""
|
29
|
+
|
26
30
|
|
27
31
|
@dataclass(frozen=True)
|
28
32
|
class CallbackWithMetadata:
|
@@ -91,10 +95,20 @@ def _filter_ignore_if_exists(callbacks: list[CallbackWithMetadata]):
|
|
91
95
|
|
92
96
|
|
93
97
|
def _process_and_filter_callbacks(
|
98
|
+
trainer_config: TrainerConfig,
|
94
99
|
callbacks: Iterable[CallbackWithMetadata],
|
95
100
|
) -> list[Callback]:
|
96
101
|
callbacks = list(callbacks)
|
97
102
|
|
103
|
+
# If we're in barebones mode, used the callback metadata
|
104
|
+
# to decide to keep/remove the callback.
|
105
|
+
if trainer_config.barebones:
|
106
|
+
callbacks = [
|
107
|
+
callback
|
108
|
+
for callback in callbacks
|
109
|
+
if callback.metadata.get("enabled_for_barebones", False)
|
110
|
+
]
|
111
|
+
|
98
112
|
# Sort by priority (higher priority first)
|
99
113
|
callbacks.sort(
|
100
114
|
key=lambda callback: callback.metadata.get("priority", 0),
|
@@ -114,9 +128,14 @@ def resolve_all_callbacks(trainer_config: TrainerConfig):
|
|
114
128
|
if config is not None
|
115
129
|
]
|
116
130
|
callbacks = _process_and_filter_callbacks(
|
117
|
-
|
118
|
-
|
119
|
-
|
131
|
+
trainer_config,
|
132
|
+
(
|
133
|
+
callback
|
134
|
+
for callback_config in callback_configs
|
135
|
+
for callback in _create_callbacks_with_metadata(
|
136
|
+
callback_config, trainer_config
|
137
|
+
)
|
138
|
+
),
|
120
139
|
)
|
121
140
|
return callbacks
|
122
141
|
|
@@ -0,0 +1,166 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import functools
|
4
|
+
import logging
|
5
|
+
from collections.abc import Iterator, Sequence
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Any, ClassVar, Literal, overload
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from lightning.fabric.utilities.apply_func import move_data_to_device
|
11
|
+
from lightning.pytorch.callbacks import BasePredictionWriter
|
12
|
+
from typing_extensions import final, override
|
13
|
+
|
14
|
+
from .base import CallbackConfigBase, CallbackMetadataConfig, callback_registry
|
15
|
+
|
16
|
+
log = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
@final
|
20
|
+
@callback_registry.register
|
21
|
+
class DistributedPredictionWriterConfig(CallbackConfigBase):
|
22
|
+
metadata: ClassVar[CallbackMetadataConfig] = CallbackMetadataConfig(
|
23
|
+
enabled_for_barebones=True
|
24
|
+
)
|
25
|
+
"""Metadata for the callback."""
|
26
|
+
|
27
|
+
name: Literal["distributed_prediction_writer"] = "distributed_prediction_writer"
|
28
|
+
|
29
|
+
dirpath: Path | None = None
|
30
|
+
"""Directory to save the predictions to. If None, will use the default directory."""
|
31
|
+
|
32
|
+
move_to_cpu_on_save: bool = True
|
33
|
+
"""Whether to move the predictions to CPU before saving. Default is True."""
|
34
|
+
|
35
|
+
save_raw: bool = True
|
36
|
+
"""Whether to save the raw predictions."""
|
37
|
+
|
38
|
+
save_processed: bool = True
|
39
|
+
"""Whether to process and save the predictions.
|
40
|
+
|
41
|
+
"Processing" means that the model's batched predictions are split into individual predictions
|
42
|
+
and saved as a list of tensors.
|
43
|
+
"""
|
44
|
+
|
45
|
+
@override
|
46
|
+
def create_callbacks(self, trainer_config):
|
47
|
+
if (dirpath := self.dirpath) is None:
|
48
|
+
dirpath = trainer_config.directory.resolve_subdirectory(
|
49
|
+
trainer_config.id, "predictions"
|
50
|
+
)
|
51
|
+
|
52
|
+
yield DistributedPredictionWriter(self, dirpath)
|
53
|
+
|
54
|
+
|
55
|
+
def _move_and_save(data, path: Path, move_to_cpu: bool):
|
56
|
+
if move_to_cpu:
|
57
|
+
data = move_data_to_device(data, "cpu")
|
58
|
+
|
59
|
+
# Save the data to the specified path
|
60
|
+
torch.save(data, path)
|
61
|
+
|
62
|
+
|
63
|
+
class DistributedPredictionWriter(BasePredictionWriter):
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
config: DistributedPredictionWriterConfig,
|
67
|
+
output_dir: Path,
|
68
|
+
):
|
69
|
+
self.config = config
|
70
|
+
|
71
|
+
super().__init__(write_interval="batch")
|
72
|
+
|
73
|
+
self.output_dir = output_dir
|
74
|
+
|
75
|
+
@override
|
76
|
+
def write_on_batch_end(
|
77
|
+
self,
|
78
|
+
trainer,
|
79
|
+
pl_module,
|
80
|
+
prediction,
|
81
|
+
batch_indices,
|
82
|
+
batch,
|
83
|
+
batch_idx,
|
84
|
+
dataloader_idx,
|
85
|
+
):
|
86
|
+
save = functools.partial(
|
87
|
+
_move_and_save,
|
88
|
+
move_to_cpu=self.config.move_to_cpu_on_save,
|
89
|
+
)
|
90
|
+
|
91
|
+
# Regular, unstructured writing.
|
92
|
+
if self.config.save_raw:
|
93
|
+
output_dir = (
|
94
|
+
self.output_dir
|
95
|
+
/ "raw"
|
96
|
+
/ f"dataloader_{dataloader_idx}"
|
97
|
+
/ f"rank_{trainer.global_rank}"
|
98
|
+
/ f"batch_{batch_idx}"
|
99
|
+
)
|
100
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
101
|
+
save(prediction, output_dir / "predictions.pt")
|
102
|
+
save(batch, output_dir / "batch.pt")
|
103
|
+
save(batch_indices, output_dir / "batch_indices.pt")
|
104
|
+
|
105
|
+
if self.config.save_processed:
|
106
|
+
# Processed writing.
|
107
|
+
from ..model.base import LightningModuleBase
|
108
|
+
|
109
|
+
if not isinstance(pl_module, LightningModuleBase):
|
110
|
+
raise ValueError(
|
111
|
+
"The model must be a subclass of LightningModuleBase to use the distributed prediction writer."
|
112
|
+
)
|
113
|
+
|
114
|
+
output_dir = self.output_dir / "processed" / f"dataloader_{dataloader_idx}"
|
115
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
116
|
+
|
117
|
+
# Split into individual predictions
|
118
|
+
assert batch_indices is not None, (
|
119
|
+
"Batch indices must be provided for processed writing."
|
120
|
+
)
|
121
|
+
for sample in pl_module.split_batched_predictions(
|
122
|
+
batch, prediction, batch_indices
|
123
|
+
):
|
124
|
+
sample = {
|
125
|
+
**sample,
|
126
|
+
"global_rank": trainer.global_rank,
|
127
|
+
"world_size": trainer.world_size,
|
128
|
+
"is_global_zero": trainer.is_global_zero,
|
129
|
+
}
|
130
|
+
save(sample, output_dir / f"{sample['index']}.pt")
|
131
|
+
|
132
|
+
|
133
|
+
class DistributedPredictionReader(Sequence[tuple[Any, Any]]):
|
134
|
+
def __init__(self, output_dir: Path):
|
135
|
+
self.output_dir = output_dir
|
136
|
+
|
137
|
+
@override
|
138
|
+
def __len__(self) -> int:
|
139
|
+
return len(list(self.output_dir.glob("*.pt")))
|
140
|
+
|
141
|
+
@overload
|
142
|
+
def __getitem__(self, index: int) -> tuple[Any, Any]: ...
|
143
|
+
|
144
|
+
@overload
|
145
|
+
def __getitem__(self, index: slice) -> list[tuple[Any, Any]]: ...
|
146
|
+
|
147
|
+
@override
|
148
|
+
def __getitem__(
|
149
|
+
self, index: int | slice
|
150
|
+
) -> tuple[Any, Any] | list[tuple[Any, Any]]:
|
151
|
+
if isinstance(index, slice):
|
152
|
+
# Handle slice indexing
|
153
|
+
indices = range(*index.indices(len(self)))
|
154
|
+
return [self.__getitem__(i) for i in indices]
|
155
|
+
|
156
|
+
# Handle integer indexing
|
157
|
+
path = self.output_dir / f"{index}.pt"
|
158
|
+
if not path.exists():
|
159
|
+
raise FileNotFoundError(f"File {path} does not exist.")
|
160
|
+
sample = torch.load(path)
|
161
|
+
return sample["batch"], sample["prediction"]
|
162
|
+
|
163
|
+
@override
|
164
|
+
def __iter__(self) -> Iterator[tuple[Any, Any]]:
|
165
|
+
for i in range(len(self)):
|
166
|
+
yield self[i]
|
nshtrainer/configs/__init__.py
CHANGED
@@ -21,6 +21,9 @@ from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackCon
|
|
21
21
|
from nshtrainer.callbacks import (
|
22
22
|
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
23
23
|
)
|
24
|
+
from nshtrainer.callbacks import (
|
25
|
+
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
26
|
+
)
|
24
27
|
from nshtrainer.callbacks import (
|
25
28
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
26
29
|
)
|
@@ -95,9 +98,21 @@ from nshtrainer.nn.nonlinearity import (
|
|
95
98
|
SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
|
96
99
|
)
|
97
100
|
from nshtrainer.nn.nonlinearity import nonlinearity_registry as nonlinearity_registry
|
101
|
+
from nshtrainer.optimizer import AdadeltaConfig as AdadeltaConfig
|
102
|
+
from nshtrainer.optimizer import AdafactorConfig as AdafactorConfig
|
103
|
+
from nshtrainer.optimizer import AdagradConfig as AdagradConfig
|
104
|
+
from nshtrainer.optimizer import AdamaxConfig as AdamaxConfig
|
105
|
+
from nshtrainer.optimizer import AdamConfig as AdamConfig
|
98
106
|
from nshtrainer.optimizer import AdamWConfig as AdamWConfig
|
107
|
+
from nshtrainer.optimizer import ASGDConfig as ASGDConfig
|
108
|
+
from nshtrainer.optimizer import NAdamConfig as NAdamConfig
|
99
109
|
from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
|
100
110
|
from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
|
111
|
+
from nshtrainer.optimizer import RAdamConfig as RAdamConfig
|
112
|
+
from nshtrainer.optimizer import RMSpropConfig as RMSpropConfig
|
113
|
+
from nshtrainer.optimizer import RpropConfig as RpropConfig
|
114
|
+
from nshtrainer.optimizer import SGDConfig as SGDConfig
|
115
|
+
from nshtrainer.optimizer import Union as Union
|
101
116
|
from nshtrainer.optimizer import optimizer_registry as optimizer_registry
|
102
117
|
from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
|
103
118
|
from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
|
@@ -225,11 +240,17 @@ from . import trainer as trainer
|
|
225
240
|
from . import util as util
|
226
241
|
|
227
242
|
__all__ = [
|
243
|
+
"ASGDConfig",
|
228
244
|
"AcceleratorConfig",
|
229
245
|
"AcceleratorConfigBase",
|
230
246
|
"ActSaveConfig",
|
231
247
|
"ActSaveLoggerConfig",
|
248
|
+
"AdadeltaConfig",
|
249
|
+
"AdafactorConfig",
|
250
|
+
"AdagradConfig",
|
251
|
+
"AdamConfig",
|
232
252
|
"AdamWConfig",
|
253
|
+
"AdamaxConfig",
|
233
254
|
"AdvancedProfilerConfig",
|
234
255
|
"AsyncCheckpointIOPlugin",
|
235
256
|
"BaseCheckpointCallbackConfig",
|
@@ -249,6 +270,7 @@ __all__ = [
|
|
249
270
|
"DeepSpeedPluginConfig",
|
250
271
|
"DirectoryConfig",
|
251
272
|
"DirectorySetupCallbackConfig",
|
273
|
+
"DistributedPredictionWriterConfig",
|
252
274
|
"DoublePrecisionPluginConfig",
|
253
275
|
"DurationConfig",
|
254
276
|
"ELUNonlinearityConfig",
|
@@ -294,6 +316,7 @@ __all__ = [
|
|
294
316
|
"MetricValidationCallbackConfig",
|
295
317
|
"MishNonlinearityConfig",
|
296
318
|
"MixedPrecisionPluginConfig",
|
319
|
+
"NAdamConfig",
|
297
320
|
"NonlinearityConfig",
|
298
321
|
"NonlinearityConfigBase",
|
299
322
|
"NormLoggingCallbackConfig",
|
@@ -306,10 +329,14 @@ __all__ = [
|
|
306
329
|
"PrintTableMetricsCallbackConfig",
|
307
330
|
"ProfilerConfig",
|
308
331
|
"PyTorchProfilerConfig",
|
332
|
+
"RAdamConfig",
|
309
333
|
"RLPSanityChecksCallbackConfig",
|
334
|
+
"RMSpropConfig",
|
310
335
|
"RNGConfig",
|
311
336
|
"ReLUNonlinearityConfig",
|
312
337
|
"ReduceLROnPlateauConfig",
|
338
|
+
"RpropConfig",
|
339
|
+
"SGDConfig",
|
313
340
|
"SLURMEnvironmentPlugin",
|
314
341
|
"SanityCheckingConfig",
|
315
342
|
"SharedParametersCallbackConfig",
|
@@ -331,6 +358,7 @@ __all__ = [
|
|
331
358
|
"TorchSyncBatchNormPlugin",
|
332
359
|
"TrainerConfig",
|
333
360
|
"TransformerEnginePluginConfig",
|
361
|
+
"Union",
|
334
362
|
"WandbLoggerConfig",
|
335
363
|
"WandbUploadCodeCallbackConfig",
|
336
364
|
"WandbWatchCallbackConfig",
|
@@ -12,6 +12,9 @@ from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackCon
|
|
12
12
|
from nshtrainer.callbacks import (
|
13
13
|
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
14
14
|
)
|
15
|
+
from nshtrainer.callbacks import (
|
16
|
+
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
17
|
+
)
|
15
18
|
from nshtrainer.callbacks import (
|
16
19
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
17
20
|
)
|
@@ -62,6 +65,7 @@ from . import base as base
|
|
62
65
|
from . import checkpoint as checkpoint
|
63
66
|
from . import debug_flag as debug_flag
|
64
67
|
from . import directory_setup as directory_setup
|
68
|
+
from . import distributed_prediction_writer as distributed_prediction_writer
|
65
69
|
from . import early_stopping as early_stopping
|
66
70
|
from . import ema as ema
|
67
71
|
from . import finite_checks as finite_checks
|
@@ -86,6 +90,7 @@ __all__ = [
|
|
86
90
|
"CheckpointMetadata",
|
87
91
|
"DebugFlagCallbackConfig",
|
88
92
|
"DirectorySetupCallbackConfig",
|
93
|
+
"DistributedPredictionWriterConfig",
|
89
94
|
"EMACallbackConfig",
|
90
95
|
"EarlyStoppingCallbackConfig",
|
91
96
|
"EpochTimerCallbackConfig",
|
@@ -109,6 +114,7 @@ __all__ = [
|
|
109
114
|
"checkpoint",
|
110
115
|
"debug_flag",
|
111
116
|
"directory_setup",
|
117
|
+
"distributed_prediction_writer",
|
112
118
|
"early_stopping",
|
113
119
|
"ema",
|
114
120
|
"finite_checks",
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from nshtrainer.callbacks.distributed_prediction_writer import (
|
6
|
+
CallbackConfigBase as CallbackConfigBase,
|
7
|
+
)
|
8
|
+
from nshtrainer.callbacks.distributed_prediction_writer import (
|
9
|
+
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
10
|
+
)
|
11
|
+
from nshtrainer.callbacks.distributed_prediction_writer import (
|
12
|
+
callback_registry as callback_registry,
|
13
|
+
)
|
14
|
+
|
15
|
+
__all__ = [
|
16
|
+
"CallbackConfigBase",
|
17
|
+
"DistributedPredictionWriterConfig",
|
18
|
+
"callback_registry",
|
19
|
+
]
|
@@ -2,14 +2,38 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
|
+
from nshtrainer.optimizer import AdadeltaConfig as AdadeltaConfig
|
6
|
+
from nshtrainer.optimizer import AdafactorConfig as AdafactorConfig
|
7
|
+
from nshtrainer.optimizer import AdagradConfig as AdagradConfig
|
8
|
+
from nshtrainer.optimizer import AdamaxConfig as AdamaxConfig
|
9
|
+
from nshtrainer.optimizer import AdamConfig as AdamConfig
|
5
10
|
from nshtrainer.optimizer import AdamWConfig as AdamWConfig
|
11
|
+
from nshtrainer.optimizer import ASGDConfig as ASGDConfig
|
12
|
+
from nshtrainer.optimizer import NAdamConfig as NAdamConfig
|
6
13
|
from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
|
7
14
|
from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
|
15
|
+
from nshtrainer.optimizer import RAdamConfig as RAdamConfig
|
16
|
+
from nshtrainer.optimizer import RMSpropConfig as RMSpropConfig
|
17
|
+
from nshtrainer.optimizer import RpropConfig as RpropConfig
|
18
|
+
from nshtrainer.optimizer import SGDConfig as SGDConfig
|
19
|
+
from nshtrainer.optimizer import Union as Union
|
8
20
|
from nshtrainer.optimizer import optimizer_registry as optimizer_registry
|
9
21
|
|
10
22
|
__all__ = [
|
23
|
+
"ASGDConfig",
|
24
|
+
"AdadeltaConfig",
|
25
|
+
"AdafactorConfig",
|
26
|
+
"AdagradConfig",
|
27
|
+
"AdamConfig",
|
11
28
|
"AdamWConfig",
|
29
|
+
"AdamaxConfig",
|
30
|
+
"NAdamConfig",
|
12
31
|
"OptimizerConfig",
|
13
32
|
"OptimizerConfigBase",
|
33
|
+
"RAdamConfig",
|
34
|
+
"RMSpropConfig",
|
35
|
+
"RpropConfig",
|
36
|
+
"SGDConfig",
|
37
|
+
"Union",
|
14
38
|
"optimizer_registry",
|
15
39
|
]
|
@@ -22,6 +22,9 @@ from nshtrainer.trainer._config import (
|
|
22
22
|
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
23
23
|
)
|
24
24
|
from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
|
25
|
+
from nshtrainer.trainer._config import (
|
26
|
+
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
27
|
+
)
|
25
28
|
from nshtrainer.trainer._config import (
|
26
29
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
27
30
|
)
|
@@ -149,6 +152,7 @@ __all__ = [
|
|
149
152
|
"DebugFlagCallbackConfig",
|
150
153
|
"DeepSpeedPluginConfig",
|
151
154
|
"DirectoryConfig",
|
155
|
+
"DistributedPredictionWriterConfig",
|
152
156
|
"DoublePrecisionPluginConfig",
|
153
157
|
"EarlyStoppingCallbackConfig",
|
154
158
|
"EnvironmentConfig",
|
@@ -18,6 +18,9 @@ from nshtrainer.trainer._config import (
|
|
18
18
|
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
19
19
|
)
|
20
20
|
from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
|
21
|
+
from nshtrainer.trainer._config import (
|
22
|
+
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
23
|
+
)
|
21
24
|
from nshtrainer.trainer._config import (
|
22
25
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
23
26
|
)
|
@@ -70,6 +73,7 @@ __all__ = [
|
|
70
73
|
"CheckpointSavingConfig",
|
71
74
|
"DebugFlagCallbackConfig",
|
72
75
|
"DirectoryConfig",
|
76
|
+
"DistributedPredictionWriterConfig",
|
73
77
|
"EarlyStoppingCallbackConfig",
|
74
78
|
"EnvironmentConfig",
|
75
79
|
"GradientClippingConfig",
|
nshtrainer/model/base.py
CHANGED
@@ -2,9 +2,9 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from abc import ABC, abstractmethod
|
5
|
-
from collections.abc import Callable, Mapping
|
5
|
+
from collections.abc import Callable, Iterable, Mapping, Sequence
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Any, Generic, Literal, cast
|
7
|
+
from typing import Any, Generic, Literal, TypedDict, cast
|
8
8
|
|
9
9
|
import nshconfig as C
|
10
10
|
import torch
|
@@ -53,6 +53,47 @@ VALID_REDUCE_OPS = (
|
|
53
53
|
)
|
54
54
|
|
55
55
|
|
56
|
+
class IndividualSample(TypedDict):
|
57
|
+
"""
|
58
|
+
A dictionary that contains the individual sample.
|
59
|
+
This is used to split the batched predictions into individual predictions.
|
60
|
+
"""
|
61
|
+
|
62
|
+
index: int
|
63
|
+
"""The index of the sample in the batch."""
|
64
|
+
|
65
|
+
batch: Any
|
66
|
+
"""The batch to split."""
|
67
|
+
|
68
|
+
prediction: Any
|
69
|
+
"""The batched prediction to split."""
|
70
|
+
|
71
|
+
|
72
|
+
def default_split_batched_predictions(
|
73
|
+
batch: Any,
|
74
|
+
prediction: Any,
|
75
|
+
batch_indices: Sequence[Any],
|
76
|
+
) -> Iterable[IndividualSample]:
|
77
|
+
"""
|
78
|
+
Splits the batched predictions into a list of individual predictions.
|
79
|
+
Args:
|
80
|
+
batch: The batch to split.
|
81
|
+
prediction: The batched prediction to split.
|
82
|
+
batch_indices: The indices of the batches.
|
83
|
+
Returns:
|
84
|
+
A tuple of two sequences: the corresponding batches and the individual predictions.
|
85
|
+
"""
|
86
|
+
import torch.utils._pytree as tree
|
87
|
+
|
88
|
+
for sample_idx, batch_idx in enumerate(batch_indices):
|
89
|
+
# Create a dictionary for each sample
|
90
|
+
yield IndividualSample(
|
91
|
+
index=batch_idx,
|
92
|
+
batch=tree.tree_map(lambda x: x[sample_idx], batch),
|
93
|
+
prediction=tree.tree_map(lambda x: x[sample_idx], prediction),
|
94
|
+
)
|
95
|
+
|
96
|
+
|
56
97
|
class LightningModuleBase(
|
57
98
|
DebugModuleMixin,
|
58
99
|
RLPSanityCheckModuleMixin,
|
@@ -171,6 +212,23 @@ class LightningModuleBase(
|
|
171
212
|
loss = cast(torch.Tensor, loss)
|
172
213
|
return loss
|
173
214
|
|
215
|
+
def split_batched_predictions(
|
216
|
+
self,
|
217
|
+
batch: Any,
|
218
|
+
prediction: Any,
|
219
|
+
batch_indices: Sequence[Any],
|
220
|
+
) -> Iterable[IndividualSample]:
|
221
|
+
"""
|
222
|
+
Splits the batched predictions into a list of individual predictions.
|
223
|
+
Args:
|
224
|
+
batch: The batch to split.
|
225
|
+
prediction: The batched prediction to split.
|
226
|
+
batch_indices: The indices of the batches.
|
227
|
+
Returns:
|
228
|
+
A tuple of two sequences: the corresponding batches and the individual predictions.
|
229
|
+
"""
|
230
|
+
return default_split_batched_predictions(batch, prediction, batch_indices)
|
231
|
+
|
174
232
|
@override
|
175
233
|
@classmethod
|
176
234
|
def load_from_checkpoint(cls, *args, **kwargs) -> Never:
|
nshtrainer/optimizer.py
CHANGED
@@ -2,10 +2,11 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from collections.abc import Iterable
|
5
|
-
from typing import Annotated, Any, Literal
|
5
|
+
from typing import Annotated, Any, Literal, Tuple, Union
|
6
6
|
|
7
7
|
import nshconfig as C
|
8
8
|
import torch.nn as nn
|
9
|
+
from torch import Tensor
|
9
10
|
from torch.optim import Optimizer
|
10
11
|
from typing_extensions import TypeAliasType, final, override
|
11
12
|
|
@@ -45,6 +46,18 @@ class AdamWConfig(OptimizerConfigBase):
|
|
45
46
|
amsgrad: bool = False
|
46
47
|
"""Whether to use the AMSGrad variant of this algorithm."""
|
47
48
|
|
49
|
+
maximize: bool = False
|
50
|
+
"""Maximize the objective with respect to the params, instead of minimizing."""
|
51
|
+
|
52
|
+
foreach: bool | None = None
|
53
|
+
"""Whether foreach implementation of optimizer is used."""
|
54
|
+
|
55
|
+
capturable: bool = False
|
56
|
+
"""Whether this instance is safe to capture in a CUDA graph."""
|
57
|
+
|
58
|
+
differentiable: bool = False
|
59
|
+
"""Whether autograd should occur through the optimizer step in training."""
|
60
|
+
|
48
61
|
@override
|
49
62
|
def create_optimizer(
|
50
63
|
self,
|
@@ -59,6 +72,551 @@ class AdamWConfig(OptimizerConfigBase):
|
|
59
72
|
betas=self.betas,
|
60
73
|
eps=self.eps,
|
61
74
|
amsgrad=self.amsgrad,
|
75
|
+
maximize=self.maximize,
|
76
|
+
foreach=self.foreach,
|
77
|
+
capturable=self.capturable,
|
78
|
+
differentiable=self.differentiable,
|
79
|
+
)
|
80
|
+
|
81
|
+
|
82
|
+
@final
|
83
|
+
@optimizer_registry.register
|
84
|
+
class AdafactorConfig(OptimizerConfigBase):
|
85
|
+
name: Literal["adafactor"] = "adafactor"
|
86
|
+
lr: float
|
87
|
+
"""Learning rate for the optimizer. If None, uses relative step size."""
|
88
|
+
|
89
|
+
eps1: float | None = None
|
90
|
+
"""Term added to the denominator to improve numerical stability (default: None)."""
|
91
|
+
|
92
|
+
eps2: float = 1e-3
|
93
|
+
"""Term added to the denominator to improve numerical stability (default: 1e-3)."""
|
94
|
+
|
95
|
+
beta2_decay: float = -0.8
|
96
|
+
"""Coefficient used for computing running averages of square gradient (default: -0.8)."""
|
97
|
+
|
98
|
+
weight_decay: float = 0.0
|
99
|
+
"""Weight decay (L2 penalty) (default: 0.0)."""
|
100
|
+
|
101
|
+
maximize: bool = False
|
102
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
103
|
+
|
104
|
+
@override
|
105
|
+
def create_optimizer(
|
106
|
+
self,
|
107
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
108
|
+
):
|
109
|
+
from torch.optim import Adafactor
|
110
|
+
|
111
|
+
return Adafactor(
|
112
|
+
parameters,
|
113
|
+
lr=self.lr,
|
114
|
+
eps=(self.eps1, self.eps2),
|
115
|
+
beta2_decay=self.beta2_decay,
|
116
|
+
weight_decay=self.weight_decay,
|
117
|
+
maximize=self.maximize,
|
118
|
+
)
|
119
|
+
|
120
|
+
|
121
|
+
@final
|
122
|
+
@optimizer_registry.register
|
123
|
+
class AdadeltaConfig(OptimizerConfigBase):
|
124
|
+
name: Literal["adadelta"] = "adadelta"
|
125
|
+
|
126
|
+
lr: float
|
127
|
+
"""Learning rate for the optimizer."""
|
128
|
+
|
129
|
+
rho: float = 0.9
|
130
|
+
"""Coefficient used for computing a running average of squared gradients."""
|
131
|
+
|
132
|
+
eps: float = 1e-6
|
133
|
+
"""Term added to the denominator to improve numerical stability."""
|
134
|
+
|
135
|
+
weight_decay: float = 0.0
|
136
|
+
"""Weight decay (L2 penalty) for the optimizer."""
|
137
|
+
|
138
|
+
maximize: bool = False
|
139
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
140
|
+
|
141
|
+
foreach: bool | None = None
|
142
|
+
"""Whether foreach implementation of optimizer is used."""
|
143
|
+
|
144
|
+
capturable: bool = False
|
145
|
+
"""Whether this instance is safe to capture in a CUDA graph."""
|
146
|
+
|
147
|
+
differentiable: bool = False
|
148
|
+
"""Whether autograd should occur through the optimizer step in training."""
|
149
|
+
|
150
|
+
@override
|
151
|
+
def create_optimizer(
|
152
|
+
self,
|
153
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
154
|
+
):
|
155
|
+
from torch.optim import Adadelta
|
156
|
+
|
157
|
+
return Adadelta(
|
158
|
+
parameters,
|
159
|
+
lr=self.lr,
|
160
|
+
rho=self.rho,
|
161
|
+
eps=self.eps,
|
162
|
+
weight_decay=self.weight_decay,
|
163
|
+
maximize=self.maximize,
|
164
|
+
foreach=self.foreach,
|
165
|
+
capturable=self.capturable,
|
166
|
+
differentiable=self.differentiable,
|
167
|
+
)
|
168
|
+
|
169
|
+
|
170
|
+
@final
|
171
|
+
@optimizer_registry.register
|
172
|
+
class AdagradConfig(OptimizerConfigBase):
|
173
|
+
name: Literal["adagrad"] = "adagrad"
|
174
|
+
|
175
|
+
lr: float
|
176
|
+
"""Learning rate for the optimizer."""
|
177
|
+
|
178
|
+
lr_decay: float = 0.0
|
179
|
+
"""Learning rate decay."""
|
180
|
+
|
181
|
+
weight_decay: float = 0.0
|
182
|
+
"""Weight decay (L2 penalty) for the optimizer."""
|
183
|
+
|
184
|
+
initial_accumulator_value: float = 0.0
|
185
|
+
"""Initial value for the accumulator."""
|
186
|
+
|
187
|
+
eps: float = 1e-10
|
188
|
+
"""Term added to the denominator to improve numerical stability."""
|
189
|
+
|
190
|
+
maximize: bool = False
|
191
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
192
|
+
|
193
|
+
foreach: bool | None = None
|
194
|
+
"""Whether foreach implementation of optimizer is used."""
|
195
|
+
|
196
|
+
differentiable: bool = False
|
197
|
+
"""Whether autograd should occur through the optimizer step in training."""
|
198
|
+
|
199
|
+
fused: bool | None = None
|
200
|
+
"""Whether the fused implementation is used."""
|
201
|
+
|
202
|
+
@override
|
203
|
+
def create_optimizer(
|
204
|
+
self,
|
205
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
206
|
+
):
|
207
|
+
from torch.optim import Adagrad
|
208
|
+
|
209
|
+
return Adagrad(
|
210
|
+
parameters,
|
211
|
+
lr=self.lr,
|
212
|
+
lr_decay=self.lr_decay,
|
213
|
+
weight_decay=self.weight_decay,
|
214
|
+
initial_accumulator_value=self.initial_accumulator_value,
|
215
|
+
eps=self.eps,
|
216
|
+
maximize=self.maximize,
|
217
|
+
foreach=self.foreach,
|
218
|
+
differentiable=self.differentiable,
|
219
|
+
fused=self.fused,
|
220
|
+
)
|
221
|
+
|
222
|
+
|
223
|
+
@final
|
224
|
+
@optimizer_registry.register
|
225
|
+
class AdamConfig(OptimizerConfigBase):
|
226
|
+
name: Literal["adam"] = "adam"
|
227
|
+
|
228
|
+
lr: float
|
229
|
+
"""Learning rate for the optimizer."""
|
230
|
+
|
231
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
232
|
+
"""Coefficients used for computing running averages of gradient and its square."""
|
233
|
+
|
234
|
+
eps: float = 1e-8
|
235
|
+
"""Term added to the denominator to improve numerical stability."""
|
236
|
+
|
237
|
+
weight_decay: float = 0.0
|
238
|
+
"""Weight decay (L2 penalty) for the optimizer."""
|
239
|
+
|
240
|
+
amsgrad: bool = False
|
241
|
+
"""Whether to use the AMSGrad variant of this algorithm."""
|
242
|
+
|
243
|
+
maximize: bool = False
|
244
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
245
|
+
|
246
|
+
foreach: bool | None = None
|
247
|
+
"""Whether foreach implementation of optimizer is used."""
|
248
|
+
|
249
|
+
capturable: bool = False
|
250
|
+
"""Whether this instance is safe to capture in a CUDA graph."""
|
251
|
+
|
252
|
+
differentiable: bool = False
|
253
|
+
"""Whether autograd should occur through the optimizer step in training."""
|
254
|
+
|
255
|
+
fused: bool | None = None
|
256
|
+
"""Whether the fused implementation is used."""
|
257
|
+
|
258
|
+
@override
|
259
|
+
def create_optimizer(
|
260
|
+
self,
|
261
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
262
|
+
):
|
263
|
+
from torch.optim import Adam
|
264
|
+
|
265
|
+
return Adam(
|
266
|
+
parameters,
|
267
|
+
lr=self.lr,
|
268
|
+
betas=self.betas,
|
269
|
+
eps=self.eps,
|
270
|
+
weight_decay=self.weight_decay,
|
271
|
+
amsgrad=self.amsgrad,
|
272
|
+
maximize=self.maximize,
|
273
|
+
foreach=self.foreach,
|
274
|
+
capturable=self.capturable,
|
275
|
+
differentiable=self.differentiable,
|
276
|
+
fused=self.fused,
|
277
|
+
)
|
278
|
+
|
279
|
+
|
280
|
+
@final
|
281
|
+
@optimizer_registry.register
|
282
|
+
class AdamaxConfig(OptimizerConfigBase):
|
283
|
+
name: Literal["adamax"] = "adamax"
|
284
|
+
|
285
|
+
lr: float
|
286
|
+
"""Learning rate for the optimizer."""
|
287
|
+
|
288
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
289
|
+
"""Coefficients used for computing running averages of gradient and its square."""
|
290
|
+
|
291
|
+
eps: float = 1e-8
|
292
|
+
"""Term added to the denominator to improve numerical stability."""
|
293
|
+
|
294
|
+
weight_decay: float = 0.0
|
295
|
+
"""Weight decay (L2 penalty) for the optimizer."""
|
296
|
+
|
297
|
+
maximize: bool = False
|
298
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
299
|
+
|
300
|
+
foreach: bool | None = None
|
301
|
+
"""Whether foreach implementation of optimizer is used."""
|
302
|
+
|
303
|
+
capturable: bool = False
|
304
|
+
"""Whether this instance is safe to capture in a CUDA graph."""
|
305
|
+
|
306
|
+
differentiable: bool = False
|
307
|
+
"""Whether autograd should occur through the optimizer step in training."""
|
308
|
+
|
309
|
+
@override
|
310
|
+
def create_optimizer(
|
311
|
+
self,
|
312
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
313
|
+
):
|
314
|
+
from torch.optim import Adamax
|
315
|
+
|
316
|
+
return Adamax(
|
317
|
+
parameters,
|
318
|
+
lr=self.lr,
|
319
|
+
betas=self.betas,
|
320
|
+
eps=self.eps,
|
321
|
+
weight_decay=self.weight_decay,
|
322
|
+
maximize=self.maximize,
|
323
|
+
foreach=self.foreach,
|
324
|
+
capturable=self.capturable,
|
325
|
+
differentiable=self.differentiable,
|
326
|
+
)
|
327
|
+
|
328
|
+
|
329
|
+
@final
|
330
|
+
@optimizer_registry.register
|
331
|
+
class ASGDConfig(OptimizerConfigBase):
|
332
|
+
name: Literal["asgd"] = "asgd"
|
333
|
+
|
334
|
+
lr: float
|
335
|
+
"""Learning rate for the optimizer."""
|
336
|
+
|
337
|
+
lambd: float = 1e-4
|
338
|
+
"""Decay term."""
|
339
|
+
|
340
|
+
alpha: float = 0.75
|
341
|
+
"""Power for eta update."""
|
342
|
+
|
343
|
+
t0: float = 1e6
|
344
|
+
"""Point at which to start averaging."""
|
345
|
+
|
346
|
+
weight_decay: float = 0.0
|
347
|
+
"""Weight decay (L2 penalty) for the optimizer."""
|
348
|
+
|
349
|
+
maximize: bool = False
|
350
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
351
|
+
|
352
|
+
@override
|
353
|
+
def create_optimizer(
|
354
|
+
self,
|
355
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
356
|
+
):
|
357
|
+
from torch.optim import ASGD
|
358
|
+
|
359
|
+
return ASGD(
|
360
|
+
parameters,
|
361
|
+
lr=self.lr,
|
362
|
+
lambd=self.lambd,
|
363
|
+
alpha=self.alpha,
|
364
|
+
t0=self.t0,
|
365
|
+
weight_decay=self.weight_decay,
|
366
|
+
maximize=self.maximize,
|
367
|
+
)
|
368
|
+
|
369
|
+
|
370
|
+
@final
|
371
|
+
@optimizer_registry.register
|
372
|
+
class NAdamConfig(OptimizerConfigBase):
|
373
|
+
name: Literal["nadam"] = "nadam"
|
374
|
+
|
375
|
+
lr: float
|
376
|
+
"""Learning rate for the optimizer."""
|
377
|
+
|
378
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
379
|
+
"""Coefficients used for computing running averages of gradient and its square."""
|
380
|
+
|
381
|
+
eps: float = 1e-8
|
382
|
+
"""Term added to the denominator to improve numerical stability."""
|
383
|
+
|
384
|
+
weight_decay: float = 0.0
|
385
|
+
"""Weight decay (L2 penalty) for the optimizer."""
|
386
|
+
|
387
|
+
momentum_decay: float = 4e-3
|
388
|
+
"""Momentum decay."""
|
389
|
+
|
390
|
+
decoupled_weight_decay: bool = False
|
391
|
+
"""Whether to use decoupled weight decay."""
|
392
|
+
|
393
|
+
maximize: bool = False
|
394
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
395
|
+
|
396
|
+
foreach: bool | None = None
|
397
|
+
"""Whether foreach implementation of optimizer is used."""
|
398
|
+
|
399
|
+
capturable: bool = False
|
400
|
+
"""Whether this instance is safe to capture in a CUDA graph."""
|
401
|
+
|
402
|
+
differentiable: bool = False
|
403
|
+
"""Whether autograd should occur through the optimizer step in training."""
|
404
|
+
|
405
|
+
@override
|
406
|
+
def create_optimizer(
|
407
|
+
self,
|
408
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
409
|
+
):
|
410
|
+
from torch.optim import NAdam
|
411
|
+
|
412
|
+
return NAdam(
|
413
|
+
parameters,
|
414
|
+
lr=self.lr,
|
415
|
+
betas=self.betas,
|
416
|
+
eps=self.eps,
|
417
|
+
weight_decay=self.weight_decay,
|
418
|
+
momentum_decay=self.momentum_decay,
|
419
|
+
decoupled_weight_decay=self.decoupled_weight_decay,
|
420
|
+
maximize=self.maximize,
|
421
|
+
foreach=self.foreach,
|
422
|
+
capturable=self.capturable,
|
423
|
+
differentiable=self.differentiable,
|
424
|
+
)
|
425
|
+
|
426
|
+
|
427
|
+
@final
|
428
|
+
@optimizer_registry.register
|
429
|
+
class RAdamConfig(OptimizerConfigBase):
|
430
|
+
name: Literal["radam"] = "radam"
|
431
|
+
|
432
|
+
lr: float
|
433
|
+
"""Learning rate for the optimizer."""
|
434
|
+
|
435
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
436
|
+
"""Coefficients used for computing running averages of gradient and its square."""
|
437
|
+
|
438
|
+
eps: float = 1e-8
|
439
|
+
"""Term added to the denominator to improve numerical stability."""
|
440
|
+
|
441
|
+
weight_decay: float = 0.0
|
442
|
+
"""Weight decay (L2 penalty) for the optimizer."""
|
443
|
+
|
444
|
+
decoupled_weight_decay: bool = False
|
445
|
+
"""Whether to use decoupled weight decay."""
|
446
|
+
|
447
|
+
maximize: bool = False
|
448
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
449
|
+
|
450
|
+
foreach: bool | None = None
|
451
|
+
"""Whether foreach implementation of optimizer is used."""
|
452
|
+
|
453
|
+
capturable: bool = False
|
454
|
+
"""Whether this instance is safe to capture in a CUDA graph."""
|
455
|
+
|
456
|
+
differentiable: bool = False
|
457
|
+
"""Whether autograd should occur through the optimizer step in training."""
|
458
|
+
|
459
|
+
@override
|
460
|
+
def create_optimizer(
|
461
|
+
self,
|
462
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
463
|
+
):
|
464
|
+
from torch.optim import RAdam
|
465
|
+
|
466
|
+
return RAdam(
|
467
|
+
parameters,
|
468
|
+
lr=self.lr,
|
469
|
+
betas=self.betas,
|
470
|
+
eps=self.eps,
|
471
|
+
weight_decay=self.weight_decay,
|
472
|
+
decoupled_weight_decay=self.decoupled_weight_decay,
|
473
|
+
maximize=self.maximize,
|
474
|
+
foreach=self.foreach,
|
475
|
+
capturable=self.capturable,
|
476
|
+
differentiable=self.differentiable,
|
477
|
+
)
|
478
|
+
|
479
|
+
|
480
|
+
@final
|
481
|
+
@optimizer_registry.register
|
482
|
+
class RMSpropConfig(OptimizerConfigBase):
|
483
|
+
name: Literal["rmsprop"] = "rmsprop"
|
484
|
+
|
485
|
+
lr: float
|
486
|
+
"""Learning rate for the optimizer."""
|
487
|
+
|
488
|
+
alpha: float = 0.99
|
489
|
+
"""Smoothing constant."""
|
490
|
+
|
491
|
+
eps: float = 1e-8
|
492
|
+
"""Term added to the denominator to improve numerical stability."""
|
493
|
+
|
494
|
+
weight_decay: float = 0.0
|
495
|
+
"""Weight decay (L2 penalty) for the optimizer."""
|
496
|
+
|
497
|
+
momentum: float = 0.0
|
498
|
+
"""Momentum factor."""
|
499
|
+
|
500
|
+
centered: bool = False
|
501
|
+
"""If True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance."""
|
502
|
+
|
503
|
+
maximize: bool = False
|
504
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
505
|
+
|
506
|
+
foreach: bool | None = None
|
507
|
+
"""Whether foreach implementation of optimizer is used."""
|
508
|
+
|
509
|
+
capturable: bool = False
|
510
|
+
"""Whether this instance is safe to capture in a CUDA graph."""
|
511
|
+
|
512
|
+
differentiable: bool = False
|
513
|
+
"""Whether autograd should occur through the optimizer step in training."""
|
514
|
+
|
515
|
+
@override
|
516
|
+
def create_optimizer(
|
517
|
+
self,
|
518
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
519
|
+
):
|
520
|
+
from torch.optim import RMSprop
|
521
|
+
|
522
|
+
return RMSprop(
|
523
|
+
parameters,
|
524
|
+
lr=self.lr,
|
525
|
+
alpha=self.alpha,
|
526
|
+
eps=self.eps,
|
527
|
+
weight_decay=self.weight_decay,
|
528
|
+
momentum=self.momentum,
|
529
|
+
centered=self.centered,
|
530
|
+
maximize=self.maximize,
|
531
|
+
foreach=self.foreach,
|
532
|
+
capturable=self.capturable,
|
533
|
+
differentiable=self.differentiable,
|
534
|
+
)
|
535
|
+
|
536
|
+
|
537
|
+
@final
|
538
|
+
@optimizer_registry.register
|
539
|
+
class RpropConfig(OptimizerConfigBase):
|
540
|
+
name: Literal["rprop"] = "rprop"
|
541
|
+
|
542
|
+
lr: float
|
543
|
+
"""Learning rate for the optimizer."""
|
544
|
+
|
545
|
+
etas: tuple[float, float] = (0.5, 1.2)
|
546
|
+
"""Pair of (etaminus, etaplus), multiplicative increase and decrease factors."""
|
547
|
+
|
548
|
+
step_sizes: tuple[float, float] = (1e-6, 50.0)
|
549
|
+
"""Pair of minimal and maximal allowed step sizes."""
|
550
|
+
|
551
|
+
maximize: bool = False
|
552
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
553
|
+
|
554
|
+
@override
|
555
|
+
def create_optimizer(
|
556
|
+
self,
|
557
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
558
|
+
):
|
559
|
+
from torch.optim import Rprop
|
560
|
+
|
561
|
+
return Rprop(
|
562
|
+
parameters,
|
563
|
+
lr=self.lr,
|
564
|
+
etas=self.etas,
|
565
|
+
step_sizes=self.step_sizes,
|
566
|
+
maximize=self.maximize,
|
567
|
+
)
|
568
|
+
|
569
|
+
|
570
|
+
@final
|
571
|
+
@optimizer_registry.register
|
572
|
+
class SGDConfig(OptimizerConfigBase):
|
573
|
+
name: Literal["sgd"] = "sgd"
|
574
|
+
|
575
|
+
lr: float
|
576
|
+
"""Learning rate for the optimizer."""
|
577
|
+
|
578
|
+
momentum: float = 0.0
|
579
|
+
"""Momentum factor."""
|
580
|
+
|
581
|
+
dampening: float = 0.0
|
582
|
+
"""Dampening for momentum."""
|
583
|
+
|
584
|
+
weight_decay: float = 0.0
|
585
|
+
"""Weight decay (L2 penalty) for the optimizer."""
|
586
|
+
|
587
|
+
nesterov: bool = False
|
588
|
+
"""Enables Nesterov momentum."""
|
589
|
+
|
590
|
+
maximize: bool = False
|
591
|
+
"""Maximize the params based on the objective, instead of minimizing."""
|
592
|
+
|
593
|
+
foreach: bool | None = None
|
594
|
+
"""Whether foreach implementation of optimizer is used."""
|
595
|
+
|
596
|
+
differentiable: bool = False
|
597
|
+
"""Whether autograd should occur through the optimizer step in training."""
|
598
|
+
|
599
|
+
fused: bool | None = None
|
600
|
+
"""Whether the fused implementation is used."""
|
601
|
+
|
602
|
+
@override
|
603
|
+
def create_optimizer(
|
604
|
+
self,
|
605
|
+
parameters: Iterable[nn.Parameter] | Iterable[dict[str, Any]],
|
606
|
+
):
|
607
|
+
from torch.optim import SGD
|
608
|
+
|
609
|
+
return SGD(
|
610
|
+
parameters,
|
611
|
+
lr=self.lr,
|
612
|
+
momentum=self.momentum,
|
613
|
+
dampening=self.dampening,
|
614
|
+
weight_decay=self.weight_decay,
|
615
|
+
nesterov=self.nesterov,
|
616
|
+
maximize=self.maximize,
|
617
|
+
foreach=self.foreach,
|
618
|
+
differentiable=self.differentiable,
|
619
|
+
fused=self.fused,
|
62
620
|
)
|
63
621
|
|
64
622
|
|
nshtrainer/trainer/_config.py
CHANGED
@@ -31,6 +31,7 @@ from .._hf_hub import HuggingFaceHubConfig
|
|
31
31
|
from ..callbacks import (
|
32
32
|
BestCheckpointCallbackConfig,
|
33
33
|
CallbackConfig,
|
34
|
+
DistributedPredictionWriterConfig,
|
34
35
|
EarlyStoppingCallbackConfig,
|
35
36
|
LastCheckpointCallbackConfig,
|
36
37
|
NormLoggingCallbackConfig,
|
@@ -701,6 +702,14 @@ class TrainerConfig(C.Config):
|
|
701
702
|
auto_validate_metrics: MetricValidationCallbackConfig | None = None
|
702
703
|
"""If enabled, will automatically validate the metrics before starting the training routine."""
|
703
704
|
|
705
|
+
distributed_predict: DistributedPredictionWriterConfig | None = (
|
706
|
+
DistributedPredictionWriterConfig()
|
707
|
+
)
|
708
|
+
"""If enabled, will use a custom BasePredictionWriter callback to automatically
|
709
|
+
handle distributed prediction. This is useful for running prediction on multiple GPUs
|
710
|
+
seamlessly.
|
711
|
+
"""
|
712
|
+
|
704
713
|
lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
|
705
714
|
"""
|
706
715
|
Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
|
@@ -752,10 +761,6 @@ class TrainerConfig(C.Config):
|
|
752
761
|
)
|
753
762
|
|
754
763
|
def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
|
755
|
-
# Disable all callbacks if barebones mode is enabled
|
756
|
-
if self.barebones:
|
757
|
-
return
|
758
|
-
|
759
764
|
yield self.early_stopping
|
760
765
|
yield self.checkpoint_saving
|
761
766
|
yield self.lr_monitor
|
@@ -772,6 +777,7 @@ class TrainerConfig(C.Config):
|
|
772
777
|
yield self.reduce_lr_on_plateau_sanity_checking
|
773
778
|
yield self.auto_set_debug_flag
|
774
779
|
yield self.auto_validate_metrics
|
780
|
+
yield self.distributed_predict
|
775
781
|
yield from self.callbacks
|
776
782
|
|
777
783
|
def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
|
nshtrainer/trainer/trainer.py
CHANGED
@@ -10,12 +10,16 @@ import torch
|
|
10
10
|
from lightning.fabric.plugins.environments.lsf import LSFEnvironment
|
11
11
|
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
|
12
12
|
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
|
13
|
-
from lightning.pytorch import LightningModule
|
13
|
+
from lightning.pytorch import LightningDataModule, LightningModule
|
14
14
|
from lightning.pytorch import Trainer as LightningTrainer
|
15
15
|
from lightning.pytorch.callbacks import Callback
|
16
16
|
from lightning.pytorch.profilers import Profiler
|
17
17
|
from lightning.pytorch.trainer.states import TrainerFn
|
18
|
-
from lightning.pytorch.utilities.types import
|
18
|
+
from lightning.pytorch.utilities.types import (
|
19
|
+
_EVALUATE_OUTPUT,
|
20
|
+
_PREDICT_OUTPUT,
|
21
|
+
EVAL_DATALOADERS,
|
22
|
+
)
|
19
23
|
from typing_extensions import Never, Unpack, assert_never, deprecated, override
|
20
24
|
|
21
25
|
from .._checkpoint.metadata import write_checkpoint_metadata
|
@@ -532,3 +536,18 @@ class Trainer(LightningTrainer):
|
|
532
536
|
update_hparams_dict=update_hparams_dict,
|
533
537
|
)
|
534
538
|
return cls(hparams)
|
539
|
+
|
540
|
+
def distributed_predict(
|
541
|
+
self,
|
542
|
+
model: LightningModule | None = None,
|
543
|
+
dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None,
|
544
|
+
datamodule: LightningDataModule | None = None,
|
545
|
+
ckpt_path: str | Path | None = None,
|
546
|
+
):
|
547
|
+
self.predict(
|
548
|
+
model,
|
549
|
+
dataloaders,
|
550
|
+
datamodule,
|
551
|
+
return_predictions=False,
|
552
|
+
ckpt_path=ckpt_path,
|
553
|
+
)
|
@@ -3,12 +3,12 @@ nshtrainer/__init__.py,sha256=VcqBfL8RgCcZDaY645nxeDmOspqerx4x46wggCMnS0E,692
|
|
3
3
|
nshtrainer/_callback.py,sha256=ZDppiJ4d65tRXTEWYPZLH_F1xFizdz1pkWJe_sQ5uII,12564
|
4
4
|
nshtrainer/_checkpoint/metadata.py,sha256=Hh5a7OkdknUEbkEwX6vS88-XLEeuVDoR6a3en2uLzQE,5597
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=utcrYKSosd04N9m2GIylufO5DO05D90qVU3mvadfApU,1658
|
6
|
-
nshtrainer/_directory.py,sha256=
|
6
|
+
nshtrainer/_directory.py,sha256=SuXJe9xJXZkDXWWfeOS9rEDz6vZUA6mpnEdkAW0ZQnY,3193
|
7
7
|
nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
8
8
|
nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
|
9
|
-
nshtrainer/callbacks/__init__.py,sha256=
|
9
|
+
nshtrainer/callbacks/__init__.py,sha256=m6eJuprZfBELuKpngKXre33B9yPXkG7jlKVmI-0yXRQ,4000
|
10
10
|
nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
|
11
|
-
nshtrainer/callbacks/base.py,sha256=
|
11
|
+
nshtrainer/callbacks/base.py,sha256=K9aom1WVVRYxl-tHWgtmDUQZ1o63NgznvLsjauTKcCc,4225
|
12
12
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
|
13
13
|
nshtrainer/callbacks/checkpoint/_base.py,sha256=f7lpk8W4xqxk3PolBEU3AWt9VTIpoLW7wMUhC5DNm3c,6345
|
14
14
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=aCs3E1eucfDlUeW2Iq_Ke7hb96BxHanmvn7PCCbqq0E,2648
|
@@ -16,6 +16,7 @@ nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM
|
|
16
16
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
|
17
17
|
nshtrainer/callbacks/debug_flag.py,sha256=96fuP0C7C6dSs1GiMeUYzzs0X3Q4Pjt9JVWg3b75fU4,1748
|
18
18
|
nshtrainer/callbacks/directory_setup.py,sha256=wPas_Ren8ANejogmIdKhqqgj4ulxz9AS_8xVIAfRXa0,2565
|
19
|
+
nshtrainer/callbacks/distributed_prediction_writer.py,sha256=OSh2C6XF7Nki4eFByNVhwlt69izkxnlmfPx54w4rvBo,5274
|
19
20
|
nshtrainer/callbacks/early_stopping.py,sha256=rC_qYKCQWjRQJFo0ky46uG0aDJdYP8vsSlKunk0bUVI,4765
|
20
21
|
nshtrainer/callbacks/ema.py,sha256=dBFiUXG0xmyCw8-ayuSzJMKqSbepl6Ii5VIbhFlT5ug,12255
|
21
22
|
nshtrainer/callbacks/finite_checks.py,sha256=3lZ3kEIjmYQfqTF0DcrgZ9_98ZLQhQj8usH7SgWst3o,2185
|
@@ -32,12 +33,12 @@ nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU
|
|
32
33
|
nshtrainer/callbacks/wandb_upload_code.py,sha256=shV7UtnXgY2bUlXdVrXiaDs0PNLlIt7TzNJkJPkzvzI,2414
|
33
34
|
nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
|
34
35
|
nshtrainer/configs/.gitattributes,sha256=VeZmarvNEqiRBOHGcllpKm90nL6C8u4tBu7SEm7fj-E,26
|
35
|
-
nshtrainer/configs/__init__.py,sha256=
|
36
|
+
nshtrainer/configs/__init__.py,sha256=KD3uClMwnA4LfQ7rY5phDdUbp3j8NoZfaGbGPbpaJVs,15848
|
36
37
|
nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
|
37
38
|
nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
|
38
39
|
nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
|
39
40
|
nshtrainer/configs/_hf_hub/__init__.py,sha256=ciFLbV-JV8SVzqo2SyythEuDMnk7gGfdIacB18QYnkY,511
|
40
|
-
nshtrainer/configs/callbacks/__init__.py,sha256=
|
41
|
+
nshtrainer/configs/callbacks/__init__.py,sha256=tP9urR73NIanyxpbi4EERsxOnGNiptbQpmsj-v53a38,4774
|
41
42
|
nshtrainer/configs/callbacks/actsave/__init__.py,sha256=JvjSZtEoA28FC4u-QT3skQzBDVbN9eq07rn4u2ydW-E,377
|
42
43
|
nshtrainer/configs/callbacks/base/__init__.py,sha256=wT3RhXttLyf6RFWCIvsoiXcPdfGx5W309WBI18AI5os,278
|
43
44
|
nshtrainer/configs/callbacks/checkpoint/__init__.py,sha256=aGJ7vX14YamkMdwYAdPv6XrRnP0aZd5uZ5X0nSLc6IU,1475
|
@@ -47,6 +48,7 @@ nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py,sha256=SIRfz
|
|
47
48
|
nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py,sha256=VSkO0TYCAYy_9mQuOBoAND7D3Cg6w6nMCpqivQZLPcE,551
|
48
49
|
nshtrainer/configs/callbacks/debug_flag/__init__.py,sha256=s_ifB-DbZjar0w11pr2oVAlcMTWWMnK_tCNilfswL04,425
|
49
50
|
nshtrainer/configs/callbacks/directory_setup/__init__.py,sha256=e8GCRy2Alds3AXLwp4ieSGtn8S0YjmKJ5khOaQ0zKGs,464
|
51
|
+
nshtrainer/configs/callbacks/distributed_prediction_writer/__init__.py,sha256=npO97m5inRgAnGtGBwz_MNJz44B2cG4j9LZFCllQcrk,530
|
50
52
|
nshtrainer/configs/callbacks/early_stopping/__init__.py,sha256=m8N6H11PjqcWqXP5ZxWC8L4PHMUI6avYyN5rUNprjuQ,546
|
51
53
|
nshtrainer/configs/callbacks/ema/__init__.py,sha256=DUJrbDD8wWX_s0_4dwKpT_IWKSVpBmhe4-1aELq7G6w,377
|
52
54
|
nshtrainer/configs/callbacks/finite_checks/__init__.py,sha256=e-vx9Kn-noqw4wPvZw7fDMfb9Tsa6Duk0TIa8ZIgIIE,443
|
@@ -77,14 +79,14 @@ nshtrainer/configs/nn/__init__.py,sha256=Ms2gIqbRxNVm6GHKCddCJTTqMwUPifjjHD_fCfJ
|
|
77
79
|
nshtrainer/configs/nn/mlp/__init__.py,sha256=O6kQ6utZNJPG9Fax5pRdZcHa3J-XFKKdXcc_PQg0jk0,347
|
78
80
|
nshtrainer/configs/nn/nonlinearity/__init__.py,sha256=LCTbTyelCMABVw505CGQ4UpEGlAnIhflSLFqwAQXLQA,2155
|
79
81
|
nshtrainer/configs/nn/rng/__init__.py,sha256=4iC6vwxbfNeXyvpwZ1Z5Kcy-he4cu7mg3UpLD-RLrHc,141
|
80
|
-
nshtrainer/configs/optimizer/__init__.py,sha256=
|
82
|
+
nshtrainer/configs/optimizer/__init__.py,sha256=8ztp5UD-edfzwF-qdJTeZwlv-YWJ5Sn230b9aWxJyQQ,1398
|
81
83
|
nshtrainer/configs/profiler/__init__.py,sha256=2ssaIpfVnvcbfNvZ-JeKp1Cx4NO1LknkVqTm1hu7Lvw,768
|
82
84
|
nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRDm1CKqjwUOQNbQjD4,176
|
83
85
|
nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
|
84
86
|
nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
|
85
87
|
nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
|
86
|
-
nshtrainer/configs/trainer/__init__.py,sha256=
|
87
|
-
nshtrainer/configs/trainer/_config/__init__.py,sha256=
|
88
|
+
nshtrainer/configs/trainer/__init__.py,sha256=PF9rYuVpk0IuhjcxS_hmBTT6A0oq7AWZDcx0Gfqi7MM,8040
|
89
|
+
nshtrainer/configs/trainer/_config/__init__.py,sha256=5B8pjyNHfyFJ6p8dD5VSHD1tw2CcZ87Eq2C_Req3t60,3977
|
88
90
|
nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
|
89
91
|
nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
|
90
92
|
nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=slW5z1FZw2qICXO9l9DnLIDB1Yl7KOcxPEZkyYIHrp4,276
|
@@ -116,7 +118,7 @@ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=irPyDjfUX843ze4bJM9sW8WSe
|
|
116
118
|
nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
|
117
119
|
nshtrainer/metrics/_config.py,sha256=ox_ScK6V0J9nzIMhEB0qpToNKpt83VVgOVSRFCV-wBc,595
|
118
120
|
nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
|
119
|
-
nshtrainer/model/base.py,sha256=
|
121
|
+
nshtrainer/model/base.py,sha256=Pv3M3QStWQp-DnfGFsLPAmp87HHrX1NrkAa4JcyBoDk,10255
|
120
122
|
nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
|
121
123
|
nshtrainer/model/mixins/debug.py,sha256=ydLuAAaa7M5bX0gougZ5gWuZnvn4Ra9assal3IZ9hq8,2086
|
122
124
|
nshtrainer/model/mixins/logger.py,sha256=7u9fQig-SVFA9RFIB4U0gqJAzruh49mgmXXvZ6VkDUk,11694
|
@@ -126,14 +128,14 @@ nshtrainer/nn/module_dict.py,sha256=9plb8aQUx5TUEPhX5jI9u8LrpTeKe7jZAHi8iIqcN8w,
|
|
126
128
|
nshtrainer/nn/module_list.py,sha256=UB43pcwD_3nUke_DyLQt-iXKhWdKM6Zjm84lRC1hPYA,1755
|
127
129
|
nshtrainer/nn/nonlinearity.py,sha256=xmaL4QCRvCxqmaGIOwetJeKK-6IK4m2OV7D3SjxSwJQ,6322
|
128
130
|
nshtrainer/nn/rng.py,sha256=IJGvX9v8qBkfgBrMlNU2aj-MbYTPoncFyJzvPkzCQpM,512
|
129
|
-
nshtrainer/optimizer.py,sha256=
|
131
|
+
nshtrainer/optimizer.py,sha256=8pjOny7NxIt04PXxn3zOyJ2soL7nmj8yBVV82r_tNsc,17522
|
130
132
|
nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
|
131
133
|
nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
|
132
134
|
nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5NOQ,1160
|
133
135
|
nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
|
134
136
|
nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
|
135
137
|
nshtrainer/trainer/__init__.py,sha256=fQ7gQRlGWX-90TYT0rttkQyvXDCzo7DAvJgr-jX1zsY,316
|
136
|
-
nshtrainer/trainer/_config.py,sha256=
|
138
|
+
nshtrainer/trainer/_config.py,sha256=tdWAYh-KGXBpgdY8fwvOejjRZN-AS2Ze0f_9s2VEuZ0,33556
|
137
139
|
nshtrainer/trainer/_log_hparams.py,sha256=XH2lZ4U_3AZBhOt91ocsEhdL_NRz35oWvqLCUFDohUs,2389
|
138
140
|
nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
|
139
141
|
nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
|
@@ -145,7 +147,7 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=-BbEyWZ063O7tZme7Gdu1lVxK6p1NeuLc
|
|
145
147
|
nshtrainer/trainer/plugin/precision.py,sha256=7lf7KZd_yFyPmhLApjEIv0pkoDB5zdxi-7in0wRj3z8,5436
|
146
148
|
nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
|
147
149
|
nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
|
148
|
-
nshtrainer/trainer/trainer.py,sha256=
|
150
|
+
nshtrainer/trainer/trainer.py,sha256=smoN61iixWYDWGFvxrt8VwryZVy_NzqqjUcgOid0gRA,21696
|
149
151
|
nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
|
150
152
|
nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
|
151
153
|
nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
|
@@ -157,6 +159,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
|
157
159
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
158
160
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
159
161
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
160
|
-
nshtrainer-1.
|
161
|
-
nshtrainer-1.
|
162
|
-
nshtrainer-1.
|
162
|
+
nshtrainer-1.2.0.dist-info/METADATA,sha256=HkNLruaJJuf3ijnGe7NqNd9emBR6QHMRh2-taC5wTrU,960
|
163
|
+
nshtrainer-1.2.0.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
164
|
+
nshtrainer-1.2.0.dist-info/RECORD,,
|