kiln-ai 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +19 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
- kiln_ai/adapters/fine_tune/__init__.py +14 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
- kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
- kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
- kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
- kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
- kiln_ai/adapters/langchain_adapters.py +103 -13
- kiln_ai/adapters/ml_model_list.py +239 -303
- kiln_ai/adapters/ollama_tools.py +115 -0
- kiln_ai/adapters/provider_tools.py +308 -0
- kiln_ai/adapters/repair/repair_task.py +4 -2
- kiln_ai/adapters/repair/test_repair_task.py +6 -11
- kiln_ai/adapters/test_langchain_adapter.py +229 -18
- kiln_ai/adapters/test_ollama_tools.py +42 -0
- kiln_ai/adapters/test_prompt_adaptors.py +7 -5
- kiln_ai/adapters/test_provider_tools.py +531 -0
- kiln_ai/adapters/test_structured_output.py +22 -43
- kiln_ai/datamodel/__init__.py +287 -24
- kiln_ai/datamodel/basemodel.py +122 -38
- kiln_ai/datamodel/model_cache.py +116 -0
- kiln_ai/datamodel/registry.py +31 -0
- kiln_ai/datamodel/test_basemodel.py +167 -4
- kiln_ai/datamodel/test_dataset_split.py +234 -0
- kiln_ai/datamodel/test_example_models.py +12 -0
- kiln_ai/datamodel/test_model_cache.py +244 -0
- kiln_ai/datamodel/test_models.py +215 -1
- kiln_ai/datamodel/test_registry.py +96 -0
- kiln_ai/utils/config.py +14 -1
- kiln_ai/utils/name_generator.py +125 -0
- kiln_ai/utils/test_name_geneator.py +47 -0
- kiln_ai-0.7.1.dist-info/METADATA +237 -0
- kiln_ai-0.7.1.dist-info/RECORD +58 -0
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/WHEEL +1 -1
- kiln_ai/adapters/test_ml_model_list.py +0 -181
- kiln_ai-0.6.1.dist-info/METADATA +0 -88
- kiln_ai-0.6.1.dist-info/RECORD +0 -37
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,18 +1,27 @@
|
|
|
1
|
+
import os
|
|
1
2
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
2
3
|
|
|
4
|
+
import pytest
|
|
5
|
+
from langchain_aws import ChatBedrockConverse
|
|
3
6
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
7
|
+
from langchain_fireworks import ChatFireworks
|
|
4
8
|
from langchain_groq import ChatGroq
|
|
9
|
+
from langchain_ollama import ChatOllama
|
|
10
|
+
from langchain_openai import ChatOpenAI
|
|
5
11
|
|
|
6
|
-
from kiln_ai.adapters.langchain_adapters import
|
|
12
|
+
from kiln_ai.adapters.langchain_adapters import (
|
|
13
|
+
LangchainAdapter,
|
|
14
|
+
get_structured_output_options,
|
|
15
|
+
langchain_model_from_provider,
|
|
16
|
+
)
|
|
17
|
+
from kiln_ai.adapters.ml_model_list import KilnModelProvider, ModelProviderName
|
|
7
18
|
from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
|
|
8
19
|
from kiln_ai.adapters.test_prompt_adaptors import build_test_task
|
|
9
20
|
|
|
10
21
|
|
|
11
22
|
def test_langchain_adapter_munge_response(tmp_path):
|
|
12
23
|
task = build_test_task(tmp_path)
|
|
13
|
-
lca =
|
|
14
|
-
kiln_task=task, model_name="llama_3_1_8b", provider="ollama"
|
|
15
|
-
)
|
|
24
|
+
lca = LangchainAdapter(kiln_task=task, model_name="llama_3_1_8b", provider="ollama")
|
|
16
25
|
# Mistral Large tool calling format is a bit different
|
|
17
26
|
response = {
|
|
18
27
|
"name": "task_response",
|
|
@@ -35,7 +44,7 @@ def test_langchain_adapter_infer_model_name(tmp_path):
|
|
|
35
44
|
task = build_test_task(tmp_path)
|
|
36
45
|
custom = ChatGroq(model="llama-3.1-8b-instant", groq_api_key="test")
|
|
37
46
|
|
|
38
|
-
lca =
|
|
47
|
+
lca = LangchainAdapter(kiln_task=task, custom_model=custom)
|
|
39
48
|
|
|
40
49
|
model_info = lca.adapter_info()
|
|
41
50
|
assert model_info.model_name == "custom.langchain:llama-3.1-8b-instant"
|
|
@@ -45,9 +54,7 @@ def test_langchain_adapter_infer_model_name(tmp_path):
|
|
|
45
54
|
def test_langchain_adapter_info(tmp_path):
|
|
46
55
|
task = build_test_task(tmp_path)
|
|
47
56
|
|
|
48
|
-
lca =
|
|
49
|
-
kiln_task=task, model_name="llama_3_1_8b", provider="ollama"
|
|
50
|
-
)
|
|
57
|
+
lca = LangchainAdapter(kiln_task=task, model_name="llama_3_1_8b", provider="ollama")
|
|
51
58
|
|
|
52
59
|
model_info = lca.adapter_info()
|
|
53
60
|
assert model_info.adapter_name == "kiln_langchain_adapter"
|
|
@@ -60,7 +67,7 @@ async def test_langchain_adapter_with_cot(tmp_path):
|
|
|
60
67
|
task.output_json_schema = (
|
|
61
68
|
'{"type": "object", "properties": {"count": {"type": "integer"}}}'
|
|
62
69
|
)
|
|
63
|
-
lca =
|
|
70
|
+
lca = LangchainAdapter(
|
|
64
71
|
kiln_task=task,
|
|
65
72
|
model_name="llama_3_1_8b",
|
|
66
73
|
provider="ollama",
|
|
@@ -69,13 +76,13 @@ async def test_langchain_adapter_with_cot(tmp_path):
|
|
|
69
76
|
|
|
70
77
|
# Mock the base model and its invoke method
|
|
71
78
|
mock_base_model = MagicMock()
|
|
72
|
-
mock_base_model.
|
|
73
|
-
content="Chain of thought reasoning..."
|
|
79
|
+
mock_base_model.ainvoke = AsyncMock(
|
|
80
|
+
return_value=AIMessage(content="Chain of thought reasoning...")
|
|
74
81
|
)
|
|
75
82
|
|
|
76
83
|
# Create a separate mock for self.model()
|
|
77
84
|
mock_model_instance = MagicMock()
|
|
78
|
-
mock_model_instance.
|
|
85
|
+
mock_model_instance.ainvoke = AsyncMock(return_value={"parsed": {"count": 1}})
|
|
79
86
|
|
|
80
87
|
# Mock the langchain_model_from function to return the base model
|
|
81
88
|
mock_model_from = AsyncMock(return_value=mock_base_model)
|
|
@@ -85,14 +92,14 @@ async def test_langchain_adapter_with_cot(tmp_path):
|
|
|
85
92
|
patch(
|
|
86
93
|
"kiln_ai.adapters.langchain_adapters.langchain_model_from", mock_model_from
|
|
87
94
|
),
|
|
88
|
-
patch.object(
|
|
95
|
+
patch.object(LangchainAdapter, "model", return_value=mock_model_instance),
|
|
89
96
|
):
|
|
90
97
|
response = await lca._run("test input")
|
|
91
98
|
|
|
92
99
|
# First 3 messages are the same for both calls
|
|
93
100
|
for invoke_args in [
|
|
94
|
-
mock_base_model.
|
|
95
|
-
mock_model_instance.
|
|
101
|
+
mock_base_model.ainvoke.call_args[0][0],
|
|
102
|
+
mock_model_instance.ainvoke.call_args[0][0],
|
|
96
103
|
]:
|
|
97
104
|
assert isinstance(
|
|
98
105
|
invoke_args[0], SystemMessage
|
|
@@ -107,11 +114,11 @@ async def test_langchain_adapter_with_cot(tmp_path):
|
|
|
107
114
|
assert "step by step" in invoke_args[2].content
|
|
108
115
|
|
|
109
116
|
# the COT should only have 3 messages
|
|
110
|
-
assert len(mock_base_model.
|
|
111
|
-
assert len(mock_model_instance.
|
|
117
|
+
assert len(mock_base_model.ainvoke.call_args[0][0]) == 3
|
|
118
|
+
assert len(mock_model_instance.ainvoke.call_args[0][0]) == 5
|
|
112
119
|
|
|
113
120
|
# the final response should have the COT content and the final instructions
|
|
114
|
-
invoke_args = mock_model_instance.
|
|
121
|
+
invoke_args = mock_model_instance.ainvoke.call_args[0][0]
|
|
115
122
|
assert isinstance(invoke_args[3], AIMessage)
|
|
116
123
|
assert "Chain of thought reasoning..." in invoke_args[3].content
|
|
117
124
|
assert isinstance(invoke_args[4], SystemMessage)
|
|
@@ -122,3 +129,207 @@ async def test_langchain_adapter_with_cot(tmp_path):
|
|
|
122
129
|
== "Chain of thought reasoning..."
|
|
123
130
|
)
|
|
124
131
|
assert response.output == {"count": 1}
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
async def test_get_structured_output_options():
|
|
135
|
+
# Mock the provider response
|
|
136
|
+
mock_provider = MagicMock()
|
|
137
|
+
mock_provider.adapter_options = {
|
|
138
|
+
"langchain": {
|
|
139
|
+
"with_structured_output_options": {
|
|
140
|
+
"force_json_response": True,
|
|
141
|
+
"max_retries": 3,
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
# Test with provider that has options
|
|
147
|
+
with patch(
|
|
148
|
+
"kiln_ai.adapters.langchain_adapters.kiln_model_provider_from",
|
|
149
|
+
AsyncMock(return_value=mock_provider),
|
|
150
|
+
):
|
|
151
|
+
options = await get_structured_output_options("model_name", "provider")
|
|
152
|
+
assert options == {"force_json_response": True, "max_retries": 3}
|
|
153
|
+
|
|
154
|
+
# Test with provider that has no options
|
|
155
|
+
with patch(
|
|
156
|
+
"kiln_ai.adapters.langchain_adapters.kiln_model_provider_from",
|
|
157
|
+
AsyncMock(return_value=None),
|
|
158
|
+
):
|
|
159
|
+
options = await get_structured_output_options("model_name", "provider")
|
|
160
|
+
assert options == {}
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@pytest.mark.asyncio
|
|
164
|
+
async def test_langchain_model_from_provider_openai():
|
|
165
|
+
provider = KilnModelProvider(
|
|
166
|
+
name=ModelProviderName.openai, provider_options={"model": "gpt-4"}
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
|
|
170
|
+
mock_config.return_value.open_ai_api_key = "test_key"
|
|
171
|
+
model = await langchain_model_from_provider(provider, "gpt-4")
|
|
172
|
+
assert isinstance(model, ChatOpenAI)
|
|
173
|
+
assert model.model_name == "gpt-4"
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@pytest.mark.asyncio
|
|
177
|
+
async def test_langchain_model_from_provider_groq():
|
|
178
|
+
provider = KilnModelProvider(
|
|
179
|
+
name=ModelProviderName.groq, provider_options={"model": "mixtral-8x7b"}
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
|
|
183
|
+
mock_config.return_value.groq_api_key = "test_key"
|
|
184
|
+
model = await langchain_model_from_provider(provider, "mixtral-8x7b")
|
|
185
|
+
assert isinstance(model, ChatGroq)
|
|
186
|
+
assert model.model_name == "mixtral-8x7b"
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@pytest.mark.asyncio
|
|
190
|
+
async def test_langchain_model_from_provider_bedrock():
|
|
191
|
+
provider = KilnModelProvider(
|
|
192
|
+
name=ModelProviderName.amazon_bedrock,
|
|
193
|
+
provider_options={"model": "anthropic.claude-v2", "region_name": "us-east-1"},
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
|
|
197
|
+
mock_config.return_value.bedrock_access_key = "test_access"
|
|
198
|
+
mock_config.return_value.bedrock_secret_key = "test_secret"
|
|
199
|
+
model = await langchain_model_from_provider(provider, "anthropic.claude-v2")
|
|
200
|
+
assert isinstance(model, ChatBedrockConverse)
|
|
201
|
+
assert os.environ.get("AWS_ACCESS_KEY_ID") == "test_access"
|
|
202
|
+
assert os.environ.get("AWS_SECRET_ACCESS_KEY") == "test_secret"
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
@pytest.mark.asyncio
|
|
206
|
+
async def test_langchain_model_from_provider_fireworks():
|
|
207
|
+
provider = KilnModelProvider(
|
|
208
|
+
name=ModelProviderName.fireworks_ai, provider_options={"model": "mixtral-8x7b"}
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
|
|
212
|
+
mock_config.return_value.fireworks_api_key = "test_key"
|
|
213
|
+
model = await langchain_model_from_provider(provider, "mixtral-8x7b")
|
|
214
|
+
assert isinstance(model, ChatFireworks)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
@pytest.mark.asyncio
|
|
218
|
+
async def test_langchain_model_from_provider_ollama():
|
|
219
|
+
provider = KilnModelProvider(
|
|
220
|
+
name=ModelProviderName.ollama,
|
|
221
|
+
provider_options={"model": "llama2", "model_aliases": ["llama2-uncensored"]},
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
mock_connection = MagicMock()
|
|
225
|
+
with (
|
|
226
|
+
patch(
|
|
227
|
+
"kiln_ai.adapters.langchain_adapters.get_ollama_connection",
|
|
228
|
+
return_value=AsyncMock(return_value=mock_connection),
|
|
229
|
+
),
|
|
230
|
+
patch(
|
|
231
|
+
"kiln_ai.adapters.langchain_adapters.ollama_model_installed",
|
|
232
|
+
return_value=True,
|
|
233
|
+
),
|
|
234
|
+
patch(
|
|
235
|
+
"kiln_ai.adapters.langchain_adapters.ollama_base_url",
|
|
236
|
+
return_value="http://localhost:11434",
|
|
237
|
+
),
|
|
238
|
+
):
|
|
239
|
+
model = await langchain_model_from_provider(provider, "llama2")
|
|
240
|
+
assert isinstance(model, ChatOllama)
|
|
241
|
+
assert model.model == "llama2"
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
@pytest.mark.asyncio
|
|
245
|
+
async def test_langchain_model_from_provider_invalid():
|
|
246
|
+
provider = KilnModelProvider.model_construct(
|
|
247
|
+
name="invalid_provider", provider_options={}
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
with pytest.raises(ValueError, match="Invalid model or provider"):
|
|
251
|
+
await langchain_model_from_provider(provider, "test_model")
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@pytest.mark.asyncio
|
|
255
|
+
async def test_langchain_adapter_model_caching(tmp_path):
|
|
256
|
+
task = build_test_task(tmp_path)
|
|
257
|
+
custom_model = ChatGroq(model="mixtral-8x7b", groq_api_key="test")
|
|
258
|
+
|
|
259
|
+
adapter = LangchainAdapter(kiln_task=task, custom_model=custom_model)
|
|
260
|
+
|
|
261
|
+
# First call should return the cached model
|
|
262
|
+
model1 = await adapter.model()
|
|
263
|
+
assert model1 is custom_model
|
|
264
|
+
|
|
265
|
+
# Second call should return the same cached instance
|
|
266
|
+
model2 = await adapter.model()
|
|
267
|
+
assert model2 is model1
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@pytest.mark.asyncio
|
|
271
|
+
async def test_langchain_adapter_model_structured_output(tmp_path):
|
|
272
|
+
task = build_test_task(tmp_path)
|
|
273
|
+
task.output_json_schema = """
|
|
274
|
+
{
|
|
275
|
+
"type": "object",
|
|
276
|
+
"properties": {
|
|
277
|
+
"count": {"type": "integer"}
|
|
278
|
+
}
|
|
279
|
+
}
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
mock_model = MagicMock()
|
|
283
|
+
mock_model.with_structured_output = MagicMock(return_value="structured_model")
|
|
284
|
+
|
|
285
|
+
adapter = LangchainAdapter(
|
|
286
|
+
kiln_task=task, model_name="test_model", provider="test_provider"
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
with (
|
|
290
|
+
patch(
|
|
291
|
+
"kiln_ai.adapters.langchain_adapters.langchain_model_from",
|
|
292
|
+
AsyncMock(return_value=mock_model),
|
|
293
|
+
),
|
|
294
|
+
patch(
|
|
295
|
+
"kiln_ai.adapters.langchain_adapters.get_structured_output_options",
|
|
296
|
+
AsyncMock(return_value={"option1": "value1"}),
|
|
297
|
+
),
|
|
298
|
+
):
|
|
299
|
+
model = await adapter.model()
|
|
300
|
+
|
|
301
|
+
# Verify the model was configured with structured output
|
|
302
|
+
mock_model.with_structured_output.assert_called_once_with(
|
|
303
|
+
{
|
|
304
|
+
"type": "object",
|
|
305
|
+
"properties": {"count": {"type": "integer"}},
|
|
306
|
+
"title": "task_response",
|
|
307
|
+
"description": "A response from the task",
|
|
308
|
+
},
|
|
309
|
+
include_raw=True,
|
|
310
|
+
option1="value1",
|
|
311
|
+
)
|
|
312
|
+
assert model == "structured_model"
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
@pytest.mark.asyncio
|
|
316
|
+
async def test_langchain_adapter_model_no_structured_output_support(tmp_path):
|
|
317
|
+
task = build_test_task(tmp_path)
|
|
318
|
+
task.output_json_schema = (
|
|
319
|
+
'{"type": "object", "properties": {"count": {"type": "integer"}}}'
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
mock_model = MagicMock()
|
|
323
|
+
# Remove with_structured_output method
|
|
324
|
+
del mock_model.with_structured_output
|
|
325
|
+
|
|
326
|
+
adapter = LangchainAdapter(
|
|
327
|
+
kiln_task=task, model_name="test_model", provider="test_provider"
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
with patch(
|
|
331
|
+
"kiln_ai.adapters.langchain_adapters.langchain_model_from",
|
|
332
|
+
AsyncMock(return_value=mock_model),
|
|
333
|
+
):
|
|
334
|
+
with pytest.raises(ValueError, match="does not support structured output"):
|
|
335
|
+
await adapter.model()
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from kiln_ai.adapters.ollama_tools import (
|
|
4
|
+
OllamaConnection,
|
|
5
|
+
ollama_model_installed,
|
|
6
|
+
parse_ollama_tags,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def test_parse_ollama_tags_no_models():
|
|
11
|
+
json_response = '{"models":[{"name":"scosman_net","model":"scosman_net:latest"},{"name":"phi3.5:latest","model":"phi3.5:latest","modified_at":"2024-10-02T12:04:35.191519822-04:00","size":2176178843,"digest":"61819fb370a3c1a9be6694869331e5f85f867a079e9271d66cb223acb81d04ba","details":{"parent_model":"","format":"gguf","family":"phi3","families":["phi3"],"parameter_size":"3.8B","quantization_level":"Q4_0"}},{"name":"gemma2:2b","model":"gemma2:2b","modified_at":"2024-09-09T16:46:38.64348929-04:00","size":1629518495,"digest":"8ccf136fdd5298f3ffe2d69862750ea7fb56555fa4d5b18c04e3fa4d82ee09d7","details":{"parent_model":"","format":"gguf","family":"gemma2","families":["gemma2"],"parameter_size":"2.6B","quantization_level":"Q4_0"}},{"name":"llama3.1:latest","model":"llama3.1:latest","modified_at":"2024-09-01T17:19:43.481523695-04:00","size":4661230720,"digest":"f66fc8dc39ea206e03ff6764fcc696b1b4dfb693f0b6ef751731dd4e6269046e","details":{"parent_model":"","format":"gguf","family":"llama","families":["llama"],"parameter_size":"8.0B","quantization_level":"Q4_0"}}]}'
|
|
12
|
+
tags = json.loads(json_response)
|
|
13
|
+
print(json.dumps(tags, indent=2))
|
|
14
|
+
conn = parse_ollama_tags(tags)
|
|
15
|
+
assert "phi3.5:latest" in conn.supported_models
|
|
16
|
+
assert "gemma2:2b" in conn.supported_models
|
|
17
|
+
assert "llama3.1:latest" in conn.supported_models
|
|
18
|
+
assert "scosman_net:latest" in conn.untested_models
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def test_parse_ollama_tags_only_untested_models():
|
|
22
|
+
json_response = '{"models":[{"name":"scosman_net","model":"scosman_net:latest"}]}'
|
|
23
|
+
tags = json.loads(json_response)
|
|
24
|
+
conn = parse_ollama_tags(tags)
|
|
25
|
+
assert conn.supported_models == []
|
|
26
|
+
assert conn.untested_models == ["scosman_net:latest"]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def test_ollama_model_installed():
|
|
30
|
+
conn = OllamaConnection(
|
|
31
|
+
supported_models=["phi3.5:latest", "gemma2:2b", "llama3.1:latest"],
|
|
32
|
+
message="Connected",
|
|
33
|
+
untested_models=["scosman_net:latest"],
|
|
34
|
+
)
|
|
35
|
+
assert ollama_model_installed(conn, "phi3.5:latest")
|
|
36
|
+
assert ollama_model_installed(conn, "phi3.5")
|
|
37
|
+
assert ollama_model_installed(conn, "gemma2:2b")
|
|
38
|
+
assert ollama_model_installed(conn, "llama3.1:latest")
|
|
39
|
+
assert ollama_model_installed(conn, "llama3.1")
|
|
40
|
+
assert ollama_model_installed(conn, "scosman_net:latest")
|
|
41
|
+
assert ollama_model_installed(conn, "scosman_net")
|
|
42
|
+
assert not ollama_model_installed(conn, "unknown_model")
|
|
@@ -5,8 +5,10 @@ import pytest
|
|
|
5
5
|
from langchain_core.language_models.fake_chat_models import FakeListChatModel
|
|
6
6
|
|
|
7
7
|
import kiln_ai.datamodel as datamodel
|
|
8
|
-
from kiln_ai.adapters.
|
|
9
|
-
from kiln_ai.adapters.
|
|
8
|
+
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
9
|
+
from kiln_ai.adapters.langchain_adapters import LangchainAdapter
|
|
10
|
+
from kiln_ai.adapters.ml_model_list import built_in_models
|
|
11
|
+
from kiln_ai.adapters.ollama_tools import ollama_online
|
|
10
12
|
from kiln_ai.adapters.prompt_builders import (
|
|
11
13
|
BasePromptBuilder,
|
|
12
14
|
SimpleChainOfThoughtPromptBuilder,
|
|
@@ -106,7 +108,7 @@ async def test_amazon_bedrock(tmp_path):
|
|
|
106
108
|
async def test_mock(tmp_path):
|
|
107
109
|
task = build_test_task(tmp_path)
|
|
108
110
|
mockChatModel = FakeListChatModel(responses=["mock response"])
|
|
109
|
-
adapter =
|
|
111
|
+
adapter = LangchainAdapter(task, custom_model=mockChatModel)
|
|
110
112
|
run = await adapter.invoke("You are a mock, send me the response!")
|
|
111
113
|
assert "mock response" in run.output.output
|
|
112
114
|
|
|
@@ -114,7 +116,7 @@ async def test_mock(tmp_path):
|
|
|
114
116
|
async def test_mock_returning_run(tmp_path):
|
|
115
117
|
task = build_test_task(tmp_path)
|
|
116
118
|
mockChatModel = FakeListChatModel(responses=["mock response"])
|
|
117
|
-
adapter =
|
|
119
|
+
adapter = LangchainAdapter(task, custom_model=mockChatModel)
|
|
118
120
|
run = await adapter.invoke("You are a mock, send me the response!")
|
|
119
121
|
assert run.output.output == "mock response"
|
|
120
122
|
assert run is not None
|
|
@@ -192,7 +194,7 @@ async def run_simple_task(
|
|
|
192
194
|
provider: str,
|
|
193
195
|
prompt_builder: BasePromptBuilder | None = None,
|
|
194
196
|
) -> datamodel.TaskRun:
|
|
195
|
-
adapter =
|
|
197
|
+
adapter = adapter_for_task(
|
|
196
198
|
task, model_name=model_name, provider=provider, prompt_builder=prompt_builder
|
|
197
199
|
)
|
|
198
200
|
|