maxframe 0.1.0b3__cp37-cp37m-win_amd64.whl → 0.1.0b5__cp37-cp37m-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (58) hide show
  1. maxframe/__init__.py +1 -0
  2. maxframe/_utils.cp37-win_amd64.pyd +0 -0
  3. maxframe/codegen.py +46 -1
  4. maxframe/config/config.py +14 -1
  5. maxframe/core/graph/core.cp37-win_amd64.pyd +0 -0
  6. maxframe/dataframe/__init__.py +6 -0
  7. maxframe/dataframe/core.py +34 -10
  8. maxframe/dataframe/datasource/read_odps_query.py +6 -2
  9. maxframe/dataframe/datasource/read_odps_table.py +5 -1
  10. maxframe/dataframe/datastore/core.py +19 -0
  11. maxframe/dataframe/datastore/to_csv.py +2 -2
  12. maxframe/dataframe/datastore/to_odps.py +2 -2
  13. maxframe/dataframe/indexing/reset_index.py +1 -17
  14. maxframe/dataframe/misc/__init__.py +4 -0
  15. maxframe/dataframe/misc/apply.py +1 -1
  16. maxframe/dataframe/misc/case_when.py +141 -0
  17. maxframe/dataframe/misc/pivot_table.py +262 -0
  18. maxframe/dataframe/misc/tests/test_misc.py +61 -0
  19. maxframe/dataframe/plotting/core.py +2 -2
  20. maxframe/dataframe/reduction/core.py +2 -1
  21. maxframe/dataframe/utils.py +7 -0
  22. maxframe/learn/contrib/utils.py +52 -0
  23. maxframe/learn/contrib/xgboost/__init__.py +26 -0
  24. maxframe/learn/contrib/xgboost/classifier.py +86 -0
  25. maxframe/learn/contrib/xgboost/core.py +156 -0
  26. maxframe/learn/contrib/xgboost/dmatrix.py +150 -0
  27. maxframe/learn/contrib/xgboost/predict.py +138 -0
  28. maxframe/learn/contrib/xgboost/regressor.py +78 -0
  29. maxframe/learn/contrib/xgboost/tests/__init__.py +13 -0
  30. maxframe/learn/contrib/xgboost/tests/test_core.py +43 -0
  31. maxframe/learn/contrib/xgboost/train.py +121 -0
  32. maxframe/learn/utils/__init__.py +15 -0
  33. maxframe/learn/utils/core.py +29 -0
  34. maxframe/lib/mmh3.cp37-win_amd64.pyd +0 -0
  35. maxframe/odpsio/arrow.py +10 -6
  36. maxframe/odpsio/schema.py +18 -5
  37. maxframe/odpsio/tableio.py +22 -0
  38. maxframe/odpsio/tests/test_schema.py +41 -11
  39. maxframe/opcodes.py +8 -0
  40. maxframe/serialization/core.cp37-win_amd64.pyd +0 -0
  41. maxframe/serialization/core.pyi +61 -0
  42. maxframe/session.py +32 -2
  43. maxframe/tensor/__init__.py +1 -1
  44. maxframe/tensor/base/__init__.py +2 -0
  45. maxframe/tensor/base/atleast_1d.py +74 -0
  46. maxframe/tensor/base/unique.py +205 -0
  47. maxframe/tensor/datasource/array.py +4 -2
  48. maxframe/tensor/datasource/scalar.py +1 -1
  49. maxframe/udf.py +63 -3
  50. maxframe/utils.py +11 -0
  51. {maxframe-0.1.0b3.dist-info → maxframe-0.1.0b5.dist-info}/METADATA +2 -2
  52. {maxframe-0.1.0b3.dist-info → maxframe-0.1.0b5.dist-info}/RECORD +58 -40
  53. maxframe_client/fetcher.py +65 -3
  54. maxframe_client/session/odps.py +41 -11
  55. maxframe_client/session/task.py +26 -53
  56. maxframe_client/tests/test_session.py +49 -1
  57. {maxframe-0.1.0b3.dist-info → maxframe-0.1.0b5.dist-info}/WHEEL +0 -0
  58. {maxframe-0.1.0b3.dist-info → maxframe-0.1.0b5.dist-info}/top_level.txt +0 -0
maxframe/__init__.py CHANGED
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from . import dataframe, learn, remote, tensor
16
+ from .config import options
16
17
  from .session import execute, fetch, new_session, stop_server
17
18
 
18
19
 
Binary file
maxframe/codegen.py CHANGED
@@ -16,6 +16,7 @@ import abc
16
16
  import base64
17
17
  import dataclasses
18
18
  import logging
19
+ from collections import defaultdict
19
20
  from enum import Enum
20
21
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
21
22
 
@@ -32,7 +33,7 @@ from .protocol import DataFrameTableMeta, ResultInfo
32
33
  from .serialization import PickleContainer
33
34
  from .serialization.serializables import Serializable, StringField
34
35
  from .typing_ import PandasObjectTypes
35
- from .udf import MarkedFunction
36
+ from .udf import MarkedFunction, PythonPackOptions
36
37
 
37
38
  if TYPE_CHECKING:
38
39
  from odpsctx import ODPSSessionContext
@@ -75,6 +76,14 @@ class AbstractUDF(Serializable):
75
76
  def unregister(self, odps: "ODPSSessionContext"):
76
77
  raise NotImplementedError
77
78
 
79
+ @abc.abstractmethod
80
+ def collect_pythonpack(self) -> List[PythonPackOptions]:
81
+ raise NotImplementedError
82
+
83
+ @abc.abstractmethod
84
+ def load_pythonpack_resources(self, odps_ctx: "ODPSSessionContext") -> None:
85
+ raise NotImplementedError
86
+
78
87
 
79
88
  class UserCodeMixin:
80
89
  @classmethod
@@ -469,6 +478,42 @@ class BigDagCodeGenerator(metaclass=abc.ABCMeta):
469
478
  output_key_to_result_infos=self._context.get_tileable_result_infos(),
470
479
  )
471
480
 
481
+ def run_pythonpacks(
482
+ self,
483
+ odps_ctx: "ODPSSessionContext",
484
+ python_tag: str,
485
+ is_production: bool = False,
486
+ schedule_id: Optional[str] = None,
487
+ hints: Optional[dict] = None,
488
+ priority: Optional[int] = None,
489
+ ) -> Dict[str, PythonPackOptions]:
490
+ key_to_packs = defaultdict(list)
491
+ for udf in self._context.get_udfs():
492
+ for pack in udf.collect_pythonpack():
493
+ key_to_packs[pack.key].append(pack)
494
+ distinct_packs = []
495
+ for packs in key_to_packs.values():
496
+ distinct_packs.append(packs[0])
497
+
498
+ inst_id_to_req = {}
499
+ for pack in distinct_packs:
500
+ inst = odps_ctx.run_pythonpack(
501
+ requirements=pack.requirements,
502
+ prefer_binary=pack.prefer_binary,
503
+ pre_release=pack.pre_release,
504
+ force_rebuild=pack.force_rebuild,
505
+ python_tag=python_tag,
506
+ is_production=is_production,
507
+ schedule_id=schedule_id,
508
+ hints=hints,
509
+ priority=priority,
510
+ )
511
+ # fulfill instance id of pythonpacks with same keys
512
+ for same_pack in key_to_packs[pack.key]:
513
+ same_pack.pack_instance_id = inst.id
514
+ inst_id_to_req[inst.id] = pack
515
+ return inst_id_to_req
516
+
472
517
  def register_udfs(self, odps_ctx: "ODPSSessionContext"):
473
518
  for udf in self._context.get_udfs():
474
519
  logger.info("[Session %s] Registering UDF %s", self._session_id, udf.name)
maxframe/config/config.py CHANGED
@@ -40,6 +40,7 @@ _DEFAULT_SPE_OPERATION_TIMEOUT_SECONDS = 120
40
40
  _DEFAULT_UPLOAD_BATCH_SIZE = 4096
41
41
  _DEFAULT_TEMP_LIFECYCLE = 1
42
42
  _DEFAULT_TASK_START_TIMEOUT = 60
43
+ _DEFAULT_LOGVIEW_HOURS = 24 * 60
43
44
 
44
45
 
45
46
  class OptionError(Exception):
@@ -296,13 +297,15 @@ class Config:
296
297
 
297
298
 
298
299
  default_options = Config()
299
-
300
300
  default_options.register_option(
301
301
  "execution_mode", "trigger", validator=is_in(["trigger", "eager"])
302
302
  )
303
303
  default_options.register_option(
304
304
  "python_tag", get_python_tag(), validator=is_string, remote=True
305
305
  )
306
+ default_options.register_option(
307
+ "session.logview_hours", _DEFAULT_LOGVIEW_HOURS, validator=is_integer, remote=True
308
+ )
306
309
  default_options.register_option(
307
310
  "client.task_start_timeout", _DEFAULT_TASK_START_TIMEOUT, validator=is_integer
308
311
  )
@@ -312,6 +315,9 @@ default_options.register_option(
312
315
  )
313
316
  default_options.register_option("sql.settings", {}, validator=is_dict, remote=True)
314
317
 
318
+ default_options.register_option("is_production", False, validator=is_bool, remote=True)
319
+ default_options.register_option("schedule_id", "", validator=is_string, remote=True)
320
+
315
321
  default_options.register_option(
316
322
  "session.max_alive_seconds",
317
323
  _DEFAULT_MAX_ALIVE_SECONDS,
@@ -358,6 +364,9 @@ default_options.register_option(
358
364
  default_options.register_option(
359
365
  "show_progress", "auto", validator=any_validator(is_bool, is_string)
360
366
  )
367
+ default_options.register_option(
368
+ "dag.settings", value=dict(), validator=is_dict, remote=True
369
+ )
361
370
 
362
371
  ################
363
372
  # SPE Settings #
@@ -373,6 +382,10 @@ default_options.register_option(
373
382
  "spe.task.settings", dict(), validator=is_dict, remote=True
374
383
  )
375
384
 
385
+ default_options.register_option(
386
+ "pythonpack.task.settings", {}, validator=is_dict, remote=True
387
+ )
388
+
376
389
  _options_ctx_var = contextvars.ContextVar("_options_ctx_var")
377
390
 
378
391
 
Binary file
@@ -46,6 +46,7 @@ from .misc.cut import cut
46
46
  from .misc.eval import maxframe_eval as eval # pylint: disable=redefined-builtin
47
47
  from .misc.get_dummies import get_dummies
48
48
  from .misc.melt import melt
49
+ from .misc.pivot_table import pivot_table
49
50
  from .misc.qcut import qcut
50
51
  from .misc.to_numeric import to_numeric
51
52
  from .missing import isna, isnull, notna, notnull
@@ -57,6 +58,11 @@ try:
57
58
  except ImportError: # pragma: no cover
58
59
  pass
59
60
 
61
+ try:
62
+ from . import _internal
63
+ except ImportError: # pragma: no cover
64
+ pass
65
+
60
66
  del (
61
67
  arithmetic,
62
68
  datasource,
@@ -35,6 +35,7 @@ from ..core import (
35
35
  register_output_types,
36
36
  )
37
37
  from ..core.entity.utils import refresh_tileable_shape
38
+ from ..protocol import DataFrameTableMeta
38
39
  from ..serialization.serializables import (
39
40
  AnyField,
40
41
  BoolField,
@@ -59,7 +60,13 @@ from ..utils import (
59
60
  on_serialize_numpy_type,
60
61
  tokenize,
61
62
  )
62
- from .utils import ReprSeries, fetch_corner_data, merge_index_value, parse_index
63
+ from .utils import (
64
+ ReprSeries,
65
+ apply_if_callable,
66
+ fetch_corner_data,
67
+ merge_index_value,
68
+ parse_index,
69
+ )
63
70
 
64
71
 
65
72
  class IndexValue(Serializable):
@@ -616,6 +623,9 @@ class IndexData(HasShapeTileableData, _ToPandasMixin):
616
623
  if self._name is None:
617
624
  self._name = self.chunks[0].name
618
625
 
626
+ def refresh_from_table_meta(self, table_meta: DataFrameTableMeta) -> None:
627
+ pass
628
+
619
629
  def _to_str(self, representation=False):
620
630
  if is_build_mode() or len(self._executed_sessions) == 0:
621
631
  # in build mode, or not executed, just return representation
@@ -945,6 +955,9 @@ class BaseSeriesData(HasShapeTileableData, _ToPandasMixin):
945
955
  if self._name is None:
946
956
  self._name = self.chunks[0].name
947
957
 
958
+ def refresh_from_table_meta(self, table_meta: DataFrameTableMeta) -> None:
959
+ pass
960
+
948
961
  def _to_str(self, representation=False):
949
962
  if is_build_mode() or len(self._executed_sessions) == 0:
950
963
  # in build mode, or not executed, just return representation
@@ -960,7 +973,9 @@ class BaseSeriesData(HasShapeTileableData, _ToPandasMixin):
960
973
  buf = StringIO()
961
974
  max_rows = pd.get_option("display.max_rows")
962
975
  corner_max_rows = (
963
- max_rows if self.shape[0] <= max_rows else corner_data.shape[0] - 1
976
+ max_rows
977
+ if self.shape[0] <= max_rows or corner_data.shape[0] == 0
978
+ else corner_data.shape[0] - 1
964
979
  ) # make sure max_rows < corner_data
965
980
 
966
981
  with pd.option_context("display.max_rows", corner_max_rows):
@@ -976,7 +991,7 @@ class BaseSeriesData(HasShapeTileableData, _ToPandasMixin):
976
991
  return self._to_str(representation=False)
977
992
 
978
993
  def __repr__(self):
979
- return self._to_str(representation=False)
994
+ return self._to_str(representation=True)
980
995
 
981
996
  @property
982
997
  def dtype(self):
@@ -1499,6 +1514,15 @@ class BaseDataFrameData(HasShapeTileableData, _ToPandasMixin):
1499
1514
  refresh_index_value(self)
1500
1515
  refresh_dtypes(self)
1501
1516
 
1517
+ def refresh_from_table_meta(self, table_meta: DataFrameTableMeta) -> None:
1518
+ dtypes = table_meta.pd_column_dtypes
1519
+ self._dtypes = dtypes
1520
+ self._columns_value = parse_index(dtypes.index, store_data=True)
1521
+ self._dtypes_value = DtypesValue(key=tokenize(dtypes), value=dtypes)
1522
+ new_shape = list(self._shape)
1523
+ new_shape[0] = len(dtypes)
1524
+ self._shape = tuple(new_shape)
1525
+
1502
1526
  @property
1503
1527
  def dtypes(self):
1504
1528
  dt = getattr(self, "_dtypes", None)
@@ -1605,7 +1629,7 @@ class DataFrameData(_BatchedFetcher, BaseDataFrameData):
1605
1629
  buf = StringIO()
1606
1630
  max_rows = pd.get_option("display.max_rows")
1607
1631
 
1608
- if self.shape[0] <= max_rows:
1632
+ if self.shape[0] <= max_rows or corner_data.shape[0] == 0:
1609
1633
  buf.write(repr(corner_data) if representation else str(corner_data))
1610
1634
  else:
1611
1635
  # remember we cannot directly call repr(df),
@@ -1995,12 +2019,6 @@ class DataFrame(HasShapeTileable, _ToPandasMixin):
1995
2019
  Berkeley 25.0 77.0 298.15
1996
2020
  """
1997
2021
 
1998
- def apply_if_callable(maybe_callable, obj, **kwargs):
1999
- if callable(maybe_callable):
2000
- return maybe_callable(obj, **kwargs)
2001
-
2002
- return maybe_callable
2003
-
2004
2022
  data = self.copy()
2005
2023
 
2006
2024
  for k, v in kwargs.items():
@@ -2195,6 +2213,9 @@ class CategoricalData(HasShapeTileableData, _ToPandasMixin):
2195
2213
  pd.Categorical(categories).categories, store_data=True
2196
2214
  )
2197
2215
 
2216
+ def refresh_from_table_meta(self, table_meta: DataFrameTableMeta) -> None:
2217
+ pass
2218
+
2198
2219
  def _to_str(self, representation=False):
2199
2220
  if is_build_mode() or len(self._executed_sessions) == 0:
2200
2221
  # in build mode, or not executed, just return representation
@@ -2345,6 +2366,9 @@ class DataFrameOrSeriesData(HasShapeTileableData, _ToPandasMixin):
2345
2366
  data_params["name"] = self.chunks[0].name
2346
2367
  self._data_params.update(data_params)
2347
2368
 
2369
+ def refresh_from_table_meta(self, table_meta: DataFrameTableMeta) -> None:
2370
+ pass
2371
+
2348
2372
  def ensure_data(self):
2349
2373
  from .fetch.core import DataFrameFetch
2350
2374
 
@@ -216,7 +216,9 @@ class DataFrameReadODPSQuery(
216
216
  index_value = parse_index(pd.RangeIndex(0))
217
217
  elif len(self.index_columns) == 1:
218
218
  index_value = parse_index(
219
- pd.Index([], name=self.index_columns[0]).astype(self.index_dtypes[0])
219
+ pd.Index([], name=self.index_columns[0]).astype(
220
+ self.index_dtypes.iloc[0]
221
+ )
220
222
  )
221
223
  else:
222
224
  idx = pd.MultiIndex.from_frame(
@@ -263,7 +265,9 @@ def read_odps_query(
263
265
  result: DataFrame
264
266
  DataFrame read from MaxCompute (ODPS) table
265
267
  """
266
- odps_entry = odps_entry or ODPS.from_environments()
268
+ odps_entry = odps_entry or ODPS.from_global() or ODPS.from_environments()
269
+ if odps_entry is None:
270
+ raise ValueError("Missing odps_entry parameter")
267
271
  inst = odps_entry.execute_sql(f"EXPLAIN {query}")
268
272
  explain_str = list(inst.get_task_results().values())[0]
269
273
 
@@ -82,7 +82,9 @@ class DataFrameReadODPSTable(
82
82
  index_value = parse_index(pd.RangeIndex(shape[0]))
83
83
  elif len(self.index_columns) == 1:
84
84
  index_value = parse_index(
85
- pd.Index([], name=self.index_columns[0]).astype(self.index_dtypes[0])
85
+ pd.Index([], name=self.index_columns[0]).astype(
86
+ self.index_dtypes.iloc[0]
87
+ )
86
88
  )
87
89
  else:
88
90
  idx = pd.MultiIndex.from_frame(
@@ -164,6 +166,8 @@ def read_odps_table(
164
166
  DataFrame read from MaxCompute (ODPS) table
165
167
  """
166
168
  odps_entry = odps_entry or ODPS.from_global() or ODPS.from_environments()
169
+ if odps_entry is None:
170
+ raise ValueError("Missing odps_entry parameter")
167
171
  if isinstance(table_name, Table):
168
172
  table = table_name
169
173
  else:
@@ -0,0 +1,19 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..operators import DataFrameOperator, DataFrameOperatorMixin
16
+
17
+
18
+ class DataFrameDataStore(DataFrameOperator, DataFrameOperatorMixin):
19
+ pass
@@ -23,11 +23,11 @@ from ...serialization.serializables import (
23
23
  ListField,
24
24
  StringField,
25
25
  )
26
- from ..operators import DataFrameOperator, DataFrameOperatorMixin
27
26
  from ..utils import parse_index
27
+ from .core import DataFrameDataStore
28
28
 
29
29
 
30
- class DataFrameToCSV(DataFrameOperator, DataFrameOperatorMixin):
30
+ class DataFrameToCSV(DataFrameDataStore):
31
31
  _op_type_ = opcodes.TO_CSV
32
32
 
33
33
  input = KeyField("input")
@@ -32,13 +32,13 @@ from ...serialization.serializables import (
32
32
  )
33
33
  from ...typing_ import TileableType
34
34
  from ..core import DataFrame # noqa: F401
35
- from ..operators import DataFrameOperator, DataFrameOperatorMixin
36
35
  from ..utils import parse_index
36
+ from .core import DataFrameDataStore
37
37
 
38
38
  logger = logging.getLogger(__name__)
39
39
 
40
40
 
41
- class DataFrameToODPSTable(DataFrameOperator, DataFrameOperatorMixin):
41
+ class DataFrameToODPSTable(DataFrameDataStore):
42
42
  _op_type_ = opcodes.TO_ODPS_TABLE
43
43
 
44
44
  dtypes = SeriesField("dtypes")
@@ -107,7 +107,6 @@ def df_reset_index(
107
107
  inplace=False,
108
108
  col_level=0,
109
109
  col_fill="",
110
- incremental_index=False,
111
110
  ):
112
111
  """
113
112
  Reset the index, or a level of it.
@@ -133,12 +132,6 @@ def df_reset_index(
133
132
  col_fill : object, default ''
134
133
  If the columns have multiple levels, determines how the other
135
134
  levels are named. If None then the index name is repeated.
136
- incremental_index: bool, default False
137
- Ensure RangeIndex incremental, when output DataFrame has multiple chunks,
138
- ensuring index incremental costs more computation,
139
- so by default, each chunk will have index which starts from 0,
140
- setting incremental_index=True,reset_index will guarantee that
141
- output DataFrame's index is from 0 to n - 1.
142
135
 
143
136
  Returns
144
137
  -------
@@ -264,7 +257,6 @@ def df_reset_index(
264
257
  drop=drop,
265
258
  col_level=col_level,
266
259
  col_fill=col_fill,
267
- incremental_index=incremental_index,
268
260
  output_types=[OutputType.dataframe],
269
261
  )
270
262
  ret = op(df)
@@ -280,7 +272,6 @@ def series_reset_index(
280
272
  drop=False,
281
273
  name=no_default,
282
274
  inplace=False,
283
- incremental_index=False,
284
275
  ):
285
276
  """
286
277
  Generate a new DataFrame or Series with the index reset.
@@ -303,12 +294,6 @@ def series_reset_index(
303
294
  when `drop` is True.
304
295
  inplace : bool, default False
305
296
  Modify the Series in place (do not create a new object).
306
- incremental_index: bool, default False
307
- Ensure RangeIndex incremental, when output Series has multiple chunks,
308
- ensuring index incremental costs more computation,
309
- so by default, each chunk will have index which starts from 0,
310
- setting incremental_index=True,reset_index will guarantee that
311
- output Series's index is from 0 to n - 1.
312
297
 
313
298
  Returns
314
299
  -------
@@ -406,8 +391,7 @@ def series_reset_index(
406
391
  level=level,
407
392
  drop=drop,
408
393
  name=name,
409
- incremental_index=incremental_index,
410
- output_types=[OutputType.series],
394
+ output_types=[OutputType.series if drop else OutputType.dataframe],
411
395
  )
412
396
  ret = op(series)
413
397
  if not inplace:
@@ -14,6 +14,7 @@
14
14
 
15
15
  from .apply import df_apply, series_apply
16
16
  from .astype import astype, index_astype
17
+ from .case_when import case_when
17
18
  from .check_monotonic import (
18
19
  check_monotonic,
19
20
  is_monotonic,
@@ -37,6 +38,7 @@ from .map import index_map, series_map
37
38
  from .melt import melt
38
39
  from .memory_usage import df_memory_usage, index_memory_usage, series_memory_usage
39
40
  from .pct_change import pct_change
41
+ from .pivot_table import pivot_table
40
42
  from .qcut import qcut
41
43
  from .select_dtypes import select_dtypes
42
44
  from .shift import shift, tshift
@@ -69,6 +71,7 @@ def _install():
69
71
  setattr(t, "melt", melt)
70
72
  setattr(t, "memory_usage", df_memory_usage)
71
73
  setattr(t, "pct_change", pct_change)
74
+ setattr(t, "pivot_table", pivot_table)
72
75
  setattr(t, "pop", df_pop)
73
76
  setattr(t, "query", df_query)
74
77
  setattr(t, "select_dtypes", select_dtypes)
@@ -81,6 +84,7 @@ def _install():
81
84
  for t in SERIES_TYPE:
82
85
  setattr(t, "apply", series_apply)
83
86
  setattr(t, "astype", astype)
87
+ setattr(t, "case_when", case_when)
84
88
  setattr(t, "check_monotonic", check_monotonic)
85
89
  setattr(t, "describe", describe)
86
90
  setattr(t, "diff", series_diff)
@@ -225,7 +225,7 @@ class ApplyOperator(
225
225
  else: # pragma: no cover
226
226
  index_value = parse_index(infer_series.index)
227
227
  else:
228
- index_value = parse_index(None, series)
228
+ index_value = parse_index(series.index_value)
229
229
 
230
230
  if output_type == OutputType.dataframe:
231
231
  if dtypes is None:
@@ -0,0 +1,141 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ from pandas.core.dtypes.cast import find_common_type
17
+
18
+ from ... import opcodes
19
+ from ...core import TILEABLE_TYPE
20
+ from ...serialization.serializables import FieldTypes, ListField
21
+ from ..core import SERIES_TYPE
22
+ from ..operators import DataFrameOperator, DataFrameOperatorMixin
23
+ from ..utils import apply_if_callable
24
+
25
+
26
+ class DataFrameCaseWhen(DataFrameOperator, DataFrameOperatorMixin):
27
+ _op_type_ = opcodes.CASE_WHEN
28
+
29
+ conditions = ListField("conditions", FieldTypes.reference, default=None)
30
+ replacements = ListField("replacements", FieldTypes.reference, default=None)
31
+
32
+ def __init__(self, output_types=None, **kw):
33
+ super().__init__(_output_types=output_types, **kw)
34
+
35
+ def _set_inputs(self, inputs):
36
+ super()._set_inputs(inputs)
37
+ it = iter(inputs)
38
+ next(it)
39
+ self.conditions = [
40
+ next(it) if isinstance(t, TILEABLE_TYPE) else t for t in self.conditions
41
+ ]
42
+ self.replacements = [
43
+ next(it) if isinstance(t, TILEABLE_TYPE) else t for t in self.replacements
44
+ ]
45
+
46
+ def __call__(self, series):
47
+ replacement_dtypes = [
48
+ it.dtype if isinstance(it, SERIES_TYPE) else np.array(it).dtype
49
+ for it in self.replacements
50
+ ]
51
+ dtype = find_common_type([series.dtype] + replacement_dtypes)
52
+
53
+ condition_tileables = [
54
+ it for it in self.conditions if isinstance(it, TILEABLE_TYPE)
55
+ ]
56
+ replacement_tileables = [
57
+ it for it in self.replacements if isinstance(it, TILEABLE_TYPE)
58
+ ]
59
+ inputs = [series] + condition_tileables + replacement_tileables
60
+
61
+ params = series.params
62
+ params["dtype"] = dtype
63
+ return self.new_series(inputs, **params)
64
+
65
+
66
+ def case_when(series, caselist):
67
+ """
68
+ Replace values where the conditions are True.
69
+
70
+ Parameters
71
+ ----------
72
+ caselist : A list of tuples of conditions and expected replacements
73
+ Takes the form: ``(condition0, replacement0)``,
74
+ ``(condition1, replacement1)``, ... .
75
+ ``condition`` should be a 1-D boolean array-like object
76
+ or a callable. If ``condition`` is a callable,
77
+ it is computed on the Series
78
+ and should return a boolean Series or array.
79
+ The callable must not change the input Series
80
+ (though pandas doesn`t check it). ``replacement`` should be a
81
+ 1-D array-like object, a scalar or a callable.
82
+ If ``replacement`` is a callable, it is computed on the Series
83
+ and should return a scalar or Series. The callable
84
+ must not change the input Series.
85
+
86
+ Returns
87
+ -------
88
+ Series
89
+
90
+ See Also
91
+ --------
92
+ Series.mask : Replace values where the condition is True.
93
+
94
+ Examples
95
+ --------
96
+ >>> import maxframe.dataframe as md
97
+ >>> c = md.Series([6, 7, 8, 9], name='c')
98
+ >>> a = md.Series([0, 0, 1, 2])
99
+ >>> b = md.Series([0, 3, 4, 5])
100
+
101
+ >>> c.case_when(caselist=[(a.gt(0), a), # condition, replacement
102
+ ... (b.gt(0), b)])
103
+ 0 6
104
+ 1 3
105
+ 2 1
106
+ 3 2
107
+ Name: c, dtype: int64
108
+ """
109
+ if not isinstance(caselist, list):
110
+ raise TypeError(
111
+ f"The caselist argument should be a list; instead got {type(caselist)}"
112
+ )
113
+
114
+ if not caselist:
115
+ raise ValueError(
116
+ "provide at least one boolean condition, "
117
+ "with a corresponding replacement."
118
+ )
119
+
120
+ for num, entry in enumerate(caselist):
121
+ if not isinstance(entry, tuple):
122
+ raise TypeError(
123
+ f"Argument {num} must be a tuple; instead got {type(entry)}."
124
+ )
125
+ if len(entry) != 2:
126
+ raise ValueError(
127
+ f"Argument {num} must have length 2; "
128
+ "a condition and replacement; "
129
+ f"instead got length {len(entry)}."
130
+ )
131
+ caselist = [
132
+ (
133
+ apply_if_callable(condition, series),
134
+ apply_if_callable(replacement, series),
135
+ )
136
+ for condition, replacement in caselist
137
+ ]
138
+ conditions = [case[0] for case in caselist]
139
+ replacements = [case[1] for case in caselist]
140
+ op = DataFrameCaseWhen(conditions=conditions, replacements=replacements)
141
+ return op(series)