maxframe 2.2.0__cp310-cp310-macosx_10_9_universal2.whl → 2.3.0rc1__cp310-cp310-macosx_10_9_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of maxframe might be problematic. Click here for more details.
- maxframe/_utils.cpython-310-darwin.so +0 -0
- maxframe/codegen/core.py +3 -2
- maxframe/codegen/spe/dataframe/merge.py +4 -0
- maxframe/codegen/spe/dataframe/misc.py +2 -0
- maxframe/codegen/spe/dataframe/reduction.py +18 -0
- maxframe/codegen/spe/dataframe/sort.py +9 -1
- maxframe/codegen/spe/dataframe/tests/test_reduction.py +13 -0
- maxframe/codegen/spe/dataframe/tseries.py +9 -0
- maxframe/codegen/spe/learn/contrib/lightgbm.py +4 -3
- maxframe/codegen/spe/tensor/datasource.py +1 -0
- maxframe/config/config.py +3 -0
- maxframe/conftest.py +10 -0
- maxframe/core/base.py +2 -1
- maxframe/core/entity/tileables.py +2 -0
- maxframe/core/graph/core.cpython-310-darwin.so +0 -0
- maxframe/core/graph/entity.py +7 -1
- maxframe/core/mode.py +6 -1
- maxframe/dataframe/__init__.py +2 -2
- maxframe/dataframe/arithmetic/__init__.py +4 -0
- maxframe/dataframe/arithmetic/maximum.py +33 -0
- maxframe/dataframe/arithmetic/minimum.py +33 -0
- maxframe/dataframe/core.py +98 -106
- maxframe/dataframe/datasource/core.py +6 -0
- maxframe/dataframe/datasource/direct.py +57 -0
- maxframe/dataframe/datasource/read_csv.py +19 -11
- maxframe/dataframe/datasource/read_odps_query.py +29 -6
- maxframe/dataframe/datasource/read_odps_table.py +32 -10
- maxframe/dataframe/datasource/read_parquet.py +38 -39
- maxframe/dataframe/datastore/__init__.py +6 -0
- maxframe/dataframe/datastore/direct.py +268 -0
- maxframe/dataframe/datastore/to_odps.py +6 -0
- maxframe/dataframe/extensions/flatjson.py +2 -1
- maxframe/dataframe/groupby/__init__.py +5 -1
- maxframe/dataframe/groupby/aggregation.py +10 -6
- maxframe/dataframe/groupby/apply_chunk.py +1 -3
- maxframe/dataframe/groupby/core.py +20 -4
- maxframe/dataframe/indexing/__init__.py +2 -1
- maxframe/dataframe/indexing/insert.py +45 -17
- maxframe/dataframe/merge/__init__.py +3 -0
- maxframe/dataframe/merge/combine.py +244 -0
- maxframe/dataframe/misc/__init__.py +14 -3
- maxframe/dataframe/misc/check_unique.py +41 -10
- maxframe/dataframe/misc/drop.py +31 -0
- maxframe/dataframe/misc/infer_dtypes.py +251 -0
- maxframe/dataframe/misc/map.py +31 -18
- maxframe/dataframe/misc/repeat.py +159 -0
- maxframe/dataframe/misc/tests/test_misc.py +35 -1
- maxframe/dataframe/missing/checkna.py +3 -2
- maxframe/dataframe/reduction/__init__.py +10 -5
- maxframe/dataframe/reduction/aggregation.py +6 -6
- maxframe/dataframe/reduction/argmax.py +7 -4
- maxframe/dataframe/reduction/argmin.py +7 -4
- maxframe/dataframe/reduction/core.py +18 -9
- maxframe/dataframe/reduction/mode.py +144 -0
- maxframe/dataframe/reduction/nunique.py +10 -3
- maxframe/dataframe/reduction/tests/test_reduction.py +12 -0
- maxframe/dataframe/sort/__init__.py +9 -2
- maxframe/dataframe/sort/argsort.py +7 -1
- maxframe/dataframe/sort/core.py +1 -1
- maxframe/dataframe/sort/rank.py +147 -0
- maxframe/dataframe/tseries/__init__.py +19 -0
- maxframe/dataframe/tseries/at_time.py +61 -0
- maxframe/dataframe/tseries/between_time.py +122 -0
- maxframe/dataframe/utils.py +30 -26
- maxframe/learn/contrib/llm/core.py +16 -7
- maxframe/learn/contrib/llm/deploy/__init__.py +13 -0
- maxframe/learn/contrib/llm/deploy/config.py +221 -0
- maxframe/learn/contrib/llm/deploy/core.py +247 -0
- maxframe/learn/contrib/llm/deploy/framework.py +35 -0
- maxframe/learn/contrib/llm/deploy/loader.py +360 -0
- maxframe/learn/contrib/llm/deploy/tests/__init__.py +13 -0
- maxframe/learn/contrib/llm/deploy/tests/test_register_models.py +359 -0
- maxframe/learn/contrib/llm/models/__init__.py +1 -0
- maxframe/learn/contrib/llm/models/dashscope.py +12 -6
- maxframe/learn/contrib/llm/models/managed.py +76 -11
- maxframe/learn/contrib/llm/models/openai.py +72 -0
- maxframe/learn/contrib/llm/tests/__init__.py +13 -0
- maxframe/learn/contrib/llm/tests/test_core.py +34 -0
- maxframe/learn/contrib/llm/tests/test_openai.py +187 -0
- maxframe/learn/contrib/llm/tests/test_text_gen.py +155 -0
- maxframe/learn/contrib/llm/text.py +348 -42
- maxframe/learn/contrib/models.py +4 -1
- maxframe/learn/contrib/xgboost/classifier.py +2 -0
- maxframe/learn/contrib/xgboost/core.py +31 -7
- maxframe/learn/contrib/xgboost/predict.py +4 -2
- maxframe/learn/contrib/xgboost/regressor.py +5 -0
- maxframe/learn/contrib/xgboost/train.py +2 -0
- maxframe/learn/preprocessing/_data/min_max_scaler.py +34 -23
- maxframe/learn/preprocessing/_data/standard_scaler.py +34 -25
- maxframe/learn/utils/__init__.py +1 -0
- maxframe/learn/utils/extmath.py +42 -9
- maxframe/learn/utils/odpsio.py +80 -11
- maxframe/lib/filesystem/_oss_lib/common.py +2 -0
- maxframe/lib/mmh3.cpython-310-darwin.so +0 -0
- maxframe/opcodes.py +9 -1
- maxframe/remote/core.py +4 -0
- maxframe/serialization/core.cpython-310-darwin.so +0 -0
- maxframe/serialization/tests/test_serial.py +2 -2
- maxframe/tensor/arithmetic/__init__.py +1 -1
- maxframe/tensor/arithmetic/core.py +2 -2
- maxframe/tensor/arithmetic/tests/test_arithmetic.py +0 -9
- maxframe/tensor/core.py +3 -0
- maxframe/tensor/misc/copyto.py +1 -1
- maxframe/tests/test_udf.py +61 -0
- maxframe/tests/test_utils.py +8 -5
- maxframe/udf.py +103 -7
- maxframe/utils.py +61 -8
- {maxframe-2.2.0.dist-info → maxframe-2.3.0rc1.dist-info}/METADATA +1 -2
- {maxframe-2.2.0.dist-info → maxframe-2.3.0rc1.dist-info}/RECORD +113 -90
- maxframe_client/session/task.py +8 -1
- maxframe_client/tests/test_session.py +24 -0
- maxframe/dataframe/arrays.py +0 -864
- {maxframe-2.2.0.dist-info → maxframe-2.3.0rc1.dist-info}/WHEEL +0 -0
- {maxframe-2.2.0.dist-info → maxframe-2.3.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
# Copyright 1999-2025 Alibaba Group Holding Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
import maxframe.dataframe as md
|
|
18
|
+
|
|
19
|
+
from ..models.openai import (
|
|
20
|
+
OpenAICompatibleLLM,
|
|
21
|
+
OpenAICompatibleTextGenOp,
|
|
22
|
+
OpenAICompatibleTextLLM,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_openai_compatible_llm_field_assignment():
|
|
27
|
+
"""Test OpenAICompatibleLLM field assignment."""
|
|
28
|
+
llm = OpenAICompatibleLLM()
|
|
29
|
+
llm.base_url = "https://api.openai.com/v1"
|
|
30
|
+
llm.api_key = "test-key"
|
|
31
|
+
llm.batch_size = 10
|
|
32
|
+
llm.batch_timeout = 300
|
|
33
|
+
|
|
34
|
+
assert llm.base_url == "https://api.openai.com/v1"
|
|
35
|
+
assert llm.api_key == "test-key"
|
|
36
|
+
assert llm.batch_size == 10
|
|
37
|
+
assert llm.batch_timeout == 300
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def test_openai_compatible_text_llm_initialization():
|
|
41
|
+
"""Test OpenAICompatibleTextLLM initialization."""
|
|
42
|
+
model = OpenAICompatibleTextLLM(
|
|
43
|
+
name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", api_key="test-key"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
assert model.name == "gpt-3.5-turbo"
|
|
47
|
+
assert model.base_url == "https://api.openai.com/v1"
|
|
48
|
+
assert model.api_key == "test-key"
|
|
49
|
+
# Test inherited default values
|
|
50
|
+
assert model.batch_size is None
|
|
51
|
+
assert model.batch_timeout is None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def test_openai_compatible_text_llm_generate_method():
|
|
55
|
+
"""Test OpenAICompatibleTextLLM generate method."""
|
|
56
|
+
model = OpenAICompatibleTextLLM(
|
|
57
|
+
name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", api_key="test-key"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Create test data
|
|
61
|
+
df = md.DataFrame({"query": ["Hello world"]})
|
|
62
|
+
prompt_template = [{"role": "user", "content": "{query}"}]
|
|
63
|
+
params = {"temperature": 0.7}
|
|
64
|
+
|
|
65
|
+
# Test that generate method returns a DataFrame (result of operator execution)
|
|
66
|
+
result = model.generate(
|
|
67
|
+
data=df,
|
|
68
|
+
prompt_template=prompt_template,
|
|
69
|
+
simple_output=True,
|
|
70
|
+
params=params,
|
|
71
|
+
extra_param="test",
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Verify result is a DataFrame
|
|
75
|
+
assert hasattr(result, "index_value")
|
|
76
|
+
assert hasattr(result, "dtypes")
|
|
77
|
+
|
|
78
|
+
# Verify the operator that was created has the correct parameters
|
|
79
|
+
# We can access the operator through the result's op attribute
|
|
80
|
+
op = result.op
|
|
81
|
+
assert isinstance(op, OpenAICompatibleTextGenOp)
|
|
82
|
+
assert op.model == model
|
|
83
|
+
assert op.prompt_template == prompt_template
|
|
84
|
+
assert op.simple_output is True
|
|
85
|
+
assert op.params == params
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def test_openai_compatible_text_llm_generate_with_defaults():
|
|
89
|
+
"""Test OpenAICompatibleTextLLM generate method with default parameters."""
|
|
90
|
+
model = OpenAICompatibleTextLLM(
|
|
91
|
+
name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", api_key="test-key"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
df = md.DataFrame({"query": ["Hello world"]})
|
|
95
|
+
prompt_template = [{"role": "user", "content": "{query}"}]
|
|
96
|
+
|
|
97
|
+
# Test that generate method returns a DataFrame (result of operator execution)
|
|
98
|
+
result = model.generate(data=df, prompt_template=prompt_template)
|
|
99
|
+
|
|
100
|
+
# Verify result is a DataFrame
|
|
101
|
+
assert hasattr(result, "index_value")
|
|
102
|
+
assert hasattr(result, "dtypes")
|
|
103
|
+
|
|
104
|
+
# Verify the operator that was created has the correct parameters
|
|
105
|
+
# We can access the operator through the result's op attribute
|
|
106
|
+
op = result.op
|
|
107
|
+
assert isinstance(op, OpenAICompatibleTextGenOp)
|
|
108
|
+
assert op.model == model
|
|
109
|
+
assert op.prompt_template == prompt_template
|
|
110
|
+
assert op.simple_output is False
|
|
111
|
+
assert op.params is None
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def test_openai_compatible_text_generation_operator_default_values():
|
|
115
|
+
"""Test OpenAICompatibleTextGenOperator default field values."""
|
|
116
|
+
op = OpenAICompatibleTextGenOp()
|
|
117
|
+
assert op.simple_output is False
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def test_openai_compatible_text_generation_operator_field_assignment():
|
|
121
|
+
"""Test OpenAICompatibleTextGenOperator field assignment."""
|
|
122
|
+
op = OpenAICompatibleTextGenOp()
|
|
123
|
+
op.simple_output = True
|
|
124
|
+
assert op.simple_output is True
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def test_openai_compatible_text_generation_operator_output_dtypes():
|
|
128
|
+
"""Test OpenAICompatibleTextGenOperator inherits correct output dtypes."""
|
|
129
|
+
op = OpenAICompatibleTextGenOp()
|
|
130
|
+
dtypes = op.get_output_dtypes()
|
|
131
|
+
assert dtypes["response"] == np.dtype("O")
|
|
132
|
+
assert dtypes["success"] == np.dtype("bool")
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def test_openai_compatible_text_generation_operator_with_model():
|
|
136
|
+
"""Test OpenAICompatibleTextGenOperator with model parameter."""
|
|
137
|
+
model = OpenAICompatibleTextLLM(
|
|
138
|
+
name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", api_key="test-key"
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
prompt_template = [{"role": "user", "content": "Hello"}]
|
|
142
|
+
params = {"temperature": 0.5}
|
|
143
|
+
|
|
144
|
+
op = OpenAICompatibleTextGenOp(
|
|
145
|
+
model=model, prompt_template=prompt_template, simple_output=True, params=params
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
assert op.model == model
|
|
149
|
+
assert op.prompt_template == prompt_template
|
|
150
|
+
assert op.simple_output is True
|
|
151
|
+
assert op.params == params
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def test_openai_compatible_text_llm_inheritance():
|
|
155
|
+
"""Test that OpenAICompatibleTextLLM properly inherits from both parent classes."""
|
|
156
|
+
model = OpenAICompatibleTextLLM(
|
|
157
|
+
name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", api_key="test-key"
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Test TextGenLLM inheritance - should have validate_params method
|
|
161
|
+
assert hasattr(model, "validate_params")
|
|
162
|
+
assert callable(getattr(model, "validate_params"))
|
|
163
|
+
|
|
164
|
+
# Test OpenAICompatibleLLM inheritance - should have OpenAI-specific fields
|
|
165
|
+
assert hasattr(model, "base_url")
|
|
166
|
+
assert hasattr(model, "api_key")
|
|
167
|
+
assert hasattr(model, "batch_size")
|
|
168
|
+
assert hasattr(model, "batch_timeout")
|
|
169
|
+
|
|
170
|
+
# Test that validate_params doesn't raise error with empty params
|
|
171
|
+
model.validate_params({})
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def test_openai_compatible_text_llm_validate_params():
|
|
175
|
+
"""Test OpenAICompatibleTextLLM validate_params method."""
|
|
176
|
+
model = OpenAICompatibleTextLLM(
|
|
177
|
+
name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", api_key="test-key"
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Should not raise exception with valid params
|
|
181
|
+
model.validate_params({"temperature": 0.7, "max_tokens": 100})
|
|
182
|
+
|
|
183
|
+
# Should not raise exception with empty params
|
|
184
|
+
model.validate_params({})
|
|
185
|
+
|
|
186
|
+
# Should not raise exception with None params
|
|
187
|
+
model.validate_params(None)
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
# Copyright 1999-2025 Alibaba Group Holding Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import mock
|
|
16
|
+
import numpy as np
|
|
17
|
+
import pytest
|
|
18
|
+
|
|
19
|
+
import maxframe.dataframe as md
|
|
20
|
+
|
|
21
|
+
from .. import text as llm_text
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def test_generate_invalid_data_type_raises():
|
|
25
|
+
model = mock.create_autospec(llm_text.TextGenLLM, instance=True)
|
|
26
|
+
with pytest.raises(ValueError):
|
|
27
|
+
llm_text.generate(
|
|
28
|
+
123, model, prompt_template=[{"role": "user", "content": "x"}]
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def test_generate_invalid_model_type_raises():
|
|
33
|
+
df = md.DataFrame({"query": ["x"]})
|
|
34
|
+
with pytest.raises(TypeError):
|
|
35
|
+
llm_text.generate(
|
|
36
|
+
df, object(), prompt_template=[{"role": "user", "content": "x"}]
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def test_generate_calls_validate_params_with_default_and_forwards():
|
|
41
|
+
df = md.DataFrame({"query": ["hello"]})
|
|
42
|
+
model = mock.create_autospec(llm_text.TextGenLLM, instance=True)
|
|
43
|
+
sentinel = object()
|
|
44
|
+
model.generate.return_value = sentinel
|
|
45
|
+
|
|
46
|
+
ret = llm_text.generate(
|
|
47
|
+
df, model, prompt_template=[{"role": "user", "content": "x"}]
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
assert ret is sentinel
|
|
51
|
+
model.validate_params.assert_called_once()
|
|
52
|
+
assert model.validate_params.call_args[0][0] == {}
|
|
53
|
+
assert model.generate.call_args.kwargs["prompt_template"] == [
|
|
54
|
+
{"role": "user", "content": "x"}
|
|
55
|
+
]
|
|
56
|
+
assert model.generate.call_args.kwargs["params"] == {}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def test_summary_type_and_dtype_validation_and_forward():
|
|
60
|
+
model = mock.create_autospec(llm_text.TextGenLLM, instance=True)
|
|
61
|
+
|
|
62
|
+
with pytest.raises(ValueError):
|
|
63
|
+
llm_text.summary("not_series", model)
|
|
64
|
+
|
|
65
|
+
s_wrong = md.Series(np.array([1], dtype=np.int_))
|
|
66
|
+
with pytest.raises(ValueError):
|
|
67
|
+
llm_text.summary(s_wrong, model)
|
|
68
|
+
|
|
69
|
+
s_ok = md.Series(np.array(["a"], dtype=np.str_))
|
|
70
|
+
model.summarize.return_value = "OK_SUM"
|
|
71
|
+
with mock.patch.object(llm_text.np, "str_", s_ok.dtype, create=True):
|
|
72
|
+
ret = llm_text.summary(s_ok, model)
|
|
73
|
+
assert ret == "OK_SUM"
|
|
74
|
+
model.summarize.assert_called_once()
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def test_translate_type_and_dtype_validation_and_forward():
|
|
78
|
+
model = mock.create_autospec(llm_text.TextGenLLM, instance=True)
|
|
79
|
+
|
|
80
|
+
with pytest.raises(ValueError):
|
|
81
|
+
llm_text.translate(
|
|
82
|
+
"not_series", model, source_language="en", target_language="zh"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
s_wrong = md.Series(np.array([1], dtype=np.int_))
|
|
86
|
+
with pytest.raises(ValueError):
|
|
87
|
+
llm_text.translate(s_wrong, model, source_language="en", target_language="zh")
|
|
88
|
+
|
|
89
|
+
s_ok = md.Series(np.array(["hello"], dtype=np.str_))
|
|
90
|
+
model.translate.return_value = "OK_TRANS"
|
|
91
|
+
with mock.patch.object(llm_text.np, "str_", s_ok.dtype, create=True):
|
|
92
|
+
ret = llm_text.translate(
|
|
93
|
+
s_ok, model, source_language="en", target_language="zh"
|
|
94
|
+
)
|
|
95
|
+
assert ret == "OK_TRANS"
|
|
96
|
+
model.translate.assert_called_once()
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def test_classify_validation_and_forward():
|
|
100
|
+
model = mock.create_autospec(llm_text.TextGenLLM, instance=True)
|
|
101
|
+
|
|
102
|
+
with pytest.raises(ValueError):
|
|
103
|
+
llm_text.classify("not_series", model, labels=["A", "B"])
|
|
104
|
+
|
|
105
|
+
s_wrong = md.Series(np.array([1], dtype=np.int_))
|
|
106
|
+
with pytest.raises(ValueError):
|
|
107
|
+
llm_text.classify(s_wrong, model, labels=["A", "B"])
|
|
108
|
+
|
|
109
|
+
s_ok = md.Series(np.array(["text"], dtype=np.str_))
|
|
110
|
+
with pytest.raises(TypeError):
|
|
111
|
+
with mock.patch.object(llm_text.np, "str_", s_ok.dtype, create=True):
|
|
112
|
+
llm_text.classify(s_ok, model, labels="not_list")
|
|
113
|
+
|
|
114
|
+
with pytest.raises(ValueError):
|
|
115
|
+
with mock.patch.object(llm_text.np, "str_", s_ok.dtype, create=True):
|
|
116
|
+
llm_text.classify(s_ok, model, labels=[])
|
|
117
|
+
|
|
118
|
+
model.classify.return_value = "OK_CLS"
|
|
119
|
+
with mock.patch.object(llm_text.np, "str_", s_ok.dtype, create=True):
|
|
120
|
+
ret = llm_text.classify(
|
|
121
|
+
s_ok,
|
|
122
|
+
model,
|
|
123
|
+
labels=["A", "B"],
|
|
124
|
+
description="desc",
|
|
125
|
+
examples=[{"text": "t", "label": "A", "reason": "r"}],
|
|
126
|
+
)
|
|
127
|
+
assert ret == "OK_CLS"
|
|
128
|
+
model.classify.assert_called_once()
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def test_extract_validation_and_forward():
|
|
132
|
+
model = mock.create_autospec(llm_text.TextGenLLM, instance=True)
|
|
133
|
+
|
|
134
|
+
with pytest.raises(ValueError):
|
|
135
|
+
llm_text.extract("not_series", model, schema={"a": "b"})
|
|
136
|
+
|
|
137
|
+
s_wrong = md.Series(np.array([1], dtype=np.int_))
|
|
138
|
+
with pytest.raises(ValueError):
|
|
139
|
+
llm_text.extract(s_wrong, model, schema={"a": "b"})
|
|
140
|
+
|
|
141
|
+
s_ok = md.Series(np.array(["text"], dtype=np.str_))
|
|
142
|
+
with pytest.raises(ValueError):
|
|
143
|
+
llm_text.extract(s_ok, model, schema=None)
|
|
144
|
+
|
|
145
|
+
with pytest.raises(ValueError):
|
|
146
|
+
llm_text.extract(s_ok, model, schema={"a": "b"}, examples="not_list")
|
|
147
|
+
|
|
148
|
+
with pytest.raises(ValueError):
|
|
149
|
+
llm_text.extract(s_ok, model, schema={"a": "b"}, examples=[{"not": "tuple"}])
|
|
150
|
+
|
|
151
|
+
model.extract.return_value = "OK_EXT"
|
|
152
|
+
with mock.patch.object(llm_text.np, "str_", s_ok.dtype, create=True):
|
|
153
|
+
ret = llm_text.extract(s_ok, model, schema={"a": "b"}, examples=[("in", "out")])
|
|
154
|
+
assert ret == "OK_EXT"
|
|
155
|
+
model.extract.assert_called_once()
|