nshtrainer 0.38.0__tar.gz → 0.40.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/PKG-INFO +1 -1
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/pyproject.toml +1 -1
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/__init__.py +1 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/checkpoint/_base.py +3 -2
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +13 -2
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/early_stopping.py +1 -1
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/data/__init__.py +1 -0
- nshtrainer-0.40.0/src/nshtrainer/data/datamodule.py +5 -0
- nshtrainer-0.40.0/src/nshtrainer/runner.py +99 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/trainer/_config.py +1 -1
- nshtrainer-0.38.0/src/nshtrainer/runner.py +0 -118
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/README.md +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/_callback.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/_checkpoint/loader.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/_checkpoint/metadata.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/_directory.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/_hf_hub.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/debug_flag.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/directory_setup.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/config.py +6 -6
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/__init__.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/_experimental.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/actsave.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/callbacks.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/config.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/data.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/log.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/lr_scheduler.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/model.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/nn.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/optimizer.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/runner.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/snapshot.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/snoop.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/trainer.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/typecheck.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/util.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/loggers/__init__.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/loggers/_base.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/loggers/wandb.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/model/config.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/model/mixins/callback.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/model/mixins/logger.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/nn/module_list.py +2 -2
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/profiler/__init__.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/profiler/_base.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/profiler/advanced.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/profiler/pytorch.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/profiler/simple.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/scripts/find_packages.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/trainer/trainer.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/_useful_types.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/bf16.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/config/__init__.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/config/dtype.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/config/duration.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/path.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/typing_utils.py +0 -0
|
@@ -8,6 +8,7 @@ from . import model as model
|
|
|
8
8
|
from . import nn as nn
|
|
9
9
|
from . import optimizer as optimizer
|
|
10
10
|
from . import profiler as profiler
|
|
11
|
+
from .data import LightningDataModuleBase as LightningDataModuleBase
|
|
11
12
|
from .metrics import MetricConfig as MetricConfig
|
|
12
13
|
from .model import BaseConfig as BaseConfig
|
|
13
14
|
from .model import LightningModuleBase as LightningModuleBase
|
|
@@ -41,7 +41,7 @@ class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
|
|
|
41
41
|
self,
|
|
42
42
|
root_config: "BaseConfig",
|
|
43
43
|
dirpath: Path,
|
|
44
|
-
) -> "CheckpointBase": ...
|
|
44
|
+
) -> "CheckpointBase | None": ...
|
|
45
45
|
|
|
46
46
|
@override
|
|
47
47
|
def create_callbacks(self, root_config):
|
|
@@ -50,7 +50,8 @@ class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
|
|
|
50
50
|
or root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
|
|
51
51
|
)
|
|
52
52
|
|
|
53
|
-
|
|
53
|
+
if (callback := self.create_checkpoint(root_config, dirpath)) is not None:
|
|
54
|
+
yield callback
|
|
54
55
|
|
|
55
56
|
|
|
56
57
|
TConfig = TypeVar("TConfig", bound=BaseCheckpointCallbackConfig, infer_variance=True)
|
{nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py
RENAMED
|
@@ -20,15 +20,26 @@ class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
|
|
20
20
|
metric: MetricConfig | None = None
|
|
21
21
|
"""Metric to monitor, or `None` to use the default metric."""
|
|
22
22
|
|
|
23
|
+
throw_on_no_metric: bool = True
|
|
24
|
+
"""
|
|
25
|
+
Whether to throw an error if no metric is provided and no primary metric is found in the root config.
|
|
26
|
+
"""
|
|
27
|
+
|
|
23
28
|
@override
|
|
24
29
|
def create_checkpoint(self, root_config, dirpath):
|
|
25
30
|
# Resolve metric
|
|
26
31
|
if (metric := self.metric) is None and (
|
|
27
32
|
metric := root_config.primary_metric
|
|
28
33
|
) is None:
|
|
29
|
-
|
|
30
|
-
"No metric provided and no primary metric found in the root config"
|
|
34
|
+
error_msg = (
|
|
35
|
+
"No metric provided and no primary metric found in the root config. "
|
|
36
|
+
"Cannot create BestCheckpointCallback."
|
|
31
37
|
)
|
|
38
|
+
if self.throw_on_no_metric:
|
|
39
|
+
raise ValueError(error_msg)
|
|
40
|
+
else:
|
|
41
|
+
log.warning(error_msg)
|
|
42
|
+
return None
|
|
32
43
|
|
|
33
44
|
return BestCheckpoint(self, dirpath, metric)
|
|
34
45
|
|
|
@@ -51,7 +51,7 @@ class EarlyStoppingConfig(CallbackConfigBase):
|
|
|
51
51
|
metric := root_config.primary_metric
|
|
52
52
|
) is None:
|
|
53
53
|
raise ValueError(
|
|
54
|
-
"Either `metric` or `root_config.primary_metric` must be set."
|
|
54
|
+
"Either `metric` or `root_config.primary_metric` must be set to use EarlyStopping."
|
|
55
55
|
)
|
|
56
56
|
|
|
57
57
|
yield EarlyStopping(self, metric)
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import logging
|
|
3
|
+
from collections.abc import Callable, Iterable, Mapping, Sequence
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Generic
|
|
6
|
+
|
|
7
|
+
import nshrunner as nr
|
|
8
|
+
from nshrunner._submit import screen
|
|
9
|
+
from typing_extensions import TypeVar, TypeVarTuple, Unpack, deprecated, override
|
|
10
|
+
|
|
11
|
+
from .model.config import BaseConfig
|
|
12
|
+
|
|
13
|
+
TConfig = TypeVar("TConfig", bound=BaseConfig, infer_variance=True)
|
|
14
|
+
TArguments = TypeVarTuple("TArguments", default=Unpack[tuple[()]])
|
|
15
|
+
TReturn = TypeVar("TReturn", infer_variance=True)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@deprecated("Use nshrunner.Runner instead.")
|
|
19
|
+
class Runner(
|
|
20
|
+
nr.Runner[TReturn, TConfig, Unpack[TArguments]],
|
|
21
|
+
Generic[TReturn, TConfig, Unpack[TArguments]],
|
|
22
|
+
):
|
|
23
|
+
@override
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
run_fn: Callable[[TConfig, Unpack[TArguments]], TReturn],
|
|
27
|
+
config: nr.RunnerConfig | None = None,
|
|
28
|
+
):
|
|
29
|
+
if config is None:
|
|
30
|
+
working_dir = Path.cwd() / "nshrunner"
|
|
31
|
+
working_dir.mkdir(exist_ok=True)
|
|
32
|
+
|
|
33
|
+
logging.warning(
|
|
34
|
+
f"`config` is not provided. Using default working directory of {working_dir}."
|
|
35
|
+
)
|
|
36
|
+
config = nr.RunnerConfig(working_dir=working_dir)
|
|
37
|
+
|
|
38
|
+
super().__init__(run_fn, config)
|
|
39
|
+
|
|
40
|
+
def fast_dev_run(
|
|
41
|
+
self,
|
|
42
|
+
runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
|
|
43
|
+
n_batches: int = 1,
|
|
44
|
+
*,
|
|
45
|
+
env: Mapping[str, str] | None = None,
|
|
46
|
+
):
|
|
47
|
+
runs_updated: list[tuple[TConfig, Unpack[TArguments]]] = []
|
|
48
|
+
for args in runs:
|
|
49
|
+
config = copy.deepcopy(args[0])
|
|
50
|
+
config.trainer.fast_dev_run = n_batches
|
|
51
|
+
runs_updated.append((config, *args[1:]))
|
|
52
|
+
del runs
|
|
53
|
+
|
|
54
|
+
return self.local(runs_updated, env=env)
|
|
55
|
+
|
|
56
|
+
def fast_dev_run_generator(
|
|
57
|
+
self,
|
|
58
|
+
runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
|
|
59
|
+
n_batches: int = 1,
|
|
60
|
+
*,
|
|
61
|
+
env: Mapping[str, str] | None = None,
|
|
62
|
+
):
|
|
63
|
+
runs_updated: list[tuple[TConfig, Unpack[TArguments]]] = []
|
|
64
|
+
for args in runs:
|
|
65
|
+
config = copy.deepcopy(args[0])
|
|
66
|
+
config.trainer.fast_dev_run = n_batches
|
|
67
|
+
runs_updated.append((config, *args[1:]))
|
|
68
|
+
del runs
|
|
69
|
+
|
|
70
|
+
return self.local_generator(runs_updated, env=env)
|
|
71
|
+
|
|
72
|
+
def fast_dev_run_session(
|
|
73
|
+
self,
|
|
74
|
+
runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
|
|
75
|
+
options: screen.ScreenJobKwargs = {},
|
|
76
|
+
n_batches: int = 1,
|
|
77
|
+
*,
|
|
78
|
+
snapshot: nr.Snapshot,
|
|
79
|
+
setup_commands: Sequence[str] | None = None,
|
|
80
|
+
env: Mapping[str, str] | None = None,
|
|
81
|
+
activate_venv: bool = True,
|
|
82
|
+
print_command: bool = True,
|
|
83
|
+
):
|
|
84
|
+
runs_updated: list[tuple[TConfig, Unpack[TArguments]]] = []
|
|
85
|
+
for args in runs:
|
|
86
|
+
config = copy.deepcopy(args[0])
|
|
87
|
+
config.trainer.fast_dev_run = n_batches
|
|
88
|
+
runs_updated.append((config, *args[1:]))
|
|
89
|
+
del runs
|
|
90
|
+
|
|
91
|
+
return self.session(
|
|
92
|
+
runs_updated,
|
|
93
|
+
options,
|
|
94
|
+
snapshot=snapshot,
|
|
95
|
+
setup_commands=setup_commands,
|
|
96
|
+
env=env,
|
|
97
|
+
activate_venv=activate_venv,
|
|
98
|
+
print_command=print_command,
|
|
99
|
+
)
|
|
@@ -263,7 +263,7 @@ class CheckpointSavingConfig(CallbackConfigBase):
|
|
|
263
263
|
"""Enable checkpoint saving."""
|
|
264
264
|
|
|
265
265
|
checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
|
|
266
|
-
BestCheckpointCallbackConfig(),
|
|
266
|
+
BestCheckpointCallbackConfig(throw_on_no_metric=False),
|
|
267
267
|
LastCheckpointCallbackConfig(),
|
|
268
268
|
OnExceptionCheckpointCallbackConfig(),
|
|
269
269
|
]
|
|
@@ -1,118 +0,0 @@
|
|
|
1
|
-
import copy
|
|
2
|
-
import functools
|
|
3
|
-
from collections.abc import Callable, Iterable, Mapping, Sequence
|
|
4
|
-
from typing import Generic
|
|
5
|
-
|
|
6
|
-
from nshrunner import RunInfo, Snapshot
|
|
7
|
-
from nshrunner import Runner as _Runner
|
|
8
|
-
from nshrunner._submit import screen
|
|
9
|
-
from typing_extensions import TypeVar, TypeVarTuple, Unpack, override
|
|
10
|
-
|
|
11
|
-
from .model.config import BaseConfig
|
|
12
|
-
|
|
13
|
-
TConfig = TypeVar("TConfig", bound=BaseConfig, infer_variance=True)
|
|
14
|
-
TArguments = TypeVarTuple("TArguments", default=Unpack[tuple[()]])
|
|
15
|
-
TReturn = TypeVar("TReturn", infer_variance=True)
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class Runner(
|
|
19
|
-
_Runner[TReturn, TConfig, Unpack[TArguments]],
|
|
20
|
-
Generic[TReturn, TConfig, Unpack[TArguments]],
|
|
21
|
-
):
|
|
22
|
-
@override
|
|
23
|
-
@classmethod
|
|
24
|
-
def default_validate_fn(cls, config: TConfig, *args: Unpack[TArguments]) -> None:
|
|
25
|
-
super().default_validate_fn(config, *args)
|
|
26
|
-
|
|
27
|
-
@override
|
|
28
|
-
@classmethod
|
|
29
|
-
def default_info_fn(cls, config: TConfig, *args: Unpack[TArguments]) -> RunInfo:
|
|
30
|
-
run_info = super().default_info_fn(config, *args)
|
|
31
|
-
return {
|
|
32
|
-
**run_info,
|
|
33
|
-
"id": config.id,
|
|
34
|
-
"base_dir": config.directory.project_root,
|
|
35
|
-
}
|
|
36
|
-
|
|
37
|
-
def _fast_dev_run_transform(
|
|
38
|
-
self,
|
|
39
|
-
config: TConfig,
|
|
40
|
-
*args: Unpack[TArguments],
|
|
41
|
-
n_batches: int,
|
|
42
|
-
):
|
|
43
|
-
config = copy.deepcopy(config)
|
|
44
|
-
config.trainer.fast_dev_run = n_batches
|
|
45
|
-
return (config, *args)
|
|
46
|
-
|
|
47
|
-
def fast_dev_run(
|
|
48
|
-
self,
|
|
49
|
-
runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
|
|
50
|
-
n_batches: int = 1,
|
|
51
|
-
*,
|
|
52
|
-
env: Mapping[str, str] | None = None,
|
|
53
|
-
transforms: list[
|
|
54
|
-
Callable[[TConfig, Unpack[TArguments]], tuple[TConfig, Unpack[TArguments]]]
|
|
55
|
-
]
|
|
56
|
-
| None = None,
|
|
57
|
-
):
|
|
58
|
-
transforms = transforms or []
|
|
59
|
-
transforms.append(
|
|
60
|
-
functools.partial(self._fast_dev_run_transform, n_batches=n_batches)
|
|
61
|
-
)
|
|
62
|
-
return self.local(
|
|
63
|
-
runs,
|
|
64
|
-
env=env,
|
|
65
|
-
transforms=transforms,
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
def fast_dev_run_generator(
|
|
69
|
-
self,
|
|
70
|
-
runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
|
|
71
|
-
n_batches: int = 1,
|
|
72
|
-
*,
|
|
73
|
-
env: Mapping[str, str] | None = None,
|
|
74
|
-
transforms: list[
|
|
75
|
-
Callable[[TConfig, Unpack[TArguments]], tuple[TConfig, Unpack[TArguments]]]
|
|
76
|
-
]
|
|
77
|
-
| None = None,
|
|
78
|
-
):
|
|
79
|
-
transforms = transforms or []
|
|
80
|
-
transforms.append(
|
|
81
|
-
functools.partial(self._fast_dev_run_transform, n_batches=n_batches)
|
|
82
|
-
)
|
|
83
|
-
return self.local_generator(
|
|
84
|
-
runs,
|
|
85
|
-
env=env,
|
|
86
|
-
transforms=transforms,
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
def fast_dev_run_session(
|
|
90
|
-
self,
|
|
91
|
-
runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
|
|
92
|
-
options: screen.ScreenJobKwargs = {},
|
|
93
|
-
n_batches: int = 1,
|
|
94
|
-
*,
|
|
95
|
-
snapshot: Snapshot,
|
|
96
|
-
setup_commands: Sequence[str] | None = None,
|
|
97
|
-
env: Mapping[str, str] | None = None,
|
|
98
|
-
transforms: list[
|
|
99
|
-
Callable[[TConfig, Unpack[TArguments]], tuple[TConfig, Unpack[TArguments]]]
|
|
100
|
-
]
|
|
101
|
-
| None = None,
|
|
102
|
-
activate_venv: bool = True,
|
|
103
|
-
print_command: bool = True,
|
|
104
|
-
):
|
|
105
|
-
transforms = transforms or []
|
|
106
|
-
transforms.append(
|
|
107
|
-
functools.partial(self._fast_dev_run_transform, n_batches=n_batches)
|
|
108
|
-
)
|
|
109
|
-
return self.session(
|
|
110
|
-
runs,
|
|
111
|
-
options,
|
|
112
|
-
snapshot=snapshot,
|
|
113
|
-
setup_commands=setup_commands,
|
|
114
|
-
env=env,
|
|
115
|
-
transforms=transforms,
|
|
116
|
-
activate_venv=activate_venv,
|
|
117
|
-
print_command=print_command,
|
|
118
|
-
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -95,10 +95,10 @@ from nshtrainer.nn.nonlinearity import MishNonlinearityConfig as MishNonlinearit
|
|
|
95
95
|
from nshtrainer.nn.nonlinearity import NonlinearityConfig as NonlinearityConfig
|
|
96
96
|
from nshtrainer.nn.nonlinearity import PReLUConfig as PReLUConfig
|
|
97
97
|
from nshtrainer.nn.nonlinearity import ReLUNonlinearityConfig as ReLUNonlinearityConfig
|
|
98
|
-
from nshtrainer.nn.nonlinearity import SiLUNonlinearityConfig as SiLUNonlinearityConfig
|
|
99
98
|
from nshtrainer.nn.nonlinearity import (
|
|
100
99
|
SigmoidNonlinearityConfig as SigmoidNonlinearityConfig,
|
|
101
100
|
)
|
|
101
|
+
from nshtrainer.nn.nonlinearity import SiLUNonlinearityConfig as SiLUNonlinearityConfig
|
|
102
102
|
from nshtrainer.nn.nonlinearity import (
|
|
103
103
|
SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig,
|
|
104
104
|
)
|
|
@@ -137,13 +137,13 @@ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
|
|
137
137
|
from nshtrainer.trainer._config import ReproducibilityConfig as ReproducibilityConfig
|
|
138
138
|
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
|
139
139
|
from nshtrainer.trainer._config import TrainerConfig as TrainerConfig
|
|
140
|
-
from nshtrainer.util._environment_info import (
|
|
141
|
-
EnvironmentCUDAConfig as EnvironmentCUDAConfig,
|
|
142
|
-
)
|
|
143
140
|
from nshtrainer.util._environment_info import (
|
|
144
141
|
EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
|
|
145
142
|
)
|
|
146
143
|
from nshtrainer.util._environment_info import EnvironmentConfig as EnvironmentConfig
|
|
144
|
+
from nshtrainer.util._environment_info import (
|
|
145
|
+
EnvironmentCUDAConfig as EnvironmentCUDAConfig,
|
|
146
|
+
)
|
|
147
147
|
from nshtrainer.util._environment_info import (
|
|
148
148
|
EnvironmentGPUConfig as EnvironmentGPUConfig,
|
|
149
149
|
)
|
|
@@ -151,10 +151,10 @@ from nshtrainer.util._environment_info import (
|
|
|
151
151
|
EnvironmentHardwareConfig as EnvironmentHardwareConfig,
|
|
152
152
|
)
|
|
153
153
|
from nshtrainer.util._environment_info import (
|
|
154
|
-
|
|
154
|
+
EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
|
|
155
155
|
)
|
|
156
156
|
from nshtrainer.util._environment_info import (
|
|
157
|
-
|
|
157
|
+
EnvironmentLSFInformationConfig as EnvironmentLSFInformationConfig,
|
|
158
158
|
)
|
|
159
159
|
from nshtrainer.util._environment_info import (
|
|
160
160
|
EnvironmentPackageConfig as EnvironmentPackageConfig,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -12,10 +12,10 @@ class TypedModuleList(nn.ModuleList, Generic[TModule]):
|
|
|
12
12
|
super().__init__(modules)
|
|
13
13
|
|
|
14
14
|
@overload
|
|
15
|
-
def __getitem__(self, idx:
|
|
15
|
+
def __getitem__(self, idx: slice) -> "TypedModuleList[TModule]": ...
|
|
16
16
|
|
|
17
17
|
@overload
|
|
18
|
-
def __getitem__(self, idx:
|
|
18
|
+
def __getitem__(self, idx: int) -> TModule: ...
|
|
19
19
|
|
|
20
20
|
@override
|
|
21
21
|
def __getitem__(self, idx: int | slice) -> TModule | "TypedModuleList[TModule]":
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|