kiln-ai 0.8.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +7 -7
- kiln_ai/adapters/adapter_registry.py +81 -10
- kiln_ai/adapters/data_gen/data_gen_task.py +21 -3
- kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
- kiln_ai/adapters/eval/base_eval.py +164 -0
- kiln_ai/adapters/eval/eval_runner.py +267 -0
- kiln_ai/adapters/eval/g_eval.py +367 -0
- kiln_ai/adapters/eval/registry.py +16 -0
- kiln_ai/adapters/eval/test_base_eval.py +324 -0
- kiln_ai/adapters/eval/test_eval_runner.py +640 -0
- kiln_ai/adapters/eval/test_g_eval.py +497 -0
- kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
- kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
- kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
- kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +472 -129
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +114 -22
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
- kiln_ai/adapters/ml_model_list.py +434 -93
- kiln_ai/adapters/model_adapters/__init__.py +18 -0
- kiln_ai/adapters/model_adapters/base_adapter.py +250 -0
- kiln_ai/adapters/model_adapters/langchain_adapters.py +309 -0
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +10 -0
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +289 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +199 -0
- kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +105 -97
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +216 -0
- kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +80 -30
- kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +125 -46
- kiln_ai/adapters/ollama_tools.py +0 -1
- kiln_ai/adapters/parsers/__init__.py +10 -0
- kiln_ai/adapters/parsers/base_parser.py +12 -0
- kiln_ai/adapters/parsers/json_parser.py +37 -0
- kiln_ai/adapters/parsers/parser_registry.py +19 -0
- kiln_ai/adapters/parsers/r1_parser.py +69 -0
- kiln_ai/adapters/parsers/test_json_parser.py +81 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
- kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
- kiln_ai/adapters/prompt_builders.py +193 -49
- kiln_ai/adapters/provider_tools.py +91 -36
- kiln_ai/adapters/repair/repair_task.py +18 -19
- kiln_ai/adapters/repair/test_repair_task.py +7 -7
- kiln_ai/adapters/run_output.py +11 -0
- kiln_ai/adapters/test_adapter_registry.py +177 -0
- kiln_ai/adapters/test_generate_docs.py +69 -0
- kiln_ai/adapters/test_ollama_tools.py +0 -1
- kiln_ai/adapters/test_prompt_adaptors.py +25 -18
- kiln_ai/adapters/test_prompt_builders.py +265 -44
- kiln_ai/adapters/test_provider_tools.py +268 -46
- kiln_ai/datamodel/__init__.py +51 -772
- kiln_ai/datamodel/basemodel.py +31 -11
- kiln_ai/datamodel/datamodel_enums.py +58 -0
- kiln_ai/datamodel/dataset_filters.py +114 -0
- kiln_ai/datamodel/dataset_split.py +170 -0
- kiln_ai/datamodel/eval.py +298 -0
- kiln_ai/datamodel/finetune.py +105 -0
- kiln_ai/datamodel/json_schema.py +14 -3
- kiln_ai/datamodel/model_cache.py +8 -3
- kiln_ai/datamodel/project.py +23 -0
- kiln_ai/datamodel/prompt.py +37 -0
- kiln_ai/datamodel/prompt_id.py +83 -0
- kiln_ai/datamodel/strict_mode.py +24 -0
- kiln_ai/datamodel/task.py +181 -0
- kiln_ai/datamodel/task_output.py +321 -0
- kiln_ai/datamodel/task_run.py +164 -0
- kiln_ai/datamodel/test_basemodel.py +80 -2
- kiln_ai/datamodel/test_dataset_filters.py +71 -0
- kiln_ai/datamodel/test_dataset_split.py +127 -6
- kiln_ai/datamodel/test_datasource.py +3 -2
- kiln_ai/datamodel/test_eval_model.py +635 -0
- kiln_ai/datamodel/test_example_models.py +34 -17
- kiln_ai/datamodel/test_json_schema.py +23 -0
- kiln_ai/datamodel/test_model_cache.py +24 -0
- kiln_ai/datamodel/test_model_perf.py +125 -0
- kiln_ai/datamodel/test_models.py +131 -2
- kiln_ai/datamodel/test_prompt_id.py +129 -0
- kiln_ai/datamodel/test_task.py +159 -0
- kiln_ai/utils/config.py +6 -1
- kiln_ai/utils/exhaustive_error.py +6 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/METADATA +45 -7
- kiln_ai-0.12.0.dist-info/RECORD +100 -0
- kiln_ai/adapters/base_adapter.py +0 -191
- kiln_ai/adapters/langchain_adapters.py +0 -256
- kiln_ai-0.8.1.dist-info/RECORD +0 -58
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -7,21 +7,31 @@ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
|
7
7
|
from langchain_fireworks import ChatFireworks
|
|
8
8
|
from langchain_groq import ChatGroq
|
|
9
9
|
from langchain_ollama import ChatOllama
|
|
10
|
-
from langchain_openai import ChatOpenAI
|
|
11
10
|
|
|
12
|
-
from kiln_ai.adapters.
|
|
11
|
+
from kiln_ai.adapters.ml_model_list import (
|
|
12
|
+
KilnModelProvider,
|
|
13
|
+
ModelProviderName,
|
|
14
|
+
StructuredOutputMode,
|
|
15
|
+
)
|
|
16
|
+
from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
|
|
17
|
+
from kiln_ai.adapters.model_adapters.langchain_adapters import (
|
|
13
18
|
LangchainAdapter,
|
|
14
|
-
get_structured_output_options,
|
|
15
19
|
langchain_model_from_provider,
|
|
16
20
|
)
|
|
17
|
-
from kiln_ai.adapters.ml_model_list import KilnModelProvider, ModelProviderName
|
|
18
|
-
from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
|
|
19
21
|
from kiln_ai.adapters.test_prompt_adaptors import build_test_task
|
|
22
|
+
from kiln_ai.datamodel.task import RunConfig
|
|
20
23
|
|
|
21
24
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
+
@pytest.fixture
|
|
26
|
+
def mock_adapter(tmp_path):
|
|
27
|
+
return LangchainAdapter(
|
|
28
|
+
kiln_task=build_test_task(tmp_path),
|
|
29
|
+
model_name="llama_3_1_8b",
|
|
30
|
+
provider="ollama",
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def test_langchain_adapter_munge_response(mock_adapter):
|
|
25
35
|
# Mistral Large tool calling format is a bit different
|
|
26
36
|
response = {
|
|
27
37
|
"name": "task_response",
|
|
@@ -30,12 +40,12 @@ def test_langchain_adapter_munge_response(tmp_path):
|
|
|
30
40
|
"punchline": "Because she wanted to be a moo-sician!",
|
|
31
41
|
},
|
|
32
42
|
}
|
|
33
|
-
munged =
|
|
43
|
+
munged = mock_adapter._munge_response(response)
|
|
34
44
|
assert munged["setup"] == "Why did the cow join a band?"
|
|
35
45
|
assert munged["punchline"] == "Because she wanted to be a moo-sician!"
|
|
36
46
|
|
|
37
47
|
# non mistral format should continue to work
|
|
38
|
-
munged =
|
|
48
|
+
munged = mock_adapter._munge_response(response["arguments"])
|
|
39
49
|
assert munged["setup"] == "Why did the cow join a band?"
|
|
40
50
|
assert munged["punchline"] == "Because she wanted to be a moo-sician!"
|
|
41
51
|
|
|
@@ -46,9 +56,8 @@ def test_langchain_adapter_infer_model_name(tmp_path):
|
|
|
46
56
|
|
|
47
57
|
lca = LangchainAdapter(kiln_task=task, custom_model=custom)
|
|
48
58
|
|
|
49
|
-
|
|
50
|
-
assert
|
|
51
|
-
assert model_info.model_provider == "custom.langchain:ChatGroq"
|
|
59
|
+
assert lca.run_config.model_name == "custom.langchain:llama-3.1-8b-instant"
|
|
60
|
+
assert lca.run_config.model_provider_name == "custom.langchain:ChatGroq"
|
|
52
61
|
|
|
53
62
|
|
|
54
63
|
def test_langchain_adapter_info(tmp_path):
|
|
@@ -56,10 +65,9 @@ def test_langchain_adapter_info(tmp_path):
|
|
|
56
65
|
|
|
57
66
|
lca = LangchainAdapter(kiln_task=task, model_name="llama_3_1_8b", provider="ollama")
|
|
58
67
|
|
|
59
|
-
|
|
60
|
-
assert
|
|
61
|
-
assert
|
|
62
|
-
assert model_info.model_provider == "ollama"
|
|
68
|
+
assert lca.adapter_name() == "kiln_langchain_adapter"
|
|
69
|
+
assert lca.run_config.model_name == "llama_3_1_8b"
|
|
70
|
+
assert lca.run_config.model_provider_name == "ollama"
|
|
63
71
|
|
|
64
72
|
|
|
65
73
|
async def test_langchain_adapter_with_cot(tmp_path):
|
|
@@ -71,7 +79,7 @@ async def test_langchain_adapter_with_cot(tmp_path):
|
|
|
71
79
|
kiln_task=task,
|
|
72
80
|
model_name="llama_3_1_8b",
|
|
73
81
|
provider="ollama",
|
|
74
|
-
|
|
82
|
+
prompt_id="simple_chain_of_thought_prompt_builder",
|
|
75
83
|
)
|
|
76
84
|
|
|
77
85
|
# Mock the base model and its invoke method
|
|
@@ -89,9 +97,7 @@ async def test_langchain_adapter_with_cot(tmp_path):
|
|
|
89
97
|
|
|
90
98
|
# Patch both the langchain_model_from function and self.model()
|
|
91
99
|
with (
|
|
92
|
-
patch(
|
|
93
|
-
"kiln_ai.adapters.langchain_adapters.langchain_model_from", mock_model_from
|
|
94
|
-
),
|
|
100
|
+
patch.object(LangchainAdapter, "langchain_model_from", mock_model_from),
|
|
95
101
|
patch.object(LangchainAdapter, "model", return_value=mock_model_instance),
|
|
96
102
|
):
|
|
97
103
|
response = await lca._run("test input")
|
|
@@ -121,8 +127,8 @@ async def test_langchain_adapter_with_cot(tmp_path):
|
|
|
121
127
|
invoke_args = mock_model_instance.ainvoke.call_args[0][0]
|
|
122
128
|
assert isinstance(invoke_args[3], AIMessage)
|
|
123
129
|
assert "Chain of thought reasoning..." in invoke_args[3].content
|
|
124
|
-
assert isinstance(invoke_args[4],
|
|
125
|
-
assert
|
|
130
|
+
assert isinstance(invoke_args[4], HumanMessage)
|
|
131
|
+
assert COT_FINAL_ANSWER_PROMPT in invoke_args[4].content
|
|
126
132
|
|
|
127
133
|
assert (
|
|
128
134
|
response.intermediate_outputs["chain_of_thought"]
|
|
@@ -131,46 +137,28 @@ async def test_langchain_adapter_with_cot(tmp_path):
|
|
|
131
137
|
assert response.output == {"count": 1}
|
|
132
138
|
|
|
133
139
|
|
|
134
|
-
|
|
140
|
+
@pytest.mark.parametrize(
|
|
141
|
+
"structured_output_mode,expected_method",
|
|
142
|
+
[
|
|
143
|
+
(StructuredOutputMode.function_calling, "function_calling"),
|
|
144
|
+
(StructuredOutputMode.json_mode, "json_mode"),
|
|
145
|
+
(StructuredOutputMode.json_schema, "json_schema"),
|
|
146
|
+
(StructuredOutputMode.json_instruction_and_object, "json_mode"),
|
|
147
|
+
(StructuredOutputMode.default, None),
|
|
148
|
+
],
|
|
149
|
+
)
|
|
150
|
+
async def test_get_structured_output_options(
|
|
151
|
+
mock_adapter, structured_output_mode, expected_method
|
|
152
|
+
):
|
|
135
153
|
# Mock the provider response
|
|
136
154
|
mock_provider = MagicMock()
|
|
137
|
-
mock_provider.
|
|
138
|
-
"langchain": {
|
|
139
|
-
"with_structured_output_options": {
|
|
140
|
-
"force_json_response": True,
|
|
141
|
-
"max_retries": 3,
|
|
142
|
-
}
|
|
143
|
-
}
|
|
144
|
-
}
|
|
145
|
-
|
|
146
|
-
# Test with provider that has options
|
|
147
|
-
with patch(
|
|
148
|
-
"kiln_ai.adapters.langchain_adapters.kiln_model_provider_from",
|
|
149
|
-
AsyncMock(return_value=mock_provider),
|
|
150
|
-
):
|
|
151
|
-
options = await get_structured_output_options("model_name", "provider")
|
|
152
|
-
assert options == {"force_json_response": True, "max_retries": 3}
|
|
153
|
-
|
|
154
|
-
# Test with provider that has no options
|
|
155
|
-
with patch(
|
|
156
|
-
"kiln_ai.adapters.langchain_adapters.kiln_model_provider_from",
|
|
157
|
-
AsyncMock(return_value=None),
|
|
158
|
-
):
|
|
159
|
-
options = await get_structured_output_options("model_name", "provider")
|
|
160
|
-
assert options == {}
|
|
155
|
+
mock_provider.structured_output_mode = structured_output_mode
|
|
161
156
|
|
|
157
|
+
# Mock adapter.model_provider()
|
|
158
|
+
mock_adapter.model_provider = MagicMock(return_value=mock_provider)
|
|
162
159
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
provider = KilnModelProvider(
|
|
166
|
-
name=ModelProviderName.openai, provider_options={"model": "gpt-4"}
|
|
167
|
-
)
|
|
168
|
-
|
|
169
|
-
with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
|
|
170
|
-
mock_config.return_value.open_ai_api_key = "test_key"
|
|
171
|
-
model = await langchain_model_from_provider(provider, "gpt-4")
|
|
172
|
-
assert isinstance(model, ChatOpenAI)
|
|
173
|
-
assert model.model_name == "gpt-4"
|
|
160
|
+
options = mock_adapter.get_structured_output_options("model_name", "provider")
|
|
161
|
+
assert options.get("method") == expected_method
|
|
174
162
|
|
|
175
163
|
|
|
176
164
|
@pytest.mark.asyncio
|
|
@@ -179,7 +167,9 @@ async def test_langchain_model_from_provider_groq():
|
|
|
179
167
|
name=ModelProviderName.groq, provider_options={"model": "mixtral-8x7b"}
|
|
180
168
|
)
|
|
181
169
|
|
|
182
|
-
with patch(
|
|
170
|
+
with patch(
|
|
171
|
+
"kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
|
|
172
|
+
) as mock_config:
|
|
183
173
|
mock_config.return_value.groq_api_key = "test_key"
|
|
184
174
|
model = await langchain_model_from_provider(provider, "mixtral-8x7b")
|
|
185
175
|
assert isinstance(model, ChatGroq)
|
|
@@ -193,7 +183,9 @@ async def test_langchain_model_from_provider_bedrock():
|
|
|
193
183
|
provider_options={"model": "anthropic.claude-v2", "region_name": "us-east-1"},
|
|
194
184
|
)
|
|
195
185
|
|
|
196
|
-
with patch(
|
|
186
|
+
with patch(
|
|
187
|
+
"kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
|
|
188
|
+
) as mock_config:
|
|
197
189
|
mock_config.return_value.bedrock_access_key = "test_access"
|
|
198
190
|
mock_config.return_value.bedrock_secret_key = "test_secret"
|
|
199
191
|
model = await langchain_model_from_provider(provider, "anthropic.claude-v2")
|
|
@@ -208,7 +200,9 @@ async def test_langchain_model_from_provider_fireworks():
|
|
|
208
200
|
name=ModelProviderName.fireworks_ai, provider_options={"model": "mixtral-8x7b"}
|
|
209
201
|
)
|
|
210
202
|
|
|
211
|
-
with patch(
|
|
203
|
+
with patch(
|
|
204
|
+
"kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
|
|
205
|
+
) as mock_config:
|
|
212
206
|
mock_config.return_value.fireworks_api_key = "test_key"
|
|
213
207
|
model = await langchain_model_from_provider(provider, "mixtral-8x7b")
|
|
214
208
|
assert isinstance(model, ChatFireworks)
|
|
@@ -224,15 +218,15 @@ async def test_langchain_model_from_provider_ollama():
|
|
|
224
218
|
mock_connection = MagicMock()
|
|
225
219
|
with (
|
|
226
220
|
patch(
|
|
227
|
-
"kiln_ai.adapters.langchain_adapters.get_ollama_connection",
|
|
221
|
+
"kiln_ai.adapters.model_adapters.langchain_adapters.get_ollama_connection",
|
|
228
222
|
return_value=AsyncMock(return_value=mock_connection),
|
|
229
223
|
),
|
|
230
224
|
patch(
|
|
231
|
-
"kiln_ai.adapters.langchain_adapters.ollama_model_installed",
|
|
225
|
+
"kiln_ai.adapters.model_adapters.langchain_adapters.ollama_model_installed",
|
|
232
226
|
return_value=True,
|
|
233
227
|
),
|
|
234
228
|
patch(
|
|
235
|
-
"kiln_ai.adapters.langchain_adapters.ollama_base_url",
|
|
229
|
+
"kiln_ai.adapters.model_adapters.langchain_adapters.ollama_base_url",
|
|
236
230
|
return_value="http://localhost:11434",
|
|
237
231
|
),
|
|
238
232
|
):
|
|
@@ -283,33 +277,27 @@ async def test_langchain_adapter_model_structured_output(tmp_path):
|
|
|
283
277
|
mock_model.with_structured_output = MagicMock(return_value="structured_model")
|
|
284
278
|
|
|
285
279
|
adapter = LangchainAdapter(
|
|
286
|
-
kiln_task=task, model_name="test_model", provider="
|
|
280
|
+
kiln_task=task, model_name="test_model", provider="ollama"
|
|
281
|
+
)
|
|
282
|
+
adapter.get_structured_output_options = MagicMock(
|
|
283
|
+
return_value={"option1": "value1"}
|
|
287
284
|
)
|
|
285
|
+
adapter.langchain_model_from = AsyncMock(return_value=mock_model)
|
|
288
286
|
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
"
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
{
|
|
304
|
-
"type": "object",
|
|
305
|
-
"properties": {"count": {"type": "integer"}},
|
|
306
|
-
"title": "task_response",
|
|
307
|
-
"description": "A response from the task",
|
|
308
|
-
},
|
|
309
|
-
include_raw=True,
|
|
310
|
-
option1="value1",
|
|
311
|
-
)
|
|
312
|
-
assert model == "structured_model"
|
|
287
|
+
model = await adapter.model()
|
|
288
|
+
|
|
289
|
+
# Verify the model was configured with structured output
|
|
290
|
+
mock_model.with_structured_output.assert_called_once_with(
|
|
291
|
+
{
|
|
292
|
+
"type": "object",
|
|
293
|
+
"properties": {"count": {"type": "integer"}},
|
|
294
|
+
"title": "task_response",
|
|
295
|
+
"description": "A response from the task",
|
|
296
|
+
},
|
|
297
|
+
include_raw=True,
|
|
298
|
+
option1="value1",
|
|
299
|
+
)
|
|
300
|
+
assert model == "structured_model"
|
|
313
301
|
|
|
314
302
|
|
|
315
303
|
@pytest.mark.asyncio
|
|
@@ -324,12 +312,32 @@ async def test_langchain_adapter_model_no_structured_output_support(tmp_path):
|
|
|
324
312
|
del mock_model.with_structured_output
|
|
325
313
|
|
|
326
314
|
adapter = LangchainAdapter(
|
|
327
|
-
kiln_task=task, model_name="test_model", provider="
|
|
315
|
+
kiln_task=task, model_name="test_model", provider="ollama"
|
|
328
316
|
)
|
|
317
|
+
adapter.langchain_model_from = AsyncMock(return_value=mock_model)
|
|
329
318
|
|
|
330
|
-
with
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
319
|
+
with pytest.raises(ValueError, match="does not support structured output"):
|
|
320
|
+
await adapter.model()
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
import pytest
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
@pytest.mark.parametrize(
|
|
327
|
+
"provider_name",
|
|
328
|
+
[
|
|
329
|
+
(ModelProviderName.openai),
|
|
330
|
+
(ModelProviderName.openai_compatible),
|
|
331
|
+
(ModelProviderName.openrouter),
|
|
332
|
+
],
|
|
333
|
+
)
|
|
334
|
+
@pytest.mark.asyncio
|
|
335
|
+
async def test_langchain_model_from_provider_unsupported_providers(provider_name):
|
|
336
|
+
# Arrange
|
|
337
|
+
provider = KilnModelProvider(
|
|
338
|
+
name=provider_name, provider_options={}, structured_output_mode="default"
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
# Assert unsupported providers raise an error
|
|
342
|
+
with pytest.raises(ValueError):
|
|
343
|
+
await langchain_model_from_provider(provider, "test-model")
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from unittest.mock import Mock, patch
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
from openai import AsyncOpenAI
|
|
6
|
+
|
|
7
|
+
from kiln_ai.adapters.ml_model_list import StructuredOutputMode
|
|
8
|
+
from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
|
|
9
|
+
from kiln_ai.adapters.model_adapters.openai_compatible_config import (
|
|
10
|
+
OpenAICompatibleConfig,
|
|
11
|
+
)
|
|
12
|
+
from kiln_ai.adapters.model_adapters.openai_model_adapter import OpenAICompatibleAdapter
|
|
13
|
+
from kiln_ai.datamodel import Project, Task
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@pytest.fixture
|
|
17
|
+
def mock_task(tmp_path):
|
|
18
|
+
# Create a project first since Task requires a parent
|
|
19
|
+
project_path = tmp_path / "test_project" / "project.kiln"
|
|
20
|
+
project_path.parent.mkdir()
|
|
21
|
+
|
|
22
|
+
project = Project(name="Test Project", path=str(project_path))
|
|
23
|
+
project.save_to_file()
|
|
24
|
+
|
|
25
|
+
schema = {
|
|
26
|
+
"type": "object",
|
|
27
|
+
"properties": {"test": {"type": "string"}},
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
task = Task(
|
|
31
|
+
name="Test Task",
|
|
32
|
+
instruction="Test instruction",
|
|
33
|
+
parent=project,
|
|
34
|
+
output_json_schema=json.dumps(schema),
|
|
35
|
+
)
|
|
36
|
+
task.save_to_file()
|
|
37
|
+
return task
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@pytest.fixture
|
|
41
|
+
def config():
|
|
42
|
+
return OpenAICompatibleConfig(
|
|
43
|
+
api_key="test_key",
|
|
44
|
+
base_url="https://api.test.com",
|
|
45
|
+
model_name="test-model",
|
|
46
|
+
provider_name="openrouter",
|
|
47
|
+
default_headers={"X-Test": "test"},
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def test_initialization(config, mock_task):
|
|
52
|
+
adapter = OpenAICompatibleAdapter(
|
|
53
|
+
config=config,
|
|
54
|
+
kiln_task=mock_task,
|
|
55
|
+
prompt_id="simple_prompt_builder",
|
|
56
|
+
base_adapter_config=AdapterConfig(default_tags=["test-tag"]),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
assert isinstance(adapter.client, AsyncOpenAI)
|
|
60
|
+
assert adapter.config == config
|
|
61
|
+
assert adapter.run_config.task == mock_task
|
|
62
|
+
assert adapter.run_config.prompt_id == "simple_prompt_builder"
|
|
63
|
+
assert adapter.base_adapter_config.default_tags == ["test-tag"]
|
|
64
|
+
assert adapter.run_config.model_name == config.model_name
|
|
65
|
+
assert adapter.run_config.model_provider_name == config.provider_name
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def test_adapter_info(config, mock_task):
|
|
69
|
+
adapter = OpenAICompatibleAdapter(config=config, kiln_task=mock_task)
|
|
70
|
+
|
|
71
|
+
assert adapter.adapter_name() == "kiln_openai_compatible_adapter"
|
|
72
|
+
|
|
73
|
+
assert adapter.run_config.model_name == config.model_name
|
|
74
|
+
assert adapter.run_config.model_provider_name == config.provider_name
|
|
75
|
+
assert adapter.run_config.prompt_id == "simple_prompt_builder"
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@pytest.mark.asyncio
|
|
79
|
+
async def test_response_format_options_unstructured(config, mock_task):
|
|
80
|
+
adapter = OpenAICompatibleAdapter(config=config, kiln_task=mock_task)
|
|
81
|
+
|
|
82
|
+
# Mock has_structured_output to return False
|
|
83
|
+
with patch.object(adapter, "has_structured_output", return_value=False):
|
|
84
|
+
options = await adapter.response_format_options()
|
|
85
|
+
assert options == {}
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@pytest.mark.parametrize(
|
|
89
|
+
"mode",
|
|
90
|
+
[
|
|
91
|
+
StructuredOutputMode.json_mode,
|
|
92
|
+
StructuredOutputMode.json_instruction_and_object,
|
|
93
|
+
],
|
|
94
|
+
)
|
|
95
|
+
@pytest.mark.asyncio
|
|
96
|
+
async def test_response_format_options_json_mode(config, mock_task, mode):
|
|
97
|
+
adapter = OpenAICompatibleAdapter(config=config, kiln_task=mock_task)
|
|
98
|
+
|
|
99
|
+
with (
|
|
100
|
+
patch.object(adapter, "has_structured_output", return_value=True),
|
|
101
|
+
patch.object(adapter, "model_provider") as mock_provider,
|
|
102
|
+
):
|
|
103
|
+
mock_provider.return_value.structured_output_mode = mode
|
|
104
|
+
|
|
105
|
+
options = await adapter.response_format_options()
|
|
106
|
+
assert options == {"response_format": {"type": "json_object"}}
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@pytest.mark.parametrize(
|
|
110
|
+
"mode",
|
|
111
|
+
[
|
|
112
|
+
StructuredOutputMode.default,
|
|
113
|
+
StructuredOutputMode.function_calling,
|
|
114
|
+
],
|
|
115
|
+
)
|
|
116
|
+
@pytest.mark.asyncio
|
|
117
|
+
async def test_response_format_options_function_calling(config, mock_task, mode):
|
|
118
|
+
adapter = OpenAICompatibleAdapter(config=config, kiln_task=mock_task)
|
|
119
|
+
|
|
120
|
+
with (
|
|
121
|
+
patch.object(adapter, "has_structured_output", return_value=True),
|
|
122
|
+
patch.object(adapter, "model_provider") as mock_provider,
|
|
123
|
+
):
|
|
124
|
+
mock_provider.return_value.structured_output_mode = mode
|
|
125
|
+
|
|
126
|
+
options = await adapter.response_format_options()
|
|
127
|
+
assert "tools" in options
|
|
128
|
+
# full tool structure validated below
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@pytest.mark.asyncio
|
|
132
|
+
async def test_response_format_options_json_instructions(config, mock_task):
|
|
133
|
+
adapter = OpenAICompatibleAdapter(config=config, kiln_task=mock_task)
|
|
134
|
+
|
|
135
|
+
with (
|
|
136
|
+
patch.object(adapter, "has_structured_output", return_value=True),
|
|
137
|
+
patch.object(adapter, "model_provider") as mock_provider,
|
|
138
|
+
):
|
|
139
|
+
mock_provider.return_value.structured_output_mode = (
|
|
140
|
+
StructuredOutputMode.json_instructions
|
|
141
|
+
)
|
|
142
|
+
options = await adapter.response_format_options()
|
|
143
|
+
assert options == {}
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@pytest.mark.asyncio
|
|
147
|
+
async def test_response_format_options_json_schema(config, mock_task):
|
|
148
|
+
adapter = OpenAICompatibleAdapter(config=config, kiln_task=mock_task)
|
|
149
|
+
|
|
150
|
+
with (
|
|
151
|
+
patch.object(adapter, "has_structured_output", return_value=True),
|
|
152
|
+
patch.object(adapter, "model_provider") as mock_provider,
|
|
153
|
+
):
|
|
154
|
+
mock_provider.return_value.structured_output_mode = (
|
|
155
|
+
StructuredOutputMode.json_schema
|
|
156
|
+
)
|
|
157
|
+
options = await adapter.response_format_options()
|
|
158
|
+
assert options == {
|
|
159
|
+
"response_format": {
|
|
160
|
+
"type": "json_schema",
|
|
161
|
+
"json_schema": {
|
|
162
|
+
"name": "task_response",
|
|
163
|
+
"schema": mock_task.output_schema(),
|
|
164
|
+
},
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def test_tool_call_params_weak(config, mock_task):
|
|
170
|
+
adapter = OpenAICompatibleAdapter(config=config, kiln_task=mock_task)
|
|
171
|
+
|
|
172
|
+
params = adapter.tool_call_params(strict=False)
|
|
173
|
+
expected_schema = mock_task.output_schema()
|
|
174
|
+
expected_schema["additionalProperties"] = False
|
|
175
|
+
|
|
176
|
+
assert params == {
|
|
177
|
+
"tools": [
|
|
178
|
+
{
|
|
179
|
+
"type": "function",
|
|
180
|
+
"function": {
|
|
181
|
+
"name": "task_response",
|
|
182
|
+
"parameters": expected_schema,
|
|
183
|
+
},
|
|
184
|
+
}
|
|
185
|
+
],
|
|
186
|
+
"tool_choice": {
|
|
187
|
+
"type": "function",
|
|
188
|
+
"function": {"name": "task_response"},
|
|
189
|
+
},
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def test_tool_call_params_strict(config, mock_task):
|
|
194
|
+
config.provider_name = "openai"
|
|
195
|
+
adapter = OpenAICompatibleAdapter(config=config, kiln_task=mock_task)
|
|
196
|
+
|
|
197
|
+
params = adapter.tool_call_params(strict=True)
|
|
198
|
+
expected_schema = mock_task.output_schema()
|
|
199
|
+
expected_schema["additionalProperties"] = False
|
|
200
|
+
|
|
201
|
+
assert params == {
|
|
202
|
+
"tools": [
|
|
203
|
+
{
|
|
204
|
+
"type": "function",
|
|
205
|
+
"function": {
|
|
206
|
+
"name": "task_response",
|
|
207
|
+
"parameters": expected_schema,
|
|
208
|
+
"strict": True,
|
|
209
|
+
},
|
|
210
|
+
}
|
|
211
|
+
],
|
|
212
|
+
"tool_choice": {
|
|
213
|
+
"type": "function",
|
|
214
|
+
"function": {"name": "task_response"},
|
|
215
|
+
},
|
|
216
|
+
}
|