kiln-ai 0.6.1__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (40) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +19 -0
  3. kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
  4. kiln_ai/adapters/fine_tune/__init__.py +14 -0
  5. kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
  6. kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
  7. kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
  8. kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
  9. kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
  10. kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
  11. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
  12. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
  13. kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
  14. kiln_ai/adapters/langchain_adapters.py +103 -13
  15. kiln_ai/adapters/ml_model_list.py +218 -304
  16. kiln_ai/adapters/ollama_tools.py +114 -0
  17. kiln_ai/adapters/provider_tools.py +295 -0
  18. kiln_ai/adapters/repair/test_repair_task.py +6 -11
  19. kiln_ai/adapters/test_langchain_adapter.py +46 -18
  20. kiln_ai/adapters/test_ollama_tools.py +42 -0
  21. kiln_ai/adapters/test_prompt_adaptors.py +7 -5
  22. kiln_ai/adapters/test_provider_tools.py +312 -0
  23. kiln_ai/adapters/test_structured_output.py +22 -43
  24. kiln_ai/datamodel/__init__.py +235 -22
  25. kiln_ai/datamodel/basemodel.py +30 -0
  26. kiln_ai/datamodel/registry.py +31 -0
  27. kiln_ai/datamodel/test_basemodel.py +29 -1
  28. kiln_ai/datamodel/test_dataset_split.py +234 -0
  29. kiln_ai/datamodel/test_example_models.py +12 -0
  30. kiln_ai/datamodel/test_models.py +91 -1
  31. kiln_ai/datamodel/test_registry.py +96 -0
  32. kiln_ai/utils/config.py +9 -0
  33. kiln_ai/utils/name_generator.py +125 -0
  34. kiln_ai/utils/test_name_geneator.py +47 -0
  35. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.0.dist-info}/METADATA +4 -2
  36. kiln_ai-0.7.0.dist-info/RECORD +56 -0
  37. kiln_ai/adapters/test_ml_model_list.py +0 -181
  38. kiln_ai-0.6.1.dist-info/RECORD +0 -37
  39. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.0.dist-info}/WHEEL +0 -0
  40. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,312 @@
1
+ from unittest.mock import AsyncMock, patch
2
+
3
+ import pytest
4
+
5
+ from kiln_ai.adapters.ml_model_list import (
6
+ ModelName,
7
+ ModelProviderName,
8
+ )
9
+ from kiln_ai.adapters.ollama_tools import OllamaConnection
10
+ from kiln_ai.adapters.provider_tools import (
11
+ check_provider_warnings,
12
+ get_model_and_provider,
13
+ kiln_model_provider_from,
14
+ provider_enabled,
15
+ provider_name_from_id,
16
+ provider_options_for_custom_model,
17
+ provider_warnings,
18
+ )
19
+
20
+
21
+ @pytest.fixture
22
+ def mock_config():
23
+ with patch("kiln_ai.adapters.provider_tools.get_config_value") as mock:
24
+ yield mock
25
+
26
+
27
+ def test_check_provider_warnings_no_warning(mock_config):
28
+ mock_config.return_value = "some_value"
29
+
30
+ # This should not raise an exception
31
+ check_provider_warnings(ModelProviderName.amazon_bedrock)
32
+
33
+
34
+ def test_check_provider_warnings_missing_key(mock_config):
35
+ mock_config.return_value = None
36
+
37
+ with pytest.raises(ValueError) as exc_info:
38
+ check_provider_warnings(ModelProviderName.amazon_bedrock)
39
+
40
+ assert provider_warnings[ModelProviderName.amazon_bedrock].message in str(
41
+ exc_info.value
42
+ )
43
+
44
+
45
+ def test_check_provider_warnings_unknown_provider():
46
+ # This should not raise an exception, as no settings are required for unknown providers
47
+ check_provider_warnings("unknown_provider")
48
+
49
+
50
+ @pytest.mark.parametrize(
51
+ "provider_name",
52
+ [
53
+ ModelProviderName.amazon_bedrock,
54
+ ModelProviderName.openrouter,
55
+ ModelProviderName.groq,
56
+ ModelProviderName.openai,
57
+ ModelProviderName.fireworks_ai,
58
+ ],
59
+ )
60
+ def test_check_provider_warnings_all_providers(mock_config, provider_name):
61
+ mock_config.return_value = None
62
+
63
+ with pytest.raises(ValueError) as exc_info:
64
+ check_provider_warnings(provider_name)
65
+
66
+ assert provider_warnings[provider_name].message in str(exc_info.value)
67
+
68
+
69
+ def test_check_provider_warnings_partial_keys_set(mock_config):
70
+ def mock_get(key):
71
+ return "value" if key == "bedrock_access_key" else None
72
+
73
+ mock_config.side_effect = mock_get
74
+
75
+ with pytest.raises(ValueError) as exc_info:
76
+ check_provider_warnings(ModelProviderName.amazon_bedrock)
77
+
78
+ assert provider_warnings[ModelProviderName.amazon_bedrock].message in str(
79
+ exc_info.value
80
+ )
81
+
82
+
83
+ def test_provider_name_from_id_unknown_provider():
84
+ assert (
85
+ provider_name_from_id("unknown_provider")
86
+ == "Unknown provider: unknown_provider"
87
+ )
88
+
89
+
90
+ def test_provider_name_from_id_case_sensitivity():
91
+ assert (
92
+ provider_name_from_id(ModelProviderName.amazon_bedrock.upper())
93
+ == "Unknown provider: AMAZON_BEDROCK"
94
+ )
95
+
96
+
97
+ @pytest.mark.parametrize(
98
+ "provider_id, expected_name",
99
+ [
100
+ (ModelProviderName.amazon_bedrock, "Amazon Bedrock"),
101
+ (ModelProviderName.openrouter, "OpenRouter"),
102
+ (ModelProviderName.groq, "Groq"),
103
+ (ModelProviderName.ollama, "Ollama"),
104
+ (ModelProviderName.openai, "OpenAI"),
105
+ (ModelProviderName.fireworks_ai, "Fireworks AI"),
106
+ ],
107
+ )
108
+ def test_provider_name_from_id_parametrized(provider_id, expected_name):
109
+ assert provider_name_from_id(provider_id) == expected_name
110
+
111
+
112
+ def test_get_model_and_provider_valid():
113
+ # Test with a known valid model and provider combination
114
+ model, provider = get_model_and_provider(
115
+ ModelName.phi_3_5, ModelProviderName.ollama
116
+ )
117
+
118
+ assert model is not None
119
+ assert provider is not None
120
+ assert model.name == ModelName.phi_3_5
121
+ assert provider.name == ModelProviderName.ollama
122
+ assert provider.provider_options["model"] == "phi3.5"
123
+
124
+
125
+ def test_get_model_and_provider_invalid_model():
126
+ # Test with an invalid model name
127
+ model, provider = get_model_and_provider(
128
+ "nonexistent_model", ModelProviderName.ollama
129
+ )
130
+
131
+ assert model is None
132
+ assert provider is None
133
+
134
+
135
+ def test_get_model_and_provider_invalid_provider():
136
+ # Test with a valid model but invalid provider
137
+ model, provider = get_model_and_provider(ModelName.phi_3_5, "nonexistent_provider")
138
+
139
+ assert model is None
140
+ assert provider is None
141
+
142
+
143
+ def test_get_model_and_provider_valid_model_wrong_provider():
144
+ # Test with a valid model but a provider that doesn't support it
145
+ model, provider = get_model_and_provider(
146
+ ModelName.phi_3_5, ModelProviderName.amazon_bedrock
147
+ )
148
+
149
+ assert model is None
150
+ assert provider is None
151
+
152
+
153
+ def test_get_model_and_provider_multiple_providers():
154
+ # Test with a model that has multiple providers
155
+ model, provider = get_model_and_provider(
156
+ ModelName.llama_3_1_70b, ModelProviderName.groq
157
+ )
158
+
159
+ assert model is not None
160
+ assert provider is not None
161
+ assert model.name == ModelName.llama_3_1_70b
162
+ assert provider.name == ModelProviderName.groq
163
+ assert provider.provider_options["model"] == "llama-3.1-70b-versatile"
164
+
165
+
166
+ @pytest.mark.asyncio
167
+ async def test_provider_enabled_ollama_success():
168
+ with patch(
169
+ "kiln_ai.adapters.provider_tools.get_ollama_connection", new_callable=AsyncMock
170
+ ) as mock_get_ollama:
171
+ # Mock successful Ollama connection with models
172
+ mock_get_ollama.return_value = OllamaConnection(
173
+ message="Connected", supported_models=["phi3.5:latest"]
174
+ )
175
+
176
+ result = await provider_enabled(ModelProviderName.ollama)
177
+ assert result is True
178
+
179
+
180
+ @pytest.mark.asyncio
181
+ async def test_provider_enabled_ollama_no_models():
182
+ with patch(
183
+ "kiln_ai.adapters.provider_tools.get_ollama_connection", new_callable=AsyncMock
184
+ ) as mock_get_ollama:
185
+ # Mock Ollama connection but with no models
186
+ mock_get_ollama.return_value = OllamaConnection(
187
+ message="Connected but no models",
188
+ supported_models=[],
189
+ unsupported_models=[],
190
+ )
191
+
192
+ result = await provider_enabled(ModelProviderName.ollama)
193
+ assert result is False
194
+
195
+
196
+ @pytest.mark.asyncio
197
+ async def test_provider_enabled_ollama_connection_error():
198
+ with patch(
199
+ "kiln_ai.adapters.provider_tools.get_ollama_connection", new_callable=AsyncMock
200
+ ) as mock_get_ollama:
201
+ # Mock Ollama connection failure
202
+ mock_get_ollama.side_effect = Exception("Connection failed")
203
+
204
+ result = await provider_enabled(ModelProviderName.ollama)
205
+ assert result is False
206
+
207
+
208
+ @pytest.mark.asyncio
209
+ async def test_provider_enabled_openai_with_key(mock_config):
210
+ # Mock config to return API key
211
+ mock_config.return_value = "fake-api-key"
212
+
213
+ result = await provider_enabled(ModelProviderName.openai)
214
+ assert result is True
215
+ mock_config.assert_called_with("open_ai_api_key")
216
+
217
+
218
+ @pytest.mark.asyncio
219
+ async def test_provider_enabled_openai_without_key(mock_config):
220
+ # Mock config to return None for API key
221
+ mock_config.return_value = None
222
+
223
+ result = await provider_enabled(ModelProviderName.openai)
224
+ assert result is False
225
+ mock_config.assert_called_with("open_ai_api_key")
226
+
227
+
228
+ @pytest.mark.asyncio
229
+ async def test_provider_enabled_unknown_provider():
230
+ # Test with a provider that isn't in provider_warnings
231
+ result = await provider_enabled("unknown_provider")
232
+ assert result is False
233
+
234
+
235
+ @pytest.mark.asyncio
236
+ async def test_kiln_model_provider_from_custom_model_no_provider():
237
+ with pytest.raises(ValueError) as exc_info:
238
+ await kiln_model_provider_from("custom_model")
239
+ assert str(exc_info.value) == "Provider name is required for custom models"
240
+
241
+
242
+ @pytest.mark.asyncio
243
+ async def test_kiln_model_provider_from_invalid_provider():
244
+ with pytest.raises(ValueError) as exc_info:
245
+ await kiln_model_provider_from("custom_model", "invalid_provider")
246
+ assert str(exc_info.value) == "Invalid provider name: invalid_provider"
247
+
248
+
249
+ @pytest.mark.asyncio
250
+ async def test_kiln_model_provider_from_custom_model_valid(mock_config):
251
+ # Mock config to pass provider warnings check
252
+ mock_config.return_value = "fake-api-key"
253
+
254
+ provider = await kiln_model_provider_from("custom_model", ModelProviderName.openai)
255
+
256
+ assert provider.name == ModelProviderName.openai
257
+ assert provider.supports_structured_output is False
258
+ assert provider.supports_data_gen is False
259
+ assert provider.untested_model is True
260
+ assert "model" in provider.provider_options
261
+ assert provider.provider_options["model"] == "custom_model"
262
+
263
+
264
+ def test_provider_options_for_custom_model_basic():
265
+ """Test basic case with custom model name"""
266
+ options = provider_options_for_custom_model(
267
+ "custom_model_name", ModelProviderName.openai
268
+ )
269
+ assert options == {"model": "custom_model_name"}
270
+
271
+
272
+ def test_provider_options_for_custom_model_bedrock():
273
+ """Test Amazon Bedrock provider options"""
274
+ options = provider_options_for_custom_model(
275
+ ModelName.llama_3_1_8b, ModelProviderName.amazon_bedrock
276
+ )
277
+ assert options == {"model": ModelName.llama_3_1_8b, "region_name": "us-west-2"}
278
+
279
+
280
+ @pytest.mark.parametrize(
281
+ "provider",
282
+ [
283
+ ModelProviderName.openai,
284
+ ModelProviderName.ollama,
285
+ ModelProviderName.fireworks_ai,
286
+ ModelProviderName.openrouter,
287
+ ModelProviderName.groq,
288
+ ],
289
+ )
290
+ def test_provider_options_for_custom_model_simple_providers(provider):
291
+ """Test providers that just need model name"""
292
+
293
+ options = provider_options_for_custom_model(ModelName.llama_3_1_8b, provider)
294
+ assert options == {"model": ModelName.llama_3_1_8b}
295
+
296
+
297
+ def test_provider_options_for_custom_model_kiln_fine_tune():
298
+ """Test that kiln_fine_tune raises appropriate error"""
299
+ with pytest.raises(ValueError) as exc_info:
300
+ provider_options_for_custom_model(
301
+ "model_name", ModelProviderName.kiln_fine_tune
302
+ )
303
+ assert (
304
+ str(exc_info.value)
305
+ == "Fine tuned models should populate provider options via another path"
306
+ )
307
+
308
+
309
+ def test_provider_options_for_custom_model_invalid_enum():
310
+ """Test handling of invalid enum value"""
311
+ with pytest.raises(ValueError):
312
+ provider_options_for_custom_model("model_name", "invalid_enum_value")
@@ -6,12 +6,12 @@ import jsonschema.exceptions
6
6
  import pytest
7
7
 
8
8
  import kiln_ai.datamodel as datamodel
9
+ from kiln_ai.adapters.adapter_registry import adapter_for_task
9
10
  from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter, RunOutput
10
- from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
11
11
  from kiln_ai.adapters.ml_model_list import (
12
12
  built_in_models,
13
- ollama_online,
14
13
  )
14
+ from kiln_ai.adapters.ollama_tools import ollama_online
15
15
  from kiln_ai.adapters.prompt_builders import (
16
16
  BasePromptBuilder,
17
17
  SimpleChainOfThoughtPromptBuilder,
@@ -20,23 +20,6 @@ from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
20
20
  from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
21
21
 
22
22
 
23
- @pytest.mark.parametrize(
24
- "model_name,provider",
25
- [
26
- ("llama_3_1_8b", "groq"),
27
- ("mistral_nemo", "openrouter"),
28
- ("llama_3_1_70b", "amazon_bedrock"),
29
- ("claude_3_5_sonnet", "openrouter"),
30
- ("gemini_1_5_pro", "openrouter"),
31
- ("gemini_1_5_flash", "openrouter"),
32
- ("gemini_1_5_flash_8b", "openrouter"),
33
- ],
34
- )
35
- @pytest.mark.paid
36
- async def test_structured_output(tmp_path, model_name, provider):
37
- await run_structured_output_test(tmp_path, model_name, provider)
38
-
39
-
40
23
  @pytest.mark.ollama
41
24
  async def test_structured_output_ollama_phi(tmp_path):
42
25
  # https://python.langchain.com/v0.2/docs/how_to/structured_output/#advanced-specifying-the-method-for-structuring-outputs
@@ -112,28 +95,27 @@ async def test_mock_unstructred_response(tmp_path):
112
95
 
113
96
  @pytest.mark.paid
114
97
  @pytest.mark.ollama
115
- async def test_all_built_in_models_structured_output(tmp_path):
116
- errors = []
98
+ @pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
99
+ async def test_all_built_in_models_structured_output(
100
+ tmp_path, model_name, provider_name
101
+ ):
117
102
  for model in built_in_models:
103
+ if model.name != model_name:
104
+ continue
118
105
  if not model.supports_structured_output:
119
- print(
106
+ pytest.skip(
120
107
  f"Skipping {model.name} because it does not support structured output"
121
108
  )
122
- continue
123
109
  for provider in model.providers:
110
+ if provider.name != provider_name:
111
+ continue
124
112
  if not provider.supports_structured_output:
125
- print(
113
+ pytest.skip(
126
114
  f"Skipping {model.name} {provider.name} because it does not support structured output"
127
115
  )
128
- continue
129
- try:
130
- print(f"Running {model.name} {provider.name}")
131
- await run_structured_output_test(tmp_path, model.name, provider.name)
132
- except Exception as e:
133
- print(f"Error running {model.name} {provider.name}")
134
- errors.append(f"{model.name} {provider.name}: {e}")
135
- if len(errors) > 0:
136
- raise RuntimeError(f"Errors: {errors}")
116
+ await run_structured_output_test(tmp_path, model.name, provider.name)
117
+ return
118
+ raise RuntimeError(f"No model {model_name} {provider_name} found")
137
119
 
138
120
 
139
121
  def build_structured_output_test_task(tmp_path: Path):
@@ -157,7 +139,7 @@ def build_structured_output_test_task(tmp_path: Path):
157
139
 
158
140
  async def run_structured_output_test(tmp_path: Path, model_name: str, provider: str):
159
141
  task = build_structured_output_test_task(tmp_path)
160
- a = LangChainPromptAdapter(task, model_name=model_name, provider=provider)
142
+ a = adapter_for_task(task, model_name=model_name, provider=provider)
161
143
  parsed = await a.invoke_returning_raw("Cows") # a joke about cows
162
144
  if parsed is None or not isinstance(parsed, Dict):
163
145
  raise RuntimeError(f"structured response is not a dict: {parsed}")
@@ -204,7 +186,7 @@ async def run_structured_input_task(
204
186
  provider: str,
205
187
  pb: BasePromptBuilder | None = None,
206
188
  ):
207
- a = LangChainPromptAdapter(
189
+ a = adapter_for_task(
208
190
  task, model_name=model_name, provider=provider, prompt_builder=pb
209
191
  )
210
192
  with pytest.raises(ValueError):
@@ -235,14 +217,11 @@ async def test_structured_input_gpt_4o_mini(tmp_path):
235
217
 
236
218
  @pytest.mark.paid
237
219
  @pytest.mark.ollama
238
- async def test_all_built_in_models_structured_input(tmp_path):
239
- for model in built_in_models:
240
- for provider in model.providers:
241
- try:
242
- print(f"Running {model.name} {provider.name}")
243
- await run_structured_input_test(tmp_path, model.name, provider.name)
244
- except Exception as e:
245
- raise RuntimeError(f"Error running {model.name} {provider}") from e
220
+ @pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
221
+ async def test_all_built_in_models_structured_input(
222
+ tmp_path, model_name, provider_name
223
+ ):
224
+ await run_structured_input_test(tmp_path, model_name, provider_name)
246
225
 
247
226
 
248
227
  @pytest.mark.paid