kiln-ai 0.12.0__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +4 -0
- kiln_ai/adapters/adapter_registry.py +153 -28
- kiln_ai/adapters/eval/__init__.py +28 -0
- kiln_ai/adapters/eval/eval_runner.py +4 -1
- kiln_ai/adapters/eval/g_eval.py +2 -1
- kiln_ai/adapters/eval/test_base_eval.py +1 -0
- kiln_ai/adapters/eval/test_eval_runner.py +1 -0
- kiln_ai/adapters/eval/test_g_eval.py +1 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/test_together_finetune.py +531 -0
- kiln_ai/adapters/fine_tune/together_finetune.py +325 -0
- kiln_ai/adapters/ml_model_list.py +638 -155
- kiln_ai/adapters/model_adapters/__init__.py +2 -4
- kiln_ai/adapters/model_adapters/base_adapter.py +14 -11
- kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
- kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
- kiln_ai/adapters/model_adapters/test_structured_output.py +23 -5
- kiln_ai/adapters/ollama_tools.py +3 -2
- kiln_ai/adapters/parsers/r1_parser.py +19 -14
- kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
- kiln_ai/adapters/provider_tools.py +50 -58
- kiln_ai/adapters/repair/test_repair_task.py +3 -3
- kiln_ai/adapters/run_output.py +1 -1
- kiln_ai/adapters/test_adapter_registry.py +17 -20
- kiln_ai/adapters/test_generate_docs.py +2 -2
- kiln_ai/adapters/test_prompt_adaptors.py +30 -19
- kiln_ai/adapters/test_provider_tools.py +26 -81
- kiln_ai/datamodel/basemodel.py +2 -0
- kiln_ai/datamodel/datamodel_enums.py +2 -0
- kiln_ai/datamodel/json_schema.py +1 -1
- kiln_ai/datamodel/task_output.py +13 -6
- kiln_ai/datamodel/test_basemodel.py +9 -0
- kiln_ai/datamodel/test_datasource.py +19 -0
- kiln_ai/utils/config.py +37 -0
- kiln_ai/utils/dataset_import.py +232 -0
- kiln_ai/utils/test_dataset_import.py +596 -0
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/METADATA +51 -7
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/RECORD +42 -39
- kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -309
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -10
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -289
- kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -343
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -216
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,407 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from unittest.mock import Mock, patch
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from kiln_ai.adapters.ml_model_list import ModelProviderName, StructuredOutputMode
|
|
7
|
+
from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
|
|
8
|
+
from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter
|
|
9
|
+
from kiln_ai.adapters.model_adapters.litellm_config import (
|
|
10
|
+
LiteLlmConfig,
|
|
11
|
+
)
|
|
12
|
+
from kiln_ai.datamodel import Project, Task
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.fixture
|
|
16
|
+
def mock_task(tmp_path):
|
|
17
|
+
# Create a project first since Task requires a parent
|
|
18
|
+
project_path = tmp_path / "test_project" / "project.kiln"
|
|
19
|
+
project_path.parent.mkdir()
|
|
20
|
+
|
|
21
|
+
project = Project(name="Test Project", path=str(project_path))
|
|
22
|
+
project.save_to_file()
|
|
23
|
+
|
|
24
|
+
schema = {
|
|
25
|
+
"type": "object",
|
|
26
|
+
"properties": {"test": {"type": "string"}},
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
task = Task(
|
|
30
|
+
name="Test Task",
|
|
31
|
+
instruction="Test instruction",
|
|
32
|
+
parent=project,
|
|
33
|
+
output_json_schema=json.dumps(schema),
|
|
34
|
+
)
|
|
35
|
+
task.save_to_file()
|
|
36
|
+
return task
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@pytest.fixture
|
|
40
|
+
def config():
|
|
41
|
+
return LiteLlmConfig(
|
|
42
|
+
base_url="https://api.test.com",
|
|
43
|
+
model_name="test-model",
|
|
44
|
+
provider_name="openrouter",
|
|
45
|
+
default_headers={"X-Test": "test"},
|
|
46
|
+
additional_body_options={"api_key": "test_key"},
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def test_initialization(config, mock_task):
|
|
51
|
+
adapter = LiteLlmAdapter(
|
|
52
|
+
config=config,
|
|
53
|
+
kiln_task=mock_task,
|
|
54
|
+
prompt_id="simple_prompt_builder",
|
|
55
|
+
base_adapter_config=AdapterConfig(default_tags=["test-tag"]),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
assert adapter.config == config
|
|
59
|
+
assert adapter.run_config.task == mock_task
|
|
60
|
+
assert adapter.run_config.prompt_id == "simple_prompt_builder"
|
|
61
|
+
assert adapter.base_adapter_config.default_tags == ["test-tag"]
|
|
62
|
+
assert adapter.run_config.model_name == config.model_name
|
|
63
|
+
assert adapter.run_config.model_provider_name == config.provider_name
|
|
64
|
+
assert adapter.config.additional_body_options["api_key"] == "test_key"
|
|
65
|
+
assert adapter._api_base == config.base_url
|
|
66
|
+
assert adapter._headers == config.default_headers
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def test_adapter_info(config, mock_task):
|
|
70
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
71
|
+
|
|
72
|
+
assert adapter.adapter_name() == "kiln_openai_compatible_adapter"
|
|
73
|
+
|
|
74
|
+
assert adapter.run_config.model_name == config.model_name
|
|
75
|
+
assert adapter.run_config.model_provider_name == config.provider_name
|
|
76
|
+
assert adapter.run_config.prompt_id == "simple_prompt_builder"
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@pytest.mark.asyncio
|
|
80
|
+
async def test_response_format_options_unstructured(config, mock_task):
|
|
81
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
82
|
+
|
|
83
|
+
# Mock has_structured_output to return False
|
|
84
|
+
with patch.object(adapter, "has_structured_output", return_value=False):
|
|
85
|
+
options = await adapter.response_format_options()
|
|
86
|
+
assert options == {}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@pytest.mark.parametrize(
|
|
90
|
+
"mode",
|
|
91
|
+
[
|
|
92
|
+
StructuredOutputMode.json_mode,
|
|
93
|
+
StructuredOutputMode.json_instruction_and_object,
|
|
94
|
+
],
|
|
95
|
+
)
|
|
96
|
+
@pytest.mark.asyncio
|
|
97
|
+
async def test_response_format_options_json_mode(config, mock_task, mode):
|
|
98
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
99
|
+
|
|
100
|
+
with (
|
|
101
|
+
patch.object(adapter, "has_structured_output", return_value=True),
|
|
102
|
+
patch.object(adapter, "model_provider") as mock_provider,
|
|
103
|
+
):
|
|
104
|
+
mock_provider.return_value.structured_output_mode = mode
|
|
105
|
+
|
|
106
|
+
options = await adapter.response_format_options()
|
|
107
|
+
assert options == {"response_format": {"type": "json_object"}}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@pytest.mark.parametrize(
|
|
111
|
+
"mode",
|
|
112
|
+
[
|
|
113
|
+
StructuredOutputMode.default,
|
|
114
|
+
StructuredOutputMode.function_calling,
|
|
115
|
+
],
|
|
116
|
+
)
|
|
117
|
+
@pytest.mark.asyncio
|
|
118
|
+
async def test_response_format_options_function_calling(config, mock_task, mode):
|
|
119
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
120
|
+
|
|
121
|
+
with (
|
|
122
|
+
patch.object(adapter, "has_structured_output", return_value=True),
|
|
123
|
+
patch.object(adapter, "model_provider") as mock_provider,
|
|
124
|
+
):
|
|
125
|
+
mock_provider.return_value.structured_output_mode = mode
|
|
126
|
+
|
|
127
|
+
options = await adapter.response_format_options()
|
|
128
|
+
assert "tools" in options
|
|
129
|
+
# full tool structure validated below
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@pytest.mark.parametrize(
|
|
133
|
+
"mode",
|
|
134
|
+
[
|
|
135
|
+
StructuredOutputMode.json_custom_instructions,
|
|
136
|
+
StructuredOutputMode.json_instructions,
|
|
137
|
+
],
|
|
138
|
+
)
|
|
139
|
+
@pytest.mark.asyncio
|
|
140
|
+
async def test_response_format_options_json_instructions(config, mock_task, mode):
|
|
141
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
142
|
+
|
|
143
|
+
with (
|
|
144
|
+
patch.object(adapter, "has_structured_output", return_value=True),
|
|
145
|
+
patch.object(adapter, "model_provider") as mock_provider,
|
|
146
|
+
):
|
|
147
|
+
mock_provider.return_value.structured_output_mode = (
|
|
148
|
+
StructuredOutputMode.json_instructions
|
|
149
|
+
)
|
|
150
|
+
options = await adapter.response_format_options()
|
|
151
|
+
assert options == {}
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@pytest.mark.asyncio
|
|
155
|
+
async def test_response_format_options_json_schema(config, mock_task):
|
|
156
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
157
|
+
|
|
158
|
+
with (
|
|
159
|
+
patch.object(adapter, "has_structured_output", return_value=True),
|
|
160
|
+
patch.object(adapter, "model_provider") as mock_provider,
|
|
161
|
+
):
|
|
162
|
+
mock_provider.return_value.structured_output_mode = (
|
|
163
|
+
StructuredOutputMode.json_schema
|
|
164
|
+
)
|
|
165
|
+
options = await adapter.response_format_options()
|
|
166
|
+
assert options == {
|
|
167
|
+
"response_format": {
|
|
168
|
+
"type": "json_schema",
|
|
169
|
+
"json_schema": {
|
|
170
|
+
"name": "task_response",
|
|
171
|
+
"schema": mock_task.output_schema(),
|
|
172
|
+
},
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def test_tool_call_params_weak(config, mock_task):
|
|
178
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
179
|
+
|
|
180
|
+
params = adapter.tool_call_params(strict=False)
|
|
181
|
+
expected_schema = mock_task.output_schema()
|
|
182
|
+
expected_schema["additionalProperties"] = False
|
|
183
|
+
|
|
184
|
+
assert params == {
|
|
185
|
+
"tools": [
|
|
186
|
+
{
|
|
187
|
+
"type": "function",
|
|
188
|
+
"function": {
|
|
189
|
+
"name": "task_response",
|
|
190
|
+
"parameters": expected_schema,
|
|
191
|
+
},
|
|
192
|
+
}
|
|
193
|
+
],
|
|
194
|
+
"tool_choice": {
|
|
195
|
+
"type": "function",
|
|
196
|
+
"function": {"name": "task_response"},
|
|
197
|
+
},
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def test_tool_call_params_strict(config, mock_task):
|
|
202
|
+
config.provider_name = "openai"
|
|
203
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
204
|
+
|
|
205
|
+
params = adapter.tool_call_params(strict=True)
|
|
206
|
+
expected_schema = mock_task.output_schema()
|
|
207
|
+
expected_schema["additionalProperties"] = False
|
|
208
|
+
|
|
209
|
+
assert params == {
|
|
210
|
+
"tools": [
|
|
211
|
+
{
|
|
212
|
+
"type": "function",
|
|
213
|
+
"function": {
|
|
214
|
+
"name": "task_response",
|
|
215
|
+
"parameters": expected_schema,
|
|
216
|
+
"strict": True,
|
|
217
|
+
},
|
|
218
|
+
}
|
|
219
|
+
],
|
|
220
|
+
"tool_choice": {
|
|
221
|
+
"type": "function",
|
|
222
|
+
"function": {"name": "task_response"},
|
|
223
|
+
},
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
@pytest.mark.parametrize(
|
|
228
|
+
"provider_name,expected_prefix",
|
|
229
|
+
[
|
|
230
|
+
(ModelProviderName.openrouter, "openrouter"),
|
|
231
|
+
(ModelProviderName.openai, "openai"),
|
|
232
|
+
(ModelProviderName.groq, "groq"),
|
|
233
|
+
(ModelProviderName.anthropic, "anthropic"),
|
|
234
|
+
(ModelProviderName.ollama, "openai"),
|
|
235
|
+
(ModelProviderName.gemini_api, "gemini"),
|
|
236
|
+
(ModelProviderName.fireworks_ai, "fireworks_ai"),
|
|
237
|
+
(ModelProviderName.amazon_bedrock, "bedrock"),
|
|
238
|
+
(ModelProviderName.azure_openai, "azure"),
|
|
239
|
+
(ModelProviderName.huggingface, "huggingface"),
|
|
240
|
+
(ModelProviderName.vertex, "vertex_ai"),
|
|
241
|
+
(ModelProviderName.together_ai, "together_ai"),
|
|
242
|
+
],
|
|
243
|
+
)
|
|
244
|
+
def test_litellm_model_id_standard_providers(
|
|
245
|
+
config, mock_task, provider_name, expected_prefix
|
|
246
|
+
):
|
|
247
|
+
"""Test litellm_model_id for standard providers"""
|
|
248
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
249
|
+
|
|
250
|
+
# Mock the model_provider method to return a provider with the specified name
|
|
251
|
+
mock_provider = Mock()
|
|
252
|
+
mock_provider.name = provider_name
|
|
253
|
+
mock_provider.model_id = "test-model"
|
|
254
|
+
|
|
255
|
+
with patch.object(adapter, "model_provider", return_value=mock_provider):
|
|
256
|
+
model_id = adapter.litellm_model_id()
|
|
257
|
+
|
|
258
|
+
assert model_id == f"{expected_prefix}/test-model"
|
|
259
|
+
# Verify caching works
|
|
260
|
+
assert adapter._litellm_model_id == model_id
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
@pytest.mark.parametrize(
|
|
264
|
+
"provider_name",
|
|
265
|
+
[
|
|
266
|
+
ModelProviderName.openai_compatible,
|
|
267
|
+
ModelProviderName.kiln_custom_registry,
|
|
268
|
+
ModelProviderName.kiln_fine_tune,
|
|
269
|
+
],
|
|
270
|
+
)
|
|
271
|
+
def test_litellm_model_id_custom_providers(config, mock_task, provider_name):
|
|
272
|
+
"""Test litellm_model_id for custom providers that require a base URL"""
|
|
273
|
+
config.base_url = "https://api.custom.com"
|
|
274
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
275
|
+
|
|
276
|
+
# Mock the model_provider method
|
|
277
|
+
mock_provider = Mock()
|
|
278
|
+
mock_provider.name = provider_name
|
|
279
|
+
mock_provider.model_id = "custom-model"
|
|
280
|
+
|
|
281
|
+
with patch.object(adapter, "model_provider", return_value=mock_provider):
|
|
282
|
+
model_id = adapter.litellm_model_id()
|
|
283
|
+
|
|
284
|
+
# Custom providers should use "openai" as the provider name
|
|
285
|
+
assert model_id == "openai/custom-model"
|
|
286
|
+
assert adapter._litellm_model_id == model_id
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def test_litellm_model_id_custom_provider_no_base_url(config, mock_task):
|
|
290
|
+
"""Test litellm_model_id raises error for custom providers without base URL"""
|
|
291
|
+
config.base_url = None
|
|
292
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
293
|
+
|
|
294
|
+
# Mock the model_provider method
|
|
295
|
+
mock_provider = Mock()
|
|
296
|
+
mock_provider.name = ModelProviderName.openai_compatible
|
|
297
|
+
mock_provider.model_id = "custom-model"
|
|
298
|
+
|
|
299
|
+
with patch.object(adapter, "model_provider", return_value=mock_provider):
|
|
300
|
+
with pytest.raises(ValueError, match="Explicit Base URL is required"):
|
|
301
|
+
adapter.litellm_model_id()
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def test_litellm_model_id_no_model_id(config, mock_task):
|
|
305
|
+
"""Test litellm_model_id raises error when provider has no model_id"""
|
|
306
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
307
|
+
|
|
308
|
+
# Mock the model_provider method to return a provider with no model_id
|
|
309
|
+
mock_provider = Mock()
|
|
310
|
+
mock_provider.name = ModelProviderName.openai
|
|
311
|
+
mock_provider.model_id = None
|
|
312
|
+
|
|
313
|
+
with patch.object(adapter, "model_provider", return_value=mock_provider):
|
|
314
|
+
with pytest.raises(ValueError, match="Model ID is required"):
|
|
315
|
+
adapter.litellm_model_id()
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def test_litellm_model_id_caching(config, mock_task):
|
|
319
|
+
"""Test that litellm_model_id caches the result"""
|
|
320
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
321
|
+
|
|
322
|
+
# Set the cached value directly
|
|
323
|
+
adapter._litellm_model_id = "cached-value"
|
|
324
|
+
|
|
325
|
+
# The method should return the cached value without calling model_provider
|
|
326
|
+
with patch.object(adapter, "model_provider") as mock_model_provider:
|
|
327
|
+
model_id = adapter.litellm_model_id()
|
|
328
|
+
|
|
329
|
+
assert model_id == "cached-value"
|
|
330
|
+
mock_model_provider.assert_not_called()
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def test_litellm_model_id_unknown_provider(config, mock_task):
|
|
334
|
+
"""Test litellm_model_id raises error for unknown provider"""
|
|
335
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
336
|
+
|
|
337
|
+
# Create a mock provider with an unknown name
|
|
338
|
+
mock_provider = Mock()
|
|
339
|
+
mock_provider.name = "unknown_provider" # Not in ModelProviderName enum
|
|
340
|
+
mock_provider.model_id = "test-model"
|
|
341
|
+
|
|
342
|
+
with patch.object(adapter, "model_provider", return_value=mock_provider):
|
|
343
|
+
with patch(
|
|
344
|
+
"kiln_ai.adapters.model_adapters.litellm_adapter.raise_exhaustive_enum_error"
|
|
345
|
+
) as mock_raise_error:
|
|
346
|
+
mock_raise_error.side_effect = Exception("Test error")
|
|
347
|
+
|
|
348
|
+
with pytest.raises(Exception, match="Test error"):
|
|
349
|
+
adapter.litellm_model_id()
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
@pytest.mark.asyncio
|
|
353
|
+
@pytest.mark.parametrize(
|
|
354
|
+
"top_logprobs,response_format,extra_body",
|
|
355
|
+
[
|
|
356
|
+
(None, {}, {}), # Basic case
|
|
357
|
+
(5, {}, {}), # With logprobs
|
|
358
|
+
(
|
|
359
|
+
None,
|
|
360
|
+
{"response_format": {"type": "json_object"}},
|
|
361
|
+
{},
|
|
362
|
+
), # With response format
|
|
363
|
+
(
|
|
364
|
+
3,
|
|
365
|
+
{"tools": [{"type": "function"}]},
|
|
366
|
+
{"reasoning_effort": 0.8},
|
|
367
|
+
), # Combined options
|
|
368
|
+
],
|
|
369
|
+
)
|
|
370
|
+
async def test_build_completion_kwargs(
|
|
371
|
+
config, mock_task, top_logprobs, response_format, extra_body
|
|
372
|
+
):
|
|
373
|
+
"""Test build_completion_kwargs with various configurations"""
|
|
374
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
375
|
+
mock_provider = Mock()
|
|
376
|
+
messages = [{"role": "user", "content": "Hello"}]
|
|
377
|
+
|
|
378
|
+
with (
|
|
379
|
+
patch.object(adapter, "model_provider", return_value=mock_provider),
|
|
380
|
+
patch.object(adapter, "litellm_model_id", return_value="openai/test-model"),
|
|
381
|
+
patch.object(adapter, "build_extra_body", return_value=extra_body),
|
|
382
|
+
patch.object(adapter, "response_format_options", return_value=response_format),
|
|
383
|
+
):
|
|
384
|
+
kwargs = await adapter.build_completion_kwargs(
|
|
385
|
+
mock_provider, messages, top_logprobs
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
# Verify core functionality
|
|
389
|
+
assert kwargs["model"] == "openai/test-model"
|
|
390
|
+
assert kwargs["messages"] == messages
|
|
391
|
+
assert kwargs["api_base"] == config.base_url
|
|
392
|
+
|
|
393
|
+
# Verify optional parameters
|
|
394
|
+
if top_logprobs is not None:
|
|
395
|
+
assert kwargs["logprobs"] is True
|
|
396
|
+
assert kwargs["top_logprobs"] == top_logprobs
|
|
397
|
+
else:
|
|
398
|
+
assert "logprobs" not in kwargs
|
|
399
|
+
assert "top_logprobs" not in kwargs
|
|
400
|
+
|
|
401
|
+
# Verify response format is included
|
|
402
|
+
for key, value in response_format.items():
|
|
403
|
+
assert kwargs[key] == value
|
|
404
|
+
|
|
405
|
+
# Verify extra body is included
|
|
406
|
+
for key, value in extra_body.items():
|
|
407
|
+
assert kwargs[key] == value
|
|
@@ -66,7 +66,8 @@ async def test_mock_unstructred_response(tmp_path):
|
|
|
66
66
|
|
|
67
67
|
# don't error on valid response
|
|
68
68
|
adapter = MockAdapter(task, response={"setup": "asdf", "punchline": "asdf"})
|
|
69
|
-
|
|
69
|
+
run = await adapter.invoke("You are a mock, send me the response!")
|
|
70
|
+
answer = json.loads(run.output.output)
|
|
70
71
|
assert answer["setup"] == "asdf"
|
|
71
72
|
assert answer["punchline"] == "asdf"
|
|
72
73
|
|
|
@@ -76,9 +77,12 @@ async def test_mock_unstructred_response(tmp_path):
|
|
|
76
77
|
answer = await adapter.invoke("You are a mock, send me the response!")
|
|
77
78
|
|
|
78
79
|
adapter = MockAdapter(task, response="string instead of dict")
|
|
79
|
-
with pytest.raises(
|
|
80
|
+
with pytest.raises(
|
|
81
|
+
ValueError,
|
|
82
|
+
match="This task requires JSON output but the model didn't return valid JSON",
|
|
83
|
+
):
|
|
80
84
|
# Not a structed response so should error
|
|
81
|
-
|
|
85
|
+
run = await adapter.invoke("You are a mock, send me the response!")
|
|
82
86
|
|
|
83
87
|
# Should error, expecting a string, not a dict
|
|
84
88
|
project = datamodel.Project(name="test", path=tmp_path / "test.kiln")
|
|
@@ -143,7 +147,8 @@ async def run_structured_output_test(tmp_path: Path, model_name: str, provider:
|
|
|
143
147
|
task = build_structured_output_test_task(tmp_path)
|
|
144
148
|
a = adapter_for_task(task, model_name=model_name, provider=provider)
|
|
145
149
|
try:
|
|
146
|
-
|
|
150
|
+
run = await a.invoke("Cows") # a joke about cows
|
|
151
|
+
parsed = json.loads(run.output.output)
|
|
147
152
|
except ValueError as e:
|
|
148
153
|
if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
|
|
149
154
|
pytest.skip(
|
|
@@ -162,6 +167,12 @@ async def run_structured_output_test(tmp_path: Path, model_name: str, provider:
|
|
|
162
167
|
assert rating >= 0
|
|
163
168
|
assert rating <= 10
|
|
164
169
|
|
|
170
|
+
# Check reasoning models
|
|
171
|
+
assert a._model_provider is not None
|
|
172
|
+
if a._model_provider.reasoning_capable:
|
|
173
|
+
assert "reasoning" in run.intermediate_outputs
|
|
174
|
+
assert isinstance(run.intermediate_outputs["reasoning"], str)
|
|
175
|
+
|
|
165
176
|
|
|
166
177
|
def build_structured_input_test_task(tmp_path: Path):
|
|
167
178
|
project = datamodel.Project(name="test", path=tmp_path / "test.kiln")
|
|
@@ -220,7 +231,8 @@ async def run_structured_input_task(
|
|
|
220
231
|
await a.invoke({"a": 1, "b": 2, "d": 3})
|
|
221
232
|
|
|
222
233
|
try:
|
|
223
|
-
|
|
234
|
+
run = await a.invoke({"a": 2, "b": 2, "c": 2})
|
|
235
|
+
response = run.output.output
|
|
224
236
|
except ValueError as e:
|
|
225
237
|
if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
|
|
226
238
|
pytest.skip(
|
|
@@ -241,6 +253,12 @@ async def run_structured_input_task(
|
|
|
241
253
|
assert a.run_config.model_name == model_name
|
|
242
254
|
assert a.run_config.model_provider_name == provider
|
|
243
255
|
|
|
256
|
+
# Check reasoning models
|
|
257
|
+
assert a._model_provider is not None
|
|
258
|
+
if a._model_provider.reasoning_capable:
|
|
259
|
+
assert "reasoning" in run.intermediate_outputs
|
|
260
|
+
assert isinstance(run.intermediate_outputs["reasoning"], str)
|
|
261
|
+
|
|
244
262
|
|
|
245
263
|
@pytest.mark.paid
|
|
246
264
|
async def test_structured_input_gpt_4o_mini(tmp_path):
|
kiln_ai/adapters/ollama_tools.py
CHANGED
|
@@ -38,6 +38,7 @@ async def ollama_online() -> bool:
|
|
|
38
38
|
|
|
39
39
|
class OllamaConnection(BaseModel):
|
|
40
40
|
message: str
|
|
41
|
+
version: str | None = None
|
|
41
42
|
supported_models: List[str]
|
|
42
43
|
untested_models: List[str] = Field(default_factory=list)
|
|
43
44
|
|
|
@@ -49,7 +50,7 @@ class OllamaConnection(BaseModel):
|
|
|
49
50
|
def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
|
|
50
51
|
# Build a list of models we support for Ollama from the built-in model list
|
|
51
52
|
supported_ollama_models = [
|
|
52
|
-
provider.
|
|
53
|
+
provider.model_id
|
|
53
54
|
for model in built_in_models
|
|
54
55
|
for provider in model.providers
|
|
55
56
|
if provider.name == ModelProviderName.ollama
|
|
@@ -60,7 +61,7 @@ def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
|
|
|
60
61
|
alias
|
|
61
62
|
for model in built_in_models
|
|
62
63
|
for provider in model.providers
|
|
63
|
-
for alias in provider.
|
|
64
|
+
for alias in provider.ollama_model_aliases or []
|
|
64
65
|
]
|
|
65
66
|
)
|
|
66
67
|
|
|
@@ -20,21 +20,33 @@ class R1ThinkingParser(BaseParser):
|
|
|
20
20
|
Raises:
|
|
21
21
|
ValueError: If response format is invalid (missing tags, multiple tags, or no content after closing tag)
|
|
22
22
|
"""
|
|
23
|
+
|
|
24
|
+
# The upstream providers (litellm, openrouter, fireworks) all keep changing their response formats, sometimes adding reasoning parsing where it didn't previously exist.
|
|
25
|
+
# If they do it already, great just return. If not we parse it ourselves. Not ideal, but better than upstream changes breaking the app.
|
|
26
|
+
if (
|
|
27
|
+
original_output.intermediate_outputs is not None
|
|
28
|
+
and "reasoning" in original_output.intermediate_outputs
|
|
29
|
+
):
|
|
30
|
+
return original_output
|
|
31
|
+
|
|
23
32
|
# This parser only works for strings
|
|
24
33
|
if not isinstance(original_output.output, str):
|
|
25
34
|
raise ValueError("Response must be a string for R1 parser")
|
|
26
35
|
|
|
27
36
|
# Strip whitespace and validate basic structure
|
|
28
37
|
cleaned_response = original_output.output.strip()
|
|
29
|
-
if not cleaned_response.startswith(self.START_TAG):
|
|
30
|
-
raise ValueError("Response must start with <think> tag")
|
|
31
38
|
|
|
32
39
|
# Find the thinking tags
|
|
33
|
-
think_start = cleaned_response.find(self.START_TAG)
|
|
34
40
|
think_end = cleaned_response.find(self.END_TAG)
|
|
41
|
+
if think_end == -1:
|
|
42
|
+
raise ValueError("Missing </think> tag")
|
|
35
43
|
|
|
36
|
-
|
|
37
|
-
|
|
44
|
+
think_tag_start = cleaned_response.find(self.START_TAG)
|
|
45
|
+
if think_tag_start == -1:
|
|
46
|
+
# We allow no start <think>, thinking starts on first char. QwQ does this.
|
|
47
|
+
think_start = 0
|
|
48
|
+
else:
|
|
49
|
+
think_start = think_tag_start + len(self.START_TAG)
|
|
38
50
|
|
|
39
51
|
# Check for multiple tags
|
|
40
52
|
if (
|
|
@@ -44,9 +56,7 @@ class R1ThinkingParser(BaseParser):
|
|
|
44
56
|
raise ValueError("Multiple thinking tags found")
|
|
45
57
|
|
|
46
58
|
# Extract thinking content
|
|
47
|
-
thinking_content = cleaned_response[
|
|
48
|
-
think_start + len(self.START_TAG) : think_end
|
|
49
|
-
].strip()
|
|
59
|
+
thinking_content = cleaned_response[think_start:think_end].strip()
|
|
50
60
|
|
|
51
61
|
# Extract result (everything after </think>)
|
|
52
62
|
result = cleaned_response[think_end + len(self.END_TAG) :].strip()
|
|
@@ -54,16 +64,11 @@ class R1ThinkingParser(BaseParser):
|
|
|
54
64
|
if not result or len(result) == 0:
|
|
55
65
|
raise ValueError("No content found after </think> tag")
|
|
56
66
|
|
|
57
|
-
# Parse JSON if needed
|
|
58
|
-
output = result
|
|
59
|
-
if self.structured_output:
|
|
60
|
-
output = parse_json_string(result)
|
|
61
|
-
|
|
62
67
|
# Add thinking content to intermediate outputs if it exists
|
|
63
68
|
intermediate_outputs = original_output.intermediate_outputs or {}
|
|
64
69
|
intermediate_outputs["reasoning"] = thinking_content
|
|
65
70
|
|
|
66
71
|
return RunOutput(
|
|
67
|
-
output=
|
|
72
|
+
output=result,
|
|
68
73
|
intermediate_outputs=intermediate_outputs,
|
|
69
74
|
)
|
|
@@ -19,6 +19,16 @@ def test_valid_response(parser):
|
|
|
19
19
|
assert parsed.output == "This is the result"
|
|
20
20
|
|
|
21
21
|
|
|
22
|
+
def test_already_parsed_response(parser):
|
|
23
|
+
response = RunOutput(
|
|
24
|
+
output="This is the result",
|
|
25
|
+
intermediate_outputs={"reasoning": "This is thinking content"},
|
|
26
|
+
)
|
|
27
|
+
parsed = parser.parse_output(response)
|
|
28
|
+
assert parsed.intermediate_outputs["reasoning"] == "This is thinking content"
|
|
29
|
+
assert parsed.output == "This is the result"
|
|
30
|
+
|
|
31
|
+
|
|
22
32
|
def test_response_with_whitespace(parser):
|
|
23
33
|
response = RunOutput(
|
|
24
34
|
output="""
|
|
@@ -37,14 +47,16 @@ def test_response_with_whitespace(parser):
|
|
|
37
47
|
|
|
38
48
|
|
|
39
49
|
def test_missing_start_tag(parser):
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
50
|
+
parsed = parser.parse_output(
|
|
51
|
+
RunOutput(output="Some content</think>result", intermediate_outputs=None)
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
assert parsed.intermediate_outputs["reasoning"] == "Some content"
|
|
55
|
+
assert parsed.output == "result"
|
|
44
56
|
|
|
45
57
|
|
|
46
58
|
def test_missing_end_tag(parser):
|
|
47
|
-
with pytest.raises(ValueError, match="Missing
|
|
59
|
+
with pytest.raises(ValueError, match="Missing </think> tag"):
|
|
48
60
|
parser.parse_output(
|
|
49
61
|
RunOutput(output="<think>Some content", intermediate_outputs=None)
|
|
50
62
|
)
|