maxframe 0.1.0b2__cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl → 0.1.0b3__cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

maxframe/codegen.py CHANGED
@@ -17,7 +17,7 @@ import base64
17
17
  import dataclasses
18
18
  import logging
19
19
  from enum import Enum
20
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
20
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
21
21
 
22
22
  from odps.types import OdpsSchema
23
23
  from odps.utils import camel_to_underline
@@ -30,6 +30,7 @@ from .odpsio import build_dataframe_table_meta
30
30
  from .odpsio.schema import pandas_to_odps_schema
31
31
  from .protocol import DataFrameTableMeta, ResultInfo
32
32
  from .serialization import PickleContainer
33
+ from .serialization.serializables import Serializable, StringField
33
34
  from .typing_ import PandasObjectTypes
34
35
  from .udf import MarkedFunction
35
36
 
@@ -48,8 +49,11 @@ class CodeGenResult:
48
49
  constants: Dict[str, Any]
49
50
 
50
51
 
51
- class AbstractUDF(abc.ABC):
52
- _session_id: str
52
+ class AbstractUDF(Serializable):
53
+ _session_id: str = StringField("session_id")
54
+
55
+ def __init__(self, session_id: Optional[str] = None, **kw):
56
+ super().__init__(_session_id=session_id, **kw)
53
57
 
54
58
  @property
55
59
  def name(self) -> str:
@@ -74,7 +78,66 @@ class AbstractUDF(abc.ABC):
74
78
 
75
79
  class UserCodeMixin:
76
80
  @classmethod
77
- def generate_pickled_codes(cls, code_to_pickle: Any) -> List[str]:
81
+ def obj_to_python_expr(cls, obj: Any = None) -> str:
82
+ """
83
+ Parameters
84
+ ----------
85
+ obj
86
+ The object to convert to python expr.
87
+ Returns
88
+ -------
89
+ str :
90
+ The str type content equals to the object when use in the python code directly.
91
+ """
92
+ if obj is None:
93
+ return "None"
94
+
95
+ if isinstance(obj, (int, float)):
96
+ return repr(obj)
97
+
98
+ if isinstance(obj, bool):
99
+ return "True" if obj else "False"
100
+
101
+ if isinstance(obj, bytes):
102
+ base64_bytes = base64.b64encode(obj)
103
+ return f"base64.b64decode({base64_bytes})"
104
+
105
+ if isinstance(obj, str):
106
+ return repr(obj)
107
+
108
+ if isinstance(obj, list):
109
+ return (
110
+ f"[{', '.join([cls.obj_to_python_expr(element) for element in obj])}]"
111
+ )
112
+
113
+ if isinstance(obj, dict):
114
+ items = (
115
+ f"{repr(key)}: {cls.obj_to_python_expr(value)}"
116
+ for key, value in obj.items()
117
+ )
118
+ return f"{{{', '.join(items)}}}"
119
+
120
+ if isinstance(obj, tuple):
121
+ return f"({', '.join([cls.obj_to_python_expr(sub_obj) for sub_obj in obj])}{',' if len(obj) == 1 else ''})"
122
+
123
+ if isinstance(obj, set):
124
+ return (
125
+ f"{{{', '.join([cls.obj_to_python_expr(sub_obj) for sub_obj in obj])}}}"
126
+ if obj
127
+ else "set()"
128
+ )
129
+
130
+ if isinstance(obj, PickleContainer):
131
+ return UserCodeMixin.generate_pickled_codes(obj, None)
132
+
133
+ raise ValueError(f"not support arg type {type(obj)}")
134
+
135
+ @classmethod
136
+ def generate_pickled_codes(
137
+ cls,
138
+ code_to_pickle: Any,
139
+ unpicked_data_var_name: Union[str, None] = "pickled_data",
140
+ ) -> str:
78
141
  """
79
142
  Generate pickled codes. The final pickled variable is called 'pickled_data'.
80
143
 
@@ -82,20 +145,20 @@ class UserCodeMixin:
82
145
  ----------
83
146
  code_to_pickle: Any
84
147
  The code to be pickled.
148
+ unpicked_data_var_name: str
149
+ The variables in code used to hold the loads object from the cloudpickle
85
150
 
86
151
  Returns
87
152
  -------
88
- List[str] :
89
- The code snippets of pickling, the final variable is called 'pickled_data'.
153
+ str :
154
+ The code snippets of pickling, the final variable is called 'pickled_data' by default.
90
155
  """
91
156
  pickled, buffers = cls.dump_pickled_data(code_to_pickle)
92
- pickled = base64.b64encode(pickled)
93
- buffers = [base64.b64encode(b) for b in buffers]
94
- buffers_str = ", ".join(f"base64.b64decode(b'{b.decode()}')" for b in buffers)
95
- return [
96
- f"base64_data = base64.b64decode(b'{pickled.decode()}')",
97
- f"pickled_data = cloudpickle.loads(base64_data, buffers=[{buffers_str}])",
98
- ]
157
+ pickle_loads_expr = f"cloudpickle.loads({cls.obj_to_python_expr(pickled)}, buffers={cls.obj_to_python_expr(buffers)})"
158
+ if unpicked_data_var_name:
159
+ return f"{unpicked_data_var_name} = {pickle_loads_expr}"
160
+
161
+ return pickle_loads_expr
99
162
 
100
163
  @staticmethod
101
164
  def dump_pickled_data(
@@ -114,8 +177,9 @@ class UserCodeMixin:
114
177
 
115
178
 
116
179
  class BigDagCodeContext(metaclass=abc.ABCMeta):
117
- def __init__(self, session_id: str = None):
180
+ def __init__(self, session_id: str = None, subdag_id: str = None):
118
181
  self._session_id = session_id
182
+ self._subdag_id = subdag_id
119
183
  self._tileable_key_to_variables = dict()
120
184
  self.constants = dict()
121
185
  self._data_table_meta_cache = dict()
@@ -142,10 +206,14 @@ class BigDagCodeContext(metaclass=abc.ABCMeta):
142
206
  except KeyError:
143
207
  var_name = self._tileable_key_to_variables[
144
208
  tileable.key
145
- ] = f"var_{self._next_var_id}"
146
- self._next_var_id += 1
209
+ ] = self.next_var_name()
147
210
  return var_name
148
211
 
212
+ def next_var_name(self) -> str:
213
+ var_name = f"var_{self._next_var_id}"
214
+ self._next_var_id += 1
215
+ return var_name
216
+
149
217
  def get_odps_schema(
150
218
  self, data: PandasObjectTypes, unknown_as_string: bool = False
151
219
  ) -> OdpsSchema:
@@ -275,9 +343,10 @@ class BigDagCodeGenerator(metaclass=abc.ABCMeta):
275
343
  engine_priority: int = 0
276
344
  _extension_loaded = False
277
345
 
278
- def __init__(self, session_id: str):
346
+ def __init__(self, session_id: str, subdag_id: str = None):
279
347
  self._session_id = session_id
280
- self._context = self._init_context(session_id)
348
+ self._subdag_id = subdag_id
349
+ self._context = self._init_context(session_id, subdag_id)
281
350
 
282
351
  @classmethod
283
352
  def _load_engine_extensions(cls):
@@ -307,7 +376,7 @@ class BigDagCodeGenerator(metaclass=abc.ABCMeta):
307
376
  raise NotImplementedError
308
377
 
309
378
  @abc.abstractmethod
310
- def _init_context(self, session_id: str) -> BigDagCodeContext:
379
+ def _init_context(self, session_id: str, subdag_id: str) -> BigDagCodeContext:
311
380
  raise NotImplementedError
312
381
 
313
382
  def _generate_comments(
maxframe/config/config.py CHANGED
@@ -340,6 +340,12 @@ default_options.register_option(
340
340
  validator=is_integer,
341
341
  remote=True,
342
342
  )
343
+ default_options.register_option(
344
+ "session.subinstance_priority",
345
+ None,
346
+ validator=any_validator(is_null, is_integer),
347
+ remote=True,
348
+ )
343
349
 
344
350
  default_options.register_option("warn_duplicated_execution", False, validator=is_bool)
345
351
  default_options.register_option("dataframe.use_arrow_dtype", True, validator=is_bool)
@@ -66,6 +66,7 @@ class DecrefRunner:
66
66
  if self._decref_thread: # pragma: no branch
67
67
  self._queue.put_nowait((None, None, None))
68
68
  self._decref_thread.join(1)
69
+ self._decref_thread = None
69
70
 
70
71
  def put(self, key: str, session_ref: ref):
71
72
  if self._decref_thread is None:
@@ -15,6 +15,7 @@
15
15
  from typing import Any, Dict
16
16
 
17
17
  from ...serialization.serializables import FieldTypes, ListField
18
+ from ...utils import skip_na_call
18
19
  from .chunks import Chunk, ChunkData
19
20
  from .core import Entity
20
21
  from .executable import _ToObjectMixin
@@ -62,8 +63,8 @@ class ObjectData(TileableData, _ToObjectMixin):
62
63
  _chunks = ListField(
63
64
  "chunks",
64
65
  FieldTypes.reference(ObjectChunkData),
65
- on_serialize=lambda x: [it.data for it in x] if x is not None else x,
66
- on_deserialize=lambda x: [ObjectChunk(it) for it in x] if x is not None else x,
66
+ on_serialize=skip_na_call(lambda x: [it.data for it in x]),
67
+ on_deserialize=skip_na_call(lambda x: [ObjectChunk(it) for it in x]),
67
68
  )
68
69
 
69
70
  def __init__(self, op=None, nsplits=None, **kw):
@@ -39,6 +39,7 @@ from .datasource.read_odps_query import read_odps_query
39
39
  from .datasource.read_odps_table import read_odps_table
40
40
  from .datasource.read_parquet import read_parquet
41
41
  from .datastore.to_odps import to_odps_table
42
+ from .groupby import NamedAgg
42
43
  from .initializer import DataFrame, Index, Series, read_pandas
43
44
  from .merge import concat, merge
44
45
  from .misc.cut import cut
@@ -52,7 +53,7 @@ from .reduction import CustomReduction, unique
52
53
  from .tseries.to_datetime import to_datetime
53
54
 
54
55
  try:
55
- from pandas import NA, NamedAgg, Timestamp
56
+ from pandas import NA, Timestamp
56
57
  except ImportError: # pragma: no cover
57
58
  pass
58
59
 
@@ -46,7 +46,7 @@ _EXPLAIN_TASK_SCHEMA_REGEX = re.compile(
46
46
  r"In Task ([^:]+)[\S\s]+FS: output: ([^\n #]+)[\s\S]+schema:\s+([\S\s]+)$",
47
47
  re.MULTILINE,
48
48
  )
49
- _EXPLAIN_COLUMN_REGEX = re.compile(r"([^ ]+) \(([^)]+)\)(?:| AS ([^ ]+))(?:\n|$)")
49
+ _EXPLAIN_COLUMN_REGEX = re.compile(r"([^\(]+) \(([^)]+)\)(?:| AS ([^ ]+))(?:\n|$)")
50
50
 
51
51
 
52
52
  @dataclasses.dataclass
@@ -69,7 +69,7 @@ class DataFrameReadODPSTable(
69
69
  return getattr(self, "partition_spec", None)
70
70
 
71
71
  def get_columns(self):
72
- return self.columns
72
+ return self.columns or list(self.dtypes.index)
73
73
 
74
74
  def set_pruned_columns(self, columns, *, keep_order=None): # pragma: no cover
75
75
  self.columns = columns
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import os
15
16
  from collections import OrderedDict
16
17
 
17
18
  import numpy as np
@@ -33,6 +34,7 @@ from ..from_tensor import (
33
34
  )
34
35
  from ..index import from_pandas as from_pandas_index
35
36
  from ..index import from_tileable
37
+ from ..read_odps_query import ColumnSchema, _resolve_task_sector
36
38
  from ..series import from_pandas as from_pandas_series
37
39
 
38
40
  ray = lazy_import("ray")
@@ -228,6 +230,7 @@ def test_from_odps_table():
228
230
  assert df.op.table_name == test_table.full_table_name
229
231
  assert df.index_value.name is None
230
232
  assert isinstance(df.index_value.value, IndexValue.RangeIndex)
233
+ assert df.op.get_columns() == ["col1", "col2", "col3"]
231
234
  pd.testing.assert_series_equal(
232
235
  df.dtypes,
233
236
  pd.Series(
@@ -247,6 +250,7 @@ def test_from_odps_table():
247
250
  assert df.op.table_name == test_table.full_table_name
248
251
  assert df.index_value.name is None
249
252
  assert isinstance(df.index_value.value, IndexValue.RangeIndex)
253
+ assert df.op.get_columns() == ["col1", "col2"]
250
254
  pd.testing.assert_series_equal(
251
255
  df.dtypes,
252
256
  pd.Series([np.dtype("O"), np.dtype("int64")], index=["col1", "col2"]),
@@ -257,6 +261,7 @@ def test_from_odps_table():
257
261
  assert df.index_value.name == "col1"
258
262
  assert isinstance(df.index_value.value, IndexValue.Index)
259
263
  assert df.index.dtype == np.dtype("O")
264
+ assert df.op.get_columns() == ["col2", "col3"]
260
265
  pd.testing.assert_series_equal(
261
266
  df.dtypes,
262
267
  pd.Series([np.dtype("int64"), np.dtype("float64")], index=["col2", "col3"]),
@@ -267,6 +272,7 @@ def test_from_odps_table():
267
272
 
268
273
  df = read_odps_table(test_parted_table, append_partitions=True)
269
274
  assert df.op.append_partitions is True
275
+ assert df.op.get_columns() == ["col1", "col2", "col3", "pt"]
270
276
  pd.testing.assert_series_equal(
271
277
  df.dtypes,
272
278
  pd.Series(
@@ -280,6 +286,7 @@ def test_from_odps_table():
280
286
  )
281
287
  assert df.op.append_partitions is True
282
288
  assert df.op.partitions == ["pt=20240103"]
289
+ assert df.op.get_columns() == ["col1", "col2", "pt"]
283
290
  pd.testing.assert_series_equal(
284
291
  df.dtypes,
285
292
  pd.Series(
@@ -377,3 +384,18 @@ def test_date_range():
377
384
  assert dr.index_value.is_unique == expected.is_unique
378
385
  assert dr.index_value.is_monotonic_increasing == expected.is_monotonic_increasing
379
386
  assert dr.name == expected.name
387
+
388
+
389
+ def test_resolve_task_sector():
390
+ input_path = os.path.join(os.path.dirname(__file__), "test-data", "task-input.txt")
391
+ with open(input_path, "r") as f:
392
+ sector = f.read()
393
+ actual_sector = _resolve_task_sector("job0", sector)
394
+
395
+ assert actual_sector.job_name == "job0"
396
+ assert actual_sector.task_name == "M1"
397
+ assert actual_sector.output_target == "Screen"
398
+ assert len(actual_sector.schema) == 78
399
+ assert actual_sector.schema[0] == ColumnSchema("unnamed: 0", "bigint", "")
400
+ assert actual_sector.schema[1] == ColumnSchema("id", "bigint", "id_alias")
401
+ assert actual_sector.schema[2] == ColumnSchema("listing_url", "string", "")
@@ -14,6 +14,7 @@
14
14
 
15
15
  # noinspection PyUnresolvedReferences
16
16
  from ..core import DataFrameGroupBy, GroupBy, SeriesGroupBy
17
+ from .core import NamedAgg
17
18
 
18
19
 
19
20
  def _install():
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from collections import namedtuple
16
+
15
17
  import pandas as pd
16
18
 
17
19
  from ... import opcodes
@@ -30,6 +32,9 @@ _GROUP_KEYS_NO_DEFAULT = pd_release_version >= (1, 5, 0)
30
32
  _default_group_keys = no_default if _GROUP_KEYS_NO_DEFAULT else True
31
33
 
32
34
 
35
+ NamedAgg = namedtuple("NamedAgg", ["column", "aggfunc"])
36
+
37
+
33
38
  class DataFrameGroupByOperator(MapReduceOperator, DataFrameOperatorMixin):
34
39
  _op_type_ = opcodes.GROUPBY
35
40
 
@@ -14,11 +14,14 @@
14
14
 
15
15
  import asyncio
16
16
  import atexit
17
+ import itertools
17
18
  import threading
18
19
  from typing import Dict, Optional
19
20
 
20
21
 
21
22
  class Isolation:
23
+ _counter = itertools.count().__next__
24
+
22
25
  loop: asyncio.AbstractEventLoop
23
26
  _stopped: Optional[asyncio.Event]
24
27
  _thread: Optional[threading.Thread]
@@ -38,7 +41,9 @@ class Isolation:
38
41
 
39
42
  def start(self):
40
43
  if self._threaded:
41
- self._thread = thread = threading.Thread(target=self._run)
44
+ self._thread = thread = threading.Thread(
45
+ name=f"IsolationThread-{self._counter()}", target=self._run
46
+ )
42
47
  thread.daemon = True
43
48
  thread.start()
44
49
  self._thread_ident = thread.ident
maxframe/protocol.py CHANGED
@@ -46,6 +46,8 @@ BodyType = TypeVar("BodyType", bound="Serializable")
46
46
 
47
47
 
48
48
  class JsonSerializable(Serializable):
49
+ _ignore_non_existing_keys = True
50
+
49
51
  @classmethod
50
52
  def from_json(cls, serialized: dict) -> "JsonSerializable":
51
53
  raise NotImplementedError
@@ -245,6 +247,8 @@ class DagInfo(JsonSerializable):
245
247
  default_factory=dict,
246
248
  )
247
249
  error_info: Optional[ErrorInfo] = ReferenceField("error_info", default=None)
250
+ start_timestamp: Optional[float] = Float64Field("start_timestamp", default=None)
251
+ end_timestamp: Optional[float] = Float64Field("end_timestamp", default=None)
248
252
 
249
253
  @classmethod
250
254
  def from_json(cls, serialized: dict) -> "DagInfo":
@@ -265,7 +269,10 @@ class DagInfo(JsonSerializable):
265
269
  "dag_id": self.dag_id,
266
270
  "status": self.status.value,
267
271
  "progress": self.progress,
272
+ "start_timestamp": self.start_timestamp,
273
+ "end_timestamp": self.end_timestamp,
268
274
  }
275
+ ret = {k: v for k, v in ret.items() if v is not None}
269
276
  if self.tileable_to_result_infos:
270
277
  ret["tileable_to_result_infos"] = {
271
278
  k: v.to_json() for k, v in self.tileable_to_result_infos.items()
@@ -112,6 +112,7 @@ class Serializable(metaclass=SerializableMeta):
112
112
  __slots__ = ("__weakref__",)
113
113
 
114
114
  _cache_primitive_serial = False
115
+ _ignore_non_existing_keys = False
115
116
 
116
117
  _FIELDS: Dict[str, Field]
117
118
  _FIELD_ORDER: List[str]
@@ -128,7 +129,11 @@ class Serializable(metaclass=SerializableMeta):
128
129
  else:
129
130
  values = kwargs
130
131
  for k, v in values.items():
131
- fields[k].set(self, v)
132
+ try:
133
+ fields[k].set(self, v)
134
+ except KeyError:
135
+ if not self._ignore_non_existing_keys:
136
+ raise
132
137
 
133
138
  def __on_deserialize__(self):
134
139
  pass
@@ -507,12 +507,14 @@ class ReferenceField(Field):
507
507
  tag: str,
508
508
  reference_type: Union[str, Type] = None,
509
509
  default: Any = no_default,
510
+ default_factory: Optional[Callable] = None,
510
511
  on_serialize: Callable[[Any], Any] = None,
511
512
  on_deserialize: Callable[[Any], Any] = None,
512
513
  ):
513
514
  super().__init__(
514
515
  tag,
515
516
  default=default,
517
+ default_factory=default_factory,
516
518
  on_serialize=on_serialize,
517
519
  on_deserialize=on_deserialize,
518
520
  )
maxframe/tensor/core.py CHANGED
@@ -43,7 +43,7 @@ from ..serialization.serializables import (
43
43
  StringField,
44
44
  TupleField,
45
45
  )
46
- from ..utils import on_deserialize_shape, on_serialize_shape
46
+ from ..utils import on_deserialize_shape, on_serialize_shape, skip_na_call
47
47
  from .utils import fetch_corner_data, get_chunk_slices
48
48
 
49
49
  logger = logging.getLogger(__name__)
@@ -181,8 +181,8 @@ class TensorData(HasShapeTileableData, _ExecuteAndFetchMixin):
181
181
  _chunks = ListField(
182
182
  "chunks",
183
183
  FieldTypes.reference(TensorChunkData),
184
- on_serialize=lambda x: [it.data for it in x] if x is not None else x,
185
- on_deserialize=lambda x: [TensorChunk(it) for it in x] if x is not None else x,
184
+ on_serialize=skip_na_call(lambda x: [it.data for it in x]),
185
+ on_deserialize=skip_na_call(lambda x: [TensorChunk(it) for it in x]),
186
186
  )
187
187
 
188
188
  def __init__(
@@ -0,0 +1,69 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import base64
17
+ from typing import List, Tuple
18
+
19
+ # 使用pytest生成单元测试
20
+ import pytest
21
+
22
+ from maxframe.codegen import UserCodeMixin
23
+ from maxframe.lib import wrapped_pickle
24
+ from maxframe.serialization.core import PickleContainer
25
+
26
+
27
+ @pytest.mark.parametrize(
28
+ "input_obj, expected_output",
29
+ [
30
+ (None, "None"),
31
+ (10, "10"),
32
+ (3.14, "3.14"),
33
+ (True, "True"),
34
+ (False, "False"),
35
+ (b"hello", "base64.b64decode(b'aGVsbG8=')"),
36
+ ("hello", "'hello'"),
37
+ ([1, 2, 3], "[1, 2, 3]"),
38
+ ({"a": 1, "b": 2}, "{'a': 1, 'b': 2}"),
39
+ ((1, 2, 3), "(1, 2, 3)"),
40
+ ((1,), "(1,)"),
41
+ ((), "()"),
42
+ ({1, 2, 3}, "{1, 2, 3}"),
43
+ (set(), "set()"),
44
+ ],
45
+ )
46
+ def test_obj_to_python_expr(input_obj, expected_output):
47
+ assert UserCodeMixin.obj_to_python_expr(input_obj) == expected_output
48
+
49
+
50
+ def test_obj_to_python_expr_custom_object():
51
+ class CustomClass:
52
+ def __init__(self, a: int, b: List[int], c: Tuple[int, int]):
53
+ self.a = a
54
+ self.b = b
55
+ self.c = c
56
+
57
+ custom_obj = CustomClass(1, [2, 3], (4, 5))
58
+ pickle_data = wrapped_pickle.dumps(custom_obj)
59
+ pickle_str = base64.b64encode(pickle_data)
60
+ custom_obj_pickle_container = PickleContainer([pickle_data])
61
+
62
+ # with class obj will not support currently
63
+ with pytest.raises(ValueError):
64
+ UserCodeMixin.obj_to_python_expr(custom_obj)
65
+
66
+ assert (
67
+ UserCodeMixin.obj_to_python_expr(custom_obj_pickle_container)
68
+ == f"cloudpickle.loads(base64.b64decode({pickle_str}), buffers=[])"
69
+ )
@@ -11,6 +11,8 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
15
+ import json
14
16
  import time
15
17
 
16
18
  import pytest
@@ -29,28 +31,32 @@ from ..serialization import RemoteException
29
31
  from ..utils import deserialize_serializable, serialize_serializable
30
32
 
31
33
 
34
+ def _json_round_trip(json_data: dict) -> dict:
35
+ return json.loads(json.dumps(json_data))
36
+
37
+
32
38
  def test_result_info_json_serialize():
33
- ri = ResultInfo.from_json(ResultInfo().to_json())
39
+ ri = ResultInfo.from_json(_json_round_trip(ResultInfo().to_json()))
34
40
  assert type(ri) is ResultInfo
35
41
 
36
42
  ri = ODPSTableResultInfo(
37
43
  full_table_name="table_name", partition_specs=["pt=partition"]
38
44
  )
39
- deserial_ri = ResultInfo.from_json(ri.to_json())
45
+ deserial_ri = ResultInfo.from_json(_json_round_trip(ri.to_json()))
40
46
  assert type(ri) is ODPSTableResultInfo
41
47
  assert ri.result_type == deserial_ri.result_type
42
48
  assert ri.full_table_name == deserial_ri.full_table_name
43
49
  assert ri.partition_specs == deserial_ri.partition_specs
44
50
 
45
51
  ri = ODPSTableResultInfo(full_table_name="table_name")
46
- deserial_ri = ResultInfo.from_json(ri.to_json())
52
+ deserial_ri = ResultInfo.from_json(_json_round_trip(ri.to_json()))
47
53
  assert type(ri) is ODPSTableResultInfo
48
54
  assert ri.result_type == deserial_ri.result_type
49
55
  assert ri.full_table_name == deserial_ri.full_table_name
50
56
  assert ri.partition_specs == deserial_ri.partition_specs
51
57
 
52
58
  ri = ODPSVolumeResultInfo(volume_name="vol_name", volume_path="vol_path")
53
- deserial_ri = ResultInfo.from_json(ri.to_json())
59
+ deserial_ri = ResultInfo.from_json(_json_round_trip(ri.to_json()))
54
60
  assert type(ri) is ODPSVolumeResultInfo
55
61
  assert ri.result_type == deserial_ri.result_type
56
62
  assert ri.volume_name == deserial_ri.volume_name
@@ -63,7 +69,7 @@ def test_error_info_json_serialize():
63
69
  except ValueError as ex:
64
70
  err_info = ErrorInfo.from_exception(ex)
65
71
 
66
- deserial_err_info = ErrorInfo.from_json(err_info.to_json())
72
+ deserial_err_info = ErrorInfo.from_json(_json_round_trip(err_info.to_json()))
67
73
  assert deserial_err_info.error_messages == err_info.error_messages
68
74
  assert isinstance(deserial_err_info.raw_error_data, ValueError)
69
75
 
@@ -73,7 +79,7 @@ def test_error_info_json_serialize():
73
79
  with pytest.raises(RemoteException):
74
80
  mf_err_info.reraise()
75
81
 
76
- deserial_err_info = ErrorInfo.from_json(mf_err_info.to_json())
82
+ deserial_err_info = ErrorInfo.from_json(_json_round_trip(mf_err_info.to_json()))
77
83
  assert isinstance(deserial_err_info.raw_error_data, ValueError)
78
84
  with pytest.raises(ValueError):
79
85
  deserial_err_info.reraise()
@@ -94,7 +100,9 @@ def test_dag_info_json_serialize():
94
100
  },
95
101
  error_info=err_info,
96
102
  )
97
- deserial_info = DagInfo.from_json(info.to_json())
103
+ json_info = info.to_json()
104
+ json_info["non_existing_field"] = "non_existing"
105
+ deserial_info = DagInfo.from_json(_json_round_trip(json_info))
98
106
  assert deserial_info.session_id == info.session_id
99
107
  assert deserial_info.dag_id == info.dag_id
100
108
  assert deserial_info.status == info.status
@@ -121,7 +129,7 @@ def test_session_info_json_serialize():
121
129
  idle_timestamp=None,
122
130
  dag_infos={"test_dag_id": dag_info},
123
131
  )
124
- deserial_info = SessionInfo.from_json(info.to_json())
132
+ deserial_info = SessionInfo.from_json(_json_round_trip(info.to_json()))
125
133
  assert deserial_info.session_id == info.session_id
126
134
  assert deserial_info.settings == info.settings
127
135
  assert deserial_info.start_timestamp == info.start_timestamp
maxframe/tests/utils.py CHANGED
@@ -104,6 +104,7 @@ def run_app_in_thread(app_func):
104
104
  q = queue.Queue()
105
105
  exit_event = asyncio.Event(loop=app_loop)
106
106
  app_thread = Thread(
107
+ name="TestAppThread",
107
108
  target=app_thread_func,
108
109
  args=(app_loop, q, exit_event, args, kwargs),
109
110
  )