kiln-ai 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/adapter_registry.py +2 -0
- kiln_ai/adapters/base_adapter.py +6 -1
- kiln_ai/adapters/langchain_adapters.py +5 -1
- kiln_ai/adapters/ml_model_list.py +43 -12
- kiln_ai/adapters/ollama_tools.py +4 -3
- kiln_ai/adapters/provider_tools.py +63 -2
- kiln_ai/adapters/repair/repair_task.py +4 -2
- kiln_ai/adapters/test_langchain_adapter.py +183 -0
- kiln_ai/adapters/test_provider_tools.py +315 -1
- kiln_ai/datamodel/__init__.py +162 -19
- kiln_ai/datamodel/basemodel.py +90 -42
- kiln_ai/datamodel/model_cache.py +116 -0
- kiln_ai/datamodel/test_basemodel.py +138 -3
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_model_cache.py +244 -0
- kiln_ai/datamodel/test_models.py +173 -0
- kiln_ai/datamodel/test_output_rating.py +377 -10
- kiln_ai/utils/config.py +33 -10
- kiln_ai/utils/test_config.py +48 -0
- kiln_ai-0.8.0.dist-info/METADATA +237 -0
- {kiln_ai-0.7.0.dist-info → kiln_ai-0.8.0.dist-info}/RECORD +23 -21
- {kiln_ai-0.7.0.dist-info → kiln_ai-0.8.0.dist-info}/WHEEL +1 -1
- kiln_ai-0.7.0.dist-info/METADATA +0 -90
- {kiln_ai-0.7.0.dist-info → kiln_ai-0.8.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -9,6 +9,7 @@ def adapter_for_task(
|
|
|
9
9
|
model_name: str | None = None,
|
|
10
10
|
provider: str | None = None,
|
|
11
11
|
prompt_builder: BasePromptBuilder | None = None,
|
|
12
|
+
tags: list[str] | None = None,
|
|
12
13
|
) -> BaseAdapter:
|
|
13
14
|
# We use langchain for everything right now, but can add any others here
|
|
14
15
|
return LangchainAdapter(
|
|
@@ -16,4 +17,5 @@ def adapter_for_task(
|
|
|
16
17
|
model_name=model_name,
|
|
17
18
|
provider=provider,
|
|
18
19
|
prompt_builder=prompt_builder,
|
|
20
|
+
tags=tags,
|
|
19
21
|
)
|
kiln_ai/adapters/base_adapter.py
CHANGED
|
@@ -45,12 +45,16 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
45
45
|
"""
|
|
46
46
|
|
|
47
47
|
def __init__(
|
|
48
|
-
self,
|
|
48
|
+
self,
|
|
49
|
+
kiln_task: Task,
|
|
50
|
+
prompt_builder: BasePromptBuilder | None = None,
|
|
51
|
+
tags: list[str] | None = None,
|
|
49
52
|
):
|
|
50
53
|
self.prompt_builder = prompt_builder or SimplePromptBuilder(kiln_task)
|
|
51
54
|
self.kiln_task = kiln_task
|
|
52
55
|
self.output_schema = self.kiln_task.output_json_schema
|
|
53
56
|
self.input_schema = self.kiln_task.input_json_schema
|
|
57
|
+
self.default_tags = tags
|
|
54
58
|
|
|
55
59
|
async def invoke_returning_raw(
|
|
56
60
|
self,
|
|
@@ -148,6 +152,7 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
148
152
|
),
|
|
149
153
|
),
|
|
150
154
|
intermediate_outputs=run_output.intermediate_outputs,
|
|
155
|
+
tags=self.default_tags or [],
|
|
151
156
|
)
|
|
152
157
|
|
|
153
158
|
exclude_fields = {
|
|
@@ -39,8 +39,9 @@ class LangchainAdapter(BaseAdapter):
|
|
|
39
39
|
model_name: str | None = None,
|
|
40
40
|
provider: str | None = None,
|
|
41
41
|
prompt_builder: BasePromptBuilder | None = None,
|
|
42
|
+
tags: list[str] | None = None,
|
|
42
43
|
):
|
|
43
|
-
super().__init__(kiln_task, prompt_builder=prompt_builder)
|
|
44
|
+
super().__init__(kiln_task, prompt_builder=prompt_builder, tags=tags)
|
|
44
45
|
if custom_model is not None:
|
|
45
46
|
self._model = custom_model
|
|
46
47
|
|
|
@@ -198,6 +199,9 @@ async def langchain_model_from_provider(
|
|
|
198
199
|
if provider.name == ModelProviderName.openai:
|
|
199
200
|
api_key = Config.shared().open_ai_api_key
|
|
200
201
|
return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type]
|
|
202
|
+
elif provider.name == ModelProviderName.openai_compatible:
|
|
203
|
+
# See provider_tools.py for how base_url, key and other parameters are set
|
|
204
|
+
return ChatOpenAI(**provider.provider_options) # type: ignore[arg-type]
|
|
201
205
|
elif provider.name == ModelProviderName.groq:
|
|
202
206
|
api_key = Config.shared().groq_api_key
|
|
203
207
|
if api_key is None:
|
|
@@ -22,6 +22,8 @@ class ModelProviderName(str, Enum):
|
|
|
22
22
|
openrouter = "openrouter"
|
|
23
23
|
fireworks_ai = "fireworks_ai"
|
|
24
24
|
kiln_fine_tune = "kiln_fine_tune"
|
|
25
|
+
kiln_custom_registry = "kiln_custom_registry"
|
|
26
|
+
openai_compatible = "openai_compatible"
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
class ModelFamily(str, Enum):
|
|
@@ -54,6 +56,7 @@ class ModelName(str, Enum):
|
|
|
54
56
|
llama_3_2_3b = "llama_3_2_3b"
|
|
55
57
|
llama_3_2_11b = "llama_3_2_11b"
|
|
56
58
|
llama_3_2_90b = "llama_3_2_90b"
|
|
59
|
+
llama_3_3_70b = "llama_3_3_70b"
|
|
57
60
|
gpt_4o_mini = "gpt_4o_mini"
|
|
58
61
|
gpt_4o = "gpt_4o"
|
|
59
62
|
phi_3_5 = "phi_3_5"
|
|
@@ -502,6 +505,46 @@ built_in_models: List[KilnModel] = [
|
|
|
502
505
|
),
|
|
503
506
|
],
|
|
504
507
|
),
|
|
508
|
+
# Llama 3.3 70B
|
|
509
|
+
KilnModel(
|
|
510
|
+
family=ModelFamily.llama,
|
|
511
|
+
name=ModelName.llama_3_3_70b,
|
|
512
|
+
friendly_name="Llama 3.3 70B",
|
|
513
|
+
providers=[
|
|
514
|
+
KilnModelProvider(
|
|
515
|
+
name=ModelProviderName.openrouter,
|
|
516
|
+
provider_options={"model": "meta-llama/llama-3.3-70b-instruct"},
|
|
517
|
+
# Openrouter not supporing tools yet. Once they do probably can remove. JSON mode sometimes works, but not consistently.
|
|
518
|
+
supports_structured_output=False,
|
|
519
|
+
supports_data_gen=False,
|
|
520
|
+
adapter_options={
|
|
521
|
+
"langchain": {
|
|
522
|
+
"with_structured_output_options": {"method": "json_mode"}
|
|
523
|
+
}
|
|
524
|
+
},
|
|
525
|
+
),
|
|
526
|
+
KilnModelProvider(
|
|
527
|
+
name=ModelProviderName.groq,
|
|
528
|
+
supports_structured_output=True,
|
|
529
|
+
supports_data_gen=True,
|
|
530
|
+
provider_options={"model": "llama-3.3-70b-versatile"},
|
|
531
|
+
),
|
|
532
|
+
KilnModelProvider(
|
|
533
|
+
name=ModelProviderName.ollama,
|
|
534
|
+
provider_options={"model": "llama3.3"},
|
|
535
|
+
),
|
|
536
|
+
KilnModelProvider(
|
|
537
|
+
name=ModelProviderName.fireworks_ai,
|
|
538
|
+
# Finetuning not live yet
|
|
539
|
+
# provider_finetune_id="accounts/fireworks/models/llama-v3p3-70b-instruct",
|
|
540
|
+
supports_structured_output=True,
|
|
541
|
+
supports_data_gen=True,
|
|
542
|
+
provider_options={
|
|
543
|
+
"model": "accounts/fireworks/models/llama-v3p3-70b-instruct"
|
|
544
|
+
},
|
|
545
|
+
),
|
|
546
|
+
],
|
|
547
|
+
),
|
|
505
548
|
# Phi 3.5
|
|
506
549
|
KilnModel(
|
|
507
550
|
family=ModelFamily.phi,
|
|
@@ -598,18 +641,6 @@ built_in_models: List[KilnModel] = [
|
|
|
598
641
|
name=ModelName.mixtral_8x7b,
|
|
599
642
|
friendly_name="Mixtral 8x7B",
|
|
600
643
|
providers=[
|
|
601
|
-
KilnModelProvider(
|
|
602
|
-
name=ModelProviderName.fireworks_ai,
|
|
603
|
-
provider_options={
|
|
604
|
-
"model": "accounts/fireworks/models/mixtral-8x7b-instruct-hf",
|
|
605
|
-
},
|
|
606
|
-
provider_finetune_id="accounts/fireworks/models/mixtral-8x7b-instruct-hf",
|
|
607
|
-
adapter_options={
|
|
608
|
-
"langchain": {
|
|
609
|
-
"with_structured_output_options": {"method": "json_mode"}
|
|
610
|
-
}
|
|
611
|
-
},
|
|
612
|
-
),
|
|
613
644
|
KilnModelProvider(
|
|
614
645
|
name=ModelProviderName.openrouter,
|
|
615
646
|
provider_options={"model": "mistralai/mixtral-8x7b-instruct"},
|
kiln_ai/adapters/ollama_tools.py
CHANGED
|
@@ -6,6 +6,7 @@ import requests
|
|
|
6
6
|
from pydantic import BaseModel, Field
|
|
7
7
|
|
|
8
8
|
from kiln_ai.adapters.ml_model_list import ModelProviderName, built_in_models
|
|
9
|
+
from kiln_ai.utils.config import Config
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
def ollama_base_url() -> str:
|
|
@@ -16,9 +17,9 @@ def ollama_base_url() -> str:
|
|
|
16
17
|
The base URL to use for Ollama API calls, using environment variable if set
|
|
17
18
|
or falling back to localhost default
|
|
18
19
|
"""
|
|
19
|
-
|
|
20
|
-
if
|
|
21
|
-
return
|
|
20
|
+
config_base_url = Config.shared().ollama_base_url
|
|
21
|
+
if config_base_url:
|
|
22
|
+
return config_base_url
|
|
22
23
|
return "http://localhost:11434"
|
|
23
24
|
|
|
24
25
|
|
|
@@ -11,6 +11,7 @@ from kiln_ai.adapters.ml_model_list import (
|
|
|
11
11
|
from kiln_ai.adapters.ollama_tools import (
|
|
12
12
|
get_ollama_connection,
|
|
13
13
|
)
|
|
14
|
+
from kiln_ai.datamodel import Finetune, Task
|
|
14
15
|
from kiln_ai.datamodel.registry import project_from_id
|
|
15
16
|
|
|
16
17
|
from ..utils.config import Config
|
|
@@ -107,10 +108,18 @@ async def kiln_model_provider_from(
|
|
|
107
108
|
if provider_name == ModelProviderName.kiln_fine_tune:
|
|
108
109
|
return finetune_provider_model(name)
|
|
109
110
|
|
|
111
|
+
if provider_name == ModelProviderName.openai_compatible:
|
|
112
|
+
return openai_compatible_provider_model(name)
|
|
113
|
+
|
|
110
114
|
built_in_model = await builtin_model_from(name, provider_name)
|
|
111
115
|
if built_in_model:
|
|
112
116
|
return built_in_model
|
|
113
117
|
|
|
118
|
+
# For custom registry, get the provider name and model name from the model id
|
|
119
|
+
if provider_name == ModelProviderName.kiln_custom_registry:
|
|
120
|
+
provider_name = name.split("::", 1)[0]
|
|
121
|
+
name = name.split("::", 1)[1]
|
|
122
|
+
|
|
114
123
|
# Custom/untested model. Set untested, and build a ModelProvider at runtime
|
|
115
124
|
if provider_name is None:
|
|
116
125
|
raise ValueError("Provider name is required for custom models")
|
|
@@ -130,6 +139,45 @@ async def kiln_model_provider_from(
|
|
|
130
139
|
finetune_cache: dict[str, KilnModelProvider] = {}
|
|
131
140
|
|
|
132
141
|
|
|
142
|
+
def openai_compatible_provider_model(
|
|
143
|
+
model_id: str,
|
|
144
|
+
) -> KilnModelProvider:
|
|
145
|
+
try:
|
|
146
|
+
openai_provider_name, model_id = model_id.split("::")
|
|
147
|
+
except Exception:
|
|
148
|
+
raise ValueError(f"Invalid openai compatible model ID: {model_id}")
|
|
149
|
+
|
|
150
|
+
openai_compatible_providers = Config.shared().openai_compatible_providers or []
|
|
151
|
+
provider = next(
|
|
152
|
+
filter(
|
|
153
|
+
lambda p: p.get("name") == openai_provider_name, openai_compatible_providers
|
|
154
|
+
),
|
|
155
|
+
None,
|
|
156
|
+
)
|
|
157
|
+
if provider is None:
|
|
158
|
+
raise ValueError(f"OpenAI compatible provider {openai_provider_name} not found")
|
|
159
|
+
|
|
160
|
+
# API key optional some providers don't use it
|
|
161
|
+
api_key = provider.get("api_key")
|
|
162
|
+
base_url = provider.get("base_url")
|
|
163
|
+
if base_url is None:
|
|
164
|
+
raise ValueError(
|
|
165
|
+
f"OpenAI compatible provider {openai_provider_name} has no base URL"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
return KilnModelProvider(
|
|
169
|
+
name=ModelProviderName.openai_compatible,
|
|
170
|
+
provider_options={
|
|
171
|
+
"model": model_id,
|
|
172
|
+
"api_key": api_key,
|
|
173
|
+
"openai_api_base": base_url,
|
|
174
|
+
},
|
|
175
|
+
supports_structured_output=False,
|
|
176
|
+
supports_data_gen=False,
|
|
177
|
+
untested_model=True,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
|
|
133
181
|
def finetune_provider_model(
|
|
134
182
|
model_id: str,
|
|
135
183
|
) -> KilnModelProvider:
|
|
@@ -143,10 +191,10 @@ def finetune_provider_model(
|
|
|
143
191
|
project = project_from_id(project_id)
|
|
144
192
|
if project is None:
|
|
145
193
|
raise ValueError(f"Project {project_id} not found")
|
|
146
|
-
task =
|
|
194
|
+
task = Task.from_id_and_parent_path(task_id, project.path)
|
|
147
195
|
if task is None:
|
|
148
196
|
raise ValueError(f"Task {task_id} not found")
|
|
149
|
-
fine_tune =
|
|
197
|
+
fine_tune = Finetune.from_id_and_parent_path(fine_tune_id, task.path)
|
|
150
198
|
if fine_tune is None:
|
|
151
199
|
raise ValueError(f"Fine tune {fine_tune_id} not found")
|
|
152
200
|
if fine_tune.fine_tune_model_id is None:
|
|
@@ -220,6 +268,10 @@ def provider_name_from_id(id: str) -> str:
|
|
|
220
268
|
return "Fine Tuned Models"
|
|
221
269
|
case ModelProviderName.fireworks_ai:
|
|
222
270
|
return "Fireworks AI"
|
|
271
|
+
case ModelProviderName.kiln_custom_registry:
|
|
272
|
+
return "Custom Models"
|
|
273
|
+
case ModelProviderName.openai_compatible:
|
|
274
|
+
return "OpenAI Compatible"
|
|
223
275
|
case _:
|
|
224
276
|
# triggers pyright warning if I miss a case
|
|
225
277
|
raise_exhaustive_error(enum_id)
|
|
@@ -233,6 +285,7 @@ def provider_options_for_custom_model(
|
|
|
233
285
|
"""
|
|
234
286
|
Generated model provider options for a custom model. Each has their own format/options.
|
|
235
287
|
"""
|
|
288
|
+
|
|
236
289
|
if provider_name not in ModelProviderName.__members__:
|
|
237
290
|
raise ValueError(f"Invalid provider name: {provider_name}")
|
|
238
291
|
|
|
@@ -249,10 +302,18 @@ def provider_options_for_custom_model(
|
|
|
249
302
|
| ModelProviderName.groq
|
|
250
303
|
):
|
|
251
304
|
return {"model": model_name}
|
|
305
|
+
case ModelProviderName.kiln_custom_registry:
|
|
306
|
+
raise ValueError(
|
|
307
|
+
"Custom models from registry should be parsed into provider/model before calling this."
|
|
308
|
+
)
|
|
252
309
|
case ModelProviderName.kiln_fine_tune:
|
|
253
310
|
raise ValueError(
|
|
254
311
|
"Fine tuned models should populate provider options via another path"
|
|
255
312
|
)
|
|
313
|
+
case ModelProviderName.openai_compatible:
|
|
314
|
+
raise ValueError(
|
|
315
|
+
"OpenAI compatible models should populate provider options via another path"
|
|
316
|
+
)
|
|
256
317
|
case _:
|
|
257
318
|
# triggers pyright warning if I miss a case
|
|
258
319
|
raise_exhaustive_error(enum_id)
|
|
@@ -43,8 +43,10 @@ feedback describing what should be improved. Your job is to understand the evalu
|
|
|
43
43
|
@classmethod
|
|
44
44
|
def _original_prompt(cls, run: TaskRun, task: Task) -> str:
|
|
45
45
|
prompt_builder_class: Type[BasePromptBuilder] | None = None
|
|
46
|
-
prompt_builder_name =
|
|
47
|
-
"prompt_builder_name", None
|
|
46
|
+
prompt_builder_name = (
|
|
47
|
+
run.output.source.properties.get("prompt_builder_name", None)
|
|
48
|
+
if run.output.source
|
|
49
|
+
else None
|
|
48
50
|
)
|
|
49
51
|
if prompt_builder_name is not None and isinstance(prompt_builder_name, str):
|
|
50
52
|
prompt_builder_class = prompt_builder_registry.get(
|
|
@@ -1,12 +1,20 @@
|
|
|
1
|
+
import os
|
|
1
2
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
2
3
|
|
|
4
|
+
import pytest
|
|
5
|
+
from langchain_aws import ChatBedrockConverse
|
|
3
6
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
7
|
+
from langchain_fireworks import ChatFireworks
|
|
4
8
|
from langchain_groq import ChatGroq
|
|
9
|
+
from langchain_ollama import ChatOllama
|
|
10
|
+
from langchain_openai import ChatOpenAI
|
|
5
11
|
|
|
6
12
|
from kiln_ai.adapters.langchain_adapters import (
|
|
7
13
|
LangchainAdapter,
|
|
8
14
|
get_structured_output_options,
|
|
15
|
+
langchain_model_from_provider,
|
|
9
16
|
)
|
|
17
|
+
from kiln_ai.adapters.ml_model_list import KilnModelProvider, ModelProviderName
|
|
10
18
|
from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
|
|
11
19
|
from kiln_ai.adapters.test_prompt_adaptors import build_test_task
|
|
12
20
|
|
|
@@ -150,3 +158,178 @@ async def test_get_structured_output_options():
|
|
|
150
158
|
):
|
|
151
159
|
options = await get_structured_output_options("model_name", "provider")
|
|
152
160
|
assert options == {}
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@pytest.mark.asyncio
|
|
164
|
+
async def test_langchain_model_from_provider_openai():
|
|
165
|
+
provider = KilnModelProvider(
|
|
166
|
+
name=ModelProviderName.openai, provider_options={"model": "gpt-4"}
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
|
|
170
|
+
mock_config.return_value.open_ai_api_key = "test_key"
|
|
171
|
+
model = await langchain_model_from_provider(provider, "gpt-4")
|
|
172
|
+
assert isinstance(model, ChatOpenAI)
|
|
173
|
+
assert model.model_name == "gpt-4"
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@pytest.mark.asyncio
|
|
177
|
+
async def test_langchain_model_from_provider_groq():
|
|
178
|
+
provider = KilnModelProvider(
|
|
179
|
+
name=ModelProviderName.groq, provider_options={"model": "mixtral-8x7b"}
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
|
|
183
|
+
mock_config.return_value.groq_api_key = "test_key"
|
|
184
|
+
model = await langchain_model_from_provider(provider, "mixtral-8x7b")
|
|
185
|
+
assert isinstance(model, ChatGroq)
|
|
186
|
+
assert model.model_name == "mixtral-8x7b"
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@pytest.mark.asyncio
|
|
190
|
+
async def test_langchain_model_from_provider_bedrock():
|
|
191
|
+
provider = KilnModelProvider(
|
|
192
|
+
name=ModelProviderName.amazon_bedrock,
|
|
193
|
+
provider_options={"model": "anthropic.claude-v2", "region_name": "us-east-1"},
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
|
|
197
|
+
mock_config.return_value.bedrock_access_key = "test_access"
|
|
198
|
+
mock_config.return_value.bedrock_secret_key = "test_secret"
|
|
199
|
+
model = await langchain_model_from_provider(provider, "anthropic.claude-v2")
|
|
200
|
+
assert isinstance(model, ChatBedrockConverse)
|
|
201
|
+
assert os.environ.get("AWS_ACCESS_KEY_ID") == "test_access"
|
|
202
|
+
assert os.environ.get("AWS_SECRET_ACCESS_KEY") == "test_secret"
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
@pytest.mark.asyncio
|
|
206
|
+
async def test_langchain_model_from_provider_fireworks():
|
|
207
|
+
provider = KilnModelProvider(
|
|
208
|
+
name=ModelProviderName.fireworks_ai, provider_options={"model": "mixtral-8x7b"}
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
|
|
212
|
+
mock_config.return_value.fireworks_api_key = "test_key"
|
|
213
|
+
model = await langchain_model_from_provider(provider, "mixtral-8x7b")
|
|
214
|
+
assert isinstance(model, ChatFireworks)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
@pytest.mark.asyncio
|
|
218
|
+
async def test_langchain_model_from_provider_ollama():
|
|
219
|
+
provider = KilnModelProvider(
|
|
220
|
+
name=ModelProviderName.ollama,
|
|
221
|
+
provider_options={"model": "llama2", "model_aliases": ["llama2-uncensored"]},
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
mock_connection = MagicMock()
|
|
225
|
+
with (
|
|
226
|
+
patch(
|
|
227
|
+
"kiln_ai.adapters.langchain_adapters.get_ollama_connection",
|
|
228
|
+
return_value=AsyncMock(return_value=mock_connection),
|
|
229
|
+
),
|
|
230
|
+
patch(
|
|
231
|
+
"kiln_ai.adapters.langchain_adapters.ollama_model_installed",
|
|
232
|
+
return_value=True,
|
|
233
|
+
),
|
|
234
|
+
patch(
|
|
235
|
+
"kiln_ai.adapters.langchain_adapters.ollama_base_url",
|
|
236
|
+
return_value="http://localhost:11434",
|
|
237
|
+
),
|
|
238
|
+
):
|
|
239
|
+
model = await langchain_model_from_provider(provider, "llama2")
|
|
240
|
+
assert isinstance(model, ChatOllama)
|
|
241
|
+
assert model.model == "llama2"
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
@pytest.mark.asyncio
|
|
245
|
+
async def test_langchain_model_from_provider_invalid():
|
|
246
|
+
provider = KilnModelProvider.model_construct(
|
|
247
|
+
name="invalid_provider", provider_options={}
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
with pytest.raises(ValueError, match="Invalid model or provider"):
|
|
251
|
+
await langchain_model_from_provider(provider, "test_model")
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@pytest.mark.asyncio
|
|
255
|
+
async def test_langchain_adapter_model_caching(tmp_path):
|
|
256
|
+
task = build_test_task(tmp_path)
|
|
257
|
+
custom_model = ChatGroq(model="mixtral-8x7b", groq_api_key="test")
|
|
258
|
+
|
|
259
|
+
adapter = LangchainAdapter(kiln_task=task, custom_model=custom_model)
|
|
260
|
+
|
|
261
|
+
# First call should return the cached model
|
|
262
|
+
model1 = await adapter.model()
|
|
263
|
+
assert model1 is custom_model
|
|
264
|
+
|
|
265
|
+
# Second call should return the same cached instance
|
|
266
|
+
model2 = await adapter.model()
|
|
267
|
+
assert model2 is model1
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@pytest.mark.asyncio
|
|
271
|
+
async def test_langchain_adapter_model_structured_output(tmp_path):
|
|
272
|
+
task = build_test_task(tmp_path)
|
|
273
|
+
task.output_json_schema = """
|
|
274
|
+
{
|
|
275
|
+
"type": "object",
|
|
276
|
+
"properties": {
|
|
277
|
+
"count": {"type": "integer"}
|
|
278
|
+
}
|
|
279
|
+
}
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
mock_model = MagicMock()
|
|
283
|
+
mock_model.with_structured_output = MagicMock(return_value="structured_model")
|
|
284
|
+
|
|
285
|
+
adapter = LangchainAdapter(
|
|
286
|
+
kiln_task=task, model_name="test_model", provider="test_provider"
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
with (
|
|
290
|
+
patch(
|
|
291
|
+
"kiln_ai.adapters.langchain_adapters.langchain_model_from",
|
|
292
|
+
AsyncMock(return_value=mock_model),
|
|
293
|
+
),
|
|
294
|
+
patch(
|
|
295
|
+
"kiln_ai.adapters.langchain_adapters.get_structured_output_options",
|
|
296
|
+
AsyncMock(return_value={"option1": "value1"}),
|
|
297
|
+
),
|
|
298
|
+
):
|
|
299
|
+
model = await adapter.model()
|
|
300
|
+
|
|
301
|
+
# Verify the model was configured with structured output
|
|
302
|
+
mock_model.with_structured_output.assert_called_once_with(
|
|
303
|
+
{
|
|
304
|
+
"type": "object",
|
|
305
|
+
"properties": {"count": {"type": "integer"}},
|
|
306
|
+
"title": "task_response",
|
|
307
|
+
"description": "A response from the task",
|
|
308
|
+
},
|
|
309
|
+
include_raw=True,
|
|
310
|
+
option1="value1",
|
|
311
|
+
)
|
|
312
|
+
assert model == "structured_model"
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
@pytest.mark.asyncio
|
|
316
|
+
async def test_langchain_adapter_model_no_structured_output_support(tmp_path):
|
|
317
|
+
task = build_test_task(tmp_path)
|
|
318
|
+
task.output_json_schema = (
|
|
319
|
+
'{"type": "object", "properties": {"count": {"type": "integer"}}}'
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
mock_model = MagicMock()
|
|
323
|
+
# Remove with_structured_output method
|
|
324
|
+
del mock_model.with_structured_output
|
|
325
|
+
|
|
326
|
+
adapter = LangchainAdapter(
|
|
327
|
+
kiln_task=task, model_name="test_model", provider="test_provider"
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
with patch(
|
|
331
|
+
"kiln_ai.adapters.langchain_adapters.langchain_model_from",
|
|
332
|
+
AsyncMock(return_value=mock_model),
|
|
333
|
+
):
|
|
334
|
+
with pytest.raises(ValueError, match="does not support structured output"):
|
|
335
|
+
await adapter.model()
|