kiln-ai 0.15.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (72) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +22 -44
  3. kiln_ai/adapters/chat/__init__.py +8 -0
  4. kiln_ai/adapters/chat/chat_formatter.py +234 -0
  5. kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
  6. kiln_ai/adapters/data_gen/test_data_gen_task.py +19 -6
  7. kiln_ai/adapters/eval/base_eval.py +8 -6
  8. kiln_ai/adapters/eval/eval_runner.py +9 -65
  9. kiln_ai/adapters/eval/g_eval.py +26 -8
  10. kiln_ai/adapters/eval/test_base_eval.py +166 -15
  11. kiln_ai/adapters/eval/test_eval_runner.py +3 -0
  12. kiln_ai/adapters/eval/test_g_eval.py +1 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +2 -2
  14. kiln_ai/adapters/fine_tune/dataset_formatter.py +153 -197
  15. kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
  16. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +402 -211
  17. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
  18. kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
  19. kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
  20. kiln_ai/adapters/fine_tune/test_vertex_finetune.py +4 -4
  21. kiln_ai/adapters/fine_tune/together_finetune.py +12 -1
  22. kiln_ai/adapters/ml_model_list.py +556 -45
  23. kiln_ai/adapters/model_adapters/base_adapter.py +100 -35
  24. kiln_ai/adapters/model_adapters/litellm_adapter.py +116 -100
  25. kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
  26. kiln_ai/adapters/model_adapters/test_base_adapter.py +299 -52
  27. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +121 -22
  28. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +44 -2
  29. kiln_ai/adapters/model_adapters/test_structured_output.py +48 -18
  30. kiln_ai/adapters/parsers/base_parser.py +0 -3
  31. kiln_ai/adapters/parsers/parser_registry.py +5 -3
  32. kiln_ai/adapters/parsers/r1_parser.py +17 -2
  33. kiln_ai/adapters/parsers/request_formatters.py +40 -0
  34. kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
  35. kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
  36. kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
  37. kiln_ai/adapters/prompt_builders.py +14 -17
  38. kiln_ai/adapters/provider_tools.py +39 -4
  39. kiln_ai/adapters/repair/test_repair_task.py +27 -5
  40. kiln_ai/adapters/test_adapter_registry.py +88 -28
  41. kiln_ai/adapters/test_ml_model_list.py +158 -0
  42. kiln_ai/adapters/test_prompt_adaptors.py +17 -3
  43. kiln_ai/adapters/test_prompt_builders.py +27 -19
  44. kiln_ai/adapters/test_provider_tools.py +130 -12
  45. kiln_ai/datamodel/__init__.py +2 -2
  46. kiln_ai/datamodel/datamodel_enums.py +43 -4
  47. kiln_ai/datamodel/dataset_filters.py +69 -1
  48. kiln_ai/datamodel/dataset_split.py +4 -0
  49. kiln_ai/datamodel/eval.py +8 -0
  50. kiln_ai/datamodel/finetune.py +13 -7
  51. kiln_ai/datamodel/prompt_id.py +1 -0
  52. kiln_ai/datamodel/task.py +68 -7
  53. kiln_ai/datamodel/task_output.py +1 -1
  54. kiln_ai/datamodel/task_run.py +39 -7
  55. kiln_ai/datamodel/test_basemodel.py +5 -8
  56. kiln_ai/datamodel/test_dataset_filters.py +82 -0
  57. kiln_ai/datamodel/test_dataset_split.py +2 -8
  58. kiln_ai/datamodel/test_example_models.py +54 -0
  59. kiln_ai/datamodel/test_models.py +80 -9
  60. kiln_ai/datamodel/test_task.py +168 -2
  61. kiln_ai/utils/async_job_runner.py +106 -0
  62. kiln_ai/utils/config.py +3 -2
  63. kiln_ai/utils/dataset_import.py +81 -19
  64. kiln_ai/utils/logging.py +165 -0
  65. kiln_ai/utils/test_async_job_runner.py +199 -0
  66. kiln_ai/utils/test_config.py +23 -0
  67. kiln_ai/utils/test_dataset_import.py +272 -10
  68. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/METADATA +1 -1
  69. kiln_ai-0.17.0.dist-info/RECORD +113 -0
  70. kiln_ai-0.15.0.dist-info/RECORD +0 -104
  71. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/WHEEL +0 -0
  72. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -11,14 +11,15 @@ from kiln_ai.datamodel import (
11
11
  DataSourceType,
12
12
  Project,
13
13
  Task,
14
+ Usage,
14
15
  )
15
16
  from kiln_ai.datamodel.task import RunConfig
16
17
  from kiln_ai.utils.config import Config
17
18
 
18
19
 
19
20
  class MockAdapter(BaseAdapter):
20
- async def _run(self, input: dict | str) -> dict | str:
21
- return RunOutput(output="Test output", intermediate_outputs=None)
21
+ async def _run(self, input: dict | str) -> tuple[RunOutput, Usage | None]:
22
+ return RunOutput(output="Test output", intermediate_outputs=None), None
22
23
 
23
24
  def adapter_name(self) -> str:
24
25
  return "mock_adapter"
@@ -45,6 +46,7 @@ def adapter(test_task):
45
46
  model_name="phi_3_5",
46
47
  model_provider_name="ollama",
47
48
  prompt_id="simple_chain_of_thought_prompt_builder",
49
+ structured_output_mode="json_schema",
48
50
  ),
49
51
  )
50
52
 
@@ -101,6 +103,9 @@ def test_save_run_isolation(test_task, adapter):
101
103
  reloaded_output.source.properties["prompt_id"]
102
104
  == "simple_chain_of_thought_prompt_builder"
103
105
  )
106
+ assert reloaded_output.source.properties["structured_output_mode"] == "json_schema"
107
+ assert reloaded_output.source.properties["temperature"] == 1.0
108
+ assert reloaded_output.source.properties["top_p"] == 1.0
104
109
  # Run again, with same input and different output. Should create a new TaskRun.
105
110
  different_run_output = RunOutput(
106
111
  output="Different output", intermediate_outputs=None
@@ -227,3 +232,40 @@ async def test_autosave_true(test_task, adapter):
227
232
  output.source.properties["prompt_id"]
228
233
  == "simple_chain_of_thought_prompt_builder"
229
234
  )
235
+ assert output.source.properties["structured_output_mode"] == "json_schema"
236
+ assert output.source.properties["temperature"] == 1.0
237
+ assert output.source.properties["top_p"] == 1.0
238
+
239
+
240
+ def test_properties_for_task_output_custom_values(test_task):
241
+ """Test that _properties_for_task_output includes custom temperature, top_p, and structured_output_mode"""
242
+ adapter = MockAdapter(
243
+ run_config=RunConfig(
244
+ task=test_task,
245
+ model_name="gpt-4",
246
+ model_provider_name="openai",
247
+ prompt_id="simple_prompt_builder",
248
+ temperature=0.7,
249
+ top_p=0.9,
250
+ structured_output_mode="json_schema",
251
+ ),
252
+ )
253
+
254
+ input_data = "Test input"
255
+ output_data = "Test output"
256
+ run_output = RunOutput(output=output_data, intermediate_outputs=None)
257
+
258
+ task_run = adapter.generate_run(
259
+ input=input_data, input_source=None, run_output=run_output
260
+ )
261
+ task_run.save_to_file()
262
+
263
+ # Verify custom values are preserved in properties
264
+ output = task_run.output
265
+ assert output.source.properties["adapter_name"] == "mock_adapter"
266
+ assert output.source.properties["model_name"] == "gpt-4"
267
+ assert output.source.properties["model_provider"] == "openai"
268
+ assert output.source.properties["prompt_id"] == "simple_prompt_builder"
269
+ assert output.source.properties["structured_output_mode"] == "json_schema"
270
+ assert output.source.properties["temperature"] == 0.7
271
+ assert output.source.properties["top_p"] == 0.9
@@ -12,11 +12,12 @@ from kiln_ai.adapters.ml_model_list import (
12
12
  from kiln_ai.adapters.model_adapters.base_adapter import (
13
13
  BaseAdapter,
14
14
  RunOutput,
15
+ Usage,
15
16
  )
16
17
  from kiln_ai.adapters.ollama_tools import ollama_online
17
18
  from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
18
19
  from kiln_ai.datamodel import PromptId
19
- from kiln_ai.datamodel.task import RunConfig
20
+ from kiln_ai.datamodel.task import RunConfig, RunConfigProperties
20
21
  from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
21
22
 
22
23
 
@@ -50,12 +51,13 @@ class MockAdapter(BaseAdapter):
50
51
  model_name="phi_3_5",
51
52
  model_provider_name="ollama",
52
53
  prompt_id="simple_chain_of_thought_prompt_builder",
54
+ structured_output_mode="json_schema",
53
55
  ),
54
56
  )
55
57
  self.response = response
56
58
 
57
- async def _run(self, input: str) -> RunOutput:
58
- return RunOutput(output=self.response, intermediate_outputs=None)
59
+ async def _run(self, input: str) -> tuple[RunOutput, Usage | None]:
60
+ return RunOutput(output=self.response, intermediate_outputs=None), None
59
61
 
60
62
  def adapter_name(self) -> str:
61
63
  return "mock_adapter"
@@ -145,7 +147,15 @@ def build_structured_output_test_task(tmp_path: Path):
145
147
 
146
148
  async def run_structured_output_test(tmp_path: Path, model_name: str, provider: str):
147
149
  task = build_structured_output_test_task(tmp_path)
148
- a = adapter_for_task(task, model_name=model_name, provider=provider)
150
+ a = adapter_for_task(
151
+ task,
152
+ run_config_properties=RunConfigProperties(
153
+ model_name=model_name,
154
+ model_provider_name=provider,
155
+ prompt_id="simple_prompt_builder",
156
+ structured_output_mode="unknown",
157
+ ),
158
+ )
149
159
  try:
150
160
  run = await a.invoke("Cows") # a joke about cows
151
161
  parsed = json.loads(run.output.output)
@@ -196,10 +206,12 @@ def build_structured_input_test_task(tmp_path: Path):
196
206
  return task
197
207
 
198
208
 
199
- async def run_structured_input_test(tmp_path: Path, model_name: str, provider: str):
209
+ async def run_structured_input_test(
210
+ tmp_path: Path, model_name: str, provider: str, prompt_id: PromptId
211
+ ):
200
212
  task = build_structured_input_test_task(tmp_path)
201
213
  try:
202
- await run_structured_input_task(task, model_name, provider)
214
+ await run_structured_input_task(task, model_name, provider, prompt_id)
203
215
  except ValueError as e:
204
216
  if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
205
217
  pytest.skip(
@@ -208,43 +220,54 @@ async def run_structured_input_test(tmp_path: Path, model_name: str, provider: s
208
220
  raise e
209
221
 
210
222
 
211
- async def run_structured_input_task(
223
+ async def run_structured_input_task_no_validation(
212
224
  task: datamodel.Task,
213
225
  model_name: str,
214
226
  provider: str,
215
- prompt_id: PromptId | None = None,
227
+ prompt_id: PromptId,
216
228
  ):
217
229
  a = adapter_for_task(
218
230
  task,
219
- model_name=model_name,
220
- provider=provider,
221
- prompt_id=prompt_id,
231
+ run_config_properties=RunConfigProperties(
232
+ model_name=model_name,
233
+ model_provider_name=provider,
234
+ prompt_id=prompt_id,
235
+ structured_output_mode="unknown",
236
+ ),
222
237
  )
223
238
  with pytest.raises(ValueError):
224
239
  # not structured input in dictionary
225
240
  await a.invoke("a=1, b=2, c=3")
226
- with pytest.raises(
227
- ValueError,
228
- match="This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema.",
229
- ):
241
+ with pytest.raises(ValueError, match="This task requires a specific input"):
230
242
  # invalid structured input
231
243
  await a.invoke({"a": 1, "b": 2, "d": 3})
232
244
 
233
245
  try:
234
246
  run = await a.invoke({"a": 2, "b": 2, "c": 2})
235
247
  response = run.output.output
248
+ return response, a
236
249
  except ValueError as e:
237
250
  if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
238
251
  pytest.skip(
239
252
  f"Skipping {model_name} {provider} because Ollama is not running"
240
253
  )
241
254
  raise e
255
+
256
+
257
+ async def run_structured_input_task(
258
+ task: datamodel.Task,
259
+ model_name: str,
260
+ provider: str,
261
+ prompt_id: PromptId,
262
+ ):
263
+ response, a = await run_structured_input_task_no_validation(
264
+ task, model_name, provider, prompt_id
265
+ )
242
266
  assert response is not None
243
267
  if isinstance(response, str):
244
268
  assert "[[equilateral]]" in response
245
269
  else:
246
270
  assert response["is_equilateral"] is True
247
-
248
271
  expected_pb_name = "simple_prompt_builder"
249
272
  if prompt_id is not None:
250
273
  expected_pb_name = prompt_id
@@ -271,7 +294,9 @@ async def test_structured_input_gpt_4o_mini(tmp_path):
271
294
  async def test_all_built_in_models_structured_input(
272
295
  tmp_path, model_name, provider_name
273
296
  ):
274
- await run_structured_input_test(tmp_path, model_name, provider_name)
297
+ await run_structured_input_test(
298
+ tmp_path, model_name, provider_name, "simple_prompt_builder"
299
+ )
275
300
 
276
301
 
277
302
  @pytest.mark.paid
@@ -325,6 +350,11 @@ When asked for a final result, this is the format (for an equilateral example):
325
350
  """
326
351
  task.output_json_schema = json.dumps(triangle_schema)
327
352
  task.save_to_file()
328
- await run_structured_input_task(
353
+ response, adapter = await run_structured_input_task_no_validation(
329
354
  task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
330
355
  )
356
+
357
+ formatted_response = json.loads(response)
358
+ assert formatted_response["is_equilateral"] is True
359
+ assert formatted_response["is_scalene"] is False
360
+ assert formatted_response["is_obtuse"] is False
@@ -2,9 +2,6 @@ from kiln_ai.adapters.run_output import RunOutput
2
2
 
3
3
 
4
4
  class BaseParser:
5
- def __init__(self, structured_output: bool = False):
6
- self.structured_output = structured_output
7
-
8
5
  def parse_output(self, original_output: RunOutput) -> RunOutput:
9
6
  """
10
7
  Method for parsing the output of a model. Typically overridden by subclasses.
@@ -6,14 +6,16 @@ from kiln_ai.adapters.parsers.r1_parser import R1ThinkingParser
6
6
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
7
7
 
8
8
 
9
- def model_parser_from_id(parser_id: ModelParserID | None) -> Type[BaseParser]:
9
+ def model_parser_from_id(parser_id: ModelParserID | None) -> BaseParser:
10
10
  """
11
11
  Get a model parser from its ID.
12
12
  """
13
13
  match parser_id:
14
14
  case None:
15
- return BaseParser
15
+ return BaseParser()
16
16
  case ModelParserID.r1_thinking:
17
- return R1ThinkingParser
17
+ return R1ThinkingParser()
18
+ case ModelParserID.optional_r1_thinking:
19
+ return R1ThinkingParser(allow_missing_thinking=True)
18
20
  case _:
19
21
  raise_exhaustive_enum_error(parser_id)
@@ -7,6 +7,9 @@ class R1ThinkingParser(BaseParser):
7
7
  START_TAG = "<think>"
8
8
  END_TAG = "</think>"
9
9
 
10
+ def __init__(self, allow_missing_thinking: bool = False):
11
+ self.allow_missing_thinking = allow_missing_thinking
12
+
10
13
  def parse_output(self, original_output: RunOutput) -> RunOutput:
11
14
  """
12
15
  Parse the <think> </think> tags from the response into the intermediate and final outputs.
@@ -27,6 +30,14 @@ class R1ThinkingParser(BaseParser):
27
30
  original_output.intermediate_outputs is not None
28
31
  and "reasoning" in original_output.intermediate_outputs
29
32
  ):
33
+ # sometimes the output and reasoning are wrapped in newlines
34
+ if isinstance(original_output.output, str):
35
+ original_output.output = original_output.output.strip()
36
+
37
+ original_output.intermediate_outputs["reasoning"] = (
38
+ original_output.intermediate_outputs["reasoning"].strip()
39
+ )
40
+
30
41
  return original_output
31
42
 
32
43
  # This parser only works for strings
@@ -39,7 +50,10 @@ class R1ThinkingParser(BaseParser):
39
50
  # Find the thinking tags
40
51
  think_end = cleaned_response.find(self.END_TAG)
41
52
  if think_end == -1:
42
- raise ValueError("Missing </think> tag")
53
+ if self.allow_missing_thinking:
54
+ return original_output
55
+ else:
56
+ raise ValueError("Missing </think> tag")
43
57
 
44
58
  think_tag_start = cleaned_response.find(self.START_TAG)
45
59
  if think_tag_start == -1:
@@ -66,7 +80,8 @@ class R1ThinkingParser(BaseParser):
66
80
 
67
81
  # Add thinking content to intermediate outputs if it exists
68
82
  intermediate_outputs = original_output.intermediate_outputs or {}
69
- intermediate_outputs["reasoning"] = thinking_content
83
+ if thinking_content is not None and len(thinking_content) > 0:
84
+ intermediate_outputs["reasoning"] = thinking_content
70
85
 
71
86
  return RunOutput(
72
87
  output=result,
@@ -0,0 +1,40 @@
1
+ import json
2
+ from typing import Dict, Protocol
3
+
4
+ from kiln_ai.adapters.ml_model_list import ModelFormatterID
5
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
6
+
7
+
8
+ class RequestFormatter(Protocol):
9
+ def format_input(self, original_input: Dict | str) -> Dict | str:
10
+ """
11
+ Method for formatting the input to a model.
12
+ """
13
+ ...
14
+
15
+
16
+ class Qwen3StyleNoThinkFormatter:
17
+ def format_input(self, original_input: Dict | str) -> Dict | str:
18
+ """
19
+ Format the input to a model for Qwen3 /no_think instruction
20
+ """
21
+ formatted_input = (
22
+ original_input
23
+ if isinstance(original_input, str)
24
+ else json.dumps(original_input, indent=2)
25
+ )
26
+
27
+ return formatted_input + "\n\n/no_think"
28
+
29
+
30
+ def request_formatter_from_id(
31
+ formatter_id: ModelFormatterID,
32
+ ) -> RequestFormatter:
33
+ """
34
+ Get a model parser from its ID.
35
+ """
36
+ match formatter_id:
37
+ case ModelFormatterID.qwen3_style_no_think:
38
+ return Qwen3StyleNoThinkFormatter()
39
+ case _:
40
+ raise_exhaustive_enum_error(formatter_id)
@@ -28,5 +28,5 @@ def test_model_parser_from_id_invalid():
28
28
  )
29
29
  def test_model_parser_from_id_parametrized(parser_id, expected_class):
30
30
  """Test all valid parser IDs using parametrize."""
31
- parser_class = model_parser_from_id(parser_id)
32
- assert parser_class == expected_class
31
+ parser = model_parser_from_id(parser_id)
32
+ assert isinstance(parser, expected_class)
@@ -46,6 +46,21 @@ def test_response_with_whitespace(parser):
46
46
  assert parsed.output.strip() == "This is the result"
47
47
 
48
48
 
49
+ def test_empty_thinking_content(parser):
50
+ response = RunOutput(
51
+ output="""
52
+ <think>
53
+
54
+ </think>
55
+ This is the result
56
+ """,
57
+ intermediate_outputs=None,
58
+ )
59
+ parsed = parser.parse_output(response)
60
+ assert "reasoning" not in parsed.intermediate_outputs
61
+ assert parsed.output.strip() == "This is the result"
62
+
63
+
49
64
  def test_missing_start_tag(parser):
50
65
  parsed = parser.parse_output(
51
66
  RunOutput(output="Some content</think>result", intermediate_outputs=None)
@@ -86,7 +101,7 @@ def test_empty_thinking_content(parser):
86
101
  output="<think></think>This is the result", intermediate_outputs=None
87
102
  )
88
103
  parsed = parser.parse_output(response)
89
- assert parsed.intermediate_outputs == {"reasoning": ""}
104
+ assert "reasoning" not in parsed.intermediate_outputs
90
105
  assert parsed.output == "This is the result"
91
106
 
92
107
 
@@ -154,3 +169,31 @@ def test_intermediate_outputs(parser):
154
169
  )
155
170
  )
156
171
  assert out.intermediate_outputs["reasoning"] == "Some content"
172
+
173
+
174
+ def test_strip_newlines(parser):
175
+ # certain providers via LiteLLM for example, add newlines to the output
176
+ # and to the reasoning. This tests that we strip those newlines.
177
+ response = RunOutput(
178
+ output="\n\nSome content",
179
+ intermediate_outputs={
180
+ "reasoning": "\n\nSome thinking\n\n",
181
+ },
182
+ )
183
+ parsed = parser.parse_output(response)
184
+ assert parsed.output == "Some content"
185
+ assert parsed.intermediate_outputs["reasoning"] == "Some thinking"
186
+
187
+
188
+ def test_strip_newlines_with_structured_output(parser):
189
+ # certain providers via LiteLLM for example, add newlines to the output
190
+ # and to the reasoning. This tests that we strip those newlines.
191
+ response = RunOutput(
192
+ output={"some_key": "Some content"},
193
+ intermediate_outputs={
194
+ "reasoning": "\n\nSome thinking\n\n",
195
+ },
196
+ )
197
+ parsed = parser.parse_output(response)
198
+ assert parsed.output == {"some_key": "Some content"}
199
+ assert parsed.intermediate_outputs["reasoning"] == "Some thinking"
@@ -0,0 +1,76 @@
1
+ import pytest
2
+
3
+ from kiln_ai.adapters.ml_model_list import ModelFormatterID
4
+ from kiln_ai.adapters.parsers.request_formatters import (
5
+ Qwen3StyleNoThinkFormatter,
6
+ request_formatter_from_id,
7
+ )
8
+
9
+
10
+ @pytest.fixture
11
+ def qwen_formatter():
12
+ return Qwen3StyleNoThinkFormatter()
13
+
14
+
15
+ def test_qwen_formatter_string_input(qwen_formatter):
16
+ input_text = "Hello world"
17
+ formatted = qwen_formatter.format_input(input_text)
18
+ assert formatted == "Hello world\n\n/no_think"
19
+
20
+
21
+ def test_qwen_formatter_dict_input(qwen_formatter):
22
+ input_dict = {"key": "value", "nested": {"inner": "data"}}
23
+ formatted = qwen_formatter.format_input(input_dict)
24
+ expected = """{
25
+ "key": "value",
26
+ "nested": {
27
+ "inner": "data"
28
+ }
29
+ }
30
+
31
+ /no_think"""
32
+ assert formatted == expected
33
+
34
+
35
+ def test_qwen_formatter_empty_input(qwen_formatter):
36
+ # Test empty string
37
+ assert qwen_formatter.format_input("") == "\n\n/no_think"
38
+
39
+ # Test empty dict
40
+ assert qwen_formatter.format_input({}) == "{}\n\n/no_think"
41
+
42
+
43
+ def test_qwen_formatter_special_characters(qwen_formatter):
44
+ input_text = "Special chars: !@#$%^&*()_+思"
45
+ formatted = qwen_formatter.format_input(input_text)
46
+ assert formatted == "Special chars: !@#$%^&*()_+思\n\n/no_think"
47
+
48
+
49
+ def test_qwen_formatter_multiline_string(qwen_formatter):
50
+ input_text = """Line 1
51
+ Line 2
52
+ Line 3"""
53
+ formatted = qwen_formatter.format_input(input_text)
54
+ assert (
55
+ formatted
56
+ == """Line 1
57
+ Line 2
58
+ Line 3
59
+
60
+ /no_think"""
61
+ )
62
+
63
+
64
+ def test_request_formatter_factory():
65
+ # Test valid formatter ID
66
+ formatter = request_formatter_from_id(ModelFormatterID.qwen3_style_no_think)
67
+ assert isinstance(formatter, Qwen3StyleNoThinkFormatter)
68
+
69
+ # Test that the formatter works
70
+ assert formatter.format_input("test") == "test\n\n/no_think"
71
+
72
+
73
+ def test_request_formatter_factory_invalid_id():
74
+ # Test with an invalid enum value by using a string that doesn't exist in the enum
75
+ with pytest.raises(ValueError, match="Unhandled enum value"):
76
+ request_formatter_from_id("invalid_formatter_id") # type: ignore
@@ -1,6 +1,4 @@
1
- import json
2
1
  from abc import ABCMeta, abstractmethod
3
- from typing import Dict
4
2
 
5
3
  from kiln_ai.datamodel import PromptGenerators, PromptId, Task, TaskRun
6
4
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
@@ -53,20 +51,6 @@ class BasePromptBuilder(metaclass=ABCMeta):
53
51
  """
54
52
  pass
55
53
 
56
- def build_user_message(self, input: Dict | str) -> str:
57
- """Build a user message from the input.
58
-
59
- Args:
60
- input (Union[Dict, str]): The input to format into a message.
61
-
62
- Returns:
63
- str: The formatted user message.
64
- """
65
- if isinstance(input, Dict):
66
- return f"The input is:\n{json.dumps(input, indent=2, ensure_ascii=False)}"
67
-
68
- return f"The input is:\n{input}"
69
-
70
54
  def chain_of_thought_prompt(self) -> str | None:
71
55
  """Build and return the chain of thought prompt string.
72
56
 
@@ -101,7 +85,6 @@ class SimplePromptBuilder(BasePromptBuilder):
101
85
  """
102
86
  base_prompt = self.task.instruction
103
87
 
104
- # TODO: this is just a quick version. Formatting and best practices TBD
105
88
  if len(self.task.requirements) > 0:
106
89
  base_prompt += (
107
90
  "\n\nYour response should respect the following requirements:\n"
@@ -113,6 +96,18 @@ class SimplePromptBuilder(BasePromptBuilder):
113
96
  return base_prompt
114
97
 
115
98
 
99
+ class ShortPromptBuilder(BasePromptBuilder):
100
+ """A prompt builder that includes a the base prompt but excludes the requirements."""
101
+
102
+ def build_base_prompt(self) -> str:
103
+ """Build a short prompt with just the base prompt, no requirements.
104
+
105
+ Returns:
106
+ str: The constructed prompt string.
107
+ """
108
+ return self.task.instruction
109
+
110
+
116
111
  class MultiShotPromptBuilder(BasePromptBuilder):
117
112
  """A prompt builder that includes multiple examples in the prompt."""
118
113
 
@@ -414,6 +409,8 @@ def prompt_builder_from_id(prompt_id: PromptId, task: Task) -> BasePromptBuilder
414
409
  match typed_prompt_generator:
415
410
  case PromptGenerators.SIMPLE:
416
411
  return SimplePromptBuilder(task)
412
+ case PromptGenerators.SHORT:
413
+ return ShortPromptBuilder(task)
417
414
  case PromptGenerators.FEW_SHOT:
418
415
  return FewShotPromptBuilder(task)
419
416
  case PromptGenerators.MULTI_SHOT:
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  from dataclasses import dataclass
2
3
  from typing import Dict, List
3
4
 
@@ -5,6 +6,7 @@ from kiln_ai.adapters.ml_model_list import (
5
6
  KilnModel,
6
7
  KilnModelProvider,
7
8
  ModelName,
9
+ ModelParserID,
8
10
  ModelProviderName,
9
11
  StructuredOutputMode,
10
12
  built_in_models,
@@ -16,10 +18,14 @@ from kiln_ai.adapters.ollama_tools import (
16
18
  get_ollama_connection,
17
19
  )
18
20
  from kiln_ai.datamodel import Finetune, Task
21
+ from kiln_ai.datamodel.datamodel_enums import ChatStrategy
19
22
  from kiln_ai.datamodel.registry import project_from_id
23
+ from kiln_ai.datamodel.task import RunConfigProperties
20
24
  from kiln_ai.utils.config import Config
21
25
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
22
26
 
27
+ logger = logging.getLogger(__name__)
28
+
23
29
 
24
30
  async def provider_enabled(provider_name: ModelProviderName) -> bool:
25
31
  if provider_name == ModelProviderName.ollama:
@@ -162,6 +168,10 @@ def kiln_model_provider_from(
162
168
  # For custom registry, get the provider name and model name from the model id
163
169
  if provider_name == ModelProviderName.kiln_custom_registry:
164
170
  provider_name, name = parse_custom_model_id(name)
171
+ else:
172
+ logger.warning(
173
+ f"Unexpected model/provider pair. Will treat as custom model but check your model settings. Provider: {provider_name}/{name}"
174
+ )
165
175
 
166
176
  # Custom/untested model. Set untested, and build a ModelProvider at runtime
167
177
  if provider_name is None:
@@ -176,12 +186,15 @@ def kiln_model_provider_from(
176
186
  supports_data_gen=False,
177
187
  untested_model=True,
178
188
  model_id=name,
189
+ # We don't know the structured output mode for custom models, so we default to json_instructions which is the only one that works everywhere.
190
+ structured_output_mode=StructuredOutputMode.json_instructions,
179
191
  )
180
192
 
181
193
 
182
- def lite_llm_config(
183
- model_id: str,
194
+ def lite_llm_config_for_openai_compatible(
195
+ run_config_properties: RunConfigProperties,
184
196
  ) -> LiteLlmConfig:
197
+ model_id = run_config_properties.model_name
185
198
  try:
186
199
  openai_provider_name, model_id = model_id.split("::")
187
200
  except Exception:
@@ -205,10 +218,16 @@ def lite_llm_config(
205
218
  f"OpenAI compatible provider {openai_provider_name} has no base URL"
206
219
  )
207
220
 
221
+ # Update a copy of the run config properties to use the openai compatible provider
222
+ updated_run_config_properties = run_config_properties.model_copy(deep=True)
223
+ updated_run_config_properties.model_provider_name = (
224
+ ModelProviderName.openai_compatible
225
+ )
226
+ updated_run_config_properties.model_name = model_id
227
+
208
228
  return LiteLlmConfig(
209
229
  # OpenAI compatible, with a custom base URL
210
- model_name=model_id,
211
- provider_name=ModelProviderName.openai_compatible,
230
+ run_config_properties=updated_run_config_properties,
212
231
  base_url=base_url,
213
232
  additional_body_options={
214
233
  "api_key": api_key,
@@ -257,6 +276,14 @@ def finetune_from_id(model_id: str) -> Finetune:
257
276
  return fine_tune
258
277
 
259
278
 
279
+ def parser_from_data_strategy(
280
+ data_strategy: ChatStrategy,
281
+ ) -> ModelParserID | None:
282
+ if data_strategy == ChatStrategy.single_turn_r1_thinking:
283
+ return ModelParserID.r1_thinking
284
+ return None
285
+
286
+
260
287
  def finetune_provider_model(
261
288
  model_id: str,
262
289
  ) -> KilnModelProvider:
@@ -266,6 +293,14 @@ def finetune_provider_model(
266
293
  model_provider = KilnModelProvider(
267
294
  name=provider,
268
295
  model_id=fine_tune.fine_tune_model_id,
296
+ parser=parser_from_data_strategy(fine_tune.data_strategy),
297
+ reasoning_capable=(
298
+ fine_tune.data_strategy
299
+ in [
300
+ ChatStrategy.single_turn_r1_thinking,
301
+ ]
302
+ ),
303
+ tuned_chat_strategy=fine_tune.data_strategy,
269
304
  )
270
305
 
271
306
  if provider == ModelProviderName.vertex and fine_tune.fine_tune_model_id: