nshtrainer 0.39.0__py3-none-any.whl → 0.40.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/__init__.py CHANGED
@@ -8,6 +8,7 @@ from . import model as model
8
8
  from . import nn as nn
9
9
  from . import optimizer as optimizer
10
10
  from . import profiler as profiler
11
+ from .data import LightningDataModuleBase as LightningDataModuleBase
11
12
  from .metrics import MetricConfig as MetricConfig
12
13
  from .model import BaseConfig as BaseConfig
13
14
  from .model import LightningModuleBase as LightningModuleBase
@@ -1,4 +1,5 @@
1
1
  from . import transform as dataset_transform
2
2
  from .balanced_batch_sampler import BalancedBatchSampler as BalancedBatchSampler
3
+ from .datamodule import LightningDataModuleBase as LightningDataModuleBase
3
4
 
4
5
  _ = dataset_transform
@@ -0,0 +1,5 @@
1
+ from lightning.pytorch import LightningDataModule
2
+
3
+
4
+ class LightningDataModuleBase(LightningDataModule):
5
+ pass
@@ -92,7 +92,7 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
92
92
  - "none" or False: Do not log any checkpoints
93
93
  """
94
94
 
95
- log_code: WandbUploadCodeConfig | None = None
95
+ log_code: WandbUploadCodeConfig | None = WandbUploadCodeConfig()
96
96
  """WandB code upload configuration. Used to upload code to WandB."""
97
97
 
98
98
  watch: WandbWatchConfig | None = WandbWatchConfig()
@@ -12,10 +12,10 @@ class TypedModuleList(nn.ModuleList, Generic[TModule]):
12
12
  super().__init__(modules)
13
13
 
14
14
  @overload
15
- def __getitem__(self, idx: int) -> TModule: ...
15
+ def __getitem__(self, idx: slice) -> "TypedModuleList[TModule]": ...
16
16
 
17
17
  @overload
18
- def __getitem__(self, idx: slice) -> "TypedModuleList[TModule]": ...
18
+ def __getitem__(self, idx: int) -> TModule: ...
19
19
 
20
20
  @override
21
21
  def __getitem__(self, idx: int | slice) -> TModule | "TypedModuleList[TModule]":
nshtrainer/runner.py CHANGED
@@ -1,12 +1,12 @@
1
1
  import copy
2
- import functools
2
+ import logging
3
3
  from collections.abc import Callable, Iterable, Mapping, Sequence
4
+ from pathlib import Path
4
5
  from typing import Generic
5
6
 
6
- from nshrunner import RunInfo, Snapshot
7
- from nshrunner import Runner as _Runner
7
+ import nshrunner as nr
8
8
  from nshrunner._submit import screen
9
- from typing_extensions import TypeVar, TypeVarTuple, Unpack, override
9
+ from typing_extensions import TypeVar, TypeVarTuple, Unpack, deprecated, override
10
10
 
11
11
  from .model.config import BaseConfig
12
12
 
@@ -15,34 +15,27 @@ TArguments = TypeVarTuple("TArguments", default=Unpack[tuple[()]])
15
15
  TReturn = TypeVar("TReturn", infer_variance=True)
16
16
 
17
17
 
18
+ @deprecated("Use nshrunner.Runner instead.")
18
19
  class Runner(
19
- _Runner[TReturn, TConfig, Unpack[TArguments]],
20
+ nr.Runner[TReturn, TConfig, Unpack[TArguments]],
20
21
  Generic[TReturn, TConfig, Unpack[TArguments]],
21
22
  ):
22
23
  @override
23
- @classmethod
24
- def default_validate_fn(cls, config: TConfig, *args: Unpack[TArguments]) -> None:
25
- super().default_validate_fn(config, *args)
26
-
27
- @override
28
- @classmethod
29
- def default_info_fn(cls, config: TConfig, *args: Unpack[TArguments]) -> RunInfo:
30
- run_info = super().default_info_fn(config, *args)
31
- return {
32
- **run_info,
33
- "id": config.id,
34
- "base_dir": config.directory.project_root,
35
- }
36
-
37
- def _fast_dev_run_transform(
24
+ def __init__(
38
25
  self,
39
- config: TConfig,
40
- *args: Unpack[TArguments],
41
- n_batches: int,
26
+ run_fn: Callable[[TConfig, Unpack[TArguments]], TReturn],
27
+ config: nr.RunnerConfig | None = None,
42
28
  ):
43
- config = copy.deepcopy(config)
44
- config.trainer.fast_dev_run = n_batches
45
- return (config, *args)
29
+ if config is None:
30
+ working_dir = Path.cwd() / "nshrunner"
31
+ working_dir.mkdir(exist_ok=True)
32
+
33
+ logging.warning(
34
+ f"`config` is not provided. Using default working directory of {working_dir}."
35
+ )
36
+ config = nr.RunnerConfig(working_dir=working_dir)
37
+
38
+ super().__init__(run_fn, config)
46
39
 
47
40
  def fast_dev_run(
48
41
  self,
@@ -50,20 +43,15 @@ class Runner(
50
43
  n_batches: int = 1,
51
44
  *,
52
45
  env: Mapping[str, str] | None = None,
53
- transforms: list[
54
- Callable[[TConfig, Unpack[TArguments]], tuple[TConfig, Unpack[TArguments]]]
55
- ]
56
- | None = None,
57
46
  ):
58
- transforms = transforms or []
59
- transforms.append(
60
- functools.partial(self._fast_dev_run_transform, n_batches=n_batches)
61
- )
62
- return self.local(
63
- runs,
64
- env=env,
65
- transforms=transforms,
66
- )
47
+ runs_updated: list[tuple[TConfig, Unpack[TArguments]]] = []
48
+ for args in runs:
49
+ config = copy.deepcopy(args[0])
50
+ config.trainer.fast_dev_run = n_batches
51
+ runs_updated.append((config, *args[1:]))
52
+ del runs
53
+
54
+ return self.local(runs_updated, env=env)
67
55
 
68
56
  def fast_dev_run_generator(
69
57
  self,
@@ -71,20 +59,15 @@ class Runner(
71
59
  n_batches: int = 1,
72
60
  *,
73
61
  env: Mapping[str, str] | None = None,
74
- transforms: list[
75
- Callable[[TConfig, Unpack[TArguments]], tuple[TConfig, Unpack[TArguments]]]
76
- ]
77
- | None = None,
78
62
  ):
79
- transforms = transforms or []
80
- transforms.append(
81
- functools.partial(self._fast_dev_run_transform, n_batches=n_batches)
82
- )
83
- return self.local_generator(
84
- runs,
85
- env=env,
86
- transforms=transforms,
87
- )
63
+ runs_updated: list[tuple[TConfig, Unpack[TArguments]]] = []
64
+ for args in runs:
65
+ config = copy.deepcopy(args[0])
66
+ config.trainer.fast_dev_run = n_batches
67
+ runs_updated.append((config, *args[1:]))
68
+ del runs
69
+
70
+ return self.local_generator(runs_updated, env=env)
88
71
 
89
72
  def fast_dev_run_session(
90
73
  self,
@@ -92,27 +75,25 @@ class Runner(
92
75
  options: screen.ScreenJobKwargs = {},
93
76
  n_batches: int = 1,
94
77
  *,
95
- snapshot: Snapshot,
78
+ snapshot: nr.Snapshot,
96
79
  setup_commands: Sequence[str] | None = None,
97
80
  env: Mapping[str, str] | None = None,
98
- transforms: list[
99
- Callable[[TConfig, Unpack[TArguments]], tuple[TConfig, Unpack[TArguments]]]
100
- ]
101
- | None = None,
102
81
  activate_venv: bool = True,
103
82
  print_command: bool = True,
104
83
  ):
105
- transforms = transforms or []
106
- transforms.append(
107
- functools.partial(self._fast_dev_run_transform, n_batches=n_batches)
108
- )
84
+ runs_updated: list[tuple[TConfig, Unpack[TArguments]]] = []
85
+ for args in runs:
86
+ config = copy.deepcopy(args[0])
87
+ config.trainer.fast_dev_run = n_batches
88
+ runs_updated.append((config, *args[1:]))
89
+ del runs
90
+
109
91
  return self.session(
110
- runs,
92
+ runs_updated,
111
93
  options,
112
94
  snapshot=snapshot,
113
95
  setup_commands=setup_commands,
114
96
  env=env,
115
- transforms=transforms,
116
97
  activate_venv=activate_venv,
117
98
  print_command=print_command,
118
99
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.39.0
3
+ Version: 0.40.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,4 +1,4 @@
1
- nshtrainer/__init__.py,sha256=flMI50Hj1Ie8c1YMSUQ759AqtNBQLT_zHaV2J9EUmOs,573
1
+ nshtrainer/__init__.py,sha256=8hx3uqzroMYOQM7S4U3Eznw7MHF3YDqUkO1MDQeC4cM,642
2
2
  nshtrainer/_callback.py,sha256=A1zLsTy4b_wOYnInLLXGSRdHzT2yNa6mPEql-ozm0u0,1013
3
3
  nshtrainer/_checkpoint/loader.py,sha256=5vjg-OFChXJjgiOVv8vnV8nwTscfdDtEdxQRz6uPfDE,14158
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=5D4PgKodzhLsmQvuF3xxkH49epKaegxi4wh_ImDTtns,4737
@@ -32,8 +32,9 @@ nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50
32
32
  nshtrainer/callbacks/wandb_upload_code.py,sha256=OWG4UkL2SfW6oj6AGRXeBJsZmgsqeHLW2Fj8Jm4ga3I,2298
33
33
  nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
34
34
  nshtrainer/config.py,sha256=pZyRZOkBRR7eBFRiHpHjQFNEFjaX9tYZIAqZtvKi6cA,8312
35
- nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
35
+ nshtrainer/data/__init__.py,sha256=fnigdEtNcB8DJqz4MHvcI6UwOZICFzWItskY78lLTMA,224
36
36
  nshtrainer/data/balanced_batch_sampler.py,sha256=ybMJF-CguaZ17fLEweZ5suaGOiHOMEm3Bn8rQfGTzGQ,5445
37
+ nshtrainer/data/datamodule.py,sha256=Dk06LdcbK13sCLQWacdCoXU_CHWZxoKEZ_xpAg7nZQg,113
37
38
  nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
38
39
  nshtrainer/ll/__init__.py,sha256=L-aTi1V1bbvnZjOro8NvI393zbHQSFR9movWSRK9Mds,2477
39
40
  nshtrainer/ll/_experimental.py,sha256=oBQCKOEVYoxuUU9eLb-Fg2B2mzZD7SA0zfAO6lmWZ88,53
@@ -56,7 +57,7 @@ nshtrainer/loggers/__init__.py,sha256=C_xk0A3_qKbNdTmzK85AgjRHFD3w-jPRS2ig-iPhfE
56
57
  nshtrainer/loggers/_base.py,sha256=xiZKEK0ALJkcqf4OpVNRY0QbZsamR_WR7x7m_68YHXQ,705
57
58
  nshtrainer/loggers/csv.py,sha256=D_lYyd94bZ8jAgnRo-ARtFgVcInaD9zktxtsUD9RWCI,1052
58
59
  nshtrainer/loggers/tensorboard.py,sha256=wL2amRSdP68zbslZvBeM0ZQBnjF3hIKsz-_lBbdomaM,2216
59
- nshtrainer/loggers/wandb.py,sha256=td8J2v8T1nvGQI7OYQ1El6k8FGsXZxbnuY97s8KzCiY,6643
60
+ nshtrainer/loggers/wandb.py,sha256=GMlEzkFesgS3qgVgYqATPHJ3rvs9etwFlk3HbidOyRQ,6662
60
61
  nshtrainer/lr_scheduler/__init__.py,sha256=uEvgaFAs-4s_bAEMaildy0GT6OvgpgOEKTuzqutESHE,736
61
62
  nshtrainer/lr_scheduler/_base.py,sha256=7xOIuxQ86YHbFWG5a3gX46emQj1WN_LaY4-i0Q1TDBg,3659
62
63
  nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=E8LW78uuby7bIsoLPpcF1bmNK4lSko-r3qPL-vuHWXQ,5370
@@ -71,7 +72,7 @@ nshtrainer/model/mixins/logger.py,sha256=xOymSTofukEYZGkGojXsMEO__ZlBI5lIPZVmlot
71
72
  nshtrainer/nn/__init__.py,sha256=0QPFl02a71WZQjLMGOlFNMmsYP5aa1q3eABHmnWH58Q,1427
72
73
  nshtrainer/nn/mlp.py,sha256=V0FrScpIUdg_IgIO8GMtIsGEtmHjwF14i2IWxmZrsqg,5952
73
74
  nshtrainer/nn/module_dict.py,sha256=NOY0B6WDTnktyWH4GthsprMQo0bpehC-hCq9SfD8paE,2329
74
- nshtrainer/nn/module_list.py,sha256=fb2u5Rqdjff8Pekyr9hkCPkBorQ-fldzzFAjsgWAm30,1719
75
+ nshtrainer/nn/module_list.py,sha256=bbR62bnQM4No6HtmBllijOhLwDtWwmcPcO2DklPbsbc,1719
75
76
  nshtrainer/nn/nonlinearity.py,sha256=4sYE4MN5zojc-go1k0PYtqssVRuXrM7D4tbpIXp5K-E,6078
76
77
  nshtrainer/optimizer.py,sha256=kuJEA1pvB3y1FcsfhAoOJujVqEZqFHlmYO8GW6JeA1g,1527
77
78
  nshtrainer/profiler/__init__.py,sha256=cJ_wAm8j3Bz6cKgNQ_9gQLZz9nddKW53VE81UKg8l8g,480
@@ -79,7 +80,7 @@ nshtrainer/profiler/_base.py,sha256=YF5lsJBIl9qts9GLW5Z62JuYeo4SnIArhlFwTGkfTb4,
79
80
  nshtrainer/profiler/advanced.py,sha256=44asloha0aGUW8YwjQt3lm3ve8H-N6mM4QgseUSLT30,1112
80
81
  nshtrainer/profiler/pytorch.py,sha256=tGeRvoPP5ulWX2RkfXrQvMBoki1T95dpz5p8mwyon1I,2709
81
82
  nshtrainer/profiler/simple.py,sha256=MbMfsJvligd0mtGiltxJ0T8MQVDP9T9BzQZFwswl66Y,957
82
- nshtrainer/runner.py,sha256=USAjrExHkN5oVNVunsoPnLxfQrEHSaa54S3RipOe544,3605
83
+ nshtrainer/runner.py,sha256=llW5PThXfzlu8ym1waFVzED1RGI91OHqN3j02sepoZo,3153
83
84
  nshtrainer/scripts/find_packages.py,sha256=ixYivZobumyyGsf2B9oYMLyLTRcBzY_vUv-u3bNW-hs,1424
84
85
  nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
85
86
  nshtrainer/trainer/_config.py,sha256=YqpGb4RodkUg87TVE5WBSc4CQkUF0z3qDRdil1HRxoM,29198
@@ -99,6 +100,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
99
100
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
100
101
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
101
102
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
102
- nshtrainer-0.39.0.dist-info/METADATA,sha256=bIWwvsGuePEZeT3Q8dYS8IK5Y6ZM4yqe9v41Ybs9OGM,916
103
- nshtrainer-0.39.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
104
- nshtrainer-0.39.0.dist-info/RECORD,,
103
+ nshtrainer-0.40.1.dist-info/METADATA,sha256=m9hT1zAMnexEhlbXEvTFJBbDsWffVWw11M9NxqqTrZg,916
104
+ nshtrainer-0.40.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
105
+ nshtrainer-0.40.1.dist-info/RECORD,,