kiln-ai 0.15.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +22 -44
- kiln_ai/adapters/chat/__init__.py +8 -0
- kiln_ai/adapters/chat/chat_formatter.py +234 -0
- kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +19 -6
- kiln_ai/adapters/eval/base_eval.py +8 -6
- kiln_ai/adapters/eval/eval_runner.py +9 -65
- kiln_ai/adapters/eval/g_eval.py +26 -8
- kiln_ai/adapters/eval/test_base_eval.py +166 -15
- kiln_ai/adapters/eval/test_eval_runner.py +3 -0
- kiln_ai/adapters/eval/test_g_eval.py +1 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +2 -2
- kiln_ai/adapters/fine_tune/dataset_formatter.py +153 -197
- kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +402 -211
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +4 -4
- kiln_ai/adapters/fine_tune/together_finetune.py +12 -1
- kiln_ai/adapters/ml_model_list.py +556 -45
- kiln_ai/adapters/model_adapters/base_adapter.py +100 -35
- kiln_ai/adapters/model_adapters/litellm_adapter.py +116 -100
- kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
- kiln_ai/adapters/model_adapters/test_base_adapter.py +299 -52
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +121 -22
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +44 -2
- kiln_ai/adapters/model_adapters/test_structured_output.py +48 -18
- kiln_ai/adapters/parsers/base_parser.py +0 -3
- kiln_ai/adapters/parsers/parser_registry.py +5 -3
- kiln_ai/adapters/parsers/r1_parser.py +17 -2
- kiln_ai/adapters/parsers/request_formatters.py +40 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
- kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
- kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
- kiln_ai/adapters/prompt_builders.py +14 -17
- kiln_ai/adapters/provider_tools.py +39 -4
- kiln_ai/adapters/repair/test_repair_task.py +27 -5
- kiln_ai/adapters/test_adapter_registry.py +88 -28
- kiln_ai/adapters/test_ml_model_list.py +158 -0
- kiln_ai/adapters/test_prompt_adaptors.py +17 -3
- kiln_ai/adapters/test_prompt_builders.py +27 -19
- kiln_ai/adapters/test_provider_tools.py +130 -12
- kiln_ai/datamodel/__init__.py +2 -2
- kiln_ai/datamodel/datamodel_enums.py +43 -4
- kiln_ai/datamodel/dataset_filters.py +69 -1
- kiln_ai/datamodel/dataset_split.py +4 -0
- kiln_ai/datamodel/eval.py +8 -0
- kiln_ai/datamodel/finetune.py +13 -7
- kiln_ai/datamodel/prompt_id.py +1 -0
- kiln_ai/datamodel/task.py +68 -7
- kiln_ai/datamodel/task_output.py +1 -1
- kiln_ai/datamodel/task_run.py +39 -7
- kiln_ai/datamodel/test_basemodel.py +5 -8
- kiln_ai/datamodel/test_dataset_filters.py +82 -0
- kiln_ai/datamodel/test_dataset_split.py +2 -8
- kiln_ai/datamodel/test_example_models.py +54 -0
- kiln_ai/datamodel/test_models.py +80 -9
- kiln_ai/datamodel/test_task.py +168 -2
- kiln_ai/utils/async_job_runner.py +106 -0
- kiln_ai/utils/config.py +3 -2
- kiln_ai/utils/dataset_import.py +81 -19
- kiln_ai/utils/logging.py +165 -0
- kiln_ai/utils/test_async_job_runner.py +199 -0
- kiln_ai/utils/test_config.py +23 -0
- kiln_ai/utils/test_dataset_import.py +272 -10
- {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/METADATA +1 -1
- kiln_ai-0.17.0.dist-info/RECORD +113 -0
- kiln_ai-0.15.0.dist-info/RECORD +0 -104
- {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -3,16 +3,18 @@ from unittest.mock import MagicMock, patch
|
|
|
3
3
|
import pytest
|
|
4
4
|
|
|
5
5
|
from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
|
|
6
|
-
from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter
|
|
6
|
+
from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput
|
|
7
|
+
from kiln_ai.adapters.parsers.request_formatters import request_formatter_from_id
|
|
7
8
|
from kiln_ai.datamodel import Task
|
|
8
|
-
from kiln_ai.datamodel.
|
|
9
|
+
from kiln_ai.datamodel.datamodel_enums import ChatStrategy
|
|
10
|
+
from kiln_ai.datamodel.task import RunConfig, RunConfigProperties
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
class MockAdapter(BaseAdapter):
|
|
12
14
|
"""Concrete implementation of BaseAdapter for testing"""
|
|
13
15
|
|
|
14
16
|
async def _run(self, input):
|
|
15
|
-
return None
|
|
17
|
+
return None, None
|
|
16
18
|
|
|
17
19
|
def adapter_name(self) -> str:
|
|
18
20
|
return "test"
|
|
@@ -36,12 +38,29 @@ def adapter(base_task):
|
|
|
36
38
|
run_config=RunConfig(
|
|
37
39
|
task=base_task,
|
|
38
40
|
model_name="test_model",
|
|
39
|
-
model_provider_name="
|
|
41
|
+
model_provider_name="openai",
|
|
40
42
|
prompt_id="simple_prompt_builder",
|
|
43
|
+
structured_output_mode="json_schema",
|
|
41
44
|
),
|
|
42
45
|
)
|
|
43
46
|
|
|
44
47
|
|
|
48
|
+
@pytest.fixture
|
|
49
|
+
def mock_formatter():
|
|
50
|
+
formatter = MagicMock()
|
|
51
|
+
formatter.format_input.return_value = {"formatted": "input"}
|
|
52
|
+
return formatter
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@pytest.fixture
|
|
56
|
+
def mock_parser():
|
|
57
|
+
parser = MagicMock()
|
|
58
|
+
parser.parse_output.return_value = RunOutput(
|
|
59
|
+
output="test output", intermediate_outputs={}
|
|
60
|
+
)
|
|
61
|
+
return parser
|
|
62
|
+
|
|
63
|
+
|
|
45
64
|
async def test_model_provider_uses_cache(adapter, mock_provider):
|
|
46
65
|
"""Test that cached provider is returned if it exists"""
|
|
47
66
|
# Set up cached provider
|
|
@@ -71,7 +90,7 @@ async def test_model_provider_loads_and_caches(adapter, mock_provider):
|
|
|
71
90
|
# First call should load and cache
|
|
72
91
|
provider1 = adapter.model_provider()
|
|
73
92
|
assert provider1 == mock_provider
|
|
74
|
-
mock_loader.assert_called_once_with("test_model", "
|
|
93
|
+
mock_loader.assert_called_once_with("test_model", "openai")
|
|
75
94
|
|
|
76
95
|
# Second call should use cache
|
|
77
96
|
mock_loader.reset_mock()
|
|
@@ -80,29 +99,30 @@ async def test_model_provider_loads_and_caches(adapter, mock_provider):
|
|
|
80
99
|
mock_loader.assert_not_called()
|
|
81
100
|
|
|
82
101
|
|
|
83
|
-
async def
|
|
102
|
+
async def test_model_provider_invalid_provider_model_name(base_task):
|
|
103
|
+
"""Test error when model or provider name is missing"""
|
|
104
|
+
# Test with missing model name
|
|
105
|
+
with pytest.raises(ValueError, match="Input should be"):
|
|
106
|
+
adapter = MockAdapter(
|
|
107
|
+
run_config=RunConfig(
|
|
108
|
+
task=base_task,
|
|
109
|
+
model_name="test_model",
|
|
110
|
+
model_provider_name="invalid",
|
|
111
|
+
prompt_id="simple_prompt_builder",
|
|
112
|
+
),
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
async def test_model_provider_missing_model_names(base_task):
|
|
84
117
|
"""Test error when model or provider name is missing"""
|
|
85
118
|
# Test with missing model name
|
|
86
119
|
adapter = MockAdapter(
|
|
87
120
|
run_config=RunConfig(
|
|
88
121
|
task=base_task,
|
|
89
122
|
model_name="",
|
|
90
|
-
model_provider_name="",
|
|
91
|
-
prompt_id="simple_prompt_builder",
|
|
92
|
-
),
|
|
93
|
-
)
|
|
94
|
-
with pytest.raises(
|
|
95
|
-
ValueError, match="model_name and model_provider_name must be provided"
|
|
96
|
-
):
|
|
97
|
-
await adapter.model_provider()
|
|
98
|
-
|
|
99
|
-
# Test with missing provider name
|
|
100
|
-
adapter = MockAdapter(
|
|
101
|
-
run_config=RunConfig(
|
|
102
|
-
task=base_task,
|
|
103
|
-
model_name="test_model",
|
|
104
|
-
model_provider_name="",
|
|
123
|
+
model_provider_name="openai",
|
|
105
124
|
prompt_id="simple_prompt_builder",
|
|
125
|
+
structured_output_mode="json_schema",
|
|
106
126
|
),
|
|
107
127
|
)
|
|
108
128
|
with pytest.raises(
|
|
@@ -121,7 +141,7 @@ async def test_model_provider_not_found(adapter):
|
|
|
121
141
|
|
|
122
142
|
with pytest.raises(
|
|
123
143
|
ValueError,
|
|
124
|
-
match="
|
|
144
|
+
match="not found for model test_model",
|
|
125
145
|
):
|
|
126
146
|
await adapter.model_provider()
|
|
127
147
|
|
|
@@ -151,11 +171,7 @@ async def test_prompt_builder_json_instructions(
|
|
|
151
171
|
adapter.prompt_builder = mock_prompt_builder
|
|
152
172
|
adapter.model_provider_name = "openai"
|
|
153
173
|
adapter.has_structured_output = MagicMock(return_value=output_schema)
|
|
154
|
-
|
|
155
|
-
# provider mock
|
|
156
|
-
provider = MagicMock()
|
|
157
|
-
provider.structured_output_mode = structured_output_mode
|
|
158
|
-
adapter.model_provider = MagicMock(return_value=provider)
|
|
174
|
+
adapter.run_config.structured_output_mode = structured_output_mode
|
|
159
175
|
|
|
160
176
|
# Test
|
|
161
177
|
adapter.build_prompt()
|
|
@@ -164,36 +180,267 @@ async def test_prompt_builder_json_instructions(
|
|
|
164
180
|
)
|
|
165
181
|
|
|
166
182
|
|
|
183
|
+
@pytest.mark.asyncio
|
|
167
184
|
@pytest.mark.parametrize(
|
|
168
|
-
"
|
|
185
|
+
"formatter_id,expected_input,expected_calls",
|
|
169
186
|
[
|
|
170
|
-
|
|
171
|
-
("
|
|
172
|
-
# Structured output with thinking-capable LLM
|
|
173
|
-
("think carefully", True, True, ("cot_as_message", "think carefully")),
|
|
174
|
-
# Structured output with normal LLM
|
|
175
|
-
("think carefully", True, False, ("cot_two_call", "think carefully")),
|
|
176
|
-
# Basic cases - no COT
|
|
177
|
-
(None, True, True, ("basic", None)),
|
|
178
|
-
(None, False, False, ("basic", None)),
|
|
179
|
-
(None, True, False, ("basic", None)),
|
|
180
|
-
(None, False, True, ("basic", None)),
|
|
181
|
-
# Edge case - COT prompt exists but structured output is False and reasoning_capable is True
|
|
182
|
-
("think carefully", False, True, ("cot_as_message", "think carefully")),
|
|
187
|
+
(None, {"original": "input"}, 0), # No formatter
|
|
188
|
+
("test_formatter", {"formatted": "input"}, 1), # With formatter
|
|
183
189
|
],
|
|
184
190
|
)
|
|
185
|
-
async def
|
|
186
|
-
adapter,
|
|
191
|
+
async def test_input_formatting(
|
|
192
|
+
adapter, mock_formatter, mock_parser, formatter_id, expected_input, expected_calls
|
|
187
193
|
):
|
|
188
|
-
"""Test that
|
|
189
|
-
# Mock
|
|
190
|
-
adapter.prompt_builder.chain_of_thought_prompt = MagicMock(return_value=cot_prompt)
|
|
191
|
-
adapter.has_structured_output = MagicMock(return_value=has_structured_output)
|
|
192
|
-
|
|
194
|
+
"""Test that input formatting is handled correctly based on formatter configuration"""
|
|
195
|
+
# Mock the model provider to return our formatter ID and parser
|
|
193
196
|
provider = MagicMock()
|
|
194
|
-
provider.
|
|
197
|
+
provider.formatter = formatter_id
|
|
198
|
+
provider.parser = "test_parser"
|
|
199
|
+
provider.reasoning_capable = False
|
|
195
200
|
adapter.model_provider = MagicMock(return_value=provider)
|
|
196
201
|
|
|
197
|
-
#
|
|
198
|
-
|
|
199
|
-
|
|
202
|
+
# Mock the formatter factory and parser factory
|
|
203
|
+
with (
|
|
204
|
+
patch(
|
|
205
|
+
"kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id"
|
|
206
|
+
) as mock_factory,
|
|
207
|
+
patch(
|
|
208
|
+
"kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id"
|
|
209
|
+
) as mock_parser_factory,
|
|
210
|
+
):
|
|
211
|
+
mock_factory.return_value = mock_formatter
|
|
212
|
+
mock_parser_factory.return_value = mock_parser
|
|
213
|
+
|
|
214
|
+
# Mock the _run method to capture the input
|
|
215
|
+
captured_input = None
|
|
216
|
+
|
|
217
|
+
async def mock_run(input):
|
|
218
|
+
nonlocal captured_input
|
|
219
|
+
captured_input = input
|
|
220
|
+
return RunOutput(output="test output", intermediate_outputs={}), None
|
|
221
|
+
|
|
222
|
+
adapter._run = mock_run
|
|
223
|
+
|
|
224
|
+
# Run the adapter
|
|
225
|
+
original_input = {"original": "input"}
|
|
226
|
+
await adapter.invoke_returning_run_output(original_input)
|
|
227
|
+
|
|
228
|
+
# Verify formatter was called correctly
|
|
229
|
+
assert captured_input == expected_input
|
|
230
|
+
assert mock_factory.call_count == (1 if formatter_id else 0)
|
|
231
|
+
assert mock_formatter.format_input.call_count == expected_calls
|
|
232
|
+
|
|
233
|
+
# Verify original input was preserved in the run
|
|
234
|
+
if formatter_id:
|
|
235
|
+
mock_formatter.format_input.assert_called_once_with(original_input)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
async def test_properties_for_task_output_includes_all_run_config_properties(adapter):
|
|
239
|
+
"""Test that all properties from RunConfigProperties are saved in task output properties"""
|
|
240
|
+
# Get all field names from RunConfigProperties
|
|
241
|
+
run_config_properties_fields = set(RunConfigProperties.model_fields.keys())
|
|
242
|
+
|
|
243
|
+
# Get the properties saved by the adapter
|
|
244
|
+
saved_properties = adapter._properties_for_task_output()
|
|
245
|
+
saved_property_keys = set(saved_properties.keys())
|
|
246
|
+
|
|
247
|
+
# Check which RunConfigProperties fields are missing from saved properties
|
|
248
|
+
# Note: model_provider_name becomes model_provider in saved properties
|
|
249
|
+
expected_mappings = {
|
|
250
|
+
"model_name": "model_name",
|
|
251
|
+
"model_provider_name": "model_provider",
|
|
252
|
+
"prompt_id": "prompt_id",
|
|
253
|
+
"temperature": "temperature",
|
|
254
|
+
"top_p": "top_p",
|
|
255
|
+
"structured_output_mode": "structured_output_mode",
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
missing_properties = []
|
|
259
|
+
for field_name in run_config_properties_fields:
|
|
260
|
+
expected_key = expected_mappings.get(field_name, field_name)
|
|
261
|
+
if expected_key not in saved_property_keys:
|
|
262
|
+
missing_properties.append(
|
|
263
|
+
f"RunConfigProperties.{field_name} -> {expected_key}"
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
assert not missing_properties, (
|
|
267
|
+
f"The following RunConfigProperties fields are not saved by _properties_for_task_output: {missing_properties}. Please update the method to include them."
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
async def test_properties_for_task_output_catches_missing_new_property(adapter):
|
|
272
|
+
"""Test that demonstrates our test will catch when new properties are added to RunConfigProperties but not to _properties_for_task_output"""
|
|
273
|
+
# Simulate what happens if a new property was added to RunConfigProperties
|
|
274
|
+
# We'll mock the model_fields to include a fake new property
|
|
275
|
+
original_fields = RunConfigProperties.model_fields.copy()
|
|
276
|
+
|
|
277
|
+
# Create a mock field to simulate a new property being added
|
|
278
|
+
from pydantic.fields import FieldInfo
|
|
279
|
+
|
|
280
|
+
mock_field = FieldInfo(annotation=str, default="default_value")
|
|
281
|
+
|
|
282
|
+
try:
|
|
283
|
+
# Add a fake new field to simulate someone adding a property
|
|
284
|
+
RunConfigProperties.model_fields["new_fake_property"] = mock_field
|
|
285
|
+
|
|
286
|
+
# Get all field names from RunConfigProperties (now includes our fake property)
|
|
287
|
+
run_config_properties_fields = set(RunConfigProperties.model_fields.keys())
|
|
288
|
+
|
|
289
|
+
# Get the properties saved by the adapter (won't include our fake property)
|
|
290
|
+
saved_properties = adapter._properties_for_task_output()
|
|
291
|
+
saved_property_keys = set(saved_properties.keys())
|
|
292
|
+
|
|
293
|
+
# The mappings don't include our fake property
|
|
294
|
+
expected_mappings = {
|
|
295
|
+
"model_name": "model_name",
|
|
296
|
+
"model_provider_name": "model_provider",
|
|
297
|
+
"prompt_id": "prompt_id",
|
|
298
|
+
"temperature": "temperature",
|
|
299
|
+
"top_p": "top_p",
|
|
300
|
+
"structured_output_mode": "structured_output_mode",
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
missing_properties = []
|
|
304
|
+
for field_name in run_config_properties_fields:
|
|
305
|
+
expected_key = expected_mappings.get(field_name, field_name)
|
|
306
|
+
if expected_key not in saved_property_keys:
|
|
307
|
+
missing_properties.append(
|
|
308
|
+
f"RunConfigProperties.{field_name} -> {expected_key}"
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
# This should find our missing fake property
|
|
312
|
+
assert missing_properties == [
|
|
313
|
+
"RunConfigProperties.new_fake_property -> new_fake_property"
|
|
314
|
+
], f"Expected to find missing fake property, but got: {missing_properties}"
|
|
315
|
+
|
|
316
|
+
finally:
|
|
317
|
+
# Restore the original fields
|
|
318
|
+
RunConfigProperties.model_fields.clear()
|
|
319
|
+
RunConfigProperties.model_fields.update(original_fields)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
@pytest.mark.parametrize(
|
|
323
|
+
"cot_prompt,tuned_strategy,reasoning_capable,expected_formatter_class",
|
|
324
|
+
[
|
|
325
|
+
# No COT prompt -> always single turn
|
|
326
|
+
(None, None, False, "SingleTurnFormatter"),
|
|
327
|
+
(None, ChatStrategy.two_message_cot, False, "SingleTurnFormatter"),
|
|
328
|
+
(None, ChatStrategy.single_turn_r1_thinking, True, "SingleTurnFormatter"),
|
|
329
|
+
# With COT prompt:
|
|
330
|
+
# - Tuned strategy takes precedence (except single turn)
|
|
331
|
+
(
|
|
332
|
+
"think step by step",
|
|
333
|
+
ChatStrategy.two_message_cot,
|
|
334
|
+
False,
|
|
335
|
+
"TwoMessageCotFormatter",
|
|
336
|
+
),
|
|
337
|
+
(
|
|
338
|
+
"think step by step",
|
|
339
|
+
ChatStrategy.single_turn_r1_thinking,
|
|
340
|
+
False,
|
|
341
|
+
"SingleTurnR1ThinkingFormatter",
|
|
342
|
+
),
|
|
343
|
+
# - Tuned single turn is ignored when COT exists
|
|
344
|
+
(
|
|
345
|
+
"think step by step",
|
|
346
|
+
ChatStrategy.single_turn,
|
|
347
|
+
True,
|
|
348
|
+
"SingleTurnR1ThinkingFormatter",
|
|
349
|
+
),
|
|
350
|
+
# - Reasoning capable -> single turn R1 thinking
|
|
351
|
+
("think step by step", None, True, "SingleTurnR1ThinkingFormatter"),
|
|
352
|
+
# - Not reasoning capable -> two message COT
|
|
353
|
+
("think step by step", None, False, "TwoMessageCotFormatter"),
|
|
354
|
+
],
|
|
355
|
+
)
|
|
356
|
+
def test_build_chat_formatter(
|
|
357
|
+
adapter,
|
|
358
|
+
cot_prompt,
|
|
359
|
+
tuned_strategy,
|
|
360
|
+
reasoning_capable,
|
|
361
|
+
expected_formatter_class,
|
|
362
|
+
):
|
|
363
|
+
"""Test chat formatter strategy selection based on COT prompt, tuned strategy, and model capabilities"""
|
|
364
|
+
# Mock the prompt builder
|
|
365
|
+
mock_prompt_builder = MagicMock()
|
|
366
|
+
mock_prompt_builder.chain_of_thought_prompt.return_value = cot_prompt
|
|
367
|
+
mock_prompt_builder.build_prompt.return_value = "system message"
|
|
368
|
+
adapter.prompt_builder = mock_prompt_builder
|
|
369
|
+
|
|
370
|
+
# Mock the model provider
|
|
371
|
+
mock_provider = MagicMock()
|
|
372
|
+
mock_provider.tuned_chat_strategy = tuned_strategy
|
|
373
|
+
mock_provider.reasoning_capable = reasoning_capable
|
|
374
|
+
adapter.model_provider = MagicMock(return_value=mock_provider)
|
|
375
|
+
|
|
376
|
+
# Get the formatter
|
|
377
|
+
formatter = adapter.build_chat_formatter("test input")
|
|
378
|
+
|
|
379
|
+
# Verify the formatter type
|
|
380
|
+
assert formatter.__class__.__name__ == expected_formatter_class
|
|
381
|
+
|
|
382
|
+
# Verify the formatter was created with correct parameters
|
|
383
|
+
assert formatter.system_message == "system message"
|
|
384
|
+
assert formatter.user_input == "test input"
|
|
385
|
+
# Only check thinking_instructions for formatters that use it
|
|
386
|
+
if expected_formatter_class == "TwoMessageCotFormatter":
|
|
387
|
+
if cot_prompt:
|
|
388
|
+
assert formatter.thinking_instructions == cot_prompt
|
|
389
|
+
else:
|
|
390
|
+
assert formatter.thinking_instructions is None
|
|
391
|
+
# For other formatters, don't assert thinking_instructions
|
|
392
|
+
|
|
393
|
+
# Verify prompt builder was called correctly
|
|
394
|
+
mock_prompt_builder.build_prompt.assert_called_once()
|
|
395
|
+
mock_prompt_builder.chain_of_thought_prompt.assert_called_once()
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
@pytest.mark.parametrize(
|
|
399
|
+
"initial_mode,expected_mode",
|
|
400
|
+
[
|
|
401
|
+
(
|
|
402
|
+
StructuredOutputMode.json_schema,
|
|
403
|
+
StructuredOutputMode.json_schema,
|
|
404
|
+
), # Should not change
|
|
405
|
+
(
|
|
406
|
+
StructuredOutputMode.unknown,
|
|
407
|
+
StructuredOutputMode.json_mode,
|
|
408
|
+
), # Should update to default
|
|
409
|
+
],
|
|
410
|
+
)
|
|
411
|
+
async def test_update_run_config_unknown_structured_output_mode(
|
|
412
|
+
base_task, initial_mode, expected_mode
|
|
413
|
+
):
|
|
414
|
+
"""Test that unknown structured output mode is updated to the default for the model provider"""
|
|
415
|
+
# Create a run config with the initial mode
|
|
416
|
+
run_config = RunConfig(
|
|
417
|
+
task=base_task,
|
|
418
|
+
model_name="test_model",
|
|
419
|
+
model_provider_name="openai",
|
|
420
|
+
prompt_id="simple_prompt_builder",
|
|
421
|
+
structured_output_mode=initial_mode,
|
|
422
|
+
temperature=0.7, # Add some other properties to verify they're preserved
|
|
423
|
+
top_p=0.9,
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
# Mock the default mode lookup
|
|
427
|
+
with patch(
|
|
428
|
+
"kiln_ai.adapters.model_adapters.base_adapter.default_structured_output_mode_for_model_provider"
|
|
429
|
+
) as mock_default:
|
|
430
|
+
mock_default.return_value = StructuredOutputMode.json_mode
|
|
431
|
+
|
|
432
|
+
# Create the adapter
|
|
433
|
+
adapter = MockAdapter(run_config=run_config)
|
|
434
|
+
|
|
435
|
+
# Verify the mode was updated correctly
|
|
436
|
+
assert adapter.run_config.structured_output_mode == expected_mode
|
|
437
|
+
|
|
438
|
+
# Verify other properties were preserved
|
|
439
|
+
assert adapter.run_config.temperature == 0.7
|
|
440
|
+
assert adapter.run_config.top_p == 0.9
|
|
441
|
+
|
|
442
|
+
# Verify the default mode lookup was only called when needed
|
|
443
|
+
if initial_mode == StructuredOutputMode.unknown:
|
|
444
|
+
mock_default.assert_called_once_with("test_model", "openai")
|
|
445
|
+
else:
|
|
446
|
+
mock_default.assert_not_called()
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from unittest.mock import Mock, patch
|
|
3
3
|
|
|
4
|
+
import litellm
|
|
4
5
|
import pytest
|
|
5
6
|
|
|
6
7
|
from kiln_ai.adapters.ml_model_list import ModelProviderName, StructuredOutputMode
|
|
@@ -9,7 +10,8 @@ from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter
|
|
|
9
10
|
from kiln_ai.adapters.model_adapters.litellm_config import (
|
|
10
11
|
LiteLlmConfig,
|
|
11
12
|
)
|
|
12
|
-
from kiln_ai.datamodel import Project, Task
|
|
13
|
+
from kiln_ai.datamodel import Project, Task, Usage
|
|
14
|
+
from kiln_ai.datamodel.task import RunConfigProperties
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
@pytest.fixture
|
|
@@ -40,8 +42,12 @@ def mock_task(tmp_path):
|
|
|
40
42
|
def config():
|
|
41
43
|
return LiteLlmConfig(
|
|
42
44
|
base_url="https://api.test.com",
|
|
43
|
-
|
|
44
|
-
|
|
45
|
+
run_config_properties=RunConfigProperties(
|
|
46
|
+
model_name="test-model",
|
|
47
|
+
model_provider_name="openrouter",
|
|
48
|
+
prompt_id="simple_prompt_builder",
|
|
49
|
+
structured_output_mode="json_schema",
|
|
50
|
+
),
|
|
45
51
|
default_headers={"X-Test": "test"},
|
|
46
52
|
additional_body_options={"api_key": "test_key"},
|
|
47
53
|
)
|
|
@@ -51,7 +57,6 @@ def test_initialization(config, mock_task):
|
|
|
51
57
|
adapter = LiteLlmAdapter(
|
|
52
58
|
config=config,
|
|
53
59
|
kiln_task=mock_task,
|
|
54
|
-
prompt_id="simple_prompt_builder",
|
|
55
60
|
base_adapter_config=AdapterConfig(default_tags=["test-tag"]),
|
|
56
61
|
)
|
|
57
62
|
|
|
@@ -59,8 +64,11 @@ def test_initialization(config, mock_task):
|
|
|
59
64
|
assert adapter.run_config.task == mock_task
|
|
60
65
|
assert adapter.run_config.prompt_id == "simple_prompt_builder"
|
|
61
66
|
assert adapter.base_adapter_config.default_tags == ["test-tag"]
|
|
62
|
-
assert adapter.run_config.model_name == config.model_name
|
|
63
|
-
assert
|
|
67
|
+
assert adapter.run_config.model_name == config.run_config_properties.model_name
|
|
68
|
+
assert (
|
|
69
|
+
adapter.run_config.model_provider_name
|
|
70
|
+
== config.run_config_properties.model_provider_name
|
|
71
|
+
)
|
|
64
72
|
assert adapter.config.additional_body_options["api_key"] == "test_key"
|
|
65
73
|
assert adapter._api_base == config.base_url
|
|
66
74
|
assert adapter._headers == config.default_headers
|
|
@@ -71,8 +79,11 @@ def test_adapter_info(config, mock_task):
|
|
|
71
79
|
|
|
72
80
|
assert adapter.adapter_name() == "kiln_openai_compatible_adapter"
|
|
73
81
|
|
|
74
|
-
assert adapter.run_config.model_name == config.model_name
|
|
75
|
-
assert
|
|
82
|
+
assert adapter.run_config.model_name == config.run_config_properties.model_name
|
|
83
|
+
assert (
|
|
84
|
+
adapter.run_config.model_provider_name
|
|
85
|
+
== config.run_config_properties.model_provider_name
|
|
86
|
+
)
|
|
76
87
|
assert adapter.run_config.prompt_id == "simple_prompt_builder"
|
|
77
88
|
|
|
78
89
|
|
|
@@ -95,14 +106,12 @@ async def test_response_format_options_unstructured(config, mock_task):
|
|
|
95
106
|
)
|
|
96
107
|
@pytest.mark.asyncio
|
|
97
108
|
async def test_response_format_options_json_mode(config, mock_task, mode):
|
|
109
|
+
config.run_config_properties.structured_output_mode = mode
|
|
98
110
|
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
99
111
|
|
|
100
112
|
with (
|
|
101
113
|
patch.object(adapter, "has_structured_output", return_value=True),
|
|
102
|
-
patch.object(adapter, "model_provider") as mock_provider,
|
|
103
114
|
):
|
|
104
|
-
mock_provider.return_value.structured_output_mode = mode
|
|
105
|
-
|
|
106
115
|
options = await adapter.response_format_options()
|
|
107
116
|
assert options == {"response_format": {"type": "json_object"}}
|
|
108
117
|
|
|
@@ -116,14 +125,12 @@ async def test_response_format_options_json_mode(config, mock_task, mode):
|
|
|
116
125
|
)
|
|
117
126
|
@pytest.mark.asyncio
|
|
118
127
|
async def test_response_format_options_function_calling(config, mock_task, mode):
|
|
128
|
+
config.run_config_properties.structured_output_mode = mode
|
|
119
129
|
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
120
130
|
|
|
121
131
|
with (
|
|
122
132
|
patch.object(adapter, "has_structured_output", return_value=True),
|
|
123
|
-
patch.object(adapter, "model_provider") as mock_provider,
|
|
124
133
|
):
|
|
125
|
-
mock_provider.return_value.structured_output_mode = mode
|
|
126
|
-
|
|
127
134
|
options = await adapter.response_format_options()
|
|
128
135
|
assert "tools" in options
|
|
129
136
|
# full tool structure validated below
|
|
@@ -138,30 +145,26 @@ async def test_response_format_options_function_calling(config, mock_task, mode)
|
|
|
138
145
|
)
|
|
139
146
|
@pytest.mark.asyncio
|
|
140
147
|
async def test_response_format_options_json_instructions(config, mock_task, mode):
|
|
148
|
+
config.run_config_properties.structured_output_mode = mode
|
|
141
149
|
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
142
150
|
|
|
143
151
|
with (
|
|
144
152
|
patch.object(adapter, "has_structured_output", return_value=True),
|
|
145
|
-
patch.object(adapter, "model_provider") as mock_provider,
|
|
146
153
|
):
|
|
147
|
-
mock_provider.return_value.structured_output_mode = (
|
|
148
|
-
StructuredOutputMode.json_instructions
|
|
149
|
-
)
|
|
150
154
|
options = await adapter.response_format_options()
|
|
151
155
|
assert options == {}
|
|
152
156
|
|
|
153
157
|
|
|
154
158
|
@pytest.mark.asyncio
|
|
155
159
|
async def test_response_format_options_json_schema(config, mock_task):
|
|
160
|
+
config.run_config_properties.structured_output_mode = (
|
|
161
|
+
StructuredOutputMode.json_schema
|
|
162
|
+
)
|
|
156
163
|
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
157
164
|
|
|
158
165
|
with (
|
|
159
166
|
patch.object(adapter, "has_structured_output", return_value=True),
|
|
160
|
-
patch.object(adapter, "model_provider") as mock_provider,
|
|
161
167
|
):
|
|
162
|
-
mock_provider.return_value.structured_output_mode = (
|
|
163
|
-
StructuredOutputMode.json_schema
|
|
164
|
-
)
|
|
165
168
|
options = await adapter.response_format_options()
|
|
166
169
|
assert options == {
|
|
167
170
|
"response_format": {
|
|
@@ -349,6 +352,32 @@ def test_litellm_model_id_unknown_provider(config, mock_task):
|
|
|
349
352
|
adapter.litellm_model_id()
|
|
350
353
|
|
|
351
354
|
|
|
355
|
+
@pytest.mark.asyncio
|
|
356
|
+
async def test_build_completion_kwargs_custom_temperature_top_p(config, mock_task):
|
|
357
|
+
"""Test build_completion_kwargs with custom temperature and top_p values"""
|
|
358
|
+
# Create config with custom temperature and top_p
|
|
359
|
+
config.run_config_properties.temperature = 0.7
|
|
360
|
+
config.run_config_properties.top_p = 0.9
|
|
361
|
+
|
|
362
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
363
|
+
mock_provider = Mock()
|
|
364
|
+
messages = [{"role": "user", "content": "Hello"}]
|
|
365
|
+
|
|
366
|
+
with (
|
|
367
|
+
patch.object(adapter, "model_provider", return_value=mock_provider),
|
|
368
|
+
patch.object(adapter, "litellm_model_id", return_value="openai/test-model"),
|
|
369
|
+
patch.object(adapter, "build_extra_body", return_value={}),
|
|
370
|
+
patch.object(adapter, "response_format_options", return_value={}),
|
|
371
|
+
):
|
|
372
|
+
kwargs = await adapter.build_completion_kwargs(mock_provider, messages, None)
|
|
373
|
+
|
|
374
|
+
# Verify custom temperature and top_p are passed through
|
|
375
|
+
assert kwargs["temperature"] == 0.7
|
|
376
|
+
assert kwargs["top_p"] == 0.9
|
|
377
|
+
# Verify drop_params is set correctly
|
|
378
|
+
assert kwargs["drop_params"] is True
|
|
379
|
+
|
|
380
|
+
|
|
352
381
|
@pytest.mark.asyncio
|
|
353
382
|
@pytest.mark.parametrize(
|
|
354
383
|
"top_logprobs,response_format,extra_body",
|
|
@@ -390,6 +419,13 @@ async def test_build_completion_kwargs(
|
|
|
390
419
|
assert kwargs["messages"] == messages
|
|
391
420
|
assert kwargs["api_base"] == config.base_url
|
|
392
421
|
|
|
422
|
+
# Verify temperature and top_p are included with default values
|
|
423
|
+
assert kwargs["temperature"] == 1.0 # Default from RunConfigProperties
|
|
424
|
+
assert kwargs["top_p"] == 1.0 # Default from RunConfigProperties
|
|
425
|
+
|
|
426
|
+
# Verify drop_params is set correctly
|
|
427
|
+
assert kwargs["drop_params"] is True
|
|
428
|
+
|
|
393
429
|
# Verify optional parameters
|
|
394
430
|
if top_logprobs is not None:
|
|
395
431
|
assert kwargs["logprobs"] is True
|
|
@@ -405,3 +441,66 @@ async def test_build_completion_kwargs(
|
|
|
405
441
|
# Verify extra body is included
|
|
406
442
|
for key, value in extra_body.items():
|
|
407
443
|
assert kwargs[key] == value
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
@pytest.mark.parametrize(
|
|
447
|
+
"litellm_usage,cost,expected_usage",
|
|
448
|
+
[
|
|
449
|
+
# No usage data
|
|
450
|
+
(None, None, None),
|
|
451
|
+
# Only cost
|
|
452
|
+
(None, 0.5, Usage(cost=0.5)),
|
|
453
|
+
# Only token counts
|
|
454
|
+
(
|
|
455
|
+
litellm.types.utils.Usage(
|
|
456
|
+
prompt_tokens=10,
|
|
457
|
+
completion_tokens=20,
|
|
458
|
+
total_tokens=30,
|
|
459
|
+
),
|
|
460
|
+
None,
|
|
461
|
+
Usage(input_tokens=10, output_tokens=20, total_tokens=30),
|
|
462
|
+
),
|
|
463
|
+
# Both cost and token counts
|
|
464
|
+
(
|
|
465
|
+
litellm.types.utils.Usage(
|
|
466
|
+
prompt_tokens=10,
|
|
467
|
+
completion_tokens=20,
|
|
468
|
+
total_tokens=30,
|
|
469
|
+
),
|
|
470
|
+
0.5,
|
|
471
|
+
Usage(input_tokens=10, output_tokens=20, total_tokens=30, cost=0.5),
|
|
472
|
+
),
|
|
473
|
+
# Invalid usage type (should be ignored)
|
|
474
|
+
({"prompt_tokens": 10}, None, None),
|
|
475
|
+
# Invalid cost type (should be ignored)
|
|
476
|
+
(None, "0.5", None),
|
|
477
|
+
],
|
|
478
|
+
)
|
|
479
|
+
def test_usage_from_response(config, mock_task, litellm_usage, cost, expected_usage):
|
|
480
|
+
"""Test usage_from_response with various combinations of usage data and cost"""
|
|
481
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
482
|
+
|
|
483
|
+
# Create a mock response
|
|
484
|
+
response = Mock(spec=litellm.types.utils.ModelResponse)
|
|
485
|
+
response.get.return_value = litellm_usage
|
|
486
|
+
response._hidden_params = {"response_cost": cost}
|
|
487
|
+
|
|
488
|
+
# Call the method
|
|
489
|
+
result = adapter.usage_from_response(response)
|
|
490
|
+
|
|
491
|
+
# Verify the result
|
|
492
|
+
if expected_usage is None:
|
|
493
|
+
if result is not None:
|
|
494
|
+
assert result.input_tokens is None
|
|
495
|
+
assert result.output_tokens is None
|
|
496
|
+
assert result.total_tokens is None
|
|
497
|
+
assert result.cost is None
|
|
498
|
+
else:
|
|
499
|
+
assert result is not None
|
|
500
|
+
assert result.input_tokens == expected_usage.input_tokens
|
|
501
|
+
assert result.output_tokens == expected_usage.output_tokens
|
|
502
|
+
assert result.total_tokens == expected_usage.total_tokens
|
|
503
|
+
assert result.cost == expected_usage.cost
|
|
504
|
+
|
|
505
|
+
# Verify the response was queried correctly
|
|
506
|
+
response.get.assert_called_once_with("usage", None)
|