aspyx-service 0.10.6__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aspyx-service might be problematic. Click here for more details.

@@ -0,0 +1,1083 @@
1
+ """
2
+ Protobuf channel and utilities
3
+ """
4
+ from __future__ import annotations
5
+
6
+ import inspect
7
+ import threading
8
+ from dataclasses import is_dataclass, fields as dc_fields
9
+ from typing import Type, get_type_hints, Callable, Tuple, get_origin, get_args, List, Dict, Any, Union, Sequence, \
10
+ Optional
11
+
12
+ import httpx
13
+ from google.protobuf.message_factory import GetMessageClass
14
+ from pydantic import BaseModel
15
+
16
+ from google.protobuf import descriptor_pb2, descriptor_pool, message_factory
17
+ from google.protobuf.descriptor_pool import DescriptorPool
18
+ from google.protobuf.message import Message
19
+ from google.protobuf.descriptor import FieldDescriptor, Descriptor
20
+
21
+ from aspyx.di import injectable
22
+ from aspyx.reflection import DynamicProxy, TypeDescriptor
23
+ from aspyx.util import CopyOnWriteCache
24
+
25
+ from .service import channel, ServiceException
26
+ from .channels import HTTPXChannel
27
+ from .service import ServiceManager, ServiceCommunicationException, AuthorizationException, RemoteServiceException
28
+
29
+ def get_inner_type(typ: Type) -> Type:
30
+ """
31
+ Extract the inner type from List[InnerType], Optional[InnerType], etc.
32
+ """
33
+ origin = getattr(typ, "__origin__", None)
34
+ args = getattr(typ, "__args__", None)
35
+
36
+ if origin in (list, List):
37
+ return args[0] if args else Any
38
+
39
+ # Handle Optional[X] -> X
40
+ if origin is Union and len(args) == 2 and type(None) in args:
41
+ return args[0] if args[1] is type(None) else args[1]
42
+
43
+ return typ
44
+
45
+
46
+ def defaults_dict(model_cls: Type[BaseModel]) -> dict[str, Any]:
47
+ result = {}
48
+ for name, field in model_cls.model_fields.items():
49
+ if field.default is not None:
50
+ result[name] = field.default
51
+ elif field.default_factory is not None:
52
+ result[name] = field.default_factory()
53
+ return result
54
+
55
+ class ProtobufBuilder:
56
+ # slots
57
+
58
+ __slots__ = [
59
+ "pool",
60
+ "factory",
61
+ "modules",
62
+ "components",
63
+ "lock"
64
+ ]
65
+
66
+ @classmethod
67
+ def get_message_name(cls, type: Type, suffix="") -> str:
68
+ module = type.__module__.replace(".", "_")
69
+ name = type.__name__
70
+
71
+ return f"{module}.{name}{suffix}"
72
+
73
+ @classmethod
74
+ def get_request_message_name(cls, type: Type, method: Callable) -> str:
75
+ return cls.get_message_name(type, f"{method.__name__}Request")
76
+
77
+ @classmethod
78
+ def get_response_message_name(cls, type: Type, method: Callable) -> str:
79
+ return cls.get_message_name(type, f"{method.__name__}Response")
80
+
81
+ # local classes
82
+
83
+ class Module:
84
+ # constructor
85
+
86
+ def __init__(self, builder: ProtobufBuilder, name: str):
87
+ self.builder = builder
88
+ self.name = name.replace(".", "_")
89
+ self.file_desc_proto = descriptor_pb2.FileDescriptorProto() # type: ignore
90
+ self.file_desc_proto.name = f"{self.name}.proto"
91
+ self.file_desc_proto.package = self.name
92
+ self.types : dict[Type, Any] = {}
93
+ self.finalized = False
94
+ self.lock = threading.RLock()
95
+
96
+ # public
97
+
98
+ def get_fields_and_types(self, type: Type) -> List[Tuple[str, Type]]:
99
+ hints = get_type_hints(type)
100
+
101
+ if is_dataclass(type):
102
+ return [(f.name, hints.get(f.name, str)) for f in dc_fields(type)]
103
+
104
+ if issubclass(type, BaseModel):
105
+ return [(name, hints.get(name, str)) for name in type.model_fields]
106
+
107
+ raise TypeError("Expected a dataclass or Pydantic model class.")
108
+
109
+ def add_message(self, cls: Type) -> str:
110
+ if self.finalized:
111
+ raise ServiceException(f"module {self.name} is already sealed")
112
+
113
+ name = cls.__name__
114
+ full_name = f"{self.name}.{name}"
115
+
116
+ # Check if a message type is already defined
117
+
118
+ if any(m.name == name for m in self.file_desc_proto.message_type):
119
+ return f".{full_name}"
120
+
121
+ desc = descriptor_pb2.DescriptorProto() # type: ignore
122
+ desc.name = name
123
+
124
+ # Extract fields from dataclass or pydantic model
125
+
126
+ if is_dataclass(cls) or issubclass(cls, BaseModel):
127
+ index = 1
128
+ for field_name, field_type in self.get_fields_and_types(cls):
129
+ field_type_enum, label, type_name = self.builder.to_proto_type(self, field_type)
130
+
131
+ f = desc.field.add()
132
+ f.name = field_name
133
+ f.number = index
134
+ f.label = label
135
+ f.type = field_type_enum
136
+ if type_name:
137
+ f.type_name = type_name
138
+ index += 1
139
+
140
+ # add message type descriptor to the file descriptor proto
141
+
142
+ self.file_desc_proto.message_type.add().CopyFrom(desc)
143
+
144
+ return f".{full_name}"
145
+
146
+ def check_message(self, origin, type: Type) -> str:
147
+ if type not in self.types:
148
+ if self is not origin:
149
+ if not self.name in origin.file_desc_proto.dependency:
150
+ origin.file_desc_proto.dependency.append(self.file_desc_proto.name)
151
+
152
+ self.types[type] = self.add_message(type)
153
+
154
+ return self.types[type]
155
+
156
+ def build_request_message(self, method: TypeDescriptor.MethodDescriptor, request_name: str):
157
+ if self.finalized:
158
+ raise ServiceException(f"module {self.name} is already sealed")
159
+
160
+ request_msg = descriptor_pb2.DescriptorProto() # type: ignore
161
+ request_msg.name = request_name.split(".")[-1]
162
+
163
+ # loop over parameters
164
+
165
+ field_index = 1
166
+ for param in method.params:
167
+ field = request_msg.field.add()
168
+
169
+ field.name = param.name
170
+ field.number = field_index
171
+
172
+ field_type, label, type_name = self.builder.to_proto_type(self, param.type)
173
+ field.type = field_type
174
+ field.label = label
175
+ if type_name:
176
+ field.type_name = type_name
177
+
178
+ field_index += 1
179
+
180
+ # add to service file descriptor
181
+
182
+ self.file_desc_proto.message_type.add().CopyFrom(request_msg)
183
+
184
+ def build_response_message(self, method: TypeDescriptor.MethodDescriptor, response_name: str):
185
+ if self.finalized:
186
+ raise ServiceException(f"module {self.name} is already sealed")
187
+
188
+ response_msg = descriptor_pb2.DescriptorProto() # type: ignore
189
+ response_msg.name = response_name.split(".")[-1]
190
+
191
+ # return
192
+
193
+ return_type = method.return_type
194
+ response_field = response_msg.field.add()
195
+ response_field.name = "result"
196
+ response_field.number = 1
197
+
198
+ field_type, label, type_name = self.builder.to_proto_type(self, return_type)
199
+ response_field.type = field_type
200
+ response_field.label = label
201
+ if type_name:
202
+ response_field.type_name = type_name
203
+
204
+ # exception
205
+
206
+ exception_field = response_msg.field.add()
207
+ exception_field.name = "exception"
208
+ exception_field.number = 2
209
+
210
+ field_type, label, type_name = self.builder.to_proto_type(self, str)
211
+ exception_field.type = field_type
212
+ exception_field.label = label
213
+ if type_name:
214
+ exception_field.type_name = type_name
215
+
216
+ # add to service file descriptor
217
+
218
+ self.file_desc_proto.message_type.add().CopyFrom(response_msg)
219
+
220
+ def build_service_method(self, service_desc: descriptor_pb2.ServiceDescriptorProto, service_type: TypeDescriptor, method: TypeDescriptor.MethodDescriptor):
221
+ name = f"{service_type.cls.__name__}{method.get_name()}"
222
+ package = self.name
223
+
224
+ method_desc = descriptor_pb2.MethodDescriptorProto()
225
+
226
+ request_name = f".{package}.{name}Request"
227
+ response_name = f".{package}.{name}Response"
228
+
229
+ method_desc.name = method.get_name()
230
+ method_desc.input_type = request_name
231
+ method_desc.output_type = response_name
232
+
233
+ # Build request and response message types
234
+
235
+ self.build_request_message(method, request_name)
236
+ self.build_response_message(method, response_name)
237
+
238
+ # Add method to service descriptor
239
+
240
+ service_desc.method.add().CopyFrom(method_desc)
241
+
242
+ def add_service(self, service_type: TypeDescriptor):
243
+ if self.finalized:
244
+ raise ServiceException(f"module {self.name} is already sealed")
245
+
246
+ service_desc = descriptor_pb2.ServiceDescriptorProto() # type: ignore
247
+ service_desc.name = service_type.cls.__name__
248
+
249
+ # check methods
250
+
251
+ for method in service_type.get_methods():
252
+ self.build_service_method(service_desc, service_type, method)
253
+
254
+ # done
255
+
256
+ self.file_desc_proto.service.add().CopyFrom(service_desc)
257
+
258
+ def finalize(self, builder: ProtobufBuilder):
259
+ if not self.finalized:
260
+ self.finalized = True
261
+
262
+ #for m in self.file_desc_proto.message_type:
263
+ # print(m)
264
+
265
+ # add dependency first
266
+
267
+ for dependency in self.file_desc_proto.dependency:
268
+ builder.modules[dependency].finalize(builder)
269
+
270
+ builder.pool.Add(self.file_desc_proto)
271
+
272
+ #print(ProtobufDumper.dump_proto(self.file_desc_proto))
273
+
274
+ # constructor
275
+
276
+ def __init__(self):
277
+ self.pool: DescriptorPool = descriptor_pool.Default()
278
+ self.factory = message_factory.MessageFactory(self.pool)
279
+ self.modules: Dict[str, ProtobufBuilder.Module] = {}
280
+ self.components = {}
281
+ self.lock = threading.RLock()
282
+
283
+ # internal
284
+
285
+ def to_proto_type(self, module_origin, py_type: Type) -> Tuple[int, int, Optional[str]]:
286
+ """
287
+ Convert Python type to protobuf (field_type, label, type_name).
288
+ Returns:
289
+ - field_type: int (descriptor_pb2.FieldDescriptorProto.TYPE_*)
290
+ - label: int (descriptor_pb2.FieldDescriptorProto.LABEL_*)
291
+ - type_name: Optional[str] (fully qualified message name for messages)
292
+ """
293
+ origin = get_origin(py_type)
294
+ args = get_args(py_type)
295
+
296
+ # Check for repeated fields (list / List)
297
+ if origin in (list, List):
298
+ # Assume single-argument generic list e.g. List[int], List[FooModel]
299
+ item_type = args[0] if args else str
300
+ field_type, _, type_name = self._resolve_type(module_origin, item_type)
301
+ return (
302
+ field_type,
303
+ descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED, # type: ignore
304
+ type_name,
305
+ )
306
+
307
+ return self._resolve_type(module_origin, py_type)
308
+
309
+ def _resolve_type(self, origin, py_type: Type) -> Tuple[int, int, Optional[str]]:
310
+ """Resolves Python type to protobuf scalar or message type with label=optional."""
311
+ # Structured message (dataclass or Pydantic BaseModel)
312
+ if is_dataclass(py_type) or (inspect.isclass(py_type) and issubclass(py_type, BaseModel)):
313
+ type_name = self.get_module(py_type).check_message(origin, py_type)
314
+ return (
315
+ descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, # type: ignore
316
+ descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL, # type: ignore
317
+ type_name,
318
+ )
319
+
320
+ # Scalar mappings
321
+
322
+ scalar = self._map_scalar_type(py_type)
323
+
324
+ return scalar, descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL, None # type: ignore
325
+
326
+ def _map_scalar_type(self, py_type: Type) -> int:
327
+ """Map Python scalar types to protobuf field types."""
328
+ mapping = {
329
+ str: descriptor_pb2.FieldDescriptorProto.TYPE_STRING, # type: ignore
330
+ int: descriptor_pb2.FieldDescriptorProto.TYPE_INT32, # type: ignore
331
+ float: descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT, # type: ignore
332
+ bool: descriptor_pb2.FieldDescriptorProto.TYPE_BOOL, # type: ignore
333
+ bytes: descriptor_pb2.FieldDescriptorProto.TYPE_BYTES, # type: ignore
334
+ }
335
+
336
+ return mapping.get(py_type, descriptor_pb2.FieldDescriptorProto.TYPE_STRING) # type: ignore
337
+
338
+ def check_type(self, type: Type):
339
+ self.get_module(type).check_message(type)
340
+
341
+ def get_message_type(self, full_name: str):
342
+ return GetMessageClass(self.pool.FindMessageTypeByName(full_name))
343
+ #return self.factory.GetPrototype(self.pool.FindMessageTypeByName(full_name))
344
+
345
+ def get_request_message(self, type: Type, method: Callable):
346
+ return self.get_message_type(self.get_request_message_name(type, method))
347
+
348
+ def get_response_message(self, type: Type, method: Callable):
349
+ return self.get_message_type(self.get_response_message_name(type, method))
350
+
351
+ def get_module(self, type: Type):
352
+ name = type.__module__#.replace(".", "_")
353
+ key = name.replace(".", "_") + ".proto"
354
+ module = self.modules.get(key, None)
355
+ if module is None:
356
+ module = ProtobufBuilder.Module(self, name)
357
+ self.modules[key] = module
358
+
359
+ return module
360
+
361
+ def build_service(self, service: TypeDescriptor):
362
+ self.get_module(service.cls).add_service(service)
363
+
364
+ # public
365
+
366
+ #@synchronized()
367
+ def check(self, service_type: Type):
368
+ descriptor = getattr(service_type, "__descriptor__")
369
+
370
+ component_descriptor = descriptor.component_descriptor
371
+
372
+ with self.lock:
373
+ if component_descriptor not in self.components:
374
+ for service in component_descriptor.services:
375
+ self.build_service(TypeDescriptor.for_type(service.type))
376
+
377
+ # finalize
378
+
379
+ for module in self.modules.values():
380
+ module.finalize(self)
381
+
382
+ # done
383
+
384
+ self.components[component_descriptor] = True
385
+
386
+
387
+ @injectable()
388
+ class ProtobufManager(ProtobufBuilder):
389
+ # local classes
390
+
391
+ class MethodDeserializer:
392
+ __slots__ = [
393
+ "manager",
394
+ "descriptor",
395
+ "getters"
396
+ ]
397
+
398
+ # constructor
399
+
400
+ def __init__(self, manager: ProtobufManager, descriptor: Descriptor):
401
+ self.manager = manager
402
+ self.descriptor = descriptor
403
+
404
+ self.getters = []
405
+
406
+ # internal
407
+
408
+ def args(self, method: Callable)-> ProtobufManager.MethodDeserializer:
409
+ type_hints = get_type_hints(method)
410
+
411
+ # loop over parameters
412
+
413
+ for param_name in inspect.signature(method).parameters:
414
+ if param_name == "self":
415
+ continue
416
+
417
+ field_desc = self.descriptor.fields_by_name[param_name]
418
+
419
+ self.getters.append(self._create_getter(field_desc, param_name, type_hints.get(param_name, str)))
420
+
421
+ return self
422
+
423
+ def result(self, method: Callable) -> 'ProtobufManager.MethodDeserializer':
424
+ type_hints = get_type_hints(method)
425
+
426
+ return_type = type_hints.get('return')
427
+
428
+ result_field_desc = self.descriptor.DESCRIPTOR.fields_by_name["result"]
429
+ exception_field_desc = self.descriptor.DESCRIPTOR.fields_by_name["exception"]
430
+
431
+ self.getters.append(self._create_getter(result_field_desc, "result", return_type))
432
+ self.getters.append(self._create_getter(exception_field_desc, "exception", str))
433
+
434
+ return self
435
+
436
+ def get_fields_and_types(self, type: Type) -> List[Tuple[str, Type]]:
437
+ hints = get_type_hints(type)
438
+
439
+ if is_dataclass(type):
440
+ return [(f.name, hints.get(f.name, str)) for f in dc_fields(type)]
441
+
442
+ if issubclass(type, BaseModel):
443
+ return [(name, hints.get(name, str)) for name in type.model_fields]
444
+
445
+ raise TypeError("Expected a dataclass or Pydantic model class.")
446
+
447
+ def _create_getter(self, field_desc: FieldDescriptor, field_name: str, type: Type):
448
+ is_repeated = field_desc.label == field_desc.LABEL_REPEATED
449
+ is_message = field_desc.message_type is not None
450
+
451
+ ## local func
452
+
453
+ def compute_class_getters(item_type: Type) -> list[Callable]:
454
+ getters = []
455
+
456
+ for sub_field_name, field_type in self.get_fields_and_types(item_type):
457
+ getters.append(self._create_getter(message_type.fields_by_name[sub_field_name], sub_field_name, field_type))
458
+
459
+ return getters
460
+
461
+ # list
462
+
463
+ if is_repeated:
464
+ item_type = get_args(type)[0] if get_origin(type) in (list, List) else str
465
+
466
+ # list of messages
467
+
468
+ if is_dataclass(item_type) or issubclass(item_type, BaseModel):
469
+ message_type = self.manager.pool.FindMessageTypeByName(ProtobufManager.get_message_name(item_type))
470
+
471
+ getters = self.manager.getter_lambdas_cache.get(item_type, compute_class_getters)
472
+
473
+ def deserialize_dataclass_list(msg: Message, val: Any, setter=setattr, getters=getters):
474
+ list = []
475
+
476
+ for item in getattr(msg, field_name):
477
+ instance = item_type.__new__(item_type)
478
+
479
+ for getter in getters:
480
+ getter(item, instance, object.__setattr__)
481
+
482
+ list.append(instance)
483
+
484
+ setter(val, field_name, list)
485
+
486
+ default = {}
487
+ if issubclass(item_type, BaseModel):
488
+ default = defaults_dict(item_type)
489
+
490
+ def deserialize_pydantic_list(msg: Message, val: Any, setter=setattr, getters=getters):
491
+ list = []
492
+
493
+ for item in getattr(msg, field_name):
494
+ #instance = type.__new__(type)
495
+
496
+ instance = item_type.model_construct(**default)
497
+
498
+ for getter in getters:
499
+ getter(item, instance, setattr)
500
+
501
+ list.append(instance)
502
+
503
+ setter(val, field_name, list)
504
+
505
+ if is_dataclass(item_type):
506
+ return deserialize_dataclass_list
507
+ else:
508
+ return deserialize_pydantic_list
509
+
510
+ # list of scalars
511
+
512
+ else:
513
+ def deserialize_list(msg: Message, val, setter=setattr):
514
+ list = []
515
+
516
+ for item in getattr(msg, field_name):
517
+ list.append(item)
518
+
519
+ setter(val, field_name, list)
520
+
521
+ return deserialize_list
522
+
523
+ # message
524
+
525
+ elif is_message:
526
+ if is_dataclass(type) or issubclass(type, BaseModel):
527
+ message_type = self.manager.pool.FindMessageTypeByName(ProtobufManager.get_message_name(type))
528
+
529
+ sub_getters = self.manager.getter_lambdas_cache.get(type, compute_class_getters)
530
+
531
+ default = {}
532
+ if issubclass(type, BaseModel):
533
+ default = defaults_dict(type)
534
+
535
+ def deserialize_dataclass(msg: Message, val: Any, setter=setattr, getters=sub_getters):
536
+ sub_message = getattr(msg, field_name)
537
+
538
+ instance = type.__new__(type)
539
+
540
+ for getter in getters:
541
+ getter(sub_message, instance, setattr)#object.__setattr__
542
+
543
+ setter(val, field_name, instance)
544
+
545
+ def deserialize_pydantic(msg: Message, val: Any, setter=setattr, getters=sub_getters):
546
+ sub_message = getattr(msg, field_name)
547
+
548
+ instance = type.model_construct(**default)
549
+
550
+ for getter in getters:
551
+ getter(sub_message, instance, setattr)
552
+
553
+ setter(val, field_name, instance)
554
+
555
+ if is_dataclass(type):
556
+ return deserialize_dataclass
557
+ else:
558
+ return deserialize_pydantic
559
+ else:
560
+ raise TypeError(f"Expected dataclass or BaseModel for field '{field_name}', got {type}")
561
+
562
+ # scalar
563
+
564
+ else:
565
+ def deserialize_scalar(msg: Message, val: Any, setter=setattr):
566
+ if msg.HasField(field_name):
567
+ setter(val, field_name, getattr(msg, field_name))
568
+ else:
569
+ setter(val, field_name, None)
570
+
571
+ return deserialize_scalar
572
+
573
+ # public
574
+
575
+ def deserialize(self, message: Message) -> list[Any]:
576
+ # call setters
577
+
578
+ list = []
579
+ for getter in self.getters:
580
+ getter(message, list, lambda obj, prop, value: list.append(value))
581
+
582
+ return list
583
+
584
+ def deserialize_result(self, message: Message) -> Any:
585
+ result = None
586
+ exception = None
587
+
588
+ def set_result(obj, prop, value):
589
+ nonlocal result, exception
590
+
591
+ if prop == "result":
592
+ result = value
593
+ else:
594
+ exception = value
595
+
596
+ # call setters
597
+
598
+ for getter in self.getters:
599
+ getter(message, None, set_result)
600
+
601
+ if result is None:
602
+ raise RemoteServiceException(f"server side exception {exception}")
603
+
604
+ return result
605
+
606
+ class MethodSerializer:
607
+ __slots__ = [
608
+ "manager",
609
+ "message_type",
610
+ "setters"
611
+ ]
612
+
613
+ # constructor
614
+
615
+ def __init__(self, manager: ProtobufManager, message_type):
616
+ self.manager = manager
617
+ self.message_type = message_type
618
+
619
+ self.setters = []
620
+
621
+ def result(self, method: Callable) -> ProtobufManager.MethodSerializer:
622
+ msg_descriptor = self.message_type.DESCRIPTOR
623
+ type_hints = get_type_hints(method)
624
+
625
+ return_type = type_hints["return"]
626
+
627
+ result_field_desc = msg_descriptor.fields_by_name["result"]
628
+ exception_field_desc = msg_descriptor.fields_by_name["exception"]
629
+
630
+ self.setters.append(self._create_setter(result_field_desc, "result", return_type))
631
+ self.setters.append(self._create_setter(exception_field_desc, "exception", str))
632
+
633
+ return self
634
+
635
+ def args(self, method: Callable)-> ProtobufManager.MethodSerializer:
636
+ msg_descriptor = self.message_type.DESCRIPTOR
637
+ type_hints = get_type_hints(method)
638
+
639
+ # loop over parameters
640
+
641
+ for param_name in inspect.signature(method).parameters:
642
+ if param_name == "self":
643
+ continue
644
+
645
+ field_desc = msg_descriptor.fields_by_name[param_name]
646
+
647
+ self.setters.append(self._create_setter(field_desc, param_name, type_hints.get(param_name, str)))
648
+
649
+ # done
650
+
651
+ return self
652
+
653
+ def get_fields_and_types(self, type: Type) -> List[Tuple[str, Type]]:
654
+ hints = get_type_hints(type)
655
+
656
+ if is_dataclass(type):
657
+ return [(f.name, hints.get(f.name, str)) for f in dc_fields(type)]
658
+
659
+ if issubclass(type, BaseModel):
660
+ return [(name, hints.get(name, str)) for name in type.model_fields]
661
+
662
+ raise TypeError("Expected a dataclass or Pydantic model class.")
663
+
664
+ def _create_setter(self, field_desc: FieldDescriptor, field_name: str, type: Type):
665
+ is_repeated = field_desc.label == field_desc.LABEL_REPEATED
666
+ is_message = field_desc.message_type is not None
667
+
668
+ # local func
669
+
670
+ def create(message_type: Descriptor, item_type: Type) -> Tuple[list[Callable],list[str]]:
671
+ setters = []
672
+ fields = []
673
+ for field_name, field_type in self.get_fields_and_types(item_type):
674
+ fields.append(field_name)
675
+ setters.append(self._create_setter(message_type.fields_by_name[field_name], field_name, field_type))
676
+
677
+ return setters, fields
678
+
679
+ # list
680
+
681
+ if is_repeated:
682
+ item_type = get_args(type)[0] if get_origin(type) in (list, List) else str
683
+
684
+ # list of messages
685
+
686
+ if is_dataclass(item_type) or issubclass(item_type, BaseModel):
687
+ message_type = self.manager.pool.FindMessageTypeByName(ProtobufManager.get_message_name(item_type))
688
+
689
+ setters, fields = self.manager.setter_lambdas_cache.get(item_type, lambda t: create(message_type, item_type))
690
+
691
+ def serialize_message_list(msg: Message, val: Any, fields=fields, setters=setters):
692
+ for item in val:
693
+ msg_item = getattr(msg, field_name).add()
694
+ for i in range(len(setters)):
695
+ setters[i](msg_item, getattr(item, fields[i]))
696
+
697
+ return serialize_message_list
698
+
699
+ # list of scalars
700
+
701
+ else:
702
+ return lambda msg, val: getattr(msg, field_name).extend(val)
703
+
704
+ # message
705
+
706
+ elif is_message:
707
+ if is_dataclass(type) or issubclass(type, BaseModel):
708
+ message_type = self.manager.pool.FindMessageTypeByName(ProtobufManager.get_message_name(type))
709
+
710
+ sub_setters, fields = self.manager.setter_lambdas_cache.get(type, lambda t: create(message_type, type))
711
+
712
+ def serialize_message(msg: Message, val: Any, fields=fields, setters=sub_setters):
713
+ field = getattr(msg, field_name)
714
+ for i in range(len(sub_setters)):
715
+ setters[i](field, getattr(val, fields[i]))
716
+
717
+ return serialize_message
718
+ else:
719
+ raise TypeError(f"Expected dataclass or BaseModel for field '{field_name}', got {type}")
720
+
721
+ # scalar
722
+
723
+ else:
724
+ def set_attr(msg, val):
725
+ if val is not None:
726
+ setattr(msg, field_name, val)
727
+ else:
728
+ pass#delattr(msg, field_name)
729
+
730
+ return set_attr# lambda msg, val: setattr(msg, field_name, val)
731
+
732
+ def serialize(self, value: Any) -> Any:
733
+ # create message instance
734
+
735
+ message = self.message_type()
736
+
737
+ # call setters
738
+
739
+ for i in range(len(self.setters)):
740
+ self.setters[i](message, value)
741
+
742
+ return message
743
+
744
+ def serialize_result(self, value: Any, exception: str) -> Any:
745
+ # create message instance
746
+
747
+ message = self.message_type()
748
+
749
+ # call setters
750
+
751
+ if value is not None:
752
+ self.setters[0](message, value)
753
+
754
+ if exception is not None:
755
+ self.setters[1](message, exception)
756
+
757
+ return message
758
+
759
+ def serialize_args(self, args: Sequence[Any]) -> Any:
760
+ # create message instance
761
+
762
+ message = self.message_type()
763
+
764
+ # call setters
765
+
766
+ for i in range(len(self.setters)):
767
+ self.setters[i](message, args[i])
768
+
769
+ #for setter, value in zip(self.setters, invocation.args):
770
+ # setter(message, value)
771
+
772
+ return message
773
+
774
+ # slots
775
+
776
+ __slots__ = [
777
+ "serializer_cache",
778
+ "deserializer_cache",
779
+ "result_serializer_cache",
780
+ "result_deserializer_cache",
781
+ "setter_lambdas_cache",
782
+ "getter_lambdas_cache"
783
+ ]
784
+
785
+ # constructor
786
+
787
+ def __init__(self):
788
+ super().__init__()
789
+
790
+ self.serializer_cache = CopyOnWriteCache[Callable, ProtobufManager.MethodSerializer]()
791
+ self.deserializer_cache = CopyOnWriteCache[Descriptor, ProtobufManager.MethodDeserializer]()
792
+
793
+ self.result_serializer_cache = CopyOnWriteCache[Descriptor, ProtobufManager.MethodSerializer]()
794
+ self.result_deserializer_cache = CopyOnWriteCache[Descriptor, ProtobufManager.MethodDeserializer]()
795
+
796
+ self.setter_lambdas_cache = CopyOnWriteCache[Type, list[Callable]]()
797
+ self.getter_lambdas_cache = CopyOnWriteCache[Type, Tuple[list[Callable], list[str]]]()
798
+
799
+ # public
800
+
801
+ def create_serializer(self, type: Type, method: Callable) -> ProtobufManager.MethodSerializer:
802
+ # is it cached?
803
+
804
+ serializer = self.serializer_cache.get(method)
805
+ if serializer is None:
806
+ self.check(type) # make sure all messages are created
807
+
808
+ serializer = ProtobufManager.MethodSerializer(self, self.get_request_message(type, method)).args(method)
809
+
810
+ self.serializer_cache.put(method, serializer)
811
+
812
+ return serializer
813
+
814
+ def create_deserializer(self, descriptor: Descriptor, method: Callable) -> ProtobufManager.MethodDeserializer:
815
+ # is it cached?
816
+
817
+ deserializer = self.deserializer_cache.get(descriptor)
818
+ if deserializer is None:
819
+ deserializer = ProtobufManager.MethodDeserializer(self, descriptor).args(method)
820
+
821
+ self.deserializer_cache.put(descriptor, deserializer)
822
+
823
+ return deserializer
824
+
825
+ def create_result_serializer(self, descriptor: Descriptor, method: Callable) -> ProtobufManager.MethodSerializer:
826
+ # is it cached?
827
+
828
+ serializer = self.result_serializer_cache.get(descriptor)
829
+ if serializer is None:
830
+ serializer = ProtobufManager.MethodSerializer(self, descriptor).result(method)
831
+
832
+ self.result_serializer_cache.put(descriptor, serializer)
833
+
834
+ return serializer
835
+
836
+ def create_result_deserializer(self, descriptor: Descriptor,
837
+ method: Callable) -> ProtobufManager.MethodDeserializer:
838
+ # is it cached?
839
+
840
+ deserializer = self.result_deserializer_cache.get(descriptor)
841
+ if deserializer is None:
842
+ deserializer = ProtobufManager.MethodDeserializer(self, descriptor).result(method)
843
+
844
+ self.result_deserializer_cache.put(descriptor, deserializer)
845
+
846
+ return deserializer
847
+
848
+ @channel("dispatch-protobuf")
849
+ class ProtobufChannel(HTTPXChannel):
850
+ # local classes
851
+
852
+ class Call:
853
+ __slots__ = [
854
+ "method_name",
855
+ "serializer",
856
+ "response_type",
857
+ "deserializer"
858
+ ]
859
+
860
+ # constructor
861
+
862
+ def __init__(self, method_name: str, serializer: ProtobufManager.MethodSerializer, response_type, deserializer: ProtobufManager.MethodDeserializer):
863
+ self.method_name = method_name
864
+ self.serializer = serializer
865
+ self.response_type = response_type
866
+ self.deserializer = deserializer
867
+
868
+ # public
869
+
870
+ def serialize(self, args: Sequence[Any]) -> Any:
871
+ message = self.serializer.serialize_args(args)
872
+ return message.SerializeToString()
873
+
874
+ def deserialize(self, http_response: httpx.Response) -> Any:
875
+ response = self.response_type()
876
+ response.ParseFromString(http_response.content)
877
+
878
+ return self.deserializer.deserialize_result(response)
879
+
880
+ # slots
881
+
882
+ __slots__ = [
883
+ "manager",
884
+ "environment",
885
+ "protobuf_manager",
886
+ "cache"
887
+ ]
888
+
889
+ # constructor
890
+
891
+ def __init__(self, manager: ServiceManager, protobuf_manager: ProtobufManager):
892
+ super().__init__()
893
+
894
+ self.manager = manager
895
+ self.environment = None
896
+ self.protobuf_manager = protobuf_manager
897
+ self.cache = CopyOnWriteCache[Callable, ProtobufChannel.Call]()
898
+
899
+ # internal
900
+
901
+ def get_call(self, type: Type, method: Callable) -> ProtobufChannel.Call:
902
+ call = self.cache.get(method)
903
+ if call is None:
904
+ method_name = f"{self.component_descriptor.name}:{self.service_names[type]}:{method.__name__}"
905
+ serializer = self.protobuf_manager.create_serializer(type, method)
906
+ response_type = self.protobuf_manager.get_message_type(self.protobuf_manager.get_response_message_name(type, method))
907
+ deserializer = self.protobuf_manager.create_result_deserializer(response_type, method)
908
+
909
+ call = ProtobufChannel.Call(method_name, serializer, response_type, deserializer)
910
+
911
+ self.cache.put(method, call)
912
+
913
+ return call
914
+
915
+ # implement
916
+
917
+ async def invoke_async(self, invocation: DynamicProxy.Invocation):
918
+ call = self.get_call(invocation.type, invocation.method)
919
+
920
+ try:
921
+ http_result = await self.request_async("post", f"{self.get_url()}/invoke", content=call.serialize(invocation.args),
922
+ timeout=self.timeout, headers={
923
+ "Content-Type": "application/x-protobuf",
924
+ # "Accept": "application/x-protobuf",
925
+ "x-rpc-method": call.method_name
926
+ })
927
+
928
+ return call.deserialize(http_result)
929
+ except (ServiceCommunicationException, AuthorizationException, RemoteServiceException) as e:
930
+ raise
931
+
932
+ except Exception as e:
933
+ raise ServiceCommunicationException(f"communication exception {e}") from e
934
+
935
+ def invoke(self, invocation: DynamicProxy.Invocation):
936
+ call = self.get_call(invocation.type, invocation.method)
937
+
938
+ try:
939
+ http_result = self.request("post", f"{self.get_url()}/invoke", content=call.serialize(invocation.args), timeout=self.timeout, headers={
940
+ "Content-Type": "application/x-protobuf",
941
+ #"Accept": "application/x-protobuf",
942
+ "x-rpc-method": call.method_name
943
+ })
944
+
945
+ return call.deserialize(http_result)
946
+ except (ServiceCommunicationException, AuthorizationException, RemoteServiceException) as e:
947
+ raise
948
+
949
+ except Exception as e:
950
+ raise ServiceCommunicationException(f"communication exception {e}") from e
951
+
952
+ class ProtobufDumper:
953
+ @classmethod
954
+ def dump_proto(cls, fd: descriptor_pb2.FileDescriptorProto) -> str:
955
+ lines = []
956
+
957
+ # Syntax
958
+
959
+ syntax = fd.syntax if fd.syntax else "proto2"
960
+ lines.append(f'syntax = "{syntax}";\n')
961
+
962
+ # Package
963
+
964
+ if fd.package:
965
+ lines.append(f'package {fd.package};\n')
966
+
967
+ # Imports
968
+
969
+ for dep in fd.dependency:
970
+ lines.append(f'import "{dep}";')
971
+
972
+ if fd.dependency:
973
+ lines.append('') # blank line
974
+
975
+ # Options (basic)
976
+ for opt in fd.options.ListFields() if fd.HasField('options') else []:
977
+ # Just a simple string option dump; for complex options you'd need more logic
978
+ name = opt[0].name
979
+ value = opt[1]
980
+ lines.append(f'option {name} = {value};')
981
+ if fd.HasField('options'):
982
+ lines.append('')
983
+
984
+ # Enums
985
+ def dump_enum(enum: descriptor_pb2.EnumDescriptorProto, indent=''):
986
+ enum_lines = [f"{indent}enum {enum.name} {{"]
987
+ for value in enum.value:
988
+ enum_lines.append(f"{indent} {value.name} = {value.number};")
989
+ enum_lines.append(f"{indent}}}\n")
990
+ return enum_lines
991
+
992
+ # Messages (recursive)
993
+ def dump_message(msg: descriptor_pb2.DescriptorProto, indent=''):
994
+ msg_lines = [f"{indent}message {msg.name} {{"]
995
+ # Nested enums
996
+ for enum in msg.enum_type:
997
+ msg_lines.extend(dump_enum(enum, indent + ' '))
998
+
999
+ # Nested messages
1000
+ for nested in msg.nested_type:
1001
+ # skip map entry messages (synthetic)
1002
+ if nested.options.map_entry:
1003
+ continue
1004
+ msg_lines.extend(dump_message(nested, indent + ' '))
1005
+
1006
+ # Fields
1007
+ for field in msg.field:
1008
+ label = {
1009
+ 1: 'optional',
1010
+ 2: 'required',
1011
+ 3: 'repeated'
1012
+ }.get(field.label, '')
1013
+
1014
+ # Field type string
1015
+ if field.type_name:
1016
+ # It's a message or enum type
1017
+ # type_name is fully qualified, remove leading dot if present
1018
+ type_str = field.type_name.lstrip('.')
1019
+ else:
1020
+ # primitive type
1021
+ type_map = {
1022
+ 1: "double",
1023
+ 2: "float",
1024
+ 3: "int64",
1025
+ 4: "uint64",
1026
+ 5: "int32",
1027
+ 6: "fixed64",
1028
+ 7: "fixed32",
1029
+ 8: "bool",
1030
+ 9: "string",
1031
+ 10: "group", # deprecated
1032
+ 11: "message",
1033
+ 12: "bytes",
1034
+ 13: "uint32",
1035
+ 14: "enum",
1036
+ 15: "sfixed32",
1037
+ 16: "sfixed64",
1038
+ 17: "sint32",
1039
+ 18: "sint64",
1040
+ }
1041
+ type_str = type_map.get(field.type, f"TYPE_{field.type}")
1042
+
1043
+ # Field options (only packed example)
1044
+ opts = []
1045
+ if field.options.HasField('packed'):
1046
+ opts.append(f"packed = {str(field.options.packed).lower()}")
1047
+
1048
+ opts_str = f" [{', '.join(opts)}]" if opts else ""
1049
+
1050
+ msg_lines.append(f"{indent} {label} {type_str} {field.name} = {field.number}{opts_str};")
1051
+
1052
+ msg_lines.append(f"{indent}}}\n")
1053
+
1054
+ return msg_lines
1055
+
1056
+ # Services
1057
+ def dump_service(svc: descriptor_pb2.ServiceDescriptorProto, indent=''):
1058
+ svc_lines = [f"{indent}service {svc.name} {{"]
1059
+ for method in svc.method:
1060
+ input_type = method.input_type.lstrip('.') if method.input_type else 'Unknown'
1061
+ output_type = method.output_type.lstrip('.') if method.output_type else 'Unknown'
1062
+ client_streaming = 'stream ' if method.client_streaming else ''
1063
+ server_streaming = 'stream ' if method.server_streaming else ''
1064
+ svc_lines.append(f"{indent} rpc {method.name} ({client_streaming}{input_type}) returns ({server_streaming}{output_type});")
1065
+ svc_lines.append(f"{indent}}}\n")
1066
+ return svc_lines
1067
+
1068
+ # Dump enums at file level
1069
+
1070
+ for enum in fd.enum_type:
1071
+ lines.extend(dump_enum(enum))
1072
+
1073
+ # Dump messages
1074
+
1075
+ for msg in fd.message_type:
1076
+ lines.extend(dump_message(msg))
1077
+
1078
+ # Dump services
1079
+
1080
+ for svc in fd.service:
1081
+ lines.extend(dump_service(svc))
1082
+
1083
+ return "\n".join(lines)