maxframe 1.0.0rc3__cp311-cp311-win_amd64.whl → 1.0.0rc4__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (57) hide show
  1. maxframe/_utils.cp311-win_amd64.pyd +0 -0
  2. maxframe/codegen.py +1 -0
  3. maxframe/config/config.py +13 -1
  4. maxframe/conftest.py +43 -12
  5. maxframe/core/entity/executable.py +1 -1
  6. maxframe/core/graph/core.cp311-win_amd64.pyd +0 -0
  7. maxframe/dataframe/arithmetic/docstring.py +26 -2
  8. maxframe/dataframe/arithmetic/equal.py +4 -2
  9. maxframe/dataframe/arithmetic/greater.py +4 -2
  10. maxframe/dataframe/arithmetic/greater_equal.py +4 -2
  11. maxframe/dataframe/arithmetic/less.py +2 -2
  12. maxframe/dataframe/arithmetic/less_equal.py +4 -2
  13. maxframe/dataframe/arithmetic/not_equal.py +4 -2
  14. maxframe/dataframe/core.py +2 -0
  15. maxframe/dataframe/datasource/read_odps_query.py +66 -7
  16. maxframe/dataframe/datasource/read_odps_table.py +3 -1
  17. maxframe/dataframe/datasource/tests/test_datasource.py +35 -6
  18. maxframe/dataframe/datastore/to_odps.py +7 -0
  19. maxframe/dataframe/extensions/__init__.py +3 -0
  20. maxframe/dataframe/extensions/flatmap.py +326 -0
  21. maxframe/dataframe/extensions/tests/test_extensions.py +62 -1
  22. maxframe/dataframe/indexing/add_prefix_suffix.py +1 -1
  23. maxframe/dataframe/indexing/rename.py +11 -0
  24. maxframe/dataframe/initializer.py +11 -1
  25. maxframe/dataframe/misc/drop_duplicates.py +18 -1
  26. maxframe/dataframe/tests/test_initializer.py +33 -2
  27. maxframe/io/odpsio/schema.py +5 -3
  28. maxframe/io/odpsio/tableio.py +44 -38
  29. maxframe/io/odpsio/tests/test_schema.py +0 -4
  30. maxframe/io/odpsio/volumeio.py +9 -3
  31. maxframe/learn/contrib/__init__.py +2 -1
  32. maxframe/learn/contrib/graph/__init__.py +15 -0
  33. maxframe/learn/contrib/graph/connected_components.py +215 -0
  34. maxframe/learn/contrib/graph/tests/__init__.py +13 -0
  35. maxframe/learn/contrib/graph/tests/test_connected_components.py +53 -0
  36. maxframe/learn/contrib/xgboost/classifier.py +3 -3
  37. maxframe/learn/contrib/xgboost/predict.py +8 -39
  38. maxframe/learn/contrib/xgboost/train.py +4 -3
  39. maxframe/lib/mmh3.cp311-win_amd64.pyd +0 -0
  40. maxframe/opcodes.py +3 -0
  41. maxframe/protocol.py +6 -1
  42. maxframe/serialization/core.cp311-win_amd64.pyd +0 -0
  43. maxframe/session.py +9 -2
  44. maxframe/tensor/indexing/getitem.py +2 -0
  45. maxframe/tensor/merge/concatenate.py +23 -20
  46. maxframe/tensor/merge/vstack.py +5 -1
  47. maxframe/tensor/misc/transpose.py +1 -1
  48. maxframe/utils.py +34 -12
  49. {maxframe-1.0.0rc3.dist-info → maxframe-1.0.0rc4.dist-info}/METADATA +1 -1
  50. {maxframe-1.0.0rc3.dist-info → maxframe-1.0.0rc4.dist-info}/RECORD +57 -52
  51. {maxframe-1.0.0rc3.dist-info → maxframe-1.0.0rc4.dist-info}/WHEEL +1 -1
  52. maxframe_client/fetcher.py +10 -8
  53. maxframe_client/session/consts.py +3 -0
  54. maxframe_client/session/odps.py +84 -13
  55. maxframe_client/session/task.py +58 -20
  56. maxframe_client/tests/test_session.py +14 -2
  57. {maxframe-1.0.0rc3.dist-info → maxframe-1.0.0rc4.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ from abc import ABC, abstractmethod
18
18
  from contextlib import contextmanager
19
19
  from typing import Dict, List, Optional, Union
20
20
 
21
+ import numpy as np
21
22
  import pyarrow as pa
22
23
  from odps import ODPS
23
24
  from odps import __version__ as pyodps_version
@@ -26,7 +27,6 @@ from odps.apis.storage_api import (
26
27
  TableBatchScanResponse,
27
28
  TableBatchWriteResponse,
28
29
  )
29
- from odps.config import option_context as pyodps_option_context
30
30
  from odps.tunnel import TableTunnel
31
31
  from odps.types import OdpsSchema, PartitionSpec, timestamp_ntz
32
32
 
@@ -38,19 +38,13 @@ except ImportError:
38
38
  from ...config import options
39
39
  from ...env import ODPS_STORAGE_API_ENDPOINT
40
40
  from ...lib.version import Version
41
+ from ...utils import sync_pyodps_options
41
42
  from .schema import odps_schema_to_arrow_schema
42
43
 
43
44
  PartitionsType = Union[List[str], str, None]
44
45
 
45
46
  _DEFAULT_ROW_BATCH_SIZE = 4096
46
- _need_convert_timezone = Version(pyodps_version) < Version("0.11.7")
47
-
48
-
49
- @contextmanager
50
- def _sync_pyodps_timezone():
51
- with pyodps_option_context() as cfg:
52
- cfg.local_timezone = options.local_timezone
53
- yield
47
+ _need_patch_batch = Version(pyodps_version) < Version("0.12.0")
54
48
 
55
49
 
56
50
  class ODPSTableIO(ABC):
@@ -166,10 +160,15 @@ class TunnelMultiPartitionReader:
166
160
  self._cur_partition_id += 1
167
161
 
168
162
  part_str = self._partitions[self._cur_partition_id]
169
- with _sync_pyodps_timezone():
163
+
164
+ # todo make this more formal when PyODPS 0.12.0 is released
165
+ req_columns = self._columns
166
+ if not _need_patch_batch:
167
+ req_columns = self._schema.names
168
+ with sync_pyodps_options():
170
169
  self._cur_reader = self._table.open_reader(
171
170
  part_str,
172
- columns=self._columns,
171
+ columns=req_columns,
173
172
  arrow=True,
174
173
  download_id=self._partition_to_download_ids.get(part_str),
175
174
  )
@@ -180,7 +179,7 @@ class TunnelMultiPartitionReader:
180
179
  else:
181
180
  count = min(self._count, self._cur_reader.count - start)
182
181
 
183
- with _sync_pyodps_timezone():
182
+ with sync_pyodps_options():
184
183
  self._reader_iter = self._cur_reader.read(start, count)
185
184
  break
186
185
  self._reader_start_pos += self._cur_reader.count
@@ -194,7 +193,7 @@ class TunnelMultiPartitionReader:
194
193
  arrays = []
195
194
  for idx in range(batch.num_columns):
196
195
  col = batch.column(idx)
197
- if _need_convert_timezone and isinstance(col.type, pa.TimestampType):
196
+ if isinstance(col.type, pa.TimestampType):
198
197
  if col.type.tz is not None:
199
198
  target_type = pa.timestamp(
200
199
  self._schema.types[idx].unit, col.type.tz
@@ -212,11 +211,12 @@ class TunnelMultiPartitionReader:
212
211
  for part_col in self._partition_cols or []:
213
212
  names.append(part_col)
214
213
  col_type = self._schema.field_by_name(part_col).type
215
- arrays.append(pa.array([pt_spec[part_col]] * batch.num_rows).cast(col_type))
214
+ pt_col = np.repeat([pt_spec[part_col]], batch.num_rows)
215
+ arrays.append(pa.array(pt_col).cast(col_type))
216
216
  return pa.RecordBatch.from_arrays(arrays, names)
217
217
 
218
218
  def read(self):
219
- with _sync_pyodps_timezone():
219
+ with sync_pyodps_options():
220
220
  if self._cur_reader is None:
221
221
  self._open_next_reader()
222
222
  if self._cur_reader is None:
@@ -227,7 +227,10 @@ class TunnelMultiPartitionReader:
227
227
  if batch is not None:
228
228
  if self._row_left is not None:
229
229
  self._row_left -= batch.num_rows
230
- return self._fill_batch_partition(batch)
230
+ if _need_patch_batch:
231
+ return self._fill_batch_partition(batch)
232
+ else:
233
+ return batch
231
234
  except StopIteration:
232
235
  self._open_next_reader()
233
236
  return None
@@ -285,7 +288,9 @@ class TunnelTableIO(ODPSTableIO):
285
288
  reverse_range: bool = False,
286
289
  row_batch_size: int = _DEFAULT_ROW_BATCH_SIZE,
287
290
  ):
288
- table = self._odps.get_table(full_table_name)
291
+ with sync_pyodps_options():
292
+ table = self._odps.get_table(full_table_name)
293
+
289
294
  if partition_columns is True:
290
295
  partition_columns = [c.name for c in table.table_schema.partitions]
291
296
 
@@ -296,21 +301,22 @@ class TunnelTableIO(ODPSTableIO):
296
301
  or (stop is not None and stop < 0)
297
302
  or (reverse_range and start is None)
298
303
  ):
299
- table = self._odps.get_table(full_table_name)
300
- tunnel = TableTunnel(self._odps)
301
- parts = (
302
- [partitions]
303
- if partitions is None or isinstance(partitions, str)
304
- else partitions
305
- )
306
- part_to_down_id = dict()
307
- total_records = 0
308
- for part in parts:
309
- down_session = tunnel.create_download_session(
310
- table, async_mode=True, partition_spec=part
304
+ with sync_pyodps_options():
305
+ table = self._odps.get_table(full_table_name)
306
+ tunnel = TableTunnel(self._odps)
307
+ parts = (
308
+ [partitions]
309
+ if partitions is None or isinstance(partitions, str)
310
+ else partitions
311
311
  )
312
- part_to_down_id[part] = down_session.id
313
- total_records += down_session.count
312
+ part_to_down_id = dict()
313
+ total_records = 0
314
+ for part in parts:
315
+ down_session = tunnel.create_download_session(
316
+ table, async_mode=True, partition_spec=part
317
+ )
318
+ part_to_down_id[part] = down_session.id
319
+ total_records += down_session.count
314
320
 
315
321
  count = None
316
322
  if start is not None or stop is not None:
@@ -347,7 +353,7 @@ class TunnelTableIO(ODPSTableIO):
347
353
  overwrite: bool = True,
348
354
  ):
349
355
  table = self._odps.get_table(full_table_name)
350
- with _sync_pyodps_timezone():
356
+ with sync_pyodps_options():
351
357
  with table.open_writer(
352
358
  partition=partition,
353
359
  arrow=True,
@@ -357,7 +363,7 @@ class TunnelTableIO(ODPSTableIO):
357
363
  # fixme should yield writer directly once pyodps fixes
358
364
  # related arrow timestamp bug when provided schema and
359
365
  # table schema is identical.
360
- if _need_convert_timezone:
366
+ if _need_patch_batch:
361
367
  yield TunnelWrappedWriter(writer)
362
368
  else:
363
369
  yield writer
@@ -596,8 +602,8 @@ class HaloTableIO(ODPSTableIO):
596
602
  ):
597
603
  from odps.apis.storage_api import (
598
604
  SessionRequest,
605
+ SessionStatus,
599
606
  SplitOptions,
600
- Status,
601
607
  TableBatchScanRequest,
602
608
  )
603
609
 
@@ -628,13 +634,13 @@ class HaloTableIO(ODPSTableIO):
628
634
  resp = client.create_read_session(req)
629
635
 
630
636
  session_id = resp.session_id
631
- status = resp.status
632
- while status == Status.WAIT:
637
+ status = resp.session_status
638
+ while status == SessionStatus.INIT:
633
639
  resp = client.get_read_session(SessionRequest(session_id))
634
- status = resp.status
640
+ status = resp.session_status
635
641
  time.sleep(1.0)
636
642
 
637
- assert status == Status.OK
643
+ assert status == SessionStatus.NORMAL
638
644
 
639
645
  count = None
640
646
  if start is not None or stop is not None:
@@ -270,10 +270,6 @@ def test_odps_arrow_schema_conversion():
270
270
 
271
271
  with pytest.raises(TypeError):
272
272
  arrow_schema_to_odps_schema(pa.schema([("col1", pa.float16())]))
273
- with pytest.raises(TypeError):
274
- odps_schema_to_arrow_schema(
275
- odps_types.OdpsSchema([odps_types.Column("col1", "json")])
276
- )
277
273
 
278
274
 
279
275
  def test_build_column_name():
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import Iterator, List, Union
16
+ from typing import Iterator, List, Optional, Union
17
17
 
18
18
  from odps import ODPS
19
19
 
@@ -43,9 +43,15 @@ class ODPSVolumeReader:
43
43
 
44
44
 
45
45
  class ODPSVolumeWriter:
46
- def __init__(self, odps_entry: ODPS, volume_name: str, volume_dir: str):
46
+ def __init__(
47
+ self,
48
+ odps_entry: ODPS,
49
+ volume_name: str,
50
+ volume_dir: str,
51
+ schema_name: Optional[str] = None,
52
+ ):
47
53
  self._odps_entry = odps_entry
48
- self._volume = odps_entry.get_volume(volume_name)
54
+ self._volume = odps_entry.get_volume(volume_name, schema=schema_name)
49
55
  self._volume_dir = volume_dir
50
56
 
51
57
  def write_file(self, file_name: str, data: Union[bytes, Iterator[bytes]]):
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from . import pytorch
15
+ from . import graph, pytorch
16
16
 
17
17
  del pytorch
18
+ del graph
@@ -0,0 +1,15 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .connected_components import connected_components
@@ -0,0 +1,215 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+
18
+ from maxframe import opcodes
19
+
20
+ from ....core import OutputType
21
+ from ....dataframe.operators import DataFrameOperator, DataFrameOperatorMixin
22
+ from ....dataframe.utils import make_dtypes, parse_index
23
+ from ....serialization.serializables import Int32Field, StringField
24
+
25
+
26
+ class DataFrameConnectedComponentsOperator(DataFrameOperator, DataFrameOperatorMixin):
27
+ _op_type_ = opcodes.CONNECTED_COMPONENTS
28
+
29
+ vertex_col1 = StringField("vertex_col1", default=None)
30
+ vertex_col2 = StringField("vertex_col2", default=None)
31
+ max_iter = Int32Field("max_iter", default=6)
32
+
33
+ def __call__(self, df):
34
+ node_id_dtype = df.dtypes[self.vertex_col1]
35
+ dtypes = make_dtypes({"id": node_id_dtype, "component": node_id_dtype})
36
+ # this will return a dataframe and a bool flag
37
+ new_dataframe_tileable_kw = {
38
+ "shape": (np.nan, 2),
39
+ "index_value": parse_index(pd.RangeIndex(0)),
40
+ "columns_value": parse_index(dtypes.index, store_data=True),
41
+ "dtypes": dtypes,
42
+ }
43
+ new_scalar_tileable_kw = {"dtype": np.dtype(np.bool_), "shape": ()}
44
+ return self.new_tileables(
45
+ [df],
46
+ kws=[new_dataframe_tileable_kw, new_scalar_tileable_kw],
47
+ )
48
+
49
+ @property
50
+ def output_limit(self):
51
+ return 2
52
+
53
+
54
+ def connected_components(
55
+ dataframe, vertex_col1: str, vertex_col2: str, max_iter: int = 6
56
+ ):
57
+ """
58
+ The connected components algorithm labels each node as belonging to a specific connected component with the ID of
59
+ its lowest-numbered vertex.
60
+
61
+ Parameters
62
+ ----------
63
+ dataframe : DataFrame
64
+ A DataFrame containing the edges of the graph.
65
+
66
+ vertex_col1 : str
67
+ The name of the column in `dataframe` that contains the one of edge vertices. The column value must be an
68
+ integer.
69
+
70
+ vertex_col2 : str
71
+ The name of the column in `dataframe` that contains the other one of edge vertices. The column value must be an
72
+ integer.
73
+
74
+ max_iter : int
75
+ The algorithm use large and small star transformation to find all connected components, `max_iter`
76
+ controls the max round of the iterations before finds all edges. Default is 6.
77
+
78
+
79
+ Returns
80
+ -------
81
+ DataFrame
82
+ Return dataFrame contains all connected component edges by two columns `id` and `component`. `component` is
83
+ the lowest-numbered vertex in the connected components.
84
+
85
+ Notes
86
+ -------
87
+ After `execute()`, the dataframe has a bool member `flag` to indicate if the `connected_components` already
88
+ converged in `max_iter` rounds. `True` means the dataframe already contains all edges of the connected components.
89
+ If `False` you can run `connected_components` more times to reach the converged state.
90
+
91
+ Examples
92
+ --------
93
+ >>> import numpy as np
94
+ >>> import maxframe.dataframe as md
95
+ >>> import maxframe.learn.contrib.graph.connected_components
96
+ >>> df = md.DataFrame({'x': [4, 1], 'y': [0, 4]})
97
+ >>> df.execute()
98
+ x y
99
+ 0 4 1
100
+ 1 0 4
101
+
102
+ Get connected components with 1 round iteration.
103
+
104
+ >>> components, converged = connected_components(df, "x", "y", 1)
105
+ >>> session.execute(components, converged)
106
+ >>> components
107
+ A B
108
+ 0 1 0
109
+ 1 4 0
110
+
111
+ >>> converged
112
+ True
113
+
114
+ Sometimes, a single iteration may not be sufficient to propagate the connectivity of all edges.
115
+ By default, `connected_components` performs 6 iterations of calculations.
116
+ If you are unsure whether the connected components have converged, you can check the `flag` variable in
117
+ the output DataFrame after calling `execute()`.
118
+
119
+ >>> df = md.DataFrame({'x': [4, 1, 7, 5, 8, 11, 11], 'y': [0, 4, 4, 7, 7, 9, 13]})
120
+ >>> df.execute()
121
+ x y
122
+ 0 4 0
123
+ 1 1 4
124
+ 2 7 4
125
+ 3 5 7
126
+ 4 8 7
127
+ 5 11 9
128
+ 6 11 13
129
+
130
+ >>> components, converged = connected_components(df, "x", "y", 1)
131
+ >>> session.execute(components, converged)
132
+ >>> components
133
+ id component
134
+ 0 4 0
135
+ 1 7 0
136
+ 2 8 4
137
+ 3 13 9
138
+ 4 1 0
139
+ 5 5 0
140
+ 6 11 9
141
+
142
+ If `flag` is True, it means convergence has been achieved.
143
+
144
+ >>> converged
145
+ False
146
+
147
+ You can determine whether to continue iterating or to use a larger number of iterations
148
+ (but not too large, which would result in wasted computational overhead).
149
+
150
+ >>> components, converged = connected_components(components, "id", "component", 1)
151
+ >>> session.execute(components, converged)
152
+ >>> components
153
+ id component
154
+ 0 4 0
155
+ 1 7 0
156
+ 2 13 9
157
+ 3 1 0
158
+ 4 5 0
159
+ 5 11 9
160
+ 6 8 0
161
+
162
+ >>> components, converged = connected_components(df, "x", "y")
163
+ >>> session.execute(components, converged)
164
+ >>> components
165
+ id component
166
+ 0 4 0
167
+ 1 7 0
168
+ 2 13 9
169
+ 3 1 0
170
+ 4 5 0
171
+ 5 11 9
172
+ 6 8 0
173
+ """
174
+
175
+ # Check if vertex columns are provided
176
+ if not vertex_col1 or not vertex_col2:
177
+ raise ValueError("Both vertex_col1 and vertex_col2 must be provided.")
178
+
179
+ # Check if max_iter is provided and within the valid range
180
+ if max_iter is None:
181
+ raise ValueError("max_iter must be provided.")
182
+ if not (1 <= max_iter <= 50):
183
+ raise ValueError("max_iter must be an integer between 1 and 50.")
184
+
185
+ # Verify that the vertex columns exist in the dataframe
186
+ missing_cols = [
187
+ col for col in (vertex_col1, vertex_col2) if col not in dataframe.dtypes
188
+ ]
189
+ if missing_cols:
190
+ raise ValueError(
191
+ f"The following required columns {missing_cols} are not in {list(dataframe.dtypes.index)}"
192
+ )
193
+
194
+ # Ensure that the vertex columns are of integer type
195
+ # TODO support string dtype
196
+ incorrect_dtypes = [
197
+ col
198
+ for col in (vertex_col1, vertex_col2)
199
+ if dataframe[col].dtype != np.dtype("int")
200
+ ]
201
+ if incorrect_dtypes:
202
+ dtypes_str = ", ".join(str(dataframe[col].dtype) for col in incorrect_dtypes)
203
+ raise ValueError(
204
+ f"Columns {incorrect_dtypes} should be of integer type, but found {dtypes_str}."
205
+ )
206
+
207
+ op = DataFrameConnectedComponentsOperator(
208
+ vertex_col1=vertex_col1,
209
+ vertex_col2=vertex_col2,
210
+ _output_types=[OutputType.dataframe, OutputType.scalar],
211
+ max_iter=max_iter,
212
+ )
213
+ return op(
214
+ dataframe,
215
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,53 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import pytest
17
+
18
+ from ..... import dataframe as md
19
+ from .....dataframe.core import DataFrameData
20
+ from .....tensor.core import TensorData
21
+ from .. import connected_components
22
+
23
+
24
+ @pytest.fixture
25
+ def df1():
26
+ return md.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
27
+
28
+
29
+ @pytest.fixture
30
+ def df2():
31
+ return md.DataFrame(
32
+ [[1, "2"], [1, "2"]],
33
+ columns=["a", "b"],
34
+ )
35
+
36
+
37
+ def test_connected_components(df1, df2):
38
+ edges, flag = connected_components(df1, "a", "b")
39
+ assert edges.op.max_iter == 6
40
+ assert edges.shape == (np.nan, 2)
41
+ assert isinstance(edges.data, DataFrameData)
42
+ assert isinstance(flag.data, TensorData)
43
+ assert flag.shape == ()
44
+ assert "id" in edges.dtypes and "component" in edges.dtypes
45
+
46
+ with pytest.raises(ValueError):
47
+ connected_components(df1, "a", "x")
48
+
49
+ with pytest.raises(ValueError):
50
+ connected_components(df1, "a", "b", 0)
51
+
52
+ with pytest.raises(ValueError):
53
+ connected_components(df2, "a", "b")
@@ -14,7 +14,8 @@
14
14
 
15
15
  import numpy as np
16
16
 
17
- from ....tensor import argmax, transpose, vstack
17
+ from ....tensor import argmax, transpose
18
+ from ....tensor.merge.vstack import _vstack
18
19
  from ..utils import make_import_error_func
19
20
  from .core import XGBScikitLearnBase, xgboost
20
21
 
@@ -89,7 +90,6 @@ else:
89
90
  if ntree_limit is not None:
90
91
  raise NotImplementedError("ntree_limit is not currently supported")
91
92
  prediction = predict(self.get_booster(), data, flag=flag, **kw)
92
-
93
93
  if len(prediction.shape) == 2 and prediction.shape[1] == self.n_classes_:
94
94
  # multi-class
95
95
  return prediction
@@ -103,7 +103,7 @@ else:
103
103
  # binary logistic function
104
104
  classone_probs = prediction
105
105
  classzero_probs = 1.0 - classone_probs
106
- return transpose(vstack((classzero_probs, classone_probs)))
106
+ return transpose(_vstack((classzero_probs, classone_probs)))
107
107
 
108
108
  @property
109
109
  def classes_(self) -> np.ndarray:
@@ -14,20 +14,18 @@
14
14
 
15
15
 
16
16
  import numpy as np
17
- import pandas as pd
18
17
 
19
18
  from .... import opcodes
20
19
  from ....core.entity.output_types import OutputType
21
20
  from ....core.operator.base import Operator
22
21
  from ....core.operator.core import TileableOperatorMixin
23
- from ....dataframe.utils import parse_index
24
22
  from ....serialization.serializables import (
25
23
  BoolField,
26
24
  KeyField,
27
25
  ReferenceField,
28
26
  TupleField,
29
27
  )
30
- from ....tensor.core import TENSOR_TYPE, TensorOrder
28
+ from ....tensor.core import TensorOrder
31
29
  from .core import BoosterData
32
30
  from .dmatrix import check_data
33
31
 
@@ -65,35 +63,12 @@ class XGBPredict(Operator, TileableOperatorMixin):
65
63
  else:
66
64
  shape = (self.data.shape[0],)
67
65
  inputs = [self.data, self.model]
68
- if self.output_types[0] == OutputType.tensor:
69
- # tensor
70
- return self.new_tileable(
71
- inputs,
72
- shape=shape,
73
- dtype=self.output_dtype,
74
- order=TensorOrder.C_ORDER,
75
- )
76
- elif self.output_types[0] == OutputType.dataframe:
77
- # dataframe
78
- dtypes = pd.DataFrame(
79
- np.random.rand(0, num_class), dtype=self.output_dtype
80
- ).dtypes
81
- return self.new_tileable(
82
- inputs,
83
- shape=shape,
84
- dtypes=dtypes,
85
- columns_value=parse_index(dtypes.index),
86
- index_value=self.data.index_value,
87
- )
88
- else:
89
- # series
90
- return self.new_tileable(
91
- inputs,
92
- shape=shape,
93
- index_value=self.data.index_value,
94
- name="predictions",
95
- dtype=self.output_dtype,
96
- )
66
+ return self.new_tileable(
67
+ inputs,
68
+ shape=shape,
69
+ dtype=self.output_dtype,
70
+ order=TensorOrder.C_ORDER,
71
+ )
97
72
 
98
73
 
99
74
  def predict(
@@ -124,13 +99,7 @@ def predict(
124
99
  data = check_data(data)
125
100
  # TODO: check model datatype
126
101
 
127
- num_class = getattr(model.op, "num_class", None)
128
- if isinstance(data, TENSOR_TYPE):
129
- output_types = [OutputType.tensor]
130
- elif num_class is not None:
131
- output_types = [OutputType.dataframe]
132
- else:
133
- output_types = [OutputType.series]
102
+ output_types = [OutputType.tensor]
134
103
 
135
104
  iteration_range = iteration_range or (0, 0)
136
105
 
@@ -102,7 +102,7 @@ def train(params, dtrain, evals=None, evals_result=None, num_class=None, **kwarg
102
102
  Parameters
103
103
  ----------
104
104
  Parameters are the same as `xgboost.train`. Note that train is an eager-execution
105
- API. The call will be blocked until training finished.
105
+ API if evals is passed, thus the call will be blocked until training finished.
106
106
 
107
107
  Returns
108
108
  -------
@@ -121,11 +121,12 @@ def train(params, dtrain, evals=None, evals_result=None, num_class=None, **kwarg
121
121
  processed_evals.append((eval_dmatrix, name))
122
122
  else:
123
123
  processed_evals.append((to_dmatrix(eval_dmatrix), name))
124
- return XGBTrain(
124
+ data = XGBTrain(
125
125
  params=params,
126
126
  dtrain=dtrain,
127
127
  evals=processed_evals,
128
128
  evals_result=evals_result,
129
129
  num_class=num_class,
130
130
  **kwargs,
131
- )(evals_result).execute(session=session, **run_kwargs)
131
+ )(evals_result)
132
+ return data.execute(session=session, **run_kwargs) if evals else data
Binary file