nshtrainer 1.0.0b28__py3-none-any.whl → 1.0.0b29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,12 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import builtins
4
- from typing import Literal
4
+ from typing import Any, Literal
5
5
 
6
6
  import nshconfig as C
7
7
 
8
- from ..util._useful_types import SupportsRichComparisonT
9
-
10
8
 
11
9
  class MetricConfig(C.Config):
12
10
  name: str
@@ -40,5 +38,5 @@ class MetricConfig(C.Config):
40
38
  def best(self):
41
39
  return builtins.min if self.mode == "min" else builtins.max
42
40
 
43
- def is_better(self, a: SupportsRichComparisonT, b: SupportsRichComparisonT) -> bool:
41
+ def is_better(self, a: Any, b: Any):
44
42
  return self.best(a, b) == a
@@ -82,6 +82,13 @@ class PluginConfigBase(C.Config, ABC):
82
82
 
83
83
 
84
84
  plugin_registry = C.Registry(PluginConfigBase, discriminator="name")
85
+ PluginConfig = TypeAliasType(
86
+ "PluginConfig", Annotated[PluginConfigBase, plugin_registry.DynamicResolution()]
87
+ )
88
+
89
+ AcceleratorLiteral = TypeAliasType(
90
+ "AcceleratorLiteral", Literal["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"]
91
+ )
85
92
 
86
93
 
87
94
  class AcceleratorConfigBase(C.Config, ABC):
@@ -90,18 +97,9 @@ class AcceleratorConfigBase(C.Config, ABC):
90
97
 
91
98
 
92
99
  accelerator_registry = C.Registry(AcceleratorConfigBase, discriminator="name")
93
-
94
-
95
- class StrategyConfigBase(C.Config, ABC):
96
- @abstractmethod
97
- def create_strategy(self) -> Strategy: ...
98
-
99
-
100
- strategy_registry = C.Registry(StrategyConfigBase, discriminator="name")
101
-
102
-
103
- AcceleratorLiteral = TypeAliasType(
104
- "AcceleratorLiteral", Literal["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"]
100
+ AcceleratorConfig = TypeAliasType(
101
+ "AcceleratorConfig",
102
+ Annotated[AcceleratorConfigBase, accelerator_registry.DynamicResolution()],
105
103
  )
106
104
 
107
105
  StrategyLiteral = TypeAliasType(
@@ -137,6 +135,17 @@ StrategyLiteral = TypeAliasType(
137
135
  )
138
136
 
139
137
 
138
+ class StrategyConfigBase(C.Config, ABC):
139
+ @abstractmethod
140
+ def create_strategy(self) -> Strategy: ...
141
+
142
+
143
+ strategy_registry = C.Registry(StrategyConfigBase, discriminator="name")
144
+ StrategyConfig = TypeAliasType(
145
+ "StrategyConfig",
146
+ Annotated[StrategyConfigBase, strategy_registry.DynamicResolution()],
147
+ )
148
+
140
149
  CheckpointCallbackConfig = TypeAliasType(
141
150
  "CheckpointCallbackConfig",
142
151
  Annotated[
@@ -578,9 +587,7 @@ class TrainerConfig(C.Config):
578
587
  Default: ``False``.
579
588
  """
580
589
 
581
- plugins: (
582
- list[Annotated[PluginConfigBase, plugin_registry.DynamicResolution()]] | None
583
- ) = None
590
+ plugins: list[PluginConfig] | None = None
584
591
  """
585
592
  Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
586
593
  Default: ``None``.
@@ -740,21 +747,13 @@ class TrainerConfig(C.Config):
740
747
  Default: ``True``.
741
748
  """
742
749
 
743
- accelerator: (
744
- Annotated[AcceleratorConfigBase, accelerator_registry.DynamicResolution()]
745
- | AcceleratorLiteral
746
- | None
747
- ) = None
750
+ accelerator: AcceleratorConfig | AcceleratorLiteral | None = None
748
751
  """Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto")
749
752
  as well as custom accelerator instances.
750
753
  Default: ``"auto"``.
751
754
  """
752
755
 
753
- strategy: (
754
- Annotated[StrategyConfigBase, strategy_registry.DynamicResolution()]
755
- | StrategyLiteral
756
- | None
757
- ) = None
756
+ strategy: StrategyConfig | StrategyLiteral | None = None
758
757
  """Supports different training strategies with aliases as well custom strategies.
759
758
  Default: ``"auto"``.
760
759
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 1.0.0b28
3
+ Version: 1.0.0b29
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -104,7 +104,7 @@ nshtrainer/lr_scheduler/_base.py,sha256=EhA2f_WiZ79RcXL2nJbwCwNK620c8ugEVUmJ8CcV
104
104
  nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=gvUuv031lvWdXboDeH7iAF3ZgNPQK40bQwfmqb11TNk,5492
105
105
  nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=vXH5S26ESHO_LPPqW8aDC3S5NGoZYkXeFjAOgttaUX8,2870
106
106
  nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
107
- nshtrainer/metrics/_config.py,sha256=kR9OJLZXIaZpG8UpHG3EQwOL36HB8yPk6afYjoIM0XM,1324
107
+ nshtrainer/metrics/_config.py,sha256=XIRokFM8PHrhBa3w2R6BM6a4es3ncsoBqE_LqXQFsFE,1223
108
108
  nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
109
109
  nshtrainer/model/base.py,sha256=JL3AmH17GQjQIoMrZl3O0vUI7dj5ZsO5iEJgoLPyzHw,10356
110
110
  nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
@@ -122,12 +122,11 @@ nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5N
122
122
  nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
123
123
  nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
124
124
  nshtrainer/trainer/__init__.py,sha256=MmoydVS6aYeav7zgDAUHxAQrV_PMQsbnZTCuPnLH9Wk,128
125
- nshtrainer/trainer/_config.py,sha256=Mz9J2ZFqxTlttnRA1eScGRgSAuf3-o3i9-xjN7eTm-k,35256
125
+ nshtrainer/trainer/_config.py,sha256=VD0DfdS-pyQ2nFG83c4u5AUkSAHODmXLX5s2qtvS_to,35400
126
126
  nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
127
127
  nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
128
128
  nshtrainer/trainer/trainer.py,sha256=HHqT83zWtYY9g5yD6X9aWrVh5VSpILW8PhoE6fp4snE,20734
129
129
  nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
130
- nshtrainer/util/_useful_types.py,sha256=7yd1ajSmjwfmZdBPlHVrIG3iXl1-T3n83JI53N8C7as,8080
131
130
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
132
131
  nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
133
132
  nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
@@ -138,6 +137,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
138
137
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
139
138
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
140
139
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
141
- nshtrainer-1.0.0b28.dist-info/METADATA,sha256=1MJi65pa7HEVmtDR64Y32SwDe_bv1AZHSgyo6gIBmzo,988
142
- nshtrainer-1.0.0b28.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
143
- nshtrainer-1.0.0b28.dist-info/RECORD,,
140
+ nshtrainer-1.0.0b29.dist-info/METADATA,sha256=YRehZvU9svmmfAmwFrdmu-Tzxgi_EHbFwrn-ewD8W9c,988
141
+ nshtrainer-1.0.0b29.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
142
+ nshtrainer-1.0.0b29.dist-info/RECORD,,
@@ -1,316 +0,0 @@
1
- """Credit to useful-types from https://github.com/hauntsaninja/useful_types"""
2
-
3
- from __future__ import annotations
4
-
5
- from collections.abc import Awaitable, Iterable, Iterator, Sequence, Sized
6
- from collections.abc import Set as AbstractSet
7
- from os import PathLike
8
- from typing import Any, TypeVar, overload
9
-
10
- from typing_extensions import (
11
- Buffer,
12
- Literal,
13
- Protocol,
14
- SupportsIndex,
15
- TypeAlias,
16
- TypeAliasType,
17
- )
18
-
19
- _KT = TypeVar("_KT")
20
- _KT_co = TypeVar("_KT_co", covariant=True)
21
- _KT_contra = TypeVar("_KT_contra", contravariant=True)
22
- _VT = TypeVar("_VT")
23
- _VT_co = TypeVar("_VT_co", covariant=True)
24
- _T = TypeVar("_T")
25
- _T_co = TypeVar("_T_co", covariant=True)
26
- _T_contra = TypeVar("_T_contra", contravariant=True)
27
-
28
- # For partially known annotations. Usually, fields where type annotations
29
- # haven't been added are left unannotated, but in some situations this
30
- # isn't possible or a type is already partially known. In cases like these,
31
- # use Incomplete instead of Any as a marker. For example, use
32
- # "Incomplete | None" instead of "Any | None".
33
- Incomplete: TypeAlias = Any
34
-
35
-
36
- class IdentityFunction(Protocol):
37
- def __call__(self, __x: _T) -> _T: ...
38
-
39
-
40
- # ====================
41
- # Comparison protocols
42
- # ====================
43
-
44
-
45
- class SupportsDunderLT(Protocol[_T_contra]):
46
- def __lt__(self, __other: _T_contra) -> bool: ...
47
-
48
-
49
- class SupportsDunderGT(Protocol[_T_contra]):
50
- def __gt__(self, __other: _T_contra) -> bool: ...
51
-
52
-
53
- class SupportsDunderLE(Protocol[_T_contra]):
54
- def __le__(self, __other: _T_contra) -> bool: ...
55
-
56
-
57
- class SupportsDunderGE(Protocol[_T_contra]):
58
- def __ge__(self, __other: _T_contra) -> bool: ...
59
-
60
-
61
- class SupportsAllComparisons(
62
- SupportsDunderLT[Any],
63
- SupportsDunderGT[Any],
64
- SupportsDunderLE[Any],
65
- SupportsDunderGE[Any],
66
- Protocol,
67
- ): ...
68
-
69
-
70
- SupportsRichComparison = TypeAliasType(
71
- "SupportsRichComparison", SupportsDunderLT[Any] | SupportsDunderGT[Any]
72
- )
73
- SupportsRichComparisonT = TypeVar(
74
- "SupportsRichComparisonT", bound=SupportsRichComparison
75
- )
76
-
77
- # ====================
78
- # Dunder protocols
79
- # ====================
80
-
81
-
82
- class SupportsNext(Protocol[_T_co]):
83
- def __next__(self) -> _T_co: ...
84
-
85
-
86
- class SupportsAnext(Protocol[_T_co]):
87
- def __anext__(self) -> Awaitable[_T_co]: ...
88
-
89
-
90
- class SupportsAdd(Protocol[_T_contra, _T_co]):
91
- def __add__(self, __x: _T_contra) -> _T_co: ...
92
-
93
-
94
- class SupportsRAdd(Protocol[_T_contra, _T_co]):
95
- def __radd__(self, __x: _T_contra) -> _T_co: ...
96
-
97
-
98
- class SupportsSub(Protocol[_T_contra, _T_co]):
99
- def __sub__(self, __x: _T_contra) -> _T_co: ...
100
-
101
-
102
- class SupportsRSub(Protocol[_T_contra, _T_co]):
103
- def __rsub__(self, __x: _T_contra) -> _T_co: ...
104
-
105
-
106
- class SupportsDivMod(Protocol[_T_contra, _T_co]):
107
- def __divmod__(self, __other: _T_contra) -> _T_co: ...
108
-
109
-
110
- class SupportsRDivMod(Protocol[_T_contra, _T_co]):
111
- def __rdivmod__(self, __other: _T_contra) -> _T_co: ...
112
-
113
-
114
- # This protocol is generic over the iterator type, while Iterable is
115
- # generic over the type that is iterated over.
116
- class SupportsIter(Protocol[_T_co]):
117
- def __iter__(self) -> _T_co: ...
118
-
119
-
120
- # This protocol is generic over the iterator type, while AsyncIterable is
121
- # generic over the type that is iterated over.
122
- class SupportsAiter(Protocol[_T_co]):
123
- def __aiter__(self) -> _T_co: ...
124
-
125
-
126
- class SupportsLenAndGetItem(Protocol[_T_co]):
127
- def __len__(self) -> int: ...
128
- def __getitem__(self, __k: int) -> _T_co: ...
129
-
130
-
131
- class SupportsTrunc(Protocol):
132
- def __trunc__(self) -> int: ...
133
-
134
-
135
- # ====================
136
- # Mapping-like protocols
137
- # ====================
138
-
139
-
140
- class SupportsItems(Protocol[_KT_co, _VT_co]):
141
- def items(self) -> AbstractSet[tuple[_KT_co, _VT_co]]: ...
142
-
143
-
144
- class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
145
- def keys(self) -> Iterable[_KT]: ...
146
- def __getitem__(self, __key: _KT) -> _VT_co: ...
147
-
148
-
149
- class SupportsGetItem(Protocol[_KT_contra, _VT_co]):
150
- def __contains__(self, __x: Any) -> bool: ...
151
- def __getitem__(self, __key: _KT_contra) -> _VT_co: ...
152
-
153
-
154
- class SupportsItemAccess(SupportsGetItem[_KT_contra, _VT], Protocol[_KT_contra, _VT]):
155
- def __setitem__(self, __key: _KT_contra, __value: _VT) -> None: ...
156
- def __delitem__(self, __key: _KT_contra) -> None: ...
157
-
158
-
159
- # ====================
160
- # File handling
161
- # ====================
162
-
163
- StrPath: TypeAlias = str | PathLike[str]
164
- BytesPath: TypeAlias = bytes | PathLike[bytes]
165
- StrOrBytesPath: TypeAlias = str | bytes | PathLike[str] | PathLike[bytes]
166
-
167
- OpenTextModeUpdating: TypeAlias = Literal[
168
- "r+",
169
- "+r",
170
- "rt+",
171
- "r+t",
172
- "+rt",
173
- "tr+",
174
- "t+r",
175
- "+tr",
176
- "w+",
177
- "+w",
178
- "wt+",
179
- "w+t",
180
- "+wt",
181
- "tw+",
182
- "t+w",
183
- "+tw",
184
- "a+",
185
- "+a",
186
- "at+",
187
- "a+t",
188
- "+at",
189
- "ta+",
190
- "t+a",
191
- "+ta",
192
- "x+",
193
- "+x",
194
- "xt+",
195
- "x+t",
196
- "+xt",
197
- "tx+",
198
- "t+x",
199
- "+tx",
200
- ]
201
- OpenTextModeWriting: TypeAlias = Literal[
202
- "w", "wt", "tw", "a", "at", "ta", "x", "xt", "tx"
203
- ]
204
- OpenTextModeReading: TypeAlias = Literal[
205
- "r", "rt", "tr", "U", "rU", "Ur", "rtU", "rUt", "Urt", "trU", "tUr", "Utr"
206
- ]
207
- OpenTextMode: TypeAlias = (
208
- OpenTextModeUpdating | OpenTextModeWriting | OpenTextModeReading
209
- )
210
- OpenBinaryModeUpdating: TypeAlias = Literal[
211
- "rb+",
212
- "r+b",
213
- "+rb",
214
- "br+",
215
- "b+r",
216
- "+br",
217
- "wb+",
218
- "w+b",
219
- "+wb",
220
- "bw+",
221
- "b+w",
222
- "+bw",
223
- "ab+",
224
- "a+b",
225
- "+ab",
226
- "ba+",
227
- "b+a",
228
- "+ba",
229
- "xb+",
230
- "x+b",
231
- "+xb",
232
- "bx+",
233
- "b+x",
234
- "+bx",
235
- ]
236
- OpenBinaryModeWriting: TypeAlias = Literal["wb", "bw", "ab", "ba", "xb", "bx"]
237
- OpenBinaryModeReading: TypeAlias = Literal[
238
- "rb", "br", "rbU", "rUb", "Urb", "brU", "bUr", "Ubr"
239
- ]
240
- OpenBinaryMode: TypeAlias = (
241
- OpenBinaryModeUpdating | OpenBinaryModeReading | OpenBinaryModeWriting
242
- )
243
-
244
-
245
- class HasFileno(Protocol):
246
- def fileno(self) -> int: ...
247
-
248
-
249
- FileDescriptor: TypeAlias = int
250
- FileDescriptorLike: TypeAlias = int | HasFileno
251
- FileDescriptorOrPath: TypeAlias = int | StrOrBytesPath
252
-
253
-
254
- class SupportsRead(Protocol[_T_co]):
255
- def read(self, __length: int = ...) -> _T_co: ...
256
-
257
-
258
- class SupportsReadline(Protocol[_T_co]):
259
- def readline(self, __length: int = ...) -> _T_co: ...
260
-
261
-
262
- class SupportsNoArgReadline(Protocol[_T_co]):
263
- def readline(self) -> _T_co: ...
264
-
265
-
266
- class SupportsWrite(Protocol[_T_contra]):
267
- def write(self, __s: _T_contra) -> object: ...
268
-
269
-
270
- # ====================
271
- # Buffer protocols
272
- # ====================
273
-
274
- # Unfortunately PEP 688 does not allow us to distinguish read-only
275
- # from writable buffers. We use these aliases for readability for now.
276
- # Perhaps a future extension of the buffer protocol will allow us to
277
- # distinguish these cases in the type system.
278
- ReadOnlyBuffer: TypeAlias = Buffer
279
- # Anything that implements the read-write buffer interface.
280
- WriteableBuffer: TypeAlias = Buffer
281
- # Same as WriteableBuffer, but also includes read-only buffer types (like bytes).
282
- ReadableBuffer: TypeAlias = Buffer
283
-
284
-
285
- class SliceableBuffer(Buffer, Protocol):
286
- def __getitem__(self, __slice: slice) -> Sequence[int]: ...
287
-
288
-
289
- class IndexableBuffer(Buffer, Protocol):
290
- def __getitem__(self, __i: int) -> int: ...
291
-
292
-
293
- class SupportsGetItemBuffer(SliceableBuffer, IndexableBuffer, Protocol):
294
- def __contains__(self, __x: Any) -> bool: ...
295
- @overload
296
- def __getitem__(self, __slice: slice) -> Sequence[int]: ...
297
- @overload
298
- def __getitem__(self, __i: int) -> int: ...
299
-
300
-
301
- class SizedBuffer(Sized, Buffer, Protocol): ...
302
-
303
-
304
- # Source from https://github.com/python/typing/issues/256#issuecomment-1442633430
305
- # This works because str.__contains__ does not accept object (either in typeshed or at runtime)
306
- class SequenceNotStr(Protocol[_T_co]):
307
- @overload
308
- def __getitem__(self, index: SupportsIndex, /) -> _T_co: ...
309
- @overload
310
- def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ...
311
- def __contains__(self, value: object, /) -> bool: ...
312
- def __len__(self) -> int: ...
313
- def __iter__(self) -> Iterator[_T_co]: ...
314
- def index(self, value: Any, start: int = 0, stop: int = ..., /) -> int: ...
315
- def count(self, value: Any, /) -> int: ...
316
- def __reversed__(self) -> Iterator[_T_co]: ...