nshtrainer 1.0.0b46__py3-none-any.whl → 1.0.0b47__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_callback.py +42 -42
- nshtrainer/_checkpoint/metadata.py +16 -4
- nshtrainer/lr_scheduler/base.py +5 -7
- nshtrainer/model/base.py +0 -67
- nshtrainer/trainer/trainer.py +0 -14
- {nshtrainer-1.0.0b46.dist-info → nshtrainer-1.0.0b47.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b46.dist-info → nshtrainer-1.0.0b47.dist-info}/RECORD +8 -8
- {nshtrainer-1.0.0b46.dist-info → nshtrainer-1.0.0b47.dist-info}/WHEEL +0 -0
nshtrainer/_callback.py
CHANGED
@@ -4,46 +4,46 @@ from pathlib import Path
|
|
4
4
|
from typing import TYPE_CHECKING, Any
|
5
5
|
|
6
6
|
import torch
|
7
|
+
from lightning.pytorch import LightningModule
|
7
8
|
from lightning.pytorch.callbacks import Callback as _LightningCallback
|
8
9
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
9
10
|
from torch.optim import Optimizer
|
10
11
|
|
11
12
|
if TYPE_CHECKING:
|
12
|
-
from .model import LightningModuleBase
|
13
13
|
from .trainer import Trainer
|
14
14
|
|
15
15
|
|
16
16
|
class NTCallbackBase(_LightningCallback):
|
17
17
|
def setup( # pyright: ignore[reportIncompatibleMethodOverride]
|
18
|
-
self, trainer: Trainer, pl_module:
|
18
|
+
self, trainer: Trainer, pl_module: LightningModule, stage: str
|
19
19
|
) -> None:
|
20
20
|
"""Called when fit, validate, test, predict, or tune begins."""
|
21
21
|
|
22
22
|
def teardown( # pyright: ignore[reportIncompatibleMethodOverride]
|
23
|
-
self, trainer: Trainer, pl_module:
|
23
|
+
self, trainer: Trainer, pl_module: LightningModule, stage: str
|
24
24
|
) -> None:
|
25
25
|
"""Called when fit, validate, test, predict, or tune ends."""
|
26
26
|
|
27
|
-
def on_fit_start(self, trainer: Trainer, pl_module:
|
27
|
+
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
28
28
|
"""Called when fit begins."""
|
29
29
|
|
30
|
-
def on_fit_end(self, trainer: Trainer, pl_module:
|
30
|
+
def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
31
31
|
"""Called when fit ends."""
|
32
32
|
|
33
33
|
def on_sanity_check_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
34
|
-
self, trainer: Trainer, pl_module:
|
34
|
+
self, trainer: Trainer, pl_module: LightningModule
|
35
35
|
) -> None:
|
36
36
|
"""Called when the validation sanity check starts."""
|
37
37
|
|
38
38
|
def on_sanity_check_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
39
|
-
self, trainer: Trainer, pl_module:
|
39
|
+
self, trainer: Trainer, pl_module: LightningModule
|
40
40
|
) -> None:
|
41
41
|
"""Called when the validation sanity check ends."""
|
42
42
|
|
43
43
|
def on_train_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
44
44
|
self,
|
45
45
|
trainer: Trainer,
|
46
|
-
pl_module:
|
46
|
+
pl_module: LightningModule,
|
47
47
|
batch: Any,
|
48
48
|
batch_idx: int,
|
49
49
|
) -> None:
|
@@ -52,7 +52,7 @@ class NTCallbackBase(_LightningCallback):
|
|
52
52
|
def on_train_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
53
53
|
self,
|
54
54
|
trainer: Trainer,
|
55
|
-
pl_module:
|
55
|
+
pl_module: LightningModule,
|
56
56
|
outputs: STEP_OUTPUT,
|
57
57
|
batch: Any,
|
58
58
|
batch_idx: int,
|
@@ -66,12 +66,12 @@ class NTCallbackBase(_LightningCallback):
|
|
66
66
|
"""
|
67
67
|
|
68
68
|
def on_train_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
69
|
-
self, trainer: Trainer, pl_module:
|
69
|
+
self, trainer: Trainer, pl_module: LightningModule
|
70
70
|
) -> None:
|
71
71
|
"""Called when the train epoch begins."""
|
72
72
|
|
73
73
|
def on_train_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
74
|
-
self, trainer: Trainer, pl_module:
|
74
|
+
self, trainer: Trainer, pl_module: LightningModule
|
75
75
|
) -> None:
|
76
76
|
"""Called when the train epoch ends.
|
77
77
|
|
@@ -102,39 +102,39 @@ class NTCallbackBase(_LightningCallback):
|
|
102
102
|
"""
|
103
103
|
|
104
104
|
def on_validation_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
105
|
-
self, trainer: Trainer, pl_module:
|
105
|
+
self, trainer: Trainer, pl_module: LightningModule
|
106
106
|
) -> None:
|
107
107
|
"""Called when the val epoch begins."""
|
108
108
|
|
109
109
|
def on_validation_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
110
|
-
self, trainer: Trainer, pl_module:
|
110
|
+
self, trainer: Trainer, pl_module: LightningModule
|
111
111
|
) -> None:
|
112
112
|
"""Called when the val epoch ends."""
|
113
113
|
|
114
114
|
def on_test_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
115
|
-
self, trainer: Trainer, pl_module:
|
115
|
+
self, trainer: Trainer, pl_module: LightningModule
|
116
116
|
) -> None:
|
117
117
|
"""Called when the test epoch begins."""
|
118
118
|
|
119
119
|
def on_test_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
120
|
-
self, trainer: Trainer, pl_module:
|
120
|
+
self, trainer: Trainer, pl_module: LightningModule
|
121
121
|
) -> None:
|
122
122
|
"""Called when the test epoch ends."""
|
123
123
|
|
124
124
|
def on_predict_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
125
|
-
self, trainer: Trainer, pl_module:
|
125
|
+
self, trainer: Trainer, pl_module: LightningModule
|
126
126
|
) -> None:
|
127
127
|
"""Called when the predict epoch begins."""
|
128
128
|
|
129
129
|
def on_predict_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
130
|
-
self, trainer: Trainer, pl_module:
|
130
|
+
self, trainer: Trainer, pl_module: LightningModule
|
131
131
|
) -> None:
|
132
132
|
"""Called when the predict epoch ends."""
|
133
133
|
|
134
134
|
def on_validation_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
135
135
|
self,
|
136
136
|
trainer: Trainer,
|
137
|
-
pl_module:
|
137
|
+
pl_module: LightningModule,
|
138
138
|
batch: Any,
|
139
139
|
batch_idx: int,
|
140
140
|
dataloader_idx: int = 0,
|
@@ -144,7 +144,7 @@ class NTCallbackBase(_LightningCallback):
|
|
144
144
|
def on_validation_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
145
145
|
self,
|
146
146
|
trainer: Trainer,
|
147
|
-
pl_module:
|
147
|
+
pl_module: LightningModule,
|
148
148
|
outputs: STEP_OUTPUT,
|
149
149
|
batch: Any,
|
150
150
|
batch_idx: int,
|
@@ -155,7 +155,7 @@ class NTCallbackBase(_LightningCallback):
|
|
155
155
|
def on_test_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
156
156
|
self,
|
157
157
|
trainer: Trainer,
|
158
|
-
pl_module:
|
158
|
+
pl_module: LightningModule,
|
159
159
|
batch: Any,
|
160
160
|
batch_idx: int,
|
161
161
|
dataloader_idx: int = 0,
|
@@ -165,7 +165,7 @@ class NTCallbackBase(_LightningCallback):
|
|
165
165
|
def on_test_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
166
166
|
self,
|
167
167
|
trainer: Trainer,
|
168
|
-
pl_module:
|
168
|
+
pl_module: LightningModule,
|
169
169
|
outputs: STEP_OUTPUT,
|
170
170
|
batch: Any,
|
171
171
|
batch_idx: int,
|
@@ -176,7 +176,7 @@ class NTCallbackBase(_LightningCallback):
|
|
176
176
|
def on_predict_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
177
177
|
self,
|
178
178
|
trainer: Trainer,
|
179
|
-
pl_module:
|
179
|
+
pl_module: LightningModule,
|
180
180
|
batch: Any,
|
181
181
|
batch_idx: int,
|
182
182
|
dataloader_idx: int = 0,
|
@@ -186,7 +186,7 @@ class NTCallbackBase(_LightningCallback):
|
|
186
186
|
def on_predict_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
187
187
|
self,
|
188
188
|
trainer: Trainer,
|
189
|
-
pl_module:
|
189
|
+
pl_module: LightningModule,
|
190
190
|
outputs: Any,
|
191
191
|
batch: Any,
|
192
192
|
batch_idx: int,
|
@@ -194,40 +194,40 @@ class NTCallbackBase(_LightningCallback):
|
|
194
194
|
) -> None:
|
195
195
|
"""Called when the predict batch ends."""
|
196
196
|
|
197
|
-
def on_train_start(self, trainer: Trainer, pl_module:
|
197
|
+
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
198
198
|
"""Called when the train begins."""
|
199
199
|
|
200
|
-
def on_train_end(self, trainer: Trainer, pl_module:
|
200
|
+
def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
201
201
|
"""Called when the train ends."""
|
202
202
|
|
203
203
|
def on_validation_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
204
|
-
self, trainer: Trainer, pl_module:
|
204
|
+
self, trainer: Trainer, pl_module: LightningModule
|
205
205
|
) -> None:
|
206
206
|
"""Called when the validation loop begins."""
|
207
207
|
|
208
208
|
def on_validation_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
209
|
-
self, trainer: Trainer, pl_module:
|
209
|
+
self, trainer: Trainer, pl_module: LightningModule
|
210
210
|
) -> None:
|
211
211
|
"""Called when the validation loop ends."""
|
212
212
|
|
213
|
-
def on_test_start(self, trainer: Trainer, pl_module:
|
213
|
+
def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
214
214
|
"""Called when the test begins."""
|
215
215
|
|
216
|
-
def on_test_end(self, trainer: Trainer, pl_module:
|
216
|
+
def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
217
217
|
"""Called when the test ends."""
|
218
218
|
|
219
219
|
def on_predict_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
220
|
-
self, trainer: Trainer, pl_module:
|
220
|
+
self, trainer: Trainer, pl_module: LightningModule
|
221
221
|
) -> None:
|
222
222
|
"""Called when the predict begins."""
|
223
223
|
|
224
|
-
def on_predict_end(self, trainer: Trainer, pl_module:
|
224
|
+
def on_predict_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
225
225
|
"""Called when predict ends."""
|
226
226
|
|
227
227
|
def on_exception( # pyright: ignore[reportIncompatibleMethodOverride]
|
228
228
|
self,
|
229
229
|
trainer: Trainer,
|
230
|
-
pl_module:
|
230
|
+
pl_module: LightningModule,
|
231
231
|
exception: BaseException,
|
232
232
|
) -> None:
|
233
233
|
"""Called when any trainer execution is interrupted by an exception."""
|
@@ -253,7 +253,7 @@ class NTCallbackBase(_LightningCallback):
|
|
253
253
|
def on_save_checkpoint( # pyright: ignore[reportIncompatibleMethodOverride]
|
254
254
|
self,
|
255
255
|
trainer: Trainer,
|
256
|
-
pl_module:
|
256
|
+
pl_module: LightningModule,
|
257
257
|
checkpoint: dict[str, Any],
|
258
258
|
) -> None:
|
259
259
|
r"""Called when saving a checkpoint to give you a chance to store anything else you might want to save.
|
@@ -268,7 +268,7 @@ class NTCallbackBase(_LightningCallback):
|
|
268
268
|
def on_load_checkpoint( # pyright: ignore[reportIncompatibleMethodOverride]
|
269
269
|
self,
|
270
270
|
trainer: Trainer,
|
271
|
-
pl_module:
|
271
|
+
pl_module: LightningModule,
|
272
272
|
checkpoint: dict[str, Any],
|
273
273
|
) -> None:
|
274
274
|
r"""Called when loading a model checkpoint, use to reload state.
|
@@ -281,19 +281,19 @@ class NTCallbackBase(_LightningCallback):
|
|
281
281
|
"""
|
282
282
|
|
283
283
|
def on_before_backward( # pyright: ignore[reportIncompatibleMethodOverride]
|
284
|
-
self, trainer: Trainer, pl_module:
|
284
|
+
self, trainer: Trainer, pl_module: LightningModule, loss: torch.Tensor
|
285
285
|
) -> None:
|
286
286
|
"""Called before ``loss.backward()``."""
|
287
287
|
|
288
288
|
def on_after_backward( # pyright: ignore[reportIncompatibleMethodOverride]
|
289
|
-
self, trainer: Trainer, pl_module:
|
289
|
+
self, trainer: Trainer, pl_module: LightningModule
|
290
290
|
) -> None:
|
291
291
|
"""Called after ``loss.backward()`` and before optimizers are stepped."""
|
292
292
|
|
293
293
|
def on_before_optimizer_step( # pyright: ignore[reportIncompatibleMethodOverride]
|
294
294
|
self,
|
295
295
|
trainer: Trainer,
|
296
|
-
pl_module:
|
296
|
+
pl_module: LightningModule,
|
297
297
|
optimizer: Optimizer,
|
298
298
|
) -> None:
|
299
299
|
"""Called before ``optimizer.step()``."""
|
@@ -301,7 +301,7 @@ class NTCallbackBase(_LightningCallback):
|
|
301
301
|
def on_before_zero_grad( # pyright: ignore[reportIncompatibleMethodOverride]
|
302
302
|
self,
|
303
303
|
trainer: Trainer,
|
304
|
-
pl_module:
|
304
|
+
pl_module: LightningModule,
|
305
305
|
optimizer: Optimizer,
|
306
306
|
) -> None:
|
307
307
|
"""Called before ``optimizer.zero_grad()``."""
|
@@ -310,15 +310,15 @@ class NTCallbackBase(_LightningCallback):
|
|
310
310
|
self,
|
311
311
|
ckpt_path: Path,
|
312
312
|
metadata_path: Path | None,
|
313
|
-
trainer:
|
314
|
-
pl_module:
|
313
|
+
trainer: Trainer,
|
314
|
+
pl_module: LightningModule,
|
315
315
|
) -> None:
|
316
316
|
"""Called after a checkpoint is saved."""
|
317
317
|
pass
|
318
318
|
|
319
319
|
|
320
320
|
def _call_on_checkpoint_saved(
|
321
|
-
trainer:
|
321
|
+
trainer: Trainer,
|
322
322
|
ckpt_path: str | Path,
|
323
323
|
metadata_path: str | Path | None,
|
324
324
|
):
|
@@ -333,5 +333,5 @@ def _call_on_checkpoint_saved(
|
|
333
333
|
ckpt_path,
|
334
334
|
metadata_path,
|
335
335
|
trainer,
|
336
|
-
trainer.
|
336
|
+
trainer.lightning_module,
|
337
337
|
)
|
@@ -16,6 +16,18 @@ from ..util.path import compute_file_checksum, try_symlink_or_copy
|
|
16
16
|
if TYPE_CHECKING:
|
17
17
|
from ..trainer.trainer import Trainer
|
18
18
|
|
19
|
+
try:
|
20
|
+
from pydantic import BaseModel
|
21
|
+
|
22
|
+
_HAS_PYDANTIC = True
|
23
|
+
except ImportError:
|
24
|
+
if not TYPE_CHECKING:
|
25
|
+
BaseModel = object
|
26
|
+
else:
|
27
|
+
from pydantic import BaseModel
|
28
|
+
_HAS_PYDANTIC = False
|
29
|
+
|
30
|
+
|
19
31
|
log = logging.getLogger(__name__)
|
20
32
|
|
21
33
|
|
@@ -27,10 +39,10 @@ def _full_hparams_dict(trainer: Trainer):
|
|
27
39
|
hparams["trainer"] = trainer.hparams.model_dump(mode="json")
|
28
40
|
|
29
41
|
if trainer.lightning_module is not None:
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
42
|
+
model_hparams = trainer.lightning_module.hparams
|
43
|
+
if _HAS_PYDANTIC and isinstance(model_hparams, BaseModel):
|
44
|
+
model_hparams = model_hparams.model_dump(mode="json")
|
45
|
+
hparams["model"] = dict(model_hparams)
|
34
46
|
|
35
47
|
return hparams
|
36
48
|
|
nshtrainer/lr_scheduler/base.py
CHANGED
@@ -3,9 +3,10 @@ from __future__ import annotations
|
|
3
3
|
import math
|
4
4
|
from abc import ABC, abstractmethod
|
5
5
|
from collections.abc import Mapping
|
6
|
-
from typing import
|
6
|
+
from typing import Literal
|
7
7
|
|
8
8
|
import nshconfig as C
|
9
|
+
from lightning.pytorch import LightningModule
|
9
10
|
from lightning.pytorch.utilities.types import (
|
10
11
|
LRSchedulerConfigType,
|
11
12
|
LRSchedulerTypeUnion,
|
@@ -13,9 +14,6 @@ from lightning.pytorch.utilities.types import (
|
|
13
14
|
from torch.optim import Optimizer
|
14
15
|
from typing_extensions import Never, NotRequired, TypedDict
|
15
16
|
|
16
|
-
if TYPE_CHECKING:
|
17
|
-
from ..model.base import LightningModuleBase
|
18
|
-
|
19
17
|
|
20
18
|
class LRSchedulerMetadata(TypedDict):
|
21
19
|
interval: Literal["epoch", "step"]
|
@@ -44,13 +42,13 @@ class LRSchedulerConfigBase(C.Config, ABC):
|
|
44
42
|
|
45
43
|
@abstractmethod
|
46
44
|
def create_scheduler_impl(
|
47
|
-
self, optimizer: Optimizer, lightning_module:
|
45
|
+
self, optimizer: Optimizer, lightning_module: LightningModule
|
48
46
|
) -> LRSchedulerTypeUnion | LRSchedulerConfigType: ...
|
49
47
|
|
50
48
|
def create_scheduler(
|
51
49
|
self,
|
52
50
|
optimizer: Optimizer,
|
53
|
-
lightning_module:
|
51
|
+
lightning_module: LightningModule,
|
54
52
|
lr: Never
|
55
53
|
| None = None, # Backward compatibility, should be removed in the future
|
56
54
|
) -> LRSchedulerConfigType:
|
@@ -87,7 +85,7 @@ class LRSchedulerConfigBase(C.Config, ABC):
|
|
87
85
|
|
88
86
|
return scheduler
|
89
87
|
|
90
|
-
def compute_num_steps_per_epoch(self, lightning_module:
|
88
|
+
def compute_num_steps_per_epoch(self, lightning_module: LightningModule) -> int:
|
91
89
|
trainer = lightning_module.trainer
|
92
90
|
# Use the Lightning trainer to convert the epoch-based values to step-based values
|
93
91
|
_ = trainer.estimated_stepping_batches
|
nshtrainer/model/base.py
CHANGED
@@ -62,73 +62,6 @@ class LightningModuleBase(
|
|
62
62
|
ABC,
|
63
63
|
Generic[THparams],
|
64
64
|
):
|
65
|
-
# region Debug
|
66
|
-
@property
|
67
|
-
def debug(self) -> bool:
|
68
|
-
if torch.jit.is_scripting():
|
69
|
-
return False
|
70
|
-
|
71
|
-
if (trainer := self._trainer) is None:
|
72
|
-
return False
|
73
|
-
|
74
|
-
from ..trainer import Trainer
|
75
|
-
|
76
|
-
if not isinstance(trainer, Trainer):
|
77
|
-
return False
|
78
|
-
|
79
|
-
return trainer.debug
|
80
|
-
|
81
|
-
@debug.setter
|
82
|
-
def debug(self, value: bool):
|
83
|
-
if torch.jit.is_scripting():
|
84
|
-
return
|
85
|
-
|
86
|
-
if (trainer := self._trainer) is None:
|
87
|
-
return
|
88
|
-
|
89
|
-
from ..trainer import Trainer
|
90
|
-
|
91
|
-
if not isinstance(trainer, Trainer):
|
92
|
-
return
|
93
|
-
|
94
|
-
trainer.debug = value
|
95
|
-
|
96
|
-
@torch.jit.unused
|
97
|
-
def breakpoint(self, rank_zero_only: bool = True):
|
98
|
-
if (
|
99
|
-
not rank_zero_only
|
100
|
-
or not torch.distributed.is_initialized()
|
101
|
-
or torch.distributed.get_rank() == 0
|
102
|
-
):
|
103
|
-
breakpoint()
|
104
|
-
|
105
|
-
if rank_zero_only and torch.distributed.is_initialized():
|
106
|
-
_ = torch.distributed.barrier()
|
107
|
-
|
108
|
-
@torch.jit.unused
|
109
|
-
def ensure_finite(
|
110
|
-
self,
|
111
|
-
tensor: torch.Tensor,
|
112
|
-
name: str | None = None,
|
113
|
-
throw: bool = False,
|
114
|
-
):
|
115
|
-
name_parts: list[str] = ["Tensor"]
|
116
|
-
if name is not None:
|
117
|
-
name_parts.append(name)
|
118
|
-
name = " ".join(name_parts)
|
119
|
-
|
120
|
-
not_finite = ~torch.isfinite(tensor)
|
121
|
-
if not_finite.any():
|
122
|
-
msg = f"{name} has {not_finite.sum().item()}/{not_finite.numel()} non-finite values."
|
123
|
-
if throw:
|
124
|
-
raise RuntimeError(msg)
|
125
|
-
else:
|
126
|
-
log.warning(msg)
|
127
|
-
return False
|
128
|
-
return True
|
129
|
-
|
130
|
-
# endregion
|
131
|
-
|
132
65
|
# region Profiler
|
133
66
|
@property
|
134
67
|
def profiler(self) -> Profiler:
|
nshtrainer/trainer/trainer.py
CHANGED
@@ -418,20 +418,6 @@ class Trainer(LightningTrainer):
|
|
418
418
|
|
419
419
|
return tracker.time_elapsed(stage)
|
420
420
|
|
421
|
-
@property
|
422
|
-
def _base_module(self):
|
423
|
-
if self.lightning_module is None:
|
424
|
-
raise ValueError("LightningModule is not set.")
|
425
|
-
|
426
|
-
from ..model.base import LightningModuleBase
|
427
|
-
|
428
|
-
if not isinstance(self.lightning_module, LightningModuleBase):
|
429
|
-
raise ValueError(
|
430
|
-
f"LightningModule is not an instance of {LightningModuleBase}."
|
431
|
-
)
|
432
|
-
|
433
|
-
return self.lightning_module
|
434
|
-
|
435
421
|
@override
|
436
422
|
def _run(
|
437
423
|
self, model: LightningModule, ckpt_path: str | Path | None = None
|
@@ -1,7 +1,7 @@
|
|
1
1
|
nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
|
2
2
|
nshtrainer/__init__.py,sha256=g_moPnfQxSxFZX5NB9ILQQOJrt4RTRuiFt9N0STIpxM,874
|
3
|
-
nshtrainer/_callback.py,sha256=
|
4
|
-
nshtrainer/_checkpoint/metadata.py,sha256=
|
3
|
+
nshtrainer/_callback.py,sha256=ZDppiJ4d65tRXTEWYPZLH_F1xFizdz1pkWJe_sQ5uII,12564
|
4
|
+
nshtrainer/_checkpoint/metadata.py,sha256=ojSEmq0udFwdzIC5vkbF0yEdhMaJ2iBrZCSFNDkeeGY,5578
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=65UDrz3KuhkgVfco-RkWuoa1wzTZoXxunlC769yJaMc,1639
|
6
6
|
nshtrainer/_directory.py,sha256=TJR9ccyuzRlAVfVjGyeQ3E2AFAcz-XbBCxWfiXo2SlY,3191
|
7
7
|
nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
@@ -109,13 +109,13 @@ nshtrainer/loggers/csv.py,sha256=xJ8mSmw4vJwinIfqhF6t2HWmh_1dXEYyLfGuXwL7WHo,116
|
|
109
109
|
nshtrainer/loggers/tensorboard.py,sha256=E7iO_fDt9bfH02hBL430bXPLljOo5iGgq2QyPqmx2gQ,2324
|
110
110
|
nshtrainer/loggers/wandb.py,sha256=KZXAUWrrmdX_L8rqej77oUHaM0JxZRM8y9z6JP9PISw,6856
|
111
111
|
nshtrainer/lr_scheduler/__init__.py,sha256=daMMK3erUcNXGGd_nZB8AWu3ZTYqfS1RSWeK4FV2udw,851
|
112
|
-
nshtrainer/lr_scheduler/base.py,sha256=
|
112
|
+
nshtrainer/lr_scheduler/base.py,sha256=LE53JRBTuAlA1fqbMgCZ7m39D1z0rGj2TizhJ62CPvE,3756
|
113
113
|
nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=MsoXgCcWTKsrkNZiGnKS6yC-slRuleuwFxeM_lmG_pQ,5560
|
114
114
|
nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=v9T0GpvOoHV30atFB0MwExHgHcTpMCYxbMRoPjPBjt8,2938
|
115
115
|
nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
|
116
116
|
nshtrainer/metrics/_config.py,sha256=XIRokFM8PHrhBa3w2R6BM6a4es3ncsoBqE_LqXQFsFE,1223
|
117
117
|
nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
|
118
|
-
nshtrainer/model/base.py,sha256=
|
118
|
+
nshtrainer/model/base.py,sha256=q1IMVG3lHvI84x-8hXmiLNJN_NplY_q9W5u6D2rrmVY,8684
|
119
119
|
nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
|
120
120
|
nshtrainer/model/mixins/debug.py,sha256=1LX9KzeFX9JDPs_a6YCdYDZXLhEk_5rBO2aCqlfBy7w,2087
|
121
121
|
nshtrainer/model/mixins/logger.py,sha256=27H99FuLaxc6_dDLG2pid4E_5E0-eLGnc2Ifpt0HYIM,6066
|
@@ -142,7 +142,7 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=h-ydZwXepnsw5-paLgiDatqPyQ_8C0QEv
|
|
142
142
|
nshtrainer/trainer/plugin/precision.py,sha256=I0QsB1bVxmsFmBOkgrAfGONsuYae_lD9Bz0PfJEQvH4,5598
|
143
143
|
nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
|
144
144
|
nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
|
145
|
-
nshtrainer/trainer/trainer.py,sha256=
|
145
|
+
nshtrainer/trainer/trainer.py,sha256=QEK-0bcw1y5Cconi99PYFXr0MElUGgGYMZ_SlcJUQ1k,20364
|
146
146
|
nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
|
147
147
|
nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
|
148
148
|
nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
|
@@ -154,6 +154,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
|
154
154
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
155
155
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
156
156
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
157
|
-
nshtrainer-1.0.
|
158
|
-
nshtrainer-1.0.
|
159
|
-
nshtrainer-1.0.
|
157
|
+
nshtrainer-1.0.0b47.dist-info/METADATA,sha256=E7d5EfVnqLgmFPuh_D_VWERKQrA5tjePktx1vujkSs8,988
|
158
|
+
nshtrainer-1.0.0b47.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
|
159
|
+
nshtrainer-1.0.0b47.dist-info/RECORD,,
|
File without changes
|