maxframe 1.0.0rc1__cp311-cp311-win_amd64.whl → 1.0.0rc3__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of maxframe might be problematic. Click here for more details.
- maxframe/_utils.cp311-win_amd64.pyd +0 -0
- maxframe/codegen.py +3 -6
- maxframe/config/config.py +49 -10
- maxframe/config/validators.py +42 -11
- maxframe/conftest.py +15 -2
- maxframe/core/__init__.py +2 -13
- maxframe/core/entity/__init__.py +0 -4
- maxframe/core/entity/objects.py +46 -3
- maxframe/core/entity/output_types.py +0 -3
- maxframe/core/entity/tests/test_objects.py +43 -0
- maxframe/core/entity/tileables.py +5 -78
- maxframe/core/graph/__init__.py +2 -2
- maxframe/core/graph/builder/__init__.py +0 -1
- maxframe/core/graph/builder/base.py +5 -4
- maxframe/core/graph/builder/tileable.py +4 -4
- maxframe/core/graph/builder/utils.py +4 -8
- maxframe/core/graph/core.cp311-win_amd64.pyd +0 -0
- maxframe/core/graph/entity.py +9 -33
- maxframe/core/operator/__init__.py +2 -9
- maxframe/core/operator/base.py +3 -5
- maxframe/core/operator/objects.py +0 -9
- maxframe/core/operator/utils.py +55 -0
- maxframe/dataframe/__init__.py +1 -1
- maxframe/dataframe/arithmetic/around.py +5 -17
- maxframe/dataframe/arithmetic/core.py +15 -7
- maxframe/dataframe/arithmetic/docstring.py +5 -55
- maxframe/dataframe/arithmetic/tests/test_arithmetic.py +22 -0
- maxframe/dataframe/core.py +5 -5
- maxframe/dataframe/datasource/date_range.py +2 -2
- maxframe/dataframe/datasource/read_odps_query.py +7 -1
- maxframe/dataframe/datasource/read_odps_table.py +3 -2
- maxframe/dataframe/datasource/tests/test_datasource.py +14 -0
- maxframe/dataframe/datastore/to_odps.py +1 -1
- maxframe/dataframe/groupby/cum.py +0 -1
- maxframe/dataframe/groupby/tests/test_groupby.py +4 -0
- maxframe/dataframe/indexing/add_prefix_suffix.py +1 -1
- maxframe/dataframe/indexing/rename.py +3 -37
- maxframe/dataframe/indexing/sample.py +0 -1
- maxframe/dataframe/indexing/set_index.py +68 -1
- maxframe/dataframe/merge/merge.py +236 -2
- maxframe/dataframe/merge/tests/test_merge.py +123 -0
- maxframe/dataframe/misc/apply.py +3 -10
- maxframe/dataframe/misc/case_when.py +1 -1
- maxframe/dataframe/misc/describe.py +2 -2
- maxframe/dataframe/misc/drop_duplicates.py +4 -25
- maxframe/dataframe/misc/eval.py +4 -0
- maxframe/dataframe/misc/pct_change.py +1 -83
- maxframe/dataframe/misc/transform.py +1 -30
- maxframe/dataframe/misc/value_counts.py +4 -17
- maxframe/dataframe/missing/dropna.py +1 -1
- maxframe/dataframe/missing/fillna.py +5 -5
- maxframe/dataframe/operators.py +1 -17
- maxframe/dataframe/reduction/core.py +2 -2
- maxframe/dataframe/sort/sort_values.py +1 -11
- maxframe/dataframe/statistics/quantile.py +5 -17
- maxframe/dataframe/utils.py +4 -7
- maxframe/io/objects/__init__.py +24 -0
- maxframe/io/objects/core.py +140 -0
- maxframe/io/objects/tensor.py +76 -0
- maxframe/io/objects/tests/__init__.py +13 -0
- maxframe/io/objects/tests/test_object_io.py +97 -0
- maxframe/{odpsio → io/odpsio}/__init__.py +3 -1
- maxframe/{odpsio → io/odpsio}/arrow.py +12 -8
- maxframe/{odpsio → io/odpsio}/schema.py +15 -12
- maxframe/io/odpsio/tableio.py +702 -0
- maxframe/io/odpsio/tests/__init__.py +13 -0
- maxframe/{odpsio → io/odpsio}/tests/test_schema.py +19 -18
- maxframe/{odpsio → io/odpsio}/tests/test_tableio.py +50 -23
- maxframe/{odpsio → io/odpsio}/tests/test_volumeio.py +4 -6
- maxframe/io/odpsio/volumeio.py +57 -0
- maxframe/learn/contrib/xgboost/classifier.py +26 -2
- maxframe/learn/contrib/xgboost/core.py +87 -2
- maxframe/learn/contrib/xgboost/dmatrix.py +3 -6
- maxframe/learn/contrib/xgboost/predict.py +21 -7
- maxframe/learn/contrib/xgboost/regressor.py +3 -10
- maxframe/learn/contrib/xgboost/train.py +27 -17
- maxframe/{core/operator/fuse.py → learn/core.py} +7 -10
- maxframe/lib/mmh3.cp311-win_amd64.pyd +0 -0
- maxframe/protocol.py +41 -17
- maxframe/remote/core.py +4 -8
- maxframe/serialization/__init__.py +1 -0
- maxframe/serialization/core.cp311-win_amd64.pyd +0 -0
- maxframe/serialization/serializables/core.py +48 -9
- maxframe/tensor/__init__.py +69 -2
- maxframe/tensor/arithmetic/isclose.py +1 -0
- maxframe/tensor/arithmetic/tests/test_arithmetic.py +21 -17
- maxframe/tensor/core.py +5 -136
- maxframe/tensor/datasource/array.py +3 -0
- maxframe/tensor/datasource/full.py +1 -1
- maxframe/tensor/datasource/tests/test_datasource.py +1 -1
- maxframe/tensor/indexing/flatnonzero.py +1 -1
- maxframe/tensor/merge/__init__.py +2 -0
- maxframe/tensor/merge/concatenate.py +98 -0
- maxframe/tensor/merge/tests/test_merge.py +30 -1
- maxframe/tensor/merge/vstack.py +70 -0
- maxframe/tensor/{base → misc}/__init__.py +2 -0
- maxframe/tensor/{base → misc}/atleast_1d.py +0 -2
- maxframe/tensor/misc/atleast_2d.py +70 -0
- maxframe/tensor/misc/atleast_3d.py +85 -0
- maxframe/tensor/misc/tests/__init__.py +13 -0
- maxframe/tensor/{base → misc}/transpose.py +22 -18
- maxframe/tensor/{base → misc}/unique.py +2 -2
- maxframe/tensor/operators.py +1 -7
- maxframe/tensor/random/core.py +1 -1
- maxframe/tensor/reduction/count_nonzero.py +1 -0
- maxframe/tensor/reduction/mean.py +1 -0
- maxframe/tensor/reduction/nanmean.py +1 -0
- maxframe/tensor/reduction/nanvar.py +2 -0
- maxframe/tensor/reduction/tests/test_reduction.py +12 -1
- maxframe/tensor/reduction/var.py +2 -0
- maxframe/tensor/statistics/quantile.py +2 -2
- maxframe/tensor/utils.py +2 -22
- maxframe/tests/utils.py +11 -2
- maxframe/typing_.py +4 -1
- maxframe/udf.py +8 -9
- maxframe/utils.py +32 -70
- {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/METADATA +25 -25
- {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/RECORD +133 -123
- {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/WHEEL +1 -1
- maxframe_client/fetcher.py +60 -68
- maxframe_client/session/graph.py +8 -2
- maxframe_client/session/odps.py +58 -22
- maxframe_client/tests/test_fetcher.py +21 -3
- maxframe_client/tests/test_session.py +27 -4
- maxframe/core/entity/chunks.py +0 -68
- maxframe/core/entity/fuse.py +0 -73
- maxframe/core/graph/builder/chunk.py +0 -430
- maxframe/odpsio/tableio.py +0 -322
- maxframe/odpsio/volumeio.py +0 -95
- /maxframe/{odpsio → core/entity}/tests/__init__.py +0 -0
- /maxframe/{tensor/base/tests → io}/__init__.py +0 -0
- /maxframe/{odpsio → io/odpsio}/tests/test_arrow.py +0 -0
- /maxframe/tensor/{base → misc}/astype.py +0 -0
- /maxframe/tensor/{base → misc}/broadcast_to.py +0 -0
- /maxframe/tensor/{base → misc}/ravel.py +0 -0
- /maxframe/tensor/{base/tests/test_base.py → misc/tests/test_misc.py} +0 -0
- /maxframe/tensor/{base → misc}/where.py +0 -0
- {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/top_level.txt +0 -0
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
import logging
|
|
16
16
|
from collections import OrderedDict
|
|
17
17
|
|
|
18
|
-
from .... import opcodes
|
|
18
|
+
from .... import opcodes
|
|
19
19
|
from ....core import OutputType
|
|
20
20
|
from ....core.operator.base import Operator
|
|
21
21
|
from ....core.operator.core import TileableOperatorMixin
|
|
@@ -29,6 +29,7 @@ from ....serialization.serializables import (
|
|
|
29
29
|
KeyField,
|
|
30
30
|
ListField,
|
|
31
31
|
)
|
|
32
|
+
from .core import Booster
|
|
32
33
|
from .dmatrix import ToDMatrix, to_dmatrix
|
|
33
34
|
|
|
34
35
|
logger = logging.getLogger(__name__)
|
|
@@ -41,7 +42,7 @@ def _on_serialize_evals(evals_val):
|
|
|
41
42
|
|
|
42
43
|
|
|
43
44
|
class XGBTrain(Operator, TileableOperatorMixin):
|
|
44
|
-
_op_type_ =
|
|
45
|
+
_op_type_ = opcodes.XGBOOST_TRAIN
|
|
45
46
|
|
|
46
47
|
params = DictField("params", key_type=FieldTypes.string, default=None)
|
|
47
48
|
dtrain = KeyField("dtrain", default=None)
|
|
@@ -59,49 +60,59 @@ class XGBTrain(Operator, TileableOperatorMixin):
|
|
|
59
60
|
num_boost_round = Int64Field("num_boost_round", default=10)
|
|
60
61
|
num_class = Int64Field("num_class", default=None)
|
|
61
62
|
|
|
62
|
-
# Store evals_result in local to store the remote evals_result
|
|
63
|
-
evals_result: dict = None
|
|
64
|
-
|
|
65
63
|
def __init__(self, gpu=None, **kw):
|
|
66
64
|
super().__init__(gpu=gpu, **kw)
|
|
67
65
|
if self.output_types is None:
|
|
68
66
|
self.output_types = [OutputType.object]
|
|
67
|
+
if self.has_evals_result:
|
|
68
|
+
self.output_types.append(OutputType.object)
|
|
69
69
|
|
|
70
70
|
def _set_inputs(self, inputs):
|
|
71
71
|
super()._set_inputs(inputs)
|
|
72
72
|
self.dtrain = self._inputs[0]
|
|
73
73
|
rest = self._inputs[1:]
|
|
74
|
-
if self.
|
|
74
|
+
if self.has_evals_result:
|
|
75
75
|
evals_dict = OrderedDict(self.evals)
|
|
76
76
|
new_evals_dict = OrderedDict()
|
|
77
77
|
for new_key, val in zip(rest, evals_dict.values()):
|
|
78
78
|
new_evals_dict[new_key] = val
|
|
79
79
|
self.evals = list(new_evals_dict.items())
|
|
80
80
|
|
|
81
|
-
def __call__(self):
|
|
81
|
+
def __call__(self, evals_result):
|
|
82
82
|
inputs = [self.dtrain]
|
|
83
|
-
if self.
|
|
83
|
+
if self.has_evals_result:
|
|
84
84
|
inputs.extend(e[0] for e in self.evals)
|
|
85
|
-
return self.
|
|
85
|
+
return self.new_tileables(
|
|
86
|
+
inputs, object_class=Booster, evals_result=evals_result
|
|
87
|
+
)[0]
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def output_limit(self):
|
|
91
|
+
return 2 if self.has_evals_result else 1
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def has_evals_result(self) -> bool:
|
|
95
|
+
return self.evals
|
|
86
96
|
|
|
87
97
|
|
|
88
98
|
def train(params, dtrain, evals=None, evals_result=None, num_class=None, **kwargs):
|
|
89
99
|
"""
|
|
90
|
-
Train XGBoost model in
|
|
100
|
+
Train XGBoost model in MaxFrame manner.
|
|
91
101
|
|
|
92
102
|
Parameters
|
|
93
103
|
----------
|
|
94
|
-
Parameters are the same as `xgboost.train`.
|
|
104
|
+
Parameters are the same as `xgboost.train`. Note that train is an eager-execution
|
|
105
|
+
API. The call will be blocked until training finished.
|
|
95
106
|
|
|
96
107
|
Returns
|
|
97
108
|
-------
|
|
98
109
|
results: Booster
|
|
99
110
|
"""
|
|
100
111
|
|
|
101
|
-
evals_result = evals_result
|
|
102
|
-
evals = None or ()
|
|
103
|
-
|
|
112
|
+
evals_result = evals_result if evals_result is not None else dict()
|
|
104
113
|
processed_evals = []
|
|
114
|
+
session = kwargs.pop("session", None)
|
|
115
|
+
run_kwargs = kwargs.pop("run_kwargs", dict())
|
|
105
116
|
if evals:
|
|
106
117
|
for eval_dmatrix, name in evals:
|
|
107
118
|
if not isinstance(name, str):
|
|
@@ -110,12 +121,11 @@ def train(params, dtrain, evals=None, evals_result=None, num_class=None, **kwarg
|
|
|
110
121
|
processed_evals.append((eval_dmatrix, name))
|
|
111
122
|
else:
|
|
112
123
|
processed_evals.append((to_dmatrix(eval_dmatrix), name))
|
|
113
|
-
|
|
114
124
|
return XGBTrain(
|
|
115
125
|
params=params,
|
|
116
126
|
dtrain=dtrain,
|
|
117
127
|
evals=processed_evals,
|
|
118
128
|
evals_result=evals_result,
|
|
119
129
|
num_class=num_class,
|
|
120
|
-
**kwargs
|
|
121
|
-
)()
|
|
130
|
+
**kwargs,
|
|
131
|
+
)(evals_result).execute(session=session, **run_kwargs)
|
|
@@ -12,18 +12,15 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from
|
|
16
|
-
from ...serialization.serializables import ReferenceField
|
|
17
|
-
from ..graph import ChunkGraph
|
|
18
|
-
from .base import Operator
|
|
15
|
+
from ..core.entity.objects import Object, ObjectData
|
|
19
16
|
|
|
20
17
|
|
|
21
|
-
class
|
|
22
|
-
|
|
23
|
-
_op_type_ = opcodes.FUSE
|
|
18
|
+
class ModelData(ObjectData):
|
|
19
|
+
pass
|
|
24
20
|
|
|
25
|
-
fuse_graph = ReferenceField("fuse_graph", ChunkGraph)
|
|
26
21
|
|
|
22
|
+
class Model(Object):
|
|
23
|
+
pass
|
|
27
24
|
|
|
28
|
-
|
|
29
|
-
|
|
25
|
+
|
|
26
|
+
MODEL_TYPE = (Model, ModelData)
|
|
Binary file
|
maxframe/protocol.py
CHANGED
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
import base64
|
|
16
16
|
import enum
|
|
17
17
|
import uuid
|
|
18
|
-
from typing import Any, Dict, Generic, List, Optional,
|
|
18
|
+
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
|
|
19
19
|
|
|
20
20
|
import pandas as pd
|
|
21
21
|
|
|
@@ -38,7 +38,6 @@ from .serialization.serializables import (
|
|
|
38
38
|
Serializable,
|
|
39
39
|
SeriesField,
|
|
40
40
|
StringField,
|
|
41
|
-
TupleField,
|
|
42
41
|
)
|
|
43
42
|
|
|
44
43
|
pickling_support.install()
|
|
@@ -92,19 +91,6 @@ class DataSerializeType(enum.Enum):
|
|
|
92
91
|
PICKLE = 0
|
|
93
92
|
|
|
94
93
|
|
|
95
|
-
class VolumeDataMeta(Serializable):
|
|
96
|
-
output_type: OutputType = EnumField(
|
|
97
|
-
"output_type", OutputType, FieldTypes.int8, default=None
|
|
98
|
-
)
|
|
99
|
-
serial_type: DataSerializeType = EnumField(
|
|
100
|
-
"serial_type", DataSerializeType, FieldTypes.int8, default=None
|
|
101
|
-
)
|
|
102
|
-
shape: Tuple[int, ...] = TupleField("shape", FieldTypes.int64, default=None)
|
|
103
|
-
nsplits: Tuple[Tuple[int, ...], ...] = TupleField(
|
|
104
|
-
"nsplits", FieldTypes.tuple(FieldTypes.tuple(FieldTypes.int64)), default=None
|
|
105
|
-
)
|
|
106
|
-
|
|
107
|
-
|
|
108
94
|
_result_type_to_info_cls: Dict[ResultType, Type["ResultInfo"]] = dict()
|
|
109
95
|
|
|
110
96
|
|
|
@@ -154,6 +140,9 @@ class ODPSTableResultInfo(ResultInfo):
|
|
|
154
140
|
partition_specs: Optional[List[str]] = ListField(
|
|
155
141
|
"partition_specs", FieldTypes.string, default=None
|
|
156
142
|
)
|
|
143
|
+
table_meta: Optional["DataFrameTableMeta"] = ReferenceField(
|
|
144
|
+
"table_meta", default=None
|
|
145
|
+
)
|
|
157
146
|
|
|
158
147
|
def __init__(self, result_type: ResultType = None, **kw):
|
|
159
148
|
result_type = result_type or ResultType.ODPS_TABLE
|
|
@@ -164,8 +153,17 @@ class ODPSTableResultInfo(ResultInfo):
|
|
|
164
153
|
ret["full_table_name"] = self.full_table_name
|
|
165
154
|
if self.partition_specs:
|
|
166
155
|
ret["partition_specs"] = self.partition_specs
|
|
156
|
+
if self.table_meta:
|
|
157
|
+
ret["table_meta"] = self.table_meta.to_json()
|
|
167
158
|
return ret
|
|
168
159
|
|
|
160
|
+
@classmethod
|
|
161
|
+
def _json_to_kwargs(cls, serialized: dict) -> dict:
|
|
162
|
+
kw = super()._json_to_kwargs(serialized)
|
|
163
|
+
if "table_meta" in kw:
|
|
164
|
+
kw["table_meta"] = DataFrameTableMeta.from_json(kw["table_meta"])
|
|
165
|
+
return kw
|
|
166
|
+
|
|
169
167
|
|
|
170
168
|
class ODPSVolumeResultInfo(ResultInfo):
|
|
171
169
|
_result_type = ResultType.ODPS_VOLUME
|
|
@@ -469,7 +467,7 @@ class DecrefRequest(Serializable):
|
|
|
469
467
|
keys: List[str] = ListField("keys", FieldTypes.string, default=None)
|
|
470
468
|
|
|
471
469
|
|
|
472
|
-
class DataFrameTableMeta(
|
|
470
|
+
class DataFrameTableMeta(JsonSerializable):
|
|
473
471
|
__slots__ = "_pd_column_names", "_pd_index_level_names"
|
|
474
472
|
|
|
475
473
|
table_name: Optional[str] = StringField("table_name", default=None)
|
|
@@ -500,7 +498,7 @@ class DataFrameTableMeta(Serializable):
|
|
|
500
498
|
self._pd_index_level_names = self.pd_index_dtypes.index.tolist()
|
|
501
499
|
return self._pd_index_level_names
|
|
502
500
|
|
|
503
|
-
def __eq__(self, other: "
|
|
501
|
+
def __eq__(self, other: "DataFrameTableMeta") -> bool:
|
|
504
502
|
if not isinstance(other, type(self)):
|
|
505
503
|
return False
|
|
506
504
|
for k in self._FIELDS:
|
|
@@ -511,3 +509,29 @@ class DataFrameTableMeta(Serializable):
|
|
|
511
509
|
if not is_same:
|
|
512
510
|
return False
|
|
513
511
|
return True
|
|
512
|
+
|
|
513
|
+
def to_json(self) -> dict:
|
|
514
|
+
b64_pk = lambda x: base64.b64encode(pickle.dumps(x))
|
|
515
|
+
ret = {
|
|
516
|
+
"table_name": self.table_name,
|
|
517
|
+
"type": self.type.value,
|
|
518
|
+
"table_column_names": self.table_column_names,
|
|
519
|
+
"table_index_column_names": self.table_index_column_names,
|
|
520
|
+
"pd_column_dtypes": b64_pk(self.pd_column_dtypes),
|
|
521
|
+
"pd_column_level_names": b64_pk(self.pd_column_level_names),
|
|
522
|
+
"pd_index_dtypes": b64_pk(self.pd_index_dtypes),
|
|
523
|
+
}
|
|
524
|
+
return ret
|
|
525
|
+
|
|
526
|
+
@classmethod
|
|
527
|
+
def from_json(cls, serialized: dict) -> "DataFrameTableMeta":
|
|
528
|
+
b64_upk = lambda x: pickle.loads(base64.b64decode(x))
|
|
529
|
+
serialized.update(
|
|
530
|
+
{
|
|
531
|
+
"type": OutputType(serialized["type"]),
|
|
532
|
+
"pd_column_dtypes": b64_upk(serialized["pd_column_dtypes"]),
|
|
533
|
+
"pd_column_level_names": b64_upk(serialized["pd_column_level_names"]),
|
|
534
|
+
"pd_index_dtypes": b64_upk(serialized["pd_index_dtypes"]),
|
|
535
|
+
}
|
|
536
|
+
)
|
|
537
|
+
return DataFrameTableMeta(**serialized)
|
maxframe/remote/core.py
CHANGED
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
from functools import partial
|
|
16
16
|
|
|
17
17
|
from .. import opcodes
|
|
18
|
-
from ..core import ENTITY_TYPE
|
|
18
|
+
from ..core import ENTITY_TYPE
|
|
19
19
|
from ..core.operator import ObjectOperator, ObjectOperatorMixin
|
|
20
20
|
from ..dataframe.core import DATAFRAME_TYPE, INDEX_TYPE, SERIES_TYPE
|
|
21
21
|
from ..serialization.serializables import (
|
|
@@ -26,7 +26,7 @@ from ..serialization.serializables import (
|
|
|
26
26
|
ListField,
|
|
27
27
|
)
|
|
28
28
|
from ..tensor.core import TENSOR_TYPE
|
|
29
|
-
from ..utils import
|
|
29
|
+
from ..utils import find_objects, replace_objects
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
class RemoteFunction(ObjectOperatorMixin, ObjectOperator):
|
|
@@ -63,12 +63,8 @@ class RemoteFunction(ObjectOperatorMixin, ObjectOperator):
|
|
|
63
63
|
if raw_inputs is not None:
|
|
64
64
|
for raw_inp in raw_inputs:
|
|
65
65
|
if self._no_prepare(raw_inp):
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
mapping[raw_inp] = next(function_inputs)
|
|
69
|
-
else:
|
|
70
|
-
# in tile, set_inputs from chunk
|
|
71
|
-
mapping[raw_inp] = build_fetch_tileable(raw_inp)
|
|
66
|
+
# not in tile, set_inputs from tileable
|
|
67
|
+
mapping[raw_inp] = next(function_inputs)
|
|
72
68
|
else:
|
|
73
69
|
mapping[raw_inp] = next(function_inputs)
|
|
74
70
|
self.function_args = replace_objects(self.function_args, mapping)
|
|
Binary file
|
|
@@ -51,7 +51,10 @@ def _is_field_primitive_compound(field: Field):
|
|
|
51
51
|
class SerializableMeta(type):
|
|
52
52
|
def __new__(mcs, name: str, bases: Tuple[Type], properties: Dict):
|
|
53
53
|
# All the fields including misc fields.
|
|
54
|
-
|
|
54
|
+
legacy_name_hash = hash(f"{properties.get('__module__')}.{name}")
|
|
55
|
+
name_hash = hash(
|
|
56
|
+
f"{properties.get('__module__')}.{properties.get('__qualname__')}"
|
|
57
|
+
)
|
|
55
58
|
all_fields = dict()
|
|
56
59
|
# mapping field names to base classes
|
|
57
60
|
field_to_cls_hash = dict()
|
|
@@ -107,6 +110,10 @@ class SerializableMeta(type):
|
|
|
107
110
|
slots.update(properties_field_slot_names)
|
|
108
111
|
|
|
109
112
|
properties = properties_without_fields
|
|
113
|
+
|
|
114
|
+
# todo remove this prop when all versions below v1.0.0rc1 is eliminated
|
|
115
|
+
properties["_LEGACY_NAME_HASH"] = legacy_name_hash
|
|
116
|
+
|
|
110
117
|
properties["_NAME_HASH"] = name_hash
|
|
111
118
|
properties["_FIELDS"] = all_fields
|
|
112
119
|
properties["_FIELD_ORDER"] = field_order
|
|
@@ -210,8 +217,8 @@ class SerializableSerializer(Serializer):
|
|
|
210
217
|
"""
|
|
211
218
|
|
|
212
219
|
@classmethod
|
|
213
|
-
def _get_obj_field_count_key(cls, obj: Serializable):
|
|
214
|
-
return f"FC_{obj._NAME_HASH}"
|
|
220
|
+
def _get_obj_field_count_key(cls, obj: Serializable, legacy: bool = False):
|
|
221
|
+
return f"FC_{obj._NAME_HASH if not legacy else obj._LEGACY_NAME_HASH}"
|
|
215
222
|
|
|
216
223
|
@classmethod
|
|
217
224
|
def _get_field_values(cls, obj: Serializable, fields):
|
|
@@ -290,6 +297,12 @@ class SerializableSerializer(Serializer):
|
|
|
290
297
|
server_cls_to_field_count = obj_class._CLS_TO_NON_PRIMITIVE_FIELD_COUNT
|
|
291
298
|
server_fields = obj_class._NON_PRIMITIVE_FIELDS
|
|
292
299
|
|
|
300
|
+
legacy_to_new_hash = {
|
|
301
|
+
c._LEGACY_NAME_HASH: c._NAME_HASH
|
|
302
|
+
for c in obj_class.__mro__
|
|
303
|
+
if hasattr(c, "_NAME_HASH") and c._LEGACY_NAME_HASH != c._NAME_HASH
|
|
304
|
+
}
|
|
305
|
+
|
|
293
306
|
if client_cls_to_field_count:
|
|
294
307
|
field_num, server_field_num = 0, 0
|
|
295
308
|
for cls_hash, count in client_cls_to_field_count.items():
|
|
@@ -301,20 +314,40 @@ class SerializableSerializer(Serializer):
|
|
|
301
314
|
if not is_primitive or value != {}:
|
|
302
315
|
cls._set_field_value(obj, field, value)
|
|
303
316
|
field_num += count
|
|
304
|
-
|
|
317
|
+
try:
|
|
318
|
+
server_field_num += server_cls_to_field_count[cls_hash]
|
|
319
|
+
except KeyError:
|
|
320
|
+
try:
|
|
321
|
+
# todo remove this fallback when all
|
|
322
|
+
# versions below v1.0.0rc1 is eliminated
|
|
323
|
+
server_field_num += server_cls_to_field_count[
|
|
324
|
+
legacy_to_new_hash[cls_hash]
|
|
325
|
+
]
|
|
326
|
+
except KeyError:
|
|
327
|
+
# it is possible that certain type of field does not exist
|
|
328
|
+
# at server side
|
|
329
|
+
pass
|
|
305
330
|
else:
|
|
331
|
+
# handle legacy serialization style, with all fields sorted by name
|
|
306
332
|
# todo remove this branch when all versions below v0.1.0b5 is eliminated
|
|
307
333
|
from .field import AnyField
|
|
308
334
|
|
|
309
|
-
# legacy serialization style, with all fields sorted by name
|
|
310
335
|
if is_primitive:
|
|
311
|
-
|
|
336
|
+
new_field_attr = "_legacy_new_primitives"
|
|
337
|
+
deprecated_field_attr = "_legacy_deprecated_primitives"
|
|
312
338
|
else:
|
|
313
|
-
|
|
339
|
+
new_field_attr = "_legacy_new_non_primitives"
|
|
340
|
+
deprecated_field_attr = "_legacy_deprecated_non_primitives"
|
|
341
|
+
|
|
342
|
+
# remove fields added on later releases
|
|
343
|
+
new_names = set(getattr(obj_class, new_field_attr, None) or [])
|
|
344
|
+
server_fields = [f for f in server_fields if f.name not in new_names]
|
|
345
|
+
|
|
346
|
+
# fill fields deprecated on later releases
|
|
314
347
|
deprecated_fields = []
|
|
315
348
|
deprecated_names = set()
|
|
316
|
-
if hasattr(obj_class,
|
|
317
|
-
deprecated_names = set(getattr(obj_class,
|
|
349
|
+
if hasattr(obj_class, deprecated_field_attr):
|
|
350
|
+
deprecated_names = set(getattr(obj_class, deprecated_field_attr))
|
|
318
351
|
for field_name in deprecated_names:
|
|
319
352
|
field = AnyField(tag=field_name)
|
|
320
353
|
field.name = field_name
|
|
@@ -342,6 +375,12 @@ class SerializableSerializer(Serializer):
|
|
|
342
375
|
field_count_data = self.get_public_data(
|
|
343
376
|
context, self._get_obj_field_count_key(obj)
|
|
344
377
|
)
|
|
378
|
+
if field_count_data is None:
|
|
379
|
+
# todo remove this fallback when all
|
|
380
|
+
# versions below v1.0.0rc1 is eliminated
|
|
381
|
+
field_count_data = self.get_public_data(
|
|
382
|
+
context, self._get_obj_field_count_key(obj, legacy=True)
|
|
383
|
+
)
|
|
345
384
|
if field_count_data is not None:
|
|
346
385
|
cls_to_prim_key, cls_to_non_prim_key = msgpack.loads(field_count_data)
|
|
347
386
|
cls_to_prim_key = dict(cls_to_prim_key)
|
maxframe/tensor/__init__.py
CHANGED
|
@@ -114,7 +114,6 @@ from .arithmetic import (
|
|
|
114
114
|
)
|
|
115
115
|
from .arithmetic import truediv as true_divide
|
|
116
116
|
from .arithmetic import trunc
|
|
117
|
-
from .base import broadcast_to, transpose, unique, where
|
|
118
117
|
from .core import Tensor
|
|
119
118
|
from .datasource import (
|
|
120
119
|
arange,
|
|
@@ -143,7 +142,16 @@ from .indexing import (
|
|
|
143
142
|
take,
|
|
144
143
|
unravel_index,
|
|
145
144
|
)
|
|
146
|
-
from .merge import stack
|
|
145
|
+
from .merge import concatenate, stack, vstack
|
|
146
|
+
from .misc import (
|
|
147
|
+
atleast_1d,
|
|
148
|
+
atleast_2d,
|
|
149
|
+
atleast_3d,
|
|
150
|
+
broadcast_to,
|
|
151
|
+
transpose,
|
|
152
|
+
unique,
|
|
153
|
+
where,
|
|
154
|
+
)
|
|
147
155
|
from .rechunk import rechunk
|
|
148
156
|
from .reduction import (
|
|
149
157
|
all,
|
|
@@ -180,4 +188,63 @@ from .reduction import std, sum, var
|
|
|
180
188
|
from .reshape import reshape
|
|
181
189
|
from .ufunc import ufunc
|
|
182
190
|
|
|
191
|
+
# isort: off
|
|
192
|
+
# noinspection PyUnresolvedReferences
|
|
193
|
+
from numpy import (
|
|
194
|
+
NAN,
|
|
195
|
+
NINF,
|
|
196
|
+
AxisError,
|
|
197
|
+
Inf,
|
|
198
|
+
NaN,
|
|
199
|
+
e,
|
|
200
|
+
errstate,
|
|
201
|
+
geterr,
|
|
202
|
+
inf,
|
|
203
|
+
nan,
|
|
204
|
+
newaxis,
|
|
205
|
+
pi,
|
|
206
|
+
seterr,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# import numpy types
|
|
210
|
+
# noinspection PyUnresolvedReferences
|
|
211
|
+
from numpy import (
|
|
212
|
+
bool_ as bool,
|
|
213
|
+
bytes_,
|
|
214
|
+
cfloat,
|
|
215
|
+
character,
|
|
216
|
+
complex64,
|
|
217
|
+
complex128,
|
|
218
|
+
complexfloating,
|
|
219
|
+
datetime64,
|
|
220
|
+
double,
|
|
221
|
+
dtype,
|
|
222
|
+
flexible,
|
|
223
|
+
float16,
|
|
224
|
+
float32,
|
|
225
|
+
float64,
|
|
226
|
+
floating,
|
|
227
|
+
generic,
|
|
228
|
+
inexact,
|
|
229
|
+
int8,
|
|
230
|
+
int16,
|
|
231
|
+
int32,
|
|
232
|
+
int64,
|
|
233
|
+
intc,
|
|
234
|
+
intp,
|
|
235
|
+
number,
|
|
236
|
+
integer,
|
|
237
|
+
object_ as object,
|
|
238
|
+
signedinteger,
|
|
239
|
+
timedelta64,
|
|
240
|
+
uint,
|
|
241
|
+
uint8,
|
|
242
|
+
uint16,
|
|
243
|
+
uint32,
|
|
244
|
+
uint64,
|
|
245
|
+
unicode_,
|
|
246
|
+
unsignedinteger,
|
|
247
|
+
void,
|
|
248
|
+
)
|
|
249
|
+
|
|
183
250
|
del fetch, ufunc
|
|
@@ -17,26 +17,13 @@
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
import pytest
|
|
19
19
|
|
|
20
|
+
from maxframe.tensor.arithmetic.core import TensorBinOp, TensorUnaryOp
|
|
21
|
+
from maxframe.utils import collect_leaf_operators
|
|
22
|
+
|
|
20
23
|
from ....core import enter_mode
|
|
21
24
|
from ...core import SparseTensor, Tensor
|
|
22
25
|
from ...datasource import array, empty, ones, tensor
|
|
23
|
-
from .. import
|
|
24
|
-
TensorAdd,
|
|
25
|
-
TensorGreaterThan,
|
|
26
|
-
TensorIsclose,
|
|
27
|
-
TensorLog,
|
|
28
|
-
TensorSubtract,
|
|
29
|
-
add,
|
|
30
|
-
around,
|
|
31
|
-
cos,
|
|
32
|
-
frexp,
|
|
33
|
-
isclose,
|
|
34
|
-
isfinite,
|
|
35
|
-
log,
|
|
36
|
-
negative,
|
|
37
|
-
subtract,
|
|
38
|
-
truediv,
|
|
39
|
-
)
|
|
26
|
+
from .. import * # noqa: F401
|
|
40
27
|
|
|
41
28
|
|
|
42
29
|
def test_add():
|
|
@@ -412,3 +399,20 @@ def test_build_mode():
|
|
|
412
399
|
|
|
413
400
|
with enter_mode(build=True):
|
|
414
401
|
assert t1 != 2
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def test_unary_op_func_name():
|
|
405
|
+
# make sure all the unary op has defined the func name.
|
|
406
|
+
|
|
407
|
+
results = collect_leaf_operators(TensorUnaryOp)
|
|
408
|
+
for op_type in results:
|
|
409
|
+
assert hasattr(op_type, "_func_name")
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def test_binary_op_func_name():
|
|
413
|
+
# make sure all the binary op has defined the func name.
|
|
414
|
+
|
|
415
|
+
results = collect_leaf_operators(TensorBinOp)
|
|
416
|
+
for op_type in results:
|
|
417
|
+
if op_type not in (TensorSetImag, TensorSetReal):
|
|
418
|
+
assert hasattr(op_type, "_func_name")
|