maxframe 0.1.0b4__cp39-cp39-win32.whl → 1.0.0rc1__cp39-cp39-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of maxframe might be problematic. Click here for more details.
- maxframe/__init__.py +1 -0
- maxframe/_utils.cp39-win32.pyd +0 -0
- maxframe/codegen.py +56 -3
- maxframe/config/config.py +15 -1
- maxframe/core/__init__.py +0 -3
- maxframe/core/entity/__init__.py +1 -8
- maxframe/core/entity/objects.py +3 -45
- maxframe/core/graph/core.cp39-win32.pyd +0 -0
- maxframe/core/graph/core.pyx +4 -4
- maxframe/dataframe/__init__.py +1 -0
- maxframe/dataframe/core.py +30 -8
- maxframe/dataframe/datasource/read_odps_query.py +3 -1
- maxframe/dataframe/datasource/read_odps_table.py +3 -1
- maxframe/dataframe/datastore/tests/__init__.py +13 -0
- maxframe/dataframe/datastore/tests/test_to_odps.py +48 -0
- maxframe/dataframe/datastore/to_odps.py +21 -0
- maxframe/dataframe/indexing/align.py +1 -1
- maxframe/dataframe/misc/__init__.py +4 -0
- maxframe/dataframe/misc/apply.py +3 -1
- maxframe/dataframe/misc/case_when.py +141 -0
- maxframe/dataframe/misc/memory_usage.py +2 -2
- maxframe/dataframe/misc/pivot_table.py +262 -0
- maxframe/dataframe/misc/tests/test_misc.py +84 -0
- maxframe/dataframe/plotting/core.py +2 -2
- maxframe/dataframe/reduction/core.py +2 -1
- maxframe/dataframe/statistics/corr.py +3 -3
- maxframe/dataframe/utils.py +7 -0
- maxframe/errors.py +13 -0
- maxframe/extension.py +12 -0
- maxframe/learn/contrib/utils.py +52 -0
- maxframe/learn/contrib/xgboost/__init__.py +26 -0
- maxframe/learn/contrib/xgboost/classifier.py +86 -0
- maxframe/learn/contrib/xgboost/core.py +156 -0
- maxframe/learn/contrib/xgboost/dmatrix.py +150 -0
- maxframe/learn/contrib/xgboost/predict.py +138 -0
- maxframe/learn/contrib/xgboost/regressor.py +78 -0
- maxframe/learn/contrib/xgboost/tests/__init__.py +13 -0
- maxframe/learn/contrib/xgboost/tests/test_core.py +43 -0
- maxframe/learn/contrib/xgboost/train.py +121 -0
- maxframe/learn/utils/__init__.py +15 -0
- maxframe/learn/utils/core.py +29 -0
- maxframe/lib/mmh3.cp39-win32.pyd +0 -0
- maxframe/lib/mmh3.pyi +43 -0
- maxframe/lib/wrapped_pickle.py +2 -1
- maxframe/odpsio/arrow.py +2 -3
- maxframe/odpsio/tableio.py +22 -0
- maxframe/odpsio/tests/test_schema.py +16 -11
- maxframe/opcodes.py +3 -0
- maxframe/protocol.py +108 -10
- maxframe/serialization/core.cp39-win32.pyd +0 -0
- maxframe/serialization/core.pxd +3 -0
- maxframe/serialization/core.pyi +64 -0
- maxframe/serialization/core.pyx +54 -25
- maxframe/serialization/exception.py +1 -1
- maxframe/serialization/pandas.py +7 -2
- maxframe/serialization/serializables/core.py +119 -12
- maxframe/serialization/serializables/tests/test_serializable.py +46 -4
- maxframe/session.py +28 -0
- maxframe/tensor/__init__.py +1 -1
- maxframe/tensor/arithmetic/tests/test_arithmetic.py +1 -1
- maxframe/tensor/base/__init__.py +2 -0
- maxframe/tensor/base/atleast_1d.py +74 -0
- maxframe/tensor/base/unique.py +205 -0
- maxframe/tensor/datasource/array.py +4 -2
- maxframe/tensor/datasource/scalar.py +1 -1
- maxframe/tensor/reduction/count_nonzero.py +1 -1
- maxframe/tests/test_protocol.py +34 -0
- maxframe/tests/test_utils.py +0 -12
- maxframe/tests/utils.py +2 -2
- maxframe/udf.py +63 -3
- maxframe/utils.py +22 -13
- {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/METADATA +3 -3
- {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/RECORD +80 -61
- maxframe_client/__init__.py +0 -1
- maxframe_client/fetcher.py +65 -3
- maxframe_client/session/odps.py +74 -5
- maxframe_client/session/task.py +65 -71
- maxframe_client/tests/test_session.py +64 -1
- maxframe_client/clients/spe.py +0 -104
- {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/WHEEL +0 -0
- {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
# Copyright 1999-2024 Alibaba Group Holding Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
from collections import OrderedDict
|
|
17
|
+
|
|
18
|
+
from .... import opcodes as OperandDef
|
|
19
|
+
from ....core import OutputType
|
|
20
|
+
from ....core.operator.base import Operator
|
|
21
|
+
from ....core.operator.core import TileableOperatorMixin
|
|
22
|
+
from ....serialization.serializables import (
|
|
23
|
+
AnyField,
|
|
24
|
+
BoolField,
|
|
25
|
+
DictField,
|
|
26
|
+
FieldTypes,
|
|
27
|
+
FunctionField,
|
|
28
|
+
Int64Field,
|
|
29
|
+
KeyField,
|
|
30
|
+
ListField,
|
|
31
|
+
)
|
|
32
|
+
from .dmatrix import ToDMatrix, to_dmatrix
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _on_serialize_evals(evals_val):
|
|
38
|
+
if evals_val is None:
|
|
39
|
+
return None
|
|
40
|
+
return [list(x) for x in evals_val]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class XGBTrain(Operator, TileableOperatorMixin):
|
|
44
|
+
_op_type_ = OperandDef.XGBOOST_TRAIN
|
|
45
|
+
|
|
46
|
+
params = DictField("params", key_type=FieldTypes.string, default=None)
|
|
47
|
+
dtrain = KeyField("dtrain", default=None)
|
|
48
|
+
evals = ListField("evals", on_serialize=_on_serialize_evals, default=None)
|
|
49
|
+
obj = FunctionField("obj", default=None)
|
|
50
|
+
feval = FunctionField("obj", default=None)
|
|
51
|
+
maximize = BoolField("maximize", default=None)
|
|
52
|
+
early_stopping_rounds = Int64Field("early_stopping_rounds", default=None)
|
|
53
|
+
verbose_eval = AnyField("verbose_eval", default=None)
|
|
54
|
+
xgb_model = AnyField("xgb_model", default=None)
|
|
55
|
+
callbacks = ListField(
|
|
56
|
+
"callbacks", field_type=FunctionField.field_type, default=None
|
|
57
|
+
)
|
|
58
|
+
custom_metric = FunctionField("custom_metric", default=None)
|
|
59
|
+
num_boost_round = Int64Field("num_boost_round", default=10)
|
|
60
|
+
num_class = Int64Field("num_class", default=None)
|
|
61
|
+
|
|
62
|
+
# Store evals_result in local to store the remote evals_result
|
|
63
|
+
evals_result: dict = None
|
|
64
|
+
|
|
65
|
+
def __init__(self, gpu=None, **kw):
|
|
66
|
+
super().__init__(gpu=gpu, **kw)
|
|
67
|
+
if self.output_types is None:
|
|
68
|
+
self.output_types = [OutputType.object]
|
|
69
|
+
|
|
70
|
+
def _set_inputs(self, inputs):
|
|
71
|
+
super()._set_inputs(inputs)
|
|
72
|
+
self.dtrain = self._inputs[0]
|
|
73
|
+
rest = self._inputs[1:]
|
|
74
|
+
if self.evals is not None:
|
|
75
|
+
evals_dict = OrderedDict(self.evals)
|
|
76
|
+
new_evals_dict = OrderedDict()
|
|
77
|
+
for new_key, val in zip(rest, evals_dict.values()):
|
|
78
|
+
new_evals_dict[new_key] = val
|
|
79
|
+
self.evals = list(new_evals_dict.items())
|
|
80
|
+
|
|
81
|
+
def __call__(self):
|
|
82
|
+
inputs = [self.dtrain]
|
|
83
|
+
if self.evals is not None:
|
|
84
|
+
inputs.extend(e[0] for e in self.evals)
|
|
85
|
+
return self.new_tileable(inputs)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def train(params, dtrain, evals=None, evals_result=None, num_class=None, **kwargs):
|
|
89
|
+
"""
|
|
90
|
+
Train XGBoost model in Mars manner.
|
|
91
|
+
|
|
92
|
+
Parameters
|
|
93
|
+
----------
|
|
94
|
+
Parameters are the same as `xgboost.train`.
|
|
95
|
+
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
results: Booster
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
evals_result = evals_result or dict()
|
|
102
|
+
evals = None or ()
|
|
103
|
+
|
|
104
|
+
processed_evals = []
|
|
105
|
+
if evals:
|
|
106
|
+
for eval_dmatrix, name in evals:
|
|
107
|
+
if not isinstance(name, str):
|
|
108
|
+
raise TypeError("evals must a list of pairs (DMatrix, string)")
|
|
109
|
+
if hasattr(eval_dmatrix, "op") and isinstance(eval_dmatrix.op, ToDMatrix):
|
|
110
|
+
processed_evals.append((eval_dmatrix, name))
|
|
111
|
+
else:
|
|
112
|
+
processed_evals.append((to_dmatrix(eval_dmatrix), name))
|
|
113
|
+
|
|
114
|
+
return XGBTrain(
|
|
115
|
+
params=params,
|
|
116
|
+
dtrain=dtrain,
|
|
117
|
+
evals=processed_evals,
|
|
118
|
+
evals_result=evals_result,
|
|
119
|
+
num_class=num_class,
|
|
120
|
+
**kwargs
|
|
121
|
+
)()
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 1999-2024 Alibaba Group Holding Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .core import convert_to_tensor_or_dataframe
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright 1999-2024 Alibaba Group Holding Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import pandas as pd
|
|
16
|
+
|
|
17
|
+
from ...dataframe import DataFrame, Series
|
|
18
|
+
from ...dataframe.core import DATAFRAME_TYPE, SERIES_TYPE
|
|
19
|
+
from ...tensor import tensor as astensor
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def convert_to_tensor_or_dataframe(item):
|
|
23
|
+
if isinstance(item, (DATAFRAME_TYPE, pd.DataFrame)):
|
|
24
|
+
item = DataFrame(item)
|
|
25
|
+
elif isinstance(item, (SERIES_TYPE, pd.Series)):
|
|
26
|
+
item = Series(item)
|
|
27
|
+
else:
|
|
28
|
+
item = astensor(item)
|
|
29
|
+
return item
|
maxframe/lib/mmh3.cp39-win32.pyd
CHANGED
|
Binary file
|
maxframe/lib/mmh3.pyi
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# Copyright 1999-2024 Alibaba Group Holding Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Tuple
|
|
16
|
+
|
|
17
|
+
def hash(key, seed=0, signed=True) -> int:
|
|
18
|
+
"""
|
|
19
|
+
Return a 32 bit integer.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def hash_from_buffer(key, seed=0, signed=True) -> int:
|
|
23
|
+
"""
|
|
24
|
+
Return a 32 bit integer. Designed for large memory-views such as numpy arrays.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def hash64(key, seed=0, x64arch=True, signed=True) -> Tuple[int, int]:
|
|
28
|
+
"""
|
|
29
|
+
Return a tuple of two 64 bit integers for a string. Optimized for
|
|
30
|
+
the x64 bit architecture when x64arch=True, otherwise for x86.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def hash128(key, seed=0, x64arch=True, signed=False) -> int:
|
|
34
|
+
"""
|
|
35
|
+
Return a 128 bit long integer. Optimized for the x64 bit architecture
|
|
36
|
+
when x64arch=True, otherwise for x86.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def hash_bytes(key, seed=0, x64arch=True) -> bytes:
|
|
40
|
+
"""
|
|
41
|
+
Return a 128 bit hash value as bytes for a string. Optimized for the
|
|
42
|
+
x64 bit architecture when x64arch=True, otherwise for the x86.
|
|
43
|
+
"""
|
maxframe/lib/wrapped_pickle.py
CHANGED
|
@@ -120,7 +120,8 @@ class _UnpickleSwitch:
|
|
|
120
120
|
@functools.wraps(func)
|
|
121
121
|
async def wrapped(*args, **kwargs):
|
|
122
122
|
with _UnpickleSwitch(forbidden=self._forbidden):
|
|
123
|
-
|
|
123
|
+
ret = await func(*args, **kwargs)
|
|
124
|
+
return ret
|
|
124
125
|
|
|
125
126
|
else:
|
|
126
127
|
|
maxframe/odpsio/arrow.py
CHANGED
|
@@ -17,10 +17,9 @@ from typing import Any, Tuple, Union
|
|
|
17
17
|
import pandas as pd
|
|
18
18
|
import pyarrow as pa
|
|
19
19
|
|
|
20
|
-
import maxframe.tensor as mt
|
|
21
|
-
|
|
22
20
|
from ..core import OutputType
|
|
23
21
|
from ..protocol import DataFrameTableMeta
|
|
22
|
+
from ..tensor.core import TENSOR_TYPE
|
|
24
23
|
from ..typing_ import ArrowTableType, PandasObjectTypes
|
|
25
24
|
from .schema import build_dataframe_table_meta
|
|
26
25
|
|
|
@@ -83,7 +82,7 @@ def pandas_to_arrow(
|
|
|
83
82
|
df = df.to_frame(name=names[0] if len(names) == 1 else names)
|
|
84
83
|
elif table_meta.type == OutputType.scalar:
|
|
85
84
|
names = ["_idx_0"]
|
|
86
|
-
if isinstance(df,
|
|
85
|
+
if isinstance(df, TENSOR_TYPE):
|
|
87
86
|
df = pd.DataFrame([], columns=names).astype({names[0]: df.dtype})
|
|
88
87
|
else:
|
|
89
88
|
df = pd.DataFrame([[df]], columns=names)
|
maxframe/odpsio/tableio.py
CHANGED
|
@@ -183,6 +183,28 @@ class HaloTableIO(MCTableIO):
|
|
|
183
183
|
for pt in partitions
|
|
184
184
|
]
|
|
185
185
|
|
|
186
|
+
def get_table_record_count(
|
|
187
|
+
self, full_table_name: str, partitions: PartitionsType = None
|
|
188
|
+
):
|
|
189
|
+
from odps.apis.storage_api import SplitOptions, TableBatchScanRequest
|
|
190
|
+
|
|
191
|
+
table = self._odps.get_table(full_table_name)
|
|
192
|
+
client = StorageApiArrowClient(
|
|
193
|
+
self._odps, table, rest_endpoint=self._storage_api_endpoint
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
split_option = SplitOptions.SplitMode.SIZE
|
|
197
|
+
|
|
198
|
+
scan_kw = {
|
|
199
|
+
"required_partitions": self._convert_partitions(partitions),
|
|
200
|
+
"split_options": SplitOptions.get_default_options(split_option),
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
# todo add more options for partition column handling
|
|
204
|
+
req = TableBatchScanRequest(**scan_kw)
|
|
205
|
+
resp = client.create_read_session(req)
|
|
206
|
+
return resp.record_count
|
|
207
|
+
|
|
186
208
|
@contextmanager
|
|
187
209
|
def open_reader(
|
|
188
210
|
self,
|
|
@@ -30,20 +30,23 @@ from ..schema import (
|
|
|
30
30
|
)
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
def _wrap_maxframe_obj(obj, wrap=
|
|
34
|
-
if
|
|
33
|
+
def _wrap_maxframe_obj(obj, wrap="no"):
|
|
34
|
+
if wrap == "no":
|
|
35
35
|
return obj
|
|
36
36
|
if isinstance(obj, pd.DataFrame):
|
|
37
|
-
|
|
37
|
+
obj = md.DataFrame(obj)
|
|
38
38
|
elif isinstance(obj, pd.Series):
|
|
39
|
-
|
|
39
|
+
obj = md.Series(obj)
|
|
40
40
|
elif isinstance(obj, pd.Index):
|
|
41
|
-
|
|
41
|
+
obj = md.Index(obj)
|
|
42
42
|
else:
|
|
43
|
-
|
|
43
|
+
obj = mt.scalar(obj)
|
|
44
|
+
if wrap == "data":
|
|
45
|
+
return obj.data
|
|
46
|
+
return obj
|
|
44
47
|
|
|
45
48
|
|
|
46
|
-
@pytest.mark.parametrize("wrap_obj", [
|
|
49
|
+
@pytest.mark.parametrize("wrap_obj", ["no", "yes", "data"])
|
|
47
50
|
def test_pandas_to_odps_schema_dataframe(wrap_obj):
|
|
48
51
|
data = pd.DataFrame(np.random.rand(100, 5), columns=list("ABCDE"))
|
|
49
52
|
|
|
@@ -94,7 +97,7 @@ def test_pandas_to_odps_schema_dataframe(wrap_obj):
|
|
|
94
97
|
assert meta.pd_index_level_names == [None, None]
|
|
95
98
|
|
|
96
99
|
|
|
97
|
-
@pytest.mark.parametrize("wrap_obj", [
|
|
100
|
+
@pytest.mark.parametrize("wrap_obj", ["no", "yes", "data"])
|
|
98
101
|
def test_pandas_to_odps_schema_series(wrap_obj):
|
|
99
102
|
data = pd.Series(np.random.rand(100))
|
|
100
103
|
|
|
@@ -135,7 +138,7 @@ def test_pandas_to_odps_schema_series(wrap_obj):
|
|
|
135
138
|
assert meta.pd_index_level_names == ["c1", "c2"]
|
|
136
139
|
|
|
137
140
|
|
|
138
|
-
@pytest.mark.parametrize("wrap_obj", [
|
|
141
|
+
@pytest.mark.parametrize("wrap_obj", ["no", "yes", "data"])
|
|
139
142
|
def test_pandas_to_odps_schema_index(wrap_obj):
|
|
140
143
|
data = pd.Index(np.random.randint(0, 100, 100))
|
|
141
144
|
|
|
@@ -167,11 +170,13 @@ def test_pandas_to_odps_schema_index(wrap_obj):
|
|
|
167
170
|
assert meta.pd_index_level_names == ["c1", "c2"]
|
|
168
171
|
|
|
169
172
|
|
|
170
|
-
@pytest.mark.parametrize("wrap_obj", [
|
|
173
|
+
@pytest.mark.parametrize("wrap_obj", ["no", "yes", "data"])
|
|
171
174
|
def test_pandas_to_odps_schema_scalar(wrap_obj):
|
|
172
175
|
data = 1234.56
|
|
173
176
|
|
|
174
177
|
test_scalar = _wrap_maxframe_obj(data, wrap=wrap_obj)
|
|
178
|
+
if wrap_obj != "no":
|
|
179
|
+
test_scalar.op.data = None
|
|
175
180
|
schema, meta = pandas_to_odps_schema(test_scalar, unknown_as_string=True)
|
|
176
181
|
assert schema.columns[0].name == "_idx_0"
|
|
177
182
|
assert schema.columns[0].type.name == "double"
|
|
@@ -279,7 +284,7 @@ def test_build_column_name():
|
|
|
279
284
|
assert build_table_column_name(4, ("A", 1), records) == "a_1"
|
|
280
285
|
|
|
281
286
|
|
|
282
|
-
@pytest.mark.parametrize("wrap_obj", [
|
|
287
|
+
@pytest.mark.parametrize("wrap_obj", ["no", "yes", "data"])
|
|
283
288
|
def test_build_table_meta(wrap_obj):
|
|
284
289
|
data = pd.DataFrame(
|
|
285
290
|
np.random.rand(100, 7),
|
maxframe/opcodes.py
CHANGED
maxframe/protocol.py
CHANGED
|
@@ -32,6 +32,7 @@ from .serialization.serializables import (
|
|
|
32
32
|
EnumField,
|
|
33
33
|
FieldTypes,
|
|
34
34
|
Float64Field,
|
|
35
|
+
Int32Field,
|
|
35
36
|
ListField,
|
|
36
37
|
ReferenceField,
|
|
37
38
|
Serializable,
|
|
@@ -71,6 +72,9 @@ class DagStatus(enum.Enum):
|
|
|
71
72
|
CANCELLING = 4
|
|
72
73
|
CANCELLED = 5
|
|
73
74
|
|
|
75
|
+
def is_terminated(self):
|
|
76
|
+
return self in (DagStatus.CANCELLED, DagStatus.SUCCEEDED, DagStatus.FAILED)
|
|
77
|
+
|
|
74
78
|
|
|
75
79
|
class DimensionIndex(Serializable):
|
|
76
80
|
is_slice: bool = BoolField("is_slice", default=None)
|
|
@@ -190,9 +194,9 @@ class ErrorInfo(JsonSerializable):
|
|
|
190
194
|
"error_tracebacks", FieldTypes.list
|
|
191
195
|
)
|
|
192
196
|
raw_error_source: ErrorSource = EnumField(
|
|
193
|
-
"raw_error_source", ErrorSource, FieldTypes.int8
|
|
197
|
+
"raw_error_source", ErrorSource, FieldTypes.int8, default=None
|
|
194
198
|
)
|
|
195
|
-
raw_error_data: Optional[Exception] = AnyField("raw_error_data")
|
|
199
|
+
raw_error_data: Optional[Exception] = AnyField("raw_error_data", default=None)
|
|
196
200
|
|
|
197
201
|
@classmethod
|
|
198
202
|
def from_exception(cls, exc: Exception):
|
|
@@ -201,20 +205,29 @@ class ErrorInfo(JsonSerializable):
|
|
|
201
205
|
return cls(messages, tracebacks, ErrorSource.PYTHON, exc)
|
|
202
206
|
|
|
203
207
|
def reraise(self):
|
|
204
|
-
if
|
|
208
|
+
if (
|
|
209
|
+
self.raw_error_source == ErrorSource.PYTHON
|
|
210
|
+
and self.raw_error_data is not None
|
|
211
|
+
):
|
|
205
212
|
raise self.raw_error_data
|
|
206
213
|
raise RemoteException(self.error_messages, self.error_tracebacks, [])
|
|
207
214
|
|
|
208
215
|
@classmethod
|
|
209
216
|
def from_json(cls, serialized: dict) -> "ErrorInfo":
|
|
210
217
|
kw = serialized.copy()
|
|
211
|
-
kw
|
|
218
|
+
if kw.get("raw_error_source") is not None:
|
|
219
|
+
kw["raw_error_source"] = ErrorSource(serialized["raw_error_source"])
|
|
220
|
+
else:
|
|
221
|
+
kw["raw_error_source"] = None
|
|
222
|
+
|
|
212
223
|
if kw.get("raw_error_data"):
|
|
213
224
|
bufs = [base64.b64decode(s) for s in kw["raw_error_data"]]
|
|
214
225
|
try:
|
|
215
226
|
kw["raw_error_data"] = pickle.loads(bufs[0], buffers=bufs[1:])
|
|
216
227
|
except:
|
|
217
|
-
|
|
228
|
+
# both error source and data shall be None to make sure
|
|
229
|
+
# RemoteException is raised.
|
|
230
|
+
kw["raw_error_source"] = kw["raw_error_data"] = None
|
|
218
231
|
return cls(**kw)
|
|
219
232
|
|
|
220
233
|
def to_json(self) -> dict:
|
|
@@ -227,7 +240,12 @@ class ErrorInfo(JsonSerializable):
|
|
|
227
240
|
if isinstance(self.raw_error_data, (PickleContainer, RemoteException)):
|
|
228
241
|
err_data_bufs = self.raw_error_data.get_buffers()
|
|
229
242
|
elif isinstance(self.raw_error_data, BaseException):
|
|
230
|
-
|
|
243
|
+
try:
|
|
244
|
+
err_data_bufs = pickle_buffers(self.raw_error_data)
|
|
245
|
+
except:
|
|
246
|
+
err_data_bufs = None
|
|
247
|
+
ret["raw_error_source"] = None
|
|
248
|
+
|
|
231
249
|
if err_data_bufs:
|
|
232
250
|
ret["raw_error_data"] = [
|
|
233
251
|
base64.b64encode(s).decode() for s in err_data_bufs
|
|
@@ -249,9 +267,17 @@ class DagInfo(JsonSerializable):
|
|
|
249
267
|
error_info: Optional[ErrorInfo] = ReferenceField("error_info", default=None)
|
|
250
268
|
start_timestamp: Optional[float] = Float64Field("start_timestamp", default=None)
|
|
251
269
|
end_timestamp: Optional[float] = Float64Field("end_timestamp", default=None)
|
|
270
|
+
subdag_infos: Dict[str, "SubDagInfo"] = DictField(
|
|
271
|
+
"subdag_infos",
|
|
272
|
+
key_type=FieldTypes.string,
|
|
273
|
+
value_type=FieldTypes.reference,
|
|
274
|
+
default_factory=dict,
|
|
275
|
+
)
|
|
252
276
|
|
|
253
277
|
@classmethod
|
|
254
|
-
def from_json(cls, serialized: dict) -> "DagInfo":
|
|
278
|
+
def from_json(cls, serialized: dict) -> Optional["DagInfo"]:
|
|
279
|
+
if serialized is None:
|
|
280
|
+
return None
|
|
255
281
|
kw = serialized.copy()
|
|
256
282
|
kw["status"] = DagStatus(kw["status"])
|
|
257
283
|
if kw.get("tileable_to_result_infos"):
|
|
@@ -261,6 +287,10 @@ class DagInfo(JsonSerializable):
|
|
|
261
287
|
}
|
|
262
288
|
if kw.get("error_info"):
|
|
263
289
|
kw["error_info"] = ErrorInfo.from_json(kw["error_info"])
|
|
290
|
+
if kw.get("subdag_infos"):
|
|
291
|
+
kw["subdag_infos"] = {
|
|
292
|
+
k: SubDagInfo.from_json(v) for k, v in kw["subdag_infos"].items()
|
|
293
|
+
}
|
|
264
294
|
return DagInfo(**kw)
|
|
265
295
|
|
|
266
296
|
def to_json(self) -> dict:
|
|
@@ -279,6 +309,8 @@ class DagInfo(JsonSerializable):
|
|
|
279
309
|
}
|
|
280
310
|
if self.error_info:
|
|
281
311
|
ret["error_info"] = self.error_info.to_json()
|
|
312
|
+
if self.subdag_infos:
|
|
313
|
+
ret["subdag_infos"] = {k: v.to_json() for k, v in self.subdag_infos.items()}
|
|
282
314
|
return ret
|
|
283
315
|
|
|
284
316
|
|
|
@@ -302,7 +334,9 @@ class SessionInfo(JsonSerializable):
|
|
|
302
334
|
error_info: Optional[ErrorInfo] = ReferenceField("error_info", default=None)
|
|
303
335
|
|
|
304
336
|
@classmethod
|
|
305
|
-
def from_json(cls, serialized: dict) -> "SessionInfo":
|
|
337
|
+
def from_json(cls, serialized: dict) -> Optional["SessionInfo"]:
|
|
338
|
+
if serialized is None:
|
|
339
|
+
return None
|
|
306
340
|
kw = serialized.copy()
|
|
307
341
|
if kw.get("dag_infos"):
|
|
308
342
|
kw["dag_infos"] = {
|
|
@@ -320,7 +354,10 @@ class SessionInfo(JsonSerializable):
|
|
|
320
354
|
"idle_timestamp": self.idle_timestamp,
|
|
321
355
|
}
|
|
322
356
|
if self.dag_infos:
|
|
323
|
-
ret["dag_infos"] = {
|
|
357
|
+
ret["dag_infos"] = {
|
|
358
|
+
k: v.to_json() if v is not None else None
|
|
359
|
+
for k, v in self.dag_infos.items()
|
|
360
|
+
}
|
|
324
361
|
if self.error_info:
|
|
325
362
|
ret["error_info"] = self.error_info.to_json()
|
|
326
363
|
return ret
|
|
@@ -342,7 +379,25 @@ class ExecuteDagRequest(Serializable):
|
|
|
342
379
|
)
|
|
343
380
|
|
|
344
381
|
|
|
345
|
-
class
|
|
382
|
+
class SubDagSubmitInstanceInfo(JsonSerializable):
|
|
383
|
+
submit_reason: str = StringField("submit_reason")
|
|
384
|
+
instance_id: str = StringField("instance_id")
|
|
385
|
+
subquery_id: Optional[int] = Int32Field("subquery_id", default=None)
|
|
386
|
+
|
|
387
|
+
@classmethod
|
|
388
|
+
def from_json(cls, serialized: dict) -> "SubDagSubmitInstanceInfo":
|
|
389
|
+
return SubDagSubmitInstanceInfo(**serialized)
|
|
390
|
+
|
|
391
|
+
def to_json(self) -> dict:
|
|
392
|
+
ret = {
|
|
393
|
+
"submit_reason": self.submit_reason,
|
|
394
|
+
"instance_id": self.instance_id,
|
|
395
|
+
"subquery_id": self.subquery_id,
|
|
396
|
+
}
|
|
397
|
+
return ret
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
class SubDagInfo(JsonSerializable):
|
|
346
401
|
subdag_id: str = StringField("subdag_id")
|
|
347
402
|
status: DagStatus = EnumField("status", DagStatus, FieldTypes.int8, default=None)
|
|
348
403
|
progress: float = Float64Field("progress", default=None)
|
|
@@ -355,9 +410,52 @@ class SubDagInfo(Serializable):
|
|
|
355
410
|
FieldTypes.reference,
|
|
356
411
|
default_factory=dict,
|
|
357
412
|
)
|
|
413
|
+
start_timestamp: Optional[float] = Float64Field("start_timestamp", default=None)
|
|
414
|
+
end_timestamp: Optional[float] = Float64Field("end_timestamp", default=None)
|
|
415
|
+
submit_instances: List[SubDagSubmitInstanceInfo] = ListField(
|
|
416
|
+
"submit_instances",
|
|
417
|
+
FieldTypes.reference,
|
|
418
|
+
default_factory=list,
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
@classmethod
|
|
422
|
+
def from_json(cls, serialized: dict) -> "SubDagInfo":
|
|
423
|
+
kw = serialized.copy()
|
|
424
|
+
kw["status"] = DagStatus(kw["status"])
|
|
425
|
+
if kw.get("tileable_to_result_infos"):
|
|
426
|
+
kw["tileable_to_result_infos"] = {
|
|
427
|
+
k: ResultInfo.from_json(s)
|
|
428
|
+
for k, s in kw["tileable_to_result_infos"].items()
|
|
429
|
+
}
|
|
430
|
+
if kw.get("error_info"):
|
|
431
|
+
kw["error_info"] = ErrorInfo.from_json(kw["error_info"])
|
|
432
|
+
if kw.get("submit_instances"):
|
|
433
|
+
kw["submit_instances"] = [
|
|
434
|
+
SubDagSubmitInstanceInfo.from_json(s) for s in kw["submit_instances"]
|
|
435
|
+
]
|
|
436
|
+
return SubDagInfo(**kw)
|
|
437
|
+
|
|
438
|
+
def to_json(self) -> dict:
|
|
439
|
+
ret = {
|
|
440
|
+
"subdag_id": self.subdag_id,
|
|
441
|
+
"status": self.status.value,
|
|
442
|
+
"progress": self.progress,
|
|
443
|
+
"start_timestamp": self.start_timestamp,
|
|
444
|
+
"end_timestamp": self.end_timestamp,
|
|
445
|
+
}
|
|
446
|
+
if self.error_info:
|
|
447
|
+
ret["error_info"] = self.error_info.to_json()
|
|
448
|
+
if self.tileable_to_result_infos:
|
|
449
|
+
ret["tileable_to_result_infos"] = {
|
|
450
|
+
k: v.to_json() for k, v in self.tileable_to_result_infos.items()
|
|
451
|
+
}
|
|
452
|
+
if self.submit_instances:
|
|
453
|
+
ret["submit_instances"] = [i.to_json() for i in self.submit_instances]
|
|
454
|
+
return ret
|
|
358
455
|
|
|
359
456
|
|
|
360
457
|
class ExecuteSubDagRequest(Serializable):
|
|
458
|
+
subdag_id: str = StringField("subdag_id")
|
|
361
459
|
dag: TileableGraph = ReferenceField(
|
|
362
460
|
"dag",
|
|
363
461
|
on_serialize=SerializableGraph.from_graph,
|
|
Binary file
|
maxframe/serialization/core.pxd
CHANGED
|
@@ -18,6 +18,9 @@ from libc.stdint cimport int32_t, uint64_t
|
|
|
18
18
|
cdef class Serializer:
|
|
19
19
|
cdef int _serializer_id
|
|
20
20
|
|
|
21
|
+
cpdef bint is_public_data_exist(self, dict context, object key)
|
|
22
|
+
cpdef put_public_data(self, dict context, object key, object value)
|
|
23
|
+
cpdef get_public_data(self, dict context, object key)
|
|
21
24
|
cpdef serial(self, object obj, dict context)
|
|
22
25
|
cpdef deserial(self, list serialized, dict context, list subs)
|
|
23
26
|
cpdef on_deserial_error(
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# Copyright 1999-2024 Alibaba Group Holding Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from concurrent.futures import Executor
|
|
16
|
+
from typing import Any, Callable, Dict, List, TypeVar
|
|
17
|
+
|
|
18
|
+
def buffered(func: Callable) -> Callable: ...
|
|
19
|
+
def fast_id(obj: Any) -> int: ...
|
|
20
|
+
|
|
21
|
+
LoadType = TypeVar("LoadType")
|
|
22
|
+
|
|
23
|
+
def load_type(class_name: str, parent_class: LoadType) -> LoadType: ...
|
|
24
|
+
|
|
25
|
+
class PickleContainer:
|
|
26
|
+
def __init__(self, buffers: List[bytes]): ...
|
|
27
|
+
def get(self) -> Any: ...
|
|
28
|
+
def get_buffers(self) -> List[bytes]: ...
|
|
29
|
+
|
|
30
|
+
class Serializer:
|
|
31
|
+
serializer_id: int
|
|
32
|
+
def is_public_data_exist(self, context: Dict, key: Any) -> bool: ...
|
|
33
|
+
def put_public_data(self, context: Dict, key: Any, value: Any) -> None: ...
|
|
34
|
+
def get_public_data(self, context: Dict, key: Any) -> Any: ...
|
|
35
|
+
def serial(self, obj: Any, context: Dict): ...
|
|
36
|
+
def deserial(self, serialized: List, context: Dict, subs: List[Any]): ...
|
|
37
|
+
def on_deserial_error(
|
|
38
|
+
self,
|
|
39
|
+
serialized: List,
|
|
40
|
+
context: Dict,
|
|
41
|
+
subs_serialized: List,
|
|
42
|
+
error_index: int,
|
|
43
|
+
exc: BaseException,
|
|
44
|
+
): ...
|
|
45
|
+
@classmethod
|
|
46
|
+
def register(cls, obj_type): ...
|
|
47
|
+
@classmethod
|
|
48
|
+
def unregister(cls, obj_type): ...
|
|
49
|
+
|
|
50
|
+
class Placeholder:
|
|
51
|
+
id: int
|
|
52
|
+
callbacks: List[Callable]
|
|
53
|
+
def __init__(self, id_: int): ...
|
|
54
|
+
def __hash__(self): ...
|
|
55
|
+
def __eq__(self, other): ...
|
|
56
|
+
|
|
57
|
+
def serialize(obj: Any, context: Dict = None): ...
|
|
58
|
+
async def serialize_with_spawn(
|
|
59
|
+
obj: Any,
|
|
60
|
+
context: Dict = None,
|
|
61
|
+
spawn_threshold: int = 100,
|
|
62
|
+
executor: Executor = None,
|
|
63
|
+
): ...
|
|
64
|
+
def deserialize(headers: List, buffers: List, context: Dict = None): ...
|