maxframe 1.0.0rc1__cp311-cp311-win32.whl → 1.0.0rc3__cp311-cp311-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (138) hide show
  1. maxframe/_utils.cp311-win32.pyd +0 -0
  2. maxframe/codegen.py +3 -6
  3. maxframe/config/config.py +49 -10
  4. maxframe/config/validators.py +42 -11
  5. maxframe/conftest.py +15 -2
  6. maxframe/core/__init__.py +2 -13
  7. maxframe/core/entity/__init__.py +0 -4
  8. maxframe/core/entity/objects.py +46 -3
  9. maxframe/core/entity/output_types.py +0 -3
  10. maxframe/core/entity/tests/test_objects.py +43 -0
  11. maxframe/core/entity/tileables.py +5 -78
  12. maxframe/core/graph/__init__.py +2 -2
  13. maxframe/core/graph/builder/__init__.py +0 -1
  14. maxframe/core/graph/builder/base.py +5 -4
  15. maxframe/core/graph/builder/tileable.py +4 -4
  16. maxframe/core/graph/builder/utils.py +4 -8
  17. maxframe/core/graph/core.cp311-win32.pyd +0 -0
  18. maxframe/core/graph/entity.py +9 -33
  19. maxframe/core/operator/__init__.py +2 -9
  20. maxframe/core/operator/base.py +3 -5
  21. maxframe/core/operator/objects.py +0 -9
  22. maxframe/core/operator/utils.py +55 -0
  23. maxframe/dataframe/__init__.py +1 -1
  24. maxframe/dataframe/arithmetic/around.py +5 -17
  25. maxframe/dataframe/arithmetic/core.py +15 -7
  26. maxframe/dataframe/arithmetic/docstring.py +5 -55
  27. maxframe/dataframe/arithmetic/tests/test_arithmetic.py +22 -0
  28. maxframe/dataframe/core.py +5 -5
  29. maxframe/dataframe/datasource/date_range.py +2 -2
  30. maxframe/dataframe/datasource/read_odps_query.py +7 -1
  31. maxframe/dataframe/datasource/read_odps_table.py +3 -2
  32. maxframe/dataframe/datasource/tests/test_datasource.py +14 -0
  33. maxframe/dataframe/datastore/to_odps.py +1 -1
  34. maxframe/dataframe/groupby/cum.py +0 -1
  35. maxframe/dataframe/groupby/tests/test_groupby.py +4 -0
  36. maxframe/dataframe/indexing/add_prefix_suffix.py +1 -1
  37. maxframe/dataframe/indexing/rename.py +3 -37
  38. maxframe/dataframe/indexing/sample.py +0 -1
  39. maxframe/dataframe/indexing/set_index.py +68 -1
  40. maxframe/dataframe/merge/merge.py +236 -2
  41. maxframe/dataframe/merge/tests/test_merge.py +123 -0
  42. maxframe/dataframe/misc/apply.py +3 -10
  43. maxframe/dataframe/misc/case_when.py +1 -1
  44. maxframe/dataframe/misc/describe.py +2 -2
  45. maxframe/dataframe/misc/drop_duplicates.py +4 -25
  46. maxframe/dataframe/misc/eval.py +4 -0
  47. maxframe/dataframe/misc/pct_change.py +1 -83
  48. maxframe/dataframe/misc/transform.py +1 -30
  49. maxframe/dataframe/misc/value_counts.py +4 -17
  50. maxframe/dataframe/missing/dropna.py +1 -1
  51. maxframe/dataframe/missing/fillna.py +5 -5
  52. maxframe/dataframe/operators.py +1 -17
  53. maxframe/dataframe/reduction/core.py +2 -2
  54. maxframe/dataframe/sort/sort_values.py +1 -11
  55. maxframe/dataframe/statistics/quantile.py +5 -17
  56. maxframe/dataframe/utils.py +4 -7
  57. maxframe/io/objects/__init__.py +24 -0
  58. maxframe/io/objects/core.py +140 -0
  59. maxframe/io/objects/tensor.py +76 -0
  60. maxframe/io/objects/tests/__init__.py +13 -0
  61. maxframe/io/objects/tests/test_object_io.py +97 -0
  62. maxframe/{odpsio → io/odpsio}/__init__.py +3 -1
  63. maxframe/{odpsio → io/odpsio}/arrow.py +12 -8
  64. maxframe/{odpsio → io/odpsio}/schema.py +15 -12
  65. maxframe/io/odpsio/tableio.py +702 -0
  66. maxframe/io/odpsio/tests/__init__.py +13 -0
  67. maxframe/{odpsio → io/odpsio}/tests/test_schema.py +19 -18
  68. maxframe/{odpsio → io/odpsio}/tests/test_tableio.py +50 -23
  69. maxframe/{odpsio → io/odpsio}/tests/test_volumeio.py +4 -6
  70. maxframe/io/odpsio/volumeio.py +57 -0
  71. maxframe/learn/contrib/xgboost/classifier.py +26 -2
  72. maxframe/learn/contrib/xgboost/core.py +87 -2
  73. maxframe/learn/contrib/xgboost/dmatrix.py +3 -6
  74. maxframe/learn/contrib/xgboost/predict.py +21 -7
  75. maxframe/learn/contrib/xgboost/regressor.py +3 -10
  76. maxframe/learn/contrib/xgboost/train.py +27 -17
  77. maxframe/{core/operator/fuse.py → learn/core.py} +7 -10
  78. maxframe/lib/mmh3.cp311-win32.pyd +0 -0
  79. maxframe/protocol.py +41 -17
  80. maxframe/remote/core.py +4 -8
  81. maxframe/serialization/__init__.py +1 -0
  82. maxframe/serialization/core.cp311-win32.pyd +0 -0
  83. maxframe/serialization/serializables/core.py +48 -9
  84. maxframe/tensor/__init__.py +69 -2
  85. maxframe/tensor/arithmetic/isclose.py +1 -0
  86. maxframe/tensor/arithmetic/tests/test_arithmetic.py +21 -17
  87. maxframe/tensor/core.py +5 -136
  88. maxframe/tensor/datasource/array.py +3 -0
  89. maxframe/tensor/datasource/full.py +1 -1
  90. maxframe/tensor/datasource/tests/test_datasource.py +1 -1
  91. maxframe/tensor/indexing/flatnonzero.py +1 -1
  92. maxframe/tensor/merge/__init__.py +2 -0
  93. maxframe/tensor/merge/concatenate.py +98 -0
  94. maxframe/tensor/merge/tests/test_merge.py +30 -1
  95. maxframe/tensor/merge/vstack.py +70 -0
  96. maxframe/tensor/{base → misc}/__init__.py +2 -0
  97. maxframe/tensor/{base → misc}/atleast_1d.py +0 -2
  98. maxframe/tensor/misc/atleast_2d.py +70 -0
  99. maxframe/tensor/misc/atleast_3d.py +85 -0
  100. maxframe/tensor/misc/tests/__init__.py +13 -0
  101. maxframe/tensor/{base → misc}/transpose.py +22 -18
  102. maxframe/tensor/{base → misc}/unique.py +2 -2
  103. maxframe/tensor/operators.py +1 -7
  104. maxframe/tensor/random/core.py +1 -1
  105. maxframe/tensor/reduction/count_nonzero.py +1 -0
  106. maxframe/tensor/reduction/mean.py +1 -0
  107. maxframe/tensor/reduction/nanmean.py +1 -0
  108. maxframe/tensor/reduction/nanvar.py +2 -0
  109. maxframe/tensor/reduction/tests/test_reduction.py +12 -1
  110. maxframe/tensor/reduction/var.py +2 -0
  111. maxframe/tensor/statistics/quantile.py +2 -2
  112. maxframe/tensor/utils.py +2 -22
  113. maxframe/tests/utils.py +11 -2
  114. maxframe/typing_.py +4 -1
  115. maxframe/udf.py +8 -9
  116. maxframe/utils.py +32 -70
  117. {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/METADATA +25 -25
  118. {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/RECORD +133 -123
  119. {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/WHEEL +1 -1
  120. maxframe_client/fetcher.py +60 -68
  121. maxframe_client/session/graph.py +8 -2
  122. maxframe_client/session/odps.py +58 -22
  123. maxframe_client/tests/test_fetcher.py +21 -3
  124. maxframe_client/tests/test_session.py +27 -4
  125. maxframe/core/entity/chunks.py +0 -68
  126. maxframe/core/entity/fuse.py +0 -73
  127. maxframe/core/graph/builder/chunk.py +0 -430
  128. maxframe/odpsio/tableio.py +0 -322
  129. maxframe/odpsio/volumeio.py +0 -95
  130. /maxframe/{odpsio → core/entity}/tests/__init__.py +0 -0
  131. /maxframe/{tensor/base/tests → io}/__init__.py +0 -0
  132. /maxframe/{odpsio → io/odpsio}/tests/test_arrow.py +0 -0
  133. /maxframe/tensor/{base → misc}/astype.py +0 -0
  134. /maxframe/tensor/{base → misc}/broadcast_to.py +0 -0
  135. /maxframe/tensor/{base → misc}/ravel.py +0 -0
  136. /maxframe/tensor/{base/tests/test_base.py → misc/tests/test_misc.py} +0 -0
  137. /maxframe/tensor/{base → misc}/where.py +0 -0
  138. {maxframe-1.0.0rc1.dist-info → maxframe-1.0.0rc3.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,7 @@ from odps.types import PartitionSpec
23
23
  from ... import opcodes
24
24
  from ...config import options
25
25
  from ...core import OutputType
26
- from ...odpsio import build_dataframe_table_meta
26
+ from ...io.odpsio import build_dataframe_table_meta
27
27
  from ...serialization.serializables import (
28
28
  BoolField,
29
29
  FieldTypes,
@@ -59,7 +59,6 @@ class GroupByCumReductionOperator(DataFrameOperatorMixin, DataFrameOperator):
59
59
  out_dtypes = self._calc_out_dtypes(groupby)
60
60
 
61
61
  kw = in_df.params.copy()
62
- kw["index_value"] = parse_index(pd.RangeIndex(-1), groupby.key)
63
62
  if self.output_types[0] == OutputType.dataframe:
64
63
  kw.update(
65
64
  dict(
@@ -282,14 +282,17 @@ def test_groupby_cum():
282
282
  r = getattr(mdf.groupby("b"), fun)()
283
283
  assert r.op.output_types[0] == OutputType.dataframe
284
284
  assert r.shape == (len(df1), 2)
285
+ assert r.index_value.key == mdf.index_value.key
285
286
 
286
287
  r = getattr(mdf.groupby("b"), fun)(axis=1)
287
288
  assert r.op.output_types[0] == OutputType.dataframe
288
289
  assert r.shape == (len(df1), 3)
290
+ assert r.index_value.key == mdf.index_value.key
289
291
 
290
292
  r = mdf.groupby("b").cumcount()
291
293
  assert r.op.output_types[0] == OutputType.series
292
294
  assert r.shape == (len(df1),)
295
+ assert r.index_value.key == mdf.index_value.key
293
296
 
294
297
  series1 = pd.Series([2, 2, 5, 7, 3, 7, 8, 8, 5, 6])
295
298
  ms1 = md.Series(series1, chunk_size=3)
@@ -298,6 +301,7 @@ def test_groupby_cum():
298
301
  r = getattr(ms1.groupby(lambda x: x % 2), fun)()
299
302
  assert r.op.output_types[0] == OutputType.series
300
303
  assert r.shape == (len(series1),)
304
+ assert r.index_value.key == ms1.index_value.key
301
305
 
302
306
 
303
307
  def test_groupby_fill():
@@ -51,7 +51,7 @@ def _get_prefix_suffix_docs(is_prefix: bool):
51
51
  Examples
52
52
  --------
53
53
  >>> import maxframe.dataframe as md
54
- >>> s = md.Series([1, 2, 3, 4])
54
+ >>> s = md.Series([1, 2, 3, 4])
55
55
  >>> s.execute()
56
56
  0 1
57
57
  1 2
@@ -17,7 +17,7 @@ import warnings
17
17
  from ... import opcodes
18
18
  from ...core import get_output_types
19
19
  from ...serialization.serializables import AnyField, StringField
20
- from ..core import SERIES_TYPE
20
+ from ..core import INDEX_TYPE, SERIES_TYPE
21
21
  from ..operators import DataFrameOperator, DataFrameOperatorMixin
22
22
  from ..utils import build_df, build_series, parse_index, validate_axis
23
23
 
@@ -73,6 +73,8 @@ class DataFrameRename(DataFrameOperator, DataFrameOperatorMixin):
73
73
  params["index_value"] = parse_index(new_index)
74
74
  if df.ndim == 1:
75
75
  params["name"] = new_df.name
76
+ if isinstance(df, INDEX_TYPE):
77
+ params["names"] = new_df.names
76
78
  return self.new_tileable([df], **params)
77
79
 
78
80
 
@@ -303,11 +305,6 @@ def series_rename(
303
305
  1 2
304
306
  2 3
305
307
  Name: my_name, dtype: int64
306
- >>> s.rename(lambda x: x ** 2).execute() # function, changes labels.execute()
307
- 0 1
308
- 1 2
309
- 4 3
310
- dtype: int64
311
308
  >>> s.rename({1: 3, 2: 5}).execute() # mapping, changes labels.execute()
312
309
  0 1
313
310
  3 2
@@ -410,37 +407,6 @@ def index_set_names(index, names, level=None, inplace=False):
410
407
  See Also
411
408
  --------
412
409
  Index.rename : Able to set new names without level.
413
-
414
- Examples
415
- --------
416
- >>> import maxframe.dataframe as md
417
- >>> idx = md.Index([1, 2, 3, 4])
418
- >>> idx.execute()
419
- Int64Index([1, 2, 3, 4], dtype='int64')
420
- >>> idx.set_names('quarter').execute()
421
- Int64Index([1, 2, 3, 4], dtype='int64', name='quarter')
422
-
423
- >>> idx = md.MultiIndex.from_product([['python', 'cobra'],
424
- ... [2018, 2019]])
425
- >>> idx.execute()
426
- MultiIndex([('python', 2018),
427
- ('python', 2019),
428
- ( 'cobra', 2018),
429
- ( 'cobra', 2019)],
430
- )
431
- >>> idx.set_names(['kind', 'year'], inplace=True)
432
- >>> idx.execute()
433
- MultiIndex([('python', 2018),
434
- ('python', 2019),
435
- ( 'cobra', 2018),
436
- ( 'cobra', 2019)],
437
- names=['kind', 'year'])
438
- >>> idx.set_names('species', level=0).execute()
439
- MultiIndex([('python', 2018),
440
- ('python', 2019),
441
- ( 'cobra', 2018),
442
- ( 'cobra', 2019)],
443
- names=['species', 'year'])
444
410
  """
445
411
  op = DataFrameRename(
446
412
  index_mapper=names, level=level, output_types=get_output_types(index)
@@ -195,7 +195,6 @@ def sample(
195
195
  num_legs num_wings num_specimen_seen
196
196
  falcon 2 2 10
197
197
  fish 0 0 8
198
-
199
198
  """
200
199
  axis = validate_axis(axis or 0, df_or_series)
201
200
  if axis == 1:
@@ -31,7 +31,7 @@ class DataFrameSetIndex(DataFrameOperator, DataFrameOperatorMixin):
31
31
  super().__init__(_output_types=output_types, **kw)
32
32
 
33
33
  def __call__(self, df):
34
- new_df = build_empty_df(df.dtypes).set_index(
34
+ new_df = build_empty_df(df.dtypes, index=df.index_value.to_pandas()).set_index(
35
35
  keys=self.keys,
36
36
  drop=self.drop,
37
37
  append=self.append,
@@ -47,6 +47,73 @@ class DataFrameSetIndex(DataFrameOperator, DataFrameOperatorMixin):
47
47
 
48
48
 
49
49
  def set_index(df, keys, drop=True, append=False, inplace=False, verify_integrity=False):
50
+ # TODO add support for set index by series, index, mt.ndarray, etc.
51
+ """
52
+ Set the DataFrame index using existing columns.
53
+
54
+ Set the DataFrame index (row labels) using one or more existing
55
+ columns. The index can replace the existing index or expand on it.
56
+
57
+ Parameters
58
+ ----------
59
+ keys : label or array-like or list of labels
60
+ This parameter can be either a single column key, or a list containing column keys.
61
+ drop : bool, default True
62
+ Delete columns to be used as the new index.
63
+ append : bool, default False
64
+ Whether to append columns to existing index.
65
+ inplace : bool, default False
66
+ If True, modifies the DataFrame in place (do not create a new object).
67
+ verify_integrity : bool, default False
68
+ Check the new index for duplicates. Otherwise defer the check until
69
+ necessary. Setting to False will improve the performance of this
70
+ method.
71
+
72
+ Returns
73
+ -------
74
+ DataFrame or None
75
+ Changed row labels or None if ``inplace=True``.
76
+
77
+ See Also
78
+ --------
79
+ DataFrame.reset_index : Opposite of set_index.
80
+ DataFrame.reindex : Change to new indices or expand indices.
81
+ DataFrame.reindex_like : Change to same indices as other DataFrame.
82
+
83
+ Examples
84
+ --------
85
+ >>> import maxframe.dataframe as md
86
+
87
+ >>> df = md.DataFrame({'month': [1, 4, 7, 10],
88
+ ... 'year': [2012, 2014, 2013, 2014],
89
+ ... 'sale': [55, 40, 84, 31]})
90
+ >>> df
91
+ month year sale
92
+ 0 1 2012 55
93
+ 1 4 2014 40
94
+ 2 7 2013 84
95
+ 3 10 2014 31
96
+
97
+ Set the index to become the 'month' column:
98
+
99
+ >>> df.set_index('month')
100
+ year sale
101
+ month
102
+ 1 2012 55
103
+ 4 2014 40
104
+ 7 2013 84
105
+ 10 2014 31
106
+
107
+ Create a MultiIndex using columns 'year' and 'month':
108
+
109
+ >>> df.set_index(['year', 'month'])
110
+ sale
111
+ year month
112
+ 2012 1 55
113
+ 2014 4 40
114
+ 2013 7 84
115
+ 2014 10 31
116
+ """
50
117
  op = DataFrameSetIndex(
51
118
  keys=keys,
52
119
  drop=drop,
@@ -11,12 +11,13 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
14
  import logging
15
+ from abc import abstractmethod
16
16
  from collections import namedtuple
17
- from typing import Any, Dict, Optional, Tuple, Union
17
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
18
18
 
19
19
  import numpy as np
20
+ from pandas import Index
20
21
 
21
22
  from ... import opcodes
22
23
  from ...core import OutputType
@@ -28,6 +29,7 @@ from ...serialization.serializables import (
28
29
  Int32Field,
29
30
  KeyField,
30
31
  NamedTupleField,
32
+ Serializable,
31
33
  StringField,
32
34
  TupleField,
33
35
  )
@@ -73,9 +75,208 @@ class DataFrameMergeAlign(MapReduceOperator, DataFrameOperatorMixin):
73
75
  MergeSplitInfo = namedtuple("MergeSplitInfo", "split_side, split_index, nsplits")
74
76
 
75
77
 
78
+ class JoinHint(Serializable):
79
+ @abstractmethod
80
+ def verify_params(
81
+ self,
82
+ hint_on_df: Union[DataFrame, Series],
83
+ on: str,
84
+ is_on_index: bool,
85
+ how: str,
86
+ is_hint_for_left: bool,
87
+ ):
88
+ pass
89
+
90
+ @abstractmethod
91
+ def verify_can_work_with(self, other: "JoinHint"):
92
+ pass
93
+
94
+
95
+ class MapJoinHint(JoinHint):
96
+ def verify_params(
97
+ self,
98
+ hint_on_df: Union[DataFrame, Series],
99
+ on: str,
100
+ is_on_index: bool,
101
+ how: str,
102
+ is_hint_for_left: bool,
103
+ ):
104
+ if how in ("cross", "outer"):
105
+ raise ValueError(
106
+ "Invalid join hint, MapJoinHint is not support in cross and outer join"
107
+ )
108
+
109
+ def verify_can_work_with(self, other: JoinHint):
110
+ if isinstance(other, SkewJoinHint):
111
+ raise ValueError(
112
+ "Invalid join hint, SkewJoinHint cannot work with MapJoinHint"
113
+ )
114
+
115
+
116
+ class DistributedMapJoinHint(JoinHint):
117
+ shard_count = Int32Field("shard_count")
118
+ replica_count = Int32Field("replica_count", default=1)
119
+
120
+ def verify_params(
121
+ self,
122
+ hint_on_df: Union[DataFrame, Series],
123
+ on: str,
124
+ is_on_index: bool,
125
+ how: str,
126
+ is_hint_for_left: bool,
127
+ ):
128
+ if how in ("cross", "outer"):
129
+ raise ValueError(
130
+ "Invalid join hint, DistributedMapJoinHint is not support in cross and outer join"
131
+ )
132
+ if not hasattr(self, "shard_count"):
133
+ raise ValueError(
134
+ "Invalid DistributedMapJoinHint, shard_count must be specified"
135
+ )
136
+ if self.shard_count <= 0 or self.replica_count <= 0:
137
+ raise ValueError(
138
+ "Invalid DistributedMapJoinHint, shard_count and replica_count must be greater than 0"
139
+ )
140
+
141
+ def verify_can_work_with(self, other: JoinHint):
142
+ pass
143
+
144
+
145
+ class SkewJoinHint(JoinHint):
146
+ columns = AnyField("columns", default=None)
147
+
148
+ @staticmethod
149
+ def _check_index_levels(index, level_list):
150
+ selected_levels = set()
151
+ valid_levels = set(range(index.nlevels))
152
+ valid_level_names = set(index.names)
153
+
154
+ for item in level_list:
155
+ if isinstance(item, int):
156
+ if item not in valid_levels:
157
+ raise ValueError(f"Level {item} is not a valid index level")
158
+ if item in selected_levels:
159
+ raise ValueError(f"Level {item} is selected multiple times")
160
+ selected_levels.add(item)
161
+ elif isinstance(item, str):
162
+ if item not in valid_level_names:
163
+ raise ValueError(f"'{item}' is not a valid index level name")
164
+ level = index.names.index(item)
165
+ if level in selected_levels:
166
+ raise ValueError(
167
+ f"'{item}' (Level {level}) is selected multiple times"
168
+ )
169
+ selected_levels.add(level)
170
+ else:
171
+ raise ValueError(f"Invalid input type: {type(item)}")
172
+
173
+ @staticmethod
174
+ def _check_columns(join_on_columns, column_list):
175
+ selected_columns = set()
176
+ valid_columns = set(join_on_columns)
177
+
178
+ for item in column_list:
179
+ if isinstance(item, int):
180
+ if item < 0 or item >= len(join_on_columns):
181
+ raise ValueError(f"Column index {item} is out of range")
182
+ col_name = join_on_columns[item]
183
+ if col_name in selected_columns:
184
+ raise ValueError(
185
+ f"Column '{col_name}' (index {item}) is selected multiple times"
186
+ )
187
+ selected_columns.add(col_name)
188
+ elif isinstance(item, str):
189
+ if item not in valid_columns:
190
+ raise ValueError(f"'{item}' is not a valid column name")
191
+ if item in selected_columns:
192
+ raise ValueError(f"Column '{item}' is selected multiple times")
193
+ selected_columns.add(item)
194
+ else:
195
+ raise ValueError(f"Invalid input type: {type(item)}")
196
+
197
+ def verify_params(
198
+ self,
199
+ hint_on_df: Union[DataFrame, Series],
200
+ on: str,
201
+ is_on_index: bool,
202
+ how: str,
203
+ is_hint_for_left: bool,
204
+ ):
205
+ if how in ("cross", "outer"):
206
+ raise ValueError(
207
+ "Invalid join hint, map join is not support in cross and outer join"
208
+ )
209
+ if is_hint_for_left and how == "right":
210
+ raise ValueError(
211
+ "Invalid join hint, right join can only use SkewJoinHint on right frame"
212
+ )
213
+ elif not is_hint_for_left and how == "left":
214
+ raise ValueError(
215
+ "Invalid join hint, left join can only use SkewJoinHint on left frame"
216
+ )
217
+
218
+ # check columns
219
+ if self.columns is None:
220
+ return
221
+
222
+ if not isinstance(self.columns, list):
223
+ raise TypeError("Invalid SkewJoinHint, `columns` must be a list")
224
+
225
+ if all(isinstance(item, (int, str)) for item in self.columns):
226
+ # if elements are int (levels) or str (index names or column names)
227
+ self._verify_valid_index_or_columns(
228
+ self.columns, hint_on_df.index_value.to_pandas(), on, is_on_index
229
+ )
230
+ elif all(isinstance(c, dict) for c in self.columns):
231
+ # dict with column names and values
232
+ cols_set = set(self.columns[0].keys())
233
+ if any(cols_set != set(c.keys()) for c in self.columns):
234
+ raise ValueError(
235
+ "Invalid SkewJoinHint, all values in `columns` need to have same columns"
236
+ )
237
+
238
+ self._verify_valid_index_or_columns(
239
+ cols_set, hint_on_df.index_value.to_pandas(), on, is_on_index
240
+ )
241
+ else:
242
+ raise TypeError("Invalid SkewJoinHint, annot accept `columns` type")
243
+
244
+ def verify_can_work_with(self, other: JoinHint):
245
+ if isinstance(other, SkewJoinHint):
246
+ raise ValueError(
247
+ "Invalid join hint, SkewJoinHint cannot work with MapJoinHint"
248
+ )
249
+
250
+ @staticmethod
251
+ def _verify_valid_index_or_columns(
252
+ skew_join_columns: Iterable[Union[int, str]],
253
+ frame_index: Index,
254
+ on: Union[str, List[str]],
255
+ is_on_index: bool,
256
+ ):
257
+ if isinstance(on, str):
258
+ on = [on]
259
+ on_columns = set(frame_index.names if is_on_index else on)
260
+ for col in skew_join_columns:
261
+ if isinstance(col, int):
262
+ if col < 0 or col >= len(on_columns):
263
+ raise ValueError(
264
+ f"Invalid, SkeJoinHint, `{col}` is out of join on columns range"
265
+ )
266
+ else:
267
+ if col not in on_columns:
268
+ raise ValueError(
269
+ f"Invalid, SkeJoinHint, '{col}' is not a valid column name"
270
+ )
271
+
272
+
76
273
  class DataFrameMerge(DataFrameOperator, DataFrameOperatorMixin):
77
274
  _op_type_ = opcodes.DATAFRAME_MERGE
78
275
 
276
+ # workaround for new field since v1.0.0rc2
277
+ # todo remove this when all versions below v1.0.0rc1 is eliminated
278
+ _legacy_new_non_primitives = ["left_hint", "right_hint"]
279
+
79
280
  how = StringField("how")
80
281
  on = AnyField("on")
81
282
  left_on = AnyField("left_on")
@@ -95,6 +296,8 @@ class DataFrameMerge(DataFrameOperator, DataFrameOperatorMixin):
95
296
 
96
297
  # only for broadcast merge
97
298
  split_info = NamedTupleField("split_info")
299
+ left_hint = AnyField("left_hint", default=None)
300
+ right_hint = AnyField("right_hint", default=None)
98
301
 
99
302
  def __init__(self, copy=None, **kwargs):
100
303
  super().__init__(copy_=copy, **kwargs)
@@ -165,6 +368,8 @@ def merge(
165
368
  auto_merge_threshold: int = 8,
166
369
  bloom_filter: Union[bool, str] = "auto",
167
370
  bloom_filter_options: Dict[str, Any] = None,
371
+ left_hint: JoinHint = None,
372
+ right_hint: JoinHint = None,
168
373
  ) -> DataFrame:
169
374
  """
170
375
  Merge DataFrame or named Series objects with a database-style join.
@@ -267,6 +472,12 @@ def merge(
267
472
  when chunk size of left and right is greater than this threshold, apply bloom filter
268
473
  * "filter": "large", "small", "both", default "large"
269
474
  decides to filter on large, small or both DataFrames.
475
+ left_hint: JoinHint, default None
476
+ Join strategy to use for left frame. When data skew occurs, consider these strategies to avoid long-tail issues,
477
+ but use them cautiously to prevent OOM and unnecessary overhead.
478
+ right_hint: JoinHint, default None
479
+ Join strategy to use for right frame.
480
+
270
481
 
271
482
  Returns
272
483
  -------
@@ -381,6 +592,18 @@ def merge(
381
592
  raise ValueError(
382
593
  f"Invalid filter {k}, available: {BLOOM_FILTER_ON_OPTIONS}"
383
594
  )
595
+
596
+ if left_hint:
597
+ if not isinstance(left_hint, JoinHint):
598
+ raise TypeError(f"left_hint must be a JoinHint, got {type(left_hint)}")
599
+ left_hint.verify_can_work_with(right_hint)
600
+ left_hint.verify_params(df, on or left_on, left_index, how, True)
601
+
602
+ if right_hint:
603
+ if not isinstance(right_hint, JoinHint):
604
+ raise TypeError(f"right_hint must be a JoinHint, got {type(right_hint)}")
605
+ right_hint.verify_params(right, on or right_on, right_index, how, False)
606
+
384
607
  op = DataFrameMerge(
385
608
  how=how,
386
609
  on=on,
@@ -399,6 +622,8 @@ def merge(
399
622
  bloom_filter=bloom_filter,
400
623
  bloom_filter_options=bloom_filter_options,
401
624
  output_types=[OutputType.dataframe],
625
+ left_hint=left_hint,
626
+ right_hint=right_hint,
402
627
  )
403
628
  return op(df, right)
404
629
 
@@ -416,6 +641,8 @@ def join(
416
641
  auto_merge_threshold: int = 8,
417
642
  bloom_filter: Union[bool, Dict] = True,
418
643
  bloom_filter_options: Dict[str, Any] = None,
644
+ left_hint: JoinHint = None,
645
+ right_hint: JoinHint = None,
419
646
  ) -> DataFrame:
420
647
  """
421
648
  Join columns of another DataFrame.
@@ -480,6 +707,11 @@ def join(
480
707
  when chunk size of left and right is greater than this threshold, apply bloom filter
481
708
  * "filter": "large", "small", "both", default "large"
482
709
  decides to filter on large, small or both DataFrames.
710
+ left_hint: JoinHint, default None
711
+ Join strategy to use for left frame. When data skew occurs, consider these strategies to avoid long-tail issues,
712
+ but use them cautiously to prevent OOM and unnecessary overhead.
713
+ right_hint: JoinHint, default None
714
+ Join strategy to use for right frame.
483
715
 
484
716
  Returns
485
717
  -------
@@ -590,4 +822,6 @@ def join(
590
822
  auto_merge_threshold=auto_merge_threshold,
591
823
  bloom_filter=bloom_filter,
592
824
  bloom_filter_options=bloom_filter_options,
825
+ left_hint=left_hint,
826
+ right_hint=right_hint,
593
827
  )
@@ -19,6 +19,7 @@ import pytest
19
19
  from ...core import IndexValue
20
20
  from ...datasource.dataframe import from_pandas
21
21
  from .. import DataFrameMerge, concat
22
+ from ..merge import DistributedMapJoinHint, MapJoinHint, SkewJoinHint
22
23
 
23
24
 
24
25
  def test_merge():
@@ -30,14 +31,39 @@ def test_merge():
30
31
  mdf1 = from_pandas(df1, chunk_size=2)
31
32
  mdf2 = from_pandas(df2, chunk_size=3)
32
33
 
34
+ mapjoin = MapJoinHint()
35
+ dist_mapjoin1 = DistributedMapJoinHint(shard_count=5)
36
+ skew_join1 = SkewJoinHint()
37
+ skew_join2 = SkewJoinHint(columns=[0])
38
+ skew_join3 = SkewJoinHint(columns=[{"a": 4}, {"a": 6}])
39
+ skew_join4 = SkewJoinHint(columns=[{"a": 4, "b": "test"}, {"a": 5, "b": "hello"}])
40
+
33
41
  parameters = [
34
42
  {},
35
43
  {"how": "left", "right_on": "x", "left_index": True},
44
+ {
45
+ "how": "left",
46
+ "right_on": "x",
47
+ "left_index": True,
48
+ "left_hint": mapjoin,
49
+ "right_hint": mapjoin,
50
+ },
36
51
  {"how": "right", "left_on": "a", "right_index": True},
52
+ {
53
+ "how": "right",
54
+ "left_on": "a",
55
+ "right_index": True,
56
+ "left_hint": mapjoin,
57
+ "right_hint": dist_mapjoin1,
58
+ },
37
59
  {"how": "left", "left_on": "a", "right_on": "x"},
60
+ {"how": "left", "left_on": "a", "right_on": "x", "left_hint": skew_join1},
38
61
  {"how": "right", "left_on": "a", "right_index": True},
62
+ {"how": "right", "left_on": "a", "right_index": True, "right_hint": skew_join2},
39
63
  {"how": "right", "on": "a"},
64
+ {"how": "right", "on": "a", "right_hint": skew_join3},
40
65
  {"how": "inner", "on": ["a", "b"]},
66
+ {"how": "inner", "on": ["a", "b"], "left_hint": skew_join4},
41
67
  ]
42
68
 
43
69
  for kw in parameters:
@@ -213,3 +239,100 @@ def test_concat():
213
239
  mdf2 = from_pandas(df2, chunk_size=3)
214
240
  r = concat([mdf1, mdf2], join="inner")
215
241
  assert r.shape == (20, 3)
242
+
243
+
244
+ def test_invalid_join_hint():
245
+ df1 = pd.DataFrame(
246
+ np.arange(20).reshape((4, 5)) + 1, columns=["a", "b", "c", "d", "e"]
247
+ )
248
+ df2 = pd.DataFrame(np.arange(20).reshape((5, 4)) + 1, columns=["a", "b", "x", "y"])
249
+
250
+ mdf1 = from_pandas(df1, chunk_size=2)
251
+ mdf2 = from_pandas(df2, chunk_size=3)
252
+
253
+ # type error
254
+ parameters = [
255
+ {"how": "left", "right_on": "x", "left_index": True, "left_hint": [1]},
256
+ {
257
+ "how": "left",
258
+ "right_on": "x",
259
+ "left_index": True,
260
+ "left_hint": {"key": "value"},
261
+ },
262
+ {
263
+ "how": "right",
264
+ "left_on": "a",
265
+ "right_index": True,
266
+ "right_hint": SkewJoinHint(columns=2),
267
+ },
268
+ {
269
+ "how": "left",
270
+ "left_on": "a",
271
+ "right_on": "x",
272
+ "left_hint": SkewJoinHint(columns="a"),
273
+ },
274
+ {
275
+ "how": "right",
276
+ "left_on": "a",
277
+ "right_index": True,
278
+ "right_hint": SkewJoinHint(columns=["0", []]),
279
+ },
280
+ ]
281
+
282
+ for kw in parameters:
283
+ print(kw)
284
+ with pytest.raises(TypeError):
285
+ mdf1.merge(mdf2, **kw)
286
+
287
+ # value error
288
+ parameters = [
289
+ # mapjoin can't working with skew join
290
+ {
291
+ "how": "left",
292
+ "right_on": "x",
293
+ "left_index": True,
294
+ "left_hint": MapJoinHint(),
295
+ "right_hint": SkewJoinHint(),
296
+ },
297
+ # right join can't apply to skew join left frame
298
+ {
299
+ "how": "right",
300
+ "left_on": "a",
301
+ "right_index": True,
302
+ "left_hint": SkewJoinHint(),
303
+ },
304
+ # invalid columns
305
+ {
306
+ "how": "left",
307
+ "left_on": "a",
308
+ "right_on": "x",
309
+ "left_hint": SkewJoinHint(columns=["b"]),
310
+ },
311
+ # invalid index level
312
+ {
313
+ "how": "right",
314
+ "left_on": "a",
315
+ "right_index": True,
316
+ "right_hint": SkewJoinHint(columns=[5]),
317
+ },
318
+ # unmatched skew join columns
319
+ {
320
+ "how": "right",
321
+ "left_on": "a",
322
+ "right_index": True,
323
+ "right_hint": SkewJoinHint(columns=[{0: "value1"}, {1: "value2"}]),
324
+ },
325
+ # invalid dist_mapjoin shard_count
326
+ {"how": "right", "on": "a", "right_hint": DistributedMapJoinHint()},
327
+ # all can't work with outer join
328
+ {"how": "outer", "on": ["a", "b"], "left_hint": MapJoinHint()},
329
+ {
330
+ "how": "outer",
331
+ "on": ["a", "b"],
332
+ "left_hint": DistributedMapJoinHint(shard_count=5),
333
+ },
334
+ {"how": "outer", "on": ["a", "b"], "left_hint": SkewJoinHint()},
335
+ ]
336
+ for kw in parameters:
337
+ with pytest.raises(ValueError):
338
+ mdf1.merge(mdf2, **kw)