maxframe 1.2.1__cp311-cp311-win_amd64.whl → 1.3.1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (73) hide show
  1. maxframe/_utils.cp311-win_amd64.pyd +0 -0
  2. maxframe/codegen.py +70 -21
  3. maxframe/config/config.py +6 -0
  4. maxframe/core/accessor.py +1 -0
  5. maxframe/core/graph/core.cp311-win_amd64.pyd +0 -0
  6. maxframe/dataframe/accessors/__init__.py +1 -1
  7. maxframe/dataframe/accessors/dict_/accessor.py +1 -0
  8. maxframe/dataframe/accessors/dict_/length.py +1 -0
  9. maxframe/dataframe/accessors/dict_/setitem.py +1 -0
  10. maxframe/dataframe/accessors/dict_/tests/test_dict_accessor.py +5 -7
  11. maxframe/dataframe/accessors/list_/__init__.py +37 -0
  12. maxframe/dataframe/accessors/list_/accessor.py +39 -0
  13. maxframe/dataframe/accessors/list_/getitem.py +135 -0
  14. maxframe/dataframe/accessors/list_/length.py +73 -0
  15. maxframe/dataframe/accessors/list_/tests/__init__.py +13 -0
  16. maxframe/dataframe/accessors/list_/tests/test_list_accessor.py +79 -0
  17. maxframe/dataframe/accessors/plotting/__init__.py +2 -0
  18. maxframe/dataframe/accessors/string_/__init__.py +1 -0
  19. maxframe/dataframe/datastore/to_odps.py +6 -0
  20. maxframe/dataframe/extensions/accessor.py +1 -0
  21. maxframe/dataframe/extensions/apply_chunk.py +34 -21
  22. maxframe/dataframe/extensions/flatmap.py +8 -1
  23. maxframe/dataframe/extensions/tests/test_apply_chunk.py +2 -1
  24. maxframe/dataframe/extensions/tests/test_extensions.py +1 -0
  25. maxframe/dataframe/groupby/aggregation.py +53 -1
  26. maxframe/dataframe/merge/concat.py +7 -4
  27. maxframe/dataframe/merge/merge.py +1 -0
  28. maxframe/dataframe/merge/tests/test_merge.py +97 -47
  29. maxframe/dataframe/missing/tests/test_missing.py +1 -0
  30. maxframe/dataframe/reduction/aggregation.py +63 -0
  31. maxframe/dataframe/reduction/core.py +17 -5
  32. maxframe/dataframe/tests/test_utils.py +7 -0
  33. maxframe/dataframe/ufunc/ufunc.py +1 -0
  34. maxframe/dataframe/utils.py +3 -0
  35. maxframe/io/odpsio/schema.py +1 -0
  36. maxframe/learn/contrib/__init__.py +2 -4
  37. maxframe/learn/contrib/llm/__init__.py +1 -0
  38. maxframe/learn/contrib/llm/core.py +31 -10
  39. maxframe/learn/contrib/llm/models/__init__.py +1 -0
  40. maxframe/learn/contrib/llm/models/dashscope.py +38 -3
  41. maxframe/learn/contrib/llm/models/managed.py +54 -0
  42. maxframe/learn/contrib/llm/multi_modal.py +93 -0
  43. maxframe/learn/contrib/llm/text.py +268 -8
  44. maxframe/learn/contrib/models.py +77 -0
  45. maxframe/learn/contrib/utils.py +1 -0
  46. maxframe/learn/contrib/xgboost/__init__.py +8 -1
  47. maxframe/learn/contrib/xgboost/classifier.py +15 -4
  48. maxframe/learn/contrib/xgboost/core.py +108 -1
  49. maxframe/learn/contrib/xgboost/dmatrix.py +1 -1
  50. maxframe/learn/contrib/xgboost/predict.py +6 -3
  51. maxframe/learn/contrib/xgboost/regressor.py +15 -1
  52. maxframe/learn/contrib/xgboost/train.py +5 -4
  53. maxframe/lib/dtypes_extension/__init__.py +2 -1
  54. maxframe/lib/dtypes_extension/dtypes.py +21 -0
  55. maxframe/lib/dtypes_extension/tests/test_dtypes.py +13 -3
  56. maxframe/lib/mmh3.cp311-win_amd64.pyd +0 -0
  57. maxframe/opcodes.py +19 -0
  58. maxframe/serialization/__init__.py +1 -0
  59. maxframe/serialization/core.cp311-win_amd64.pyd +0 -0
  60. maxframe/serialization/core.pyx +12 -1
  61. maxframe/serialization/numpy.py +12 -4
  62. maxframe/serialization/serializables/tests/test_serializable.py +13 -2
  63. maxframe/serialization/tests/test_serial.py +2 -0
  64. maxframe/tensor/merge/concatenate.py +1 -0
  65. maxframe/tensor/misc/unique.py +11 -10
  66. maxframe/tensor/reshape/reshape.py +4 -1
  67. maxframe/utils.py +4 -0
  68. {maxframe-1.2.1.dist-info → maxframe-1.3.1.dist-info}/METADATA +3 -2
  69. {maxframe-1.2.1.dist-info → maxframe-1.3.1.dist-info}/RECORD +73 -65
  70. {maxframe-1.2.1.dist-info → maxframe-1.3.1.dist-info}/WHEEL +1 -1
  71. maxframe_client/session/odps.py +3 -0
  72. maxframe_client/session/tests/test_task.py +1 -0
  73. {maxframe-1.2.1.dist-info → maxframe-1.3.1.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
14
15
  from typing import Any, Dict
15
16
 
16
17
  from ....dataframe.core import DATAFRAME_TYPE, SERIES_TYPE
@@ -33,6 +34,98 @@ def generate(
33
34
  prompt_template: Dict[str, Any],
34
35
  params: Dict[str, Any] = None,
35
36
  ):
37
+ """
38
+ Generate text with multi model llm based on given data and prompt template.
39
+
40
+ Parameters
41
+ ----------
42
+ data : DataFrame or Series
43
+ Input data used for generation. Can be maxframe DataFrame, Series that contain text to be processed.
44
+ model : MultiModalLLM
45
+ Language model instance support **MultiModal** inputs used for text generation.
46
+ prompt_template : List[Dict[str, List[Dict[str, str]]]]
47
+ List of message with column names as placeholders. Each message contains a role and content. Content is a list of dict, each dict contains a text or image, the value can reference column data from input.
48
+
49
+ Here is an example of prompt template.
50
+
51
+ .. code-block:: python
52
+
53
+ [
54
+ {
55
+ "role": "<role>", # e.g. "user" or "assistant"
56
+ "content": [
57
+ {
58
+ # At least one of these fields is required
59
+ "image": "<image_data_url>", # optional
60
+ "text": "<prompt_text_template>" # optional
61
+ },
62
+ ...
63
+ ]
64
+ }
65
+ ]
66
+
67
+ Where:
68
+
69
+ - ``text`` can be a Python format string using column names from input data as parameters (e.g. ``"{column_name}"``)
70
+ - ``image`` should be a DataURL string following `RFC2397 <https://en.wikipedia.org/wiki/Data_URI_scheme>`_ standard with format.
71
+
72
+ .. code-block:: none
73
+
74
+ data:<mime_type>[;base64],<column_name>
75
+
76
+
77
+ params : Dict[str, Any], optional
78
+ Additional parameters for generation configuration, by default None.
79
+ Can include settings like temperature, max_tokens, etc.
80
+
81
+ Returns
82
+ -------
83
+ DataFrame
84
+ Generated text raw response and success status. If the success is False, the generated text will return the
85
+ error message.
86
+
87
+ Notes
88
+ -----
89
+ - The ``api_key_resource`` parameter should reference a text file resource in MaxCompute that contains only your DashScope API key.
90
+
91
+ - Using DashScope services requires enabling public network access for your MaxCompute project. This can be configured through the MaxCompute console by `enabling the Internet access feature <https://help.aliyun.com/zh/maxcompute/user-guide/network-connection-process>`_ for your project. Without this configuration, the API calls to DashScope will fail due to network connectivity issues.
92
+
93
+ Examples
94
+ --------
95
+ You can initialize a DashScope multi-modal model (such as qwen-vl-max) by providing a model name and an ``api_key_resource``.
96
+ The ``api_key_resource`` is a MaxCompute resource name that points to a text file containing a `DashScope <https://dashscope.aliyun.com/>`_ API key.
97
+
98
+ >>> from maxframe.learn.contrib.llm.models.dashscope import DashScopeMultiModalLLM
99
+ >>> import maxframe.dataframe as md
100
+ >>>
101
+ >>> model = DashScopeMultiModalLLM(
102
+ ... name="qwen-vl-max",
103
+ ... api_key_resource="<api-key-resource-name>"
104
+ ... )
105
+
106
+ We use Data Url Schema to provide multi modal input in prompt template, here is an example to fill in the image from table.
107
+
108
+ Assuming you have a MaxCompute table with two columns: ``image_id`` (as the index) and ``encoded_image_data_base64`` (containing Base64 encoded image data),
109
+ you can construct a prompt message template as follows:
110
+
111
+ >>> df = md.read_odps_table("image_content", index_col="image_id")
112
+
113
+ >>> prompt_template = [
114
+ ... {
115
+ ... "role": "user",
116
+ ... "content": [
117
+ ... {
118
+ ... "image": "_image_data_base64",
119
+ ... },
120
+ ... {
121
+ ... "text": "Analyze this image in detail",
122
+ ... },
123
+ ... ],
124
+ ... },
125
+ ... ]
126
+ >>> result = model.generate(df, prompt_template)
127
+ >>> result.execute()
128
+ """
36
129
  if not isinstance(data, DATAFRAME_TYPE) and not isinstance(data, SERIES_TYPE):
37
130
  raise ValueError("data must be a maxframe dataframe or series object")
38
131
  if not isinstance(model, MultiModalLLM):
@@ -11,32 +11,292 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Any, Dict
15
14
 
16
- from ....dataframe.core import DATAFRAME_TYPE, SERIES_TYPE
17
- from .core import LLM
15
+ from typing import Any, Dict, List
16
+
17
+ import numpy as np
18
+
19
+ from .... import opcodes
20
+ from ....dataframe.core import DataFrame, Series
21
+ from ....serialization.serializables import FieldTypes, ListField, StringField
22
+ from .core import LLM, LLMTaskOperator
23
+
24
+
25
+ class TextLLMSummarizeOperator(LLMTaskOperator):
26
+ _op_type_ = opcodes.LLM_TEXT_SUMMARIZE_TASK
27
+
28
+ def get_output_dtypes(self) -> Dict[str, np.dtype]:
29
+ return {
30
+ "summary": np.dtype("O"),
31
+ "success": np.dtype("bool"),
32
+ }
33
+
34
+
35
+ class TextLLMTranslateOperator(LLMTaskOperator):
36
+ _op_type_ = opcodes.LLM_TEXT_TRANSLATE_TASK
37
+
38
+ source_language = StringField("source_language")
39
+ target_language = StringField("target_language")
40
+
41
+ def get_output_dtypes(self) -> Dict[str, np.dtype]:
42
+ return {
43
+ "target": np.dtype("O"),
44
+ "success": np.dtype("bool"),
45
+ }
46
+
47
+
48
+ class TextLLMClassifyOperator(LLMTaskOperator):
49
+ _op_type_ = opcodes.LLM_TEXT_CLASSIFY_TASK
50
+
51
+ labels = ListField("labels")
52
+ description = StringField("description", default=None)
53
+ examples = ListField("examples", FieldTypes.dict, default=None)
54
+
55
+ def get_output_dtypes(self) -> Dict[str, np.dtype]:
56
+ return {
57
+ "label": np.dtype("O"),
58
+ "reason": np.dtype("O"),
59
+ "success": np.dtype("bool"),
60
+ }
18
61
 
19
62
 
20
63
  class TextLLM(LLM):
21
64
  def generate(
22
65
  self,
23
66
  data,
24
- prompt_template: Dict[str, Any],
67
+ prompt_template: List[Dict[str, str]],
25
68
  params: Dict[str, Any] = None,
26
69
  ):
27
70
  raise NotImplementedError
28
71
 
72
+ def summarize(self, series, index=None, **kw):
73
+ return TextLLMSummarizeOperator(model=self, task="summarize", **kw)(
74
+ series, index
75
+ )
76
+
77
+ def translate(
78
+ self,
79
+ series,
80
+ target_language: str,
81
+ source_language: str = None,
82
+ index=None,
83
+ **kw
84
+ ):
85
+ return TextLLMTranslateOperator(
86
+ model=self,
87
+ task="translate",
88
+ source_language=source_language,
89
+ target_language=target_language,
90
+ **kw
91
+ )(series, index)
92
+
93
+ def classify(
94
+ self,
95
+ series,
96
+ labels: List[str],
97
+ description=None,
98
+ examples=None,
99
+ index=None,
100
+ **kw
101
+ ):
102
+ return TextLLMClassifyOperator(
103
+ model=self,
104
+ labels=labels,
105
+ task="classify",
106
+ description=description,
107
+ examples=examples,
108
+ **kw
109
+ )(series, index)
110
+
29
111
 
30
112
  def generate(
31
113
  data,
32
114
  model: TextLLM,
33
- prompt_template: Dict[str, Any],
115
+ prompt_template: List[Dict[str, Any]],
34
116
  params: Dict[str, Any] = None,
35
117
  ):
36
- if not isinstance(data, DATAFRAME_TYPE) and not isinstance(data, SERIES_TYPE):
118
+ """
119
+ Generate text using a text language model based on given data and prompt template.
120
+
121
+ Parameters
122
+ ----------
123
+ data : DataFrame or Series
124
+ Input data used for generation. Can be maxframe DataFrame, Series that contain text to be processed.
125
+ model : TextLLM
126
+ Language model instance used for text generation.
127
+ prompt_template : List[Dict[str, str]]
128
+ Dictionary containing the conversation messages template. Use ``{col_name}`` as a placeholder to reference
129
+ column data from input data.
130
+
131
+ Usually in format of [{"role": "user", "content": "{query}"}], same with openai api schema.
132
+ params : Dict[str, Any], optional
133
+ Additional parameters for generation configuration, by default None.
134
+ Can include settings like temperature, max_tokens, etc.
135
+
136
+ Returns
137
+ -------
138
+ DataFrame
139
+ Generated text raw response and success status. If the success is False, the generated text will return the
140
+ error message.
141
+
142
+ Examples
143
+ --------
144
+ >>> from maxframe.learn.contrib.llm.models.managed import ManagedTextLLM
145
+ >>> import maxframe.dataframe as md
146
+ >>>
147
+ >>> # Initialize the model
148
+ >>> llm = ManagedTextLLM(name="Qwen2.5-0.5B-instruct")
149
+ >>>
150
+ >>> # Prepare prompt template
151
+ >>> messages = [
152
+ ... {
153
+ ... "role": "user",
154
+ ... "content": "Help answer following question: {query}",
155
+ ... },
156
+ ... ]
157
+
158
+ >>> # Create sample data
159
+ >>> df = md.DataFrame({"query": ["What is machine learning?"]})
160
+ >>>
161
+ >>> # Generate response
162
+ >>> result = generate(df, llm, prompt_template=messages)
163
+ >>> result.execute()
164
+ """
165
+ if not isinstance(data, DataFrame) and not isinstance(data, Series):
37
166
  raise ValueError("data must be a maxframe dataframe or series object")
38
167
  if not isinstance(model, TextLLM):
39
- raise ValueError("model must be a TextLLM object")
168
+ raise TypeError("model must be a TextLLM object")
40
169
  params = params if params is not None else dict()
41
170
  model.validate_params(params)
42
- return model.generate(data, prompt_template, params)
171
+ return model.generate(data, prompt_template=prompt_template, params=params)
172
+
173
+
174
+ def summary(series, model: TextLLM, index=None):
175
+ """
176
+ Generate summaries for text content in a series using a language model.
177
+
178
+ Parameters
179
+ ----------
180
+ series : Series
181
+ A maxframe Series containing text data to be summarized.
182
+ Each element should be a text string.
183
+ model : TextLLM
184
+ Language model instance used for text summarization.
185
+ index : array-like, optional
186
+ Index for the output series, by default None, will generate new index.
187
+
188
+ Returns
189
+ -------
190
+ maxframe.Series
191
+ A pandas Series containing the generated summaries and success status.
192
+
193
+ Notes
194
+ -----
195
+ **Preview:** This API is in preview state and may be unstable.
196
+ The interface may change in future releases.
197
+ """
198
+ if not isinstance(series, Series):
199
+ raise ValueError("series must be a maxframe series object")
200
+
201
+ if series.dtype != np.str_:
202
+ raise ValueError("summary input must be a string series")
203
+
204
+ return model.summarize(series, index=index)
205
+
206
+
207
+ def translate(
208
+ series, model: TextLLM, source_language: str, target_language: str, index=None
209
+ ):
210
+ """
211
+ Translate text content in a series using a language model from source language to target language.
212
+
213
+ Parameters
214
+ ----------
215
+ series : pandas.Series
216
+ A maxframe Series containing text data to translate.
217
+ Each element should be a text string.
218
+ model : TextLLM
219
+ Language model instance used for text summarization.
220
+ source_language : str
221
+ Source language of the text.
222
+ target_language : str
223
+ Target language of the text.
224
+ index : array-like, optional
225
+ Index for the output series, by default None, will generate new index.
226
+
227
+ Returns
228
+ -------
229
+ maxframe.Series
230
+ A pandas Series containing the generated translation and success status.
231
+
232
+ Notes
233
+ -----
234
+ **Preview:** This API is in preview state and may be unstable.
235
+ The interface may change in future releases.
236
+
237
+ """
238
+ if not isinstance(series, Series):
239
+ raise ValueError("series must be a maxframe series object")
240
+ if series.dtype != np.str_:
241
+ raise ValueError("translate input must be a string series")
242
+ return model.translate(
243
+ series,
244
+ source_language=source_language,
245
+ target_language=target_language,
246
+ index=index,
247
+ )
248
+
249
+
250
+ def classify(
251
+ series,
252
+ model: TextLLM,
253
+ labels: List[str],
254
+ description: str = None,
255
+ examples: List[Dict[str, str]] = None,
256
+ index=None,
257
+ ):
258
+ """
259
+ Classify text content in a series with given labels.
260
+
261
+ Parameters
262
+ ----------
263
+ series : pandas.Series
264
+ A maxframe Series containing text data to be classified.
265
+ Each element should be a text string.
266
+ model : TextLLM
267
+ Language model instance used for text summarization.
268
+ labels : List[str]
269
+ List of labels to classify the text.
270
+ description : str
271
+ Description of the classification task.
272
+ examples : List[Dict[str, Dict[str, str]]]
273
+ Examples of the classification task, like [{ "text": "text...", "label":"A", reason : "reason..."}], help
274
+ LLM to better understand your rules.
275
+ index : array-like, optional
276
+ Index for the output series, by default None, will generate new index.
277
+
278
+ Returns
279
+ -------
280
+ maxframe.Series
281
+ A pandas Series containing the generated classification results and success status.
282
+
283
+ Notes
284
+ -----
285
+ **Preview:** This API is in preview state and may be unstable.
286
+ The interface may change in future releases.
287
+ """
288
+ if not isinstance(series, Series):
289
+ raise ValueError("series must be a maxframe series object")
290
+
291
+ if series.dtype != np.str_:
292
+ raise ValueError("classify input must be a string series")
293
+
294
+ if not isinstance(labels, list):
295
+ raise TypeError("labels must be a list")
296
+
297
+ if not labels:
298
+ raise ValueError("labels must not be empty")
299
+
300
+ return model.classify(
301
+ series, labels=labels, description=description, examples=examples, index=index
302
+ )
@@ -0,0 +1,77 @@
1
+ # Copyright 1999-2025 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ... import opcodes
16
+ from ...core import ENTITY_TYPE, OutputType
17
+ from ...core.operator import ObjectOperator, ObjectOperatorMixin
18
+ from ...serialization.serializables import (
19
+ AnyField,
20
+ DictField,
21
+ FunctionField,
22
+ TupleField,
23
+ )
24
+ from ...utils import find_objects, replace_objects
25
+
26
+
27
+ class ModelDataSource(ObjectOperator, ObjectOperatorMixin):
28
+ _op_type_ = opcodes.MODEL_DATA_SOURCE
29
+
30
+ data = AnyField("data")
31
+
32
+ def __call__(self, model_cls):
33
+ self._output_types = [OutputType.object]
34
+ return self.new_tileable(None, object_class=model_cls)
35
+
36
+
37
+ class ModelApplyChunk(ObjectOperator, ObjectOperatorMixin):
38
+ _op_module_ = "maxframe.learn.contrib.models"
39
+ _op_type_ = opcodes.APPLY_CHUNK
40
+
41
+ func = FunctionField("func")
42
+ args = TupleField("args")
43
+ kwargs = DictField("kwargs")
44
+
45
+ def __init__(self, output_types=None, **kwargs):
46
+ if not isinstance(output_types, (tuple, list)):
47
+ output_types = [output_types]
48
+ self._output_types = list(output_types)
49
+ super().__init__(**kwargs)
50
+
51
+ def _set_inputs(self, inputs):
52
+ super()._set_inputs(inputs)
53
+ old_inputs = find_objects(self.args, ENTITY_TYPE) + find_objects(
54
+ self.kwargs, ENTITY_TYPE
55
+ )
56
+ mapping = {o: n for o, n in zip(old_inputs, self._inputs[1:])}
57
+ self.args = replace_objects(self.args, mapping)
58
+ self.kwargs = replace_objects(self.kwargs, mapping)
59
+
60
+ @property
61
+ def output_limit(self) -> int:
62
+ return len(self._output_types)
63
+
64
+ def __call__(self, t, output_kws, args=None, **kwargs):
65
+ self.args = args or ()
66
+ self.kwargs = kwargs
67
+ inputs = (
68
+ [t]
69
+ + find_objects(self.args, ENTITY_TYPE)
70
+ + find_objects(self.kwargs, ENTITY_TYPE)
71
+ )
72
+ return self.new_tileables(inputs, kws=output_kws)
73
+
74
+
75
+ def to_remote_model(model, model_cls):
76
+ op = ModelDataSource(data=model)
77
+ return op(model_cls)
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
14
15
  import sys
15
16
 
16
17
 
@@ -12,11 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..utils import config_mod_getattr as _config_mod_getattr
15
+ from .core import Booster
16
16
  from .dmatrix import DMatrix
17
17
  from .predict import predict
18
18
  from .train import train
19
19
 
20
+ # isort: off
21
+ from ..utils import config_mod_getattr as _config_mod_getattr
22
+
20
23
  _config_mod_getattr(
21
24
  {
22
25
  "XGBClassifier": ".classifier.XGBClassifier",
@@ -24,3 +27,7 @@ _config_mod_getattr(
24
27
  },
25
28
  globals(),
26
29
  )
30
+
31
+ del _config_mod_getattr
32
+
33
+ __all__ = ["Booster", "DMatrix", "XGBClassifier", "XGBRegressor", "predict", "train"]
@@ -12,9 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from typing import Union
16
+
15
17
  import numpy as np
16
18
 
17
- from ....tensor import argmax, transpose
19
+ from .... import tensor as mt
18
20
  from ....tensor.merge.vstack import _vstack
19
21
  from ..utils import make_import_error_func
20
22
  from .core import XGBScikitLearnBase, xgboost
@@ -33,6 +35,14 @@ else:
33
35
  Implementation of the scikit-learn API for XGBoost classification.
34
36
  """
35
37
 
38
+ def __init__(
39
+ self,
40
+ xgb_model: Union[xgboost.XGBClassifier, xgboost.Booster] = None,
41
+ **kwargs,
42
+ ):
43
+ super().__init__(**kwargs)
44
+ self._set_model(xgb_model)
45
+
36
46
  def fit(
37
47
  self,
38
48
  X,
@@ -46,7 +56,7 @@ else:
46
56
  **kw,
47
57
  ):
48
58
  session = kw.pop("session", None)
49
- run_kwargs = kw.pop("run_kwargs", dict())
59
+ run_kwargs = kw.pop("run_kwargs", None) or dict()
50
60
  dtrain, evals = wrap_evaluation_matrices(
51
61
  None,
52
62
  X,
@@ -58,6 +68,7 @@ else:
58
68
  base_margin_eval_set,
59
69
  )
60
70
  params = self.get_xgb_params()
71
+ self._n_features_in = X.shape[1]
61
72
  self.n_classes_ = num_class or 1
62
73
  if self.n_classes_ > 2:
63
74
  params["objective"] = "multi:softprob"
@@ -81,7 +92,7 @@ else:
81
92
  def predict(self, data, **kw):
82
93
  prob = self.predict_proba(data, flag=True, **kw)
83
94
  if prob.ndim > 1:
84
- prediction = argmax(prob, axis=1)
95
+ prediction = mt.argmax(prob, axis=1)
85
96
  else:
86
97
  prediction = (prob > 0.5).astype(np.int64)
87
98
  return prediction
@@ -103,7 +114,7 @@ else:
103
114
  # binary logistic function
104
115
  classone_probs = prediction
105
116
  classzero_probs = 1.0 - classone_probs
106
- return transpose(_vstack((classzero_probs, classone_probs)))
117
+ return mt.transpose(_vstack((classzero_probs, classone_probs)))
107
118
 
108
119
  @property
109
120
  def classes_(self) -> np.ndarray: