pydantic-rpc 0.6.1__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_rpc/__init__.py +10 -0
- pydantic_rpc/core.py +1443 -453
- pydantic_rpc/mcp/__init__.py +5 -0
- pydantic_rpc/mcp/converter.py +115 -0
- pydantic_rpc/mcp/exporter.py +283 -0
- {pydantic_rpc-0.6.1.dist-info → pydantic_rpc-0.7.0.dist-info}/METADATA +149 -10
- pydantic_rpc-0.7.0.dist-info/RECORD +11 -0
- pydantic_rpc-0.6.1.dist-info/RECORD +0 -8
- {pydantic_rpc-0.6.1.dist-info → pydantic_rpc-0.7.0.dist-info}/WHEEL +0 -0
- {pydantic_rpc-0.6.1.dist-info → pydantic_rpc-0.7.0.dist-info}/entry_points.txt +0 -0
- {pydantic_rpc-0.6.1.dist-info → pydantic_rpc-0.7.0.dist-info}/licenses/LICENSE +0 -0
pydantic_rpc/core.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import annotated_types
|
|
2
1
|
import asyncio
|
|
2
|
+
import datetime
|
|
3
3
|
import enum
|
|
4
4
|
import importlib.util
|
|
5
5
|
import inspect
|
|
@@ -8,33 +8,37 @@ import signal
|
|
|
8
8
|
import sys
|
|
9
9
|
import time
|
|
10
10
|
import types
|
|
11
|
-
import
|
|
11
|
+
from typing import Union
|
|
12
|
+
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
12
13
|
from concurrent import futures
|
|
14
|
+
from pathlib import Path
|
|
13
15
|
from posixpath import basename
|
|
14
16
|
from typing import (
|
|
15
|
-
|
|
16
|
-
|
|
17
|
+
Any,
|
|
18
|
+
TypeAlias,
|
|
17
19
|
get_args,
|
|
18
20
|
get_origin,
|
|
19
|
-
|
|
20
|
-
|
|
21
|
+
cast,
|
|
22
|
+
TypeGuard,
|
|
21
23
|
)
|
|
22
|
-
from collections.abc import AsyncIterator
|
|
23
24
|
|
|
25
|
+
import annotated_types
|
|
24
26
|
import grpc
|
|
27
|
+
from grpc import ServicerContext
|
|
28
|
+
import grpc_tools
|
|
29
|
+
from connecpy.server import ConnecpyASGIApplication as ConnecpyASGI
|
|
30
|
+
from connecpy.server import ConnecpyWSGIApplication as ConnecpyWSGI
|
|
31
|
+
from connecpy.code import Code as Errors
|
|
32
|
+
|
|
33
|
+
# Protobuf Python modules for Timestamp, Duration (requires protobuf / grpcio)
|
|
34
|
+
from google.protobuf import duration_pb2, timestamp_pb2, empty_pb2
|
|
25
35
|
from grpc_health.v1 import health_pb2, health_pb2_grpc
|
|
26
36
|
from grpc_health.v1.health import HealthServicer
|
|
27
37
|
from grpc_reflection.v1alpha import reflection
|
|
28
38
|
from grpc_tools import protoc
|
|
29
39
|
from pydantic import BaseModel, ValidationError
|
|
30
|
-
from sonora.wsgi import grpcWSGI
|
|
31
40
|
from sonora.asgi import grpcASGI
|
|
32
|
-
from
|
|
33
|
-
from connecpy.errors import Errors
|
|
34
|
-
from connecpy.wsgi import ConnecpyWSGIApp as ConnecpyWSGI
|
|
35
|
-
|
|
36
|
-
# Protobuf Python modules for Timestamp, Duration (requires protobuf / grpcio)
|
|
37
|
-
from google.protobuf import timestamp_pb2, duration_pb2
|
|
41
|
+
from sonora.wsgi import grpcWSGI
|
|
38
42
|
|
|
39
43
|
###############################################################################
|
|
40
44
|
# 1. Message definitions & converter extensions
|
|
@@ -46,7 +50,12 @@ from google.protobuf import timestamp_pb2, duration_pb2
|
|
|
46
50
|
Message: TypeAlias = BaseModel
|
|
47
51
|
|
|
48
52
|
|
|
49
|
-
def
|
|
53
|
+
def is_none_type(annotation: Any) -> TypeGuard[type[None] | None]:
|
|
54
|
+
"""Check if annotation represents None/NoneType (handles both None and type(None))."""
|
|
55
|
+
return annotation is None or annotation is type(None)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def primitiveProtoValueToPythonValue(value: Any):
|
|
50
59
|
# Returns the value as-is (primitive type).
|
|
51
60
|
return value
|
|
52
61
|
|
|
@@ -75,11 +84,20 @@ def python_to_duration(td: datetime.timedelta) -> duration_pb2.Duration: # type
|
|
|
75
84
|
return d
|
|
76
85
|
|
|
77
86
|
|
|
78
|
-
def generate_converter(annotation:
|
|
87
|
+
def generate_converter(annotation: type[Any] | None) -> Callable[[Any], Any]:
|
|
79
88
|
"""
|
|
80
89
|
Returns a converter function to convert protobuf types to Python types.
|
|
81
90
|
This is used primarily when handling incoming requests.
|
|
82
91
|
"""
|
|
92
|
+
# For NoneType (Empty messages)
|
|
93
|
+
if is_none_type(annotation):
|
|
94
|
+
|
|
95
|
+
def empty_converter(value: empty_pb2.Empty): # type: ignore
|
|
96
|
+
_ = value
|
|
97
|
+
return None
|
|
98
|
+
|
|
99
|
+
return empty_converter
|
|
100
|
+
|
|
83
101
|
# For primitive types
|
|
84
102
|
if annotation in (int, str, bool, bytes, float):
|
|
85
103
|
return primitiveProtoValueToPythonValue
|
|
@@ -87,7 +105,7 @@ def generate_converter(annotation: Type) -> Callable:
|
|
|
87
105
|
# For enum types
|
|
88
106
|
if inspect.isclass(annotation) and issubclass(annotation, enum.Enum):
|
|
89
107
|
|
|
90
|
-
def enum_converter(value):
|
|
108
|
+
def enum_converter(value: enum.Enum):
|
|
91
109
|
return annotation(value)
|
|
92
110
|
|
|
93
111
|
return enum_converter
|
|
@@ -114,7 +132,7 @@ def generate_converter(annotation: Type) -> Callable:
|
|
|
114
132
|
if origin in (list, tuple):
|
|
115
133
|
item_converter = generate_converter(get_args(annotation)[0])
|
|
116
134
|
|
|
117
|
-
def seq_converter(value):
|
|
135
|
+
def seq_converter(value: list[Any] | tuple[Any, ...]):
|
|
118
136
|
return [item_converter(v) for v in value]
|
|
119
137
|
|
|
120
138
|
return seq_converter
|
|
@@ -124,7 +142,7 @@ def generate_converter(annotation: Type) -> Callable:
|
|
|
124
142
|
key_converter = generate_converter(get_args(annotation)[0])
|
|
125
143
|
value_converter = generate_converter(get_args(annotation)[1])
|
|
126
144
|
|
|
127
|
-
def dict_converter(value):
|
|
145
|
+
def dict_converter(value: dict[Any, Any]):
|
|
128
146
|
return {key_converter(k): value_converter(v) for k, v in value.items()}
|
|
129
147
|
|
|
130
148
|
return dict_converter
|
|
@@ -137,27 +155,95 @@ def generate_converter(annotation: Type) -> Callable:
|
|
|
137
155
|
return primitiveProtoValueToPythonValue
|
|
138
156
|
|
|
139
157
|
|
|
140
|
-
def generate_message_converter(
|
|
158
|
+
def generate_message_converter(
|
|
159
|
+
arg_type: type[Message] | type[None] | None,
|
|
160
|
+
) -> Callable[[Any], Message | None]:
|
|
141
161
|
"""Return a converter function for protobuf -> Python Message."""
|
|
142
|
-
if arg_type is None or not issubclass(arg_type, Message):
|
|
143
|
-
raise TypeError("Request arg must be subclass of Message")
|
|
144
162
|
|
|
163
|
+
# Handle NoneType (Empty messages)
|
|
164
|
+
if is_none_type(arg_type):
|
|
165
|
+
|
|
166
|
+
def empty_converter(request: Any) -> None:
|
|
167
|
+
_ = request
|
|
168
|
+
return None
|
|
169
|
+
|
|
170
|
+
return empty_converter
|
|
171
|
+
|
|
172
|
+
arg_type = cast("type[Message]", arg_type)
|
|
145
173
|
fields = arg_type.model_fields
|
|
146
174
|
converters = {
|
|
147
175
|
field: generate_converter(field_type.annotation) # type: ignore
|
|
148
176
|
for field, field_type in fields.items()
|
|
149
177
|
}
|
|
150
178
|
|
|
151
|
-
def converter(request):
|
|
179
|
+
def converter(request: Any) -> Message:
|
|
152
180
|
rdict = {}
|
|
153
|
-
for
|
|
154
|
-
|
|
181
|
+
for field_name, field_info in fields.items():
|
|
182
|
+
field_type = field_info.annotation
|
|
183
|
+
|
|
184
|
+
# Check if this is a union type
|
|
185
|
+
if field_type is not None and is_union_type(field_type):
|
|
186
|
+
union_args = flatten_union(field_type)
|
|
187
|
+
has_none = type(None) in union_args
|
|
188
|
+
non_none_args = [arg for arg in union_args if arg is not type(None)]
|
|
189
|
+
|
|
190
|
+
if has_none and len(non_none_args) == 1:
|
|
191
|
+
# This is Optional[T] - check if protobuf field is set
|
|
192
|
+
try:
|
|
193
|
+
if hasattr(request, "HasField") and request.HasField(
|
|
194
|
+
field_name
|
|
195
|
+
):
|
|
196
|
+
rdict[field_name] = converters[field_name](
|
|
197
|
+
getattr(request, field_name)
|
|
198
|
+
)
|
|
199
|
+
else:
|
|
200
|
+
# Field not set in protobuf, set to None for Optional fields
|
|
201
|
+
rdict[field_name] = None
|
|
202
|
+
except ValueError:
|
|
203
|
+
# HasField doesn't work for this field type (e.g., repeated fields)
|
|
204
|
+
# Fall back to regular conversion
|
|
205
|
+
rdict[field_name] = converters[field_name](
|
|
206
|
+
getattr(request, field_name)
|
|
207
|
+
)
|
|
208
|
+
elif len(non_none_args) > 1:
|
|
209
|
+
# This is a oneof field (Union[str, int] etc.)
|
|
210
|
+
# Check which oneof field is set
|
|
211
|
+
try:
|
|
212
|
+
which_field = request.WhichOneof(field_name)
|
|
213
|
+
if which_field:
|
|
214
|
+
# Extract the value from the set oneof field
|
|
215
|
+
proto_value = getattr(request, which_field)
|
|
216
|
+
|
|
217
|
+
# Determine the Python type from the oneof field name
|
|
218
|
+
# e.g., "value_string" -> str, "value_int32" -> int
|
|
219
|
+
for union_arg in non_none_args:
|
|
220
|
+
proto_typename = protobuf_type_mapping(union_arg)
|
|
221
|
+
if (
|
|
222
|
+
proto_typename
|
|
223
|
+
and which_field
|
|
224
|
+
== f"{field_name}_{proto_typename.replace('.', '_')}"
|
|
225
|
+
):
|
|
226
|
+
# Convert using the specific type converter
|
|
227
|
+
type_converter = generate_converter(union_arg)
|
|
228
|
+
rdict[field_name] = type_converter(proto_value)
|
|
229
|
+
break
|
|
230
|
+
except (AttributeError, ValueError):
|
|
231
|
+
# WhichOneof failed, try fallback
|
|
232
|
+
# This shouldn't happen for properly generated oneof fields
|
|
233
|
+
pass
|
|
234
|
+
else:
|
|
235
|
+
# Union with only None type (shouldn't happen)
|
|
236
|
+
pass
|
|
237
|
+
else:
|
|
238
|
+
# For non-union fields, convert normally
|
|
239
|
+
rdict[field_name] = converters[field_name](getattr(request, field_name))
|
|
240
|
+
|
|
155
241
|
return arg_type(**rdict)
|
|
156
242
|
|
|
157
243
|
return converter
|
|
158
244
|
|
|
159
245
|
|
|
160
|
-
def python_value_to_proto_value(field_type:
|
|
246
|
+
def python_value_to_proto_value(field_type: type[Any], value: Any) -> Any:
|
|
161
247
|
"""
|
|
162
248
|
Converts Python values to protobuf values.
|
|
163
249
|
Used primarily when constructing a response object.
|
|
@@ -179,77 +265,112 @@ def python_value_to_proto_value(field_type: Type, value):
|
|
|
179
265
|
###############################################################################
|
|
180
266
|
|
|
181
267
|
|
|
182
|
-
def connect_obj_with_stub(
|
|
268
|
+
def connect_obj_with_stub(
|
|
269
|
+
pb2_grpc_module: Any, pb2_module: Any, service_obj: object
|
|
270
|
+
) -> type:
|
|
183
271
|
"""
|
|
184
272
|
Connect a Python service object to a gRPC stub, generating server methods.
|
|
273
|
+
Returns a subclass of the generated Servicer stub with concrete implementations.
|
|
185
274
|
"""
|
|
186
275
|
service_class = service_obj.__class__
|
|
187
276
|
stub_class_name = service_class.__name__ + "Servicer"
|
|
188
277
|
stub_class = getattr(pb2_grpc_module, stub_class_name)
|
|
189
278
|
|
|
190
279
|
class ConcreteServiceClass(stub_class):
|
|
280
|
+
"""Dynamically generated servicer class with stub methods implemented."""
|
|
281
|
+
|
|
191
282
|
pass
|
|
192
283
|
|
|
193
|
-
def implement_stub_method(
|
|
194
|
-
|
|
284
|
+
def implement_stub_method(
|
|
285
|
+
method: Callable[..., Message],
|
|
286
|
+
) -> Callable[[object, Any, Any], Any]:
|
|
287
|
+
"""
|
|
288
|
+
Wraps a user-defined method (self, *args) -> R into a gRPC stub signature:
|
|
289
|
+
(self, request_proto, context) -> response_proto
|
|
290
|
+
"""
|
|
195
291
|
sig = inspect.signature(method)
|
|
196
292
|
arg_type = get_request_arg_type(sig)
|
|
197
|
-
# Convert request from protobuf to Python.
|
|
198
|
-
converter = generate_message_converter(arg_type)
|
|
199
|
-
|
|
200
293
|
response_type = sig.return_annotation
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
match size_of_parameters:
|
|
204
|
-
case 1:
|
|
294
|
+
param_count = len(sig.parameters)
|
|
295
|
+
converter = generate_message_converter(arg_type)
|
|
205
296
|
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
297
|
+
if param_count == 1:
|
|
298
|
+
|
|
299
|
+
def stub_method(
|
|
300
|
+
self: object,
|
|
301
|
+
request: Any,
|
|
302
|
+
context: Any,
|
|
303
|
+
*,
|
|
304
|
+
original: Callable[..., Message] = method,
|
|
305
|
+
) -> Any:
|
|
306
|
+
_ = self
|
|
307
|
+
try:
|
|
308
|
+
if is_none_type(arg_type):
|
|
309
|
+
resp_obj = original(None) # Fixed: pass None instead of no args
|
|
310
|
+
else:
|
|
209
311
|
arg = converter(request)
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
312
|
+
resp_obj = original(arg)
|
|
313
|
+
|
|
314
|
+
if is_none_type(response_type):
|
|
315
|
+
return empty_pb2.Empty() # type: ignore
|
|
316
|
+
else:
|
|
213
317
|
return convert_python_message_to_proto(
|
|
214
318
|
resp_obj, response_type, pb2_module
|
|
215
319
|
)
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
320
|
+
except ValidationError as e:
|
|
321
|
+
return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
|
322
|
+
except Exception as e:
|
|
323
|
+
return context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
324
|
+
|
|
325
|
+
elif param_count == 2:
|
|
326
|
+
|
|
327
|
+
def stub_method(
|
|
328
|
+
self: object,
|
|
329
|
+
request: Any,
|
|
330
|
+
context: Any,
|
|
331
|
+
*,
|
|
332
|
+
original: Callable[..., Message] = method,
|
|
333
|
+
) -> Any:
|
|
334
|
+
_ = self
|
|
335
|
+
try:
|
|
336
|
+
if is_none_type(arg_type):
|
|
337
|
+
resp_obj = original(
|
|
338
|
+
None, context
|
|
339
|
+
) # Fixed: pass None instead of Empty
|
|
340
|
+
else:
|
|
227
341
|
arg = converter(request)
|
|
228
|
-
resp_obj =
|
|
342
|
+
resp_obj = original(arg, context)
|
|
343
|
+
|
|
344
|
+
if is_none_type(response_type):
|
|
345
|
+
return empty_pb2.Empty() # type: ignore
|
|
346
|
+
else:
|
|
229
347
|
return convert_python_message_to_proto(
|
|
230
348
|
resp_obj, response_type, pb2_module
|
|
231
349
|
)
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
350
|
+
except ValidationError as e:
|
|
351
|
+
return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
|
352
|
+
except Exception as e:
|
|
353
|
+
return context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
236
354
|
|
|
237
|
-
|
|
355
|
+
else:
|
|
356
|
+
raise TypeError(
|
|
357
|
+
f"Method '{method.__name__}' must have exactly 1 or 2 parameters, got {param_count}"
|
|
358
|
+
)
|
|
238
359
|
|
|
239
|
-
|
|
240
|
-
raise Exception("Method must have exactly one or two parameters")
|
|
360
|
+
return stub_method
|
|
241
361
|
|
|
362
|
+
# Attach all RPC methods from service_obj to the concrete servicer
|
|
242
363
|
for method_name, method in get_rpc_methods(service_obj):
|
|
243
|
-
if
|
|
364
|
+
if method_name.startswith("_"):
|
|
244
365
|
continue
|
|
245
|
-
|
|
246
|
-
a_method = implement_stub_method(method)
|
|
247
|
-
setattr(ConcreteServiceClass, method_name, a_method)
|
|
366
|
+
setattr(ConcreteServiceClass, method_name, implement_stub_method(method))
|
|
248
367
|
|
|
249
368
|
return ConcreteServiceClass
|
|
250
369
|
|
|
251
370
|
|
|
252
|
-
def connect_obj_with_stub_async(
|
|
371
|
+
def connect_obj_with_stub_async(
|
|
372
|
+
pb2_grpc_module: Any, pb2_module: Any, obj: object
|
|
373
|
+
) -> type:
|
|
253
374
|
"""
|
|
254
375
|
Connect a Python service object to a gRPC stub for async methods.
|
|
255
376
|
"""
|
|
@@ -260,26 +381,151 @@ def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> typ
|
|
|
260
381
|
class ConcreteServiceClass(stub_class):
|
|
261
382
|
pass
|
|
262
383
|
|
|
263
|
-
def implement_stub_method(
|
|
384
|
+
def implement_stub_method(
|
|
385
|
+
method: Callable[..., Any],
|
|
386
|
+
) -> Callable[[object, Any, Any], Any]:
|
|
264
387
|
sig = inspect.signature(method)
|
|
265
|
-
|
|
266
|
-
|
|
388
|
+
input_type = get_request_arg_type(sig)
|
|
389
|
+
is_input_stream = is_stream_type(input_type)
|
|
267
390
|
response_type = sig.return_annotation
|
|
391
|
+
is_output_stream = is_stream_type(response_type)
|
|
268
392
|
size_of_parameters = len(sig.parameters)
|
|
269
393
|
|
|
270
|
-
if
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
394
|
+
if size_of_parameters not in (1, 2):
|
|
395
|
+
raise TypeError(
|
|
396
|
+
f"Method '{method.__name__}' must have 1 or 2 parameters, got {size_of_parameters}"
|
|
397
|
+
)
|
|
274
398
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
399
|
+
if is_input_stream:
|
|
400
|
+
input_item_type = get_args(input_type)[0]
|
|
401
|
+
item_converter = generate_message_converter(input_item_type)
|
|
402
|
+
|
|
403
|
+
async def convert_iterator(
|
|
404
|
+
proto_iter: AsyncIterator[Any],
|
|
405
|
+
) -> AsyncIterator[Message]:
|
|
406
|
+
async for proto in proto_iter:
|
|
407
|
+
result = item_converter(proto)
|
|
408
|
+
if result is None:
|
|
409
|
+
raise TypeError(
|
|
410
|
+
f"Unexpected None result from converter for type {input_item_type}"
|
|
411
|
+
)
|
|
412
|
+
yield result
|
|
413
|
+
|
|
414
|
+
if is_output_stream:
|
|
415
|
+
# stream-stream
|
|
416
|
+
output_item_type = get_args(response_type)[0]
|
|
417
|
+
|
|
418
|
+
if size_of_parameters == 1:
|
|
419
|
+
|
|
420
|
+
async def stub_method(
|
|
421
|
+
self: object,
|
|
422
|
+
request_iterator: AsyncIterator[Any],
|
|
423
|
+
context: Any,
|
|
424
|
+
) -> AsyncIterator[Any]:
|
|
425
|
+
_ = self
|
|
426
|
+
try:
|
|
427
|
+
arg_iter = convert_iterator(request_iterator)
|
|
428
|
+
async for resp_obj in method(arg_iter):
|
|
429
|
+
yield convert_python_message_to_proto(
|
|
430
|
+
resp_obj, output_item_type, pb2_module
|
|
431
|
+
)
|
|
432
|
+
except ValidationError as e:
|
|
433
|
+
await context.abort(
|
|
434
|
+
grpc.StatusCode.INVALID_ARGUMENT, str(e)
|
|
435
|
+
)
|
|
436
|
+
except Exception as e:
|
|
437
|
+
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
438
|
+
|
|
439
|
+
else: # size_of_parameters == 2
|
|
440
|
+
|
|
441
|
+
async def stub_method(
|
|
442
|
+
self: object,
|
|
443
|
+
request_iterator: AsyncIterator[Any],
|
|
444
|
+
context: Any,
|
|
445
|
+
) -> AsyncIterator[Any]:
|
|
446
|
+
_ = self
|
|
447
|
+
try:
|
|
448
|
+
arg_iter = convert_iterator(request_iterator)
|
|
449
|
+
async for resp_obj in method(arg_iter, context):
|
|
450
|
+
yield convert_python_message_to_proto(
|
|
451
|
+
resp_obj, output_item_type, pb2_module
|
|
452
|
+
)
|
|
453
|
+
except ValidationError as e:
|
|
454
|
+
await context.abort(
|
|
455
|
+
grpc.StatusCode.INVALID_ARGUMENT, str(e)
|
|
456
|
+
)
|
|
457
|
+
except Exception as e:
|
|
458
|
+
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
459
|
+
|
|
460
|
+
return stub_method
|
|
461
|
+
|
|
462
|
+
else:
|
|
463
|
+
# stream-unary
|
|
464
|
+
if size_of_parameters == 1:
|
|
465
|
+
|
|
466
|
+
async def stub_method(
|
|
467
|
+
self: object,
|
|
468
|
+
request_iterator: AsyncIterator[Any],
|
|
469
|
+
context: Any,
|
|
470
|
+
) -> Any:
|
|
471
|
+
_ = self
|
|
472
|
+
try:
|
|
473
|
+
arg_iter = convert_iterator(request_iterator)
|
|
474
|
+
resp_obj = await method(arg_iter)
|
|
475
|
+
return convert_python_message_to_proto(
|
|
476
|
+
resp_obj, response_type, pb2_module
|
|
477
|
+
)
|
|
478
|
+
except ValidationError as e:
|
|
479
|
+
await context.abort(
|
|
480
|
+
grpc.StatusCode.INVALID_ARGUMENT, str(e)
|
|
481
|
+
)
|
|
482
|
+
except Exception as e:
|
|
483
|
+
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
484
|
+
|
|
485
|
+
else: # size_of_parameters == 2
|
|
486
|
+
|
|
487
|
+
async def stub_method(
|
|
488
|
+
self: object,
|
|
489
|
+
request_iterator: AsyncIterator[Any],
|
|
490
|
+
context: Any,
|
|
491
|
+
) -> Any:
|
|
492
|
+
_ = self
|
|
493
|
+
try:
|
|
494
|
+
arg_iter = convert_iterator(request_iterator)
|
|
495
|
+
resp_obj = await method(arg_iter, context)
|
|
496
|
+
return convert_python_message_to_proto(
|
|
497
|
+
resp_obj, response_type, pb2_module
|
|
498
|
+
)
|
|
499
|
+
except ValidationError as e:
|
|
500
|
+
await context.abort(
|
|
501
|
+
grpc.StatusCode.INVALID_ARGUMENT, str(e)
|
|
502
|
+
)
|
|
503
|
+
except Exception as e:
|
|
504
|
+
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
505
|
+
|
|
506
|
+
return stub_method
|
|
507
|
+
|
|
508
|
+
else:
|
|
509
|
+
# unary input
|
|
510
|
+
converter = generate_message_converter(input_type)
|
|
511
|
+
|
|
512
|
+
if is_output_stream:
|
|
513
|
+
# unary-stream
|
|
514
|
+
output_item_type = get_args(response_type)[0]
|
|
515
|
+
|
|
516
|
+
if size_of_parameters == 1:
|
|
517
|
+
|
|
518
|
+
async def stub_method(
|
|
519
|
+
self: object,
|
|
520
|
+
request: Any,
|
|
521
|
+
context: Any,
|
|
522
|
+
) -> AsyncIterator[Any]:
|
|
523
|
+
_ = self
|
|
278
524
|
try:
|
|
279
525
|
arg = converter(request)
|
|
280
526
|
async for resp_obj in method(arg):
|
|
281
527
|
yield convert_python_message_to_proto(
|
|
282
|
-
resp_obj,
|
|
528
|
+
resp_obj, output_item_type, pb2_module
|
|
283
529
|
)
|
|
284
530
|
except ValidationError as e:
|
|
285
531
|
await context.abort(
|
|
@@ -288,17 +534,19 @@ def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> typ
|
|
|
288
534
|
except Exception as e:
|
|
289
535
|
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
290
536
|
|
|
291
|
-
|
|
292
|
-
case 2:
|
|
537
|
+
else: # size_of_parameters == 2
|
|
293
538
|
|
|
294
|
-
async def
|
|
295
|
-
self
|
|
296
|
-
|
|
539
|
+
async def stub_method(
|
|
540
|
+
self: object,
|
|
541
|
+
request: Any,
|
|
542
|
+
context: Any,
|
|
543
|
+
) -> AsyncIterator[Any]:
|
|
544
|
+
_ = self
|
|
297
545
|
try:
|
|
298
546
|
arg = converter(request)
|
|
299
547
|
async for resp_obj in method(arg, context):
|
|
300
548
|
yield convert_python_message_to_proto(
|
|
301
|
-
resp_obj,
|
|
549
|
+
resp_obj, output_item_type, pb2_module
|
|
302
550
|
)
|
|
303
551
|
except ValidationError as e:
|
|
304
552
|
await context.abort(
|
|
@@ -307,45 +555,67 @@ def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> typ
|
|
|
307
555
|
except Exception as e:
|
|
308
556
|
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
309
557
|
|
|
310
|
-
|
|
311
|
-
case _:
|
|
312
|
-
raise Exception("Method must have exactly one or two parameters")
|
|
313
|
-
|
|
314
|
-
match size_of_parameters:
|
|
315
|
-
case 1:
|
|
316
|
-
|
|
317
|
-
async def stub_method1(self, request, context, method=method):
|
|
318
|
-
try:
|
|
319
|
-
arg = converter(request)
|
|
320
|
-
resp_obj = await method(arg)
|
|
321
|
-
return convert_python_message_to_proto(
|
|
322
|
-
resp_obj, response_type, pb2_module
|
|
323
|
-
)
|
|
324
|
-
except ValidationError as e:
|
|
325
|
-
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
|
326
|
-
except Exception as e:
|
|
327
|
-
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
558
|
+
return stub_method
|
|
328
559
|
|
|
329
|
-
|
|
560
|
+
else:
|
|
561
|
+
# unary-unary
|
|
562
|
+
if size_of_parameters == 1:
|
|
330
563
|
|
|
331
|
-
|
|
564
|
+
async def stub_method(
|
|
565
|
+
self: object,
|
|
566
|
+
request: Any,
|
|
567
|
+
context: Any,
|
|
568
|
+
) -> Any:
|
|
569
|
+
_ = self
|
|
570
|
+
try:
|
|
571
|
+
if is_none_type(input_type):
|
|
572
|
+
resp_obj = await method(None)
|
|
573
|
+
else:
|
|
574
|
+
arg = converter(request)
|
|
575
|
+
resp_obj = await method(arg)
|
|
576
|
+
|
|
577
|
+
if is_none_type(response_type):
|
|
578
|
+
return empty_pb2.Empty() # type: ignore
|
|
579
|
+
else:
|
|
580
|
+
return convert_python_message_to_proto(
|
|
581
|
+
resp_obj, response_type, pb2_module
|
|
582
|
+
)
|
|
583
|
+
except ValidationError as e:
|
|
584
|
+
await context.abort(
|
|
585
|
+
grpc.StatusCode.INVALID_ARGUMENT, str(e)
|
|
586
|
+
)
|
|
587
|
+
except Exception as e:
|
|
588
|
+
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
332
589
|
|
|
333
|
-
|
|
334
|
-
try:
|
|
335
|
-
arg = converter(request)
|
|
336
|
-
resp_obj = await method(arg, context)
|
|
337
|
-
return convert_python_message_to_proto(
|
|
338
|
-
resp_obj, response_type, pb2_module
|
|
339
|
-
)
|
|
340
|
-
except ValidationError as e:
|
|
341
|
-
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
|
342
|
-
except Exception as e:
|
|
343
|
-
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
590
|
+
else: # size_of_parameters == 2
|
|
344
591
|
|
|
345
|
-
|
|
592
|
+
async def stub_method(
|
|
593
|
+
self: object,
|
|
594
|
+
request: Any,
|
|
595
|
+
context: Any,
|
|
596
|
+
) -> Any:
|
|
597
|
+
_ = self
|
|
598
|
+
try:
|
|
599
|
+
if is_none_type(input_type):
|
|
600
|
+
resp_obj = await method(None, context)
|
|
601
|
+
else:
|
|
602
|
+
arg = converter(request)
|
|
603
|
+
resp_obj = await method(arg, context)
|
|
604
|
+
|
|
605
|
+
if is_none_type(response_type):
|
|
606
|
+
return empty_pb2.Empty() # type: ignore
|
|
607
|
+
else:
|
|
608
|
+
return convert_python_message_to_proto(
|
|
609
|
+
resp_obj, response_type, pb2_module
|
|
610
|
+
)
|
|
611
|
+
except ValidationError as e:
|
|
612
|
+
await context.abort(
|
|
613
|
+
grpc.StatusCode.INVALID_ARGUMENT, str(e)
|
|
614
|
+
)
|
|
615
|
+
except Exception as e:
|
|
616
|
+
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
346
617
|
|
|
347
|
-
|
|
348
|
-
raise Exception("Method must have exactly one or two parameters")
|
|
618
|
+
return stub_method
|
|
349
619
|
|
|
350
620
|
for method_name, method in get_rpc_methods(obj):
|
|
351
621
|
if method.__name__.startswith("_"):
|
|
@@ -357,7 +627,9 @@ def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> typ
|
|
|
357
627
|
return ConcreteServiceClass
|
|
358
628
|
|
|
359
629
|
|
|
360
|
-
def connect_obj_with_stub_connecpy(
|
|
630
|
+
def connect_obj_with_stub_connecpy(
|
|
631
|
+
connecpy_module: Any, pb2_module: Any, obj: object
|
|
632
|
+
) -> type:
|
|
361
633
|
"""
|
|
362
634
|
Connect a Python service object to a Connecpy stub.
|
|
363
635
|
"""
|
|
@@ -368,7 +640,9 @@ def connect_obj_with_stub_connecpy(connecpy_module, pb2_module, obj: object) ->
|
|
|
368
640
|
class ConcreteServiceClass(stub_class):
|
|
369
641
|
pass
|
|
370
642
|
|
|
371
|
-
def implement_stub_method(
|
|
643
|
+
def implement_stub_method(
|
|
644
|
+
method: Callable[..., Message],
|
|
645
|
+
) -> Callable[[object, Any, Any], Any]:
|
|
372
646
|
sig = inspect.signature(method)
|
|
373
647
|
arg_type = get_request_arg_type(sig)
|
|
374
648
|
converter = generate_message_converter(arg_type)
|
|
@@ -378,33 +652,59 @@ def connect_obj_with_stub_connecpy(connecpy_module, pb2_module, obj: object) ->
|
|
|
378
652
|
match size_of_parameters:
|
|
379
653
|
case 1:
|
|
380
654
|
|
|
381
|
-
def stub_method1(
|
|
655
|
+
def stub_method1(
|
|
656
|
+
self: object,
|
|
657
|
+
request: Any,
|
|
658
|
+
context: Any,
|
|
659
|
+
method: Callable[..., Message] = method,
|
|
660
|
+
) -> Any:
|
|
661
|
+
_ = self
|
|
382
662
|
try:
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
663
|
+
if is_none_type(arg_type):
|
|
664
|
+
resp_obj = method(None)
|
|
665
|
+
else:
|
|
666
|
+
arg = converter(request)
|
|
667
|
+
resp_obj = method(arg)
|
|
668
|
+
|
|
669
|
+
if is_none_type(response_type):
|
|
670
|
+
return empty_pb2.Empty() # type: ignore
|
|
671
|
+
else:
|
|
672
|
+
return convert_python_message_to_proto(
|
|
673
|
+
resp_obj, response_type, pb2_module
|
|
674
|
+
)
|
|
388
675
|
except ValidationError as e:
|
|
389
|
-
return context.abort(Errors.
|
|
676
|
+
return context.abort(Errors.INVALID_ARGUMENT, str(e))
|
|
390
677
|
except Exception as e:
|
|
391
|
-
return context.abort(Errors.
|
|
678
|
+
return context.abort(Errors.INTERNAL, str(e))
|
|
392
679
|
|
|
393
680
|
return stub_method1
|
|
394
681
|
|
|
395
682
|
case 2:
|
|
396
683
|
|
|
397
|
-
def stub_method2(
|
|
684
|
+
def stub_method2(
|
|
685
|
+
self: object,
|
|
686
|
+
request: Any,
|
|
687
|
+
context: Any,
|
|
688
|
+
method: Callable[..., Message] = method,
|
|
689
|
+
) -> Any:
|
|
690
|
+
_ = self
|
|
398
691
|
try:
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
692
|
+
if is_none_type(arg_type):
|
|
693
|
+
resp_obj = method(None, context)
|
|
694
|
+
else:
|
|
695
|
+
arg = converter(request)
|
|
696
|
+
resp_obj = method(arg, context)
|
|
697
|
+
|
|
698
|
+
if is_none_type(response_type):
|
|
699
|
+
return empty_pb2.Empty() # type: ignore
|
|
700
|
+
else:
|
|
701
|
+
return convert_python_message_to_proto(
|
|
702
|
+
resp_obj, response_type, pb2_module
|
|
703
|
+
)
|
|
404
704
|
except ValidationError as e:
|
|
405
|
-
return context.abort(Errors.
|
|
705
|
+
return context.abort(Errors.INVALID_ARGUMENT, str(e))
|
|
406
706
|
except Exception as e:
|
|
407
|
-
return context.abort(Errors.
|
|
707
|
+
return context.abort(Errors.INTERNAL, str(e))
|
|
408
708
|
|
|
409
709
|
return stub_method2
|
|
410
710
|
|
|
@@ -421,7 +721,7 @@ def connect_obj_with_stub_connecpy(connecpy_module, pb2_module, obj: object) ->
|
|
|
421
721
|
|
|
422
722
|
|
|
423
723
|
def connect_obj_with_stub_async_connecpy(
|
|
424
|
-
connecpy_module, pb2_module, obj: object
|
|
724
|
+
connecpy_module: Any, pb2_module: Any, obj: object
|
|
425
725
|
) -> type:
|
|
426
726
|
"""
|
|
427
727
|
Connect a Python service object to a Connecpy stub for async methods.
|
|
@@ -433,7 +733,9 @@ def connect_obj_with_stub_async_connecpy(
|
|
|
433
733
|
class ConcreteServiceClass(stub_class):
|
|
434
734
|
pass
|
|
435
735
|
|
|
436
|
-
def implement_stub_method(
|
|
736
|
+
def implement_stub_method(
|
|
737
|
+
method: Callable[..., Awaitable[Message]],
|
|
738
|
+
) -> Callable[[object, Any, Any], Any]:
|
|
437
739
|
sig = inspect.signature(method)
|
|
438
740
|
arg_type = get_request_arg_type(sig)
|
|
439
741
|
converter = generate_message_converter(arg_type)
|
|
@@ -443,33 +745,59 @@ def connect_obj_with_stub_async_connecpy(
|
|
|
443
745
|
match size_of_parameters:
|
|
444
746
|
case 1:
|
|
445
747
|
|
|
446
|
-
async def stub_method1(
|
|
748
|
+
async def stub_method1(
|
|
749
|
+
self: object,
|
|
750
|
+
request: Any,
|
|
751
|
+
context: Any,
|
|
752
|
+
method: Callable[..., Awaitable[Message]] = method,
|
|
753
|
+
) -> Any:
|
|
754
|
+
_ = self
|
|
447
755
|
try:
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
756
|
+
if is_none_type(arg_type):
|
|
757
|
+
resp_obj = await method(None)
|
|
758
|
+
else:
|
|
759
|
+
arg = converter(request)
|
|
760
|
+
resp_obj = await method(arg)
|
|
761
|
+
|
|
762
|
+
if is_none_type(response_type):
|
|
763
|
+
return empty_pb2.Empty() # type: ignore
|
|
764
|
+
else:
|
|
765
|
+
return convert_python_message_to_proto(
|
|
766
|
+
resp_obj, response_type, pb2_module
|
|
767
|
+
)
|
|
453
768
|
except ValidationError as e:
|
|
454
|
-
await context.abort(Errors.
|
|
769
|
+
await context.abort(Errors.INVALID_ARGUMENT, str(e))
|
|
455
770
|
except Exception as e:
|
|
456
|
-
await context.abort(Errors.
|
|
771
|
+
await context.abort(Errors.INTERNAL, str(e))
|
|
457
772
|
|
|
458
773
|
return stub_method1
|
|
459
774
|
|
|
460
775
|
case 2:
|
|
461
776
|
|
|
462
|
-
async def stub_method2(
|
|
777
|
+
async def stub_method2(
|
|
778
|
+
self: object,
|
|
779
|
+
request: Any,
|
|
780
|
+
context: Any,
|
|
781
|
+
method: Callable[..., Awaitable[Message]] = method,
|
|
782
|
+
) -> Any:
|
|
783
|
+
_ = self
|
|
463
784
|
try:
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
785
|
+
if is_none_type(arg_type):
|
|
786
|
+
resp_obj = await method(None, context)
|
|
787
|
+
else:
|
|
788
|
+
arg = converter(request)
|
|
789
|
+
resp_obj = await method(arg, context)
|
|
790
|
+
|
|
791
|
+
if is_none_type(response_type):
|
|
792
|
+
return empty_pb2.Empty() # type: ignore
|
|
793
|
+
else:
|
|
794
|
+
return convert_python_message_to_proto(
|
|
795
|
+
resp_obj, response_type, pb2_module
|
|
796
|
+
)
|
|
469
797
|
except ValidationError as e:
|
|
470
|
-
await context.abort(Errors.
|
|
798
|
+
await context.abort(Errors.INVALID_ARGUMENT, str(e))
|
|
471
799
|
except Exception as e:
|
|
472
|
-
await context.abort(Errors.
|
|
800
|
+
await context.abort(Errors.INTERNAL, str(e))
|
|
473
801
|
|
|
474
802
|
return stub_method2
|
|
475
803
|
|
|
@@ -487,36 +815,84 @@ def connect_obj_with_stub_async_connecpy(
|
|
|
487
815
|
return ConcreteServiceClass
|
|
488
816
|
|
|
489
817
|
|
|
490
|
-
def
|
|
491
|
-
|
|
492
|
-
) ->
|
|
818
|
+
def python_value_to_proto_oneof(
|
|
819
|
+
field_name: str, field_type: type[Any], value: Any, pb2_module: Any
|
|
820
|
+
) -> tuple[str, Any]:
|
|
493
821
|
"""
|
|
494
|
-
|
|
495
|
-
|
|
822
|
+
Converts a Python value from a Union type to a protobuf oneof field.
|
|
823
|
+
Returns the field name to set and the converted value.
|
|
496
824
|
"""
|
|
497
|
-
|
|
498
|
-
|
|
825
|
+
union_args = [arg for arg in flatten_union(field_type) if arg is not type(None)]
|
|
826
|
+
|
|
827
|
+
# Find which subtype in the Union matches the value's type.
|
|
828
|
+
actual_type = None
|
|
829
|
+
for sub_type in union_args:
|
|
830
|
+
origin = get_origin(sub_type)
|
|
831
|
+
type_to_check = origin or sub_type
|
|
832
|
+
try:
|
|
833
|
+
if isinstance(value, type_to_check):
|
|
834
|
+
actual_type = sub_type
|
|
835
|
+
break
|
|
836
|
+
except TypeError:
|
|
837
|
+
# This can happen if `sub_type` is not a class, e.g. a generic alias
|
|
838
|
+
if isinstance(value, type_to_check):
|
|
839
|
+
actual_type = sub_type
|
|
840
|
+
break
|
|
841
|
+
|
|
842
|
+
if actual_type is None:
|
|
843
|
+
raise TypeError(f"Value of type {type(value)} not found in union {field_type}")
|
|
844
|
+
|
|
845
|
+
proto_typename = protobuf_type_mapping(actual_type)
|
|
846
|
+
if proto_typename is None:
|
|
847
|
+
raise TypeError(f"Unsupported type in oneof: {actual_type}")
|
|
848
|
+
|
|
849
|
+
oneof_field_name = f"{field_name}_{proto_typename.replace('.', '_')}"
|
|
850
|
+
converted_value = python_value_to_proto(actual_type, value, pb2_module)
|
|
851
|
+
return oneof_field_name, converted_value
|
|
852
|
+
|
|
853
|
+
|
|
854
|
+
def convert_python_message_to_proto(
|
|
855
|
+
py_msg: Message, msg_type: type[Message], pb2_module: Any
|
|
856
|
+
) -> object:
|
|
857
|
+
"""Convert a Python Pydantic Message instance to a protobuf message instance. Used for constructing a response."""
|
|
499
858
|
field_dict = {}
|
|
500
859
|
for name, field_info in msg_type.model_fields.items():
|
|
501
860
|
value = getattr(py_msg, name)
|
|
502
861
|
if value is None:
|
|
503
|
-
field_dict[name] = None
|
|
504
862
|
continue
|
|
505
863
|
|
|
506
864
|
field_type = field_info.annotation
|
|
507
|
-
|
|
865
|
+
|
|
866
|
+
# Handle oneof fields, which are represented as Unions.
|
|
867
|
+
if field_type is not None and is_union_type(field_type):
|
|
868
|
+
union_args = [
|
|
869
|
+
arg for arg in flatten_union(field_type) if arg is not type(None)
|
|
870
|
+
]
|
|
871
|
+
if len(union_args) > 1:
|
|
872
|
+
# It's a oneof field. We need to determine the concrete type and
|
|
873
|
+
# the corresponding protobuf field name.
|
|
874
|
+
(
|
|
875
|
+
oneof_field_name,
|
|
876
|
+
converted_value,
|
|
877
|
+
) = python_value_to_proto_oneof(name, field_type, value, pb2_module)
|
|
878
|
+
field_dict[oneof_field_name] = converted_value
|
|
879
|
+
continue
|
|
880
|
+
|
|
881
|
+
# For regular and Optional fields that have a value.
|
|
882
|
+
if field_type is not None:
|
|
883
|
+
field_dict[name] = python_value_to_proto(field_type, value, pb2_module)
|
|
508
884
|
|
|
509
885
|
# Retrieve the appropriate protobuf class dynamically
|
|
510
886
|
proto_class = getattr(pb2_module, msg_type.__name__)
|
|
511
887
|
return proto_class(**field_dict)
|
|
512
888
|
|
|
513
889
|
|
|
514
|
-
def python_value_to_proto(field_type:
|
|
890
|
+
def python_value_to_proto(field_type: type[Any], value: Any, pb2_module: Any) -> Any:
|
|
515
891
|
"""
|
|
516
892
|
Perform Python->protobuf type conversion for each field value.
|
|
517
893
|
"""
|
|
518
|
-
import inspect
|
|
519
894
|
import datetime
|
|
895
|
+
import inspect
|
|
520
896
|
|
|
521
897
|
# If datetime
|
|
522
898
|
if field_type == datetime.datetime:
|
|
@@ -533,12 +909,12 @@ def python_value_to_proto(field_type: Type, value, pb2_module):
|
|
|
533
909
|
origin = get_origin(field_type)
|
|
534
910
|
# If seq
|
|
535
911
|
if origin in (list, tuple):
|
|
536
|
-
inner_type = get_args(field_type)[0]
|
|
912
|
+
inner_type = get_args(field_type)[0]
|
|
537
913
|
return [python_value_to_proto(inner_type, v, pb2_module) for v in value]
|
|
538
914
|
|
|
539
915
|
# If dict
|
|
540
916
|
if origin is dict:
|
|
541
|
-
key_type, val_type = get_args(field_type)
|
|
917
|
+
key_type, val_type = get_args(field_type)
|
|
542
918
|
return {
|
|
543
919
|
python_value_to_proto(key_type, k, pb2_module): python_value_to_proto(
|
|
544
920
|
val_type, v, pb2_module
|
|
@@ -546,31 +922,16 @@ def python_value_to_proto(field_type: Type, value, pb2_module):
|
|
|
546
922
|
for k, v in value.items()
|
|
547
923
|
}
|
|
548
924
|
|
|
549
|
-
# If union -> oneof
|
|
925
|
+
# If union -> oneof. This path is now only for Optional[T] where value is not None.
|
|
550
926
|
if is_union_type(field_type):
|
|
551
|
-
#
|
|
552
|
-
|
|
553
|
-
if
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
and issubclass(sub_type, enum.Enum)
|
|
560
|
-
and isinstance(value, enum.Enum)
|
|
561
|
-
):
|
|
562
|
-
return value.value
|
|
563
|
-
if sub_type in (int, float, str, bool, bytes) and isinstance(
|
|
564
|
-
value, sub_type
|
|
565
|
-
):
|
|
566
|
-
return value
|
|
567
|
-
if (
|
|
568
|
-
inspect.isclass(sub_type)
|
|
569
|
-
and issubclass(sub_type, Message)
|
|
570
|
-
and isinstance(value, Message)
|
|
571
|
-
):
|
|
572
|
-
return convert_python_message_to_proto(value, sub_type, pb2_module)
|
|
573
|
-
return None
|
|
927
|
+
# The value is not None, so we need to find the actual type.
|
|
928
|
+
non_none_args = [
|
|
929
|
+
arg for arg in flatten_union(field_type) if arg is not type(None)
|
|
930
|
+
]
|
|
931
|
+
if non_none_args:
|
|
932
|
+
# Assuming it's an Optional[T], so there's one type left.
|
|
933
|
+
return python_value_to_proto(non_none_args[0], value, pb2_module)
|
|
934
|
+
return None # Should not be reached if value is not None
|
|
574
935
|
|
|
575
936
|
# If Message
|
|
576
937
|
if inspect.isclass(field_type) and issubclass(field_type, Message):
|
|
@@ -585,12 +946,12 @@ def python_value_to_proto(field_type: Type, value, pb2_module):
|
|
|
585
946
|
###############################################################################
|
|
586
947
|
|
|
587
948
|
|
|
588
|
-
def is_enum_type(python_type:
|
|
949
|
+
def is_enum_type(python_type: Any) -> bool:
|
|
589
950
|
"""Return True if the given Python type is an enum."""
|
|
590
951
|
return inspect.isclass(python_type) and issubclass(python_type, enum.Enum)
|
|
591
952
|
|
|
592
953
|
|
|
593
|
-
def is_union_type(python_type:
|
|
954
|
+
def is_union_type(python_type: Any) -> bool:
|
|
594
955
|
"""
|
|
595
956
|
Check if a given Python type is a Union type (including Python 3.10's UnionType).
|
|
596
957
|
"""
|
|
@@ -599,26 +960,28 @@ def is_union_type(python_type: Type) -> bool:
|
|
|
599
960
|
if sys.version_info >= (3, 10):
|
|
600
961
|
import types
|
|
601
962
|
|
|
602
|
-
if
|
|
963
|
+
if isinstance(python_type, types.UnionType):
|
|
603
964
|
return True
|
|
604
965
|
return False
|
|
605
966
|
|
|
606
967
|
|
|
607
|
-
def flatten_union(field_type:
|
|
968
|
+
def flatten_union(field_type: Any) -> list[Any]:
|
|
608
969
|
"""Recursively flatten nested Unions into a single list of types."""
|
|
609
970
|
if is_union_type(field_type):
|
|
610
971
|
results = []
|
|
611
972
|
for arg in get_args(field_type):
|
|
612
973
|
results.extend(flatten_union(arg))
|
|
613
974
|
return results
|
|
975
|
+
elif is_none_type(field_type):
|
|
976
|
+
return [field_type]
|
|
614
977
|
else:
|
|
615
978
|
return [field_type]
|
|
616
979
|
|
|
617
980
|
|
|
618
|
-
def protobuf_type_mapping(python_type:
|
|
981
|
+
def protobuf_type_mapping(python_type: Any) -> str | None:
|
|
619
982
|
"""
|
|
620
983
|
Map a Python type to a protobuf type name/class.
|
|
621
|
-
Includes support for Timestamp and
|
|
984
|
+
Includes support for Timestamp, Duration, and Empty.
|
|
622
985
|
"""
|
|
623
986
|
import datetime
|
|
624
987
|
|
|
@@ -636,8 +999,11 @@ def protobuf_type_mapping(python_type: Type) -> str | type | None:
|
|
|
636
999
|
if python_type == datetime.timedelta:
|
|
637
1000
|
return "google.protobuf.Duration"
|
|
638
1001
|
|
|
1002
|
+
if is_none_type(python_type):
|
|
1003
|
+
return "google.protobuf.Empty"
|
|
1004
|
+
|
|
639
1005
|
if is_enum_type(python_type):
|
|
640
|
-
return python_type
|
|
1006
|
+
return python_type.__name__
|
|
641
1007
|
|
|
642
1008
|
if is_union_type(python_type):
|
|
643
1009
|
return None # Handled separately as oneof
|
|
@@ -657,9 +1023,9 @@ def protobuf_type_mapping(python_type: Type) -> str | type | None:
|
|
|
657
1023
|
return f"map<{key_proto_type}, {value_proto_type}>"
|
|
658
1024
|
|
|
659
1025
|
if inspect.isclass(python_type) and issubclass(python_type, Message):
|
|
660
|
-
return python_type
|
|
1026
|
+
return python_type.__name__
|
|
661
1027
|
|
|
662
|
-
return mapping.get(python_type)
|
|
1028
|
+
return mapping.get(python_type)
|
|
663
1029
|
|
|
664
1030
|
|
|
665
1031
|
def comment_out(docstr: str) -> tuple[str, ...]:
|
|
@@ -673,15 +1039,15 @@ def comment_out(docstr: str) -> tuple[str, ...]:
|
|
|
673
1039
|
return tuple("//" if line == "" else f"// {line}" for line in docstr.split("\n"))
|
|
674
1040
|
|
|
675
1041
|
|
|
676
|
-
def indent_lines(lines, indentation=" "):
|
|
1042
|
+
def indent_lines(lines: list[str], indentation: str = " ") -> str:
|
|
677
1043
|
"""Indent multiple lines with a given indentation string."""
|
|
678
1044
|
return "\n".join(indentation + line for line in lines)
|
|
679
1045
|
|
|
680
1046
|
|
|
681
|
-
def generate_enum_definition(enum_type:
|
|
1047
|
+
def generate_enum_definition(enum_type: Any) -> str:
|
|
682
1048
|
"""Generate a protobuf enum definition from a Python enum."""
|
|
683
1049
|
enum_name = enum_type.__name__
|
|
684
|
-
members = []
|
|
1050
|
+
members: list[str] = []
|
|
685
1051
|
for _, member in enum_type.__members__.items():
|
|
686
1052
|
members.append(f" {member.name} = {member.value};")
|
|
687
1053
|
enum_def = f"enum {enum_name} {{\n"
|
|
@@ -691,7 +1057,7 @@ def generate_enum_definition(enum_type: Type[enum.Enum]) -> str:
|
|
|
691
1057
|
|
|
692
1058
|
|
|
693
1059
|
def generate_oneof_definition(
|
|
694
|
-
field_name: str, union_args: list[
|
|
1060
|
+
field_name: str, union_args: list[Any], start_index: int
|
|
695
1061
|
) -> tuple[list[str], int]:
|
|
696
1062
|
"""
|
|
697
1063
|
Generate a oneof block in protobuf for a union field.
|
|
@@ -705,12 +1071,6 @@ def generate_oneof_definition(
|
|
|
705
1071
|
if proto_typename is None:
|
|
706
1072
|
raise Exception(f"Nested Union not flattened properly: {arg_type}")
|
|
707
1073
|
|
|
708
|
-
# If it's an enum or Message, use the type name.
|
|
709
|
-
if is_enum_type(arg_type):
|
|
710
|
-
proto_typename = arg_type.__name__
|
|
711
|
-
elif inspect.isclass(arg_type) and issubclass(arg_type, Message):
|
|
712
|
-
proto_typename = arg_type.__name__
|
|
713
|
-
|
|
714
1074
|
field_alias = f"{field_name}_{proto_typename.replace('.', '_')}"
|
|
715
1075
|
lines.append(f" {proto_typename} {field_alias} = {current};")
|
|
716
1076
|
current += 1
|
|
@@ -718,17 +1078,50 @@ def generate_oneof_definition(
|
|
|
718
1078
|
return lines, current
|
|
719
1079
|
|
|
720
1080
|
|
|
1081
|
+
def extract_nested_types(field_type: Any) -> list[Any]:
|
|
1082
|
+
"""
|
|
1083
|
+
Recursively extract all Message and enum types from a field type,
|
|
1084
|
+
including those nested within list, dict, and union types.
|
|
1085
|
+
"""
|
|
1086
|
+
extracted_types = []
|
|
1087
|
+
|
|
1088
|
+
if field_type is None or is_none_type(field_type):
|
|
1089
|
+
return extracted_types
|
|
1090
|
+
|
|
1091
|
+
# Check if the type itself is an enum or Message
|
|
1092
|
+
if is_enum_type(field_type):
|
|
1093
|
+
extracted_types.append(field_type)
|
|
1094
|
+
elif inspect.isclass(field_type) and issubclass(field_type, Message):
|
|
1095
|
+
extracted_types.append(field_type)
|
|
1096
|
+
|
|
1097
|
+
# Handle Union types
|
|
1098
|
+
if is_union_type(field_type):
|
|
1099
|
+
union_args = flatten_union(field_type)
|
|
1100
|
+
for arg in union_args:
|
|
1101
|
+
if arg is not type(None):
|
|
1102
|
+
extracted_types.extend(extract_nested_types(arg))
|
|
1103
|
+
|
|
1104
|
+
# Handle generic types (list, dict, etc.)
|
|
1105
|
+
origin = get_origin(field_type)
|
|
1106
|
+
if origin is not None:
|
|
1107
|
+
args = get_args(field_type)
|
|
1108
|
+
for arg in args:
|
|
1109
|
+
extracted_types.extend(extract_nested_types(arg))
|
|
1110
|
+
|
|
1111
|
+
return extracted_types
|
|
1112
|
+
|
|
1113
|
+
|
|
721
1114
|
def generate_message_definition(
|
|
722
|
-
message_type:
|
|
723
|
-
done_enums: set,
|
|
724
|
-
done_messages: set,
|
|
725
|
-
) -> tuple[str, list[
|
|
1115
|
+
message_type: Any,
|
|
1116
|
+
done_enums: set[Any],
|
|
1117
|
+
done_messages: set[Any],
|
|
1118
|
+
) -> tuple[str, list[Any]]:
|
|
726
1119
|
"""
|
|
727
1120
|
Generate a protobuf message definition for a Pydantic-based Message class.
|
|
728
1121
|
Also returns any referenced types (enums, messages) that need to be defined.
|
|
729
1122
|
"""
|
|
730
|
-
fields = []
|
|
731
|
-
refs = []
|
|
1123
|
+
fields: list[str] = []
|
|
1124
|
+
refs: list[Any] = []
|
|
732
1125
|
pydantic_fields = message_type.model_fields
|
|
733
1126
|
index = 1
|
|
734
1127
|
|
|
@@ -737,92 +1130,130 @@ def generate_message_definition(
|
|
|
737
1130
|
if field_type is None:
|
|
738
1131
|
raise Exception(f"Field {field_name} has no type annotation.")
|
|
739
1132
|
|
|
1133
|
+
is_optional = False
|
|
1134
|
+
# Handle Union types, which may be Optional or a oneof.
|
|
740
1135
|
if is_union_type(field_type):
|
|
741
1136
|
union_args = flatten_union(field_type)
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
)
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
1137
|
+
none_type = type(
|
|
1138
|
+
None
|
|
1139
|
+
) # Keep this as type(None) since we're working with union args
|
|
1140
|
+
|
|
1141
|
+
if none_type in union_args or None in union_args:
|
|
1142
|
+
is_optional = True
|
|
1143
|
+
union_args = [arg for arg in union_args if not is_none_type(arg)]
|
|
1144
|
+
|
|
1145
|
+
if len(union_args) == 1:
|
|
1146
|
+
# This is an Optional[T]. Treat it as a simple optional field.
|
|
1147
|
+
field_type = union_args[0]
|
|
1148
|
+
elif len(union_args) > 1:
|
|
1149
|
+
# This is a Union of multiple types, so it becomes a `oneof`.
|
|
1150
|
+
oneof_lines, new_index = generate_oneof_definition(
|
|
1151
|
+
field_name, union_args, index
|
|
1152
|
+
)
|
|
1153
|
+
fields.extend(oneof_lines)
|
|
1154
|
+
index = new_index
|
|
1155
|
+
|
|
1156
|
+
for utype in union_args:
|
|
1157
|
+
if is_enum_type(utype) and utype not in done_enums:
|
|
1158
|
+
refs.append(utype)
|
|
1159
|
+
elif (
|
|
1160
|
+
inspect.isclass(utype)
|
|
1161
|
+
and issubclass(utype, Message)
|
|
1162
|
+
and utype not in done_messages
|
|
1163
|
+
):
|
|
1164
|
+
refs.append(utype)
|
|
1165
|
+
continue # Proceed to the next field
|
|
1166
|
+
else:
|
|
1167
|
+
# This was a field of only `NoneType`, which is not supported.
|
|
1168
|
+
continue
|
|
757
1169
|
|
|
1170
|
+
# For regular fields or optional fields that have been unwrapped.
|
|
1171
|
+
proto_typename = protobuf_type_mapping(field_type)
|
|
1172
|
+
if proto_typename is None:
|
|
1173
|
+
raise Exception(f"Type {field_type} is not supported.")
|
|
1174
|
+
|
|
1175
|
+
# Extract all nested Message and enum types recursively
|
|
1176
|
+
nested_types = extract_nested_types(field_type)
|
|
1177
|
+
for nested_type in nested_types:
|
|
1178
|
+
if is_enum_type(nested_type) and nested_type not in done_enums:
|
|
1179
|
+
refs.append(nested_type)
|
|
1180
|
+
elif (
|
|
1181
|
+
inspect.isclass(nested_type)
|
|
1182
|
+
and issubclass(nested_type, Message)
|
|
1183
|
+
and nested_type not in done_messages
|
|
1184
|
+
):
|
|
1185
|
+
refs.append(nested_type)
|
|
1186
|
+
|
|
1187
|
+
if field_info.description:
|
|
1188
|
+
fields.append("// " + field_info.description)
|
|
1189
|
+
if field_info.metadata:
|
|
1190
|
+
fields.append("// Constraint:")
|
|
1191
|
+
for metadata_item in field_info.metadata:
|
|
1192
|
+
match type(metadata_item):
|
|
1193
|
+
case annotated_types.Ge:
|
|
1194
|
+
fields.append(
|
|
1195
|
+
"// greater than or equal to " + str(metadata_item.ge)
|
|
1196
|
+
)
|
|
1197
|
+
case annotated_types.Le:
|
|
1198
|
+
fields.append(
|
|
1199
|
+
"// less than or equal to " + str(metadata_item.le)
|
|
1200
|
+
)
|
|
1201
|
+
case annotated_types.Gt:
|
|
1202
|
+
fields.append("// greater than " + str(metadata_item.gt))
|
|
1203
|
+
case annotated_types.Lt:
|
|
1204
|
+
fields.append("// less than " + str(metadata_item.lt))
|
|
1205
|
+
case annotated_types.MultipleOf:
|
|
1206
|
+
fields.append(
|
|
1207
|
+
"// multiple of " + str(metadata_item.multiple_of)
|
|
1208
|
+
)
|
|
1209
|
+
case annotated_types.Len:
|
|
1210
|
+
fields.append("// length of " + str(metadata_item.len))
|
|
1211
|
+
case annotated_types.MinLen:
|
|
1212
|
+
fields.append(
|
|
1213
|
+
"// minimum length of " + str(metadata_item.min_len)
|
|
1214
|
+
)
|
|
1215
|
+
case annotated_types.MaxLen:
|
|
1216
|
+
fields.append(
|
|
1217
|
+
"// maximum length of " + str(metadata_item.max_len)
|
|
1218
|
+
)
|
|
1219
|
+
case _:
|
|
1220
|
+
fields.append("// " + str(metadata_item))
|
|
1221
|
+
|
|
1222
|
+
field_definition = f"{proto_typename} {field_name} = {index};"
|
|
1223
|
+
if is_optional:
|
|
1224
|
+
field_definition = f"optional {field_definition}"
|
|
1225
|
+
|
|
1226
|
+
fields.append(field_definition)
|
|
1227
|
+
index += 1
|
|
1228
|
+
|
|
1229
|
+
# Add reserved fields for forward/backward compatibility if specified
|
|
1230
|
+
reserved_count = get_reserved_fields_count()
|
|
1231
|
+
if reserved_count > 0:
|
|
1232
|
+
start_reserved = index
|
|
1233
|
+
end_reserved = index + reserved_count - 1
|
|
1234
|
+
fields.append("")
|
|
1235
|
+
fields.append("// Reserved fields for future compatibility")
|
|
1236
|
+
if reserved_count == 1:
|
|
1237
|
+
fields.append(f"reserved {start_reserved};")
|
|
758
1238
|
else:
|
|
759
|
-
|
|
760
|
-
if proto_typename is None:
|
|
761
|
-
raise Exception(f"Type {field_type} is not supported.")
|
|
762
|
-
|
|
763
|
-
if is_enum_type(field_type):
|
|
764
|
-
proto_typename = field_type.__name__
|
|
765
|
-
if field_type not in done_enums:
|
|
766
|
-
refs.append(field_type)
|
|
767
|
-
elif inspect.isclass(field_type) and issubclass(field_type, Message):
|
|
768
|
-
proto_typename = field_type.__name__
|
|
769
|
-
if field_type not in done_messages:
|
|
770
|
-
refs.append(field_type)
|
|
771
|
-
|
|
772
|
-
if field_info.description:
|
|
773
|
-
fields.append("// " + field_info.description)
|
|
774
|
-
if field_info.metadata:
|
|
775
|
-
fields.append("// Constraint:")
|
|
776
|
-
for metadata_item in field_info.metadata:
|
|
777
|
-
match type(metadata_item):
|
|
778
|
-
case annotated_types.Ge:
|
|
779
|
-
fields.append(
|
|
780
|
-
"// greater than or equal to " + str(metadata_item.ge)
|
|
781
|
-
)
|
|
782
|
-
case annotated_types.Le:
|
|
783
|
-
fields.append(
|
|
784
|
-
"// less than or equal to " + str(metadata_item.le)
|
|
785
|
-
)
|
|
786
|
-
case annotated_types.Gt:
|
|
787
|
-
fields.append("// greater than " + str(metadata_item.gt))
|
|
788
|
-
case annotated_types.Lt:
|
|
789
|
-
fields.append("// less than " + str(metadata_item.lt))
|
|
790
|
-
case annotated_types.MultipleOf:
|
|
791
|
-
fields.append(
|
|
792
|
-
"// multiple of " + str(metadata_item.multiple_of)
|
|
793
|
-
)
|
|
794
|
-
case annotated_types.Len:
|
|
795
|
-
fields.append("// length of " + str(metadata_item.len))
|
|
796
|
-
case annotated_types.MinLen:
|
|
797
|
-
fields.append(
|
|
798
|
-
"// minimum length of " + str(metadata_item.min_len)
|
|
799
|
-
)
|
|
800
|
-
case annotated_types.MaxLen:
|
|
801
|
-
fields.append(
|
|
802
|
-
"// maximum length of " + str(metadata_item.max_len)
|
|
803
|
-
)
|
|
804
|
-
case _:
|
|
805
|
-
fields.append("// " + str(metadata_item))
|
|
806
|
-
|
|
807
|
-
fields.append(f"{proto_typename} {field_name} = {index};")
|
|
808
|
-
index += 1
|
|
1239
|
+
fields.append(f"reserved {start_reserved} to {end_reserved};")
|
|
809
1240
|
|
|
810
1241
|
msg_def = f"message {message_type.__name__} {{\n{indent_lines(fields)}\n}}"
|
|
811
1242
|
return msg_def, refs
|
|
812
1243
|
|
|
813
1244
|
|
|
814
|
-
def is_stream_type(annotation:
|
|
1245
|
+
def is_stream_type(annotation: Any) -> bool:
|
|
815
1246
|
return get_origin(annotation) is AsyncIterator
|
|
816
1247
|
|
|
817
1248
|
|
|
818
|
-
def is_generic_alias(annotation:
|
|
1249
|
+
def is_generic_alias(annotation: Any) -> bool:
|
|
819
1250
|
return get_origin(annotation) is not None
|
|
820
1251
|
|
|
821
1252
|
|
|
822
1253
|
def generate_proto(obj: object, package_name: str = "") -> str:
|
|
823
1254
|
"""
|
|
824
1255
|
Generate a .proto definition from a service class.
|
|
825
|
-
Automatically handles Timestamp and
|
|
1256
|
+
Automatically handles Timestamp, Duration, and Empty usage.
|
|
826
1257
|
"""
|
|
827
1258
|
import datetime
|
|
828
1259
|
|
|
@@ -831,20 +1262,23 @@ def generate_proto(obj: object, package_name: str = "") -> str:
|
|
|
831
1262
|
service_docstr = inspect.getdoc(service_class)
|
|
832
1263
|
service_comment = "\n".join(comment_out(service_docstr)) if service_docstr else ""
|
|
833
1264
|
|
|
834
|
-
rpc_definitions = []
|
|
835
|
-
all_type_definitions = []
|
|
836
|
-
done_messages = set()
|
|
837
|
-
done_enums = set()
|
|
1265
|
+
rpc_definitions: list[str] = []
|
|
1266
|
+
all_type_definitions: list[str] = []
|
|
1267
|
+
done_messages: set[Any] = set()
|
|
1268
|
+
done_enums: set[Any] = set()
|
|
838
1269
|
|
|
839
1270
|
uses_timestamp = False
|
|
840
1271
|
uses_duration = False
|
|
1272
|
+
uses_empty = False
|
|
841
1273
|
|
|
842
|
-
def
|
|
1274
|
+
def check_and_set_well_known_types_for_fields(py_type: Any):
|
|
1275
|
+
"""Check well-known types for field annotations (excludes None/Empty)."""
|
|
843
1276
|
nonlocal uses_timestamp, uses_duration
|
|
844
1277
|
if py_type == datetime.datetime:
|
|
845
1278
|
uses_timestamp = True
|
|
846
1279
|
if py_type == datetime.timedelta:
|
|
847
1280
|
uses_duration = True
|
|
1281
|
+
# Don't check for None here - Optional fields don't use Empty
|
|
848
1282
|
|
|
849
1283
|
for method_name, method in get_rpc_methods(obj):
|
|
850
1284
|
if method.__name__.startswith("_"):
|
|
@@ -854,26 +1288,59 @@ def generate_proto(obj: object, package_name: str = "") -> str:
|
|
|
854
1288
|
request_type = get_request_arg_type(method_sig)
|
|
855
1289
|
response_type = method_sig.return_annotation
|
|
856
1290
|
|
|
1291
|
+
# Validate that we don't have AsyncIterator[None] which doesn't make any practical sense
|
|
1292
|
+
if is_stream_type(request_type):
|
|
1293
|
+
stream_item_type = get_args(request_type)[0]
|
|
1294
|
+
if is_none_type(stream_item_type):
|
|
1295
|
+
raise TypeError(
|
|
1296
|
+
f"Method '{method_name}' has AsyncIterator[None] as input type, which is not allowed. Streaming Empty messages is meaningless."
|
|
1297
|
+
)
|
|
1298
|
+
|
|
1299
|
+
if is_stream_type(response_type):
|
|
1300
|
+
stream_item_type = get_args(response_type)[0]
|
|
1301
|
+
if is_none_type(stream_item_type):
|
|
1302
|
+
raise TypeError(
|
|
1303
|
+
f"Method '{method_name}' has AsyncIterator[None] as return type, which is not allowed. Streaming Empty messages is meaningless."
|
|
1304
|
+
)
|
|
1305
|
+
|
|
1306
|
+
# Handle NoneType for request and response
|
|
1307
|
+
if is_none_type(request_type):
|
|
1308
|
+
uses_empty = True
|
|
1309
|
+
if is_none_type(response_type):
|
|
1310
|
+
uses_empty = True
|
|
1311
|
+
|
|
857
1312
|
# Recursively generate message definitions
|
|
858
|
-
message_types = [
|
|
1313
|
+
message_types = []
|
|
1314
|
+
if not is_none_type(request_type):
|
|
1315
|
+
message_types.append(request_type)
|
|
1316
|
+
if not is_none_type(response_type):
|
|
1317
|
+
message_types.append(response_type)
|
|
1318
|
+
|
|
859
1319
|
while message_types:
|
|
860
|
-
mt = message_types.pop()
|
|
861
|
-
if mt in done_messages:
|
|
1320
|
+
mt: type[Message] | type[ServicerContext] | None = message_types.pop()
|
|
1321
|
+
if mt in done_messages or mt is ServicerContext or mt is None:
|
|
862
1322
|
continue
|
|
863
1323
|
done_messages.add(mt)
|
|
864
1324
|
|
|
865
1325
|
if is_stream_type(mt):
|
|
866
1326
|
item_type = get_args(mt)[0]
|
|
867
|
-
|
|
1327
|
+
if not is_none_type(item_type):
|
|
1328
|
+
message_types.append(item_type)
|
|
868
1329
|
continue
|
|
869
1330
|
|
|
1331
|
+
mt = cast(type[Message], mt)
|
|
1332
|
+
|
|
870
1333
|
for _, field_info in mt.model_fields.items():
|
|
871
1334
|
t = field_info.annotation
|
|
872
1335
|
if is_union_type(t):
|
|
873
1336
|
for sub_t in flatten_union(t):
|
|
874
|
-
|
|
1337
|
+
check_and_set_well_known_types_for_fields(
|
|
1338
|
+
sub_t
|
|
1339
|
+
) # Use the field-specific version
|
|
875
1340
|
else:
|
|
876
|
-
|
|
1341
|
+
check_and_set_well_known_types_for_fields(
|
|
1342
|
+
t
|
|
1343
|
+
) # Use the field-specific version
|
|
877
1344
|
|
|
878
1345
|
msg_def, refs = generate_message_definition(mt, done_enums, done_messages)
|
|
879
1346
|
mt_doc = inspect.getdoc(mt)
|
|
@@ -898,16 +1365,47 @@ def generate_proto(obj: object, package_name: str = "") -> str:
|
|
|
898
1365
|
for comment_line in comment_out(method_docstr):
|
|
899
1366
|
rpc_definitions.append(comment_line)
|
|
900
1367
|
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
1368
|
+
input_type = request_type
|
|
1369
|
+
input_is_stream = is_stream_type(input_type)
|
|
1370
|
+
output_is_stream = is_stream_type(response_type)
|
|
1371
|
+
|
|
1372
|
+
if input_is_stream:
|
|
1373
|
+
input_msg_type = get_args(input_type)[0]
|
|
1374
|
+
else:
|
|
1375
|
+
input_msg_type = input_type
|
|
1376
|
+
|
|
1377
|
+
if output_is_stream:
|
|
1378
|
+
output_msg_type = get_args(response_type)[0]
|
|
1379
|
+
else:
|
|
1380
|
+
output_msg_type = response_type
|
|
1381
|
+
|
|
1382
|
+
# Handle NoneType by using Empty (but we've already validated no streaming of Empty above)
|
|
1383
|
+
if input_msg_type is None or input_msg_type is ServicerContext:
|
|
1384
|
+
input_str = "google.protobuf.Empty" # No need to check for stream since we validated above
|
|
1385
|
+
if input_msg_type is ServicerContext:
|
|
1386
|
+
uses_empty = True
|
|
1387
|
+
else:
|
|
1388
|
+
input_str = (
|
|
1389
|
+
f"stream {input_msg_type.__name__}"
|
|
1390
|
+
if input_is_stream
|
|
1391
|
+
else input_msg_type.__name__
|
|
905
1392
|
)
|
|
1393
|
+
|
|
1394
|
+
if output_msg_type is None or output_msg_type is ServicerContext:
|
|
1395
|
+
output_str = "google.protobuf.Empty" # No need to check for stream since we validated above
|
|
1396
|
+
if output_msg_type is ServicerContext:
|
|
1397
|
+
uses_empty = True
|
|
906
1398
|
else:
|
|
907
|
-
|
|
908
|
-
f"
|
|
1399
|
+
output_str = (
|
|
1400
|
+
f"stream {output_msg_type.__name__}"
|
|
1401
|
+
if output_is_stream
|
|
1402
|
+
else output_msg_type.__name__
|
|
909
1403
|
)
|
|
910
1404
|
|
|
1405
|
+
rpc_definitions.append(
|
|
1406
|
+
f"rpc {method_name} ({input_str}) returns ({output_str});"
|
|
1407
|
+
)
|
|
1408
|
+
|
|
911
1409
|
if not package_name:
|
|
912
1410
|
if service_name.endswith("Service"):
|
|
913
1411
|
package_name = service_name[: -len("Service")]
|
|
@@ -915,11 +1413,13 @@ def generate_proto(obj: object, package_name: str = "") -> str:
|
|
|
915
1413
|
package_name = service_name
|
|
916
1414
|
package_name = package_name.lower() + ".v1"
|
|
917
1415
|
|
|
918
|
-
imports = []
|
|
1416
|
+
imports: list[str] = []
|
|
919
1417
|
if uses_timestamp:
|
|
920
1418
|
imports.append('import "google/protobuf/timestamp.proto";')
|
|
921
1419
|
if uses_duration:
|
|
922
1420
|
imports.append('import "google/protobuf/duration.proto";')
|
|
1421
|
+
if uses_empty:
|
|
1422
|
+
imports.append('import "google/protobuf/empty.proto";')
|
|
923
1423
|
|
|
924
1424
|
import_block = "\n".join(imports)
|
|
925
1425
|
if import_block:
|
|
@@ -939,98 +1439,168 @@ service {service_name} {{
|
|
|
939
1439
|
return proto_definition
|
|
940
1440
|
|
|
941
1441
|
|
|
942
|
-
def generate_grpc_code(
|
|
1442
|
+
def generate_grpc_code(proto_path: Path) -> types.ModuleType | None:
|
|
943
1443
|
"""
|
|
944
|
-
|
|
945
|
-
|
|
1444
|
+
Run protoc to generate Python gRPC code from proto_path.
|
|
1445
|
+
Writes foo_pb2_grpc.py next to proto_path, then imports and returns that module.
|
|
946
1446
|
"""
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
1447
|
+
# 1) Ensure the .proto exists
|
|
1448
|
+
if not proto_path.is_file():
|
|
1449
|
+
raise FileNotFoundError(f"{proto_path!r} does not exist")
|
|
1450
|
+
|
|
1451
|
+
# 2) Determine output directory (same as the .proto's parent)
|
|
1452
|
+
proto_path = proto_path.resolve()
|
|
1453
|
+
out_dir = proto_path.parent
|
|
1454
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
1455
|
+
|
|
1456
|
+
# 3) Build and run the protoc command
|
|
1457
|
+
out_str = str(out_dir)
|
|
1458
|
+
well_known_path = os.path.join(os.path.dirname(grpc_tools.__file__), "_proto")
|
|
1459
|
+
args = [
|
|
1460
|
+
"protoc", # Dummy program name (required for protoc.main)
|
|
1461
|
+
"-I.",
|
|
1462
|
+
f"-I{well_known_path}",
|
|
1463
|
+
f"--grpc_python_out={out_str}",
|
|
1464
|
+
proto_path.name,
|
|
1465
|
+
]
|
|
1466
|
+
|
|
1467
|
+
current_dir = os.getcwd()
|
|
1468
|
+
os.chdir(str(out_dir))
|
|
1469
|
+
try:
|
|
1470
|
+
if protoc.main(args) != 0:
|
|
1471
|
+
return None
|
|
1472
|
+
finally:
|
|
1473
|
+
os.chdir(current_dir)
|
|
951
1474
|
|
|
952
|
-
|
|
953
|
-
|
|
1475
|
+
# 4) Locate the generated gRPC file
|
|
1476
|
+
base_name = proto_path.stem # "foo"
|
|
1477
|
+
generated_filename = f"{base_name}_pb2_grpc.py" # "foo_pb2_grpc.py"
|
|
1478
|
+
generated_filepath = out_dir / generated_filename
|
|
954
1479
|
|
|
955
|
-
|
|
956
|
-
|
|
1480
|
+
# 5) Add out_dir to sys.path so we can import it
|
|
1481
|
+
if out_str not in sys.path:
|
|
1482
|
+
sys.path.append(out_str)
|
|
957
1483
|
|
|
1484
|
+
# 6) Load and return the module
|
|
958
1485
|
spec = importlib.util.spec_from_file_location(
|
|
959
|
-
|
|
1486
|
+
base_name + "_pb2_grpc", str(generated_filepath)
|
|
960
1487
|
)
|
|
961
|
-
if spec is None:
|
|
1488
|
+
if spec is None or spec.loader is None:
|
|
962
1489
|
return None
|
|
963
|
-
pb2_grpc_module = importlib.util.module_from_spec(spec)
|
|
964
|
-
if spec.loader is None:
|
|
965
|
-
return None
|
|
966
|
-
spec.loader.exec_module(pb2_grpc_module)
|
|
967
1490
|
|
|
968
|
-
|
|
1491
|
+
module = importlib.util.module_from_spec(spec)
|
|
1492
|
+
spec.loader.exec_module(module)
|
|
1493
|
+
return module
|
|
969
1494
|
|
|
970
1495
|
|
|
971
|
-
def generate_connecpy_code(
|
|
972
|
-
proto_file: str, connecpy_out: str
|
|
973
|
-
) -> types.ModuleType | None:
|
|
1496
|
+
def generate_connecpy_code(proto_path: Path) -> types.ModuleType | None:
|
|
974
1497
|
"""
|
|
975
|
-
|
|
976
|
-
|
|
1498
|
+
Run protoc with the Connecpy plugin to generate Python Connecpy code from proto_path.
|
|
1499
|
+
Writes foo_connecpy.py next to proto_path, then imports and returns that module.
|
|
977
1500
|
"""
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
1501
|
+
# 1) Ensure the .proto exists
|
|
1502
|
+
if not proto_path.is_file():
|
|
1503
|
+
raise FileNotFoundError(f"{proto_path!r} does not exist")
|
|
1504
|
+
|
|
1505
|
+
# 2) Determine output directory (same as the .proto's parent)
|
|
1506
|
+
proto_path = proto_path.resolve()
|
|
1507
|
+
out_dir = proto_path.parent
|
|
1508
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
1509
|
+
|
|
1510
|
+
# 3) Build and run the protoc command
|
|
1511
|
+
out_str = str(out_dir)
|
|
1512
|
+
well_known_path = os.path.join(os.path.dirname(grpc_tools.__file__), "_proto")
|
|
1513
|
+
args = [
|
|
1514
|
+
"protoc", # Dummy program name (required for protoc.main)
|
|
1515
|
+
"-I.",
|
|
1516
|
+
f"-I{well_known_path}",
|
|
1517
|
+
f"--connecpy_out={out_str}",
|
|
1518
|
+
proto_path.name,
|
|
1519
|
+
]
|
|
982
1520
|
|
|
983
|
-
|
|
984
|
-
|
|
1521
|
+
current_dir = os.getcwd()
|
|
1522
|
+
os.chdir(str(out_dir))
|
|
1523
|
+
try:
|
|
1524
|
+
if protoc.main(args) != 0:
|
|
1525
|
+
return None
|
|
1526
|
+
finally:
|
|
1527
|
+
os.chdir(current_dir)
|
|
985
1528
|
|
|
986
|
-
|
|
987
|
-
|
|
1529
|
+
# 4) Locate the generated file
|
|
1530
|
+
base_name = proto_path.stem # "foo"
|
|
1531
|
+
generated_filename = f"{base_name}_connecpy.py" # "foo_connecpy.py"
|
|
1532
|
+
generated_filepath = out_dir / generated_filename
|
|
988
1533
|
|
|
1534
|
+
# 5) Add out_dir to sys.path so we can import by filename
|
|
1535
|
+
if out_str not in sys.path:
|
|
1536
|
+
sys.path.append(out_str)
|
|
1537
|
+
|
|
1538
|
+
# 6) Load and return the module
|
|
989
1539
|
spec = importlib.util.spec_from_file_location(
|
|
990
|
-
|
|
1540
|
+
base_name + "_connecpy", str(generated_filepath)
|
|
991
1541
|
)
|
|
992
|
-
if spec is None:
|
|
993
|
-
return None
|
|
994
|
-
connecpy_module = importlib.util.module_from_spec(spec)
|
|
995
|
-
if spec.loader is None:
|
|
1542
|
+
if spec is None or spec.loader is None:
|
|
996
1543
|
return None
|
|
997
|
-
spec.loader.exec_module(connecpy_module)
|
|
998
1544
|
|
|
999
|
-
|
|
1545
|
+
module = importlib.util.module_from_spec(spec)
|
|
1546
|
+
spec.loader.exec_module(module)
|
|
1547
|
+
return module
|
|
1000
1548
|
|
|
1001
1549
|
|
|
1002
|
-
def generate_pb_code(
|
|
1003
|
-
proto_file: str, python_out: str, pyi_out: str
|
|
1004
|
-
) -> types.ModuleType | None:
|
|
1550
|
+
def generate_pb_code(proto_path: Path) -> types.ModuleType | None:
|
|
1005
1551
|
"""
|
|
1006
|
-
|
|
1007
|
-
|
|
1552
|
+
Run protoc to generate Python gRPC code from proto_path.
|
|
1553
|
+
Writes foo_pb2.py and foo_pb2.pyi next to proto_path, then imports and returns the pb2 module.
|
|
1008
1554
|
"""
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1555
|
+
# 1) Make sure proto_path exists
|
|
1556
|
+
if not proto_path.is_file():
|
|
1557
|
+
raise FileNotFoundError(f"{proto_path!r} does not exist")
|
|
1558
|
+
|
|
1559
|
+
# 2) Determine output directory (same as proto file)
|
|
1560
|
+
proto_path = proto_path.resolve()
|
|
1561
|
+
out_dir = proto_path.parent
|
|
1562
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
1563
|
+
|
|
1564
|
+
# 3) Build and run protoc command
|
|
1565
|
+
out_str = str(out_dir)
|
|
1566
|
+
well_known_path = os.path.join(os.path.dirname(grpc_tools.__file__), "_proto")
|
|
1567
|
+
args = [
|
|
1568
|
+
"protoc", # Dummy program name (required for protoc.main)
|
|
1569
|
+
"-I.",
|
|
1570
|
+
f"-I{well_known_path}",
|
|
1571
|
+
f"--python_out={out_str}",
|
|
1572
|
+
f"--pyi_out={out_str}",
|
|
1573
|
+
proto_path.name,
|
|
1574
|
+
]
|
|
1575
|
+
|
|
1576
|
+
current_dir = os.getcwd()
|
|
1577
|
+
os.chdir(str(out_dir))
|
|
1578
|
+
try:
|
|
1579
|
+
if protoc.main(args) != 0:
|
|
1580
|
+
return None
|
|
1581
|
+
finally:
|
|
1582
|
+
os.chdir(current_dir)
|
|
1013
1583
|
|
|
1014
|
-
|
|
1015
|
-
|
|
1584
|
+
# 4) Locate generated file
|
|
1585
|
+
base_name = proto_path.stem # e.g. "foo"
|
|
1586
|
+
generated_file = out_dir / f"{base_name}_pb2.py"
|
|
1016
1587
|
|
|
1017
|
-
|
|
1018
|
-
|
|
1588
|
+
# 5) Add to sys.path if needed
|
|
1589
|
+
if out_str not in sys.path:
|
|
1590
|
+
sys.path.append(out_str)
|
|
1019
1591
|
|
|
1592
|
+
# 6) Import it
|
|
1020
1593
|
spec = importlib.util.spec_from_file_location(
|
|
1021
|
-
|
|
1594
|
+
base_name + "_pb2", str(generated_file)
|
|
1022
1595
|
)
|
|
1023
|
-
if spec is None:
|
|
1596
|
+
if spec is None or spec.loader is None:
|
|
1024
1597
|
return None
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
spec.loader.exec_module(pb2_module)
|
|
1029
|
-
|
|
1030
|
-
return pb2_module
|
|
1598
|
+
module = importlib.util.module_from_spec(spec)
|
|
1599
|
+
spec.loader.exec_module(module)
|
|
1600
|
+
return module
|
|
1031
1601
|
|
|
1032
1602
|
|
|
1033
|
-
def get_request_arg_type(sig):
|
|
1603
|
+
def get_request_arg_type(sig: inspect.Signature) -> Any:
|
|
1034
1604
|
"""Return the type annotation of the first parameter (request) of a method."""
|
|
1035
1605
|
num_of_params = len(sig.parameters)
|
|
1036
1606
|
if not (num_of_params == 1 or num_of_params == 2):
|
|
@@ -1038,7 +1608,7 @@ def get_request_arg_type(sig):
|
|
|
1038
1608
|
return tuple(sig.parameters.values())[0].annotation
|
|
1039
1609
|
|
|
1040
1610
|
|
|
1041
|
-
def get_rpc_methods(obj: object) -> list[tuple[str,
|
|
1611
|
+
def get_rpc_methods(obj: object) -> list[tuple[str, Callable[..., Any]]]:
|
|
1042
1612
|
"""
|
|
1043
1613
|
Retrieve the list of RPC methods from a service object.
|
|
1044
1614
|
The method name is converted to PascalCase for .proto compatibility.
|
|
@@ -1059,68 +1629,393 @@ def is_skip_generation() -> bool:
|
|
|
1059
1629
|
return os.getenv("PYDANTIC_RPC_SKIP_GENERATION", "false").lower() == "true"
|
|
1060
1630
|
|
|
1061
1631
|
|
|
1062
|
-
def
|
|
1632
|
+
def get_reserved_fields_count() -> int:
|
|
1633
|
+
"""Get the number of reserved fields to add to each message from environment variable."""
|
|
1634
|
+
try:
|
|
1635
|
+
return max(0, int(os.getenv("PYDANTIC_RPC_RESERVED_FIELDS", "0")))
|
|
1636
|
+
except ValueError:
|
|
1637
|
+
return 0
|
|
1638
|
+
|
|
1639
|
+
|
|
1640
|
+
def generate_and_compile_proto(
|
|
1641
|
+
obj: object,
|
|
1642
|
+
package_name: str = "",
|
|
1643
|
+
existing_proto_path: Path | None = None,
|
|
1644
|
+
) -> tuple[Any, Any]:
|
|
1063
1645
|
if is_skip_generation():
|
|
1064
1646
|
import importlib
|
|
1065
1647
|
|
|
1066
|
-
pb2_module =
|
|
1067
|
-
pb2_grpc_module =
|
|
1068
|
-
|
|
1069
|
-
|
|
1648
|
+
pb2_module = None
|
|
1649
|
+
pb2_grpc_module = None
|
|
1650
|
+
|
|
1651
|
+
try:
|
|
1652
|
+
pb2_module = importlib.import_module(
|
|
1653
|
+
f"{obj.__class__.__name__.lower()}_pb2"
|
|
1654
|
+
)
|
|
1655
|
+
except ImportError:
|
|
1656
|
+
pass
|
|
1657
|
+
|
|
1658
|
+
try:
|
|
1659
|
+
pb2_grpc_module = importlib.import_module(
|
|
1660
|
+
f"{obj.__class__.__name__.lower()}_pb2_grpc"
|
|
1661
|
+
)
|
|
1662
|
+
except ImportError:
|
|
1663
|
+
pass
|
|
1070
1664
|
|
|
1071
1665
|
if pb2_grpc_module is not None and pb2_module is not None:
|
|
1072
1666
|
return pb2_grpc_module, pb2_module
|
|
1073
1667
|
|
|
1074
1668
|
# If the modules are not found, generate and compile the proto files.
|
|
1075
1669
|
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1670
|
+
if existing_proto_path:
|
|
1671
|
+
# Use the provided existing proto file (skip generation)
|
|
1672
|
+
proto_file_path = existing_proto_path
|
|
1673
|
+
else:
|
|
1674
|
+
# Generate as before
|
|
1675
|
+
klass = obj.__class__
|
|
1676
|
+
proto_file = generate_proto(obj, package_name)
|
|
1677
|
+
proto_file_name = klass.__name__.lower() + ".proto"
|
|
1678
|
+
proto_file_path = get_proto_path(proto_file_name)
|
|
1079
1679
|
|
|
1080
|
-
|
|
1081
|
-
|
|
1680
|
+
with proto_file_path.open(mode="w", encoding="utf-8") as f:
|
|
1681
|
+
_ = f.write(proto_file)
|
|
1082
1682
|
|
|
1083
|
-
gen_pb = generate_pb_code(
|
|
1683
|
+
gen_pb = generate_pb_code(proto_file_path)
|
|
1084
1684
|
if gen_pb is None:
|
|
1085
1685
|
raise Exception("Generating pb code")
|
|
1086
1686
|
|
|
1087
|
-
gen_grpc = generate_grpc_code(
|
|
1687
|
+
gen_grpc = generate_grpc_code(proto_file_path)
|
|
1088
1688
|
if gen_grpc is None:
|
|
1089
1689
|
raise Exception("Generating grpc code")
|
|
1090
1690
|
return gen_grpc, gen_pb
|
|
1091
1691
|
|
|
1092
1692
|
|
|
1093
|
-
def
|
|
1693
|
+
def get_proto_path(proto_filename: str) -> Path:
|
|
1694
|
+
# 1. Get raw env var (or default to cwd)
|
|
1695
|
+
raw = os.getenv("PYDANTIC_RPC_PROTO_PATH", None)
|
|
1696
|
+
base = Path(raw) if raw is not None else Path.cwd()
|
|
1697
|
+
|
|
1698
|
+
# 2. Expand ~ and env-vars, then make absolute
|
|
1699
|
+
base = Path(os.path.expandvars(os.path.expanduser(str(base)))).resolve()
|
|
1700
|
+
|
|
1701
|
+
# 3. Ensure it's a directory (or create it)
|
|
1702
|
+
if not base.exists():
|
|
1703
|
+
try:
|
|
1704
|
+
base.mkdir(parents=True, exist_ok=True)
|
|
1705
|
+
except OSError as e:
|
|
1706
|
+
raise RuntimeError(f"Unable to create directory {base!r}: {e}") from e
|
|
1707
|
+
elif not base.is_dir():
|
|
1708
|
+
raise NotADirectoryError(f"{base!r} exists but is not a directory")
|
|
1709
|
+
|
|
1710
|
+
# 4. Check writability
|
|
1711
|
+
if not os.access(base, os.W_OK):
|
|
1712
|
+
raise PermissionError(f"No write permission for directory {base!r}")
|
|
1713
|
+
|
|
1714
|
+
# 5. Return the final file path
|
|
1715
|
+
return base / proto_filename
|
|
1716
|
+
|
|
1717
|
+
|
|
1718
|
+
def generate_and_compile_proto_using_connecpy(
|
|
1719
|
+
obj: object,
|
|
1720
|
+
package_name: str = "",
|
|
1721
|
+
existing_proto_path: Path | None = None,
|
|
1722
|
+
) -> tuple[Any, Any]:
|
|
1094
1723
|
if is_skip_generation():
|
|
1095
1724
|
import importlib
|
|
1096
1725
|
|
|
1097
|
-
pb2_module =
|
|
1098
|
-
connecpy_module =
|
|
1099
|
-
|
|
1100
|
-
|
|
1726
|
+
pb2_module = None
|
|
1727
|
+
connecpy_module = None
|
|
1728
|
+
|
|
1729
|
+
try:
|
|
1730
|
+
pb2_module = importlib.import_module(
|
|
1731
|
+
f"{obj.__class__.__name__.lower()}_pb2"
|
|
1732
|
+
)
|
|
1733
|
+
except ImportError:
|
|
1734
|
+
pass
|
|
1735
|
+
|
|
1736
|
+
try:
|
|
1737
|
+
connecpy_module = importlib.import_module(
|
|
1738
|
+
f"{obj.__class__.__name__.lower()}_connecpy"
|
|
1739
|
+
)
|
|
1740
|
+
except ImportError:
|
|
1741
|
+
pass
|
|
1101
1742
|
|
|
1102
1743
|
if connecpy_module is not None and pb2_module is not None:
|
|
1103
1744
|
return connecpy_module, pb2_module
|
|
1104
1745
|
|
|
1105
1746
|
# If the modules are not found, generate and compile the proto files.
|
|
1106
1747
|
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1748
|
+
if existing_proto_path:
|
|
1749
|
+
# Use the provided existing proto file (skip generation)
|
|
1750
|
+
proto_file_path = existing_proto_path
|
|
1751
|
+
else:
|
|
1752
|
+
# Generate as before
|
|
1753
|
+
klass = obj.__class__
|
|
1754
|
+
proto_file = generate_proto(obj, package_name)
|
|
1755
|
+
proto_file_name = klass.__name__.lower() + ".proto"
|
|
1110
1756
|
|
|
1111
|
-
|
|
1112
|
-
|
|
1757
|
+
proto_file_path = get_proto_path(proto_file_name)
|
|
1758
|
+
with proto_file_path.open(mode="w", encoding="utf-8") as f:
|
|
1759
|
+
_ = f.write(proto_file)
|
|
1113
1760
|
|
|
1114
|
-
gen_pb = generate_pb_code(
|
|
1761
|
+
gen_pb = generate_pb_code(proto_file_path)
|
|
1115
1762
|
if gen_pb is None:
|
|
1116
1763
|
raise Exception("Generating pb code")
|
|
1117
1764
|
|
|
1118
|
-
gen_connecpy = generate_connecpy_code(
|
|
1765
|
+
gen_connecpy = generate_connecpy_code(proto_file_path)
|
|
1119
1766
|
if gen_connecpy is None:
|
|
1120
1767
|
raise Exception("Generating Connecpy code")
|
|
1121
1768
|
return gen_connecpy, gen_pb
|
|
1122
1769
|
|
|
1123
1770
|
|
|
1771
|
+
def is_combined_proto_enabled() -> bool:
|
|
1772
|
+
"""Check if combined proto file generation is enabled."""
|
|
1773
|
+
return os.getenv("PYDANTIC_RPC_COMBINED_PROTO", "false").lower() == "true"
|
|
1774
|
+
|
|
1775
|
+
|
|
1776
|
+
def generate_combined_proto(
|
|
1777
|
+
*services: object, package_name: str = "combined.v1"
|
|
1778
|
+
) -> str:
|
|
1779
|
+
"""Generate a combined .proto definition from multiple service classes."""
|
|
1780
|
+
import datetime
|
|
1781
|
+
|
|
1782
|
+
all_type_definitions: list[str] = []
|
|
1783
|
+
all_service_definitions: list[str] = []
|
|
1784
|
+
done_messages: set[Any] = set()
|
|
1785
|
+
done_enums: set[Any] = set()
|
|
1786
|
+
|
|
1787
|
+
uses_timestamp = False
|
|
1788
|
+
uses_duration = False
|
|
1789
|
+
uses_empty = False
|
|
1790
|
+
|
|
1791
|
+
def check_and_set_well_known_types_for_fields(py_type: Any):
|
|
1792
|
+
"""Check well-known types for field annotations (excludes None/Empty)."""
|
|
1793
|
+
nonlocal uses_timestamp, uses_duration
|
|
1794
|
+
if py_type == datetime.datetime:
|
|
1795
|
+
uses_timestamp = True
|
|
1796
|
+
if py_type == datetime.timedelta:
|
|
1797
|
+
uses_duration = True
|
|
1798
|
+
|
|
1799
|
+
# Process each service
|
|
1800
|
+
for service_obj in services:
|
|
1801
|
+
service_class = service_obj.__class__
|
|
1802
|
+
service_name = service_class.__name__
|
|
1803
|
+
service_docstr = inspect.getdoc(service_class)
|
|
1804
|
+
service_comment = (
|
|
1805
|
+
"\n".join(comment_out(service_docstr)) if service_docstr else ""
|
|
1806
|
+
)
|
|
1807
|
+
|
|
1808
|
+
service_rpc_definitions: list[str] = []
|
|
1809
|
+
|
|
1810
|
+
for method_name, method in get_rpc_methods(service_obj):
|
|
1811
|
+
if method.__name__.startswith("_"):
|
|
1812
|
+
continue
|
|
1813
|
+
|
|
1814
|
+
method_sig = inspect.signature(method)
|
|
1815
|
+
request_type = get_request_arg_type(method_sig)
|
|
1816
|
+
response_type = method_sig.return_annotation
|
|
1817
|
+
|
|
1818
|
+
# Validate stream types
|
|
1819
|
+
if is_stream_type(request_type):
|
|
1820
|
+
stream_item_type = get_args(request_type)[0]
|
|
1821
|
+
if is_none_type(stream_item_type):
|
|
1822
|
+
raise TypeError(
|
|
1823
|
+
f"Method '{method_name}' has AsyncIterator[None] as input type, which is not allowed."
|
|
1824
|
+
)
|
|
1825
|
+
|
|
1826
|
+
if is_stream_type(response_type):
|
|
1827
|
+
stream_item_type = get_args(response_type)[0]
|
|
1828
|
+
if is_none_type(stream_item_type):
|
|
1829
|
+
raise TypeError(
|
|
1830
|
+
f"Method '{method_name}' has AsyncIterator[None] as return type, which is not allowed."
|
|
1831
|
+
)
|
|
1832
|
+
|
|
1833
|
+
# Handle NoneType for request and response
|
|
1834
|
+
if is_none_type(request_type):
|
|
1835
|
+
uses_empty = True
|
|
1836
|
+
if is_none_type(response_type):
|
|
1837
|
+
uses_empty = True
|
|
1838
|
+
|
|
1839
|
+
# Collect message types for processing
|
|
1840
|
+
message_types = []
|
|
1841
|
+
if not is_none_type(request_type):
|
|
1842
|
+
message_types.append(request_type)
|
|
1843
|
+
if not is_none_type(response_type):
|
|
1844
|
+
message_types.append(response_type)
|
|
1845
|
+
|
|
1846
|
+
# Process message types
|
|
1847
|
+
while message_types:
|
|
1848
|
+
mt: type[Message] | type[ServicerContext] | None = message_types.pop()
|
|
1849
|
+
if mt in done_messages or mt is ServicerContext or mt is None:
|
|
1850
|
+
continue
|
|
1851
|
+
done_messages.add(mt)
|
|
1852
|
+
|
|
1853
|
+
if is_stream_type(mt):
|
|
1854
|
+
item_type = get_args(mt)[0]
|
|
1855
|
+
if not is_none_type(item_type):
|
|
1856
|
+
message_types.append(item_type)
|
|
1857
|
+
continue
|
|
1858
|
+
|
|
1859
|
+
mt = cast(type[Message], mt)
|
|
1860
|
+
|
|
1861
|
+
for _, field_info in mt.model_fields.items():
|
|
1862
|
+
t = field_info.annotation
|
|
1863
|
+
if is_union_type(t):
|
|
1864
|
+
for sub_t in flatten_union(t):
|
|
1865
|
+
check_and_set_well_known_types_for_fields(sub_t)
|
|
1866
|
+
else:
|
|
1867
|
+
check_and_set_well_known_types_for_fields(t)
|
|
1868
|
+
|
|
1869
|
+
msg_def, refs = generate_message_definition(
|
|
1870
|
+
mt, done_enums, done_messages
|
|
1871
|
+
)
|
|
1872
|
+
mt_doc = inspect.getdoc(mt)
|
|
1873
|
+
if mt_doc:
|
|
1874
|
+
for comment_line in comment_out(mt_doc):
|
|
1875
|
+
all_type_definitions.append(comment_line)
|
|
1876
|
+
|
|
1877
|
+
all_type_definitions.append(msg_def)
|
|
1878
|
+
all_type_definitions.append("")
|
|
1879
|
+
|
|
1880
|
+
for r in refs:
|
|
1881
|
+
if is_enum_type(r) and r not in done_enums:
|
|
1882
|
+
done_enums.add(r)
|
|
1883
|
+
enum_def = generate_enum_definition(r)
|
|
1884
|
+
all_type_definitions.append(enum_def)
|
|
1885
|
+
all_type_definitions.append("")
|
|
1886
|
+
elif issubclass(r, Message) and r not in done_messages:
|
|
1887
|
+
message_types.append(r)
|
|
1888
|
+
|
|
1889
|
+
# Generate RPC definition
|
|
1890
|
+
method_docstr = inspect.getdoc(method)
|
|
1891
|
+
if method_docstr:
|
|
1892
|
+
for comment_line in comment_out(method_docstr):
|
|
1893
|
+
service_rpc_definitions.append(comment_line)
|
|
1894
|
+
|
|
1895
|
+
input_type = request_type
|
|
1896
|
+
input_is_stream = is_stream_type(input_type)
|
|
1897
|
+
output_is_stream = is_stream_type(response_type)
|
|
1898
|
+
|
|
1899
|
+
if input_is_stream:
|
|
1900
|
+
input_msg_type = get_args(input_type)[0]
|
|
1901
|
+
else:
|
|
1902
|
+
input_msg_type = input_type
|
|
1903
|
+
|
|
1904
|
+
if output_is_stream:
|
|
1905
|
+
output_msg_type = get_args(response_type)[0]
|
|
1906
|
+
else:
|
|
1907
|
+
output_msg_type = response_type
|
|
1908
|
+
|
|
1909
|
+
# Handle NoneType by using Empty
|
|
1910
|
+
if input_msg_type is None or input_msg_type is ServicerContext:
|
|
1911
|
+
input_str = "google.protobuf.Empty" # No need to check for stream since we validated above
|
|
1912
|
+
if input_msg_type is ServicerContext:
|
|
1913
|
+
uses_empty = True
|
|
1914
|
+
else:
|
|
1915
|
+
input_str = (
|
|
1916
|
+
f"stream {input_msg_type.__name__}"
|
|
1917
|
+
if input_is_stream
|
|
1918
|
+
else input_msg_type.__name__
|
|
1919
|
+
)
|
|
1920
|
+
|
|
1921
|
+
if output_msg_type is None or output_msg_type is ServicerContext:
|
|
1922
|
+
output_str = "google.protobuf.Empty" # No need to check for stream since we validated above
|
|
1923
|
+
if output_msg_type is ServicerContext:
|
|
1924
|
+
uses_empty = True
|
|
1925
|
+
else:
|
|
1926
|
+
output_str = (
|
|
1927
|
+
f"stream {output_msg_type.__name__}"
|
|
1928
|
+
if output_is_stream
|
|
1929
|
+
else output_msg_type.__name__
|
|
1930
|
+
)
|
|
1931
|
+
|
|
1932
|
+
service_rpc_definitions.append(
|
|
1933
|
+
f"rpc {method_name} ({input_str}) returns ({output_str});"
|
|
1934
|
+
)
|
|
1935
|
+
|
|
1936
|
+
# Create service definition
|
|
1937
|
+
service_def_lines: list[str] = []
|
|
1938
|
+
if service_comment:
|
|
1939
|
+
service_def_lines.append(service_comment)
|
|
1940
|
+
service_def_lines.append(f"service {service_name} {{")
|
|
1941
|
+
service_def_lines.extend([f" {line}" for line in service_rpc_definitions])
|
|
1942
|
+
service_def_lines.append("}")
|
|
1943
|
+
service_def_lines.append("")
|
|
1944
|
+
|
|
1945
|
+
all_service_definitions.extend(service_def_lines)
|
|
1946
|
+
|
|
1947
|
+
# Build imports
|
|
1948
|
+
imports: list[str] = []
|
|
1949
|
+
if uses_timestamp:
|
|
1950
|
+
imports.append('import "google/protobuf/timestamp.proto";')
|
|
1951
|
+
if uses_duration:
|
|
1952
|
+
imports.append('import "google/protobuf/duration.proto";')
|
|
1953
|
+
if uses_empty:
|
|
1954
|
+
imports.append('import "google/protobuf/empty.proto";')
|
|
1955
|
+
|
|
1956
|
+
import_block = "\n".join(imports)
|
|
1957
|
+
if import_block:
|
|
1958
|
+
import_block += "\n"
|
|
1959
|
+
|
|
1960
|
+
# Combine everything
|
|
1961
|
+
proto_definition = f"""syntax = "proto3";
|
|
1962
|
+
|
|
1963
|
+
package {package_name};
|
|
1964
|
+
|
|
1965
|
+
{import_block}{"".join(all_service_definitions)}
|
|
1966
|
+
{indent_lines(all_type_definitions, "")}
|
|
1967
|
+
"""
|
|
1968
|
+
return proto_definition
|
|
1969
|
+
|
|
1970
|
+
|
|
1971
|
+
def get_combined_proto_filename() -> str:
|
|
1972
|
+
"""Get the combined proto filename."""
|
|
1973
|
+
return os.getenv("PYDANTIC_RPC_COMBINED_PROTO_FILENAME", "combined_services.proto")
|
|
1974
|
+
|
|
1975
|
+
|
|
1976
|
+
def generate_combined_descriptor_set(
|
|
1977
|
+
*services: object, output_path: Path | None = None
|
|
1978
|
+
) -> bytes:
|
|
1979
|
+
"""Generate a combined protobuf descriptor set from multiple services."""
|
|
1980
|
+
filename = get_combined_proto_filename()
|
|
1981
|
+
|
|
1982
|
+
if output_path is None:
|
|
1983
|
+
output_path = get_proto_path(filename)
|
|
1984
|
+
|
|
1985
|
+
# Generate combined proto file
|
|
1986
|
+
combined_proto = generate_combined_proto(*services)
|
|
1987
|
+
proto_file_path = get_proto_path(filename)
|
|
1988
|
+
|
|
1989
|
+
with proto_file_path.open(mode="w", encoding="utf-8") as f:
|
|
1990
|
+
_ = f.write(combined_proto)
|
|
1991
|
+
|
|
1992
|
+
# Generate descriptor set using protoc
|
|
1993
|
+
out_str = str(proto_file_path.parent)
|
|
1994
|
+
well_known_path = os.path.join(os.path.dirname(grpc_tools.__file__), "_proto")
|
|
1995
|
+
args = [
|
|
1996
|
+
"protoc",
|
|
1997
|
+
f"-I{out_str}",
|
|
1998
|
+
f"-I{well_known_path}",
|
|
1999
|
+
f"--descriptor_set_out={output_path}",
|
|
2000
|
+
"--include_imports",
|
|
2001
|
+
proto_file_path.name,
|
|
2002
|
+
]
|
|
2003
|
+
|
|
2004
|
+
current_dir = os.getcwd()
|
|
2005
|
+
os.chdir(out_str)
|
|
2006
|
+
try:
|
|
2007
|
+
if protoc.main(args) != 0:
|
|
2008
|
+
raise RuntimeError("Failed to generate combined descriptor set")
|
|
2009
|
+
finally:
|
|
2010
|
+
os.chdir(current_dir)
|
|
2011
|
+
|
|
2012
|
+
# Read and return the descriptor set
|
|
2013
|
+
with open(output_path, "rb") as f:
|
|
2014
|
+
descriptor_data = f.read()
|
|
2015
|
+
|
|
2016
|
+
return descriptor_data
|
|
2017
|
+
|
|
2018
|
+
|
|
1124
2019
|
###############################################################################
|
|
1125
2020
|
# 4. Server Implementations
|
|
1126
2021
|
###############################################################################
|
|
@@ -1129,13 +2024,13 @@ def generate_and_compile_proto_using_connecpy(obj: object, package_name: str = "
|
|
|
1129
2024
|
class Server:
|
|
1130
2025
|
"""A simple gRPC server that uses ThreadPoolExecutor for concurrency."""
|
|
1131
2026
|
|
|
1132
|
-
def __init__(self, max_workers: int = 8, *interceptors) -> None:
|
|
1133
|
-
self._server = grpc.server(
|
|
2027
|
+
def __init__(self, max_workers: int = 8, *interceptors: Any) -> None:
|
|
2028
|
+
self._server: grpc.Server = grpc.server(
|
|
1134
2029
|
futures.ThreadPoolExecutor(max_workers), interceptors=interceptors
|
|
1135
2030
|
)
|
|
1136
|
-
self._service_names = []
|
|
1137
|
-
self._package_name = ""
|
|
1138
|
-
self._port = 50051
|
|
2031
|
+
self._service_names: list[str] = []
|
|
2032
|
+
self._package_name: str = ""
|
|
2033
|
+
self._port: int = 50051
|
|
1139
2034
|
|
|
1140
2035
|
def set_package_name(self, package_name: str):
|
|
1141
2036
|
"""Set the package name for .proto generation."""
|
|
@@ -1147,10 +2042,15 @@ class Server:
|
|
|
1147
2042
|
|
|
1148
2043
|
def mount(self, obj: object, package_name: str = ""):
|
|
1149
2044
|
"""Generate and compile proto files, then mount the service implementation."""
|
|
1150
|
-
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
|
|
2045
|
+
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name) or (
|
|
2046
|
+
None,
|
|
2047
|
+
None,
|
|
2048
|
+
)
|
|
1151
2049
|
self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
|
|
1152
2050
|
|
|
1153
|
-
def mount_using_pb2_modules(
|
|
2051
|
+
def mount_using_pb2_modules(
|
|
2052
|
+
self, pb2_grpc_module: Any, pb2_module: Any, obj: object
|
|
2053
|
+
):
|
|
1154
2054
|
"""Connect the compiled gRPC modules with the service implementation."""
|
|
1155
2055
|
concreteServiceClass = connect_obj_with_stub(pb2_grpc_module, pb2_module, obj)
|
|
1156
2056
|
service_name = obj.__class__.__name__
|
|
@@ -1163,7 +2063,7 @@ class Server:
|
|
|
1163
2063
|
].full_name
|
|
1164
2064
|
self._service_names.append(full_service_name)
|
|
1165
2065
|
|
|
1166
|
-
def run(self, *objs):
|
|
2066
|
+
def run(self, *objs: object):
|
|
1167
2067
|
"""
|
|
1168
2068
|
Mount multiple services and run the gRPC server with reflection and health check.
|
|
1169
2069
|
Press Ctrl+C or send SIGTERM to stop.
|
|
@@ -1183,14 +2083,16 @@ class Server:
|
|
|
1183
2083
|
self._server.add_insecure_port(f"[::]:{self._port}")
|
|
1184
2084
|
self._server.start()
|
|
1185
2085
|
|
|
1186
|
-
def handle_signal(signal, frame):
|
|
2086
|
+
def handle_signal(signum: signal.Signals, frame: Any):
|
|
2087
|
+
_ = signum
|
|
2088
|
+
_ = frame
|
|
1187
2089
|
print("Received shutdown signal...")
|
|
1188
2090
|
self._server.stop(grace=10)
|
|
1189
2091
|
print("gRPC server shutdown.")
|
|
1190
2092
|
sys.exit(0)
|
|
1191
2093
|
|
|
1192
|
-
signal.signal(signal.SIGINT, handle_signal)
|
|
1193
|
-
signal.signal(signal.SIGTERM, handle_signal)
|
|
2094
|
+
_ = signal.signal(signal.SIGINT, handle_signal) # pyright:ignore[reportArgumentType]
|
|
2095
|
+
_ = signal.signal(signal.SIGTERM, handle_signal) # pyright:ignore[reportArgumentType]
|
|
1194
2096
|
|
|
1195
2097
|
print("gRPC server is running...")
|
|
1196
2098
|
while True:
|
|
@@ -1200,11 +2102,11 @@ class Server:
|
|
|
1200
2102
|
class AsyncIOServer:
|
|
1201
2103
|
"""An async gRPC server using asyncio."""
|
|
1202
2104
|
|
|
1203
|
-
def __init__(self, *interceptors) -> None:
|
|
1204
|
-
self._server = grpc.aio.server(interceptors=interceptors)
|
|
1205
|
-
self._service_names = []
|
|
1206
|
-
self._package_name = ""
|
|
1207
|
-
self._port = 50051
|
|
2105
|
+
def __init__(self, *interceptors: grpc.ServerInterceptor) -> None:
|
|
2106
|
+
self._server: grpc.aio.Server = grpc.aio.server(interceptors=interceptors)
|
|
2107
|
+
self._service_names: list[str] = []
|
|
2108
|
+
self._package_name: str = ""
|
|
2109
|
+
self._port: int = 50051
|
|
1208
2110
|
|
|
1209
2111
|
def set_package_name(self, package_name: str):
|
|
1210
2112
|
"""Set the package name for .proto generation."""
|
|
@@ -1216,10 +2118,15 @@ class AsyncIOServer:
|
|
|
1216
2118
|
|
|
1217
2119
|
def mount(self, obj: object, package_name: str = ""):
|
|
1218
2120
|
"""Generate and compile proto files, then mount the service implementation (async)."""
|
|
1219
|
-
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
|
|
2121
|
+
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name) or (
|
|
2122
|
+
None,
|
|
2123
|
+
None,
|
|
2124
|
+
)
|
|
1220
2125
|
self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
|
|
1221
2126
|
|
|
1222
|
-
def mount_using_pb2_modules(
|
|
2127
|
+
def mount_using_pb2_modules(
|
|
2128
|
+
self, pb2_grpc_module: Any, pb2_module: Any, obj: object
|
|
2129
|
+
):
|
|
1223
2130
|
"""Connect the compiled gRPC modules with the async service implementation."""
|
|
1224
2131
|
concreteServiceClass = connect_obj_with_stub_async(
|
|
1225
2132
|
pb2_grpc_module, pb2_module, obj
|
|
@@ -1234,7 +2141,7 @@ class AsyncIOServer:
|
|
|
1234
2141
|
].full_name
|
|
1235
2142
|
self._service_names.append(full_service_name)
|
|
1236
2143
|
|
|
1237
|
-
async def run(self, *objs):
|
|
2144
|
+
async def run(self, *objs: object):
|
|
1238
2145
|
"""
|
|
1239
2146
|
Mount multiple async services and run the gRPC server with reflection and health check.
|
|
1240
2147
|
Press Ctrl+C or send SIGTERM to stop.
|
|
@@ -1251,20 +2158,22 @@ class AsyncIOServer:
|
|
|
1251
2158
|
health_pb2_grpc.add_HealthServicer_to_server(health_servicer, self._server)
|
|
1252
2159
|
reflection.enable_server_reflection(SERVICE_NAMES, self._server)
|
|
1253
2160
|
|
|
1254
|
-
self._server.add_insecure_port(f"[::]:{self._port}")
|
|
2161
|
+
_ = self._server.add_insecure_port(f"[::]:{self._port}")
|
|
1255
2162
|
await self._server.start()
|
|
1256
2163
|
|
|
1257
2164
|
shutdown_event = asyncio.Event()
|
|
1258
2165
|
|
|
1259
|
-
def shutdown(signum, frame):
|
|
2166
|
+
def shutdown(signum: signal.Signals, frame: Any):
|
|
2167
|
+
_ = signum
|
|
2168
|
+
_ = frame
|
|
1260
2169
|
print("Received shutdown signal...")
|
|
1261
2170
|
shutdown_event.set()
|
|
1262
2171
|
|
|
1263
2172
|
for s in [signal.SIGTERM, signal.SIGINT]:
|
|
1264
|
-
signal.signal(s, shutdown)
|
|
2173
|
+
_ = signal.signal(s, shutdown) # pyright:ignore[reportArgumentType]
|
|
1265
2174
|
|
|
1266
2175
|
print("gRPC server is running...")
|
|
1267
|
-
await shutdown_event.wait()
|
|
2176
|
+
_ = await shutdown_event.wait()
|
|
1268
2177
|
await self._server.stop(10)
|
|
1269
2178
|
print("gRPC server shutdown.")
|
|
1270
2179
|
|
|
@@ -1275,17 +2184,22 @@ class WSGIApp:
|
|
|
1275
2184
|
Useful for embedding gRPC within an existing WSGI stack.
|
|
1276
2185
|
"""
|
|
1277
2186
|
|
|
1278
|
-
def __init__(self, app):
|
|
1279
|
-
self._app = grpcWSGI(app)
|
|
1280
|
-
self._service_names = []
|
|
1281
|
-
self._package_name = ""
|
|
2187
|
+
def __init__(self, app: Any):
|
|
2188
|
+
self._app: grpcWSGI = grpcWSGI(app)
|
|
2189
|
+
self._service_names: list[str] = []
|
|
2190
|
+
self._package_name: str = ""
|
|
1282
2191
|
|
|
1283
2192
|
def mount(self, obj: object, package_name: str = ""):
|
|
1284
2193
|
"""Generate and compile proto files, then mount the service implementation."""
|
|
1285
|
-
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
|
|
2194
|
+
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name) or (
|
|
2195
|
+
None,
|
|
2196
|
+
None,
|
|
2197
|
+
)
|
|
1286
2198
|
self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
|
|
1287
2199
|
|
|
1288
|
-
def mount_using_pb2_modules(
|
|
2200
|
+
def mount_using_pb2_modules(
|
|
2201
|
+
self, pb2_grpc_module: Any, pb2_module: Any, obj: object
|
|
2202
|
+
):
|
|
1289
2203
|
"""Connect the compiled gRPC modules with the service implementation."""
|
|
1290
2204
|
concreteServiceClass = connect_obj_with_stub(pb2_grpc_module, pb2_module, obj)
|
|
1291
2205
|
service_name = obj.__class__.__name__
|
|
@@ -1298,12 +2212,16 @@ class WSGIApp:
|
|
|
1298
2212
|
].full_name
|
|
1299
2213
|
self._service_names.append(full_service_name)
|
|
1300
2214
|
|
|
1301
|
-
def mount_objs(self, *objs):
|
|
2215
|
+
def mount_objs(self, *objs: object):
|
|
1302
2216
|
"""Mount multiple service objects into this WSGI app."""
|
|
1303
2217
|
for obj in objs:
|
|
1304
2218
|
self.mount(obj, self._package_name)
|
|
1305
2219
|
|
|
1306
|
-
def __call__(
|
|
2220
|
+
def __call__(
|
|
2221
|
+
self,
|
|
2222
|
+
environ: dict[str, Any],
|
|
2223
|
+
start_response: Callable[[str, list[tuple[str, str]]], None],
|
|
2224
|
+
) -> Any:
|
|
1307
2225
|
"""WSGI entry point."""
|
|
1308
2226
|
return self._app(environ, start_response)
|
|
1309
2227
|
|
|
@@ -1314,17 +2232,22 @@ class ASGIApp:
|
|
|
1314
2232
|
Useful for embedding gRPC within an existing ASGI stack.
|
|
1315
2233
|
"""
|
|
1316
2234
|
|
|
1317
|
-
def __init__(self, app):
|
|
1318
|
-
self._app = grpcASGI(app)
|
|
1319
|
-
self._service_names = []
|
|
1320
|
-
self._package_name = ""
|
|
2235
|
+
def __init__(self, app: Any):
|
|
2236
|
+
self._app: grpcASGI = grpcASGI(app)
|
|
2237
|
+
self._service_names: list[str] = []
|
|
2238
|
+
self._package_name: str = ""
|
|
1321
2239
|
|
|
1322
2240
|
def mount(self, obj: object, package_name: str = ""):
|
|
1323
2241
|
"""Generate and compile proto files, then mount the async service implementation."""
|
|
1324
|
-
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
|
|
2242
|
+
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name) or (
|
|
2243
|
+
None,
|
|
2244
|
+
None,
|
|
2245
|
+
)
|
|
1325
2246
|
self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
|
|
1326
2247
|
|
|
1327
|
-
def mount_using_pb2_modules(
|
|
2248
|
+
def mount_using_pb2_modules(
|
|
2249
|
+
self, pb2_grpc_module: Any, pb2_module: Any, obj: object
|
|
2250
|
+
):
|
|
1328
2251
|
"""Connect the compiled gRPC modules with the async service implementation."""
|
|
1329
2252
|
concreteServiceClass = connect_obj_with_stub_async(
|
|
1330
2253
|
pb2_grpc_module, pb2_module, obj
|
|
@@ -1339,18 +2262,27 @@ class ASGIApp:
|
|
|
1339
2262
|
].full_name
|
|
1340
2263
|
self._service_names.append(full_service_name)
|
|
1341
2264
|
|
|
1342
|
-
def mount_objs(self, *objs):
|
|
2265
|
+
def mount_objs(self, *objs: object):
|
|
1343
2266
|
"""Mount multiple service objects into this ASGI app."""
|
|
1344
2267
|
for obj in objs:
|
|
1345
2268
|
self.mount(obj, self._package_name)
|
|
1346
2269
|
|
|
1347
|
-
async def __call__(
|
|
2270
|
+
async def __call__(
|
|
2271
|
+
self,
|
|
2272
|
+
scope: dict[str, Any],
|
|
2273
|
+
receive: Callable[[], Any],
|
|
2274
|
+
send: Callable[[dict[str, Any]], Any],
|
|
2275
|
+
) -> Any:
|
|
1348
2276
|
"""ASGI entry point."""
|
|
1349
|
-
await self._app(scope, receive, send)
|
|
2277
|
+
_ = await self._app(scope, receive, send)
|
|
2278
|
+
|
|
2279
|
+
|
|
2280
|
+
def get_connecpy_asgi_app_class(connecpy_module: Any, service_name: str):
|
|
2281
|
+
return getattr(connecpy_module, f"{service_name}ASGIApplication")
|
|
1350
2282
|
|
|
1351
2283
|
|
|
1352
|
-
def
|
|
1353
|
-
return getattr(connecpy_module, f"{service_name}
|
|
2284
|
+
def get_connecpy_wsgi_app_class(connecpy_module: Any, service_name: str):
|
|
2285
|
+
return getattr(connecpy_module, f"{service_name}WSGIApplication")
|
|
1354
2286
|
|
|
1355
2287
|
|
|
1356
2288
|
class ConnecpyASGIApp:
|
|
@@ -1359,9 +2291,9 @@ class ConnecpyASGIApp:
|
|
|
1359
2291
|
"""
|
|
1360
2292
|
|
|
1361
2293
|
def __init__(self):
|
|
1362
|
-
self.
|
|
1363
|
-
self._service_names = []
|
|
1364
|
-
self._package_name = ""
|
|
2294
|
+
self._services: list[tuple[Any, str]] = [] # List of (app, path) tuples
|
|
2295
|
+
self._service_names: list[str] = []
|
|
2296
|
+
self._package_name: str = ""
|
|
1365
2297
|
|
|
1366
2298
|
def mount(self, obj: object, package_name: str = ""):
|
|
1367
2299
|
"""Generate and compile proto files, then mount the async service implementation."""
|
|
@@ -1370,28 +2302,59 @@ class ConnecpyASGIApp:
|
|
|
1370
2302
|
)
|
|
1371
2303
|
self.mount_using_pb2_modules(connecpy_module, pb2_module, obj)
|
|
1372
2304
|
|
|
1373
|
-
def mount_using_pb2_modules(
|
|
2305
|
+
def mount_using_pb2_modules(
|
|
2306
|
+
self, connecpy_module: Any, pb2_module: Any, obj: object
|
|
2307
|
+
):
|
|
1374
2308
|
"""Connect the compiled connecpy and pb2 modules with the async service implementation."""
|
|
1375
2309
|
concreteServiceClass = connect_obj_with_stub_async_connecpy(
|
|
1376
2310
|
connecpy_module, pb2_module, obj
|
|
1377
2311
|
)
|
|
1378
2312
|
service_name = obj.__class__.__name__
|
|
1379
2313
|
service_impl = concreteServiceClass()
|
|
1380
|
-
|
|
1381
|
-
|
|
2314
|
+
|
|
2315
|
+
# Get the service-specific ASGI application class
|
|
2316
|
+
app_class = get_connecpy_asgi_app_class(connecpy_module, service_name)
|
|
2317
|
+
app = app_class(service=service_impl)
|
|
2318
|
+
|
|
2319
|
+
# Store the app and its path for routing
|
|
2320
|
+
self._services.append((app, app.path))
|
|
2321
|
+
|
|
1382
2322
|
full_service_name = pb2_module.DESCRIPTOR.services_by_name[
|
|
1383
2323
|
service_name
|
|
1384
2324
|
].full_name
|
|
1385
2325
|
self._service_names.append(full_service_name)
|
|
1386
2326
|
|
|
1387
|
-
def mount_objs(self, *objs):
|
|
2327
|
+
def mount_objs(self, *objs: object):
|
|
1388
2328
|
"""Mount multiple service objects into this ASGI app."""
|
|
1389
2329
|
for obj in objs:
|
|
1390
2330
|
self.mount(obj, self._package_name)
|
|
1391
2331
|
|
|
1392
|
-
async def __call__(
|
|
1393
|
-
|
|
1394
|
-
|
|
2332
|
+
async def __call__(
|
|
2333
|
+
self,
|
|
2334
|
+
scope: dict[str, Any],
|
|
2335
|
+
receive: Callable[[], Any],
|
|
2336
|
+
send: Callable[[dict[str, Any]], Any],
|
|
2337
|
+
):
|
|
2338
|
+
"""ASGI entry point with routing for multiple services."""
|
|
2339
|
+
if scope["type"] != "http":
|
|
2340
|
+
await send({"type": "http.response.start", "status": 404})
|
|
2341
|
+
await send({"type": "http.response.body", "body": b"Not Found"})
|
|
2342
|
+
return
|
|
2343
|
+
|
|
2344
|
+
path = scope.get("path", "")
|
|
2345
|
+
|
|
2346
|
+
# Route to the appropriate service based on path
|
|
2347
|
+
for app, service_path in self._services:
|
|
2348
|
+
if path.startswith(service_path):
|
|
2349
|
+
return await app(scope, receive, send)
|
|
2350
|
+
|
|
2351
|
+
# If only one service is mounted, use it as default
|
|
2352
|
+
if len(self._services) == 1:
|
|
2353
|
+
return await self._services[0][0](scope, receive, send)
|
|
2354
|
+
|
|
2355
|
+
# No matching service found
|
|
2356
|
+
await send({"type": "http.response.start", "status": 404})
|
|
2357
|
+
await send({"type": "http.response.body", "body": b"Not Found"})
|
|
1395
2358
|
|
|
1396
2359
|
|
|
1397
2360
|
class ConnecpyWSGIApp:
|
|
@@ -1400,55 +2363,82 @@ class ConnecpyWSGIApp:
|
|
|
1400
2363
|
"""
|
|
1401
2364
|
|
|
1402
2365
|
def __init__(self):
|
|
1403
|
-
self.
|
|
1404
|
-
self._service_names = []
|
|
1405
|
-
self._package_name = ""
|
|
2366
|
+
self._services: list[tuple[Any, str]] = [] # List of (app, path) tuples
|
|
2367
|
+
self._service_names: list[str] = []
|
|
2368
|
+
self._package_name: str = ""
|
|
1406
2369
|
|
|
1407
2370
|
def mount(self, obj: object, package_name: str = ""):
|
|
1408
|
-
"""Generate and compile proto files, then mount the
|
|
2371
|
+
"""Generate and compile proto files, then mount the sync service implementation."""
|
|
1409
2372
|
connecpy_module, pb2_module = generate_and_compile_proto_using_connecpy(
|
|
1410
2373
|
obj, package_name
|
|
1411
2374
|
)
|
|
1412
2375
|
self.mount_using_pb2_modules(connecpy_module, pb2_module, obj)
|
|
1413
2376
|
|
|
1414
|
-
def mount_using_pb2_modules(
|
|
1415
|
-
|
|
2377
|
+
def mount_using_pb2_modules(
|
|
2378
|
+
self, connecpy_module: Any, pb2_module: Any, obj: object
|
|
2379
|
+
):
|
|
2380
|
+
"""Connect the compiled connecpy and pb2 modules with the sync service implementation."""
|
|
1416
2381
|
concreteServiceClass = connect_obj_with_stub_connecpy(
|
|
1417
2382
|
connecpy_module, pb2_module, obj
|
|
1418
2383
|
)
|
|
1419
2384
|
service_name = obj.__class__.__name__
|
|
1420
2385
|
service_impl = concreteServiceClass()
|
|
1421
|
-
|
|
1422
|
-
|
|
2386
|
+
|
|
2387
|
+
# Get the service-specific WSGI application class
|
|
2388
|
+
app_class = get_connecpy_wsgi_app_class(connecpy_module, service_name)
|
|
2389
|
+
app = app_class(service=service_impl)
|
|
2390
|
+
|
|
2391
|
+
# Store the app and its path for routing
|
|
2392
|
+
self._services.append((app, app.path))
|
|
2393
|
+
|
|
1423
2394
|
full_service_name = pb2_module.DESCRIPTOR.services_by_name[
|
|
1424
2395
|
service_name
|
|
1425
2396
|
].full_name
|
|
1426
2397
|
self._service_names.append(full_service_name)
|
|
1427
2398
|
|
|
1428
|
-
def mount_objs(self, *objs):
|
|
2399
|
+
def mount_objs(self, *objs: object):
|
|
1429
2400
|
"""Mount multiple service objects into this WSGI app."""
|
|
1430
2401
|
for obj in objs:
|
|
1431
2402
|
self.mount(obj, self._package_name)
|
|
1432
2403
|
|
|
1433
|
-
def __call__(
|
|
1434
|
-
|
|
1435
|
-
|
|
2404
|
+
def __call__(
|
|
2405
|
+
self,
|
|
2406
|
+
environ: dict[str, Any],
|
|
2407
|
+
start_response: Callable[[str, list[tuple[str, str]]], None],
|
|
2408
|
+
) -> Any:
|
|
2409
|
+
"""WSGI entry point with routing for multiple services."""
|
|
2410
|
+
path = environ.get("PATH_INFO", "")
|
|
2411
|
+
|
|
2412
|
+
# Route to the appropriate service based on path
|
|
2413
|
+
for app, service_path in self._services:
|
|
2414
|
+
if path.startswith(service_path):
|
|
2415
|
+
return app(environ, start_response)
|
|
2416
|
+
|
|
2417
|
+
# If only one service is mounted, use it as default
|
|
2418
|
+
if len(self._services) == 1:
|
|
2419
|
+
return self._services[0][0](environ, start_response)
|
|
2420
|
+
|
|
2421
|
+
# No matching service found
|
|
2422
|
+
start_response("404 Not Found", [("Content-Type", "text/plain")])
|
|
2423
|
+
return [b"Not Found"]
|
|
1436
2424
|
|
|
1437
2425
|
|
|
1438
2426
|
def main():
|
|
1439
2427
|
import argparse
|
|
1440
2428
|
|
|
1441
2429
|
parser = argparse.ArgumentParser(description="Generate and compile proto files.")
|
|
1442
|
-
parser.add_argument(
|
|
2430
|
+
_ = parser.add_argument(
|
|
1443
2431
|
"py_file", type=str, help="The Python file containing the service class."
|
|
1444
2432
|
)
|
|
1445
|
-
parser.add_argument(
|
|
2433
|
+
_ = parser.add_argument(
|
|
2434
|
+
"class_name", type=str, help="The name of the service class."
|
|
2435
|
+
)
|
|
1446
2436
|
args = parser.parse_args()
|
|
1447
2437
|
|
|
1448
2438
|
module_name = os.path.splitext(basename(args.py_file))[0]
|
|
1449
2439
|
module = importlib.import_module(module_name)
|
|
1450
2440
|
klass = getattr(module, args.class_name)
|
|
1451
|
-
generate_and_compile_proto(klass())
|
|
2441
|
+
_ = generate_and_compile_proto(klass())
|
|
1452
2442
|
|
|
1453
2443
|
|
|
1454
2444
|
if __name__ == "__main__":
|