maxframe 1.0.0rc2__cp37-cp37m-win_amd64.whl → 1.0.0rc4__cp37-cp37m-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (133) hide show
  1. maxframe/_utils.cp37-win_amd64.pyd +0 -0
  2. maxframe/codegen.py +4 -2
  3. maxframe/config/config.py +28 -9
  4. maxframe/config/validators.py +42 -12
  5. maxframe/conftest.py +56 -14
  6. maxframe/core/__init__.py +2 -13
  7. maxframe/core/entity/__init__.py +0 -4
  8. maxframe/core/entity/executable.py +1 -1
  9. maxframe/core/entity/objects.py +45 -2
  10. maxframe/core/entity/output_types.py +0 -3
  11. maxframe/core/entity/tests/test_objects.py +43 -0
  12. maxframe/core/entity/tileables.py +5 -78
  13. maxframe/core/graph/__init__.py +2 -2
  14. maxframe/core/graph/builder/__init__.py +0 -1
  15. maxframe/core/graph/builder/base.py +5 -4
  16. maxframe/core/graph/builder/tileable.py +4 -4
  17. maxframe/core/graph/builder/utils.py +4 -8
  18. maxframe/core/graph/core.cp37-win_amd64.pyd +0 -0
  19. maxframe/core/graph/entity.py +9 -33
  20. maxframe/core/operator/__init__.py +2 -9
  21. maxframe/core/operator/base.py +3 -5
  22. maxframe/core/operator/objects.py +0 -9
  23. maxframe/core/operator/utils.py +55 -0
  24. maxframe/dataframe/arithmetic/docstring.py +26 -2
  25. maxframe/dataframe/arithmetic/equal.py +4 -2
  26. maxframe/dataframe/arithmetic/greater.py +4 -2
  27. maxframe/dataframe/arithmetic/greater_equal.py +4 -2
  28. maxframe/dataframe/arithmetic/less.py +2 -2
  29. maxframe/dataframe/arithmetic/less_equal.py +4 -2
  30. maxframe/dataframe/arithmetic/not_equal.py +4 -2
  31. maxframe/dataframe/core.py +2 -0
  32. maxframe/dataframe/datasource/read_odps_query.py +67 -8
  33. maxframe/dataframe/datasource/read_odps_table.py +4 -2
  34. maxframe/dataframe/datasource/tests/test_datasource.py +35 -6
  35. maxframe/dataframe/datastore/to_odps.py +8 -1
  36. maxframe/dataframe/extensions/__init__.py +3 -0
  37. maxframe/dataframe/extensions/flatmap.py +326 -0
  38. maxframe/dataframe/extensions/tests/test_extensions.py +62 -1
  39. maxframe/dataframe/indexing/add_prefix_suffix.py +1 -1
  40. maxframe/dataframe/indexing/rename.py +11 -0
  41. maxframe/dataframe/initializer.py +11 -1
  42. maxframe/dataframe/misc/drop_duplicates.py +18 -1
  43. maxframe/dataframe/operators.py +1 -17
  44. maxframe/dataframe/reduction/core.py +2 -2
  45. maxframe/dataframe/tests/test_initializer.py +33 -2
  46. maxframe/io/objects/__init__.py +24 -0
  47. maxframe/io/objects/core.py +140 -0
  48. maxframe/io/objects/tensor.py +76 -0
  49. maxframe/io/objects/tests/__init__.py +13 -0
  50. maxframe/io/objects/tests/test_object_io.py +97 -0
  51. maxframe/{odpsio → io/odpsio}/__init__.py +2 -0
  52. maxframe/{odpsio → io/odpsio}/arrow.py +4 -4
  53. maxframe/{odpsio → io/odpsio}/schema.py +10 -8
  54. maxframe/{odpsio → io/odpsio}/tableio.py +50 -38
  55. maxframe/io/odpsio/tests/__init__.py +13 -0
  56. maxframe/{odpsio → io/odpsio}/tests/test_schema.py +3 -7
  57. maxframe/{odpsio → io/odpsio}/tests/test_tableio.py +3 -3
  58. maxframe/{odpsio → io/odpsio}/tests/test_volumeio.py +4 -6
  59. maxframe/io/odpsio/volumeio.py +63 -0
  60. maxframe/learn/contrib/__init__.py +2 -1
  61. maxframe/learn/contrib/graph/__init__.py +15 -0
  62. maxframe/learn/contrib/graph/connected_components.py +215 -0
  63. maxframe/learn/contrib/graph/tests/__init__.py +13 -0
  64. maxframe/learn/contrib/graph/tests/test_connected_components.py +53 -0
  65. maxframe/learn/contrib/xgboost/classifier.py +26 -2
  66. maxframe/learn/contrib/xgboost/core.py +87 -2
  67. maxframe/learn/contrib/xgboost/dmatrix.py +1 -4
  68. maxframe/learn/contrib/xgboost/predict.py +27 -44
  69. maxframe/learn/contrib/xgboost/regressor.py +3 -10
  70. maxframe/learn/contrib/xgboost/train.py +27 -16
  71. maxframe/{core/operator/fuse.py → learn/core.py} +7 -10
  72. maxframe/lib/mmh3.cp37-win_amd64.pyd +0 -0
  73. maxframe/opcodes.py +3 -0
  74. maxframe/protocol.py +7 -16
  75. maxframe/remote/core.py +4 -8
  76. maxframe/serialization/__init__.py +1 -0
  77. maxframe/serialization/core.cp37-win_amd64.pyd +0 -0
  78. maxframe/session.py +9 -2
  79. maxframe/tensor/__init__.py +10 -2
  80. maxframe/tensor/arithmetic/isclose.py +1 -0
  81. maxframe/tensor/arithmetic/tests/test_arithmetic.py +21 -17
  82. maxframe/tensor/core.py +5 -136
  83. maxframe/tensor/datasource/array.py +3 -0
  84. maxframe/tensor/datasource/full.py +1 -1
  85. maxframe/tensor/datasource/tests/test_datasource.py +1 -1
  86. maxframe/tensor/indexing/flatnonzero.py +1 -1
  87. maxframe/tensor/indexing/getitem.py +2 -0
  88. maxframe/tensor/merge/__init__.py +2 -0
  89. maxframe/tensor/merge/concatenate.py +101 -0
  90. maxframe/tensor/merge/tests/test_merge.py +30 -1
  91. maxframe/tensor/merge/vstack.py +74 -0
  92. maxframe/tensor/{base → misc}/__init__.py +2 -0
  93. maxframe/tensor/{base → misc}/atleast_1d.py +0 -2
  94. maxframe/tensor/misc/atleast_2d.py +70 -0
  95. maxframe/tensor/misc/atleast_3d.py +85 -0
  96. maxframe/tensor/misc/tests/__init__.py +13 -0
  97. maxframe/tensor/{base → misc}/transpose.py +22 -18
  98. maxframe/tensor/operators.py +1 -7
  99. maxframe/tensor/random/core.py +1 -1
  100. maxframe/tensor/reduction/count_nonzero.py +1 -0
  101. maxframe/tensor/reduction/mean.py +1 -0
  102. maxframe/tensor/reduction/nanmean.py +1 -0
  103. maxframe/tensor/reduction/nanvar.py +2 -0
  104. maxframe/tensor/reduction/tests/test_reduction.py +12 -1
  105. maxframe/tensor/reduction/var.py +2 -0
  106. maxframe/tensor/utils.py +2 -22
  107. maxframe/typing_.py +4 -1
  108. maxframe/udf.py +8 -9
  109. maxframe/utils.py +49 -73
  110. {maxframe-1.0.0rc2.dist-info → maxframe-1.0.0rc4.dist-info}/METADATA +2 -75
  111. {maxframe-1.0.0rc2.dist-info → maxframe-1.0.0rc4.dist-info}/RECORD +129 -114
  112. maxframe_client/fetcher.py +33 -50
  113. maxframe_client/session/consts.py +3 -0
  114. maxframe_client/session/graph.py +8 -2
  115. maxframe_client/session/odps.py +134 -27
  116. maxframe_client/session/task.py +58 -20
  117. maxframe_client/tests/test_fetcher.py +1 -1
  118. maxframe_client/tests/test_session.py +27 -3
  119. maxframe/core/entity/chunks.py +0 -68
  120. maxframe/core/entity/fuse.py +0 -73
  121. maxframe/core/graph/builder/chunk.py +0 -430
  122. maxframe/odpsio/volumeio.py +0 -95
  123. /maxframe/{odpsio → core/entity}/tests/__init__.py +0 -0
  124. /maxframe/{tensor/base/tests → io}/__init__.py +0 -0
  125. /maxframe/{odpsio → io/odpsio}/tests/test_arrow.py +0 -0
  126. /maxframe/tensor/{base → misc}/astype.py +0 -0
  127. /maxframe/tensor/{base → misc}/broadcast_to.py +0 -0
  128. /maxframe/tensor/{base → misc}/ravel.py +0 -0
  129. /maxframe/tensor/{base/tests/test_base.py → misc/tests/test_misc.py} +0 -0
  130. /maxframe/tensor/{base → misc}/unique.py +0 -0
  131. /maxframe/tensor/{base → misc}/where.py +0 -0
  132. {maxframe-1.0.0rc2.dist-info → maxframe-1.0.0rc4.dist-info}/WHEEL +0 -0
  133. {maxframe-1.0.0rc2.dist-info → maxframe-1.0.0rc4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,13 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -18,9 +18,9 @@ import pyarrow as pa
18
18
  import pytest
19
19
  from odps import types as odps_types
20
20
 
21
- from ... import dataframe as md
22
- from ... import tensor as mt
23
- from ...core import OutputType
21
+ from .... import dataframe as md
22
+ from .... import tensor as mt
23
+ from ....core import OutputType
24
24
  from ..schema import (
25
25
  arrow_schema_to_odps_schema,
26
26
  build_dataframe_table_meta,
@@ -270,10 +270,6 @@ def test_odps_arrow_schema_conversion():
270
270
 
271
271
  with pytest.raises(TypeError):
272
272
  arrow_schema_to_odps_schema(pa.schema([("col1", pa.float16())]))
273
- with pytest.raises(TypeError):
274
- odps_schema_to_arrow_schema(
275
- odps_types.OdpsSchema([odps_types.Column("col1", "json")])
276
- )
277
273
 
278
274
 
279
275
  def test_build_column_name():
@@ -20,9 +20,9 @@ import pyarrow as pa
20
20
  import pytest
21
21
  from odps import ODPS
22
22
 
23
- from ...config import options
24
- from ...tests.utils import flaky, tn
25
- from ...utils import config_odps_default_options
23
+ from ....config import options
24
+ from ....tests.utils import flaky, tn
25
+ from ....utils import config_odps_default_options
26
26
  from ..tableio import ODPSTableIO
27
27
 
28
28
 
@@ -15,7 +15,7 @@
15
15
  import pytest
16
16
  from odps import ODPS
17
17
 
18
- from ...tests.utils import tn
18
+ from ....tests.utils import tn
19
19
  from ..volumeio import ODPSVolumeReader, ODPSVolumeWriter
20
20
 
21
21
 
@@ -69,19 +69,17 @@ def create_volume(request, oss_config):
69
69
  oss_config.oss_bucket.batch_delete_objects(keys)
70
70
 
71
71
 
72
- @pytest.mark.parametrize("create_volume", ["parted", "external"], indirect=True)
72
+ @pytest.mark.parametrize("create_volume", ["external"], indirect=True)
73
73
  def test_read_write_volume(create_volume):
74
74
  test_vol_dir = "test_vol_dir"
75
75
 
76
76
  odps_entry = ODPS.from_environments()
77
77
 
78
78
  writer = ODPSVolumeWriter(odps_entry, create_volume, test_vol_dir)
79
- write_session_id = writer.create_write_session()
80
79
 
81
80
  writer = ODPSVolumeWriter(odps_entry, create_volume, test_vol_dir)
82
- writer.write_file("file1", b"content1", write_session_id)
83
- writer.write_file("file2", b"content2", write_session_id)
84
- writer.commit(["file1", "file2"], write_session_id)
81
+ writer.write_file("file1", b"content1")
82
+ writer.write_file("file2", b"content2")
85
83
 
86
84
  reader = ODPSVolumeReader(odps_entry, create_volume, test_vol_dir)
87
85
  assert reader.read_file("file1") == b"content1"
@@ -0,0 +1,63 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Iterator, List, Optional, Union
17
+
18
+ from odps import ODPS
19
+
20
+
21
+ class ODPSVolumeReader:
22
+ def __init__(self, odps_entry: ODPS, volume_name: str, volume_dir: str):
23
+ self._odps_entry = odps_entry
24
+ self._volume = odps_entry.get_volume(volume_name)
25
+ self._volume_dir = volume_dir
26
+
27
+ def list_files(self) -> List[str]:
28
+ def _get_file_name(vol_file):
29
+ if hasattr(vol_file, "name"):
30
+ return vol_file.name
31
+ return vol_file.path.rsplit("/", 1)[-1]
32
+
33
+ return [
34
+ _get_file_name(f)
35
+ for f in self._odps_entry.list_volume_files(
36
+ f"/{self._volume.name}/{self._volume_dir}"
37
+ )
38
+ ]
39
+
40
+ def read_file(self, file_name: str) -> bytes:
41
+ with self._volume.open_reader(self._volume_dir + "/" + file_name) as reader:
42
+ return reader.read()
43
+
44
+
45
+ class ODPSVolumeWriter:
46
+ def __init__(
47
+ self,
48
+ odps_entry: ODPS,
49
+ volume_name: str,
50
+ volume_dir: str,
51
+ schema_name: Optional[str] = None,
52
+ ):
53
+ self._odps_entry = odps_entry
54
+ self._volume = odps_entry.get_volume(volume_name, schema=schema_name)
55
+ self._volume_dir = volume_dir
56
+
57
+ def write_file(self, file_name: str, data: Union[bytes, Iterator[bytes]]):
58
+ with self._volume.open_writer(self._volume_dir + "/" + file_name) as writer:
59
+ if not inspect.isgenerator(data):
60
+ writer.write(data)
61
+ else:
62
+ for chunk in data:
63
+ writer.write(chunk)
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from . import pytorch
15
+ from . import graph, pytorch
16
16
 
17
17
  del pytorch
18
+ del graph
@@ -0,0 +1,15 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .connected_components import connected_components
@@ -0,0 +1,215 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+
18
+ from maxframe import opcodes
19
+
20
+ from ....core import OutputType
21
+ from ....dataframe.operators import DataFrameOperator, DataFrameOperatorMixin
22
+ from ....dataframe.utils import make_dtypes, parse_index
23
+ from ....serialization.serializables import Int32Field, StringField
24
+
25
+
26
+ class DataFrameConnectedComponentsOperator(DataFrameOperator, DataFrameOperatorMixin):
27
+ _op_type_ = opcodes.CONNECTED_COMPONENTS
28
+
29
+ vertex_col1 = StringField("vertex_col1", default=None)
30
+ vertex_col2 = StringField("vertex_col2", default=None)
31
+ max_iter = Int32Field("max_iter", default=6)
32
+
33
+ def __call__(self, df):
34
+ node_id_dtype = df.dtypes[self.vertex_col1]
35
+ dtypes = make_dtypes({"id": node_id_dtype, "component": node_id_dtype})
36
+ # this will return a dataframe and a bool flag
37
+ new_dataframe_tileable_kw = {
38
+ "shape": (np.nan, 2),
39
+ "index_value": parse_index(pd.RangeIndex(0)),
40
+ "columns_value": parse_index(dtypes.index, store_data=True),
41
+ "dtypes": dtypes,
42
+ }
43
+ new_scalar_tileable_kw = {"dtype": np.dtype(np.bool_), "shape": ()}
44
+ return self.new_tileables(
45
+ [df],
46
+ kws=[new_dataframe_tileable_kw, new_scalar_tileable_kw],
47
+ )
48
+
49
+ @property
50
+ def output_limit(self):
51
+ return 2
52
+
53
+
54
+ def connected_components(
55
+ dataframe, vertex_col1: str, vertex_col2: str, max_iter: int = 6
56
+ ):
57
+ """
58
+ The connected components algorithm labels each node as belonging to a specific connected component with the ID of
59
+ its lowest-numbered vertex.
60
+
61
+ Parameters
62
+ ----------
63
+ dataframe : DataFrame
64
+ A DataFrame containing the edges of the graph.
65
+
66
+ vertex_col1 : str
67
+ The name of the column in `dataframe` that contains the one of edge vertices. The column value must be an
68
+ integer.
69
+
70
+ vertex_col2 : str
71
+ The name of the column in `dataframe` that contains the other one of edge vertices. The column value must be an
72
+ integer.
73
+
74
+ max_iter : int
75
+ The algorithm use large and small star transformation to find all connected components, `max_iter`
76
+ controls the max round of the iterations before finds all edges. Default is 6.
77
+
78
+
79
+ Returns
80
+ -------
81
+ DataFrame
82
+ Return dataFrame contains all connected component edges by two columns `id` and `component`. `component` is
83
+ the lowest-numbered vertex in the connected components.
84
+
85
+ Notes
86
+ -------
87
+ After `execute()`, the dataframe has a bool member `flag` to indicate if the `connected_components` already
88
+ converged in `max_iter` rounds. `True` means the dataframe already contains all edges of the connected components.
89
+ If `False` you can run `connected_components` more times to reach the converged state.
90
+
91
+ Examples
92
+ --------
93
+ >>> import numpy as np
94
+ >>> import maxframe.dataframe as md
95
+ >>> import maxframe.learn.contrib.graph.connected_components
96
+ >>> df = md.DataFrame({'x': [4, 1], 'y': [0, 4]})
97
+ >>> df.execute()
98
+ x y
99
+ 0 4 1
100
+ 1 0 4
101
+
102
+ Get connected components with 1 round iteration.
103
+
104
+ >>> components, converged = connected_components(df, "x", "y", 1)
105
+ >>> session.execute(components, converged)
106
+ >>> components
107
+ A B
108
+ 0 1 0
109
+ 1 4 0
110
+
111
+ >>> converged
112
+ True
113
+
114
+ Sometimes, a single iteration may not be sufficient to propagate the connectivity of all edges.
115
+ By default, `connected_components` performs 6 iterations of calculations.
116
+ If you are unsure whether the connected components have converged, you can check the `flag` variable in
117
+ the output DataFrame after calling `execute()`.
118
+
119
+ >>> df = md.DataFrame({'x': [4, 1, 7, 5, 8, 11, 11], 'y': [0, 4, 4, 7, 7, 9, 13]})
120
+ >>> df.execute()
121
+ x y
122
+ 0 4 0
123
+ 1 1 4
124
+ 2 7 4
125
+ 3 5 7
126
+ 4 8 7
127
+ 5 11 9
128
+ 6 11 13
129
+
130
+ >>> components, converged = connected_components(df, "x", "y", 1)
131
+ >>> session.execute(components, converged)
132
+ >>> components
133
+ id component
134
+ 0 4 0
135
+ 1 7 0
136
+ 2 8 4
137
+ 3 13 9
138
+ 4 1 0
139
+ 5 5 0
140
+ 6 11 9
141
+
142
+ If `flag` is True, it means convergence has been achieved.
143
+
144
+ >>> converged
145
+ False
146
+
147
+ You can determine whether to continue iterating or to use a larger number of iterations
148
+ (but not too large, which would result in wasted computational overhead).
149
+
150
+ >>> components, converged = connected_components(components, "id", "component", 1)
151
+ >>> session.execute(components, converged)
152
+ >>> components
153
+ id component
154
+ 0 4 0
155
+ 1 7 0
156
+ 2 13 9
157
+ 3 1 0
158
+ 4 5 0
159
+ 5 11 9
160
+ 6 8 0
161
+
162
+ >>> components, converged = connected_components(df, "x", "y")
163
+ >>> session.execute(components, converged)
164
+ >>> components
165
+ id component
166
+ 0 4 0
167
+ 1 7 0
168
+ 2 13 9
169
+ 3 1 0
170
+ 4 5 0
171
+ 5 11 9
172
+ 6 8 0
173
+ """
174
+
175
+ # Check if vertex columns are provided
176
+ if not vertex_col1 or not vertex_col2:
177
+ raise ValueError("Both vertex_col1 and vertex_col2 must be provided.")
178
+
179
+ # Check if max_iter is provided and within the valid range
180
+ if max_iter is None:
181
+ raise ValueError("max_iter must be provided.")
182
+ if not (1 <= max_iter <= 50):
183
+ raise ValueError("max_iter must be an integer between 1 and 50.")
184
+
185
+ # Verify that the vertex columns exist in the dataframe
186
+ missing_cols = [
187
+ col for col in (vertex_col1, vertex_col2) if col not in dataframe.dtypes
188
+ ]
189
+ if missing_cols:
190
+ raise ValueError(
191
+ f"The following required columns {missing_cols} are not in {list(dataframe.dtypes.index)}"
192
+ )
193
+
194
+ # Ensure that the vertex columns are of integer type
195
+ # TODO support string dtype
196
+ incorrect_dtypes = [
197
+ col
198
+ for col in (vertex_col1, vertex_col2)
199
+ if dataframe[col].dtype != np.dtype("int")
200
+ ]
201
+ if incorrect_dtypes:
202
+ dtypes_str = ", ".join(str(dataframe[col].dtype) for col in incorrect_dtypes)
203
+ raise ValueError(
204
+ f"Columns {incorrect_dtypes} should be of integer type, but found {dtypes_str}."
205
+ )
206
+
207
+ op = DataFrameConnectedComponentsOperator(
208
+ vertex_col1=vertex_col1,
209
+ vertex_col2=vertex_col2,
210
+ _output_types=[OutputType.dataframe, OutputType.scalar],
211
+ max_iter=max_iter,
212
+ )
213
+ return op(
214
+ dataframe,
215
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,53 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import pytest
17
+
18
+ from ..... import dataframe as md
19
+ from .....dataframe.core import DataFrameData
20
+ from .....tensor.core import TensorData
21
+ from .. import connected_components
22
+
23
+
24
+ @pytest.fixture
25
+ def df1():
26
+ return md.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
27
+
28
+
29
+ @pytest.fixture
30
+ def df2():
31
+ return md.DataFrame(
32
+ [[1, "2"], [1, "2"]],
33
+ columns=["a", "b"],
34
+ )
35
+
36
+
37
+ def test_connected_components(df1, df2):
38
+ edges, flag = connected_components(df1, "a", "b")
39
+ assert edges.op.max_iter == 6
40
+ assert edges.shape == (np.nan, 2)
41
+ assert isinstance(edges.data, DataFrameData)
42
+ assert isinstance(flag.data, TensorData)
43
+ assert flag.shape == ()
44
+ assert "id" in edges.dtypes and "component" in edges.dtypes
45
+
46
+ with pytest.raises(ValueError):
47
+ connected_components(df1, "a", "x")
48
+
49
+ with pytest.raises(ValueError):
50
+ connected_components(df1, "a", "b", 0)
51
+
52
+ with pytest.raises(ValueError):
53
+ connected_components(df2, "a", "b")
@@ -14,7 +14,8 @@
14
14
 
15
15
  import numpy as np
16
16
 
17
- from ....tensor import argmax
17
+ from ....tensor import argmax, transpose
18
+ from ....tensor.merge.vstack import _vstack
18
19
  from ..utils import make_import_error_func
19
20
  from .core import XGBScikitLearnBase, xgboost
20
21
 
@@ -42,7 +43,10 @@ else:
42
43
  sample_weight_eval_set=None,
43
44
  base_margin_eval_set=None,
44
45
  num_class=None,
46
+ **kw,
45
47
  ):
48
+ session = kw.pop("session", None)
49
+ run_kwargs = kw.pop("run_kwargs", dict())
46
50
  dtrain, evals = wrap_evaluation_matrices(
47
51
  None,
48
52
  X,
@@ -68,6 +72,8 @@ else:
68
72
  evals=evals,
69
73
  evals_result=self.evals_result_,
70
74
  num_class=num_class,
75
+ session=session,
76
+ run_kwargs=run_kwargs,
71
77
  )
72
78
  self._Booster = result
73
79
  return self
@@ -83,4 +89,22 @@ else:
83
89
  def predict_proba(self, data, ntree_limit=None, flag=False, **kw):
84
90
  if ntree_limit is not None:
85
91
  raise NotImplementedError("ntree_limit is not currently supported")
86
- return predict(self.get_booster(), data, flag=flag, **kw)
92
+ prediction = predict(self.get_booster(), data, flag=flag, **kw)
93
+ if len(prediction.shape) == 2 and prediction.shape[1] == self.n_classes_:
94
+ # multi-class
95
+ return prediction
96
+ if (
97
+ len(prediction.shape) == 2
98
+ and self.n_classes_ == 2
99
+ and prediction.shape[1] >= self.n_classes_
100
+ ):
101
+ # multi-label
102
+ return prediction
103
+ # binary logistic function
104
+ classone_probs = prediction
105
+ classzero_probs = 1.0 - classone_probs
106
+ return transpose(_vstack((classzero_probs, classone_probs)))
107
+
108
+ @property
109
+ def classes_(self) -> np.ndarray:
110
+ return np.arange(self.n_classes_)
@@ -12,15 +12,67 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Callable, List, Optional, Tuple
15
+ from typing import Any, Callable, Dict, List, Optional, Tuple
16
16
 
17
17
  try:
18
18
  import xgboost
19
19
  except ImportError:
20
20
  xgboost = None
21
21
 
22
+ from ...core import Model, ModelData
22
23
  from .dmatrix import DMatrix
23
24
 
25
+
26
+ class BoosterData(ModelData):
27
+ __slots__ = ("_evals_result",)
28
+
29
+ _evals_result: Dict
30
+
31
+ def __init__(self, *args, evals_result=None, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+ self._evals_result = evals_result if evals_result is not None else dict()
34
+
35
+ def execute(self, session=None, **kw):
36
+ # The evals_result should be fetched when BoosterData.execute() is called.
37
+ result = super().execute(session=session, **kw)
38
+ if self.op.has_evals_result and self.key == self.op.outputs[0].key:
39
+ self._evals_result.update(self.op.outputs[1].fetch(session=session))
40
+ return result
41
+
42
+ def predict(
43
+ self,
44
+ data,
45
+ output_margin=False,
46
+ pred_leaf=False,
47
+ pred_contribs=False,
48
+ approx_contribs=False,
49
+ pred_interactions=False,
50
+ validate_features=True,
51
+ training=False,
52
+ iteration_range=None,
53
+ strict_shape=False,
54
+ ):
55
+ from .predict import predict
56
+
57
+ return predict(
58
+ self,
59
+ data,
60
+ output_margin=output_margin,
61
+ pred_leaf=pred_leaf,
62
+ pred_contribs=pred_contribs,
63
+ approx_contribs=approx_contribs,
64
+ pred_interactions=pred_interactions,
65
+ validate_features=validate_features,
66
+ training=training,
67
+ iteration_range=iteration_range,
68
+ strict_shape=strict_shape,
69
+ )
70
+
71
+
72
+ class Booster(Model):
73
+ pass
74
+
75
+
24
76
  if not xgboost:
25
77
  XGBScikitLearnBase = None
26
78
  else:
@@ -40,7 +92,9 @@ else:
40
92
  **kw,
41
93
  ):
42
94
  """
43
- Fit the regressor.
95
+ Fit the regressor. Note that fit() is an eager-execution
96
+ API. The call will be blocked until training finished.
97
+
44
98
  Parameters
45
99
  ----------
46
100
  X : array_like
@@ -72,6 +126,37 @@ else:
72
126
  """
73
127
  raise NotImplementedError
74
128
 
129
+ def evals_result(self, **kw) -> Dict:
130
+ """Return the evaluation results.
131
+
132
+ If **eval_set** is passed to the :py:meth:`fit` function, you can call
133
+ ``evals_result()`` to get evaluation results for all passed **eval_sets**. When
134
+ **eval_metric** is also passed to the :py:meth:`fit` function, the
135
+ **evals_result** will contain the **eval_metrics** passed to the :py:meth:`fit`
136
+ function.
137
+
138
+ The returned evaluation result is a dictionary:
139
+
140
+ .. code-block:: python
141
+
142
+ {'validation_0': {'logloss': ['0.604835', '0.531479']},
143
+ 'validation_1': {'logloss': ['0.41965', '0.17686']}}
144
+
145
+ Note that evals_result() will be blocked until the train is finished.
146
+
147
+ Returns
148
+ -------
149
+ evals_result
150
+
151
+ """
152
+ result = super().evals_result()
153
+ if not self._Booster.op.has_evals_result or len(result) != 0:
154
+ return result
155
+ session = kw.pop("session", None)
156
+ run_kwargs = kw.pop("run_kwargs", dict())
157
+ self._Booster.execute(session=session, **run_kwargs)
158
+ return super().evals_result()
159
+
75
160
  def wrap_evaluation_matrices(
76
161
  missing: float,
77
162
  X: Any,
@@ -99,10 +99,7 @@ def check_array_like(y: TileableType, name: str) -> TileableType:
99
99
  y = convert_to_tensor_or_dataframe(y)
100
100
  if isinstance(y, DATAFRAME_TYPE):
101
101
  y = y.iloc[:, 0]
102
- y = astensor(y)
103
- if y.ndim != 1:
104
- raise ValueError(f"Expecting 1-d {name}, got: {y.ndim}-d")
105
- return y
102
+ return astensor(y)
106
103
 
107
104
 
108
105
  def to_dmatrix(