betterproto2-compiler 0.2.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- betterproto2_compiler/__init__.py +0 -0
- betterproto2_compiler/casing.py +140 -0
- betterproto2_compiler/compile/__init__.py +0 -0
- betterproto2_compiler/compile/importing.py +180 -0
- betterproto2_compiler/compile/naming.py +21 -0
- betterproto2_compiler/known_types/__init__.py +14 -0
- betterproto2_compiler/known_types/any.py +36 -0
- betterproto2_compiler/known_types/duration.py +25 -0
- betterproto2_compiler/known_types/timestamp.py +45 -0
- betterproto2_compiler/lib/__init__.py +0 -0
- betterproto2_compiler/lib/google/__init__.py +0 -0
- betterproto2_compiler/lib/google/protobuf/__init__.py +3338 -0
- betterproto2_compiler/lib/google/protobuf/compiler/__init__.py +235 -0
- betterproto2_compiler/lib/message_pool.py +3 -0
- betterproto2_compiler/plugin/__init__.py +3 -0
- betterproto2_compiler/plugin/__main__.py +3 -0
- betterproto2_compiler/plugin/compiler.py +70 -0
- betterproto2_compiler/plugin/main.py +47 -0
- betterproto2_compiler/plugin/models.py +643 -0
- betterproto2_compiler/plugin/module_validation.py +156 -0
- betterproto2_compiler/plugin/parser.py +272 -0
- betterproto2_compiler/plugin/plugin.bat +2 -0
- betterproto2_compiler/plugin/typing_compiler.py +163 -0
- betterproto2_compiler/py.typed +0 -0
- betterproto2_compiler/settings.py +9 -0
- betterproto2_compiler/templates/header.py.j2 +59 -0
- betterproto2_compiler/templates/template.py.j2 +258 -0
- betterproto2_compiler-0.2.0.dist-info/LICENSE.md +22 -0
- betterproto2_compiler-0.2.0.dist-info/METADATA +35 -0
- betterproto2_compiler-0.2.0.dist-info/RECORD +32 -0
- betterproto2_compiler-0.2.0.dist-info/WHEEL +4 -0
- betterproto2_compiler-0.2.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,156 @@
|
|
1
|
+
import re
|
2
|
+
from collections import defaultdict
|
3
|
+
from collections.abc import Iterator
|
4
|
+
from dataclasses import (
|
5
|
+
dataclass,
|
6
|
+
field,
|
7
|
+
)
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class ModuleValidator:
|
12
|
+
line_iterator: Iterator[str]
|
13
|
+
line_number: int = field(init=False, default=0)
|
14
|
+
|
15
|
+
collisions: dict[str, list[tuple[int, str]]] = field(init=False, default_factory=lambda: defaultdict(list))
|
16
|
+
|
17
|
+
def add_import(self, imp: str, number: int, full_line: str):
|
18
|
+
"""
|
19
|
+
Adds an import to be tracked.
|
20
|
+
"""
|
21
|
+
self.collisions[imp].append((number, full_line))
|
22
|
+
|
23
|
+
def process_import(self, imp: str):
|
24
|
+
"""
|
25
|
+
Filters out the import to its actual value.
|
26
|
+
"""
|
27
|
+
if " as " in imp:
|
28
|
+
imp = imp[imp.index(" as ") + 4 :]
|
29
|
+
|
30
|
+
imp = imp.strip()
|
31
|
+
assert " " not in imp, imp
|
32
|
+
return imp
|
33
|
+
|
34
|
+
def evaluate_multiline_import(self, line: str):
|
35
|
+
"""
|
36
|
+
Evaluates a multiline import from a starting line
|
37
|
+
"""
|
38
|
+
# Filter the first line and remove anything before the import statement.
|
39
|
+
full_line = line
|
40
|
+
line = line.split("import", 1)[1]
|
41
|
+
if "(" in line:
|
42
|
+
conditional = lambda line: ")" not in line
|
43
|
+
else:
|
44
|
+
conditional = lambda line: "\\" in line
|
45
|
+
|
46
|
+
# Remove open parenthesis if it exists.
|
47
|
+
if "(" in line:
|
48
|
+
line = line[line.index("(") + 1 :]
|
49
|
+
|
50
|
+
# Choose the conditional based on how multiline imports are formatted.
|
51
|
+
while conditional(line):
|
52
|
+
# Split the line by commas
|
53
|
+
imports = line.split(",")
|
54
|
+
|
55
|
+
for imp in imports:
|
56
|
+
# Add the import to the namespace
|
57
|
+
imp = self.process_import(imp)
|
58
|
+
if imp:
|
59
|
+
self.add_import(imp, self.line_number, full_line)
|
60
|
+
# Get the next line
|
61
|
+
full_line = line = next(self.line_iterator)
|
62
|
+
# Increment the line number
|
63
|
+
self.line_number += 1
|
64
|
+
|
65
|
+
# validate the last line
|
66
|
+
if ")" in line:
|
67
|
+
line = line[: line.index(")")]
|
68
|
+
imports = line.split(",")
|
69
|
+
for imp in imports:
|
70
|
+
imp = self.process_import(imp)
|
71
|
+
if imp:
|
72
|
+
self.add_import(imp, self.line_number, full_line)
|
73
|
+
|
74
|
+
def evaluate_import(self, line: str):
|
75
|
+
"""
|
76
|
+
Extracts an import from a line.
|
77
|
+
"""
|
78
|
+
whole_line = line
|
79
|
+
line = line[line.index("import") + 6 :]
|
80
|
+
values = line.split(",")
|
81
|
+
for v in values:
|
82
|
+
self.add_import(self.process_import(v), self.line_number, whole_line)
|
83
|
+
|
84
|
+
def next(self):
|
85
|
+
"""
|
86
|
+
Evaluate each line for names in the module.
|
87
|
+
"""
|
88
|
+
line = next(self.line_iterator)
|
89
|
+
|
90
|
+
# Skip lines with indentation or comments
|
91
|
+
if (
|
92
|
+
# Skip indents and whitespace.
|
93
|
+
line.startswith(" ")
|
94
|
+
or line == "\n"
|
95
|
+
or line.startswith("\t")
|
96
|
+
or
|
97
|
+
# Skip comments
|
98
|
+
line.startswith("#")
|
99
|
+
or
|
100
|
+
# Skip decorators
|
101
|
+
line.startswith("@")
|
102
|
+
):
|
103
|
+
self.line_number += 1
|
104
|
+
return
|
105
|
+
|
106
|
+
# Skip docstrings.
|
107
|
+
if line.startswith('"""') or line.startswith("'''"):
|
108
|
+
quote = line[0] * 3
|
109
|
+
line = line[3:]
|
110
|
+
while quote not in line:
|
111
|
+
line = next(self.line_iterator)
|
112
|
+
self.line_number += 1
|
113
|
+
return
|
114
|
+
|
115
|
+
# Evaluate Imports.
|
116
|
+
if line.startswith("from ") or line.startswith("import "):
|
117
|
+
if "(" in line or "\\" in line:
|
118
|
+
self.evaluate_multiline_import(line)
|
119
|
+
else:
|
120
|
+
self.evaluate_import(line)
|
121
|
+
|
122
|
+
# Evaluate Classes.
|
123
|
+
elif match := re.search(r"^class (\w+)", line):
|
124
|
+
class_name = match.group(1)
|
125
|
+
if class_name:
|
126
|
+
self.add_import(class_name, self.line_number, line)
|
127
|
+
|
128
|
+
# Evaluate Functions.
|
129
|
+
elif match := re.search(r"^def (\w+)", line):
|
130
|
+
function_name = match.group(1)
|
131
|
+
if function_name:
|
132
|
+
self.add_import(function_name, self.line_number, line)
|
133
|
+
|
134
|
+
# Evaluate direct assignments.
|
135
|
+
elif match := re.search(r"^(\w+)\s*=", line):
|
136
|
+
assignment = match.group(1)
|
137
|
+
if assignment:
|
138
|
+
self.add_import(assignment, self.line_number, line)
|
139
|
+
|
140
|
+
self.line_number += 1
|
141
|
+
|
142
|
+
def validate(self) -> bool:
|
143
|
+
"""
|
144
|
+
Run Validation.
|
145
|
+
"""
|
146
|
+
try:
|
147
|
+
while True:
|
148
|
+
self.next()
|
149
|
+
except StopIteration:
|
150
|
+
pass
|
151
|
+
|
152
|
+
# Filter collisions for those with more than one value.
|
153
|
+
self.collisions = {k: v for k, v in self.collisions.items() if len(v) > 1}
|
154
|
+
|
155
|
+
# Return True if no collisions are found.
|
156
|
+
return not bool(self.collisions)
|
@@ -0,0 +1,272 @@
|
|
1
|
+
import pathlib
|
2
|
+
import sys
|
3
|
+
from collections.abc import Generator
|
4
|
+
|
5
|
+
from betterproto2_compiler.lib.google.protobuf import (
|
6
|
+
DescriptorProto,
|
7
|
+
EnumDescriptorProto,
|
8
|
+
FileDescriptorProto,
|
9
|
+
ServiceDescriptorProto,
|
10
|
+
)
|
11
|
+
from betterproto2_compiler.lib.google.protobuf.compiler import (
|
12
|
+
CodeGeneratorRequest,
|
13
|
+
CodeGeneratorResponse,
|
14
|
+
CodeGeneratorResponseFeature,
|
15
|
+
CodeGeneratorResponseFile,
|
16
|
+
)
|
17
|
+
from betterproto2_compiler.settings import Settings
|
18
|
+
|
19
|
+
from .compiler import outputfile_compiler
|
20
|
+
from .models import (
|
21
|
+
EnumDefinitionCompiler,
|
22
|
+
FieldCompiler,
|
23
|
+
MapEntryCompiler,
|
24
|
+
MessageCompiler,
|
25
|
+
OneofCompiler,
|
26
|
+
OneOfFieldCompiler,
|
27
|
+
OutputTemplate,
|
28
|
+
PluginRequestCompiler,
|
29
|
+
ServiceCompiler,
|
30
|
+
ServiceMethodCompiler,
|
31
|
+
is_map,
|
32
|
+
is_oneof,
|
33
|
+
)
|
34
|
+
from .typing_compiler import (
|
35
|
+
DirectImportTypingCompiler,
|
36
|
+
NoTyping310TypingCompiler,
|
37
|
+
TypingImportTypingCompiler,
|
38
|
+
)
|
39
|
+
|
40
|
+
|
41
|
+
def traverse(
|
42
|
+
proto_file: FileDescriptorProto,
|
43
|
+
) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int]], None, None]:
|
44
|
+
# Todo: Keep information about nested hierarchy
|
45
|
+
def _traverse(
|
46
|
+
path: list[int],
|
47
|
+
items: list[EnumDescriptorProto] | list[DescriptorProto],
|
48
|
+
prefix: str = "",
|
49
|
+
) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int]], None, None]:
|
50
|
+
for i, item in enumerate(items):
|
51
|
+
# Adjust the name since we flatten the hierarchy.
|
52
|
+
# Todo: don't change the name, but include full name in returned tuple
|
53
|
+
should_rename = not isinstance(item, DescriptorProto) or not item.options or not item.options.map_entry
|
54
|
+
|
55
|
+
item.name = next_prefix = f"{prefix}.{item.name}" if prefix and should_rename else item.name
|
56
|
+
yield item, [*path, i]
|
57
|
+
|
58
|
+
if isinstance(item, DescriptorProto):
|
59
|
+
# Get nested types.
|
60
|
+
yield from _traverse([*path, i, 4], item.enum_type, next_prefix)
|
61
|
+
yield from _traverse([*path, i, 3], item.nested_type, next_prefix)
|
62
|
+
|
63
|
+
yield from _traverse([5], proto_file.enum_type)
|
64
|
+
yield from _traverse([4], proto_file.message_type)
|
65
|
+
|
66
|
+
|
67
|
+
def get_settings(plugin_options: list[str]) -> Settings:
|
68
|
+
# Gather any typing generation options.
|
69
|
+
typing_opts = [opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.")]
|
70
|
+
|
71
|
+
if len(typing_opts) > 1:
|
72
|
+
raise ValueError("Multiple typing options provided")
|
73
|
+
|
74
|
+
# Set the compiler type.
|
75
|
+
typing_opt = typing_opts[0] if typing_opts else "direct"
|
76
|
+
if typing_opt == "direct":
|
77
|
+
typing_compiler = DirectImportTypingCompiler()
|
78
|
+
elif typing_opt == "root":
|
79
|
+
typing_compiler = TypingImportTypingCompiler()
|
80
|
+
elif typing_opt == "310":
|
81
|
+
typing_compiler = NoTyping310TypingCompiler()
|
82
|
+
else:
|
83
|
+
raise ValueError("Invalid typing option provided")
|
84
|
+
|
85
|
+
return Settings(
|
86
|
+
typing_compiler=typing_compiler,
|
87
|
+
pydantic_dataclasses="pydantic_dataclasses" in plugin_options,
|
88
|
+
)
|
89
|
+
|
90
|
+
|
91
|
+
def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
92
|
+
response = CodeGeneratorResponse(supported_features=CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL)
|
93
|
+
|
94
|
+
plugin_options = request.parameter.split(",") if request.parameter else []
|
95
|
+
settings = get_settings(plugin_options)
|
96
|
+
|
97
|
+
request_data = PluginRequestCompiler(plugin_request_obj=request)
|
98
|
+
# Gather output packages
|
99
|
+
for proto_file in request.proto_file:
|
100
|
+
output_package_name = proto_file.package
|
101
|
+
if output_package_name not in request_data.output_packages:
|
102
|
+
# Create a new output if there is no output for this package
|
103
|
+
request_data.output_packages[output_package_name] = OutputTemplate(
|
104
|
+
parent_request=request_data, package_proto_obj=proto_file, settings=settings
|
105
|
+
)
|
106
|
+
# Add this input file to the output corresponding to this package
|
107
|
+
request_data.output_packages[output_package_name].input_files.append(proto_file)
|
108
|
+
|
109
|
+
# Read Messages and Enums
|
110
|
+
# We need to read Messages before Services in so that we can
|
111
|
+
# get the references to input/output messages for each service
|
112
|
+
for output_package_name, output_package in request_data.output_packages.items():
|
113
|
+
for proto_input_file in output_package.input_files:
|
114
|
+
for item, path in traverse(proto_input_file):
|
115
|
+
read_protobuf_type(
|
116
|
+
source_file=proto_input_file,
|
117
|
+
item=item,
|
118
|
+
path=path,
|
119
|
+
output_package=output_package,
|
120
|
+
)
|
121
|
+
|
122
|
+
# Read Services
|
123
|
+
for output_package_name, output_package in request_data.output_packages.items():
|
124
|
+
for proto_input_file in output_package.input_files:
|
125
|
+
for index, service in enumerate(proto_input_file.service):
|
126
|
+
read_protobuf_service(proto_input_file, service, index, output_package)
|
127
|
+
|
128
|
+
# All the hierarchy is ready. We can perform pre-computations before generating the output files
|
129
|
+
for package in request_data.output_packages.values():
|
130
|
+
for message in package.messages.values():
|
131
|
+
for field in message.fields:
|
132
|
+
field.ready()
|
133
|
+
message.ready()
|
134
|
+
for enum in package.enums.values():
|
135
|
+
enum.ready()
|
136
|
+
for service in package.services.values():
|
137
|
+
for method in service.methods:
|
138
|
+
method.ready()
|
139
|
+
service.ready()
|
140
|
+
|
141
|
+
# Generate output files
|
142
|
+
output_paths: set[pathlib.Path] = set()
|
143
|
+
for output_package_name, output_package in request_data.output_packages.items():
|
144
|
+
# Add files to the response object
|
145
|
+
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
|
146
|
+
output_paths.add(output_path)
|
147
|
+
|
148
|
+
response.file.append(
|
149
|
+
CodeGeneratorResponseFile(
|
150
|
+
name=str(output_path),
|
151
|
+
# Render and then format the output file
|
152
|
+
content=outputfile_compiler(output_file=output_package),
|
153
|
+
),
|
154
|
+
)
|
155
|
+
|
156
|
+
# Make each output directory a package with __init__ file
|
157
|
+
init_files = {
|
158
|
+
directory.joinpath("__init__.py")
|
159
|
+
for path in output_paths
|
160
|
+
for directory in path.parents
|
161
|
+
if not directory.joinpath("__init__.py").exists()
|
162
|
+
} - output_paths
|
163
|
+
|
164
|
+
for init_file in init_files:
|
165
|
+
response.file.append(CodeGeneratorResponseFile(name=str(init_file)))
|
166
|
+
|
167
|
+
response.file.append(
|
168
|
+
CodeGeneratorResponseFile(
|
169
|
+
name="message_pool.py", content="import betterproto2\n\ndefault_message_pool = betterproto2.MessagePool()\n"
|
170
|
+
)
|
171
|
+
)
|
172
|
+
|
173
|
+
for output_package_name in sorted(output_paths.union(init_files)):
|
174
|
+
print(f"Writing {output_package_name}", file=sys.stderr)
|
175
|
+
|
176
|
+
return response
|
177
|
+
|
178
|
+
|
179
|
+
def read_protobuf_type(
|
180
|
+
item: DescriptorProto | EnumDescriptorProto,
|
181
|
+
path: list[int],
|
182
|
+
source_file: "FileDescriptorProto",
|
183
|
+
output_package: OutputTemplate,
|
184
|
+
) -> None:
|
185
|
+
if isinstance(item, DescriptorProto):
|
186
|
+
if item.options and item.options.map_entry:
|
187
|
+
# Skip generated map entry messages since we just use dicts
|
188
|
+
return
|
189
|
+
# Process Message
|
190
|
+
message_data = MessageCompiler(
|
191
|
+
source_file=source_file,
|
192
|
+
output_file=output_package,
|
193
|
+
proto_obj=item,
|
194
|
+
path=path,
|
195
|
+
)
|
196
|
+
output_package.messages[message_data.proto_name] = message_data
|
197
|
+
|
198
|
+
for index, field in enumerate(item.field):
|
199
|
+
if is_map(field, item):
|
200
|
+
message_data.fields.append(
|
201
|
+
MapEntryCompiler(
|
202
|
+
source_file=source_file,
|
203
|
+
message=message_data,
|
204
|
+
proto_obj=field,
|
205
|
+
path=path + [2, index],
|
206
|
+
typing_compiler=output_package.settings.typing_compiler,
|
207
|
+
)
|
208
|
+
)
|
209
|
+
elif is_oneof(field):
|
210
|
+
message_data.fields.append(
|
211
|
+
OneOfFieldCompiler(
|
212
|
+
source_file=source_file,
|
213
|
+
message=message_data,
|
214
|
+
proto_obj=field,
|
215
|
+
path=path + [2, index],
|
216
|
+
typing_compiler=output_package.settings.typing_compiler,
|
217
|
+
)
|
218
|
+
)
|
219
|
+
else:
|
220
|
+
message_data.fields.append(
|
221
|
+
FieldCompiler(
|
222
|
+
source_file=source_file,
|
223
|
+
message=message_data,
|
224
|
+
proto_obj=field,
|
225
|
+
path=path + [2, index],
|
226
|
+
typing_compiler=output_package.settings.typing_compiler,
|
227
|
+
)
|
228
|
+
)
|
229
|
+
|
230
|
+
for index, oneof in enumerate(item.oneof_decl):
|
231
|
+
message_data.oneofs.append(
|
232
|
+
OneofCompiler(
|
233
|
+
source_file=source_file,
|
234
|
+
path=path + [8, index],
|
235
|
+
proto_obj=oneof,
|
236
|
+
)
|
237
|
+
)
|
238
|
+
|
239
|
+
elif isinstance(item, EnumDescriptorProto):
|
240
|
+
# Enum
|
241
|
+
enum = EnumDefinitionCompiler(
|
242
|
+
source_file=source_file,
|
243
|
+
output_file=output_package,
|
244
|
+
proto_obj=item,
|
245
|
+
path=path,
|
246
|
+
)
|
247
|
+
output_package.enums[enum.proto_name] = enum
|
248
|
+
|
249
|
+
|
250
|
+
def read_protobuf_service(
|
251
|
+
source_file: FileDescriptorProto,
|
252
|
+
service: ServiceDescriptorProto,
|
253
|
+
index: int,
|
254
|
+
output_package: OutputTemplate,
|
255
|
+
) -> None:
|
256
|
+
service_data = ServiceCompiler(
|
257
|
+
source_file=source_file,
|
258
|
+
output_file=output_package,
|
259
|
+
proto_obj=service,
|
260
|
+
path=[6, index],
|
261
|
+
)
|
262
|
+
service_data.output_file.services[service_data.proto_name] = service_data
|
263
|
+
|
264
|
+
for j, method in enumerate(service.method):
|
265
|
+
service_data.methods.append(
|
266
|
+
ServiceMethodCompiler(
|
267
|
+
source_file=source_file,
|
268
|
+
parent=service_data,
|
269
|
+
proto_obj=method,
|
270
|
+
path=[6, index, 2, j],
|
271
|
+
)
|
272
|
+
)
|
@@ -0,0 +1,163 @@
|
|
1
|
+
import abc
|
2
|
+
import builtins
|
3
|
+
from collections import defaultdict
|
4
|
+
from collections.abc import Iterator
|
5
|
+
from dataclasses import (
|
6
|
+
dataclass,
|
7
|
+
field,
|
8
|
+
)
|
9
|
+
|
10
|
+
|
11
|
+
class TypingCompiler(metaclass=abc.ABCMeta):
|
12
|
+
@abc.abstractmethod
|
13
|
+
def optional(self, type_: str) -> str:
|
14
|
+
raise NotImplementedError
|
15
|
+
|
16
|
+
@abc.abstractmethod
|
17
|
+
def list(self, type_: str) -> str:
|
18
|
+
raise NotImplementedError
|
19
|
+
|
20
|
+
@abc.abstractmethod
|
21
|
+
def dict(self, key: str, value: str) -> str:
|
22
|
+
raise NotImplementedError
|
23
|
+
|
24
|
+
@abc.abstractmethod
|
25
|
+
def union(self, *types: str) -> str:
|
26
|
+
raise NotImplementedError
|
27
|
+
|
28
|
+
@abc.abstractmethod
|
29
|
+
def iterable(self, type_: str) -> str:
|
30
|
+
raise NotImplementedError
|
31
|
+
|
32
|
+
@abc.abstractmethod
|
33
|
+
def async_iterable(self, type_: str) -> str:
|
34
|
+
raise NotImplementedError
|
35
|
+
|
36
|
+
@abc.abstractmethod
|
37
|
+
def async_iterator(self, type_: str) -> str:
|
38
|
+
raise NotImplementedError
|
39
|
+
|
40
|
+
@abc.abstractmethod
|
41
|
+
def imports(self) -> builtins.dict[str, set[str] | None]:
|
42
|
+
"""
|
43
|
+
Returns either the direct import as a key with none as value, or a set of
|
44
|
+
values to import from the key.
|
45
|
+
"""
|
46
|
+
raise NotImplementedError
|
47
|
+
|
48
|
+
def import_lines(self) -> Iterator:
|
49
|
+
imports = self.imports()
|
50
|
+
for key, value in imports.items():
|
51
|
+
if value is None:
|
52
|
+
yield f"import {key}"
|
53
|
+
else:
|
54
|
+
yield f"from {key} import ("
|
55
|
+
for v in sorted(value):
|
56
|
+
yield f" {v},"
|
57
|
+
yield ")"
|
58
|
+
|
59
|
+
|
60
|
+
@dataclass
|
61
|
+
class DirectImportTypingCompiler(TypingCompiler):
|
62
|
+
_imports: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set))
|
63
|
+
|
64
|
+
def optional(self, type_: str) -> str:
|
65
|
+
self._imports["typing"].add("Optional")
|
66
|
+
return f"Optional[{type_}]"
|
67
|
+
|
68
|
+
def list(self, type_: str) -> str:
|
69
|
+
self._imports["typing"].add("List")
|
70
|
+
return f"List[{type_}]"
|
71
|
+
|
72
|
+
def dict(self, key: str, value: str) -> str:
|
73
|
+
self._imports["typing"].add("Dict")
|
74
|
+
return f"Dict[{key}, {value}]"
|
75
|
+
|
76
|
+
def union(self, *types: str) -> str:
|
77
|
+
self._imports["typing"].add("Union")
|
78
|
+
return f"Union[{', '.join(types)}]"
|
79
|
+
|
80
|
+
def iterable(self, type_: str) -> str:
|
81
|
+
self._imports["typing"].add("Iterable")
|
82
|
+
return f"Iterable[{type_}]"
|
83
|
+
|
84
|
+
def async_iterable(self, type_: str) -> str:
|
85
|
+
self._imports["typing"].add("AsyncIterable")
|
86
|
+
return f"AsyncIterable[{type_}]"
|
87
|
+
|
88
|
+
def async_iterator(self, type_: str) -> str:
|
89
|
+
self._imports["typing"].add("AsyncIterator")
|
90
|
+
return f"AsyncIterator[{type_}]"
|
91
|
+
|
92
|
+
def imports(self) -> builtins.dict[str, set[str] | None]:
|
93
|
+
return {k: v if v else None for k, v in self._imports.items()}
|
94
|
+
|
95
|
+
|
96
|
+
@dataclass
|
97
|
+
class TypingImportTypingCompiler(TypingCompiler):
|
98
|
+
_imported: bool = False
|
99
|
+
|
100
|
+
def optional(self, type_: str) -> str:
|
101
|
+
self._imported = True
|
102
|
+
return f"typing.Optional[{type_}]"
|
103
|
+
|
104
|
+
def list(self, type_: str) -> str:
|
105
|
+
self._imported = True
|
106
|
+
return f"typing.List[{type_}]"
|
107
|
+
|
108
|
+
def dict(self, key: str, value: str) -> str:
|
109
|
+
self._imported = True
|
110
|
+
return f"typing.Dict[{key}, {value}]"
|
111
|
+
|
112
|
+
def union(self, *types: str) -> str:
|
113
|
+
self._imported = True
|
114
|
+
return f"typing.Union[{', '.join(types)}]"
|
115
|
+
|
116
|
+
def iterable(self, type_: str) -> str:
|
117
|
+
self._imported = True
|
118
|
+
return f"typing.Iterable[{type_}]"
|
119
|
+
|
120
|
+
def async_iterable(self, type_: str) -> str:
|
121
|
+
self._imported = True
|
122
|
+
return f"typing.AsyncIterable[{type_}]"
|
123
|
+
|
124
|
+
def async_iterator(self, type_: str) -> str:
|
125
|
+
self._imported = True
|
126
|
+
return f"typing.AsyncIterator[{type_}]"
|
127
|
+
|
128
|
+
def imports(self) -> builtins.dict[str, set[str] | None]:
|
129
|
+
if self._imported:
|
130
|
+
return {"typing": None}
|
131
|
+
return {}
|
132
|
+
|
133
|
+
|
134
|
+
@dataclass
|
135
|
+
class NoTyping310TypingCompiler(TypingCompiler):
|
136
|
+
_imports: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set))
|
137
|
+
|
138
|
+
def optional(self, type_: str) -> str:
|
139
|
+
return f"{type_} | None"
|
140
|
+
|
141
|
+
def list(self, type_: str) -> str:
|
142
|
+
return f"list[{type_}]"
|
143
|
+
|
144
|
+
def dict(self, key: str, value: str) -> str:
|
145
|
+
return f"dict[{key}, {value}]"
|
146
|
+
|
147
|
+
def union(self, *types: str) -> str:
|
148
|
+
return f"{' | '.join(types)}"
|
149
|
+
|
150
|
+
def iterable(self, type_: str) -> str:
|
151
|
+
self._imports["collections.abc"].add("Iterable")
|
152
|
+
return f"Iterable[{type_}]"
|
153
|
+
|
154
|
+
def async_iterable(self, type_: str) -> str:
|
155
|
+
self._imports["collections.abc"].add("AsyncIterable")
|
156
|
+
return f"AsyncIterable[{type_}]"
|
157
|
+
|
158
|
+
def async_iterator(self, type_: str) -> str:
|
159
|
+
self._imports["collections.abc"].add("AsyncIterator")
|
160
|
+
return f"AsyncIterator[{type_}]"
|
161
|
+
|
162
|
+
def imports(self) -> builtins.dict[str, set[str] | None]:
|
163
|
+
return {k: v if v else None for k, v in self._imports.items()}
|
File without changes
|
@@ -0,0 +1,59 @@
|
|
1
|
+
{# All the imports needed for this file. The useless imports will be removed by Ruff. #}
|
2
|
+
|
3
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
4
|
+
# sources: {{ ', '.join(output_file.input_filenames) }}
|
5
|
+
# plugin: python-betterproto2
|
6
|
+
# This file has been @generated
|
7
|
+
|
8
|
+
__all__ = (
|
9
|
+
{% for _, enum in output_file.enums|dictsort(by="key") %}
|
10
|
+
"{{ enum.py_name }}",
|
11
|
+
{%- endfor -%}
|
12
|
+
{% for _, message in output_file.messages|dictsort(by="key") %}
|
13
|
+
"{{ message.py_name }}",
|
14
|
+
{%- endfor -%}
|
15
|
+
{% for _, service in output_file.services|dictsort(by="key") %}
|
16
|
+
"{{ service.py_name }}Stub",
|
17
|
+
"{{ service.py_name }}Base",
|
18
|
+
{%- endfor -%}
|
19
|
+
)
|
20
|
+
|
21
|
+
import builtins
|
22
|
+
import datetime
|
23
|
+
import warnings
|
24
|
+
|
25
|
+
{% if output_file.settings.pydantic_dataclasses %}
|
26
|
+
from pydantic.dataclasses import dataclass
|
27
|
+
from pydantic import model_validator
|
28
|
+
{%- else -%}
|
29
|
+
from dataclasses import dataclass
|
30
|
+
{% endif %}
|
31
|
+
|
32
|
+
{% set typing_imports = output_file.settings.typing_compiler.imports() %}
|
33
|
+
{% if typing_imports %}
|
34
|
+
{% for line in output_file.settings.typing_compiler.import_lines() %}
|
35
|
+
{{ line }}
|
36
|
+
{% endfor %}
|
37
|
+
{% endif %}
|
38
|
+
|
39
|
+
import betterproto2
|
40
|
+
{% if output_file.services %}
|
41
|
+
from betterproto2.grpc.grpclib_server import ServiceBase
|
42
|
+
import grpclib
|
43
|
+
{% endif %}
|
44
|
+
|
45
|
+
from typing import TYPE_CHECKING
|
46
|
+
|
47
|
+
{# Import the message pool of the generated code. #}
|
48
|
+
{% if output_file.package %}
|
49
|
+
from {{ "." * output_file.package.count(".") }}..message_pool import default_message_pool
|
50
|
+
{% else %}
|
51
|
+
from .message_pool import default_message_pool
|
52
|
+
{% endif %}
|
53
|
+
|
54
|
+
if TYPE_CHECKING:
|
55
|
+
import grpclib.server
|
56
|
+
from betterproto2.grpc.grpclib_client import MetadataLike
|
57
|
+
from grpclib.metadata import Deadline
|
58
|
+
|
59
|
+
betterproto2.check_compiler_version("{{ version }}")
|