maxframe 0.1.0b4__cp39-cp39-win32.whl → 0.1.0b5__cp39-cp39-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of maxframe might be problematic. Click here for more details.
- maxframe/__init__.py +1 -0
- maxframe/_utils.cp39-win32.pyd +0 -0
- maxframe/codegen.py +46 -1
- maxframe/config/config.py +11 -1
- maxframe/core/graph/core.cp39-win32.pyd +0 -0
- maxframe/dataframe/__init__.py +1 -0
- maxframe/dataframe/core.py +30 -8
- maxframe/dataframe/datasource/read_odps_query.py +3 -1
- maxframe/dataframe/datasource/read_odps_table.py +3 -1
- maxframe/dataframe/misc/__init__.py +4 -0
- maxframe/dataframe/misc/apply.py +1 -1
- maxframe/dataframe/misc/case_when.py +141 -0
- maxframe/dataframe/misc/pivot_table.py +262 -0
- maxframe/dataframe/misc/tests/test_misc.py +61 -0
- maxframe/dataframe/plotting/core.py +2 -2
- maxframe/dataframe/reduction/core.py +2 -1
- maxframe/dataframe/utils.py +7 -0
- maxframe/learn/contrib/utils.py +52 -0
- maxframe/learn/contrib/xgboost/__init__.py +26 -0
- maxframe/learn/contrib/xgboost/classifier.py +86 -0
- maxframe/learn/contrib/xgboost/core.py +156 -0
- maxframe/learn/contrib/xgboost/dmatrix.py +150 -0
- maxframe/learn/contrib/xgboost/predict.py +138 -0
- maxframe/learn/contrib/xgboost/regressor.py +78 -0
- maxframe/learn/contrib/xgboost/tests/__init__.py +13 -0
- maxframe/learn/contrib/xgboost/tests/test_core.py +43 -0
- maxframe/learn/contrib/xgboost/train.py +121 -0
- maxframe/learn/utils/__init__.py +15 -0
- maxframe/learn/utils/core.py +29 -0
- maxframe/lib/mmh3.cp39-win32.pyd +0 -0
- maxframe/odpsio/arrow.py +2 -3
- maxframe/odpsio/tableio.py +22 -0
- maxframe/odpsio/tests/test_schema.py +16 -11
- maxframe/opcodes.py +3 -0
- maxframe/serialization/core.cp39-win32.pyd +0 -0
- maxframe/serialization/core.pyi +61 -0
- maxframe/session.py +28 -0
- maxframe/tensor/__init__.py +1 -1
- maxframe/tensor/base/__init__.py +2 -0
- maxframe/tensor/base/atleast_1d.py +74 -0
- maxframe/tensor/base/unique.py +205 -0
- maxframe/tensor/datasource/array.py +4 -2
- maxframe/tensor/datasource/scalar.py +1 -1
- maxframe/udf.py +63 -3
- maxframe/utils.py +6 -0
- {maxframe-0.1.0b4.dist-info → maxframe-0.1.0b5.dist-info}/METADATA +2 -2
- {maxframe-0.1.0b4.dist-info → maxframe-0.1.0b5.dist-info}/RECORD +53 -36
- maxframe_client/fetcher.py +65 -3
- maxframe_client/session/odps.py +30 -1
- maxframe_client/session/task.py +26 -53
- maxframe_client/tests/test_session.py +28 -1
- {maxframe-0.1.0b4.dist-info → maxframe-0.1.0b5.dist-info}/WHEEL +0 -0
- {maxframe-0.1.0b4.dist-info → maxframe-0.1.0b5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
# Copyright 1999-2024 Alibaba Group Holding Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import pandas as pd
|
|
17
|
+
from pandas.api.types import is_list_like
|
|
18
|
+
|
|
19
|
+
from ... import opcodes
|
|
20
|
+
from ...core import OutputType
|
|
21
|
+
from ...serialization.serializables import AnyField, BoolField, StringField
|
|
22
|
+
from ...utils import no_default
|
|
23
|
+
from ..operators import DataFrameOperator, DataFrameOperatorMixin
|
|
24
|
+
from ..utils import build_df, parse_index
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DataFramePivotTable(DataFrameOperator, DataFrameOperatorMixin):
|
|
28
|
+
_op_type_ = opcodes.PIVOT_TABLE
|
|
29
|
+
|
|
30
|
+
values = AnyField("values", default=None)
|
|
31
|
+
index = AnyField("index", default=None)
|
|
32
|
+
columns = AnyField("columns", default=None)
|
|
33
|
+
aggfunc = AnyField("aggfunc", default="mean")
|
|
34
|
+
fill_value = AnyField("fill_value", default=None)
|
|
35
|
+
margins = BoolField("margins", default=False)
|
|
36
|
+
dropna = BoolField("dropna", default=True)
|
|
37
|
+
margins_name = StringField("margins_name", default=None)
|
|
38
|
+
sort = BoolField("sort", default=False)
|
|
39
|
+
|
|
40
|
+
def __init__(self, **kw):
|
|
41
|
+
super().__init__(**kw)
|
|
42
|
+
self.output_types = [OutputType.dataframe]
|
|
43
|
+
|
|
44
|
+
def __call__(self, df):
|
|
45
|
+
index_value = columns_value = dtypes = None
|
|
46
|
+
if self.index is not None:
|
|
47
|
+
# index is now a required field
|
|
48
|
+
if len(self.index) == 1:
|
|
49
|
+
index_data = pd.Index(
|
|
50
|
+
[], dtype=df.dtypes[self.index[0]], name=self.index[0]
|
|
51
|
+
)
|
|
52
|
+
else:
|
|
53
|
+
index_data = pd.MultiIndex.from_frame(build_df(df[self.index]))
|
|
54
|
+
index_value = parse_index(index_data)
|
|
55
|
+
|
|
56
|
+
if self.columns is None: # output columns can be determined
|
|
57
|
+
sel_df = df
|
|
58
|
+
groupby_obj = sel_df.groupby(self.index)
|
|
59
|
+
if self.values:
|
|
60
|
+
groupby_obj = groupby_obj[self.values]
|
|
61
|
+
aggregated_df = groupby_obj.agg(self.aggfunc)
|
|
62
|
+
index_value = aggregated_df.index_value
|
|
63
|
+
columns_value = aggregated_df.columns_value
|
|
64
|
+
dtypes = aggregated_df.dtypes
|
|
65
|
+
else:
|
|
66
|
+
columns_value = dtypes = None
|
|
67
|
+
return self.new_dataframe(
|
|
68
|
+
[df],
|
|
69
|
+
shape=(np.nan, np.nan),
|
|
70
|
+
dtypes=dtypes,
|
|
71
|
+
columns_value=columns_value,
|
|
72
|
+
index_value=index_value,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def pivot_table(
|
|
77
|
+
data,
|
|
78
|
+
values=None,
|
|
79
|
+
index=None,
|
|
80
|
+
columns=None,
|
|
81
|
+
aggfunc="mean",
|
|
82
|
+
fill_value=None,
|
|
83
|
+
margins=False,
|
|
84
|
+
dropna=True,
|
|
85
|
+
margins_name="All",
|
|
86
|
+
sort=True,
|
|
87
|
+
):
|
|
88
|
+
"""
|
|
89
|
+
Create a spreadsheet-style pivot table as a DataFrame.
|
|
90
|
+
|
|
91
|
+
The levels in the pivot table will be stored in MultiIndex objects
|
|
92
|
+
(hierarchical indexes) on the index and columns of the result DataFrame.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
values : column to aggregate, optional
|
|
97
|
+
index : column, Grouper, array, or list of the previous
|
|
98
|
+
If an array is passed, it must be the same length as the data. The
|
|
99
|
+
list can contain any of the other types (except list).
|
|
100
|
+
Keys to group by on the pivot table index. If an array is passed,
|
|
101
|
+
it is being used as the same manner as column values.
|
|
102
|
+
columns : column, Grouper, array, or list of the previous
|
|
103
|
+
If an array is passed, it must be the same length as the data. The
|
|
104
|
+
list can contain any of the other types (except list).
|
|
105
|
+
Keys to group by on the pivot table column. If an array is passed,
|
|
106
|
+
it is being used as the same manner as column values.
|
|
107
|
+
aggfunc : function, list of functions, dict, default numpy.mean
|
|
108
|
+
If list of functions passed, the resulting pivot table will have
|
|
109
|
+
hierarchical columns whose top level are the function names
|
|
110
|
+
(inferred from the function objects themselves)
|
|
111
|
+
If dict is passed, the key is column to aggregate and value
|
|
112
|
+
is function or list of functions.
|
|
113
|
+
fill_value : scalar, default None
|
|
114
|
+
Value to replace missing values with (in the resulting pivot table,
|
|
115
|
+
after aggregation).
|
|
116
|
+
margins : bool, default False
|
|
117
|
+
Add all row / columns (e.g. for subtotal / grand totals).
|
|
118
|
+
dropna : bool, default True
|
|
119
|
+
Do not include columns whose entries are all NaN.
|
|
120
|
+
margins_name : str, default 'All'
|
|
121
|
+
Name of the row / column that will contain the totals
|
|
122
|
+
when margins is True.
|
|
123
|
+
sort : bool, default True
|
|
124
|
+
Specifies if the result should be sorted.
|
|
125
|
+
|
|
126
|
+
Returns
|
|
127
|
+
-------
|
|
128
|
+
DataFrame
|
|
129
|
+
An Excel style pivot table.
|
|
130
|
+
|
|
131
|
+
See Also
|
|
132
|
+
--------
|
|
133
|
+
DataFrame.pivot : Pivot without aggregation that can handle
|
|
134
|
+
non-numeric data.
|
|
135
|
+
DataFrame.melt: Unpivot a DataFrame from wide to long format,
|
|
136
|
+
optionally leaving identifiers set.
|
|
137
|
+
wide_to_long : Wide panel to long format. Less flexible but more
|
|
138
|
+
user-friendly than melt.
|
|
139
|
+
|
|
140
|
+
Examples
|
|
141
|
+
--------
|
|
142
|
+
>>> import numpy as np
|
|
143
|
+
>>> import maxframe.dataframe as md
|
|
144
|
+
>>> df = md.DataFrame({"A": ["foo", "foo", "foo", "foo", "foo",
|
|
145
|
+
... "bar", "bar", "bar", "bar"],
|
|
146
|
+
... "B": ["one", "one", "one", "two", "two",
|
|
147
|
+
... "one", "one", "two", "two"],
|
|
148
|
+
... "C": ["small", "large", "large", "small",
|
|
149
|
+
... "small", "large", "small", "small",
|
|
150
|
+
... "large"],
|
|
151
|
+
... "D": [1, 2, 2, 3, 3, 4, 5, 6, 7],
|
|
152
|
+
... "E": [2, 4, 5, 5, 6, 6, 8, 9, 9]})
|
|
153
|
+
>>> df.execute()
|
|
154
|
+
A B C D E
|
|
155
|
+
0 foo one small 1 2
|
|
156
|
+
1 foo one large 2 4
|
|
157
|
+
2 foo one large 2 5
|
|
158
|
+
3 foo two small 3 5
|
|
159
|
+
4 foo two small 3 6
|
|
160
|
+
5 bar one large 4 6
|
|
161
|
+
6 bar one small 5 8
|
|
162
|
+
7 bar two small 6 9
|
|
163
|
+
8 bar two large 7 9
|
|
164
|
+
|
|
165
|
+
This first example aggregates values by taking the sum.
|
|
166
|
+
|
|
167
|
+
>>> table = md.pivot_table(df, values='D', index=['A', 'B'],
|
|
168
|
+
... columns=['C'], aggfunc=np.sum)
|
|
169
|
+
>>> table.execute()
|
|
170
|
+
C large small
|
|
171
|
+
A B
|
|
172
|
+
bar one 4.0 5.0
|
|
173
|
+
two 7.0 6.0
|
|
174
|
+
foo one 4.0 1.0
|
|
175
|
+
two NaN 6.0
|
|
176
|
+
|
|
177
|
+
We can also fill missing values using the `fill_value` parameter.
|
|
178
|
+
|
|
179
|
+
>>> table = md.pivot_table(df, values='D', index=['A', 'B'],
|
|
180
|
+
... columns=['C'], aggfunc=np.sum, fill_value=0)
|
|
181
|
+
>>> table.execute()
|
|
182
|
+
C large small
|
|
183
|
+
A B
|
|
184
|
+
bar one 4 5
|
|
185
|
+
two 7 6
|
|
186
|
+
foo one 4 1
|
|
187
|
+
two 0 6
|
|
188
|
+
|
|
189
|
+
The next example aggregates by taking the mean across multiple columns.
|
|
190
|
+
|
|
191
|
+
>>> table = md.pivot_table(df, values=['D', 'E'], index=['A', 'C'],
|
|
192
|
+
... aggfunc={'D': np.mean,
|
|
193
|
+
... 'E': np.mean})
|
|
194
|
+
>>> table.execute()
|
|
195
|
+
D E
|
|
196
|
+
A C
|
|
197
|
+
bar large 5.500000 7.500000
|
|
198
|
+
small 5.500000 8.500000
|
|
199
|
+
foo large 2.000000 4.500000
|
|
200
|
+
small 2.333333 4.333333
|
|
201
|
+
|
|
202
|
+
We can also calculate multiple types of aggregations for any given
|
|
203
|
+
value column.
|
|
204
|
+
|
|
205
|
+
>>> table = md.pivot_table(df, values=['D', 'E'], index=['A', 'C'],
|
|
206
|
+
... aggfunc={'D': np.mean,
|
|
207
|
+
... 'E': [min, max, np.mean]})
|
|
208
|
+
>>> table.execute()
|
|
209
|
+
D E
|
|
210
|
+
mean max mean min
|
|
211
|
+
A C
|
|
212
|
+
bar large 5.500000 9.0 7.500000 6.0
|
|
213
|
+
small 5.500000 9.0 8.500000 8.0
|
|
214
|
+
foo large 2.000000 5.0 4.500000 4.0
|
|
215
|
+
small 2.333333 6.0 4.333333 2.0
|
|
216
|
+
"""
|
|
217
|
+
if index is None and columns is None:
|
|
218
|
+
raise ValueError(
|
|
219
|
+
"No group keys passed, need to specify at least one of index or columns"
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
def make_col_list(col):
|
|
223
|
+
try:
|
|
224
|
+
if col in data.dtypes.index:
|
|
225
|
+
return [col]
|
|
226
|
+
except TypeError:
|
|
227
|
+
return col
|
|
228
|
+
return col
|
|
229
|
+
|
|
230
|
+
values_list = make_col_list(values)
|
|
231
|
+
index_list = make_col_list(index)
|
|
232
|
+
columns_list = make_col_list(columns)
|
|
233
|
+
|
|
234
|
+
name_to_attr = {"values": values_list, "index": index_list, "columns": columns_list}
|
|
235
|
+
for key, val in name_to_attr.items():
|
|
236
|
+
if val is None:
|
|
237
|
+
continue
|
|
238
|
+
if not is_list_like(val):
|
|
239
|
+
raise ValueError(f"Need to specify {key} as a list-like object.")
|
|
240
|
+
non_exist_key = next((c for c in val if c not in data.dtypes.index), no_default)
|
|
241
|
+
if non_exist_key is not no_default:
|
|
242
|
+
raise ValueError(
|
|
243
|
+
f"Column {non_exist_key} specified in {key} is not a valid column."
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
if columns is None and not margins:
|
|
247
|
+
if values_list:
|
|
248
|
+
data = data[index_list + values_list]
|
|
249
|
+
return data.groupby(index, sort=sort).agg(aggfunc)
|
|
250
|
+
|
|
251
|
+
op = DataFramePivotTable(
|
|
252
|
+
values=values,
|
|
253
|
+
index=index,
|
|
254
|
+
columns=columns,
|
|
255
|
+
aggfunc=aggfunc,
|
|
256
|
+
fill_value=fill_value,
|
|
257
|
+
margins=margins,
|
|
258
|
+
dropna=dropna,
|
|
259
|
+
margins_name=margins_name,
|
|
260
|
+
sort=sort,
|
|
261
|
+
)
|
|
262
|
+
return op(data)
|
|
@@ -21,6 +21,7 @@ from ....core import OutputType
|
|
|
21
21
|
from ....tensor.core import TENSOR_TYPE
|
|
22
22
|
from ... import eval as maxframe_eval
|
|
23
23
|
from ... import get_dummies, to_numeric
|
|
24
|
+
from ...arithmetic import DataFrameGreater, DataFrameLess
|
|
24
25
|
from ...core import CATEGORICAL_TYPE, DATAFRAME_TYPE, INDEX_TYPE, SERIES_TYPE
|
|
25
26
|
from ...datasource.dataframe import from_pandas as from_pandas_df
|
|
26
27
|
from ...datasource.index import from_pandas as from_pandas_index
|
|
@@ -405,3 +406,63 @@ def test_to_numeric():
|
|
|
405
406
|
|
|
406
407
|
with pytest.raises(ValueError):
|
|
407
408
|
_ = to_numeric([])
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def test_case_when():
|
|
412
|
+
rs = np.random.RandomState(0)
|
|
413
|
+
raw = pd.DataFrame(
|
|
414
|
+
rs.randint(1000, size=(20, 8)), columns=["c" + str(i + 1) for i in range(8)]
|
|
415
|
+
)
|
|
416
|
+
df = from_pandas_df(raw, chunk_size=8)
|
|
417
|
+
|
|
418
|
+
with pytest.raises(TypeError):
|
|
419
|
+
df.c1.case_when(df.c2)
|
|
420
|
+
with pytest.raises(ValueError):
|
|
421
|
+
df.c1.case_when([])
|
|
422
|
+
with pytest.raises(TypeError):
|
|
423
|
+
df.c1.case_when([[]])
|
|
424
|
+
with pytest.raises(ValueError):
|
|
425
|
+
df.c1.case_when([()])
|
|
426
|
+
|
|
427
|
+
col = df.c1.case_when([(df.c2 < 10, 10), (df.c2 > 20, df.c3)])
|
|
428
|
+
assert len(col.inputs) == 4
|
|
429
|
+
assert isinstance(col.inputs[1].op, DataFrameLess)
|
|
430
|
+
assert isinstance(col.inputs[2].op, DataFrameGreater)
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def test_pivot_table():
|
|
434
|
+
from ...groupby.aggregation import DataFrameGroupByAgg
|
|
435
|
+
from ...misc.pivot_table import DataFramePivotTable
|
|
436
|
+
|
|
437
|
+
raw = pd.DataFrame(
|
|
438
|
+
{
|
|
439
|
+
"A": "foo foo foo foo foo bar bar bar bar".split(),
|
|
440
|
+
"B": "one one one two two one one two two".split(),
|
|
441
|
+
"C": "small large large small small large small small large".split(),
|
|
442
|
+
"D": [1, 2, 2, 3, 3, 4, 5, 6, 7],
|
|
443
|
+
"E": [2, 4, 5, 5, 6, 6, 8, 9, 9],
|
|
444
|
+
}
|
|
445
|
+
)
|
|
446
|
+
df = from_pandas_df(raw, chunk_size=8)
|
|
447
|
+
with pytest.raises(ValueError):
|
|
448
|
+
df.pivot_table(index=123)
|
|
449
|
+
with pytest.raises(ValueError):
|
|
450
|
+
df.pivot_table(index=["F"])
|
|
451
|
+
with pytest.raises(ValueError):
|
|
452
|
+
df.pivot_table(values=["D", "E"], aggfunc="sum")
|
|
453
|
+
|
|
454
|
+
t = df.pivot_table(index="A")
|
|
455
|
+
assert isinstance(t.op, DataFrameGroupByAgg)
|
|
456
|
+
t = df.pivot_table(index="A", values=["D", "E"], aggfunc="sum")
|
|
457
|
+
assert isinstance(t.op, DataFrameGroupByAgg)
|
|
458
|
+
|
|
459
|
+
t = df.pivot_table(index=["A", "B"], values=["D", "E"], aggfunc="sum", margins=True)
|
|
460
|
+
assert isinstance(t.op, DataFramePivotTable)
|
|
461
|
+
|
|
462
|
+
t = df.pivot_table(index="A", columns=["B", "C"], aggfunc="sum")
|
|
463
|
+
assert isinstance(t.op, DataFramePivotTable)
|
|
464
|
+
assert t.shape == (np.nan, np.nan)
|
|
465
|
+
|
|
466
|
+
t = df.pivot_table(index=["A", "B"], columns="C", aggfunc="sum")
|
|
467
|
+
assert isinstance(t.op, DataFramePivotTable)
|
|
468
|
+
assert t.shape == (np.nan, np.nan)
|
|
@@ -17,7 +17,7 @@ from collections import OrderedDict
|
|
|
17
17
|
import pandas as pd
|
|
18
18
|
|
|
19
19
|
from ...core import ENTITY_TYPE, ExecutableTuple
|
|
20
|
-
from ...utils import adapt_docstring
|
|
20
|
+
from ...utils import adapt_docstring, get_item_if_scalar
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class PlotAccessor:
|
|
@@ -34,7 +34,7 @@ class PlotAccessor:
|
|
|
34
34
|
.fetch(session=session)
|
|
35
35
|
)
|
|
36
36
|
for p, v in zip(to_executes, executed):
|
|
37
|
-
result[p] = v
|
|
37
|
+
result[p] = get_item_if_scalar(v)
|
|
38
38
|
|
|
39
39
|
data = result.pop("__object__")
|
|
40
40
|
pd_kwargs = kwargs.copy()
|
|
@@ -30,7 +30,7 @@ from ...serialization.serializables import (
|
|
|
30
30
|
StringField,
|
|
31
31
|
)
|
|
32
32
|
from ...typing_ import TileableType
|
|
33
|
-
from ...utils import pd_release_version, tokenize
|
|
33
|
+
from ...utils import get_item_if_scalar, pd_release_version, tokenize
|
|
34
34
|
from ..operators import DATAFRAME_TYPE, DataFrameOperator, DataFrameOperatorMixin
|
|
35
35
|
from ..utils import (
|
|
36
36
|
build_df,
|
|
@@ -715,6 +715,7 @@ class ReductionCompiler:
|
|
|
715
715
|
keys_to_vars = {inp.key: local_key_to_var[inp.key] for inp in t.inputs}
|
|
716
716
|
|
|
717
717
|
def _interpret_var(v):
|
|
718
|
+
v = get_item_if_scalar(v)
|
|
718
719
|
# get representation for variables
|
|
719
720
|
if hasattr(v, "key"):
|
|
720
721
|
return keys_to_vars[v.key]
|
maxframe/dataframe/utils.py
CHANGED
|
@@ -1136,6 +1136,13 @@ def concat_on_columns(objs: List) -> Any:
|
|
|
1136
1136
|
return result
|
|
1137
1137
|
|
|
1138
1138
|
|
|
1139
|
+
def apply_if_callable(maybe_callable, obj, **kwargs):
|
|
1140
|
+
if callable(maybe_callable):
|
|
1141
|
+
return maybe_callable(obj, **kwargs)
|
|
1142
|
+
|
|
1143
|
+
return maybe_callable
|
|
1144
|
+
|
|
1145
|
+
|
|
1139
1146
|
def patch_sa_engine_execute():
|
|
1140
1147
|
"""
|
|
1141
1148
|
pandas did not resolve compatibility issue of sqlalchemy 2.0, the issue
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# Copyright 1999-2024 Alibaba Group Holding Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import sys
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def make_import_error_func(package_name):
|
|
18
|
+
def _func(*_, **__): # pragma: no cover
|
|
19
|
+
raise ImportError(
|
|
20
|
+
f"Cannot import {package_name}, please reinstall that package."
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
return _func
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def config_mod_getattr(mod_dict, globals_):
|
|
27
|
+
def __getattr__(name):
|
|
28
|
+
import importlib
|
|
29
|
+
|
|
30
|
+
if name in mod_dict:
|
|
31
|
+
mod_name, cls_name = mod_dict[name].rsplit(".", 1)
|
|
32
|
+
mod = importlib.import_module(mod_name, globals_["__name__"])
|
|
33
|
+
cls = globals_[name] = getattr(mod, cls_name)
|
|
34
|
+
return cls
|
|
35
|
+
else: # pragma: no cover
|
|
36
|
+
raise AttributeError(name)
|
|
37
|
+
|
|
38
|
+
if sys.version_info[:2] < (3, 7):
|
|
39
|
+
for _mod in mod_dict.keys():
|
|
40
|
+
__getattr__(_mod)
|
|
41
|
+
|
|
42
|
+
def __dir__():
|
|
43
|
+
return sorted([n for n in globals_ if not n.startswith("_")] + list(mod_dict))
|
|
44
|
+
|
|
45
|
+
globals_.update(
|
|
46
|
+
{
|
|
47
|
+
"__getattr__": __getattr__,
|
|
48
|
+
"__dir__": __dir__,
|
|
49
|
+
"__all__": list(__dir__()),
|
|
50
|
+
"__warningregistry__": dict(),
|
|
51
|
+
}
|
|
52
|
+
)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright 1999-2024 Alibaba Group Holding Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ..utils import config_mod_getattr as _config_mod_getattr
|
|
16
|
+
from .dmatrix import DMatrix
|
|
17
|
+
from .predict import predict
|
|
18
|
+
from .train import train
|
|
19
|
+
|
|
20
|
+
_config_mod_getattr(
|
|
21
|
+
{
|
|
22
|
+
"XGBClassifier": ".classifier.XGBClassifier",
|
|
23
|
+
"XGBRegressor": ".regressor.XGBRegressor",
|
|
24
|
+
},
|
|
25
|
+
globals(),
|
|
26
|
+
)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# Copyright 1999-2024 Alibaba Group Holding Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
from ....tensor import argmax
|
|
18
|
+
from ..utils import make_import_error_func
|
|
19
|
+
from .core import XGBScikitLearnBase, xgboost
|
|
20
|
+
|
|
21
|
+
if not xgboost:
|
|
22
|
+
XGBClassifier = make_import_error_func("xgboost")
|
|
23
|
+
else:
|
|
24
|
+
from xgboost.sklearn import XGBClassifierBase
|
|
25
|
+
|
|
26
|
+
from .core import wrap_evaluation_matrices
|
|
27
|
+
from .predict import predict
|
|
28
|
+
from .train import train
|
|
29
|
+
|
|
30
|
+
class XGBClassifier(XGBScikitLearnBase, XGBClassifierBase):
|
|
31
|
+
"""
|
|
32
|
+
Implementation of the scikit-learn API for XGBoost classification.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def fit(
|
|
36
|
+
self,
|
|
37
|
+
X,
|
|
38
|
+
y,
|
|
39
|
+
sample_weight=None,
|
|
40
|
+
base_margin=None,
|
|
41
|
+
eval_set=None,
|
|
42
|
+
sample_weight_eval_set=None,
|
|
43
|
+
base_margin_eval_set=None,
|
|
44
|
+
num_class=None,
|
|
45
|
+
):
|
|
46
|
+
dtrain, evals = wrap_evaluation_matrices(
|
|
47
|
+
None,
|
|
48
|
+
X,
|
|
49
|
+
y,
|
|
50
|
+
sample_weight,
|
|
51
|
+
base_margin,
|
|
52
|
+
eval_set,
|
|
53
|
+
sample_weight_eval_set,
|
|
54
|
+
base_margin_eval_set,
|
|
55
|
+
)
|
|
56
|
+
params = self.get_xgb_params()
|
|
57
|
+
self.n_classes_ = num_class or 1
|
|
58
|
+
if self.n_classes_ > 2:
|
|
59
|
+
params["objective"] = "multi:softprob"
|
|
60
|
+
params["num_class"] = self.n_classes_
|
|
61
|
+
else:
|
|
62
|
+
params["objective"] = "binary:logistic"
|
|
63
|
+
self.evals_result_ = dict()
|
|
64
|
+
result = train(
|
|
65
|
+
params,
|
|
66
|
+
dtrain,
|
|
67
|
+
num_boost_round=self.get_num_boosting_rounds(),
|
|
68
|
+
evals=evals,
|
|
69
|
+
evals_result=self.evals_result_,
|
|
70
|
+
num_class=num_class,
|
|
71
|
+
)
|
|
72
|
+
self._Booster = result
|
|
73
|
+
return self
|
|
74
|
+
|
|
75
|
+
def predict(self, data, **kw):
|
|
76
|
+
prob = self.predict_proba(data, flag=True, **kw)
|
|
77
|
+
if prob.ndim > 1:
|
|
78
|
+
prediction = argmax(prob, axis=1)
|
|
79
|
+
else:
|
|
80
|
+
prediction = (prob > 0.5).astype(np.int64)
|
|
81
|
+
return prediction
|
|
82
|
+
|
|
83
|
+
def predict_proba(self, data, ntree_limit=None, flag=False, **kw):
|
|
84
|
+
if ntree_limit is not None:
|
|
85
|
+
raise NotImplementedError("ntree_limit is not currently supported")
|
|
86
|
+
return predict(self.get_booster(), data, flag=flag, **kw)
|