betterproto2-compiler 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,11 +3,6 @@ from __future__ import annotations
3
3
  import os
4
4
  from typing import (
5
5
  TYPE_CHECKING,
6
- Dict,
7
- List,
8
- Set,
9
- Tuple,
10
- Type,
11
6
  )
12
7
 
13
8
  from ..casing import safe_snake_case
@@ -18,7 +13,7 @@ if TYPE_CHECKING:
18
13
  from ..plugin.models import PluginRequestCompiler
19
14
  from ..plugin.typing_compiler import TypingCompiler
20
15
 
21
- WRAPPER_TYPES: Dict[str, Type] = {
16
+ WRAPPER_TYPES: dict[str, type] = {
22
17
  ".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
23
18
  ".google.protobuf.FloatValue": google_protobuf.FloatValue,
24
19
  ".google.protobuf.Int32Value": google_protobuf.Int32Value,
@@ -31,7 +26,7 @@ WRAPPER_TYPES: Dict[str, Type] = {
31
26
  }
32
27
 
33
28
 
34
- def parse_source_type_name(field_type_name: str, request: "PluginRequestCompiler") -> Tuple[str, str]:
29
+ def parse_source_type_name(field_type_name: str, request: PluginRequestCompiler) -> tuple[str, str]:
35
30
  """
36
31
  Split full source type name into package and type name.
37
32
  E.g. 'root.package.Message' -> ('root.package', 'Message')
@@ -77,7 +72,7 @@ def get_type_reference(
77
72
  imports: set,
78
73
  source_type: str,
79
74
  typing_compiler: TypingCompiler,
80
- request: "PluginRequestCompiler",
75
+ request: PluginRequestCompiler,
81
76
  unwrap: bool = True,
82
77
  pydantic: bool = False,
83
78
  ) -> str:
@@ -98,16 +93,16 @@ def get_type_reference(
98
93
 
99
94
  source_package, source_type = parse_source_type_name(source_type, request)
100
95
 
101
- current_package: List[str] = package.split(".") if package else []
102
- py_package: List[str] = source_package.split(".") if source_package else []
96
+ current_package: list[str] = package.split(".") if package else []
97
+ py_package: list[str] = source_package.split(".") if source_package else []
103
98
  py_type: str = pythonize_class_name(source_type)
104
99
 
105
100
  compiling_google_protobuf = current_package == ["google", "protobuf"]
106
101
  importing_google_protobuf = py_package == ["google", "protobuf"]
107
102
  if importing_google_protobuf and not compiling_google_protobuf:
108
- py_package = ["betterproto", "lib"] + (["pydantic"] if pydantic else []) + py_package
103
+ py_package = ["betterproto2", "lib"] + (["pydantic"] if pydantic else []) + py_package
109
104
 
110
- if py_package[:1] == ["betterproto"]:
105
+ if py_package[:1] == ["betterproto2"]:
111
106
  return reference_absolute(imports, py_package, py_type)
112
107
 
113
108
  if py_package == current_package:
@@ -122,7 +117,7 @@ def get_type_reference(
122
117
  return reference_cousin(current_package, imports, py_package, py_type)
123
118
 
124
119
 
125
- def reference_absolute(imports: Set[str], py_package: List[str], py_type: str) -> str:
120
+ def reference_absolute(imports: set[str], py_package: list[str], py_type: str) -> str:
126
121
  """
127
122
  Returns a reference to a python type located in the root, i.e. sys.path.
128
123
  """
@@ -139,7 +134,7 @@ def reference_sibling(py_type: str) -> str:
139
134
  return f"{py_type}"
140
135
 
141
136
 
142
- def reference_descendent(current_package: List[str], imports: Set[str], py_package: List[str], py_type: str) -> str:
137
+ def reference_descendent(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str:
143
138
  """
144
139
  Returns a reference to a python type in a package that is a descendent of the
145
140
  current package, and adds the required import that is aliased to avoid name
@@ -157,7 +152,7 @@ def reference_descendent(current_package: List[str], imports: Set[str], py_packa
157
152
  return f"{string_import}.{py_type}"
158
153
 
159
154
 
160
- def reference_ancestor(current_package: List[str], imports: Set[str], py_package: List[str], py_type: str) -> str:
155
+ def reference_ancestor(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str:
161
156
  """
162
157
  Returns a reference to a python type in a package which is an ancestor to the
163
158
  current package, and adds the required import that is aliased (if possible) to avoid
@@ -178,7 +173,7 @@ def reference_ancestor(current_package: List[str], imports: Set[str], py_package
178
173
  return string_alias
179
174
 
180
175
 
181
- def reference_cousin(current_package: List[str], imports: Set[str], py_package: List[str], py_type: str) -> str:
176
+ def reference_cousin(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str:
182
177
  """
183
178
  Returns a reference to a python type in a package that is not descendent, ancestor
184
179
  or sibling, and adds the required import that is aliased to avoid name conflicts.
@@ -2401,13 +2401,13 @@ class Value(betterproto2_compiler.Message):
2401
2401
  )
2402
2402
  """Represents a null value."""
2403
2403
 
2404
- number_value: Optional[float] = betterproto2_compiler.double_field(2, optional=True, group="kind")
2404
+ number_value: float | None = betterproto2_compiler.double_field(2, optional=True, group="kind")
2405
2405
  """Represents a double value."""
2406
2406
 
2407
- string_value: Optional[str] = betterproto2_compiler.string_field(3, optional=True, group="kind")
2407
+ string_value: str | None = betterproto2_compiler.string_field(3, optional=True, group="kind")
2408
2408
  """Represents a string value."""
2409
2409
 
2410
- bool_value: Optional[bool] = betterproto2_compiler.bool_field(4, optional=True, group="kind")
2410
+ bool_value: bool | None = betterproto2_compiler.bool_field(4, optional=True, group="kind")
2411
2411
  """Represents a boolean value."""
2412
2412
 
2413
2413
  struct_value: Optional["Struct"] = betterproto2_compiler.message_field(5, optional=True, group="kind")
@@ -78,7 +78,6 @@ from typing import (
78
78
  Dict,
79
79
  List,
80
80
  Mapping,
81
- Optional,
82
81
  )
83
82
 
84
83
  import betterproto2
@@ -1022,7 +1021,7 @@ class FieldDescriptorProto(betterproto2.Message):
1022
1021
  TODO(kenton): Base-64 encode?
1023
1022
  """
1024
1023
 
1025
- oneof_index: Optional[int] = betterproto2.int32_field(9, optional=True)
1024
+ oneof_index: int | None = betterproto2.int32_field(9, optional=True)
1026
1025
  """
1027
1026
  If set, gives the index of a oneof in the containing type's oneof_decl
1028
1027
  list. This field is a member of that oneof.
@@ -1,6 +1,7 @@
1
1
  import os.path
2
2
  import subprocess
3
3
  import sys
4
+ from importlib import metadata
4
5
 
5
6
  from .module_validation import ModuleValidator
6
7
 
@@ -14,7 +15,7 @@ except ImportError as err:
14
15
  "Please ensure that you've installed betterproto as "
15
16
  '`pip install "betterproto[compiler]"` so that compiler dependencies '
16
17
  "are included."
17
- "\033[0m"
18
+ "\033[0m",
18
19
  )
19
20
  raise SystemExit(1)
20
21
 
@@ -24,6 +25,8 @@ from .models import OutputTemplate
24
25
  def outputfile_compiler(output_file: OutputTemplate) -> str:
25
26
  templates_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "templates"))
26
27
 
28
+ version = metadata.version("betterproto2_compiler")
29
+
27
30
  env = jinja2.Environment(
28
31
  trim_blocks=True,
29
32
  lstrip_blocks=True,
@@ -35,7 +38,7 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
35
38
  header_template = env.get_template("header.py.j2")
36
39
 
37
40
  code = body_template.render(output_file=output_file)
38
- code = header_template.render(output_file=output_file) + "\n" + code
41
+ code = header_template.render(output_file=output_file, version=version) + "\n" + code
39
42
 
40
43
  # Sort imports, delete unused ones
41
44
  code = subprocess.check_output(
@@ -31,22 +31,17 @@ reference to `A` to `B`'s `fields` attribute.
31
31
 
32
32
  import builtins
33
33
  import re
34
+ from collections.abc import Iterable, Iterator
34
35
  from dataclasses import (
35
36
  dataclass,
36
37
  field,
37
38
  )
38
39
  from typing import (
39
- Dict,
40
- Iterable,
41
- Iterator,
42
- List,
43
- Optional,
44
- Set,
45
- Type,
46
40
  Union,
47
41
  )
48
42
 
49
- import betterproto2_compiler
43
+ import betterproto2
44
+
50
45
  from betterproto2_compiler.compile.naming import (
51
46
  pythonize_class_name,
52
47
  pythonize_field_name,
@@ -58,6 +53,7 @@ from betterproto2_compiler.lib.google.protobuf import (
58
53
  FieldDescriptorProto,
59
54
  FieldDescriptorProtoLabel,
60
55
  FieldDescriptorProtoType,
56
+ FieldDescriptorProtoType as FieldType,
61
57
  FileDescriptorProto,
62
58
  MethodDescriptorProto,
63
59
  )
@@ -145,7 +141,7 @@ PROTO_PACKED_TYPES = (
145
141
 
146
142
  def get_comment(
147
143
  proto_file: "FileDescriptorProto",
148
- path: List[int],
144
+ path: list[int],
149
145
  ) -> str:
150
146
  for sci_loc in proto_file.source_code_info.location:
151
147
  if list(sci_loc.path) == path:
@@ -181,10 +177,10 @@ class ProtoContentBase:
181
177
 
182
178
  source_file: FileDescriptorProto
183
179
  typing_compiler: TypingCompiler
184
- path: List[int]
185
- parent: Union["betterproto2_compiler.Message", "OutputTemplate"]
180
+ path: list[int]
181
+ parent: Union["betterproto2.Message", "OutputTemplate"]
186
182
 
187
- __dataclass_fields__: Dict[str, object]
183
+ __dataclass_fields__: dict[str, object]
188
184
 
189
185
  def __post_init__(self) -> None:
190
186
  """Checks that no fake default fields were left as placeholders."""
@@ -224,10 +220,10 @@ class ProtoContentBase:
224
220
  @dataclass
225
221
  class PluginRequestCompiler:
226
222
  plugin_request_obj: CodeGeneratorRequest
227
- output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
223
+ output_packages: dict[str, "OutputTemplate"] = field(default_factory=dict)
228
224
 
229
225
  @property
230
- def all_messages(self) -> List["MessageCompiler"]:
226
+ def all_messages(self) -> list["MessageCompiler"]:
231
227
  """All of the messages in this request.
232
228
 
233
229
  Returns
@@ -249,11 +245,11 @@ class OutputTemplate:
249
245
 
250
246
  parent_request: PluginRequestCompiler
251
247
  package_proto_obj: FileDescriptorProto
252
- input_files: List[str] = field(default_factory=list)
253
- imports_end: Set[str] = field(default_factory=set)
254
- messages: Dict[str, "MessageCompiler"] = field(default_factory=dict)
255
- enums: Dict[str, "EnumDefinitionCompiler"] = field(default_factory=dict)
256
- services: Dict[str, "ServiceCompiler"] = field(default_factory=dict)
248
+ input_files: list[str] = field(default_factory=list)
249
+ imports_end: set[str] = field(default_factory=set)
250
+ messages: dict[str, "MessageCompiler"] = field(default_factory=dict)
251
+ enums: dict[str, "EnumDefinitionCompiler"] = field(default_factory=dict)
252
+ services: dict[str, "ServiceCompiler"] = field(default_factory=dict)
257
253
  pydantic_dataclasses: bool = False
258
254
  output: bool = True
259
255
  typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler)
@@ -289,9 +285,9 @@ class MessageCompiler(ProtoContentBase):
289
285
  typing_compiler: TypingCompiler
290
286
  parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
291
287
  proto_obj: DescriptorProto = PLACEHOLDER
292
- path: List[int] = PLACEHOLDER
293
- fields: List[Union["FieldCompiler", "MessageCompiler"]] = field(default_factory=list)
294
- builtins_types: Set[str] = field(default_factory=set)
288
+ path: list[int] = PLACEHOLDER
289
+ fields: list[Union["FieldCompiler", "MessageCompiler"]] = field(default_factory=list)
290
+ builtins_types: set[str] = field(default_factory=set)
295
291
 
296
292
  def __post_init__(self) -> None:
297
293
  # Add message to output file
@@ -327,11 +323,9 @@ class MessageCompiler(ProtoContentBase):
327
323
  @property
328
324
  def has_message_field(self) -> bool:
329
325
  return any(
330
- (
331
- field.proto_obj.type in PROTO_MESSAGE_TYPES
332
- for field in self.fields
333
- if isinstance(field.proto_obj, FieldDescriptorProto)
334
- )
326
+ field.proto_obj.type in PROTO_MESSAGE_TYPES
327
+ for field in self.fields
328
+ if isinstance(field.proto_obj, FieldDescriptorProto)
335
329
  )
336
330
 
337
331
 
@@ -346,7 +340,7 @@ def is_map(proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProt
346
340
  map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
347
341
  if message_type == map_entry:
348
342
  for nested in parent_message.nested_type: # parent message
349
- if nested.name.replace("_", "").lower() == map_entry and nested.options.map_entry:
343
+ if nested.name.replace("_", "").lower() == map_entry and nested.options and nested.options.map_entry:
350
344
  return True
351
345
  return False
352
346
 
@@ -373,8 +367,8 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
373
367
  class FieldCompiler(ProtoContentBase):
374
368
  source_file: FileDescriptorProto
375
369
  typing_compiler: TypingCompiler
376
- path: List[int] = PLACEHOLDER
377
- builtins_types: Set[str] = field(default_factory=set)
370
+ path: list[int] = PLACEHOLDER
371
+ builtins_types: set[str] = field(default_factory=set)
378
372
 
379
373
  parent: MessageCompiler = PLACEHOLDER
380
374
  proto_obj: FieldDescriptorProto = PLACEHOLDER
@@ -389,13 +383,16 @@ class FieldCompiler(ProtoContentBase):
389
383
  """Construct string representation of this field as a field."""
390
384
  name = f"{self.py_name}"
391
385
  field_args = ", ".join(([""] + self.betterproto_field_args) if self.betterproto_field_args else [])
392
- betterproto_field_type = f"betterproto2.{self.field_type}_field({self.proto_obj.number}{field_args})"
386
+
387
+ betterproto_field_type = (
388
+ f"betterproto2.field({self.proto_obj.number}, betterproto2.{str(self.field_type)}{field_args})"
389
+ )
393
390
  if self.py_name in dir(builtins):
394
391
  self.parent.builtins_types.add(self.py_name)
395
392
  return f'{name}: "{self.annotation}" = {betterproto_field_type}'
396
393
 
397
394
  @property
398
- def betterproto_field_args(self) -> List[str]:
395
+ def betterproto_field_args(self) -> list[str]:
399
396
  args = []
400
397
  if self.field_wraps:
401
398
  args.append(f"wraps={self.field_wraps}")
@@ -403,9 +400,9 @@ class FieldCompiler(ProtoContentBase):
403
400
  args.append("optional=True")
404
401
  if self.repeated:
405
402
  args.append("repeated=True")
406
- if self.field_type == "enum":
403
+ if self.field_type == FieldType.TYPE_ENUM:
407
404
  t = self.py_type
408
- args.append(f"enum_default_value=lambda: {t}.try_value(0)")
405
+ args.append(f"default_factory=lambda: {t}.try_value(0)")
409
406
  return args
410
407
 
411
408
  @property
@@ -415,29 +412,31 @@ class FieldCompiler(ProtoContentBase):
415
412
  )
416
413
 
417
414
  @property
418
- def field_wraps(self) -> Optional[str]:
415
+ def field_wraps(self) -> str | None:
419
416
  """Returns betterproto wrapped field type or None."""
420
417
  match_wrapper = re.match(r"\.google\.protobuf\.(.+)Value$", self.proto_obj.type_name)
421
418
  if match_wrapper:
422
419
  wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
423
- if hasattr(betterproto2_compiler, wrapped_type):
420
+ if hasattr(betterproto2, wrapped_type):
424
421
  return f"betterproto2.{wrapped_type}"
425
422
  return None
426
423
 
427
424
  @property
428
425
  def repeated(self) -> bool:
429
426
  return self.proto_obj.label == FieldDescriptorProtoLabel.LABEL_REPEATED and not is_map(
430
- self.proto_obj, self.parent
427
+ self.proto_obj,
428
+ self.parent,
431
429
  )
432
430
 
433
431
  @property
434
432
  def optional(self) -> bool:
435
- return self.proto_obj.proto3_optional or (self.field_type == "message" and not self.repeated)
433
+ # TODO not for maps
434
+ return self.proto_obj.proto3_optional or (self.field_type == FieldType.TYPE_MESSAGE and not self.repeated)
436
435
 
437
436
  @property
438
- def field_type(self) -> str:
439
- """String representation of proto field type."""
440
- return FieldDescriptorProtoType(self.proto_obj.type).name.lower().replace("type_", "")
437
+ def field_type(self) -> FieldType:
438
+ # TODO it should be possible to remove constructor
439
+ return FieldType(self.proto_obj.type)
441
440
 
442
441
  @property
443
442
  def packed(self) -> bool:
@@ -499,7 +498,7 @@ class OneOfFieldCompiler(FieldCompiler):
499
498
  return True
500
499
 
501
500
  @property
502
- def betterproto_field_args(self) -> List[str]:
501
+ def betterproto_field_args(self) -> list[str]:
503
502
  args = super().betterproto_field_args
504
503
  group = self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name
505
504
  args.append(f'group="{group}"')
@@ -508,8 +507,8 @@ class OneOfFieldCompiler(FieldCompiler):
508
507
 
509
508
  @dataclass
510
509
  class MapEntryCompiler(FieldCompiler):
511
- py_k_type: Optional[Type] = None
512
- py_v_type: Optional[Type] = None
510
+ py_k_type: type | None = None
511
+ py_v_type: type | None = None
513
512
  proto_k_type: str = ""
514
513
  proto_v_type: str = ""
515
514
 
@@ -546,13 +545,17 @@ class MapEntryCompiler(FieldCompiler):
546
545
 
547
546
  raise ValueError("can't find enum")
548
547
 
549
- @property
550
- def betterproto_field_args(self) -> List[str]:
551
- return [f"betterproto2.{self.proto_k_type}", f"betterproto2.{self.proto_v_type}"]
552
-
553
- @property
554
- def field_type(self) -> str:
555
- return "map"
548
+ def get_field_string(self) -> str:
549
+ """Construct string representation of this field as a field."""
550
+ betterproto_field_type = (
551
+ f"betterproto2.field({self.proto_obj.number}, "
552
+ "betterproto2.TYPE_MAP, "
553
+ f"map_types=(betterproto2.{self.proto_k_type}, "
554
+ f"betterproto2.{self.proto_v_type}))"
555
+ )
556
+ if self.py_name in dir(builtins):
557
+ self.parent.builtins_types.add(self.py_name)
558
+ return f'{self.py_name}: "{self.annotation}" = {betterproto_field_type}'
556
559
 
557
560
  @property
558
561
  def annotation(self) -> str:
@@ -568,7 +571,7 @@ class EnumDefinitionCompiler(MessageCompiler):
568
571
  """Representation of a proto Enum definition."""
569
572
 
570
573
  proto_obj: EnumDescriptorProto = PLACEHOLDER
571
- entries: List["EnumDefinitionCompiler.EnumEntry"] = PLACEHOLDER
574
+ entries: list["EnumDefinitionCompiler.EnumEntry"] = PLACEHOLDER
572
575
 
573
576
  @dataclass(unsafe_hash=True)
574
577
  class EnumEntry:
@@ -596,8 +599,8 @@ class ServiceCompiler(ProtoContentBase):
596
599
  source_file: FileDescriptorProto
597
600
  parent: OutputTemplate = PLACEHOLDER
598
601
  proto_obj: DescriptorProto = PLACEHOLDER
599
- path: List[int] = PLACEHOLDER
600
- methods: List["ServiceMethodCompiler"] = field(default_factory=list)
602
+ path: list[int] = PLACEHOLDER
603
+ methods: list["ServiceMethodCompiler"] = field(default_factory=list)
601
604
 
602
605
  def __post_init__(self) -> None:
603
606
  # Add service to output file
@@ -618,7 +621,7 @@ class ServiceMethodCompiler(ProtoContentBase):
618
621
  source_file: FileDescriptorProto
619
622
  parent: ServiceCompiler
620
623
  proto_obj: MethodDescriptorProto
621
- path: List[int] = PLACEHOLDER
624
+ path: list[int] = PLACEHOLDER
622
625
 
623
626
  def __post_init__(self) -> None:
624
627
  # Add method to service
@@ -1,15 +1,10 @@
1
1
  import re
2
2
  from collections import defaultdict
3
+ from collections.abc import Iterator
3
4
  from dataclasses import (
4
5
  dataclass,
5
6
  field,
6
7
  )
7
- from typing import (
8
- Dict,
9
- Iterator,
10
- List,
11
- Tuple,
12
- )
13
8
 
14
9
 
15
10
  @dataclass
@@ -17,7 +12,7 @@ class ModuleValidator:
17
12
  line_iterator: Iterator[str]
18
13
  line_number: int = field(init=False, default=0)
19
14
 
20
- collisions: Dict[str, List[Tuple[int, str]]] = field(init=False, default_factory=lambda: defaultdict(list))
15
+ collisions: dict[str, list[tuple[int, str]]] = field(init=False, default_factory=lambda: defaultdict(list))
21
16
 
22
17
  def add_import(self, imp: str, number: int, full_line: str):
23
18
  """
@@ -1,12 +1,6 @@
1
1
  import pathlib
2
2
  import sys
3
- from typing import (
4
- Generator,
5
- List,
6
- Set,
7
- Tuple,
8
- Union,
9
- )
3
+ from collections.abc import Generator
10
4
 
11
5
  from betterproto2_compiler.lib.google.protobuf import (
12
6
  DescriptorProto,
@@ -45,13 +39,13 @@ from .typing_compiler import (
45
39
 
46
40
  def traverse(
47
41
  proto_file: FileDescriptorProto,
48
- ) -> Generator[Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None]:
42
+ ) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int]], None, None]:
49
43
  # Todo: Keep information about nested hierarchy
50
44
  def _traverse(
51
- path: List[int],
52
- items: Union[List[EnumDescriptorProto], List[DescriptorProto]],
45
+ path: list[int],
46
+ items: list[EnumDescriptorProto] | list[DescriptorProto],
53
47
  prefix: str = "",
54
- ) -> Generator[Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None]:
48
+ ) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int]], None, None]:
55
49
  for i, item in enumerate(items):
56
50
  # Adjust the name since we flatten the hierarchy.
57
51
  # Todo: don't change the name, but include full name in returned tuple
@@ -82,7 +76,8 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
82
76
  if output_package_name not in request_data.output_packages:
83
77
  # Create a new output if there is no output for this package
84
78
  request_data.output_packages[output_package_name] = OutputTemplate(
85
- parent_request=request_data, package_proto_obj=proto_file
79
+ parent_request=request_data,
80
+ package_proto_obj=proto_file,
86
81
  )
87
82
  # Add this input file to the output corresponding to this package
88
83
  request_data.output_packages[output_package_name].input_files.append(proto_file)
@@ -144,7 +139,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
144
139
  service.ready()
145
140
 
146
141
  # Generate output files
147
- output_paths: Set[pathlib.Path] = set()
142
+ output_paths: set[pathlib.Path] = set()
148
143
  for output_package_name, output_package in request_data.output_packages.items():
149
144
  if not output_package.output:
150
145
  continue
@@ -158,7 +153,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
158
153
  name=str(output_path),
159
154
  # Render and then format the output file
160
155
  content=outputfile_compiler(output_file=output_package),
161
- )
156
+ ),
162
157
  )
163
158
 
164
159
  # Make each output directory a package with __init__ file
@@ -183,7 +178,7 @@ def _make_one_of_field_compiler(
183
178
  source_file: "FileDescriptorProto",
184
179
  parent: MessageCompiler,
185
180
  proto_obj: "FieldDescriptorProto",
186
- path: List[int],
181
+ path: list[int],
187
182
  ) -> FieldCompiler:
188
183
  return OneOfFieldCompiler(
189
184
  source_file=source_file,
@@ -196,7 +191,7 @@ def _make_one_of_field_compiler(
196
191
 
197
192
  def read_protobuf_type(
198
193
  item: DescriptorProto,
199
- path: List[int],
194
+ path: list[int],
200
195
  source_file: "FileDescriptorProto",
201
196
  output_package: OutputTemplate,
202
197
  ) -> None:
@@ -1,15 +1,11 @@
1
1
  import abc
2
+ import builtins
2
3
  from collections import defaultdict
4
+ from collections.abc import Iterator
3
5
  from dataclasses import (
4
6
  dataclass,
5
7
  field,
6
8
  )
7
- from typing import (
8
- Dict,
9
- Iterator,
10
- Optional,
11
- Set,
12
- )
13
9
 
14
10
 
15
11
  class TypingCompiler(metaclass=abc.ABCMeta):
@@ -42,7 +38,7 @@ class TypingCompiler(metaclass=abc.ABCMeta):
42
38
  raise NotImplementedError
43
39
 
44
40
  @abc.abstractmethod
45
- def imports(self) -> Dict[str, Optional[Set[str]]]:
41
+ def imports(self) -> builtins.dict[str, set[str] | None]:
46
42
  """
47
43
  Returns either the direct import as a key with none as value, or a set of
48
44
  values to import from the key.
@@ -63,7 +59,7 @@ class TypingCompiler(metaclass=abc.ABCMeta):
63
59
 
64
60
  @dataclass
65
61
  class DirectImportTypingCompiler(TypingCompiler):
66
- _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
62
+ _imports: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set))
67
63
 
68
64
  def optional(self, type_: str) -> str:
69
65
  self._imports["typing"].add("Optional")
@@ -93,7 +89,7 @@ class DirectImportTypingCompiler(TypingCompiler):
93
89
  self._imports["typing"].add("AsyncIterator")
94
90
  return f"AsyncIterator[{type_}]"
95
91
 
96
- def imports(self) -> Dict[str, Optional[Set[str]]]:
92
+ def imports(self) -> builtins.dict[str, set[str] | None]:
97
93
  return {k: v if v else None for k, v in self._imports.items()}
98
94
 
99
95
 
@@ -129,7 +125,7 @@ class TypingImportTypingCompiler(TypingCompiler):
129
125
  self._imported = True
130
126
  return f"typing.AsyncIterator[{type_}]"
131
127
 
132
- def imports(self) -> Dict[str, Optional[Set[str]]]:
128
+ def imports(self) -> builtins.dict[str, set[str] | None]:
133
129
  if self._imported:
134
130
  return {"typing": None}
135
131
  return {}
@@ -137,7 +133,7 @@ class TypingImportTypingCompiler(TypingCompiler):
137
133
 
138
134
  @dataclass
139
135
  class NoTyping310TypingCompiler(TypingCompiler):
140
- _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
136
+ _imports: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set))
141
137
 
142
138
  def optional(self, type_: str) -> str:
143
139
  return f"{type_} | None"
@@ -163,5 +159,5 @@ class NoTyping310TypingCompiler(TypingCompiler):
163
159
  self._imports["collections.abc"].add("AsyncIterator")
164
160
  return f"AsyncIterator[{type_}]"
165
161
 
166
- def imports(self) -> Dict[str, Optional[Set[str]]]:
162
+ def imports(self) -> builtins.dict[str, set[str] | None]:
167
163
  return {k: v if v else None for k, v in self._imports.items()}
@@ -48,3 +48,5 @@ if TYPE_CHECKING:
48
48
  import grpclib.server
49
49
  from betterproto2.grpc.grpclib_client import MetadataLike
50
50
  from grpclib.metadata import Deadline
51
+
52
+ betterproto2.check_compiler_version("{{ version }}")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: betterproto2_compiler
3
- Version: 0.0.2
3
+ Version: 0.1.0
4
4
  Summary: Compiler for betterproto2
5
5
  Home-page: https://github.com/betterproto/python-betterproto2-compiler
6
6
  License: MIT
@@ -1,41 +1,34 @@
1
1
  betterproto2_compiler/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- betterproto2_compiler/_types.py,sha256=nIsUxcId43N1Gu8EqdeuflR9iUZB1JWu4JTGQV9NeUI,294
3
2
  betterproto2_compiler/casing.py,sha256=bMdI4W0hfYh6kV-DQIqFEjSfGYEqUtPciAzP64z5HLQ,3587
4
3
  betterproto2_compiler/compile/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- betterproto2_compiler/compile/importing.py,sha256=SpDU88rUbBYg5EQ4xOmiym8Xrwlg9GbUuIpdKgrLMmo,7440
4
+ betterproto2_compiler/compile/importing.py,sha256=MN1pzFISG96wd6Djsym8q9yfsknBC5logvPrW6qnWyc,7388
6
5
  betterproto2_compiler/compile/naming.py,sha256=zf0VOmNojzyv33upOGelGxjZTEDE8JULEEED5_3inHg,562
7
- betterproto2_compiler/enum.py,sha256=LcILQf1BEjnszouUCtPwifJAR_8u2tf9U9hfwP4vXTc,5396
8
- betterproto2_compiler/grpc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
- betterproto2_compiler/grpc/grpclib_client.py,sha256=6yKUqLFEfiMUPT85g81ajvVI4bvH4kBULZa7pIL18Y8,5275
10
- betterproto2_compiler/grpc/grpclib_server.py,sha256=Tv3NIGPrxdA48_THUl3jut_IekQgkecX8NPfjvF0kdg,872
11
- betterproto2_compiler/grpc/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- betterproto2_compiler/grpc/util/async_channel.py,sha256=4sfqoHtS_-qU1GFc0LatnV_dLsmVrGEm74WOJ9RpV1Y,6778
13
6
  betterproto2_compiler/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
7
  betterproto2_compiler/lib/google/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
8
  betterproto2_compiler/lib/google/protobuf/__init__.py,sha256=RvtT0CjNfAegDv42ITEvaIVrGpEh2fXvGNyA7bd99oo,60
16
9
  betterproto2_compiler/lib/google/protobuf/compiler/__init__.py,sha256=1EvLU05Ck-rwv2Y_jW4ylJlCgqd7VpSHeFgbBFWfAmg,69
17
10
  betterproto2_compiler/lib/pydantic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
11
  betterproto2_compiler/lib/pydantic/google/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- betterproto2_compiler/lib/pydantic/google/protobuf/__init__.py,sha256=eAvwcNussnueselEJwX0KcmGbhqz4oTux7e1ugCV6KE,96981
12
+ betterproto2_compiler/lib/pydantic/google/protobuf/__init__.py,sha256=RgGj0p9-KkaO-CFGV_CLTXm1yhPf-Vm02cUkF7z39SI,96972
20
13
  betterproto2_compiler/lib/pydantic/google/protobuf/compiler/__init__.py,sha256=uI4F0DytRxbrwXqiRKOJD0nJFewlQE0Lj_3WJdiGoDU,9217
21
14
  betterproto2_compiler/lib/std/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
15
  betterproto2_compiler/lib/std/google/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
- betterproto2_compiler/lib/std/google/protobuf/__init__.py,sha256=s5Tkt3684xnq2VL0cHZXytN2ja4D73C6ZNf7APEYdcI,80454
16
+ betterproto2_compiler/lib/std/google/protobuf/__init__.py,sha256=bfxwQzYXJ-ImFktGBN5_8HY9JgASJ8BbBg9fKeLTVQ0,80437
24
17
  betterproto2_compiler/lib/std/google/protobuf/compiler/__init__.py,sha256=z63o4vkjnDWB7yZpYnx-T0FZCbmSs5gBzfKrNP_qf8c,8646
25
18
  betterproto2_compiler/plugin/__init__.py,sha256=L3pW0b4CvkM5x53x_sYt1kYiSFPO0_vaeH6EQPq9FAM,43
26
19
  betterproto2_compiler/plugin/__main__.py,sha256=vBQ82334kX06ImDbFlPFgiBRiLIinwNk3z8Khs6hd74,31
27
- betterproto2_compiler/plugin/compiler.py,sha256=J_0WvuOVXuIINogNTOtU9Kyhzbu3NDrKh7ojbMjSjJk,2032
20
+ betterproto2_compiler/plugin/compiler.py,sha256=jICLI4-5rAOkWQI1v5j7JqIvoao-ZM9szMuq0OBRteA,2138
28
21
  betterproto2_compiler/plugin/main.py,sha256=Q9PmcJqXuYYFe51l7AqHVzJrHqi2LWCUu80CZSQOOwk,1469
29
- betterproto2_compiler/plugin/models.py,sha256=Qmpf7HZCGdbUShVBmb7wirFwMKW4xB9phAniwf-H5vk,24314
30
- betterproto2_compiler/plugin/module_validation.py,sha256=vye8PjsZFs1Ikh0yNLQXuy12EdM0em0Bflgx7xYrMhk,4853
31
- betterproto2_compiler/plugin/parser.py,sha256=PFFlK7Di7BF7_tCzkIWUeaRqDfVKJQc9YjSVKeHVXWM,9651
22
+ betterproto2_compiler/plugin/models.py,sha256=7l1P8-ijdatENy-ERO5twHeMpYxLG3YEqhYM9Slm5TY,24655
23
+ betterproto2_compiler/plugin/module_validation.py,sha256=RdPFwdmkbD6NKADaHC5eaPix_pz-yGxHvYJj8Ev48fA,4822
24
+ betterproto2_compiler/plugin/parser.py,sha256=i_HDW0O4ieIatbq1hckOj0kM9MCfdzuib0fumSJqIA4,9610
32
25
  betterproto2_compiler/plugin/plugin.bat,sha256=lfLT1WguAXqyerLLsRL6BfHA0RqUE6QG79v-1BYVSpI,48
33
- betterproto2_compiler/plugin/typing_compiler.py,sha256=gMrKsrA_xFoy33tbm4VjiktXyGAARFGHPL6iVmUiPLU,4866
26
+ betterproto2_compiler/plugin/typing_compiler.py,sha256=IK6m4ggHXK7HL98Ed_WjvQ_yeWfIpf_fIBZ9SA8UcyM,4873
34
27
  betterproto2_compiler/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
35
- betterproto2_compiler/templates/header.py.j2,sha256=H3R2v5MiCebp5p1zMKCDjsVlnOLklqrwO24SUtjfdN0,1410
28
+ betterproto2_compiler/templates/header.py.j2,sha256=nxqsengMcM_IRqQYNVntPQ0gUFUPCy_1P1mcoLvbDos,1464
36
29
  betterproto2_compiler/templates/template.py.j2,sha256=icyiNdSTJRgyD20e_lTgTAvSjgnSFSn4t1L1-yZnkEM,8712
37
- betterproto2_compiler-0.0.2.dist-info/LICENSE.md,sha256=Pgl2pReU-2yw2miGeQ55UFlyzqAZ_EpYVyZ2nWjwRv4,1121
38
- betterproto2_compiler-0.0.2.dist-info/METADATA,sha256=QtFZcJYtNY44aPlst2o8H30uyuX0nUB5e5QE3CV8rTQ,1163
39
- betterproto2_compiler-0.0.2.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
40
- betterproto2_compiler-0.0.2.dist-info/entry_points.txt,sha256=DE80wLfBwKlvu82d9pAYzEo7Cp22WNqwU7WJZq6JAWk,83
41
- betterproto2_compiler-0.0.2.dist-info/RECORD,,
30
+ betterproto2_compiler-0.1.0.dist-info/LICENSE.md,sha256=Pgl2pReU-2yw2miGeQ55UFlyzqAZ_EpYVyZ2nWjwRv4,1121
31
+ betterproto2_compiler-0.1.0.dist-info/METADATA,sha256=sJiWJ2eOTCN0ARCo0VmNcMlCPgEFli0tfsnM2vYM-Gc,1163
32
+ betterproto2_compiler-0.1.0.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
33
+ betterproto2_compiler-0.1.0.dist-info/entry_points.txt,sha256=re3Qg8lLljbVobeeKH2f1FVQZ114wfZkGv3zCZTD8Ok,84
34
+ betterproto2_compiler-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,3 @@
1
+ [console_scripts]
2
+ protoc-gen-python_betterproto2=betterproto2_compiler.plugin:main
3
+
@@ -1,13 +0,0 @@
1
- from typing import (
2
- TYPE_CHECKING,
3
- TypeVar,
4
- )
5
-
6
- if TYPE_CHECKING:
7
- from grpclib._typing import IProtoMessage
8
-
9
- from . import Message
10
-
11
- # Bound type variable to allow methods to return `self` of subclasses
12
- T = TypeVar("T", bound="Message")
13
- ST = TypeVar("ST", bound="IProtoMessage")
@@ -1,180 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from enum import (
4
- EnumMeta,
5
- IntEnum,
6
- )
7
- from types import MappingProxyType
8
- from typing import (
9
- TYPE_CHECKING,
10
- Any,
11
- Dict,
12
- Optional,
13
- Tuple,
14
- )
15
-
16
- if TYPE_CHECKING:
17
- from collections.abc import (
18
- Generator,
19
- Mapping,
20
- )
21
-
22
- from typing_extensions import (
23
- Never,
24
- Self,
25
- )
26
-
27
-
28
- def _is_descriptor(obj: object) -> bool:
29
- return hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
30
-
31
-
32
- class EnumType(EnumMeta if TYPE_CHECKING else type):
33
- _value_map_: Mapping[int, Enum]
34
- _member_map_: Mapping[str, Enum]
35
-
36
- def __new__(mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]) -> Self:
37
- value_map = {}
38
- member_map = {}
39
-
40
- new_mcs = type(
41
- f"{name}Type",
42
- tuple(
43
- dict.fromkeys([base.__class__ for base in bases if base.__class__ is not type] + [EnumType, type])
44
- ), # reorder the bases so EnumType and type are last to avoid conflicts
45
- {"_value_map_": value_map, "_member_map_": member_map},
46
- )
47
-
48
- members = {
49
- name: value for name, value in namespace.items() if not _is_descriptor(value) and not name.startswith("__")
50
- }
51
-
52
- cls = type.__new__(
53
- new_mcs,
54
- name,
55
- bases,
56
- {key: value for key, value in namespace.items() if key not in members},
57
- )
58
- # this allows us to disallow member access from other members as
59
- # members become proper class variables
60
-
61
- for name, value in members.items():
62
- member = value_map.get(value)
63
- if member is None:
64
- member = cls.__new__(cls, name=name, value=value) # type: ignore
65
- value_map[value] = member
66
- member_map[name] = member
67
- type.__setattr__(new_mcs, name, member)
68
-
69
- return cls
70
-
71
- if not TYPE_CHECKING:
72
-
73
- def __call__(cls, value: int) -> Enum:
74
- try:
75
- return cls._value_map_[value]
76
- except (KeyError, TypeError):
77
- raise ValueError(f"{value!r} is not a valid {cls.__name__}") from None
78
-
79
- def __iter__(cls) -> Generator[Enum, None, None]:
80
- yield from cls._member_map_.values()
81
-
82
- def __reversed__(cls) -> Generator[Enum, None, None]:
83
- yield from reversed(cls._member_map_.values())
84
-
85
- def __getitem__(cls, key: str) -> Enum:
86
- return cls._member_map_[key]
87
-
88
- @property
89
- def __members__(cls) -> MappingProxyType[str, Enum]:
90
- return MappingProxyType(cls._member_map_)
91
-
92
- def __repr__(cls) -> str:
93
- return f"<enum {cls.__name__!r}>"
94
-
95
- def __len__(cls) -> int:
96
- return len(cls._member_map_)
97
-
98
- def __setattr__(cls, name: str, value: Any) -> Never:
99
- raise AttributeError(f"{cls.__name__}: cannot reassign Enum members.")
100
-
101
- def __delattr__(cls, name: str) -> Never:
102
- raise AttributeError(f"{cls.__name__}: cannot delete Enum members.")
103
-
104
- def __contains__(cls, member: object) -> bool:
105
- return isinstance(member, cls) and member.name in cls._member_map_
106
-
107
-
108
- class Enum(IntEnum if TYPE_CHECKING else int, metaclass=EnumType):
109
- """
110
- The base class for protobuf enumerations, all generated enumerations will
111
- inherit from this. Emulates `enum.IntEnum`.
112
- """
113
-
114
- name: Optional[str]
115
- value: int
116
-
117
- if not TYPE_CHECKING:
118
-
119
- def __new__(cls, *, name: Optional[str], value: int) -> Self:
120
- self = super().__new__(cls, value)
121
- super().__setattr__(self, "name", name)
122
- super().__setattr__(self, "value", value)
123
- return self
124
-
125
- def __str__(self) -> str:
126
- return self.name or "None"
127
-
128
- def __repr__(self) -> str:
129
- return f"{self.__class__.__name__}.{self.name}"
130
-
131
- def __setattr__(self, key: str, value: Any) -> Never:
132
- raise AttributeError(f"{self.__class__.__name__} Cannot reassign a member's attributes.")
133
-
134
- def __delattr__(self, item: Any) -> Never:
135
- raise AttributeError(f"{self.__class__.__name__} Cannot delete a member's attributes.")
136
-
137
- def __copy__(self) -> Self:
138
- return self
139
-
140
- def __deepcopy__(self, memo: Any) -> Self:
141
- return self
142
-
143
- @classmethod
144
- def try_value(cls, value: int = 0) -> Self:
145
- """Return the value which corresponds to the value.
146
-
147
- Parameters
148
- -----------
149
- value: :class:`int`
150
- The value of the enum member to get.
151
-
152
- Returns
153
- -------
154
- :class:`Enum`
155
- The corresponding member or a new instance of the enum if
156
- ``value`` isn't actually a member.
157
- """
158
- try:
159
- return cls._value_map_[value]
160
- except (KeyError, TypeError):
161
- return cls.__new__(cls, name=None, value=value)
162
-
163
- @classmethod
164
- def from_string(cls, name: str) -> Self:
165
- """Return the value which corresponds to the string name.
166
-
167
- Parameters
168
- -----------
169
- name: :class:`str`
170
- The name of the enum member to get.
171
-
172
- Raises
173
- -------
174
- :exc:`ValueError`
175
- The member was not found in the Enum.
176
- """
177
- try:
178
- return cls._member_map_[name]
179
- except KeyError as e:
180
- raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
File without changes
@@ -1,172 +0,0 @@
1
- import asyncio
2
- from abc import ABC
3
- from typing import (
4
- TYPE_CHECKING,
5
- AsyncIterable,
6
- AsyncIterator,
7
- Collection,
8
- Iterable,
9
- Mapping,
10
- Optional,
11
- Tuple,
12
- Type,
13
- Union,
14
- )
15
-
16
- import grpclib.const
17
-
18
- if TYPE_CHECKING:
19
- from grpclib.client import Channel
20
- from grpclib.metadata import Deadline
21
-
22
- from .._types import (
23
- IProtoMessage,
24
- T,
25
- )
26
-
27
-
28
- Value = Union[str, bytes]
29
- MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]]
30
- MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]
31
-
32
-
33
- class ServiceStub(ABC):
34
- """
35
- Base class for async gRPC clients.
36
- """
37
-
38
- def __init__(
39
- self,
40
- channel: "Channel",
41
- *,
42
- timeout: Optional[float] = None,
43
- deadline: Optional["Deadline"] = None,
44
- metadata: Optional[MetadataLike] = None,
45
- ) -> None:
46
- self.channel = channel
47
- self.timeout = timeout
48
- self.deadline = deadline
49
- self.metadata = metadata
50
-
51
- def __resolve_request_kwargs(
52
- self,
53
- timeout: Optional[float],
54
- deadline: Optional["Deadline"],
55
- metadata: Optional[MetadataLike],
56
- ):
57
- return {
58
- "timeout": self.timeout if timeout is None else timeout,
59
- "deadline": self.deadline if deadline is None else deadline,
60
- "metadata": self.metadata if metadata is None else metadata,
61
- }
62
-
63
- async def _unary_unary(
64
- self,
65
- route: str,
66
- request: "IProtoMessage",
67
- response_type: Type["T"],
68
- *,
69
- timeout: Optional[float] = None,
70
- deadline: Optional["Deadline"] = None,
71
- metadata: Optional[MetadataLike] = None,
72
- ) -> "T":
73
- """Make a unary request and return the response."""
74
- async with self.channel.request(
75
- route,
76
- grpclib.const.Cardinality.UNARY_UNARY,
77
- type(request),
78
- response_type,
79
- **self.__resolve_request_kwargs(timeout, deadline, metadata),
80
- ) as stream:
81
- await stream.send_message(request, end=True)
82
- response = await stream.recv_message()
83
- assert response is not None
84
- return response
85
-
86
- async def _unary_stream(
87
- self,
88
- route: str,
89
- request: "IProtoMessage",
90
- response_type: Type["T"],
91
- *,
92
- timeout: Optional[float] = None,
93
- deadline: Optional["Deadline"] = None,
94
- metadata: Optional[MetadataLike] = None,
95
- ) -> AsyncIterator["T"]:
96
- """Make a unary request and return the stream response iterator."""
97
- async with self.channel.request(
98
- route,
99
- grpclib.const.Cardinality.UNARY_STREAM,
100
- type(request),
101
- response_type,
102
- **self.__resolve_request_kwargs(timeout, deadline, metadata),
103
- ) as stream:
104
- await stream.send_message(request, end=True)
105
- async for message in stream:
106
- yield message
107
-
108
- async def _stream_unary(
109
- self,
110
- route: str,
111
- request_iterator: MessageSource,
112
- request_type: Type["IProtoMessage"],
113
- response_type: Type["T"],
114
- *,
115
- timeout: Optional[float] = None,
116
- deadline: Optional["Deadline"] = None,
117
- metadata: Optional[MetadataLike] = None,
118
- ) -> "T":
119
- """Make a stream request and return the response."""
120
- async with self.channel.request(
121
- route,
122
- grpclib.const.Cardinality.STREAM_UNARY,
123
- request_type,
124
- response_type,
125
- **self.__resolve_request_kwargs(timeout, deadline, metadata),
126
- ) as stream:
127
- await stream.send_request()
128
- await self._send_messages(stream, request_iterator)
129
- response = await stream.recv_message()
130
- assert response is not None
131
- return response
132
-
133
- async def _stream_stream(
134
- self,
135
- route: str,
136
- request_iterator: MessageSource,
137
- request_type: Type["IProtoMessage"],
138
- response_type: Type["T"],
139
- *,
140
- timeout: Optional[float] = None,
141
- deadline: Optional["Deadline"] = None,
142
- metadata: Optional[MetadataLike] = None,
143
- ) -> AsyncIterator["T"]:
144
- """
145
- Make a stream request and return an AsyncIterator to iterate over response
146
- messages.
147
- """
148
- async with self.channel.request(
149
- route,
150
- grpclib.const.Cardinality.STREAM_STREAM,
151
- request_type,
152
- response_type,
153
- **self.__resolve_request_kwargs(timeout, deadline, metadata),
154
- ) as stream:
155
- await stream.send_request()
156
- sending_task = asyncio.ensure_future(self._send_messages(stream, request_iterator))
157
- try:
158
- async for response in stream:
159
- yield response
160
- except:
161
- sending_task.cancel()
162
- raise
163
-
164
- @staticmethod
165
- async def _send_messages(stream, messages: MessageSource):
166
- if isinstance(messages, AsyncIterable):
167
- async for message in messages:
168
- await stream.send_message(message)
169
- else:
170
- for message in messages:
171
- await stream.send_message(message)
172
- await stream.end()
@@ -1,32 +0,0 @@
1
- from abc import ABC
2
- from collections.abc import AsyncIterable
3
- from typing import (
4
- Any,
5
- Callable,
6
- )
7
-
8
- import grpclib
9
- import grpclib.server
10
-
11
-
12
- class ServiceBase(ABC):
13
- """
14
- Base class for async gRPC servers.
15
- """
16
-
17
- async def _call_rpc_handler_server_stream(
18
- self,
19
- handler: Callable,
20
- stream: grpclib.server.Stream,
21
- request: Any,
22
- ) -> None:
23
- response_iter = handler(request)
24
- # check if response is actually an AsyncIterator
25
- # this might be false if the method just returns without
26
- # yielding at least once
27
- # in that case, we just interpret it as an empty iterator
28
- if isinstance(response_iter, AsyncIterable):
29
- async for response_message in response_iter:
30
- await stream.send_message(response_message)
31
- else:
32
- response_iter.close()
File without changes
@@ -1,190 +0,0 @@
1
- import asyncio
2
- from typing import (
3
- AsyncIterable,
4
- AsyncIterator,
5
- Iterable,
6
- Optional,
7
- TypeVar,
8
- Union,
9
- )
10
-
11
- T = TypeVar("T")
12
-
13
-
14
- class ChannelClosed(Exception):
15
- """
16
- An exception raised on an attempt to send through a closed channel
17
- """
18
-
19
-
20
- class ChannelDone(Exception):
21
- """
22
- An exception raised on an attempt to send receive from a channel that is both closed
23
- and empty.
24
- """
25
-
26
-
27
- class AsyncChannel(AsyncIterable[T]):
28
- """
29
- A buffered async channel for sending items between coroutines with FIFO ordering.
30
-
31
- This makes decoupled bidirectional steaming gRPC requests easy if used like:
32
-
33
- .. code-block:: python
34
- client = GeneratedStub(grpclib_chan)
35
- request_channel = await AsyncChannel()
36
- # We can start be sending all the requests we already have
37
- await request_channel.send_from([RequestObject(...), RequestObject(...)])
38
- async for response in client.rpc_call(request_channel):
39
- # The response iterator will remain active until the connection is closed
40
- ...
41
- # More items can be sent at any time
42
- await request_channel.send(RequestObject(...))
43
- ...
44
- # The channel must be closed to complete the gRPC connection
45
- request_channel.close()
46
-
47
- Items can be sent through the channel by either:
48
- - providing an iterable to the send_from method
49
- - passing them to the send method one at a time
50
-
51
- Items can be received from the channel by either:
52
- - iterating over the channel with a for loop to get all items
53
- - calling the receive method to get one item at a time
54
-
55
- If the channel is empty then receivers will wait until either an item appears or the
56
- channel is closed.
57
-
58
- Once the channel is closed then subsequent attempt to send through the channel will
59
- fail with a ChannelClosed exception.
60
-
61
- When th channel is closed and empty then it is done, and further attempts to receive
62
- from it will fail with a ChannelDone exception
63
-
64
- If multiple coroutines receive from the channel concurrently, each item sent will be
65
- received by only one of the receivers.
66
-
67
- :param source:
68
- An optional iterable will items that should be sent through the channel
69
- immediately.
70
- :param buffer_limit:
71
- Limit the number of items that can be buffered in the channel, A value less than
72
- 1 implies no limit. If the channel is full then attempts to send more items will
73
- result in the sender waiting until an item is received from the channel.
74
- :param close:
75
- If set to True then the channel will automatically close after exhausting source
76
- or immediately if no source is provided.
77
- """
78
-
79
- def __init__(self, *, buffer_limit: int = 0, close: bool = False):
80
- self._queue: asyncio.Queue[T] = asyncio.Queue(buffer_limit)
81
- self._closed = False
82
- self._waiting_receivers: int = 0
83
- # Track whether flush has been invoked so it can only happen once
84
- self._flushed = False
85
-
86
- def __aiter__(self) -> AsyncIterator[T]:
87
- return self
88
-
89
- async def __anext__(self) -> T:
90
- if self.done():
91
- raise StopAsyncIteration
92
- self._waiting_receivers += 1
93
- try:
94
- result = await self._queue.get()
95
- if result is self.__flush:
96
- raise StopAsyncIteration
97
- return result
98
- finally:
99
- self._waiting_receivers -= 1
100
- self._queue.task_done()
101
-
102
- def closed(self) -> bool:
103
- """
104
- Returns True if this channel is closed and no-longer accepting new items
105
- """
106
- return self._closed
107
-
108
- def done(self) -> bool:
109
- """
110
- Check if this channel is done.
111
-
112
- :return: True if this channel is closed and and has been drained of items in
113
- which case any further attempts to receive an item from this channel will raise
114
- a ChannelDone exception.
115
- """
116
- # After close the channel is not yet done until there is at least one waiting
117
- # receiver per enqueued item.
118
- return self._closed and self._queue.qsize() <= self._waiting_receivers
119
-
120
- async def send_from(self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False) -> "AsyncChannel[T]":
121
- """
122
- Iterates the given [Async]Iterable and sends all the resulting items.
123
- If close is set to True then subsequent send calls will be rejected with a
124
- ChannelClosed exception.
125
- :param source: an iterable of items to send
126
- :param close:
127
- if True then the channel will be closed after the source has been exhausted
128
-
129
- """
130
- if self._closed:
131
- raise ChannelClosed("Cannot send through a closed channel")
132
- if isinstance(source, AsyncIterable):
133
- async for item in source:
134
- await self._queue.put(item)
135
- else:
136
- for item in source:
137
- await self._queue.put(item)
138
- if close:
139
- # Complete the closing process
140
- self.close()
141
- return self
142
-
143
- async def send(self, item: T) -> "AsyncChannel[T]":
144
- """
145
- Send a single item over this channel.
146
- :param item: The item to send
147
- """
148
- if self._closed:
149
- raise ChannelClosed("Cannot send through a closed channel")
150
- await self._queue.put(item)
151
- return self
152
-
153
- async def receive(self) -> Optional[T]:
154
- """
155
- Returns the next item from this channel when it becomes available,
156
- or None if the channel is closed before another item is sent.
157
- :return: An item from the channel
158
- """
159
- if self.done():
160
- raise ChannelDone("Cannot receive from a closed channel")
161
- self._waiting_receivers += 1
162
- try:
163
- result = await self._queue.get()
164
- if result is self.__flush:
165
- return None
166
- return result
167
- finally:
168
- self._waiting_receivers -= 1
169
- self._queue.task_done()
170
-
171
- def close(self):
172
- """
173
- Close this channel to new items
174
- """
175
- self._closed = True
176
- asyncio.ensure_future(self._flush_queue())
177
-
178
- async def _flush_queue(self):
179
- """
180
- To be called after the channel is closed. Pushes a number of self.__flush
181
- objects to the queue to ensure no waiting consumers get deadlocked.
182
- """
183
- if not self._flushed:
184
- self._flushed = True
185
- deadlocked_receivers = max(0, self._waiting_receivers - self._queue.qsize())
186
- for _ in range(deadlocked_receivers):
187
- await self._queue.put(self.__flush)
188
-
189
- # A special signal object for flushing the queue when the channel is closed
190
- __flush = object()
@@ -1,3 +0,0 @@
1
- [console_scripts]
2
- protoc-gen-python_betterproto=betterproto2_compiler.plugin:main
3
-