maxframe 0.1.0b4__cp39-cp39-win32.whl → 1.0.0rc1__cp39-cp39-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (81) hide show
  1. maxframe/__init__.py +1 -0
  2. maxframe/_utils.cp39-win32.pyd +0 -0
  3. maxframe/codegen.py +56 -3
  4. maxframe/config/config.py +15 -1
  5. maxframe/core/__init__.py +0 -3
  6. maxframe/core/entity/__init__.py +1 -8
  7. maxframe/core/entity/objects.py +3 -45
  8. maxframe/core/graph/core.cp39-win32.pyd +0 -0
  9. maxframe/core/graph/core.pyx +4 -4
  10. maxframe/dataframe/__init__.py +1 -0
  11. maxframe/dataframe/core.py +30 -8
  12. maxframe/dataframe/datasource/read_odps_query.py +3 -1
  13. maxframe/dataframe/datasource/read_odps_table.py +3 -1
  14. maxframe/dataframe/datastore/tests/__init__.py +13 -0
  15. maxframe/dataframe/datastore/tests/test_to_odps.py +48 -0
  16. maxframe/dataframe/datastore/to_odps.py +21 -0
  17. maxframe/dataframe/indexing/align.py +1 -1
  18. maxframe/dataframe/misc/__init__.py +4 -0
  19. maxframe/dataframe/misc/apply.py +3 -1
  20. maxframe/dataframe/misc/case_when.py +141 -0
  21. maxframe/dataframe/misc/memory_usage.py +2 -2
  22. maxframe/dataframe/misc/pivot_table.py +262 -0
  23. maxframe/dataframe/misc/tests/test_misc.py +84 -0
  24. maxframe/dataframe/plotting/core.py +2 -2
  25. maxframe/dataframe/reduction/core.py +2 -1
  26. maxframe/dataframe/statistics/corr.py +3 -3
  27. maxframe/dataframe/utils.py +7 -0
  28. maxframe/errors.py +13 -0
  29. maxframe/extension.py +12 -0
  30. maxframe/learn/contrib/utils.py +52 -0
  31. maxframe/learn/contrib/xgboost/__init__.py +26 -0
  32. maxframe/learn/contrib/xgboost/classifier.py +86 -0
  33. maxframe/learn/contrib/xgboost/core.py +156 -0
  34. maxframe/learn/contrib/xgboost/dmatrix.py +150 -0
  35. maxframe/learn/contrib/xgboost/predict.py +138 -0
  36. maxframe/learn/contrib/xgboost/regressor.py +78 -0
  37. maxframe/learn/contrib/xgboost/tests/__init__.py +13 -0
  38. maxframe/learn/contrib/xgboost/tests/test_core.py +43 -0
  39. maxframe/learn/contrib/xgboost/train.py +121 -0
  40. maxframe/learn/utils/__init__.py +15 -0
  41. maxframe/learn/utils/core.py +29 -0
  42. maxframe/lib/mmh3.cp39-win32.pyd +0 -0
  43. maxframe/lib/mmh3.pyi +43 -0
  44. maxframe/lib/wrapped_pickle.py +2 -1
  45. maxframe/odpsio/arrow.py +2 -3
  46. maxframe/odpsio/tableio.py +22 -0
  47. maxframe/odpsio/tests/test_schema.py +16 -11
  48. maxframe/opcodes.py +3 -0
  49. maxframe/protocol.py +108 -10
  50. maxframe/serialization/core.cp39-win32.pyd +0 -0
  51. maxframe/serialization/core.pxd +3 -0
  52. maxframe/serialization/core.pyi +64 -0
  53. maxframe/serialization/core.pyx +54 -25
  54. maxframe/serialization/exception.py +1 -1
  55. maxframe/serialization/pandas.py +7 -2
  56. maxframe/serialization/serializables/core.py +119 -12
  57. maxframe/serialization/serializables/tests/test_serializable.py +46 -4
  58. maxframe/session.py +28 -0
  59. maxframe/tensor/__init__.py +1 -1
  60. maxframe/tensor/arithmetic/tests/test_arithmetic.py +1 -1
  61. maxframe/tensor/base/__init__.py +2 -0
  62. maxframe/tensor/base/atleast_1d.py +74 -0
  63. maxframe/tensor/base/unique.py +205 -0
  64. maxframe/tensor/datasource/array.py +4 -2
  65. maxframe/tensor/datasource/scalar.py +1 -1
  66. maxframe/tensor/reduction/count_nonzero.py +1 -1
  67. maxframe/tests/test_protocol.py +34 -0
  68. maxframe/tests/test_utils.py +0 -12
  69. maxframe/tests/utils.py +2 -2
  70. maxframe/udf.py +63 -3
  71. maxframe/utils.py +22 -13
  72. {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/METADATA +3 -3
  73. {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/RECORD +80 -61
  74. maxframe_client/__init__.py +0 -1
  75. maxframe_client/fetcher.py +65 -3
  76. maxframe_client/session/odps.py +74 -5
  77. maxframe_client/session/task.py +65 -71
  78. maxframe_client/tests/test_session.py +64 -1
  79. maxframe_client/clients/spe.py +0 -104
  80. {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/WHEEL +0 -0
  81. {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,121 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from collections import OrderedDict
17
+
18
+ from .... import opcodes as OperandDef
19
+ from ....core import OutputType
20
+ from ....core.operator.base import Operator
21
+ from ....core.operator.core import TileableOperatorMixin
22
+ from ....serialization.serializables import (
23
+ AnyField,
24
+ BoolField,
25
+ DictField,
26
+ FieldTypes,
27
+ FunctionField,
28
+ Int64Field,
29
+ KeyField,
30
+ ListField,
31
+ )
32
+ from .dmatrix import ToDMatrix, to_dmatrix
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ def _on_serialize_evals(evals_val):
38
+ if evals_val is None:
39
+ return None
40
+ return [list(x) for x in evals_val]
41
+
42
+
43
+ class XGBTrain(Operator, TileableOperatorMixin):
44
+ _op_type_ = OperandDef.XGBOOST_TRAIN
45
+
46
+ params = DictField("params", key_type=FieldTypes.string, default=None)
47
+ dtrain = KeyField("dtrain", default=None)
48
+ evals = ListField("evals", on_serialize=_on_serialize_evals, default=None)
49
+ obj = FunctionField("obj", default=None)
50
+ feval = FunctionField("obj", default=None)
51
+ maximize = BoolField("maximize", default=None)
52
+ early_stopping_rounds = Int64Field("early_stopping_rounds", default=None)
53
+ verbose_eval = AnyField("verbose_eval", default=None)
54
+ xgb_model = AnyField("xgb_model", default=None)
55
+ callbacks = ListField(
56
+ "callbacks", field_type=FunctionField.field_type, default=None
57
+ )
58
+ custom_metric = FunctionField("custom_metric", default=None)
59
+ num_boost_round = Int64Field("num_boost_round", default=10)
60
+ num_class = Int64Field("num_class", default=None)
61
+
62
+ # Store evals_result in local to store the remote evals_result
63
+ evals_result: dict = None
64
+
65
+ def __init__(self, gpu=None, **kw):
66
+ super().__init__(gpu=gpu, **kw)
67
+ if self.output_types is None:
68
+ self.output_types = [OutputType.object]
69
+
70
+ def _set_inputs(self, inputs):
71
+ super()._set_inputs(inputs)
72
+ self.dtrain = self._inputs[0]
73
+ rest = self._inputs[1:]
74
+ if self.evals is not None:
75
+ evals_dict = OrderedDict(self.evals)
76
+ new_evals_dict = OrderedDict()
77
+ for new_key, val in zip(rest, evals_dict.values()):
78
+ new_evals_dict[new_key] = val
79
+ self.evals = list(new_evals_dict.items())
80
+
81
+ def __call__(self):
82
+ inputs = [self.dtrain]
83
+ if self.evals is not None:
84
+ inputs.extend(e[0] for e in self.evals)
85
+ return self.new_tileable(inputs)
86
+
87
+
88
+ def train(params, dtrain, evals=None, evals_result=None, num_class=None, **kwargs):
89
+ """
90
+ Train XGBoost model in Mars manner.
91
+
92
+ Parameters
93
+ ----------
94
+ Parameters are the same as `xgboost.train`.
95
+
96
+ Returns
97
+ -------
98
+ results: Booster
99
+ """
100
+
101
+ evals_result = evals_result or dict()
102
+ evals = None or ()
103
+
104
+ processed_evals = []
105
+ if evals:
106
+ for eval_dmatrix, name in evals:
107
+ if not isinstance(name, str):
108
+ raise TypeError("evals must a list of pairs (DMatrix, string)")
109
+ if hasattr(eval_dmatrix, "op") and isinstance(eval_dmatrix.op, ToDMatrix):
110
+ processed_evals.append((eval_dmatrix, name))
111
+ else:
112
+ processed_evals.append((to_dmatrix(eval_dmatrix), name))
113
+
114
+ return XGBTrain(
115
+ params=params,
116
+ dtrain=dtrain,
117
+ evals=processed_evals,
118
+ evals_result=evals_result,
119
+ num_class=num_class,
120
+ **kwargs
121
+ )()
@@ -0,0 +1,15 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .core import convert_to_tensor_or_dataframe
@@ -0,0 +1,29 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import pandas as pd
16
+
17
+ from ...dataframe import DataFrame, Series
18
+ from ...dataframe.core import DATAFRAME_TYPE, SERIES_TYPE
19
+ from ...tensor import tensor as astensor
20
+
21
+
22
+ def convert_to_tensor_or_dataframe(item):
23
+ if isinstance(item, (DATAFRAME_TYPE, pd.DataFrame)):
24
+ item = DataFrame(item)
25
+ elif isinstance(item, (SERIES_TYPE, pd.Series)):
26
+ item = Series(item)
27
+ else:
28
+ item = astensor(item)
29
+ return item
Binary file
maxframe/lib/mmh3.pyi ADDED
@@ -0,0 +1,43 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple
16
+
17
+ def hash(key, seed=0, signed=True) -> int:
18
+ """
19
+ Return a 32 bit integer.
20
+ """
21
+
22
+ def hash_from_buffer(key, seed=0, signed=True) -> int:
23
+ """
24
+ Return a 32 bit integer. Designed for large memory-views such as numpy arrays.
25
+ """
26
+
27
+ def hash64(key, seed=0, x64arch=True, signed=True) -> Tuple[int, int]:
28
+ """
29
+ Return a tuple of two 64 bit integers for a string. Optimized for
30
+ the x64 bit architecture when x64arch=True, otherwise for x86.
31
+ """
32
+
33
+ def hash128(key, seed=0, x64arch=True, signed=False) -> int:
34
+ """
35
+ Return a 128 bit long integer. Optimized for the x64 bit architecture
36
+ when x64arch=True, otherwise for x86.
37
+ """
38
+
39
+ def hash_bytes(key, seed=0, x64arch=True) -> bytes:
40
+ """
41
+ Return a 128 bit hash value as bytes for a string. Optimized for the
42
+ x64 bit architecture when x64arch=True, otherwise for the x86.
43
+ """
@@ -120,7 +120,8 @@ class _UnpickleSwitch:
120
120
  @functools.wraps(func)
121
121
  async def wrapped(*args, **kwargs):
122
122
  with _UnpickleSwitch(forbidden=self._forbidden):
123
- return await func(*args, **kwargs)
123
+ ret = await func(*args, **kwargs)
124
+ return ret
124
125
 
125
126
  else:
126
127
 
maxframe/odpsio/arrow.py CHANGED
@@ -17,10 +17,9 @@ from typing import Any, Tuple, Union
17
17
  import pandas as pd
18
18
  import pyarrow as pa
19
19
 
20
- import maxframe.tensor as mt
21
-
22
20
  from ..core import OutputType
23
21
  from ..protocol import DataFrameTableMeta
22
+ from ..tensor.core import TENSOR_TYPE
24
23
  from ..typing_ import ArrowTableType, PandasObjectTypes
25
24
  from .schema import build_dataframe_table_meta
26
25
 
@@ -83,7 +82,7 @@ def pandas_to_arrow(
83
82
  df = df.to_frame(name=names[0] if len(names) == 1 else names)
84
83
  elif table_meta.type == OutputType.scalar:
85
84
  names = ["_idx_0"]
86
- if isinstance(df, mt.Tensor):
85
+ if isinstance(df, TENSOR_TYPE):
87
86
  df = pd.DataFrame([], columns=names).astype({names[0]: df.dtype})
88
87
  else:
89
88
  df = pd.DataFrame([[df]], columns=names)
@@ -183,6 +183,28 @@ class HaloTableIO(MCTableIO):
183
183
  for pt in partitions
184
184
  ]
185
185
 
186
+ def get_table_record_count(
187
+ self, full_table_name: str, partitions: PartitionsType = None
188
+ ):
189
+ from odps.apis.storage_api import SplitOptions, TableBatchScanRequest
190
+
191
+ table = self._odps.get_table(full_table_name)
192
+ client = StorageApiArrowClient(
193
+ self._odps, table, rest_endpoint=self._storage_api_endpoint
194
+ )
195
+
196
+ split_option = SplitOptions.SplitMode.SIZE
197
+
198
+ scan_kw = {
199
+ "required_partitions": self._convert_partitions(partitions),
200
+ "split_options": SplitOptions.get_default_options(split_option),
201
+ }
202
+
203
+ # todo add more options for partition column handling
204
+ req = TableBatchScanRequest(**scan_kw)
205
+ resp = client.create_read_session(req)
206
+ return resp.record_count
207
+
186
208
  @contextmanager
187
209
  def open_reader(
188
210
  self,
@@ -30,20 +30,23 @@ from ..schema import (
30
30
  )
31
31
 
32
32
 
33
- def _wrap_maxframe_obj(obj, wrap=True):
34
- if not wrap:
33
+ def _wrap_maxframe_obj(obj, wrap="no"):
34
+ if wrap == "no":
35
35
  return obj
36
36
  if isinstance(obj, pd.DataFrame):
37
- return md.DataFrame(obj)
37
+ obj = md.DataFrame(obj)
38
38
  elif isinstance(obj, pd.Series):
39
- return md.Series(obj)
39
+ obj = md.Series(obj)
40
40
  elif isinstance(obj, pd.Index):
41
- return md.Index(obj)
41
+ obj = md.Index(obj)
42
42
  else:
43
- return mt.scalar(obj)
43
+ obj = mt.scalar(obj)
44
+ if wrap == "data":
45
+ return obj.data
46
+ return obj
44
47
 
45
48
 
46
- @pytest.mark.parametrize("wrap_obj", [False, True])
49
+ @pytest.mark.parametrize("wrap_obj", ["no", "yes", "data"])
47
50
  def test_pandas_to_odps_schema_dataframe(wrap_obj):
48
51
  data = pd.DataFrame(np.random.rand(100, 5), columns=list("ABCDE"))
49
52
 
@@ -94,7 +97,7 @@ def test_pandas_to_odps_schema_dataframe(wrap_obj):
94
97
  assert meta.pd_index_level_names == [None, None]
95
98
 
96
99
 
97
- @pytest.mark.parametrize("wrap_obj", [False, True])
100
+ @pytest.mark.parametrize("wrap_obj", ["no", "yes", "data"])
98
101
  def test_pandas_to_odps_schema_series(wrap_obj):
99
102
  data = pd.Series(np.random.rand(100))
100
103
 
@@ -135,7 +138,7 @@ def test_pandas_to_odps_schema_series(wrap_obj):
135
138
  assert meta.pd_index_level_names == ["c1", "c2"]
136
139
 
137
140
 
138
- @pytest.mark.parametrize("wrap_obj", [False, True])
141
+ @pytest.mark.parametrize("wrap_obj", ["no", "yes", "data"])
139
142
  def test_pandas_to_odps_schema_index(wrap_obj):
140
143
  data = pd.Index(np.random.randint(0, 100, 100))
141
144
 
@@ -167,11 +170,13 @@ def test_pandas_to_odps_schema_index(wrap_obj):
167
170
  assert meta.pd_index_level_names == ["c1", "c2"]
168
171
 
169
172
 
170
- @pytest.mark.parametrize("wrap_obj", [False, True])
173
+ @pytest.mark.parametrize("wrap_obj", ["no", "yes", "data"])
171
174
  def test_pandas_to_odps_schema_scalar(wrap_obj):
172
175
  data = 1234.56
173
176
 
174
177
  test_scalar = _wrap_maxframe_obj(data, wrap=wrap_obj)
178
+ if wrap_obj != "no":
179
+ test_scalar.op.data = None
175
180
  schema, meta = pandas_to_odps_schema(test_scalar, unknown_as_string=True)
176
181
  assert schema.columns[0].name == "_idx_0"
177
182
  assert schema.columns[0].type.name == "double"
@@ -279,7 +284,7 @@ def test_build_column_name():
279
284
  assert build_table_column_name(4, ("A", 1), records) == "a_1"
280
285
 
281
286
 
282
- @pytest.mark.parametrize("wrap_obj", [False, True])
287
+ @pytest.mark.parametrize("wrap_obj", ["no", "yes", "data"])
283
288
  def test_build_table_meta(wrap_obj):
284
289
  data = pd.DataFrame(
285
290
  np.random.rand(100, 7),
maxframe/opcodes.py CHANGED
@@ -386,6 +386,9 @@ DATAFRAME_EVAL = 738
386
386
  DUPLICATED = 739
387
387
  DELETE = 740
388
388
  ALIGN = 741
389
+ CASE_WHEN = 742
390
+ PIVOT = 743
391
+ PIVOT_TABLE = 744
389
392
 
390
393
  FUSE = 801
391
394
 
maxframe/protocol.py CHANGED
@@ -32,6 +32,7 @@ from .serialization.serializables import (
32
32
  EnumField,
33
33
  FieldTypes,
34
34
  Float64Field,
35
+ Int32Field,
35
36
  ListField,
36
37
  ReferenceField,
37
38
  Serializable,
@@ -71,6 +72,9 @@ class DagStatus(enum.Enum):
71
72
  CANCELLING = 4
72
73
  CANCELLED = 5
73
74
 
75
+ def is_terminated(self):
76
+ return self in (DagStatus.CANCELLED, DagStatus.SUCCEEDED, DagStatus.FAILED)
77
+
74
78
 
75
79
  class DimensionIndex(Serializable):
76
80
  is_slice: bool = BoolField("is_slice", default=None)
@@ -190,9 +194,9 @@ class ErrorInfo(JsonSerializable):
190
194
  "error_tracebacks", FieldTypes.list
191
195
  )
192
196
  raw_error_source: ErrorSource = EnumField(
193
- "raw_error_source", ErrorSource, FieldTypes.int8
197
+ "raw_error_source", ErrorSource, FieldTypes.int8, default=None
194
198
  )
195
- raw_error_data: Optional[Exception] = AnyField("raw_error_data")
199
+ raw_error_data: Optional[Exception] = AnyField("raw_error_data", default=None)
196
200
 
197
201
  @classmethod
198
202
  def from_exception(cls, exc: Exception):
@@ -201,20 +205,29 @@ class ErrorInfo(JsonSerializable):
201
205
  return cls(messages, tracebacks, ErrorSource.PYTHON, exc)
202
206
 
203
207
  def reraise(self):
204
- if self.raw_error_source == ErrorSource.PYTHON:
208
+ if (
209
+ self.raw_error_source == ErrorSource.PYTHON
210
+ and self.raw_error_data is not None
211
+ ):
205
212
  raise self.raw_error_data
206
213
  raise RemoteException(self.error_messages, self.error_tracebacks, [])
207
214
 
208
215
  @classmethod
209
216
  def from_json(cls, serialized: dict) -> "ErrorInfo":
210
217
  kw = serialized.copy()
211
- kw["raw_error_source"] = ErrorSource(serialized["raw_error_source"])
218
+ if kw.get("raw_error_source") is not None:
219
+ kw["raw_error_source"] = ErrorSource(serialized["raw_error_source"])
220
+ else:
221
+ kw["raw_error_source"] = None
222
+
212
223
  if kw.get("raw_error_data"):
213
224
  bufs = [base64.b64decode(s) for s in kw["raw_error_data"]]
214
225
  try:
215
226
  kw["raw_error_data"] = pickle.loads(bufs[0], buffers=bufs[1:])
216
227
  except:
217
- kw["raw_error_data"] = None
228
+ # both error source and data shall be None to make sure
229
+ # RemoteException is raised.
230
+ kw["raw_error_source"] = kw["raw_error_data"] = None
218
231
  return cls(**kw)
219
232
 
220
233
  def to_json(self) -> dict:
@@ -227,7 +240,12 @@ class ErrorInfo(JsonSerializable):
227
240
  if isinstance(self.raw_error_data, (PickleContainer, RemoteException)):
228
241
  err_data_bufs = self.raw_error_data.get_buffers()
229
242
  elif isinstance(self.raw_error_data, BaseException):
230
- err_data_bufs = pickle_buffers(self.raw_error_data)
243
+ try:
244
+ err_data_bufs = pickle_buffers(self.raw_error_data)
245
+ except:
246
+ err_data_bufs = None
247
+ ret["raw_error_source"] = None
248
+
231
249
  if err_data_bufs:
232
250
  ret["raw_error_data"] = [
233
251
  base64.b64encode(s).decode() for s in err_data_bufs
@@ -249,9 +267,17 @@ class DagInfo(JsonSerializable):
249
267
  error_info: Optional[ErrorInfo] = ReferenceField("error_info", default=None)
250
268
  start_timestamp: Optional[float] = Float64Field("start_timestamp", default=None)
251
269
  end_timestamp: Optional[float] = Float64Field("end_timestamp", default=None)
270
+ subdag_infos: Dict[str, "SubDagInfo"] = DictField(
271
+ "subdag_infos",
272
+ key_type=FieldTypes.string,
273
+ value_type=FieldTypes.reference,
274
+ default_factory=dict,
275
+ )
252
276
 
253
277
  @classmethod
254
- def from_json(cls, serialized: dict) -> "DagInfo":
278
+ def from_json(cls, serialized: dict) -> Optional["DagInfo"]:
279
+ if serialized is None:
280
+ return None
255
281
  kw = serialized.copy()
256
282
  kw["status"] = DagStatus(kw["status"])
257
283
  if kw.get("tileable_to_result_infos"):
@@ -261,6 +287,10 @@ class DagInfo(JsonSerializable):
261
287
  }
262
288
  if kw.get("error_info"):
263
289
  kw["error_info"] = ErrorInfo.from_json(kw["error_info"])
290
+ if kw.get("subdag_infos"):
291
+ kw["subdag_infos"] = {
292
+ k: SubDagInfo.from_json(v) for k, v in kw["subdag_infos"].items()
293
+ }
264
294
  return DagInfo(**kw)
265
295
 
266
296
  def to_json(self) -> dict:
@@ -279,6 +309,8 @@ class DagInfo(JsonSerializable):
279
309
  }
280
310
  if self.error_info:
281
311
  ret["error_info"] = self.error_info.to_json()
312
+ if self.subdag_infos:
313
+ ret["subdag_infos"] = {k: v.to_json() for k, v in self.subdag_infos.items()}
282
314
  return ret
283
315
 
284
316
 
@@ -302,7 +334,9 @@ class SessionInfo(JsonSerializable):
302
334
  error_info: Optional[ErrorInfo] = ReferenceField("error_info", default=None)
303
335
 
304
336
  @classmethod
305
- def from_json(cls, serialized: dict) -> "SessionInfo":
337
+ def from_json(cls, serialized: dict) -> Optional["SessionInfo"]:
338
+ if serialized is None:
339
+ return None
306
340
  kw = serialized.copy()
307
341
  if kw.get("dag_infos"):
308
342
  kw["dag_infos"] = {
@@ -320,7 +354,10 @@ class SessionInfo(JsonSerializable):
320
354
  "idle_timestamp": self.idle_timestamp,
321
355
  }
322
356
  if self.dag_infos:
323
- ret["dag_infos"] = {k: v.to_json() for k, v in self.dag_infos.items()}
357
+ ret["dag_infos"] = {
358
+ k: v.to_json() if v is not None else None
359
+ for k, v in self.dag_infos.items()
360
+ }
324
361
  if self.error_info:
325
362
  ret["error_info"] = self.error_info.to_json()
326
363
  return ret
@@ -342,7 +379,25 @@ class ExecuteDagRequest(Serializable):
342
379
  )
343
380
 
344
381
 
345
- class SubDagInfo(Serializable):
382
+ class SubDagSubmitInstanceInfo(JsonSerializable):
383
+ submit_reason: str = StringField("submit_reason")
384
+ instance_id: str = StringField("instance_id")
385
+ subquery_id: Optional[int] = Int32Field("subquery_id", default=None)
386
+
387
+ @classmethod
388
+ def from_json(cls, serialized: dict) -> "SubDagSubmitInstanceInfo":
389
+ return SubDagSubmitInstanceInfo(**serialized)
390
+
391
+ def to_json(self) -> dict:
392
+ ret = {
393
+ "submit_reason": self.submit_reason,
394
+ "instance_id": self.instance_id,
395
+ "subquery_id": self.subquery_id,
396
+ }
397
+ return ret
398
+
399
+
400
+ class SubDagInfo(JsonSerializable):
346
401
  subdag_id: str = StringField("subdag_id")
347
402
  status: DagStatus = EnumField("status", DagStatus, FieldTypes.int8, default=None)
348
403
  progress: float = Float64Field("progress", default=None)
@@ -355,9 +410,52 @@ class SubDagInfo(Serializable):
355
410
  FieldTypes.reference,
356
411
  default_factory=dict,
357
412
  )
413
+ start_timestamp: Optional[float] = Float64Field("start_timestamp", default=None)
414
+ end_timestamp: Optional[float] = Float64Field("end_timestamp", default=None)
415
+ submit_instances: List[SubDagSubmitInstanceInfo] = ListField(
416
+ "submit_instances",
417
+ FieldTypes.reference,
418
+ default_factory=list,
419
+ )
420
+
421
+ @classmethod
422
+ def from_json(cls, serialized: dict) -> "SubDagInfo":
423
+ kw = serialized.copy()
424
+ kw["status"] = DagStatus(kw["status"])
425
+ if kw.get("tileable_to_result_infos"):
426
+ kw["tileable_to_result_infos"] = {
427
+ k: ResultInfo.from_json(s)
428
+ for k, s in kw["tileable_to_result_infos"].items()
429
+ }
430
+ if kw.get("error_info"):
431
+ kw["error_info"] = ErrorInfo.from_json(kw["error_info"])
432
+ if kw.get("submit_instances"):
433
+ kw["submit_instances"] = [
434
+ SubDagSubmitInstanceInfo.from_json(s) for s in kw["submit_instances"]
435
+ ]
436
+ return SubDagInfo(**kw)
437
+
438
+ def to_json(self) -> dict:
439
+ ret = {
440
+ "subdag_id": self.subdag_id,
441
+ "status": self.status.value,
442
+ "progress": self.progress,
443
+ "start_timestamp": self.start_timestamp,
444
+ "end_timestamp": self.end_timestamp,
445
+ }
446
+ if self.error_info:
447
+ ret["error_info"] = self.error_info.to_json()
448
+ if self.tileable_to_result_infos:
449
+ ret["tileable_to_result_infos"] = {
450
+ k: v.to_json() for k, v in self.tileable_to_result_infos.items()
451
+ }
452
+ if self.submit_instances:
453
+ ret["submit_instances"] = [i.to_json() for i in self.submit_instances]
454
+ return ret
358
455
 
359
456
 
360
457
  class ExecuteSubDagRequest(Serializable):
458
+ subdag_id: str = StringField("subdag_id")
361
459
  dag: TileableGraph = ReferenceField(
362
460
  "dag",
363
461
  on_serialize=SerializableGraph.from_graph,
Binary file
@@ -18,6 +18,9 @@ from libc.stdint cimport int32_t, uint64_t
18
18
  cdef class Serializer:
19
19
  cdef int _serializer_id
20
20
 
21
+ cpdef bint is_public_data_exist(self, dict context, object key)
22
+ cpdef put_public_data(self, dict context, object key, object value)
23
+ cpdef get_public_data(self, dict context, object key)
21
24
  cpdef serial(self, object obj, dict context)
22
25
  cpdef deserial(self, list serialized, dict context, list subs)
23
26
  cpdef on_deserial_error(
@@ -0,0 +1,64 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from concurrent.futures import Executor
16
+ from typing import Any, Callable, Dict, List, TypeVar
17
+
18
+ def buffered(func: Callable) -> Callable: ...
19
+ def fast_id(obj: Any) -> int: ...
20
+
21
+ LoadType = TypeVar("LoadType")
22
+
23
+ def load_type(class_name: str, parent_class: LoadType) -> LoadType: ...
24
+
25
+ class PickleContainer:
26
+ def __init__(self, buffers: List[bytes]): ...
27
+ def get(self) -> Any: ...
28
+ def get_buffers(self) -> List[bytes]: ...
29
+
30
+ class Serializer:
31
+ serializer_id: int
32
+ def is_public_data_exist(self, context: Dict, key: Any) -> bool: ...
33
+ def put_public_data(self, context: Dict, key: Any, value: Any) -> None: ...
34
+ def get_public_data(self, context: Dict, key: Any) -> Any: ...
35
+ def serial(self, obj: Any, context: Dict): ...
36
+ def deserial(self, serialized: List, context: Dict, subs: List[Any]): ...
37
+ def on_deserial_error(
38
+ self,
39
+ serialized: List,
40
+ context: Dict,
41
+ subs_serialized: List,
42
+ error_index: int,
43
+ exc: BaseException,
44
+ ): ...
45
+ @classmethod
46
+ def register(cls, obj_type): ...
47
+ @classmethod
48
+ def unregister(cls, obj_type): ...
49
+
50
+ class Placeholder:
51
+ id: int
52
+ callbacks: List[Callable]
53
+ def __init__(self, id_: int): ...
54
+ def __hash__(self): ...
55
+ def __eq__(self, other): ...
56
+
57
+ def serialize(obj: Any, context: Dict = None): ...
58
+ async def serialize_with_spawn(
59
+ obj: Any,
60
+ context: Dict = None,
61
+ spawn_threshold: int = 100,
62
+ executor: Executor = None,
63
+ ): ...
64
+ def deserialize(headers: List, buffers: List, context: Dict = None): ...