pydantic-rpc 0.6.1__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_rpc/__init__.py +10 -0
- pydantic_rpc/core.py +1829 -469
- pydantic_rpc/mcp/__init__.py +5 -0
- pydantic_rpc/mcp/converter.py +115 -0
- pydantic_rpc/mcp/exporter.py +283 -0
- pydantic_rpc/py.typed +0 -0
- {pydantic_rpc-0.6.1.dist-info → pydantic_rpc-0.8.0.dist-info}/METADATA +353 -18
- pydantic_rpc-0.8.0.dist-info/RECORD +10 -0
- pydantic_rpc-0.8.0.dist-info/WHEEL +4 -0
- {pydantic_rpc-0.6.1.dist-info → pydantic_rpc-0.8.0.dist-info}/entry_points.txt +1 -0
- pydantic_rpc-0.6.1.dist-info/RECORD +0 -8
- pydantic_rpc-0.6.1.dist-info/WHEEL +0 -4
- pydantic_rpc-0.6.1.dist-info/licenses/LICENSE +0 -21
pydantic_rpc/core.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import annotated_types
|
|
2
1
|
import asyncio
|
|
2
|
+
import datetime
|
|
3
3
|
import enum
|
|
4
4
|
import importlib.util
|
|
5
5
|
import inspect
|
|
@@ -8,33 +8,35 @@ import signal
|
|
|
8
8
|
import sys
|
|
9
9
|
import time
|
|
10
10
|
import types
|
|
11
|
-
import
|
|
11
|
+
from typing import Union
|
|
12
|
+
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
12
13
|
from concurrent import futures
|
|
14
|
+
from pathlib import Path
|
|
13
15
|
from posixpath import basename
|
|
14
16
|
from typing import (
|
|
15
|
-
|
|
16
|
-
|
|
17
|
+
Any,
|
|
18
|
+
TypeAlias,
|
|
17
19
|
get_args,
|
|
18
20
|
get_origin,
|
|
19
|
-
|
|
20
|
-
|
|
21
|
+
cast,
|
|
22
|
+
TypeGuard,
|
|
21
23
|
)
|
|
22
|
-
from collections.abc import AsyncIterator
|
|
23
24
|
|
|
25
|
+
import annotated_types
|
|
24
26
|
import grpc
|
|
27
|
+
from grpc import ServicerContext
|
|
28
|
+
import grpc_tools
|
|
29
|
+
from connecpy.code import Code as Errors
|
|
30
|
+
|
|
31
|
+
# Protobuf Python modules for Timestamp, Duration (requires protobuf / grpcio)
|
|
32
|
+
from google.protobuf import duration_pb2, timestamp_pb2, empty_pb2
|
|
25
33
|
from grpc_health.v1 import health_pb2, health_pb2_grpc
|
|
26
34
|
from grpc_health.v1.health import HealthServicer
|
|
27
35
|
from grpc_reflection.v1alpha import reflection
|
|
28
36
|
from grpc_tools import protoc
|
|
29
37
|
from pydantic import BaseModel, ValidationError
|
|
30
|
-
from sonora.wsgi import grpcWSGI
|
|
31
38
|
from sonora.asgi import grpcASGI
|
|
32
|
-
from
|
|
33
|
-
from connecpy.errors import Errors
|
|
34
|
-
from connecpy.wsgi import ConnecpyWSGIApp as ConnecpyWSGI
|
|
35
|
-
|
|
36
|
-
# Protobuf Python modules for Timestamp, Duration (requires protobuf / grpcio)
|
|
37
|
-
from google.protobuf import timestamp_pb2, duration_pb2
|
|
39
|
+
from sonora.wsgi import grpcWSGI
|
|
38
40
|
|
|
39
41
|
###############################################################################
|
|
40
42
|
# 1. Message definitions & converter extensions
|
|
@@ -45,8 +47,57 @@ from google.protobuf import timestamp_pb2, duration_pb2
|
|
|
45
47
|
|
|
46
48
|
Message: TypeAlias = BaseModel
|
|
47
49
|
|
|
50
|
+
# Cache for serializer checks
|
|
51
|
+
_has_serializers_cache: dict[type[Message], bool] = {}
|
|
52
|
+
_has_field_serializers_cache: dict[type[Message], set[str]] = {}
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# Serializer strategy configuration
|
|
56
|
+
class SerializerStrategy(enum.Enum):
|
|
57
|
+
"""Strategy for applying Pydantic serializers."""
|
|
58
|
+
|
|
59
|
+
NONE = "none" # No serializers applied
|
|
60
|
+
SHALLOW = "shallow" # Only top-level serializers
|
|
61
|
+
DEEP = "deep" # Nested serializers too (default)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# Get strategy from environment variable
|
|
65
|
+
_SERIALIZER_STRATEGY = SerializerStrategy(
|
|
66
|
+
os.getenv("PYDANTIC_RPC_SERIALIZER_STRATEGY", "deep").lower()
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def has_serializers(msg_type: type[Message]) -> bool:
|
|
71
|
+
"""Check if a Message type has any serializers (cached for performance)."""
|
|
72
|
+
if msg_type not in _has_serializers_cache:
|
|
73
|
+
# Check for both field and model serializers
|
|
74
|
+
has_field = bool(getattr(msg_type, "__pydantic_serializer__", None))
|
|
75
|
+
has_model = bool(
|
|
76
|
+
getattr(msg_type, "model_dump", None) and msg_type != BaseModel
|
|
77
|
+
)
|
|
78
|
+
_has_serializers_cache[msg_type] = has_field or has_model
|
|
79
|
+
return _has_serializers_cache[msg_type]
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def get_field_serializers(msg_type: type[Message]) -> set[str]:
|
|
83
|
+
"""Get the set of fields that have serializers (cached for performance)."""
|
|
84
|
+
if msg_type not in _has_field_serializers_cache:
|
|
85
|
+
fields_with_serializers = set()
|
|
86
|
+
# Check model_fields_set or similar for fields with serializers
|
|
87
|
+
if hasattr(msg_type, "__pydantic_decorators__"):
|
|
88
|
+
decorators = msg_type.__pydantic_decorators__
|
|
89
|
+
if hasattr(decorators, "field_serializers"):
|
|
90
|
+
fields_with_serializers = set(decorators.field_serializers.keys())
|
|
91
|
+
_has_field_serializers_cache[msg_type] = fields_with_serializers
|
|
92
|
+
return _has_field_serializers_cache[msg_type]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def is_none_type(annotation: Any) -> TypeGuard[type[None] | None]:
|
|
96
|
+
"""Check if annotation represents None/NoneType (handles both None and type(None))."""
|
|
97
|
+
return annotation is None or annotation is type(None)
|
|
98
|
+
|
|
48
99
|
|
|
49
|
-
def primitiveProtoValueToPythonValue(value):
|
|
100
|
+
def primitiveProtoValueToPythonValue(value: Any):
|
|
50
101
|
# Returns the value as-is (primitive type).
|
|
51
102
|
return value
|
|
52
103
|
|
|
@@ -75,11 +126,20 @@ def python_to_duration(td: datetime.timedelta) -> duration_pb2.Duration: # type
|
|
|
75
126
|
return d
|
|
76
127
|
|
|
77
128
|
|
|
78
|
-
def generate_converter(annotation:
|
|
129
|
+
def generate_converter(annotation: type[Any] | None) -> Callable[[Any], Any]:
|
|
79
130
|
"""
|
|
80
131
|
Returns a converter function to convert protobuf types to Python types.
|
|
81
132
|
This is used primarily when handling incoming requests.
|
|
82
133
|
"""
|
|
134
|
+
# For NoneType (Empty messages)
|
|
135
|
+
if is_none_type(annotation):
|
|
136
|
+
|
|
137
|
+
def empty_converter(value: empty_pb2.Empty): # type: ignore
|
|
138
|
+
_ = value
|
|
139
|
+
return None
|
|
140
|
+
|
|
141
|
+
return empty_converter
|
|
142
|
+
|
|
83
143
|
# For primitive types
|
|
84
144
|
if annotation in (int, str, bool, bytes, float):
|
|
85
145
|
return primitiveProtoValueToPythonValue
|
|
@@ -87,7 +147,7 @@ def generate_converter(annotation: Type) -> Callable:
|
|
|
87
147
|
# For enum types
|
|
88
148
|
if inspect.isclass(annotation) and issubclass(annotation, enum.Enum):
|
|
89
149
|
|
|
90
|
-
def enum_converter(value):
|
|
150
|
+
def enum_converter(value: enum.Enum):
|
|
91
151
|
return annotation(value)
|
|
92
152
|
|
|
93
153
|
return enum_converter
|
|
@@ -114,7 +174,7 @@ def generate_converter(annotation: Type) -> Callable:
|
|
|
114
174
|
if origin in (list, tuple):
|
|
115
175
|
item_converter = generate_converter(get_args(annotation)[0])
|
|
116
176
|
|
|
117
|
-
def seq_converter(value):
|
|
177
|
+
def seq_converter(value: list[Any] | tuple[Any, ...]):
|
|
118
178
|
return [item_converter(v) for v in value]
|
|
119
179
|
|
|
120
180
|
return seq_converter
|
|
@@ -124,40 +184,130 @@ def generate_converter(annotation: Type) -> Callable:
|
|
|
124
184
|
key_converter = generate_converter(get_args(annotation)[0])
|
|
125
185
|
value_converter = generate_converter(get_args(annotation)[1])
|
|
126
186
|
|
|
127
|
-
def dict_converter(value):
|
|
187
|
+
def dict_converter(value: dict[Any, Any]):
|
|
128
188
|
return {key_converter(k): value_converter(v) for k, v in value.items()}
|
|
129
189
|
|
|
130
190
|
return dict_converter
|
|
131
191
|
|
|
132
192
|
# For Message classes
|
|
133
193
|
if inspect.isclass(annotation) and issubclass(annotation, Message):
|
|
194
|
+
# Check if it's an empty message class
|
|
195
|
+
if not annotation.model_fields:
|
|
196
|
+
|
|
197
|
+
def empty_message_converter(value: empty_pb2.Empty): # type: ignore
|
|
198
|
+
_ = value
|
|
199
|
+
return annotation() # Return instance of the empty message class
|
|
200
|
+
|
|
201
|
+
return empty_message_converter
|
|
134
202
|
return generate_message_converter(annotation)
|
|
135
203
|
|
|
136
204
|
# For union types or other unsupported cases, just return the value as-is.
|
|
137
205
|
return primitiveProtoValueToPythonValue
|
|
138
206
|
|
|
139
207
|
|
|
140
|
-
def generate_message_converter(
|
|
208
|
+
def generate_message_converter(
|
|
209
|
+
arg_type: type[Message] | type[None] | None,
|
|
210
|
+
) -> Callable[[Any], Message | None]:
|
|
141
211
|
"""Return a converter function for protobuf -> Python Message."""
|
|
142
|
-
if arg_type is None or not issubclass(arg_type, Message):
|
|
143
|
-
raise TypeError("Request arg must be subclass of Message")
|
|
144
212
|
|
|
213
|
+
# Handle NoneType (Empty messages)
|
|
214
|
+
if is_none_type(arg_type):
|
|
215
|
+
|
|
216
|
+
def empty_converter(request: Any) -> None:
|
|
217
|
+
_ = request
|
|
218
|
+
return None
|
|
219
|
+
|
|
220
|
+
return empty_converter
|
|
221
|
+
|
|
222
|
+
arg_type = cast("type[Message]", arg_type)
|
|
145
223
|
fields = arg_type.model_fields
|
|
224
|
+
|
|
225
|
+
# Handle empty message classes (no fields)
|
|
226
|
+
if not fields:
|
|
227
|
+
|
|
228
|
+
def empty_message_converter(request: Any) -> Message:
|
|
229
|
+
_ = request # The incoming request will be google.protobuf.Empty
|
|
230
|
+
return arg_type() # Return an instance of the empty message class
|
|
231
|
+
|
|
232
|
+
return empty_message_converter
|
|
146
233
|
converters = {
|
|
147
234
|
field: generate_converter(field_type.annotation) # type: ignore
|
|
148
235
|
for field, field_type in fields.items()
|
|
149
236
|
}
|
|
150
237
|
|
|
151
|
-
def converter(request):
|
|
238
|
+
def converter(request: Any) -> Message:
|
|
152
239
|
rdict = {}
|
|
153
|
-
for
|
|
154
|
-
|
|
155
|
-
|
|
240
|
+
for field_name, field_info in fields.items():
|
|
241
|
+
field_type = field_info.annotation
|
|
242
|
+
|
|
243
|
+
# Check if this is a union type
|
|
244
|
+
if field_type is not None and is_union_type(field_type):
|
|
245
|
+
union_args = flatten_union(field_type)
|
|
246
|
+
has_none = type(None) in union_args
|
|
247
|
+
non_none_args = [arg for arg in union_args if arg is not type(None)]
|
|
248
|
+
|
|
249
|
+
if has_none and len(non_none_args) == 1:
|
|
250
|
+
# This is Optional[T] - check if protobuf field is set
|
|
251
|
+
try:
|
|
252
|
+
if hasattr(request, "HasField") and request.HasField(
|
|
253
|
+
field_name
|
|
254
|
+
):
|
|
255
|
+
rdict[field_name] = converters[field_name](
|
|
256
|
+
getattr(request, field_name)
|
|
257
|
+
)
|
|
258
|
+
else:
|
|
259
|
+
# Field not set in protobuf, set to None for Optional fields
|
|
260
|
+
rdict[field_name] = None
|
|
261
|
+
except ValueError:
|
|
262
|
+
# HasField doesn't work for this field type (e.g., repeated fields)
|
|
263
|
+
# Fall back to regular conversion
|
|
264
|
+
rdict[field_name] = converters[field_name](
|
|
265
|
+
getattr(request, field_name)
|
|
266
|
+
)
|
|
267
|
+
elif len(non_none_args) > 1:
|
|
268
|
+
# This is a oneof field (Union[str, int] etc.)
|
|
269
|
+
# Check which oneof field is set
|
|
270
|
+
try:
|
|
271
|
+
which_field = request.WhichOneof(field_name)
|
|
272
|
+
if which_field:
|
|
273
|
+
# Extract the value from the set oneof field
|
|
274
|
+
proto_value = getattr(request, which_field)
|
|
275
|
+
|
|
276
|
+
# Determine the Python type from the oneof field name
|
|
277
|
+
# e.g., "value_string" -> str, "value_int32" -> int
|
|
278
|
+
for union_arg in non_none_args:
|
|
279
|
+
proto_typename = protobuf_type_mapping(union_arg)
|
|
280
|
+
if (
|
|
281
|
+
proto_typename
|
|
282
|
+
and which_field
|
|
283
|
+
== f"{field_name}_{proto_typename.replace('.', '_')}"
|
|
284
|
+
):
|
|
285
|
+
# Convert using the specific type converter
|
|
286
|
+
type_converter = generate_converter(union_arg)
|
|
287
|
+
rdict[field_name] = type_converter(proto_value)
|
|
288
|
+
break
|
|
289
|
+
except (AttributeError, ValueError):
|
|
290
|
+
# WhichOneof failed, try fallback
|
|
291
|
+
# This shouldn't happen for properly generated oneof fields
|
|
292
|
+
pass
|
|
293
|
+
else:
|
|
294
|
+
# Union with only None type (shouldn't happen)
|
|
295
|
+
pass
|
|
296
|
+
else:
|
|
297
|
+
# For non-union fields, convert normally
|
|
298
|
+
rdict[field_name] = converters[field_name](getattr(request, field_name))
|
|
299
|
+
|
|
300
|
+
# Use model_validate to support @model_validator
|
|
301
|
+
try:
|
|
302
|
+
return arg_type.model_validate(rdict)
|
|
303
|
+
except AttributeError:
|
|
304
|
+
# Fallback for older Pydantic versions or if model_validate doesn't exist
|
|
305
|
+
return arg_type(**rdict)
|
|
156
306
|
|
|
157
307
|
return converter
|
|
158
308
|
|
|
159
309
|
|
|
160
|
-
def python_value_to_proto_value(field_type:
|
|
310
|
+
def python_value_to_proto_value(field_type: type[Any], value: Any) -> Any:
|
|
161
311
|
"""
|
|
162
312
|
Converts Python values to protobuf values.
|
|
163
313
|
Used primarily when constructing a response object.
|
|
@@ -179,77 +329,129 @@ def python_value_to_proto_value(field_type: Type, value):
|
|
|
179
329
|
###############################################################################
|
|
180
330
|
|
|
181
331
|
|
|
182
|
-
def connect_obj_with_stub(
|
|
332
|
+
def connect_obj_with_stub(
|
|
333
|
+
pb2_grpc_module: Any, pb2_module: Any, service_obj: object
|
|
334
|
+
) -> type:
|
|
183
335
|
"""
|
|
184
336
|
Connect a Python service object to a gRPC stub, generating server methods.
|
|
337
|
+
Returns a subclass of the generated Servicer stub with concrete implementations.
|
|
185
338
|
"""
|
|
186
339
|
service_class = service_obj.__class__
|
|
187
340
|
stub_class_name = service_class.__name__ + "Servicer"
|
|
188
341
|
stub_class = getattr(pb2_grpc_module, stub_class_name)
|
|
189
342
|
|
|
190
343
|
class ConcreteServiceClass(stub_class):
|
|
344
|
+
"""Dynamically generated servicer class with stub methods implemented."""
|
|
345
|
+
|
|
191
346
|
pass
|
|
192
347
|
|
|
193
|
-
def implement_stub_method(
|
|
194
|
-
|
|
348
|
+
def implement_stub_method(
|
|
349
|
+
method: Callable[..., Message],
|
|
350
|
+
) -> Callable[[object, Any, Any], Any]:
|
|
351
|
+
"""
|
|
352
|
+
Wraps a user-defined method (self, *args) -> R into a gRPC stub signature:
|
|
353
|
+
(self, request_proto, context) -> response_proto
|
|
354
|
+
"""
|
|
195
355
|
sig = inspect.signature(method)
|
|
196
356
|
arg_type = get_request_arg_type(sig)
|
|
197
|
-
# Convert request from protobuf to Python.
|
|
198
|
-
converter = generate_message_converter(arg_type)
|
|
199
|
-
|
|
200
357
|
response_type = sig.return_annotation
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
match size_of_parameters:
|
|
204
|
-
case 1:
|
|
358
|
+
param_count = len(sig.parameters)
|
|
359
|
+
converter = generate_message_converter(arg_type)
|
|
205
360
|
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
361
|
+
if param_count == 0 or param_count == 1:
|
|
362
|
+
|
|
363
|
+
def stub_method(
|
|
364
|
+
self: object,
|
|
365
|
+
request: Any,
|
|
366
|
+
context: Any,
|
|
367
|
+
*,
|
|
368
|
+
original: Callable[..., Message] = method,
|
|
369
|
+
) -> Any:
|
|
370
|
+
_ = self
|
|
371
|
+
try:
|
|
372
|
+
if param_count == 0:
|
|
373
|
+
# Method takes no parameters
|
|
374
|
+
resp_obj = original()
|
|
375
|
+
elif is_none_type(arg_type):
|
|
376
|
+
resp_obj = original(None) # Fixed: pass None instead of no args
|
|
377
|
+
else:
|
|
209
378
|
arg = converter(request)
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
379
|
+
resp_obj = original(arg)
|
|
380
|
+
|
|
381
|
+
if is_none_type(response_type):
|
|
382
|
+
return empty_pb2.Empty() # type: ignore
|
|
383
|
+
elif (
|
|
384
|
+
inspect.isclass(response_type)
|
|
385
|
+
and issubclass(response_type, Message)
|
|
386
|
+
and not response_type.model_fields
|
|
387
|
+
):
|
|
388
|
+
# Empty message class
|
|
389
|
+
return empty_pb2.Empty() # type: ignore
|
|
390
|
+
else:
|
|
213
391
|
return convert_python_message_to_proto(
|
|
214
392
|
resp_obj, response_type, pb2_module
|
|
215
393
|
)
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
394
|
+
except ValidationError as e:
|
|
395
|
+
return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
|
396
|
+
except Exception as e:
|
|
397
|
+
return context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
398
|
+
|
|
399
|
+
elif param_count == 2:
|
|
400
|
+
|
|
401
|
+
def stub_method(
|
|
402
|
+
self: object,
|
|
403
|
+
request: Any,
|
|
404
|
+
context: Any,
|
|
405
|
+
*,
|
|
406
|
+
original: Callable[..., Message] = method,
|
|
407
|
+
) -> Any:
|
|
408
|
+
_ = self
|
|
409
|
+
try:
|
|
410
|
+
if is_none_type(arg_type):
|
|
411
|
+
resp_obj = original(
|
|
412
|
+
None, context
|
|
413
|
+
) # Fixed: pass None instead of Empty
|
|
414
|
+
else:
|
|
227
415
|
arg = converter(request)
|
|
228
|
-
resp_obj =
|
|
416
|
+
resp_obj = original(arg, context)
|
|
417
|
+
|
|
418
|
+
if is_none_type(response_type):
|
|
419
|
+
return empty_pb2.Empty() # type: ignore
|
|
420
|
+
elif (
|
|
421
|
+
inspect.isclass(response_type)
|
|
422
|
+
and issubclass(response_type, Message)
|
|
423
|
+
and not response_type.model_fields
|
|
424
|
+
):
|
|
425
|
+
# Empty message class
|
|
426
|
+
return empty_pb2.Empty() # type: ignore
|
|
427
|
+
else:
|
|
229
428
|
return convert_python_message_to_proto(
|
|
230
429
|
resp_obj, response_type, pb2_module
|
|
231
430
|
)
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
431
|
+
except ValidationError as e:
|
|
432
|
+
return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
|
433
|
+
except Exception as e:
|
|
434
|
+
return context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
236
435
|
|
|
237
|
-
|
|
436
|
+
else:
|
|
437
|
+
raise TypeError(
|
|
438
|
+
f"Method '{method.__name__}' must have exactly 1 or 2 parameters, got {param_count}"
|
|
439
|
+
)
|
|
238
440
|
|
|
239
|
-
|
|
240
|
-
raise Exception("Method must have exactly one or two parameters")
|
|
441
|
+
return stub_method
|
|
241
442
|
|
|
443
|
+
# Attach all RPC methods from service_obj to the concrete servicer
|
|
242
444
|
for method_name, method in get_rpc_methods(service_obj):
|
|
243
|
-
if
|
|
445
|
+
if method_name.startswith("_"):
|
|
244
446
|
continue
|
|
245
|
-
|
|
246
|
-
a_method = implement_stub_method(method)
|
|
247
|
-
setattr(ConcreteServiceClass, method_name, a_method)
|
|
447
|
+
setattr(ConcreteServiceClass, method_name, implement_stub_method(method))
|
|
248
448
|
|
|
249
449
|
return ConcreteServiceClass
|
|
250
450
|
|
|
251
451
|
|
|
252
|
-
def connect_obj_with_stub_async(
|
|
452
|
+
def connect_obj_with_stub_async(
|
|
453
|
+
pb2_grpc_module: Any, pb2_module: Any, obj: object
|
|
454
|
+
) -> type:
|
|
253
455
|
"""
|
|
254
456
|
Connect a Python service object to a gRPC stub for async methods.
|
|
255
457
|
"""
|
|
@@ -260,26 +462,151 @@ def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> typ
|
|
|
260
462
|
class ConcreteServiceClass(stub_class):
|
|
261
463
|
pass
|
|
262
464
|
|
|
263
|
-
def implement_stub_method(
|
|
465
|
+
def implement_stub_method(
|
|
466
|
+
method: Callable[..., Any],
|
|
467
|
+
) -> Callable[[object, Any, Any], Any]:
|
|
264
468
|
sig = inspect.signature(method)
|
|
265
|
-
|
|
266
|
-
|
|
469
|
+
input_type = get_request_arg_type(sig)
|
|
470
|
+
is_input_stream = is_stream_type(input_type)
|
|
267
471
|
response_type = sig.return_annotation
|
|
472
|
+
is_output_stream = is_stream_type(response_type)
|
|
268
473
|
size_of_parameters = len(sig.parameters)
|
|
269
474
|
|
|
270
|
-
if
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
475
|
+
if size_of_parameters not in (0, 1, 2):
|
|
476
|
+
raise TypeError(
|
|
477
|
+
f"Method '{method.__name__}' must have 0, 1 or 2 parameters, got {size_of_parameters}"
|
|
478
|
+
)
|
|
274
479
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
480
|
+
if is_input_stream:
|
|
481
|
+
input_item_type = get_args(input_type)[0]
|
|
482
|
+
item_converter = generate_message_converter(input_item_type)
|
|
483
|
+
|
|
484
|
+
async def convert_iterator(
|
|
485
|
+
proto_iter: AsyncIterator[Any],
|
|
486
|
+
) -> AsyncIterator[Message]:
|
|
487
|
+
async for proto in proto_iter:
|
|
488
|
+
result = item_converter(proto)
|
|
489
|
+
if result is None:
|
|
490
|
+
raise TypeError(
|
|
491
|
+
f"Unexpected None result from converter for type {input_item_type}"
|
|
492
|
+
)
|
|
493
|
+
yield result
|
|
494
|
+
|
|
495
|
+
if is_output_stream:
|
|
496
|
+
# stream-stream
|
|
497
|
+
output_item_type = get_args(response_type)[0]
|
|
498
|
+
|
|
499
|
+
if size_of_parameters == 1:
|
|
500
|
+
|
|
501
|
+
async def stub_method(
|
|
502
|
+
self: object,
|
|
503
|
+
request_iterator: AsyncIterator[Any],
|
|
504
|
+
context: Any,
|
|
505
|
+
) -> AsyncIterator[Any]:
|
|
506
|
+
_ = self
|
|
507
|
+
try:
|
|
508
|
+
arg_iter = convert_iterator(request_iterator)
|
|
509
|
+
async for resp_obj in method(arg_iter):
|
|
510
|
+
yield convert_python_message_to_proto(
|
|
511
|
+
resp_obj, output_item_type, pb2_module
|
|
512
|
+
)
|
|
513
|
+
except ValidationError as e:
|
|
514
|
+
await context.abort(
|
|
515
|
+
grpc.StatusCode.INVALID_ARGUMENT, str(e)
|
|
516
|
+
)
|
|
517
|
+
except Exception as e:
|
|
518
|
+
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
519
|
+
|
|
520
|
+
else: # size_of_parameters == 2
|
|
521
|
+
|
|
522
|
+
async def stub_method(
|
|
523
|
+
self: object,
|
|
524
|
+
request_iterator: AsyncIterator[Any],
|
|
525
|
+
context: Any,
|
|
526
|
+
) -> AsyncIterator[Any]:
|
|
527
|
+
_ = self
|
|
528
|
+
try:
|
|
529
|
+
arg_iter = convert_iterator(request_iterator)
|
|
530
|
+
async for resp_obj in method(arg_iter, context):
|
|
531
|
+
yield convert_python_message_to_proto(
|
|
532
|
+
resp_obj, output_item_type, pb2_module
|
|
533
|
+
)
|
|
534
|
+
except ValidationError as e:
|
|
535
|
+
await context.abort(
|
|
536
|
+
grpc.StatusCode.INVALID_ARGUMENT, str(e)
|
|
537
|
+
)
|
|
538
|
+
except Exception as e:
|
|
539
|
+
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
540
|
+
|
|
541
|
+
return stub_method
|
|
542
|
+
|
|
543
|
+
else:
|
|
544
|
+
# stream-unary
|
|
545
|
+
if size_of_parameters == 1:
|
|
546
|
+
|
|
547
|
+
async def stub_method(
|
|
548
|
+
self: object,
|
|
549
|
+
request_iterator: AsyncIterator[Any],
|
|
550
|
+
context: Any,
|
|
551
|
+
) -> Any:
|
|
552
|
+
_ = self
|
|
553
|
+
try:
|
|
554
|
+
arg_iter = convert_iterator(request_iterator)
|
|
555
|
+
resp_obj = await method(arg_iter)
|
|
556
|
+
return convert_python_message_to_proto(
|
|
557
|
+
resp_obj, response_type, pb2_module
|
|
558
|
+
)
|
|
559
|
+
except ValidationError as e:
|
|
560
|
+
await context.abort(
|
|
561
|
+
grpc.StatusCode.INVALID_ARGUMENT, str(e)
|
|
562
|
+
)
|
|
563
|
+
except Exception as e:
|
|
564
|
+
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
565
|
+
|
|
566
|
+
else: # size_of_parameters == 2
|
|
567
|
+
|
|
568
|
+
async def stub_method(
|
|
569
|
+
self: object,
|
|
570
|
+
request_iterator: AsyncIterator[Any],
|
|
571
|
+
context: Any,
|
|
572
|
+
) -> Any:
|
|
573
|
+
_ = self
|
|
574
|
+
try:
|
|
575
|
+
arg_iter = convert_iterator(request_iterator)
|
|
576
|
+
resp_obj = await method(arg_iter, context)
|
|
577
|
+
return convert_python_message_to_proto(
|
|
578
|
+
resp_obj, response_type, pb2_module
|
|
579
|
+
)
|
|
580
|
+
except ValidationError as e:
|
|
581
|
+
await context.abort(
|
|
582
|
+
grpc.StatusCode.INVALID_ARGUMENT, str(e)
|
|
583
|
+
)
|
|
584
|
+
except Exception as e:
|
|
585
|
+
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
586
|
+
|
|
587
|
+
return stub_method
|
|
588
|
+
|
|
589
|
+
else:
|
|
590
|
+
# unary input
|
|
591
|
+
converter = generate_message_converter(input_type)
|
|
592
|
+
|
|
593
|
+
if is_output_stream:
|
|
594
|
+
# unary-stream
|
|
595
|
+
output_item_type = get_args(response_type)[0]
|
|
596
|
+
|
|
597
|
+
if size_of_parameters == 1:
|
|
598
|
+
|
|
599
|
+
async def stub_method(
|
|
600
|
+
self: object,
|
|
601
|
+
request: Any,
|
|
602
|
+
context: Any,
|
|
603
|
+
) -> AsyncIterator[Any]:
|
|
604
|
+
_ = self
|
|
278
605
|
try:
|
|
279
606
|
arg = converter(request)
|
|
280
607
|
async for resp_obj in method(arg):
|
|
281
608
|
yield convert_python_message_to_proto(
|
|
282
|
-
resp_obj,
|
|
609
|
+
resp_obj, output_item_type, pb2_module
|
|
283
610
|
)
|
|
284
611
|
except ValidationError as e:
|
|
285
612
|
await context.abort(
|
|
@@ -288,17 +615,19 @@ def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> typ
|
|
|
288
615
|
except Exception as e:
|
|
289
616
|
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
290
617
|
|
|
291
|
-
|
|
292
|
-
case 2:
|
|
618
|
+
else: # size_of_parameters == 2
|
|
293
619
|
|
|
294
|
-
async def
|
|
295
|
-
self
|
|
296
|
-
|
|
620
|
+
async def stub_method(
|
|
621
|
+
self: object,
|
|
622
|
+
request: Any,
|
|
623
|
+
context: Any,
|
|
624
|
+
) -> AsyncIterator[Any]:
|
|
625
|
+
_ = self
|
|
297
626
|
try:
|
|
298
627
|
arg = converter(request)
|
|
299
628
|
async for resp_obj in method(arg, context):
|
|
300
629
|
yield convert_python_message_to_proto(
|
|
301
|
-
resp_obj,
|
|
630
|
+
resp_obj, output_item_type, pb2_module
|
|
302
631
|
)
|
|
303
632
|
except ValidationError as e:
|
|
304
633
|
await context.abort(
|
|
@@ -307,45 +636,84 @@ def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> typ
|
|
|
307
636
|
except Exception as e:
|
|
308
637
|
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
309
638
|
|
|
310
|
-
|
|
311
|
-
case _:
|
|
312
|
-
raise Exception("Method must have exactly one or two parameters")
|
|
639
|
+
return stub_method
|
|
313
640
|
|
|
314
|
-
|
|
315
|
-
|
|
641
|
+
else:
|
|
642
|
+
# unary-unary
|
|
643
|
+
if size_of_parameters == 0 or size_of_parameters == 1:
|
|
316
644
|
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
645
|
+
async def stub_method(
|
|
646
|
+
self: object,
|
|
647
|
+
request: Any,
|
|
648
|
+
context: Any,
|
|
649
|
+
) -> Any:
|
|
650
|
+
_ = self
|
|
651
|
+
try:
|
|
652
|
+
if size_of_parameters == 0:
|
|
653
|
+
# Method takes no parameters
|
|
654
|
+
resp_obj = await method()
|
|
655
|
+
elif is_none_type(input_type):
|
|
656
|
+
resp_obj = await method(None)
|
|
657
|
+
else:
|
|
658
|
+
arg = converter(request)
|
|
659
|
+
resp_obj = await method(arg)
|
|
660
|
+
|
|
661
|
+
if is_none_type(response_type):
|
|
662
|
+
return empty_pb2.Empty() # type: ignore
|
|
663
|
+
elif (
|
|
664
|
+
inspect.isclass(response_type)
|
|
665
|
+
and issubclass(response_type, Message)
|
|
666
|
+
and not response_type.model_fields
|
|
667
|
+
):
|
|
668
|
+
# Empty message class
|
|
669
|
+
return empty_pb2.Empty() # type: ignore
|
|
670
|
+
else:
|
|
671
|
+
return convert_python_message_to_proto(
|
|
672
|
+
resp_obj, response_type, pb2_module
|
|
673
|
+
)
|
|
674
|
+
except ValidationError as e:
|
|
675
|
+
await context.abort(
|
|
676
|
+
grpc.StatusCode.INVALID_ARGUMENT, str(e)
|
|
677
|
+
)
|
|
678
|
+
except Exception as e:
|
|
679
|
+
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
332
680
|
|
|
333
|
-
|
|
334
|
-
try:
|
|
335
|
-
arg = converter(request)
|
|
336
|
-
resp_obj = await method(arg, context)
|
|
337
|
-
return convert_python_message_to_proto(
|
|
338
|
-
resp_obj, response_type, pb2_module
|
|
339
|
-
)
|
|
340
|
-
except ValidationError as e:
|
|
341
|
-
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
|
342
|
-
except Exception as e:
|
|
343
|
-
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
681
|
+
else: # size_of_parameters == 2
|
|
344
682
|
|
|
345
|
-
|
|
683
|
+
async def stub_method(
|
|
684
|
+
self: object,
|
|
685
|
+
request: Any,
|
|
686
|
+
context: Any,
|
|
687
|
+
) -> Any:
|
|
688
|
+
_ = self
|
|
689
|
+
try:
|
|
690
|
+
if is_none_type(input_type):
|
|
691
|
+
resp_obj = await method(None, context)
|
|
692
|
+
else:
|
|
693
|
+
arg = converter(request)
|
|
694
|
+
resp_obj = await method(arg, context)
|
|
695
|
+
|
|
696
|
+
if is_none_type(response_type):
|
|
697
|
+
return empty_pb2.Empty() # type: ignore
|
|
698
|
+
elif (
|
|
699
|
+
inspect.isclass(response_type)
|
|
700
|
+
and issubclass(response_type, Message)
|
|
701
|
+
and not response_type.model_fields
|
|
702
|
+
):
|
|
703
|
+
# Empty message class
|
|
704
|
+
return empty_pb2.Empty() # type: ignore
|
|
705
|
+
else:
|
|
706
|
+
return convert_python_message_to_proto(
|
|
707
|
+
resp_obj, response_type, pb2_module
|
|
708
|
+
)
|
|
709
|
+
except ValidationError as e:
|
|
710
|
+
await context.abort(
|
|
711
|
+
grpc.StatusCode.INVALID_ARGUMENT, str(e)
|
|
712
|
+
)
|
|
713
|
+
except Exception as e:
|
|
714
|
+
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
346
715
|
|
|
347
|
-
|
|
348
|
-
raise Exception("Method must have exactly one or two parameters")
|
|
716
|
+
return stub_method
|
|
349
717
|
|
|
350
718
|
for method_name, method in get_rpc_methods(obj):
|
|
351
719
|
if method.__name__.startswith("_"):
|
|
@@ -357,7 +725,9 @@ def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> typ
|
|
|
357
725
|
return ConcreteServiceClass
|
|
358
726
|
|
|
359
727
|
|
|
360
|
-
def connect_obj_with_stub_connecpy(
|
|
728
|
+
def connect_obj_with_stub_connecpy(
|
|
729
|
+
connecpy_module: Any, pb2_module: Any, obj: object
|
|
730
|
+
) -> type:
|
|
361
731
|
"""
|
|
362
732
|
Connect a Python service object to a Connecpy stub.
|
|
363
733
|
"""
|
|
@@ -368,7 +738,9 @@ def connect_obj_with_stub_connecpy(connecpy_module, pb2_module, obj: object) ->
|
|
|
368
738
|
class ConcreteServiceClass(stub_class):
|
|
369
739
|
pass
|
|
370
740
|
|
|
371
|
-
def implement_stub_method(
|
|
741
|
+
def implement_stub_method(
|
|
742
|
+
method: Callable[..., Message],
|
|
743
|
+
) -> Callable[[object, Any, Any], Any]:
|
|
372
744
|
sig = inspect.signature(method)
|
|
373
745
|
arg_type = get_request_arg_type(sig)
|
|
374
746
|
converter = generate_message_converter(arg_type)
|
|
@@ -376,40 +748,91 @@ def connect_obj_with_stub_connecpy(connecpy_module, pb2_module, obj: object) ->
|
|
|
376
748
|
size_of_parameters = len(sig.parameters)
|
|
377
749
|
|
|
378
750
|
match size_of_parameters:
|
|
751
|
+
case 0:
|
|
752
|
+
# Method with no parameters (empty request)
|
|
753
|
+
def stub_method0(
|
|
754
|
+
self: object,
|
|
755
|
+
request: Any,
|
|
756
|
+
context: Any,
|
|
757
|
+
method: Callable[..., Message] = method,
|
|
758
|
+
) -> Any:
|
|
759
|
+
_ = self
|
|
760
|
+
try:
|
|
761
|
+
resp_obj = method()
|
|
762
|
+
|
|
763
|
+
if is_none_type(response_type):
|
|
764
|
+
return empty_pb2.Empty() # type: ignore
|
|
765
|
+
else:
|
|
766
|
+
return convert_python_message_to_proto(
|
|
767
|
+
resp_obj, response_type, pb2_module
|
|
768
|
+
)
|
|
769
|
+
except ValidationError as e:
|
|
770
|
+
return context.abort(Errors.INVALID_ARGUMENT, str(e))
|
|
771
|
+
except Exception as e:
|
|
772
|
+
return context.abort(Errors.INTERNAL, str(e))
|
|
773
|
+
|
|
774
|
+
return stub_method0
|
|
775
|
+
|
|
379
776
|
case 1:
|
|
380
777
|
|
|
381
|
-
def stub_method1(
|
|
778
|
+
def stub_method1(
|
|
779
|
+
self: object,
|
|
780
|
+
request: Any,
|
|
781
|
+
context: Any,
|
|
782
|
+
method: Callable[..., Message] = method,
|
|
783
|
+
) -> Any:
|
|
784
|
+
_ = self
|
|
382
785
|
try:
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
786
|
+
if is_none_type(arg_type):
|
|
787
|
+
resp_obj = method(None)
|
|
788
|
+
else:
|
|
789
|
+
arg = converter(request)
|
|
790
|
+
resp_obj = method(arg)
|
|
791
|
+
|
|
792
|
+
if is_none_type(response_type):
|
|
793
|
+
return empty_pb2.Empty() # type: ignore
|
|
794
|
+
else:
|
|
795
|
+
return convert_python_message_to_proto(
|
|
796
|
+
resp_obj, response_type, pb2_module
|
|
797
|
+
)
|
|
388
798
|
except ValidationError as e:
|
|
389
|
-
return context.abort(Errors.
|
|
799
|
+
return context.abort(Errors.INVALID_ARGUMENT, str(e))
|
|
390
800
|
except Exception as e:
|
|
391
|
-
return context.abort(Errors.
|
|
801
|
+
return context.abort(Errors.INTERNAL, str(e))
|
|
392
802
|
|
|
393
803
|
return stub_method1
|
|
394
804
|
|
|
395
805
|
case 2:
|
|
396
806
|
|
|
397
|
-
def stub_method2(
|
|
807
|
+
def stub_method2(
|
|
808
|
+
self: object,
|
|
809
|
+
request: Any,
|
|
810
|
+
context: Any,
|
|
811
|
+
method: Callable[..., Message] = method,
|
|
812
|
+
) -> Any:
|
|
813
|
+
_ = self
|
|
398
814
|
try:
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
815
|
+
if is_none_type(arg_type):
|
|
816
|
+
resp_obj = method(None, context)
|
|
817
|
+
else:
|
|
818
|
+
arg = converter(request)
|
|
819
|
+
resp_obj = method(arg, context)
|
|
820
|
+
|
|
821
|
+
if is_none_type(response_type):
|
|
822
|
+
return empty_pb2.Empty() # type: ignore
|
|
823
|
+
else:
|
|
824
|
+
return convert_python_message_to_proto(
|
|
825
|
+
resp_obj, response_type, pb2_module
|
|
826
|
+
)
|
|
404
827
|
except ValidationError as e:
|
|
405
|
-
return context.abort(Errors.
|
|
828
|
+
return context.abort(Errors.INVALID_ARGUMENT, str(e))
|
|
406
829
|
except Exception as e:
|
|
407
|
-
return context.abort(Errors.
|
|
830
|
+
return context.abort(Errors.INTERNAL, str(e))
|
|
408
831
|
|
|
409
832
|
return stub_method2
|
|
410
833
|
|
|
411
834
|
case _:
|
|
412
|
-
raise Exception("Method must have
|
|
835
|
+
raise Exception("Method must have 0, 1, or 2 parameters")
|
|
413
836
|
|
|
414
837
|
for method_name, method in get_rpc_methods(obj):
|
|
415
838
|
if method.__name__.startswith("_"):
|
|
@@ -421,7 +844,7 @@ def connect_obj_with_stub_connecpy(connecpy_module, pb2_module, obj: object) ->
|
|
|
421
844
|
|
|
422
845
|
|
|
423
846
|
def connect_obj_with_stub_async_connecpy(
|
|
424
|
-
connecpy_module, pb2_module, obj: object
|
|
847
|
+
connecpy_module: Any, pb2_module: Any, obj: object
|
|
425
848
|
) -> type:
|
|
426
849
|
"""
|
|
427
850
|
Connect a Python service object to a Connecpy stub for async methods.
|
|
@@ -433,7 +856,9 @@ def connect_obj_with_stub_async_connecpy(
|
|
|
433
856
|
class ConcreteServiceClass(stub_class):
|
|
434
857
|
pass
|
|
435
858
|
|
|
436
|
-
def implement_stub_method(
|
|
859
|
+
def implement_stub_method(
|
|
860
|
+
method: Callable[..., Awaitable[Message]],
|
|
861
|
+
) -> Callable[[object, Any, Any], Any]:
|
|
437
862
|
sig = inspect.signature(method)
|
|
438
863
|
arg_type = get_request_arg_type(sig)
|
|
439
864
|
converter = generate_message_converter(arg_type)
|
|
@@ -441,40 +866,91 @@ def connect_obj_with_stub_async_connecpy(
|
|
|
441
866
|
size_of_parameters = len(sig.parameters)
|
|
442
867
|
|
|
443
868
|
match size_of_parameters:
|
|
869
|
+
case 0:
|
|
870
|
+
# Method with no parameters (empty request)
|
|
871
|
+
async def stub_method0(
|
|
872
|
+
self: object,
|
|
873
|
+
request: Any,
|
|
874
|
+
context: Any,
|
|
875
|
+
method: Callable[..., Awaitable[Message]] = method,
|
|
876
|
+
) -> Any:
|
|
877
|
+
_ = self
|
|
878
|
+
try:
|
|
879
|
+
resp_obj = await method()
|
|
880
|
+
|
|
881
|
+
if is_none_type(response_type):
|
|
882
|
+
return empty_pb2.Empty() # type: ignore
|
|
883
|
+
else:
|
|
884
|
+
return convert_python_message_to_proto(
|
|
885
|
+
resp_obj, response_type, pb2_module
|
|
886
|
+
)
|
|
887
|
+
except ValidationError as e:
|
|
888
|
+
await context.abort(Errors.INVALID_ARGUMENT, str(e))
|
|
889
|
+
except Exception as e:
|
|
890
|
+
await context.abort(Errors.INTERNAL, str(e))
|
|
891
|
+
|
|
892
|
+
return stub_method0
|
|
893
|
+
|
|
444
894
|
case 1:
|
|
445
895
|
|
|
446
|
-
async def stub_method1(
|
|
896
|
+
async def stub_method1(
|
|
897
|
+
self: object,
|
|
898
|
+
request: Any,
|
|
899
|
+
context: Any,
|
|
900
|
+
method: Callable[..., Awaitable[Message]] = method,
|
|
901
|
+
) -> Any:
|
|
902
|
+
_ = self
|
|
447
903
|
try:
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
904
|
+
if is_none_type(arg_type):
|
|
905
|
+
resp_obj = await method(None)
|
|
906
|
+
else:
|
|
907
|
+
arg = converter(request)
|
|
908
|
+
resp_obj = await method(arg)
|
|
909
|
+
|
|
910
|
+
if is_none_type(response_type):
|
|
911
|
+
return empty_pb2.Empty() # type: ignore
|
|
912
|
+
else:
|
|
913
|
+
return convert_python_message_to_proto(
|
|
914
|
+
resp_obj, response_type, pb2_module
|
|
915
|
+
)
|
|
453
916
|
except ValidationError as e:
|
|
454
|
-
await context.abort(Errors.
|
|
917
|
+
await context.abort(Errors.INVALID_ARGUMENT, str(e))
|
|
455
918
|
except Exception as e:
|
|
456
|
-
await context.abort(Errors.
|
|
919
|
+
await context.abort(Errors.INTERNAL, str(e))
|
|
457
920
|
|
|
458
921
|
return stub_method1
|
|
459
922
|
|
|
460
923
|
case 2:
|
|
461
924
|
|
|
462
|
-
async def stub_method2(
|
|
925
|
+
async def stub_method2(
|
|
926
|
+
self: object,
|
|
927
|
+
request: Any,
|
|
928
|
+
context: Any,
|
|
929
|
+
method: Callable[..., Awaitable[Message]] = method,
|
|
930
|
+
) -> Any:
|
|
931
|
+
_ = self
|
|
463
932
|
try:
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
933
|
+
if is_none_type(arg_type):
|
|
934
|
+
resp_obj = await method(None, context)
|
|
935
|
+
else:
|
|
936
|
+
arg = converter(request)
|
|
937
|
+
resp_obj = await method(arg, context)
|
|
938
|
+
|
|
939
|
+
if is_none_type(response_type):
|
|
940
|
+
return empty_pb2.Empty() # type: ignore
|
|
941
|
+
else:
|
|
942
|
+
return convert_python_message_to_proto(
|
|
943
|
+
resp_obj, response_type, pb2_module
|
|
944
|
+
)
|
|
469
945
|
except ValidationError as e:
|
|
470
|
-
await context.abort(Errors.
|
|
946
|
+
await context.abort(Errors.INVALID_ARGUMENT, str(e))
|
|
471
947
|
except Exception as e:
|
|
472
|
-
await context.abort(Errors.
|
|
948
|
+
await context.abort(Errors.INTERNAL, str(e))
|
|
473
949
|
|
|
474
950
|
return stub_method2
|
|
475
951
|
|
|
476
952
|
case _:
|
|
477
|
-
raise Exception("Method must have
|
|
953
|
+
raise Exception("Method must have 0, 1, or 2 parameters")
|
|
478
954
|
|
|
479
955
|
for method_name, method in get_rpc_methods(obj):
|
|
480
956
|
if method.__name__.startswith("_"):
|
|
@@ -487,36 +963,202 @@ def connect_obj_with_stub_async_connecpy(
|
|
|
487
963
|
return ConcreteServiceClass
|
|
488
964
|
|
|
489
965
|
|
|
966
|
+
def python_value_to_proto_oneof(
|
|
967
|
+
field_name: str,
|
|
968
|
+
field_type: type[Any],
|
|
969
|
+
value: Any,
|
|
970
|
+
pb2_module: Any,
|
|
971
|
+
_visited: set[int] | None = None,
|
|
972
|
+
_depth: int = 0,
|
|
973
|
+
) -> tuple[str, Any]:
|
|
974
|
+
"""
|
|
975
|
+
Converts a Python value from a Union type to a protobuf oneof field.
|
|
976
|
+
Returns the field name to set and the converted value.
|
|
977
|
+
"""
|
|
978
|
+
union_args = [arg for arg in flatten_union(field_type) if arg is not type(None)]
|
|
979
|
+
|
|
980
|
+
# Find which subtype in the Union matches the value's type.
|
|
981
|
+
actual_type = None
|
|
982
|
+
for sub_type in union_args:
|
|
983
|
+
origin = get_origin(sub_type)
|
|
984
|
+
type_to_check = origin or sub_type
|
|
985
|
+
try:
|
|
986
|
+
if isinstance(value, type_to_check):
|
|
987
|
+
actual_type = sub_type
|
|
988
|
+
break
|
|
989
|
+
except TypeError:
|
|
990
|
+
# This can happen if `sub_type` is not a class, e.g. a generic alias
|
|
991
|
+
if isinstance(value, type_to_check):
|
|
992
|
+
actual_type = sub_type
|
|
993
|
+
break
|
|
994
|
+
|
|
995
|
+
if actual_type is None:
|
|
996
|
+
raise TypeError(f"Value of type {type(value)} not found in union {field_type}")
|
|
997
|
+
|
|
998
|
+
proto_typename = protobuf_type_mapping(actual_type)
|
|
999
|
+
if proto_typename is None:
|
|
1000
|
+
raise TypeError(f"Unsupported type in oneof: {actual_type}")
|
|
1001
|
+
|
|
1002
|
+
oneof_field_name = f"{field_name}_{proto_typename.replace('.', '_')}"
|
|
1003
|
+
converted_value = python_value_to_proto(
|
|
1004
|
+
actual_type, value, pb2_module, _visited, _depth
|
|
1005
|
+
)
|
|
1006
|
+
return oneof_field_name, converted_value
|
|
1007
|
+
|
|
1008
|
+
|
|
490
1009
|
def convert_python_message_to_proto(
|
|
491
|
-
py_msg: Message,
|
|
1010
|
+
py_msg: Message,
|
|
1011
|
+
msg_type: type[Message],
|
|
1012
|
+
pb2_module: Any,
|
|
1013
|
+
_visited: set[int] | None = None,
|
|
1014
|
+
_depth: int = 0,
|
|
492
1015
|
) -> object:
|
|
1016
|
+
"""Convert a Python Pydantic Message instance to a protobuf message instance. Used for constructing a response.
|
|
1017
|
+
|
|
1018
|
+
Args:
|
|
1019
|
+
py_msg: The Pydantic Message instance to convert
|
|
1020
|
+
msg_type: The type of the Message
|
|
1021
|
+
pb2_module: The protobuf module
|
|
1022
|
+
_visited: Internal set to track visited objects for circular reference detection
|
|
1023
|
+
_depth: Current recursion depth for strategy control
|
|
493
1024
|
"""
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
1025
|
+
# Handle empty message classes
|
|
1026
|
+
if not msg_type.model_fields:
|
|
1027
|
+
# Return google.protobuf.Empty instance
|
|
1028
|
+
return empty_pb2.Empty()
|
|
1029
|
+
|
|
1030
|
+
# Initialize visited set for circular reference detection
|
|
1031
|
+
if _visited is None:
|
|
1032
|
+
_visited = set()
|
|
1033
|
+
|
|
1034
|
+
# Check for circular references
|
|
1035
|
+
obj_id = id(py_msg)
|
|
1036
|
+
if obj_id in _visited:
|
|
1037
|
+
# Return empty proto to avoid infinite recursion
|
|
1038
|
+
proto_class = getattr(pb2_module, msg_type.__name__)
|
|
1039
|
+
return proto_class()
|
|
1040
|
+
_visited.add(obj_id)
|
|
1041
|
+
|
|
1042
|
+
# Determine if we should apply serializers based on strategy
|
|
1043
|
+
apply_serializers = False
|
|
1044
|
+
if _SERIALIZER_STRATEGY == SerializerStrategy.DEEP:
|
|
1045
|
+
apply_serializers = True # Always apply at any depth
|
|
1046
|
+
elif _SERIALIZER_STRATEGY == SerializerStrategy.SHALLOW:
|
|
1047
|
+
apply_serializers = _depth == 0 # Only at top level
|
|
1048
|
+
# SerializerStrategy.NONE: never apply
|
|
1049
|
+
|
|
1050
|
+
# Only use model_dump if there are serializers and strategy allows
|
|
1051
|
+
serialized_data = None
|
|
1052
|
+
if apply_serializers and has_serializers(msg_type):
|
|
1053
|
+
try:
|
|
1054
|
+
serialized_data = py_msg.model_dump(mode="python")
|
|
1055
|
+
except Exception:
|
|
1056
|
+
# Fallback to the old approach if model_dump fails
|
|
1057
|
+
serialized_data = None
|
|
1058
|
+
|
|
499
1059
|
field_dict = {}
|
|
500
1060
|
for name, field_info in msg_type.model_fields.items():
|
|
1061
|
+
field_type = field_info.annotation
|
|
1062
|
+
|
|
1063
|
+
# Check if this field type contains Message types
|
|
1064
|
+
contains_message = False
|
|
1065
|
+
if field_type:
|
|
1066
|
+
if inspect.isclass(field_type) and issubclass(field_type, Message):
|
|
1067
|
+
contains_message = True
|
|
1068
|
+
elif is_union_type(field_type):
|
|
1069
|
+
# Check if any of the union args are Messages
|
|
1070
|
+
for arg in flatten_union(field_type):
|
|
1071
|
+
if arg and inspect.isclass(arg) and issubclass(arg, Message):
|
|
1072
|
+
contains_message = True
|
|
1073
|
+
break
|
|
1074
|
+
else:
|
|
1075
|
+
# Check for list/dict of Messages
|
|
1076
|
+
origin = get_origin(field_type)
|
|
1077
|
+
if origin in (list, tuple):
|
|
1078
|
+
inner_type = (
|
|
1079
|
+
get_args(field_type)[0] if get_args(field_type) else None
|
|
1080
|
+
)
|
|
1081
|
+
if (
|
|
1082
|
+
inner_type
|
|
1083
|
+
and inspect.isclass(inner_type)
|
|
1084
|
+
and issubclass(inner_type, Message)
|
|
1085
|
+
):
|
|
1086
|
+
contains_message = True
|
|
1087
|
+
elif origin is dict:
|
|
1088
|
+
args = get_args(field_type)
|
|
1089
|
+
if len(args) >= 2:
|
|
1090
|
+
val_type = args[1]
|
|
1091
|
+
if inspect.isclass(val_type) and issubclass(val_type, Message):
|
|
1092
|
+
contains_message = True
|
|
1093
|
+
|
|
1094
|
+
# Get the value
|
|
501
1095
|
value = getattr(py_msg, name)
|
|
502
1096
|
if value is None:
|
|
503
|
-
field_dict[name] = None
|
|
504
1097
|
continue
|
|
505
1098
|
|
|
1099
|
+
# For Message types, recursively apply serialization
|
|
1100
|
+
if contains_message and not is_union_type(field_type):
|
|
1101
|
+
# Direct Message field
|
|
1102
|
+
if inspect.isclass(field_type) and issubclass(field_type, Message):
|
|
1103
|
+
# Recursively convert nested Message with serializers
|
|
1104
|
+
field_dict[name] = convert_python_message_to_proto(
|
|
1105
|
+
value, field_type, pb2_module, _visited, _depth + 1
|
|
1106
|
+
)
|
|
1107
|
+
continue
|
|
1108
|
+
|
|
1109
|
+
# Use serialized data for non-Message types to respect custom serializers
|
|
1110
|
+
if (
|
|
1111
|
+
serialized_data is not None
|
|
1112
|
+
and name in serialized_data
|
|
1113
|
+
and not contains_message
|
|
1114
|
+
):
|
|
1115
|
+
value = serialized_data[name]
|
|
1116
|
+
|
|
506
1117
|
field_type = field_info.annotation
|
|
507
|
-
|
|
1118
|
+
|
|
1119
|
+
# Handle oneof fields, which are represented as Unions.
|
|
1120
|
+
if field_type is not None and is_union_type(field_type):
|
|
1121
|
+
union_args = [
|
|
1122
|
+
arg for arg in flatten_union(field_type) if arg is not type(None)
|
|
1123
|
+
]
|
|
1124
|
+
if len(union_args) > 1:
|
|
1125
|
+
# It's a oneof field. We need to determine the concrete type and
|
|
1126
|
+
# the corresponding protobuf field name.
|
|
1127
|
+
(
|
|
1128
|
+
oneof_field_name,
|
|
1129
|
+
converted_value,
|
|
1130
|
+
) = python_value_to_proto_oneof(
|
|
1131
|
+
name, field_type, value, pb2_module, _visited, _depth
|
|
1132
|
+
)
|
|
1133
|
+
field_dict[oneof_field_name] = converted_value
|
|
1134
|
+
continue
|
|
1135
|
+
|
|
1136
|
+
# For regular and Optional fields that have a value.
|
|
1137
|
+
if field_type is not None:
|
|
1138
|
+
field_dict[name] = python_value_to_proto(
|
|
1139
|
+
field_type, value, pb2_module, _visited, _depth
|
|
1140
|
+
)
|
|
1141
|
+
|
|
1142
|
+
# Remove from visited set when done
|
|
1143
|
+
_visited.discard(obj_id)
|
|
508
1144
|
|
|
509
1145
|
# Retrieve the appropriate protobuf class dynamically
|
|
510
1146
|
proto_class = getattr(pb2_module, msg_type.__name__)
|
|
511
1147
|
return proto_class(**field_dict)
|
|
512
1148
|
|
|
513
1149
|
|
|
514
|
-
def python_value_to_proto(
|
|
1150
|
+
def python_value_to_proto(
|
|
1151
|
+
field_type: type[Any],
|
|
1152
|
+
value: Any,
|
|
1153
|
+
pb2_module: Any,
|
|
1154
|
+
_visited: set[int] | None = None,
|
|
1155
|
+
_depth: int = 0,
|
|
1156
|
+
) -> Any:
|
|
515
1157
|
"""
|
|
516
1158
|
Perform Python->protobuf type conversion for each field value.
|
|
517
1159
|
"""
|
|
518
|
-
import inspect
|
|
519
1160
|
import datetime
|
|
1161
|
+
import inspect
|
|
520
1162
|
|
|
521
1163
|
# If datetime
|
|
522
1164
|
if field_type == datetime.datetime:
|
|
@@ -533,48 +1175,58 @@ def python_value_to_proto(field_type: Type, value, pb2_module):
|
|
|
533
1175
|
origin = get_origin(field_type)
|
|
534
1176
|
# If seq
|
|
535
1177
|
if origin in (list, tuple):
|
|
536
|
-
inner_type = get_args(field_type)[0]
|
|
537
|
-
|
|
1178
|
+
inner_type = get_args(field_type)[0]
|
|
1179
|
+
# Handle list of Messages
|
|
1180
|
+
if inspect.isclass(inner_type) and issubclass(inner_type, Message):
|
|
1181
|
+
return [
|
|
1182
|
+
convert_python_message_to_proto(
|
|
1183
|
+
v, inner_type, pb2_module, _visited, _depth + 1
|
|
1184
|
+
)
|
|
1185
|
+
for v in value
|
|
1186
|
+
]
|
|
1187
|
+
return [
|
|
1188
|
+
python_value_to_proto(inner_type, v, pb2_module, _visited, _depth)
|
|
1189
|
+
for v in value
|
|
1190
|
+
]
|
|
538
1191
|
|
|
539
1192
|
# If dict
|
|
540
1193
|
if origin is dict:
|
|
541
|
-
key_type, val_type = get_args(field_type)
|
|
1194
|
+
key_type, val_type = get_args(field_type)
|
|
1195
|
+
# Handle dict with Message values
|
|
1196
|
+
if inspect.isclass(val_type) and issubclass(val_type, Message):
|
|
1197
|
+
return {
|
|
1198
|
+
python_value_to_proto(
|
|
1199
|
+
key_type, k, pb2_module, _visited, _depth
|
|
1200
|
+
): convert_python_message_to_proto(
|
|
1201
|
+
v, val_type, pb2_module, _visited, _depth + 1
|
|
1202
|
+
)
|
|
1203
|
+
for k, v in value.items()
|
|
1204
|
+
}
|
|
542
1205
|
return {
|
|
543
|
-
python_value_to_proto(
|
|
544
|
-
|
|
545
|
-
)
|
|
1206
|
+
python_value_to_proto(
|
|
1207
|
+
key_type, k, pb2_module, _visited, _depth
|
|
1208
|
+
): python_value_to_proto(val_type, v, pb2_module, _visited, _depth)
|
|
546
1209
|
for k, v in value.items()
|
|
547
1210
|
}
|
|
548
1211
|
|
|
549
|
-
# If union -> oneof
|
|
1212
|
+
# If union -> oneof. This path is now only for Optional[T] where value is not None.
|
|
550
1213
|
if is_union_type(field_type):
|
|
551
|
-
#
|
|
552
|
-
|
|
553
|
-
if
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
):
|
|
562
|
-
return value.value
|
|
563
|
-
if sub_type in (int, float, str, bool, bytes) and isinstance(
|
|
564
|
-
value, sub_type
|
|
565
|
-
):
|
|
566
|
-
return value
|
|
567
|
-
if (
|
|
568
|
-
inspect.isclass(sub_type)
|
|
569
|
-
and issubclass(sub_type, Message)
|
|
570
|
-
and isinstance(value, Message)
|
|
571
|
-
):
|
|
572
|
-
return convert_python_message_to_proto(value, sub_type, pb2_module)
|
|
573
|
-
return None
|
|
1214
|
+
# The value is not None, so we need to find the actual type.
|
|
1215
|
+
non_none_args = [
|
|
1216
|
+
arg for arg in flatten_union(field_type) if arg is not type(None)
|
|
1217
|
+
]
|
|
1218
|
+
if non_none_args:
|
|
1219
|
+
# Assuming it's an Optional[T], so there's one type left.
|
|
1220
|
+
return python_value_to_proto(
|
|
1221
|
+
non_none_args[0], value, pb2_module, _visited, _depth
|
|
1222
|
+
)
|
|
1223
|
+
return None # Should not be reached if value is not None
|
|
574
1224
|
|
|
575
1225
|
# If Message
|
|
576
1226
|
if inspect.isclass(field_type) and issubclass(field_type, Message):
|
|
577
|
-
return convert_python_message_to_proto(
|
|
1227
|
+
return convert_python_message_to_proto(
|
|
1228
|
+
value, field_type, pb2_module, _visited, _depth + 1
|
|
1229
|
+
)
|
|
578
1230
|
|
|
579
1231
|
# If primitive
|
|
580
1232
|
return value
|
|
@@ -585,12 +1237,12 @@ def python_value_to_proto(field_type: Type, value, pb2_module):
|
|
|
585
1237
|
###############################################################################
|
|
586
1238
|
|
|
587
1239
|
|
|
588
|
-
def is_enum_type(python_type:
|
|
1240
|
+
def is_enum_type(python_type: Any) -> bool:
|
|
589
1241
|
"""Return True if the given Python type is an enum."""
|
|
590
1242
|
return inspect.isclass(python_type) and issubclass(python_type, enum.Enum)
|
|
591
1243
|
|
|
592
1244
|
|
|
593
|
-
def is_union_type(python_type:
|
|
1245
|
+
def is_union_type(python_type: Any) -> bool:
|
|
594
1246
|
"""
|
|
595
1247
|
Check if a given Python type is a Union type (including Python 3.10's UnionType).
|
|
596
1248
|
"""
|
|
@@ -599,26 +1251,28 @@ def is_union_type(python_type: Type) -> bool:
|
|
|
599
1251
|
if sys.version_info >= (3, 10):
|
|
600
1252
|
import types
|
|
601
1253
|
|
|
602
|
-
if
|
|
1254
|
+
if isinstance(python_type, types.UnionType):
|
|
603
1255
|
return True
|
|
604
1256
|
return False
|
|
605
1257
|
|
|
606
1258
|
|
|
607
|
-
def flatten_union(field_type:
|
|
1259
|
+
def flatten_union(field_type: Any) -> list[Any]:
|
|
608
1260
|
"""Recursively flatten nested Unions into a single list of types."""
|
|
609
1261
|
if is_union_type(field_type):
|
|
610
1262
|
results = []
|
|
611
1263
|
for arg in get_args(field_type):
|
|
612
1264
|
results.extend(flatten_union(arg))
|
|
613
1265
|
return results
|
|
1266
|
+
elif is_none_type(field_type):
|
|
1267
|
+
return [field_type]
|
|
614
1268
|
else:
|
|
615
1269
|
return [field_type]
|
|
616
1270
|
|
|
617
1271
|
|
|
618
|
-
def protobuf_type_mapping(python_type:
|
|
1272
|
+
def protobuf_type_mapping(python_type: Any) -> str | None:
|
|
619
1273
|
"""
|
|
620
1274
|
Map a Python type to a protobuf type name/class.
|
|
621
|
-
Includes support for Timestamp and
|
|
1275
|
+
Includes support for Timestamp, Duration, and Empty.
|
|
622
1276
|
"""
|
|
623
1277
|
import datetime
|
|
624
1278
|
|
|
@@ -636,8 +1290,11 @@ def protobuf_type_mapping(python_type: Type) -> str | type | None:
|
|
|
636
1290
|
if python_type == datetime.timedelta:
|
|
637
1291
|
return "google.protobuf.Duration"
|
|
638
1292
|
|
|
1293
|
+
if is_none_type(python_type):
|
|
1294
|
+
return "google.protobuf.Empty"
|
|
1295
|
+
|
|
639
1296
|
if is_enum_type(python_type):
|
|
640
|
-
return python_type
|
|
1297
|
+
return python_type.__name__
|
|
641
1298
|
|
|
642
1299
|
if is_union_type(python_type):
|
|
643
1300
|
return None # Handled separately as oneof
|
|
@@ -657,9 +1314,12 @@ def protobuf_type_mapping(python_type: Type) -> str | type | None:
|
|
|
657
1314
|
return f"map<{key_proto_type}, {value_proto_type}>"
|
|
658
1315
|
|
|
659
1316
|
if inspect.isclass(python_type) and issubclass(python_type, Message):
|
|
660
|
-
|
|
1317
|
+
# Check if it's an empty message
|
|
1318
|
+
if not python_type.model_fields:
|
|
1319
|
+
return "google.protobuf.Empty"
|
|
1320
|
+
return python_type.__name__
|
|
661
1321
|
|
|
662
|
-
return mapping.get(python_type)
|
|
1322
|
+
return mapping.get(python_type)
|
|
663
1323
|
|
|
664
1324
|
|
|
665
1325
|
def comment_out(docstr: str) -> tuple[str, ...]:
|
|
@@ -673,15 +1333,15 @@ def comment_out(docstr: str) -> tuple[str, ...]:
|
|
|
673
1333
|
return tuple("//" if line == "" else f"// {line}" for line in docstr.split("\n"))
|
|
674
1334
|
|
|
675
1335
|
|
|
676
|
-
def indent_lines(lines, indentation=" "):
|
|
1336
|
+
def indent_lines(lines: list[str], indentation: str = " ") -> str:
|
|
677
1337
|
"""Indent multiple lines with a given indentation string."""
|
|
678
1338
|
return "\n".join(indentation + line for line in lines)
|
|
679
1339
|
|
|
680
1340
|
|
|
681
|
-
def generate_enum_definition(enum_type:
|
|
1341
|
+
def generate_enum_definition(enum_type: Any) -> str:
|
|
682
1342
|
"""Generate a protobuf enum definition from a Python enum."""
|
|
683
1343
|
enum_name = enum_type.__name__
|
|
684
|
-
members = []
|
|
1344
|
+
members: list[str] = []
|
|
685
1345
|
for _, member in enum_type.__members__.items():
|
|
686
1346
|
members.append(f" {member.name} = {member.value};")
|
|
687
1347
|
enum_def = f"enum {enum_name} {{\n"
|
|
@@ -691,7 +1351,7 @@ def generate_enum_definition(enum_type: Type[enum.Enum]) -> str:
|
|
|
691
1351
|
|
|
692
1352
|
|
|
693
1353
|
def generate_oneof_definition(
|
|
694
|
-
field_name: str, union_args: list[
|
|
1354
|
+
field_name: str, union_args: list[Any], start_index: int
|
|
695
1355
|
) -> tuple[list[str], int]:
|
|
696
1356
|
"""
|
|
697
1357
|
Generate a oneof block in protobuf for a union field.
|
|
@@ -705,12 +1365,6 @@ def generate_oneof_definition(
|
|
|
705
1365
|
if proto_typename is None:
|
|
706
1366
|
raise Exception(f"Nested Union not flattened properly: {arg_type}")
|
|
707
1367
|
|
|
708
|
-
# If it's an enum or Message, use the type name.
|
|
709
|
-
if is_enum_type(arg_type):
|
|
710
|
-
proto_typename = arg_type.__name__
|
|
711
|
-
elif inspect.isclass(arg_type) and issubclass(arg_type, Message):
|
|
712
|
-
proto_typename = arg_type.__name__
|
|
713
|
-
|
|
714
1368
|
field_alias = f"{field_name}_{proto_typename.replace('.', '_')}"
|
|
715
1369
|
lines.append(f" {proto_typename} {field_alias} = {current};")
|
|
716
1370
|
current += 1
|
|
@@ -718,18 +1372,57 @@ def generate_oneof_definition(
|
|
|
718
1372
|
return lines, current
|
|
719
1373
|
|
|
720
1374
|
|
|
1375
|
+
def extract_nested_types(field_type: Any) -> list[Any]:
|
|
1376
|
+
"""
|
|
1377
|
+
Recursively extract all Message and enum types from a field type,
|
|
1378
|
+
including those nested within list, dict, and union types.
|
|
1379
|
+
"""
|
|
1380
|
+
extracted_types = []
|
|
1381
|
+
|
|
1382
|
+
if field_type is None or is_none_type(field_type):
|
|
1383
|
+
return extracted_types
|
|
1384
|
+
|
|
1385
|
+
# Check if the type itself is an enum or Message
|
|
1386
|
+
if is_enum_type(field_type):
|
|
1387
|
+
extracted_types.append(field_type)
|
|
1388
|
+
elif inspect.isclass(field_type) and issubclass(field_type, Message):
|
|
1389
|
+
extracted_types.append(field_type)
|
|
1390
|
+
|
|
1391
|
+
# Handle Union types
|
|
1392
|
+
if is_union_type(field_type):
|
|
1393
|
+
union_args = flatten_union(field_type)
|
|
1394
|
+
for arg in union_args:
|
|
1395
|
+
if arg is not type(None):
|
|
1396
|
+
extracted_types.extend(extract_nested_types(arg))
|
|
1397
|
+
|
|
1398
|
+
# Handle generic types (list, dict, etc.)
|
|
1399
|
+
origin = get_origin(field_type)
|
|
1400
|
+
if origin is not None:
|
|
1401
|
+
args = get_args(field_type)
|
|
1402
|
+
for arg in args:
|
|
1403
|
+
extracted_types.extend(extract_nested_types(arg))
|
|
1404
|
+
|
|
1405
|
+
return extracted_types
|
|
1406
|
+
|
|
1407
|
+
|
|
721
1408
|
def generate_message_definition(
|
|
722
|
-
message_type:
|
|
723
|
-
done_enums: set,
|
|
724
|
-
done_messages: set,
|
|
725
|
-
) -> tuple[str, list[
|
|
1409
|
+
message_type: Any,
|
|
1410
|
+
done_enums: set[Any],
|
|
1411
|
+
done_messages: set[Any],
|
|
1412
|
+
) -> tuple[str, list[Any]]:
|
|
726
1413
|
"""
|
|
727
1414
|
Generate a protobuf message definition for a Pydantic-based Message class.
|
|
728
1415
|
Also returns any referenced types (enums, messages) that need to be defined.
|
|
729
1416
|
"""
|
|
730
|
-
fields = []
|
|
731
|
-
refs = []
|
|
1417
|
+
fields: list[str] = []
|
|
1418
|
+
refs: list[Any] = []
|
|
732
1419
|
pydantic_fields = message_type.model_fields
|
|
1420
|
+
|
|
1421
|
+
# Check if this is an empty message and should use google.protobuf.Empty
|
|
1422
|
+
if not pydantic_fields:
|
|
1423
|
+
# Return a special marker that indicates this should use google.protobuf.Empty
|
|
1424
|
+
return "__EMPTY__", []
|
|
1425
|
+
|
|
733
1426
|
index = 1
|
|
734
1427
|
|
|
735
1428
|
for field_name, field_info in pydantic_fields.items():
|
|
@@ -737,92 +1430,130 @@ def generate_message_definition(
|
|
|
737
1430
|
if field_type is None:
|
|
738
1431
|
raise Exception(f"Field {field_name} has no type annotation.")
|
|
739
1432
|
|
|
1433
|
+
is_optional = False
|
|
1434
|
+
# Handle Union types, which may be Optional or a oneof.
|
|
740
1435
|
if is_union_type(field_type):
|
|
741
1436
|
union_args = flatten_union(field_type)
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
)
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
1437
|
+
none_type = type(
|
|
1438
|
+
None
|
|
1439
|
+
) # Keep this as type(None) since we're working with union args
|
|
1440
|
+
|
|
1441
|
+
if none_type in union_args or None in union_args:
|
|
1442
|
+
is_optional = True
|
|
1443
|
+
union_args = [arg for arg in union_args if not is_none_type(arg)]
|
|
1444
|
+
|
|
1445
|
+
if len(union_args) == 1:
|
|
1446
|
+
# This is an Optional[T]. Treat it as a simple optional field.
|
|
1447
|
+
field_type = union_args[0]
|
|
1448
|
+
elif len(union_args) > 1:
|
|
1449
|
+
# This is a Union of multiple types, so it becomes a `oneof`.
|
|
1450
|
+
oneof_lines, new_index = generate_oneof_definition(
|
|
1451
|
+
field_name, union_args, index
|
|
1452
|
+
)
|
|
1453
|
+
fields.extend(oneof_lines)
|
|
1454
|
+
index = new_index
|
|
1455
|
+
|
|
1456
|
+
for utype in union_args:
|
|
1457
|
+
if is_enum_type(utype) and utype not in done_enums:
|
|
1458
|
+
refs.append(utype)
|
|
1459
|
+
elif (
|
|
1460
|
+
inspect.isclass(utype)
|
|
1461
|
+
and issubclass(utype, Message)
|
|
1462
|
+
and utype not in done_messages
|
|
1463
|
+
):
|
|
1464
|
+
refs.append(utype)
|
|
1465
|
+
continue # Proceed to the next field
|
|
1466
|
+
else:
|
|
1467
|
+
# This was a field of only `NoneType`, which is not supported.
|
|
1468
|
+
continue
|
|
757
1469
|
|
|
1470
|
+
# For regular fields or optional fields that have been unwrapped.
|
|
1471
|
+
proto_typename = protobuf_type_mapping(field_type)
|
|
1472
|
+
if proto_typename is None:
|
|
1473
|
+
raise Exception(f"Type {field_type} is not supported.")
|
|
1474
|
+
|
|
1475
|
+
# Extract all nested Message and enum types recursively
|
|
1476
|
+
nested_types = extract_nested_types(field_type)
|
|
1477
|
+
for nested_type in nested_types:
|
|
1478
|
+
if is_enum_type(nested_type) and nested_type not in done_enums:
|
|
1479
|
+
refs.append(nested_type)
|
|
1480
|
+
elif (
|
|
1481
|
+
inspect.isclass(nested_type)
|
|
1482
|
+
and issubclass(nested_type, Message)
|
|
1483
|
+
and nested_type not in done_messages
|
|
1484
|
+
):
|
|
1485
|
+
refs.append(nested_type)
|
|
1486
|
+
|
|
1487
|
+
if field_info.description:
|
|
1488
|
+
fields.append("// " + field_info.description)
|
|
1489
|
+
if field_info.metadata:
|
|
1490
|
+
fields.append("// Constraint:")
|
|
1491
|
+
for metadata_item in field_info.metadata:
|
|
1492
|
+
match type(metadata_item):
|
|
1493
|
+
case annotated_types.Ge:
|
|
1494
|
+
fields.append(
|
|
1495
|
+
"// greater than or equal to " + str(metadata_item.ge)
|
|
1496
|
+
)
|
|
1497
|
+
case annotated_types.Le:
|
|
1498
|
+
fields.append(
|
|
1499
|
+
"// less than or equal to " + str(metadata_item.le)
|
|
1500
|
+
)
|
|
1501
|
+
case annotated_types.Gt:
|
|
1502
|
+
fields.append("// greater than " + str(metadata_item.gt))
|
|
1503
|
+
case annotated_types.Lt:
|
|
1504
|
+
fields.append("// less than " + str(metadata_item.lt))
|
|
1505
|
+
case annotated_types.MultipleOf:
|
|
1506
|
+
fields.append(
|
|
1507
|
+
"// multiple of " + str(metadata_item.multiple_of)
|
|
1508
|
+
)
|
|
1509
|
+
case annotated_types.Len:
|
|
1510
|
+
fields.append("// length of " + str(metadata_item.len))
|
|
1511
|
+
case annotated_types.MinLen:
|
|
1512
|
+
fields.append(
|
|
1513
|
+
"// minimum length of " + str(metadata_item.min_len)
|
|
1514
|
+
)
|
|
1515
|
+
case annotated_types.MaxLen:
|
|
1516
|
+
fields.append(
|
|
1517
|
+
"// maximum length of " + str(metadata_item.max_len)
|
|
1518
|
+
)
|
|
1519
|
+
case _:
|
|
1520
|
+
fields.append("// " + str(metadata_item))
|
|
1521
|
+
|
|
1522
|
+
field_definition = f"{proto_typename} {field_name} = {index};"
|
|
1523
|
+
if is_optional:
|
|
1524
|
+
field_definition = f"optional {field_definition}"
|
|
1525
|
+
|
|
1526
|
+
fields.append(field_definition)
|
|
1527
|
+
index += 1
|
|
1528
|
+
|
|
1529
|
+
# Add reserved fields for forward/backward compatibility if specified
|
|
1530
|
+
reserved_count = get_reserved_fields_count()
|
|
1531
|
+
if reserved_count > 0:
|
|
1532
|
+
start_reserved = index
|
|
1533
|
+
end_reserved = index + reserved_count - 1
|
|
1534
|
+
fields.append("")
|
|
1535
|
+
fields.append("// Reserved fields for future compatibility")
|
|
1536
|
+
if reserved_count == 1:
|
|
1537
|
+
fields.append(f"reserved {start_reserved};")
|
|
758
1538
|
else:
|
|
759
|
-
|
|
760
|
-
if proto_typename is None:
|
|
761
|
-
raise Exception(f"Type {field_type} is not supported.")
|
|
762
|
-
|
|
763
|
-
if is_enum_type(field_type):
|
|
764
|
-
proto_typename = field_type.__name__
|
|
765
|
-
if field_type not in done_enums:
|
|
766
|
-
refs.append(field_type)
|
|
767
|
-
elif inspect.isclass(field_type) and issubclass(field_type, Message):
|
|
768
|
-
proto_typename = field_type.__name__
|
|
769
|
-
if field_type not in done_messages:
|
|
770
|
-
refs.append(field_type)
|
|
771
|
-
|
|
772
|
-
if field_info.description:
|
|
773
|
-
fields.append("// " + field_info.description)
|
|
774
|
-
if field_info.metadata:
|
|
775
|
-
fields.append("// Constraint:")
|
|
776
|
-
for metadata_item in field_info.metadata:
|
|
777
|
-
match type(metadata_item):
|
|
778
|
-
case annotated_types.Ge:
|
|
779
|
-
fields.append(
|
|
780
|
-
"// greater than or equal to " + str(metadata_item.ge)
|
|
781
|
-
)
|
|
782
|
-
case annotated_types.Le:
|
|
783
|
-
fields.append(
|
|
784
|
-
"// less than or equal to " + str(metadata_item.le)
|
|
785
|
-
)
|
|
786
|
-
case annotated_types.Gt:
|
|
787
|
-
fields.append("// greater than " + str(metadata_item.gt))
|
|
788
|
-
case annotated_types.Lt:
|
|
789
|
-
fields.append("// less than " + str(metadata_item.lt))
|
|
790
|
-
case annotated_types.MultipleOf:
|
|
791
|
-
fields.append(
|
|
792
|
-
"// multiple of " + str(metadata_item.multiple_of)
|
|
793
|
-
)
|
|
794
|
-
case annotated_types.Len:
|
|
795
|
-
fields.append("// length of " + str(metadata_item.len))
|
|
796
|
-
case annotated_types.MinLen:
|
|
797
|
-
fields.append(
|
|
798
|
-
"// minimum length of " + str(metadata_item.min_len)
|
|
799
|
-
)
|
|
800
|
-
case annotated_types.MaxLen:
|
|
801
|
-
fields.append(
|
|
802
|
-
"// maximum length of " + str(metadata_item.max_len)
|
|
803
|
-
)
|
|
804
|
-
case _:
|
|
805
|
-
fields.append("// " + str(metadata_item))
|
|
806
|
-
|
|
807
|
-
fields.append(f"{proto_typename} {field_name} = {index};")
|
|
808
|
-
index += 1
|
|
1539
|
+
fields.append(f"reserved {start_reserved} to {end_reserved};")
|
|
809
1540
|
|
|
810
1541
|
msg_def = f"message {message_type.__name__} {{\n{indent_lines(fields)}\n}}"
|
|
811
1542
|
return msg_def, refs
|
|
812
1543
|
|
|
813
1544
|
|
|
814
|
-
def is_stream_type(annotation:
|
|
1545
|
+
def is_stream_type(annotation: Any) -> bool:
|
|
815
1546
|
return get_origin(annotation) is AsyncIterator
|
|
816
1547
|
|
|
817
1548
|
|
|
818
|
-
def is_generic_alias(annotation:
|
|
1549
|
+
def is_generic_alias(annotation: Any) -> bool:
|
|
819
1550
|
return get_origin(annotation) is not None
|
|
820
1551
|
|
|
821
1552
|
|
|
822
1553
|
def generate_proto(obj: object, package_name: str = "") -> str:
|
|
823
1554
|
"""
|
|
824
1555
|
Generate a .proto definition from a service class.
|
|
825
|
-
Automatically handles Timestamp and
|
|
1556
|
+
Automatically handles Timestamp, Duration, and Empty usage.
|
|
826
1557
|
"""
|
|
827
1558
|
import datetime
|
|
828
1559
|
|
|
@@ -831,20 +1562,23 @@ def generate_proto(obj: object, package_name: str = "") -> str:
|
|
|
831
1562
|
service_docstr = inspect.getdoc(service_class)
|
|
832
1563
|
service_comment = "\n".join(comment_out(service_docstr)) if service_docstr else ""
|
|
833
1564
|
|
|
834
|
-
rpc_definitions = []
|
|
835
|
-
all_type_definitions = []
|
|
836
|
-
done_messages = set()
|
|
837
|
-
done_enums = set()
|
|
1565
|
+
rpc_definitions: list[str] = []
|
|
1566
|
+
all_type_definitions: list[str] = []
|
|
1567
|
+
done_messages: set[Any] = set()
|
|
1568
|
+
done_enums: set[Any] = set()
|
|
838
1569
|
|
|
839
1570
|
uses_timestamp = False
|
|
840
1571
|
uses_duration = False
|
|
1572
|
+
uses_empty = False
|
|
841
1573
|
|
|
842
|
-
def
|
|
1574
|
+
def check_and_set_well_known_types_for_fields(py_type: Any):
|
|
1575
|
+
"""Check well-known types for field annotations (excludes None/Empty)."""
|
|
843
1576
|
nonlocal uses_timestamp, uses_duration
|
|
844
1577
|
if py_type == datetime.datetime:
|
|
845
1578
|
uses_timestamp = True
|
|
846
1579
|
if py_type == datetime.timedelta:
|
|
847
1580
|
uses_duration = True
|
|
1581
|
+
# Don't check for None here - Optional fields don't use Empty
|
|
848
1582
|
|
|
849
1583
|
for method_name, method in get_rpc_methods(obj):
|
|
850
1584
|
if method.__name__.startswith("_"):
|
|
@@ -854,35 +1588,79 @@ def generate_proto(obj: object, package_name: str = "") -> str:
|
|
|
854
1588
|
request_type = get_request_arg_type(method_sig)
|
|
855
1589
|
response_type = method_sig.return_annotation
|
|
856
1590
|
|
|
1591
|
+
# Validate that we don't have AsyncIterator[None] which doesn't make any practical sense
|
|
1592
|
+
if is_stream_type(request_type):
|
|
1593
|
+
stream_item_type = get_args(request_type)[0]
|
|
1594
|
+
if is_none_type(stream_item_type):
|
|
1595
|
+
raise TypeError(
|
|
1596
|
+
f"Method '{method_name}' has AsyncIterator[None] as input type, which is not allowed. Streaming Empty messages is meaningless."
|
|
1597
|
+
)
|
|
1598
|
+
|
|
1599
|
+
if is_stream_type(response_type):
|
|
1600
|
+
stream_item_type = get_args(response_type)[0]
|
|
1601
|
+
if is_none_type(stream_item_type):
|
|
1602
|
+
raise TypeError(
|
|
1603
|
+
f"Method '{method_name}' has AsyncIterator[None] as return type, which is not allowed. Streaming Empty messages is meaningless."
|
|
1604
|
+
)
|
|
1605
|
+
|
|
1606
|
+
# Handle NoneType for request and response
|
|
1607
|
+
if is_none_type(request_type):
|
|
1608
|
+
uses_empty = True
|
|
1609
|
+
if is_none_type(response_type):
|
|
1610
|
+
uses_empty = True
|
|
1611
|
+
|
|
857
1612
|
# Recursively generate message definitions
|
|
858
|
-
message_types = [
|
|
1613
|
+
message_types = []
|
|
1614
|
+
if not is_none_type(request_type):
|
|
1615
|
+
message_types.append(request_type)
|
|
1616
|
+
if not is_none_type(response_type):
|
|
1617
|
+
message_types.append(response_type)
|
|
1618
|
+
|
|
859
1619
|
while message_types:
|
|
860
|
-
mt = message_types.pop()
|
|
861
|
-
if mt in done_messages:
|
|
1620
|
+
mt: type[Message] | type[ServicerContext] | None = message_types.pop()
|
|
1621
|
+
if mt in done_messages or mt is ServicerContext or mt is None:
|
|
862
1622
|
continue
|
|
863
1623
|
done_messages.add(mt)
|
|
864
1624
|
|
|
865
1625
|
if is_stream_type(mt):
|
|
866
1626
|
item_type = get_args(mt)[0]
|
|
867
|
-
|
|
1627
|
+
if not is_none_type(item_type):
|
|
1628
|
+
message_types.append(item_type)
|
|
868
1629
|
continue
|
|
869
1630
|
|
|
1631
|
+
# Check if mt is actually a Message type
|
|
1632
|
+
if not (inspect.isclass(mt) and issubclass(mt, Message)):
|
|
1633
|
+
# Not a Message type, skip processing
|
|
1634
|
+
continue
|
|
1635
|
+
|
|
1636
|
+
mt = cast(type[Message], mt)
|
|
1637
|
+
|
|
870
1638
|
for _, field_info in mt.model_fields.items():
|
|
871
1639
|
t = field_info.annotation
|
|
872
1640
|
if is_union_type(t):
|
|
873
1641
|
for sub_t in flatten_union(t):
|
|
874
|
-
|
|
1642
|
+
check_and_set_well_known_types_for_fields(
|
|
1643
|
+
sub_t
|
|
1644
|
+
) # Use the field-specific version
|
|
875
1645
|
else:
|
|
876
|
-
|
|
1646
|
+
check_and_set_well_known_types_for_fields(
|
|
1647
|
+
t
|
|
1648
|
+
) # Use the field-specific version
|
|
877
1649
|
|
|
878
1650
|
msg_def, refs = generate_message_definition(mt, done_enums, done_messages)
|
|
879
|
-
mt_doc = inspect.getdoc(mt)
|
|
880
|
-
if mt_doc:
|
|
881
|
-
for comment_line in comment_out(mt_doc):
|
|
882
|
-
all_type_definitions.append(comment_line)
|
|
883
1651
|
|
|
884
|
-
|
|
885
|
-
|
|
1652
|
+
# Skip adding definition if it's an empty message (will use google.protobuf.Empty)
|
|
1653
|
+
if msg_def != "__EMPTY__":
|
|
1654
|
+
mt_doc = inspect.getdoc(mt)
|
|
1655
|
+
if mt_doc:
|
|
1656
|
+
for comment_line in comment_out(mt_doc):
|
|
1657
|
+
all_type_definitions.append(comment_line)
|
|
1658
|
+
|
|
1659
|
+
all_type_definitions.append(msg_def)
|
|
1660
|
+
all_type_definitions.append("")
|
|
1661
|
+
else:
|
|
1662
|
+
# Mark that we need google.protobuf.Empty import
|
|
1663
|
+
uses_empty = True
|
|
886
1664
|
|
|
887
1665
|
for r in refs:
|
|
888
1666
|
if is_enum_type(r) and r not in done_enums:
|
|
@@ -898,16 +1676,63 @@ def generate_proto(obj: object, package_name: str = "") -> str:
|
|
|
898
1676
|
for comment_line in comment_out(method_docstr):
|
|
899
1677
|
rpc_definitions.append(comment_line)
|
|
900
1678
|
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
1679
|
+
input_type = request_type
|
|
1680
|
+
input_is_stream = is_stream_type(input_type)
|
|
1681
|
+
output_is_stream = is_stream_type(response_type)
|
|
1682
|
+
|
|
1683
|
+
if input_is_stream:
|
|
1684
|
+
input_msg_type = get_args(input_type)[0]
|
|
1685
|
+
else:
|
|
1686
|
+
input_msg_type = input_type
|
|
1687
|
+
|
|
1688
|
+
if output_is_stream:
|
|
1689
|
+
output_msg_type = get_args(response_type)[0]
|
|
1690
|
+
else:
|
|
1691
|
+
output_msg_type = response_type
|
|
1692
|
+
|
|
1693
|
+
# Handle NoneType and empty messages by using Empty
|
|
1694
|
+
if input_msg_type is None or input_msg_type is ServicerContext:
|
|
1695
|
+
input_str = "google.protobuf.Empty" # No need to check for stream since we validated above
|
|
1696
|
+
if input_msg_type is ServicerContext:
|
|
1697
|
+
uses_empty = True
|
|
1698
|
+
elif (
|
|
1699
|
+
inspect.isclass(input_msg_type)
|
|
1700
|
+
and issubclass(input_msg_type, Message)
|
|
1701
|
+
and not input_msg_type.model_fields
|
|
1702
|
+
):
|
|
1703
|
+
# Empty message class
|
|
1704
|
+
input_str = "google.protobuf.Empty"
|
|
1705
|
+
uses_empty = True
|
|
1706
|
+
else:
|
|
1707
|
+
input_str = (
|
|
1708
|
+
f"stream {input_msg_type.__name__}"
|
|
1709
|
+
if input_is_stream
|
|
1710
|
+
else input_msg_type.__name__
|
|
905
1711
|
)
|
|
1712
|
+
|
|
1713
|
+
if output_msg_type is None or output_msg_type is ServicerContext:
|
|
1714
|
+
output_str = "google.protobuf.Empty" # No need to check for stream since we validated above
|
|
1715
|
+
if output_msg_type is ServicerContext:
|
|
1716
|
+
uses_empty = True
|
|
1717
|
+
elif (
|
|
1718
|
+
inspect.isclass(output_msg_type)
|
|
1719
|
+
and issubclass(output_msg_type, Message)
|
|
1720
|
+
and not output_msg_type.model_fields
|
|
1721
|
+
):
|
|
1722
|
+
# Empty message class
|
|
1723
|
+
output_str = "google.protobuf.Empty"
|
|
1724
|
+
uses_empty = True
|
|
906
1725
|
else:
|
|
907
|
-
|
|
908
|
-
f"
|
|
1726
|
+
output_str = (
|
|
1727
|
+
f"stream {output_msg_type.__name__}"
|
|
1728
|
+
if output_is_stream
|
|
1729
|
+
else output_msg_type.__name__
|
|
909
1730
|
)
|
|
910
1731
|
|
|
1732
|
+
rpc_definitions.append(
|
|
1733
|
+
f"rpc {method_name} ({input_str}) returns ({output_str});"
|
|
1734
|
+
)
|
|
1735
|
+
|
|
911
1736
|
if not package_name:
|
|
912
1737
|
if service_name.endswith("Service"):
|
|
913
1738
|
package_name = service_name[: -len("Service")]
|
|
@@ -915,11 +1740,13 @@ def generate_proto(obj: object, package_name: str = "") -> str:
|
|
|
915
1740
|
package_name = service_name
|
|
916
1741
|
package_name = package_name.lower() + ".v1"
|
|
917
1742
|
|
|
918
|
-
imports = []
|
|
1743
|
+
imports: list[str] = []
|
|
919
1744
|
if uses_timestamp:
|
|
920
1745
|
imports.append('import "google/protobuf/timestamp.proto";')
|
|
921
1746
|
if uses_duration:
|
|
922
1747
|
imports.append('import "google/protobuf/duration.proto";')
|
|
1748
|
+
if uses_empty:
|
|
1749
|
+
imports.append('import "google/protobuf/empty.proto";')
|
|
923
1750
|
|
|
924
1751
|
import_block = "\n".join(imports)
|
|
925
1752
|
if import_block:
|
|
@@ -939,106 +1766,183 @@ service {service_name} {{
|
|
|
939
1766
|
return proto_definition
|
|
940
1767
|
|
|
941
1768
|
|
|
942
|
-
def generate_grpc_code(
|
|
1769
|
+
def generate_grpc_code(proto_path: Path) -> types.ModuleType | None:
|
|
943
1770
|
"""
|
|
944
|
-
|
|
945
|
-
|
|
1771
|
+
Run protoc to generate Python gRPC code from proto_path.
|
|
1772
|
+
Writes foo_pb2_grpc.py next to proto_path, then imports and returns that module.
|
|
946
1773
|
"""
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
1774
|
+
# 1) Ensure the .proto exists
|
|
1775
|
+
if not proto_path.is_file():
|
|
1776
|
+
raise FileNotFoundError(f"{proto_path!r} does not exist")
|
|
1777
|
+
|
|
1778
|
+
# 2) Determine output directory (same as the .proto's parent)
|
|
1779
|
+
proto_path = proto_path.resolve()
|
|
1780
|
+
out_dir = proto_path.parent
|
|
1781
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
1782
|
+
|
|
1783
|
+
# 3) Build and run the protoc command
|
|
1784
|
+
out_str = str(out_dir)
|
|
1785
|
+
well_known_path = os.path.join(os.path.dirname(grpc_tools.__file__), "_proto")
|
|
1786
|
+
args = [
|
|
1787
|
+
"protoc", # Dummy program name (required for protoc.main)
|
|
1788
|
+
"-I.",
|
|
1789
|
+
f"-I{well_known_path}",
|
|
1790
|
+
f"--grpc_python_out={out_str}",
|
|
1791
|
+
proto_path.name,
|
|
1792
|
+
]
|
|
1793
|
+
|
|
1794
|
+
current_dir = os.getcwd()
|
|
1795
|
+
os.chdir(str(out_dir))
|
|
1796
|
+
try:
|
|
1797
|
+
if protoc.main(args) != 0:
|
|
1798
|
+
return None
|
|
1799
|
+
finally:
|
|
1800
|
+
os.chdir(current_dir)
|
|
951
1801
|
|
|
952
|
-
|
|
953
|
-
|
|
1802
|
+
# 4) Locate the generated gRPC file
|
|
1803
|
+
base_name = proto_path.stem # "foo"
|
|
1804
|
+
generated_filename = f"{base_name}_pb2_grpc.py" # "foo_pb2_grpc.py"
|
|
1805
|
+
generated_filepath = out_dir / generated_filename
|
|
954
1806
|
|
|
955
|
-
|
|
956
|
-
|
|
1807
|
+
# 5) Add out_dir to sys.path so we can import it
|
|
1808
|
+
if out_str not in sys.path:
|
|
1809
|
+
sys.path.append(out_str)
|
|
957
1810
|
|
|
1811
|
+
# 6) Load and return the module
|
|
958
1812
|
spec = importlib.util.spec_from_file_location(
|
|
959
|
-
|
|
1813
|
+
base_name + "_pb2_grpc", str(generated_filepath)
|
|
960
1814
|
)
|
|
961
|
-
if spec is None:
|
|
1815
|
+
if spec is None or spec.loader is None:
|
|
962
1816
|
return None
|
|
963
|
-
pb2_grpc_module = importlib.util.module_from_spec(spec)
|
|
964
|
-
if spec.loader is None:
|
|
965
|
-
return None
|
|
966
|
-
spec.loader.exec_module(pb2_grpc_module)
|
|
967
1817
|
|
|
968
|
-
|
|
1818
|
+
module = importlib.util.module_from_spec(spec)
|
|
1819
|
+
spec.loader.exec_module(module)
|
|
1820
|
+
return module
|
|
969
1821
|
|
|
970
1822
|
|
|
971
|
-
def generate_connecpy_code(
|
|
972
|
-
proto_file: str, connecpy_out: str
|
|
973
|
-
) -> types.ModuleType | None:
|
|
1823
|
+
def generate_connecpy_code(proto_path: Path) -> types.ModuleType | None:
|
|
974
1824
|
"""
|
|
975
|
-
|
|
976
|
-
|
|
1825
|
+
Run protoc with the Connecpy plugin to generate Python Connecpy code from proto_path.
|
|
1826
|
+
Writes foo_connecpy.py next to proto_path, then imports and returns that module.
|
|
977
1827
|
"""
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
1828
|
+
# 1) Ensure the .proto exists
|
|
1829
|
+
if not proto_path.is_file():
|
|
1830
|
+
raise FileNotFoundError(f"{proto_path!r} does not exist")
|
|
1831
|
+
|
|
1832
|
+
# 2) Determine output directory (same as the .proto's parent)
|
|
1833
|
+
proto_path = proto_path.resolve()
|
|
1834
|
+
out_dir = proto_path.parent
|
|
1835
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
1836
|
+
|
|
1837
|
+
# 3) Build and run the protoc command
|
|
1838
|
+
out_str = str(out_dir)
|
|
1839
|
+
well_known_path = os.path.join(os.path.dirname(grpc_tools.__file__), "_proto")
|
|
1840
|
+
args = [
|
|
1841
|
+
"protoc", # Dummy program name (required for protoc.main)
|
|
1842
|
+
"-I.",
|
|
1843
|
+
f"-I{well_known_path}",
|
|
1844
|
+
f"--connecpy_out={out_str}",
|
|
1845
|
+
proto_path.name,
|
|
1846
|
+
]
|
|
982
1847
|
|
|
983
|
-
|
|
984
|
-
|
|
1848
|
+
current_dir = os.getcwd()
|
|
1849
|
+
os.chdir(str(out_dir))
|
|
1850
|
+
try:
|
|
1851
|
+
if protoc.main(args) != 0:
|
|
1852
|
+
return None
|
|
1853
|
+
finally:
|
|
1854
|
+
os.chdir(current_dir)
|
|
985
1855
|
|
|
986
|
-
|
|
987
|
-
|
|
1856
|
+
# 4) Locate the generated file
|
|
1857
|
+
base_name = proto_path.stem # "foo"
|
|
1858
|
+
generated_filename = f"{base_name}_connecpy.py" # "foo_connecpy.py"
|
|
1859
|
+
generated_filepath = out_dir / generated_filename
|
|
988
1860
|
|
|
1861
|
+
# 5) Add out_dir to sys.path so we can import by filename
|
|
1862
|
+
if out_str not in sys.path:
|
|
1863
|
+
sys.path.append(out_str)
|
|
1864
|
+
|
|
1865
|
+
# 6) Load and return the module
|
|
989
1866
|
spec = importlib.util.spec_from_file_location(
|
|
990
|
-
|
|
1867
|
+
base_name + "_connecpy", str(generated_filepath)
|
|
991
1868
|
)
|
|
992
|
-
if spec is None:
|
|
993
|
-
return None
|
|
994
|
-
connecpy_module = importlib.util.module_from_spec(spec)
|
|
995
|
-
if spec.loader is None:
|
|
1869
|
+
if spec is None or spec.loader is None:
|
|
996
1870
|
return None
|
|
997
|
-
spec.loader.exec_module(connecpy_module)
|
|
998
1871
|
|
|
999
|
-
|
|
1872
|
+
module = importlib.util.module_from_spec(spec)
|
|
1873
|
+
spec.loader.exec_module(module)
|
|
1874
|
+
return module
|
|
1000
1875
|
|
|
1001
1876
|
|
|
1002
|
-
def generate_pb_code(
|
|
1003
|
-
proto_file: str, python_out: str, pyi_out: str
|
|
1004
|
-
) -> types.ModuleType | None:
|
|
1877
|
+
def generate_pb_code(proto_path: Path) -> types.ModuleType | None:
|
|
1005
1878
|
"""
|
|
1006
|
-
|
|
1007
|
-
|
|
1879
|
+
Run protoc to generate Python gRPC code from proto_path.
|
|
1880
|
+
Writes foo_pb2.py and foo_pb2.pyi next to proto_path, then imports and returns the pb2 module.
|
|
1008
1881
|
"""
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1882
|
+
# 1) Make sure proto_path exists
|
|
1883
|
+
if not proto_path.is_file():
|
|
1884
|
+
raise FileNotFoundError(f"{proto_path!r} does not exist")
|
|
1885
|
+
|
|
1886
|
+
# 2) Determine output directory (same as proto file)
|
|
1887
|
+
proto_path = proto_path.resolve()
|
|
1888
|
+
out_dir = proto_path.parent
|
|
1889
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
1890
|
+
|
|
1891
|
+
# 3) Build and run protoc command
|
|
1892
|
+
out_str = str(out_dir)
|
|
1893
|
+
well_known_path = os.path.join(os.path.dirname(grpc_tools.__file__), "_proto")
|
|
1894
|
+
args = [
|
|
1895
|
+
"protoc", # Dummy program name (required for protoc.main)
|
|
1896
|
+
"-I.",
|
|
1897
|
+
f"-I{well_known_path}",
|
|
1898
|
+
f"--python_out={out_str}",
|
|
1899
|
+
f"--pyi_out={out_str}",
|
|
1900
|
+
proto_path.name,
|
|
1901
|
+
]
|
|
1902
|
+
|
|
1903
|
+
current_dir = os.getcwd()
|
|
1904
|
+
os.chdir(str(out_dir))
|
|
1905
|
+
try:
|
|
1906
|
+
if protoc.main(args) != 0:
|
|
1907
|
+
return None
|
|
1908
|
+
finally:
|
|
1909
|
+
os.chdir(current_dir)
|
|
1013
1910
|
|
|
1014
|
-
|
|
1015
|
-
|
|
1911
|
+
# 4) Locate generated file
|
|
1912
|
+
base_name = proto_path.stem # e.g. "foo"
|
|
1913
|
+
generated_file = out_dir / f"{base_name}_pb2.py"
|
|
1016
1914
|
|
|
1017
|
-
|
|
1018
|
-
|
|
1915
|
+
# 5) Add to sys.path if needed
|
|
1916
|
+
if out_str not in sys.path:
|
|
1917
|
+
sys.path.append(out_str)
|
|
1019
1918
|
|
|
1919
|
+
# 6) Import it
|
|
1020
1920
|
spec = importlib.util.spec_from_file_location(
|
|
1021
|
-
|
|
1921
|
+
base_name + "_pb2", str(generated_file)
|
|
1022
1922
|
)
|
|
1023
|
-
if spec is None:
|
|
1024
|
-
return None
|
|
1025
|
-
pb2_module = importlib.util.module_from_spec(spec)
|
|
1026
|
-
if spec.loader is None:
|
|
1923
|
+
if spec is None or spec.loader is None:
|
|
1027
1924
|
return None
|
|
1028
|
-
|
|
1925
|
+
module = importlib.util.module_from_spec(spec)
|
|
1926
|
+
spec.loader.exec_module(module)
|
|
1927
|
+
return module
|
|
1029
1928
|
|
|
1030
|
-
return pb2_module
|
|
1031
1929
|
|
|
1930
|
+
def get_request_arg_type(sig: inspect.Signature) -> Any:
|
|
1931
|
+
"""Return the type annotation of the first parameter (request) of a method.
|
|
1032
1932
|
|
|
1033
|
-
|
|
1034
|
-
"""
|
|
1933
|
+
If the method has no parameters, return None (implying an empty request).
|
|
1934
|
+
"""
|
|
1035
1935
|
num_of_params = len(sig.parameters)
|
|
1036
|
-
if
|
|
1037
|
-
|
|
1038
|
-
|
|
1936
|
+
if num_of_params == 0:
|
|
1937
|
+
# No parameters means empty request
|
|
1938
|
+
return None
|
|
1939
|
+
elif num_of_params == 1 or num_of_params == 2:
|
|
1940
|
+
return tuple(sig.parameters.values())[0].annotation
|
|
1941
|
+
else:
|
|
1942
|
+
raise Exception("Method must have 0, 1, or 2 parameters")
|
|
1039
1943
|
|
|
1040
1944
|
|
|
1041
|
-
def get_rpc_methods(obj: object) -> list[tuple[str,
|
|
1945
|
+
def get_rpc_methods(obj: object) -> list[tuple[str, Callable[..., Any]]]:
|
|
1042
1946
|
"""
|
|
1043
1947
|
Retrieve the list of RPC methods from a service object.
|
|
1044
1948
|
The method name is converted to PascalCase for .proto compatibility.
|
|
@@ -1059,68 +1963,429 @@ def is_skip_generation() -> bool:
|
|
|
1059
1963
|
return os.getenv("PYDANTIC_RPC_SKIP_GENERATION", "false").lower() == "true"
|
|
1060
1964
|
|
|
1061
1965
|
|
|
1062
|
-
def
|
|
1966
|
+
def get_reserved_fields_count() -> int:
|
|
1967
|
+
"""Get the number of reserved fields to add to each message from environment variable."""
|
|
1968
|
+
try:
|
|
1969
|
+
return max(0, int(os.getenv("PYDANTIC_RPC_RESERVED_FIELDS", "0")))
|
|
1970
|
+
except ValueError:
|
|
1971
|
+
return 0
|
|
1972
|
+
|
|
1973
|
+
|
|
1974
|
+
def generate_and_compile_proto(
|
|
1975
|
+
obj: object,
|
|
1976
|
+
package_name: str = "",
|
|
1977
|
+
existing_proto_path: Path | None = None,
|
|
1978
|
+
) -> tuple[Any, Any]:
|
|
1063
1979
|
if is_skip_generation():
|
|
1064
1980
|
import importlib
|
|
1065
1981
|
|
|
1066
|
-
pb2_module =
|
|
1067
|
-
pb2_grpc_module =
|
|
1068
|
-
|
|
1069
|
-
|
|
1982
|
+
pb2_module = None
|
|
1983
|
+
pb2_grpc_module = None
|
|
1984
|
+
|
|
1985
|
+
try:
|
|
1986
|
+
pb2_module = importlib.import_module(
|
|
1987
|
+
f"{obj.__class__.__name__.lower()}_pb2"
|
|
1988
|
+
)
|
|
1989
|
+
except ImportError:
|
|
1990
|
+
pass
|
|
1991
|
+
|
|
1992
|
+
try:
|
|
1993
|
+
pb2_grpc_module = importlib.import_module(
|
|
1994
|
+
f"{obj.__class__.__name__.lower()}_pb2_grpc"
|
|
1995
|
+
)
|
|
1996
|
+
except ImportError:
|
|
1997
|
+
pass
|
|
1070
1998
|
|
|
1071
1999
|
if pb2_grpc_module is not None and pb2_module is not None:
|
|
1072
2000
|
return pb2_grpc_module, pb2_module
|
|
1073
2001
|
|
|
1074
2002
|
# If the modules are not found, generate and compile the proto files.
|
|
1075
2003
|
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
2004
|
+
if existing_proto_path:
|
|
2005
|
+
# Use the provided existing proto file (skip generation)
|
|
2006
|
+
proto_file_path = existing_proto_path
|
|
2007
|
+
else:
|
|
2008
|
+
# Generate as before
|
|
2009
|
+
klass = obj.__class__
|
|
2010
|
+
proto_file = generate_proto(obj, package_name)
|
|
2011
|
+
proto_file_name = klass.__name__.lower() + ".proto"
|
|
2012
|
+
proto_file_path = get_proto_path(proto_file_name)
|
|
1079
2013
|
|
|
1080
|
-
|
|
1081
|
-
|
|
2014
|
+
with proto_file_path.open(mode="w", encoding="utf-8") as f:
|
|
2015
|
+
_ = f.write(proto_file)
|
|
1082
2016
|
|
|
1083
|
-
gen_pb = generate_pb_code(
|
|
2017
|
+
gen_pb = generate_pb_code(proto_file_path)
|
|
1084
2018
|
if gen_pb is None:
|
|
1085
2019
|
raise Exception("Generating pb code")
|
|
1086
2020
|
|
|
1087
|
-
gen_grpc = generate_grpc_code(
|
|
2021
|
+
gen_grpc = generate_grpc_code(proto_file_path)
|
|
1088
2022
|
if gen_grpc is None:
|
|
1089
2023
|
raise Exception("Generating grpc code")
|
|
1090
2024
|
return gen_grpc, gen_pb
|
|
1091
2025
|
|
|
1092
2026
|
|
|
1093
|
-
def
|
|
2027
|
+
def get_proto_path(proto_filename: str) -> Path:
|
|
2028
|
+
# 1. Get raw env var (or default to cwd)
|
|
2029
|
+
raw = os.getenv("PYDANTIC_RPC_PROTO_PATH", None)
|
|
2030
|
+
base = Path(raw) if raw is not None else Path.cwd()
|
|
2031
|
+
|
|
2032
|
+
# 2. Expand ~ and env-vars, then make absolute
|
|
2033
|
+
base = Path(os.path.expandvars(os.path.expanduser(str(base)))).resolve()
|
|
2034
|
+
|
|
2035
|
+
# 3. Ensure it's a directory (or create it)
|
|
2036
|
+
if not base.exists():
|
|
2037
|
+
try:
|
|
2038
|
+
base.mkdir(parents=True, exist_ok=True)
|
|
2039
|
+
except OSError as e:
|
|
2040
|
+
raise RuntimeError(f"Unable to create directory {base!r}: {e}") from e
|
|
2041
|
+
elif not base.is_dir():
|
|
2042
|
+
raise NotADirectoryError(f"{base!r} exists but is not a directory")
|
|
2043
|
+
|
|
2044
|
+
# 4. Check writability
|
|
2045
|
+
if not os.access(base, os.W_OK):
|
|
2046
|
+
raise PermissionError(f"No write permission for directory {base!r}")
|
|
2047
|
+
|
|
2048
|
+
# 5. Return the final file path
|
|
2049
|
+
return base / proto_filename
|
|
2050
|
+
|
|
2051
|
+
|
|
2052
|
+
def generate_and_compile_proto_using_connecpy(
|
|
2053
|
+
obj: object,
|
|
2054
|
+
package_name: str = "",
|
|
2055
|
+
existing_proto_path: Path | None = None,
|
|
2056
|
+
) -> tuple[Any, Any]:
|
|
1094
2057
|
if is_skip_generation():
|
|
1095
2058
|
import importlib
|
|
1096
2059
|
|
|
1097
|
-
pb2_module =
|
|
1098
|
-
connecpy_module =
|
|
1099
|
-
|
|
1100
|
-
|
|
2060
|
+
pb2_module = None
|
|
2061
|
+
connecpy_module = None
|
|
2062
|
+
|
|
2063
|
+
try:
|
|
2064
|
+
pb2_module = importlib.import_module(
|
|
2065
|
+
f"{obj.__class__.__name__.lower()}_pb2"
|
|
2066
|
+
)
|
|
2067
|
+
except ImportError:
|
|
2068
|
+
pass
|
|
2069
|
+
|
|
2070
|
+
try:
|
|
2071
|
+
connecpy_module = importlib.import_module(
|
|
2072
|
+
f"{obj.__class__.__name__.lower()}_connecpy"
|
|
2073
|
+
)
|
|
2074
|
+
except ImportError:
|
|
2075
|
+
pass
|
|
1101
2076
|
|
|
1102
2077
|
if connecpy_module is not None and pb2_module is not None:
|
|
1103
2078
|
return connecpy_module, pb2_module
|
|
1104
2079
|
|
|
1105
2080
|
# If the modules are not found, generate and compile the proto files.
|
|
1106
2081
|
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
2082
|
+
if existing_proto_path:
|
|
2083
|
+
# Use the provided existing proto file (skip generation)
|
|
2084
|
+
proto_file_path = existing_proto_path
|
|
2085
|
+
else:
|
|
2086
|
+
# Generate as before
|
|
2087
|
+
klass = obj.__class__
|
|
2088
|
+
proto_file = generate_proto(obj, package_name)
|
|
2089
|
+
proto_file_name = klass.__name__.lower() + ".proto"
|
|
1110
2090
|
|
|
1111
|
-
|
|
1112
|
-
|
|
2091
|
+
proto_file_path = get_proto_path(proto_file_name)
|
|
2092
|
+
with proto_file_path.open(mode="w", encoding="utf-8") as f:
|
|
2093
|
+
_ = f.write(proto_file)
|
|
1113
2094
|
|
|
1114
|
-
gen_pb = generate_pb_code(
|
|
2095
|
+
gen_pb = generate_pb_code(proto_file_path)
|
|
1115
2096
|
if gen_pb is None:
|
|
1116
2097
|
raise Exception("Generating pb code")
|
|
1117
2098
|
|
|
1118
|
-
gen_connecpy = generate_connecpy_code(
|
|
2099
|
+
gen_connecpy = generate_connecpy_code(proto_file_path)
|
|
1119
2100
|
if gen_connecpy is None:
|
|
1120
2101
|
raise Exception("Generating Connecpy code")
|
|
1121
2102
|
return gen_connecpy, gen_pb
|
|
1122
2103
|
|
|
1123
2104
|
|
|
2105
|
+
def is_combined_proto_enabled() -> bool:
|
|
2106
|
+
"""Check if combined proto file generation is enabled."""
|
|
2107
|
+
return os.getenv("PYDANTIC_RPC_COMBINED_PROTO", "false").lower() == "true"
|
|
2108
|
+
|
|
2109
|
+
|
|
2110
|
+
def generate_combined_proto(
|
|
2111
|
+
*services: object, package_name: str = "combined.v1"
|
|
2112
|
+
) -> str:
|
|
2113
|
+
"""Generate a combined .proto definition from multiple service classes."""
|
|
2114
|
+
import datetime
|
|
2115
|
+
|
|
2116
|
+
all_type_definitions: list[str] = []
|
|
2117
|
+
all_service_definitions: list[str] = []
|
|
2118
|
+
done_messages: set[Any] = set()
|
|
2119
|
+
done_enums: set[Any] = set()
|
|
2120
|
+
|
|
2121
|
+
uses_timestamp = False
|
|
2122
|
+
uses_duration = False
|
|
2123
|
+
uses_empty = False
|
|
2124
|
+
|
|
2125
|
+
def check_and_set_well_known_types_for_fields(py_type: Any):
|
|
2126
|
+
"""Check well-known types for field annotations (excludes None/Empty)."""
|
|
2127
|
+
nonlocal uses_timestamp, uses_duration
|
|
2128
|
+
if py_type == datetime.datetime:
|
|
2129
|
+
uses_timestamp = True
|
|
2130
|
+
if py_type == datetime.timedelta:
|
|
2131
|
+
uses_duration = True
|
|
2132
|
+
|
|
2133
|
+
# Process each service
|
|
2134
|
+
for service_obj in services:
|
|
2135
|
+
service_class = service_obj.__class__
|
|
2136
|
+
service_name = service_class.__name__
|
|
2137
|
+
service_docstr = inspect.getdoc(service_class)
|
|
2138
|
+
service_comment = (
|
|
2139
|
+
"\n".join(comment_out(service_docstr)) if service_docstr else ""
|
|
2140
|
+
)
|
|
2141
|
+
|
|
2142
|
+
service_rpc_definitions: list[str] = []
|
|
2143
|
+
|
|
2144
|
+
for method_name, method in get_rpc_methods(service_obj):
|
|
2145
|
+
if method.__name__.startswith("_"):
|
|
2146
|
+
continue
|
|
2147
|
+
|
|
2148
|
+
method_sig = inspect.signature(method)
|
|
2149
|
+
request_type = get_request_arg_type(method_sig)
|
|
2150
|
+
response_type = method_sig.return_annotation
|
|
2151
|
+
|
|
2152
|
+
# Validate stream types
|
|
2153
|
+
if is_stream_type(request_type):
|
|
2154
|
+
stream_item_type = get_args(request_type)[0]
|
|
2155
|
+
if is_none_type(stream_item_type):
|
|
2156
|
+
raise TypeError(
|
|
2157
|
+
f"Method '{method_name}' has AsyncIterator[None] as input type, which is not allowed."
|
|
2158
|
+
)
|
|
2159
|
+
|
|
2160
|
+
if is_stream_type(response_type):
|
|
2161
|
+
stream_item_type = get_args(response_type)[0]
|
|
2162
|
+
if is_none_type(stream_item_type):
|
|
2163
|
+
raise TypeError(
|
|
2164
|
+
f"Method '{method_name}' has AsyncIterator[None] as return type, which is not allowed."
|
|
2165
|
+
)
|
|
2166
|
+
|
|
2167
|
+
# Handle NoneType for request and response
|
|
2168
|
+
if is_none_type(request_type):
|
|
2169
|
+
uses_empty = True
|
|
2170
|
+
if is_none_type(response_type):
|
|
2171
|
+
uses_empty = True
|
|
2172
|
+
|
|
2173
|
+
# Validate that users aren't using protobuf messages directly
|
|
2174
|
+
if hasattr(request_type, "DESCRIPTOR") and not is_none_type(request_type):
|
|
2175
|
+
raise TypeError(
|
|
2176
|
+
f"Method '{method_name}' uses protobuf message '{request_type.__name__}' directly. "
|
|
2177
|
+
f"Please use Pydantic Message classes instead, or None/empty Message for empty requests."
|
|
2178
|
+
)
|
|
2179
|
+
if hasattr(response_type, "DESCRIPTOR") and not is_none_type(response_type):
|
|
2180
|
+
raise TypeError(
|
|
2181
|
+
f"Method '{method_name}' uses protobuf message '{response_type.__name__}' directly. "
|
|
2182
|
+
f"Please use Pydantic Message classes instead, or None/empty Message for empty responses."
|
|
2183
|
+
)
|
|
2184
|
+
|
|
2185
|
+
# Collect message types for processing
|
|
2186
|
+
message_types = []
|
|
2187
|
+
if not is_none_type(request_type):
|
|
2188
|
+
message_types.append(request_type)
|
|
2189
|
+
if not is_none_type(response_type):
|
|
2190
|
+
message_types.append(response_type)
|
|
2191
|
+
|
|
2192
|
+
# Process message types
|
|
2193
|
+
while message_types:
|
|
2194
|
+
mt: type[Message] | type[ServicerContext] | None = message_types.pop()
|
|
2195
|
+
if mt in done_messages or mt is ServicerContext or mt is None:
|
|
2196
|
+
continue
|
|
2197
|
+
done_messages.add(mt)
|
|
2198
|
+
|
|
2199
|
+
if is_stream_type(mt):
|
|
2200
|
+
item_type = get_args(mt)[0]
|
|
2201
|
+
if not is_none_type(item_type):
|
|
2202
|
+
message_types.append(item_type)
|
|
2203
|
+
continue
|
|
2204
|
+
|
|
2205
|
+
mt = cast(type[Message], mt)
|
|
2206
|
+
|
|
2207
|
+
for _, field_info in mt.model_fields.items():
|
|
2208
|
+
t = field_info.annotation
|
|
2209
|
+
if is_union_type(t):
|
|
2210
|
+
for sub_t in flatten_union(t):
|
|
2211
|
+
check_and_set_well_known_types_for_fields(sub_t)
|
|
2212
|
+
else:
|
|
2213
|
+
check_and_set_well_known_types_for_fields(t)
|
|
2214
|
+
|
|
2215
|
+
msg_def, refs = generate_message_definition(
|
|
2216
|
+
mt, done_enums, done_messages
|
|
2217
|
+
)
|
|
2218
|
+
|
|
2219
|
+
# Skip adding definition if it's an empty message (will use google.protobuf.Empty)
|
|
2220
|
+
if msg_def != "__EMPTY__":
|
|
2221
|
+
mt_doc = inspect.getdoc(mt)
|
|
2222
|
+
if mt_doc:
|
|
2223
|
+
for comment_line in comment_out(mt_doc):
|
|
2224
|
+
all_type_definitions.append(comment_line)
|
|
2225
|
+
|
|
2226
|
+
all_type_definitions.append(msg_def)
|
|
2227
|
+
all_type_definitions.append("")
|
|
2228
|
+
else:
|
|
2229
|
+
# Mark that we need google.protobuf.Empty import
|
|
2230
|
+
uses_empty = True
|
|
2231
|
+
|
|
2232
|
+
for r in refs:
|
|
2233
|
+
if is_enum_type(r) and r not in done_enums:
|
|
2234
|
+
done_enums.add(r)
|
|
2235
|
+
enum_def = generate_enum_definition(r)
|
|
2236
|
+
all_type_definitions.append(enum_def)
|
|
2237
|
+
all_type_definitions.append("")
|
|
2238
|
+
elif issubclass(r, Message) and r not in done_messages:
|
|
2239
|
+
message_types.append(r)
|
|
2240
|
+
|
|
2241
|
+
# Generate RPC definition
|
|
2242
|
+
method_docstr = inspect.getdoc(method)
|
|
2243
|
+
if method_docstr:
|
|
2244
|
+
for comment_line in comment_out(method_docstr):
|
|
2245
|
+
service_rpc_definitions.append(comment_line)
|
|
2246
|
+
|
|
2247
|
+
input_type = request_type
|
|
2248
|
+
input_is_stream = is_stream_type(input_type)
|
|
2249
|
+
output_is_stream = is_stream_type(response_type)
|
|
2250
|
+
|
|
2251
|
+
if input_is_stream:
|
|
2252
|
+
input_msg_type = get_args(input_type)[0]
|
|
2253
|
+
else:
|
|
2254
|
+
input_msg_type = input_type
|
|
2255
|
+
|
|
2256
|
+
if output_is_stream:
|
|
2257
|
+
output_msg_type = get_args(response_type)[0]
|
|
2258
|
+
else:
|
|
2259
|
+
output_msg_type = response_type
|
|
2260
|
+
|
|
2261
|
+
# Handle NoneType and empty messages by using Empty
|
|
2262
|
+
if input_msg_type is None or input_msg_type is ServicerContext:
|
|
2263
|
+
input_str = "google.protobuf.Empty" # No need to check for stream since we validated above
|
|
2264
|
+
if input_msg_type is ServicerContext:
|
|
2265
|
+
uses_empty = True
|
|
2266
|
+
elif (
|
|
2267
|
+
inspect.isclass(input_msg_type)
|
|
2268
|
+
and issubclass(input_msg_type, Message)
|
|
2269
|
+
and not input_msg_type.model_fields
|
|
2270
|
+
):
|
|
2271
|
+
# Empty message class
|
|
2272
|
+
input_str = "google.protobuf.Empty"
|
|
2273
|
+
uses_empty = True
|
|
2274
|
+
else:
|
|
2275
|
+
input_str = (
|
|
2276
|
+
f"stream {input_msg_type.__name__}"
|
|
2277
|
+
if input_is_stream
|
|
2278
|
+
else input_msg_type.__name__
|
|
2279
|
+
)
|
|
2280
|
+
|
|
2281
|
+
if output_msg_type is None or output_msg_type is ServicerContext:
|
|
2282
|
+
output_str = "google.protobuf.Empty" # No need to check for stream since we validated above
|
|
2283
|
+
if output_msg_type is ServicerContext:
|
|
2284
|
+
uses_empty = True
|
|
2285
|
+
elif (
|
|
2286
|
+
inspect.isclass(output_msg_type)
|
|
2287
|
+
and issubclass(output_msg_type, Message)
|
|
2288
|
+
and not output_msg_type.model_fields
|
|
2289
|
+
):
|
|
2290
|
+
# Empty message class
|
|
2291
|
+
output_str = "google.protobuf.Empty"
|
|
2292
|
+
uses_empty = True
|
|
2293
|
+
else:
|
|
2294
|
+
output_str = (
|
|
2295
|
+
f"stream {output_msg_type.__name__}"
|
|
2296
|
+
if output_is_stream
|
|
2297
|
+
else output_msg_type.__name__
|
|
2298
|
+
)
|
|
2299
|
+
|
|
2300
|
+
service_rpc_definitions.append(
|
|
2301
|
+
f"rpc {method_name} ({input_str}) returns ({output_str});"
|
|
2302
|
+
)
|
|
2303
|
+
|
|
2304
|
+
# Create service definition
|
|
2305
|
+
service_def_lines: list[str] = []
|
|
2306
|
+
if service_comment:
|
|
2307
|
+
service_def_lines.append(service_comment)
|
|
2308
|
+
service_def_lines.append(f"service {service_name} {{")
|
|
2309
|
+
service_def_lines.extend([f" {line}" for line in service_rpc_definitions])
|
|
2310
|
+
service_def_lines.append("}")
|
|
2311
|
+
service_def_lines.append("")
|
|
2312
|
+
|
|
2313
|
+
all_service_definitions.extend(service_def_lines)
|
|
2314
|
+
|
|
2315
|
+
# Build imports
|
|
2316
|
+
imports: list[str] = []
|
|
2317
|
+
if uses_timestamp:
|
|
2318
|
+
imports.append('import "google/protobuf/timestamp.proto";')
|
|
2319
|
+
if uses_duration:
|
|
2320
|
+
imports.append('import "google/protobuf/duration.proto";')
|
|
2321
|
+
if uses_empty:
|
|
2322
|
+
imports.append('import "google/protobuf/empty.proto";')
|
|
2323
|
+
|
|
2324
|
+
import_block = "\n".join(imports)
|
|
2325
|
+
if import_block:
|
|
2326
|
+
import_block += "\n"
|
|
2327
|
+
|
|
2328
|
+
# Combine everything
|
|
2329
|
+
service_defs = "\n".join(all_service_definitions)
|
|
2330
|
+
type_defs = "\n".join(all_type_definitions)
|
|
2331
|
+
proto_definition = f"""syntax = "proto3";
|
|
2332
|
+
|
|
2333
|
+
package {package_name};
|
|
2334
|
+
|
|
2335
|
+
{import_block}{service_defs}
|
|
2336
|
+
{type_defs}
|
|
2337
|
+
"""
|
|
2338
|
+
return proto_definition
|
|
2339
|
+
|
|
2340
|
+
|
|
2341
|
+
def get_combined_proto_filename() -> str:
|
|
2342
|
+
"""Get the combined proto filename."""
|
|
2343
|
+
return os.getenv("PYDANTIC_RPC_COMBINED_PROTO_FILENAME", "combined_services.proto")
|
|
2344
|
+
|
|
2345
|
+
|
|
2346
|
+
def generate_combined_descriptor_set(
|
|
2347
|
+
*services: object, output_path: Path | None = None
|
|
2348
|
+
) -> bytes:
|
|
2349
|
+
"""Generate a combined protobuf descriptor set from multiple services."""
|
|
2350
|
+
filename = get_combined_proto_filename()
|
|
2351
|
+
|
|
2352
|
+
if output_path is None:
|
|
2353
|
+
output_path = get_proto_path(filename)
|
|
2354
|
+
|
|
2355
|
+
# Generate combined proto file
|
|
2356
|
+
combined_proto = generate_combined_proto(*services)
|
|
2357
|
+
proto_file_path = get_proto_path(filename)
|
|
2358
|
+
|
|
2359
|
+
with proto_file_path.open(mode="w", encoding="utf-8") as f:
|
|
2360
|
+
_ = f.write(combined_proto)
|
|
2361
|
+
|
|
2362
|
+
# Generate descriptor set using protoc
|
|
2363
|
+
out_str = str(proto_file_path.parent)
|
|
2364
|
+
well_known_path = os.path.join(os.path.dirname(grpc_tools.__file__), "_proto")
|
|
2365
|
+
args = [
|
|
2366
|
+
"protoc",
|
|
2367
|
+
f"-I{out_str}",
|
|
2368
|
+
f"-I{well_known_path}",
|
|
2369
|
+
f"--descriptor_set_out={output_path}",
|
|
2370
|
+
"--include_imports",
|
|
2371
|
+
proto_file_path.name,
|
|
2372
|
+
]
|
|
2373
|
+
|
|
2374
|
+
current_dir = os.getcwd()
|
|
2375
|
+
os.chdir(out_str)
|
|
2376
|
+
try:
|
|
2377
|
+
if protoc.main(args) != 0:
|
|
2378
|
+
raise RuntimeError("Failed to generate combined descriptor set")
|
|
2379
|
+
finally:
|
|
2380
|
+
os.chdir(current_dir)
|
|
2381
|
+
|
|
2382
|
+
# Read and return the descriptor set
|
|
2383
|
+
with open(output_path, "rb") as f:
|
|
2384
|
+
descriptor_data = f.read()
|
|
2385
|
+
|
|
2386
|
+
return descriptor_data
|
|
2387
|
+
|
|
2388
|
+
|
|
1124
2389
|
###############################################################################
|
|
1125
2390
|
# 4. Server Implementations
|
|
1126
2391
|
###############################################################################
|
|
@@ -1129,13 +2394,13 @@ def generate_and_compile_proto_using_connecpy(obj: object, package_name: str = "
|
|
|
1129
2394
|
class Server:
|
|
1130
2395
|
"""A simple gRPC server that uses ThreadPoolExecutor for concurrency."""
|
|
1131
2396
|
|
|
1132
|
-
def __init__(self, max_workers: int = 8, *interceptors) -> None:
|
|
1133
|
-
self._server = grpc.server(
|
|
2397
|
+
def __init__(self, max_workers: int = 8, *interceptors: Any) -> None:
|
|
2398
|
+
self._server: grpc.Server = grpc.server(
|
|
1134
2399
|
futures.ThreadPoolExecutor(max_workers), interceptors=interceptors
|
|
1135
2400
|
)
|
|
1136
|
-
self._service_names = []
|
|
1137
|
-
self._package_name = ""
|
|
1138
|
-
self._port = 50051
|
|
2401
|
+
self._service_names: list[str] = []
|
|
2402
|
+
self._package_name: str = ""
|
|
2403
|
+
self._port: int = 50051
|
|
1139
2404
|
|
|
1140
2405
|
def set_package_name(self, package_name: str):
|
|
1141
2406
|
"""Set the package name for .proto generation."""
|
|
@@ -1147,10 +2412,15 @@ class Server:
|
|
|
1147
2412
|
|
|
1148
2413
|
def mount(self, obj: object, package_name: str = ""):
|
|
1149
2414
|
"""Generate and compile proto files, then mount the service implementation."""
|
|
1150
|
-
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
|
|
2415
|
+
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name) or (
|
|
2416
|
+
None,
|
|
2417
|
+
None,
|
|
2418
|
+
)
|
|
1151
2419
|
self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
|
|
1152
2420
|
|
|
1153
|
-
def mount_using_pb2_modules(
|
|
2421
|
+
def mount_using_pb2_modules(
|
|
2422
|
+
self, pb2_grpc_module: Any, pb2_module: Any, obj: object
|
|
2423
|
+
):
|
|
1154
2424
|
"""Connect the compiled gRPC modules with the service implementation."""
|
|
1155
2425
|
concreteServiceClass = connect_obj_with_stub(pb2_grpc_module, pb2_module, obj)
|
|
1156
2426
|
service_name = obj.__class__.__name__
|
|
@@ -1163,7 +2433,7 @@ class Server:
|
|
|
1163
2433
|
].full_name
|
|
1164
2434
|
self._service_names.append(full_service_name)
|
|
1165
2435
|
|
|
1166
|
-
def run(self, *objs):
|
|
2436
|
+
def run(self, *objs: object):
|
|
1167
2437
|
"""
|
|
1168
2438
|
Mount multiple services and run the gRPC server with reflection and health check.
|
|
1169
2439
|
Press Ctrl+C or send SIGTERM to stop.
|
|
@@ -1183,14 +2453,16 @@ class Server:
|
|
|
1183
2453
|
self._server.add_insecure_port(f"[::]:{self._port}")
|
|
1184
2454
|
self._server.start()
|
|
1185
2455
|
|
|
1186
|
-
def handle_signal(signal, frame):
|
|
2456
|
+
def handle_signal(signum: signal.Signals, frame: Any):
|
|
2457
|
+
_ = signum
|
|
2458
|
+
_ = frame
|
|
1187
2459
|
print("Received shutdown signal...")
|
|
1188
2460
|
self._server.stop(grace=10)
|
|
1189
2461
|
print("gRPC server shutdown.")
|
|
1190
2462
|
sys.exit(0)
|
|
1191
2463
|
|
|
1192
|
-
signal.signal(signal.SIGINT, handle_signal)
|
|
1193
|
-
signal.signal(signal.SIGTERM, handle_signal)
|
|
2464
|
+
_ = signal.signal(signal.SIGINT, handle_signal) # pyright:ignore[reportArgumentType]
|
|
2465
|
+
_ = signal.signal(signal.SIGTERM, handle_signal) # pyright:ignore[reportArgumentType]
|
|
1194
2466
|
|
|
1195
2467
|
print("gRPC server is running...")
|
|
1196
2468
|
while True:
|
|
@@ -1200,11 +2472,11 @@ class Server:
|
|
|
1200
2472
|
class AsyncIOServer:
|
|
1201
2473
|
"""An async gRPC server using asyncio."""
|
|
1202
2474
|
|
|
1203
|
-
def __init__(self, *interceptors) -> None:
|
|
1204
|
-
self._server = grpc.aio.server(interceptors=interceptors)
|
|
1205
|
-
self._service_names = []
|
|
1206
|
-
self._package_name = ""
|
|
1207
|
-
self._port = 50051
|
|
2475
|
+
def __init__(self, *interceptors: grpc.ServerInterceptor) -> None:
|
|
2476
|
+
self._server: grpc.aio.Server = grpc.aio.server(interceptors=interceptors)
|
|
2477
|
+
self._service_names: list[str] = []
|
|
2478
|
+
self._package_name: str = ""
|
|
2479
|
+
self._port: int = 50051
|
|
1208
2480
|
|
|
1209
2481
|
def set_package_name(self, package_name: str):
|
|
1210
2482
|
"""Set the package name for .proto generation."""
|
|
@@ -1216,10 +2488,15 @@ class AsyncIOServer:
|
|
|
1216
2488
|
|
|
1217
2489
|
def mount(self, obj: object, package_name: str = ""):
|
|
1218
2490
|
"""Generate and compile proto files, then mount the service implementation (async)."""
|
|
1219
|
-
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
|
|
2491
|
+
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name) or (
|
|
2492
|
+
None,
|
|
2493
|
+
None,
|
|
2494
|
+
)
|
|
1220
2495
|
self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
|
|
1221
2496
|
|
|
1222
|
-
def mount_using_pb2_modules(
|
|
2497
|
+
def mount_using_pb2_modules(
|
|
2498
|
+
self, pb2_grpc_module: Any, pb2_module: Any, obj: object
|
|
2499
|
+
):
|
|
1223
2500
|
"""Connect the compiled gRPC modules with the async service implementation."""
|
|
1224
2501
|
concreteServiceClass = connect_obj_with_stub_async(
|
|
1225
2502
|
pb2_grpc_module, pb2_module, obj
|
|
@@ -1234,7 +2511,7 @@ class AsyncIOServer:
|
|
|
1234
2511
|
].full_name
|
|
1235
2512
|
self._service_names.append(full_service_name)
|
|
1236
2513
|
|
|
1237
|
-
async def run(self, *objs):
|
|
2514
|
+
async def run(self, *objs: object):
|
|
1238
2515
|
"""
|
|
1239
2516
|
Mount multiple async services and run the gRPC server with reflection and health check.
|
|
1240
2517
|
Press Ctrl+C or send SIGTERM to stop.
|
|
@@ -1251,20 +2528,22 @@ class AsyncIOServer:
|
|
|
1251
2528
|
health_pb2_grpc.add_HealthServicer_to_server(health_servicer, self._server)
|
|
1252
2529
|
reflection.enable_server_reflection(SERVICE_NAMES, self._server)
|
|
1253
2530
|
|
|
1254
|
-
self._server.add_insecure_port(f"[::]:{self._port}")
|
|
2531
|
+
_ = self._server.add_insecure_port(f"[::]:{self._port}")
|
|
1255
2532
|
await self._server.start()
|
|
1256
2533
|
|
|
1257
2534
|
shutdown_event = asyncio.Event()
|
|
1258
2535
|
|
|
1259
|
-
def shutdown(signum, frame):
|
|
2536
|
+
def shutdown(signum: signal.Signals, frame: Any):
|
|
2537
|
+
_ = signum
|
|
2538
|
+
_ = frame
|
|
1260
2539
|
print("Received shutdown signal...")
|
|
1261
2540
|
shutdown_event.set()
|
|
1262
2541
|
|
|
1263
2542
|
for s in [signal.SIGTERM, signal.SIGINT]:
|
|
1264
|
-
signal.signal(s, shutdown)
|
|
2543
|
+
_ = signal.signal(s, shutdown) # pyright:ignore[reportArgumentType]
|
|
1265
2544
|
|
|
1266
2545
|
print("gRPC server is running...")
|
|
1267
|
-
await shutdown_event.wait()
|
|
2546
|
+
_ = await shutdown_event.wait()
|
|
1268
2547
|
await self._server.stop(10)
|
|
1269
2548
|
print("gRPC server shutdown.")
|
|
1270
2549
|
|
|
@@ -1275,17 +2554,22 @@ class WSGIApp:
|
|
|
1275
2554
|
Useful for embedding gRPC within an existing WSGI stack.
|
|
1276
2555
|
"""
|
|
1277
2556
|
|
|
1278
|
-
def __init__(self, app):
|
|
1279
|
-
self._app = grpcWSGI(app)
|
|
1280
|
-
self._service_names = []
|
|
1281
|
-
self._package_name = ""
|
|
2557
|
+
def __init__(self, app: Any):
|
|
2558
|
+
self._app: grpcWSGI = grpcWSGI(app)
|
|
2559
|
+
self._service_names: list[str] = []
|
|
2560
|
+
self._package_name: str = ""
|
|
1282
2561
|
|
|
1283
2562
|
def mount(self, obj: object, package_name: str = ""):
|
|
1284
2563
|
"""Generate and compile proto files, then mount the service implementation."""
|
|
1285
|
-
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
|
|
2564
|
+
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name) or (
|
|
2565
|
+
None,
|
|
2566
|
+
None,
|
|
2567
|
+
)
|
|
1286
2568
|
self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
|
|
1287
2569
|
|
|
1288
|
-
def mount_using_pb2_modules(
|
|
2570
|
+
def mount_using_pb2_modules(
|
|
2571
|
+
self, pb2_grpc_module: Any, pb2_module: Any, obj: object
|
|
2572
|
+
):
|
|
1289
2573
|
"""Connect the compiled gRPC modules with the service implementation."""
|
|
1290
2574
|
concreteServiceClass = connect_obj_with_stub(pb2_grpc_module, pb2_module, obj)
|
|
1291
2575
|
service_name = obj.__class__.__name__
|
|
@@ -1298,12 +2582,16 @@ class WSGIApp:
|
|
|
1298
2582
|
].full_name
|
|
1299
2583
|
self._service_names.append(full_service_name)
|
|
1300
2584
|
|
|
1301
|
-
def mount_objs(self, *objs):
|
|
2585
|
+
def mount_objs(self, *objs: object):
|
|
1302
2586
|
"""Mount multiple service objects into this WSGI app."""
|
|
1303
2587
|
for obj in objs:
|
|
1304
2588
|
self.mount(obj, self._package_name)
|
|
1305
2589
|
|
|
1306
|
-
def __call__(
|
|
2590
|
+
def __call__(
|
|
2591
|
+
self,
|
|
2592
|
+
environ: dict[str, Any],
|
|
2593
|
+
start_response: Callable[[str, list[tuple[str, str]]], None],
|
|
2594
|
+
) -> Any:
|
|
1307
2595
|
"""WSGI entry point."""
|
|
1308
2596
|
return self._app(environ, start_response)
|
|
1309
2597
|
|
|
@@ -1314,17 +2602,22 @@ class ASGIApp:
|
|
|
1314
2602
|
Useful for embedding gRPC within an existing ASGI stack.
|
|
1315
2603
|
"""
|
|
1316
2604
|
|
|
1317
|
-
def __init__(self, app):
|
|
1318
|
-
self._app = grpcASGI(app)
|
|
1319
|
-
self._service_names = []
|
|
1320
|
-
self._package_name = ""
|
|
2605
|
+
def __init__(self, app: Any):
|
|
2606
|
+
self._app: grpcASGI = grpcASGI(app)
|
|
2607
|
+
self._service_names: list[str] = []
|
|
2608
|
+
self._package_name: str = ""
|
|
1321
2609
|
|
|
1322
2610
|
def mount(self, obj: object, package_name: str = ""):
|
|
1323
2611
|
"""Generate and compile proto files, then mount the async service implementation."""
|
|
1324
|
-
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
|
|
2612
|
+
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name) or (
|
|
2613
|
+
None,
|
|
2614
|
+
None,
|
|
2615
|
+
)
|
|
1325
2616
|
self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
|
|
1326
2617
|
|
|
1327
|
-
def mount_using_pb2_modules(
|
|
2618
|
+
def mount_using_pb2_modules(
|
|
2619
|
+
self, pb2_grpc_module: Any, pb2_module: Any, obj: object
|
|
2620
|
+
):
|
|
1328
2621
|
"""Connect the compiled gRPC modules with the async service implementation."""
|
|
1329
2622
|
concreteServiceClass = connect_obj_with_stub_async(
|
|
1330
2623
|
pb2_grpc_module, pb2_module, obj
|
|
@@ -1339,18 +2632,27 @@ class ASGIApp:
|
|
|
1339
2632
|
].full_name
|
|
1340
2633
|
self._service_names.append(full_service_name)
|
|
1341
2634
|
|
|
1342
|
-
def mount_objs(self, *objs):
|
|
2635
|
+
def mount_objs(self, *objs: object):
|
|
1343
2636
|
"""Mount multiple service objects into this ASGI app."""
|
|
1344
2637
|
for obj in objs:
|
|
1345
2638
|
self.mount(obj, self._package_name)
|
|
1346
2639
|
|
|
1347
|
-
async def __call__(
|
|
2640
|
+
async def __call__(
|
|
2641
|
+
self,
|
|
2642
|
+
scope: dict[str, Any],
|
|
2643
|
+
receive: Callable[[], Any],
|
|
2644
|
+
send: Callable[[dict[str, Any]], Any],
|
|
2645
|
+
) -> Any:
|
|
1348
2646
|
"""ASGI entry point."""
|
|
1349
|
-
await self._app(scope, receive, send)
|
|
2647
|
+
_ = await self._app(scope, receive, send)
|
|
1350
2648
|
|
|
1351
2649
|
|
|
1352
|
-
def
|
|
1353
|
-
return getattr(connecpy_module, f"{service_name}
|
|
2650
|
+
def get_connecpy_asgi_app_class(connecpy_module: Any, service_name: str):
|
|
2651
|
+
return getattr(connecpy_module, f"{service_name}ASGIApplication")
|
|
2652
|
+
|
|
2653
|
+
|
|
2654
|
+
def get_connecpy_wsgi_app_class(connecpy_module: Any, service_name: str):
|
|
2655
|
+
return getattr(connecpy_module, f"{service_name}WSGIApplication")
|
|
1354
2656
|
|
|
1355
2657
|
|
|
1356
2658
|
class ConnecpyASGIApp:
|
|
@@ -1359,9 +2661,9 @@ class ConnecpyASGIApp:
|
|
|
1359
2661
|
"""
|
|
1360
2662
|
|
|
1361
2663
|
def __init__(self):
|
|
1362
|
-
self.
|
|
1363
|
-
self._service_names = []
|
|
1364
|
-
self._package_name = ""
|
|
2664
|
+
self._services: list[tuple[Any, str]] = [] # List of (app, path) tuples
|
|
2665
|
+
self._service_names: list[str] = []
|
|
2666
|
+
self._package_name: str = ""
|
|
1365
2667
|
|
|
1366
2668
|
def mount(self, obj: object, package_name: str = ""):
|
|
1367
2669
|
"""Generate and compile proto files, then mount the async service implementation."""
|
|
@@ -1370,28 +2672,59 @@ class ConnecpyASGIApp:
|
|
|
1370
2672
|
)
|
|
1371
2673
|
self.mount_using_pb2_modules(connecpy_module, pb2_module, obj)
|
|
1372
2674
|
|
|
1373
|
-
def mount_using_pb2_modules(
|
|
2675
|
+
def mount_using_pb2_modules(
|
|
2676
|
+
self, connecpy_module: Any, pb2_module: Any, obj: object
|
|
2677
|
+
):
|
|
1374
2678
|
"""Connect the compiled connecpy and pb2 modules with the async service implementation."""
|
|
1375
2679
|
concreteServiceClass = connect_obj_with_stub_async_connecpy(
|
|
1376
2680
|
connecpy_module, pb2_module, obj
|
|
1377
2681
|
)
|
|
1378
2682
|
service_name = obj.__class__.__name__
|
|
1379
2683
|
service_impl = concreteServiceClass()
|
|
1380
|
-
|
|
1381
|
-
|
|
2684
|
+
|
|
2685
|
+
# Get the service-specific ASGI application class
|
|
2686
|
+
app_class = get_connecpy_asgi_app_class(connecpy_module, service_name)
|
|
2687
|
+
app = app_class(service=service_impl)
|
|
2688
|
+
|
|
2689
|
+
# Store the app and its path for routing
|
|
2690
|
+
self._services.append((app, app.path))
|
|
2691
|
+
|
|
1382
2692
|
full_service_name = pb2_module.DESCRIPTOR.services_by_name[
|
|
1383
2693
|
service_name
|
|
1384
2694
|
].full_name
|
|
1385
2695
|
self._service_names.append(full_service_name)
|
|
1386
2696
|
|
|
1387
|
-
def mount_objs(self, *objs):
|
|
2697
|
+
def mount_objs(self, *objs: object):
|
|
1388
2698
|
"""Mount multiple service objects into this ASGI app."""
|
|
1389
2699
|
for obj in objs:
|
|
1390
2700
|
self.mount(obj, self._package_name)
|
|
1391
2701
|
|
|
1392
|
-
async def __call__(
|
|
1393
|
-
|
|
1394
|
-
|
|
2702
|
+
async def __call__(
|
|
2703
|
+
self,
|
|
2704
|
+
scope: dict[str, Any],
|
|
2705
|
+
receive: Callable[[], Any],
|
|
2706
|
+
send: Callable[[dict[str, Any]], Any],
|
|
2707
|
+
):
|
|
2708
|
+
"""ASGI entry point with routing for multiple services."""
|
|
2709
|
+
if scope["type"] != "http":
|
|
2710
|
+
await send({"type": "http.response.start", "status": 404})
|
|
2711
|
+
await send({"type": "http.response.body", "body": b"Not Found"})
|
|
2712
|
+
return
|
|
2713
|
+
|
|
2714
|
+
path = scope.get("path", "")
|
|
2715
|
+
|
|
2716
|
+
# Route to the appropriate service based on path
|
|
2717
|
+
for app, service_path in self._services:
|
|
2718
|
+
if path.startswith(service_path):
|
|
2719
|
+
return await app(scope, receive, send)
|
|
2720
|
+
|
|
2721
|
+
# If only one service is mounted, use it as default
|
|
2722
|
+
if len(self._services) == 1:
|
|
2723
|
+
return await self._services[0][0](scope, receive, send)
|
|
2724
|
+
|
|
2725
|
+
# No matching service found
|
|
2726
|
+
await send({"type": "http.response.start", "status": 404})
|
|
2727
|
+
await send({"type": "http.response.body", "body": b"Not Found"})
|
|
1395
2728
|
|
|
1396
2729
|
|
|
1397
2730
|
class ConnecpyWSGIApp:
|
|
@@ -1400,55 +2733,82 @@ class ConnecpyWSGIApp:
|
|
|
1400
2733
|
"""
|
|
1401
2734
|
|
|
1402
2735
|
def __init__(self):
|
|
1403
|
-
self.
|
|
1404
|
-
self._service_names = []
|
|
1405
|
-
self._package_name = ""
|
|
2736
|
+
self._services: list[tuple[Any, str]] = [] # List of (app, path) tuples
|
|
2737
|
+
self._service_names: list[str] = []
|
|
2738
|
+
self._package_name: str = ""
|
|
1406
2739
|
|
|
1407
2740
|
def mount(self, obj: object, package_name: str = ""):
|
|
1408
|
-
"""Generate and compile proto files, then mount the
|
|
2741
|
+
"""Generate and compile proto files, then mount the sync service implementation."""
|
|
1409
2742
|
connecpy_module, pb2_module = generate_and_compile_proto_using_connecpy(
|
|
1410
2743
|
obj, package_name
|
|
1411
2744
|
)
|
|
1412
2745
|
self.mount_using_pb2_modules(connecpy_module, pb2_module, obj)
|
|
1413
2746
|
|
|
1414
|
-
def mount_using_pb2_modules(
|
|
1415
|
-
|
|
2747
|
+
def mount_using_pb2_modules(
|
|
2748
|
+
self, connecpy_module: Any, pb2_module: Any, obj: object
|
|
2749
|
+
):
|
|
2750
|
+
"""Connect the compiled connecpy and pb2 modules with the sync service implementation."""
|
|
1416
2751
|
concreteServiceClass = connect_obj_with_stub_connecpy(
|
|
1417
2752
|
connecpy_module, pb2_module, obj
|
|
1418
2753
|
)
|
|
1419
2754
|
service_name = obj.__class__.__name__
|
|
1420
2755
|
service_impl = concreteServiceClass()
|
|
1421
|
-
|
|
1422
|
-
|
|
2756
|
+
|
|
2757
|
+
# Get the service-specific WSGI application class
|
|
2758
|
+
app_class = get_connecpy_wsgi_app_class(connecpy_module, service_name)
|
|
2759
|
+
app = app_class(service=service_impl)
|
|
2760
|
+
|
|
2761
|
+
# Store the app and its path for routing
|
|
2762
|
+
self._services.append((app, app.path))
|
|
2763
|
+
|
|
1423
2764
|
full_service_name = pb2_module.DESCRIPTOR.services_by_name[
|
|
1424
2765
|
service_name
|
|
1425
2766
|
].full_name
|
|
1426
2767
|
self._service_names.append(full_service_name)
|
|
1427
2768
|
|
|
1428
|
-
def mount_objs(self, *objs):
|
|
2769
|
+
def mount_objs(self, *objs: object):
|
|
1429
2770
|
"""Mount multiple service objects into this WSGI app."""
|
|
1430
2771
|
for obj in objs:
|
|
1431
2772
|
self.mount(obj, self._package_name)
|
|
1432
2773
|
|
|
1433
|
-
def __call__(
|
|
1434
|
-
|
|
1435
|
-
|
|
2774
|
+
def __call__(
|
|
2775
|
+
self,
|
|
2776
|
+
environ: dict[str, Any],
|
|
2777
|
+
start_response: Callable[[str, list[tuple[str, str]]], None],
|
|
2778
|
+
) -> Any:
|
|
2779
|
+
"""WSGI entry point with routing for multiple services."""
|
|
2780
|
+
path = environ.get("PATH_INFO", "")
|
|
2781
|
+
|
|
2782
|
+
# Route to the appropriate service based on path
|
|
2783
|
+
for app, service_path in self._services:
|
|
2784
|
+
if path.startswith(service_path):
|
|
2785
|
+
return app(environ, start_response)
|
|
2786
|
+
|
|
2787
|
+
# If only one service is mounted, use it as default
|
|
2788
|
+
if len(self._services) == 1:
|
|
2789
|
+
return self._services[0][0](environ, start_response)
|
|
2790
|
+
|
|
2791
|
+
# No matching service found
|
|
2792
|
+
start_response("404 Not Found", [("Content-Type", "text/plain")])
|
|
2793
|
+
return [b"Not Found"]
|
|
1436
2794
|
|
|
1437
2795
|
|
|
1438
2796
|
def main():
|
|
1439
2797
|
import argparse
|
|
1440
2798
|
|
|
1441
2799
|
parser = argparse.ArgumentParser(description="Generate and compile proto files.")
|
|
1442
|
-
parser.add_argument(
|
|
2800
|
+
_ = parser.add_argument(
|
|
1443
2801
|
"py_file", type=str, help="The Python file containing the service class."
|
|
1444
2802
|
)
|
|
1445
|
-
parser.add_argument(
|
|
2803
|
+
_ = parser.add_argument(
|
|
2804
|
+
"class_name", type=str, help="The name of the service class."
|
|
2805
|
+
)
|
|
1446
2806
|
args = parser.parse_args()
|
|
1447
2807
|
|
|
1448
2808
|
module_name = os.path.splitext(basename(args.py_file))[0]
|
|
1449
2809
|
module = importlib.import_module(module_name)
|
|
1450
2810
|
klass = getattr(module, args.class_name)
|
|
1451
|
-
generate_and_compile_proto(klass())
|
|
2811
|
+
_ = generate_and_compile_proto(klass())
|
|
1452
2812
|
|
|
1453
2813
|
|
|
1454
2814
|
if __name__ == "__main__":
|