pydantic-rpc 0.6.1__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pydantic_rpc/core.py CHANGED
@@ -1,5 +1,5 @@
1
- import annotated_types
2
1
  import asyncio
2
+ import datetime
3
3
  import enum
4
4
  import importlib.util
5
5
  import inspect
@@ -8,33 +8,37 @@ import signal
8
8
  import sys
9
9
  import time
10
10
  import types
11
- import datetime
11
+ from typing import Union
12
+ from collections.abc import AsyncIterator, Awaitable, Callable
12
13
  from concurrent import futures
14
+ from pathlib import Path
13
15
  from posixpath import basename
14
16
  from typing import (
15
- Callable,
16
- Type,
17
+ Any,
18
+ TypeAlias,
17
19
  get_args,
18
20
  get_origin,
19
- Union,
20
- TypeAlias,
21
+ cast,
22
+ TypeGuard,
21
23
  )
22
- from collections.abc import AsyncIterator
23
24
 
25
+ import annotated_types
24
26
  import grpc
27
+ from grpc import ServicerContext
28
+ import grpc_tools
29
+ from connecpy.server import ConnecpyASGIApplication as ConnecpyASGI
30
+ from connecpy.server import ConnecpyWSGIApplication as ConnecpyWSGI
31
+ from connecpy.code import Code as Errors
32
+
33
+ # Protobuf Python modules for Timestamp, Duration (requires protobuf / grpcio)
34
+ from google.protobuf import duration_pb2, timestamp_pb2, empty_pb2
25
35
  from grpc_health.v1 import health_pb2, health_pb2_grpc
26
36
  from grpc_health.v1.health import HealthServicer
27
37
  from grpc_reflection.v1alpha import reflection
28
38
  from grpc_tools import protoc
29
39
  from pydantic import BaseModel, ValidationError
30
- from sonora.wsgi import grpcWSGI
31
40
  from sonora.asgi import grpcASGI
32
- from connecpy.asgi import ConnecpyASGIApp as ConnecpyASGI
33
- from connecpy.errors import Errors
34
- from connecpy.wsgi import ConnecpyWSGIApp as ConnecpyWSGI
35
-
36
- # Protobuf Python modules for Timestamp, Duration (requires protobuf / grpcio)
37
- from google.protobuf import timestamp_pb2, duration_pb2
41
+ from sonora.wsgi import grpcWSGI
38
42
 
39
43
  ###############################################################################
40
44
  # 1. Message definitions & converter extensions
@@ -46,7 +50,12 @@ from google.protobuf import timestamp_pb2, duration_pb2
46
50
  Message: TypeAlias = BaseModel
47
51
 
48
52
 
49
- def primitiveProtoValueToPythonValue(value):
53
+ def is_none_type(annotation: Any) -> TypeGuard[type[None] | None]:
54
+ """Check if annotation represents None/NoneType (handles both None and type(None))."""
55
+ return annotation is None or annotation is type(None)
56
+
57
+
58
+ def primitiveProtoValueToPythonValue(value: Any):
50
59
  # Returns the value as-is (primitive type).
51
60
  return value
52
61
 
@@ -75,11 +84,20 @@ def python_to_duration(td: datetime.timedelta) -> duration_pb2.Duration: # type
75
84
  return d
76
85
 
77
86
 
78
- def generate_converter(annotation: Type) -> Callable:
87
+ def generate_converter(annotation: type[Any] | None) -> Callable[[Any], Any]:
79
88
  """
80
89
  Returns a converter function to convert protobuf types to Python types.
81
90
  This is used primarily when handling incoming requests.
82
91
  """
92
+ # For NoneType (Empty messages)
93
+ if is_none_type(annotation):
94
+
95
+ def empty_converter(value: empty_pb2.Empty): # type: ignore
96
+ _ = value
97
+ return None
98
+
99
+ return empty_converter
100
+
83
101
  # For primitive types
84
102
  if annotation in (int, str, bool, bytes, float):
85
103
  return primitiveProtoValueToPythonValue
@@ -87,7 +105,7 @@ def generate_converter(annotation: Type) -> Callable:
87
105
  # For enum types
88
106
  if inspect.isclass(annotation) and issubclass(annotation, enum.Enum):
89
107
 
90
- def enum_converter(value):
108
+ def enum_converter(value: enum.Enum):
91
109
  return annotation(value)
92
110
 
93
111
  return enum_converter
@@ -114,7 +132,7 @@ def generate_converter(annotation: Type) -> Callable:
114
132
  if origin in (list, tuple):
115
133
  item_converter = generate_converter(get_args(annotation)[0])
116
134
 
117
- def seq_converter(value):
135
+ def seq_converter(value: list[Any] | tuple[Any, ...]):
118
136
  return [item_converter(v) for v in value]
119
137
 
120
138
  return seq_converter
@@ -124,7 +142,7 @@ def generate_converter(annotation: Type) -> Callable:
124
142
  key_converter = generate_converter(get_args(annotation)[0])
125
143
  value_converter = generate_converter(get_args(annotation)[1])
126
144
 
127
- def dict_converter(value):
145
+ def dict_converter(value: dict[Any, Any]):
128
146
  return {key_converter(k): value_converter(v) for k, v in value.items()}
129
147
 
130
148
  return dict_converter
@@ -137,27 +155,95 @@ def generate_converter(annotation: Type) -> Callable:
137
155
  return primitiveProtoValueToPythonValue
138
156
 
139
157
 
140
- def generate_message_converter(arg_type: Type[Message]) -> Callable:
158
+ def generate_message_converter(
159
+ arg_type: type[Message] | type[None] | None,
160
+ ) -> Callable[[Any], Message | None]:
141
161
  """Return a converter function for protobuf -> Python Message."""
142
- if arg_type is None or not issubclass(arg_type, Message):
143
- raise TypeError("Request arg must be subclass of Message")
144
162
 
163
+ # Handle NoneType (Empty messages)
164
+ if is_none_type(arg_type):
165
+
166
+ def empty_converter(request: Any) -> None:
167
+ _ = request
168
+ return None
169
+
170
+ return empty_converter
171
+
172
+ arg_type = cast("type[Message]", arg_type)
145
173
  fields = arg_type.model_fields
146
174
  converters = {
147
175
  field: generate_converter(field_type.annotation) # type: ignore
148
176
  for field, field_type in fields.items()
149
177
  }
150
178
 
151
- def converter(request):
179
+ def converter(request: Any) -> Message:
152
180
  rdict = {}
153
- for field in fields.keys():
154
- rdict[field] = converters[field](getattr(request, field))
181
+ for field_name, field_info in fields.items():
182
+ field_type = field_info.annotation
183
+
184
+ # Check if this is a union type
185
+ if field_type is not None and is_union_type(field_type):
186
+ union_args = flatten_union(field_type)
187
+ has_none = type(None) in union_args
188
+ non_none_args = [arg for arg in union_args if arg is not type(None)]
189
+
190
+ if has_none and len(non_none_args) == 1:
191
+ # This is Optional[T] - check if protobuf field is set
192
+ try:
193
+ if hasattr(request, "HasField") and request.HasField(
194
+ field_name
195
+ ):
196
+ rdict[field_name] = converters[field_name](
197
+ getattr(request, field_name)
198
+ )
199
+ else:
200
+ # Field not set in protobuf, set to None for Optional fields
201
+ rdict[field_name] = None
202
+ except ValueError:
203
+ # HasField doesn't work for this field type (e.g., repeated fields)
204
+ # Fall back to regular conversion
205
+ rdict[field_name] = converters[field_name](
206
+ getattr(request, field_name)
207
+ )
208
+ elif len(non_none_args) > 1:
209
+ # This is a oneof field (Union[str, int] etc.)
210
+ # Check which oneof field is set
211
+ try:
212
+ which_field = request.WhichOneof(field_name)
213
+ if which_field:
214
+ # Extract the value from the set oneof field
215
+ proto_value = getattr(request, which_field)
216
+
217
+ # Determine the Python type from the oneof field name
218
+ # e.g., "value_string" -> str, "value_int32" -> int
219
+ for union_arg in non_none_args:
220
+ proto_typename = protobuf_type_mapping(union_arg)
221
+ if (
222
+ proto_typename
223
+ and which_field
224
+ == f"{field_name}_{proto_typename.replace('.', '_')}"
225
+ ):
226
+ # Convert using the specific type converter
227
+ type_converter = generate_converter(union_arg)
228
+ rdict[field_name] = type_converter(proto_value)
229
+ break
230
+ except (AttributeError, ValueError):
231
+ # WhichOneof failed, try fallback
232
+ # This shouldn't happen for properly generated oneof fields
233
+ pass
234
+ else:
235
+ # Union with only None type (shouldn't happen)
236
+ pass
237
+ else:
238
+ # For non-union fields, convert normally
239
+ rdict[field_name] = converters[field_name](getattr(request, field_name))
240
+
155
241
  return arg_type(**rdict)
156
242
 
157
243
  return converter
158
244
 
159
245
 
160
- def python_value_to_proto_value(field_type: Type, value):
246
+ def python_value_to_proto_value(field_type: type[Any], value: Any) -> Any:
161
247
  """
162
248
  Converts Python values to protobuf values.
163
249
  Used primarily when constructing a response object.
@@ -179,77 +265,112 @@ def python_value_to_proto_value(field_type: Type, value):
179
265
  ###############################################################################
180
266
 
181
267
 
182
- def connect_obj_with_stub(pb2_grpc_module, pb2_module, service_obj: object) -> type:
268
+ def connect_obj_with_stub(
269
+ pb2_grpc_module: Any, pb2_module: Any, service_obj: object
270
+ ) -> type:
183
271
  """
184
272
  Connect a Python service object to a gRPC stub, generating server methods.
273
+ Returns a subclass of the generated Servicer stub with concrete implementations.
185
274
  """
186
275
  service_class = service_obj.__class__
187
276
  stub_class_name = service_class.__name__ + "Servicer"
188
277
  stub_class = getattr(pb2_grpc_module, stub_class_name)
189
278
 
190
279
  class ConcreteServiceClass(stub_class):
280
+ """Dynamically generated servicer class with stub methods implemented."""
281
+
191
282
  pass
192
283
 
193
- def implement_stub_method(method):
194
- # Analyze method signature.
284
+ def implement_stub_method(
285
+ method: Callable[..., Message],
286
+ ) -> Callable[[object, Any, Any], Any]:
287
+ """
288
+ Wraps a user-defined method (self, *args) -> R into a gRPC stub signature:
289
+ (self, request_proto, context) -> response_proto
290
+ """
195
291
  sig = inspect.signature(method)
196
292
  arg_type = get_request_arg_type(sig)
197
- # Convert request from protobuf to Python.
198
- converter = generate_message_converter(arg_type)
199
-
200
293
  response_type = sig.return_annotation
201
- size_of_parameters = len(sig.parameters)
202
-
203
- match size_of_parameters:
204
- case 1:
294
+ param_count = len(sig.parameters)
295
+ converter = generate_message_converter(arg_type)
205
296
 
206
- def stub_method1(self, request, context, method=method):
207
- try:
208
- # Convert request to Python object
297
+ if param_count == 1:
298
+
299
+ def stub_method(
300
+ self: object,
301
+ request: Any,
302
+ context: Any,
303
+ *,
304
+ original: Callable[..., Message] = method,
305
+ ) -> Any:
306
+ _ = self
307
+ try:
308
+ if is_none_type(arg_type):
309
+ resp_obj = original(None) # Fixed: pass None instead of no args
310
+ else:
209
311
  arg = converter(request)
210
- # Invoke the actual method
211
- resp_obj = method(arg)
212
- # Convert the returned Python Message to a protobuf message
312
+ resp_obj = original(arg)
313
+
314
+ if is_none_type(response_type):
315
+ return empty_pb2.Empty() # type: ignore
316
+ else:
213
317
  return convert_python_message_to_proto(
214
318
  resp_obj, response_type, pb2_module
215
319
  )
216
- except ValidationError as e:
217
- return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
218
- except Exception as e:
219
- return context.abort(grpc.StatusCode.INTERNAL, str(e))
220
-
221
- return stub_method1
222
-
223
- case 2:
224
-
225
- def stub_method2(self, request, context, method=method):
226
- try:
320
+ except ValidationError as e:
321
+ return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
322
+ except Exception as e:
323
+ return context.abort(grpc.StatusCode.INTERNAL, str(e))
324
+
325
+ elif param_count == 2:
326
+
327
+ def stub_method(
328
+ self: object,
329
+ request: Any,
330
+ context: Any,
331
+ *,
332
+ original: Callable[..., Message] = method,
333
+ ) -> Any:
334
+ _ = self
335
+ try:
336
+ if is_none_type(arg_type):
337
+ resp_obj = original(
338
+ None, context
339
+ ) # Fixed: pass None instead of Empty
340
+ else:
227
341
  arg = converter(request)
228
- resp_obj = method(arg, context)
342
+ resp_obj = original(arg, context)
343
+
344
+ if is_none_type(response_type):
345
+ return empty_pb2.Empty() # type: ignore
346
+ else:
229
347
  return convert_python_message_to_proto(
230
348
  resp_obj, response_type, pb2_module
231
349
  )
232
- except ValidationError as e:
233
- return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
234
- except Exception as e:
235
- return context.abort(grpc.StatusCode.INTERNAL, str(e))
350
+ except ValidationError as e:
351
+ return context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
352
+ except Exception as e:
353
+ return context.abort(grpc.StatusCode.INTERNAL, str(e))
236
354
 
237
- return stub_method2
355
+ else:
356
+ raise TypeError(
357
+ f"Method '{method.__name__}' must have exactly 1 or 2 parameters, got {param_count}"
358
+ )
238
359
 
239
- case _:
240
- raise Exception("Method must have exactly one or two parameters")
360
+ return stub_method
241
361
 
362
+ # Attach all RPC methods from service_obj to the concrete servicer
242
363
  for method_name, method in get_rpc_methods(service_obj):
243
- if method.__name__.startswith("_"):
364
+ if method_name.startswith("_"):
244
365
  continue
245
-
246
- a_method = implement_stub_method(method)
247
- setattr(ConcreteServiceClass, method_name, a_method)
366
+ setattr(ConcreteServiceClass, method_name, implement_stub_method(method))
248
367
 
249
368
  return ConcreteServiceClass
250
369
 
251
370
 
252
- def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> type:
371
+ def connect_obj_with_stub_async(
372
+ pb2_grpc_module: Any, pb2_module: Any, obj: object
373
+ ) -> type:
253
374
  """
254
375
  Connect a Python service object to a gRPC stub for async methods.
255
376
  """
@@ -260,26 +381,151 @@ def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> typ
260
381
  class ConcreteServiceClass(stub_class):
261
382
  pass
262
383
 
263
- def implement_stub_method(method):
384
+ def implement_stub_method(
385
+ method: Callable[..., Any],
386
+ ) -> Callable[[object, Any, Any], Any]:
264
387
  sig = inspect.signature(method)
265
- arg_type = get_request_arg_type(sig)
266
- converter = generate_message_converter(arg_type)
388
+ input_type = get_request_arg_type(sig)
389
+ is_input_stream = is_stream_type(input_type)
267
390
  response_type = sig.return_annotation
391
+ is_output_stream = is_stream_type(response_type)
268
392
  size_of_parameters = len(sig.parameters)
269
393
 
270
- if is_stream_type(response_type):
271
- item_type = get_args(response_type)[0]
272
- match size_of_parameters:
273
- case 1:
394
+ if size_of_parameters not in (1, 2):
395
+ raise TypeError(
396
+ f"Method '{method.__name__}' must have 1 or 2 parameters, got {size_of_parameters}"
397
+ )
274
398
 
275
- async def stub_method_stream1(
276
- self, request, context, method=method
277
- ):
399
+ if is_input_stream:
400
+ input_item_type = get_args(input_type)[0]
401
+ item_converter = generate_message_converter(input_item_type)
402
+
403
+ async def convert_iterator(
404
+ proto_iter: AsyncIterator[Any],
405
+ ) -> AsyncIterator[Message]:
406
+ async for proto in proto_iter:
407
+ result = item_converter(proto)
408
+ if result is None:
409
+ raise TypeError(
410
+ f"Unexpected None result from converter for type {input_item_type}"
411
+ )
412
+ yield result
413
+
414
+ if is_output_stream:
415
+ # stream-stream
416
+ output_item_type = get_args(response_type)[0]
417
+
418
+ if size_of_parameters == 1:
419
+
420
+ async def stub_method(
421
+ self: object,
422
+ request_iterator: AsyncIterator[Any],
423
+ context: Any,
424
+ ) -> AsyncIterator[Any]:
425
+ _ = self
426
+ try:
427
+ arg_iter = convert_iterator(request_iterator)
428
+ async for resp_obj in method(arg_iter):
429
+ yield convert_python_message_to_proto(
430
+ resp_obj, output_item_type, pb2_module
431
+ )
432
+ except ValidationError as e:
433
+ await context.abort(
434
+ grpc.StatusCode.INVALID_ARGUMENT, str(e)
435
+ )
436
+ except Exception as e:
437
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
438
+
439
+ else: # size_of_parameters == 2
440
+
441
+ async def stub_method(
442
+ self: object,
443
+ request_iterator: AsyncIterator[Any],
444
+ context: Any,
445
+ ) -> AsyncIterator[Any]:
446
+ _ = self
447
+ try:
448
+ arg_iter = convert_iterator(request_iterator)
449
+ async for resp_obj in method(arg_iter, context):
450
+ yield convert_python_message_to_proto(
451
+ resp_obj, output_item_type, pb2_module
452
+ )
453
+ except ValidationError as e:
454
+ await context.abort(
455
+ grpc.StatusCode.INVALID_ARGUMENT, str(e)
456
+ )
457
+ except Exception as e:
458
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
459
+
460
+ return stub_method
461
+
462
+ else:
463
+ # stream-unary
464
+ if size_of_parameters == 1:
465
+
466
+ async def stub_method(
467
+ self: object,
468
+ request_iterator: AsyncIterator[Any],
469
+ context: Any,
470
+ ) -> Any:
471
+ _ = self
472
+ try:
473
+ arg_iter = convert_iterator(request_iterator)
474
+ resp_obj = await method(arg_iter)
475
+ return convert_python_message_to_proto(
476
+ resp_obj, response_type, pb2_module
477
+ )
478
+ except ValidationError as e:
479
+ await context.abort(
480
+ grpc.StatusCode.INVALID_ARGUMENT, str(e)
481
+ )
482
+ except Exception as e:
483
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
484
+
485
+ else: # size_of_parameters == 2
486
+
487
+ async def stub_method(
488
+ self: object,
489
+ request_iterator: AsyncIterator[Any],
490
+ context: Any,
491
+ ) -> Any:
492
+ _ = self
493
+ try:
494
+ arg_iter = convert_iterator(request_iterator)
495
+ resp_obj = await method(arg_iter, context)
496
+ return convert_python_message_to_proto(
497
+ resp_obj, response_type, pb2_module
498
+ )
499
+ except ValidationError as e:
500
+ await context.abort(
501
+ grpc.StatusCode.INVALID_ARGUMENT, str(e)
502
+ )
503
+ except Exception as e:
504
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
505
+
506
+ return stub_method
507
+
508
+ else:
509
+ # unary input
510
+ converter = generate_message_converter(input_type)
511
+
512
+ if is_output_stream:
513
+ # unary-stream
514
+ output_item_type = get_args(response_type)[0]
515
+
516
+ if size_of_parameters == 1:
517
+
518
+ async def stub_method(
519
+ self: object,
520
+ request: Any,
521
+ context: Any,
522
+ ) -> AsyncIterator[Any]:
523
+ _ = self
278
524
  try:
279
525
  arg = converter(request)
280
526
  async for resp_obj in method(arg):
281
527
  yield convert_python_message_to_proto(
282
- resp_obj, item_type, pb2_module
528
+ resp_obj, output_item_type, pb2_module
283
529
  )
284
530
  except ValidationError as e:
285
531
  await context.abort(
@@ -288,17 +534,19 @@ def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> typ
288
534
  except Exception as e:
289
535
  await context.abort(grpc.StatusCode.INTERNAL, str(e))
290
536
 
291
- return stub_method_stream1
292
- case 2:
537
+ else: # size_of_parameters == 2
293
538
 
294
- async def stub_method_stream2(
295
- self, request, context, method=method
296
- ):
539
+ async def stub_method(
540
+ self: object,
541
+ request: Any,
542
+ context: Any,
543
+ ) -> AsyncIterator[Any]:
544
+ _ = self
297
545
  try:
298
546
  arg = converter(request)
299
547
  async for resp_obj in method(arg, context):
300
548
  yield convert_python_message_to_proto(
301
- resp_obj, item_type, pb2_module
549
+ resp_obj, output_item_type, pb2_module
302
550
  )
303
551
  except ValidationError as e:
304
552
  await context.abort(
@@ -307,45 +555,67 @@ def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> typ
307
555
  except Exception as e:
308
556
  await context.abort(grpc.StatusCode.INTERNAL, str(e))
309
557
 
310
- return stub_method_stream2
311
- case _:
312
- raise Exception("Method must have exactly one or two parameters")
313
-
314
- match size_of_parameters:
315
- case 1:
316
-
317
- async def stub_method1(self, request, context, method=method):
318
- try:
319
- arg = converter(request)
320
- resp_obj = await method(arg)
321
- return convert_python_message_to_proto(
322
- resp_obj, response_type, pb2_module
323
- )
324
- except ValidationError as e:
325
- await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
326
- except Exception as e:
327
- await context.abort(grpc.StatusCode.INTERNAL, str(e))
558
+ return stub_method
328
559
 
329
- return stub_method1
560
+ else:
561
+ # unary-unary
562
+ if size_of_parameters == 1:
330
563
 
331
- case 2:
564
+ async def stub_method(
565
+ self: object,
566
+ request: Any,
567
+ context: Any,
568
+ ) -> Any:
569
+ _ = self
570
+ try:
571
+ if is_none_type(input_type):
572
+ resp_obj = await method(None)
573
+ else:
574
+ arg = converter(request)
575
+ resp_obj = await method(arg)
576
+
577
+ if is_none_type(response_type):
578
+ return empty_pb2.Empty() # type: ignore
579
+ else:
580
+ return convert_python_message_to_proto(
581
+ resp_obj, response_type, pb2_module
582
+ )
583
+ except ValidationError as e:
584
+ await context.abort(
585
+ grpc.StatusCode.INVALID_ARGUMENT, str(e)
586
+ )
587
+ except Exception as e:
588
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
332
589
 
333
- async def stub_method2(self, request, context, method=method):
334
- try:
335
- arg = converter(request)
336
- resp_obj = await method(arg, context)
337
- return convert_python_message_to_proto(
338
- resp_obj, response_type, pb2_module
339
- )
340
- except ValidationError as e:
341
- await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
342
- except Exception as e:
343
- await context.abort(grpc.StatusCode.INTERNAL, str(e))
590
+ else: # size_of_parameters == 2
344
591
 
345
- return stub_method2
592
+ async def stub_method(
593
+ self: object,
594
+ request: Any,
595
+ context: Any,
596
+ ) -> Any:
597
+ _ = self
598
+ try:
599
+ if is_none_type(input_type):
600
+ resp_obj = await method(None, context)
601
+ else:
602
+ arg = converter(request)
603
+ resp_obj = await method(arg, context)
604
+
605
+ if is_none_type(response_type):
606
+ return empty_pb2.Empty() # type: ignore
607
+ else:
608
+ return convert_python_message_to_proto(
609
+ resp_obj, response_type, pb2_module
610
+ )
611
+ except ValidationError as e:
612
+ await context.abort(
613
+ grpc.StatusCode.INVALID_ARGUMENT, str(e)
614
+ )
615
+ except Exception as e:
616
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
346
617
 
347
- case _:
348
- raise Exception("Method must have exactly one or two parameters")
618
+ return stub_method
349
619
 
350
620
  for method_name, method in get_rpc_methods(obj):
351
621
  if method.__name__.startswith("_"):
@@ -357,7 +627,9 @@ def connect_obj_with_stub_async(pb2_grpc_module, pb2_module, obj: object) -> typ
357
627
  return ConcreteServiceClass
358
628
 
359
629
 
360
- def connect_obj_with_stub_connecpy(connecpy_module, pb2_module, obj: object) -> type:
630
+ def connect_obj_with_stub_connecpy(
631
+ connecpy_module: Any, pb2_module: Any, obj: object
632
+ ) -> type:
361
633
  """
362
634
  Connect a Python service object to a Connecpy stub.
363
635
  """
@@ -368,7 +640,9 @@ def connect_obj_with_stub_connecpy(connecpy_module, pb2_module, obj: object) ->
368
640
  class ConcreteServiceClass(stub_class):
369
641
  pass
370
642
 
371
- def implement_stub_method(method):
643
+ def implement_stub_method(
644
+ method: Callable[..., Message],
645
+ ) -> Callable[[object, Any, Any], Any]:
372
646
  sig = inspect.signature(method)
373
647
  arg_type = get_request_arg_type(sig)
374
648
  converter = generate_message_converter(arg_type)
@@ -378,33 +652,59 @@ def connect_obj_with_stub_connecpy(connecpy_module, pb2_module, obj: object) ->
378
652
  match size_of_parameters:
379
653
  case 1:
380
654
 
381
- def stub_method1(self, request, context, method=method):
655
+ def stub_method1(
656
+ self: object,
657
+ request: Any,
658
+ context: Any,
659
+ method: Callable[..., Message] = method,
660
+ ) -> Any:
661
+ _ = self
382
662
  try:
383
- arg = converter(request)
384
- resp_obj = method(arg)
385
- return convert_python_message_to_proto(
386
- resp_obj, response_type, pb2_module
387
- )
663
+ if is_none_type(arg_type):
664
+ resp_obj = method(None)
665
+ else:
666
+ arg = converter(request)
667
+ resp_obj = method(arg)
668
+
669
+ if is_none_type(response_type):
670
+ return empty_pb2.Empty() # type: ignore
671
+ else:
672
+ return convert_python_message_to_proto(
673
+ resp_obj, response_type, pb2_module
674
+ )
388
675
  except ValidationError as e:
389
- return context.abort(Errors.InvalidArgument, str(e))
676
+ return context.abort(Errors.INVALID_ARGUMENT, str(e))
390
677
  except Exception as e:
391
- return context.abort(Errors.Internal, str(e))
678
+ return context.abort(Errors.INTERNAL, str(e))
392
679
 
393
680
  return stub_method1
394
681
 
395
682
  case 2:
396
683
 
397
- def stub_method2(self, request, context, method=method):
684
+ def stub_method2(
685
+ self: object,
686
+ request: Any,
687
+ context: Any,
688
+ method: Callable[..., Message] = method,
689
+ ) -> Any:
690
+ _ = self
398
691
  try:
399
- arg = converter(request)
400
- resp_obj = method(arg, context)
401
- return convert_python_message_to_proto(
402
- resp_obj, response_type, pb2_module
403
- )
692
+ if is_none_type(arg_type):
693
+ resp_obj = method(None, context)
694
+ else:
695
+ arg = converter(request)
696
+ resp_obj = method(arg, context)
697
+
698
+ if is_none_type(response_type):
699
+ return empty_pb2.Empty() # type: ignore
700
+ else:
701
+ return convert_python_message_to_proto(
702
+ resp_obj, response_type, pb2_module
703
+ )
404
704
  except ValidationError as e:
405
- return context.abort(Errors.InvalidArgument, str(e))
705
+ return context.abort(Errors.INVALID_ARGUMENT, str(e))
406
706
  except Exception as e:
407
- return context.abort(Errors.Internal, str(e))
707
+ return context.abort(Errors.INTERNAL, str(e))
408
708
 
409
709
  return stub_method2
410
710
 
@@ -421,7 +721,7 @@ def connect_obj_with_stub_connecpy(connecpy_module, pb2_module, obj: object) ->
421
721
 
422
722
 
423
723
  def connect_obj_with_stub_async_connecpy(
424
- connecpy_module, pb2_module, obj: object
724
+ connecpy_module: Any, pb2_module: Any, obj: object
425
725
  ) -> type:
426
726
  """
427
727
  Connect a Python service object to a Connecpy stub for async methods.
@@ -433,7 +733,9 @@ def connect_obj_with_stub_async_connecpy(
433
733
  class ConcreteServiceClass(stub_class):
434
734
  pass
435
735
 
436
- def implement_stub_method(method):
736
+ def implement_stub_method(
737
+ method: Callable[..., Awaitable[Message]],
738
+ ) -> Callable[[object, Any, Any], Any]:
437
739
  sig = inspect.signature(method)
438
740
  arg_type = get_request_arg_type(sig)
439
741
  converter = generate_message_converter(arg_type)
@@ -443,33 +745,59 @@ def connect_obj_with_stub_async_connecpy(
443
745
  match size_of_parameters:
444
746
  case 1:
445
747
 
446
- async def stub_method1(self, request, context, method=method):
748
+ async def stub_method1(
749
+ self: object,
750
+ request: Any,
751
+ context: Any,
752
+ method: Callable[..., Awaitable[Message]] = method,
753
+ ) -> Any:
754
+ _ = self
447
755
  try:
448
- arg = converter(request)
449
- resp_obj = await method(arg)
450
- return convert_python_message_to_proto(
451
- resp_obj, response_type, pb2_module
452
- )
756
+ if is_none_type(arg_type):
757
+ resp_obj = await method(None)
758
+ else:
759
+ arg = converter(request)
760
+ resp_obj = await method(arg)
761
+
762
+ if is_none_type(response_type):
763
+ return empty_pb2.Empty() # type: ignore
764
+ else:
765
+ return convert_python_message_to_proto(
766
+ resp_obj, response_type, pb2_module
767
+ )
453
768
  except ValidationError as e:
454
- await context.abort(Errors.InvalidArgument, str(e))
769
+ await context.abort(Errors.INVALID_ARGUMENT, str(e))
455
770
  except Exception as e:
456
- await context.abort(Errors.Internal, str(e))
771
+ await context.abort(Errors.INTERNAL, str(e))
457
772
 
458
773
  return stub_method1
459
774
 
460
775
  case 2:
461
776
 
462
- async def stub_method2(self, request, context, method=method):
777
+ async def stub_method2(
778
+ self: object,
779
+ request: Any,
780
+ context: Any,
781
+ method: Callable[..., Awaitable[Message]] = method,
782
+ ) -> Any:
783
+ _ = self
463
784
  try:
464
- arg = converter(request)
465
- resp_obj = await method(arg, context)
466
- return convert_python_message_to_proto(
467
- resp_obj, response_type, pb2_module
468
- )
785
+ if is_none_type(arg_type):
786
+ resp_obj = await method(None, context)
787
+ else:
788
+ arg = converter(request)
789
+ resp_obj = await method(arg, context)
790
+
791
+ if is_none_type(response_type):
792
+ return empty_pb2.Empty() # type: ignore
793
+ else:
794
+ return convert_python_message_to_proto(
795
+ resp_obj, response_type, pb2_module
796
+ )
469
797
  except ValidationError as e:
470
- await context.abort(Errors.InvalidArgument, str(e))
798
+ await context.abort(Errors.INVALID_ARGUMENT, str(e))
471
799
  except Exception as e:
472
- await context.abort(Errors.Internal, str(e))
800
+ await context.abort(Errors.INTERNAL, str(e))
473
801
 
474
802
  return stub_method2
475
803
 
@@ -487,36 +815,84 @@ def connect_obj_with_stub_async_connecpy(
487
815
  return ConcreteServiceClass
488
816
 
489
817
 
490
- def convert_python_message_to_proto(
491
- py_msg: Message, msg_type: Type, pb2_module
492
- ) -> object:
818
+ def python_value_to_proto_oneof(
819
+ field_name: str, field_type: type[Any], value: Any, pb2_module: Any
820
+ ) -> tuple[str, Any]:
493
821
  """
494
- Convert a Python Pydantic Message instance to a protobuf message instance.
495
- Used for constructing a response.
822
+ Converts a Python value from a Union type to a protobuf oneof field.
823
+ Returns the field name to set and the converted value.
496
824
  """
497
- # Before calling something like pb2_module.AResponseMessage(...),
498
- # convert each field from Python to proto.
825
+ union_args = [arg for arg in flatten_union(field_type) if arg is not type(None)]
826
+
827
+ # Find which subtype in the Union matches the value's type.
828
+ actual_type = None
829
+ for sub_type in union_args:
830
+ origin = get_origin(sub_type)
831
+ type_to_check = origin or sub_type
832
+ try:
833
+ if isinstance(value, type_to_check):
834
+ actual_type = sub_type
835
+ break
836
+ except TypeError:
837
+ # This can happen if `sub_type` is not a class, e.g. a generic alias
838
+ if isinstance(value, type_to_check):
839
+ actual_type = sub_type
840
+ break
841
+
842
+ if actual_type is None:
843
+ raise TypeError(f"Value of type {type(value)} not found in union {field_type}")
844
+
845
+ proto_typename = protobuf_type_mapping(actual_type)
846
+ if proto_typename is None:
847
+ raise TypeError(f"Unsupported type in oneof: {actual_type}")
848
+
849
+ oneof_field_name = f"{field_name}_{proto_typename.replace('.', '_')}"
850
+ converted_value = python_value_to_proto(actual_type, value, pb2_module)
851
+ return oneof_field_name, converted_value
852
+
853
+
854
+ def convert_python_message_to_proto(
855
+ py_msg: Message, msg_type: type[Message], pb2_module: Any
856
+ ) -> object:
857
+ """Convert a Python Pydantic Message instance to a protobuf message instance. Used for constructing a response."""
499
858
  field_dict = {}
500
859
  for name, field_info in msg_type.model_fields.items():
501
860
  value = getattr(py_msg, name)
502
861
  if value is None:
503
- field_dict[name] = None
504
862
  continue
505
863
 
506
864
  field_type = field_info.annotation
507
- field_dict[name] = python_value_to_proto(field_type, value, pb2_module)
865
+
866
+ # Handle oneof fields, which are represented as Unions.
867
+ if field_type is not None and is_union_type(field_type):
868
+ union_args = [
869
+ arg for arg in flatten_union(field_type) if arg is not type(None)
870
+ ]
871
+ if len(union_args) > 1:
872
+ # It's a oneof field. We need to determine the concrete type and
873
+ # the corresponding protobuf field name.
874
+ (
875
+ oneof_field_name,
876
+ converted_value,
877
+ ) = python_value_to_proto_oneof(name, field_type, value, pb2_module)
878
+ field_dict[oneof_field_name] = converted_value
879
+ continue
880
+
881
+ # For regular and Optional fields that have a value.
882
+ if field_type is not None:
883
+ field_dict[name] = python_value_to_proto(field_type, value, pb2_module)
508
884
 
509
885
  # Retrieve the appropriate protobuf class dynamically
510
886
  proto_class = getattr(pb2_module, msg_type.__name__)
511
887
  return proto_class(**field_dict)
512
888
 
513
889
 
514
- def python_value_to_proto(field_type: Type, value, pb2_module):
890
+ def python_value_to_proto(field_type: type[Any], value: Any, pb2_module: Any) -> Any:
515
891
  """
516
892
  Perform Python->protobuf type conversion for each field value.
517
893
  """
518
- import inspect
519
894
  import datetime
895
+ import inspect
520
896
 
521
897
  # If datetime
522
898
  if field_type == datetime.datetime:
@@ -533,12 +909,12 @@ def python_value_to_proto(field_type: Type, value, pb2_module):
533
909
  origin = get_origin(field_type)
534
910
  # If seq
535
911
  if origin in (list, tuple):
536
- inner_type = get_args(field_type)[0] # type: ignore
912
+ inner_type = get_args(field_type)[0]
537
913
  return [python_value_to_proto(inner_type, v, pb2_module) for v in value]
538
914
 
539
915
  # If dict
540
916
  if origin is dict:
541
- key_type, val_type = get_args(field_type) # type: ignore
917
+ key_type, val_type = get_args(field_type)
542
918
  return {
543
919
  python_value_to_proto(key_type, k, pb2_module): python_value_to_proto(
544
920
  val_type, v, pb2_module
@@ -546,31 +922,16 @@ def python_value_to_proto(field_type: Type, value, pb2_module):
546
922
  for k, v in value.items()
547
923
  }
548
924
 
549
- # If union -> oneof
925
+ # If union -> oneof. This path is now only for Optional[T] where value is not None.
550
926
  if is_union_type(field_type):
551
- # Flatten union and check which type matches. If matched, return converted value.
552
- for sub_type in flatten_union(field_type):
553
- if sub_type == datetime.datetime and isinstance(value, datetime.datetime):
554
- return python_to_timestamp(value)
555
- if sub_type == datetime.timedelta and isinstance(value, datetime.timedelta):
556
- return python_to_duration(value)
557
- if (
558
- inspect.isclass(sub_type)
559
- and issubclass(sub_type, enum.Enum)
560
- and isinstance(value, enum.Enum)
561
- ):
562
- return value.value
563
- if sub_type in (int, float, str, bool, bytes) and isinstance(
564
- value, sub_type
565
- ):
566
- return value
567
- if (
568
- inspect.isclass(sub_type)
569
- and issubclass(sub_type, Message)
570
- and isinstance(value, Message)
571
- ):
572
- return convert_python_message_to_proto(value, sub_type, pb2_module)
573
- return None
927
+ # The value is not None, so we need to find the actual type.
928
+ non_none_args = [
929
+ arg for arg in flatten_union(field_type) if arg is not type(None)
930
+ ]
931
+ if non_none_args:
932
+ # Assuming it's an Optional[T], so there's one type left.
933
+ return python_value_to_proto(non_none_args[0], value, pb2_module)
934
+ return None # Should not be reached if value is not None
574
935
 
575
936
  # If Message
576
937
  if inspect.isclass(field_type) and issubclass(field_type, Message):
@@ -585,12 +946,12 @@ def python_value_to_proto(field_type: Type, value, pb2_module):
585
946
  ###############################################################################
586
947
 
587
948
 
588
- def is_enum_type(python_type: Type) -> bool:
949
+ def is_enum_type(python_type: Any) -> bool:
589
950
  """Return True if the given Python type is an enum."""
590
951
  return inspect.isclass(python_type) and issubclass(python_type, enum.Enum)
591
952
 
592
953
 
593
- def is_union_type(python_type: Type) -> bool:
954
+ def is_union_type(python_type: Any) -> bool:
594
955
  """
595
956
  Check if a given Python type is a Union type (including Python 3.10's UnionType).
596
957
  """
@@ -599,26 +960,28 @@ def is_union_type(python_type: Type) -> bool:
599
960
  if sys.version_info >= (3, 10):
600
961
  import types
601
962
 
602
- if type(python_type) is types.UnionType:
963
+ if isinstance(python_type, types.UnionType):
603
964
  return True
604
965
  return False
605
966
 
606
967
 
607
- def flatten_union(field_type: Type) -> list[Type]:
968
+ def flatten_union(field_type: Any) -> list[Any]:
608
969
  """Recursively flatten nested Unions into a single list of types."""
609
970
  if is_union_type(field_type):
610
971
  results = []
611
972
  for arg in get_args(field_type):
612
973
  results.extend(flatten_union(arg))
613
974
  return results
975
+ elif is_none_type(field_type):
976
+ return [field_type]
614
977
  else:
615
978
  return [field_type]
616
979
 
617
980
 
618
- def protobuf_type_mapping(python_type: Type) -> str | type | None:
981
+ def protobuf_type_mapping(python_type: Any) -> str | None:
619
982
  """
620
983
  Map a Python type to a protobuf type name/class.
621
- Includes support for Timestamp and Duration.
984
+ Includes support for Timestamp, Duration, and Empty.
622
985
  """
623
986
  import datetime
624
987
 
@@ -636,8 +999,11 @@ def protobuf_type_mapping(python_type: Type) -> str | type | None:
636
999
  if python_type == datetime.timedelta:
637
1000
  return "google.protobuf.Duration"
638
1001
 
1002
+ if is_none_type(python_type):
1003
+ return "google.protobuf.Empty"
1004
+
639
1005
  if is_enum_type(python_type):
640
- return python_type # Will be defined as enum later
1006
+ return python_type.__name__
641
1007
 
642
1008
  if is_union_type(python_type):
643
1009
  return None # Handled separately as oneof
@@ -657,9 +1023,9 @@ def protobuf_type_mapping(python_type: Type) -> str | type | None:
657
1023
  return f"map<{key_proto_type}, {value_proto_type}>"
658
1024
 
659
1025
  if inspect.isclass(python_type) and issubclass(python_type, Message):
660
- return python_type
1026
+ return python_type.__name__
661
1027
 
662
- return mapping.get(python_type) # type: ignore
1028
+ return mapping.get(python_type)
663
1029
 
664
1030
 
665
1031
  def comment_out(docstr: str) -> tuple[str, ...]:
@@ -673,15 +1039,15 @@ def comment_out(docstr: str) -> tuple[str, ...]:
673
1039
  return tuple("//" if line == "" else f"// {line}" for line in docstr.split("\n"))
674
1040
 
675
1041
 
676
- def indent_lines(lines, indentation=" "):
1042
+ def indent_lines(lines: list[str], indentation: str = " ") -> str:
677
1043
  """Indent multiple lines with a given indentation string."""
678
1044
  return "\n".join(indentation + line for line in lines)
679
1045
 
680
1046
 
681
- def generate_enum_definition(enum_type: Type[enum.Enum]) -> str:
1047
+ def generate_enum_definition(enum_type: Any) -> str:
682
1048
  """Generate a protobuf enum definition from a Python enum."""
683
1049
  enum_name = enum_type.__name__
684
- members = []
1050
+ members: list[str] = []
685
1051
  for _, member in enum_type.__members__.items():
686
1052
  members.append(f" {member.name} = {member.value};")
687
1053
  enum_def = f"enum {enum_name} {{\n"
@@ -691,7 +1057,7 @@ def generate_enum_definition(enum_type: Type[enum.Enum]) -> str:
691
1057
 
692
1058
 
693
1059
  def generate_oneof_definition(
694
- field_name: str, union_args: list[Type], start_index: int
1060
+ field_name: str, union_args: list[Any], start_index: int
695
1061
  ) -> tuple[list[str], int]:
696
1062
  """
697
1063
  Generate a oneof block in protobuf for a union field.
@@ -705,12 +1071,6 @@ def generate_oneof_definition(
705
1071
  if proto_typename is None:
706
1072
  raise Exception(f"Nested Union not flattened properly: {arg_type}")
707
1073
 
708
- # If it's an enum or Message, use the type name.
709
- if is_enum_type(arg_type):
710
- proto_typename = arg_type.__name__
711
- elif inspect.isclass(arg_type) and issubclass(arg_type, Message):
712
- proto_typename = arg_type.__name__
713
-
714
1074
  field_alias = f"{field_name}_{proto_typename.replace('.', '_')}"
715
1075
  lines.append(f" {proto_typename} {field_alias} = {current};")
716
1076
  current += 1
@@ -718,17 +1078,50 @@ def generate_oneof_definition(
718
1078
  return lines, current
719
1079
 
720
1080
 
1081
+ def extract_nested_types(field_type: Any) -> list[Any]:
1082
+ """
1083
+ Recursively extract all Message and enum types from a field type,
1084
+ including those nested within list, dict, and union types.
1085
+ """
1086
+ extracted_types = []
1087
+
1088
+ if field_type is None or is_none_type(field_type):
1089
+ return extracted_types
1090
+
1091
+ # Check if the type itself is an enum or Message
1092
+ if is_enum_type(field_type):
1093
+ extracted_types.append(field_type)
1094
+ elif inspect.isclass(field_type) and issubclass(field_type, Message):
1095
+ extracted_types.append(field_type)
1096
+
1097
+ # Handle Union types
1098
+ if is_union_type(field_type):
1099
+ union_args = flatten_union(field_type)
1100
+ for arg in union_args:
1101
+ if arg is not type(None):
1102
+ extracted_types.extend(extract_nested_types(arg))
1103
+
1104
+ # Handle generic types (list, dict, etc.)
1105
+ origin = get_origin(field_type)
1106
+ if origin is not None:
1107
+ args = get_args(field_type)
1108
+ for arg in args:
1109
+ extracted_types.extend(extract_nested_types(arg))
1110
+
1111
+ return extracted_types
1112
+
1113
+
721
1114
  def generate_message_definition(
722
- message_type: Type[Message],
723
- done_enums: set,
724
- done_messages: set,
725
- ) -> tuple[str, list[type]]:
1115
+ message_type: Any,
1116
+ done_enums: set[Any],
1117
+ done_messages: set[Any],
1118
+ ) -> tuple[str, list[Any]]:
726
1119
  """
727
1120
  Generate a protobuf message definition for a Pydantic-based Message class.
728
1121
  Also returns any referenced types (enums, messages) that need to be defined.
729
1122
  """
730
- fields = []
731
- refs = []
1123
+ fields: list[str] = []
1124
+ refs: list[Any] = []
732
1125
  pydantic_fields = message_type.model_fields
733
1126
  index = 1
734
1127
 
@@ -737,92 +1130,130 @@ def generate_message_definition(
737
1130
  if field_type is None:
738
1131
  raise Exception(f"Field {field_name} has no type annotation.")
739
1132
 
1133
+ is_optional = False
1134
+ # Handle Union types, which may be Optional or a oneof.
740
1135
  if is_union_type(field_type):
741
1136
  union_args = flatten_union(field_type)
742
- oneof_lines, new_index = generate_oneof_definition(
743
- field_name, union_args, index
744
- )
745
- fields.extend(oneof_lines)
746
- index = new_index
747
-
748
- for utype in union_args:
749
- if is_enum_type(utype) and utype not in done_enums:
750
- refs.append(utype)
751
- elif (
752
- inspect.isclass(utype)
753
- and issubclass(utype, Message)
754
- and utype not in done_messages
755
- ):
756
- refs.append(utype)
1137
+ none_type = type(
1138
+ None
1139
+ ) # Keep this as type(None) since we're working with union args
1140
+
1141
+ if none_type in union_args or None in union_args:
1142
+ is_optional = True
1143
+ union_args = [arg for arg in union_args if not is_none_type(arg)]
1144
+
1145
+ if len(union_args) == 1:
1146
+ # This is an Optional[T]. Treat it as a simple optional field.
1147
+ field_type = union_args[0]
1148
+ elif len(union_args) > 1:
1149
+ # This is a Union of multiple types, so it becomes a `oneof`.
1150
+ oneof_lines, new_index = generate_oneof_definition(
1151
+ field_name, union_args, index
1152
+ )
1153
+ fields.extend(oneof_lines)
1154
+ index = new_index
1155
+
1156
+ for utype in union_args:
1157
+ if is_enum_type(utype) and utype not in done_enums:
1158
+ refs.append(utype)
1159
+ elif (
1160
+ inspect.isclass(utype)
1161
+ and issubclass(utype, Message)
1162
+ and utype not in done_messages
1163
+ ):
1164
+ refs.append(utype)
1165
+ continue # Proceed to the next field
1166
+ else:
1167
+ # This was a field of only `NoneType`, which is not supported.
1168
+ continue
757
1169
 
1170
+ # For regular fields or optional fields that have been unwrapped.
1171
+ proto_typename = protobuf_type_mapping(field_type)
1172
+ if proto_typename is None:
1173
+ raise Exception(f"Type {field_type} is not supported.")
1174
+
1175
+ # Extract all nested Message and enum types recursively
1176
+ nested_types = extract_nested_types(field_type)
1177
+ for nested_type in nested_types:
1178
+ if is_enum_type(nested_type) and nested_type not in done_enums:
1179
+ refs.append(nested_type)
1180
+ elif (
1181
+ inspect.isclass(nested_type)
1182
+ and issubclass(nested_type, Message)
1183
+ and nested_type not in done_messages
1184
+ ):
1185
+ refs.append(nested_type)
1186
+
1187
+ if field_info.description:
1188
+ fields.append("// " + field_info.description)
1189
+ if field_info.metadata:
1190
+ fields.append("// Constraint:")
1191
+ for metadata_item in field_info.metadata:
1192
+ match type(metadata_item):
1193
+ case annotated_types.Ge:
1194
+ fields.append(
1195
+ "// greater than or equal to " + str(metadata_item.ge)
1196
+ )
1197
+ case annotated_types.Le:
1198
+ fields.append(
1199
+ "// less than or equal to " + str(metadata_item.le)
1200
+ )
1201
+ case annotated_types.Gt:
1202
+ fields.append("// greater than " + str(metadata_item.gt))
1203
+ case annotated_types.Lt:
1204
+ fields.append("// less than " + str(metadata_item.lt))
1205
+ case annotated_types.MultipleOf:
1206
+ fields.append(
1207
+ "// multiple of " + str(metadata_item.multiple_of)
1208
+ )
1209
+ case annotated_types.Len:
1210
+ fields.append("// length of " + str(metadata_item.len))
1211
+ case annotated_types.MinLen:
1212
+ fields.append(
1213
+ "// minimum length of " + str(metadata_item.min_len)
1214
+ )
1215
+ case annotated_types.MaxLen:
1216
+ fields.append(
1217
+ "// maximum length of " + str(metadata_item.max_len)
1218
+ )
1219
+ case _:
1220
+ fields.append("// " + str(metadata_item))
1221
+
1222
+ field_definition = f"{proto_typename} {field_name} = {index};"
1223
+ if is_optional:
1224
+ field_definition = f"optional {field_definition}"
1225
+
1226
+ fields.append(field_definition)
1227
+ index += 1
1228
+
1229
+ # Add reserved fields for forward/backward compatibility if specified
1230
+ reserved_count = get_reserved_fields_count()
1231
+ if reserved_count > 0:
1232
+ start_reserved = index
1233
+ end_reserved = index + reserved_count - 1
1234
+ fields.append("")
1235
+ fields.append("// Reserved fields for future compatibility")
1236
+ if reserved_count == 1:
1237
+ fields.append(f"reserved {start_reserved};")
758
1238
  else:
759
- proto_typename = protobuf_type_mapping(field_type)
760
- if proto_typename is None:
761
- raise Exception(f"Type {field_type} is not supported.")
762
-
763
- if is_enum_type(field_type):
764
- proto_typename = field_type.__name__
765
- if field_type not in done_enums:
766
- refs.append(field_type)
767
- elif inspect.isclass(field_type) and issubclass(field_type, Message):
768
- proto_typename = field_type.__name__
769
- if field_type not in done_messages:
770
- refs.append(field_type)
771
-
772
- if field_info.description:
773
- fields.append("// " + field_info.description)
774
- if field_info.metadata:
775
- fields.append("// Constraint:")
776
- for metadata_item in field_info.metadata:
777
- match type(metadata_item):
778
- case annotated_types.Ge:
779
- fields.append(
780
- "// greater than or equal to " + str(metadata_item.ge)
781
- )
782
- case annotated_types.Le:
783
- fields.append(
784
- "// less than or equal to " + str(metadata_item.le)
785
- )
786
- case annotated_types.Gt:
787
- fields.append("// greater than " + str(metadata_item.gt))
788
- case annotated_types.Lt:
789
- fields.append("// less than " + str(metadata_item.lt))
790
- case annotated_types.MultipleOf:
791
- fields.append(
792
- "// multiple of " + str(metadata_item.multiple_of)
793
- )
794
- case annotated_types.Len:
795
- fields.append("// length of " + str(metadata_item.len))
796
- case annotated_types.MinLen:
797
- fields.append(
798
- "// minimum length of " + str(metadata_item.min_len)
799
- )
800
- case annotated_types.MaxLen:
801
- fields.append(
802
- "// maximum length of " + str(metadata_item.max_len)
803
- )
804
- case _:
805
- fields.append("// " + str(metadata_item))
806
-
807
- fields.append(f"{proto_typename} {field_name} = {index};")
808
- index += 1
1239
+ fields.append(f"reserved {start_reserved} to {end_reserved};")
809
1240
 
810
1241
  msg_def = f"message {message_type.__name__} {{\n{indent_lines(fields)}\n}}"
811
1242
  return msg_def, refs
812
1243
 
813
1244
 
814
- def is_stream_type(annotation: Type) -> bool:
1245
+ def is_stream_type(annotation: Any) -> bool:
815
1246
  return get_origin(annotation) is AsyncIterator
816
1247
 
817
1248
 
818
- def is_generic_alias(annotation: Type) -> bool:
1249
+ def is_generic_alias(annotation: Any) -> bool:
819
1250
  return get_origin(annotation) is not None
820
1251
 
821
1252
 
822
1253
  def generate_proto(obj: object, package_name: str = "") -> str:
823
1254
  """
824
1255
  Generate a .proto definition from a service class.
825
- Automatically handles Timestamp and Duration usage.
1256
+ Automatically handles Timestamp, Duration, and Empty usage.
826
1257
  """
827
1258
  import datetime
828
1259
 
@@ -831,20 +1262,23 @@ def generate_proto(obj: object, package_name: str = "") -> str:
831
1262
  service_docstr = inspect.getdoc(service_class)
832
1263
  service_comment = "\n".join(comment_out(service_docstr)) if service_docstr else ""
833
1264
 
834
- rpc_definitions = []
835
- all_type_definitions = []
836
- done_messages = set()
837
- done_enums = set()
1265
+ rpc_definitions: list[str] = []
1266
+ all_type_definitions: list[str] = []
1267
+ done_messages: set[Any] = set()
1268
+ done_enums: set[Any] = set()
838
1269
 
839
1270
  uses_timestamp = False
840
1271
  uses_duration = False
1272
+ uses_empty = False
841
1273
 
842
- def check_and_set_well_known_types(py_type: Type):
1274
+ def check_and_set_well_known_types_for_fields(py_type: Any):
1275
+ """Check well-known types for field annotations (excludes None/Empty)."""
843
1276
  nonlocal uses_timestamp, uses_duration
844
1277
  if py_type == datetime.datetime:
845
1278
  uses_timestamp = True
846
1279
  if py_type == datetime.timedelta:
847
1280
  uses_duration = True
1281
+ # Don't check for None here - Optional fields don't use Empty
848
1282
 
849
1283
  for method_name, method in get_rpc_methods(obj):
850
1284
  if method.__name__.startswith("_"):
@@ -854,26 +1288,59 @@ def generate_proto(obj: object, package_name: str = "") -> str:
854
1288
  request_type = get_request_arg_type(method_sig)
855
1289
  response_type = method_sig.return_annotation
856
1290
 
1291
+ # Validate that we don't have AsyncIterator[None] which doesn't make any practical sense
1292
+ if is_stream_type(request_type):
1293
+ stream_item_type = get_args(request_type)[0]
1294
+ if is_none_type(stream_item_type):
1295
+ raise TypeError(
1296
+ f"Method '{method_name}' has AsyncIterator[None] as input type, which is not allowed. Streaming Empty messages is meaningless."
1297
+ )
1298
+
1299
+ if is_stream_type(response_type):
1300
+ stream_item_type = get_args(response_type)[0]
1301
+ if is_none_type(stream_item_type):
1302
+ raise TypeError(
1303
+ f"Method '{method_name}' has AsyncIterator[None] as return type, which is not allowed. Streaming Empty messages is meaningless."
1304
+ )
1305
+
1306
+ # Handle NoneType for request and response
1307
+ if is_none_type(request_type):
1308
+ uses_empty = True
1309
+ if is_none_type(response_type):
1310
+ uses_empty = True
1311
+
857
1312
  # Recursively generate message definitions
858
- message_types = [request_type, response_type]
1313
+ message_types = []
1314
+ if not is_none_type(request_type):
1315
+ message_types.append(request_type)
1316
+ if not is_none_type(response_type):
1317
+ message_types.append(response_type)
1318
+
859
1319
  while message_types:
860
- mt = message_types.pop()
861
- if mt in done_messages:
1320
+ mt: type[Message] | type[ServicerContext] | None = message_types.pop()
1321
+ if mt in done_messages or mt is ServicerContext or mt is None:
862
1322
  continue
863
1323
  done_messages.add(mt)
864
1324
 
865
1325
  if is_stream_type(mt):
866
1326
  item_type = get_args(mt)[0]
867
- message_types.append(item_type)
1327
+ if not is_none_type(item_type):
1328
+ message_types.append(item_type)
868
1329
  continue
869
1330
 
1331
+ mt = cast(type[Message], mt)
1332
+
870
1333
  for _, field_info in mt.model_fields.items():
871
1334
  t = field_info.annotation
872
1335
  if is_union_type(t):
873
1336
  for sub_t in flatten_union(t):
874
- check_and_set_well_known_types(sub_t)
1337
+ check_and_set_well_known_types_for_fields(
1338
+ sub_t
1339
+ ) # Use the field-specific version
875
1340
  else:
876
- check_and_set_well_known_types(t)
1341
+ check_and_set_well_known_types_for_fields(
1342
+ t
1343
+ ) # Use the field-specific version
877
1344
 
878
1345
  msg_def, refs = generate_message_definition(mt, done_enums, done_messages)
879
1346
  mt_doc = inspect.getdoc(mt)
@@ -898,16 +1365,47 @@ def generate_proto(obj: object, package_name: str = "") -> str:
898
1365
  for comment_line in comment_out(method_docstr):
899
1366
  rpc_definitions.append(comment_line)
900
1367
 
901
- if is_stream_type(response_type):
902
- item_type = get_args(response_type)[0]
903
- rpc_definitions.append(
904
- f"rpc {method_name} ({request_type.__name__}) returns (stream {item_type.__name__});"
1368
+ input_type = request_type
1369
+ input_is_stream = is_stream_type(input_type)
1370
+ output_is_stream = is_stream_type(response_type)
1371
+
1372
+ if input_is_stream:
1373
+ input_msg_type = get_args(input_type)[0]
1374
+ else:
1375
+ input_msg_type = input_type
1376
+
1377
+ if output_is_stream:
1378
+ output_msg_type = get_args(response_type)[0]
1379
+ else:
1380
+ output_msg_type = response_type
1381
+
1382
+ # Handle NoneType by using Empty (but we've already validated no streaming of Empty above)
1383
+ if input_msg_type is None or input_msg_type is ServicerContext:
1384
+ input_str = "google.protobuf.Empty" # No need to check for stream since we validated above
1385
+ if input_msg_type is ServicerContext:
1386
+ uses_empty = True
1387
+ else:
1388
+ input_str = (
1389
+ f"stream {input_msg_type.__name__}"
1390
+ if input_is_stream
1391
+ else input_msg_type.__name__
905
1392
  )
1393
+
1394
+ if output_msg_type is None or output_msg_type is ServicerContext:
1395
+ output_str = "google.protobuf.Empty" # No need to check for stream since we validated above
1396
+ if output_msg_type is ServicerContext:
1397
+ uses_empty = True
906
1398
  else:
907
- rpc_definitions.append(
908
- f"rpc {method_name} ({request_type.__name__}) returns ({response_type.__name__});"
1399
+ output_str = (
1400
+ f"stream {output_msg_type.__name__}"
1401
+ if output_is_stream
1402
+ else output_msg_type.__name__
909
1403
  )
910
1404
 
1405
+ rpc_definitions.append(
1406
+ f"rpc {method_name} ({input_str}) returns ({output_str});"
1407
+ )
1408
+
911
1409
  if not package_name:
912
1410
  if service_name.endswith("Service"):
913
1411
  package_name = service_name[: -len("Service")]
@@ -915,11 +1413,13 @@ def generate_proto(obj: object, package_name: str = "") -> str:
915
1413
  package_name = service_name
916
1414
  package_name = package_name.lower() + ".v1"
917
1415
 
918
- imports = []
1416
+ imports: list[str] = []
919
1417
  if uses_timestamp:
920
1418
  imports.append('import "google/protobuf/timestamp.proto";')
921
1419
  if uses_duration:
922
1420
  imports.append('import "google/protobuf/duration.proto";')
1421
+ if uses_empty:
1422
+ imports.append('import "google/protobuf/empty.proto";')
923
1423
 
924
1424
  import_block = "\n".join(imports)
925
1425
  if import_block:
@@ -939,98 +1439,168 @@ service {service_name} {{
939
1439
  return proto_definition
940
1440
 
941
1441
 
942
- def generate_grpc_code(proto_file, grpc_python_out) -> types.ModuleType | None:
1442
+ def generate_grpc_code(proto_path: Path) -> types.ModuleType | None:
943
1443
  """
944
- Execute the protoc command to generate Python gRPC code from the .proto file.
945
- Returns a tuple of (pb2_grpc_module, pb2_module) on success, or None if failed.
1444
+ Run protoc to generate Python gRPC code from proto_path.
1445
+ Writes foo_pb2_grpc.py next to proto_path, then imports and returns that module.
946
1446
  """
947
- command = f"-I. --grpc_python_out={grpc_python_out} {proto_file}"
948
- exit_code = protoc.main(command.split())
949
- if exit_code != 0:
950
- return None
1447
+ # 1) Ensure the .proto exists
1448
+ if not proto_path.is_file():
1449
+ raise FileNotFoundError(f"{proto_path!r} does not exist")
1450
+
1451
+ # 2) Determine output directory (same as the .proto's parent)
1452
+ proto_path = proto_path.resolve()
1453
+ out_dir = proto_path.parent
1454
+ out_dir.mkdir(parents=True, exist_ok=True)
1455
+
1456
+ # 3) Build and run the protoc command
1457
+ out_str = str(out_dir)
1458
+ well_known_path = os.path.join(os.path.dirname(grpc_tools.__file__), "_proto")
1459
+ args = [
1460
+ "protoc", # Dummy program name (required for protoc.main)
1461
+ "-I.",
1462
+ f"-I{well_known_path}",
1463
+ f"--grpc_python_out={out_str}",
1464
+ proto_path.name,
1465
+ ]
1466
+
1467
+ current_dir = os.getcwd()
1468
+ os.chdir(str(out_dir))
1469
+ try:
1470
+ if protoc.main(args) != 0:
1471
+ return None
1472
+ finally:
1473
+ os.chdir(current_dir)
951
1474
 
952
- base = os.path.splitext(proto_file)[0]
953
- generated_pb2_grpc_file = f"{base}_pb2_grpc.py"
1475
+ # 4) Locate the generated gRPC file
1476
+ base_name = proto_path.stem # "foo"
1477
+ generated_filename = f"{base_name}_pb2_grpc.py" # "foo_pb2_grpc.py"
1478
+ generated_filepath = out_dir / generated_filename
954
1479
 
955
- if grpc_python_out not in sys.path:
956
- sys.path.append(grpc_python_out)
1480
+ # 5) Add out_dir to sys.path so we can import it
1481
+ if out_str not in sys.path:
1482
+ sys.path.append(out_str)
957
1483
 
1484
+ # 6) Load and return the module
958
1485
  spec = importlib.util.spec_from_file_location(
959
- generated_pb2_grpc_file, os.path.join(grpc_python_out, generated_pb2_grpc_file)
1486
+ base_name + "_pb2_grpc", str(generated_filepath)
960
1487
  )
961
- if spec is None:
1488
+ if spec is None or spec.loader is None:
962
1489
  return None
963
- pb2_grpc_module = importlib.util.module_from_spec(spec)
964
- if spec.loader is None:
965
- return None
966
- spec.loader.exec_module(pb2_grpc_module)
967
1490
 
968
- return pb2_grpc_module
1491
+ module = importlib.util.module_from_spec(spec)
1492
+ spec.loader.exec_module(module)
1493
+ return module
969
1494
 
970
1495
 
971
- def generate_connecpy_code(
972
- proto_file: str, connecpy_out: str
973
- ) -> types.ModuleType | None:
1496
+ def generate_connecpy_code(proto_path: Path) -> types.ModuleType | None:
974
1497
  """
975
- Execute the protoc command to generate Python Connecpy code from the .proto file.
976
- Returns a tuple of (connecpy_module, pb2_module) on success, or None if failed.
1498
+ Run protoc with the Connecpy plugin to generate Python Connecpy code from proto_path.
1499
+ Writes foo_connecpy.py next to proto_path, then imports and returns that module.
977
1500
  """
978
- command = f"-I. --connecpy_out={connecpy_out} {proto_file}"
979
- exit_code = protoc.main(command.split())
980
- if exit_code != 0:
981
- return None
1501
+ # 1) Ensure the .proto exists
1502
+ if not proto_path.is_file():
1503
+ raise FileNotFoundError(f"{proto_path!r} does not exist")
1504
+
1505
+ # 2) Determine output directory (same as the .proto's parent)
1506
+ proto_path = proto_path.resolve()
1507
+ out_dir = proto_path.parent
1508
+ out_dir.mkdir(parents=True, exist_ok=True)
1509
+
1510
+ # 3) Build and run the protoc command
1511
+ out_str = str(out_dir)
1512
+ well_known_path = os.path.join(os.path.dirname(grpc_tools.__file__), "_proto")
1513
+ args = [
1514
+ "protoc", # Dummy program name (required for protoc.main)
1515
+ "-I.",
1516
+ f"-I{well_known_path}",
1517
+ f"--connecpy_out={out_str}",
1518
+ proto_path.name,
1519
+ ]
982
1520
 
983
- base = os.path.splitext(proto_file)[0]
984
- generated_connecpy_file = f"{base}_connecpy.py"
1521
+ current_dir = os.getcwd()
1522
+ os.chdir(str(out_dir))
1523
+ try:
1524
+ if protoc.main(args) != 0:
1525
+ return None
1526
+ finally:
1527
+ os.chdir(current_dir)
985
1528
 
986
- if connecpy_out not in sys.path:
987
- sys.path.append(connecpy_out)
1529
+ # 4) Locate the generated file
1530
+ base_name = proto_path.stem # "foo"
1531
+ generated_filename = f"{base_name}_connecpy.py" # "foo_connecpy.py"
1532
+ generated_filepath = out_dir / generated_filename
988
1533
 
1534
+ # 5) Add out_dir to sys.path so we can import by filename
1535
+ if out_str not in sys.path:
1536
+ sys.path.append(out_str)
1537
+
1538
+ # 6) Load and return the module
989
1539
  spec = importlib.util.spec_from_file_location(
990
- generated_connecpy_file, os.path.join(connecpy_out, generated_connecpy_file)
1540
+ base_name + "_connecpy", str(generated_filepath)
991
1541
  )
992
- if spec is None:
993
- return None
994
- connecpy_module = importlib.util.module_from_spec(spec)
995
- if spec.loader is None:
1542
+ if spec is None or spec.loader is None:
996
1543
  return None
997
- spec.loader.exec_module(connecpy_module)
998
1544
 
999
- return connecpy_module
1545
+ module = importlib.util.module_from_spec(spec)
1546
+ spec.loader.exec_module(module)
1547
+ return module
1000
1548
 
1001
1549
 
1002
- def generate_pb_code(
1003
- proto_file: str, python_out: str, pyi_out: str
1004
- ) -> types.ModuleType | None:
1550
+ def generate_pb_code(proto_path: Path) -> types.ModuleType | None:
1005
1551
  """
1006
- Execute the protoc command to generate Python gRPC code from the .proto file.
1007
- Returns a tuple of (pb2_grpc_module, pb2_module) on success, or None if failed.
1552
+ Run protoc to generate Python gRPC code from proto_path.
1553
+ Writes foo_pb2.py and foo_pb2.pyi next to proto_path, then imports and returns the pb2 module.
1008
1554
  """
1009
- command = f"-I. --python_out={python_out} --pyi_out={pyi_out} {proto_file}"
1010
- exit_code = protoc.main(command.split())
1011
- if exit_code != 0:
1012
- return None
1555
+ # 1) Make sure proto_path exists
1556
+ if not proto_path.is_file():
1557
+ raise FileNotFoundError(f"{proto_path!r} does not exist")
1558
+
1559
+ # 2) Determine output directory (same as proto file)
1560
+ proto_path = proto_path.resolve()
1561
+ out_dir = proto_path.parent
1562
+ out_dir.mkdir(parents=True, exist_ok=True)
1563
+
1564
+ # 3) Build and run protoc command
1565
+ out_str = str(out_dir)
1566
+ well_known_path = os.path.join(os.path.dirname(grpc_tools.__file__), "_proto")
1567
+ args = [
1568
+ "protoc", # Dummy program name (required for protoc.main)
1569
+ "-I.",
1570
+ f"-I{well_known_path}",
1571
+ f"--python_out={out_str}",
1572
+ f"--pyi_out={out_str}",
1573
+ proto_path.name,
1574
+ ]
1575
+
1576
+ current_dir = os.getcwd()
1577
+ os.chdir(str(out_dir))
1578
+ try:
1579
+ if protoc.main(args) != 0:
1580
+ return None
1581
+ finally:
1582
+ os.chdir(current_dir)
1013
1583
 
1014
- base = os.path.splitext(proto_file)[0]
1015
- generated_pb2_file = f"{base}_pb2.py"
1584
+ # 4) Locate generated file
1585
+ base_name = proto_path.stem # e.g. "foo"
1586
+ generated_file = out_dir / f"{base_name}_pb2.py"
1016
1587
 
1017
- if python_out not in sys.path:
1018
- sys.path.append(python_out)
1588
+ # 5) Add to sys.path if needed
1589
+ if out_str not in sys.path:
1590
+ sys.path.append(out_str)
1019
1591
 
1592
+ # 6) Import it
1020
1593
  spec = importlib.util.spec_from_file_location(
1021
- generated_pb2_file, os.path.join(python_out, generated_pb2_file)
1594
+ base_name + "_pb2", str(generated_file)
1022
1595
  )
1023
- if spec is None:
1596
+ if spec is None or spec.loader is None:
1024
1597
  return None
1025
- pb2_module = importlib.util.module_from_spec(spec)
1026
- if spec.loader is None:
1027
- return None
1028
- spec.loader.exec_module(pb2_module)
1029
-
1030
- return pb2_module
1598
+ module = importlib.util.module_from_spec(spec)
1599
+ spec.loader.exec_module(module)
1600
+ return module
1031
1601
 
1032
1602
 
1033
- def get_request_arg_type(sig):
1603
+ def get_request_arg_type(sig: inspect.Signature) -> Any:
1034
1604
  """Return the type annotation of the first parameter (request) of a method."""
1035
1605
  num_of_params = len(sig.parameters)
1036
1606
  if not (num_of_params == 1 or num_of_params == 2):
@@ -1038,7 +1608,7 @@ def get_request_arg_type(sig):
1038
1608
  return tuple(sig.parameters.values())[0].annotation
1039
1609
 
1040
1610
 
1041
- def get_rpc_methods(obj: object) -> list[tuple[str, types.MethodType]]:
1611
+ def get_rpc_methods(obj: object) -> list[tuple[str, Callable[..., Any]]]:
1042
1612
  """
1043
1613
  Retrieve the list of RPC methods from a service object.
1044
1614
  The method name is converted to PascalCase for .proto compatibility.
@@ -1059,68 +1629,393 @@ def is_skip_generation() -> bool:
1059
1629
  return os.getenv("PYDANTIC_RPC_SKIP_GENERATION", "false").lower() == "true"
1060
1630
 
1061
1631
 
1062
- def generate_and_compile_proto(obj: object, package_name: str = ""):
1632
+ def get_reserved_fields_count() -> int:
1633
+ """Get the number of reserved fields to add to each message from environment variable."""
1634
+ try:
1635
+ return max(0, int(os.getenv("PYDANTIC_RPC_RESERVED_FIELDS", "0")))
1636
+ except ValueError:
1637
+ return 0
1638
+
1639
+
1640
+ def generate_and_compile_proto(
1641
+ obj: object,
1642
+ package_name: str = "",
1643
+ existing_proto_path: Path | None = None,
1644
+ ) -> tuple[Any, Any]:
1063
1645
  if is_skip_generation():
1064
1646
  import importlib
1065
1647
 
1066
- pb2_module = importlib.import_module(f"{obj.__class__.__name__.lower()}_pb2")
1067
- pb2_grpc_module = importlib.import_module(
1068
- f"{obj.__class__.__name__.lower()}_pb2_grpc"
1069
- )
1648
+ pb2_module = None
1649
+ pb2_grpc_module = None
1650
+
1651
+ try:
1652
+ pb2_module = importlib.import_module(
1653
+ f"{obj.__class__.__name__.lower()}_pb2"
1654
+ )
1655
+ except ImportError:
1656
+ pass
1657
+
1658
+ try:
1659
+ pb2_grpc_module = importlib.import_module(
1660
+ f"{obj.__class__.__name__.lower()}_pb2_grpc"
1661
+ )
1662
+ except ImportError:
1663
+ pass
1070
1664
 
1071
1665
  if pb2_grpc_module is not None and pb2_module is not None:
1072
1666
  return pb2_grpc_module, pb2_module
1073
1667
 
1074
1668
  # If the modules are not found, generate and compile the proto files.
1075
1669
 
1076
- klass = obj.__class__
1077
- proto_file = generate_proto(obj, package_name)
1078
- proto_file_name = klass.__name__.lower() + ".proto"
1670
+ if existing_proto_path:
1671
+ # Use the provided existing proto file (skip generation)
1672
+ proto_file_path = existing_proto_path
1673
+ else:
1674
+ # Generate as before
1675
+ klass = obj.__class__
1676
+ proto_file = generate_proto(obj, package_name)
1677
+ proto_file_name = klass.__name__.lower() + ".proto"
1678
+ proto_file_path = get_proto_path(proto_file_name)
1079
1679
 
1080
- with open(proto_file_name, "w", encoding="utf-8") as f:
1081
- f.write(proto_file)
1680
+ with proto_file_path.open(mode="w", encoding="utf-8") as f:
1681
+ _ = f.write(proto_file)
1082
1682
 
1083
- gen_pb = generate_pb_code(proto_file_name, ".", ".")
1683
+ gen_pb = generate_pb_code(proto_file_path)
1084
1684
  if gen_pb is None:
1085
1685
  raise Exception("Generating pb code")
1086
1686
 
1087
- gen_grpc = generate_grpc_code(proto_file_name, ".")
1687
+ gen_grpc = generate_grpc_code(proto_file_path)
1088
1688
  if gen_grpc is None:
1089
1689
  raise Exception("Generating grpc code")
1090
1690
  return gen_grpc, gen_pb
1091
1691
 
1092
1692
 
1093
- def generate_and_compile_proto_using_connecpy(obj: object, package_name: str = ""):
1693
+ def get_proto_path(proto_filename: str) -> Path:
1694
+ # 1. Get raw env var (or default to cwd)
1695
+ raw = os.getenv("PYDANTIC_RPC_PROTO_PATH", None)
1696
+ base = Path(raw) if raw is not None else Path.cwd()
1697
+
1698
+ # 2. Expand ~ and env-vars, then make absolute
1699
+ base = Path(os.path.expandvars(os.path.expanduser(str(base)))).resolve()
1700
+
1701
+ # 3. Ensure it's a directory (or create it)
1702
+ if not base.exists():
1703
+ try:
1704
+ base.mkdir(parents=True, exist_ok=True)
1705
+ except OSError as e:
1706
+ raise RuntimeError(f"Unable to create directory {base!r}: {e}") from e
1707
+ elif not base.is_dir():
1708
+ raise NotADirectoryError(f"{base!r} exists but is not a directory")
1709
+
1710
+ # 4. Check writability
1711
+ if not os.access(base, os.W_OK):
1712
+ raise PermissionError(f"No write permission for directory {base!r}")
1713
+
1714
+ # 5. Return the final file path
1715
+ return base / proto_filename
1716
+
1717
+
1718
+ def generate_and_compile_proto_using_connecpy(
1719
+ obj: object,
1720
+ package_name: str = "",
1721
+ existing_proto_path: Path | None = None,
1722
+ ) -> tuple[Any, Any]:
1094
1723
  if is_skip_generation():
1095
1724
  import importlib
1096
1725
 
1097
- pb2_module = importlib.import_module(f"{obj.__class__.__name__.lower()}_pb2")
1098
- connecpy_module = importlib.import_module(
1099
- f"{obj.__class__.__name__.lower()}_connecpy"
1100
- )
1726
+ pb2_module = None
1727
+ connecpy_module = None
1728
+
1729
+ try:
1730
+ pb2_module = importlib.import_module(
1731
+ f"{obj.__class__.__name__.lower()}_pb2"
1732
+ )
1733
+ except ImportError:
1734
+ pass
1735
+
1736
+ try:
1737
+ connecpy_module = importlib.import_module(
1738
+ f"{obj.__class__.__name__.lower()}_connecpy"
1739
+ )
1740
+ except ImportError:
1741
+ pass
1101
1742
 
1102
1743
  if connecpy_module is not None and pb2_module is not None:
1103
1744
  return connecpy_module, pb2_module
1104
1745
 
1105
1746
  # If the modules are not found, generate and compile the proto files.
1106
1747
 
1107
- klass = obj.__class__
1108
- proto_file = generate_proto(obj, package_name)
1109
- proto_file_name = klass.__name__.lower() + ".proto"
1748
+ if existing_proto_path:
1749
+ # Use the provided existing proto file (skip generation)
1750
+ proto_file_path = existing_proto_path
1751
+ else:
1752
+ # Generate as before
1753
+ klass = obj.__class__
1754
+ proto_file = generate_proto(obj, package_name)
1755
+ proto_file_name = klass.__name__.lower() + ".proto"
1110
1756
 
1111
- with open(proto_file_name, "w", encoding="utf-8") as f:
1112
- f.write(proto_file)
1757
+ proto_file_path = get_proto_path(proto_file_name)
1758
+ with proto_file_path.open(mode="w", encoding="utf-8") as f:
1759
+ _ = f.write(proto_file)
1113
1760
 
1114
- gen_pb = generate_pb_code(proto_file_name, ".", ".")
1761
+ gen_pb = generate_pb_code(proto_file_path)
1115
1762
  if gen_pb is None:
1116
1763
  raise Exception("Generating pb code")
1117
1764
 
1118
- gen_connecpy = generate_connecpy_code(proto_file_name, ".")
1765
+ gen_connecpy = generate_connecpy_code(proto_file_path)
1119
1766
  if gen_connecpy is None:
1120
1767
  raise Exception("Generating Connecpy code")
1121
1768
  return gen_connecpy, gen_pb
1122
1769
 
1123
1770
 
1771
+ def is_combined_proto_enabled() -> bool:
1772
+ """Check if combined proto file generation is enabled."""
1773
+ return os.getenv("PYDANTIC_RPC_COMBINED_PROTO", "false").lower() == "true"
1774
+
1775
+
1776
+ def generate_combined_proto(
1777
+ *services: object, package_name: str = "combined.v1"
1778
+ ) -> str:
1779
+ """Generate a combined .proto definition from multiple service classes."""
1780
+ import datetime
1781
+
1782
+ all_type_definitions: list[str] = []
1783
+ all_service_definitions: list[str] = []
1784
+ done_messages: set[Any] = set()
1785
+ done_enums: set[Any] = set()
1786
+
1787
+ uses_timestamp = False
1788
+ uses_duration = False
1789
+ uses_empty = False
1790
+
1791
+ def check_and_set_well_known_types_for_fields(py_type: Any):
1792
+ """Check well-known types for field annotations (excludes None/Empty)."""
1793
+ nonlocal uses_timestamp, uses_duration
1794
+ if py_type == datetime.datetime:
1795
+ uses_timestamp = True
1796
+ if py_type == datetime.timedelta:
1797
+ uses_duration = True
1798
+
1799
+ # Process each service
1800
+ for service_obj in services:
1801
+ service_class = service_obj.__class__
1802
+ service_name = service_class.__name__
1803
+ service_docstr = inspect.getdoc(service_class)
1804
+ service_comment = (
1805
+ "\n".join(comment_out(service_docstr)) if service_docstr else ""
1806
+ )
1807
+
1808
+ service_rpc_definitions: list[str] = []
1809
+
1810
+ for method_name, method in get_rpc_methods(service_obj):
1811
+ if method.__name__.startswith("_"):
1812
+ continue
1813
+
1814
+ method_sig = inspect.signature(method)
1815
+ request_type = get_request_arg_type(method_sig)
1816
+ response_type = method_sig.return_annotation
1817
+
1818
+ # Validate stream types
1819
+ if is_stream_type(request_type):
1820
+ stream_item_type = get_args(request_type)[0]
1821
+ if is_none_type(stream_item_type):
1822
+ raise TypeError(
1823
+ f"Method '{method_name}' has AsyncIterator[None] as input type, which is not allowed."
1824
+ )
1825
+
1826
+ if is_stream_type(response_type):
1827
+ stream_item_type = get_args(response_type)[0]
1828
+ if is_none_type(stream_item_type):
1829
+ raise TypeError(
1830
+ f"Method '{method_name}' has AsyncIterator[None] as return type, which is not allowed."
1831
+ )
1832
+
1833
+ # Handle NoneType for request and response
1834
+ if is_none_type(request_type):
1835
+ uses_empty = True
1836
+ if is_none_type(response_type):
1837
+ uses_empty = True
1838
+
1839
+ # Collect message types for processing
1840
+ message_types = []
1841
+ if not is_none_type(request_type):
1842
+ message_types.append(request_type)
1843
+ if not is_none_type(response_type):
1844
+ message_types.append(response_type)
1845
+
1846
+ # Process message types
1847
+ while message_types:
1848
+ mt: type[Message] | type[ServicerContext] | None = message_types.pop()
1849
+ if mt in done_messages or mt is ServicerContext or mt is None:
1850
+ continue
1851
+ done_messages.add(mt)
1852
+
1853
+ if is_stream_type(mt):
1854
+ item_type = get_args(mt)[0]
1855
+ if not is_none_type(item_type):
1856
+ message_types.append(item_type)
1857
+ continue
1858
+
1859
+ mt = cast(type[Message], mt)
1860
+
1861
+ for _, field_info in mt.model_fields.items():
1862
+ t = field_info.annotation
1863
+ if is_union_type(t):
1864
+ for sub_t in flatten_union(t):
1865
+ check_and_set_well_known_types_for_fields(sub_t)
1866
+ else:
1867
+ check_and_set_well_known_types_for_fields(t)
1868
+
1869
+ msg_def, refs = generate_message_definition(
1870
+ mt, done_enums, done_messages
1871
+ )
1872
+ mt_doc = inspect.getdoc(mt)
1873
+ if mt_doc:
1874
+ for comment_line in comment_out(mt_doc):
1875
+ all_type_definitions.append(comment_line)
1876
+
1877
+ all_type_definitions.append(msg_def)
1878
+ all_type_definitions.append("")
1879
+
1880
+ for r in refs:
1881
+ if is_enum_type(r) and r not in done_enums:
1882
+ done_enums.add(r)
1883
+ enum_def = generate_enum_definition(r)
1884
+ all_type_definitions.append(enum_def)
1885
+ all_type_definitions.append("")
1886
+ elif issubclass(r, Message) and r not in done_messages:
1887
+ message_types.append(r)
1888
+
1889
+ # Generate RPC definition
1890
+ method_docstr = inspect.getdoc(method)
1891
+ if method_docstr:
1892
+ for comment_line in comment_out(method_docstr):
1893
+ service_rpc_definitions.append(comment_line)
1894
+
1895
+ input_type = request_type
1896
+ input_is_stream = is_stream_type(input_type)
1897
+ output_is_stream = is_stream_type(response_type)
1898
+
1899
+ if input_is_stream:
1900
+ input_msg_type = get_args(input_type)[0]
1901
+ else:
1902
+ input_msg_type = input_type
1903
+
1904
+ if output_is_stream:
1905
+ output_msg_type = get_args(response_type)[0]
1906
+ else:
1907
+ output_msg_type = response_type
1908
+
1909
+ # Handle NoneType by using Empty
1910
+ if input_msg_type is None or input_msg_type is ServicerContext:
1911
+ input_str = "google.protobuf.Empty" # No need to check for stream since we validated above
1912
+ if input_msg_type is ServicerContext:
1913
+ uses_empty = True
1914
+ else:
1915
+ input_str = (
1916
+ f"stream {input_msg_type.__name__}"
1917
+ if input_is_stream
1918
+ else input_msg_type.__name__
1919
+ )
1920
+
1921
+ if output_msg_type is None or output_msg_type is ServicerContext:
1922
+ output_str = "google.protobuf.Empty" # No need to check for stream since we validated above
1923
+ if output_msg_type is ServicerContext:
1924
+ uses_empty = True
1925
+ else:
1926
+ output_str = (
1927
+ f"stream {output_msg_type.__name__}"
1928
+ if output_is_stream
1929
+ else output_msg_type.__name__
1930
+ )
1931
+
1932
+ service_rpc_definitions.append(
1933
+ f"rpc {method_name} ({input_str}) returns ({output_str});"
1934
+ )
1935
+
1936
+ # Create service definition
1937
+ service_def_lines: list[str] = []
1938
+ if service_comment:
1939
+ service_def_lines.append(service_comment)
1940
+ service_def_lines.append(f"service {service_name} {{")
1941
+ service_def_lines.extend([f" {line}" for line in service_rpc_definitions])
1942
+ service_def_lines.append("}")
1943
+ service_def_lines.append("")
1944
+
1945
+ all_service_definitions.extend(service_def_lines)
1946
+
1947
+ # Build imports
1948
+ imports: list[str] = []
1949
+ if uses_timestamp:
1950
+ imports.append('import "google/protobuf/timestamp.proto";')
1951
+ if uses_duration:
1952
+ imports.append('import "google/protobuf/duration.proto";')
1953
+ if uses_empty:
1954
+ imports.append('import "google/protobuf/empty.proto";')
1955
+
1956
+ import_block = "\n".join(imports)
1957
+ if import_block:
1958
+ import_block += "\n"
1959
+
1960
+ # Combine everything
1961
+ proto_definition = f"""syntax = "proto3";
1962
+
1963
+ package {package_name};
1964
+
1965
+ {import_block}{"".join(all_service_definitions)}
1966
+ {indent_lines(all_type_definitions, "")}
1967
+ """
1968
+ return proto_definition
1969
+
1970
+
1971
+ def get_combined_proto_filename() -> str:
1972
+ """Get the combined proto filename."""
1973
+ return os.getenv("PYDANTIC_RPC_COMBINED_PROTO_FILENAME", "combined_services.proto")
1974
+
1975
+
1976
+ def generate_combined_descriptor_set(
1977
+ *services: object, output_path: Path | None = None
1978
+ ) -> bytes:
1979
+ """Generate a combined protobuf descriptor set from multiple services."""
1980
+ filename = get_combined_proto_filename()
1981
+
1982
+ if output_path is None:
1983
+ output_path = get_proto_path(filename)
1984
+
1985
+ # Generate combined proto file
1986
+ combined_proto = generate_combined_proto(*services)
1987
+ proto_file_path = get_proto_path(filename)
1988
+
1989
+ with proto_file_path.open(mode="w", encoding="utf-8") as f:
1990
+ _ = f.write(combined_proto)
1991
+
1992
+ # Generate descriptor set using protoc
1993
+ out_str = str(proto_file_path.parent)
1994
+ well_known_path = os.path.join(os.path.dirname(grpc_tools.__file__), "_proto")
1995
+ args = [
1996
+ "protoc",
1997
+ f"-I{out_str}",
1998
+ f"-I{well_known_path}",
1999
+ f"--descriptor_set_out={output_path}",
2000
+ "--include_imports",
2001
+ proto_file_path.name,
2002
+ ]
2003
+
2004
+ current_dir = os.getcwd()
2005
+ os.chdir(out_str)
2006
+ try:
2007
+ if protoc.main(args) != 0:
2008
+ raise RuntimeError("Failed to generate combined descriptor set")
2009
+ finally:
2010
+ os.chdir(current_dir)
2011
+
2012
+ # Read and return the descriptor set
2013
+ with open(output_path, "rb") as f:
2014
+ descriptor_data = f.read()
2015
+
2016
+ return descriptor_data
2017
+
2018
+
1124
2019
  ###############################################################################
1125
2020
  # 4. Server Implementations
1126
2021
  ###############################################################################
@@ -1129,13 +2024,13 @@ def generate_and_compile_proto_using_connecpy(obj: object, package_name: str = "
1129
2024
  class Server:
1130
2025
  """A simple gRPC server that uses ThreadPoolExecutor for concurrency."""
1131
2026
 
1132
- def __init__(self, max_workers: int = 8, *interceptors) -> None:
1133
- self._server = grpc.server(
2027
+ def __init__(self, max_workers: int = 8, *interceptors: Any) -> None:
2028
+ self._server: grpc.Server = grpc.server(
1134
2029
  futures.ThreadPoolExecutor(max_workers), interceptors=interceptors
1135
2030
  )
1136
- self._service_names = []
1137
- self._package_name = ""
1138
- self._port = 50051
2031
+ self._service_names: list[str] = []
2032
+ self._package_name: str = ""
2033
+ self._port: int = 50051
1139
2034
 
1140
2035
  def set_package_name(self, package_name: str):
1141
2036
  """Set the package name for .proto generation."""
@@ -1147,10 +2042,15 @@ class Server:
1147
2042
 
1148
2043
  def mount(self, obj: object, package_name: str = ""):
1149
2044
  """Generate and compile proto files, then mount the service implementation."""
1150
- pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
2045
+ pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name) or (
2046
+ None,
2047
+ None,
2048
+ )
1151
2049
  self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
1152
2050
 
1153
- def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
2051
+ def mount_using_pb2_modules(
2052
+ self, pb2_grpc_module: Any, pb2_module: Any, obj: object
2053
+ ):
1154
2054
  """Connect the compiled gRPC modules with the service implementation."""
1155
2055
  concreteServiceClass = connect_obj_with_stub(pb2_grpc_module, pb2_module, obj)
1156
2056
  service_name = obj.__class__.__name__
@@ -1163,7 +2063,7 @@ class Server:
1163
2063
  ].full_name
1164
2064
  self._service_names.append(full_service_name)
1165
2065
 
1166
- def run(self, *objs):
2066
+ def run(self, *objs: object):
1167
2067
  """
1168
2068
  Mount multiple services and run the gRPC server with reflection and health check.
1169
2069
  Press Ctrl+C or send SIGTERM to stop.
@@ -1183,14 +2083,16 @@ class Server:
1183
2083
  self._server.add_insecure_port(f"[::]:{self._port}")
1184
2084
  self._server.start()
1185
2085
 
1186
- def handle_signal(signal, frame):
2086
+ def handle_signal(signum: signal.Signals, frame: Any):
2087
+ _ = signum
2088
+ _ = frame
1187
2089
  print("Received shutdown signal...")
1188
2090
  self._server.stop(grace=10)
1189
2091
  print("gRPC server shutdown.")
1190
2092
  sys.exit(0)
1191
2093
 
1192
- signal.signal(signal.SIGINT, handle_signal)
1193
- signal.signal(signal.SIGTERM, handle_signal)
2094
+ _ = signal.signal(signal.SIGINT, handle_signal) # pyright:ignore[reportArgumentType]
2095
+ _ = signal.signal(signal.SIGTERM, handle_signal) # pyright:ignore[reportArgumentType]
1194
2096
 
1195
2097
  print("gRPC server is running...")
1196
2098
  while True:
@@ -1200,11 +2102,11 @@ class Server:
1200
2102
  class AsyncIOServer:
1201
2103
  """An async gRPC server using asyncio."""
1202
2104
 
1203
- def __init__(self, *interceptors) -> None:
1204
- self._server = grpc.aio.server(interceptors=interceptors)
1205
- self._service_names = []
1206
- self._package_name = ""
1207
- self._port = 50051
2105
+ def __init__(self, *interceptors: grpc.ServerInterceptor) -> None:
2106
+ self._server: grpc.aio.Server = grpc.aio.server(interceptors=interceptors)
2107
+ self._service_names: list[str] = []
2108
+ self._package_name: str = ""
2109
+ self._port: int = 50051
1208
2110
 
1209
2111
  def set_package_name(self, package_name: str):
1210
2112
  """Set the package name for .proto generation."""
@@ -1216,10 +2118,15 @@ class AsyncIOServer:
1216
2118
 
1217
2119
  def mount(self, obj: object, package_name: str = ""):
1218
2120
  """Generate and compile proto files, then mount the service implementation (async)."""
1219
- pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
2121
+ pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name) or (
2122
+ None,
2123
+ None,
2124
+ )
1220
2125
  self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
1221
2126
 
1222
- def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
2127
+ def mount_using_pb2_modules(
2128
+ self, pb2_grpc_module: Any, pb2_module: Any, obj: object
2129
+ ):
1223
2130
  """Connect the compiled gRPC modules with the async service implementation."""
1224
2131
  concreteServiceClass = connect_obj_with_stub_async(
1225
2132
  pb2_grpc_module, pb2_module, obj
@@ -1234,7 +2141,7 @@ class AsyncIOServer:
1234
2141
  ].full_name
1235
2142
  self._service_names.append(full_service_name)
1236
2143
 
1237
- async def run(self, *objs):
2144
+ async def run(self, *objs: object):
1238
2145
  """
1239
2146
  Mount multiple async services and run the gRPC server with reflection and health check.
1240
2147
  Press Ctrl+C or send SIGTERM to stop.
@@ -1251,20 +2158,22 @@ class AsyncIOServer:
1251
2158
  health_pb2_grpc.add_HealthServicer_to_server(health_servicer, self._server)
1252
2159
  reflection.enable_server_reflection(SERVICE_NAMES, self._server)
1253
2160
 
1254
- self._server.add_insecure_port(f"[::]:{self._port}")
2161
+ _ = self._server.add_insecure_port(f"[::]:{self._port}")
1255
2162
  await self._server.start()
1256
2163
 
1257
2164
  shutdown_event = asyncio.Event()
1258
2165
 
1259
- def shutdown(signum, frame):
2166
+ def shutdown(signum: signal.Signals, frame: Any):
2167
+ _ = signum
2168
+ _ = frame
1260
2169
  print("Received shutdown signal...")
1261
2170
  shutdown_event.set()
1262
2171
 
1263
2172
  for s in [signal.SIGTERM, signal.SIGINT]:
1264
- signal.signal(s, shutdown)
2173
+ _ = signal.signal(s, shutdown) # pyright:ignore[reportArgumentType]
1265
2174
 
1266
2175
  print("gRPC server is running...")
1267
- await shutdown_event.wait()
2176
+ _ = await shutdown_event.wait()
1268
2177
  await self._server.stop(10)
1269
2178
  print("gRPC server shutdown.")
1270
2179
 
@@ -1275,17 +2184,22 @@ class WSGIApp:
1275
2184
  Useful for embedding gRPC within an existing WSGI stack.
1276
2185
  """
1277
2186
 
1278
- def __init__(self, app):
1279
- self._app = grpcWSGI(app)
1280
- self._service_names = []
1281
- self._package_name = ""
2187
+ def __init__(self, app: Any):
2188
+ self._app: grpcWSGI = grpcWSGI(app)
2189
+ self._service_names: list[str] = []
2190
+ self._package_name: str = ""
1282
2191
 
1283
2192
  def mount(self, obj: object, package_name: str = ""):
1284
2193
  """Generate and compile proto files, then mount the service implementation."""
1285
- pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
2194
+ pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name) or (
2195
+ None,
2196
+ None,
2197
+ )
1286
2198
  self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
1287
2199
 
1288
- def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
2200
+ def mount_using_pb2_modules(
2201
+ self, pb2_grpc_module: Any, pb2_module: Any, obj: object
2202
+ ):
1289
2203
  """Connect the compiled gRPC modules with the service implementation."""
1290
2204
  concreteServiceClass = connect_obj_with_stub(pb2_grpc_module, pb2_module, obj)
1291
2205
  service_name = obj.__class__.__name__
@@ -1298,12 +2212,16 @@ class WSGIApp:
1298
2212
  ].full_name
1299
2213
  self._service_names.append(full_service_name)
1300
2214
 
1301
- def mount_objs(self, *objs):
2215
+ def mount_objs(self, *objs: object):
1302
2216
  """Mount multiple service objects into this WSGI app."""
1303
2217
  for obj in objs:
1304
2218
  self.mount(obj, self._package_name)
1305
2219
 
1306
- def __call__(self, environ, start_response):
2220
+ def __call__(
2221
+ self,
2222
+ environ: dict[str, Any],
2223
+ start_response: Callable[[str, list[tuple[str, str]]], None],
2224
+ ) -> Any:
1307
2225
  """WSGI entry point."""
1308
2226
  return self._app(environ, start_response)
1309
2227
 
@@ -1314,17 +2232,22 @@ class ASGIApp:
1314
2232
  Useful for embedding gRPC within an existing ASGI stack.
1315
2233
  """
1316
2234
 
1317
- def __init__(self, app):
1318
- self._app = grpcASGI(app)
1319
- self._service_names = []
1320
- self._package_name = ""
2235
+ def __init__(self, app: Any):
2236
+ self._app: grpcASGI = grpcASGI(app)
2237
+ self._service_names: list[str] = []
2238
+ self._package_name: str = ""
1321
2239
 
1322
2240
  def mount(self, obj: object, package_name: str = ""):
1323
2241
  """Generate and compile proto files, then mount the async service implementation."""
1324
- pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name)
2242
+ pb2_grpc_module, pb2_module = generate_and_compile_proto(obj, package_name) or (
2243
+ None,
2244
+ None,
2245
+ )
1325
2246
  self.mount_using_pb2_modules(pb2_grpc_module, pb2_module, obj)
1326
2247
 
1327
- def mount_using_pb2_modules(self, pb2_grpc_module, pb2_module, obj: object):
2248
+ def mount_using_pb2_modules(
2249
+ self, pb2_grpc_module: Any, pb2_module: Any, obj: object
2250
+ ):
1328
2251
  """Connect the compiled gRPC modules with the async service implementation."""
1329
2252
  concreteServiceClass = connect_obj_with_stub_async(
1330
2253
  pb2_grpc_module, pb2_module, obj
@@ -1339,18 +2262,27 @@ class ASGIApp:
1339
2262
  ].full_name
1340
2263
  self._service_names.append(full_service_name)
1341
2264
 
1342
- def mount_objs(self, *objs):
2265
+ def mount_objs(self, *objs: object):
1343
2266
  """Mount multiple service objects into this ASGI app."""
1344
2267
  for obj in objs:
1345
2268
  self.mount(obj, self._package_name)
1346
2269
 
1347
- async def __call__(self, scope, receive, send):
2270
+ async def __call__(
2271
+ self,
2272
+ scope: dict[str, Any],
2273
+ receive: Callable[[], Any],
2274
+ send: Callable[[dict[str, Any]], Any],
2275
+ ) -> Any:
1348
2276
  """ASGI entry point."""
1349
- await self._app(scope, receive, send)
2277
+ _ = await self._app(scope, receive, send)
2278
+
2279
+
2280
+ def get_connecpy_asgi_app_class(connecpy_module: Any, service_name: str):
2281
+ return getattr(connecpy_module, f"{service_name}ASGIApplication")
1350
2282
 
1351
2283
 
1352
- def get_connecpy_server_class(connecpy_module, service_name):
1353
- return getattr(connecpy_module, f"{service_name}Server")
2284
+ def get_connecpy_wsgi_app_class(connecpy_module: Any, service_name: str):
2285
+ return getattr(connecpy_module, f"{service_name}WSGIApplication")
1354
2286
 
1355
2287
 
1356
2288
  class ConnecpyASGIApp:
@@ -1359,9 +2291,9 @@ class ConnecpyASGIApp:
1359
2291
  """
1360
2292
 
1361
2293
  def __init__(self):
1362
- self._app = ConnecpyASGI()
1363
- self._service_names = []
1364
- self._package_name = ""
2294
+ self._services: list[tuple[Any, str]] = [] # List of (app, path) tuples
2295
+ self._service_names: list[str] = []
2296
+ self._package_name: str = ""
1365
2297
 
1366
2298
  def mount(self, obj: object, package_name: str = ""):
1367
2299
  """Generate and compile proto files, then mount the async service implementation."""
@@ -1370,28 +2302,59 @@ class ConnecpyASGIApp:
1370
2302
  )
1371
2303
  self.mount_using_pb2_modules(connecpy_module, pb2_module, obj)
1372
2304
 
1373
- def mount_using_pb2_modules(self, connecpy_module, pb2_module, obj: object):
2305
+ def mount_using_pb2_modules(
2306
+ self, connecpy_module: Any, pb2_module: Any, obj: object
2307
+ ):
1374
2308
  """Connect the compiled connecpy and pb2 modules with the async service implementation."""
1375
2309
  concreteServiceClass = connect_obj_with_stub_async_connecpy(
1376
2310
  connecpy_module, pb2_module, obj
1377
2311
  )
1378
2312
  service_name = obj.__class__.__name__
1379
2313
  service_impl = concreteServiceClass()
1380
- connecpy_server = get_connecpy_server_class(connecpy_module, service_name)
1381
- self._app.add_service(connecpy_server(service=service_impl))
2314
+
2315
+ # Get the service-specific ASGI application class
2316
+ app_class = get_connecpy_asgi_app_class(connecpy_module, service_name)
2317
+ app = app_class(service=service_impl)
2318
+
2319
+ # Store the app and its path for routing
2320
+ self._services.append((app, app.path))
2321
+
1382
2322
  full_service_name = pb2_module.DESCRIPTOR.services_by_name[
1383
2323
  service_name
1384
2324
  ].full_name
1385
2325
  self._service_names.append(full_service_name)
1386
2326
 
1387
- def mount_objs(self, *objs):
2327
+ def mount_objs(self, *objs: object):
1388
2328
  """Mount multiple service objects into this ASGI app."""
1389
2329
  for obj in objs:
1390
2330
  self.mount(obj, self._package_name)
1391
2331
 
1392
- async def __call__(self, scope, receive, send):
1393
- """ASGI entry point."""
1394
- await self._app(scope, receive, send)
2332
+ async def __call__(
2333
+ self,
2334
+ scope: dict[str, Any],
2335
+ receive: Callable[[], Any],
2336
+ send: Callable[[dict[str, Any]], Any],
2337
+ ):
2338
+ """ASGI entry point with routing for multiple services."""
2339
+ if scope["type"] != "http":
2340
+ await send({"type": "http.response.start", "status": 404})
2341
+ await send({"type": "http.response.body", "body": b"Not Found"})
2342
+ return
2343
+
2344
+ path = scope.get("path", "")
2345
+
2346
+ # Route to the appropriate service based on path
2347
+ for app, service_path in self._services:
2348
+ if path.startswith(service_path):
2349
+ return await app(scope, receive, send)
2350
+
2351
+ # If only one service is mounted, use it as default
2352
+ if len(self._services) == 1:
2353
+ return await self._services[0][0](scope, receive, send)
2354
+
2355
+ # No matching service found
2356
+ await send({"type": "http.response.start", "status": 404})
2357
+ await send({"type": "http.response.body", "body": b"Not Found"})
1395
2358
 
1396
2359
 
1397
2360
  class ConnecpyWSGIApp:
@@ -1400,55 +2363,82 @@ class ConnecpyWSGIApp:
1400
2363
  """
1401
2364
 
1402
2365
  def __init__(self):
1403
- self._app = ConnecpyWSGI()
1404
- self._service_names = []
1405
- self._package_name = ""
2366
+ self._services: list[tuple[Any, str]] = [] # List of (app, path) tuples
2367
+ self._service_names: list[str] = []
2368
+ self._package_name: str = ""
1406
2369
 
1407
2370
  def mount(self, obj: object, package_name: str = ""):
1408
- """Generate and compile proto files, then mount the async service implementation."""
2371
+ """Generate and compile proto files, then mount the sync service implementation."""
1409
2372
  connecpy_module, pb2_module = generate_and_compile_proto_using_connecpy(
1410
2373
  obj, package_name
1411
2374
  )
1412
2375
  self.mount_using_pb2_modules(connecpy_module, pb2_module, obj)
1413
2376
 
1414
- def mount_using_pb2_modules(self, connecpy_module, pb2_module, obj: object):
1415
- """Connect the compiled connecpy and pb2 modules with the async service implementation."""
2377
+ def mount_using_pb2_modules(
2378
+ self, connecpy_module: Any, pb2_module: Any, obj: object
2379
+ ):
2380
+ """Connect the compiled connecpy and pb2 modules with the sync service implementation."""
1416
2381
  concreteServiceClass = connect_obj_with_stub_connecpy(
1417
2382
  connecpy_module, pb2_module, obj
1418
2383
  )
1419
2384
  service_name = obj.__class__.__name__
1420
2385
  service_impl = concreteServiceClass()
1421
- connecpy_server = get_connecpy_server_class(connecpy_module, service_name)
1422
- self._app.add_service(connecpy_server(service=service_impl))
2386
+
2387
+ # Get the service-specific WSGI application class
2388
+ app_class = get_connecpy_wsgi_app_class(connecpy_module, service_name)
2389
+ app = app_class(service=service_impl)
2390
+
2391
+ # Store the app and its path for routing
2392
+ self._services.append((app, app.path))
2393
+
1423
2394
  full_service_name = pb2_module.DESCRIPTOR.services_by_name[
1424
2395
  service_name
1425
2396
  ].full_name
1426
2397
  self._service_names.append(full_service_name)
1427
2398
 
1428
- def mount_objs(self, *objs):
2399
+ def mount_objs(self, *objs: object):
1429
2400
  """Mount multiple service objects into this WSGI app."""
1430
2401
  for obj in objs:
1431
2402
  self.mount(obj, self._package_name)
1432
2403
 
1433
- def __call__(self, environ, start_response):
1434
- """WSGI entry point."""
1435
- return self._app(environ, start_response)
2404
+ def __call__(
2405
+ self,
2406
+ environ: dict[str, Any],
2407
+ start_response: Callable[[str, list[tuple[str, str]]], None],
2408
+ ) -> Any:
2409
+ """WSGI entry point with routing for multiple services."""
2410
+ path = environ.get("PATH_INFO", "")
2411
+
2412
+ # Route to the appropriate service based on path
2413
+ for app, service_path in self._services:
2414
+ if path.startswith(service_path):
2415
+ return app(environ, start_response)
2416
+
2417
+ # If only one service is mounted, use it as default
2418
+ if len(self._services) == 1:
2419
+ return self._services[0][0](environ, start_response)
2420
+
2421
+ # No matching service found
2422
+ start_response("404 Not Found", [("Content-Type", "text/plain")])
2423
+ return [b"Not Found"]
1436
2424
 
1437
2425
 
1438
2426
  def main():
1439
2427
  import argparse
1440
2428
 
1441
2429
  parser = argparse.ArgumentParser(description="Generate and compile proto files.")
1442
- parser.add_argument(
2430
+ _ = parser.add_argument(
1443
2431
  "py_file", type=str, help="The Python file containing the service class."
1444
2432
  )
1445
- parser.add_argument("class_name", type=str, help="The name of the service class.")
2433
+ _ = parser.add_argument(
2434
+ "class_name", type=str, help="The name of the service class."
2435
+ )
1446
2436
  args = parser.parse_args()
1447
2437
 
1448
2438
  module_name = os.path.splitext(basename(args.py_file))[0]
1449
2439
  module = importlib.import_module(module_name)
1450
2440
  klass = getattr(module, args.class_name)
1451
- generate_and_compile_proto(klass())
2441
+ _ = generate_and_compile_proto(klass())
1452
2442
 
1453
2443
 
1454
2444
  if __name__ == "__main__":