aspyx-service 0.10.7__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aspyx-service might be problematic. Click here for more details.

@@ -0,0 +1,1093 @@
1
+ """
2
+ Protobuf channel and utilities
3
+ """
4
+ from __future__ import annotations
5
+
6
+ import inspect
7
+ import logging
8
+ import threading
9
+ from dataclasses import is_dataclass, fields as dc_fields
10
+ from typing import Type, get_type_hints, Callable, Tuple, get_origin, get_args, List, Dict, Any, Union, Sequence, \
11
+ Optional, cast
12
+
13
+ import httpx
14
+ from google.protobuf.message_factory import GetMessageClass
15
+ from pydantic import BaseModel
16
+
17
+ from google.protobuf import descriptor_pb2, descriptor_pool, message_factory
18
+ from google.protobuf.descriptor_pool import DescriptorPool
19
+ from google.protobuf.message import Message
20
+ from google.protobuf.descriptor import FieldDescriptor, Descriptor
21
+ from starlette.responses import PlainTextResponse
22
+
23
+ from aspyx.di import injectable, Environment
24
+ from aspyx.reflection import DynamicProxy, TypeDescriptor
25
+ from aspyx.util import CopyOnWriteCache, StringBuilder
26
+
27
+ from .service import channel, ServiceException, Server, ComponentDescriptor
28
+ from .channels import HTTPXChannel
29
+ from .service import ServiceManager, ServiceCommunicationException, AuthorizationException, RemoteServiceException
30
+
31
+ def get_inner_type(typ: Type) -> Type:
32
+ """
33
+ Extract the inner type from List[InnerType], Optional[InnerType], etc.
34
+ """
35
+ origin = getattr(typ, "__origin__", None)
36
+ args = getattr(typ, "__args__", None)
37
+
38
+ if origin in (list, List):
39
+ return args[0] if args else Any
40
+
41
+ # Handle Optional[X] -> X
42
+ if origin is Union and len(args) == 2 and type(None) in args:
43
+ return args[0] if args[1] is type(None) else args[1]
44
+
45
+ return typ
46
+
47
+
48
+ def defaults_dict(model_cls: Type[BaseModel]) -> dict[str, Any]:
49
+ result = {}
50
+ for name, field in model_cls.model_fields.items():
51
+ if field.default is not None:
52
+ result[name] = field.default
53
+ elif field.default_factory is not None:
54
+ result[name] = field.default_factory()
55
+ return result
56
+
57
+ class ProtobufBuilder:
58
+ """
59
+ used to infer protobuf services and messages given component and service structures.
60
+ """
61
+ logger = logging.getLogger("aspyx.service.protobuf") #
62
+
63
+ # slots
64
+
65
+ __slots__ = [
66
+ "pool",
67
+ "factory",
68
+ "modules",
69
+ "components",
70
+ "lock"
71
+ ]
72
+
73
+ @classmethod
74
+ def get_message_name(cls, type: Type, suffix="") -> str:
75
+ module = type.__module__.replace(".", "_")
76
+ name = type.__name__
77
+
78
+ return f"{module}.{name}{suffix}"
79
+
80
+ @classmethod
81
+ def get_request_message_name(cls, type: Type, method: Callable) -> str:
82
+ return cls.get_message_name(type, f"_{method.__name__}_Request")
83
+
84
+ @classmethod
85
+ def get_response_message_name(cls, type: Type, method: Callable) -> str:
86
+ return cls.get_message_name(type, f"_{method.__name__}_Response")
87
+
88
+ # local classes
89
+
90
+ class Module:
91
+ # constructor
92
+
93
+ def __init__(self, builder: ProtobufBuilder, name: str):
94
+ self.builder = builder
95
+ self.name = name.replace(".", "_")
96
+ self.file_desc_proto = descriptor_pb2.FileDescriptorProto() # type: ignore
97
+ self.file_desc_proto.name = f"{self.name}.proto"
98
+ self.file_desc_proto.package = self.name
99
+ self.types : dict[Type, Any] = {}
100
+ self.sealed = False
101
+ self.lock = threading.RLock()
102
+
103
+ # public
104
+
105
+ def get_fields_and_types(self, type: Type) -> List[Tuple[str, Type]]:
106
+ hints = get_type_hints(type)
107
+
108
+ if is_dataclass(type):
109
+ return [(f.name, hints.get(f.name, str)) for f in dc_fields(type)]
110
+
111
+ if issubclass(type, BaseModel):
112
+ return [(name, hints.get(name, str)) for name in type.model_fields]
113
+
114
+ raise TypeError("Expected a dataclass or Pydantic model class.")
115
+
116
+ def add_message(self, cls: Type) -> str:
117
+ if self.sealed:
118
+ raise ServiceException(f"module {self.name} is already sealed")
119
+
120
+
121
+
122
+ name = cls.__name__
123
+ full_name = f"{self.name}.{name}"
124
+
125
+ ProtobufBuilder.logger.debug(f"adding message %s", full_name)
126
+
127
+ # Check if a message type is already defined
128
+
129
+ if any(m.name == name for m in self.file_desc_proto.message_type):
130
+ return f".{full_name}"
131
+
132
+ desc = descriptor_pb2.DescriptorProto() # type: ignore
133
+ desc.name = name
134
+
135
+ # Extract fields from dataclass or pydantic model
136
+
137
+ if is_dataclass(cls) or issubclass(cls, BaseModel):
138
+ index = 1
139
+ for field_name, field_type in self.get_fields_and_types(cls):
140
+ field_type_enum, label, type_name = self.builder.to_proto_type(self, field_type)
141
+
142
+ f = desc.field.add()
143
+ f.name = field_name
144
+ f.number = index
145
+ f.label = label
146
+ f.type = field_type_enum
147
+ if type_name:
148
+ f.type_name = type_name
149
+ index += 1
150
+
151
+ # add message type descriptor to the file descriptor proto
152
+
153
+ self.file_desc_proto.message_type.add().CopyFrom(desc)
154
+
155
+ return f".{full_name}"
156
+
157
+ def check_message(self, origin, type: Type) -> str:
158
+ if type not in self.types:
159
+ if self is not origin:
160
+ if not self.name in origin.file_desc_proto.dependency:
161
+ origin.file_desc_proto.dependency.append(self.file_desc_proto.name)
162
+
163
+ self.types[type] = self.add_message(type)
164
+
165
+ return self.types[type]
166
+
167
+ def build_request_message(self, method: TypeDescriptor.MethodDescriptor, request_name: str):
168
+ if self.sealed:
169
+ raise ServiceException(f"module {self.name} is already sealed")
170
+
171
+ request_msg = descriptor_pb2.DescriptorProto() # type: ignore
172
+ request_msg.name = request_name.split(".")[-1]
173
+
174
+ ProtobufBuilder.logger.debug(f"adding request message %s", request_msg.name)
175
+
176
+ # loop over parameters
177
+
178
+ field_index = 1
179
+ for param in method.params:
180
+ field = request_msg.field.add()
181
+
182
+ field.name = param.name
183
+ field.number = field_index
184
+
185
+ field_type, label, type_name = self.builder.to_proto_type(self, param.type)
186
+ field.type = field_type
187
+ field.label = label
188
+ if type_name:
189
+ field.type_name = type_name
190
+
191
+ field_index += 1
192
+
193
+ # add to service file descriptor
194
+
195
+ self.file_desc_proto.message_type.add().CopyFrom(request_msg)
196
+
197
+ def build_response_message(self, method: TypeDescriptor.MethodDescriptor, response_name: str):
198
+ if self.sealed:
199
+ raise ServiceException(f"module {self.name} is already sealed")
200
+
201
+ response_msg = descriptor_pb2.DescriptorProto() # type: ignore
202
+ response_msg.name = response_name.split(".")[-1]
203
+
204
+ ProtobufBuilder.logger.debug(f"adding response message %s", response_msg.name)
205
+
206
+ # return
207
+
208
+ return_type = method.return_type
209
+ response_field = response_msg.field.add()
210
+ response_field.name = "result"
211
+ response_field.number = 1
212
+
213
+ field_type, label, type_name = self.builder.to_proto_type(self, return_type)
214
+ response_field.type = field_type
215
+ response_field.label = label
216
+ if type_name:
217
+ response_field.type_name = type_name
218
+
219
+ # exception
220
+
221
+ exception_field = response_msg.field.add()
222
+ exception_field.name = "exception"
223
+ exception_field.number = 2
224
+
225
+ field_type, label, type_name = self.builder.to_proto_type(self, str)
226
+ exception_field.type = field_type
227
+ exception_field.label = label
228
+ if type_name:
229
+ exception_field.type_name = type_name
230
+
231
+ # add to service file descriptor
232
+
233
+ self.file_desc_proto.message_type.add().CopyFrom(response_msg)
234
+
235
+ def build_service_method(self, service_desc: descriptor_pb2.ServiceDescriptorProto, service_type: TypeDescriptor, method: TypeDescriptor.MethodDescriptor):
236
+ name = f"{service_type.cls.__name__}_{method.get_name()}"
237
+ package = self.name
238
+
239
+ method_desc = descriptor_pb2.MethodDescriptorProto()
240
+
241
+ request_name = f".{package}.{name}_Request"
242
+ response_name = f".{package}.{name}_Response"
243
+
244
+ method_desc.name = method.get_name()
245
+ method_desc.input_type = request_name
246
+ method_desc.output_type = response_name
247
+
248
+ # Build request and response message types
249
+
250
+ self.build_request_message(method, request_name)
251
+ self.build_response_message(method, response_name)
252
+
253
+ # Add method to service descriptor
254
+
255
+ service_desc.method.add().CopyFrom(method_desc)
256
+
257
+ def add_service(self, service_type: TypeDescriptor):
258
+ if self.sealed:
259
+ raise ServiceException(f"module {self.name} is already sealed")
260
+
261
+ service_desc = descriptor_pb2.ServiceDescriptorProto() # type: ignore
262
+ service_desc.name = service_type.cls.__name__
263
+
264
+ ProtobufBuilder.logger.debug(f"add service %s", service_desc.name)
265
+
266
+ # check methods
267
+
268
+ for method in service_type.get_methods():
269
+ self.build_service_method(service_desc, service_type, method)
270
+
271
+ # done
272
+
273
+ self.file_desc_proto.service.add().CopyFrom(service_desc)
274
+
275
+ def seal(self, builder: ProtobufBuilder):
276
+ if not self.sealed:
277
+ ProtobufBuilder.logger.debug(f"create protobuf {self.file_desc_proto.name}")
278
+
279
+ self.sealed = True
280
+
281
+ # add dependency first
282
+
283
+ for dependency in self.file_desc_proto.dependency:
284
+ builder.modules[dependency].seal(builder)
285
+
286
+ builder.pool.Add(self.file_desc_proto)
287
+
288
+ #print(ProtobufDumper.dump_proto(self.file_desc_proto))
289
+
290
+ # constructor
291
+
292
+ def __init__(self):
293
+ self.pool: DescriptorPool = descriptor_pool.Default()
294
+ self.factory = message_factory.MessageFactory(self.pool)
295
+ self.modules: Dict[str, ProtobufBuilder.Module] = {}
296
+ self.components = {}
297
+ self.lock = threading.RLock()
298
+
299
+ # internal
300
+
301
+ def to_proto_type(self, module_origin, py_type: Type) -> Tuple[int, int, Optional[str]]:
302
+ """
303
+ Convert Python type to protobuf (field_type, label, type_name).
304
+ Returns:
305
+ - field_type: int (descriptor_pb2.FieldDescriptorProto.TYPE_*)
306
+ - label: int (descriptor_pb2.FieldDescriptorProto.LABEL_*)
307
+ - type_name: Optional[str] (fully qualified message name for messages)
308
+ """
309
+ origin = get_origin(py_type)
310
+ args = get_args(py_type)
311
+
312
+ # Check for repeated fields (list / List)
313
+ if origin in (list, List):
314
+ # Assume single-argument generic list e.g. List[int], List[FooModel]
315
+ item_type = args[0] if args else str
316
+ field_type, _, type_name = self._resolve_type(module_origin, item_type)
317
+ return (
318
+ field_type,
319
+ descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED, # type: ignore
320
+ type_name,
321
+ )
322
+
323
+ return self._resolve_type(module_origin, py_type)
324
+
325
+ def _resolve_type(self, origin, py_type: Type) -> Tuple[int, int, Optional[str]]:
326
+ """Resolves Python type to protobuf scalar or message type with label=optional."""
327
+ # Structured message (dataclass or Pydantic BaseModel)
328
+ if is_dataclass(py_type) or (inspect.isclass(py_type) and issubclass(py_type, BaseModel)):
329
+ type_name = self.get_module(py_type).check_message(origin, py_type)
330
+ return (
331
+ descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, # type: ignore
332
+ descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL, # type: ignore
333
+ type_name,
334
+ )
335
+
336
+ # Scalar mappings
337
+
338
+ scalar = self._map_scalar_type(py_type)
339
+
340
+ return scalar, descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL, None # type: ignore
341
+
342
+ def _map_scalar_type(self, py_type: Type) -> int:
343
+ """Map Python scalar types to protobuf field types."""
344
+ mapping = {
345
+ str: descriptor_pb2.FieldDescriptorProto.TYPE_STRING, # type: ignore
346
+ int: descriptor_pb2.FieldDescriptorProto.TYPE_INT32, # type: ignore
347
+ float: descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT, # type: ignore
348
+ bool: descriptor_pb2.FieldDescriptorProto.TYPE_BOOL, # type: ignore
349
+ bytes: descriptor_pb2.FieldDescriptorProto.TYPE_BYTES, # type: ignore
350
+ }
351
+
352
+ return mapping.get(py_type, descriptor_pb2.FieldDescriptorProto.TYPE_STRING) # type: ignore
353
+
354
+ def check_type(self, type: Type):
355
+ self.get_module(type).check_message(type)
356
+
357
+ def get_message_type(self, full_name: str):
358
+ return GetMessageClass(self.pool.FindMessageTypeByName(full_name))
359
+
360
+ def get_request_message(self, type: Type, method: Callable):
361
+ return self.get_message_type(self.get_request_message_name(type, method))
362
+
363
+ def get_response_message(self, type: Type, method: Callable):
364
+ return self.get_message_type(self.get_response_message_name(type, method))
365
+
366
+ def get_module(self, type: Type):
367
+ name = type.__module__#.replace(".", "_")
368
+ key = name.replace(".", "_") + ".proto"
369
+ module = self.modules.get(key, None)
370
+ if module is None:
371
+ module = ProtobufBuilder.Module(self, name)
372
+ self.modules[key] = module
373
+
374
+ return module
375
+
376
+ def build_service(self, service: TypeDescriptor):
377
+ self.get_module(service.cls).add_service(service)
378
+
379
+ # public
380
+
381
+ #@synchronized()
382
+ def prepare_component(self, component_descriptor: ComponentDescriptor):
383
+ with self.lock:
384
+ if component_descriptor not in self.components:
385
+ for service in component_descriptor.services:
386
+ self.build_service(TypeDescriptor.for_type(service.type))
387
+
388
+ # finalize
389
+
390
+ for module in self.modules.values():
391
+ module.seal(self)
392
+
393
+ # done
394
+
395
+ self.components[component_descriptor] = True
396
+
397
+
398
+ @injectable()
399
+ class ProtobufManager(ProtobufBuilder):
400
+ # local classes
401
+
402
+ class MethodDeserializer:
403
+ __slots__ = [
404
+ "manager",
405
+ "descriptor",
406
+ "getters"
407
+ ]
408
+
409
+ # constructor
410
+
411
+ def __init__(self, manager: ProtobufManager, descriptor: Descriptor):
412
+ self.manager = manager
413
+ self.descriptor = descriptor
414
+
415
+ self.getters = []
416
+
417
+ # internal
418
+
419
+ def args(self, method: Callable)-> ProtobufManager.MethodDeserializer:
420
+ type_hints = get_type_hints(method)
421
+
422
+ # loop over parameters
423
+
424
+ for param_name in inspect.signature(method).parameters:
425
+ if param_name == "self":
426
+ continue
427
+
428
+ field_desc = self.descriptor.fields_by_name[param_name]
429
+
430
+ self.getters.append(self._create_getter(field_desc, param_name, type_hints.get(param_name, str)))
431
+
432
+ return self
433
+
434
+ def result(self, method: Callable) -> 'ProtobufManager.MethodDeserializer':
435
+ type_hints = get_type_hints(method)
436
+
437
+ return_type = type_hints.get('return')
438
+
439
+ result_field_desc = self.descriptor.DESCRIPTOR.fields_by_name["result"]
440
+ exception_field_desc = self.descriptor.DESCRIPTOR.fields_by_name["exception"]
441
+
442
+ self.getters.append(self._create_getter(result_field_desc, "result", return_type))
443
+ self.getters.append(self._create_getter(exception_field_desc, "exception", str))
444
+
445
+ return self
446
+
447
+ def get_fields_and_types(self, type: Type) -> List[Tuple[str, Type]]:
448
+ hints = get_type_hints(type)
449
+
450
+ if is_dataclass(type):
451
+ return [(f.name, hints.get(f.name, str)) for f in dc_fields(type)]
452
+
453
+ if issubclass(type, BaseModel):
454
+ return [(name, hints.get(name, str)) for name in type.model_fields]
455
+
456
+ raise TypeError("Expected a dataclass or Pydantic model class.")
457
+
458
+ def _create_getter(self, field_desc: FieldDescriptor, field_name: str, type: Type):
459
+ is_repeated = field_desc.label == field_desc.LABEL_REPEATED
460
+ is_message = field_desc.message_type is not None
461
+
462
+ ## local func
463
+
464
+ def compute_class_getters(item_type: Type) -> list[Callable]:
465
+ getters = []
466
+
467
+ for sub_field_name, field_type in self.get_fields_and_types(item_type):
468
+ getters.append(self._create_getter(message_type.fields_by_name[sub_field_name], sub_field_name, field_type))
469
+
470
+ return getters
471
+
472
+ # list
473
+
474
+ if is_repeated:
475
+ item_type = get_args(type)[0] if get_origin(type) in (list, List) else str
476
+
477
+ # list of messages
478
+
479
+ if is_dataclass(item_type) or issubclass(item_type, BaseModel):
480
+ message_type = self.manager.pool.FindMessageTypeByName(ProtobufManager.get_message_name(item_type))
481
+
482
+ getters = self.manager.getter_lambdas_cache.get(item_type, compute_class_getters)
483
+
484
+ def deserialize_dataclass_list(msg: Message, val: Any, setter=setattr, getters=getters):
485
+ list = []
486
+
487
+ for item in getattr(msg, field_name):
488
+ instance = item_type.__new__(item_type)
489
+
490
+ for getter in getters:
491
+ getter(item, instance, object.__setattr__)
492
+
493
+ list.append(instance)
494
+
495
+ setter(val, field_name, list)
496
+
497
+ default = {}
498
+ if issubclass(item_type, BaseModel):
499
+ default = defaults_dict(item_type)
500
+
501
+ def deserialize_pydantic_list(msg: Message, val: Any, setter=setattr, getters=getters):
502
+ list = []
503
+
504
+ for item in getattr(msg, field_name):
505
+ #instance = type.__new__(type)
506
+
507
+ instance = item_type.model_construct(**default)
508
+
509
+ for getter in getters:
510
+ getter(item, instance, setattr)
511
+
512
+ list.append(instance)
513
+
514
+ setter(val, field_name, list)
515
+
516
+ if is_dataclass(item_type):
517
+ return deserialize_dataclass_list
518
+ else:
519
+ return deserialize_pydantic_list
520
+
521
+ # list of scalars
522
+
523
+ else:
524
+ def deserialize_list(msg: Message, val, setter=setattr):
525
+ list = []
526
+
527
+ for item in getattr(msg, field_name):
528
+ list.append(item)
529
+
530
+ setter(val, field_name, list)
531
+
532
+ return deserialize_list
533
+
534
+ # message
535
+
536
+ elif is_message:
537
+ if is_dataclass(type) or issubclass(type, BaseModel):
538
+ message_type = self.manager.pool.FindMessageTypeByName(ProtobufManager.get_message_name(type))
539
+
540
+ sub_getters = self.manager.getter_lambdas_cache.get(type, compute_class_getters)
541
+
542
+ default = {}
543
+ if issubclass(type, BaseModel):
544
+ default = defaults_dict(type)
545
+
546
+ def deserialize_dataclass(msg: Message, val: Any, setter=setattr, getters=sub_getters):
547
+ sub_message = getattr(msg, field_name)
548
+
549
+ instance = type.__new__(type)
550
+
551
+ for getter in getters:
552
+ getter(sub_message, instance, setattr)#object.__setattr__
553
+
554
+ setter(val, field_name, instance)
555
+
556
+ def deserialize_pydantic(msg: Message, val: Any, setter=setattr, getters=sub_getters):
557
+ sub_message = getattr(msg, field_name)
558
+
559
+ instance = type.model_construct(**default)
560
+
561
+ for getter in getters:
562
+ getter(sub_message, instance, setattr)
563
+
564
+ setter(val, field_name, instance)
565
+
566
+ if is_dataclass(type):
567
+ return deserialize_dataclass
568
+ else:
569
+ return deserialize_pydantic
570
+ else:
571
+ raise TypeError(f"Expected dataclass or BaseModel for field '{field_name}', got {type}")
572
+
573
+ # scalar
574
+
575
+ else:
576
+ def deserialize_scalar(msg: Message, val: Any, setter=setattr):
577
+ if msg.HasField(field_name):
578
+ setter(val, field_name, getattr(msg, field_name))
579
+ else:
580
+ setter(val, field_name, None)
581
+
582
+ return deserialize_scalar
583
+
584
+ # public
585
+
586
+ def deserialize(self, message: Message) -> list[Any]:
587
+ # call setters
588
+
589
+ list = []
590
+ for getter in self.getters:
591
+ getter(message, list, lambda obj, prop, value: list.append(value))
592
+
593
+ return list
594
+
595
+ def deserialize_result(self, message: Message) -> Any:
596
+ result = None
597
+ exception = None
598
+
599
+ def set_result(obj, prop, value):
600
+ nonlocal result, exception
601
+
602
+ if prop == "result":
603
+ result = value
604
+ else:
605
+ exception = value
606
+
607
+ # call setters
608
+
609
+ for getter in self.getters:
610
+ getter(message, None, set_result)
611
+
612
+ if result is None:
613
+ raise RemoteServiceException(f"server side exception {exception}")
614
+
615
+ return result
616
+
617
+ class MethodSerializer:
618
+ __slots__ = [
619
+ "manager",
620
+ "message_type",
621
+ "setters"
622
+ ]
623
+
624
+ # constructor
625
+
626
+ def __init__(self, manager: ProtobufManager, message_type):
627
+ self.manager = manager
628
+ self.message_type = message_type
629
+
630
+ self.setters = []
631
+
632
+ def result(self, method: Callable) -> ProtobufManager.MethodSerializer:
633
+ msg_descriptor = self.message_type.DESCRIPTOR
634
+ type_hints = get_type_hints(method)
635
+
636
+ return_type = type_hints["return"]
637
+
638
+ result_field_desc = msg_descriptor.fields_by_name["result"]
639
+ exception_field_desc = msg_descriptor.fields_by_name["exception"]
640
+
641
+ self.setters.append(self._create_setter(result_field_desc, "result", return_type))
642
+ self.setters.append(self._create_setter(exception_field_desc, "exception", str))
643
+
644
+ return self
645
+
646
+ def args(self, method: Callable)-> ProtobufManager.MethodSerializer:
647
+ msg_descriptor = self.message_type.DESCRIPTOR
648
+ type_hints = get_type_hints(method)
649
+
650
+ # loop over parameters
651
+
652
+ for param_name in inspect.signature(method).parameters:
653
+ if param_name == "self":
654
+ continue
655
+
656
+ field_desc = msg_descriptor.fields_by_name[param_name]
657
+
658
+ self.setters.append(self._create_setter(field_desc, param_name, type_hints.get(param_name, str)))
659
+
660
+ # done
661
+
662
+ return self
663
+
664
+ def get_fields_and_types(self, type: Type) -> List[Tuple[str, Type]]:
665
+ hints = get_type_hints(type)
666
+
667
+ if is_dataclass(type):
668
+ return [(f.name, hints.get(f.name, str)) for f in dc_fields(type)]
669
+
670
+ if issubclass(type, BaseModel):
671
+ return [(name, hints.get(name, str)) for name in type.model_fields]
672
+
673
+ raise TypeError("Expected a dataclass or Pydantic model class.")
674
+
675
+ def _create_setter(self, field_desc: FieldDescriptor, field_name: str, type: Type):
676
+ is_repeated = field_desc.label == field_desc.LABEL_REPEATED
677
+ is_message = field_desc.message_type is not None
678
+
679
+ # local func
680
+
681
+ def create(message_type: Descriptor, item_type: Type) -> Tuple[list[Callable],list[str]]:
682
+ setters = []
683
+ fields = []
684
+ for field_name, field_type in self.get_fields_and_types(item_type):
685
+ fields.append(field_name)
686
+ setters.append(self._create_setter(message_type.fields_by_name[field_name], field_name, field_type))
687
+
688
+ return setters, fields
689
+
690
+ # list
691
+
692
+ if is_repeated:
693
+ item_type = get_args(type)[0] if get_origin(type) in (list, List) else str
694
+
695
+ # list of messages
696
+
697
+ if is_dataclass(item_type) or issubclass(item_type, BaseModel):
698
+ message_type = self.manager.pool.FindMessageTypeByName(ProtobufManager.get_message_name(item_type))
699
+
700
+ setters, fields = self.manager.setter_lambdas_cache.get(item_type, lambda t: create(message_type, item_type))
701
+
702
+ def serialize_message_list(msg: Message, val: Any, fields=fields, setters=setters):
703
+ for item in val:
704
+ msg_item = getattr(msg, field_name).add()
705
+ for i in range(len(setters)):
706
+ setters[i](msg_item, getattr(item, fields[i]))
707
+
708
+ return serialize_message_list
709
+
710
+ # list of scalars
711
+
712
+ else:
713
+ return lambda msg, val: getattr(msg, field_name).extend(val)
714
+
715
+ # message
716
+
717
+ elif is_message:
718
+ if is_dataclass(type) or issubclass(type, BaseModel):
719
+ message_type = self.manager.pool.FindMessageTypeByName(ProtobufManager.get_message_name(type))
720
+
721
+ sub_setters, fields = self.manager.setter_lambdas_cache.get(type, lambda t: create(message_type, type))
722
+
723
+ def serialize_message(msg: Message, val: Any, fields=fields, setters=sub_setters):
724
+ field = getattr(msg, field_name)
725
+ for i in range(len(sub_setters)):
726
+ setters[i](field, getattr(val, fields[i]))
727
+
728
+ return serialize_message
729
+ else:
730
+ raise TypeError(f"Expected dataclass or BaseModel for field '{field_name}', got {type}")
731
+
732
+ # scalar
733
+
734
+ else:
735
+ def set_attr(msg, val):
736
+ if val is not None:
737
+ setattr(msg, field_name, val)
738
+ else:
739
+ pass#delattr(msg, field_name)
740
+
741
+ return set_attr# lambda msg, val: setattr(msg, field_name, val)
742
+
743
+ def serialize(self, value: Any) -> Any:
744
+ # create message instance
745
+
746
+ message = self.message_type()
747
+
748
+ # call setters
749
+
750
+ for i in range(len(self.setters)):
751
+ self.setters[i](message, value)
752
+
753
+ return message
754
+
755
+ def serialize_result(self, value: Any, exception: str) -> Any:
756
+ # create message instance
757
+
758
+ message = self.message_type()
759
+
760
+ # call setters
761
+
762
+ if value is not None:
763
+ self.setters[0](message, value)
764
+
765
+ if exception is not None:
766
+ self.setters[1](message, exception)
767
+
768
+ return message
769
+
770
+ def serialize_args(self, args: Sequence[Any]) -> Any:
771
+ # create message instance
772
+
773
+ message = self.message_type()
774
+
775
+ # call setters
776
+
777
+ for i in range(len(self.setters)):
778
+ self.setters[i](message, args[i])
779
+
780
+ #for setter, value in zip(self.setters, invocation.args):
781
+ # setter(message, value)
782
+
783
+ return message
784
+
785
+ # slots
786
+
787
+ __slots__ = [
788
+ "serializer_cache",
789
+ "deserializer_cache",
790
+ "result_serializer_cache",
791
+ "result_deserializer_cache",
792
+ "setter_lambdas_cache",
793
+ "getter_lambdas_cache"
794
+ ]
795
+
796
+ # constructor
797
+
798
+ def __init__(self):
799
+ super().__init__()
800
+
801
+ self.serializer_cache = CopyOnWriteCache[Callable, ProtobufManager.MethodSerializer]()
802
+ self.deserializer_cache = CopyOnWriteCache[Descriptor, ProtobufManager.MethodDeserializer]()
803
+
804
+ self.result_serializer_cache = CopyOnWriteCache[Descriptor, ProtobufManager.MethodSerializer]()
805
+ self.result_deserializer_cache = CopyOnWriteCache[Descriptor, ProtobufManager.MethodDeserializer]()
806
+
807
+ self.setter_lambdas_cache = CopyOnWriteCache[Type, list[Callable]]()
808
+ self.getter_lambdas_cache = CopyOnWriteCache[Type, Tuple[list[Callable], list[str]]]()
809
+
810
+ # public
811
+
812
+ def create_serializer(self, type: Type, method: Callable) -> ProtobufManager.MethodSerializer:
813
+ return self.serializer_cache.get(method, lambda m: ProtobufManager.MethodSerializer(self, self.get_request_message(type, m)).args(m) )
814
+
815
+ def create_deserializer(self, descriptor: Descriptor, method: Callable) -> ProtobufManager.MethodDeserializer:
816
+ return self.deserializer_cache.get(descriptor, lambda d: ProtobufManager.MethodDeserializer(self, d).args(method))
817
+
818
+ def create_result_serializer(self, descriptor: Descriptor, method: Callable) -> ProtobufManager.MethodSerializer:
819
+ return self.result_serializer_cache.get(descriptor, lambda d: ProtobufManager.MethodSerializer(self, d).result(method))
820
+
821
+ def create_result_deserializer(self, descriptor: Descriptor, method: Callable) -> ProtobufManager.MethodDeserializer:
822
+ return self.result_deserializer_cache.get(descriptor, lambda d: ProtobufManager.MethodDeserializer(self, d).result(method))
823
+
824
+ def report(self) -> str:
825
+ builder = StringBuilder()
826
+ for module in self.modules.values():
827
+ builder.append(ProtobufDumper.dump_proto(module.file_desc_proto))
828
+
829
+ return str(builder)
830
+
831
+ @channel("dispatch-protobuf")
832
+ class ProtobufChannel(HTTPXChannel):
833
+ """
834
+ channel, encoding requests and responses with protobuf
835
+ """
836
+ # class methods
837
+
838
+ @classmethod
839
+ def prepare(cls, server: Server, component_descriptor: ComponentDescriptor):
840
+ protobuf_manager = server.get(ProtobufManager)
841
+ protobuf_manager.prepare_component(component_descriptor)
842
+
843
+ def report_protobuf():
844
+ return protobuf_manager.report()
845
+
846
+ server.add_route(path="/report-protobuf", endpoint=report_protobuf, methods=["GET"], response_class=PlainTextResponse)
847
+
848
+ # local classes
849
+
850
+ class Call:
851
+ __slots__ = [
852
+ "method_name",
853
+ "serializer",
854
+ "response_type",
855
+ "deserializer"
856
+ ]
857
+
858
+ # constructor
859
+
860
+ def __init__(self, method_name: str, serializer: ProtobufManager.MethodSerializer, response_type, deserializer: ProtobufManager.MethodDeserializer):
861
+ self.method_name = method_name
862
+ self.serializer = serializer
863
+ self.response_type = response_type
864
+ self.deserializer = deserializer
865
+
866
+ # public
867
+
868
+ def serialize(self, args: Sequence[Any]) -> Any:
869
+ message = self.serializer.serialize_args(args)
870
+ return message.SerializeToString()
871
+
872
+ def deserialize(self, http_response: httpx.Response) -> Any:
873
+ response = self.response_type()
874
+ response.ParseFromString(http_response.content)
875
+
876
+ return self.deserializer.deserialize_result(response)
877
+
878
+ # slots
879
+
880
+ __slots__ = [
881
+ "manager",
882
+ "environment",
883
+ "protobuf_manager",
884
+ "cache"
885
+ ]
886
+
887
+ # constructor
888
+
889
+ def __init__(self, manager: ServiceManager, protobuf_manager: ProtobufManager):
890
+ super().__init__()
891
+
892
+ self.manager = manager
893
+ self.environment = None
894
+ self.protobuf_manager = protobuf_manager
895
+ self.cache = CopyOnWriteCache[Callable, ProtobufChannel.Call]()
896
+
897
+ # make sure, all protobuf messages are created
898
+
899
+ for descriptor in manager.descriptors.values():
900
+ if descriptor.is_component():
901
+ protobuf_manager.prepare_component(cast(ComponentDescriptor, descriptor))
902
+
903
+ # internal
904
+
905
+ def get_call(self, type: Type, method: Callable) -> ProtobufChannel.Call:
906
+ call = self.cache.get(method)
907
+ if call is None:
908
+ method_name = f"{self.component_descriptor.name}:{self.service_names[type]}:{method.__name__}"
909
+ serializer = self.protobuf_manager.create_serializer(type, method)
910
+ response_type = self.protobuf_manager.get_message_type(self.protobuf_manager.get_response_message_name(type, method))
911
+ deserializer = self.protobuf_manager.create_result_deserializer(response_type, method)
912
+
913
+ call = ProtobufChannel.Call(method_name, serializer, response_type, deserializer)
914
+
915
+ self.cache.put(method, call)
916
+
917
+ return call
918
+
919
+ # implement
920
+
921
+ async def invoke_async(self, invocation: DynamicProxy.Invocation):
922
+ call = self.get_call(invocation.type, invocation.method)
923
+
924
+ try:
925
+ http_result = await self.request_async("post", f"{self.get_url()}/invoke", content=call.serialize(invocation.args),
926
+ timeout=self.timeout, headers={
927
+ "Content-Type": "application/x-protobuf",
928
+ # "Accept": "application/x-protobuf",
929
+ "x-rpc-method": call.method_name
930
+ })
931
+
932
+ return call.deserialize(http_result)
933
+ except (ServiceCommunicationException, AuthorizationException, RemoteServiceException) as e:
934
+ raise
935
+
936
+ except Exception as e:
937
+ raise ServiceCommunicationException(f"communication exception {e}") from e
938
+
939
+ def invoke(self, invocation: DynamicProxy.Invocation):
940
+ call = self.get_call(invocation.type, invocation.method)
941
+
942
+ try:
943
+ http_result = self.request("post", f"{self.get_url()}/invoke", content=call.serialize(invocation.args), timeout=self.timeout, headers={
944
+ "Content-Type": "application/x-protobuf",
945
+ #"Accept": "application/x-protobuf",
946
+ "x-rpc-method": call.method_name
947
+ })
948
+
949
+ return call.deserialize(http_result)
950
+ except (ServiceCommunicationException, AuthorizationException, RemoteServiceException) as e:
951
+ raise
952
+
953
+ except Exception as e:
954
+ raise ServiceCommunicationException(f"communication exception {e}") from e
955
+
956
+ class ProtobufDumper:
957
+ @classmethod
958
+ def dump_proto(cls, fd: descriptor_pb2.FileDescriptorProto) -> str:
959
+ lines = []
960
+
961
+ # Syntax
962
+
963
+ syntax = fd.syntax if fd.syntax else "proto2"
964
+ lines.append(f'syntax = "{syntax}";\n')
965
+
966
+ # Package
967
+
968
+ if fd.package:
969
+ lines.append(f'package {fd.package};\n')
970
+
971
+ # Imports
972
+
973
+ for dep in fd.dependency:
974
+ lines.append(f'import "{dep}";')
975
+
976
+ if fd.dependency:
977
+ lines.append('') # blank line
978
+
979
+ # Options (basic)
980
+
981
+ for opt in fd.options.ListFields() if fd.HasField('options') else []:
982
+ # Just a simple string option dump; for complex options you'd need more logic
983
+ name = opt[0].name
984
+ value = opt[1]
985
+ lines.append(f'option {name} = {value};')
986
+ if fd.HasField('options'):
987
+ lines.append('')
988
+
989
+ # Enums
990
+
991
+ def dump_enum(enum: descriptor_pb2.EnumDescriptorProto, indent=''):
992
+ enum_lines = [f"{indent}enum {enum.name} {{"]
993
+ for value in enum.value:
994
+ enum_lines.append(f"{indent} {value.name} = {value.number};")
995
+ enum_lines.append(f"{indent}}}\n")
996
+ return enum_lines
997
+
998
+ # Messages (recursive)
999
+
1000
+ def dump_message(msg: descriptor_pb2.DescriptorProto, indent=''):
1001
+ msg_lines = [f"{indent}message {msg.name} {{"]
1002
+ # Nested enums
1003
+ for enum in msg.enum_type:
1004
+ msg_lines.extend(dump_enum(enum, indent + ' '))
1005
+
1006
+ # Nested messages
1007
+
1008
+ for nested in msg.nested_type:
1009
+ # skip map entry messages (synthetic)
1010
+ if nested.options.map_entry:
1011
+ continue
1012
+ msg_lines.extend(dump_message(nested, indent + ' '))
1013
+
1014
+ # Fields
1015
+
1016
+ for field in msg.field:
1017
+ label = {
1018
+ 1: 'optional',
1019
+ 2: 'required',
1020
+ 3: 'repeated'
1021
+ }.get(field.label, '')
1022
+
1023
+ # Field type string
1024
+ if field.type_name:
1025
+ # It's a message or enum type
1026
+ # type_name is fully qualified, remove leading dot if present
1027
+ type_str = field.type_name.lstrip('.')
1028
+ else:
1029
+ # primitive type
1030
+ type_map = {
1031
+ 1: "double",
1032
+ 2: "float",
1033
+ 3: "int64",
1034
+ 4: "uint64",
1035
+ 5: "int32",
1036
+ 6: "fixed64",
1037
+ 7: "fixed32",
1038
+ 8: "bool",
1039
+ 9: "string",
1040
+ 10: "group", # deprecated
1041
+ 11: "message",
1042
+ 12: "bytes",
1043
+ 13: "uint32",
1044
+ 14: "enum",
1045
+ 15: "sfixed32",
1046
+ 16: "sfixed64",
1047
+ 17: "sint32",
1048
+ 18: "sint64",
1049
+ }
1050
+ type_str = type_map.get(field.type, f"TYPE_{field.type}")
1051
+
1052
+ # Field options (only packed example)
1053
+ opts = []
1054
+ if field.options.HasField('packed'):
1055
+ opts.append(f"packed = {str(field.options.packed).lower()}")
1056
+
1057
+ opts_str = f" [{', '.join(opts)}]" if opts else ""
1058
+
1059
+ msg_lines.append(f"{indent} {label} {type_str} {field.name} = {field.number}{opts_str};")
1060
+
1061
+ msg_lines.append(f"{indent}}}\n")
1062
+
1063
+ return msg_lines
1064
+
1065
+ # Services
1066
+
1067
+ def dump_service(svc: descriptor_pb2.ServiceDescriptorProto, indent=''):
1068
+ svc_lines = [f"{indent}service {svc.name} {{"]
1069
+ for method in svc.method:
1070
+ input_type = method.input_type.lstrip('.') if method.input_type else 'Unknown'
1071
+ output_type = method.output_type.lstrip('.') if method.output_type else 'Unknown'
1072
+ client_streaming = 'stream ' if method.client_streaming else ''
1073
+ server_streaming = 'stream ' if method.server_streaming else ''
1074
+ svc_lines.append(f"{indent} rpc {method.name} ({client_streaming}{input_type}) returns ({server_streaming}{output_type});")
1075
+ svc_lines.append(f"{indent}}}\n")
1076
+ return svc_lines
1077
+
1078
+ # Dump enums at file level
1079
+
1080
+ for enum in fd.enum_type:
1081
+ lines.extend(dump_enum(enum))
1082
+
1083
+ # Dump messages
1084
+
1085
+ for msg in fd.message_type:
1086
+ lines.extend(dump_message(msg))
1087
+
1088
+ # Dump services
1089
+
1090
+ for svc in fd.service:
1091
+ lines.extend(dump_service(svc))
1092
+
1093
+ return "\n".join(lines)