kiln-ai 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kiln_ai/adapters/eval/base_eval.py +7 -2
- kiln_ai/adapters/eval/eval_runner.py +5 -64
- kiln_ai/adapters/eval/g_eval.py +3 -3
- kiln_ai/adapters/fine_tune/base_finetune.py +6 -3
- kiln_ai/adapters/fine_tune/dataset_formatter.py +128 -38
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +2 -1
- kiln_ai/adapters/fine_tune/test_base_finetune.py +7 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +267 -10
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +586 -0
- kiln_ai/adapters/fine_tune/vertex_finetune.py +217 -0
- kiln_ai/adapters/ml_model_list.py +817 -62
- kiln_ai/adapters/model_adapters/base_adapter.py +33 -10
- kiln_ai/adapters/model_adapters/litellm_adapter.py +51 -12
- kiln_ai/adapters/model_adapters/test_base_adapter.py +74 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +65 -1
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +3 -2
- kiln_ai/adapters/model_adapters/test_structured_output.py +4 -6
- kiln_ai/adapters/parsers/base_parser.py +0 -3
- kiln_ai/adapters/parsers/parser_registry.py +5 -3
- kiln_ai/adapters/parsers/r1_parser.py +17 -2
- kiln_ai/adapters/parsers/request_formatters.py +40 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
- kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
- kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
- kiln_ai/adapters/prompt_builders.py +14 -1
- kiln_ai/adapters/provider_tools.py +25 -1
- kiln_ai/adapters/repair/test_repair_task.py +3 -2
- kiln_ai/adapters/test_prompt_builders.py +24 -3
- kiln_ai/adapters/test_provider_tools.py +86 -1
- kiln_ai/datamodel/__init__.py +2 -0
- kiln_ai/datamodel/datamodel_enums.py +14 -0
- kiln_ai/datamodel/dataset_filters.py +69 -1
- kiln_ai/datamodel/dataset_split.py +4 -0
- kiln_ai/datamodel/eval.py +8 -0
- kiln_ai/datamodel/finetune.py +1 -0
- kiln_ai/datamodel/json_schema.py +24 -7
- kiln_ai/datamodel/prompt_id.py +1 -0
- kiln_ai/datamodel/task_output.py +10 -6
- kiln_ai/datamodel/task_run.py +68 -12
- kiln_ai/datamodel/test_basemodel.py +3 -7
- kiln_ai/datamodel/test_dataset_filters.py +82 -0
- kiln_ai/datamodel/test_dataset_split.py +2 -0
- kiln_ai/datamodel/test_example_models.py +158 -3
- kiln_ai/datamodel/test_json_schema.py +22 -3
- kiln_ai/datamodel/test_model_perf.py +3 -2
- kiln_ai/datamodel/test_models.py +50 -2
- kiln_ai/utils/async_job_runner.py +106 -0
- kiln_ai/utils/dataset_import.py +80 -18
- kiln_ai/utils/test_async_job_runner.py +199 -0
- kiln_ai/utils/test_dataset_import.py +242 -10
- {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/METADATA +3 -2
- kiln_ai-0.16.0.dist-info/RECORD +108 -0
- kiln_ai/adapters/test_generate_docs.py +0 -69
- kiln_ai-0.14.0.dist-info/RECORD +0 -103
- {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -3,9 +3,12 @@ from abc import ABCMeta, abstractmethod
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from typing import Dict, Literal, Tuple
|
|
5
5
|
|
|
6
|
+
import jsonschema
|
|
7
|
+
|
|
6
8
|
from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
|
|
7
9
|
from kiln_ai.adapters.parsers.json_parser import parse_json_string
|
|
8
10
|
from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id
|
|
11
|
+
from kiln_ai.adapters.parsers.request_formatters import request_formatter_from_id
|
|
9
12
|
from kiln_ai.adapters.prompt_builders import prompt_builder_from_id
|
|
10
13
|
from kiln_ai.adapters.provider_tools import kiln_model_provider_from
|
|
11
14
|
from kiln_ai.adapters.run_output import RunOutput
|
|
@@ -15,8 +18,9 @@ from kiln_ai.datamodel import (
|
|
|
15
18
|
Task,
|
|
16
19
|
TaskOutput,
|
|
17
20
|
TaskRun,
|
|
21
|
+
Usage,
|
|
18
22
|
)
|
|
19
|
-
from kiln_ai.datamodel.json_schema import
|
|
23
|
+
from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
|
|
20
24
|
from kiln_ai.datamodel.task import RunConfig
|
|
21
25
|
from kiln_ai.utils.config import Config
|
|
22
26
|
|
|
@@ -103,16 +107,26 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
103
107
|
if self.input_schema is not None:
|
|
104
108
|
if not isinstance(input, dict):
|
|
105
109
|
raise ValueError(f"structured input is not a dict: {input}")
|
|
106
|
-
|
|
110
|
+
|
|
111
|
+
validate_schema_with_value_error(
|
|
112
|
+
input,
|
|
113
|
+
self.input_schema,
|
|
114
|
+
"This task requires a specific input schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information.",
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Format model input for model call (we save the original input in the task without formatting)
|
|
118
|
+
formatted_input = input
|
|
119
|
+
formatter_id = self.model_provider().formatter
|
|
120
|
+
if formatter_id is not None:
|
|
121
|
+
formatter = request_formatter_from_id(formatter_id)
|
|
122
|
+
formatted_input = formatter.format_input(input)
|
|
107
123
|
|
|
108
124
|
# Run
|
|
109
|
-
run_output = await self._run(
|
|
125
|
+
run_output, usage = await self._run(formatted_input)
|
|
110
126
|
|
|
111
127
|
# Parse
|
|
112
128
|
provider = self.model_provider()
|
|
113
|
-
parser = model_parser_from_id(provider.parser)
|
|
114
|
-
structured_output=self.has_structured_output()
|
|
115
|
-
)
|
|
129
|
+
parser = model_parser_from_id(provider.parser)
|
|
116
130
|
parsed_output = parser.parse_output(original_output=run_output)
|
|
117
131
|
|
|
118
132
|
# validate output
|
|
@@ -125,7 +139,11 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
125
139
|
raise RuntimeError(
|
|
126
140
|
f"structured response is not a dict: {parsed_output.output}"
|
|
127
141
|
)
|
|
128
|
-
|
|
142
|
+
validate_schema_with_value_error(
|
|
143
|
+
parsed_output.output,
|
|
144
|
+
self.output_schema,
|
|
145
|
+
"This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information.",
|
|
146
|
+
)
|
|
129
147
|
else:
|
|
130
148
|
if not isinstance(parsed_output.output, str):
|
|
131
149
|
raise RuntimeError(
|
|
@@ -142,7 +160,7 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
142
160
|
)
|
|
143
161
|
|
|
144
162
|
# Generate the run and output
|
|
145
|
-
run = self.generate_run(input, input_source, parsed_output)
|
|
163
|
+
run = self.generate_run(input, input_source, parsed_output, usage)
|
|
146
164
|
|
|
147
165
|
# Save the run if configured to do so, and we have a path to save to
|
|
148
166
|
if (
|
|
@@ -165,7 +183,7 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
165
183
|
pass
|
|
166
184
|
|
|
167
185
|
@abstractmethod
|
|
168
|
-
async def _run(self, input: Dict | str) -> RunOutput:
|
|
186
|
+
async def _run(self, input: Dict | str) -> Tuple[RunOutput, Usage | None]:
|
|
169
187
|
pass
|
|
170
188
|
|
|
171
189
|
def build_prompt(self) -> str:
|
|
@@ -204,7 +222,11 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
204
222
|
|
|
205
223
|
# create a run and task output
|
|
206
224
|
def generate_run(
|
|
207
|
-
self,
|
|
225
|
+
self,
|
|
226
|
+
input: Dict | str,
|
|
227
|
+
input_source: DataSource | None,
|
|
228
|
+
run_output: RunOutput,
|
|
229
|
+
usage: Usage | None = None,
|
|
208
230
|
) -> TaskRun:
|
|
209
231
|
# Convert input and output to JSON strings if they are dictionaries
|
|
210
232
|
input_str = (
|
|
@@ -237,6 +259,7 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
237
259
|
),
|
|
238
260
|
intermediate_outputs=run_output.intermediate_outputs,
|
|
239
261
|
tags=self.base_adapter_config.default_tags or [],
|
|
262
|
+
usage=usage,
|
|
240
263
|
)
|
|
241
264
|
|
|
242
265
|
return new_task_run
|
|
@@ -1,7 +1,9 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from typing import Any, Dict
|
|
2
3
|
|
|
3
4
|
import litellm
|
|
4
5
|
from litellm.types.utils import ChoiceLogprobs, Choices, ModelResponse
|
|
6
|
+
from litellm.types.utils import Usage as LiteLlmUsage
|
|
5
7
|
|
|
6
8
|
import kiln_ai.datamodel as datamodel
|
|
7
9
|
from kiln_ai.adapters.ml_model_list import (
|
|
@@ -14,14 +16,15 @@ from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
|
14
16
|
AdapterConfig,
|
|
15
17
|
BaseAdapter,
|
|
16
18
|
RunOutput,
|
|
19
|
+
Usage,
|
|
17
20
|
)
|
|
18
|
-
from kiln_ai.adapters.model_adapters.litellm_config import
|
|
19
|
-
LiteLlmConfig,
|
|
20
|
-
)
|
|
21
|
+
from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
|
|
21
22
|
from kiln_ai.datamodel import PromptGenerators, PromptId
|
|
22
23
|
from kiln_ai.datamodel.task import RunConfig
|
|
23
24
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
24
25
|
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
25
28
|
|
|
26
29
|
class LiteLlmAdapter(BaseAdapter):
|
|
27
30
|
def __init__(
|
|
@@ -49,7 +52,7 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
49
52
|
config=base_adapter_config,
|
|
50
53
|
)
|
|
51
54
|
|
|
52
|
-
async def _run(self, input: Dict | str) -> RunOutput:
|
|
55
|
+
async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]:
|
|
53
56
|
provider = self.model_provider()
|
|
54
57
|
if not provider.model_id:
|
|
55
58
|
raise ValueError("Model ID is required for OpenAI compatible models")
|
|
@@ -65,6 +68,7 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
65
68
|
run_strategy, cot_prompt = self.run_strategy()
|
|
66
69
|
|
|
67
70
|
if run_strategy == "cot_as_message":
|
|
71
|
+
# Used for reasoning-capable models that can output thinking and structured format
|
|
68
72
|
if not cot_prompt:
|
|
69
73
|
raise ValueError("cot_prompt is required for cot_as_message strategy")
|
|
70
74
|
messages.append({"role": "system", "content": cot_prompt})
|
|
@@ -73,9 +77,11 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
73
77
|
raise ValueError("cot_prompt is required for cot_two_call strategy")
|
|
74
78
|
messages.append({"role": "system", "content": cot_prompt})
|
|
75
79
|
|
|
76
|
-
# First call for chain of thought
|
|
80
|
+
# First call for chain of thought
|
|
81
|
+
# No response format as this request is for "thinking" in plain text
|
|
82
|
+
# No logprobs as only needed for final answer
|
|
77
83
|
completion_kwargs = await self.build_completion_kwargs(
|
|
78
|
-
provider, messages, None
|
|
84
|
+
provider, messages, None, skip_response_format=True
|
|
79
85
|
)
|
|
80
86
|
cot_response = await litellm.acompletion(**completion_kwargs)
|
|
81
87
|
if (
|
|
@@ -136,8 +142,12 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
136
142
|
raise RuntimeError("Logprobs were required, but no logprobs were returned.")
|
|
137
143
|
|
|
138
144
|
# Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream)
|
|
139
|
-
if
|
|
140
|
-
|
|
145
|
+
if (
|
|
146
|
+
hasattr(message, "reasoning_content")
|
|
147
|
+
and message.reasoning_content
|
|
148
|
+
and len(message.reasoning_content.strip()) > 0
|
|
149
|
+
):
|
|
150
|
+
intermediate_outputs["reasoning"] = message.reasoning_content.strip()
|
|
141
151
|
|
|
142
152
|
# the string content of the response
|
|
143
153
|
response_content = message.content
|
|
@@ -166,7 +176,7 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
166
176
|
output=response_content,
|
|
167
177
|
intermediate_outputs=intermediate_outputs,
|
|
168
178
|
output_logprobs=logprobs,
|
|
169
|
-
)
|
|
179
|
+
), self.usage_from_response(response)
|
|
170
180
|
|
|
171
181
|
def adapter_name(self) -> str:
|
|
172
182
|
return "kiln_openai_compatible_adapter"
|
|
@@ -367,6 +377,7 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
367
377
|
provider: KilnModelProvider,
|
|
368
378
|
messages: list[dict[str, Any]],
|
|
369
379
|
top_logprobs: int | None,
|
|
380
|
+
skip_response_format: bool = False,
|
|
370
381
|
) -> dict[str, Any]:
|
|
371
382
|
extra_body = self.build_extra_body(provider)
|
|
372
383
|
|
|
@@ -380,12 +391,40 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
380
391
|
**self._additional_body_options,
|
|
381
392
|
}
|
|
382
393
|
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
394
|
+
if not skip_response_format:
|
|
395
|
+
# Response format: json_schema, json_instructions, json_mode, function_calling, etc
|
|
396
|
+
response_format_options = await self.response_format_options()
|
|
397
|
+
completion_kwargs.update(response_format_options)
|
|
386
398
|
|
|
387
399
|
if top_logprobs is not None:
|
|
388
400
|
completion_kwargs["logprobs"] = True
|
|
389
401
|
completion_kwargs["top_logprobs"] = top_logprobs
|
|
390
402
|
|
|
391
403
|
return completion_kwargs
|
|
404
|
+
|
|
405
|
+
def usage_from_response(self, response: ModelResponse) -> Usage | None:
|
|
406
|
+
litellm_usage = response.get("usage", None)
|
|
407
|
+
cost = response._hidden_params.get("response_cost", None)
|
|
408
|
+
if not litellm_usage and not cost:
|
|
409
|
+
return None
|
|
410
|
+
|
|
411
|
+
usage = Usage()
|
|
412
|
+
|
|
413
|
+
if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
|
|
414
|
+
usage.input_tokens = litellm_usage.get("prompt_tokens", None)
|
|
415
|
+
usage.output_tokens = litellm_usage.get("completion_tokens", None)
|
|
416
|
+
usage.total_tokens = litellm_usage.get("total_tokens", None)
|
|
417
|
+
else:
|
|
418
|
+
logger.warning(
|
|
419
|
+
f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
if isinstance(cost, float):
|
|
423
|
+
usage.cost = cost
|
|
424
|
+
elif cost is not None:
|
|
425
|
+
# None is allowed, but no other types are expected
|
|
426
|
+
logger.warning(
|
|
427
|
+
f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
return usage
|
|
@@ -3,7 +3,8 @@ from unittest.mock import MagicMock, patch
|
|
|
3
3
|
import pytest
|
|
4
4
|
|
|
5
5
|
from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
|
|
6
|
-
from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter
|
|
6
|
+
from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput
|
|
7
|
+
from kiln_ai.adapters.parsers.request_formatters import request_formatter_from_id
|
|
7
8
|
from kiln_ai.datamodel import Task
|
|
8
9
|
from kiln_ai.datamodel.task import RunConfig
|
|
9
10
|
|
|
@@ -12,7 +13,7 @@ class MockAdapter(BaseAdapter):
|
|
|
12
13
|
"""Concrete implementation of BaseAdapter for testing"""
|
|
13
14
|
|
|
14
15
|
async def _run(self, input):
|
|
15
|
-
return None
|
|
16
|
+
return None, None
|
|
16
17
|
|
|
17
18
|
def adapter_name(self) -> str:
|
|
18
19
|
return "test"
|
|
@@ -42,6 +43,22 @@ def adapter(base_task):
|
|
|
42
43
|
)
|
|
43
44
|
|
|
44
45
|
|
|
46
|
+
@pytest.fixture
|
|
47
|
+
def mock_formatter():
|
|
48
|
+
formatter = MagicMock()
|
|
49
|
+
formatter.format_input.return_value = {"formatted": "input"}
|
|
50
|
+
return formatter
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@pytest.fixture
|
|
54
|
+
def mock_parser():
|
|
55
|
+
parser = MagicMock()
|
|
56
|
+
parser.parse_output.return_value = RunOutput(
|
|
57
|
+
output="test output", intermediate_outputs={}
|
|
58
|
+
)
|
|
59
|
+
return parser
|
|
60
|
+
|
|
61
|
+
|
|
45
62
|
async def test_model_provider_uses_cache(adapter, mock_provider):
|
|
46
63
|
"""Test that cached provider is returned if it exists"""
|
|
47
64
|
# Set up cached provider
|
|
@@ -197,3 +214,58 @@ async def test_run_strategy(
|
|
|
197
214
|
# Test
|
|
198
215
|
result = adapter.run_strategy()
|
|
199
216
|
assert result == expected
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@pytest.mark.asyncio
|
|
220
|
+
@pytest.mark.parametrize(
|
|
221
|
+
"formatter_id,expected_input,expected_calls",
|
|
222
|
+
[
|
|
223
|
+
(None, {"original": "input"}, 0), # No formatter
|
|
224
|
+
("test_formatter", {"formatted": "input"}, 1), # With formatter
|
|
225
|
+
],
|
|
226
|
+
)
|
|
227
|
+
async def test_input_formatting(
|
|
228
|
+
adapter, mock_formatter, mock_parser, formatter_id, expected_input, expected_calls
|
|
229
|
+
):
|
|
230
|
+
"""Test that input formatting is handled correctly based on formatter configuration"""
|
|
231
|
+
# Mock the model provider to return our formatter ID and parser
|
|
232
|
+
provider = MagicMock()
|
|
233
|
+
provider.formatter = formatter_id
|
|
234
|
+
provider.parser = "test_parser"
|
|
235
|
+
provider.reasoning_capable = False
|
|
236
|
+
adapter.model_provider = MagicMock(return_value=provider)
|
|
237
|
+
|
|
238
|
+
# Mock the formatter factory and parser factory
|
|
239
|
+
with (
|
|
240
|
+
patch(
|
|
241
|
+
"kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id"
|
|
242
|
+
) as mock_factory,
|
|
243
|
+
patch(
|
|
244
|
+
"kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id"
|
|
245
|
+
) as mock_parser_factory,
|
|
246
|
+
):
|
|
247
|
+
mock_factory.return_value = mock_formatter
|
|
248
|
+
mock_parser_factory.return_value = mock_parser
|
|
249
|
+
|
|
250
|
+
# Mock the _run method to capture the input
|
|
251
|
+
captured_input = None
|
|
252
|
+
|
|
253
|
+
async def mock_run(input):
|
|
254
|
+
nonlocal captured_input
|
|
255
|
+
captured_input = input
|
|
256
|
+
return RunOutput(output="test output", intermediate_outputs={}), None
|
|
257
|
+
|
|
258
|
+
adapter._run = mock_run
|
|
259
|
+
|
|
260
|
+
# Run the adapter
|
|
261
|
+
original_input = {"original": "input"}
|
|
262
|
+
await adapter.invoke_returning_run_output(original_input)
|
|
263
|
+
|
|
264
|
+
# Verify formatter was called correctly
|
|
265
|
+
assert captured_input == expected_input
|
|
266
|
+
assert mock_factory.call_count == (1 if formatter_id else 0)
|
|
267
|
+
assert mock_formatter.format_input.call_count == expected_calls
|
|
268
|
+
|
|
269
|
+
# Verify original input was preserved in the run
|
|
270
|
+
if formatter_id:
|
|
271
|
+
mock_formatter.format_input.assert_called_once_with(original_input)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from unittest.mock import Mock, patch
|
|
3
3
|
|
|
4
|
+
import litellm
|
|
4
5
|
import pytest
|
|
5
6
|
|
|
6
7
|
from kiln_ai.adapters.ml_model_list import ModelProviderName, StructuredOutputMode
|
|
@@ -9,7 +10,7 @@ from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter
|
|
|
9
10
|
from kiln_ai.adapters.model_adapters.litellm_config import (
|
|
10
11
|
LiteLlmConfig,
|
|
11
12
|
)
|
|
12
|
-
from kiln_ai.datamodel import Project, Task
|
|
13
|
+
from kiln_ai.datamodel import Project, Task, Usage
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
@pytest.fixture
|
|
@@ -405,3 +406,66 @@ async def test_build_completion_kwargs(
|
|
|
405
406
|
# Verify extra body is included
|
|
406
407
|
for key, value in extra_body.items():
|
|
407
408
|
assert kwargs[key] == value
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
@pytest.mark.parametrize(
|
|
412
|
+
"litellm_usage,cost,expected_usage",
|
|
413
|
+
[
|
|
414
|
+
# No usage data
|
|
415
|
+
(None, None, None),
|
|
416
|
+
# Only cost
|
|
417
|
+
(None, 0.5, Usage(cost=0.5)),
|
|
418
|
+
# Only token counts
|
|
419
|
+
(
|
|
420
|
+
litellm.types.utils.Usage(
|
|
421
|
+
prompt_tokens=10,
|
|
422
|
+
completion_tokens=20,
|
|
423
|
+
total_tokens=30,
|
|
424
|
+
),
|
|
425
|
+
None,
|
|
426
|
+
Usage(input_tokens=10, output_tokens=20, total_tokens=30),
|
|
427
|
+
),
|
|
428
|
+
# Both cost and token counts
|
|
429
|
+
(
|
|
430
|
+
litellm.types.utils.Usage(
|
|
431
|
+
prompt_tokens=10,
|
|
432
|
+
completion_tokens=20,
|
|
433
|
+
total_tokens=30,
|
|
434
|
+
),
|
|
435
|
+
0.5,
|
|
436
|
+
Usage(input_tokens=10, output_tokens=20, total_tokens=30, cost=0.5),
|
|
437
|
+
),
|
|
438
|
+
# Invalid usage type (should be ignored)
|
|
439
|
+
({"prompt_tokens": 10}, None, None),
|
|
440
|
+
# Invalid cost type (should be ignored)
|
|
441
|
+
(None, "0.5", None),
|
|
442
|
+
],
|
|
443
|
+
)
|
|
444
|
+
def test_usage_from_response(config, mock_task, litellm_usage, cost, expected_usage):
|
|
445
|
+
"""Test usage_from_response with various combinations of usage data and cost"""
|
|
446
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
447
|
+
|
|
448
|
+
# Create a mock response
|
|
449
|
+
response = Mock(spec=litellm.types.utils.ModelResponse)
|
|
450
|
+
response.get.return_value = litellm_usage
|
|
451
|
+
response._hidden_params = {"response_cost": cost}
|
|
452
|
+
|
|
453
|
+
# Call the method
|
|
454
|
+
result = adapter.usage_from_response(response)
|
|
455
|
+
|
|
456
|
+
# Verify the result
|
|
457
|
+
if expected_usage is None:
|
|
458
|
+
if result is not None:
|
|
459
|
+
assert result.input_tokens is None
|
|
460
|
+
assert result.output_tokens is None
|
|
461
|
+
assert result.total_tokens is None
|
|
462
|
+
assert result.cost is None
|
|
463
|
+
else:
|
|
464
|
+
assert result is not None
|
|
465
|
+
assert result.input_tokens == expected_usage.input_tokens
|
|
466
|
+
assert result.output_tokens == expected_usage.output_tokens
|
|
467
|
+
assert result.total_tokens == expected_usage.total_tokens
|
|
468
|
+
assert result.cost == expected_usage.cost
|
|
469
|
+
|
|
470
|
+
# Verify the response was queried correctly
|
|
471
|
+
response.get.assert_called_once_with("usage", None)
|
|
@@ -11,14 +11,15 @@ from kiln_ai.datamodel import (
|
|
|
11
11
|
DataSourceType,
|
|
12
12
|
Project,
|
|
13
13
|
Task,
|
|
14
|
+
Usage,
|
|
14
15
|
)
|
|
15
16
|
from kiln_ai.datamodel.task import RunConfig
|
|
16
17
|
from kiln_ai.utils.config import Config
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class MockAdapter(BaseAdapter):
|
|
20
|
-
async def _run(self, input: dict | str) ->
|
|
21
|
-
return RunOutput(output="Test output", intermediate_outputs=None)
|
|
21
|
+
async def _run(self, input: dict | str) -> tuple[RunOutput, Usage | None]:
|
|
22
|
+
return RunOutput(output="Test output", intermediate_outputs=None), None
|
|
22
23
|
|
|
23
24
|
def adapter_name(self) -> str:
|
|
24
25
|
return "mock_adapter"
|
|
@@ -12,6 +12,7 @@ from kiln_ai.adapters.ml_model_list import (
|
|
|
12
12
|
from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
13
13
|
BaseAdapter,
|
|
14
14
|
RunOutput,
|
|
15
|
+
Usage,
|
|
15
16
|
)
|
|
16
17
|
from kiln_ai.adapters.ollama_tools import ollama_online
|
|
17
18
|
from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
|
|
@@ -54,8 +55,8 @@ class MockAdapter(BaseAdapter):
|
|
|
54
55
|
)
|
|
55
56
|
self.response = response
|
|
56
57
|
|
|
57
|
-
async def _run(self, input: str) -> RunOutput:
|
|
58
|
-
return RunOutput(output=self.response, intermediate_outputs=None)
|
|
58
|
+
async def _run(self, input: str) -> tuple[RunOutput, Usage | None]:
|
|
59
|
+
return RunOutput(output=self.response, intermediate_outputs=None), None
|
|
59
60
|
|
|
60
61
|
def adapter_name(self) -> str:
|
|
61
62
|
return "mock_adapter"
|
|
@@ -223,10 +224,7 @@ async def run_structured_input_task(
|
|
|
223
224
|
with pytest.raises(ValueError):
|
|
224
225
|
# not structured input in dictionary
|
|
225
226
|
await a.invoke("a=1, b=2, c=3")
|
|
226
|
-
with pytest.raises(
|
|
227
|
-
ValueError,
|
|
228
|
-
match="This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema.",
|
|
229
|
-
):
|
|
227
|
+
with pytest.raises(ValueError, match="This task requires a specific input"):
|
|
230
228
|
# invalid structured input
|
|
231
229
|
await a.invoke({"a": 1, "b": 2, "d": 3})
|
|
232
230
|
|
|
@@ -2,9 +2,6 @@ from kiln_ai.adapters.run_output import RunOutput
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
class BaseParser:
|
|
5
|
-
def __init__(self, structured_output: bool = False):
|
|
6
|
-
self.structured_output = structured_output
|
|
7
|
-
|
|
8
5
|
def parse_output(self, original_output: RunOutput) -> RunOutput:
|
|
9
6
|
"""
|
|
10
7
|
Method for parsing the output of a model. Typically overridden by subclasses.
|
|
@@ -6,14 +6,16 @@ from kiln_ai.adapters.parsers.r1_parser import R1ThinkingParser
|
|
|
6
6
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def model_parser_from_id(parser_id: ModelParserID | None) ->
|
|
9
|
+
def model_parser_from_id(parser_id: ModelParserID | None) -> BaseParser:
|
|
10
10
|
"""
|
|
11
11
|
Get a model parser from its ID.
|
|
12
12
|
"""
|
|
13
13
|
match parser_id:
|
|
14
14
|
case None:
|
|
15
|
-
return BaseParser
|
|
15
|
+
return BaseParser()
|
|
16
16
|
case ModelParserID.r1_thinking:
|
|
17
|
-
return R1ThinkingParser
|
|
17
|
+
return R1ThinkingParser()
|
|
18
|
+
case ModelParserID.optional_r1_thinking:
|
|
19
|
+
return R1ThinkingParser(allow_missing_thinking=True)
|
|
18
20
|
case _:
|
|
19
21
|
raise_exhaustive_enum_error(parser_id)
|
|
@@ -7,6 +7,9 @@ class R1ThinkingParser(BaseParser):
|
|
|
7
7
|
START_TAG = "<think>"
|
|
8
8
|
END_TAG = "</think>"
|
|
9
9
|
|
|
10
|
+
def __init__(self, allow_missing_thinking: bool = False):
|
|
11
|
+
self.allow_missing_thinking = allow_missing_thinking
|
|
12
|
+
|
|
10
13
|
def parse_output(self, original_output: RunOutput) -> RunOutput:
|
|
11
14
|
"""
|
|
12
15
|
Parse the <think> </think> tags from the response into the intermediate and final outputs.
|
|
@@ -27,6 +30,14 @@ class R1ThinkingParser(BaseParser):
|
|
|
27
30
|
original_output.intermediate_outputs is not None
|
|
28
31
|
and "reasoning" in original_output.intermediate_outputs
|
|
29
32
|
):
|
|
33
|
+
# sometimes the output and reasoning are wrapped in newlines
|
|
34
|
+
if isinstance(original_output.output, str):
|
|
35
|
+
original_output.output = original_output.output.strip()
|
|
36
|
+
|
|
37
|
+
original_output.intermediate_outputs["reasoning"] = (
|
|
38
|
+
original_output.intermediate_outputs["reasoning"].strip()
|
|
39
|
+
)
|
|
40
|
+
|
|
30
41
|
return original_output
|
|
31
42
|
|
|
32
43
|
# This parser only works for strings
|
|
@@ -39,7 +50,10 @@ class R1ThinkingParser(BaseParser):
|
|
|
39
50
|
# Find the thinking tags
|
|
40
51
|
think_end = cleaned_response.find(self.END_TAG)
|
|
41
52
|
if think_end == -1:
|
|
42
|
-
|
|
53
|
+
if self.allow_missing_thinking:
|
|
54
|
+
return original_output
|
|
55
|
+
else:
|
|
56
|
+
raise ValueError("Missing </think> tag")
|
|
43
57
|
|
|
44
58
|
think_tag_start = cleaned_response.find(self.START_TAG)
|
|
45
59
|
if think_tag_start == -1:
|
|
@@ -66,7 +80,8 @@ class R1ThinkingParser(BaseParser):
|
|
|
66
80
|
|
|
67
81
|
# Add thinking content to intermediate outputs if it exists
|
|
68
82
|
intermediate_outputs = original_output.intermediate_outputs or {}
|
|
69
|
-
|
|
83
|
+
if thinking_content is not None and len(thinking_content) > 0:
|
|
84
|
+
intermediate_outputs["reasoning"] = thinking_content
|
|
70
85
|
|
|
71
86
|
return RunOutput(
|
|
72
87
|
output=result,
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Dict, Protocol
|
|
3
|
+
|
|
4
|
+
from kiln_ai.adapters.ml_model_list import ModelFormatterID
|
|
5
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RequestFormatter(Protocol):
|
|
9
|
+
def format_input(self, original_input: Dict | str) -> Dict | str:
|
|
10
|
+
"""
|
|
11
|
+
Method for formatting the input to a model.
|
|
12
|
+
"""
|
|
13
|
+
...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Qwen3StyleNoThinkFormatter:
|
|
17
|
+
def format_input(self, original_input: Dict | str) -> Dict | str:
|
|
18
|
+
"""
|
|
19
|
+
Format the input to a model for Qwen3 /no_think instruction
|
|
20
|
+
"""
|
|
21
|
+
formatted_input = (
|
|
22
|
+
original_input
|
|
23
|
+
if isinstance(original_input, str)
|
|
24
|
+
else json.dumps(original_input, indent=2)
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
return formatted_input + "\n\n/no_think"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def request_formatter_from_id(
|
|
31
|
+
formatter_id: ModelFormatterID,
|
|
32
|
+
) -> RequestFormatter:
|
|
33
|
+
"""
|
|
34
|
+
Get a model parser from its ID.
|
|
35
|
+
"""
|
|
36
|
+
match formatter_id:
|
|
37
|
+
case ModelFormatterID.qwen3_style_no_think:
|
|
38
|
+
return Qwen3StyleNoThinkFormatter()
|
|
39
|
+
case _:
|
|
40
|
+
raise_exhaustive_enum_error(formatter_id)
|
|
@@ -28,5 +28,5 @@ def test_model_parser_from_id_invalid():
|
|
|
28
28
|
)
|
|
29
29
|
def test_model_parser_from_id_parametrized(parser_id, expected_class):
|
|
30
30
|
"""Test all valid parser IDs using parametrize."""
|
|
31
|
-
|
|
32
|
-
assert
|
|
31
|
+
parser = model_parser_from_id(parser_id)
|
|
32
|
+
assert isinstance(parser, expected_class)
|
|
@@ -46,6 +46,21 @@ def test_response_with_whitespace(parser):
|
|
|
46
46
|
assert parsed.output.strip() == "This is the result"
|
|
47
47
|
|
|
48
48
|
|
|
49
|
+
def test_empty_thinking_content(parser):
|
|
50
|
+
response = RunOutput(
|
|
51
|
+
output="""
|
|
52
|
+
<think>
|
|
53
|
+
|
|
54
|
+
</think>
|
|
55
|
+
This is the result
|
|
56
|
+
""",
|
|
57
|
+
intermediate_outputs=None,
|
|
58
|
+
)
|
|
59
|
+
parsed = parser.parse_output(response)
|
|
60
|
+
assert "reasoning" not in parsed.intermediate_outputs
|
|
61
|
+
assert parsed.output.strip() == "This is the result"
|
|
62
|
+
|
|
63
|
+
|
|
49
64
|
def test_missing_start_tag(parser):
|
|
50
65
|
parsed = parser.parse_output(
|
|
51
66
|
RunOutput(output="Some content</think>result", intermediate_outputs=None)
|
|
@@ -86,7 +101,7 @@ def test_empty_thinking_content(parser):
|
|
|
86
101
|
output="<think></think>This is the result", intermediate_outputs=None
|
|
87
102
|
)
|
|
88
103
|
parsed = parser.parse_output(response)
|
|
89
|
-
assert
|
|
104
|
+
assert "reasoning" not in parsed.intermediate_outputs
|
|
90
105
|
assert parsed.output == "This is the result"
|
|
91
106
|
|
|
92
107
|
|
|
@@ -154,3 +169,31 @@ def test_intermediate_outputs(parser):
|
|
|
154
169
|
)
|
|
155
170
|
)
|
|
156
171
|
assert out.intermediate_outputs["reasoning"] == "Some content"
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def test_strip_newlines(parser):
|
|
175
|
+
# certain providers via LiteLLM for example, add newlines to the output
|
|
176
|
+
# and to the reasoning. This tests that we strip those newlines.
|
|
177
|
+
response = RunOutput(
|
|
178
|
+
output="\n\nSome content",
|
|
179
|
+
intermediate_outputs={
|
|
180
|
+
"reasoning": "\n\nSome thinking\n\n",
|
|
181
|
+
},
|
|
182
|
+
)
|
|
183
|
+
parsed = parser.parse_output(response)
|
|
184
|
+
assert parsed.output == "Some content"
|
|
185
|
+
assert parsed.intermediate_outputs["reasoning"] == "Some thinking"
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def test_strip_newlines_with_structured_output(parser):
|
|
189
|
+
# certain providers via LiteLLM for example, add newlines to the output
|
|
190
|
+
# and to the reasoning. This tests that we strip those newlines.
|
|
191
|
+
response = RunOutput(
|
|
192
|
+
output={"some_key": "Some content"},
|
|
193
|
+
intermediate_outputs={
|
|
194
|
+
"reasoning": "\n\nSome thinking\n\n",
|
|
195
|
+
},
|
|
196
|
+
)
|
|
197
|
+
parsed = parser.parse_output(response)
|
|
198
|
+
assert parsed.output == {"some_key": "Some content"}
|
|
199
|
+
assert parsed.intermediate_outputs["reasoning"] == "Some thinking"
|