kiln-ai 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (54) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +22 -44
  3. kiln_ai/adapters/chat/__init__.py +8 -0
  4. kiln_ai/adapters/chat/chat_formatter.py +234 -0
  5. kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
  6. kiln_ai/adapters/data_gen/test_data_gen_task.py +19 -6
  7. kiln_ai/adapters/eval/base_eval.py +8 -6
  8. kiln_ai/adapters/eval/eval_runner.py +4 -1
  9. kiln_ai/adapters/eval/g_eval.py +23 -5
  10. kiln_ai/adapters/eval/test_base_eval.py +166 -15
  11. kiln_ai/adapters/eval/test_eval_runner.py +3 -0
  12. kiln_ai/adapters/eval/test_g_eval.py +1 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +2 -2
  14. kiln_ai/adapters/fine_tune/dataset_formatter.py +138 -272
  15. kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
  16. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +287 -353
  17. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
  18. kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
  19. kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
  20. kiln_ai/adapters/fine_tune/test_vertex_finetune.py +4 -4
  21. kiln_ai/adapters/fine_tune/together_finetune.py +12 -1
  22. kiln_ai/adapters/ml_model_list.py +80 -43
  23. kiln_ai/adapters/model_adapters/base_adapter.py +73 -26
  24. kiln_ai/adapters/model_adapters/litellm_adapter.py +79 -97
  25. kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
  26. kiln_ai/adapters/model_adapters/test_base_adapter.py +235 -60
  27. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +56 -21
  28. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -0
  29. kiln_ai/adapters/model_adapters/test_structured_output.py +44 -12
  30. kiln_ai/adapters/prompt_builders.py +0 -16
  31. kiln_ai/adapters/provider_tools.py +27 -9
  32. kiln_ai/adapters/repair/test_repair_task.py +24 -3
  33. kiln_ai/adapters/test_adapter_registry.py +88 -28
  34. kiln_ai/adapters/test_ml_model_list.py +158 -0
  35. kiln_ai/adapters/test_prompt_adaptors.py +17 -3
  36. kiln_ai/adapters/test_prompt_builders.py +3 -16
  37. kiln_ai/adapters/test_provider_tools.py +69 -20
  38. kiln_ai/datamodel/__init__.py +0 -2
  39. kiln_ai/datamodel/datamodel_enums.py +38 -13
  40. kiln_ai/datamodel/finetune.py +12 -7
  41. kiln_ai/datamodel/task.py +68 -7
  42. kiln_ai/datamodel/test_basemodel.py +2 -1
  43. kiln_ai/datamodel/test_dataset_split.py +0 -8
  44. kiln_ai/datamodel/test_models.py +33 -10
  45. kiln_ai/datamodel/test_task.py +168 -2
  46. kiln_ai/utils/config.py +3 -2
  47. kiln_ai/utils/dataset_import.py +1 -1
  48. kiln_ai/utils/logging.py +165 -0
  49. kiln_ai/utils/test_config.py +23 -0
  50. kiln_ai/utils/test_dataset_import.py +30 -0
  51. {kiln_ai-0.16.0.dist-info → kiln_ai-0.17.0.dist-info}/METADATA +1 -1
  52. {kiln_ai-0.16.0.dist-info → kiln_ai-0.17.0.dist-info}/RECORD +54 -49
  53. {kiln_ai-0.16.0.dist-info → kiln_ai-0.17.0.dist-info}/WHEEL +0 -0
  54. {kiln_ai-0.16.0.dist-info → kiln_ai-0.17.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -12,15 +12,13 @@ from kiln_ai.adapters.ml_model_list import (
12
12
  StructuredOutputMode,
13
13
  )
14
14
  from kiln_ai.adapters.model_adapters.base_adapter import (
15
- COT_FINAL_ANSWER_PROMPT,
16
15
  AdapterConfig,
17
16
  BaseAdapter,
18
17
  RunOutput,
19
18
  Usage,
20
19
  )
21
20
  from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
22
- from kiln_ai.datamodel import PromptGenerators, PromptId
23
- from kiln_ai.datamodel.task import RunConfig
21
+ from kiln_ai.datamodel.task import run_config_from_run_config_properties
24
22
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
25
23
 
26
24
  logger = logging.getLogger(__name__)
@@ -31,7 +29,6 @@ class LiteLlmAdapter(BaseAdapter):
31
29
  self,
32
30
  config: LiteLlmConfig,
33
31
  kiln_task: datamodel.Task,
34
- prompt_id: PromptId | None = None,
35
32
  base_adapter_config: AdapterConfig | None = None,
36
33
  ):
37
34
  self.config = config
@@ -40,11 +37,10 @@ class LiteLlmAdapter(BaseAdapter):
40
37
  self._headers = config.default_headers
41
38
  self._litellm_model_id: str | None = None
42
39
 
43
- run_config = RunConfig(
40
+ # Create a RunConfig, adding the task to the RunConfigProperties
41
+ run_config = run_config_from_run_config_properties(
44
42
  task=kiln_task,
45
- model_name=config.model_name,
46
- model_provider_name=config.provider_name,
47
- prompt_id=prompt_id or PromptGenerators.SIMPLE,
43
+ run_config_properties=config.run_config_properties,
48
44
  )
49
45
 
50
46
  super().__init__(
@@ -57,79 +53,69 @@ class LiteLlmAdapter(BaseAdapter):
57
53
  if not provider.model_id:
58
54
  raise ValueError("Model ID is required for OpenAI compatible models")
59
55
 
60
- intermediate_outputs: dict[str, str] = {}
61
- prompt = self.build_prompt()
62
- user_msg = self.prompt_builder.build_user_message(input)
63
- messages = [
64
- {"role": "system", "content": prompt},
65
- {"role": "user", "content": user_msg},
66
- ]
67
-
68
- run_strategy, cot_prompt = self.run_strategy()
69
-
70
- if run_strategy == "cot_as_message":
71
- # Used for reasoning-capable models that can output thinking and structured format
72
- if not cot_prompt:
73
- raise ValueError("cot_prompt is required for cot_as_message strategy")
74
- messages.append({"role": "system", "content": cot_prompt})
75
- elif run_strategy == "cot_two_call":
76
- if not cot_prompt:
77
- raise ValueError("cot_prompt is required for cot_two_call strategy")
78
- messages.append({"role": "system", "content": cot_prompt})
79
-
80
- # First call for chain of thought
81
- # No response format as this request is for "thinking" in plain text
82
- # No logprobs as only needed for final answer
56
+ chat_formatter = self.build_chat_formatter(input)
57
+
58
+ prior_output = None
59
+ prior_message = None
60
+ response = None
61
+ turns = 0
62
+ while True:
63
+ turns += 1
64
+ if turns > 10:
65
+ raise RuntimeError(
66
+ "Too many turns. Stopping iteration to avoid using too many tokens."
67
+ )
68
+
69
+ turn = chat_formatter.next_turn(prior_output)
70
+ if turn is None:
71
+ break
72
+
73
+ skip_response_format = not turn.final_call
74
+ all_messages = chat_formatter.message_dicts()
83
75
  completion_kwargs = await self.build_completion_kwargs(
84
- provider, messages, None, skip_response_format=True
76
+ provider,
77
+ all_messages,
78
+ self.base_adapter_config.top_logprobs if turn.final_call else None,
79
+ skip_response_format,
85
80
  )
86
- cot_response = await litellm.acompletion(**completion_kwargs)
81
+ response = await litellm.acompletion(**completion_kwargs)
87
82
  if (
88
- not isinstance(cot_response, ModelResponse)
89
- or not cot_response.choices
90
- or len(cot_response.choices) == 0
91
- or not isinstance(cot_response.choices[0], Choices)
83
+ not isinstance(response, ModelResponse)
84
+ or not response.choices
85
+ or len(response.choices) == 0
86
+ or not isinstance(response.choices[0], Choices)
92
87
  ):
93
88
  raise RuntimeError(
94
- f"Expected ModelResponse with Choices, got {type(cot_response)}."
89
+ f"Expected ModelResponse with Choices, got {type(response)}."
95
90
  )
96
- cot_content = cot_response.choices[0].message.content
97
- if cot_content is not None:
98
- intermediate_outputs["chain_of_thought"] = cot_content
99
-
100
- messages.extend(
101
- [
102
- {"role": "assistant", "content": cot_content or ""},
103
- {"role": "user", "content": COT_FINAL_ANSWER_PROMPT},
104
- ]
105
- )
91
+ prior_message = response.choices[0].message
92
+ prior_output = prior_message.content
106
93
 
107
- # Make the API call using litellm
108
- completion_kwargs = await self.build_completion_kwargs(
109
- provider, messages, self.base_adapter_config.top_logprobs
110
- )
111
- response = await litellm.acompletion(**completion_kwargs)
94
+ # Fallback: Use args of first tool call to task_response if it exists
95
+ if (
96
+ not prior_output
97
+ and hasattr(prior_message, "tool_calls")
98
+ and prior_message.tool_calls
99
+ ):
100
+ tool_call = next(
101
+ (
102
+ tool_call
103
+ for tool_call in prior_message.tool_calls
104
+ if tool_call.function.name == "task_response"
105
+ ),
106
+ None,
107
+ )
108
+ if tool_call:
109
+ prior_output = tool_call.function.arguments
112
110
 
113
- if not isinstance(response, ModelResponse):
114
- raise RuntimeError(f"Expected ModelResponse, got {type(response)}.")
111
+ if not prior_output:
112
+ raise RuntimeError("No output returned from model")
115
113
 
116
- # Maybe remove this? There is no error attribute on the response object.
117
- # # Keeping in typesafe way as we added it for a reason, but should investigate what that was and if it still applies.
118
- if hasattr(response, "error") and response.__getattribute__("error"):
119
- raise RuntimeError(
120
- f"LLM API returned an error: {response.__getattribute__('error')}"
121
- )
114
+ if response is None or prior_message is None:
115
+ raise RuntimeError("No response returned from model")
122
116
 
123
- if (
124
- not response.choices
125
- or len(response.choices) == 0
126
- or not isinstance(response.choices[0], Choices)
127
- ):
128
- raise RuntimeError(
129
- "No message content returned in the response from LLM API"
130
- )
117
+ intermediate_outputs = chat_formatter.intermediate_outputs()
131
118
 
132
- message = response.choices[0].message
133
119
  logprobs = (
134
120
  response.choices[0].logprobs
135
121
  if hasattr(response.choices[0], "logprobs")
@@ -143,31 +129,15 @@ class LiteLlmAdapter(BaseAdapter):
143
129
 
144
130
  # Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream)
145
131
  if (
146
- hasattr(message, "reasoning_content")
147
- and message.reasoning_content
148
- and len(message.reasoning_content.strip()) > 0
132
+ prior_message is not None
133
+ and hasattr(prior_message, "reasoning_content")
134
+ and prior_message.reasoning_content
135
+ and len(prior_message.reasoning_content.strip()) > 0
149
136
  ):
150
- intermediate_outputs["reasoning"] = message.reasoning_content.strip()
137
+ intermediate_outputs["reasoning"] = prior_message.reasoning_content.strip()
151
138
 
152
139
  # the string content of the response
153
- response_content = message.content
154
-
155
- # Fallback: Use args of first tool call to task_response if it exists
156
- if (
157
- not response_content
158
- and hasattr(message, "tool_calls")
159
- and message.tool_calls
160
- ):
161
- tool_call = next(
162
- (
163
- tool_call
164
- for tool_call in message.tool_calls
165
- if tool_call.function.name == "task_response"
166
- ),
167
- None,
168
- )
169
- if tool_call:
170
- response_content = tool_call.function.arguments
140
+ response_content = prior_output
171
141
 
172
142
  if not isinstance(response_content, str):
173
143
  raise RuntimeError(f"response is not a string: {response_content}")
@@ -186,8 +156,9 @@ class LiteLlmAdapter(BaseAdapter):
186
156
  if not self.has_structured_output():
187
157
  return {}
188
158
 
189
- provider = self.model_provider()
190
- match provider.structured_output_mode:
159
+ structured_output_mode = self.run_config.structured_output_mode
160
+
161
+ match structured_output_mode:
191
162
  case StructuredOutputMode.json_mode:
192
163
  return {"response_format": {"type": "json_object"}}
193
164
  case StructuredOutputMode.json_schema:
@@ -206,16 +177,20 @@ class LiteLlmAdapter(BaseAdapter):
206
177
  # We set response_format to json_object and also set json instructions in the prompt
207
178
  return {"response_format": {"type": "json_object"}}
208
179
  case StructuredOutputMode.default:
209
- if provider.name == ModelProviderName.ollama:
180
+ provider_name = self.run_config.model_provider_name
181
+ if provider_name == ModelProviderName.ollama:
210
182
  # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
211
183
  return self.json_schema_response_format()
212
184
  else:
213
185
  # Default to function calling -- it's older than the other modes. Higher compatibility.
214
186
  # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
215
- strict = provider.name == ModelProviderName.openai
187
+ strict = provider_name == ModelProviderName.openai
216
188
  return self.tool_call_params(strict=strict)
189
+ case StructuredOutputMode.unknown:
190
+ # See above, but this case should never happen.
191
+ raise ValueError("Structured output mode is unknown.")
217
192
  case _:
218
- raise_exhaustive_enum_error(provider.structured_output_mode)
193
+ raise_exhaustive_enum_error(structured_output_mode)
219
194
 
220
195
  def json_schema_response_format(self) -> dict[str, Any]:
221
196
  output_schema = self.task().output_schema()
@@ -387,6 +362,13 @@ class LiteLlmAdapter(BaseAdapter):
387
362
  "messages": messages,
388
363
  "api_base": self._api_base,
389
364
  "headers": self._headers,
365
+ "temperature": self.run_config.temperature,
366
+ "top_p": self.run_config.top_p,
367
+ # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
368
+ # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
369
+ # Better to ignore them than to fail the model call.
370
+ # https://docs.litellm.ai/docs/completion/input
371
+ "drop_params": True,
390
372
  **extra_body,
391
373
  **self._additional_body_options,
392
374
  }
@@ -1,10 +1,11 @@
1
1
  from dataclasses import dataclass, field
2
2
 
3
+ from kiln_ai.datamodel.task import RunConfigProperties
4
+
3
5
 
4
6
  @dataclass
5
7
  class LiteLlmConfig:
6
- model_name: str
7
- provider_name: str
8
+ run_config_properties: RunConfigProperties
8
9
  # If set, over rides the provider-name based URL from litellm
9
10
  base_url: str | None = None
10
11
  # Headers to send with every request
@@ -6,7 +6,8 @@ from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMo
6
6
  from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput
7
7
  from kiln_ai.adapters.parsers.request_formatters import request_formatter_from_id
8
8
  from kiln_ai.datamodel import Task
9
- from kiln_ai.datamodel.task import RunConfig
9
+ from kiln_ai.datamodel.datamodel_enums import ChatStrategy
10
+ from kiln_ai.datamodel.task import RunConfig, RunConfigProperties
10
11
 
11
12
 
12
13
  class MockAdapter(BaseAdapter):
@@ -37,8 +38,9 @@ def adapter(base_task):
37
38
  run_config=RunConfig(
38
39
  task=base_task,
39
40
  model_name="test_model",
40
- model_provider_name="test_provider",
41
+ model_provider_name="openai",
41
42
  prompt_id="simple_prompt_builder",
43
+ structured_output_mode="json_schema",
42
44
  ),
43
45
  )
44
46
 
@@ -88,7 +90,7 @@ async def test_model_provider_loads_and_caches(adapter, mock_provider):
88
90
  # First call should load and cache
89
91
  provider1 = adapter.model_provider()
90
92
  assert provider1 == mock_provider
91
- mock_loader.assert_called_once_with("test_model", "test_provider")
93
+ mock_loader.assert_called_once_with("test_model", "openai")
92
94
 
93
95
  # Second call should use cache
94
96
  mock_loader.reset_mock()
@@ -97,29 +99,30 @@ async def test_model_provider_loads_and_caches(adapter, mock_provider):
97
99
  mock_loader.assert_not_called()
98
100
 
99
101
 
100
- async def test_model_provider_missing_names(base_task):
102
+ async def test_model_provider_invalid_provider_model_name(base_task):
103
+ """Test error when model or provider name is missing"""
104
+ # Test with missing model name
105
+ with pytest.raises(ValueError, match="Input should be"):
106
+ adapter = MockAdapter(
107
+ run_config=RunConfig(
108
+ task=base_task,
109
+ model_name="test_model",
110
+ model_provider_name="invalid",
111
+ prompt_id="simple_prompt_builder",
112
+ ),
113
+ )
114
+
115
+
116
+ async def test_model_provider_missing_model_names(base_task):
101
117
  """Test error when model or provider name is missing"""
102
118
  # Test with missing model name
103
119
  adapter = MockAdapter(
104
120
  run_config=RunConfig(
105
121
  task=base_task,
106
122
  model_name="",
107
- model_provider_name="",
108
- prompt_id="simple_prompt_builder",
109
- ),
110
- )
111
- with pytest.raises(
112
- ValueError, match="model_name and model_provider_name must be provided"
113
- ):
114
- await adapter.model_provider()
115
-
116
- # Test with missing provider name
117
- adapter = MockAdapter(
118
- run_config=RunConfig(
119
- task=base_task,
120
- model_name="test_model",
121
- model_provider_name="",
123
+ model_provider_name="openai",
122
124
  prompt_id="simple_prompt_builder",
125
+ structured_output_mode="json_schema",
123
126
  ),
124
127
  )
125
128
  with pytest.raises(
@@ -138,7 +141,7 @@ async def test_model_provider_not_found(adapter):
138
141
 
139
142
  with pytest.raises(
140
143
  ValueError,
141
- match="model_provider_name test_provider not found for model test_model",
144
+ match="not found for model test_model",
142
145
  ):
143
146
  await adapter.model_provider()
144
147
 
@@ -168,11 +171,7 @@ async def test_prompt_builder_json_instructions(
168
171
  adapter.prompt_builder = mock_prompt_builder
169
172
  adapter.model_provider_name = "openai"
170
173
  adapter.has_structured_output = MagicMock(return_value=output_schema)
171
-
172
- # provider mock
173
- provider = MagicMock()
174
- provider.structured_output_mode = structured_output_mode
175
- adapter.model_provider = MagicMock(return_value=provider)
174
+ adapter.run_config.structured_output_mode = structured_output_mode
176
175
 
177
176
  # Test
178
177
  adapter.build_prompt()
@@ -181,41 +180,6 @@ async def test_prompt_builder_json_instructions(
181
180
  )
182
181
 
183
182
 
184
- @pytest.mark.parametrize(
185
- "cot_prompt,has_structured_output,reasoning_capable,expected",
186
- [
187
- # COT and normal LLM
188
- ("think carefully", False, False, ("cot_two_call", "think carefully")),
189
- # Structured output with thinking-capable LLM
190
- ("think carefully", True, True, ("cot_as_message", "think carefully")),
191
- # Structured output with normal LLM
192
- ("think carefully", True, False, ("cot_two_call", "think carefully")),
193
- # Basic cases - no COT
194
- (None, True, True, ("basic", None)),
195
- (None, False, False, ("basic", None)),
196
- (None, True, False, ("basic", None)),
197
- (None, False, True, ("basic", None)),
198
- # Edge case - COT prompt exists but structured output is False and reasoning_capable is True
199
- ("think carefully", False, True, ("cot_as_message", "think carefully")),
200
- ],
201
- )
202
- async def test_run_strategy(
203
- adapter, cot_prompt, has_structured_output, reasoning_capable, expected
204
- ):
205
- """Test that run_strategy returns correct strategy based on conditions"""
206
- # Mock dependencies
207
- adapter.prompt_builder.chain_of_thought_prompt = MagicMock(return_value=cot_prompt)
208
- adapter.has_structured_output = MagicMock(return_value=has_structured_output)
209
-
210
- provider = MagicMock()
211
- provider.reasoning_capable = reasoning_capable
212
- adapter.model_provider = MagicMock(return_value=provider)
213
-
214
- # Test
215
- result = adapter.run_strategy()
216
- assert result == expected
217
-
218
-
219
183
  @pytest.mark.asyncio
220
184
  @pytest.mark.parametrize(
221
185
  "formatter_id,expected_input,expected_calls",
@@ -269,3 +233,214 @@ async def test_input_formatting(
269
233
  # Verify original input was preserved in the run
270
234
  if formatter_id:
271
235
  mock_formatter.format_input.assert_called_once_with(original_input)
236
+
237
+
238
+ async def test_properties_for_task_output_includes_all_run_config_properties(adapter):
239
+ """Test that all properties from RunConfigProperties are saved in task output properties"""
240
+ # Get all field names from RunConfigProperties
241
+ run_config_properties_fields = set(RunConfigProperties.model_fields.keys())
242
+
243
+ # Get the properties saved by the adapter
244
+ saved_properties = adapter._properties_for_task_output()
245
+ saved_property_keys = set(saved_properties.keys())
246
+
247
+ # Check which RunConfigProperties fields are missing from saved properties
248
+ # Note: model_provider_name becomes model_provider in saved properties
249
+ expected_mappings = {
250
+ "model_name": "model_name",
251
+ "model_provider_name": "model_provider",
252
+ "prompt_id": "prompt_id",
253
+ "temperature": "temperature",
254
+ "top_p": "top_p",
255
+ "structured_output_mode": "structured_output_mode",
256
+ }
257
+
258
+ missing_properties = []
259
+ for field_name in run_config_properties_fields:
260
+ expected_key = expected_mappings.get(field_name, field_name)
261
+ if expected_key not in saved_property_keys:
262
+ missing_properties.append(
263
+ f"RunConfigProperties.{field_name} -> {expected_key}"
264
+ )
265
+
266
+ assert not missing_properties, (
267
+ f"The following RunConfigProperties fields are not saved by _properties_for_task_output: {missing_properties}. Please update the method to include them."
268
+ )
269
+
270
+
271
+ async def test_properties_for_task_output_catches_missing_new_property(adapter):
272
+ """Test that demonstrates our test will catch when new properties are added to RunConfigProperties but not to _properties_for_task_output"""
273
+ # Simulate what happens if a new property was added to RunConfigProperties
274
+ # We'll mock the model_fields to include a fake new property
275
+ original_fields = RunConfigProperties.model_fields.copy()
276
+
277
+ # Create a mock field to simulate a new property being added
278
+ from pydantic.fields import FieldInfo
279
+
280
+ mock_field = FieldInfo(annotation=str, default="default_value")
281
+
282
+ try:
283
+ # Add a fake new field to simulate someone adding a property
284
+ RunConfigProperties.model_fields["new_fake_property"] = mock_field
285
+
286
+ # Get all field names from RunConfigProperties (now includes our fake property)
287
+ run_config_properties_fields = set(RunConfigProperties.model_fields.keys())
288
+
289
+ # Get the properties saved by the adapter (won't include our fake property)
290
+ saved_properties = adapter._properties_for_task_output()
291
+ saved_property_keys = set(saved_properties.keys())
292
+
293
+ # The mappings don't include our fake property
294
+ expected_mappings = {
295
+ "model_name": "model_name",
296
+ "model_provider_name": "model_provider",
297
+ "prompt_id": "prompt_id",
298
+ "temperature": "temperature",
299
+ "top_p": "top_p",
300
+ "structured_output_mode": "structured_output_mode",
301
+ }
302
+
303
+ missing_properties = []
304
+ for field_name in run_config_properties_fields:
305
+ expected_key = expected_mappings.get(field_name, field_name)
306
+ if expected_key not in saved_property_keys:
307
+ missing_properties.append(
308
+ f"RunConfigProperties.{field_name} -> {expected_key}"
309
+ )
310
+
311
+ # This should find our missing fake property
312
+ assert missing_properties == [
313
+ "RunConfigProperties.new_fake_property -> new_fake_property"
314
+ ], f"Expected to find missing fake property, but got: {missing_properties}"
315
+
316
+ finally:
317
+ # Restore the original fields
318
+ RunConfigProperties.model_fields.clear()
319
+ RunConfigProperties.model_fields.update(original_fields)
320
+
321
+
322
+ @pytest.mark.parametrize(
323
+ "cot_prompt,tuned_strategy,reasoning_capable,expected_formatter_class",
324
+ [
325
+ # No COT prompt -> always single turn
326
+ (None, None, False, "SingleTurnFormatter"),
327
+ (None, ChatStrategy.two_message_cot, False, "SingleTurnFormatter"),
328
+ (None, ChatStrategy.single_turn_r1_thinking, True, "SingleTurnFormatter"),
329
+ # With COT prompt:
330
+ # - Tuned strategy takes precedence (except single turn)
331
+ (
332
+ "think step by step",
333
+ ChatStrategy.two_message_cot,
334
+ False,
335
+ "TwoMessageCotFormatter",
336
+ ),
337
+ (
338
+ "think step by step",
339
+ ChatStrategy.single_turn_r1_thinking,
340
+ False,
341
+ "SingleTurnR1ThinkingFormatter",
342
+ ),
343
+ # - Tuned single turn is ignored when COT exists
344
+ (
345
+ "think step by step",
346
+ ChatStrategy.single_turn,
347
+ True,
348
+ "SingleTurnR1ThinkingFormatter",
349
+ ),
350
+ # - Reasoning capable -> single turn R1 thinking
351
+ ("think step by step", None, True, "SingleTurnR1ThinkingFormatter"),
352
+ # - Not reasoning capable -> two message COT
353
+ ("think step by step", None, False, "TwoMessageCotFormatter"),
354
+ ],
355
+ )
356
+ def test_build_chat_formatter(
357
+ adapter,
358
+ cot_prompt,
359
+ tuned_strategy,
360
+ reasoning_capable,
361
+ expected_formatter_class,
362
+ ):
363
+ """Test chat formatter strategy selection based on COT prompt, tuned strategy, and model capabilities"""
364
+ # Mock the prompt builder
365
+ mock_prompt_builder = MagicMock()
366
+ mock_prompt_builder.chain_of_thought_prompt.return_value = cot_prompt
367
+ mock_prompt_builder.build_prompt.return_value = "system message"
368
+ adapter.prompt_builder = mock_prompt_builder
369
+
370
+ # Mock the model provider
371
+ mock_provider = MagicMock()
372
+ mock_provider.tuned_chat_strategy = tuned_strategy
373
+ mock_provider.reasoning_capable = reasoning_capable
374
+ adapter.model_provider = MagicMock(return_value=mock_provider)
375
+
376
+ # Get the formatter
377
+ formatter = adapter.build_chat_formatter("test input")
378
+
379
+ # Verify the formatter type
380
+ assert formatter.__class__.__name__ == expected_formatter_class
381
+
382
+ # Verify the formatter was created with correct parameters
383
+ assert formatter.system_message == "system message"
384
+ assert formatter.user_input == "test input"
385
+ # Only check thinking_instructions for formatters that use it
386
+ if expected_formatter_class == "TwoMessageCotFormatter":
387
+ if cot_prompt:
388
+ assert formatter.thinking_instructions == cot_prompt
389
+ else:
390
+ assert formatter.thinking_instructions is None
391
+ # For other formatters, don't assert thinking_instructions
392
+
393
+ # Verify prompt builder was called correctly
394
+ mock_prompt_builder.build_prompt.assert_called_once()
395
+ mock_prompt_builder.chain_of_thought_prompt.assert_called_once()
396
+
397
+
398
+ @pytest.mark.parametrize(
399
+ "initial_mode,expected_mode",
400
+ [
401
+ (
402
+ StructuredOutputMode.json_schema,
403
+ StructuredOutputMode.json_schema,
404
+ ), # Should not change
405
+ (
406
+ StructuredOutputMode.unknown,
407
+ StructuredOutputMode.json_mode,
408
+ ), # Should update to default
409
+ ],
410
+ )
411
+ async def test_update_run_config_unknown_structured_output_mode(
412
+ base_task, initial_mode, expected_mode
413
+ ):
414
+ """Test that unknown structured output mode is updated to the default for the model provider"""
415
+ # Create a run config with the initial mode
416
+ run_config = RunConfig(
417
+ task=base_task,
418
+ model_name="test_model",
419
+ model_provider_name="openai",
420
+ prompt_id="simple_prompt_builder",
421
+ structured_output_mode=initial_mode,
422
+ temperature=0.7, # Add some other properties to verify they're preserved
423
+ top_p=0.9,
424
+ )
425
+
426
+ # Mock the default mode lookup
427
+ with patch(
428
+ "kiln_ai.adapters.model_adapters.base_adapter.default_structured_output_mode_for_model_provider"
429
+ ) as mock_default:
430
+ mock_default.return_value = StructuredOutputMode.json_mode
431
+
432
+ # Create the adapter
433
+ adapter = MockAdapter(run_config=run_config)
434
+
435
+ # Verify the mode was updated correctly
436
+ assert adapter.run_config.structured_output_mode == expected_mode
437
+
438
+ # Verify other properties were preserved
439
+ assert adapter.run_config.temperature == 0.7
440
+ assert adapter.run_config.top_p == 0.9
441
+
442
+ # Verify the default mode lookup was only called when needed
443
+ if initial_mode == StructuredOutputMode.unknown:
444
+ mock_default.assert_called_once_with("test_model", "openai")
445
+ else:
446
+ mock_default.assert_not_called()