maxframe 0.1.0b4__cp310-cp310-win_amd64.whl → 1.0.0rc1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (81) hide show
  1. maxframe/__init__.py +1 -0
  2. maxframe/_utils.cp310-win_amd64.pyd +0 -0
  3. maxframe/codegen.py +56 -3
  4. maxframe/config/config.py +15 -1
  5. maxframe/core/__init__.py +0 -3
  6. maxframe/core/entity/__init__.py +1 -8
  7. maxframe/core/entity/objects.py +3 -45
  8. maxframe/core/graph/core.cp310-win_amd64.pyd +0 -0
  9. maxframe/core/graph/core.pyx +4 -4
  10. maxframe/dataframe/__init__.py +1 -0
  11. maxframe/dataframe/core.py +30 -8
  12. maxframe/dataframe/datasource/read_odps_query.py +3 -1
  13. maxframe/dataframe/datasource/read_odps_table.py +3 -1
  14. maxframe/dataframe/datastore/tests/__init__.py +13 -0
  15. maxframe/dataframe/datastore/tests/test_to_odps.py +48 -0
  16. maxframe/dataframe/datastore/to_odps.py +21 -0
  17. maxframe/dataframe/indexing/align.py +1 -1
  18. maxframe/dataframe/misc/__init__.py +4 -0
  19. maxframe/dataframe/misc/apply.py +3 -1
  20. maxframe/dataframe/misc/case_when.py +141 -0
  21. maxframe/dataframe/misc/memory_usage.py +2 -2
  22. maxframe/dataframe/misc/pivot_table.py +262 -0
  23. maxframe/dataframe/misc/tests/test_misc.py +84 -0
  24. maxframe/dataframe/plotting/core.py +2 -2
  25. maxframe/dataframe/reduction/core.py +2 -1
  26. maxframe/dataframe/statistics/corr.py +3 -3
  27. maxframe/dataframe/utils.py +7 -0
  28. maxframe/errors.py +13 -0
  29. maxframe/extension.py +12 -0
  30. maxframe/learn/contrib/utils.py +52 -0
  31. maxframe/learn/contrib/xgboost/__init__.py +26 -0
  32. maxframe/learn/contrib/xgboost/classifier.py +86 -0
  33. maxframe/learn/contrib/xgboost/core.py +156 -0
  34. maxframe/learn/contrib/xgboost/dmatrix.py +150 -0
  35. maxframe/learn/contrib/xgboost/predict.py +138 -0
  36. maxframe/learn/contrib/xgboost/regressor.py +78 -0
  37. maxframe/learn/contrib/xgboost/tests/__init__.py +13 -0
  38. maxframe/learn/contrib/xgboost/tests/test_core.py +43 -0
  39. maxframe/learn/contrib/xgboost/train.py +121 -0
  40. maxframe/learn/utils/__init__.py +15 -0
  41. maxframe/learn/utils/core.py +29 -0
  42. maxframe/lib/mmh3.cp310-win_amd64.pyd +0 -0
  43. maxframe/lib/mmh3.pyi +43 -0
  44. maxframe/lib/wrapped_pickle.py +2 -1
  45. maxframe/odpsio/arrow.py +2 -3
  46. maxframe/odpsio/tableio.py +22 -0
  47. maxframe/odpsio/tests/test_schema.py +16 -11
  48. maxframe/opcodes.py +3 -0
  49. maxframe/protocol.py +108 -10
  50. maxframe/serialization/core.cp310-win_amd64.pyd +0 -0
  51. maxframe/serialization/core.pxd +3 -0
  52. maxframe/serialization/core.pyi +64 -0
  53. maxframe/serialization/core.pyx +54 -25
  54. maxframe/serialization/exception.py +1 -1
  55. maxframe/serialization/pandas.py +7 -2
  56. maxframe/serialization/serializables/core.py +119 -12
  57. maxframe/serialization/serializables/tests/test_serializable.py +46 -4
  58. maxframe/session.py +28 -0
  59. maxframe/tensor/__init__.py +1 -1
  60. maxframe/tensor/arithmetic/tests/test_arithmetic.py +1 -1
  61. maxframe/tensor/base/__init__.py +2 -0
  62. maxframe/tensor/base/atleast_1d.py +74 -0
  63. maxframe/tensor/base/unique.py +205 -0
  64. maxframe/tensor/datasource/array.py +4 -2
  65. maxframe/tensor/datasource/scalar.py +1 -1
  66. maxframe/tensor/reduction/count_nonzero.py +1 -1
  67. maxframe/tests/test_protocol.py +34 -0
  68. maxframe/tests/test_utils.py +0 -12
  69. maxframe/tests/utils.py +2 -2
  70. maxframe/udf.py +63 -3
  71. maxframe/utils.py +22 -13
  72. {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/METADATA +3 -3
  73. {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/RECORD +80 -61
  74. maxframe_client/__init__.py +0 -1
  75. maxframe_client/fetcher.py +65 -3
  76. maxframe_client/session/odps.py +74 -5
  77. maxframe_client/session/task.py +65 -71
  78. maxframe_client/tests/test_session.py +64 -1
  79. maxframe_client/clients/spe.py +0 -104
  80. {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/WHEEL +0 -0
  81. {maxframe-0.1.0b4.dist-info → maxframe-1.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,26 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..utils import config_mod_getattr as _config_mod_getattr
16
+ from .dmatrix import DMatrix
17
+ from .predict import predict
18
+ from .train import train
19
+
20
+ _config_mod_getattr(
21
+ {
22
+ "XGBClassifier": ".classifier.XGBClassifier",
23
+ "XGBRegressor": ".regressor.XGBRegressor",
24
+ },
25
+ globals(),
26
+ )
@@ -0,0 +1,86 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+
17
+ from ....tensor import argmax
18
+ from ..utils import make_import_error_func
19
+ from .core import XGBScikitLearnBase, xgboost
20
+
21
+ if not xgboost:
22
+ XGBClassifier = make_import_error_func("xgboost")
23
+ else:
24
+ from xgboost.sklearn import XGBClassifierBase
25
+
26
+ from .core import wrap_evaluation_matrices
27
+ from .predict import predict
28
+ from .train import train
29
+
30
+ class XGBClassifier(XGBScikitLearnBase, XGBClassifierBase):
31
+ """
32
+ Implementation of the scikit-learn API for XGBoost classification.
33
+ """
34
+
35
+ def fit(
36
+ self,
37
+ X,
38
+ y,
39
+ sample_weight=None,
40
+ base_margin=None,
41
+ eval_set=None,
42
+ sample_weight_eval_set=None,
43
+ base_margin_eval_set=None,
44
+ num_class=None,
45
+ ):
46
+ dtrain, evals = wrap_evaluation_matrices(
47
+ None,
48
+ X,
49
+ y,
50
+ sample_weight,
51
+ base_margin,
52
+ eval_set,
53
+ sample_weight_eval_set,
54
+ base_margin_eval_set,
55
+ )
56
+ params = self.get_xgb_params()
57
+ self.n_classes_ = num_class or 1
58
+ if self.n_classes_ > 2:
59
+ params["objective"] = "multi:softprob"
60
+ params["num_class"] = self.n_classes_
61
+ else:
62
+ params["objective"] = "binary:logistic"
63
+ self.evals_result_ = dict()
64
+ result = train(
65
+ params,
66
+ dtrain,
67
+ num_boost_round=self.get_num_boosting_rounds(),
68
+ evals=evals,
69
+ evals_result=self.evals_result_,
70
+ num_class=num_class,
71
+ )
72
+ self._Booster = result
73
+ return self
74
+
75
+ def predict(self, data, **kw):
76
+ prob = self.predict_proba(data, flag=True, **kw)
77
+ if prob.ndim > 1:
78
+ prediction = argmax(prob, axis=1)
79
+ else:
80
+ prediction = (prob > 0.5).astype(np.int64)
81
+ return prediction
82
+
83
+ def predict_proba(self, data, ntree_limit=None, flag=False, **kw):
84
+ if ntree_limit is not None:
85
+ raise NotImplementedError("ntree_limit is not currently supported")
86
+ return predict(self.get_booster(), data, flag=flag, **kw)
@@ -0,0 +1,156 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Callable, List, Optional, Tuple
16
+
17
+ try:
18
+ import xgboost
19
+ except ImportError:
20
+ xgboost = None
21
+
22
+ from .dmatrix import DMatrix
23
+
24
+ if not xgboost:
25
+ XGBScikitLearnBase = None
26
+ else:
27
+
28
+ class XGBScikitLearnBase(xgboost.XGBModel):
29
+ """
30
+ Base class for implementing scikit-learn interface
31
+ """
32
+
33
+ def fit(
34
+ self,
35
+ X,
36
+ y,
37
+ sample_weights=None,
38
+ eval_set=None,
39
+ sample_weight_eval_set=None,
40
+ **kw,
41
+ ):
42
+ """
43
+ Fit the regressor.
44
+ Parameters
45
+ ----------
46
+ X : array_like
47
+ Feature matrix
48
+ y : array_like
49
+ Labels
50
+ sample_weight : array_like
51
+ instance weights
52
+ eval_set : list, optional
53
+ A list of (X, y) tuple pairs to use as validation sets, for which
54
+ metrics will be computed.
55
+ Validation metrics will help us track the performance of the model.
56
+ sample_weight_eval_set : list, optional
57
+ A list of the form [L_1, L_2, ..., L_n], where each L_i is a list
58
+ of group weights on the i-th validation set.
59
+ """
60
+ raise NotImplementedError
61
+
62
+ def predict(self, data, **kw):
63
+ """
64
+ Predict with `data`.
65
+
66
+ Parameters
67
+ ----------
68
+ data: data that can be used to perform prediction
69
+ Returns
70
+ -------
71
+ prediction : maxframe.tensor.Tensor
72
+ """
73
+ raise NotImplementedError
74
+
75
+ def wrap_evaluation_matrices(
76
+ missing: float,
77
+ X: Any,
78
+ y: Any,
79
+ sample_weight: Optional[Any],
80
+ base_margin: Optional[Any],
81
+ eval_set: Optional[List[Tuple[Any, Any]]],
82
+ sample_weight_eval_set: Optional[List[Any]],
83
+ base_margin_eval_set: Optional[List[Any]],
84
+ label_transform: Callable = lambda x: x,
85
+ ) -> Tuple[Any, Optional[List[Tuple[Any, str]]]]:
86
+ """
87
+ Convert array_like evaluation matrices into DMatrix.
88
+ Perform validation on the way.
89
+ """
90
+ train_dmatrix = DMatrix(
91
+ data=X,
92
+ label=label_transform(y),
93
+ weight=sample_weight,
94
+ base_margin=base_margin,
95
+ missing=missing,
96
+ )
97
+
98
+ n_validation = 0 if eval_set is None else len(eval_set)
99
+
100
+ def validate_or_none(meta: Optional[List], name: str) -> List:
101
+ if meta is None:
102
+ return [None] * n_validation
103
+ if len(meta) != n_validation:
104
+ raise ValueError(
105
+ f"{name}'s length does not equal `eval_set`'s length, "
106
+ + f"expecting {n_validation}, got {len(meta)}"
107
+ )
108
+ return meta
109
+
110
+ if eval_set is not None:
111
+ sample_weight_eval_set = validate_or_none(
112
+ sample_weight_eval_set, "sample_weight_eval_set"
113
+ )
114
+ base_margin_eval_set = validate_or_none(
115
+ base_margin_eval_set, "base_margin_eval_set"
116
+ )
117
+
118
+ evals = []
119
+ for i, (valid_X, valid_y) in enumerate(eval_set):
120
+ # Skip the duplicated entry.
121
+ if all(
122
+ (
123
+ valid_X is X,
124
+ valid_y is y,
125
+ sample_weight_eval_set[i] is sample_weight,
126
+ base_margin_eval_set[i] is base_margin,
127
+ )
128
+ ):
129
+ evals.append(train_dmatrix)
130
+ else:
131
+ m = DMatrix(
132
+ data=valid_X,
133
+ label=label_transform(valid_y),
134
+ weight=sample_weight_eval_set[i],
135
+ base_margin=base_margin_eval_set[i],
136
+ missing=missing,
137
+ )
138
+ evals.append(m)
139
+ nevals = len(evals)
140
+ eval_names = [f"validation_{i}" for i in range(nevals)]
141
+ evals = list(zip(evals, eval_names))
142
+ else:
143
+ if any(
144
+ meta is not None
145
+ for meta in [
146
+ sample_weight_eval_set,
147
+ base_margin_eval_set,
148
+ ]
149
+ ):
150
+ raise ValueError(
151
+ "`eval_set` is not set but one of the other evaluation meta info is "
152
+ "not None."
153
+ )
154
+ evals = []
155
+
156
+ return train_dmatrix, evals
@@ -0,0 +1,150 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from .... import opcodes as OperandDef
17
+ from ....core.entity.output_types import get_output_types
18
+ from ....core.operator.base import Operator
19
+ from ....core.operator.core import TileableOperatorMixin
20
+ from ....dataframe.core import DATAFRAME_TYPE
21
+ from ....serialization.serializables import Float64Field, KeyField, ListField
22
+ from ....serialization.serializables.field import AnyField, Int64Field
23
+ from ....tensor import tensor as astensor
24
+ from ....tensor.core import TENSOR_TYPE
25
+ from ....typing_ import TileableType
26
+ from ...utils import convert_to_tensor_or_dataframe
27
+
28
+
29
+ class ToDMatrix(Operator, TileableOperatorMixin):
30
+ _op_type_ = OperandDef.TO_DMATRIX
31
+
32
+ data = KeyField("data", default=None)
33
+ label = KeyField("label", default=None)
34
+ missing = Float64Field("missing", default=None)
35
+ weight = KeyField("weight", default=None)
36
+ base_margin = KeyField("base_margin", default=None)
37
+ feature_names = ListField("feature_names", default=None)
38
+ feature_types = ListField("feature_types", default=None)
39
+ feature_weights = AnyField("feature_weights", default=None)
40
+ nthread = Int64Field("nthread", default=None)
41
+ group = AnyField("group", default=None)
42
+ qid = AnyField("qid", default=None)
43
+ label_lower_bound = AnyField("label_lower_bound", default=None)
44
+ label_upper_bound = AnyField("label_upper_bound", default=None)
45
+
46
+ @property
47
+ def output_limit(self):
48
+ return 1
49
+
50
+ def _set_inputs(self, inputs):
51
+ super()._set_inputs(inputs)
52
+ if self.data is not None:
53
+ self.data = self._inputs[0]
54
+ has_label = self.label is not None
55
+ if has_label:
56
+ self.label = self._inputs[1]
57
+ if self.weight is not None:
58
+ i = 1 if not has_label else 2
59
+ self.weight = self._inputs[i]
60
+ if self.base_margin is not None:
61
+ self.base_margin = self._inputs[-1]
62
+
63
+ @staticmethod
64
+ def _get_kw(obj):
65
+ if isinstance(obj, TENSOR_TYPE):
66
+ return {"shape": obj.shape, "dtype": obj.dtype, "order": obj.order}
67
+ else:
68
+ return {
69
+ "shape": obj.shape,
70
+ "dtypes": obj.dtypes,
71
+ "index_value": obj.index_value,
72
+ "columns_value": obj.columns_value,
73
+ }
74
+
75
+ def __call__(self):
76
+ inputs = [self.data]
77
+ kw = self._get_kw(self.data)
78
+ if self.label is not None:
79
+ inputs.append(self.label)
80
+ if self.weight is not None:
81
+ inputs.append(self.weight)
82
+ if self.base_margin is not None:
83
+ inputs.append(self.base_margin)
84
+
85
+ return self.new_tileable(inputs, **kw)
86
+
87
+
88
+ def check_data(data):
89
+ data = convert_to_tensor_or_dataframe(data)
90
+ if data.ndim != 2:
91
+ raise ValueError(f"Expecting 2-d data, got: {data.ndim}-d")
92
+
93
+ return data
94
+
95
+
96
+ def check_array_like(y: TileableType, name: str) -> TileableType:
97
+ if y is None:
98
+ return
99
+ y = convert_to_tensor_or_dataframe(y)
100
+ if isinstance(y, DATAFRAME_TYPE):
101
+ y = y.iloc[:, 0]
102
+ y = astensor(y)
103
+ if y.ndim != 1:
104
+ raise ValueError(f"Expecting 1-d {name}, got: {y.ndim}-d")
105
+ return y
106
+
107
+
108
+ def to_dmatrix(
109
+ data,
110
+ label=None,
111
+ missing=None,
112
+ weight=None,
113
+ base_margin=None,
114
+ feature_names=None,
115
+ feature_types=None,
116
+ feature_weights=None,
117
+ nthread=None,
118
+ group=None,
119
+ qid=None,
120
+ label_lower_bound=None,
121
+ label_upper_bound=None,
122
+ ):
123
+ data = check_data(data)
124
+ label = check_array_like(label, "label")
125
+ weight = check_array_like(weight, "weight")
126
+ base_margin = check_array_like(base_margin, "base_margin")
127
+
128
+ # If not multiple outputs, try to collect the chunks on same worker into one
129
+ # to feed the data into XGBoost for training.
130
+ op = ToDMatrix(
131
+ data=data,
132
+ label=label,
133
+ missing=missing,
134
+ weight=weight,
135
+ base_margin=base_margin,
136
+ feature_names=feature_names,
137
+ feature_types=feature_types,
138
+ feature_weights=feature_weights,
139
+ nthread=nthread,
140
+ group=group,
141
+ qid=qid,
142
+ label_lower_bound=label_lower_bound,
143
+ label_upper_bound=label_upper_bound,
144
+ gpu=data.op.gpu,
145
+ _output_types=get_output_types(data),
146
+ )
147
+ return op()
148
+
149
+
150
+ DMatrix = to_dmatrix
@@ -0,0 +1,138 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import pickle
16
+
17
+ import numpy as np
18
+ import pandas as pd
19
+
20
+ from .... import opcodes as OperandDef
21
+ from ....core.entity.output_types import OutputType
22
+ from ....core.operator.base import Operator
23
+ from ....core.operator.core import TileableOperatorMixin
24
+ from ....dataframe.utils import parse_index
25
+ from ....serialization.serializables import BoolField, BytesField, KeyField, TupleField
26
+ from ....tensor.core import TENSOR_TYPE, TensorOrder
27
+ from .dmatrix import check_data
28
+
29
+
30
+ class XGBPredict(Operator, TileableOperatorMixin):
31
+ _op_type_ = OperandDef.XGBOOST_PREDICT
32
+ output_dtype = np.dtype(np.float32)
33
+
34
+ data = KeyField("data", default=None)
35
+ model = BytesField(
36
+ "model", on_serialize=pickle.dumps, on_deserialize=pickle.loads, default=None
37
+ )
38
+ pred_leaf = BoolField("pred_leaf", default=False)
39
+ pred_contribs = BoolField("pred_contribs", default=False)
40
+ approx_contribs = BoolField("approx_contribs", default=False)
41
+ pred_interactions = BoolField("pred_interactions", default=False)
42
+ validate_features = BoolField("validate_features", default=True)
43
+ training = BoolField("training", default=False)
44
+ iteration_range = TupleField("iteration_range", default_factory=lambda x: (0, 0))
45
+ strict_shape = BoolField("strict_shape", default=False)
46
+ flag = BoolField("flag", default=False)
47
+
48
+ def __init__(self, output_types=None, gpu=None, **kw):
49
+ super().__init__(_output_types=output_types, gpu=gpu, **kw)
50
+
51
+ def _set_inputs(self, inputs):
52
+ super()._set_inputs(inputs)
53
+ self.data = self._inputs[0]
54
+ self.model = self._inputs[1]
55
+
56
+ def __call__(self):
57
+ num_class = getattr(self.model.op, "num_class", None)
58
+ if num_class is not None:
59
+ num_class = int(num_class)
60
+ if num_class is not None:
61
+ shape = (self.data.shape[0], num_class)
62
+ else:
63
+ shape = (self.data.shape[0],)
64
+ inputs = [self.data, self.model]
65
+ if self.output_types[0] == OutputType.tensor:
66
+ # tensor
67
+ return self.new_tileable(
68
+ inputs,
69
+ shape=shape,
70
+ dtype=self.output_dtype,
71
+ order=TensorOrder.C_ORDER,
72
+ )
73
+ elif self.output_types[0] == OutputType.dataframe:
74
+ # dataframe
75
+ dtypes = pd.DataFrame(
76
+ np.random.rand(0, num_class), dtype=self.output_dtype
77
+ ).dtypes
78
+ return self.new_tileable(
79
+ inputs,
80
+ shape=shape,
81
+ dtypes=dtypes,
82
+ columns_value=parse_index(dtypes.index),
83
+ index_value=self.data.index_value,
84
+ )
85
+ else:
86
+ # series
87
+ return self.new_tileable(
88
+ inputs,
89
+ shape=shape,
90
+ index_value=self.data.index_value,
91
+ name="predictions",
92
+ dtype=self.output_dtype,
93
+ )
94
+
95
+
96
+ def predict(
97
+ model,
98
+ data,
99
+ output_margin=False,
100
+ pred_leaf=False,
101
+ pred_contribs=False,
102
+ approx_contribs=False,
103
+ pred_interactions=False,
104
+ validate_features=True,
105
+ training=False,
106
+ iteration_range=None,
107
+ strict_shape=False,
108
+ flag=False,
109
+ ):
110
+ data = check_data(data)
111
+ # TODO: check model datatype
112
+
113
+ num_class = getattr(model.op, "num_class", None)
114
+ if isinstance(data, TENSOR_TYPE):
115
+ output_types = [OutputType.tensor]
116
+ elif num_class is not None:
117
+ output_types = [OutputType.dataframe]
118
+ else:
119
+ output_types = [OutputType.series]
120
+
121
+ iteration_range = iteration_range or (0, 0)
122
+
123
+ return XGBPredict(
124
+ data=data,
125
+ model=model,
126
+ output_margin=output_margin,
127
+ pred_leaf=pred_leaf,
128
+ pred_contribs=pred_contribs,
129
+ approx_contribs=approx_contribs,
130
+ pred_interactions=pred_interactions,
131
+ validate_features=validate_features,
132
+ training=training,
133
+ iteration_range=iteration_range,
134
+ strict_shape=strict_shape,
135
+ gpu=data.op.gpu,
136
+ output_types=output_types,
137
+ flag=flag,
138
+ )()
@@ -0,0 +1,78 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from ..utils import make_import_error_func
17
+ from .core import XGBScikitLearnBase, xgboost
18
+
19
+ if not xgboost:
20
+ XGBRegressor = make_import_error_func("xgboost")
21
+ else:
22
+ from .core import wrap_evaluation_matrices
23
+ from .predict import predict
24
+ from .train import train
25
+
26
+ class XGBRegressor(XGBScikitLearnBase):
27
+ """
28
+ Implementation of the scikit-learn API for XGBoost regressor.
29
+ """
30
+
31
+ def fit(
32
+ self,
33
+ X,
34
+ y,
35
+ sample_weight=None,
36
+ base_margin=None,
37
+ eval_set=None,
38
+ sample_weight_eval_set=None,
39
+ base_margin_eval_set=None,
40
+ **kw,
41
+ ):
42
+ session = kw.pop("session", None)
43
+ run_kwargs = kw.pop("run_kwargs", dict())
44
+ if kw:
45
+ raise TypeError(
46
+ f"fit got an unexpected keyword argument '{next(iter(kw))}'"
47
+ )
48
+
49
+ dtrain, evals = wrap_evaluation_matrices(
50
+ None,
51
+ X,
52
+ y,
53
+ sample_weight,
54
+ base_margin,
55
+ eval_set,
56
+ sample_weight_eval_set,
57
+ base_margin_eval_set,
58
+ )
59
+ params = self.get_xgb_params()
60
+ self.evals_result_ = dict()
61
+ result = train(
62
+ params,
63
+ dtrain,
64
+ num_boost_round=self.get_num_boosting_rounds(),
65
+ evals=evals,
66
+ evals_result=self.evals_result_,
67
+ session=session,
68
+ run_kwargs=run_kwargs,
69
+ )
70
+ self._Booster = result
71
+ return self
72
+
73
+ def predict(self, data, **kw):
74
+ session = kw.pop("session", None)
75
+ run_kwargs = kw.pop("run_kwargs", None)
76
+ return predict(
77
+ self.get_booster(), data, session=session, run_kwargs=run_kwargs, **kw
78
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,43 @@
1
+ # Copyright 1999-2024 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import pytest
16
+
17
+ try:
18
+ import xgboost
19
+ except ImportError:
20
+ xgboost = None
21
+
22
+
23
+ from ..... import tensor as mt
24
+
25
+ if xgboost:
26
+ from ..core import wrap_evaluation_matrices
27
+
28
+
29
+ @pytest.mark.skipif(xgboost is None, reason="XGBoost not installed")
30
+ def test_wrap_evaluation_matrices():
31
+ X = mt.random.rand(100, 3)
32
+ y = mt.random.randint(3, size=(100,))
33
+
34
+ eval_set = [(mt.random.rand(10, 3), mt.random.randint(3, size=10))]
35
+ with pytest.raises(ValueError):
36
+ # sample_weight_eval_set size wrong
37
+ wrap_evaluation_matrices(0.0, X, y, None, None, eval_set, [], None)
38
+
39
+ with pytest.raises(ValueError):
40
+ wrap_evaluation_matrices(0.0, X, y, None, None, None, eval_set, None)
41
+
42
+ evals = wrap_evaluation_matrices(0.0, X, y, None, None, eval_set, None, None)[1]
43
+ assert len(evals) > 0