maxframe 0.1.0b5__cp37-cp37m-win32.whl → 1.0.0rc2__cp37-cp37m-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (92) hide show
  1. maxframe/_utils.cp37-win32.pyd +0 -0
  2. maxframe/codegen.py +6 -2
  3. maxframe/config/config.py +38 -2
  4. maxframe/config/validators.py +1 -0
  5. maxframe/conftest.py +2 -0
  6. maxframe/core/__init__.py +0 -3
  7. maxframe/core/entity/__init__.py +1 -8
  8. maxframe/core/entity/objects.py +3 -45
  9. maxframe/core/graph/core.cp37-win32.pyd +0 -0
  10. maxframe/core/graph/core.pyx +4 -4
  11. maxframe/dataframe/__init__.py +1 -1
  12. maxframe/dataframe/arithmetic/around.py +5 -17
  13. maxframe/dataframe/arithmetic/core.py +15 -7
  14. maxframe/dataframe/arithmetic/docstring.py +5 -55
  15. maxframe/dataframe/arithmetic/tests/test_arithmetic.py +22 -0
  16. maxframe/dataframe/core.py +5 -5
  17. maxframe/dataframe/datasource/date_range.py +2 -2
  18. maxframe/dataframe/datasource/read_odps_query.py +6 -0
  19. maxframe/dataframe/datasource/read_odps_table.py +2 -1
  20. maxframe/dataframe/datasource/tests/test_datasource.py +14 -0
  21. maxframe/dataframe/datastore/tests/__init__.py +13 -0
  22. maxframe/dataframe/datastore/tests/test_to_odps.py +48 -0
  23. maxframe/dataframe/datastore/to_odps.py +21 -0
  24. maxframe/dataframe/groupby/cum.py +0 -1
  25. maxframe/dataframe/groupby/tests/test_groupby.py +4 -0
  26. maxframe/dataframe/indexing/add_prefix_suffix.py +1 -1
  27. maxframe/dataframe/indexing/align.py +1 -1
  28. maxframe/dataframe/indexing/rename.py +3 -37
  29. maxframe/dataframe/indexing/sample.py +0 -1
  30. maxframe/dataframe/indexing/set_index.py +68 -1
  31. maxframe/dataframe/merge/merge.py +236 -2
  32. maxframe/dataframe/merge/tests/test_merge.py +123 -0
  33. maxframe/dataframe/misc/apply.py +5 -10
  34. maxframe/dataframe/misc/case_when.py +1 -1
  35. maxframe/dataframe/misc/describe.py +2 -2
  36. maxframe/dataframe/misc/drop_duplicates.py +4 -25
  37. maxframe/dataframe/misc/eval.py +4 -0
  38. maxframe/dataframe/misc/memory_usage.py +2 -2
  39. maxframe/dataframe/misc/pct_change.py +1 -83
  40. maxframe/dataframe/misc/tests/test_misc.py +23 -0
  41. maxframe/dataframe/misc/transform.py +1 -30
  42. maxframe/dataframe/misc/value_counts.py +4 -17
  43. maxframe/dataframe/missing/dropna.py +1 -1
  44. maxframe/dataframe/missing/fillna.py +5 -5
  45. maxframe/dataframe/sort/sort_values.py +1 -11
  46. maxframe/dataframe/statistics/corr.py +3 -3
  47. maxframe/dataframe/statistics/quantile.py +5 -17
  48. maxframe/dataframe/utils.py +4 -7
  49. maxframe/errors.py +13 -0
  50. maxframe/extension.py +12 -0
  51. maxframe/learn/contrib/xgboost/dmatrix.py +2 -2
  52. maxframe/learn/contrib/xgboost/predict.py +2 -2
  53. maxframe/learn/contrib/xgboost/train.py +2 -2
  54. maxframe/lib/mmh3.cp37-win32.pyd +0 -0
  55. maxframe/lib/mmh3.pyi +43 -0
  56. maxframe/lib/wrapped_pickle.py +2 -1
  57. maxframe/odpsio/__init__.py +1 -1
  58. maxframe/odpsio/arrow.py +8 -4
  59. maxframe/odpsio/schema.py +10 -7
  60. maxframe/odpsio/tableio.py +388 -14
  61. maxframe/odpsio/tests/test_schema.py +16 -15
  62. maxframe/odpsio/tests/test_tableio.py +48 -21
  63. maxframe/protocol.py +148 -12
  64. maxframe/serialization/core.cp37-win32.pyd +0 -0
  65. maxframe/serialization/core.pxd +3 -0
  66. maxframe/serialization/core.pyi +3 -0
  67. maxframe/serialization/core.pyx +54 -25
  68. maxframe/serialization/exception.py +1 -1
  69. maxframe/serialization/pandas.py +7 -2
  70. maxframe/serialization/serializables/core.py +158 -12
  71. maxframe/serialization/serializables/tests/test_serializable.py +46 -4
  72. maxframe/tensor/__init__.py +59 -0
  73. maxframe/tensor/arithmetic/tests/test_arithmetic.py +1 -1
  74. maxframe/tensor/base/atleast_1d.py +1 -1
  75. maxframe/tensor/base/unique.py +3 -3
  76. maxframe/tensor/reduction/count_nonzero.py +1 -1
  77. maxframe/tensor/statistics/quantile.py +2 -2
  78. maxframe/tests/test_protocol.py +34 -0
  79. maxframe/tests/test_utils.py +0 -12
  80. maxframe/tests/utils.py +11 -2
  81. maxframe/utils.py +24 -13
  82. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc2.dist-info}/METADATA +75 -2
  83. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc2.dist-info}/RECORD +91 -89
  84. maxframe_client/__init__.py +0 -1
  85. maxframe_client/fetcher.py +38 -27
  86. maxframe_client/session/odps.py +50 -10
  87. maxframe_client/session/task.py +41 -20
  88. maxframe_client/tests/test_fetcher.py +21 -3
  89. maxframe_client/tests/test_session.py +49 -2
  90. maxframe_client/clients/spe.py +0 -104
  91. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc2.dist-info}/WHEEL +0 -0
  92. {maxframe-0.1.0b5.dist-info → maxframe-1.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -12,11 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import base64
16
- import json
17
15
  from abc import ABC, abstractmethod
18
16
  from numbers import Integral
19
- from typing import Any, Dict, List, Optional, Type, Union
17
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union
20
18
 
21
19
  import pandas as pd
22
20
  import pyarrow as pa
@@ -28,7 +26,7 @@ from tornado import httpclient
28
26
  from maxframe.core import OBJECT_TYPE
29
27
  from maxframe.dataframe.core import DATAFRAME_TYPE
30
28
  from maxframe.lib import wrapped_pickle as pickle
31
- from maxframe.odpsio import HaloTableIO, arrow_to_pandas, build_dataframe_table_meta
29
+ from maxframe.odpsio import ODPSTableIO, arrow_to_pandas, build_dataframe_table_meta
32
30
  from maxframe.protocol import (
33
31
  DataFrameTableMeta,
34
32
  ODPSTableResultInfo,
@@ -38,7 +36,7 @@ from maxframe.protocol import (
38
36
  )
39
37
  from maxframe.tensor.core import TENSOR_TYPE
40
38
  from maxframe.typing_ import PandasObjectTypes, TileableType
41
- from maxframe.utils import ToThreadMixin, deserialize_serializable
39
+ from maxframe.utils import ToThreadMixin
42
40
 
43
41
  _result_fetchers: Dict[ResultType, Type["ResultFetcher"]] = dict()
44
42
 
@@ -109,17 +107,12 @@ class ODPSTableFetcher(ToThreadMixin, ResultFetcher):
109
107
  tileable: TileableType,
110
108
  info: ODPSTableResultInfo,
111
109
  ) -> None:
112
- if isinstance(tileable, DATAFRAME_TYPE) and tileable.dtypes is None:
113
- tb_comment = await self.to_thread(
114
- self._get_table_comment, info.full_table_name
115
- )
116
- if tb_comment: # pragma: no branch
117
- comment_data = json.loads(tb_comment)
118
-
119
- table_meta: DataFrameTableMeta = deserialize_serializable(
120
- base64.b64decode(comment_data["table_meta"])
121
- )
122
- tileable.refresh_from_table_meta(table_meta)
110
+ if (
111
+ isinstance(tileable, DATAFRAME_TYPE)
112
+ and tileable.dtypes is None
113
+ and info.table_meta is not None
114
+ ):
115
+ tileable.refresh_from_table_meta(info.table_meta)
123
116
 
124
117
  if tileable.shape and any(pd.isna(x) for x in tileable.shape):
125
118
  part_specs = [None] if not info.partition_specs else info.partition_specs
@@ -131,16 +124,39 @@ class ODPSTableFetcher(ToThreadMixin, ResultFetcher):
131
124
  )
132
125
  total_records += session.count
133
126
  new_shape_list = list(tileable.shape)
134
- new_shape_list[-1] = total_records
127
+ new_shape_list[0] = total_records
135
128
  tileable.params = {"shape": tuple(new_shape_list)}
136
129
 
130
+ @staticmethod
131
+ def _align_selection_with_shape(
132
+ row_sel: slice, shape: Tuple[Optional[int], ...]
133
+ ) -> dict:
134
+ size = shape[0]
135
+ if not row_sel.start and not row_sel.stop:
136
+ return {}
137
+ is_reversed = row_sel.step is not None and row_sel.step < 0
138
+ read_kw = {
139
+ "start": row_sel.start,
140
+ "stop": row_sel.stop,
141
+ "reverse_range": is_reversed,
142
+ }
143
+ if pd.isna(size):
144
+ return read_kw
145
+
146
+ if is_reversed and row_sel.start is not None:
147
+ read_kw["start"] = min(size - 1, row_sel.start)
148
+ if not is_reversed and row_sel.stop is not None:
149
+ read_kw["stop"] = min(size, row_sel.stop)
150
+ return read_kw
151
+
137
152
  def _read_single_source(
138
153
  self,
139
154
  table_meta: DataFrameTableMeta,
140
155
  info: ODPSTableResultInfo,
141
156
  indexes: List[Union[None, Integral, slice]],
157
+ shape: Tuple[Optional[int], ...],
142
158
  ):
143
- table_io = HaloTableIO(self._odps_entry)
159
+ table_io = ODPSTableIO(self._odps_entry)
144
160
  read_kw = {}
145
161
  row_step = None
146
162
  if indexes:
@@ -148,13 +164,8 @@ class ODPSTableFetcher(ToThreadMixin, ResultFetcher):
148
164
  indexes += [None]
149
165
  row_sel, col_sel = indexes
150
166
  if isinstance(row_sel, slice):
151
- if row_sel.start or row_sel.stop:
152
- read_kw["start"] = row_sel.start
153
- read_kw["stop"] = row_sel.stop
154
- read_kw["reverse_range"] = (
155
- row_sel.step is not None and row_sel.step < 0
156
- )
157
- row_step = row_sel.step
167
+ row_step = row_sel.step
168
+ read_kw = self._align_selection_with_shape(row_sel, shape)
158
169
  elif isinstance(row_sel, int):
159
170
  read_kw["start"] = row_sel
160
171
  read_kw["stop"] = row_sel + 1
@@ -173,8 +184,8 @@ class ODPSTableFetcher(ToThreadMixin, ResultFetcher):
173
184
  with table_io.open_reader(
174
185
  info.full_table_name, info.partition_specs, **read_kw
175
186
  ) as reader:
176
- reader_count = reader.count
177
187
  result = reader.read_all()
188
+ reader_count = result.num_rows
178
189
 
179
190
  if not row_step:
180
191
  return result
@@ -195,7 +206,7 @@ class ODPSTableFetcher(ToThreadMixin, ResultFetcher):
195
206
  ) -> PandasObjectTypes:
196
207
  table_meta = build_dataframe_table_meta(tileable)
197
208
  arrow_table: pa.Table = await self.to_thread(
198
- self._read_single_source, table_meta, info, indexes
209
+ self._read_single_source, table_meta, info, indexes, tileable.shape
199
210
  )
200
211
  return arrow_to_pandas(arrow_table, table_meta)
201
212
 
@@ -31,7 +31,12 @@ from maxframe.dataframe import read_odps_table
31
31
  from maxframe.dataframe.core import DATAFRAME_TYPE, SERIES_TYPE
32
32
  from maxframe.dataframe.datasource import PandasDataSourceOperator
33
33
  from maxframe.dataframe.datasource.read_odps_table import DataFrameReadODPSTable
34
- from maxframe.odpsio import HaloTableIO, pandas_to_arrow, pandas_to_odps_schema
34
+ from maxframe.errors import (
35
+ MaxFrameError,
36
+ NoTaskServerResponseError,
37
+ SessionAlreadyClosedError,
38
+ )
39
+ from maxframe.odpsio import ODPSTableIO, pandas_to_arrow, pandas_to_odps_schema
35
40
  from maxframe.protocol import (
36
41
  DagInfo,
37
42
  DagStatus,
@@ -144,18 +149,23 @@ class MaxFrameSession(ToThreadMixin, IsolatedAsyncSession):
144
149
 
145
150
  schema, table_meta = pandas_to_odps_schema(t, unknown_as_string=True)
146
151
  if self._odps_entry.exist_table(table_meta.table_name):
147
- self._odps_entry.delete_table(table_meta.table_name)
152
+ self._odps_entry.delete_table(
153
+ table_meta.table_name, hints=options.sql.settings
154
+ )
148
155
  table_name = build_temp_table_name(self.session_id, t.key)
149
156
  table_obj = self._odps_entry.create_table(
150
- table_name, schema, lifecycle=options.session.temp_table_lifecycle
157
+ table_name,
158
+ schema,
159
+ lifecycle=options.session.temp_table_lifecycle,
160
+ hints=options.sql.settings,
151
161
  )
152
162
 
153
163
  data = t.op.get_data()
154
164
  batch_size = options.session.upload_batch_size
155
165
 
156
166
  if len(data):
157
- halo_client = HaloTableIO(self._odps_entry)
158
- with halo_client.open_writer(table_obj.full_table_name) as writer:
167
+ table_client = ODPSTableIO(self._odps_entry)
168
+ with table_client.open_writer(table_obj.full_table_name) as writer:
159
169
  for batch_start in range(0, len(data), batch_size):
160
170
  if isinstance(data, pd.Index):
161
171
  batch = data[batch_start : batch_start + batch_size]
@@ -178,7 +188,7 @@ class MaxFrameSession(ToThreadMixin, IsolatedAsyncSession):
178
188
  read_tileable.name = t.name
179
189
  else: # INDEX_TYPE
180
190
  if list(read_tileable.names) != list(t.names):
181
- read_tileable.names = t.names
191
+ read_tileable.rename(t.names, inplace=True)
182
192
  read_tileable._key = t.key
183
193
  read_tileable.params = t.params
184
194
  return read_tileable.data
@@ -264,8 +274,10 @@ class MaxFrameSession(ToThreadMixin, IsolatedAsyncSession):
264
274
  self, dag_info: DagInfo, tileables: List, progress: Progress
265
275
  ):
266
276
  start_time = time.time()
277
+ session_id = dag_info.session_id
267
278
  dag_id = dag_info.dag_id
268
279
  wait_timeout = 10
280
+ server_no_response_time = None
269
281
  with enter_mode(build=True, kernel=True):
270
282
  key_to_tileables = {t.key: t for t in tileables}
271
283
 
@@ -280,9 +292,37 @@ class MaxFrameSession(ToThreadMixin, IsolatedAsyncSession):
280
292
  if timeout_val <= 0:
281
293
  raise TimeoutError("Running DAG timed out")
282
294
 
283
- dag_info: DagInfo = await self.ensure_async_call(
284
- self._caller.get_dag_info, dag_id
285
- )
295
+ try:
296
+ dag_info: DagInfo = await self.ensure_async_call(
297
+ self._caller.get_dag_info, dag_id
298
+ )
299
+ server_no_response_time = None
300
+ except (NoTaskServerResponseError, SessionAlreadyClosedError) as ex:
301
+ # when we receive SessionAlreadyClosedError after NoTaskServerResponseError
302
+ # is received, it is possible that task server is restarted and
303
+ # SessionAlreadyClosedError might be flaky. Otherwise, the error
304
+ # should be raised.
305
+ if (
306
+ isinstance(ex, SessionAlreadyClosedError)
307
+ and not server_no_response_time
308
+ ):
309
+ raise
310
+ server_no_response_time = server_no_response_time or time.time()
311
+ if (
312
+ time.time() - server_no_response_time
313
+ > options.client.task_restart_timeout
314
+ ):
315
+ raise MaxFrameError(
316
+ "Failed to get valid response from service. "
317
+ f"Session {self._session_id}."
318
+ ) from None
319
+ await asyncio.sleep(timeout_val)
320
+ continue
321
+
322
+ if dag_info is None:
323
+ raise SystemError(
324
+ f"Cannot find DAG with ID {dag_id} in session {session_id}"
325
+ )
286
326
  progress.value = dag_info.progress
287
327
  if dag_info.status != DagStatus.RUNNING:
288
328
  break
@@ -344,7 +384,7 @@ class MaxFrameSession(ToThreadMixin, IsolatedAsyncSession):
344
384
  data_tileable, indexes = self._get_data_tileable_and_indexes(tileable)
345
385
  info = self._tileable_to_infos[data_tileable]
346
386
  fetcher = get_fetcher_cls(info.result_type)(self._odps_entry)
347
- results.append(await fetcher.fetch(tileable, info, indexes))
387
+ results.append(await fetcher.fetch(data_tileable, info, indexes))
348
388
  return results
349
389
 
350
390
  async def decref(self, *tileable_keys):
@@ -26,6 +26,7 @@ from odps.models import Instance, MaxFrameTask
26
26
 
27
27
  from maxframe.config import options
28
28
  from maxframe.core import TileableGraph
29
+ from maxframe.errors import NoTaskServerResponseError, SessionAlreadyClosedError
29
30
  from maxframe.protocol import DagInfo, JsonSerializable, ResultInfo, SessionInfo
30
31
  from maxframe.utils import deserialize_serializable, serialize_serializable, to_str
31
32
 
@@ -82,6 +83,7 @@ class MaxFrameInstanceCaller(MaxFrameServiceCaller):
82
83
  self._running_cluster = running_cluster
83
84
  self._major_version = major_version
84
85
  self._output_format = output_format or MAXFRAME_OUTPUT_MSGPACK_FORMAT
86
+ self._deleted = False
85
87
 
86
88
  if nested_instance_id is None:
87
89
  self._nested = False
@@ -94,10 +96,18 @@ class MaxFrameInstanceCaller(MaxFrameServiceCaller):
94
96
  self, content: Union[bytes, str, dict], target_cls: Type[JsonSerializable]
95
97
  ):
96
98
  if isinstance(content, (str, bytes)):
99
+ if len(content) == 0:
100
+ content = "{}"
97
101
  json_data = json.loads(to_str(content))
98
102
  else:
99
103
  json_data = content
100
- result_data = base64.b64decode(json_data["result"])
104
+ encoded_result = json_data.get("result")
105
+ if not encoded_result:
106
+ if self._deleted:
107
+ return None
108
+ else:
109
+ raise SessionAlreadyClosedError(self._instance.id)
110
+ result_data = base64.b64decode(encoded_result)
101
111
  if self._output_format == MAXFRAME_OUTPUT_MAXFRAME_FORMAT:
102
112
  return deserialize_serializable(result_data)
103
113
  elif self._output_format == MAXFRAME_OUTPUT_JSON_FORMAT:
@@ -178,6 +188,14 @@ class MaxFrameInstanceCaller(MaxFrameServiceCaller):
178
188
  time.sleep(interval)
179
189
  interval = min(max_interval, interval * 2)
180
190
 
191
+ def _put_task_info(self, method_name: str, json_data: dict):
192
+ resp_data = self._instance.put_task_info(
193
+ self._task_name, method_name, json.dumps(json_data)
194
+ )
195
+ if not resp_data:
196
+ raise NoTaskServerResponseError(f"No response for request {method_name}")
197
+ return resp_data
198
+
181
199
  def get_session(self) -> SessionInfo:
182
200
  req_data = {"output_format": self._output_format}
183
201
  serialized = self._instance.put_task_info(
@@ -192,11 +210,8 @@ class MaxFrameInstanceCaller(MaxFrameServiceCaller):
192
210
  self._instance.stop()
193
211
  else:
194
212
  req_data = {"output_format": self._output_format}
195
- self._instance.put_task_info(
196
- self._task_name,
197
- MAXFRAME_TASK_DELETE_SESSION_METHOD,
198
- json.dumps(req_data),
199
- )
213
+ self._put_task_info(MAXFRAME_TASK_DELETE_SESSION_METHOD, req_data)
214
+ self._deleted = True
200
215
 
201
216
  def submit_dag(
202
217
  self,
@@ -211,9 +226,7 @@ class MaxFrameInstanceCaller(MaxFrameServiceCaller):
211
226
  ).decode(),
212
227
  "output_format": self._output_format,
213
228
  }
214
- res = self._instance.put_task_info(
215
- self._task_name, MAXFRAME_TASK_SUBMIT_DAG_METHOD, json.dumps(req_data)
216
- )
229
+ res = self._put_task_info(MAXFRAME_TASK_SUBMIT_DAG_METHOD, req_data)
217
230
  return self._deserial_task_info_result(res, DagInfo)
218
231
 
219
232
  def get_dag_info(self, dag_id: str) -> DagInfo:
@@ -222,9 +235,7 @@ class MaxFrameInstanceCaller(MaxFrameServiceCaller):
222
235
  "dag_id": dag_id,
223
236
  "output_format": self._output_format,
224
237
  }
225
- res = self._instance.put_task_info(
226
- self._task_name, MAXFRAME_TASK_GET_DAG_INFO_METHOD, json.dumps(req_data)
227
- )
238
+ res = self._put_task_info(MAXFRAME_TASK_GET_DAG_INFO_METHOD, req_data)
228
239
  return self._deserial_task_info_result(res, DagInfo)
229
240
 
230
241
  def cancel_dag(self, dag_id: str) -> DagInfo:
@@ -233,23 +244,33 @@ class MaxFrameInstanceCaller(MaxFrameServiceCaller):
233
244
  "dag_id": dag_id,
234
245
  "output_format": self._output_format,
235
246
  }
236
- res = self._instance.put_task_info(
237
- self._task_name, MAXFRAME_TASK_CANCEL_DAG_METHOD, json.dumps(req_data)
238
- )
247
+ res = self._put_task_info(MAXFRAME_TASK_CANCEL_DAG_METHOD, req_data)
239
248
  return self._deserial_task_info_result(res, DagInfo)
240
249
 
241
250
  def decref(self, tileable_keys: List[str]) -> None:
242
251
  req_data = {
243
252
  "tileable_keys": ",".join(tileable_keys),
244
253
  }
245
- self._instance.put_task_info(
246
- self._task_name, MAXFRAME_TASK_DECREF_METHOD, json.dumps(req_data)
247
- )
254
+ self._put_task_info(MAXFRAME_TASK_DECREF_METHOD, req_data)
248
255
 
249
256
  def get_logview_address(self, dag_id=None, hours=None) -> Optional[str]:
257
+ """
258
+ Generate logview address
259
+
260
+ Parameters
261
+ ----------
262
+ dag_id: id of dag for which dag logview detail page to access
263
+ hours: hours of the logview address auth limit
264
+ Returns
265
+ -------
266
+ Logview address
267
+ """
250
268
  hours = hours or options.session.logview_hours
251
- subquery_suffix = f"&subQuery={dag_id}" if dag_id else ""
252
- return self._instance.get_logview_address(hours) + subquery_suffix
269
+ # notice: maxframe can't reuse subQuery else will conflict with mcqa when fetch resource data,
270
+ # added dagId for maxframe so logview backend will return maxframe data format if
271
+ # instance and dagId is provided.
272
+ dag_suffix = f"&dagId={dag_id}" if dag_id else ""
273
+ return self._instance.get_logview_address(hours) + dag_suffix
253
274
 
254
275
 
255
276
  class MaxFrameTaskSession(MaxFrameSession):
@@ -17,19 +17,32 @@ import uuid
17
17
  import numpy as np
18
18
  import pandas as pd
19
19
  import pyarrow as pa
20
+ import pytest
20
21
  from odps import ODPS
21
22
 
22
23
  import maxframe.dataframe as md
23
- from maxframe.odpsio import HaloTableIO
24
+ from maxframe.config import options
25
+ from maxframe.odpsio import ODPSTableIO
24
26
  from maxframe.protocol import ODPSTableResultInfo, ResultType
25
27
  from maxframe.tests.utils import tn
26
28
 
27
29
  from ..fetcher import ODPSTableFetcher
28
30
 
29
31
 
30
- async def test_table_fetcher():
32
+ @pytest.fixture
33
+ def switch_table_io(request):
34
+ old_use_common_table = options.use_common_table
35
+ try:
36
+ options.use_common_table = request.param
37
+ yield
38
+ finally:
39
+ options.use_common_table = old_use_common_table
40
+
41
+
42
+ @pytest.mark.parametrize("switch_table_io", [False, True], indirect=True)
43
+ async def test_table_fetcher(switch_table_io):
31
44
  odps_entry = ODPS.from_environments()
32
- halo_table_io = HaloTableIO(odps_entry)
45
+ halo_table_io = ODPSTableIO(odps_entry)
33
46
  fetcher = ODPSTableFetcher(odps_entry)
34
47
 
35
48
  data = pd.DataFrame(
@@ -58,6 +71,11 @@ async def test_table_fetcher():
58
71
  assert len(fetched) == 1000
59
72
  pd.testing.assert_frame_equal(raw_data, fetched)
60
73
 
74
+ result_info = ODPSTableResultInfo(ResultType.ODPS_TABLE, full_table_name=table_name)
75
+ fetched = await fetcher.fetch(tileable, result_info, [slice(None, 2000), None])
76
+ assert len(fetched) == 1000
77
+ pd.testing.assert_frame_equal(raw_data, fetched)
78
+
61
79
  result_info = ODPSTableResultInfo(ResultType.ODPS_TABLE, full_table_name=table_name)
62
80
  fetched = await fetcher.fetch(tileable, result_info, [2, None])
63
81
  assert len(fetched) == 1
@@ -23,7 +23,10 @@ from odps import ODPS
23
23
 
24
24
  import maxframe.dataframe as md
25
25
  import maxframe.remote as mr
26
+ from maxframe.config import options
27
+ from maxframe.config.config import option_context
26
28
  from maxframe.core import ExecutableTuple, TileableGraph
29
+ from maxframe.errors import NoTaskServerResponseError
27
30
  from maxframe.lib.aio import stop_isolation
28
31
  from maxframe.protocol import ResultInfo
29
32
  from maxframe.serialization import RemoteException
@@ -35,6 +38,7 @@ from maxframe_framedriver.app.tests.test_framedriver_webapp import ( # noqa: F4
35
38
  )
36
39
 
37
40
  from ..clients.framedriver import FrameDriverClient
41
+ from ..session.odps import MaxFrameRestCaller
38
42
 
39
43
  pytestmark = pytest.mark.maxframe_engine(["MCSQL", "SPE"])
40
44
 
@@ -86,11 +90,25 @@ def test_simple_run_dataframe(start_mock_session):
86
90
  assert len(dag) == 2
87
91
  return await original_submit_dag(self, session_id, dag, managed_input_infos)
88
92
 
93
+ no_task_server_raised = False
94
+ original_get_dag_info = MaxFrameRestCaller.get_dag_info
95
+
96
+ async def patched_get_dag_info(self, dag_id: str):
97
+ nonlocal no_task_server_raised
98
+
99
+ if not no_task_server_raised:
100
+ no_task_server_raised = True
101
+ raise NoTaskServerResponseError
102
+ return await original_get_dag_info(self, dag_id)
103
+
89
104
  df["H"] = "extra_content"
90
105
 
91
106
  with mock.patch(
92
107
  "maxframe_client.clients.framedriver.FrameDriverClient.submit_dag",
93
108
  new=patched_submit_dag,
109
+ ), mock.patch(
110
+ "maxframe_client.session.odps.MaxFrameRestCaller.get_dag_info",
111
+ new=patched_get_dag_info,
94
112
  ):
95
113
  result = df.execute().fetch()
96
114
  assert len(result) == 1000
@@ -177,13 +195,32 @@ def test_run_dataframe_from_to_odps_table(start_mock_session):
177
195
  assert len(result_df) == 10
178
196
  assert len(result_df.columns) == 6
179
197
 
180
- df = md.read_odps_table(table_obj, index_col="index").head(10).execute().fetch()
198
+ df = md.read_odps_table(table_obj, index_col="index").head(10).execute()
199
+ assert df.shape == (10, 5)
181
200
  assert len(df) == 10
182
201
  assert len(df.columns) == 5
183
202
  finally:
184
203
  odps_entry.delete_table(table_name, if_exists=True)
185
204
 
186
205
 
206
+ def test_create_session_with_options(framedriver_app): # noqa: F811
207
+ odps_entry = ODPS.from_environments()
208
+ framedriver_addr = f"mf://localhost:{framedriver_app.port}"
209
+ old_value = options.session.max_alive_seconds
210
+ session = None
211
+ try:
212
+ options.session.max_alive_seconds = 10
213
+ session = new_session(framedriver_addr, odps_entry=odps_entry)
214
+ session_id = session.session_id
215
+ session_conf = framedriver_app.session_manager.get_session_settings(session_id)
216
+ with option_context(session_conf) as session_options:
217
+ assert session_options.session.max_alive_seconds == 10
218
+ finally:
219
+ options.session.max_alive_seconds = old_value
220
+ if session is not None:
221
+ session.destroy()
222
+
223
+
187
224
  def test_run_and_fetch_series(start_mock_session):
188
225
  odps_entry = ODPS.from_environments()
189
226
 
@@ -244,7 +281,7 @@ def test_pivot_dataframe(start_mock_session):
244
281
  df = md.DataFrame(pd_df)
245
282
  pivot = df.pivot_table(values="D", index=["A", "B"], columns=["C"], aggfunc="sum")
246
283
  executed = pivot.execute()
247
- assert pivot.shape == (2, 4)
284
+ assert pivot.shape == (4, 2)
248
285
  pd.testing.assert_index_equal(
249
286
  pivot.dtypes.index, pd.Index(["large", "small"], name="C")
250
287
  )
@@ -253,3 +290,13 @@ def test_pivot_dataframe(start_mock_session):
253
290
  values="D", index=["A", "B"], columns=["C"], aggfunc="sum"
254
291
  )
255
292
  pd.testing.assert_frame_equal(executed.to_pandas(), expected)
293
+
294
+
295
+ def test_index_drop_duplicates(start_mock_session):
296
+ pd_idx = pd.Index(["lame", "cow", "lame", "beetle", "lame", "hippo"])
297
+ idx = md.Index(pd_idx)
298
+ executed = idx.drop_duplicates(keep="first").execute()
299
+ expected = pd_idx.drop_duplicates(keep="first")
300
+ pd.testing.assert_index_equal(
301
+ executed.to_pandas().sort_values(), expected.sort_values()
302
+ )
@@ -1,104 +0,0 @@
1
- # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Any, Dict, Optional
16
-
17
- from tornado import httpclient
18
-
19
- from maxframe.core import TileableGraph
20
- from maxframe.protocol import ExecuteSubDagRequest, ProtocolBody, SubDagInfo
21
- from maxframe.typing_ import TimeoutType
22
- from maxframe.utils import (
23
- deserialize_serializable,
24
- format_timeout_params,
25
- serialize_serializable,
26
- wait_http_response,
27
- )
28
-
29
-
30
- class SPEClient:
31
- def __init__(
32
- self,
33
- endpoint: str,
34
- session_id: Optional[str] = None,
35
- host: str = None,
36
- ):
37
- self._endpoint = endpoint.rstrip("/")
38
- self._session_id = session_id
39
- self._headers = {"Host": host}
40
-
41
- @staticmethod
42
- def _load_subdag_info(resp: httpclient.HTTPResponse) -> SubDagInfo:
43
- res: ProtocolBody[SubDagInfo] = deserialize_serializable(resp.body)
44
- return res.body
45
-
46
- async def submit_subdag(
47
- self, subdag: TileableGraph, settings: Dict[str, Any] = None
48
- ) -> SubDagInfo:
49
- req_url = f"{self._endpoint}/api/subdags"
50
- req_body: ProtocolBody[ExecuteSubDagRequest] = ProtocolBody(
51
- body=ExecuteSubDagRequest(dag=subdag, settings=settings),
52
- )
53
-
54
- if self._session_id is not None:
55
- req_url += f"?session_id={self._session_id}"
56
-
57
- resp = await httpclient.AsyncHTTPClient().fetch(
58
- req_url,
59
- method="POST",
60
- headers=self._headers,
61
- body=serialize_serializable(req_body),
62
- )
63
- return self._load_subdag_info(resp)
64
-
65
- async def get_subdag_info(self, subdag_id: str) -> SubDagInfo:
66
- req_url = f"{self._endpoint}/api/subdags/{subdag_id}?wait=0"
67
- resp = await httpclient.AsyncHTTPClient().fetch(
68
- req_url,
69
- method="GET",
70
- headers=self._headers,
71
- )
72
- return self._load_subdag_info(resp)
73
-
74
- async def wait_subdag(
75
- self, subdag_id: str, wait_timeout: TimeoutType = None
76
- ) -> SubDagInfo:
77
- req_url = f"{self._endpoint}/api/subdags/{subdag_id}"
78
- params = format_timeout_params(wait_timeout)
79
- try:
80
- resp = await wait_http_response(
81
- req_url + params,
82
- method="GET",
83
- headers=self._headers,
84
- request_timeout=wait_timeout,
85
- )
86
- return self._load_subdag_info(resp)
87
- except TimeoutError:
88
- return await self.get_subdag_info(subdag_id)
89
-
90
- async def cancel_subdag(
91
- self, subdag_id: str, wait_timeout: TimeoutType = None
92
- ) -> SubDagInfo:
93
- req_url = f"{self._endpoint}/api/subdags/{subdag_id}"
94
- params = format_timeout_params(wait_timeout)
95
- try:
96
- resp = await wait_http_response(
97
- req_url + params,
98
- method="DELETE",
99
- headers=self._headers,
100
- request_timeout=wait_timeout,
101
- )
102
- return self._load_subdag_info(resp)
103
- except TimeoutError:
104
- return await self.get_subdag_info(subdag_id)