nshtrainer 0.8.7__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +2 -1
- nshtrainer/callbacks/__init__.py +17 -1
- nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
- nshtrainer/callbacks/base.py +7 -5
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
- nshtrainer/callbacks/model_checkpoint.py +187 -0
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/throughput_monitor.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/ll/__init__.py +0 -1
- nshtrainer/ll/actsave.py +2 -1
- nshtrainer/metrics/__init__.py +1 -0
- nshtrainer/metrics/_config.py +37 -0
- nshtrainer/model/__init__.py +11 -11
- nshtrainer/model/_environment.py +777 -0
- nshtrainer/model/base.py +5 -114
- nshtrainer/model/config.py +92 -507
- nshtrainer/model/modules/logger.py +11 -6
- nshtrainer/runner.py +3 -6
- nshtrainer/trainer/_checkpoint_metadata.py +102 -0
- nshtrainer/trainer/_checkpoint_resolver.py +319 -0
- nshtrainer/trainer/_runtime_callback.py +120 -0
- nshtrainer/trainer/checkpoint_connector.py +63 -0
- nshtrainer/trainer/signal_connector.py +12 -9
- nshtrainer/trainer/trainer.py +111 -31
- {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/METADATA +3 -1
- {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/RECORD +34 -27
- nshtrainer/actsave/__init__.py +0 -3
- {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/WHEEL +0 -0
nshtrainer/model/config.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
import os
|
|
3
|
-
import re
|
|
4
|
-
import socket
|
|
5
3
|
import string
|
|
6
4
|
import time
|
|
7
5
|
import warnings
|
|
@@ -36,10 +34,17 @@ from lightning.pytorch.strategies.strategy import Strategy
|
|
|
36
34
|
from pydantic import DirectoryPath
|
|
37
35
|
from typing_extensions import Self, TypedDict, TypeVar, override
|
|
38
36
|
|
|
39
|
-
from ..callbacks import
|
|
37
|
+
from ..callbacks import (
|
|
38
|
+
CallbackConfig,
|
|
39
|
+
LatestEpochCheckpointCallbackConfig,
|
|
40
|
+
ModelCheckpointCallbackConfig,
|
|
41
|
+
OnExceptionCheckpointCallbackConfig,
|
|
42
|
+
WandbWatchConfig,
|
|
43
|
+
)
|
|
40
44
|
from ..callbacks.base import CallbackConfigBase
|
|
41
|
-
from ..
|
|
42
|
-
from ..
|
|
45
|
+
from ..metrics import MetricConfig
|
|
46
|
+
from ..trainer._checkpoint_resolver import CheckpointLoadingConfig
|
|
47
|
+
from ._environment import EnvironmentConfig
|
|
43
48
|
|
|
44
49
|
log = getLogger(__name__)
|
|
45
50
|
|
|
@@ -62,7 +67,7 @@ class BaseProfilerConfig(C.Config, ABC):
|
|
|
62
67
|
"""
|
|
63
68
|
|
|
64
69
|
@abstractmethod
|
|
65
|
-
def
|
|
70
|
+
def create_profiler(self, root_config: "BaseConfig") -> Profiler: ...
|
|
66
71
|
|
|
67
72
|
|
|
68
73
|
class SimpleProfilerConfig(BaseProfilerConfig):
|
|
@@ -75,7 +80,7 @@ class SimpleProfilerConfig(BaseProfilerConfig):
|
|
|
75
80
|
"""
|
|
76
81
|
|
|
77
82
|
@override
|
|
78
|
-
def
|
|
83
|
+
def create_profiler(self, root_config):
|
|
79
84
|
from lightning.pytorch.profilers.simple import SimpleProfiler
|
|
80
85
|
|
|
81
86
|
if (dirpath := self.dirpath) is None:
|
|
@@ -104,7 +109,7 @@ class AdvancedProfilerConfig(BaseProfilerConfig):
|
|
|
104
109
|
"""
|
|
105
110
|
|
|
106
111
|
@override
|
|
107
|
-
def
|
|
112
|
+
def create_profiler(self, root_config):
|
|
108
113
|
from lightning.pytorch.profilers.advanced import AdvancedProfiler
|
|
109
114
|
|
|
110
115
|
if (dirpath := self.dirpath) is None:
|
|
@@ -172,7 +177,7 @@ class PyTorchProfilerConfig(BaseProfilerConfig):
|
|
|
172
177
|
"""Keyword arguments for the PyTorch profiler. This depends on your PyTorch version"""
|
|
173
178
|
|
|
174
179
|
@override
|
|
175
|
-
def
|
|
180
|
+
def create_profiler(self, root_config):
|
|
176
181
|
from lightning.pytorch.profilers.pytorch import PyTorchProfiler
|
|
177
182
|
|
|
178
183
|
if (dirpath := self.dirpath) is None:
|
|
@@ -203,190 +208,6 @@ ProfilerConfig: TypeAlias = Annotated[
|
|
|
203
208
|
]
|
|
204
209
|
|
|
205
210
|
|
|
206
|
-
class EnvironmentClassInformationConfig(C.Config):
|
|
207
|
-
name: str
|
|
208
|
-
module: str
|
|
209
|
-
full_name: str
|
|
210
|
-
|
|
211
|
-
file_path: Path
|
|
212
|
-
source_file_path: Path | None = None
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
class EnvironmentSLURMInformationConfig(C.Config):
|
|
216
|
-
hostname: str
|
|
217
|
-
hostnames: list[str]
|
|
218
|
-
job_id: str
|
|
219
|
-
raw_job_id: str
|
|
220
|
-
array_job_id: str | None
|
|
221
|
-
array_task_id: str | None
|
|
222
|
-
num_tasks: int
|
|
223
|
-
num_nodes: int
|
|
224
|
-
node: str | int | None
|
|
225
|
-
global_rank: int
|
|
226
|
-
local_rank: int
|
|
227
|
-
|
|
228
|
-
@classmethod
|
|
229
|
-
def from_current_environment(cls):
|
|
230
|
-
try:
|
|
231
|
-
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
|
|
232
|
-
|
|
233
|
-
if not SLURMEnvironment.detect():
|
|
234
|
-
return None
|
|
235
|
-
|
|
236
|
-
hostname = socket.gethostname()
|
|
237
|
-
hostnames = [hostname]
|
|
238
|
-
if node_list := os.environ.get("SLURM_JOB_NODELIST", ""):
|
|
239
|
-
hostnames = parse_slurm_node_list(node_list)
|
|
240
|
-
|
|
241
|
-
raw_job_id = os.environ["SLURM_JOB_ID"]
|
|
242
|
-
job_id = raw_job_id
|
|
243
|
-
array_job_id = os.environ.get("SLURM_ARRAY_JOB_ID")
|
|
244
|
-
array_task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
|
|
245
|
-
if array_job_id and array_task_id:
|
|
246
|
-
job_id = f"{array_job_id}_{array_task_id}"
|
|
247
|
-
|
|
248
|
-
num_tasks = int(os.environ["SLURM_NTASKS"])
|
|
249
|
-
num_nodes = int(os.environ["SLURM_JOB_NUM_NODES"])
|
|
250
|
-
|
|
251
|
-
node_id = os.environ.get("SLURM_NODEID")
|
|
252
|
-
|
|
253
|
-
global_rank = int(os.environ["SLURM_PROCID"])
|
|
254
|
-
local_rank = int(os.environ["SLURM_LOCALID"])
|
|
255
|
-
|
|
256
|
-
return cls(
|
|
257
|
-
hostname=hostname,
|
|
258
|
-
hostnames=hostnames,
|
|
259
|
-
job_id=job_id,
|
|
260
|
-
raw_job_id=raw_job_id,
|
|
261
|
-
array_job_id=array_job_id,
|
|
262
|
-
array_task_id=array_task_id,
|
|
263
|
-
num_tasks=num_tasks,
|
|
264
|
-
num_nodes=num_nodes,
|
|
265
|
-
node=node_id,
|
|
266
|
-
global_rank=global_rank,
|
|
267
|
-
local_rank=local_rank,
|
|
268
|
-
)
|
|
269
|
-
except (ImportError, RuntimeError, ValueError, KeyError):
|
|
270
|
-
return None
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
class EnvironmentLSFInformationConfig(C.Config):
|
|
274
|
-
hostname: str
|
|
275
|
-
hostnames: list[str]
|
|
276
|
-
job_id: str
|
|
277
|
-
array_job_id: str | None
|
|
278
|
-
array_task_id: str | None
|
|
279
|
-
num_tasks: int
|
|
280
|
-
num_nodes: int
|
|
281
|
-
node: str | int | None
|
|
282
|
-
global_rank: int
|
|
283
|
-
local_rank: int
|
|
284
|
-
|
|
285
|
-
@classmethod
|
|
286
|
-
def from_current_environment(cls):
|
|
287
|
-
try:
|
|
288
|
-
import os
|
|
289
|
-
import socket
|
|
290
|
-
|
|
291
|
-
hostname = socket.gethostname()
|
|
292
|
-
hostnames = [hostname]
|
|
293
|
-
if node_list := os.environ.get("LSB_HOSTS", ""):
|
|
294
|
-
hostnames = node_list.split()
|
|
295
|
-
|
|
296
|
-
job_id = os.environ["LSB_JOBID"]
|
|
297
|
-
array_job_id = os.environ.get("LSB_JOBINDEX")
|
|
298
|
-
array_task_id = os.environ.get("LSB_JOBINDEX")
|
|
299
|
-
|
|
300
|
-
num_tasks = int(os.environ.get("LSB_DJOB_NUMPROC", 1))
|
|
301
|
-
num_nodes = len(set(hostnames))
|
|
302
|
-
|
|
303
|
-
node_id = (
|
|
304
|
-
os.environ.get("LSB_HOSTS", "").split().index(hostname)
|
|
305
|
-
if "LSB_HOSTS" in os.environ
|
|
306
|
-
else None
|
|
307
|
-
)
|
|
308
|
-
|
|
309
|
-
# LSF doesn't have direct equivalents for global_rank and local_rank
|
|
310
|
-
# You might need to calculate these based on your specific setup
|
|
311
|
-
global_rank = int(os.environ.get("PMI_RANK", 0))
|
|
312
|
-
local_rank = int(os.environ.get("LSB_RANK", 0))
|
|
313
|
-
|
|
314
|
-
return cls(
|
|
315
|
-
hostname=hostname,
|
|
316
|
-
hostnames=hostnames,
|
|
317
|
-
job_id=job_id,
|
|
318
|
-
array_job_id=array_job_id,
|
|
319
|
-
array_task_id=array_task_id,
|
|
320
|
-
num_tasks=num_tasks,
|
|
321
|
-
num_nodes=num_nodes,
|
|
322
|
-
node=node_id,
|
|
323
|
-
global_rank=global_rank,
|
|
324
|
-
local_rank=local_rank,
|
|
325
|
-
)
|
|
326
|
-
except (ImportError, RuntimeError, ValueError, KeyError):
|
|
327
|
-
return None
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
class EnvironmentLinuxEnvironmentConfig(C.Config):
|
|
331
|
-
"""
|
|
332
|
-
Information about the Linux environment (e.g., current user, hostname, etc.)
|
|
333
|
-
"""
|
|
334
|
-
|
|
335
|
-
user: str | None = None
|
|
336
|
-
hostname: str | None = None
|
|
337
|
-
system: str | None = None
|
|
338
|
-
release: str | None = None
|
|
339
|
-
version: str | None = None
|
|
340
|
-
machine: str | None = None
|
|
341
|
-
processor: str | None = None
|
|
342
|
-
cpu_count: int | None = None
|
|
343
|
-
memory: int | None = None
|
|
344
|
-
uptime: timedelta | None = None
|
|
345
|
-
boot_time: float | None = None
|
|
346
|
-
load_avg: tuple[float, float, float] | None = None
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
class EnvironmentSnapshotConfig(C.Config):
|
|
350
|
-
snapshot_dir: Path | None = None
|
|
351
|
-
modules: list[str] | None = None
|
|
352
|
-
|
|
353
|
-
@classmethod
|
|
354
|
-
def from_current_environment(cls):
|
|
355
|
-
draft = cls.draft()
|
|
356
|
-
if snapshot_dir := os.environ.get("NSHRUNNER_SNAPSHOT_DIR"):
|
|
357
|
-
draft.snapshot_dir = Path(snapshot_dir)
|
|
358
|
-
if modules := os.environ.get("NSHRUNNER_SNAPSHOT_MODULES"):
|
|
359
|
-
draft.modules = modules.split(",")
|
|
360
|
-
return draft.finalize()
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
class EnvironmentConfig(C.Config):
|
|
364
|
-
cwd: Path | None = None
|
|
365
|
-
|
|
366
|
-
snapshot: EnvironmentSnapshotConfig | None = None
|
|
367
|
-
|
|
368
|
-
python_executable: Path | None = None
|
|
369
|
-
python_path: list[Path] | None = None
|
|
370
|
-
python_version: str | None = None
|
|
371
|
-
|
|
372
|
-
config: EnvironmentClassInformationConfig | None = None
|
|
373
|
-
model: EnvironmentClassInformationConfig | None = None
|
|
374
|
-
data: EnvironmentClassInformationConfig | None = None
|
|
375
|
-
|
|
376
|
-
linux: EnvironmentLinuxEnvironmentConfig | None = None
|
|
377
|
-
|
|
378
|
-
slurm: EnvironmentSLURMInformationConfig | None = None
|
|
379
|
-
lsf: EnvironmentLSFInformationConfig | None = None
|
|
380
|
-
|
|
381
|
-
base_dir: Path | None = None
|
|
382
|
-
log_dir: Path | None = None
|
|
383
|
-
checkpoint_dir: Path | None = None
|
|
384
|
-
stdio_dir: Path | None = None
|
|
385
|
-
|
|
386
|
-
seed: int | None = None
|
|
387
|
-
seed_workers: bool | None = None
|
|
388
|
-
|
|
389
|
-
|
|
390
211
|
class BaseLoggerConfig(C.Config, ABC):
|
|
391
212
|
enabled: bool = True
|
|
392
213
|
"""Enable this logger."""
|
|
@@ -398,7 +219,7 @@ class BaseLoggerConfig(C.Config, ABC):
|
|
|
398
219
|
"""Directory to save the logs to. If None, will use the default log directory for the trainer."""
|
|
399
220
|
|
|
400
221
|
@abstractmethod
|
|
401
|
-
def
|
|
222
|
+
def create_logger(self, root_config: "BaseConfig") -> Logger | None: ...
|
|
402
223
|
|
|
403
224
|
def disable_(self):
|
|
404
225
|
self.enabled = False
|
|
@@ -466,18 +287,16 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
|
|
|
466
287
|
"""Whether to run WandB in offline mode."""
|
|
467
288
|
|
|
468
289
|
@override
|
|
469
|
-
def
|
|
290
|
+
def create_logger(self, root_config):
|
|
470
291
|
if not self.enabled:
|
|
471
292
|
return None
|
|
472
293
|
|
|
473
294
|
from lightning.pytorch.loggers.wandb import WandbLogger
|
|
474
295
|
|
|
475
|
-
save_dir = root_config.directory.
|
|
296
|
+
save_dir = root_config.directory._resolve_log_directory_for_logger(
|
|
476
297
|
root_config.id,
|
|
477
298
|
self,
|
|
478
299
|
)
|
|
479
|
-
save_dir = save_dir / "wandb"
|
|
480
|
-
save_dir.mkdir(parents=True, exist_ok=True)
|
|
481
300
|
return WandbLogger(
|
|
482
301
|
save_dir=save_dir,
|
|
483
302
|
project=self.project or _project_name(root_config),
|
|
@@ -494,9 +313,9 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
|
|
|
494
313
|
)
|
|
495
314
|
|
|
496
315
|
@override
|
|
497
|
-
def
|
|
316
|
+
def create_callbacks(self, root_config):
|
|
498
317
|
if self.watch:
|
|
499
|
-
yield from self.watch.
|
|
318
|
+
yield from self.watch.create_callbacks(root_config)
|
|
500
319
|
|
|
501
320
|
|
|
502
321
|
class CSVLoggerConfig(BaseLoggerConfig):
|
|
@@ -515,18 +334,16 @@ class CSVLoggerConfig(BaseLoggerConfig):
|
|
|
515
334
|
"""How often to flush logs to disk."""
|
|
516
335
|
|
|
517
336
|
@override
|
|
518
|
-
def
|
|
337
|
+
def create_logger(self, root_config):
|
|
519
338
|
if not self.enabled:
|
|
520
339
|
return None
|
|
521
340
|
|
|
522
341
|
from lightning.pytorch.loggers.csv_logs import CSVLogger
|
|
523
342
|
|
|
524
|
-
save_dir = root_config.directory.
|
|
343
|
+
save_dir = root_config.directory._resolve_log_directory_for_logger(
|
|
525
344
|
root_config.id,
|
|
526
345
|
self,
|
|
527
346
|
)
|
|
528
|
-
save_dir = save_dir / "csv"
|
|
529
|
-
save_dir.mkdir(parents=True, exist_ok=True)
|
|
530
347
|
return CSVLogger(
|
|
531
348
|
save_dir=save_dir,
|
|
532
349
|
name=root_config.run_name,
|
|
@@ -581,18 +398,16 @@ class TensorboardLoggerConfig(BaseLoggerConfig):
|
|
|
581
398
|
"""A string to put at the beginning of metric keys."""
|
|
582
399
|
|
|
583
400
|
@override
|
|
584
|
-
def
|
|
401
|
+
def create_logger(self, root_config):
|
|
585
402
|
if not self.enabled:
|
|
586
403
|
return None
|
|
587
404
|
|
|
588
405
|
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
|
|
589
406
|
|
|
590
|
-
save_dir = root_config.directory.
|
|
407
|
+
save_dir = root_config.directory._resolve_log_directory_for_logger(
|
|
591
408
|
root_config.id,
|
|
592
409
|
self,
|
|
593
410
|
)
|
|
594
|
-
save_dir = save_dir / "tensorboard"
|
|
595
|
-
save_dir.mkdir(parents=True, exist_ok=True)
|
|
596
411
|
return TensorBoardLogger(
|
|
597
412
|
save_dir=save_dir,
|
|
598
413
|
name=root_config.run_name,
|
|
@@ -624,6 +439,9 @@ class LoggingConfig(CallbackConfigBase):
|
|
|
624
439
|
log_epoch: bool = True
|
|
625
440
|
"""If enabled, will log the fractional epoch number to the logger."""
|
|
626
441
|
|
|
442
|
+
actsave_logged_metrics: bool = False
|
|
443
|
+
"""If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
|
|
444
|
+
|
|
627
445
|
@property
|
|
628
446
|
def wandb(self) -> WandbLoggerConfig | None:
|
|
629
447
|
return next(
|
|
@@ -650,7 +468,7 @@ class LoggingConfig(CallbackConfigBase):
|
|
|
650
468
|
),
|
|
651
469
|
)
|
|
652
470
|
|
|
653
|
-
def
|
|
471
|
+
def create_loggers(self, root_config: "BaseConfig"):
|
|
654
472
|
"""
|
|
655
473
|
Constructs and returns a list of loggers based on the provided root configuration.
|
|
656
474
|
|
|
@@ -671,13 +489,13 @@ class LoggingConfig(CallbackConfigBase):
|
|
|
671
489
|
):
|
|
672
490
|
if not logger_config.enabled:
|
|
673
491
|
continue
|
|
674
|
-
if (logger := logger_config.
|
|
492
|
+
if (logger := logger_config.create_logger(root_config)) is None:
|
|
675
493
|
continue
|
|
676
494
|
loggers.append(logger)
|
|
677
495
|
return loggers
|
|
678
496
|
|
|
679
497
|
@override
|
|
680
|
-
def
|
|
498
|
+
def create_callbacks(self, root_config):
|
|
681
499
|
if self.log_lr:
|
|
682
500
|
from lightning.pytorch.callbacks import LearningRateMonitor
|
|
683
501
|
|
|
@@ -696,7 +514,7 @@ class LoggingConfig(CallbackConfigBase):
|
|
|
696
514
|
if not logger or not isinstance(logger, CallbackConfigBase):
|
|
697
515
|
continue
|
|
698
516
|
|
|
699
|
-
yield from logger.
|
|
517
|
+
yield from logger.create_callbacks(root_config)
|
|
700
518
|
|
|
701
519
|
|
|
702
520
|
class GradientClippingConfig(C.Config):
|
|
@@ -723,7 +541,7 @@ class OptimizationConfig(CallbackConfigBase):
|
|
|
723
541
|
"""Gradient clipping configuration, or None to disable gradient clipping."""
|
|
724
542
|
|
|
725
543
|
@override
|
|
726
|
-
def
|
|
544
|
+
def create_callbacks(self, root_config):
|
|
727
545
|
from ..callbacks.norm_logging import NormLoggingConfig
|
|
728
546
|
|
|
729
547
|
yield from NormLoggingConfig(
|
|
@@ -731,7 +549,7 @@ class OptimizationConfig(CallbackConfigBase):
|
|
|
731
549
|
log_grad_norm_per_param=self.log_grad_norm_per_param,
|
|
732
550
|
log_param_norm=self.log_param_norm,
|
|
733
551
|
log_param_norm_per_param=self.log_param_norm_per_param,
|
|
734
|
-
).
|
|
552
|
+
).create_callbacks(root_config)
|
|
735
553
|
|
|
736
554
|
|
|
737
555
|
TPlugin = TypeVar(
|
|
@@ -746,17 +564,17 @@ TPlugin = TypeVar(
|
|
|
746
564
|
|
|
747
565
|
@runtime_checkable
|
|
748
566
|
class PluginConfigProtocol(Protocol[TPlugin]):
|
|
749
|
-
def
|
|
567
|
+
def create_plugin(self) -> TPlugin: ...
|
|
750
568
|
|
|
751
569
|
|
|
752
570
|
@runtime_checkable
|
|
753
571
|
class AcceleratorConfigProtocol(Protocol):
|
|
754
|
-
def
|
|
572
|
+
def create_accelerator(self) -> Accelerator: ...
|
|
755
573
|
|
|
756
574
|
|
|
757
575
|
@runtime_checkable
|
|
758
576
|
class StrategyConfigProtocol(Protocol):
|
|
759
|
-
def
|
|
577
|
+
def create_strategy(self) -> Strategy: ...
|
|
760
578
|
|
|
761
579
|
|
|
762
580
|
AcceleratorLiteral: TypeAlias = Literal[
|
|
@@ -793,16 +611,34 @@ StrategyLiteral: TypeAlias = Literal[
|
|
|
793
611
|
]
|
|
794
612
|
|
|
795
613
|
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
""
|
|
799
|
-
|
|
614
|
+
def _create_symlink_to_nshrunner(base_dir: Path):
|
|
615
|
+
# Resolve the current nshrunner session directory
|
|
616
|
+
if not (session_dir := os.environ.get("NSHRUNNER_SESSION_DIR")):
|
|
617
|
+
log.warning("NSHRUNNER_SESSION_DIR is not set. Skipping symlink creation.")
|
|
618
|
+
return
|
|
619
|
+
session_dir = Path(session_dir)
|
|
620
|
+
if not session_dir.exists() or not session_dir.is_dir():
|
|
621
|
+
log.warning(
|
|
622
|
+
f"NSHRUNNER_SESSION_DIR is not a valid directory: {session_dir}. "
|
|
623
|
+
"Skipping symlink creation."
|
|
624
|
+
)
|
|
625
|
+
return
|
|
800
626
|
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
627
|
+
# Create the symlink
|
|
628
|
+
symlink_path = base_dir / "nshrunner"
|
|
629
|
+
if symlink_path.exists():
|
|
630
|
+
# If it already points to the correct directory, we're done
|
|
631
|
+
if symlink_path.resolve() == session_dir.resolve():
|
|
632
|
+
return
|
|
633
|
+
|
|
634
|
+
# Otherwise, we should log a warning and remove the existing symlink
|
|
635
|
+
log.warning(
|
|
636
|
+
f"A symlink pointing to {symlink_path.resolve()} already exists at {symlink_path}. "
|
|
637
|
+
"Removing the existing symlink."
|
|
638
|
+
)
|
|
639
|
+
symlink_path.unlink()
|
|
640
|
+
|
|
641
|
+
symlink_path.symlink_to(session_dir)
|
|
806
642
|
|
|
807
643
|
|
|
808
644
|
class DirectoryConfig(C.Config):
|
|
@@ -813,30 +649,33 @@ class DirectoryConfig(C.Config):
|
|
|
813
649
|
This isn't specific to the run; it is the parent directory of all runs.
|
|
814
650
|
"""
|
|
815
651
|
|
|
652
|
+
create_symlink_to_nshrunner_root: bool = True
|
|
653
|
+
"""Should we create a symlink to the root folder for the Runner (if we're in one)?"""
|
|
654
|
+
|
|
816
655
|
log: Path | None = None
|
|
817
|
-
"""Base directory for all experiment tracking (e.g., WandB, Tensorboard, etc.) files. If None, will use
|
|
656
|
+
"""Base directory for all experiment tracking (e.g., WandB, Tensorboard, etc.) files. If None, will use nshtrainer/{id}/log/."""
|
|
818
657
|
|
|
819
658
|
stdio: Path | None = None
|
|
820
|
-
"""stdout/stderr log directory to use for the trainer. If None, will use
|
|
659
|
+
"""stdout/stderr log directory to use for the trainer. If None, will use nshtrainer/{id}/stdio/."""
|
|
821
660
|
|
|
822
661
|
checkpoint: Path | None = None
|
|
823
|
-
"""Checkpoint directory to use for the trainer. If None, will use
|
|
662
|
+
"""Checkpoint directory to use for the trainer. If None, will use nshtrainer/{id}/checkpoint/."""
|
|
824
663
|
|
|
825
664
|
activation: Path | None = None
|
|
826
|
-
"""Activation directory to use for the trainer. If None, will use
|
|
665
|
+
"""Activation directory to use for the trainer. If None, will use nshtrainer/{id}/activation/."""
|
|
827
666
|
|
|
828
667
|
profile: Path | None = None
|
|
829
|
-
"""Directory to save profiling information to. If None, will use
|
|
668
|
+
"""Directory to save profiling information to. If None, will use nshtrainer/{id}/profile/."""
|
|
830
669
|
|
|
831
670
|
def resolve_run_root_directory(self, run_id: str) -> Path:
|
|
832
671
|
if (project_root_dir := self.project_root) is None:
|
|
833
672
|
project_root_dir = Path.cwd()
|
|
834
673
|
|
|
835
|
-
# The default base dir is $CWD/
|
|
836
|
-
base_dir = project_root_dir / "
|
|
674
|
+
# The default base dir is $CWD/nshtrainer/{id}/
|
|
675
|
+
base_dir = project_root_dir / "nshtrainer"
|
|
837
676
|
base_dir.mkdir(exist_ok=True)
|
|
838
677
|
|
|
839
|
-
# Add a .gitignore file to the
|
|
678
|
+
# Add a .gitignore file to the nshtrainer directory
|
|
840
679
|
# which will ignore all files except for the .gitignore file itself
|
|
841
680
|
gitignore_path = base_dir / ".gitignore"
|
|
842
681
|
if not gitignore_path.exists():
|
|
@@ -846,6 +685,10 @@ class DirectoryConfig(C.Config):
|
|
|
846
685
|
base_dir = base_dir / run_id
|
|
847
686
|
base_dir.mkdir(exist_ok=True)
|
|
848
687
|
|
|
688
|
+
# Create a symlink to the root folder for the Runner
|
|
689
|
+
if self.create_symlink_to_nshrunner_root:
|
|
690
|
+
_create_symlink_to_nshrunner(base_dir)
|
|
691
|
+
|
|
849
692
|
return base_dir
|
|
850
693
|
|
|
851
694
|
def resolve_subdirectory(
|
|
@@ -854,7 +697,7 @@ class DirectoryConfig(C.Config):
|
|
|
854
697
|
# subdirectory: Literal["log", "stdio", "checkpoint", "activation", "profile"],
|
|
855
698
|
subdirectory: str,
|
|
856
699
|
) -> Path:
|
|
857
|
-
# The subdir will be $CWD/
|
|
700
|
+
# The subdir will be $CWD/nshtrainer/{id}/{log, stdio, checkpoint, activation}/
|
|
858
701
|
if (subdir := getattr(self, subdirectory, None)) is not None:
|
|
859
702
|
assert isinstance(
|
|
860
703
|
subdir, Path
|
|
@@ -866,7 +709,7 @@ class DirectoryConfig(C.Config):
|
|
|
866
709
|
dir.mkdir(exist_ok=True)
|
|
867
710
|
return dir
|
|
868
711
|
|
|
869
|
-
def
|
|
712
|
+
def _resolve_log_directory_for_logger(
|
|
870
713
|
self,
|
|
871
714
|
run_id: str,
|
|
872
715
|
logger: LoggerConfig,
|
|
@@ -874,9 +717,10 @@ class DirectoryConfig(C.Config):
|
|
|
874
717
|
if (log_dir := logger.log_dir) is not None:
|
|
875
718
|
return log_dir
|
|
876
719
|
|
|
877
|
-
# Save to
|
|
720
|
+
# Save to nshtrainer/{id}/log/{logger kind}
|
|
878
721
|
log_dir = self.resolve_subdirectory(run_id, "log")
|
|
879
722
|
log_dir = log_dir / logger.kind
|
|
723
|
+
log_dir.mkdir(exist_ok=True)
|
|
880
724
|
|
|
881
725
|
return log_dir
|
|
882
726
|
|
|
@@ -890,208 +734,6 @@ class ReproducibilityConfig(C.Config):
|
|
|
890
734
|
"""
|
|
891
735
|
|
|
892
736
|
|
|
893
|
-
class ModelCheckpointCallbackConfig(CallbackConfigBase):
|
|
894
|
-
"""Arguments for the ModelCheckpoint callback."""
|
|
895
|
-
|
|
896
|
-
kind: Literal["model_checkpoint"] = "model_checkpoint"
|
|
897
|
-
|
|
898
|
-
dirpath: str | Path | None = None
|
|
899
|
-
"""
|
|
900
|
-
Directory path to save the model file. If `None`, we save to the checkpoint directory set in `config.directory`.
|
|
901
|
-
"""
|
|
902
|
-
|
|
903
|
-
filename: str | None = None
|
|
904
|
-
"""
|
|
905
|
-
Checkpoint filename.
|
|
906
|
-
If None, a default template is used (see :attr:`ModelCheckpoint.CHECKPOINT_JOIN_CHAR`).
|
|
907
|
-
"""
|
|
908
|
-
|
|
909
|
-
monitor: str | None = None
|
|
910
|
-
"""
|
|
911
|
-
Quantity to monitor for saving checkpoints.
|
|
912
|
-
If None, no metric is monitored and checkpoints are saved at the end of every epoch.
|
|
913
|
-
"""
|
|
914
|
-
|
|
915
|
-
verbose: bool = False
|
|
916
|
-
"""Verbosity mode. If True, print additional information about checkpoints."""
|
|
917
|
-
|
|
918
|
-
save_last: Literal[True, False, "link"] | None = "link"
|
|
919
|
-
"""
|
|
920
|
-
Whether to save the last checkpoint.
|
|
921
|
-
If True, saves a copy of the last checkpoint separately.
|
|
922
|
-
If "link", creates a symbolic link to the last checkpoint.
|
|
923
|
-
"""
|
|
924
|
-
|
|
925
|
-
save_top_k: int = 1
|
|
926
|
-
"""
|
|
927
|
-
Number of best models to save.
|
|
928
|
-
If -1, all models are saved.
|
|
929
|
-
If 0, no models are saved.
|
|
930
|
-
"""
|
|
931
|
-
|
|
932
|
-
save_weights_only: bool = False
|
|
933
|
-
"""Whether to save only the model's weights or the entire model object."""
|
|
934
|
-
|
|
935
|
-
mode: str = "min"
|
|
936
|
-
"""
|
|
937
|
-
One of "min" or "max".
|
|
938
|
-
If "min", training will stop when the metric monitored has stopped decreasing.
|
|
939
|
-
If "max", training will stop when the metric monitored has stopped increasing.
|
|
940
|
-
"""
|
|
941
|
-
|
|
942
|
-
auto_insert_metric_name: bool = True
|
|
943
|
-
"""Whether to automatically insert the metric name in the checkpoint filename."""
|
|
944
|
-
|
|
945
|
-
every_n_train_steps: int | None = None
|
|
946
|
-
"""
|
|
947
|
-
Number of training steps between checkpoints.
|
|
948
|
-
If None or 0, no checkpoints are saved during training.
|
|
949
|
-
"""
|
|
950
|
-
|
|
951
|
-
train_time_interval: timedelta | None = None
|
|
952
|
-
"""
|
|
953
|
-
Time interval between checkpoints during training.
|
|
954
|
-
If None, no checkpoints are saved during training based on time.
|
|
955
|
-
"""
|
|
956
|
-
|
|
957
|
-
every_n_epochs: int | None = None
|
|
958
|
-
"""
|
|
959
|
-
Number of epochs between checkpoints.
|
|
960
|
-
If None or 0, no checkpoints are saved at the end of epochs.
|
|
961
|
-
"""
|
|
962
|
-
|
|
963
|
-
save_on_train_epoch_end: bool | None = None
|
|
964
|
-
"""
|
|
965
|
-
Whether to run checkpointing at the end of the training epoch.
|
|
966
|
-
If False, checkpointing runs at the end of the validation.
|
|
967
|
-
"""
|
|
968
|
-
|
|
969
|
-
enable_version_counter: bool = True
|
|
970
|
-
"""Whether to append a version to the existing file name."""
|
|
971
|
-
|
|
972
|
-
auto_append_metric: bool = True
|
|
973
|
-
"""If enabled, this will automatically add "-{monitor}" to the filename."""
|
|
974
|
-
|
|
975
|
-
@staticmethod
|
|
976
|
-
def _convert_string(input_string: str):
|
|
977
|
-
# Find all variables enclosed in curly braces
|
|
978
|
-
variables = re.findall(r"\{(.*?)\}", input_string)
|
|
979
|
-
|
|
980
|
-
# Replace each variable with its corresponding key-value pair
|
|
981
|
-
output_string = input_string
|
|
982
|
-
for variable in variables:
|
|
983
|
-
# If the name is something like {variable:format}, we shouldn't process the format.
|
|
984
|
-
key_name = variable
|
|
985
|
-
if ":" in variable:
|
|
986
|
-
key_name, _ = variable.split(":", 1)
|
|
987
|
-
continue
|
|
988
|
-
|
|
989
|
-
# Replace '/' with '_' in the key name
|
|
990
|
-
key_name = key_name.replace("/", "_")
|
|
991
|
-
output_string = output_string.replace(
|
|
992
|
-
f"{{{variable}}}", f"{key_name}={{{variable}}}"
|
|
993
|
-
)
|
|
994
|
-
|
|
995
|
-
return output_string
|
|
996
|
-
|
|
997
|
-
@override
|
|
998
|
-
def construct_callbacks(self, root_config):
|
|
999
|
-
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
|
|
1000
|
-
|
|
1001
|
-
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
1002
|
-
root_config.id, "checkpoint"
|
|
1003
|
-
)
|
|
1004
|
-
|
|
1005
|
-
# If `monitor` is not provided, we can use `config.primary_metric` if it is set.
|
|
1006
|
-
monitor = self.monitor
|
|
1007
|
-
mode = self.mode
|
|
1008
|
-
if (
|
|
1009
|
-
monitor is None
|
|
1010
|
-
and (primary_metric := root_config.primary_metric) is not None
|
|
1011
|
-
):
|
|
1012
|
-
monitor = primary_metric.validation_monitor
|
|
1013
|
-
mode = primary_metric.mode
|
|
1014
|
-
|
|
1015
|
-
filename = self.filename
|
|
1016
|
-
if self.auto_append_metric:
|
|
1017
|
-
if not filename:
|
|
1018
|
-
filename = "{epoch}-{step}"
|
|
1019
|
-
filename = f"{filename}-{{{monitor}}}"
|
|
1020
|
-
|
|
1021
|
-
if self.auto_insert_metric_name and filename:
|
|
1022
|
-
new_filename = self._convert_string(filename)
|
|
1023
|
-
log.critical(
|
|
1024
|
-
f"Updated ModelCheckpoint filename: {filename} -> {new_filename}"
|
|
1025
|
-
)
|
|
1026
|
-
filename = new_filename
|
|
1027
|
-
|
|
1028
|
-
yield ModelCheckpoint(
|
|
1029
|
-
dirpath=dirpath,
|
|
1030
|
-
filename=filename,
|
|
1031
|
-
monitor=monitor,
|
|
1032
|
-
mode=mode,
|
|
1033
|
-
verbose=self.verbose,
|
|
1034
|
-
save_last=self.save_last,
|
|
1035
|
-
save_top_k=self.save_top_k,
|
|
1036
|
-
save_weights_only=self.save_weights_only,
|
|
1037
|
-
auto_insert_metric_name=False,
|
|
1038
|
-
every_n_train_steps=self.every_n_train_steps,
|
|
1039
|
-
train_time_interval=self.train_time_interval,
|
|
1040
|
-
every_n_epochs=self.every_n_epochs,
|
|
1041
|
-
save_on_train_epoch_end=self.save_on_train_epoch_end,
|
|
1042
|
-
enable_version_counter=self.enable_version_counter,
|
|
1043
|
-
)
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
|
|
1047
|
-
kind: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
|
|
1048
|
-
|
|
1049
|
-
dirpath: str | Path | None = None
|
|
1050
|
-
"""Directory path to save the checkpoint file."""
|
|
1051
|
-
|
|
1052
|
-
filename: str | None = None
|
|
1053
|
-
"""Checkpoint filename. This must not include the extension. If `None`, `latest_epoch_{id}_{timestamp}` is used."""
|
|
1054
|
-
|
|
1055
|
-
save_weights_only: bool = False
|
|
1056
|
-
"""Whether to save only the model's weights or the entire model object."""
|
|
1057
|
-
|
|
1058
|
-
@override
|
|
1059
|
-
def construct_callbacks(self, root_config):
|
|
1060
|
-
from ..callbacks.latest_epoch_checkpoint import LatestEpochCheckpoint
|
|
1061
|
-
|
|
1062
|
-
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
1063
|
-
root_config.id, "checkpoint"
|
|
1064
|
-
)
|
|
1065
|
-
|
|
1066
|
-
yield LatestEpochCheckpoint(
|
|
1067
|
-
dirpath=dirpath,
|
|
1068
|
-
filename=self.filename,
|
|
1069
|
-
save_weights_only=self.save_weights_only,
|
|
1070
|
-
)
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
|
|
1074
|
-
kind: Literal["on_exception_checkpoint"] = "on_exception_checkpoint"
|
|
1075
|
-
|
|
1076
|
-
dirpath: str | Path | None = None
|
|
1077
|
-
"""Directory path to save the checkpoint file."""
|
|
1078
|
-
|
|
1079
|
-
filename: str | None = None
|
|
1080
|
-
"""Checkpoint filename. This must not include the extension. If `None`, `on_exception_{id}_{timestamp}` is used."""
|
|
1081
|
-
|
|
1082
|
-
@override
|
|
1083
|
-
def construct_callbacks(self, root_config):
|
|
1084
|
-
from ..callbacks.on_exception_checkpoint import OnExceptionCheckpoint
|
|
1085
|
-
|
|
1086
|
-
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
1087
|
-
root_config.id, "checkpoint"
|
|
1088
|
-
)
|
|
1089
|
-
|
|
1090
|
-
if not (filename := self.filename):
|
|
1091
|
-
filename = f"on_exception_{root_config.id}"
|
|
1092
|
-
yield OnExceptionCheckpoint(dirpath=dirpath, filename=filename)
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
737
|
CheckpointCallbackConfig: TypeAlias = Annotated[
|
|
1096
738
|
ModelCheckpointCallbackConfig
|
|
1097
739
|
| LatestEpochCheckpointCallbackConfig
|
|
@@ -1155,12 +797,12 @@ class CheckpointSavingConfig(CallbackConfigBase):
|
|
|
1155
797
|
)
|
|
1156
798
|
|
|
1157
799
|
@override
|
|
1158
|
-
def
|
|
800
|
+
def create_callbacks(self, root_config: "BaseConfig"):
|
|
1159
801
|
if not self.should_save_checkpoints(root_config):
|
|
1160
802
|
return
|
|
1161
803
|
|
|
1162
804
|
for callback_config in self.checkpoint_callbacks:
|
|
1163
|
-
yield from callback_config.
|
|
805
|
+
yield from callback_config.create_callbacks(root_config)
|
|
1164
806
|
|
|
1165
807
|
|
|
1166
808
|
class LightningTrainerKwargs(TypedDict, total=False):
|
|
@@ -1437,7 +1079,7 @@ class EarlyStoppingConfig(CallbackConfigBase):
|
|
|
1437
1079
|
"""
|
|
1438
1080
|
|
|
1439
1081
|
@override
|
|
1440
|
-
def
|
|
1082
|
+
def create_callbacks(self, root_config: "BaseConfig"):
|
|
1441
1083
|
from ..callbacks.early_stopping import EarlyStopping
|
|
1442
1084
|
|
|
1443
1085
|
monitor = self.monitor
|
|
@@ -1468,32 +1110,6 @@ class EarlyStoppingConfig(CallbackConfigBase):
|
|
|
1468
1110
|
]
|
|
1469
1111
|
|
|
1470
1112
|
|
|
1471
|
-
class ActSaveConfig(CallbackConfigBase):
|
|
1472
|
-
enabled: bool = True
|
|
1473
|
-
"""Enable activation saving."""
|
|
1474
|
-
|
|
1475
|
-
auto_save_logged_metrics: bool = False
|
|
1476
|
-
"""If enabled, will automatically save logged metrics (using `LightningModule.log`) as activations."""
|
|
1477
|
-
|
|
1478
|
-
save_dir: Path | None = None
|
|
1479
|
-
"""Directory to save activations to. If None, will use the activation directory set in `config.directory`."""
|
|
1480
|
-
|
|
1481
|
-
def __bool__(self):
|
|
1482
|
-
return self.enabled
|
|
1483
|
-
|
|
1484
|
-
def resolve_save_dir(self, root_config: "BaseConfig"):
|
|
1485
|
-
if self.save_dir is not None:
|
|
1486
|
-
return self.save_dir
|
|
1487
|
-
|
|
1488
|
-
return root_config.directory.resolve_subdirectory(root_config.id, "activation")
|
|
1489
|
-
|
|
1490
|
-
@override
|
|
1491
|
-
def construct_callbacks(self, root_config):
|
|
1492
|
-
from ..actsave import ActSaveCallback
|
|
1493
|
-
|
|
1494
|
-
return [ActSaveCallback()]
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
1113
|
class SanityCheckingConfig(C.Config):
|
|
1498
1114
|
reduce_lr_on_plateau: Literal["disable", "error", "warn"] = "error"
|
|
1499
1115
|
"""
|
|
@@ -1505,7 +1121,7 @@ class SanityCheckingConfig(C.Config):
|
|
|
1505
1121
|
|
|
1506
1122
|
|
|
1507
1123
|
class TrainerConfig(C.Config):
|
|
1508
|
-
checkpoint_loading: CheckpointLoadingConfig =
|
|
1124
|
+
checkpoint_loading: CheckpointLoadingConfig | Literal["auto"] = "auto"
|
|
1509
1125
|
"""Checkpoint loading configuration options."""
|
|
1510
1126
|
|
|
1511
1127
|
checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
|
|
@@ -1523,9 +1139,6 @@ class TrainerConfig(C.Config):
|
|
|
1523
1139
|
sanity_checking: SanityCheckingConfig = SanityCheckingConfig()
|
|
1524
1140
|
"""Sanity checking configuration options."""
|
|
1525
1141
|
|
|
1526
|
-
actsave: ActSaveConfig | None = ActSaveConfig(enabled=False)
|
|
1527
|
-
"""Activation saving configuration options."""
|
|
1528
|
-
|
|
1529
1142
|
early_stopping: EarlyStoppingConfig | None = None
|
|
1530
1143
|
"""Early stopping configuration options."""
|
|
1531
1144
|
|
|
@@ -1694,12 +1307,12 @@ class TrainerConfig(C.Config):
|
|
|
1694
1307
|
automatic selection based on the chosen accelerator. Default: ``"auto"``.
|
|
1695
1308
|
"""
|
|
1696
1309
|
|
|
1697
|
-
auto_wrap_trainer: bool = True
|
|
1698
|
-
"""If enabled, will automatically wrap the `run` function with a `Trainer.context()` context manager. Should be `True` most of the time."""
|
|
1699
1310
|
auto_set_default_root_dir: bool = True
|
|
1700
1311
|
"""If enabled, will automatically set the default root dir to [cwd/lightning_logs/<id>/]. There is basically no reason to disable this."""
|
|
1701
1312
|
supports_shared_parameters: bool = True
|
|
1702
1313
|
"""If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`"""
|
|
1314
|
+
save_checkpoint_metadata: bool = True
|
|
1315
|
+
"""If enabled, will save additional metadata whenever a checkpoint is saved."""
|
|
1703
1316
|
|
|
1704
1317
|
lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
|
|
1705
1318
|
"""
|
|
@@ -1719,35 +1332,6 @@ class TrainerConfig(C.Config):
|
|
|
1719
1332
|
"""If enabled, will set the torch float32 matmul precision to the specified value. Useful for faster training on Ampere+ GPUs."""
|
|
1720
1333
|
|
|
1721
1334
|
|
|
1722
|
-
class MetricConfig(C.Config):
|
|
1723
|
-
name: str
|
|
1724
|
-
"""The name of the primary metric."""
|
|
1725
|
-
|
|
1726
|
-
mode: Literal["min", "max"]
|
|
1727
|
-
"""
|
|
1728
|
-
The mode of the primary metric:
|
|
1729
|
-
- "min" for metrics that should be minimized (e.g., loss)
|
|
1730
|
-
- "max" for metrics that should be maximized (e.g., accuracy)
|
|
1731
|
-
"""
|
|
1732
|
-
|
|
1733
|
-
@property
|
|
1734
|
-
def validation_monitor(self) -> str:
|
|
1735
|
-
return f"val/{self.name}"
|
|
1736
|
-
|
|
1737
|
-
def __post_init__(self):
|
|
1738
|
-
for split in ("train", "val", "test", "predict"):
|
|
1739
|
-
if self.name.startswith(f"{split}/"):
|
|
1740
|
-
raise ValueError(
|
|
1741
|
-
f"Primary metric name should not start with '{split}/'. "
|
|
1742
|
-
f"Just use '{self.name[len(split) + 1:]}' instead. "
|
|
1743
|
-
"The split name is automatically added depending on the context."
|
|
1744
|
-
)
|
|
1745
|
-
|
|
1746
|
-
@classmethod
|
|
1747
|
-
def loss(cls, mode: Literal["min", "max"] = "min"):
|
|
1748
|
-
return cls(name="loss", mode=mode)
|
|
1749
|
-
|
|
1750
|
-
|
|
1751
1335
|
PrimaryMetricConfig: TypeAlias = MetricConfig
|
|
1752
1336
|
|
|
1753
1337
|
|
|
@@ -1767,7 +1351,9 @@ class BaseConfig(C.Config):
|
|
|
1767
1351
|
|
|
1768
1352
|
debug: bool = False
|
|
1769
1353
|
"""Whether to run in debug mode. This will enable debug logging and enable debug code paths."""
|
|
1770
|
-
environment: Annotated[EnvironmentConfig, C.Field(repr=False)] =
|
|
1354
|
+
environment: Annotated[EnvironmentConfig, C.Field(repr=False)] = (
|
|
1355
|
+
EnvironmentConfig.empty()
|
|
1356
|
+
)
|
|
1771
1357
|
"""A snapshot of the current environment information (e.g. python version, slurm info, etc.). This is automatically populated by the run script."""
|
|
1772
1358
|
|
|
1773
1359
|
directory: DirectoryConfig = DirectoryConfig()
|
|
@@ -1855,7 +1441,7 @@ class BaseConfig(C.Config):
|
|
|
1855
1441
|
self.directory = DirectoryConfig()
|
|
1856
1442
|
|
|
1857
1443
|
if environment:
|
|
1858
|
-
self.environment = EnvironmentConfig()
|
|
1444
|
+
self.environment = EnvironmentConfig.empty()
|
|
1859
1445
|
|
|
1860
1446
|
if meta:
|
|
1861
1447
|
self.meta = {}
|
|
@@ -1953,8 +1539,7 @@ class BaseConfig(C.Config):
|
|
|
1953
1539
|
)
|
|
1954
1540
|
return cls.model_validate(hparams)
|
|
1955
1541
|
|
|
1956
|
-
def
|
|
1957
|
-
yield self.trainer.actsave
|
|
1542
|
+
def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
|
|
1958
1543
|
yield self.trainer.early_stopping
|
|
1959
1544
|
yield self.trainer.checkpoint_saving
|
|
1960
1545
|
yield self.trainer.logging
|