maxframe 0.1.0b4__cp310-cp310-win_amd64.whl → 1.0.0rc1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (81) hide show
  1. maxframe/__init__.py +1 -0
  2. maxframe/_utils.cp310-win_amd64.pyd +0 -0
  3. maxframe/codegen.py +56 -3
  4. maxframe/config/config.py +15 -1
  5. maxframe/core/__init__.py +0 -3
  6. maxframe/core/entity/__init__.py +1 -8
  7. maxframe/core/entity/objects.py +3 -45
  8. maxframe/core/graph/core.cp310-win_amd64.pyd +0 -0
  9. maxframe/core/graph/core.pyx +4 -4
  10. maxframe/dataframe/__init__.py +1 -0
  11. maxframe/dataframe/core.py +30 -8
  12. maxframe/dataframe/datasource/read_odps_query.py +3 -1
  13. maxframe/dataframe/datasource/read_odps_table.py +3 -1
  14. maxframe/dataframe/datastore/tests/__init__.py +13 -0
  15. maxframe/dataframe/datastore/tests/test_to_odps.py +48 -0
  16. maxframe/dataframe/datastore/to_odps.py +21 -0
  17. maxframe/dataframe/indexing/align.py +1 -1
  18. maxframe/dataframe/misc/__init__.py +4 -0
  19. maxframe/dataframe/misc/apply.py +3 -1
  20. maxframe/dataframe/misc/case_when.py +141 -0
  21. maxframe/dataframe/misc/memory_usage.py +2 -2
  22. maxframe/dataframe/misc/pivot_table.py +262 -0
  23. maxframe/dataframe/misc/tests/test_misc.py +84 -0
  24. maxframe/dataframe/plotting/core.py +2 -2
  25. maxframe/dataframe/reduction/core.py +2 -1
  26. maxframe/dataframe/statistics/corr.py +3 -3
  27. maxframe/dataframe/utils.py +7 -0
  28. maxframe/errors.py +13 -0
  29. maxframe/extension.py +12 -0
  30. maxframe/learn/contrib/utils.py +52 -0
  31. maxframe/learn/contrib/xgboost/__init__.py +26 -0
  32. maxframe/learn/contrib/xgboost/classifier.py +86 -0
  33. maxframe/learn/contrib/xgboost/core.py +156 -0
  34. maxframe/learn/contrib/xgboost/dmatrix.py +150 -0
  35. maxframe/learn/contrib/xgboost/predict.py +138 -0
  36. maxframe/learn/contrib/xgboost/regressor.py +78 -0
  37. maxframe/learn/contrib/xgboost/tests/__init__.py +13 -0
  38. maxframe/learn/contrib/xgboost/tests/test_core.py +43 -0
  39. maxframe/learn/contrib/xgboost/train.py +121 -0
  40. maxframe/learn/utils/__init__.py +15 -0
  41. maxframe/learn/utils/core.py +29 -0
  42. maxframe/lib/mmh3.cp310-win_amd64.pyd +0 -0
  43. maxframe/lib/mmh3.pyi +43 -0
  44. maxframe/lib/wrapped_pickle.py +2 -1
  45. maxframe/odpsio/arrow.py +2 -3
  46. maxframe/odpsio/tableio.py +22 -0
  47. maxframe/odpsio/tests/test_schema.py +16 -11
  48. maxframe/opcodes.py +3 -0
  49. maxframe/protocol.py +108 -10
  50. maxframe/serialization/core.cp310-win_amd64.pyd +0 -0
  51. maxframe/serialization/core.pxd +3 -0
  52. maxframe/serialization/core.pyi +64 -0
  53. maxframe/serialization/core.pyx +54 -25
  54. maxframe/serialization/exception.py +1 -1
  55. maxframe/serialization/pandas.py +7 -2
  56. maxframe/serialization/serializables/core.py +119 -12
  57. maxframe/serialization/serializables/tests/test_serializable.py +46 -4
  58. maxframe/session.py +28 -0
  59. maxframe/tensor/__init__.py +1 -1
  60. maxframe/tensor/arithmetic/tests/test_arithmetic.py +1 -1
  61. maxframe/tensor/base/__init__.py +2 -0
  62. maxframe/tensor/base/atleast_1d.py +74 -0
  63. maxframe/tensor/base/unique.py +205 -0
  64. maxframe/tensor/datasource/array.py +4 -2
  65. maxframe/tensor/datasource/scalar.py +1 -1
  66. maxframe/tensor/reduction/count_nonzero.py +1 -1
  67. maxframe/tests/test_protocol.py +34 -0
  68. maxframe/tests/test_utils.py +0 -12
  69. maxframe/tests/utils.py +2 -2
  70. maxframe/udf.py +63 -3
  71. maxframe/utils.py +22 -13
  72. {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/METADATA +3 -3
  73. {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/RECORD +80 -61
  74. maxframe_client/__init__.py +0 -1
  75. maxframe_client/fetcher.py +65 -3
  76. maxframe_client/session/odps.py +74 -5
  77. maxframe_client/session/task.py +65 -71
  78. maxframe_client/tests/test_session.py +64 -1
  79. maxframe_client/clients/spe.py +0 -104
  80. {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/WHEEL +0 -0
  81. {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -130,11 +130,30 @@ cdef Serializer get_deserializer(int32_t deserializer_id):
130
130
 
131
131
  cdef class Serializer:
132
132
  serializer_id = None
133
+ _public_data_context_key = 0x7fffffff - 1
133
134
 
134
135
  def __cinit__(self):
135
136
  # make the value can be referenced with C code
136
137
  self._serializer_id = self.serializer_id
137
138
 
139
+ cpdef bint is_public_data_exist(self, dict context, object key):
140
+ cdef dict public_dict = context.get(self._public_data_context_key, None)
141
+ if public_dict is None:
142
+ return False
143
+ return key in public_dict
144
+
145
+ cpdef put_public_data(self, dict context, object key, object value):
146
+ cdef dict public_dict = context.get(self._public_data_context_key, None)
147
+ if public_dict is None:
148
+ public_dict = context[self._public_data_context_key] = {}
149
+ public_dict[key] = value
150
+
151
+ cpdef get_public_data(self, dict context, object key):
152
+ cdef dict public_dict = context.get(self._public_data_context_key, None)
153
+ if public_dict is None:
154
+ return None
155
+ return public_dict.get(key)
156
+
138
157
  cpdef serial(self, object obj, dict context):
139
158
  """
140
159
  Returns intermediate serialization result of certain object.
@@ -993,17 +1012,20 @@ def serialize(obj, dict context = None):
993
1012
  cdef list subs
994
1013
  cdef bint final
995
1014
  cdef _IdContextHolder id_context_holder = _IdContextHolder()
1015
+ cdef tuple result
996
1016
 
997
1017
  context = context if context is not None else dict()
998
1018
  serialized, subs, final = _serial_single(obj, context, id_context_holder)
999
1019
  if final or not subs:
1000
1020
  # marked as a leaf node, return directly
1001
- return [{}, serialized], subs
1002
-
1003
- serial_stack.append(_SerialStackItem(serialized, subs))
1004
- return _serialize_with_stack(
1005
- serial_stack, None, context, id_context_holder, result_bufs_list
1006
- )
1021
+ result = [{}, serialized], subs
1022
+ else:
1023
+ serial_stack.append(_SerialStackItem(serialized, subs))
1024
+ result = _serialize_with_stack(
1025
+ serial_stack, None, context, id_context_holder, result_bufs_list
1026
+ )
1027
+ result[0][0]["_PUB"] = context.get(Serializer._public_data_context_key)
1028
+ return result
1007
1029
 
1008
1030
 
1009
1031
  async def serialize_with_spawn(
@@ -1036,31 +1058,38 @@ async def serialize_with_spawn(
1036
1058
  cdef list subs
1037
1059
  cdef bint final
1038
1060
  cdef _IdContextHolder id_context_holder = _IdContextHolder()
1061
+ cdef tuple result
1039
1062
 
1040
1063
  context = context if context is not None else dict()
1041
1064
  serialized, subs, final = _serial_single(obj, context, id_context_holder)
1042
1065
  if final or not subs:
1043
1066
  # marked as a leaf node, return directly
1044
- return [{}, serialized], subs
1045
-
1046
- serial_stack.append(_SerialStackItem(serialized, subs))
1067
+ result = [{}, serialized], subs
1068
+ else:
1069
+ serial_stack.append(_SerialStackItem(serialized, subs))
1047
1070
 
1048
- try:
1049
- result = _serialize_with_stack(
1050
- serial_stack, None, context, id_context_holder, result_bufs_list, spawn_threshold
1051
- )
1052
- except _SerializeObjectOverflow as ex:
1053
- result = await asyncio.get_running_loop().run_in_executor(
1054
- executor,
1055
- _serialize_with_stack,
1056
- serial_stack,
1057
- ex.cur_serialized,
1058
- context,
1059
- id_context_holder,
1060
- result_bufs_list,
1061
- 0,
1062
- ex.num_total_serialized,
1063
- )
1071
+ try:
1072
+ result = _serialize_with_stack(
1073
+ serial_stack,
1074
+ None,
1075
+ context,
1076
+ id_context_holder,
1077
+ result_bufs_list,
1078
+ spawn_threshold,
1079
+ )
1080
+ except _SerializeObjectOverflow as ex:
1081
+ result = await asyncio.get_running_loop().run_in_executor(
1082
+ executor,
1083
+ _serialize_with_stack,
1084
+ serial_stack,
1085
+ ex.cur_serialized,
1086
+ context,
1087
+ id_context_holder,
1088
+ result_bufs_list,
1089
+ 0,
1090
+ ex.num_total_serialized,
1091
+ )
1092
+ result[0][0]["_PUB"] = context.get(Serializer._public_data_context_key)
1064
1093
  return result
1065
1094
 
1066
1095
 
@@ -35,7 +35,7 @@ class RemoteException(MaxFrameError):
35
35
  def from_exception(cls, exc: Exception):
36
36
  try:
37
37
  buffers = pickle_buffers(exc)
38
- except (TypeError, pickle.PicklingError):
38
+ except:
39
39
  logger.exception("Cannot pickle exception %s", exc)
40
40
  buffers = []
41
41
 
@@ -176,11 +176,16 @@ class PdTimestampSerializer(Serializer):
176
176
 
177
177
  class PdTimedeltaSerializer(Serializer):
178
178
  def serial(self, obj: pd.Timedelta, context: Dict):
179
- return [int(obj.seconds), obj.microseconds, obj.nanoseconds], [], True
179
+ return [int(obj.seconds), obj.microseconds, obj.nanoseconds, obj.days], [], True
180
180
 
181
181
  def deserial(self, serialized: List, context: Dict, subs: List):
182
+ days = 0 if len(serialized) < 4 else serialized[3]
183
+ seconds, microseconds, nanoseconds = serialized[:3]
182
184
  return pd.Timedelta(
183
- seconds=serialized[0], microseconds=serialized[1], nanoseconds=serialized[2]
185
+ days=days,
186
+ seconds=seconds,
187
+ microseconds=microseconds,
188
+ nanoseconds=nanoseconds,
184
189
  )
185
190
 
186
191
 
@@ -12,12 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import operator
16
15
  import weakref
17
- from typing import Dict, List, Tuple, Type
16
+ from collections import defaultdict
17
+ from typing import Any, Dict, List, Optional, Tuple, Type
18
18
 
19
19
  import msgpack
20
20
 
21
+ from ...lib.mmh3 import hash
21
22
  from ..core import Placeholder, Serializer, buffered, load_type
22
23
  from .field import Field
23
24
  from .field_type import DictType, ListType, PrimitiveFieldType, TupleType
@@ -50,11 +51,16 @@ def _is_field_primitive_compound(field: Field):
50
51
  class SerializableMeta(type):
51
52
  def __new__(mcs, name: str, bases: Tuple[Type], properties: Dict):
52
53
  # All the fields including misc fields.
54
+ name_hash = hash(f"{properties.get('__module__')}.{name}")
53
55
  all_fields = dict()
56
+ # mapping field names to base classes
57
+ field_to_cls_hash = dict()
54
58
 
55
59
  for base in bases:
56
- if hasattr(base, "_FIELDS"):
57
- all_fields.update(base._FIELDS)
60
+ if not hasattr(base, "_FIELDS"):
61
+ continue
62
+ all_fields.update(base._FIELDS)
63
+ field_to_cls_hash.update(base._FIELD_TO_NAME_HASH)
58
64
 
59
65
  properties_without_fields = {}
60
66
  properties_field_slot_names = []
@@ -64,6 +70,8 @@ class SerializableMeta(type):
64
70
  continue
65
71
 
66
72
  field = all_fields.get(k)
73
+ # record the field for the class being created
74
+ field_to_cls_hash[k] = name_hash
67
75
  if field is None:
68
76
  properties_field_slot_names.append(k)
69
77
  else:
@@ -75,23 +83,40 @@ class SerializableMeta(type):
75
83
 
76
84
  # Make field order deterministic to serialize it as list instead of dict.
77
85
  field_order = list(all_fields)
78
- all_fields = dict(sorted(all_fields.items(), key=operator.itemgetter(0)))
79
86
  primitive_fields = []
87
+ primitive_field_names = set()
80
88
  non_primitive_fields = []
81
- for v in all_fields.values():
89
+ for field_name, v in all_fields.items():
82
90
  if _is_field_primitive_compound(v):
83
91
  primitive_fields.append(v)
92
+ primitive_field_names.add(field_name)
84
93
  else:
85
94
  non_primitive_fields.append(v)
86
95
 
96
+ # count number of fields for every base class
97
+ cls_to_primitive_field_count = defaultdict(lambda: 0)
98
+ cls_to_non_primitive_field_count = defaultdict(lambda: 0)
99
+ for field_name in field_order:
100
+ cls_hash = field_to_cls_hash[field_name]
101
+ if field_name in primitive_field_names:
102
+ cls_to_primitive_field_count[cls_hash] += 1
103
+ else:
104
+ cls_to_non_primitive_field_count[cls_hash] += 1
105
+
87
106
  slots = set(properties.pop("__slots__", set()))
88
107
  slots.update(properties_field_slot_names)
89
108
 
90
109
  properties = properties_without_fields
110
+ properties["_NAME_HASH"] = name_hash
91
111
  properties["_FIELDS"] = all_fields
92
112
  properties["_FIELD_ORDER"] = field_order
113
+ properties["_FIELD_TO_NAME_HASH"] = field_to_cls_hash
93
114
  properties["_PRIMITIVE_FIELDS"] = primitive_fields
115
+ properties["_CLS_TO_PRIMITIVE_FIELD_COUNT"] = dict(cls_to_primitive_field_count)
94
116
  properties["_NON_PRIMITIVE_FIELDS"] = non_primitive_fields
117
+ properties["_CLS_TO_NON_PRIMITIVE_FIELD_COUNT"] = dict(
118
+ cls_to_non_primitive_field_count
119
+ )
95
120
  properties["__slots__"] = tuple(slots)
96
121
 
97
122
  clz = type.__new__(mcs, name, bases, properties)
@@ -114,10 +139,14 @@ class Serializable(metaclass=SerializableMeta):
114
139
  _cache_primitive_serial = False
115
140
  _ignore_non_existing_keys = False
116
141
 
142
+ _NAME_HASH: int
117
143
  _FIELDS: Dict[str, Field]
118
144
  _FIELD_ORDER: List[str]
145
+ _FIELD_TO_NAME_HASH: Dict[str, int]
119
146
  _PRIMITIVE_FIELDS: List[str]
147
+ _CLS_TO_PRIMITIVE_FIELD_COUNT: Dict[int, int]
120
148
  _NON_PRIMITIVE_FIELDS: List[str]
149
+ _CLS_TO_NON_PRIMITIVE_FIELD_COUNT: Dict[int, int]
121
150
 
122
151
  def __init__(self, *args, **kwargs):
123
152
  fields = self._FIELDS
@@ -180,6 +209,10 @@ class SerializableSerializer(Serializer):
180
209
  Leverage DictSerializer to perform serde.
181
210
  """
182
211
 
212
+ @classmethod
213
+ def _get_obj_field_count_key(cls, obj: Serializable):
214
+ return f"FC_{obj._NAME_HASH}"
215
+
183
216
  @classmethod
184
217
  def _get_field_values(cls, obj: Serializable, fields):
185
218
  values = []
@@ -210,6 +243,18 @@ class SerializableSerializer(Serializer):
210
243
 
211
244
  compound_vals = self._get_field_values(obj, obj._NON_PRIMITIVE_FIELDS)
212
245
  cls_module = f"{type(obj).__module__}#{type(obj).__qualname__}"
246
+
247
+ field_count_key = self._get_obj_field_count_key(obj)
248
+ if not self.is_public_data_exist(context, field_count_key):
249
+ # store field distribution for current Serializable
250
+ counts = [
251
+ list(obj._CLS_TO_PRIMITIVE_FIELD_COUNT.items()),
252
+ list(obj._CLS_TO_NON_PRIMITIVE_FIELD_COUNT.items()),
253
+ ]
254
+ field_count_data = msgpack.dumps(counts)
255
+ self.put_public_data(
256
+ context, self._get_obj_field_count_key(obj), field_count_data
257
+ )
213
258
  return [cls_module, primitive_vals], [compound_vals], False
214
259
 
215
260
  @staticmethod
@@ -229,6 +274,62 @@ class SerializableSerializer(Serializer):
229
274
  else:
230
275
  field.set(obj, value)
231
276
 
277
+ @classmethod
278
+ def _set_field_values(
279
+ cls,
280
+ obj: Serializable,
281
+ values: List[Any],
282
+ client_cls_to_field_count: Optional[Dict[str, int]],
283
+ is_primitive: bool = True,
284
+ ):
285
+ obj_class = type(obj)
286
+ if is_primitive:
287
+ server_cls_to_field_count = obj_class._CLS_TO_PRIMITIVE_FIELD_COUNT
288
+ server_fields = obj_class._PRIMITIVE_FIELDS
289
+ else:
290
+ server_cls_to_field_count = obj_class._CLS_TO_NON_PRIMITIVE_FIELD_COUNT
291
+ server_fields = obj_class._NON_PRIMITIVE_FIELDS
292
+
293
+ if client_cls_to_field_count:
294
+ field_num, server_field_num = 0, 0
295
+ for cls_hash, count in client_cls_to_field_count.items():
296
+ # cut values and fields given field distribution
297
+ # at client and server end
298
+ cls_fields = server_fields[server_field_num : field_num + count]
299
+ cls_values = values[field_num : field_num + count]
300
+ for field, value in zip(cls_fields, cls_values):
301
+ if not is_primitive or value != {}:
302
+ cls._set_field_value(obj, field, value)
303
+ field_num += count
304
+ server_field_num += server_cls_to_field_count[cls_hash]
305
+ else:
306
+ # todo remove this branch when all versions below v0.1.0b5 is eliminated
307
+ from .field import AnyField
308
+
309
+ # legacy serialization style, with all fields sorted by name
310
+ if is_primitive:
311
+ field_attr = "_legacy_deprecated_primitives"
312
+ else:
313
+ field_attr = "_legacy_deprecated_non_primitives"
314
+ deprecated_fields = []
315
+ deprecated_names = set()
316
+ if hasattr(obj_class, field_attr):
317
+ deprecated_names = set(getattr(obj_class, field_attr))
318
+ for field_name in deprecated_names:
319
+ field = AnyField(tag=field_name)
320
+ field.name = field_name
321
+ deprecated_fields.append(field)
322
+ server_fields = sorted(
323
+ server_fields + deprecated_fields, key=lambda f: f.name
324
+ )
325
+ for field, value in zip(server_fields, values):
326
+ if not is_primitive or value != {}:
327
+ try:
328
+ cls._set_field_value(obj, field, value)
329
+ except AttributeError: # pragma: no cover
330
+ if field.name not in deprecated_names:
331
+ raise
332
+
232
333
  def deserial(self, serialized: List, context: Dict, subs: List) -> Serializable:
233
334
  obj_class_name, primitives = serialized
234
335
  obj_class = load_type(obj_class_name, Serializable)
@@ -238,14 +339,20 @@ class SerializableSerializer(Serializer):
238
339
 
239
340
  obj = obj_class.__new__(obj_class)
240
341
 
241
- if primitives:
242
- for field, value in zip(obj_class._PRIMITIVE_FIELDS, primitives):
243
- if value != {}:
244
- self._set_field_value(obj, field, value)
342
+ field_count_data = self.get_public_data(
343
+ context, self._get_obj_field_count_key(obj)
344
+ )
345
+ if field_count_data is not None:
346
+ cls_to_prim_key, cls_to_non_prim_key = msgpack.loads(field_count_data)
347
+ cls_to_prim_key = dict(cls_to_prim_key)
348
+ cls_to_non_prim_key = dict(cls_to_non_prim_key)
349
+ else:
350
+ cls_to_prim_key, cls_to_non_prim_key = None, None
245
351
 
352
+ if primitives:
353
+ self._set_field_values(obj, primitives, cls_to_prim_key, True)
246
354
  if obj_class._NON_PRIMITIVE_FIELDS:
247
- for field, value in zip(obj_class._NON_PRIMITIVE_FIELDS, subs[0]):
248
- self._set_field_value(obj, field, value)
355
+ self._set_field_values(obj, subs[0], cls_to_non_prim_key, False)
249
356
  obj.__on_deserialize__()
250
357
  return obj
251
358
 
@@ -94,6 +94,11 @@ class MySimpleSerializable(Serializable):
94
94
  _ref_val = ReferenceField("ref_val", "MySimpleSerializable")
95
95
 
96
96
 
97
+ class MySubSerializable(MySimpleSerializable):
98
+ _m_int_val = Int64Field("m_int_val", default=250)
99
+ _m_str_val = StringField("m_str_val", default="SUB_STR")
100
+
101
+
97
102
  class MySerializable(Serializable):
98
103
  _id = IdentityField("id")
99
104
  _any_val = AnyField("any_val")
@@ -141,7 +146,7 @@ class MySerializable(Serializable):
141
146
 
142
147
 
143
148
  @pytest.mark.parametrize("set_is_ci", [False, True], indirect=True)
144
- @switch_unpickle
149
+ @switch_unpickle(forbidden=False)
145
150
  def test_serializable(set_is_ci):
146
151
  my_serializable = MySerializable(
147
152
  _id="1",
@@ -165,7 +170,9 @@ def test_serializable(set_is_ci):
165
170
  _key_val=MyHasKey("aaa"),
166
171
  _ndarray_val=np.random.rand(4, 3),
167
172
  _datetime64_val=pd.Timestamp(123),
168
- _timedelta64_val=pd.Timedelta(days=1),
173
+ _timedelta64_val=pd.Timedelta(
174
+ days=1, seconds=123, microseconds=345, nanoseconds=132
175
+ ),
169
176
  _datatype_val=np.dtype(np.int32),
170
177
  _index_val=pd.Index([1, 2]),
171
178
  _series_val=pd.Series(["a", "b"]),
@@ -187,9 +194,44 @@ def test_serializable(set_is_ci):
187
194
  _assert_serializable_eq(my_serializable, my_serializable2)
188
195
 
189
196
 
197
+ @pytest.mark.parametrize("set_is_ci", [False, True], indirect=True)
198
+ @switch_unpickle
199
+ def test_compatible_serializable(set_is_ci):
200
+ global MySimpleSerializable, MySubSerializable
201
+
202
+ old_base, old_sub = MySimpleSerializable, MySubSerializable
203
+
204
+ try:
205
+ my_sub_serializable = MySubSerializable(
206
+ _id="id_val",
207
+ _list_val=["abcd", "wxyz"],
208
+ _ref_val=MyHasKey(),
209
+ _m_int_val=3412,
210
+ _m_str_val="dfgghj",
211
+ )
212
+ header, buffers = serialize(my_sub_serializable)
213
+
214
+ class MySimpleSerializable(Serializable):
215
+ _id = IdentityField("id")
216
+ _int_val = Int64Field("int_val", default=1000)
217
+ _list_val = ListField("list_val", default_factory=list)
218
+ _ref_val = ReferenceField("ref_val", "MySimpleSerializable")
219
+ _dict_val = DictField("dict_val")
220
+
221
+ class MySubSerializable(MySimpleSerializable):
222
+ _m_int_val = Int64Field("m_int_val", default=250)
223
+ _m_str_val = StringField("m_str_val", default="SUB_STR")
224
+
225
+ my_sub_serializable2 = deserialize(header, buffers)
226
+ assert type(my_sub_serializable) is not type(my_sub_serializable2)
227
+ _assert_serializable_eq(my_sub_serializable, my_sub_serializable2)
228
+ finally:
229
+ MySimpleSerializable, MySubSerializable = old_base, old_sub
230
+
231
+
190
232
  def _assert_serializable_eq(my_serializable, my_serializable2):
191
233
  for field_name, field in my_serializable._FIELDS.items():
192
- if not hasattr(my_serializable, field.tag):
234
+ if not hasattr(my_serializable, field.name):
193
235
  continue
194
236
  expect_value = getattr(my_serializable, field_name)
195
237
  actual_value = getattr(my_serializable2, field_name)
@@ -208,7 +250,7 @@ def _assert_serializable_eq(my_serializable, my_serializable2):
208
250
  elif callable(expect_value):
209
251
  assert expect_value(1) == actual_value(1)
210
252
  else:
211
- assert expect_value == actual_value
253
+ assert expect_value == actual_value, f"Field {field_name}"
212
254
 
213
255
 
214
256
  @pytest.mark.parametrize("set_is_ci", [True], indirect=True)
maxframe/session.py CHANGED
@@ -365,6 +365,15 @@ class AbstractAsyncSession(AbstractSession, metaclass=ABCMeta):
365
365
  Stop server.
366
366
  """
367
367
 
368
+ @abstractmethod
369
+ async def get_logview_address(self, hours=None) -> Optional[str]:
370
+ """
371
+ Get Logview address
372
+ Returns
373
+ -------
374
+ Logview address
375
+ """
376
+
368
377
  def close(self):
369
378
  asyncio.run(self.destroy())
370
379
 
@@ -549,6 +558,15 @@ class AbstractSyncSession(AbstractSession, metaclass=ABCMeta):
549
558
 
550
559
  return fetch(tileables, self, offsets=offsets, sizes=sizes)
551
560
 
561
+ @abstractmethod
562
+ def get_logview_address(self, hours=None) -> Optional[str]:
563
+ """
564
+ Get logview address
565
+ Returns
566
+ -------
567
+ logview address
568
+ """
569
+
552
570
 
553
571
  def _delegate_to_isolated_session(func: Union[Callable, Coroutine]):
554
572
  if asyncio.iscoroutinefunction(func):
@@ -728,6 +746,11 @@ class AsyncSession(AbstractAsyncSession):
728
746
  await asyncio.wrap_future(asyncio.run_coroutine_threadsafe(coro, self._loop))
729
747
  stop_isolation()
730
748
 
749
+ @implements(AbstractAsyncSession.get_logview_address)
750
+ @_delegate_to_isolated_session
751
+ async def get_logview_address(self, hours=None) -> Optional[str]:
752
+ pass # pragma: no cover
753
+
731
754
 
732
755
  class ProgressBar:
733
756
  def __init__(self, show_progress):
@@ -949,6 +972,11 @@ class SyncSession(AbstractSyncSession):
949
972
  def get_cluster_versions(self) -> List[str]:
950
973
  pass # pragma: no cover
951
974
 
975
+ @implements(AbstractSyncSession.get_logview_address)
976
+ @_delegate_to_isolated_session
977
+ def get_logview_address(self, hours=None) -> Optional[str]:
978
+ pass # pragma: no cover
979
+
952
980
  def destroy(self):
953
981
  coro = self._isolated_session.destroy()
954
982
  asyncio.run_coroutine_threadsafe(coro, self._loop).result()
@@ -114,7 +114,7 @@ from .arithmetic import (
114
114
  )
115
115
  from .arithmetic import truediv as true_divide
116
116
  from .arithmetic import trunc
117
- from .base import broadcast_to, transpose, where
117
+ from .base import broadcast_to, transpose, unique, where
118
118
  from .core import Tensor
119
119
  from .datasource import (
120
120
  arange,
@@ -252,7 +252,7 @@ def test_compare():
252
252
 
253
253
  def test_frexp():
254
254
  t1 = ones((3, 4, 5), chunk_size=2)
255
- t2 = empty((3, 4, 5), dtype=np.float_, chunk_size=2)
255
+ t2 = empty((3, 4, 5), dtype=np.dtype(float), chunk_size=2)
256
256
  op_type = type(t1.op)
257
257
 
258
258
  o1, o2 = frexp(t1)
@@ -13,9 +13,11 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from .astype import TensorAstype
16
+ from .atleast_1d import atleast_1d
16
17
  from .broadcast_to import TensorBroadcastTo, broadcast_to
17
18
  from .ravel import ravel
18
19
  from .transpose import transpose
20
+ from .unique import unique
19
21
  from .where import TensorWhere, where
20
22
 
21
23
 
@@ -0,0 +1,74 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import numpy as np
18
+
19
+ from ...core import ExecutableTuple
20
+ from ..datasource import tensor as astensor
21
+
22
+
23
+ def atleast_1d(*tensors):
24
+ """
25
+ Convert inputs to tensors with at least one dimension.
26
+
27
+ Scalar inputs are converted to 1-dimensional tensors, whilst
28
+ higher-dimensional inputs are preserved.
29
+
30
+ Parameters
31
+ ----------
32
+ tensors1, tensors2, ... : array_like
33
+ One or more input tensors.
34
+
35
+ Returns
36
+ -------
37
+ ret : Tensor
38
+ An tensor, or list of tensors, each with ``a.ndim >= 1``.
39
+ Copies are made only if necessary.
40
+
41
+ See Also
42
+ --------
43
+ atleast_2d, atleast_3d
44
+
45
+ Examples
46
+ --------
47
+ >>> import maxframe.tensor as mt
48
+
49
+ >>> mt.atleast_1d(1.0).execute()
50
+ array([ 1.])
51
+
52
+ >>> x = mt.arange(9.0).reshape(3,3)
53
+ >>> mt.atleast_1d(x).execute()
54
+ array([[ 0., 1., 2.],
55
+ [ 3., 4., 5.],
56
+ [ 6., 7., 8.]])
57
+ >>> mt.atleast_1d(x) is x
58
+ True
59
+
60
+ >>> mt.atleast_1d(1, [3, 4]).execute()
61
+ [array([1]), array([3, 4])]
62
+
63
+ """
64
+ new_tensors = []
65
+ for x in tensors:
66
+ x = astensor(x)
67
+ if x.ndim == 0:
68
+ x = x[np.newaxis]
69
+
70
+ new_tensors.append(x)
71
+
72
+ if len(new_tensors) == 1:
73
+ return new_tensors[0]
74
+ return ExecutableTuple(new_tensors)