maxframe 0.1.0b1__cp38-cp38-win_amd64.whl → 0.1.0b3__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of maxframe might be problematic. Click here for more details.
- maxframe/_utils.cp38-win_amd64.pyd +0 -0
- maxframe/codegen.py +88 -19
- maxframe/config/config.py +10 -0
- maxframe/core/entity/executable.py +1 -0
- maxframe/core/entity/objects.py +3 -2
- maxframe/core/graph/core.cp38-win_amd64.pyd +0 -0
- maxframe/core/graph/core.pyx +2 -2
- maxframe/core/operator/base.py +14 -0
- maxframe/dataframe/__init__.py +3 -1
- maxframe/dataframe/datasource/from_records.py +4 -0
- maxframe/dataframe/datasource/read_odps_query.py +295 -0
- maxframe/dataframe/datasource/read_odps_table.py +1 -1
- maxframe/dataframe/datasource/tests/test_datasource.py +84 -1
- maxframe/dataframe/groupby/__init__.py +4 -0
- maxframe/dataframe/groupby/core.py +5 -0
- maxframe/dataframe/misc/to_numeric.py +4 -0
- maxframe/dataframe/window/aggregation.py +1 -24
- maxframe/dataframe/window/ewm.py +0 -7
- maxframe/dataframe/window/tests/test_ewm.py +0 -6
- maxframe/errors.py +21 -0
- maxframe/lib/aio/isolation.py +6 -1
- maxframe/lib/mmh3.cp38-win_amd64.pyd +0 -0
- maxframe/opcodes.py +1 -0
- maxframe/protocol.py +25 -5
- maxframe/serialization/core.cp38-win_amd64.pyd +0 -0
- maxframe/serialization/exception.py +2 -1
- maxframe/serialization/serializables/core.py +6 -1
- maxframe/serialization/serializables/field.py +2 -0
- maxframe/tensor/core.py +3 -3
- maxframe/tests/test_codegen.py +69 -0
- maxframe/tests/test_protocol.py +16 -8
- maxframe/tests/utils.py +1 -0
- maxframe/udf.py +15 -16
- maxframe/utils.py +21 -1
- {maxframe-0.1.0b1.dist-info → maxframe-0.1.0b3.dist-info}/METADATA +1 -74
- {maxframe-0.1.0b1.dist-info → maxframe-0.1.0b3.dist-info}/RECORD +42 -39
- {maxframe-0.1.0b1.dist-info → maxframe-0.1.0b3.dist-info}/WHEEL +1 -1
- maxframe_client/clients/framedriver.py +7 -7
- maxframe_client/session/task.py +31 -3
- maxframe_client/session/tests/test_task.py +29 -11
- maxframe_client/tests/test_session.py +2 -0
- {maxframe-0.1.0b1.dist-info → maxframe-0.1.0b3.dist-info}/top_level.txt +0 -0
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
import os
|
|
15
16
|
from collections import OrderedDict
|
|
16
17
|
|
|
17
18
|
import numpy as np
|
|
@@ -22,7 +23,7 @@ from odps import ODPS
|
|
|
22
23
|
from .... import tensor as mt
|
|
23
24
|
from ....tests.utils import tn
|
|
24
25
|
from ....utils import lazy_import
|
|
25
|
-
from ... import read_odps_table
|
|
26
|
+
from ... import read_odps_query, read_odps_table
|
|
26
27
|
from ...core import DatetimeIndex, Float64Index, IndexValue, Int64Index, MultiIndex
|
|
27
28
|
from ..dataframe import from_pandas as from_pandas_df
|
|
28
29
|
from ..date_range import date_range
|
|
@@ -33,6 +34,7 @@ from ..from_tensor import (
|
|
|
33
34
|
)
|
|
34
35
|
from ..index import from_pandas as from_pandas_index
|
|
35
36
|
from ..index import from_tileable
|
|
37
|
+
from ..read_odps_query import ColumnSchema, _resolve_task_sector
|
|
36
38
|
from ..series import from_pandas as from_pandas_series
|
|
37
39
|
|
|
38
40
|
ray = lazy_import("ray")
|
|
@@ -228,6 +230,7 @@ def test_from_odps_table():
|
|
|
228
230
|
assert df.op.table_name == test_table.full_table_name
|
|
229
231
|
assert df.index_value.name is None
|
|
230
232
|
assert isinstance(df.index_value.value, IndexValue.RangeIndex)
|
|
233
|
+
assert df.op.get_columns() == ["col1", "col2", "col3"]
|
|
231
234
|
pd.testing.assert_series_equal(
|
|
232
235
|
df.dtypes,
|
|
233
236
|
pd.Series(
|
|
@@ -247,6 +250,7 @@ def test_from_odps_table():
|
|
|
247
250
|
assert df.op.table_name == test_table.full_table_name
|
|
248
251
|
assert df.index_value.name is None
|
|
249
252
|
assert isinstance(df.index_value.value, IndexValue.RangeIndex)
|
|
253
|
+
assert df.op.get_columns() == ["col1", "col2"]
|
|
250
254
|
pd.testing.assert_series_equal(
|
|
251
255
|
df.dtypes,
|
|
252
256
|
pd.Series([np.dtype("O"), np.dtype("int64")], index=["col1", "col2"]),
|
|
@@ -257,6 +261,7 @@ def test_from_odps_table():
|
|
|
257
261
|
assert df.index_value.name == "col1"
|
|
258
262
|
assert isinstance(df.index_value.value, IndexValue.Index)
|
|
259
263
|
assert df.index.dtype == np.dtype("O")
|
|
264
|
+
assert df.op.get_columns() == ["col2", "col3"]
|
|
260
265
|
pd.testing.assert_series_equal(
|
|
261
266
|
df.dtypes,
|
|
262
267
|
pd.Series([np.dtype("int64"), np.dtype("float64")], index=["col2", "col3"]),
|
|
@@ -267,6 +272,7 @@ def test_from_odps_table():
|
|
|
267
272
|
|
|
268
273
|
df = read_odps_table(test_parted_table, append_partitions=True)
|
|
269
274
|
assert df.op.append_partitions is True
|
|
275
|
+
assert df.op.get_columns() == ["col1", "col2", "col3", "pt"]
|
|
270
276
|
pd.testing.assert_series_equal(
|
|
271
277
|
df.dtypes,
|
|
272
278
|
pd.Series(
|
|
@@ -280,6 +286,7 @@ def test_from_odps_table():
|
|
|
280
286
|
)
|
|
281
287
|
assert df.op.append_partitions is True
|
|
282
288
|
assert df.op.partitions == ["pt=20240103"]
|
|
289
|
+
assert df.op.get_columns() == ["col1", "col2", "pt"]
|
|
283
290
|
pd.testing.assert_series_equal(
|
|
284
291
|
df.dtypes,
|
|
285
292
|
pd.Series(
|
|
@@ -292,6 +299,67 @@ def test_from_odps_table():
|
|
|
292
299
|
test_parted_table.drop()
|
|
293
300
|
|
|
294
301
|
|
|
302
|
+
def test_from_odps_query():
|
|
303
|
+
odps_entry = ODPS.from_environments()
|
|
304
|
+
table1_name = tn("test_from_odps_query_src1")
|
|
305
|
+
table2_name = tn("test_from_odps_query_src2")
|
|
306
|
+
odps_entry.delete_table(table1_name, if_exists=True)
|
|
307
|
+
odps_entry.delete_table(table2_name, if_exists=True)
|
|
308
|
+
test_table = odps_entry.create_table(
|
|
309
|
+
table1_name, "col1 string, col2 bigint, col3 double", lifecycle=1
|
|
310
|
+
)
|
|
311
|
+
# need some data to produce complicated plans
|
|
312
|
+
odps_entry.write_table(test_table, [["A", 10, 3.5]])
|
|
313
|
+
test_table2 = odps_entry.create_table(
|
|
314
|
+
table2_name, "col1 string, col2 bigint, col3 double", lifecycle=1
|
|
315
|
+
)
|
|
316
|
+
odps_entry.write_table(test_table2, [["A", 10, 4.5]])
|
|
317
|
+
|
|
318
|
+
with pytest.raises(ValueError) as err_info:
|
|
319
|
+
read_odps_query(f"CREATE TABLE dummy_table AS SELECT * FROM {table1_name}")
|
|
320
|
+
assert "instant query" in err_info.value.args[0]
|
|
321
|
+
|
|
322
|
+
query1 = f"SELECT * FROM {table1_name} WHERE col1 > 10"
|
|
323
|
+
df = read_odps_query(query1)
|
|
324
|
+
assert df.op.query == query1
|
|
325
|
+
assert df.index_value.name is None
|
|
326
|
+
assert isinstance(df.index_value.value, IndexValue.RangeIndex)
|
|
327
|
+
pd.testing.assert_series_equal(
|
|
328
|
+
df.dtypes,
|
|
329
|
+
pd.Series(
|
|
330
|
+
[np.dtype("O"), np.dtype("int64"), np.dtype("float64")],
|
|
331
|
+
index=["col1", "col2", "col3"],
|
|
332
|
+
),
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
df = read_odps_query(query1, index_col="col1")
|
|
336
|
+
assert df.op.query == query1
|
|
337
|
+
assert df.index_value.name == "col1"
|
|
338
|
+
assert isinstance(df.index_value.value, IndexValue.Index)
|
|
339
|
+
pd.testing.assert_series_equal(
|
|
340
|
+
df.dtypes,
|
|
341
|
+
pd.Series([np.dtype("int64"), np.dtype("float64")], index=["col2", "col3"]),
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
query2 = (
|
|
345
|
+
f"SELECT t1.col1, t1.col2, t1.col3 as c31, t2.col3 as c32 "
|
|
346
|
+
f"FROM {table1_name} t1 "
|
|
347
|
+
f"INNER JOIN {table2_name} t2 "
|
|
348
|
+
f"ON t1.col1 = t2.col1 AND t1.col2 = t2.col2"
|
|
349
|
+
)
|
|
350
|
+
df = read_odps_query(query2, index_col=["col1", "col2"])
|
|
351
|
+
assert df.op.query == query2
|
|
352
|
+
assert df.index_value.names == ["col1", "col2"]
|
|
353
|
+
assert isinstance(df.index_value.value, IndexValue.MultiIndex)
|
|
354
|
+
pd.testing.assert_series_equal(
|
|
355
|
+
df.dtypes,
|
|
356
|
+
pd.Series([np.dtype("float64"), np.dtype("float64")], index=["c31", "c32"]),
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
test_table.drop()
|
|
360
|
+
test_table2.drop()
|
|
361
|
+
|
|
362
|
+
|
|
295
363
|
def test_date_range():
|
|
296
364
|
with pytest.raises(TypeError):
|
|
297
365
|
_ = date_range("2020-1-1", periods="2")
|
|
@@ -316,3 +384,18 @@ def test_date_range():
|
|
|
316
384
|
assert dr.index_value.is_unique == expected.is_unique
|
|
317
385
|
assert dr.index_value.is_monotonic_increasing == expected.is_monotonic_increasing
|
|
318
386
|
assert dr.name == expected.name
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def test_resolve_task_sector():
|
|
390
|
+
input_path = os.path.join(os.path.dirname(__file__), "test-data", "task-input.txt")
|
|
391
|
+
with open(input_path, "r") as f:
|
|
392
|
+
sector = f.read()
|
|
393
|
+
actual_sector = _resolve_task_sector("job0", sector)
|
|
394
|
+
|
|
395
|
+
assert actual_sector.job_name == "job0"
|
|
396
|
+
assert actual_sector.task_name == "M1"
|
|
397
|
+
assert actual_sector.output_target == "Screen"
|
|
398
|
+
assert len(actual_sector.schema) == 78
|
|
399
|
+
assert actual_sector.schema[0] == ColumnSchema("unnamed: 0", "bigint", "")
|
|
400
|
+
assert actual_sector.schema[1] == ColumnSchema("id", "bigint", "id_alias")
|
|
401
|
+
assert actual_sector.schema[2] == ColumnSchema("listing_url", "string", "")
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
# noinspection PyUnresolvedReferences
|
|
16
16
|
from ..core import DataFrameGroupBy, GroupBy, SeriesGroupBy
|
|
17
|
+
from .core import NamedAgg
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
def _install():
|
|
@@ -25,6 +26,7 @@ def _install():
|
|
|
25
26
|
from .fill import bfill, ffill, fillna
|
|
26
27
|
from .getitem import df_groupby_getitem
|
|
27
28
|
from .head import head
|
|
29
|
+
from .sample import groupby_sample
|
|
28
30
|
from .transform import groupby_transform
|
|
29
31
|
|
|
30
32
|
for cls in DATAFRAME_TYPE:
|
|
@@ -65,6 +67,8 @@ def _install():
|
|
|
65
67
|
|
|
66
68
|
setattr(cls, "head", head)
|
|
67
69
|
|
|
70
|
+
setattr(cls, "sample", groupby_sample)
|
|
71
|
+
|
|
68
72
|
setattr(cls, "ffill", ffill)
|
|
69
73
|
setattr(cls, "bfill", bfill)
|
|
70
74
|
setattr(cls, "backfill", bfill)
|
|
@@ -12,6 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from collections import namedtuple
|
|
16
|
+
|
|
15
17
|
import pandas as pd
|
|
16
18
|
|
|
17
19
|
from ... import opcodes
|
|
@@ -30,6 +32,9 @@ _GROUP_KEYS_NO_DEFAULT = pd_release_version >= (1, 5, 0)
|
|
|
30
32
|
_default_group_keys = no_default if _GROUP_KEYS_NO_DEFAULT else True
|
|
31
33
|
|
|
32
34
|
|
|
35
|
+
NamedAgg = namedtuple("NamedAgg", ["column", "aggfunc"])
|
|
36
|
+
|
|
37
|
+
|
|
33
38
|
class DataFrameGroupByOperator(MapReduceOperator, DataFrameOperatorMixin):
|
|
34
39
|
_op_type_ = opcodes.GROUPBY
|
|
35
40
|
|
|
@@ -29,6 +29,10 @@ class DataFrameToNumeric(DataFrameOperator, DataFrameOperatorMixin):
|
|
|
29
29
|
def __init__(self, errors="raise", downcast=None, **kw):
|
|
30
30
|
super().__init__(errors=errors, downcast=downcast, **kw)
|
|
31
31
|
|
|
32
|
+
@property
|
|
33
|
+
def input(self):
|
|
34
|
+
return self.inputs[0]
|
|
35
|
+
|
|
32
36
|
def __call__(self, arg):
|
|
33
37
|
if isinstance(arg, pd.Series):
|
|
34
38
|
arg = asseries(arg)
|
|
@@ -18,14 +18,7 @@ from collections.abc import Iterable
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
import pandas as pd
|
|
20
20
|
|
|
21
|
-
from ...serialization.serializables import
|
|
22
|
-
AnyField,
|
|
23
|
-
BoolField,
|
|
24
|
-
DictField,
|
|
25
|
-
Int32Field,
|
|
26
|
-
Int64Field,
|
|
27
|
-
StringField,
|
|
28
|
-
)
|
|
21
|
+
from ...serialization.serializables import AnyField, BoolField, Int32Field, Int64Field
|
|
29
22
|
from ..core import DATAFRAME_TYPE
|
|
30
23
|
from ..operators import DataFrameOperator, DataFrameOperatorMixin
|
|
31
24
|
from ..utils import build_df, build_empty_series, parse_index
|
|
@@ -41,22 +34,6 @@ class BaseDataFrameExpandingAgg(DataFrameOperator, DataFrameOperatorMixin):
|
|
|
41
34
|
# True if function name is treated as new index
|
|
42
35
|
append_index = BoolField("append_index", default=None)
|
|
43
36
|
|
|
44
|
-
# chunk params
|
|
45
|
-
output_agg = BoolField("output_agg", default=None)
|
|
46
|
-
|
|
47
|
-
map_groups = DictField("map_groups", default=None)
|
|
48
|
-
map_sources = DictField("map_sources", default=None)
|
|
49
|
-
combine_sources = DictField("combine_sources", default=None)
|
|
50
|
-
combine_columns = DictField("combine_columns", default=None)
|
|
51
|
-
combine_funcs = DictField("combine_funcs", default=None)
|
|
52
|
-
key_to_funcs = DictField("keys_to_funcs", default=None)
|
|
53
|
-
|
|
54
|
-
min_periods_func_name = StringField("min_periods_func_name", default=None)
|
|
55
|
-
|
|
56
|
-
@property
|
|
57
|
-
def output_limit(self):
|
|
58
|
-
return 2 if self.output_agg else 1
|
|
59
|
-
|
|
60
37
|
def __call__(self, expanding):
|
|
61
38
|
inp = expanding.input
|
|
62
39
|
raw_func = self.func
|
maxframe/dataframe/window/ewm.py
CHANGED
|
@@ -233,13 +233,6 @@ def ewm(
|
|
|
233
233
|
if alpha <= 0 or alpha > 1:
|
|
234
234
|
raise ValueError("alpha must satisfy: 0 < alpha <= 1")
|
|
235
235
|
|
|
236
|
-
if not adjust and not ignore_na:
|
|
237
|
-
raise NotImplementedError(
|
|
238
|
-
"adjust == False when ignore_na == False not implemented"
|
|
239
|
-
)
|
|
240
|
-
if axis == 1:
|
|
241
|
-
raise NotImplementedError("axis other than 0 is not supported")
|
|
242
|
-
|
|
243
236
|
if alpha == 1:
|
|
244
237
|
return obj.expanding(min_periods=min_periods, axis=axis)
|
|
245
238
|
|
|
@@ -23,9 +23,6 @@ def test_ewm():
|
|
|
23
23
|
df = pd.DataFrame(np.random.rand(4, 3), columns=list("abc"))
|
|
24
24
|
df2 = md.DataFrame(df)
|
|
25
25
|
|
|
26
|
-
with pytest.raises(NotImplementedError):
|
|
27
|
-
_ = df2.ewm(2, adjust=False, ignore_na=False)
|
|
28
|
-
|
|
29
26
|
with pytest.raises(ValueError):
|
|
30
27
|
_ = df2.ewm()
|
|
31
28
|
|
|
@@ -59,9 +56,6 @@ def test_ewm_agg():
|
|
|
59
56
|
df = pd.DataFrame(np.random.rand(4, 3), columns=list("abc"))
|
|
60
57
|
df2 = md.DataFrame(df, chunk_size=3)
|
|
61
58
|
|
|
62
|
-
with pytest.raises(NotImplementedError):
|
|
63
|
-
_ = df2.ewm(span=3, axis=1).agg("mean")
|
|
64
|
-
|
|
65
59
|
r = df2.ewm(span=3).agg("mean")
|
|
66
60
|
expected = df.ewm(span=3).agg("mean")
|
|
67
61
|
|
maxframe/errors.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright 1999-2024 Alibaba Group Holding Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MaxFrameError(Exception):
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class MaxFrameUserError(MaxFrameError):
|
|
21
|
+
pass
|
maxframe/lib/aio/isolation.py
CHANGED
|
@@ -14,11 +14,14 @@
|
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
16
|
import atexit
|
|
17
|
+
import itertools
|
|
17
18
|
import threading
|
|
18
19
|
from typing import Dict, Optional
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class Isolation:
|
|
23
|
+
_counter = itertools.count().__next__
|
|
24
|
+
|
|
22
25
|
loop: asyncio.AbstractEventLoop
|
|
23
26
|
_stopped: Optional[asyncio.Event]
|
|
24
27
|
_thread: Optional[threading.Thread]
|
|
@@ -38,7 +41,9 @@ class Isolation:
|
|
|
38
41
|
|
|
39
42
|
def start(self):
|
|
40
43
|
if self._threaded:
|
|
41
|
-
self._thread = thread = threading.Thread(
|
|
44
|
+
self._thread = thread = threading.Thread(
|
|
45
|
+
name=f"IsolationThread-{self._counter()}", target=self._run
|
|
46
|
+
)
|
|
42
47
|
thread.daemon = True
|
|
43
48
|
thread.start()
|
|
44
49
|
self._thread_ident = thread.ident
|
|
Binary file
|
maxframe/opcodes.py
CHANGED
maxframe/protocol.py
CHANGED
|
@@ -46,6 +46,8 @@ BodyType = TypeVar("BodyType", bound="Serializable")
|
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
class JsonSerializable(Serializable):
|
|
49
|
+
_ignore_non_existing_keys = True
|
|
50
|
+
|
|
49
51
|
@classmethod
|
|
50
52
|
def from_json(cls, serialized: dict) -> "JsonSerializable":
|
|
51
53
|
raise NotImplementedError
|
|
@@ -209,7 +211,10 @@ class ErrorInfo(JsonSerializable):
|
|
|
209
211
|
kw["raw_error_source"] = ErrorSource(serialized["raw_error_source"])
|
|
210
212
|
if kw.get("raw_error_data"):
|
|
211
213
|
bufs = [base64.b64decode(s) for s in kw["raw_error_data"]]
|
|
212
|
-
|
|
214
|
+
try:
|
|
215
|
+
kw["raw_error_data"] = pickle.loads(bufs[0], buffers=bufs[1:])
|
|
216
|
+
except:
|
|
217
|
+
kw["raw_error_data"] = None
|
|
213
218
|
return cls(**kw)
|
|
214
219
|
|
|
215
220
|
def to_json(self) -> dict:
|
|
@@ -242,6 +247,8 @@ class DagInfo(JsonSerializable):
|
|
|
242
247
|
default_factory=dict,
|
|
243
248
|
)
|
|
244
249
|
error_info: Optional[ErrorInfo] = ReferenceField("error_info", default=None)
|
|
250
|
+
start_timestamp: Optional[float] = Float64Field("start_timestamp", default=None)
|
|
251
|
+
end_timestamp: Optional[float] = Float64Field("end_timestamp", default=None)
|
|
245
252
|
|
|
246
253
|
@classmethod
|
|
247
254
|
def from_json(cls, serialized: dict) -> "DagInfo":
|
|
@@ -262,7 +269,10 @@ class DagInfo(JsonSerializable):
|
|
|
262
269
|
"dag_id": self.dag_id,
|
|
263
270
|
"status": self.status.value,
|
|
264
271
|
"progress": self.progress,
|
|
272
|
+
"start_timestamp": self.start_timestamp,
|
|
273
|
+
"end_timestamp": self.end_timestamp,
|
|
265
274
|
}
|
|
275
|
+
ret = {k: v for k, v in ret.items() if v is not None}
|
|
266
276
|
if self.tileable_to_result_infos:
|
|
267
277
|
ret["tileable_to_result_infos"] = {
|
|
268
278
|
k: v.to_json() for k, v in self.tileable_to_result_infos.items()
|
|
@@ -278,12 +288,18 @@ class CreateSessionRequest(Serializable):
|
|
|
278
288
|
|
|
279
289
|
class SessionInfo(JsonSerializable):
|
|
280
290
|
session_id: str = StringField("session_id")
|
|
281
|
-
settings: Dict[str, Any] = DictField(
|
|
282
|
-
|
|
283
|
-
|
|
291
|
+
settings: Dict[str, Any] = DictField(
|
|
292
|
+
"settings", key_type=FieldTypes.string, default=None
|
|
293
|
+
)
|
|
294
|
+
start_timestamp: float = Float64Field("start_timestamp", default=None)
|
|
295
|
+
idle_timestamp: float = Float64Field("idle_timestamp", default=None)
|
|
284
296
|
dag_infos: Dict[str, Optional[DagInfo]] = DictField(
|
|
285
|
-
"dag_infos",
|
|
297
|
+
"dag_infos",
|
|
298
|
+
key_type=FieldTypes.string,
|
|
299
|
+
value_type=FieldTypes.reference,
|
|
300
|
+
default=None,
|
|
286
301
|
)
|
|
302
|
+
error_info: Optional[ErrorInfo] = ReferenceField("error_info", default=None)
|
|
287
303
|
|
|
288
304
|
@classmethod
|
|
289
305
|
def from_json(cls, serialized: dict) -> "SessionInfo":
|
|
@@ -292,6 +308,8 @@ class SessionInfo(JsonSerializable):
|
|
|
292
308
|
kw["dag_infos"] = {
|
|
293
309
|
k: DagInfo.from_json(v) for k, v in kw["dag_infos"].items()
|
|
294
310
|
}
|
|
311
|
+
if kw.get("error_info"):
|
|
312
|
+
kw["error_info"] = ErrorInfo.from_json(kw["error_info"])
|
|
295
313
|
return SessionInfo(**kw)
|
|
296
314
|
|
|
297
315
|
def to_json(self) -> dict:
|
|
@@ -303,6 +321,8 @@ class SessionInfo(JsonSerializable):
|
|
|
303
321
|
}
|
|
304
322
|
if self.dag_infos:
|
|
305
323
|
ret["dag_infos"] = {k: v.to_json() for k, v in self.dag_infos.items()}
|
|
324
|
+
if self.error_info:
|
|
325
|
+
ret["error_info"] = self.error_info.to_json()
|
|
306
326
|
return ret
|
|
307
327
|
|
|
308
328
|
|
|
Binary file
|
|
@@ -16,13 +16,14 @@ import logging
|
|
|
16
16
|
import traceback
|
|
17
17
|
from typing import Dict, List
|
|
18
18
|
|
|
19
|
+
from ..errors import MaxFrameError
|
|
19
20
|
from ..lib import wrapped_pickle as pickle
|
|
20
21
|
from .core import Serializer, buffered, pickle_buffers, unpickle_buffers
|
|
21
22
|
|
|
22
23
|
logger = logging.getLogger(__name__)
|
|
23
24
|
|
|
24
25
|
|
|
25
|
-
class RemoteException(
|
|
26
|
+
class RemoteException(MaxFrameError):
|
|
26
27
|
def __init__(
|
|
27
28
|
self, messages: List[str], tracebacks: List[List[str]], buffers: List[bytes]
|
|
28
29
|
):
|
|
@@ -112,6 +112,7 @@ class Serializable(metaclass=SerializableMeta):
|
|
|
112
112
|
__slots__ = ("__weakref__",)
|
|
113
113
|
|
|
114
114
|
_cache_primitive_serial = False
|
|
115
|
+
_ignore_non_existing_keys = False
|
|
115
116
|
|
|
116
117
|
_FIELDS: Dict[str, Field]
|
|
117
118
|
_FIELD_ORDER: List[str]
|
|
@@ -128,7 +129,11 @@ class Serializable(metaclass=SerializableMeta):
|
|
|
128
129
|
else:
|
|
129
130
|
values = kwargs
|
|
130
131
|
for k, v in values.items():
|
|
131
|
-
|
|
132
|
+
try:
|
|
133
|
+
fields[k].set(self, v)
|
|
134
|
+
except KeyError:
|
|
135
|
+
if not self._ignore_non_existing_keys:
|
|
136
|
+
raise
|
|
132
137
|
|
|
133
138
|
def __on_deserialize__(self):
|
|
134
139
|
pass
|
|
@@ -507,12 +507,14 @@ class ReferenceField(Field):
|
|
|
507
507
|
tag: str,
|
|
508
508
|
reference_type: Union[str, Type] = None,
|
|
509
509
|
default: Any = no_default,
|
|
510
|
+
default_factory: Optional[Callable] = None,
|
|
510
511
|
on_serialize: Callable[[Any], Any] = None,
|
|
511
512
|
on_deserialize: Callable[[Any], Any] = None,
|
|
512
513
|
):
|
|
513
514
|
super().__init__(
|
|
514
515
|
tag,
|
|
515
516
|
default=default,
|
|
517
|
+
default_factory=default_factory,
|
|
516
518
|
on_serialize=on_serialize,
|
|
517
519
|
on_deserialize=on_deserialize,
|
|
518
520
|
)
|
maxframe/tensor/core.py
CHANGED
|
@@ -43,7 +43,7 @@ from ..serialization.serializables import (
|
|
|
43
43
|
StringField,
|
|
44
44
|
TupleField,
|
|
45
45
|
)
|
|
46
|
-
from ..utils import on_deserialize_shape, on_serialize_shape
|
|
46
|
+
from ..utils import on_deserialize_shape, on_serialize_shape, skip_na_call
|
|
47
47
|
from .utils import fetch_corner_data, get_chunk_slices
|
|
48
48
|
|
|
49
49
|
logger = logging.getLogger(__name__)
|
|
@@ -181,8 +181,8 @@ class TensorData(HasShapeTileableData, _ExecuteAndFetchMixin):
|
|
|
181
181
|
_chunks = ListField(
|
|
182
182
|
"chunks",
|
|
183
183
|
FieldTypes.reference(TensorChunkData),
|
|
184
|
-
on_serialize=lambda x: [it.data for it in x]
|
|
185
|
-
on_deserialize=lambda x: [TensorChunk(it) for it in x]
|
|
184
|
+
on_serialize=skip_na_call(lambda x: [it.data for it in x]),
|
|
185
|
+
on_deserialize=skip_na_call(lambda x: [TensorChunk(it) for it in x]),
|
|
186
186
|
)
|
|
187
187
|
|
|
188
188
|
def __init__(
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright 1999-2024 Alibaba Group Holding Ltd.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import base64
|
|
17
|
+
from typing import List, Tuple
|
|
18
|
+
|
|
19
|
+
# 使用pytest生成单元测试
|
|
20
|
+
import pytest
|
|
21
|
+
|
|
22
|
+
from maxframe.codegen import UserCodeMixin
|
|
23
|
+
from maxframe.lib import wrapped_pickle
|
|
24
|
+
from maxframe.serialization.core import PickleContainer
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@pytest.mark.parametrize(
|
|
28
|
+
"input_obj, expected_output",
|
|
29
|
+
[
|
|
30
|
+
(None, "None"),
|
|
31
|
+
(10, "10"),
|
|
32
|
+
(3.14, "3.14"),
|
|
33
|
+
(True, "True"),
|
|
34
|
+
(False, "False"),
|
|
35
|
+
(b"hello", "base64.b64decode(b'aGVsbG8=')"),
|
|
36
|
+
("hello", "'hello'"),
|
|
37
|
+
([1, 2, 3], "[1, 2, 3]"),
|
|
38
|
+
({"a": 1, "b": 2}, "{'a': 1, 'b': 2}"),
|
|
39
|
+
((1, 2, 3), "(1, 2, 3)"),
|
|
40
|
+
((1,), "(1,)"),
|
|
41
|
+
((), "()"),
|
|
42
|
+
({1, 2, 3}, "{1, 2, 3}"),
|
|
43
|
+
(set(), "set()"),
|
|
44
|
+
],
|
|
45
|
+
)
|
|
46
|
+
def test_obj_to_python_expr(input_obj, expected_output):
|
|
47
|
+
assert UserCodeMixin.obj_to_python_expr(input_obj) == expected_output
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def test_obj_to_python_expr_custom_object():
|
|
51
|
+
class CustomClass:
|
|
52
|
+
def __init__(self, a: int, b: List[int], c: Tuple[int, int]):
|
|
53
|
+
self.a = a
|
|
54
|
+
self.b = b
|
|
55
|
+
self.c = c
|
|
56
|
+
|
|
57
|
+
custom_obj = CustomClass(1, [2, 3], (4, 5))
|
|
58
|
+
pickle_data = wrapped_pickle.dumps(custom_obj)
|
|
59
|
+
pickle_str = base64.b64encode(pickle_data)
|
|
60
|
+
custom_obj_pickle_container = PickleContainer([pickle_data])
|
|
61
|
+
|
|
62
|
+
# with class obj will not support currently
|
|
63
|
+
with pytest.raises(ValueError):
|
|
64
|
+
UserCodeMixin.obj_to_python_expr(custom_obj)
|
|
65
|
+
|
|
66
|
+
assert (
|
|
67
|
+
UserCodeMixin.obj_to_python_expr(custom_obj_pickle_container)
|
|
68
|
+
== f"cloudpickle.loads(base64.b64decode({pickle_str}), buffers=[])"
|
|
69
|
+
)
|
maxframe/tests/test_protocol.py
CHANGED
|
@@ -11,6 +11,8 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import json
|
|
14
16
|
import time
|
|
15
17
|
|
|
16
18
|
import pytest
|
|
@@ -29,28 +31,32 @@ from ..serialization import RemoteException
|
|
|
29
31
|
from ..utils import deserialize_serializable, serialize_serializable
|
|
30
32
|
|
|
31
33
|
|
|
34
|
+
def _json_round_trip(json_data: dict) -> dict:
|
|
35
|
+
return json.loads(json.dumps(json_data))
|
|
36
|
+
|
|
37
|
+
|
|
32
38
|
def test_result_info_json_serialize():
|
|
33
|
-
ri = ResultInfo.from_json(ResultInfo().to_json())
|
|
39
|
+
ri = ResultInfo.from_json(_json_round_trip(ResultInfo().to_json()))
|
|
34
40
|
assert type(ri) is ResultInfo
|
|
35
41
|
|
|
36
42
|
ri = ODPSTableResultInfo(
|
|
37
43
|
full_table_name="table_name", partition_specs=["pt=partition"]
|
|
38
44
|
)
|
|
39
|
-
deserial_ri = ResultInfo.from_json(ri.to_json())
|
|
45
|
+
deserial_ri = ResultInfo.from_json(_json_round_trip(ri.to_json()))
|
|
40
46
|
assert type(ri) is ODPSTableResultInfo
|
|
41
47
|
assert ri.result_type == deserial_ri.result_type
|
|
42
48
|
assert ri.full_table_name == deserial_ri.full_table_name
|
|
43
49
|
assert ri.partition_specs == deserial_ri.partition_specs
|
|
44
50
|
|
|
45
51
|
ri = ODPSTableResultInfo(full_table_name="table_name")
|
|
46
|
-
deserial_ri = ResultInfo.from_json(ri.to_json())
|
|
52
|
+
deserial_ri = ResultInfo.from_json(_json_round_trip(ri.to_json()))
|
|
47
53
|
assert type(ri) is ODPSTableResultInfo
|
|
48
54
|
assert ri.result_type == deserial_ri.result_type
|
|
49
55
|
assert ri.full_table_name == deserial_ri.full_table_name
|
|
50
56
|
assert ri.partition_specs == deserial_ri.partition_specs
|
|
51
57
|
|
|
52
58
|
ri = ODPSVolumeResultInfo(volume_name="vol_name", volume_path="vol_path")
|
|
53
|
-
deserial_ri = ResultInfo.from_json(ri.to_json())
|
|
59
|
+
deserial_ri = ResultInfo.from_json(_json_round_trip(ri.to_json()))
|
|
54
60
|
assert type(ri) is ODPSVolumeResultInfo
|
|
55
61
|
assert ri.result_type == deserial_ri.result_type
|
|
56
62
|
assert ri.volume_name == deserial_ri.volume_name
|
|
@@ -63,7 +69,7 @@ def test_error_info_json_serialize():
|
|
|
63
69
|
except ValueError as ex:
|
|
64
70
|
err_info = ErrorInfo.from_exception(ex)
|
|
65
71
|
|
|
66
|
-
deserial_err_info = ErrorInfo.from_json(err_info.to_json())
|
|
72
|
+
deserial_err_info = ErrorInfo.from_json(_json_round_trip(err_info.to_json()))
|
|
67
73
|
assert deserial_err_info.error_messages == err_info.error_messages
|
|
68
74
|
assert isinstance(deserial_err_info.raw_error_data, ValueError)
|
|
69
75
|
|
|
@@ -73,7 +79,7 @@ def test_error_info_json_serialize():
|
|
|
73
79
|
with pytest.raises(RemoteException):
|
|
74
80
|
mf_err_info.reraise()
|
|
75
81
|
|
|
76
|
-
deserial_err_info = ErrorInfo.from_json(mf_err_info.to_json())
|
|
82
|
+
deserial_err_info = ErrorInfo.from_json(_json_round_trip(mf_err_info.to_json()))
|
|
77
83
|
assert isinstance(deserial_err_info.raw_error_data, ValueError)
|
|
78
84
|
with pytest.raises(ValueError):
|
|
79
85
|
deserial_err_info.reraise()
|
|
@@ -94,7 +100,9 @@ def test_dag_info_json_serialize():
|
|
|
94
100
|
},
|
|
95
101
|
error_info=err_info,
|
|
96
102
|
)
|
|
97
|
-
|
|
103
|
+
json_info = info.to_json()
|
|
104
|
+
json_info["non_existing_field"] = "non_existing"
|
|
105
|
+
deserial_info = DagInfo.from_json(_json_round_trip(json_info))
|
|
98
106
|
assert deserial_info.session_id == info.session_id
|
|
99
107
|
assert deserial_info.dag_id == info.dag_id
|
|
100
108
|
assert deserial_info.status == info.status
|
|
@@ -121,7 +129,7 @@ def test_session_info_json_serialize():
|
|
|
121
129
|
idle_timestamp=None,
|
|
122
130
|
dag_infos={"test_dag_id": dag_info},
|
|
123
131
|
)
|
|
124
|
-
deserial_info = SessionInfo.from_json(info.to_json())
|
|
132
|
+
deserial_info = SessionInfo.from_json(_json_round_trip(info.to_json()))
|
|
125
133
|
assert deserial_info.session_id == info.session_id
|
|
126
134
|
assert deserial_info.settings == info.settings
|
|
127
135
|
assert deserial_info.start_timestamp == info.start_timestamp
|
maxframe/tests/utils.py
CHANGED