maxframe 0.1.0b4__cp39-cp39-win32.whl → 1.0.0rc1__cp39-cp39-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of maxframe might be problematic. Click here for more details.
- maxframe/__init__.py +1 -0
- maxframe/_utils.cp39-win32.pyd +0 -0
- maxframe/codegen.py +56 -3
- maxframe/config/config.py +15 -1
- maxframe/core/__init__.py +0 -3
- maxframe/core/entity/__init__.py +1 -8
- maxframe/core/entity/objects.py +3 -45
- maxframe/core/graph/core.cp39-win32.pyd +0 -0
- maxframe/core/graph/core.pyx +4 -4
- maxframe/dataframe/__init__.py +1 -0
- maxframe/dataframe/core.py +30 -8
- maxframe/dataframe/datasource/read_odps_query.py +3 -1
- maxframe/dataframe/datasource/read_odps_table.py +3 -1
- maxframe/dataframe/datastore/tests/__init__.py +13 -0
- maxframe/dataframe/datastore/tests/test_to_odps.py +48 -0
- maxframe/dataframe/datastore/to_odps.py +21 -0
- maxframe/dataframe/indexing/align.py +1 -1
- maxframe/dataframe/misc/__init__.py +4 -0
- maxframe/dataframe/misc/apply.py +3 -1
- maxframe/dataframe/misc/case_when.py +141 -0
- maxframe/dataframe/misc/memory_usage.py +2 -2
- maxframe/dataframe/misc/pivot_table.py +262 -0
- maxframe/dataframe/misc/tests/test_misc.py +84 -0
- maxframe/dataframe/plotting/core.py +2 -2
- maxframe/dataframe/reduction/core.py +2 -1
- maxframe/dataframe/statistics/corr.py +3 -3
- maxframe/dataframe/utils.py +7 -0
- maxframe/errors.py +13 -0
- maxframe/extension.py +12 -0
- maxframe/learn/contrib/utils.py +52 -0
- maxframe/learn/contrib/xgboost/__init__.py +26 -0
- maxframe/learn/contrib/xgboost/classifier.py +86 -0
- maxframe/learn/contrib/xgboost/core.py +156 -0
- maxframe/learn/contrib/xgboost/dmatrix.py +150 -0
- maxframe/learn/contrib/xgboost/predict.py +138 -0
- maxframe/learn/contrib/xgboost/regressor.py +78 -0
- maxframe/learn/contrib/xgboost/tests/__init__.py +13 -0
- maxframe/learn/contrib/xgboost/tests/test_core.py +43 -0
- maxframe/learn/contrib/xgboost/train.py +121 -0
- maxframe/learn/utils/__init__.py +15 -0
- maxframe/learn/utils/core.py +29 -0
- maxframe/lib/mmh3.cp39-win32.pyd +0 -0
- maxframe/lib/mmh3.pyi +43 -0
- maxframe/lib/wrapped_pickle.py +2 -1
- maxframe/odpsio/arrow.py +2 -3
- maxframe/odpsio/tableio.py +22 -0
- maxframe/odpsio/tests/test_schema.py +16 -11
- maxframe/opcodes.py +3 -0
- maxframe/protocol.py +108 -10
- maxframe/serialization/core.cp39-win32.pyd +0 -0
- maxframe/serialization/core.pxd +3 -0
- maxframe/serialization/core.pyi +64 -0
- maxframe/serialization/core.pyx +54 -25
- maxframe/serialization/exception.py +1 -1
- maxframe/serialization/pandas.py +7 -2
- maxframe/serialization/serializables/core.py +119 -12
- maxframe/serialization/serializables/tests/test_serializable.py +46 -4
- maxframe/session.py +28 -0
- maxframe/tensor/__init__.py +1 -1
- maxframe/tensor/arithmetic/tests/test_arithmetic.py +1 -1
- maxframe/tensor/base/__init__.py +2 -0
- maxframe/tensor/base/atleast_1d.py +74 -0
- maxframe/tensor/base/unique.py +205 -0
- maxframe/tensor/datasource/array.py +4 -2
- maxframe/tensor/datasource/scalar.py +1 -1
- maxframe/tensor/reduction/count_nonzero.py +1 -1
- maxframe/tests/test_protocol.py +34 -0
- maxframe/tests/test_utils.py +0 -12
- maxframe/tests/utils.py +2 -2
- maxframe/udf.py +63 -3
- maxframe/utils.py +22 -13
- {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/METADATA +3 -3
- {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/RECORD +80 -61
- maxframe_client/__init__.py +0 -1
- maxframe_client/fetcher.py +65 -3
- maxframe_client/session/odps.py +74 -5
- maxframe_client/session/task.py +65 -71
- maxframe_client/tests/test_session.py +64 -1
- maxframe_client/clients/spe.py +0 -104
- {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/WHEEL +0 -0
- {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/top_level.txt +0 -0
maxframe/serialization/core.pyx
CHANGED
|
@@ -130,11 +130,30 @@ cdef Serializer get_deserializer(int32_t deserializer_id):
|
|
|
130
130
|
|
|
131
131
|
cdef class Serializer:
|
|
132
132
|
serializer_id = None
|
|
133
|
+
_public_data_context_key = 0x7fffffff - 1
|
|
133
134
|
|
|
134
135
|
def __cinit__(self):
|
|
135
136
|
# make the value can be referenced with C code
|
|
136
137
|
self._serializer_id = self.serializer_id
|
|
137
138
|
|
|
139
|
+
cpdef bint is_public_data_exist(self, dict context, object key):
|
|
140
|
+
cdef dict public_dict = context.get(self._public_data_context_key, None)
|
|
141
|
+
if public_dict is None:
|
|
142
|
+
return False
|
|
143
|
+
return key in public_dict
|
|
144
|
+
|
|
145
|
+
cpdef put_public_data(self, dict context, object key, object value):
|
|
146
|
+
cdef dict public_dict = context.get(self._public_data_context_key, None)
|
|
147
|
+
if public_dict is None:
|
|
148
|
+
public_dict = context[self._public_data_context_key] = {}
|
|
149
|
+
public_dict[key] = value
|
|
150
|
+
|
|
151
|
+
cpdef get_public_data(self, dict context, object key):
|
|
152
|
+
cdef dict public_dict = context.get(self._public_data_context_key, None)
|
|
153
|
+
if public_dict is None:
|
|
154
|
+
return None
|
|
155
|
+
return public_dict.get(key)
|
|
156
|
+
|
|
138
157
|
cpdef serial(self, object obj, dict context):
|
|
139
158
|
"""
|
|
140
159
|
Returns intermediate serialization result of certain object.
|
|
@@ -993,17 +1012,20 @@ def serialize(obj, dict context = None):
|
|
|
993
1012
|
cdef list subs
|
|
994
1013
|
cdef bint final
|
|
995
1014
|
cdef _IdContextHolder id_context_holder = _IdContextHolder()
|
|
1015
|
+
cdef tuple result
|
|
996
1016
|
|
|
997
1017
|
context = context if context is not None else dict()
|
|
998
1018
|
serialized, subs, final = _serial_single(obj, context, id_context_holder)
|
|
999
1019
|
if final or not subs:
|
|
1000
1020
|
# marked as a leaf node, return directly
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1021
|
+
result = [{}, serialized], subs
|
|
1022
|
+
else:
|
|
1023
|
+
serial_stack.append(_SerialStackItem(serialized, subs))
|
|
1024
|
+
result = _serialize_with_stack(
|
|
1025
|
+
serial_stack, None, context, id_context_holder, result_bufs_list
|
|
1026
|
+
)
|
|
1027
|
+
result[0][0]["_PUB"] = context.get(Serializer._public_data_context_key)
|
|
1028
|
+
return result
|
|
1007
1029
|
|
|
1008
1030
|
|
|
1009
1031
|
async def serialize_with_spawn(
|
|
@@ -1036,31 +1058,38 @@ async def serialize_with_spawn(
|
|
|
1036
1058
|
cdef list subs
|
|
1037
1059
|
cdef bint final
|
|
1038
1060
|
cdef _IdContextHolder id_context_holder = _IdContextHolder()
|
|
1061
|
+
cdef tuple result
|
|
1039
1062
|
|
|
1040
1063
|
context = context if context is not None else dict()
|
|
1041
1064
|
serialized, subs, final = _serial_single(obj, context, id_context_holder)
|
|
1042
1065
|
if final or not subs:
|
|
1043
1066
|
# marked as a leaf node, return directly
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1067
|
+
result = [{}, serialized], subs
|
|
1068
|
+
else:
|
|
1069
|
+
serial_stack.append(_SerialStackItem(serialized, subs))
|
|
1047
1070
|
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1071
|
+
try:
|
|
1072
|
+
result = _serialize_with_stack(
|
|
1073
|
+
serial_stack,
|
|
1074
|
+
None,
|
|
1075
|
+
context,
|
|
1076
|
+
id_context_holder,
|
|
1077
|
+
result_bufs_list,
|
|
1078
|
+
spawn_threshold,
|
|
1079
|
+
)
|
|
1080
|
+
except _SerializeObjectOverflow as ex:
|
|
1081
|
+
result = await asyncio.get_running_loop().run_in_executor(
|
|
1082
|
+
executor,
|
|
1083
|
+
_serialize_with_stack,
|
|
1084
|
+
serial_stack,
|
|
1085
|
+
ex.cur_serialized,
|
|
1086
|
+
context,
|
|
1087
|
+
id_context_holder,
|
|
1088
|
+
result_bufs_list,
|
|
1089
|
+
0,
|
|
1090
|
+
ex.num_total_serialized,
|
|
1091
|
+
)
|
|
1092
|
+
result[0][0]["_PUB"] = context.get(Serializer._public_data_context_key)
|
|
1064
1093
|
return result
|
|
1065
1094
|
|
|
1066
1095
|
|
maxframe/serialization/pandas.py
CHANGED
|
@@ -176,11 +176,16 @@ class PdTimestampSerializer(Serializer):
|
|
|
176
176
|
|
|
177
177
|
class PdTimedeltaSerializer(Serializer):
|
|
178
178
|
def serial(self, obj: pd.Timedelta, context: Dict):
|
|
179
|
-
return [int(obj.seconds), obj.microseconds, obj.nanoseconds], [], True
|
|
179
|
+
return [int(obj.seconds), obj.microseconds, obj.nanoseconds, obj.days], [], True
|
|
180
180
|
|
|
181
181
|
def deserial(self, serialized: List, context: Dict, subs: List):
|
|
182
|
+
days = 0 if len(serialized) < 4 else serialized[3]
|
|
183
|
+
seconds, microseconds, nanoseconds = serialized[:3]
|
|
182
184
|
return pd.Timedelta(
|
|
183
|
-
|
|
185
|
+
days=days,
|
|
186
|
+
seconds=seconds,
|
|
187
|
+
microseconds=microseconds,
|
|
188
|
+
nanoseconds=nanoseconds,
|
|
184
189
|
)
|
|
185
190
|
|
|
186
191
|
|
|
@@ -12,12 +12,13 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import operator
|
|
16
15
|
import weakref
|
|
17
|
-
from
|
|
16
|
+
from collections import defaultdict
|
|
17
|
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
|
18
18
|
|
|
19
19
|
import msgpack
|
|
20
20
|
|
|
21
|
+
from ...lib.mmh3 import hash
|
|
21
22
|
from ..core import Placeholder, Serializer, buffered, load_type
|
|
22
23
|
from .field import Field
|
|
23
24
|
from .field_type import DictType, ListType, PrimitiveFieldType, TupleType
|
|
@@ -50,11 +51,16 @@ def _is_field_primitive_compound(field: Field):
|
|
|
50
51
|
class SerializableMeta(type):
|
|
51
52
|
def __new__(mcs, name: str, bases: Tuple[Type], properties: Dict):
|
|
52
53
|
# All the fields including misc fields.
|
|
54
|
+
name_hash = hash(f"{properties.get('__module__')}.{name}")
|
|
53
55
|
all_fields = dict()
|
|
56
|
+
# mapping field names to base classes
|
|
57
|
+
field_to_cls_hash = dict()
|
|
54
58
|
|
|
55
59
|
for base in bases:
|
|
56
|
-
if hasattr(base, "_FIELDS"):
|
|
57
|
-
|
|
60
|
+
if not hasattr(base, "_FIELDS"):
|
|
61
|
+
continue
|
|
62
|
+
all_fields.update(base._FIELDS)
|
|
63
|
+
field_to_cls_hash.update(base._FIELD_TO_NAME_HASH)
|
|
58
64
|
|
|
59
65
|
properties_without_fields = {}
|
|
60
66
|
properties_field_slot_names = []
|
|
@@ -64,6 +70,8 @@ class SerializableMeta(type):
|
|
|
64
70
|
continue
|
|
65
71
|
|
|
66
72
|
field = all_fields.get(k)
|
|
73
|
+
# record the field for the class being created
|
|
74
|
+
field_to_cls_hash[k] = name_hash
|
|
67
75
|
if field is None:
|
|
68
76
|
properties_field_slot_names.append(k)
|
|
69
77
|
else:
|
|
@@ -75,23 +83,40 @@ class SerializableMeta(type):
|
|
|
75
83
|
|
|
76
84
|
# Make field order deterministic to serialize it as list instead of dict.
|
|
77
85
|
field_order = list(all_fields)
|
|
78
|
-
all_fields = dict(sorted(all_fields.items(), key=operator.itemgetter(0)))
|
|
79
86
|
primitive_fields = []
|
|
87
|
+
primitive_field_names = set()
|
|
80
88
|
non_primitive_fields = []
|
|
81
|
-
for v in all_fields.
|
|
89
|
+
for field_name, v in all_fields.items():
|
|
82
90
|
if _is_field_primitive_compound(v):
|
|
83
91
|
primitive_fields.append(v)
|
|
92
|
+
primitive_field_names.add(field_name)
|
|
84
93
|
else:
|
|
85
94
|
non_primitive_fields.append(v)
|
|
86
95
|
|
|
96
|
+
# count number of fields for every base class
|
|
97
|
+
cls_to_primitive_field_count = defaultdict(lambda: 0)
|
|
98
|
+
cls_to_non_primitive_field_count = defaultdict(lambda: 0)
|
|
99
|
+
for field_name in field_order:
|
|
100
|
+
cls_hash = field_to_cls_hash[field_name]
|
|
101
|
+
if field_name in primitive_field_names:
|
|
102
|
+
cls_to_primitive_field_count[cls_hash] += 1
|
|
103
|
+
else:
|
|
104
|
+
cls_to_non_primitive_field_count[cls_hash] += 1
|
|
105
|
+
|
|
87
106
|
slots = set(properties.pop("__slots__", set()))
|
|
88
107
|
slots.update(properties_field_slot_names)
|
|
89
108
|
|
|
90
109
|
properties = properties_without_fields
|
|
110
|
+
properties["_NAME_HASH"] = name_hash
|
|
91
111
|
properties["_FIELDS"] = all_fields
|
|
92
112
|
properties["_FIELD_ORDER"] = field_order
|
|
113
|
+
properties["_FIELD_TO_NAME_HASH"] = field_to_cls_hash
|
|
93
114
|
properties["_PRIMITIVE_FIELDS"] = primitive_fields
|
|
115
|
+
properties["_CLS_TO_PRIMITIVE_FIELD_COUNT"] = dict(cls_to_primitive_field_count)
|
|
94
116
|
properties["_NON_PRIMITIVE_FIELDS"] = non_primitive_fields
|
|
117
|
+
properties["_CLS_TO_NON_PRIMITIVE_FIELD_COUNT"] = dict(
|
|
118
|
+
cls_to_non_primitive_field_count
|
|
119
|
+
)
|
|
95
120
|
properties["__slots__"] = tuple(slots)
|
|
96
121
|
|
|
97
122
|
clz = type.__new__(mcs, name, bases, properties)
|
|
@@ -114,10 +139,14 @@ class Serializable(metaclass=SerializableMeta):
|
|
|
114
139
|
_cache_primitive_serial = False
|
|
115
140
|
_ignore_non_existing_keys = False
|
|
116
141
|
|
|
142
|
+
_NAME_HASH: int
|
|
117
143
|
_FIELDS: Dict[str, Field]
|
|
118
144
|
_FIELD_ORDER: List[str]
|
|
145
|
+
_FIELD_TO_NAME_HASH: Dict[str, int]
|
|
119
146
|
_PRIMITIVE_FIELDS: List[str]
|
|
147
|
+
_CLS_TO_PRIMITIVE_FIELD_COUNT: Dict[int, int]
|
|
120
148
|
_NON_PRIMITIVE_FIELDS: List[str]
|
|
149
|
+
_CLS_TO_NON_PRIMITIVE_FIELD_COUNT: Dict[int, int]
|
|
121
150
|
|
|
122
151
|
def __init__(self, *args, **kwargs):
|
|
123
152
|
fields = self._FIELDS
|
|
@@ -180,6 +209,10 @@ class SerializableSerializer(Serializer):
|
|
|
180
209
|
Leverage DictSerializer to perform serde.
|
|
181
210
|
"""
|
|
182
211
|
|
|
212
|
+
@classmethod
|
|
213
|
+
def _get_obj_field_count_key(cls, obj: Serializable):
|
|
214
|
+
return f"FC_{obj._NAME_HASH}"
|
|
215
|
+
|
|
183
216
|
@classmethod
|
|
184
217
|
def _get_field_values(cls, obj: Serializable, fields):
|
|
185
218
|
values = []
|
|
@@ -210,6 +243,18 @@ class SerializableSerializer(Serializer):
|
|
|
210
243
|
|
|
211
244
|
compound_vals = self._get_field_values(obj, obj._NON_PRIMITIVE_FIELDS)
|
|
212
245
|
cls_module = f"{type(obj).__module__}#{type(obj).__qualname__}"
|
|
246
|
+
|
|
247
|
+
field_count_key = self._get_obj_field_count_key(obj)
|
|
248
|
+
if not self.is_public_data_exist(context, field_count_key):
|
|
249
|
+
# store field distribution for current Serializable
|
|
250
|
+
counts = [
|
|
251
|
+
list(obj._CLS_TO_PRIMITIVE_FIELD_COUNT.items()),
|
|
252
|
+
list(obj._CLS_TO_NON_PRIMITIVE_FIELD_COUNT.items()),
|
|
253
|
+
]
|
|
254
|
+
field_count_data = msgpack.dumps(counts)
|
|
255
|
+
self.put_public_data(
|
|
256
|
+
context, self._get_obj_field_count_key(obj), field_count_data
|
|
257
|
+
)
|
|
213
258
|
return [cls_module, primitive_vals], [compound_vals], False
|
|
214
259
|
|
|
215
260
|
@staticmethod
|
|
@@ -229,6 +274,62 @@ class SerializableSerializer(Serializer):
|
|
|
229
274
|
else:
|
|
230
275
|
field.set(obj, value)
|
|
231
276
|
|
|
277
|
+
@classmethod
|
|
278
|
+
def _set_field_values(
|
|
279
|
+
cls,
|
|
280
|
+
obj: Serializable,
|
|
281
|
+
values: List[Any],
|
|
282
|
+
client_cls_to_field_count: Optional[Dict[str, int]],
|
|
283
|
+
is_primitive: bool = True,
|
|
284
|
+
):
|
|
285
|
+
obj_class = type(obj)
|
|
286
|
+
if is_primitive:
|
|
287
|
+
server_cls_to_field_count = obj_class._CLS_TO_PRIMITIVE_FIELD_COUNT
|
|
288
|
+
server_fields = obj_class._PRIMITIVE_FIELDS
|
|
289
|
+
else:
|
|
290
|
+
server_cls_to_field_count = obj_class._CLS_TO_NON_PRIMITIVE_FIELD_COUNT
|
|
291
|
+
server_fields = obj_class._NON_PRIMITIVE_FIELDS
|
|
292
|
+
|
|
293
|
+
if client_cls_to_field_count:
|
|
294
|
+
field_num, server_field_num = 0, 0
|
|
295
|
+
for cls_hash, count in client_cls_to_field_count.items():
|
|
296
|
+
# cut values and fields given field distribution
|
|
297
|
+
# at client and server end
|
|
298
|
+
cls_fields = server_fields[server_field_num : field_num + count]
|
|
299
|
+
cls_values = values[field_num : field_num + count]
|
|
300
|
+
for field, value in zip(cls_fields, cls_values):
|
|
301
|
+
if not is_primitive or value != {}:
|
|
302
|
+
cls._set_field_value(obj, field, value)
|
|
303
|
+
field_num += count
|
|
304
|
+
server_field_num += server_cls_to_field_count[cls_hash]
|
|
305
|
+
else:
|
|
306
|
+
# todo remove this branch when all versions below v0.1.0b5 is eliminated
|
|
307
|
+
from .field import AnyField
|
|
308
|
+
|
|
309
|
+
# legacy serialization style, with all fields sorted by name
|
|
310
|
+
if is_primitive:
|
|
311
|
+
field_attr = "_legacy_deprecated_primitives"
|
|
312
|
+
else:
|
|
313
|
+
field_attr = "_legacy_deprecated_non_primitives"
|
|
314
|
+
deprecated_fields = []
|
|
315
|
+
deprecated_names = set()
|
|
316
|
+
if hasattr(obj_class, field_attr):
|
|
317
|
+
deprecated_names = set(getattr(obj_class, field_attr))
|
|
318
|
+
for field_name in deprecated_names:
|
|
319
|
+
field = AnyField(tag=field_name)
|
|
320
|
+
field.name = field_name
|
|
321
|
+
deprecated_fields.append(field)
|
|
322
|
+
server_fields = sorted(
|
|
323
|
+
server_fields + deprecated_fields, key=lambda f: f.name
|
|
324
|
+
)
|
|
325
|
+
for field, value in zip(server_fields, values):
|
|
326
|
+
if not is_primitive or value != {}:
|
|
327
|
+
try:
|
|
328
|
+
cls._set_field_value(obj, field, value)
|
|
329
|
+
except AttributeError: # pragma: no cover
|
|
330
|
+
if field.name not in deprecated_names:
|
|
331
|
+
raise
|
|
332
|
+
|
|
232
333
|
def deserial(self, serialized: List, context: Dict, subs: List) -> Serializable:
|
|
233
334
|
obj_class_name, primitives = serialized
|
|
234
335
|
obj_class = load_type(obj_class_name, Serializable)
|
|
@@ -238,14 +339,20 @@ class SerializableSerializer(Serializer):
|
|
|
238
339
|
|
|
239
340
|
obj = obj_class.__new__(obj_class)
|
|
240
341
|
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
342
|
+
field_count_data = self.get_public_data(
|
|
343
|
+
context, self._get_obj_field_count_key(obj)
|
|
344
|
+
)
|
|
345
|
+
if field_count_data is not None:
|
|
346
|
+
cls_to_prim_key, cls_to_non_prim_key = msgpack.loads(field_count_data)
|
|
347
|
+
cls_to_prim_key = dict(cls_to_prim_key)
|
|
348
|
+
cls_to_non_prim_key = dict(cls_to_non_prim_key)
|
|
349
|
+
else:
|
|
350
|
+
cls_to_prim_key, cls_to_non_prim_key = None, None
|
|
245
351
|
|
|
352
|
+
if primitives:
|
|
353
|
+
self._set_field_values(obj, primitives, cls_to_prim_key, True)
|
|
246
354
|
if obj_class._NON_PRIMITIVE_FIELDS:
|
|
247
|
-
|
|
248
|
-
self._set_field_value(obj, field, value)
|
|
355
|
+
self._set_field_values(obj, subs[0], cls_to_non_prim_key, False)
|
|
249
356
|
obj.__on_deserialize__()
|
|
250
357
|
return obj
|
|
251
358
|
|
|
@@ -94,6 +94,11 @@ class MySimpleSerializable(Serializable):
|
|
|
94
94
|
_ref_val = ReferenceField("ref_val", "MySimpleSerializable")
|
|
95
95
|
|
|
96
96
|
|
|
97
|
+
class MySubSerializable(MySimpleSerializable):
|
|
98
|
+
_m_int_val = Int64Field("m_int_val", default=250)
|
|
99
|
+
_m_str_val = StringField("m_str_val", default="SUB_STR")
|
|
100
|
+
|
|
101
|
+
|
|
97
102
|
class MySerializable(Serializable):
|
|
98
103
|
_id = IdentityField("id")
|
|
99
104
|
_any_val = AnyField("any_val")
|
|
@@ -141,7 +146,7 @@ class MySerializable(Serializable):
|
|
|
141
146
|
|
|
142
147
|
|
|
143
148
|
@pytest.mark.parametrize("set_is_ci", [False, True], indirect=True)
|
|
144
|
-
@switch_unpickle
|
|
149
|
+
@switch_unpickle(forbidden=False)
|
|
145
150
|
def test_serializable(set_is_ci):
|
|
146
151
|
my_serializable = MySerializable(
|
|
147
152
|
_id="1",
|
|
@@ -165,7 +170,9 @@ def test_serializable(set_is_ci):
|
|
|
165
170
|
_key_val=MyHasKey("aaa"),
|
|
166
171
|
_ndarray_val=np.random.rand(4, 3),
|
|
167
172
|
_datetime64_val=pd.Timestamp(123),
|
|
168
|
-
_timedelta64_val=pd.Timedelta(
|
|
173
|
+
_timedelta64_val=pd.Timedelta(
|
|
174
|
+
days=1, seconds=123, microseconds=345, nanoseconds=132
|
|
175
|
+
),
|
|
169
176
|
_datatype_val=np.dtype(np.int32),
|
|
170
177
|
_index_val=pd.Index([1, 2]),
|
|
171
178
|
_series_val=pd.Series(["a", "b"]),
|
|
@@ -187,9 +194,44 @@ def test_serializable(set_is_ci):
|
|
|
187
194
|
_assert_serializable_eq(my_serializable, my_serializable2)
|
|
188
195
|
|
|
189
196
|
|
|
197
|
+
@pytest.mark.parametrize("set_is_ci", [False, True], indirect=True)
|
|
198
|
+
@switch_unpickle
|
|
199
|
+
def test_compatible_serializable(set_is_ci):
|
|
200
|
+
global MySimpleSerializable, MySubSerializable
|
|
201
|
+
|
|
202
|
+
old_base, old_sub = MySimpleSerializable, MySubSerializable
|
|
203
|
+
|
|
204
|
+
try:
|
|
205
|
+
my_sub_serializable = MySubSerializable(
|
|
206
|
+
_id="id_val",
|
|
207
|
+
_list_val=["abcd", "wxyz"],
|
|
208
|
+
_ref_val=MyHasKey(),
|
|
209
|
+
_m_int_val=3412,
|
|
210
|
+
_m_str_val="dfgghj",
|
|
211
|
+
)
|
|
212
|
+
header, buffers = serialize(my_sub_serializable)
|
|
213
|
+
|
|
214
|
+
class MySimpleSerializable(Serializable):
|
|
215
|
+
_id = IdentityField("id")
|
|
216
|
+
_int_val = Int64Field("int_val", default=1000)
|
|
217
|
+
_list_val = ListField("list_val", default_factory=list)
|
|
218
|
+
_ref_val = ReferenceField("ref_val", "MySimpleSerializable")
|
|
219
|
+
_dict_val = DictField("dict_val")
|
|
220
|
+
|
|
221
|
+
class MySubSerializable(MySimpleSerializable):
|
|
222
|
+
_m_int_val = Int64Field("m_int_val", default=250)
|
|
223
|
+
_m_str_val = StringField("m_str_val", default="SUB_STR")
|
|
224
|
+
|
|
225
|
+
my_sub_serializable2 = deserialize(header, buffers)
|
|
226
|
+
assert type(my_sub_serializable) is not type(my_sub_serializable2)
|
|
227
|
+
_assert_serializable_eq(my_sub_serializable, my_sub_serializable2)
|
|
228
|
+
finally:
|
|
229
|
+
MySimpleSerializable, MySubSerializable = old_base, old_sub
|
|
230
|
+
|
|
231
|
+
|
|
190
232
|
def _assert_serializable_eq(my_serializable, my_serializable2):
|
|
191
233
|
for field_name, field in my_serializable._FIELDS.items():
|
|
192
|
-
if not hasattr(my_serializable, field.
|
|
234
|
+
if not hasattr(my_serializable, field.name):
|
|
193
235
|
continue
|
|
194
236
|
expect_value = getattr(my_serializable, field_name)
|
|
195
237
|
actual_value = getattr(my_serializable2, field_name)
|
|
@@ -208,7 +250,7 @@ def _assert_serializable_eq(my_serializable, my_serializable2):
|
|
|
208
250
|
elif callable(expect_value):
|
|
209
251
|
assert expect_value(1) == actual_value(1)
|
|
210
252
|
else:
|
|
211
|
-
assert expect_value == actual_value
|
|
253
|
+
assert expect_value == actual_value, f"Field {field_name}"
|
|
212
254
|
|
|
213
255
|
|
|
214
256
|
@pytest.mark.parametrize("set_is_ci", [True], indirect=True)
|
maxframe/session.py
CHANGED
|
@@ -365,6 +365,15 @@ class AbstractAsyncSession(AbstractSession, metaclass=ABCMeta):
|
|
|
365
365
|
Stop server.
|
|
366
366
|
"""
|
|
367
367
|
|
|
368
|
+
@abstractmethod
|
|
369
|
+
async def get_logview_address(self, hours=None) -> Optional[str]:
|
|
370
|
+
"""
|
|
371
|
+
Get Logview address
|
|
372
|
+
Returns
|
|
373
|
+
-------
|
|
374
|
+
Logview address
|
|
375
|
+
"""
|
|
376
|
+
|
|
368
377
|
def close(self):
|
|
369
378
|
asyncio.run(self.destroy())
|
|
370
379
|
|
|
@@ -549,6 +558,15 @@ class AbstractSyncSession(AbstractSession, metaclass=ABCMeta):
|
|
|
549
558
|
|
|
550
559
|
return fetch(tileables, self, offsets=offsets, sizes=sizes)
|
|
551
560
|
|
|
561
|
+
@abstractmethod
|
|
562
|
+
def get_logview_address(self, hours=None) -> Optional[str]:
|
|
563
|
+
"""
|
|
564
|
+
Get logview address
|
|
565
|
+
Returns
|
|
566
|
+
-------
|
|
567
|
+
logview address
|
|
568
|
+
"""
|
|
569
|
+
|
|
552
570
|
|
|
553
571
|
def _delegate_to_isolated_session(func: Union[Callable, Coroutine]):
|
|
554
572
|
if asyncio.iscoroutinefunction(func):
|
|
@@ -728,6 +746,11 @@ class AsyncSession(AbstractAsyncSession):
|
|
|
728
746
|
await asyncio.wrap_future(asyncio.run_coroutine_threadsafe(coro, self._loop))
|
|
729
747
|
stop_isolation()
|
|
730
748
|
|
|
749
|
+
@implements(AbstractAsyncSession.get_logview_address)
|
|
750
|
+
@_delegate_to_isolated_session
|
|
751
|
+
async def get_logview_address(self, hours=None) -> Optional[str]:
|
|
752
|
+
pass # pragma: no cover
|
|
753
|
+
|
|
731
754
|
|
|
732
755
|
class ProgressBar:
|
|
733
756
|
def __init__(self, show_progress):
|
|
@@ -949,6 +972,11 @@ class SyncSession(AbstractSyncSession):
|
|
|
949
972
|
def get_cluster_versions(self) -> List[str]:
|
|
950
973
|
pass # pragma: no cover
|
|
951
974
|
|
|
975
|
+
@implements(AbstractSyncSession.get_logview_address)
|
|
976
|
+
@_delegate_to_isolated_session
|
|
977
|
+
def get_logview_address(self, hours=None) -> Optional[str]:
|
|
978
|
+
pass # pragma: no cover
|
|
979
|
+
|
|
952
980
|
def destroy(self):
|
|
953
981
|
coro = self._isolated_session.destroy()
|
|
954
982
|
asyncio.run_coroutine_threadsafe(coro, self._loop).result()
|
maxframe/tensor/__init__.py
CHANGED
|
@@ -114,7 +114,7 @@ from .arithmetic import (
|
|
|
114
114
|
)
|
|
115
115
|
from .arithmetic import truediv as true_divide
|
|
116
116
|
from .arithmetic import trunc
|
|
117
|
-
from .base import broadcast_to, transpose, where
|
|
117
|
+
from .base import broadcast_to, transpose, unique, where
|
|
118
118
|
from .core import Tensor
|
|
119
119
|
from .datasource import (
|
|
120
120
|
arange,
|
|
@@ -252,7 +252,7 @@ def test_compare():
|
|
|
252
252
|
|
|
253
253
|
def test_frexp():
|
|
254
254
|
t1 = ones((3, 4, 5), chunk_size=2)
|
|
255
|
-
t2 = empty((3, 4, 5), dtype=np.
|
|
255
|
+
t2 = empty((3, 4, 5), dtype=np.dtype(float), chunk_size=2)
|
|
256
256
|
op_type = type(t1.op)
|
|
257
257
|
|
|
258
258
|
o1, o2 = frexp(t1)
|
maxframe/tensor/base/__init__.py
CHANGED
|
@@ -13,9 +13,11 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
from .astype import TensorAstype
|
|
16
|
+
from .atleast_1d import atleast_1d
|
|
16
17
|
from .broadcast_to import TensorBroadcastTo, broadcast_to
|
|
17
18
|
from .ravel import ravel
|
|
18
19
|
from .transpose import transpose
|
|
20
|
+
from .unique import unique
|
|
19
21
|
from .where import TensorWhere, where
|
|
20
22
|
|
|
21
23
|
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright 1999-2024 Alibaba Group Holding Ltd.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
from ...core import ExecutableTuple
|
|
20
|
+
from ..datasource import tensor as astensor
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def atleast_1d(*tensors):
|
|
24
|
+
"""
|
|
25
|
+
Convert inputs to tensors with at least one dimension.
|
|
26
|
+
|
|
27
|
+
Scalar inputs are converted to 1-dimensional tensors, whilst
|
|
28
|
+
higher-dimensional inputs are preserved.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
tensors1, tensors2, ... : array_like
|
|
33
|
+
One or more input tensors.
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
ret : Tensor
|
|
38
|
+
An tensor, or list of tensors, each with ``a.ndim >= 1``.
|
|
39
|
+
Copies are made only if necessary.
|
|
40
|
+
|
|
41
|
+
See Also
|
|
42
|
+
--------
|
|
43
|
+
atleast_2d, atleast_3d
|
|
44
|
+
|
|
45
|
+
Examples
|
|
46
|
+
--------
|
|
47
|
+
>>> import maxframe.tensor as mt
|
|
48
|
+
|
|
49
|
+
>>> mt.atleast_1d(1.0).execute()
|
|
50
|
+
array([ 1.])
|
|
51
|
+
|
|
52
|
+
>>> x = mt.arange(9.0).reshape(3,3)
|
|
53
|
+
>>> mt.atleast_1d(x).execute()
|
|
54
|
+
array([[ 0., 1., 2.],
|
|
55
|
+
[ 3., 4., 5.],
|
|
56
|
+
[ 6., 7., 8.]])
|
|
57
|
+
>>> mt.atleast_1d(x) is x
|
|
58
|
+
True
|
|
59
|
+
|
|
60
|
+
>>> mt.atleast_1d(1, [3, 4]).execute()
|
|
61
|
+
[array([1]), array([3, 4])]
|
|
62
|
+
|
|
63
|
+
"""
|
|
64
|
+
new_tensors = []
|
|
65
|
+
for x in tensors:
|
|
66
|
+
x = astensor(x)
|
|
67
|
+
if x.ndim == 0:
|
|
68
|
+
x = x[np.newaxis]
|
|
69
|
+
|
|
70
|
+
new_tensors.append(x)
|
|
71
|
+
|
|
72
|
+
if len(new_tensors) == 1:
|
|
73
|
+
return new_tensors[0]
|
|
74
|
+
return ExecutableTuple(new_tensors)
|