betterproto2-compiler 0.2.1__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/PKG-INFO +1 -1
  2. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/pyproject.toml +1 -1
  3. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/casing.py +12 -30
  4. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/compile/importing.py +8 -3
  5. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/plugin/models.py +3 -26
  6. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/plugin/parser.py +0 -26
  7. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/settings.py +0 -3
  8. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/templates/header.py.j2 +2 -9
  9. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/templates/template.py.j2 +19 -19
  10. betterproto2_compiler-0.2.1/src/betterproto2_compiler/plugin/typing_compiler.py +0 -163
  11. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/LICENSE.md +0 -0
  12. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/README.md +0 -0
  13. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/__init__.py +0 -0
  14. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/compile/__init__.py +0 -0
  15. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/compile/naming.py +0 -0
  16. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/known_types/__init__.py +0 -0
  17. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/known_types/any.py +0 -0
  18. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/known_types/duration.py +0 -0
  19. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/known_types/timestamp.py +0 -0
  20. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/lib/__init__.py +0 -0
  21. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/lib/google/__init__.py +0 -0
  22. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/lib/google/protobuf/__init__.py +0 -0
  23. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/lib/google/protobuf/compiler/__init__.py +0 -0
  24. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/lib/message_pool.py +0 -0
  25. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/plugin/__init__.py +0 -0
  26. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/plugin/__main__.py +0 -0
  27. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/plugin/compiler.py +0 -0
  28. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/plugin/main.py +0 -0
  29. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/plugin/module_validation.py +0 -0
  30. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/plugin/plugin.bat +0 -0
  31. {betterproto2_compiler-0.2.1 → betterproto2_compiler-0.2.3}/src/betterproto2_compiler/py.typed +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: betterproto2_compiler
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: Compiler for betterproto2
5
5
  License: MIT
6
6
  Keywords: protobuf,gRPC,compiler
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "betterproto2_compiler"
3
- version = "0.2.1"
3
+ version = "0.2.3"
4
4
  description = "Compiler for betterproto2"
5
5
  authors = ["Adrien Vannson <adrien.vannson@protonmail.com>", "Daniel G. Taylor <danielgtaylor@gmail.com>"]
6
6
  readme = "README.md"
@@ -21,43 +21,25 @@ def safe_snake_case(value: str) -> str:
21
21
  return value
22
22
 
23
23
 
24
- def snake_case(value: str, strict: bool = True) -> str:
24
+ def snake_case(name: str) -> str:
25
25
  """
26
26
  Join words with an underscore into lowercase and remove symbols.
27
+ """
27
28
 
28
- Parameters
29
- -----------
30
- value: :class:`str`
31
- The value to convert.
32
- strict: :class:`bool`
33
- Whether or not to force single underscores.
29
+ # If there are already underscores in the name, don't break it
30
+ if "_" in name or not any([c.isupper() for c in name]):
31
+ return name
34
32
 
35
- Returns
36
- --------
37
- :class:`str`
38
- The value in snake_case.
39
- """
33
+ # Add an underscore before capital letters
34
+ name = re.sub(r"(?<=[a-z0-9])([A-Z])", r"_\1", name)
40
35
 
41
- def substitute_word(symbols: str, word: str, is_start: bool) -> str:
42
- if not word:
43
- return ""
44
- if strict:
45
- delimiter_count = 0 if is_start else 1 # Single underscore if strict.
46
- elif is_start:
47
- delimiter_count = len(symbols)
48
- elif word.isupper() or word.islower():
49
- delimiter_count = max(1, len(symbols)) # Preserve all delimiters if not strict.
50
- else:
51
- delimiter_count = len(symbols) + 1 # Extra underscore for leading capital.
36
+ # Add an underscore before capital letters following an acronym
37
+ name = re.sub(r"(?<=[A-Z])([A-Z])(?=[a-z])", r"_\1", name)
52
38
 
53
- return ("_" * delimiter_count) + word.lower()
39
+ # Add an underscore before digits
40
+ name = re.sub(r"(?<=[a-zA-Z])([0-9])", r"_\1", name)
54
41
 
55
- snake = re.sub(
56
- f"(^)?({SYMBOLS})({WORD_UPPER}|{WORD})",
57
- lambda groups: substitute_word(groups[2], groups[3], groups[1] is not None),
58
- value,
59
- )
60
- return snake
42
+ return name.lower()
61
43
 
62
44
 
63
45
  def pascal_case(value: str, strict: bool = True) -> str:
@@ -83,7 +83,7 @@ def get_type_reference(
83
83
  if unwrap:
84
84
  if source_type in WRAPPER_TYPES:
85
85
  wrapped_type = type(WRAPPER_TYPES[source_type]().value)
86
- return settings.typing_compiler.optional(wrapped_type.__name__)
86
+ return f"{wrapped_type.__name__} | None"
87
87
 
88
88
  if source_type == ".google.protobuf.Duration":
89
89
  return "datetime.timedelta"
@@ -114,7 +114,7 @@ def reference_absolute(imports: set[str], py_package: list[str], py_type: str) -
114
114
  Returns a reference to a python type located in the root, i.e. sys.path.
115
115
  """
116
116
  string_import = ".".join(py_package)
117
- string_alias = safe_snake_case(string_import)
117
+ string_alias = "__".join([safe_snake_case(name) for name in py_package])
118
118
  imports.add(f"import {string_import} as {string_alias}")
119
119
  return f"{string_alias}.{py_type}"
120
120
 
@@ -175,6 +175,11 @@ def reference_cousin(current_package: list[str], imports: set[str], py_package:
175
175
  string_from = f".{'.' * distance_up}" + ".".join(py_package[len(shared_ancestry) : -1])
176
176
  string_import = py_package[-1]
177
177
  # Add trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34)
178
- string_alias = f"{'_' * distance_up}" + safe_snake_case(".".join(py_package[len(shared_ancestry) :])) + "__"
178
+ # string_alias = f"{'_' * distance_up}" + safe_snake_case(".".join(py_package[len(shared_ancestry) :])) + "__"
179
+ string_alias = (
180
+ f"{'_' * distance_up}"
181
+ + "__".join([safe_snake_case(name) for name in py_package[len(shared_ancestry) :]])
182
+ + "__"
183
+ )
179
184
  imports.add(f"from {string_from} import {string_import} as {string_alias}")
180
185
  return f"{string_alias}.{py_type}"
@@ -57,7 +57,6 @@ from betterproto2_compiler.lib.google.protobuf import (
57
57
  ServiceDescriptorProto,
58
58
  )
59
59
  from betterproto2_compiler.lib.google.protobuf.compiler import CodeGeneratorRequest
60
- from betterproto2_compiler.plugin.typing_compiler import TypingCompiler
61
60
  from betterproto2_compiler.settings import Settings
62
61
 
63
62
  # Organize proto types into categories
@@ -251,14 +250,6 @@ class MessageCompiler(ProtoContentBase):
251
250
  def has_oneof_fields(self) -> bool:
252
251
  return any(isinstance(field, OneOfFieldCompiler) for field in self.fields)
253
252
 
254
- @property
255
- def has_message_field(self) -> bool:
256
- return any(
257
- field.proto_obj.type in PROTO_MESSAGE_TYPES
258
- for field in self.fields
259
- if isinstance(field.proto_obj, FieldDescriptorProto)
260
- )
261
-
262
253
  @property
263
254
  def custom_methods(self) -> list[str]:
264
255
  """
@@ -298,7 +289,6 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
298
289
 
299
290
  @dataclass(kw_only=True)
300
291
  class FieldCompiler(ProtoContentBase):
301
- typing_compiler: TypingCompiler
302
292
  builtins_types: set[str] = field(default_factory=set)
303
293
 
304
294
  message: MessageCompiler
@@ -413,9 +403,9 @@ class FieldCompiler(ProtoContentBase):
413
403
  if self.use_builtins:
414
404
  py_type = f"builtins.{py_type}"
415
405
  if self.repeated:
416
- return self.typing_compiler.list(py_type)
406
+ return f"list[{py_type}]"
417
407
  if self.optional:
418
- return self.typing_compiler.optional(py_type)
408
+ return f"{py_type} | None"
419
409
  return py_type
420
410
 
421
411
 
@@ -449,14 +439,12 @@ class MapEntryCompiler(FieldCompiler):
449
439
  self.py_k_type = FieldCompiler(
450
440
  source_file=self.source_file,
451
441
  proto_obj=nested.field[0], # key
452
- typing_compiler=self.typing_compiler,
453
442
  path=[],
454
443
  message=self.message,
455
444
  ).py_type
456
445
  self.py_v_type = FieldCompiler(
457
446
  source_file=self.source_file,
458
447
  proto_obj=nested.field[1], # value
459
- typing_compiler=self.typing_compiler,
460
448
  path=[],
461
449
  message=self.message,
462
450
  ).py_type
@@ -482,7 +470,7 @@ class MapEntryCompiler(FieldCompiler):
482
470
 
483
471
  @property
484
472
  def annotation(self) -> str:
485
- return self.typing_compiler.dict(self.py_k_type, self.py_v_type)
473
+ return f"dict[{self.py_k_type}, {self.py_v_type}]"
486
474
 
487
475
  @property
488
476
  def repeated(self) -> bool:
@@ -600,17 +588,6 @@ class ServiceMethodCompiler(ProtoContentBase):
600
588
 
601
589
  return not bool(msg.fields)
602
590
 
603
- @property
604
- def py_input_message_param(self) -> str:
605
- """Param name corresponding to py_input_message_type.
606
-
607
- Returns
608
- -------
609
- str
610
- Param name corresponding to py_input_message_type.
611
- """
612
- return pythonize_field_name(self.py_input_message_type)
613
-
614
591
  @property
615
592
  def py_output_message_type(self) -> str:
616
593
  """String representation of the Python type corresponding to the
@@ -31,11 +31,6 @@ from .models import (
31
31
  is_map,
32
32
  is_oneof,
33
33
  )
34
- from .typing_compiler import (
35
- DirectImportTypingCompiler,
36
- NoTyping310TypingCompiler,
37
- TypingImportTypingCompiler,
38
- )
39
34
 
40
35
 
41
36
  def traverse(
@@ -65,25 +60,7 @@ def traverse(
65
60
 
66
61
 
67
62
  def get_settings(plugin_options: list[str]) -> Settings:
68
- # Gather any typing generation options.
69
- typing_opts = [opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.")]
70
-
71
- if len(typing_opts) > 1:
72
- raise ValueError("Multiple typing options provided")
73
-
74
- # Set the compiler type.
75
- typing_opt = typing_opts[0] if typing_opts else "direct"
76
- if typing_opt == "direct":
77
- typing_compiler = DirectImportTypingCompiler()
78
- elif typing_opt == "root":
79
- typing_compiler = TypingImportTypingCompiler()
80
- elif typing_opt == "310":
81
- typing_compiler = NoTyping310TypingCompiler()
82
- else:
83
- raise ValueError("Invalid typing option provided")
84
-
85
63
  return Settings(
86
- typing_compiler=typing_compiler,
87
64
  pydantic_dataclasses="pydantic_dataclasses" in plugin_options,
88
65
  )
89
66
 
@@ -203,7 +180,6 @@ def read_protobuf_type(
203
180
  message=message_data,
204
181
  proto_obj=field,
205
182
  path=path + [2, index],
206
- typing_compiler=output_package.settings.typing_compiler,
207
183
  )
208
184
  )
209
185
  elif is_oneof(field):
@@ -213,7 +189,6 @@ def read_protobuf_type(
213
189
  message=message_data,
214
190
  proto_obj=field,
215
191
  path=path + [2, index],
216
- typing_compiler=output_package.settings.typing_compiler,
217
192
  )
218
193
  )
219
194
  else:
@@ -223,7 +198,6 @@ def read_protobuf_type(
223
198
  message=message_data,
224
199
  proto_obj=field,
225
200
  path=path + [2, index],
226
- typing_compiler=output_package.settings.typing_compiler,
227
201
  )
228
202
  )
229
203
 
@@ -1,9 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
 
3
- from .plugin.typing_compiler import TypingCompiler
4
-
5
3
 
6
4
  @dataclass
7
5
  class Settings:
8
6
  pydantic_dataclasses: bool
9
- typing_compiler: TypingCompiler
@@ -21,6 +21,8 @@ __all__ = (
21
21
  import builtins
22
22
  import datetime
23
23
  import warnings
24
+ from collections.abc import AsyncIterable, AsyncIterator, Iterable
25
+ from typing import TYPE_CHECKING
24
26
 
25
27
  {% if output_file.settings.pydantic_dataclasses %}
26
28
  from pydantic.dataclasses import dataclass
@@ -29,21 +31,12 @@ from pydantic import model_validator
29
31
  from dataclasses import dataclass
30
32
  {% endif %}
31
33
 
32
- {% set typing_imports = output_file.settings.typing_compiler.imports() %}
33
- {% if typing_imports %}
34
- {% for line in output_file.settings.typing_compiler.import_lines() %}
35
- {{ line }}
36
- {% endfor %}
37
- {% endif %}
38
-
39
34
  import betterproto2
40
35
  {% if output_file.services %}
41
36
  from betterproto2.grpc.grpclib_server import ServiceBase
42
37
  import grpclib
43
38
  {% endif %}
44
39
 
45
- from typing import TYPE_CHECKING
46
-
47
40
  {# Import the message pool of the generated code. #}
48
41
  {% if output_file.package %}
49
42
  from {{ "." * output_file.package.count(".") }}..message_pool import default_message_pool
@@ -98,22 +98,22 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub):
98
98
  {% for method in service.methods %}
99
99
  async def {{ method.py_name }}(self
100
100
  {%- if not method.client_streaming -%}
101
- , {{ method.py_input_message_param }}:
101
+ , message:
102
102
  {%- if method.is_input_msg_empty -%}
103
- "{{ output_file.settings.typing_compiler.optional(method.py_input_message_type) }}" = None
103
+ "{{ method.py_input_message_type }} | None" = None
104
104
  {%- else -%}
105
105
  "{{ method.py_input_message_type }}"
106
106
  {%- endif -%}
107
107
  {%- else -%}
108
108
  {# Client streaming: need a request iterator instead #}
109
- , {{ method.py_input_message_param }}_iterator: "{{ output_file.settings.typing_compiler.union(output_file.settings.typing_compiler.async_iterable(method.py_input_message_type), output_file.settings.typing_compiler.iterable(method.py_input_message_type)) }}"
109
+ , messages: "AsyncIterable[{{ method.py_input_message_type }}] | Iterable[{{ method.py_input_message_type }}]"
110
110
  {%- endif -%}
111
111
  ,
112
112
  *
113
- , timeout: {{ output_file.settings.typing_compiler.optional("float") }} = None
114
- , deadline: "{{ output_file.settings.typing_compiler.optional("Deadline") }}" = None
115
- , metadata: "{{ output_file.settings.typing_compiler.optional("MetadataLike") }}" = None
116
- ) -> "{% if method.server_streaming %}{{ output_file.settings.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}{{ method.py_output_message_type }}{% endif %}":
113
+ , timeout: "float | None" = None
114
+ , deadline: "Deadline | None" = None
115
+ , metadata: "MetadataLike | None" = None
116
+ ) -> "{% if method.server_streaming %}AsyncIterator[{{ method.py_output_message_type }}]{% else %}{{ method.py_output_message_type }}{% endif %}":
117
117
  {% if method.comment %}
118
118
  """
119
119
  {{ method.comment | indent(8) }}
@@ -128,7 +128,7 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub):
128
128
  {% if method.client_streaming %}
129
129
  async for response in self._stream_stream(
130
130
  "{{ method.route }}",
131
- {{ method.py_input_message_param }}_iterator,
131
+ messages,
132
132
  {{ method.py_input_message_type }},
133
133
  {{ method.py_output_message_type }},
134
134
  timeout=timeout,
@@ -138,13 +138,13 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub):
138
138
  yield response
139
139
  {% else %}{# i.e. not client streaming #}
140
140
  {% if method.is_input_msg_empty %}
141
- if {{ method.py_input_message_param }} is None:
142
- {{ method.py_input_message_param }} = {{ method.py_input_message_type }}()
141
+ if message is None:
142
+ message = {{ method.py_input_message_type }}()
143
143
 
144
144
  {% endif %}
145
145
  async for response in self._unary_stream(
146
146
  "{{ method.route }}",
147
- {{ method.py_input_message_param }},
147
+ message,
148
148
  {{ method.py_output_message_type }},
149
149
  timeout=timeout,
150
150
  deadline=deadline,
@@ -157,7 +157,7 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub):
157
157
  {% if method.client_streaming %}
158
158
  return await self._stream_unary(
159
159
  "{{ method.route }}",
160
- {{ method.py_input_message_param }}_iterator,
160
+ messages,
161
161
  {{ method.py_input_message_type }},
162
162
  {{ method.py_output_message_type }},
163
163
  timeout=timeout,
@@ -166,13 +166,13 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub):
166
166
  )
167
167
  {% else %}{# i.e. not client streaming #}
168
168
  {% if method.is_input_msg_empty %}
169
- if {{ method.py_input_message_param }} is None:
170
- {{ method.py_input_message_param }} = {{ method.py_input_message_type }}()
169
+ if message is None:
170
+ message = {{ method.py_input_message_type }}()
171
171
 
172
172
  {% endif %}
173
173
  return await self._unary_unary(
174
174
  "{{ method.route }}",
175
- {{ method.py_input_message_param }},
175
+ message,
176
176
  {{ method.py_output_message_type }},
177
177
  timeout=timeout,
178
178
  deadline=deadline,
@@ -199,12 +199,12 @@ class {{ service.py_name }}Base(ServiceBase):
199
199
  {% for method in service.methods %}
200
200
  async def {{ method.py_name }}(self
201
201
  {%- if not method.client_streaming -%}
202
- , {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"
202
+ , message: "{{ method.py_input_message_type }}"
203
203
  {%- else -%}
204
204
  {# Client streaming: need a request iterator instead #}
205
- , {{ method.py_input_message_param }}_iterator: {{ output_file.settings.typing_compiler.async_iterator(method.py_input_message_type) }}
205
+ , messages: "AsyncIterator[{{ method.py_input_message_type }}]"
206
206
  {%- endif -%}
207
- ) -> {% if method.server_streaming %}{{ output_file.settings.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
207
+ ) -> {% if method.server_streaming %}"AsyncIterator[{{ method.py_output_message_type }}]"{% else %}"{{ method.py_output_message_type }}"{% endif %}:
208
208
  {% if method.comment %}
209
209
  """
210
210
  {{ method.comment | indent(8) }}
@@ -235,7 +235,7 @@ class {{ service.py_name }}Base(ServiceBase):
235
235
 
236
236
  {% endfor %}
237
237
 
238
- def __mapping__(self) -> {{ output_file.settings.typing_compiler.dict("str", "grpclib.const.Handler") }}:
238
+ def __mapping__(self) -> "dict[str, grpclib.const.Handler":
239
239
  return {
240
240
  {% for method in service.methods %}
241
241
  "{{ method.route }}": grpclib.const.Handler(
@@ -1,163 +0,0 @@
1
- import abc
2
- import builtins
3
- from collections import defaultdict
4
- from collections.abc import Iterator
5
- from dataclasses import (
6
- dataclass,
7
- field,
8
- )
9
-
10
-
11
- class TypingCompiler(metaclass=abc.ABCMeta):
12
- @abc.abstractmethod
13
- def optional(self, type_: str) -> str:
14
- raise NotImplementedError
15
-
16
- @abc.abstractmethod
17
- def list(self, type_: str) -> str:
18
- raise NotImplementedError
19
-
20
- @abc.abstractmethod
21
- def dict(self, key: str, value: str) -> str:
22
- raise NotImplementedError
23
-
24
- @abc.abstractmethod
25
- def union(self, *types: str) -> str:
26
- raise NotImplementedError
27
-
28
- @abc.abstractmethod
29
- def iterable(self, type_: str) -> str:
30
- raise NotImplementedError
31
-
32
- @abc.abstractmethod
33
- def async_iterable(self, type_: str) -> str:
34
- raise NotImplementedError
35
-
36
- @abc.abstractmethod
37
- def async_iterator(self, type_: str) -> str:
38
- raise NotImplementedError
39
-
40
- @abc.abstractmethod
41
- def imports(self) -> builtins.dict[str, set[str] | None]:
42
- """
43
- Returns either the direct import as a key with none as value, or a set of
44
- values to import from the key.
45
- """
46
- raise NotImplementedError
47
-
48
- def import_lines(self) -> Iterator:
49
- imports = self.imports()
50
- for key, value in imports.items():
51
- if value is None:
52
- yield f"import {key}"
53
- else:
54
- yield f"from {key} import ("
55
- for v in sorted(value):
56
- yield f" {v},"
57
- yield ")"
58
-
59
-
60
- @dataclass
61
- class DirectImportTypingCompiler(TypingCompiler):
62
- _imports: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set))
63
-
64
- def optional(self, type_: str) -> str:
65
- self._imports["typing"].add("Optional")
66
- return f"Optional[{type_}]"
67
-
68
- def list(self, type_: str) -> str:
69
- self._imports["typing"].add("List")
70
- return f"List[{type_}]"
71
-
72
- def dict(self, key: str, value: str) -> str:
73
- self._imports["typing"].add("Dict")
74
- return f"Dict[{key}, {value}]"
75
-
76
- def union(self, *types: str) -> str:
77
- self._imports["typing"].add("Union")
78
- return f"Union[{', '.join(types)}]"
79
-
80
- def iterable(self, type_: str) -> str:
81
- self._imports["typing"].add("Iterable")
82
- return f"Iterable[{type_}]"
83
-
84
- def async_iterable(self, type_: str) -> str:
85
- self._imports["typing"].add("AsyncIterable")
86
- return f"AsyncIterable[{type_}]"
87
-
88
- def async_iterator(self, type_: str) -> str:
89
- self._imports["typing"].add("AsyncIterator")
90
- return f"AsyncIterator[{type_}]"
91
-
92
- def imports(self) -> builtins.dict[str, set[str] | None]:
93
- return {k: v if v else None for k, v in self._imports.items()}
94
-
95
-
96
- @dataclass
97
- class TypingImportTypingCompiler(TypingCompiler):
98
- _imported: bool = False
99
-
100
- def optional(self, type_: str) -> str:
101
- self._imported = True
102
- return f"typing.Optional[{type_}]"
103
-
104
- def list(self, type_: str) -> str:
105
- self._imported = True
106
- return f"typing.List[{type_}]"
107
-
108
- def dict(self, key: str, value: str) -> str:
109
- self._imported = True
110
- return f"typing.Dict[{key}, {value}]"
111
-
112
- def union(self, *types: str) -> str:
113
- self._imported = True
114
- return f"typing.Union[{', '.join(types)}]"
115
-
116
- def iterable(self, type_: str) -> str:
117
- self._imported = True
118
- return f"typing.Iterable[{type_}]"
119
-
120
- def async_iterable(self, type_: str) -> str:
121
- self._imported = True
122
- return f"typing.AsyncIterable[{type_}]"
123
-
124
- def async_iterator(self, type_: str) -> str:
125
- self._imported = True
126
- return f"typing.AsyncIterator[{type_}]"
127
-
128
- def imports(self) -> builtins.dict[str, set[str] | None]:
129
- if self._imported:
130
- return {"typing": None}
131
- return {}
132
-
133
-
134
- @dataclass
135
- class NoTyping310TypingCompiler(TypingCompiler):
136
- _imports: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set))
137
-
138
- def optional(self, type_: str) -> str:
139
- return f"{type_} | None"
140
-
141
- def list(self, type_: str) -> str:
142
- return f"list[{type_}]"
143
-
144
- def dict(self, key: str, value: str) -> str:
145
- return f"dict[{key}, {value}]"
146
-
147
- def union(self, *types: str) -> str:
148
- return f"{' | '.join(types)}"
149
-
150
- def iterable(self, type_: str) -> str:
151
- self._imports["collections.abc"].add("Iterable")
152
- return f"Iterable[{type_}]"
153
-
154
- def async_iterable(self, type_: str) -> str:
155
- self._imports["collections.abc"].add("AsyncIterable")
156
- return f"AsyncIterable[{type_}]"
157
-
158
- def async_iterator(self, type_: str) -> str:
159
- self._imports["collections.abc"].add("AsyncIterator")
160
- return f"AsyncIterator[{type_}]"
161
-
162
- def imports(self) -> builtins.dict[str, set[str] | None]:
163
- return {k: v if v else None for k, v in self._imports.items()}