nshtrainer 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. nshtrainer/__init__.py +2 -1
  2. nshtrainer/callbacks/__init__.py +17 -1
  3. nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
  4. nshtrainer/callbacks/base.py +7 -5
  5. nshtrainer/callbacks/ema.py +1 -1
  6. nshtrainer/callbacks/finite_checks.py +1 -1
  7. nshtrainer/callbacks/gradient_skipping.py +1 -1
  8. nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
  9. nshtrainer/callbacks/model_checkpoint.py +187 -0
  10. nshtrainer/callbacks/norm_logging.py +1 -1
  11. nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
  12. nshtrainer/callbacks/print_table.py +1 -1
  13. nshtrainer/callbacks/throughput_monitor.py +1 -1
  14. nshtrainer/callbacks/timer.py +1 -1
  15. nshtrainer/callbacks/wandb_watch.py +1 -1
  16. nshtrainer/ll/__init__.py +0 -1
  17. nshtrainer/ll/actsave.py +2 -1
  18. nshtrainer/metrics/__init__.py +1 -0
  19. nshtrainer/metrics/_config.py +37 -0
  20. nshtrainer/model/__init__.py +11 -11
  21. nshtrainer/model/_environment.py +777 -0
  22. nshtrainer/model/base.py +5 -114
  23. nshtrainer/model/config.py +49 -501
  24. nshtrainer/model/modules/logger.py +11 -6
  25. nshtrainer/runner.py +3 -6
  26. nshtrainer/trainer/_checkpoint_metadata.py +102 -0
  27. nshtrainer/trainer/_checkpoint_resolver.py +319 -0
  28. nshtrainer/trainer/_runtime_callback.py +120 -0
  29. nshtrainer/trainer/checkpoint_connector.py +63 -0
  30. nshtrainer/trainer/signal_connector.py +12 -9
  31. nshtrainer/trainer/trainer.py +111 -31
  32. {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.0.dist-info}/METADATA +3 -1
  33. {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.0.dist-info}/RECORD +34 -27
  34. nshtrainer/actsave/__init__.py +0 -3
  35. {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.0.dist-info}/WHEEL +0 -0
@@ -1,7 +1,5 @@
1
1
  import copy
2
2
  import os
3
- import re
4
- import socket
5
3
  import string
6
4
  import time
7
5
  import warnings
@@ -36,10 +34,17 @@ from lightning.pytorch.strategies.strategy import Strategy
36
34
  from pydantic import DirectoryPath
37
35
  from typing_extensions import Self, TypedDict, TypeVar, override
38
36
 
39
- from ..callbacks import CallbackConfig
37
+ from ..callbacks import (
38
+ CallbackConfig,
39
+ LatestEpochCheckpointCallbackConfig,
40
+ ModelCheckpointCallbackConfig,
41
+ OnExceptionCheckpointCallbackConfig,
42
+ WandbWatchConfig,
43
+ )
40
44
  from ..callbacks.base import CallbackConfigBase
41
- from ..callbacks.wandb_watch import WandbWatchConfig
42
- from ..util.slurm import parse_slurm_node_list
45
+ from ..metrics import MetricConfig
46
+ from ..trainer._checkpoint_resolver import CheckpointLoadingConfig
47
+ from ._environment import EnvironmentConfig
43
48
 
44
49
  log = getLogger(__name__)
45
50
 
@@ -62,7 +67,7 @@ class BaseProfilerConfig(C.Config, ABC):
62
67
  """
63
68
 
64
69
  @abstractmethod
65
- def construct_profiler(self, root_config: "BaseConfig") -> Profiler: ...
70
+ def create_profiler(self, root_config: "BaseConfig") -> Profiler: ...
66
71
 
67
72
 
68
73
  class SimpleProfilerConfig(BaseProfilerConfig):
@@ -75,7 +80,7 @@ class SimpleProfilerConfig(BaseProfilerConfig):
75
80
  """
76
81
 
77
82
  @override
78
- def construct_profiler(self, root_config):
83
+ def create_profiler(self, root_config):
79
84
  from lightning.pytorch.profilers.simple import SimpleProfiler
80
85
 
81
86
  if (dirpath := self.dirpath) is None:
@@ -104,7 +109,7 @@ class AdvancedProfilerConfig(BaseProfilerConfig):
104
109
  """
105
110
 
106
111
  @override
107
- def construct_profiler(self, root_config):
112
+ def create_profiler(self, root_config):
108
113
  from lightning.pytorch.profilers.advanced import AdvancedProfiler
109
114
 
110
115
  if (dirpath := self.dirpath) is None:
@@ -172,7 +177,7 @@ class PyTorchProfilerConfig(BaseProfilerConfig):
172
177
  """Keyword arguments for the PyTorch profiler. This depends on your PyTorch version"""
173
178
 
174
179
  @override
175
- def construct_profiler(self, root_config):
180
+ def create_profiler(self, root_config):
176
181
  from lightning.pytorch.profilers.pytorch import PyTorchProfiler
177
182
 
178
183
  if (dirpath := self.dirpath) is None:
@@ -203,190 +208,6 @@ ProfilerConfig: TypeAlias = Annotated[
203
208
  ]
204
209
 
205
210
 
206
- class EnvironmentClassInformationConfig(C.Config):
207
- name: str
208
- module: str
209
- full_name: str
210
-
211
- file_path: Path
212
- source_file_path: Path | None = None
213
-
214
-
215
- class EnvironmentSLURMInformationConfig(C.Config):
216
- hostname: str
217
- hostnames: list[str]
218
- job_id: str
219
- raw_job_id: str
220
- array_job_id: str | None
221
- array_task_id: str | None
222
- num_tasks: int
223
- num_nodes: int
224
- node: str | int | None
225
- global_rank: int
226
- local_rank: int
227
-
228
- @classmethod
229
- def from_current_environment(cls):
230
- try:
231
- from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
232
-
233
- if not SLURMEnvironment.detect():
234
- return None
235
-
236
- hostname = socket.gethostname()
237
- hostnames = [hostname]
238
- if node_list := os.environ.get("SLURM_JOB_NODELIST", ""):
239
- hostnames = parse_slurm_node_list(node_list)
240
-
241
- raw_job_id = os.environ["SLURM_JOB_ID"]
242
- job_id = raw_job_id
243
- array_job_id = os.environ.get("SLURM_ARRAY_JOB_ID")
244
- array_task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
245
- if array_job_id and array_task_id:
246
- job_id = f"{array_job_id}_{array_task_id}"
247
-
248
- num_tasks = int(os.environ["SLURM_NTASKS"])
249
- num_nodes = int(os.environ["SLURM_JOB_NUM_NODES"])
250
-
251
- node_id = os.environ.get("SLURM_NODEID")
252
-
253
- global_rank = int(os.environ["SLURM_PROCID"])
254
- local_rank = int(os.environ["SLURM_LOCALID"])
255
-
256
- return cls(
257
- hostname=hostname,
258
- hostnames=hostnames,
259
- job_id=job_id,
260
- raw_job_id=raw_job_id,
261
- array_job_id=array_job_id,
262
- array_task_id=array_task_id,
263
- num_tasks=num_tasks,
264
- num_nodes=num_nodes,
265
- node=node_id,
266
- global_rank=global_rank,
267
- local_rank=local_rank,
268
- )
269
- except (ImportError, RuntimeError, ValueError, KeyError):
270
- return None
271
-
272
-
273
- class EnvironmentLSFInformationConfig(C.Config):
274
- hostname: str
275
- hostnames: list[str]
276
- job_id: str
277
- array_job_id: str | None
278
- array_task_id: str | None
279
- num_tasks: int
280
- num_nodes: int
281
- node: str | int | None
282
- global_rank: int
283
- local_rank: int
284
-
285
- @classmethod
286
- def from_current_environment(cls):
287
- try:
288
- import os
289
- import socket
290
-
291
- hostname = socket.gethostname()
292
- hostnames = [hostname]
293
- if node_list := os.environ.get("LSB_HOSTS", ""):
294
- hostnames = node_list.split()
295
-
296
- job_id = os.environ["LSB_JOBID"]
297
- array_job_id = os.environ.get("LSB_JOBINDEX")
298
- array_task_id = os.environ.get("LSB_JOBINDEX")
299
-
300
- num_tasks = int(os.environ.get("LSB_DJOB_NUMPROC", 1))
301
- num_nodes = len(set(hostnames))
302
-
303
- node_id = (
304
- os.environ.get("LSB_HOSTS", "").split().index(hostname)
305
- if "LSB_HOSTS" in os.environ
306
- else None
307
- )
308
-
309
- # LSF doesn't have direct equivalents for global_rank and local_rank
310
- # You might need to calculate these based on your specific setup
311
- global_rank = int(os.environ.get("PMI_RANK", 0))
312
- local_rank = int(os.environ.get("LSB_RANK", 0))
313
-
314
- return cls(
315
- hostname=hostname,
316
- hostnames=hostnames,
317
- job_id=job_id,
318
- array_job_id=array_job_id,
319
- array_task_id=array_task_id,
320
- num_tasks=num_tasks,
321
- num_nodes=num_nodes,
322
- node=node_id,
323
- global_rank=global_rank,
324
- local_rank=local_rank,
325
- )
326
- except (ImportError, RuntimeError, ValueError, KeyError):
327
- return None
328
-
329
-
330
- class EnvironmentLinuxEnvironmentConfig(C.Config):
331
- """
332
- Information about the Linux environment (e.g., current user, hostname, etc.)
333
- """
334
-
335
- user: str | None = None
336
- hostname: str | None = None
337
- system: str | None = None
338
- release: str | None = None
339
- version: str | None = None
340
- machine: str | None = None
341
- processor: str | None = None
342
- cpu_count: int | None = None
343
- memory: int | None = None
344
- uptime: timedelta | None = None
345
- boot_time: float | None = None
346
- load_avg: tuple[float, float, float] | None = None
347
-
348
-
349
- class EnvironmentSnapshotConfig(C.Config):
350
- snapshot_dir: Path | None = None
351
- modules: list[str] | None = None
352
-
353
- @classmethod
354
- def from_current_environment(cls):
355
- draft = cls.draft()
356
- if snapshot_dir := os.environ.get("NSHRUNNER_SNAPSHOT_DIR"):
357
- draft.snapshot_dir = Path(snapshot_dir)
358
- if modules := os.environ.get("NSHRUNNER_SNAPSHOT_MODULES"):
359
- draft.modules = modules.split(",")
360
- return draft.finalize()
361
-
362
-
363
- class EnvironmentConfig(C.Config):
364
- cwd: Path | None = None
365
-
366
- snapshot: EnvironmentSnapshotConfig | None = None
367
-
368
- python_executable: Path | None = None
369
- python_path: list[Path] | None = None
370
- python_version: str | None = None
371
-
372
- config: EnvironmentClassInformationConfig | None = None
373
- model: EnvironmentClassInformationConfig | None = None
374
- data: EnvironmentClassInformationConfig | None = None
375
-
376
- linux: EnvironmentLinuxEnvironmentConfig | None = None
377
-
378
- slurm: EnvironmentSLURMInformationConfig | None = None
379
- lsf: EnvironmentLSFInformationConfig | None = None
380
-
381
- base_dir: Path | None = None
382
- log_dir: Path | None = None
383
- checkpoint_dir: Path | None = None
384
- stdio_dir: Path | None = None
385
-
386
- seed: int | None = None
387
- seed_workers: bool | None = None
388
-
389
-
390
211
  class BaseLoggerConfig(C.Config, ABC):
391
212
  enabled: bool = True
392
213
  """Enable this logger."""
@@ -398,7 +219,7 @@ class BaseLoggerConfig(C.Config, ABC):
398
219
  """Directory to save the logs to. If None, will use the default log directory for the trainer."""
399
220
 
400
221
  @abstractmethod
401
- def construct_logger(self, root_config: "BaseConfig") -> Logger | None: ...
222
+ def create_logger(self, root_config: "BaseConfig") -> Logger | None: ...
402
223
 
403
224
  def disable_(self):
404
225
  self.enabled = False
@@ -466,18 +287,16 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
466
287
  """Whether to run WandB in offline mode."""
467
288
 
468
289
  @override
469
- def construct_logger(self, root_config):
290
+ def create_logger(self, root_config):
470
291
  if not self.enabled:
471
292
  return None
472
293
 
473
294
  from lightning.pytorch.loggers.wandb import WandbLogger
474
295
 
475
- save_dir = root_config.directory.resolve_log_directory_for_logger(
296
+ save_dir = root_config.directory._resolve_log_directory_for_logger(
476
297
  root_config.id,
477
298
  self,
478
299
  )
479
- save_dir = save_dir / "wandb"
480
- save_dir.mkdir(parents=True, exist_ok=True)
481
300
  return WandbLogger(
482
301
  save_dir=save_dir,
483
302
  project=self.project or _project_name(root_config),
@@ -494,9 +313,9 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
494
313
  )
495
314
 
496
315
  @override
497
- def construct_callbacks(self, root_config):
316
+ def create_callbacks(self, root_config):
498
317
  if self.watch:
499
- yield from self.watch.construct_callbacks(root_config)
318
+ yield from self.watch.create_callbacks(root_config)
500
319
 
501
320
 
502
321
  class CSVLoggerConfig(BaseLoggerConfig):
@@ -515,18 +334,16 @@ class CSVLoggerConfig(BaseLoggerConfig):
515
334
  """How often to flush logs to disk."""
516
335
 
517
336
  @override
518
- def construct_logger(self, root_config):
337
+ def create_logger(self, root_config):
519
338
  if not self.enabled:
520
339
  return None
521
340
 
522
341
  from lightning.pytorch.loggers.csv_logs import CSVLogger
523
342
 
524
- save_dir = root_config.directory.resolve_log_directory_for_logger(
343
+ save_dir = root_config.directory._resolve_log_directory_for_logger(
525
344
  root_config.id,
526
345
  self,
527
346
  )
528
- save_dir = save_dir / "csv"
529
- save_dir.mkdir(parents=True, exist_ok=True)
530
347
  return CSVLogger(
531
348
  save_dir=save_dir,
532
349
  name=root_config.run_name,
@@ -581,18 +398,16 @@ class TensorboardLoggerConfig(BaseLoggerConfig):
581
398
  """A string to put at the beginning of metric keys."""
582
399
 
583
400
  @override
584
- def construct_logger(self, root_config):
401
+ def create_logger(self, root_config):
585
402
  if not self.enabled:
586
403
  return None
587
404
 
588
405
  from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
589
406
 
590
- save_dir = root_config.directory.resolve_log_directory_for_logger(
407
+ save_dir = root_config.directory._resolve_log_directory_for_logger(
591
408
  root_config.id,
592
409
  self,
593
410
  )
594
- save_dir = save_dir / "tensorboard"
595
- save_dir.mkdir(parents=True, exist_ok=True)
596
411
  return TensorBoardLogger(
597
412
  save_dir=save_dir,
598
413
  name=root_config.run_name,
@@ -624,6 +439,9 @@ class LoggingConfig(CallbackConfigBase):
624
439
  log_epoch: bool = True
625
440
  """If enabled, will log the fractional epoch number to the logger."""
626
441
 
442
+ actsave_logged_metrics: bool = False
443
+ """If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
444
+
627
445
  @property
628
446
  def wandb(self) -> WandbLoggerConfig | None:
629
447
  return next(
@@ -650,7 +468,7 @@ class LoggingConfig(CallbackConfigBase):
650
468
  ),
651
469
  )
652
470
 
653
- def construct_loggers(self, root_config: "BaseConfig"):
471
+ def create_loggers(self, root_config: "BaseConfig"):
654
472
  """
655
473
  Constructs and returns a list of loggers based on the provided root configuration.
656
474
 
@@ -671,13 +489,13 @@ class LoggingConfig(CallbackConfigBase):
671
489
  ):
672
490
  if not logger_config.enabled:
673
491
  continue
674
- if (logger := logger_config.construct_logger(root_config)) is None:
492
+ if (logger := logger_config.create_logger(root_config)) is None:
675
493
  continue
676
494
  loggers.append(logger)
677
495
  return loggers
678
496
 
679
497
  @override
680
- def construct_callbacks(self, root_config):
498
+ def create_callbacks(self, root_config):
681
499
  if self.log_lr:
682
500
  from lightning.pytorch.callbacks import LearningRateMonitor
683
501
 
@@ -696,7 +514,7 @@ class LoggingConfig(CallbackConfigBase):
696
514
  if not logger or not isinstance(logger, CallbackConfigBase):
697
515
  continue
698
516
 
699
- yield from logger.construct_callbacks(root_config)
517
+ yield from logger.create_callbacks(root_config)
700
518
 
701
519
 
702
520
  class GradientClippingConfig(C.Config):
@@ -723,7 +541,7 @@ class OptimizationConfig(CallbackConfigBase):
723
541
  """Gradient clipping configuration, or None to disable gradient clipping."""
724
542
 
725
543
  @override
726
- def construct_callbacks(self, root_config):
544
+ def create_callbacks(self, root_config):
727
545
  from ..callbacks.norm_logging import NormLoggingConfig
728
546
 
729
547
  yield from NormLoggingConfig(
@@ -731,7 +549,7 @@ class OptimizationConfig(CallbackConfigBase):
731
549
  log_grad_norm_per_param=self.log_grad_norm_per_param,
732
550
  log_param_norm=self.log_param_norm,
733
551
  log_param_norm_per_param=self.log_param_norm_per_param,
734
- ).construct_callbacks(root_config)
552
+ ).create_callbacks(root_config)
735
553
 
736
554
 
737
555
  TPlugin = TypeVar(
@@ -746,17 +564,17 @@ TPlugin = TypeVar(
746
564
 
747
565
  @runtime_checkable
748
566
  class PluginConfigProtocol(Protocol[TPlugin]):
749
- def construct_plugin(self) -> TPlugin: ...
567
+ def create_plugin(self) -> TPlugin: ...
750
568
 
751
569
 
752
570
  @runtime_checkable
753
571
  class AcceleratorConfigProtocol(Protocol):
754
- def construct_accelerator(self) -> Accelerator: ...
572
+ def create_accelerator(self) -> Accelerator: ...
755
573
 
756
574
 
757
575
  @runtime_checkable
758
576
  class StrategyConfigProtocol(Protocol):
759
- def construct_strategy(self) -> Strategy: ...
577
+ def create_strategy(self) -> Strategy: ...
760
578
 
761
579
 
762
580
  AcceleratorLiteral: TypeAlias = Literal[
@@ -793,18 +611,6 @@ StrategyLiteral: TypeAlias = Literal[
793
611
  ]
794
612
 
795
613
 
796
- class CheckpointLoadingConfig(C.Config):
797
- path: Literal["best", "last", "hpc"] | str | Path | None = None
798
- """
799
- Checkpoint path to use when loading a checkpoint.
800
-
801
- - "best" will load the best checkpoint.
802
- - "last" will load the last checkpoint.
803
- - "hpc" will load the SLURM pre-empted checkpoint.
804
- - Any other string or Path will load the checkpoint from the specified path.
805
- """
806
-
807
-
808
614
  def _create_symlink_to_nshrunner(base_dir: Path):
809
615
  # Resolve the current nshrunner session directory
810
616
  if not (session_dir := os.environ.get("NSHRUNNER_SESSION_DIR")):
@@ -903,7 +709,7 @@ class DirectoryConfig(C.Config):
903
709
  dir.mkdir(exist_ok=True)
904
710
  return dir
905
711
 
906
- def resolve_log_directory_for_logger(
712
+ def _resolve_log_directory_for_logger(
907
713
  self,
908
714
  run_id: str,
909
715
  logger: LoggerConfig,
@@ -911,9 +717,10 @@ class DirectoryConfig(C.Config):
911
717
  if (log_dir := logger.log_dir) is not None:
912
718
  return log_dir
913
719
 
914
- # Save to nshtrainer/{id}/log/{logger kind}/{id}/
720
+ # Save to nshtrainer/{id}/log/{logger kind}
915
721
  log_dir = self.resolve_subdirectory(run_id, "log")
916
722
  log_dir = log_dir / logger.kind
723
+ log_dir.mkdir(exist_ok=True)
917
724
 
918
725
  return log_dir
919
726
 
@@ -927,208 +734,6 @@ class ReproducibilityConfig(C.Config):
927
734
  """
928
735
 
929
736
 
930
- class ModelCheckpointCallbackConfig(CallbackConfigBase):
931
- """Arguments for the ModelCheckpoint callback."""
932
-
933
- kind: Literal["model_checkpoint"] = "model_checkpoint"
934
-
935
- dirpath: str | Path | None = None
936
- """
937
- Directory path to save the model file. If `None`, we save to the checkpoint directory set in `config.directory`.
938
- """
939
-
940
- filename: str | None = None
941
- """
942
- Checkpoint filename.
943
- If None, a default template is used (see :attr:`ModelCheckpoint.CHECKPOINT_JOIN_CHAR`).
944
- """
945
-
946
- monitor: str | None = None
947
- """
948
- Quantity to monitor for saving checkpoints.
949
- If None, no metric is monitored and checkpoints are saved at the end of every epoch.
950
- """
951
-
952
- verbose: bool = False
953
- """Verbosity mode. If True, print additional information about checkpoints."""
954
-
955
- save_last: Literal[True, False, "link"] | None = "link"
956
- """
957
- Whether to save the last checkpoint.
958
- If True, saves a copy of the last checkpoint separately.
959
- If "link", creates a symbolic link to the last checkpoint.
960
- """
961
-
962
- save_top_k: int = 1
963
- """
964
- Number of best models to save.
965
- If -1, all models are saved.
966
- If 0, no models are saved.
967
- """
968
-
969
- save_weights_only: bool = False
970
- """Whether to save only the model's weights or the entire model object."""
971
-
972
- mode: str = "min"
973
- """
974
- One of "min" or "max".
975
- If "min", training will stop when the metric monitored has stopped decreasing.
976
- If "max", training will stop when the metric monitored has stopped increasing.
977
- """
978
-
979
- auto_insert_metric_name: bool = True
980
- """Whether to automatically insert the metric name in the checkpoint filename."""
981
-
982
- every_n_train_steps: int | None = None
983
- """
984
- Number of training steps between checkpoints.
985
- If None or 0, no checkpoints are saved during training.
986
- """
987
-
988
- train_time_interval: timedelta | None = None
989
- """
990
- Time interval between checkpoints during training.
991
- If None, no checkpoints are saved during training based on time.
992
- """
993
-
994
- every_n_epochs: int | None = None
995
- """
996
- Number of epochs between checkpoints.
997
- If None or 0, no checkpoints are saved at the end of epochs.
998
- """
999
-
1000
- save_on_train_epoch_end: bool | None = None
1001
- """
1002
- Whether to run checkpointing at the end of the training epoch.
1003
- If False, checkpointing runs at the end of the validation.
1004
- """
1005
-
1006
- enable_version_counter: bool = True
1007
- """Whether to append a version to the existing file name."""
1008
-
1009
- auto_append_metric: bool = True
1010
- """If enabled, this will automatically add "-{monitor}" to the filename."""
1011
-
1012
- @staticmethod
1013
- def _convert_string(input_string: str):
1014
- # Find all variables enclosed in curly braces
1015
- variables = re.findall(r"\{(.*?)\}", input_string)
1016
-
1017
- # Replace each variable with its corresponding key-value pair
1018
- output_string = input_string
1019
- for variable in variables:
1020
- # If the name is something like {variable:format}, we shouldn't process the format.
1021
- key_name = variable
1022
- if ":" in variable:
1023
- key_name, _ = variable.split(":", 1)
1024
- continue
1025
-
1026
- # Replace '/' with '_' in the key name
1027
- key_name = key_name.replace("/", "_")
1028
- output_string = output_string.replace(
1029
- f"{{{variable}}}", f"{key_name}={{{variable}}}"
1030
- )
1031
-
1032
- return output_string
1033
-
1034
- @override
1035
- def construct_callbacks(self, root_config):
1036
- from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
1037
-
1038
- dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
1039
- root_config.id, "checkpoint"
1040
- )
1041
-
1042
- # If `monitor` is not provided, we can use `config.primary_metric` if it is set.
1043
- monitor = self.monitor
1044
- mode = self.mode
1045
- if (
1046
- monitor is None
1047
- and (primary_metric := root_config.primary_metric) is not None
1048
- ):
1049
- monitor = primary_metric.validation_monitor
1050
- mode = primary_metric.mode
1051
-
1052
- filename = self.filename
1053
- if self.auto_append_metric:
1054
- if not filename:
1055
- filename = "{epoch}-{step}"
1056
- filename = f"{filename}-{{{monitor}}}"
1057
-
1058
- if self.auto_insert_metric_name and filename:
1059
- new_filename = self._convert_string(filename)
1060
- log.critical(
1061
- f"Updated ModelCheckpoint filename: {filename} -> {new_filename}"
1062
- )
1063
- filename = new_filename
1064
-
1065
- yield ModelCheckpoint(
1066
- dirpath=dirpath,
1067
- filename=filename,
1068
- monitor=monitor,
1069
- mode=mode,
1070
- verbose=self.verbose,
1071
- save_last=self.save_last,
1072
- save_top_k=self.save_top_k,
1073
- save_weights_only=self.save_weights_only,
1074
- auto_insert_metric_name=False,
1075
- every_n_train_steps=self.every_n_train_steps,
1076
- train_time_interval=self.train_time_interval,
1077
- every_n_epochs=self.every_n_epochs,
1078
- save_on_train_epoch_end=self.save_on_train_epoch_end,
1079
- enable_version_counter=self.enable_version_counter,
1080
- )
1081
-
1082
-
1083
- class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
1084
- kind: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
1085
-
1086
- dirpath: str | Path | None = None
1087
- """Directory path to save the checkpoint file."""
1088
-
1089
- filename: str | None = None
1090
- """Checkpoint filename. This must not include the extension. If `None`, `latest_epoch_{id}_{timestamp}` is used."""
1091
-
1092
- save_weights_only: bool = False
1093
- """Whether to save only the model's weights or the entire model object."""
1094
-
1095
- @override
1096
- def construct_callbacks(self, root_config):
1097
- from ..callbacks.latest_epoch_checkpoint import LatestEpochCheckpoint
1098
-
1099
- dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
1100
- root_config.id, "checkpoint"
1101
- )
1102
-
1103
- yield LatestEpochCheckpoint(
1104
- dirpath=dirpath,
1105
- filename=self.filename,
1106
- save_weights_only=self.save_weights_only,
1107
- )
1108
-
1109
-
1110
- class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
1111
- kind: Literal["on_exception_checkpoint"] = "on_exception_checkpoint"
1112
-
1113
- dirpath: str | Path | None = None
1114
- """Directory path to save the checkpoint file."""
1115
-
1116
- filename: str | None = None
1117
- """Checkpoint filename. This must not include the extension. If `None`, `on_exception_{id}_{timestamp}` is used."""
1118
-
1119
- @override
1120
- def construct_callbacks(self, root_config):
1121
- from ..callbacks.on_exception_checkpoint import OnExceptionCheckpoint
1122
-
1123
- dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
1124
- root_config.id, "checkpoint"
1125
- )
1126
-
1127
- if not (filename := self.filename):
1128
- filename = f"on_exception_{root_config.id}"
1129
- yield OnExceptionCheckpoint(dirpath=dirpath, filename=filename)
1130
-
1131
-
1132
737
  CheckpointCallbackConfig: TypeAlias = Annotated[
1133
738
  ModelCheckpointCallbackConfig
1134
739
  | LatestEpochCheckpointCallbackConfig
@@ -1192,12 +797,12 @@ class CheckpointSavingConfig(CallbackConfigBase):
1192
797
  )
1193
798
 
1194
799
  @override
1195
- def construct_callbacks(self, root_config: "BaseConfig"):
800
+ def create_callbacks(self, root_config: "BaseConfig"):
1196
801
  if not self.should_save_checkpoints(root_config):
1197
802
  return
1198
803
 
1199
804
  for callback_config in self.checkpoint_callbacks:
1200
- yield from callback_config.construct_callbacks(root_config)
805
+ yield from callback_config.create_callbacks(root_config)
1201
806
 
1202
807
 
1203
808
  class LightningTrainerKwargs(TypedDict, total=False):
@@ -1474,7 +1079,7 @@ class EarlyStoppingConfig(CallbackConfigBase):
1474
1079
  """
1475
1080
 
1476
1081
  @override
1477
- def construct_callbacks(self, root_config: "BaseConfig"):
1082
+ def create_callbacks(self, root_config: "BaseConfig"):
1478
1083
  from ..callbacks.early_stopping import EarlyStopping
1479
1084
 
1480
1085
  monitor = self.monitor
@@ -1505,32 +1110,6 @@ class EarlyStoppingConfig(CallbackConfigBase):
1505
1110
  ]
1506
1111
 
1507
1112
 
1508
- class ActSaveConfig(CallbackConfigBase):
1509
- enabled: bool = True
1510
- """Enable activation saving."""
1511
-
1512
- auto_save_logged_metrics: bool = False
1513
- """If enabled, will automatically save logged metrics (using `LightningModule.log`) as activations."""
1514
-
1515
- save_dir: Path | None = None
1516
- """Directory to save activations to. If None, will use the activation directory set in `config.directory`."""
1517
-
1518
- def __bool__(self):
1519
- return self.enabled
1520
-
1521
- def resolve_save_dir(self, root_config: "BaseConfig"):
1522
- if self.save_dir is not None:
1523
- return self.save_dir
1524
-
1525
- return root_config.directory.resolve_subdirectory(root_config.id, "activation")
1526
-
1527
- @override
1528
- def construct_callbacks(self, root_config):
1529
- from ..actsave import ActSaveCallback
1530
-
1531
- return [ActSaveCallback()]
1532
-
1533
-
1534
1113
  class SanityCheckingConfig(C.Config):
1535
1114
  reduce_lr_on_plateau: Literal["disable", "error", "warn"] = "error"
1536
1115
  """
@@ -1542,7 +1121,7 @@ class SanityCheckingConfig(C.Config):
1542
1121
 
1543
1122
 
1544
1123
  class TrainerConfig(C.Config):
1545
- checkpoint_loading: CheckpointLoadingConfig = CheckpointLoadingConfig()
1124
+ checkpoint_loading: CheckpointLoadingConfig | Literal["auto"] = "auto"
1546
1125
  """Checkpoint loading configuration options."""
1547
1126
 
1548
1127
  checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
@@ -1560,9 +1139,6 @@ class TrainerConfig(C.Config):
1560
1139
  sanity_checking: SanityCheckingConfig = SanityCheckingConfig()
1561
1140
  """Sanity checking configuration options."""
1562
1141
 
1563
- actsave: ActSaveConfig | None = ActSaveConfig(enabled=False)
1564
- """Activation saving configuration options."""
1565
-
1566
1142
  early_stopping: EarlyStoppingConfig | None = None
1567
1143
  """Early stopping configuration options."""
1568
1144
 
@@ -1731,12 +1307,12 @@ class TrainerConfig(C.Config):
1731
1307
  automatic selection based on the chosen accelerator. Default: ``"auto"``.
1732
1308
  """
1733
1309
 
1734
- auto_wrap_trainer: bool = True
1735
- """If enabled, will automatically wrap the `run` function with a `Trainer.context()` context manager. Should be `True` most of the time."""
1736
1310
  auto_set_default_root_dir: bool = True
1737
1311
  """If enabled, will automatically set the default root dir to [cwd/lightning_logs/<id>/]. There is basically no reason to disable this."""
1738
1312
  supports_shared_parameters: bool = True
1739
1313
  """If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`"""
1314
+ save_checkpoint_metadata: bool = True
1315
+ """If enabled, will save additional metadata whenever a checkpoint is saved."""
1740
1316
 
1741
1317
  lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
1742
1318
  """
@@ -1756,35 +1332,6 @@ class TrainerConfig(C.Config):
1756
1332
  """If enabled, will set the torch float32 matmul precision to the specified value. Useful for faster training on Ampere+ GPUs."""
1757
1333
 
1758
1334
 
1759
- class MetricConfig(C.Config):
1760
- name: str
1761
- """The name of the primary metric."""
1762
-
1763
- mode: Literal["min", "max"]
1764
- """
1765
- The mode of the primary metric:
1766
- - "min" for metrics that should be minimized (e.g., loss)
1767
- - "max" for metrics that should be maximized (e.g., accuracy)
1768
- """
1769
-
1770
- @property
1771
- def validation_monitor(self) -> str:
1772
- return f"val/{self.name}"
1773
-
1774
- def __post_init__(self):
1775
- for split in ("train", "val", "test", "predict"):
1776
- if self.name.startswith(f"{split}/"):
1777
- raise ValueError(
1778
- f"Primary metric name should not start with '{split}/'. "
1779
- f"Just use '{self.name[len(split) + 1:]}' instead. "
1780
- "The split name is automatically added depending on the context."
1781
- )
1782
-
1783
- @classmethod
1784
- def loss(cls, mode: Literal["min", "max"] = "min"):
1785
- return cls(name="loss", mode=mode)
1786
-
1787
-
1788
1335
  PrimaryMetricConfig: TypeAlias = MetricConfig
1789
1336
 
1790
1337
 
@@ -1804,7 +1351,9 @@ class BaseConfig(C.Config):
1804
1351
 
1805
1352
  debug: bool = False
1806
1353
  """Whether to run in debug mode. This will enable debug logging and enable debug code paths."""
1807
- environment: Annotated[EnvironmentConfig, C.Field(repr=False)] = EnvironmentConfig()
1354
+ environment: Annotated[EnvironmentConfig, C.Field(repr=False)] = (
1355
+ EnvironmentConfig.empty()
1356
+ )
1808
1357
  """A snapshot of the current environment information (e.g. python version, slurm info, etc.). This is automatically populated by the run script."""
1809
1358
 
1810
1359
  directory: DirectoryConfig = DirectoryConfig()
@@ -1892,7 +1441,7 @@ class BaseConfig(C.Config):
1892
1441
  self.directory = DirectoryConfig()
1893
1442
 
1894
1443
  if environment:
1895
- self.environment = EnvironmentConfig()
1444
+ self.environment = EnvironmentConfig.empty()
1896
1445
 
1897
1446
  if meta:
1898
1447
  self.meta = {}
@@ -1990,8 +1539,7 @@ class BaseConfig(C.Config):
1990
1539
  )
1991
1540
  return cls.model_validate(hparams)
1992
1541
 
1993
- def ll_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
1994
- yield self.trainer.actsave
1542
+ def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
1995
1543
  yield self.trainer.early_stopping
1996
1544
  yield self.trainer.checkpoint_saving
1997
1545
  yield self.trainer.logging