kiln-ai 0.15.0__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kiln_ai/adapters/eval/eval_runner.py +5 -64
- kiln_ai/adapters/eval/g_eval.py +3 -3
- kiln_ai/adapters/fine_tune/dataset_formatter.py +124 -34
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +264 -7
- kiln_ai/adapters/ml_model_list.py +478 -4
- kiln_ai/adapters/model_adapters/base_adapter.py +26 -8
- kiln_ai/adapters/model_adapters/litellm_adapter.py +41 -7
- kiln_ai/adapters/model_adapters/test_base_adapter.py +74 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +65 -1
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +3 -2
- kiln_ai/adapters/model_adapters/test_structured_output.py +4 -6
- kiln_ai/adapters/parsers/base_parser.py +0 -3
- kiln_ai/adapters/parsers/parser_registry.py +5 -3
- kiln_ai/adapters/parsers/r1_parser.py +17 -2
- kiln_ai/adapters/parsers/request_formatters.py +40 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
- kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
- kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
- kiln_ai/adapters/prompt_builders.py +14 -1
- kiln_ai/adapters/provider_tools.py +18 -1
- kiln_ai/adapters/repair/test_repair_task.py +3 -2
- kiln_ai/adapters/test_prompt_builders.py +24 -3
- kiln_ai/adapters/test_provider_tools.py +70 -1
- kiln_ai/datamodel/__init__.py +2 -0
- kiln_ai/datamodel/datamodel_enums.py +14 -0
- kiln_ai/datamodel/dataset_filters.py +69 -1
- kiln_ai/datamodel/dataset_split.py +4 -0
- kiln_ai/datamodel/eval.py +8 -0
- kiln_ai/datamodel/finetune.py +1 -0
- kiln_ai/datamodel/prompt_id.py +1 -0
- kiln_ai/datamodel/task_output.py +1 -1
- kiln_ai/datamodel/task_run.py +39 -7
- kiln_ai/datamodel/test_basemodel.py +3 -7
- kiln_ai/datamodel/test_dataset_filters.py +82 -0
- kiln_ai/datamodel/test_dataset_split.py +2 -0
- kiln_ai/datamodel/test_example_models.py +54 -0
- kiln_ai/datamodel/test_models.py +50 -2
- kiln_ai/utils/async_job_runner.py +106 -0
- kiln_ai/utils/dataset_import.py +80 -18
- kiln_ai/utils/test_async_job_runner.py +199 -0
- kiln_ai/utils/test_dataset_import.py +242 -10
- {kiln_ai-0.15.0.dist-info → kiln_ai-0.16.0.dist-info}/METADATA +1 -1
- {kiln_ai-0.15.0.dist-info → kiln_ai-0.16.0.dist-info}/RECORD +45 -41
- {kiln_ai-0.15.0.dist-info → kiln_ai-0.16.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.15.0.dist-info → kiln_ai-0.16.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,7 +1,9 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from typing import Any, Dict
|
|
2
3
|
|
|
3
4
|
import litellm
|
|
4
5
|
from litellm.types.utils import ChoiceLogprobs, Choices, ModelResponse
|
|
6
|
+
from litellm.types.utils import Usage as LiteLlmUsage
|
|
5
7
|
|
|
6
8
|
import kiln_ai.datamodel as datamodel
|
|
7
9
|
from kiln_ai.adapters.ml_model_list import (
|
|
@@ -14,14 +16,15 @@ from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
|
14
16
|
AdapterConfig,
|
|
15
17
|
BaseAdapter,
|
|
16
18
|
RunOutput,
|
|
19
|
+
Usage,
|
|
17
20
|
)
|
|
18
|
-
from kiln_ai.adapters.model_adapters.litellm_config import
|
|
19
|
-
LiteLlmConfig,
|
|
20
|
-
)
|
|
21
|
+
from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
|
|
21
22
|
from kiln_ai.datamodel import PromptGenerators, PromptId
|
|
22
23
|
from kiln_ai.datamodel.task import RunConfig
|
|
23
24
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
24
25
|
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
25
28
|
|
|
26
29
|
class LiteLlmAdapter(BaseAdapter):
|
|
27
30
|
def __init__(
|
|
@@ -49,7 +52,7 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
49
52
|
config=base_adapter_config,
|
|
50
53
|
)
|
|
51
54
|
|
|
52
|
-
async def _run(self, input: Dict | str) -> RunOutput:
|
|
55
|
+
async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]:
|
|
53
56
|
provider = self.model_provider()
|
|
54
57
|
if not provider.model_id:
|
|
55
58
|
raise ValueError("Model ID is required for OpenAI compatible models")
|
|
@@ -139,8 +142,12 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
139
142
|
raise RuntimeError("Logprobs were required, but no logprobs were returned.")
|
|
140
143
|
|
|
141
144
|
# Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream)
|
|
142
|
-
if
|
|
143
|
-
|
|
145
|
+
if (
|
|
146
|
+
hasattr(message, "reasoning_content")
|
|
147
|
+
and message.reasoning_content
|
|
148
|
+
and len(message.reasoning_content.strip()) > 0
|
|
149
|
+
):
|
|
150
|
+
intermediate_outputs["reasoning"] = message.reasoning_content.strip()
|
|
144
151
|
|
|
145
152
|
# the string content of the response
|
|
146
153
|
response_content = message.content
|
|
@@ -169,7 +176,7 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
169
176
|
output=response_content,
|
|
170
177
|
intermediate_outputs=intermediate_outputs,
|
|
171
178
|
output_logprobs=logprobs,
|
|
172
|
-
)
|
|
179
|
+
), self.usage_from_response(response)
|
|
173
180
|
|
|
174
181
|
def adapter_name(self) -> str:
|
|
175
182
|
return "kiln_openai_compatible_adapter"
|
|
@@ -394,3 +401,30 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
394
401
|
completion_kwargs["top_logprobs"] = top_logprobs
|
|
395
402
|
|
|
396
403
|
return completion_kwargs
|
|
404
|
+
|
|
405
|
+
def usage_from_response(self, response: ModelResponse) -> Usage | None:
|
|
406
|
+
litellm_usage = response.get("usage", None)
|
|
407
|
+
cost = response._hidden_params.get("response_cost", None)
|
|
408
|
+
if not litellm_usage and not cost:
|
|
409
|
+
return None
|
|
410
|
+
|
|
411
|
+
usage = Usage()
|
|
412
|
+
|
|
413
|
+
if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
|
|
414
|
+
usage.input_tokens = litellm_usage.get("prompt_tokens", None)
|
|
415
|
+
usage.output_tokens = litellm_usage.get("completion_tokens", None)
|
|
416
|
+
usage.total_tokens = litellm_usage.get("total_tokens", None)
|
|
417
|
+
else:
|
|
418
|
+
logger.warning(
|
|
419
|
+
f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
if isinstance(cost, float):
|
|
423
|
+
usage.cost = cost
|
|
424
|
+
elif cost is not None:
|
|
425
|
+
# None is allowed, but no other types are expected
|
|
426
|
+
logger.warning(
|
|
427
|
+
f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
return usage
|
|
@@ -3,7 +3,8 @@ from unittest.mock import MagicMock, patch
|
|
|
3
3
|
import pytest
|
|
4
4
|
|
|
5
5
|
from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
|
|
6
|
-
from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter
|
|
6
|
+
from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput
|
|
7
|
+
from kiln_ai.adapters.parsers.request_formatters import request_formatter_from_id
|
|
7
8
|
from kiln_ai.datamodel import Task
|
|
8
9
|
from kiln_ai.datamodel.task import RunConfig
|
|
9
10
|
|
|
@@ -12,7 +13,7 @@ class MockAdapter(BaseAdapter):
|
|
|
12
13
|
"""Concrete implementation of BaseAdapter for testing"""
|
|
13
14
|
|
|
14
15
|
async def _run(self, input):
|
|
15
|
-
return None
|
|
16
|
+
return None, None
|
|
16
17
|
|
|
17
18
|
def adapter_name(self) -> str:
|
|
18
19
|
return "test"
|
|
@@ -42,6 +43,22 @@ def adapter(base_task):
|
|
|
42
43
|
)
|
|
43
44
|
|
|
44
45
|
|
|
46
|
+
@pytest.fixture
|
|
47
|
+
def mock_formatter():
|
|
48
|
+
formatter = MagicMock()
|
|
49
|
+
formatter.format_input.return_value = {"formatted": "input"}
|
|
50
|
+
return formatter
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@pytest.fixture
|
|
54
|
+
def mock_parser():
|
|
55
|
+
parser = MagicMock()
|
|
56
|
+
parser.parse_output.return_value = RunOutput(
|
|
57
|
+
output="test output", intermediate_outputs={}
|
|
58
|
+
)
|
|
59
|
+
return parser
|
|
60
|
+
|
|
61
|
+
|
|
45
62
|
async def test_model_provider_uses_cache(adapter, mock_provider):
|
|
46
63
|
"""Test that cached provider is returned if it exists"""
|
|
47
64
|
# Set up cached provider
|
|
@@ -197,3 +214,58 @@ async def test_run_strategy(
|
|
|
197
214
|
# Test
|
|
198
215
|
result = adapter.run_strategy()
|
|
199
216
|
assert result == expected
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@pytest.mark.asyncio
|
|
220
|
+
@pytest.mark.parametrize(
|
|
221
|
+
"formatter_id,expected_input,expected_calls",
|
|
222
|
+
[
|
|
223
|
+
(None, {"original": "input"}, 0), # No formatter
|
|
224
|
+
("test_formatter", {"formatted": "input"}, 1), # With formatter
|
|
225
|
+
],
|
|
226
|
+
)
|
|
227
|
+
async def test_input_formatting(
|
|
228
|
+
adapter, mock_formatter, mock_parser, formatter_id, expected_input, expected_calls
|
|
229
|
+
):
|
|
230
|
+
"""Test that input formatting is handled correctly based on formatter configuration"""
|
|
231
|
+
# Mock the model provider to return our formatter ID and parser
|
|
232
|
+
provider = MagicMock()
|
|
233
|
+
provider.formatter = formatter_id
|
|
234
|
+
provider.parser = "test_parser"
|
|
235
|
+
provider.reasoning_capable = False
|
|
236
|
+
adapter.model_provider = MagicMock(return_value=provider)
|
|
237
|
+
|
|
238
|
+
# Mock the formatter factory and parser factory
|
|
239
|
+
with (
|
|
240
|
+
patch(
|
|
241
|
+
"kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id"
|
|
242
|
+
) as mock_factory,
|
|
243
|
+
patch(
|
|
244
|
+
"kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id"
|
|
245
|
+
) as mock_parser_factory,
|
|
246
|
+
):
|
|
247
|
+
mock_factory.return_value = mock_formatter
|
|
248
|
+
mock_parser_factory.return_value = mock_parser
|
|
249
|
+
|
|
250
|
+
# Mock the _run method to capture the input
|
|
251
|
+
captured_input = None
|
|
252
|
+
|
|
253
|
+
async def mock_run(input):
|
|
254
|
+
nonlocal captured_input
|
|
255
|
+
captured_input = input
|
|
256
|
+
return RunOutput(output="test output", intermediate_outputs={}), None
|
|
257
|
+
|
|
258
|
+
adapter._run = mock_run
|
|
259
|
+
|
|
260
|
+
# Run the adapter
|
|
261
|
+
original_input = {"original": "input"}
|
|
262
|
+
await adapter.invoke_returning_run_output(original_input)
|
|
263
|
+
|
|
264
|
+
# Verify formatter was called correctly
|
|
265
|
+
assert captured_input == expected_input
|
|
266
|
+
assert mock_factory.call_count == (1 if formatter_id else 0)
|
|
267
|
+
assert mock_formatter.format_input.call_count == expected_calls
|
|
268
|
+
|
|
269
|
+
# Verify original input was preserved in the run
|
|
270
|
+
if formatter_id:
|
|
271
|
+
mock_formatter.format_input.assert_called_once_with(original_input)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from unittest.mock import Mock, patch
|
|
3
3
|
|
|
4
|
+
import litellm
|
|
4
5
|
import pytest
|
|
5
6
|
|
|
6
7
|
from kiln_ai.adapters.ml_model_list import ModelProviderName, StructuredOutputMode
|
|
@@ -9,7 +10,7 @@ from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter
|
|
|
9
10
|
from kiln_ai.adapters.model_adapters.litellm_config import (
|
|
10
11
|
LiteLlmConfig,
|
|
11
12
|
)
|
|
12
|
-
from kiln_ai.datamodel import Project, Task
|
|
13
|
+
from kiln_ai.datamodel import Project, Task, Usage
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
@pytest.fixture
|
|
@@ -405,3 +406,66 @@ async def test_build_completion_kwargs(
|
|
|
405
406
|
# Verify extra body is included
|
|
406
407
|
for key, value in extra_body.items():
|
|
407
408
|
assert kwargs[key] == value
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
@pytest.mark.parametrize(
|
|
412
|
+
"litellm_usage,cost,expected_usage",
|
|
413
|
+
[
|
|
414
|
+
# No usage data
|
|
415
|
+
(None, None, None),
|
|
416
|
+
# Only cost
|
|
417
|
+
(None, 0.5, Usage(cost=0.5)),
|
|
418
|
+
# Only token counts
|
|
419
|
+
(
|
|
420
|
+
litellm.types.utils.Usage(
|
|
421
|
+
prompt_tokens=10,
|
|
422
|
+
completion_tokens=20,
|
|
423
|
+
total_tokens=30,
|
|
424
|
+
),
|
|
425
|
+
None,
|
|
426
|
+
Usage(input_tokens=10, output_tokens=20, total_tokens=30),
|
|
427
|
+
),
|
|
428
|
+
# Both cost and token counts
|
|
429
|
+
(
|
|
430
|
+
litellm.types.utils.Usage(
|
|
431
|
+
prompt_tokens=10,
|
|
432
|
+
completion_tokens=20,
|
|
433
|
+
total_tokens=30,
|
|
434
|
+
),
|
|
435
|
+
0.5,
|
|
436
|
+
Usage(input_tokens=10, output_tokens=20, total_tokens=30, cost=0.5),
|
|
437
|
+
),
|
|
438
|
+
# Invalid usage type (should be ignored)
|
|
439
|
+
({"prompt_tokens": 10}, None, None),
|
|
440
|
+
# Invalid cost type (should be ignored)
|
|
441
|
+
(None, "0.5", None),
|
|
442
|
+
],
|
|
443
|
+
)
|
|
444
|
+
def test_usage_from_response(config, mock_task, litellm_usage, cost, expected_usage):
|
|
445
|
+
"""Test usage_from_response with various combinations of usage data and cost"""
|
|
446
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
447
|
+
|
|
448
|
+
# Create a mock response
|
|
449
|
+
response = Mock(spec=litellm.types.utils.ModelResponse)
|
|
450
|
+
response.get.return_value = litellm_usage
|
|
451
|
+
response._hidden_params = {"response_cost": cost}
|
|
452
|
+
|
|
453
|
+
# Call the method
|
|
454
|
+
result = adapter.usage_from_response(response)
|
|
455
|
+
|
|
456
|
+
# Verify the result
|
|
457
|
+
if expected_usage is None:
|
|
458
|
+
if result is not None:
|
|
459
|
+
assert result.input_tokens is None
|
|
460
|
+
assert result.output_tokens is None
|
|
461
|
+
assert result.total_tokens is None
|
|
462
|
+
assert result.cost is None
|
|
463
|
+
else:
|
|
464
|
+
assert result is not None
|
|
465
|
+
assert result.input_tokens == expected_usage.input_tokens
|
|
466
|
+
assert result.output_tokens == expected_usage.output_tokens
|
|
467
|
+
assert result.total_tokens == expected_usage.total_tokens
|
|
468
|
+
assert result.cost == expected_usage.cost
|
|
469
|
+
|
|
470
|
+
# Verify the response was queried correctly
|
|
471
|
+
response.get.assert_called_once_with("usage", None)
|
|
@@ -11,14 +11,15 @@ from kiln_ai.datamodel import (
|
|
|
11
11
|
DataSourceType,
|
|
12
12
|
Project,
|
|
13
13
|
Task,
|
|
14
|
+
Usage,
|
|
14
15
|
)
|
|
15
16
|
from kiln_ai.datamodel.task import RunConfig
|
|
16
17
|
from kiln_ai.utils.config import Config
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class MockAdapter(BaseAdapter):
|
|
20
|
-
async def _run(self, input: dict | str) ->
|
|
21
|
-
return RunOutput(output="Test output", intermediate_outputs=None)
|
|
21
|
+
async def _run(self, input: dict | str) -> tuple[RunOutput, Usage | None]:
|
|
22
|
+
return RunOutput(output="Test output", intermediate_outputs=None), None
|
|
22
23
|
|
|
23
24
|
def adapter_name(self) -> str:
|
|
24
25
|
return "mock_adapter"
|
|
@@ -12,6 +12,7 @@ from kiln_ai.adapters.ml_model_list import (
|
|
|
12
12
|
from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
13
13
|
BaseAdapter,
|
|
14
14
|
RunOutput,
|
|
15
|
+
Usage,
|
|
15
16
|
)
|
|
16
17
|
from kiln_ai.adapters.ollama_tools import ollama_online
|
|
17
18
|
from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
|
|
@@ -54,8 +55,8 @@ class MockAdapter(BaseAdapter):
|
|
|
54
55
|
)
|
|
55
56
|
self.response = response
|
|
56
57
|
|
|
57
|
-
async def _run(self, input: str) -> RunOutput:
|
|
58
|
-
return RunOutput(output=self.response, intermediate_outputs=None)
|
|
58
|
+
async def _run(self, input: str) -> tuple[RunOutput, Usage | None]:
|
|
59
|
+
return RunOutput(output=self.response, intermediate_outputs=None), None
|
|
59
60
|
|
|
60
61
|
def adapter_name(self) -> str:
|
|
61
62
|
return "mock_adapter"
|
|
@@ -223,10 +224,7 @@ async def run_structured_input_task(
|
|
|
223
224
|
with pytest.raises(ValueError):
|
|
224
225
|
# not structured input in dictionary
|
|
225
226
|
await a.invoke("a=1, b=2, c=3")
|
|
226
|
-
with pytest.raises(
|
|
227
|
-
ValueError,
|
|
228
|
-
match="This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema.",
|
|
229
|
-
):
|
|
227
|
+
with pytest.raises(ValueError, match="This task requires a specific input"):
|
|
230
228
|
# invalid structured input
|
|
231
229
|
await a.invoke({"a": 1, "b": 2, "d": 3})
|
|
232
230
|
|
|
@@ -2,9 +2,6 @@ from kiln_ai.adapters.run_output import RunOutput
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
class BaseParser:
|
|
5
|
-
def __init__(self, structured_output: bool = False):
|
|
6
|
-
self.structured_output = structured_output
|
|
7
|
-
|
|
8
5
|
def parse_output(self, original_output: RunOutput) -> RunOutput:
|
|
9
6
|
"""
|
|
10
7
|
Method for parsing the output of a model. Typically overridden by subclasses.
|
|
@@ -6,14 +6,16 @@ from kiln_ai.adapters.parsers.r1_parser import R1ThinkingParser
|
|
|
6
6
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def model_parser_from_id(parser_id: ModelParserID | None) ->
|
|
9
|
+
def model_parser_from_id(parser_id: ModelParserID | None) -> BaseParser:
|
|
10
10
|
"""
|
|
11
11
|
Get a model parser from its ID.
|
|
12
12
|
"""
|
|
13
13
|
match parser_id:
|
|
14
14
|
case None:
|
|
15
|
-
return BaseParser
|
|
15
|
+
return BaseParser()
|
|
16
16
|
case ModelParserID.r1_thinking:
|
|
17
|
-
return R1ThinkingParser
|
|
17
|
+
return R1ThinkingParser()
|
|
18
|
+
case ModelParserID.optional_r1_thinking:
|
|
19
|
+
return R1ThinkingParser(allow_missing_thinking=True)
|
|
18
20
|
case _:
|
|
19
21
|
raise_exhaustive_enum_error(parser_id)
|
|
@@ -7,6 +7,9 @@ class R1ThinkingParser(BaseParser):
|
|
|
7
7
|
START_TAG = "<think>"
|
|
8
8
|
END_TAG = "</think>"
|
|
9
9
|
|
|
10
|
+
def __init__(self, allow_missing_thinking: bool = False):
|
|
11
|
+
self.allow_missing_thinking = allow_missing_thinking
|
|
12
|
+
|
|
10
13
|
def parse_output(self, original_output: RunOutput) -> RunOutput:
|
|
11
14
|
"""
|
|
12
15
|
Parse the <think> </think> tags from the response into the intermediate and final outputs.
|
|
@@ -27,6 +30,14 @@ class R1ThinkingParser(BaseParser):
|
|
|
27
30
|
original_output.intermediate_outputs is not None
|
|
28
31
|
and "reasoning" in original_output.intermediate_outputs
|
|
29
32
|
):
|
|
33
|
+
# sometimes the output and reasoning are wrapped in newlines
|
|
34
|
+
if isinstance(original_output.output, str):
|
|
35
|
+
original_output.output = original_output.output.strip()
|
|
36
|
+
|
|
37
|
+
original_output.intermediate_outputs["reasoning"] = (
|
|
38
|
+
original_output.intermediate_outputs["reasoning"].strip()
|
|
39
|
+
)
|
|
40
|
+
|
|
30
41
|
return original_output
|
|
31
42
|
|
|
32
43
|
# This parser only works for strings
|
|
@@ -39,7 +50,10 @@ class R1ThinkingParser(BaseParser):
|
|
|
39
50
|
# Find the thinking tags
|
|
40
51
|
think_end = cleaned_response.find(self.END_TAG)
|
|
41
52
|
if think_end == -1:
|
|
42
|
-
|
|
53
|
+
if self.allow_missing_thinking:
|
|
54
|
+
return original_output
|
|
55
|
+
else:
|
|
56
|
+
raise ValueError("Missing </think> tag")
|
|
43
57
|
|
|
44
58
|
think_tag_start = cleaned_response.find(self.START_TAG)
|
|
45
59
|
if think_tag_start == -1:
|
|
@@ -66,7 +80,8 @@ class R1ThinkingParser(BaseParser):
|
|
|
66
80
|
|
|
67
81
|
# Add thinking content to intermediate outputs if it exists
|
|
68
82
|
intermediate_outputs = original_output.intermediate_outputs or {}
|
|
69
|
-
|
|
83
|
+
if thinking_content is not None and len(thinking_content) > 0:
|
|
84
|
+
intermediate_outputs["reasoning"] = thinking_content
|
|
70
85
|
|
|
71
86
|
return RunOutput(
|
|
72
87
|
output=result,
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Dict, Protocol
|
|
3
|
+
|
|
4
|
+
from kiln_ai.adapters.ml_model_list import ModelFormatterID
|
|
5
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RequestFormatter(Protocol):
|
|
9
|
+
def format_input(self, original_input: Dict | str) -> Dict | str:
|
|
10
|
+
"""
|
|
11
|
+
Method for formatting the input to a model.
|
|
12
|
+
"""
|
|
13
|
+
...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Qwen3StyleNoThinkFormatter:
|
|
17
|
+
def format_input(self, original_input: Dict | str) -> Dict | str:
|
|
18
|
+
"""
|
|
19
|
+
Format the input to a model for Qwen3 /no_think instruction
|
|
20
|
+
"""
|
|
21
|
+
formatted_input = (
|
|
22
|
+
original_input
|
|
23
|
+
if isinstance(original_input, str)
|
|
24
|
+
else json.dumps(original_input, indent=2)
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
return formatted_input + "\n\n/no_think"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def request_formatter_from_id(
|
|
31
|
+
formatter_id: ModelFormatterID,
|
|
32
|
+
) -> RequestFormatter:
|
|
33
|
+
"""
|
|
34
|
+
Get a model parser from its ID.
|
|
35
|
+
"""
|
|
36
|
+
match formatter_id:
|
|
37
|
+
case ModelFormatterID.qwen3_style_no_think:
|
|
38
|
+
return Qwen3StyleNoThinkFormatter()
|
|
39
|
+
case _:
|
|
40
|
+
raise_exhaustive_enum_error(formatter_id)
|
|
@@ -28,5 +28,5 @@ def test_model_parser_from_id_invalid():
|
|
|
28
28
|
)
|
|
29
29
|
def test_model_parser_from_id_parametrized(parser_id, expected_class):
|
|
30
30
|
"""Test all valid parser IDs using parametrize."""
|
|
31
|
-
|
|
32
|
-
assert
|
|
31
|
+
parser = model_parser_from_id(parser_id)
|
|
32
|
+
assert isinstance(parser, expected_class)
|
|
@@ -46,6 +46,21 @@ def test_response_with_whitespace(parser):
|
|
|
46
46
|
assert parsed.output.strip() == "This is the result"
|
|
47
47
|
|
|
48
48
|
|
|
49
|
+
def test_empty_thinking_content(parser):
|
|
50
|
+
response = RunOutput(
|
|
51
|
+
output="""
|
|
52
|
+
<think>
|
|
53
|
+
|
|
54
|
+
</think>
|
|
55
|
+
This is the result
|
|
56
|
+
""",
|
|
57
|
+
intermediate_outputs=None,
|
|
58
|
+
)
|
|
59
|
+
parsed = parser.parse_output(response)
|
|
60
|
+
assert "reasoning" not in parsed.intermediate_outputs
|
|
61
|
+
assert parsed.output.strip() == "This is the result"
|
|
62
|
+
|
|
63
|
+
|
|
49
64
|
def test_missing_start_tag(parser):
|
|
50
65
|
parsed = parser.parse_output(
|
|
51
66
|
RunOutput(output="Some content</think>result", intermediate_outputs=None)
|
|
@@ -86,7 +101,7 @@ def test_empty_thinking_content(parser):
|
|
|
86
101
|
output="<think></think>This is the result", intermediate_outputs=None
|
|
87
102
|
)
|
|
88
103
|
parsed = parser.parse_output(response)
|
|
89
|
-
assert
|
|
104
|
+
assert "reasoning" not in parsed.intermediate_outputs
|
|
90
105
|
assert parsed.output == "This is the result"
|
|
91
106
|
|
|
92
107
|
|
|
@@ -154,3 +169,31 @@ def test_intermediate_outputs(parser):
|
|
|
154
169
|
)
|
|
155
170
|
)
|
|
156
171
|
assert out.intermediate_outputs["reasoning"] == "Some content"
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def test_strip_newlines(parser):
|
|
175
|
+
# certain providers via LiteLLM for example, add newlines to the output
|
|
176
|
+
# and to the reasoning. This tests that we strip those newlines.
|
|
177
|
+
response = RunOutput(
|
|
178
|
+
output="\n\nSome content",
|
|
179
|
+
intermediate_outputs={
|
|
180
|
+
"reasoning": "\n\nSome thinking\n\n",
|
|
181
|
+
},
|
|
182
|
+
)
|
|
183
|
+
parsed = parser.parse_output(response)
|
|
184
|
+
assert parsed.output == "Some content"
|
|
185
|
+
assert parsed.intermediate_outputs["reasoning"] == "Some thinking"
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def test_strip_newlines_with_structured_output(parser):
|
|
189
|
+
# certain providers via LiteLLM for example, add newlines to the output
|
|
190
|
+
# and to the reasoning. This tests that we strip those newlines.
|
|
191
|
+
response = RunOutput(
|
|
192
|
+
output={"some_key": "Some content"},
|
|
193
|
+
intermediate_outputs={
|
|
194
|
+
"reasoning": "\n\nSome thinking\n\n",
|
|
195
|
+
},
|
|
196
|
+
)
|
|
197
|
+
parsed = parser.parse_output(response)
|
|
198
|
+
assert parsed.output == {"some_key": "Some content"}
|
|
199
|
+
assert parsed.intermediate_outputs["reasoning"] == "Some thinking"
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from kiln_ai.adapters.ml_model_list import ModelFormatterID
|
|
4
|
+
from kiln_ai.adapters.parsers.request_formatters import (
|
|
5
|
+
Qwen3StyleNoThinkFormatter,
|
|
6
|
+
request_formatter_from_id,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@pytest.fixture
|
|
11
|
+
def qwen_formatter():
|
|
12
|
+
return Qwen3StyleNoThinkFormatter()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def test_qwen_formatter_string_input(qwen_formatter):
|
|
16
|
+
input_text = "Hello world"
|
|
17
|
+
formatted = qwen_formatter.format_input(input_text)
|
|
18
|
+
assert formatted == "Hello world\n\n/no_think"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def test_qwen_formatter_dict_input(qwen_formatter):
|
|
22
|
+
input_dict = {"key": "value", "nested": {"inner": "data"}}
|
|
23
|
+
formatted = qwen_formatter.format_input(input_dict)
|
|
24
|
+
expected = """{
|
|
25
|
+
"key": "value",
|
|
26
|
+
"nested": {
|
|
27
|
+
"inner": "data"
|
|
28
|
+
}
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
/no_think"""
|
|
32
|
+
assert formatted == expected
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def test_qwen_formatter_empty_input(qwen_formatter):
|
|
36
|
+
# Test empty string
|
|
37
|
+
assert qwen_formatter.format_input("") == "\n\n/no_think"
|
|
38
|
+
|
|
39
|
+
# Test empty dict
|
|
40
|
+
assert qwen_formatter.format_input({}) == "{}\n\n/no_think"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def test_qwen_formatter_special_characters(qwen_formatter):
|
|
44
|
+
input_text = "Special chars: !@#$%^&*()_+思"
|
|
45
|
+
formatted = qwen_formatter.format_input(input_text)
|
|
46
|
+
assert formatted == "Special chars: !@#$%^&*()_+思\n\n/no_think"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def test_qwen_formatter_multiline_string(qwen_formatter):
|
|
50
|
+
input_text = """Line 1
|
|
51
|
+
Line 2
|
|
52
|
+
Line 3"""
|
|
53
|
+
formatted = qwen_formatter.format_input(input_text)
|
|
54
|
+
assert (
|
|
55
|
+
formatted
|
|
56
|
+
== """Line 1
|
|
57
|
+
Line 2
|
|
58
|
+
Line 3
|
|
59
|
+
|
|
60
|
+
/no_think"""
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def test_request_formatter_factory():
|
|
65
|
+
# Test valid formatter ID
|
|
66
|
+
formatter = request_formatter_from_id(ModelFormatterID.qwen3_style_no_think)
|
|
67
|
+
assert isinstance(formatter, Qwen3StyleNoThinkFormatter)
|
|
68
|
+
|
|
69
|
+
# Test that the formatter works
|
|
70
|
+
assert formatter.format_input("test") == "test\n\n/no_think"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def test_request_formatter_factory_invalid_id():
|
|
74
|
+
# Test with an invalid enum value by using a string that doesn't exist in the enum
|
|
75
|
+
with pytest.raises(ValueError, match="Unhandled enum value"):
|
|
76
|
+
request_formatter_from_id("invalid_formatter_id") # type: ignore
|
|
@@ -101,7 +101,6 @@ class SimplePromptBuilder(BasePromptBuilder):
|
|
|
101
101
|
"""
|
|
102
102
|
base_prompt = self.task.instruction
|
|
103
103
|
|
|
104
|
-
# TODO: this is just a quick version. Formatting and best practices TBD
|
|
105
104
|
if len(self.task.requirements) > 0:
|
|
106
105
|
base_prompt += (
|
|
107
106
|
"\n\nYour response should respect the following requirements:\n"
|
|
@@ -113,6 +112,18 @@ class SimplePromptBuilder(BasePromptBuilder):
|
|
|
113
112
|
return base_prompt
|
|
114
113
|
|
|
115
114
|
|
|
115
|
+
class ShortPromptBuilder(BasePromptBuilder):
|
|
116
|
+
"""A prompt builder that includes a the base prompt but excludes the requirements."""
|
|
117
|
+
|
|
118
|
+
def build_base_prompt(self) -> str:
|
|
119
|
+
"""Build a short prompt with just the base prompt, no requirements.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
str: The constructed prompt string.
|
|
123
|
+
"""
|
|
124
|
+
return self.task.instruction
|
|
125
|
+
|
|
126
|
+
|
|
116
127
|
class MultiShotPromptBuilder(BasePromptBuilder):
|
|
117
128
|
"""A prompt builder that includes multiple examples in the prompt."""
|
|
118
129
|
|
|
@@ -414,6 +425,8 @@ def prompt_builder_from_id(prompt_id: PromptId, task: Task) -> BasePromptBuilder
|
|
|
414
425
|
match typed_prompt_generator:
|
|
415
426
|
case PromptGenerators.SIMPLE:
|
|
416
427
|
return SimplePromptBuilder(task)
|
|
428
|
+
case PromptGenerators.SHORT:
|
|
429
|
+
return ShortPromptBuilder(task)
|
|
417
430
|
case PromptGenerators.FEW_SHOT:
|
|
418
431
|
return FewShotPromptBuilder(task)
|
|
419
432
|
case PromptGenerators.MULTI_SHOT:
|
|
@@ -5,6 +5,7 @@ from kiln_ai.adapters.ml_model_list import (
|
|
|
5
5
|
KilnModel,
|
|
6
6
|
KilnModelProvider,
|
|
7
7
|
ModelName,
|
|
8
|
+
ModelParserID,
|
|
8
9
|
ModelProviderName,
|
|
9
10
|
StructuredOutputMode,
|
|
10
11
|
built_in_models,
|
|
@@ -15,7 +16,7 @@ from kiln_ai.adapters.model_adapters.litellm_config import (
|
|
|
15
16
|
from kiln_ai.adapters.ollama_tools import (
|
|
16
17
|
get_ollama_connection,
|
|
17
18
|
)
|
|
18
|
-
from kiln_ai.datamodel import Finetune, Task
|
|
19
|
+
from kiln_ai.datamodel import Finetune, FinetuneDataStrategy, Task
|
|
19
20
|
from kiln_ai.datamodel.registry import project_from_id
|
|
20
21
|
from kiln_ai.utils.config import Config
|
|
21
22
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
@@ -257,6 +258,14 @@ def finetune_from_id(model_id: str) -> Finetune:
|
|
|
257
258
|
return fine_tune
|
|
258
259
|
|
|
259
260
|
|
|
261
|
+
def parser_from_data_strategy(
|
|
262
|
+
data_strategy: FinetuneDataStrategy,
|
|
263
|
+
) -> ModelParserID | None:
|
|
264
|
+
if data_strategy == FinetuneDataStrategy.final_and_intermediate_r1_compatible:
|
|
265
|
+
return ModelParserID.r1_thinking
|
|
266
|
+
return None
|
|
267
|
+
|
|
268
|
+
|
|
260
269
|
def finetune_provider_model(
|
|
261
270
|
model_id: str,
|
|
262
271
|
) -> KilnModelProvider:
|
|
@@ -266,6 +275,14 @@ def finetune_provider_model(
|
|
|
266
275
|
model_provider = KilnModelProvider(
|
|
267
276
|
name=provider,
|
|
268
277
|
model_id=fine_tune.fine_tune_model_id,
|
|
278
|
+
parser=parser_from_data_strategy(fine_tune.data_strategy),
|
|
279
|
+
reasoning_capable=(
|
|
280
|
+
fine_tune.data_strategy
|
|
281
|
+
in [
|
|
282
|
+
FinetuneDataStrategy.final_and_intermediate,
|
|
283
|
+
FinetuneDataStrategy.final_and_intermediate_r1_compatible,
|
|
284
|
+
]
|
|
285
|
+
),
|
|
269
286
|
)
|
|
270
287
|
|
|
271
288
|
if provider == ModelProviderName.vertex and fine_tune.fine_tune_model_id:
|