maxframe 0.1.0b5__cp38-cp38-macosx_10_9_universal2.whl → 1.0.0rc2__cp38-cp38-macosx_10_9_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of maxframe might be problematic. Click here for more details.
- maxframe/_utils.cpython-38-darwin.so +0 -0
- maxframe/codegen.py +6 -2
- maxframe/config/config.py +38 -2
- maxframe/config/validators.py +1 -0
- maxframe/conftest.py +2 -0
- maxframe/core/__init__.py +0 -3
- maxframe/core/entity/__init__.py +1 -8
- maxframe/core/entity/objects.py +3 -45
- maxframe/core/graph/core.cpython-38-darwin.so +0 -0
- maxframe/core/graph/core.pyx +4 -4
- maxframe/dataframe/__init__.py +1 -1
- maxframe/dataframe/arithmetic/around.py +5 -17
- maxframe/dataframe/arithmetic/core.py +15 -7
- maxframe/dataframe/arithmetic/docstring.py +5 -55
- maxframe/dataframe/arithmetic/tests/test_arithmetic.py +22 -0
- maxframe/dataframe/core.py +5 -5
- maxframe/dataframe/datasource/date_range.py +2 -2
- maxframe/dataframe/datasource/read_odps_query.py +6 -0
- maxframe/dataframe/datasource/read_odps_table.py +2 -1
- maxframe/dataframe/datasource/tests/test_datasource.py +14 -0
- maxframe/dataframe/datastore/tests/__init__.py +13 -0
- maxframe/dataframe/datastore/tests/test_to_odps.py +48 -0
- maxframe/dataframe/datastore/to_odps.py +21 -0
- maxframe/dataframe/groupby/cum.py +0 -1
- maxframe/dataframe/groupby/tests/test_groupby.py +4 -0
- maxframe/dataframe/indexing/add_prefix_suffix.py +1 -1
- maxframe/dataframe/indexing/align.py +1 -1
- maxframe/dataframe/indexing/rename.py +3 -37
- maxframe/dataframe/indexing/sample.py +0 -1
- maxframe/dataframe/indexing/set_index.py +68 -1
- maxframe/dataframe/merge/merge.py +236 -2
- maxframe/dataframe/merge/tests/test_merge.py +123 -0
- maxframe/dataframe/misc/apply.py +5 -10
- maxframe/dataframe/misc/case_when.py +1 -1
- maxframe/dataframe/misc/describe.py +2 -2
- maxframe/dataframe/misc/drop_duplicates.py +4 -25
- maxframe/dataframe/misc/eval.py +4 -0
- maxframe/dataframe/misc/memory_usage.py +2 -2
- maxframe/dataframe/misc/pct_change.py +1 -83
- maxframe/dataframe/misc/tests/test_misc.py +23 -0
- maxframe/dataframe/misc/transform.py +1 -30
- maxframe/dataframe/misc/value_counts.py +4 -17
- maxframe/dataframe/missing/dropna.py +1 -1
- maxframe/dataframe/missing/fillna.py +5 -5
- maxframe/dataframe/sort/sort_values.py +1 -11
- maxframe/dataframe/statistics/corr.py +3 -3
- maxframe/dataframe/statistics/quantile.py +5 -17
- maxframe/dataframe/utils.py +4 -7
- maxframe/errors.py +13 -0
- maxframe/extension.py +12 -0
- maxframe/learn/contrib/xgboost/dmatrix.py +2 -2
- maxframe/learn/contrib/xgboost/predict.py +2 -2
- maxframe/learn/contrib/xgboost/train.py +2 -2
- maxframe/lib/mmh3.cpython-38-darwin.so +0 -0
- maxframe/lib/mmh3.pyi +43 -0
- maxframe/lib/wrapped_pickle.py +2 -1
- maxframe/odpsio/__init__.py +1 -1
- maxframe/odpsio/arrow.py +8 -4
- maxframe/odpsio/schema.py +10 -7
- maxframe/odpsio/tableio.py +388 -14
- maxframe/odpsio/tests/test_schema.py +16 -15
- maxframe/odpsio/tests/test_tableio.py +48 -21
- maxframe/protocol.py +148 -12
- maxframe/serialization/core.cpython-38-darwin.so +0 -0
- maxframe/serialization/core.pxd +3 -0
- maxframe/serialization/core.pyi +3 -0
- maxframe/serialization/core.pyx +54 -25
- maxframe/serialization/exception.py +1 -1
- maxframe/serialization/pandas.py +7 -2
- maxframe/serialization/serializables/core.py +158 -12
- maxframe/serialization/serializables/tests/test_serializable.py +46 -4
- maxframe/tensor/__init__.py +59 -0
- maxframe/tensor/arithmetic/tests/test_arithmetic.py +1 -1
- maxframe/tensor/base/atleast_1d.py +1 -1
- maxframe/tensor/base/unique.py +3 -3
- maxframe/tensor/reduction/count_nonzero.py +1 -1
- maxframe/tensor/statistics/quantile.py +2 -2
- maxframe/tests/test_protocol.py +34 -0
- maxframe/tests/test_utils.py +0 -12
- maxframe/tests/utils.py +11 -2
- maxframe/utils.py +24 -13
- {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc2.dist-info}/METADATA +75 -2
- {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc2.dist-info}/RECORD +91 -89
- {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc2.dist-info}/WHEEL +1 -1
- maxframe_client/__init__.py +0 -1
- maxframe_client/fetcher.py +38 -27
- maxframe_client/session/odps.py +50 -10
- maxframe_client/session/task.py +41 -20
- maxframe_client/tests/test_fetcher.py +21 -3
- maxframe_client/tests/test_session.py +49 -2
- maxframe_client/clients/spe.py +0 -104
- {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -12,12 +12,13 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import operator
|
|
16
15
|
import weakref
|
|
17
|
-
from
|
|
16
|
+
from collections import defaultdict
|
|
17
|
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
|
18
18
|
|
|
19
19
|
import msgpack
|
|
20
20
|
|
|
21
|
+
from ...lib.mmh3 import hash
|
|
21
22
|
from ..core import Placeholder, Serializer, buffered, load_type
|
|
22
23
|
from .field import Field
|
|
23
24
|
from .field_type import DictType, ListType, PrimitiveFieldType, TupleType
|
|
@@ -50,11 +51,19 @@ def _is_field_primitive_compound(field: Field):
|
|
|
50
51
|
class SerializableMeta(type):
|
|
51
52
|
def __new__(mcs, name: str, bases: Tuple[Type], properties: Dict):
|
|
52
53
|
# All the fields including misc fields.
|
|
54
|
+
legacy_name_hash = hash(f"{properties.get('__module__')}.{name}")
|
|
55
|
+
name_hash = hash(
|
|
56
|
+
f"{properties.get('__module__')}.{properties.get('__qualname__')}"
|
|
57
|
+
)
|
|
53
58
|
all_fields = dict()
|
|
59
|
+
# mapping field names to base classes
|
|
60
|
+
field_to_cls_hash = dict()
|
|
54
61
|
|
|
55
62
|
for base in bases:
|
|
56
|
-
if hasattr(base, "_FIELDS"):
|
|
57
|
-
|
|
63
|
+
if not hasattr(base, "_FIELDS"):
|
|
64
|
+
continue
|
|
65
|
+
all_fields.update(base._FIELDS)
|
|
66
|
+
field_to_cls_hash.update(base._FIELD_TO_NAME_HASH)
|
|
58
67
|
|
|
59
68
|
properties_without_fields = {}
|
|
60
69
|
properties_field_slot_names = []
|
|
@@ -64,6 +73,8 @@ class SerializableMeta(type):
|
|
|
64
73
|
continue
|
|
65
74
|
|
|
66
75
|
field = all_fields.get(k)
|
|
76
|
+
# record the field for the class being created
|
|
77
|
+
field_to_cls_hash[k] = name_hash
|
|
67
78
|
if field is None:
|
|
68
79
|
properties_field_slot_names.append(k)
|
|
69
80
|
else:
|
|
@@ -75,23 +86,44 @@ class SerializableMeta(type):
|
|
|
75
86
|
|
|
76
87
|
# Make field order deterministic to serialize it as list instead of dict.
|
|
77
88
|
field_order = list(all_fields)
|
|
78
|
-
all_fields = dict(sorted(all_fields.items(), key=operator.itemgetter(0)))
|
|
79
89
|
primitive_fields = []
|
|
90
|
+
primitive_field_names = set()
|
|
80
91
|
non_primitive_fields = []
|
|
81
|
-
for v in all_fields.
|
|
92
|
+
for field_name, v in all_fields.items():
|
|
82
93
|
if _is_field_primitive_compound(v):
|
|
83
94
|
primitive_fields.append(v)
|
|
95
|
+
primitive_field_names.add(field_name)
|
|
84
96
|
else:
|
|
85
97
|
non_primitive_fields.append(v)
|
|
86
98
|
|
|
99
|
+
# count number of fields for every base class
|
|
100
|
+
cls_to_primitive_field_count = defaultdict(lambda: 0)
|
|
101
|
+
cls_to_non_primitive_field_count = defaultdict(lambda: 0)
|
|
102
|
+
for field_name in field_order:
|
|
103
|
+
cls_hash = field_to_cls_hash[field_name]
|
|
104
|
+
if field_name in primitive_field_names:
|
|
105
|
+
cls_to_primitive_field_count[cls_hash] += 1
|
|
106
|
+
else:
|
|
107
|
+
cls_to_non_primitive_field_count[cls_hash] += 1
|
|
108
|
+
|
|
87
109
|
slots = set(properties.pop("__slots__", set()))
|
|
88
110
|
slots.update(properties_field_slot_names)
|
|
89
111
|
|
|
90
112
|
properties = properties_without_fields
|
|
113
|
+
|
|
114
|
+
# todo remove this prop when all versions below v1.0.0rc1 is eliminated
|
|
115
|
+
properties["_LEGACY_NAME_HASH"] = legacy_name_hash
|
|
116
|
+
|
|
117
|
+
properties["_NAME_HASH"] = name_hash
|
|
91
118
|
properties["_FIELDS"] = all_fields
|
|
92
119
|
properties["_FIELD_ORDER"] = field_order
|
|
120
|
+
properties["_FIELD_TO_NAME_HASH"] = field_to_cls_hash
|
|
93
121
|
properties["_PRIMITIVE_FIELDS"] = primitive_fields
|
|
122
|
+
properties["_CLS_TO_PRIMITIVE_FIELD_COUNT"] = dict(cls_to_primitive_field_count)
|
|
94
123
|
properties["_NON_PRIMITIVE_FIELDS"] = non_primitive_fields
|
|
124
|
+
properties["_CLS_TO_NON_PRIMITIVE_FIELD_COUNT"] = dict(
|
|
125
|
+
cls_to_non_primitive_field_count
|
|
126
|
+
)
|
|
95
127
|
properties["__slots__"] = tuple(slots)
|
|
96
128
|
|
|
97
129
|
clz = type.__new__(mcs, name, bases, properties)
|
|
@@ -114,10 +146,14 @@ class Serializable(metaclass=SerializableMeta):
|
|
|
114
146
|
_cache_primitive_serial = False
|
|
115
147
|
_ignore_non_existing_keys = False
|
|
116
148
|
|
|
149
|
+
_NAME_HASH: int
|
|
117
150
|
_FIELDS: Dict[str, Field]
|
|
118
151
|
_FIELD_ORDER: List[str]
|
|
152
|
+
_FIELD_TO_NAME_HASH: Dict[str, int]
|
|
119
153
|
_PRIMITIVE_FIELDS: List[str]
|
|
154
|
+
_CLS_TO_PRIMITIVE_FIELD_COUNT: Dict[int, int]
|
|
120
155
|
_NON_PRIMITIVE_FIELDS: List[str]
|
|
156
|
+
_CLS_TO_NON_PRIMITIVE_FIELD_COUNT: Dict[int, int]
|
|
121
157
|
|
|
122
158
|
def __init__(self, *args, **kwargs):
|
|
123
159
|
fields = self._FIELDS
|
|
@@ -180,6 +216,10 @@ class SerializableSerializer(Serializer):
|
|
|
180
216
|
Leverage DictSerializer to perform serde.
|
|
181
217
|
"""
|
|
182
218
|
|
|
219
|
+
@classmethod
|
|
220
|
+
def _get_obj_field_count_key(cls, obj: Serializable, legacy: bool = False):
|
|
221
|
+
return f"FC_{obj._NAME_HASH if not legacy else obj._LEGACY_NAME_HASH}"
|
|
222
|
+
|
|
183
223
|
@classmethod
|
|
184
224
|
def _get_field_values(cls, obj: Serializable, fields):
|
|
185
225
|
values = []
|
|
@@ -210,6 +250,18 @@ class SerializableSerializer(Serializer):
|
|
|
210
250
|
|
|
211
251
|
compound_vals = self._get_field_values(obj, obj._NON_PRIMITIVE_FIELDS)
|
|
212
252
|
cls_module = f"{type(obj).__module__}#{type(obj).__qualname__}"
|
|
253
|
+
|
|
254
|
+
field_count_key = self._get_obj_field_count_key(obj)
|
|
255
|
+
if not self.is_public_data_exist(context, field_count_key):
|
|
256
|
+
# store field distribution for current Serializable
|
|
257
|
+
counts = [
|
|
258
|
+
list(obj._CLS_TO_PRIMITIVE_FIELD_COUNT.items()),
|
|
259
|
+
list(obj._CLS_TO_NON_PRIMITIVE_FIELD_COUNT.items()),
|
|
260
|
+
]
|
|
261
|
+
field_count_data = msgpack.dumps(counts)
|
|
262
|
+
self.put_public_data(
|
|
263
|
+
context, self._get_obj_field_count_key(obj), field_count_data
|
|
264
|
+
)
|
|
213
265
|
return [cls_module, primitive_vals], [compound_vals], False
|
|
214
266
|
|
|
215
267
|
@staticmethod
|
|
@@ -229,6 +281,88 @@ class SerializableSerializer(Serializer):
|
|
|
229
281
|
else:
|
|
230
282
|
field.set(obj, value)
|
|
231
283
|
|
|
284
|
+
@classmethod
|
|
285
|
+
def _set_field_values(
|
|
286
|
+
cls,
|
|
287
|
+
obj: Serializable,
|
|
288
|
+
values: List[Any],
|
|
289
|
+
client_cls_to_field_count: Optional[Dict[str, int]],
|
|
290
|
+
is_primitive: bool = True,
|
|
291
|
+
):
|
|
292
|
+
obj_class = type(obj)
|
|
293
|
+
if is_primitive:
|
|
294
|
+
server_cls_to_field_count = obj_class._CLS_TO_PRIMITIVE_FIELD_COUNT
|
|
295
|
+
server_fields = obj_class._PRIMITIVE_FIELDS
|
|
296
|
+
else:
|
|
297
|
+
server_cls_to_field_count = obj_class._CLS_TO_NON_PRIMITIVE_FIELD_COUNT
|
|
298
|
+
server_fields = obj_class._NON_PRIMITIVE_FIELDS
|
|
299
|
+
|
|
300
|
+
legacy_to_new_hash = {
|
|
301
|
+
c._LEGACY_NAME_HASH: c._NAME_HASH
|
|
302
|
+
for c in obj_class.__mro__
|
|
303
|
+
if hasattr(c, "_NAME_HASH") and c._LEGACY_NAME_HASH != c._NAME_HASH
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
if client_cls_to_field_count:
|
|
307
|
+
field_num, server_field_num = 0, 0
|
|
308
|
+
for cls_hash, count in client_cls_to_field_count.items():
|
|
309
|
+
# cut values and fields given field distribution
|
|
310
|
+
# at client and server end
|
|
311
|
+
cls_fields = server_fields[server_field_num : field_num + count]
|
|
312
|
+
cls_values = values[field_num : field_num + count]
|
|
313
|
+
for field, value in zip(cls_fields, cls_values):
|
|
314
|
+
if not is_primitive or value != {}:
|
|
315
|
+
cls._set_field_value(obj, field, value)
|
|
316
|
+
field_num += count
|
|
317
|
+
try:
|
|
318
|
+
server_field_num += server_cls_to_field_count[cls_hash]
|
|
319
|
+
except KeyError:
|
|
320
|
+
try:
|
|
321
|
+
# todo remove this fallback when all
|
|
322
|
+
# versions below v1.0.0rc1 is eliminated
|
|
323
|
+
server_field_num += server_cls_to_field_count[
|
|
324
|
+
legacy_to_new_hash[cls_hash]
|
|
325
|
+
]
|
|
326
|
+
except KeyError:
|
|
327
|
+
# it is possible that certain type of field does not exist
|
|
328
|
+
# at server side
|
|
329
|
+
pass
|
|
330
|
+
else:
|
|
331
|
+
# handle legacy serialization style, with all fields sorted by name
|
|
332
|
+
# todo remove this branch when all versions below v0.1.0b5 is eliminated
|
|
333
|
+
from .field import AnyField
|
|
334
|
+
|
|
335
|
+
if is_primitive:
|
|
336
|
+
new_field_attr = "_legacy_new_primitives"
|
|
337
|
+
deprecated_field_attr = "_legacy_deprecated_primitives"
|
|
338
|
+
else:
|
|
339
|
+
new_field_attr = "_legacy_new_non_primitives"
|
|
340
|
+
deprecated_field_attr = "_legacy_deprecated_non_primitives"
|
|
341
|
+
|
|
342
|
+
# remove fields added on later releases
|
|
343
|
+
new_names = set(getattr(obj_class, new_field_attr, None) or [])
|
|
344
|
+
server_fields = [f for f in server_fields if f.name not in new_names]
|
|
345
|
+
|
|
346
|
+
# fill fields deprecated on later releases
|
|
347
|
+
deprecated_fields = []
|
|
348
|
+
deprecated_names = set()
|
|
349
|
+
if hasattr(obj_class, deprecated_field_attr):
|
|
350
|
+
deprecated_names = set(getattr(obj_class, deprecated_field_attr))
|
|
351
|
+
for field_name in deprecated_names:
|
|
352
|
+
field = AnyField(tag=field_name)
|
|
353
|
+
field.name = field_name
|
|
354
|
+
deprecated_fields.append(field)
|
|
355
|
+
server_fields = sorted(
|
|
356
|
+
server_fields + deprecated_fields, key=lambda f: f.name
|
|
357
|
+
)
|
|
358
|
+
for field, value in zip(server_fields, values):
|
|
359
|
+
if not is_primitive or value != {}:
|
|
360
|
+
try:
|
|
361
|
+
cls._set_field_value(obj, field, value)
|
|
362
|
+
except AttributeError: # pragma: no cover
|
|
363
|
+
if field.name not in deprecated_names:
|
|
364
|
+
raise
|
|
365
|
+
|
|
232
366
|
def deserial(self, serialized: List, context: Dict, subs: List) -> Serializable:
|
|
233
367
|
obj_class_name, primitives = serialized
|
|
234
368
|
obj_class = load_type(obj_class_name, Serializable)
|
|
@@ -238,14 +372,26 @@ class SerializableSerializer(Serializer):
|
|
|
238
372
|
|
|
239
373
|
obj = obj_class.__new__(obj_class)
|
|
240
374
|
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
375
|
+
field_count_data = self.get_public_data(
|
|
376
|
+
context, self._get_obj_field_count_key(obj)
|
|
377
|
+
)
|
|
378
|
+
if field_count_data is None:
|
|
379
|
+
# todo remove this fallback when all
|
|
380
|
+
# versions below v1.0.0rc1 is eliminated
|
|
381
|
+
field_count_data = self.get_public_data(
|
|
382
|
+
context, self._get_obj_field_count_key(obj, legacy=True)
|
|
383
|
+
)
|
|
384
|
+
if field_count_data is not None:
|
|
385
|
+
cls_to_prim_key, cls_to_non_prim_key = msgpack.loads(field_count_data)
|
|
386
|
+
cls_to_prim_key = dict(cls_to_prim_key)
|
|
387
|
+
cls_to_non_prim_key = dict(cls_to_non_prim_key)
|
|
388
|
+
else:
|
|
389
|
+
cls_to_prim_key, cls_to_non_prim_key = None, None
|
|
245
390
|
|
|
391
|
+
if primitives:
|
|
392
|
+
self._set_field_values(obj, primitives, cls_to_prim_key, True)
|
|
246
393
|
if obj_class._NON_PRIMITIVE_FIELDS:
|
|
247
|
-
|
|
248
|
-
self._set_field_value(obj, field, value)
|
|
394
|
+
self._set_field_values(obj, subs[0], cls_to_non_prim_key, False)
|
|
249
395
|
obj.__on_deserialize__()
|
|
250
396
|
return obj
|
|
251
397
|
|
|
@@ -94,6 +94,11 @@ class MySimpleSerializable(Serializable):
|
|
|
94
94
|
_ref_val = ReferenceField("ref_val", "MySimpleSerializable")
|
|
95
95
|
|
|
96
96
|
|
|
97
|
+
class MySubSerializable(MySimpleSerializable):
|
|
98
|
+
_m_int_val = Int64Field("m_int_val", default=250)
|
|
99
|
+
_m_str_val = StringField("m_str_val", default="SUB_STR")
|
|
100
|
+
|
|
101
|
+
|
|
97
102
|
class MySerializable(Serializable):
|
|
98
103
|
_id = IdentityField("id")
|
|
99
104
|
_any_val = AnyField("any_val")
|
|
@@ -141,7 +146,7 @@ class MySerializable(Serializable):
|
|
|
141
146
|
|
|
142
147
|
|
|
143
148
|
@pytest.mark.parametrize("set_is_ci", [False, True], indirect=True)
|
|
144
|
-
@switch_unpickle
|
|
149
|
+
@switch_unpickle(forbidden=False)
|
|
145
150
|
def test_serializable(set_is_ci):
|
|
146
151
|
my_serializable = MySerializable(
|
|
147
152
|
_id="1",
|
|
@@ -165,7 +170,9 @@ def test_serializable(set_is_ci):
|
|
|
165
170
|
_key_val=MyHasKey("aaa"),
|
|
166
171
|
_ndarray_val=np.random.rand(4, 3),
|
|
167
172
|
_datetime64_val=pd.Timestamp(123),
|
|
168
|
-
_timedelta64_val=pd.Timedelta(
|
|
173
|
+
_timedelta64_val=pd.Timedelta(
|
|
174
|
+
days=1, seconds=123, microseconds=345, nanoseconds=132
|
|
175
|
+
),
|
|
169
176
|
_datatype_val=np.dtype(np.int32),
|
|
170
177
|
_index_val=pd.Index([1, 2]),
|
|
171
178
|
_series_val=pd.Series(["a", "b"]),
|
|
@@ -187,9 +194,44 @@ def test_serializable(set_is_ci):
|
|
|
187
194
|
_assert_serializable_eq(my_serializable, my_serializable2)
|
|
188
195
|
|
|
189
196
|
|
|
197
|
+
@pytest.mark.parametrize("set_is_ci", [False, True], indirect=True)
|
|
198
|
+
@switch_unpickle
|
|
199
|
+
def test_compatible_serializable(set_is_ci):
|
|
200
|
+
global MySimpleSerializable, MySubSerializable
|
|
201
|
+
|
|
202
|
+
old_base, old_sub = MySimpleSerializable, MySubSerializable
|
|
203
|
+
|
|
204
|
+
try:
|
|
205
|
+
my_sub_serializable = MySubSerializable(
|
|
206
|
+
_id="id_val",
|
|
207
|
+
_list_val=["abcd", "wxyz"],
|
|
208
|
+
_ref_val=MyHasKey(),
|
|
209
|
+
_m_int_val=3412,
|
|
210
|
+
_m_str_val="dfgghj",
|
|
211
|
+
)
|
|
212
|
+
header, buffers = serialize(my_sub_serializable)
|
|
213
|
+
|
|
214
|
+
class MySimpleSerializable(Serializable):
|
|
215
|
+
_id = IdentityField("id")
|
|
216
|
+
_int_val = Int64Field("int_val", default=1000)
|
|
217
|
+
_list_val = ListField("list_val", default_factory=list)
|
|
218
|
+
_ref_val = ReferenceField("ref_val", "MySimpleSerializable")
|
|
219
|
+
_dict_val = DictField("dict_val")
|
|
220
|
+
|
|
221
|
+
class MySubSerializable(MySimpleSerializable):
|
|
222
|
+
_m_int_val = Int64Field("m_int_val", default=250)
|
|
223
|
+
_m_str_val = StringField("m_str_val", default="SUB_STR")
|
|
224
|
+
|
|
225
|
+
my_sub_serializable2 = deserialize(header, buffers)
|
|
226
|
+
assert type(my_sub_serializable) is not type(my_sub_serializable2)
|
|
227
|
+
_assert_serializable_eq(my_sub_serializable, my_sub_serializable2)
|
|
228
|
+
finally:
|
|
229
|
+
MySimpleSerializable, MySubSerializable = old_base, old_sub
|
|
230
|
+
|
|
231
|
+
|
|
190
232
|
def _assert_serializable_eq(my_serializable, my_serializable2):
|
|
191
233
|
for field_name, field in my_serializable._FIELDS.items():
|
|
192
|
-
if not hasattr(my_serializable, field.
|
|
234
|
+
if not hasattr(my_serializable, field.name):
|
|
193
235
|
continue
|
|
194
236
|
expect_value = getattr(my_serializable, field_name)
|
|
195
237
|
actual_value = getattr(my_serializable2, field_name)
|
|
@@ -208,7 +250,7 @@ def _assert_serializable_eq(my_serializable, my_serializable2):
|
|
|
208
250
|
elif callable(expect_value):
|
|
209
251
|
assert expect_value(1) == actual_value(1)
|
|
210
252
|
else:
|
|
211
|
-
assert expect_value == actual_value
|
|
253
|
+
assert expect_value == actual_value, f"Field {field_name}"
|
|
212
254
|
|
|
213
255
|
|
|
214
256
|
@pytest.mark.parametrize("set_is_ci", [True], indirect=True)
|
maxframe/tensor/__init__.py
CHANGED
|
@@ -180,4 +180,63 @@ from .reduction import std, sum, var
|
|
|
180
180
|
from .reshape import reshape
|
|
181
181
|
from .ufunc import ufunc
|
|
182
182
|
|
|
183
|
+
# isort: off
|
|
184
|
+
# noinspection PyUnresolvedReferences
|
|
185
|
+
from numpy import (
|
|
186
|
+
NAN,
|
|
187
|
+
NINF,
|
|
188
|
+
AxisError,
|
|
189
|
+
Inf,
|
|
190
|
+
NaN,
|
|
191
|
+
e,
|
|
192
|
+
errstate,
|
|
193
|
+
geterr,
|
|
194
|
+
inf,
|
|
195
|
+
nan,
|
|
196
|
+
newaxis,
|
|
197
|
+
pi,
|
|
198
|
+
seterr,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# import numpy types
|
|
202
|
+
# noinspection PyUnresolvedReferences
|
|
203
|
+
from numpy import (
|
|
204
|
+
bool_ as bool,
|
|
205
|
+
bytes_,
|
|
206
|
+
cfloat,
|
|
207
|
+
character,
|
|
208
|
+
complex64,
|
|
209
|
+
complex128,
|
|
210
|
+
complexfloating,
|
|
211
|
+
datetime64,
|
|
212
|
+
double,
|
|
213
|
+
dtype,
|
|
214
|
+
flexible,
|
|
215
|
+
float16,
|
|
216
|
+
float32,
|
|
217
|
+
float64,
|
|
218
|
+
floating,
|
|
219
|
+
generic,
|
|
220
|
+
inexact,
|
|
221
|
+
int8,
|
|
222
|
+
int16,
|
|
223
|
+
int32,
|
|
224
|
+
int64,
|
|
225
|
+
intc,
|
|
226
|
+
intp,
|
|
227
|
+
number,
|
|
228
|
+
integer,
|
|
229
|
+
object_ as object,
|
|
230
|
+
signedinteger,
|
|
231
|
+
timedelta64,
|
|
232
|
+
uint,
|
|
233
|
+
uint8,
|
|
234
|
+
uint16,
|
|
235
|
+
uint32,
|
|
236
|
+
uint64,
|
|
237
|
+
unicode_,
|
|
238
|
+
unsignedinteger,
|
|
239
|
+
void,
|
|
240
|
+
)
|
|
241
|
+
|
|
183
242
|
del fetch, ufunc
|
|
@@ -252,7 +252,7 @@ def test_compare():
|
|
|
252
252
|
|
|
253
253
|
def test_frexp():
|
|
254
254
|
t1 = ones((3, 4, 5), chunk_size=2)
|
|
255
|
-
t2 = empty((3, 4, 5), dtype=np.
|
|
255
|
+
t2 = empty((3, 4, 5), dtype=np.dtype(float), chunk_size=2)
|
|
256
256
|
op_type = type(t1.op)
|
|
257
257
|
|
|
258
258
|
o1, o2 = frexp(t1)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
#!/usr/bin/env python
|
|
2
2
|
# -*- coding: utf-8 -*-
|
|
3
|
-
# Copyright 1999-
|
|
3
|
+
# Copyright 1999-2024 Alibaba Group Holding Ltd.
|
|
4
4
|
#
|
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
6
|
# you may not use this file except in compliance with the License.
|
maxframe/tensor/base/unique.py
CHANGED
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
|
|
18
|
-
from ... import opcodes
|
|
18
|
+
from ... import opcodes
|
|
19
19
|
from ...serialization.serializables import BoolField, Int32Field
|
|
20
20
|
from ..core import TensorOrder
|
|
21
21
|
from ..operators import TensorHasInput, TensorOperatorMixin
|
|
@@ -23,7 +23,7 @@ from ..utils import validate_axis
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class TensorUnique(TensorHasInput, TensorOperatorMixin):
|
|
26
|
-
_op_type_ =
|
|
26
|
+
_op_type_ = opcodes.UNIQUE
|
|
27
27
|
|
|
28
28
|
return_index = BoolField("return_index", default=False)
|
|
29
29
|
return_inverse = BoolField("return_inverse", default=False)
|
|
@@ -75,7 +75,7 @@ class TensorUnique(TensorHasInput, TensorOperatorMixin):
|
|
|
75
75
|
if self.return_counts:
|
|
76
76
|
kw = {
|
|
77
77
|
"shape": (np.nan,),
|
|
78
|
-
"dtype": np.dtype(
|
|
78
|
+
"dtype": np.dtype(int),
|
|
79
79
|
"gpu": input_obj.op.gpu,
|
|
80
80
|
"type": "counts",
|
|
81
81
|
}
|
|
@@ -16,7 +16,7 @@ from collections.abc import Iterable
|
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
|
|
19
|
-
from ... import opcodes
|
|
19
|
+
from ... import opcodes
|
|
20
20
|
from ...core import ENTITY_TYPE
|
|
21
21
|
from ...serialization.serializables import AnyField, BoolField, KeyField, StringField
|
|
22
22
|
from ..core import TENSOR_TYPE, TensorOrder
|
|
@@ -43,7 +43,7 @@ q_error_msg = "Quantiles must be in the range [0, 1]"
|
|
|
43
43
|
|
|
44
44
|
class TensorQuantile(TensorOperator, TensorOperatorMixin):
|
|
45
45
|
__slots__ = ("q_error_msg",)
|
|
46
|
-
_op_type_ =
|
|
46
|
+
_op_type_ = opcodes.QUANTILE
|
|
47
47
|
|
|
48
48
|
a = KeyField("a")
|
|
49
49
|
q = AnyField("q")
|
maxframe/tests/test_protocol.py
CHANGED
|
@@ -85,6 +85,40 @@ def test_error_info_json_serialize():
|
|
|
85
85
|
deserial_err_info.reraise()
|
|
86
86
|
|
|
87
87
|
|
|
88
|
+
class CannotPickleException(Exception):
|
|
89
|
+
def __reduce__(self):
|
|
90
|
+
raise ValueError
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class CannotUnpickleException(Exception):
|
|
94
|
+
@classmethod
|
|
95
|
+
def load_from_pk(cls, _):
|
|
96
|
+
raise ValueError
|
|
97
|
+
|
|
98
|
+
def __reduce__(self):
|
|
99
|
+
return type(self).load_from_pk, (0,)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def test_error_info_fallback_json_serialize():
|
|
103
|
+
try:
|
|
104
|
+
raise CannotPickleException
|
|
105
|
+
except CannotPickleException as ex:
|
|
106
|
+
err_info1 = ErrorInfo.from_exception(ex)
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
raise CannotUnpickleException
|
|
110
|
+
except CannotUnpickleException as ex:
|
|
111
|
+
err_info2 = ErrorInfo.from_exception(ex)
|
|
112
|
+
|
|
113
|
+
for err_info in (err_info1, err_info2):
|
|
114
|
+
deserial_err_info = ErrorInfo.from_json(err_info.to_json())
|
|
115
|
+
assert deserial_err_info.raw_error_source is None
|
|
116
|
+
assert deserial_err_info.raw_error_data is None
|
|
117
|
+
|
|
118
|
+
with pytest.raises(RemoteException):
|
|
119
|
+
deserial_err_info.reraise()
|
|
120
|
+
|
|
121
|
+
|
|
88
122
|
def test_dag_info_json_serialize():
|
|
89
123
|
try:
|
|
90
124
|
raise ValueError("ERR_DATA")
|
maxframe/tests/test_utils.py
CHANGED
|
@@ -288,15 +288,6 @@ def test_estimate_pandas_size():
|
|
|
288
288
|
df2 = pd.DataFrame(np.random.rand(1000, 10))
|
|
289
289
|
assert utils.estimate_pandas_size(df2) == sys.getsizeof(df2)
|
|
290
290
|
|
|
291
|
-
df3 = pd.DataFrame(
|
|
292
|
-
{
|
|
293
|
-
"A": np.random.choice(["abcd", "def", "gh"], size=(1000,)),
|
|
294
|
-
"B": np.random.rand(1000),
|
|
295
|
-
"C": np.random.rand(1000),
|
|
296
|
-
}
|
|
297
|
-
)
|
|
298
|
-
assert utils.estimate_pandas_size(df3) != sys.getsizeof(df3)
|
|
299
|
-
|
|
300
291
|
s1 = pd.Series(np.random.rand(1000))
|
|
301
292
|
assert utils.estimate_pandas_size(s1) == sys.getsizeof(s1)
|
|
302
293
|
|
|
@@ -307,7 +298,6 @@ def test_estimate_pandas_size():
|
|
|
307
298
|
assert utils.estimate_pandas_size(s2) == sys.getsizeof(s2)
|
|
308
299
|
|
|
309
300
|
s3 = pd.Series(np.random.choice(["abcd", "def", "gh"], size=(1000,)))
|
|
310
|
-
assert utils.estimate_pandas_size(s3) != sys.getsizeof(s3)
|
|
311
301
|
assert (
|
|
312
302
|
pytest.approx(utils.estimate_pandas_size(s3) / sys.getsizeof(s3), abs=0.5) == 1
|
|
313
303
|
)
|
|
@@ -318,7 +308,6 @@ def test_estimate_pandas_size():
|
|
|
318
308
|
assert utils.estimate_pandas_size(idx1) == sys.getsizeof(idx1)
|
|
319
309
|
|
|
320
310
|
string_idx = pd.Index(np.random.choice(["a", "bb", "cc"], size=(1000,)))
|
|
321
|
-
assert utils.estimate_pandas_size(string_idx) != sys.getsizeof(string_idx)
|
|
322
311
|
assert (
|
|
323
312
|
pytest.approx(
|
|
324
313
|
utils.estimate_pandas_size(string_idx) / sys.getsizeof(string_idx), abs=0.5
|
|
@@ -338,7 +327,6 @@ def test_estimate_pandas_size():
|
|
|
338
327
|
},
|
|
339
328
|
index=idx2,
|
|
340
329
|
)
|
|
341
|
-
assert utils.estimate_pandas_size(df4) != sys.getsizeof(df4)
|
|
342
330
|
assert (
|
|
343
331
|
pytest.approx(utils.estimate_pandas_size(df4) / sys.getsizeof(df4), abs=0.5)
|
|
344
332
|
== 1
|
maxframe/tests/utils.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
16
|
import functools
|
|
17
|
+
import hashlib
|
|
17
18
|
import os
|
|
18
19
|
import queue
|
|
19
20
|
import socket
|
|
@@ -25,7 +26,7 @@ import pytest
|
|
|
25
26
|
from tornado import netutil
|
|
26
27
|
|
|
27
28
|
from ..core import Tileable, TileableGraph
|
|
28
|
-
from ..utils import lazy_import
|
|
29
|
+
from ..utils import create_sync_primitive, lazy_import, to_binary
|
|
29
30
|
|
|
30
31
|
try:
|
|
31
32
|
from flaky import flaky
|
|
@@ -102,7 +103,7 @@ def run_app_in_thread(app_func):
|
|
|
102
103
|
def fixture_func(*args, **kwargs):
|
|
103
104
|
app_loop = asyncio.new_event_loop()
|
|
104
105
|
q = queue.Queue()
|
|
105
|
-
exit_event = asyncio.Event
|
|
106
|
+
exit_event = create_sync_primitive(asyncio.Event, app_loop)
|
|
106
107
|
app_thread = Thread(
|
|
107
108
|
name="TestAppThread",
|
|
108
109
|
target=app_thread_func,
|
|
@@ -162,3 +163,11 @@ def require_hadoop(func):
|
|
|
162
163
|
not os.environ.get("WITH_HADOOP"), reason="Only run when hadoop is installed"
|
|
163
164
|
)(func)
|
|
164
165
|
return func
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def get_test_unique_name(size=None):
|
|
169
|
+
test_name = os.getenv("PYTEST_CURRENT_TEST", "pyodps_test")
|
|
170
|
+
digest = hashlib.md5(to_binary(test_name)).hexdigest()
|
|
171
|
+
if size:
|
|
172
|
+
digest = digest[:size]
|
|
173
|
+
return digest + "_" + str(os.getpid())
|
maxframe/utils.py
CHANGED
|
@@ -33,7 +33,6 @@ import sys
|
|
|
33
33
|
import threading
|
|
34
34
|
import time
|
|
35
35
|
import tokenize as pytokenize
|
|
36
|
-
import traceback
|
|
37
36
|
import types
|
|
38
37
|
import weakref
|
|
39
38
|
import zlib
|
|
@@ -396,18 +395,6 @@ def build_tileable_dir_name(tileable_key: str) -> str:
|
|
|
396
395
|
return m.hexdigest()
|
|
397
396
|
|
|
398
397
|
|
|
399
|
-
def extract_messages_and_stacks(exc: Exception) -> Tuple[List[str], List[str]]:
|
|
400
|
-
cur_exc = exc
|
|
401
|
-
messages, stacks = [], []
|
|
402
|
-
while True:
|
|
403
|
-
messages.append(str(cur_exc))
|
|
404
|
-
stacks.append("".join(traceback.format_tb(cur_exc.__traceback__)))
|
|
405
|
-
if exc.__cause__ is None:
|
|
406
|
-
break
|
|
407
|
-
cur_exc = exc.__cause__
|
|
408
|
-
return messages, stacks
|
|
409
|
-
|
|
410
|
-
|
|
411
398
|
async def wait_http_response(
|
|
412
399
|
url: str, *, request_timeout: TimeoutType = None, **kwargs
|
|
413
400
|
) -> httpclient.HTTPResponse:
|
|
@@ -449,6 +436,29 @@ async def to_thread_pool(func, *args, pool=None, **kwargs):
|
|
|
449
436
|
return await loop.run_in_executor(pool, func_call)
|
|
450
437
|
|
|
451
438
|
|
|
439
|
+
_PrimitiveType = TypeVar("_PrimitiveType")
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
def create_sync_primitive(
|
|
443
|
+
cls: Type[_PrimitiveType], loop: asyncio.AbstractEventLoop
|
|
444
|
+
) -> _PrimitiveType:
|
|
445
|
+
"""
|
|
446
|
+
Create an asyncio sync primitive (locks, events, etc.)
|
|
447
|
+
in a certain event loop.
|
|
448
|
+
"""
|
|
449
|
+
if sys.version_info[1] < 10:
|
|
450
|
+
return cls(loop=loop)
|
|
451
|
+
|
|
452
|
+
# From Python3.10 the loop parameter has been removed. We should work around here.
|
|
453
|
+
old_loop = asyncio.get_event_loop()
|
|
454
|
+
try:
|
|
455
|
+
asyncio.set_event_loop(loop)
|
|
456
|
+
primitive = cls()
|
|
457
|
+
finally:
|
|
458
|
+
asyncio.set_event_loop(old_loop)
|
|
459
|
+
return primitive
|
|
460
|
+
|
|
461
|
+
|
|
452
462
|
class ToThreadCancelledError(asyncio.CancelledError):
|
|
453
463
|
def __init__(self, *args, result=None):
|
|
454
464
|
super().__init__(*args)
|
|
@@ -519,6 +529,7 @@ def config_odps_default_options():
|
|
|
519
529
|
"metaservice.client.cache.enable": "false",
|
|
520
530
|
"odps.sql.session.result.cache.enable": "false",
|
|
521
531
|
"odps.sql.submit.mode": "script",
|
|
532
|
+
"odps.sql.job.max.time.hours": 72,
|
|
522
533
|
}
|
|
523
534
|
|
|
524
535
|
|