maxframe 1.2.1__cp311-cp311-win32.whl → 1.3.0__cp311-cp311-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of maxframe might be problematic. Click here for more details.

Files changed (70) hide show
  1. maxframe/_utils.cp311-win32.pyd +0 -0
  2. maxframe/codegen.py +70 -21
  3. maxframe/config/config.py +6 -0
  4. maxframe/core/accessor.py +1 -0
  5. maxframe/core/graph/core.cp311-win32.pyd +0 -0
  6. maxframe/dataframe/accessors/__init__.py +1 -1
  7. maxframe/dataframe/accessors/dict_/accessor.py +1 -0
  8. maxframe/dataframe/accessors/dict_/length.py +1 -0
  9. maxframe/dataframe/accessors/dict_/setitem.py +1 -0
  10. maxframe/dataframe/accessors/dict_/tests/test_dict_accessor.py +5 -7
  11. maxframe/dataframe/accessors/list_/__init__.py +37 -0
  12. maxframe/dataframe/accessors/list_/accessor.py +39 -0
  13. maxframe/dataframe/accessors/list_/getitem.py +135 -0
  14. maxframe/dataframe/accessors/list_/length.py +73 -0
  15. maxframe/dataframe/accessors/list_/tests/__init__.py +13 -0
  16. maxframe/dataframe/accessors/list_/tests/test_list_accessor.py +79 -0
  17. maxframe/dataframe/accessors/plotting/__init__.py +2 -0
  18. maxframe/dataframe/accessors/string_/__init__.py +1 -0
  19. maxframe/dataframe/datastore/to_odps.py +6 -0
  20. maxframe/dataframe/extensions/accessor.py +1 -0
  21. maxframe/dataframe/extensions/apply_chunk.py +34 -21
  22. maxframe/dataframe/extensions/flatmap.py +8 -1
  23. maxframe/dataframe/extensions/tests/test_apply_chunk.py +2 -1
  24. maxframe/dataframe/extensions/tests/test_extensions.py +1 -0
  25. maxframe/dataframe/merge/concat.py +7 -4
  26. maxframe/dataframe/merge/merge.py +1 -0
  27. maxframe/dataframe/merge/tests/test_merge.py +97 -47
  28. maxframe/dataframe/missing/tests/test_missing.py +1 -0
  29. maxframe/dataframe/tests/test_utils.py +7 -0
  30. maxframe/dataframe/ufunc/ufunc.py +1 -0
  31. maxframe/dataframe/utils.py +3 -0
  32. maxframe/io/odpsio/schema.py +1 -0
  33. maxframe/learn/contrib/__init__.py +2 -4
  34. maxframe/learn/contrib/llm/__init__.py +1 -0
  35. maxframe/learn/contrib/llm/core.py +31 -10
  36. maxframe/learn/contrib/llm/models/__init__.py +1 -0
  37. maxframe/learn/contrib/llm/models/dashscope.py +4 -3
  38. maxframe/learn/contrib/llm/models/managed.py +39 -0
  39. maxframe/learn/contrib/llm/multi_modal.py +1 -0
  40. maxframe/learn/contrib/llm/text.py +252 -8
  41. maxframe/learn/contrib/models.py +77 -0
  42. maxframe/learn/contrib/utils.py +1 -0
  43. maxframe/learn/contrib/xgboost/__init__.py +8 -1
  44. maxframe/learn/contrib/xgboost/classifier.py +15 -4
  45. maxframe/learn/contrib/xgboost/core.py +108 -1
  46. maxframe/learn/contrib/xgboost/dmatrix.py +1 -1
  47. maxframe/learn/contrib/xgboost/predict.py +8 -3
  48. maxframe/learn/contrib/xgboost/regressor.py +15 -1
  49. maxframe/learn/contrib/xgboost/train.py +5 -4
  50. maxframe/lib/dtypes_extension/__init__.py +2 -1
  51. maxframe/lib/dtypes_extension/dtypes.py +21 -0
  52. maxframe/lib/dtypes_extension/tests/test_dtypes.py +13 -3
  53. maxframe/lib/mmh3.cp311-win32.pyd +0 -0
  54. maxframe/opcodes.py +19 -0
  55. maxframe/serialization/__init__.py +1 -0
  56. maxframe/serialization/core.cp311-win32.pyd +0 -0
  57. maxframe/serialization/core.pyx +12 -1
  58. maxframe/serialization/numpy.py +12 -4
  59. maxframe/serialization/serializables/tests/test_serializable.py +13 -2
  60. maxframe/serialization/tests/test_serial.py +2 -0
  61. maxframe/tensor/merge/concatenate.py +1 -0
  62. maxframe/tensor/misc/unique.py +11 -10
  63. maxframe/tensor/reshape/reshape.py +4 -1
  64. maxframe/utils.py +4 -0
  65. {maxframe-1.2.1.dist-info → maxframe-1.3.0.dist-info}/METADATA +2 -2
  66. {maxframe-1.2.1.dist-info → maxframe-1.3.0.dist-info}/RECORD +70 -62
  67. {maxframe-1.2.1.dist-info → maxframe-1.3.0.dist-info}/WHEEL +1 -1
  68. maxframe_client/session/odps.py +3 -0
  69. maxframe_client/session/tests/test_task.py +1 -0
  70. {maxframe-1.2.1.dist-info → maxframe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -28,6 +28,7 @@ from ...protocol import DataFrameTableMeta
28
28
  from ...tensor.core import TENSOR_TYPE
29
29
 
30
30
  _TEMP_TABLE_PREFIX = "tmp_mf_"
31
+ DEFAULT_SINGLE_INDEX_NAME = "_idx_0"
31
32
 
32
33
  _arrow_to_odps_types = {
33
34
  pa.string(): odps_types.string,
@@ -12,8 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from . import graph, llm, pytorch
15
+ from . import graph, llm, models, pytorch
16
16
 
17
- del graph
18
- del llm
19
- del pytorch
17
+ del graph, llm, models, pytorch
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
14
15
  from . import models, multi_modal, text
15
16
 
16
17
  del models
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
14
15
  from typing import Any, Dict
15
16
 
16
17
  import numpy as np
@@ -19,6 +20,8 @@ import pandas as pd
19
20
  from ....core.entity.output_types import OutputType
20
21
  from ....core.operator.base import Operator
21
22
  from ....core.operator.core import TileableOperatorMixin
23
+ from ....dataframe.core import SERIES_TYPE
24
+ from ....dataframe.operators import DataFrameOperatorMixin
22
25
  from ....dataframe.utils import parse_index
23
26
  from ....serialization.serializables.core import Serializable
24
27
  from ....serialization.serializables.field import AnyField, DictField, StringField
@@ -31,24 +34,42 @@ class LLM(Serializable):
31
34
  pass
32
35
 
33
36
 
34
- class LLMOperator(Operator, TileableOperatorMixin):
37
+ class LLMTaskOperator(Operator, DataFrameOperatorMixin):
38
+ task = AnyField("task", default=None)
35
39
  model = AnyField("model", default=None)
36
- prompt_template = AnyField("prompt_template", default=None)
37
40
  params = DictField("params", default=None)
41
+ running_options: Dict[str, Any] = DictField("running_options", default=None)
38
42
 
39
43
  def __init__(self, output_types=None, **kw):
40
44
  if output_types is None:
41
45
  output_types = [OutputType.dataframe]
42
46
  super().__init__(_output_types=output_types, **kw)
43
47
 
44
- def __call__(self, data):
45
- col_names = ["response", "success"]
46
- columns = parse_index(pd.Index(col_names), store_data=True)
47
- out_dtypes = pd.Series([np.dtype("O"), np.dtype("bool")], index=col_names)
48
- return self.new_tileable(
48
+ def get_output_dtypes(self) -> Dict[str, np.dtype]:
49
+ raise NotImplementedError
50
+
51
+ def __call__(self, data, index=None):
52
+ outputs = self.get_output_dtypes()
53
+ col_name = list(outputs.keys())
54
+ columns = parse_index(pd.Index(col_name), store_data=True)
55
+ out_dtypes = pd.Series(list(outputs.values()), index=col_name)
56
+ index_value = index or (
57
+ parse_index(pd.RangeIndex(-1), data)
58
+ if isinstance(data, SERIES_TYPE)
59
+ else data.index_value
60
+ )
61
+
62
+ return self.new_dataframe(
49
63
  inputs=[data],
50
- dtypes=out_dtypes,
51
- shape=(data.shape[0], len(col_names)),
52
- index_value=data.index_value,
64
+ shape=(np.nan, len(col_name)),
65
+ index_value=index_value,
53
66
  columns_value=columns,
67
+ dtypes=out_dtypes,
54
68
  )
69
+
70
+
71
+ class LLMTextGenOperator(LLMTaskOperator, TileableOperatorMixin):
72
+ prompt_template = AnyField("prompt_template", default=None)
73
+
74
+ def get_output_dtypes(self) -> Dict[str, np.dtype]:
75
+ return {"response": np.dtype("O"), "success": np.dtype("bool")}
@@ -11,4 +11,5 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
14
15
  from .dashscope import DashScopeMultiModalLLM, DashScopeTextLLM
@@ -11,12 +11,13 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
14
15
  from typing import Any, Dict
15
16
 
16
17
  from ..... import opcodes
17
18
  from .....serialization.serializables.core import Serializable
18
19
  from .....serialization.serializables.field import StringField
19
- from ..core import LLMOperator
20
+ from ..core import LLMTextGenOperator
20
21
  from ..multi_modal import MultiModalLLM
21
22
  from ..text import TextLLM
22
23
 
@@ -65,9 +66,9 @@ class DashScopeMultiModalLLM(MultiModalLLM, DashScopeLLMMixin):
65
66
  )(data)
66
67
 
67
68
 
68
- class DashScopeTextGenerationOperator(LLMOperator):
69
+ class DashScopeTextGenerationOperator(LLMTextGenOperator):
69
70
  _op_type_ = opcodes.DASHSCOPE_TEXT_GENERATION
70
71
 
71
72
 
72
- class DashScopeMultiModalGenerationOperator(LLMOperator):
73
+ class DashScopeMultiModalGenerationOperator(LLMTextGenOperator):
73
74
  _op_type_ = opcodes.DASHSCOPE_MULTI_MODAL_GENERATION
@@ -0,0 +1,39 @@
1
+ # Copyright 1999-2025 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, List
16
+
17
+ from ..... import opcodes
18
+ from .....serialization.serializables import StringField
19
+ from ..core import LLMTextGenOperator
20
+ from ..text import TextLLM
21
+
22
+
23
+ class ManagedLLMTextGenOperator(LLMTextGenOperator):
24
+ _op_type_ = opcodes.MANAGED_TEXT_MODAL_GENERATION
25
+
26
+ inference_framework: str = StringField("inference_framework", default=None)
27
+
28
+
29
+ class ManagedTextLLM(TextLLM):
30
+ def generate(
31
+ self,
32
+ data,
33
+ prompt_template: List[Dict[str, Any]],
34
+ params: Dict[str, Any] = None,
35
+ **kw
36
+ ):
37
+ return ManagedLLMTextGenOperator(
38
+ model=self, prompt_template=prompt_template, params=params, **kw
39
+ )(data)
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
14
15
  from typing import Any, Dict
15
16
 
16
17
  from ....dataframe.core import DATAFRAME_TYPE, SERIES_TYPE
@@ -11,32 +11,276 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Any, Dict
15
14
 
16
- from ....dataframe.core import DATAFRAME_TYPE, SERIES_TYPE
17
- from .core import LLM
15
+ from typing import Any, Dict, List
16
+
17
+ import numpy as np
18
+
19
+ from .... import opcodes
20
+ from ....dataframe.core import DataFrame, Series
21
+ from ....serialization.serializables import FieldTypes, ListField, StringField
22
+ from .core import LLM, LLMTaskOperator
23
+
24
+
25
+ class TextLLMSummarizeOperator(LLMTaskOperator):
26
+ _op_type_ = opcodes.LLM_TEXT_SUMMARIZE_TASK
27
+
28
+ def get_output_dtypes(self) -> Dict[str, np.dtype]:
29
+ return {
30
+ "summary": np.dtype("O"),
31
+ "success": np.dtype("bool"),
32
+ }
33
+
34
+
35
+ class TextLLMTranslateOperator(LLMTaskOperator):
36
+ _op_type_ = opcodes.LLM_TEXT_TRANSLATE_TASK
37
+
38
+ source_language = StringField("source_language")
39
+ target_language = StringField("target_language")
40
+
41
+ def get_output_dtypes(self) -> Dict[str, np.dtype]:
42
+ return {
43
+ "target": np.dtype("O"),
44
+ "success": np.dtype("bool"),
45
+ }
46
+
47
+
48
+ class TextLLMClassifyOperator(LLMTaskOperator):
49
+ _op_type_ = opcodes.LLM_TEXT_CLASSIFY_TASK
50
+
51
+ labels = ListField("labels")
52
+ description = StringField("description", default=None)
53
+ examples = ListField("examples", FieldTypes.dict, default=None)
54
+
55
+ def get_output_dtypes(self) -> Dict[str, np.dtype]:
56
+ return {
57
+ "label": np.dtype("O"),
58
+ "reason": np.dtype("O"),
59
+ "success": np.dtype("bool"),
60
+ }
18
61
 
19
62
 
20
63
  class TextLLM(LLM):
21
64
  def generate(
22
65
  self,
23
66
  data,
24
- prompt_template: Dict[str, Any],
67
+ prompt_template: List[Dict[str, str]],
25
68
  params: Dict[str, Any] = None,
26
69
  ):
27
70
  raise NotImplementedError
28
71
 
72
+ def summarize(self, series, index=None, **kw):
73
+ return TextLLMSummarizeOperator(model=self, task="summarize", **kw)(
74
+ series, index
75
+ )
76
+
77
+ def translate(
78
+ self,
79
+ series,
80
+ target_language: str,
81
+ source_language: str = None,
82
+ index=None,
83
+ **kw
84
+ ):
85
+ return TextLLMTranslateOperator(
86
+ model=self,
87
+ task="translate",
88
+ source_language=source_language,
89
+ target_language=target_language,
90
+ **kw
91
+ )(series, index)
92
+
93
+ def classify(
94
+ self,
95
+ series,
96
+ labels: List[str],
97
+ description=None,
98
+ examples=None,
99
+ index=None,
100
+ **kw
101
+ ):
102
+ return TextLLMClassifyOperator(
103
+ model=self,
104
+ labels=labels,
105
+ task="classify",
106
+ description=description,
107
+ examples=examples,
108
+ **kw
109
+ )(series, index)
110
+
29
111
 
30
112
  def generate(
31
113
  data,
32
114
  model: TextLLM,
33
- prompt_template: Dict[str, Any],
115
+ prompt_template: List[Dict[str, Any]],
34
116
  params: Dict[str, Any] = None,
35
117
  ):
36
- if not isinstance(data, DATAFRAME_TYPE) and not isinstance(data, SERIES_TYPE):
118
+ """
119
+ Generate text using a text language model based on given data and prompt template.
120
+
121
+ Parameters
122
+ ----------
123
+ data : DataFrame or Series
124
+ Input data used for generation. Can be maxframe DataFrame, Series that contain text to be processed.
125
+ model : TextLLM
126
+ Language model instance used for text generation.
127
+ prompt_template : List[Dict[str, str]]
128
+ Dictionary containing the conversation messages template. Use ``{col_name}`` as a placeholder to reference
129
+ column data from input data.
130
+
131
+ Usually in format of [{"role": "user", "content": "{query}"}], same with openai api schema.
132
+ params : Dict[str, Any], optional
133
+ Additional parameters for generation configuration, by default None.
134
+ Can include settings like temperature, max_tokens, etc.
135
+
136
+ Returns
137
+ -------
138
+ DataFrame
139
+ Generated text raw response and success status. If the success is False, the generated text will return the
140
+ error message.
141
+
142
+ Examples
143
+ --------
144
+ >>> from maxframe.learn.contrib.llm.models.managed import ManagedTextLLM
145
+ >>> import maxframe.dataframe as md
146
+ >>>
147
+ >>> # Initialize the model
148
+ >>> llm = ManagedTextLLM(name="Qwen2.5-1.5B")
149
+ >>>
150
+ >>> # Prepare prompt template
151
+ >>> messages = [
152
+ ... {
153
+ ... "role": "user",
154
+ ... "content": "{query}",
155
+ ... },
156
+ ... ]
157
+ >>>
158
+ >>> # Create sample data
159
+ >>> df = md.DataFrame({"query": ["What is machine learning?"]})
160
+ >>>
161
+ >>> # Generate response
162
+ >>> result = generate(df, llm, prompt_template=messages)
163
+ >>> result.execute()
164
+ """
165
+ if not isinstance(data, DataFrame) and not isinstance(data, Series):
37
166
  raise ValueError("data must be a maxframe dataframe or series object")
38
167
  if not isinstance(model, TextLLM):
39
- raise ValueError("model must be a TextLLM object")
168
+ raise TypeError("model must be a TextLLM object")
40
169
  params = params if params is not None else dict()
41
170
  model.validate_params(params)
42
- return model.generate(data, prompt_template, params)
171
+ return model.generate(data, prompt_template=prompt_template, params=params)
172
+
173
+
174
+ def summary(series, model: TextLLM, index=None):
175
+ """
176
+ Generate summaries for text content in a series using a language model.
177
+
178
+ Parameters
179
+ ----------
180
+ series : pandas.Series
181
+ A maxframe Series containing text data to be summarized.
182
+ Each element should be a text string.
183
+ model : TextLLM
184
+ Language model instance used for text summarization.
185
+ index : array-like, optional
186
+ Index for the output series, by default None, will generate new index.
187
+
188
+ Returns
189
+ -------
190
+ maxframe.Series
191
+ A pandas Series containing the generated summaries and success status.
192
+ """
193
+ if not isinstance(series, Series):
194
+ raise ValueError("series must be a maxframe series object")
195
+
196
+ if series.dtype != np.str_:
197
+ raise ValueError("summary input must be a string series")
198
+
199
+ return model.summarize(series, index=index)
200
+
201
+
202
+ def translate(
203
+ series, model: TextLLM, source_language: str, target_language: str, index=None
204
+ ):
205
+ """
206
+ Translate text content in a series using a language model from source language to target language.
207
+
208
+ Parameters
209
+ ----------
210
+ series : pandas.Series
211
+ A maxframe Series containing text data to be translate.
212
+ Each element should be a text string.
213
+ model : TextLLM
214
+ Language model instance used for text summarization.
215
+ source_language : str
216
+ Source language of the text.
217
+ target_language : str
218
+ Target language of the text.
219
+ index : array-like, optional
220
+ Index for the output series, by default None, will generate new index.
221
+
222
+ Returns
223
+ -------
224
+ maxframe.Series
225
+ A pandas Series containing the generated translation and success status.
226
+ """
227
+ if not isinstance(series, Series):
228
+ raise ValueError("series must be a maxframe series object")
229
+ if series.dtype != np.str_:
230
+ raise ValueError("translate input must be a string series")
231
+ return model.translate(
232
+ series,
233
+ source_language=source_language,
234
+ target_language=target_language,
235
+ index=index,
236
+ )
237
+
238
+
239
+ def classify(
240
+ series,
241
+ model: TextLLM,
242
+ labels: List[str],
243
+ description: str = None,
244
+ examples: List[Dict[str, str]] = None,
245
+ index=None,
246
+ ):
247
+ """
248
+ Classify text content in a series with given labels.
249
+
250
+ Parameters
251
+ ----------
252
+ series : pandas.Series
253
+ A maxframe Series containing text data to be classified.
254
+ Each element should be a text string.
255
+ model : TextLLM
256
+ Language model instance used for text summarization.
257
+ labels : List[str]
258
+ List of labels to classify the text.
259
+ description : str
260
+ Description of the classification task.
261
+ examples : List[Dict[str, Dict[str, str]]]
262
+ Examples of the classification task, like [{ "text": "text...", "label":"A", reason : "reason..."}], help
263
+ LLM to better understand your rules.
264
+ index : array-like, optional
265
+ Index for the output series, by default None, will generate new index.
266
+
267
+ Returns
268
+ -------
269
+ maxframe.Series
270
+ A pandas Series containing the generated classification results and success status.
271
+ """
272
+ if not isinstance(series, Series):
273
+ raise ValueError("series must be a maxframe series object")
274
+
275
+ if series.dtype != np.str_:
276
+ raise ValueError("classify input must be a string series")
277
+
278
+ if not isinstance(labels, list):
279
+ raise TypeError("labels must be a list")
280
+
281
+ if not labels:
282
+ raise ValueError("labels must not be empty")
283
+
284
+ return model.classify(
285
+ series, labels=labels, description=description, examples=examples, index=index
286
+ )
@@ -0,0 +1,77 @@
1
+ # Copyright 1999-2025 Alibaba Group Holding Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ... import opcodes
16
+ from ...core import ENTITY_TYPE, OutputType
17
+ from ...core.operator import ObjectOperator, ObjectOperatorMixin
18
+ from ...serialization.serializables import (
19
+ AnyField,
20
+ DictField,
21
+ FunctionField,
22
+ TupleField,
23
+ )
24
+ from ...utils import find_objects, replace_objects
25
+
26
+
27
+ class ModelDataSource(ObjectOperator, ObjectOperatorMixin):
28
+ _op_type_ = opcodes.MODEL_DATA_SOURCE
29
+
30
+ data = AnyField("data")
31
+
32
+ def __call__(self, model_cls):
33
+ self._output_types = [OutputType.object]
34
+ return self.new_tileable(None, object_class=model_cls)
35
+
36
+
37
+ class ModelApplyChunk(ObjectOperator, ObjectOperatorMixin):
38
+ _op_module_ = "maxframe.learn.contrib.models"
39
+ _op_type_ = opcodes.APPLY_CHUNK
40
+
41
+ func = FunctionField("func")
42
+ args = TupleField("args")
43
+ kwargs = DictField("kwargs")
44
+
45
+ def __init__(self, output_types=None, **kwargs):
46
+ if not isinstance(output_types, (tuple, list)):
47
+ output_types = [output_types]
48
+ self._output_types = list(output_types)
49
+ super().__init__(**kwargs)
50
+
51
+ def _set_inputs(self, inputs):
52
+ super()._set_inputs(inputs)
53
+ old_inputs = find_objects(self.args, ENTITY_TYPE) + find_objects(
54
+ self.kwargs, ENTITY_TYPE
55
+ )
56
+ mapping = {o: n for o, n in zip(old_inputs, self._inputs[1:])}
57
+ self.args = replace_objects(self.args, mapping)
58
+ self.kwargs = replace_objects(self.kwargs, mapping)
59
+
60
+ @property
61
+ def output_limit(self) -> int:
62
+ return len(self._output_types)
63
+
64
+ def __call__(self, t, output_kws, args=None, **kwargs):
65
+ self.args = args or ()
66
+ self.kwargs = kwargs
67
+ inputs = (
68
+ [t]
69
+ + find_objects(self.args, ENTITY_TYPE)
70
+ + find_objects(self.kwargs, ENTITY_TYPE)
71
+ )
72
+ return self.new_tileables(inputs, kws=output_kws)
73
+
74
+
75
+ def to_remote_model(model, model_cls):
76
+ op = ModelDataSource(data=model)
77
+ return op(model_cls)
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
14
15
  import sys
15
16
 
16
17
 
@@ -12,11 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..utils import config_mod_getattr as _config_mod_getattr
15
+ from .core import Booster
16
16
  from .dmatrix import DMatrix
17
17
  from .predict import predict
18
18
  from .train import train
19
19
 
20
+ # isort: off
21
+ from ..utils import config_mod_getattr as _config_mod_getattr
22
+
20
23
  _config_mod_getattr(
21
24
  {
22
25
  "XGBClassifier": ".classifier.XGBClassifier",
@@ -24,3 +27,7 @@ _config_mod_getattr(
24
27
  },
25
28
  globals(),
26
29
  )
30
+
31
+ del _config_mod_getattr
32
+
33
+ __all__ = ["Booster", "DMatrix", "XGBClassifier", "XGBRegressor", "predict", "train"]
@@ -12,9 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from typing import Union
16
+
15
17
  import numpy as np
16
18
 
17
- from ....tensor import argmax, transpose
19
+ from .... import tensor as mt
18
20
  from ....tensor.merge.vstack import _vstack
19
21
  from ..utils import make_import_error_func
20
22
  from .core import XGBScikitLearnBase, xgboost
@@ -33,6 +35,14 @@ else:
33
35
  Implementation of the scikit-learn API for XGBoost classification.
34
36
  """
35
37
 
38
+ def __init__(
39
+ self,
40
+ xgb_model: Union[xgboost.XGBClassifier, xgboost.Booster] = None,
41
+ **kwargs,
42
+ ):
43
+ super().__init__(**kwargs)
44
+ self._set_model(xgb_model)
45
+
36
46
  def fit(
37
47
  self,
38
48
  X,
@@ -46,7 +56,7 @@ else:
46
56
  **kw,
47
57
  ):
48
58
  session = kw.pop("session", None)
49
- run_kwargs = kw.pop("run_kwargs", dict())
59
+ run_kwargs = kw.pop("run_kwargs", None) or dict()
50
60
  dtrain, evals = wrap_evaluation_matrices(
51
61
  None,
52
62
  X,
@@ -58,6 +68,7 @@ else:
58
68
  base_margin_eval_set,
59
69
  )
60
70
  params = self.get_xgb_params()
71
+ self._n_features_in = X.shape[1]
61
72
  self.n_classes_ = num_class or 1
62
73
  if self.n_classes_ > 2:
63
74
  params["objective"] = "multi:softprob"
@@ -81,7 +92,7 @@ else:
81
92
  def predict(self, data, **kw):
82
93
  prob = self.predict_proba(data, flag=True, **kw)
83
94
  if prob.ndim > 1:
84
- prediction = argmax(prob, axis=1)
95
+ prediction = mt.argmax(prob, axis=1)
85
96
  else:
86
97
  prediction = (prob > 0.5).astype(np.int64)
87
98
  return prediction
@@ -103,7 +114,7 @@ else:
103
114
  # binary logistic function
104
115
  classone_probs = prediction
105
116
  classzero_probs = 1.0 - classone_probs
106
- return transpose(_vstack((classzero_probs, classone_probs)))
117
+ return mt.transpose(_vstack((classzero_probs, classone_probs)))
107
118
 
108
119
  @property
109
120
  def classes_(self) -> np.ndarray: