pydantic-rpc 0.6.1__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pydantic_rpc/core.py CHANGED
@@ -1,5 +1,5 @@
1
- import annotated_types
2
1
  import asyncio
2
+ import datetime
3
3
  import enum
4
4
  import importlib.util
5
5
  import inspect
@@ -8,33 +8,35 @@ import signal
8
8
  import sys
9
9
  import time
10
10
  import types
11
- import datetime
11
+ from typing import Union
12
+ from collections.abc import AsyncIterator, Awaitable, Callable
12
13
  from concurrent import futures
14
+ from pathlib import Path
13
15
  from posixpath import basename
14
16
  from typing import (
15
- Callable,
16
- Type,
17
+ Any,
18
+ TypeAlias,
17
19
  get_args,
18
20
  get_origin,
19
- Union,
20
- TypeAlias,
21
+ cast,
22
+ TypeGuard,
21
23
  )
22
- from collections.abc import AsyncIterator
23
24
 
25
+ import annotated_types
24
26
  import grpc
27
+ from grpc import ServicerContext
28
+ import grpc_tools
29
+ from connecpy.code import Code as Errors
30
+
31
+ # Protobuf Python modules for Timestamp, Duration (requires protobuf / grpcio)
32
+ from google.protobuf import duration_pb2, timestamp_pb2, empty_pb2
25
33
  from grpc_health.v1 import health_pb2, health_pb2_grpc
26
34
  from grpc_health.v1.health import HealthServicer
27
35
  from grpc_reflection.v1alpha import reflection
28
36
  from grpc_tools import protoc
29
37
  from pydantic import BaseModel, ValidationError
30
- from sonora.wsgi import grpcWSGI
31
38
  from sonora.asgi import grpcASGI
32
- from connecpy.asgi import ConnecpyASGIApp as ConnecpyASGI
33
- from connecpy.errors import Errors
34
- from connecpy.wsgi import ConnecpyWSGIApp as ConnecpyWSGI
35
-
36
- # Protobuf Python modules for Timestamp, Duration (requires protobuf / grpcio)
37
- from google.protobuf import timestamp_pb2, duration_pb2
39
+ from sonora.wsgi import grpcWSGI
38
40
 
39
41
  ###############################################################################
40
42
  # 1. Message definitions & converter extensions
@@ -45,8 +47,57 @@ from google.protobuf import timestamp_pb2, duration_pb2
45
47
 
46
48
  Message: TypeAlias = BaseModel
47
49
 
50
+ # Cache for serializer checks
51
+ _has_serializers_cache: dict[type[Message], bool] = {}
52
+ _has_field_serializers_cache: dict[type[Message], set[str]] = {}
53
+
54
+
55
+ # Serializer strategy configuration
56
+ class SerializerStrategy(enum.Enum):
57
+ """Strategy for applying Pydantic serializers."""
58
+
59
+ NONE = "none" # No serializers applied
60
+ SHALLOW = "shallow" # Only top-level serializers
61
+ DEEP = "deep" # Nested serializers too (default)
62
+
63
+
64
+ # Get strategy from environment variable
65
+ _SERIALIZER_STRATEGY = SerializerStrategy(
66
+ os.getenv("PYDANTIC_RPC_SERIALIZER_STRATEGY", "deep").lower()
67
+ )
68
+
69
+
70
+ def has_serializers(msg_type: type[Message]) -> bool:
71
+ """Check if a Message type has any serializers (cached for performance)."""
72
+ if msg_type not in _has_serializers_cache:
73
+ # Check for both field and model serializers
74
+ has_field = bool(getattr(msg_type, "__pydantic_serializer__", None))
75
+ has_model = bool(
76
+ getattr(msg_type, "model_dump", None) and msg_type != BaseModel
77
+ )
78
+ _has_serializers_cache[msg_type] = has_field or has_model
79
+ return _has_serializers_cache[msg_type]
80
+
81
+
82
+ def get_field_serializers(msg_type: type[Message]) -> set[str]:
83
+ """Get the set of fields that have serializers (cached for performance)."""
84
+ if msg_type not in _has_field_serializers_cache:
85
+ fields_with_serializers = set()
86
+ # Check model_fields_set or similar for fields with serializers
87
+ if hasattr(msg_type, "__pydantic_decorators__"):
88
+ decorators = msg_type.__pydantic_decorators__
89
+ if hasattr(decorators, "field_serializers"):
90
+ fields_with_serializers = set(decorators.field_serializers.keys())
91
+ _has_field_serializers_cache[msg_type] = fields_with_serializers
92
+ return _has_field_serializers_cache[msg_type]
93
+
94
+
95
+ def is_none_type(annotation: Any) -> TypeGuard[type[None] | None]:
96
+ """Check if annotation represents None/NoneType (handles both None and type(None))."""
97
+ return annotation is None or annotation is type(None)
98
+
48
99
 
49
- def primitiveProtoValueToPythonValue(value):
100
+ def primitiveProtoValueToPythonValue(value: Any):
50
101
  # Returns the value as-is (primitive type).
51
102
  return value
52
103
 
@@ -75,11 +126,20 @@ def python_to_duration(td: datetime.timedelta) -> duration_pb2.Duration: # type
75
126
  return d
76
127
 
77
128
 
78
- def generate_converter(annotation: Type) -> Callable:
129
+ def generate_converter(annotation: type[Any] | None) -> Callable[[Any], Any]:
79
130
  """
80
131
  Returns a converter function to convert protobuf types to Python types.
81
132
  This is used primarily when handling incoming requests.
82
133
  """
134
+ # For NoneType (Empty messages)
135
+ if is_none_type(annotation):
136
+
137
+ def empty_converter(value: empty_pb2.Empty): # type: ignore
138
+ _ = value
139
+ return None
140
+
141
+ return empty_converter
142
+
83
143
  # For primitive types
84
144
  if annotation in (int, str, bool, bytes, float):
85
145
  return primitiveProtoValueToPythonValue
@@ -87,7 +147,7 @@ def generate_converter(annotation: Type) -> Callable:
87
147
  # For enum types
88
148
  if inspect.isclass(annotation) and issubclass(annotation, enum.Enum):
89
149
 
90
- def enum_converter(value):
150
+ def enum_converter(value: enum.Enum):
91
151
  return annotation(value)
92
152
 
93
153
  return enum_converter
@@ -114,7 +174,7 @@ def generate_converter(annotation: Type) -> Callable:
114
174
  if origin in (list, tuple):
115
175
  item_converter = generate_converter(get_args(annotation)[0])
116
176
 
117
- def seq_converter(value):
177
+ def seq_converter(value: list[Any] | tuple[Any, ...]):
118
178
  return [item_converter(v) for v in value]
119
179
 
120
180
  return seq_converter
@@ -124,40 +184,130 @@ def generate_converter(annotation: Type) -> Callable:
124
184
  key_converter = generate_converter(get_args(annotation)[0])
125
185
  value_converter = generate_converter(get_args(annotation)[1])
126
186
 
127
- def dict_converter(value):
187
+ def dict_converter(value: dict[Any, Any]):
128
188
  return {key_converter(k): value_converter(v) for k, v in value.items()}
129
189
 
130
190
  return dict_converter
131
191
 
132
192
  # For Message classes
133
193
  if inspect.isclass(annotation) and issubclass(annotation, Message):
194
+ # Check if it's an empty message class
195
+ if not annotation.model_fields:
196
+
197
+ def empty_message_converter(value: empty_pb2.Empty): # type: ignore
198
+ _ = value
199
+ return annotation() # Return instance of the empty message class
200
+
201
+ return empty_message_converter
134
202
  return generate_message_converter(annotation)
135
203
 
136
204
  # For union types or other unsupported cases, just return the value as-is.
137
205
  return primitiveProtoValueToPythonValue
138
206
 
139
207
 
140
- def generate_message_converter(arg_type: Type[Message]) -> Callable:
208
+ def generate_message_converter(
209
+ arg_type: type[Message] | type[None] | None,
210
+ ) -> Callable[[Any], Message | None]:
141
211
  """Return a converter function for protobuf -> Python Message."""
142
- if arg_type is None or not issubclass(arg_type, Message):
143
- raise TypeError("Request arg must be subclass of Message")
144
212
 
213
+ # Handle NoneType (Empty messages)
214
+ if is_none_type(arg_type):
215
+
216
+ def empty_converter(request: Any) -> None:
217
+ _ = request
218
+ return None
219
+
220
+ return empty_converter
221
+
222
+ arg_type = cast("type[Message]", arg_type)
145
223
  fields = arg_type.model_fields
224
+
225
+ # Handle empty message classes (no fields)
226
+ if not fields:
227
+
228
+ def empty_message_converter(request: Any) -> Message:
229
+ _ = request # The incoming request will be google.protobuf.Empty
230
+ return arg_type() # Return an instance of the empty message class
231
+
232
+ return empty_message_converter
146
233
  converters = {
147
234
  field: generate_converter(field_type.annotation) # type: ignore
148
235
  for field, field_type in fields.items()
149
236
  }
150
237
 
151
- def converter(request):
238
+ def converter(request: Any) -> Message:
152
239
  rdict = {}
153
- for field in fields.keys():
154
- rdict[field] = converters[field](getattr(request, field))
155
- return arg_type(**rdict)
240
+ for field_name, field_info in fields.items():
241
+ field_type = field_info.annotation
242
+
243
+ # Check if this is a union type
244
+ if field_type is not None and is_union_type(field_type):
245
+ union_args = flatten_union(field_type)
246
+ has_none = type(None) in union_args
247
+ non_none_args = [arg for arg in union_args if arg is not type(None)]
248
+
249
+ if has_none and len(non_none_args) == 1:
250
+ # This is Optional[T] - check if protobuf field is set
251
+ try:
252
+ if hasattr(request, "HasField") and request.HasField(
253
+ field_name
254
+ ):
255
+ rdict[field_name] = converters[field_name](
256
+ getattr(request, field_name)
257
+ )
258
+ else:
259
+ # Field not set in protobuf, set to None for Optional fields
260
+ rdict[field_name] = None
261
+ except ValueError:
262
+ # HasField doesn't work for this field type (e.g., repeated fields)
263
+ # Fall back to regular conversion
264
+ rdict[field_name] = converters[field_name](
265
+ getattr(request, field_name)
266
+ )
267
+ elif len(non_none_args) > 1:
268
+ # This is a oneof field (Union[str, int] etc.)
269
+ # Check which oneof field is set
270
+ try:
271
+ which_field = request.WhichOneof(field_name)
272
+ if which_field:
273
+ # Extract the value from the set oneof field
274
+ proto_value = getattr(request, which_field)
275
+
276
+ # Determine the Python type from the oneof field name
277
+ # e.g., "value_string" -> str, "value_int32" -> int
278
+ for union_arg in non_none_args:
279
+ proto_typename = protobuf_type_mapping(union_arg)
280
+ if (
281
+ proto_typename
282
+ and which_field
283
+ == f"{field_name}_{proto_typename.replace('.', '_')}"
284
+ ):
285
+ # Convert using the specific type converter
286
+ type_converter = generate_converter(union_arg)
287
+ rdict[field_name] = type_converter(proto_value)
288
+ break
289
+ except (AttributeError, ValueError):
290
+ # WhichOneof failed, try fallback
291
+ # This shouldn't happen for properly generated oneof fields
292
+ pass
293
+ else:
294
+ # Union with only None type (shouldn't happen)
295
+ pass
296
+ else:
297
+ # For non-union fields, convert normally
298
+ rdict[field_name] = converters[field_name](getattr(request, field_name))
299
+
300
+ # Use model_validate to support @model_validator
301
+ try:
302
+ return arg_type.model_validate(rdict)
303
+ except AttributeError:
304
+ # Fallback for older Pydantic versions or if model_validate doesn't exist
305
+ return arg_type(**rdict)
156
306
 
157
307
  return converter
158
308
 
159
309
 
160
- def python_value_to_proto_value(field_type: Type, value):
310
+ def python_value_to_proto_value(field_type: type[Any], value: Any) -> Any:
161
311
  """
162
312
  Converts Python values to protobuf values.
163
313
  Used primarily when constructing a response object.
@@ -179,77 +329,129 @@ def python_value_to_proto_value(field_type: Type, value):
179
329
  ###############################################################################
180
330
 
181
331
 
182
- def connect_obj_with_stub(pb2_grpc_module, pb2_module, service_obj: object) -> type:
332
+ def connect_obj_with_stub(
333
+ pb2_grpc_module: Any, pb2_module: Any, service_obj: object
334
+ ) -> type:
183
335
  """
184
336
  Connect a Python service object to a gRPC stub, generating server methods.
337
+ Returns a subclass of the generated Servicer stub with concrete implementations.
185
338
  """
186
339
  service_class = service_obj.__class__
187
340
  stub_class_name = service_class.__name__ + "Servicer"
188
341
  stub_class = getattr(pb2_grpc_module, stub_class_name)
189
342
 
190
343
  class ConcreteServiceClass(stub_class):
344
+ """Dynamically generated servicer class with stub methods implemented."""
345
+
191
346
  pass
192
347
 
193
- def implement_stub_method(method):
194
- # Analyze method signature.
348
+ def implement_stub_method(
349
+ method: Callable[..., Message],
350
+ ) -> Callable[[object, Any, Any], Any]:
351
+ """
352
+ Wraps a user-defined method (self, *args) -> R into a gRPC stub signature:
353
+ (self, request_proto, context) -> response_proto
354
+ """
195
355
  sig = inspect.signature(method)
196
356
  arg_type = get_request_arg_type(sig)
197
- # Convert request from protobuf to Python.
198
- converter = generate_message_converter(arg_type)
199
-
200
357
  response_type = sig.return_annotation
201
- size_of_parameters = len(sig.parameters)
202
-
203
- match size_of_parameters:
204
- case 1:
358
+ param_count = len(sig.parameters)
359
+ converter = generate_message_converter(arg_type)
205
360
 
206
- def stub_method1(self, request, context, method=method):
207
- try:
208
- # Convert request to Python object
361
+ if param_count == 0 or param_count == 1:
362
+
363
+ def stub_method(
364
+ self: object,
365
+ request: Any,
366
+ context: Any,
367
+ *,
368
+ original: Callable[..., Message] = method,
369
+ ) -> Any:
370
+ _ = self
371
+ try:
372
+ if param_count == 0:
373
+ # Method takes no parameters
374
+ resp_obj = original()
375
+ elif is_none_type(arg_type):
376
+ resp_obj = original(None) # Fixed: pass None instead of no args
377
+ else:
209
378
  arg = converter(request)
210
- # Invoke the actual method
211
- resp_obj = method(arg)
212
- # Convert the returned Python Message to a protobuf message
379
+ resp_obj = original(arg)
380
+
381
+ if is_none_type(response_type):
382
+ return empty_pb2.Empty() # type: ignore
383
+ elif (
384
+ inspect.isclass(response_type)
385
+ and issubclass(response_type, Message)
386
+ and not response_type.model_fields
387
+ ):
388
+ # Empty message class
389
+ return empty_pb2.Empty() # type: ignore
390
+ else:
213
391
  return convert_python_message_to_proto(
214
392
  resp_obj, response_type, pb2_module
215
393
  )
216
- except ValidationError as e:
217
- return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
218
- except Exception as e:
219
- return context.abort(grpc.StatusCode.INTERNAL, str(e))
220
-
221
- return stub_method1
222
-
223
- case 2:
224
-
225
- def stub_method2(self, request, context, method=method):
226
- try:
394
+ except ValidationError as e:
395
+ return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
396
+ except Exception as e:
397
+ return context.abort(grpc.StatusCode.INTERNAL, str(e))
398
+
399
+ elif param_count == 2:
400
+
401
+ def stub_method(
402
+ self: object,
403
+ request: Any,
404
+ context: Any,
405
+ *,
406
+ original: Callable[..., Message] = method,
407
+ ) -> Any:
408
+ _ = self
409
+ try:
410
+ if is_none_type(arg_type):
411
+ resp_obj = original(
412
+ None, context
413
+ ) # Fixed: pass None instead of Empty
414
+ else:
227
415
  arg = converter(request)
228
- resp_obj = method(arg, context)
416
+ resp_obj = original(arg, context)
417
+
418
+ if is_none_type(response_type):
419
+ return empty_pb2.Empty() # type: ignore
420
+ elif (
421
+ inspect.isclass(response_type)
422
+ and issubclass(response_type, Message)
423
+ and not response_type.model_fields
424
+ ):
425
+ # Empty message class
426
+ return empty_pb2.Empty() # type: ignore
427
+ else:
229
428
  return convert_python_message_to_proto(
230
429
  resp_obj, response_type, pb2_module
231
430
  )
232
- except ValidationError as e:
233
- return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
234
- except Exception as e:
235
- return context.abort(grpc.StatusCode.INTERNAL, str(e))
431
+ except ValidationError as e:
432
+ return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
433
+ except Exception as e:
434
+ return context.abort(grpc.StatusCode.INTERNAL, str(e))
236
435
 
237
- return stub_method2
436
+ else:
437
+ raise TypeError(
438
+ f"Method '{method.__name__}' must have exactly 1 or 2 parameters, got {param_count}"
439
+ )
238
440
 
239
- case _:
240
- raise Exception("Method must have exactly one or two parameters")
441
+ return stub_method
241
442
 
443
+ # Attach all RPC methods from service_obj to the concrete servicer
242
444
  for method_name, method in get_rpc_methods(service_obj):
243
- if method.__name__.startswith("_"):
445
+ if method_name.startswith("_"):
244
446
  continue
245
-
246
- a_method = implement_stub_method(method)
247
- setattr(ConcreteServiceClass, method_name, a_method)
447
+ setattr(ConcreteServiceClass, method_name, implement_stub_method(method))
248
448
 
249
449
  return ConcreteServiceClass
250
450
 
251
451
 
252
- def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> type:
452
+ def connect_obj_with_stub_async(
453
+ pb2_grpc_module: Any, pb2_module: Any, obj: object
454
+ ) -> type:
253
455
  """
254
456
  Connect a Python service object to a gRPC stub for async methods.
255
457
  """
@@ -260,26 +462,151 @@ def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> typ
260
462
  class ConcreteServiceClass(stub_class):
261
463
  pass
262
464
 
263
- def implement_stub_method(method):
465
+ def implement_stub_method(
466
+ method: Callable[..., Any],
467
+ ) -> Callable[[object, Any, Any], Any]:
264
468
  sig = inspect.signature(method)
265
- arg_type = get_request_arg_type(sig)
266
- converter = generate_message_converter(arg_type)
469
+ input_type = get_request_arg_type(sig)
470
+ is_input_stream = is_stream_type(input_type)
267
471
  response_type = sig.return_annotation
472
+ is_output_stream = is_stream_type(response_type)
268
473
  size_of_parameters = len(sig.parameters)
269
474
 
270
- if is_stream_type(response_type):
271
- item_type = get_args(response_type)[0]
272
- match size_of_parameters:
273
- case 1:
475
+ if size_of_parameters not in (0, 1, 2):
476
+ raise TypeError(
477
+ f"Method '{method.__name__}' must have 0, 1 or 2 parameters, got {size_of_parameters}"
478
+ )
274
479
 
275
- async def stub_method_stream1(
276
- self, request, context, method=method
277
- ):
480
+ if is_input_stream:
481
+ input_item_type = get_args(input_type)[0]
482
+ item_converter = generate_message_converter(input_item_type)
483
+
484
+ async def convert_iterator(
485
+ proto_iter: AsyncIterator[Any],
486
+ ) -> AsyncIterator[Message]:
487
+ async for proto in proto_iter:
488
+ result = item_converter(proto)
489
+ if result is None:
490
+ raise TypeError(
491
+ f"Unexpected None result from converter for type {input_item_type}"
492
+ )
493
+ yield result
494
+
495
+ if is_output_stream:
496
+ # stream-stream
497
+ output_item_type = get_args(response_type)[0]
498
+
499
+ if size_of_parameters == 1:
500
+
501
+ async def stub_method(
502
+ self: object,
503
+ request_iterator: AsyncIterator[Any],
504
+ context: Any,
505
+ ) -> AsyncIterator[Any]:
506
+ _ = self
507
+ try:
508
+ arg_iter = convert_iterator(request_iterator)
509
+ async for resp_obj in method(arg_iter):
510
+ yield convert_python_message_to_proto(
511
+ resp_obj, output_item_type, pb2_module
512
+ )
513
+ except ValidationError as e:
514
+ await context.abort(
515
+ grpc.StatusCode.INVALID_ARGUMENT, str(e)
516
+ )
517
+ except Exception as e:
518
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
519
+
520
+ else: # size_of_parameters == 2
521
+
522
+ async def stub_method(
523
+ self: object,
524
+ request_iterator: AsyncIterator[Any],
525
+ context: Any,
526
+ ) -> AsyncIterator[Any]:
527
+ _ = self
528
+ try:
529
+ arg_iter = convert_iterator(request_iterator)
530
+ async for resp_obj in method(arg_iter, context):
531
+ yield convert_python_message_to_proto(
532
+ resp_obj, output_item_type, pb2_module
533
+ )
534
+ except ValidationError as e:
535
+ await context.abort(
536
+ grpc.StatusCode.INVALID_ARGUMENT, str(e)
537
+ )
538
+ except Exception as e:
539
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
540
+
541
+ return stub_method
542
+
543
+ else:
544
+ # stream-unary
545
+ if size_of_parameters == 1:
546
+
547
+ async def stub_method(
548
+ self: object,
549
+ request_iterator: AsyncIterator[Any],
550
+ context: Any,
551
+ ) -> Any:
552
+ _ = self
553
+ try:
554
+ arg_iter = convert_iterator(request_iterator)
555
+ resp_obj = await method(arg_iter)
556
+ return convert_python_message_to_proto(
557
+ resp_obj, response_type, pb2_module
558
+ )
559
+ except ValidationError as e:
560
+ await context.abort(
561
+ grpc.StatusCode.INVALID_ARGUMENT, str(e)
562
+ )
563
+ except Exception as e:
564
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
565
+
566
+ else: # size_of_parameters == 2
567
+
568
+ async def stub_method(
569
+ self: object,
570
+ request_iterator: AsyncIterator[Any],
571
+ context: Any,
572
+ ) -> Any:
573
+ _ = self
574
+ try:
575
+ arg_iter = convert_iterator(request_iterator)
576
+ resp_obj = await method(arg_iter, context)
577
+ return convert_python_message_to_proto(
578
+ resp_obj, response_type, pb2_module
579
+ )
580
+ except ValidationError as e:
581
+ await context.abort(
582
+ grpc.StatusCode.INVALID_ARGUMENT, str(e)
583
+ )
584
+ except Exception as e:
585
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
586
+
587
+ return stub_method
588
+
589
+ else:
590
+ # unary input
591
+ converter = generate_message_converter(input_type)
592
+
593
+ if is_output_stream:
594
+ # unary-stream
595
+ output_item_type = get_args(response_type)[0]
596
+
597
+ if size_of_parameters == 1:
598
+
599
+ async def stub_method(
600
+ self: object,
601
+ request: Any,
602
+ context: Any,
603
+ ) -> AsyncIterator[Any]:
604
+ _ = self
278
605
  try:
279
606
  arg = converter(request)
280
607
  async for resp_obj in method(arg):
281
608
  yield convert_python_message_to_proto(
282
- resp_obj, item_type, pb2_module
609
+ resp_obj, output_item_type, pb2_module
283
610
  )
284
611
  except ValidationError as e:
285
612
  await context.abort(
@@ -288,17 +615,19 @@ def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> typ
288
615
  except Exception as e:
289
616
  await context.abort(grpc.StatusCode.INTERNAL, str(e))
290
617
 
291
- return stub_method_stream1
292
- case 2:
618
+ else: # size_of_parameters == 2
293
619
 
294
- async def stub_method_stream2(
295
- self, request, context, method=method
296
- ):
620
+ async def stub_method(
621
+ self: object,
622
+ request: Any,
623
+ context: Any,
624
+ ) -> AsyncIterator[Any]:
625
+ _ = self
297
626
  try:
298
627
  arg = converter(request)
299
628
  async for resp_obj in method(arg, context):
300
629
  yield convert_python_message_to_proto(
301
- resp_obj, item_type, pb2_module
630
+ resp_obj, output_item_type, pb2_module
302
631
  )
303
632
  except ValidationError as e:
304
633
  await context.abort(
@@ -307,45 +636,84 @@ def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> typ
307
636
  except Exception as e:
308
637
  await context.abort(grpc.StatusCode.INTERNAL, str(e))
309
638
 
310
- return stub_method_stream2
311
- case _:
312
- raise Exception("Method must have exactly one or two parameters")
639
+ return stub_method
313
640
 
314
- match size_of_parameters:
315
- case 1:
641
+ else:
642
+ # unary-unary
643
+ if size_of_parameters == 0 or size_of_parameters == 1:
316
644
 
317
- async def stub_method1(self, request, context, method=method):
318
- try:
319
- arg = converter(request)
320
- resp_obj = await method(arg)
321
- return convert_python_message_to_proto(
322
- resp_obj, response_type, pb2_module
323
- )
324
- except ValidationError as e:
325
- await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
326
- except Exception as e:
327
- await context.abort(grpc.StatusCode.INTERNAL, str(e))
328
-
329
- return stub_method1
330
-
331
- case 2:
645
+ async def stub_method(
646
+ self: object,
647
+ request: Any,
648
+ context: Any,
649
+ ) -> Any:
650
+ _ = self
651
+ try:
652
+ if size_of_parameters == 0:
653
+ # Method takes no parameters
654
+ resp_obj = await method()
655
+ elif is_none_type(input_type):
656
+ resp_obj = await method(None)
657
+ else:
658
+ arg = converter(request)
659
+ resp_obj = await method(arg)
660
+
661
+ if is_none_type(response_type):
662
+ return empty_pb2.Empty() # type: ignore
663
+ elif (
664
+ inspect.isclass(response_type)
665
+ and issubclass(response_type, Message)
666
+ and not response_type.model_fields
667
+ ):
668
+ # Empty message class
669
+ return empty_pb2.Empty() # type: ignore
670
+ else:
671
+ return convert_python_message_to_proto(
672
+ resp_obj, response_type, pb2_module
673
+ )
674
+ except ValidationError as e:
675
+ await context.abort(
676
+ grpc.StatusCode.INVALID_ARGUMENT, str(e)
677
+ )
678
+ except Exception as e:
679
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
332
680
 
333
- async def stub_method2(self, request, context, method=method):
334
- try:
335
- arg = converter(request)
336
- resp_obj = await method(arg, context)
337
- return convert_python_message_to_proto(
338
- resp_obj, response_type, pb2_module
339
- )
340
- except ValidationError as e:
341
- await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
342
- except Exception as e:
343
- await context.abort(grpc.StatusCode.INTERNAL, str(e))
681
+ else: # size_of_parameters == 2
344
682
 
345
- return stub_method2
683
+ async def stub_method(
684
+ self: object,
685
+ request: Any,
686
+ context: Any,
687
+ ) -> Any:
688
+ _ = self
689
+ try:
690
+ if is_none_type(input_type):
691
+ resp_obj = await method(None, context)
692
+ else:
693
+ arg = converter(request)
694
+ resp_obj = await method(arg, context)
695
+
696
+ if is_none_type(response_type):
697
+ return empty_pb2.Empty() # type: ignore
698
+ elif (
699
+ inspect.isclass(response_type)
700
+ and issubclass(response_type, Message)
701
+ and not response_type.model_fields
702
+ ):
703
+ # Empty message class
704
+ return empty_pb2.Empty() # type: ignore
705
+ else:
706
+ return convert_python_message_to_proto(
707
+ resp_obj, response_type, pb2_module
708
+ )
709
+ except ValidationError as e:
710
+ await context.abort(
711
+ grpc.StatusCode.INVALID_ARGUMENT, str(e)
712
+ )
713
+ except Exception as e:
714
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
346
715
 
347
- case _:
348
- raise Exception("Method must have exactly one or two parameters")
716
+ return stub_method
349
717
 
350
718
  for method_name, method in get_rpc_methods(obj):
351
719
  if method.__name__.startswith("_"):
@@ -357,7 +725,9 @@ def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> typ
357
725
  return ConcreteServiceClass
358
726
 
359
727
 
360
- def connect_obj_with_stub_connecpy(connecpy_module, pb2_module, obj: object) -> type:
728
+ def connect_obj_with_stub_connecpy(
729
+ connecpy_module: Any, pb2_module: Any, obj: object
730
+ ) -> type:
361
731
  """
362
732
  Connect a Python service object to a Connecpy stub.
363
733
  """
@@ -368,7 +738,9 @@ def connect_obj_with_stub_connecpy(connecpy_module, pb2_module, obj: object) ->
368
738
  class ConcreteServiceClass(stub_class):
369
739
  pass
370
740
 
371
- def implement_stub_method(method):
741
+ def implement_stub_method(
742
+ method: Callable[..., Message],
743
+ ) -> Callable[[object, Any, Any], Any]:
372
744
  sig = inspect.signature(method)
373
745
  arg_type = get_request_arg_type(sig)
374
746
  converter = generate_message_converter(arg_type)
@@ -376,40 +748,91 @@ def connect_obj_with_stub_connecpy(connecpy_module, pb2_module, obj: object) ->
376
748
  size_of_parameters = len(sig.parameters)
377
749
 
378
750
  match size_of_parameters:
751
+ case 0:
752
+ # Method with no parameters (empty request)
753
+ def stub_method0(
754
+ self: object,
755
+ request: Any,
756
+ context: Any,
757
+ method: Callable[..., Message] = method,
758
+ ) -> Any:
759
+ _ = self
760
+ try:
761
+ resp_obj = method()
762
+
763
+ if is_none_type(response_type):
764
+ return empty_pb2.Empty() # type: ignore
765
+ else:
766
+ return convert_python_message_to_proto(
767
+ resp_obj, response_type, pb2_module
768
+ )
769
+ except ValidationError as e:
770
+ return context.abort(Errors.INVALID_ARGUMENT, str(e))
771
+ except Exception as e:
772
+ return context.abort(Errors.INTERNAL, str(e))
773
+
774
+ return stub_method0
775
+
379
776
  case 1:
380
777
 
381
- def stub_method1(self, request, context, method=method):
778
+ def stub_method1(
779
+ self: object,
780
+ request: Any,
781
+ context: Any,
782
+ method: Callable[..., Message] = method,
783
+ ) -> Any:
784
+ _ = self
382
785
  try:
383
- arg = converter(request)
384
- resp_obj = method(arg)
385
- return convert_python_message_to_proto(
386
- resp_obj, response_type, pb2_module
387
- )
786
+ if is_none_type(arg_type):
787
+ resp_obj = method(None)
788
+ else:
789
+ arg = converter(request)
790
+ resp_obj = method(arg)
791
+
792
+ if is_none_type(response_type):
793
+ return empty_pb2.Empty() # type: ignore
794
+ else:
795
+ return convert_python_message_to_proto(
796
+ resp_obj, response_type, pb2_module
797
+ )
388
798
  except ValidationError as e:
389
- return context.abort(Errors.InvalidArgument, str(e))
799
+ return context.abort(Errors.INVALID_ARGUMENT, str(e))
390
800
  except Exception as e:
391
- return context.abort(Errors.Internal, str(e))
801
+ return context.abort(Errors.INTERNAL, str(e))
392
802
 
393
803
  return stub_method1
394
804
 
395
805
  case 2:
396
806
 
397
- def stub_method2(self, request, context, method=method):
807
+ def stub_method2(
808
+ self: object,
809
+ request: Any,
810
+ context: Any,
811
+ method: Callable[..., Message] = method,
812
+ ) -> Any:
813
+ _ = self
398
814
  try:
399
- arg = converter(request)
400
- resp_obj = method(arg, context)
401
- return convert_python_message_to_proto(
402
- resp_obj, response_type, pb2_module
403
- )
815
+ if is_none_type(arg_type):
816
+ resp_obj = method(None, context)
817
+ else:
818
+ arg = converter(request)
819
+ resp_obj = method(arg, context)
820
+
821
+ if is_none_type(response_type):
822
+ return empty_pb2.Empty() # type: ignore
823
+ else:
824
+ return convert_python_message_to_proto(
825
+ resp_obj, response_type, pb2_module
826
+ )
404
827
  except ValidationError as e:
405
- return context.abort(Errors.InvalidArgument, str(e))
828
+ return context.abort(Errors.INVALID_ARGUMENT, str(e))
406
829
  except Exception as e:
407
- return context.abort(Errors.Internal, str(e))
830
+ return context.abort(Errors.INTERNAL, str(e))
408
831
 
409
832
  return stub_method2
410
833
 
411
834
  case _:
412
- raise Exception("Method must have exactly one or two parameters")
835
+ raise Exception("Method must have 0, 1, or 2 parameters")
413
836
 
414
837
  for method_name, method in get_rpc_methods(obj):
415
838
  if method.__name__.startswith("_"):
@@ -421,7 +844,7 @@ def connect_obj_with_stub_connecpy(connecpy_module, pb2_module, obj: object) ->
421
844
 
422
845
 
423
846
  def connect_obj_with_stub_async_connecpy(
424
- connecpy_module, pb2_module, obj: object
847
+ connecpy_module: Any, pb2_module: Any, obj: object
425
848
  ) -> type:
426
849
  """
427
850
  Connect a Python service object to a Connecpy stub for async methods.
@@ -433,7 +856,9 @@ def connect_obj_with_stub_async_connecpy(
433
856
  class ConcreteServiceClass(stub_class):
434
857
  pass
435
858
 
436
- def implement_stub_method(method):
859
+ def implement_stub_method(
860
+ method: Callable[..., Awaitable[Message]],
861
+ ) -> Callable[[object, Any, Any], Any]:
437
862
  sig = inspect.signature(method)
438
863
  arg_type = get_request_arg_type(sig)
439
864
  converter = generate_message_converter(arg_type)
@@ -441,40 +866,91 @@ def connect_obj_with_stub_async_connecpy(
441
866
  size_of_parameters = len(sig.parameters)
442
867
 
443
868
  match size_of_parameters:
869
+ case 0:
870
+ # Method with no parameters (empty request)
871
+ async def stub_method0(
872
+ self: object,
873
+ request: Any,
874
+ context: Any,
875
+ method: Callable[..., Awaitable[Message]] = method,
876
+ ) -> Any:
877
+ _ = self
878
+ try:
879
+ resp_obj = await method()
880
+
881
+ if is_none_type(response_type):
882
+ return empty_pb2.Empty() # type: ignore
883
+ else:
884
+ return convert_python_message_to_proto(
885
+ resp_obj, response_type, pb2_module
886
+ )
887
+ except ValidationError as e:
888
+ await context.abort(Errors.INVALID_ARGUMENT, str(e))
889
+ except Exception as e:
890
+ await context.abort(Errors.INTERNAL, str(e))
891
+
892
+ return stub_method0
893
+
444
894
  case 1:
445
895
 
446
- async def stub_method1(self, request, context, method=method):
896
+ async def stub_method1(
897
+ self: object,
898
+ request: Any,
899
+ context: Any,
900
+ method: Callable[..., Awaitable[Message]] = method,
901
+ ) -> Any:
902
+ _ = self
447
903
  try:
448
- arg = converter(request)
449
- resp_obj = await method(arg)
450
- return convert_python_message_to_proto(
451
- resp_obj, response_type, pb2_module
452
- )
904
+ if is_none_type(arg_type):
905
+ resp_obj = await method(None)
906
+ else:
907
+ arg = converter(request)
908
+ resp_obj = await method(arg)
909
+
910
+ if is_none_type(response_type):
911
+ return empty_pb2.Empty() # type: ignore
912
+ else:
913
+ return convert_python_message_to_proto(
914
+ resp_obj, response_type, pb2_module
915
+ )
453
916
  except ValidationError as e:
454
- await context.abort(Errors.InvalidArgument, str(e))
917
+ await context.abort(Errors.INVALID_ARGUMENT, str(e))
455
918
  except Exception as e:
456
- await context.abort(Errors.Internal, str(e))
919
+ await context.abort(Errors.INTERNAL, str(e))
457
920
 
458
921
  return stub_method1
459
922
 
460
923
  case 2:
461
924
 
462
- async def stub_method2(self, request, context, method=method):
925
+ async def stub_method2(
926
+ self: object,
927
+ request: Any,
928
+ context: Any,
929
+ method: Callable[..., Awaitable[Message]] = method,
930
+ ) -> Any:
931
+ _ = self
463
932
  try:
464
- arg = converter(request)
465
- resp_obj = await method(arg, context)
466
- return convert_python_message_to_proto(
467
- resp_obj, response_type, pb2_module
468
- )
933
+ if is_none_type(arg_type):
934
+ resp_obj = await method(None, context)
935
+ else:
936
+ arg = converter(request)
937
+ resp_obj = await method(arg, context)
938
+
939
+ if is_none_type(response_type):
940
+ return empty_pb2.Empty() # type: ignore
941
+ else:
942
+ return convert_python_message_to_proto(
943
+ resp_obj, response_type, pb2_module
944
+ )
469
945
  except ValidationError as e:
470
- await context.abort(Errors.InvalidArgument, str(e))
946
+ await context.abort(Errors.INVALID_ARGUMENT, str(e))
471
947
  except Exception as e:
472
- await context.abort(Errors.Internal, str(e))
948
+ await context.abort(Errors.INTERNAL, str(e))
473
949
 
474
950
  return stub_method2
475
951
 
476
952
  case _:
477
- raise Exception("Method must have exactly one or two parameters")
953
+ raise Exception("Method must have 0, 1, or 2 parameters")
478
954
 
479
955
  for method_name, method in get_rpc_methods(obj):
480
956
  if method.__name__.startswith("_"):
@@ -487,36 +963,202 @@ def connect_obj_with_stub_async_connecpy(
487
963
  return ConcreteServiceClass
488
964
 
489
965
 
966
+ def python_value_to_proto_oneof(
967
+ field_name: str,
968
+ field_type: type[Any],
969
+ value: Any,
970
+ pb2_module: Any,
971
+ _visited: set[int] | None = None,
972
+ _depth: int = 0,
973
+ ) -> tuple[str, Any]:
974
+ """
975
+ Converts a Python value from a Union type to a protobuf oneof field.
976
+ Returns the field name to set and the converted value.
977
+ """
978
+ union_args = [arg for arg in flatten_union(field_type) if arg is not type(None)]
979
+
980
+ # Find which subtype in the Union matches the value's type.
981
+ actual_type = None
982
+ for sub_type in union_args:
983
+ origin = get_origin(sub_type)
984
+ type_to_check = origin or sub_type
985
+ try:
986
+ if isinstance(value, type_to_check):
987
+ actual_type = sub_type
988
+ break
989
+ except TypeError:
990
+ # This can happen if `sub_type` is not a class, e.g. a generic alias
991
+ if isinstance(value, type_to_check):
992
+ actual_type = sub_type
993
+ break
994
+
995
+ if actual_type is None:
996
+ raise TypeError(f"Value of type {type(value)} not found in union {field_type}")
997
+
998
+ proto_typename = protobuf_type_mapping(actual_type)
999
+ if proto_typename is None:
1000
+ raise TypeError(f"Unsupported type in oneof: {actual_type}")
1001
+
1002
+ oneof_field_name = f"{field_name}_{proto_typename.replace('.', '_')}"
1003
+ converted_value = python_value_to_proto(
1004
+ actual_type, value, pb2_module, _visited, _depth
1005
+ )
1006
+ return oneof_field_name, converted_value
1007
+
1008
+
490
1009
  def convert_python_message_to_proto(
491
- py_msg: Message, msg_type: Type, pb2_module
1010
+ py_msg: Message,
1011
+ msg_type: type[Message],
1012
+ pb2_module: Any,
1013
+ _visited: set[int] | None = None,
1014
+ _depth: int = 0,
492
1015
  ) -> object:
1016
+ """Convert a Python Pydantic Message instance to a protobuf message instance. Used for constructing a response.
1017
+
1018
+ Args:
1019
+ py_msg: The Pydantic Message instance to convert
1020
+ msg_type: The type of the Message
1021
+ pb2_module: The protobuf module
1022
+ _visited: Internal set to track visited objects for circular reference detection
1023
+ _depth: Current recursion depth for strategy control
493
1024
  """
494
- Convert a Python Pydantic Message instance to a protobuf message instance.
495
- Used for constructing a response.
496
- """
497
- # Before calling something like pb2_module.AResponseMessage(...),
498
- # convert each field from Python to proto.
1025
+ # Handle empty message classes
1026
+ if not msg_type.model_fields:
1027
+ # Return google.protobuf.Empty instance
1028
+ return empty_pb2.Empty()
1029
+
1030
+ # Initialize visited set for circular reference detection
1031
+ if _visited is None:
1032
+ _visited = set()
1033
+
1034
+ # Check for circular references
1035
+ obj_id = id(py_msg)
1036
+ if obj_id in _visited:
1037
+ # Return empty proto to avoid infinite recursion
1038
+ proto_class = getattr(pb2_module, msg_type.__name__)
1039
+ return proto_class()
1040
+ _visited.add(obj_id)
1041
+
1042
+ # Determine if we should apply serializers based on strategy
1043
+ apply_serializers = False
1044
+ if _SERIALIZER_STRATEGY == SerializerStrategy.DEEP:
1045
+ apply_serializers = True # Always apply at any depth
1046
+ elif _SERIALIZER_STRATEGY == SerializerStrategy.SHALLOW:
1047
+ apply_serializers = _depth == 0 # Only at top level
1048
+ # SerializerStrategy.NONE: never apply
1049
+
1050
+ # Only use model_dump if there are serializers and strategy allows
1051
+ serialized_data = None
1052
+ if apply_serializers and has_serializers(msg_type):
1053
+ try:
1054
+ serialized_data = py_msg.model_dump(mode="python")
1055
+ except Exception:
1056
+ # Fallback to the old approach if model_dump fails
1057
+ serialized_data = None
1058
+
499
1059
  field_dict = {}
500
1060
  for name, field_info in msg_type.model_fields.items():
1061
+ field_type = field_info.annotation
1062
+
1063
+ # Check if this field type contains Message types
1064
+ contains_message = False
1065
+ if field_type:
1066
+ if inspect.isclass(field_type) and issubclass(field_type, Message):
1067
+ contains_message = True
1068
+ elif is_union_type(field_type):
1069
+ # Check if any of the union args are Messages
1070
+ for arg in flatten_union(field_type):
1071
+ if arg and inspect.isclass(arg) and issubclass(arg, Message):
1072
+ contains_message = True
1073
+ break
1074
+ else:
1075
+ # Check for list/dict of Messages
1076
+ origin = get_origin(field_type)
1077
+ if origin in (list, tuple):
1078
+ inner_type = (
1079
+ get_args(field_type)[0] if get_args(field_type) else None
1080
+ )
1081
+ if (
1082
+ inner_type
1083
+ and inspect.isclass(inner_type)
1084
+ and issubclass(inner_type, Message)
1085
+ ):
1086
+ contains_message = True
1087
+ elif origin is dict:
1088
+ args = get_args(field_type)
1089
+ if len(args) >= 2:
1090
+ val_type = args[1]
1091
+ if inspect.isclass(val_type) and issubclass(val_type, Message):
1092
+ contains_message = True
1093
+
1094
+ # Get the value
501
1095
  value = getattr(py_msg, name)
502
1096
  if value is None:
503
- field_dict[name] = None
504
1097
  continue
505
1098
 
1099
+ # For Message types, recursively apply serialization
1100
+ if contains_message and not is_union_type(field_type):
1101
+ # Direct Message field
1102
+ if inspect.isclass(field_type) and issubclass(field_type, Message):
1103
+ # Recursively convert nested Message with serializers
1104
+ field_dict[name] = convert_python_message_to_proto(
1105
+ value, field_type, pb2_module, _visited, _depth + 1
1106
+ )
1107
+ continue
1108
+
1109
+ # Use serialized data for non-Message types to respect custom serializers
1110
+ if (
1111
+ serialized_data is not None
1112
+ and name in serialized_data
1113
+ and not contains_message
1114
+ ):
1115
+ value = serialized_data[name]
1116
+
506
1117
  field_type = field_info.annotation
507
- field_dict[name] = python_value_to_proto(field_type, value, pb2_module)
1118
+
1119
+ # Handle oneof fields, which are represented as Unions.
1120
+ if field_type is not None and is_union_type(field_type):
1121
+ union_args = [
1122
+ arg for arg in flatten_union(field_type) if arg is not type(None)
1123
+ ]
1124
+ if len(union_args) > 1:
1125
+ # It's a oneof field. We need to determine the concrete type and
1126
+ # the corresponding protobuf field name.
1127
+ (
1128
+ oneof_field_name,
1129
+ converted_value,
1130
+ ) = python_value_to_proto_oneof(
1131
+ name, field_type, value, pb2_module, _visited, _depth
1132
+ )
1133
+ field_dict[oneof_field_name] = converted_value
1134
+ continue
1135
+
1136
+ # For regular and Optional fields that have a value.
1137
+ if field_type is not None:
1138
+ field_dict[name] = python_value_to_proto(
1139
+ field_type, value, pb2_module, _visited, _depth
1140
+ )
1141
+
1142
+ # Remove from visited set when done
1143
+ _visited.discard(obj_id)
508
1144
 
509
1145
  # Retrieve the appropriate protobuf class dynamically
510
1146
  proto_class = getattr(pb2_module, msg_type.__name__)
511
1147
  return proto_class(**field_dict)
512
1148
 
513
1149
 
514
- def python_value_to_proto(field_type: Type, value, pb2_module):
1150
+ def python_value_to_proto(
1151
+ field_type: type[Any],
1152
+ value: Any,
1153
+ pb2_module: Any,
1154
+ _visited: set[int] | None = None,
1155
+ _depth: int = 0,
1156
+ ) -> Any:
515
1157
  """
516
1158
  Perform Python->protobuf type conversion for each field value.
517
1159
  """
518
- import inspect
519
1160
  import datetime
1161
+ import inspect
520
1162
 
521
1163
  # If datetime
522
1164
  if field_type == datetime.datetime:
@@ -533,48 +1175,58 @@ def python_value_to_proto(field_type: Type, value, pb2_module):
533
1175
  origin = get_origin(field_type)
534
1176
  # If seq
535
1177
  if origin in (list, tuple):
536
- inner_type = get_args(field_type)[0] # type: ignore
537
- return [python_value_to_proto(inner_type, v, pb2_module) for v in value]
1178
+ inner_type = get_args(field_type)[0]
1179
+ # Handle list of Messages
1180
+ if inspect.isclass(inner_type) and issubclass(inner_type, Message):
1181
+ return [
1182
+ convert_python_message_to_proto(
1183
+ v, inner_type, pb2_module, _visited, _depth + 1
1184
+ )
1185
+ for v in value
1186
+ ]
1187
+ return [
1188
+ python_value_to_proto(inner_type, v, pb2_module, _visited, _depth)
1189
+ for v in value
1190
+ ]
538
1191
 
539
1192
  # If dict
540
1193
  if origin is dict:
541
- key_type, val_type = get_args(field_type) # type: ignore
1194
+ key_type, val_type = get_args(field_type)
1195
+ # Handle dict with Message values
1196
+ if inspect.isclass(val_type) and issubclass(val_type, Message):
1197
+ return {
1198
+ python_value_to_proto(
1199
+ key_type, k, pb2_module, _visited, _depth
1200
+ ): convert_python_message_to_proto(
1201
+ v, val_type, pb2_module, _visited, _depth + 1
1202
+ )
1203
+ for k, v in value.items()
1204
+ }
542
1205
  return {
543
- python_value_to_proto(key_type, k, pb2_module): python_value_to_proto(
544
- val_type, v, pb2_module
545
- )
1206
+ python_value_to_proto(
1207
+ key_type, k, pb2_module, _visited, _depth
1208
+ ): python_value_to_proto(val_type, v, pb2_module, _visited, _depth)
546
1209
  for k, v in value.items()
547
1210
  }
548
1211
 
549
- # If union -> oneof
1212
+ # If union -> oneof. This path is now only for Optional[T] where value is not None.
550
1213
  if is_union_type(field_type):
551
- # Flatten union and check which type matches. If matched, return converted value.
552
- for sub_type in flatten_union(field_type):
553
- if sub_type == datetime.datetime and isinstance(value, datetime.datetime):
554
- return python_to_timestamp(value)
555
- if sub_type == datetime.timedelta and isinstance(value, datetime.timedelta):
556
- return python_to_duration(value)
557
- if (
558
- inspect.isclass(sub_type)
559
- and issubclass(sub_type, enum.Enum)
560
- and isinstance(value, enum.Enum)
561
- ):
562
- return value.value
563
- if sub_type in (int, float, str, bool, bytes) and isinstance(
564
- value, sub_type
565
- ):
566
- return value
567
- if (
568
- inspect.isclass(sub_type)
569
- and issubclass(sub_type, Message)
570
- and isinstance(value, Message)
571
- ):
572
- return convert_python_message_to_proto(value, sub_type, pb2_module)
573
- return None
1214
+ # The value is not None, so we need to find the actual type.
1215
+ non_none_args = [
1216
+ arg for arg in flatten_union(field_type) if arg is not type(None)
1217
+ ]
1218
+ if non_none_args:
1219
+ # Assuming it's an Optional[T], so there's one type left.
1220
+ return python_value_to_proto(
1221
+ non_none_args[0], value, pb2_module, _visited, _depth
1222
+ )
1223
+ return None # Should not be reached if value is not None
574
1224
 
575
1225
  # If Message
576
1226
  if inspect.isclass(field_type) and issubclass(field_type, Message):
577
- return convert_python_message_to_proto(value, field_type, pb2_module)
1227
+ return convert_python_message_to_proto(
1228
+ value, field_type, pb2_module, _visited, _depth + 1
1229
+ )
578
1230
 
579
1231
  # If primitive
580
1232
  return value
@@ -585,12 +1237,12 @@ def python_value_to_proto(field_type: Type, value, pb2_module):
585
1237
  ###############################################################################
586
1238
 
587
1239
 
588
- def is_enum_type(python_type: Type) -> bool:
1240
+ def is_enum_type(python_type: Any) -> bool:
589
1241
  """Return True if the given Python type is an enum."""
590
1242
  return inspect.isclass(python_type) and issubclass(python_type, enum.Enum)
591
1243
 
592
1244
 
593
- def is_union_type(python_type: Type) -> bool:
1245
+ def is_union_type(python_type: Any) -> bool:
594
1246
  """
595
1247
  Check if a given Python type is a Union type (including Python 3.10's UnionType).
596
1248
  """
@@ -599,26 +1251,28 @@ def is_union_type(python_type: Type) -> bool:
599
1251
  if sys.version_info >= (3, 10):
600
1252
  import types
601
1253
 
602
- if type(python_type) is types.UnionType:
1254
+ if isinstance(python_type, types.UnionType):
603
1255
  return True
604
1256
  return False
605
1257
 
606
1258
 
607
- def flatten_union(field_type: Type) -> list[Type]:
1259
+ def flatten_union(field_type: Any) -> list[Any]:
608
1260
  """Recursively flatten nested Unions into a single list of types."""
609
1261
  if is_union_type(field_type):
610
1262
  results = []
611
1263
  for arg in get_args(field_type):
612
1264
  results.extend(flatten_union(arg))
613
1265
  return results
1266
+ elif is_none_type(field_type):
1267
+ return [field_type]
614
1268
  else:
615
1269
  return [field_type]
616
1270
 
617
1271
 
618
- def protobuf_type_mapping(python_type: Type) -> str | type | None:
1272
+ def protobuf_type_mapping(python_type: Any) -> str | None:
619
1273
  """
620
1274
  Map a Python type to a protobuf type name/class.
621
- Includes support for Timestamp and Duration.
1275
+ Includes support for Timestamp, Duration, and Empty.
622
1276
  """
623
1277
  import datetime
624
1278
 
@@ -636,8 +1290,11 @@ def protobuf_type_mapping(python_type: Type) -> str | type | None:
636
1290
  if python_type == datetime.timedelta:
637
1291
  return "google.protobuf.Duration"
638
1292
 
1293
+ if is_none_type(python_type):
1294
+ return "google.protobuf.Empty"
1295
+
639
1296
  if is_enum_type(python_type):
640
- return python_type # Will be defined as enum later
1297
+ return python_type.__name__
641
1298
 
642
1299
  if is_union_type(python_type):
643
1300
  return None # Handled separately as oneof
@@ -657,9 +1314,12 @@ def protobuf_type_mapping(python_type: Type) -> str | type | None:
657
1314
  return f"map<{key_proto_type}, {value_proto_type}>"
658
1315
 
659
1316
  if inspect.isclass(python_type) and issubclass(python_type, Message):
660
- return python_type
1317
+ # Check if it's an empty message
1318
+ if not python_type.model_fields:
1319
+ return "google.protobuf.Empty"
1320
+ return python_type.__name__
661
1321
 
662
- return mapping.get(python_type) # type: ignore
1322
+ return mapping.get(python_type)
663
1323
 
664
1324
 
665
1325
  def comment_out(docstr: str) -> tuple[str, ...]:
@@ -673,15 +1333,15 @@ def comment_out(docstr: str) -> tuple[str, ...]:
673
1333
  return tuple("//" if line == "" else f"// {line}" for line in docstr.split("\n"))
674
1334
 
675
1335
 
676
- def indent_lines(lines, indentation=" "):
1336
+ def indent_lines(lines: list[str], indentation: str = " ") -> str:
677
1337
  """Indent multiple lines with a given indentation string."""
678
1338
  return "\n".join(indentation + line for line in lines)
679
1339
 
680
1340
 
681
- def generate_enum_definition(enum_type: Type[enum.Enum]) -> str:
1341
+ def generate_enum_definition(enum_type: Any) -> str:
682
1342
  """Generate a protobuf enum definition from a Python enum."""
683
1343
  enum_name = enum_type.__name__
684
- members = []
1344
+ members: list[str] = []
685
1345
  for _, member in enum_type.__members__.items():
686
1346
  members.append(f" {member.name} = {member.value};")
687
1347
  enum_def = f"enum {enum_name} {{\n"
@@ -691,7 +1351,7 @@ def generate_enum_definition(enum_type: Type[enum.Enum]) -> str:
691
1351
 
692
1352
 
693
1353
  def generate_oneof_definition(
694
- field_name: str, union_args: list[Type], start_index: int
1354
+ field_name: str, union_args: list[Any], start_index: int
695
1355
  ) -> tuple[list[str], int]:
696
1356
  """
697
1357
  Generate a oneof block in protobuf for a union field.
@@ -705,12 +1365,6 @@ def generate_oneof_definition(
705
1365
  if proto_typename is None:
706
1366
  raise Exception(f"Nested Union not flattened properly: {arg_type}")
707
1367
 
708
- # If it's an enum or Message, use the type name.
709
- if is_enum_type(arg_type):
710
- proto_typename = arg_type.__name__
711
- elif inspect.isclass(arg_type) and issubclass(arg_type, Message):
712
- proto_typename = arg_type.__name__
713
-
714
1368
  field_alias = f"{field_name}_{proto_typename.replace('.', '_')}"
715
1369
  lines.append(f" {proto_typename} {field_alias} = {current};")
716
1370
  current += 1
@@ -718,18 +1372,57 @@ def generate_oneof_definition(
718
1372
  return lines, current
719
1373
 
720
1374
 
1375
+ def extract_nested_types(field_type: Any) -> list[Any]:
1376
+ """
1377
+ Recursively extract all Message and enum types from a field type,
1378
+ including those nested within list, dict, and union types.
1379
+ """
1380
+ extracted_types = []
1381
+
1382
+ if field_type is None or is_none_type(field_type):
1383
+ return extracted_types
1384
+
1385
+ # Check if the type itself is an enum or Message
1386
+ if is_enum_type(field_type):
1387
+ extracted_types.append(field_type)
1388
+ elif inspect.isclass(field_type) and issubclass(field_type, Message):
1389
+ extracted_types.append(field_type)
1390
+
1391
+ # Handle Union types
1392
+ if is_union_type(field_type):
1393
+ union_args = flatten_union(field_type)
1394
+ for arg in union_args:
1395
+ if arg is not type(None):
1396
+ extracted_types.extend(extract_nested_types(arg))
1397
+
1398
+ # Handle generic types (list, dict, etc.)
1399
+ origin = get_origin(field_type)
1400
+ if origin is not None:
1401
+ args = get_args(field_type)
1402
+ for arg in args:
1403
+ extracted_types.extend(extract_nested_types(arg))
1404
+
1405
+ return extracted_types
1406
+
1407
+
721
1408
  def generate_message_definition(
722
- message_type: Type[Message],
723
- done_enums: set,
724
- done_messages: set,
725
- ) -> tuple[str, list[type]]:
1409
+ message_type: Any,
1410
+ done_enums: set[Any],
1411
+ done_messages: set[Any],
1412
+ ) -> tuple[str, list[Any]]:
726
1413
  """
727
1414
  Generate a protobuf message definition for a Pydantic-based Message class.
728
1415
  Also returns any referenced types (enums, messages) that need to be defined.
729
1416
  """
730
- fields = []
731
- refs = []
1417
+ fields: list[str] = []
1418
+ refs: list[Any] = []
732
1419
  pydantic_fields = message_type.model_fields
1420
+
1421
+ # Check if this is an empty message and should use google.protobuf.Empty
1422
+ if not pydantic_fields:
1423
+ # Return a special marker that indicates this should use google.protobuf.Empty
1424
+ return "__EMPTY__", []
1425
+
733
1426
  index = 1
734
1427
 
735
1428
  for field_name, field_info in pydantic_fields.items():
@@ -737,92 +1430,130 @@ def generate_message_definition(
737
1430
  if field_type is None:
738
1431
  raise Exception(f"Field {field_name} has no type annotation.")
739
1432
 
1433
+ is_optional = False
1434
+ # Handle Union types, which may be Optional or a oneof.
740
1435
  if is_union_type(field_type):
741
1436
  union_args = flatten_union(field_type)
742
- oneof_lines, new_index = generate_oneof_definition(
743
- field_name, union_args, index
744
- )
745
- fields.extend(oneof_lines)
746
- index = new_index
747
-
748
- for utype in union_args:
749
- if is_enum_type(utype) and utype not in done_enums:
750
- refs.append(utype)
751
- elif (
752
- inspect.isclass(utype)
753
- and issubclass(utype, Message)
754
- and utype not in done_messages
755
- ):
756
- refs.append(utype)
1437
+ none_type = type(
1438
+ None
1439
+ ) # Keep this as type(None) since we're working with union args
1440
+
1441
+ if none_type in union_args or None in union_args:
1442
+ is_optional = True
1443
+ union_args = [arg for arg in union_args if not is_none_type(arg)]
1444
+
1445
+ if len(union_args) == 1:
1446
+ # This is an Optional[T]. Treat it as a simple optional field.
1447
+ field_type = union_args[0]
1448
+ elif len(union_args) > 1:
1449
+ # This is a Union of multiple types, so it becomes a `oneof`.
1450
+ oneof_lines, new_index = generate_oneof_definition(
1451
+ field_name, union_args, index
1452
+ )
1453
+ fields.extend(oneof_lines)
1454
+ index = new_index
1455
+
1456
+ for utype in union_args:
1457
+ if is_enum_type(utype) and utype not in done_enums:
1458
+ refs.append(utype)
1459
+ elif (
1460
+ inspect.isclass(utype)
1461
+ and issubclass(utype, Message)
1462
+ and utype not in done_messages
1463
+ ):
1464
+ refs.append(utype)
1465
+ continue # Proceed to the next field
1466
+ else:
1467
+ # This was a field of only `NoneType`, which is not supported.
1468
+ continue
757
1469
 
1470
+ # For regular fields or optional fields that have been unwrapped.
1471
+ proto_typename = protobuf_type_mapping(field_type)
1472
+ if proto_typename is None:
1473
+ raise Exception(f"Type {field_type} is not supported.")
1474
+
1475
+ # Extract all nested Message and enum types recursively
1476
+ nested_types = extract_nested_types(field_type)
1477
+ for nested_type in nested_types:
1478
+ if is_enum_type(nested_type) and nested_type not in done_enums:
1479
+ refs.append(nested_type)
1480
+ elif (
1481
+ inspect.isclass(nested_type)
1482
+ and issubclass(nested_type, Message)
1483
+ and nested_type not in done_messages
1484
+ ):
1485
+ refs.append(nested_type)
1486
+
1487
+ if field_info.description:
1488
+ fields.append("// " + field_info.description)
1489
+ if field_info.metadata:
1490
+ fields.append("// Constraint:")
1491
+ for metadata_item in field_info.metadata:
1492
+ match type(metadata_item):
1493
+ case annotated_types.Ge:
1494
+ fields.append(
1495
+ "// greater than or equal to " + str(metadata_item.ge)
1496
+ )
1497
+ case annotated_types.Le:
1498
+ fields.append(
1499
+ "// less than or equal to " + str(metadata_item.le)
1500
+ )
1501
+ case annotated_types.Gt:
1502
+ fields.append("// greater than " + str(metadata_item.gt))
1503
+ case annotated_types.Lt:
1504
+ fields.append("// less than " + str(metadata_item.lt))
1505
+ case annotated_types.MultipleOf:
1506
+ fields.append(
1507
+ "// multiple of " + str(metadata_item.multiple_of)
1508
+ )
1509
+ case annotated_types.Len:
1510
+ fields.append("// length of " + str(metadata_item.len))
1511
+ case annotated_types.MinLen:
1512
+ fields.append(
1513
+ "// minimum length of " + str(metadata_item.min_len)
1514
+ )
1515
+ case annotated_types.MaxLen:
1516
+ fields.append(
1517
+ "// maximum length of " + str(metadata_item.max_len)
1518
+ )
1519
+ case _:
1520
+ fields.append("// " + str(metadata_item))
1521
+
1522
+ field_definition = f"{proto_typename} {field_name} = {index};"
1523
+ if is_optional:
1524
+ field_definition = f"optional {field_definition}"
1525
+
1526
+ fields.append(field_definition)
1527
+ index += 1
1528
+
1529
+ # Add reserved fields for forward/backward compatibility if specified
1530
+ reserved_count = get_reserved_fields_count()
1531
+ if reserved_count > 0:
1532
+ start_reserved = index
1533
+ end_reserved = index + reserved_count - 1
1534
+ fields.append("")
1535
+ fields.append("// Reserved fields for future compatibility")
1536
+ if reserved_count == 1:
1537
+ fields.append(f"reserved {start_reserved};")
758
1538
  else:
759
- proto_typename = protobuf_type_mapping(field_type)
760
- if proto_typename is None:
761
- raise Exception(f"Type {field_type} is not supported.")
762
-
763
- if is_enum_type(field_type):
764
- proto_typename = field_type.__name__
765
- if field_type not in done_enums:
766
- refs.append(field_type)
767
- elif inspect.isclass(field_type) and issubclass(field_type, Message):
768
- proto_typename = field_type.__name__
769
- if field_type not in done_messages:
770
- refs.append(field_type)
771
-
772
- if field_info.description:
773
- fields.append("// " + field_info.description)
774
- if field_info.metadata:
775
- fields.append("// Constraint:")
776
- for metadata_item in field_info.metadata:
777
- match type(metadata_item):
778
- case annotated_types.Ge:
779
- fields.append(
780
- "// greater than or equal to " + str(metadata_item.ge)
781
- )
782
- case annotated_types.Le:
783
- fields.append(
784
- "// less than or equal to " + str(metadata_item.le)
785
- )
786
- case annotated_types.Gt:
787
- fields.append("// greater than " + str(metadata_item.gt))
788
- case annotated_types.Lt:
789
- fields.append("// less than " + str(metadata_item.lt))
790
- case annotated_types.MultipleOf:
791
- fields.append(
792
- "// multiple of " + str(metadata_item.multiple_of)
793
- )
794
- case annotated_types.Len:
795
- fields.append("// length of " + str(metadata_item.len))
796
- case annotated_types.MinLen:
797
- fields.append(
798
- "// minimum length of " + str(metadata_item.min_len)
799
- )
800
- case annotated_types.MaxLen:
801
- fields.append(
802
- "// maximum length of " + str(metadata_item.max_len)
803
- )
804
- case _:
805
- fields.append("// " + str(metadata_item))
806
-
807
- fields.append(f"{proto_typename} {field_name} = {index};")
808
- index += 1
1539
+ fields.append(f"reserved {start_reserved} to {end_reserved};")
809
1540
 
810
1541
  msg_def = f"message {message_type.__name__} {{\n{indent_lines(fields)}\n}}"
811
1542
  return msg_def, refs
812
1543
 
813
1544
 
814
- def is_stream_type(annotation: Type) -> bool:
1545
+ def is_stream_type(annotation: Any) -> bool:
815
1546
  return get_origin(annotation) is AsyncIterator
816
1547
 
817
1548
 
818
- def is_generic_alias(annotation: Type) -> bool:
1549
+ def is_generic_alias(annotation: Any) -> bool:
819
1550
  return get_origin(annotation) is not None
820
1551
 
821
1552
 
822
1553
  def generate_proto(obj: object, package_name: str = "") -> str:
823
1554
  """
824
1555
  Generate a .proto definition from a service class.
825
- Automatically handles Timestamp and Duration usage.
1556
+ Automatically handles Timestamp, Duration, and Empty usage.
826
1557
  """
827
1558
  import datetime
828
1559
 
@@ -831,20 +1562,23 @@ def generate_proto(obj: object, package_name: str = "") -> str:
831
1562
  service_docstr = inspect.getdoc(service_class)
832
1563
  service_comment = "\n".join(comment_out(service_docstr)) if service_docstr else ""
833
1564
 
834
- rpc_definitions = []
835
- all_type_definitions = []
836
- done_messages = set()
837
- done_enums = set()
1565
+ rpc_definitions: list[str] = []
1566
+ all_type_definitions: list[str] = []
1567
+ done_messages: set[Any] = set()
1568
+ done_enums: set[Any] = set()
838
1569
 
839
1570
  uses_timestamp = False
840
1571
  uses_duration = False
1572
+ uses_empty = False
841
1573
 
842
- def check_and_set_well_known_types(py_type: Type):
1574
+ def check_and_set_well_known_types_for_fields(py_type: Any):
1575
+ """Check well-known types for field annotations (excludes None/Empty)."""
843
1576
  nonlocal uses_timestamp, uses_duration
844
1577
  if py_type == datetime.datetime:
845
1578
  uses_timestamp = True
846
1579
  if py_type == datetime.timedelta:
847
1580
  uses_duration = True
1581
+ # Don't check for None here - Optional fields don't use Empty
848
1582
 
849
1583
  for method_name, method in get_rpc_methods(obj):
850
1584
  if method.__name__.startswith("_"):
@@ -854,35 +1588,79 @@ def generate_proto(obj: object, package_name: str = "") -> str:
854
1588
  request_type = get_request_arg_type(method_sig)
855
1589
  response_type = method_sig.return_annotation
856
1590
 
1591
+ # Validate that we don't have AsyncIterator[None] which doesn't make any practical sense
1592
+ if is_stream_type(request_type):
1593
+ stream_item_type = get_args(request_type)[0]
1594
+ if is_none_type(stream_item_type):
1595
+ raise TypeError(
1596
+ f"Method '{method_name}' has AsyncIterator[None] as input type, which is not allowed. Streaming Empty messages is meaningless."
1597
+ )
1598
+
1599
+ if is_stream_type(response_type):
1600
+ stream_item_type = get_args(response_type)[0]
1601
+ if is_none_type(stream_item_type):
1602
+ raise TypeError(
1603
+ f"Method '{method_name}' has AsyncIterator[None] as return type, which is not allowed. Streaming Empty messages is meaningless."
1604
+ )
1605
+
1606
+ # Handle NoneType for request and response
1607
+ if is_none_type(request_type):
1608
+ uses_empty = True
1609
+ if is_none_type(response_type):
1610
+ uses_empty = True
1611
+
857
1612
  # Recursively generate message definitions
858
- message_types = [request_type, response_type]
1613
+ message_types = []
1614
+ if not is_none_type(request_type):
1615
+ message_types.append(request_type)
1616
+ if not is_none_type(response_type):
1617
+ message_types.append(response_type)
1618
+
859
1619
  while message_types:
860
- mt = message_types.pop()
861
- if mt in done_messages:
1620
+ mt: type[Message] | type[ServicerContext] | None = message_types.pop()
1621
+ if mt in done_messages or mt is ServicerContext or mt is None:
862
1622
  continue
863
1623
  done_messages.add(mt)
864
1624
 
865
1625
  if is_stream_type(mt):
866
1626
  item_type = get_args(mt)[0]
867
- message_types.append(item_type)
1627
+ if not is_none_type(item_type):
1628
+ message_types.append(item_type)
868
1629
  continue
869
1630
 
1631
+ # Check if mt is actually a Message type
1632
+ if not (inspect.isclass(mt) and issubclass(mt, Message)):
1633
+ # Not a Message type, skip processing
1634
+ continue
1635
+
1636
+ mt = cast(type[Message], mt)
1637
+
870
1638
  for _, field_info in mt.model_fields.items():
871
1639
  t = field_info.annotation
872
1640
  if is_union_type(t):
873
1641
  for sub_t in flatten_union(t):
874
- check_and_set_well_known_types(sub_t)
1642
+ check_and_set_well_known_types_for_fields(
1643
+ sub_t
1644
+ ) # Use the field-specific version
875
1645
  else:
876
- check_and_set_well_known_types(t)
1646
+ check_and_set_well_known_types_for_fields(
1647
+ t
1648
+ ) # Use the field-specific version
877
1649
 
878
1650
  msg_def, refs = generate_message_definition(mt, done_enums, done_messages)
879
- mt_doc = inspect.getdoc(mt)
880
- if mt_doc:
881
- for comment_line in comment_out(mt_doc):
882
- all_type_definitions.append(comment_line)
883
1651
 
884
- all_type_definitions.append(msg_def)
885
- all_type_definitions.append("")
1652
+ # Skip adding definition if it's an empty message (will use google.protobuf.Empty)
1653
+ if msg_def != "__EMPTY__":
1654
+ mt_doc = inspect.getdoc(mt)
1655
+ if mt_doc:
1656
+ for comment_line in comment_out(mt_doc):
1657
+ all_type_definitions.append(comment_line)
1658
+
1659
+ all_type_definitions.append(msg_def)
1660
+ all_type_definitions.append("")
1661
+ else:
1662
+ # Mark that we need google.protobuf.Empty import
1663
+ uses_empty = True
886
1664
 
887
1665
  for r in refs:
888
1666
  if is_enum_type(r) and r not in done_enums:
@@ -898,16 +1676,63 @@ def generate_proto(obj: object, package_name: str = "") -> str:
898
1676
  for comment_line in comment_out(method_docstr):
899
1677
  rpc_definitions.append(comment_line)
900
1678
 
901
- if is_stream_type(response_type):
902
- item_type = get_args(response_type)[0]
903
- rpc_definitions.append(
904
- f"rpc {method_name} ({request_type.__name__}) returns (stream {item_type.__name__});"
1679
+ input_type = request_type
1680
+ input_is_stream = is_stream_type(input_type)
1681
+ output_is_stream = is_stream_type(response_type)
1682
+
1683
+ if input_is_stream:
1684
+ input_msg_type = get_args(input_type)[0]
1685
+ else:
1686
+ input_msg_type = input_type
1687
+
1688
+ if output_is_stream:
1689
+ output_msg_type = get_args(response_type)[0]
1690
+ else:
1691
+ output_msg_type = response_type
1692
+
1693
+ # Handle NoneType and empty messages by using Empty
1694
+ if input_msg_type is None or input_msg_type is ServicerContext:
1695
+ input_str = "google.protobuf.Empty" # No need to check for stream since we validated above
1696
+ if input_msg_type is ServicerContext:
1697
+ uses_empty = True
1698
+ elif (
1699
+ inspect.isclass(input_msg_type)
1700
+ and issubclass(input_msg_type, Message)
1701
+ and not input_msg_type.model_fields
1702
+ ):
1703
+ # Empty message class
1704
+ input_str = "google.protobuf.Empty"
1705
+ uses_empty = True
1706
+ else:
1707
+ input_str = (
1708
+ f"stream {input_msg_type.__name__}"
1709
+ if input_is_stream
1710
+ else input_msg_type.__name__
905
1711
  )
1712
+
1713
+ if output_msg_type is None or output_msg_type is ServicerContext:
1714
+ output_str = "google.protobuf.Empty" # No need to check for stream since we validated above
1715
+ if output_msg_type is ServicerContext:
1716
+ uses_empty = True
1717
+ elif (
1718
+ inspect.isclass(output_msg_type)
1719
+ and issubclass(output_msg_type, Message)
1720
+ and not output_msg_type.model_fields
1721
+ ):
1722
+ # Empty message class
1723
+ output_str = "google.protobuf.Empty"
1724
+ uses_empty = True
906
1725
  else:
907
- rpc_definitions.append(
908
- f"rpc {method_name} ({request_type.__name__}) returns ({response_type.__name__});"
1726
+ output_str = (
1727
+ f"stream {output_msg_type.__name__}"
1728
+ if output_is_stream
1729
+ else output_msg_type.__name__
909
1730
  )
910
1731
 
1732
+ rpc_definitions.append(
1733
+ f"rpc {method_name} ({input_str}) returns ({output_str});"
1734
+ )
1735
+
911
1736
  if not package_name:
912
1737
  if service_name.endswith("Service"):
913
1738
  package_name = service_name[: -len("Service")]
@@ -915,11 +1740,13 @@ def generate_proto(obj: object, package_name: str = "") -> str:
915
1740
  package_name = service_name
916
1741
  package_name = package_name.lower() + ".v1"
917
1742
 
918
- imports = []
1743
+ imports: list[str] = []
919
1744
  if uses_timestamp:
920
1745
  imports.append('import "google/protobuf/timestamp.proto";')
921
1746
  if uses_duration:
922
1747
  imports.append('import "google/protobuf/duration.proto";')
1748
+ if uses_empty:
1749
+ imports.append('import "google/protobuf/empty.proto";')
923
1750
 
924
1751
  import_block = "\n".join(imports)
925
1752
  if import_block:
@@ -939,106 +1766,183 @@ service {service_name} {{
939
1766
  return proto_definition
940
1767
 
941
1768
 
942
- def generate_grpc_code(proto_file, grpc_python_out) -> types.ModuleType | None:
1769
+ def generate_grpc_code(proto_path: Path) -> types.ModuleType | None:
943
1770
  """
944
- Execute the protoc command to generate Python gRPC code from the .proto file.
945
- Returns a tuple of (pb2_grpc_module, pb2_module) on success, or None if failed.
1771
+ Run protoc to generate Python gRPC code from proto_path.
1772
+ Writes foo_pb2_grpc.py next to proto_path, then imports and returns that module.
946
1773
  """
947
- command = f"-I. --grpc_python_out={grpc_python_out} {proto_file}"
948
- exit_code = protoc.main(command.split())
949
- if exit_code != 0:
950
- return None
1774
+ # 1) Ensure the .proto exists
1775
+ if not proto_path.is_file():
1776
+ raise FileNotFoundError(f"{proto_path!r} does not exist")
1777
+
1778
+ # 2) Determine output directory (same as the .proto's parent)
1779
+ proto_path = proto_path.resolve()
1780
+ out_dir = proto_path.parent
1781
+ out_dir.mkdir(parents=True, exist_ok=True)
1782
+
1783
+ # 3) Build and run the protoc command
1784
+ out_str = str(out_dir)
1785
+ well_known_path = os.path.join(os.path.dirname(grpc_tools.__file__), "_proto")
1786
+ args = [
1787
+ "protoc", # Dummy program name (required for protoc.main)
1788
+ "-I.",
1789
+ f"-I{well_known_path}",
1790
+ f"--grpc_python_out={out_str}",
1791
+ proto_path.name,
1792
+ ]
1793
+
1794
+ current_dir = os.getcwd()
1795
+ os.chdir(str(out_dir))
1796
+ try:
1797
+ if protoc.main(args) != 0:
1798
+ return None
1799
+ finally:
1800
+ os.chdir(current_dir)
951
1801
 
952
- base = os.path.splitext(proto_file)[0]
953
- generated_pb2_grpc_file = f"{base}_pb2_grpc.py"
1802
+ # 4) Locate the generated gRPC file
1803
+ base_name = proto_path.stem # "foo"
1804
+ generated_filename = f"{base_name}_pb2_grpc.py" # "foo_pb2_grpc.py"
1805
+ generated_filepath = out_dir / generated_filename
954
1806
 
955
- if grpc_python_out not in sys.path:
956
- sys.path.append(grpc_python_out)
1807
+ # 5) Add out_dir to sys.path so we can import it
1808
+ if out_str not in sys.path:
1809
+ sys.path.append(out_str)
957
1810
 
1811
+ # 6) Load and return the module
958
1812
  spec = importlib.util.spec_from_file_location(
959
- generated_pb2_grpc_file, os.path.join(grpc_python_out, generated_pb2_grpc_file)
1813
+ base_name + "_pb2_grpc", str(generated_filepath)
960
1814
  )
961
- if spec is None:
1815
+ if spec is None or spec.loader is None:
962
1816
  return None
963
- pb2_grpc_module = importlib.util.module_from_spec(spec)
964
- if spec.loader is None:
965
- return None
966
- spec.loader.exec_module(pb2_grpc_module)
967
1817
 
968
- return pb2_grpc_module
1818
+ module = importlib.util.module_from_spec(spec)
1819
+ spec.loader.exec_module(module)
1820
+ return module
969
1821
 
970
1822
 
971
- def generate_connecpy_code(
972
- proto_file: str, connecpy_out: str
973
- ) -> types.ModuleType | None:
1823
+ def generate_connecpy_code(proto_path: Path) -> types.ModuleType | None:
974
1824
  """
975
- Execute the protoc command to generate Python Connecpy code from the .proto file.
976
- Returns a tuple of (connecpy_module, pb2_module) on success, or None if failed.
1825
+ Run protoc with the Connecpy plugin to generate Python Connecpy code from proto_path.
1826
+ Writes foo_connecpy.py next to proto_path, then imports and returns that module.
977
1827
  """
978
- command = f"-I. --connecpy_out={connecpy_out} {proto_file}"
979
- exit_code = protoc.main(command.split())
980
- if exit_code != 0:
981
- return None
1828
+ # 1) Ensure the .proto exists
1829
+ if not proto_path.is_file():
1830
+ raise FileNotFoundError(f"{proto_path!r} does not exist")
1831
+
1832
+ # 2) Determine output directory (same as the .proto's parent)
1833
+ proto_path = proto_path.resolve()
1834
+ out_dir = proto_path.parent
1835
+ out_dir.mkdir(parents=True, exist_ok=True)
1836
+
1837
+ # 3) Build and run the protoc command
1838
+ out_str = str(out_dir)
1839
+ well_known_path = os.path.join(os.path.dirname(grpc_tools.__file__), "_proto")
1840
+ args = [
1841
+ "protoc", # Dummy program name (required for protoc.main)
1842
+ "-I.",
1843
+ f"-I{well_known_path}",
1844
+ f"--connecpy_out={out_str}",
1845
+ proto_path.name,
1846
+ ]
982
1847
 
983
- base = os.path.splitext(proto_file)[0]
984
- generated_connecpy_file = f"{base}_connecpy.py"
1848
+ current_dir = os.getcwd()
1849
+ os.chdir(str(out_dir))
1850
+ try:
1851
+ if protoc.main(args) != 0:
1852
+ return None
1853
+ finally:
1854
+ os.chdir(current_dir)
985
1855
 
986
- if connecpy_out not in sys.path:
987
- sys.path.append(connecpy_out)
1856
+ # 4) Locate the generated file
1857
+ base_name = proto_path.stem # "foo"
1858
+ generated_filename = f"{base_name}_connecpy.py" # "foo_connecpy.py"
1859
+ generated_filepath = out_dir / generated_filename
988
1860
 
1861
+ # 5) Add out_dir to sys.path so we can import by filename
1862
+ if out_str not in sys.path:
1863
+ sys.path.append(out_str)
1864
+
1865
+ # 6) Load and return the module
989
1866
  spec = importlib.util.spec_from_file_location(
990
- generated_connecpy_file, os.path.join(connecpy_out, generated_connecpy_file)
1867
+ base_name + "_connecpy", str(generated_filepath)
991
1868
  )
992
- if spec is None:
993
- return None
994
- connecpy_module = importlib.util.module_from_spec(spec)
995
- if spec.loader is None:
1869
+ if spec is None or spec.loader is None:
996
1870
  return None
997
- spec.loader.exec_module(connecpy_module)
998
1871
 
999
- return connecpy_module
1872
+ module = importlib.util.module_from_spec(spec)
1873
+ spec.loader.exec_module(module)
1874
+ return module
1000
1875
 
1001
1876
 
1002
- def generate_pb_code(
1003
- proto_file: str, python_out: str, pyi_out: str
1004
- ) -> types.ModuleType | None:
1877
+ def generate_pb_code(proto_path: Path) -> types.ModuleType | None:
1005
1878
  """
1006
- Execute the protoc command to generate Python gRPC code from the .proto file.
1007
- Returns a tuple of (pb2_grpc_module, pb2_module) on success, or None if failed.
1879
+ Run protoc to generate Python gRPC code from proto_path.
1880
+ Writes foo_pb2.py and foo_pb2.pyi next to proto_path, then imports and returns the pb2 module.
1008
1881
  """
1009
- command = f"-I. --python_out={python_out} --pyi_out={pyi_out} {proto_file}"
1010
- exit_code = protoc.main(command.split())
1011
- if exit_code != 0:
1012
- return None
1882
+ # 1) Make sure proto_path exists
1883
+ if not proto_path.is_file():
1884
+ raise FileNotFoundError(f"{proto_path!r} does not exist")
1885
+
1886
+ # 2) Determine output directory (same as proto file)
1887
+ proto_path = proto_path.resolve()
1888
+ out_dir = proto_path.parent
1889
+ out_dir.mkdir(parents=True, exist_ok=True)
1890
+
1891
+ # 3) Build and run protoc command
1892
+ out_str = str(out_dir)
1893
+ well_known_path = os.path.join(os.path.dirname(grpc_tools.__file__), "_proto")
1894
+ args = [
1895
+ "protoc", # Dummy program name (required for protoc.main)
1896
+ "-I.",
1897
+ f"-I{well_known_path}",
1898
+ f"--python_out={out_str}",
1899
+ f"--pyi_out={out_str}",
1900
+ proto_path.name,
1901
+ ]
1902
+
1903
+ current_dir = os.getcwd()
1904
+ os.chdir(str(out_dir))
1905
+ try:
1906
+ if protoc.main(args) != 0:
1907
+ return None
1908
+ finally:
1909
+ os.chdir(current_dir)
1013
1910
 
1014
- base = os.path.splitext(proto_file)[0]
1015
- generated_pb2_file = f"{base}_pb2.py"
1911
+ # 4) Locate generated file
1912
+ base_name = proto_path.stem # e.g. "foo"
1913
+ generated_file = out_dir / f"{base_name}_pb2.py"
1016
1914
 
1017
- if python_out not in sys.path:
1018
- sys.path.append(python_out)
1915
+ # 5) Add to sys.path if needed
1916
+ if out_str not in sys.path:
1917
+ sys.path.append(out_str)
1019
1918
 
1919
+ # 6) Import it
1020
1920
  spec = importlib.util.spec_from_file_location(
1021
- generated_pb2_file, os.path.join(python_out, generated_pb2_file)
1921
+ base_name + "_pb2", str(generated_file)
1022
1922
  )
1023
- if spec is None:
1024
- return None
1025
- pb2_module = importlib.util.module_from_spec(spec)
1026
- if spec.loader is None:
1923
+ if spec is None or spec.loader is None:
1027
1924
  return None
1028
- spec.loader.exec_module(pb2_module)
1925
+ module = importlib.util.module_from_spec(spec)
1926
+ spec.loader.exec_module(module)
1927
+ return module
1029
1928
 
1030
- return pb2_module
1031
1929
 
1930
+ def get_request_arg_type(sig: inspect.Signature) -> Any:
1931
+ """Return the type annotation of the first parameter (request) of a method.
1032
1932
 
1033
- def get_request_arg_type(sig):
1034
- """Return the type annotation of the first parameter (request) of a method."""
1933
+ If the method has no parameters, return None (implying an empty request).
1934
+ """
1035
1935
  num_of_params = len(sig.parameters)
1036
- if not (num_of_params == 1 or num_of_params == 2):
1037
- raise Exception("Method must have exactly one or two parameters")
1038
- return tuple(sig.parameters.values())[0].annotation
1936
+ if num_of_params == 0:
1937
+ # No parameters means empty request
1938
+ return None
1939
+ elif num_of_params == 1 or num_of_params == 2:
1940
+ return tuple(sig.parameters.values())[0].annotation
1941
+ else:
1942
+ raise Exception("Method must have 0, 1, or 2 parameters")
1039
1943
 
1040
1944
 
1041
- def get_rpc_methods(obj: object) -> list[tuple[str, types.MethodType]]:
1945
+ def get_rpc_methods(obj: object) -> list[tuple[str, Callable[..., Any]]]:
1042
1946
  """
1043
1947
  Retrieve the list of RPC methods from a service object.
1044
1948
  The method name is converted to PascalCase for .proto compatibility.
@@ -1059,68 +1963,429 @@ def is_skip_generation() -> bool:
1059
1963
  return os.getenv("PYDANTIC_RPC_SKIP_GENERATION", "false").lower() == "true"
1060
1964
 
1061
1965
 
1062
- def generate_and_compile_proto(obj: object, package_name: str = ""):
1966
+ def get_reserved_fields_count() -> int:
1967
+ """Get the number of reserved fields to add to each message from environment variable."""
1968
+ try:
1969
+ return max(0, int(os.getenv("PYDANTIC_RPC_RESERVED_FIELDS", "0")))
1970
+ except ValueError:
1971
+ return 0
1972
+
1973
+
1974
+ def generate_and_compile_proto(
1975
+ obj: object,
1976
+ package_name: str = "",
1977
+ existing_proto_path: Path | None = None,
1978
+ ) -> tuple[Any, Any]:
1063
1979
  if is_skip_generation():
1064
1980
  import importlib
1065
1981
 
1066
- pb2_module = importlib.import_module(f"{obj.__class__.__name__.lower()}_pb2")
1067
- pb2_grpc_module = importlib.import_module(
1068
- f"{obj.__class__.__name__.lower()}_pb2_grpc"
1069
- )
1982
+ pb2_module = None
1983
+ pb2_grpc_module = None
1984
+
1985
+ try:
1986
+ pb2_module = importlib.import_module(
1987
+ f"{obj.__class__.__name__.lower()}_pb2"
1988
+ )
1989
+ except ImportError:
1990
+ pass
1991
+
1992
+ try:
1993
+ pb2_grpc_module = importlib.import_module(
1994
+ f"{obj.__class__.__name__.lower()}_pb2_grpc"
1995
+ )
1996
+ except ImportError:
1997
+ pass
1070
1998
 
1071
1999
  if pb2_grpc_module is not None and pb2_module is not None:
1072
2000
  return pb2_grpc_module, pb2_module
1073
2001
 
1074
2002
  # If the modules are not found, generate and compile the proto files.
1075
2003
 
1076
- klass = obj.__class__
1077
- proto_file = generate_proto(obj, package_name)
1078
- proto_file_name = klass.__name__.lower() + ".proto"
2004
+ if existing_proto_path:
2005
+ # Use the provided existing proto file (skip generation)
2006
+ proto_file_path = existing_proto_path
2007
+ else:
2008
+ # Generate as before
2009
+ klass = obj.__class__
2010
+ proto_file = generate_proto(obj, package_name)
2011
+ proto_file_name = klass.__name__.lower() + ".proto"
2012
+ proto_file_path = get_proto_path(proto_file_name)
1079
2013
 
1080
- with open(proto_file_name, "w", encoding="utf-8") as f:
1081
- f.write(proto_file)
2014
+ with proto_file_path.open(mode="w", encoding="utf-8") as f:
2015
+ _ = f.write(proto_file)
1082
2016
 
1083
- gen_pb = generate_pb_code(proto_file_name, ".", ".")
2017
+ gen_pb = generate_pb_code(proto_file_path)
1084
2018
  if gen_pb is None:
1085
2019
  raise Exception("Generating pb code")
1086
2020
 
1087
- gen_grpc = generate_grpc_code(proto_file_name, ".")
2021
+ gen_grpc = generate_grpc_code(proto_file_path)
1088
2022
  if gen_grpc is None:
1089
2023
  raise Exception("Generating grpc code")
1090
2024
  return gen_grpc, gen_pb
1091
2025
 
1092
2026
 
1093
- def generate_and_compile_proto_using_connecpy(obj: object, package_name: str = ""):
2027
+ def get_proto_path(proto_filename: str) -> Path:
2028
+ # 1. Get raw env var (or default to cwd)
2029
+ raw = os.getenv("PYDANTIC_RPC_PROTO_PATH", None)
2030
+ base = Path(raw) if raw is not None else Path.cwd()
2031
+
2032
+ # 2. Expand ~ and env-vars, then make absolute
2033
+ base = Path(os.path.expandvars(os.path.expanduser(str(base)))).resolve()
2034
+
2035
+ # 3. Ensure it's a directory (or create it)
2036
+ if not base.exists():
2037
+ try:
2038
+ base.mkdir(parents=True, exist_ok=True)
2039
+ except OSError as e:
2040
+ raise RuntimeError(f"Unable to create directory {base!r}: {e}") from e
2041
+ elif not base.is_dir():
2042
+ raise NotADirectoryError(f"{base!r} exists but is not a directory")
2043
+
2044
+ # 4. Check writability
2045
+ if not os.access(base, os.W_OK):
2046
+ raise PermissionError(f"No write permission for directory {base!r}")
2047
+
2048
+ # 5. Return the final file path
2049
+ return base / proto_filename
2050
+
2051
+
2052
+ def generate_and_compile_proto_using_connecpy(
2053
+ obj: object,
2054
+ package_name: str = "",
2055
+ existing_proto_path: Path | None = None,
2056
+ ) -> tuple[Any, Any]:
1094
2057
  if is_skip_generation():
1095
2058
  import importlib
1096
2059
 
1097
- pb2_module = importlib.import_module(f"{obj.__class__.__name__.lower()}_pb2")
1098
- connecpy_module = importlib.import_module(
1099
- f"{obj.__class__.__name__.lower()}_connecpy"
1100
- )
2060
+ pb2_module = None
2061
+ connecpy_module = None
2062
+
2063
+ try:
2064
+ pb2_module = importlib.import_module(
2065
+ f"{obj.__class__.__name__.lower()}_pb2"
2066
+ )
2067
+ except ImportError:
2068
+ pass
2069
+
2070
+ try:
2071
+ connecpy_module = importlib.import_module(
2072
+ f"{obj.__class__.__name__.lower()}_connecpy"
2073
+ )
2074
+ except ImportError:
2075
+ pass
1101
2076
 
1102
2077
  if connecpy_module is not None and pb2_module is not None:
1103
2078
  return connecpy_module, pb2_module
1104
2079
 
1105
2080
  # If the modules are not found, generate and compile the proto files.
1106
2081
 
1107
- klass = obj.__class__
1108
- proto_file = generate_proto(obj, package_name)
1109
- proto_file_name = klass.__name__.lower() + ".proto"
2082
+ if existing_proto_path:
2083
+ # Use the provided existing proto file (skip generation)
2084
+ proto_file_path = existing_proto_path
2085
+ else:
2086
+ # Generate as before
2087
+ klass = obj.__class__
2088
+ proto_file = generate_proto(obj, package_name)
2089
+ proto_file_name = klass.__name__.lower() + ".proto"
1110
2090
 
1111
- with open(proto_file_name, "w", encoding="utf-8") as f:
1112
- f.write(proto_file)
2091
+ proto_file_path = get_proto_path(proto_file_name)
2092
+ with proto_file_path.open(mode="w", encoding="utf-8") as f:
2093
+ _ = f.write(proto_file)
1113
2094
 
1114
- gen_pb = generate_pb_code(proto_file_name, ".", ".")
2095
+ gen_pb = generate_pb_code(proto_file_path)
1115
2096
  if gen_pb is None:
1116
2097
  raise Exception("Generating pb code")
1117
2098
 
1118
- gen_connecpy = generate_connecpy_code(proto_file_name, ".")
2099
+ gen_connecpy = generate_connecpy_code(proto_file_path)
1119
2100
  if gen_connecpy is None:
1120
2101
  raise Exception("Generating Connecpy code")
1121
2102
  return gen_connecpy, gen_pb
1122
2103
 
1123
2104
 
2105
+ def is_combined_proto_enabled() -> bool:
2106
+ """Check if combined proto file generation is enabled."""
2107
+ return os.getenv("PYDANTIC_RPC_COMBINED_PROTO", "false").lower() == "true"
2108
+
2109
+
2110
+ def generate_combined_proto(
2111
+ *services: object, package_name: str = "combined.v1"
2112
+ ) -> str:
2113
+ """Generate a combined .proto definition from multiple service classes."""
2114
+ import datetime
2115
+
2116
+ all_type_definitions: list[str] = []
2117
+ all_service_definitions: list[str] = []
2118
+ done_messages: set[Any] = set()
2119
+ done_enums: set[Any] = set()
2120
+
2121
+ uses_timestamp = False
2122
+ uses_duration = False
2123
+ uses_empty = False
2124
+
2125
+ def check_and_set_well_known_types_for_fields(py_type: Any):
2126
+ """Check well-known types for field annotations (excludes None/Empty)."""
2127
+ nonlocal uses_timestamp, uses_duration
2128
+ if py_type == datetime.datetime:
2129
+ uses_timestamp = True
2130
+ if py_type == datetime.timedelta:
2131
+ uses_duration = True
2132
+
2133
+ # Process each service
2134
+ for service_obj in services:
2135
+ service_class = service_obj.__class__
2136
+ service_name = service_class.__name__
2137
+ service_docstr = inspect.getdoc(service_class)
2138
+ service_comment = (
2139
+ "\n".join(comment_out(service_docstr)) if service_docstr else ""
2140
+ )
2141
+
2142
+ service_rpc_definitions: list[str] = []
2143
+
2144
+ for method_name, method in get_rpc_methods(service_obj):
2145
+ if method.__name__.startswith("_"):
2146
+ continue
2147
+
2148
+ method_sig = inspect.signature(method)
2149
+ request_type = get_request_arg_type(method_sig)
2150
+ response_type = method_sig.return_annotation
2151
+
2152
+ # Validate stream types
2153
+ if is_stream_type(request_type):
2154
+ stream_item_type = get_args(request_type)[0]
2155
+ if is_none_type(stream_item_type):
2156
+ raise TypeError(
2157
+ f"Method '{method_name}' has AsyncIterator[None] as input type, which is not allowed."
2158
+ )
2159
+
2160
+ if is_stream_type(response_type):
2161
+ stream_item_type = get_args(response_type)[0]
2162
+ if is_none_type(stream_item_type):
2163
+ raise TypeError(
2164
+ f"Method '{method_name}' has AsyncIterator[None] as return type, which is not allowed."
2165
+ )
2166
+
2167
+ # Handle NoneType for request and response
2168
+ if is_none_type(request_type):
2169
+ uses_empty = True
2170
+ if is_none_type(response_type):
2171
+ uses_empty = True
2172
+
2173
+ # Validate that users aren't using protobuf messages directly
2174
+ if hasattr(request_type, "DESCRIPTOR") and not is_none_type(request_type):
2175
+ raise TypeError(
2176
+ f"Method '{method_name}' uses protobuf message '{request_type.__name__}' directly. "
2177
+ f"Please use Pydantic Message classes instead, or None/empty Message for empty requests."
2178
+ )
2179
+ if hasattr(response_type, "DESCRIPTOR") and not is_none_type(response_type):
2180
+ raise TypeError(
2181
+ f"Method '{method_name}' uses protobuf message '{response_type.__name__}' directly. "
2182
+ f"Please use Pydantic Message classes instead, or None/empty Message for empty responses."
2183
+ )
2184
+
2185
+ # Collect message types for processing
2186
+ message_types = []
2187
+ if not is_none_type(request_type):
2188
+ message_types.append(request_type)
2189
+ if not is_none_type(response_type):
2190
+ message_types.append(response_type)
2191
+
2192
+ # Process message types
2193
+ while message_types:
2194
+ mt: type[Message] | type[ServicerContext] | None = message_types.pop()
2195
+ if mt in done_messages or mt is ServicerContext or mt is None:
2196
+ continue
2197
+ done_messages.add(mt)
2198
+
2199
+ if is_stream_type(mt):
2200
+ item_type = get_args(mt)[0]
2201
+ if not is_none_type(item_type):
2202
+ message_types.append(item_type)
2203
+ continue
2204
+
2205
+ mt = cast(type[Message], mt)
2206
+
2207
+ for _, field_info in mt.model_fields.items():
2208
+ t = field_info.annotation
2209
+ if is_union_type(t):
2210
+ for sub_t in flatten_union(t):
2211
+ check_and_set_well_known_types_for_fields(sub_t)
2212
+ else:
2213
+ check_and_set_well_known_types_for_fields(t)
2214
+
2215
+ msg_def, refs = generate_message_definition(
2216
+ mt, done_enums, done_messages
2217
+ )
2218
+
2219
+ # Skip adding definition if it's an empty message (will use google.protobuf.Empty)
2220
+ if msg_def != "__EMPTY__":
2221
+ mt_doc = inspect.getdoc(mt)
2222
+ if mt_doc:
2223
+ for comment_line in comment_out(mt_doc):
2224
+ all_type_definitions.append(comment_line)
2225
+
2226
+ all_type_definitions.append(msg_def)
2227
+ all_type_definitions.append("")
2228
+ else:
2229
+ # Mark that we need google.protobuf.Empty import
2230
+ uses_empty = True
2231
+
2232
+ for r in refs:
2233
+ if is_enum_type(r) and r not in done_enums:
2234
+ done_enums.add(r)
2235
+ enum_def = generate_enum_definition(r)
2236
+ all_type_definitions.append(enum_def)
2237
+ all_type_definitions.append("")
2238
+ elif issubclass(r, Message) and r not in done_messages:
2239
+ message_types.append(r)
2240
+
2241
+ # Generate RPC definition
2242
+ method_docstr = inspect.getdoc(method)
2243
+ if method_docstr:
2244
+ for comment_line in comment_out(method_docstr):
2245
+ service_rpc_definitions.append(comment_line)
2246
+
2247
+ input_type = request_type
2248
+ input_is_stream = is_stream_type(input_type)
2249
+ output_is_stream = is_stream_type(response_type)
2250
+
2251
+ if input_is_stream:
2252
+ input_msg_type = get_args(input_type)[0]
2253
+ else:
2254
+ input_msg_type = input_type
2255
+
2256
+ if output_is_stream:
2257
+ output_msg_type = get_args(response_type)[0]
2258
+ else:
2259
+ output_msg_type = response_type
2260
+
2261
+ # Handle NoneType and empty messages by using Empty
2262
+ if input_msg_type is None or input_msg_type is ServicerContext:
2263
+ input_str = "google.protobuf.Empty" # No need to check for stream since we validated above
2264
+ if input_msg_type is ServicerContext:
2265
+ uses_empty = True
2266
+ elif (
2267
+ inspect.isclass(input_msg_type)
2268
+ and issubclass(input_msg_type, Message)
2269
+ and not input_msg_type.model_fields
2270
+ ):
2271
+ # Empty message class
2272
+ input_str = "google.protobuf.Empty"
2273
+ uses_empty = True
2274
+ else:
2275
+ input_str = (
2276
+ f"stream {input_msg_type.__name__}"
2277
+ if input_is_stream
2278
+ else input_msg_type.__name__
2279
+ )
2280
+
2281
+ if output_msg_type is None or output_msg_type is ServicerContext:
2282
+ output_str = "google.protobuf.Empty" # No need to check for stream since we validated above
2283
+ if output_msg_type is ServicerContext:
2284
+ uses_empty = True
2285
+ elif (
2286
+ inspect.isclass(output_msg_type)
2287
+ and issubclass(output_msg_type, Message)
2288
+ and not output_msg_type.model_fields
2289
+ ):
2290
+ # Empty message class
2291
+ output_str = "google.protobuf.Empty"
2292
+ uses_empty = True
2293
+ else:
2294
+ output_str = (
2295
+ f"stream {output_msg_type.__name__}"
2296
+ if output_is_stream
2297
+ else output_msg_type.__name__
2298
+ )
2299
+
2300
+ service_rpc_definitions.append(
2301
+ f"rpc {method_name} ({input_str}) returns ({output_str});"
2302
+ )
2303
+
2304
+ # Create service definition
2305
+ service_def_lines: list[str] = []
2306
+ if service_comment:
2307
+ service_def_lines.append(service_comment)
2308
+ service_def_lines.append(f"service {service_name} {{")
2309
+ service_def_lines.extend([f" {line}" for line in service_rpc_definitions])
2310
+ service_def_lines.append("}")
2311
+ service_def_lines.append("")
2312
+
2313
+ all_service_definitions.extend(service_def_lines)
2314
+
2315
+ # Build imports
2316
+ imports: list[str] = []
2317
+ if uses_timestamp:
2318
+ imports.append('import "google/protobuf/timestamp.proto";')
2319
+ if uses_duration:
2320
+ imports.append('import "google/protobuf/duration.proto";')
2321
+ if uses_empty:
2322
+ imports.append('import "google/protobuf/empty.proto";')
2323
+
2324
+ import_block = "\n".join(imports)
2325
+ if import_block:
2326
+ import_block += "\n"
2327
+
2328
+ # Combine everything
2329
+ service_defs = "\n".join(all_service_definitions)
2330
+ type_defs = "\n".join(all_type_definitions)
2331
+ proto_definition = f"""syntax = "proto3";
2332
+
2333
+ package {package_name};
2334
+
2335
+ {import_block}{service_defs}
2336
+ {type_defs}
2337
+ """
2338
+ return proto_definition
2339
+
2340
+
2341
+ def get_combined_proto_filename() -> str:
2342
+ """Get the combined proto filename."""
2343
+ return os.getenv("PYDANTIC_RPC_COMBINED_PROTO_FILENAME", "combined_services.proto")
2344
+
2345
+
2346
+ def generate_combined_descriptor_set(
2347
+ *services: object, output_path: Path | None = None
2348
+ ) -> bytes:
2349
+ """Generate a combined protobuf descriptor set from multiple services."""
2350
+ filename = get_combined_proto_filename()
2351
+
2352
+ if output_path is None:
2353
+ output_path = get_proto_path(filename)
2354
+
2355
+ # Generate combined proto file
2356
+ combined_proto = generate_combined_proto(*services)
2357
+ proto_file_path = get_proto_path(filename)
2358
+
2359
+ with proto_file_path.open(mode="w", encoding="utf-8") as f:
2360
+ _ = f.write(combined_proto)
2361
+
2362
+ # Generate descriptor set using protoc
2363
+ out_str = str(proto_file_path.parent)
2364
+ well_known_path = os.path.join(os.path.dirname(grpc_tools.__file__), "_proto")
2365
+ args = [
2366
+ "protoc",
2367
+ f"-I{out_str}",
2368
+ f"-I{well_known_path}",
2369
+ f"--descriptor_set_out={output_path}",
2370
+ "--include_imports",
2371
+ proto_file_path.name,
2372
+ ]
2373
+
2374
+ current_dir = os.getcwd()
2375
+ os.chdir(out_str)
2376
+ try:
2377
+ if protoc.main(args) != 0:
2378
+ raise RuntimeError("Failed to generate combined descriptor set")
2379
+ finally:
2380
+ os.chdir(current_dir)
2381
+
2382
+ # Read and return the descriptor set
2383
+ with open(output_path, "rb") as f:
2384
+ descriptor_data = f.read()
2385
+
2386
+ return descriptor_data
2387
+
2388
+
1124
2389
  ###############################################################################
1125
2390
  # 4. Server Implementations
1126
2391
  ###############################################################################
@@ -1129,13 +2394,13 @@ def generate_and_compile_proto_using_connecpy(obj: object, package_name: str = "
1129
2394
  class Server:
1130
2395
  """A simple gRPC server that uses ThreadPoolExecutor for concurrency."""
1131
2396
 
1132
- def __init__(self, max_workers: int = 8, *interceptors) -> None:
1133
- self._server = grpc.server(
2397
+ def __init__(self, max_workers: int = 8, *interceptors: Any) -> None:
2398
+ self._server: grpc.Server = grpc.server(
1134
2399
  futures.ThreadPoolExecutor(max_workers), interceptors=interceptors
1135
2400
  )
1136
- self._service_names = []
1137
- self._package_name = ""
1138
- self._port = 50051
2401
+ self._service_names: list[str] = []
2402
+ self._package_name: str = ""
2403
+ self._port: int = 50051
1139
2404
 
1140
2405
  def set_package_name(self, package_name: str):
1141
2406
  """Set the package name for .proto generation."""
@@ -1147,10 +2412,15 @@ class Server:
1147
2412
 
1148
2413
  def mount(self, obj: object, package_name: str = ""):
1149
2414
  """Generate and compile proto files, then mount the service implementation."""
1150
- pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
2415
+ pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name) or (
2416
+ None,
2417
+ None,
2418
+ )
1151
2419
  self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
1152
2420
 
1153
- def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
2421
+ def mount_using_pb2_modules(
2422
+ self, pb2_grpc_module: Any, pb2_module: Any, obj: object
2423
+ ):
1154
2424
  """Connect the compiled gRPC modules with the service implementation."""
1155
2425
  concreteServiceClass = connect_obj_with_stub(pb2_grpc_module, pb2_module, obj)
1156
2426
  service_name = obj.__class__.__name__
@@ -1163,7 +2433,7 @@ class Server:
1163
2433
  ].full_name
1164
2434
  self._service_names.append(full_service_name)
1165
2435
 
1166
- def run(self, *objs):
2436
+ def run(self, *objs: object):
1167
2437
  """
1168
2438
  Mount multiple services and run the gRPC server with reflection and health check.
1169
2439
  Press Ctrl+C or send SIGTERM to stop.
@@ -1183,14 +2453,16 @@ class Server:
1183
2453
  self._server.add_insecure_port(f"[::]:{self._port}")
1184
2454
  self._server.start()
1185
2455
 
1186
- def handle_signal(signal, frame):
2456
+ def handle_signal(signum: signal.Signals, frame: Any):
2457
+ _ = signum
2458
+ _ = frame
1187
2459
  print("Received shutdown signal...")
1188
2460
  self._server.stop(grace=10)
1189
2461
  print("gRPC server shutdown.")
1190
2462
  sys.exit(0)
1191
2463
 
1192
- signal.signal(signal.SIGINT, handle_signal)
1193
- signal.signal(signal.SIGTERM, handle_signal)
2464
+ _ = signal.signal(signal.SIGINT, handle_signal) # pyright:ignore[reportArgumentType]
2465
+ _ = signal.signal(signal.SIGTERM, handle_signal) # pyright:ignore[reportArgumentType]
1194
2466
 
1195
2467
  print("gRPC server is running...")
1196
2468
  while True:
@@ -1200,11 +2472,11 @@ class Server:
1200
2472
  class AsyncIOServer:
1201
2473
  """An async gRPC server using asyncio."""
1202
2474
 
1203
- def __init__(self, *interceptors) -> None:
1204
- self._server = grpc.aio.server(interceptors=interceptors)
1205
- self._service_names = []
1206
- self._package_name = ""
1207
- self._port = 50051
2475
+ def __init__(self, *interceptors: grpc.ServerInterceptor) -> None:
2476
+ self._server: grpc.aio.Server = grpc.aio.server(interceptors=interceptors)
2477
+ self._service_names: list[str] = []
2478
+ self._package_name: str = ""
2479
+ self._port: int = 50051
1208
2480
 
1209
2481
  def set_package_name(self, package_name: str):
1210
2482
  """Set the package name for .proto generation."""
@@ -1216,10 +2488,15 @@ class AsyncIOServer:
1216
2488
 
1217
2489
  def mount(self, obj: object, package_name: str = ""):
1218
2490
  """Generate and compile proto files, then mount the service implementation (async)."""
1219
- pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
2491
+ pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name) or (
2492
+ None,
2493
+ None,
2494
+ )
1220
2495
  self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
1221
2496
 
1222
- def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
2497
+ def mount_using_pb2_modules(
2498
+ self, pb2_grpc_module: Any, pb2_module: Any, obj: object
2499
+ ):
1223
2500
  """Connect the compiled gRPC modules with the async service implementation."""
1224
2501
  concreteServiceClass = connect_obj_with_stub_async(
1225
2502
  pb2_grpc_module, pb2_module, obj
@@ -1234,7 +2511,7 @@ class AsyncIOServer:
1234
2511
  ].full_name
1235
2512
  self._service_names.append(full_service_name)
1236
2513
 
1237
- async def run(self, *objs):
2514
+ async def run(self, *objs: object):
1238
2515
  """
1239
2516
  Mount multiple async services and run the gRPC server with reflection and health check.
1240
2517
  Press Ctrl+C or send SIGTERM to stop.
@@ -1251,20 +2528,22 @@ class AsyncIOServer:
1251
2528
  health_pb2_grpc.add_HealthServicer_to_server(health_servicer, self._server)
1252
2529
  reflection.enable_server_reflection(SERVICE_NAMES, self._server)
1253
2530
 
1254
- self._server.add_insecure_port(f"[::]:{self._port}")
2531
+ _ = self._server.add_insecure_port(f"[::]:{self._port}")
1255
2532
  await self._server.start()
1256
2533
 
1257
2534
  shutdown_event = asyncio.Event()
1258
2535
 
1259
- def shutdown(signum, frame):
2536
+ def shutdown(signum: signal.Signals, frame: Any):
2537
+ _ = signum
2538
+ _ = frame
1260
2539
  print("Received shutdown signal...")
1261
2540
  shutdown_event.set()
1262
2541
 
1263
2542
  for s in [signal.SIGTERM, signal.SIGINT]:
1264
- signal.signal(s, shutdown)
2543
+ _ = signal.signal(s, shutdown) # pyright:ignore[reportArgumentType]
1265
2544
 
1266
2545
  print("gRPC server is running...")
1267
- await shutdown_event.wait()
2546
+ _ = await shutdown_event.wait()
1268
2547
  await self._server.stop(10)
1269
2548
  print("gRPC server shutdown.")
1270
2549
 
@@ -1275,17 +2554,22 @@ class WSGIApp:
1275
2554
  Useful for embedding gRPC within an existing WSGI stack.
1276
2555
  """
1277
2556
 
1278
- def __init__(self, app):
1279
- self._app = grpcWSGI(app)
1280
- self._service_names = []
1281
- self._package_name = ""
2557
+ def __init__(self, app: Any):
2558
+ self._app: grpcWSGI = grpcWSGI(app)
2559
+ self._service_names: list[str] = []
2560
+ self._package_name: str = ""
1282
2561
 
1283
2562
  def mount(self, obj: object, package_name: str = ""):
1284
2563
  """Generate and compile proto files, then mount the service implementation."""
1285
- pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
2564
+ pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name) or (
2565
+ None,
2566
+ None,
2567
+ )
1286
2568
  self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
1287
2569
 
1288
- def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
2570
+ def mount_using_pb2_modules(
2571
+ self, pb2_grpc_module: Any, pb2_module: Any, obj: object
2572
+ ):
1289
2573
  """Connect the compiled gRPC modules with the service implementation."""
1290
2574
  concreteServiceClass = connect_obj_with_stub(pb2_grpc_module, pb2_module, obj)
1291
2575
  service_name = obj.__class__.__name__
@@ -1298,12 +2582,16 @@ class WSGIApp:
1298
2582
  ].full_name
1299
2583
  self._service_names.append(full_service_name)
1300
2584
 
1301
- def mount_objs(self, *objs):
2585
+ def mount_objs(self, *objs: object):
1302
2586
  """Mount multiple service objects into this WSGI app."""
1303
2587
  for obj in objs:
1304
2588
  self.mount(obj, self._package_name)
1305
2589
 
1306
- def __call__(self, environ, start_response):
2590
+ def __call__(
2591
+ self,
2592
+ environ: dict[str, Any],
2593
+ start_response: Callable[[str, list[tuple[str, str]]], None],
2594
+ ) -> Any:
1307
2595
  """WSGI entry point."""
1308
2596
  return self._app(environ, start_response)
1309
2597
 
@@ -1314,17 +2602,22 @@ class ASGIApp:
1314
2602
  Useful for embedding gRPC within an existing ASGI stack.
1315
2603
  """
1316
2604
 
1317
- def __init__(self, app):
1318
- self._app = grpcASGI(app)
1319
- self._service_names = []
1320
- self._package_name = ""
2605
+ def __init__(self, app: Any):
2606
+ self._app: grpcASGI = grpcASGI(app)
2607
+ self._service_names: list[str] = []
2608
+ self._package_name: str = ""
1321
2609
 
1322
2610
  def mount(self, obj: object, package_name: str = ""):
1323
2611
  """Generate and compile proto files, then mount the async service implementation."""
1324
- pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
2612
+ pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name) or (
2613
+ None,
2614
+ None,
2615
+ )
1325
2616
  self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
1326
2617
 
1327
- def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
2618
+ def mount_using_pb2_modules(
2619
+ self, pb2_grpc_module: Any, pb2_module: Any, obj: object
2620
+ ):
1328
2621
  """Connect the compiled gRPC modules with the async service implementation."""
1329
2622
  concreteServiceClass = connect_obj_with_stub_async(
1330
2623
  pb2_grpc_module, pb2_module, obj
@@ -1339,18 +2632,27 @@ class ASGIApp:
1339
2632
  ].full_name
1340
2633
  self._service_names.append(full_service_name)
1341
2634
 
1342
- def mount_objs(self, *objs):
2635
+ def mount_objs(self, *objs: object):
1343
2636
  """Mount multiple service objects into this ASGI app."""
1344
2637
  for obj in objs:
1345
2638
  self.mount(obj, self._package_name)
1346
2639
 
1347
- async def __call__(self, scope, receive, send):
2640
+ async def __call__(
2641
+ self,
2642
+ scope: dict[str, Any],
2643
+ receive: Callable[[], Any],
2644
+ send: Callable[[dict[str, Any]], Any],
2645
+ ) -> Any:
1348
2646
  """ASGI entry point."""
1349
- await self._app(scope, receive, send)
2647
+ _ = await self._app(scope, receive, send)
1350
2648
 
1351
2649
 
1352
- def get_connecpy_server_class(connecpy_module, service_name):
1353
- return getattr(connecpy_module, f"{service_name}Server")
2650
+ def get_connecpy_asgi_app_class(connecpy_module: Any, service_name: str):
2651
+ return getattr(connecpy_module, f"{service_name}ASGIApplication")
2652
+
2653
+
2654
+ def get_connecpy_wsgi_app_class(connecpy_module: Any, service_name: str):
2655
+ return getattr(connecpy_module, f"{service_name}WSGIApplication")
1354
2656
 
1355
2657
 
1356
2658
  class ConnecpyASGIApp:
@@ -1359,9 +2661,9 @@ class ConnecpyASGIApp:
1359
2661
  """
1360
2662
 
1361
2663
  def __init__(self):
1362
- self._app = ConnecpyASGI()
1363
- self._service_names = []
1364
- self._package_name = ""
2664
+ self._services: list[tuple[Any, str]] = [] # List of (app, path) tuples
2665
+ self._service_names: list[str] = []
2666
+ self._package_name: str = ""
1365
2667
 
1366
2668
  def mount(self, obj: object, package_name: str = ""):
1367
2669
  """Generate and compile proto files, then mount the async service implementation."""
@@ -1370,28 +2672,59 @@ class ConnecpyASGIApp:
1370
2672
  )
1371
2673
  self.mount_using_pb2_modules(connecpy_module, pb2_module, obj)
1372
2674
 
1373
- def mount_using_pb2_modules(self, connecpy_module, pb2_module, obj: object):
2675
+ def mount_using_pb2_modules(
2676
+ self, connecpy_module: Any, pb2_module: Any, obj: object
2677
+ ):
1374
2678
  """Connect the compiled connecpy and pb2 modules with the async service implementation."""
1375
2679
  concreteServiceClass = connect_obj_with_stub_async_connecpy(
1376
2680
  connecpy_module, pb2_module, obj
1377
2681
  )
1378
2682
  service_name = obj.__class__.__name__
1379
2683
  service_impl = concreteServiceClass()
1380
- connecpy_server = get_connecpy_server_class(connecpy_module, service_name)
1381
- self._app.add_service(connecpy_server(service=service_impl))
2684
+
2685
+ # Get the service-specific ASGI application class
2686
+ app_class = get_connecpy_asgi_app_class(connecpy_module, service_name)
2687
+ app = app_class(service=service_impl)
2688
+
2689
+ # Store the app and its path for routing
2690
+ self._services.append((app, app.path))
2691
+
1382
2692
  full_service_name = pb2_module.DESCRIPTOR.services_by_name[
1383
2693
  service_name
1384
2694
  ].full_name
1385
2695
  self._service_names.append(full_service_name)
1386
2696
 
1387
- def mount_objs(self, *objs):
2697
+ def mount_objs(self, *objs: object):
1388
2698
  """Mount multiple service objects into this ASGI app."""
1389
2699
  for obj in objs:
1390
2700
  self.mount(obj, self._package_name)
1391
2701
 
1392
- async def __call__(self, scope, receive, send):
1393
- """ASGI entry point."""
1394
- await self._app(scope, receive, send)
2702
+ async def __call__(
2703
+ self,
2704
+ scope: dict[str, Any],
2705
+ receive: Callable[[], Any],
2706
+ send: Callable[[dict[str, Any]], Any],
2707
+ ):
2708
+ """ASGI entry point with routing for multiple services."""
2709
+ if scope["type"] != "http":
2710
+ await send({"type": "http.response.start", "status": 404})
2711
+ await send({"type": "http.response.body", "body": b"Not Found"})
2712
+ return
2713
+
2714
+ path = scope.get("path", "")
2715
+
2716
+ # Route to the appropriate service based on path
2717
+ for app, service_path in self._services:
2718
+ if path.startswith(service_path):
2719
+ return await app(scope, receive, send)
2720
+
2721
+ # If only one service is mounted, use it as default
2722
+ if len(self._services) == 1:
2723
+ return await self._services[0][0](scope, receive, send)
2724
+
2725
+ # No matching service found
2726
+ await send({"type": "http.response.start", "status": 404})
2727
+ await send({"type": "http.response.body", "body": b"Not Found"})
1395
2728
 
1396
2729
 
1397
2730
  class ConnecpyWSGIApp:
@@ -1400,55 +2733,82 @@ class ConnecpyWSGIApp:
1400
2733
  """
1401
2734
 
1402
2735
  def __init__(self):
1403
- self._app = ConnecpyWSGI()
1404
- self._service_names = []
1405
- self._package_name = ""
2736
+ self._services: list[tuple[Any, str]] = [] # List of (app, path) tuples
2737
+ self._service_names: list[str] = []
2738
+ self._package_name: str = ""
1406
2739
 
1407
2740
  def mount(self, obj: object, package_name: str = ""):
1408
- """Generate and compile proto files, then mount the async service implementation."""
2741
+ """Generate and compile proto files, then mount the sync service implementation."""
1409
2742
  connecpy_module, pb2_module = generate_and_compile_proto_using_connecpy(
1410
2743
  obj, package_name
1411
2744
  )
1412
2745
  self.mount_using_pb2_modules(connecpy_module, pb2_module, obj)
1413
2746
 
1414
- def mount_using_pb2_modules(self, connecpy_module, pb2_module, obj: object):
1415
- """Connect the compiled connecpy and pb2 modules with the async service implementation."""
2747
+ def mount_using_pb2_modules(
2748
+ self, connecpy_module: Any, pb2_module: Any, obj: object
2749
+ ):
2750
+ """Connect the compiled connecpy and pb2 modules with the sync service implementation."""
1416
2751
  concreteServiceClass = connect_obj_with_stub_connecpy(
1417
2752
  connecpy_module, pb2_module, obj
1418
2753
  )
1419
2754
  service_name = obj.__class__.__name__
1420
2755
  service_impl = concreteServiceClass()
1421
- connecpy_server = get_connecpy_server_class(connecpy_module, service_name)
1422
- self._app.add_service(connecpy_server(service=service_impl))
2756
+
2757
+ # Get the service-specific WSGI application class
2758
+ app_class = get_connecpy_wsgi_app_class(connecpy_module, service_name)
2759
+ app = app_class(service=service_impl)
2760
+
2761
+ # Store the app and its path for routing
2762
+ self._services.append((app, app.path))
2763
+
1423
2764
  full_service_name = pb2_module.DESCRIPTOR.services_by_name[
1424
2765
  service_name
1425
2766
  ].full_name
1426
2767
  self._service_names.append(full_service_name)
1427
2768
 
1428
- def mount_objs(self, *objs):
2769
+ def mount_objs(self, *objs: object):
1429
2770
  """Mount multiple service objects into this WSGI app."""
1430
2771
  for obj in objs:
1431
2772
  self.mount(obj, self._package_name)
1432
2773
 
1433
- def __call__(self, environ, start_response):
1434
- """WSGI entry point."""
1435
- return self._app(environ, start_response)
2774
+ def __call__(
2775
+ self,
2776
+ environ: dict[str, Any],
2777
+ start_response: Callable[[str, list[tuple[str, str]]], None],
2778
+ ) -> Any:
2779
+ """WSGI entry point with routing for multiple services."""
2780
+ path = environ.get("PATH_INFO", "")
2781
+
2782
+ # Route to the appropriate service based on path
2783
+ for app, service_path in self._services:
2784
+ if path.startswith(service_path):
2785
+ return app(environ, start_response)
2786
+
2787
+ # If only one service is mounted, use it as default
2788
+ if len(self._services) == 1:
2789
+ return self._services[0][0](environ, start_response)
2790
+
2791
+ # No matching service found
2792
+ start_response("404 Not Found", [("Content-Type", "text/plain")])
2793
+ return [b"Not Found"]
1436
2794
 
1437
2795
 
1438
2796
  def main():
1439
2797
  import argparse
1440
2798
 
1441
2799
  parser = argparse.ArgumentParser(description="Generate and compile proto files.")
1442
- parser.add_argument(
2800
+ _ = parser.add_argument(
1443
2801
  "py_file", type=str, help="The Python file containing the service class."
1444
2802
  )
1445
- parser.add_argument("class_name", type=str, help="The name of the service class.")
2803
+ _ = parser.add_argument(
2804
+ "class_name", type=str, help="The name of the service class."
2805
+ )
1446
2806
  args = parser.parse_args()
1447
2807
 
1448
2808
  module_name = os.path.splitext(basename(args.py_file))[0]
1449
2809
  module = importlib.import_module(module_name)
1450
2810
  klass = getattr(module, args.class_name)
1451
- generate_and_compile_proto(klass())
2811
+ _ = generate_and_compile_proto(klass())
1452
2812
 
1453
2813
 
1454
2814
  if __name__ == "__main__":