nshtrainer 0.44.0__py3-none-any.whl → 1.0.0b9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. nshtrainer/__init__.py +6 -3
  2. nshtrainer/_callback.py +297 -2
  3. nshtrainer/_checkpoint/loader.py +23 -30
  4. nshtrainer/_checkpoint/metadata.py +22 -18
  5. nshtrainer/_experimental/__init__.py +0 -2
  6. nshtrainer/_hf_hub.py +25 -26
  7. nshtrainer/callbacks/__init__.py +1 -3
  8. nshtrainer/callbacks/actsave.py +22 -20
  9. nshtrainer/callbacks/base.py +7 -7
  10. nshtrainer/callbacks/checkpoint/__init__.py +1 -1
  11. nshtrainer/callbacks/checkpoint/_base.py +8 -5
  12. nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
  13. nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
  14. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
  15. nshtrainer/callbacks/debug_flag.py +14 -19
  16. nshtrainer/callbacks/directory_setup.py +6 -11
  17. nshtrainer/callbacks/early_stopping.py +3 -3
  18. nshtrainer/callbacks/ema.py +1 -1
  19. nshtrainer/callbacks/finite_checks.py +1 -1
  20. nshtrainer/callbacks/gradient_skipping.py +1 -1
  21. nshtrainer/callbacks/log_epoch.py +1 -1
  22. nshtrainer/callbacks/norm_logging.py +1 -1
  23. nshtrainer/callbacks/print_table.py +1 -1
  24. nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
  25. nshtrainer/callbacks/shared_parameters.py +1 -1
  26. nshtrainer/callbacks/timer.py +1 -1
  27. nshtrainer/callbacks/wandb_upload_code.py +1 -1
  28. nshtrainer/callbacks/wandb_watch.py +1 -1
  29. nshtrainer/config/__init__.py +189 -189
  30. nshtrainer/config/_checkpoint/__init__.py +70 -0
  31. nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
  32. nshtrainer/config/_directory/__init__.py +2 -2
  33. nshtrainer/config/_hf_hub/__init__.py +2 -2
  34. nshtrainer/config/callbacks/__init__.py +44 -44
  35. nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
  36. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
  37. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
  38. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
  39. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
  40. nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
  41. nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
  42. nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
  43. nshtrainer/config/callbacks/ema/__init__.py +2 -2
  44. nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
  45. nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
  46. nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
  47. nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
  48. nshtrainer/config/callbacks/print_table/__init__.py +4 -4
  49. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
  50. nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
  51. nshtrainer/config/callbacks/timer/__init__.py +4 -4
  52. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
  53. nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
  54. nshtrainer/config/loggers/__init__.py +10 -6
  55. nshtrainer/config/loggers/actsave/__init__.py +29 -0
  56. nshtrainer/config/loggers/csv/__init__.py +2 -2
  57. nshtrainer/config/loggers/wandb/__init__.py +6 -6
  58. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
  59. nshtrainer/config/nn/__init__.py +18 -18
  60. nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
  61. nshtrainer/config/optimizer/__init__.py +2 -2
  62. nshtrainer/config/profiler/__init__.py +2 -2
  63. nshtrainer/config/profiler/pytorch/__init__.py +4 -4
  64. nshtrainer/config/profiler/simple/__init__.py +4 -4
  65. nshtrainer/config/trainer/__init__.py +180 -0
  66. nshtrainer/config/trainer/_config/__init__.py +59 -36
  67. nshtrainer/config/trainer/trainer/__init__.py +27 -0
  68. nshtrainer/config/util/__init__.py +109 -0
  69. nshtrainer/config/util/_environment_info/__init__.py +20 -20
  70. nshtrainer/config/util/config/__init__.py +2 -2
  71. nshtrainer/data/datamodule.py +51 -2
  72. nshtrainer/loggers/__init__.py +2 -1
  73. nshtrainer/loggers/_base.py +5 -2
  74. nshtrainer/loggers/actsave.py +59 -0
  75. nshtrainer/loggers/csv.py +5 -5
  76. nshtrainer/loggers/tensorboard.py +5 -5
  77. nshtrainer/loggers/wandb.py +17 -16
  78. nshtrainer/lr_scheduler/_base.py +2 -1
  79. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
  80. nshtrainer/model/__init__.py +0 -4
  81. nshtrainer/model/base.py +64 -347
  82. nshtrainer/model/mixins/callback.py +24 -5
  83. nshtrainer/model/mixins/debug.py +86 -0
  84. nshtrainer/model/mixins/logger.py +142 -145
  85. nshtrainer/profiler/_base.py +2 -2
  86. nshtrainer/profiler/advanced.py +4 -4
  87. nshtrainer/profiler/pytorch.py +4 -4
  88. nshtrainer/profiler/simple.py +4 -4
  89. nshtrainer/trainer/__init__.py +1 -0
  90. nshtrainer/trainer/_config.py +164 -17
  91. nshtrainer/trainer/checkpoint_connector.py +23 -8
  92. nshtrainer/trainer/trainer.py +194 -76
  93. nshtrainer/util/_environment_info.py +21 -13
  94. nshtrainer/util/config/dtype.py +4 -4
  95. nshtrainer/util/typing_utils.py +1 -1
  96. {nshtrainer-0.44.0.dist-info → nshtrainer-1.0.0b9.dist-info}/METADATA +2 -2
  97. nshtrainer-1.0.0b9.dist-info/RECORD +143 -0
  98. nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
  99. nshtrainer/callbacks/throughput_monitor.py +0 -58
  100. nshtrainer/config/model/__init__.py +0 -41
  101. nshtrainer/config/model/base/__init__.py +0 -25
  102. nshtrainer/config/model/config/__init__.py +0 -37
  103. nshtrainer/config/model/mixins/logger/__init__.py +0 -22
  104. nshtrainer/config/runner/__init__.py +0 -22
  105. nshtrainer/ll/__init__.py +0 -59
  106. nshtrainer/ll/_experimental.py +0 -3
  107. nshtrainer/ll/actsave.py +0 -6
  108. nshtrainer/ll/callbacks.py +0 -3
  109. nshtrainer/ll/config.py +0 -6
  110. nshtrainer/ll/data.py +0 -3
  111. nshtrainer/ll/log.py +0 -5
  112. nshtrainer/ll/lr_scheduler.py +0 -3
  113. nshtrainer/ll/model.py +0 -21
  114. nshtrainer/ll/nn.py +0 -3
  115. nshtrainer/ll/optimizer.py +0 -3
  116. nshtrainer/ll/runner.py +0 -5
  117. nshtrainer/ll/snapshot.py +0 -3
  118. nshtrainer/ll/snoop.py +0 -3
  119. nshtrainer/ll/trainer.py +0 -3
  120. nshtrainer/ll/typecheck.py +0 -3
  121. nshtrainer/ll/util.py +0 -3
  122. nshtrainer/model/config.py +0 -218
  123. nshtrainer/runner.py +0 -101
  124. nshtrainer-0.44.0.dist-info/RECORD +0 -162
  125. {nshtrainer-0.44.0.dist-info → nshtrainer-1.0.0b9.dist-info}/WHEEL +0 -0
@@ -1,37 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __codegen__ = True
4
-
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.model.config import BaseConfig as BaseConfig
11
- from nshtrainer.model.config import CallbackConfigBase as CallbackConfigBase
12
- from nshtrainer.model.config import DirectoryConfig as DirectoryConfig
13
- from nshtrainer.model.config import EnvironmentConfig as EnvironmentConfig
14
- from nshtrainer.model.config import MetricConfig as MetricConfig
15
- from nshtrainer.model.config import TrainerConfig as TrainerConfig
16
- else:
17
-
18
- def __getattr__(name):
19
- import importlib
20
-
21
- if name in globals():
22
- return globals()[name]
23
- if name == "MetricConfig":
24
- return importlib.import_module("nshtrainer.model.config").MetricConfig
25
- if name == "TrainerConfig":
26
- return importlib.import_module("nshtrainer.model.config").TrainerConfig
27
- if name == "BaseConfig":
28
- return importlib.import_module("nshtrainer.model.config").BaseConfig
29
- if name == "EnvironmentConfig":
30
- return importlib.import_module("nshtrainer.model.config").EnvironmentConfig
31
- if name == "DirectoryConfig":
32
- return importlib.import_module("nshtrainer.model.config").DirectoryConfig
33
- if name == "CallbackConfigBase":
34
- return importlib.import_module("nshtrainer.model.config").CallbackConfigBase
35
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
36
-
37
- # Submodule exports
@@ -1,22 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __codegen__ = True
4
-
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.model.mixins.logger import BaseConfig as BaseConfig
11
- else:
12
-
13
- def __getattr__(name):
14
- import importlib
15
-
16
- if name in globals():
17
- return globals()[name]
18
- if name == "BaseConfig":
19
- return importlib.import_module("nshtrainer.model.mixins.logger").BaseConfig
20
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
21
-
22
- # Submodule exports
@@ -1,22 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __codegen__ = True
4
-
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.runner import BaseConfig as BaseConfig
11
- else:
12
-
13
- def __getattr__(name):
14
- import importlib
15
-
16
- if name in globals():
17
- return globals()[name]
18
- if name == "BaseConfig":
19
- return importlib.import_module("nshtrainer.runner").BaseConfig
20
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
21
-
22
- # Submodule exports
nshtrainer/ll/__init__.py DELETED
@@ -1,59 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import TypeAlias
4
-
5
- from . import _experimental as _experimental
6
- from . import actsave as actsave
7
- from . import callbacks as callbacks
8
- from . import data as data
9
- from . import lr_scheduler as lr_scheduler
10
- from . import model as model
11
- from . import nn as nn
12
- from . import optimizer as optimizer
13
- from . import snapshot as snapshot
14
- from . import typecheck as typecheck
15
- from .actsave import ActLoad as ActLoad
16
- from .actsave import ActSave as ActSave
17
- from .config import MISSING as MISSING
18
- from .config import AllowMissing as AllowMissing
19
- from .config import Field as Field
20
- from .config import MissingField as MissingField
21
- from .config import PrivateAttr as PrivateAttr
22
- from .config import TypedConfig as TypedConfig
23
- from .data import dataset_transform as dataset_transform
24
- from .log import init_python_logging as init_python_logging
25
- from .log import lovely as lovely
26
- from .log import pretty as pretty
27
- from .lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
28
- from .model import BaseConfig as BaseConfig
29
- from .model import CheckpointLoadingConfig as CheckpointLoadingConfig
30
- from .model import CheckpointSavingConfig as CheckpointSavingConfig
31
- from .model import DirectoryConfig as DirectoryConfig
32
- from .model import (
33
- EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
34
- )
35
- from .model import EnvironmentConfig as EnvironmentConfig
36
- from .model import (
37
- EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
38
- )
39
- from .model import (
40
- EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
41
- )
42
- from .model import GradientClippingConfig as GradientClippingConfig
43
- from .model import LightningModuleBase as LightningModuleBase
44
- from .model import LoggingConfig as LoggingConfig
45
- from .model import MetricConfig as MetricConfig
46
- from .model import OptimizationConfig as OptimizationConfig
47
- from .model import ReproducibilityConfig as ReproducibilityConfig
48
- from .model import SanityCheckingConfig as SanityCheckingConfig
49
- from .model import TrainerConfig as TrainerConfig
50
- from .nn import TypedModuleDict as TypedModuleDict
51
- from .nn import TypedModuleList as TypedModuleList
52
- from .optimizer import OptimizerConfig as OptimizerConfig
53
- from .runner import Runner as Runner
54
- from .runner import SnapshotConfig as SnapshotConfig
55
- from .snoop import snoop as snoop
56
- from .trainer import Trainer as Trainer
57
-
58
- PrimaryMetricConfig: TypeAlias = MetricConfig
59
- ConfigList: TypeAlias = list[tuple[BaseConfig, type[LightningModuleBase]]]
@@ -1,3 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from nshtrainer._experimental import * # noqa: F403
nshtrainer/ll/actsave.py DELETED
@@ -1,6 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from nshutils.actsave import * # type: ignore # noqa: F403
4
-
5
- from nshtrainer.callbacks.actsave import ActSaveCallback as ActSaveCallback
6
- from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
@@ -1,3 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from nshtrainer.callbacks import * # noqa: F403
nshtrainer/ll/config.py DELETED
@@ -1,6 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from nshconfig import * # type: ignore # noqa: F403
4
- from nshconfig import Config as TypedConfig # type: ignore # noqa: F401
5
-
6
- _ = TypedConfig
nshtrainer/ll/data.py DELETED
@@ -1,3 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from nshtrainer.data import * # noqa: F403
nshtrainer/ll/log.py DELETED
@@ -1,5 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from nshutils import init_python_logging as init_python_logging
4
- from nshutils import lovely as lovely
5
- from nshutils import pretty as pretty
@@ -1,3 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from nshtrainer.lr_scheduler import * # noqa: F403
nshtrainer/ll/model.py DELETED
@@ -1,21 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from nshtrainer.model import * # noqa: F403
4
-
5
- from ..trainer._config import CheckpointLoadingConfig as CheckpointLoadingConfig
6
- from ..trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
7
- from ..trainer._config import GradientClippingConfig as GradientClippingConfig
8
- from ..trainer._config import LoggingConfig as LoggingConfig
9
- from ..trainer._config import OptimizationConfig as OptimizationConfig
10
- from ..trainer._config import ReproducibilityConfig as ReproducibilityConfig
11
- from ..trainer._config import SanityCheckingConfig as SanityCheckingConfig
12
- from ..util._environment_info import (
13
- EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
14
- )
15
- from ..util._environment_info import EnvironmentConfig as EnvironmentConfig
16
- from ..util._environment_info import (
17
- EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
18
- )
19
- from ..util._environment_info import (
20
- EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
21
- )
nshtrainer/ll/nn.py DELETED
@@ -1,3 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from nshtrainer.nn import * # noqa: F403
@@ -1,3 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from nshtrainer.optimizer import * # noqa: F403
nshtrainer/ll/runner.py DELETED
@@ -1,5 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from nshrunner import SnapshotConfig as SnapshotConfig
4
-
5
- from nshtrainer.runner import * # type: ignore # noqa: F403
nshtrainer/ll/snapshot.py DELETED
@@ -1,3 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from nshsnap import * # pyright: ignore[reportWildcardImportFromLibrary] # noqa: F403
nshtrainer/ll/snoop.py DELETED
@@ -1,3 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from nshutils import snoop as snoop
nshtrainer/ll/trainer.py DELETED
@@ -1,3 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from nshtrainer.trainer import * # noqa: F403
@@ -1,3 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from nshutils.typecheck import * # type: ignore # noqa: F403
nshtrainer/ll/util.py DELETED
@@ -1,3 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from nshtrainer.util import * # noqa: F403
@@ -1,218 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import copy
4
- import logging
5
- import os
6
- import string
7
- import time
8
- from collections.abc import Iterable
9
- from pathlib import Path
10
- from typing import Annotated, Any, ClassVar
11
-
12
- import nshconfig as C
13
- import numpy as np
14
- import torch
15
- from typing_extensions import Self
16
-
17
- from .._directory import DirectoryConfig
18
- from ..callbacks.base import CallbackConfigBase
19
- from ..metrics import MetricConfig
20
- from ..trainer._config import TrainerConfig
21
- from ..util._environment_info import EnvironmentConfig
22
-
23
- log = logging.getLogger(__name__)
24
-
25
-
26
- class BaseConfig(C.Config):
27
- id: str = C.Field(default_factory=lambda: BaseConfig.generate_id())
28
- """ID of the run."""
29
- name: str | None = None
30
- """Run name."""
31
- name_parts: list[str] = []
32
- """A list of parts used to construct the run name. This is useful for constructing the run name dynamically."""
33
- project: str | None = None
34
- """Project name."""
35
- tags: list[str] = []
36
- """Tags for the run."""
37
- notes: list[str] = []
38
- """Human readable notes for the run."""
39
-
40
- debug: bool = False
41
- """Whether to run in debug mode. This will enable debug logging and enable debug code paths."""
42
- environment: Annotated[EnvironmentConfig, C.Field(repr=False)] = (
43
- EnvironmentConfig.empty()
44
- )
45
- """A snapshot of the current environment information (e.g. python version, slurm info, etc.). This is automatically populated by the run script."""
46
-
47
- directory: DirectoryConfig = DirectoryConfig()
48
- """Directory configuration options."""
49
- trainer: TrainerConfig = TrainerConfig()
50
- """PyTorch Lightning trainer configuration options. Check Lightning's `Trainer` documentation for more information."""
51
-
52
- primary_metric: MetricConfig | None = None
53
- """Primary metric configuration options. This is used in the following ways:
54
- - To determine the best model checkpoint to save with the ModelCheckpoint callback.
55
- - To monitor the primary metric during training and stop training based on the `early_stopping` configuration.
56
- - For the ReduceLROnPlateau scheduler.
57
- """
58
-
59
- meta: dict[str, Any] = {}
60
- """Additional metadata for this run. This can be used to store arbitrary data that is not part of the config schema."""
61
-
62
- @property
63
- def run_name(self) -> str:
64
- parts = self.name_parts.copy()
65
- if self.name is not None:
66
- parts = [self.name] + parts
67
- name = "-".join(parts)
68
- if not name:
69
- name = self.id
70
- return name
71
-
72
- def clone(self, with_new_id: bool = True) -> Self:
73
- c = copy.deepcopy(self)
74
- if with_new_id:
75
- c.id = BaseConfig.generate_id()
76
- return c
77
-
78
- def subdirectory(self, subdirectory: str) -> Path:
79
- return self.directory.resolve_subdirectory(self.id, subdirectory)
80
-
81
- # region Helper methods
82
- def fast_dev_run(self, value: int | bool = True, /):
83
- """
84
- Enables fast_dev_run mode for the trainer.
85
- This will run the training loop for a specified number of batches,
86
- if an integer is provided, or for a single batch if True is provided.
87
- """
88
- config = copy.deepcopy(self)
89
- config.trainer.fast_dev_run = value
90
- return config
91
-
92
- def with_project_root_(self, project_root: str | Path | os.PathLike) -> Self:
93
- """
94
- Set the project root directory for the trainer.
95
-
96
- Args:
97
- project_root (Path): The base directory to use.
98
-
99
- Returns:
100
- self: The current instance of the class.
101
- """
102
- self.directory.project_root = Path(project_root)
103
- return self
104
-
105
- def reset_(
106
- self,
107
- *,
108
- id: bool = True,
109
- basic: bool = True,
110
- project_root: bool = True,
111
- environment: bool = True,
112
- meta: bool = True,
113
- ):
114
- """
115
- Reset the configuration object to its initial state.
116
-
117
- Parameters:
118
- - id (bool): If True, generate a new ID for the configuration object.
119
- - basic (bool): If True, reset basic attributes like name, project, tags, and notes.
120
- - project_root (bool): If True, reset the directory configuration to its initial state.
121
- - environment (bool): If True, reset the environment configuration to its initial state.
122
- - meta (bool): If True, reset the meta dictionary to an empty dictionary.
123
-
124
- Returns:
125
- - self: The updated configuration object.
126
-
127
- """
128
- if id:
129
- self.id = self.generate_id()
130
-
131
- if basic:
132
- self.name = None
133
- self.name_parts = []
134
- self.project = None
135
- self.tags = []
136
- self.notes = []
137
-
138
- if project_root:
139
- self.directory = DirectoryConfig()
140
-
141
- if environment:
142
- self.environment = EnvironmentConfig.empty()
143
-
144
- if meta:
145
- self.meta = {}
146
-
147
- return self
148
-
149
- def concise_repr(self) -> str:
150
- """Get a concise representation of the configuration object."""
151
-
152
- def _truncate(s: str, max_len: int = 50):
153
- return s if len(s) <= max_len else f"{s[:max_len - 3]}..."
154
-
155
- cls_name = self.__class__.__name__
156
-
157
- parts: list[str] = []
158
- parts.append(f"name={self.run_name}")
159
- if self.project:
160
- parts.append(f"project={_truncate(self.project)}")
161
-
162
- return f"{cls_name}({', '.join(parts)})"
163
-
164
- # endregion
165
-
166
- # region Seeding
167
-
168
- _rng: ClassVar[np.random.Generator | None] = None
169
-
170
- @staticmethod
171
- def generate_id(*, length: int = 8) -> str:
172
- """
173
- Generate a random ID of specified length.
174
-
175
- """
176
- if (rng := BaseConfig._rng) is None:
177
- rng = np.random.default_rng()
178
-
179
- alphabet = list(string.ascii_lowercase + string.digits)
180
-
181
- id = "".join(rng.choice(alphabet) for _ in range(length))
182
- return id
183
-
184
- @staticmethod
185
- def set_seed(seed: int | None = None) -> None:
186
- """
187
- Set the seed for the random number generator.
188
-
189
- Args:
190
- seed (int | None, optional): The seed value to set. If None, a seed based on the current time will be used. Defaults to None.
191
-
192
- Returns:
193
- None
194
- """
195
- if seed is None:
196
- seed = int(time.time() * 1000)
197
- log.critical(f"Seeding BaseConfig with seed {seed}")
198
- BaseConfig._rng = np.random.default_rng(seed)
199
-
200
- # endregion
201
-
202
- @classmethod
203
- def from_checkpoint(
204
- cls,
205
- path: str | Path,
206
- *,
207
- hparams_key: str = "hyper_parameters",
208
- ):
209
- ckpt = torch.load(path)
210
- if (hparams := ckpt.get(hparams_key)) is None:
211
- raise ValueError(
212
- f"The checkpoint does not contain the `{hparams_key}` attribute. "
213
- "Are you sure this is a valid Lightning checkpoint?"
214
- )
215
- return cls.model_validate(hparams)
216
-
217
- def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
218
- yield from self.trainer._nshtrainer_all_callback_configs()
nshtrainer/runner.py DELETED
@@ -1,101 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import copy
4
- import logging
5
- from collections.abc import Callable, Iterable, Mapping, Sequence
6
- from pathlib import Path
7
- from typing import Generic
8
-
9
- import nshrunner as nr
10
- from nshrunner._submit import screen
11
- from typing_extensions import TypeVar, TypeVarTuple, Unpack, deprecated, override
12
-
13
- from .model.config import BaseConfig
14
-
15
- TConfig = TypeVar("TConfig", bound=BaseConfig, infer_variance=True)
16
- TArguments = TypeVarTuple("TArguments", default=Unpack[tuple[()]])
17
- TReturn = TypeVar("TReturn", infer_variance=True)
18
-
19
-
20
- @deprecated("Use nshrunner.Runner instead.")
21
- class Runner(
22
- nr.Runner[TReturn, TConfig, Unpack[TArguments]],
23
- Generic[TReturn, TConfig, Unpack[TArguments]],
24
- ):
25
- @override
26
- def __init__(
27
- self,
28
- run_fn: Callable[[TConfig, Unpack[TArguments]], TReturn],
29
- config: nr.RunnerConfig | None = None,
30
- ):
31
- if config is None:
32
- working_dir = Path.cwd() / "nshrunner"
33
- working_dir.mkdir(exist_ok=True)
34
-
35
- logging.warning(
36
- f"`config` is not provided. Using default working directory of {working_dir}."
37
- )
38
- config = nr.RunnerConfig(working_dir=working_dir)
39
-
40
- super().__init__(run_fn, config)
41
-
42
- def fast_dev_run(
43
- self,
44
- runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
45
- n_batches: int = 1,
46
- *,
47
- env: Mapping[str, str] | None = None,
48
- ):
49
- runs_updated: list[tuple[TConfig, Unpack[TArguments]]] = []
50
- for args in runs:
51
- config = copy.deepcopy(args[0])
52
- config.trainer.fast_dev_run = n_batches
53
- runs_updated.append((config, *args[1:]))
54
- del runs
55
-
56
- return self.local(runs_updated, env=env)
57
-
58
- def fast_dev_run_generator(
59
- self,
60
- runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
61
- n_batches: int = 1,
62
- *,
63
- env: Mapping[str, str] | None = None,
64
- ):
65
- runs_updated: list[tuple[TConfig, Unpack[TArguments]]] = []
66
- for args in runs:
67
- config = copy.deepcopy(args[0])
68
- config.trainer.fast_dev_run = n_batches
69
- runs_updated.append((config, *args[1:]))
70
- del runs
71
-
72
- return self.local_generator(runs_updated, env=env)
73
-
74
- def fast_dev_run_session(
75
- self,
76
- runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
77
- options: screen.ScreenJobKwargs = {},
78
- n_batches: int = 1,
79
- *,
80
- snapshot: nr.Snapshot,
81
- setup_commands: Sequence[str] | None = None,
82
- env: Mapping[str, str] | None = None,
83
- activate_venv: bool = True,
84
- print_command: bool = True,
85
- ):
86
- runs_updated: list[tuple[TConfig, Unpack[TArguments]]] = []
87
- for args in runs:
88
- config = copy.deepcopy(args[0])
89
- config.trainer.fast_dev_run = n_batches
90
- runs_updated.append((config, *args[1:]))
91
- del runs
92
-
93
- return self.session(
94
- runs_updated,
95
- options,
96
- snapshot=snapshot,
97
- setup_commands=setup_commands,
98
- env=env,
99
- activate_venv=activate_venv,
100
- print_command=print_command,
101
- )