kiln-ai 0.6.1__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +19 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
- kiln_ai/adapters/fine_tune/__init__.py +14 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
- kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
- kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
- kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
- kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
- kiln_ai/adapters/langchain_adapters.py +103 -13
- kiln_ai/adapters/ml_model_list.py +218 -304
- kiln_ai/adapters/ollama_tools.py +114 -0
- kiln_ai/adapters/provider_tools.py +295 -0
- kiln_ai/adapters/repair/test_repair_task.py +6 -11
- kiln_ai/adapters/test_langchain_adapter.py +46 -18
- kiln_ai/adapters/test_ollama_tools.py +42 -0
- kiln_ai/adapters/test_prompt_adaptors.py +7 -5
- kiln_ai/adapters/test_provider_tools.py +312 -0
- kiln_ai/adapters/test_structured_output.py +22 -43
- kiln_ai/datamodel/__init__.py +235 -22
- kiln_ai/datamodel/basemodel.py +30 -0
- kiln_ai/datamodel/registry.py +31 -0
- kiln_ai/datamodel/test_basemodel.py +29 -1
- kiln_ai/datamodel/test_dataset_split.py +234 -0
- kiln_ai/datamodel/test_example_models.py +12 -0
- kiln_ai/datamodel/test_models.py +91 -1
- kiln_ai/datamodel/test_registry.py +96 -0
- kiln_ai/utils/config.py +9 -0
- kiln_ai/utils/name_generator.py +125 -0
- kiln_ai/utils/test_name_geneator.py +47 -0
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.0.dist-info}/METADATA +4 -2
- kiln_ai-0.7.0.dist-info/RECORD +56 -0
- kiln_ai/adapters/test_ml_model_list.py +0 -181
- kiln_ai-0.6.1.dist-info/RECORD +0 -37
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
from unittest.mock import AsyncMock, patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from kiln_ai.adapters.ml_model_list import (
|
|
6
|
+
ModelName,
|
|
7
|
+
ModelProviderName,
|
|
8
|
+
)
|
|
9
|
+
from kiln_ai.adapters.ollama_tools import OllamaConnection
|
|
10
|
+
from kiln_ai.adapters.provider_tools import (
|
|
11
|
+
check_provider_warnings,
|
|
12
|
+
get_model_and_provider,
|
|
13
|
+
kiln_model_provider_from,
|
|
14
|
+
provider_enabled,
|
|
15
|
+
provider_name_from_id,
|
|
16
|
+
provider_options_for_custom_model,
|
|
17
|
+
provider_warnings,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@pytest.fixture
|
|
22
|
+
def mock_config():
|
|
23
|
+
with patch("kiln_ai.adapters.provider_tools.get_config_value") as mock:
|
|
24
|
+
yield mock
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def test_check_provider_warnings_no_warning(mock_config):
|
|
28
|
+
mock_config.return_value = "some_value"
|
|
29
|
+
|
|
30
|
+
# This should not raise an exception
|
|
31
|
+
check_provider_warnings(ModelProviderName.amazon_bedrock)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def test_check_provider_warnings_missing_key(mock_config):
|
|
35
|
+
mock_config.return_value = None
|
|
36
|
+
|
|
37
|
+
with pytest.raises(ValueError) as exc_info:
|
|
38
|
+
check_provider_warnings(ModelProviderName.amazon_bedrock)
|
|
39
|
+
|
|
40
|
+
assert provider_warnings[ModelProviderName.amazon_bedrock].message in str(
|
|
41
|
+
exc_info.value
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def test_check_provider_warnings_unknown_provider():
|
|
46
|
+
# This should not raise an exception, as no settings are required for unknown providers
|
|
47
|
+
check_provider_warnings("unknown_provider")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@pytest.mark.parametrize(
|
|
51
|
+
"provider_name",
|
|
52
|
+
[
|
|
53
|
+
ModelProviderName.amazon_bedrock,
|
|
54
|
+
ModelProviderName.openrouter,
|
|
55
|
+
ModelProviderName.groq,
|
|
56
|
+
ModelProviderName.openai,
|
|
57
|
+
ModelProviderName.fireworks_ai,
|
|
58
|
+
],
|
|
59
|
+
)
|
|
60
|
+
def test_check_provider_warnings_all_providers(mock_config, provider_name):
|
|
61
|
+
mock_config.return_value = None
|
|
62
|
+
|
|
63
|
+
with pytest.raises(ValueError) as exc_info:
|
|
64
|
+
check_provider_warnings(provider_name)
|
|
65
|
+
|
|
66
|
+
assert provider_warnings[provider_name].message in str(exc_info.value)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def test_check_provider_warnings_partial_keys_set(mock_config):
|
|
70
|
+
def mock_get(key):
|
|
71
|
+
return "value" if key == "bedrock_access_key" else None
|
|
72
|
+
|
|
73
|
+
mock_config.side_effect = mock_get
|
|
74
|
+
|
|
75
|
+
with pytest.raises(ValueError) as exc_info:
|
|
76
|
+
check_provider_warnings(ModelProviderName.amazon_bedrock)
|
|
77
|
+
|
|
78
|
+
assert provider_warnings[ModelProviderName.amazon_bedrock].message in str(
|
|
79
|
+
exc_info.value
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def test_provider_name_from_id_unknown_provider():
|
|
84
|
+
assert (
|
|
85
|
+
provider_name_from_id("unknown_provider")
|
|
86
|
+
== "Unknown provider: unknown_provider"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def test_provider_name_from_id_case_sensitivity():
|
|
91
|
+
assert (
|
|
92
|
+
provider_name_from_id(ModelProviderName.amazon_bedrock.upper())
|
|
93
|
+
== "Unknown provider: AMAZON_BEDROCK"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@pytest.mark.parametrize(
|
|
98
|
+
"provider_id, expected_name",
|
|
99
|
+
[
|
|
100
|
+
(ModelProviderName.amazon_bedrock, "Amazon Bedrock"),
|
|
101
|
+
(ModelProviderName.openrouter, "OpenRouter"),
|
|
102
|
+
(ModelProviderName.groq, "Groq"),
|
|
103
|
+
(ModelProviderName.ollama, "Ollama"),
|
|
104
|
+
(ModelProviderName.openai, "OpenAI"),
|
|
105
|
+
(ModelProviderName.fireworks_ai, "Fireworks AI"),
|
|
106
|
+
],
|
|
107
|
+
)
|
|
108
|
+
def test_provider_name_from_id_parametrized(provider_id, expected_name):
|
|
109
|
+
assert provider_name_from_id(provider_id) == expected_name
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def test_get_model_and_provider_valid():
|
|
113
|
+
# Test with a known valid model and provider combination
|
|
114
|
+
model, provider = get_model_and_provider(
|
|
115
|
+
ModelName.phi_3_5, ModelProviderName.ollama
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
assert model is not None
|
|
119
|
+
assert provider is not None
|
|
120
|
+
assert model.name == ModelName.phi_3_5
|
|
121
|
+
assert provider.name == ModelProviderName.ollama
|
|
122
|
+
assert provider.provider_options["model"] == "phi3.5"
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def test_get_model_and_provider_invalid_model():
|
|
126
|
+
# Test with an invalid model name
|
|
127
|
+
model, provider = get_model_and_provider(
|
|
128
|
+
"nonexistent_model", ModelProviderName.ollama
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
assert model is None
|
|
132
|
+
assert provider is None
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def test_get_model_and_provider_invalid_provider():
|
|
136
|
+
# Test with a valid model but invalid provider
|
|
137
|
+
model, provider = get_model_and_provider(ModelName.phi_3_5, "nonexistent_provider")
|
|
138
|
+
|
|
139
|
+
assert model is None
|
|
140
|
+
assert provider is None
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def test_get_model_and_provider_valid_model_wrong_provider():
|
|
144
|
+
# Test with a valid model but a provider that doesn't support it
|
|
145
|
+
model, provider = get_model_and_provider(
|
|
146
|
+
ModelName.phi_3_5, ModelProviderName.amazon_bedrock
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
assert model is None
|
|
150
|
+
assert provider is None
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def test_get_model_and_provider_multiple_providers():
|
|
154
|
+
# Test with a model that has multiple providers
|
|
155
|
+
model, provider = get_model_and_provider(
|
|
156
|
+
ModelName.llama_3_1_70b, ModelProviderName.groq
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
assert model is not None
|
|
160
|
+
assert provider is not None
|
|
161
|
+
assert model.name == ModelName.llama_3_1_70b
|
|
162
|
+
assert provider.name == ModelProviderName.groq
|
|
163
|
+
assert provider.provider_options["model"] == "llama-3.1-70b-versatile"
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@pytest.mark.asyncio
|
|
167
|
+
async def test_provider_enabled_ollama_success():
|
|
168
|
+
with patch(
|
|
169
|
+
"kiln_ai.adapters.provider_tools.get_ollama_connection", new_callable=AsyncMock
|
|
170
|
+
) as mock_get_ollama:
|
|
171
|
+
# Mock successful Ollama connection with models
|
|
172
|
+
mock_get_ollama.return_value = OllamaConnection(
|
|
173
|
+
message="Connected", supported_models=["phi3.5:latest"]
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
result = await provider_enabled(ModelProviderName.ollama)
|
|
177
|
+
assert result is True
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
@pytest.mark.asyncio
|
|
181
|
+
async def test_provider_enabled_ollama_no_models():
|
|
182
|
+
with patch(
|
|
183
|
+
"kiln_ai.adapters.provider_tools.get_ollama_connection", new_callable=AsyncMock
|
|
184
|
+
) as mock_get_ollama:
|
|
185
|
+
# Mock Ollama connection but with no models
|
|
186
|
+
mock_get_ollama.return_value = OllamaConnection(
|
|
187
|
+
message="Connected but no models",
|
|
188
|
+
supported_models=[],
|
|
189
|
+
unsupported_models=[],
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
result = await provider_enabled(ModelProviderName.ollama)
|
|
193
|
+
assert result is False
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@pytest.mark.asyncio
|
|
197
|
+
async def test_provider_enabled_ollama_connection_error():
|
|
198
|
+
with patch(
|
|
199
|
+
"kiln_ai.adapters.provider_tools.get_ollama_connection", new_callable=AsyncMock
|
|
200
|
+
) as mock_get_ollama:
|
|
201
|
+
# Mock Ollama connection failure
|
|
202
|
+
mock_get_ollama.side_effect = Exception("Connection failed")
|
|
203
|
+
|
|
204
|
+
result = await provider_enabled(ModelProviderName.ollama)
|
|
205
|
+
assert result is False
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
@pytest.mark.asyncio
|
|
209
|
+
async def test_provider_enabled_openai_with_key(mock_config):
|
|
210
|
+
# Mock config to return API key
|
|
211
|
+
mock_config.return_value = "fake-api-key"
|
|
212
|
+
|
|
213
|
+
result = await provider_enabled(ModelProviderName.openai)
|
|
214
|
+
assert result is True
|
|
215
|
+
mock_config.assert_called_with("open_ai_api_key")
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
@pytest.mark.asyncio
|
|
219
|
+
async def test_provider_enabled_openai_without_key(mock_config):
|
|
220
|
+
# Mock config to return None for API key
|
|
221
|
+
mock_config.return_value = None
|
|
222
|
+
|
|
223
|
+
result = await provider_enabled(ModelProviderName.openai)
|
|
224
|
+
assert result is False
|
|
225
|
+
mock_config.assert_called_with("open_ai_api_key")
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
@pytest.mark.asyncio
|
|
229
|
+
async def test_provider_enabled_unknown_provider():
|
|
230
|
+
# Test with a provider that isn't in provider_warnings
|
|
231
|
+
result = await provider_enabled("unknown_provider")
|
|
232
|
+
assert result is False
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
@pytest.mark.asyncio
|
|
236
|
+
async def test_kiln_model_provider_from_custom_model_no_provider():
|
|
237
|
+
with pytest.raises(ValueError) as exc_info:
|
|
238
|
+
await kiln_model_provider_from("custom_model")
|
|
239
|
+
assert str(exc_info.value) == "Provider name is required for custom models"
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
@pytest.mark.asyncio
|
|
243
|
+
async def test_kiln_model_provider_from_invalid_provider():
|
|
244
|
+
with pytest.raises(ValueError) as exc_info:
|
|
245
|
+
await kiln_model_provider_from("custom_model", "invalid_provider")
|
|
246
|
+
assert str(exc_info.value) == "Invalid provider name: invalid_provider"
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
@pytest.mark.asyncio
|
|
250
|
+
async def test_kiln_model_provider_from_custom_model_valid(mock_config):
|
|
251
|
+
# Mock config to pass provider warnings check
|
|
252
|
+
mock_config.return_value = "fake-api-key"
|
|
253
|
+
|
|
254
|
+
provider = await kiln_model_provider_from("custom_model", ModelProviderName.openai)
|
|
255
|
+
|
|
256
|
+
assert provider.name == ModelProviderName.openai
|
|
257
|
+
assert provider.supports_structured_output is False
|
|
258
|
+
assert provider.supports_data_gen is False
|
|
259
|
+
assert provider.untested_model is True
|
|
260
|
+
assert "model" in provider.provider_options
|
|
261
|
+
assert provider.provider_options["model"] == "custom_model"
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def test_provider_options_for_custom_model_basic():
|
|
265
|
+
"""Test basic case with custom model name"""
|
|
266
|
+
options = provider_options_for_custom_model(
|
|
267
|
+
"custom_model_name", ModelProviderName.openai
|
|
268
|
+
)
|
|
269
|
+
assert options == {"model": "custom_model_name"}
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def test_provider_options_for_custom_model_bedrock():
|
|
273
|
+
"""Test Amazon Bedrock provider options"""
|
|
274
|
+
options = provider_options_for_custom_model(
|
|
275
|
+
ModelName.llama_3_1_8b, ModelProviderName.amazon_bedrock
|
|
276
|
+
)
|
|
277
|
+
assert options == {"model": ModelName.llama_3_1_8b, "region_name": "us-west-2"}
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
@pytest.mark.parametrize(
|
|
281
|
+
"provider",
|
|
282
|
+
[
|
|
283
|
+
ModelProviderName.openai,
|
|
284
|
+
ModelProviderName.ollama,
|
|
285
|
+
ModelProviderName.fireworks_ai,
|
|
286
|
+
ModelProviderName.openrouter,
|
|
287
|
+
ModelProviderName.groq,
|
|
288
|
+
],
|
|
289
|
+
)
|
|
290
|
+
def test_provider_options_for_custom_model_simple_providers(provider):
|
|
291
|
+
"""Test providers that just need model name"""
|
|
292
|
+
|
|
293
|
+
options = provider_options_for_custom_model(ModelName.llama_3_1_8b, provider)
|
|
294
|
+
assert options == {"model": ModelName.llama_3_1_8b}
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def test_provider_options_for_custom_model_kiln_fine_tune():
|
|
298
|
+
"""Test that kiln_fine_tune raises appropriate error"""
|
|
299
|
+
with pytest.raises(ValueError) as exc_info:
|
|
300
|
+
provider_options_for_custom_model(
|
|
301
|
+
"model_name", ModelProviderName.kiln_fine_tune
|
|
302
|
+
)
|
|
303
|
+
assert (
|
|
304
|
+
str(exc_info.value)
|
|
305
|
+
== "Fine tuned models should populate provider options via another path"
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def test_provider_options_for_custom_model_invalid_enum():
|
|
310
|
+
"""Test handling of invalid enum value"""
|
|
311
|
+
with pytest.raises(ValueError):
|
|
312
|
+
provider_options_for_custom_model("model_name", "invalid_enum_value")
|
|
@@ -6,12 +6,12 @@ import jsonschema.exceptions
|
|
|
6
6
|
import pytest
|
|
7
7
|
|
|
8
8
|
import kiln_ai.datamodel as datamodel
|
|
9
|
+
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
9
10
|
from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter, RunOutput
|
|
10
|
-
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
|
|
11
11
|
from kiln_ai.adapters.ml_model_list import (
|
|
12
12
|
built_in_models,
|
|
13
|
-
ollama_online,
|
|
14
13
|
)
|
|
14
|
+
from kiln_ai.adapters.ollama_tools import ollama_online
|
|
15
15
|
from kiln_ai.adapters.prompt_builders import (
|
|
16
16
|
BasePromptBuilder,
|
|
17
17
|
SimpleChainOfThoughtPromptBuilder,
|
|
@@ -20,23 +20,6 @@ from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
|
|
|
20
20
|
from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
@pytest.mark.parametrize(
|
|
24
|
-
"model_name,provider",
|
|
25
|
-
[
|
|
26
|
-
("llama_3_1_8b", "groq"),
|
|
27
|
-
("mistral_nemo", "openrouter"),
|
|
28
|
-
("llama_3_1_70b", "amazon_bedrock"),
|
|
29
|
-
("claude_3_5_sonnet", "openrouter"),
|
|
30
|
-
("gemini_1_5_pro", "openrouter"),
|
|
31
|
-
("gemini_1_5_flash", "openrouter"),
|
|
32
|
-
("gemini_1_5_flash_8b", "openrouter"),
|
|
33
|
-
],
|
|
34
|
-
)
|
|
35
|
-
@pytest.mark.paid
|
|
36
|
-
async def test_structured_output(tmp_path, model_name, provider):
|
|
37
|
-
await run_structured_output_test(tmp_path, model_name, provider)
|
|
38
|
-
|
|
39
|
-
|
|
40
23
|
@pytest.mark.ollama
|
|
41
24
|
async def test_structured_output_ollama_phi(tmp_path):
|
|
42
25
|
# https://python.langchain.com/v0.2/docs/how_to/structured_output/#advanced-specifying-the-method-for-structuring-outputs
|
|
@@ -112,28 +95,27 @@ async def test_mock_unstructred_response(tmp_path):
|
|
|
112
95
|
|
|
113
96
|
@pytest.mark.paid
|
|
114
97
|
@pytest.mark.ollama
|
|
115
|
-
|
|
116
|
-
|
|
98
|
+
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
99
|
+
async def test_all_built_in_models_structured_output(
|
|
100
|
+
tmp_path, model_name, provider_name
|
|
101
|
+
):
|
|
117
102
|
for model in built_in_models:
|
|
103
|
+
if model.name != model_name:
|
|
104
|
+
continue
|
|
118
105
|
if not model.supports_structured_output:
|
|
119
|
-
|
|
106
|
+
pytest.skip(
|
|
120
107
|
f"Skipping {model.name} because it does not support structured output"
|
|
121
108
|
)
|
|
122
|
-
continue
|
|
123
109
|
for provider in model.providers:
|
|
110
|
+
if provider.name != provider_name:
|
|
111
|
+
continue
|
|
124
112
|
if not provider.supports_structured_output:
|
|
125
|
-
|
|
113
|
+
pytest.skip(
|
|
126
114
|
f"Skipping {model.name} {provider.name} because it does not support structured output"
|
|
127
115
|
)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
await run_structured_output_test(tmp_path, model.name, provider.name)
|
|
132
|
-
except Exception as e:
|
|
133
|
-
print(f"Error running {model.name} {provider.name}")
|
|
134
|
-
errors.append(f"{model.name} {provider.name}: {e}")
|
|
135
|
-
if len(errors) > 0:
|
|
136
|
-
raise RuntimeError(f"Errors: {errors}")
|
|
116
|
+
await run_structured_output_test(tmp_path, model.name, provider.name)
|
|
117
|
+
return
|
|
118
|
+
raise RuntimeError(f"No model {model_name} {provider_name} found")
|
|
137
119
|
|
|
138
120
|
|
|
139
121
|
def build_structured_output_test_task(tmp_path: Path):
|
|
@@ -157,7 +139,7 @@ def build_structured_output_test_task(tmp_path: Path):
|
|
|
157
139
|
|
|
158
140
|
async def run_structured_output_test(tmp_path: Path, model_name: str, provider: str):
|
|
159
141
|
task = build_structured_output_test_task(tmp_path)
|
|
160
|
-
a =
|
|
142
|
+
a = adapter_for_task(task, model_name=model_name, provider=provider)
|
|
161
143
|
parsed = await a.invoke_returning_raw("Cows") # a joke about cows
|
|
162
144
|
if parsed is None or not isinstance(parsed, Dict):
|
|
163
145
|
raise RuntimeError(f"structured response is not a dict: {parsed}")
|
|
@@ -204,7 +186,7 @@ async def run_structured_input_task(
|
|
|
204
186
|
provider: str,
|
|
205
187
|
pb: BasePromptBuilder | None = None,
|
|
206
188
|
):
|
|
207
|
-
a =
|
|
189
|
+
a = adapter_for_task(
|
|
208
190
|
task, model_name=model_name, provider=provider, prompt_builder=pb
|
|
209
191
|
)
|
|
210
192
|
with pytest.raises(ValueError):
|
|
@@ -235,14 +217,11 @@ async def test_structured_input_gpt_4o_mini(tmp_path):
|
|
|
235
217
|
|
|
236
218
|
@pytest.mark.paid
|
|
237
219
|
@pytest.mark.ollama
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
await run_structured_input_test(tmp_path, model.name, provider.name)
|
|
244
|
-
except Exception as e:
|
|
245
|
-
raise RuntimeError(f"Error running {model.name} {provider}") from e
|
|
220
|
+
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
221
|
+
async def test_all_built_in_models_structured_input(
|
|
222
|
+
tmp_path, model_name, provider_name
|
|
223
|
+
):
|
|
224
|
+
await run_structured_input_test(tmp_path, model_name, provider_name)
|
|
246
225
|
|
|
247
226
|
|
|
248
227
|
@pytest.mark.paid
|