nshtrainer 1.0.0b28__py3-none-any.whl → 1.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,163 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Literal
4
+
5
+ from lightning.pytorch.plugins.precision import Precision
6
+ from typing_extensions import override
7
+
8
+ from ...util.config.dtype import DTypeConfig
9
+ from .base import PluginConfigBase, plugin_registry
10
+
11
+
12
+ @plugin_registry.register
13
+ class MixedPrecisionPluginConfig(PluginConfigBase):
14
+ name: Literal["mixed_precision"] = "mixed_precision"
15
+
16
+ precision: Literal["16-mixed", "bf16-mixed"]
17
+ """Whether to use ``torch.float16`` (``'16-mixed'``) or ``torch.bfloat16`` (``'bf16-mixed'``)."""
18
+
19
+ device: str
20
+ """The device for ``torch.autocast``."""
21
+
22
+ @override
23
+ def create_plugin(self, trainer_config) -> Precision:
24
+ from lightning.pytorch.plugins.precision.amp import MixedPrecision
25
+
26
+ return MixedPrecision(self.precision, self.device)
27
+
28
+
29
+ @plugin_registry.register
30
+ class BitsandbytesPluginConfig(PluginConfigBase):
31
+ name: Literal["bitsandbytes_precision"] = "bitsandbytes_precision"
32
+
33
+ mode: Literal["nf4", "nf4-dq", "fp4", "fp4-dq", "int8", "int8-training"]
34
+ """The quantization mode to use."""
35
+
36
+ dtype: DTypeConfig | None = None
37
+ """The compute dtype to use."""
38
+
39
+ ignore_modules: set[str] | None = None
40
+ """The submodules whose Linear layers should not be replaced.
41
+
42
+ This might be desirable for numerical stability. The string will be checked
43
+ as a prefix, so a value like "transformer.blocks" will ignore all linear
44
+ layers in all of the transformer blocks.
45
+ """
46
+
47
+ @override
48
+ def create_plugin(self, trainer_config) -> Precision:
49
+ from lightning.pytorch.plugins.precision.bitsandbytes import (
50
+ BitsandbytesPrecision,
51
+ )
52
+
53
+ return BitsandbytesPrecision(
54
+ mode=self.mode,
55
+ dtype=self.dtype.torch_dtype if self.dtype is not None else None,
56
+ ignore_modules=self.ignore_modules,
57
+ )
58
+
59
+
60
+ @plugin_registry.register
61
+ class DeepSpeedPluginConfig(PluginConfigBase):
62
+ name: Literal["deepspeed_precision"] = "deepspeed_precision"
63
+
64
+ precision: Literal["16-true", "bf16-true", "16-mixed", "bf16-mixed", "32-true"]
65
+ """Full precision (32-true), half precision (16-true, bf16-true) or
66
+ mixed precision (16-mixed, bf16-mixed)."""
67
+
68
+ @override
69
+ def create_plugin(self, trainer_config) -> Precision:
70
+ from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecision
71
+
72
+ return DeepSpeedPrecision(precision=self.precision)
73
+
74
+
75
+ @plugin_registry.register
76
+ class DoublePrecisionPluginConfig(PluginConfigBase):
77
+ name: Literal["double_precision"] = "double_precision"
78
+
79
+ precision: Literal["64-true"] = "64-true"
80
+ """Plugin for training with double (``torch.float64``) precision."""
81
+
82
+ @override
83
+ def create_plugin(self, trainer_config) -> Precision:
84
+ from lightning.pytorch.plugins.precision.double import DoublePrecision
85
+
86
+ return DoublePrecision()
87
+
88
+
89
+ @plugin_registry.register
90
+ class FSDPPrecisionPluginConfig(PluginConfigBase):
91
+ name: Literal["fsdp_precision"] = "fsdp_precision"
92
+
93
+ precision: Literal["16-true", "bf16-true", "16-mixed", "bf16-mixed", "32-true"]
94
+ """Full precision (32-true), half precision (16-true, bf16-true) or
95
+ mixed precision (16-mixed, bf16-mixed)."""
96
+
97
+ @override
98
+ def create_plugin(self, trainer_config) -> Precision:
99
+ from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision
100
+
101
+ return FSDPPrecision(precision=self.precision)
102
+
103
+
104
+ @plugin_registry.register
105
+ class HalfPrecisionPluginConfig(PluginConfigBase):
106
+ name: Literal["half_precision"] = "half_precision"
107
+
108
+ precision: Literal["bf16-true", "16-true"]
109
+ """Whether to use ``torch.float16`` (``'16-true'``) or ``torch.bfloat16`` (``'bf16-true'``)."""
110
+
111
+ @override
112
+ def create_plugin(self, trainer_config) -> Precision:
113
+ from lightning.pytorch.plugins.precision.half import HalfPrecision
114
+
115
+ return HalfPrecision(precision=self.precision)
116
+
117
+
118
+ @plugin_registry.register
119
+ class TransformerEnginePluginConfig(PluginConfigBase):
120
+ name: Literal["transformer_engine_precision"] = "transformer_engine_precision"
121
+
122
+ weights_dtype: DTypeConfig
123
+ """The weights dtype to use."""
124
+
125
+ recipe: dict[str, Any] | None = None
126
+ """Recipe for the DelayedScaling configuration in dict format."""
127
+
128
+ replace_layers: bool | None = None
129
+ """Whether to replace ``Linear`` and ``LayerNorm`` layers automatically with their
130
+ Transformer Engine alternatives."""
131
+
132
+ fallback_compute_dtype: DTypeConfig | None = None
133
+ """The compute dtype to use for operations that don't support fp8 autocast.
134
+ Defaults to the same as weights_dtype."""
135
+
136
+ @override
137
+ def create_plugin(self, trainer_config) -> Precision:
138
+ from lightning.pytorch.plugins.precision.transformer_engine import (
139
+ TransformerEnginePrecision,
140
+ )
141
+
142
+ return TransformerEnginePrecision(
143
+ weights_dtype=self.weights_dtype.torch_dtype,
144
+ recipe=self.recipe,
145
+ replace_layers=self.replace_layers,
146
+ fallback_compute_dtype=self.fallback_compute_dtype.torch_dtype
147
+ if self.fallback_compute_dtype
148
+ else None,
149
+ )
150
+
151
+
152
+ @plugin_registry.register
153
+ class XLAPluginConfig(PluginConfigBase):
154
+ name: Literal["xla_precision"] = "xla_precision"
155
+
156
+ precision: Literal["32-true", "16-true", "bf16-true"]
157
+ """Full precision (32-true) or half precision (16-true, bf16-true)."""
158
+
159
+ @override
160
+ def create_plugin(self, trainer_config) -> Precision:
161
+ from lightning.pytorch.plugins.precision.xla import XLAPrecision
162
+
163
+ return XLAPrecision(precision=self.precision)
@@ -0,0 +1,51 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import TYPE_CHECKING, Literal
5
+
6
+ import nshconfig as C
7
+ from lightning.pytorch.strategies.strategy import Strategy
8
+ from typing_extensions import TypeAliasType
9
+
10
+ if TYPE_CHECKING:
11
+ from ._config import TrainerConfig
12
+
13
+ StrategyLiteral = TypeAliasType(
14
+ "StrategyLiteral",
15
+ Literal[
16
+ "auto",
17
+ "ddp",
18
+ "ddp_find_unused_parameters_false",
19
+ "ddp_find_unused_parameters_true",
20
+ "ddp_spawn",
21
+ "ddp_spawn_find_unused_parameters_false",
22
+ "ddp_spawn_find_unused_parameters_true",
23
+ "ddp_fork",
24
+ "ddp_fork_find_unused_parameters_false",
25
+ "ddp_fork_find_unused_parameters_true",
26
+ "ddp_notebook",
27
+ "dp",
28
+ "deepspeed",
29
+ "deepspeed_stage_1",
30
+ "deepspeed_stage_1_offload",
31
+ "deepspeed_stage_2",
32
+ "deepspeed_stage_2_offload",
33
+ "deepspeed_stage_3",
34
+ "deepspeed_stage_3_offload",
35
+ "deepspeed_stage_3_offload_nvme",
36
+ "fsdp",
37
+ "fsdp_cpu_offload",
38
+ "single_xla",
39
+ "xla_fsdp",
40
+ "xla",
41
+ "single_tpu",
42
+ ],
43
+ )
44
+
45
+
46
+ class StrategyConfigBase(C.Config, ABC):
47
+ @abstractmethod
48
+ def create_strategy(self, trainer_config: "TrainerConfig") -> Strategy: ...
49
+
50
+
51
+ StrategyConfig = TypeAliasType("StrategyConfig", StrategyConfigBase)
@@ -22,14 +22,12 @@ from .._checkpoint.metadata import _write_checkpoint_metadata
22
22
  from ..callbacks.base import resolve_all_callbacks
23
23
  from ..util._environment_info import EnvironmentConfig
24
24
  from ..util.bf16 import is_bf16_supported_no_emulation
25
- from ._config import (
26
- AcceleratorConfigBase,
27
- LightningTrainerKwargs,
28
- StrategyConfigBase,
29
- TrainerConfig,
30
- )
25
+ from ._config import LightningTrainerKwargs, TrainerConfig
31
26
  from ._runtime_callback import RuntimeTrackerCallback, Stage
27
+ from .accelerator import AcceleratorConfigBase
28
+ from .plugin import PluginConfigBase
32
29
  from .signal_connector import _SignalConnector
30
+ from .strategy import StrategyConfigBase
33
31
 
34
32
  log = logging.getLogger(__name__)
35
33
 
@@ -172,12 +170,12 @@ class Trainer(LightningTrainer):
172
170
 
173
171
  if (accelerator := hparams.accelerator) is not None:
174
172
  if isinstance(accelerator, AcceleratorConfigBase):
175
- accelerator = accelerator.create_accelerator()
173
+ accelerator = accelerator.create_accelerator(hparams)
176
174
  _update_kwargs(accelerator=accelerator)
177
175
 
178
176
  if (strategy := hparams.strategy) is not None:
179
177
  if isinstance(strategy, StrategyConfigBase):
180
- strategy = strategy.create_strategy()
178
+ strategy = strategy.create_strategy(hparams)
181
179
  _update_kwargs(strategy=strategy)
182
180
 
183
181
  if (precision := hparams.precision) is not None:
@@ -238,7 +236,8 @@ class Trainer(LightningTrainer):
238
236
  if plugin_configs := hparams.plugins:
239
237
  _update_kwargs(
240
238
  plugins=[
241
- plugin_config.create_plugin() for plugin_config in plugin_configs
239
+ plugin_config.create_plugin(hparams)
240
+ for plugin_config in plugin_configs
242
241
  ]
243
242
  )
244
243
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 1.0.0b28
3
+ Version: 1.0.0b30
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,5 +1,5 @@
1
1
  nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
2
- nshtrainer/__init__.py,sha256=VcqBfL8RgCcZDaY645nxeDmOspqerx4x46wggCMnS0E,692
2
+ nshtrainer/__init__.py,sha256=52OB7QRlhrTCIdDecpT7yEZyZM1XvYxywhuORn1eKoY,814
3
3
  nshtrainer/_callback.py,sha256=tXQCDzS6CvMTuTY5lQSH5qZs1pXUi-gt9bQdpXMVdEs,12715
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=PHy-54Cg-o3OtCffAqrVv6ZVMU7zhRo_-sZiSEEno1Y,5019
5
5
  nshtrainer/_checkpoint/saver.py,sha256=LOP8jjKF0Dw9x9H-BKrLMWlEp1XTan2DUK0zQUCWw5U,1360
@@ -31,7 +31,7 @@ nshtrainer/callbacks/shared_parameters.py,sha256=ggMI1krkqN7sGOrjK_I96IsTMYMXHoV
31
31
  nshtrainer/callbacks/timer.py,sha256=BB-M7tV4QNYOwY_Su6j9P7IILxVRae_upmDq4qsxiao,4670
32
32
  nshtrainer/callbacks/wandb_upload_code.py,sha256=PTqNE1QB5U8NR5zhbiQZrmQuugX2UV7B12UdMpo9aV0,2353
33
33
  nshtrainer/callbacks/wandb_watch.py,sha256=tTTcFzxd2Ia9xu8tCogQ5CLJZBq1ne5JlpGVE75vKYs,2976
34
- nshtrainer/configs/__init__.py,sha256=zyo4lV9ObB3T3_hhBhzWGNb6MRma4h7QHD3OrypxqEw,10582
34
+ nshtrainer/configs/__init__.py,sha256=eS3naq6EG1vCq28G2nAW1CqYFdsrh6ueBlzX_LazgUw,14159
35
35
  nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
36
36
  nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
37
37
  nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
@@ -81,9 +81,17 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRD
81
81
  nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
82
82
  nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
83
83
  nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
84
- nshtrainer/configs/trainer/__init__.py,sha256=KIDYjJsc-WYXKiH2RNzAZJD5MKOTdO9wdtu_vWDNPxU,3936
85
- nshtrainer/configs/trainer/_config/__init__.py,sha256=1_Ad5uTvXdVuHMJB3s8s-0EraDwNZssg3sXBmVouF9w,3847
86
- nshtrainer/configs/trainer/trainer/__init__.py,sha256=DDuBRx0kVNMW0z_sqKTUt8-Ql7bOpargi4KcHHvDu_c,486
84
+ nshtrainer/configs/trainer/__init__.py,sha256=hKMI_2ve5zcsQys2DDQDv7OmshYsIG0uJlCLreVHpF0,7779
85
+ nshtrainer/configs/trainer/_config/__init__.py,sha256=Xw6I_9tUemDbHncpjKHRqye_e1_OyubK_FJcvdcQ0yc,4020
86
+ nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
87
+ nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
88
+ nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=GuubcKrbXt4VjJRT8VpNUQqBtyuutre_CfkS0EWZ5_E,368
89
+ nshtrainer/configs/trainer/plugin/environment/__init__.py,sha256=3o16x4qRAOvkJH9Vg4-QwsEODDC6aP_OXRnPPkm_xSo,1376
90
+ nshtrainer/configs/trainer/plugin/io/__init__.py,sha256=W6G67JnigB6d3MiwLrbSKgtIZLUccXznp-IXwkK1J4U,743
91
+ nshtrainer/configs/trainer/plugin/layer_sync/__init__.py,sha256=SYDZk2M6sgpt4sEuoURuS8EKYmaqGcvYxETE9jvTrEE,431
92
+ nshtrainer/configs/trainer/plugin/precision/__init__.py,sha256=szlqSfK2XuWdkf72LQzQFv3SlWfKFdRUpBEYIxQ3TPs,1507
93
+ nshtrainer/configs/trainer/strategy/__init__.py,sha256=50whNloJVBq_bdbLaPQnPBTeS1Rcs8MwxTCYBj1kKa4,273
94
+ nshtrainer/configs/trainer/trainer/__init__.py,sha256=QnuhMQNAa1nSVN2o50_WeKAQG_qkNlkeoq9zTjjwmTI,586
87
95
  nshtrainer/configs/util/__init__.py,sha256=qXittS7f7MyaqJnjvFLKnKsyb6bXTD3dEV16jXVDaH4,2104
88
96
  nshtrainer/configs/util/_environment_info/__init__.py,sha256=eB4E0Ck7XCeSC5gbUdA5thd7TXnjGCL0t8GZIFj7uCI,1644
89
97
  nshtrainer/configs/util/config/__init__.py,sha256=nEFiDG3-dvvTytYn1tEkPFzp7fgaGRp2j7toSN7yRGs,501
@@ -104,7 +112,7 @@ nshtrainer/lr_scheduler/_base.py,sha256=EhA2f_WiZ79RcXL2nJbwCwNK620c8ugEVUmJ8CcV
104
112
  nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=gvUuv031lvWdXboDeH7iAF3ZgNPQK40bQwfmqb11TNk,5492
105
113
  nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=vXH5S26ESHO_LPPqW8aDC3S5NGoZYkXeFjAOgttaUX8,2870
106
114
  nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
107
- nshtrainer/metrics/_config.py,sha256=kR9OJLZXIaZpG8UpHG3EQwOL36HB8yPk6afYjoIM0XM,1324
115
+ nshtrainer/metrics/_config.py,sha256=XIRokFM8PHrhBa3w2R6BM6a4es3ncsoBqE_LqXQFsFE,1223
108
116
  nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
109
117
  nshtrainer/model/base.py,sha256=JL3AmH17GQjQIoMrZl3O0vUI7dj5ZsO5iEJgoLPyzHw,10356
110
118
  nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
@@ -121,13 +129,20 @@ nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,
121
129
  nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5NOQ,1160
122
130
  nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
123
131
  nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
124
- nshtrainer/trainer/__init__.py,sha256=MmoydVS6aYeav7zgDAUHxAQrV_PMQsbnZTCuPnLH9Wk,128
125
- nshtrainer/trainer/_config.py,sha256=Mz9J2ZFqxTlttnRA1eScGRgSAuf3-o3i9-xjN7eTm-k,35256
132
+ nshtrainer/trainer/__init__.py,sha256=ggDHzIUbABezh4BjEwrxyWuXmuDBV-x4jv9gwXgVHU0,250
133
+ nshtrainer/trainer/_config.py,sha256=0GgofvaWf5Vo9REXNJpTvpVVRlFExGTOzcOt4jwJXNk,34129
126
134
  nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
135
+ nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
136
+ nshtrainer/trainer/plugin/__init__.py,sha256=UM8f70Ml3RGpsXeQr1Yh1yBcxxjFNGFGgJbniCn_rws,366
137
+ nshtrainer/trainer/plugin/base.py,sha256=9-qUHXGpll_yCylun0899sbmJDpyhD9IQcBtVrJx38I,919
138
+ nshtrainer/trainer/plugin/environment.py,sha256=NW0qbsbvDPe59JGOMgPLq1fj7szLucIV1WRTxCrcjF4,4367
139
+ nshtrainer/trainer/plugin/io.py,sha256=nm6YDCVZAhmPvLaLnw6q4BrK2Gj2wvD5ZLDhj1xneEE,2030
140
+ nshtrainer/trainer/plugin/layer_sync.py,sha256=h-ydZwXepnsw5-paLgiDatqPyQ_8C0QEv6DrYFIaXOo,735
141
+ nshtrainer/trainer/plugin/precision.py,sha256=I0QsB1bVxmsFmBOkgrAfGONsuYae_lD9Bz0PfJEQvH4,5598
127
142
  nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
128
- nshtrainer/trainer/trainer.py,sha256=HHqT83zWtYY9g5yD6X9aWrVh5VSpILW8PhoE6fp4snE,20734
143
+ nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
144
+ nshtrainer/trainer/trainer.py,sha256=l2kJs27v4IHZnzxExr0zX0sVex0wukgiD2Wn_0wiGJg,20836
129
145
  nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
130
- nshtrainer/util/_useful_types.py,sha256=7yd1ajSmjwfmZdBPlHVrIG3iXl1-T3n83JI53N8C7as,8080
131
146
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
132
147
  nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
133
148
  nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
@@ -138,6 +153,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
138
153
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
139
154
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
140
155
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
141
- nshtrainer-1.0.0b28.dist-info/METADATA,sha256=1MJi65pa7HEVmtDR64Y32SwDe_bv1AZHSgyo6gIBmzo,988
142
- nshtrainer-1.0.0b28.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
143
- nshtrainer-1.0.0b28.dist-info/RECORD,,
156
+ nshtrainer-1.0.0b30.dist-info/METADATA,sha256=zxFm4X5APkZR6E4E8-jzVghTwYEYCJQzCHpCV_8hWzg,988
157
+ nshtrainer-1.0.0b30.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
158
+ nshtrainer-1.0.0b30.dist-info/RECORD,,
@@ -1,316 +0,0 @@
1
- """Credit to useful-types from https://github.com/hauntsaninja/useful_types"""
2
-
3
- from __future__ import annotations
4
-
5
- from collections.abc import Awaitable, Iterable, Iterator, Sequence, Sized
6
- from collections.abc import Set as AbstractSet
7
- from os import PathLike
8
- from typing import Any, TypeVar, overload
9
-
10
- from typing_extensions import (
11
- Buffer,
12
- Literal,
13
- Protocol,
14
- SupportsIndex,
15
- TypeAlias,
16
- TypeAliasType,
17
- )
18
-
19
- _KT = TypeVar("_KT")
20
- _KT_co = TypeVar("_KT_co", covariant=True)
21
- _KT_contra = TypeVar("_KT_contra", contravariant=True)
22
- _VT = TypeVar("_VT")
23
- _VT_co = TypeVar("_VT_co", covariant=True)
24
- _T = TypeVar("_T")
25
- _T_co = TypeVar("_T_co", covariant=True)
26
- _T_contra = TypeVar("_T_contra", contravariant=True)
27
-
28
- # For partially known annotations. Usually, fields where type annotations
29
- # haven't been added are left unannotated, but in some situations this
30
- # isn't possible or a type is already partially known. In cases like these,
31
- # use Incomplete instead of Any as a marker. For example, use
32
- # "Incomplete | None" instead of "Any | None".
33
- Incomplete: TypeAlias = Any
34
-
35
-
36
- class IdentityFunction(Protocol):
37
- def __call__(self, __x: _T) -> _T: ...
38
-
39
-
40
- # ====================
41
- # Comparison protocols
42
- # ====================
43
-
44
-
45
- class SupportsDunderLT(Protocol[_T_contra]):
46
- def __lt__(self, __other: _T_contra) -> bool: ...
47
-
48
-
49
- class SupportsDunderGT(Protocol[_T_contra]):
50
- def __gt__(self, __other: _T_contra) -> bool: ...
51
-
52
-
53
- class SupportsDunderLE(Protocol[_T_contra]):
54
- def __le__(self, __other: _T_contra) -> bool: ...
55
-
56
-
57
- class SupportsDunderGE(Protocol[_T_contra]):
58
- def __ge__(self, __other: _T_contra) -> bool: ...
59
-
60
-
61
- class SupportsAllComparisons(
62
- SupportsDunderLT[Any],
63
- SupportsDunderGT[Any],
64
- SupportsDunderLE[Any],
65
- SupportsDunderGE[Any],
66
- Protocol,
67
- ): ...
68
-
69
-
70
- SupportsRichComparison = TypeAliasType(
71
- "SupportsRichComparison", SupportsDunderLT[Any] | SupportsDunderGT[Any]
72
- )
73
- SupportsRichComparisonT = TypeVar(
74
- "SupportsRichComparisonT", bound=SupportsRichComparison
75
- )
76
-
77
- # ====================
78
- # Dunder protocols
79
- # ====================
80
-
81
-
82
- class SupportsNext(Protocol[_T_co]):
83
- def __next__(self) -> _T_co: ...
84
-
85
-
86
- class SupportsAnext(Protocol[_T_co]):
87
- def __anext__(self) -> Awaitable[_T_co]: ...
88
-
89
-
90
- class SupportsAdd(Protocol[_T_contra, _T_co]):
91
- def __add__(self, __x: _T_contra) -> _T_co: ...
92
-
93
-
94
- class SupportsRAdd(Protocol[_T_contra, _T_co]):
95
- def __radd__(self, __x: _T_contra) -> _T_co: ...
96
-
97
-
98
- class SupportsSub(Protocol[_T_contra, _T_co]):
99
- def __sub__(self, __x: _T_contra) -> _T_co: ...
100
-
101
-
102
- class SupportsRSub(Protocol[_T_contra, _T_co]):
103
- def __rsub__(self, __x: _T_contra) -> _T_co: ...
104
-
105
-
106
- class SupportsDivMod(Protocol[_T_contra, _T_co]):
107
- def __divmod__(self, __other: _T_contra) -> _T_co: ...
108
-
109
-
110
- class SupportsRDivMod(Protocol[_T_contra, _T_co]):
111
- def __rdivmod__(self, __other: _T_contra) -> _T_co: ...
112
-
113
-
114
- # This protocol is generic over the iterator type, while Iterable is
115
- # generic over the type that is iterated over.
116
- class SupportsIter(Protocol[_T_co]):
117
- def __iter__(self) -> _T_co: ...
118
-
119
-
120
- # This protocol is generic over the iterator type, while AsyncIterable is
121
- # generic over the type that is iterated over.
122
- class SupportsAiter(Protocol[_T_co]):
123
- def __aiter__(self) -> _T_co: ...
124
-
125
-
126
- class SupportsLenAndGetItem(Protocol[_T_co]):
127
- def __len__(self) -> int: ...
128
- def __getitem__(self, __k: int) -> _T_co: ...
129
-
130
-
131
- class SupportsTrunc(Protocol):
132
- def __trunc__(self) -> int: ...
133
-
134
-
135
- # ====================
136
- # Mapping-like protocols
137
- # ====================
138
-
139
-
140
- class SupportsItems(Protocol[_KT_co, _VT_co]):
141
- def items(self) -> AbstractSet[tuple[_KT_co, _VT_co]]: ...
142
-
143
-
144
- class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
145
- def keys(self) -> Iterable[_KT]: ...
146
- def __getitem__(self, __key: _KT) -> _VT_co: ...
147
-
148
-
149
- class SupportsGetItem(Protocol[_KT_contra, _VT_co]):
150
- def __contains__(self, __x: Any) -> bool: ...
151
- def __getitem__(self, __key: _KT_contra) -> _VT_co: ...
152
-
153
-
154
- class SupportsItemAccess(SupportsGetItem[_KT_contra, _VT], Protocol[_KT_contra, _VT]):
155
- def __setitem__(self, __key: _KT_contra, __value: _VT) -> None: ...
156
- def __delitem__(self, __key: _KT_contra) -> None: ...
157
-
158
-
159
- # ====================
160
- # File handling
161
- # ====================
162
-
163
- StrPath: TypeAlias = str | PathLike[str]
164
- BytesPath: TypeAlias = bytes | PathLike[bytes]
165
- StrOrBytesPath: TypeAlias = str | bytes | PathLike[str] | PathLike[bytes]
166
-
167
- OpenTextModeUpdating: TypeAlias = Literal[
168
- "r+",
169
- "+r",
170
- "rt+",
171
- "r+t",
172
- "+rt",
173
- "tr+",
174
- "t+r",
175
- "+tr",
176
- "w+",
177
- "+w",
178
- "wt+",
179
- "w+t",
180
- "+wt",
181
- "tw+",
182
- "t+w",
183
- "+tw",
184
- "a+",
185
- "+a",
186
- "at+",
187
- "a+t",
188
- "+at",
189
- "ta+",
190
- "t+a",
191
- "+ta",
192
- "x+",
193
- "+x",
194
- "xt+",
195
- "x+t",
196
- "+xt",
197
- "tx+",
198
- "t+x",
199
- "+tx",
200
- ]
201
- OpenTextModeWriting: TypeAlias = Literal[
202
- "w", "wt", "tw", "a", "at", "ta", "x", "xt", "tx"
203
- ]
204
- OpenTextModeReading: TypeAlias = Literal[
205
- "r", "rt", "tr", "U", "rU", "Ur", "rtU", "rUt", "Urt", "trU", "tUr", "Utr"
206
- ]
207
- OpenTextMode: TypeAlias = (
208
- OpenTextModeUpdating | OpenTextModeWriting | OpenTextModeReading
209
- )
210
- OpenBinaryModeUpdating: TypeAlias = Literal[
211
- "rb+",
212
- "r+b",
213
- "+rb",
214
- "br+",
215
- "b+r",
216
- "+br",
217
- "wb+",
218
- "w+b",
219
- "+wb",
220
- "bw+",
221
- "b+w",
222
- "+bw",
223
- "ab+",
224
- "a+b",
225
- "+ab",
226
- "ba+",
227
- "b+a",
228
- "+ba",
229
- "xb+",
230
- "x+b",
231
- "+xb",
232
- "bx+",
233
- "b+x",
234
- "+bx",
235
- ]
236
- OpenBinaryModeWriting: TypeAlias = Literal["wb", "bw", "ab", "ba", "xb", "bx"]
237
- OpenBinaryModeReading: TypeAlias = Literal[
238
- "rb", "br", "rbU", "rUb", "Urb", "brU", "bUr", "Ubr"
239
- ]
240
- OpenBinaryMode: TypeAlias = (
241
- OpenBinaryModeUpdating | OpenBinaryModeReading | OpenBinaryModeWriting
242
- )
243
-
244
-
245
- class HasFileno(Protocol):
246
- def fileno(self) -> int: ...
247
-
248
-
249
- FileDescriptor: TypeAlias = int
250
- FileDescriptorLike: TypeAlias = int | HasFileno
251
- FileDescriptorOrPath: TypeAlias = int | StrOrBytesPath
252
-
253
-
254
- class SupportsRead(Protocol[_T_co]):
255
- def read(self, __length: int = ...) -> _T_co: ...
256
-
257
-
258
- class SupportsReadline(Protocol[_T_co]):
259
- def readline(self, __length: int = ...) -> _T_co: ...
260
-
261
-
262
- class SupportsNoArgReadline(Protocol[_T_co]):
263
- def readline(self) -> _T_co: ...
264
-
265
-
266
- class SupportsWrite(Protocol[_T_contra]):
267
- def write(self, __s: _T_contra) -> object: ...
268
-
269
-
270
- # ====================
271
- # Buffer protocols
272
- # ====================
273
-
274
- # Unfortunately PEP 688 does not allow us to distinguish read-only
275
- # from writable buffers. We use these aliases for readability for now.
276
- # Perhaps a future extension of the buffer protocol will allow us to
277
- # distinguish these cases in the type system.
278
- ReadOnlyBuffer: TypeAlias = Buffer
279
- # Anything that implements the read-write buffer interface.
280
- WriteableBuffer: TypeAlias = Buffer
281
- # Same as WriteableBuffer, but also includes read-only buffer types (like bytes).
282
- ReadableBuffer: TypeAlias = Buffer
283
-
284
-
285
- class SliceableBuffer(Buffer, Protocol):
286
- def __getitem__(self, __slice: slice) -> Sequence[int]: ...
287
-
288
-
289
- class IndexableBuffer(Buffer, Protocol):
290
- def __getitem__(self, __i: int) -> int: ...
291
-
292
-
293
- class SupportsGetItemBuffer(SliceableBuffer, IndexableBuffer, Protocol):
294
- def __contains__(self, __x: Any) -> bool: ...
295
- @overload
296
- def __getitem__(self, __slice: slice) -> Sequence[int]: ...
297
- @overload
298
- def __getitem__(self, __i: int) -> int: ...
299
-
300
-
301
- class SizedBuffer(Sized, Buffer, Protocol): ...
302
-
303
-
304
- # Source from https://github.com/python/typing/issues/256#issuecomment-1442633430
305
- # This works because str.__contains__ does not accept object (either in typeshed or at runtime)
306
- class SequenceNotStr(Protocol[_T_co]):
307
- @overload
308
- def __getitem__(self, index: SupportsIndex, /) -> _T_co: ...
309
- @overload
310
- def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ...
311
- def __contains__(self, value: object, /) -> bool: ...
312
- def __len__(self) -> int: ...
313
- def __iter__(self) -> Iterator[_T_co]: ...
314
- def index(self, value: Any, start: int = 0, stop: int = ..., /) -> int: ...
315
- def count(self, value: Any, /) -> int: ...
316
- def __reversed__(self) -> Iterator[_T_co]: ...