maxframe 0.1.0b5__cp310-cp310-win32.whl → 1.0.0rc1__cp310-cp310-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of maxframe might be problematic. Click here for more details.
- maxframe/_utils.cp310-win32.pyd +0 -0
- maxframe/codegen.py +10 -2
- maxframe/config/config.py +4 -0
- maxframe/core/__init__.py +0 -3
- maxframe/core/entity/__init__.py +1 -8
- maxframe/core/entity/objects.py +3 -45
- maxframe/core/graph/core.cp310-win32.pyd +0 -0
- maxframe/core/graph/core.pyx +4 -4
- maxframe/dataframe/datastore/tests/__init__.py +13 -0
- maxframe/dataframe/datastore/tests/test_to_odps.py +48 -0
- maxframe/dataframe/datastore/to_odps.py +21 -0
- maxframe/dataframe/indexing/align.py +1 -1
- maxframe/dataframe/misc/apply.py +2 -0
- maxframe/dataframe/misc/memory_usage.py +2 -2
- maxframe/dataframe/misc/tests/test_misc.py +23 -0
- maxframe/dataframe/statistics/corr.py +3 -3
- maxframe/errors.py +13 -0
- maxframe/extension.py +12 -0
- maxframe/lib/mmh3.cp310-win32.pyd +0 -0
- maxframe/lib/mmh3.pyi +43 -0
- maxframe/lib/wrapped_pickle.py +2 -1
- maxframe/protocol.py +108 -10
- maxframe/serialization/core.cp310-win32.pyd +0 -0
- maxframe/serialization/core.pxd +3 -0
- maxframe/serialization/core.pyi +3 -0
- maxframe/serialization/core.pyx +54 -25
- maxframe/serialization/exception.py +1 -1
- maxframe/serialization/pandas.py +7 -2
- maxframe/serialization/serializables/core.py +119 -12
- maxframe/serialization/serializables/tests/test_serializable.py +46 -4
- maxframe/tensor/arithmetic/tests/test_arithmetic.py +1 -1
- maxframe/tensor/base/atleast_1d.py +1 -1
- maxframe/tensor/base/unique.py +1 -1
- maxframe/tensor/reduction/count_nonzero.py +1 -1
- maxframe/tests/test_protocol.py +34 -0
- maxframe/tests/test_utils.py +0 -12
- maxframe/tests/utils.py +2 -2
- maxframe/utils.py +16 -13
- {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc1.dist-info}/METADATA +2 -2
- {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc1.dist-info}/RECORD +46 -44
- maxframe_client/__init__.py +0 -1
- maxframe_client/session/odps.py +45 -5
- maxframe_client/session/task.py +41 -20
- maxframe_client/tests/test_session.py +36 -0
- maxframe_client/clients/spe.py +0 -104
- {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc1.dist-info}/WHEEL +0 -0
- {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc1.dist-info}/top_level.txt +0 -0
maxframe/protocol.py
CHANGED
|
@@ -32,6 +32,7 @@ from .serialization.serializables import (
|
|
|
32
32
|
EnumField,
|
|
33
33
|
FieldTypes,
|
|
34
34
|
Float64Field,
|
|
35
|
+
Int32Field,
|
|
35
36
|
ListField,
|
|
36
37
|
ReferenceField,
|
|
37
38
|
Serializable,
|
|
@@ -71,6 +72,9 @@ class DagStatus(enum.Enum):
|
|
|
71
72
|
CANCELLING = 4
|
|
72
73
|
CANCELLED = 5
|
|
73
74
|
|
|
75
|
+
def is_terminated(self):
|
|
76
|
+
return self in (DagStatus.CANCELLED, DagStatus.SUCCEEDED, DagStatus.FAILED)
|
|
77
|
+
|
|
74
78
|
|
|
75
79
|
class DimensionIndex(Serializable):
|
|
76
80
|
is_slice: bool = BoolField("is_slice", default=None)
|
|
@@ -190,9 +194,9 @@ class ErrorInfo(JsonSerializable):
|
|
|
190
194
|
"error_tracebacks", FieldTypes.list
|
|
191
195
|
)
|
|
192
196
|
raw_error_source: ErrorSource = EnumField(
|
|
193
|
-
"raw_error_source", ErrorSource, FieldTypes.int8
|
|
197
|
+
"raw_error_source", ErrorSource, FieldTypes.int8, default=None
|
|
194
198
|
)
|
|
195
|
-
raw_error_data: Optional[Exception] = AnyField("raw_error_data")
|
|
199
|
+
raw_error_data: Optional[Exception] = AnyField("raw_error_data", default=None)
|
|
196
200
|
|
|
197
201
|
@classmethod
|
|
198
202
|
def from_exception(cls, exc: Exception):
|
|
@@ -201,20 +205,29 @@ class ErrorInfo(JsonSerializable):
|
|
|
201
205
|
return cls(messages, tracebacks, ErrorSource.PYTHON, exc)
|
|
202
206
|
|
|
203
207
|
def reraise(self):
|
|
204
|
-
if
|
|
208
|
+
if (
|
|
209
|
+
self.raw_error_source == ErrorSource.PYTHON
|
|
210
|
+
and self.raw_error_data is not None
|
|
211
|
+
):
|
|
205
212
|
raise self.raw_error_data
|
|
206
213
|
raise RemoteException(self.error_messages, self.error_tracebacks, [])
|
|
207
214
|
|
|
208
215
|
@classmethod
|
|
209
216
|
def from_json(cls, serialized: dict) -> "ErrorInfo":
|
|
210
217
|
kw = serialized.copy()
|
|
211
|
-
kw
|
|
218
|
+
if kw.get("raw_error_source") is not None:
|
|
219
|
+
kw["raw_error_source"] = ErrorSource(serialized["raw_error_source"])
|
|
220
|
+
else:
|
|
221
|
+
kw["raw_error_source"] = None
|
|
222
|
+
|
|
212
223
|
if kw.get("raw_error_data"):
|
|
213
224
|
bufs = [base64.b64decode(s) for s in kw["raw_error_data"]]
|
|
214
225
|
try:
|
|
215
226
|
kw["raw_error_data"] = pickle.loads(bufs[0], buffers=bufs[1:])
|
|
216
227
|
except:
|
|
217
|
-
|
|
228
|
+
# both error source and data shall be None to make sure
|
|
229
|
+
# RemoteException is raised.
|
|
230
|
+
kw["raw_error_source"] = kw["raw_error_data"] = None
|
|
218
231
|
return cls(**kw)
|
|
219
232
|
|
|
220
233
|
def to_json(self) -> dict:
|
|
@@ -227,7 +240,12 @@ class ErrorInfo(JsonSerializable):
|
|
|
227
240
|
if isinstance(self.raw_error_data, (PickleContainer, RemoteException)):
|
|
228
241
|
err_data_bufs = self.raw_error_data.get_buffers()
|
|
229
242
|
elif isinstance(self.raw_error_data, BaseException):
|
|
230
|
-
|
|
243
|
+
try:
|
|
244
|
+
err_data_bufs = pickle_buffers(self.raw_error_data)
|
|
245
|
+
except:
|
|
246
|
+
err_data_bufs = None
|
|
247
|
+
ret["raw_error_source"] = None
|
|
248
|
+
|
|
231
249
|
if err_data_bufs:
|
|
232
250
|
ret["raw_error_data"] = [
|
|
233
251
|
base64.b64encode(s).decode() for s in err_data_bufs
|
|
@@ -249,9 +267,17 @@ class DagInfo(JsonSerializable):
|
|
|
249
267
|
error_info: Optional[ErrorInfo] = ReferenceField("error_info", default=None)
|
|
250
268
|
start_timestamp: Optional[float] = Float64Field("start_timestamp", default=None)
|
|
251
269
|
end_timestamp: Optional[float] = Float64Field("end_timestamp", default=None)
|
|
270
|
+
subdag_infos: Dict[str, "SubDagInfo"] = DictField(
|
|
271
|
+
"subdag_infos",
|
|
272
|
+
key_type=FieldTypes.string,
|
|
273
|
+
value_type=FieldTypes.reference,
|
|
274
|
+
default_factory=dict,
|
|
275
|
+
)
|
|
252
276
|
|
|
253
277
|
@classmethod
|
|
254
|
-
def from_json(cls, serialized: dict) -> "DagInfo":
|
|
278
|
+
def from_json(cls, serialized: dict) -> Optional["DagInfo"]:
|
|
279
|
+
if serialized is None:
|
|
280
|
+
return None
|
|
255
281
|
kw = serialized.copy()
|
|
256
282
|
kw["status"] = DagStatus(kw["status"])
|
|
257
283
|
if kw.get("tileable_to_result_infos"):
|
|
@@ -261,6 +287,10 @@ class DagInfo(JsonSerializable):
|
|
|
261
287
|
}
|
|
262
288
|
if kw.get("error_info"):
|
|
263
289
|
kw["error_info"] = ErrorInfo.from_json(kw["error_info"])
|
|
290
|
+
if kw.get("subdag_infos"):
|
|
291
|
+
kw["subdag_infos"] = {
|
|
292
|
+
k: SubDagInfo.from_json(v) for k, v in kw["subdag_infos"].items()
|
|
293
|
+
}
|
|
264
294
|
return DagInfo(**kw)
|
|
265
295
|
|
|
266
296
|
def to_json(self) -> dict:
|
|
@@ -279,6 +309,8 @@ class DagInfo(JsonSerializable):
|
|
|
279
309
|
}
|
|
280
310
|
if self.error_info:
|
|
281
311
|
ret["error_info"] = self.error_info.to_json()
|
|
312
|
+
if self.subdag_infos:
|
|
313
|
+
ret["subdag_infos"] = {k: v.to_json() for k, v in self.subdag_infos.items()}
|
|
282
314
|
return ret
|
|
283
315
|
|
|
284
316
|
|
|
@@ -302,7 +334,9 @@ class SessionInfo(JsonSerializable):
|
|
|
302
334
|
error_info: Optional[ErrorInfo] = ReferenceField("error_info", default=None)
|
|
303
335
|
|
|
304
336
|
@classmethod
|
|
305
|
-
def from_json(cls, serialized: dict) -> "SessionInfo":
|
|
337
|
+
def from_json(cls, serialized: dict) -> Optional["SessionInfo"]:
|
|
338
|
+
if serialized is None:
|
|
339
|
+
return None
|
|
306
340
|
kw = serialized.copy()
|
|
307
341
|
if kw.get("dag_infos"):
|
|
308
342
|
kw["dag_infos"] = {
|
|
@@ -320,7 +354,10 @@ class SessionInfo(JsonSerializable):
|
|
|
320
354
|
"idle_timestamp": self.idle_timestamp,
|
|
321
355
|
}
|
|
322
356
|
if self.dag_infos:
|
|
323
|
-
ret["dag_infos"] = {
|
|
357
|
+
ret["dag_infos"] = {
|
|
358
|
+
k: v.to_json() if v is not None else None
|
|
359
|
+
for k, v in self.dag_infos.items()
|
|
360
|
+
}
|
|
324
361
|
if self.error_info:
|
|
325
362
|
ret["error_info"] = self.error_info.to_json()
|
|
326
363
|
return ret
|
|
@@ -342,7 +379,25 @@ class ExecuteDagRequest(Serializable):
|
|
|
342
379
|
)
|
|
343
380
|
|
|
344
381
|
|
|
345
|
-
class
|
|
382
|
+
class SubDagSubmitInstanceInfo(JsonSerializable):
|
|
383
|
+
submit_reason: str = StringField("submit_reason")
|
|
384
|
+
instance_id: str = StringField("instance_id")
|
|
385
|
+
subquery_id: Optional[int] = Int32Field("subquery_id", default=None)
|
|
386
|
+
|
|
387
|
+
@classmethod
|
|
388
|
+
def from_json(cls, serialized: dict) -> "SubDagSubmitInstanceInfo":
|
|
389
|
+
return SubDagSubmitInstanceInfo(**serialized)
|
|
390
|
+
|
|
391
|
+
def to_json(self) -> dict:
|
|
392
|
+
ret = {
|
|
393
|
+
"submit_reason": self.submit_reason,
|
|
394
|
+
"instance_id": self.instance_id,
|
|
395
|
+
"subquery_id": self.subquery_id,
|
|
396
|
+
}
|
|
397
|
+
return ret
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
class SubDagInfo(JsonSerializable):
|
|
346
401
|
subdag_id: str = StringField("subdag_id")
|
|
347
402
|
status: DagStatus = EnumField("status", DagStatus, FieldTypes.int8, default=None)
|
|
348
403
|
progress: float = Float64Field("progress", default=None)
|
|
@@ -355,9 +410,52 @@ class SubDagInfo(Serializable):
|
|
|
355
410
|
FieldTypes.reference,
|
|
356
411
|
default_factory=dict,
|
|
357
412
|
)
|
|
413
|
+
start_timestamp: Optional[float] = Float64Field("start_timestamp", default=None)
|
|
414
|
+
end_timestamp: Optional[float] = Float64Field("end_timestamp", default=None)
|
|
415
|
+
submit_instances: List[SubDagSubmitInstanceInfo] = ListField(
|
|
416
|
+
"submit_instances",
|
|
417
|
+
FieldTypes.reference,
|
|
418
|
+
default_factory=list,
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
@classmethod
|
|
422
|
+
def from_json(cls, serialized: dict) -> "SubDagInfo":
|
|
423
|
+
kw = serialized.copy()
|
|
424
|
+
kw["status"] = DagStatus(kw["status"])
|
|
425
|
+
if kw.get("tileable_to_result_infos"):
|
|
426
|
+
kw["tileable_to_result_infos"] = {
|
|
427
|
+
k: ResultInfo.from_json(s)
|
|
428
|
+
for k, s in kw["tileable_to_result_infos"].items()
|
|
429
|
+
}
|
|
430
|
+
if kw.get("error_info"):
|
|
431
|
+
kw["error_info"] = ErrorInfo.from_json(kw["error_info"])
|
|
432
|
+
if kw.get("submit_instances"):
|
|
433
|
+
kw["submit_instances"] = [
|
|
434
|
+
SubDagSubmitInstanceInfo.from_json(s) for s in kw["submit_instances"]
|
|
435
|
+
]
|
|
436
|
+
return SubDagInfo(**kw)
|
|
437
|
+
|
|
438
|
+
def to_json(self) -> dict:
|
|
439
|
+
ret = {
|
|
440
|
+
"subdag_id": self.subdag_id,
|
|
441
|
+
"status": self.status.value,
|
|
442
|
+
"progress": self.progress,
|
|
443
|
+
"start_timestamp": self.start_timestamp,
|
|
444
|
+
"end_timestamp": self.end_timestamp,
|
|
445
|
+
}
|
|
446
|
+
if self.error_info:
|
|
447
|
+
ret["error_info"] = self.error_info.to_json()
|
|
448
|
+
if self.tileable_to_result_infos:
|
|
449
|
+
ret["tileable_to_result_infos"] = {
|
|
450
|
+
k: v.to_json() for k, v in self.tileable_to_result_infos.items()
|
|
451
|
+
}
|
|
452
|
+
if self.submit_instances:
|
|
453
|
+
ret["submit_instances"] = [i.to_json() for i in self.submit_instances]
|
|
454
|
+
return ret
|
|
358
455
|
|
|
359
456
|
|
|
360
457
|
class ExecuteSubDagRequest(Serializable):
|
|
458
|
+
subdag_id: str = StringField("subdag_id")
|
|
361
459
|
dag: TileableGraph = ReferenceField(
|
|
362
460
|
"dag",
|
|
363
461
|
on_serialize=SerializableGraph.from_graph,
|
|
Binary file
|
maxframe/serialization/core.pxd
CHANGED
|
@@ -18,6 +18,9 @@ from libc.stdint cimport int32_t, uint64_t
|
|
|
18
18
|
cdef class Serializer:
|
|
19
19
|
cdef int _serializer_id
|
|
20
20
|
|
|
21
|
+
cpdef bint is_public_data_exist(self, dict context, object key)
|
|
22
|
+
cpdef put_public_data(self, dict context, object key, object value)
|
|
23
|
+
cpdef get_public_data(self, dict context, object key)
|
|
21
24
|
cpdef serial(self, object obj, dict context)
|
|
22
25
|
cpdef deserial(self, list serialized, dict context, list subs)
|
|
23
26
|
cpdef on_deserial_error(
|
maxframe/serialization/core.pyi
CHANGED
|
@@ -29,6 +29,9 @@ class PickleContainer:
|
|
|
29
29
|
|
|
30
30
|
class Serializer:
|
|
31
31
|
serializer_id: int
|
|
32
|
+
def is_public_data_exist(self, context: Dict, key: Any) -> bool: ...
|
|
33
|
+
def put_public_data(self, context: Dict, key: Any, value: Any) -> None: ...
|
|
34
|
+
def get_public_data(self, context: Dict, key: Any) -> Any: ...
|
|
32
35
|
def serial(self, obj: Any, context: Dict): ...
|
|
33
36
|
def deserial(self, serialized: List, context: Dict, subs: List[Any]): ...
|
|
34
37
|
def on_deserial_error(
|
maxframe/serialization/core.pyx
CHANGED
|
@@ -130,11 +130,30 @@ cdef Serializer get_deserializer(int32_t deserializer_id):
|
|
|
130
130
|
|
|
131
131
|
cdef class Serializer:
|
|
132
132
|
serializer_id = None
|
|
133
|
+
_public_data_context_key = 0x7fffffff - 1
|
|
133
134
|
|
|
134
135
|
def __cinit__(self):
|
|
135
136
|
# make the value can be referenced with C code
|
|
136
137
|
self._serializer_id = self.serializer_id
|
|
137
138
|
|
|
139
|
+
cpdef bint is_public_data_exist(self, dict context, object key):
|
|
140
|
+
cdef dict public_dict = context.get(self._public_data_context_key, None)
|
|
141
|
+
if public_dict is None:
|
|
142
|
+
return False
|
|
143
|
+
return key in public_dict
|
|
144
|
+
|
|
145
|
+
cpdef put_public_data(self, dict context, object key, object value):
|
|
146
|
+
cdef dict public_dict = context.get(self._public_data_context_key, None)
|
|
147
|
+
if public_dict is None:
|
|
148
|
+
public_dict = context[self._public_data_context_key] = {}
|
|
149
|
+
public_dict[key] = value
|
|
150
|
+
|
|
151
|
+
cpdef get_public_data(self, dict context, object key):
|
|
152
|
+
cdef dict public_dict = context.get(self._public_data_context_key, None)
|
|
153
|
+
if public_dict is None:
|
|
154
|
+
return None
|
|
155
|
+
return public_dict.get(key)
|
|
156
|
+
|
|
138
157
|
cpdef serial(self, object obj, dict context):
|
|
139
158
|
"""
|
|
140
159
|
Returns intermediate serialization result of certain object.
|
|
@@ -993,17 +1012,20 @@ def serialize(obj, dict context = None):
|
|
|
993
1012
|
cdef list subs
|
|
994
1013
|
cdef bint final
|
|
995
1014
|
cdef _IdContextHolder id_context_holder = _IdContextHolder()
|
|
1015
|
+
cdef tuple result
|
|
996
1016
|
|
|
997
1017
|
context = context if context is not None else dict()
|
|
998
1018
|
serialized, subs, final = _serial_single(obj, context, id_context_holder)
|
|
999
1019
|
if final or not subs:
|
|
1000
1020
|
# marked as a leaf node, return directly
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1021
|
+
result = [{}, serialized], subs
|
|
1022
|
+
else:
|
|
1023
|
+
serial_stack.append(_SerialStackItem(serialized, subs))
|
|
1024
|
+
result = _serialize_with_stack(
|
|
1025
|
+
serial_stack, None, context, id_context_holder, result_bufs_list
|
|
1026
|
+
)
|
|
1027
|
+
result[0][0]["_PUB"] = context.get(Serializer._public_data_context_key)
|
|
1028
|
+
return result
|
|
1007
1029
|
|
|
1008
1030
|
|
|
1009
1031
|
async def serialize_with_spawn(
|
|
@@ -1036,31 +1058,38 @@ async def serialize_with_spawn(
|
|
|
1036
1058
|
cdef list subs
|
|
1037
1059
|
cdef bint final
|
|
1038
1060
|
cdef _IdContextHolder id_context_holder = _IdContextHolder()
|
|
1061
|
+
cdef tuple result
|
|
1039
1062
|
|
|
1040
1063
|
context = context if context is not None else dict()
|
|
1041
1064
|
serialized, subs, final = _serial_single(obj, context, id_context_holder)
|
|
1042
1065
|
if final or not subs:
|
|
1043
1066
|
# marked as a leaf node, return directly
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1067
|
+
result = [{}, serialized], subs
|
|
1068
|
+
else:
|
|
1069
|
+
serial_stack.append(_SerialStackItem(serialized, subs))
|
|
1047
1070
|
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1071
|
+
try:
|
|
1072
|
+
result = _serialize_with_stack(
|
|
1073
|
+
serial_stack,
|
|
1074
|
+
None,
|
|
1075
|
+
context,
|
|
1076
|
+
id_context_holder,
|
|
1077
|
+
result_bufs_list,
|
|
1078
|
+
spawn_threshold,
|
|
1079
|
+
)
|
|
1080
|
+
except _SerializeObjectOverflow as ex:
|
|
1081
|
+
result = await asyncio.get_running_loop().run_in_executor(
|
|
1082
|
+
executor,
|
|
1083
|
+
_serialize_with_stack,
|
|
1084
|
+
serial_stack,
|
|
1085
|
+
ex.cur_serialized,
|
|
1086
|
+
context,
|
|
1087
|
+
id_context_holder,
|
|
1088
|
+
result_bufs_list,
|
|
1089
|
+
0,
|
|
1090
|
+
ex.num_total_serialized,
|
|
1091
|
+
)
|
|
1092
|
+
result[0][0]["_PUB"] = context.get(Serializer._public_data_context_key)
|
|
1064
1093
|
return result
|
|
1065
1094
|
|
|
1066
1095
|
|
maxframe/serialization/pandas.py
CHANGED
|
@@ -176,11 +176,16 @@ class PdTimestampSerializer(Serializer):
|
|
|
176
176
|
|
|
177
177
|
class PdTimedeltaSerializer(Serializer):
|
|
178
178
|
def serial(self, obj: pd.Timedelta, context: Dict):
|
|
179
|
-
return [int(obj.seconds), obj.microseconds, obj.nanoseconds], [], True
|
|
179
|
+
return [int(obj.seconds), obj.microseconds, obj.nanoseconds, obj.days], [], True
|
|
180
180
|
|
|
181
181
|
def deserial(self, serialized: List, context: Dict, subs: List):
|
|
182
|
+
days = 0 if len(serialized) < 4 else serialized[3]
|
|
183
|
+
seconds, microseconds, nanoseconds = serialized[:3]
|
|
182
184
|
return pd.Timedelta(
|
|
183
|
-
|
|
185
|
+
days=days,
|
|
186
|
+
seconds=seconds,
|
|
187
|
+
microseconds=microseconds,
|
|
188
|
+
nanoseconds=nanoseconds,
|
|
184
189
|
)
|
|
185
190
|
|
|
186
191
|
|
|
@@ -12,12 +12,13 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import operator
|
|
16
15
|
import weakref
|
|
17
|
-
from
|
|
16
|
+
from collections import defaultdict
|
|
17
|
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
|
18
18
|
|
|
19
19
|
import msgpack
|
|
20
20
|
|
|
21
|
+
from ...lib.mmh3 import hash
|
|
21
22
|
from ..core import Placeholder, Serializer, buffered, load_type
|
|
22
23
|
from .field import Field
|
|
23
24
|
from .field_type import DictType, ListType, PrimitiveFieldType, TupleType
|
|
@@ -50,11 +51,16 @@ def _is_field_primitive_compound(field: Field):
|
|
|
50
51
|
class SerializableMeta(type):
|
|
51
52
|
def __new__(mcs, name: str, bases: Tuple[Type], properties: Dict):
|
|
52
53
|
# All the fields including misc fields.
|
|
54
|
+
name_hash = hash(f"{properties.get('__module__')}.{name}")
|
|
53
55
|
all_fields = dict()
|
|
56
|
+
# mapping field names to base classes
|
|
57
|
+
field_to_cls_hash = dict()
|
|
54
58
|
|
|
55
59
|
for base in bases:
|
|
56
|
-
if hasattr(base, "_FIELDS"):
|
|
57
|
-
|
|
60
|
+
if not hasattr(base, "_FIELDS"):
|
|
61
|
+
continue
|
|
62
|
+
all_fields.update(base._FIELDS)
|
|
63
|
+
field_to_cls_hash.update(base._FIELD_TO_NAME_HASH)
|
|
58
64
|
|
|
59
65
|
properties_without_fields = {}
|
|
60
66
|
properties_field_slot_names = []
|
|
@@ -64,6 +70,8 @@ class SerializableMeta(type):
|
|
|
64
70
|
continue
|
|
65
71
|
|
|
66
72
|
field = all_fields.get(k)
|
|
73
|
+
# record the field for the class being created
|
|
74
|
+
field_to_cls_hash[k] = name_hash
|
|
67
75
|
if field is None:
|
|
68
76
|
properties_field_slot_names.append(k)
|
|
69
77
|
else:
|
|
@@ -75,23 +83,40 @@ class SerializableMeta(type):
|
|
|
75
83
|
|
|
76
84
|
# Make field order deterministic to serialize it as list instead of dict.
|
|
77
85
|
field_order = list(all_fields)
|
|
78
|
-
all_fields = dict(sorted(all_fields.items(), key=operator.itemgetter(0)))
|
|
79
86
|
primitive_fields = []
|
|
87
|
+
primitive_field_names = set()
|
|
80
88
|
non_primitive_fields = []
|
|
81
|
-
for v in all_fields.
|
|
89
|
+
for field_name, v in all_fields.items():
|
|
82
90
|
if _is_field_primitive_compound(v):
|
|
83
91
|
primitive_fields.append(v)
|
|
92
|
+
primitive_field_names.add(field_name)
|
|
84
93
|
else:
|
|
85
94
|
non_primitive_fields.append(v)
|
|
86
95
|
|
|
96
|
+
# count number of fields for every base class
|
|
97
|
+
cls_to_primitive_field_count = defaultdict(lambda: 0)
|
|
98
|
+
cls_to_non_primitive_field_count = defaultdict(lambda: 0)
|
|
99
|
+
for field_name in field_order:
|
|
100
|
+
cls_hash = field_to_cls_hash[field_name]
|
|
101
|
+
if field_name in primitive_field_names:
|
|
102
|
+
cls_to_primitive_field_count[cls_hash] += 1
|
|
103
|
+
else:
|
|
104
|
+
cls_to_non_primitive_field_count[cls_hash] += 1
|
|
105
|
+
|
|
87
106
|
slots = set(properties.pop("__slots__", set()))
|
|
88
107
|
slots.update(properties_field_slot_names)
|
|
89
108
|
|
|
90
109
|
properties = properties_without_fields
|
|
110
|
+
properties["_NAME_HASH"] = name_hash
|
|
91
111
|
properties["_FIELDS"] = all_fields
|
|
92
112
|
properties["_FIELD_ORDER"] = field_order
|
|
113
|
+
properties["_FIELD_TO_NAME_HASH"] = field_to_cls_hash
|
|
93
114
|
properties["_PRIMITIVE_FIELDS"] = primitive_fields
|
|
115
|
+
properties["_CLS_TO_PRIMITIVE_FIELD_COUNT"] = dict(cls_to_primitive_field_count)
|
|
94
116
|
properties["_NON_PRIMITIVE_FIELDS"] = non_primitive_fields
|
|
117
|
+
properties["_CLS_TO_NON_PRIMITIVE_FIELD_COUNT"] = dict(
|
|
118
|
+
cls_to_non_primitive_field_count
|
|
119
|
+
)
|
|
95
120
|
properties["__slots__"] = tuple(slots)
|
|
96
121
|
|
|
97
122
|
clz = type.__new__(mcs, name, bases, properties)
|
|
@@ -114,10 +139,14 @@ class Serializable(metaclass=SerializableMeta):
|
|
|
114
139
|
_cache_primitive_serial = False
|
|
115
140
|
_ignore_non_existing_keys = False
|
|
116
141
|
|
|
142
|
+
_NAME_HASH: int
|
|
117
143
|
_FIELDS: Dict[str, Field]
|
|
118
144
|
_FIELD_ORDER: List[str]
|
|
145
|
+
_FIELD_TO_NAME_HASH: Dict[str, int]
|
|
119
146
|
_PRIMITIVE_FIELDS: List[str]
|
|
147
|
+
_CLS_TO_PRIMITIVE_FIELD_COUNT: Dict[int, int]
|
|
120
148
|
_NON_PRIMITIVE_FIELDS: List[str]
|
|
149
|
+
_CLS_TO_NON_PRIMITIVE_FIELD_COUNT: Dict[int, int]
|
|
121
150
|
|
|
122
151
|
def __init__(self, *args, **kwargs):
|
|
123
152
|
fields = self._FIELDS
|
|
@@ -180,6 +209,10 @@ class SerializableSerializer(Serializer):
|
|
|
180
209
|
Leverage DictSerializer to perform serde.
|
|
181
210
|
"""
|
|
182
211
|
|
|
212
|
+
@classmethod
|
|
213
|
+
def _get_obj_field_count_key(cls, obj: Serializable):
|
|
214
|
+
return f"FC_{obj._NAME_HASH}"
|
|
215
|
+
|
|
183
216
|
@classmethod
|
|
184
217
|
def _get_field_values(cls, obj: Serializable, fields):
|
|
185
218
|
values = []
|
|
@@ -210,6 +243,18 @@ class SerializableSerializer(Serializer):
|
|
|
210
243
|
|
|
211
244
|
compound_vals = self._get_field_values(obj, obj._NON_PRIMITIVE_FIELDS)
|
|
212
245
|
cls_module = f"{type(obj).__module__}#{type(obj).__qualname__}"
|
|
246
|
+
|
|
247
|
+
field_count_key = self._get_obj_field_count_key(obj)
|
|
248
|
+
if not self.is_public_data_exist(context, field_count_key):
|
|
249
|
+
# store field distribution for current Serializable
|
|
250
|
+
counts = [
|
|
251
|
+
list(obj._CLS_TO_PRIMITIVE_FIELD_COUNT.items()),
|
|
252
|
+
list(obj._CLS_TO_NON_PRIMITIVE_FIELD_COUNT.items()),
|
|
253
|
+
]
|
|
254
|
+
field_count_data = msgpack.dumps(counts)
|
|
255
|
+
self.put_public_data(
|
|
256
|
+
context, self._get_obj_field_count_key(obj), field_count_data
|
|
257
|
+
)
|
|
213
258
|
return [cls_module, primitive_vals], [compound_vals], False
|
|
214
259
|
|
|
215
260
|
@staticmethod
|
|
@@ -229,6 +274,62 @@ class SerializableSerializer(Serializer):
|
|
|
229
274
|
else:
|
|
230
275
|
field.set(obj, value)
|
|
231
276
|
|
|
277
|
+
@classmethod
|
|
278
|
+
def _set_field_values(
|
|
279
|
+
cls,
|
|
280
|
+
obj: Serializable,
|
|
281
|
+
values: List[Any],
|
|
282
|
+
client_cls_to_field_count: Optional[Dict[str, int]],
|
|
283
|
+
is_primitive: bool = True,
|
|
284
|
+
):
|
|
285
|
+
obj_class = type(obj)
|
|
286
|
+
if is_primitive:
|
|
287
|
+
server_cls_to_field_count = obj_class._CLS_TO_PRIMITIVE_FIELD_COUNT
|
|
288
|
+
server_fields = obj_class._PRIMITIVE_FIELDS
|
|
289
|
+
else:
|
|
290
|
+
server_cls_to_field_count = obj_class._CLS_TO_NON_PRIMITIVE_FIELD_COUNT
|
|
291
|
+
server_fields = obj_class._NON_PRIMITIVE_FIELDS
|
|
292
|
+
|
|
293
|
+
if client_cls_to_field_count:
|
|
294
|
+
field_num, server_field_num = 0, 0
|
|
295
|
+
for cls_hash, count in client_cls_to_field_count.items():
|
|
296
|
+
# cut values and fields given field distribution
|
|
297
|
+
# at client and server end
|
|
298
|
+
cls_fields = server_fields[server_field_num : field_num + count]
|
|
299
|
+
cls_values = values[field_num : field_num + count]
|
|
300
|
+
for field, value in zip(cls_fields, cls_values):
|
|
301
|
+
if not is_primitive or value != {}:
|
|
302
|
+
cls._set_field_value(obj, field, value)
|
|
303
|
+
field_num += count
|
|
304
|
+
server_field_num += server_cls_to_field_count[cls_hash]
|
|
305
|
+
else:
|
|
306
|
+
# todo remove this branch when all versions below v0.1.0b5 is eliminated
|
|
307
|
+
from .field import AnyField
|
|
308
|
+
|
|
309
|
+
# legacy serialization style, with all fields sorted by name
|
|
310
|
+
if is_primitive:
|
|
311
|
+
field_attr = "_legacy_deprecated_primitives"
|
|
312
|
+
else:
|
|
313
|
+
field_attr = "_legacy_deprecated_non_primitives"
|
|
314
|
+
deprecated_fields = []
|
|
315
|
+
deprecated_names = set()
|
|
316
|
+
if hasattr(obj_class, field_attr):
|
|
317
|
+
deprecated_names = set(getattr(obj_class, field_attr))
|
|
318
|
+
for field_name in deprecated_names:
|
|
319
|
+
field = AnyField(tag=field_name)
|
|
320
|
+
field.name = field_name
|
|
321
|
+
deprecated_fields.append(field)
|
|
322
|
+
server_fields = sorted(
|
|
323
|
+
server_fields + deprecated_fields, key=lambda f: f.name
|
|
324
|
+
)
|
|
325
|
+
for field, value in zip(server_fields, values):
|
|
326
|
+
if not is_primitive or value != {}:
|
|
327
|
+
try:
|
|
328
|
+
cls._set_field_value(obj, field, value)
|
|
329
|
+
except AttributeError: # pragma: no cover
|
|
330
|
+
if field.name not in deprecated_names:
|
|
331
|
+
raise
|
|
332
|
+
|
|
232
333
|
def deserial(self, serialized: List, context: Dict, subs: List) -> Serializable:
|
|
233
334
|
obj_class_name, primitives = serialized
|
|
234
335
|
obj_class = load_type(obj_class_name, Serializable)
|
|
@@ -238,14 +339,20 @@ class SerializableSerializer(Serializer):
|
|
|
238
339
|
|
|
239
340
|
obj = obj_class.__new__(obj_class)
|
|
240
341
|
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
342
|
+
field_count_data = self.get_public_data(
|
|
343
|
+
context, self._get_obj_field_count_key(obj)
|
|
344
|
+
)
|
|
345
|
+
if field_count_data is not None:
|
|
346
|
+
cls_to_prim_key, cls_to_non_prim_key = msgpack.loads(field_count_data)
|
|
347
|
+
cls_to_prim_key = dict(cls_to_prim_key)
|
|
348
|
+
cls_to_non_prim_key = dict(cls_to_non_prim_key)
|
|
349
|
+
else:
|
|
350
|
+
cls_to_prim_key, cls_to_non_prim_key = None, None
|
|
245
351
|
|
|
352
|
+
if primitives:
|
|
353
|
+
self._set_field_values(obj, primitives, cls_to_prim_key, True)
|
|
246
354
|
if obj_class._NON_PRIMITIVE_FIELDS:
|
|
247
|
-
|
|
248
|
-
self._set_field_value(obj, field, value)
|
|
355
|
+
self._set_field_values(obj, subs[0], cls_to_non_prim_key, False)
|
|
249
356
|
obj.__on_deserialize__()
|
|
250
357
|
return obj
|
|
251
358
|
|