kiln-ai 0.7.0__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/ml_model_list.py +34 -12
- kiln_ai/adapters/ollama_tools.py +4 -3
- kiln_ai/adapters/provider_tools.py +15 -2
- kiln_ai/adapters/repair/repair_task.py +4 -2
- kiln_ai/adapters/test_langchain_adapter.py +183 -0
- kiln_ai/adapters/test_provider_tools.py +220 -1
- kiln_ai/datamodel/__init__.py +55 -5
- kiln_ai/datamodel/basemodel.py +92 -38
- kiln_ai/datamodel/model_cache.py +116 -0
- kiln_ai/datamodel/test_basemodel.py +138 -3
- kiln_ai/datamodel/test_model_cache.py +244 -0
- kiln_ai/datamodel/test_models.py +124 -0
- kiln_ai/utils/config.py +5 -1
- kiln_ai-0.7.1.dist-info/METADATA +237 -0
- {kiln_ai-0.7.0.dist-info → kiln_ai-0.7.1.dist-info}/RECORD +17 -15
- {kiln_ai-0.7.0.dist-info → kiln_ai-0.7.1.dist-info}/WHEEL +1 -1
- kiln_ai-0.7.0.dist-info/METADATA +0 -90
- {kiln_ai-0.7.0.dist-info → kiln_ai-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -22,6 +22,7 @@ class ModelProviderName(str, Enum):
|
|
|
22
22
|
openrouter = "openrouter"
|
|
23
23
|
fireworks_ai = "fireworks_ai"
|
|
24
24
|
kiln_fine_tune = "kiln_fine_tune"
|
|
25
|
+
kiln_custom_registry = "kiln_custom_registry"
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
class ModelFamily(str, Enum):
|
|
@@ -54,6 +55,7 @@ class ModelName(str, Enum):
|
|
|
54
55
|
llama_3_2_3b = "llama_3_2_3b"
|
|
55
56
|
llama_3_2_11b = "llama_3_2_11b"
|
|
56
57
|
llama_3_2_90b = "llama_3_2_90b"
|
|
58
|
+
llama_3_3_70b = "llama_3_3_70b"
|
|
57
59
|
gpt_4o_mini = "gpt_4o_mini"
|
|
58
60
|
gpt_4o = "gpt_4o"
|
|
59
61
|
phi_3_5 = "phi_3_5"
|
|
@@ -502,6 +504,38 @@ built_in_models: List[KilnModel] = [
|
|
|
502
504
|
),
|
|
503
505
|
],
|
|
504
506
|
),
|
|
507
|
+
# Llama 3.3 70B
|
|
508
|
+
KilnModel(
|
|
509
|
+
family=ModelFamily.llama,
|
|
510
|
+
name=ModelName.llama_3_3_70b,
|
|
511
|
+
friendly_name="Llama 3.3 70B",
|
|
512
|
+
providers=[
|
|
513
|
+
KilnModelProvider(
|
|
514
|
+
name=ModelProviderName.openrouter,
|
|
515
|
+
provider_options={"model": "meta-llama/llama-3.3-70b-instruct"},
|
|
516
|
+
# Openrouter not supporing tools yet. Once they do probably can remove. JSON mode sometimes works, but not consistently.
|
|
517
|
+
supports_structured_output=False,
|
|
518
|
+
supports_data_gen=False,
|
|
519
|
+
adapter_options={
|
|
520
|
+
"langchain": {
|
|
521
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
522
|
+
}
|
|
523
|
+
},
|
|
524
|
+
),
|
|
525
|
+
KilnModelProvider(
|
|
526
|
+
name=ModelProviderName.ollama,
|
|
527
|
+
provider_options={"model": "llama3.3"},
|
|
528
|
+
),
|
|
529
|
+
KilnModelProvider(
|
|
530
|
+
name=ModelProviderName.fireworks_ai,
|
|
531
|
+
# Finetuning not live yet
|
|
532
|
+
# provider_finetune_id="accounts/fireworks/models/llama-v3p3-70b-instruct",
|
|
533
|
+
provider_options={
|
|
534
|
+
"model": "accounts/fireworks/models/llama-v3p3-70b-instruct"
|
|
535
|
+
},
|
|
536
|
+
),
|
|
537
|
+
],
|
|
538
|
+
),
|
|
505
539
|
# Phi 3.5
|
|
506
540
|
KilnModel(
|
|
507
541
|
family=ModelFamily.phi,
|
|
@@ -598,18 +632,6 @@ built_in_models: List[KilnModel] = [
|
|
|
598
632
|
name=ModelName.mixtral_8x7b,
|
|
599
633
|
friendly_name="Mixtral 8x7B",
|
|
600
634
|
providers=[
|
|
601
|
-
KilnModelProvider(
|
|
602
|
-
name=ModelProviderName.fireworks_ai,
|
|
603
|
-
provider_options={
|
|
604
|
-
"model": "accounts/fireworks/models/mixtral-8x7b-instruct-hf",
|
|
605
|
-
},
|
|
606
|
-
provider_finetune_id="accounts/fireworks/models/mixtral-8x7b-instruct-hf",
|
|
607
|
-
adapter_options={
|
|
608
|
-
"langchain": {
|
|
609
|
-
"with_structured_output_options": {"method": "json_mode"}
|
|
610
|
-
}
|
|
611
|
-
},
|
|
612
|
-
),
|
|
613
635
|
KilnModelProvider(
|
|
614
636
|
name=ModelProviderName.openrouter,
|
|
615
637
|
provider_options={"model": "mistralai/mixtral-8x7b-instruct"},
|
kiln_ai/adapters/ollama_tools.py
CHANGED
|
@@ -6,6 +6,7 @@ import requests
|
|
|
6
6
|
from pydantic import BaseModel, Field
|
|
7
7
|
|
|
8
8
|
from kiln_ai.adapters.ml_model_list import ModelProviderName, built_in_models
|
|
9
|
+
from kiln_ai.utils.config import Config
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
def ollama_base_url() -> str:
|
|
@@ -16,9 +17,9 @@ def ollama_base_url() -> str:
|
|
|
16
17
|
The base URL to use for Ollama API calls, using environment variable if set
|
|
17
18
|
or falling back to localhost default
|
|
18
19
|
"""
|
|
19
|
-
|
|
20
|
-
if
|
|
21
|
-
return
|
|
20
|
+
config_base_url = Config.shared().ollama_base_url
|
|
21
|
+
if config_base_url:
|
|
22
|
+
return config_base_url
|
|
22
23
|
return "http://localhost:11434"
|
|
23
24
|
|
|
24
25
|
|
|
@@ -11,6 +11,7 @@ from kiln_ai.adapters.ml_model_list import (
|
|
|
11
11
|
from kiln_ai.adapters.ollama_tools import (
|
|
12
12
|
get_ollama_connection,
|
|
13
13
|
)
|
|
14
|
+
from kiln_ai.datamodel import Finetune, Task
|
|
14
15
|
from kiln_ai.datamodel.registry import project_from_id
|
|
15
16
|
|
|
16
17
|
from ..utils.config import Config
|
|
@@ -111,6 +112,11 @@ async def kiln_model_provider_from(
|
|
|
111
112
|
if built_in_model:
|
|
112
113
|
return built_in_model
|
|
113
114
|
|
|
115
|
+
# For custom registry, get the provider name and model name from the model id
|
|
116
|
+
if provider_name == ModelProviderName.kiln_custom_registry:
|
|
117
|
+
provider_name = name.split("::", 1)[0]
|
|
118
|
+
name = name.split("::", 1)[1]
|
|
119
|
+
|
|
114
120
|
# Custom/untested model. Set untested, and build a ModelProvider at runtime
|
|
115
121
|
if provider_name is None:
|
|
116
122
|
raise ValueError("Provider name is required for custom models")
|
|
@@ -143,10 +149,10 @@ def finetune_provider_model(
|
|
|
143
149
|
project = project_from_id(project_id)
|
|
144
150
|
if project is None:
|
|
145
151
|
raise ValueError(f"Project {project_id} not found")
|
|
146
|
-
task =
|
|
152
|
+
task = Task.from_id_and_parent_path(task_id, project.path)
|
|
147
153
|
if task is None:
|
|
148
154
|
raise ValueError(f"Task {task_id} not found")
|
|
149
|
-
fine_tune =
|
|
155
|
+
fine_tune = Finetune.from_id_and_parent_path(fine_tune_id, task.path)
|
|
150
156
|
if fine_tune is None:
|
|
151
157
|
raise ValueError(f"Fine tune {fine_tune_id} not found")
|
|
152
158
|
if fine_tune.fine_tune_model_id is None:
|
|
@@ -220,6 +226,8 @@ def provider_name_from_id(id: str) -> str:
|
|
|
220
226
|
return "Fine Tuned Models"
|
|
221
227
|
case ModelProviderName.fireworks_ai:
|
|
222
228
|
return "Fireworks AI"
|
|
229
|
+
case ModelProviderName.kiln_custom_registry:
|
|
230
|
+
return "Custom Models"
|
|
223
231
|
case _:
|
|
224
232
|
# triggers pyright warning if I miss a case
|
|
225
233
|
raise_exhaustive_error(enum_id)
|
|
@@ -233,6 +241,7 @@ def provider_options_for_custom_model(
|
|
|
233
241
|
"""
|
|
234
242
|
Generated model provider options for a custom model. Each has their own format/options.
|
|
235
243
|
"""
|
|
244
|
+
|
|
236
245
|
if provider_name not in ModelProviderName.__members__:
|
|
237
246
|
raise ValueError(f"Invalid provider name: {provider_name}")
|
|
238
247
|
|
|
@@ -249,6 +258,10 @@ def provider_options_for_custom_model(
|
|
|
249
258
|
| ModelProviderName.groq
|
|
250
259
|
):
|
|
251
260
|
return {"model": model_name}
|
|
261
|
+
case ModelProviderName.kiln_custom_registry:
|
|
262
|
+
raise ValueError(
|
|
263
|
+
"Custom models from registry should be parsed into provider/model before calling this."
|
|
264
|
+
)
|
|
252
265
|
case ModelProviderName.kiln_fine_tune:
|
|
253
266
|
raise ValueError(
|
|
254
267
|
"Fine tuned models should populate provider options via another path"
|
|
@@ -43,8 +43,10 @@ feedback describing what should be improved. Your job is to understand the evalu
|
|
|
43
43
|
@classmethod
|
|
44
44
|
def _original_prompt(cls, run: TaskRun, task: Task) -> str:
|
|
45
45
|
prompt_builder_class: Type[BasePromptBuilder] | None = None
|
|
46
|
-
prompt_builder_name =
|
|
47
|
-
"prompt_builder_name", None
|
|
46
|
+
prompt_builder_name = (
|
|
47
|
+
run.output.source.properties.get("prompt_builder_name", None)
|
|
48
|
+
if run.output.source
|
|
49
|
+
else None
|
|
48
50
|
)
|
|
49
51
|
if prompt_builder_name is not None and isinstance(prompt_builder_name, str):
|
|
50
52
|
prompt_builder_class = prompt_builder_registry.get(
|
|
@@ -1,12 +1,20 @@
|
|
|
1
|
+
import os
|
|
1
2
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
2
3
|
|
|
4
|
+
import pytest
|
|
5
|
+
from langchain_aws import ChatBedrockConverse
|
|
3
6
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
7
|
+
from langchain_fireworks import ChatFireworks
|
|
4
8
|
from langchain_groq import ChatGroq
|
|
9
|
+
from langchain_ollama import ChatOllama
|
|
10
|
+
from langchain_openai import ChatOpenAI
|
|
5
11
|
|
|
6
12
|
from kiln_ai.adapters.langchain_adapters import (
|
|
7
13
|
LangchainAdapter,
|
|
8
14
|
get_structured_output_options,
|
|
15
|
+
langchain_model_from_provider,
|
|
9
16
|
)
|
|
17
|
+
from kiln_ai.adapters.ml_model_list import KilnModelProvider, ModelProviderName
|
|
10
18
|
from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
|
|
11
19
|
from kiln_ai.adapters.test_prompt_adaptors import build_test_task
|
|
12
20
|
|
|
@@ -150,3 +158,178 @@ async def test_get_structured_output_options():
|
|
|
150
158
|
):
|
|
151
159
|
options = await get_structured_output_options("model_name", "provider")
|
|
152
160
|
assert options == {}
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@pytest.mark.asyncio
|
|
164
|
+
async def test_langchain_model_from_provider_openai():
|
|
165
|
+
provider = KilnModelProvider(
|
|
166
|
+
name=ModelProviderName.openai, provider_options={"model": "gpt-4"}
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
|
|
170
|
+
mock_config.return_value.open_ai_api_key = "test_key"
|
|
171
|
+
model = await langchain_model_from_provider(provider, "gpt-4")
|
|
172
|
+
assert isinstance(model, ChatOpenAI)
|
|
173
|
+
assert model.model_name == "gpt-4"
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@pytest.mark.asyncio
|
|
177
|
+
async def test_langchain_model_from_provider_groq():
|
|
178
|
+
provider = KilnModelProvider(
|
|
179
|
+
name=ModelProviderName.groq, provider_options={"model": "mixtral-8x7b"}
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
|
|
183
|
+
mock_config.return_value.groq_api_key = "test_key"
|
|
184
|
+
model = await langchain_model_from_provider(provider, "mixtral-8x7b")
|
|
185
|
+
assert isinstance(model, ChatGroq)
|
|
186
|
+
assert model.model_name == "mixtral-8x7b"
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@pytest.mark.asyncio
|
|
190
|
+
async def test_langchain_model_from_provider_bedrock():
|
|
191
|
+
provider = KilnModelProvider(
|
|
192
|
+
name=ModelProviderName.amazon_bedrock,
|
|
193
|
+
provider_options={"model": "anthropic.claude-v2", "region_name": "us-east-1"},
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
|
|
197
|
+
mock_config.return_value.bedrock_access_key = "test_access"
|
|
198
|
+
mock_config.return_value.bedrock_secret_key = "test_secret"
|
|
199
|
+
model = await langchain_model_from_provider(provider, "anthropic.claude-v2")
|
|
200
|
+
assert isinstance(model, ChatBedrockConverse)
|
|
201
|
+
assert os.environ.get("AWS_ACCESS_KEY_ID") == "test_access"
|
|
202
|
+
assert os.environ.get("AWS_SECRET_ACCESS_KEY") == "test_secret"
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
@pytest.mark.asyncio
|
|
206
|
+
async def test_langchain_model_from_provider_fireworks():
|
|
207
|
+
provider = KilnModelProvider(
|
|
208
|
+
name=ModelProviderName.fireworks_ai, provider_options={"model": "mixtral-8x7b"}
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
|
|
212
|
+
mock_config.return_value.fireworks_api_key = "test_key"
|
|
213
|
+
model = await langchain_model_from_provider(provider, "mixtral-8x7b")
|
|
214
|
+
assert isinstance(model, ChatFireworks)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
@pytest.mark.asyncio
|
|
218
|
+
async def test_langchain_model_from_provider_ollama():
|
|
219
|
+
provider = KilnModelProvider(
|
|
220
|
+
name=ModelProviderName.ollama,
|
|
221
|
+
provider_options={"model": "llama2", "model_aliases": ["llama2-uncensored"]},
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
mock_connection = MagicMock()
|
|
225
|
+
with (
|
|
226
|
+
patch(
|
|
227
|
+
"kiln_ai.adapters.langchain_adapters.get_ollama_connection",
|
|
228
|
+
return_value=AsyncMock(return_value=mock_connection),
|
|
229
|
+
),
|
|
230
|
+
patch(
|
|
231
|
+
"kiln_ai.adapters.langchain_adapters.ollama_model_installed",
|
|
232
|
+
return_value=True,
|
|
233
|
+
),
|
|
234
|
+
patch(
|
|
235
|
+
"kiln_ai.adapters.langchain_adapters.ollama_base_url",
|
|
236
|
+
return_value="http://localhost:11434",
|
|
237
|
+
),
|
|
238
|
+
):
|
|
239
|
+
model = await langchain_model_from_provider(provider, "llama2")
|
|
240
|
+
assert isinstance(model, ChatOllama)
|
|
241
|
+
assert model.model == "llama2"
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
@pytest.mark.asyncio
|
|
245
|
+
async def test_langchain_model_from_provider_invalid():
|
|
246
|
+
provider = KilnModelProvider.model_construct(
|
|
247
|
+
name="invalid_provider", provider_options={}
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
with pytest.raises(ValueError, match="Invalid model or provider"):
|
|
251
|
+
await langchain_model_from_provider(provider, "test_model")
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@pytest.mark.asyncio
|
|
255
|
+
async def test_langchain_adapter_model_caching(tmp_path):
|
|
256
|
+
task = build_test_task(tmp_path)
|
|
257
|
+
custom_model = ChatGroq(model="mixtral-8x7b", groq_api_key="test")
|
|
258
|
+
|
|
259
|
+
adapter = LangchainAdapter(kiln_task=task, custom_model=custom_model)
|
|
260
|
+
|
|
261
|
+
# First call should return the cached model
|
|
262
|
+
model1 = await adapter.model()
|
|
263
|
+
assert model1 is custom_model
|
|
264
|
+
|
|
265
|
+
# Second call should return the same cached instance
|
|
266
|
+
model2 = await adapter.model()
|
|
267
|
+
assert model2 is model1
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@pytest.mark.asyncio
|
|
271
|
+
async def test_langchain_adapter_model_structured_output(tmp_path):
|
|
272
|
+
task = build_test_task(tmp_path)
|
|
273
|
+
task.output_json_schema = """
|
|
274
|
+
{
|
|
275
|
+
"type": "object",
|
|
276
|
+
"properties": {
|
|
277
|
+
"count": {"type": "integer"}
|
|
278
|
+
}
|
|
279
|
+
}
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
mock_model = MagicMock()
|
|
283
|
+
mock_model.with_structured_output = MagicMock(return_value="structured_model")
|
|
284
|
+
|
|
285
|
+
adapter = LangchainAdapter(
|
|
286
|
+
kiln_task=task, model_name="test_model", provider="test_provider"
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
with (
|
|
290
|
+
patch(
|
|
291
|
+
"kiln_ai.adapters.langchain_adapters.langchain_model_from",
|
|
292
|
+
AsyncMock(return_value=mock_model),
|
|
293
|
+
),
|
|
294
|
+
patch(
|
|
295
|
+
"kiln_ai.adapters.langchain_adapters.get_structured_output_options",
|
|
296
|
+
AsyncMock(return_value={"option1": "value1"}),
|
|
297
|
+
),
|
|
298
|
+
):
|
|
299
|
+
model = await adapter.model()
|
|
300
|
+
|
|
301
|
+
# Verify the model was configured with structured output
|
|
302
|
+
mock_model.with_structured_output.assert_called_once_with(
|
|
303
|
+
{
|
|
304
|
+
"type": "object",
|
|
305
|
+
"properties": {"count": {"type": "integer"}},
|
|
306
|
+
"title": "task_response",
|
|
307
|
+
"description": "A response from the task",
|
|
308
|
+
},
|
|
309
|
+
include_raw=True,
|
|
310
|
+
option1="value1",
|
|
311
|
+
)
|
|
312
|
+
assert model == "structured_model"
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
@pytest.mark.asyncio
|
|
316
|
+
async def test_langchain_adapter_model_no_structured_output_support(tmp_path):
|
|
317
|
+
task = build_test_task(tmp_path)
|
|
318
|
+
task.output_json_schema = (
|
|
319
|
+
'{"type": "object", "properties": {"count": {"type": "integer"}}}'
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
mock_model = MagicMock()
|
|
323
|
+
# Remove with_structured_output method
|
|
324
|
+
del mock_model.with_structured_output
|
|
325
|
+
|
|
326
|
+
adapter = LangchainAdapter(
|
|
327
|
+
kiln_task=task, model_name="test_model", provider="test_provider"
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
with patch(
|
|
331
|
+
"kiln_ai.adapters.langchain_adapters.langchain_model_from",
|
|
332
|
+
AsyncMock(return_value=mock_model),
|
|
333
|
+
):
|
|
334
|
+
with pytest.raises(ValueError, match="does not support structured output"):
|
|
335
|
+
await adapter.model()
|
|
@@ -1,14 +1,18 @@
|
|
|
1
|
-
from unittest.mock import AsyncMock, patch
|
|
1
|
+
from unittest.mock import AsyncMock, Mock, patch
|
|
2
2
|
|
|
3
3
|
import pytest
|
|
4
4
|
|
|
5
5
|
from kiln_ai.adapters.ml_model_list import (
|
|
6
|
+
KilnModel,
|
|
6
7
|
ModelName,
|
|
7
8
|
ModelProviderName,
|
|
8
9
|
)
|
|
9
10
|
from kiln_ai.adapters.ollama_tools import OllamaConnection
|
|
10
11
|
from kiln_ai.adapters.provider_tools import (
|
|
12
|
+
builtin_model_from,
|
|
11
13
|
check_provider_warnings,
|
|
14
|
+
finetune_cache,
|
|
15
|
+
finetune_provider_model,
|
|
12
16
|
get_model_and_provider,
|
|
13
17
|
kiln_model_provider_from,
|
|
14
18
|
provider_enabled,
|
|
@@ -16,6 +20,14 @@ from kiln_ai.adapters.provider_tools import (
|
|
|
16
20
|
provider_options_for_custom_model,
|
|
17
21
|
provider_warnings,
|
|
18
22
|
)
|
|
23
|
+
from kiln_ai.datamodel import Finetune, Task
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@pytest.fixture(autouse=True)
|
|
27
|
+
def clear_finetune_cache():
|
|
28
|
+
"""Clear the finetune provider model cache before each test"""
|
|
29
|
+
finetune_cache.clear()
|
|
30
|
+
yield
|
|
19
31
|
|
|
20
32
|
|
|
21
33
|
@pytest.fixture
|
|
@@ -24,6 +36,34 @@ def mock_config():
|
|
|
24
36
|
yield mock
|
|
25
37
|
|
|
26
38
|
|
|
39
|
+
@pytest.fixture
|
|
40
|
+
def mock_project():
|
|
41
|
+
with patch("kiln_ai.adapters.provider_tools.project_from_id") as mock:
|
|
42
|
+
project = Mock()
|
|
43
|
+
project.path = "/fake/path"
|
|
44
|
+
mock.return_value = project
|
|
45
|
+
yield mock
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@pytest.fixture
|
|
49
|
+
def mock_task():
|
|
50
|
+
with patch("kiln_ai.datamodel.Task.from_id_and_parent_path") as mock:
|
|
51
|
+
task = Mock(spec=Task)
|
|
52
|
+
task.path = "/fake/path/task"
|
|
53
|
+
mock.return_value = task
|
|
54
|
+
yield mock
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@pytest.fixture
|
|
58
|
+
def mock_finetune():
|
|
59
|
+
with patch("kiln_ai.datamodel.Finetune.from_id_and_parent_path") as mock:
|
|
60
|
+
finetune = Mock(spec=Finetune)
|
|
61
|
+
finetune.provider = ModelProviderName.openai
|
|
62
|
+
finetune.fine_tune_model_id = "ft:gpt-3.5-turbo:custom:model-123"
|
|
63
|
+
mock.return_value = finetune
|
|
64
|
+
yield mock
|
|
65
|
+
|
|
66
|
+
|
|
27
67
|
def test_check_provider_warnings_no_warning(mock_config):
|
|
28
68
|
mock_config.return_value = "some_value"
|
|
29
69
|
|
|
@@ -103,6 +143,8 @@ def test_provider_name_from_id_case_sensitivity():
|
|
|
103
143
|
(ModelProviderName.ollama, "Ollama"),
|
|
104
144
|
(ModelProviderName.openai, "OpenAI"),
|
|
105
145
|
(ModelProviderName.fireworks_ai, "Fireworks AI"),
|
|
146
|
+
(ModelProviderName.kiln_fine_tune, "Fine Tuned Models"),
|
|
147
|
+
(ModelProviderName.kiln_custom_registry, "Custom Models"),
|
|
106
148
|
],
|
|
107
149
|
)
|
|
108
150
|
def test_provider_name_from_id_parametrized(provider_id, expected_name):
|
|
@@ -310,3 +352,180 @@ def test_provider_options_for_custom_model_invalid_enum():
|
|
|
310
352
|
"""Test handling of invalid enum value"""
|
|
311
353
|
with pytest.raises(ValueError):
|
|
312
354
|
provider_options_for_custom_model("model_name", "invalid_enum_value")
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
@pytest.mark.asyncio
|
|
358
|
+
async def test_kiln_model_provider_from_custom_registry(mock_config):
|
|
359
|
+
# Mock config to pass provider warnings check
|
|
360
|
+
mock_config.return_value = "fake-api-key"
|
|
361
|
+
|
|
362
|
+
# Test with a custom registry model ID in format "provider::model_name"
|
|
363
|
+
provider = await kiln_model_provider_from(
|
|
364
|
+
"openai::gpt-4-turbo", ModelProviderName.kiln_custom_registry
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
assert provider.name == ModelProviderName.openai
|
|
368
|
+
assert provider.supports_structured_output is False
|
|
369
|
+
assert provider.supports_data_gen is False
|
|
370
|
+
assert provider.untested_model is True
|
|
371
|
+
assert provider.provider_options == {"model": "gpt-4-turbo"}
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
@pytest.mark.asyncio
|
|
375
|
+
async def test_builtin_model_from_invalid_model():
|
|
376
|
+
"""Test that an invalid model name returns None"""
|
|
377
|
+
result = await builtin_model_from("non_existent_model")
|
|
378
|
+
assert result is None
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
@pytest.mark.asyncio
|
|
382
|
+
async def test_builtin_model_from_valid_model_default_provider(mock_config):
|
|
383
|
+
"""Test getting a valid model with default provider"""
|
|
384
|
+
mock_config.return_value = "fake-api-key"
|
|
385
|
+
|
|
386
|
+
provider = await builtin_model_from(ModelName.phi_3_5)
|
|
387
|
+
|
|
388
|
+
assert provider is not None
|
|
389
|
+
assert provider.name == ModelProviderName.ollama
|
|
390
|
+
assert provider.provider_options["model"] == "phi3.5"
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
@pytest.mark.asyncio
|
|
394
|
+
async def test_builtin_model_from_valid_model_specific_provider(mock_config):
|
|
395
|
+
"""Test getting a valid model with specific provider"""
|
|
396
|
+
mock_config.return_value = "fake-api-key"
|
|
397
|
+
|
|
398
|
+
provider = await builtin_model_from(
|
|
399
|
+
ModelName.llama_3_1_70b, provider_name=ModelProviderName.groq
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
assert provider is not None
|
|
403
|
+
assert provider.name == ModelProviderName.groq
|
|
404
|
+
assert provider.provider_options["model"] == "llama-3.1-70b-versatile"
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
@pytest.mark.asyncio
|
|
408
|
+
async def test_builtin_model_from_invalid_provider(mock_config):
|
|
409
|
+
"""Test that requesting an invalid provider returns None"""
|
|
410
|
+
mock_config.return_value = "fake-api-key"
|
|
411
|
+
|
|
412
|
+
provider = await builtin_model_from(
|
|
413
|
+
ModelName.phi_3_5, provider_name="invalid_provider"
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
assert provider is None
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
@pytest.mark.asyncio
|
|
420
|
+
async def test_builtin_model_from_model_no_providers():
|
|
421
|
+
"""Test handling of a model with no providers"""
|
|
422
|
+
with patch("kiln_ai.adapters.provider_tools.built_in_models") as mock_models:
|
|
423
|
+
# Create a mock model with no providers
|
|
424
|
+
mock_model = KilnModel(
|
|
425
|
+
name=ModelName.phi_3_5,
|
|
426
|
+
friendly_name="Test Model",
|
|
427
|
+
providers=[],
|
|
428
|
+
family="test_family",
|
|
429
|
+
)
|
|
430
|
+
mock_models.__iter__.return_value = [mock_model]
|
|
431
|
+
|
|
432
|
+
with pytest.raises(ValueError) as exc_info:
|
|
433
|
+
await builtin_model_from(ModelName.phi_3_5)
|
|
434
|
+
|
|
435
|
+
assert str(exc_info.value) == f"Model {ModelName.phi_3_5} has no providers"
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
@pytest.mark.asyncio
|
|
439
|
+
async def test_builtin_model_from_provider_warning_check(mock_config):
|
|
440
|
+
"""Test that provider warnings are checked"""
|
|
441
|
+
# Make the config check fail
|
|
442
|
+
mock_config.return_value = None
|
|
443
|
+
|
|
444
|
+
with pytest.raises(ValueError) as exc_info:
|
|
445
|
+
await builtin_model_from(ModelName.llama_3_1_70b, ModelProviderName.groq)
|
|
446
|
+
|
|
447
|
+
assert provider_warnings[ModelProviderName.groq].message in str(exc_info.value)
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def test_finetune_provider_model_success(mock_project, mock_task, mock_finetune):
|
|
451
|
+
"""Test successful creation of a fine-tuned model provider"""
|
|
452
|
+
model_id = "project-123::task-456::finetune-789"
|
|
453
|
+
|
|
454
|
+
provider = finetune_provider_model(model_id)
|
|
455
|
+
|
|
456
|
+
assert provider.name == ModelProviderName.openai
|
|
457
|
+
assert provider.provider_options == {"model": "ft:gpt-3.5-turbo:custom:model-123"}
|
|
458
|
+
|
|
459
|
+
# Test cache
|
|
460
|
+
cached_provider = finetune_provider_model(model_id)
|
|
461
|
+
assert cached_provider is provider
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def test_finetune_provider_model_invalid_id():
|
|
465
|
+
"""Test handling of invalid model ID format"""
|
|
466
|
+
with pytest.raises(ValueError) as exc_info:
|
|
467
|
+
finetune_provider_model("invalid-id-format")
|
|
468
|
+
assert str(exc_info.value) == "Invalid fine tune ID: invalid-id-format"
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
def test_finetune_provider_model_project_not_found(mock_project):
|
|
472
|
+
"""Test handling of non-existent project"""
|
|
473
|
+
mock_project.return_value = None
|
|
474
|
+
|
|
475
|
+
with pytest.raises(ValueError) as exc_info:
|
|
476
|
+
finetune_provider_model("project-123::task-456::finetune-789")
|
|
477
|
+
assert str(exc_info.value) == "Project project-123 not found"
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def test_finetune_provider_model_task_not_found(mock_project, mock_task):
|
|
481
|
+
"""Test handling of non-existent task"""
|
|
482
|
+
mock_task.return_value = None
|
|
483
|
+
|
|
484
|
+
with pytest.raises(ValueError) as exc_info:
|
|
485
|
+
finetune_provider_model("project-123::task-456::finetune-789")
|
|
486
|
+
assert str(exc_info.value) == "Task task-456 not found"
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def test_finetune_provider_model_finetune_not_found(
|
|
490
|
+
mock_project, mock_task, mock_finetune
|
|
491
|
+
):
|
|
492
|
+
"""Test handling of non-existent fine-tune"""
|
|
493
|
+
mock_finetune.return_value = None
|
|
494
|
+
|
|
495
|
+
with pytest.raises(ValueError) as exc_info:
|
|
496
|
+
finetune_provider_model("project-123::task-456::finetune-789")
|
|
497
|
+
assert str(exc_info.value) == "Fine tune finetune-789 not found"
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def test_finetune_provider_model_incomplete_finetune(
|
|
501
|
+
mock_project, mock_task, mock_finetune
|
|
502
|
+
):
|
|
503
|
+
"""Test handling of incomplete fine-tune"""
|
|
504
|
+
finetune = Mock(spec=Finetune)
|
|
505
|
+
finetune.fine_tune_model_id = None
|
|
506
|
+
mock_finetune.return_value = finetune
|
|
507
|
+
|
|
508
|
+
with pytest.raises(ValueError) as exc_info:
|
|
509
|
+
finetune_provider_model("project-123::task-456::finetune-789")
|
|
510
|
+
assert (
|
|
511
|
+
str(exc_info.value)
|
|
512
|
+
== "Fine tune finetune-789 not completed. Refresh it's status in the fine-tune tab."
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def test_finetune_provider_model_fireworks_provider(
|
|
517
|
+
mock_project, mock_task, mock_finetune
|
|
518
|
+
):
|
|
519
|
+
"""Test creation of Fireworks AI provider with specific adapter options"""
|
|
520
|
+
finetune = Mock(spec=Finetune)
|
|
521
|
+
finetune.provider = ModelProviderName.fireworks_ai
|
|
522
|
+
finetune.fine_tune_model_id = "fireworks-model-123"
|
|
523
|
+
mock_finetune.return_value = finetune
|
|
524
|
+
|
|
525
|
+
provider = finetune_provider_model("project-123::task-456::finetune-789")
|
|
526
|
+
|
|
527
|
+
assert provider.name == ModelProviderName.fireworks_ai
|
|
528
|
+
assert provider.provider_options == {"model": "fireworks-model-123"}
|
|
529
|
+
assert provider.adapter_options == {
|
|
530
|
+
"langchain": {"with_structured_output_options": {"method": "json_mode"}}
|
|
531
|
+
}
|