betterproto2-compiler 0.0.3__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,11 +3,6 @@ from __future__ import annotations
3
3
  import os
4
4
  from typing import (
5
5
  TYPE_CHECKING,
6
- Dict,
7
- List,
8
- Set,
9
- Tuple,
10
- Type,
11
6
  )
12
7
 
13
8
  from ..casing import safe_snake_case
@@ -18,7 +13,7 @@ if TYPE_CHECKING:
18
13
  from ..plugin.models import PluginRequestCompiler
19
14
  from ..plugin.typing_compiler import TypingCompiler
20
15
 
21
- WRAPPER_TYPES: Dict[str, Type] = {
16
+ WRAPPER_TYPES: dict[str, type] = {
22
17
  ".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
23
18
  ".google.protobuf.FloatValue": google_protobuf.FloatValue,
24
19
  ".google.protobuf.Int32Value": google_protobuf.Int32Value,
@@ -31,7 +26,7 @@ WRAPPER_TYPES: Dict[str, Type] = {
31
26
  }
32
27
 
33
28
 
34
- def parse_source_type_name(field_type_name: str, request: "PluginRequestCompiler") -> Tuple[str, str]:
29
+ def parse_source_type_name(field_type_name: str, request: PluginRequestCompiler) -> tuple[str, str]:
35
30
  """
36
31
  Split full source type name into package and type name.
37
32
  E.g. 'root.package.Message' -> ('root.package', 'Message')
@@ -77,7 +72,7 @@ def get_type_reference(
77
72
  imports: set,
78
73
  source_type: str,
79
74
  typing_compiler: TypingCompiler,
80
- request: "PluginRequestCompiler",
75
+ request: PluginRequestCompiler,
81
76
  unwrap: bool = True,
82
77
  pydantic: bool = False,
83
78
  ) -> str:
@@ -98,8 +93,8 @@ def get_type_reference(
98
93
 
99
94
  source_package, source_type = parse_source_type_name(source_type, request)
100
95
 
101
- current_package: List[str] = package.split(".") if package else []
102
- py_package: List[str] = source_package.split(".") if source_package else []
96
+ current_package: list[str] = package.split(".") if package else []
97
+ py_package: list[str] = source_package.split(".") if source_package else []
103
98
  py_type: str = pythonize_class_name(source_type)
104
99
 
105
100
  compiling_google_protobuf = current_package == ["google", "protobuf"]
@@ -122,7 +117,7 @@ def get_type_reference(
122
117
  return reference_cousin(current_package, imports, py_package, py_type)
123
118
 
124
119
 
125
- def reference_absolute(imports: Set[str], py_package: List[str], py_type: str) -> str:
120
+ def reference_absolute(imports: set[str], py_package: list[str], py_type: str) -> str:
126
121
  """
127
122
  Returns a reference to a python type located in the root, i.e. sys.path.
128
123
  """
@@ -139,7 +134,7 @@ def reference_sibling(py_type: str) -> str:
139
134
  return f"{py_type}"
140
135
 
141
136
 
142
- def reference_descendent(current_package: List[str], imports: Set[str], py_package: List[str], py_type: str) -> str:
137
+ def reference_descendent(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str:
143
138
  """
144
139
  Returns a reference to a python type in a package that is a descendent of the
145
140
  current package, and adds the required import that is aliased to avoid name
@@ -157,7 +152,7 @@ def reference_descendent(current_package: List[str], imports: Set[str], py_packa
157
152
  return f"{string_import}.{py_type}"
158
153
 
159
154
 
160
- def reference_ancestor(current_package: List[str], imports: Set[str], py_package: List[str], py_type: str) -> str:
155
+ def reference_ancestor(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str:
161
156
  """
162
157
  Returns a reference to a python type in a package which is an ancestor to the
163
158
  current package, and adds the required import that is aliased (if possible) to avoid
@@ -178,7 +173,7 @@ def reference_ancestor(current_package: List[str], imports: Set[str], py_package
178
173
  return string_alias
179
174
 
180
175
 
181
- def reference_cousin(current_package: List[str], imports: Set[str], py_package: List[str], py_type: str) -> str:
176
+ def reference_cousin(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str:
182
177
  """
183
178
  Returns a reference to a python type in a package that is not descendent, ancestor
184
179
  or sibling, and adds the required import that is aliased to avoid name conflicts.
@@ -2401,13 +2401,13 @@ class Value(betterproto2_compiler.Message):
2401
2401
  )
2402
2402
  """Represents a null value."""
2403
2403
 
2404
- number_value: Optional[float] = betterproto2_compiler.double_field(2, optional=True, group="kind")
2404
+ number_value: float | None = betterproto2_compiler.double_field(2, optional=True, group="kind")
2405
2405
  """Represents a double value."""
2406
2406
 
2407
- string_value: Optional[str] = betterproto2_compiler.string_field(3, optional=True, group="kind")
2407
+ string_value: str | None = betterproto2_compiler.string_field(3, optional=True, group="kind")
2408
2408
  """Represents a string value."""
2409
2409
 
2410
- bool_value: Optional[bool] = betterproto2_compiler.bool_field(4, optional=True, group="kind")
2410
+ bool_value: bool | None = betterproto2_compiler.bool_field(4, optional=True, group="kind")
2411
2411
  """Represents a boolean value."""
2412
2412
 
2413
2413
  struct_value: Optional["Struct"] = betterproto2_compiler.message_field(5, optional=True, group="kind")
@@ -78,7 +78,6 @@ from typing import (
78
78
  Dict,
79
79
  List,
80
80
  Mapping,
81
- Optional,
82
81
  )
83
82
 
84
83
  import betterproto2
@@ -1022,7 +1021,7 @@ class FieldDescriptorProto(betterproto2.Message):
1022
1021
  TODO(kenton): Base-64 encode?
1023
1022
  """
1024
1023
 
1025
- oneof_index: Optional[int] = betterproto2.int32_field(9, optional=True)
1024
+ oneof_index: int | None = betterproto2.int32_field(9, optional=True)
1026
1025
  """
1027
1026
  If set, gives the index of a oneof in the containing type's oneof_decl
1028
1027
  list. This field is a member of that oneof.
@@ -1,6 +1,7 @@
1
1
  import os.path
2
2
  import subprocess
3
3
  import sys
4
+ from importlib import metadata
4
5
 
5
6
  from .module_validation import ModuleValidator
6
7
 
@@ -14,7 +15,7 @@ except ImportError as err:
14
15
  "Please ensure that you've installed betterproto as "
15
16
  '`pip install "betterproto[compiler]"` so that compiler dependencies '
16
17
  "are included."
17
- "\033[0m"
18
+ "\033[0m",
18
19
  )
19
20
  raise SystemExit(1)
20
21
 
@@ -24,6 +25,8 @@ from .models import OutputTemplate
24
25
  def outputfile_compiler(output_file: OutputTemplate) -> str:
25
26
  templates_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "templates"))
26
27
 
28
+ version = metadata.version("betterproto2_compiler")
29
+
27
30
  env = jinja2.Environment(
28
31
  trim_blocks=True,
29
32
  lstrip_blocks=True,
@@ -35,7 +38,7 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
35
38
  header_template = env.get_template("header.py.j2")
36
39
 
37
40
  code = body_template.render(output_file=output_file)
38
- code = header_template.render(output_file=output_file) + "\n" + code
41
+ code = header_template.render(output_file=output_file, version=version) + "\n" + code
39
42
 
40
43
  # Sort imports, delete unused ones
41
44
  code = subprocess.check_output(
@@ -31,18 +31,12 @@ reference to `A` to `B`'s `fields` attribute.
31
31
 
32
32
  import builtins
33
33
  import re
34
+ from collections.abc import Iterable, Iterator
34
35
  from dataclasses import (
35
36
  dataclass,
36
37
  field,
37
38
  )
38
39
  from typing import (
39
- Dict,
40
- Iterable,
41
- Iterator,
42
- List,
43
- Optional,
44
- Set,
45
- Type,
46
40
  Union,
47
41
  )
48
42
 
@@ -59,6 +53,7 @@ from betterproto2_compiler.lib.google.protobuf import (
59
53
  FieldDescriptorProto,
60
54
  FieldDescriptorProtoLabel,
61
55
  FieldDescriptorProtoType,
56
+ FieldDescriptorProtoType as FieldType,
62
57
  FileDescriptorProto,
63
58
  MethodDescriptorProto,
64
59
  )
@@ -146,7 +141,7 @@ PROTO_PACKED_TYPES = (
146
141
 
147
142
  def get_comment(
148
143
  proto_file: "FileDescriptorProto",
149
- path: List[int],
144
+ path: list[int],
150
145
  ) -> str:
151
146
  for sci_loc in proto_file.source_code_info.location:
152
147
  if list(sci_loc.path) == path:
@@ -182,10 +177,10 @@ class ProtoContentBase:
182
177
 
183
178
  source_file: FileDescriptorProto
184
179
  typing_compiler: TypingCompiler
185
- path: List[int]
180
+ path: list[int]
186
181
  parent: Union["betterproto2.Message", "OutputTemplate"]
187
182
 
188
- __dataclass_fields__: Dict[str, object]
183
+ __dataclass_fields__: dict[str, object]
189
184
 
190
185
  def __post_init__(self) -> None:
191
186
  """Checks that no fake default fields were left as placeholders."""
@@ -225,10 +220,10 @@ class ProtoContentBase:
225
220
  @dataclass
226
221
  class PluginRequestCompiler:
227
222
  plugin_request_obj: CodeGeneratorRequest
228
- output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
223
+ output_packages: dict[str, "OutputTemplate"] = field(default_factory=dict)
229
224
 
230
225
  @property
231
- def all_messages(self) -> List["MessageCompiler"]:
226
+ def all_messages(self) -> list["MessageCompiler"]:
232
227
  """All of the messages in this request.
233
228
 
234
229
  Returns
@@ -250,11 +245,11 @@ class OutputTemplate:
250
245
 
251
246
  parent_request: PluginRequestCompiler
252
247
  package_proto_obj: FileDescriptorProto
253
- input_files: List[str] = field(default_factory=list)
254
- imports_end: Set[str] = field(default_factory=set)
255
- messages: Dict[str, "MessageCompiler"] = field(default_factory=dict)
256
- enums: Dict[str, "EnumDefinitionCompiler"] = field(default_factory=dict)
257
- services: Dict[str, "ServiceCompiler"] = field(default_factory=dict)
248
+ input_files: list[str] = field(default_factory=list)
249
+ imports_end: set[str] = field(default_factory=set)
250
+ messages: dict[str, "MessageCompiler"] = field(default_factory=dict)
251
+ enums: dict[str, "EnumDefinitionCompiler"] = field(default_factory=dict)
252
+ services: dict[str, "ServiceCompiler"] = field(default_factory=dict)
258
253
  pydantic_dataclasses: bool = False
259
254
  output: bool = True
260
255
  typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler)
@@ -290,9 +285,9 @@ class MessageCompiler(ProtoContentBase):
290
285
  typing_compiler: TypingCompiler
291
286
  parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
292
287
  proto_obj: DescriptorProto = PLACEHOLDER
293
- path: List[int] = PLACEHOLDER
294
- fields: List[Union["FieldCompiler", "MessageCompiler"]] = field(default_factory=list)
295
- builtins_types: Set[str] = field(default_factory=set)
288
+ path: list[int] = PLACEHOLDER
289
+ fields: list[Union["FieldCompiler", "MessageCompiler"]] = field(default_factory=list)
290
+ builtins_types: set[str] = field(default_factory=set)
296
291
 
297
292
  def __post_init__(self) -> None:
298
293
  # Add message to output file
@@ -328,11 +323,9 @@ class MessageCompiler(ProtoContentBase):
328
323
  @property
329
324
  def has_message_field(self) -> bool:
330
325
  return any(
331
- (
332
- field.proto_obj.type in PROTO_MESSAGE_TYPES
333
- for field in self.fields
334
- if isinstance(field.proto_obj, FieldDescriptorProto)
335
- )
326
+ field.proto_obj.type in PROTO_MESSAGE_TYPES
327
+ for field in self.fields
328
+ if isinstance(field.proto_obj, FieldDescriptorProto)
336
329
  )
337
330
 
338
331
 
@@ -347,7 +340,7 @@ def is_map(proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProt
347
340
  map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
348
341
  if message_type == map_entry:
349
342
  for nested in parent_message.nested_type: # parent message
350
- if nested.name.replace("_", "").lower() == map_entry and nested.options.map_entry:
343
+ if nested.name.replace("_", "").lower() == map_entry and nested.options and nested.options.map_entry:
351
344
  return True
352
345
  return False
353
346
 
@@ -374,8 +367,8 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
374
367
  class FieldCompiler(ProtoContentBase):
375
368
  source_file: FileDescriptorProto
376
369
  typing_compiler: TypingCompiler
377
- path: List[int] = PLACEHOLDER
378
- builtins_types: Set[str] = field(default_factory=set)
370
+ path: list[int] = PLACEHOLDER
371
+ builtins_types: set[str] = field(default_factory=set)
379
372
 
380
373
  parent: MessageCompiler = PLACEHOLDER
381
374
  proto_obj: FieldDescriptorProto = PLACEHOLDER
@@ -390,13 +383,16 @@ class FieldCompiler(ProtoContentBase):
390
383
  """Construct string representation of this field as a field."""
391
384
  name = f"{self.py_name}"
392
385
  field_args = ", ".join(([""] + self.betterproto_field_args) if self.betterproto_field_args else [])
393
- betterproto_field_type = f"betterproto2.{self.field_type}_field({self.proto_obj.number}{field_args})"
386
+
387
+ betterproto_field_type = (
388
+ f"betterproto2.field({self.proto_obj.number}, betterproto2.{str(self.field_type)}{field_args})"
389
+ )
394
390
  if self.py_name in dir(builtins):
395
391
  self.parent.builtins_types.add(self.py_name)
396
392
  return f'{name}: "{self.annotation}" = {betterproto_field_type}'
397
393
 
398
394
  @property
399
- def betterproto_field_args(self) -> List[str]:
395
+ def betterproto_field_args(self) -> list[str]:
400
396
  args = []
401
397
  if self.field_wraps:
402
398
  args.append(f"wraps={self.field_wraps}")
@@ -404,9 +400,9 @@ class FieldCompiler(ProtoContentBase):
404
400
  args.append("optional=True")
405
401
  if self.repeated:
406
402
  args.append("repeated=True")
407
- if self.field_type == "enum":
403
+ if self.field_type == FieldType.TYPE_ENUM:
408
404
  t = self.py_type
409
- args.append(f"enum_default_value=lambda: {t}.try_value(0)")
405
+ args.append(f"default_factory=lambda: {t}.try_value(0)")
410
406
  return args
411
407
 
412
408
  @property
@@ -416,7 +412,7 @@ class FieldCompiler(ProtoContentBase):
416
412
  )
417
413
 
418
414
  @property
419
- def field_wraps(self) -> Optional[str]:
415
+ def field_wraps(self) -> str | None:
420
416
  """Returns betterproto wrapped field type or None."""
421
417
  match_wrapper = re.match(r"\.google\.protobuf\.(.+)Value$", self.proto_obj.type_name)
422
418
  if match_wrapper:
@@ -428,17 +424,19 @@ class FieldCompiler(ProtoContentBase):
428
424
  @property
429
425
  def repeated(self) -> bool:
430
426
  return self.proto_obj.label == FieldDescriptorProtoLabel.LABEL_REPEATED and not is_map(
431
- self.proto_obj, self.parent
427
+ self.proto_obj,
428
+ self.parent,
432
429
  )
433
430
 
434
431
  @property
435
432
  def optional(self) -> bool:
436
- return self.proto_obj.proto3_optional or (self.field_type == "message" and not self.repeated)
433
+ # TODO not for maps
434
+ return self.proto_obj.proto3_optional or (self.field_type == FieldType.TYPE_MESSAGE and not self.repeated)
437
435
 
438
436
  @property
439
- def field_type(self) -> str:
440
- """String representation of proto field type."""
441
- return FieldDescriptorProtoType(self.proto_obj.type).name.lower().replace("type_", "")
437
+ def field_type(self) -> FieldType:
438
+ # TODO it should be possible to remove constructor
439
+ return FieldType(self.proto_obj.type)
442
440
 
443
441
  @property
444
442
  def packed(self) -> bool:
@@ -500,7 +498,7 @@ class OneOfFieldCompiler(FieldCompiler):
500
498
  return True
501
499
 
502
500
  @property
503
- def betterproto_field_args(self) -> List[str]:
501
+ def betterproto_field_args(self) -> list[str]:
504
502
  args = super().betterproto_field_args
505
503
  group = self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name
506
504
  args.append(f'group="{group}"')
@@ -509,8 +507,8 @@ class OneOfFieldCompiler(FieldCompiler):
509
507
 
510
508
  @dataclass
511
509
  class MapEntryCompiler(FieldCompiler):
512
- py_k_type: Optional[Type] = None
513
- py_v_type: Optional[Type] = None
510
+ py_k_type: type | None = None
511
+ py_v_type: type | None = None
514
512
  proto_k_type: str = ""
515
513
  proto_v_type: str = ""
516
514
 
@@ -547,13 +545,17 @@ class MapEntryCompiler(FieldCompiler):
547
545
 
548
546
  raise ValueError("can't find enum")
549
547
 
550
- @property
551
- def betterproto_field_args(self) -> List[str]:
552
- return [f"betterproto2.{self.proto_k_type}", f"betterproto2.{self.proto_v_type}"]
553
-
554
- @property
555
- def field_type(self) -> str:
556
- return "map"
548
+ def get_field_string(self) -> str:
549
+ """Construct string representation of this field as a field."""
550
+ betterproto_field_type = (
551
+ f"betterproto2.field({self.proto_obj.number}, "
552
+ "betterproto2.TYPE_MAP, "
553
+ f"map_types=(betterproto2.{self.proto_k_type}, "
554
+ f"betterproto2.{self.proto_v_type}))"
555
+ )
556
+ if self.py_name in dir(builtins):
557
+ self.parent.builtins_types.add(self.py_name)
558
+ return f'{self.py_name}: "{self.annotation}" = {betterproto_field_type}'
557
559
 
558
560
  @property
559
561
  def annotation(self) -> str:
@@ -569,7 +571,7 @@ class EnumDefinitionCompiler(MessageCompiler):
569
571
  """Representation of a proto Enum definition."""
570
572
 
571
573
  proto_obj: EnumDescriptorProto = PLACEHOLDER
572
- entries: List["EnumDefinitionCompiler.EnumEntry"] = PLACEHOLDER
574
+ entries: list["EnumDefinitionCompiler.EnumEntry"] = PLACEHOLDER
573
575
 
574
576
  @dataclass(unsafe_hash=True)
575
577
  class EnumEntry:
@@ -597,8 +599,8 @@ class ServiceCompiler(ProtoContentBase):
597
599
  source_file: FileDescriptorProto
598
600
  parent: OutputTemplate = PLACEHOLDER
599
601
  proto_obj: DescriptorProto = PLACEHOLDER
600
- path: List[int] = PLACEHOLDER
601
- methods: List["ServiceMethodCompiler"] = field(default_factory=list)
602
+ path: list[int] = PLACEHOLDER
603
+ methods: list["ServiceMethodCompiler"] = field(default_factory=list)
602
604
 
603
605
  def __post_init__(self) -> None:
604
606
  # Add service to output file
@@ -619,7 +621,7 @@ class ServiceMethodCompiler(ProtoContentBase):
619
621
  source_file: FileDescriptorProto
620
622
  parent: ServiceCompiler
621
623
  proto_obj: MethodDescriptorProto
622
- path: List[int] = PLACEHOLDER
624
+ path: list[int] = PLACEHOLDER
623
625
 
624
626
  def __post_init__(self) -> None:
625
627
  # Add method to service
@@ -1,15 +1,10 @@
1
1
  import re
2
2
  from collections import defaultdict
3
+ from collections.abc import Iterator
3
4
  from dataclasses import (
4
5
  dataclass,
5
6
  field,
6
7
  )
7
- from typing import (
8
- Dict,
9
- Iterator,
10
- List,
11
- Tuple,
12
- )
13
8
 
14
9
 
15
10
  @dataclass
@@ -17,7 +12,7 @@ class ModuleValidator:
17
12
  line_iterator: Iterator[str]
18
13
  line_number: int = field(init=False, default=0)
19
14
 
20
- collisions: Dict[str, List[Tuple[int, str]]] = field(init=False, default_factory=lambda: defaultdict(list))
15
+ collisions: dict[str, list[tuple[int, str]]] = field(init=False, default_factory=lambda: defaultdict(list))
21
16
 
22
17
  def add_import(self, imp: str, number: int, full_line: str):
23
18
  """
@@ -1,12 +1,6 @@
1
1
  import pathlib
2
2
  import sys
3
- from typing import (
4
- Generator,
5
- List,
6
- Set,
7
- Tuple,
8
- Union,
9
- )
3
+ from collections.abc import Generator
10
4
 
11
5
  from betterproto2_compiler.lib.google.protobuf import (
12
6
  DescriptorProto,
@@ -45,13 +39,13 @@ from .typing_compiler import (
45
39
 
46
40
  def traverse(
47
41
  proto_file: FileDescriptorProto,
48
- ) -> Generator[Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None]:
42
+ ) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int]], None, None]:
49
43
  # Todo: Keep information about nested hierarchy
50
44
  def _traverse(
51
- path: List[int],
52
- items: Union[List[EnumDescriptorProto], List[DescriptorProto]],
45
+ path: list[int],
46
+ items: list[EnumDescriptorProto] | list[DescriptorProto],
53
47
  prefix: str = "",
54
- ) -> Generator[Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None]:
48
+ ) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int]], None, None]:
55
49
  for i, item in enumerate(items):
56
50
  # Adjust the name since we flatten the hierarchy.
57
51
  # Todo: don't change the name, but include full name in returned tuple
@@ -82,7 +76,8 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
82
76
  if output_package_name not in request_data.output_packages:
83
77
  # Create a new output if there is no output for this package
84
78
  request_data.output_packages[output_package_name] = OutputTemplate(
85
- parent_request=request_data, package_proto_obj=proto_file
79
+ parent_request=request_data,
80
+ package_proto_obj=proto_file,
86
81
  )
87
82
  # Add this input file to the output corresponding to this package
88
83
  request_data.output_packages[output_package_name].input_files.append(proto_file)
@@ -144,7 +139,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
144
139
  service.ready()
145
140
 
146
141
  # Generate output files
147
- output_paths: Set[pathlib.Path] = set()
142
+ output_paths: set[pathlib.Path] = set()
148
143
  for output_package_name, output_package in request_data.output_packages.items():
149
144
  if not output_package.output:
150
145
  continue
@@ -158,7 +153,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
158
153
  name=str(output_path),
159
154
  # Render and then format the output file
160
155
  content=outputfile_compiler(output_file=output_package),
161
- )
156
+ ),
162
157
  )
163
158
 
164
159
  # Make each output directory a package with __init__ file
@@ -183,7 +178,7 @@ def _make_one_of_field_compiler(
183
178
  source_file: "FileDescriptorProto",
184
179
  parent: MessageCompiler,
185
180
  proto_obj: "FieldDescriptorProto",
186
- path: List[int],
181
+ path: list[int],
187
182
  ) -> FieldCompiler:
188
183
  return OneOfFieldCompiler(
189
184
  source_file=source_file,
@@ -196,7 +191,7 @@ def _make_one_of_field_compiler(
196
191
 
197
192
  def read_protobuf_type(
198
193
  item: DescriptorProto,
199
- path: List[int],
194
+ path: list[int],
200
195
  source_file: "FileDescriptorProto",
201
196
  output_package: OutputTemplate,
202
197
  ) -> None:
@@ -1,15 +1,11 @@
1
1
  import abc
2
+ import builtins
2
3
  from collections import defaultdict
4
+ from collections.abc import Iterator
3
5
  from dataclasses import (
4
6
  dataclass,
5
7
  field,
6
8
  )
7
- from typing import (
8
- Dict,
9
- Iterator,
10
- Optional,
11
- Set,
12
- )
13
9
 
14
10
 
15
11
  class TypingCompiler(metaclass=abc.ABCMeta):
@@ -42,7 +38,7 @@ class TypingCompiler(metaclass=abc.ABCMeta):
42
38
  raise NotImplementedError
43
39
 
44
40
  @abc.abstractmethod
45
- def imports(self) -> Dict[str, Optional[Set[str]]]:
41
+ def imports(self) -> builtins.dict[str, set[str] | None]:
46
42
  """
47
43
  Returns either the direct import as a key with none as value, or a set of
48
44
  values to import from the key.
@@ -63,7 +59,7 @@ class TypingCompiler(metaclass=abc.ABCMeta):
63
59
 
64
60
  @dataclass
65
61
  class DirectImportTypingCompiler(TypingCompiler):
66
- _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
62
+ _imports: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set))
67
63
 
68
64
  def optional(self, type_: str) -> str:
69
65
  self._imports["typing"].add("Optional")
@@ -93,7 +89,7 @@ class DirectImportTypingCompiler(TypingCompiler):
93
89
  self._imports["typing"].add("AsyncIterator")
94
90
  return f"AsyncIterator[{type_}]"
95
91
 
96
- def imports(self) -> Dict[str, Optional[Set[str]]]:
92
+ def imports(self) -> builtins.dict[str, set[str] | None]:
97
93
  return {k: v if v else None for k, v in self._imports.items()}
98
94
 
99
95
 
@@ -129,7 +125,7 @@ class TypingImportTypingCompiler(TypingCompiler):
129
125
  self._imported = True
130
126
  return f"typing.AsyncIterator[{type_}]"
131
127
 
132
- def imports(self) -> Dict[str, Optional[Set[str]]]:
128
+ def imports(self) -> builtins.dict[str, set[str] | None]:
133
129
  if self._imported:
134
130
  return {"typing": None}
135
131
  return {}
@@ -137,7 +133,7 @@ class TypingImportTypingCompiler(TypingCompiler):
137
133
 
138
134
  @dataclass
139
135
  class NoTyping310TypingCompiler(TypingCompiler):
140
- _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
136
+ _imports: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set))
141
137
 
142
138
  def optional(self, type_: str) -> str:
143
139
  return f"{type_} | None"
@@ -163,5 +159,5 @@ class NoTyping310TypingCompiler(TypingCompiler):
163
159
  self._imports["collections.abc"].add("AsyncIterator")
164
160
  return f"AsyncIterator[{type_}]"
165
161
 
166
- def imports(self) -> Dict[str, Optional[Set[str]]]:
162
+ def imports(self) -> builtins.dict[str, set[str] | None]:
167
163
  return {k: v if v else None for k, v in self._imports.items()}
@@ -48,3 +48,5 @@ if TYPE_CHECKING:
48
48
  import grpclib.server
49
49
  from betterproto2.grpc.grpclib_client import MetadataLike
50
50
  from grpclib.metadata import Deadline
51
+
52
+ betterproto2.check_compiler_version("{{ version }}")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: betterproto2_compiler
3
- Version: 0.0.3
3
+ Version: 0.1.0
4
4
  Summary: Compiler for betterproto2
5
5
  Home-page: https://github.com/betterproto/python-betterproto2-compiler
6
6
  License: MIT
@@ -1,41 +1,34 @@
1
1
  betterproto2_compiler/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- betterproto2_compiler/_types.py,sha256=nIsUxcId43N1Gu8EqdeuflR9iUZB1JWu4JTGQV9NeUI,294
3
2
  betterproto2_compiler/casing.py,sha256=bMdI4W0hfYh6kV-DQIqFEjSfGYEqUtPciAzP64z5HLQ,3587
4
3
  betterproto2_compiler/compile/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- betterproto2_compiler/compile/importing.py,sha256=UNBvayHZmvvwQMIEZ0Zzk0KsqIGqGSbPKhv_DUg2qV8,7442
4
+ betterproto2_compiler/compile/importing.py,sha256=MN1pzFISG96wd6Djsym8q9yfsknBC5logvPrW6qnWyc,7388
6
5
  betterproto2_compiler/compile/naming.py,sha256=zf0VOmNojzyv33upOGelGxjZTEDE8JULEEED5_3inHg,562
7
- betterproto2_compiler/enum.py,sha256=LcILQf1BEjnszouUCtPwifJAR_8u2tf9U9hfwP4vXTc,5396
8
- betterproto2_compiler/grpc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
- betterproto2_compiler/grpc/grpclib_client.py,sha256=6yKUqLFEfiMUPT85g81ajvVI4bvH4kBULZa7pIL18Y8,5275
10
- betterproto2_compiler/grpc/grpclib_server.py,sha256=Tv3NIGPrxdA48_THUl3jut_IekQgkecX8NPfjvF0kdg,872
11
- betterproto2_compiler/grpc/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- betterproto2_compiler/grpc/util/async_channel.py,sha256=4sfqoHtS_-qU1GFc0LatnV_dLsmVrGEm74WOJ9RpV1Y,6778
13
6
  betterproto2_compiler/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
7
  betterproto2_compiler/lib/google/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
8
  betterproto2_compiler/lib/google/protobuf/__init__.py,sha256=RvtT0CjNfAegDv42ITEvaIVrGpEh2fXvGNyA7bd99oo,60
16
9
  betterproto2_compiler/lib/google/protobuf/compiler/__init__.py,sha256=1EvLU05Ck-rwv2Y_jW4ylJlCgqd7VpSHeFgbBFWfAmg,69
17
10
  betterproto2_compiler/lib/pydantic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
11
  betterproto2_compiler/lib/pydantic/google/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- betterproto2_compiler/lib/pydantic/google/protobuf/__init__.py,sha256=eAvwcNussnueselEJwX0KcmGbhqz4oTux7e1ugCV6KE,96981
12
+ betterproto2_compiler/lib/pydantic/google/protobuf/__init__.py,sha256=RgGj0p9-KkaO-CFGV_CLTXm1yhPf-Vm02cUkF7z39SI,96972
20
13
  betterproto2_compiler/lib/pydantic/google/protobuf/compiler/__init__.py,sha256=uI4F0DytRxbrwXqiRKOJD0nJFewlQE0Lj_3WJdiGoDU,9217
21
14
  betterproto2_compiler/lib/std/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
15
  betterproto2_compiler/lib/std/google/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
- betterproto2_compiler/lib/std/google/protobuf/__init__.py,sha256=s5Tkt3684xnq2VL0cHZXytN2ja4D73C6ZNf7APEYdcI,80454
16
+ betterproto2_compiler/lib/std/google/protobuf/__init__.py,sha256=bfxwQzYXJ-ImFktGBN5_8HY9JgASJ8BbBg9fKeLTVQ0,80437
24
17
  betterproto2_compiler/lib/std/google/protobuf/compiler/__init__.py,sha256=z63o4vkjnDWB7yZpYnx-T0FZCbmSs5gBzfKrNP_qf8c,8646
25
18
  betterproto2_compiler/plugin/__init__.py,sha256=L3pW0b4CvkM5x53x_sYt1kYiSFPO0_vaeH6EQPq9FAM,43
26
19
  betterproto2_compiler/plugin/__main__.py,sha256=vBQ82334kX06ImDbFlPFgiBRiLIinwNk3z8Khs6hd74,31
27
- betterproto2_compiler/plugin/compiler.py,sha256=J_0WvuOVXuIINogNTOtU9Kyhzbu3NDrKh7ojbMjSjJk,2032
20
+ betterproto2_compiler/plugin/compiler.py,sha256=jICLI4-5rAOkWQI1v5j7JqIvoao-ZM9szMuq0OBRteA,2138
28
21
  betterproto2_compiler/plugin/main.py,sha256=Q9PmcJqXuYYFe51l7AqHVzJrHqi2LWCUu80CZSQOOwk,1469
29
- betterproto2_compiler/plugin/models.py,sha256=Ljo08MGmZ68q7bFTI7532KrqyjjJhHMe15a-tvxAGyI,24288
30
- betterproto2_compiler/plugin/module_validation.py,sha256=vye8PjsZFs1Ikh0yNLQXuy12EdM0em0Bflgx7xYrMhk,4853
31
- betterproto2_compiler/plugin/parser.py,sha256=PFFlK7Di7BF7_tCzkIWUeaRqDfVKJQc9YjSVKeHVXWM,9651
22
+ betterproto2_compiler/plugin/models.py,sha256=7l1P8-ijdatENy-ERO5twHeMpYxLG3YEqhYM9Slm5TY,24655
23
+ betterproto2_compiler/plugin/module_validation.py,sha256=RdPFwdmkbD6NKADaHC5eaPix_pz-yGxHvYJj8Ev48fA,4822
24
+ betterproto2_compiler/plugin/parser.py,sha256=i_HDW0O4ieIatbq1hckOj0kM9MCfdzuib0fumSJqIA4,9610
32
25
  betterproto2_compiler/plugin/plugin.bat,sha256=lfLT1WguAXqyerLLsRL6BfHA0RqUE6QG79v-1BYVSpI,48
33
- betterproto2_compiler/plugin/typing_compiler.py,sha256=gMrKsrA_xFoy33tbm4VjiktXyGAARFGHPL6iVmUiPLU,4866
26
+ betterproto2_compiler/plugin/typing_compiler.py,sha256=IK6m4ggHXK7HL98Ed_WjvQ_yeWfIpf_fIBZ9SA8UcyM,4873
34
27
  betterproto2_compiler/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
35
- betterproto2_compiler/templates/header.py.j2,sha256=H3R2v5MiCebp5p1zMKCDjsVlnOLklqrwO24SUtjfdN0,1410
28
+ betterproto2_compiler/templates/header.py.j2,sha256=nxqsengMcM_IRqQYNVntPQ0gUFUPCy_1P1mcoLvbDos,1464
36
29
  betterproto2_compiler/templates/template.py.j2,sha256=icyiNdSTJRgyD20e_lTgTAvSjgnSFSn4t1L1-yZnkEM,8712
37
- betterproto2_compiler-0.0.3.dist-info/LICENSE.md,sha256=Pgl2pReU-2yw2miGeQ55UFlyzqAZ_EpYVyZ2nWjwRv4,1121
38
- betterproto2_compiler-0.0.3.dist-info/METADATA,sha256=-h0sesZ9kJlHnXU645Jmgbo-DA2BMVqfTZu7b8pCbiQ,1163
39
- betterproto2_compiler-0.0.3.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
40
- betterproto2_compiler-0.0.3.dist-info/entry_points.txt,sha256=DE80wLfBwKlvu82d9pAYzEo7Cp22WNqwU7WJZq6JAWk,83
41
- betterproto2_compiler-0.0.3.dist-info/RECORD,,
30
+ betterproto2_compiler-0.1.0.dist-info/LICENSE.md,sha256=Pgl2pReU-2yw2miGeQ55UFlyzqAZ_EpYVyZ2nWjwRv4,1121
31
+ betterproto2_compiler-0.1.0.dist-info/METADATA,sha256=sJiWJ2eOTCN0ARCo0VmNcMlCPgEFli0tfsnM2vYM-Gc,1163
32
+ betterproto2_compiler-0.1.0.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
33
+ betterproto2_compiler-0.1.0.dist-info/entry_points.txt,sha256=re3Qg8lLljbVobeeKH2f1FVQZ114wfZkGv3zCZTD8Ok,84
34
+ betterproto2_compiler-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,3 @@
1
+ [console_scripts]
2
+ protoc-gen-python_betterproto2=betterproto2_compiler.plugin:main
3
+
@@ -1,13 +0,0 @@
1
- from typing import (
2
- TYPE_CHECKING,
3
- TypeVar,
4
- )
5
-
6
- if TYPE_CHECKING:
7
- from grpclib._typing import IProtoMessage
8
-
9
- from . import Message
10
-
11
- # Bound type variable to allow methods to return `self` of subclasses
12
- T = TypeVar("T", bound="Message")
13
- ST = TypeVar("ST", bound="IProtoMessage")
@@ -1,180 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from enum import (
4
- EnumMeta,
5
- IntEnum,
6
- )
7
- from types import MappingProxyType
8
- from typing import (
9
- TYPE_CHECKING,
10
- Any,
11
- Dict,
12
- Optional,
13
- Tuple,
14
- )
15
-
16
- if TYPE_CHECKING:
17
- from collections.abc import (
18
- Generator,
19
- Mapping,
20
- )
21
-
22
- from typing_extensions import (
23
- Never,
24
- Self,
25
- )
26
-
27
-
28
- def _is_descriptor(obj: object) -> bool:
29
- return hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
30
-
31
-
32
- class EnumType(EnumMeta if TYPE_CHECKING else type):
33
- _value_map_: Mapping[int, Enum]
34
- _member_map_: Mapping[str, Enum]
35
-
36
- def __new__(mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]) -> Self:
37
- value_map = {}
38
- member_map = {}
39
-
40
- new_mcs = type(
41
- f"{name}Type",
42
- tuple(
43
- dict.fromkeys([base.__class__ for base in bases if base.__class__ is not type] + [EnumType, type])
44
- ), # reorder the bases so EnumType and type are last to avoid conflicts
45
- {"_value_map_": value_map, "_member_map_": member_map},
46
- )
47
-
48
- members = {
49
- name: value for name, value in namespace.items() if not _is_descriptor(value) and not name.startswith("__")
50
- }
51
-
52
- cls = type.__new__(
53
- new_mcs,
54
- name,
55
- bases,
56
- {key: value for key, value in namespace.items() if key not in members},
57
- )
58
- # this allows us to disallow member access from other members as
59
- # members become proper class variables
60
-
61
- for name, value in members.items():
62
- member = value_map.get(value)
63
- if member is None:
64
- member = cls.__new__(cls, name=name, value=value) # type: ignore
65
- value_map[value] = member
66
- member_map[name] = member
67
- type.__setattr__(new_mcs, name, member)
68
-
69
- return cls
70
-
71
- if not TYPE_CHECKING:
72
-
73
- def __call__(cls, value: int) -> Enum:
74
- try:
75
- return cls._value_map_[value]
76
- except (KeyError, TypeError):
77
- raise ValueError(f"{value!r} is not a valid {cls.__name__}") from None
78
-
79
- def __iter__(cls) -> Generator[Enum, None, None]:
80
- yield from cls._member_map_.values()
81
-
82
- def __reversed__(cls) -> Generator[Enum, None, None]:
83
- yield from reversed(cls._member_map_.values())
84
-
85
- def __getitem__(cls, key: str) -> Enum:
86
- return cls._member_map_[key]
87
-
88
- @property
89
- def __members__(cls) -> MappingProxyType[str, Enum]:
90
- return MappingProxyType(cls._member_map_)
91
-
92
- def __repr__(cls) -> str:
93
- return f"<enum {cls.__name__!r}>"
94
-
95
- def __len__(cls) -> int:
96
- return len(cls._member_map_)
97
-
98
- def __setattr__(cls, name: str, value: Any) -> Never:
99
- raise AttributeError(f"{cls.__name__}: cannot reassign Enum members.")
100
-
101
- def __delattr__(cls, name: str) -> Never:
102
- raise AttributeError(f"{cls.__name__}: cannot delete Enum members.")
103
-
104
- def __contains__(cls, member: object) -> bool:
105
- return isinstance(member, cls) and member.name in cls._member_map_
106
-
107
-
108
- class Enum(IntEnum if TYPE_CHECKING else int, metaclass=EnumType):
109
- """
110
- The base class for protobuf enumerations, all generated enumerations will
111
- inherit from this. Emulates `enum.IntEnum`.
112
- """
113
-
114
- name: Optional[str]
115
- value: int
116
-
117
- if not TYPE_CHECKING:
118
-
119
- def __new__(cls, *, name: Optional[str], value: int) -> Self:
120
- self = super().__new__(cls, value)
121
- super().__setattr__(self, "name", name)
122
- super().__setattr__(self, "value", value)
123
- return self
124
-
125
- def __str__(self) -> str:
126
- return self.name or "None"
127
-
128
- def __repr__(self) -> str:
129
- return f"{self.__class__.__name__}.{self.name}"
130
-
131
- def __setattr__(self, key: str, value: Any) -> Never:
132
- raise AttributeError(f"{self.__class__.__name__} Cannot reassign a member's attributes.")
133
-
134
- def __delattr__(self, item: Any) -> Never:
135
- raise AttributeError(f"{self.__class__.__name__} Cannot delete a member's attributes.")
136
-
137
- def __copy__(self) -> Self:
138
- return self
139
-
140
- def __deepcopy__(self, memo: Any) -> Self:
141
- return self
142
-
143
- @classmethod
144
- def try_value(cls, value: int = 0) -> Self:
145
- """Return the value which corresponds to the value.
146
-
147
- Parameters
148
- -----------
149
- value: :class:`int`
150
- The value of the enum member to get.
151
-
152
- Returns
153
- -------
154
- :class:`Enum`
155
- The corresponding member or a new instance of the enum if
156
- ``value`` isn't actually a member.
157
- """
158
- try:
159
- return cls._value_map_[value]
160
- except (KeyError, TypeError):
161
- return cls.__new__(cls, name=None, value=value)
162
-
163
- @classmethod
164
- def from_string(cls, name: str) -> Self:
165
- """Return the value which corresponds to the string name.
166
-
167
- Parameters
168
- -----------
169
- name: :class:`str`
170
- The name of the enum member to get.
171
-
172
- Raises
173
- -------
174
- :exc:`ValueError`
175
- The member was not found in the Enum.
176
- """
177
- try:
178
- return cls._member_map_[name]
179
- except KeyError as e:
180
- raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
File without changes
@@ -1,172 +0,0 @@
1
- import asyncio
2
- from abc import ABC
3
- from typing import (
4
- TYPE_CHECKING,
5
- AsyncIterable,
6
- AsyncIterator,
7
- Collection,
8
- Iterable,
9
- Mapping,
10
- Optional,
11
- Tuple,
12
- Type,
13
- Union,
14
- )
15
-
16
- import grpclib.const
17
-
18
- if TYPE_CHECKING:
19
- from grpclib.client import Channel
20
- from grpclib.metadata import Deadline
21
-
22
- from .._types import (
23
- IProtoMessage,
24
- T,
25
- )
26
-
27
-
28
- Value = Union[str, bytes]
29
- MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]]
30
- MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]
31
-
32
-
33
- class ServiceStub(ABC):
34
- """
35
- Base class for async gRPC clients.
36
- """
37
-
38
- def __init__(
39
- self,
40
- channel: "Channel",
41
- *,
42
- timeout: Optional[float] = None,
43
- deadline: Optional["Deadline"] = None,
44
- metadata: Optional[MetadataLike] = None,
45
- ) -> None:
46
- self.channel = channel
47
- self.timeout = timeout
48
- self.deadline = deadline
49
- self.metadata = metadata
50
-
51
- def __resolve_request_kwargs(
52
- self,
53
- timeout: Optional[float],
54
- deadline: Optional["Deadline"],
55
- metadata: Optional[MetadataLike],
56
- ):
57
- return {
58
- "timeout": self.timeout if timeout is None else timeout,
59
- "deadline": self.deadline if deadline is None else deadline,
60
- "metadata": self.metadata if metadata is None else metadata,
61
- }
62
-
63
- async def _unary_unary(
64
- self,
65
- route: str,
66
- request: "IProtoMessage",
67
- response_type: Type["T"],
68
- *,
69
- timeout: Optional[float] = None,
70
- deadline: Optional["Deadline"] = None,
71
- metadata: Optional[MetadataLike] = None,
72
- ) -> "T":
73
- """Make a unary request and return the response."""
74
- async with self.channel.request(
75
- route,
76
- grpclib.const.Cardinality.UNARY_UNARY,
77
- type(request),
78
- response_type,
79
- **self.__resolve_request_kwargs(timeout, deadline, metadata),
80
- ) as stream:
81
- await stream.send_message(request, end=True)
82
- response = await stream.recv_message()
83
- assert response is not None
84
- return response
85
-
86
- async def _unary_stream(
87
- self,
88
- route: str,
89
- request: "IProtoMessage",
90
- response_type: Type["T"],
91
- *,
92
- timeout: Optional[float] = None,
93
- deadline: Optional["Deadline"] = None,
94
- metadata: Optional[MetadataLike] = None,
95
- ) -> AsyncIterator["T"]:
96
- """Make a unary request and return the stream response iterator."""
97
- async with self.channel.request(
98
- route,
99
- grpclib.const.Cardinality.UNARY_STREAM,
100
- type(request),
101
- response_type,
102
- **self.__resolve_request_kwargs(timeout, deadline, metadata),
103
- ) as stream:
104
- await stream.send_message(request, end=True)
105
- async for message in stream:
106
- yield message
107
-
108
- async def _stream_unary(
109
- self,
110
- route: str,
111
- request_iterator: MessageSource,
112
- request_type: Type["IProtoMessage"],
113
- response_type: Type["T"],
114
- *,
115
- timeout: Optional[float] = None,
116
- deadline: Optional["Deadline"] = None,
117
- metadata: Optional[MetadataLike] = None,
118
- ) -> "T":
119
- """Make a stream request and return the response."""
120
- async with self.channel.request(
121
- route,
122
- grpclib.const.Cardinality.STREAM_UNARY,
123
- request_type,
124
- response_type,
125
- **self.__resolve_request_kwargs(timeout, deadline, metadata),
126
- ) as stream:
127
- await stream.send_request()
128
- await self._send_messages(stream, request_iterator)
129
- response = await stream.recv_message()
130
- assert response is not None
131
- return response
132
-
133
- async def _stream_stream(
134
- self,
135
- route: str,
136
- request_iterator: MessageSource,
137
- request_type: Type["IProtoMessage"],
138
- response_type: Type["T"],
139
- *,
140
- timeout: Optional[float] = None,
141
- deadline: Optional["Deadline"] = None,
142
- metadata: Optional[MetadataLike] = None,
143
- ) -> AsyncIterator["T"]:
144
- """
145
- Make a stream request and return an AsyncIterator to iterate over response
146
- messages.
147
- """
148
- async with self.channel.request(
149
- route,
150
- grpclib.const.Cardinality.STREAM_STREAM,
151
- request_type,
152
- response_type,
153
- **self.__resolve_request_kwargs(timeout, deadline, metadata),
154
- ) as stream:
155
- await stream.send_request()
156
- sending_task = asyncio.ensure_future(self._send_messages(stream, request_iterator))
157
- try:
158
- async for response in stream:
159
- yield response
160
- except:
161
- sending_task.cancel()
162
- raise
163
-
164
- @staticmethod
165
- async def _send_messages(stream, messages: MessageSource):
166
- if isinstance(messages, AsyncIterable):
167
- async for message in messages:
168
- await stream.send_message(message)
169
- else:
170
- for message in messages:
171
- await stream.send_message(message)
172
- await stream.end()
@@ -1,32 +0,0 @@
1
- from abc import ABC
2
- from collections.abc import AsyncIterable
3
- from typing import (
4
- Any,
5
- Callable,
6
- )
7
-
8
- import grpclib
9
- import grpclib.server
10
-
11
-
12
- class ServiceBase(ABC):
13
- """
14
- Base class for async gRPC servers.
15
- """
16
-
17
- async def _call_rpc_handler_server_stream(
18
- self,
19
- handler: Callable,
20
- stream: grpclib.server.Stream,
21
- request: Any,
22
- ) -> None:
23
- response_iter = handler(request)
24
- # check if response is actually an AsyncIterator
25
- # this might be false if the method just returns without
26
- # yielding at least once
27
- # in that case, we just interpret it as an empty iterator
28
- if isinstance(response_iter, AsyncIterable):
29
- async for response_message in response_iter:
30
- await stream.send_message(response_message)
31
- else:
32
- response_iter.close()
File without changes
@@ -1,190 +0,0 @@
1
- import asyncio
2
- from typing import (
3
- AsyncIterable,
4
- AsyncIterator,
5
- Iterable,
6
- Optional,
7
- TypeVar,
8
- Union,
9
- )
10
-
11
- T = TypeVar("T")
12
-
13
-
14
- class ChannelClosed(Exception):
15
- """
16
- An exception raised on an attempt to send through a closed channel
17
- """
18
-
19
-
20
- class ChannelDone(Exception):
21
- """
22
- An exception raised on an attempt to send receive from a channel that is both closed
23
- and empty.
24
- """
25
-
26
-
27
- class AsyncChannel(AsyncIterable[T]):
28
- """
29
- A buffered async channel for sending items between coroutines with FIFO ordering.
30
-
31
- This makes decoupled bidirectional steaming gRPC requests easy if used like:
32
-
33
- .. code-block:: python
34
- client = GeneratedStub(grpclib_chan)
35
- request_channel = await AsyncChannel()
36
- # We can start be sending all the requests we already have
37
- await request_channel.send_from([RequestObject(...), RequestObject(...)])
38
- async for response in client.rpc_call(request_channel):
39
- # The response iterator will remain active until the connection is closed
40
- ...
41
- # More items can be sent at any time
42
- await request_channel.send(RequestObject(...))
43
- ...
44
- # The channel must be closed to complete the gRPC connection
45
- request_channel.close()
46
-
47
- Items can be sent through the channel by either:
48
- - providing an iterable to the send_from method
49
- - passing them to the send method one at a time
50
-
51
- Items can be received from the channel by either:
52
- - iterating over the channel with a for loop to get all items
53
- - calling the receive method to get one item at a time
54
-
55
- If the channel is empty then receivers will wait until either an item appears or the
56
- channel is closed.
57
-
58
- Once the channel is closed then subsequent attempt to send through the channel will
59
- fail with a ChannelClosed exception.
60
-
61
- When th channel is closed and empty then it is done, and further attempts to receive
62
- from it will fail with a ChannelDone exception
63
-
64
- If multiple coroutines receive from the channel concurrently, each item sent will be
65
- received by only one of the receivers.
66
-
67
- :param source:
68
- An optional iterable will items that should be sent through the channel
69
- immediately.
70
- :param buffer_limit:
71
- Limit the number of items that can be buffered in the channel, A value less than
72
- 1 implies no limit. If the channel is full then attempts to send more items will
73
- result in the sender waiting until an item is received from the channel.
74
- :param close:
75
- If set to True then the channel will automatically close after exhausting source
76
- or immediately if no source is provided.
77
- """
78
-
79
- def __init__(self, *, buffer_limit: int = 0, close: bool = False):
80
- self._queue: asyncio.Queue[T] = asyncio.Queue(buffer_limit)
81
- self._closed = False
82
- self._waiting_receivers: int = 0
83
- # Track whether flush has been invoked so it can only happen once
84
- self._flushed = False
85
-
86
- def __aiter__(self) -> AsyncIterator[T]:
87
- return self
88
-
89
- async def __anext__(self) -> T:
90
- if self.done():
91
- raise StopAsyncIteration
92
- self._waiting_receivers += 1
93
- try:
94
- result = await self._queue.get()
95
- if result is self.__flush:
96
- raise StopAsyncIteration
97
- return result
98
- finally:
99
- self._waiting_receivers -= 1
100
- self._queue.task_done()
101
-
102
- def closed(self) -> bool:
103
- """
104
- Returns True if this channel is closed and no-longer accepting new items
105
- """
106
- return self._closed
107
-
108
- def done(self) -> bool:
109
- """
110
- Check if this channel is done.
111
-
112
- :return: True if this channel is closed and and has been drained of items in
113
- which case any further attempts to receive an item from this channel will raise
114
- a ChannelDone exception.
115
- """
116
- # After close the channel is not yet done until there is at least one waiting
117
- # receiver per enqueued item.
118
- return self._closed and self._queue.qsize() <= self._waiting_receivers
119
-
120
- async def send_from(self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False) -> "AsyncChannel[T]":
121
- """
122
- Iterates the given [Async]Iterable and sends all the resulting items.
123
- If close is set to True then subsequent send calls will be rejected with a
124
- ChannelClosed exception.
125
- :param source: an iterable of items to send
126
- :param close:
127
- if True then the channel will be closed after the source has been exhausted
128
-
129
- """
130
- if self._closed:
131
- raise ChannelClosed("Cannot send through a closed channel")
132
- if isinstance(source, AsyncIterable):
133
- async for item in source:
134
- await self._queue.put(item)
135
- else:
136
- for item in source:
137
- await self._queue.put(item)
138
- if close:
139
- # Complete the closing process
140
- self.close()
141
- return self
142
-
143
- async def send(self, item: T) -> "AsyncChannel[T]":
144
- """
145
- Send a single item over this channel.
146
- :param item: The item to send
147
- """
148
- if self._closed:
149
- raise ChannelClosed("Cannot send through a closed channel")
150
- await self._queue.put(item)
151
- return self
152
-
153
- async def receive(self) -> Optional[T]:
154
- """
155
- Returns the next item from this channel when it becomes available,
156
- or None if the channel is closed before another item is sent.
157
- :return: An item from the channel
158
- """
159
- if self.done():
160
- raise ChannelDone("Cannot receive from a closed channel")
161
- self._waiting_receivers += 1
162
- try:
163
- result = await self._queue.get()
164
- if result is self.__flush:
165
- return None
166
- return result
167
- finally:
168
- self._waiting_receivers -= 1
169
- self._queue.task_done()
170
-
171
- def close(self):
172
- """
173
- Close this channel to new items
174
- """
175
- self._closed = True
176
- asyncio.ensure_future(self._flush_queue())
177
-
178
- async def _flush_queue(self):
179
- """
180
- To be called after the channel is closed. Pushes a number of self.__flush
181
- objects to the queue to ensure no waiting consumers get deadlocked.
182
- """
183
- if not self._flushed:
184
- self._flushed = True
185
- deadlocked_receivers = max(0, self._waiting_receivers - self._queue.qsize())
186
- for _ in range(deadlocked_receivers):
187
- await self._queue.put(self.__flush)
188
-
189
- # A special signal object for flushing the queue when the channel is closed
190
- __flush = object()
@@ -1,3 +0,0 @@
1
- [console_scripts]
2
- protoc-gen-python_betterproto=betterproto2_compiler.plugin:main
3
-