kiln-ai 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +19 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
- kiln_ai/adapters/fine_tune/__init__.py +14 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
- kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
- kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
- kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
- kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
- kiln_ai/adapters/langchain_adapters.py +103 -13
- kiln_ai/adapters/ml_model_list.py +239 -303
- kiln_ai/adapters/ollama_tools.py +115 -0
- kiln_ai/adapters/provider_tools.py +308 -0
- kiln_ai/adapters/repair/repair_task.py +4 -2
- kiln_ai/adapters/repair/test_repair_task.py +6 -11
- kiln_ai/adapters/test_langchain_adapter.py +229 -18
- kiln_ai/adapters/test_ollama_tools.py +42 -0
- kiln_ai/adapters/test_prompt_adaptors.py +7 -5
- kiln_ai/adapters/test_provider_tools.py +531 -0
- kiln_ai/adapters/test_structured_output.py +22 -43
- kiln_ai/datamodel/__init__.py +287 -24
- kiln_ai/datamodel/basemodel.py +122 -38
- kiln_ai/datamodel/model_cache.py +116 -0
- kiln_ai/datamodel/registry.py +31 -0
- kiln_ai/datamodel/test_basemodel.py +167 -4
- kiln_ai/datamodel/test_dataset_split.py +234 -0
- kiln_ai/datamodel/test_example_models.py +12 -0
- kiln_ai/datamodel/test_model_cache.py +244 -0
- kiln_ai/datamodel/test_models.py +215 -1
- kiln_ai/datamodel/test_registry.py +96 -0
- kiln_ai/utils/config.py +14 -1
- kiln_ai/utils/name_generator.py +125 -0
- kiln_ai/utils/test_name_geneator.py +47 -0
- kiln_ai-0.7.1.dist-info/METADATA +237 -0
- kiln_ai-0.7.1.dist-info/RECORD +58 -0
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/WHEEL +1 -1
- kiln_ai/adapters/test_ml_model_list.py +0 -181
- kiln_ai-0.6.1.dist-info/METADATA +0 -88
- kiln_ai-0.6.1.dist-info/RECORD +0 -37
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,531 @@
|
|
|
1
|
+
from unittest.mock import AsyncMock, Mock, patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from kiln_ai.adapters.ml_model_list import (
|
|
6
|
+
KilnModel,
|
|
7
|
+
ModelName,
|
|
8
|
+
ModelProviderName,
|
|
9
|
+
)
|
|
10
|
+
from kiln_ai.adapters.ollama_tools import OllamaConnection
|
|
11
|
+
from kiln_ai.adapters.provider_tools import (
|
|
12
|
+
builtin_model_from,
|
|
13
|
+
check_provider_warnings,
|
|
14
|
+
finetune_cache,
|
|
15
|
+
finetune_provider_model,
|
|
16
|
+
get_model_and_provider,
|
|
17
|
+
kiln_model_provider_from,
|
|
18
|
+
provider_enabled,
|
|
19
|
+
provider_name_from_id,
|
|
20
|
+
provider_options_for_custom_model,
|
|
21
|
+
provider_warnings,
|
|
22
|
+
)
|
|
23
|
+
from kiln_ai.datamodel import Finetune, Task
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@pytest.fixture(autouse=True)
|
|
27
|
+
def clear_finetune_cache():
|
|
28
|
+
"""Clear the finetune provider model cache before each test"""
|
|
29
|
+
finetune_cache.clear()
|
|
30
|
+
yield
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@pytest.fixture
|
|
34
|
+
def mock_config():
|
|
35
|
+
with patch("kiln_ai.adapters.provider_tools.get_config_value") as mock:
|
|
36
|
+
yield mock
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@pytest.fixture
|
|
40
|
+
def mock_project():
|
|
41
|
+
with patch("kiln_ai.adapters.provider_tools.project_from_id") as mock:
|
|
42
|
+
project = Mock()
|
|
43
|
+
project.path = "/fake/path"
|
|
44
|
+
mock.return_value = project
|
|
45
|
+
yield mock
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@pytest.fixture
|
|
49
|
+
def mock_task():
|
|
50
|
+
with patch("kiln_ai.datamodel.Task.from_id_and_parent_path") as mock:
|
|
51
|
+
task = Mock(spec=Task)
|
|
52
|
+
task.path = "/fake/path/task"
|
|
53
|
+
mock.return_value = task
|
|
54
|
+
yield mock
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@pytest.fixture
|
|
58
|
+
def mock_finetune():
|
|
59
|
+
with patch("kiln_ai.datamodel.Finetune.from_id_and_parent_path") as mock:
|
|
60
|
+
finetune = Mock(spec=Finetune)
|
|
61
|
+
finetune.provider = ModelProviderName.openai
|
|
62
|
+
finetune.fine_tune_model_id = "ft:gpt-3.5-turbo:custom:model-123"
|
|
63
|
+
mock.return_value = finetune
|
|
64
|
+
yield mock
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def test_check_provider_warnings_no_warning(mock_config):
|
|
68
|
+
mock_config.return_value = "some_value"
|
|
69
|
+
|
|
70
|
+
# This should not raise an exception
|
|
71
|
+
check_provider_warnings(ModelProviderName.amazon_bedrock)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def test_check_provider_warnings_missing_key(mock_config):
|
|
75
|
+
mock_config.return_value = None
|
|
76
|
+
|
|
77
|
+
with pytest.raises(ValueError) as exc_info:
|
|
78
|
+
check_provider_warnings(ModelProviderName.amazon_bedrock)
|
|
79
|
+
|
|
80
|
+
assert provider_warnings[ModelProviderName.amazon_bedrock].message in str(
|
|
81
|
+
exc_info.value
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def test_check_provider_warnings_unknown_provider():
|
|
86
|
+
# This should not raise an exception, as no settings are required for unknown providers
|
|
87
|
+
check_provider_warnings("unknown_provider")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@pytest.mark.parametrize(
|
|
91
|
+
"provider_name",
|
|
92
|
+
[
|
|
93
|
+
ModelProviderName.amazon_bedrock,
|
|
94
|
+
ModelProviderName.openrouter,
|
|
95
|
+
ModelProviderName.groq,
|
|
96
|
+
ModelProviderName.openai,
|
|
97
|
+
ModelProviderName.fireworks_ai,
|
|
98
|
+
],
|
|
99
|
+
)
|
|
100
|
+
def test_check_provider_warnings_all_providers(mock_config, provider_name):
|
|
101
|
+
mock_config.return_value = None
|
|
102
|
+
|
|
103
|
+
with pytest.raises(ValueError) as exc_info:
|
|
104
|
+
check_provider_warnings(provider_name)
|
|
105
|
+
|
|
106
|
+
assert provider_warnings[provider_name].message in str(exc_info.value)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def test_check_provider_warnings_partial_keys_set(mock_config):
|
|
110
|
+
def mock_get(key):
|
|
111
|
+
return "value" if key == "bedrock_access_key" else None
|
|
112
|
+
|
|
113
|
+
mock_config.side_effect = mock_get
|
|
114
|
+
|
|
115
|
+
with pytest.raises(ValueError) as exc_info:
|
|
116
|
+
check_provider_warnings(ModelProviderName.amazon_bedrock)
|
|
117
|
+
|
|
118
|
+
assert provider_warnings[ModelProviderName.amazon_bedrock].message in str(
|
|
119
|
+
exc_info.value
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def test_provider_name_from_id_unknown_provider():
|
|
124
|
+
assert (
|
|
125
|
+
provider_name_from_id("unknown_provider")
|
|
126
|
+
== "Unknown provider: unknown_provider"
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def test_provider_name_from_id_case_sensitivity():
|
|
131
|
+
assert (
|
|
132
|
+
provider_name_from_id(ModelProviderName.amazon_bedrock.upper())
|
|
133
|
+
== "Unknown provider: AMAZON_BEDROCK"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@pytest.mark.parametrize(
|
|
138
|
+
"provider_id, expected_name",
|
|
139
|
+
[
|
|
140
|
+
(ModelProviderName.amazon_bedrock, "Amazon Bedrock"),
|
|
141
|
+
(ModelProviderName.openrouter, "OpenRouter"),
|
|
142
|
+
(ModelProviderName.groq, "Groq"),
|
|
143
|
+
(ModelProviderName.ollama, "Ollama"),
|
|
144
|
+
(ModelProviderName.openai, "OpenAI"),
|
|
145
|
+
(ModelProviderName.fireworks_ai, "Fireworks AI"),
|
|
146
|
+
(ModelProviderName.kiln_fine_tune, "Fine Tuned Models"),
|
|
147
|
+
(ModelProviderName.kiln_custom_registry, "Custom Models"),
|
|
148
|
+
],
|
|
149
|
+
)
|
|
150
|
+
def test_provider_name_from_id_parametrized(provider_id, expected_name):
|
|
151
|
+
assert provider_name_from_id(provider_id) == expected_name
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def test_get_model_and_provider_valid():
|
|
155
|
+
# Test with a known valid model and provider combination
|
|
156
|
+
model, provider = get_model_and_provider(
|
|
157
|
+
ModelName.phi_3_5, ModelProviderName.ollama
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
assert model is not None
|
|
161
|
+
assert provider is not None
|
|
162
|
+
assert model.name == ModelName.phi_3_5
|
|
163
|
+
assert provider.name == ModelProviderName.ollama
|
|
164
|
+
assert provider.provider_options["model"] == "phi3.5"
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def test_get_model_and_provider_invalid_model():
|
|
168
|
+
# Test with an invalid model name
|
|
169
|
+
model, provider = get_model_and_provider(
|
|
170
|
+
"nonexistent_model", ModelProviderName.ollama
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
assert model is None
|
|
174
|
+
assert provider is None
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def test_get_model_and_provider_invalid_provider():
|
|
178
|
+
# Test with a valid model but invalid provider
|
|
179
|
+
model, provider = get_model_and_provider(ModelName.phi_3_5, "nonexistent_provider")
|
|
180
|
+
|
|
181
|
+
assert model is None
|
|
182
|
+
assert provider is None
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def test_get_model_and_provider_valid_model_wrong_provider():
|
|
186
|
+
# Test with a valid model but a provider that doesn't support it
|
|
187
|
+
model, provider = get_model_and_provider(
|
|
188
|
+
ModelName.phi_3_5, ModelProviderName.amazon_bedrock
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
assert model is None
|
|
192
|
+
assert provider is None
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def test_get_model_and_provider_multiple_providers():
|
|
196
|
+
# Test with a model that has multiple providers
|
|
197
|
+
model, provider = get_model_and_provider(
|
|
198
|
+
ModelName.llama_3_1_70b, ModelProviderName.groq
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
assert model is not None
|
|
202
|
+
assert provider is not None
|
|
203
|
+
assert model.name == ModelName.llama_3_1_70b
|
|
204
|
+
assert provider.name == ModelProviderName.groq
|
|
205
|
+
assert provider.provider_options["model"] == "llama-3.1-70b-versatile"
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
@pytest.mark.asyncio
|
|
209
|
+
async def test_provider_enabled_ollama_success():
|
|
210
|
+
with patch(
|
|
211
|
+
"kiln_ai.adapters.provider_tools.get_ollama_connection", new_callable=AsyncMock
|
|
212
|
+
) as mock_get_ollama:
|
|
213
|
+
# Mock successful Ollama connection with models
|
|
214
|
+
mock_get_ollama.return_value = OllamaConnection(
|
|
215
|
+
message="Connected", supported_models=["phi3.5:latest"]
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
result = await provider_enabled(ModelProviderName.ollama)
|
|
219
|
+
assert result is True
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@pytest.mark.asyncio
|
|
223
|
+
async def test_provider_enabled_ollama_no_models():
|
|
224
|
+
with patch(
|
|
225
|
+
"kiln_ai.adapters.provider_tools.get_ollama_connection", new_callable=AsyncMock
|
|
226
|
+
) as mock_get_ollama:
|
|
227
|
+
# Mock Ollama connection but with no models
|
|
228
|
+
mock_get_ollama.return_value = OllamaConnection(
|
|
229
|
+
message="Connected but no models",
|
|
230
|
+
supported_models=[],
|
|
231
|
+
unsupported_models=[],
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
result = await provider_enabled(ModelProviderName.ollama)
|
|
235
|
+
assert result is False
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@pytest.mark.asyncio
|
|
239
|
+
async def test_provider_enabled_ollama_connection_error():
|
|
240
|
+
with patch(
|
|
241
|
+
"kiln_ai.adapters.provider_tools.get_ollama_connection", new_callable=AsyncMock
|
|
242
|
+
) as mock_get_ollama:
|
|
243
|
+
# Mock Ollama connection failure
|
|
244
|
+
mock_get_ollama.side_effect = Exception("Connection failed")
|
|
245
|
+
|
|
246
|
+
result = await provider_enabled(ModelProviderName.ollama)
|
|
247
|
+
assert result is False
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@pytest.mark.asyncio
|
|
251
|
+
async def test_provider_enabled_openai_with_key(mock_config):
|
|
252
|
+
# Mock config to return API key
|
|
253
|
+
mock_config.return_value = "fake-api-key"
|
|
254
|
+
|
|
255
|
+
result = await provider_enabled(ModelProviderName.openai)
|
|
256
|
+
assert result is True
|
|
257
|
+
mock_config.assert_called_with("open_ai_api_key")
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
@pytest.mark.asyncio
|
|
261
|
+
async def test_provider_enabled_openai_without_key(mock_config):
|
|
262
|
+
# Mock config to return None for API key
|
|
263
|
+
mock_config.return_value = None
|
|
264
|
+
|
|
265
|
+
result = await provider_enabled(ModelProviderName.openai)
|
|
266
|
+
assert result is False
|
|
267
|
+
mock_config.assert_called_with("open_ai_api_key")
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@pytest.mark.asyncio
|
|
271
|
+
async def test_provider_enabled_unknown_provider():
|
|
272
|
+
# Test with a provider that isn't in provider_warnings
|
|
273
|
+
result = await provider_enabled("unknown_provider")
|
|
274
|
+
assert result is False
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
@pytest.mark.asyncio
|
|
278
|
+
async def test_kiln_model_provider_from_custom_model_no_provider():
|
|
279
|
+
with pytest.raises(ValueError) as exc_info:
|
|
280
|
+
await kiln_model_provider_from("custom_model")
|
|
281
|
+
assert str(exc_info.value) == "Provider name is required for custom models"
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
@pytest.mark.asyncio
|
|
285
|
+
async def test_kiln_model_provider_from_invalid_provider():
|
|
286
|
+
with pytest.raises(ValueError) as exc_info:
|
|
287
|
+
await kiln_model_provider_from("custom_model", "invalid_provider")
|
|
288
|
+
assert str(exc_info.value) == "Invalid provider name: invalid_provider"
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
@pytest.mark.asyncio
|
|
292
|
+
async def test_kiln_model_provider_from_custom_model_valid(mock_config):
|
|
293
|
+
# Mock config to pass provider warnings check
|
|
294
|
+
mock_config.return_value = "fake-api-key"
|
|
295
|
+
|
|
296
|
+
provider = await kiln_model_provider_from("custom_model", ModelProviderName.openai)
|
|
297
|
+
|
|
298
|
+
assert provider.name == ModelProviderName.openai
|
|
299
|
+
assert provider.supports_structured_output is False
|
|
300
|
+
assert provider.supports_data_gen is False
|
|
301
|
+
assert provider.untested_model is True
|
|
302
|
+
assert "model" in provider.provider_options
|
|
303
|
+
assert provider.provider_options["model"] == "custom_model"
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def test_provider_options_for_custom_model_basic():
|
|
307
|
+
"""Test basic case with custom model name"""
|
|
308
|
+
options = provider_options_for_custom_model(
|
|
309
|
+
"custom_model_name", ModelProviderName.openai
|
|
310
|
+
)
|
|
311
|
+
assert options == {"model": "custom_model_name"}
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def test_provider_options_for_custom_model_bedrock():
|
|
315
|
+
"""Test Amazon Bedrock provider options"""
|
|
316
|
+
options = provider_options_for_custom_model(
|
|
317
|
+
ModelName.llama_3_1_8b, ModelProviderName.amazon_bedrock
|
|
318
|
+
)
|
|
319
|
+
assert options == {"model": ModelName.llama_3_1_8b, "region_name": "us-west-2"}
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
@pytest.mark.parametrize(
|
|
323
|
+
"provider",
|
|
324
|
+
[
|
|
325
|
+
ModelProviderName.openai,
|
|
326
|
+
ModelProviderName.ollama,
|
|
327
|
+
ModelProviderName.fireworks_ai,
|
|
328
|
+
ModelProviderName.openrouter,
|
|
329
|
+
ModelProviderName.groq,
|
|
330
|
+
],
|
|
331
|
+
)
|
|
332
|
+
def test_provider_options_for_custom_model_simple_providers(provider):
|
|
333
|
+
"""Test providers that just need model name"""
|
|
334
|
+
|
|
335
|
+
options = provider_options_for_custom_model(ModelName.llama_3_1_8b, provider)
|
|
336
|
+
assert options == {"model": ModelName.llama_3_1_8b}
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def test_provider_options_for_custom_model_kiln_fine_tune():
|
|
340
|
+
"""Test that kiln_fine_tune raises appropriate error"""
|
|
341
|
+
with pytest.raises(ValueError) as exc_info:
|
|
342
|
+
provider_options_for_custom_model(
|
|
343
|
+
"model_name", ModelProviderName.kiln_fine_tune
|
|
344
|
+
)
|
|
345
|
+
assert (
|
|
346
|
+
str(exc_info.value)
|
|
347
|
+
== "Fine tuned models should populate provider options via another path"
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def test_provider_options_for_custom_model_invalid_enum():
|
|
352
|
+
"""Test handling of invalid enum value"""
|
|
353
|
+
with pytest.raises(ValueError):
|
|
354
|
+
provider_options_for_custom_model("model_name", "invalid_enum_value")
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
@pytest.mark.asyncio
|
|
358
|
+
async def test_kiln_model_provider_from_custom_registry(mock_config):
|
|
359
|
+
# Mock config to pass provider warnings check
|
|
360
|
+
mock_config.return_value = "fake-api-key"
|
|
361
|
+
|
|
362
|
+
# Test with a custom registry model ID in format "provider::model_name"
|
|
363
|
+
provider = await kiln_model_provider_from(
|
|
364
|
+
"openai::gpt-4-turbo", ModelProviderName.kiln_custom_registry
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
assert provider.name == ModelProviderName.openai
|
|
368
|
+
assert provider.supports_structured_output is False
|
|
369
|
+
assert provider.supports_data_gen is False
|
|
370
|
+
assert provider.untested_model is True
|
|
371
|
+
assert provider.provider_options == {"model": "gpt-4-turbo"}
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
@pytest.mark.asyncio
|
|
375
|
+
async def test_builtin_model_from_invalid_model():
|
|
376
|
+
"""Test that an invalid model name returns None"""
|
|
377
|
+
result = await builtin_model_from("non_existent_model")
|
|
378
|
+
assert result is None
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
@pytest.mark.asyncio
|
|
382
|
+
async def test_builtin_model_from_valid_model_default_provider(mock_config):
|
|
383
|
+
"""Test getting a valid model with default provider"""
|
|
384
|
+
mock_config.return_value = "fake-api-key"
|
|
385
|
+
|
|
386
|
+
provider = await builtin_model_from(ModelName.phi_3_5)
|
|
387
|
+
|
|
388
|
+
assert provider is not None
|
|
389
|
+
assert provider.name == ModelProviderName.ollama
|
|
390
|
+
assert provider.provider_options["model"] == "phi3.5"
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
@pytest.mark.asyncio
|
|
394
|
+
async def test_builtin_model_from_valid_model_specific_provider(mock_config):
|
|
395
|
+
"""Test getting a valid model with specific provider"""
|
|
396
|
+
mock_config.return_value = "fake-api-key"
|
|
397
|
+
|
|
398
|
+
provider = await builtin_model_from(
|
|
399
|
+
ModelName.llama_3_1_70b, provider_name=ModelProviderName.groq
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
assert provider is not None
|
|
403
|
+
assert provider.name == ModelProviderName.groq
|
|
404
|
+
assert provider.provider_options["model"] == "llama-3.1-70b-versatile"
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
@pytest.mark.asyncio
|
|
408
|
+
async def test_builtin_model_from_invalid_provider(mock_config):
|
|
409
|
+
"""Test that requesting an invalid provider returns None"""
|
|
410
|
+
mock_config.return_value = "fake-api-key"
|
|
411
|
+
|
|
412
|
+
provider = await builtin_model_from(
|
|
413
|
+
ModelName.phi_3_5, provider_name="invalid_provider"
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
assert provider is None
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
@pytest.mark.asyncio
|
|
420
|
+
async def test_builtin_model_from_model_no_providers():
|
|
421
|
+
"""Test handling of a model with no providers"""
|
|
422
|
+
with patch("kiln_ai.adapters.provider_tools.built_in_models") as mock_models:
|
|
423
|
+
# Create a mock model with no providers
|
|
424
|
+
mock_model = KilnModel(
|
|
425
|
+
name=ModelName.phi_3_5,
|
|
426
|
+
friendly_name="Test Model",
|
|
427
|
+
providers=[],
|
|
428
|
+
family="test_family",
|
|
429
|
+
)
|
|
430
|
+
mock_models.__iter__.return_value = [mock_model]
|
|
431
|
+
|
|
432
|
+
with pytest.raises(ValueError) as exc_info:
|
|
433
|
+
await builtin_model_from(ModelName.phi_3_5)
|
|
434
|
+
|
|
435
|
+
assert str(exc_info.value) == f"Model {ModelName.phi_3_5} has no providers"
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
@pytest.mark.asyncio
|
|
439
|
+
async def test_builtin_model_from_provider_warning_check(mock_config):
|
|
440
|
+
"""Test that provider warnings are checked"""
|
|
441
|
+
# Make the config check fail
|
|
442
|
+
mock_config.return_value = None
|
|
443
|
+
|
|
444
|
+
with pytest.raises(ValueError) as exc_info:
|
|
445
|
+
await builtin_model_from(ModelName.llama_3_1_70b, ModelProviderName.groq)
|
|
446
|
+
|
|
447
|
+
assert provider_warnings[ModelProviderName.groq].message in str(exc_info.value)
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def test_finetune_provider_model_success(mock_project, mock_task, mock_finetune):
|
|
451
|
+
"""Test successful creation of a fine-tuned model provider"""
|
|
452
|
+
model_id = "project-123::task-456::finetune-789"
|
|
453
|
+
|
|
454
|
+
provider = finetune_provider_model(model_id)
|
|
455
|
+
|
|
456
|
+
assert provider.name == ModelProviderName.openai
|
|
457
|
+
assert provider.provider_options == {"model": "ft:gpt-3.5-turbo:custom:model-123"}
|
|
458
|
+
|
|
459
|
+
# Test cache
|
|
460
|
+
cached_provider = finetune_provider_model(model_id)
|
|
461
|
+
assert cached_provider is provider
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def test_finetune_provider_model_invalid_id():
|
|
465
|
+
"""Test handling of invalid model ID format"""
|
|
466
|
+
with pytest.raises(ValueError) as exc_info:
|
|
467
|
+
finetune_provider_model("invalid-id-format")
|
|
468
|
+
assert str(exc_info.value) == "Invalid fine tune ID: invalid-id-format"
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
def test_finetune_provider_model_project_not_found(mock_project):
|
|
472
|
+
"""Test handling of non-existent project"""
|
|
473
|
+
mock_project.return_value = None
|
|
474
|
+
|
|
475
|
+
with pytest.raises(ValueError) as exc_info:
|
|
476
|
+
finetune_provider_model("project-123::task-456::finetune-789")
|
|
477
|
+
assert str(exc_info.value) == "Project project-123 not found"
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def test_finetune_provider_model_task_not_found(mock_project, mock_task):
|
|
481
|
+
"""Test handling of non-existent task"""
|
|
482
|
+
mock_task.return_value = None
|
|
483
|
+
|
|
484
|
+
with pytest.raises(ValueError) as exc_info:
|
|
485
|
+
finetune_provider_model("project-123::task-456::finetune-789")
|
|
486
|
+
assert str(exc_info.value) == "Task task-456 not found"
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def test_finetune_provider_model_finetune_not_found(
|
|
490
|
+
mock_project, mock_task, mock_finetune
|
|
491
|
+
):
|
|
492
|
+
"""Test handling of non-existent fine-tune"""
|
|
493
|
+
mock_finetune.return_value = None
|
|
494
|
+
|
|
495
|
+
with pytest.raises(ValueError) as exc_info:
|
|
496
|
+
finetune_provider_model("project-123::task-456::finetune-789")
|
|
497
|
+
assert str(exc_info.value) == "Fine tune finetune-789 not found"
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def test_finetune_provider_model_incomplete_finetune(
|
|
501
|
+
mock_project, mock_task, mock_finetune
|
|
502
|
+
):
|
|
503
|
+
"""Test handling of incomplete fine-tune"""
|
|
504
|
+
finetune = Mock(spec=Finetune)
|
|
505
|
+
finetune.fine_tune_model_id = None
|
|
506
|
+
mock_finetune.return_value = finetune
|
|
507
|
+
|
|
508
|
+
with pytest.raises(ValueError) as exc_info:
|
|
509
|
+
finetune_provider_model("project-123::task-456::finetune-789")
|
|
510
|
+
assert (
|
|
511
|
+
str(exc_info.value)
|
|
512
|
+
== "Fine tune finetune-789 not completed. Refresh it's status in the fine-tune tab."
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def test_finetune_provider_model_fireworks_provider(
|
|
517
|
+
mock_project, mock_task, mock_finetune
|
|
518
|
+
):
|
|
519
|
+
"""Test creation of Fireworks AI provider with specific adapter options"""
|
|
520
|
+
finetune = Mock(spec=Finetune)
|
|
521
|
+
finetune.provider = ModelProviderName.fireworks_ai
|
|
522
|
+
finetune.fine_tune_model_id = "fireworks-model-123"
|
|
523
|
+
mock_finetune.return_value = finetune
|
|
524
|
+
|
|
525
|
+
provider = finetune_provider_model("project-123::task-456::finetune-789")
|
|
526
|
+
|
|
527
|
+
assert provider.name == ModelProviderName.fireworks_ai
|
|
528
|
+
assert provider.provider_options == {"model": "fireworks-model-123"}
|
|
529
|
+
assert provider.adapter_options == {
|
|
530
|
+
"langchain": {"with_structured_output_options": {"method": "json_mode"}}
|
|
531
|
+
}
|
|
@@ -6,12 +6,12 @@ import jsonschema.exceptions
|
|
|
6
6
|
import pytest
|
|
7
7
|
|
|
8
8
|
import kiln_ai.datamodel as datamodel
|
|
9
|
+
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
9
10
|
from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter, RunOutput
|
|
10
|
-
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
|
|
11
11
|
from kiln_ai.adapters.ml_model_list import (
|
|
12
12
|
built_in_models,
|
|
13
|
-
ollama_online,
|
|
14
13
|
)
|
|
14
|
+
from kiln_ai.adapters.ollama_tools import ollama_online
|
|
15
15
|
from kiln_ai.adapters.prompt_builders import (
|
|
16
16
|
BasePromptBuilder,
|
|
17
17
|
SimpleChainOfThoughtPromptBuilder,
|
|
@@ -20,23 +20,6 @@ from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
|
|
|
20
20
|
from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
@pytest.mark.parametrize(
|
|
24
|
-
"model_name,provider",
|
|
25
|
-
[
|
|
26
|
-
("llama_3_1_8b", "groq"),
|
|
27
|
-
("mistral_nemo", "openrouter"),
|
|
28
|
-
("llama_3_1_70b", "amazon_bedrock"),
|
|
29
|
-
("claude_3_5_sonnet", "openrouter"),
|
|
30
|
-
("gemini_1_5_pro", "openrouter"),
|
|
31
|
-
("gemini_1_5_flash", "openrouter"),
|
|
32
|
-
("gemini_1_5_flash_8b", "openrouter"),
|
|
33
|
-
],
|
|
34
|
-
)
|
|
35
|
-
@pytest.mark.paid
|
|
36
|
-
async def test_structured_output(tmp_path, model_name, provider):
|
|
37
|
-
await run_structured_output_test(tmp_path, model_name, provider)
|
|
38
|
-
|
|
39
|
-
|
|
40
23
|
@pytest.mark.ollama
|
|
41
24
|
async def test_structured_output_ollama_phi(tmp_path):
|
|
42
25
|
# https://python.langchain.com/v0.2/docs/how_to/structured_output/#advanced-specifying-the-method-for-structuring-outputs
|
|
@@ -112,28 +95,27 @@ async def test_mock_unstructred_response(tmp_path):
|
|
|
112
95
|
|
|
113
96
|
@pytest.mark.paid
|
|
114
97
|
@pytest.mark.ollama
|
|
115
|
-
|
|
116
|
-
|
|
98
|
+
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
99
|
+
async def test_all_built_in_models_structured_output(
|
|
100
|
+
tmp_path, model_name, provider_name
|
|
101
|
+
):
|
|
117
102
|
for model in built_in_models:
|
|
103
|
+
if model.name != model_name:
|
|
104
|
+
continue
|
|
118
105
|
if not model.supports_structured_output:
|
|
119
|
-
|
|
106
|
+
pytest.skip(
|
|
120
107
|
f"Skipping {model.name} because it does not support structured output"
|
|
121
108
|
)
|
|
122
|
-
continue
|
|
123
109
|
for provider in model.providers:
|
|
110
|
+
if provider.name != provider_name:
|
|
111
|
+
continue
|
|
124
112
|
if not provider.supports_structured_output:
|
|
125
|
-
|
|
113
|
+
pytest.skip(
|
|
126
114
|
f"Skipping {model.name} {provider.name} because it does not support structured output"
|
|
127
115
|
)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
await run_structured_output_test(tmp_path, model.name, provider.name)
|
|
132
|
-
except Exception as e:
|
|
133
|
-
print(f"Error running {model.name} {provider.name}")
|
|
134
|
-
errors.append(f"{model.name} {provider.name}: {e}")
|
|
135
|
-
if len(errors) > 0:
|
|
136
|
-
raise RuntimeError(f"Errors: {errors}")
|
|
116
|
+
await run_structured_output_test(tmp_path, model.name, provider.name)
|
|
117
|
+
return
|
|
118
|
+
raise RuntimeError(f"No model {model_name} {provider_name} found")
|
|
137
119
|
|
|
138
120
|
|
|
139
121
|
def build_structured_output_test_task(tmp_path: Path):
|
|
@@ -157,7 +139,7 @@ def build_structured_output_test_task(tmp_path: Path):
|
|
|
157
139
|
|
|
158
140
|
async def run_structured_output_test(tmp_path: Path, model_name: str, provider: str):
|
|
159
141
|
task = build_structured_output_test_task(tmp_path)
|
|
160
|
-
a =
|
|
142
|
+
a = adapter_for_task(task, model_name=model_name, provider=provider)
|
|
161
143
|
parsed = await a.invoke_returning_raw("Cows") # a joke about cows
|
|
162
144
|
if parsed is None or not isinstance(parsed, Dict):
|
|
163
145
|
raise RuntimeError(f"structured response is not a dict: {parsed}")
|
|
@@ -204,7 +186,7 @@ async def run_structured_input_task(
|
|
|
204
186
|
provider: str,
|
|
205
187
|
pb: BasePromptBuilder | None = None,
|
|
206
188
|
):
|
|
207
|
-
a =
|
|
189
|
+
a = adapter_for_task(
|
|
208
190
|
task, model_name=model_name, provider=provider, prompt_builder=pb
|
|
209
191
|
)
|
|
210
192
|
with pytest.raises(ValueError):
|
|
@@ -235,14 +217,11 @@ async def test_structured_input_gpt_4o_mini(tmp_path):
|
|
|
235
217
|
|
|
236
218
|
@pytest.mark.paid
|
|
237
219
|
@pytest.mark.ollama
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
await run_structured_input_test(tmp_path, model.name, provider.name)
|
|
244
|
-
except Exception as e:
|
|
245
|
-
raise RuntimeError(f"Error running {model.name} {provider}") from e
|
|
220
|
+
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
221
|
+
async def test_all_built_in_models_structured_input(
|
|
222
|
+
tmp_path, model_name, provider_name
|
|
223
|
+
):
|
|
224
|
+
await run_structured_input_test(tmp_path, model_name, provider_name)
|
|
246
225
|
|
|
247
226
|
|
|
248
227
|
@pytest.mark.paid
|