nshtrainer 0.8.7__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +2 -1
- nshtrainer/callbacks/__init__.py +17 -1
- nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
- nshtrainer/callbacks/base.py +7 -5
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
- nshtrainer/callbacks/model_checkpoint.py +187 -0
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/throughput_monitor.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/ll/__init__.py +0 -1
- nshtrainer/ll/actsave.py +2 -1
- nshtrainer/metrics/__init__.py +1 -0
- nshtrainer/metrics/_config.py +37 -0
- nshtrainer/model/__init__.py +11 -11
- nshtrainer/model/_environment.py +777 -0
- nshtrainer/model/base.py +5 -114
- nshtrainer/model/config.py +92 -507
- nshtrainer/model/modules/logger.py +11 -6
- nshtrainer/runner.py +3 -6
- nshtrainer/trainer/_checkpoint_metadata.py +102 -0
- nshtrainer/trainer/_checkpoint_resolver.py +319 -0
- nshtrainer/trainer/_runtime_callback.py +120 -0
- nshtrainer/trainer/checkpoint_connector.py +63 -0
- nshtrainer/trainer/signal_connector.py +12 -9
- nshtrainer/trainer/trainer.py +111 -31
- {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/METADATA +3 -1
- {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/RECORD +34 -27
- nshtrainer/actsave/__init__.py +0 -3
- {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/WHEEL +0 -0
nshtrainer/model/base.py
CHANGED
|
@@ -1,30 +1,18 @@
|
|
|
1
|
-
import getpass
|
|
2
1
|
import inspect
|
|
3
|
-
import os
|
|
4
|
-
import platform
|
|
5
|
-
import sys
|
|
6
2
|
from abc import ABC, abstractmethod
|
|
7
|
-
from collections.abc import
|
|
8
|
-
from datetime import timedelta
|
|
3
|
+
from collections.abc import MutableMapping
|
|
9
4
|
from logging import getLogger
|
|
10
|
-
from pathlib import Path
|
|
11
5
|
from typing import IO, TYPE_CHECKING, Any, Generic, cast
|
|
12
6
|
|
|
13
7
|
import torch
|
|
14
8
|
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
|
|
15
|
-
from lightning.pytorch import
|
|
9
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
16
10
|
from lightning.pytorch.callbacks import Callback
|
|
17
11
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
18
12
|
from typing_extensions import Self, TypeVar, override
|
|
19
13
|
|
|
20
|
-
from .
|
|
21
|
-
|
|
22
|
-
EnvironmentClassInformationConfig,
|
|
23
|
-
EnvironmentLinuxEnvironmentConfig,
|
|
24
|
-
EnvironmentLSFInformationConfig,
|
|
25
|
-
EnvironmentSLURMInformationConfig,
|
|
26
|
-
EnvironmentSnapshotConfig,
|
|
27
|
-
)
|
|
14
|
+
from ._environment import EnvironmentConfig
|
|
15
|
+
from .config import BaseConfig
|
|
28
16
|
from .modules.callback import CallbackModuleMixin
|
|
29
17
|
from .modules.debug import DebugModuleMixin
|
|
30
18
|
from .modules.distributed import DistributedMixin
|
|
@@ -102,39 +90,6 @@ class DebugFlagCallback(Callback):
|
|
|
102
90
|
hparams.debug = self._debug
|
|
103
91
|
|
|
104
92
|
|
|
105
|
-
def _cls_info(cls: type):
|
|
106
|
-
name = cls.__name__
|
|
107
|
-
module = cls.__module__
|
|
108
|
-
full_name = f"{cls.__module__}.{cls.__qualname__}"
|
|
109
|
-
|
|
110
|
-
file_path = inspect.getfile(cls)
|
|
111
|
-
source_file_path = inspect.getsourcefile(cls)
|
|
112
|
-
return EnvironmentClassInformationConfig(
|
|
113
|
-
name=name,
|
|
114
|
-
module=module,
|
|
115
|
-
full_name=full_name,
|
|
116
|
-
file_path=Path(file_path),
|
|
117
|
-
source_file_path=Path(source_file_path) if source_file_path else None,
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
T = TypeVar("T")
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
def _psutil():
|
|
125
|
-
import psutil
|
|
126
|
-
|
|
127
|
-
return psutil
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
def _try_get(fn: Callable[[], T | None]) -> T | None:
|
|
131
|
-
try:
|
|
132
|
-
return fn()
|
|
133
|
-
except Exception as e:
|
|
134
|
-
log.warning(f"Failed to get value: {e}")
|
|
135
|
-
return None
|
|
136
|
-
|
|
137
|
-
|
|
138
93
|
class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
139
94
|
ProfilerMixin,
|
|
140
95
|
RLPSanityCheckModuleMixin,
|
|
@@ -212,58 +167,6 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
|
212
167
|
**kwargs,
|
|
213
168
|
)
|
|
214
169
|
|
|
215
|
-
@classmethod
|
|
216
|
-
def _update_environment(cls, hparams: THparams):
|
|
217
|
-
hparams.environment.cwd = Path(os.getcwd())
|
|
218
|
-
hparams.environment.python_executable = Path(sys.executable)
|
|
219
|
-
hparams.environment.python_path = [Path(path) for path in sys.path]
|
|
220
|
-
hparams.environment.python_version = sys.version
|
|
221
|
-
hparams.environment.config = _cls_info(cls.config_cls())
|
|
222
|
-
hparams.environment.model = _cls_info(cls)
|
|
223
|
-
hparams.environment.slurm = (
|
|
224
|
-
EnvironmentSLURMInformationConfig.from_current_environment()
|
|
225
|
-
)
|
|
226
|
-
hparams.environment.lsf = (
|
|
227
|
-
EnvironmentLSFInformationConfig.from_current_environment()
|
|
228
|
-
)
|
|
229
|
-
hparams.environment.base_dir = hparams.directory.resolve_run_root_directory(
|
|
230
|
-
hparams.id
|
|
231
|
-
)
|
|
232
|
-
hparams.environment.log_dir = hparams.directory.resolve_subdirectory(
|
|
233
|
-
hparams.id, "log"
|
|
234
|
-
)
|
|
235
|
-
hparams.environment.checkpoint_dir = hparams.directory.resolve_subdirectory(
|
|
236
|
-
hparams.id, "checkpoint"
|
|
237
|
-
)
|
|
238
|
-
hparams.environment.stdio_dir = hparams.directory.resolve_subdirectory(
|
|
239
|
-
hparams.id, "stdio"
|
|
240
|
-
)
|
|
241
|
-
hparams.environment.seed = (
|
|
242
|
-
int(seed_str) if (seed_str := os.environ.get("PL_GLOBAL_SEED")) else None
|
|
243
|
-
)
|
|
244
|
-
hparams.environment.seed_workers = (
|
|
245
|
-
bool(int(seed_everything))
|
|
246
|
-
if (seed_everything := os.environ.get("PL_SEED_WORKERS"))
|
|
247
|
-
else None
|
|
248
|
-
)
|
|
249
|
-
hparams.environment.linux = EnvironmentLinuxEnvironmentConfig(
|
|
250
|
-
user=_try_get(lambda: getpass.getuser()),
|
|
251
|
-
hostname=_try_get(lambda: platform.node()),
|
|
252
|
-
system=_try_get(lambda: platform.system()),
|
|
253
|
-
release=_try_get(lambda: platform.release()),
|
|
254
|
-
version=_try_get(lambda: platform.version()),
|
|
255
|
-
machine=_try_get(lambda: platform.machine()),
|
|
256
|
-
processor=_try_get(lambda: platform.processor()),
|
|
257
|
-
cpu_count=_try_get(lambda: os.cpu_count()),
|
|
258
|
-
memory=_try_get(lambda: _psutil().virtual_memory().total),
|
|
259
|
-
uptime=_try_get(lambda: timedelta(seconds=_psutil().boot_time())),
|
|
260
|
-
boot_time=_try_get(lambda: _psutil().boot_time()),
|
|
261
|
-
load_avg=_try_get(lambda: os.getloadavg()),
|
|
262
|
-
)
|
|
263
|
-
hparams.environment.snapshot = (
|
|
264
|
-
EnvironmentSnapshotConfig.from_current_environment()
|
|
265
|
-
)
|
|
266
|
-
|
|
267
170
|
def pre_init_update_hparams_dict(self, hparams: MutableMapping[str, Any]):
|
|
268
171
|
"""
|
|
269
172
|
Override this method to update the hparams dictionary before it is used to create the hparams object.
|
|
@@ -287,12 +190,11 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
|
287
190
|
|
|
288
191
|
hparams = self.pre_init_update_hparams_dict(hparams)
|
|
289
192
|
hparams = self.config_cls().model_validate(hparams)
|
|
290
|
-
|
|
193
|
+
hparams.environment = EnvironmentConfig.from_current_environment(hparams, self)
|
|
291
194
|
hparams = self.pre_init_update_hparams(hparams)
|
|
292
195
|
super().__init__(hparams)
|
|
293
196
|
|
|
294
197
|
self.save_hyperparameters(hparams)
|
|
295
|
-
|
|
296
198
|
self.register_callback(lambda: DebugFlagCallback())
|
|
297
199
|
|
|
298
200
|
def zero_loss(self):
|
|
@@ -304,17 +206,6 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
|
304
206
|
loss = cast(torch.Tensor, loss)
|
|
305
207
|
return loss
|
|
306
208
|
|
|
307
|
-
@property
|
|
308
|
-
def datamodule(self):
|
|
309
|
-
datamodule = getattr(self.trainer, "datamodule", None)
|
|
310
|
-
if (datamodule := getattr(self.trainer, "datamodule", None)) is None:
|
|
311
|
-
return None
|
|
312
|
-
if not isinstance(datamodule, LightningDataModule):
|
|
313
|
-
raise TypeError(
|
|
314
|
-
f"datamodule must be a LightningDataModule: {type(datamodule)}"
|
|
315
|
-
)
|
|
316
|
-
return datamodule
|
|
317
|
-
|
|
318
209
|
if TYPE_CHECKING:
|
|
319
210
|
|
|
320
211
|
@override
|