maxframe 0.1.0b5__cp311-cp311-macosx_10_9_universal2.whl → 1.0.0rc2__cp311-cp311-macosx_10_9_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (92) hide show
  1. maxframe/_utils.cpython-311-darwin.so +0 -0
  2. maxframe/codegen.py +6 -2
  3. maxframe/config/config.py +38 -2
  4. maxframe/config/validators.py +1 -0
  5. maxframe/conftest.py +2 -0
  6. maxframe/core/__init__.py +0 -3
  7. maxframe/core/entity/__init__.py +1 -8
  8. maxframe/core/entity/objects.py +3 -45
  9. maxframe/core/graph/core.cpython-311-darwin.so +0 -0
  10. maxframe/core/graph/core.pyx +4 -4
  11. maxframe/dataframe/__init__.py +1 -1
  12. maxframe/dataframe/arithmetic/around.py +5 -17
  13. maxframe/dataframe/arithmetic/core.py +15 -7
  14. maxframe/dataframe/arithmetic/docstring.py +5 -55
  15. maxframe/dataframe/arithmetic/tests/test_arithmetic.py +22 -0
  16. maxframe/dataframe/core.py +5 -5
  17. maxframe/dataframe/datasource/date_range.py +2 -2
  18. maxframe/dataframe/datasource/read_odps_query.py +6 -0
  19. maxframe/dataframe/datasource/read_odps_table.py +2 -1
  20. maxframe/dataframe/datasource/tests/test_datasource.py +14 -0
  21. maxframe/dataframe/datastore/tests/__init__.py +13 -0
  22. maxframe/dataframe/datastore/tests/test_to_odps.py +48 -0
  23. maxframe/dataframe/datastore/to_odps.py +21 -0
  24. maxframe/dataframe/groupby/cum.py +0 -1
  25. maxframe/dataframe/groupby/tests/test_groupby.py +4 -0
  26. maxframe/dataframe/indexing/add_prefix_suffix.py +1 -1
  27. maxframe/dataframe/indexing/align.py +1 -1
  28. maxframe/dataframe/indexing/rename.py +3 -37
  29. maxframe/dataframe/indexing/sample.py +0 -1
  30. maxframe/dataframe/indexing/set_index.py +68 -1
  31. maxframe/dataframe/merge/merge.py +236 -2
  32. maxframe/dataframe/merge/tests/test_merge.py +123 -0
  33. maxframe/dataframe/misc/apply.py +5 -10
  34. maxframe/dataframe/misc/case_when.py +1 -1
  35. maxframe/dataframe/misc/describe.py +2 -2
  36. maxframe/dataframe/misc/drop_duplicates.py +4 -25
  37. maxframe/dataframe/misc/eval.py +4 -0
  38. maxframe/dataframe/misc/memory_usage.py +2 -2
  39. maxframe/dataframe/misc/pct_change.py +1 -83
  40. maxframe/dataframe/misc/tests/test_misc.py +23 -0
  41. maxframe/dataframe/misc/transform.py +1 -30
  42. maxframe/dataframe/misc/value_counts.py +4 -17
  43. maxframe/dataframe/missing/dropna.py +1 -1
  44. maxframe/dataframe/missing/fillna.py +5 -5
  45. maxframe/dataframe/sort/sort_values.py +1 -11
  46. maxframe/dataframe/statistics/corr.py +3 -3
  47. maxframe/dataframe/statistics/quantile.py +5 -17
  48. maxframe/dataframe/utils.py +4 -7
  49. maxframe/errors.py +13 -0
  50. maxframe/extension.py +12 -0
  51. maxframe/learn/contrib/xgboost/dmatrix.py +2 -2
  52. maxframe/learn/contrib/xgboost/predict.py +2 -2
  53. maxframe/learn/contrib/xgboost/train.py +2 -2
  54. maxframe/lib/mmh3.cpython-311-darwin.so +0 -0
  55. maxframe/lib/mmh3.pyi +43 -0
  56. maxframe/lib/wrapped_pickle.py +2 -1
  57. maxframe/odpsio/__init__.py +1 -1
  58. maxframe/odpsio/arrow.py +8 -4
  59. maxframe/odpsio/schema.py +10 -7
  60. maxframe/odpsio/tableio.py +388 -14
  61. maxframe/odpsio/tests/test_schema.py +16 -15
  62. maxframe/odpsio/tests/test_tableio.py +48 -21
  63. maxframe/protocol.py +148 -12
  64. maxframe/serialization/core.cpython-311-darwin.so +0 -0
  65. maxframe/serialization/core.pxd +3 -0
  66. maxframe/serialization/core.pyi +3 -0
  67. maxframe/serialization/core.pyx +54 -25
  68. maxframe/serialization/exception.py +1 -1
  69. maxframe/serialization/pandas.py +7 -2
  70. maxframe/serialization/serializables/core.py +158 -12
  71. maxframe/serialization/serializables/tests/test_serializable.py +46 -4
  72. maxframe/tensor/__init__.py +59 -0
  73. maxframe/tensor/arithmetic/tests/test_arithmetic.py +1 -1
  74. maxframe/tensor/base/atleast_1d.py +1 -1
  75. maxframe/tensor/base/unique.py +3 -3
  76. maxframe/tensor/reduction/count_nonzero.py +1 -1
  77. maxframe/tensor/statistics/quantile.py +2 -2
  78. maxframe/tests/test_protocol.py +34 -0
  79. maxframe/tests/test_utils.py +0 -12
  80. maxframe/tests/utils.py +11 -2
  81. maxframe/utils.py +24 -13
  82. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc2.dist-info}/METADATA +75 -2
  83. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc2.dist-info}/RECORD +91 -89
  84. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc2.dist-info}/WHEEL +1 -1
  85. maxframe_client/__init__.py +0 -1
  86. maxframe_client/fetcher.py +38 -27
  87. maxframe_client/session/odps.py +50 -10
  88. maxframe_client/session/task.py +41 -20
  89. maxframe_client/tests/test_fetcher.py +21 -3
  90. maxframe_client/tests/test_session.py +49 -2
  91. maxframe_client/clients/spe.py +0 -104
  92. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -12,12 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import operator
16
15
  import weakref
17
- from typing import Dict, List, Tuple, Type
16
+ from collections import defaultdict
17
+ from typing import Any, Dict, List, Optional, Tuple, Type
18
18
 
19
19
  import msgpack
20
20
 
21
+ from ...lib.mmh3 import hash
21
22
  from ..core import Placeholder, Serializer, buffered, load_type
22
23
  from .field import Field
23
24
  from .field_type import DictType, ListType, PrimitiveFieldType, TupleType
@@ -50,11 +51,19 @@ def _is_field_primitive_compound(field: Field):
50
51
  class SerializableMeta(type):
51
52
  def __new__(mcs, name: str, bases: Tuple[Type], properties: Dict):
52
53
  # All the fields including misc fields.
54
+ legacy_name_hash = hash(f"{properties.get('__module__')}.{name}")
55
+ name_hash = hash(
56
+ f"{properties.get('__module__')}.{properties.get('__qualname__')}"
57
+ )
53
58
  all_fields = dict()
59
+ # mapping field names to base classes
60
+ field_to_cls_hash = dict()
54
61
 
55
62
  for base in bases:
56
- if hasattr(base, "_FIELDS"):
57
- all_fields.update(base._FIELDS)
63
+ if not hasattr(base, "_FIELDS"):
64
+ continue
65
+ all_fields.update(base._FIELDS)
66
+ field_to_cls_hash.update(base._FIELD_TO_NAME_HASH)
58
67
 
59
68
  properties_without_fields = {}
60
69
  properties_field_slot_names = []
@@ -64,6 +73,8 @@ class SerializableMeta(type):
64
73
  continue
65
74
 
66
75
  field = all_fields.get(k)
76
+ # record the field for the class being created
77
+ field_to_cls_hash[k] = name_hash
67
78
  if field is None:
68
79
  properties_field_slot_names.append(k)
69
80
  else:
@@ -75,23 +86,44 @@ class SerializableMeta(type):
75
86
 
76
87
  # Make field order deterministic to serialize it as list instead of dict.
77
88
  field_order = list(all_fields)
78
- all_fields = dict(sorted(all_fields.items(), key=operator.itemgetter(0)))
79
89
  primitive_fields = []
90
+ primitive_field_names = set()
80
91
  non_primitive_fields = []
81
- for v in all_fields.values():
92
+ for field_name, v in all_fields.items():
82
93
  if _is_field_primitive_compound(v):
83
94
  primitive_fields.append(v)
95
+ primitive_field_names.add(field_name)
84
96
  else:
85
97
  non_primitive_fields.append(v)
86
98
 
99
+ # count number of fields for every base class
100
+ cls_to_primitive_field_count = defaultdict(lambda: 0)
101
+ cls_to_non_primitive_field_count = defaultdict(lambda: 0)
102
+ for field_name in field_order:
103
+ cls_hash = field_to_cls_hash[field_name]
104
+ if field_name in primitive_field_names:
105
+ cls_to_primitive_field_count[cls_hash] += 1
106
+ else:
107
+ cls_to_non_primitive_field_count[cls_hash] += 1
108
+
87
109
  slots = set(properties.pop("__slots__", set()))
88
110
  slots.update(properties_field_slot_names)
89
111
 
90
112
  properties = properties_without_fields
113
+
114
+ # todo remove this prop when all versions below v1.0.0rc1 is eliminated
115
+ properties["_LEGACY_NAME_HASH"] = legacy_name_hash
116
+
117
+ properties["_NAME_HASH"] = name_hash
91
118
  properties["_FIELDS"] = all_fields
92
119
  properties["_FIELD_ORDER"] = field_order
120
+ properties["_FIELD_TO_NAME_HASH"] = field_to_cls_hash
93
121
  properties["_PRIMITIVE_FIELDS"] = primitive_fields
122
+ properties["_CLS_TO_PRIMITIVE_FIELD_COUNT"] = dict(cls_to_primitive_field_count)
94
123
  properties["_NON_PRIMITIVE_FIELDS"] = non_primitive_fields
124
+ properties["_CLS_TO_NON_PRIMITIVE_FIELD_COUNT"] = dict(
125
+ cls_to_non_primitive_field_count
126
+ )
95
127
  properties["__slots__"] = tuple(slots)
96
128
 
97
129
  clz = type.__new__(mcs, name, bases, properties)
@@ -114,10 +146,14 @@ class Serializable(metaclass=SerializableMeta):
114
146
  _cache_primitive_serial = False
115
147
  _ignore_non_existing_keys = False
116
148
 
149
+ _NAME_HASH: int
117
150
  _FIELDS: Dict[str, Field]
118
151
  _FIELD_ORDER: List[str]
152
+ _FIELD_TO_NAME_HASH: Dict[str, int]
119
153
  _PRIMITIVE_FIELDS: List[str]
154
+ _CLS_TO_PRIMITIVE_FIELD_COUNT: Dict[int, int]
120
155
  _NON_PRIMITIVE_FIELDS: List[str]
156
+ _CLS_TO_NON_PRIMITIVE_FIELD_COUNT: Dict[int, int]
121
157
 
122
158
  def __init__(self, *args, **kwargs):
123
159
  fields = self._FIELDS
@@ -180,6 +216,10 @@ class SerializableSerializer(Serializer):
180
216
  Leverage DictSerializer to perform serde.
181
217
  """
182
218
 
219
+ @classmethod
220
+ def _get_obj_field_count_key(cls, obj: Serializable, legacy: bool = False):
221
+ return f"FC_{obj._NAME_HASH if not legacy else obj._LEGACY_NAME_HASH}"
222
+
183
223
  @classmethod
184
224
  def _get_field_values(cls, obj: Serializable, fields):
185
225
  values = []
@@ -210,6 +250,18 @@ class SerializableSerializer(Serializer):
210
250
 
211
251
  compound_vals = self._get_field_values(obj, obj._NON_PRIMITIVE_FIELDS)
212
252
  cls_module = f"{type(obj).__module__}#{type(obj).__qualname__}"
253
+
254
+ field_count_key = self._get_obj_field_count_key(obj)
255
+ if not self.is_public_data_exist(context, field_count_key):
256
+ # store field distribution for current Serializable
257
+ counts = [
258
+ list(obj._CLS_TO_PRIMITIVE_FIELD_COUNT.items()),
259
+ list(obj._CLS_TO_NON_PRIMITIVE_FIELD_COUNT.items()),
260
+ ]
261
+ field_count_data = msgpack.dumps(counts)
262
+ self.put_public_data(
263
+ context, self._get_obj_field_count_key(obj), field_count_data
264
+ )
213
265
  return [cls_module, primitive_vals], [compound_vals], False
214
266
 
215
267
  @staticmethod
@@ -229,6 +281,88 @@ class SerializableSerializer(Serializer):
229
281
  else:
230
282
  field.set(obj, value)
231
283
 
284
+ @classmethod
285
+ def _set_field_values(
286
+ cls,
287
+ obj: Serializable,
288
+ values: List[Any],
289
+ client_cls_to_field_count: Optional[Dict[str, int]],
290
+ is_primitive: bool = True,
291
+ ):
292
+ obj_class = type(obj)
293
+ if is_primitive:
294
+ server_cls_to_field_count = obj_class._CLS_TO_PRIMITIVE_FIELD_COUNT
295
+ server_fields = obj_class._PRIMITIVE_FIELDS
296
+ else:
297
+ server_cls_to_field_count = obj_class._CLS_TO_NON_PRIMITIVE_FIELD_COUNT
298
+ server_fields = obj_class._NON_PRIMITIVE_FIELDS
299
+
300
+ legacy_to_new_hash = {
301
+ c._LEGACY_NAME_HASH: c._NAME_HASH
302
+ for c in obj_class.__mro__
303
+ if hasattr(c, "_NAME_HASH") and c._LEGACY_NAME_HASH != c._NAME_HASH
304
+ }
305
+
306
+ if client_cls_to_field_count:
307
+ field_num, server_field_num = 0, 0
308
+ for cls_hash, count in client_cls_to_field_count.items():
309
+ # cut values and fields given field distribution
310
+ # at client and server end
311
+ cls_fields = server_fields[server_field_num : field_num + count]
312
+ cls_values = values[field_num : field_num + count]
313
+ for field, value in zip(cls_fields, cls_values):
314
+ if not is_primitive or value != {}:
315
+ cls._set_field_value(obj, field, value)
316
+ field_num += count
317
+ try:
318
+ server_field_num += server_cls_to_field_count[cls_hash]
319
+ except KeyError:
320
+ try:
321
+ # todo remove this fallback when all
322
+ # versions below v1.0.0rc1 is eliminated
323
+ server_field_num += server_cls_to_field_count[
324
+ legacy_to_new_hash[cls_hash]
325
+ ]
326
+ except KeyError:
327
+ # it is possible that certain type of field does not exist
328
+ # at server side
329
+ pass
330
+ else:
331
+ # handle legacy serialization style, with all fields sorted by name
332
+ # todo remove this branch when all versions below v0.1.0b5 is eliminated
333
+ from .field import AnyField
334
+
335
+ if is_primitive:
336
+ new_field_attr = "_legacy_new_primitives"
337
+ deprecated_field_attr = "_legacy_deprecated_primitives"
338
+ else:
339
+ new_field_attr = "_legacy_new_non_primitives"
340
+ deprecated_field_attr = "_legacy_deprecated_non_primitives"
341
+
342
+ # remove fields added on later releases
343
+ new_names = set(getattr(obj_class, new_field_attr, None) or [])
344
+ server_fields = [f for f in server_fields if f.name not in new_names]
345
+
346
+ # fill fields deprecated on later releases
347
+ deprecated_fields = []
348
+ deprecated_names = set()
349
+ if hasattr(obj_class, deprecated_field_attr):
350
+ deprecated_names = set(getattr(obj_class, deprecated_field_attr))
351
+ for field_name in deprecated_names:
352
+ field = AnyField(tag=field_name)
353
+ field.name = field_name
354
+ deprecated_fields.append(field)
355
+ server_fields = sorted(
356
+ server_fields + deprecated_fields, key=lambda f: f.name
357
+ )
358
+ for field, value in zip(server_fields, values):
359
+ if not is_primitive or value != {}:
360
+ try:
361
+ cls._set_field_value(obj, field, value)
362
+ except AttributeError: # pragma: no cover
363
+ if field.name not in deprecated_names:
364
+ raise
365
+
232
366
  def deserial(self, serialized: List, context: Dict, subs: List) -> Serializable:
233
367
  obj_class_name, primitives = serialized
234
368
  obj_class = load_type(obj_class_name, Serializable)
@@ -238,14 +372,26 @@ class SerializableSerializer(Serializer):
238
372
 
239
373
  obj = obj_class.__new__(obj_class)
240
374
 
241
- if primitives:
242
- for field, value in zip(obj_class._PRIMITIVE_FIELDS, primitives):
243
- if value != {}:
244
- self._set_field_value(obj, field, value)
375
+ field_count_data = self.get_public_data(
376
+ context, self._get_obj_field_count_key(obj)
377
+ )
378
+ if field_count_data is None:
379
+ # todo remove this fallback when all
380
+ # versions below v1.0.0rc1 is eliminated
381
+ field_count_data = self.get_public_data(
382
+ context, self._get_obj_field_count_key(obj, legacy=True)
383
+ )
384
+ if field_count_data is not None:
385
+ cls_to_prim_key, cls_to_non_prim_key = msgpack.loads(field_count_data)
386
+ cls_to_prim_key = dict(cls_to_prim_key)
387
+ cls_to_non_prim_key = dict(cls_to_non_prim_key)
388
+ else:
389
+ cls_to_prim_key, cls_to_non_prim_key = None, None
245
390
 
391
+ if primitives:
392
+ self._set_field_values(obj, primitives, cls_to_prim_key, True)
246
393
  if obj_class._NON_PRIMITIVE_FIELDS:
247
- for field, value in zip(obj_class._NON_PRIMITIVE_FIELDS, subs[0]):
248
- self._set_field_value(obj, field, value)
394
+ self._set_field_values(obj, subs[0], cls_to_non_prim_key, False)
249
395
  obj.__on_deserialize__()
250
396
  return obj
251
397
 
@@ -94,6 +94,11 @@ class MySimpleSerializable(Serializable):
94
94
  _ref_val = ReferenceField("ref_val", "MySimpleSerializable")
95
95
 
96
96
 
97
+ class MySubSerializable(MySimpleSerializable):
98
+ _m_int_val = Int64Field("m_int_val", default=250)
99
+ _m_str_val = StringField("m_str_val", default="SUB_STR")
100
+
101
+
97
102
  class MySerializable(Serializable):
98
103
  _id = IdentityField("id")
99
104
  _any_val = AnyField("any_val")
@@ -141,7 +146,7 @@ class MySerializable(Serializable):
141
146
 
142
147
 
143
148
  @pytest.mark.parametrize("set_is_ci", [False, True], indirect=True)
144
- @switch_unpickle
149
+ @switch_unpickle(forbidden=False)
145
150
  def test_serializable(set_is_ci):
146
151
  my_serializable = MySerializable(
147
152
  _id="1",
@@ -165,7 +170,9 @@ def test_serializable(set_is_ci):
165
170
  _key_val=MyHasKey("aaa"),
166
171
  _ndarray_val=np.random.rand(4, 3),
167
172
  _datetime64_val=pd.Timestamp(123),
168
- _timedelta64_val=pd.Timedelta(days=1),
173
+ _timedelta64_val=pd.Timedelta(
174
+ days=1, seconds=123, microseconds=345, nanoseconds=132
175
+ ),
169
176
  _datatype_val=np.dtype(np.int32),
170
177
  _index_val=pd.Index([1, 2]),
171
178
  _series_val=pd.Series(["a", "b"]),
@@ -187,9 +194,44 @@ def test_serializable(set_is_ci):
187
194
  _assert_serializable_eq(my_serializable, my_serializable2)
188
195
 
189
196
 
197
+ @pytest.mark.parametrize("set_is_ci", [False, True], indirect=True)
198
+ @switch_unpickle
199
+ def test_compatible_serializable(set_is_ci):
200
+ global MySimpleSerializable, MySubSerializable
201
+
202
+ old_base, old_sub = MySimpleSerializable, MySubSerializable
203
+
204
+ try:
205
+ my_sub_serializable = MySubSerializable(
206
+ _id="id_val",
207
+ _list_val=["abcd", "wxyz"],
208
+ _ref_val=MyHasKey(),
209
+ _m_int_val=3412,
210
+ _m_str_val="dfgghj",
211
+ )
212
+ header, buffers = serialize(my_sub_serializable)
213
+
214
+ class MySimpleSerializable(Serializable):
215
+ _id = IdentityField("id")
216
+ _int_val = Int64Field("int_val", default=1000)
217
+ _list_val = ListField("list_val", default_factory=list)
218
+ _ref_val = ReferenceField("ref_val", "MySimpleSerializable")
219
+ _dict_val = DictField("dict_val")
220
+
221
+ class MySubSerializable(MySimpleSerializable):
222
+ _m_int_val = Int64Field("m_int_val", default=250)
223
+ _m_str_val = StringField("m_str_val", default="SUB_STR")
224
+
225
+ my_sub_serializable2 = deserialize(header, buffers)
226
+ assert type(my_sub_serializable) is not type(my_sub_serializable2)
227
+ _assert_serializable_eq(my_sub_serializable, my_sub_serializable2)
228
+ finally:
229
+ MySimpleSerializable, MySubSerializable = old_base, old_sub
230
+
231
+
190
232
  def _assert_serializable_eq(my_serializable, my_serializable2):
191
233
  for field_name, field in my_serializable._FIELDS.items():
192
- if not hasattr(my_serializable, field.tag):
234
+ if not hasattr(my_serializable, field.name):
193
235
  continue
194
236
  expect_value = getattr(my_serializable, field_name)
195
237
  actual_value = getattr(my_serializable2, field_name)
@@ -208,7 +250,7 @@ def _assert_serializable_eq(my_serializable, my_serializable2):
208
250
  elif callable(expect_value):
209
251
  assert expect_value(1) == actual_value(1)
210
252
  else:
211
- assert expect_value == actual_value
253
+ assert expect_value == actual_value, f"Field {field_name}"
212
254
 
213
255
 
214
256
  @pytest.mark.parametrize("set_is_ci", [True], indirect=True)
@@ -180,4 +180,63 @@ from .reduction import std, sum, var
180
180
  from .reshape import reshape
181
181
  from .ufunc import ufunc
182
182
 
183
+ # isort: off
184
+ # noinspection PyUnresolvedReferences
185
+ from numpy import (
186
+ NAN,
187
+ NINF,
188
+ AxisError,
189
+ Inf,
190
+ NaN,
191
+ e,
192
+ errstate,
193
+ geterr,
194
+ inf,
195
+ nan,
196
+ newaxis,
197
+ pi,
198
+ seterr,
199
+ )
200
+
201
+ # import numpy types
202
+ # noinspection PyUnresolvedReferences
203
+ from numpy import (
204
+ bool_ as bool,
205
+ bytes_,
206
+ cfloat,
207
+ character,
208
+ complex64,
209
+ complex128,
210
+ complexfloating,
211
+ datetime64,
212
+ double,
213
+ dtype,
214
+ flexible,
215
+ float16,
216
+ float32,
217
+ float64,
218
+ floating,
219
+ generic,
220
+ inexact,
221
+ int8,
222
+ int16,
223
+ int32,
224
+ int64,
225
+ intc,
226
+ intp,
227
+ number,
228
+ integer,
229
+ object_ as object,
230
+ signedinteger,
231
+ timedelta64,
232
+ uint,
233
+ uint8,
234
+ uint16,
235
+ uint32,
236
+ uint64,
237
+ unicode_,
238
+ unsignedinteger,
239
+ void,
240
+ )
241
+
183
242
  del fetch, ufunc
@@ -252,7 +252,7 @@ def test_compare():
252
252
 
253
253
  def test_frexp():
254
254
  t1 = ones((3, 4, 5), chunk_size=2)
255
- t2 = empty((3, 4, 5), dtype=np.float_, chunk_size=2)
255
+ t2 = empty((3, 4, 5), dtype=np.dtype(float), chunk_size=2)
256
256
  op_type = type(t1.op)
257
257
 
258
258
  o1, o2 = frexp(t1)
@@ -1,6 +1,6 @@
1
1
  #!/usr/bin/env python
2
2
  # -*- coding: utf-8 -*-
3
- # Copyright 1999-2021 Alibaba Group Holding Ltd.
3
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
6
  # you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
15
15
 
16
16
  import numpy as np
17
17
 
18
- from ... import opcodes as OperandDef
18
+ from ... import opcodes
19
19
  from ...serialization.serializables import BoolField, Int32Field
20
20
  from ..core import TensorOrder
21
21
  from ..operators import TensorHasInput, TensorOperatorMixin
@@ -23,7 +23,7 @@ from ..utils import validate_axis
23
23
 
24
24
 
25
25
  class TensorUnique(TensorHasInput, TensorOperatorMixin):
26
- _op_type_ = OperandDef.UNIQUE
26
+ _op_type_ = opcodes.UNIQUE
27
27
 
28
28
  return_index = BoolField("return_index", default=False)
29
29
  return_inverse = BoolField("return_inverse", default=False)
@@ -75,7 +75,7 @@ class TensorUnique(TensorHasInput, TensorOperatorMixin):
75
75
  if self.return_counts:
76
76
  kw = {
77
77
  "shape": (np.nan,),
78
- "dtype": np.dtype(np.int_),
78
+ "dtype": np.dtype(int),
79
79
  "gpu": input_obj.op.gpu,
80
80
  "type": "counts",
81
81
  }
@@ -77,5 +77,5 @@ def count_nonzero(a, axis=None):
77
77
  array([2, 3])
78
78
 
79
79
  """
80
- op = TensorCountNonzero(axis=axis, dtype=np.dtype(np.int_), keepdims=None)
80
+ op = TensorCountNonzero(axis=axis, dtype=np.dtype(int), keepdims=None)
81
81
  return op(a)
@@ -16,7 +16,7 @@ from collections.abc import Iterable
16
16
 
17
17
  import numpy as np
18
18
 
19
- from ... import opcodes as OperandDef
19
+ from ... import opcodes
20
20
  from ...core import ENTITY_TYPE
21
21
  from ...serialization.serializables import AnyField, BoolField, KeyField, StringField
22
22
  from ..core import TENSOR_TYPE, TensorOrder
@@ -43,7 +43,7 @@ q_error_msg = "Quantiles must be in the range [0, 1]"
43
43
 
44
44
  class TensorQuantile(TensorOperator, TensorOperatorMixin):
45
45
  __slots__ = ("q_error_msg",)
46
- _op_type_ = OperandDef.QUANTILE
46
+ _op_type_ = opcodes.QUANTILE
47
47
 
48
48
  a = KeyField("a")
49
49
  q = AnyField("q")
@@ -85,6 +85,40 @@ def test_error_info_json_serialize():
85
85
  deserial_err_info.reraise()
86
86
 
87
87
 
88
+ class CannotPickleException(Exception):
89
+ def __reduce__(self):
90
+ raise ValueError
91
+
92
+
93
+ class CannotUnpickleException(Exception):
94
+ @classmethod
95
+ def load_from_pk(cls, _):
96
+ raise ValueError
97
+
98
+ def __reduce__(self):
99
+ return type(self).load_from_pk, (0,)
100
+
101
+
102
+ def test_error_info_fallback_json_serialize():
103
+ try:
104
+ raise CannotPickleException
105
+ except CannotPickleException as ex:
106
+ err_info1 = ErrorInfo.from_exception(ex)
107
+
108
+ try:
109
+ raise CannotUnpickleException
110
+ except CannotUnpickleException as ex:
111
+ err_info2 = ErrorInfo.from_exception(ex)
112
+
113
+ for err_info in (err_info1, err_info2):
114
+ deserial_err_info = ErrorInfo.from_json(err_info.to_json())
115
+ assert deserial_err_info.raw_error_source is None
116
+ assert deserial_err_info.raw_error_data is None
117
+
118
+ with pytest.raises(RemoteException):
119
+ deserial_err_info.reraise()
120
+
121
+
88
122
  def test_dag_info_json_serialize():
89
123
  try:
90
124
  raise ValueError("ERR_DATA")
@@ -288,15 +288,6 @@ def test_estimate_pandas_size():
288
288
  df2 = pd.DataFrame(np.random.rand(1000, 10))
289
289
  assert utils.estimate_pandas_size(df2) == sys.getsizeof(df2)
290
290
 
291
- df3 = pd.DataFrame(
292
- {
293
- "A": np.random.choice(["abcd", "def", "gh"], size=(1000,)),
294
- "B": np.random.rand(1000),
295
- "C": np.random.rand(1000),
296
- }
297
- )
298
- assert utils.estimate_pandas_size(df3) != sys.getsizeof(df3)
299
-
300
291
  s1 = pd.Series(np.random.rand(1000))
301
292
  assert utils.estimate_pandas_size(s1) == sys.getsizeof(s1)
302
293
 
@@ -307,7 +298,6 @@ def test_estimate_pandas_size():
307
298
  assert utils.estimate_pandas_size(s2) == sys.getsizeof(s2)
308
299
 
309
300
  s3 = pd.Series(np.random.choice(["abcd", "def", "gh"], size=(1000,)))
310
- assert utils.estimate_pandas_size(s3) != sys.getsizeof(s3)
311
301
  assert (
312
302
  pytest.approx(utils.estimate_pandas_size(s3) / sys.getsizeof(s3), abs=0.5) == 1
313
303
  )
@@ -318,7 +308,6 @@ def test_estimate_pandas_size():
318
308
  assert utils.estimate_pandas_size(idx1) == sys.getsizeof(idx1)
319
309
 
320
310
  string_idx = pd.Index(np.random.choice(["a", "bb", "cc"], size=(1000,)))
321
- assert utils.estimate_pandas_size(string_idx) != sys.getsizeof(string_idx)
322
311
  assert (
323
312
  pytest.approx(
324
313
  utils.estimate_pandas_size(string_idx) / sys.getsizeof(string_idx), abs=0.5
@@ -338,7 +327,6 @@ def test_estimate_pandas_size():
338
327
  },
339
328
  index=idx2,
340
329
  )
341
- assert utils.estimate_pandas_size(df4) != sys.getsizeof(df4)
342
330
  assert (
343
331
  pytest.approx(utils.estimate_pandas_size(df4) / sys.getsizeof(df4), abs=0.5)
344
332
  == 1
maxframe/tests/utils.py CHANGED
@@ -14,6 +14,7 @@
14
14
 
15
15
  import asyncio
16
16
  import functools
17
+ import hashlib
17
18
  import os
18
19
  import queue
19
20
  import socket
@@ -25,7 +26,7 @@ import pytest
25
26
  from tornado import netutil
26
27
 
27
28
  from ..core import Tileable, TileableGraph
28
- from ..utils import lazy_import
29
+ from ..utils import create_sync_primitive, lazy_import, to_binary
29
30
 
30
31
  try:
31
32
  from flaky import flaky
@@ -102,7 +103,7 @@ def run_app_in_thread(app_func):
102
103
  def fixture_func(*args, **kwargs):
103
104
  app_loop = asyncio.new_event_loop()
104
105
  q = queue.Queue()
105
- exit_event = asyncio.Event(loop=app_loop)
106
+ exit_event = create_sync_primitive(asyncio.Event, app_loop)
106
107
  app_thread = Thread(
107
108
  name="TestAppThread",
108
109
  target=app_thread_func,
@@ -162,3 +163,11 @@ def require_hadoop(func):
162
163
  not os.environ.get("WITH_HADOOP"), reason="Only run when hadoop is installed"
163
164
  )(func)
164
165
  return func
166
+
167
+
168
+ def get_test_unique_name(size=None):
169
+ test_name = os.getenv("PYTEST_CURRENT_TEST", "pyodps_test")
170
+ digest = hashlib.md5(to_binary(test_name)).hexdigest()
171
+ if size:
172
+ digest = digest[:size]
173
+ return digest + "_" + str(os.getpid())
maxframe/utils.py CHANGED
@@ -33,7 +33,6 @@ import sys
33
33
  import threading
34
34
  import time
35
35
  import tokenize as pytokenize
36
- import traceback
37
36
  import types
38
37
  import weakref
39
38
  import zlib
@@ -396,18 +395,6 @@ def build_tileable_dir_name(tileable_key: str) -> str:
396
395
  return m.hexdigest()
397
396
 
398
397
 
399
- def extract_messages_and_stacks(exc: Exception) -> Tuple[List[str], List[str]]:
400
- cur_exc = exc
401
- messages, stacks = [], []
402
- while True:
403
- messages.append(str(cur_exc))
404
- stacks.append("".join(traceback.format_tb(cur_exc.__traceback__)))
405
- if exc.__cause__ is None:
406
- break
407
- cur_exc = exc.__cause__
408
- return messages, stacks
409
-
410
-
411
398
  async def wait_http_response(
412
399
  url: str, *, request_timeout: TimeoutType = None, **kwargs
413
400
  ) -> httpclient.HTTPResponse:
@@ -449,6 +436,29 @@ async def to_thread_pool(func, *args, pool=None, **kwargs):
449
436
  return await loop.run_in_executor(pool, func_call)
450
437
 
451
438
 
439
+ _PrimitiveType = TypeVar("_PrimitiveType")
440
+
441
+
442
+ def create_sync_primitive(
443
+ cls: Type[_PrimitiveType], loop: asyncio.AbstractEventLoop
444
+ ) -> _PrimitiveType:
445
+ """
446
+ Create an asyncio sync primitive (locks, events, etc.)
447
+ in a certain event loop.
448
+ """
449
+ if sys.version_info[1] < 10:
450
+ return cls(loop=loop)
451
+
452
+ # From Python3.10 the loop parameter has been removed. We should work around here.
453
+ old_loop = asyncio.get_event_loop()
454
+ try:
455
+ asyncio.set_event_loop(loop)
456
+ primitive = cls()
457
+ finally:
458
+ asyncio.set_event_loop(old_loop)
459
+ return primitive
460
+
461
+
452
462
  class ToThreadCancelledError(asyncio.CancelledError):
453
463
  def __init__(self, *args, result=None):
454
464
  super().__init__(*args)
@@ -519,6 +529,7 @@ def config_odps_default_options():
519
529
  "metaservice.client.cache.enable": "false",
520
530
  "odps.sql.session.result.cache.enable": "false",
521
531
  "odps.sql.submit.mode": "script",
532
+ "odps.sql.job.max.time.hours": 72,
522
533
  }
523
534
 
524
535