maxframe 1.2.1__cp310-cp310-macosx_10_9_universal2.whl → 1.3.1__cp310-cp310-macosx_10_9_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of maxframe might be problematic. Click here for more details.
- maxframe/_utils.cpython-310-darwin.so +0 -0
- maxframe/codegen.py +70 -21
- maxframe/config/config.py +6 -0
- maxframe/core/accessor.py +1 -0
- maxframe/core/graph/core.cpython-310-darwin.so +0 -0
- maxframe/dataframe/accessors/__init__.py +1 -1
- maxframe/dataframe/accessors/dict_/accessor.py +1 -0
- maxframe/dataframe/accessors/dict_/length.py +1 -0
- maxframe/dataframe/accessors/dict_/setitem.py +1 -0
- maxframe/dataframe/accessors/dict_/tests/test_dict_accessor.py +5 -7
- maxframe/dataframe/accessors/list_/__init__.py +37 -0
- maxframe/dataframe/accessors/list_/accessor.py +39 -0
- maxframe/dataframe/accessors/list_/getitem.py +135 -0
- maxframe/dataframe/accessors/list_/length.py +73 -0
- maxframe/dataframe/accessors/list_/tests/__init__.py +13 -0
- maxframe/dataframe/accessors/list_/tests/test_list_accessor.py +79 -0
- maxframe/dataframe/accessors/plotting/__init__.py +2 -0
- maxframe/dataframe/accessors/string_/__init__.py +1 -0
- maxframe/dataframe/datastore/to_odps.py +6 -0
- maxframe/dataframe/extensions/accessor.py +1 -0
- maxframe/dataframe/extensions/apply_chunk.py +34 -21
- maxframe/dataframe/extensions/flatmap.py +8 -1
- maxframe/dataframe/extensions/tests/test_apply_chunk.py +2 -1
- maxframe/dataframe/extensions/tests/test_extensions.py +1 -0
- maxframe/dataframe/groupby/aggregation.py +53 -1
- maxframe/dataframe/merge/concat.py +7 -4
- maxframe/dataframe/merge/merge.py +1 -0
- maxframe/dataframe/merge/tests/test_merge.py +97 -47
- maxframe/dataframe/missing/tests/test_missing.py +1 -0
- maxframe/dataframe/reduction/aggregation.py +63 -0
- maxframe/dataframe/reduction/core.py +17 -5
- maxframe/dataframe/tests/test_utils.py +7 -0
- maxframe/dataframe/ufunc/ufunc.py +1 -0
- maxframe/dataframe/utils.py +3 -0
- maxframe/io/odpsio/schema.py +1 -0
- maxframe/learn/contrib/__init__.py +2 -4
- maxframe/learn/contrib/llm/__init__.py +1 -0
- maxframe/learn/contrib/llm/core.py +31 -10
- maxframe/learn/contrib/llm/models/__init__.py +1 -0
- maxframe/learn/contrib/llm/models/dashscope.py +38 -3
- maxframe/learn/contrib/llm/models/managed.py +54 -0
- maxframe/learn/contrib/llm/multi_modal.py +93 -0
- maxframe/learn/contrib/llm/text.py +268 -8
- maxframe/learn/contrib/models.py +77 -0
- maxframe/learn/contrib/utils.py +1 -0
- maxframe/learn/contrib/xgboost/__init__.py +8 -1
- maxframe/learn/contrib/xgboost/classifier.py +15 -4
- maxframe/learn/contrib/xgboost/core.py +108 -1
- maxframe/learn/contrib/xgboost/dmatrix.py +1 -1
- maxframe/learn/contrib/xgboost/predict.py +6 -3
- maxframe/learn/contrib/xgboost/regressor.py +15 -1
- maxframe/learn/contrib/xgboost/train.py +5 -4
- maxframe/lib/dtypes_extension/__init__.py +2 -1
- maxframe/lib/dtypes_extension/dtypes.py +21 -0
- maxframe/lib/dtypes_extension/tests/test_dtypes.py +13 -3
- maxframe/lib/mmh3.cpython-310-darwin.so +0 -0
- maxframe/opcodes.py +19 -0
- maxframe/serialization/__init__.py +1 -0
- maxframe/serialization/core.cpython-310-darwin.so +0 -0
- maxframe/serialization/core.pyx +12 -1
- maxframe/serialization/numpy.py +12 -4
- maxframe/serialization/serializables/tests/test_serializable.py +13 -2
- maxframe/serialization/tests/test_serial.py +2 -0
- maxframe/tensor/merge/concatenate.py +1 -0
- maxframe/tensor/misc/unique.py +11 -10
- maxframe/tensor/reshape/reshape.py +4 -1
- maxframe/utils.py +4 -0
- {maxframe-1.2.1.dist-info → maxframe-1.3.1.dist-info}/METADATA +3 -2
- {maxframe-1.2.1.dist-info → maxframe-1.3.1.dist-info}/RECORD +73 -65
- {maxframe-1.2.1.dist-info → maxframe-1.3.1.dist-info}/WHEEL +1 -1
- maxframe_client/session/odps.py +3 -0
- maxframe_client/session/tests/test_task.py +1 -0
- {maxframe-1.2.1.dist-info → maxframe-1.3.1.dist-info}/top_level.txt +0 -0
|
@@ -11,6 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
|
|
14
15
|
from typing import Any, Dict
|
|
15
16
|
|
|
16
17
|
from ....dataframe.core import DATAFRAME_TYPE, SERIES_TYPE
|
|
@@ -33,6 +34,98 @@ def generate(
|
|
|
33
34
|
prompt_template: Dict[str, Any],
|
|
34
35
|
params: Dict[str, Any] = None,
|
|
35
36
|
):
|
|
37
|
+
"""
|
|
38
|
+
Generate text with multi model llm based on given data and prompt template.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
data : DataFrame or Series
|
|
43
|
+
Input data used for generation. Can be maxframe DataFrame, Series that contain text to be processed.
|
|
44
|
+
model : MultiModalLLM
|
|
45
|
+
Language model instance support **MultiModal** inputs used for text generation.
|
|
46
|
+
prompt_template : List[Dict[str, List[Dict[str, str]]]]
|
|
47
|
+
List of message with column names as placeholders. Each message contains a role and content. Content is a list of dict, each dict contains a text or image, the value can reference column data from input.
|
|
48
|
+
|
|
49
|
+
Here is an example of prompt template.
|
|
50
|
+
|
|
51
|
+
.. code-block:: python
|
|
52
|
+
|
|
53
|
+
[
|
|
54
|
+
{
|
|
55
|
+
"role": "<role>", # e.g. "user" or "assistant"
|
|
56
|
+
"content": [
|
|
57
|
+
{
|
|
58
|
+
# At least one of these fields is required
|
|
59
|
+
"image": "<image_data_url>", # optional
|
|
60
|
+
"text": "<prompt_text_template>" # optional
|
|
61
|
+
},
|
|
62
|
+
...
|
|
63
|
+
]
|
|
64
|
+
}
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
Where:
|
|
68
|
+
|
|
69
|
+
- ``text`` can be a Python format string using column names from input data as parameters (e.g. ``"{column_name}"``)
|
|
70
|
+
- ``image`` should be a DataURL string following `RFC2397 <https://en.wikipedia.org/wiki/Data_URI_scheme>`_ standard with format.
|
|
71
|
+
|
|
72
|
+
.. code-block:: none
|
|
73
|
+
|
|
74
|
+
data:<mime_type>[;base64],<column_name>
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
params : Dict[str, Any], optional
|
|
78
|
+
Additional parameters for generation configuration, by default None.
|
|
79
|
+
Can include settings like temperature, max_tokens, etc.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
DataFrame
|
|
84
|
+
Generated text raw response and success status. If the success is False, the generated text will return the
|
|
85
|
+
error message.
|
|
86
|
+
|
|
87
|
+
Notes
|
|
88
|
+
-----
|
|
89
|
+
- The ``api_key_resource`` parameter should reference a text file resource in MaxCompute that contains only your DashScope API key.
|
|
90
|
+
|
|
91
|
+
- Using DashScope services requires enabling public network access for your MaxCompute project. This can be configured through the MaxCompute console by `enabling the Internet access feature <https://help.aliyun.com/zh/maxcompute/user-guide/network-connection-process>`_ for your project. Without this configuration, the API calls to DashScope will fail due to network connectivity issues.
|
|
92
|
+
|
|
93
|
+
Examples
|
|
94
|
+
--------
|
|
95
|
+
You can initialize a DashScope multi-modal model (such as qwen-vl-max) by providing a model name and an ``api_key_resource``.
|
|
96
|
+
The ``api_key_resource`` is a MaxCompute resource name that points to a text file containing a `DashScope <https://dashscope.aliyun.com/>`_ API key.
|
|
97
|
+
|
|
98
|
+
>>> from maxframe.learn.contrib.llm.models.dashscope import DashScopeMultiModalLLM
|
|
99
|
+
>>> import maxframe.dataframe as md
|
|
100
|
+
>>>
|
|
101
|
+
>>> model = DashScopeMultiModalLLM(
|
|
102
|
+
... name="qwen-vl-max",
|
|
103
|
+
... api_key_resource="<api-key-resource-name>"
|
|
104
|
+
... )
|
|
105
|
+
|
|
106
|
+
We use Data Url Schema to provide multi modal input in prompt template, here is an example to fill in the image from table.
|
|
107
|
+
|
|
108
|
+
Assuming you have a MaxCompute table with two columns: ``image_id`` (as the index) and ``encoded_image_data_base64`` (containing Base64 encoded image data),
|
|
109
|
+
you can construct a prompt message template as follows:
|
|
110
|
+
|
|
111
|
+
>>> df = md.read_odps_table("image_content", index_col="image_id")
|
|
112
|
+
|
|
113
|
+
>>> prompt_template = [
|
|
114
|
+
... {
|
|
115
|
+
... "role": "user",
|
|
116
|
+
... "content": [
|
|
117
|
+
... {
|
|
118
|
+
... "image": "data:image/png;base64,encoded_image_data_base64",
|
|
119
|
+
... },
|
|
120
|
+
... {
|
|
121
|
+
... "text": "Analyze this image in detail",
|
|
122
|
+
... },
|
|
123
|
+
... ],
|
|
124
|
+
... },
|
|
125
|
+
... ]
|
|
126
|
+
>>> result = model.generate(df, prompt_template)
|
|
127
|
+
>>> result.execute()
|
|
128
|
+
"""
|
|
36
129
|
if not isinstance(data, DATAFRAME_TYPE) and not isinstance(data, SERIES_TYPE):
|
|
37
130
|
raise ValueError("data must be a maxframe dataframe or series object")
|
|
38
131
|
if not isinstance(model, MultiModalLLM):
|
|
@@ -11,32 +11,292 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
from typing import Any, Dict
|
|
15
14
|
|
|
16
|
-
from
|
|
17
|
-
|
|
15
|
+
from typing import Any, Dict, List
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
from .... import opcodes
|
|
20
|
+
from ....dataframe.core import DataFrame, Series
|
|
21
|
+
from ....serialization.serializables import FieldTypes, ListField, StringField
|
|
22
|
+
from .core import LLM, LLMTaskOperator
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TextLLMSummarizeOperator(LLMTaskOperator):
|
|
26
|
+
_op_type_ = opcodes.LLM_TEXT_SUMMARIZE_TASK
|
|
27
|
+
|
|
28
|
+
def get_output_dtypes(self) -> Dict[str, np.dtype]:
|
|
29
|
+
return {
|
|
30
|
+
"summary": np.dtype("O"),
|
|
31
|
+
"success": np.dtype("bool"),
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class TextLLMTranslateOperator(LLMTaskOperator):
|
|
36
|
+
_op_type_ = opcodes.LLM_TEXT_TRANSLATE_TASK
|
|
37
|
+
|
|
38
|
+
source_language = StringField("source_language")
|
|
39
|
+
target_language = StringField("target_language")
|
|
40
|
+
|
|
41
|
+
def get_output_dtypes(self) -> Dict[str, np.dtype]:
|
|
42
|
+
return {
|
|
43
|
+
"target": np.dtype("O"),
|
|
44
|
+
"success": np.dtype("bool"),
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class TextLLMClassifyOperator(LLMTaskOperator):
|
|
49
|
+
_op_type_ = opcodes.LLM_TEXT_CLASSIFY_TASK
|
|
50
|
+
|
|
51
|
+
labels = ListField("labels")
|
|
52
|
+
description = StringField("description", default=None)
|
|
53
|
+
examples = ListField("examples", FieldTypes.dict, default=None)
|
|
54
|
+
|
|
55
|
+
def get_output_dtypes(self) -> Dict[str, np.dtype]:
|
|
56
|
+
return {
|
|
57
|
+
"label": np.dtype("O"),
|
|
58
|
+
"reason": np.dtype("O"),
|
|
59
|
+
"success": np.dtype("bool"),
|
|
60
|
+
}
|
|
18
61
|
|
|
19
62
|
|
|
20
63
|
class TextLLM(LLM):
|
|
21
64
|
def generate(
|
|
22
65
|
self,
|
|
23
66
|
data,
|
|
24
|
-
prompt_template: Dict[str,
|
|
67
|
+
prompt_template: List[Dict[str, str]],
|
|
25
68
|
params: Dict[str, Any] = None,
|
|
26
69
|
):
|
|
27
70
|
raise NotImplementedError
|
|
28
71
|
|
|
72
|
+
def summarize(self, series, index=None, **kw):
|
|
73
|
+
return TextLLMSummarizeOperator(model=self, task="summarize", **kw)(
|
|
74
|
+
series, index
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def translate(
|
|
78
|
+
self,
|
|
79
|
+
series,
|
|
80
|
+
target_language: str,
|
|
81
|
+
source_language: str = None,
|
|
82
|
+
index=None,
|
|
83
|
+
**kw
|
|
84
|
+
):
|
|
85
|
+
return TextLLMTranslateOperator(
|
|
86
|
+
model=self,
|
|
87
|
+
task="translate",
|
|
88
|
+
source_language=source_language,
|
|
89
|
+
target_language=target_language,
|
|
90
|
+
**kw
|
|
91
|
+
)(series, index)
|
|
92
|
+
|
|
93
|
+
def classify(
|
|
94
|
+
self,
|
|
95
|
+
series,
|
|
96
|
+
labels: List[str],
|
|
97
|
+
description=None,
|
|
98
|
+
examples=None,
|
|
99
|
+
index=None,
|
|
100
|
+
**kw
|
|
101
|
+
):
|
|
102
|
+
return TextLLMClassifyOperator(
|
|
103
|
+
model=self,
|
|
104
|
+
labels=labels,
|
|
105
|
+
task="classify",
|
|
106
|
+
description=description,
|
|
107
|
+
examples=examples,
|
|
108
|
+
**kw
|
|
109
|
+
)(series, index)
|
|
110
|
+
|
|
29
111
|
|
|
30
112
|
def generate(
|
|
31
113
|
data,
|
|
32
114
|
model: TextLLM,
|
|
33
|
-
prompt_template: Dict[str, Any],
|
|
115
|
+
prompt_template: List[Dict[str, Any]],
|
|
34
116
|
params: Dict[str, Any] = None,
|
|
35
117
|
):
|
|
36
|
-
|
|
118
|
+
"""
|
|
119
|
+
Generate text using a text language model based on given data and prompt template.
|
|
120
|
+
|
|
121
|
+
Parameters
|
|
122
|
+
----------
|
|
123
|
+
data : DataFrame or Series
|
|
124
|
+
Input data used for generation. Can be maxframe DataFrame, Series that contain text to be processed.
|
|
125
|
+
model : TextLLM
|
|
126
|
+
Language model instance used for text generation.
|
|
127
|
+
prompt_template : List[Dict[str, str]]
|
|
128
|
+
Dictionary containing the conversation messages template. Use ``{col_name}`` as a placeholder to reference
|
|
129
|
+
column data from input data.
|
|
130
|
+
|
|
131
|
+
Usually in format of [{"role": "user", "content": "{query}"}], same with openai api schema.
|
|
132
|
+
params : Dict[str, Any], optional
|
|
133
|
+
Additional parameters for generation configuration, by default None.
|
|
134
|
+
Can include settings like temperature, max_tokens, etc.
|
|
135
|
+
|
|
136
|
+
Returns
|
|
137
|
+
-------
|
|
138
|
+
DataFrame
|
|
139
|
+
Generated text raw response and success status. If the success is False, the generated text will return the
|
|
140
|
+
error message.
|
|
141
|
+
|
|
142
|
+
Examples
|
|
143
|
+
--------
|
|
144
|
+
>>> from maxframe.learn.contrib.llm.models.managed import ManagedTextLLM
|
|
145
|
+
>>> import maxframe.dataframe as md
|
|
146
|
+
>>>
|
|
147
|
+
>>> # Initialize the model
|
|
148
|
+
>>> llm = ManagedTextLLM(name="Qwen2.5-0.5B-instruct")
|
|
149
|
+
>>>
|
|
150
|
+
>>> # Prepare prompt template
|
|
151
|
+
>>> messages = [
|
|
152
|
+
... {
|
|
153
|
+
... "role": "user",
|
|
154
|
+
... "content": "Help answer following question: {query}",
|
|
155
|
+
... },
|
|
156
|
+
... ]
|
|
157
|
+
|
|
158
|
+
>>> # Create sample data
|
|
159
|
+
>>> df = md.DataFrame({"query": ["What is machine learning?"]})
|
|
160
|
+
>>>
|
|
161
|
+
>>> # Generate response
|
|
162
|
+
>>> result = generate(df, llm, prompt_template=messages)
|
|
163
|
+
>>> result.execute()
|
|
164
|
+
"""
|
|
165
|
+
if not isinstance(data, DataFrame) and not isinstance(data, Series):
|
|
37
166
|
raise ValueError("data must be a maxframe dataframe or series object")
|
|
38
167
|
if not isinstance(model, TextLLM):
|
|
39
|
-
raise
|
|
168
|
+
raise TypeError("model must be a TextLLM object")
|
|
40
169
|
params = params if params is not None else dict()
|
|
41
170
|
model.validate_params(params)
|
|
42
|
-
return model.generate(data, prompt_template, params)
|
|
171
|
+
return model.generate(data, prompt_template=prompt_template, params=params)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def summary(series, model: TextLLM, index=None):
|
|
175
|
+
"""
|
|
176
|
+
Generate summaries for text content in a series using a language model.
|
|
177
|
+
|
|
178
|
+
Parameters
|
|
179
|
+
----------
|
|
180
|
+
series : Series
|
|
181
|
+
A maxframe Series containing text data to be summarized.
|
|
182
|
+
Each element should be a text string.
|
|
183
|
+
model : TextLLM
|
|
184
|
+
Language model instance used for text summarization.
|
|
185
|
+
index : array-like, optional
|
|
186
|
+
Index for the output series, by default None, will generate new index.
|
|
187
|
+
|
|
188
|
+
Returns
|
|
189
|
+
-------
|
|
190
|
+
maxframe.Series
|
|
191
|
+
A pandas Series containing the generated summaries and success status.
|
|
192
|
+
|
|
193
|
+
Notes
|
|
194
|
+
-----
|
|
195
|
+
**Preview:** This API is in preview state and may be unstable.
|
|
196
|
+
The interface may change in future releases.
|
|
197
|
+
"""
|
|
198
|
+
if not isinstance(series, Series):
|
|
199
|
+
raise ValueError("series must be a maxframe series object")
|
|
200
|
+
|
|
201
|
+
if series.dtype != np.str_:
|
|
202
|
+
raise ValueError("summary input must be a string series")
|
|
203
|
+
|
|
204
|
+
return model.summarize(series, index=index)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def translate(
|
|
208
|
+
series, model: TextLLM, source_language: str, target_language: str, index=None
|
|
209
|
+
):
|
|
210
|
+
"""
|
|
211
|
+
Translate text content in a series using a language model from source language to target language.
|
|
212
|
+
|
|
213
|
+
Parameters
|
|
214
|
+
----------
|
|
215
|
+
series : pandas.Series
|
|
216
|
+
A maxframe Series containing text data to translate.
|
|
217
|
+
Each element should be a text string.
|
|
218
|
+
model : TextLLM
|
|
219
|
+
Language model instance used for text summarization.
|
|
220
|
+
source_language : str
|
|
221
|
+
Source language of the text.
|
|
222
|
+
target_language : str
|
|
223
|
+
Target language of the text.
|
|
224
|
+
index : array-like, optional
|
|
225
|
+
Index for the output series, by default None, will generate new index.
|
|
226
|
+
|
|
227
|
+
Returns
|
|
228
|
+
-------
|
|
229
|
+
maxframe.Series
|
|
230
|
+
A pandas Series containing the generated translation and success status.
|
|
231
|
+
|
|
232
|
+
Notes
|
|
233
|
+
-----
|
|
234
|
+
**Preview:** This API is in preview state and may be unstable.
|
|
235
|
+
The interface may change in future releases.
|
|
236
|
+
|
|
237
|
+
"""
|
|
238
|
+
if not isinstance(series, Series):
|
|
239
|
+
raise ValueError("series must be a maxframe series object")
|
|
240
|
+
if series.dtype != np.str_:
|
|
241
|
+
raise ValueError("translate input must be a string series")
|
|
242
|
+
return model.translate(
|
|
243
|
+
series,
|
|
244
|
+
source_language=source_language,
|
|
245
|
+
target_language=target_language,
|
|
246
|
+
index=index,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def classify(
|
|
251
|
+
series,
|
|
252
|
+
model: TextLLM,
|
|
253
|
+
labels: List[str],
|
|
254
|
+
description: str = None,
|
|
255
|
+
examples: List[Dict[str, str]] = None,
|
|
256
|
+
index=None,
|
|
257
|
+
):
|
|
258
|
+
"""
|
|
259
|
+
Classify text content in a series with given labels.
|
|
260
|
+
|
|
261
|
+
Parameters
|
|
262
|
+
----------
|
|
263
|
+
series : pandas.Series
|
|
264
|
+
A maxframe Series containing text data to be classified.
|
|
265
|
+
Each element should be a text string.
|
|
266
|
+
model : TextLLM
|
|
267
|
+
Language model instance used for text summarization.
|
|
268
|
+
labels : List[str]
|
|
269
|
+
List of labels to classify the text.
|
|
270
|
+
description : str
|
|
271
|
+
Description of the classification task.
|
|
272
|
+
examples : List[Dict[str, Dict[str, str]]]
|
|
273
|
+
Examples of the classification task, like [{ "text": "text...", "label":"A", reason : "reason..."}], help
|
|
274
|
+
LLM to better understand your rules.
|
|
275
|
+
index : array-like, optional
|
|
276
|
+
Index for the output series, by default None, will generate new index.
|
|
277
|
+
|
|
278
|
+
Returns
|
|
279
|
+
-------
|
|
280
|
+
maxframe.Series
|
|
281
|
+
A pandas Series containing the generated classification results and success status.
|
|
282
|
+
|
|
283
|
+
Notes
|
|
284
|
+
-----
|
|
285
|
+
**Preview:** This API is in preview state and may be unstable.
|
|
286
|
+
The interface may change in future releases.
|
|
287
|
+
"""
|
|
288
|
+
if not isinstance(series, Series):
|
|
289
|
+
raise ValueError("series must be a maxframe series object")
|
|
290
|
+
|
|
291
|
+
if series.dtype != np.str_:
|
|
292
|
+
raise ValueError("classify input must be a string series")
|
|
293
|
+
|
|
294
|
+
if not isinstance(labels, list):
|
|
295
|
+
raise TypeError("labels must be a list")
|
|
296
|
+
|
|
297
|
+
if not labels:
|
|
298
|
+
raise ValueError("labels must not be empty")
|
|
299
|
+
|
|
300
|
+
return model.classify(
|
|
301
|
+
series, labels=labels, description=description, examples=examples, index=index
|
|
302
|
+
)
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
# Copyright 1999-2025 Alibaba Group Holding Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ... import opcodes
|
|
16
|
+
from ...core import ENTITY_TYPE, OutputType
|
|
17
|
+
from ...core.operator import ObjectOperator, ObjectOperatorMixin
|
|
18
|
+
from ...serialization.serializables import (
|
|
19
|
+
AnyField,
|
|
20
|
+
DictField,
|
|
21
|
+
FunctionField,
|
|
22
|
+
TupleField,
|
|
23
|
+
)
|
|
24
|
+
from ...utils import find_objects, replace_objects
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ModelDataSource(ObjectOperator, ObjectOperatorMixin):
|
|
28
|
+
_op_type_ = opcodes.MODEL_DATA_SOURCE
|
|
29
|
+
|
|
30
|
+
data = AnyField("data")
|
|
31
|
+
|
|
32
|
+
def __call__(self, model_cls):
|
|
33
|
+
self._output_types = [OutputType.object]
|
|
34
|
+
return self.new_tileable(None, object_class=model_cls)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ModelApplyChunk(ObjectOperator, ObjectOperatorMixin):
|
|
38
|
+
_op_module_ = "maxframe.learn.contrib.models"
|
|
39
|
+
_op_type_ = opcodes.APPLY_CHUNK
|
|
40
|
+
|
|
41
|
+
func = FunctionField("func")
|
|
42
|
+
args = TupleField("args")
|
|
43
|
+
kwargs = DictField("kwargs")
|
|
44
|
+
|
|
45
|
+
def __init__(self, output_types=None, **kwargs):
|
|
46
|
+
if not isinstance(output_types, (tuple, list)):
|
|
47
|
+
output_types = [output_types]
|
|
48
|
+
self._output_types = list(output_types)
|
|
49
|
+
super().__init__(**kwargs)
|
|
50
|
+
|
|
51
|
+
def _set_inputs(self, inputs):
|
|
52
|
+
super()._set_inputs(inputs)
|
|
53
|
+
old_inputs = find_objects(self.args, ENTITY_TYPE) + find_objects(
|
|
54
|
+
self.kwargs, ENTITY_TYPE
|
|
55
|
+
)
|
|
56
|
+
mapping = {o: n for o, n in zip(old_inputs, self._inputs[1:])}
|
|
57
|
+
self.args = replace_objects(self.args, mapping)
|
|
58
|
+
self.kwargs = replace_objects(self.kwargs, mapping)
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def output_limit(self) -> int:
|
|
62
|
+
return len(self._output_types)
|
|
63
|
+
|
|
64
|
+
def __call__(self, t, output_kws, args=None, **kwargs):
|
|
65
|
+
self.args = args or ()
|
|
66
|
+
self.kwargs = kwargs
|
|
67
|
+
inputs = (
|
|
68
|
+
[t]
|
|
69
|
+
+ find_objects(self.args, ENTITY_TYPE)
|
|
70
|
+
+ find_objects(self.kwargs, ENTITY_TYPE)
|
|
71
|
+
)
|
|
72
|
+
return self.new_tileables(inputs, kws=output_kws)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def to_remote_model(model, model_cls):
|
|
76
|
+
op = ModelDataSource(data=model)
|
|
77
|
+
return op(model_cls)
|
maxframe/learn/contrib/utils.py
CHANGED
|
@@ -12,11 +12,14 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from
|
|
15
|
+
from .core import Booster
|
|
16
16
|
from .dmatrix import DMatrix
|
|
17
17
|
from .predict import predict
|
|
18
18
|
from .train import train
|
|
19
19
|
|
|
20
|
+
# isort: off
|
|
21
|
+
from ..utils import config_mod_getattr as _config_mod_getattr
|
|
22
|
+
|
|
20
23
|
_config_mod_getattr(
|
|
21
24
|
{
|
|
22
25
|
"XGBClassifier": ".classifier.XGBClassifier",
|
|
@@ -24,3 +27,7 @@ _config_mod_getattr(
|
|
|
24
27
|
},
|
|
25
28
|
globals(),
|
|
26
29
|
)
|
|
30
|
+
|
|
31
|
+
del _config_mod_getattr
|
|
32
|
+
|
|
33
|
+
__all__ = ["Booster", "DMatrix", "XGBClassifier", "XGBRegressor", "predict", "train"]
|
|
@@ -12,9 +12,11 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from typing import Union
|
|
16
|
+
|
|
15
17
|
import numpy as np
|
|
16
18
|
|
|
17
|
-
from ....
|
|
19
|
+
from .... import tensor as mt
|
|
18
20
|
from ....tensor.merge.vstack import _vstack
|
|
19
21
|
from ..utils import make_import_error_func
|
|
20
22
|
from .core import XGBScikitLearnBase, xgboost
|
|
@@ -33,6 +35,14 @@ else:
|
|
|
33
35
|
Implementation of the scikit-learn API for XGBoost classification.
|
|
34
36
|
"""
|
|
35
37
|
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
xgb_model: Union[xgboost.XGBClassifier, xgboost.Booster] = None,
|
|
41
|
+
**kwargs,
|
|
42
|
+
):
|
|
43
|
+
super().__init__(**kwargs)
|
|
44
|
+
self._set_model(xgb_model)
|
|
45
|
+
|
|
36
46
|
def fit(
|
|
37
47
|
self,
|
|
38
48
|
X,
|
|
@@ -46,7 +56,7 @@ else:
|
|
|
46
56
|
**kw,
|
|
47
57
|
):
|
|
48
58
|
session = kw.pop("session", None)
|
|
49
|
-
run_kwargs = kw.pop("run_kwargs", dict()
|
|
59
|
+
run_kwargs = kw.pop("run_kwargs", None) or dict()
|
|
50
60
|
dtrain, evals = wrap_evaluation_matrices(
|
|
51
61
|
None,
|
|
52
62
|
X,
|
|
@@ -58,6 +68,7 @@ else:
|
|
|
58
68
|
base_margin_eval_set,
|
|
59
69
|
)
|
|
60
70
|
params = self.get_xgb_params()
|
|
71
|
+
self._n_features_in = X.shape[1]
|
|
61
72
|
self.n_classes_ = num_class or 1
|
|
62
73
|
if self.n_classes_ > 2:
|
|
63
74
|
params["objective"] = "multi:softprob"
|
|
@@ -81,7 +92,7 @@ else:
|
|
|
81
92
|
def predict(self, data, **kw):
|
|
82
93
|
prob = self.predict_proba(data, flag=True, **kw)
|
|
83
94
|
if prob.ndim > 1:
|
|
84
|
-
prediction = argmax(prob, axis=1)
|
|
95
|
+
prediction = mt.argmax(prob, axis=1)
|
|
85
96
|
else:
|
|
86
97
|
prediction = (prob > 0.5).astype(np.int64)
|
|
87
98
|
return prediction
|
|
@@ -103,7 +114,7 @@ else:
|
|
|
103
114
|
# binary logistic function
|
|
104
115
|
classone_probs = prediction
|
|
105
116
|
classzero_probs = 1.0 - classone_probs
|
|
106
|
-
return transpose(_vstack((classzero_probs, classone_probs)))
|
|
117
|
+
return mt.transpose(_vstack((classzero_probs, classone_probs)))
|
|
107
118
|
|
|
108
119
|
@property
|
|
109
120
|
def classes_(self) -> np.ndarray:
|