betterproto2-compiler 0.0.3__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- betterproto2_compiler/compile/importing.py +9 -14
- betterproto2_compiler/lib/pydantic/google/protobuf/__init__.py +3 -3
- betterproto2_compiler/lib/std/google/protobuf/__init__.py +1 -2
- betterproto2_compiler/plugin/compiler.py +5 -2
- betterproto2_compiler/plugin/models.py +54 -52
- betterproto2_compiler/plugin/module_validation.py +2 -7
- betterproto2_compiler/plugin/parser.py +11 -16
- betterproto2_compiler/plugin/typing_compiler.py +8 -12
- betterproto2_compiler/templates/header.py.j2 +2 -0
- {betterproto2_compiler-0.0.3.dist-info → betterproto2_compiler-0.1.0.dist-info}/METADATA +1 -1
- {betterproto2_compiler-0.0.3.dist-info → betterproto2_compiler-0.1.0.dist-info}/RECORD +14 -21
- betterproto2_compiler-0.1.0.dist-info/entry_points.txt +3 -0
- betterproto2_compiler/_types.py +0 -13
- betterproto2_compiler/enum.py +0 -180
- betterproto2_compiler/grpc/__init__.py +0 -0
- betterproto2_compiler/grpc/grpclib_client.py +0 -172
- betterproto2_compiler/grpc/grpclib_server.py +0 -32
- betterproto2_compiler/grpc/util/__init__.py +0 -0
- betterproto2_compiler/grpc/util/async_channel.py +0 -190
- betterproto2_compiler-0.0.3.dist-info/entry_points.txt +0 -3
- {betterproto2_compiler-0.0.3.dist-info → betterproto2_compiler-0.1.0.dist-info}/LICENSE.md +0 -0
- {betterproto2_compiler-0.0.3.dist-info → betterproto2_compiler-0.1.0.dist-info}/WHEEL +0 -0
@@ -3,11 +3,6 @@ from __future__ import annotations
|
|
3
3
|
import os
|
4
4
|
from typing import (
|
5
5
|
TYPE_CHECKING,
|
6
|
-
Dict,
|
7
|
-
List,
|
8
|
-
Set,
|
9
|
-
Tuple,
|
10
|
-
Type,
|
11
6
|
)
|
12
7
|
|
13
8
|
from ..casing import safe_snake_case
|
@@ -18,7 +13,7 @@ if TYPE_CHECKING:
|
|
18
13
|
from ..plugin.models import PluginRequestCompiler
|
19
14
|
from ..plugin.typing_compiler import TypingCompiler
|
20
15
|
|
21
|
-
WRAPPER_TYPES:
|
16
|
+
WRAPPER_TYPES: dict[str, type] = {
|
22
17
|
".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
|
23
18
|
".google.protobuf.FloatValue": google_protobuf.FloatValue,
|
24
19
|
".google.protobuf.Int32Value": google_protobuf.Int32Value,
|
@@ -31,7 +26,7 @@ WRAPPER_TYPES: Dict[str, Type] = {
|
|
31
26
|
}
|
32
27
|
|
33
28
|
|
34
|
-
def parse_source_type_name(field_type_name: str, request:
|
29
|
+
def parse_source_type_name(field_type_name: str, request: PluginRequestCompiler) -> tuple[str, str]:
|
35
30
|
"""
|
36
31
|
Split full source type name into package and type name.
|
37
32
|
E.g. 'root.package.Message' -> ('root.package', 'Message')
|
@@ -77,7 +72,7 @@ def get_type_reference(
|
|
77
72
|
imports: set,
|
78
73
|
source_type: str,
|
79
74
|
typing_compiler: TypingCompiler,
|
80
|
-
request:
|
75
|
+
request: PluginRequestCompiler,
|
81
76
|
unwrap: bool = True,
|
82
77
|
pydantic: bool = False,
|
83
78
|
) -> str:
|
@@ -98,8 +93,8 @@ def get_type_reference(
|
|
98
93
|
|
99
94
|
source_package, source_type = parse_source_type_name(source_type, request)
|
100
95
|
|
101
|
-
current_package:
|
102
|
-
py_package:
|
96
|
+
current_package: list[str] = package.split(".") if package else []
|
97
|
+
py_package: list[str] = source_package.split(".") if source_package else []
|
103
98
|
py_type: str = pythonize_class_name(source_type)
|
104
99
|
|
105
100
|
compiling_google_protobuf = current_package == ["google", "protobuf"]
|
@@ -122,7 +117,7 @@ def get_type_reference(
|
|
122
117
|
return reference_cousin(current_package, imports, py_package, py_type)
|
123
118
|
|
124
119
|
|
125
|
-
def reference_absolute(imports:
|
120
|
+
def reference_absolute(imports: set[str], py_package: list[str], py_type: str) -> str:
|
126
121
|
"""
|
127
122
|
Returns a reference to a python type located in the root, i.e. sys.path.
|
128
123
|
"""
|
@@ -139,7 +134,7 @@ def reference_sibling(py_type: str) -> str:
|
|
139
134
|
return f"{py_type}"
|
140
135
|
|
141
136
|
|
142
|
-
def reference_descendent(current_package:
|
137
|
+
def reference_descendent(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str:
|
143
138
|
"""
|
144
139
|
Returns a reference to a python type in a package that is a descendent of the
|
145
140
|
current package, and adds the required import that is aliased to avoid name
|
@@ -157,7 +152,7 @@ def reference_descendent(current_package: List[str], imports: Set[str], py_packa
|
|
157
152
|
return f"{string_import}.{py_type}"
|
158
153
|
|
159
154
|
|
160
|
-
def reference_ancestor(current_package:
|
155
|
+
def reference_ancestor(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str:
|
161
156
|
"""
|
162
157
|
Returns a reference to a python type in a package which is an ancestor to the
|
163
158
|
current package, and adds the required import that is aliased (if possible) to avoid
|
@@ -178,7 +173,7 @@ def reference_ancestor(current_package: List[str], imports: Set[str], py_package
|
|
178
173
|
return string_alias
|
179
174
|
|
180
175
|
|
181
|
-
def reference_cousin(current_package:
|
176
|
+
def reference_cousin(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str:
|
182
177
|
"""
|
183
178
|
Returns a reference to a python type in a package that is not descendent, ancestor
|
184
179
|
or sibling, and adds the required import that is aliased to avoid name conflicts.
|
@@ -2401,13 +2401,13 @@ class Value(betterproto2_compiler.Message):
|
|
2401
2401
|
)
|
2402
2402
|
"""Represents a null value."""
|
2403
2403
|
|
2404
|
-
number_value:
|
2404
|
+
number_value: float | None = betterproto2_compiler.double_field(2, optional=True, group="kind")
|
2405
2405
|
"""Represents a double value."""
|
2406
2406
|
|
2407
|
-
string_value:
|
2407
|
+
string_value: str | None = betterproto2_compiler.string_field(3, optional=True, group="kind")
|
2408
2408
|
"""Represents a string value."""
|
2409
2409
|
|
2410
|
-
bool_value:
|
2410
|
+
bool_value: bool | None = betterproto2_compiler.bool_field(4, optional=True, group="kind")
|
2411
2411
|
"""Represents a boolean value."""
|
2412
2412
|
|
2413
2413
|
struct_value: Optional["Struct"] = betterproto2_compiler.message_field(5, optional=True, group="kind")
|
@@ -78,7 +78,6 @@ from typing import (
|
|
78
78
|
Dict,
|
79
79
|
List,
|
80
80
|
Mapping,
|
81
|
-
Optional,
|
82
81
|
)
|
83
82
|
|
84
83
|
import betterproto2
|
@@ -1022,7 +1021,7 @@ class FieldDescriptorProto(betterproto2.Message):
|
|
1022
1021
|
TODO(kenton): Base-64 encode?
|
1023
1022
|
"""
|
1024
1023
|
|
1025
|
-
oneof_index:
|
1024
|
+
oneof_index: int | None = betterproto2.int32_field(9, optional=True)
|
1026
1025
|
"""
|
1027
1026
|
If set, gives the index of a oneof in the containing type's oneof_decl
|
1028
1027
|
list. This field is a member of that oneof.
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import os.path
|
2
2
|
import subprocess
|
3
3
|
import sys
|
4
|
+
from importlib import metadata
|
4
5
|
|
5
6
|
from .module_validation import ModuleValidator
|
6
7
|
|
@@ -14,7 +15,7 @@ except ImportError as err:
|
|
14
15
|
"Please ensure that you've installed betterproto as "
|
15
16
|
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
16
17
|
"are included."
|
17
|
-
"\033[0m"
|
18
|
+
"\033[0m",
|
18
19
|
)
|
19
20
|
raise SystemExit(1)
|
20
21
|
|
@@ -24,6 +25,8 @@ from .models import OutputTemplate
|
|
24
25
|
def outputfile_compiler(output_file: OutputTemplate) -> str:
|
25
26
|
templates_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "templates"))
|
26
27
|
|
28
|
+
version = metadata.version("betterproto2_compiler")
|
29
|
+
|
27
30
|
env = jinja2.Environment(
|
28
31
|
trim_blocks=True,
|
29
32
|
lstrip_blocks=True,
|
@@ -35,7 +38,7 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
|
|
35
38
|
header_template = env.get_template("header.py.j2")
|
36
39
|
|
37
40
|
code = body_template.render(output_file=output_file)
|
38
|
-
code = header_template.render(output_file=output_file) + "\n" + code
|
41
|
+
code = header_template.render(output_file=output_file, version=version) + "\n" + code
|
39
42
|
|
40
43
|
# Sort imports, delete unused ones
|
41
44
|
code = subprocess.check_output(
|
@@ -31,18 +31,12 @@ reference to `A` to `B`'s `fields` attribute.
|
|
31
31
|
|
32
32
|
import builtins
|
33
33
|
import re
|
34
|
+
from collections.abc import Iterable, Iterator
|
34
35
|
from dataclasses import (
|
35
36
|
dataclass,
|
36
37
|
field,
|
37
38
|
)
|
38
39
|
from typing import (
|
39
|
-
Dict,
|
40
|
-
Iterable,
|
41
|
-
Iterator,
|
42
|
-
List,
|
43
|
-
Optional,
|
44
|
-
Set,
|
45
|
-
Type,
|
46
40
|
Union,
|
47
41
|
)
|
48
42
|
|
@@ -59,6 +53,7 @@ from betterproto2_compiler.lib.google.protobuf import (
|
|
59
53
|
FieldDescriptorProto,
|
60
54
|
FieldDescriptorProtoLabel,
|
61
55
|
FieldDescriptorProtoType,
|
56
|
+
FieldDescriptorProtoType as FieldType,
|
62
57
|
FileDescriptorProto,
|
63
58
|
MethodDescriptorProto,
|
64
59
|
)
|
@@ -146,7 +141,7 @@ PROTO_PACKED_TYPES = (
|
|
146
141
|
|
147
142
|
def get_comment(
|
148
143
|
proto_file: "FileDescriptorProto",
|
149
|
-
path:
|
144
|
+
path: list[int],
|
150
145
|
) -> str:
|
151
146
|
for sci_loc in proto_file.source_code_info.location:
|
152
147
|
if list(sci_loc.path) == path:
|
@@ -182,10 +177,10 @@ class ProtoContentBase:
|
|
182
177
|
|
183
178
|
source_file: FileDescriptorProto
|
184
179
|
typing_compiler: TypingCompiler
|
185
|
-
path:
|
180
|
+
path: list[int]
|
186
181
|
parent: Union["betterproto2.Message", "OutputTemplate"]
|
187
182
|
|
188
|
-
__dataclass_fields__:
|
183
|
+
__dataclass_fields__: dict[str, object]
|
189
184
|
|
190
185
|
def __post_init__(self) -> None:
|
191
186
|
"""Checks that no fake default fields were left as placeholders."""
|
@@ -225,10 +220,10 @@ class ProtoContentBase:
|
|
225
220
|
@dataclass
|
226
221
|
class PluginRequestCompiler:
|
227
222
|
plugin_request_obj: CodeGeneratorRequest
|
228
|
-
output_packages:
|
223
|
+
output_packages: dict[str, "OutputTemplate"] = field(default_factory=dict)
|
229
224
|
|
230
225
|
@property
|
231
|
-
def all_messages(self) ->
|
226
|
+
def all_messages(self) -> list["MessageCompiler"]:
|
232
227
|
"""All of the messages in this request.
|
233
228
|
|
234
229
|
Returns
|
@@ -250,11 +245,11 @@ class OutputTemplate:
|
|
250
245
|
|
251
246
|
parent_request: PluginRequestCompiler
|
252
247
|
package_proto_obj: FileDescriptorProto
|
253
|
-
input_files:
|
254
|
-
imports_end:
|
255
|
-
messages:
|
256
|
-
enums:
|
257
|
-
services:
|
248
|
+
input_files: list[str] = field(default_factory=list)
|
249
|
+
imports_end: set[str] = field(default_factory=set)
|
250
|
+
messages: dict[str, "MessageCompiler"] = field(default_factory=dict)
|
251
|
+
enums: dict[str, "EnumDefinitionCompiler"] = field(default_factory=dict)
|
252
|
+
services: dict[str, "ServiceCompiler"] = field(default_factory=dict)
|
258
253
|
pydantic_dataclasses: bool = False
|
259
254
|
output: bool = True
|
260
255
|
typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler)
|
@@ -290,9 +285,9 @@ class MessageCompiler(ProtoContentBase):
|
|
290
285
|
typing_compiler: TypingCompiler
|
291
286
|
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
|
292
287
|
proto_obj: DescriptorProto = PLACEHOLDER
|
293
|
-
path:
|
294
|
-
fields:
|
295
|
-
builtins_types:
|
288
|
+
path: list[int] = PLACEHOLDER
|
289
|
+
fields: list[Union["FieldCompiler", "MessageCompiler"]] = field(default_factory=list)
|
290
|
+
builtins_types: set[str] = field(default_factory=set)
|
296
291
|
|
297
292
|
def __post_init__(self) -> None:
|
298
293
|
# Add message to output file
|
@@ -328,11 +323,9 @@ class MessageCompiler(ProtoContentBase):
|
|
328
323
|
@property
|
329
324
|
def has_message_field(self) -> bool:
|
330
325
|
return any(
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
if isinstance(field.proto_obj, FieldDescriptorProto)
|
335
|
-
)
|
326
|
+
field.proto_obj.type in PROTO_MESSAGE_TYPES
|
327
|
+
for field in self.fields
|
328
|
+
if isinstance(field.proto_obj, FieldDescriptorProto)
|
336
329
|
)
|
337
330
|
|
338
331
|
|
@@ -347,7 +340,7 @@ def is_map(proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProt
|
|
347
340
|
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
|
348
341
|
if message_type == map_entry:
|
349
342
|
for nested in parent_message.nested_type: # parent message
|
350
|
-
if nested.name.replace("_", "").lower() == map_entry and nested.options.map_entry:
|
343
|
+
if nested.name.replace("_", "").lower() == map_entry and nested.options and nested.options.map_entry:
|
351
344
|
return True
|
352
345
|
return False
|
353
346
|
|
@@ -374,8 +367,8 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
|
|
374
367
|
class FieldCompiler(ProtoContentBase):
|
375
368
|
source_file: FileDescriptorProto
|
376
369
|
typing_compiler: TypingCompiler
|
377
|
-
path:
|
378
|
-
builtins_types:
|
370
|
+
path: list[int] = PLACEHOLDER
|
371
|
+
builtins_types: set[str] = field(default_factory=set)
|
379
372
|
|
380
373
|
parent: MessageCompiler = PLACEHOLDER
|
381
374
|
proto_obj: FieldDescriptorProto = PLACEHOLDER
|
@@ -390,13 +383,16 @@ class FieldCompiler(ProtoContentBase):
|
|
390
383
|
"""Construct string representation of this field as a field."""
|
391
384
|
name = f"{self.py_name}"
|
392
385
|
field_args = ", ".join(([""] + self.betterproto_field_args) if self.betterproto_field_args else [])
|
393
|
-
|
386
|
+
|
387
|
+
betterproto_field_type = (
|
388
|
+
f"betterproto2.field({self.proto_obj.number}, betterproto2.{str(self.field_type)}{field_args})"
|
389
|
+
)
|
394
390
|
if self.py_name in dir(builtins):
|
395
391
|
self.parent.builtins_types.add(self.py_name)
|
396
392
|
return f'{name}: "{self.annotation}" = {betterproto_field_type}'
|
397
393
|
|
398
394
|
@property
|
399
|
-
def betterproto_field_args(self) ->
|
395
|
+
def betterproto_field_args(self) -> list[str]:
|
400
396
|
args = []
|
401
397
|
if self.field_wraps:
|
402
398
|
args.append(f"wraps={self.field_wraps}")
|
@@ -404,9 +400,9 @@ class FieldCompiler(ProtoContentBase):
|
|
404
400
|
args.append("optional=True")
|
405
401
|
if self.repeated:
|
406
402
|
args.append("repeated=True")
|
407
|
-
if self.field_type ==
|
403
|
+
if self.field_type == FieldType.TYPE_ENUM:
|
408
404
|
t = self.py_type
|
409
|
-
args.append(f"
|
405
|
+
args.append(f"default_factory=lambda: {t}.try_value(0)")
|
410
406
|
return args
|
411
407
|
|
412
408
|
@property
|
@@ -416,7 +412,7 @@ class FieldCompiler(ProtoContentBase):
|
|
416
412
|
)
|
417
413
|
|
418
414
|
@property
|
419
|
-
def field_wraps(self) ->
|
415
|
+
def field_wraps(self) -> str | None:
|
420
416
|
"""Returns betterproto wrapped field type or None."""
|
421
417
|
match_wrapper = re.match(r"\.google\.protobuf\.(.+)Value$", self.proto_obj.type_name)
|
422
418
|
if match_wrapper:
|
@@ -428,17 +424,19 @@ class FieldCompiler(ProtoContentBase):
|
|
428
424
|
@property
|
429
425
|
def repeated(self) -> bool:
|
430
426
|
return self.proto_obj.label == FieldDescriptorProtoLabel.LABEL_REPEATED and not is_map(
|
431
|
-
self.proto_obj,
|
427
|
+
self.proto_obj,
|
428
|
+
self.parent,
|
432
429
|
)
|
433
430
|
|
434
431
|
@property
|
435
432
|
def optional(self) -> bool:
|
436
|
-
|
433
|
+
# TODO not for maps
|
434
|
+
return self.proto_obj.proto3_optional or (self.field_type == FieldType.TYPE_MESSAGE and not self.repeated)
|
437
435
|
|
438
436
|
@property
|
439
|
-
def field_type(self) ->
|
440
|
-
|
441
|
-
return
|
437
|
+
def field_type(self) -> FieldType:
|
438
|
+
# TODO it should be possible to remove constructor
|
439
|
+
return FieldType(self.proto_obj.type)
|
442
440
|
|
443
441
|
@property
|
444
442
|
def packed(self) -> bool:
|
@@ -500,7 +498,7 @@ class OneOfFieldCompiler(FieldCompiler):
|
|
500
498
|
return True
|
501
499
|
|
502
500
|
@property
|
503
|
-
def betterproto_field_args(self) ->
|
501
|
+
def betterproto_field_args(self) -> list[str]:
|
504
502
|
args = super().betterproto_field_args
|
505
503
|
group = self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name
|
506
504
|
args.append(f'group="{group}"')
|
@@ -509,8 +507,8 @@ class OneOfFieldCompiler(FieldCompiler):
|
|
509
507
|
|
510
508
|
@dataclass
|
511
509
|
class MapEntryCompiler(FieldCompiler):
|
512
|
-
py_k_type:
|
513
|
-
py_v_type:
|
510
|
+
py_k_type: type | None = None
|
511
|
+
py_v_type: type | None = None
|
514
512
|
proto_k_type: str = ""
|
515
513
|
proto_v_type: str = ""
|
516
514
|
|
@@ -547,13 +545,17 @@ class MapEntryCompiler(FieldCompiler):
|
|
547
545
|
|
548
546
|
raise ValueError("can't find enum")
|
549
547
|
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
548
|
+
def get_field_string(self) -> str:
|
549
|
+
"""Construct string representation of this field as a field."""
|
550
|
+
betterproto_field_type = (
|
551
|
+
f"betterproto2.field({self.proto_obj.number}, "
|
552
|
+
"betterproto2.TYPE_MAP, "
|
553
|
+
f"map_types=(betterproto2.{self.proto_k_type}, "
|
554
|
+
f"betterproto2.{self.proto_v_type}))"
|
555
|
+
)
|
556
|
+
if self.py_name in dir(builtins):
|
557
|
+
self.parent.builtins_types.add(self.py_name)
|
558
|
+
return f'{self.py_name}: "{self.annotation}" = {betterproto_field_type}'
|
557
559
|
|
558
560
|
@property
|
559
561
|
def annotation(self) -> str:
|
@@ -569,7 +571,7 @@ class EnumDefinitionCompiler(MessageCompiler):
|
|
569
571
|
"""Representation of a proto Enum definition."""
|
570
572
|
|
571
573
|
proto_obj: EnumDescriptorProto = PLACEHOLDER
|
572
|
-
entries:
|
574
|
+
entries: list["EnumDefinitionCompiler.EnumEntry"] = PLACEHOLDER
|
573
575
|
|
574
576
|
@dataclass(unsafe_hash=True)
|
575
577
|
class EnumEntry:
|
@@ -597,8 +599,8 @@ class ServiceCompiler(ProtoContentBase):
|
|
597
599
|
source_file: FileDescriptorProto
|
598
600
|
parent: OutputTemplate = PLACEHOLDER
|
599
601
|
proto_obj: DescriptorProto = PLACEHOLDER
|
600
|
-
path:
|
601
|
-
methods:
|
602
|
+
path: list[int] = PLACEHOLDER
|
603
|
+
methods: list["ServiceMethodCompiler"] = field(default_factory=list)
|
602
604
|
|
603
605
|
def __post_init__(self) -> None:
|
604
606
|
# Add service to output file
|
@@ -619,7 +621,7 @@ class ServiceMethodCompiler(ProtoContentBase):
|
|
619
621
|
source_file: FileDescriptorProto
|
620
622
|
parent: ServiceCompiler
|
621
623
|
proto_obj: MethodDescriptorProto
|
622
|
-
path:
|
624
|
+
path: list[int] = PLACEHOLDER
|
623
625
|
|
624
626
|
def __post_init__(self) -> None:
|
625
627
|
# Add method to service
|
@@ -1,15 +1,10 @@
|
|
1
1
|
import re
|
2
2
|
from collections import defaultdict
|
3
|
+
from collections.abc import Iterator
|
3
4
|
from dataclasses import (
|
4
5
|
dataclass,
|
5
6
|
field,
|
6
7
|
)
|
7
|
-
from typing import (
|
8
|
-
Dict,
|
9
|
-
Iterator,
|
10
|
-
List,
|
11
|
-
Tuple,
|
12
|
-
)
|
13
8
|
|
14
9
|
|
15
10
|
@dataclass
|
@@ -17,7 +12,7 @@ class ModuleValidator:
|
|
17
12
|
line_iterator: Iterator[str]
|
18
13
|
line_number: int = field(init=False, default=0)
|
19
14
|
|
20
|
-
collisions:
|
15
|
+
collisions: dict[str, list[tuple[int, str]]] = field(init=False, default_factory=lambda: defaultdict(list))
|
21
16
|
|
22
17
|
def add_import(self, imp: str, number: int, full_line: str):
|
23
18
|
"""
|
@@ -1,12 +1,6 @@
|
|
1
1
|
import pathlib
|
2
2
|
import sys
|
3
|
-
from
|
4
|
-
Generator,
|
5
|
-
List,
|
6
|
-
Set,
|
7
|
-
Tuple,
|
8
|
-
Union,
|
9
|
-
)
|
3
|
+
from collections.abc import Generator
|
10
4
|
|
11
5
|
from betterproto2_compiler.lib.google.protobuf import (
|
12
6
|
DescriptorProto,
|
@@ -45,13 +39,13 @@ from .typing_compiler import (
|
|
45
39
|
|
46
40
|
def traverse(
|
47
41
|
proto_file: FileDescriptorProto,
|
48
|
-
) -> Generator[
|
42
|
+
) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int]], None, None]:
|
49
43
|
# Todo: Keep information about nested hierarchy
|
50
44
|
def _traverse(
|
51
|
-
path:
|
52
|
-
items:
|
45
|
+
path: list[int],
|
46
|
+
items: list[EnumDescriptorProto] | list[DescriptorProto],
|
53
47
|
prefix: str = "",
|
54
|
-
) -> Generator[
|
48
|
+
) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int]], None, None]:
|
55
49
|
for i, item in enumerate(items):
|
56
50
|
# Adjust the name since we flatten the hierarchy.
|
57
51
|
# Todo: don't change the name, but include full name in returned tuple
|
@@ -82,7 +76,8 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
|
82
76
|
if output_package_name not in request_data.output_packages:
|
83
77
|
# Create a new output if there is no output for this package
|
84
78
|
request_data.output_packages[output_package_name] = OutputTemplate(
|
85
|
-
parent_request=request_data,
|
79
|
+
parent_request=request_data,
|
80
|
+
package_proto_obj=proto_file,
|
86
81
|
)
|
87
82
|
# Add this input file to the output corresponding to this package
|
88
83
|
request_data.output_packages[output_package_name].input_files.append(proto_file)
|
@@ -144,7 +139,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
|
144
139
|
service.ready()
|
145
140
|
|
146
141
|
# Generate output files
|
147
|
-
output_paths:
|
142
|
+
output_paths: set[pathlib.Path] = set()
|
148
143
|
for output_package_name, output_package in request_data.output_packages.items():
|
149
144
|
if not output_package.output:
|
150
145
|
continue
|
@@ -158,7 +153,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
|
158
153
|
name=str(output_path),
|
159
154
|
# Render and then format the output file
|
160
155
|
content=outputfile_compiler(output_file=output_package),
|
161
|
-
)
|
156
|
+
),
|
162
157
|
)
|
163
158
|
|
164
159
|
# Make each output directory a package with __init__ file
|
@@ -183,7 +178,7 @@ def _make_one_of_field_compiler(
|
|
183
178
|
source_file: "FileDescriptorProto",
|
184
179
|
parent: MessageCompiler,
|
185
180
|
proto_obj: "FieldDescriptorProto",
|
186
|
-
path:
|
181
|
+
path: list[int],
|
187
182
|
) -> FieldCompiler:
|
188
183
|
return OneOfFieldCompiler(
|
189
184
|
source_file=source_file,
|
@@ -196,7 +191,7 @@ def _make_one_of_field_compiler(
|
|
196
191
|
|
197
192
|
def read_protobuf_type(
|
198
193
|
item: DescriptorProto,
|
199
|
-
path:
|
194
|
+
path: list[int],
|
200
195
|
source_file: "FileDescriptorProto",
|
201
196
|
output_package: OutputTemplate,
|
202
197
|
) -> None:
|
@@ -1,15 +1,11 @@
|
|
1
1
|
import abc
|
2
|
+
import builtins
|
2
3
|
from collections import defaultdict
|
4
|
+
from collections.abc import Iterator
|
3
5
|
from dataclasses import (
|
4
6
|
dataclass,
|
5
7
|
field,
|
6
8
|
)
|
7
|
-
from typing import (
|
8
|
-
Dict,
|
9
|
-
Iterator,
|
10
|
-
Optional,
|
11
|
-
Set,
|
12
|
-
)
|
13
9
|
|
14
10
|
|
15
11
|
class TypingCompiler(metaclass=abc.ABCMeta):
|
@@ -42,7 +38,7 @@ class TypingCompiler(metaclass=abc.ABCMeta):
|
|
42
38
|
raise NotImplementedError
|
43
39
|
|
44
40
|
@abc.abstractmethod
|
45
|
-
def imports(self) ->
|
41
|
+
def imports(self) -> builtins.dict[str, set[str] | None]:
|
46
42
|
"""
|
47
43
|
Returns either the direct import as a key with none as value, or a set of
|
48
44
|
values to import from the key.
|
@@ -63,7 +59,7 @@ class TypingCompiler(metaclass=abc.ABCMeta):
|
|
63
59
|
|
64
60
|
@dataclass
|
65
61
|
class DirectImportTypingCompiler(TypingCompiler):
|
66
|
-
_imports:
|
62
|
+
_imports: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set))
|
67
63
|
|
68
64
|
def optional(self, type_: str) -> str:
|
69
65
|
self._imports["typing"].add("Optional")
|
@@ -93,7 +89,7 @@ class DirectImportTypingCompiler(TypingCompiler):
|
|
93
89
|
self._imports["typing"].add("AsyncIterator")
|
94
90
|
return f"AsyncIterator[{type_}]"
|
95
91
|
|
96
|
-
def imports(self) ->
|
92
|
+
def imports(self) -> builtins.dict[str, set[str] | None]:
|
97
93
|
return {k: v if v else None for k, v in self._imports.items()}
|
98
94
|
|
99
95
|
|
@@ -129,7 +125,7 @@ class TypingImportTypingCompiler(TypingCompiler):
|
|
129
125
|
self._imported = True
|
130
126
|
return f"typing.AsyncIterator[{type_}]"
|
131
127
|
|
132
|
-
def imports(self) ->
|
128
|
+
def imports(self) -> builtins.dict[str, set[str] | None]:
|
133
129
|
if self._imported:
|
134
130
|
return {"typing": None}
|
135
131
|
return {}
|
@@ -137,7 +133,7 @@ class TypingImportTypingCompiler(TypingCompiler):
|
|
137
133
|
|
138
134
|
@dataclass
|
139
135
|
class NoTyping310TypingCompiler(TypingCompiler):
|
140
|
-
_imports:
|
136
|
+
_imports: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set))
|
141
137
|
|
142
138
|
def optional(self, type_: str) -> str:
|
143
139
|
return f"{type_} | None"
|
@@ -163,5 +159,5 @@ class NoTyping310TypingCompiler(TypingCompiler):
|
|
163
159
|
self._imports["collections.abc"].add("AsyncIterator")
|
164
160
|
return f"AsyncIterator[{type_}]"
|
165
161
|
|
166
|
-
def imports(self) ->
|
162
|
+
def imports(self) -> builtins.dict[str, set[str] | None]:
|
167
163
|
return {k: v if v else None for k, v in self._imports.items()}
|
@@ -1,41 +1,34 @@
|
|
1
1
|
betterproto2_compiler/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
betterproto2_compiler/_types.py,sha256=nIsUxcId43N1Gu8EqdeuflR9iUZB1JWu4JTGQV9NeUI,294
|
3
2
|
betterproto2_compiler/casing.py,sha256=bMdI4W0hfYh6kV-DQIqFEjSfGYEqUtPciAzP64z5HLQ,3587
|
4
3
|
betterproto2_compiler/compile/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
-
betterproto2_compiler/compile/importing.py,sha256=
|
4
|
+
betterproto2_compiler/compile/importing.py,sha256=MN1pzFISG96wd6Djsym8q9yfsknBC5logvPrW6qnWyc,7388
|
6
5
|
betterproto2_compiler/compile/naming.py,sha256=zf0VOmNojzyv33upOGelGxjZTEDE8JULEEED5_3inHg,562
|
7
|
-
betterproto2_compiler/enum.py,sha256=LcILQf1BEjnszouUCtPwifJAR_8u2tf9U9hfwP4vXTc,5396
|
8
|
-
betterproto2_compiler/grpc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
|
-
betterproto2_compiler/grpc/grpclib_client.py,sha256=6yKUqLFEfiMUPT85g81ajvVI4bvH4kBULZa7pIL18Y8,5275
|
10
|
-
betterproto2_compiler/grpc/grpclib_server.py,sha256=Tv3NIGPrxdA48_THUl3jut_IekQgkecX8NPfjvF0kdg,872
|
11
|
-
betterproto2_compiler/grpc/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
betterproto2_compiler/grpc/util/async_channel.py,sha256=4sfqoHtS_-qU1GFc0LatnV_dLsmVrGEm74WOJ9RpV1Y,6778
|
13
6
|
betterproto2_compiler/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
7
|
betterproto2_compiler/lib/google/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
8
|
betterproto2_compiler/lib/google/protobuf/__init__.py,sha256=RvtT0CjNfAegDv42ITEvaIVrGpEh2fXvGNyA7bd99oo,60
|
16
9
|
betterproto2_compiler/lib/google/protobuf/compiler/__init__.py,sha256=1EvLU05Ck-rwv2Y_jW4ylJlCgqd7VpSHeFgbBFWfAmg,69
|
17
10
|
betterproto2_compiler/lib/pydantic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
18
11
|
betterproto2_compiler/lib/pydantic/google/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
betterproto2_compiler/lib/pydantic/google/protobuf/__init__.py,sha256=
|
12
|
+
betterproto2_compiler/lib/pydantic/google/protobuf/__init__.py,sha256=RgGj0p9-KkaO-CFGV_CLTXm1yhPf-Vm02cUkF7z39SI,96972
|
20
13
|
betterproto2_compiler/lib/pydantic/google/protobuf/compiler/__init__.py,sha256=uI4F0DytRxbrwXqiRKOJD0nJFewlQE0Lj_3WJdiGoDU,9217
|
21
14
|
betterproto2_compiler/lib/std/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
15
|
betterproto2_compiler/lib/std/google/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
|
-
betterproto2_compiler/lib/std/google/protobuf/__init__.py,sha256=
|
16
|
+
betterproto2_compiler/lib/std/google/protobuf/__init__.py,sha256=bfxwQzYXJ-ImFktGBN5_8HY9JgASJ8BbBg9fKeLTVQ0,80437
|
24
17
|
betterproto2_compiler/lib/std/google/protobuf/compiler/__init__.py,sha256=z63o4vkjnDWB7yZpYnx-T0FZCbmSs5gBzfKrNP_qf8c,8646
|
25
18
|
betterproto2_compiler/plugin/__init__.py,sha256=L3pW0b4CvkM5x53x_sYt1kYiSFPO0_vaeH6EQPq9FAM,43
|
26
19
|
betterproto2_compiler/plugin/__main__.py,sha256=vBQ82334kX06ImDbFlPFgiBRiLIinwNk3z8Khs6hd74,31
|
27
|
-
betterproto2_compiler/plugin/compiler.py,sha256=
|
20
|
+
betterproto2_compiler/plugin/compiler.py,sha256=jICLI4-5rAOkWQI1v5j7JqIvoao-ZM9szMuq0OBRteA,2138
|
28
21
|
betterproto2_compiler/plugin/main.py,sha256=Q9PmcJqXuYYFe51l7AqHVzJrHqi2LWCUu80CZSQOOwk,1469
|
29
|
-
betterproto2_compiler/plugin/models.py,sha256=
|
30
|
-
betterproto2_compiler/plugin/module_validation.py,sha256=
|
31
|
-
betterproto2_compiler/plugin/parser.py,sha256=
|
22
|
+
betterproto2_compiler/plugin/models.py,sha256=7l1P8-ijdatENy-ERO5twHeMpYxLG3YEqhYM9Slm5TY,24655
|
23
|
+
betterproto2_compiler/plugin/module_validation.py,sha256=RdPFwdmkbD6NKADaHC5eaPix_pz-yGxHvYJj8Ev48fA,4822
|
24
|
+
betterproto2_compiler/plugin/parser.py,sha256=i_HDW0O4ieIatbq1hckOj0kM9MCfdzuib0fumSJqIA4,9610
|
32
25
|
betterproto2_compiler/plugin/plugin.bat,sha256=lfLT1WguAXqyerLLsRL6BfHA0RqUE6QG79v-1BYVSpI,48
|
33
|
-
betterproto2_compiler/plugin/typing_compiler.py,sha256=
|
26
|
+
betterproto2_compiler/plugin/typing_compiler.py,sha256=IK6m4ggHXK7HL98Ed_WjvQ_yeWfIpf_fIBZ9SA8UcyM,4873
|
34
27
|
betterproto2_compiler/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
35
|
-
betterproto2_compiler/templates/header.py.j2,sha256=
|
28
|
+
betterproto2_compiler/templates/header.py.j2,sha256=nxqsengMcM_IRqQYNVntPQ0gUFUPCy_1P1mcoLvbDos,1464
|
36
29
|
betterproto2_compiler/templates/template.py.j2,sha256=icyiNdSTJRgyD20e_lTgTAvSjgnSFSn4t1L1-yZnkEM,8712
|
37
|
-
betterproto2_compiler-0.0.
|
38
|
-
betterproto2_compiler-0.0.
|
39
|
-
betterproto2_compiler-0.0.
|
40
|
-
betterproto2_compiler-0.0.
|
41
|
-
betterproto2_compiler-0.0.
|
30
|
+
betterproto2_compiler-0.1.0.dist-info/LICENSE.md,sha256=Pgl2pReU-2yw2miGeQ55UFlyzqAZ_EpYVyZ2nWjwRv4,1121
|
31
|
+
betterproto2_compiler-0.1.0.dist-info/METADATA,sha256=sJiWJ2eOTCN0ARCo0VmNcMlCPgEFli0tfsnM2vYM-Gc,1163
|
32
|
+
betterproto2_compiler-0.1.0.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
33
|
+
betterproto2_compiler-0.1.0.dist-info/entry_points.txt,sha256=re3Qg8lLljbVobeeKH2f1FVQZ114wfZkGv3zCZTD8Ok,84
|
34
|
+
betterproto2_compiler-0.1.0.dist-info/RECORD,,
|
betterproto2_compiler/_types.py
DELETED
@@ -1,13 +0,0 @@
|
|
1
|
-
from typing import (
|
2
|
-
TYPE_CHECKING,
|
3
|
-
TypeVar,
|
4
|
-
)
|
5
|
-
|
6
|
-
if TYPE_CHECKING:
|
7
|
-
from grpclib._typing import IProtoMessage
|
8
|
-
|
9
|
-
from . import Message
|
10
|
-
|
11
|
-
# Bound type variable to allow methods to return `self` of subclasses
|
12
|
-
T = TypeVar("T", bound="Message")
|
13
|
-
ST = TypeVar("ST", bound="IProtoMessage")
|
betterproto2_compiler/enum.py
DELETED
@@ -1,180 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from enum import (
|
4
|
-
EnumMeta,
|
5
|
-
IntEnum,
|
6
|
-
)
|
7
|
-
from types import MappingProxyType
|
8
|
-
from typing import (
|
9
|
-
TYPE_CHECKING,
|
10
|
-
Any,
|
11
|
-
Dict,
|
12
|
-
Optional,
|
13
|
-
Tuple,
|
14
|
-
)
|
15
|
-
|
16
|
-
if TYPE_CHECKING:
|
17
|
-
from collections.abc import (
|
18
|
-
Generator,
|
19
|
-
Mapping,
|
20
|
-
)
|
21
|
-
|
22
|
-
from typing_extensions import (
|
23
|
-
Never,
|
24
|
-
Self,
|
25
|
-
)
|
26
|
-
|
27
|
-
|
28
|
-
def _is_descriptor(obj: object) -> bool:
|
29
|
-
return hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
|
30
|
-
|
31
|
-
|
32
|
-
class EnumType(EnumMeta if TYPE_CHECKING else type):
|
33
|
-
_value_map_: Mapping[int, Enum]
|
34
|
-
_member_map_: Mapping[str, Enum]
|
35
|
-
|
36
|
-
def __new__(mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]) -> Self:
|
37
|
-
value_map = {}
|
38
|
-
member_map = {}
|
39
|
-
|
40
|
-
new_mcs = type(
|
41
|
-
f"{name}Type",
|
42
|
-
tuple(
|
43
|
-
dict.fromkeys([base.__class__ for base in bases if base.__class__ is not type] + [EnumType, type])
|
44
|
-
), # reorder the bases so EnumType and type are last to avoid conflicts
|
45
|
-
{"_value_map_": value_map, "_member_map_": member_map},
|
46
|
-
)
|
47
|
-
|
48
|
-
members = {
|
49
|
-
name: value for name, value in namespace.items() if not _is_descriptor(value) and not name.startswith("__")
|
50
|
-
}
|
51
|
-
|
52
|
-
cls = type.__new__(
|
53
|
-
new_mcs,
|
54
|
-
name,
|
55
|
-
bases,
|
56
|
-
{key: value for key, value in namespace.items() if key not in members},
|
57
|
-
)
|
58
|
-
# this allows us to disallow member access from other members as
|
59
|
-
# members become proper class variables
|
60
|
-
|
61
|
-
for name, value in members.items():
|
62
|
-
member = value_map.get(value)
|
63
|
-
if member is None:
|
64
|
-
member = cls.__new__(cls, name=name, value=value) # type: ignore
|
65
|
-
value_map[value] = member
|
66
|
-
member_map[name] = member
|
67
|
-
type.__setattr__(new_mcs, name, member)
|
68
|
-
|
69
|
-
return cls
|
70
|
-
|
71
|
-
if not TYPE_CHECKING:
|
72
|
-
|
73
|
-
def __call__(cls, value: int) -> Enum:
|
74
|
-
try:
|
75
|
-
return cls._value_map_[value]
|
76
|
-
except (KeyError, TypeError):
|
77
|
-
raise ValueError(f"{value!r} is not a valid {cls.__name__}") from None
|
78
|
-
|
79
|
-
def __iter__(cls) -> Generator[Enum, None, None]:
|
80
|
-
yield from cls._member_map_.values()
|
81
|
-
|
82
|
-
def __reversed__(cls) -> Generator[Enum, None, None]:
|
83
|
-
yield from reversed(cls._member_map_.values())
|
84
|
-
|
85
|
-
def __getitem__(cls, key: str) -> Enum:
|
86
|
-
return cls._member_map_[key]
|
87
|
-
|
88
|
-
@property
|
89
|
-
def __members__(cls) -> MappingProxyType[str, Enum]:
|
90
|
-
return MappingProxyType(cls._member_map_)
|
91
|
-
|
92
|
-
def __repr__(cls) -> str:
|
93
|
-
return f"<enum {cls.__name__!r}>"
|
94
|
-
|
95
|
-
def __len__(cls) -> int:
|
96
|
-
return len(cls._member_map_)
|
97
|
-
|
98
|
-
def __setattr__(cls, name: str, value: Any) -> Never:
|
99
|
-
raise AttributeError(f"{cls.__name__}: cannot reassign Enum members.")
|
100
|
-
|
101
|
-
def __delattr__(cls, name: str) -> Never:
|
102
|
-
raise AttributeError(f"{cls.__name__}: cannot delete Enum members.")
|
103
|
-
|
104
|
-
def __contains__(cls, member: object) -> bool:
|
105
|
-
return isinstance(member, cls) and member.name in cls._member_map_
|
106
|
-
|
107
|
-
|
108
|
-
class Enum(IntEnum if TYPE_CHECKING else int, metaclass=EnumType):
|
109
|
-
"""
|
110
|
-
The base class for protobuf enumerations, all generated enumerations will
|
111
|
-
inherit from this. Emulates `enum.IntEnum`.
|
112
|
-
"""
|
113
|
-
|
114
|
-
name: Optional[str]
|
115
|
-
value: int
|
116
|
-
|
117
|
-
if not TYPE_CHECKING:
|
118
|
-
|
119
|
-
def __new__(cls, *, name: Optional[str], value: int) -> Self:
|
120
|
-
self = super().__new__(cls, value)
|
121
|
-
super().__setattr__(self, "name", name)
|
122
|
-
super().__setattr__(self, "value", value)
|
123
|
-
return self
|
124
|
-
|
125
|
-
def __str__(self) -> str:
|
126
|
-
return self.name or "None"
|
127
|
-
|
128
|
-
def __repr__(self) -> str:
|
129
|
-
return f"{self.__class__.__name__}.{self.name}"
|
130
|
-
|
131
|
-
def __setattr__(self, key: str, value: Any) -> Never:
|
132
|
-
raise AttributeError(f"{self.__class__.__name__} Cannot reassign a member's attributes.")
|
133
|
-
|
134
|
-
def __delattr__(self, item: Any) -> Never:
|
135
|
-
raise AttributeError(f"{self.__class__.__name__} Cannot delete a member's attributes.")
|
136
|
-
|
137
|
-
def __copy__(self) -> Self:
|
138
|
-
return self
|
139
|
-
|
140
|
-
def __deepcopy__(self, memo: Any) -> Self:
|
141
|
-
return self
|
142
|
-
|
143
|
-
@classmethod
|
144
|
-
def try_value(cls, value: int = 0) -> Self:
|
145
|
-
"""Return the value which corresponds to the value.
|
146
|
-
|
147
|
-
Parameters
|
148
|
-
-----------
|
149
|
-
value: :class:`int`
|
150
|
-
The value of the enum member to get.
|
151
|
-
|
152
|
-
Returns
|
153
|
-
-------
|
154
|
-
:class:`Enum`
|
155
|
-
The corresponding member or a new instance of the enum if
|
156
|
-
``value`` isn't actually a member.
|
157
|
-
"""
|
158
|
-
try:
|
159
|
-
return cls._value_map_[value]
|
160
|
-
except (KeyError, TypeError):
|
161
|
-
return cls.__new__(cls, name=None, value=value)
|
162
|
-
|
163
|
-
@classmethod
|
164
|
-
def from_string(cls, name: str) -> Self:
|
165
|
-
"""Return the value which corresponds to the string name.
|
166
|
-
|
167
|
-
Parameters
|
168
|
-
-----------
|
169
|
-
name: :class:`str`
|
170
|
-
The name of the enum member to get.
|
171
|
-
|
172
|
-
Raises
|
173
|
-
-------
|
174
|
-
:exc:`ValueError`
|
175
|
-
The member was not found in the Enum.
|
176
|
-
"""
|
177
|
-
try:
|
178
|
-
return cls._member_map_[name]
|
179
|
-
except KeyError as e:
|
180
|
-
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
|
File without changes
|
@@ -1,172 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
from abc import ABC
|
3
|
-
from typing import (
|
4
|
-
TYPE_CHECKING,
|
5
|
-
AsyncIterable,
|
6
|
-
AsyncIterator,
|
7
|
-
Collection,
|
8
|
-
Iterable,
|
9
|
-
Mapping,
|
10
|
-
Optional,
|
11
|
-
Tuple,
|
12
|
-
Type,
|
13
|
-
Union,
|
14
|
-
)
|
15
|
-
|
16
|
-
import grpclib.const
|
17
|
-
|
18
|
-
if TYPE_CHECKING:
|
19
|
-
from grpclib.client import Channel
|
20
|
-
from grpclib.metadata import Deadline
|
21
|
-
|
22
|
-
from .._types import (
|
23
|
-
IProtoMessage,
|
24
|
-
T,
|
25
|
-
)
|
26
|
-
|
27
|
-
|
28
|
-
Value = Union[str, bytes]
|
29
|
-
MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]]
|
30
|
-
MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]
|
31
|
-
|
32
|
-
|
33
|
-
class ServiceStub(ABC):
|
34
|
-
"""
|
35
|
-
Base class for async gRPC clients.
|
36
|
-
"""
|
37
|
-
|
38
|
-
def __init__(
|
39
|
-
self,
|
40
|
-
channel: "Channel",
|
41
|
-
*,
|
42
|
-
timeout: Optional[float] = None,
|
43
|
-
deadline: Optional["Deadline"] = None,
|
44
|
-
metadata: Optional[MetadataLike] = None,
|
45
|
-
) -> None:
|
46
|
-
self.channel = channel
|
47
|
-
self.timeout = timeout
|
48
|
-
self.deadline = deadline
|
49
|
-
self.metadata = metadata
|
50
|
-
|
51
|
-
def __resolve_request_kwargs(
|
52
|
-
self,
|
53
|
-
timeout: Optional[float],
|
54
|
-
deadline: Optional["Deadline"],
|
55
|
-
metadata: Optional[MetadataLike],
|
56
|
-
):
|
57
|
-
return {
|
58
|
-
"timeout": self.timeout if timeout is None else timeout,
|
59
|
-
"deadline": self.deadline if deadline is None else deadline,
|
60
|
-
"metadata": self.metadata if metadata is None else metadata,
|
61
|
-
}
|
62
|
-
|
63
|
-
async def _unary_unary(
|
64
|
-
self,
|
65
|
-
route: str,
|
66
|
-
request: "IProtoMessage",
|
67
|
-
response_type: Type["T"],
|
68
|
-
*,
|
69
|
-
timeout: Optional[float] = None,
|
70
|
-
deadline: Optional["Deadline"] = None,
|
71
|
-
metadata: Optional[MetadataLike] = None,
|
72
|
-
) -> "T":
|
73
|
-
"""Make a unary request and return the response."""
|
74
|
-
async with self.channel.request(
|
75
|
-
route,
|
76
|
-
grpclib.const.Cardinality.UNARY_UNARY,
|
77
|
-
type(request),
|
78
|
-
response_type,
|
79
|
-
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
80
|
-
) as stream:
|
81
|
-
await stream.send_message(request, end=True)
|
82
|
-
response = await stream.recv_message()
|
83
|
-
assert response is not None
|
84
|
-
return response
|
85
|
-
|
86
|
-
async def _unary_stream(
|
87
|
-
self,
|
88
|
-
route: str,
|
89
|
-
request: "IProtoMessage",
|
90
|
-
response_type: Type["T"],
|
91
|
-
*,
|
92
|
-
timeout: Optional[float] = None,
|
93
|
-
deadline: Optional["Deadline"] = None,
|
94
|
-
metadata: Optional[MetadataLike] = None,
|
95
|
-
) -> AsyncIterator["T"]:
|
96
|
-
"""Make a unary request and return the stream response iterator."""
|
97
|
-
async with self.channel.request(
|
98
|
-
route,
|
99
|
-
grpclib.const.Cardinality.UNARY_STREAM,
|
100
|
-
type(request),
|
101
|
-
response_type,
|
102
|
-
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
103
|
-
) as stream:
|
104
|
-
await stream.send_message(request, end=True)
|
105
|
-
async for message in stream:
|
106
|
-
yield message
|
107
|
-
|
108
|
-
async def _stream_unary(
|
109
|
-
self,
|
110
|
-
route: str,
|
111
|
-
request_iterator: MessageSource,
|
112
|
-
request_type: Type["IProtoMessage"],
|
113
|
-
response_type: Type["T"],
|
114
|
-
*,
|
115
|
-
timeout: Optional[float] = None,
|
116
|
-
deadline: Optional["Deadline"] = None,
|
117
|
-
metadata: Optional[MetadataLike] = None,
|
118
|
-
) -> "T":
|
119
|
-
"""Make a stream request and return the response."""
|
120
|
-
async with self.channel.request(
|
121
|
-
route,
|
122
|
-
grpclib.const.Cardinality.STREAM_UNARY,
|
123
|
-
request_type,
|
124
|
-
response_type,
|
125
|
-
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
126
|
-
) as stream:
|
127
|
-
await stream.send_request()
|
128
|
-
await self._send_messages(stream, request_iterator)
|
129
|
-
response = await stream.recv_message()
|
130
|
-
assert response is not None
|
131
|
-
return response
|
132
|
-
|
133
|
-
async def _stream_stream(
|
134
|
-
self,
|
135
|
-
route: str,
|
136
|
-
request_iterator: MessageSource,
|
137
|
-
request_type: Type["IProtoMessage"],
|
138
|
-
response_type: Type["T"],
|
139
|
-
*,
|
140
|
-
timeout: Optional[float] = None,
|
141
|
-
deadline: Optional["Deadline"] = None,
|
142
|
-
metadata: Optional[MetadataLike] = None,
|
143
|
-
) -> AsyncIterator["T"]:
|
144
|
-
"""
|
145
|
-
Make a stream request and return an AsyncIterator to iterate over response
|
146
|
-
messages.
|
147
|
-
"""
|
148
|
-
async with self.channel.request(
|
149
|
-
route,
|
150
|
-
grpclib.const.Cardinality.STREAM_STREAM,
|
151
|
-
request_type,
|
152
|
-
response_type,
|
153
|
-
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
154
|
-
) as stream:
|
155
|
-
await stream.send_request()
|
156
|
-
sending_task = asyncio.ensure_future(self._send_messages(stream, request_iterator))
|
157
|
-
try:
|
158
|
-
async for response in stream:
|
159
|
-
yield response
|
160
|
-
except:
|
161
|
-
sending_task.cancel()
|
162
|
-
raise
|
163
|
-
|
164
|
-
@staticmethod
|
165
|
-
async def _send_messages(stream, messages: MessageSource):
|
166
|
-
if isinstance(messages, AsyncIterable):
|
167
|
-
async for message in messages:
|
168
|
-
await stream.send_message(message)
|
169
|
-
else:
|
170
|
-
for message in messages:
|
171
|
-
await stream.send_message(message)
|
172
|
-
await stream.end()
|
@@ -1,32 +0,0 @@
|
|
1
|
-
from abc import ABC
|
2
|
-
from collections.abc import AsyncIterable
|
3
|
-
from typing import (
|
4
|
-
Any,
|
5
|
-
Callable,
|
6
|
-
)
|
7
|
-
|
8
|
-
import grpclib
|
9
|
-
import grpclib.server
|
10
|
-
|
11
|
-
|
12
|
-
class ServiceBase(ABC):
|
13
|
-
"""
|
14
|
-
Base class for async gRPC servers.
|
15
|
-
"""
|
16
|
-
|
17
|
-
async def _call_rpc_handler_server_stream(
|
18
|
-
self,
|
19
|
-
handler: Callable,
|
20
|
-
stream: grpclib.server.Stream,
|
21
|
-
request: Any,
|
22
|
-
) -> None:
|
23
|
-
response_iter = handler(request)
|
24
|
-
# check if response is actually an AsyncIterator
|
25
|
-
# this might be false if the method just returns without
|
26
|
-
# yielding at least once
|
27
|
-
# in that case, we just interpret it as an empty iterator
|
28
|
-
if isinstance(response_iter, AsyncIterable):
|
29
|
-
async for response_message in response_iter:
|
30
|
-
await stream.send_message(response_message)
|
31
|
-
else:
|
32
|
-
response_iter.close()
|
File without changes
|
@@ -1,190 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
from typing import (
|
3
|
-
AsyncIterable,
|
4
|
-
AsyncIterator,
|
5
|
-
Iterable,
|
6
|
-
Optional,
|
7
|
-
TypeVar,
|
8
|
-
Union,
|
9
|
-
)
|
10
|
-
|
11
|
-
T = TypeVar("T")
|
12
|
-
|
13
|
-
|
14
|
-
class ChannelClosed(Exception):
|
15
|
-
"""
|
16
|
-
An exception raised on an attempt to send through a closed channel
|
17
|
-
"""
|
18
|
-
|
19
|
-
|
20
|
-
class ChannelDone(Exception):
|
21
|
-
"""
|
22
|
-
An exception raised on an attempt to send receive from a channel that is both closed
|
23
|
-
and empty.
|
24
|
-
"""
|
25
|
-
|
26
|
-
|
27
|
-
class AsyncChannel(AsyncIterable[T]):
|
28
|
-
"""
|
29
|
-
A buffered async channel for sending items between coroutines with FIFO ordering.
|
30
|
-
|
31
|
-
This makes decoupled bidirectional steaming gRPC requests easy if used like:
|
32
|
-
|
33
|
-
.. code-block:: python
|
34
|
-
client = GeneratedStub(grpclib_chan)
|
35
|
-
request_channel = await AsyncChannel()
|
36
|
-
# We can start be sending all the requests we already have
|
37
|
-
await request_channel.send_from([RequestObject(...), RequestObject(...)])
|
38
|
-
async for response in client.rpc_call(request_channel):
|
39
|
-
# The response iterator will remain active until the connection is closed
|
40
|
-
...
|
41
|
-
# More items can be sent at any time
|
42
|
-
await request_channel.send(RequestObject(...))
|
43
|
-
...
|
44
|
-
# The channel must be closed to complete the gRPC connection
|
45
|
-
request_channel.close()
|
46
|
-
|
47
|
-
Items can be sent through the channel by either:
|
48
|
-
- providing an iterable to the send_from method
|
49
|
-
- passing them to the send method one at a time
|
50
|
-
|
51
|
-
Items can be received from the channel by either:
|
52
|
-
- iterating over the channel with a for loop to get all items
|
53
|
-
- calling the receive method to get one item at a time
|
54
|
-
|
55
|
-
If the channel is empty then receivers will wait until either an item appears or the
|
56
|
-
channel is closed.
|
57
|
-
|
58
|
-
Once the channel is closed then subsequent attempt to send through the channel will
|
59
|
-
fail with a ChannelClosed exception.
|
60
|
-
|
61
|
-
When th channel is closed and empty then it is done, and further attempts to receive
|
62
|
-
from it will fail with a ChannelDone exception
|
63
|
-
|
64
|
-
If multiple coroutines receive from the channel concurrently, each item sent will be
|
65
|
-
received by only one of the receivers.
|
66
|
-
|
67
|
-
:param source:
|
68
|
-
An optional iterable will items that should be sent through the channel
|
69
|
-
immediately.
|
70
|
-
:param buffer_limit:
|
71
|
-
Limit the number of items that can be buffered in the channel, A value less than
|
72
|
-
1 implies no limit. If the channel is full then attempts to send more items will
|
73
|
-
result in the sender waiting until an item is received from the channel.
|
74
|
-
:param close:
|
75
|
-
If set to True then the channel will automatically close after exhausting source
|
76
|
-
or immediately if no source is provided.
|
77
|
-
"""
|
78
|
-
|
79
|
-
def __init__(self, *, buffer_limit: int = 0, close: bool = False):
|
80
|
-
self._queue: asyncio.Queue[T] = asyncio.Queue(buffer_limit)
|
81
|
-
self._closed = False
|
82
|
-
self._waiting_receivers: int = 0
|
83
|
-
# Track whether flush has been invoked so it can only happen once
|
84
|
-
self._flushed = False
|
85
|
-
|
86
|
-
def __aiter__(self) -> AsyncIterator[T]:
|
87
|
-
return self
|
88
|
-
|
89
|
-
async def __anext__(self) -> T:
|
90
|
-
if self.done():
|
91
|
-
raise StopAsyncIteration
|
92
|
-
self._waiting_receivers += 1
|
93
|
-
try:
|
94
|
-
result = await self._queue.get()
|
95
|
-
if result is self.__flush:
|
96
|
-
raise StopAsyncIteration
|
97
|
-
return result
|
98
|
-
finally:
|
99
|
-
self._waiting_receivers -= 1
|
100
|
-
self._queue.task_done()
|
101
|
-
|
102
|
-
def closed(self) -> bool:
|
103
|
-
"""
|
104
|
-
Returns True if this channel is closed and no-longer accepting new items
|
105
|
-
"""
|
106
|
-
return self._closed
|
107
|
-
|
108
|
-
def done(self) -> bool:
|
109
|
-
"""
|
110
|
-
Check if this channel is done.
|
111
|
-
|
112
|
-
:return: True if this channel is closed and and has been drained of items in
|
113
|
-
which case any further attempts to receive an item from this channel will raise
|
114
|
-
a ChannelDone exception.
|
115
|
-
"""
|
116
|
-
# After close the channel is not yet done until there is at least one waiting
|
117
|
-
# receiver per enqueued item.
|
118
|
-
return self._closed and self._queue.qsize() <= self._waiting_receivers
|
119
|
-
|
120
|
-
async def send_from(self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False) -> "AsyncChannel[T]":
|
121
|
-
"""
|
122
|
-
Iterates the given [Async]Iterable and sends all the resulting items.
|
123
|
-
If close is set to True then subsequent send calls will be rejected with a
|
124
|
-
ChannelClosed exception.
|
125
|
-
:param source: an iterable of items to send
|
126
|
-
:param close:
|
127
|
-
if True then the channel will be closed after the source has been exhausted
|
128
|
-
|
129
|
-
"""
|
130
|
-
if self._closed:
|
131
|
-
raise ChannelClosed("Cannot send through a closed channel")
|
132
|
-
if isinstance(source, AsyncIterable):
|
133
|
-
async for item in source:
|
134
|
-
await self._queue.put(item)
|
135
|
-
else:
|
136
|
-
for item in source:
|
137
|
-
await self._queue.put(item)
|
138
|
-
if close:
|
139
|
-
# Complete the closing process
|
140
|
-
self.close()
|
141
|
-
return self
|
142
|
-
|
143
|
-
async def send(self, item: T) -> "AsyncChannel[T]":
|
144
|
-
"""
|
145
|
-
Send a single item over this channel.
|
146
|
-
:param item: The item to send
|
147
|
-
"""
|
148
|
-
if self._closed:
|
149
|
-
raise ChannelClosed("Cannot send through a closed channel")
|
150
|
-
await self._queue.put(item)
|
151
|
-
return self
|
152
|
-
|
153
|
-
async def receive(self) -> Optional[T]:
|
154
|
-
"""
|
155
|
-
Returns the next item from this channel when it becomes available,
|
156
|
-
or None if the channel is closed before another item is sent.
|
157
|
-
:return: An item from the channel
|
158
|
-
"""
|
159
|
-
if self.done():
|
160
|
-
raise ChannelDone("Cannot receive from a closed channel")
|
161
|
-
self._waiting_receivers += 1
|
162
|
-
try:
|
163
|
-
result = await self._queue.get()
|
164
|
-
if result is self.__flush:
|
165
|
-
return None
|
166
|
-
return result
|
167
|
-
finally:
|
168
|
-
self._waiting_receivers -= 1
|
169
|
-
self._queue.task_done()
|
170
|
-
|
171
|
-
def close(self):
|
172
|
-
"""
|
173
|
-
Close this channel to new items
|
174
|
-
"""
|
175
|
-
self._closed = True
|
176
|
-
asyncio.ensure_future(self._flush_queue())
|
177
|
-
|
178
|
-
async def _flush_queue(self):
|
179
|
-
"""
|
180
|
-
To be called after the channel is closed. Pushes a number of self.__flush
|
181
|
-
objects to the queue to ensure no waiting consumers get deadlocked.
|
182
|
-
"""
|
183
|
-
if not self._flushed:
|
184
|
-
self._flushed = True
|
185
|
-
deadlocked_receivers = max(0, self._waiting_receivers - self._queue.qsize())
|
186
|
-
for _ in range(deadlocked_receivers):
|
187
|
-
await self._queue.put(self.__flush)
|
188
|
-
|
189
|
-
# A special signal object for flushing the queue when the channel is closed
|
190
|
-
__flush = object()
|
File without changes
|
File without changes
|