nshtrainer 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +2 -1
- nshtrainer/callbacks/__init__.py +17 -1
- nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
- nshtrainer/callbacks/base.py +7 -5
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
- nshtrainer/callbacks/model_checkpoint.py +187 -0
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/throughput_monitor.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/ll/__init__.py +0 -1
- nshtrainer/ll/actsave.py +2 -1
- nshtrainer/metrics/__init__.py +1 -0
- nshtrainer/metrics/_config.py +37 -0
- nshtrainer/model/__init__.py +11 -11
- nshtrainer/model/_environment.py +777 -0
- nshtrainer/model/base.py +5 -114
- nshtrainer/model/config.py +49 -501
- nshtrainer/model/modules/logger.py +11 -6
- nshtrainer/runner.py +3 -6
- nshtrainer/trainer/_checkpoint_metadata.py +102 -0
- nshtrainer/trainer/_checkpoint_resolver.py +319 -0
- nshtrainer/trainer/_runtime_callback.py +120 -0
- nshtrainer/trainer/checkpoint_connector.py +63 -0
- nshtrainer/trainer/signal_connector.py +12 -9
- nshtrainer/trainer/trainer.py +111 -31
- {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.0.dist-info}/METADATA +3 -1
- {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.0.dist-info}/RECORD +34 -27
- nshtrainer/actsave/__init__.py +0 -3
- {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.0.dist-info}/WHEEL +0 -0
nshtrainer/model/config.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
import os
|
|
3
|
-
import re
|
|
4
|
-
import socket
|
|
5
3
|
import string
|
|
6
4
|
import time
|
|
7
5
|
import warnings
|
|
@@ -36,10 +34,17 @@ from lightning.pytorch.strategies.strategy import Strategy
|
|
|
36
34
|
from pydantic import DirectoryPath
|
|
37
35
|
from typing_extensions import Self, TypedDict, TypeVar, override
|
|
38
36
|
|
|
39
|
-
from ..callbacks import
|
|
37
|
+
from ..callbacks import (
|
|
38
|
+
CallbackConfig,
|
|
39
|
+
LatestEpochCheckpointCallbackConfig,
|
|
40
|
+
ModelCheckpointCallbackConfig,
|
|
41
|
+
OnExceptionCheckpointCallbackConfig,
|
|
42
|
+
WandbWatchConfig,
|
|
43
|
+
)
|
|
40
44
|
from ..callbacks.base import CallbackConfigBase
|
|
41
|
-
from ..
|
|
42
|
-
from ..
|
|
45
|
+
from ..metrics import MetricConfig
|
|
46
|
+
from ..trainer._checkpoint_resolver import CheckpointLoadingConfig
|
|
47
|
+
from ._environment import EnvironmentConfig
|
|
43
48
|
|
|
44
49
|
log = getLogger(__name__)
|
|
45
50
|
|
|
@@ -62,7 +67,7 @@ class BaseProfilerConfig(C.Config, ABC):
|
|
|
62
67
|
"""
|
|
63
68
|
|
|
64
69
|
@abstractmethod
|
|
65
|
-
def
|
|
70
|
+
def create_profiler(self, root_config: "BaseConfig") -> Profiler: ...
|
|
66
71
|
|
|
67
72
|
|
|
68
73
|
class SimpleProfilerConfig(BaseProfilerConfig):
|
|
@@ -75,7 +80,7 @@ class SimpleProfilerConfig(BaseProfilerConfig):
|
|
|
75
80
|
"""
|
|
76
81
|
|
|
77
82
|
@override
|
|
78
|
-
def
|
|
83
|
+
def create_profiler(self, root_config):
|
|
79
84
|
from lightning.pytorch.profilers.simple import SimpleProfiler
|
|
80
85
|
|
|
81
86
|
if (dirpath := self.dirpath) is None:
|
|
@@ -104,7 +109,7 @@ class AdvancedProfilerConfig(BaseProfilerConfig):
|
|
|
104
109
|
"""
|
|
105
110
|
|
|
106
111
|
@override
|
|
107
|
-
def
|
|
112
|
+
def create_profiler(self, root_config):
|
|
108
113
|
from lightning.pytorch.profilers.advanced import AdvancedProfiler
|
|
109
114
|
|
|
110
115
|
if (dirpath := self.dirpath) is None:
|
|
@@ -172,7 +177,7 @@ class PyTorchProfilerConfig(BaseProfilerConfig):
|
|
|
172
177
|
"""Keyword arguments for the PyTorch profiler. This depends on your PyTorch version"""
|
|
173
178
|
|
|
174
179
|
@override
|
|
175
|
-
def
|
|
180
|
+
def create_profiler(self, root_config):
|
|
176
181
|
from lightning.pytorch.profilers.pytorch import PyTorchProfiler
|
|
177
182
|
|
|
178
183
|
if (dirpath := self.dirpath) is None:
|
|
@@ -203,190 +208,6 @@ ProfilerConfig: TypeAlias = Annotated[
|
|
|
203
208
|
]
|
|
204
209
|
|
|
205
210
|
|
|
206
|
-
class EnvironmentClassInformationConfig(C.Config):
|
|
207
|
-
name: str
|
|
208
|
-
module: str
|
|
209
|
-
full_name: str
|
|
210
|
-
|
|
211
|
-
file_path: Path
|
|
212
|
-
source_file_path: Path | None = None
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
class EnvironmentSLURMInformationConfig(C.Config):
|
|
216
|
-
hostname: str
|
|
217
|
-
hostnames: list[str]
|
|
218
|
-
job_id: str
|
|
219
|
-
raw_job_id: str
|
|
220
|
-
array_job_id: str | None
|
|
221
|
-
array_task_id: str | None
|
|
222
|
-
num_tasks: int
|
|
223
|
-
num_nodes: int
|
|
224
|
-
node: str | int | None
|
|
225
|
-
global_rank: int
|
|
226
|
-
local_rank: int
|
|
227
|
-
|
|
228
|
-
@classmethod
|
|
229
|
-
def from_current_environment(cls):
|
|
230
|
-
try:
|
|
231
|
-
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
|
|
232
|
-
|
|
233
|
-
if not SLURMEnvironment.detect():
|
|
234
|
-
return None
|
|
235
|
-
|
|
236
|
-
hostname = socket.gethostname()
|
|
237
|
-
hostnames = [hostname]
|
|
238
|
-
if node_list := os.environ.get("SLURM_JOB_NODELIST", ""):
|
|
239
|
-
hostnames = parse_slurm_node_list(node_list)
|
|
240
|
-
|
|
241
|
-
raw_job_id = os.environ["SLURM_JOB_ID"]
|
|
242
|
-
job_id = raw_job_id
|
|
243
|
-
array_job_id = os.environ.get("SLURM_ARRAY_JOB_ID")
|
|
244
|
-
array_task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
|
|
245
|
-
if array_job_id and array_task_id:
|
|
246
|
-
job_id = f"{array_job_id}_{array_task_id}"
|
|
247
|
-
|
|
248
|
-
num_tasks = int(os.environ["SLURM_NTASKS"])
|
|
249
|
-
num_nodes = int(os.environ["SLURM_JOB_NUM_NODES"])
|
|
250
|
-
|
|
251
|
-
node_id = os.environ.get("SLURM_NODEID")
|
|
252
|
-
|
|
253
|
-
global_rank = int(os.environ["SLURM_PROCID"])
|
|
254
|
-
local_rank = int(os.environ["SLURM_LOCALID"])
|
|
255
|
-
|
|
256
|
-
return cls(
|
|
257
|
-
hostname=hostname,
|
|
258
|
-
hostnames=hostnames,
|
|
259
|
-
job_id=job_id,
|
|
260
|
-
raw_job_id=raw_job_id,
|
|
261
|
-
array_job_id=array_job_id,
|
|
262
|
-
array_task_id=array_task_id,
|
|
263
|
-
num_tasks=num_tasks,
|
|
264
|
-
num_nodes=num_nodes,
|
|
265
|
-
node=node_id,
|
|
266
|
-
global_rank=global_rank,
|
|
267
|
-
local_rank=local_rank,
|
|
268
|
-
)
|
|
269
|
-
except (ImportError, RuntimeError, ValueError, KeyError):
|
|
270
|
-
return None
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
class EnvironmentLSFInformationConfig(C.Config):
|
|
274
|
-
hostname: str
|
|
275
|
-
hostnames: list[str]
|
|
276
|
-
job_id: str
|
|
277
|
-
array_job_id: str | None
|
|
278
|
-
array_task_id: str | None
|
|
279
|
-
num_tasks: int
|
|
280
|
-
num_nodes: int
|
|
281
|
-
node: str | int | None
|
|
282
|
-
global_rank: int
|
|
283
|
-
local_rank: int
|
|
284
|
-
|
|
285
|
-
@classmethod
|
|
286
|
-
def from_current_environment(cls):
|
|
287
|
-
try:
|
|
288
|
-
import os
|
|
289
|
-
import socket
|
|
290
|
-
|
|
291
|
-
hostname = socket.gethostname()
|
|
292
|
-
hostnames = [hostname]
|
|
293
|
-
if node_list := os.environ.get("LSB_HOSTS", ""):
|
|
294
|
-
hostnames = node_list.split()
|
|
295
|
-
|
|
296
|
-
job_id = os.environ["LSB_JOBID"]
|
|
297
|
-
array_job_id = os.environ.get("LSB_JOBINDEX")
|
|
298
|
-
array_task_id = os.environ.get("LSB_JOBINDEX")
|
|
299
|
-
|
|
300
|
-
num_tasks = int(os.environ.get("LSB_DJOB_NUMPROC", 1))
|
|
301
|
-
num_nodes = len(set(hostnames))
|
|
302
|
-
|
|
303
|
-
node_id = (
|
|
304
|
-
os.environ.get("LSB_HOSTS", "").split().index(hostname)
|
|
305
|
-
if "LSB_HOSTS" in os.environ
|
|
306
|
-
else None
|
|
307
|
-
)
|
|
308
|
-
|
|
309
|
-
# LSF doesn't have direct equivalents for global_rank and local_rank
|
|
310
|
-
# You might need to calculate these based on your specific setup
|
|
311
|
-
global_rank = int(os.environ.get("PMI_RANK", 0))
|
|
312
|
-
local_rank = int(os.environ.get("LSB_RANK", 0))
|
|
313
|
-
|
|
314
|
-
return cls(
|
|
315
|
-
hostname=hostname,
|
|
316
|
-
hostnames=hostnames,
|
|
317
|
-
job_id=job_id,
|
|
318
|
-
array_job_id=array_job_id,
|
|
319
|
-
array_task_id=array_task_id,
|
|
320
|
-
num_tasks=num_tasks,
|
|
321
|
-
num_nodes=num_nodes,
|
|
322
|
-
node=node_id,
|
|
323
|
-
global_rank=global_rank,
|
|
324
|
-
local_rank=local_rank,
|
|
325
|
-
)
|
|
326
|
-
except (ImportError, RuntimeError, ValueError, KeyError):
|
|
327
|
-
return None
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
class EnvironmentLinuxEnvironmentConfig(C.Config):
|
|
331
|
-
"""
|
|
332
|
-
Information about the Linux environment (e.g., current user, hostname, etc.)
|
|
333
|
-
"""
|
|
334
|
-
|
|
335
|
-
user: str | None = None
|
|
336
|
-
hostname: str | None = None
|
|
337
|
-
system: str | None = None
|
|
338
|
-
release: str | None = None
|
|
339
|
-
version: str | None = None
|
|
340
|
-
machine: str | None = None
|
|
341
|
-
processor: str | None = None
|
|
342
|
-
cpu_count: int | None = None
|
|
343
|
-
memory: int | None = None
|
|
344
|
-
uptime: timedelta | None = None
|
|
345
|
-
boot_time: float | None = None
|
|
346
|
-
load_avg: tuple[float, float, float] | None = None
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
class EnvironmentSnapshotConfig(C.Config):
|
|
350
|
-
snapshot_dir: Path | None = None
|
|
351
|
-
modules: list[str] | None = None
|
|
352
|
-
|
|
353
|
-
@classmethod
|
|
354
|
-
def from_current_environment(cls):
|
|
355
|
-
draft = cls.draft()
|
|
356
|
-
if snapshot_dir := os.environ.get("NSHRUNNER_SNAPSHOT_DIR"):
|
|
357
|
-
draft.snapshot_dir = Path(snapshot_dir)
|
|
358
|
-
if modules := os.environ.get("NSHRUNNER_SNAPSHOT_MODULES"):
|
|
359
|
-
draft.modules = modules.split(",")
|
|
360
|
-
return draft.finalize()
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
class EnvironmentConfig(C.Config):
|
|
364
|
-
cwd: Path | None = None
|
|
365
|
-
|
|
366
|
-
snapshot: EnvironmentSnapshotConfig | None = None
|
|
367
|
-
|
|
368
|
-
python_executable: Path | None = None
|
|
369
|
-
python_path: list[Path] | None = None
|
|
370
|
-
python_version: str | None = None
|
|
371
|
-
|
|
372
|
-
config: EnvironmentClassInformationConfig | None = None
|
|
373
|
-
model: EnvironmentClassInformationConfig | None = None
|
|
374
|
-
data: EnvironmentClassInformationConfig | None = None
|
|
375
|
-
|
|
376
|
-
linux: EnvironmentLinuxEnvironmentConfig | None = None
|
|
377
|
-
|
|
378
|
-
slurm: EnvironmentSLURMInformationConfig | None = None
|
|
379
|
-
lsf: EnvironmentLSFInformationConfig | None = None
|
|
380
|
-
|
|
381
|
-
base_dir: Path | None = None
|
|
382
|
-
log_dir: Path | None = None
|
|
383
|
-
checkpoint_dir: Path | None = None
|
|
384
|
-
stdio_dir: Path | None = None
|
|
385
|
-
|
|
386
|
-
seed: int | None = None
|
|
387
|
-
seed_workers: bool | None = None
|
|
388
|
-
|
|
389
|
-
|
|
390
211
|
class BaseLoggerConfig(C.Config, ABC):
|
|
391
212
|
enabled: bool = True
|
|
392
213
|
"""Enable this logger."""
|
|
@@ -398,7 +219,7 @@ class BaseLoggerConfig(C.Config, ABC):
|
|
|
398
219
|
"""Directory to save the logs to. If None, will use the default log directory for the trainer."""
|
|
399
220
|
|
|
400
221
|
@abstractmethod
|
|
401
|
-
def
|
|
222
|
+
def create_logger(self, root_config: "BaseConfig") -> Logger | None: ...
|
|
402
223
|
|
|
403
224
|
def disable_(self):
|
|
404
225
|
self.enabled = False
|
|
@@ -466,18 +287,16 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
|
|
|
466
287
|
"""Whether to run WandB in offline mode."""
|
|
467
288
|
|
|
468
289
|
@override
|
|
469
|
-
def
|
|
290
|
+
def create_logger(self, root_config):
|
|
470
291
|
if not self.enabled:
|
|
471
292
|
return None
|
|
472
293
|
|
|
473
294
|
from lightning.pytorch.loggers.wandb import WandbLogger
|
|
474
295
|
|
|
475
|
-
save_dir = root_config.directory.
|
|
296
|
+
save_dir = root_config.directory._resolve_log_directory_for_logger(
|
|
476
297
|
root_config.id,
|
|
477
298
|
self,
|
|
478
299
|
)
|
|
479
|
-
save_dir = save_dir / "wandb"
|
|
480
|
-
save_dir.mkdir(parents=True, exist_ok=True)
|
|
481
300
|
return WandbLogger(
|
|
482
301
|
save_dir=save_dir,
|
|
483
302
|
project=self.project or _project_name(root_config),
|
|
@@ -494,9 +313,9 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
|
|
|
494
313
|
)
|
|
495
314
|
|
|
496
315
|
@override
|
|
497
|
-
def
|
|
316
|
+
def create_callbacks(self, root_config):
|
|
498
317
|
if self.watch:
|
|
499
|
-
yield from self.watch.
|
|
318
|
+
yield from self.watch.create_callbacks(root_config)
|
|
500
319
|
|
|
501
320
|
|
|
502
321
|
class CSVLoggerConfig(BaseLoggerConfig):
|
|
@@ -515,18 +334,16 @@ class CSVLoggerConfig(BaseLoggerConfig):
|
|
|
515
334
|
"""How often to flush logs to disk."""
|
|
516
335
|
|
|
517
336
|
@override
|
|
518
|
-
def
|
|
337
|
+
def create_logger(self, root_config):
|
|
519
338
|
if not self.enabled:
|
|
520
339
|
return None
|
|
521
340
|
|
|
522
341
|
from lightning.pytorch.loggers.csv_logs import CSVLogger
|
|
523
342
|
|
|
524
|
-
save_dir = root_config.directory.
|
|
343
|
+
save_dir = root_config.directory._resolve_log_directory_for_logger(
|
|
525
344
|
root_config.id,
|
|
526
345
|
self,
|
|
527
346
|
)
|
|
528
|
-
save_dir = save_dir / "csv"
|
|
529
|
-
save_dir.mkdir(parents=True, exist_ok=True)
|
|
530
347
|
return CSVLogger(
|
|
531
348
|
save_dir=save_dir,
|
|
532
349
|
name=root_config.run_name,
|
|
@@ -581,18 +398,16 @@ class TensorboardLoggerConfig(BaseLoggerConfig):
|
|
|
581
398
|
"""A string to put at the beginning of metric keys."""
|
|
582
399
|
|
|
583
400
|
@override
|
|
584
|
-
def
|
|
401
|
+
def create_logger(self, root_config):
|
|
585
402
|
if not self.enabled:
|
|
586
403
|
return None
|
|
587
404
|
|
|
588
405
|
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
|
|
589
406
|
|
|
590
|
-
save_dir = root_config.directory.
|
|
407
|
+
save_dir = root_config.directory._resolve_log_directory_for_logger(
|
|
591
408
|
root_config.id,
|
|
592
409
|
self,
|
|
593
410
|
)
|
|
594
|
-
save_dir = save_dir / "tensorboard"
|
|
595
|
-
save_dir.mkdir(parents=True, exist_ok=True)
|
|
596
411
|
return TensorBoardLogger(
|
|
597
412
|
save_dir=save_dir,
|
|
598
413
|
name=root_config.run_name,
|
|
@@ -624,6 +439,9 @@ class LoggingConfig(CallbackConfigBase):
|
|
|
624
439
|
log_epoch: bool = True
|
|
625
440
|
"""If enabled, will log the fractional epoch number to the logger."""
|
|
626
441
|
|
|
442
|
+
actsave_logged_metrics: bool = False
|
|
443
|
+
"""If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
|
|
444
|
+
|
|
627
445
|
@property
|
|
628
446
|
def wandb(self) -> WandbLoggerConfig | None:
|
|
629
447
|
return next(
|
|
@@ -650,7 +468,7 @@ class LoggingConfig(CallbackConfigBase):
|
|
|
650
468
|
),
|
|
651
469
|
)
|
|
652
470
|
|
|
653
|
-
def
|
|
471
|
+
def create_loggers(self, root_config: "BaseConfig"):
|
|
654
472
|
"""
|
|
655
473
|
Constructs and returns a list of loggers based on the provided root configuration.
|
|
656
474
|
|
|
@@ -671,13 +489,13 @@ class LoggingConfig(CallbackConfigBase):
|
|
|
671
489
|
):
|
|
672
490
|
if not logger_config.enabled:
|
|
673
491
|
continue
|
|
674
|
-
if (logger := logger_config.
|
|
492
|
+
if (logger := logger_config.create_logger(root_config)) is None:
|
|
675
493
|
continue
|
|
676
494
|
loggers.append(logger)
|
|
677
495
|
return loggers
|
|
678
496
|
|
|
679
497
|
@override
|
|
680
|
-
def
|
|
498
|
+
def create_callbacks(self, root_config):
|
|
681
499
|
if self.log_lr:
|
|
682
500
|
from lightning.pytorch.callbacks import LearningRateMonitor
|
|
683
501
|
|
|
@@ -696,7 +514,7 @@ class LoggingConfig(CallbackConfigBase):
|
|
|
696
514
|
if not logger or not isinstance(logger, CallbackConfigBase):
|
|
697
515
|
continue
|
|
698
516
|
|
|
699
|
-
yield from logger.
|
|
517
|
+
yield from logger.create_callbacks(root_config)
|
|
700
518
|
|
|
701
519
|
|
|
702
520
|
class GradientClippingConfig(C.Config):
|
|
@@ -723,7 +541,7 @@ class OptimizationConfig(CallbackConfigBase):
|
|
|
723
541
|
"""Gradient clipping configuration, or None to disable gradient clipping."""
|
|
724
542
|
|
|
725
543
|
@override
|
|
726
|
-
def
|
|
544
|
+
def create_callbacks(self, root_config):
|
|
727
545
|
from ..callbacks.norm_logging import NormLoggingConfig
|
|
728
546
|
|
|
729
547
|
yield from NormLoggingConfig(
|
|
@@ -731,7 +549,7 @@ class OptimizationConfig(CallbackConfigBase):
|
|
|
731
549
|
log_grad_norm_per_param=self.log_grad_norm_per_param,
|
|
732
550
|
log_param_norm=self.log_param_norm,
|
|
733
551
|
log_param_norm_per_param=self.log_param_norm_per_param,
|
|
734
|
-
).
|
|
552
|
+
).create_callbacks(root_config)
|
|
735
553
|
|
|
736
554
|
|
|
737
555
|
TPlugin = TypeVar(
|
|
@@ -746,17 +564,17 @@ TPlugin = TypeVar(
|
|
|
746
564
|
|
|
747
565
|
@runtime_checkable
|
|
748
566
|
class PluginConfigProtocol(Protocol[TPlugin]):
|
|
749
|
-
def
|
|
567
|
+
def create_plugin(self) -> TPlugin: ...
|
|
750
568
|
|
|
751
569
|
|
|
752
570
|
@runtime_checkable
|
|
753
571
|
class AcceleratorConfigProtocol(Protocol):
|
|
754
|
-
def
|
|
572
|
+
def create_accelerator(self) -> Accelerator: ...
|
|
755
573
|
|
|
756
574
|
|
|
757
575
|
@runtime_checkable
|
|
758
576
|
class StrategyConfigProtocol(Protocol):
|
|
759
|
-
def
|
|
577
|
+
def create_strategy(self) -> Strategy: ...
|
|
760
578
|
|
|
761
579
|
|
|
762
580
|
AcceleratorLiteral: TypeAlias = Literal[
|
|
@@ -793,18 +611,6 @@ StrategyLiteral: TypeAlias = Literal[
|
|
|
793
611
|
]
|
|
794
612
|
|
|
795
613
|
|
|
796
|
-
class CheckpointLoadingConfig(C.Config):
|
|
797
|
-
path: Literal["best", "last", "hpc"] | str | Path | None = None
|
|
798
|
-
"""
|
|
799
|
-
Checkpoint path to use when loading a checkpoint.
|
|
800
|
-
|
|
801
|
-
- "best" will load the best checkpoint.
|
|
802
|
-
- "last" will load the last checkpoint.
|
|
803
|
-
- "hpc" will load the SLURM pre-empted checkpoint.
|
|
804
|
-
- Any other string or Path will load the checkpoint from the specified path.
|
|
805
|
-
"""
|
|
806
|
-
|
|
807
|
-
|
|
808
614
|
def _create_symlink_to_nshrunner(base_dir: Path):
|
|
809
615
|
# Resolve the current nshrunner session directory
|
|
810
616
|
if not (session_dir := os.environ.get("NSHRUNNER_SESSION_DIR")):
|
|
@@ -903,7 +709,7 @@ class DirectoryConfig(C.Config):
|
|
|
903
709
|
dir.mkdir(exist_ok=True)
|
|
904
710
|
return dir
|
|
905
711
|
|
|
906
|
-
def
|
|
712
|
+
def _resolve_log_directory_for_logger(
|
|
907
713
|
self,
|
|
908
714
|
run_id: str,
|
|
909
715
|
logger: LoggerConfig,
|
|
@@ -911,9 +717,10 @@ class DirectoryConfig(C.Config):
|
|
|
911
717
|
if (log_dir := logger.log_dir) is not None:
|
|
912
718
|
return log_dir
|
|
913
719
|
|
|
914
|
-
# Save to nshtrainer/{id}/log/{logger kind}
|
|
720
|
+
# Save to nshtrainer/{id}/log/{logger kind}
|
|
915
721
|
log_dir = self.resolve_subdirectory(run_id, "log")
|
|
916
722
|
log_dir = log_dir / logger.kind
|
|
723
|
+
log_dir.mkdir(exist_ok=True)
|
|
917
724
|
|
|
918
725
|
return log_dir
|
|
919
726
|
|
|
@@ -927,208 +734,6 @@ class ReproducibilityConfig(C.Config):
|
|
|
927
734
|
"""
|
|
928
735
|
|
|
929
736
|
|
|
930
|
-
class ModelCheckpointCallbackConfig(CallbackConfigBase):
|
|
931
|
-
"""Arguments for the ModelCheckpoint callback."""
|
|
932
|
-
|
|
933
|
-
kind: Literal["model_checkpoint"] = "model_checkpoint"
|
|
934
|
-
|
|
935
|
-
dirpath: str | Path | None = None
|
|
936
|
-
"""
|
|
937
|
-
Directory path to save the model file. If `None`, we save to the checkpoint directory set in `config.directory`.
|
|
938
|
-
"""
|
|
939
|
-
|
|
940
|
-
filename: str | None = None
|
|
941
|
-
"""
|
|
942
|
-
Checkpoint filename.
|
|
943
|
-
If None, a default template is used (see :attr:`ModelCheckpoint.CHECKPOINT_JOIN_CHAR`).
|
|
944
|
-
"""
|
|
945
|
-
|
|
946
|
-
monitor: str | None = None
|
|
947
|
-
"""
|
|
948
|
-
Quantity to monitor for saving checkpoints.
|
|
949
|
-
If None, no metric is monitored and checkpoints are saved at the end of every epoch.
|
|
950
|
-
"""
|
|
951
|
-
|
|
952
|
-
verbose: bool = False
|
|
953
|
-
"""Verbosity mode. If True, print additional information about checkpoints."""
|
|
954
|
-
|
|
955
|
-
save_last: Literal[True, False, "link"] | None = "link"
|
|
956
|
-
"""
|
|
957
|
-
Whether to save the last checkpoint.
|
|
958
|
-
If True, saves a copy of the last checkpoint separately.
|
|
959
|
-
If "link", creates a symbolic link to the last checkpoint.
|
|
960
|
-
"""
|
|
961
|
-
|
|
962
|
-
save_top_k: int = 1
|
|
963
|
-
"""
|
|
964
|
-
Number of best models to save.
|
|
965
|
-
If -1, all models are saved.
|
|
966
|
-
If 0, no models are saved.
|
|
967
|
-
"""
|
|
968
|
-
|
|
969
|
-
save_weights_only: bool = False
|
|
970
|
-
"""Whether to save only the model's weights or the entire model object."""
|
|
971
|
-
|
|
972
|
-
mode: str = "min"
|
|
973
|
-
"""
|
|
974
|
-
One of "min" or "max".
|
|
975
|
-
If "min", training will stop when the metric monitored has stopped decreasing.
|
|
976
|
-
If "max", training will stop when the metric monitored has stopped increasing.
|
|
977
|
-
"""
|
|
978
|
-
|
|
979
|
-
auto_insert_metric_name: bool = True
|
|
980
|
-
"""Whether to automatically insert the metric name in the checkpoint filename."""
|
|
981
|
-
|
|
982
|
-
every_n_train_steps: int | None = None
|
|
983
|
-
"""
|
|
984
|
-
Number of training steps between checkpoints.
|
|
985
|
-
If None or 0, no checkpoints are saved during training.
|
|
986
|
-
"""
|
|
987
|
-
|
|
988
|
-
train_time_interval: timedelta | None = None
|
|
989
|
-
"""
|
|
990
|
-
Time interval between checkpoints during training.
|
|
991
|
-
If None, no checkpoints are saved during training based on time.
|
|
992
|
-
"""
|
|
993
|
-
|
|
994
|
-
every_n_epochs: int | None = None
|
|
995
|
-
"""
|
|
996
|
-
Number of epochs between checkpoints.
|
|
997
|
-
If None or 0, no checkpoints are saved at the end of epochs.
|
|
998
|
-
"""
|
|
999
|
-
|
|
1000
|
-
save_on_train_epoch_end: bool | None = None
|
|
1001
|
-
"""
|
|
1002
|
-
Whether to run checkpointing at the end of the training epoch.
|
|
1003
|
-
If False, checkpointing runs at the end of the validation.
|
|
1004
|
-
"""
|
|
1005
|
-
|
|
1006
|
-
enable_version_counter: bool = True
|
|
1007
|
-
"""Whether to append a version to the existing file name."""
|
|
1008
|
-
|
|
1009
|
-
auto_append_metric: bool = True
|
|
1010
|
-
"""If enabled, this will automatically add "-{monitor}" to the filename."""
|
|
1011
|
-
|
|
1012
|
-
@staticmethod
|
|
1013
|
-
def _convert_string(input_string: str):
|
|
1014
|
-
# Find all variables enclosed in curly braces
|
|
1015
|
-
variables = re.findall(r"\{(.*?)\}", input_string)
|
|
1016
|
-
|
|
1017
|
-
# Replace each variable with its corresponding key-value pair
|
|
1018
|
-
output_string = input_string
|
|
1019
|
-
for variable in variables:
|
|
1020
|
-
# If the name is something like {variable:format}, we shouldn't process the format.
|
|
1021
|
-
key_name = variable
|
|
1022
|
-
if ":" in variable:
|
|
1023
|
-
key_name, _ = variable.split(":", 1)
|
|
1024
|
-
continue
|
|
1025
|
-
|
|
1026
|
-
# Replace '/' with '_' in the key name
|
|
1027
|
-
key_name = key_name.replace("/", "_")
|
|
1028
|
-
output_string = output_string.replace(
|
|
1029
|
-
f"{{{variable}}}", f"{key_name}={{{variable}}}"
|
|
1030
|
-
)
|
|
1031
|
-
|
|
1032
|
-
return output_string
|
|
1033
|
-
|
|
1034
|
-
@override
|
|
1035
|
-
def construct_callbacks(self, root_config):
|
|
1036
|
-
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
|
|
1037
|
-
|
|
1038
|
-
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
1039
|
-
root_config.id, "checkpoint"
|
|
1040
|
-
)
|
|
1041
|
-
|
|
1042
|
-
# If `monitor` is not provided, we can use `config.primary_metric` if it is set.
|
|
1043
|
-
monitor = self.monitor
|
|
1044
|
-
mode = self.mode
|
|
1045
|
-
if (
|
|
1046
|
-
monitor is None
|
|
1047
|
-
and (primary_metric := root_config.primary_metric) is not None
|
|
1048
|
-
):
|
|
1049
|
-
monitor = primary_metric.validation_monitor
|
|
1050
|
-
mode = primary_metric.mode
|
|
1051
|
-
|
|
1052
|
-
filename = self.filename
|
|
1053
|
-
if self.auto_append_metric:
|
|
1054
|
-
if not filename:
|
|
1055
|
-
filename = "{epoch}-{step}"
|
|
1056
|
-
filename = f"{filename}-{{{monitor}}}"
|
|
1057
|
-
|
|
1058
|
-
if self.auto_insert_metric_name and filename:
|
|
1059
|
-
new_filename = self._convert_string(filename)
|
|
1060
|
-
log.critical(
|
|
1061
|
-
f"Updated ModelCheckpoint filename: {filename} -> {new_filename}"
|
|
1062
|
-
)
|
|
1063
|
-
filename = new_filename
|
|
1064
|
-
|
|
1065
|
-
yield ModelCheckpoint(
|
|
1066
|
-
dirpath=dirpath,
|
|
1067
|
-
filename=filename,
|
|
1068
|
-
monitor=monitor,
|
|
1069
|
-
mode=mode,
|
|
1070
|
-
verbose=self.verbose,
|
|
1071
|
-
save_last=self.save_last,
|
|
1072
|
-
save_top_k=self.save_top_k,
|
|
1073
|
-
save_weights_only=self.save_weights_only,
|
|
1074
|
-
auto_insert_metric_name=False,
|
|
1075
|
-
every_n_train_steps=self.every_n_train_steps,
|
|
1076
|
-
train_time_interval=self.train_time_interval,
|
|
1077
|
-
every_n_epochs=self.every_n_epochs,
|
|
1078
|
-
save_on_train_epoch_end=self.save_on_train_epoch_end,
|
|
1079
|
-
enable_version_counter=self.enable_version_counter,
|
|
1080
|
-
)
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
|
|
1084
|
-
kind: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
|
|
1085
|
-
|
|
1086
|
-
dirpath: str | Path | None = None
|
|
1087
|
-
"""Directory path to save the checkpoint file."""
|
|
1088
|
-
|
|
1089
|
-
filename: str | None = None
|
|
1090
|
-
"""Checkpoint filename. This must not include the extension. If `None`, `latest_epoch_{id}_{timestamp}` is used."""
|
|
1091
|
-
|
|
1092
|
-
save_weights_only: bool = False
|
|
1093
|
-
"""Whether to save only the model's weights or the entire model object."""
|
|
1094
|
-
|
|
1095
|
-
@override
|
|
1096
|
-
def construct_callbacks(self, root_config):
|
|
1097
|
-
from ..callbacks.latest_epoch_checkpoint import LatestEpochCheckpoint
|
|
1098
|
-
|
|
1099
|
-
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
1100
|
-
root_config.id, "checkpoint"
|
|
1101
|
-
)
|
|
1102
|
-
|
|
1103
|
-
yield LatestEpochCheckpoint(
|
|
1104
|
-
dirpath=dirpath,
|
|
1105
|
-
filename=self.filename,
|
|
1106
|
-
save_weights_only=self.save_weights_only,
|
|
1107
|
-
)
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
|
|
1111
|
-
kind: Literal["on_exception_checkpoint"] = "on_exception_checkpoint"
|
|
1112
|
-
|
|
1113
|
-
dirpath: str | Path | None = None
|
|
1114
|
-
"""Directory path to save the checkpoint file."""
|
|
1115
|
-
|
|
1116
|
-
filename: str | None = None
|
|
1117
|
-
"""Checkpoint filename. This must not include the extension. If `None`, `on_exception_{id}_{timestamp}` is used."""
|
|
1118
|
-
|
|
1119
|
-
@override
|
|
1120
|
-
def construct_callbacks(self, root_config):
|
|
1121
|
-
from ..callbacks.on_exception_checkpoint import OnExceptionCheckpoint
|
|
1122
|
-
|
|
1123
|
-
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
1124
|
-
root_config.id, "checkpoint"
|
|
1125
|
-
)
|
|
1126
|
-
|
|
1127
|
-
if not (filename := self.filename):
|
|
1128
|
-
filename = f"on_exception_{root_config.id}"
|
|
1129
|
-
yield OnExceptionCheckpoint(dirpath=dirpath, filename=filename)
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
737
|
CheckpointCallbackConfig: TypeAlias = Annotated[
|
|
1133
738
|
ModelCheckpointCallbackConfig
|
|
1134
739
|
| LatestEpochCheckpointCallbackConfig
|
|
@@ -1192,12 +797,12 @@ class CheckpointSavingConfig(CallbackConfigBase):
|
|
|
1192
797
|
)
|
|
1193
798
|
|
|
1194
799
|
@override
|
|
1195
|
-
def
|
|
800
|
+
def create_callbacks(self, root_config: "BaseConfig"):
|
|
1196
801
|
if not self.should_save_checkpoints(root_config):
|
|
1197
802
|
return
|
|
1198
803
|
|
|
1199
804
|
for callback_config in self.checkpoint_callbacks:
|
|
1200
|
-
yield from callback_config.
|
|
805
|
+
yield from callback_config.create_callbacks(root_config)
|
|
1201
806
|
|
|
1202
807
|
|
|
1203
808
|
class LightningTrainerKwargs(TypedDict, total=False):
|
|
@@ -1474,7 +1079,7 @@ class EarlyStoppingConfig(CallbackConfigBase):
|
|
|
1474
1079
|
"""
|
|
1475
1080
|
|
|
1476
1081
|
@override
|
|
1477
|
-
def
|
|
1082
|
+
def create_callbacks(self, root_config: "BaseConfig"):
|
|
1478
1083
|
from ..callbacks.early_stopping import EarlyStopping
|
|
1479
1084
|
|
|
1480
1085
|
monitor = self.monitor
|
|
@@ -1505,32 +1110,6 @@ class EarlyStoppingConfig(CallbackConfigBase):
|
|
|
1505
1110
|
]
|
|
1506
1111
|
|
|
1507
1112
|
|
|
1508
|
-
class ActSaveConfig(CallbackConfigBase):
|
|
1509
|
-
enabled: bool = True
|
|
1510
|
-
"""Enable activation saving."""
|
|
1511
|
-
|
|
1512
|
-
auto_save_logged_metrics: bool = False
|
|
1513
|
-
"""If enabled, will automatically save logged metrics (using `LightningModule.log`) as activations."""
|
|
1514
|
-
|
|
1515
|
-
save_dir: Path | None = None
|
|
1516
|
-
"""Directory to save activations to. If None, will use the activation directory set in `config.directory`."""
|
|
1517
|
-
|
|
1518
|
-
def __bool__(self):
|
|
1519
|
-
return self.enabled
|
|
1520
|
-
|
|
1521
|
-
def resolve_save_dir(self, root_config: "BaseConfig"):
|
|
1522
|
-
if self.save_dir is not None:
|
|
1523
|
-
return self.save_dir
|
|
1524
|
-
|
|
1525
|
-
return root_config.directory.resolve_subdirectory(root_config.id, "activation")
|
|
1526
|
-
|
|
1527
|
-
@override
|
|
1528
|
-
def construct_callbacks(self, root_config):
|
|
1529
|
-
from ..actsave import ActSaveCallback
|
|
1530
|
-
|
|
1531
|
-
return [ActSaveCallback()]
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
1113
|
class SanityCheckingConfig(C.Config):
|
|
1535
1114
|
reduce_lr_on_plateau: Literal["disable", "error", "warn"] = "error"
|
|
1536
1115
|
"""
|
|
@@ -1542,7 +1121,7 @@ class SanityCheckingConfig(C.Config):
|
|
|
1542
1121
|
|
|
1543
1122
|
|
|
1544
1123
|
class TrainerConfig(C.Config):
|
|
1545
|
-
checkpoint_loading: CheckpointLoadingConfig =
|
|
1124
|
+
checkpoint_loading: CheckpointLoadingConfig | Literal["auto"] = "auto"
|
|
1546
1125
|
"""Checkpoint loading configuration options."""
|
|
1547
1126
|
|
|
1548
1127
|
checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
|
|
@@ -1560,9 +1139,6 @@ class TrainerConfig(C.Config):
|
|
|
1560
1139
|
sanity_checking: SanityCheckingConfig = SanityCheckingConfig()
|
|
1561
1140
|
"""Sanity checking configuration options."""
|
|
1562
1141
|
|
|
1563
|
-
actsave: ActSaveConfig | None = ActSaveConfig(enabled=False)
|
|
1564
|
-
"""Activation saving configuration options."""
|
|
1565
|
-
|
|
1566
1142
|
early_stopping: EarlyStoppingConfig | None = None
|
|
1567
1143
|
"""Early stopping configuration options."""
|
|
1568
1144
|
|
|
@@ -1731,12 +1307,12 @@ class TrainerConfig(C.Config):
|
|
|
1731
1307
|
automatic selection based on the chosen accelerator. Default: ``"auto"``.
|
|
1732
1308
|
"""
|
|
1733
1309
|
|
|
1734
|
-
auto_wrap_trainer: bool = True
|
|
1735
|
-
"""If enabled, will automatically wrap the `run` function with a `Trainer.context()` context manager. Should be `True` most of the time."""
|
|
1736
1310
|
auto_set_default_root_dir: bool = True
|
|
1737
1311
|
"""If enabled, will automatically set the default root dir to [cwd/lightning_logs/<id>/]. There is basically no reason to disable this."""
|
|
1738
1312
|
supports_shared_parameters: bool = True
|
|
1739
1313
|
"""If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`"""
|
|
1314
|
+
save_checkpoint_metadata: bool = True
|
|
1315
|
+
"""If enabled, will save additional metadata whenever a checkpoint is saved."""
|
|
1740
1316
|
|
|
1741
1317
|
lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
|
|
1742
1318
|
"""
|
|
@@ -1756,35 +1332,6 @@ class TrainerConfig(C.Config):
|
|
|
1756
1332
|
"""If enabled, will set the torch float32 matmul precision to the specified value. Useful for faster training on Ampere+ GPUs."""
|
|
1757
1333
|
|
|
1758
1334
|
|
|
1759
|
-
class MetricConfig(C.Config):
|
|
1760
|
-
name: str
|
|
1761
|
-
"""The name of the primary metric."""
|
|
1762
|
-
|
|
1763
|
-
mode: Literal["min", "max"]
|
|
1764
|
-
"""
|
|
1765
|
-
The mode of the primary metric:
|
|
1766
|
-
- "min" for metrics that should be minimized (e.g., loss)
|
|
1767
|
-
- "max" for metrics that should be maximized (e.g., accuracy)
|
|
1768
|
-
"""
|
|
1769
|
-
|
|
1770
|
-
@property
|
|
1771
|
-
def validation_monitor(self) -> str:
|
|
1772
|
-
return f"val/{self.name}"
|
|
1773
|
-
|
|
1774
|
-
def __post_init__(self):
|
|
1775
|
-
for split in ("train", "val", "test", "predict"):
|
|
1776
|
-
if self.name.startswith(f"{split}/"):
|
|
1777
|
-
raise ValueError(
|
|
1778
|
-
f"Primary metric name should not start with '{split}/'. "
|
|
1779
|
-
f"Just use '{self.name[len(split) + 1:]}' instead. "
|
|
1780
|
-
"The split name is automatically added depending on the context."
|
|
1781
|
-
)
|
|
1782
|
-
|
|
1783
|
-
@classmethod
|
|
1784
|
-
def loss(cls, mode: Literal["min", "max"] = "min"):
|
|
1785
|
-
return cls(name="loss", mode=mode)
|
|
1786
|
-
|
|
1787
|
-
|
|
1788
1335
|
PrimaryMetricConfig: TypeAlias = MetricConfig
|
|
1789
1336
|
|
|
1790
1337
|
|
|
@@ -1804,7 +1351,9 @@ class BaseConfig(C.Config):
|
|
|
1804
1351
|
|
|
1805
1352
|
debug: bool = False
|
|
1806
1353
|
"""Whether to run in debug mode. This will enable debug logging and enable debug code paths."""
|
|
1807
|
-
environment: Annotated[EnvironmentConfig, C.Field(repr=False)] =
|
|
1354
|
+
environment: Annotated[EnvironmentConfig, C.Field(repr=False)] = (
|
|
1355
|
+
EnvironmentConfig.empty()
|
|
1356
|
+
)
|
|
1808
1357
|
"""A snapshot of the current environment information (e.g. python version, slurm info, etc.). This is automatically populated by the run script."""
|
|
1809
1358
|
|
|
1810
1359
|
directory: DirectoryConfig = DirectoryConfig()
|
|
@@ -1892,7 +1441,7 @@ class BaseConfig(C.Config):
|
|
|
1892
1441
|
self.directory = DirectoryConfig()
|
|
1893
1442
|
|
|
1894
1443
|
if environment:
|
|
1895
|
-
self.environment = EnvironmentConfig()
|
|
1444
|
+
self.environment = EnvironmentConfig.empty()
|
|
1896
1445
|
|
|
1897
1446
|
if meta:
|
|
1898
1447
|
self.meta = {}
|
|
@@ -1990,8 +1539,7 @@ class BaseConfig(C.Config):
|
|
|
1990
1539
|
)
|
|
1991
1540
|
return cls.model_validate(hparams)
|
|
1992
1541
|
|
|
1993
|
-
def
|
|
1994
|
-
yield self.trainer.actsave
|
|
1542
|
+
def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
|
|
1995
1543
|
yield self.trainer.early_stopping
|
|
1996
1544
|
yield self.trainer.checkpoint_saving
|
|
1997
1545
|
yield self.trainer.logging
|