maxframe 0.1.0b4__cp311-cp311-win32.whl → 0.1.0b5__cp311-cp311-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (53) hide show
  1. maxframe/__init__.py +1 -0
  2. maxframe/_utils.cp311-win32.pyd +0 -0
  3. maxframe/codegen.py +46 -1
  4. maxframe/config/config.py +11 -1
  5. maxframe/core/graph/core.cp311-win32.pyd +0 -0
  6. maxframe/dataframe/__init__.py +1 -0
  7. maxframe/dataframe/core.py +30 -8
  8. maxframe/dataframe/datasource/read_odps_query.py +3 -1
  9. maxframe/dataframe/datasource/read_odps_table.py +3 -1
  10. maxframe/dataframe/misc/__init__.py +4 -0
  11. maxframe/dataframe/misc/apply.py +1 -1
  12. maxframe/dataframe/misc/case_when.py +141 -0
  13. maxframe/dataframe/misc/pivot_table.py +262 -0
  14. maxframe/dataframe/misc/tests/test_misc.py +61 -0
  15. maxframe/dataframe/plotting/core.py +2 -2
  16. maxframe/dataframe/reduction/core.py +2 -1
  17. maxframe/dataframe/utils.py +7 -0
  18. maxframe/learn/contrib/utils.py +52 -0
  19. maxframe/learn/contrib/xgboost/__init__.py +26 -0
  20. maxframe/learn/contrib/xgboost/classifier.py +86 -0
  21. maxframe/learn/contrib/xgboost/core.py +156 -0
  22. maxframe/learn/contrib/xgboost/dmatrix.py +150 -0
  23. maxframe/learn/contrib/xgboost/predict.py +138 -0
  24. maxframe/learn/contrib/xgboost/regressor.py +78 -0
  25. maxframe/learn/contrib/xgboost/tests/__init__.py +13 -0
  26. maxframe/learn/contrib/xgboost/tests/test_core.py +43 -0
  27. maxframe/learn/contrib/xgboost/train.py +121 -0
  28. maxframe/learn/utils/__init__.py +15 -0
  29. maxframe/learn/utils/core.py +29 -0
  30. maxframe/lib/mmh3.cp311-win32.pyd +0 -0
  31. maxframe/odpsio/arrow.py +2 -3
  32. maxframe/odpsio/tableio.py +22 -0
  33. maxframe/odpsio/tests/test_schema.py +16 -11
  34. maxframe/opcodes.py +3 -0
  35. maxframe/serialization/core.cp311-win32.pyd +0 -0
  36. maxframe/serialization/core.pyi +61 -0
  37. maxframe/session.py +28 -0
  38. maxframe/tensor/__init__.py +1 -1
  39. maxframe/tensor/base/__init__.py +2 -0
  40. maxframe/tensor/base/atleast_1d.py +74 -0
  41. maxframe/tensor/base/unique.py +205 -0
  42. maxframe/tensor/datasource/array.py +4 -2
  43. maxframe/tensor/datasource/scalar.py +1 -1
  44. maxframe/udf.py +63 -3
  45. maxframe/utils.py +6 -0
  46. {maxframe-0.1.0b4.dist-info → maxframe-0.1.0b5.dist-info}/METADATA +2 -2
  47. {maxframe-0.1.0b4.dist-info → maxframe-0.1.0b5.dist-info}/RECORD +53 -36
  48. maxframe_client/fetcher.py +65 -3
  49. maxframe_client/session/odps.py +30 -1
  50. maxframe_client/session/task.py +26 -53
  51. maxframe_client/tests/test_session.py +28 -1
  52. {maxframe-0.1.0b4.dist-info → maxframe-0.1.0b5.dist-info}/WHEEL +0 -0
  53. {maxframe-0.1.0b4.dist-info → maxframe-0.1.0b5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,262 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+ from pandas.api.types import is_list_like
18
+
19
+ from ... import opcodes
20
+ from ...core import OutputType
21
+ from ...serialization.serializables import AnyField, BoolField, StringField
22
+ from ...utils import no_default
23
+ from ..operators import DataFrameOperator, DataFrameOperatorMixin
24
+ from ..utils import build_df, parse_index
25
+
26
+
27
+ class DataFramePivotTable(DataFrameOperator, DataFrameOperatorMixin):
28
+ _op_type_ = opcodes.PIVOT_TABLE
29
+
30
+ values = AnyField("values", default=None)
31
+ index = AnyField("index", default=None)
32
+ columns = AnyField("columns", default=None)
33
+ aggfunc = AnyField("aggfunc", default="mean")
34
+ fill_value = AnyField("fill_value", default=None)
35
+ margins = BoolField("margins", default=False)
36
+ dropna = BoolField("dropna", default=True)
37
+ margins_name = StringField("margins_name", default=None)
38
+ sort = BoolField("sort", default=False)
39
+
40
+ def __init__(self, **kw):
41
+ super().__init__(**kw)
42
+ self.output_types = [OutputType.dataframe]
43
+
44
+ def __call__(self, df):
45
+ index_value = columns_value = dtypes = None
46
+ if self.index is not None:
47
+ # index is now a required field
48
+ if len(self.index) == 1:
49
+ index_data = pd.Index(
50
+ [], dtype=df.dtypes[self.index[0]], name=self.index[0]
51
+ )
52
+ else:
53
+ index_data = pd.MultiIndex.from_frame(build_df(df[self.index]))
54
+ index_value = parse_index(index_data)
55
+
56
+ if self.columns is None: # output columns can be determined
57
+ sel_df = df
58
+ groupby_obj = sel_df.groupby(self.index)
59
+ if self.values:
60
+ groupby_obj = groupby_obj[self.values]
61
+ aggregated_df = groupby_obj.agg(self.aggfunc)
62
+ index_value = aggregated_df.index_value
63
+ columns_value = aggregated_df.columns_value
64
+ dtypes = aggregated_df.dtypes
65
+ else:
66
+ columns_value = dtypes = None
67
+ return self.new_dataframe(
68
+ [df],
69
+ shape=(np.nan, np.nan),
70
+ dtypes=dtypes,
71
+ columns_value=columns_value,
72
+ index_value=index_value,
73
+ )
74
+
75
+
76
+ def pivot_table(
77
+ data,
78
+ values=None,
79
+ index=None,
80
+ columns=None,
81
+ aggfunc="mean",
82
+ fill_value=None,
83
+ margins=False,
84
+ dropna=True,
85
+ margins_name="All",
86
+ sort=True,
87
+ ):
88
+ """
89
+ Create a spreadsheet-style pivot table as a DataFrame.
90
+
91
+ The levels in the pivot table will be stored in MultiIndex objects
92
+ (hierarchical indexes) on the index and columns of the result DataFrame.
93
+
94
+ Parameters
95
+ ----------
96
+ values : column to aggregate, optional
97
+ index : column, Grouper, array, or list of the previous
98
+ If an array is passed, it must be the same length as the data. The
99
+ list can contain any of the other types (except list).
100
+ Keys to group by on the pivot table index. If an array is passed,
101
+ it is being used as the same manner as column values.
102
+ columns : column, Grouper, array, or list of the previous
103
+ If an array is passed, it must be the same length as the data. The
104
+ list can contain any of the other types (except list).
105
+ Keys to group by on the pivot table column. If an array is passed,
106
+ it is being used as the same manner as column values.
107
+ aggfunc : function, list of functions, dict, default numpy.mean
108
+ If list of functions passed, the resulting pivot table will have
109
+ hierarchical columns whose top level are the function names
110
+ (inferred from the function objects themselves)
111
+ If dict is passed, the key is column to aggregate and value
112
+ is function or list of functions.
113
+ fill_value : scalar, default None
114
+ Value to replace missing values with (in the resulting pivot table,
115
+ after aggregation).
116
+ margins : bool, default False
117
+ Add all row / columns (e.g. for subtotal / grand totals).
118
+ dropna : bool, default True
119
+ Do not include columns whose entries are all NaN.
120
+ margins_name : str, default 'All'
121
+ Name of the row / column that will contain the totals
122
+ when margins is True.
123
+ sort : bool, default True
124
+ Specifies if the result should be sorted.
125
+
126
+ Returns
127
+ -------
128
+ DataFrame
129
+ An Excel style pivot table.
130
+
131
+ See Also
132
+ --------
133
+ DataFrame.pivot : Pivot without aggregation that can handle
134
+ non-numeric data.
135
+ DataFrame.melt: Unpivot a DataFrame from wide to long format,
136
+ optionally leaving identifiers set.
137
+ wide_to_long : Wide panel to long format. Less flexible but more
138
+ user-friendly than melt.
139
+
140
+ Examples
141
+ --------
142
+ >>> import numpy as np
143
+ >>> import maxframe.dataframe as md
144
+ >>> df = md.DataFrame({"A": ["foo", "foo", "foo", "foo", "foo",
145
+ ... "bar", "bar", "bar", "bar"],
146
+ ... "B": ["one", "one", "one", "two", "two",
147
+ ... "one", "one", "two", "two"],
148
+ ... "C": ["small", "large", "large", "small",
149
+ ... "small", "large", "small", "small",
150
+ ... "large"],
151
+ ... "D": [1, 2, 2, 3, 3, 4, 5, 6, 7],
152
+ ... "E": [2, 4, 5, 5, 6, 6, 8, 9, 9]})
153
+ >>> df.execute()
154
+ A B C D E
155
+ 0 foo one small 1 2
156
+ 1 foo one large 2 4
157
+ 2 foo one large 2 5
158
+ 3 foo two small 3 5
159
+ 4 foo two small 3 6
160
+ 5 bar one large 4 6
161
+ 6 bar one small 5 8
162
+ 7 bar two small 6 9
163
+ 8 bar two large 7 9
164
+
165
+ This first example aggregates values by taking the sum.
166
+
167
+ >>> table = md.pivot_table(df, values='D', index=['A', 'B'],
168
+ ... columns=['C'], aggfunc=np.sum)
169
+ >>> table.execute()
170
+ C large small
171
+ A B
172
+ bar one 4.0 5.0
173
+ two 7.0 6.0
174
+ foo one 4.0 1.0
175
+ two NaN 6.0
176
+
177
+ We can also fill missing values using the `fill_value` parameter.
178
+
179
+ >>> table = md.pivot_table(df, values='D', index=['A', 'B'],
180
+ ... columns=['C'], aggfunc=np.sum, fill_value=0)
181
+ >>> table.execute()
182
+ C large small
183
+ A B
184
+ bar one 4 5
185
+ two 7 6
186
+ foo one 4 1
187
+ two 0 6
188
+
189
+ The next example aggregates by taking the mean across multiple columns.
190
+
191
+ >>> table = md.pivot_table(df, values=['D', 'E'], index=['A', 'C'],
192
+ ... aggfunc={'D': np.mean,
193
+ ... 'E': np.mean})
194
+ >>> table.execute()
195
+ D E
196
+ A C
197
+ bar large 5.500000 7.500000
198
+ small 5.500000 8.500000
199
+ foo large 2.000000 4.500000
200
+ small 2.333333 4.333333
201
+
202
+ We can also calculate multiple types of aggregations for any given
203
+ value column.
204
+
205
+ >>> table = md.pivot_table(df, values=['D', 'E'], index=['A', 'C'],
206
+ ... aggfunc={'D': np.mean,
207
+ ... 'E': [min, max, np.mean]})
208
+ >>> table.execute()
209
+ D E
210
+ mean max mean min
211
+ A C
212
+ bar large 5.500000 9.0 7.500000 6.0
213
+ small 5.500000 9.0 8.500000 8.0
214
+ foo large 2.000000 5.0 4.500000 4.0
215
+ small 2.333333 6.0 4.333333 2.0
216
+ """
217
+ if index is None and columns is None:
218
+ raise ValueError(
219
+ "No group keys passed, need to specify at least one of index or columns"
220
+ )
221
+
222
+ def make_col_list(col):
223
+ try:
224
+ if col in data.dtypes.index:
225
+ return [col]
226
+ except TypeError:
227
+ return col
228
+ return col
229
+
230
+ values_list = make_col_list(values)
231
+ index_list = make_col_list(index)
232
+ columns_list = make_col_list(columns)
233
+
234
+ name_to_attr = {"values": values_list, "index": index_list, "columns": columns_list}
235
+ for key, val in name_to_attr.items():
236
+ if val is None:
237
+ continue
238
+ if not is_list_like(val):
239
+ raise ValueError(f"Need to specify {key} as a list-like object.")
240
+ non_exist_key = next((c for c in val if c not in data.dtypes.index), no_default)
241
+ if non_exist_key is not no_default:
242
+ raise ValueError(
243
+ f"Column {non_exist_key} specified in {key} is not a valid column."
244
+ )
245
+
246
+ if columns is None and not margins:
247
+ if values_list:
248
+ data = data[index_list + values_list]
249
+ return data.groupby(index, sort=sort).agg(aggfunc)
250
+
251
+ op = DataFramePivotTable(
252
+ values=values,
253
+ index=index,
254
+ columns=columns,
255
+ aggfunc=aggfunc,
256
+ fill_value=fill_value,
257
+ margins=margins,
258
+ dropna=dropna,
259
+ margins_name=margins_name,
260
+ sort=sort,
261
+ )
262
+ return op(data)
@@ -21,6 +21,7 @@ from ....core import OutputType
21
21
  from ....tensor.core import TENSOR_TYPE
22
22
  from ... import eval as maxframe_eval
23
23
  from ... import get_dummies, to_numeric
24
+ from ...arithmetic import DataFrameGreater, DataFrameLess
24
25
  from ...core import CATEGORICAL_TYPE, DATAFRAME_TYPE, INDEX_TYPE, SERIES_TYPE
25
26
  from ...datasource.dataframe import from_pandas as from_pandas_df
26
27
  from ...datasource.index import from_pandas as from_pandas_index
@@ -405,3 +406,63 @@ def test_to_numeric():
405
406
 
406
407
  with pytest.raises(ValueError):
407
408
  _ = to_numeric([])
409
+
410
+
411
+ def test_case_when():
412
+ rs = np.random.RandomState(0)
413
+ raw = pd.DataFrame(
414
+ rs.randint(1000, size=(20, 8)), columns=["c" + str(i + 1) for i in range(8)]
415
+ )
416
+ df = from_pandas_df(raw, chunk_size=8)
417
+
418
+ with pytest.raises(TypeError):
419
+ df.c1.case_when(df.c2)
420
+ with pytest.raises(ValueError):
421
+ df.c1.case_when([])
422
+ with pytest.raises(TypeError):
423
+ df.c1.case_when([[]])
424
+ with pytest.raises(ValueError):
425
+ df.c1.case_when([()])
426
+
427
+ col = df.c1.case_when([(df.c2 < 10, 10), (df.c2 > 20, df.c3)])
428
+ assert len(col.inputs) == 4
429
+ assert isinstance(col.inputs[1].op, DataFrameLess)
430
+ assert isinstance(col.inputs[2].op, DataFrameGreater)
431
+
432
+
433
+ def test_pivot_table():
434
+ from ...groupby.aggregation import DataFrameGroupByAgg
435
+ from ...misc.pivot_table import DataFramePivotTable
436
+
437
+ raw = pd.DataFrame(
438
+ {
439
+ "A": "foo foo foo foo foo bar bar bar bar".split(),
440
+ "B": "one one one two two one one two two".split(),
441
+ "C": "small large large small small large small small large".split(),
442
+ "D": [1, 2, 2, 3, 3, 4, 5, 6, 7],
443
+ "E": [2, 4, 5, 5, 6, 6, 8, 9, 9],
444
+ }
445
+ )
446
+ df = from_pandas_df(raw, chunk_size=8)
447
+ with pytest.raises(ValueError):
448
+ df.pivot_table(index=123)
449
+ with pytest.raises(ValueError):
450
+ df.pivot_table(index=["F"])
451
+ with pytest.raises(ValueError):
452
+ df.pivot_table(values=["D", "E"], aggfunc="sum")
453
+
454
+ t = df.pivot_table(index="A")
455
+ assert isinstance(t.op, DataFrameGroupByAgg)
456
+ t = df.pivot_table(index="A", values=["D", "E"], aggfunc="sum")
457
+ assert isinstance(t.op, DataFrameGroupByAgg)
458
+
459
+ t = df.pivot_table(index=["A", "B"], values=["D", "E"], aggfunc="sum", margins=True)
460
+ assert isinstance(t.op, DataFramePivotTable)
461
+
462
+ t = df.pivot_table(index="A", columns=["B", "C"], aggfunc="sum")
463
+ assert isinstance(t.op, DataFramePivotTable)
464
+ assert t.shape == (np.nan, np.nan)
465
+
466
+ t = df.pivot_table(index=["A", "B"], columns="C", aggfunc="sum")
467
+ assert isinstance(t.op, DataFramePivotTable)
468
+ assert t.shape == (np.nan, np.nan)
@@ -17,7 +17,7 @@ from collections import OrderedDict
17
17
  import pandas as pd
18
18
 
19
19
  from ...core import ENTITY_TYPE, ExecutableTuple
20
- from ...utils import adapt_docstring
20
+ from ...utils import adapt_docstring, get_item_if_scalar
21
21
 
22
22
 
23
23
  class PlotAccessor:
@@ -34,7 +34,7 @@ class PlotAccessor:
34
34
  .fetch(session=session)
35
35
  )
36
36
  for p, v in zip(to_executes, executed):
37
- result[p] = v
37
+ result[p] = get_item_if_scalar(v)
38
38
 
39
39
  data = result.pop("__object__")
40
40
  pd_kwargs = kwargs.copy()
@@ -30,7 +30,7 @@ from ...serialization.serializables import (
30
30
  StringField,
31
31
  )
32
32
  from ...typing_ import TileableType
33
- from ...utils import pd_release_version, tokenize
33
+ from ...utils import get_item_if_scalar, pd_release_version, tokenize
34
34
  from ..operators import DATAFRAME_TYPE, DataFrameOperator, DataFrameOperatorMixin
35
35
  from ..utils import (
36
36
  build_df,
@@ -715,6 +715,7 @@ class ReductionCompiler:
715
715
  keys_to_vars = {inp.key: local_key_to_var[inp.key] for inp in t.inputs}
716
716
 
717
717
  def _interpret_var(v):
718
+ v = get_item_if_scalar(v)
718
719
  # get representation for variables
719
720
  if hasattr(v, "key"):
720
721
  return keys_to_vars[v.key]
@@ -1136,6 +1136,13 @@ def concat_on_columns(objs: List) -> Any:
1136
1136
  return result
1137
1137
 
1138
1138
 
1139
+ def apply_if_callable(maybe_callable, obj, **kwargs):
1140
+ if callable(maybe_callable):
1141
+ return maybe_callable(obj, **kwargs)
1142
+
1143
+ return maybe_callable
1144
+
1145
+
1139
1146
  def patch_sa_engine_execute():
1140
1147
  """
1141
1148
  pandas did not resolve compatibility issue of sqlalchemy 2.0, the issue
@@ -0,0 +1,52 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import sys
15
+
16
+
17
+ def make_import_error_func(package_name):
18
+ def _func(*_, **__): # pragma: no cover
19
+ raise ImportError(
20
+ f"Cannot import {package_name}, please reinstall that package."
21
+ )
22
+
23
+ return _func
24
+
25
+
26
+ def config_mod_getattr(mod_dict, globals_):
27
+ def __getattr__(name):
28
+ import importlib
29
+
30
+ if name in mod_dict:
31
+ mod_name, cls_name = mod_dict[name].rsplit(".", 1)
32
+ mod = importlib.import_module(mod_name, globals_["__name__"])
33
+ cls = globals_[name] = getattr(mod, cls_name)
34
+ return cls
35
+ else: # pragma: no cover
36
+ raise AttributeError(name)
37
+
38
+ if sys.version_info[:2] < (3, 7):
39
+ for _mod in mod_dict.keys():
40
+ __getattr__(_mod)
41
+
42
+ def __dir__():
43
+ return sorted([n for n in globals_ if not n.startswith("_")] + list(mod_dict))
44
+
45
+ globals_.update(
46
+ {
47
+ "__getattr__": __getattr__,
48
+ "__dir__": __dir__,
49
+ "__all__": list(__dir__()),
50
+ "__warningregistry__": dict(),
51
+ }
52
+ )
@@ -0,0 +1,26 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..utils import config_mod_getattr as _config_mod_getattr
16
+ from .dmatrix import DMatrix
17
+ from .predict import predict
18
+ from .train import train
19
+
20
+ _config_mod_getattr(
21
+ {
22
+ "XGBClassifier": ".classifier.XGBClassifier",
23
+ "XGBRegressor": ".regressor.XGBRegressor",
24
+ },
25
+ globals(),
26
+ )
@@ -0,0 +1,86 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+
17
+ from ....tensor import argmax
18
+ from ..utils import make_import_error_func
19
+ from .core import XGBScikitLearnBase, xgboost
20
+
21
+ if not xgboost:
22
+ XGBClassifier = make_import_error_func("xgboost")
23
+ else:
24
+ from xgboost.sklearn import XGBClassifierBase
25
+
26
+ from .core import wrap_evaluation_matrices
27
+ from .predict import predict
28
+ from .train import train
29
+
30
+ class XGBClassifier(XGBScikitLearnBase, XGBClassifierBase):
31
+ """
32
+ Implementation of the scikit-learn API for XGBoost classification.
33
+ """
34
+
35
+ def fit(
36
+ self,
37
+ X,
38
+ y,
39
+ sample_weight=None,
40
+ base_margin=None,
41
+ eval_set=None,
42
+ sample_weight_eval_set=None,
43
+ base_margin_eval_set=None,
44
+ num_class=None,
45
+ ):
46
+ dtrain, evals = wrap_evaluation_matrices(
47
+ None,
48
+ X,
49
+ y,
50
+ sample_weight,
51
+ base_margin,
52
+ eval_set,
53
+ sample_weight_eval_set,
54
+ base_margin_eval_set,
55
+ )
56
+ params = self.get_xgb_params()
57
+ self.n_classes_ = num_class or 1
58
+ if self.n_classes_ > 2:
59
+ params["objective"] = "multi:softprob"
60
+ params["num_class"] = self.n_classes_
61
+ else:
62
+ params["objective"] = "binary:logistic"
63
+ self.evals_result_ = dict()
64
+ result = train(
65
+ params,
66
+ dtrain,
67
+ num_boost_round=self.get_num_boosting_rounds(),
68
+ evals=evals,
69
+ evals_result=self.evals_result_,
70
+ num_class=num_class,
71
+ )
72
+ self._Booster = result
73
+ return self
74
+
75
+ def predict(self, data, **kw):
76
+ prob = self.predict_proba(data, flag=True, **kw)
77
+ if prob.ndim > 1:
78
+ prediction = argmax(prob, axis=1)
79
+ else:
80
+ prediction = (prob > 0.5).astype(np.int64)
81
+ return prediction
82
+
83
+ def predict_proba(self, data, ntree_limit=None, flag=False, **kw):
84
+ if ntree_limit is not None:
85
+ raise NotImplementedError("ntree_limit is not currently supported")
86
+ return predict(self.get_booster(), data, flag=flag, **kw)