maxframe 0.1.0b1__cp311-cp311-win_amd64.whl → 0.1.0b3__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (42) hide show
  1. maxframe/_utils.cp311-win_amd64.pyd +0 -0
  2. maxframe/codegen.py +88 -19
  3. maxframe/config/config.py +10 -0
  4. maxframe/core/entity/executable.py +1 -0
  5. maxframe/core/entity/objects.py +3 -2
  6. maxframe/core/graph/core.cp311-win_amd64.pyd +0 -0
  7. maxframe/core/graph/core.pyx +2 -2
  8. maxframe/core/operator/base.py +14 -0
  9. maxframe/dataframe/__init__.py +3 -1
  10. maxframe/dataframe/datasource/from_records.py +4 -0
  11. maxframe/dataframe/datasource/read_odps_query.py +295 -0
  12. maxframe/dataframe/datasource/read_odps_table.py +1 -1
  13. maxframe/dataframe/datasource/tests/test_datasource.py +84 -1
  14. maxframe/dataframe/groupby/__init__.py +4 -0
  15. maxframe/dataframe/groupby/core.py +5 -0
  16. maxframe/dataframe/misc/to_numeric.py +4 -0
  17. maxframe/dataframe/window/aggregation.py +1 -24
  18. maxframe/dataframe/window/ewm.py +0 -7
  19. maxframe/dataframe/window/tests/test_ewm.py +0 -6
  20. maxframe/errors.py +21 -0
  21. maxframe/lib/aio/isolation.py +6 -1
  22. maxframe/lib/mmh3.cp311-win_amd64.pyd +0 -0
  23. maxframe/opcodes.py +1 -0
  24. maxframe/protocol.py +25 -5
  25. maxframe/serialization/core.cp311-win_amd64.pyd +0 -0
  26. maxframe/serialization/exception.py +2 -1
  27. maxframe/serialization/serializables/core.py +6 -1
  28. maxframe/serialization/serializables/field.py +2 -0
  29. maxframe/tensor/core.py +3 -3
  30. maxframe/tests/test_codegen.py +69 -0
  31. maxframe/tests/test_protocol.py +16 -8
  32. maxframe/tests/utils.py +1 -0
  33. maxframe/udf.py +15 -16
  34. maxframe/utils.py +21 -1
  35. {maxframe-0.1.0b1.dist-info → maxframe-0.1.0b3.dist-info}/METADATA +1 -74
  36. {maxframe-0.1.0b1.dist-info → maxframe-0.1.0b3.dist-info}/RECORD +42 -39
  37. {maxframe-0.1.0b1.dist-info → maxframe-0.1.0b3.dist-info}/WHEEL +1 -1
  38. maxframe_client/clients/framedriver.py +7 -7
  39. maxframe_client/session/task.py +31 -3
  40. maxframe_client/session/tests/test_task.py +29 -11
  41. maxframe_client/tests/test_session.py +2 -0
  42. {maxframe-0.1.0b1.dist-info → maxframe-0.1.0b3.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import os
15
16
  from collections import OrderedDict
16
17
 
17
18
  import numpy as np
@@ -22,7 +23,7 @@ from odps import ODPS
22
23
  from .... import tensor as mt
23
24
  from ....tests.utils import tn
24
25
  from ....utils import lazy_import
25
- from ... import read_odps_table
26
+ from ... import read_odps_query, read_odps_table
26
27
  from ...core import DatetimeIndex, Float64Index, IndexValue, Int64Index, MultiIndex
27
28
  from ..dataframe import from_pandas as from_pandas_df
28
29
  from ..date_range import date_range
@@ -33,6 +34,7 @@ from ..from_tensor import (
33
34
  )
34
35
  from ..index import from_pandas as from_pandas_index
35
36
  from ..index import from_tileable
37
+ from ..read_odps_query import ColumnSchema, _resolve_task_sector
36
38
  from ..series import from_pandas as from_pandas_series
37
39
 
38
40
  ray = lazy_import("ray")
@@ -228,6 +230,7 @@ def test_from_odps_table():
228
230
  assert df.op.table_name == test_table.full_table_name
229
231
  assert df.index_value.name is None
230
232
  assert isinstance(df.index_value.value, IndexValue.RangeIndex)
233
+ assert df.op.get_columns() == ["col1", "col2", "col3"]
231
234
  pd.testing.assert_series_equal(
232
235
  df.dtypes,
233
236
  pd.Series(
@@ -247,6 +250,7 @@ def test_from_odps_table():
247
250
  assert df.op.table_name == test_table.full_table_name
248
251
  assert df.index_value.name is None
249
252
  assert isinstance(df.index_value.value, IndexValue.RangeIndex)
253
+ assert df.op.get_columns() == ["col1", "col2"]
250
254
  pd.testing.assert_series_equal(
251
255
  df.dtypes,
252
256
  pd.Series([np.dtype("O"), np.dtype("int64")], index=["col1", "col2"]),
@@ -257,6 +261,7 @@ def test_from_odps_table():
257
261
  assert df.index_value.name == "col1"
258
262
  assert isinstance(df.index_value.value, IndexValue.Index)
259
263
  assert df.index.dtype == np.dtype("O")
264
+ assert df.op.get_columns() == ["col2", "col3"]
260
265
  pd.testing.assert_series_equal(
261
266
  df.dtypes,
262
267
  pd.Series([np.dtype("int64"), np.dtype("float64")], index=["col2", "col3"]),
@@ -267,6 +272,7 @@ def test_from_odps_table():
267
272
 
268
273
  df = read_odps_table(test_parted_table, append_partitions=True)
269
274
  assert df.op.append_partitions is True
275
+ assert df.op.get_columns() == ["col1", "col2", "col3", "pt"]
270
276
  pd.testing.assert_series_equal(
271
277
  df.dtypes,
272
278
  pd.Series(
@@ -280,6 +286,7 @@ def test_from_odps_table():
280
286
  )
281
287
  assert df.op.append_partitions is True
282
288
  assert df.op.partitions == ["pt=20240103"]
289
+ assert df.op.get_columns() == ["col1", "col2", "pt"]
283
290
  pd.testing.assert_series_equal(
284
291
  df.dtypes,
285
292
  pd.Series(
@@ -292,6 +299,67 @@ def test_from_odps_table():
292
299
  test_parted_table.drop()
293
300
 
294
301
 
302
+ def test_from_odps_query():
303
+ odps_entry = ODPS.from_environments()
304
+ table1_name = tn("test_from_odps_query_src1")
305
+ table2_name = tn("test_from_odps_query_src2")
306
+ odps_entry.delete_table(table1_name, if_exists=True)
307
+ odps_entry.delete_table(table2_name, if_exists=True)
308
+ test_table = odps_entry.create_table(
309
+ table1_name, "col1 string, col2 bigint, col3 double", lifecycle=1
310
+ )
311
+ # need some data to produce complicated plans
312
+ odps_entry.write_table(test_table, [["A", 10, 3.5]])
313
+ test_table2 = odps_entry.create_table(
314
+ table2_name, "col1 string, col2 bigint, col3 double", lifecycle=1
315
+ )
316
+ odps_entry.write_table(test_table2, [["A", 10, 4.5]])
317
+
318
+ with pytest.raises(ValueError) as err_info:
319
+ read_odps_query(f"CREATE TABLE dummy_table AS SELECT * FROM {table1_name}")
320
+ assert "instant query" in err_info.value.args[0]
321
+
322
+ query1 = f"SELECT * FROM {table1_name} WHERE col1 > 10"
323
+ df = read_odps_query(query1)
324
+ assert df.op.query == query1
325
+ assert df.index_value.name is None
326
+ assert isinstance(df.index_value.value, IndexValue.RangeIndex)
327
+ pd.testing.assert_series_equal(
328
+ df.dtypes,
329
+ pd.Series(
330
+ [np.dtype("O"), np.dtype("int64"), np.dtype("float64")],
331
+ index=["col1", "col2", "col3"],
332
+ ),
333
+ )
334
+
335
+ df = read_odps_query(query1, index_col="col1")
336
+ assert df.op.query == query1
337
+ assert df.index_value.name == "col1"
338
+ assert isinstance(df.index_value.value, IndexValue.Index)
339
+ pd.testing.assert_series_equal(
340
+ df.dtypes,
341
+ pd.Series([np.dtype("int64"), np.dtype("float64")], index=["col2", "col3"]),
342
+ )
343
+
344
+ query2 = (
345
+ f"SELECT t1.col1, t1.col2, t1.col3 as c31, t2.col3 as c32 "
346
+ f"FROM {table1_name} t1 "
347
+ f"INNER JOIN {table2_name} t2 "
348
+ f"ON t1.col1 = t2.col1 AND t1.col2 = t2.col2"
349
+ )
350
+ df = read_odps_query(query2, index_col=["col1", "col2"])
351
+ assert df.op.query == query2
352
+ assert df.index_value.names == ["col1", "col2"]
353
+ assert isinstance(df.index_value.value, IndexValue.MultiIndex)
354
+ pd.testing.assert_series_equal(
355
+ df.dtypes,
356
+ pd.Series([np.dtype("float64"), np.dtype("float64")], index=["c31", "c32"]),
357
+ )
358
+
359
+ test_table.drop()
360
+ test_table2.drop()
361
+
362
+
295
363
  def test_date_range():
296
364
  with pytest.raises(TypeError):
297
365
  _ = date_range("2020-1-1", periods="2")
@@ -316,3 +384,18 @@ def test_date_range():
316
384
  assert dr.index_value.is_unique == expected.is_unique
317
385
  assert dr.index_value.is_monotonic_increasing == expected.is_monotonic_increasing
318
386
  assert dr.name == expected.name
387
+
388
+
389
+ def test_resolve_task_sector():
390
+ input_path = os.path.join(os.path.dirname(__file__), "test-data", "task-input.txt")
391
+ with open(input_path, "r") as f:
392
+ sector = f.read()
393
+ actual_sector = _resolve_task_sector("job0", sector)
394
+
395
+ assert actual_sector.job_name == "job0"
396
+ assert actual_sector.task_name == "M1"
397
+ assert actual_sector.output_target == "Screen"
398
+ assert len(actual_sector.schema) == 78
399
+ assert actual_sector.schema[0] == ColumnSchema("unnamed: 0", "bigint", "")
400
+ assert actual_sector.schema[1] == ColumnSchema("id", "bigint", "id_alias")
401
+ assert actual_sector.schema[2] == ColumnSchema("listing_url", "string", "")
@@ -14,6 +14,7 @@
14
14
 
15
15
  # noinspection PyUnresolvedReferences
16
16
  from ..core import DataFrameGroupBy, GroupBy, SeriesGroupBy
17
+ from .core import NamedAgg
17
18
 
18
19
 
19
20
  def _install():
@@ -25,6 +26,7 @@ def _install():
25
26
  from .fill import bfill, ffill, fillna
26
27
  from .getitem import df_groupby_getitem
27
28
  from .head import head
29
+ from .sample import groupby_sample
28
30
  from .transform import groupby_transform
29
31
 
30
32
  for cls in DATAFRAME_TYPE:
@@ -65,6 +67,8 @@ def _install():
65
67
 
66
68
  setattr(cls, "head", head)
67
69
 
70
+ setattr(cls, "sample", groupby_sample)
71
+
68
72
  setattr(cls, "ffill", ffill)
69
73
  setattr(cls, "bfill", bfill)
70
74
  setattr(cls, "backfill", bfill)
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from collections import namedtuple
16
+
15
17
  import pandas as pd
16
18
 
17
19
  from ... import opcodes
@@ -30,6 +32,9 @@ _GROUP_KEYS_NO_DEFAULT = pd_release_version >= (1, 5, 0)
30
32
  _default_group_keys = no_default if _GROUP_KEYS_NO_DEFAULT else True
31
33
 
32
34
 
35
+ NamedAgg = namedtuple("NamedAgg", ["column", "aggfunc"])
36
+
37
+
33
38
  class DataFrameGroupByOperator(MapReduceOperator, DataFrameOperatorMixin):
34
39
  _op_type_ = opcodes.GROUPBY
35
40
 
@@ -29,6 +29,10 @@ class DataFrameToNumeric(DataFrameOperator, DataFrameOperatorMixin):
29
29
  def __init__(self, errors="raise", downcast=None, **kw):
30
30
  super().__init__(errors=errors, downcast=downcast, **kw)
31
31
 
32
+ @property
33
+ def input(self):
34
+ return self.inputs[0]
35
+
32
36
  def __call__(self, arg):
33
37
  if isinstance(arg, pd.Series):
34
38
  arg = asseries(arg)
@@ -18,14 +18,7 @@ from collections.abc import Iterable
18
18
  import numpy as np
19
19
  import pandas as pd
20
20
 
21
- from ...serialization.serializables import (
22
- AnyField,
23
- BoolField,
24
- DictField,
25
- Int32Field,
26
- Int64Field,
27
- StringField,
28
- )
21
+ from ...serialization.serializables import AnyField, BoolField, Int32Field, Int64Field
29
22
  from ..core import DATAFRAME_TYPE
30
23
  from ..operators import DataFrameOperator, DataFrameOperatorMixin
31
24
  from ..utils import build_df, build_empty_series, parse_index
@@ -41,22 +34,6 @@ class BaseDataFrameExpandingAgg(DataFrameOperator, DataFrameOperatorMixin):
41
34
  # True if function name is treated as new index
42
35
  append_index = BoolField("append_index", default=None)
43
36
 
44
- # chunk params
45
- output_agg = BoolField("output_agg", default=None)
46
-
47
- map_groups = DictField("map_groups", default=None)
48
- map_sources = DictField("map_sources", default=None)
49
- combine_sources = DictField("combine_sources", default=None)
50
- combine_columns = DictField("combine_columns", default=None)
51
- combine_funcs = DictField("combine_funcs", default=None)
52
- key_to_funcs = DictField("keys_to_funcs", default=None)
53
-
54
- min_periods_func_name = StringField("min_periods_func_name", default=None)
55
-
56
- @property
57
- def output_limit(self):
58
- return 2 if self.output_agg else 1
59
-
60
37
  def __call__(self, expanding):
61
38
  inp = expanding.input
62
39
  raw_func = self.func
@@ -233,13 +233,6 @@ def ewm(
233
233
  if alpha <= 0 or alpha > 1:
234
234
  raise ValueError("alpha must satisfy: 0 < alpha <= 1")
235
235
 
236
- if not adjust and not ignore_na:
237
- raise NotImplementedError(
238
- "adjust == False when ignore_na == False not implemented"
239
- )
240
- if axis == 1:
241
- raise NotImplementedError("axis other than 0 is not supported")
242
-
243
236
  if alpha == 1:
244
237
  return obj.expanding(min_periods=min_periods, axis=axis)
245
238
 
@@ -23,9 +23,6 @@ def test_ewm():
23
23
  df = pd.DataFrame(np.random.rand(4, 3), columns=list("abc"))
24
24
  df2 = md.DataFrame(df)
25
25
 
26
- with pytest.raises(NotImplementedError):
27
- _ = df2.ewm(2, adjust=False, ignore_na=False)
28
-
29
26
  with pytest.raises(ValueError):
30
27
  _ = df2.ewm()
31
28
 
@@ -59,9 +56,6 @@ def test_ewm_agg():
59
56
  df = pd.DataFrame(np.random.rand(4, 3), columns=list("abc"))
60
57
  df2 = md.DataFrame(df, chunk_size=3)
61
58
 
62
- with pytest.raises(NotImplementedError):
63
- _ = df2.ewm(span=3, axis=1).agg("mean")
64
-
65
59
  r = df2.ewm(span=3).agg("mean")
66
60
  expected = df.ewm(span=3).agg("mean")
67
61
 
maxframe/errors.py ADDED
@@ -0,0 +1,21 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ class MaxFrameError(Exception):
17
+ pass
18
+
19
+
20
+ class MaxFrameUserError(MaxFrameError):
21
+ pass
@@ -14,11 +14,14 @@
14
14
 
15
15
  import asyncio
16
16
  import atexit
17
+ import itertools
17
18
  import threading
18
19
  from typing import Dict, Optional
19
20
 
20
21
 
21
22
  class Isolation:
23
+ _counter = itertools.count().__next__
24
+
22
25
  loop: asyncio.AbstractEventLoop
23
26
  _stopped: Optional[asyncio.Event]
24
27
  _thread: Optional[threading.Thread]
@@ -38,7 +41,9 @@ class Isolation:
38
41
 
39
42
  def start(self):
40
43
  if self._threaded:
41
- self._thread = thread = threading.Thread(target=self._run)
44
+ self._thread = thread = threading.Thread(
45
+ name=f"IsolationThread-{self._counter()}", target=self._run
46
+ )
42
47
  thread.daemon = True
43
48
  thread.start()
44
49
  self._thread_ident = thread.ident
Binary file
maxframe/opcodes.py CHANGED
@@ -462,6 +462,7 @@ READ_ODPS_TABLE = 20111
462
462
  TO_ODPS_TABLE = 20112
463
463
  READ_ODPS_VOLUME = 20113
464
464
  TO_ODPS_VOLUME = 20114
465
+ READ_ODPS_QUERY = 20115
465
466
 
466
467
  TO_CSV_STAT = 2102
467
468
 
maxframe/protocol.py CHANGED
@@ -46,6 +46,8 @@ BodyType = TypeVar("BodyType", bound="Serializable")
46
46
 
47
47
 
48
48
  class JsonSerializable(Serializable):
49
+ _ignore_non_existing_keys = True
50
+
49
51
  @classmethod
50
52
  def from_json(cls, serialized: dict) -> "JsonSerializable":
51
53
  raise NotImplementedError
@@ -209,7 +211,10 @@ class ErrorInfo(JsonSerializable):
209
211
  kw["raw_error_source"] = ErrorSource(serialized["raw_error_source"])
210
212
  if kw.get("raw_error_data"):
211
213
  bufs = [base64.b64decode(s) for s in kw["raw_error_data"]]
212
- kw["raw_error_data"] = pickle.loads(bufs[0], buffers=bufs[1:])
214
+ try:
215
+ kw["raw_error_data"] = pickle.loads(bufs[0], buffers=bufs[1:])
216
+ except:
217
+ kw["raw_error_data"] = None
213
218
  return cls(**kw)
214
219
 
215
220
  def to_json(self) -> dict:
@@ -242,6 +247,8 @@ class DagInfo(JsonSerializable):
242
247
  default_factory=dict,
243
248
  )
244
249
  error_info: Optional[ErrorInfo] = ReferenceField("error_info", default=None)
250
+ start_timestamp: Optional[float] = Float64Field("start_timestamp", default=None)
251
+ end_timestamp: Optional[float] = Float64Field("end_timestamp", default=None)
245
252
 
246
253
  @classmethod
247
254
  def from_json(cls, serialized: dict) -> "DagInfo":
@@ -262,7 +269,10 @@ class DagInfo(JsonSerializable):
262
269
  "dag_id": self.dag_id,
263
270
  "status": self.status.value,
264
271
  "progress": self.progress,
272
+ "start_timestamp": self.start_timestamp,
273
+ "end_timestamp": self.end_timestamp,
265
274
  }
275
+ ret = {k: v for k, v in ret.items() if v is not None}
266
276
  if self.tileable_to_result_infos:
267
277
  ret["tileable_to_result_infos"] = {
268
278
  k: v.to_json() for k, v in self.tileable_to_result_infos.items()
@@ -278,12 +288,18 @@ class CreateSessionRequest(Serializable):
278
288
 
279
289
  class SessionInfo(JsonSerializable):
280
290
  session_id: str = StringField("session_id")
281
- settings: Dict[str, Any] = DictField("settings", key_type=FieldTypes.string)
282
- start_timestamp: float = Float64Field("start_timestamp")
283
- idle_timestamp: float = Float64Field("idle_timestamp")
291
+ settings: Dict[str, Any] = DictField(
292
+ "settings", key_type=FieldTypes.string, default=None
293
+ )
294
+ start_timestamp: float = Float64Field("start_timestamp", default=None)
295
+ idle_timestamp: float = Float64Field("idle_timestamp", default=None)
284
296
  dag_infos: Dict[str, Optional[DagInfo]] = DictField(
285
- "dag_infos", key_type=FieldTypes.string, value_type=FieldTypes.reference
297
+ "dag_infos",
298
+ key_type=FieldTypes.string,
299
+ value_type=FieldTypes.reference,
300
+ default=None,
286
301
  )
302
+ error_info: Optional[ErrorInfo] = ReferenceField("error_info", default=None)
287
303
 
288
304
  @classmethod
289
305
  def from_json(cls, serialized: dict) -> "SessionInfo":
@@ -292,6 +308,8 @@ class SessionInfo(JsonSerializable):
292
308
  kw["dag_infos"] = {
293
309
  k: DagInfo.from_json(v) for k, v in kw["dag_infos"].items()
294
310
  }
311
+ if kw.get("error_info"):
312
+ kw["error_info"] = ErrorInfo.from_json(kw["error_info"])
295
313
  return SessionInfo(**kw)
296
314
 
297
315
  def to_json(self) -> dict:
@@ -303,6 +321,8 @@ class SessionInfo(JsonSerializable):
303
321
  }
304
322
  if self.dag_infos:
305
323
  ret["dag_infos"] = {k: v.to_json() for k, v in self.dag_infos.items()}
324
+ if self.error_info:
325
+ ret["error_info"] = self.error_info.to_json()
306
326
  return ret
307
327
 
308
328
 
@@ -16,13 +16,14 @@ import logging
16
16
  import traceback
17
17
  from typing import Dict, List
18
18
 
19
+ from ..errors import MaxFrameError
19
20
  from ..lib import wrapped_pickle as pickle
20
21
  from .core import Serializer, buffered, pickle_buffers, unpickle_buffers
21
22
 
22
23
  logger = logging.getLogger(__name__)
23
24
 
24
25
 
25
- class RemoteException(Exception):
26
+ class RemoteException(MaxFrameError):
26
27
  def __init__(
27
28
  self, messages: List[str], tracebacks: List[List[str]], buffers: List[bytes]
28
29
  ):
@@ -112,6 +112,7 @@ class Serializable(metaclass=SerializableMeta):
112
112
  __slots__ = ("__weakref__",)
113
113
 
114
114
  _cache_primitive_serial = False
115
+ _ignore_non_existing_keys = False
115
116
 
116
117
  _FIELDS: Dict[str, Field]
117
118
  _FIELD_ORDER: List[str]
@@ -128,7 +129,11 @@ class Serializable(metaclass=SerializableMeta):
128
129
  else:
129
130
  values = kwargs
130
131
  for k, v in values.items():
131
- fields[k].set(self, v)
132
+ try:
133
+ fields[k].set(self, v)
134
+ except KeyError:
135
+ if not self._ignore_non_existing_keys:
136
+ raise
132
137
 
133
138
  def __on_deserialize__(self):
134
139
  pass
@@ -507,12 +507,14 @@ class ReferenceField(Field):
507
507
  tag: str,
508
508
  reference_type: Union[str, Type] = None,
509
509
  default: Any = no_default,
510
+ default_factory: Optional[Callable] = None,
510
511
  on_serialize: Callable[[Any], Any] = None,
511
512
  on_deserialize: Callable[[Any], Any] = None,
512
513
  ):
513
514
  super().__init__(
514
515
  tag,
515
516
  default=default,
517
+ default_factory=default_factory,
516
518
  on_serialize=on_serialize,
517
519
  on_deserialize=on_deserialize,
518
520
  )
maxframe/tensor/core.py CHANGED
@@ -43,7 +43,7 @@ from ..serialization.serializables import (
43
43
  StringField,
44
44
  TupleField,
45
45
  )
46
- from ..utils import on_deserialize_shape, on_serialize_shape
46
+ from ..utils import on_deserialize_shape, on_serialize_shape, skip_na_call
47
47
  from .utils import fetch_corner_data, get_chunk_slices
48
48
 
49
49
  logger = logging.getLogger(__name__)
@@ -181,8 +181,8 @@ class TensorData(HasShapeTileableData, _ExecuteAndFetchMixin):
181
181
  _chunks = ListField(
182
182
  "chunks",
183
183
  FieldTypes.reference(TensorChunkData),
184
- on_serialize=lambda x: [it.data for it in x] if x is not None else x,
185
- on_deserialize=lambda x: [TensorChunk(it) for it in x] if x is not None else x,
184
+ on_serialize=skip_na_call(lambda x: [it.data for it in x]),
185
+ on_deserialize=skip_na_call(lambda x: [TensorChunk(it) for it in x]),
186
186
  )
187
187
 
188
188
  def __init__(
@@ -0,0 +1,69 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import base64
17
+ from typing import List, Tuple
18
+
19
+ # 使用pytest生成单元测试
20
+ import pytest
21
+
22
+ from maxframe.codegen import UserCodeMixin
23
+ from maxframe.lib import wrapped_pickle
24
+ from maxframe.serialization.core import PickleContainer
25
+
26
+
27
+ @pytest.mark.parametrize(
28
+ "input_obj, expected_output",
29
+ [
30
+ (None, "None"),
31
+ (10, "10"),
32
+ (3.14, "3.14"),
33
+ (True, "True"),
34
+ (False, "False"),
35
+ (b"hello", "base64.b64decode(b'aGVsbG8=')"),
36
+ ("hello", "'hello'"),
37
+ ([1, 2, 3], "[1, 2, 3]"),
38
+ ({"a": 1, "b": 2}, "{'a': 1, 'b': 2}"),
39
+ ((1, 2, 3), "(1, 2, 3)"),
40
+ ((1,), "(1,)"),
41
+ ((), "()"),
42
+ ({1, 2, 3}, "{1, 2, 3}"),
43
+ (set(), "set()"),
44
+ ],
45
+ )
46
+ def test_obj_to_python_expr(input_obj, expected_output):
47
+ assert UserCodeMixin.obj_to_python_expr(input_obj) == expected_output
48
+
49
+
50
+ def test_obj_to_python_expr_custom_object():
51
+ class CustomClass:
52
+ def __init__(self, a: int, b: List[int], c: Tuple[int, int]):
53
+ self.a = a
54
+ self.b = b
55
+ self.c = c
56
+
57
+ custom_obj = CustomClass(1, [2, 3], (4, 5))
58
+ pickle_data = wrapped_pickle.dumps(custom_obj)
59
+ pickle_str = base64.b64encode(pickle_data)
60
+ custom_obj_pickle_container = PickleContainer([pickle_data])
61
+
62
+ # with class obj will not support currently
63
+ with pytest.raises(ValueError):
64
+ UserCodeMixin.obj_to_python_expr(custom_obj)
65
+
66
+ assert (
67
+ UserCodeMixin.obj_to_python_expr(custom_obj_pickle_container)
68
+ == f"cloudpickle.loads(base64.b64decode({pickle_str}), buffers=[])"
69
+ )
@@ -11,6 +11,8 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
15
+ import json
14
16
  import time
15
17
 
16
18
  import pytest
@@ -29,28 +31,32 @@ from ..serialization import RemoteException
29
31
  from ..utils import deserialize_serializable, serialize_serializable
30
32
 
31
33
 
34
+ def _json_round_trip(json_data: dict) -> dict:
35
+ return json.loads(json.dumps(json_data))
36
+
37
+
32
38
  def test_result_info_json_serialize():
33
- ri = ResultInfo.from_json(ResultInfo().to_json())
39
+ ri = ResultInfo.from_json(_json_round_trip(ResultInfo().to_json()))
34
40
  assert type(ri) is ResultInfo
35
41
 
36
42
  ri = ODPSTableResultInfo(
37
43
  full_table_name="table_name", partition_specs=["pt=partition"]
38
44
  )
39
- deserial_ri = ResultInfo.from_json(ri.to_json())
45
+ deserial_ri = ResultInfo.from_json(_json_round_trip(ri.to_json()))
40
46
  assert type(ri) is ODPSTableResultInfo
41
47
  assert ri.result_type == deserial_ri.result_type
42
48
  assert ri.full_table_name == deserial_ri.full_table_name
43
49
  assert ri.partition_specs == deserial_ri.partition_specs
44
50
 
45
51
  ri = ODPSTableResultInfo(full_table_name="table_name")
46
- deserial_ri = ResultInfo.from_json(ri.to_json())
52
+ deserial_ri = ResultInfo.from_json(_json_round_trip(ri.to_json()))
47
53
  assert type(ri) is ODPSTableResultInfo
48
54
  assert ri.result_type == deserial_ri.result_type
49
55
  assert ri.full_table_name == deserial_ri.full_table_name
50
56
  assert ri.partition_specs == deserial_ri.partition_specs
51
57
 
52
58
  ri = ODPSVolumeResultInfo(volume_name="vol_name", volume_path="vol_path")
53
- deserial_ri = ResultInfo.from_json(ri.to_json())
59
+ deserial_ri = ResultInfo.from_json(_json_round_trip(ri.to_json()))
54
60
  assert type(ri) is ODPSVolumeResultInfo
55
61
  assert ri.result_type == deserial_ri.result_type
56
62
  assert ri.volume_name == deserial_ri.volume_name
@@ -63,7 +69,7 @@ def test_error_info_json_serialize():
63
69
  except ValueError as ex:
64
70
  err_info = ErrorInfo.from_exception(ex)
65
71
 
66
- deserial_err_info = ErrorInfo.from_json(err_info.to_json())
72
+ deserial_err_info = ErrorInfo.from_json(_json_round_trip(err_info.to_json()))
67
73
  assert deserial_err_info.error_messages == err_info.error_messages
68
74
  assert isinstance(deserial_err_info.raw_error_data, ValueError)
69
75
 
@@ -73,7 +79,7 @@ def test_error_info_json_serialize():
73
79
  with pytest.raises(RemoteException):
74
80
  mf_err_info.reraise()
75
81
 
76
- deserial_err_info = ErrorInfo.from_json(mf_err_info.to_json())
82
+ deserial_err_info = ErrorInfo.from_json(_json_round_trip(mf_err_info.to_json()))
77
83
  assert isinstance(deserial_err_info.raw_error_data, ValueError)
78
84
  with pytest.raises(ValueError):
79
85
  deserial_err_info.reraise()
@@ -94,7 +100,9 @@ def test_dag_info_json_serialize():
94
100
  },
95
101
  error_info=err_info,
96
102
  )
97
- deserial_info = DagInfo.from_json(info.to_json())
103
+ json_info = info.to_json()
104
+ json_info["non_existing_field"] = "non_existing"
105
+ deserial_info = DagInfo.from_json(_json_round_trip(json_info))
98
106
  assert deserial_info.session_id == info.session_id
99
107
  assert deserial_info.dag_id == info.dag_id
100
108
  assert deserial_info.status == info.status
@@ -121,7 +129,7 @@ def test_session_info_json_serialize():
121
129
  idle_timestamp=None,
122
130
  dag_infos={"test_dag_id": dag_info},
123
131
  )
124
- deserial_info = SessionInfo.from_json(info.to_json())
132
+ deserial_info = SessionInfo.from_json(_json_round_trip(info.to_json()))
125
133
  assert deserial_info.session_id == info.session_id
126
134
  assert deserial_info.settings == info.settings
127
135
  assert deserial_info.start_timestamp == info.start_timestamp
maxframe/tests/utils.py CHANGED
@@ -104,6 +104,7 @@ def run_app_in_thread(app_func):
104
104
  q = queue.Queue()
105
105
  exit_event = asyncio.Event(loop=app_loop)
106
106
  app_thread = Thread(
107
+ name="TestAppThread",
107
108
  target=app_thread_func,
108
109
  args=(app_loop, q, exit_event, args, kwargs),
109
110
  )