nshtrainer 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/__init__.py CHANGED
@@ -10,16 +10,7 @@ from . import typecheck as typecheck
10
10
  from ._snoop import snoop as snoop
11
11
  from .actsave import ActLoad as ActLoad
12
12
  from .actsave import ActSave as ActSave
13
- from .config import MISSING as MISSING
14
- from .config import AllowMissing as AllowMissing
15
- from .config import Field as Field
16
- from .config import MissingField as MissingField
17
- from .config import PrivateAttr as PrivateAttr
18
- from .config import TypedConfig as TypedConfig
19
13
  from .data import dataset_transform as dataset_transform
20
- from .log import init_python_logging as init_python_logging
21
- from .log import lovely as lovely
22
- from .log import pretty as pretty
23
14
  from .lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
24
15
  from .model import ActSaveConfig as ActSaveConfig
25
16
  from .model import Base as Base
@@ -41,24 +32,17 @@ from .model import (
41
32
  EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
42
33
  )
43
34
  from .model import GradientClippingConfig as GradientClippingConfig
44
- from .model import LightningDataModuleBase as LightningDataModuleBase
45
35
  from .model import LightningModuleBase as LightningModuleBase
46
36
  from .model import LoggingConfig as LoggingConfig
47
37
  from .model import MetricConfig as MetricConfig
48
38
  from .model import OptimizationConfig as OptimizationConfig
49
39
  from .model import PrimaryMetricConfig as PrimaryMetricConfig
50
- from .model import PythonLogging as PythonLogging
51
40
  from .model import ReproducibilityConfig as ReproducibilityConfig
52
- from .model import RunnerConfig as RunnerConfig
53
41
  from .model import SanityCheckingConfig as SanityCheckingConfig
54
- from .model import SeedConfig as SeedConfig
55
42
  from .model import TrainerConfig as TrainerConfig
56
43
  from .model import WandbWatchConfig as WandbWatchConfig
57
44
  from .nn import TypedModuleDict as TypedModuleDict
58
45
  from .nn import TypedModuleList as TypedModuleList
59
46
  from .optimizer import OptimizerConfig as OptimizerConfig
60
47
  from .runner import Runner as Runner
61
- from .runner import SnapshotConfig as SnapshotConfig
62
48
  from .trainer import Trainer as Trainer
63
- from .util.singleton import Registry as Registry
64
- from .util.singleton import Singleton as Singleton
@@ -1,6 +1,7 @@
1
1
  from typing import Annotated
2
2
 
3
- from ..config import Field
3
+ import nshconfig as C
4
+
4
5
  from .base import CallbackConfigBase as CallbackConfigBase
5
6
  from .early_stopping import EarlyStopping as EarlyStopping
6
7
  from .ema import EMA as EMA
@@ -31,5 +32,5 @@ CallbackConfig = Annotated[
31
32
  | NormLoggingConfig
32
33
  | GradientSkippingConfig
33
34
  | EMAConfig,
34
- Field(discriminator="name"),
35
+ C.Field(discriminator="name"),
35
36
  ]
@@ -4,10 +4,9 @@ from collections.abc import Iterable
4
4
  from dataclasses import dataclass
5
5
  from typing import TYPE_CHECKING, TypeAlias, TypedDict
6
6
 
7
+ import nshconfig as C
7
8
  from lightning.pytorch import Callback
8
9
 
9
- from ..config import TypedConfig
10
-
11
10
  if TYPE_CHECKING:
12
11
  from ..model.config import BaseConfig
13
12
 
@@ -20,7 +19,7 @@ class CallbackMetadataDict(TypedDict, total=False):
20
19
  """Priority of the callback. Callbacks with higher priority will be loaded first."""
21
20
 
22
21
 
23
- class CallbackMetadataConfig(TypedConfig):
22
+ class CallbackMetadataConfig(C.Config):
24
23
  ignore_if_exists: bool = False
25
24
  """If `True`, the callback will not be added if another callback with the same class already exists."""
26
25
 
@@ -37,7 +36,7 @@ class CallbackWithMetadata:
37
36
  ConstructedCallback: TypeAlias = Callback | CallbackWithMetadata
38
37
 
39
38
 
40
- class CallbackConfigBase(TypedConfig, ABC):
39
+ class CallbackConfigBase(C.Config, ABC):
41
40
  metadata: CallbackMetadataConfig = CallbackMetadataConfig()
42
41
  """Metadata for the callback."""
43
42
 
@@ -1,6 +1,7 @@
1
1
  from typing import Annotated, TypeAlias
2
2
 
3
- from ..config import Field
3
+ import nshconfig as C
4
+
4
5
  from ._base import LRSchedulerConfigBase as LRSchedulerConfigBase
5
6
  from ._base import LRSchedulerMetadata as LRSchedulerMetadata
6
7
  from .linear_warmup_cosine import (
@@ -14,5 +15,5 @@ from .reduce_lr_on_plateau import ReduceLROnPlateauConfig as ReduceLROnPlateauCo
14
15
 
15
16
  LRSchedulerConfig: TypeAlias = Annotated[
16
17
  LinearWarmupCosineDecayLRSchedulerConfig | ReduceLROnPlateauConfig,
17
- Field(discriminator="name"),
18
+ C.Field(discriminator="name"),
18
19
  ]
@@ -1,8 +1,9 @@
1
1
  import math
2
2
  from abc import ABC, abstractmethod
3
3
  from collections.abc import Mapping
4
- from typing import TYPE_CHECKING, ClassVar, Literal, TypeAlias
4
+ from typing import TYPE_CHECKING, Literal
5
5
 
6
+ import nshconfig as C
6
7
  from lightning.pytorch.utilities.types import (
7
8
  LRSchedulerConfigType,
8
9
  LRSchedulerTypeUnion,
@@ -10,8 +11,6 @@ from lightning.pytorch.utilities.types import (
10
11
  from torch.optim import Optimizer
11
12
  from typing_extensions import NotRequired, TypedDict
12
13
 
13
- from ..config import TypedConfig
14
-
15
14
  if TYPE_CHECKING:
16
15
  from ..model.base import LightningModuleBase
17
16
 
@@ -37,9 +36,7 @@ class LRSchedulerMetadata(TypedDict):
37
36
  """Whether to enforce that the monitor exists for reducing the learning rate on plateau. Default is `True`."""
38
37
 
39
38
 
40
- class LRSchedulerConfigBase(TypedConfig, ABC):
41
- Metadata: ClassVar[TypeAlias] = LRSchedulerMetadata
42
-
39
+ class LRSchedulerConfigBase(C.Config, ABC):
43
40
  @abstractmethod
44
41
  def metadata(self) -> LRSchedulerMetadata: ...
45
42
 
@@ -2,12 +2,12 @@ import math
2
2
  import warnings
3
3
  from typing import Literal
4
4
 
5
+ import nshconfig as C
5
6
  from torch.optim import Optimizer
6
7
  from torch.optim.lr_scheduler import LRScheduler
7
8
  from typing_extensions import override
8
9
 
9
- from ..config import Field
10
- from ._base import LRSchedulerConfigBase
10
+ from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
11
11
 
12
12
 
13
13
  class LinearWarmupCosineAnnealingLR(LRScheduler):
@@ -91,11 +91,11 @@ class LinearWarmupCosineAnnealingLR(LRScheduler):
91
91
  class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
92
92
  name: Literal["linear_warmup_cosine_decay"] = "linear_warmup_cosine_decay"
93
93
 
94
- warmup_epochs: int = Field(ge=0)
94
+ warmup_epochs: int = C.Field(ge=0)
95
95
  r"""The number of epochs for the linear warmup phase.
96
96
  The learning rate is linearly increased from `warmup_start_lr` to the initial learning rate over this number of epochs."""
97
97
 
98
- max_epochs: int = Field(gt=0)
98
+ max_epochs: int = C.Field(gt=0)
99
99
  r"""The total number of epochs.
100
100
  The learning rate is decayed to `min_lr` over this number of epochs."""
101
101
 
@@ -113,7 +113,7 @@ class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
113
113
  If `True`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be increased back to the initial learning rate over `max_epochs` epochs, and so on (this is called a cosine annealing schedule)."""
114
114
 
115
115
  @override
116
- def metadata(self) -> LRSchedulerConfigBase.Metadata:
116
+ def metadata(self) -> LRSchedulerMetadata:
117
117
  return {
118
118
  "interval": "step",
119
119
  }
@@ -1,12 +1,11 @@
1
1
  from typing import TYPE_CHECKING, Literal, cast
2
2
 
3
+ from lightning.pytorch.utilities.types import LRSchedulerConfigType
3
4
  from torch.optim.lr_scheduler import ReduceLROnPlateau
4
5
  from typing_extensions import override
5
6
 
6
- from ll.lr_scheduler._base import LRSchedulerMetadata
7
-
8
7
  from ..model.config import MetricConfig
9
- from ._base import LRSchedulerConfigBase
8
+ from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
10
9
 
11
10
  if TYPE_CHECKING:
12
11
  from ..model.base import BaseConfig
@@ -43,7 +42,9 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
43
42
  r"""One of `rel`, `abs`. In `rel` mode, dynamic_threshold = best * (1 + threshold) in 'max' mode or best * (1 - threshold) in `min` mode. In `abs` mode, dynamic_threshold = best + threshold in `max` mode or best - threshold in `min` mode. Default: 'rel'."""
44
43
 
45
44
  @override
46
- def create_scheduler_impl(self, optimizer, lightning_module, lr):
45
+ def create_scheduler_impl(
46
+ self, optimizer, lightning_module, lr
47
+ ) -> LRSchedulerConfigType:
47
48
  if (metric := self.metric) is None:
48
49
  lm_config = cast("BaseConfig", lightning_module.config)
49
50
  assert (
@@ -1,7 +1,6 @@
1
1
  from typing_extensions import TypeAlias
2
2
 
3
3
  from .base import Base as Base
4
- from .base import LightningDataModuleBase as LightningDataModuleBase
5
4
  from .base import LightningModuleBase as LightningModuleBase
6
5
  from .config import ActSaveConfig as ActSaveConfig
7
6
  from .config import BaseConfig as BaseConfig
@@ -33,11 +32,8 @@ from .config import (
33
32
  )
34
33
  from .config import OptimizationConfig as OptimizationConfig
35
34
  from .config import PrimaryMetricConfig as PrimaryMetricConfig
36
- from .config import PythonLogging as PythonLogging
37
35
  from .config import ReproducibilityConfig as ReproducibilityConfig
38
- from .config import RunnerConfig as RunnerConfig
39
36
  from .config import SanityCheckingConfig as SanityCheckingConfig
40
- from .config import SeedConfig as SeedConfig
41
37
  from .config import TrainerConfig as TrainerConfig
42
38
  from .config import WandbWatchConfig as WandbWatchConfig
43
39
 
nshtrainer/model/base.py CHANGED
@@ -23,11 +23,12 @@ from .config import (
23
23
  EnvironmentLinuxEnvironmentConfig,
24
24
  EnvironmentLSFInformationConfig,
25
25
  EnvironmentSLURMInformationConfig,
26
+ EnvironmentSnapshotConfig,
26
27
  )
27
- from .modules.callback import CallbackModuleMixin, CallbackRegistrarModuleMixin
28
+ from .modules.callback import CallbackModuleMixin
28
29
  from .modules.debug import DebugModuleMixin
29
30
  from .modules.distributed import DistributedMixin
30
- from .modules.logger import LoggerLightningModuleMixin, LoggerModuleMixin
31
+ from .modules.logger import LoggerLightningModuleMixin
31
32
  from .modules.profiler import ProfilerMixin
32
33
  from .modules.rlp_sanity_checks import RLPSanityCheckModuleMixin
33
34
  from .modules.shared_parameters import SharedParametersModuleMixin
@@ -265,6 +266,9 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
265
266
  boot_time=_try_get(lambda: _psutil().boot_time()),
266
267
  load_avg=_try_get(lambda: os.getloadavg()),
267
268
  )
269
+ hparams.environment.snapshot = (
270
+ EnvironmentSnapshotConfig.from_current_environment()
271
+ )
268
272
 
269
273
  def pre_init_update_hparams_dict(self, hparams: MutableMapping[str, Any]):
270
274
  """
@@ -309,15 +313,12 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
309
313
  @property
310
314
  def datamodule(self):
311
315
  datamodule = getattr(self.trainer, "datamodule", None)
312
- if datamodule is None:
316
+ if (datamodule := getattr(self.trainer, "datamodule", None)) is None:
313
317
  return None
314
-
315
- if not isinstance(datamodule, LightningDataModuleBase):
318
+ if not isinstance(datamodule, LightningDataModule):
316
319
  raise TypeError(
317
- f"datamodule must be a LightningDataModuleBase: {type(datamodule)}"
320
+ f"datamodule must be a LightningDataModule: {type(datamodule)}"
318
321
  )
319
-
320
- datamodule = cast(LightningDataModuleBase[THparams], datamodule)
321
322
  return datamodule
322
323
 
323
324
  if TYPE_CHECKING:
@@ -576,66 +577,3 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
576
577
  "batch": batch,
577
578
  "batch_idx": batch_idx,
578
579
  }
579
-
580
-
581
- class LightningDataModuleBase(
582
- LoggerModuleMixin,
583
- CallbackRegistrarModuleMixin,
584
- Base[THparams],
585
- LightningDataModule,
586
- ABC,
587
- Generic[THparams],
588
- ):
589
- hparams: THparams # pyright: ignore[reportIncompatibleMethodOverride]
590
- hparams_initial: THparams # pyright: ignore[reportIncompatibleMethodOverride]
591
-
592
- def pre_init_update_hparams_dict(self, hparams: MutableMapping[str, Any]):
593
- """
594
- Override this method to update the hparams dictionary before it is used to create the hparams object.
595
- Mapping-based parameters are passed to the constructor of the hparams object when we're loading the model from a checkpoint.
596
- """
597
- return hparams
598
-
599
- def pre_init_update_hparams(self, hparams: THparams):
600
- """
601
- Override this method to update the hparams object before it is used to create the hparams_initial object.
602
- """
603
- return hparams
604
-
605
- @classmethod
606
- def _update_environment(cls, hparams: THparams):
607
- hparams.environment.data = _cls_info(cls)
608
-
609
- @override
610
- def __init__(self, hparams: THparams):
611
- if not isinstance(hparams, BaseConfig):
612
- if not isinstance(hparams, MutableMapping):
613
- raise TypeError(
614
- f"hparams must be a BaseConfig or a MutableMapping: {type(hparams)}"
615
- )
616
-
617
- hparams = self.pre_init_update_hparams_dict(hparams)
618
- hparams = self.config_cls().from_dict(hparams)
619
- self._update_environment(hparams)
620
- hparams = self.pre_init_update_hparams(hparams)
621
- super().__init__(hparams)
622
-
623
- self.save_hyperparameters(hparams)
624
-
625
- @property
626
- def lightning_module(self):
627
- if not self.trainer:
628
- raise ValueError("Trainer has not been set.")
629
-
630
- module = self.trainer.lightning_module
631
- if not isinstance(module, LightningModuleBase):
632
- raise ValueError(
633
- f"Trainer's lightning_module is not a LightningModuleBase: {type(module)}"
634
- )
635
-
636
- module = cast(LightningModuleBase[THparams], module)
637
- return module
638
-
639
- @property
640
- def device(self):
641
- return self.lightning_module.device
@@ -1,7 +1,6 @@
1
1
  import copy
2
2
  import os
3
3
  import re
4
- import signal
5
4
  import socket
6
5
  import string
7
6
  import time
@@ -21,6 +20,7 @@ from typing import (
21
20
  runtime_checkable,
22
21
  )
23
22
 
23
+ import nshconfig as C
24
24
  import numpy as np
25
25
  import torch
26
26
  from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
@@ -39,7 +39,6 @@ from typing_extensions import Self, TypedDict, TypeVar, override
39
39
  from ..callbacks import CallbackConfig
40
40
  from ..callbacks.base import CallbackConfigBase
41
41
  from ..callbacks.wandb_watch import WandbWatchConfig
42
- from ..config import Field, TypedConfig
43
42
  from ..util.slurm import parse_slurm_node_list
44
43
 
45
44
  log = getLogger(__name__)
@@ -49,7 +48,7 @@ class IdSeedWarning(Warning):
49
48
  pass
50
49
 
51
50
 
52
- class BaseProfilerConfig(TypedConfig, ABC):
51
+ class BaseProfilerConfig(C.Config, ABC):
53
52
  dirpath: str | Path | None = None
54
53
  """
55
54
  Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the
@@ -200,11 +199,11 @@ class PyTorchProfilerConfig(BaseProfilerConfig):
200
199
 
201
200
  ProfilerConfig: TypeAlias = Annotated[
202
201
  SimpleProfilerConfig | AdvancedProfilerConfig | PyTorchProfilerConfig,
203
- Field(discriminator="kind"),
202
+ C.Field(discriminator="kind"),
204
203
  ]
205
204
 
206
205
 
207
- class EnvironmentClassInformationConfig(TypedConfig):
206
+ class EnvironmentClassInformationConfig(C.Config):
208
207
  name: str
209
208
  module: str
210
209
  full_name: str
@@ -213,7 +212,7 @@ class EnvironmentClassInformationConfig(TypedConfig):
213
212
  source_file_path: Path | None = None
214
213
 
215
214
 
216
- class EnvironmentSLURMInformationConfig(TypedConfig):
215
+ class EnvironmentSLURMInformationConfig(C.Config):
217
216
  hostname: str
218
217
  hostnames: list[str]
219
218
  job_id: str
@@ -271,7 +270,7 @@ class EnvironmentSLURMInformationConfig(TypedConfig):
271
270
  return None
272
271
 
273
272
 
274
- class EnvironmentLSFInformationConfig(TypedConfig):
273
+ class EnvironmentLSFInformationConfig(C.Config):
275
274
  hostname: str
276
275
  hostnames: list[str]
277
276
  job_id: str
@@ -328,7 +327,7 @@ class EnvironmentLSFInformationConfig(TypedConfig):
328
327
  return None
329
328
 
330
329
 
331
- class EnvironmentLinuxEnvironmentConfig(TypedConfig):
330
+ class EnvironmentLinuxEnvironmentConfig(C.Config):
332
331
  """
333
332
  Information about the Linux environment (e.g., current user, hostname, etc.)
334
333
  """
@@ -347,9 +346,25 @@ class EnvironmentLinuxEnvironmentConfig(TypedConfig):
347
346
  load_avg: tuple[float, float, float] | None = None
348
347
 
349
348
 
350
- class EnvironmentConfig(TypedConfig):
349
+ class EnvironmentSnapshotConfig(C.Config):
350
+ snapshot_dir: Path | None = None
351
+ modules: list[str] | None = None
352
+
353
+ @classmethod
354
+ def from_current_environment(cls):
355
+ draft = cls.draft()
356
+ if snapshot_dir := os.environ.get("NSHRUNNER_SNAPSHOT_DIR"):
357
+ draft.snapshot_dir = Path(snapshot_dir)
358
+ if modules := os.environ.get("NSHRUNNER_SNAPSHOT_MODULES"):
359
+ draft.modules = modules.split(",")
360
+ return draft.finalize()
361
+
362
+
363
+ class EnvironmentConfig(C.Config):
351
364
  cwd: Path | None = None
352
365
 
366
+ snapshot: EnvironmentSnapshotConfig | None = None
367
+
353
368
  python_executable: Path | None = None
354
369
  python_path: list[Path] | None = None
355
370
  python_version: str | None = None
@@ -372,7 +387,7 @@ class EnvironmentConfig(TypedConfig):
372
387
  seed_workers: bool | None = None
373
388
 
374
389
 
375
- class BaseLoggerConfig(TypedConfig, ABC):
390
+ class BaseLoggerConfig(C.Config, ABC):
376
391
  enabled: bool = True
377
392
  """Enable this logger."""
378
393
 
@@ -426,7 +441,7 @@ def _wandb_available():
426
441
  class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
427
442
  kind: Literal["wandb"] = "wandb"
428
443
 
429
- enabled: bool = Field(default_factory=lambda: _wandb_available())
444
+ enabled: bool = C.Field(default_factory=lambda: _wandb_available())
430
445
  """Enable WandB logging."""
431
446
 
432
447
  priority: int = 2
@@ -543,7 +558,7 @@ def _tensorboard_available():
543
558
  class TensorboardLoggerConfig(BaseLoggerConfig):
544
559
  kind: Literal["tensorboard"] = "tensorboard"
545
560
 
546
- enabled: bool = Field(default_factory=lambda: _tensorboard_available())
561
+ enabled: bool = C.Field(default_factory=lambda: _tensorboard_available())
547
562
  """Enable TensorBoard logging."""
548
563
 
549
564
  priority: int = 2
@@ -589,7 +604,7 @@ class TensorboardLoggerConfig(BaseLoggerConfig):
589
604
 
590
605
  LoggerConfig: TypeAlias = Annotated[
591
606
  WandbLoggerConfig | CSVLoggerConfig | TensorboardLoggerConfig,
592
- Field(discriminator="kind"),
607
+ C.Field(discriminator="kind"),
593
608
  ]
594
609
 
595
610
 
@@ -684,7 +699,7 @@ class LoggingConfig(CallbackConfigBase):
684
699
  yield from logger.construct_callbacks(root_config)
685
700
 
686
701
 
687
- class GradientClippingConfig(TypedConfig):
702
+ class GradientClippingConfig(C.Config):
688
703
  enabled: bool = True
689
704
  """Enable gradient clipping."""
690
705
  value: int | float
@@ -719,41 +734,6 @@ class OptimizationConfig(CallbackConfigBase):
719
734
  ).construct_callbacks(root_config)
720
735
 
721
736
 
722
- LogLevel: TypeAlias = Literal[
723
- "CRITICAL", "FATAL", "ERROR", "WARN", "WARNING", "INFO", "DEBUG"
724
- ]
725
-
726
-
727
- class PythonLogging(TypedConfig):
728
- log_level: LogLevel | None = None
729
- """Log level to use for the Python logger (or None to use the default)."""
730
-
731
- rich: bool = False
732
- """If enabled, will use the rich library to format the Python logger output."""
733
- rich_tracebacks: bool = True
734
- """If enabled, will use the rich library to format the Python logger tracebacks."""
735
-
736
- lovely_tensors: bool = False
737
- """If enabled, will use the lovely-tensors library to format PyTorch tensors. False by default as it causes issues when used with `torch.vmap`."""
738
- lovely_numpy: bool = False
739
- """If enabled, will use the lovely-numpy library to format numpy arrays. False by default as it causes some issues with other libaries."""
740
-
741
- def pretty_(
742
- self,
743
- *,
744
- log_level: LogLevel | None = "INFO",
745
- torch: bool = True,
746
- numpy: bool = True,
747
- rich: bool = True,
748
- rich_tracebacks: bool = True,
749
- ):
750
- self.log_level = log_level
751
- self.lovely_tensors = torch
752
- self.lovely_numpy = numpy
753
- self.rich = rich
754
- self.rich_tracebacks = rich_tracebacks
755
-
756
-
757
737
  TPlugin = TypeVar(
758
738
  "TPlugin",
759
739
  Precision,
@@ -813,7 +793,7 @@ StrategyLiteral: TypeAlias = Literal[
813
793
  ]
814
794
 
815
795
 
816
- class CheckpointLoadingConfig(TypedConfig):
796
+ class CheckpointLoadingConfig(C.Config):
817
797
  path: Literal["best", "last", "hpc"] | str | Path | None = None
818
798
  """
819
799
  Checkpoint path to use when loading a checkpoint.
@@ -825,7 +805,7 @@ class CheckpointLoadingConfig(TypedConfig):
825
805
  """
826
806
 
827
807
 
828
- class DirectoryConfig(TypedConfig):
808
+ class DirectoryConfig(C.Config):
829
809
  project_root: Path | None = None
830
810
  """
831
811
  Root directory for this project.
@@ -901,7 +881,7 @@ class DirectoryConfig(TypedConfig):
901
881
  return log_dir
902
882
 
903
883
 
904
- class ReproducibilityConfig(TypedConfig):
884
+ class ReproducibilityConfig(C.Config):
905
885
  deterministic: bool | Literal["warn"] | None = None
906
886
  """
907
887
  If ``True``, sets whether PyTorch operations must use deterministic algorithms.
@@ -1116,7 +1096,7 @@ CheckpointCallbackConfig: TypeAlias = Annotated[
1116
1096
  ModelCheckpointCallbackConfig
1117
1097
  | LatestEpochCheckpointCallbackConfig
1118
1098
  | OnExceptionCheckpointCallbackConfig,
1119
- Field(discriminator="kind"),
1099
+ C.Field(discriminator="kind"),
1120
1100
  ]
1121
1101
 
1122
1102
 
@@ -1514,7 +1494,7 @@ class ActSaveConfig(CallbackConfigBase):
1514
1494
  return [ActSaveCallback()]
1515
1495
 
1516
1496
 
1517
- class SanityCheckingConfig(TypedConfig):
1497
+ class SanityCheckingConfig(C.Config):
1518
1498
  reduce_lr_on_plateau: Literal["disable", "error", "warn"] = "error"
1519
1499
  """
1520
1500
  If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
@@ -1524,7 +1504,7 @@ class SanityCheckingConfig(TypedConfig):
1524
1504
  """
1525
1505
 
1526
1506
 
1527
- class TrainerConfig(TypedConfig):
1507
+ class TrainerConfig(C.Config):
1528
1508
  checkpoint_loading: CheckpointLoadingConfig = CheckpointLoadingConfig()
1529
1509
  """Checkpoint loading configuration options."""
1530
1510
 
@@ -1739,87 +1719,7 @@ class TrainerConfig(TypedConfig):
1739
1719
  """If enabled, will set the torch float32 matmul precision to the specified value. Useful for faster training on Ampere+ GPUs."""
1740
1720
 
1741
1721
 
1742
- class SeedConfig(TypedConfig):
1743
- seed: int
1744
- """Seed for the random number generator."""
1745
-
1746
- seed_workers: bool = False
1747
- """Whether to seed the workers of the dataloader."""
1748
-
1749
-
1750
- Signal: TypeAlias = Literal[
1751
- "SIGHUP",
1752
- "SIGINT",
1753
- "SIGQUIT",
1754
- "SIGILL",
1755
- "SIGTRAP",
1756
- "SIGABRT",
1757
- "SIGBUS",
1758
- "SIGFPE",
1759
- "SIGKILL",
1760
- "SIGUSR1",
1761
- "SIGSEGV",
1762
- "SIGUSR2",
1763
- "SIGPIPE",
1764
- "SIGALRM",
1765
- "SIGTERM",
1766
- "SIGCHLD",
1767
- "SIGCONT",
1768
- "SIGSTOP",
1769
- "SIGTSTP",
1770
- "SIGTTIN",
1771
- "SIGTTOU",
1772
- "SIGURG",
1773
- "SIGXCPU",
1774
- "SIGXFSZ",
1775
- "SIGVTALRM",
1776
- "SIGPROF",
1777
- "SIGWINCH",
1778
- "SIGIO",
1779
- "SIGPWR",
1780
- "SIGSYS",
1781
- "SIGRTMIN",
1782
- "SIGRTMAX",
1783
- ]
1784
-
1785
-
1786
- class SubmitConfig(TypedConfig):
1787
- auto_requeue_signals: list[Signal] = [
1788
- # "SIGUSR1",
1789
- # On SIGURG:
1790
- # Important note from https://amrex-astro.github.io/workflow/olcf-workflow.html:
1791
- # We can also ask the job manager to send a warning signal some amount of time before the allocation expires by passing -wa 'signal' and -wt '[hour:]minute' to bsub. We can then have bash create a dump_and_stop file when it receives the signal, which will tell Castro to output a checkpoint file and exit cleanly after it finishes the current timestep. An important detail that I couldn't find documented anywhere is that the job manager sends the signal to all the processes in the job, not just the submission script, and we have to use a signal that is ignored by default so Castro doesn't immediately crash upon receiving it. SIGCHLD, SIGURG, and SIGWINCH are the only signals that fit this requirement and of these, SIGURG is the least likely to be triggered by other events.
1792
- "SIGURG"
1793
- ]
1794
- """Signals that will trigger an automatic requeue of the job."""
1795
-
1796
- def _resolved_auto_requeue_signals(self) -> list[signal.Signals]:
1797
- return [getattr(signal.Signals, sig) for sig in self.auto_requeue_signals]
1798
-
1799
-
1800
- class RunnerConfig(TypedConfig):
1801
- python_logging: PythonLogging = PythonLogging()
1802
- """Python logging configuration options."""
1803
-
1804
- seed: SeedConfig = SeedConfig(seed=0)
1805
- """Seed everything configuration options."""
1806
-
1807
- submit: SubmitConfig = SubmitConfig()
1808
- """Submit (e.g., SLURM or LSF) configuration options."""
1809
-
1810
- dump_run_information: bool = True
1811
- """
1812
- If enabled, will dump different bits of run information to the output directory before starting the run.
1813
- This includes:
1814
- - Run config
1815
- - Full set of environment variables
1816
- """
1817
-
1818
- additional_env_vars: dict[str, str] = {}
1819
- """Additional environment variables to set when running the script."""
1820
-
1821
-
1822
- class MetricConfig(TypedConfig):
1722
+ class MetricConfig(C.Config):
1823
1723
  name: str
1824
1724
  """The name of the primary metric."""
1825
1725
 
@@ -1851,8 +1751,8 @@ class MetricConfig(TypedConfig):
1851
1751
  PrimaryMetricConfig: TypeAlias = MetricConfig
1852
1752
 
1853
1753
 
1854
- class BaseConfig(TypedConfig):
1855
- id: str = Field(default_factory=lambda: BaseConfig.generate_id())
1754
+ class BaseConfig(C.Config):
1755
+ id: str = C.Field(default_factory=lambda: BaseConfig.generate_id())
1856
1756
  """ID of the run."""
1857
1757
  name: str | None = None
1858
1758
  """Run name."""
@@ -1867,15 +1767,13 @@ class BaseConfig(TypedConfig):
1867
1767
 
1868
1768
  debug: bool = False
1869
1769
  """Whether to run in debug mode. This will enable debug logging and enable debug code paths."""
1870
- environment: Annotated[EnvironmentConfig, Field(repr=False)] = EnvironmentConfig()
1770
+ environment: Annotated[EnvironmentConfig, C.Field(repr=False)] = EnvironmentConfig()
1871
1771
  """A snapshot of the current environment information (e.g. python version, slurm info, etc.). This is automatically populated by the run script."""
1872
1772
 
1873
1773
  directory: DirectoryConfig = DirectoryConfig()
1874
1774
  """Directory configuration options."""
1875
1775
  trainer: TrainerConfig = TrainerConfig()
1876
1776
  """PyTorch Lightning trainer configuration options. Check Lightning's `Trainer` documentation for more information."""
1877
- runner: RunnerConfig = RunnerConfig()
1878
- """`ll.Runner` configuration options."""
1879
1777
 
1880
1778
  primary_metric: PrimaryMetricConfig | None = None
1881
1779
  """Primary metric configuration options. This is used in the following ways: