nshtrainer 1.0.0b27__py3-none-any.whl → 1.0.0b29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,6 +14,7 @@ from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
14
14
  from nshtrainer.callbacks import (
15
15
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
16
16
  )
17
+ from nshtrainer.callbacks import CallbackConfig as CallbackConfig
17
18
  from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
18
19
  from nshtrainer.callbacks import (
19
20
  DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
@@ -60,11 +61,13 @@ from nshtrainer.callbacks.checkpoint._base import (
60
61
  from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
61
62
  from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
62
63
  from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
64
+ from nshtrainer.loggers import LoggerConfig as LoggerConfig
63
65
  from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
64
66
  from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
65
67
  from nshtrainer.lr_scheduler import (
66
68
  LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
67
69
  )
70
+ from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
68
71
  from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
69
72
  from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
70
73
  from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
@@ -73,6 +76,7 @@ from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
73
76
  from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
74
77
  from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
75
78
  from nshtrainer.nn import MLPConfig as MLPConfig
79
+ from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
76
80
  from nshtrainer.nn import PReLUConfig as PReLUConfig
77
81
  from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
78
82
  from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
@@ -86,12 +90,17 @@ from nshtrainer.nn.nonlinearity import (
86
90
  SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
87
91
  )
88
92
  from nshtrainer.optimizer import AdamWConfig as AdamWConfig
93
+ from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
89
94
  from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
90
95
  from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
91
96
  from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
97
+ from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
92
98
  from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
93
99
  from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
94
100
  from nshtrainer.trainer._config import AcceleratorConfigBase as AcceleratorConfigBase
101
+ from nshtrainer.trainer._config import (
102
+ CheckpointCallbackConfig as CheckpointCallbackConfig,
103
+ )
95
104
  from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
96
105
  from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
97
106
  from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
@@ -130,6 +139,7 @@ from nshtrainer.util._environment_info import (
130
139
  )
131
140
  from nshtrainer.util._environment_info import GitRepositoryConfig as GitRepositoryConfig
132
141
  from nshtrainer.util.config import DTypeConfig as DTypeConfig
142
+ from nshtrainer.util.config import DurationConfig as DurationConfig
133
143
  from nshtrainer.util.config import EpochsConfig as EpochsConfig
134
144
  from nshtrainer.util.config import StepsConfig as StepsConfig
135
145
 
@@ -158,13 +168,16 @@ __all__ = [
158
168
  "BaseProfilerConfig",
159
169
  "BestCheckpointCallbackConfig",
160
170
  "CSVLoggerConfig",
171
+ "CallbackConfig",
161
172
  "CallbackConfigBase",
173
+ "CheckpointCallbackConfig",
162
174
  "CheckpointMetadata",
163
175
  "CheckpointSavingConfig",
164
176
  "DTypeConfig",
165
177
  "DebugFlagCallbackConfig",
166
178
  "DirectoryConfig",
167
179
  "DirectorySetupCallbackConfig",
180
+ "DurationConfig",
168
181
  "ELUNonlinearityConfig",
169
182
  "EMACallbackConfig",
170
183
  "EarlyStoppingCallbackConfig",
@@ -187,21 +200,26 @@ __all__ = [
187
200
  "GradientSkippingCallbackConfig",
188
201
  "HuggingFaceHubAutoCreateConfig",
189
202
  "HuggingFaceHubConfig",
203
+ "LRSchedulerConfig",
190
204
  "LRSchedulerConfigBase",
191
205
  "LastCheckpointCallbackConfig",
192
206
  "LeakyReLUNonlinearityConfig",
193
207
  "LearningRateMonitorConfig",
194
208
  "LinearWarmupCosineDecayLRSchedulerConfig",
195
209
  "LogEpochCallbackConfig",
210
+ "LoggerConfig",
196
211
  "MLPConfig",
197
212
  "MetricConfig",
198
213
  "MishNonlinearityConfig",
214
+ "NonlinearityConfig",
199
215
  "NormLoggingCallbackConfig",
200
216
  "OnExceptionCheckpointCallbackConfig",
217
+ "OptimizerConfig",
201
218
  "OptimizerConfigBase",
202
219
  "PReLUConfig",
203
220
  "PluginConfigBase",
204
221
  "PrintTableMetricsCallbackConfig",
222
+ "ProfilerConfig",
205
223
  "PyTorchProfilerConfig",
206
224
  "RLPSanityChecksCallbackConfig",
207
225
  "ReLUNonlinearityConfig",
@@ -6,8 +6,10 @@ from nshtrainer._directory import DirectoryConfig as DirectoryConfig
6
6
  from nshtrainer._directory import (
7
7
  DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
8
8
  )
9
+ from nshtrainer._directory import LoggerConfig as LoggerConfig
9
10
 
10
11
  __all__ = [
11
12
  "DirectoryConfig",
12
13
  "DirectorySetupCallbackConfig",
14
+ "LoggerConfig",
13
15
  ]
@@ -5,6 +5,7 @@ __codegen__ = True
5
5
  from nshtrainer.callbacks import (
6
6
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
7
7
  )
8
+ from nshtrainer.callbacks import CallbackConfig as CallbackConfig
8
9
  from nshtrainer.callbacks import CallbackConfigBase as CallbackConfigBase
9
10
  from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
10
11
  from nshtrainer.callbacks import (
@@ -80,6 +81,7 @@ __all__ = [
80
81
  "ActSaveConfig",
81
82
  "BaseCheckpointCallbackConfig",
82
83
  "BestCheckpointCallbackConfig",
84
+ "CallbackConfig",
83
85
  "CallbackConfigBase",
84
86
  "CheckpointMetadata",
85
87
  "DebugFlagCallbackConfig",
@@ -5,6 +5,7 @@ __codegen__ = True
5
5
  from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
6
6
  from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
7
7
  from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
8
+ from nshtrainer.loggers import LoggerConfig as LoggerConfig
8
9
  from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
9
10
  from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
10
11
  from nshtrainer.loggers.wandb import CallbackConfigBase as CallbackConfigBase
@@ -26,6 +27,7 @@ __all__ = [
26
27
  "BaseLoggerConfig",
27
28
  "CSVLoggerConfig",
28
29
  "CallbackConfigBase",
30
+ "LoggerConfig",
29
31
  "TensorboardLoggerConfig",
30
32
  "WandbLoggerConfig",
31
33
  "WandbUploadCodeCallbackConfig",
@@ -5,8 +5,12 @@ __codegen__ = True
5
5
  from nshtrainer.lr_scheduler import (
6
6
  LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
7
7
  )
8
+ from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
8
9
  from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
9
10
  from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
11
+ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
12
+ DurationConfig as DurationConfig,
13
+ )
10
14
  from nshtrainer.lr_scheduler.reduce_lr_on_plateau import MetricConfig as MetricConfig
11
15
 
12
16
  from . import _base as _base
@@ -14,6 +18,8 @@ from . import linear_warmup_cosine as linear_warmup_cosine
14
18
  from . import reduce_lr_on_plateau as reduce_lr_on_plateau
15
19
 
16
20
  __all__ = [
21
+ "DurationConfig",
22
+ "LRSchedulerConfig",
17
23
  "LRSchedulerConfigBase",
18
24
  "LinearWarmupCosineDecayLRSchedulerConfig",
19
25
  "MetricConfig",
@@ -2,6 +2,9 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
+ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
6
+ DurationConfig as DurationConfig,
7
+ )
5
8
  from nshtrainer.lr_scheduler.linear_warmup_cosine import (
6
9
  LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
7
10
  )
@@ -10,6 +13,7 @@ from nshtrainer.lr_scheduler.linear_warmup_cosine import (
10
13
  )
11
14
 
12
15
  __all__ = [
16
+ "DurationConfig",
13
17
  "LRSchedulerConfigBase",
14
18
  "LinearWarmupCosineDecayLRSchedulerConfig",
15
19
  ]
@@ -8,6 +8,7 @@ from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
8
8
  from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
9
9
  from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
10
10
  from nshtrainer.nn import MLPConfig as MLPConfig
11
+ from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
11
12
  from nshtrainer.nn import PReLUConfig as PReLUConfig
12
13
  from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
13
14
  from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
@@ -31,6 +32,7 @@ __all__ = [
31
32
  "LeakyReLUNonlinearityConfig",
32
33
  "MLPConfig",
33
34
  "MishNonlinearityConfig",
35
+ "NonlinearityConfig",
34
36
  "PReLUConfig",
35
37
  "ReLUNonlinearityConfig",
36
38
  "SiLUNonlinearityConfig",
@@ -4,8 +4,10 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer.nn.mlp import BaseNonlinearityConfig as BaseNonlinearityConfig
6
6
  from nshtrainer.nn.mlp import MLPConfig as MLPConfig
7
+ from nshtrainer.nn.mlp import NonlinearityConfig as NonlinearityConfig
7
8
 
8
9
  __all__ = [
9
10
  "BaseNonlinearityConfig",
10
11
  "MLPConfig",
12
+ "NonlinearityConfig",
11
13
  ]
@@ -9,6 +9,7 @@ from nshtrainer.nn.nonlinearity import (
9
9
  LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig,
10
10
  )
11
11
  from nshtrainer.nn.nonlinearity import MishNonlinearityConfig as MishNonlinearityConfig
12
+ from nshtrainer.nn.nonlinearity import NonlinearityConfig as NonlinearityConfig
12
13
  from nshtrainer.nn.nonlinearity import PReLUConfig as PReLUConfig
13
14
  from nshtrainer.nn.nonlinearity import ReLUNonlinearityConfig as ReLUNonlinearityConfig
14
15
  from nshtrainer.nn.nonlinearity import (
@@ -38,6 +39,7 @@ __all__ = [
38
39
  "GELUNonlinearityConfig",
39
40
  "LeakyReLUNonlinearityConfig",
40
41
  "MishNonlinearityConfig",
42
+ "NonlinearityConfig",
41
43
  "PReLUConfig",
42
44
  "ReLUNonlinearityConfig",
43
45
  "SiLUNonlinearityConfig",
@@ -3,9 +3,11 @@ from __future__ import annotations
3
3
  __codegen__ = True
4
4
 
5
5
  from nshtrainer.optimizer import AdamWConfig as AdamWConfig
6
+ from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
6
7
  from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
7
8
 
8
9
  __all__ = [
9
10
  "AdamWConfig",
11
+ "OptimizerConfig",
10
12
  "OptimizerConfigBase",
11
13
  ]
@@ -4,6 +4,7 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
6
6
  from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
7
+ from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
7
8
  from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
8
9
  from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
9
10
 
@@ -15,6 +16,7 @@ from . import simple as simple
15
16
  __all__ = [
16
17
  "AdvancedProfilerConfig",
17
18
  "BaseProfilerConfig",
19
+ "ProfilerConfig",
18
20
  "PyTorchProfilerConfig",
19
21
  "SimpleProfilerConfig",
20
22
  "_base",
@@ -9,7 +9,11 @@ from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
9
9
  from nshtrainer.trainer._config import (
10
10
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
11
11
  )
12
+ from nshtrainer.trainer._config import CallbackConfig as CallbackConfig
12
13
  from nshtrainer.trainer._config import CallbackConfigBase as CallbackConfigBase
14
+ from nshtrainer.trainer._config import (
15
+ CheckpointCallbackConfig as CheckpointCallbackConfig,
16
+ )
13
17
  from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
14
18
  from nshtrainer.trainer._config import CSVLoggerConfig as CSVLoggerConfig
15
19
  from nshtrainer.trainer._config import (
@@ -29,6 +33,7 @@ from nshtrainer.trainer._config import (
29
33
  LearningRateMonitorConfig as LearningRateMonitorConfig,
30
34
  )
31
35
  from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbackConfig
36
+ from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
32
37
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
33
38
  from nshtrainer.trainer._config import (
34
39
  NormLoggingCallbackConfig as NormLoggingCallbackConfig,
@@ -37,6 +42,7 @@ from nshtrainer.trainer._config import (
37
42
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
38
43
  )
39
44
  from nshtrainer.trainer._config import PluginConfigBase as PluginConfigBase
45
+ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
40
46
  from nshtrainer.trainer._config import (
41
47
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
42
48
  )
@@ -62,7 +68,9 @@ __all__ = [
62
68
  "BaseLoggerConfig",
63
69
  "BestCheckpointCallbackConfig",
64
70
  "CSVLoggerConfig",
71
+ "CallbackConfig",
65
72
  "CallbackConfigBase",
73
+ "CheckpointCallbackConfig",
66
74
  "CheckpointSavingConfig",
67
75
  "DebugFlagCallbackConfig",
68
76
  "DirectoryConfig",
@@ -73,10 +81,12 @@ __all__ = [
73
81
  "LastCheckpointCallbackConfig",
74
82
  "LearningRateMonitorConfig",
75
83
  "LogEpochCallbackConfig",
84
+ "LoggerConfig",
76
85
  "MetricConfig",
77
86
  "NormLoggingCallbackConfig",
78
87
  "OnExceptionCheckpointCallbackConfig",
79
88
  "PluginConfigBase",
89
+ "ProfilerConfig",
80
90
  "RLPSanityChecksCallbackConfig",
81
91
  "SanityCheckingConfig",
82
92
  "SharedParametersCallbackConfig",
@@ -8,7 +8,11 @@ from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
8
8
  from nshtrainer.trainer._config import (
9
9
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
10
10
  )
11
+ from nshtrainer.trainer._config import CallbackConfig as CallbackConfig
11
12
  from nshtrainer.trainer._config import CallbackConfigBase as CallbackConfigBase
13
+ from nshtrainer.trainer._config import (
14
+ CheckpointCallbackConfig as CheckpointCallbackConfig,
15
+ )
12
16
  from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
13
17
  from nshtrainer.trainer._config import CSVLoggerConfig as CSVLoggerConfig
14
18
  from nshtrainer.trainer._config import (
@@ -28,6 +32,7 @@ from nshtrainer.trainer._config import (
28
32
  LearningRateMonitorConfig as LearningRateMonitorConfig,
29
33
  )
30
34
  from nshtrainer.trainer._config import LogEpochCallbackConfig as LogEpochCallbackConfig
35
+ from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
31
36
  from nshtrainer.trainer._config import MetricConfig as MetricConfig
32
37
  from nshtrainer.trainer._config import (
33
38
  NormLoggingCallbackConfig as NormLoggingCallbackConfig,
@@ -36,6 +41,7 @@ from nshtrainer.trainer._config import (
36
41
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
37
42
  )
38
43
  from nshtrainer.trainer._config import PluginConfigBase as PluginConfigBase
44
+ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
39
45
  from nshtrainer.trainer._config import (
40
46
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
41
47
  )
@@ -59,7 +65,9 @@ __all__ = [
59
65
  "BaseLoggerConfig",
60
66
  "BestCheckpointCallbackConfig",
61
67
  "CSVLoggerConfig",
68
+ "CallbackConfig",
62
69
  "CallbackConfigBase",
70
+ "CheckpointCallbackConfig",
63
71
  "CheckpointSavingConfig",
64
72
  "DebugFlagCallbackConfig",
65
73
  "DirectoryConfig",
@@ -70,10 +78,12 @@ __all__ = [
70
78
  "LastCheckpointCallbackConfig",
71
79
  "LearningRateMonitorConfig",
72
80
  "LogEpochCallbackConfig",
81
+ "LoggerConfig",
73
82
  "MetricConfig",
74
83
  "NormLoggingCallbackConfig",
75
84
  "OnExceptionCheckpointCallbackConfig",
76
85
  "PluginConfigBase",
86
+ "ProfilerConfig",
77
87
  "RLPSanityChecksCallbackConfig",
78
88
  "SanityCheckingConfig",
79
89
  "SharedParametersCallbackConfig",
@@ -32,6 +32,7 @@ from nshtrainer.util._environment_info import (
32
32
  )
33
33
  from nshtrainer.util._environment_info import GitRepositoryConfig as GitRepositoryConfig
34
34
  from nshtrainer.util.config import DTypeConfig as DTypeConfig
35
+ from nshtrainer.util.config import DurationConfig as DurationConfig
35
36
  from nshtrainer.util.config import EpochsConfig as EpochsConfig
36
37
  from nshtrainer.util.config import StepsConfig as StepsConfig
37
38
 
@@ -40,6 +41,7 @@ from . import config as config
40
41
 
41
42
  __all__ = [
42
43
  "DTypeConfig",
44
+ "DurationConfig",
43
45
  "EnvironmentCUDAConfig",
44
46
  "EnvironmentClassInformationConfig",
45
47
  "EnvironmentConfig",
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  __codegen__ = True
4
4
 
5
5
  from nshtrainer.util.config import DTypeConfig as DTypeConfig
6
+ from nshtrainer.util.config import DurationConfig as DurationConfig
6
7
  from nshtrainer.util.config import EpochsConfig as EpochsConfig
7
8
  from nshtrainer.util.config import StepsConfig as StepsConfig
8
9
 
@@ -11,6 +12,7 @@ from . import duration as duration
11
12
 
12
13
  __all__ = [
13
14
  "DTypeConfig",
15
+ "DurationConfig",
14
16
  "EpochsConfig",
15
17
  "StepsConfig",
16
18
  "dtype",
@@ -2,10 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  __codegen__ = True
4
4
 
5
+ from nshtrainer.util.config.duration import DurationConfig as DurationConfig
5
6
  from nshtrainer.util.config.duration import EpochsConfig as EpochsConfig
6
7
  from nshtrainer.util.config.duration import StepsConfig as StepsConfig
7
8
 
8
9
  __all__ = [
10
+ "DurationConfig",
9
11
  "EpochsConfig",
10
12
  "StepsConfig",
11
13
  ]
@@ -1,12 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import builtins
4
- from typing import Literal
4
+ from typing import Any, Literal
5
5
 
6
6
  import nshconfig as C
7
7
 
8
- from ..util._useful_types import SupportsRichComparisonT
9
-
10
8
 
11
9
  class MetricConfig(C.Config):
12
10
  name: str
@@ -40,5 +38,5 @@ class MetricConfig(C.Config):
40
38
  def best(self):
41
39
  return builtins.min if self.mode == "min" else builtins.max
42
40
 
43
- def is_better(self, a: SupportsRichComparisonT, b: SupportsRichComparisonT) -> bool:
41
+ def is_better(self, a: Any, b: Any):
44
42
  return self.best(a, b) == a
@@ -82,6 +82,13 @@ class PluginConfigBase(C.Config, ABC):
82
82
 
83
83
 
84
84
  plugin_registry = C.Registry(PluginConfigBase, discriminator="name")
85
+ PluginConfig = TypeAliasType(
86
+ "PluginConfig", Annotated[PluginConfigBase, plugin_registry.DynamicResolution()]
87
+ )
88
+
89
+ AcceleratorLiteral = TypeAliasType(
90
+ "AcceleratorLiteral", Literal["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"]
91
+ )
85
92
 
86
93
 
87
94
  class AcceleratorConfigBase(C.Config, ABC):
@@ -90,18 +97,9 @@ class AcceleratorConfigBase(C.Config, ABC):
90
97
 
91
98
 
92
99
  accelerator_registry = C.Registry(AcceleratorConfigBase, discriminator="name")
93
-
94
-
95
- class StrategyConfigBase(C.Config, ABC):
96
- @abstractmethod
97
- def create_strategy(self) -> Strategy: ...
98
-
99
-
100
- strategy_registry = C.Registry(StrategyConfigBase, discriminator="name")
101
-
102
-
103
- AcceleratorLiteral = TypeAliasType(
104
- "AcceleratorLiteral", Literal["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"]
100
+ AcceleratorConfig = TypeAliasType(
101
+ "AcceleratorConfig",
102
+ Annotated[AcceleratorConfigBase, accelerator_registry.DynamicResolution()],
105
103
  )
106
104
 
107
105
  StrategyLiteral = TypeAliasType(
@@ -137,6 +135,17 @@ StrategyLiteral = TypeAliasType(
137
135
  )
138
136
 
139
137
 
138
+ class StrategyConfigBase(C.Config, ABC):
139
+ @abstractmethod
140
+ def create_strategy(self) -> Strategy: ...
141
+
142
+
143
+ strategy_registry = C.Registry(StrategyConfigBase, discriminator="name")
144
+ StrategyConfig = TypeAliasType(
145
+ "StrategyConfig",
146
+ Annotated[StrategyConfigBase, strategy_registry.DynamicResolution()],
147
+ )
148
+
140
149
  CheckpointCallbackConfig = TypeAliasType(
141
150
  "CheckpointCallbackConfig",
142
151
  Annotated[
@@ -578,9 +587,7 @@ class TrainerConfig(C.Config):
578
587
  Default: ``False``.
579
588
  """
580
589
 
581
- plugins: (
582
- list[Annotated[PluginConfigBase, plugin_registry.DynamicResolution()]] | None
583
- ) = None
590
+ plugins: list[PluginConfig] | None = None
584
591
  """
585
592
  Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
586
593
  Default: ``None``.
@@ -740,21 +747,13 @@ class TrainerConfig(C.Config):
740
747
  Default: ``True``.
741
748
  """
742
749
 
743
- accelerator: (
744
- Annotated[AcceleratorConfigBase, accelerator_registry.DynamicResolution()]
745
- | AcceleratorLiteral
746
- | None
747
- ) = None
750
+ accelerator: AcceleratorConfig | AcceleratorLiteral | None = None
748
751
  """Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto")
749
752
  as well as custom accelerator instances.
750
753
  Default: ``"auto"``.
751
754
  """
752
755
 
753
- strategy: (
754
- Annotated[StrategyConfigBase, strategy_registry.DynamicResolution()]
755
- | StrategyLiteral
756
- | None
757
- ) = None
756
+ strategy: StrategyConfig | StrategyLiteral | None = None
758
757
  """Supports different training strategies with aliases as well custom strategies.
759
758
  Default: ``"auto"``.
760
759
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 1.0.0b27
3
+ Version: 1.0.0b29
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -31,12 +31,12 @@ nshtrainer/callbacks/shared_parameters.py,sha256=ggMI1krkqN7sGOrjK_I96IsTMYMXHoV
31
31
  nshtrainer/callbacks/timer.py,sha256=BB-M7tV4QNYOwY_Su6j9P7IILxVRae_upmDq4qsxiao,4670
32
32
  nshtrainer/callbacks/wandb_upload_code.py,sha256=PTqNE1QB5U8NR5zhbiQZrmQuugX2UV7B12UdMpo9aV0,2353
33
33
  nshtrainer/callbacks/wandb_watch.py,sha256=tTTcFzxd2Ia9xu8tCogQ5CLJZBq1ne5JlpGVE75vKYs,2976
34
- nshtrainer/configs/__init__.py,sha256=crQ0IcNkmXx-dUEyKMg9mhzxWiBYppba6UOKzEgbWzo,9820
34
+ nshtrainer/configs/__init__.py,sha256=zyo4lV9ObB3T3_hhBhzWGNb6MRma4h7QHD3OrypxqEw,10582
35
35
  nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
36
36
  nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
37
- nshtrainer/configs/_directory/__init__.py,sha256=7H3fIh9c31ce0r8JpuzEY8bZptI7tiVLNwVtj729HAY,303
37
+ nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
38
38
  nshtrainer/configs/_hf_hub/__init__.py,sha256=VUgQnyEI2ekBxBIV15L09tKdrfGt7eWxnf30DiCLaso,416
39
- nshtrainer/configs/callbacks/__init__.py,sha256=TwXL-GkDa1j3m1GEfIJ-YaBqazm9wm1uQpzUd6135cA,4265
39
+ nshtrainer/configs/callbacks/__init__.py,sha256=APpF2jmafqbS4CoMmDvFADi0wdmXJ_BvFw4QnnQpok0,4353
40
40
  nshtrainer/configs/callbacks/actsave/__init__.py,sha256=JJg9d8iNGpO-9M1LsK4h1cu3NYWniyIyLQ4SauFCzOs,272
41
41
  nshtrainer/configs/callbacks/base/__init__.py,sha256=V694hzF_ubnA-hwTps30PeFbgDSm3I_UIMTnljM3_OI,176
42
42
  nshtrainer/configs/callbacks/checkpoint/__init__.py,sha256=zPUItxoYWrMT9i1TxOvVhIeTa0NEFg-nDE5FjHfkP-A,1564
@@ -60,35 +60,35 @@ nshtrainer/configs/callbacks/shared_parameters/__init__.py,sha256=Ivef5jk3RMYQDe
60
60
  nshtrainer/configs/callbacks/timer/__init__.py,sha256=RHOQoREp4NxS_AvKNdc0UuUlS0PnqCxxsuOz5D8h7iM,310
61
61
  nshtrainer/configs/callbacks/wandb_upload_code/__init__.py,sha256=WM9hCGFl2LXDUOgkIGaV3tkdnXnVBasrhIILjbIeFUo,358
62
62
  nshtrainer/configs/callbacks/wandb_watch/__init__.py,sha256=MW-ANrF529DxBhopovPjYEQ7nANX9ttd1K4_bJnKXks,322
63
- nshtrainer/configs/loggers/__init__.py,sha256=WcP-g5dgP_cgLHtv7G2gG72nwq_7KzExSiz9ioQ1SJw,1171
63
+ nshtrainer/configs/loggers/__init__.py,sha256=5wTekL79mQxit8f1K3AMllvb0mKertTzOKfC3gpE2Zk,1251
64
64
  nshtrainer/configs/loggers/_base/__init__.py,sha256=HxPPPePsEjlNuhnjsMgYIl0rwj_iqNKKOBTEk_zIOsM,169
65
65
  nshtrainer/configs/loggers/actsave/__init__.py,sha256=2lZQ4bpbjwd4MuUE_Z_PGbmQjjGtWCZUCtXqKO4dTSc,280
66
66
  nshtrainer/configs/loggers/csv/__init__.py,sha256=M3QGF5GKiRGENy3re6LJKpa4A4RThy1FlmaFuR4cPyo,260
67
67
  nshtrainer/configs/loggers/tensorboard/__init__.py,sha256=FbkYXnSohIX6JN5XyI-9y91IJv_T3VB3IwmpagXAnM4,309
68
68
  nshtrainer/configs/loggers/wandb/__init__.py,sha256=76qb0HhWojf0Ub1x9OkMjtzeXxE67KysBGa-MBbJyC4,651
69
- nshtrainer/configs/lr_scheduler/__init__.py,sha256=19rkQMI8j31SVp-LAmmV3w9J2Lpv0XdyEzoeQT2dPfE,802
69
+ nshtrainer/configs/lr_scheduler/__init__.py,sha256=8ORO-QC12SjZ2F_reMoDgr8-O8nxZxX0IKU4fl-cC3A,1023
70
70
  nshtrainer/configs/lr_scheduler/_base/__init__.py,sha256=fvGjkUJ1K2RVXjXror22QOtEa-xWFJz2Cz3HrBC5XfA,189
71
- nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py,sha256=8-KVX8cBJspiELZAbEdJpp8zTkXVik6mn-LNo_Qv27I,412
71
+ nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py,sha256=i8LeZh0c4wqtZ1ehZb2LCq7kwOL0OyswMMOnwyI6R04,533
72
72
  nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py,sha256=lpXEFZY4cM3znZqYG9IZ1xNNtzttt8VVspSuOz0fb-k,467
73
73
  nshtrainer/configs/metrics/__init__.py,sha256=mK_xgXJDAyGY4K_x_Zo4aj36kjT45d850keuUe3U1rY,200
74
74
  nshtrainer/configs/metrics/_config/__init__.py,sha256=XDDvDPWULd_vd3lrgF2KGAVR2LVDuhdQvy-fF2ImarI,159
75
- nshtrainer/configs/nn/__init__.py,sha256=2nSgm07ym2D4yiMQ15pyyb26laM6WWt5KM44PnCDK5A,1864
76
- nshtrainer/configs/nn/mlp/__init__.py,sha256=ZHQX44z6A_DCGYGYd0N_AzaHguiVBMwexMBJD6TjppQ,250
77
- nshtrainer/configs/nn/nonlinearity/__init__.py,sha256=GnJmCX0CwH6GLbs1RgXYrtmxfcmbD0GlujrxYIRr4ms,1934
78
- nshtrainer/configs/optimizer/__init__.py,sha256=2N9LK2Dm4w45Ff0tlSG4Z52Ne4OAujNT5s6_K-bE-qA,253
79
- nshtrainer/configs/profiler/__init__.py,sha256=AGkoWizikpXOKE3YQ5wSBrHLHAV3ZNf5rSl8YuEb84s,681
75
+ nshtrainer/configs/nn/__init__.py,sha256=3hVc81Gs9AJYVkrwJkQ_ye7tLU2HOLdBj-mMkXx2c_I,1957
76
+ nshtrainer/configs/nn/mlp/__init__.py,sha256=eMECrgz-My9mFS7lpWVI3dj1ApB-E7xwfmNc37hUsPI,347
77
+ nshtrainer/configs/nn/nonlinearity/__init__.py,sha256=Gjr2HCx8jJTcfu7sLgn54o2ucGKaBea4encm4AWpKNY,2040
78
+ nshtrainer/configs/optimizer/__init__.py,sha256=IMEsEbiVFXSkj6WmDjNjmKQuRspphs5xZnYZ2gYE39Y,344
79
+ nshtrainer/configs/profiler/__init__.py,sha256=2ssaIpfVnvcbfNvZ-JeKp1Cx4NO1LknkVqTm1hu7Lvw,768
80
80
  nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRDm1CKqjwUOQNbQjD4,176
81
81
  nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
82
82
  nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
83
83
  nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
84
- nshtrainer/configs/trainer/__init__.py,sha256=jRbJylnfPa483iHzN-ZYDObInUfMxuql47gFMsBlJKU,3527
85
- nshtrainer/configs/trainer/_config/__init__.py,sha256=tAWlUTtn2EeQ8xnKft4CA4gVwXn2nf9Yrs57em8jz70,3438
84
+ nshtrainer/configs/trainer/__init__.py,sha256=KIDYjJsc-WYXKiH2RNzAZJD5MKOTdO9wdtu_vWDNPxU,3936
85
+ nshtrainer/configs/trainer/_config/__init__.py,sha256=1_Ad5uTvXdVuHMJB3s8s-0EraDwNZssg3sXBmVouF9w,3847
86
86
  nshtrainer/configs/trainer/trainer/__init__.py,sha256=DDuBRx0kVNMW0z_sqKTUt8-Ql7bOpargi4KcHHvDu_c,486
87
- nshtrainer/configs/util/__init__.py,sha256=gtYtZ4VGwEvF9_hByZl8CWOSeDpEOIkkcLtUwvNbSEQ,2014
87
+ nshtrainer/configs/util/__init__.py,sha256=qXittS7f7MyaqJnjvFLKnKsyb6bXTD3dEV16jXVDaH4,2104
88
88
  nshtrainer/configs/util/_environment_info/__init__.py,sha256=eB4E0Ck7XCeSC5gbUdA5thd7TXnjGCL0t8GZIFj7uCI,1644
89
- nshtrainer/configs/util/config/__init__.py,sha256=GHQZR-M0LwL7Qow2oCgmaWwz9h16NkfiWpxIT9cF52Y,411
89
+ nshtrainer/configs/util/config/__init__.py,sha256=nEFiDG3-dvvTytYn1tEkPFzp7fgaGRp2j7toSN7yRGs,501
90
90
  nshtrainer/configs/util/config/dtype/__init__.py,sha256=PmGF-O4r6SXqEaagVsQ5YxEqhdVdcU0dgJW1Ljzpp6k,158
91
- nshtrainer/configs/util/config/duration/__init__.py,sha256=rja-dB-WC2criHrSBC7gkl5GnWeXQ3bHD48zy9kEPbo,254
91
+ nshtrainer/configs/util/config/duration/__init__.py,sha256=44lS2irOIPVfgshMTfnZM2jC6l0Pjst9w2M_lJoS_MU,353
92
92
  nshtrainer/data/__init__.py,sha256=K4i3Tw4g9EOK2zlMMbidi99y0SyI4w8P7_XUf1n42Ts,260
93
93
  nshtrainer/data/balanced_batch_sampler.py,sha256=r1cBKRXKHD8E1Ax6tj-FUbE-z1qpbO58mQ9VrK9uLnc,5481
94
94
  nshtrainer/data/datamodule.py,sha256=lSOgH32nysJWa6Y7ba1QyOdUV0DVVdO98qokP8wigjk,4138
@@ -104,7 +104,7 @@ nshtrainer/lr_scheduler/_base.py,sha256=EhA2f_WiZ79RcXL2nJbwCwNK620c8ugEVUmJ8CcV
104
104
  nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=gvUuv031lvWdXboDeH7iAF3ZgNPQK40bQwfmqb11TNk,5492
105
105
  nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=vXH5S26ESHO_LPPqW8aDC3S5NGoZYkXeFjAOgttaUX8,2870
106
106
  nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
107
- nshtrainer/metrics/_config.py,sha256=kR9OJLZXIaZpG8UpHG3EQwOL36HB8yPk6afYjoIM0XM,1324
107
+ nshtrainer/metrics/_config.py,sha256=XIRokFM8PHrhBa3w2R6BM6a4es3ncsoBqE_LqXQFsFE,1223
108
108
  nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
109
109
  nshtrainer/model/base.py,sha256=JL3AmH17GQjQIoMrZl3O0vUI7dj5ZsO5iEJgoLPyzHw,10356
110
110
  nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
@@ -122,12 +122,11 @@ nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5N
122
122
  nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
123
123
  nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
124
124
  nshtrainer/trainer/__init__.py,sha256=MmoydVS6aYeav7zgDAUHxAQrV_PMQsbnZTCuPnLH9Wk,128
125
- nshtrainer/trainer/_config.py,sha256=Mz9J2ZFqxTlttnRA1eScGRgSAuf3-o3i9-xjN7eTm-k,35256
125
+ nshtrainer/trainer/_config.py,sha256=VD0DfdS-pyQ2nFG83c4u5AUkSAHODmXLX5s2qtvS_to,35400
126
126
  nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
127
127
  nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
128
128
  nshtrainer/trainer/trainer.py,sha256=HHqT83zWtYY9g5yD6X9aWrVh5VSpILW8PhoE6fp4snE,20734
129
129
  nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
130
- nshtrainer/util/_useful_types.py,sha256=7yd1ajSmjwfmZdBPlHVrIG3iXl1-T3n83JI53N8C7as,8080
131
130
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
132
131
  nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
133
132
  nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
@@ -138,6 +137,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
138
137
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
139
138
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
140
139
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
141
- nshtrainer-1.0.0b27.dist-info/METADATA,sha256=-eLqorpTOufpf0XwVyHmk2_nsgI3NdETpNPYa3uhHy0,988
142
- nshtrainer-1.0.0b27.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
143
- nshtrainer-1.0.0b27.dist-info/RECORD,,
140
+ nshtrainer-1.0.0b29.dist-info/METADATA,sha256=YRehZvU9svmmfAmwFrdmu-Tzxgi_EHbFwrn-ewD8W9c,988
141
+ nshtrainer-1.0.0b29.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
142
+ nshtrainer-1.0.0b29.dist-info/RECORD,,
@@ -1,316 +0,0 @@
1
- """Credit to useful-types from https://github.com/hauntsaninja/useful_types"""
2
-
3
- from __future__ import annotations
4
-
5
- from collections.abc import Awaitable, Iterable, Iterator, Sequence, Sized
6
- from collections.abc import Set as AbstractSet
7
- from os import PathLike
8
- from typing import Any, TypeVar, overload
9
-
10
- from typing_extensions import (
11
- Buffer,
12
- Literal,
13
- Protocol,
14
- SupportsIndex,
15
- TypeAlias,
16
- TypeAliasType,
17
- )
18
-
19
- _KT = TypeVar("_KT")
20
- _KT_co = TypeVar("_KT_co", covariant=True)
21
- _KT_contra = TypeVar("_KT_contra", contravariant=True)
22
- _VT = TypeVar("_VT")
23
- _VT_co = TypeVar("_VT_co", covariant=True)
24
- _T = TypeVar("_T")
25
- _T_co = TypeVar("_T_co", covariant=True)
26
- _T_contra = TypeVar("_T_contra", contravariant=True)
27
-
28
- # For partially known annotations. Usually, fields where type annotations
29
- # haven't been added are left unannotated, but in some situations this
30
- # isn't possible or a type is already partially known. In cases like these,
31
- # use Incomplete instead of Any as a marker. For example, use
32
- # "Incomplete | None" instead of "Any | None".
33
- Incomplete: TypeAlias = Any
34
-
35
-
36
- class IdentityFunction(Protocol):
37
- def __call__(self, __x: _T) -> _T: ...
38
-
39
-
40
- # ====================
41
- # Comparison protocols
42
- # ====================
43
-
44
-
45
- class SupportsDunderLT(Protocol[_T_contra]):
46
- def __lt__(self, __other: _T_contra) -> bool: ...
47
-
48
-
49
- class SupportsDunderGT(Protocol[_T_contra]):
50
- def __gt__(self, __other: _T_contra) -> bool: ...
51
-
52
-
53
- class SupportsDunderLE(Protocol[_T_contra]):
54
- def __le__(self, __other: _T_contra) -> bool: ...
55
-
56
-
57
- class SupportsDunderGE(Protocol[_T_contra]):
58
- def __ge__(self, __other: _T_contra) -> bool: ...
59
-
60
-
61
- class SupportsAllComparisons(
62
- SupportsDunderLT[Any],
63
- SupportsDunderGT[Any],
64
- SupportsDunderLE[Any],
65
- SupportsDunderGE[Any],
66
- Protocol,
67
- ): ...
68
-
69
-
70
- SupportsRichComparison = TypeAliasType(
71
- "SupportsRichComparison", SupportsDunderLT[Any] | SupportsDunderGT[Any]
72
- )
73
- SupportsRichComparisonT = TypeVar(
74
- "SupportsRichComparisonT", bound=SupportsRichComparison
75
- )
76
-
77
- # ====================
78
- # Dunder protocols
79
- # ====================
80
-
81
-
82
- class SupportsNext(Protocol[_T_co]):
83
- def __next__(self) -> _T_co: ...
84
-
85
-
86
- class SupportsAnext(Protocol[_T_co]):
87
- def __anext__(self) -> Awaitable[_T_co]: ...
88
-
89
-
90
- class SupportsAdd(Protocol[_T_contra, _T_co]):
91
- def __add__(self, __x: _T_contra) -> _T_co: ...
92
-
93
-
94
- class SupportsRAdd(Protocol[_T_contra, _T_co]):
95
- def __radd__(self, __x: _T_contra) -> _T_co: ...
96
-
97
-
98
- class SupportsSub(Protocol[_T_contra, _T_co]):
99
- def __sub__(self, __x: _T_contra) -> _T_co: ...
100
-
101
-
102
- class SupportsRSub(Protocol[_T_contra, _T_co]):
103
- def __rsub__(self, __x: _T_contra) -> _T_co: ...
104
-
105
-
106
- class SupportsDivMod(Protocol[_T_contra, _T_co]):
107
- def __divmod__(self, __other: _T_contra) -> _T_co: ...
108
-
109
-
110
- class SupportsRDivMod(Protocol[_T_contra, _T_co]):
111
- def __rdivmod__(self, __other: _T_contra) -> _T_co: ...
112
-
113
-
114
- # This protocol is generic over the iterator type, while Iterable is
115
- # generic over the type that is iterated over.
116
- class SupportsIter(Protocol[_T_co]):
117
- def __iter__(self) -> _T_co: ...
118
-
119
-
120
- # This protocol is generic over the iterator type, while AsyncIterable is
121
- # generic over the type that is iterated over.
122
- class SupportsAiter(Protocol[_T_co]):
123
- def __aiter__(self) -> _T_co: ...
124
-
125
-
126
- class SupportsLenAndGetItem(Protocol[_T_co]):
127
- def __len__(self) -> int: ...
128
- def __getitem__(self, __k: int) -> _T_co: ...
129
-
130
-
131
- class SupportsTrunc(Protocol):
132
- def __trunc__(self) -> int: ...
133
-
134
-
135
- # ====================
136
- # Mapping-like protocols
137
- # ====================
138
-
139
-
140
- class SupportsItems(Protocol[_KT_co, _VT_co]):
141
- def items(self) -> AbstractSet[tuple[_KT_co, _VT_co]]: ...
142
-
143
-
144
- class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
145
- def keys(self) -> Iterable[_KT]: ...
146
- def __getitem__(self, __key: _KT) -> _VT_co: ...
147
-
148
-
149
- class SupportsGetItem(Protocol[_KT_contra, _VT_co]):
150
- def __contains__(self, __x: Any) -> bool: ...
151
- def __getitem__(self, __key: _KT_contra) -> _VT_co: ...
152
-
153
-
154
- class SupportsItemAccess(SupportsGetItem[_KT_contra, _VT], Protocol[_KT_contra, _VT]):
155
- def __setitem__(self, __key: _KT_contra, __value: _VT) -> None: ...
156
- def __delitem__(self, __key: _KT_contra) -> None: ...
157
-
158
-
159
- # ====================
160
- # File handling
161
- # ====================
162
-
163
- StrPath: TypeAlias = str | PathLike[str]
164
- BytesPath: TypeAlias = bytes | PathLike[bytes]
165
- StrOrBytesPath: TypeAlias = str | bytes | PathLike[str] | PathLike[bytes]
166
-
167
- OpenTextModeUpdating: TypeAlias = Literal[
168
- "r+",
169
- "+r",
170
- "rt+",
171
- "r+t",
172
- "+rt",
173
- "tr+",
174
- "t+r",
175
- "+tr",
176
- "w+",
177
- "+w",
178
- "wt+",
179
- "w+t",
180
- "+wt",
181
- "tw+",
182
- "t+w",
183
- "+tw",
184
- "a+",
185
- "+a",
186
- "at+",
187
- "a+t",
188
- "+at",
189
- "ta+",
190
- "t+a",
191
- "+ta",
192
- "x+",
193
- "+x",
194
- "xt+",
195
- "x+t",
196
- "+xt",
197
- "tx+",
198
- "t+x",
199
- "+tx",
200
- ]
201
- OpenTextModeWriting: TypeAlias = Literal[
202
- "w", "wt", "tw", "a", "at", "ta", "x", "xt", "tx"
203
- ]
204
- OpenTextModeReading: TypeAlias = Literal[
205
- "r", "rt", "tr", "U", "rU", "Ur", "rtU", "rUt", "Urt", "trU", "tUr", "Utr"
206
- ]
207
- OpenTextMode: TypeAlias = (
208
- OpenTextModeUpdating | OpenTextModeWriting | OpenTextModeReading
209
- )
210
- OpenBinaryModeUpdating: TypeAlias = Literal[
211
- "rb+",
212
- "r+b",
213
- "+rb",
214
- "br+",
215
- "b+r",
216
- "+br",
217
- "wb+",
218
- "w+b",
219
- "+wb",
220
- "bw+",
221
- "b+w",
222
- "+bw",
223
- "ab+",
224
- "a+b",
225
- "+ab",
226
- "ba+",
227
- "b+a",
228
- "+ba",
229
- "xb+",
230
- "x+b",
231
- "+xb",
232
- "bx+",
233
- "b+x",
234
- "+bx",
235
- ]
236
- OpenBinaryModeWriting: TypeAlias = Literal["wb", "bw", "ab", "ba", "xb", "bx"]
237
- OpenBinaryModeReading: TypeAlias = Literal[
238
- "rb", "br", "rbU", "rUb", "Urb", "brU", "bUr", "Ubr"
239
- ]
240
- OpenBinaryMode: TypeAlias = (
241
- OpenBinaryModeUpdating | OpenBinaryModeReading | OpenBinaryModeWriting
242
- )
243
-
244
-
245
- class HasFileno(Protocol):
246
- def fileno(self) -> int: ...
247
-
248
-
249
- FileDescriptor: TypeAlias = int
250
- FileDescriptorLike: TypeAlias = int | HasFileno
251
- FileDescriptorOrPath: TypeAlias = int | StrOrBytesPath
252
-
253
-
254
- class SupportsRead(Protocol[_T_co]):
255
- def read(self, __length: int = ...) -> _T_co: ...
256
-
257
-
258
- class SupportsReadline(Protocol[_T_co]):
259
- def readline(self, __length: int = ...) -> _T_co: ...
260
-
261
-
262
- class SupportsNoArgReadline(Protocol[_T_co]):
263
- def readline(self) -> _T_co: ...
264
-
265
-
266
- class SupportsWrite(Protocol[_T_contra]):
267
- def write(self, __s: _T_contra) -> object: ...
268
-
269
-
270
- # ====================
271
- # Buffer protocols
272
- # ====================
273
-
274
- # Unfortunately PEP 688 does not allow us to distinguish read-only
275
- # from writable buffers. We use these aliases for readability for now.
276
- # Perhaps a future extension of the buffer protocol will allow us to
277
- # distinguish these cases in the type system.
278
- ReadOnlyBuffer: TypeAlias = Buffer
279
- # Anything that implements the read-write buffer interface.
280
- WriteableBuffer: TypeAlias = Buffer
281
- # Same as WriteableBuffer, but also includes read-only buffer types (like bytes).
282
- ReadableBuffer: TypeAlias = Buffer
283
-
284
-
285
- class SliceableBuffer(Buffer, Protocol):
286
- def __getitem__(self, __slice: slice) -> Sequence[int]: ...
287
-
288
-
289
- class IndexableBuffer(Buffer, Protocol):
290
- def __getitem__(self, __i: int) -> int: ...
291
-
292
-
293
- class SupportsGetItemBuffer(SliceableBuffer, IndexableBuffer, Protocol):
294
- def __contains__(self, __x: Any) -> bool: ...
295
- @overload
296
- def __getitem__(self, __slice: slice) -> Sequence[int]: ...
297
- @overload
298
- def __getitem__(self, __i: int) -> int: ...
299
-
300
-
301
- class SizedBuffer(Sized, Buffer, Protocol): ...
302
-
303
-
304
- # Source from https://github.com/python/typing/issues/256#issuecomment-1442633430
305
- # This works because str.__contains__ does not accept object (either in typeshed or at runtime)
306
- class SequenceNotStr(Protocol[_T_co]):
307
- @overload
308
- def __getitem__(self, index: SupportsIndex, /) -> _T_co: ...
309
- @overload
310
- def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ...
311
- def __contains__(self, value: object, /) -> bool: ...
312
- def __len__(self) -> int: ...
313
- def __iter__(self) -> Iterator[_T_co]: ...
314
- def index(self, value: Any, start: int = 0, stop: int = ..., /) -> int: ...
315
- def count(self, value: Any, /) -> int: ...
316
- def __reversed__(self) -> Iterator[_T_co]: ...