betterproto2-compiler 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -83,7 +83,7 @@ def get_type_reference(
83
83
  if unwrap:
84
84
  if source_type in WRAPPER_TYPES:
85
85
  wrapped_type = type(WRAPPER_TYPES[source_type]().value)
86
- return settings.typing_compiler.optional(wrapped_type.__name__)
86
+ return f"{wrapped_type.__name__} | None"
87
87
 
88
88
  if source_type == ".google.protobuf.Duration":
89
89
  return "datetime.timedelta"
@@ -1,3 +1,5 @@
1
+ import typing
2
+
1
3
  import betterproto2
2
4
 
3
5
  from betterproto2_compiler.lib.google.protobuf import Any as VanillaAny
@@ -18,19 +20,37 @@ class Any(VanillaAny):
18
20
  self.type_url = message_pool.type_to_url[type(message)]
19
21
  self.value = bytes(message)
20
22
 
21
- def unpack(self, message_pool: "betterproto2.MessagePool | None" = None) -> betterproto2.Message:
23
+ def unpack(self, message_pool: "betterproto2.MessagePool | None" = None) -> betterproto2.Message | None:
22
24
  """
23
25
  Return the message packed inside the `Any` object.
24
26
 
25
27
  The target message type must be registered in the message pool, which is done automatically when the module
26
28
  defining the message type is imported.
27
29
  """
30
+ if not self.type_url:
31
+ return None
32
+
28
33
  message_pool = message_pool or default_message_pool
29
34
 
30
- message_type = message_pool.url_to_type[self.type_url]
35
+ try:
36
+ message_type = message_pool.url_to_type[self.type_url]
37
+ except KeyError:
38
+ raise TypeError(f"Can't unpack unregistered type: {self.type_url}")
31
39
 
32
40
  return message_type().parse(self.value)
33
41
 
34
- def to_dict(self) -> dict: # pyright: ignore [reportIncompatibleMethodOverride]
35
- # TOOO improve when dict is updated
36
- return {"@type": self.type_url, "value": self.unpack().to_dict()}
42
+ def to_dict(self, **kwargs) -> dict[str, typing.Any]:
43
+ # TODO allow passing a message pool to `to_dict`
44
+ output: dict[str, typing.Any] = {"@type": self.type_url}
45
+
46
+ value = self.unpack()
47
+
48
+ if value is None:
49
+ return output
50
+
51
+ if type(value).to_dict == betterproto2.Message.to_dict:
52
+ output.update(value.to_dict(**kwargs))
53
+ else:
54
+ output["value"] = value.to_dict(**kwargs)
55
+
56
+ return output
@@ -37,9 +37,9 @@ class Timestamp(VanillaTimestamp):
37
37
  return f"{result}Z"
38
38
  if (nanos % 1e6) == 0:
39
39
  # Serialize 3 fractional digits.
40
- return f"{result}.{int(nanos // 1e6) :03d}Z"
40
+ return f"{result}.{int(nanos // 1e6):03d}Z"
41
41
  if (nanos % 1e3) == 0:
42
42
  # Serialize 6 fractional digits.
43
- return f"{result}.{int(nanos // 1e3) :06d}Z"
43
+ return f"{result}.{int(nanos // 1e3):06d}Z"
44
44
  # Serialize 9 fractional digits.
45
45
  return f"{result}.{nanos:09d}"
@@ -724,35 +724,6 @@ class Any(betterproto2.Message):
724
724
  Must be a valid serialized protocol buffer of the above specified type.
725
725
  """
726
726
 
727
- def pack(self, message: betterproto2.Message, message_pool: "betterproto2.MessagePool | None" = None) -> None:
728
- """
729
- Pack the given message in the `Any` object.
730
-
731
- The message type must be registered in the message pool, which is done automatically when the module defining
732
- the message type is imported.
733
- """
734
- message_pool = message_pool or default_message_pool
735
-
736
- self.type_url = message_pool.type_to_url[type(message)]
737
- self.value = bytes(message)
738
-
739
- def unpack(self, message_pool: "betterproto2.MessagePool | None" = None) -> betterproto2.Message:
740
- """
741
- Return the message packed inside the `Any` object.
742
-
743
- The target message type must be registered in the message pool, which is done automatically when the module
744
- defining the message type is imported.
745
- """
746
- message_pool = message_pool or default_message_pool
747
-
748
- message_type = message_pool.url_to_type[self.type_url]
749
-
750
- return message_type().parse(self.value)
751
-
752
- def to_dict(self) -> dict: # pyright: ignore [reportIncompatibleMethodOverride]
753
- # TOOO improve when dict is updated
754
- return {"@type": self.type_url, "value": self.unpack().to_dict()}
755
-
756
727
 
757
728
  default_message_pool.register_message("google.protobuf", "Any", Any)
758
729
 
@@ -57,7 +57,6 @@ from betterproto2_compiler.lib.google.protobuf import (
57
57
  ServiceDescriptorProto,
58
58
  )
59
59
  from betterproto2_compiler.lib.google.protobuf.compiler import CodeGeneratorRequest
60
- from betterproto2_compiler.plugin.typing_compiler import TypingCompiler
61
60
  from betterproto2_compiler.settings import Settings
62
61
 
63
62
  # Organize proto types into categories
@@ -251,14 +250,6 @@ class MessageCompiler(ProtoContentBase):
251
250
  def has_oneof_fields(self) -> bool:
252
251
  return any(isinstance(field, OneOfFieldCompiler) for field in self.fields)
253
252
 
254
- @property
255
- def has_message_field(self) -> bool:
256
- return any(
257
- field.proto_obj.type in PROTO_MESSAGE_TYPES
258
- for field in self.fields
259
- if isinstance(field.proto_obj, FieldDescriptorProto)
260
- )
261
-
262
253
  @property
263
254
  def custom_methods(self) -> list[str]:
264
255
  """
@@ -298,7 +289,6 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
298
289
 
299
290
  @dataclass(kw_only=True)
300
291
  class FieldCompiler(ProtoContentBase):
301
- typing_compiler: TypingCompiler
302
292
  builtins_types: set[str] = field(default_factory=set)
303
293
 
304
294
  message: MessageCompiler
@@ -413,9 +403,9 @@ class FieldCompiler(ProtoContentBase):
413
403
  if self.use_builtins:
414
404
  py_type = f"builtins.{py_type}"
415
405
  if self.repeated:
416
- return self.typing_compiler.list(py_type)
406
+ return f"list[{py_type}]"
417
407
  if self.optional:
418
- return self.typing_compiler.optional(py_type)
408
+ return f"{py_type} | None"
419
409
  return py_type
420
410
 
421
411
 
@@ -449,14 +439,12 @@ class MapEntryCompiler(FieldCompiler):
449
439
  self.py_k_type = FieldCompiler(
450
440
  source_file=self.source_file,
451
441
  proto_obj=nested.field[0], # key
452
- typing_compiler=self.typing_compiler,
453
442
  path=[],
454
443
  message=self.message,
455
444
  ).py_type
456
445
  self.py_v_type = FieldCompiler(
457
446
  source_file=self.source_file,
458
447
  proto_obj=nested.field[1], # value
459
- typing_compiler=self.typing_compiler,
460
448
  path=[],
461
449
  message=self.message,
462
450
  ).py_type
@@ -482,7 +470,7 @@ class MapEntryCompiler(FieldCompiler):
482
470
 
483
471
  @property
484
472
  def annotation(self) -> str:
485
- return self.typing_compiler.dict(self.py_k_type, self.py_v_type)
473
+ return f"dict[{self.py_k_type}, {self.py_v_type}]"
486
474
 
487
475
  @property
488
476
  def repeated(self) -> bool:
@@ -31,11 +31,6 @@ from .models import (
31
31
  is_map,
32
32
  is_oneof,
33
33
  )
34
- from .typing_compiler import (
35
- DirectImportTypingCompiler,
36
- NoTyping310TypingCompiler,
37
- TypingImportTypingCompiler,
38
- )
39
34
 
40
35
 
41
36
  def traverse(
@@ -65,25 +60,7 @@ def traverse(
65
60
 
66
61
 
67
62
  def get_settings(plugin_options: list[str]) -> Settings:
68
- # Gather any typing generation options.
69
- typing_opts = [opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.")]
70
-
71
- if len(typing_opts) > 1:
72
- raise ValueError("Multiple typing options provided")
73
-
74
- # Set the compiler type.
75
- typing_opt = typing_opts[0] if typing_opts else "direct"
76
- if typing_opt == "direct":
77
- typing_compiler = DirectImportTypingCompiler()
78
- elif typing_opt == "root":
79
- typing_compiler = TypingImportTypingCompiler()
80
- elif typing_opt == "310":
81
- typing_compiler = NoTyping310TypingCompiler()
82
- else:
83
- raise ValueError("Invalid typing option provided")
84
-
85
63
  return Settings(
86
- typing_compiler=typing_compiler,
87
64
  pydantic_dataclasses="pydantic_dataclasses" in plugin_options,
88
65
  )
89
66
 
@@ -203,7 +180,6 @@ def read_protobuf_type(
203
180
  message=message_data,
204
181
  proto_obj=field,
205
182
  path=path + [2, index],
206
- typing_compiler=output_package.settings.typing_compiler,
207
183
  )
208
184
  )
209
185
  elif is_oneof(field):
@@ -213,7 +189,6 @@ def read_protobuf_type(
213
189
  message=message_data,
214
190
  proto_obj=field,
215
191
  path=path + [2, index],
216
- typing_compiler=output_package.settings.typing_compiler,
217
192
  )
218
193
  )
219
194
  else:
@@ -223,7 +198,6 @@ def read_protobuf_type(
223
198
  message=message_data,
224
199
  proto_obj=field,
225
200
  path=path + [2, index],
226
- typing_compiler=output_package.settings.typing_compiler,
227
201
  )
228
202
  )
229
203
 
@@ -1,9 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
 
3
- from .plugin.typing_compiler import TypingCompiler
4
-
5
3
 
6
4
  @dataclass
7
5
  class Settings:
8
6
  pydantic_dataclasses: bool
9
- typing_compiler: TypingCompiler
@@ -21,6 +21,9 @@ __all__ = (
21
21
  import builtins
22
22
  import datetime
23
23
  import warnings
24
+ from collections.abc import AsyncIterable, AsyncIterator, Iterable
25
+ import typing
26
+ from typing import TYPE_CHECKING
24
27
 
25
28
  {% if output_file.settings.pydantic_dataclasses %}
26
29
  from pydantic.dataclasses import dataclass
@@ -29,21 +32,12 @@ from pydantic import model_validator
29
32
  from dataclasses import dataclass
30
33
  {% endif %}
31
34
 
32
- {% set typing_imports = output_file.settings.typing_compiler.imports() %}
33
- {% if typing_imports %}
34
- {% for line in output_file.settings.typing_compiler.import_lines() %}
35
- {{ line }}
36
- {% endfor %}
37
- {% endif %}
38
-
39
35
  import betterproto2
40
36
  {% if output_file.services %}
41
37
  from betterproto2.grpc.grpclib_server import ServiceBase
42
38
  import grpclib
43
39
  {% endif %}
44
40
 
45
- from typing import TYPE_CHECKING
46
-
47
41
  {# Import the message pool of the generated code. #}
48
42
  {% if output_file.package %}
49
43
  from {{ "." * output_file.package.count(".") }}..message_pool import default_message_pool
@@ -100,20 +100,20 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub):
100
100
  {%- if not method.client_streaming -%}
101
101
  , message:
102
102
  {%- if method.is_input_msg_empty -%}
103
- "{{ output_file.settings.typing_compiler.optional(method.py_input_message_type) }}" = None
103
+ "{{ method.py_input_message_type }} | None" = None
104
104
  {%- else -%}
105
105
  "{{ method.py_input_message_type }}"
106
106
  {%- endif -%}
107
107
  {%- else -%}
108
108
  {# Client streaming: need a request iterator instead #}
109
- , messages: "{{ output_file.settings.typing_compiler.union(output_file.settings.typing_compiler.async_iterable(method.py_input_message_type), output_file.settings.typing_compiler.iterable(method.py_input_message_type)) }}"
109
+ , messages: "AsyncIterable[{{ method.py_input_message_type }}] | Iterable[{{ method.py_input_message_type }}]"
110
110
  {%- endif -%}
111
111
  ,
112
112
  *
113
- , timeout: {{ output_file.settings.typing_compiler.optional("float") }} = None
114
- , deadline: "{{ output_file.settings.typing_compiler.optional("Deadline") }}" = None
115
- , metadata: "{{ output_file.settings.typing_compiler.optional("MetadataLike") }}" = None
116
- ) -> "{% if method.server_streaming %}{{ output_file.settings.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}{{ method.py_output_message_type }}{% endif %}":
113
+ , timeout: "float | None" = None
114
+ , deadline: "Deadline | None" = None
115
+ , metadata: "MetadataLike | None" = None
116
+ ) -> "{% if method.server_streaming %}AsyncIterator[{{ method.py_output_message_type }}]{% else %}{{ method.py_output_message_type }}{% endif %}":
117
117
  {% if method.comment %}
118
118
  """
119
119
  {{ method.comment | indent(8) }}
@@ -202,9 +202,9 @@ class {{ service.py_name }}Base(ServiceBase):
202
202
  , message: "{{ method.py_input_message_type }}"
203
203
  {%- else -%}
204
204
  {# Client streaming: need a request iterator instead #}
205
- , messages: {{ output_file.settings.typing_compiler.async_iterator(method.py_input_message_type) }}
205
+ , messages: "AsyncIterator[{{ method.py_input_message_type }}]"
206
206
  {%- endif -%}
207
- ) -> {% if method.server_streaming %}{{ output_file.settings.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
207
+ ) -> {% if method.server_streaming %}"AsyncIterator[{{ method.py_output_message_type }}]"{% else %}"{{ method.py_output_message_type }}"{% endif %}:
208
208
  {% if method.comment %}
209
209
  """
210
210
  {{ method.comment | indent(8) }}
@@ -235,7 +235,7 @@ class {{ service.py_name }}Base(ServiceBase):
235
235
 
236
236
  {% endfor %}
237
237
 
238
- def __mapping__(self) -> {{ output_file.settings.typing_compiler.dict("str", "grpclib.const.Handler") }}:
238
+ def __mapping__(self) -> "dict[str, grpclib.const.Handler":
239
239
  return {
240
240
  {% for method in service.methods %}
241
241
  "{{ method.route }}": grpclib.const.Handler(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: betterproto2_compiler
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: Compiler for betterproto2
5
5
  License: MIT
6
6
  Keywords: protobuf,gRPC,compiler
@@ -13,11 +13,12 @@ Classifier: Programming Language :: Python :: 3.10
13
13
  Classifier: Programming Language :: Python :: 3.11
14
14
  Classifier: Programming Language :: Python :: 3.12
15
15
  Classifier: Programming Language :: Python :: 3.13
16
- Requires-Dist: betterproto2 (>=0.2.1,<0.3.0)
16
+ Requires-Dist: betterproto2 (>=0.2.3,<0.3.0)
17
17
  Requires-Dist: grpclib (>=0.4.1,<0.5.0)
18
18
  Requires-Dist: jinja2 (>=3.0.3)
19
- Requires-Dist: ruff (>=0.7.4,<0.8.0)
19
+ Requires-Dist: ruff (>=0.9.3,<0.10.0)
20
20
  Requires-Dist: typing-extensions (>=4.7.1,<5.0.0)
21
+ Project-URL: Documentation, https://betterproto.github.io/python-betterproto2-compiler/
21
22
  Project-URL: Repository, https://github.com/betterproto/python-betterproto2-compiler
22
23
  Description-Content-Type: text/markdown
23
24
 
@@ -1,32 +1,31 @@
1
1
  betterproto2_compiler/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  betterproto2_compiler/casing.py,sha256=HSXLXAOqZzEnu-tC1SZjpW0LIjzdPqUNJEwy1BHzfgg,3056
3
3
  betterproto2_compiler/compile/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- betterproto2_compiler/compile/importing.py,sha256=4jPXNBbA3jRDQf5n7GHkC1yvR1cozFoFLKRvu5GIzMk,7152
4
+ betterproto2_compiler/compile/importing.py,sha256=2DF9zpYhjX_H3PCyUHCXhho1J2QdGAtXitlNneZZfFs,7129
5
5
  betterproto2_compiler/compile/naming.py,sha256=zf0VOmNojzyv33upOGelGxjZTEDE8JULEEED5_3inHg,562
6
6
  betterproto2_compiler/known_types/__init__.py,sha256=Exqo-3ubDuik0TZDTw5ZPqf-dVb2uPJTZxMG7X58E6U,780
7
- betterproto2_compiler/known_types/any.py,sha256=QnfSKTXzazZwmsxaFu8_SYNmLww1D2mFdmi29_8Nzb4,1432
7
+ betterproto2_compiler/known_types/any.py,sha256=eRMenvvrn-1Wiss3YxhqRsjZ4XqiqPb1YQuinoA8wI4,1899
8
8
  betterproto2_compiler/known_types/duration.py,sha256=jy9GPnQTT9qhi12pQDG_ptdFAdm2gYkq3NH75zUTDOU,895
9
- betterproto2_compiler/known_types/timestamp.py,sha256=lw7kQPYJn76Q2wR4l3nCshj7VZ3s8h_xRcvrGIno4z0,2155
9
+ betterproto2_compiler/known_types/timestamp.py,sha256=dUfJmdrVg1NW-zkquAc8kxMH6nqdxQ2x_MKeP4N5ksY,2153
10
10
  betterproto2_compiler/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  betterproto2_compiler/lib/google/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- betterproto2_compiler/lib/google/protobuf/__init__.py,sha256=ajfzhKfLLdrg1eESD9LR18chGNDI2EerRtjLTUnWxKY,109428
12
+ betterproto2_compiler/lib/google/protobuf/__init__.py,sha256=dGC8iW2zPQQxxNMGmS2IjY5bOhYUmNvcasB-GC2CtHQ,108191
13
13
  betterproto2_compiler/lib/google/protobuf/compiler/__init__.py,sha256=bWYhEcL4nz0_4H6FTCmN3F419FQWXkY49LGE2hJJ6jM,8906
14
14
  betterproto2_compiler/lib/message_pool.py,sha256=4-cRhhiM6bmfpUJZ8qxc8LEyqHBHpLCcotjbyZxl7JM,71
15
15
  betterproto2_compiler/plugin/__init__.py,sha256=L3pW0b4CvkM5x53x_sYt1kYiSFPO0_vaeH6EQPq9FAM,43
16
16
  betterproto2_compiler/plugin/__main__.py,sha256=vBQ82334kX06ImDbFlPFgiBRiLIinwNk3z8Khs6hd74,31
17
17
  betterproto2_compiler/plugin/compiler.py,sha256=3sPbCtdzjAGftVvoGe5dvxE8uTzJ78hABA-_f-1rEUo,2471
18
18
  betterproto2_compiler/plugin/main.py,sha256=gI2fSWc9U-fn6MOlkLg7iResr2YsXbdOge6SzNWxBAo,1302
19
- betterproto2_compiler/plugin/models.py,sha256=Kfda57yUyicTVs61HWzFZd7WT6BgntJ19VagFM56mSk,21449
19
+ betterproto2_compiler/plugin/models.py,sha256=qm-ZolRu8_R6LHadjZZBuk0Ht61MwXKLa-Y-S88RMxY,20932
20
20
  betterproto2_compiler/plugin/module_validation.py,sha256=JnP8dSN83eJJVDP_UPJsHzq7E7Md3lah0PnKXDbFW5Q,4808
21
- betterproto2_compiler/plugin/parser.py,sha256=9PT8ArcGwnCviIYo1j_dsfPcdroaSVWC0OfLNzI66hU,10410
21
+ betterproto2_compiler/plugin/parser.py,sha256=MIA5-pAIJsng59wk3KYEKBARNCsQEQeetnVZk_MhL0I,9349
22
22
  betterproto2_compiler/plugin/plugin.bat,sha256=lfLT1WguAXqyerLLsRL6BfHA0RqUE6QG79v-1BYVSpI,48
23
- betterproto2_compiler/plugin/typing_compiler.py,sha256=IK6m4ggHXK7HL98Ed_WjvQ_yeWfIpf_fIBZ9SA8UcyM,4873
24
23
  betterproto2_compiler/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
- betterproto2_compiler/settings.py,sha256=d7XTRMywahR9PcOgaycPXLkHgJMSYDffuSO874yyrh0,182
26
- betterproto2_compiler/templates/header.py.j2,sha256=wYBR4yer77dTGM9e1RijuID-mI4GHrmk66OjdVvmBxc,1734
27
- betterproto2_compiler/templates/template.py.j2,sha256=J3eWLM_zBiRQcu7AaLc6LMg54yiifQUpb0l0FdRyNAk,8870
28
- betterproto2_compiler-0.2.2.dist-info/LICENSE.md,sha256=Pgl2pReU-2yw2miGeQ55UFlyzqAZ_EpYVyZ2nWjwRv4,1121
29
- betterproto2_compiler-0.2.2.dist-info/METADATA,sha256=mfl_LYJm49_hnWbg6Dr98uvgM0Vcvxe0Dx-_Qhy9PKE,1099
30
- betterproto2_compiler-0.2.2.dist-info/WHEEL,sha256=IYZQI976HJqqOpQU6PHkJ8fb3tMNBFjg-Cn-pwAbaFM,88
31
- betterproto2_compiler-0.2.2.dist-info/entry_points.txt,sha256=re3Qg8lLljbVobeeKH2f1FVQZ114wfZkGv3zCZTD8Ok,84
32
- betterproto2_compiler-0.2.2.dist-info/RECORD,,
24
+ betterproto2_compiler/settings.py,sha256=FQwco5j9ViBXtDYoFqog7SlXhX2YcbgEnFP77znYiwc,94
25
+ betterproto2_compiler/templates/header.py.j2,sha256=HrISX0IKmUsVpIlGIT6XlD9e3LgWWPN7jYRpws5_-CY,1609
26
+ betterproto2_compiler/templates/template.py.j2,sha256=Kwbw302GGzSEQ007eSHmHYGF0lHw01n_c0sGbtV9pWI,8419
27
+ betterproto2_compiler-0.2.4.dist-info/LICENSE.md,sha256=Pgl2pReU-2yw2miGeQ55UFlyzqAZ_EpYVyZ2nWjwRv4,1121
28
+ betterproto2_compiler-0.2.4.dist-info/METADATA,sha256=bQ5WSjI-tm0jNx_WK_1fwlCd3qIP37SZF0EpyFBoI9k,1188
29
+ betterproto2_compiler-0.2.4.dist-info/WHEEL,sha256=IYZQI976HJqqOpQU6PHkJ8fb3tMNBFjg-Cn-pwAbaFM,88
30
+ betterproto2_compiler-0.2.4.dist-info/entry_points.txt,sha256=re3Qg8lLljbVobeeKH2f1FVQZ114wfZkGv3zCZTD8Ok,84
31
+ betterproto2_compiler-0.2.4.dist-info/RECORD,,
@@ -1,163 +0,0 @@
1
- import abc
2
- import builtins
3
- from collections import defaultdict
4
- from collections.abc import Iterator
5
- from dataclasses import (
6
- dataclass,
7
- field,
8
- )
9
-
10
-
11
- class TypingCompiler(metaclass=abc.ABCMeta):
12
- @abc.abstractmethod
13
- def optional(self, type_: str) -> str:
14
- raise NotImplementedError
15
-
16
- @abc.abstractmethod
17
- def list(self, type_: str) -> str:
18
- raise NotImplementedError
19
-
20
- @abc.abstractmethod
21
- def dict(self, key: str, value: str) -> str:
22
- raise NotImplementedError
23
-
24
- @abc.abstractmethod
25
- def union(self, *types: str) -> str:
26
- raise NotImplementedError
27
-
28
- @abc.abstractmethod
29
- def iterable(self, type_: str) -> str:
30
- raise NotImplementedError
31
-
32
- @abc.abstractmethod
33
- def async_iterable(self, type_: str) -> str:
34
- raise NotImplementedError
35
-
36
- @abc.abstractmethod
37
- def async_iterator(self, type_: str) -> str:
38
- raise NotImplementedError
39
-
40
- @abc.abstractmethod
41
- def imports(self) -> builtins.dict[str, set[str] | None]:
42
- """
43
- Returns either the direct import as a key with none as value, or a set of
44
- values to import from the key.
45
- """
46
- raise NotImplementedError
47
-
48
- def import_lines(self) -> Iterator:
49
- imports = self.imports()
50
- for key, value in imports.items():
51
- if value is None:
52
- yield f"import {key}"
53
- else:
54
- yield f"from {key} import ("
55
- for v in sorted(value):
56
- yield f" {v},"
57
- yield ")"
58
-
59
-
60
- @dataclass
61
- class DirectImportTypingCompiler(TypingCompiler):
62
- _imports: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set))
63
-
64
- def optional(self, type_: str) -> str:
65
- self._imports["typing"].add("Optional")
66
- return f"Optional[{type_}]"
67
-
68
- def list(self, type_: str) -> str:
69
- self._imports["typing"].add("List")
70
- return f"List[{type_}]"
71
-
72
- def dict(self, key: str, value: str) -> str:
73
- self._imports["typing"].add("Dict")
74
- return f"Dict[{key}, {value}]"
75
-
76
- def union(self, *types: str) -> str:
77
- self._imports["typing"].add("Union")
78
- return f"Union[{', '.join(types)}]"
79
-
80
- def iterable(self, type_: str) -> str:
81
- self._imports["typing"].add("Iterable")
82
- return f"Iterable[{type_}]"
83
-
84
- def async_iterable(self, type_: str) -> str:
85
- self._imports["typing"].add("AsyncIterable")
86
- return f"AsyncIterable[{type_}]"
87
-
88
- def async_iterator(self, type_: str) -> str:
89
- self._imports["typing"].add("AsyncIterator")
90
- return f"AsyncIterator[{type_}]"
91
-
92
- def imports(self) -> builtins.dict[str, set[str] | None]:
93
- return {k: v if v else None for k, v in self._imports.items()}
94
-
95
-
96
- @dataclass
97
- class TypingImportTypingCompiler(TypingCompiler):
98
- _imported: bool = False
99
-
100
- def optional(self, type_: str) -> str:
101
- self._imported = True
102
- return f"typing.Optional[{type_}]"
103
-
104
- def list(self, type_: str) -> str:
105
- self._imported = True
106
- return f"typing.List[{type_}]"
107
-
108
- def dict(self, key: str, value: str) -> str:
109
- self._imported = True
110
- return f"typing.Dict[{key}, {value}]"
111
-
112
- def union(self, *types: str) -> str:
113
- self._imported = True
114
- return f"typing.Union[{', '.join(types)}]"
115
-
116
- def iterable(self, type_: str) -> str:
117
- self._imported = True
118
- return f"typing.Iterable[{type_}]"
119
-
120
- def async_iterable(self, type_: str) -> str:
121
- self._imported = True
122
- return f"typing.AsyncIterable[{type_}]"
123
-
124
- def async_iterator(self, type_: str) -> str:
125
- self._imported = True
126
- return f"typing.AsyncIterator[{type_}]"
127
-
128
- def imports(self) -> builtins.dict[str, set[str] | None]:
129
- if self._imported:
130
- return {"typing": None}
131
- return {}
132
-
133
-
134
- @dataclass
135
- class NoTyping310TypingCompiler(TypingCompiler):
136
- _imports: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set))
137
-
138
- def optional(self, type_: str) -> str:
139
- return f"{type_} | None"
140
-
141
- def list(self, type_: str) -> str:
142
- return f"list[{type_}]"
143
-
144
- def dict(self, key: str, value: str) -> str:
145
- return f"dict[{key}, {value}]"
146
-
147
- def union(self, *types: str) -> str:
148
- return f"{' | '.join(types)}"
149
-
150
- def iterable(self, type_: str) -> str:
151
- self._imports["collections.abc"].add("Iterable")
152
- return f"Iterable[{type_}]"
153
-
154
- def async_iterable(self, type_: str) -> str:
155
- self._imports["collections.abc"].add("AsyncIterable")
156
- return f"AsyncIterable[{type_}]"
157
-
158
- def async_iterator(self, type_: str) -> str:
159
- self._imports["collections.abc"].add("AsyncIterator")
160
- return f"AsyncIterator[{type_}]"
161
-
162
- def imports(self) -> builtins.dict[str, set[str] | None]:
163
- return {k: v if v else None for k, v in self._imports.items()}