maxframe 0.1.0b5__cp39-cp39-win_amd64.whl → 1.0.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (47) hide show
  1. maxframe/_utils.cp39-win_amd64.pyd +0 -0
  2. maxframe/codegen.py +10 -2
  3. maxframe/config/config.py +4 -0
  4. maxframe/core/__init__.py +0 -3
  5. maxframe/core/entity/__init__.py +1 -8
  6. maxframe/core/entity/objects.py +3 -45
  7. maxframe/core/graph/core.cp39-win_amd64.pyd +0 -0
  8. maxframe/core/graph/core.pyx +4 -4
  9. maxframe/dataframe/datastore/tests/__init__.py +13 -0
  10. maxframe/dataframe/datastore/tests/test_to_odps.py +48 -0
  11. maxframe/dataframe/datastore/to_odps.py +21 -0
  12. maxframe/dataframe/indexing/align.py +1 -1
  13. maxframe/dataframe/misc/apply.py +2 -0
  14. maxframe/dataframe/misc/memory_usage.py +2 -2
  15. maxframe/dataframe/misc/tests/test_misc.py +23 -0
  16. maxframe/dataframe/statistics/corr.py +3 -3
  17. maxframe/errors.py +13 -0
  18. maxframe/extension.py +12 -0
  19. maxframe/lib/mmh3.cp39-win_amd64.pyd +0 -0
  20. maxframe/lib/mmh3.pyi +43 -0
  21. maxframe/lib/wrapped_pickle.py +2 -1
  22. maxframe/protocol.py +108 -10
  23. maxframe/serialization/core.cp39-win_amd64.pyd +0 -0
  24. maxframe/serialization/core.pxd +3 -0
  25. maxframe/serialization/core.pyi +3 -0
  26. maxframe/serialization/core.pyx +54 -25
  27. maxframe/serialization/exception.py +1 -1
  28. maxframe/serialization/pandas.py +7 -2
  29. maxframe/serialization/serializables/core.py +119 -12
  30. maxframe/serialization/serializables/tests/test_serializable.py +46 -4
  31. maxframe/tensor/arithmetic/tests/test_arithmetic.py +1 -1
  32. maxframe/tensor/base/atleast_1d.py +1 -1
  33. maxframe/tensor/base/unique.py +1 -1
  34. maxframe/tensor/reduction/count_nonzero.py +1 -1
  35. maxframe/tests/test_protocol.py +34 -0
  36. maxframe/tests/test_utils.py +0 -12
  37. maxframe/tests/utils.py +2 -2
  38. maxframe/utils.py +16 -13
  39. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc1.dist-info}/METADATA +2 -2
  40. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc1.dist-info}/RECORD +46 -44
  41. maxframe_client/__init__.py +0 -1
  42. maxframe_client/session/odps.py +45 -5
  43. maxframe_client/session/task.py +41 -20
  44. maxframe_client/tests/test_session.py +36 -0
  45. maxframe_client/clients/spe.py +0 -104
  46. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc1.dist-info}/WHEEL +0 -0
  47. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc1.dist-info}/top_level.txt +0 -0
maxframe/protocol.py CHANGED
@@ -32,6 +32,7 @@ from .serialization.serializables import (
32
32
  EnumField,
33
33
  FieldTypes,
34
34
  Float64Field,
35
+ Int32Field,
35
36
  ListField,
36
37
  ReferenceField,
37
38
  Serializable,
@@ -71,6 +72,9 @@ class DagStatus(enum.Enum):
71
72
  CANCELLING = 4
72
73
  CANCELLED = 5
73
74
 
75
+ def is_terminated(self):
76
+ return self in (DagStatus.CANCELLED, DagStatus.SUCCEEDED, DagStatus.FAILED)
77
+
74
78
 
75
79
  class DimensionIndex(Serializable):
76
80
  is_slice: bool = BoolField("is_slice", default=None)
@@ -190,9 +194,9 @@ class ErrorInfo(JsonSerializable):
190
194
  "error_tracebacks", FieldTypes.list
191
195
  )
192
196
  raw_error_source: ErrorSource = EnumField(
193
- "raw_error_source", ErrorSource, FieldTypes.int8
197
+ "raw_error_source", ErrorSource, FieldTypes.int8, default=None
194
198
  )
195
- raw_error_data: Optional[Exception] = AnyField("raw_error_data")
199
+ raw_error_data: Optional[Exception] = AnyField("raw_error_data", default=None)
196
200
 
197
201
  @classmethod
198
202
  def from_exception(cls, exc: Exception):
@@ -201,20 +205,29 @@ class ErrorInfo(JsonSerializable):
201
205
  return cls(messages, tracebacks, ErrorSource.PYTHON, exc)
202
206
 
203
207
  def reraise(self):
204
- if self.raw_error_source == ErrorSource.PYTHON:
208
+ if (
209
+ self.raw_error_source == ErrorSource.PYTHON
210
+ and self.raw_error_data is not None
211
+ ):
205
212
  raise self.raw_error_data
206
213
  raise RemoteException(self.error_messages, self.error_tracebacks, [])
207
214
 
208
215
  @classmethod
209
216
  def from_json(cls, serialized: dict) -> "ErrorInfo":
210
217
  kw = serialized.copy()
211
- kw["raw_error_source"] = ErrorSource(serialized["raw_error_source"])
218
+ if kw.get("raw_error_source") is not None:
219
+ kw["raw_error_source"] = ErrorSource(serialized["raw_error_source"])
220
+ else:
221
+ kw["raw_error_source"] = None
222
+
212
223
  if kw.get("raw_error_data"):
213
224
  bufs = [base64.b64decode(s) for s in kw["raw_error_data"]]
214
225
  try:
215
226
  kw["raw_error_data"] = pickle.loads(bufs[0], buffers=bufs[1:])
216
227
  except:
217
- kw["raw_error_data"] = None
228
+ # both error source and data shall be None to make sure
229
+ # RemoteException is raised.
230
+ kw["raw_error_source"] = kw["raw_error_data"] = None
218
231
  return cls(**kw)
219
232
 
220
233
  def to_json(self) -> dict:
@@ -227,7 +240,12 @@ class ErrorInfo(JsonSerializable):
227
240
  if isinstance(self.raw_error_data, (PickleContainer, RemoteException)):
228
241
  err_data_bufs = self.raw_error_data.get_buffers()
229
242
  elif isinstance(self.raw_error_data, BaseException):
230
- err_data_bufs = pickle_buffers(self.raw_error_data)
243
+ try:
244
+ err_data_bufs = pickle_buffers(self.raw_error_data)
245
+ except:
246
+ err_data_bufs = None
247
+ ret["raw_error_source"] = None
248
+
231
249
  if err_data_bufs:
232
250
  ret["raw_error_data"] = [
233
251
  base64.b64encode(s).decode() for s in err_data_bufs
@@ -249,9 +267,17 @@ class DagInfo(JsonSerializable):
249
267
  error_info: Optional[ErrorInfo] = ReferenceField("error_info", default=None)
250
268
  start_timestamp: Optional[float] = Float64Field("start_timestamp", default=None)
251
269
  end_timestamp: Optional[float] = Float64Field("end_timestamp", default=None)
270
+ subdag_infos: Dict[str, "SubDagInfo"] = DictField(
271
+ "subdag_infos",
272
+ key_type=FieldTypes.string,
273
+ value_type=FieldTypes.reference,
274
+ default_factory=dict,
275
+ )
252
276
 
253
277
  @classmethod
254
- def from_json(cls, serialized: dict) -> "DagInfo":
278
+ def from_json(cls, serialized: dict) -> Optional["DagInfo"]:
279
+ if serialized is None:
280
+ return None
255
281
  kw = serialized.copy()
256
282
  kw["status"] = DagStatus(kw["status"])
257
283
  if kw.get("tileable_to_result_infos"):
@@ -261,6 +287,10 @@ class DagInfo(JsonSerializable):
261
287
  }
262
288
  if kw.get("error_info"):
263
289
  kw["error_info"] = ErrorInfo.from_json(kw["error_info"])
290
+ if kw.get("subdag_infos"):
291
+ kw["subdag_infos"] = {
292
+ k: SubDagInfo.from_json(v) for k, v in kw["subdag_infos"].items()
293
+ }
264
294
  return DagInfo(**kw)
265
295
 
266
296
  def to_json(self) -> dict:
@@ -279,6 +309,8 @@ class DagInfo(JsonSerializable):
279
309
  }
280
310
  if self.error_info:
281
311
  ret["error_info"] = self.error_info.to_json()
312
+ if self.subdag_infos:
313
+ ret["subdag_infos"] = {k: v.to_json() for k, v in self.subdag_infos.items()}
282
314
  return ret
283
315
 
284
316
 
@@ -302,7 +334,9 @@ class SessionInfo(JsonSerializable):
302
334
  error_info: Optional[ErrorInfo] = ReferenceField("error_info", default=None)
303
335
 
304
336
  @classmethod
305
- def from_json(cls, serialized: dict) -> "SessionInfo":
337
+ def from_json(cls, serialized: dict) -> Optional["SessionInfo"]:
338
+ if serialized is None:
339
+ return None
306
340
  kw = serialized.copy()
307
341
  if kw.get("dag_infos"):
308
342
  kw["dag_infos"] = {
@@ -320,7 +354,10 @@ class SessionInfo(JsonSerializable):
320
354
  "idle_timestamp": self.idle_timestamp,
321
355
  }
322
356
  if self.dag_infos:
323
- ret["dag_infos"] = {k: v.to_json() for k, v in self.dag_infos.items()}
357
+ ret["dag_infos"] = {
358
+ k: v.to_json() if v is not None else None
359
+ for k, v in self.dag_infos.items()
360
+ }
324
361
  if self.error_info:
325
362
  ret["error_info"] = self.error_info.to_json()
326
363
  return ret
@@ -342,7 +379,25 @@ class ExecuteDagRequest(Serializable):
342
379
  )
343
380
 
344
381
 
345
- class SubDagInfo(Serializable):
382
+ class SubDagSubmitInstanceInfo(JsonSerializable):
383
+ submit_reason: str = StringField("submit_reason")
384
+ instance_id: str = StringField("instance_id")
385
+ subquery_id: Optional[int] = Int32Field("subquery_id", default=None)
386
+
387
+ @classmethod
388
+ def from_json(cls, serialized: dict) -> "SubDagSubmitInstanceInfo":
389
+ return SubDagSubmitInstanceInfo(**serialized)
390
+
391
+ def to_json(self) -> dict:
392
+ ret = {
393
+ "submit_reason": self.submit_reason,
394
+ "instance_id": self.instance_id,
395
+ "subquery_id": self.subquery_id,
396
+ }
397
+ return ret
398
+
399
+
400
+ class SubDagInfo(JsonSerializable):
346
401
  subdag_id: str = StringField("subdag_id")
347
402
  status: DagStatus = EnumField("status", DagStatus, FieldTypes.int8, default=None)
348
403
  progress: float = Float64Field("progress", default=None)
@@ -355,9 +410,52 @@ class SubDagInfo(Serializable):
355
410
  FieldTypes.reference,
356
411
  default_factory=dict,
357
412
  )
413
+ start_timestamp: Optional[float] = Float64Field("start_timestamp", default=None)
414
+ end_timestamp: Optional[float] = Float64Field("end_timestamp", default=None)
415
+ submit_instances: List[SubDagSubmitInstanceInfo] = ListField(
416
+ "submit_instances",
417
+ FieldTypes.reference,
418
+ default_factory=list,
419
+ )
420
+
421
+ @classmethod
422
+ def from_json(cls, serialized: dict) -> "SubDagInfo":
423
+ kw = serialized.copy()
424
+ kw["status"] = DagStatus(kw["status"])
425
+ if kw.get("tileable_to_result_infos"):
426
+ kw["tileable_to_result_infos"] = {
427
+ k: ResultInfo.from_json(s)
428
+ for k, s in kw["tileable_to_result_infos"].items()
429
+ }
430
+ if kw.get("error_info"):
431
+ kw["error_info"] = ErrorInfo.from_json(kw["error_info"])
432
+ if kw.get("submit_instances"):
433
+ kw["submit_instances"] = [
434
+ SubDagSubmitInstanceInfo.from_json(s) for s in kw["submit_instances"]
435
+ ]
436
+ return SubDagInfo(**kw)
437
+
438
+ def to_json(self) -> dict:
439
+ ret = {
440
+ "subdag_id": self.subdag_id,
441
+ "status": self.status.value,
442
+ "progress": self.progress,
443
+ "start_timestamp": self.start_timestamp,
444
+ "end_timestamp": self.end_timestamp,
445
+ }
446
+ if self.error_info:
447
+ ret["error_info"] = self.error_info.to_json()
448
+ if self.tileable_to_result_infos:
449
+ ret["tileable_to_result_infos"] = {
450
+ k: v.to_json() for k, v in self.tileable_to_result_infos.items()
451
+ }
452
+ if self.submit_instances:
453
+ ret["submit_instances"] = [i.to_json() for i in self.submit_instances]
454
+ return ret
358
455
 
359
456
 
360
457
  class ExecuteSubDagRequest(Serializable):
458
+ subdag_id: str = StringField("subdag_id")
361
459
  dag: TileableGraph = ReferenceField(
362
460
  "dag",
363
461
  on_serialize=SerializableGraph.from_graph,
@@ -18,6 +18,9 @@ from libc.stdint cimport int32_t, uint64_t
18
18
  cdef class Serializer:
19
19
  cdef int _serializer_id
20
20
 
21
+ cpdef bint is_public_data_exist(self, dict context, object key)
22
+ cpdef put_public_data(self, dict context, object key, object value)
23
+ cpdef get_public_data(self, dict context, object key)
21
24
  cpdef serial(self, object obj, dict context)
22
25
  cpdef deserial(self, list serialized, dict context, list subs)
23
26
  cpdef on_deserial_error(
@@ -29,6 +29,9 @@ class PickleContainer:
29
29
 
30
30
  class Serializer:
31
31
  serializer_id: int
32
+ def is_public_data_exist(self, context: Dict, key: Any) -> bool: ...
33
+ def put_public_data(self, context: Dict, key: Any, value: Any) -> None: ...
34
+ def get_public_data(self, context: Dict, key: Any) -> Any: ...
32
35
  def serial(self, obj: Any, context: Dict): ...
33
36
  def deserial(self, serialized: List, context: Dict, subs: List[Any]): ...
34
37
  def on_deserial_error(
@@ -130,11 +130,30 @@ cdef Serializer get_deserializer(int32_t deserializer_id):
130
130
 
131
131
  cdef class Serializer:
132
132
  serializer_id = None
133
+ _public_data_context_key = 0x7fffffff - 1
133
134
 
134
135
  def __cinit__(self):
135
136
  # make the value can be referenced with C code
136
137
  self._serializer_id = self.serializer_id
137
138
 
139
+ cpdef bint is_public_data_exist(self, dict context, object key):
140
+ cdef dict public_dict = context.get(self._public_data_context_key, None)
141
+ if public_dict is None:
142
+ return False
143
+ return key in public_dict
144
+
145
+ cpdef put_public_data(self, dict context, object key, object value):
146
+ cdef dict public_dict = context.get(self._public_data_context_key, None)
147
+ if public_dict is None:
148
+ public_dict = context[self._public_data_context_key] = {}
149
+ public_dict[key] = value
150
+
151
+ cpdef get_public_data(self, dict context, object key):
152
+ cdef dict public_dict = context.get(self._public_data_context_key, None)
153
+ if public_dict is None:
154
+ return None
155
+ return public_dict.get(key)
156
+
138
157
  cpdef serial(self, object obj, dict context):
139
158
  """
140
159
  Returns intermediate serialization result of certain object.
@@ -993,17 +1012,20 @@ def serialize(obj, dict context = None):
993
1012
  cdef list subs
994
1013
  cdef bint final
995
1014
  cdef _IdContextHolder id_context_holder = _IdContextHolder()
1015
+ cdef tuple result
996
1016
 
997
1017
  context = context if context is not None else dict()
998
1018
  serialized, subs, final = _serial_single(obj, context, id_context_holder)
999
1019
  if final or not subs:
1000
1020
  # marked as a leaf node, return directly
1001
- return [{}, serialized], subs
1002
-
1003
- serial_stack.append(_SerialStackItem(serialized, subs))
1004
- return _serialize_with_stack(
1005
- serial_stack, None, context, id_context_holder, result_bufs_list
1006
- )
1021
+ result = [{}, serialized], subs
1022
+ else:
1023
+ serial_stack.append(_SerialStackItem(serialized, subs))
1024
+ result = _serialize_with_stack(
1025
+ serial_stack, None, context, id_context_holder, result_bufs_list
1026
+ )
1027
+ result[0][0]["_PUB"] = context.get(Serializer._public_data_context_key)
1028
+ return result
1007
1029
 
1008
1030
 
1009
1031
  async def serialize_with_spawn(
@@ -1036,31 +1058,38 @@ async def serialize_with_spawn(
1036
1058
  cdef list subs
1037
1059
  cdef bint final
1038
1060
  cdef _IdContextHolder id_context_holder = _IdContextHolder()
1061
+ cdef tuple result
1039
1062
 
1040
1063
  context = context if context is not None else dict()
1041
1064
  serialized, subs, final = _serial_single(obj, context, id_context_holder)
1042
1065
  if final or not subs:
1043
1066
  # marked as a leaf node, return directly
1044
- return [{}, serialized], subs
1045
-
1046
- serial_stack.append(_SerialStackItem(serialized, subs))
1067
+ result = [{}, serialized], subs
1068
+ else:
1069
+ serial_stack.append(_SerialStackItem(serialized, subs))
1047
1070
 
1048
- try:
1049
- result = _serialize_with_stack(
1050
- serial_stack, None, context, id_context_holder, result_bufs_list, spawn_threshold
1051
- )
1052
- except _SerializeObjectOverflow as ex:
1053
- result = await asyncio.get_running_loop().run_in_executor(
1054
- executor,
1055
- _serialize_with_stack,
1056
- serial_stack,
1057
- ex.cur_serialized,
1058
- context,
1059
- id_context_holder,
1060
- result_bufs_list,
1061
- 0,
1062
- ex.num_total_serialized,
1063
- )
1071
+ try:
1072
+ result = _serialize_with_stack(
1073
+ serial_stack,
1074
+ None,
1075
+ context,
1076
+ id_context_holder,
1077
+ result_bufs_list,
1078
+ spawn_threshold,
1079
+ )
1080
+ except _SerializeObjectOverflow as ex:
1081
+ result = await asyncio.get_running_loop().run_in_executor(
1082
+ executor,
1083
+ _serialize_with_stack,
1084
+ serial_stack,
1085
+ ex.cur_serialized,
1086
+ context,
1087
+ id_context_holder,
1088
+ result_bufs_list,
1089
+ 0,
1090
+ ex.num_total_serialized,
1091
+ )
1092
+ result[0][0]["_PUB"] = context.get(Serializer._public_data_context_key)
1064
1093
  return result
1065
1094
 
1066
1095
 
@@ -35,7 +35,7 @@ class RemoteException(MaxFrameError):
35
35
  def from_exception(cls, exc: Exception):
36
36
  try:
37
37
  buffers = pickle_buffers(exc)
38
- except (TypeError, pickle.PicklingError):
38
+ except:
39
39
  logger.exception("Cannot pickle exception %s", exc)
40
40
  buffers = []
41
41
 
@@ -176,11 +176,16 @@ class PdTimestampSerializer(Serializer):
176
176
 
177
177
  class PdTimedeltaSerializer(Serializer):
178
178
  def serial(self, obj: pd.Timedelta, context: Dict):
179
- return [int(obj.seconds), obj.microseconds, obj.nanoseconds], [], True
179
+ return [int(obj.seconds), obj.microseconds, obj.nanoseconds, obj.days], [], True
180
180
 
181
181
  def deserial(self, serialized: List, context: Dict, subs: List):
182
+ days = 0 if len(serialized) < 4 else serialized[3]
183
+ seconds, microseconds, nanoseconds = serialized[:3]
182
184
  return pd.Timedelta(
183
- seconds=serialized[0], microseconds=serialized[1], nanoseconds=serialized[2]
185
+ days=days,
186
+ seconds=seconds,
187
+ microseconds=microseconds,
188
+ nanoseconds=nanoseconds,
184
189
  )
185
190
 
186
191
 
@@ -12,12 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import operator
16
15
  import weakref
17
- from typing import Dict, List, Tuple, Type
16
+ from collections import defaultdict
17
+ from typing import Any, Dict, List, Optional, Tuple, Type
18
18
 
19
19
  import msgpack
20
20
 
21
+ from ...lib.mmh3 import hash
21
22
  from ..core import Placeholder, Serializer, buffered, load_type
22
23
  from .field import Field
23
24
  from .field_type import DictType, ListType, PrimitiveFieldType, TupleType
@@ -50,11 +51,16 @@ def _is_field_primitive_compound(field: Field):
50
51
  class SerializableMeta(type):
51
52
  def __new__(mcs, name: str, bases: Tuple[Type], properties: Dict):
52
53
  # All the fields including misc fields.
54
+ name_hash = hash(f"{properties.get('__module__')}.{name}")
53
55
  all_fields = dict()
56
+ # mapping field names to base classes
57
+ field_to_cls_hash = dict()
54
58
 
55
59
  for base in bases:
56
- if hasattr(base, "_FIELDS"):
57
- all_fields.update(base._FIELDS)
60
+ if not hasattr(base, "_FIELDS"):
61
+ continue
62
+ all_fields.update(base._FIELDS)
63
+ field_to_cls_hash.update(base._FIELD_TO_NAME_HASH)
58
64
 
59
65
  properties_without_fields = {}
60
66
  properties_field_slot_names = []
@@ -64,6 +70,8 @@ class SerializableMeta(type):
64
70
  continue
65
71
 
66
72
  field = all_fields.get(k)
73
+ # record the field for the class being created
74
+ field_to_cls_hash[k] = name_hash
67
75
  if field is None:
68
76
  properties_field_slot_names.append(k)
69
77
  else:
@@ -75,23 +83,40 @@ class SerializableMeta(type):
75
83
 
76
84
  # Make field order deterministic to serialize it as list instead of dict.
77
85
  field_order = list(all_fields)
78
- all_fields = dict(sorted(all_fields.items(), key=operator.itemgetter(0)))
79
86
  primitive_fields = []
87
+ primitive_field_names = set()
80
88
  non_primitive_fields = []
81
- for v in all_fields.values():
89
+ for field_name, v in all_fields.items():
82
90
  if _is_field_primitive_compound(v):
83
91
  primitive_fields.append(v)
92
+ primitive_field_names.add(field_name)
84
93
  else:
85
94
  non_primitive_fields.append(v)
86
95
 
96
+ # count number of fields for every base class
97
+ cls_to_primitive_field_count = defaultdict(lambda: 0)
98
+ cls_to_non_primitive_field_count = defaultdict(lambda: 0)
99
+ for field_name in field_order:
100
+ cls_hash = field_to_cls_hash[field_name]
101
+ if field_name in primitive_field_names:
102
+ cls_to_primitive_field_count[cls_hash] += 1
103
+ else:
104
+ cls_to_non_primitive_field_count[cls_hash] += 1
105
+
87
106
  slots = set(properties.pop("__slots__", set()))
88
107
  slots.update(properties_field_slot_names)
89
108
 
90
109
  properties = properties_without_fields
110
+ properties["_NAME_HASH"] = name_hash
91
111
  properties["_FIELDS"] = all_fields
92
112
  properties["_FIELD_ORDER"] = field_order
113
+ properties["_FIELD_TO_NAME_HASH"] = field_to_cls_hash
93
114
  properties["_PRIMITIVE_FIELDS"] = primitive_fields
115
+ properties["_CLS_TO_PRIMITIVE_FIELD_COUNT"] = dict(cls_to_primitive_field_count)
94
116
  properties["_NON_PRIMITIVE_FIELDS"] = non_primitive_fields
117
+ properties["_CLS_TO_NON_PRIMITIVE_FIELD_COUNT"] = dict(
118
+ cls_to_non_primitive_field_count
119
+ )
95
120
  properties["__slots__"] = tuple(slots)
96
121
 
97
122
  clz = type.__new__(mcs, name, bases, properties)
@@ -114,10 +139,14 @@ class Serializable(metaclass=SerializableMeta):
114
139
  _cache_primitive_serial = False
115
140
  _ignore_non_existing_keys = False
116
141
 
142
+ _NAME_HASH: int
117
143
  _FIELDS: Dict[str, Field]
118
144
  _FIELD_ORDER: List[str]
145
+ _FIELD_TO_NAME_HASH: Dict[str, int]
119
146
  _PRIMITIVE_FIELDS: List[str]
147
+ _CLS_TO_PRIMITIVE_FIELD_COUNT: Dict[int, int]
120
148
  _NON_PRIMITIVE_FIELDS: List[str]
149
+ _CLS_TO_NON_PRIMITIVE_FIELD_COUNT: Dict[int, int]
121
150
 
122
151
  def __init__(self, *args, **kwargs):
123
152
  fields = self._FIELDS
@@ -180,6 +209,10 @@ class SerializableSerializer(Serializer):
180
209
  Leverage DictSerializer to perform serde.
181
210
  """
182
211
 
212
+ @classmethod
213
+ def _get_obj_field_count_key(cls, obj: Serializable):
214
+ return f"FC_{obj._NAME_HASH}"
215
+
183
216
  @classmethod
184
217
  def _get_field_values(cls, obj: Serializable, fields):
185
218
  values = []
@@ -210,6 +243,18 @@ class SerializableSerializer(Serializer):
210
243
 
211
244
  compound_vals = self._get_field_values(obj, obj._NON_PRIMITIVE_FIELDS)
212
245
  cls_module = f"{type(obj).__module__}#{type(obj).__qualname__}"
246
+
247
+ field_count_key = self._get_obj_field_count_key(obj)
248
+ if not self.is_public_data_exist(context, field_count_key):
249
+ # store field distribution for current Serializable
250
+ counts = [
251
+ list(obj._CLS_TO_PRIMITIVE_FIELD_COUNT.items()),
252
+ list(obj._CLS_TO_NON_PRIMITIVE_FIELD_COUNT.items()),
253
+ ]
254
+ field_count_data = msgpack.dumps(counts)
255
+ self.put_public_data(
256
+ context, self._get_obj_field_count_key(obj), field_count_data
257
+ )
213
258
  return [cls_module, primitive_vals], [compound_vals], False
214
259
 
215
260
  @staticmethod
@@ -229,6 +274,62 @@ class SerializableSerializer(Serializer):
229
274
  else:
230
275
  field.set(obj, value)
231
276
 
277
+ @classmethod
278
+ def _set_field_values(
279
+ cls,
280
+ obj: Serializable,
281
+ values: List[Any],
282
+ client_cls_to_field_count: Optional[Dict[str, int]],
283
+ is_primitive: bool = True,
284
+ ):
285
+ obj_class = type(obj)
286
+ if is_primitive:
287
+ server_cls_to_field_count = obj_class._CLS_TO_PRIMITIVE_FIELD_COUNT
288
+ server_fields = obj_class._PRIMITIVE_FIELDS
289
+ else:
290
+ server_cls_to_field_count = obj_class._CLS_TO_NON_PRIMITIVE_FIELD_COUNT
291
+ server_fields = obj_class._NON_PRIMITIVE_FIELDS
292
+
293
+ if client_cls_to_field_count:
294
+ field_num, server_field_num = 0, 0
295
+ for cls_hash, count in client_cls_to_field_count.items():
296
+ # cut values and fields given field distribution
297
+ # at client and server end
298
+ cls_fields = server_fields[server_field_num : field_num + count]
299
+ cls_values = values[field_num : field_num + count]
300
+ for field, value in zip(cls_fields, cls_values):
301
+ if not is_primitive or value != {}:
302
+ cls._set_field_value(obj, field, value)
303
+ field_num += count
304
+ server_field_num += server_cls_to_field_count[cls_hash]
305
+ else:
306
+ # todo remove this branch when all versions below v0.1.0b5 is eliminated
307
+ from .field import AnyField
308
+
309
+ # legacy serialization style, with all fields sorted by name
310
+ if is_primitive:
311
+ field_attr = "_legacy_deprecated_primitives"
312
+ else:
313
+ field_attr = "_legacy_deprecated_non_primitives"
314
+ deprecated_fields = []
315
+ deprecated_names = set()
316
+ if hasattr(obj_class, field_attr):
317
+ deprecated_names = set(getattr(obj_class, field_attr))
318
+ for field_name in deprecated_names:
319
+ field = AnyField(tag=field_name)
320
+ field.name = field_name
321
+ deprecated_fields.append(field)
322
+ server_fields = sorted(
323
+ server_fields + deprecated_fields, key=lambda f: f.name
324
+ )
325
+ for field, value in zip(server_fields, values):
326
+ if not is_primitive or value != {}:
327
+ try:
328
+ cls._set_field_value(obj, field, value)
329
+ except AttributeError: # pragma: no cover
330
+ if field.name not in deprecated_names:
331
+ raise
332
+
232
333
  def deserial(self, serialized: List, context: Dict, subs: List) -> Serializable:
233
334
  obj_class_name, primitives = serialized
234
335
  obj_class = load_type(obj_class_name, Serializable)
@@ -238,14 +339,20 @@ class SerializableSerializer(Serializer):
238
339
 
239
340
  obj = obj_class.__new__(obj_class)
240
341
 
241
- if primitives:
242
- for field, value in zip(obj_class._PRIMITIVE_FIELDS, primitives):
243
- if value != {}:
244
- self._set_field_value(obj, field, value)
342
+ field_count_data = self.get_public_data(
343
+ context, self._get_obj_field_count_key(obj)
344
+ )
345
+ if field_count_data is not None:
346
+ cls_to_prim_key, cls_to_non_prim_key = msgpack.loads(field_count_data)
347
+ cls_to_prim_key = dict(cls_to_prim_key)
348
+ cls_to_non_prim_key = dict(cls_to_non_prim_key)
349
+ else:
350
+ cls_to_prim_key, cls_to_non_prim_key = None, None
245
351
 
352
+ if primitives:
353
+ self._set_field_values(obj, primitives, cls_to_prim_key, True)
246
354
  if obj_class._NON_PRIMITIVE_FIELDS:
247
- for field, value in zip(obj_class._NON_PRIMITIVE_FIELDS, subs[0]):
248
- self._set_field_value(obj, field, value)
355
+ self._set_field_values(obj, subs[0], cls_to_non_prim_key, False)
249
356
  obj.__on_deserialize__()
250
357
  return obj
251
358