soia-client 1.1.5__py3-none-any.whl → 1.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of soia-client might be problematic. Click here for more details.

soia/_impl/enums.py CHANGED
@@ -1,15 +1,16 @@
1
1
  import copy
2
2
  from collections.abc import Callable, Sequence
3
3
  from dataclasses import FrozenInstanceError, dataclass
4
- from typing import Any, Final, Union
4
+ from typing import Any, Final, Generic, Union
5
5
 
6
6
  from soia import _spec, reflection
7
+ from soia._impl.binary import decode_int64, decode_unused, encode_int64
7
8
  from soia._impl.function_maker import BodyBuilder, Expr, ExprLike, Line, make_function
8
9
  from soia._impl.repr import repr_impl
9
- from soia._impl.type_adapter import TypeAdapter
10
+ from soia._impl.type_adapter import T, ByteStream, TypeAdapter
10
11
 
11
12
 
12
- class EnumAdapter(TypeAdapter):
13
+ class EnumAdapter(Generic[T], TypeAdapter[T]):
13
14
  __slots__ = (
14
15
  "spec",
15
16
  "gen_class",
@@ -30,6 +31,12 @@ class EnumAdapter(TypeAdapter):
30
31
  self.spec = spec
31
32
  base_class = self.gen_class = _make_base_class(spec)
32
33
 
34
+ def forward_decode(stream: ByteStream) -> T:
35
+ return base_class._decode(stream)
36
+
37
+ # Will be overridden at finalization time.
38
+ base_class._decode = forward_decode
39
+
33
40
  private_is_enum_attr = _name_private_is_enum_attr(spec.id)
34
41
  self.private_is_enum_attr = private_is_enum_attr
35
42
  setattr(base_class, private_is_enum_attr, True)
@@ -80,11 +87,21 @@ class EnumAdapter(TypeAdapter):
80
87
  create_fn = _make_create_fn(wrap_fn, frozen_class)
81
88
  setattr(base_class, f"create_{value_field.spec.name}", create_fn)
82
89
 
90
+ unrecognized_class = _make_unrecognized_class(base_class)
91
+
83
92
  base_class._fj = _make_from_json_fn(
84
93
  self.all_constant_fields,
85
94
  value_fields,
86
95
  set(self.spec.removed_numbers),
87
- base_class,
96
+ base_class=base_class,
97
+ unrecognized_class=unrecognized_class,
98
+ )
99
+ base_class._decode = _make_decode_fn(
100
+ self.all_constant_fields,
101
+ value_fields,
102
+ set(self.spec.removed_numbers),
103
+ base_class=base_class,
104
+ unrecognized_class=unrecognized_class,
88
105
  )
89
106
 
90
107
  # Mark finalization as done.
@@ -117,16 +134,36 @@ class EnumAdapter(TypeAdapter):
117
134
  def to_json_expr(self, in_expr: ExprLike, readable: bool) -> Expr:
118
135
  return Expr.join(in_expr, "._rj" if readable else "._dj")
119
136
 
120
- def from_json_expr(self, json_expr: ExprLike) -> Expr:
137
+ def from_json_expr(
138
+ self, json_expr: ExprLike, keep_unrecognized_expr: ExprLike
139
+ ) -> Expr:
121
140
  fn_name = "_fj"
122
141
  from_json_fn = getattr(self.gen_class, fn_name, None)
123
142
  if from_json_fn:
124
- return Expr.join(Expr.local("_fj?", from_json_fn), "(", json_expr, ")")
143
+ return Expr.join(
144
+ Expr.local("_fj?", from_json_fn),
145
+ "(",
146
+ json_expr,
147
+ ", ",
148
+ keep_unrecognized_expr,
149
+ ")",
150
+ )
125
151
  else:
126
152
  return Expr.join(
127
- Expr.local("_cls?", self.gen_class), f".{fn_name}(", json_expr, ")"
153
+ Expr.local("_cls?", self.gen_class),
154
+ f".{fn_name}(",
155
+ json_expr,
156
+ ", ",
157
+ keep_unrecognized_expr,
158
+ ")",
128
159
  )
129
160
 
161
+ def encode_fn(self) -> Callable[[T, bytearray], None]:
162
+ return _encode_impl
163
+
164
+ def decode_fn(self) -> Callable[[ByteStream], T]:
165
+ return self.gen_class._decode
166
+
130
167
  def get_type(self) -> reflection.Type:
131
168
  return reflection.RecordType(
132
169
  kind="record",
@@ -209,6 +246,9 @@ def _make_base_class(spec: _spec.Enum) -> type:
209
246
 
210
247
 
211
248
  def _make_constant_class(base_class: type, spec: _spec.ConstantField) -> type:
249
+ byte_array = bytearray()
250
+ encode_int64(spec.number, byte_array)
251
+
212
252
  class Constant(base_class):
213
253
  __slots__ = ()
214
254
 
@@ -220,6 +260,7 @@ def _make_constant_class(base_class: type, spec: _spec.ConstantField) -> type:
220
260
  _rj: Final[str] = spec.name
221
261
  # has value
222
262
  _hv: Final[bool] = False
263
+ _bytes: Final[bytes | None] = bytes(byte_array)
223
264
 
224
265
  def __init__(self):
225
266
  # Do not call super().__init__().
@@ -239,20 +280,22 @@ def _make_unrecognized_class(base_class: type) -> type:
239
280
  """
240
281
 
241
282
  class Unrecognized(base_class):
242
- __slots__ = ("_dj",)
283
+ __slots__ = ("_dj", "_bytes")
243
284
 
244
285
  kind: Final[str] = "?"
245
286
  _number: Final[int] = 0
246
287
  # dense JSON
247
- _dj: Any
288
+ _dj: list[Any] | int
289
+ _bytes: bytes
248
290
  # readable JSON
249
291
  _rj: Final[str] = "?"
250
292
  # has value
251
293
  _hv: Final[bool] = False
252
294
 
253
- def __init__(self, dj: Any):
295
+ def __init__(self, dj: list[Any] | int, bytes: bytes):
254
296
  # Do not call super().__init__().
255
297
  object.__setattr__(self, "_dj", copy.deepcopy(dj))
298
+ object.__setattr__(self, "_bytes", bytes)
256
299
  object.__setattr__(self, "value", None)
257
300
 
258
301
  def __repr__(self) -> str:
@@ -275,6 +318,7 @@ def _make_value_class(
275
318
  _number: Final[int] = number
276
319
  # has value
277
320
  _hv: Final[bool] = True
321
+ _bytes: Final[None] = None
278
322
 
279
323
  def __init__(self):
280
324
  # Do not call super().__init__().
@@ -319,9 +363,38 @@ def _make_value_class(
319
363
  )
320
364
  )
321
365
 
366
+ bytes_prefix = bytearray()
367
+ if number in range(1, 5):
368
+ bytes_prefix.append(250 + number)
369
+ else:
370
+ bytes_prefix.append(248)
371
+ encode_int64(number, bytes_prefix)
372
+
373
+ ret._enc = make_function(
374
+ name="encode",
375
+ params=["self", "buffer"],
376
+ body=[
377
+ f"buffer.extend({bytes_prefix})",
378
+ Line.join(
379
+ Expr.local("encode_value", field_type.encode_fn()),
380
+ "(self.value, buffer)",
381
+ ),
382
+ ],
383
+ )
384
+
322
385
  return ret
323
386
 
324
387
 
388
+ def _encode_impl(
389
+ value: Any,
390
+ buffer: bytearray,
391
+ ) -> None:
392
+ if value._bytes:
393
+ buffer.extend(value._bytes)
394
+ else:
395
+ value._enc(buffer)
396
+
397
+
325
398
  @dataclass(frozen=True)
326
399
  class _ValueField:
327
400
  spec: _spec.ValueField
@@ -370,11 +443,11 @@ def _make_from_json_fn(
370
443
  value_fields: Sequence[_ValueField],
371
444
  removed_numbers: set[int],
372
445
  base_class: type,
446
+ unrecognized_class: type,
373
447
  ) -> Callable[[Any], Any]:
374
- unrecognized_class = _make_unrecognized_class(base_class)
375
448
  unrecognized_class_local = Expr.local("Unrecognized", unrecognized_class)
376
449
  obj_setattr_local = Expr.local("obj_settatr", object.__setattr__)
377
- removed_numbers_local = Expr.local("removed_numbers", removed_numbers)
450
+ removed_numbers_tuple = tuple(sorted(removed_numbers))
378
451
 
379
452
  key_to_constant: dict[Union[int, str], Any] = {}
380
453
  for field in constant_fields:
@@ -385,15 +458,13 @@ def _make_from_json_fn(
385
458
  unknown_constant = key_to_constant[0]
386
459
  unknown_constant_local = Expr.local("unknown_constant", unknown_constant)
387
460
 
388
- numbers: list[int] = []
389
- names: list[str] = []
390
- key_to_field: dict[Union[int, str], _ValueField] = {}
461
+ number_to_value_field: dict[int, _ValueField] = {}
462
+ name_to_value_field: dict[str, _ValueField] = {}
391
463
  for field in value_fields:
392
- numbers.append(field.spec.number)
393
- names.append(field.spec.name)
394
- key_to_field[field.spec.number] = field
395
- key_to_field[field.spec.name] = field
396
- value_keys_local = Expr.local("value_keys", set(key_to_field.keys()))
464
+ number_to_value_field[field.spec.number] = field
465
+ name_to_value_field[field.spec.name] = field
466
+ value_field_numbers = tuple(sorted(number_to_value_field.keys()))
467
+ value_field_names = tuple(sorted(name_to_value_field.keys()))
397
468
 
398
469
  builder = BodyBuilder()
399
470
  # The reason why we wrap the function inside a 'while' is explained below.
@@ -410,16 +481,20 @@ def _make_from_json_fn(
410
481
  builder.append_ln(" return ", key_to_constant_local, "[json]")
411
482
  builder.append_ln(" except:")
412
483
  if removed_numbers:
413
- builder.append_ln(" if json in ", removed_numbers_local, ":")
484
+ builder.append_ln(
485
+ f" if json in {removed_numbers_tuple} or not keep_unrecognized_fields:"
486
+ )
414
487
  builder.append_ln(" return ", unknown_constant_local)
415
- builder.append_ln(" return ", unrecognized_class_local, "(json)")
488
+ builder.append_ln(" return ", unrecognized_class_local, "(json, b'\\0')")
416
489
 
417
- def append_number_branches(numbers: list[int], indent: str) -> None:
490
+ def append_number_branches(numbers: Sequence[int], indent: str) -> None:
418
491
  if len(numbers) == 1:
419
492
  number = numbers[0]
420
- field = key_to_field[number]
493
+ field = number_to_value_field[number]
421
494
  value_class_local = Expr.local("cls?", field.value_class)
422
- value_expr = field.field_type.from_json_expr("json[1]")
495
+ value_expr = field.field_type.from_json_expr(
496
+ "json[1]", "keep_unrecognized_fields"
497
+ )
423
498
  builder.append_ln(f"{indent}ret = ", value_class_local, "()")
424
499
  builder.append_ln(
425
500
  indent, obj_setattr_local, '(ret, "value", ', value_expr, ")"
@@ -441,19 +516,20 @@ def _make_from_json_fn(
441
516
  if not value_fields:
442
517
  # The field was either removed or is an unrecognized field.
443
518
  if removed_numbers:
444
- builder.append_ln(" if number in ", removed_numbers_local, ":")
519
+ builder.append_ln(
520
+ f" if number in {removed_numbers_tuple} or not keep_unrecognized_fields:"
521
+ )
445
522
  builder.append_ln(" return ", unknown_constant_local)
446
- builder.append_ln(" return ", unrecognized_class_local, "(json)")
523
+ builder.append_ln(" return ", unrecognized_class_local, "(json, b'\\0')")
447
524
  else:
448
- if len(value_fields) == 1:
449
- builder.append_ln(f" if number != {value_fields[0].spec.number}:")
450
- else:
451
- builder.append_ln(" if number not in ", value_keys_local, ":")
525
+ builder.append_ln(f" if number not in {value_field_numbers}:")
452
526
  if removed_numbers:
453
- builder.append_ln(" if number in ", removed_numbers_local, ":")
527
+ builder.append_ln(
528
+ f" if number in {removed_numbers_tuple} or not keep_unrecognized_fields:"
529
+ )
454
530
  builder.append_ln(" return ", unknown_constant_local)
455
- builder.append_ln(" return ", unrecognized_class_local, "(json)")
456
- append_number_branches(sorted(numbers), " ")
531
+ builder.append_ln(" return ", unrecognized_class_local, "(json, b'\\0')")
532
+ append_number_branches(value_field_numbers, " ")
457
533
 
458
534
  # READABLE FORMAT
459
535
  if len(constant_fields) == 1:
@@ -467,12 +543,14 @@ def _make_from_json_fn(
467
543
  # In readable mode, drop unrecognized values and use UNKNOWN instead.
468
544
  builder.append_ln(" return ", unknown_constant_local)
469
545
 
470
- def append_name_branches(names: list[str], indent: str) -> None:
546
+ def append_name_branches(names: Sequence[str], indent: str) -> None:
471
547
  if len(names) == 1:
472
548
  name = names[0]
473
- field = key_to_field[name]
549
+ field = name_to_value_field[name]
474
550
  value_class_local = Expr.local("cls?", field.value_class)
475
- value_expr = field.field_type.from_json_expr("json['value']")
551
+ value_expr = field.field_type.from_json_expr(
552
+ "json['value']", "keep_unrecognized_fields"
553
+ )
476
554
  builder.append_ln(f"{indent}ret = ", value_class_local, "()")
477
555
  builder.append_ln(
478
556
  indent, obj_setattr_local, '(ret, "value", ', value_expr, ")"
@@ -493,10 +571,10 @@ def _make_from_json_fn(
493
571
  builder.append_ln(" return ", unknown_constant_local)
494
572
  else:
495
573
  builder.append_ln(" kind = json['kind']")
496
- builder.append_ln(" if kind not in ", value_keys_local, ":")
574
+ builder.append_ln(f" if kind not in {value_field_names}:")
497
575
  builder.append_ln(" return ", unknown_constant_local)
498
576
  builder.append_ln(" else:")
499
- append_name_branches(sorted(names), " ")
577
+ append_name_branches(value_field_names, " ")
500
578
 
501
579
  # In the unlikely event that json.loads() returns an instance of a subclass of int.
502
580
  builder.append_ln(" elif isinstance(json, int):")
@@ -508,7 +586,112 @@ def _make_from_json_fn(
508
586
 
509
587
  return make_function(
510
588
  name="from_json",
511
- params=["json"],
589
+ params=["json", "keep_unrecognized_fields"],
590
+ body=builder.build(),
591
+ )
592
+
593
+
594
+ def _make_decode_fn(
595
+ constant_fields: Sequence[_spec.ConstantField],
596
+ value_fields: Sequence[_ValueField],
597
+ removed_numbers: set[int],
598
+ base_class: type,
599
+ unrecognized_class: type,
600
+ ) -> Callable[[ByteStream], Any]:
601
+ unrecognized_class_local = Expr.local("Unrecognized", unrecognized_class)
602
+ obj_setattr_local = Expr.local("obj_settatr", object.__setattr__)
603
+
604
+ number_to_constant: dict[int, Any] = {}
605
+ for field in constant_fields:
606
+ constant = getattr(base_class, field.attribute)
607
+ number_to_constant[field.number] = constant
608
+ number_to_constant_local = Expr.local("number_to_constant", number_to_constant)
609
+ removed_numbers_tuple = tuple(sorted(removed_numbers))
610
+ unknown_constant = number_to_constant[0]
611
+ unknown_constant_local = Expr.local("unknown_constant", unknown_constant)
612
+
613
+ number_to_value_field: dict[int, _ValueField] = {}
614
+ for field in value_fields:
615
+ number_to_value_field[field.spec.number] = field
616
+ value_field_numbers = tuple(sorted(number_to_value_field.keys()))
617
+
618
+ builder = BodyBuilder()
619
+ builder.append_ln("start_offset = stream.position")
620
+ builder.append_ln("wire = stream.buffer[start_offset]")
621
+ builder.append_ln("if wire <= 238:")
622
+ # A number
623
+ builder.append_ln(" if wire < 232:")
624
+ builder.append_ln(" stream.position += 1")
625
+ builder.append_ln(" number = wire")
626
+ builder.append_ln(" else:")
627
+ builder.append_ln(
628
+ " number = ", Expr.local("decode_int64", decode_int64), "(stream)"
629
+ )
630
+ builder.append_ln(" try:")
631
+ builder.append_ln(" return ", number_to_constant_local, "[number]")
632
+ builder.append_ln(" except:")
633
+ builder.append_ln(" ", Expr.local("decode_unused", decode_unused), "(stream)")
634
+ if removed_numbers:
635
+ builder.append_ln(
636
+ f" if number in {removed_numbers_tuple} or not keep_unrecognized_fields:"
637
+ )
638
+ builder.append_ln(" return ", unknown_constant_local)
639
+ builder.append_ln(" bytes = stream.buffer[start_offset:stream.position]")
640
+ builder.append_ln(" return ", unrecognized_class_local, "(0, bytes)")
641
+ # An array of 2
642
+ builder.append_ln("stream.position += 1")
643
+ builder.append_ln("if wire == 248:")
644
+ builder.append_ln(
645
+ " number = ", Expr.local("decode_int64", decode_int64), "(stream)"
646
+ )
647
+ builder.append_ln("else:")
648
+ builder.append_ln(" number = wire - 250")
649
+
650
+ def append_number_branches(numbers: Sequence[int], indent: str) -> None:
651
+ if len(numbers) == 1:
652
+ number = numbers[0]
653
+ field = number_to_value_field[number]
654
+ value_class_local = Expr.local("cls?", field.value_class)
655
+ decode_local = Expr.local("decode?", field.field_type.decode_fn())
656
+ value_expr = Expr.join(decode_local, "(stream)")
657
+ builder.append_ln(f"{indent}ret = ", value_class_local, "()")
658
+ builder.append_ln(
659
+ indent, obj_setattr_local, '(ret, "value", ', value_expr, ")"
660
+ )
661
+ builder.append_ln(f"{indent}return ret")
662
+ else:
663
+ indented = f" {indent}"
664
+ mid_index = int(len(numbers) / 2)
665
+ mid_number = numbers[mid_index - 1]
666
+ operator = "==" if mid_index == 1 else "<="
667
+ builder.append_ln(f"{indent}if number {operator} {mid_number}:")
668
+ append_number_branches(numbers[0:mid_index], indented)
669
+ builder.append_ln(f"{indent}else:")
670
+ append_number_branches(numbers[mid_index:], indented)
671
+
672
+ if not value_fields:
673
+ # The field was either removed or is an unrecognized field.
674
+ if removed_numbers:
675
+ builder.append_ln(
676
+ f"if number in {removed_numbers_tuple} or not keep_unrecognized_fields:"
677
+ )
678
+ builder.append_ln(" return ", unknown_constant_local)
679
+ builder.append_ln("bytes = stream.buffer[start_offset:stream.position]")
680
+ builder.append_ln("return ", unrecognized_class_local, "(0, bytes)")
681
+ else:
682
+ builder.append_ln(f"if number not in {value_field_numbers}:")
683
+ if removed_numbers:
684
+ builder.append_ln(
685
+ f" if number in {removed_numbers_tuple} or not keep_unrecognized_fields:"
686
+ )
687
+ builder.append_ln(" return ", unknown_constant_local)
688
+ builder.append_ln(" bytes = stream.buffer[start_offset:stream.position]")
689
+ builder.append_ln(" return ", unrecognized_class_local, "(0, bytes)")
690
+ append_number_branches(value_field_numbers, "")
691
+
692
+ return make_function(
693
+ name="decode",
694
+ params=["stream"],
512
695
  body=builder.build(),
513
696
  )
514
697
 
soia/_impl/optionals.py CHANGED
@@ -1,23 +1,22 @@
1
1
  from collections.abc import Callable
2
2
  from dataclasses import dataclass
3
- from typing import TypeVar
3
+ from functools import cached_property
4
+ from typing import Generic
4
5
  from weakref import WeakValueDictionary
5
6
 
6
7
  from soia import _spec, reflection
7
8
  from soia._impl.function_maker import Expr, ExprLike
8
- from soia._impl.type_adapter import TypeAdapter
9
+ from soia._impl.type_adapter import T, ByteStream, TypeAdapter
9
10
 
10
- Other = TypeVar("Other")
11
11
 
12
-
13
- def get_optional_adapter(other_adapter: TypeAdapter) -> TypeAdapter:
12
+ def get_optional_adapter(other_adapter: TypeAdapter[T]) -> TypeAdapter[T | None]:
14
13
  return _other_adapter_to_optional_adapter.setdefault(
15
14
  other_adapter, _OptionalAdapter(other_adapter)
16
15
  )
17
16
 
18
17
 
19
18
  @dataclass(frozen=True)
20
- class _OptionalAdapter(TypeAdapter):
19
+ class _OptionalAdapter(Generic[T], TypeAdapter[T | None]):
21
20
  __slots__ = ("other_adapter",)
22
21
 
23
22
  other_adapter: TypeAdapter
@@ -49,8 +48,12 @@ class _OptionalAdapter(TypeAdapter):
49
48
  ")",
50
49
  )
51
50
 
52
- def from_json_expr(self, json_expr: ExprLike) -> ExprLike:
53
- other_from_json = self.other_adapter.from_json_expr(json_expr)
51
+ def from_json_expr(
52
+ self, json_expr: ExprLike, keep_unrecognized_expr: ExprLike
53
+ ) -> ExprLike:
54
+ other_from_json = self.other_adapter.from_json_expr(
55
+ json_expr, keep_unrecognized_expr
56
+ )
54
57
  if other_from_json == json_expr:
55
58
  return json_expr
56
59
  return Expr.join(
@@ -61,6 +64,42 @@ class _OptionalAdapter(TypeAdapter):
61
64
  ")",
62
65
  )
63
66
 
67
+ @cached_property
68
+ def encode_fn_impl(self) -> Callable[[T | None, bytearray], None]:
69
+ encode_value = self.other_adapter.encode_fn()
70
+
71
+ def encode(
72
+ value: T | None,
73
+ buffer: bytearray,
74
+ ) -> None:
75
+ if value is None:
76
+ buffer.append(255)
77
+ else:
78
+ encode_value(value, buffer)
79
+
80
+ return encode
81
+
82
+ def encode_fn(self) -> Callable[[T | None, bytearray], None]:
83
+ return self.encode_fn_impl
84
+
85
+ @cached_property
86
+ def decode_fn_impl(self) -> Callable[[ByteStream], T | None]:
87
+ decode_value = self.other_adapter.decode_fn()
88
+
89
+ def decode(
90
+ stream: ByteStream,
91
+ ) -> T | None:
92
+ if stream.buffer[stream.position] == 255:
93
+ stream.position += 1
94
+ return None
95
+ else:
96
+ return decode_value(stream)
97
+
98
+ return decode
99
+
100
+ def decode_fn(self) -> Callable[[ByteStream], T | None]:
101
+ return self.decode_fn_impl
102
+
64
103
  def finalize(
65
104
  self,
66
105
  resolve_type_fn: Callable[[_spec.Type], "TypeAdapter"],