kiln-ai 0.11.1__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +4 -0
- kiln_ai/adapters/adapter_registry.py +163 -39
- kiln_ai/adapters/data_gen/data_gen_task.py +18 -0
- kiln_ai/adapters/eval/__init__.py +28 -0
- kiln_ai/adapters/eval/base_eval.py +164 -0
- kiln_ai/adapters/eval/eval_runner.py +270 -0
- kiln_ai/adapters/eval/g_eval.py +368 -0
- kiln_ai/adapters/eval/registry.py +16 -0
- kiln_ai/adapters/eval/test_base_eval.py +325 -0
- kiln_ai/adapters/eval/test_eval_runner.py +641 -0
- kiln_ai/adapters/eval/test_g_eval.py +498 -0
- kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +4 -1
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +1 -1
- kiln_ai/adapters/fine_tune/test_together_finetune.py +531 -0
- kiln_ai/adapters/fine_tune/together_finetune.py +325 -0
- kiln_ai/adapters/ml_model_list.py +758 -163
- kiln_ai/adapters/model_adapters/__init__.py +2 -4
- kiln_ai/adapters/model_adapters/base_adapter.py +61 -43
- kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
- kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +22 -13
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -19
- kiln_ai/adapters/model_adapters/test_structured_output.py +59 -35
- kiln_ai/adapters/ollama_tools.py +3 -3
- kiln_ai/adapters/parsers/r1_parser.py +19 -14
- kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
- kiln_ai/adapters/prompt_builders.py +80 -42
- kiln_ai/adapters/provider_tools.py +50 -58
- kiln_ai/adapters/repair/repair_task.py +9 -21
- kiln_ai/adapters/repair/test_repair_task.py +6 -6
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +26 -29
- kiln_ai/adapters/test_generate_docs.py +4 -4
- kiln_ai/adapters/test_ollama_tools.py +0 -1
- kiln_ai/adapters/test_prompt_adaptors.py +47 -33
- kiln_ai/adapters/test_prompt_builders.py +91 -31
- kiln_ai/adapters/test_provider_tools.py +26 -81
- kiln_ai/datamodel/__init__.py +50 -952
- kiln_ai/datamodel/basemodel.py +2 -0
- kiln_ai/datamodel/datamodel_enums.py +60 -0
- kiln_ai/datamodel/dataset_filters.py +114 -0
- kiln_ai/datamodel/dataset_split.py +170 -0
- kiln_ai/datamodel/eval.py +298 -0
- kiln_ai/datamodel/finetune.py +105 -0
- kiln_ai/datamodel/json_schema.py +7 -1
- kiln_ai/datamodel/project.py +23 -0
- kiln_ai/datamodel/prompt.py +37 -0
- kiln_ai/datamodel/prompt_id.py +83 -0
- kiln_ai/datamodel/strict_mode.py +24 -0
- kiln_ai/datamodel/task.py +181 -0
- kiln_ai/datamodel/task_output.py +328 -0
- kiln_ai/datamodel/task_run.py +164 -0
- kiln_ai/datamodel/test_basemodel.py +19 -11
- kiln_ai/datamodel/test_dataset_filters.py +71 -0
- kiln_ai/datamodel/test_dataset_split.py +32 -8
- kiln_ai/datamodel/test_datasource.py +22 -2
- kiln_ai/datamodel/test_eval_model.py +635 -0
- kiln_ai/datamodel/test_example_models.py +9 -13
- kiln_ai/datamodel/test_json_schema.py +23 -0
- kiln_ai/datamodel/test_models.py +2 -2
- kiln_ai/datamodel/test_prompt_id.py +129 -0
- kiln_ai/datamodel/test_task.py +159 -0
- kiln_ai/utils/config.py +43 -1
- kiln_ai/utils/dataset_import.py +232 -0
- kiln_ai/utils/test_dataset_import.py +596 -0
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/METADATA +86 -6
- kiln_ai-0.13.0.dist-info/RECORD +103 -0
- kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -302
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -11
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -246
- kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -350
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -225
- kiln_ai-0.11.1.dist-info/RECORD +0 -76
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,407 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from unittest.mock import Mock, patch
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from kiln_ai.adapters.ml_model_list import ModelProviderName, StructuredOutputMode
|
|
7
|
+
from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
|
|
8
|
+
from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter
|
|
9
|
+
from kiln_ai.adapters.model_adapters.litellm_config import (
|
|
10
|
+
LiteLlmConfig,
|
|
11
|
+
)
|
|
12
|
+
from kiln_ai.datamodel import Project, Task
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.fixture
|
|
16
|
+
def mock_task(tmp_path):
|
|
17
|
+
# Create a project first since Task requires a parent
|
|
18
|
+
project_path = tmp_path / "test_project" / "project.kiln"
|
|
19
|
+
project_path.parent.mkdir()
|
|
20
|
+
|
|
21
|
+
project = Project(name="Test Project", path=str(project_path))
|
|
22
|
+
project.save_to_file()
|
|
23
|
+
|
|
24
|
+
schema = {
|
|
25
|
+
"type": "object",
|
|
26
|
+
"properties": {"test": {"type": "string"}},
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
task = Task(
|
|
30
|
+
name="Test Task",
|
|
31
|
+
instruction="Test instruction",
|
|
32
|
+
parent=project,
|
|
33
|
+
output_json_schema=json.dumps(schema),
|
|
34
|
+
)
|
|
35
|
+
task.save_to_file()
|
|
36
|
+
return task
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@pytest.fixture
|
|
40
|
+
def config():
|
|
41
|
+
return LiteLlmConfig(
|
|
42
|
+
base_url="https://api.test.com",
|
|
43
|
+
model_name="test-model",
|
|
44
|
+
provider_name="openrouter",
|
|
45
|
+
default_headers={"X-Test": "test"},
|
|
46
|
+
additional_body_options={"api_key": "test_key"},
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def test_initialization(config, mock_task):
|
|
51
|
+
adapter = LiteLlmAdapter(
|
|
52
|
+
config=config,
|
|
53
|
+
kiln_task=mock_task,
|
|
54
|
+
prompt_id="simple_prompt_builder",
|
|
55
|
+
base_adapter_config=AdapterConfig(default_tags=["test-tag"]),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
assert adapter.config == config
|
|
59
|
+
assert adapter.run_config.task == mock_task
|
|
60
|
+
assert adapter.run_config.prompt_id == "simple_prompt_builder"
|
|
61
|
+
assert adapter.base_adapter_config.default_tags == ["test-tag"]
|
|
62
|
+
assert adapter.run_config.model_name == config.model_name
|
|
63
|
+
assert adapter.run_config.model_provider_name == config.provider_name
|
|
64
|
+
assert adapter.config.additional_body_options["api_key"] == "test_key"
|
|
65
|
+
assert adapter._api_base == config.base_url
|
|
66
|
+
assert adapter._headers == config.default_headers
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def test_adapter_info(config, mock_task):
|
|
70
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
71
|
+
|
|
72
|
+
assert adapter.adapter_name() == "kiln_openai_compatible_adapter"
|
|
73
|
+
|
|
74
|
+
assert adapter.run_config.model_name == config.model_name
|
|
75
|
+
assert adapter.run_config.model_provider_name == config.provider_name
|
|
76
|
+
assert adapter.run_config.prompt_id == "simple_prompt_builder"
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@pytest.mark.asyncio
|
|
80
|
+
async def test_response_format_options_unstructured(config, mock_task):
|
|
81
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
82
|
+
|
|
83
|
+
# Mock has_structured_output to return False
|
|
84
|
+
with patch.object(adapter, "has_structured_output", return_value=False):
|
|
85
|
+
options = await adapter.response_format_options()
|
|
86
|
+
assert options == {}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@pytest.mark.parametrize(
|
|
90
|
+
"mode",
|
|
91
|
+
[
|
|
92
|
+
StructuredOutputMode.json_mode,
|
|
93
|
+
StructuredOutputMode.json_instruction_and_object,
|
|
94
|
+
],
|
|
95
|
+
)
|
|
96
|
+
@pytest.mark.asyncio
|
|
97
|
+
async def test_response_format_options_json_mode(config, mock_task, mode):
|
|
98
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
99
|
+
|
|
100
|
+
with (
|
|
101
|
+
patch.object(adapter, "has_structured_output", return_value=True),
|
|
102
|
+
patch.object(adapter, "model_provider") as mock_provider,
|
|
103
|
+
):
|
|
104
|
+
mock_provider.return_value.structured_output_mode = mode
|
|
105
|
+
|
|
106
|
+
options = await adapter.response_format_options()
|
|
107
|
+
assert options == {"response_format": {"type": "json_object"}}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@pytest.mark.parametrize(
|
|
111
|
+
"mode",
|
|
112
|
+
[
|
|
113
|
+
StructuredOutputMode.default,
|
|
114
|
+
StructuredOutputMode.function_calling,
|
|
115
|
+
],
|
|
116
|
+
)
|
|
117
|
+
@pytest.mark.asyncio
|
|
118
|
+
async def test_response_format_options_function_calling(config, mock_task, mode):
|
|
119
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
120
|
+
|
|
121
|
+
with (
|
|
122
|
+
patch.object(adapter, "has_structured_output", return_value=True),
|
|
123
|
+
patch.object(adapter, "model_provider") as mock_provider,
|
|
124
|
+
):
|
|
125
|
+
mock_provider.return_value.structured_output_mode = mode
|
|
126
|
+
|
|
127
|
+
options = await adapter.response_format_options()
|
|
128
|
+
assert "tools" in options
|
|
129
|
+
# full tool structure validated below
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@pytest.mark.parametrize(
|
|
133
|
+
"mode",
|
|
134
|
+
[
|
|
135
|
+
StructuredOutputMode.json_custom_instructions,
|
|
136
|
+
StructuredOutputMode.json_instructions,
|
|
137
|
+
],
|
|
138
|
+
)
|
|
139
|
+
@pytest.mark.asyncio
|
|
140
|
+
async def test_response_format_options_json_instructions(config, mock_task, mode):
|
|
141
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
142
|
+
|
|
143
|
+
with (
|
|
144
|
+
patch.object(adapter, "has_structured_output", return_value=True),
|
|
145
|
+
patch.object(adapter, "model_provider") as mock_provider,
|
|
146
|
+
):
|
|
147
|
+
mock_provider.return_value.structured_output_mode = (
|
|
148
|
+
StructuredOutputMode.json_instructions
|
|
149
|
+
)
|
|
150
|
+
options = await adapter.response_format_options()
|
|
151
|
+
assert options == {}
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@pytest.mark.asyncio
|
|
155
|
+
async def test_response_format_options_json_schema(config, mock_task):
|
|
156
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
157
|
+
|
|
158
|
+
with (
|
|
159
|
+
patch.object(adapter, "has_structured_output", return_value=True),
|
|
160
|
+
patch.object(adapter, "model_provider") as mock_provider,
|
|
161
|
+
):
|
|
162
|
+
mock_provider.return_value.structured_output_mode = (
|
|
163
|
+
StructuredOutputMode.json_schema
|
|
164
|
+
)
|
|
165
|
+
options = await adapter.response_format_options()
|
|
166
|
+
assert options == {
|
|
167
|
+
"response_format": {
|
|
168
|
+
"type": "json_schema",
|
|
169
|
+
"json_schema": {
|
|
170
|
+
"name": "task_response",
|
|
171
|
+
"schema": mock_task.output_schema(),
|
|
172
|
+
},
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def test_tool_call_params_weak(config, mock_task):
|
|
178
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
179
|
+
|
|
180
|
+
params = adapter.tool_call_params(strict=False)
|
|
181
|
+
expected_schema = mock_task.output_schema()
|
|
182
|
+
expected_schema["additionalProperties"] = False
|
|
183
|
+
|
|
184
|
+
assert params == {
|
|
185
|
+
"tools": [
|
|
186
|
+
{
|
|
187
|
+
"type": "function",
|
|
188
|
+
"function": {
|
|
189
|
+
"name": "task_response",
|
|
190
|
+
"parameters": expected_schema,
|
|
191
|
+
},
|
|
192
|
+
}
|
|
193
|
+
],
|
|
194
|
+
"tool_choice": {
|
|
195
|
+
"type": "function",
|
|
196
|
+
"function": {"name": "task_response"},
|
|
197
|
+
},
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def test_tool_call_params_strict(config, mock_task):
|
|
202
|
+
config.provider_name = "openai"
|
|
203
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
204
|
+
|
|
205
|
+
params = adapter.tool_call_params(strict=True)
|
|
206
|
+
expected_schema = mock_task.output_schema()
|
|
207
|
+
expected_schema["additionalProperties"] = False
|
|
208
|
+
|
|
209
|
+
assert params == {
|
|
210
|
+
"tools": [
|
|
211
|
+
{
|
|
212
|
+
"type": "function",
|
|
213
|
+
"function": {
|
|
214
|
+
"name": "task_response",
|
|
215
|
+
"parameters": expected_schema,
|
|
216
|
+
"strict": True,
|
|
217
|
+
},
|
|
218
|
+
}
|
|
219
|
+
],
|
|
220
|
+
"tool_choice": {
|
|
221
|
+
"type": "function",
|
|
222
|
+
"function": {"name": "task_response"},
|
|
223
|
+
},
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
@pytest.mark.parametrize(
|
|
228
|
+
"provider_name,expected_prefix",
|
|
229
|
+
[
|
|
230
|
+
(ModelProviderName.openrouter, "openrouter"),
|
|
231
|
+
(ModelProviderName.openai, "openai"),
|
|
232
|
+
(ModelProviderName.groq, "groq"),
|
|
233
|
+
(ModelProviderName.anthropic, "anthropic"),
|
|
234
|
+
(ModelProviderName.ollama, "openai"),
|
|
235
|
+
(ModelProviderName.gemini_api, "gemini"),
|
|
236
|
+
(ModelProviderName.fireworks_ai, "fireworks_ai"),
|
|
237
|
+
(ModelProviderName.amazon_bedrock, "bedrock"),
|
|
238
|
+
(ModelProviderName.azure_openai, "azure"),
|
|
239
|
+
(ModelProviderName.huggingface, "huggingface"),
|
|
240
|
+
(ModelProviderName.vertex, "vertex_ai"),
|
|
241
|
+
(ModelProviderName.together_ai, "together_ai"),
|
|
242
|
+
],
|
|
243
|
+
)
|
|
244
|
+
def test_litellm_model_id_standard_providers(
|
|
245
|
+
config, mock_task, provider_name, expected_prefix
|
|
246
|
+
):
|
|
247
|
+
"""Test litellm_model_id for standard providers"""
|
|
248
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
249
|
+
|
|
250
|
+
# Mock the model_provider method to return a provider with the specified name
|
|
251
|
+
mock_provider = Mock()
|
|
252
|
+
mock_provider.name = provider_name
|
|
253
|
+
mock_provider.model_id = "test-model"
|
|
254
|
+
|
|
255
|
+
with patch.object(adapter, "model_provider", return_value=mock_provider):
|
|
256
|
+
model_id = adapter.litellm_model_id()
|
|
257
|
+
|
|
258
|
+
assert model_id == f"{expected_prefix}/test-model"
|
|
259
|
+
# Verify caching works
|
|
260
|
+
assert adapter._litellm_model_id == model_id
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
@pytest.mark.parametrize(
|
|
264
|
+
"provider_name",
|
|
265
|
+
[
|
|
266
|
+
ModelProviderName.openai_compatible,
|
|
267
|
+
ModelProviderName.kiln_custom_registry,
|
|
268
|
+
ModelProviderName.kiln_fine_tune,
|
|
269
|
+
],
|
|
270
|
+
)
|
|
271
|
+
def test_litellm_model_id_custom_providers(config, mock_task, provider_name):
|
|
272
|
+
"""Test litellm_model_id for custom providers that require a base URL"""
|
|
273
|
+
config.base_url = "https://api.custom.com"
|
|
274
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
275
|
+
|
|
276
|
+
# Mock the model_provider method
|
|
277
|
+
mock_provider = Mock()
|
|
278
|
+
mock_provider.name = provider_name
|
|
279
|
+
mock_provider.model_id = "custom-model"
|
|
280
|
+
|
|
281
|
+
with patch.object(adapter, "model_provider", return_value=mock_provider):
|
|
282
|
+
model_id = adapter.litellm_model_id()
|
|
283
|
+
|
|
284
|
+
# Custom providers should use "openai" as the provider name
|
|
285
|
+
assert model_id == "openai/custom-model"
|
|
286
|
+
assert adapter._litellm_model_id == model_id
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def test_litellm_model_id_custom_provider_no_base_url(config, mock_task):
|
|
290
|
+
"""Test litellm_model_id raises error for custom providers without base URL"""
|
|
291
|
+
config.base_url = None
|
|
292
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
293
|
+
|
|
294
|
+
# Mock the model_provider method
|
|
295
|
+
mock_provider = Mock()
|
|
296
|
+
mock_provider.name = ModelProviderName.openai_compatible
|
|
297
|
+
mock_provider.model_id = "custom-model"
|
|
298
|
+
|
|
299
|
+
with patch.object(adapter, "model_provider", return_value=mock_provider):
|
|
300
|
+
with pytest.raises(ValueError, match="Explicit Base URL is required"):
|
|
301
|
+
adapter.litellm_model_id()
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def test_litellm_model_id_no_model_id(config, mock_task):
|
|
305
|
+
"""Test litellm_model_id raises error when provider has no model_id"""
|
|
306
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
307
|
+
|
|
308
|
+
# Mock the model_provider method to return a provider with no model_id
|
|
309
|
+
mock_provider = Mock()
|
|
310
|
+
mock_provider.name = ModelProviderName.openai
|
|
311
|
+
mock_provider.model_id = None
|
|
312
|
+
|
|
313
|
+
with patch.object(adapter, "model_provider", return_value=mock_provider):
|
|
314
|
+
with pytest.raises(ValueError, match="Model ID is required"):
|
|
315
|
+
adapter.litellm_model_id()
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def test_litellm_model_id_caching(config, mock_task):
|
|
319
|
+
"""Test that litellm_model_id caches the result"""
|
|
320
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
321
|
+
|
|
322
|
+
# Set the cached value directly
|
|
323
|
+
adapter._litellm_model_id = "cached-value"
|
|
324
|
+
|
|
325
|
+
# The method should return the cached value without calling model_provider
|
|
326
|
+
with patch.object(adapter, "model_provider") as mock_model_provider:
|
|
327
|
+
model_id = adapter.litellm_model_id()
|
|
328
|
+
|
|
329
|
+
assert model_id == "cached-value"
|
|
330
|
+
mock_model_provider.assert_not_called()
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def test_litellm_model_id_unknown_provider(config, mock_task):
|
|
334
|
+
"""Test litellm_model_id raises error for unknown provider"""
|
|
335
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
336
|
+
|
|
337
|
+
# Create a mock provider with an unknown name
|
|
338
|
+
mock_provider = Mock()
|
|
339
|
+
mock_provider.name = "unknown_provider" # Not in ModelProviderName enum
|
|
340
|
+
mock_provider.model_id = "test-model"
|
|
341
|
+
|
|
342
|
+
with patch.object(adapter, "model_provider", return_value=mock_provider):
|
|
343
|
+
with patch(
|
|
344
|
+
"kiln_ai.adapters.model_adapters.litellm_adapter.raise_exhaustive_enum_error"
|
|
345
|
+
) as mock_raise_error:
|
|
346
|
+
mock_raise_error.side_effect = Exception("Test error")
|
|
347
|
+
|
|
348
|
+
with pytest.raises(Exception, match="Test error"):
|
|
349
|
+
adapter.litellm_model_id()
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
@pytest.mark.asyncio
|
|
353
|
+
@pytest.mark.parametrize(
|
|
354
|
+
"top_logprobs,response_format,extra_body",
|
|
355
|
+
[
|
|
356
|
+
(None, {}, {}), # Basic case
|
|
357
|
+
(5, {}, {}), # With logprobs
|
|
358
|
+
(
|
|
359
|
+
None,
|
|
360
|
+
{"response_format": {"type": "json_object"}},
|
|
361
|
+
{},
|
|
362
|
+
), # With response format
|
|
363
|
+
(
|
|
364
|
+
3,
|
|
365
|
+
{"tools": [{"type": "function"}]},
|
|
366
|
+
{"reasoning_effort": 0.8},
|
|
367
|
+
), # Combined options
|
|
368
|
+
],
|
|
369
|
+
)
|
|
370
|
+
async def test_build_completion_kwargs(
|
|
371
|
+
config, mock_task, top_logprobs, response_format, extra_body
|
|
372
|
+
):
|
|
373
|
+
"""Test build_completion_kwargs with various configurations"""
|
|
374
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
375
|
+
mock_provider = Mock()
|
|
376
|
+
messages = [{"role": "user", "content": "Hello"}]
|
|
377
|
+
|
|
378
|
+
with (
|
|
379
|
+
patch.object(adapter, "model_provider", return_value=mock_provider),
|
|
380
|
+
patch.object(adapter, "litellm_model_id", return_value="openai/test-model"),
|
|
381
|
+
patch.object(adapter, "build_extra_body", return_value=extra_body),
|
|
382
|
+
patch.object(adapter, "response_format_options", return_value=response_format),
|
|
383
|
+
):
|
|
384
|
+
kwargs = await adapter.build_completion_kwargs(
|
|
385
|
+
mock_provider, messages, top_logprobs
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
# Verify core functionality
|
|
389
|
+
assert kwargs["model"] == "openai/test-model"
|
|
390
|
+
assert kwargs["messages"] == messages
|
|
391
|
+
assert kwargs["api_base"] == config.base_url
|
|
392
|
+
|
|
393
|
+
# Verify optional parameters
|
|
394
|
+
if top_logprobs is not None:
|
|
395
|
+
assert kwargs["logprobs"] is True
|
|
396
|
+
assert kwargs["top_logprobs"] == top_logprobs
|
|
397
|
+
else:
|
|
398
|
+
assert "logprobs" not in kwargs
|
|
399
|
+
assert "top_logprobs" not in kwargs
|
|
400
|
+
|
|
401
|
+
# Verify response format is included
|
|
402
|
+
for key, value in response_format.items():
|
|
403
|
+
assert kwargs[key] == value
|
|
404
|
+
|
|
405
|
+
# Verify extra body is included
|
|
406
|
+
for key, value in extra_body.items():
|
|
407
|
+
assert kwargs[key] == value
|
|
@@ -3,7 +3,6 @@ from unittest.mock import patch
|
|
|
3
3
|
import pytest
|
|
4
4
|
|
|
5
5
|
from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
6
|
-
AdapterInfo,
|
|
7
6
|
BaseAdapter,
|
|
8
7
|
RunOutput,
|
|
9
8
|
)
|
|
@@ -13,6 +12,7 @@ from kiln_ai.datamodel import (
|
|
|
13
12
|
Project,
|
|
14
13
|
Task,
|
|
15
14
|
)
|
|
15
|
+
from kiln_ai.datamodel.task import RunConfig
|
|
16
16
|
from kiln_ai.utils.config import Config
|
|
17
17
|
|
|
18
18
|
|
|
@@ -20,14 +20,8 @@ class MockAdapter(BaseAdapter):
|
|
|
20
20
|
async def _run(self, input: dict | str) -> dict | str:
|
|
21
21
|
return RunOutput(output="Test output", intermediate_outputs=None)
|
|
22
22
|
|
|
23
|
-
def
|
|
24
|
-
return
|
|
25
|
-
adapter_name="mock_adapter",
|
|
26
|
-
model_name="mock_model",
|
|
27
|
-
model_provider="mock_provider",
|
|
28
|
-
prompt_builder_name="mock_prompt_builder",
|
|
29
|
-
prompt_id="mock_prompt_id",
|
|
30
|
-
)
|
|
23
|
+
def adapter_name(self) -> str:
|
|
24
|
+
return "mock_adapter"
|
|
31
25
|
|
|
32
26
|
|
|
33
27
|
@pytest.fixture
|
|
@@ -45,7 +39,14 @@ def test_task(tmp_path):
|
|
|
45
39
|
|
|
46
40
|
@pytest.fixture
|
|
47
41
|
def adapter(test_task):
|
|
48
|
-
return MockAdapter(
|
|
42
|
+
return MockAdapter(
|
|
43
|
+
run_config=RunConfig(
|
|
44
|
+
task=test_task,
|
|
45
|
+
model_name="phi_3_5",
|
|
46
|
+
model_provider_name="ollama",
|
|
47
|
+
prompt_id="simple_chain_of_thought_prompt_builder",
|
|
48
|
+
),
|
|
49
|
+
)
|
|
49
50
|
|
|
50
51
|
|
|
51
52
|
def test_save_run_isolation(test_task, adapter):
|
|
@@ -94,13 +95,12 @@ def test_save_run_isolation(test_task, adapter):
|
|
|
94
95
|
assert reloaded_output.source.type == DataSourceType.synthetic
|
|
95
96
|
assert reloaded_output.rating is None
|
|
96
97
|
assert reloaded_output.source.properties["adapter_name"] == "mock_adapter"
|
|
97
|
-
assert reloaded_output.source.properties["model_name"] == "
|
|
98
|
-
assert reloaded_output.source.properties["model_provider"] == "
|
|
98
|
+
assert reloaded_output.source.properties["model_name"] == "phi_3_5"
|
|
99
|
+
assert reloaded_output.source.properties["model_provider"] == "ollama"
|
|
99
100
|
assert (
|
|
100
|
-
reloaded_output.source.properties["
|
|
101
|
-
== "
|
|
101
|
+
reloaded_output.source.properties["prompt_id"]
|
|
102
|
+
== "simple_chain_of_thought_prompt_builder"
|
|
102
103
|
)
|
|
103
|
-
assert reloaded_output.source.properties["prompt_id"] == "mock_prompt_id"
|
|
104
104
|
# Run again, with same input and different output. Should create a new TaskRun.
|
|
105
105
|
different_run_output = RunOutput(
|
|
106
106
|
output="Different output", intermediate_outputs=None
|
|
@@ -118,7 +118,7 @@ def test_save_run_isolation(test_task, adapter):
|
|
|
118
118
|
properties={
|
|
119
119
|
"model_name": "mock_model",
|
|
120
120
|
"model_provider": "mock_provider",
|
|
121
|
-
"
|
|
121
|
+
"prompt_id": "mock_prompt_builder",
|
|
122
122
|
"adapter_name": "mock_adapter",
|
|
123
123
|
},
|
|
124
124
|
),
|
|
@@ -178,6 +178,25 @@ async def test_autosave_false(test_task, adapter):
|
|
|
178
178
|
assert run.id is None
|
|
179
179
|
|
|
180
180
|
|
|
181
|
+
@pytest.mark.asyncio
|
|
182
|
+
async def test_autosave_true_with_disabled(test_task, adapter):
|
|
183
|
+
with patch("kiln_ai.utils.config.Config.shared") as mock_shared:
|
|
184
|
+
mock_config = mock_shared.return_value
|
|
185
|
+
mock_config.autosave_runs = True
|
|
186
|
+
mock_config.user_id = "test_user"
|
|
187
|
+
|
|
188
|
+
input_data = "Test input"
|
|
189
|
+
|
|
190
|
+
adapter.base_adapter_config.allow_saving = False
|
|
191
|
+
run = await adapter.invoke(input_data)
|
|
192
|
+
|
|
193
|
+
# Check that no runs were saved
|
|
194
|
+
assert len(test_task.runs()) == 0
|
|
195
|
+
|
|
196
|
+
# Check that the run ID is not set
|
|
197
|
+
assert run.id is None
|
|
198
|
+
|
|
199
|
+
|
|
181
200
|
@pytest.mark.asyncio
|
|
182
201
|
async def test_autosave_true(test_task, adapter):
|
|
183
202
|
with patch("kiln_ai.utils.config.Config.shared") as mock_shared:
|
|
@@ -202,6 +221,9 @@ async def test_autosave_true(test_task, adapter):
|
|
|
202
221
|
assert output.output == "Test output"
|
|
203
222
|
assert output.source.type == DataSourceType.synthetic
|
|
204
223
|
assert output.source.properties["adapter_name"] == "mock_adapter"
|
|
205
|
-
assert output.source.properties["model_name"] == "
|
|
206
|
-
assert output.source.properties["model_provider"] == "
|
|
207
|
-
assert
|
|
224
|
+
assert output.source.properties["model_name"] == "phi_3_5"
|
|
225
|
+
assert output.source.properties["model_provider"] == "ollama"
|
|
226
|
+
assert (
|
|
227
|
+
output.source.properties["prompt_id"]
|
|
228
|
+
== "simple_chain_of_thought_prompt_builder"
|
|
229
|
+
)
|