nshtrainer 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_callback.py +50 -3
- nshtrainer/callbacks/__init__.py +1 -1
- nshtrainer/callbacks/checkpoint/_base.py +2 -2
- nshtrainer/callbacks/log_epoch.py +55 -7
- nshtrainer/callbacks/print_table.py +2 -2
- nshtrainer/callbacks/rlp_sanity_checks.py +1 -0
- nshtrainer/configs/__init__.py +0 -2
- nshtrainer/configs/optimizer/__init__.py +0 -2
- nshtrainer/loggers/__init__.py +1 -2
- nshtrainer/loggers/actsave.py +7 -1
- nshtrainer/loggers/wandb.py +5 -5
- nshtrainer/lr_scheduler/base.py +1 -1
- nshtrainer/model/mixins/callback.py +0 -17
- nshtrainer/model/mixins/logger.py +1 -0
- nshtrainer/nn/module_dict.py +4 -4
- nshtrainer/nn/module_list.py +17 -17
- nshtrainer/nn/nonlinearity.py +15 -2
- nshtrainer/optimizer.py +2 -4
- nshtrainer/trainer/accelerator.py +1 -2
- nshtrainer/trainer/plugin/__init__.py +1 -2
- nshtrainer/util/code_upload.py +1 -1
- {nshtrainer-1.4.1.dist-info → nshtrainer-1.5.1.dist-info}/METADATA +1 -1
- {nshtrainer-1.4.1.dist-info → nshtrainer-1.5.1.dist-info}/RECORD +24 -24
- {nshtrainer-1.4.1.dist-info → nshtrainer-1.5.1.dist-info}/WHEEL +0 -0
nshtrainer/_callback.py
CHANGED
@@ -8,38 +8,46 @@ from lightning.pytorch import LightningModule
|
|
8
8
|
from lightning.pytorch.callbacks import Callback as _LightningCallback
|
9
9
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
10
10
|
from torch.optim import Optimizer
|
11
|
+
from typing_extensions import override
|
11
12
|
|
12
13
|
if TYPE_CHECKING:
|
13
14
|
from .trainer import Trainer
|
14
15
|
|
15
16
|
|
16
17
|
class NTCallbackBase(_LightningCallback):
|
18
|
+
@override
|
17
19
|
def setup( # pyright: ignore[reportIncompatibleMethodOverride]
|
18
20
|
self, trainer: Trainer, pl_module: LightningModule, stage: str
|
19
21
|
) -> None:
|
20
22
|
"""Called when fit, validate, test, predict, or tune begins."""
|
21
23
|
|
24
|
+
@override
|
22
25
|
def teardown( # pyright: ignore[reportIncompatibleMethodOverride]
|
23
26
|
self, trainer: Trainer, pl_module: LightningModule, stage: str
|
24
27
|
) -> None:
|
25
28
|
"""Called when fit, validate, test, predict, or tune ends."""
|
26
29
|
|
30
|
+
@override
|
27
31
|
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
28
32
|
"""Called when fit begins."""
|
29
33
|
|
34
|
+
@override
|
30
35
|
def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
31
36
|
"""Called when fit ends."""
|
32
37
|
|
38
|
+
@override
|
33
39
|
def on_sanity_check_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
34
40
|
self, trainer: Trainer, pl_module: LightningModule
|
35
41
|
) -> None:
|
36
42
|
"""Called when the validation sanity check starts."""
|
37
43
|
|
44
|
+
@override
|
38
45
|
def on_sanity_check_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
39
46
|
self, trainer: Trainer, pl_module: LightningModule
|
40
47
|
) -> None:
|
41
48
|
"""Called when the validation sanity check ends."""
|
42
49
|
|
50
|
+
@override
|
43
51
|
def on_train_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
44
52
|
self,
|
45
53
|
trainer: Trainer,
|
@@ -49,6 +57,7 @@ class NTCallbackBase(_LightningCallback):
|
|
49
57
|
) -> None:
|
50
58
|
"""Called when the train batch begins."""
|
51
59
|
|
60
|
+
@override
|
52
61
|
def on_train_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
53
62
|
self,
|
54
63
|
trainer: Trainer,
|
@@ -65,11 +74,13 @@ class NTCallbackBase(_LightningCallback):
|
|
65
74
|
|
66
75
|
"""
|
67
76
|
|
77
|
+
@override
|
68
78
|
def on_train_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
69
79
|
self, trainer: Trainer, pl_module: LightningModule
|
70
80
|
) -> None:
|
71
81
|
"""Called when the train epoch begins."""
|
72
82
|
|
83
|
+
@override
|
73
84
|
def on_train_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
74
85
|
self, trainer: Trainer, pl_module: LightningModule
|
75
86
|
) -> None:
|
@@ -81,10 +92,12 @@ class NTCallbackBase(_LightningCallback):
|
|
81
92
|
.. code-block:: python
|
82
93
|
|
83
94
|
class MyLightningModule(L.LightningModule):
|
95
|
+
@override
|
84
96
|
def __init__(self):
|
85
97
|
super().__init__() # pyright: ignore[reportIncompatibleMethodOverride]
|
86
98
|
self.training_step_outputs = []
|
87
99
|
|
100
|
+
@override
|
88
101
|
def training_step(self):
|
89
102
|
loss = ... # pyright: ignore[reportIncompatibleMethodOverride]
|
90
103
|
self.training_step_outputs.append(loss)
|
@@ -92,6 +105,7 @@ class NTCallbackBase(_LightningCallback):
|
|
92
105
|
|
93
106
|
|
94
107
|
class MyCallback(L.Callback):
|
108
|
+
@override
|
95
109
|
def on_train_epoch_end(self, trainer, pl_module):
|
96
110
|
# do something with all training_step outputs, for example: # pyright: ignore[reportIncompatibleMethodOverride]
|
97
111
|
epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
|
@@ -101,36 +115,43 @@ class NTCallbackBase(_LightningCallback):
|
|
101
115
|
|
102
116
|
"""
|
103
117
|
|
118
|
+
@override
|
104
119
|
def on_validation_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
105
120
|
self, trainer: Trainer, pl_module: LightningModule
|
106
121
|
) -> None:
|
107
122
|
"""Called when the val epoch begins."""
|
108
123
|
|
124
|
+
@override
|
109
125
|
def on_validation_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
110
126
|
self, trainer: Trainer, pl_module: LightningModule
|
111
127
|
) -> None:
|
112
128
|
"""Called when the val epoch ends."""
|
113
129
|
|
130
|
+
@override
|
114
131
|
def on_test_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
115
132
|
self, trainer: Trainer, pl_module: LightningModule
|
116
133
|
) -> None:
|
117
134
|
"""Called when the test epoch begins."""
|
118
135
|
|
136
|
+
@override
|
119
137
|
def on_test_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
120
138
|
self, trainer: Trainer, pl_module: LightningModule
|
121
139
|
) -> None:
|
122
140
|
"""Called when the test epoch ends."""
|
123
141
|
|
142
|
+
@override
|
124
143
|
def on_predict_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
125
144
|
self, trainer: Trainer, pl_module: LightningModule
|
126
145
|
) -> None:
|
127
146
|
"""Called when the predict epoch begins."""
|
128
147
|
|
148
|
+
@override
|
129
149
|
def on_predict_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
130
150
|
self, trainer: Trainer, pl_module: LightningModule
|
131
151
|
) -> None:
|
132
152
|
"""Called when the predict epoch ends."""
|
133
153
|
|
154
|
+
@override
|
134
155
|
def on_validation_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
135
156
|
self,
|
136
157
|
trainer: Trainer,
|
@@ -141,6 +162,7 @@ class NTCallbackBase(_LightningCallback):
|
|
141
162
|
) -> None:
|
142
163
|
"""Called when the validation batch begins."""
|
143
164
|
|
165
|
+
@override
|
144
166
|
def on_validation_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
145
167
|
self,
|
146
168
|
trainer: Trainer,
|
@@ -152,6 +174,7 @@ class NTCallbackBase(_LightningCallback):
|
|
152
174
|
) -> None:
|
153
175
|
"""Called when the validation batch ends."""
|
154
176
|
|
177
|
+
@override
|
155
178
|
def on_test_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
156
179
|
self,
|
157
180
|
trainer: Trainer,
|
@@ -162,6 +185,7 @@ class NTCallbackBase(_LightningCallback):
|
|
162
185
|
) -> None:
|
163
186
|
"""Called when the test batch begins."""
|
164
187
|
|
188
|
+
@override
|
165
189
|
def on_test_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
166
190
|
self,
|
167
191
|
trainer: Trainer,
|
@@ -173,6 +197,7 @@ class NTCallbackBase(_LightningCallback):
|
|
173
197
|
) -> None:
|
174
198
|
"""Called when the test batch ends."""
|
175
199
|
|
200
|
+
@override
|
176
201
|
def on_predict_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
177
202
|
self,
|
178
203
|
trainer: Trainer,
|
@@ -183,6 +208,7 @@ class NTCallbackBase(_LightningCallback):
|
|
183
208
|
) -> None:
|
184
209
|
"""Called when the predict batch begins."""
|
185
210
|
|
211
|
+
@override
|
186
212
|
def on_predict_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
187
213
|
self,
|
188
214
|
trainer: Trainer,
|
@@ -194,36 +220,45 @@ class NTCallbackBase(_LightningCallback):
|
|
194
220
|
) -> None:
|
195
221
|
"""Called when the predict batch ends."""
|
196
222
|
|
223
|
+
@override
|
197
224
|
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
198
225
|
"""Called when the train begins."""
|
199
226
|
|
227
|
+
@override
|
200
228
|
def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
201
229
|
"""Called when the train ends."""
|
202
230
|
|
231
|
+
@override
|
203
232
|
def on_validation_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
204
233
|
self, trainer: Trainer, pl_module: LightningModule
|
205
234
|
) -> None:
|
206
235
|
"""Called when the validation loop begins."""
|
207
236
|
|
237
|
+
@override
|
208
238
|
def on_validation_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
209
239
|
self, trainer: Trainer, pl_module: LightningModule
|
210
240
|
) -> None:
|
211
241
|
"""Called when the validation loop ends."""
|
212
242
|
|
243
|
+
@override
|
213
244
|
def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
214
245
|
"""Called when the test begins."""
|
215
246
|
|
247
|
+
@override
|
216
248
|
def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
217
249
|
"""Called when the test ends."""
|
218
250
|
|
251
|
+
@override
|
219
252
|
def on_predict_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
220
253
|
self, trainer: Trainer, pl_module: LightningModule
|
221
254
|
) -> None:
|
222
255
|
"""Called when the predict begins."""
|
223
256
|
|
257
|
+
@override
|
224
258
|
def on_predict_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
225
259
|
"""Called when predict ends."""
|
226
260
|
|
261
|
+
@override
|
227
262
|
def on_exception( # pyright: ignore[reportIncompatibleMethodOverride]
|
228
263
|
self,
|
229
264
|
trainer: Trainer,
|
@@ -232,7 +267,8 @@ class NTCallbackBase(_LightningCallback):
|
|
232
267
|
) -> None:
|
233
268
|
"""Called when any trainer execution is interrupted by an exception."""
|
234
269
|
|
235
|
-
|
270
|
+
@override
|
271
|
+
def state_dict(self) -> dict[str, Any]:
|
236
272
|
"""Called when saving a checkpoint, implement to generate callback's ``state_dict``.
|
237
273
|
|
238
274
|
Returns:
|
@@ -241,7 +277,8 @@ class NTCallbackBase(_LightningCallback):
|
|
241
277
|
"""
|
242
278
|
return {}
|
243
279
|
|
244
|
-
|
280
|
+
@override
|
281
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
245
282
|
"""Called when loading a checkpoint, implement to reload callback state given callback's ``state_dict``.
|
246
283
|
|
247
284
|
Args:
|
@@ -250,6 +287,7 @@ class NTCallbackBase(_LightningCallback):
|
|
250
287
|
"""
|
251
288
|
pass
|
252
289
|
|
290
|
+
@override
|
253
291
|
def on_save_checkpoint( # pyright: ignore[reportIncompatibleMethodOverride]
|
254
292
|
self,
|
255
293
|
trainer: Trainer,
|
@@ -265,6 +303,7 @@ class NTCallbackBase(_LightningCallback):
|
|
265
303
|
|
266
304
|
"""
|
267
305
|
|
306
|
+
@override
|
268
307
|
def on_load_checkpoint( # pyright: ignore[reportIncompatibleMethodOverride]
|
269
308
|
self,
|
270
309
|
trainer: Trainer,
|
@@ -280,16 +319,19 @@ class NTCallbackBase(_LightningCallback):
|
|
280
319
|
|
281
320
|
"""
|
282
321
|
|
322
|
+
@override
|
283
323
|
def on_before_backward( # pyright: ignore[reportIncompatibleMethodOverride]
|
284
324
|
self, trainer: Trainer, pl_module: LightningModule, loss: torch.Tensor
|
285
325
|
) -> None:
|
286
326
|
"""Called before ``loss.backward()``."""
|
287
327
|
|
328
|
+
@override
|
288
329
|
def on_after_backward( # pyright: ignore[reportIncompatibleMethodOverride]
|
289
330
|
self, trainer: Trainer, pl_module: LightningModule
|
290
331
|
) -> None:
|
291
332
|
"""Called after ``loss.backward()`` and before optimizers are stepped."""
|
292
333
|
|
334
|
+
@override
|
293
335
|
def on_before_optimizer_step( # pyright: ignore[reportIncompatibleMethodOverride]
|
294
336
|
self,
|
295
337
|
trainer: Trainer,
|
@@ -298,6 +340,7 @@ class NTCallbackBase(_LightningCallback):
|
|
298
340
|
) -> None:
|
299
341
|
"""Called before ``optimizer.step()``."""
|
300
342
|
|
343
|
+
@override
|
301
344
|
def on_before_zero_grad( # pyright: ignore[reportIncompatibleMethodOverride]
|
302
345
|
self,
|
303
346
|
trainer: Trainer,
|
@@ -306,7 +349,10 @@ class NTCallbackBase(_LightningCallback):
|
|
306
349
|
) -> None:
|
307
350
|
"""Called before ``optimizer.zero_grad()``."""
|
308
351
|
|
309
|
-
|
352
|
+
# =================================================================
|
353
|
+
# Our own new callbacks
|
354
|
+
# =================================================================
|
355
|
+
def on_checkpoint_saved(
|
310
356
|
self,
|
311
357
|
ckpt_path: Path,
|
312
358
|
metadata_path: Path | None,
|
@@ -317,6 +363,7 @@ class NTCallbackBase(_LightningCallback):
|
|
317
363
|
pass
|
318
364
|
|
319
365
|
|
366
|
+
@override
|
320
367
|
def _call_on_checkpoint_saved(
|
321
368
|
trainer: Trainer,
|
322
369
|
ckpt_path: str | Path,
|
nshtrainer/callbacks/__init__.py
CHANGED
@@ -75,5 +75,5 @@ from .wandb_watch import WandbWatchCallbackConfig as WandbWatchCallbackConfig
|
|
75
75
|
|
76
76
|
CallbackConfig = TypeAliasType(
|
77
77
|
"CallbackConfig",
|
78
|
-
Annotated[CallbackConfigBase, callback_registry
|
78
|
+
Annotated[CallbackConfigBase, callback_registry],
|
79
79
|
)
|
@@ -5,13 +5,13 @@ import string
|
|
5
5
|
from abc import ABC, abstractmethod
|
6
6
|
from collections.abc import Callable
|
7
7
|
from pathlib import Path
|
8
|
-
from typing import TYPE_CHECKING, Any, Generic, Literal
|
8
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal
|
9
9
|
|
10
10
|
import numpy as np
|
11
11
|
import torch
|
12
12
|
from lightning.pytorch import Trainer
|
13
13
|
from lightning.pytorch.callbacks import Checkpoint
|
14
|
-
from typing_extensions import override
|
14
|
+
from typing_extensions import TypeVar, override
|
15
15
|
|
16
16
|
from ..._checkpoint.metadata import CheckpointMetadata, _generate_checkpoint_metadata
|
17
17
|
from ..._checkpoint.saver import link_checkpoint, remove_checkpoint
|
@@ -35,7 +35,7 @@ class LogEpochCallbackConfig(CallbackConfigBase):
|
|
35
35
|
yield LogEpochCallback(self)
|
36
36
|
|
37
37
|
|
38
|
-
def
|
38
|
+
def _log_on_step(
|
39
39
|
trainer: Trainer,
|
40
40
|
pl_module: LightningModule,
|
41
41
|
num_batches_prop: str,
|
@@ -75,6 +75,19 @@ def _worker_fn(
|
|
75
75
|
pl_module.log(metric_name, epoch, on_step=True, on_epoch=False)
|
76
76
|
|
77
77
|
|
78
|
+
def _log_on_epoch(
|
79
|
+
trainer: Trainer,
|
80
|
+
pl_module: LightningModule,
|
81
|
+
*,
|
82
|
+
metric_name: str,
|
83
|
+
):
|
84
|
+
if trainer.logger is None:
|
85
|
+
return
|
86
|
+
|
87
|
+
epoch = pl_module.current_epoch + 1
|
88
|
+
pl_module.log(metric_name, epoch, on_step=False, on_epoch=True)
|
89
|
+
|
90
|
+
|
78
91
|
class LogEpochCallback(Callback):
|
79
92
|
def __init__(self, config: LogEpochCallbackConfig):
|
80
93
|
super().__init__()
|
@@ -85,16 +98,27 @@ class LogEpochCallback(Callback):
|
|
85
98
|
def on_train_batch_start(
|
86
99
|
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
|
87
100
|
):
|
88
|
-
if
|
101
|
+
if not self.config.train:
|
89
102
|
return
|
90
103
|
|
91
|
-
|
104
|
+
_log_on_step(
|
92
105
|
trainer,
|
93
106
|
pl_module,
|
94
107
|
"num_training_batches",
|
95
108
|
metric_name=self.config.metric_name,
|
96
109
|
)
|
97
110
|
|
111
|
+
@override
|
112
|
+
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
113
|
+
if not self.config.train:
|
114
|
+
return
|
115
|
+
|
116
|
+
_log_on_epoch(
|
117
|
+
trainer,
|
118
|
+
pl_module,
|
119
|
+
metric_name=self.config.metric_name,
|
120
|
+
)
|
121
|
+
|
98
122
|
@override
|
99
123
|
def on_validation_batch_start(
|
100
124
|
self,
|
@@ -104,10 +128,10 @@ class LogEpochCallback(Callback):
|
|
104
128
|
batch_idx: int,
|
105
129
|
dataloader_idx: int = 0,
|
106
130
|
) -> None:
|
107
|
-
if
|
131
|
+
if not self.config.val:
|
108
132
|
return
|
109
133
|
|
110
|
-
|
134
|
+
_log_on_step(
|
111
135
|
trainer,
|
112
136
|
pl_module,
|
113
137
|
"num_val_batches",
|
@@ -115,6 +139,19 @@ class LogEpochCallback(Callback):
|
|
115
139
|
metric_name=self.config.metric_name,
|
116
140
|
)
|
117
141
|
|
142
|
+
@override
|
143
|
+
def on_validation_epoch_end(
|
144
|
+
self, trainer: Trainer, pl_module: LightningModule
|
145
|
+
) -> None:
|
146
|
+
if not self.config.val:
|
147
|
+
return
|
148
|
+
|
149
|
+
_log_on_epoch(
|
150
|
+
trainer,
|
151
|
+
pl_module,
|
152
|
+
metric_name=self.config.metric_name,
|
153
|
+
)
|
154
|
+
|
118
155
|
@override
|
119
156
|
def on_test_batch_start(
|
120
157
|
self,
|
@@ -124,13 +161,24 @@ class LogEpochCallback(Callback):
|
|
124
161
|
batch_idx: int,
|
125
162
|
dataloader_idx: int = 0,
|
126
163
|
) -> None:
|
127
|
-
if
|
164
|
+
if not self.config.test:
|
128
165
|
return
|
129
166
|
|
130
|
-
|
167
|
+
_log_on_step(
|
131
168
|
trainer,
|
132
169
|
pl_module,
|
133
170
|
"num_test_batches",
|
134
171
|
dataloader_idx=dataloader_idx,
|
135
172
|
metric_name=self.config.metric_name,
|
136
173
|
)
|
174
|
+
|
175
|
+
@override
|
176
|
+
def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
177
|
+
if not self.config.test:
|
178
|
+
return
|
179
|
+
|
180
|
+
_log_on_epoch(
|
181
|
+
trainer,
|
182
|
+
pl_module,
|
183
|
+
metric_name=self.config.metric_name,
|
184
|
+
)
|
@@ -67,14 +67,14 @@ class PrintTableMetricsCallback(Callback):
|
|
67
67
|
}
|
68
68
|
self.metrics.append(metrics_dict)
|
69
69
|
|
70
|
-
from rich.console import Console #
|
70
|
+
from rich.console import Console # pyright: ignore[reportMissingImports] # noqa
|
71
71
|
|
72
72
|
console = Console()
|
73
73
|
table = self.create_metrics_table()
|
74
74
|
console.print(table)
|
75
75
|
|
76
76
|
def create_metrics_table(self):
|
77
|
-
from rich.table import Table #
|
77
|
+
from rich.table import Table # pyright: ignore[reportMissingImports] # noqa
|
78
78
|
|
79
79
|
table = Table(show_header=True, header_style="bold magenta")
|
80
80
|
|
nshtrainer/configs/__init__.py
CHANGED
@@ -111,7 +111,6 @@ from nshtrainer.optimizer import RAdamConfig as RAdamConfig
|
|
111
111
|
from nshtrainer.optimizer import RMSpropConfig as RMSpropConfig
|
112
112
|
from nshtrainer.optimizer import RpropConfig as RpropConfig
|
113
113
|
from nshtrainer.optimizer import SGDConfig as SGDConfig
|
114
|
-
from nshtrainer.optimizer import Union as Union
|
115
114
|
from nshtrainer.optimizer import optimizer_registry as optimizer_registry
|
116
115
|
from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
|
117
116
|
from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
|
@@ -355,7 +354,6 @@ __all__ = [
|
|
355
354
|
"TorchSyncBatchNormPlugin",
|
356
355
|
"TrainerConfig",
|
357
356
|
"TransformerEnginePluginConfig",
|
358
|
-
"Union",
|
359
357
|
"WandbLoggerConfig",
|
360
358
|
"WandbUploadCodeCallbackConfig",
|
361
359
|
"WandbWatchCallbackConfig",
|
@@ -16,7 +16,6 @@ from nshtrainer.optimizer import RAdamConfig as RAdamConfig
|
|
16
16
|
from nshtrainer.optimizer import RMSpropConfig as RMSpropConfig
|
17
17
|
from nshtrainer.optimizer import RpropConfig as RpropConfig
|
18
18
|
from nshtrainer.optimizer import SGDConfig as SGDConfig
|
19
|
-
from nshtrainer.optimizer import Union as Union
|
20
19
|
from nshtrainer.optimizer import optimizer_registry as optimizer_registry
|
21
20
|
|
22
21
|
__all__ = [
|
@@ -34,6 +33,5 @@ __all__ = [
|
|
34
33
|
"RMSpropConfig",
|
35
34
|
"RpropConfig",
|
36
35
|
"SGDConfig",
|
37
|
-
"Union",
|
38
36
|
"optimizer_registry",
|
39
37
|
]
|
nshtrainer/loggers/__init__.py
CHANGED
@@ -12,6 +12,5 @@ from .tensorboard import TensorboardLoggerConfig as TensorboardLoggerConfig
|
|
12
12
|
from .wandb import WandbLoggerConfig as WandbLoggerConfig
|
13
13
|
|
14
14
|
LoggerConfig = TypeAliasType(
|
15
|
-
"LoggerConfig",
|
16
|
-
Annotated[LoggerConfigBase, logger_registry.DynamicResolution()],
|
15
|
+
"LoggerConfig", Annotated[LoggerConfigBase, logger_registry]
|
17
16
|
)
|
nshtrainer/loggers/actsave.py
CHANGED
@@ -5,7 +5,7 @@ from typing import Any, Literal
|
|
5
5
|
|
6
6
|
import numpy as np
|
7
7
|
from lightning.pytorch.loggers import Logger
|
8
|
-
from typing_extensions import final
|
8
|
+
from typing_extensions import final, override
|
9
9
|
|
10
10
|
from .base import LoggerConfigBase, logger_registry
|
11
11
|
|
@@ -15,6 +15,7 @@ from .base import LoggerConfigBase, logger_registry
|
|
15
15
|
class ActSaveLoggerConfig(LoggerConfigBase):
|
16
16
|
name: Literal["actsave"] = "actsave"
|
17
17
|
|
18
|
+
@override
|
18
19
|
def create_logger(self, trainer_config):
|
19
20
|
if not self.enabled:
|
20
21
|
return None
|
@@ -24,10 +25,12 @@ class ActSaveLoggerConfig(LoggerConfigBase):
|
|
24
25
|
|
25
26
|
class ActSaveLogger(Logger):
|
26
27
|
@property
|
28
|
+
@override
|
27
29
|
def name(self):
|
28
30
|
return None
|
29
31
|
|
30
32
|
@property
|
33
|
+
@override
|
31
34
|
def version(self):
|
32
35
|
from nshutils import ActSave
|
33
36
|
|
@@ -37,6 +40,7 @@ class ActSaveLogger(Logger):
|
|
37
40
|
return ActSave._saver._id
|
38
41
|
|
39
42
|
@property
|
43
|
+
@override
|
40
44
|
def save_dir(self):
|
41
45
|
from nshutils import ActSave
|
42
46
|
|
@@ -45,6 +49,7 @@ class ActSaveLogger(Logger):
|
|
45
49
|
|
46
50
|
return str(ActSave._saver._save_dir)
|
47
51
|
|
52
|
+
@override
|
48
53
|
def log_hyperparams(
|
49
54
|
self,
|
50
55
|
params: dict[str, Any] | Namespace,
|
@@ -56,6 +61,7 @@ class ActSaveLogger(Logger):
|
|
56
61
|
# Wrap the hparams as a object-dtype np array
|
57
62
|
return ActSave.save({"hyperparameters": np.array(params, dtype=object)})
|
58
63
|
|
64
|
+
@override
|
59
65
|
def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> None:
|
60
66
|
from nshutils import ActSave
|
61
67
|
|
nshtrainer/loggers/wandb.py
CHANGED
@@ -63,7 +63,7 @@ class FinishWandbOnTeardownCallback(Callback):
|
|
63
63
|
stage: str,
|
64
64
|
):
|
65
65
|
try:
|
66
|
-
import wandb
|
66
|
+
import wandb
|
67
67
|
except ImportError:
|
68
68
|
return
|
69
69
|
|
@@ -139,7 +139,7 @@ class WandbLoggerConfig(CallbackConfigBase, LoggerConfigBase):
|
|
139
139
|
# If `wandb-core` is enabled, we should use the new backend.
|
140
140
|
if self.use_wandb_core:
|
141
141
|
try:
|
142
|
-
import wandb
|
142
|
+
import wandb
|
143
143
|
|
144
144
|
# The minimum version that supports the new backend is 0.17.5
|
145
145
|
wandb_version = version.parse(importlib.metadata.version("wandb"))
|
@@ -151,7 +151,7 @@ class WandbLoggerConfig(CallbackConfigBase, LoggerConfigBase):
|
|
151
151
|
)
|
152
152
|
# W&B versions 0.18.0 use wandb-core by default
|
153
153
|
elif wandb_version < version.parse("0.18.0"):
|
154
|
-
wandb.require("core")
|
154
|
+
wandb.require("core")
|
155
155
|
log.critical("Using the `wandb-core` backend for WandB.")
|
156
156
|
except ImportError:
|
157
157
|
pass
|
@@ -166,9 +166,9 @@ class WandbLoggerConfig(CallbackConfigBase, LoggerConfigBase):
|
|
166
166
|
"If you want to use the new `wandb-core` backend, set `use_wandb_core=True`."
|
167
167
|
)
|
168
168
|
try:
|
169
|
-
import wandb
|
169
|
+
import wandb
|
170
170
|
|
171
|
-
wandb.require("legacy-service")
|
171
|
+
wandb.require("legacy-service")
|
172
172
|
except ImportError:
|
173
173
|
pass
|
174
174
|
|
nshtrainer/lr_scheduler/base.py
CHANGED
@@ -81,7 +81,7 @@ class LRSchedulerConfigBase(C.Config, ABC):
|
|
81
81
|
scheduler["monitor"] = metadata["monitor"]
|
82
82
|
# - `strict`
|
83
83
|
if scheduler.get("strict") is None and "strict" in metadata:
|
84
|
-
scheduler["strict"] = metadata["strict"]
|
84
|
+
scheduler["strict"] = metadata["strict"]
|
85
85
|
|
86
86
|
return scheduler
|
87
87
|
|
@@ -41,23 +41,6 @@ class CallbackModuleMixin(
|
|
41
41
|
CallbackRegistrarModuleMixin,
|
42
42
|
mixin_base_type(LightningModule),
|
43
43
|
):
|
44
|
-
@property
|
45
|
-
def _nshtrainer_callbacks(self) -> list[CallbackFn]:
|
46
|
-
if not hasattr(self, "_private_nshtrainer_callbacks_list"):
|
47
|
-
self._private_nshtrainer_callbacks_list = []
|
48
|
-
return self._private_nshtrainer_callbacks_list
|
49
|
-
|
50
|
-
def register_callback(
|
51
|
-
self,
|
52
|
-
callback: _Callback | Iterable[_Callback] | CallbackFn | None = None,
|
53
|
-
):
|
54
|
-
if not callable(callback):
|
55
|
-
callback_ = cast(CallbackFn, lambda: callback)
|
56
|
-
else:
|
57
|
-
callback_ = callback
|
58
|
-
|
59
|
-
self._nshtrainer_callbacks.append(callback_)
|
60
|
-
|
61
44
|
def _gather_all_callbacks(self):
|
62
45
|
modules: list[Any] = []
|
63
46
|
if isinstance(self, CallbackRegistrarModuleMixin):
|
@@ -203,6 +203,7 @@ class LoggerLightningModuleMixin(mixin_base_type(LightningModule)):
|
|
203
203
|
name = f"{prefix}{name}"
|
204
204
|
return super().log(name, value, metric_attribute=metric_attribute, **fn_kwargs)
|
205
205
|
|
206
|
+
@override
|
206
207
|
def log_dict(
|
207
208
|
self,
|
208
209
|
dictionary: Mapping[str, _METRIC] | torchmetrics.MetricCollection,
|
nshtrainer/nn/module_dict.py
CHANGED
@@ -28,9 +28,9 @@ class TypedModuleDict(nn.Module, Generic[TModule]):
|
|
28
28
|
return f"{self.key_prefix}{key}"
|
29
29
|
|
30
30
|
def _remove_prefix(self, key: str) -> str:
|
31
|
-
assert key.startswith(
|
32
|
-
self.key_prefix
|
33
|
-
)
|
31
|
+
assert key.startswith(self.key_prefix), (
|
32
|
+
f"{key} does not start with {self.key_prefix}"
|
33
|
+
)
|
34
34
|
return key[len(self.key_prefix) :]
|
35
35
|
|
36
36
|
def __setitem__(self, key: str, module: TModule) -> None:
|
@@ -39,7 +39,7 @@ class TypedModuleDict(nn.Module, Generic[TModule]):
|
|
39
39
|
|
40
40
|
def __getitem__(self, key: str) -> TModule:
|
41
41
|
key = self._with_prefix(key)
|
42
|
-
return self._module_dict.__getitem__(key)
|
42
|
+
return cast(TModule, self._module_dict.__getitem__(key))
|
43
43
|
|
44
44
|
def update(self, modules: Mapping[str, TModule]) -> None:
|
45
45
|
return self._module_dict.update(
|
nshtrainer/nn/module_list.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from collections.abc import Iterable, Iterator
|
4
|
-
from typing import Generic,
|
4
|
+
from typing import Generic, cast, overload
|
5
5
|
|
6
6
|
import torch.nn as nn
|
7
|
-
from typing_extensions import override
|
7
|
+
from typing_extensions import TypeVar, override
|
8
8
|
|
9
|
-
TModule = TypeVar("TModule", bound=nn.Module)
|
9
|
+
TModule = TypeVar("TModule", bound=nn.Module, infer_variance=True)
|
10
10
|
|
11
11
|
|
12
12
|
class TypedModuleList(nn.ModuleList, Generic[TModule]):
|
@@ -14,39 +14,39 @@ class TypedModuleList(nn.ModuleList, Generic[TModule]):
|
|
14
14
|
super().__init__(modules)
|
15
15
|
|
16
16
|
@overload
|
17
|
-
def __getitem__(self, idx: slice) ->
|
17
|
+
def __getitem__(self, idx: slice) -> TypedModuleList[TModule]: ...
|
18
18
|
|
19
19
|
@overload
|
20
20
|
def __getitem__(self, idx: int) -> TModule: ...
|
21
21
|
|
22
22
|
@override
|
23
|
-
def __getitem__(self, idx: int | slice) -> TModule |
|
24
|
-
return super().__getitem__(idx)
|
23
|
+
def __getitem__(self, idx: int | slice) -> TModule | TypedModuleList[TModule]:
|
24
|
+
return cast(TModule | TypedModuleList[TModule], super().__getitem__(idx))
|
25
25
|
|
26
26
|
@override
|
27
|
-
def __setitem__(self, idx: int, module: TModule) -> None: #
|
27
|
+
def __setitem__(self, idx: int, module: TModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
28
28
|
return super().__setitem__(idx, module)
|
29
29
|
|
30
30
|
@override
|
31
31
|
def __iter__(self) -> Iterator[TModule]:
|
32
|
-
return super().__iter__()
|
32
|
+
return cast(Iterator[TModule], super().__iter__())
|
33
33
|
|
34
34
|
@override
|
35
|
-
def __iadd__(self, modules: Iterable[TModule]) ->
|
36
|
-
return super().__iadd__(modules)
|
35
|
+
def __iadd__(self, modules: Iterable[TModule]) -> TypedModuleList[TModule]: # pyright: ignore[reportIncompatibleMethodOverride]
|
36
|
+
return cast(TypedModuleList[TModule], super().__iadd__(modules))
|
37
37
|
|
38
38
|
@override
|
39
|
-
def __add__(self, modules: Iterable[TModule]) ->
|
40
|
-
return super().__add__(modules)
|
39
|
+
def __add__(self, modules: Iterable[TModule]) -> TypedModuleList[TModule]: # pyright: ignore[reportIncompatibleMethodOverride]
|
40
|
+
return cast(TypedModuleList[TModule], super().__add__(modules))
|
41
41
|
|
42
42
|
@override
|
43
|
-
def insert(self, idx: int, module: TModule) -> None: #
|
43
|
+
def insert(self, idx: int, module: TModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
44
44
|
return super().insert(idx, module)
|
45
45
|
|
46
46
|
@override
|
47
|
-
def append(self, module: TModule) ->
|
48
|
-
return super().append(module)
|
47
|
+
def append(self, module: TModule) -> TypedModuleList[TModule]: # pyright: ignore[reportIncompatibleMethodOverride]
|
48
|
+
return cast(TypedModuleList[TModule], super().append(module))
|
49
49
|
|
50
50
|
@override
|
51
|
-
def extend(self, modules: Iterable[TModule]) ->
|
52
|
-
return super().extend(modules)
|
51
|
+
def extend(self, modules: Iterable[TModule]) -> TypedModuleList[TModule]: # pyright: ignore[reportIncompatibleMethodOverride]
|
52
|
+
return cast(TypedModuleList[TModule], super().extend(modules))
|
nshtrainer/nn/nonlinearity.py
CHANGED
@@ -30,6 +30,7 @@ class ReLUNonlinearityConfig(NonlinearityConfigBase):
|
|
30
30
|
def create_module(self) -> nn.Module:
|
31
31
|
return nn.ReLU()
|
32
32
|
|
33
|
+
@override
|
33
34
|
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
34
35
|
return F.relu(x)
|
35
36
|
|
@@ -43,6 +44,7 @@ class SigmoidNonlinearityConfig(NonlinearityConfigBase):
|
|
43
44
|
def create_module(self) -> nn.Module:
|
44
45
|
return nn.Sigmoid()
|
45
46
|
|
47
|
+
@override
|
46
48
|
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
47
49
|
return torch.sigmoid(x)
|
48
50
|
|
@@ -56,6 +58,7 @@ class TanhNonlinearityConfig(NonlinearityConfigBase):
|
|
56
58
|
def create_module(self) -> nn.Module:
|
57
59
|
return nn.Tanh()
|
58
60
|
|
61
|
+
@override
|
59
62
|
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
60
63
|
return torch.tanh(x)
|
61
64
|
|
@@ -72,6 +75,7 @@ class SoftmaxNonlinearityConfig(NonlinearityConfigBase):
|
|
72
75
|
def create_module(self) -> nn.Module:
|
73
76
|
return nn.Softmax(dim=self.dim)
|
74
77
|
|
78
|
+
@override
|
75
79
|
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
76
80
|
return torch.softmax(x, dim=self.dim)
|
77
81
|
|
@@ -91,6 +95,7 @@ class SoftplusNonlinearityConfig(NonlinearityConfigBase):
|
|
91
95
|
def create_module(self) -> nn.Module:
|
92
96
|
return nn.Softplus(beta=self.beta, threshold=self.threshold)
|
93
97
|
|
98
|
+
@override
|
94
99
|
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
95
100
|
return F.softplus(x, beta=self.beta, threshold=self.threshold)
|
96
101
|
|
@@ -104,6 +109,7 @@ class SoftsignNonlinearityConfig(NonlinearityConfigBase):
|
|
104
109
|
def create_module(self) -> nn.Module:
|
105
110
|
return nn.Softsign()
|
106
111
|
|
112
|
+
@override
|
107
113
|
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
108
114
|
return F.softsign(x)
|
109
115
|
|
@@ -120,6 +126,7 @@ class ELUNonlinearityConfig(NonlinearityConfigBase):
|
|
120
126
|
def create_module(self) -> nn.Module:
|
121
127
|
return nn.ELU()
|
122
128
|
|
129
|
+
@override
|
123
130
|
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
124
131
|
return F.elu(x, alpha=self.alpha)
|
125
132
|
|
@@ -136,6 +143,7 @@ class LeakyReLUNonlinearityConfig(NonlinearityConfigBase):
|
|
136
143
|
def create_module(self) -> nn.Module:
|
137
144
|
return nn.LeakyReLU(negative_slope=self.negative_slope)
|
138
145
|
|
146
|
+
@override
|
139
147
|
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
140
148
|
return F.leaky_relu(x, negative_slope=self.negative_slope)
|
141
149
|
|
@@ -157,6 +165,7 @@ class PReLUConfig(NonlinearityConfigBase):
|
|
157
165
|
def create_module(self) -> nn.Module:
|
158
166
|
return nn.PReLU(num_parameters=self.num_parameters, init=self.init)
|
159
167
|
|
168
|
+
@override
|
160
169
|
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
161
170
|
raise NotImplementedError(
|
162
171
|
"PReLU requires learnable parameters and cannot be called directly."
|
@@ -175,6 +184,7 @@ class GELUNonlinearityConfig(NonlinearityConfigBase):
|
|
175
184
|
def create_module(self) -> nn.Module:
|
176
185
|
return nn.GELU(approximate=self.approximate)
|
177
186
|
|
187
|
+
@override
|
178
188
|
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
179
189
|
return F.gelu(x, approximate=self.approximate)
|
180
190
|
|
@@ -188,6 +198,7 @@ class SwishNonlinearityConfig(NonlinearityConfigBase):
|
|
188
198
|
def create_module(self) -> nn.Module:
|
189
199
|
return nn.SiLU()
|
190
200
|
|
201
|
+
@override
|
191
202
|
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
192
203
|
return F.silu(x)
|
193
204
|
|
@@ -201,6 +212,7 @@ class SiLUNonlinearityConfig(NonlinearityConfigBase):
|
|
201
212
|
def create_module(self) -> nn.Module:
|
202
213
|
return nn.SiLU()
|
203
214
|
|
215
|
+
@override
|
204
216
|
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
205
217
|
return F.silu(x)
|
206
218
|
|
@@ -214,6 +226,7 @@ class MishNonlinearityConfig(NonlinearityConfigBase):
|
|
214
226
|
def create_module(self) -> nn.Module:
|
215
227
|
return nn.Mish()
|
216
228
|
|
229
|
+
@override
|
217
230
|
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
218
231
|
return F.mish(x)
|
219
232
|
|
@@ -234,12 +247,12 @@ class SwiGLUNonlinearityConfig(NonlinearityConfigBase):
|
|
234
247
|
def create_module(self) -> nn.Module:
|
235
248
|
return SwiGLU()
|
236
249
|
|
250
|
+
@override
|
237
251
|
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
238
252
|
input, gate = x.chunk(2, dim=-1)
|
239
253
|
return input * F.silu(gate)
|
240
254
|
|
241
255
|
|
242
256
|
NonlinearityConfig = TypeAliasType(
|
243
|
-
"NonlinearityConfig",
|
244
|
-
Annotated[NonlinearityConfigBase, nonlinearity_registry.DynamicResolution()],
|
257
|
+
"NonlinearityConfig", Annotated[NonlinearityConfigBase, nonlinearity_registry]
|
245
258
|
)
|
nshtrainer/optimizer.py
CHANGED
@@ -2,11 +2,10 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from collections.abc import Iterable
|
5
|
-
from typing import Annotated, Any, Literal
|
5
|
+
from typing import Annotated, Any, Literal
|
6
6
|
|
7
7
|
import nshconfig as C
|
8
8
|
import torch.nn as nn
|
9
|
-
from torch import Tensor
|
10
9
|
from torch.optim import Optimizer
|
11
10
|
from typing_extensions import TypeAliasType, final, override
|
12
11
|
|
@@ -621,6 +620,5 @@ class SGDConfig(OptimizerConfigBase):
|
|
621
620
|
|
622
621
|
|
623
622
|
OptimizerConfig = TypeAliasType(
|
624
|
-
"OptimizerConfig",
|
625
|
-
Annotated[OptimizerConfigBase, optimizer_registry.DynamicResolution()],
|
623
|
+
"OptimizerConfig", Annotated[OptimizerConfigBase, optimizer_registry]
|
626
624
|
)
|
@@ -23,8 +23,7 @@ class AcceleratorConfigBase(C.Config, ABC):
|
|
23
23
|
accelerator_registry = C.Registry(AcceleratorConfigBase, discriminator="name")
|
24
24
|
|
25
25
|
AcceleratorConfig = TypeAliasType(
|
26
|
-
"AcceleratorConfig",
|
27
|
-
Annotated[AcceleratorConfigBase, accelerator_registry.DynamicResolution()],
|
26
|
+
"AcceleratorConfig", Annotated[AcceleratorConfigBase, accelerator_registry]
|
28
27
|
)
|
29
28
|
|
30
29
|
|
@@ -13,6 +13,5 @@ from .base import PluginConfigBase as PluginConfigBase
|
|
13
13
|
from .base import plugin_registry as plugin_registry
|
14
14
|
|
15
15
|
PluginConfig = TypeAliasType(
|
16
|
-
"PluginConfig",
|
17
|
-
Annotated[PluginConfigBase, plugin_registry.DynamicResolution()],
|
16
|
+
"PluginConfig", Annotated[PluginConfigBase, plugin_registry]
|
18
17
|
)
|
nshtrainer/util/code_upload.py
CHANGED
@@ -17,7 +17,7 @@ def get_code_dir() -> Path | None:
|
|
17
17
|
# New versions of nshrunner will have the code_dir attribute
|
18
18
|
# in the session object. We should use that. Otherwise, use snapshot_dir.
|
19
19
|
try:
|
20
|
-
code_dir = session.code_dir
|
20
|
+
code_dir = session.code_dir
|
21
21
|
except AttributeError:
|
22
22
|
code_dir = session.snapshot_dir
|
23
23
|
|
@@ -1,15 +1,15 @@
|
|
1
1
|
nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
|
2
2
|
nshtrainer/__init__.py,sha256=RI_2B_IUWa10B6H5TAuWtE5FWX1X4ue-J4dTDaF2-lQ,1035
|
3
|
-
nshtrainer/_callback.py,sha256=
|
3
|
+
nshtrainer/_callback.py,sha256=aBg9Za6hjteHcGjb8bIGzaN57A03cXrPv4rMWqaNsLU,13253
|
4
4
|
nshtrainer/_checkpoint/metadata.py,sha256=El9Ip8jGA7mAN5rAMpVfg1dfUe2dGoOOfvF1JfYJGHM,5676
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=utcrYKSosd04N9m2GIylufO5DO05D90qVU3mvadfApU,1658
|
6
6
|
nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
7
7
|
nshtrainer/_hf_hub.py,sha256=OB4252GJ6AbKNCRmHVvEglvjYVMUN822BFYECABxfZU,14037
|
8
|
-
nshtrainer/callbacks/__init__.py,sha256=
|
8
|
+
nshtrainer/callbacks/__init__.py,sha256=6l2vrFywWftzKTlZMEkF-WgE5uLjLgX89BoUMq8_x-0,3980
|
9
9
|
nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
|
10
10
|
nshtrainer/callbacks/base.py,sha256=K9aom1WVVRYxl-tHWgtmDUQZ1o63NgznvLsjauTKcCc,4225
|
11
11
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
|
12
|
-
nshtrainer/callbacks/checkpoint/_base.py,sha256=
|
12
|
+
nshtrainer/callbacks/checkpoint/_base.py,sha256=AsNt1bZ-yloPHqenRy4KAJK5DDmhBY1RprR2_xbvomc,11010
|
13
13
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=aCs3E1eucfDlUeW2Iq_Ke7hb96BxHanmvn7PCCbqq0E,2648
|
14
14
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM6kUSHYNwHJeYG82j2dMcc,3554
|
15
15
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
|
@@ -21,18 +21,18 @@ nshtrainer/callbacks/ema.py,sha256=dBFiUXG0xmyCw8-ayuSzJMKqSbepl6Ii5VIbhFlT5ug,1
|
|
21
21
|
nshtrainer/callbacks/finite_checks.py,sha256=3lZ3kEIjmYQfqTF0DcrgZ9_98ZLQhQj8usH7SgWst3o,2185
|
22
22
|
nshtrainer/callbacks/gradient_skipping.py,sha256=8g7oC7PF0LTAEzwiNoaS5tWOnkjk_EB0QG3JdHkQ8ek,3523
|
23
23
|
nshtrainer/callbacks/interval.py,sha256=UCzUzt3XCFVyQyCWL9lOrStkkxesvduNOYk8yMrGTTk,8116
|
24
|
-
nshtrainer/callbacks/log_epoch.py,sha256
|
24
|
+
nshtrainer/callbacks/log_epoch.py,sha256=-uC5ss9p_ngXUCrSIUwViFcaaVX6ALUzIAKxoDgZrac,4823
|
25
25
|
nshtrainer/callbacks/lr_monitor.py,sha256=v45ehnwNO987087HfiOY5aIrVRbwdKMgPYRFHs1fyEE,1444
|
26
26
|
nshtrainer/callbacks/metric_validation.py,sha256=4RDr1FuNKfro-6QEtmcFqT4iNf2twmJVNk9y-8nq9bg,2882
|
27
27
|
nshtrainer/callbacks/norm_logging.py,sha256=nVIDWe-ASl5zN830-ODR8QMCqI1ma-QPCIwoy0Wb-Nk,6390
|
28
|
-
nshtrainer/callbacks/print_table.py,sha256=
|
29
|
-
nshtrainer/callbacks/rlp_sanity_checks.py,sha256=
|
28
|
+
nshtrainer/callbacks/print_table.py,sha256=xdDvogpLFHdaHM4yDGENvJUX4Gz4hDq-QpsPcv-Oqi8,3041
|
29
|
+
nshtrainer/callbacks/rlp_sanity_checks.py,sha256=PRtcj9K9fa2Oh6nbKQJR6w2__on0Jln969qZXlnkv1Q,10064
|
30
30
|
nshtrainer/callbacks/shared_parameters.py,sha256=s94jJTAIbDGukYJu6l247QonVOCudGClU4t5kLt8XrY,3076
|
31
31
|
nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU,4731
|
32
32
|
nshtrainer/callbacks/wandb_upload_code.py,sha256=4X-mpiX5ghj9vnEreK2i8Xyvimqt0K-PNWA2HtT-B6I,1940
|
33
33
|
nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
|
34
34
|
nshtrainer/configs/.gitattributes,sha256=VeZmarvNEqiRBOHGcllpKm90nL6C8u4tBu7SEm7fj-E,26
|
35
|
-
nshtrainer/configs/__init__.py,sha256
|
35
|
+
nshtrainer/configs/__init__.py,sha256=ZHV_1zCZKUYBKzWiLPrF8eFKsb-gepAF4G7AmsCxkkA,15623
|
36
36
|
nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
|
37
37
|
nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
|
38
38
|
nshtrainer/configs/_hf_hub/__init__.py,sha256=ciFLbV-JV8SVzqo2SyythEuDMnk7gGfdIacB18QYnkY,511
|
@@ -77,7 +77,7 @@ nshtrainer/configs/nn/__init__.py,sha256=Ms2gIqbRxNVm6GHKCddCJTTqMwUPifjjHD_fCfJ
|
|
77
77
|
nshtrainer/configs/nn/mlp/__init__.py,sha256=O6kQ6utZNJPG9Fax5pRdZcHa3J-XFKKdXcc_PQg0jk0,347
|
78
78
|
nshtrainer/configs/nn/nonlinearity/__init__.py,sha256=LCTbTyelCMABVw505CGQ4UpEGlAnIhflSLFqwAQXLQA,2155
|
79
79
|
nshtrainer/configs/nn/rng/__init__.py,sha256=4iC6vwxbfNeXyvpwZ1Z5Kcy-he4cu7mg3UpLD-RLrHc,141
|
80
|
-
nshtrainer/configs/optimizer/__init__.py,sha256=
|
80
|
+
nshtrainer/configs/optimizer/__init__.py,sha256=Kq6ACztSQhwgE_tP4F1RI7nQMBgC1ebQmY3HaBYKbeg,1337
|
81
81
|
nshtrainer/configs/profiler/__init__.py,sha256=2ssaIpfVnvcbfNvZ-JeKp1Cx4NO1LknkVqTm1hu7Lvw,768
|
82
82
|
nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRDm1CKqjwUOQNbQjD4,176
|
83
83
|
nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
|
@@ -103,30 +103,30 @@ nshtrainer/data/__init__.py,sha256=K4i3Tw4g9EOK2zlMMbidi99y0SyI4w8P7_XUf1n42Ts,2
|
|
103
103
|
nshtrainer/data/balanced_batch_sampler.py,sha256=r1cBKRXKHD8E1Ax6tj-FUbE-z1qpbO58mQ9VrK9uLnc,5481
|
104
104
|
nshtrainer/data/datamodule.py,sha256=Rb4-mA8iXtjRlNUHcIqVPEvxA_VkiJXwN1EvHIsydJ0,4095
|
105
105
|
nshtrainer/data/transform.py,sha256=qd0lIocO59Fk_m90xyOHgFezbymd1mRwly8nbYIfHGc,2263
|
106
|
-
nshtrainer/loggers/__init__.py,sha256=
|
107
|
-
nshtrainer/loggers/actsave.py,sha256=
|
106
|
+
nshtrainer/loggers/__init__.py,sha256=0fnclaEIgAUrRlYuSKfzni11dlJ6edllrs06NVmbtYc,567
|
107
|
+
nshtrainer/loggers/actsave.py,sha256=Xd21jaBVUmkxITKYfycWKZEgcHu1-dmi1H5EYEjvPDw,1503
|
108
108
|
nshtrainer/loggers/base.py,sha256=ON92XbwTSgadQOSyw5PiRRFzyH6uJ-xLtE0nB3cbgPc,1205
|
109
109
|
nshtrainer/loggers/csv.py,sha256=xJ8mSmw4vJwinIfqhF6t2HWmh_1dXEYyLfGuXwL7WHo,1160
|
110
110
|
nshtrainer/loggers/tensorboard.py,sha256=E7iO_fDt9bfH02hBL430bXPLljOo5iGgq2QyPqmx2gQ,2324
|
111
|
-
nshtrainer/loggers/wandb.py,sha256=
|
111
|
+
nshtrainer/loggers/wandb.py,sha256=EK2rvJwmV-LxXIm21ZORNBI9nz-AXqo2mIN7xyjs8bc,6776
|
112
112
|
nshtrainer/lr_scheduler/__init__.py,sha256=daMMK3erUcNXGGd_nZB8AWu3ZTYqfS1RSWeK4FV2udw,851
|
113
|
-
nshtrainer/lr_scheduler/base.py,sha256=
|
113
|
+
nshtrainer/lr_scheduler/base.py,sha256=24I3PNlj2iYPoaHeD2_InMAptclCtMnZoD8nXWqxLYw,3740
|
114
114
|
nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=MsoXgCcWTKsrkNZiGnKS6yC-slRuleuwFxeM_lmG_pQ,5560
|
115
115
|
nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=irPyDjfUX843ze4bJM9sW8WSeEcU643QJ30JN2hz9Rc,3206
|
116
116
|
nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
|
117
117
|
nshtrainer/metrics/_config.py,sha256=ox_ScK6V0J9nzIMhEB0qpToNKpt83VVgOVSRFCV-wBc,595
|
118
118
|
nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
|
119
119
|
nshtrainer/model/base.py,sha256=PvTmupfGahEZME0BWqbeErDPP1VOm2Nm9JxJkO8afcc,10815
|
120
|
-
nshtrainer/model/mixins/callback.py,sha256
|
120
|
+
nshtrainer/model/mixins/callback.py,sha256=-walDV3fxH4K-ezugvL__Tml9OP1WIlaaTT8j6mWxLI,2580
|
121
121
|
nshtrainer/model/mixins/debug.py,sha256=ydLuAAaa7M5bX0gougZ5gWuZnvn4Ra9assal3IZ9hq8,2086
|
122
|
-
nshtrainer/model/mixins/logger.py,sha256
|
122
|
+
nshtrainer/model/mixins/logger.py,sha256=-D4YwSg0eTDtXj3N288FEo6rqsZ518u1aMBE4Dv4tmg,11708
|
123
123
|
nshtrainer/nn/__init__.py,sha256=Vd246v2N9tBQ8XxmTquWzj5lAmeSnngrjpYOfp4LTXM,1499
|
124
124
|
nshtrainer/nn/mlp.py,sha256=nYUgAISzuhC8sav6PloAdyz0PdEoikwppiXIuToEVdE,7550
|
125
|
-
nshtrainer/nn/module_dict.py,sha256=
|
126
|
-
nshtrainer/nn/module_list.py,sha256=
|
127
|
-
nshtrainer/nn/nonlinearity.py,sha256=
|
125
|
+
nshtrainer/nn/module_dict.py,sha256=FJrxUgQkY6O6tmA_7I_kRoPvxLtPU3cYZY-42InVG3A,2366
|
126
|
+
nshtrainer/nn/module_list.py,sha256=xvoF8F-pG-z3MnYc91anG9vUQFVes6niy8J8J0qVAlg,2091
|
127
|
+
nshtrainer/nn/nonlinearity.py,sha256=UhAsc8o_6AIsos6sktUWC9xLFCFgHJn5WiurKN1sf5U,6493
|
128
128
|
nshtrainer/nn/rng.py,sha256=IJGvX9v8qBkfgBrMlNU2aj-MbYTPoncFyJzvPkzCQpM,512
|
129
|
-
nshtrainer/optimizer.py,sha256=
|
129
|
+
nshtrainer/optimizer.py,sha256=hvw_UNovYgLHhDvMr9BbUz3EPOIrGZDz9ir8lvCgiw0,17458
|
130
130
|
nshtrainer/profiler/__init__.py,sha256=RjaNBoVcTFu8lF0dNlFp-2LaPYdonoIbDy2_KhgF0Ek,594
|
131
131
|
nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,950
|
132
132
|
nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5NOQ,1160
|
@@ -137,8 +137,8 @@ nshtrainer/trainer/_config.py,sha256=GL8DtuH-6x2aHcRlEcmzyhEBMRRldiSazNAeNmPw7gM
|
|
137
137
|
nshtrainer/trainer/_distributed_prediction_result.py,sha256=bQw8Z6PT694UUf-zQPkech6CxyUSy8bAIexfSfPej0U,2507
|
138
138
|
nshtrainer/trainer/_log_hparams.py,sha256=XH2lZ4U_3AZBhOt91ocsEhdL_NRz35oWvqLCUFDohUs,2389
|
139
139
|
nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
|
140
|
-
nshtrainer/trainer/accelerator.py,sha256=
|
141
|
-
nshtrainer/trainer/plugin/__init__.py,sha256=
|
140
|
+
nshtrainer/trainer/accelerator.py,sha256=rWfSJ-pQsLREaRPF_rRXsqxvaQQf6XGT6zpNt829jk0,2390
|
141
|
+
nshtrainer/trainer/plugin/__init__.py,sha256=LSxEK0vnoN9WkU8MDIetrVrDPLCowGLoc9cvh6RG6gg,492
|
142
142
|
nshtrainer/trainer/plugin/base.py,sha256=76ct2TTHLpPr5MO8B9CIkoCOo-dFImzqAll8cIdC0cg,736
|
143
143
|
nshtrainer/trainer/plugin/environment.py,sha256=SSXRWHjyFUA6oFx3duD_ZwhM59pWUjR1_UzHz02NI2c,5440
|
144
144
|
nshtrainer/trainer/plugin/io.py,sha256=OmFSKLloMypletjaUr_Ptg6LS0ljqTVIp2o4Hm3eZoE,1926
|
@@ -149,7 +149,7 @@ nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTu
|
|
149
149
|
nshtrainer/trainer/trainer.py,sha256=G_tHqzZCHJazhROcoKeOI5rZ5A8F8XlghiIWkdMbPR0,24387
|
150
150
|
nshtrainer/util/_environment_info.py,sha256=j-wyEHKirsu3rIXTtqC2kLmIIkRe6obWjxPVWaqg2ow,24887
|
151
151
|
nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
|
152
|
-
nshtrainer/util/code_upload.py,sha256=
|
152
|
+
nshtrainer/util/code_upload.py,sha256=o0GKWROL5EUvJ2F-eOr9ag6R588ZbgG8HX37fvEMfgY,1241
|
153
153
|
nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
|
154
154
|
nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
|
155
155
|
nshtrainer/util/config/duration.py,sha256=mM-UfU_HvhXwW33TYEDg0x58n80tnle2e6VaWtxZTjk,764
|
@@ -159,6 +159,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
|
159
159
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
160
160
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
161
161
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
162
|
-
nshtrainer-1.
|
163
|
-
nshtrainer-1.
|
164
|
-
nshtrainer-1.
|
162
|
+
nshtrainer-1.5.1.dist-info/METADATA,sha256=ct7S8c2O-oHJ2yw3-pApipwmP_r07z8lmFu40FhQY-k,980
|
163
|
+
nshtrainer-1.5.1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
164
|
+
nshtrainer-1.5.1.dist-info/RECORD,,
|
File without changes
|