pydantic-rpc 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pydantic_rpc/core.py ADDED
@@ -0,0 +1,1197 @@
1
+ import asyncio
2
+ import enum
3
+ import importlib.util
4
+ import inspect
5
+ import os
6
+ import signal
7
+ import sys
8
+ import time
9
+ import types
10
+ import datetime
11
+ from concurrent import futures
12
+ from posixpath import basename
13
+ from typing import (
14
+ Callable,
15
+ Type,
16
+ get_args,
17
+ get_origin,
18
+ Union,
19
+ TypeAlias,
20
+ )
21
+
22
+ import grpc
23
+ from grpc_health.v1 import health_pb2, health_pb2_grpc
24
+ from grpc_health.v1.health import HealthServicer
25
+ from grpc_reflection.v1alpha import reflection
26
+ from grpc_tools import protoc
27
+ from pydantic import BaseModel
28
+ from sonora.wsgi import grpcWSGI
29
+ from sonora.asgi import grpcASGI
30
+ from connecpy.asgi import ConnecpyASGIApp as ConnecpyASGI
31
+
32
+ # Protobuf Python modules for Timestamp, Duration (requires protobuf / grpcio)
33
+ from google.protobuf import timestamp_pb2, duration_pb2
34
+
35
+ ###############################################################################
36
+ # 1. Message definitions & converter extensions
37
+ # (datetime.datetime <-> google.protobuf.Timestamp)
38
+ # (datetime.timedelta <-> google.protobuf.Duration)
39
+ ###############################################################################
40
+
41
+
42
+ Message: TypeAlias = BaseModel
43
+
44
+
45
+ def primitiveProtoValueToPythonValue(value):
46
+ # Returns the value as-is (primitive type).
47
+ return value
48
+
49
+
50
+ def timestamp_to_python(ts: timestamp_pb2.Timestamp) -> datetime.datetime: # type: ignore
51
+ """Convert a protobuf Timestamp to a Python datetime object."""
52
+ return ts.ToDatetime()
53
+
54
+
55
+ def python_to_timestamp(dt: datetime.datetime) -> timestamp_pb2.Timestamp: # type: ignore
56
+ """Convert a Python datetime object to a protobuf Timestamp."""
57
+ ts = timestamp_pb2.Timestamp() # type: ignore
58
+ ts.FromDatetime(dt)
59
+ return ts
60
+
61
+
62
+ def duration_to_python(d: duration_pb2.Duration) -> datetime.timedelta: # type: ignore
63
+ """Convert a protobuf Duration to a Python timedelta object."""
64
+ return d.ToTimedelta()
65
+
66
+
67
+ def python_to_duration(td: datetime.timedelta) -> duration_pb2.Duration: # type: ignore
68
+ """Convert a Python timedelta object to a protobuf Duration."""
69
+ d = duration_pb2.Duration() # type: ignore
70
+ d.FromTimedelta(td)
71
+ return d
72
+
73
+
74
+ def generate_converter(annotation: Type) -> Callable:
75
+ """
76
+ Returns a converter function to convert protobuf types to Python types.
77
+ This is used primarily when handling incoming requests.
78
+ """
79
+ # For primitive types
80
+ if annotation in (int, str, bool, bytes, float):
81
+ return primitiveProtoValueToPythonValue
82
+
83
+ # For enum types
84
+ if inspect.isclass(annotation) and issubclass(annotation, enum.Enum):
85
+
86
+ def enum_converter(value):
87
+ return annotation(value)
88
+
89
+ return enum_converter
90
+
91
+ # For datetime
92
+ if annotation == datetime.datetime:
93
+
94
+ def ts_converter(value: timestamp_pb2.Timestamp): # type: ignore
95
+ return value.ToDatetime()
96
+
97
+ return ts_converter
98
+
99
+ # For timedelta
100
+ if annotation == datetime.timedelta:
101
+
102
+ def dur_converter(value: duration_pb2.Duration): # type: ignore
103
+ return value.ToTimedelta()
104
+
105
+ return dur_converter
106
+
107
+ # For list types
108
+ if get_origin(annotation) is list:
109
+ item_converter = generate_converter(get_args(annotation)[0])
110
+
111
+ def list_converter(value):
112
+ return [item_converter(v) for v in value]
113
+
114
+ return list_converter
115
+
116
+ # For dict types
117
+ if get_origin(annotation) is dict:
118
+ key_converter = generate_converter(get_args(annotation)[0])
119
+ value_converter = generate_converter(get_args(annotation)[1])
120
+
121
+ def dict_converter(value):
122
+ return {key_converter(k): value_converter(v) for k, v in value.items()}
123
+
124
+ return dict_converter
125
+
126
+ # For Message classes
127
+ if inspect.isclass(annotation) and issubclass(annotation, Message):
128
+ return generate_message_converter(annotation)
129
+
130
+ # For union types or other unsupported cases, just return the value as-is.
131
+ return primitiveProtoValueToPythonValue
132
+
133
+
134
+ def generate_message_converter(arg_type: Type[Message]) -> Callable:
135
+ """Return a converter function for protobuf -> Python Message."""
136
+ if arg_type is None or not issubclass(arg_type, Message):
137
+ raise TypeError("Request arg must be subclass of Message")
138
+
139
+ fields = arg_type.model_fields
140
+ converters = {
141
+ field: generate_converter(field_type.annotation) # type: ignore
142
+ for field, field_type in fields.items()
143
+ }
144
+
145
+ def converter(request):
146
+ rdict = {}
147
+ for field in fields.keys():
148
+ rdict[field] = converters[field](getattr(request, field))
149
+ return arg_type(**rdict)
150
+
151
+ return converter
152
+
153
+
154
+ def python_value_to_proto_value(field_type: Type, value):
155
+ """
156
+ Converts Python values to protobuf values.
157
+ Used primarily when constructing a response object.
158
+ """
159
+ # datetime.datetime -> Timestamp
160
+ if field_type == datetime.datetime:
161
+ return python_to_timestamp(value)
162
+
163
+ # datetime.timedelta -> Duration
164
+ if field_type == datetime.timedelta:
165
+ return python_to_duration(value)
166
+
167
+ # Default behavior: return the value as-is.
168
+ return value
169
+
170
+
171
+ ###############################################################################
172
+ # 2. Stub implementation
173
+ ###############################################################################
174
+
175
+
176
+ def connect_obj_with_stub(pb2_grpc_module, pb2_module, service_obj: object) -> type:
177
+ """
178
+ Connect a Python service object to a gRPC stub, generating server methods.
179
+ """
180
+ service_class = service_obj.__class__
181
+ stub_class_name = service_class.__name__ + "Servicer"
182
+ stub_class = getattr(pb2_grpc_module, stub_class_name)
183
+
184
+ class ConcreteServiceClass(stub_class):
185
+ pass
186
+
187
+ def implement_stub_method(method):
188
+ # Analyze method signature.
189
+ sig = inspect.signature(method)
190
+ arg_type = get_request_arg_type(sig)
191
+ # Convert request from protobuf to Python.
192
+ converter = generate_message_converter(arg_type)
193
+
194
+ response_type = sig.return_annotation
195
+ size_of_parameters = len(sig.parameters)
196
+
197
+ match size_of_parameters:
198
+ case 1:
199
+
200
+ def stub_method1(self, request, context, method=method):
201
+ # Convert request to Python object
202
+ arg = converter(request)
203
+ # Invoke the actual method
204
+ resp_obj = method(arg)
205
+ # Convert the returned Python Message to a protobuf message
206
+ return convert_python_message_to_proto(
207
+ resp_obj, response_type, pb2_module
208
+ )
209
+
210
+ return stub_method1
211
+
212
+ case 2:
213
+
214
+ def stub_method2(self, request, context, method=method):
215
+ arg = converter(request)
216
+ resp_obj = method(arg, context)
217
+ return convert_python_message_to_proto(
218
+ resp_obj, response_type, pb2_module
219
+ )
220
+
221
+ return stub_method2
222
+
223
+ case _:
224
+ raise Exception("Method must have exactly one or two parameters")
225
+
226
+ for method_name, method in get_rpc_methods(service_obj):
227
+ if method.__name__.startswith("_"):
228
+ continue
229
+
230
+ a_method = implement_stub_method(method)
231
+ setattr(ConcreteServiceClass, method_name, a_method)
232
+
233
+ return ConcreteServiceClass
234
+
235
+
236
+ def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> type:
237
+ """
238
+ Connect a Python service object to a gRPC stub for async methods.
239
+ """
240
+ service_class = obj.__class__
241
+ stub_class_name = service_class.__name__ + "Servicer"
242
+ stub_class = getattr(pb2_grpc_module, stub_class_name)
243
+
244
+ class ConcreteServiceClass(stub_class):
245
+ pass
246
+
247
+ def implement_stub_method(method):
248
+ sig = inspect.signature(method)
249
+ arg_type = get_request_arg_type(sig)
250
+ converter = generate_message_converter(arg_type)
251
+ response_type = sig.return_annotation
252
+ size_of_parameters = len(sig.parameters)
253
+
254
+ match size_of_parameters:
255
+ case 1:
256
+
257
+ async def stub_method1(self, request, context, method=method):
258
+ arg = converter(request)
259
+ resp_obj = await method(arg)
260
+ return convert_python_message_to_proto(
261
+ resp_obj, response_type, pb2_module
262
+ )
263
+
264
+ return stub_method1
265
+
266
+ case 2:
267
+
268
+ async def stub_method2(self, request, context, method=method):
269
+ arg = converter(request)
270
+ resp_obj = await method(arg, context)
271
+ return convert_python_message_to_proto(
272
+ resp_obj, response_type, pb2_module
273
+ )
274
+
275
+ return stub_method2
276
+
277
+ case _:
278
+ raise Exception("Method must have exactly one or two parameters")
279
+
280
+ for method_name, method in get_rpc_methods(obj):
281
+ if method.__name__.startswith("_"):
282
+ continue
283
+
284
+ a_method = implement_stub_method(method)
285
+ setattr(ConcreteServiceClass, method_name, a_method)
286
+
287
+ return ConcreteServiceClass
288
+
289
+
290
+ def connect_obj_with_stub_async_connecpy(
291
+ connecpy_module, pb2_module, obj: object
292
+ ) -> type:
293
+ """
294
+ Connect a Python service object to a Connecpy stub for async methods.
295
+ """
296
+ service_class = obj.__class__
297
+ stub_class_name = service_class.__name__
298
+ stub_class = getattr(connecpy_module, stub_class_name)
299
+
300
+ class ConcreteServiceClass(stub_class):
301
+ pass
302
+
303
+ def implement_stub_method(method):
304
+ sig = inspect.signature(method)
305
+ print(type(method))
306
+ arg_type = get_request_arg_type(sig)
307
+ converter = generate_message_converter(arg_type)
308
+ response_type = sig.return_annotation
309
+ size_of_parameters = len(sig.parameters)
310
+
311
+ match size_of_parameters:
312
+ case 1:
313
+
314
+ async def stub_method1(self, request, context, method=method):
315
+ arg = converter(request)
316
+ resp_obj = await method(arg)
317
+ return convert_python_message_to_proto(
318
+ resp_obj, response_type, pb2_module
319
+ )
320
+
321
+ return stub_method1
322
+
323
+ case 2:
324
+
325
+ async def stub_method2(self, request, context, method=method):
326
+ arg = converter(request)
327
+ resp_obj = await method(arg, context)
328
+ return convert_python_message_to_proto(
329
+ resp_obj, response_type, pb2_module
330
+ )
331
+
332
+ return stub_method2
333
+
334
+ case _:
335
+ raise Exception("Method must have exactly one or two parameters")
336
+
337
+ for method_name, method in get_rpc_methods(obj):
338
+ if method.__name__.startswith("_"):
339
+ continue
340
+ if not asyncio.iscoroutinefunction(method):
341
+ raise Exception("Method must be async", method_name)
342
+ a_method = implement_stub_method(method)
343
+ setattr(ConcreteServiceClass, method_name, a_method)
344
+
345
+ return ConcreteServiceClass
346
+
347
+
348
+ def convert_python_message_to_proto(
349
+ py_msg: Message, msg_type: Type, pb2_module
350
+ ) -> object:
351
+ """
352
+ Convert a Python Pydantic Message instance to a protobuf message instance.
353
+ Used for constructing a response.
354
+ """
355
+ # Before calling something like pb2_module.AResponseMessage(...),
356
+ # convert each field from Python to proto.
357
+ field_dict = {}
358
+ for name, field_info in msg_type.model_fields.items():
359
+ value = getattr(py_msg, name)
360
+ if value is None:
361
+ field_dict[name] = None
362
+ continue
363
+
364
+ field_type = field_info.annotation
365
+ field_dict[name] = python_value_to_proto(field_type, value, pb2_module)
366
+
367
+ # Retrieve the appropriate protobuf class dynamically
368
+ proto_class = getattr(pb2_module, msg_type.__name__)
369
+ return proto_class(**field_dict)
370
+
371
+
372
+ def python_value_to_proto(field_type: Type, value, pb2_module):
373
+ """
374
+ Perform Python->protobuf type conversion for each field value.
375
+ """
376
+ import inspect
377
+ import datetime
378
+
379
+ # If datetime
380
+ if field_type == datetime.datetime:
381
+ return python_to_timestamp(value)
382
+
383
+ # If timedelta
384
+ if field_type == datetime.timedelta:
385
+ return python_to_duration(value)
386
+
387
+ # If enum
388
+ if inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
389
+ return value.value # proto3 enum is an int
390
+
391
+ # If list
392
+ if get_origin(field_type) is list:
393
+ inner_type = get_args(field_type)[0] # type: ignore
394
+ return [python_value_to_proto(inner_type, v, pb2_module) for v in value]
395
+
396
+ # If dict
397
+ if get_origin(field_type) is dict:
398
+ key_type, val_type = get_args(field_type) # type: ignore
399
+ return {
400
+ python_value_to_proto(key_type, k, pb2_module): python_value_to_proto(
401
+ val_type, v, pb2_module
402
+ )
403
+ for k, v in value.items()
404
+ }
405
+
406
+ # If union -> oneof
407
+ if is_union_type(field_type):
408
+ # Flatten union and check which type matches. If matched, return converted value.
409
+ for sub_type in flatten_union(field_type):
410
+ if sub_type == datetime.datetime and isinstance(value, datetime.datetime):
411
+ return python_to_timestamp(value)
412
+ if sub_type == datetime.timedelta and isinstance(value, datetime.timedelta):
413
+ return python_to_duration(value)
414
+ if (
415
+ inspect.isclass(sub_type)
416
+ and issubclass(sub_type, enum.Enum)
417
+ and isinstance(value, enum.Enum)
418
+ ):
419
+ return value.value
420
+ if sub_type in (int, float, str, bool, bytes) and isinstance(
421
+ value, sub_type
422
+ ):
423
+ return value
424
+ if (
425
+ inspect.isclass(sub_type)
426
+ and issubclass(sub_type, Message)
427
+ and isinstance(value, Message)
428
+ ):
429
+ return convert_python_message_to_proto(value, sub_type, pb2_module)
430
+ return None
431
+
432
+ # If Message
433
+ if inspect.isclass(field_type) and issubclass(field_type, Message):
434
+ return convert_python_message_to_proto(value, field_type, pb2_module)
435
+
436
+ # If primitive
437
+ return value
438
+
439
+
440
+ ###############################################################################
441
+ # 3. Generating proto files (datetime->Timestamp, timedelta->Duration)
442
+ ###############################################################################
443
+
444
+
445
+ def is_enum_type(python_type: Type) -> bool:
446
+ """Return True if the given Python type is an enum."""
447
+ return inspect.isclass(python_type) and issubclass(python_type, enum.Enum)
448
+
449
+
450
+ def is_union_type(python_type: Type) -> bool:
451
+ """
452
+ Check if a given Python type is a Union type (including Python 3.10's UnionType).
453
+ """
454
+ if get_origin(python_type) is Union:
455
+ return True
456
+ if sys.version_info >= (3, 10):
457
+ import types
458
+
459
+ if type(python_type) is types.UnionType:
460
+ return True
461
+ return False
462
+
463
+
464
+ def flatten_union(field_type: Type) -> list[Type]:
465
+ """Recursively flatten nested Unions into a single list of types."""
466
+ if is_union_type(field_type):
467
+ results = []
468
+ for arg in get_args(field_type):
469
+ results.extend(flatten_union(arg))
470
+ return results
471
+ else:
472
+ return [field_type]
473
+
474
+
475
+ def protobuf_type_mapping(python_type: Type) -> str | type | None:
476
+ """
477
+ Map a Python type to a protobuf type name/class.
478
+ Includes support for Timestamp and Duration.
479
+ """
480
+ import datetime
481
+
482
+ mapping = {
483
+ int: "int32",
484
+ str: "string",
485
+ bool: "bool",
486
+ bytes: "bytes",
487
+ float: "float",
488
+ }
489
+
490
+ if python_type == datetime.datetime:
491
+ return "google.protobuf.Timestamp"
492
+
493
+ if python_type == datetime.timedelta:
494
+ return "google.protobuf.Duration"
495
+
496
+ if is_enum_type(python_type):
497
+ return python_type # Will be defined as enum later
498
+
499
+ if is_union_type(python_type):
500
+ return None # Handled separately as oneof
501
+
502
+ if hasattr(python_type, "__origin__"):
503
+ if python_type.__origin__ is list:
504
+ inner_type = python_type.__args__[0]
505
+ inner_proto_type = protobuf_type_mapping(inner_type)
506
+ if inner_proto_type:
507
+ return f"repeated {inner_proto_type}"
508
+ elif python_type.__origin__ is dict:
509
+ key_type = python_type.__args__[0]
510
+ value_type = python_type.__args__[1]
511
+ key_proto_type = protobuf_type_mapping(key_type)
512
+ value_proto_type = protobuf_type_mapping(value_type)
513
+ if key_proto_type and value_proto_type:
514
+ return f"map<{key_proto_type}, {value_proto_type}>"
515
+
516
+ if inspect.isclass(python_type) and issubclass(python_type, Message):
517
+ return python_type
518
+
519
+ return mapping.get(python_type) # type: ignore
520
+
521
+
522
+ def comment_out(docstr: str) -> tuple[str, ...]:
523
+ """Convert docstrings into commented-out lines in a .proto file."""
524
+ if docstr is None:
525
+ return tuple()
526
+ return tuple(f"//" if line == "" else f"// {line}" for line in docstr.split("\n"))
527
+
528
+
529
+ def indent_lines(lines, indentation=" "):
530
+ """Indent multiple lines with a given indentation string."""
531
+ return "\n".join(indentation + line for line in lines)
532
+
533
+
534
+ def generate_enum_definition(enum_type: Type[enum.Enum]) -> str:
535
+ """Generate a protobuf enum definition from a Python enum."""
536
+ enum_name = enum_type.__name__
537
+ members = []
538
+ for _, member in enum_type.__members__.items():
539
+ members.append(f" {member.name} = {member.value};")
540
+ enum_def = f"enum {enum_name} {{\n"
541
+ enum_def += "\n".join(members)
542
+ enum_def += "\n}"
543
+ return enum_def
544
+
545
+
546
+ def generate_oneof_definition(
547
+ field_name: str, union_args: list[Type], start_index: int
548
+ ) -> tuple[list[str], int]:
549
+ """
550
+ Generate a oneof block in protobuf for a union field.
551
+ Returns a tuple of the definition lines and the updated field index.
552
+ """
553
+ lines = []
554
+ lines.append(f"oneof {field_name} {{")
555
+ current = start_index
556
+ for arg_type in union_args:
557
+ proto_typename = protobuf_type_mapping(arg_type)
558
+ if proto_typename is None:
559
+ raise Exception(f"Nested Union not flattened properly: {arg_type}")
560
+
561
+ # If it's an enum or Message, use the type name.
562
+ if is_enum_type(arg_type):
563
+ proto_typename = arg_type.__name__
564
+ elif inspect.isclass(arg_type) and issubclass(arg_type, Message):
565
+ proto_typename = arg_type.__name__
566
+
567
+ field_alias = f"{field_name}_{proto_typename.replace('.', '_')}"
568
+ lines.append(f" {proto_typename} {field_alias} = {current};")
569
+ current += 1
570
+ lines.append("}")
571
+ return lines, current
572
+
573
+
574
+ def generate_message_definition(
575
+ message_type: Type[Message],
576
+ done_enums: set,
577
+ done_messages: set,
578
+ ) -> tuple[str, list[type]]:
579
+ """
580
+ Generate a protobuf message definition for a Pydantic-based Message class.
581
+ Also returns any referenced types (enums, messages) that need to be defined.
582
+ """
583
+ fields = []
584
+ refs = []
585
+ pydantic_fields = message_type.model_fields
586
+ index = 1
587
+
588
+ for field_name, field_info in pydantic_fields.items():
589
+ field_type = field_info.annotation
590
+ if field_type is None:
591
+ raise Exception(f"Field {field_name} has no type annotation.")
592
+
593
+ if is_union_type(field_type):
594
+ union_args = flatten_union(field_type)
595
+ oneof_lines, new_index = generate_oneof_definition(
596
+ field_name, union_args, index
597
+ )
598
+ fields.extend(oneof_lines)
599
+ index = new_index
600
+
601
+ for utype in union_args:
602
+ if is_enum_type(utype) and utype not in done_enums:
603
+ refs.append(utype)
604
+ elif (
605
+ inspect.isclass(utype)
606
+ and issubclass(utype, Message)
607
+ and utype not in done_messages
608
+ ):
609
+ refs.append(utype)
610
+
611
+ else:
612
+ proto_typename = protobuf_type_mapping(field_type)
613
+ if proto_typename is None:
614
+ raise Exception(f"Type {field_type} is not supported.")
615
+
616
+ if is_enum_type(field_type):
617
+ proto_typename = field_type.__name__
618
+ if field_type not in done_enums:
619
+ refs.append(field_type)
620
+ elif inspect.isclass(field_type) and issubclass(field_type, Message):
621
+ proto_typename = field_type.__name__
622
+ if field_type not in done_messages:
623
+ refs.append(field_type)
624
+
625
+ fields.append(f"{proto_typename} {field_name} = {index};")
626
+ index += 1
627
+
628
+ msg_def = f"message {message_type.__name__} {{\n{indent_lines(fields)}\n}}"
629
+ return msg_def, refs
630
+
631
+
632
+ def generate_proto(obj: object, package_name: str = "") -> str:
633
+ """
634
+ Generate a .proto definition from a service class.
635
+ Automatically handles Timestamp and Duration usage.
636
+ """
637
+ import datetime
638
+
639
+ service_class = obj.__class__
640
+ service_name = service_class.__name__
641
+ service_docstr = inspect.getdoc(service_class)
642
+ service_comment = "\n".join(comment_out(service_docstr)) if service_docstr else ""
643
+
644
+ rpc_definitions = []
645
+ all_type_definitions = []
646
+ done_messages = set()
647
+ done_enums = set()
648
+
649
+ uses_timestamp = False
650
+ uses_duration = False
651
+
652
+ def check_and_set_well_known_types(py_type: Type):
653
+ nonlocal uses_timestamp, uses_duration
654
+ if py_type == datetime.datetime:
655
+ uses_timestamp = True
656
+ if py_type == datetime.timedelta:
657
+ uses_duration = True
658
+
659
+ for method_name, method in get_rpc_methods(obj):
660
+ if method.__name__.startswith("_"):
661
+ continue
662
+
663
+ method_sig = inspect.signature(method)
664
+ request_type = get_request_arg_type(method_sig)
665
+ response_type = method_sig.return_annotation
666
+
667
+ # Recursively generate message definitions
668
+ message_types = [request_type, response_type]
669
+ while message_types:
670
+ mt = message_types.pop()
671
+ if mt in done_messages:
672
+ continue
673
+ done_messages.add(mt)
674
+
675
+ for _, field_info in mt.model_fields.items():
676
+ t = field_info.annotation
677
+ if is_union_type(t):
678
+ for sub_t in flatten_union(t):
679
+ check_and_set_well_known_types(sub_t)
680
+ else:
681
+ check_and_set_well_known_types(t)
682
+
683
+ msg_def, refs = generate_message_definition(mt, done_enums, done_messages)
684
+ mt_doc = inspect.getdoc(mt)
685
+ if mt_doc:
686
+ for comment_line in comment_out(mt_doc):
687
+ all_type_definitions.append(comment_line)
688
+
689
+ all_type_definitions.append(msg_def)
690
+ all_type_definitions.append("")
691
+
692
+ for r in refs:
693
+ if is_enum_type(r) and r not in done_enums:
694
+ done_enums.add(r)
695
+ enum_def = generate_enum_definition(r)
696
+ all_type_definitions.append(enum_def)
697
+ all_type_definitions.append("")
698
+ elif issubclass(r, Message) and r not in done_messages:
699
+ message_types.append(r)
700
+
701
+ method_docstr = inspect.getdoc(method)
702
+ if method_docstr:
703
+ for comment_line in comment_out(method_docstr):
704
+ rpc_definitions.append(comment_line)
705
+ rpc_definitions.append(
706
+ f"rpc {method_name} ({request_type.__name__}) returns ({response_type.__name__});"
707
+ )
708
+
709
+ if not package_name:
710
+ if service_name.endswith("Service"):
711
+ package_name = service_name[: -len("Service")]
712
+ else:
713
+ package_name = service_name
714
+ package_name = package_name.lower() + ".v1"
715
+
716
+ imports = []
717
+ if uses_timestamp:
718
+ imports.append('import "google/protobuf/timestamp.proto";')
719
+ if uses_duration:
720
+ imports.append('import "google/protobuf/duration.proto";')
721
+
722
+ import_block = "\n".join(imports)
723
+ if import_block:
724
+ import_block += "\n"
725
+
726
+ proto_definition = f"""syntax = "proto3";
727
+
728
+ package {package_name};
729
+
730
+ {import_block}{service_comment}
731
+ service {service_name} {{
732
+ {indent_lines(rpc_definitions)}
733
+ }}
734
+
735
+ {indent_lines(all_type_definitions, "")}
736
+ """
737
+ return proto_definition
738
+
739
+
740
+ def generate_grpc_code(proto_file, grpc_python_out) -> types.ModuleType | None:
741
+ """
742
+ Execute the protoc command to generate Python gRPC code from the .proto file.
743
+ Returns a tuple of (pb2_grpc_module, pb2_module) on success, or None if failed.
744
+ """
745
+ command = f"-I. --grpc_python_out={grpc_python_out} {proto_file}"
746
+ exit_code = protoc.main(command.split())
747
+ if exit_code != 0:
748
+ return None
749
+
750
+ base = os.path.splitext(proto_file)[0]
751
+ generated_pb2_grpc_file = f"{base}_pb2_grpc.py"
752
+
753
+ if grpc_python_out not in sys.path:
754
+ sys.path.append(grpc_python_out)
755
+
756
+ spec = importlib.util.spec_from_file_location(
757
+ generated_pb2_grpc_file, os.path.join(grpc_python_out, generated_pb2_grpc_file)
758
+ )
759
+ if spec is None:
760
+ return None
761
+ pb2_grpc_module = importlib.util.module_from_spec(spec)
762
+ if spec.loader is None:
763
+ return None
764
+ spec.loader.exec_module(pb2_grpc_module)
765
+
766
+ return pb2_grpc_module
767
+
768
+
769
+ def generate_connecpy_code(
770
+ proto_file: str, connecpy_out: str
771
+ ) -> types.ModuleType | None:
772
+ """
773
+ Execute the protoc command to generate Python Connecpy code from the .proto file.
774
+ Returns a tuple of (connecpy_module, pb2_module) on success, or None if failed.
775
+ """
776
+ command = f"-I. --connecpy_out={connecpy_out} {proto_file}"
777
+ exit_code = protoc.main(command.split())
778
+ if exit_code != 0:
779
+ return None
780
+
781
+ base = os.path.splitext(proto_file)[0]
782
+ generated_connecpy_file = f"{base}_connecpy.py"
783
+
784
+ if connecpy_out not in sys.path:
785
+ sys.path.append(connecpy_out)
786
+
787
+ spec = importlib.util.spec_from_file_location(
788
+ generated_connecpy_file, os.path.join(connecpy_out, generated_connecpy_file)
789
+ )
790
+ if spec is None:
791
+ return None
792
+ connecpy_module = importlib.util.module_from_spec(spec)
793
+ if spec.loader is None:
794
+ return None
795
+ spec.loader.exec_module(connecpy_module)
796
+
797
+ return connecpy_module
798
+
799
+
800
+ def generate_pb_code(
801
+ proto_file: str, python_out: str, pyi_out: str
802
+ ) -> types.ModuleType | None:
803
+ """
804
+ Execute the protoc command to generate Python gRPC code from the .proto file.
805
+ Returns a tuple of (pb2_grpc_module, pb2_module) on success, or None if failed.
806
+ """
807
+ command = f"-I. --python_out={python_out} --pyi_out={pyi_out} {proto_file}"
808
+ exit_code = protoc.main(command.split())
809
+ if exit_code != 0:
810
+ return None
811
+
812
+ base = os.path.splitext(proto_file)[0]
813
+ generated_pb2_file = f"{base}_pb2.py"
814
+
815
+ if python_out not in sys.path:
816
+ sys.path.append(python_out)
817
+
818
+ spec = importlib.util.spec_from_file_location(
819
+ generated_pb2_file, os.path.join(python_out, generated_pb2_file)
820
+ )
821
+ if spec is None:
822
+ return None
823
+ pb2_module = importlib.util.module_from_spec(spec)
824
+ if spec.loader is None:
825
+ return None
826
+ spec.loader.exec_module(pb2_module)
827
+
828
+ return pb2_module
829
+
830
+
831
+ def get_request_arg_type(sig):
832
+ """Return the type annotation of the first parameter (request) of a method."""
833
+ num_of_params = len(sig.parameters)
834
+ if not (num_of_params == 1 or num_of_params == 2):
835
+ raise Exception("Method must have exactly one or two parameters")
836
+ return tuple(sig.parameters.values())[0].annotation
837
+
838
+
839
+ def get_rpc_methods(obj: object) -> list[tuple[str, types.MethodType]]:
840
+ """
841
+ Retrieve the list of RPC methods from a service object.
842
+ The method name is converted to PascalCase for .proto compatibility.
843
+ """
844
+
845
+ def to_pascal_case(name: str) -> str:
846
+ return "".join(part.capitalize() for part in name.split("_"))
847
+
848
+ return [
849
+ (to_pascal_case(attr_name), getattr(obj, attr_name))
850
+ for attr_name in dir(obj)
851
+ if inspect.ismethod(getattr(obj, attr_name))
852
+ ]
853
+
854
+
855
+ def is_skip_generation() -> bool:
856
+ """Check if the proto file and code generation should be skipped."""
857
+ return os.getenv("PYDANTIC_RPC_SKIP_GENERATION", "false").lower() == "true"
858
+
859
+
860
+ def generate_and_compile_proto(obj: object, package_name: str = ""):
861
+ if is_skip_generation():
862
+ import importlib
863
+
864
+ pb2_module = importlib.import_module(f"{obj.__class__.__name__.lower()}_pb2")
865
+ pb2_grpc_module = importlib.import_module(
866
+ f"{obj.__class__.__name__.lower()}_pb2_grpc"
867
+ )
868
+ return pb2_grpc_module, pb2_module
869
+
870
+ klass = obj.__class__
871
+ proto_file = generate_proto(obj, package_name)
872
+ proto_file_name = klass.__name__.lower() + ".proto"
873
+
874
+ with open(proto_file_name, "w", encoding="utf-8") as f:
875
+ f.write(proto_file)
876
+
877
+ gen_pb = generate_pb_code(proto_file_name, ".", ".")
878
+ if gen_pb is None:
879
+ raise Exception("Generating pb code")
880
+
881
+ gen_grpc = generate_grpc_code(proto_file_name, ".")
882
+ if gen_grpc is None:
883
+ raise Exception("Generating grpc code")
884
+ return gen_grpc, gen_pb
885
+
886
+
887
+ def generate_and_compile_proto_using_connecpy(obj: object, package_name: str = ""):
888
+ if is_skip_generation():
889
+ import importlib
890
+
891
+ pb2_module = importlib.import_module(f"{obj.__class__.__name__.lower()}_pb2")
892
+ connecpy_module = importlib.import_module(
893
+ f"{obj.__class__.__name__.lower()}_connecpy"
894
+ )
895
+ return connecpy_module, pb2_module
896
+
897
+ klass = obj.__class__
898
+ proto_file = generate_proto(obj, package_name)
899
+ proto_file_name = klass.__name__.lower() + ".proto"
900
+
901
+ with open(proto_file_name, "w", encoding="utf-8") as f:
902
+ f.write(proto_file)
903
+
904
+ gen_pb = generate_pb_code(proto_file_name, ".", ".")
905
+ if gen_pb is None:
906
+ raise Exception("Generating pb code")
907
+
908
+ gen_connecpy = generate_connecpy_code(proto_file_name, ".")
909
+ if gen_connecpy is None:
910
+ raise Exception("Generating Connecpy code")
911
+ return gen_connecpy, gen_pb
912
+
913
+
914
+ ###############################################################################
915
+ # 4. Server Implementations
916
+ ###############################################################################
917
+
918
+
919
+ class Server:
920
+ """A simple gRPC server that uses ThreadPoolExecutor for concurrency."""
921
+
922
+ def __init__(self, max_workers: int = 8, *interceptors) -> None:
923
+ self._server = grpc.server(
924
+ futures.ThreadPoolExecutor(max_workers), interceptors=interceptors
925
+ )
926
+ self._service_names = []
927
+ self._package_name = ""
928
+ self._port = 50051
929
+
930
+ def set_package_name(self, package_name: str):
931
+ """Set the package name for .proto generation."""
932
+ self._package_name = package_name
933
+
934
+ def set_port(self, port: int):
935
+ """Set the port number for the gRPC server."""
936
+ self._port = port
937
+
938
+ def mount(self, obj: object, package_name: str = ""):
939
+ """Generate and compile proto files, then mount the service implementation."""
940
+ pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
941
+ self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
942
+
943
+ def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
944
+ """Connect the compiled gRPC modules with the service implementation."""
945
+ concreteServiceClass = connect_obj_with_stub(pb2_grpc_module, pb2_module, obj)
946
+ service_name = obj.__class__.__name__
947
+ service_impl = concreteServiceClass()
948
+ getattr(pb2_grpc_module, f"add_{service_name}Servicer_to_server")(
949
+ service_impl, self._server
950
+ )
951
+ full_service_name = pb2_module.DESCRIPTOR.services_by_name[
952
+ service_name
953
+ ].full_name
954
+ self._service_names.append(full_service_name)
955
+
956
+ def run(self, *objs):
957
+ """
958
+ Mount multiple services and run the gRPC server with reflection and health check.
959
+ Press Ctrl+C or send SIGTERM to stop.
960
+ """
961
+ for obj in objs:
962
+ self.mount(obj, self._package_name)
963
+
964
+ SERVICE_NAMES = (
965
+ health_pb2.DESCRIPTOR.services_by_name["Health"].full_name,
966
+ reflection.SERVICE_NAME,
967
+ *self._service_names,
968
+ )
969
+ health_servicer = HealthServicer()
970
+ health_pb2_grpc.add_HealthServicer_to_server(health_servicer, self._server)
971
+ reflection.enable_server_reflection(SERVICE_NAMES, self._server)
972
+
973
+ self._server.add_insecure_port(f"[::]:{self._port}")
974
+ self._server.start()
975
+
976
+ def handle_signal(signal, frame):
977
+ print("Received shutdown signal...")
978
+ self._server.stop(grace=10)
979
+ print("gRPC server shutdown.")
980
+ sys.exit(0)
981
+
982
+ signal.signal(signal.SIGINT, handle_signal)
983
+ signal.signal(signal.SIGTERM, handle_signal)
984
+
985
+ print("gRPC server is running...")
986
+ while True:
987
+ time.sleep(86400)
988
+
989
+
990
+ class AsyncIOServer:
991
+ """An async gRPC server using asyncio."""
992
+
993
+ def __init__(self, *interceptors) -> None:
994
+ self._server = grpc.aio.server(interceptors=interceptors)
995
+ self._service_names = []
996
+ self._package_name = ""
997
+ self._port = 50051
998
+
999
+ def set_package_name(self, package_name: str):
1000
+ """Set the package name for .proto generation."""
1001
+ self._package_name = package_name
1002
+
1003
+ def set_port(self, port: int):
1004
+ """Set the port number for the async gRPC server."""
1005
+ self._port = port
1006
+
1007
+ def mount(self, obj: object, package_name: str = ""):
1008
+ """Generate and compile proto files, then mount the service implementation (async)."""
1009
+ pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
1010
+ self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
1011
+
1012
+ def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
1013
+ """Connect the compiled gRPC modules with the async service implementation."""
1014
+ concreteServiceClass = connect_obj_with_stub_async(
1015
+ pb2_grpc_module, pb2_module, obj
1016
+ )
1017
+ service_name = obj.__class__.__name__
1018
+ service_impl = concreteServiceClass()
1019
+ getattr(pb2_grpc_module, f"add_{service_name}Servicer_to_server")(
1020
+ service_impl, self._server
1021
+ )
1022
+ full_service_name = pb2_module.DESCRIPTOR.services_by_name[
1023
+ service_name
1024
+ ].full_name
1025
+ self._service_names.append(full_service_name)
1026
+
1027
+ async def run(self, *objs):
1028
+ """
1029
+ Mount multiple async services and run the gRPC server with reflection and health check.
1030
+ Press Ctrl+C or send SIGTERM to stop.
1031
+ """
1032
+ for obj in objs:
1033
+ self.mount(obj, self._package_name)
1034
+
1035
+ SERVICE_NAMES = (
1036
+ health_pb2.DESCRIPTOR.services_by_name["Health"].full_name,
1037
+ reflection.SERVICE_NAME,
1038
+ *self._service_names,
1039
+ )
1040
+ health_servicer = HealthServicer()
1041
+ health_pb2_grpc.add_HealthServicer_to_server(health_servicer, self._server)
1042
+ reflection.enable_server_reflection(SERVICE_NAMES, self._server)
1043
+
1044
+ self._server.add_insecure_port(f"[::]:{self._port}")
1045
+ await self._server.start()
1046
+
1047
+ shutdown_event = asyncio.Event()
1048
+
1049
+ def shutdown(signum, frame):
1050
+ print("Received shutdown signal...")
1051
+ shutdown_event.set()
1052
+
1053
+ for s in [signal.SIGTERM, signal.SIGINT]:
1054
+ signal.signal(s, shutdown)
1055
+
1056
+ print("gRPC server is running...")
1057
+ await shutdown_event.wait()
1058
+ await self._server.stop(10)
1059
+ print("gRPC server shutdown.")
1060
+
1061
+
1062
+ class WSGIApp:
1063
+ """
1064
+ A WSGI-compatible application that can serve gRPC via sonora's grpcWSGI.
1065
+ Useful for embedding gRPC within an existing WSGI stack.
1066
+ """
1067
+
1068
+ def __init__(self, app):
1069
+ self._app = grpcWSGI(app)
1070
+ self._service_names = []
1071
+ self._package_name = ""
1072
+
1073
+ def mount(self, obj: object, package_name: str = ""):
1074
+ """Generate and compile proto files, then mount the service implementation."""
1075
+ pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
1076
+ self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
1077
+
1078
+ def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
1079
+ """Connect the compiled gRPC modules with the service implementation."""
1080
+ concreteServiceClass = connect_obj_with_stub(pb2_grpc_module, pb2_module, obj)
1081
+ service_name = obj.__class__.__name__
1082
+ service_impl = concreteServiceClass()
1083
+ getattr(pb2_grpc_module, f"add_{service_name}Servicer_to_server")(
1084
+ service_impl, self._app
1085
+ )
1086
+ full_service_name = pb2_module.DESCRIPTOR.services_by_name[
1087
+ service_name
1088
+ ].full_name
1089
+ self._service_names.append(full_service_name)
1090
+
1091
+ def mount_objs(self, *objs):
1092
+ """Mount multiple service objects into this WSGI app."""
1093
+ for obj in objs:
1094
+ self.mount(obj, self._package_name)
1095
+
1096
+ def __call__(self, environ, start_response):
1097
+ """WSGI entry point."""
1098
+ return self._app(environ, start_response)
1099
+
1100
+
1101
+ class ASGIApp:
1102
+ """
1103
+ An ASGI-compatible application that can serve gRPC via sonora's grpcASGI.
1104
+ Useful for embedding gRPC within an existing ASGI stack.
1105
+ """
1106
+
1107
+ def __init__(self, app):
1108
+ self._app = grpcASGI(app)
1109
+ self._service_names = []
1110
+ self._package_name = ""
1111
+
1112
+ def mount(self, obj: object, package_name: str = ""):
1113
+ """Generate and compile proto files, then mount the async service implementation."""
1114
+ pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
1115
+ self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
1116
+
1117
+ def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
1118
+ """Connect the compiled gRPC modules with the async service implementation."""
1119
+ concreteServiceClass = connect_obj_with_stub_async(
1120
+ pb2_grpc_module, pb2_module, obj
1121
+ )
1122
+ service_name = obj.__class__.__name__
1123
+ service_impl = concreteServiceClass()
1124
+ getattr(pb2_grpc_module, f"add_{service_name}Servicer_to_server")(
1125
+ service_impl, self._app
1126
+ )
1127
+ full_service_name = pb2_module.DESCRIPTOR.services_by_name[
1128
+ service_name
1129
+ ].full_name
1130
+ self._service_names.append(full_service_name)
1131
+
1132
+ def mount_objs(self, *objs):
1133
+ """Mount multiple service objects into this ASGI app."""
1134
+ for obj in objs:
1135
+ self.mount(obj, self._package_name)
1136
+
1137
+ async def __call__(self, scope, receive, send):
1138
+ """ASGI entry point."""
1139
+ await self._app(scope, receive, send)
1140
+
1141
+
1142
+ def get_connecpy_server_class(connecpy_module, service_name):
1143
+ return getattr(connecpy_module, f"{service_name}Server")
1144
+
1145
+
1146
+ class ConnecpyASGIApp:
1147
+ """
1148
+ An ASGI-compatible application that can serve Connect-RPC via Connecpy's ConnecpyASGIApp.
1149
+ """
1150
+
1151
+ def __init__(self):
1152
+ self._app = ConnecpyASGI()
1153
+ self._service_names = []
1154
+ self._package_name = ""
1155
+
1156
+ def mount(self, obj: object, package_name: str = ""):
1157
+ """Generate and compile proto files, then mount the async service implementation."""
1158
+ connecpy_module, pb2_module = generate_and_compile_proto_using_connecpy(
1159
+ obj, package_name
1160
+ )
1161
+ self.mount_using_pb2_modules(connecpy_module, pb2_module, obj)
1162
+
1163
+ def mount_using_pb2_modules(self, connecpy_module, pb2_module, obj: object):
1164
+ """Connect the compiled connecpy and pb2 modules with the async service implementation."""
1165
+ concreteServiceClass = connect_obj_with_stub_async_connecpy(
1166
+ connecpy_module, pb2_module, obj
1167
+ )
1168
+ service_name = obj.__class__.__name__
1169
+ service_impl = concreteServiceClass()
1170
+ connecpy_server = get_connecpy_server_class(connecpy_module, service_name)
1171
+ self._app.add_service(connecpy_server(service=service_impl))
1172
+ full_service_name = pb2_module.DESCRIPTOR.services_by_name[
1173
+ service_name
1174
+ ].full_name
1175
+ self._service_names.append(full_service_name)
1176
+
1177
+ def mount_objs(self, *objs):
1178
+ """Mount multiple service objects into this ASGI app."""
1179
+ for obj in objs:
1180
+ self.mount(obj, self._package_name)
1181
+
1182
+ async def __call__(self, scope, receive, send):
1183
+ """ASGI entry point."""
1184
+ await self._app(scope, receive, send)
1185
+
1186
+
1187
+ if __name__ == "__main__":
1188
+ """
1189
+ If executed as a script, generate the .proto files for a given class.
1190
+ Usage: python core.py some_module.py SomeServiceClass
1191
+ """
1192
+ py_file_name = sys.argv[1]
1193
+ class_name = sys.argv[2]
1194
+ module_name = os.path.splitext(basename(py_file_name))[0]
1195
+ module = importlib.import_module(module_name)
1196
+ klass = getattr(module, class_name)
1197
+ generate_and_compile_proto(klass())