pydantic-rpc 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_rpc/core.py +1343 -1308
- {pydantic_rpc-0.3.1.dist-info → pydantic_rpc-0.4.0.dist-info}/METADATA +29 -38
- pydantic_rpc-0.4.0.dist-info/RECORD +7 -0
- pydantic_rpc-0.3.1.dist-info/RECORD +0 -7
- {pydantic_rpc-0.3.1.dist-info → pydantic_rpc-0.4.0.dist-info}/WHEEL +0 -0
- {pydantic_rpc-0.3.1.dist-info → pydantic_rpc-0.4.0.dist-info}/licenses/LICENSE +0 -0
pydantic_rpc/core.py
CHANGED
|
@@ -1,1308 +1,1343 @@
|
|
|
1
|
-
import
|
|
2
|
-
import
|
|
3
|
-
import
|
|
4
|
-
import
|
|
5
|
-
import
|
|
6
|
-
import
|
|
7
|
-
import
|
|
8
|
-
import
|
|
9
|
-
import
|
|
10
|
-
import
|
|
11
|
-
|
|
12
|
-
from
|
|
13
|
-
from
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
)
|
|
22
|
-
from collections.abc import AsyncIterator
|
|
23
|
-
|
|
24
|
-
import grpc
|
|
25
|
-
from grpc_health.v1 import health_pb2, health_pb2_grpc
|
|
26
|
-
from grpc_health.v1.health import HealthServicer
|
|
27
|
-
from grpc_reflection.v1alpha import reflection
|
|
28
|
-
from grpc_tools import protoc
|
|
29
|
-
from pydantic import BaseModel, ValidationError
|
|
30
|
-
from sonora.wsgi import grpcWSGI
|
|
31
|
-
from sonora.asgi import grpcASGI
|
|
32
|
-
from connecpy.asgi import ConnecpyASGIApp as ConnecpyASGI
|
|
33
|
-
from connecpy.errors import Errors
|
|
34
|
-
|
|
35
|
-
# Protobuf Python modules for Timestamp, Duration (requires protobuf / grpcio)
|
|
36
|
-
from google.protobuf import timestamp_pb2, duration_pb2
|
|
37
|
-
|
|
38
|
-
###############################################################################
|
|
39
|
-
# 1. Message definitions & converter extensions
|
|
40
|
-
# (datetime.datetime <-> google.protobuf.Timestamp)
|
|
41
|
-
# (datetime.timedelta <-> google.protobuf.Duration)
|
|
42
|
-
###############################################################################
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
Message: TypeAlias = BaseModel
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
def primitiveProtoValueToPythonValue(value):
|
|
49
|
-
# Returns the value as-is (primitive type).
|
|
50
|
-
return value
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def timestamp_to_python(ts: timestamp_pb2.Timestamp) -> datetime.datetime: # type: ignore
|
|
54
|
-
"""Convert a protobuf Timestamp to a Python datetime object."""
|
|
55
|
-
return ts.ToDatetime()
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def python_to_timestamp(dt: datetime.datetime) -> timestamp_pb2.Timestamp: # type: ignore
|
|
59
|
-
"""Convert a Python datetime object to a protobuf Timestamp."""
|
|
60
|
-
ts = timestamp_pb2.Timestamp() # type: ignore
|
|
61
|
-
ts.FromDatetime(dt)
|
|
62
|
-
return ts
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
def duration_to_python(d: duration_pb2.Duration) -> datetime.timedelta: # type: ignore
|
|
66
|
-
"""Convert a protobuf Duration to a Python timedelta object."""
|
|
67
|
-
return d.ToTimedelta()
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def python_to_duration(td: datetime.timedelta) -> duration_pb2.Duration: # type: ignore
|
|
71
|
-
"""Convert a Python timedelta object to a protobuf Duration."""
|
|
72
|
-
d = duration_pb2.Duration() # type: ignore
|
|
73
|
-
d.FromTimedelta(td)
|
|
74
|
-
return d
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def generate_converter(annotation: Type) -> Callable:
|
|
78
|
-
"""
|
|
79
|
-
Returns a converter function to convert protobuf types to Python types.
|
|
80
|
-
This is used primarily when handling incoming requests.
|
|
81
|
-
"""
|
|
82
|
-
# For primitive types
|
|
83
|
-
if annotation in (int, str, bool, bytes, float):
|
|
84
|
-
return primitiveProtoValueToPythonValue
|
|
85
|
-
|
|
86
|
-
# For enum types
|
|
87
|
-
if inspect.isclass(annotation) and issubclass(annotation, enum.Enum):
|
|
88
|
-
|
|
89
|
-
def enum_converter(value):
|
|
90
|
-
return annotation(value)
|
|
91
|
-
|
|
92
|
-
return enum_converter
|
|
93
|
-
|
|
94
|
-
# For datetime
|
|
95
|
-
if annotation == datetime.datetime:
|
|
96
|
-
|
|
97
|
-
def ts_converter(value: timestamp_pb2.Timestamp): # type: ignore
|
|
98
|
-
return value.ToDatetime()
|
|
99
|
-
|
|
100
|
-
return ts_converter
|
|
101
|
-
|
|
102
|
-
# For timedelta
|
|
103
|
-
if annotation == datetime.timedelta:
|
|
104
|
-
|
|
105
|
-
def dur_converter(value: duration_pb2.Duration): # type: ignore
|
|
106
|
-
return value.ToTimedelta()
|
|
107
|
-
|
|
108
|
-
return dur_converter
|
|
109
|
-
|
|
110
|
-
origin = get_origin(annotation)
|
|
111
|
-
if origin is not None:
|
|
112
|
-
# For seq types
|
|
113
|
-
if origin in (list, tuple):
|
|
114
|
-
item_converter = generate_converter(get_args(annotation)[0])
|
|
115
|
-
|
|
116
|
-
def seq_converter(value):
|
|
117
|
-
return [item_converter(v) for v in value]
|
|
118
|
-
|
|
119
|
-
return seq_converter
|
|
120
|
-
|
|
121
|
-
# For dict types
|
|
122
|
-
if origin is dict:
|
|
123
|
-
key_converter = generate_converter(get_args(annotation)[0])
|
|
124
|
-
value_converter = generate_converter(get_args(annotation)[1])
|
|
125
|
-
|
|
126
|
-
def dict_converter(value):
|
|
127
|
-
return {key_converter(k): value_converter(v) for k, v in value.items()}
|
|
128
|
-
|
|
129
|
-
return dict_converter
|
|
130
|
-
|
|
131
|
-
# For Message classes
|
|
132
|
-
if inspect.isclass(annotation) and issubclass(annotation, Message):
|
|
133
|
-
return generate_message_converter(annotation)
|
|
134
|
-
|
|
135
|
-
# For union types or other unsupported cases, just return the value as-is.
|
|
136
|
-
return primitiveProtoValueToPythonValue
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
def generate_message_converter(arg_type: Type[Message]) -> Callable:
|
|
140
|
-
"""Return a converter function for protobuf -> Python Message."""
|
|
141
|
-
if arg_type is None or not issubclass(arg_type, Message):
|
|
142
|
-
raise TypeError("Request arg must be subclass of Message")
|
|
143
|
-
|
|
144
|
-
fields = arg_type.model_fields
|
|
145
|
-
converters = {
|
|
146
|
-
field: generate_converter(field_type.annotation) # type: ignore
|
|
147
|
-
for field, field_type in fields.items()
|
|
148
|
-
}
|
|
149
|
-
|
|
150
|
-
def converter(request):
|
|
151
|
-
rdict = {}
|
|
152
|
-
for field in fields.keys():
|
|
153
|
-
rdict[field] = converters[field](getattr(request, field))
|
|
154
|
-
return arg_type(**rdict)
|
|
155
|
-
|
|
156
|
-
return converter
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
def python_value_to_proto_value(field_type: Type, value):
|
|
160
|
-
"""
|
|
161
|
-
Converts Python values to protobuf values.
|
|
162
|
-
Used primarily when constructing a response object.
|
|
163
|
-
"""
|
|
164
|
-
# datetime.datetime -> Timestamp
|
|
165
|
-
if field_type == datetime.datetime:
|
|
166
|
-
return python_to_timestamp(value)
|
|
167
|
-
|
|
168
|
-
# datetime.timedelta -> Duration
|
|
169
|
-
if field_type == datetime.timedelta:
|
|
170
|
-
return python_to_duration(value)
|
|
171
|
-
|
|
172
|
-
# Default behavior: return the value as-is.
|
|
173
|
-
return value
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
###############################################################################
|
|
177
|
-
# 2. Stub implementation
|
|
178
|
-
###############################################################################
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
def connect_obj_with_stub(pb2_grpc_module, pb2_module, service_obj: object) -> type:
|
|
182
|
-
"""
|
|
183
|
-
Connect a Python service object to a gRPC stub, generating server methods.
|
|
184
|
-
"""
|
|
185
|
-
service_class = service_obj.__class__
|
|
186
|
-
stub_class_name = service_class.__name__ + "Servicer"
|
|
187
|
-
stub_class = getattr(pb2_grpc_module, stub_class_name)
|
|
188
|
-
|
|
189
|
-
class ConcreteServiceClass(stub_class):
|
|
190
|
-
pass
|
|
191
|
-
|
|
192
|
-
def implement_stub_method(method):
|
|
193
|
-
# Analyze method signature.
|
|
194
|
-
sig = inspect.signature(method)
|
|
195
|
-
arg_type = get_request_arg_type(sig)
|
|
196
|
-
# Convert request from protobuf to Python.
|
|
197
|
-
converter = generate_message_converter(arg_type)
|
|
198
|
-
|
|
199
|
-
response_type = sig.return_annotation
|
|
200
|
-
size_of_parameters = len(sig.parameters)
|
|
201
|
-
|
|
202
|
-
match size_of_parameters:
|
|
203
|
-
case 1:
|
|
204
|
-
|
|
205
|
-
def stub_method1(self, request, context, method=method):
|
|
206
|
-
try:
|
|
207
|
-
# Convert request to Python object
|
|
208
|
-
arg = converter(request)
|
|
209
|
-
# Invoke the actual method
|
|
210
|
-
resp_obj = method(arg)
|
|
211
|
-
# Convert the returned Python Message to a protobuf message
|
|
212
|
-
return convert_python_message_to_proto(
|
|
213
|
-
resp_obj, response_type, pb2_module
|
|
214
|
-
)
|
|
215
|
-
except ValidationError as e:
|
|
216
|
-
return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
|
217
|
-
except Exception as e:
|
|
218
|
-
return context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
219
|
-
|
|
220
|
-
return stub_method1
|
|
221
|
-
|
|
222
|
-
case 2:
|
|
223
|
-
|
|
224
|
-
def stub_method2(self, request, context, method=method):
|
|
225
|
-
try:
|
|
226
|
-
arg = converter(request)
|
|
227
|
-
resp_obj = method(arg, context)
|
|
228
|
-
return convert_python_message_to_proto(
|
|
229
|
-
resp_obj, response_type, pb2_module
|
|
230
|
-
)
|
|
231
|
-
except ValidationError as e:
|
|
232
|
-
return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
|
233
|
-
except Exception as e:
|
|
234
|
-
return context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
235
|
-
|
|
236
|
-
return stub_method2
|
|
237
|
-
|
|
238
|
-
case _:
|
|
239
|
-
raise Exception("Method must have exactly one or two parameters")
|
|
240
|
-
|
|
241
|
-
for method_name, method in get_rpc_methods(service_obj):
|
|
242
|
-
if method.__name__.startswith("_"):
|
|
243
|
-
continue
|
|
244
|
-
|
|
245
|
-
a_method = implement_stub_method(method)
|
|
246
|
-
setattr(ConcreteServiceClass, method_name, a_method)
|
|
247
|
-
|
|
248
|
-
return ConcreteServiceClass
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> type:
|
|
252
|
-
"""
|
|
253
|
-
Connect a Python service object to a gRPC stub for async methods.
|
|
254
|
-
"""
|
|
255
|
-
service_class = obj.__class__
|
|
256
|
-
stub_class_name = service_class.__name__ + "Servicer"
|
|
257
|
-
stub_class = getattr(pb2_grpc_module, stub_class_name)
|
|
258
|
-
|
|
259
|
-
class ConcreteServiceClass(stub_class):
|
|
260
|
-
pass
|
|
261
|
-
|
|
262
|
-
def implement_stub_method(method):
|
|
263
|
-
sig = inspect.signature(method)
|
|
264
|
-
arg_type = get_request_arg_type(sig)
|
|
265
|
-
converter = generate_message_converter(arg_type)
|
|
266
|
-
response_type = sig.return_annotation
|
|
267
|
-
size_of_parameters = len(sig.parameters)
|
|
268
|
-
|
|
269
|
-
if is_stream_type(response_type):
|
|
270
|
-
item_type = get_args(response_type)[0]
|
|
271
|
-
match size_of_parameters:
|
|
272
|
-
case 1:
|
|
273
|
-
|
|
274
|
-
async def stub_method_stream1(
|
|
275
|
-
self, request, context, method=method
|
|
276
|
-
):
|
|
277
|
-
try:
|
|
278
|
-
arg = converter(request)
|
|
279
|
-
async for resp_obj in method(arg):
|
|
280
|
-
yield convert_python_message_to_proto(
|
|
281
|
-
resp_obj, item_type, pb2_module
|
|
282
|
-
)
|
|
283
|
-
except ValidationError as e:
|
|
284
|
-
await context.abort(
|
|
285
|
-
grpc.StatusCode.INVALID_ARGUMENT, str(e)
|
|
286
|
-
)
|
|
287
|
-
except Exception as e:
|
|
288
|
-
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
289
|
-
|
|
290
|
-
return stub_method_stream1
|
|
291
|
-
case 2:
|
|
292
|
-
|
|
293
|
-
async def stub_method_stream2(
|
|
294
|
-
self, request, context, method=method
|
|
295
|
-
):
|
|
296
|
-
try:
|
|
297
|
-
arg = converter(request)
|
|
298
|
-
async for resp_obj in method(arg, context):
|
|
299
|
-
yield convert_python_message_to_proto(
|
|
300
|
-
resp_obj, item_type, pb2_module
|
|
301
|
-
)
|
|
302
|
-
except ValidationError as e:
|
|
303
|
-
await context.abort(
|
|
304
|
-
grpc.StatusCode.INVALID_ARGUMENT, str(e)
|
|
305
|
-
)
|
|
306
|
-
except Exception as e:
|
|
307
|
-
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
308
|
-
|
|
309
|
-
return stub_method_stream2
|
|
310
|
-
case _:
|
|
311
|
-
raise Exception("Method must have exactly one or two parameters")
|
|
312
|
-
|
|
313
|
-
match size_of_parameters:
|
|
314
|
-
case 1:
|
|
315
|
-
|
|
316
|
-
async def stub_method1(self, request, context, method=method):
|
|
317
|
-
try:
|
|
318
|
-
arg = converter(request)
|
|
319
|
-
resp_obj = await method(arg)
|
|
320
|
-
return convert_python_message_to_proto(
|
|
321
|
-
resp_obj, response_type, pb2_module
|
|
322
|
-
)
|
|
323
|
-
except ValidationError as e:
|
|
324
|
-
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
|
325
|
-
except Exception as e:
|
|
326
|
-
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
327
|
-
|
|
328
|
-
return stub_method1
|
|
329
|
-
|
|
330
|
-
case 2:
|
|
331
|
-
|
|
332
|
-
async def stub_method2(self, request, context, method=method):
|
|
333
|
-
try:
|
|
334
|
-
arg = converter(request)
|
|
335
|
-
resp_obj = await method(arg, context)
|
|
336
|
-
return convert_python_message_to_proto(
|
|
337
|
-
resp_obj, response_type, pb2_module
|
|
338
|
-
)
|
|
339
|
-
except ValidationError as e:
|
|
340
|
-
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
|
341
|
-
except Exception as e:
|
|
342
|
-
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
343
|
-
|
|
344
|
-
return stub_method2
|
|
345
|
-
|
|
346
|
-
case _:
|
|
347
|
-
raise Exception("Method must have exactly one or two parameters")
|
|
348
|
-
|
|
349
|
-
for method_name, method in get_rpc_methods(obj):
|
|
350
|
-
if method.__name__.startswith("_"):
|
|
351
|
-
continue
|
|
352
|
-
|
|
353
|
-
a_method = implement_stub_method(method)
|
|
354
|
-
setattr(ConcreteServiceClass, method_name, a_method)
|
|
355
|
-
|
|
356
|
-
return ConcreteServiceClass
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
def connect_obj_with_stub_async_connecpy(
|
|
360
|
-
connecpy_module, pb2_module, obj: object
|
|
361
|
-
) -> type:
|
|
362
|
-
"""
|
|
363
|
-
Connect a Python service object to a Connecpy stub for async methods.
|
|
364
|
-
"""
|
|
365
|
-
service_class = obj.__class__
|
|
366
|
-
stub_class_name = service_class.__name__
|
|
367
|
-
stub_class = getattr(connecpy_module, stub_class_name)
|
|
368
|
-
|
|
369
|
-
class ConcreteServiceClass(stub_class):
|
|
370
|
-
pass
|
|
371
|
-
|
|
372
|
-
def implement_stub_method(method):
|
|
373
|
-
sig = inspect.signature(method)
|
|
374
|
-
arg_type = get_request_arg_type(sig)
|
|
375
|
-
converter = generate_message_converter(arg_type)
|
|
376
|
-
response_type = sig.return_annotation
|
|
377
|
-
size_of_parameters = len(sig.parameters)
|
|
378
|
-
|
|
379
|
-
match size_of_parameters:
|
|
380
|
-
case 1:
|
|
381
|
-
|
|
382
|
-
async def stub_method1(self, request, context, method=method):
|
|
383
|
-
try:
|
|
384
|
-
arg = converter(request)
|
|
385
|
-
resp_obj = await method(arg)
|
|
386
|
-
return convert_python_message_to_proto(
|
|
387
|
-
resp_obj, response_type, pb2_module
|
|
388
|
-
)
|
|
389
|
-
except ValidationError as e:
|
|
390
|
-
await context.abort(Errors.InvalidArgument, str(e))
|
|
391
|
-
except Exception as e:
|
|
392
|
-
await context.abort(Errors.Internal, str(e))
|
|
393
|
-
|
|
394
|
-
return stub_method1
|
|
395
|
-
|
|
396
|
-
case 2:
|
|
397
|
-
|
|
398
|
-
async def stub_method2(self, request, context, method=method):
|
|
399
|
-
try:
|
|
400
|
-
arg = converter(request)
|
|
401
|
-
resp_obj = await method(arg, context)
|
|
402
|
-
return convert_python_message_to_proto(
|
|
403
|
-
resp_obj, response_type, pb2_module
|
|
404
|
-
)
|
|
405
|
-
except ValidationError as e:
|
|
406
|
-
await context.abort(Errors.InvalidArgument, str(e))
|
|
407
|
-
except Exception as e:
|
|
408
|
-
await context.abort(Errors.Internal, str(e))
|
|
409
|
-
|
|
410
|
-
return stub_method2
|
|
411
|
-
|
|
412
|
-
case _:
|
|
413
|
-
raise Exception("Method must have exactly one or two parameters")
|
|
414
|
-
|
|
415
|
-
for method_name, method in get_rpc_methods(obj):
|
|
416
|
-
if method.__name__.startswith("_"):
|
|
417
|
-
continue
|
|
418
|
-
if not asyncio.iscoroutinefunction(method):
|
|
419
|
-
raise Exception("Method must be async", method_name)
|
|
420
|
-
a_method = implement_stub_method(method)
|
|
421
|
-
setattr(ConcreteServiceClass, method_name, a_method)
|
|
422
|
-
|
|
423
|
-
return ConcreteServiceClass
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
def convert_python_message_to_proto(
|
|
427
|
-
py_msg: Message, msg_type: Type, pb2_module
|
|
428
|
-
) -> object:
|
|
429
|
-
"""
|
|
430
|
-
Convert a Python Pydantic Message instance to a protobuf message instance.
|
|
431
|
-
Used for constructing a response.
|
|
432
|
-
"""
|
|
433
|
-
# Before calling something like pb2_module.AResponseMessage(...),
|
|
434
|
-
# convert each field from Python to proto.
|
|
435
|
-
field_dict = {}
|
|
436
|
-
for name, field_info in msg_type.model_fields.items():
|
|
437
|
-
value = getattr(py_msg, name)
|
|
438
|
-
if value is None:
|
|
439
|
-
field_dict[name] = None
|
|
440
|
-
continue
|
|
441
|
-
|
|
442
|
-
field_type = field_info.annotation
|
|
443
|
-
field_dict[name] = python_value_to_proto(field_type, value, pb2_module)
|
|
444
|
-
|
|
445
|
-
# Retrieve the appropriate protobuf class dynamically
|
|
446
|
-
proto_class = getattr(pb2_module, msg_type.__name__)
|
|
447
|
-
return proto_class(**field_dict)
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
def python_value_to_proto(field_type: Type, value, pb2_module):
|
|
451
|
-
"""
|
|
452
|
-
Perform Python->protobuf type conversion for each field value.
|
|
453
|
-
"""
|
|
454
|
-
import inspect
|
|
455
|
-
import datetime
|
|
456
|
-
|
|
457
|
-
# If datetime
|
|
458
|
-
if field_type == datetime.datetime:
|
|
459
|
-
return python_to_timestamp(value)
|
|
460
|
-
|
|
461
|
-
# If timedelta
|
|
462
|
-
if field_type == datetime.timedelta:
|
|
463
|
-
return python_to_duration(value)
|
|
464
|
-
|
|
465
|
-
# If enum
|
|
466
|
-
if inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
|
|
467
|
-
return value.value # proto3 enum is an int
|
|
468
|
-
|
|
469
|
-
origin = get_origin(field_type)
|
|
470
|
-
# If seq
|
|
471
|
-
if origin in (list, tuple):
|
|
472
|
-
inner_type = get_args(field_type)[0] # type: ignore
|
|
473
|
-
return [python_value_to_proto(inner_type, v, pb2_module) for v in value]
|
|
474
|
-
|
|
475
|
-
# If dict
|
|
476
|
-
if origin is dict:
|
|
477
|
-
key_type, val_type = get_args(field_type) # type: ignore
|
|
478
|
-
return {
|
|
479
|
-
python_value_to_proto(key_type, k, pb2_module): python_value_to_proto(
|
|
480
|
-
val_type, v, pb2_module
|
|
481
|
-
)
|
|
482
|
-
for k, v in value.items()
|
|
483
|
-
}
|
|
484
|
-
|
|
485
|
-
# If union -> oneof
|
|
486
|
-
if is_union_type(field_type):
|
|
487
|
-
# Flatten union and check which type matches. If matched, return converted value.
|
|
488
|
-
for sub_type in flatten_union(field_type):
|
|
489
|
-
if sub_type == datetime.datetime and isinstance(value, datetime.datetime):
|
|
490
|
-
return python_to_timestamp(value)
|
|
491
|
-
if sub_type == datetime.timedelta and isinstance(value, datetime.timedelta):
|
|
492
|
-
return python_to_duration(value)
|
|
493
|
-
if (
|
|
494
|
-
inspect.isclass(sub_type)
|
|
495
|
-
and issubclass(sub_type, enum.Enum)
|
|
496
|
-
and isinstance(value, enum.Enum)
|
|
497
|
-
):
|
|
498
|
-
return value.value
|
|
499
|
-
if sub_type in (int, float, str, bool, bytes) and isinstance(
|
|
500
|
-
value, sub_type
|
|
501
|
-
):
|
|
502
|
-
return value
|
|
503
|
-
if (
|
|
504
|
-
inspect.isclass(sub_type)
|
|
505
|
-
and issubclass(sub_type, Message)
|
|
506
|
-
and isinstance(value, Message)
|
|
507
|
-
):
|
|
508
|
-
return convert_python_message_to_proto(value, sub_type, pb2_module)
|
|
509
|
-
return None
|
|
510
|
-
|
|
511
|
-
# If Message
|
|
512
|
-
if inspect.isclass(field_type) and issubclass(field_type, Message):
|
|
513
|
-
return convert_python_message_to_proto(value, field_type, pb2_module)
|
|
514
|
-
|
|
515
|
-
# If primitive
|
|
516
|
-
return value
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
###############################################################################
|
|
520
|
-
# 3. Generating proto files (datetime->Timestamp, timedelta->Duration)
|
|
521
|
-
###############################################################################
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
def is_enum_type(python_type: Type) -> bool:
|
|
525
|
-
"""Return True if the given Python type is an enum."""
|
|
526
|
-
return inspect.isclass(python_type) and issubclass(python_type, enum.Enum)
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
def is_union_type(python_type: Type) -> bool:
|
|
530
|
-
"""
|
|
531
|
-
Check if a given Python type is a Union type (including Python 3.10's UnionType).
|
|
532
|
-
"""
|
|
533
|
-
if get_origin(python_type) is Union:
|
|
534
|
-
return True
|
|
535
|
-
if sys.version_info >= (3, 10):
|
|
536
|
-
import types
|
|
537
|
-
|
|
538
|
-
if type(python_type) is types.UnionType:
|
|
539
|
-
return True
|
|
540
|
-
return False
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
def flatten_union(field_type: Type) -> list[Type]:
|
|
544
|
-
"""Recursively flatten nested Unions into a single list of types."""
|
|
545
|
-
if is_union_type(field_type):
|
|
546
|
-
results = []
|
|
547
|
-
for arg in get_args(field_type):
|
|
548
|
-
results.extend(flatten_union(arg))
|
|
549
|
-
return results
|
|
550
|
-
else:
|
|
551
|
-
return [field_type]
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
def protobuf_type_mapping(python_type: Type) -> str | type | None:
|
|
555
|
-
"""
|
|
556
|
-
Map a Python type to a protobuf type name/class.
|
|
557
|
-
Includes support for Timestamp and Duration.
|
|
558
|
-
"""
|
|
559
|
-
import datetime
|
|
560
|
-
|
|
561
|
-
mapping = {
|
|
562
|
-
int: "int32",
|
|
563
|
-
str: "string",
|
|
564
|
-
bool: "bool",
|
|
565
|
-
bytes: "bytes",
|
|
566
|
-
float: "float",
|
|
567
|
-
}
|
|
568
|
-
|
|
569
|
-
if python_type == datetime.datetime:
|
|
570
|
-
return "google.protobuf.Timestamp"
|
|
571
|
-
|
|
572
|
-
if python_type == datetime.timedelta:
|
|
573
|
-
return "google.protobuf.Duration"
|
|
574
|
-
|
|
575
|
-
if is_enum_type(python_type):
|
|
576
|
-
return python_type # Will be defined as enum later
|
|
577
|
-
|
|
578
|
-
if is_union_type(python_type):
|
|
579
|
-
return None # Handled separately as oneof
|
|
580
|
-
|
|
581
|
-
if hasattr(python_type, "__origin__"):
|
|
582
|
-
if python_type.__origin__ in (list, tuple):
|
|
583
|
-
inner_type = python_type.__args__[0]
|
|
584
|
-
inner_proto_type = protobuf_type_mapping(inner_type)
|
|
585
|
-
if inner_proto_type:
|
|
586
|
-
return f"repeated {inner_proto_type}"
|
|
587
|
-
elif python_type.__origin__ is dict:
|
|
588
|
-
key_type = python_type.__args__[0]
|
|
589
|
-
value_type = python_type.__args__[1]
|
|
590
|
-
key_proto_type = protobuf_type_mapping(key_type)
|
|
591
|
-
value_proto_type = protobuf_type_mapping(value_type)
|
|
592
|
-
if key_proto_type and value_proto_type:
|
|
593
|
-
return f"map<{key_proto_type}, {value_proto_type}>"
|
|
594
|
-
|
|
595
|
-
if inspect.isclass(python_type) and issubclass(python_type, Message):
|
|
596
|
-
return python_type
|
|
597
|
-
|
|
598
|
-
return mapping.get(python_type) # type: ignore
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
def comment_out(docstr: str) -> tuple[str, ...]:
|
|
602
|
-
"""Convert docstrings into commented-out lines in a .proto file."""
|
|
603
|
-
if docstr
|
|
604
|
-
return tuple()
|
|
605
|
-
|
|
606
|
-
if docstr.startswith("Usage docs: https://docs.pydantic.dev/2.10/concepts/models/"):
|
|
607
|
-
return tuple()
|
|
608
|
-
|
|
609
|
-
return tuple(f"//" if line == "" else f"// {line}" for line in docstr.split("\n"))
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
def indent_lines(lines, indentation=" "):
|
|
613
|
-
"""Indent multiple lines with a given indentation string."""
|
|
614
|
-
return "\n".join(indentation + line for line in lines)
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
def generate_enum_definition(enum_type: Type[enum.Enum]) -> str:
|
|
618
|
-
"""Generate a protobuf enum definition from a Python enum."""
|
|
619
|
-
enum_name = enum_type.__name__
|
|
620
|
-
members = []
|
|
621
|
-
for _, member in enum_type.__members__.items():
|
|
622
|
-
members.append(f" {member.name} = {member.value};")
|
|
623
|
-
enum_def = f"enum {enum_name} {{\n"
|
|
624
|
-
enum_def += "\n".join(members)
|
|
625
|
-
enum_def += "\n}"
|
|
626
|
-
return enum_def
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
def generate_oneof_definition(
|
|
630
|
-
field_name: str, union_args: list[Type], start_index: int
|
|
631
|
-
) -> tuple[list[str], int]:
|
|
632
|
-
"""
|
|
633
|
-
Generate a oneof block in protobuf for a union field.
|
|
634
|
-
Returns a tuple of the definition lines and the updated field index.
|
|
635
|
-
"""
|
|
636
|
-
lines = []
|
|
637
|
-
lines.append(f"oneof {field_name} {{")
|
|
638
|
-
current = start_index
|
|
639
|
-
for arg_type in union_args:
|
|
640
|
-
proto_typename = protobuf_type_mapping(arg_type)
|
|
641
|
-
if proto_typename is None:
|
|
642
|
-
raise Exception(f"Nested Union not flattened properly: {arg_type}")
|
|
643
|
-
|
|
644
|
-
# If it's an enum or Message, use the type name.
|
|
645
|
-
if is_enum_type(arg_type):
|
|
646
|
-
proto_typename = arg_type.__name__
|
|
647
|
-
elif inspect.isclass(arg_type) and issubclass(arg_type, Message):
|
|
648
|
-
proto_typename = arg_type.__name__
|
|
649
|
-
|
|
650
|
-
field_alias = f"{field_name}_{proto_typename.replace('.', '_')}"
|
|
651
|
-
lines.append(f" {proto_typename} {field_alias} = {current};")
|
|
652
|
-
current += 1
|
|
653
|
-
lines.append("}")
|
|
654
|
-
return lines, current
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
def generate_message_definition(
|
|
658
|
-
message_type: Type[Message],
|
|
659
|
-
done_enums: set,
|
|
660
|
-
done_messages: set,
|
|
661
|
-
) -> tuple[str, list[type]]:
|
|
662
|
-
"""
|
|
663
|
-
Generate a protobuf message definition for a Pydantic-based Message class.
|
|
664
|
-
Also returns any referenced types (enums, messages) that need to be defined.
|
|
665
|
-
"""
|
|
666
|
-
fields = []
|
|
667
|
-
refs = []
|
|
668
|
-
pydantic_fields = message_type.model_fields
|
|
669
|
-
index = 1
|
|
670
|
-
|
|
671
|
-
for field_name, field_info in pydantic_fields.items():
|
|
672
|
-
field_type = field_info.annotation
|
|
673
|
-
if field_type is None:
|
|
674
|
-
raise Exception(f"Field {field_name} has no type annotation.")
|
|
675
|
-
|
|
676
|
-
if is_union_type(field_type):
|
|
677
|
-
union_args = flatten_union(field_type)
|
|
678
|
-
oneof_lines, new_index = generate_oneof_definition(
|
|
679
|
-
field_name, union_args, index
|
|
680
|
-
)
|
|
681
|
-
fields.extend(oneof_lines)
|
|
682
|
-
index = new_index
|
|
683
|
-
|
|
684
|
-
for utype in union_args:
|
|
685
|
-
if is_enum_type(utype) and utype not in done_enums:
|
|
686
|
-
refs.append(utype)
|
|
687
|
-
elif (
|
|
688
|
-
inspect.isclass(utype)
|
|
689
|
-
and issubclass(utype, Message)
|
|
690
|
-
and utype not in done_messages
|
|
691
|
-
):
|
|
692
|
-
refs.append(utype)
|
|
693
|
-
|
|
694
|
-
else:
|
|
695
|
-
proto_typename = protobuf_type_mapping(field_type)
|
|
696
|
-
if proto_typename is None:
|
|
697
|
-
raise Exception(f"Type {field_type} is not supported.")
|
|
698
|
-
|
|
699
|
-
if is_enum_type(field_type):
|
|
700
|
-
proto_typename = field_type.__name__
|
|
701
|
-
if field_type not in done_enums:
|
|
702
|
-
refs.append(field_type)
|
|
703
|
-
elif inspect.isclass(field_type) and issubclass(field_type, Message):
|
|
704
|
-
proto_typename = field_type.__name__
|
|
705
|
-
if field_type not in done_messages:
|
|
706
|
-
refs.append(field_type)
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
)
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
if
|
|
1021
|
-
raise Exception("Generating
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
)
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
self.
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
""
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
""
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
"
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
]
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
def
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
""
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
""
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1
|
+
import annotated_types
|
|
2
|
+
import asyncio
|
|
3
|
+
import enum
|
|
4
|
+
import importlib.util
|
|
5
|
+
import inspect
|
|
6
|
+
import os
|
|
7
|
+
import signal
|
|
8
|
+
import sys
|
|
9
|
+
import time
|
|
10
|
+
import types
|
|
11
|
+
import datetime
|
|
12
|
+
from concurrent import futures
|
|
13
|
+
from posixpath import basename
|
|
14
|
+
from typing import (
|
|
15
|
+
Callable,
|
|
16
|
+
Type,
|
|
17
|
+
get_args,
|
|
18
|
+
get_origin,
|
|
19
|
+
Union,
|
|
20
|
+
TypeAlias,
|
|
21
|
+
)
|
|
22
|
+
from collections.abc import AsyncIterator
|
|
23
|
+
|
|
24
|
+
import grpc
|
|
25
|
+
from grpc_health.v1 import health_pb2, health_pb2_grpc
|
|
26
|
+
from grpc_health.v1.health import HealthServicer
|
|
27
|
+
from grpc_reflection.v1alpha import reflection
|
|
28
|
+
from grpc_tools import protoc
|
|
29
|
+
from pydantic import BaseModel, ValidationError
|
|
30
|
+
from sonora.wsgi import grpcWSGI
|
|
31
|
+
from sonora.asgi import grpcASGI
|
|
32
|
+
from connecpy.asgi import ConnecpyASGIApp as ConnecpyASGI
|
|
33
|
+
from connecpy.errors import Errors
|
|
34
|
+
|
|
35
|
+
# Protobuf Python modules for Timestamp, Duration (requires protobuf / grpcio)
|
|
36
|
+
from google.protobuf import timestamp_pb2, duration_pb2
|
|
37
|
+
|
|
38
|
+
###############################################################################
|
|
39
|
+
# 1. Message definitions & converter extensions
|
|
40
|
+
# (datetime.datetime <-> google.protobuf.Timestamp)
|
|
41
|
+
# (datetime.timedelta <-> google.protobuf.Duration)
|
|
42
|
+
###############################################################################
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
Message: TypeAlias = BaseModel
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def primitiveProtoValueToPythonValue(value):
|
|
49
|
+
# Returns the value as-is (primitive type).
|
|
50
|
+
return value
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def timestamp_to_python(ts: timestamp_pb2.Timestamp) -> datetime.datetime: # type: ignore
|
|
54
|
+
"""Convert a protobuf Timestamp to a Python datetime object."""
|
|
55
|
+
return ts.ToDatetime()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def python_to_timestamp(dt: datetime.datetime) -> timestamp_pb2.Timestamp: # type: ignore
|
|
59
|
+
"""Convert a Python datetime object to a protobuf Timestamp."""
|
|
60
|
+
ts = timestamp_pb2.Timestamp() # type: ignore
|
|
61
|
+
ts.FromDatetime(dt)
|
|
62
|
+
return ts
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def duration_to_python(d: duration_pb2.Duration) -> datetime.timedelta: # type: ignore
|
|
66
|
+
"""Convert a protobuf Duration to a Python timedelta object."""
|
|
67
|
+
return d.ToTimedelta()
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def python_to_duration(td: datetime.timedelta) -> duration_pb2.Duration: # type: ignore
|
|
71
|
+
"""Convert a Python timedelta object to a protobuf Duration."""
|
|
72
|
+
d = duration_pb2.Duration() # type: ignore
|
|
73
|
+
d.FromTimedelta(td)
|
|
74
|
+
return d
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def generate_converter(annotation: Type) -> Callable:
|
|
78
|
+
"""
|
|
79
|
+
Returns a converter function to convert protobuf types to Python types.
|
|
80
|
+
This is used primarily when handling incoming requests.
|
|
81
|
+
"""
|
|
82
|
+
# For primitive types
|
|
83
|
+
if annotation in (int, str, bool, bytes, float):
|
|
84
|
+
return primitiveProtoValueToPythonValue
|
|
85
|
+
|
|
86
|
+
# For enum types
|
|
87
|
+
if inspect.isclass(annotation) and issubclass(annotation, enum.Enum):
|
|
88
|
+
|
|
89
|
+
def enum_converter(value):
|
|
90
|
+
return annotation(value)
|
|
91
|
+
|
|
92
|
+
return enum_converter
|
|
93
|
+
|
|
94
|
+
# For datetime
|
|
95
|
+
if annotation == datetime.datetime:
|
|
96
|
+
|
|
97
|
+
def ts_converter(value: timestamp_pb2.Timestamp): # type: ignore
|
|
98
|
+
return value.ToDatetime()
|
|
99
|
+
|
|
100
|
+
return ts_converter
|
|
101
|
+
|
|
102
|
+
# For timedelta
|
|
103
|
+
if annotation == datetime.timedelta:
|
|
104
|
+
|
|
105
|
+
def dur_converter(value: duration_pb2.Duration): # type: ignore
|
|
106
|
+
return value.ToTimedelta()
|
|
107
|
+
|
|
108
|
+
return dur_converter
|
|
109
|
+
|
|
110
|
+
origin = get_origin(annotation)
|
|
111
|
+
if origin is not None:
|
|
112
|
+
# For seq types
|
|
113
|
+
if origin in (list, tuple):
|
|
114
|
+
item_converter = generate_converter(get_args(annotation)[0])
|
|
115
|
+
|
|
116
|
+
def seq_converter(value):
|
|
117
|
+
return [item_converter(v) for v in value]
|
|
118
|
+
|
|
119
|
+
return seq_converter
|
|
120
|
+
|
|
121
|
+
# For dict types
|
|
122
|
+
if origin is dict:
|
|
123
|
+
key_converter = generate_converter(get_args(annotation)[0])
|
|
124
|
+
value_converter = generate_converter(get_args(annotation)[1])
|
|
125
|
+
|
|
126
|
+
def dict_converter(value):
|
|
127
|
+
return {key_converter(k): value_converter(v) for k, v in value.items()}
|
|
128
|
+
|
|
129
|
+
return dict_converter
|
|
130
|
+
|
|
131
|
+
# For Message classes
|
|
132
|
+
if inspect.isclass(annotation) and issubclass(annotation, Message):
|
|
133
|
+
return generate_message_converter(annotation)
|
|
134
|
+
|
|
135
|
+
# For union types or other unsupported cases, just return the value as-is.
|
|
136
|
+
return primitiveProtoValueToPythonValue
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def generate_message_converter(arg_type: Type[Message]) -> Callable:
|
|
140
|
+
"""Return a converter function for protobuf -> Python Message."""
|
|
141
|
+
if arg_type is None or not issubclass(arg_type, Message):
|
|
142
|
+
raise TypeError("Request arg must be subclass of Message")
|
|
143
|
+
|
|
144
|
+
fields = arg_type.model_fields
|
|
145
|
+
converters = {
|
|
146
|
+
field: generate_converter(field_type.annotation) # type: ignore
|
|
147
|
+
for field, field_type in fields.items()
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
def converter(request):
|
|
151
|
+
rdict = {}
|
|
152
|
+
for field in fields.keys():
|
|
153
|
+
rdict[field] = converters[field](getattr(request, field))
|
|
154
|
+
return arg_type(**rdict)
|
|
155
|
+
|
|
156
|
+
return converter
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def python_value_to_proto_value(field_type: Type, value):
|
|
160
|
+
"""
|
|
161
|
+
Converts Python values to protobuf values.
|
|
162
|
+
Used primarily when constructing a response object.
|
|
163
|
+
"""
|
|
164
|
+
# datetime.datetime -> Timestamp
|
|
165
|
+
if field_type == datetime.datetime:
|
|
166
|
+
return python_to_timestamp(value)
|
|
167
|
+
|
|
168
|
+
# datetime.timedelta -> Duration
|
|
169
|
+
if field_type == datetime.timedelta:
|
|
170
|
+
return python_to_duration(value)
|
|
171
|
+
|
|
172
|
+
# Default behavior: return the value as-is.
|
|
173
|
+
return value
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
###############################################################################
|
|
177
|
+
# 2. Stub implementation
|
|
178
|
+
###############################################################################
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def connect_obj_with_stub(pb2_grpc_module, pb2_module, service_obj: object) -> type:
|
|
182
|
+
"""
|
|
183
|
+
Connect a Python service object to a gRPC stub, generating server methods.
|
|
184
|
+
"""
|
|
185
|
+
service_class = service_obj.__class__
|
|
186
|
+
stub_class_name = service_class.__name__ + "Servicer"
|
|
187
|
+
stub_class = getattr(pb2_grpc_module, stub_class_name)
|
|
188
|
+
|
|
189
|
+
class ConcreteServiceClass(stub_class):
|
|
190
|
+
pass
|
|
191
|
+
|
|
192
|
+
def implement_stub_method(method):
|
|
193
|
+
# Analyze method signature.
|
|
194
|
+
sig = inspect.signature(method)
|
|
195
|
+
arg_type = get_request_arg_type(sig)
|
|
196
|
+
# Convert request from protobuf to Python.
|
|
197
|
+
converter = generate_message_converter(arg_type)
|
|
198
|
+
|
|
199
|
+
response_type = sig.return_annotation
|
|
200
|
+
size_of_parameters = len(sig.parameters)
|
|
201
|
+
|
|
202
|
+
match size_of_parameters:
|
|
203
|
+
case 1:
|
|
204
|
+
|
|
205
|
+
def stub_method1(self, request, context, method=method):
|
|
206
|
+
try:
|
|
207
|
+
# Convert request to Python object
|
|
208
|
+
arg = converter(request)
|
|
209
|
+
# Invoke the actual method
|
|
210
|
+
resp_obj = method(arg)
|
|
211
|
+
# Convert the returned Python Message to a protobuf message
|
|
212
|
+
return convert_python_message_to_proto(
|
|
213
|
+
resp_obj, response_type, pb2_module
|
|
214
|
+
)
|
|
215
|
+
except ValidationError as e:
|
|
216
|
+
return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
|
217
|
+
except Exception as e:
|
|
218
|
+
return context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
219
|
+
|
|
220
|
+
return stub_method1
|
|
221
|
+
|
|
222
|
+
case 2:
|
|
223
|
+
|
|
224
|
+
def stub_method2(self, request, context, method=method):
|
|
225
|
+
try:
|
|
226
|
+
arg = converter(request)
|
|
227
|
+
resp_obj = method(arg, context)
|
|
228
|
+
return convert_python_message_to_proto(
|
|
229
|
+
resp_obj, response_type, pb2_module
|
|
230
|
+
)
|
|
231
|
+
except ValidationError as e:
|
|
232
|
+
return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
|
233
|
+
except Exception as e:
|
|
234
|
+
return context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
235
|
+
|
|
236
|
+
return stub_method2
|
|
237
|
+
|
|
238
|
+
case _:
|
|
239
|
+
raise Exception("Method must have exactly one or two parameters")
|
|
240
|
+
|
|
241
|
+
for method_name, method in get_rpc_methods(service_obj):
|
|
242
|
+
if method.__name__.startswith("_"):
|
|
243
|
+
continue
|
|
244
|
+
|
|
245
|
+
a_method = implement_stub_method(method)
|
|
246
|
+
setattr(ConcreteServiceClass, method_name, a_method)
|
|
247
|
+
|
|
248
|
+
return ConcreteServiceClass
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> type:
|
|
252
|
+
"""
|
|
253
|
+
Connect a Python service object to a gRPC stub for async methods.
|
|
254
|
+
"""
|
|
255
|
+
service_class = obj.__class__
|
|
256
|
+
stub_class_name = service_class.__name__ + "Servicer"
|
|
257
|
+
stub_class = getattr(pb2_grpc_module, stub_class_name)
|
|
258
|
+
|
|
259
|
+
class ConcreteServiceClass(stub_class):
|
|
260
|
+
pass
|
|
261
|
+
|
|
262
|
+
def implement_stub_method(method):
|
|
263
|
+
sig = inspect.signature(method)
|
|
264
|
+
arg_type = get_request_arg_type(sig)
|
|
265
|
+
converter = generate_message_converter(arg_type)
|
|
266
|
+
response_type = sig.return_annotation
|
|
267
|
+
size_of_parameters = len(sig.parameters)
|
|
268
|
+
|
|
269
|
+
if is_stream_type(response_type):
|
|
270
|
+
item_type = get_args(response_type)[0]
|
|
271
|
+
match size_of_parameters:
|
|
272
|
+
case 1:
|
|
273
|
+
|
|
274
|
+
async def stub_method_stream1(
|
|
275
|
+
self, request, context, method=method
|
|
276
|
+
):
|
|
277
|
+
try:
|
|
278
|
+
arg = converter(request)
|
|
279
|
+
async for resp_obj in method(arg):
|
|
280
|
+
yield convert_python_message_to_proto(
|
|
281
|
+
resp_obj, item_type, pb2_module
|
|
282
|
+
)
|
|
283
|
+
except ValidationError as e:
|
|
284
|
+
await context.abort(
|
|
285
|
+
grpc.StatusCode.INVALID_ARGUMENT, str(e)
|
|
286
|
+
)
|
|
287
|
+
except Exception as e:
|
|
288
|
+
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
289
|
+
|
|
290
|
+
return stub_method_stream1
|
|
291
|
+
case 2:
|
|
292
|
+
|
|
293
|
+
async def stub_method_stream2(
|
|
294
|
+
self, request, context, method=method
|
|
295
|
+
):
|
|
296
|
+
try:
|
|
297
|
+
arg = converter(request)
|
|
298
|
+
async for resp_obj in method(arg, context):
|
|
299
|
+
yield convert_python_message_to_proto(
|
|
300
|
+
resp_obj, item_type, pb2_module
|
|
301
|
+
)
|
|
302
|
+
except ValidationError as e:
|
|
303
|
+
await context.abort(
|
|
304
|
+
grpc.StatusCode.INVALID_ARGUMENT, str(e)
|
|
305
|
+
)
|
|
306
|
+
except Exception as e:
|
|
307
|
+
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
308
|
+
|
|
309
|
+
return stub_method_stream2
|
|
310
|
+
case _:
|
|
311
|
+
raise Exception("Method must have exactly one or two parameters")
|
|
312
|
+
|
|
313
|
+
match size_of_parameters:
|
|
314
|
+
case 1:
|
|
315
|
+
|
|
316
|
+
async def stub_method1(self, request, context, method=method):
|
|
317
|
+
try:
|
|
318
|
+
arg = converter(request)
|
|
319
|
+
resp_obj = await method(arg)
|
|
320
|
+
return convert_python_message_to_proto(
|
|
321
|
+
resp_obj, response_type, pb2_module
|
|
322
|
+
)
|
|
323
|
+
except ValidationError as e:
|
|
324
|
+
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
|
325
|
+
except Exception as e:
|
|
326
|
+
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
327
|
+
|
|
328
|
+
return stub_method1
|
|
329
|
+
|
|
330
|
+
case 2:
|
|
331
|
+
|
|
332
|
+
async def stub_method2(self, request, context, method=method):
|
|
333
|
+
try:
|
|
334
|
+
arg = converter(request)
|
|
335
|
+
resp_obj = await method(arg, context)
|
|
336
|
+
return convert_python_message_to_proto(
|
|
337
|
+
resp_obj, response_type, pb2_module
|
|
338
|
+
)
|
|
339
|
+
except ValidationError as e:
|
|
340
|
+
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
|
341
|
+
except Exception as e:
|
|
342
|
+
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
343
|
+
|
|
344
|
+
return stub_method2
|
|
345
|
+
|
|
346
|
+
case _:
|
|
347
|
+
raise Exception("Method must have exactly one or two parameters")
|
|
348
|
+
|
|
349
|
+
for method_name, method in get_rpc_methods(obj):
|
|
350
|
+
if method.__name__.startswith("_"):
|
|
351
|
+
continue
|
|
352
|
+
|
|
353
|
+
a_method = implement_stub_method(method)
|
|
354
|
+
setattr(ConcreteServiceClass, method_name, a_method)
|
|
355
|
+
|
|
356
|
+
return ConcreteServiceClass
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def connect_obj_with_stub_async_connecpy(
|
|
360
|
+
connecpy_module, pb2_module, obj: object
|
|
361
|
+
) -> type:
|
|
362
|
+
"""
|
|
363
|
+
Connect a Python service object to a Connecpy stub for async methods.
|
|
364
|
+
"""
|
|
365
|
+
service_class = obj.__class__
|
|
366
|
+
stub_class_name = service_class.__name__
|
|
367
|
+
stub_class = getattr(connecpy_module, stub_class_name)
|
|
368
|
+
|
|
369
|
+
class ConcreteServiceClass(stub_class):
|
|
370
|
+
pass
|
|
371
|
+
|
|
372
|
+
def implement_stub_method(method):
|
|
373
|
+
sig = inspect.signature(method)
|
|
374
|
+
arg_type = get_request_arg_type(sig)
|
|
375
|
+
converter = generate_message_converter(arg_type)
|
|
376
|
+
response_type = sig.return_annotation
|
|
377
|
+
size_of_parameters = len(sig.parameters)
|
|
378
|
+
|
|
379
|
+
match size_of_parameters:
|
|
380
|
+
case 1:
|
|
381
|
+
|
|
382
|
+
async def stub_method1(self, request, context, method=method):
|
|
383
|
+
try:
|
|
384
|
+
arg = converter(request)
|
|
385
|
+
resp_obj = await method(arg)
|
|
386
|
+
return convert_python_message_to_proto(
|
|
387
|
+
resp_obj, response_type, pb2_module
|
|
388
|
+
)
|
|
389
|
+
except ValidationError as e:
|
|
390
|
+
await context.abort(Errors.InvalidArgument, str(e))
|
|
391
|
+
except Exception as e:
|
|
392
|
+
await context.abort(Errors.Internal, str(e))
|
|
393
|
+
|
|
394
|
+
return stub_method1
|
|
395
|
+
|
|
396
|
+
case 2:
|
|
397
|
+
|
|
398
|
+
async def stub_method2(self, request, context, method=method):
|
|
399
|
+
try:
|
|
400
|
+
arg = converter(request)
|
|
401
|
+
resp_obj = await method(arg, context)
|
|
402
|
+
return convert_python_message_to_proto(
|
|
403
|
+
resp_obj, response_type, pb2_module
|
|
404
|
+
)
|
|
405
|
+
except ValidationError as e:
|
|
406
|
+
await context.abort(Errors.InvalidArgument, str(e))
|
|
407
|
+
except Exception as e:
|
|
408
|
+
await context.abort(Errors.Internal, str(e))
|
|
409
|
+
|
|
410
|
+
return stub_method2
|
|
411
|
+
|
|
412
|
+
case _:
|
|
413
|
+
raise Exception("Method must have exactly one or two parameters")
|
|
414
|
+
|
|
415
|
+
for method_name, method in get_rpc_methods(obj):
|
|
416
|
+
if method.__name__.startswith("_"):
|
|
417
|
+
continue
|
|
418
|
+
if not asyncio.iscoroutinefunction(method):
|
|
419
|
+
raise Exception("Method must be async", method_name)
|
|
420
|
+
a_method = implement_stub_method(method)
|
|
421
|
+
setattr(ConcreteServiceClass, method_name, a_method)
|
|
422
|
+
|
|
423
|
+
return ConcreteServiceClass
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def convert_python_message_to_proto(
|
|
427
|
+
py_msg: Message, msg_type: Type, pb2_module
|
|
428
|
+
) -> object:
|
|
429
|
+
"""
|
|
430
|
+
Convert a Python Pydantic Message instance to a protobuf message instance.
|
|
431
|
+
Used for constructing a response.
|
|
432
|
+
"""
|
|
433
|
+
# Before calling something like pb2_module.AResponseMessage(...),
|
|
434
|
+
# convert each field from Python to proto.
|
|
435
|
+
field_dict = {}
|
|
436
|
+
for name, field_info in msg_type.model_fields.items():
|
|
437
|
+
value = getattr(py_msg, name)
|
|
438
|
+
if value is None:
|
|
439
|
+
field_dict[name] = None
|
|
440
|
+
continue
|
|
441
|
+
|
|
442
|
+
field_type = field_info.annotation
|
|
443
|
+
field_dict[name] = python_value_to_proto(field_type, value, pb2_module)
|
|
444
|
+
|
|
445
|
+
# Retrieve the appropriate protobuf class dynamically
|
|
446
|
+
proto_class = getattr(pb2_module, msg_type.__name__)
|
|
447
|
+
return proto_class(**field_dict)
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def python_value_to_proto(field_type: Type, value, pb2_module):
|
|
451
|
+
"""
|
|
452
|
+
Perform Python->protobuf type conversion for each field value.
|
|
453
|
+
"""
|
|
454
|
+
import inspect
|
|
455
|
+
import datetime
|
|
456
|
+
|
|
457
|
+
# If datetime
|
|
458
|
+
if field_type == datetime.datetime:
|
|
459
|
+
return python_to_timestamp(value)
|
|
460
|
+
|
|
461
|
+
# If timedelta
|
|
462
|
+
if field_type == datetime.timedelta:
|
|
463
|
+
return python_to_duration(value)
|
|
464
|
+
|
|
465
|
+
# If enum
|
|
466
|
+
if inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
|
|
467
|
+
return value.value # proto3 enum is an int
|
|
468
|
+
|
|
469
|
+
origin = get_origin(field_type)
|
|
470
|
+
# If seq
|
|
471
|
+
if origin in (list, tuple):
|
|
472
|
+
inner_type = get_args(field_type)[0] # type: ignore
|
|
473
|
+
return [python_value_to_proto(inner_type, v, pb2_module) for v in value]
|
|
474
|
+
|
|
475
|
+
# If dict
|
|
476
|
+
if origin is dict:
|
|
477
|
+
key_type, val_type = get_args(field_type) # type: ignore
|
|
478
|
+
return {
|
|
479
|
+
python_value_to_proto(key_type, k, pb2_module): python_value_to_proto(
|
|
480
|
+
val_type, v, pb2_module
|
|
481
|
+
)
|
|
482
|
+
for k, v in value.items()
|
|
483
|
+
}
|
|
484
|
+
|
|
485
|
+
# If union -> oneof
|
|
486
|
+
if is_union_type(field_type):
|
|
487
|
+
# Flatten union and check which type matches. If matched, return converted value.
|
|
488
|
+
for sub_type in flatten_union(field_type):
|
|
489
|
+
if sub_type == datetime.datetime and isinstance(value, datetime.datetime):
|
|
490
|
+
return python_to_timestamp(value)
|
|
491
|
+
if sub_type == datetime.timedelta and isinstance(value, datetime.timedelta):
|
|
492
|
+
return python_to_duration(value)
|
|
493
|
+
if (
|
|
494
|
+
inspect.isclass(sub_type)
|
|
495
|
+
and issubclass(sub_type, enum.Enum)
|
|
496
|
+
and isinstance(value, enum.Enum)
|
|
497
|
+
):
|
|
498
|
+
return value.value
|
|
499
|
+
if sub_type in (int, float, str, bool, bytes) and isinstance(
|
|
500
|
+
value, sub_type
|
|
501
|
+
):
|
|
502
|
+
return value
|
|
503
|
+
if (
|
|
504
|
+
inspect.isclass(sub_type)
|
|
505
|
+
and issubclass(sub_type, Message)
|
|
506
|
+
and isinstance(value, Message)
|
|
507
|
+
):
|
|
508
|
+
return convert_python_message_to_proto(value, sub_type, pb2_module)
|
|
509
|
+
return None
|
|
510
|
+
|
|
511
|
+
# If Message
|
|
512
|
+
if inspect.isclass(field_type) and issubclass(field_type, Message):
|
|
513
|
+
return convert_python_message_to_proto(value, field_type, pb2_module)
|
|
514
|
+
|
|
515
|
+
# If primitive
|
|
516
|
+
return value
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
###############################################################################
|
|
520
|
+
# 3. Generating proto files (datetime->Timestamp, timedelta->Duration)
|
|
521
|
+
###############################################################################
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
def is_enum_type(python_type: Type) -> bool:
|
|
525
|
+
"""Return True if the given Python type is an enum."""
|
|
526
|
+
return inspect.isclass(python_type) and issubclass(python_type, enum.Enum)
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def is_union_type(python_type: Type) -> bool:
|
|
530
|
+
"""
|
|
531
|
+
Check if a given Python type is a Union type (including Python 3.10's UnionType).
|
|
532
|
+
"""
|
|
533
|
+
if get_origin(python_type) is Union:
|
|
534
|
+
return True
|
|
535
|
+
if sys.version_info >= (3, 10):
|
|
536
|
+
import types
|
|
537
|
+
|
|
538
|
+
if type(python_type) is types.UnionType:
|
|
539
|
+
return True
|
|
540
|
+
return False
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
def flatten_union(field_type: Type) -> list[Type]:
|
|
544
|
+
"""Recursively flatten nested Unions into a single list of types."""
|
|
545
|
+
if is_union_type(field_type):
|
|
546
|
+
results = []
|
|
547
|
+
for arg in get_args(field_type):
|
|
548
|
+
results.extend(flatten_union(arg))
|
|
549
|
+
return results
|
|
550
|
+
else:
|
|
551
|
+
return [field_type]
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def protobuf_type_mapping(python_type: Type) -> str | type | None:
|
|
555
|
+
"""
|
|
556
|
+
Map a Python type to a protobuf type name/class.
|
|
557
|
+
Includes support for Timestamp and Duration.
|
|
558
|
+
"""
|
|
559
|
+
import datetime
|
|
560
|
+
|
|
561
|
+
mapping = {
|
|
562
|
+
int: "int32",
|
|
563
|
+
str: "string",
|
|
564
|
+
bool: "bool",
|
|
565
|
+
bytes: "bytes",
|
|
566
|
+
float: "float",
|
|
567
|
+
}
|
|
568
|
+
|
|
569
|
+
if python_type == datetime.datetime:
|
|
570
|
+
return "google.protobuf.Timestamp"
|
|
571
|
+
|
|
572
|
+
if python_type == datetime.timedelta:
|
|
573
|
+
return "google.protobuf.Duration"
|
|
574
|
+
|
|
575
|
+
if is_enum_type(python_type):
|
|
576
|
+
return python_type # Will be defined as enum later
|
|
577
|
+
|
|
578
|
+
if is_union_type(python_type):
|
|
579
|
+
return None # Handled separately as oneof
|
|
580
|
+
|
|
581
|
+
if hasattr(python_type, "__origin__"):
|
|
582
|
+
if python_type.__origin__ in (list, tuple):
|
|
583
|
+
inner_type = python_type.__args__[0]
|
|
584
|
+
inner_proto_type = protobuf_type_mapping(inner_type)
|
|
585
|
+
if inner_proto_type:
|
|
586
|
+
return f"repeated {inner_proto_type}"
|
|
587
|
+
elif python_type.__origin__ is dict:
|
|
588
|
+
key_type = python_type.__args__[0]
|
|
589
|
+
value_type = python_type.__args__[1]
|
|
590
|
+
key_proto_type = protobuf_type_mapping(key_type)
|
|
591
|
+
value_proto_type = protobuf_type_mapping(value_type)
|
|
592
|
+
if key_proto_type and value_proto_type:
|
|
593
|
+
return f"map<{key_proto_type}, {value_proto_type}>"
|
|
594
|
+
|
|
595
|
+
if inspect.isclass(python_type) and issubclass(python_type, Message):
|
|
596
|
+
return python_type
|
|
597
|
+
|
|
598
|
+
return mapping.get(python_type) # type: ignore
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
def comment_out(docstr: str) -> tuple[str, ...]:
|
|
602
|
+
"""Convert docstrings into commented-out lines in a .proto file."""
|
|
603
|
+
if not docstr:
|
|
604
|
+
return tuple()
|
|
605
|
+
|
|
606
|
+
if docstr.startswith("Usage docs: https://docs.pydantic.dev/2.10/concepts/models/"):
|
|
607
|
+
return tuple()
|
|
608
|
+
|
|
609
|
+
return tuple(f"//" if line == "" else f"// {line}" for line in docstr.split("\n"))
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
def indent_lines(lines, indentation=" "):
|
|
613
|
+
"""Indent multiple lines with a given indentation string."""
|
|
614
|
+
return "\n".join(indentation + line for line in lines)
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
def generate_enum_definition(enum_type: Type[enum.Enum]) -> str:
|
|
618
|
+
"""Generate a protobuf enum definition from a Python enum."""
|
|
619
|
+
enum_name = enum_type.__name__
|
|
620
|
+
members = []
|
|
621
|
+
for _, member in enum_type.__members__.items():
|
|
622
|
+
members.append(f" {member.name} = {member.value};")
|
|
623
|
+
enum_def = f"enum {enum_name} {{\n"
|
|
624
|
+
enum_def += "\n".join(members)
|
|
625
|
+
enum_def += "\n}"
|
|
626
|
+
return enum_def
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
def generate_oneof_definition(
|
|
630
|
+
field_name: str, union_args: list[Type], start_index: int
|
|
631
|
+
) -> tuple[list[str], int]:
|
|
632
|
+
"""
|
|
633
|
+
Generate a oneof block in protobuf for a union field.
|
|
634
|
+
Returns a tuple of the definition lines and the updated field index.
|
|
635
|
+
"""
|
|
636
|
+
lines = []
|
|
637
|
+
lines.append(f"oneof {field_name} {{")
|
|
638
|
+
current = start_index
|
|
639
|
+
for arg_type in union_args:
|
|
640
|
+
proto_typename = protobuf_type_mapping(arg_type)
|
|
641
|
+
if proto_typename is None:
|
|
642
|
+
raise Exception(f"Nested Union not flattened properly: {arg_type}")
|
|
643
|
+
|
|
644
|
+
# If it's an enum or Message, use the type name.
|
|
645
|
+
if is_enum_type(arg_type):
|
|
646
|
+
proto_typename = arg_type.__name__
|
|
647
|
+
elif inspect.isclass(arg_type) and issubclass(arg_type, Message):
|
|
648
|
+
proto_typename = arg_type.__name__
|
|
649
|
+
|
|
650
|
+
field_alias = f"{field_name}_{proto_typename.replace('.', '_')}"
|
|
651
|
+
lines.append(f" {proto_typename} {field_alias} = {current};")
|
|
652
|
+
current += 1
|
|
653
|
+
lines.append("}")
|
|
654
|
+
return lines, current
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
def generate_message_definition(
|
|
658
|
+
message_type: Type[Message],
|
|
659
|
+
done_enums: set,
|
|
660
|
+
done_messages: set,
|
|
661
|
+
) -> tuple[str, list[type]]:
|
|
662
|
+
"""
|
|
663
|
+
Generate a protobuf message definition for a Pydantic-based Message class.
|
|
664
|
+
Also returns any referenced types (enums, messages) that need to be defined.
|
|
665
|
+
"""
|
|
666
|
+
fields = []
|
|
667
|
+
refs = []
|
|
668
|
+
pydantic_fields = message_type.model_fields
|
|
669
|
+
index = 1
|
|
670
|
+
|
|
671
|
+
for field_name, field_info in pydantic_fields.items():
|
|
672
|
+
field_type = field_info.annotation
|
|
673
|
+
if field_type is None:
|
|
674
|
+
raise Exception(f"Field {field_name} has no type annotation.")
|
|
675
|
+
|
|
676
|
+
if is_union_type(field_type):
|
|
677
|
+
union_args = flatten_union(field_type)
|
|
678
|
+
oneof_lines, new_index = generate_oneof_definition(
|
|
679
|
+
field_name, union_args, index
|
|
680
|
+
)
|
|
681
|
+
fields.extend(oneof_lines)
|
|
682
|
+
index = new_index
|
|
683
|
+
|
|
684
|
+
for utype in union_args:
|
|
685
|
+
if is_enum_type(utype) and utype not in done_enums:
|
|
686
|
+
refs.append(utype)
|
|
687
|
+
elif (
|
|
688
|
+
inspect.isclass(utype)
|
|
689
|
+
and issubclass(utype, Message)
|
|
690
|
+
and utype not in done_messages
|
|
691
|
+
):
|
|
692
|
+
refs.append(utype)
|
|
693
|
+
|
|
694
|
+
else:
|
|
695
|
+
proto_typename = protobuf_type_mapping(field_type)
|
|
696
|
+
if proto_typename is None:
|
|
697
|
+
raise Exception(f"Type {field_type} is not supported.")
|
|
698
|
+
|
|
699
|
+
if is_enum_type(field_type):
|
|
700
|
+
proto_typename = field_type.__name__
|
|
701
|
+
if field_type not in done_enums:
|
|
702
|
+
refs.append(field_type)
|
|
703
|
+
elif inspect.isclass(field_type) and issubclass(field_type, Message):
|
|
704
|
+
proto_typename = field_type.__name__
|
|
705
|
+
if field_type not in done_messages:
|
|
706
|
+
refs.append(field_type)
|
|
707
|
+
|
|
708
|
+
if field_info.description:
|
|
709
|
+
fields.append("// " + field_info.description)
|
|
710
|
+
if field_info.metadata:
|
|
711
|
+
fields.append("// Constraint:")
|
|
712
|
+
for metadata_item in field_info.metadata:
|
|
713
|
+
match type(metadata_item):
|
|
714
|
+
case annotated_types.Ge:
|
|
715
|
+
fields.append(
|
|
716
|
+
"// greater than or equal to " + str(metadata_item.ge)
|
|
717
|
+
)
|
|
718
|
+
case annotated_types.Le:
|
|
719
|
+
fields.append(
|
|
720
|
+
"// less than or equal to " + str(metadata_item.le)
|
|
721
|
+
)
|
|
722
|
+
case annotated_types.Gt:
|
|
723
|
+
fields.append("// greater than " + str(metadata_item.gt))
|
|
724
|
+
case annotated_types.Lt:
|
|
725
|
+
fields.append("// less than " + str(metadata_item.lt))
|
|
726
|
+
case annotated_types.MultipleOf:
|
|
727
|
+
fields.append(
|
|
728
|
+
"// multiple of " + str(metadata_item.multiple_of)
|
|
729
|
+
)
|
|
730
|
+
case annotated_types.Len:
|
|
731
|
+
fields.append("// length of " + str(metadata_item.len))
|
|
732
|
+
case annotated_types.MinLen:
|
|
733
|
+
fields.append(
|
|
734
|
+
"// minimum length of " + str(metadata_item.min_len)
|
|
735
|
+
)
|
|
736
|
+
case annotated_types.MaxLen:
|
|
737
|
+
fields.append(
|
|
738
|
+
"// maximum length of " + str(metadata_item.max_len)
|
|
739
|
+
)
|
|
740
|
+
case _:
|
|
741
|
+
fields.append("// " + str(metadata_item))
|
|
742
|
+
|
|
743
|
+
fields.append(f"{proto_typename} {field_name} = {index};")
|
|
744
|
+
index += 1
|
|
745
|
+
|
|
746
|
+
msg_def = f"message {message_type.__name__} {{\n{indent_lines(fields)}\n}}"
|
|
747
|
+
return msg_def, refs
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
def is_stream_type(annotation: Type) -> bool:
|
|
751
|
+
return get_origin(annotation) is AsyncIterator
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
def is_generic_alias(annotation: Type) -> bool:
|
|
755
|
+
return get_origin(annotation) is not None
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
def generate_proto(obj: object, package_name: str = "") -> str:
|
|
759
|
+
"""
|
|
760
|
+
Generate a .proto definition from a service class.
|
|
761
|
+
Automatically handles Timestamp and Duration usage.
|
|
762
|
+
"""
|
|
763
|
+
import datetime
|
|
764
|
+
|
|
765
|
+
service_class = obj.__class__
|
|
766
|
+
service_name = service_class.__name__
|
|
767
|
+
service_docstr = inspect.getdoc(service_class)
|
|
768
|
+
service_comment = "\n".join(comment_out(service_docstr)) if service_docstr else ""
|
|
769
|
+
|
|
770
|
+
rpc_definitions = []
|
|
771
|
+
all_type_definitions = []
|
|
772
|
+
done_messages = set()
|
|
773
|
+
done_enums = set()
|
|
774
|
+
|
|
775
|
+
uses_timestamp = False
|
|
776
|
+
uses_duration = False
|
|
777
|
+
|
|
778
|
+
def check_and_set_well_known_types(py_type: Type):
|
|
779
|
+
nonlocal uses_timestamp, uses_duration
|
|
780
|
+
if py_type == datetime.datetime:
|
|
781
|
+
uses_timestamp = True
|
|
782
|
+
if py_type == datetime.timedelta:
|
|
783
|
+
uses_duration = True
|
|
784
|
+
|
|
785
|
+
for method_name, method in get_rpc_methods(obj):
|
|
786
|
+
if method.__name__.startswith("_"):
|
|
787
|
+
continue
|
|
788
|
+
|
|
789
|
+
method_sig = inspect.signature(method)
|
|
790
|
+
request_type = get_request_arg_type(method_sig)
|
|
791
|
+
response_type = method_sig.return_annotation
|
|
792
|
+
|
|
793
|
+
# Recursively generate message definitions
|
|
794
|
+
message_types = [request_type, response_type]
|
|
795
|
+
while message_types:
|
|
796
|
+
mt = message_types.pop()
|
|
797
|
+
if mt in done_messages:
|
|
798
|
+
continue
|
|
799
|
+
done_messages.add(mt)
|
|
800
|
+
|
|
801
|
+
if is_stream_type(mt):
|
|
802
|
+
item_type = get_args(mt)[0]
|
|
803
|
+
message_types.append(item_type)
|
|
804
|
+
continue
|
|
805
|
+
|
|
806
|
+
for _, field_info in mt.model_fields.items():
|
|
807
|
+
t = field_info.annotation
|
|
808
|
+
if is_union_type(t):
|
|
809
|
+
for sub_t in flatten_union(t):
|
|
810
|
+
check_and_set_well_known_types(sub_t)
|
|
811
|
+
else:
|
|
812
|
+
check_and_set_well_known_types(t)
|
|
813
|
+
|
|
814
|
+
msg_def, refs = generate_message_definition(mt, done_enums, done_messages)
|
|
815
|
+
mt_doc = inspect.getdoc(mt)
|
|
816
|
+
if mt_doc:
|
|
817
|
+
for comment_line in comment_out(mt_doc):
|
|
818
|
+
all_type_definitions.append(comment_line)
|
|
819
|
+
|
|
820
|
+
all_type_definitions.append(msg_def)
|
|
821
|
+
all_type_definitions.append("")
|
|
822
|
+
|
|
823
|
+
for r in refs:
|
|
824
|
+
if is_enum_type(r) and r not in done_enums:
|
|
825
|
+
done_enums.add(r)
|
|
826
|
+
enum_def = generate_enum_definition(r)
|
|
827
|
+
all_type_definitions.append(enum_def)
|
|
828
|
+
all_type_definitions.append("")
|
|
829
|
+
elif issubclass(r, Message) and r not in done_messages:
|
|
830
|
+
message_types.append(r)
|
|
831
|
+
|
|
832
|
+
method_docstr = inspect.getdoc(method)
|
|
833
|
+
if method_docstr:
|
|
834
|
+
for comment_line in comment_out(method_docstr):
|
|
835
|
+
rpc_definitions.append(comment_line)
|
|
836
|
+
|
|
837
|
+
if is_stream_type(response_type):
|
|
838
|
+
item_type = get_args(response_type)[0]
|
|
839
|
+
rpc_definitions.append(
|
|
840
|
+
f"rpc {method_name} ({request_type.__name__}) returns (stream {item_type.__name__});"
|
|
841
|
+
)
|
|
842
|
+
else:
|
|
843
|
+
rpc_definitions.append(
|
|
844
|
+
f"rpc {method_name} ({request_type.__name__}) returns ({response_type.__name__});"
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
if not package_name:
|
|
848
|
+
if service_name.endswith("Service"):
|
|
849
|
+
package_name = service_name[: -len("Service")]
|
|
850
|
+
else:
|
|
851
|
+
package_name = service_name
|
|
852
|
+
package_name = package_name.lower() + ".v1"
|
|
853
|
+
|
|
854
|
+
imports = []
|
|
855
|
+
if uses_timestamp:
|
|
856
|
+
imports.append('import "google/protobuf/timestamp.proto";')
|
|
857
|
+
if uses_duration:
|
|
858
|
+
imports.append('import "google/protobuf/duration.proto";')
|
|
859
|
+
|
|
860
|
+
import_block = "\n".join(imports)
|
|
861
|
+
if import_block:
|
|
862
|
+
import_block += "\n"
|
|
863
|
+
|
|
864
|
+
proto_definition = f"""syntax = "proto3";
|
|
865
|
+
|
|
866
|
+
package {package_name};
|
|
867
|
+
|
|
868
|
+
{import_block}{service_comment}
|
|
869
|
+
service {service_name} {{
|
|
870
|
+
{indent_lines(rpc_definitions)}
|
|
871
|
+
}}
|
|
872
|
+
|
|
873
|
+
{indent_lines(all_type_definitions, "")}
|
|
874
|
+
"""
|
|
875
|
+
return proto_definition
|
|
876
|
+
|
|
877
|
+
|
|
878
|
+
def generate_grpc_code(proto_file, grpc_python_out) -> types.ModuleType | None:
|
|
879
|
+
"""
|
|
880
|
+
Execute the protoc command to generate Python gRPC code from the .proto file.
|
|
881
|
+
Returns a tuple of (pb2_grpc_module, pb2_module) on success, or None if failed.
|
|
882
|
+
"""
|
|
883
|
+
command = f"-I. --grpc_python_out={grpc_python_out} {proto_file}"
|
|
884
|
+
exit_code = protoc.main(command.split())
|
|
885
|
+
if exit_code != 0:
|
|
886
|
+
return None
|
|
887
|
+
|
|
888
|
+
base = os.path.splitext(proto_file)[0]
|
|
889
|
+
generated_pb2_grpc_file = f"{base}_pb2_grpc.py"
|
|
890
|
+
|
|
891
|
+
if grpc_python_out not in sys.path:
|
|
892
|
+
sys.path.append(grpc_python_out)
|
|
893
|
+
|
|
894
|
+
spec = importlib.util.spec_from_file_location(
|
|
895
|
+
generated_pb2_grpc_file, os.path.join(grpc_python_out, generated_pb2_grpc_file)
|
|
896
|
+
)
|
|
897
|
+
if spec is None:
|
|
898
|
+
return None
|
|
899
|
+
pb2_grpc_module = importlib.util.module_from_spec(spec)
|
|
900
|
+
if spec.loader is None:
|
|
901
|
+
return None
|
|
902
|
+
spec.loader.exec_module(pb2_grpc_module)
|
|
903
|
+
|
|
904
|
+
return pb2_grpc_module
|
|
905
|
+
|
|
906
|
+
|
|
907
|
+
def generate_connecpy_code(
|
|
908
|
+
proto_file: str, connecpy_out: str
|
|
909
|
+
) -> types.ModuleType | None:
|
|
910
|
+
"""
|
|
911
|
+
Execute the protoc command to generate Python Connecpy code from the .proto file.
|
|
912
|
+
Returns a tuple of (connecpy_module, pb2_module) on success, or None if failed.
|
|
913
|
+
"""
|
|
914
|
+
command = f"-I. --connecpy_out={connecpy_out} {proto_file}"
|
|
915
|
+
exit_code = protoc.main(command.split())
|
|
916
|
+
if exit_code != 0:
|
|
917
|
+
return None
|
|
918
|
+
|
|
919
|
+
base = os.path.splitext(proto_file)[0]
|
|
920
|
+
generated_connecpy_file = f"{base}_connecpy.py"
|
|
921
|
+
|
|
922
|
+
if connecpy_out not in sys.path:
|
|
923
|
+
sys.path.append(connecpy_out)
|
|
924
|
+
|
|
925
|
+
spec = importlib.util.spec_from_file_location(
|
|
926
|
+
generated_connecpy_file, os.path.join(connecpy_out, generated_connecpy_file)
|
|
927
|
+
)
|
|
928
|
+
if spec is None:
|
|
929
|
+
return None
|
|
930
|
+
connecpy_module = importlib.util.module_from_spec(spec)
|
|
931
|
+
if spec.loader is None:
|
|
932
|
+
return None
|
|
933
|
+
spec.loader.exec_module(connecpy_module)
|
|
934
|
+
|
|
935
|
+
return connecpy_module
|
|
936
|
+
|
|
937
|
+
|
|
938
|
+
def generate_pb_code(
|
|
939
|
+
proto_file: str, python_out: str, pyi_out: str
|
|
940
|
+
) -> types.ModuleType | None:
|
|
941
|
+
"""
|
|
942
|
+
Execute the protoc command to generate Python gRPC code from the .proto file.
|
|
943
|
+
Returns a tuple of (pb2_grpc_module, pb2_module) on success, or None if failed.
|
|
944
|
+
"""
|
|
945
|
+
command = f"-I. --python_out={python_out} --pyi_out={pyi_out} {proto_file}"
|
|
946
|
+
exit_code = protoc.main(command.split())
|
|
947
|
+
if exit_code != 0:
|
|
948
|
+
return None
|
|
949
|
+
|
|
950
|
+
base = os.path.splitext(proto_file)[0]
|
|
951
|
+
generated_pb2_file = f"{base}_pb2.py"
|
|
952
|
+
|
|
953
|
+
if python_out not in sys.path:
|
|
954
|
+
sys.path.append(python_out)
|
|
955
|
+
|
|
956
|
+
spec = importlib.util.spec_from_file_location(
|
|
957
|
+
generated_pb2_file, os.path.join(python_out, generated_pb2_file)
|
|
958
|
+
)
|
|
959
|
+
if spec is None:
|
|
960
|
+
return None
|
|
961
|
+
pb2_module = importlib.util.module_from_spec(spec)
|
|
962
|
+
if spec.loader is None:
|
|
963
|
+
return None
|
|
964
|
+
spec.loader.exec_module(pb2_module)
|
|
965
|
+
|
|
966
|
+
return pb2_module
|
|
967
|
+
|
|
968
|
+
|
|
969
|
+
def get_request_arg_type(sig):
|
|
970
|
+
"""Return the type annotation of the first parameter (request) of a method."""
|
|
971
|
+
num_of_params = len(sig.parameters)
|
|
972
|
+
if not (num_of_params == 1 or num_of_params == 2):
|
|
973
|
+
raise Exception("Method must have exactly one or two parameters")
|
|
974
|
+
return tuple(sig.parameters.values())[0].annotation
|
|
975
|
+
|
|
976
|
+
|
|
977
|
+
def get_rpc_methods(obj: object) -> list[tuple[str, types.MethodType]]:
|
|
978
|
+
"""
|
|
979
|
+
Retrieve the list of RPC methods from a service object.
|
|
980
|
+
The method name is converted to PascalCase for .proto compatibility.
|
|
981
|
+
"""
|
|
982
|
+
|
|
983
|
+
def to_pascal_case(name: str) -> str:
|
|
984
|
+
return "".join(part.capitalize() for part in name.split("_"))
|
|
985
|
+
|
|
986
|
+
return [
|
|
987
|
+
(to_pascal_case(attr_name), getattr(obj, attr_name))
|
|
988
|
+
for attr_name in dir(obj)
|
|
989
|
+
if inspect.ismethod(getattr(obj, attr_name))
|
|
990
|
+
]
|
|
991
|
+
|
|
992
|
+
|
|
993
|
+
def is_skip_generation() -> bool:
|
|
994
|
+
"""Check if the proto file and code generation should be skipped."""
|
|
995
|
+
return os.getenv("PYDANTIC_RPC_SKIP_GENERATION", "false").lower() == "true"
|
|
996
|
+
|
|
997
|
+
|
|
998
|
+
def generate_and_compile_proto(obj: object, package_name: str = ""):
|
|
999
|
+
if is_skip_generation():
|
|
1000
|
+
import importlib
|
|
1001
|
+
|
|
1002
|
+
pb2_module = importlib.import_module(f"{obj.__class__.__name__.lower()}_pb2")
|
|
1003
|
+
pb2_grpc_module = importlib.import_module(
|
|
1004
|
+
f"{obj.__class__.__name__.lower()}_pb2_grpc"
|
|
1005
|
+
)
|
|
1006
|
+
|
|
1007
|
+
if pb2_grpc_module is not None and pb2_module is not None:
|
|
1008
|
+
return pb2_grpc_module, pb2_module
|
|
1009
|
+
|
|
1010
|
+
# If the modules are not found, generate and compile the proto files.
|
|
1011
|
+
|
|
1012
|
+
klass = obj.__class__
|
|
1013
|
+
proto_file = generate_proto(obj, package_name)
|
|
1014
|
+
proto_file_name = klass.__name__.lower() + ".proto"
|
|
1015
|
+
|
|
1016
|
+
with open(proto_file_name, "w", encoding="utf-8") as f:
|
|
1017
|
+
f.write(proto_file)
|
|
1018
|
+
|
|
1019
|
+
gen_pb = generate_pb_code(proto_file_name, ".", ".")
|
|
1020
|
+
if gen_pb is None:
|
|
1021
|
+
raise Exception("Generating pb code")
|
|
1022
|
+
|
|
1023
|
+
gen_grpc = generate_grpc_code(proto_file_name, ".")
|
|
1024
|
+
if gen_grpc is None:
|
|
1025
|
+
raise Exception("Generating grpc code")
|
|
1026
|
+
return gen_grpc, gen_pb
|
|
1027
|
+
|
|
1028
|
+
|
|
1029
|
+
def generate_and_compile_proto_using_connecpy(obj: object, package_name: str = ""):
|
|
1030
|
+
if is_skip_generation():
|
|
1031
|
+
import importlib
|
|
1032
|
+
|
|
1033
|
+
pb2_module = importlib.import_module(f"{obj.__class__.__name__.lower()}_pb2")
|
|
1034
|
+
connecpy_module = importlib.import_module(
|
|
1035
|
+
f"{obj.__class__.__name__.lower()}_connecpy"
|
|
1036
|
+
)
|
|
1037
|
+
|
|
1038
|
+
if connecpy_module is not None and pb2_module is not None:
|
|
1039
|
+
return connecpy_module, pb2_module
|
|
1040
|
+
|
|
1041
|
+
# If the modules are not found, generate and compile the proto files.
|
|
1042
|
+
|
|
1043
|
+
klass = obj.__class__
|
|
1044
|
+
proto_file = generate_proto(obj, package_name)
|
|
1045
|
+
proto_file_name = klass.__name__.lower() + ".proto"
|
|
1046
|
+
|
|
1047
|
+
with open(proto_file_name, "w", encoding="utf-8") as f:
|
|
1048
|
+
f.write(proto_file)
|
|
1049
|
+
|
|
1050
|
+
gen_pb = generate_pb_code(proto_file_name, ".", ".")
|
|
1051
|
+
if gen_pb is None:
|
|
1052
|
+
raise Exception("Generating pb code")
|
|
1053
|
+
|
|
1054
|
+
gen_connecpy = generate_connecpy_code(proto_file_name, ".")
|
|
1055
|
+
if gen_connecpy is None:
|
|
1056
|
+
raise Exception("Generating Connecpy code")
|
|
1057
|
+
return gen_connecpy, gen_pb
|
|
1058
|
+
|
|
1059
|
+
|
|
1060
|
+
###############################################################################
|
|
1061
|
+
# 4. Server Implementations
|
|
1062
|
+
###############################################################################
|
|
1063
|
+
|
|
1064
|
+
|
|
1065
|
+
class Server:
|
|
1066
|
+
"""A simple gRPC server that uses ThreadPoolExecutor for concurrency."""
|
|
1067
|
+
|
|
1068
|
+
def __init__(self, max_workers: int = 8, *interceptors) -> None:
|
|
1069
|
+
self._server = grpc.server(
|
|
1070
|
+
futures.ThreadPoolExecutor(max_workers), interceptors=interceptors
|
|
1071
|
+
)
|
|
1072
|
+
self._service_names = []
|
|
1073
|
+
self._package_name = ""
|
|
1074
|
+
self._port = 50051
|
|
1075
|
+
|
|
1076
|
+
def set_package_name(self, package_name: str):
|
|
1077
|
+
"""Set the package name for .proto generation."""
|
|
1078
|
+
self._package_name = package_name
|
|
1079
|
+
|
|
1080
|
+
def set_port(self, port: int):
|
|
1081
|
+
"""Set the port number for the gRPC server."""
|
|
1082
|
+
self._port = port
|
|
1083
|
+
|
|
1084
|
+
def mount(self, obj: object, package_name: str = ""):
|
|
1085
|
+
"""Generate and compile proto files, then mount the service implementation."""
|
|
1086
|
+
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
|
|
1087
|
+
self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
|
|
1088
|
+
|
|
1089
|
+
def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
|
|
1090
|
+
"""Connect the compiled gRPC modules with the service implementation."""
|
|
1091
|
+
concreteServiceClass = connect_obj_with_stub(pb2_grpc_module, pb2_module, obj)
|
|
1092
|
+
service_name = obj.__class__.__name__
|
|
1093
|
+
service_impl = concreteServiceClass()
|
|
1094
|
+
getattr(pb2_grpc_module, f"add_{service_name}Servicer_to_server")(
|
|
1095
|
+
service_impl, self._server
|
|
1096
|
+
)
|
|
1097
|
+
full_service_name = pb2_module.DESCRIPTOR.services_by_name[
|
|
1098
|
+
service_name
|
|
1099
|
+
].full_name
|
|
1100
|
+
self._service_names.append(full_service_name)
|
|
1101
|
+
|
|
1102
|
+
def run(self, *objs):
|
|
1103
|
+
"""
|
|
1104
|
+
Mount multiple services and run the gRPC server with reflection and health check.
|
|
1105
|
+
Press Ctrl+C or send SIGTERM to stop.
|
|
1106
|
+
"""
|
|
1107
|
+
for obj in objs:
|
|
1108
|
+
self.mount(obj, self._package_name)
|
|
1109
|
+
|
|
1110
|
+
SERVICE_NAMES = (
|
|
1111
|
+
health_pb2.DESCRIPTOR.services_by_name["Health"].full_name,
|
|
1112
|
+
reflection.SERVICE_NAME,
|
|
1113
|
+
*self._service_names,
|
|
1114
|
+
)
|
|
1115
|
+
health_servicer = HealthServicer()
|
|
1116
|
+
health_pb2_grpc.add_HealthServicer_to_server(health_servicer, self._server)
|
|
1117
|
+
reflection.enable_server_reflection(SERVICE_NAMES, self._server)
|
|
1118
|
+
|
|
1119
|
+
self._server.add_insecure_port(f"[::]:{self._port}")
|
|
1120
|
+
self._server.start()
|
|
1121
|
+
|
|
1122
|
+
def handle_signal(signal, frame):
|
|
1123
|
+
print("Received shutdown signal...")
|
|
1124
|
+
self._server.stop(grace=10)
|
|
1125
|
+
print("gRPC server shutdown.")
|
|
1126
|
+
sys.exit(0)
|
|
1127
|
+
|
|
1128
|
+
signal.signal(signal.SIGINT, handle_signal)
|
|
1129
|
+
signal.signal(signal.SIGTERM, handle_signal)
|
|
1130
|
+
|
|
1131
|
+
print("gRPC server is running...")
|
|
1132
|
+
while True:
|
|
1133
|
+
time.sleep(86400)
|
|
1134
|
+
|
|
1135
|
+
|
|
1136
|
+
class AsyncIOServer:
|
|
1137
|
+
"""An async gRPC server using asyncio."""
|
|
1138
|
+
|
|
1139
|
+
def __init__(self, *interceptors) -> None:
|
|
1140
|
+
self._server = grpc.aio.server(interceptors=interceptors)
|
|
1141
|
+
self._service_names = []
|
|
1142
|
+
self._package_name = ""
|
|
1143
|
+
self._port = 50051
|
|
1144
|
+
|
|
1145
|
+
def set_package_name(self, package_name: str):
|
|
1146
|
+
"""Set the package name for .proto generation."""
|
|
1147
|
+
self._package_name = package_name
|
|
1148
|
+
|
|
1149
|
+
def set_port(self, port: int):
|
|
1150
|
+
"""Set the port number for the async gRPC server."""
|
|
1151
|
+
self._port = port
|
|
1152
|
+
|
|
1153
|
+
def mount(self, obj: object, package_name: str = ""):
|
|
1154
|
+
"""Generate and compile proto files, then mount the service implementation (async)."""
|
|
1155
|
+
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
|
|
1156
|
+
self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
|
|
1157
|
+
|
|
1158
|
+
def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
|
|
1159
|
+
"""Connect the compiled gRPC modules with the async service implementation."""
|
|
1160
|
+
concreteServiceClass = connect_obj_with_stub_async(
|
|
1161
|
+
pb2_grpc_module, pb2_module, obj
|
|
1162
|
+
)
|
|
1163
|
+
service_name = obj.__class__.__name__
|
|
1164
|
+
service_impl = concreteServiceClass()
|
|
1165
|
+
getattr(pb2_grpc_module, f"add_{service_name}Servicer_to_server")(
|
|
1166
|
+
service_impl, self._server
|
|
1167
|
+
)
|
|
1168
|
+
full_service_name = pb2_module.DESCRIPTOR.services_by_name[
|
|
1169
|
+
service_name
|
|
1170
|
+
].full_name
|
|
1171
|
+
self._service_names.append(full_service_name)
|
|
1172
|
+
|
|
1173
|
+
async def run(self, *objs):
|
|
1174
|
+
"""
|
|
1175
|
+
Mount multiple async services and run the gRPC server with reflection and health check.
|
|
1176
|
+
Press Ctrl+C or send SIGTERM to stop.
|
|
1177
|
+
"""
|
|
1178
|
+
for obj in objs:
|
|
1179
|
+
self.mount(obj, self._package_name)
|
|
1180
|
+
|
|
1181
|
+
SERVICE_NAMES = (
|
|
1182
|
+
health_pb2.DESCRIPTOR.services_by_name["Health"].full_name,
|
|
1183
|
+
reflection.SERVICE_NAME,
|
|
1184
|
+
*self._service_names,
|
|
1185
|
+
)
|
|
1186
|
+
health_servicer = HealthServicer()
|
|
1187
|
+
health_pb2_grpc.add_HealthServicer_to_server(health_servicer, self._server)
|
|
1188
|
+
reflection.enable_server_reflection(SERVICE_NAMES, self._server)
|
|
1189
|
+
|
|
1190
|
+
self._server.add_insecure_port(f"[::]:{self._port}")
|
|
1191
|
+
await self._server.start()
|
|
1192
|
+
|
|
1193
|
+
shutdown_event = asyncio.Event()
|
|
1194
|
+
|
|
1195
|
+
def shutdown(signum, frame):
|
|
1196
|
+
print("Received shutdown signal...")
|
|
1197
|
+
shutdown_event.set()
|
|
1198
|
+
|
|
1199
|
+
for s in [signal.SIGTERM, signal.SIGINT]:
|
|
1200
|
+
signal.signal(s, shutdown)
|
|
1201
|
+
|
|
1202
|
+
print("gRPC server is running...")
|
|
1203
|
+
await shutdown_event.wait()
|
|
1204
|
+
await self._server.stop(10)
|
|
1205
|
+
print("gRPC server shutdown.")
|
|
1206
|
+
|
|
1207
|
+
|
|
1208
|
+
class WSGIApp:
|
|
1209
|
+
"""
|
|
1210
|
+
A WSGI-compatible application that can serve gRPC via sonora's grpcWSGI.
|
|
1211
|
+
Useful for embedding gRPC within an existing WSGI stack.
|
|
1212
|
+
"""
|
|
1213
|
+
|
|
1214
|
+
def __init__(self, app):
|
|
1215
|
+
self._app = grpcWSGI(app)
|
|
1216
|
+
self._service_names = []
|
|
1217
|
+
self._package_name = ""
|
|
1218
|
+
|
|
1219
|
+
def mount(self, obj: object, package_name: str = ""):
|
|
1220
|
+
"""Generate and compile proto files, then mount the service implementation."""
|
|
1221
|
+
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
|
|
1222
|
+
self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
|
|
1223
|
+
|
|
1224
|
+
def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
|
|
1225
|
+
"""Connect the compiled gRPC modules with the service implementation."""
|
|
1226
|
+
concreteServiceClass = connect_obj_with_stub(pb2_grpc_module, pb2_module, obj)
|
|
1227
|
+
service_name = obj.__class__.__name__
|
|
1228
|
+
service_impl = concreteServiceClass()
|
|
1229
|
+
getattr(pb2_grpc_module, f"add_{service_name}Servicer_to_server")(
|
|
1230
|
+
service_impl, self._app
|
|
1231
|
+
)
|
|
1232
|
+
full_service_name = pb2_module.DESCRIPTOR.services_by_name[
|
|
1233
|
+
service_name
|
|
1234
|
+
].full_name
|
|
1235
|
+
self._service_names.append(full_service_name)
|
|
1236
|
+
|
|
1237
|
+
def mount_objs(self, *objs):
|
|
1238
|
+
"""Mount multiple service objects into this WSGI app."""
|
|
1239
|
+
for obj in objs:
|
|
1240
|
+
self.mount(obj, self._package_name)
|
|
1241
|
+
|
|
1242
|
+
def __call__(self, environ, start_response):
|
|
1243
|
+
"""WSGI entry point."""
|
|
1244
|
+
return self._app(environ, start_response)
|
|
1245
|
+
|
|
1246
|
+
|
|
1247
|
+
class ASGIApp:
|
|
1248
|
+
"""
|
|
1249
|
+
An ASGI-compatible application that can serve gRPC via sonora's grpcASGI.
|
|
1250
|
+
Useful for embedding gRPC within an existing ASGI stack.
|
|
1251
|
+
"""
|
|
1252
|
+
|
|
1253
|
+
def __init__(self, app):
|
|
1254
|
+
self._app = grpcASGI(app)
|
|
1255
|
+
self._service_names = []
|
|
1256
|
+
self._package_name = ""
|
|
1257
|
+
|
|
1258
|
+
def mount(self, obj: object, package_name: str = ""):
|
|
1259
|
+
"""Generate and compile proto files, then mount the async service implementation."""
|
|
1260
|
+
pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
|
|
1261
|
+
self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
|
|
1262
|
+
|
|
1263
|
+
def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
|
|
1264
|
+
"""Connect the compiled gRPC modules with the async service implementation."""
|
|
1265
|
+
concreteServiceClass = connect_obj_with_stub_async(
|
|
1266
|
+
pb2_grpc_module, pb2_module, obj
|
|
1267
|
+
)
|
|
1268
|
+
service_name = obj.__class__.__name__
|
|
1269
|
+
service_impl = concreteServiceClass()
|
|
1270
|
+
getattr(pb2_grpc_module, f"add_{service_name}Servicer_to_server")(
|
|
1271
|
+
service_impl, self._app
|
|
1272
|
+
)
|
|
1273
|
+
full_service_name = pb2_module.DESCRIPTOR.services_by_name[
|
|
1274
|
+
service_name
|
|
1275
|
+
].full_name
|
|
1276
|
+
self._service_names.append(full_service_name)
|
|
1277
|
+
|
|
1278
|
+
def mount_objs(self, *objs):
|
|
1279
|
+
"""Mount multiple service objects into this ASGI app."""
|
|
1280
|
+
for obj in objs:
|
|
1281
|
+
self.mount(obj, self._package_name)
|
|
1282
|
+
|
|
1283
|
+
async def __call__(self, scope, receive, send):
|
|
1284
|
+
"""ASGI entry point."""
|
|
1285
|
+
await self._app(scope, receive, send)
|
|
1286
|
+
|
|
1287
|
+
|
|
1288
|
+
def get_connecpy_server_class(connecpy_module, service_name):
|
|
1289
|
+
return getattr(connecpy_module, f"{service_name}Server")
|
|
1290
|
+
|
|
1291
|
+
|
|
1292
|
+
class ConnecpyASGIApp:
|
|
1293
|
+
"""
|
|
1294
|
+
An ASGI-compatible application that can serve Connect-RPC via Connecpy's ConnecpyASGIApp.
|
|
1295
|
+
"""
|
|
1296
|
+
|
|
1297
|
+
def __init__(self):
|
|
1298
|
+
self._app = ConnecpyASGI()
|
|
1299
|
+
self._service_names = []
|
|
1300
|
+
self._package_name = ""
|
|
1301
|
+
|
|
1302
|
+
def mount(self, obj: object, package_name: str = ""):
|
|
1303
|
+
"""Generate and compile proto files, then mount the async service implementation."""
|
|
1304
|
+
connecpy_module, pb2_module = generate_and_compile_proto_using_connecpy(
|
|
1305
|
+
obj, package_name
|
|
1306
|
+
)
|
|
1307
|
+
self.mount_using_pb2_modules(connecpy_module, pb2_module, obj)
|
|
1308
|
+
|
|
1309
|
+
def mount_using_pb2_modules(self, connecpy_module, pb2_module, obj: object):
|
|
1310
|
+
"""Connect the compiled connecpy and pb2 modules with the async service implementation."""
|
|
1311
|
+
concreteServiceClass = connect_obj_with_stub_async_connecpy(
|
|
1312
|
+
connecpy_module, pb2_module, obj
|
|
1313
|
+
)
|
|
1314
|
+
service_name = obj.__class__.__name__
|
|
1315
|
+
service_impl = concreteServiceClass()
|
|
1316
|
+
connecpy_server = get_connecpy_server_class(connecpy_module, service_name)
|
|
1317
|
+
self._app.add_service(connecpy_server(service=service_impl))
|
|
1318
|
+
full_service_name = pb2_module.DESCRIPTOR.services_by_name[
|
|
1319
|
+
service_name
|
|
1320
|
+
].full_name
|
|
1321
|
+
self._service_names.append(full_service_name)
|
|
1322
|
+
|
|
1323
|
+
def mount_objs(self, *objs):
|
|
1324
|
+
"""Mount multiple service objects into this ASGI app."""
|
|
1325
|
+
for obj in objs:
|
|
1326
|
+
self.mount(obj, self._package_name)
|
|
1327
|
+
|
|
1328
|
+
async def __call__(self, scope, receive, send):
|
|
1329
|
+
"""ASGI entry point."""
|
|
1330
|
+
await self._app(scope, receive, send)
|
|
1331
|
+
|
|
1332
|
+
|
|
1333
|
+
if __name__ == "__main__":
|
|
1334
|
+
"""
|
|
1335
|
+
If executed as a script, generate the .proto files for a given class.
|
|
1336
|
+
Usage: python core.py some_module.py SomeServiceClass
|
|
1337
|
+
"""
|
|
1338
|
+
py_file_name = sys.argv[1]
|
|
1339
|
+
class_name = sys.argv[2]
|
|
1340
|
+
module_name = os.path.splitext(basename(py_file_name))[0]
|
|
1341
|
+
module = importlib.import_module(module_name)
|
|
1342
|
+
klass = getattr(module, class_name)
|
|
1343
|
+
generate_and_compile_proto(klass())
|