kiln-ai 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (44) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +19 -0
  3. kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
  4. kiln_ai/adapters/fine_tune/__init__.py +14 -0
  5. kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
  6. kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
  7. kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
  8. kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
  9. kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
  10. kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
  11. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
  12. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
  13. kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
  14. kiln_ai/adapters/langchain_adapters.py +103 -13
  15. kiln_ai/adapters/ml_model_list.py +239 -303
  16. kiln_ai/adapters/ollama_tools.py +115 -0
  17. kiln_ai/adapters/provider_tools.py +308 -0
  18. kiln_ai/adapters/repair/repair_task.py +4 -2
  19. kiln_ai/adapters/repair/test_repair_task.py +6 -11
  20. kiln_ai/adapters/test_langchain_adapter.py +229 -18
  21. kiln_ai/adapters/test_ollama_tools.py +42 -0
  22. kiln_ai/adapters/test_prompt_adaptors.py +7 -5
  23. kiln_ai/adapters/test_provider_tools.py +531 -0
  24. kiln_ai/adapters/test_structured_output.py +22 -43
  25. kiln_ai/datamodel/__init__.py +287 -24
  26. kiln_ai/datamodel/basemodel.py +122 -38
  27. kiln_ai/datamodel/model_cache.py +116 -0
  28. kiln_ai/datamodel/registry.py +31 -0
  29. kiln_ai/datamodel/test_basemodel.py +167 -4
  30. kiln_ai/datamodel/test_dataset_split.py +234 -0
  31. kiln_ai/datamodel/test_example_models.py +12 -0
  32. kiln_ai/datamodel/test_model_cache.py +244 -0
  33. kiln_ai/datamodel/test_models.py +215 -1
  34. kiln_ai/datamodel/test_registry.py +96 -0
  35. kiln_ai/utils/config.py +14 -1
  36. kiln_ai/utils/name_generator.py +125 -0
  37. kiln_ai/utils/test_name_geneator.py +47 -0
  38. kiln_ai-0.7.1.dist-info/METADATA +237 -0
  39. kiln_ai-0.7.1.dist-info/RECORD +58 -0
  40. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/WHEEL +1 -1
  41. kiln_ai/adapters/test_ml_model_list.py +0 -181
  42. kiln_ai-0.6.1.dist-info/METADATA +0 -88
  43. kiln_ai-0.6.1.dist-info/RECORD +0 -37
  44. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,531 @@
1
+ from unittest.mock import AsyncMock, Mock, patch
2
+
3
+ import pytest
4
+
5
+ from kiln_ai.adapters.ml_model_list import (
6
+ KilnModel,
7
+ ModelName,
8
+ ModelProviderName,
9
+ )
10
+ from kiln_ai.adapters.ollama_tools import OllamaConnection
11
+ from kiln_ai.adapters.provider_tools import (
12
+ builtin_model_from,
13
+ check_provider_warnings,
14
+ finetune_cache,
15
+ finetune_provider_model,
16
+ get_model_and_provider,
17
+ kiln_model_provider_from,
18
+ provider_enabled,
19
+ provider_name_from_id,
20
+ provider_options_for_custom_model,
21
+ provider_warnings,
22
+ )
23
+ from kiln_ai.datamodel import Finetune, Task
24
+
25
+
26
+ @pytest.fixture(autouse=True)
27
+ def clear_finetune_cache():
28
+ """Clear the finetune provider model cache before each test"""
29
+ finetune_cache.clear()
30
+ yield
31
+
32
+
33
+ @pytest.fixture
34
+ def mock_config():
35
+ with patch("kiln_ai.adapters.provider_tools.get_config_value") as mock:
36
+ yield mock
37
+
38
+
39
+ @pytest.fixture
40
+ def mock_project():
41
+ with patch("kiln_ai.adapters.provider_tools.project_from_id") as mock:
42
+ project = Mock()
43
+ project.path = "/fake/path"
44
+ mock.return_value = project
45
+ yield mock
46
+
47
+
48
+ @pytest.fixture
49
+ def mock_task():
50
+ with patch("kiln_ai.datamodel.Task.from_id_and_parent_path") as mock:
51
+ task = Mock(spec=Task)
52
+ task.path = "/fake/path/task"
53
+ mock.return_value = task
54
+ yield mock
55
+
56
+
57
+ @pytest.fixture
58
+ def mock_finetune():
59
+ with patch("kiln_ai.datamodel.Finetune.from_id_and_parent_path") as mock:
60
+ finetune = Mock(spec=Finetune)
61
+ finetune.provider = ModelProviderName.openai
62
+ finetune.fine_tune_model_id = "ft:gpt-3.5-turbo:custom:model-123"
63
+ mock.return_value = finetune
64
+ yield mock
65
+
66
+
67
+ def test_check_provider_warnings_no_warning(mock_config):
68
+ mock_config.return_value = "some_value"
69
+
70
+ # This should not raise an exception
71
+ check_provider_warnings(ModelProviderName.amazon_bedrock)
72
+
73
+
74
+ def test_check_provider_warnings_missing_key(mock_config):
75
+ mock_config.return_value = None
76
+
77
+ with pytest.raises(ValueError) as exc_info:
78
+ check_provider_warnings(ModelProviderName.amazon_bedrock)
79
+
80
+ assert provider_warnings[ModelProviderName.amazon_bedrock].message in str(
81
+ exc_info.value
82
+ )
83
+
84
+
85
+ def test_check_provider_warnings_unknown_provider():
86
+ # This should not raise an exception, as no settings are required for unknown providers
87
+ check_provider_warnings("unknown_provider")
88
+
89
+
90
+ @pytest.mark.parametrize(
91
+ "provider_name",
92
+ [
93
+ ModelProviderName.amazon_bedrock,
94
+ ModelProviderName.openrouter,
95
+ ModelProviderName.groq,
96
+ ModelProviderName.openai,
97
+ ModelProviderName.fireworks_ai,
98
+ ],
99
+ )
100
+ def test_check_provider_warnings_all_providers(mock_config, provider_name):
101
+ mock_config.return_value = None
102
+
103
+ with pytest.raises(ValueError) as exc_info:
104
+ check_provider_warnings(provider_name)
105
+
106
+ assert provider_warnings[provider_name].message in str(exc_info.value)
107
+
108
+
109
+ def test_check_provider_warnings_partial_keys_set(mock_config):
110
+ def mock_get(key):
111
+ return "value" if key == "bedrock_access_key" else None
112
+
113
+ mock_config.side_effect = mock_get
114
+
115
+ with pytest.raises(ValueError) as exc_info:
116
+ check_provider_warnings(ModelProviderName.amazon_bedrock)
117
+
118
+ assert provider_warnings[ModelProviderName.amazon_bedrock].message in str(
119
+ exc_info.value
120
+ )
121
+
122
+
123
+ def test_provider_name_from_id_unknown_provider():
124
+ assert (
125
+ provider_name_from_id("unknown_provider")
126
+ == "Unknown provider: unknown_provider"
127
+ )
128
+
129
+
130
+ def test_provider_name_from_id_case_sensitivity():
131
+ assert (
132
+ provider_name_from_id(ModelProviderName.amazon_bedrock.upper())
133
+ == "Unknown provider: AMAZON_BEDROCK"
134
+ )
135
+
136
+
137
+ @pytest.mark.parametrize(
138
+ "provider_id, expected_name",
139
+ [
140
+ (ModelProviderName.amazon_bedrock, "Amazon Bedrock"),
141
+ (ModelProviderName.openrouter, "OpenRouter"),
142
+ (ModelProviderName.groq, "Groq"),
143
+ (ModelProviderName.ollama, "Ollama"),
144
+ (ModelProviderName.openai, "OpenAI"),
145
+ (ModelProviderName.fireworks_ai, "Fireworks AI"),
146
+ (ModelProviderName.kiln_fine_tune, "Fine Tuned Models"),
147
+ (ModelProviderName.kiln_custom_registry, "Custom Models"),
148
+ ],
149
+ )
150
+ def test_provider_name_from_id_parametrized(provider_id, expected_name):
151
+ assert provider_name_from_id(provider_id) == expected_name
152
+
153
+
154
+ def test_get_model_and_provider_valid():
155
+ # Test with a known valid model and provider combination
156
+ model, provider = get_model_and_provider(
157
+ ModelName.phi_3_5, ModelProviderName.ollama
158
+ )
159
+
160
+ assert model is not None
161
+ assert provider is not None
162
+ assert model.name == ModelName.phi_3_5
163
+ assert provider.name == ModelProviderName.ollama
164
+ assert provider.provider_options["model"] == "phi3.5"
165
+
166
+
167
+ def test_get_model_and_provider_invalid_model():
168
+ # Test with an invalid model name
169
+ model, provider = get_model_and_provider(
170
+ "nonexistent_model", ModelProviderName.ollama
171
+ )
172
+
173
+ assert model is None
174
+ assert provider is None
175
+
176
+
177
+ def test_get_model_and_provider_invalid_provider():
178
+ # Test with a valid model but invalid provider
179
+ model, provider = get_model_and_provider(ModelName.phi_3_5, "nonexistent_provider")
180
+
181
+ assert model is None
182
+ assert provider is None
183
+
184
+
185
+ def test_get_model_and_provider_valid_model_wrong_provider():
186
+ # Test with a valid model but a provider that doesn't support it
187
+ model, provider = get_model_and_provider(
188
+ ModelName.phi_3_5, ModelProviderName.amazon_bedrock
189
+ )
190
+
191
+ assert model is None
192
+ assert provider is None
193
+
194
+
195
+ def test_get_model_and_provider_multiple_providers():
196
+ # Test with a model that has multiple providers
197
+ model, provider = get_model_and_provider(
198
+ ModelName.llama_3_1_70b, ModelProviderName.groq
199
+ )
200
+
201
+ assert model is not None
202
+ assert provider is not None
203
+ assert model.name == ModelName.llama_3_1_70b
204
+ assert provider.name == ModelProviderName.groq
205
+ assert provider.provider_options["model"] == "llama-3.1-70b-versatile"
206
+
207
+
208
+ @pytest.mark.asyncio
209
+ async def test_provider_enabled_ollama_success():
210
+ with patch(
211
+ "kiln_ai.adapters.provider_tools.get_ollama_connection", new_callable=AsyncMock
212
+ ) as mock_get_ollama:
213
+ # Mock successful Ollama connection with models
214
+ mock_get_ollama.return_value = OllamaConnection(
215
+ message="Connected", supported_models=["phi3.5:latest"]
216
+ )
217
+
218
+ result = await provider_enabled(ModelProviderName.ollama)
219
+ assert result is True
220
+
221
+
222
+ @pytest.mark.asyncio
223
+ async def test_provider_enabled_ollama_no_models():
224
+ with patch(
225
+ "kiln_ai.adapters.provider_tools.get_ollama_connection", new_callable=AsyncMock
226
+ ) as mock_get_ollama:
227
+ # Mock Ollama connection but with no models
228
+ mock_get_ollama.return_value = OllamaConnection(
229
+ message="Connected but no models",
230
+ supported_models=[],
231
+ unsupported_models=[],
232
+ )
233
+
234
+ result = await provider_enabled(ModelProviderName.ollama)
235
+ assert result is False
236
+
237
+
238
+ @pytest.mark.asyncio
239
+ async def test_provider_enabled_ollama_connection_error():
240
+ with patch(
241
+ "kiln_ai.adapters.provider_tools.get_ollama_connection", new_callable=AsyncMock
242
+ ) as mock_get_ollama:
243
+ # Mock Ollama connection failure
244
+ mock_get_ollama.side_effect = Exception("Connection failed")
245
+
246
+ result = await provider_enabled(ModelProviderName.ollama)
247
+ assert result is False
248
+
249
+
250
+ @pytest.mark.asyncio
251
+ async def test_provider_enabled_openai_with_key(mock_config):
252
+ # Mock config to return API key
253
+ mock_config.return_value = "fake-api-key"
254
+
255
+ result = await provider_enabled(ModelProviderName.openai)
256
+ assert result is True
257
+ mock_config.assert_called_with("open_ai_api_key")
258
+
259
+
260
+ @pytest.mark.asyncio
261
+ async def test_provider_enabled_openai_without_key(mock_config):
262
+ # Mock config to return None for API key
263
+ mock_config.return_value = None
264
+
265
+ result = await provider_enabled(ModelProviderName.openai)
266
+ assert result is False
267
+ mock_config.assert_called_with("open_ai_api_key")
268
+
269
+
270
+ @pytest.mark.asyncio
271
+ async def test_provider_enabled_unknown_provider():
272
+ # Test with a provider that isn't in provider_warnings
273
+ result = await provider_enabled("unknown_provider")
274
+ assert result is False
275
+
276
+
277
+ @pytest.mark.asyncio
278
+ async def test_kiln_model_provider_from_custom_model_no_provider():
279
+ with pytest.raises(ValueError) as exc_info:
280
+ await kiln_model_provider_from("custom_model")
281
+ assert str(exc_info.value) == "Provider name is required for custom models"
282
+
283
+
284
+ @pytest.mark.asyncio
285
+ async def test_kiln_model_provider_from_invalid_provider():
286
+ with pytest.raises(ValueError) as exc_info:
287
+ await kiln_model_provider_from("custom_model", "invalid_provider")
288
+ assert str(exc_info.value) == "Invalid provider name: invalid_provider"
289
+
290
+
291
+ @pytest.mark.asyncio
292
+ async def test_kiln_model_provider_from_custom_model_valid(mock_config):
293
+ # Mock config to pass provider warnings check
294
+ mock_config.return_value = "fake-api-key"
295
+
296
+ provider = await kiln_model_provider_from("custom_model", ModelProviderName.openai)
297
+
298
+ assert provider.name == ModelProviderName.openai
299
+ assert provider.supports_structured_output is False
300
+ assert provider.supports_data_gen is False
301
+ assert provider.untested_model is True
302
+ assert "model" in provider.provider_options
303
+ assert provider.provider_options["model"] == "custom_model"
304
+
305
+
306
+ def test_provider_options_for_custom_model_basic():
307
+ """Test basic case with custom model name"""
308
+ options = provider_options_for_custom_model(
309
+ "custom_model_name", ModelProviderName.openai
310
+ )
311
+ assert options == {"model": "custom_model_name"}
312
+
313
+
314
+ def test_provider_options_for_custom_model_bedrock():
315
+ """Test Amazon Bedrock provider options"""
316
+ options = provider_options_for_custom_model(
317
+ ModelName.llama_3_1_8b, ModelProviderName.amazon_bedrock
318
+ )
319
+ assert options == {"model": ModelName.llama_3_1_8b, "region_name": "us-west-2"}
320
+
321
+
322
+ @pytest.mark.parametrize(
323
+ "provider",
324
+ [
325
+ ModelProviderName.openai,
326
+ ModelProviderName.ollama,
327
+ ModelProviderName.fireworks_ai,
328
+ ModelProviderName.openrouter,
329
+ ModelProviderName.groq,
330
+ ],
331
+ )
332
+ def test_provider_options_for_custom_model_simple_providers(provider):
333
+ """Test providers that just need model name"""
334
+
335
+ options = provider_options_for_custom_model(ModelName.llama_3_1_8b, provider)
336
+ assert options == {"model": ModelName.llama_3_1_8b}
337
+
338
+
339
+ def test_provider_options_for_custom_model_kiln_fine_tune():
340
+ """Test that kiln_fine_tune raises appropriate error"""
341
+ with pytest.raises(ValueError) as exc_info:
342
+ provider_options_for_custom_model(
343
+ "model_name", ModelProviderName.kiln_fine_tune
344
+ )
345
+ assert (
346
+ str(exc_info.value)
347
+ == "Fine tuned models should populate provider options via another path"
348
+ )
349
+
350
+
351
+ def test_provider_options_for_custom_model_invalid_enum():
352
+ """Test handling of invalid enum value"""
353
+ with pytest.raises(ValueError):
354
+ provider_options_for_custom_model("model_name", "invalid_enum_value")
355
+
356
+
357
+ @pytest.mark.asyncio
358
+ async def test_kiln_model_provider_from_custom_registry(mock_config):
359
+ # Mock config to pass provider warnings check
360
+ mock_config.return_value = "fake-api-key"
361
+
362
+ # Test with a custom registry model ID in format "provider::model_name"
363
+ provider = await kiln_model_provider_from(
364
+ "openai::gpt-4-turbo", ModelProviderName.kiln_custom_registry
365
+ )
366
+
367
+ assert provider.name == ModelProviderName.openai
368
+ assert provider.supports_structured_output is False
369
+ assert provider.supports_data_gen is False
370
+ assert provider.untested_model is True
371
+ assert provider.provider_options == {"model": "gpt-4-turbo"}
372
+
373
+
374
+ @pytest.mark.asyncio
375
+ async def test_builtin_model_from_invalid_model():
376
+ """Test that an invalid model name returns None"""
377
+ result = await builtin_model_from("non_existent_model")
378
+ assert result is None
379
+
380
+
381
+ @pytest.mark.asyncio
382
+ async def test_builtin_model_from_valid_model_default_provider(mock_config):
383
+ """Test getting a valid model with default provider"""
384
+ mock_config.return_value = "fake-api-key"
385
+
386
+ provider = await builtin_model_from(ModelName.phi_3_5)
387
+
388
+ assert provider is not None
389
+ assert provider.name == ModelProviderName.ollama
390
+ assert provider.provider_options["model"] == "phi3.5"
391
+
392
+
393
+ @pytest.mark.asyncio
394
+ async def test_builtin_model_from_valid_model_specific_provider(mock_config):
395
+ """Test getting a valid model with specific provider"""
396
+ mock_config.return_value = "fake-api-key"
397
+
398
+ provider = await builtin_model_from(
399
+ ModelName.llama_3_1_70b, provider_name=ModelProviderName.groq
400
+ )
401
+
402
+ assert provider is not None
403
+ assert provider.name == ModelProviderName.groq
404
+ assert provider.provider_options["model"] == "llama-3.1-70b-versatile"
405
+
406
+
407
+ @pytest.mark.asyncio
408
+ async def test_builtin_model_from_invalid_provider(mock_config):
409
+ """Test that requesting an invalid provider returns None"""
410
+ mock_config.return_value = "fake-api-key"
411
+
412
+ provider = await builtin_model_from(
413
+ ModelName.phi_3_5, provider_name="invalid_provider"
414
+ )
415
+
416
+ assert provider is None
417
+
418
+
419
+ @pytest.mark.asyncio
420
+ async def test_builtin_model_from_model_no_providers():
421
+ """Test handling of a model with no providers"""
422
+ with patch("kiln_ai.adapters.provider_tools.built_in_models") as mock_models:
423
+ # Create a mock model with no providers
424
+ mock_model = KilnModel(
425
+ name=ModelName.phi_3_5,
426
+ friendly_name="Test Model",
427
+ providers=[],
428
+ family="test_family",
429
+ )
430
+ mock_models.__iter__.return_value = [mock_model]
431
+
432
+ with pytest.raises(ValueError) as exc_info:
433
+ await builtin_model_from(ModelName.phi_3_5)
434
+
435
+ assert str(exc_info.value) == f"Model {ModelName.phi_3_5} has no providers"
436
+
437
+
438
+ @pytest.mark.asyncio
439
+ async def test_builtin_model_from_provider_warning_check(mock_config):
440
+ """Test that provider warnings are checked"""
441
+ # Make the config check fail
442
+ mock_config.return_value = None
443
+
444
+ with pytest.raises(ValueError) as exc_info:
445
+ await builtin_model_from(ModelName.llama_3_1_70b, ModelProviderName.groq)
446
+
447
+ assert provider_warnings[ModelProviderName.groq].message in str(exc_info.value)
448
+
449
+
450
+ def test_finetune_provider_model_success(mock_project, mock_task, mock_finetune):
451
+ """Test successful creation of a fine-tuned model provider"""
452
+ model_id = "project-123::task-456::finetune-789"
453
+
454
+ provider = finetune_provider_model(model_id)
455
+
456
+ assert provider.name == ModelProviderName.openai
457
+ assert provider.provider_options == {"model": "ft:gpt-3.5-turbo:custom:model-123"}
458
+
459
+ # Test cache
460
+ cached_provider = finetune_provider_model(model_id)
461
+ assert cached_provider is provider
462
+
463
+
464
+ def test_finetune_provider_model_invalid_id():
465
+ """Test handling of invalid model ID format"""
466
+ with pytest.raises(ValueError) as exc_info:
467
+ finetune_provider_model("invalid-id-format")
468
+ assert str(exc_info.value) == "Invalid fine tune ID: invalid-id-format"
469
+
470
+
471
+ def test_finetune_provider_model_project_not_found(mock_project):
472
+ """Test handling of non-existent project"""
473
+ mock_project.return_value = None
474
+
475
+ with pytest.raises(ValueError) as exc_info:
476
+ finetune_provider_model("project-123::task-456::finetune-789")
477
+ assert str(exc_info.value) == "Project project-123 not found"
478
+
479
+
480
+ def test_finetune_provider_model_task_not_found(mock_project, mock_task):
481
+ """Test handling of non-existent task"""
482
+ mock_task.return_value = None
483
+
484
+ with pytest.raises(ValueError) as exc_info:
485
+ finetune_provider_model("project-123::task-456::finetune-789")
486
+ assert str(exc_info.value) == "Task task-456 not found"
487
+
488
+
489
+ def test_finetune_provider_model_finetune_not_found(
490
+ mock_project, mock_task, mock_finetune
491
+ ):
492
+ """Test handling of non-existent fine-tune"""
493
+ mock_finetune.return_value = None
494
+
495
+ with pytest.raises(ValueError) as exc_info:
496
+ finetune_provider_model("project-123::task-456::finetune-789")
497
+ assert str(exc_info.value) == "Fine tune finetune-789 not found"
498
+
499
+
500
+ def test_finetune_provider_model_incomplete_finetune(
501
+ mock_project, mock_task, mock_finetune
502
+ ):
503
+ """Test handling of incomplete fine-tune"""
504
+ finetune = Mock(spec=Finetune)
505
+ finetune.fine_tune_model_id = None
506
+ mock_finetune.return_value = finetune
507
+
508
+ with pytest.raises(ValueError) as exc_info:
509
+ finetune_provider_model("project-123::task-456::finetune-789")
510
+ assert (
511
+ str(exc_info.value)
512
+ == "Fine tune finetune-789 not completed. Refresh it's status in the fine-tune tab."
513
+ )
514
+
515
+
516
+ def test_finetune_provider_model_fireworks_provider(
517
+ mock_project, mock_task, mock_finetune
518
+ ):
519
+ """Test creation of Fireworks AI provider with specific adapter options"""
520
+ finetune = Mock(spec=Finetune)
521
+ finetune.provider = ModelProviderName.fireworks_ai
522
+ finetune.fine_tune_model_id = "fireworks-model-123"
523
+ mock_finetune.return_value = finetune
524
+
525
+ provider = finetune_provider_model("project-123::task-456::finetune-789")
526
+
527
+ assert provider.name == ModelProviderName.fireworks_ai
528
+ assert provider.provider_options == {"model": "fireworks-model-123"}
529
+ assert provider.adapter_options == {
530
+ "langchain": {"with_structured_output_options": {"method": "json_mode"}}
531
+ }
@@ -6,12 +6,12 @@ import jsonschema.exceptions
6
6
  import pytest
7
7
 
8
8
  import kiln_ai.datamodel as datamodel
9
+ from kiln_ai.adapters.adapter_registry import adapter_for_task
9
10
  from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter, RunOutput
10
- from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
11
11
  from kiln_ai.adapters.ml_model_list import (
12
12
  built_in_models,
13
- ollama_online,
14
13
  )
14
+ from kiln_ai.adapters.ollama_tools import ollama_online
15
15
  from kiln_ai.adapters.prompt_builders import (
16
16
  BasePromptBuilder,
17
17
  SimpleChainOfThoughtPromptBuilder,
@@ -20,23 +20,6 @@ from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
20
20
  from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
21
21
 
22
22
 
23
- @pytest.mark.parametrize(
24
- "model_name,provider",
25
- [
26
- ("llama_3_1_8b", "groq"),
27
- ("mistral_nemo", "openrouter"),
28
- ("llama_3_1_70b", "amazon_bedrock"),
29
- ("claude_3_5_sonnet", "openrouter"),
30
- ("gemini_1_5_pro", "openrouter"),
31
- ("gemini_1_5_flash", "openrouter"),
32
- ("gemini_1_5_flash_8b", "openrouter"),
33
- ],
34
- )
35
- @pytest.mark.paid
36
- async def test_structured_output(tmp_path, model_name, provider):
37
- await run_structured_output_test(tmp_path, model_name, provider)
38
-
39
-
40
23
  @pytest.mark.ollama
41
24
  async def test_structured_output_ollama_phi(tmp_path):
42
25
  # https://python.langchain.com/v0.2/docs/how_to/structured_output/#advanced-specifying-the-method-for-structuring-outputs
@@ -112,28 +95,27 @@ async def test_mock_unstructred_response(tmp_path):
112
95
 
113
96
  @pytest.mark.paid
114
97
  @pytest.mark.ollama
115
- async def test_all_built_in_models_structured_output(tmp_path):
116
- errors = []
98
+ @pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
99
+ async def test_all_built_in_models_structured_output(
100
+ tmp_path, model_name, provider_name
101
+ ):
117
102
  for model in built_in_models:
103
+ if model.name != model_name:
104
+ continue
118
105
  if not model.supports_structured_output:
119
- print(
106
+ pytest.skip(
120
107
  f"Skipping {model.name} because it does not support structured output"
121
108
  )
122
- continue
123
109
  for provider in model.providers:
110
+ if provider.name != provider_name:
111
+ continue
124
112
  if not provider.supports_structured_output:
125
- print(
113
+ pytest.skip(
126
114
  f"Skipping {model.name} {provider.name} because it does not support structured output"
127
115
  )
128
- continue
129
- try:
130
- print(f"Running {model.name} {provider.name}")
131
- await run_structured_output_test(tmp_path, model.name, provider.name)
132
- except Exception as e:
133
- print(f"Error running {model.name} {provider.name}")
134
- errors.append(f"{model.name} {provider.name}: {e}")
135
- if len(errors) > 0:
136
- raise RuntimeError(f"Errors: {errors}")
116
+ await run_structured_output_test(tmp_path, model.name, provider.name)
117
+ return
118
+ raise RuntimeError(f"No model {model_name} {provider_name} found")
137
119
 
138
120
 
139
121
  def build_structured_output_test_task(tmp_path: Path):
@@ -157,7 +139,7 @@ def build_structured_output_test_task(tmp_path: Path):
157
139
 
158
140
  async def run_structured_output_test(tmp_path: Path, model_name: str, provider: str):
159
141
  task = build_structured_output_test_task(tmp_path)
160
- a = LangChainPromptAdapter(task, model_name=model_name, provider=provider)
142
+ a = adapter_for_task(task, model_name=model_name, provider=provider)
161
143
  parsed = await a.invoke_returning_raw("Cows") # a joke about cows
162
144
  if parsed is None or not isinstance(parsed, Dict):
163
145
  raise RuntimeError(f"structured response is not a dict: {parsed}")
@@ -204,7 +186,7 @@ async def run_structured_input_task(
204
186
  provider: str,
205
187
  pb: BasePromptBuilder | None = None,
206
188
  ):
207
- a = LangChainPromptAdapter(
189
+ a = adapter_for_task(
208
190
  task, model_name=model_name, provider=provider, prompt_builder=pb
209
191
  )
210
192
  with pytest.raises(ValueError):
@@ -235,14 +217,11 @@ async def test_structured_input_gpt_4o_mini(tmp_path):
235
217
 
236
218
  @pytest.mark.paid
237
219
  @pytest.mark.ollama
238
- async def test_all_built_in_models_structured_input(tmp_path):
239
- for model in built_in_models:
240
- for provider in model.providers:
241
- try:
242
- print(f"Running {model.name} {provider.name}")
243
- await run_structured_input_test(tmp_path, model.name, provider.name)
244
- except Exception as e:
245
- raise RuntimeError(f"Error running {model.name} {provider}") from e
220
+ @pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
221
+ async def test_all_built_in_models_structured_input(
222
+ tmp_path, model_name, provider_name
223
+ ):
224
+ await run_structured_input_test(tmp_path, model_name, provider_name)
246
225
 
247
226
 
248
227
  @pytest.mark.paid