nshtrainer 0.8.7__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. nshtrainer/__init__.py +2 -1
  2. nshtrainer/callbacks/__init__.py +17 -1
  3. nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
  4. nshtrainer/callbacks/base.py +7 -5
  5. nshtrainer/callbacks/ema.py +1 -1
  6. nshtrainer/callbacks/finite_checks.py +1 -1
  7. nshtrainer/callbacks/gradient_skipping.py +1 -1
  8. nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
  9. nshtrainer/callbacks/model_checkpoint.py +187 -0
  10. nshtrainer/callbacks/norm_logging.py +1 -1
  11. nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
  12. nshtrainer/callbacks/print_table.py +1 -1
  13. nshtrainer/callbacks/throughput_monitor.py +1 -1
  14. nshtrainer/callbacks/timer.py +1 -1
  15. nshtrainer/callbacks/wandb_watch.py +1 -1
  16. nshtrainer/ll/__init__.py +0 -1
  17. nshtrainer/ll/actsave.py +2 -1
  18. nshtrainer/metrics/__init__.py +1 -0
  19. nshtrainer/metrics/_config.py +37 -0
  20. nshtrainer/model/__init__.py +11 -11
  21. nshtrainer/model/_environment.py +777 -0
  22. nshtrainer/model/base.py +5 -114
  23. nshtrainer/model/config.py +92 -507
  24. nshtrainer/model/modules/logger.py +11 -6
  25. nshtrainer/runner.py +3 -6
  26. nshtrainer/trainer/_checkpoint_metadata.py +102 -0
  27. nshtrainer/trainer/_checkpoint_resolver.py +319 -0
  28. nshtrainer/trainer/_runtime_callback.py +120 -0
  29. nshtrainer/trainer/checkpoint_connector.py +63 -0
  30. nshtrainer/trainer/signal_connector.py +12 -9
  31. nshtrainer/trainer/trainer.py +111 -31
  32. {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/METADATA +3 -1
  33. {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/RECORD +34 -27
  34. nshtrainer/actsave/__init__.py +0 -3
  35. {nshtrainer-0.8.7.dist-info → nshtrainer-0.10.0.dist-info}/WHEEL +0 -0
@@ -1,7 +1,5 @@
1
1
  import copy
2
2
  import os
3
- import re
4
- import socket
5
3
  import string
6
4
  import time
7
5
  import warnings
@@ -36,10 +34,17 @@ from lightning.pytorch.strategies.strategy import Strategy
36
34
  from pydantic import DirectoryPath
37
35
  from typing_extensions import Self, TypedDict, TypeVar, override
38
36
 
39
- from ..callbacks import CallbackConfig
37
+ from ..callbacks import (
38
+ CallbackConfig,
39
+ LatestEpochCheckpointCallbackConfig,
40
+ ModelCheckpointCallbackConfig,
41
+ OnExceptionCheckpointCallbackConfig,
42
+ WandbWatchConfig,
43
+ )
40
44
  from ..callbacks.base import CallbackConfigBase
41
- from ..callbacks.wandb_watch import WandbWatchConfig
42
- from ..util.slurm import parse_slurm_node_list
45
+ from ..metrics import MetricConfig
46
+ from ..trainer._checkpoint_resolver import CheckpointLoadingConfig
47
+ from ._environment import EnvironmentConfig
43
48
 
44
49
  log = getLogger(__name__)
45
50
 
@@ -62,7 +67,7 @@ class BaseProfilerConfig(C.Config, ABC):
62
67
  """
63
68
 
64
69
  @abstractmethod
65
- def construct_profiler(self, root_config: "BaseConfig") -> Profiler: ...
70
+ def create_profiler(self, root_config: "BaseConfig") -> Profiler: ...
66
71
 
67
72
 
68
73
  class SimpleProfilerConfig(BaseProfilerConfig):
@@ -75,7 +80,7 @@ class SimpleProfilerConfig(BaseProfilerConfig):
75
80
  """
76
81
 
77
82
  @override
78
- def construct_profiler(self, root_config):
83
+ def create_profiler(self, root_config):
79
84
  from lightning.pytorch.profilers.simple import SimpleProfiler
80
85
 
81
86
  if (dirpath := self.dirpath) is None:
@@ -104,7 +109,7 @@ class AdvancedProfilerConfig(BaseProfilerConfig):
104
109
  """
105
110
 
106
111
  @override
107
- def construct_profiler(self, root_config):
112
+ def create_profiler(self, root_config):
108
113
  from lightning.pytorch.profilers.advanced import AdvancedProfiler
109
114
 
110
115
  if (dirpath := self.dirpath) is None:
@@ -172,7 +177,7 @@ class PyTorchProfilerConfig(BaseProfilerConfig):
172
177
  """Keyword arguments for the PyTorch profiler. This depends on your PyTorch version"""
173
178
 
174
179
  @override
175
- def construct_profiler(self, root_config):
180
+ def create_profiler(self, root_config):
176
181
  from lightning.pytorch.profilers.pytorch import PyTorchProfiler
177
182
 
178
183
  if (dirpath := self.dirpath) is None:
@@ -203,190 +208,6 @@ ProfilerConfig: TypeAlias = Annotated[
203
208
  ]
204
209
 
205
210
 
206
- class EnvironmentClassInformationConfig(C.Config):
207
- name: str
208
- module: str
209
- full_name: str
210
-
211
- file_path: Path
212
- source_file_path: Path | None = None
213
-
214
-
215
- class EnvironmentSLURMInformationConfig(C.Config):
216
- hostname: str
217
- hostnames: list[str]
218
- job_id: str
219
- raw_job_id: str
220
- array_job_id: str | None
221
- array_task_id: str | None
222
- num_tasks: int
223
- num_nodes: int
224
- node: str | int | None
225
- global_rank: int
226
- local_rank: int
227
-
228
- @classmethod
229
- def from_current_environment(cls):
230
- try:
231
- from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
232
-
233
- if not SLURMEnvironment.detect():
234
- return None
235
-
236
- hostname = socket.gethostname()
237
- hostnames = [hostname]
238
- if node_list := os.environ.get("SLURM_JOB_NODELIST", ""):
239
- hostnames = parse_slurm_node_list(node_list)
240
-
241
- raw_job_id = os.environ["SLURM_JOB_ID"]
242
- job_id = raw_job_id
243
- array_job_id = os.environ.get("SLURM_ARRAY_JOB_ID")
244
- array_task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
245
- if array_job_id and array_task_id:
246
- job_id = f"{array_job_id}_{array_task_id}"
247
-
248
- num_tasks = int(os.environ["SLURM_NTASKS"])
249
- num_nodes = int(os.environ["SLURM_JOB_NUM_NODES"])
250
-
251
- node_id = os.environ.get("SLURM_NODEID")
252
-
253
- global_rank = int(os.environ["SLURM_PROCID"])
254
- local_rank = int(os.environ["SLURM_LOCALID"])
255
-
256
- return cls(
257
- hostname=hostname,
258
- hostnames=hostnames,
259
- job_id=job_id,
260
- raw_job_id=raw_job_id,
261
- array_job_id=array_job_id,
262
- array_task_id=array_task_id,
263
- num_tasks=num_tasks,
264
- num_nodes=num_nodes,
265
- node=node_id,
266
- global_rank=global_rank,
267
- local_rank=local_rank,
268
- )
269
- except (ImportError, RuntimeError, ValueError, KeyError):
270
- return None
271
-
272
-
273
- class EnvironmentLSFInformationConfig(C.Config):
274
- hostname: str
275
- hostnames: list[str]
276
- job_id: str
277
- array_job_id: str | None
278
- array_task_id: str | None
279
- num_tasks: int
280
- num_nodes: int
281
- node: str | int | None
282
- global_rank: int
283
- local_rank: int
284
-
285
- @classmethod
286
- def from_current_environment(cls):
287
- try:
288
- import os
289
- import socket
290
-
291
- hostname = socket.gethostname()
292
- hostnames = [hostname]
293
- if node_list := os.environ.get("LSB_HOSTS", ""):
294
- hostnames = node_list.split()
295
-
296
- job_id = os.environ["LSB_JOBID"]
297
- array_job_id = os.environ.get("LSB_JOBINDEX")
298
- array_task_id = os.environ.get("LSB_JOBINDEX")
299
-
300
- num_tasks = int(os.environ.get("LSB_DJOB_NUMPROC", 1))
301
- num_nodes = len(set(hostnames))
302
-
303
- node_id = (
304
- os.environ.get("LSB_HOSTS", "").split().index(hostname)
305
- if "LSB_HOSTS" in os.environ
306
- else None
307
- )
308
-
309
- # LSF doesn't have direct equivalents for global_rank and local_rank
310
- # You might need to calculate these based on your specific setup
311
- global_rank = int(os.environ.get("PMI_RANK", 0))
312
- local_rank = int(os.environ.get("LSB_RANK", 0))
313
-
314
- return cls(
315
- hostname=hostname,
316
- hostnames=hostnames,
317
- job_id=job_id,
318
- array_job_id=array_job_id,
319
- array_task_id=array_task_id,
320
- num_tasks=num_tasks,
321
- num_nodes=num_nodes,
322
- node=node_id,
323
- global_rank=global_rank,
324
- local_rank=local_rank,
325
- )
326
- except (ImportError, RuntimeError, ValueError, KeyError):
327
- return None
328
-
329
-
330
- class EnvironmentLinuxEnvironmentConfig(C.Config):
331
- """
332
- Information about the Linux environment (e.g., current user, hostname, etc.)
333
- """
334
-
335
- user: str | None = None
336
- hostname: str | None = None
337
- system: str | None = None
338
- release: str | None = None
339
- version: str | None = None
340
- machine: str | None = None
341
- processor: str | None = None
342
- cpu_count: int | None = None
343
- memory: int | None = None
344
- uptime: timedelta | None = None
345
- boot_time: float | None = None
346
- load_avg: tuple[float, float, float] | None = None
347
-
348
-
349
- class EnvironmentSnapshotConfig(C.Config):
350
- snapshot_dir: Path | None = None
351
- modules: list[str] | None = None
352
-
353
- @classmethod
354
- def from_current_environment(cls):
355
- draft = cls.draft()
356
- if snapshot_dir := os.environ.get("NSHRUNNER_SNAPSHOT_DIR"):
357
- draft.snapshot_dir = Path(snapshot_dir)
358
- if modules := os.environ.get("NSHRUNNER_SNAPSHOT_MODULES"):
359
- draft.modules = modules.split(",")
360
- return draft.finalize()
361
-
362
-
363
- class EnvironmentConfig(C.Config):
364
- cwd: Path | None = None
365
-
366
- snapshot: EnvironmentSnapshotConfig | None = None
367
-
368
- python_executable: Path | None = None
369
- python_path: list[Path] | None = None
370
- python_version: str | None = None
371
-
372
- config: EnvironmentClassInformationConfig | None = None
373
- model: EnvironmentClassInformationConfig | None = None
374
- data: EnvironmentClassInformationConfig | None = None
375
-
376
- linux: EnvironmentLinuxEnvironmentConfig | None = None
377
-
378
- slurm: EnvironmentSLURMInformationConfig | None = None
379
- lsf: EnvironmentLSFInformationConfig | None = None
380
-
381
- base_dir: Path | None = None
382
- log_dir: Path | None = None
383
- checkpoint_dir: Path | None = None
384
- stdio_dir: Path | None = None
385
-
386
- seed: int | None = None
387
- seed_workers: bool | None = None
388
-
389
-
390
211
  class BaseLoggerConfig(C.Config, ABC):
391
212
  enabled: bool = True
392
213
  """Enable this logger."""
@@ -398,7 +219,7 @@ class BaseLoggerConfig(C.Config, ABC):
398
219
  """Directory to save the logs to. If None, will use the default log directory for the trainer."""
399
220
 
400
221
  @abstractmethod
401
- def construct_logger(self, root_config: "BaseConfig") -> Logger | None: ...
222
+ def create_logger(self, root_config: "BaseConfig") -> Logger | None: ...
402
223
 
403
224
  def disable_(self):
404
225
  self.enabled = False
@@ -466,18 +287,16 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
466
287
  """Whether to run WandB in offline mode."""
467
288
 
468
289
  @override
469
- def construct_logger(self, root_config):
290
+ def create_logger(self, root_config):
470
291
  if not self.enabled:
471
292
  return None
472
293
 
473
294
  from lightning.pytorch.loggers.wandb import WandbLogger
474
295
 
475
- save_dir = root_config.directory.resolve_log_directory_for_logger(
296
+ save_dir = root_config.directory._resolve_log_directory_for_logger(
476
297
  root_config.id,
477
298
  self,
478
299
  )
479
- save_dir = save_dir / "wandb"
480
- save_dir.mkdir(parents=True, exist_ok=True)
481
300
  return WandbLogger(
482
301
  save_dir=save_dir,
483
302
  project=self.project or _project_name(root_config),
@@ -494,9 +313,9 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
494
313
  )
495
314
 
496
315
  @override
497
- def construct_callbacks(self, root_config):
316
+ def create_callbacks(self, root_config):
498
317
  if self.watch:
499
- yield from self.watch.construct_callbacks(root_config)
318
+ yield from self.watch.create_callbacks(root_config)
500
319
 
501
320
 
502
321
  class CSVLoggerConfig(BaseLoggerConfig):
@@ -515,18 +334,16 @@ class CSVLoggerConfig(BaseLoggerConfig):
515
334
  """How often to flush logs to disk."""
516
335
 
517
336
  @override
518
- def construct_logger(self, root_config):
337
+ def create_logger(self, root_config):
519
338
  if not self.enabled:
520
339
  return None
521
340
 
522
341
  from lightning.pytorch.loggers.csv_logs import CSVLogger
523
342
 
524
- save_dir = root_config.directory.resolve_log_directory_for_logger(
343
+ save_dir = root_config.directory._resolve_log_directory_for_logger(
525
344
  root_config.id,
526
345
  self,
527
346
  )
528
- save_dir = save_dir / "csv"
529
- save_dir.mkdir(parents=True, exist_ok=True)
530
347
  return CSVLogger(
531
348
  save_dir=save_dir,
532
349
  name=root_config.run_name,
@@ -581,18 +398,16 @@ class TensorboardLoggerConfig(BaseLoggerConfig):
581
398
  """A string to put at the beginning of metric keys."""
582
399
 
583
400
  @override
584
- def construct_logger(self, root_config):
401
+ def create_logger(self, root_config):
585
402
  if not self.enabled:
586
403
  return None
587
404
 
588
405
  from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
589
406
 
590
- save_dir = root_config.directory.resolve_log_directory_for_logger(
407
+ save_dir = root_config.directory._resolve_log_directory_for_logger(
591
408
  root_config.id,
592
409
  self,
593
410
  )
594
- save_dir = save_dir / "tensorboard"
595
- save_dir.mkdir(parents=True, exist_ok=True)
596
411
  return TensorBoardLogger(
597
412
  save_dir=save_dir,
598
413
  name=root_config.run_name,
@@ -624,6 +439,9 @@ class LoggingConfig(CallbackConfigBase):
624
439
  log_epoch: bool = True
625
440
  """If enabled, will log the fractional epoch number to the logger."""
626
441
 
442
+ actsave_logged_metrics: bool = False
443
+ """If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
444
+
627
445
  @property
628
446
  def wandb(self) -> WandbLoggerConfig | None:
629
447
  return next(
@@ -650,7 +468,7 @@ class LoggingConfig(CallbackConfigBase):
650
468
  ),
651
469
  )
652
470
 
653
- def construct_loggers(self, root_config: "BaseConfig"):
471
+ def create_loggers(self, root_config: "BaseConfig"):
654
472
  """
655
473
  Constructs and returns a list of loggers based on the provided root configuration.
656
474
 
@@ -671,13 +489,13 @@ class LoggingConfig(CallbackConfigBase):
671
489
  ):
672
490
  if not logger_config.enabled:
673
491
  continue
674
- if (logger := logger_config.construct_logger(root_config)) is None:
492
+ if (logger := logger_config.create_logger(root_config)) is None:
675
493
  continue
676
494
  loggers.append(logger)
677
495
  return loggers
678
496
 
679
497
  @override
680
- def construct_callbacks(self, root_config):
498
+ def create_callbacks(self, root_config):
681
499
  if self.log_lr:
682
500
  from lightning.pytorch.callbacks import LearningRateMonitor
683
501
 
@@ -696,7 +514,7 @@ class LoggingConfig(CallbackConfigBase):
696
514
  if not logger or not isinstance(logger, CallbackConfigBase):
697
515
  continue
698
516
 
699
- yield from logger.construct_callbacks(root_config)
517
+ yield from logger.create_callbacks(root_config)
700
518
 
701
519
 
702
520
  class GradientClippingConfig(C.Config):
@@ -723,7 +541,7 @@ class OptimizationConfig(CallbackConfigBase):
723
541
  """Gradient clipping configuration, or None to disable gradient clipping."""
724
542
 
725
543
  @override
726
- def construct_callbacks(self, root_config):
544
+ def create_callbacks(self, root_config):
727
545
  from ..callbacks.norm_logging import NormLoggingConfig
728
546
 
729
547
  yield from NormLoggingConfig(
@@ -731,7 +549,7 @@ class OptimizationConfig(CallbackConfigBase):
731
549
  log_grad_norm_per_param=self.log_grad_norm_per_param,
732
550
  log_param_norm=self.log_param_norm,
733
551
  log_param_norm_per_param=self.log_param_norm_per_param,
734
- ).construct_callbacks(root_config)
552
+ ).create_callbacks(root_config)
735
553
 
736
554
 
737
555
  TPlugin = TypeVar(
@@ -746,17 +564,17 @@ TPlugin = TypeVar(
746
564
 
747
565
  @runtime_checkable
748
566
  class PluginConfigProtocol(Protocol[TPlugin]):
749
- def construct_plugin(self) -> TPlugin: ...
567
+ def create_plugin(self) -> TPlugin: ...
750
568
 
751
569
 
752
570
  @runtime_checkable
753
571
  class AcceleratorConfigProtocol(Protocol):
754
- def construct_accelerator(self) -> Accelerator: ...
572
+ def create_accelerator(self) -> Accelerator: ...
755
573
 
756
574
 
757
575
  @runtime_checkable
758
576
  class StrategyConfigProtocol(Protocol):
759
- def construct_strategy(self) -> Strategy: ...
577
+ def create_strategy(self) -> Strategy: ...
760
578
 
761
579
 
762
580
  AcceleratorLiteral: TypeAlias = Literal[
@@ -793,16 +611,34 @@ StrategyLiteral: TypeAlias = Literal[
793
611
  ]
794
612
 
795
613
 
796
- class CheckpointLoadingConfig(C.Config):
797
- path: Literal["best", "last", "hpc"] | str | Path | None = None
798
- """
799
- Checkpoint path to use when loading a checkpoint.
614
+ def _create_symlink_to_nshrunner(base_dir: Path):
615
+ # Resolve the current nshrunner session directory
616
+ if not (session_dir := os.environ.get("NSHRUNNER_SESSION_DIR")):
617
+ log.warning("NSHRUNNER_SESSION_DIR is not set. Skipping symlink creation.")
618
+ return
619
+ session_dir = Path(session_dir)
620
+ if not session_dir.exists() or not session_dir.is_dir():
621
+ log.warning(
622
+ f"NSHRUNNER_SESSION_DIR is not a valid directory: {session_dir}. "
623
+ "Skipping symlink creation."
624
+ )
625
+ return
800
626
 
801
- - "best" will load the best checkpoint.
802
- - "last" will load the last checkpoint.
803
- - "hpc" will load the SLURM pre-empted checkpoint.
804
- - Any other string or Path will load the checkpoint from the specified path.
805
- """
627
+ # Create the symlink
628
+ symlink_path = base_dir / "nshrunner"
629
+ if symlink_path.exists():
630
+ # If it already points to the correct directory, we're done
631
+ if symlink_path.resolve() == session_dir.resolve():
632
+ return
633
+
634
+ # Otherwise, we should log a warning and remove the existing symlink
635
+ log.warning(
636
+ f"A symlink pointing to {symlink_path.resolve()} already exists at {symlink_path}. "
637
+ "Removing the existing symlink."
638
+ )
639
+ symlink_path.unlink()
640
+
641
+ symlink_path.symlink_to(session_dir)
806
642
 
807
643
 
808
644
  class DirectoryConfig(C.Config):
@@ -813,30 +649,33 @@ class DirectoryConfig(C.Config):
813
649
  This isn't specific to the run; it is the parent directory of all runs.
814
650
  """
815
651
 
652
+ create_symlink_to_nshrunner_root: bool = True
653
+ """Should we create a symlink to the root folder for the Runner (if we're in one)?"""
654
+
816
655
  log: Path | None = None
817
- """Base directory for all experiment tracking (e.g., WandB, Tensorboard, etc.) files. If None, will use lltrainer/{id}/log/."""
656
+ """Base directory for all experiment tracking (e.g., WandB, Tensorboard, etc.) files. If None, will use nshtrainer/{id}/log/."""
818
657
 
819
658
  stdio: Path | None = None
820
- """stdout/stderr log directory to use for the trainer. If None, will use lltrainer/{id}/stdio/."""
659
+ """stdout/stderr log directory to use for the trainer. If None, will use nshtrainer/{id}/stdio/."""
821
660
 
822
661
  checkpoint: Path | None = None
823
- """Checkpoint directory to use for the trainer. If None, will use lltrainer/{id}/checkpoint/."""
662
+ """Checkpoint directory to use for the trainer. If None, will use nshtrainer/{id}/checkpoint/."""
824
663
 
825
664
  activation: Path | None = None
826
- """Activation directory to use for the trainer. If None, will use lltrainer/{id}/activation/."""
665
+ """Activation directory to use for the trainer. If None, will use nshtrainer/{id}/activation/."""
827
666
 
828
667
  profile: Path | None = None
829
- """Directory to save profiling information to. If None, will use lltrainer/{id}/profile/."""
668
+ """Directory to save profiling information to. If None, will use nshtrainer/{id}/profile/."""
830
669
 
831
670
  def resolve_run_root_directory(self, run_id: str) -> Path:
832
671
  if (project_root_dir := self.project_root) is None:
833
672
  project_root_dir = Path.cwd()
834
673
 
835
- # The default base dir is $CWD/lltrainer/{id}/
836
- base_dir = project_root_dir / "lltrainer"
674
+ # The default base dir is $CWD/nshtrainer/{id}/
675
+ base_dir = project_root_dir / "nshtrainer"
837
676
  base_dir.mkdir(exist_ok=True)
838
677
 
839
- # Add a .gitignore file to the lltrainer directory
678
+ # Add a .gitignore file to the nshtrainer directory
840
679
  # which will ignore all files except for the .gitignore file itself
841
680
  gitignore_path = base_dir / ".gitignore"
842
681
  if not gitignore_path.exists():
@@ -846,6 +685,10 @@ class DirectoryConfig(C.Config):
846
685
  base_dir = base_dir / run_id
847
686
  base_dir.mkdir(exist_ok=True)
848
687
 
688
+ # Create a symlink to the root folder for the Runner
689
+ if self.create_symlink_to_nshrunner_root:
690
+ _create_symlink_to_nshrunner(base_dir)
691
+
849
692
  return base_dir
850
693
 
851
694
  def resolve_subdirectory(
@@ -854,7 +697,7 @@ class DirectoryConfig(C.Config):
854
697
  # subdirectory: Literal["log", "stdio", "checkpoint", "activation", "profile"],
855
698
  subdirectory: str,
856
699
  ) -> Path:
857
- # The subdir will be $CWD/lltrainer/{id}/{log, stdio, checkpoint, activation}/
700
+ # The subdir will be $CWD/nshtrainer/{id}/{log, stdio, checkpoint, activation}/
858
701
  if (subdir := getattr(self, subdirectory, None)) is not None:
859
702
  assert isinstance(
860
703
  subdir, Path
@@ -866,7 +709,7 @@ class DirectoryConfig(C.Config):
866
709
  dir.mkdir(exist_ok=True)
867
710
  return dir
868
711
 
869
- def resolve_log_directory_for_logger(
712
+ def _resolve_log_directory_for_logger(
870
713
  self,
871
714
  run_id: str,
872
715
  logger: LoggerConfig,
@@ -874,9 +717,10 @@ class DirectoryConfig(C.Config):
874
717
  if (log_dir := logger.log_dir) is not None:
875
718
  return log_dir
876
719
 
877
- # Save to lltrainer/{id}/log/{logger kind}/{id}/
720
+ # Save to nshtrainer/{id}/log/{logger kind}
878
721
  log_dir = self.resolve_subdirectory(run_id, "log")
879
722
  log_dir = log_dir / logger.kind
723
+ log_dir.mkdir(exist_ok=True)
880
724
 
881
725
  return log_dir
882
726
 
@@ -890,208 +734,6 @@ class ReproducibilityConfig(C.Config):
890
734
  """
891
735
 
892
736
 
893
- class ModelCheckpointCallbackConfig(CallbackConfigBase):
894
- """Arguments for the ModelCheckpoint callback."""
895
-
896
- kind: Literal["model_checkpoint"] = "model_checkpoint"
897
-
898
- dirpath: str | Path | None = None
899
- """
900
- Directory path to save the model file. If `None`, we save to the checkpoint directory set in `config.directory`.
901
- """
902
-
903
- filename: str | None = None
904
- """
905
- Checkpoint filename.
906
- If None, a default template is used (see :attr:`ModelCheckpoint.CHECKPOINT_JOIN_CHAR`).
907
- """
908
-
909
- monitor: str | None = None
910
- """
911
- Quantity to monitor for saving checkpoints.
912
- If None, no metric is monitored and checkpoints are saved at the end of every epoch.
913
- """
914
-
915
- verbose: bool = False
916
- """Verbosity mode. If True, print additional information about checkpoints."""
917
-
918
- save_last: Literal[True, False, "link"] | None = "link"
919
- """
920
- Whether to save the last checkpoint.
921
- If True, saves a copy of the last checkpoint separately.
922
- If "link", creates a symbolic link to the last checkpoint.
923
- """
924
-
925
- save_top_k: int = 1
926
- """
927
- Number of best models to save.
928
- If -1, all models are saved.
929
- If 0, no models are saved.
930
- """
931
-
932
- save_weights_only: bool = False
933
- """Whether to save only the model's weights or the entire model object."""
934
-
935
- mode: str = "min"
936
- """
937
- One of "min" or "max".
938
- If "min", training will stop when the metric monitored has stopped decreasing.
939
- If "max", training will stop when the metric monitored has stopped increasing.
940
- """
941
-
942
- auto_insert_metric_name: bool = True
943
- """Whether to automatically insert the metric name in the checkpoint filename."""
944
-
945
- every_n_train_steps: int | None = None
946
- """
947
- Number of training steps between checkpoints.
948
- If None or 0, no checkpoints are saved during training.
949
- """
950
-
951
- train_time_interval: timedelta | None = None
952
- """
953
- Time interval between checkpoints during training.
954
- If None, no checkpoints are saved during training based on time.
955
- """
956
-
957
- every_n_epochs: int | None = None
958
- """
959
- Number of epochs between checkpoints.
960
- If None or 0, no checkpoints are saved at the end of epochs.
961
- """
962
-
963
- save_on_train_epoch_end: bool | None = None
964
- """
965
- Whether to run checkpointing at the end of the training epoch.
966
- If False, checkpointing runs at the end of the validation.
967
- """
968
-
969
- enable_version_counter: bool = True
970
- """Whether to append a version to the existing file name."""
971
-
972
- auto_append_metric: bool = True
973
- """If enabled, this will automatically add "-{monitor}" to the filename."""
974
-
975
- @staticmethod
976
- def _convert_string(input_string: str):
977
- # Find all variables enclosed in curly braces
978
- variables = re.findall(r"\{(.*?)\}", input_string)
979
-
980
- # Replace each variable with its corresponding key-value pair
981
- output_string = input_string
982
- for variable in variables:
983
- # If the name is something like {variable:format}, we shouldn't process the format.
984
- key_name = variable
985
- if ":" in variable:
986
- key_name, _ = variable.split(":", 1)
987
- continue
988
-
989
- # Replace '/' with '_' in the key name
990
- key_name = key_name.replace("/", "_")
991
- output_string = output_string.replace(
992
- f"{{{variable}}}", f"{key_name}={{{variable}}}"
993
- )
994
-
995
- return output_string
996
-
997
- @override
998
- def construct_callbacks(self, root_config):
999
- from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
1000
-
1001
- dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
1002
- root_config.id, "checkpoint"
1003
- )
1004
-
1005
- # If `monitor` is not provided, we can use `config.primary_metric` if it is set.
1006
- monitor = self.monitor
1007
- mode = self.mode
1008
- if (
1009
- monitor is None
1010
- and (primary_metric := root_config.primary_metric) is not None
1011
- ):
1012
- monitor = primary_metric.validation_monitor
1013
- mode = primary_metric.mode
1014
-
1015
- filename = self.filename
1016
- if self.auto_append_metric:
1017
- if not filename:
1018
- filename = "{epoch}-{step}"
1019
- filename = f"{filename}-{{{monitor}}}"
1020
-
1021
- if self.auto_insert_metric_name and filename:
1022
- new_filename = self._convert_string(filename)
1023
- log.critical(
1024
- f"Updated ModelCheckpoint filename: {filename} -> {new_filename}"
1025
- )
1026
- filename = new_filename
1027
-
1028
- yield ModelCheckpoint(
1029
- dirpath=dirpath,
1030
- filename=filename,
1031
- monitor=monitor,
1032
- mode=mode,
1033
- verbose=self.verbose,
1034
- save_last=self.save_last,
1035
- save_top_k=self.save_top_k,
1036
- save_weights_only=self.save_weights_only,
1037
- auto_insert_metric_name=False,
1038
- every_n_train_steps=self.every_n_train_steps,
1039
- train_time_interval=self.train_time_interval,
1040
- every_n_epochs=self.every_n_epochs,
1041
- save_on_train_epoch_end=self.save_on_train_epoch_end,
1042
- enable_version_counter=self.enable_version_counter,
1043
- )
1044
-
1045
-
1046
- class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
1047
- kind: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
1048
-
1049
- dirpath: str | Path | None = None
1050
- """Directory path to save the checkpoint file."""
1051
-
1052
- filename: str | None = None
1053
- """Checkpoint filename. This must not include the extension. If `None`, `latest_epoch_{id}_{timestamp}` is used."""
1054
-
1055
- save_weights_only: bool = False
1056
- """Whether to save only the model's weights or the entire model object."""
1057
-
1058
- @override
1059
- def construct_callbacks(self, root_config):
1060
- from ..callbacks.latest_epoch_checkpoint import LatestEpochCheckpoint
1061
-
1062
- dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
1063
- root_config.id, "checkpoint"
1064
- )
1065
-
1066
- yield LatestEpochCheckpoint(
1067
- dirpath=dirpath,
1068
- filename=self.filename,
1069
- save_weights_only=self.save_weights_only,
1070
- )
1071
-
1072
-
1073
- class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
1074
- kind: Literal["on_exception_checkpoint"] = "on_exception_checkpoint"
1075
-
1076
- dirpath: str | Path | None = None
1077
- """Directory path to save the checkpoint file."""
1078
-
1079
- filename: str | None = None
1080
- """Checkpoint filename. This must not include the extension. If `None`, `on_exception_{id}_{timestamp}` is used."""
1081
-
1082
- @override
1083
- def construct_callbacks(self, root_config):
1084
- from ..callbacks.on_exception_checkpoint import OnExceptionCheckpoint
1085
-
1086
- dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
1087
- root_config.id, "checkpoint"
1088
- )
1089
-
1090
- if not (filename := self.filename):
1091
- filename = f"on_exception_{root_config.id}"
1092
- yield OnExceptionCheckpoint(dirpath=dirpath, filename=filename)
1093
-
1094
-
1095
737
  CheckpointCallbackConfig: TypeAlias = Annotated[
1096
738
  ModelCheckpointCallbackConfig
1097
739
  | LatestEpochCheckpointCallbackConfig
@@ -1155,12 +797,12 @@ class CheckpointSavingConfig(CallbackConfigBase):
1155
797
  )
1156
798
 
1157
799
  @override
1158
- def construct_callbacks(self, root_config: "BaseConfig"):
800
+ def create_callbacks(self, root_config: "BaseConfig"):
1159
801
  if not self.should_save_checkpoints(root_config):
1160
802
  return
1161
803
 
1162
804
  for callback_config in self.checkpoint_callbacks:
1163
- yield from callback_config.construct_callbacks(root_config)
805
+ yield from callback_config.create_callbacks(root_config)
1164
806
 
1165
807
 
1166
808
  class LightningTrainerKwargs(TypedDict, total=False):
@@ -1437,7 +1079,7 @@ class EarlyStoppingConfig(CallbackConfigBase):
1437
1079
  """
1438
1080
 
1439
1081
  @override
1440
- def construct_callbacks(self, root_config: "BaseConfig"):
1082
+ def create_callbacks(self, root_config: "BaseConfig"):
1441
1083
  from ..callbacks.early_stopping import EarlyStopping
1442
1084
 
1443
1085
  monitor = self.monitor
@@ -1468,32 +1110,6 @@ class EarlyStoppingConfig(CallbackConfigBase):
1468
1110
  ]
1469
1111
 
1470
1112
 
1471
- class ActSaveConfig(CallbackConfigBase):
1472
- enabled: bool = True
1473
- """Enable activation saving."""
1474
-
1475
- auto_save_logged_metrics: bool = False
1476
- """If enabled, will automatically save logged metrics (using `LightningModule.log`) as activations."""
1477
-
1478
- save_dir: Path | None = None
1479
- """Directory to save activations to. If None, will use the activation directory set in `config.directory`."""
1480
-
1481
- def __bool__(self):
1482
- return self.enabled
1483
-
1484
- def resolve_save_dir(self, root_config: "BaseConfig"):
1485
- if self.save_dir is not None:
1486
- return self.save_dir
1487
-
1488
- return root_config.directory.resolve_subdirectory(root_config.id, "activation")
1489
-
1490
- @override
1491
- def construct_callbacks(self, root_config):
1492
- from ..actsave import ActSaveCallback
1493
-
1494
- return [ActSaveCallback()]
1495
-
1496
-
1497
1113
  class SanityCheckingConfig(C.Config):
1498
1114
  reduce_lr_on_plateau: Literal["disable", "error", "warn"] = "error"
1499
1115
  """
@@ -1505,7 +1121,7 @@ class SanityCheckingConfig(C.Config):
1505
1121
 
1506
1122
 
1507
1123
  class TrainerConfig(C.Config):
1508
- checkpoint_loading: CheckpointLoadingConfig = CheckpointLoadingConfig()
1124
+ checkpoint_loading: CheckpointLoadingConfig | Literal["auto"] = "auto"
1509
1125
  """Checkpoint loading configuration options."""
1510
1126
 
1511
1127
  checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
@@ -1523,9 +1139,6 @@ class TrainerConfig(C.Config):
1523
1139
  sanity_checking: SanityCheckingConfig = SanityCheckingConfig()
1524
1140
  """Sanity checking configuration options."""
1525
1141
 
1526
- actsave: ActSaveConfig | None = ActSaveConfig(enabled=False)
1527
- """Activation saving configuration options."""
1528
-
1529
1142
  early_stopping: EarlyStoppingConfig | None = None
1530
1143
  """Early stopping configuration options."""
1531
1144
 
@@ -1694,12 +1307,12 @@ class TrainerConfig(C.Config):
1694
1307
  automatic selection based on the chosen accelerator. Default: ``"auto"``.
1695
1308
  """
1696
1309
 
1697
- auto_wrap_trainer: bool = True
1698
- """If enabled, will automatically wrap the `run` function with a `Trainer.context()` context manager. Should be `True` most of the time."""
1699
1310
  auto_set_default_root_dir: bool = True
1700
1311
  """If enabled, will automatically set the default root dir to [cwd/lightning_logs/<id>/]. There is basically no reason to disable this."""
1701
1312
  supports_shared_parameters: bool = True
1702
1313
  """If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`"""
1314
+ save_checkpoint_metadata: bool = True
1315
+ """If enabled, will save additional metadata whenever a checkpoint is saved."""
1703
1316
 
1704
1317
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
1705
1318
  """
@@ -1719,35 +1332,6 @@ class TrainerConfig(C.Config):
1719
1332
  """If enabled, will set the torch float32 matmul precision to the specified value. Useful for faster training on Ampere+ GPUs."""
1720
1333
 
1721
1334
 
1722
- class MetricConfig(C.Config):
1723
- name: str
1724
- """The name of the primary metric."""
1725
-
1726
- mode: Literal["min", "max"]
1727
- """
1728
- The mode of the primary metric:
1729
- - "min" for metrics that should be minimized (e.g., loss)
1730
- - "max" for metrics that should be maximized (e.g., accuracy)
1731
- """
1732
-
1733
- @property
1734
- def validation_monitor(self) -> str:
1735
- return f"val/{self.name}"
1736
-
1737
- def __post_init__(self):
1738
- for split in ("train", "val", "test", "predict"):
1739
- if self.name.startswith(f"{split}/"):
1740
- raise ValueError(
1741
- f"Primary metric name should not start with '{split}/'. "
1742
- f"Just use '{self.name[len(split) + 1:]}' instead. "
1743
- "The split name is automatically added depending on the context."
1744
- )
1745
-
1746
- @classmethod
1747
- def loss(cls, mode: Literal["min", "max"] = "min"):
1748
- return cls(name="loss", mode=mode)
1749
-
1750
-
1751
1335
  PrimaryMetricConfig: TypeAlias = MetricConfig
1752
1336
 
1753
1337
 
@@ -1767,7 +1351,9 @@ class BaseConfig(C.Config):
1767
1351
 
1768
1352
  debug: bool = False
1769
1353
  """Whether to run in debug mode. This will enable debug logging and enable debug code paths."""
1770
- environment: Annotated[EnvironmentConfig, C.Field(repr=False)] = EnvironmentConfig()
1354
+ environment: Annotated[EnvironmentConfig, C.Field(repr=False)] = (
1355
+ EnvironmentConfig.empty()
1356
+ )
1771
1357
  """A snapshot of the current environment information (e.g. python version, slurm info, etc.). This is automatically populated by the run script."""
1772
1358
 
1773
1359
  directory: DirectoryConfig = DirectoryConfig()
@@ -1855,7 +1441,7 @@ class BaseConfig(C.Config):
1855
1441
  self.directory = DirectoryConfig()
1856
1442
 
1857
1443
  if environment:
1858
- self.environment = EnvironmentConfig()
1444
+ self.environment = EnvironmentConfig.empty()
1859
1445
 
1860
1446
  if meta:
1861
1447
  self.meta = {}
@@ -1953,8 +1539,7 @@ class BaseConfig(C.Config):
1953
1539
  )
1954
1540
  return cls.model_validate(hparams)
1955
1541
 
1956
- def ll_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
1957
- yield self.trainer.actsave
1542
+ def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
1958
1543
  yield self.trainer.early_stopping
1959
1544
  yield self.trainer.checkpoint_saving
1960
1545
  yield self.trainer.logging