kiln-ai 0.15.0__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. kiln_ai/adapters/eval/eval_runner.py +5 -64
  2. kiln_ai/adapters/eval/g_eval.py +3 -3
  3. kiln_ai/adapters/fine_tune/dataset_formatter.py +124 -34
  4. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +264 -7
  5. kiln_ai/adapters/ml_model_list.py +478 -4
  6. kiln_ai/adapters/model_adapters/base_adapter.py +26 -8
  7. kiln_ai/adapters/model_adapters/litellm_adapter.py +41 -7
  8. kiln_ai/adapters/model_adapters/test_base_adapter.py +74 -2
  9. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +65 -1
  10. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +3 -2
  11. kiln_ai/adapters/model_adapters/test_structured_output.py +4 -6
  12. kiln_ai/adapters/parsers/base_parser.py +0 -3
  13. kiln_ai/adapters/parsers/parser_registry.py +5 -3
  14. kiln_ai/adapters/parsers/r1_parser.py +17 -2
  15. kiln_ai/adapters/parsers/request_formatters.py +40 -0
  16. kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
  17. kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
  18. kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
  19. kiln_ai/adapters/prompt_builders.py +14 -1
  20. kiln_ai/adapters/provider_tools.py +18 -1
  21. kiln_ai/adapters/repair/test_repair_task.py +3 -2
  22. kiln_ai/adapters/test_prompt_builders.py +24 -3
  23. kiln_ai/adapters/test_provider_tools.py +70 -1
  24. kiln_ai/datamodel/__init__.py +2 -0
  25. kiln_ai/datamodel/datamodel_enums.py +14 -0
  26. kiln_ai/datamodel/dataset_filters.py +69 -1
  27. kiln_ai/datamodel/dataset_split.py +4 -0
  28. kiln_ai/datamodel/eval.py +8 -0
  29. kiln_ai/datamodel/finetune.py +1 -0
  30. kiln_ai/datamodel/prompt_id.py +1 -0
  31. kiln_ai/datamodel/task_output.py +1 -1
  32. kiln_ai/datamodel/task_run.py +39 -7
  33. kiln_ai/datamodel/test_basemodel.py +3 -7
  34. kiln_ai/datamodel/test_dataset_filters.py +82 -0
  35. kiln_ai/datamodel/test_dataset_split.py +2 -0
  36. kiln_ai/datamodel/test_example_models.py +54 -0
  37. kiln_ai/datamodel/test_models.py +50 -2
  38. kiln_ai/utils/async_job_runner.py +106 -0
  39. kiln_ai/utils/dataset_import.py +80 -18
  40. kiln_ai/utils/test_async_job_runner.py +199 -0
  41. kiln_ai/utils/test_dataset_import.py +242 -10
  42. {kiln_ai-0.15.0.dist-info → kiln_ai-0.16.0.dist-info}/METADATA +1 -1
  43. {kiln_ai-0.15.0.dist-info → kiln_ai-0.16.0.dist-info}/RECORD +45 -41
  44. {kiln_ai-0.15.0.dist-info → kiln_ai-0.16.0.dist-info}/WHEEL +0 -0
  45. {kiln_ai-0.15.0.dist-info → kiln_ai-0.16.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,7 +1,9 @@
1
+ import logging
1
2
  from typing import Any, Dict
2
3
 
3
4
  import litellm
4
5
  from litellm.types.utils import ChoiceLogprobs, Choices, ModelResponse
6
+ from litellm.types.utils import Usage as LiteLlmUsage
5
7
 
6
8
  import kiln_ai.datamodel as datamodel
7
9
  from kiln_ai.adapters.ml_model_list import (
@@ -14,14 +16,15 @@ from kiln_ai.adapters.model_adapters.base_adapter import (
14
16
  AdapterConfig,
15
17
  BaseAdapter,
16
18
  RunOutput,
19
+ Usage,
17
20
  )
18
- from kiln_ai.adapters.model_adapters.litellm_config import (
19
- LiteLlmConfig,
20
- )
21
+ from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
21
22
  from kiln_ai.datamodel import PromptGenerators, PromptId
22
23
  from kiln_ai.datamodel.task import RunConfig
23
24
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
24
25
 
26
+ logger = logging.getLogger(__name__)
27
+
25
28
 
26
29
  class LiteLlmAdapter(BaseAdapter):
27
30
  def __init__(
@@ -49,7 +52,7 @@ class LiteLlmAdapter(BaseAdapter):
49
52
  config=base_adapter_config,
50
53
  )
51
54
 
52
- async def _run(self, input: Dict | str) -> RunOutput:
55
+ async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]:
53
56
  provider = self.model_provider()
54
57
  if not provider.model_id:
55
58
  raise ValueError("Model ID is required for OpenAI compatible models")
@@ -139,8 +142,12 @@ class LiteLlmAdapter(BaseAdapter):
139
142
  raise RuntimeError("Logprobs were required, but no logprobs were returned.")
140
143
 
141
144
  # Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream)
142
- if hasattr(message, "reasoning_content") and message.reasoning_content:
143
- intermediate_outputs["reasoning"] = message.reasoning_content
145
+ if (
146
+ hasattr(message, "reasoning_content")
147
+ and message.reasoning_content
148
+ and len(message.reasoning_content.strip()) > 0
149
+ ):
150
+ intermediate_outputs["reasoning"] = message.reasoning_content.strip()
144
151
 
145
152
  # the string content of the response
146
153
  response_content = message.content
@@ -169,7 +176,7 @@ class LiteLlmAdapter(BaseAdapter):
169
176
  output=response_content,
170
177
  intermediate_outputs=intermediate_outputs,
171
178
  output_logprobs=logprobs,
172
- )
179
+ ), self.usage_from_response(response)
173
180
 
174
181
  def adapter_name(self) -> str:
175
182
  return "kiln_openai_compatible_adapter"
@@ -394,3 +401,30 @@ class LiteLlmAdapter(BaseAdapter):
394
401
  completion_kwargs["top_logprobs"] = top_logprobs
395
402
 
396
403
  return completion_kwargs
404
+
405
+ def usage_from_response(self, response: ModelResponse) -> Usage | None:
406
+ litellm_usage = response.get("usage", None)
407
+ cost = response._hidden_params.get("response_cost", None)
408
+ if not litellm_usage and not cost:
409
+ return None
410
+
411
+ usage = Usage()
412
+
413
+ if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
414
+ usage.input_tokens = litellm_usage.get("prompt_tokens", None)
415
+ usage.output_tokens = litellm_usage.get("completion_tokens", None)
416
+ usage.total_tokens = litellm_usage.get("total_tokens", None)
417
+ else:
418
+ logger.warning(
419
+ f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
420
+ )
421
+
422
+ if isinstance(cost, float):
423
+ usage.cost = cost
424
+ elif cost is not None:
425
+ # None is allowed, but no other types are expected
426
+ logger.warning(
427
+ f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
428
+ )
429
+
430
+ return usage
@@ -3,7 +3,8 @@ from unittest.mock import MagicMock, patch
3
3
  import pytest
4
4
 
5
5
  from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
6
- from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter
6
+ from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput
7
+ from kiln_ai.adapters.parsers.request_formatters import request_formatter_from_id
7
8
  from kiln_ai.datamodel import Task
8
9
  from kiln_ai.datamodel.task import RunConfig
9
10
 
@@ -12,7 +13,7 @@ class MockAdapter(BaseAdapter):
12
13
  """Concrete implementation of BaseAdapter for testing"""
13
14
 
14
15
  async def _run(self, input):
15
- return None
16
+ return None, None
16
17
 
17
18
  def adapter_name(self) -> str:
18
19
  return "test"
@@ -42,6 +43,22 @@ def adapter(base_task):
42
43
  )
43
44
 
44
45
 
46
+ @pytest.fixture
47
+ def mock_formatter():
48
+ formatter = MagicMock()
49
+ formatter.format_input.return_value = {"formatted": "input"}
50
+ return formatter
51
+
52
+
53
+ @pytest.fixture
54
+ def mock_parser():
55
+ parser = MagicMock()
56
+ parser.parse_output.return_value = RunOutput(
57
+ output="test output", intermediate_outputs={}
58
+ )
59
+ return parser
60
+
61
+
45
62
  async def test_model_provider_uses_cache(adapter, mock_provider):
46
63
  """Test that cached provider is returned if it exists"""
47
64
  # Set up cached provider
@@ -197,3 +214,58 @@ async def test_run_strategy(
197
214
  # Test
198
215
  result = adapter.run_strategy()
199
216
  assert result == expected
217
+
218
+
219
+ @pytest.mark.asyncio
220
+ @pytest.mark.parametrize(
221
+ "formatter_id,expected_input,expected_calls",
222
+ [
223
+ (None, {"original": "input"}, 0), # No formatter
224
+ ("test_formatter", {"formatted": "input"}, 1), # With formatter
225
+ ],
226
+ )
227
+ async def test_input_formatting(
228
+ adapter, mock_formatter, mock_parser, formatter_id, expected_input, expected_calls
229
+ ):
230
+ """Test that input formatting is handled correctly based on formatter configuration"""
231
+ # Mock the model provider to return our formatter ID and parser
232
+ provider = MagicMock()
233
+ provider.formatter = formatter_id
234
+ provider.parser = "test_parser"
235
+ provider.reasoning_capable = False
236
+ adapter.model_provider = MagicMock(return_value=provider)
237
+
238
+ # Mock the formatter factory and parser factory
239
+ with (
240
+ patch(
241
+ "kiln_ai.adapters.model_adapters.base_adapter.request_formatter_from_id"
242
+ ) as mock_factory,
243
+ patch(
244
+ "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id"
245
+ ) as mock_parser_factory,
246
+ ):
247
+ mock_factory.return_value = mock_formatter
248
+ mock_parser_factory.return_value = mock_parser
249
+
250
+ # Mock the _run method to capture the input
251
+ captured_input = None
252
+
253
+ async def mock_run(input):
254
+ nonlocal captured_input
255
+ captured_input = input
256
+ return RunOutput(output="test output", intermediate_outputs={}), None
257
+
258
+ adapter._run = mock_run
259
+
260
+ # Run the adapter
261
+ original_input = {"original": "input"}
262
+ await adapter.invoke_returning_run_output(original_input)
263
+
264
+ # Verify formatter was called correctly
265
+ assert captured_input == expected_input
266
+ assert mock_factory.call_count == (1 if formatter_id else 0)
267
+ assert mock_formatter.format_input.call_count == expected_calls
268
+
269
+ # Verify original input was preserved in the run
270
+ if formatter_id:
271
+ mock_formatter.format_input.assert_called_once_with(original_input)
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  from unittest.mock import Mock, patch
3
3
 
4
+ import litellm
4
5
  import pytest
5
6
 
6
7
  from kiln_ai.adapters.ml_model_list import ModelProviderName, StructuredOutputMode
@@ -9,7 +10,7 @@ from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter
9
10
  from kiln_ai.adapters.model_adapters.litellm_config import (
10
11
  LiteLlmConfig,
11
12
  )
12
- from kiln_ai.datamodel import Project, Task
13
+ from kiln_ai.datamodel import Project, Task, Usage
13
14
 
14
15
 
15
16
  @pytest.fixture
@@ -405,3 +406,66 @@ async def test_build_completion_kwargs(
405
406
  # Verify extra body is included
406
407
  for key, value in extra_body.items():
407
408
  assert kwargs[key] == value
409
+
410
+
411
+ @pytest.mark.parametrize(
412
+ "litellm_usage,cost,expected_usage",
413
+ [
414
+ # No usage data
415
+ (None, None, None),
416
+ # Only cost
417
+ (None, 0.5, Usage(cost=0.5)),
418
+ # Only token counts
419
+ (
420
+ litellm.types.utils.Usage(
421
+ prompt_tokens=10,
422
+ completion_tokens=20,
423
+ total_tokens=30,
424
+ ),
425
+ None,
426
+ Usage(input_tokens=10, output_tokens=20, total_tokens=30),
427
+ ),
428
+ # Both cost and token counts
429
+ (
430
+ litellm.types.utils.Usage(
431
+ prompt_tokens=10,
432
+ completion_tokens=20,
433
+ total_tokens=30,
434
+ ),
435
+ 0.5,
436
+ Usage(input_tokens=10, output_tokens=20, total_tokens=30, cost=0.5),
437
+ ),
438
+ # Invalid usage type (should be ignored)
439
+ ({"prompt_tokens": 10}, None, None),
440
+ # Invalid cost type (should be ignored)
441
+ (None, "0.5", None),
442
+ ],
443
+ )
444
+ def test_usage_from_response(config, mock_task, litellm_usage, cost, expected_usage):
445
+ """Test usage_from_response with various combinations of usage data and cost"""
446
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
447
+
448
+ # Create a mock response
449
+ response = Mock(spec=litellm.types.utils.ModelResponse)
450
+ response.get.return_value = litellm_usage
451
+ response._hidden_params = {"response_cost": cost}
452
+
453
+ # Call the method
454
+ result = adapter.usage_from_response(response)
455
+
456
+ # Verify the result
457
+ if expected_usage is None:
458
+ if result is not None:
459
+ assert result.input_tokens is None
460
+ assert result.output_tokens is None
461
+ assert result.total_tokens is None
462
+ assert result.cost is None
463
+ else:
464
+ assert result is not None
465
+ assert result.input_tokens == expected_usage.input_tokens
466
+ assert result.output_tokens == expected_usage.output_tokens
467
+ assert result.total_tokens == expected_usage.total_tokens
468
+ assert result.cost == expected_usage.cost
469
+
470
+ # Verify the response was queried correctly
471
+ response.get.assert_called_once_with("usage", None)
@@ -11,14 +11,15 @@ from kiln_ai.datamodel import (
11
11
  DataSourceType,
12
12
  Project,
13
13
  Task,
14
+ Usage,
14
15
  )
15
16
  from kiln_ai.datamodel.task import RunConfig
16
17
  from kiln_ai.utils.config import Config
17
18
 
18
19
 
19
20
  class MockAdapter(BaseAdapter):
20
- async def _run(self, input: dict | str) -> dict | str:
21
- return RunOutput(output="Test output", intermediate_outputs=None)
21
+ async def _run(self, input: dict | str) -> tuple[RunOutput, Usage | None]:
22
+ return RunOutput(output="Test output", intermediate_outputs=None), None
22
23
 
23
24
  def adapter_name(self) -> str:
24
25
  return "mock_adapter"
@@ -12,6 +12,7 @@ from kiln_ai.adapters.ml_model_list import (
12
12
  from kiln_ai.adapters.model_adapters.base_adapter import (
13
13
  BaseAdapter,
14
14
  RunOutput,
15
+ Usage,
15
16
  )
16
17
  from kiln_ai.adapters.ollama_tools import ollama_online
17
18
  from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
@@ -54,8 +55,8 @@ class MockAdapter(BaseAdapter):
54
55
  )
55
56
  self.response = response
56
57
 
57
- async def _run(self, input: str) -> RunOutput:
58
- return RunOutput(output=self.response, intermediate_outputs=None)
58
+ async def _run(self, input: str) -> tuple[RunOutput, Usage | None]:
59
+ return RunOutput(output=self.response, intermediate_outputs=None), None
59
60
 
60
61
  def adapter_name(self) -> str:
61
62
  return "mock_adapter"
@@ -223,10 +224,7 @@ async def run_structured_input_task(
223
224
  with pytest.raises(ValueError):
224
225
  # not structured input in dictionary
225
226
  await a.invoke("a=1, b=2, c=3")
226
- with pytest.raises(
227
- ValueError,
228
- match="This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema.",
229
- ):
227
+ with pytest.raises(ValueError, match="This task requires a specific input"):
230
228
  # invalid structured input
231
229
  await a.invoke({"a": 1, "b": 2, "d": 3})
232
230
 
@@ -2,9 +2,6 @@ from kiln_ai.adapters.run_output import RunOutput
2
2
 
3
3
 
4
4
  class BaseParser:
5
- def __init__(self, structured_output: bool = False):
6
- self.structured_output = structured_output
7
-
8
5
  def parse_output(self, original_output: RunOutput) -> RunOutput:
9
6
  """
10
7
  Method for parsing the output of a model. Typically overridden by subclasses.
@@ -6,14 +6,16 @@ from kiln_ai.adapters.parsers.r1_parser import R1ThinkingParser
6
6
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
7
7
 
8
8
 
9
- def model_parser_from_id(parser_id: ModelParserID | None) -> Type[BaseParser]:
9
+ def model_parser_from_id(parser_id: ModelParserID | None) -> BaseParser:
10
10
  """
11
11
  Get a model parser from its ID.
12
12
  """
13
13
  match parser_id:
14
14
  case None:
15
- return BaseParser
15
+ return BaseParser()
16
16
  case ModelParserID.r1_thinking:
17
- return R1ThinkingParser
17
+ return R1ThinkingParser()
18
+ case ModelParserID.optional_r1_thinking:
19
+ return R1ThinkingParser(allow_missing_thinking=True)
18
20
  case _:
19
21
  raise_exhaustive_enum_error(parser_id)
@@ -7,6 +7,9 @@ class R1ThinkingParser(BaseParser):
7
7
  START_TAG = "<think>"
8
8
  END_TAG = "</think>"
9
9
 
10
+ def __init__(self, allow_missing_thinking: bool = False):
11
+ self.allow_missing_thinking = allow_missing_thinking
12
+
10
13
  def parse_output(self, original_output: RunOutput) -> RunOutput:
11
14
  """
12
15
  Parse the <think> </think> tags from the response into the intermediate and final outputs.
@@ -27,6 +30,14 @@ class R1ThinkingParser(BaseParser):
27
30
  original_output.intermediate_outputs is not None
28
31
  and "reasoning" in original_output.intermediate_outputs
29
32
  ):
33
+ # sometimes the output and reasoning are wrapped in newlines
34
+ if isinstance(original_output.output, str):
35
+ original_output.output = original_output.output.strip()
36
+
37
+ original_output.intermediate_outputs["reasoning"] = (
38
+ original_output.intermediate_outputs["reasoning"].strip()
39
+ )
40
+
30
41
  return original_output
31
42
 
32
43
  # This parser only works for strings
@@ -39,7 +50,10 @@ class R1ThinkingParser(BaseParser):
39
50
  # Find the thinking tags
40
51
  think_end = cleaned_response.find(self.END_TAG)
41
52
  if think_end == -1:
42
- raise ValueError("Missing </think> tag")
53
+ if self.allow_missing_thinking:
54
+ return original_output
55
+ else:
56
+ raise ValueError("Missing </think> tag")
43
57
 
44
58
  think_tag_start = cleaned_response.find(self.START_TAG)
45
59
  if think_tag_start == -1:
@@ -66,7 +80,8 @@ class R1ThinkingParser(BaseParser):
66
80
 
67
81
  # Add thinking content to intermediate outputs if it exists
68
82
  intermediate_outputs = original_output.intermediate_outputs or {}
69
- intermediate_outputs["reasoning"] = thinking_content
83
+ if thinking_content is not None and len(thinking_content) > 0:
84
+ intermediate_outputs["reasoning"] = thinking_content
70
85
 
71
86
  return RunOutput(
72
87
  output=result,
@@ -0,0 +1,40 @@
1
+ import json
2
+ from typing import Dict, Protocol
3
+
4
+ from kiln_ai.adapters.ml_model_list import ModelFormatterID
5
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
6
+
7
+
8
+ class RequestFormatter(Protocol):
9
+ def format_input(self, original_input: Dict | str) -> Dict | str:
10
+ """
11
+ Method for formatting the input to a model.
12
+ """
13
+ ...
14
+
15
+
16
+ class Qwen3StyleNoThinkFormatter:
17
+ def format_input(self, original_input: Dict | str) -> Dict | str:
18
+ """
19
+ Format the input to a model for Qwen3 /no_think instruction
20
+ """
21
+ formatted_input = (
22
+ original_input
23
+ if isinstance(original_input, str)
24
+ else json.dumps(original_input, indent=2)
25
+ )
26
+
27
+ return formatted_input + "\n\n/no_think"
28
+
29
+
30
+ def request_formatter_from_id(
31
+ formatter_id: ModelFormatterID,
32
+ ) -> RequestFormatter:
33
+ """
34
+ Get a model parser from its ID.
35
+ """
36
+ match formatter_id:
37
+ case ModelFormatterID.qwen3_style_no_think:
38
+ return Qwen3StyleNoThinkFormatter()
39
+ case _:
40
+ raise_exhaustive_enum_error(formatter_id)
@@ -28,5 +28,5 @@ def test_model_parser_from_id_invalid():
28
28
  )
29
29
  def test_model_parser_from_id_parametrized(parser_id, expected_class):
30
30
  """Test all valid parser IDs using parametrize."""
31
- parser_class = model_parser_from_id(parser_id)
32
- assert parser_class == expected_class
31
+ parser = model_parser_from_id(parser_id)
32
+ assert isinstance(parser, expected_class)
@@ -46,6 +46,21 @@ def test_response_with_whitespace(parser):
46
46
  assert parsed.output.strip() == "This is the result"
47
47
 
48
48
 
49
+ def test_empty_thinking_content(parser):
50
+ response = RunOutput(
51
+ output="""
52
+ <think>
53
+
54
+ </think>
55
+ This is the result
56
+ """,
57
+ intermediate_outputs=None,
58
+ )
59
+ parsed = parser.parse_output(response)
60
+ assert "reasoning" not in parsed.intermediate_outputs
61
+ assert parsed.output.strip() == "This is the result"
62
+
63
+
49
64
  def test_missing_start_tag(parser):
50
65
  parsed = parser.parse_output(
51
66
  RunOutput(output="Some content</think>result", intermediate_outputs=None)
@@ -86,7 +101,7 @@ def test_empty_thinking_content(parser):
86
101
  output="<think></think>This is the result", intermediate_outputs=None
87
102
  )
88
103
  parsed = parser.parse_output(response)
89
- assert parsed.intermediate_outputs == {"reasoning": ""}
104
+ assert "reasoning" not in parsed.intermediate_outputs
90
105
  assert parsed.output == "This is the result"
91
106
 
92
107
 
@@ -154,3 +169,31 @@ def test_intermediate_outputs(parser):
154
169
  )
155
170
  )
156
171
  assert out.intermediate_outputs["reasoning"] == "Some content"
172
+
173
+
174
+ def test_strip_newlines(parser):
175
+ # certain providers via LiteLLM for example, add newlines to the output
176
+ # and to the reasoning. This tests that we strip those newlines.
177
+ response = RunOutput(
178
+ output="\n\nSome content",
179
+ intermediate_outputs={
180
+ "reasoning": "\n\nSome thinking\n\n",
181
+ },
182
+ )
183
+ parsed = parser.parse_output(response)
184
+ assert parsed.output == "Some content"
185
+ assert parsed.intermediate_outputs["reasoning"] == "Some thinking"
186
+
187
+
188
+ def test_strip_newlines_with_structured_output(parser):
189
+ # certain providers via LiteLLM for example, add newlines to the output
190
+ # and to the reasoning. This tests that we strip those newlines.
191
+ response = RunOutput(
192
+ output={"some_key": "Some content"},
193
+ intermediate_outputs={
194
+ "reasoning": "\n\nSome thinking\n\n",
195
+ },
196
+ )
197
+ parsed = parser.parse_output(response)
198
+ assert parsed.output == {"some_key": "Some content"}
199
+ assert parsed.intermediate_outputs["reasoning"] == "Some thinking"
@@ -0,0 +1,76 @@
1
+ import pytest
2
+
3
+ from kiln_ai.adapters.ml_model_list import ModelFormatterID
4
+ from kiln_ai.adapters.parsers.request_formatters import (
5
+ Qwen3StyleNoThinkFormatter,
6
+ request_formatter_from_id,
7
+ )
8
+
9
+
10
+ @pytest.fixture
11
+ def qwen_formatter():
12
+ return Qwen3StyleNoThinkFormatter()
13
+
14
+
15
+ def test_qwen_formatter_string_input(qwen_formatter):
16
+ input_text = "Hello world"
17
+ formatted = qwen_formatter.format_input(input_text)
18
+ assert formatted == "Hello world\n\n/no_think"
19
+
20
+
21
+ def test_qwen_formatter_dict_input(qwen_formatter):
22
+ input_dict = {"key": "value", "nested": {"inner": "data"}}
23
+ formatted = qwen_formatter.format_input(input_dict)
24
+ expected = """{
25
+ "key": "value",
26
+ "nested": {
27
+ "inner": "data"
28
+ }
29
+ }
30
+
31
+ /no_think"""
32
+ assert formatted == expected
33
+
34
+
35
+ def test_qwen_formatter_empty_input(qwen_formatter):
36
+ # Test empty string
37
+ assert qwen_formatter.format_input("") == "\n\n/no_think"
38
+
39
+ # Test empty dict
40
+ assert qwen_formatter.format_input({}) == "{}\n\n/no_think"
41
+
42
+
43
+ def test_qwen_formatter_special_characters(qwen_formatter):
44
+ input_text = "Special chars: !@#$%^&*()_+思"
45
+ formatted = qwen_formatter.format_input(input_text)
46
+ assert formatted == "Special chars: !@#$%^&*()_+思\n\n/no_think"
47
+
48
+
49
+ def test_qwen_formatter_multiline_string(qwen_formatter):
50
+ input_text = """Line 1
51
+ Line 2
52
+ Line 3"""
53
+ formatted = qwen_formatter.format_input(input_text)
54
+ assert (
55
+ formatted
56
+ == """Line 1
57
+ Line 2
58
+ Line 3
59
+
60
+ /no_think"""
61
+ )
62
+
63
+
64
+ def test_request_formatter_factory():
65
+ # Test valid formatter ID
66
+ formatter = request_formatter_from_id(ModelFormatterID.qwen3_style_no_think)
67
+ assert isinstance(formatter, Qwen3StyleNoThinkFormatter)
68
+
69
+ # Test that the formatter works
70
+ assert formatter.format_input("test") == "test\n\n/no_think"
71
+
72
+
73
+ def test_request_formatter_factory_invalid_id():
74
+ # Test with an invalid enum value by using a string that doesn't exist in the enum
75
+ with pytest.raises(ValueError, match="Unhandled enum value"):
76
+ request_formatter_from_id("invalid_formatter_id") # type: ignore
@@ -101,7 +101,6 @@ class SimplePromptBuilder(BasePromptBuilder):
101
101
  """
102
102
  base_prompt = self.task.instruction
103
103
 
104
- # TODO: this is just a quick version. Formatting and best practices TBD
105
104
  if len(self.task.requirements) > 0:
106
105
  base_prompt += (
107
106
  "\n\nYour response should respect the following requirements:\n"
@@ -113,6 +112,18 @@ class SimplePromptBuilder(BasePromptBuilder):
113
112
  return base_prompt
114
113
 
115
114
 
115
+ class ShortPromptBuilder(BasePromptBuilder):
116
+ """A prompt builder that includes a the base prompt but excludes the requirements."""
117
+
118
+ def build_base_prompt(self) -> str:
119
+ """Build a short prompt with just the base prompt, no requirements.
120
+
121
+ Returns:
122
+ str: The constructed prompt string.
123
+ """
124
+ return self.task.instruction
125
+
126
+
116
127
  class MultiShotPromptBuilder(BasePromptBuilder):
117
128
  """A prompt builder that includes multiple examples in the prompt."""
118
129
 
@@ -414,6 +425,8 @@ def prompt_builder_from_id(prompt_id: PromptId, task: Task) -> BasePromptBuilder
414
425
  match typed_prompt_generator:
415
426
  case PromptGenerators.SIMPLE:
416
427
  return SimplePromptBuilder(task)
428
+ case PromptGenerators.SHORT:
429
+ return ShortPromptBuilder(task)
417
430
  case PromptGenerators.FEW_SHOT:
418
431
  return FewShotPromptBuilder(task)
419
432
  case PromptGenerators.MULTI_SHOT:
@@ -5,6 +5,7 @@ from kiln_ai.adapters.ml_model_list import (
5
5
  KilnModel,
6
6
  KilnModelProvider,
7
7
  ModelName,
8
+ ModelParserID,
8
9
  ModelProviderName,
9
10
  StructuredOutputMode,
10
11
  built_in_models,
@@ -15,7 +16,7 @@ from kiln_ai.adapters.model_adapters.litellm_config import (
15
16
  from kiln_ai.adapters.ollama_tools import (
16
17
  get_ollama_connection,
17
18
  )
18
- from kiln_ai.datamodel import Finetune, Task
19
+ from kiln_ai.datamodel import Finetune, FinetuneDataStrategy, Task
19
20
  from kiln_ai.datamodel.registry import project_from_id
20
21
  from kiln_ai.utils.config import Config
21
22
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
@@ -257,6 +258,14 @@ def finetune_from_id(model_id: str) -> Finetune:
257
258
  return fine_tune
258
259
 
259
260
 
261
+ def parser_from_data_strategy(
262
+ data_strategy: FinetuneDataStrategy,
263
+ ) -> ModelParserID | None:
264
+ if data_strategy == FinetuneDataStrategy.final_and_intermediate_r1_compatible:
265
+ return ModelParserID.r1_thinking
266
+ return None
267
+
268
+
260
269
  def finetune_provider_model(
261
270
  model_id: str,
262
271
  ) -> KilnModelProvider:
@@ -266,6 +275,14 @@ def finetune_provider_model(
266
275
  model_provider = KilnModelProvider(
267
276
  name=provider,
268
277
  model_id=fine_tune.fine_tune_model_id,
278
+ parser=parser_from_data_strategy(fine_tune.data_strategy),
279
+ reasoning_capable=(
280
+ fine_tune.data_strategy
281
+ in [
282
+ FinetuneDataStrategy.final_and_intermediate,
283
+ FinetuneDataStrategy.final_and_intermediate_r1_compatible,
284
+ ]
285
+ ),
269
286
  )
270
287
 
271
288
  if provider == ModelProviderName.vertex and fine_tune.fine_tune_model_id: