pydantic-rpc 0.3.1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pydantic_rpc/core.py CHANGED
@@ -1,1308 +1,1343 @@
1
- import asyncio
2
- import enum
3
- import importlib.util
4
- import inspect
5
- import os
6
- import signal
7
- import sys
8
- import time
9
- import types
10
- import datetime
11
- from concurrent import futures
12
- from posixpath import basename
13
- from typing import (
14
- Callable,
15
- Type,
16
- get_args,
17
- get_origin,
18
- Union,
19
- TypeAlias,
20
- # AsyncIterator, # Add if not already present
21
- )
22
- from collections.abc import AsyncIterator
23
-
24
- import grpc
25
- from grpc_health.v1 import health_pb2, health_pb2_grpc
26
- from grpc_health.v1.health import HealthServicer
27
- from grpc_reflection.v1alpha import reflection
28
- from grpc_tools import protoc
29
- from pydantic import BaseModel, ValidationError
30
- from sonora.wsgi import grpcWSGI
31
- from sonora.asgi import grpcASGI
32
- from connecpy.asgi import ConnecpyASGIApp as ConnecpyASGI
33
- from connecpy.errors import Errors
34
-
35
- # Protobuf Python modules for Timestamp, Duration (requires protobuf / grpcio)
36
- from google.protobuf import timestamp_pb2, duration_pb2
37
-
38
- ###############################################################################
39
- # 1. Message definitions & converter extensions
40
- # (datetime.datetime <-> google.protobuf.Timestamp)
41
- # (datetime.timedelta <-> google.protobuf.Duration)
42
- ###############################################################################
43
-
44
-
45
- Message: TypeAlias = BaseModel
46
-
47
-
48
- def primitiveProtoValueToPythonValue(value):
49
- # Returns the value as-is (primitive type).
50
- return value
51
-
52
-
53
- def timestamp_to_python(ts: timestamp_pb2.Timestamp) -> datetime.datetime: # type: ignore
54
- """Convert a protobuf Timestamp to a Python datetime object."""
55
- return ts.ToDatetime()
56
-
57
-
58
- def python_to_timestamp(dt: datetime.datetime) -> timestamp_pb2.Timestamp: # type: ignore
59
- """Convert a Python datetime object to a protobuf Timestamp."""
60
- ts = timestamp_pb2.Timestamp() # type: ignore
61
- ts.FromDatetime(dt)
62
- return ts
63
-
64
-
65
- def duration_to_python(d: duration_pb2.Duration) -> datetime.timedelta: # type: ignore
66
- """Convert a protobuf Duration to a Python timedelta object."""
67
- return d.ToTimedelta()
68
-
69
-
70
- def python_to_duration(td: datetime.timedelta) -> duration_pb2.Duration: # type: ignore
71
- """Convert a Python timedelta object to a protobuf Duration."""
72
- d = duration_pb2.Duration() # type: ignore
73
- d.FromTimedelta(td)
74
- return d
75
-
76
-
77
- def generate_converter(annotation: Type) -> Callable:
78
- """
79
- Returns a converter function to convert protobuf types to Python types.
80
- This is used primarily when handling incoming requests.
81
- """
82
- # For primitive types
83
- if annotation in (int, str, bool, bytes, float):
84
- return primitiveProtoValueToPythonValue
85
-
86
- # For enum types
87
- if inspect.isclass(annotation) and issubclass(annotation, enum.Enum):
88
-
89
- def enum_converter(value):
90
- return annotation(value)
91
-
92
- return enum_converter
93
-
94
- # For datetime
95
- if annotation == datetime.datetime:
96
-
97
- def ts_converter(value: timestamp_pb2.Timestamp): # type: ignore
98
- return value.ToDatetime()
99
-
100
- return ts_converter
101
-
102
- # For timedelta
103
- if annotation == datetime.timedelta:
104
-
105
- def dur_converter(value: duration_pb2.Duration): # type: ignore
106
- return value.ToTimedelta()
107
-
108
- return dur_converter
109
-
110
- origin = get_origin(annotation)
111
- if origin is not None:
112
- # For seq types
113
- if origin in (list, tuple):
114
- item_converter = generate_converter(get_args(annotation)[0])
115
-
116
- def seq_converter(value):
117
- return [item_converter(v) for v in value]
118
-
119
- return seq_converter
120
-
121
- # For dict types
122
- if origin is dict:
123
- key_converter = generate_converter(get_args(annotation)[0])
124
- value_converter = generate_converter(get_args(annotation)[1])
125
-
126
- def dict_converter(value):
127
- return {key_converter(k): value_converter(v) for k, v in value.items()}
128
-
129
- return dict_converter
130
-
131
- # For Message classes
132
- if inspect.isclass(annotation) and issubclass(annotation, Message):
133
- return generate_message_converter(annotation)
134
-
135
- # For union types or other unsupported cases, just return the value as-is.
136
- return primitiveProtoValueToPythonValue
137
-
138
-
139
- def generate_message_converter(arg_type: Type[Message]) -> Callable:
140
- """Return a converter function for protobuf -> Python Message."""
141
- if arg_type is None or not issubclass(arg_type, Message):
142
- raise TypeError("Request arg must be subclass of Message")
143
-
144
- fields = arg_type.model_fields
145
- converters = {
146
- field: generate_converter(field_type.annotation) # type: ignore
147
- for field, field_type in fields.items()
148
- }
149
-
150
- def converter(request):
151
- rdict = {}
152
- for field in fields.keys():
153
- rdict[field] = converters[field](getattr(request, field))
154
- return arg_type(**rdict)
155
-
156
- return converter
157
-
158
-
159
- def python_value_to_proto_value(field_type: Type, value):
160
- """
161
- Converts Python values to protobuf values.
162
- Used primarily when constructing a response object.
163
- """
164
- # datetime.datetime -> Timestamp
165
- if field_type == datetime.datetime:
166
- return python_to_timestamp(value)
167
-
168
- # datetime.timedelta -> Duration
169
- if field_type == datetime.timedelta:
170
- return python_to_duration(value)
171
-
172
- # Default behavior: return the value as-is.
173
- return value
174
-
175
-
176
- ###############################################################################
177
- # 2. Stub implementation
178
- ###############################################################################
179
-
180
-
181
- def connect_obj_with_stub(pb2_grpc_module, pb2_module, service_obj: object) -> type:
182
- """
183
- Connect a Python service object to a gRPC stub, generating server methods.
184
- """
185
- service_class = service_obj.__class__
186
- stub_class_name = service_class.__name__ + "Servicer"
187
- stub_class = getattr(pb2_grpc_module, stub_class_name)
188
-
189
- class ConcreteServiceClass(stub_class):
190
- pass
191
-
192
- def implement_stub_method(method):
193
- # Analyze method signature.
194
- sig = inspect.signature(method)
195
- arg_type = get_request_arg_type(sig)
196
- # Convert request from protobuf to Python.
197
- converter = generate_message_converter(arg_type)
198
-
199
- response_type = sig.return_annotation
200
- size_of_parameters = len(sig.parameters)
201
-
202
- match size_of_parameters:
203
- case 1:
204
-
205
- def stub_method1(self, request, context, method=method):
206
- try:
207
- # Convert request to Python object
208
- arg = converter(request)
209
- # Invoke the actual method
210
- resp_obj = method(arg)
211
- # Convert the returned Python Message to a protobuf message
212
- return convert_python_message_to_proto(
213
- resp_obj, response_type, pb2_module
214
- )
215
- except ValidationError as e:
216
- return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
217
- except Exception as e:
218
- return context.abort(grpc.StatusCode.INTERNAL, str(e))
219
-
220
- return stub_method1
221
-
222
- case 2:
223
-
224
- def stub_method2(self, request, context, method=method):
225
- try:
226
- arg = converter(request)
227
- resp_obj = method(arg, context)
228
- return convert_python_message_to_proto(
229
- resp_obj, response_type, pb2_module
230
- )
231
- except ValidationError as e:
232
- return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
233
- except Exception as e:
234
- return context.abort(grpc.StatusCode.INTERNAL, str(e))
235
-
236
- return stub_method2
237
-
238
- case _:
239
- raise Exception("Method must have exactly one or two parameters")
240
-
241
- for method_name, method in get_rpc_methods(service_obj):
242
- if method.__name__.startswith("_"):
243
- continue
244
-
245
- a_method = implement_stub_method(method)
246
- setattr(ConcreteServiceClass, method_name, a_method)
247
-
248
- return ConcreteServiceClass
249
-
250
-
251
- def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> type:
252
- """
253
- Connect a Python service object to a gRPC stub for async methods.
254
- """
255
- service_class = obj.__class__
256
- stub_class_name = service_class.__name__ + "Servicer"
257
- stub_class = getattr(pb2_grpc_module, stub_class_name)
258
-
259
- class ConcreteServiceClass(stub_class):
260
- pass
261
-
262
- def implement_stub_method(method):
263
- sig = inspect.signature(method)
264
- arg_type = get_request_arg_type(sig)
265
- converter = generate_message_converter(arg_type)
266
- response_type = sig.return_annotation
267
- size_of_parameters = len(sig.parameters)
268
-
269
- if is_stream_type(response_type):
270
- item_type = get_args(response_type)[0]
271
- match size_of_parameters:
272
- case 1:
273
-
274
- async def stub_method_stream1(
275
- self, request, context, method=method
276
- ):
277
- try:
278
- arg = converter(request)
279
- async for resp_obj in method(arg):
280
- yield convert_python_message_to_proto(
281
- resp_obj, item_type, pb2_module
282
- )
283
- except ValidationError as e:
284
- await context.abort(
285
- grpc.StatusCode.INVALID_ARGUMENT, str(e)
286
- )
287
- except Exception as e:
288
- await context.abort(grpc.StatusCode.INTERNAL, str(e))
289
-
290
- return stub_method_stream1
291
- case 2:
292
-
293
- async def stub_method_stream2(
294
- self, request, context, method=method
295
- ):
296
- try:
297
- arg = converter(request)
298
- async for resp_obj in method(arg, context):
299
- yield convert_python_message_to_proto(
300
- resp_obj, item_type, pb2_module
301
- )
302
- except ValidationError as e:
303
- await context.abort(
304
- grpc.StatusCode.INVALID_ARGUMENT, str(e)
305
- )
306
- except Exception as e:
307
- await context.abort(grpc.StatusCode.INTERNAL, str(e))
308
-
309
- return stub_method_stream2
310
- case _:
311
- raise Exception("Method must have exactly one or two parameters")
312
-
313
- match size_of_parameters:
314
- case 1:
315
-
316
- async def stub_method1(self, request, context, method=method):
317
- try:
318
- arg = converter(request)
319
- resp_obj = await method(arg)
320
- return convert_python_message_to_proto(
321
- resp_obj, response_type, pb2_module
322
- )
323
- except ValidationError as e:
324
- await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
325
- except Exception as e:
326
- await context.abort(grpc.StatusCode.INTERNAL, str(e))
327
-
328
- return stub_method1
329
-
330
- case 2:
331
-
332
- async def stub_method2(self, request, context, method=method):
333
- try:
334
- arg = converter(request)
335
- resp_obj = await method(arg, context)
336
- return convert_python_message_to_proto(
337
- resp_obj, response_type, pb2_module
338
- )
339
- except ValidationError as e:
340
- await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
341
- except Exception as e:
342
- await context.abort(grpc.StatusCode.INTERNAL, str(e))
343
-
344
- return stub_method2
345
-
346
- case _:
347
- raise Exception("Method must have exactly one or two parameters")
348
-
349
- for method_name, method in get_rpc_methods(obj):
350
- if method.__name__.startswith("_"):
351
- continue
352
-
353
- a_method = implement_stub_method(method)
354
- setattr(ConcreteServiceClass, method_name, a_method)
355
-
356
- return ConcreteServiceClass
357
-
358
-
359
- def connect_obj_with_stub_async_connecpy(
360
- connecpy_module, pb2_module, obj: object
361
- ) -> type:
362
- """
363
- Connect a Python service object to a Connecpy stub for async methods.
364
- """
365
- service_class = obj.__class__
366
- stub_class_name = service_class.__name__
367
- stub_class = getattr(connecpy_module, stub_class_name)
368
-
369
- class ConcreteServiceClass(stub_class):
370
- pass
371
-
372
- def implement_stub_method(method):
373
- sig = inspect.signature(method)
374
- arg_type = get_request_arg_type(sig)
375
- converter = generate_message_converter(arg_type)
376
- response_type = sig.return_annotation
377
- size_of_parameters = len(sig.parameters)
378
-
379
- match size_of_parameters:
380
- case 1:
381
-
382
- async def stub_method1(self, request, context, method=method):
383
- try:
384
- arg = converter(request)
385
- resp_obj = await method(arg)
386
- return convert_python_message_to_proto(
387
- resp_obj, response_type, pb2_module
388
- )
389
- except ValidationError as e:
390
- await context.abort(Errors.InvalidArgument, str(e))
391
- except Exception as e:
392
- await context.abort(Errors.Internal, str(e))
393
-
394
- return stub_method1
395
-
396
- case 2:
397
-
398
- async def stub_method2(self, request, context, method=method):
399
- try:
400
- arg = converter(request)
401
- resp_obj = await method(arg, context)
402
- return convert_python_message_to_proto(
403
- resp_obj, response_type, pb2_module
404
- )
405
- except ValidationError as e:
406
- await context.abort(Errors.InvalidArgument, str(e))
407
- except Exception as e:
408
- await context.abort(Errors.Internal, str(e))
409
-
410
- return stub_method2
411
-
412
- case _:
413
- raise Exception("Method must have exactly one or two parameters")
414
-
415
- for method_name, method in get_rpc_methods(obj):
416
- if method.__name__.startswith("_"):
417
- continue
418
- if not asyncio.iscoroutinefunction(method):
419
- raise Exception("Method must be async", method_name)
420
- a_method = implement_stub_method(method)
421
- setattr(ConcreteServiceClass, method_name, a_method)
422
-
423
- return ConcreteServiceClass
424
-
425
-
426
- def convert_python_message_to_proto(
427
- py_msg: Message, msg_type: Type, pb2_module
428
- ) -> object:
429
- """
430
- Convert a Python Pydantic Message instance to a protobuf message instance.
431
- Used for constructing a response.
432
- """
433
- # Before calling something like pb2_module.AResponseMessage(...),
434
- # convert each field from Python to proto.
435
- field_dict = {}
436
- for name, field_info in msg_type.model_fields.items():
437
- value = getattr(py_msg, name)
438
- if value is None:
439
- field_dict[name] = None
440
- continue
441
-
442
- field_type = field_info.annotation
443
- field_dict[name] = python_value_to_proto(field_type, value, pb2_module)
444
-
445
- # Retrieve the appropriate protobuf class dynamically
446
- proto_class = getattr(pb2_module, msg_type.__name__)
447
- return proto_class(**field_dict)
448
-
449
-
450
- def python_value_to_proto(field_type: Type, value, pb2_module):
451
- """
452
- Perform Python->protobuf type conversion for each field value.
453
- """
454
- import inspect
455
- import datetime
456
-
457
- # If datetime
458
- if field_type == datetime.datetime:
459
- return python_to_timestamp(value)
460
-
461
- # If timedelta
462
- if field_type == datetime.timedelta:
463
- return python_to_duration(value)
464
-
465
- # If enum
466
- if inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
467
- return value.value # proto3 enum is an int
468
-
469
- origin = get_origin(field_type)
470
- # If seq
471
- if origin in (list, tuple):
472
- inner_type = get_args(field_type)[0] # type: ignore
473
- return [python_value_to_proto(inner_type, v, pb2_module) for v in value]
474
-
475
- # If dict
476
- if origin is dict:
477
- key_type, val_type = get_args(field_type) # type: ignore
478
- return {
479
- python_value_to_proto(key_type, k, pb2_module): python_value_to_proto(
480
- val_type, v, pb2_module
481
- )
482
- for k, v in value.items()
483
- }
484
-
485
- # If union -> oneof
486
- if is_union_type(field_type):
487
- # Flatten union and check which type matches. If matched, return converted value.
488
- for sub_type in flatten_union(field_type):
489
- if sub_type == datetime.datetime and isinstance(value, datetime.datetime):
490
- return python_to_timestamp(value)
491
- if sub_type == datetime.timedelta and isinstance(value, datetime.timedelta):
492
- return python_to_duration(value)
493
- if (
494
- inspect.isclass(sub_type)
495
- and issubclass(sub_type, enum.Enum)
496
- and isinstance(value, enum.Enum)
497
- ):
498
- return value.value
499
- if sub_type in (int, float, str, bool, bytes) and isinstance(
500
- value, sub_type
501
- ):
502
- return value
503
- if (
504
- inspect.isclass(sub_type)
505
- and issubclass(sub_type, Message)
506
- and isinstance(value, Message)
507
- ):
508
- return convert_python_message_to_proto(value, sub_type, pb2_module)
509
- return None
510
-
511
- # If Message
512
- if inspect.isclass(field_type) and issubclass(field_type, Message):
513
- return convert_python_message_to_proto(value, field_type, pb2_module)
514
-
515
- # If primitive
516
- return value
517
-
518
-
519
- ###############################################################################
520
- # 3. Generating proto files (datetime->Timestamp, timedelta->Duration)
521
- ###############################################################################
522
-
523
-
524
- def is_enum_type(python_type: Type) -> bool:
525
- """Return True if the given Python type is an enum."""
526
- return inspect.isclass(python_type) and issubclass(python_type, enum.Enum)
527
-
528
-
529
- def is_union_type(python_type: Type) -> bool:
530
- """
531
- Check if a given Python type is a Union type (including Python 3.10's UnionType).
532
- """
533
- if get_origin(python_type) is Union:
534
- return True
535
- if sys.version_info >= (3, 10):
536
- import types
537
-
538
- if type(python_type) is types.UnionType:
539
- return True
540
- return False
541
-
542
-
543
- def flatten_union(field_type: Type) -> list[Type]:
544
- """Recursively flatten nested Unions into a single list of types."""
545
- if is_union_type(field_type):
546
- results = []
547
- for arg in get_args(field_type):
548
- results.extend(flatten_union(arg))
549
- return results
550
- else:
551
- return [field_type]
552
-
553
-
554
- def protobuf_type_mapping(python_type: Type) -> str | type | None:
555
- """
556
- Map a Python type to a protobuf type name/class.
557
- Includes support for Timestamp and Duration.
558
- """
559
- import datetime
560
-
561
- mapping = {
562
- int: "int32",
563
- str: "string",
564
- bool: "bool",
565
- bytes: "bytes",
566
- float: "float",
567
- }
568
-
569
- if python_type == datetime.datetime:
570
- return "google.protobuf.Timestamp"
571
-
572
- if python_type == datetime.timedelta:
573
- return "google.protobuf.Duration"
574
-
575
- if is_enum_type(python_type):
576
- return python_type # Will be defined as enum later
577
-
578
- if is_union_type(python_type):
579
- return None # Handled separately as oneof
580
-
581
- if hasattr(python_type, "__origin__"):
582
- if python_type.__origin__ in (list, tuple):
583
- inner_type = python_type.__args__[0]
584
- inner_proto_type = protobuf_type_mapping(inner_type)
585
- if inner_proto_type:
586
- return f"repeated {inner_proto_type}"
587
- elif python_type.__origin__ is dict:
588
- key_type = python_type.__args__[0]
589
- value_type = python_type.__args__[1]
590
- key_proto_type = protobuf_type_mapping(key_type)
591
- value_proto_type = protobuf_type_mapping(value_type)
592
- if key_proto_type and value_proto_type:
593
- return f"map<{key_proto_type}, {value_proto_type}>"
594
-
595
- if inspect.isclass(python_type) and issubclass(python_type, Message):
596
- return python_type
597
-
598
- return mapping.get(python_type) # type: ignore
599
-
600
-
601
- def comment_out(docstr: str) -> tuple[str, ...]:
602
- """Convert docstrings into commented-out lines in a .proto file."""
603
- if docstr is None:
604
- return tuple()
605
-
606
- if docstr.startswith("Usage docs: https://docs.pydantic.dev/2.10/concepts/models/"):
607
- return tuple()
608
-
609
- return tuple(f"//" if line == "" else f"// {line}" for line in docstr.split("\n"))
610
-
611
-
612
- def indent_lines(lines, indentation=" "):
613
- """Indent multiple lines with a given indentation string."""
614
- return "\n".join(indentation + line for line in lines)
615
-
616
-
617
- def generate_enum_definition(enum_type: Type[enum.Enum]) -> str:
618
- """Generate a protobuf enum definition from a Python enum."""
619
- enum_name = enum_type.__name__
620
- members = []
621
- for _, member in enum_type.__members__.items():
622
- members.append(f" {member.name} = {member.value};")
623
- enum_def = f"enum {enum_name} {{\n"
624
- enum_def += "\n".join(members)
625
- enum_def += "\n}"
626
- return enum_def
627
-
628
-
629
- def generate_oneof_definition(
630
- field_name: str, union_args: list[Type], start_index: int
631
- ) -> tuple[list[str], int]:
632
- """
633
- Generate a oneof block in protobuf for a union field.
634
- Returns a tuple of the definition lines and the updated field index.
635
- """
636
- lines = []
637
- lines.append(f"oneof {field_name} {{")
638
- current = start_index
639
- for arg_type in union_args:
640
- proto_typename = protobuf_type_mapping(arg_type)
641
- if proto_typename is None:
642
- raise Exception(f"Nested Union not flattened properly: {arg_type}")
643
-
644
- # If it's an enum or Message, use the type name.
645
- if is_enum_type(arg_type):
646
- proto_typename = arg_type.__name__
647
- elif inspect.isclass(arg_type) and issubclass(arg_type, Message):
648
- proto_typename = arg_type.__name__
649
-
650
- field_alias = f"{field_name}_{proto_typename.replace('.', '_')}"
651
- lines.append(f" {proto_typename} {field_alias} = {current};")
652
- current += 1
653
- lines.append("}")
654
- return lines, current
655
-
656
-
657
- def generate_message_definition(
658
- message_type: Type[Message],
659
- done_enums: set,
660
- done_messages: set,
661
- ) -> tuple[str, list[type]]:
662
- """
663
- Generate a protobuf message definition for a Pydantic-based Message class.
664
- Also returns any referenced types (enums, messages) that need to be defined.
665
- """
666
- fields = []
667
- refs = []
668
- pydantic_fields = message_type.model_fields
669
- index = 1
670
-
671
- for field_name, field_info in pydantic_fields.items():
672
- field_type = field_info.annotation
673
- if field_type is None:
674
- raise Exception(f"Field {field_name} has no type annotation.")
675
-
676
- if is_union_type(field_type):
677
- union_args = flatten_union(field_type)
678
- oneof_lines, new_index = generate_oneof_definition(
679
- field_name, union_args, index
680
- )
681
- fields.extend(oneof_lines)
682
- index = new_index
683
-
684
- for utype in union_args:
685
- if is_enum_type(utype) and utype not in done_enums:
686
- refs.append(utype)
687
- elif (
688
- inspect.isclass(utype)
689
- and issubclass(utype, Message)
690
- and utype not in done_messages
691
- ):
692
- refs.append(utype)
693
-
694
- else:
695
- proto_typename = protobuf_type_mapping(field_type)
696
- if proto_typename is None:
697
- raise Exception(f"Type {field_type} is not supported.")
698
-
699
- if is_enum_type(field_type):
700
- proto_typename = field_type.__name__
701
- if field_type not in done_enums:
702
- refs.append(field_type)
703
- elif inspect.isclass(field_type) and issubclass(field_type, Message):
704
- proto_typename = field_type.__name__
705
- if field_type not in done_messages:
706
- refs.append(field_type)
707
-
708
- fields.append(f"{proto_typename} {field_name} = {index};")
709
- index += 1
710
-
711
- msg_def = f"message {message_type.__name__} {{\n{indent_lines(fields)}\n}}"
712
- return msg_def, refs
713
-
714
-
715
- def is_stream_type(annotation: Type) -> bool:
716
- return get_origin(annotation) is AsyncIterator
717
-
718
-
719
- def is_generic_alias(annotation: Type) -> bool:
720
- return get_origin(annotation) is not None
721
-
722
-
723
- def generate_proto(obj: object, package_name: str = "") -> str:
724
- """
725
- Generate a .proto definition from a service class.
726
- Automatically handles Timestamp and Duration usage.
727
- """
728
- import datetime
729
-
730
- service_class = obj.__class__
731
- service_name = service_class.__name__
732
- service_docstr = inspect.getdoc(service_class)
733
- service_comment = "\n".join(comment_out(service_docstr)) if service_docstr else ""
734
-
735
- rpc_definitions = []
736
- all_type_definitions = []
737
- done_messages = set()
738
- done_enums = set()
739
-
740
- uses_timestamp = False
741
- uses_duration = False
742
-
743
- def check_and_set_well_known_types(py_type: Type):
744
- nonlocal uses_timestamp, uses_duration
745
- if py_type == datetime.datetime:
746
- uses_timestamp = True
747
- if py_type == datetime.timedelta:
748
- uses_duration = True
749
-
750
- for method_name, method in get_rpc_methods(obj):
751
- if method.__name__.startswith("_"):
752
- continue
753
-
754
- method_sig = inspect.signature(method)
755
- request_type = get_request_arg_type(method_sig)
756
- response_type = method_sig.return_annotation
757
-
758
- # Recursively generate message definitions
759
- message_types = [request_type, response_type]
760
- while message_types:
761
- mt = message_types.pop()
762
- if mt in done_messages:
763
- continue
764
- done_messages.add(mt)
765
-
766
- if is_stream_type(mt):
767
- item_type = get_args(mt)[0]
768
- message_types.append(item_type)
769
- continue
770
-
771
- for _, field_info in mt.model_fields.items():
772
- t = field_info.annotation
773
- if is_union_type(t):
774
- for sub_t in flatten_union(t):
775
- check_and_set_well_known_types(sub_t)
776
- else:
777
- check_and_set_well_known_types(t)
778
-
779
- msg_def, refs = generate_message_definition(mt, done_enums, done_messages)
780
- mt_doc = inspect.getdoc(mt)
781
- if mt_doc:
782
- for comment_line in comment_out(mt_doc):
783
- all_type_definitions.append(comment_line)
784
-
785
- all_type_definitions.append(msg_def)
786
- all_type_definitions.append("")
787
-
788
- for r in refs:
789
- if is_enum_type(r) and r not in done_enums:
790
- done_enums.add(r)
791
- enum_def = generate_enum_definition(r)
792
- all_type_definitions.append(enum_def)
793
- all_type_definitions.append("")
794
- elif issubclass(r, Message) and r not in done_messages:
795
- message_types.append(r)
796
-
797
- method_docstr = inspect.getdoc(method)
798
- if method_docstr:
799
- for comment_line in comment_out(method_docstr):
800
- rpc_definitions.append(comment_line)
801
-
802
- if is_stream_type(response_type):
803
- item_type = get_args(response_type)[0]
804
- rpc_definitions.append(
805
- f"rpc {method_name} ({request_type.__name__}) returns (stream {item_type.__name__});"
806
- )
807
- else:
808
- rpc_definitions.append(
809
- f"rpc {method_name} ({request_type.__name__}) returns ({response_type.__name__});"
810
- )
811
-
812
- if not package_name:
813
- if service_name.endswith("Service"):
814
- package_name = service_name[: -len("Service")]
815
- else:
816
- package_name = service_name
817
- package_name = package_name.lower() + ".v1"
818
-
819
- imports = []
820
- if uses_timestamp:
821
- imports.append('import "google/protobuf/timestamp.proto";')
822
- if uses_duration:
823
- imports.append('import "google/protobuf/duration.proto";')
824
-
825
- import_block = "\n".join(imports)
826
- if import_block:
827
- import_block += "\n"
828
-
829
- proto_definition = f"""syntax = "proto3";
830
-
831
- package {package_name};
832
-
833
- {import_block}{service_comment}
834
- service {service_name} {{
835
- {indent_lines(rpc_definitions)}
836
- }}
837
-
838
- {indent_lines(all_type_definitions, "")}
839
- """
840
- return proto_definition
841
-
842
-
843
- def generate_grpc_code(proto_file, grpc_python_out) -> types.ModuleType | None:
844
- """
845
- Execute the protoc command to generate Python gRPC code from the .proto file.
846
- Returns a tuple of (pb2_grpc_module, pb2_module) on success, or None if failed.
847
- """
848
- command = f"-I. --grpc_python_out={grpc_python_out} {proto_file}"
849
- exit_code = protoc.main(command.split())
850
- if exit_code != 0:
851
- return None
852
-
853
- base = os.path.splitext(proto_file)[0]
854
- generated_pb2_grpc_file = f"{base}_pb2_grpc.py"
855
-
856
- if grpc_python_out not in sys.path:
857
- sys.path.append(grpc_python_out)
858
-
859
- spec = importlib.util.spec_from_file_location(
860
- generated_pb2_grpc_file, os.path.join(grpc_python_out, generated_pb2_grpc_file)
861
- )
862
- if spec is None:
863
- return None
864
- pb2_grpc_module = importlib.util.module_from_spec(spec)
865
- if spec.loader is None:
866
- return None
867
- spec.loader.exec_module(pb2_grpc_module)
868
-
869
- return pb2_grpc_module
870
-
871
-
872
- def generate_connecpy_code(
873
- proto_file: str, connecpy_out: str
874
- ) -> types.ModuleType | None:
875
- """
876
- Execute the protoc command to generate Python Connecpy code from the .proto file.
877
- Returns a tuple of (connecpy_module, pb2_module) on success, or None if failed.
878
- """
879
- command = f"-I. --connecpy_out={connecpy_out} {proto_file}"
880
- exit_code = protoc.main(command.split())
881
- if exit_code != 0:
882
- return None
883
-
884
- base = os.path.splitext(proto_file)[0]
885
- generated_connecpy_file = f"{base}_connecpy.py"
886
-
887
- if connecpy_out not in sys.path:
888
- sys.path.append(connecpy_out)
889
-
890
- spec = importlib.util.spec_from_file_location(
891
- generated_connecpy_file, os.path.join(connecpy_out, generated_connecpy_file)
892
- )
893
- if spec is None:
894
- return None
895
- connecpy_module = importlib.util.module_from_spec(spec)
896
- if spec.loader is None:
897
- return None
898
- spec.loader.exec_module(connecpy_module)
899
-
900
- return connecpy_module
901
-
902
-
903
- def generate_pb_code(
904
- proto_file: str, python_out: str, pyi_out: str
905
- ) -> types.ModuleType | None:
906
- """
907
- Execute the protoc command to generate Python gRPC code from the .proto file.
908
- Returns a tuple of (pb2_grpc_module, pb2_module) on success, or None if failed.
909
- """
910
- command = f"-I. --python_out={python_out} --pyi_out={pyi_out} {proto_file}"
911
- exit_code = protoc.main(command.split())
912
- if exit_code != 0:
913
- return None
914
-
915
- base = os.path.splitext(proto_file)[0]
916
- generated_pb2_file = f"{base}_pb2.py"
917
-
918
- if python_out not in sys.path:
919
- sys.path.append(python_out)
920
-
921
- spec = importlib.util.spec_from_file_location(
922
- generated_pb2_file, os.path.join(python_out, generated_pb2_file)
923
- )
924
- if spec is None:
925
- return None
926
- pb2_module = importlib.util.module_from_spec(spec)
927
- if spec.loader is None:
928
- return None
929
- spec.loader.exec_module(pb2_module)
930
-
931
- return pb2_module
932
-
933
-
934
- def get_request_arg_type(sig):
935
- """Return the type annotation of the first parameter (request) of a method."""
936
- num_of_params = len(sig.parameters)
937
- if not (num_of_params == 1 or num_of_params == 2):
938
- raise Exception("Method must have exactly one or two parameters")
939
- return tuple(sig.parameters.values())[0].annotation
940
-
941
-
942
- def get_rpc_methods(obj: object) -> list[tuple[str, types.MethodType]]:
943
- """
944
- Retrieve the list of RPC methods from a service object.
945
- The method name is converted to PascalCase for .proto compatibility.
946
- """
947
-
948
- def to_pascal_case(name: str) -> str:
949
- return "".join(part.capitalize() for part in name.split("_"))
950
-
951
- return [
952
- (to_pascal_case(attr_name), getattr(obj, attr_name))
953
- for attr_name in dir(obj)
954
- if inspect.ismethod(getattr(obj, attr_name))
955
- ]
956
-
957
-
958
- def is_skip_generation() -> bool:
959
- """Check if the proto file and code generation should be skipped."""
960
- return os.getenv("PYDANTIC_RPC_SKIP_GENERATION", "false").lower() == "true"
961
-
962
-
963
- def generate_and_compile_proto(obj: object, package_name: str = ""):
964
- if is_skip_generation():
965
- import importlib
966
-
967
- pb2_module = importlib.import_module(f"{obj.__class__.__name__.lower()}_pb2")
968
- pb2_grpc_module = importlib.import_module(
969
- f"{obj.__class__.__name__.lower()}_pb2_grpc"
970
- )
971
-
972
- if pb2_grpc_module is not None and pb2_module is not None:
973
- return pb2_grpc_module, pb2_module
974
-
975
- # If the modules are not found, generate and compile the proto files.
976
-
977
- klass = obj.__class__
978
- proto_file = generate_proto(obj, package_name)
979
- proto_file_name = klass.__name__.lower() + ".proto"
980
-
981
- with open(proto_file_name, "w", encoding="utf-8") as f:
982
- f.write(proto_file)
983
-
984
- gen_pb = generate_pb_code(proto_file_name, ".", ".")
985
- if gen_pb is None:
986
- raise Exception("Generating pb code")
987
-
988
- gen_grpc = generate_grpc_code(proto_file_name, ".")
989
- if gen_grpc is None:
990
- raise Exception("Generating grpc code")
991
- return gen_grpc, gen_pb
992
-
993
-
994
- def generate_and_compile_proto_using_connecpy(obj: object, package_name: str = ""):
995
- if is_skip_generation():
996
- import importlib
997
-
998
- pb2_module = importlib.import_module(f"{obj.__class__.__name__.lower()}_pb2")
999
- connecpy_module = importlib.import_module(
1000
- f"{obj.__class__.__name__.lower()}_connecpy"
1001
- )
1002
-
1003
- if connecpy_module is not None and pb2_module is not None:
1004
- return connecpy_module, pb2_module
1005
-
1006
- # If the modules are not found, generate and compile the proto files.
1007
-
1008
- klass = obj.__class__
1009
- proto_file = generate_proto(obj, package_name)
1010
- proto_file_name = klass.__name__.lower() + ".proto"
1011
-
1012
- with open(proto_file_name, "w", encoding="utf-8") as f:
1013
- f.write(proto_file)
1014
-
1015
- gen_pb = generate_pb_code(proto_file_name, ".", ".")
1016
- if gen_pb is None:
1017
- raise Exception("Generating pb code")
1018
-
1019
- gen_connecpy = generate_connecpy_code(proto_file_name, ".")
1020
- if gen_connecpy is None:
1021
- raise Exception("Generating Connecpy code")
1022
- return gen_connecpy, gen_pb
1023
-
1024
-
1025
- ###############################################################################
1026
- # 4. Server Implementations
1027
- ###############################################################################
1028
-
1029
-
1030
- class Server:
1031
- """A simple gRPC server that uses ThreadPoolExecutor for concurrency."""
1032
-
1033
- def __init__(self, max_workers: int = 8, *interceptors) -> None:
1034
- self._server = grpc.server(
1035
- futures.ThreadPoolExecutor(max_workers), interceptors=interceptors
1036
- )
1037
- self._service_names = []
1038
- self._package_name = ""
1039
- self._port = 50051
1040
-
1041
- def set_package_name(self, package_name: str):
1042
- """Set the package name for .proto generation."""
1043
- self._package_name = package_name
1044
-
1045
- def set_port(self, port: int):
1046
- """Set the port number for the gRPC server."""
1047
- self._port = port
1048
-
1049
- def mount(self, obj: object, package_name: str = ""):
1050
- """Generate and compile proto files, then mount the service implementation."""
1051
- pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
1052
- self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
1053
-
1054
- def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
1055
- """Connect the compiled gRPC modules with the service implementation."""
1056
- concreteServiceClass = connect_obj_with_stub(pb2_grpc_module, pb2_module, obj)
1057
- service_name = obj.__class__.__name__
1058
- service_impl = concreteServiceClass()
1059
- getattr(pb2_grpc_module, f"add_{service_name}Servicer_to_server")(
1060
- service_impl, self._server
1061
- )
1062
- full_service_name = pb2_module.DESCRIPTOR.services_by_name[
1063
- service_name
1064
- ].full_name
1065
- self._service_names.append(full_service_name)
1066
-
1067
- def run(self, *objs):
1068
- """
1069
- Mount multiple services and run the gRPC server with reflection and health check.
1070
- Press Ctrl+C or send SIGTERM to stop.
1071
- """
1072
- for obj in objs:
1073
- self.mount(obj, self._package_name)
1074
-
1075
- SERVICE_NAMES = (
1076
- health_pb2.DESCRIPTOR.services_by_name["Health"].full_name,
1077
- reflection.SERVICE_NAME,
1078
- *self._service_names,
1079
- )
1080
- health_servicer = HealthServicer()
1081
- health_pb2_grpc.add_HealthServicer_to_server(health_servicer, self._server)
1082
- reflection.enable_server_reflection(SERVICE_NAMES, self._server)
1083
-
1084
- self._server.add_insecure_port(f"[::]:{self._port}")
1085
- self._server.start()
1086
-
1087
- def handle_signal(signal, frame):
1088
- print("Received shutdown signal...")
1089
- self._server.stop(grace=10)
1090
- print("gRPC server shutdown.")
1091
- sys.exit(0)
1092
-
1093
- signal.signal(signal.SIGINT, handle_signal)
1094
- signal.signal(signal.SIGTERM, handle_signal)
1095
-
1096
- print("gRPC server is running...")
1097
- while True:
1098
- time.sleep(86400)
1099
-
1100
-
1101
- class AsyncIOServer:
1102
- """An async gRPC server using asyncio."""
1103
-
1104
- def __init__(self, *interceptors) -> None:
1105
- self._server = grpc.aio.server(interceptors=interceptors)
1106
- self._service_names = []
1107
- self._package_name = ""
1108
- self._port = 50051
1109
-
1110
- def set_package_name(self, package_name: str):
1111
- """Set the package name for .proto generation."""
1112
- self._package_name = package_name
1113
-
1114
- def set_port(self, port: int):
1115
- """Set the port number for the async gRPC server."""
1116
- self._port = port
1117
-
1118
- def mount(self, obj: object, package_name: str = ""):
1119
- """Generate and compile proto files, then mount the service implementation (async)."""
1120
- pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
1121
- self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
1122
-
1123
- def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
1124
- """Connect the compiled gRPC modules with the async service implementation."""
1125
- concreteServiceClass = connect_obj_with_stub_async(
1126
- pb2_grpc_module, pb2_module, obj
1127
- )
1128
- service_name = obj.__class__.__name__
1129
- service_impl = concreteServiceClass()
1130
- getattr(pb2_grpc_module, f"add_{service_name}Servicer_to_server")(
1131
- service_impl, self._server
1132
- )
1133
- full_service_name = pb2_module.DESCRIPTOR.services_by_name[
1134
- service_name
1135
- ].full_name
1136
- self._service_names.append(full_service_name)
1137
-
1138
- async def run(self, *objs):
1139
- """
1140
- Mount multiple async services and run the gRPC server with reflection and health check.
1141
- Press Ctrl+C or send SIGTERM to stop.
1142
- """
1143
- for obj in objs:
1144
- self.mount(obj, self._package_name)
1145
-
1146
- SERVICE_NAMES = (
1147
- health_pb2.DESCRIPTOR.services_by_name["Health"].full_name,
1148
- reflection.SERVICE_NAME,
1149
- *self._service_names,
1150
- )
1151
- health_servicer = HealthServicer()
1152
- health_pb2_grpc.add_HealthServicer_to_server(health_servicer, self._server)
1153
- reflection.enable_server_reflection(SERVICE_NAMES, self._server)
1154
-
1155
- self._server.add_insecure_port(f"[::]:{self._port}")
1156
- await self._server.start()
1157
-
1158
- shutdown_event = asyncio.Event()
1159
-
1160
- def shutdown(signum, frame):
1161
- print("Received shutdown signal...")
1162
- shutdown_event.set()
1163
-
1164
- for s in [signal.SIGTERM, signal.SIGINT]:
1165
- signal.signal(s, shutdown)
1166
-
1167
- print("gRPC server is running...")
1168
- await shutdown_event.wait()
1169
- await self._server.stop(10)
1170
- print("gRPC server shutdown.")
1171
-
1172
-
1173
- class WSGIApp:
1174
- """
1175
- A WSGI-compatible application that can serve gRPC via sonora's grpcWSGI.
1176
- Useful for embedding gRPC within an existing WSGI stack.
1177
- """
1178
-
1179
- def __init__(self, app):
1180
- self._app = grpcWSGI(app)
1181
- self._service_names = []
1182
- self._package_name = ""
1183
-
1184
- def mount(self, obj: object, package_name: str = ""):
1185
- """Generate and compile proto files, then mount the service implementation."""
1186
- pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
1187
- self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
1188
-
1189
- def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
1190
- """Connect the compiled gRPC modules with the service implementation."""
1191
- concreteServiceClass = connect_obj_with_stub(pb2_grpc_module, pb2_module, obj)
1192
- service_name = obj.__class__.__name__
1193
- service_impl = concreteServiceClass()
1194
- getattr(pb2_grpc_module, f"add_{service_name}Servicer_to_server")(
1195
- service_impl, self._app
1196
- )
1197
- full_service_name = pb2_module.DESCRIPTOR.services_by_name[
1198
- service_name
1199
- ].full_name
1200
- self._service_names.append(full_service_name)
1201
-
1202
- def mount_objs(self, *objs):
1203
- """Mount multiple service objects into this WSGI app."""
1204
- for obj in objs:
1205
- self.mount(obj, self._package_name)
1206
-
1207
- def __call__(self, environ, start_response):
1208
- """WSGI entry point."""
1209
- return self._app(environ, start_response)
1210
-
1211
-
1212
- class ASGIApp:
1213
- """
1214
- An ASGI-compatible application that can serve gRPC via sonora's grpcASGI.
1215
- Useful for embedding gRPC within an existing ASGI stack.
1216
- """
1217
-
1218
- def __init__(self, app):
1219
- self._app = grpcASGI(app)
1220
- self._service_names = []
1221
- self._package_name = ""
1222
-
1223
- def mount(self, obj: object, package_name: str = ""):
1224
- """Generate and compile proto files, then mount the async service implementation."""
1225
- pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
1226
- self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
1227
-
1228
- def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
1229
- """Connect the compiled gRPC modules with the async service implementation."""
1230
- concreteServiceClass = connect_obj_with_stub_async(
1231
- pb2_grpc_module, pb2_module, obj
1232
- )
1233
- service_name = obj.__class__.__name__
1234
- service_impl = concreteServiceClass()
1235
- getattr(pb2_grpc_module, f"add_{service_name}Servicer_to_server")(
1236
- service_impl, self._app
1237
- )
1238
- full_service_name = pb2_module.DESCRIPTOR.services_by_name[
1239
- service_name
1240
- ].full_name
1241
- self._service_names.append(full_service_name)
1242
-
1243
- def mount_objs(self, *objs):
1244
- """Mount multiple service objects into this ASGI app."""
1245
- for obj in objs:
1246
- self.mount(obj, self._package_name)
1247
-
1248
- async def __call__(self, scope, receive, send):
1249
- """ASGI entry point."""
1250
- await self._app(scope, receive, send)
1251
-
1252
-
1253
- def get_connecpy_server_class(connecpy_module, service_name):
1254
- return getattr(connecpy_module, f"{service_name}Server")
1255
-
1256
-
1257
- class ConnecpyASGIApp:
1258
- """
1259
- An ASGI-compatible application that can serve Connect-RPC via Connecpy's ConnecpyASGIApp.
1260
- """
1261
-
1262
- def __init__(self):
1263
- self._app = ConnecpyASGI()
1264
- self._service_names = []
1265
- self._package_name = ""
1266
-
1267
- def mount(self, obj: object, package_name: str = ""):
1268
- """Generate and compile proto files, then mount the async service implementation."""
1269
- connecpy_module, pb2_module = generate_and_compile_proto_using_connecpy(
1270
- obj, package_name
1271
- )
1272
- self.mount_using_pb2_modules(connecpy_module, pb2_module, obj)
1273
-
1274
- def mount_using_pb2_modules(self, connecpy_module, pb2_module, obj: object):
1275
- """Connect the compiled connecpy and pb2 modules with the async service implementation."""
1276
- concreteServiceClass = connect_obj_with_stub_async_connecpy(
1277
- connecpy_module, pb2_module, obj
1278
- )
1279
- service_name = obj.__class__.__name__
1280
- service_impl = concreteServiceClass()
1281
- connecpy_server = get_connecpy_server_class(connecpy_module, service_name)
1282
- self._app.add_service(connecpy_server(service=service_impl))
1283
- full_service_name = pb2_module.DESCRIPTOR.services_by_name[
1284
- service_name
1285
- ].full_name
1286
- self._service_names.append(full_service_name)
1287
-
1288
- def mount_objs(self, *objs):
1289
- """Mount multiple service objects into this ASGI app."""
1290
- for obj in objs:
1291
- self.mount(obj, self._package_name)
1292
-
1293
- async def __call__(self, scope, receive, send):
1294
- """ASGI entry point."""
1295
- await self._app(scope, receive, send)
1296
-
1297
-
1298
- if __name__ == "__main__":
1299
- """
1300
- If executed as a script, generate the .proto files for a given class.
1301
- Usage: python core.py some_module.py SomeServiceClass
1302
- """
1303
- py_file_name = sys.argv[1]
1304
- class_name = sys.argv[2]
1305
- module_name = os.path.splitext(basename(py_file_name))[0]
1306
- module = importlib.import_module(module_name)
1307
- klass = getattr(module, class_name)
1308
- generate_and_compile_proto(klass())
1
+ import annotated_types
2
+ import asyncio
3
+ import enum
4
+ import importlib.util
5
+ import inspect
6
+ import os
7
+ import signal
8
+ import sys
9
+ import time
10
+ import types
11
+ import datetime
12
+ from concurrent import futures
13
+ from posixpath import basename
14
+ from typing import (
15
+ Callable,
16
+ Type,
17
+ get_args,
18
+ get_origin,
19
+ Union,
20
+ TypeAlias,
21
+ )
22
+ from collections.abc import AsyncIterator
23
+
24
+ import grpc
25
+ from grpc_health.v1 import health_pb2, health_pb2_grpc
26
+ from grpc_health.v1.health import HealthServicer
27
+ from grpc_reflection.v1alpha import reflection
28
+ from grpc_tools import protoc
29
+ from pydantic import BaseModel, ValidationError
30
+ from sonora.wsgi import grpcWSGI
31
+ from sonora.asgi import grpcASGI
32
+ from connecpy.asgi import ConnecpyASGIApp as ConnecpyASGI
33
+ from connecpy.errors import Errors
34
+
35
+ # Protobuf Python modules for Timestamp, Duration (requires protobuf / grpcio)
36
+ from google.protobuf import timestamp_pb2, duration_pb2
37
+
38
+ ###############################################################################
39
+ # 1. Message definitions & converter extensions
40
+ # (datetime.datetime <-> google.protobuf.Timestamp)
41
+ # (datetime.timedelta <-> google.protobuf.Duration)
42
+ ###############################################################################
43
+
44
+
45
+ Message: TypeAlias = BaseModel
46
+
47
+
48
+ def primitiveProtoValueToPythonValue(value):
49
+ # Returns the value as-is (primitive type).
50
+ return value
51
+
52
+
53
+ def timestamp_to_python(ts: timestamp_pb2.Timestamp) -> datetime.datetime: # type: ignore
54
+ """Convert a protobuf Timestamp to a Python datetime object."""
55
+ return ts.ToDatetime()
56
+
57
+
58
+ def python_to_timestamp(dt: datetime.datetime) -> timestamp_pb2.Timestamp: # type: ignore
59
+ """Convert a Python datetime object to a protobuf Timestamp."""
60
+ ts = timestamp_pb2.Timestamp() # type: ignore
61
+ ts.FromDatetime(dt)
62
+ return ts
63
+
64
+
65
+ def duration_to_python(d: duration_pb2.Duration) -> datetime.timedelta: # type: ignore
66
+ """Convert a protobuf Duration to a Python timedelta object."""
67
+ return d.ToTimedelta()
68
+
69
+
70
+ def python_to_duration(td: datetime.timedelta) -> duration_pb2.Duration: # type: ignore
71
+ """Convert a Python timedelta object to a protobuf Duration."""
72
+ d = duration_pb2.Duration() # type: ignore
73
+ d.FromTimedelta(td)
74
+ return d
75
+
76
+
77
+ def generate_converter(annotation: Type) -> Callable:
78
+ """
79
+ Returns a converter function to convert protobuf types to Python types.
80
+ This is used primarily when handling incoming requests.
81
+ """
82
+ # For primitive types
83
+ if annotation in (int, str, bool, bytes, float):
84
+ return primitiveProtoValueToPythonValue
85
+
86
+ # For enum types
87
+ if inspect.isclass(annotation) and issubclass(annotation, enum.Enum):
88
+
89
+ def enum_converter(value):
90
+ return annotation(value)
91
+
92
+ return enum_converter
93
+
94
+ # For datetime
95
+ if annotation == datetime.datetime:
96
+
97
+ def ts_converter(value: timestamp_pb2.Timestamp): # type: ignore
98
+ return value.ToDatetime()
99
+
100
+ return ts_converter
101
+
102
+ # For timedelta
103
+ if annotation == datetime.timedelta:
104
+
105
+ def dur_converter(value: duration_pb2.Duration): # type: ignore
106
+ return value.ToTimedelta()
107
+
108
+ return dur_converter
109
+
110
+ origin = get_origin(annotation)
111
+ if origin is not None:
112
+ # For seq types
113
+ if origin in (list, tuple):
114
+ item_converter = generate_converter(get_args(annotation)[0])
115
+
116
+ def seq_converter(value):
117
+ return [item_converter(v) for v in value]
118
+
119
+ return seq_converter
120
+
121
+ # For dict types
122
+ if origin is dict:
123
+ key_converter = generate_converter(get_args(annotation)[0])
124
+ value_converter = generate_converter(get_args(annotation)[1])
125
+
126
+ def dict_converter(value):
127
+ return {key_converter(k): value_converter(v) for k, v in value.items()}
128
+
129
+ return dict_converter
130
+
131
+ # For Message classes
132
+ if inspect.isclass(annotation) and issubclass(annotation, Message):
133
+ return generate_message_converter(annotation)
134
+
135
+ # For union types or other unsupported cases, just return the value as-is.
136
+ return primitiveProtoValueToPythonValue
137
+
138
+
139
+ def generate_message_converter(arg_type: Type[Message]) -> Callable:
140
+ """Return a converter function for protobuf -> Python Message."""
141
+ if arg_type is None or not issubclass(arg_type, Message):
142
+ raise TypeError("Request arg must be subclass of Message")
143
+
144
+ fields = arg_type.model_fields
145
+ converters = {
146
+ field: generate_converter(field_type.annotation) # type: ignore
147
+ for field, field_type in fields.items()
148
+ }
149
+
150
+ def converter(request):
151
+ rdict = {}
152
+ for field in fields.keys():
153
+ rdict[field] = converters[field](getattr(request, field))
154
+ return arg_type(**rdict)
155
+
156
+ return converter
157
+
158
+
159
+ def python_value_to_proto_value(field_type: Type, value):
160
+ """
161
+ Converts Python values to protobuf values.
162
+ Used primarily when constructing a response object.
163
+ """
164
+ # datetime.datetime -> Timestamp
165
+ if field_type == datetime.datetime:
166
+ return python_to_timestamp(value)
167
+
168
+ # datetime.timedelta -> Duration
169
+ if field_type == datetime.timedelta:
170
+ return python_to_duration(value)
171
+
172
+ # Default behavior: return the value as-is.
173
+ return value
174
+
175
+
176
+ ###############################################################################
177
+ # 2. Stub implementation
178
+ ###############################################################################
179
+
180
+
181
+ def connect_obj_with_stub(pb2_grpc_module, pb2_module, service_obj: object) -> type:
182
+ """
183
+ Connect a Python service object to a gRPC stub, generating server methods.
184
+ """
185
+ service_class = service_obj.__class__
186
+ stub_class_name = service_class.__name__ + "Servicer"
187
+ stub_class = getattr(pb2_grpc_module, stub_class_name)
188
+
189
+ class ConcreteServiceClass(stub_class):
190
+ pass
191
+
192
+ def implement_stub_method(method):
193
+ # Analyze method signature.
194
+ sig = inspect.signature(method)
195
+ arg_type = get_request_arg_type(sig)
196
+ # Convert request from protobuf to Python.
197
+ converter = generate_message_converter(arg_type)
198
+
199
+ response_type = sig.return_annotation
200
+ size_of_parameters = len(sig.parameters)
201
+
202
+ match size_of_parameters:
203
+ case 1:
204
+
205
+ def stub_method1(self, request, context, method=method):
206
+ try:
207
+ # Convert request to Python object
208
+ arg = converter(request)
209
+ # Invoke the actual method
210
+ resp_obj = method(arg)
211
+ # Convert the returned Python Message to a protobuf message
212
+ return convert_python_message_to_proto(
213
+ resp_obj, response_type, pb2_module
214
+ )
215
+ except ValidationError as e:
216
+ return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
217
+ except Exception as e:
218
+ return context.abort(grpc.StatusCode.INTERNAL, str(e))
219
+
220
+ return stub_method1
221
+
222
+ case 2:
223
+
224
+ def stub_method2(self, request, context, method=method):
225
+ try:
226
+ arg = converter(request)
227
+ resp_obj = method(arg, context)
228
+ return convert_python_message_to_proto(
229
+ resp_obj, response_type, pb2_module
230
+ )
231
+ except ValidationError as e:
232
+ return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
233
+ except Exception as e:
234
+ return context.abort(grpc.StatusCode.INTERNAL, str(e))
235
+
236
+ return stub_method2
237
+
238
+ case _:
239
+ raise Exception("Method must have exactly one or two parameters")
240
+
241
+ for method_name, method in get_rpc_methods(service_obj):
242
+ if method.__name__.startswith("_"):
243
+ continue
244
+
245
+ a_method = implement_stub_method(method)
246
+ setattr(ConcreteServiceClass, method_name, a_method)
247
+
248
+ return ConcreteServiceClass
249
+
250
+
251
+ def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> type:
252
+ """
253
+ Connect a Python service object to a gRPC stub for async methods.
254
+ """
255
+ service_class = obj.__class__
256
+ stub_class_name = service_class.__name__ + "Servicer"
257
+ stub_class = getattr(pb2_grpc_module, stub_class_name)
258
+
259
+ class ConcreteServiceClass(stub_class):
260
+ pass
261
+
262
+ def implement_stub_method(method):
263
+ sig = inspect.signature(method)
264
+ arg_type = get_request_arg_type(sig)
265
+ converter = generate_message_converter(arg_type)
266
+ response_type = sig.return_annotation
267
+ size_of_parameters = len(sig.parameters)
268
+
269
+ if is_stream_type(response_type):
270
+ item_type = get_args(response_type)[0]
271
+ match size_of_parameters:
272
+ case 1:
273
+
274
+ async def stub_method_stream1(
275
+ self, request, context, method=method
276
+ ):
277
+ try:
278
+ arg = converter(request)
279
+ async for resp_obj in method(arg):
280
+ yield convert_python_message_to_proto(
281
+ resp_obj, item_type, pb2_module
282
+ )
283
+ except ValidationError as e:
284
+ await context.abort(
285
+ grpc.StatusCode.INVALID_ARGUMENT, str(e)
286
+ )
287
+ except Exception as e:
288
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
289
+
290
+ return stub_method_stream1
291
+ case 2:
292
+
293
+ async def stub_method_stream2(
294
+ self, request, context, method=method
295
+ ):
296
+ try:
297
+ arg = converter(request)
298
+ async for resp_obj in method(arg, context):
299
+ yield convert_python_message_to_proto(
300
+ resp_obj, item_type, pb2_module
301
+ )
302
+ except ValidationError as e:
303
+ await context.abort(
304
+ grpc.StatusCode.INVALID_ARGUMENT, str(e)
305
+ )
306
+ except Exception as e:
307
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
308
+
309
+ return stub_method_stream2
310
+ case _:
311
+ raise Exception("Method must have exactly one or two parameters")
312
+
313
+ match size_of_parameters:
314
+ case 1:
315
+
316
+ async def stub_method1(self, request, context, method=method):
317
+ try:
318
+ arg = converter(request)
319
+ resp_obj = await method(arg)
320
+ return convert_python_message_to_proto(
321
+ resp_obj, response_type, pb2_module
322
+ )
323
+ except ValidationError as e:
324
+ await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
325
+ except Exception as e:
326
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
327
+
328
+ return stub_method1
329
+
330
+ case 2:
331
+
332
+ async def stub_method2(self, request, context, method=method):
333
+ try:
334
+ arg = converter(request)
335
+ resp_obj = await method(arg, context)
336
+ return convert_python_message_to_proto(
337
+ resp_obj, response_type, pb2_module
338
+ )
339
+ except ValidationError as e:
340
+ await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
341
+ except Exception as e:
342
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
343
+
344
+ return stub_method2
345
+
346
+ case _:
347
+ raise Exception("Method must have exactly one or two parameters")
348
+
349
+ for method_name, method in get_rpc_methods(obj):
350
+ if method.__name__.startswith("_"):
351
+ continue
352
+
353
+ a_method = implement_stub_method(method)
354
+ setattr(ConcreteServiceClass, method_name, a_method)
355
+
356
+ return ConcreteServiceClass
357
+
358
+
359
+ def connect_obj_with_stub_async_connecpy(
360
+ connecpy_module, pb2_module, obj: object
361
+ ) -> type:
362
+ """
363
+ Connect a Python service object to a Connecpy stub for async methods.
364
+ """
365
+ service_class = obj.__class__
366
+ stub_class_name = service_class.__name__
367
+ stub_class = getattr(connecpy_module, stub_class_name)
368
+
369
+ class ConcreteServiceClass(stub_class):
370
+ pass
371
+
372
+ def implement_stub_method(method):
373
+ sig = inspect.signature(method)
374
+ arg_type = get_request_arg_type(sig)
375
+ converter = generate_message_converter(arg_type)
376
+ response_type = sig.return_annotation
377
+ size_of_parameters = len(sig.parameters)
378
+
379
+ match size_of_parameters:
380
+ case 1:
381
+
382
+ async def stub_method1(self, request, context, method=method):
383
+ try:
384
+ arg = converter(request)
385
+ resp_obj = await method(arg)
386
+ return convert_python_message_to_proto(
387
+ resp_obj, response_type, pb2_module
388
+ )
389
+ except ValidationError as e:
390
+ await context.abort(Errors.InvalidArgument, str(e))
391
+ except Exception as e:
392
+ await context.abort(Errors.Internal, str(e))
393
+
394
+ return stub_method1
395
+
396
+ case 2:
397
+
398
+ async def stub_method2(self, request, context, method=method):
399
+ try:
400
+ arg = converter(request)
401
+ resp_obj = await method(arg, context)
402
+ return convert_python_message_to_proto(
403
+ resp_obj, response_type, pb2_module
404
+ )
405
+ except ValidationError as e:
406
+ await context.abort(Errors.InvalidArgument, str(e))
407
+ except Exception as e:
408
+ await context.abort(Errors.Internal, str(e))
409
+
410
+ return stub_method2
411
+
412
+ case _:
413
+ raise Exception("Method must have exactly one or two parameters")
414
+
415
+ for method_name, method in get_rpc_methods(obj):
416
+ if method.__name__.startswith("_"):
417
+ continue
418
+ if not asyncio.iscoroutinefunction(method):
419
+ raise Exception("Method must be async", method_name)
420
+ a_method = implement_stub_method(method)
421
+ setattr(ConcreteServiceClass, method_name, a_method)
422
+
423
+ return ConcreteServiceClass
424
+
425
+
426
+ def convert_python_message_to_proto(
427
+ py_msg: Message, msg_type: Type, pb2_module
428
+ ) -> object:
429
+ """
430
+ Convert a Python Pydantic Message instance to a protobuf message instance.
431
+ Used for constructing a response.
432
+ """
433
+ # Before calling something like pb2_module.AResponseMessage(...),
434
+ # convert each field from Python to proto.
435
+ field_dict = {}
436
+ for name, field_info in msg_type.model_fields.items():
437
+ value = getattr(py_msg, name)
438
+ if value is None:
439
+ field_dict[name] = None
440
+ continue
441
+
442
+ field_type = field_info.annotation
443
+ field_dict[name] = python_value_to_proto(field_type, value, pb2_module)
444
+
445
+ # Retrieve the appropriate protobuf class dynamically
446
+ proto_class = getattr(pb2_module, msg_type.__name__)
447
+ return proto_class(**field_dict)
448
+
449
+
450
+ def python_value_to_proto(field_type: Type, value, pb2_module):
451
+ """
452
+ Perform Python->protobuf type conversion for each field value.
453
+ """
454
+ import inspect
455
+ import datetime
456
+
457
+ # If datetime
458
+ if field_type == datetime.datetime:
459
+ return python_to_timestamp(value)
460
+
461
+ # If timedelta
462
+ if field_type == datetime.timedelta:
463
+ return python_to_duration(value)
464
+
465
+ # If enum
466
+ if inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
467
+ return value.value # proto3 enum is an int
468
+
469
+ origin = get_origin(field_type)
470
+ # If seq
471
+ if origin in (list, tuple):
472
+ inner_type = get_args(field_type)[0] # type: ignore
473
+ return [python_value_to_proto(inner_type, v, pb2_module) for v in value]
474
+
475
+ # If dict
476
+ if origin is dict:
477
+ key_type, val_type = get_args(field_type) # type: ignore
478
+ return {
479
+ python_value_to_proto(key_type, k, pb2_module): python_value_to_proto(
480
+ val_type, v, pb2_module
481
+ )
482
+ for k, v in value.items()
483
+ }
484
+
485
+ # If union -> oneof
486
+ if is_union_type(field_type):
487
+ # Flatten union and check which type matches. If matched, return converted value.
488
+ for sub_type in flatten_union(field_type):
489
+ if sub_type == datetime.datetime and isinstance(value, datetime.datetime):
490
+ return python_to_timestamp(value)
491
+ if sub_type == datetime.timedelta and isinstance(value, datetime.timedelta):
492
+ return python_to_duration(value)
493
+ if (
494
+ inspect.isclass(sub_type)
495
+ and issubclass(sub_type, enum.Enum)
496
+ and isinstance(value, enum.Enum)
497
+ ):
498
+ return value.value
499
+ if sub_type in (int, float, str, bool, bytes) and isinstance(
500
+ value, sub_type
501
+ ):
502
+ return value
503
+ if (
504
+ inspect.isclass(sub_type)
505
+ and issubclass(sub_type, Message)
506
+ and isinstance(value, Message)
507
+ ):
508
+ return convert_python_message_to_proto(value, sub_type, pb2_module)
509
+ return None
510
+
511
+ # If Message
512
+ if inspect.isclass(field_type) and issubclass(field_type, Message):
513
+ return convert_python_message_to_proto(value, field_type, pb2_module)
514
+
515
+ # If primitive
516
+ return value
517
+
518
+
519
+ ###############################################################################
520
+ # 3. Generating proto files (datetime->Timestamp, timedelta->Duration)
521
+ ###############################################################################
522
+
523
+
524
+ def is_enum_type(python_type: Type) -> bool:
525
+ """Return True if the given Python type is an enum."""
526
+ return inspect.isclass(python_type) and issubclass(python_type, enum.Enum)
527
+
528
+
529
+ def is_union_type(python_type: Type) -> bool:
530
+ """
531
+ Check if a given Python type is a Union type (including Python 3.10's UnionType).
532
+ """
533
+ if get_origin(python_type) is Union:
534
+ return True
535
+ if sys.version_info >= (3, 10):
536
+ import types
537
+
538
+ if type(python_type) is types.UnionType:
539
+ return True
540
+ return False
541
+
542
+
543
+ def flatten_union(field_type: Type) -> list[Type]:
544
+ """Recursively flatten nested Unions into a single list of types."""
545
+ if is_union_type(field_type):
546
+ results = []
547
+ for arg in get_args(field_type):
548
+ results.extend(flatten_union(arg))
549
+ return results
550
+ else:
551
+ return [field_type]
552
+
553
+
554
+ def protobuf_type_mapping(python_type: Type) -> str | type | None:
555
+ """
556
+ Map a Python type to a protobuf type name/class.
557
+ Includes support for Timestamp and Duration.
558
+ """
559
+ import datetime
560
+
561
+ mapping = {
562
+ int: "int32",
563
+ str: "string",
564
+ bool: "bool",
565
+ bytes: "bytes",
566
+ float: "float",
567
+ }
568
+
569
+ if python_type == datetime.datetime:
570
+ return "google.protobuf.Timestamp"
571
+
572
+ if python_type == datetime.timedelta:
573
+ return "google.protobuf.Duration"
574
+
575
+ if is_enum_type(python_type):
576
+ return python_type # Will be defined as enum later
577
+
578
+ if is_union_type(python_type):
579
+ return None # Handled separately as oneof
580
+
581
+ if hasattr(python_type, "__origin__"):
582
+ if python_type.__origin__ in (list, tuple):
583
+ inner_type = python_type.__args__[0]
584
+ inner_proto_type = protobuf_type_mapping(inner_type)
585
+ if inner_proto_type:
586
+ return f"repeated {inner_proto_type}"
587
+ elif python_type.__origin__ is dict:
588
+ key_type = python_type.__args__[0]
589
+ value_type = python_type.__args__[1]
590
+ key_proto_type = protobuf_type_mapping(key_type)
591
+ value_proto_type = protobuf_type_mapping(value_type)
592
+ if key_proto_type and value_proto_type:
593
+ return f"map<{key_proto_type}, {value_proto_type}>"
594
+
595
+ if inspect.isclass(python_type) and issubclass(python_type, Message):
596
+ return python_type
597
+
598
+ return mapping.get(python_type) # type: ignore
599
+
600
+
601
+ def comment_out(docstr: str) -> tuple[str, ...]:
602
+ """Convert docstrings into commented-out lines in a .proto file."""
603
+ if not docstr:
604
+ return tuple()
605
+
606
+ if docstr.startswith("Usage docs: https://docs.pydantic.dev/2.10/concepts/models/"):
607
+ return tuple()
608
+
609
+ return tuple("//" if line == "" else f"// {line}" for line in docstr.split("\n"))
610
+
611
+
612
+ def indent_lines(lines, indentation=" "):
613
+ """Indent multiple lines with a given indentation string."""
614
+ return "\n".join(indentation + line for line in lines)
615
+
616
+
617
+ def generate_enum_definition(enum_type: Type[enum.Enum]) -> str:
618
+ """Generate a protobuf enum definition from a Python enum."""
619
+ enum_name = enum_type.__name__
620
+ members = []
621
+ for _, member in enum_type.__members__.items():
622
+ members.append(f" {member.name} = {member.value};")
623
+ enum_def = f"enum {enum_name} {{\n"
624
+ enum_def += "\n".join(members)
625
+ enum_def += "\n}"
626
+ return enum_def
627
+
628
+
629
+ def generate_oneof_definition(
630
+ field_name: str, union_args: list[Type], start_index: int
631
+ ) -> tuple[list[str], int]:
632
+ """
633
+ Generate a oneof block in protobuf for a union field.
634
+ Returns a tuple of the definition lines and the updated field index.
635
+ """
636
+ lines = []
637
+ lines.append(f"oneof {field_name} {{")
638
+ current = start_index
639
+ for arg_type in union_args:
640
+ proto_typename = protobuf_type_mapping(arg_type)
641
+ if proto_typename is None:
642
+ raise Exception(f"Nested Union not flattened properly: {arg_type}")
643
+
644
+ # If it's an enum or Message, use the type name.
645
+ if is_enum_type(arg_type):
646
+ proto_typename = arg_type.__name__
647
+ elif inspect.isclass(arg_type) and issubclass(arg_type, Message):
648
+ proto_typename = arg_type.__name__
649
+
650
+ field_alias = f"{field_name}_{proto_typename.replace('.', '_')}"
651
+ lines.append(f" {proto_typename} {field_alias} = {current};")
652
+ current += 1
653
+ lines.append("}")
654
+ return lines, current
655
+
656
+
657
+ def generate_message_definition(
658
+ message_type: Type[Message],
659
+ done_enums: set,
660
+ done_messages: set,
661
+ ) -> tuple[str, list[type]]:
662
+ """
663
+ Generate a protobuf message definition for a Pydantic-based Message class.
664
+ Also returns any referenced types (enums, messages) that need to be defined.
665
+ """
666
+ fields = []
667
+ refs = []
668
+ pydantic_fields = message_type.model_fields
669
+ index = 1
670
+
671
+ for field_name, field_info in pydantic_fields.items():
672
+ field_type = field_info.annotation
673
+ if field_type is None:
674
+ raise Exception(f"Field {field_name} has no type annotation.")
675
+
676
+ if is_union_type(field_type):
677
+ union_args = flatten_union(field_type)
678
+ oneof_lines, new_index = generate_oneof_definition(
679
+ field_name, union_args, index
680
+ )
681
+ fields.extend(oneof_lines)
682
+ index = new_index
683
+
684
+ for utype in union_args:
685
+ if is_enum_type(utype) and utype not in done_enums:
686
+ refs.append(utype)
687
+ elif (
688
+ inspect.isclass(utype)
689
+ and issubclass(utype, Message)
690
+ and utype not in done_messages
691
+ ):
692
+ refs.append(utype)
693
+
694
+ else:
695
+ proto_typename = protobuf_type_mapping(field_type)
696
+ if proto_typename is None:
697
+ raise Exception(f"Type {field_type} is not supported.")
698
+
699
+ if is_enum_type(field_type):
700
+ proto_typename = field_type.__name__
701
+ if field_type not in done_enums:
702
+ refs.append(field_type)
703
+ elif inspect.isclass(field_type) and issubclass(field_type, Message):
704
+ proto_typename = field_type.__name__
705
+ if field_type not in done_messages:
706
+ refs.append(field_type)
707
+
708
+ if field_info.description:
709
+ fields.append("// " + field_info.description)
710
+ if field_info.metadata:
711
+ fields.append("// Constraint:")
712
+ for metadata_item in field_info.metadata:
713
+ match type(metadata_item):
714
+ case annotated_types.Ge:
715
+ fields.append(
716
+ "// greater than or equal to " + str(metadata_item.ge)
717
+ )
718
+ case annotated_types.Le:
719
+ fields.append(
720
+ "// less than or equal to " + str(metadata_item.le)
721
+ )
722
+ case annotated_types.Gt:
723
+ fields.append("// greater than " + str(metadata_item.gt))
724
+ case annotated_types.Lt:
725
+ fields.append("// less than " + str(metadata_item.lt))
726
+ case annotated_types.MultipleOf:
727
+ fields.append(
728
+ "// multiple of " + str(metadata_item.multiple_of)
729
+ )
730
+ case annotated_types.Len:
731
+ fields.append("// length of " + str(metadata_item.len))
732
+ case annotated_types.MinLen:
733
+ fields.append(
734
+ "// minimum length of " + str(metadata_item.min_len)
735
+ )
736
+ case annotated_types.MaxLen:
737
+ fields.append(
738
+ "// maximum length of " + str(metadata_item.max_len)
739
+ )
740
+ case _:
741
+ fields.append("// " + str(metadata_item))
742
+
743
+ fields.append(f"{proto_typename} {field_name} = {index};")
744
+ index += 1
745
+
746
+ msg_def = f"message {message_type.__name__} {{\n{indent_lines(fields)}\n}}"
747
+ return msg_def, refs
748
+
749
+
750
+ def is_stream_type(annotation: Type) -> bool:
751
+ return get_origin(annotation) is AsyncIterator
752
+
753
+
754
+ def is_generic_alias(annotation: Type) -> bool:
755
+ return get_origin(annotation) is not None
756
+
757
+
758
+ def generate_proto(obj: object, package_name: str = "") -> str:
759
+ """
760
+ Generate a .proto definition from a service class.
761
+ Automatically handles Timestamp and Duration usage.
762
+ """
763
+ import datetime
764
+
765
+ service_class = obj.__class__
766
+ service_name = service_class.__name__
767
+ service_docstr = inspect.getdoc(service_class)
768
+ service_comment = "\n".join(comment_out(service_docstr)) if service_docstr else ""
769
+
770
+ rpc_definitions = []
771
+ all_type_definitions = []
772
+ done_messages = set()
773
+ done_enums = set()
774
+
775
+ uses_timestamp = False
776
+ uses_duration = False
777
+
778
+ def check_and_set_well_known_types(py_type: Type):
779
+ nonlocal uses_timestamp, uses_duration
780
+ if py_type == datetime.datetime:
781
+ uses_timestamp = True
782
+ if py_type == datetime.timedelta:
783
+ uses_duration = True
784
+
785
+ for method_name, method in get_rpc_methods(obj):
786
+ if method.__name__.startswith("_"):
787
+ continue
788
+
789
+ method_sig = inspect.signature(method)
790
+ request_type = get_request_arg_type(method_sig)
791
+ response_type = method_sig.return_annotation
792
+
793
+ # Recursively generate message definitions
794
+ message_types = [request_type, response_type]
795
+ while message_types:
796
+ mt = message_types.pop()
797
+ if mt in done_messages:
798
+ continue
799
+ done_messages.add(mt)
800
+
801
+ if is_stream_type(mt):
802
+ item_type = get_args(mt)[0]
803
+ message_types.append(item_type)
804
+ continue
805
+
806
+ for _, field_info in mt.model_fields.items():
807
+ t = field_info.annotation
808
+ if is_union_type(t):
809
+ for sub_t in flatten_union(t):
810
+ check_and_set_well_known_types(sub_t)
811
+ else:
812
+ check_and_set_well_known_types(t)
813
+
814
+ msg_def, refs = generate_message_definition(mt, done_enums, done_messages)
815
+ mt_doc = inspect.getdoc(mt)
816
+ if mt_doc:
817
+ for comment_line in comment_out(mt_doc):
818
+ all_type_definitions.append(comment_line)
819
+
820
+ all_type_definitions.append(msg_def)
821
+ all_type_definitions.append("")
822
+
823
+ for r in refs:
824
+ if is_enum_type(r) and r not in done_enums:
825
+ done_enums.add(r)
826
+ enum_def = generate_enum_definition(r)
827
+ all_type_definitions.append(enum_def)
828
+ all_type_definitions.append("")
829
+ elif issubclass(r, Message) and r not in done_messages:
830
+ message_types.append(r)
831
+
832
+ method_docstr = inspect.getdoc(method)
833
+ if method_docstr:
834
+ for comment_line in comment_out(method_docstr):
835
+ rpc_definitions.append(comment_line)
836
+
837
+ if is_stream_type(response_type):
838
+ item_type = get_args(response_type)[0]
839
+ rpc_definitions.append(
840
+ f"rpc {method_name} ({request_type.__name__}) returns (stream {item_type.__name__});"
841
+ )
842
+ else:
843
+ rpc_definitions.append(
844
+ f"rpc {method_name} ({request_type.__name__}) returns ({response_type.__name__});"
845
+ )
846
+
847
+ if not package_name:
848
+ if service_name.endswith("Service"):
849
+ package_name = service_name[: -len("Service")]
850
+ else:
851
+ package_name = service_name
852
+ package_name = package_name.lower() + ".v1"
853
+
854
+ imports = []
855
+ if uses_timestamp:
856
+ imports.append('import "google/protobuf/timestamp.proto";')
857
+ if uses_duration:
858
+ imports.append('import "google/protobuf/duration.proto";')
859
+
860
+ import_block = "\n".join(imports)
861
+ if import_block:
862
+ import_block += "\n"
863
+
864
+ proto_definition = f"""syntax = "proto3";
865
+
866
+ package {package_name};
867
+
868
+ {import_block}{service_comment}
869
+ service {service_name} {{
870
+ {indent_lines(rpc_definitions)}
871
+ }}
872
+
873
+ {indent_lines(all_type_definitions, "")}
874
+ """
875
+ return proto_definition
876
+
877
+
878
+ def generate_grpc_code(proto_file, grpc_python_out) -> types.ModuleType | None:
879
+ """
880
+ Execute the protoc command to generate Python gRPC code from the .proto file.
881
+ Returns a tuple of (pb2_grpc_module, pb2_module) on success, or None if failed.
882
+ """
883
+ command = f"-I. --grpc_python_out={grpc_python_out} {proto_file}"
884
+ exit_code = protoc.main(command.split())
885
+ if exit_code != 0:
886
+ return None
887
+
888
+ base = os.path.splitext(proto_file)[0]
889
+ generated_pb2_grpc_file = f"{base}_pb2_grpc.py"
890
+
891
+ if grpc_python_out not in sys.path:
892
+ sys.path.append(grpc_python_out)
893
+
894
+ spec = importlib.util.spec_from_file_location(
895
+ generated_pb2_grpc_file, os.path.join(grpc_python_out, generated_pb2_grpc_file)
896
+ )
897
+ if spec is None:
898
+ return None
899
+ pb2_grpc_module = importlib.util.module_from_spec(spec)
900
+ if spec.loader is None:
901
+ return None
902
+ spec.loader.exec_module(pb2_grpc_module)
903
+
904
+ return pb2_grpc_module
905
+
906
+
907
+ def generate_connecpy_code(
908
+ proto_file: str, connecpy_out: str
909
+ ) -> types.ModuleType | None:
910
+ """
911
+ Execute the protoc command to generate Python Connecpy code from the .proto file.
912
+ Returns a tuple of (connecpy_module, pb2_module) on success, or None if failed.
913
+ """
914
+ command = f"-I. --connecpy_out={connecpy_out} {proto_file}"
915
+ exit_code = protoc.main(command.split())
916
+ if exit_code != 0:
917
+ return None
918
+
919
+ base = os.path.splitext(proto_file)[0]
920
+ generated_connecpy_file = f"{base}_connecpy.py"
921
+
922
+ if connecpy_out not in sys.path:
923
+ sys.path.append(connecpy_out)
924
+
925
+ spec = importlib.util.spec_from_file_location(
926
+ generated_connecpy_file, os.path.join(connecpy_out, generated_connecpy_file)
927
+ )
928
+ if spec is None:
929
+ return None
930
+ connecpy_module = importlib.util.module_from_spec(spec)
931
+ if spec.loader is None:
932
+ return None
933
+ spec.loader.exec_module(connecpy_module)
934
+
935
+ return connecpy_module
936
+
937
+
938
+ def generate_pb_code(
939
+ proto_file: str, python_out: str, pyi_out: str
940
+ ) -> types.ModuleType | None:
941
+ """
942
+ Execute the protoc command to generate Python gRPC code from the .proto file.
943
+ Returns a tuple of (pb2_grpc_module, pb2_module) on success, or None if failed.
944
+ """
945
+ command = f"-I. --python_out={python_out} --pyi_out={pyi_out} {proto_file}"
946
+ exit_code = protoc.main(command.split())
947
+ if exit_code != 0:
948
+ return None
949
+
950
+ base = os.path.splitext(proto_file)[0]
951
+ generated_pb2_file = f"{base}_pb2.py"
952
+
953
+ if python_out not in sys.path:
954
+ sys.path.append(python_out)
955
+
956
+ spec = importlib.util.spec_from_file_location(
957
+ generated_pb2_file, os.path.join(python_out, generated_pb2_file)
958
+ )
959
+ if spec is None:
960
+ return None
961
+ pb2_module = importlib.util.module_from_spec(spec)
962
+ if spec.loader is None:
963
+ return None
964
+ spec.loader.exec_module(pb2_module)
965
+
966
+ return pb2_module
967
+
968
+
969
+ def get_request_arg_type(sig):
970
+ """Return the type annotation of the first parameter (request) of a method."""
971
+ num_of_params = len(sig.parameters)
972
+ if not (num_of_params == 1 or num_of_params == 2):
973
+ raise Exception("Method must have exactly one or two parameters")
974
+ return tuple(sig.parameters.values())[0].annotation
975
+
976
+
977
+ def get_rpc_methods(obj: object) -> list[tuple[str, types.MethodType]]:
978
+ """
979
+ Retrieve the list of RPC methods from a service object.
980
+ The method name is converted to PascalCase for .proto compatibility.
981
+ """
982
+
983
+ def to_pascal_case(name: str) -> str:
984
+ return "".join(part.capitalize() for part in name.split("_"))
985
+
986
+ return [
987
+ (to_pascal_case(attr_name), getattr(obj, attr_name))
988
+ for attr_name in dir(obj)
989
+ if inspect.ismethod(getattr(obj, attr_name))
990
+ ]
991
+
992
+
993
+ def is_skip_generation() -> bool:
994
+ """Check if the proto file and code generation should be skipped."""
995
+ return os.getenv("PYDANTIC_RPC_SKIP_GENERATION", "false").lower() == "true"
996
+
997
+
998
+ def generate_and_compile_proto(obj: object, package_name: str = ""):
999
+ if is_skip_generation():
1000
+ import importlib
1001
+
1002
+ pb2_module = importlib.import_module(f"{obj.__class__.__name__.lower()}_pb2")
1003
+ pb2_grpc_module = importlib.import_module(
1004
+ f"{obj.__class__.__name__.lower()}_pb2_grpc"
1005
+ )
1006
+
1007
+ if pb2_grpc_module is not None and pb2_module is not None:
1008
+ return pb2_grpc_module, pb2_module
1009
+
1010
+ # If the modules are not found, generate and compile the proto files.
1011
+
1012
+ klass = obj.__class__
1013
+ proto_file = generate_proto(obj, package_name)
1014
+ proto_file_name = klass.__name__.lower() + ".proto"
1015
+
1016
+ with open(proto_file_name, "w", encoding="utf-8") as f:
1017
+ f.write(proto_file)
1018
+
1019
+ gen_pb = generate_pb_code(proto_file_name, ".", ".")
1020
+ if gen_pb is None:
1021
+ raise Exception("Generating pb code")
1022
+
1023
+ gen_grpc = generate_grpc_code(proto_file_name, ".")
1024
+ if gen_grpc is None:
1025
+ raise Exception("Generating grpc code")
1026
+ return gen_grpc, gen_pb
1027
+
1028
+
1029
+ def generate_and_compile_proto_using_connecpy(obj: object, package_name: str = ""):
1030
+ if is_skip_generation():
1031
+ import importlib
1032
+
1033
+ pb2_module = importlib.import_module(f"{obj.__class__.__name__.lower()}_pb2")
1034
+ connecpy_module = importlib.import_module(
1035
+ f"{obj.__class__.__name__.lower()}_connecpy"
1036
+ )
1037
+
1038
+ if connecpy_module is not None and pb2_module is not None:
1039
+ return connecpy_module, pb2_module
1040
+
1041
+ # If the modules are not found, generate and compile the proto files.
1042
+
1043
+ klass = obj.__class__
1044
+ proto_file = generate_proto(obj, package_name)
1045
+ proto_file_name = klass.__name__.lower() + ".proto"
1046
+
1047
+ with open(proto_file_name, "w", encoding="utf-8") as f:
1048
+ f.write(proto_file)
1049
+
1050
+ gen_pb = generate_pb_code(proto_file_name, ".", ".")
1051
+ if gen_pb is None:
1052
+ raise Exception("Generating pb code")
1053
+
1054
+ gen_connecpy = generate_connecpy_code(proto_file_name, ".")
1055
+ if gen_connecpy is None:
1056
+ raise Exception("Generating Connecpy code")
1057
+ return gen_connecpy, gen_pb
1058
+
1059
+
1060
+ ###############################################################################
1061
+ # 4. Server Implementations
1062
+ ###############################################################################
1063
+
1064
+
1065
+ class Server:
1066
+ """A simple gRPC server that uses ThreadPoolExecutor for concurrency."""
1067
+
1068
+ def __init__(self, max_workers: int = 8, *interceptors) -> None:
1069
+ self._server = grpc.server(
1070
+ futures.ThreadPoolExecutor(max_workers), interceptors=interceptors
1071
+ )
1072
+ self._service_names = []
1073
+ self._package_name = ""
1074
+ self._port = 50051
1075
+
1076
+ def set_package_name(self, package_name: str):
1077
+ """Set the package name for .proto generation."""
1078
+ self._package_name = package_name
1079
+
1080
+ def set_port(self, port: int):
1081
+ """Set the port number for the gRPC server."""
1082
+ self._port = port
1083
+
1084
+ def mount(self, obj: object, package_name: str = ""):
1085
+ """Generate and compile proto files, then mount the service implementation."""
1086
+ pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
1087
+ self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
1088
+
1089
+ def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
1090
+ """Connect the compiled gRPC modules with the service implementation."""
1091
+ concreteServiceClass = connect_obj_with_stub(pb2_grpc_module, pb2_module, obj)
1092
+ service_name = obj.__class__.__name__
1093
+ service_impl = concreteServiceClass()
1094
+ getattr(pb2_grpc_module, f"add_{service_name}Servicer_to_server")(
1095
+ service_impl, self._server
1096
+ )
1097
+ full_service_name = pb2_module.DESCRIPTOR.services_by_name[
1098
+ service_name
1099
+ ].full_name
1100
+ self._service_names.append(full_service_name)
1101
+
1102
+ def run(self, *objs):
1103
+ """
1104
+ Mount multiple services and run the gRPC server with reflection and health check.
1105
+ Press Ctrl+C or send SIGTERM to stop.
1106
+ """
1107
+ for obj in objs:
1108
+ self.mount(obj, self._package_name)
1109
+
1110
+ SERVICE_NAMES = (
1111
+ health_pb2.DESCRIPTOR.services_by_name["Health"].full_name,
1112
+ reflection.SERVICE_NAME,
1113
+ *self._service_names,
1114
+ )
1115
+ health_servicer = HealthServicer()
1116
+ health_pb2_grpc.add_HealthServicer_to_server(health_servicer, self._server)
1117
+ reflection.enable_server_reflection(SERVICE_NAMES, self._server)
1118
+
1119
+ self._server.add_insecure_port(f"[::]:{self._port}")
1120
+ self._server.start()
1121
+
1122
+ def handle_signal(signal, frame):
1123
+ print("Received shutdown signal...")
1124
+ self._server.stop(grace=10)
1125
+ print("gRPC server shutdown.")
1126
+ sys.exit(0)
1127
+
1128
+ signal.signal(signal.SIGINT, handle_signal)
1129
+ signal.signal(signal.SIGTERM, handle_signal)
1130
+
1131
+ print("gRPC server is running...")
1132
+ while True:
1133
+ time.sleep(86400)
1134
+
1135
+
1136
+ class AsyncIOServer:
1137
+ """An async gRPC server using asyncio."""
1138
+
1139
+ def __init__(self, *interceptors) -> None:
1140
+ self._server = grpc.aio.server(interceptors=interceptors)
1141
+ self._service_names = []
1142
+ self._package_name = ""
1143
+ self._port = 50051
1144
+
1145
+ def set_package_name(self, package_name: str):
1146
+ """Set the package name for .proto generation."""
1147
+ self._package_name = package_name
1148
+
1149
+ def set_port(self, port: int):
1150
+ """Set the port number for the async gRPC server."""
1151
+ self._port = port
1152
+
1153
+ def mount(self, obj: object, package_name: str = ""):
1154
+ """Generate and compile proto files, then mount the service implementation (async)."""
1155
+ pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
1156
+ self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
1157
+
1158
+ def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
1159
+ """Connect the compiled gRPC modules with the async service implementation."""
1160
+ concreteServiceClass = connect_obj_with_stub_async(
1161
+ pb2_grpc_module, pb2_module, obj
1162
+ )
1163
+ service_name = obj.__class__.__name__
1164
+ service_impl = concreteServiceClass()
1165
+ getattr(pb2_grpc_module, f"add_{service_name}Servicer_to_server")(
1166
+ service_impl, self._server
1167
+ )
1168
+ full_service_name = pb2_module.DESCRIPTOR.services_by_name[
1169
+ service_name
1170
+ ].full_name
1171
+ self._service_names.append(full_service_name)
1172
+
1173
+ async def run(self, *objs):
1174
+ """
1175
+ Mount multiple async services and run the gRPC server with reflection and health check.
1176
+ Press Ctrl+C or send SIGTERM to stop.
1177
+ """
1178
+ for obj in objs:
1179
+ self.mount(obj, self._package_name)
1180
+
1181
+ SERVICE_NAMES = (
1182
+ health_pb2.DESCRIPTOR.services_by_name["Health"].full_name,
1183
+ reflection.SERVICE_NAME,
1184
+ *self._service_names,
1185
+ )
1186
+ health_servicer = HealthServicer()
1187
+ health_pb2_grpc.add_HealthServicer_to_server(health_servicer, self._server)
1188
+ reflection.enable_server_reflection(SERVICE_NAMES, self._server)
1189
+
1190
+ self._server.add_insecure_port(f"[::]:{self._port}")
1191
+ await self._server.start()
1192
+
1193
+ shutdown_event = asyncio.Event()
1194
+
1195
+ def shutdown(signum, frame):
1196
+ print("Received shutdown signal...")
1197
+ shutdown_event.set()
1198
+
1199
+ for s in [signal.SIGTERM, signal.SIGINT]:
1200
+ signal.signal(s, shutdown)
1201
+
1202
+ print("gRPC server is running...")
1203
+ await shutdown_event.wait()
1204
+ await self._server.stop(10)
1205
+ print("gRPC server shutdown.")
1206
+
1207
+
1208
+ class WSGIApp:
1209
+ """
1210
+ A WSGI-compatible application that can serve gRPC via sonora's grpcWSGI.
1211
+ Useful for embedding gRPC within an existing WSGI stack.
1212
+ """
1213
+
1214
+ def __init__(self, app):
1215
+ self._app = grpcWSGI(app)
1216
+ self._service_names = []
1217
+ self._package_name = ""
1218
+
1219
+ def mount(self, obj: object, package_name: str = ""):
1220
+ """Generate and compile proto files, then mount the service implementation."""
1221
+ pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
1222
+ self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
1223
+
1224
+ def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
1225
+ """Connect the compiled gRPC modules with the service implementation."""
1226
+ concreteServiceClass = connect_obj_with_stub(pb2_grpc_module, pb2_module, obj)
1227
+ service_name = obj.__class__.__name__
1228
+ service_impl = concreteServiceClass()
1229
+ getattr(pb2_grpc_module, f"add_{service_name}Servicer_to_server")(
1230
+ service_impl, self._app
1231
+ )
1232
+ full_service_name = pb2_module.DESCRIPTOR.services_by_name[
1233
+ service_name
1234
+ ].full_name
1235
+ self._service_names.append(full_service_name)
1236
+
1237
+ def mount_objs(self, *objs):
1238
+ """Mount multiple service objects into this WSGI app."""
1239
+ for obj in objs:
1240
+ self.mount(obj, self._package_name)
1241
+
1242
+ def __call__(self, environ, start_response):
1243
+ """WSGI entry point."""
1244
+ return self._app(environ, start_response)
1245
+
1246
+
1247
+ class ASGIApp:
1248
+ """
1249
+ An ASGI-compatible application that can serve gRPC via sonora's grpcASGI.
1250
+ Useful for embedding gRPC within an existing ASGI stack.
1251
+ """
1252
+
1253
+ def __init__(self, app):
1254
+ self._app = grpcASGI(app)
1255
+ self._service_names = []
1256
+ self._package_name = ""
1257
+
1258
+ def mount(self, obj: object, package_name: str = ""):
1259
+ """Generate and compile proto files, then mount the async service implementation."""
1260
+ pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
1261
+ self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
1262
+
1263
+ def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
1264
+ """Connect the compiled gRPC modules with the async service implementation."""
1265
+ concreteServiceClass = connect_obj_with_stub_async(
1266
+ pb2_grpc_module, pb2_module, obj
1267
+ )
1268
+ service_name = obj.__class__.__name__
1269
+ service_impl = concreteServiceClass()
1270
+ getattr(pb2_grpc_module, f"add_{service_name}Servicer_to_server")(
1271
+ service_impl, self._app
1272
+ )
1273
+ full_service_name = pb2_module.DESCRIPTOR.services_by_name[
1274
+ service_name
1275
+ ].full_name
1276
+ self._service_names.append(full_service_name)
1277
+
1278
+ def mount_objs(self, *objs):
1279
+ """Mount multiple service objects into this ASGI app."""
1280
+ for obj in objs:
1281
+ self.mount(obj, self._package_name)
1282
+
1283
+ async def __call__(self, scope, receive, send):
1284
+ """ASGI entry point."""
1285
+ await self._app(scope, receive, send)
1286
+
1287
+
1288
+ def get_connecpy_server_class(connecpy_module, service_name):
1289
+ return getattr(connecpy_module, f"{service_name}Server")
1290
+
1291
+
1292
+ class ConnecpyASGIApp:
1293
+ """
1294
+ An ASGI-compatible application that can serve Connect-RPC via Connecpy's ConnecpyASGIApp.
1295
+ """
1296
+
1297
+ def __init__(self):
1298
+ self._app = ConnecpyASGI()
1299
+ self._service_names = []
1300
+ self._package_name = ""
1301
+
1302
+ def mount(self, obj: object, package_name: str = ""):
1303
+ """Generate and compile proto files, then mount the async service implementation."""
1304
+ connecpy_module, pb2_module = generate_and_compile_proto_using_connecpy(
1305
+ obj, package_name
1306
+ )
1307
+ self.mount_using_pb2_modules(connecpy_module, pb2_module, obj)
1308
+
1309
+ def mount_using_pb2_modules(self, connecpy_module, pb2_module, obj: object):
1310
+ """Connect the compiled connecpy and pb2 modules with the async service implementation."""
1311
+ concreteServiceClass = connect_obj_with_stub_async_connecpy(
1312
+ connecpy_module, pb2_module, obj
1313
+ )
1314
+ service_name = obj.__class__.__name__
1315
+ service_impl = concreteServiceClass()
1316
+ connecpy_server = get_connecpy_server_class(connecpy_module, service_name)
1317
+ self._app.add_service(connecpy_server(service=service_impl))
1318
+ full_service_name = pb2_module.DESCRIPTOR.services_by_name[
1319
+ service_name
1320
+ ].full_name
1321
+ self._service_names.append(full_service_name)
1322
+
1323
+ def mount_objs(self, *objs):
1324
+ """Mount multiple service objects into this ASGI app."""
1325
+ for obj in objs:
1326
+ self.mount(obj, self._package_name)
1327
+
1328
+ async def __call__(self, scope, receive, send):
1329
+ """ASGI entry point."""
1330
+ await self._app(scope, receive, send)
1331
+
1332
+
1333
+ if __name__ == "__main__":
1334
+ """
1335
+ If executed as a script, generate the .proto files for a given class.
1336
+ Usage: python core.py some_module.py SomeServiceClass
1337
+ """
1338
+ py_file_name = sys.argv[1]
1339
+ class_name = sys.argv[2]
1340
+ module_name = os.path.splitext(basename(py_file_name))[0]
1341
+ module = importlib.import_module(module_name)
1342
+ klass = getattr(module, class_name)
1343
+ generate_and_compile_proto(klass())