nshtrainer 0.38.0__tar.gz → 0.40.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/PKG-INFO +1 -1
  2. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/pyproject.toml +1 -1
  3. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/__init__.py +1 -0
  4. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/checkpoint/_base.py +3 -2
  5. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +13 -2
  6. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/early_stopping.py +1 -1
  7. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/data/__init__.py +1 -0
  8. nshtrainer-0.40.0/src/nshtrainer/data/datamodule.py +5 -0
  9. nshtrainer-0.40.0/src/nshtrainer/runner.py +99 -0
  10. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/trainer/_config.py +1 -1
  11. nshtrainer-0.38.0/src/nshtrainer/runner.py +0 -118
  12. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/README.md +0 -0
  13. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/_callback.py +0 -0
  14. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/_checkpoint/loader.py +0 -0
  15. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  16. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
  17. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/_directory.py +0 -0
  18. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/_experimental/__init__.py +0 -0
  19. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/_hf_hub.py +0 -0
  20. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/__init__.py +0 -0
  21. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  22. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/actsave.py +0 -0
  23. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/base.py +0 -0
  24. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  25. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  26. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  27. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  28. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  29. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/ema.py +0 -0
  30. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  31. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  32. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/interval.py +0 -0
  33. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  34. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  35. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/print_table.py +0 -0
  36. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
  37. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  38. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  39. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/timer.py +0 -0
  40. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
  41. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  42. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/config.py +6 -6
  43. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  44. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/data/transform.py +0 -0
  45. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/__init__.py +0 -0
  46. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/_experimental.py +0 -0
  47. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/actsave.py +0 -0
  48. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/callbacks.py +0 -0
  49. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/config.py +0 -0
  50. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/data.py +0 -0
  51. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/log.py +0 -0
  52. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  53. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/model.py +0 -0
  54. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/nn.py +0 -0
  55. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/optimizer.py +0 -0
  56. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/runner.py +0 -0
  57. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/snapshot.py +0 -0
  58. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/snoop.py +0 -0
  59. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/trainer.py +0 -0
  60. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/typecheck.py +0 -0
  61. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/ll/util.py +0 -0
  62. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/loggers/__init__.py +0 -0
  63. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/loggers/_base.py +0 -0
  64. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/loggers/csv.py +0 -0
  65. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
  66. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/loggers/wandb.py +0 -0
  67. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  68. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  69. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  70. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  71. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/metrics/__init__.py +0 -0
  72. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/metrics/_config.py +0 -0
  73. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/model/__init__.py +0 -0
  74. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/model/base.py +0 -0
  75. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/model/config.py +0 -0
  76. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/model/mixins/callback.py +0 -0
  77. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/model/mixins/logger.py +0 -0
  78. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/nn/__init__.py +0 -0
  79. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/nn/mlp.py +0 -0
  80. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/nn/module_dict.py +0 -0
  81. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/nn/module_list.py +2 -2
  82. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
  83. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/optimizer.py +0 -0
  84. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/profiler/__init__.py +0 -0
  85. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/profiler/_base.py +0 -0
  86. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/profiler/advanced.py +0 -0
  87. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/profiler/pytorch.py +0 -0
  88. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/profiler/simple.py +0 -0
  89. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/scripts/find_packages.py +0 -0
  90. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/trainer/__init__.py +0 -0
  91. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  92. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  93. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
  94. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/trainer/trainer.py +0 -0
  95. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/_environment_info.py +0 -0
  96. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/_useful_types.py +0 -0
  97. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/bf16.py +0 -0
  98. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/config/__init__.py +0 -0
  99. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/config/dtype.py +0 -0
  100. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/config/duration.py +0 -0
  101. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/environment.py +0 -0
  102. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/path.py +0 -0
  103. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/seed.py +0 -0
  104. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/slurm.py +0 -0
  105. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/typed.py +0 -0
  106. {nshtrainer-0.38.0 → nshtrainer-0.40.0}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.38.0
3
+ Version: 0.40.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.38.0"
3
+ version = "0.40.0"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -8,6 +8,7 @@ from . import model as model
8
8
  from . import nn as nn
9
9
  from . import optimizer as optimizer
10
10
  from . import profiler as profiler
11
+ from .data import LightningDataModuleBase as LightningDataModuleBase
11
12
  from .metrics import MetricConfig as MetricConfig
12
13
  from .model import BaseConfig as BaseConfig
13
14
  from .model import LightningModuleBase as LightningModuleBase
@@ -41,7 +41,7 @@ class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
41
41
  self,
42
42
  root_config: "BaseConfig",
43
43
  dirpath: Path,
44
- ) -> "CheckpointBase": ...
44
+ ) -> "CheckpointBase | None": ...
45
45
 
46
46
  @override
47
47
  def create_callbacks(self, root_config):
@@ -50,7 +50,8 @@ class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
50
50
  or root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
51
51
  )
52
52
 
53
- yield self.create_checkpoint(root_config, dirpath)
53
+ if (callback := self.create_checkpoint(root_config, dirpath)) is not None:
54
+ yield callback
54
55
 
55
56
 
56
57
  TConfig = TypeVar("TConfig", bound=BaseCheckpointCallbackConfig, infer_variance=True)
@@ -20,15 +20,26 @@ class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
20
20
  metric: MetricConfig | None = None
21
21
  """Metric to monitor, or `None` to use the default metric."""
22
22
 
23
+ throw_on_no_metric: bool = True
24
+ """
25
+ Whether to throw an error if no metric is provided and no primary metric is found in the root config.
26
+ """
27
+
23
28
  @override
24
29
  def create_checkpoint(self, root_config, dirpath):
25
30
  # Resolve metric
26
31
  if (metric := self.metric) is None and (
27
32
  metric := root_config.primary_metric
28
33
  ) is None:
29
- raise ValueError(
30
- "No metric provided and no primary metric found in the root config"
34
+ error_msg = (
35
+ "No metric provided and no primary metric found in the root config. "
36
+ "Cannot create BestCheckpointCallback."
31
37
  )
38
+ if self.throw_on_no_metric:
39
+ raise ValueError(error_msg)
40
+ else:
41
+ log.warning(error_msg)
42
+ return None
32
43
 
33
44
  return BestCheckpoint(self, dirpath, metric)
34
45
 
@@ -51,7 +51,7 @@ class EarlyStoppingConfig(CallbackConfigBase):
51
51
  metric := root_config.primary_metric
52
52
  ) is None:
53
53
  raise ValueError(
54
- "Either `metric` or `root_config.primary_metric` must be set."
54
+ "Either `metric` or `root_config.primary_metric` must be set to use EarlyStopping."
55
55
  )
56
56
 
57
57
  yield EarlyStopping(self, metric)
@@ -1,4 +1,5 @@
1
1
  from . import transform as dataset_transform
2
2
  from .balanced_batch_sampler import BalancedBatchSampler as BalancedBatchSampler
3
+ from .datamodule import LightningDataModuleBase as LightningDataModuleBase
3
4
 
4
5
  _ = dataset_transform
@@ -0,0 +1,5 @@
1
+ from lightning.pytorch import LightningDataModule
2
+
3
+
4
+ class LightningDataModuleBase(LightningDataModule):
5
+ pass
@@ -0,0 +1,99 @@
1
+ import copy
2
+ import logging
3
+ from collections.abc import Callable, Iterable, Mapping, Sequence
4
+ from pathlib import Path
5
+ from typing import Generic
6
+
7
+ import nshrunner as nr
8
+ from nshrunner._submit import screen
9
+ from typing_extensions import TypeVar, TypeVarTuple, Unpack, deprecated, override
10
+
11
+ from .model.config import BaseConfig
12
+
13
+ TConfig = TypeVar("TConfig", bound=BaseConfig, infer_variance=True)
14
+ TArguments = TypeVarTuple("TArguments", default=Unpack[tuple[()]])
15
+ TReturn = TypeVar("TReturn", infer_variance=True)
16
+
17
+
18
+ @deprecated("Use nshrunner.Runner instead.")
19
+ class Runner(
20
+ nr.Runner[TReturn, TConfig, Unpack[TArguments]],
21
+ Generic[TReturn, TConfig, Unpack[TArguments]],
22
+ ):
23
+ @override
24
+ def __init__(
25
+ self,
26
+ run_fn: Callable[[TConfig, Unpack[TArguments]], TReturn],
27
+ config: nr.RunnerConfig | None = None,
28
+ ):
29
+ if config is None:
30
+ working_dir = Path.cwd() / "nshrunner"
31
+ working_dir.mkdir(exist_ok=True)
32
+
33
+ logging.warning(
34
+ f"`config` is not provided. Using default working directory of {working_dir}."
35
+ )
36
+ config = nr.RunnerConfig(working_dir=working_dir)
37
+
38
+ super().__init__(run_fn, config)
39
+
40
+ def fast_dev_run(
41
+ self,
42
+ runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
43
+ n_batches: int = 1,
44
+ *,
45
+ env: Mapping[str, str] | None = None,
46
+ ):
47
+ runs_updated: list[tuple[TConfig, Unpack[TArguments]]] = []
48
+ for args in runs:
49
+ config = copy.deepcopy(args[0])
50
+ config.trainer.fast_dev_run = n_batches
51
+ runs_updated.append((config, *args[1:]))
52
+ del runs
53
+
54
+ return self.local(runs_updated, env=env)
55
+
56
+ def fast_dev_run_generator(
57
+ self,
58
+ runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
59
+ n_batches: int = 1,
60
+ *,
61
+ env: Mapping[str, str] | None = None,
62
+ ):
63
+ runs_updated: list[tuple[TConfig, Unpack[TArguments]]] = []
64
+ for args in runs:
65
+ config = copy.deepcopy(args[0])
66
+ config.trainer.fast_dev_run = n_batches
67
+ runs_updated.append((config, *args[1:]))
68
+ del runs
69
+
70
+ return self.local_generator(runs_updated, env=env)
71
+
72
+ def fast_dev_run_session(
73
+ self,
74
+ runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
75
+ options: screen.ScreenJobKwargs = {},
76
+ n_batches: int = 1,
77
+ *,
78
+ snapshot: nr.Snapshot,
79
+ setup_commands: Sequence[str] | None = None,
80
+ env: Mapping[str, str] | None = None,
81
+ activate_venv: bool = True,
82
+ print_command: bool = True,
83
+ ):
84
+ runs_updated: list[tuple[TConfig, Unpack[TArguments]]] = []
85
+ for args in runs:
86
+ config = copy.deepcopy(args[0])
87
+ config.trainer.fast_dev_run = n_batches
88
+ runs_updated.append((config, *args[1:]))
89
+ del runs
90
+
91
+ return self.session(
92
+ runs_updated,
93
+ options,
94
+ snapshot=snapshot,
95
+ setup_commands=setup_commands,
96
+ env=env,
97
+ activate_venv=activate_venv,
98
+ print_command=print_command,
99
+ )
@@ -263,7 +263,7 @@ class CheckpointSavingConfig(CallbackConfigBase):
263
263
  """Enable checkpoint saving."""
264
264
 
265
265
  checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
266
- BestCheckpointCallbackConfig(),
266
+ BestCheckpointCallbackConfig(throw_on_no_metric=False),
267
267
  LastCheckpointCallbackConfig(),
268
268
  OnExceptionCheckpointCallbackConfig(),
269
269
  ]
@@ -1,118 +0,0 @@
1
- import copy
2
- import functools
3
- from collections.abc import Callable, Iterable, Mapping, Sequence
4
- from typing import Generic
5
-
6
- from nshrunner import RunInfo, Snapshot
7
- from nshrunner import Runner as _Runner
8
- from nshrunner._submit import screen
9
- from typing_extensions import TypeVar, TypeVarTuple, Unpack, override
10
-
11
- from .model.config import BaseConfig
12
-
13
- TConfig = TypeVar("TConfig", bound=BaseConfig, infer_variance=True)
14
- TArguments = TypeVarTuple("TArguments", default=Unpack[tuple[()]])
15
- TReturn = TypeVar("TReturn", infer_variance=True)
16
-
17
-
18
- class Runner(
19
- _Runner[TReturn, TConfig, Unpack[TArguments]],
20
- Generic[TReturn, TConfig, Unpack[TArguments]],
21
- ):
22
- @override
23
- @classmethod
24
- def default_validate_fn(cls, config: TConfig, *args: Unpack[TArguments]) -> None:
25
- super().default_validate_fn(config, *args)
26
-
27
- @override
28
- @classmethod
29
- def default_info_fn(cls, config: TConfig, *args: Unpack[TArguments]) -> RunInfo:
30
- run_info = super().default_info_fn(config, *args)
31
- return {
32
- **run_info,
33
- "id": config.id,
34
- "base_dir": config.directory.project_root,
35
- }
36
-
37
- def _fast_dev_run_transform(
38
- self,
39
- config: TConfig,
40
- *args: Unpack[TArguments],
41
- n_batches: int,
42
- ):
43
- config = copy.deepcopy(config)
44
- config.trainer.fast_dev_run = n_batches
45
- return (config, *args)
46
-
47
- def fast_dev_run(
48
- self,
49
- runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
50
- n_batches: int = 1,
51
- *,
52
- env: Mapping[str, str] | None = None,
53
- transforms: list[
54
- Callable[[TConfig, Unpack[TArguments]], tuple[TConfig, Unpack[TArguments]]]
55
- ]
56
- | None = None,
57
- ):
58
- transforms = transforms or []
59
- transforms.append(
60
- functools.partial(self._fast_dev_run_transform, n_batches=n_batches)
61
- )
62
- return self.local(
63
- runs,
64
- env=env,
65
- transforms=transforms,
66
- )
67
-
68
- def fast_dev_run_generator(
69
- self,
70
- runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
71
- n_batches: int = 1,
72
- *,
73
- env: Mapping[str, str] | None = None,
74
- transforms: list[
75
- Callable[[TConfig, Unpack[TArguments]], tuple[TConfig, Unpack[TArguments]]]
76
- ]
77
- | None = None,
78
- ):
79
- transforms = transforms or []
80
- transforms.append(
81
- functools.partial(self._fast_dev_run_transform, n_batches=n_batches)
82
- )
83
- return self.local_generator(
84
- runs,
85
- env=env,
86
- transforms=transforms,
87
- )
88
-
89
- def fast_dev_run_session(
90
- self,
91
- runs: Iterable[tuple[TConfig, Unpack[TArguments]]],
92
- options: screen.ScreenJobKwargs = {},
93
- n_batches: int = 1,
94
- *,
95
- snapshot: Snapshot,
96
- setup_commands: Sequence[str] | None = None,
97
- env: Mapping[str, str] | None = None,
98
- transforms: list[
99
- Callable[[TConfig, Unpack[TArguments]], tuple[TConfig, Unpack[TArguments]]]
100
- ]
101
- | None = None,
102
- activate_venv: bool = True,
103
- print_command: bool = True,
104
- ):
105
- transforms = transforms or []
106
- transforms.append(
107
- functools.partial(self._fast_dev_run_transform, n_batches=n_batches)
108
- )
109
- return self.session(
110
- runs,
111
- options,
112
- snapshot=snapshot,
113
- setup_commands=setup_commands,
114
- env=env,
115
- transforms=transforms,
116
- activate_venv=activate_venv,
117
- print_command=print_command,
118
- )
File without changes
@@ -95,10 +95,10 @@ from nshtrainer.nn.nonlinearity import MishNonlinearityConfig as MishNonlinearit
95
95
  from nshtrainer.nn.nonlinearity import NonlinearityConfig as NonlinearityConfig
96
96
  from nshtrainer.nn.nonlinearity import PReLUConfig as PReLUConfig
97
97
  from nshtrainer.nn.nonlinearity import ReLUNonlinearityConfig as ReLUNonlinearityConfig
98
- from nshtrainer.nn.nonlinearity import SiLUNonlinearityConfig as SiLUNonlinearityConfig
99
98
  from nshtrainer.nn.nonlinearity import (
100
99
  SigmoidNonlinearityConfig as SigmoidNonlinearityConfig,
101
100
  )
101
+ from nshtrainer.nn.nonlinearity import SiLUNonlinearityConfig as SiLUNonlinearityConfig
102
102
  from nshtrainer.nn.nonlinearity import (
103
103
  SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig,
104
104
  )
@@ -137,13 +137,13 @@ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
137
137
  from nshtrainer.trainer._config import ReproducibilityConfig as ReproducibilityConfig
138
138
  from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
139
139
  from nshtrainer.trainer._config import TrainerConfig as TrainerConfig
140
- from nshtrainer.util._environment_info import (
141
- EnvironmentCUDAConfig as EnvironmentCUDAConfig,
142
- )
143
140
  from nshtrainer.util._environment_info import (
144
141
  EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
145
142
  )
146
143
  from nshtrainer.util._environment_info import EnvironmentConfig as EnvironmentConfig
144
+ from nshtrainer.util._environment_info import (
145
+ EnvironmentCUDAConfig as EnvironmentCUDAConfig,
146
+ )
147
147
  from nshtrainer.util._environment_info import (
148
148
  EnvironmentGPUConfig as EnvironmentGPUConfig,
149
149
  )
@@ -151,10 +151,10 @@ from nshtrainer.util._environment_info import (
151
151
  EnvironmentHardwareConfig as EnvironmentHardwareConfig,
152
152
  )
153
153
  from nshtrainer.util._environment_info import (
154
- EnvironmentLSFInformationConfig as EnvironmentLSFInformationConfig,
154
+ EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
155
155
  )
156
156
  from nshtrainer.util._environment_info import (
157
- EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
157
+ EnvironmentLSFInformationConfig as EnvironmentLSFInformationConfig,
158
158
  )
159
159
  from nshtrainer.util._environment_info import (
160
160
  EnvironmentPackageConfig as EnvironmentPackageConfig,
@@ -12,10 +12,10 @@ class TypedModuleList(nn.ModuleList, Generic[TModule]):
12
12
  super().__init__(modules)
13
13
 
14
14
  @overload
15
- def __getitem__(self, idx: int) -> TModule: ...
15
+ def __getitem__(self, idx: slice) -> "TypedModuleList[TModule]": ...
16
16
 
17
17
  @overload
18
- def __getitem__(self, idx: slice) -> "TypedModuleList[TModule]": ...
18
+ def __getitem__(self, idx: int) -> TModule: ...
19
19
 
20
20
  @override
21
21
  def __getitem__(self, idx: int | slice) -> TModule | "TypedModuleList[TModule]":