maxframe 1.0.0rc1__cp39-cp39-macosx_10_9_universal2.whl → 1.0.0rc3__cp39-cp39-macosx_10_9_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (138) hide show
  1. maxframe/_utils.cpython-39-darwin.so +0 -0
  2. maxframe/codegen.py +3 -6
  3. maxframe/config/config.py +49 -10
  4. maxframe/config/validators.py +42 -11
  5. maxframe/conftest.py +15 -2
  6. maxframe/core/__init__.py +2 -13
  7. maxframe/core/entity/__init__.py +0 -4
  8. maxframe/core/entity/objects.py +46 -3
  9. maxframe/core/entity/output_types.py +0 -3
  10. maxframe/core/entity/tests/test_objects.py +43 -0
  11. maxframe/core/entity/tileables.py +5 -78
  12. maxframe/core/graph/__init__.py +2 -2
  13. maxframe/core/graph/builder/__init__.py +0 -1
  14. maxframe/core/graph/builder/base.py +5 -4
  15. maxframe/core/graph/builder/tileable.py +4 -4
  16. maxframe/core/graph/builder/utils.py +4 -8
  17. maxframe/core/graph/core.cpython-39-darwin.so +0 -0
  18. maxframe/core/graph/entity.py +9 -33
  19. maxframe/core/operator/__init__.py +2 -9
  20. maxframe/core/operator/base.py +3 -5
  21. maxframe/core/operator/objects.py +0 -9
  22. maxframe/core/operator/utils.py +55 -0
  23. maxframe/dataframe/__init__.py +1 -1
  24. maxframe/dataframe/arithmetic/around.py +5 -17
  25. maxframe/dataframe/arithmetic/core.py +15 -7
  26. maxframe/dataframe/arithmetic/docstring.py +5 -55
  27. maxframe/dataframe/arithmetic/tests/test_arithmetic.py +22 -0
  28. maxframe/dataframe/core.py +5 -5
  29. maxframe/dataframe/datasource/date_range.py +2 -2
  30. maxframe/dataframe/datasource/read_odps_query.py +7 -1
  31. maxframe/dataframe/datasource/read_odps_table.py +3 -2
  32. maxframe/dataframe/datasource/tests/test_datasource.py +14 -0
  33. maxframe/dataframe/datastore/to_odps.py +1 -1
  34. maxframe/dataframe/groupby/cum.py +0 -1
  35. maxframe/dataframe/groupby/tests/test_groupby.py +4 -0
  36. maxframe/dataframe/indexing/add_prefix_suffix.py +1 -1
  37. maxframe/dataframe/indexing/rename.py +3 -37
  38. maxframe/dataframe/indexing/sample.py +0 -1
  39. maxframe/dataframe/indexing/set_index.py +68 -1
  40. maxframe/dataframe/merge/merge.py +236 -2
  41. maxframe/dataframe/merge/tests/test_merge.py +123 -0
  42. maxframe/dataframe/misc/apply.py +3 -10
  43. maxframe/dataframe/misc/case_when.py +1 -1
  44. maxframe/dataframe/misc/describe.py +2 -2
  45. maxframe/dataframe/misc/drop_duplicates.py +4 -25
  46. maxframe/dataframe/misc/eval.py +4 -0
  47. maxframe/dataframe/misc/pct_change.py +1 -83
  48. maxframe/dataframe/misc/transform.py +1 -30
  49. maxframe/dataframe/misc/value_counts.py +4 -17
  50. maxframe/dataframe/missing/dropna.py +1 -1
  51. maxframe/dataframe/missing/fillna.py +5 -5
  52. maxframe/dataframe/operators.py +1 -17
  53. maxframe/dataframe/reduction/core.py +2 -2
  54. maxframe/dataframe/sort/sort_values.py +1 -11
  55. maxframe/dataframe/statistics/quantile.py +5 -17
  56. maxframe/dataframe/utils.py +4 -7
  57. maxframe/io/objects/__init__.py +24 -0
  58. maxframe/io/objects/core.py +140 -0
  59. maxframe/io/objects/tensor.py +76 -0
  60. maxframe/io/objects/tests/__init__.py +13 -0
  61. maxframe/io/objects/tests/test_object_io.py +97 -0
  62. maxframe/{odpsio → io/odpsio}/__init__.py +3 -1
  63. maxframe/{odpsio → io/odpsio}/arrow.py +12 -8
  64. maxframe/{odpsio → io/odpsio}/schema.py +15 -12
  65. maxframe/io/odpsio/tableio.py +702 -0
  66. maxframe/io/odpsio/tests/__init__.py +13 -0
  67. maxframe/{odpsio → io/odpsio}/tests/test_schema.py +19 -18
  68. maxframe/{odpsio → io/odpsio}/tests/test_tableio.py +50 -23
  69. maxframe/{odpsio → io/odpsio}/tests/test_volumeio.py +4 -6
  70. maxframe/io/odpsio/volumeio.py +57 -0
  71. maxframe/learn/contrib/xgboost/classifier.py +26 -2
  72. maxframe/learn/contrib/xgboost/core.py +87 -2
  73. maxframe/learn/contrib/xgboost/dmatrix.py +3 -6
  74. maxframe/learn/contrib/xgboost/predict.py +21 -7
  75. maxframe/learn/contrib/xgboost/regressor.py +3 -10
  76. maxframe/learn/contrib/xgboost/train.py +27 -17
  77. maxframe/{core/operator/fuse.py → learn/core.py} +7 -10
  78. maxframe/lib/mmh3.cpython-39-darwin.so +0 -0
  79. maxframe/protocol.py +41 -17
  80. maxframe/remote/core.py +4 -8
  81. maxframe/serialization/__init__.py +1 -0
  82. maxframe/serialization/core.cpython-39-darwin.so +0 -0
  83. maxframe/serialization/serializables/core.py +48 -9
  84. maxframe/tensor/__init__.py +69 -2
  85. maxframe/tensor/arithmetic/isclose.py +1 -0
  86. maxframe/tensor/arithmetic/tests/test_arithmetic.py +21 -17
  87. maxframe/tensor/core.py +5 -136
  88. maxframe/tensor/datasource/array.py +3 -0
  89. maxframe/tensor/datasource/full.py +1 -1
  90. maxframe/tensor/datasource/tests/test_datasource.py +1 -1
  91. maxframe/tensor/indexing/flatnonzero.py +1 -1
  92. maxframe/tensor/merge/__init__.py +2 -0
  93. maxframe/tensor/merge/concatenate.py +98 -0
  94. maxframe/tensor/merge/tests/test_merge.py +30 -1
  95. maxframe/tensor/merge/vstack.py +70 -0
  96. maxframe/tensor/{base → misc}/__init__.py +2 -0
  97. maxframe/tensor/{base → misc}/atleast_1d.py +0 -2
  98. maxframe/tensor/misc/atleast_2d.py +70 -0
  99. maxframe/tensor/misc/atleast_3d.py +85 -0
  100. maxframe/tensor/misc/tests/__init__.py +13 -0
  101. maxframe/tensor/{base → misc}/transpose.py +22 -18
  102. maxframe/tensor/{base → misc}/unique.py +2 -2
  103. maxframe/tensor/operators.py +1 -7
  104. maxframe/tensor/random/core.py +1 -1
  105. maxframe/tensor/reduction/count_nonzero.py +1 -0
  106. maxframe/tensor/reduction/mean.py +1 -0
  107. maxframe/tensor/reduction/nanmean.py +1 -0
  108. maxframe/tensor/reduction/nanvar.py +2 -0
  109. maxframe/tensor/reduction/tests/test_reduction.py +12 -1
  110. maxframe/tensor/reduction/var.py +2 -0
  111. maxframe/tensor/statistics/quantile.py +2 -2
  112. maxframe/tensor/utils.py +2 -22
  113. maxframe/tests/utils.py +11 -2
  114. maxframe/typing_.py +4 -1
  115. maxframe/udf.py +8 -9
  116. maxframe/utils.py +32 -70
  117. {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/METADATA +25 -25
  118. {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/RECORD +133 -123
  119. {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/WHEEL +1 -1
  120. maxframe_client/fetcher.py +60 -68
  121. maxframe_client/session/graph.py +8 -2
  122. maxframe_client/session/odps.py +58 -22
  123. maxframe_client/tests/test_fetcher.py +21 -3
  124. maxframe_client/tests/test_session.py +27 -4
  125. maxframe/core/entity/chunks.py +0 -68
  126. maxframe/core/entity/fuse.py +0 -73
  127. maxframe/core/graph/builder/chunk.py +0 -430
  128. maxframe/odpsio/tableio.py +0 -322
  129. maxframe/odpsio/volumeio.py +0 -95
  130. /maxframe/{odpsio → core/entity}/tests/__init__.py +0 -0
  131. /maxframe/{tensor/base/tests → io}/__init__.py +0 -0
  132. /maxframe/{odpsio → io/odpsio}/tests/test_arrow.py +0 -0
  133. /maxframe/tensor/{base → misc}/astype.py +0 -0
  134. /maxframe/tensor/{base → misc}/broadcast_to.py +0 -0
  135. /maxframe/tensor/{base → misc}/ravel.py +0 -0
  136. /maxframe/tensor/{base/tests/test_base.py → misc/tests/test_misc.py} +0 -0
  137. /maxframe/tensor/{base → misc}/where.py +0 -0
  138. {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@
15
15
  import logging
16
16
  from collections import OrderedDict
17
17
 
18
- from .... import opcodes as OperandDef
18
+ from .... import opcodes
19
19
  from ....core import OutputType
20
20
  from ....core.operator.base import Operator
21
21
  from ....core.operator.core import TileableOperatorMixin
@@ -29,6 +29,7 @@ from ....serialization.serializables import (
29
29
  KeyField,
30
30
  ListField,
31
31
  )
32
+ from .core import Booster
32
33
  from .dmatrix import ToDMatrix, to_dmatrix
33
34
 
34
35
  logger = logging.getLogger(__name__)
@@ -41,7 +42,7 @@ def _on_serialize_evals(evals_val):
41
42
 
42
43
 
43
44
  class XGBTrain(Operator, TileableOperatorMixin):
44
- _op_type_ = OperandDef.XGBOOST_TRAIN
45
+ _op_type_ = opcodes.XGBOOST_TRAIN
45
46
 
46
47
  params = DictField("params", key_type=FieldTypes.string, default=None)
47
48
  dtrain = KeyField("dtrain", default=None)
@@ -59,49 +60,59 @@ class XGBTrain(Operator, TileableOperatorMixin):
59
60
  num_boost_round = Int64Field("num_boost_round", default=10)
60
61
  num_class = Int64Field("num_class", default=None)
61
62
 
62
- # Store evals_result in local to store the remote evals_result
63
- evals_result: dict = None
64
-
65
63
  def __init__(self, gpu=None, **kw):
66
64
  super().__init__(gpu=gpu, **kw)
67
65
  if self.output_types is None:
68
66
  self.output_types = [OutputType.object]
67
+ if self.has_evals_result:
68
+ self.output_types.append(OutputType.object)
69
69
 
70
70
  def _set_inputs(self, inputs):
71
71
  super()._set_inputs(inputs)
72
72
  self.dtrain = self._inputs[0]
73
73
  rest = self._inputs[1:]
74
- if self.evals is not None:
74
+ if self.has_evals_result:
75
75
  evals_dict = OrderedDict(self.evals)
76
76
  new_evals_dict = OrderedDict()
77
77
  for new_key, val in zip(rest, evals_dict.values()):
78
78
  new_evals_dict[new_key] = val
79
79
  self.evals = list(new_evals_dict.items())
80
80
 
81
- def __call__(self):
81
+ def __call__(self, evals_result):
82
82
  inputs = [self.dtrain]
83
- if self.evals is not None:
83
+ if self.has_evals_result:
84
84
  inputs.extend(e[0] for e in self.evals)
85
- return self.new_tileable(inputs)
85
+ return self.new_tileables(
86
+ inputs, object_class=Booster, evals_result=evals_result
87
+ )[0]
88
+
89
+ @property
90
+ def output_limit(self):
91
+ return 2 if self.has_evals_result else 1
92
+
93
+ @property
94
+ def has_evals_result(self) -> bool:
95
+ return self.evals
86
96
 
87
97
 
88
98
  def train(params, dtrain, evals=None, evals_result=None, num_class=None, **kwargs):
89
99
  """
90
- Train XGBoost model in Mars manner.
100
+ Train XGBoost model in MaxFrame manner.
91
101
 
92
102
  Parameters
93
103
  ----------
94
- Parameters are the same as `xgboost.train`.
104
+ Parameters are the same as `xgboost.train`. Note that train is an eager-execution
105
+ API. The call will be blocked until training finished.
95
106
 
96
107
  Returns
97
108
  -------
98
109
  results: Booster
99
110
  """
100
111
 
101
- evals_result = evals_result or dict()
102
- evals = None or ()
103
-
112
+ evals_result = evals_result if evals_result is not None else dict()
104
113
  processed_evals = []
114
+ session = kwargs.pop("session", None)
115
+ run_kwargs = kwargs.pop("run_kwargs", dict())
105
116
  if evals:
106
117
  for eval_dmatrix, name in evals:
107
118
  if not isinstance(name, str):
@@ -110,12 +121,11 @@ def train(params, dtrain, evals=None, evals_result=None, num_class=None, **kwarg
110
121
  processed_evals.append((eval_dmatrix, name))
111
122
  else:
112
123
  processed_evals.append((to_dmatrix(eval_dmatrix), name))
113
-
114
124
  return XGBTrain(
115
125
  params=params,
116
126
  dtrain=dtrain,
117
127
  evals=processed_evals,
118
128
  evals_result=evals_result,
119
129
  num_class=num_class,
120
- **kwargs
121
- )()
130
+ **kwargs,
131
+ )(evals_result).execute(session=session, **run_kwargs)
@@ -12,18 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ... import opcodes
16
- from ...serialization.serializables import ReferenceField
17
- from ..graph import ChunkGraph
18
- from .base import Operator
15
+ from ..core.entity.objects import Object, ObjectData
19
16
 
20
17
 
21
- class Fuse(Operator):
22
- __slots__ = ("_fuse_graph",)
23
- _op_type_ = opcodes.FUSE
18
+ class ModelData(ObjectData):
19
+ pass
24
20
 
25
- fuse_graph = ReferenceField("fuse_graph", ChunkGraph)
26
21
 
22
+ class Model(Object):
23
+ pass
27
24
 
28
- class FuseChunkMixin:
29
- __slots__ = ()
25
+
26
+ MODEL_TYPE = (Model, ModelData)
Binary file
maxframe/protocol.py CHANGED
@@ -15,7 +15,7 @@
15
15
  import base64
16
16
  import enum
17
17
  import uuid
18
- from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
18
+ from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
19
19
 
20
20
  import pandas as pd
21
21
 
@@ -38,7 +38,6 @@ from .serialization.serializables import (
38
38
  Serializable,
39
39
  SeriesField,
40
40
  StringField,
41
- TupleField,
42
41
  )
43
42
 
44
43
  pickling_support.install()
@@ -92,19 +91,6 @@ class DataSerializeType(enum.Enum):
92
91
  PICKLE = 0
93
92
 
94
93
 
95
- class VolumeDataMeta(Serializable):
96
- output_type: OutputType = EnumField(
97
- "output_type", OutputType, FieldTypes.int8, default=None
98
- )
99
- serial_type: DataSerializeType = EnumField(
100
- "serial_type", DataSerializeType, FieldTypes.int8, default=None
101
- )
102
- shape: Tuple[int, ...] = TupleField("shape", FieldTypes.int64, default=None)
103
- nsplits: Tuple[Tuple[int, ...], ...] = TupleField(
104
- "nsplits", FieldTypes.tuple(FieldTypes.tuple(FieldTypes.int64)), default=None
105
- )
106
-
107
-
108
94
  _result_type_to_info_cls: Dict[ResultType, Type["ResultInfo"]] = dict()
109
95
 
110
96
 
@@ -154,6 +140,9 @@ class ODPSTableResultInfo(ResultInfo):
154
140
  partition_specs: Optional[List[str]] = ListField(
155
141
  "partition_specs", FieldTypes.string, default=None
156
142
  )
143
+ table_meta: Optional["DataFrameTableMeta"] = ReferenceField(
144
+ "table_meta", default=None
145
+ )
157
146
 
158
147
  def __init__(self, result_type: ResultType = None, **kw):
159
148
  result_type = result_type or ResultType.ODPS_TABLE
@@ -164,8 +153,17 @@ class ODPSTableResultInfo(ResultInfo):
164
153
  ret["full_table_name"] = self.full_table_name
165
154
  if self.partition_specs:
166
155
  ret["partition_specs"] = self.partition_specs
156
+ if self.table_meta:
157
+ ret["table_meta"] = self.table_meta.to_json()
167
158
  return ret
168
159
 
160
+ @classmethod
161
+ def _json_to_kwargs(cls, serialized: dict) -> dict:
162
+ kw = super()._json_to_kwargs(serialized)
163
+ if "table_meta" in kw:
164
+ kw["table_meta"] = DataFrameTableMeta.from_json(kw["table_meta"])
165
+ return kw
166
+
169
167
 
170
168
  class ODPSVolumeResultInfo(ResultInfo):
171
169
  _result_type = ResultType.ODPS_VOLUME
@@ -469,7 +467,7 @@ class DecrefRequest(Serializable):
469
467
  keys: List[str] = ListField("keys", FieldTypes.string, default=None)
470
468
 
471
469
 
472
- class DataFrameTableMeta(Serializable):
470
+ class DataFrameTableMeta(JsonSerializable):
473
471
  __slots__ = "_pd_column_names", "_pd_index_level_names"
474
472
 
475
473
  table_name: Optional[str] = StringField("table_name", default=None)
@@ -500,7 +498,7 @@ class DataFrameTableMeta(Serializable):
500
498
  self._pd_index_level_names = self.pd_index_dtypes.index.tolist()
501
499
  return self._pd_index_level_names
502
500
 
503
- def __eq__(self, other: "Serializable") -> bool:
501
+ def __eq__(self, other: "DataFrameTableMeta") -> bool:
504
502
  if not isinstance(other, type(self)):
505
503
  return False
506
504
  for k in self._FIELDS:
@@ -511,3 +509,29 @@ class DataFrameTableMeta(Serializable):
511
509
  if not is_same:
512
510
  return False
513
511
  return True
512
+
513
+ def to_json(self) -> dict:
514
+ b64_pk = lambda x: base64.b64encode(pickle.dumps(x))
515
+ ret = {
516
+ "table_name": self.table_name,
517
+ "type": self.type.value,
518
+ "table_column_names": self.table_column_names,
519
+ "table_index_column_names": self.table_index_column_names,
520
+ "pd_column_dtypes": b64_pk(self.pd_column_dtypes),
521
+ "pd_column_level_names": b64_pk(self.pd_column_level_names),
522
+ "pd_index_dtypes": b64_pk(self.pd_index_dtypes),
523
+ }
524
+ return ret
525
+
526
+ @classmethod
527
+ def from_json(cls, serialized: dict) -> "DataFrameTableMeta":
528
+ b64_upk = lambda x: pickle.loads(base64.b64decode(x))
529
+ serialized.update(
530
+ {
531
+ "type": OutputType(serialized["type"]),
532
+ "pd_column_dtypes": b64_upk(serialized["pd_column_dtypes"]),
533
+ "pd_column_level_names": b64_upk(serialized["pd_column_level_names"]),
534
+ "pd_index_dtypes": b64_upk(serialized["pd_index_dtypes"]),
535
+ }
536
+ )
537
+ return DataFrameTableMeta(**serialized)
maxframe/remote/core.py CHANGED
@@ -15,7 +15,7 @@
15
15
  from functools import partial
16
16
 
17
17
  from .. import opcodes
18
- from ..core import ENTITY_TYPE, ChunkData
18
+ from ..core import ENTITY_TYPE
19
19
  from ..core.operator import ObjectOperator, ObjectOperatorMixin
20
20
  from ..dataframe.core import DATAFRAME_TYPE, INDEX_TYPE, SERIES_TYPE
21
21
  from ..serialization.serializables import (
@@ -26,7 +26,7 @@ from ..serialization.serializables import (
26
26
  ListField,
27
27
  )
28
28
  from ..tensor.core import TENSOR_TYPE
29
- from ..utils import build_fetch_tileable, find_objects, replace_objects
29
+ from ..utils import find_objects, replace_objects
30
30
 
31
31
 
32
32
  class RemoteFunction(ObjectOperatorMixin, ObjectOperator):
@@ -63,12 +63,8 @@ class RemoteFunction(ObjectOperatorMixin, ObjectOperator):
63
63
  if raw_inputs is not None:
64
64
  for raw_inp in raw_inputs:
65
65
  if self._no_prepare(raw_inp):
66
- if not isinstance(self._inputs[0], ChunkData):
67
- # not in tile, set_inputs from tileable
68
- mapping[raw_inp] = next(function_inputs)
69
- else:
70
- # in tile, set_inputs from chunk
71
- mapping[raw_inp] = build_fetch_tileable(raw_inp)
66
+ # not in tile, set_inputs from tileable
67
+ mapping[raw_inp] = next(function_inputs)
72
68
  else:
73
69
  mapping[raw_inp] = next(function_inputs)
74
70
  self.function_args = replace_objects(self.function_args, mapping)
@@ -17,6 +17,7 @@ from .core import (
17
17
  PickleContainer,
18
18
  Serializer,
19
19
  deserialize,
20
+ load_type,
20
21
  pickle_buffers,
21
22
  serialize,
22
23
  serialize_with_spawn,
@@ -51,7 +51,10 @@ def _is_field_primitive_compound(field: Field):
51
51
  class SerializableMeta(type):
52
52
  def __new__(mcs, name: str, bases: Tuple[Type], properties: Dict):
53
53
  # All the fields including misc fields.
54
- name_hash = hash(f"{properties.get('__module__')}.{name}")
54
+ legacy_name_hash = hash(f"{properties.get('__module__')}.{name}")
55
+ name_hash = hash(
56
+ f"{properties.get('__module__')}.{properties.get('__qualname__')}"
57
+ )
55
58
  all_fields = dict()
56
59
  # mapping field names to base classes
57
60
  field_to_cls_hash = dict()
@@ -107,6 +110,10 @@ class SerializableMeta(type):
107
110
  slots.update(properties_field_slot_names)
108
111
 
109
112
  properties = properties_without_fields
113
+
114
+ # todo remove this prop when all versions below v1.0.0rc1 is eliminated
115
+ properties["_LEGACY_NAME_HASH"] = legacy_name_hash
116
+
110
117
  properties["_NAME_HASH"] = name_hash
111
118
  properties["_FIELDS"] = all_fields
112
119
  properties["_FIELD_ORDER"] = field_order
@@ -210,8 +217,8 @@ class SerializableSerializer(Serializer):
210
217
  """
211
218
 
212
219
  @classmethod
213
- def _get_obj_field_count_key(cls, obj: Serializable):
214
- return f"FC_{obj._NAME_HASH}"
220
+ def _get_obj_field_count_key(cls, obj: Serializable, legacy: bool = False):
221
+ return f"FC_{obj._NAME_HASH if not legacy else obj._LEGACY_NAME_HASH}"
215
222
 
216
223
  @classmethod
217
224
  def _get_field_values(cls, obj: Serializable, fields):
@@ -290,6 +297,12 @@ class SerializableSerializer(Serializer):
290
297
  server_cls_to_field_count = obj_class._CLS_TO_NON_PRIMITIVE_FIELD_COUNT
291
298
  server_fields = obj_class._NON_PRIMITIVE_FIELDS
292
299
 
300
+ legacy_to_new_hash = {
301
+ c._LEGACY_NAME_HASH: c._NAME_HASH
302
+ for c in obj_class.__mro__
303
+ if hasattr(c, "_NAME_HASH") and c._LEGACY_NAME_HASH != c._NAME_HASH
304
+ }
305
+
293
306
  if client_cls_to_field_count:
294
307
  field_num, server_field_num = 0, 0
295
308
  for cls_hash, count in client_cls_to_field_count.items():
@@ -301,20 +314,40 @@ class SerializableSerializer(Serializer):
301
314
  if not is_primitive or value != {}:
302
315
  cls._set_field_value(obj, field, value)
303
316
  field_num += count
304
- server_field_num += server_cls_to_field_count[cls_hash]
317
+ try:
318
+ server_field_num += server_cls_to_field_count[cls_hash]
319
+ except KeyError:
320
+ try:
321
+ # todo remove this fallback when all
322
+ # versions below v1.0.0rc1 is eliminated
323
+ server_field_num += server_cls_to_field_count[
324
+ legacy_to_new_hash[cls_hash]
325
+ ]
326
+ except KeyError:
327
+ # it is possible that certain type of field does not exist
328
+ # at server side
329
+ pass
305
330
  else:
331
+ # handle legacy serialization style, with all fields sorted by name
306
332
  # todo remove this branch when all versions below v0.1.0b5 is eliminated
307
333
  from .field import AnyField
308
334
 
309
- # legacy serialization style, with all fields sorted by name
310
335
  if is_primitive:
311
- field_attr = "_legacy_deprecated_primitives"
336
+ new_field_attr = "_legacy_new_primitives"
337
+ deprecated_field_attr = "_legacy_deprecated_primitives"
312
338
  else:
313
- field_attr = "_legacy_deprecated_non_primitives"
339
+ new_field_attr = "_legacy_new_non_primitives"
340
+ deprecated_field_attr = "_legacy_deprecated_non_primitives"
341
+
342
+ # remove fields added on later releases
343
+ new_names = set(getattr(obj_class, new_field_attr, None) or [])
344
+ server_fields = [f for f in server_fields if f.name not in new_names]
345
+
346
+ # fill fields deprecated on later releases
314
347
  deprecated_fields = []
315
348
  deprecated_names = set()
316
- if hasattr(obj_class, field_attr):
317
- deprecated_names = set(getattr(obj_class, field_attr))
349
+ if hasattr(obj_class, deprecated_field_attr):
350
+ deprecated_names = set(getattr(obj_class, deprecated_field_attr))
318
351
  for field_name in deprecated_names:
319
352
  field = AnyField(tag=field_name)
320
353
  field.name = field_name
@@ -342,6 +375,12 @@ class SerializableSerializer(Serializer):
342
375
  field_count_data = self.get_public_data(
343
376
  context, self._get_obj_field_count_key(obj)
344
377
  )
378
+ if field_count_data is None:
379
+ # todo remove this fallback when all
380
+ # versions below v1.0.0rc1 is eliminated
381
+ field_count_data = self.get_public_data(
382
+ context, self._get_obj_field_count_key(obj, legacy=True)
383
+ )
345
384
  if field_count_data is not None:
346
385
  cls_to_prim_key, cls_to_non_prim_key = msgpack.loads(field_count_data)
347
386
  cls_to_prim_key = dict(cls_to_prim_key)
@@ -114,7 +114,6 @@ from .arithmetic import (
114
114
  )
115
115
  from .arithmetic import truediv as true_divide
116
116
  from .arithmetic import trunc
117
- from .base import broadcast_to, transpose, unique, where
118
117
  from .core import Tensor
119
118
  from .datasource import (
120
119
  arange,
@@ -143,7 +142,16 @@ from .indexing import (
143
142
  take,
144
143
  unravel_index,
145
144
  )
146
- from .merge import stack
145
+ from .merge import concatenate, stack, vstack
146
+ from .misc import (
147
+ atleast_1d,
148
+ atleast_2d,
149
+ atleast_3d,
150
+ broadcast_to,
151
+ transpose,
152
+ unique,
153
+ where,
154
+ )
147
155
  from .rechunk import rechunk
148
156
  from .reduction import (
149
157
  all,
@@ -180,4 +188,63 @@ from .reduction import std, sum, var
180
188
  from .reshape import reshape
181
189
  from .ufunc import ufunc
182
190
 
191
+ # isort: off
192
+ # noinspection PyUnresolvedReferences
193
+ from numpy import (
194
+ NAN,
195
+ NINF,
196
+ AxisError,
197
+ Inf,
198
+ NaN,
199
+ e,
200
+ errstate,
201
+ geterr,
202
+ inf,
203
+ nan,
204
+ newaxis,
205
+ pi,
206
+ seterr,
207
+ )
208
+
209
+ # import numpy types
210
+ # noinspection PyUnresolvedReferences
211
+ from numpy import (
212
+ bool_ as bool,
213
+ bytes_,
214
+ cfloat,
215
+ character,
216
+ complex64,
217
+ complex128,
218
+ complexfloating,
219
+ datetime64,
220
+ double,
221
+ dtype,
222
+ flexible,
223
+ float16,
224
+ float32,
225
+ float64,
226
+ floating,
227
+ generic,
228
+ inexact,
229
+ int8,
230
+ int16,
231
+ int32,
232
+ int64,
233
+ intc,
234
+ intp,
235
+ number,
236
+ integer,
237
+ object_ as object,
238
+ signedinteger,
239
+ timedelta64,
240
+ uint,
241
+ uint8,
242
+ uint16,
243
+ uint32,
244
+ uint64,
245
+ unicode_,
246
+ unsignedinteger,
247
+ void,
248
+ )
249
+
183
250
  del fetch, ufunc
@@ -23,6 +23,7 @@ from .core import TensorBinOp
23
23
 
24
24
  class TensorIsclose(TensorBinOp):
25
25
  _op_type_ = opcodes.ISCLOSE
26
+ _func_name = "isclose"
26
27
 
27
28
  rtol = Float64Field("rtol", default=None)
28
29
  atol = Float64Field("atol", default=None)
@@ -17,26 +17,13 @@
17
17
  import numpy as np
18
18
  import pytest
19
19
 
20
+ from maxframe.tensor.arithmetic.core import TensorBinOp, TensorUnaryOp
21
+ from maxframe.utils import collect_leaf_operators
22
+
20
23
  from ....core import enter_mode
21
24
  from ...core import SparseTensor, Tensor
22
25
  from ...datasource import array, empty, ones, tensor
23
- from .. import (
24
- TensorAdd,
25
- TensorGreaterThan,
26
- TensorIsclose,
27
- TensorLog,
28
- TensorSubtract,
29
- add,
30
- around,
31
- cos,
32
- frexp,
33
- isclose,
34
- isfinite,
35
- log,
36
- negative,
37
- subtract,
38
- truediv,
39
- )
26
+ from .. import * # noqa: F401
40
27
 
41
28
 
42
29
  def test_add():
@@ -412,3 +399,20 @@ def test_build_mode():
412
399
 
413
400
  with enter_mode(build=True):
414
401
  assert t1 != 2
402
+
403
+
404
+ def test_unary_op_func_name():
405
+ # make sure all the unary op has defined the func name.
406
+
407
+ results = collect_leaf_operators(TensorUnaryOp)
408
+ for op_type in results:
409
+ assert hasattr(op_type, "_func_name")
410
+
411
+
412
+ def test_binary_op_func_name():
413
+ # make sure all the binary op has defined the func name.
414
+
415
+ results = collect_leaf_operators(TensorBinOp)
416
+ for op_type in results:
417
+ if op_type not in (TensorSetImag, TensorSetReal):
418
+ assert hasattr(op_type, "_func_name")