nshtrainer 0.44.1__py3-none-any.whl → 1.0.0b10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +6 -3
- nshtrainer/_callback.py +297 -2
- nshtrainer/_checkpoint/loader.py +23 -30
- nshtrainer/_checkpoint/metadata.py +22 -18
- nshtrainer/_experimental/__init__.py +0 -2
- nshtrainer/_hf_hub.py +25 -26
- nshtrainer/callbacks/__init__.py +1 -3
- nshtrainer/callbacks/actsave.py +22 -20
- nshtrainer/callbacks/base.py +7 -7
- nshtrainer/callbacks/checkpoint/__init__.py +1 -1
- nshtrainer/callbacks/checkpoint/_base.py +8 -5
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
- nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
- nshtrainer/callbacks/debug_flag.py +14 -19
- nshtrainer/callbacks/directory_setup.py +6 -11
- nshtrainer/callbacks/early_stopping.py +3 -3
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/log_epoch.py +1 -1
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
- nshtrainer/callbacks/shared_parameters.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_upload_code.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/config/__init__.py +189 -189
- nshtrainer/config/_checkpoint/__init__.py +70 -0
- nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
- nshtrainer/config/_directory/__init__.py +2 -2
- nshtrainer/config/_hf_hub/__init__.py +2 -2
- nshtrainer/config/callbacks/__init__.py +44 -44
- nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
- nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
- nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
- nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
- nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
- nshtrainer/config/callbacks/ema/__init__.py +2 -2
- nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
- nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
- nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
- nshtrainer/config/callbacks/print_table/__init__.py +4 -4
- nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
- nshtrainer/config/callbacks/timer/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
- nshtrainer/config/loggers/__init__.py +10 -6
- nshtrainer/config/loggers/actsave/__init__.py +29 -0
- nshtrainer/config/loggers/csv/__init__.py +2 -2
- nshtrainer/config/loggers/wandb/__init__.py +6 -6
- nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
- nshtrainer/config/nn/__init__.py +18 -18
- nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
- nshtrainer/config/optimizer/__init__.py +2 -2
- nshtrainer/config/profiler/__init__.py +2 -2
- nshtrainer/config/profiler/pytorch/__init__.py +4 -4
- nshtrainer/config/profiler/simple/__init__.py +4 -4
- nshtrainer/config/trainer/__init__.py +180 -0
- nshtrainer/config/trainer/_config/__init__.py +59 -36
- nshtrainer/config/trainer/trainer/__init__.py +27 -0
- nshtrainer/config/util/__init__.py +109 -0
- nshtrainer/config/util/_environment_info/__init__.py +20 -20
- nshtrainer/config/util/config/__init__.py +2 -2
- nshtrainer/data/datamodule.py +52 -2
- nshtrainer/loggers/__init__.py +2 -1
- nshtrainer/loggers/_base.py +5 -2
- nshtrainer/loggers/actsave.py +59 -0
- nshtrainer/loggers/csv.py +5 -5
- nshtrainer/loggers/tensorboard.py +5 -5
- nshtrainer/loggers/wandb.py +17 -16
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
- nshtrainer/model/__init__.py +0 -4
- nshtrainer/model/base.py +64 -347
- nshtrainer/model/mixins/callback.py +24 -5
- nshtrainer/model/mixins/debug.py +86 -0
- nshtrainer/model/mixins/logger.py +142 -145
- nshtrainer/profiler/_base.py +2 -2
- nshtrainer/profiler/advanced.py +4 -4
- nshtrainer/profiler/pytorch.py +4 -4
- nshtrainer/profiler/simple.py +4 -4
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/_config.py +164 -17
- nshtrainer/trainer/checkpoint_connector.py +23 -8
- nshtrainer/trainer/trainer.py +194 -76
- nshtrainer/util/_environment_info.py +21 -13
- nshtrainer/util/config/dtype.py +4 -4
- nshtrainer/util/typing_utils.py +1 -1
- {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/METADATA +2 -2
- nshtrainer-1.0.0b10.dist-info/RECORD +143 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
- nshtrainer/callbacks/throughput_monitor.py +0 -58
- nshtrainer/config/model/__init__.py +0 -41
- nshtrainer/config/model/base/__init__.py +0 -25
- nshtrainer/config/model/config/__init__.py +0 -37
- nshtrainer/config/model/mixins/logger/__init__.py +0 -22
- nshtrainer/config/runner/__init__.py +0 -22
- nshtrainer/ll/__init__.py +0 -59
- nshtrainer/ll/_experimental.py +0 -3
- nshtrainer/ll/actsave.py +0 -6
- nshtrainer/ll/callbacks.py +0 -3
- nshtrainer/ll/config.py +0 -6
- nshtrainer/ll/data.py +0 -3
- nshtrainer/ll/log.py +0 -5
- nshtrainer/ll/lr_scheduler.py +0 -3
- nshtrainer/ll/model.py +0 -21
- nshtrainer/ll/nn.py +0 -3
- nshtrainer/ll/optimizer.py +0 -3
- nshtrainer/ll/runner.py +0 -5
- nshtrainer/ll/snapshot.py +0 -3
- nshtrainer/ll/snoop.py +0 -3
- nshtrainer/ll/trainer.py +0 -3
- nshtrainer/ll/typecheck.py +0 -3
- nshtrainer/ll/util.py +0 -3
- nshtrainer/model/config.py +0 -218
- nshtrainer/runner.py +0 -101
- nshtrainer-0.44.1.dist-info/RECORD +0 -162
- {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/WHEEL +0 -0
@@ -1,37 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__codegen__ = True
|
4
|
-
|
5
|
-
from typing import TYPE_CHECKING
|
6
|
-
|
7
|
-
# Config/alias imports
|
8
|
-
|
9
|
-
if TYPE_CHECKING:
|
10
|
-
from nshtrainer.model.config import BaseConfig as BaseConfig
|
11
|
-
from nshtrainer.model.config import CallbackConfigBase as CallbackConfigBase
|
12
|
-
from nshtrainer.model.config import DirectoryConfig as DirectoryConfig
|
13
|
-
from nshtrainer.model.config import EnvironmentConfig as EnvironmentConfig
|
14
|
-
from nshtrainer.model.config import MetricConfig as MetricConfig
|
15
|
-
from nshtrainer.model.config import TrainerConfig as TrainerConfig
|
16
|
-
else:
|
17
|
-
|
18
|
-
def __getattr__(name):
|
19
|
-
import importlib
|
20
|
-
|
21
|
-
if name in globals():
|
22
|
-
return globals()[name]
|
23
|
-
if name == "MetricConfig":
|
24
|
-
return importlib.import_module("nshtrainer.model.config").MetricConfig
|
25
|
-
if name == "TrainerConfig":
|
26
|
-
return importlib.import_module("nshtrainer.model.config").TrainerConfig
|
27
|
-
if name == "BaseConfig":
|
28
|
-
return importlib.import_module("nshtrainer.model.config").BaseConfig
|
29
|
-
if name == "EnvironmentConfig":
|
30
|
-
return importlib.import_module("nshtrainer.model.config").EnvironmentConfig
|
31
|
-
if name == "DirectoryConfig":
|
32
|
-
return importlib.import_module("nshtrainer.model.config").DirectoryConfig
|
33
|
-
if name == "CallbackConfigBase":
|
34
|
-
return importlib.import_module("nshtrainer.model.config").CallbackConfigBase
|
35
|
-
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
36
|
-
|
37
|
-
# Submodule exports
|
@@ -1,22 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__codegen__ = True
|
4
|
-
|
5
|
-
from typing import TYPE_CHECKING
|
6
|
-
|
7
|
-
# Config/alias imports
|
8
|
-
|
9
|
-
if TYPE_CHECKING:
|
10
|
-
from nshtrainer.model.mixins.logger import BaseConfig as BaseConfig
|
11
|
-
else:
|
12
|
-
|
13
|
-
def __getattr__(name):
|
14
|
-
import importlib
|
15
|
-
|
16
|
-
if name in globals():
|
17
|
-
return globals()[name]
|
18
|
-
if name == "BaseConfig":
|
19
|
-
return importlib.import_module("nshtrainer.model.mixins.logger").BaseConfig
|
20
|
-
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
21
|
-
|
22
|
-
# Submodule exports
|
@@ -1,22 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__codegen__ = True
|
4
|
-
|
5
|
-
from typing import TYPE_CHECKING
|
6
|
-
|
7
|
-
# Config/alias imports
|
8
|
-
|
9
|
-
if TYPE_CHECKING:
|
10
|
-
from nshtrainer.runner import BaseConfig as BaseConfig
|
11
|
-
else:
|
12
|
-
|
13
|
-
def __getattr__(name):
|
14
|
-
import importlib
|
15
|
-
|
16
|
-
if name in globals():
|
17
|
-
return globals()[name]
|
18
|
-
if name == "BaseConfig":
|
19
|
-
return importlib.import_module("nshtrainer.runner").BaseConfig
|
20
|
-
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
21
|
-
|
22
|
-
# Submodule exports
|
nshtrainer/ll/__init__.py
DELETED
@@ -1,59 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from typing import TypeAlias
|
4
|
-
|
5
|
-
from . import _experimental as _experimental
|
6
|
-
from . import actsave as actsave
|
7
|
-
from . import callbacks as callbacks
|
8
|
-
from . import data as data
|
9
|
-
from . import lr_scheduler as lr_scheduler
|
10
|
-
from . import model as model
|
11
|
-
from . import nn as nn
|
12
|
-
from . import optimizer as optimizer
|
13
|
-
from . import snapshot as snapshot
|
14
|
-
from . import typecheck as typecheck
|
15
|
-
from .actsave import ActLoad as ActLoad
|
16
|
-
from .actsave import ActSave as ActSave
|
17
|
-
from .config import MISSING as MISSING
|
18
|
-
from .config import AllowMissing as AllowMissing
|
19
|
-
from .config import Field as Field
|
20
|
-
from .config import MissingField as MissingField
|
21
|
-
from .config import PrivateAttr as PrivateAttr
|
22
|
-
from .config import TypedConfig as TypedConfig
|
23
|
-
from .data import dataset_transform as dataset_transform
|
24
|
-
from .log import init_python_logging as init_python_logging
|
25
|
-
from .log import lovely as lovely
|
26
|
-
from .log import pretty as pretty
|
27
|
-
from .lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
|
28
|
-
from .model import BaseConfig as BaseConfig
|
29
|
-
from .model import CheckpointLoadingConfig as CheckpointLoadingConfig
|
30
|
-
from .model import CheckpointSavingConfig as CheckpointSavingConfig
|
31
|
-
from .model import DirectoryConfig as DirectoryConfig
|
32
|
-
from .model import (
|
33
|
-
EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
|
34
|
-
)
|
35
|
-
from .model import EnvironmentConfig as EnvironmentConfig
|
36
|
-
from .model import (
|
37
|
-
EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
|
38
|
-
)
|
39
|
-
from .model import (
|
40
|
-
EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
|
41
|
-
)
|
42
|
-
from .model import GradientClippingConfig as GradientClippingConfig
|
43
|
-
from .model import LightningModuleBase as LightningModuleBase
|
44
|
-
from .model import LoggingConfig as LoggingConfig
|
45
|
-
from .model import MetricConfig as MetricConfig
|
46
|
-
from .model import OptimizationConfig as OptimizationConfig
|
47
|
-
from .model import ReproducibilityConfig as ReproducibilityConfig
|
48
|
-
from .model import SanityCheckingConfig as SanityCheckingConfig
|
49
|
-
from .model import TrainerConfig as TrainerConfig
|
50
|
-
from .nn import TypedModuleDict as TypedModuleDict
|
51
|
-
from .nn import TypedModuleList as TypedModuleList
|
52
|
-
from .optimizer import OptimizerConfig as OptimizerConfig
|
53
|
-
from .runner import Runner as Runner
|
54
|
-
from .runner import SnapshotConfig as SnapshotConfig
|
55
|
-
from .snoop import snoop as snoop
|
56
|
-
from .trainer import Trainer as Trainer
|
57
|
-
|
58
|
-
PrimaryMetricConfig: TypeAlias = MetricConfig
|
59
|
-
ConfigList: TypeAlias = list[tuple[BaseConfig, type[LightningModuleBase]]]
|
nshtrainer/ll/_experimental.py
DELETED
nshtrainer/ll/actsave.py
DELETED
nshtrainer/ll/callbacks.py
DELETED
nshtrainer/ll/config.py
DELETED
nshtrainer/ll/data.py
DELETED
nshtrainer/ll/log.py
DELETED
nshtrainer/ll/lr_scheduler.py
DELETED
nshtrainer/ll/model.py
DELETED
@@ -1,21 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from nshtrainer.model import * # noqa: F403
|
4
|
-
|
5
|
-
from ..trainer._config import CheckpointLoadingConfig as CheckpointLoadingConfig
|
6
|
-
from ..trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
|
7
|
-
from ..trainer._config import GradientClippingConfig as GradientClippingConfig
|
8
|
-
from ..trainer._config import LoggingConfig as LoggingConfig
|
9
|
-
from ..trainer._config import OptimizationConfig as OptimizationConfig
|
10
|
-
from ..trainer._config import ReproducibilityConfig as ReproducibilityConfig
|
11
|
-
from ..trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
12
|
-
from ..util._environment_info import (
|
13
|
-
EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
|
14
|
-
)
|
15
|
-
from ..util._environment_info import EnvironmentConfig as EnvironmentConfig
|
16
|
-
from ..util._environment_info import (
|
17
|
-
EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
|
18
|
-
)
|
19
|
-
from ..util._environment_info import (
|
20
|
-
EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
|
21
|
-
)
|
nshtrainer/ll/nn.py
DELETED
nshtrainer/ll/optimizer.py
DELETED
nshtrainer/ll/runner.py
DELETED
nshtrainer/ll/snapshot.py
DELETED
nshtrainer/ll/snoop.py
DELETED
nshtrainer/ll/trainer.py
DELETED
nshtrainer/ll/typecheck.py
DELETED
nshtrainer/ll/util.py
DELETED
nshtrainer/model/config.py
DELETED
@@ -1,218 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import copy
|
4
|
-
import logging
|
5
|
-
import os
|
6
|
-
import string
|
7
|
-
import time
|
8
|
-
from collections.abc import Iterable
|
9
|
-
from pathlib import Path
|
10
|
-
from typing import Annotated, Any, ClassVar
|
11
|
-
|
12
|
-
import nshconfig as C
|
13
|
-
import numpy as np
|
14
|
-
import torch
|
15
|
-
from typing_extensions import Self
|
16
|
-
|
17
|
-
from .._directory import DirectoryConfig
|
18
|
-
from ..callbacks.base import CallbackConfigBase
|
19
|
-
from ..metrics import MetricConfig
|
20
|
-
from ..trainer._config import TrainerConfig
|
21
|
-
from ..util._environment_info import EnvironmentConfig
|
22
|
-
|
23
|
-
log = logging.getLogger(__name__)
|
24
|
-
|
25
|
-
|
26
|
-
class BaseConfig(C.Config):
|
27
|
-
id: str = C.Field(default_factory=lambda: BaseConfig.generate_id())
|
28
|
-
"""ID of the run."""
|
29
|
-
name: str | None = None
|
30
|
-
"""Run name."""
|
31
|
-
name_parts: list[str] = []
|
32
|
-
"""A list of parts used to construct the run name. This is useful for constructing the run name dynamically."""
|
33
|
-
project: str | None = None
|
34
|
-
"""Project name."""
|
35
|
-
tags: list[str] = []
|
36
|
-
"""Tags for the run."""
|
37
|
-
notes: list[str] = []
|
38
|
-
"""Human readable notes for the run."""
|
39
|
-
|
40
|
-
debug: bool = False
|
41
|
-
"""Whether to run in debug mode. This will enable debug logging and enable debug code paths."""
|
42
|
-
environment: Annotated[EnvironmentConfig, C.Field(repr=False)] = (
|
43
|
-
EnvironmentConfig.empty()
|
44
|
-
)
|
45
|
-
"""A snapshot of the current environment information (e.g. python version, slurm info, etc.). This is automatically populated by the run script."""
|
46
|
-
|
47
|
-
directory: DirectoryConfig = DirectoryConfig()
|
48
|
-
"""Directory configuration options."""
|
49
|
-
trainer: TrainerConfig = TrainerConfig()
|
50
|
-
"""PyTorch Lightning trainer configuration options. Check Lightning's `Trainer` documentation for more information."""
|
51
|
-
|
52
|
-
primary_metric: MetricConfig | None = None
|
53
|
-
"""Primary metric configuration options. This is used in the following ways:
|
54
|
-
- To determine the best model checkpoint to save with the ModelCheckpoint callback.
|
55
|
-
- To monitor the primary metric during training and stop training based on the `early_stopping` configuration.
|
56
|
-
- For the ReduceLROnPlateau scheduler.
|
57
|
-
"""
|
58
|
-
|
59
|
-
meta: dict[str, Any] = {}
|
60
|
-
"""Additional metadata for this run. This can be used to store arbitrary data that is not part of the config schema."""
|
61
|
-
|
62
|
-
@property
|
63
|
-
def run_name(self) -> str:
|
64
|
-
parts = self.name_parts.copy()
|
65
|
-
if self.name is not None:
|
66
|
-
parts = [self.name] + parts
|
67
|
-
name = "-".join(parts)
|
68
|
-
if not name:
|
69
|
-
name = self.id
|
70
|
-
return name
|
71
|
-
|
72
|
-
def clone(self, with_new_id: bool = True) -> Self:
|
73
|
-
c = copy.deepcopy(self)
|
74
|
-
if with_new_id:
|
75
|
-
c.id = BaseConfig.generate_id()
|
76
|
-
return c
|
77
|
-
|
78
|
-
def subdirectory(self, subdirectory: str) -> Path:
|
79
|
-
return self.directory.resolve_subdirectory(self.id, subdirectory)
|
80
|
-
|
81
|
-
# region Helper methods
|
82
|
-
def fast_dev_run(self, value: int | bool = True, /):
|
83
|
-
"""
|
84
|
-
Enables fast_dev_run mode for the trainer.
|
85
|
-
This will run the training loop for a specified number of batches,
|
86
|
-
if an integer is provided, or for a single batch if True is provided.
|
87
|
-
"""
|
88
|
-
config = copy.deepcopy(self)
|
89
|
-
config.trainer.fast_dev_run = value
|
90
|
-
return config
|
91
|
-
|
92
|
-
def with_project_root_(self, project_root: str | Path | os.PathLike) -> Self:
|
93
|
-
"""
|
94
|
-
Set the project root directory for the trainer.
|
95
|
-
|
96
|
-
Args:
|
97
|
-
project_root (Path): The base directory to use.
|
98
|
-
|
99
|
-
Returns:
|
100
|
-
self: The current instance of the class.
|
101
|
-
"""
|
102
|
-
self.directory.project_root = Path(project_root)
|
103
|
-
return self
|
104
|
-
|
105
|
-
def reset_(
|
106
|
-
self,
|
107
|
-
*,
|
108
|
-
id: bool = True,
|
109
|
-
basic: bool = True,
|
110
|
-
project_root: bool = True,
|
111
|
-
environment: bool = True,
|
112
|
-
meta: bool = True,
|
113
|
-
):
|
114
|
-
"""
|
115
|
-
Reset the configuration object to its initial state.
|
116
|
-
|
117
|
-
Parameters:
|
118
|
-
- id (bool): If True, generate a new ID for the configuration object.
|
119
|
-
- basic (bool): If True, reset basic attributes like name, project, tags, and notes.
|
120
|
-
- project_root (bool): If True, reset the directory configuration to its initial state.
|
121
|
-
- environment (bool): If True, reset the environment configuration to its initial state.
|
122
|
-
- meta (bool): If True, reset the meta dictionary to an empty dictionary.
|
123
|
-
|
124
|
-
Returns:
|
125
|
-
- self: The updated configuration object.
|
126
|
-
|
127
|
-
"""
|
128
|
-
if id:
|
129
|
-
self.id = self.generate_id()
|
130
|
-
|
131
|
-
if basic:
|
132
|
-
self.name = None
|
133
|
-
self.name_parts = []
|
134
|
-
self.project = None
|
135
|
-
self.tags = []
|
136
|
-
self.notes = []
|
137
|
-
|
138
|
-
if project_root:
|
139
|
-
self.directory = DirectoryConfig()
|
140
|
-
|
141
|
-
if environment:
|
142
|
-
self.environment = EnvironmentConfig.empty()
|
143
|
-
|
144
|
-
if meta:
|
145
|
-
self.meta = {}
|
146
|
-
|
147
|
-
return self
|
148
|
-
|
149
|
-
def concise_repr(self) -> str:
|
150
|
-
"""Get a concise representation of the configuration object."""
|
151
|
-
|
152
|
-
def _truncate(s: str, max_len: int = 50):
|
153
|
-
return s if len(s) <= max_len else f"{s[:max_len - 3]}..."
|
154
|
-
|
155
|
-
cls_name = self.__class__.__name__
|
156
|
-
|
157
|
-
parts: list[str] = []
|
158
|
-
parts.append(f"name={self.run_name}")
|
159
|
-
if self.project:
|
160
|
-
parts.append(f"project={_truncate(self.project)}")
|
161
|
-
|
162
|
-
return f"{cls_name}({', '.join(parts)})"
|
163
|
-
|
164
|
-
# endregion
|
165
|
-
|
166
|
-
# region Seeding
|
167
|
-
|
168
|
-
_rng: ClassVar[np.random.Generator | None] = None
|
169
|
-
|
170
|
-
@staticmethod
|
171
|
-
def generate_id(*, length: int = 8) -> str:
|
172
|
-
"""
|
173
|
-
Generate a random ID of specified length.
|
174
|
-
|
175
|
-
"""
|
176
|
-
if (rng := BaseConfig._rng) is None:
|
177
|
-
rng = np.random.default_rng()
|
178
|
-
|
179
|
-
alphabet = list(string.ascii_lowercase + string.digits)
|
180
|
-
|
181
|
-
id = "".join(rng.choice(alphabet) for _ in range(length))
|
182
|
-
return id
|
183
|
-
|
184
|
-
@staticmethod
|
185
|
-
def set_seed(seed: int | None = None) -> None:
|
186
|
-
"""
|
187
|
-
Set the seed for the random number generator.
|
188
|
-
|
189
|
-
Args:
|
190
|
-
seed (int | None, optional): The seed value to set. If None, a seed based on the current time will be used. Defaults to None.
|
191
|
-
|
192
|
-
Returns:
|
193
|
-
None
|
194
|
-
"""
|
195
|
-
if seed is None:
|
196
|
-
seed = int(time.time() * 1000)
|
197
|
-
log.critical(f"Seeding BaseConfig with seed {seed}")
|
198
|
-
BaseConfig._rng = np.random.default_rng(seed)
|
199
|
-
|
200
|
-
# endregion
|
201
|
-
|
202
|
-
@classmethod
|
203
|
-
def from_checkpoint(
|
204
|
-
cls,
|
205
|
-
path: str | Path,
|
206
|
-
*,
|
207
|
-
hparams_key: str = "hyper_parameters",
|
208
|
-
):
|
209
|
-
ckpt = torch.load(path)
|
210
|
-
if (hparams := ckpt.get(hparams_key)) is None:
|
211
|
-
raise ValueError(
|
212
|
-
f"The checkpoint does not contain the `{hparams_key}` attribute. "
|
213
|
-
"Are you sure this is a valid Lightning checkpoint?"
|
214
|
-
)
|
215
|
-
return cls.model_validate(hparams)
|
216
|
-
|
217
|
-
def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
|
218
|
-
yield from self.trainer._nshtrainer_all_callback_configs()
|
nshtrainer/runner.py
DELETED
@@ -1,101 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import copy
|
4
|
-
import logging
|
5
|
-
from collections.abc import Callable, Iterable, Mapping, Sequence
|
6
|
-
from pathlib import Path
|
7
|
-
from typing import Generic
|
8
|
-
|
9
|
-
import nshrunner as nr
|
10
|
-
from nshrunner._submit import screen
|
11
|
-
from typing_extensions import TypeVar, TypeVarTuple, Unpack, deprecated, override
|
12
|
-
|
13
|
-
from .model.config import BaseConfig
|
14
|
-
|
15
|
-
TConfig = TypeVar("TConfig", bound=BaseConfig, infer_variance=True)
|
16
|
-
TArguments = TypeVarTuple("TArguments", default=Unpack[tuple[()]])
|
17
|
-
TReturn = TypeVar("TReturn", infer_variance=True)
|
18
|
-
|
19
|
-
|
20
|
-
@deprecated("Use nshrunner.Runner instead.")
|
21
|
-
class Runner(
|
22
|
-
nr.Runner[TReturn, TConfig, Unpack[TArguments]],
|
23
|
-
Generic[TReturn, TConfig, Unpack[TArguments]],
|
24
|
-
):
|
25
|
-
@override
|
26
|
-
def __init__(
|
27
|
-
self,
|
28
|
-
run_fn: Callable[[TConfig, Unpack[TArguments]], TReturn],
|
29
|
-
config: nr.RunnerConfig | None = None,
|
30
|
-
):
|
31
|
-
if config is None:
|
32
|
-
working_dir = Path.cwd() / "nshrunner"
|
33
|
-
working_dir.mkdir(exist_ok=True)
|
34
|
-
|
35
|
-
logging.warning(
|
36
|
-
f"`config` is not provided. Using default working directory of {working_dir}."
|
37
|
-
)
|
38
|
-
config = nr.RunnerConfig(working_dir=working_dir)
|
39
|
-
|
40
|
-
super().__init__(run_fn, config)
|
41
|
-
|
42
|
-
def fast_dev_run(
|
43
|
-
self,
|
44
|
-
runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
|
45
|
-
n_batches: int = 1,
|
46
|
-
*,
|
47
|
-
env: Mapping[str, str] | None = None,
|
48
|
-
):
|
49
|
-
runs_updated: list[tuple[TConfig, Unpack[TArguments]]] = []
|
50
|
-
for args in runs:
|
51
|
-
config = copy.deepcopy(args[0])
|
52
|
-
config.trainer.fast_dev_run = n_batches
|
53
|
-
runs_updated.append((config, *args[1:]))
|
54
|
-
del runs
|
55
|
-
|
56
|
-
return self.local(runs_updated, env=env)
|
57
|
-
|
58
|
-
def fast_dev_run_generator(
|
59
|
-
self,
|
60
|
-
runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
|
61
|
-
n_batches: int = 1,
|
62
|
-
*,
|
63
|
-
env: Mapping[str, str] | None = None,
|
64
|
-
):
|
65
|
-
runs_updated: list[tuple[TConfig, Unpack[TArguments]]] = []
|
66
|
-
for args in runs:
|
67
|
-
config = copy.deepcopy(args[0])
|
68
|
-
config.trainer.fast_dev_run = n_batches
|
69
|
-
runs_updated.append((config, *args[1:]))
|
70
|
-
del runs
|
71
|
-
|
72
|
-
return self.local_generator(runs_updated, env=env)
|
73
|
-
|
74
|
-
def fast_dev_run_session(
|
75
|
-
self,
|
76
|
-
runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
|
77
|
-
options: screen.ScreenJobKwargs = {},
|
78
|
-
n_batches: int = 1,
|
79
|
-
*,
|
80
|
-
snapshot: nr.Snapshot,
|
81
|
-
setup_commands: Sequence[str] | None = None,
|
82
|
-
env: Mapping[str, str] | None = None,
|
83
|
-
activate_venv: bool = True,
|
84
|
-
print_command: bool = True,
|
85
|
-
):
|
86
|
-
runs_updated: list[tuple[TConfig, Unpack[TArguments]]] = []
|
87
|
-
for args in runs:
|
88
|
-
config = copy.deepcopy(args[0])
|
89
|
-
config.trainer.fast_dev_run = n_batches
|
90
|
-
runs_updated.append((config, *args[1:]))
|
91
|
-
del runs
|
92
|
-
|
93
|
-
return self.session(
|
94
|
-
runs_updated,
|
95
|
-
options,
|
96
|
-
snapshot=snapshot,
|
97
|
-
setup_commands=setup_commands,
|
98
|
-
env=env,
|
99
|
-
activate_venv=activate_venv,
|
100
|
-
print_command=print_command,
|
101
|
-
)
|