maxframe 0.1.0b4__cp39-cp39-macosx_10_9_universal2.whl → 1.0.0__cp39-cp39-macosx_10_9_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of maxframe might be problematic. Click here for more details.
- maxframe/__init__.py +1 -0
- maxframe/_utils.cpython-39-darwin.so +0 -0
- maxframe/codegen.py +56 -5
- maxframe/config/config.py +78 -10
- maxframe/config/validators.py +42 -11
- maxframe/conftest.py +58 -14
- maxframe/core/__init__.py +2 -16
- maxframe/core/entity/__init__.py +1 -12
- maxframe/core/entity/executable.py +1 -1
- maxframe/core/entity/objects.py +46 -45
- maxframe/core/entity/output_types.py +0 -3
- maxframe/core/entity/tests/test_objects.py +43 -0
- maxframe/core/entity/tileables.py +5 -78
- maxframe/core/graph/__init__.py +2 -2
- maxframe/core/graph/builder/__init__.py +0 -1
- maxframe/core/graph/builder/base.py +5 -4
- maxframe/core/graph/builder/tileable.py +4 -4
- maxframe/core/graph/builder/utils.py +4 -8
- maxframe/core/graph/core.cpython-39-darwin.so +0 -0
- maxframe/core/graph/core.pyx +4 -4
- maxframe/core/graph/entity.py +9 -33
- maxframe/core/operator/__init__.py +2 -9
- maxframe/core/operator/base.py +3 -5
- maxframe/core/operator/objects.py +0 -9
- maxframe/core/operator/utils.py +55 -0
- maxframe/dataframe/__init__.py +2 -1
- maxframe/dataframe/arithmetic/around.py +5 -17
- maxframe/dataframe/arithmetic/core.py +15 -7
- maxframe/dataframe/arithmetic/docstring.py +7 -33
- maxframe/dataframe/arithmetic/equal.py +4 -2
- maxframe/dataframe/arithmetic/greater.py +4 -2
- maxframe/dataframe/arithmetic/greater_equal.py +4 -2
- maxframe/dataframe/arithmetic/less.py +2 -2
- maxframe/dataframe/arithmetic/less_equal.py +4 -2
- maxframe/dataframe/arithmetic/not_equal.py +4 -2
- maxframe/dataframe/arithmetic/tests/test_arithmetic.py +39 -16
- maxframe/dataframe/core.py +58 -12
- maxframe/dataframe/datasource/date_range.py +2 -2
- maxframe/dataframe/datasource/read_odps_query.py +120 -24
- maxframe/dataframe/datasource/read_odps_table.py +9 -4
- maxframe/dataframe/datasource/tests/test_datasource.py +103 -8
- maxframe/dataframe/datastore/tests/test_to_odps.py +48 -0
- maxframe/dataframe/datastore/to_odps.py +28 -0
- maxframe/dataframe/extensions/__init__.py +5 -0
- maxframe/dataframe/extensions/flatjson.py +131 -0
- maxframe/dataframe/extensions/flatmap.py +317 -0
- maxframe/dataframe/extensions/reshuffle.py +1 -1
- maxframe/dataframe/extensions/tests/test_extensions.py +108 -3
- maxframe/dataframe/groupby/core.py +1 -1
- maxframe/dataframe/groupby/cum.py +0 -1
- maxframe/dataframe/groupby/fill.py +4 -1
- maxframe/dataframe/groupby/getitem.py +6 -0
- maxframe/dataframe/groupby/tests/test_groupby.py +5 -1
- maxframe/dataframe/groupby/transform.py +5 -1
- maxframe/dataframe/indexing/align.py +1 -1
- maxframe/dataframe/indexing/loc.py +6 -4
- maxframe/dataframe/indexing/rename.py +5 -28
- maxframe/dataframe/indexing/sample.py +0 -1
- maxframe/dataframe/indexing/set_index.py +68 -1
- maxframe/dataframe/initializer.py +11 -1
- maxframe/dataframe/merge/__init__.py +9 -1
- maxframe/dataframe/merge/concat.py +41 -31
- maxframe/dataframe/merge/merge.py +237 -3
- maxframe/dataframe/merge/tests/test_merge.py +126 -1
- maxframe/dataframe/misc/__init__.py +4 -0
- maxframe/dataframe/misc/apply.py +6 -11
- maxframe/dataframe/misc/case_when.py +141 -0
- maxframe/dataframe/misc/describe.py +2 -2
- maxframe/dataframe/misc/drop_duplicates.py +8 -8
- maxframe/dataframe/misc/eval.py +4 -0
- maxframe/dataframe/misc/memory_usage.py +2 -2
- maxframe/dataframe/misc/pct_change.py +1 -83
- maxframe/dataframe/misc/pivot_table.py +262 -0
- maxframe/dataframe/misc/tests/test_misc.py +93 -1
- maxframe/dataframe/misc/transform.py +1 -30
- maxframe/dataframe/misc/value_counts.py +4 -17
- maxframe/dataframe/missing/dropna.py +1 -1
- maxframe/dataframe/missing/fillna.py +5 -5
- maxframe/dataframe/operators.py +1 -17
- maxframe/dataframe/plotting/core.py +2 -2
- maxframe/dataframe/reduction/core.py +4 -3
- maxframe/dataframe/reduction/tests/test_reduction.py +2 -4
- maxframe/dataframe/sort/sort_values.py +1 -11
- maxframe/dataframe/statistics/corr.py +3 -3
- maxframe/dataframe/statistics/quantile.py +13 -19
- maxframe/dataframe/statistics/tests/test_statistics.py +4 -4
- maxframe/dataframe/tests/test_initializer.py +33 -2
- maxframe/dataframe/utils.py +33 -11
- maxframe/dataframe/window/expanding.py +5 -3
- maxframe/dataframe/window/tests/test_expanding.py +2 -2
- maxframe/errors.py +13 -0
- maxframe/extension.py +12 -0
- maxframe/io/__init__.py +13 -0
- maxframe/io/objects/__init__.py +24 -0
- maxframe/io/objects/core.py +140 -0
- maxframe/io/objects/tensor.py +76 -0
- maxframe/io/objects/tests/__init__.py +13 -0
- maxframe/io/objects/tests/test_object_io.py +97 -0
- maxframe/{odpsio → io/odpsio}/__init__.py +3 -1
- maxframe/{odpsio → io/odpsio}/arrow.py +43 -12
- maxframe/{odpsio → io/odpsio}/schema.py +38 -16
- maxframe/io/odpsio/tableio.py +719 -0
- maxframe/io/odpsio/tests/__init__.py +13 -0
- maxframe/{odpsio → io/odpsio}/tests/test_schema.py +75 -33
- maxframe/{odpsio → io/odpsio}/tests/test_tableio.py +50 -23
- maxframe/{odpsio → io/odpsio}/tests/test_volumeio.py +4 -6
- maxframe/io/odpsio/volumeio.py +63 -0
- maxframe/learn/contrib/__init__.py +3 -1
- maxframe/learn/contrib/graph/__init__.py +15 -0
- maxframe/learn/contrib/graph/connected_components.py +215 -0
- maxframe/learn/contrib/graph/tests/__init__.py +13 -0
- maxframe/learn/contrib/graph/tests/test_connected_components.py +53 -0
- maxframe/learn/contrib/llm/__init__.py +16 -0
- maxframe/learn/contrib/llm/core.py +54 -0
- maxframe/learn/contrib/llm/models/__init__.py +14 -0
- maxframe/learn/contrib/llm/models/dashscope.py +73 -0
- maxframe/learn/contrib/llm/multi_modal.py +42 -0
- maxframe/learn/contrib/llm/text.py +42 -0
- maxframe/learn/contrib/utils.py +52 -0
- maxframe/learn/contrib/xgboost/__init__.py +26 -0
- maxframe/learn/contrib/xgboost/classifier.py +110 -0
- maxframe/learn/contrib/xgboost/core.py +241 -0
- maxframe/learn/contrib/xgboost/dmatrix.py +147 -0
- maxframe/learn/contrib/xgboost/predict.py +121 -0
- maxframe/learn/contrib/xgboost/regressor.py +71 -0
- maxframe/learn/contrib/xgboost/tests/__init__.py +13 -0
- maxframe/learn/contrib/xgboost/tests/test_core.py +43 -0
- maxframe/learn/contrib/xgboost/train.py +132 -0
- maxframe/{core/operator/fuse.py → learn/core.py} +7 -10
- maxframe/learn/utils/__init__.py +15 -0
- maxframe/learn/utils/core.py +29 -0
- maxframe/lib/mmh3.cpython-39-darwin.so +0 -0
- maxframe/lib/mmh3.pyi +43 -0
- maxframe/lib/sparse/tests/test_sparse.py +15 -15
- maxframe/lib/wrapped_pickle.py +2 -1
- maxframe/opcodes.py +11 -0
- maxframe/protocol.py +154 -27
- maxframe/remote/core.py +4 -8
- maxframe/serialization/__init__.py +1 -0
- maxframe/serialization/core.cpython-39-darwin.so +0 -0
- maxframe/serialization/core.pxd +3 -0
- maxframe/serialization/core.pyi +64 -0
- maxframe/serialization/core.pyx +67 -26
- maxframe/serialization/exception.py +1 -1
- maxframe/serialization/pandas.py +52 -17
- maxframe/serialization/serializables/core.py +180 -15
- maxframe/serialization/serializables/field_type.py +4 -1
- maxframe/serialization/serializables/tests/test_serializable.py +54 -5
- maxframe/serialization/tests/test_serial.py +2 -1
- maxframe/session.py +37 -2
- maxframe/tensor/__init__.py +81 -2
- maxframe/tensor/arithmetic/isclose.py +1 -0
- maxframe/tensor/arithmetic/tests/test_arithmetic.py +22 -18
- maxframe/tensor/core.py +5 -136
- maxframe/tensor/datasource/array.py +7 -2
- maxframe/tensor/datasource/full.py +1 -1
- maxframe/tensor/datasource/scalar.py +1 -1
- maxframe/tensor/datasource/tests/test_datasource.py +1 -1
- maxframe/tensor/indexing/flatnonzero.py +1 -1
- maxframe/tensor/indexing/getitem.py +2 -0
- maxframe/tensor/merge/__init__.py +2 -0
- maxframe/tensor/merge/concatenate.py +101 -0
- maxframe/tensor/merge/tests/test_merge.py +30 -1
- maxframe/tensor/merge/vstack.py +74 -0
- maxframe/tensor/{base → misc}/__init__.py +4 -0
- maxframe/tensor/misc/atleast_1d.py +72 -0
- maxframe/tensor/misc/atleast_2d.py +70 -0
- maxframe/tensor/misc/atleast_3d.py +85 -0
- maxframe/tensor/misc/tests/__init__.py +13 -0
- maxframe/tensor/{base → misc}/transpose.py +22 -18
- maxframe/tensor/misc/unique.py +205 -0
- maxframe/tensor/operators.py +1 -7
- maxframe/tensor/random/core.py +1 -1
- maxframe/tensor/reduction/count_nonzero.py +2 -1
- maxframe/tensor/reduction/mean.py +1 -0
- maxframe/tensor/reduction/nanmean.py +1 -0
- maxframe/tensor/reduction/nanvar.py +2 -0
- maxframe/tensor/reduction/tests/test_reduction.py +12 -1
- maxframe/tensor/reduction/var.py +2 -0
- maxframe/tensor/statistics/quantile.py +2 -2
- maxframe/tensor/utils.py +2 -22
- maxframe/tests/test_protocol.py +34 -0
- maxframe/tests/test_utils.py +0 -12
- maxframe/tests/utils.py +17 -2
- maxframe/typing_.py +4 -1
- maxframe/udf.py +62 -3
- maxframe/utils.py +112 -86
- {maxframe-0.1.0b4.dist-info → maxframe-1.0.0.dist-info}/METADATA +25 -25
- {maxframe-0.1.0b4.dist-info → maxframe-1.0.0.dist-info}/RECORD +208 -167
- {maxframe-0.1.0b4.dist-info → maxframe-1.0.0.dist-info}/WHEEL +1 -1
- maxframe_client/__init__.py +0 -1
- maxframe_client/clients/framedriver.py +4 -1
- maxframe_client/fetcher.py +123 -54
- maxframe_client/session/consts.py +3 -0
- maxframe_client/session/graph.py +8 -2
- maxframe_client/session/odps.py +223 -40
- maxframe_client/session/task.py +108 -80
- maxframe_client/tests/test_fetcher.py +21 -3
- maxframe_client/tests/test_session.py +136 -8
- maxframe/core/entity/chunks.py +0 -68
- maxframe/core/entity/fuse.py +0 -73
- maxframe/core/graph/builder/chunk.py +0 -430
- maxframe/odpsio/tableio.py +0 -300
- maxframe/odpsio/volumeio.py +0 -95
- maxframe_client/clients/spe.py +0 -104
- /maxframe/{odpsio → core/entity}/tests/__init__.py +0 -0
- /maxframe/{tensor/base → dataframe/datastore}/tests/__init__.py +0 -0
- /maxframe/{odpsio → io/odpsio}/tests/test_arrow.py +0 -0
- /maxframe/tensor/{base → misc}/astype.py +0 -0
- /maxframe/tensor/{base → misc}/broadcast_to.py +0 -0
- /maxframe/tensor/{base → misc}/ravel.py +0 -0
- /maxframe/tensor/{base/tests/test_base.py → misc/tests/test_misc.py} +0 -0
- /maxframe/tensor/{base → misc}/where.py +0 -0
- {maxframe-0.1.0b4.dist-info → maxframe-1.0.0.dist-info}/top_level.txt +0 -0
maxframe_client/session/task.py
CHANGED
|
@@ -16,17 +16,23 @@ import base64
|
|
|
16
16
|
import json
|
|
17
17
|
import logging
|
|
18
18
|
import time
|
|
19
|
-
from typing import Dict, List, Optional, Type, Union
|
|
19
|
+
from typing import Any, Dict, List, Optional, Type, Union
|
|
20
20
|
|
|
21
21
|
import msgpack
|
|
22
22
|
from odps import ODPS
|
|
23
23
|
from odps import options as odps_options
|
|
24
|
-
from odps import serializers
|
|
25
24
|
from odps.errors import parse_instance_error
|
|
26
|
-
from odps.models import Instance,
|
|
25
|
+
from odps.models import Instance, MaxFrameTask
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
from odps.errors import EmptyTaskInfoError
|
|
29
|
+
except ImportError: # pragma: no cover
|
|
30
|
+
# todo remove when pyodps>=0.12.0 is enforced
|
|
31
|
+
EmptyTaskInfoError = type("EmptyTaskInfoError", (Exception,), {})
|
|
27
32
|
|
|
28
33
|
from maxframe.config import options
|
|
29
34
|
from maxframe.core import TileableGraph
|
|
35
|
+
from maxframe.errors import NoTaskServerResponseError, SessionAlreadyClosedError
|
|
30
36
|
from maxframe.protocol import DagInfo, JsonSerializable, ResultInfo, SessionInfo
|
|
31
37
|
from maxframe.utils import deserialize_serializable, serialize_serializable, to_str
|
|
32
38
|
|
|
@@ -36,6 +42,7 @@ except ImportError:
|
|
|
36
42
|
mf_version = None
|
|
37
43
|
|
|
38
44
|
from .consts import (
|
|
45
|
+
EMPTY_RESPONSE_RETRY_COUNT,
|
|
39
46
|
MAXFRAME_DEFAULT_PROTOCOL,
|
|
40
47
|
MAXFRAME_OUTPUT_JSON_FORMAT,
|
|
41
48
|
MAXFRAME_OUTPUT_MAXFRAME_FORMAT,
|
|
@@ -55,55 +62,6 @@ from .odps import MaxFrameServiceCaller, MaxFrameSession
|
|
|
55
62
|
logger = logging.getLogger(__name__)
|
|
56
63
|
|
|
57
64
|
|
|
58
|
-
class MaxFrameTask(Task):
|
|
59
|
-
__slots__ = ("_output_format", "_major_version", "_service_endpoint")
|
|
60
|
-
_root = "MaxFrame"
|
|
61
|
-
_anonymous_task_name = "AnonymousMaxFrameTask"
|
|
62
|
-
|
|
63
|
-
command = serializers.XMLNodeField("Command", default="CREATE_SESSION")
|
|
64
|
-
|
|
65
|
-
def __init__(self, **kwargs):
|
|
66
|
-
kwargs["name"] = kwargs.get("name") or self._anonymous_task_name
|
|
67
|
-
self._output_format = kwargs.pop(
|
|
68
|
-
"output_format", MAXFRAME_OUTPUT_MSGPACK_FORMAT
|
|
69
|
-
)
|
|
70
|
-
self._major_version = kwargs.pop("major_version", None)
|
|
71
|
-
self._service_endpoint = kwargs.pop("service_endpoint", None)
|
|
72
|
-
super().__init__(**kwargs)
|
|
73
|
-
|
|
74
|
-
def serial(self):
|
|
75
|
-
if self.properties is None:
|
|
76
|
-
self.properties = dict()
|
|
77
|
-
|
|
78
|
-
if odps_options.default_task_settings:
|
|
79
|
-
settings = odps_options.default_task_settings
|
|
80
|
-
else:
|
|
81
|
-
settings = dict()
|
|
82
|
-
|
|
83
|
-
if self._major_version is not None:
|
|
84
|
-
settings["odps.task.major.version"] = self._major_version
|
|
85
|
-
|
|
86
|
-
if "settings" in self.properties:
|
|
87
|
-
settings.update(json.loads(self.properties["settings"]))
|
|
88
|
-
|
|
89
|
-
# merge sql options
|
|
90
|
-
sql_settings = (odps_options.sql.settings or {}).copy()
|
|
91
|
-
sql_settings.update(options.sql.settings or {})
|
|
92
|
-
|
|
93
|
-
mf_settings = dict(options.to_dict(remote_only=True).items())
|
|
94
|
-
mf_settings["sql.settings"] = sql_settings
|
|
95
|
-
mf_opts = {
|
|
96
|
-
"odps.maxframe.settings": json.dumps(mf_settings),
|
|
97
|
-
"odps.maxframe.output_format": self._output_format,
|
|
98
|
-
"odps.service.endpoint": self._service_endpoint,
|
|
99
|
-
}
|
|
100
|
-
if mf_version:
|
|
101
|
-
mf_opts["odps.maxframe.client_version"] = mf_version
|
|
102
|
-
settings.update(mf_opts)
|
|
103
|
-
self.properties["settings"] = json.dumps(settings)
|
|
104
|
-
return super().serial()
|
|
105
|
-
|
|
106
|
-
|
|
107
65
|
class MaxFrameInstanceCaller(MaxFrameServiceCaller):
|
|
108
66
|
_instance: Optional[Instance]
|
|
109
67
|
|
|
@@ -132,6 +90,7 @@ class MaxFrameInstanceCaller(MaxFrameServiceCaller):
|
|
|
132
90
|
self._running_cluster = running_cluster
|
|
133
91
|
self._major_version = major_version
|
|
134
92
|
self._output_format = output_format or MAXFRAME_OUTPUT_MSGPACK_FORMAT
|
|
93
|
+
self._deleted = False
|
|
135
94
|
|
|
136
95
|
if nested_instance_id is None:
|
|
137
96
|
self._nested = False
|
|
@@ -140,14 +99,26 @@ class MaxFrameInstanceCaller(MaxFrameServiceCaller):
|
|
|
140
99
|
self._nested = True
|
|
141
100
|
self._instance = odps_entry.get_instance(nested_instance_id)
|
|
142
101
|
|
|
102
|
+
@property
|
|
103
|
+
def instance(self):
|
|
104
|
+
return self._instance
|
|
105
|
+
|
|
143
106
|
def _deserial_task_info_result(
|
|
144
107
|
self, content: Union[bytes, str, dict], target_cls: Type[JsonSerializable]
|
|
145
108
|
):
|
|
146
109
|
if isinstance(content, (str, bytes)):
|
|
110
|
+
if len(content) == 0:
|
|
111
|
+
content = "{}"
|
|
147
112
|
json_data = json.loads(to_str(content))
|
|
148
113
|
else:
|
|
149
114
|
json_data = content
|
|
150
|
-
|
|
115
|
+
encoded_result = json_data.get("result")
|
|
116
|
+
if not encoded_result:
|
|
117
|
+
if self._deleted:
|
|
118
|
+
return None
|
|
119
|
+
else:
|
|
120
|
+
raise SessionAlreadyClosedError(self._instance.id)
|
|
121
|
+
result_data = base64.b64decode(encoded_result)
|
|
151
122
|
if self._output_format == MAXFRAME_OUTPUT_MAXFRAME_FORMAT:
|
|
152
123
|
return deserialize_serializable(result_data)
|
|
153
124
|
elif self._output_format == MAXFRAME_OUTPUT_JSON_FORMAT:
|
|
@@ -159,13 +130,19 @@ class MaxFrameInstanceCaller(MaxFrameServiceCaller):
|
|
|
159
130
|
f"Serialization format {self._output_format} not supported"
|
|
160
131
|
)
|
|
161
132
|
|
|
133
|
+
def _create_maxframe_task(self) -> MaxFrameTask:
|
|
134
|
+
task = MaxFrameTask(name=self._task_name, major_version=self._major_version)
|
|
135
|
+
mf_opts = {
|
|
136
|
+
"odps.maxframe.settings": json.dumps(self.get_settings_to_upload()),
|
|
137
|
+
"odps.maxframe.output_format": self._output_format,
|
|
138
|
+
}
|
|
139
|
+
if mf_version:
|
|
140
|
+
mf_opts["odps.maxframe.client_version"] = mf_version
|
|
141
|
+
task.update_settings(mf_opts)
|
|
142
|
+
return task
|
|
143
|
+
|
|
162
144
|
def create_session(self) -> SessionInfo:
|
|
163
|
-
task =
|
|
164
|
-
name=self._task_name,
|
|
165
|
-
major_version=self._major_version,
|
|
166
|
-
output_format=self._output_format,
|
|
167
|
-
service_endpoint=self._odps_entry.endpoint,
|
|
168
|
-
)
|
|
145
|
+
task = self._create_maxframe_task()
|
|
169
146
|
if not self._nested:
|
|
170
147
|
self._task_name = task.name
|
|
171
148
|
project = self._odps_entry.get_project(self._project)
|
|
@@ -210,11 +187,40 @@ class MaxFrameInstanceCaller(MaxFrameServiceCaller):
|
|
|
210
187
|
time.sleep(interval)
|
|
211
188
|
interval = min(max_interval, interval * 2)
|
|
212
189
|
|
|
190
|
+
def _put_task_info(self, method_name: str, json_data: dict):
|
|
191
|
+
for trial in range(EMPTY_RESPONSE_RETRY_COUNT):
|
|
192
|
+
try:
|
|
193
|
+
return self._instance.put_task_info(
|
|
194
|
+
self._task_name,
|
|
195
|
+
method_name,
|
|
196
|
+
json.dumps(json_data),
|
|
197
|
+
raise_empty=True,
|
|
198
|
+
)
|
|
199
|
+
except TypeError: # pragma: no cover
|
|
200
|
+
# todo remove when pyodps>=0.12.0 is enforced
|
|
201
|
+
resp_data = self._instance.put_task_info(
|
|
202
|
+
self._task_name, method_name, json.dumps(json_data)
|
|
203
|
+
)
|
|
204
|
+
if resp_data:
|
|
205
|
+
return resp_data
|
|
206
|
+
else:
|
|
207
|
+
raise NoTaskServerResponseError(
|
|
208
|
+
f"No response for request {method_name}. "
|
|
209
|
+
f"Instance ID: {self._instance.id}"
|
|
210
|
+
)
|
|
211
|
+
except EmptyTaskInfoError as ex:
|
|
212
|
+
# retry when server returns HTTP 204, which is designed for retry
|
|
213
|
+
if ex.code != 204 or trial >= EMPTY_RESPONSE_RETRY_COUNT - 1:
|
|
214
|
+
raise NoTaskServerResponseError(
|
|
215
|
+
f"No response for request {method_name}. "
|
|
216
|
+
f"Instance ID: {self._instance.id}. "
|
|
217
|
+
f"Request ID: {ex.request_id}"
|
|
218
|
+
) from None
|
|
219
|
+
time.sleep(0.5)
|
|
220
|
+
|
|
213
221
|
def get_session(self) -> SessionInfo:
|
|
214
222
|
req_data = {"output_format": self._output_format}
|
|
215
|
-
serialized = self.
|
|
216
|
-
self._task_name, MAXFRAME_TASK_GET_SESSION_METHOD, json.dumps(req_data)
|
|
217
|
-
)
|
|
223
|
+
serialized = self._put_task_info(MAXFRAME_TASK_GET_SESSION_METHOD, req_data)
|
|
218
224
|
info: SessionInfo = self._deserial_task_info_result(serialized, SessionInfo)
|
|
219
225
|
info.session_id = self._instance.id
|
|
220
226
|
return info
|
|
@@ -224,28 +230,28 @@ class MaxFrameInstanceCaller(MaxFrameServiceCaller):
|
|
|
224
230
|
self._instance.stop()
|
|
225
231
|
else:
|
|
226
232
|
req_data = {"output_format": self._output_format}
|
|
227
|
-
self.
|
|
228
|
-
|
|
229
|
-
MAXFRAME_TASK_DELETE_SESSION_METHOD,
|
|
230
|
-
json.dumps(req_data),
|
|
231
|
-
)
|
|
233
|
+
self._put_task_info(MAXFRAME_TASK_DELETE_SESSION_METHOD, req_data)
|
|
234
|
+
self._deleted = True
|
|
232
235
|
|
|
233
236
|
def submit_dag(
|
|
234
237
|
self,
|
|
235
238
|
dag: TileableGraph,
|
|
236
239
|
managed_input_infos: Optional[Dict[str, ResultInfo]] = None,
|
|
240
|
+
new_settings: Dict[str, Any] = None,
|
|
237
241
|
) -> DagInfo:
|
|
242
|
+
new_settings_value = {
|
|
243
|
+
"odps.maxframe.settings": json.dumps(new_settings),
|
|
244
|
+
}
|
|
238
245
|
req_data = {
|
|
239
246
|
"protocol": MAXFRAME_DEFAULT_PROTOCOL,
|
|
240
247
|
"dag": base64.b64encode(serialize_serializable(dag)).decode(),
|
|
241
248
|
"managed_input_infos": base64.b64encode(
|
|
242
249
|
serialize_serializable(managed_input_infos)
|
|
243
250
|
).decode(),
|
|
251
|
+
"new_settings": json.dumps(new_settings_value),
|
|
244
252
|
"output_format": self._output_format,
|
|
245
253
|
}
|
|
246
|
-
res = self.
|
|
247
|
-
self._task_name, MAXFRAME_TASK_SUBMIT_DAG_METHOD, json.dumps(req_data)
|
|
248
|
-
)
|
|
254
|
+
res = self._put_task_info(MAXFRAME_TASK_SUBMIT_DAG_METHOD, req_data)
|
|
249
255
|
return self._deserial_task_info_result(res, DagInfo)
|
|
250
256
|
|
|
251
257
|
def get_dag_info(self, dag_id: str) -> DagInfo:
|
|
@@ -254,9 +260,7 @@ class MaxFrameInstanceCaller(MaxFrameServiceCaller):
|
|
|
254
260
|
"dag_id": dag_id,
|
|
255
261
|
"output_format": self._output_format,
|
|
256
262
|
}
|
|
257
|
-
res = self.
|
|
258
|
-
self._task_name, MAXFRAME_TASK_GET_DAG_INFO_METHOD, json.dumps(req_data)
|
|
259
|
-
)
|
|
263
|
+
res = self._put_task_info(MAXFRAME_TASK_GET_DAG_INFO_METHOD, req_data)
|
|
260
264
|
return self._deserial_task_info_result(res, DagInfo)
|
|
261
265
|
|
|
262
266
|
def cancel_dag(self, dag_id: str) -> DagInfo:
|
|
@@ -265,24 +269,39 @@ class MaxFrameInstanceCaller(MaxFrameServiceCaller):
|
|
|
265
269
|
"dag_id": dag_id,
|
|
266
270
|
"output_format": self._output_format,
|
|
267
271
|
}
|
|
268
|
-
res = self.
|
|
269
|
-
self._task_name, MAXFRAME_TASK_CANCEL_DAG_METHOD, json.dumps(req_data)
|
|
270
|
-
)
|
|
272
|
+
res = self._put_task_info(MAXFRAME_TASK_CANCEL_DAG_METHOD, req_data)
|
|
271
273
|
return self._deserial_task_info_result(res, DagInfo)
|
|
272
274
|
|
|
273
275
|
def decref(self, tileable_keys: List[str]) -> None:
|
|
274
276
|
req_data = {
|
|
275
277
|
"tileable_keys": ",".join(tileable_keys),
|
|
276
278
|
}
|
|
277
|
-
self.
|
|
278
|
-
|
|
279
|
-
|
|
279
|
+
self._put_task_info(MAXFRAME_TASK_DECREF_METHOD, req_data)
|
|
280
|
+
|
|
281
|
+
def get_logview_address(self, dag_id=None, hours=None) -> Optional[str]:
|
|
282
|
+
"""
|
|
283
|
+
Generate logview address
|
|
284
|
+
|
|
285
|
+
Parameters
|
|
286
|
+
----------
|
|
287
|
+
dag_id: id of dag for which dag logview detail page to access
|
|
288
|
+
hours: hours of the logview address auth limit
|
|
289
|
+
Returns
|
|
290
|
+
-------
|
|
291
|
+
Logview address
|
|
292
|
+
"""
|
|
293
|
+
hours = hours or options.session.logview_hours
|
|
294
|
+
# notice: maxframe can't reuse subQuery else will conflict with mcqa when fetch resource data,
|
|
295
|
+
# added dagId for maxframe so logview backend will return maxframe data format if
|
|
296
|
+
# instance and dagId is provided.
|
|
297
|
+
dag_suffix = f"&dagId={dag_id}" if dag_id else ""
|
|
298
|
+
return self._instance.get_logview_address(hours) + dag_suffix
|
|
280
299
|
|
|
281
300
|
|
|
282
301
|
class MaxFrameTaskSession(MaxFrameSession):
|
|
283
302
|
schemes = [ODPS_SESSION_INSECURE_SCHEME, ODPS_SESSION_SECURE_SCHEME]
|
|
284
303
|
|
|
285
|
-
|
|
304
|
+
_caller: MaxFrameInstanceCaller
|
|
286
305
|
|
|
287
306
|
@classmethod
|
|
288
307
|
def _create_caller(
|
|
@@ -302,6 +321,15 @@ class MaxFrameTaskSession(MaxFrameSession):
|
|
|
302
321
|
**kwargs,
|
|
303
322
|
)
|
|
304
323
|
|
|
324
|
+
@property
|
|
325
|
+
def closed(self) -> bool:
|
|
326
|
+
if super().closed:
|
|
327
|
+
return True
|
|
328
|
+
if not self._caller or not self._caller.instance:
|
|
329
|
+
# session not initialized yet
|
|
330
|
+
return False
|
|
331
|
+
return self._caller.instance.is_terminated()
|
|
332
|
+
|
|
305
333
|
|
|
306
334
|
def register_session_schemes(overwrite: bool = False):
|
|
307
335
|
MaxFrameTaskSession.register_schemes(overwrite=overwrite)
|
|
@@ -17,19 +17,32 @@ import uuid
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
import pandas as pd
|
|
19
19
|
import pyarrow as pa
|
|
20
|
+
import pytest
|
|
20
21
|
from odps import ODPS
|
|
21
22
|
|
|
22
23
|
import maxframe.dataframe as md
|
|
23
|
-
from maxframe.
|
|
24
|
+
from maxframe.config import options
|
|
25
|
+
from maxframe.io.odpsio import ODPSTableIO
|
|
24
26
|
from maxframe.protocol import ODPSTableResultInfo, ResultType
|
|
25
27
|
from maxframe.tests.utils import tn
|
|
26
28
|
|
|
27
29
|
from ..fetcher import ODPSTableFetcher
|
|
28
30
|
|
|
29
31
|
|
|
30
|
-
|
|
32
|
+
@pytest.fixture
|
|
33
|
+
def switch_table_io(request):
|
|
34
|
+
old_use_common_table = options.use_common_table
|
|
35
|
+
try:
|
|
36
|
+
options.use_common_table = request.param
|
|
37
|
+
yield
|
|
38
|
+
finally:
|
|
39
|
+
options.use_common_table = old_use_common_table
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@pytest.mark.parametrize("switch_table_io", [False, True], indirect=True)
|
|
43
|
+
async def test_table_fetcher(switch_table_io):
|
|
31
44
|
odps_entry = ODPS.from_environments()
|
|
32
|
-
halo_table_io =
|
|
45
|
+
halo_table_io = ODPSTableIO(odps_entry)
|
|
33
46
|
fetcher = ODPSTableFetcher(odps_entry)
|
|
34
47
|
|
|
35
48
|
data = pd.DataFrame(
|
|
@@ -58,6 +71,11 @@ async def test_table_fetcher():
|
|
|
58
71
|
assert len(fetched) == 1000
|
|
59
72
|
pd.testing.assert_frame_equal(raw_data, fetched)
|
|
60
73
|
|
|
74
|
+
result_info = ODPSTableResultInfo(ResultType.ODPS_TABLE, full_table_name=table_name)
|
|
75
|
+
fetched = await fetcher.fetch(tileable, result_info, [slice(None, 2000), None])
|
|
76
|
+
assert len(fetched) == 1000
|
|
77
|
+
pd.testing.assert_frame_equal(raw_data, fetched)
|
|
78
|
+
|
|
61
79
|
result_info = ODPSTableResultInfo(ResultType.ODPS_TABLE, full_table_name=table_name)
|
|
62
80
|
fetched = await fetcher.fetch(tileable, result_info, [2, None])
|
|
63
81
|
assert len(fetched) == 1
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import time
|
|
16
|
-
from typing import Dict
|
|
16
|
+
from typing import Any, Dict
|
|
17
17
|
|
|
18
18
|
import mock
|
|
19
19
|
import numpy as np
|
|
@@ -23,7 +23,10 @@ from odps import ODPS
|
|
|
23
23
|
|
|
24
24
|
import maxframe.dataframe as md
|
|
25
25
|
import maxframe.remote as mr
|
|
26
|
+
from maxframe.config import options
|
|
27
|
+
from maxframe.config.config import option_context
|
|
26
28
|
from maxframe.core import ExecutableTuple, TileableGraph
|
|
29
|
+
from maxframe.errors import NoTaskServerResponseError
|
|
27
30
|
from maxframe.lib.aio import stop_isolation
|
|
28
31
|
from maxframe.protocol import ResultInfo
|
|
29
32
|
from maxframe.serialization import RemoteException
|
|
@@ -35,6 +38,7 @@ from maxframe_framedriver.app.tests.test_framedriver_webapp import ( # noqa: F4
|
|
|
35
38
|
)
|
|
36
39
|
|
|
37
40
|
from ..clients.framedriver import FrameDriverClient
|
|
41
|
+
from ..session.odps import MaxFrameRestCaller
|
|
38
42
|
|
|
39
43
|
pytestmark = pytest.mark.maxframe_engine(["MCSQL", "SPE"])
|
|
40
44
|
|
|
@@ -82,15 +86,32 @@ def test_simple_run_dataframe(start_mock_session):
|
|
|
82
86
|
session_id: str,
|
|
83
87
|
dag: TileableGraph,
|
|
84
88
|
managed_input_infos: Dict[str, ResultInfo] = None,
|
|
89
|
+
new_settings: Dict[str, Any] = None,
|
|
85
90
|
):
|
|
86
91
|
assert len(dag) == 2
|
|
87
|
-
return await original_submit_dag(
|
|
92
|
+
return await original_submit_dag(
|
|
93
|
+
self, session_id, dag, managed_input_infos, new_settings
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
no_task_server_raised = False
|
|
97
|
+
original_get_dag_info = MaxFrameRestCaller.get_dag_info
|
|
98
|
+
|
|
99
|
+
async def patched_get_dag_info(self, dag_id: str):
|
|
100
|
+
nonlocal no_task_server_raised
|
|
101
|
+
|
|
102
|
+
if not no_task_server_raised:
|
|
103
|
+
no_task_server_raised = True
|
|
104
|
+
raise NoTaskServerResponseError
|
|
105
|
+
return await original_get_dag_info(self, dag_id)
|
|
88
106
|
|
|
89
107
|
df["H"] = "extra_content"
|
|
90
108
|
|
|
91
109
|
with mock.patch(
|
|
92
110
|
"maxframe_client.clients.framedriver.FrameDriverClient.submit_dag",
|
|
93
111
|
new=patched_submit_dag,
|
|
112
|
+
), mock.patch(
|
|
113
|
+
"maxframe_client.session.odps.MaxFrameRestCaller.get_dag_info",
|
|
114
|
+
new=patched_get_dag_info,
|
|
94
115
|
):
|
|
95
116
|
result = df.execute().fetch()
|
|
96
117
|
assert len(result) == 1000
|
|
@@ -99,9 +120,12 @@ def test_simple_run_dataframe(start_mock_session):
|
|
|
99
120
|
corner_top, corner_bottom = ExecutableTuple([df.iloc[:10], df.iloc[-10:]]).fetch()
|
|
100
121
|
assert len(corner_top) == len(corner_bottom) == 10
|
|
101
122
|
|
|
102
|
-
# check ellipsis mark in DataFrame
|
|
123
|
+
# check ellipsis mark in DataFrame reprs
|
|
103
124
|
df_str_repr = str(df)
|
|
104
125
|
assert ".." in df_str_repr
|
|
126
|
+
# check ellipsis mark in Series reprs
|
|
127
|
+
series_str_repr = str(df.A.execute())
|
|
128
|
+
assert ".." in series_str_repr
|
|
105
129
|
|
|
106
130
|
key = df.key
|
|
107
131
|
assert odps_entry.exist_table(
|
|
@@ -109,13 +133,30 @@ def test_simple_run_dataframe(start_mock_session):
|
|
|
109
133
|
)
|
|
110
134
|
assert odps_entry.exist_table(build_temp_table_name(start_mock_session, key))
|
|
111
135
|
del df
|
|
112
|
-
|
|
136
|
+
retry_times = 10
|
|
137
|
+
while (
|
|
138
|
+
odps_entry.exist_table(
|
|
139
|
+
build_temp_table_name(start_mock_session, intermediate_key)
|
|
140
|
+
)
|
|
141
|
+
and retry_times > 0
|
|
142
|
+
):
|
|
143
|
+
time.sleep(1)
|
|
144
|
+
retry_times -= 1
|
|
113
145
|
assert not odps_entry.exist_table(
|
|
114
146
|
build_temp_table_name(start_mock_session, intermediate_key)
|
|
115
147
|
)
|
|
116
148
|
assert not odps_entry.exist_table(build_temp_table_name(start_mock_session, key))
|
|
117
149
|
|
|
118
150
|
|
|
151
|
+
def test_run_and_fetch_slice(start_mock_session):
|
|
152
|
+
pd_df = pd.DataFrame(np.random.rand(1000, 5), columns=list("ABCDE"))
|
|
153
|
+
df = md.DataFrame(pd_df)
|
|
154
|
+
result = df.execute()
|
|
155
|
+
|
|
156
|
+
sliced = result.head(10).fetch()
|
|
157
|
+
assert len(sliced) == 10
|
|
158
|
+
|
|
159
|
+
|
|
119
160
|
def test_run_empty_table(start_mock_session):
|
|
120
161
|
odps_entry = ODPS.from_environments()
|
|
121
162
|
|
|
@@ -136,6 +177,25 @@ def test_run_empty_table(start_mock_session):
|
|
|
136
177
|
empty_table.drop()
|
|
137
178
|
|
|
138
179
|
|
|
180
|
+
def test_run_odps_query_without_schema(start_mock_session):
|
|
181
|
+
odps_entry = ODPS.from_environments()
|
|
182
|
+
|
|
183
|
+
table_name = tn("test_session_empty_table")
|
|
184
|
+
odps_entry.delete_table(table_name, if_exists=True)
|
|
185
|
+
test_table = odps_entry.create_table(table_name, "a double, b double", lifecycle=1)
|
|
186
|
+
|
|
187
|
+
with test_table.open_writer() as writer:
|
|
188
|
+
writer.write([123, 456])
|
|
189
|
+
|
|
190
|
+
df = md.read_odps_query(
|
|
191
|
+
f"select a, b, a + b as `special: name` from {table_name}", skip_schema=True
|
|
192
|
+
)
|
|
193
|
+
executed = df.execute().fetch()
|
|
194
|
+
assert len(executed.dtypes) == 3
|
|
195
|
+
|
|
196
|
+
test_table.drop()
|
|
197
|
+
|
|
198
|
+
|
|
139
199
|
def test_run_dataframe_with_pd_source(start_mock_session):
|
|
140
200
|
odps_entry = ODPS.from_environments()
|
|
141
201
|
|
|
@@ -168,19 +228,38 @@ def test_run_dataframe_from_to_odps_table(start_mock_session):
|
|
|
168
228
|
table_name = build_temp_table_name(start_mock_session, "tmp_save")
|
|
169
229
|
table_obj = odps_entry.get_table(table_name)
|
|
170
230
|
try:
|
|
171
|
-
md.to_odps_table(md.DataFrame(pd_df), table_obj).execute().fetch()
|
|
231
|
+
md.to_odps_table(md.DataFrame(pd_df), table_obj, lifecycle=1).execute().fetch()
|
|
172
232
|
with table_obj.open_reader() as reader:
|
|
173
233
|
result_df = reader.to_pandas()
|
|
174
234
|
assert len(result_df) == 10
|
|
175
235
|
assert len(result_df.columns) == 6
|
|
176
236
|
|
|
177
|
-
df = md.read_odps_table(table_obj, index_col="index").head(10).execute()
|
|
237
|
+
df = md.read_odps_table(table_obj, index_col="index").head(10).execute()
|
|
238
|
+
assert df.shape == (10, 5)
|
|
178
239
|
assert len(df) == 10
|
|
179
240
|
assert len(df.columns) == 5
|
|
180
241
|
finally:
|
|
181
242
|
odps_entry.delete_table(table_name, if_exists=True)
|
|
182
243
|
|
|
183
244
|
|
|
245
|
+
def test_create_session_with_options(framedriver_app): # noqa: F811
|
|
246
|
+
odps_entry = ODPS.from_environments()
|
|
247
|
+
framedriver_addr = f"mf://localhost:{framedriver_app.port}"
|
|
248
|
+
old_value = options.session.max_alive_seconds
|
|
249
|
+
session = None
|
|
250
|
+
try:
|
|
251
|
+
options.session.max_alive_seconds = 10
|
|
252
|
+
session = new_session(framedriver_addr, odps_entry=odps_entry)
|
|
253
|
+
session_id = session.session_id
|
|
254
|
+
session_conf = framedriver_app.session_manager.get_session_settings(session_id)
|
|
255
|
+
with option_context(session_conf) as session_options:
|
|
256
|
+
assert session_options.session.max_alive_seconds == 10
|
|
257
|
+
finally:
|
|
258
|
+
options.session.max_alive_seconds = old_value
|
|
259
|
+
if session is not None:
|
|
260
|
+
session.destroy()
|
|
261
|
+
|
|
262
|
+
|
|
184
263
|
def test_run_and_fetch_series(start_mock_session):
|
|
185
264
|
odps_entry = ODPS.from_environments()
|
|
186
265
|
|
|
@@ -207,7 +286,22 @@ def test_run_and_fetch_series(start_mock_session):
|
|
|
207
286
|
)
|
|
208
287
|
|
|
209
288
|
|
|
210
|
-
def
|
|
289
|
+
def test_execute_with_tensor(oss_config, start_mock_session):
|
|
290
|
+
pd_df = pd.DataFrame(
|
|
291
|
+
{"angles": [0, 3, 4], "degrees": [360, 180, 360]},
|
|
292
|
+
index=["circle", "triangle", "rectangle"],
|
|
293
|
+
)
|
|
294
|
+
df = md.DataFrame(pd_df)
|
|
295
|
+
|
|
296
|
+
result = (df - [1, 2]).execute().fetch()
|
|
297
|
+
expected = pd_df - [1, 2]
|
|
298
|
+
# TODO: currently the record order in tensor reading from table is the index
|
|
299
|
+
# sorting order
|
|
300
|
+
expected.sort_index(axis=0, inplace=True)
|
|
301
|
+
pd.testing.assert_frame_equal(result, expected, check_like=True)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def test_run_remote_success(oss_config, start_mock_session):
|
|
211
305
|
def func(a, b):
|
|
212
306
|
return a + b
|
|
213
307
|
|
|
@@ -218,7 +312,7 @@ def test_run_remote_success(start_mock_session):
|
|
|
218
312
|
assert result == 21
|
|
219
313
|
|
|
220
314
|
|
|
221
|
-
def test_run_remote_error(start_mock_session):
|
|
315
|
+
def test_run_remote_error(oss_config, start_mock_session):
|
|
222
316
|
def func():
|
|
223
317
|
raise ValueError
|
|
224
318
|
|
|
@@ -226,3 +320,37 @@ def test_run_remote_error(start_mock_session):
|
|
|
226
320
|
|
|
227
321
|
with pytest.raises((ValueError, RemoteException)):
|
|
228
322
|
v.execute()
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def test_pivot_dataframe(start_mock_session):
|
|
326
|
+
pd_df = pd.DataFrame(
|
|
327
|
+
{
|
|
328
|
+
"A": "foo foo foo foo foo bar bar bar bar".split(),
|
|
329
|
+
"B": "one one one two two one one two two".split(),
|
|
330
|
+
"C": "small large large small small large small small large".split(),
|
|
331
|
+
"D": [1, 2, 2, 3, 3, 4, 5, 6, 7],
|
|
332
|
+
"E": [2, 4, 5, 5, 6, 6, 8, 9, 9],
|
|
333
|
+
}
|
|
334
|
+
)
|
|
335
|
+
df = md.DataFrame(pd_df)
|
|
336
|
+
pivot = df.pivot_table(values="D", index=["A", "B"], columns=["C"], aggfunc="sum")
|
|
337
|
+
executed = pivot.execute()
|
|
338
|
+
assert pivot.shape == (4, 2)
|
|
339
|
+
pd.testing.assert_index_equal(
|
|
340
|
+
pivot.dtypes.index, pd.Index(["large", "small"], name="C")
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
expected = pd_df.pivot_table(
|
|
344
|
+
values="D", index=["A", "B"], columns=["C"], aggfunc="sum"
|
|
345
|
+
)
|
|
346
|
+
pd.testing.assert_frame_equal(executed.to_pandas(), expected)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def test_index_drop_duplicates(start_mock_session):
|
|
350
|
+
pd_idx = pd.Index(["lame", "cow", "lame", "beetle", "lame", "hippo"])
|
|
351
|
+
idx = md.Index(pd_idx)
|
|
352
|
+
executed = idx.drop_duplicates(keep="first").execute()
|
|
353
|
+
expected = pd_idx.drop_duplicates(keep="first")
|
|
354
|
+
pd.testing.assert_index_equal(
|
|
355
|
+
executed.to_pandas().sort_values(), expected.sort_values()
|
|
356
|
+
)
|
maxframe/core/entity/chunks.py
DELETED
|
@@ -1,68 +0,0 @@
|
|
|
1
|
-
# Copyright 1999-2024 Alibaba Group Holding Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from ...serialization.serializables import BoolField, FieldTypes, TupleField
|
|
16
|
-
from ...utils import tokenize
|
|
17
|
-
from .core import Entity, EntityData
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class ChunkData(EntityData):
|
|
21
|
-
__slots__ = ()
|
|
22
|
-
|
|
23
|
-
is_broadcaster = BoolField("is_broadcaster", default=False)
|
|
24
|
-
# If the operator is a shuffle mapper, this flag indicates whether the current chunk is mapper chunk when
|
|
25
|
-
# the operator produce multiple chunks such as TensorUnique.
|
|
26
|
-
is_mapper = BoolField("is_mapper", default=None)
|
|
27
|
-
# optional fields
|
|
28
|
-
_index = TupleField("index", FieldTypes.uint32)
|
|
29
|
-
|
|
30
|
-
def __repr__(self):
|
|
31
|
-
if self.op.stage is None:
|
|
32
|
-
return (
|
|
33
|
-
f"{type(self).__name__} <op={type(self.op).__name__}, "
|
|
34
|
-
f"key={self.key}>"
|
|
35
|
-
)
|
|
36
|
-
else:
|
|
37
|
-
return (
|
|
38
|
-
f"{type(self).__name__} <op={type(self.op).__name__}, "
|
|
39
|
-
f"stage={self.op.stage.name}, key={self.key}>"
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
@property
|
|
43
|
-
def index(self):
|
|
44
|
-
return getattr(self, "_index", None)
|
|
45
|
-
|
|
46
|
-
@property
|
|
47
|
-
def device(self):
|
|
48
|
-
return self.op.device
|
|
49
|
-
|
|
50
|
-
def _update_key(self):
|
|
51
|
-
object.__setattr__(
|
|
52
|
-
self,
|
|
53
|
-
"_key",
|
|
54
|
-
tokenize(
|
|
55
|
-
type(self).__name__,
|
|
56
|
-
*(getattr(self, k, None) for k in self._keys_ if k != "_index"),
|
|
57
|
-
),
|
|
58
|
-
)
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
class Chunk(Entity):
|
|
62
|
-
_allow_data_type_ = (ChunkData,)
|
|
63
|
-
|
|
64
|
-
def __repr__(self):
|
|
65
|
-
return f"{type(self).__name__}({self._data.__repr__()})"
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
CHUNK_TYPE = (Chunk, ChunkData)
|