kiln-ai 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +22 -44
- kiln_ai/adapters/chat/__init__.py +8 -0
- kiln_ai/adapters/chat/chat_formatter.py +234 -0
- kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +19 -6
- kiln_ai/adapters/eval/base_eval.py +8 -6
- kiln_ai/adapters/eval/eval_runner.py +4 -1
- kiln_ai/adapters/eval/g_eval.py +23 -5
- kiln_ai/adapters/eval/test_base_eval.py +166 -15
- kiln_ai/adapters/eval/test_eval_runner.py +3 -0
- kiln_ai/adapters/eval/test_g_eval.py +1 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +2 -2
- kiln_ai/adapters/fine_tune/dataset_formatter.py +138 -272
- kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +287 -353
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +4 -4
- kiln_ai/adapters/fine_tune/together_finetune.py +12 -1
- kiln_ai/adapters/ml_model_list.py +80 -43
- kiln_ai/adapters/model_adapters/base_adapter.py +73 -26
- kiln_ai/adapters/model_adapters/litellm_adapter.py +79 -97
- kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
- kiln_ai/adapters/model_adapters/test_base_adapter.py +235 -60
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +56 -21
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -0
- kiln_ai/adapters/model_adapters/test_structured_output.py +44 -12
- kiln_ai/adapters/prompt_builders.py +0 -16
- kiln_ai/adapters/provider_tools.py +27 -9
- kiln_ai/adapters/repair/test_repair_task.py +24 -3
- kiln_ai/adapters/test_adapter_registry.py +88 -28
- kiln_ai/adapters/test_ml_model_list.py +158 -0
- kiln_ai/adapters/test_prompt_adaptors.py +17 -3
- kiln_ai/adapters/test_prompt_builders.py +3 -16
- kiln_ai/adapters/test_provider_tools.py +69 -20
- kiln_ai/datamodel/__init__.py +0 -2
- kiln_ai/datamodel/datamodel_enums.py +38 -13
- kiln_ai/datamodel/finetune.py +12 -7
- kiln_ai/datamodel/task.py +68 -7
- kiln_ai/datamodel/test_basemodel.py +2 -1
- kiln_ai/datamodel/test_dataset_split.py +0 -8
- kiln_ai/datamodel/test_models.py +33 -10
- kiln_ai/datamodel/test_task.py +168 -2
- kiln_ai/utils/config.py +3 -2
- kiln_ai/utils/dataset_import.py +1 -1
- kiln_ai/utils/logging.py +165 -0
- kiln_ai/utils/test_config.py +23 -0
- kiln_ai/utils/test_dataset_import.py +30 -0
- {kiln_ai-0.16.0.dist-info → kiln_ai-0.17.0.dist-info}/METADATA +1 -1
- {kiln_ai-0.16.0.dist-info → kiln_ai-0.17.0.dist-info}/RECORD +54 -49
- {kiln_ai-0.16.0.dist-info → kiln_ai-0.17.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.16.0.dist-info → kiln_ai-0.17.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -12,15 +12,13 @@ from kiln_ai.adapters.ml_model_list import (
|
|
|
12
12
|
StructuredOutputMode,
|
|
13
13
|
)
|
|
14
14
|
from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
15
|
-
COT_FINAL_ANSWER_PROMPT,
|
|
16
15
|
AdapterConfig,
|
|
17
16
|
BaseAdapter,
|
|
18
17
|
RunOutput,
|
|
19
18
|
Usage,
|
|
20
19
|
)
|
|
21
20
|
from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
|
|
22
|
-
from kiln_ai.datamodel import
|
|
23
|
-
from kiln_ai.datamodel.task import RunConfig
|
|
21
|
+
from kiln_ai.datamodel.task import run_config_from_run_config_properties
|
|
24
22
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
25
23
|
|
|
26
24
|
logger = logging.getLogger(__name__)
|
|
@@ -31,7 +29,6 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
31
29
|
self,
|
|
32
30
|
config: LiteLlmConfig,
|
|
33
31
|
kiln_task: datamodel.Task,
|
|
34
|
-
prompt_id: PromptId | None = None,
|
|
35
32
|
base_adapter_config: AdapterConfig | None = None,
|
|
36
33
|
):
|
|
37
34
|
self.config = config
|
|
@@ -40,11 +37,10 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
40
37
|
self._headers = config.default_headers
|
|
41
38
|
self._litellm_model_id: str | None = None
|
|
42
39
|
|
|
43
|
-
|
|
40
|
+
# Create a RunConfig, adding the task to the RunConfigProperties
|
|
41
|
+
run_config = run_config_from_run_config_properties(
|
|
44
42
|
task=kiln_task,
|
|
45
|
-
|
|
46
|
-
model_provider_name=config.provider_name,
|
|
47
|
-
prompt_id=prompt_id or PromptGenerators.SIMPLE,
|
|
43
|
+
run_config_properties=config.run_config_properties,
|
|
48
44
|
)
|
|
49
45
|
|
|
50
46
|
super().__init__(
|
|
@@ -57,79 +53,69 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
57
53
|
if not provider.model_id:
|
|
58
54
|
raise ValueError("Model ID is required for OpenAI compatible models")
|
|
59
55
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
# First call for chain of thought
|
|
81
|
-
# No response format as this request is for "thinking" in plain text
|
|
82
|
-
# No logprobs as only needed for final answer
|
|
56
|
+
chat_formatter = self.build_chat_formatter(input)
|
|
57
|
+
|
|
58
|
+
prior_output = None
|
|
59
|
+
prior_message = None
|
|
60
|
+
response = None
|
|
61
|
+
turns = 0
|
|
62
|
+
while True:
|
|
63
|
+
turns += 1
|
|
64
|
+
if turns > 10:
|
|
65
|
+
raise RuntimeError(
|
|
66
|
+
"Too many turns. Stopping iteration to avoid using too many tokens."
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
turn = chat_formatter.next_turn(prior_output)
|
|
70
|
+
if turn is None:
|
|
71
|
+
break
|
|
72
|
+
|
|
73
|
+
skip_response_format = not turn.final_call
|
|
74
|
+
all_messages = chat_formatter.message_dicts()
|
|
83
75
|
completion_kwargs = await self.build_completion_kwargs(
|
|
84
|
-
provider,
|
|
76
|
+
provider,
|
|
77
|
+
all_messages,
|
|
78
|
+
self.base_adapter_config.top_logprobs if turn.final_call else None,
|
|
79
|
+
skip_response_format,
|
|
85
80
|
)
|
|
86
|
-
|
|
81
|
+
response = await litellm.acompletion(**completion_kwargs)
|
|
87
82
|
if (
|
|
88
|
-
not isinstance(
|
|
89
|
-
or not
|
|
90
|
-
or len(
|
|
91
|
-
or not isinstance(
|
|
83
|
+
not isinstance(response, ModelResponse)
|
|
84
|
+
or not response.choices
|
|
85
|
+
or len(response.choices) == 0
|
|
86
|
+
or not isinstance(response.choices[0], Choices)
|
|
92
87
|
):
|
|
93
88
|
raise RuntimeError(
|
|
94
|
-
f"Expected ModelResponse with Choices, got {type(
|
|
89
|
+
f"Expected ModelResponse with Choices, got {type(response)}."
|
|
95
90
|
)
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
intermediate_outputs["chain_of_thought"] = cot_content
|
|
99
|
-
|
|
100
|
-
messages.extend(
|
|
101
|
-
[
|
|
102
|
-
{"role": "assistant", "content": cot_content or ""},
|
|
103
|
-
{"role": "user", "content": COT_FINAL_ANSWER_PROMPT},
|
|
104
|
-
]
|
|
105
|
-
)
|
|
91
|
+
prior_message = response.choices[0].message
|
|
92
|
+
prior_output = prior_message.content
|
|
106
93
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
94
|
+
# Fallback: Use args of first tool call to task_response if it exists
|
|
95
|
+
if (
|
|
96
|
+
not prior_output
|
|
97
|
+
and hasattr(prior_message, "tool_calls")
|
|
98
|
+
and prior_message.tool_calls
|
|
99
|
+
):
|
|
100
|
+
tool_call = next(
|
|
101
|
+
(
|
|
102
|
+
tool_call
|
|
103
|
+
for tool_call in prior_message.tool_calls
|
|
104
|
+
if tool_call.function.name == "task_response"
|
|
105
|
+
),
|
|
106
|
+
None,
|
|
107
|
+
)
|
|
108
|
+
if tool_call:
|
|
109
|
+
prior_output = tool_call.function.arguments
|
|
112
110
|
|
|
113
|
-
|
|
114
|
-
|
|
111
|
+
if not prior_output:
|
|
112
|
+
raise RuntimeError("No output returned from model")
|
|
115
113
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
if hasattr(response, "error") and response.__getattribute__("error"):
|
|
119
|
-
raise RuntimeError(
|
|
120
|
-
f"LLM API returned an error: {response.__getattribute__('error')}"
|
|
121
|
-
)
|
|
114
|
+
if response is None or prior_message is None:
|
|
115
|
+
raise RuntimeError("No response returned from model")
|
|
122
116
|
|
|
123
|
-
|
|
124
|
-
not response.choices
|
|
125
|
-
or len(response.choices) == 0
|
|
126
|
-
or not isinstance(response.choices[0], Choices)
|
|
127
|
-
):
|
|
128
|
-
raise RuntimeError(
|
|
129
|
-
"No message content returned in the response from LLM API"
|
|
130
|
-
)
|
|
117
|
+
intermediate_outputs = chat_formatter.intermediate_outputs()
|
|
131
118
|
|
|
132
|
-
message = response.choices[0].message
|
|
133
119
|
logprobs = (
|
|
134
120
|
response.choices[0].logprobs
|
|
135
121
|
if hasattr(response.choices[0], "logprobs")
|
|
@@ -143,31 +129,15 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
143
129
|
|
|
144
130
|
# Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream)
|
|
145
131
|
if (
|
|
146
|
-
|
|
147
|
-
and
|
|
148
|
-
and
|
|
132
|
+
prior_message is not None
|
|
133
|
+
and hasattr(prior_message, "reasoning_content")
|
|
134
|
+
and prior_message.reasoning_content
|
|
135
|
+
and len(prior_message.reasoning_content.strip()) > 0
|
|
149
136
|
):
|
|
150
|
-
intermediate_outputs["reasoning"] =
|
|
137
|
+
intermediate_outputs["reasoning"] = prior_message.reasoning_content.strip()
|
|
151
138
|
|
|
152
139
|
# the string content of the response
|
|
153
|
-
response_content =
|
|
154
|
-
|
|
155
|
-
# Fallback: Use args of first tool call to task_response if it exists
|
|
156
|
-
if (
|
|
157
|
-
not response_content
|
|
158
|
-
and hasattr(message, "tool_calls")
|
|
159
|
-
and message.tool_calls
|
|
160
|
-
):
|
|
161
|
-
tool_call = next(
|
|
162
|
-
(
|
|
163
|
-
tool_call
|
|
164
|
-
for tool_call in message.tool_calls
|
|
165
|
-
if tool_call.function.name == "task_response"
|
|
166
|
-
),
|
|
167
|
-
None,
|
|
168
|
-
)
|
|
169
|
-
if tool_call:
|
|
170
|
-
response_content = tool_call.function.arguments
|
|
140
|
+
response_content = prior_output
|
|
171
141
|
|
|
172
142
|
if not isinstance(response_content, str):
|
|
173
143
|
raise RuntimeError(f"response is not a string: {response_content}")
|
|
@@ -186,8 +156,9 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
186
156
|
if not self.has_structured_output():
|
|
187
157
|
return {}
|
|
188
158
|
|
|
189
|
-
|
|
190
|
-
|
|
159
|
+
structured_output_mode = self.run_config.structured_output_mode
|
|
160
|
+
|
|
161
|
+
match structured_output_mode:
|
|
191
162
|
case StructuredOutputMode.json_mode:
|
|
192
163
|
return {"response_format": {"type": "json_object"}}
|
|
193
164
|
case StructuredOutputMode.json_schema:
|
|
@@ -206,16 +177,20 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
206
177
|
# We set response_format to json_object and also set json instructions in the prompt
|
|
207
178
|
return {"response_format": {"type": "json_object"}}
|
|
208
179
|
case StructuredOutputMode.default:
|
|
209
|
-
|
|
180
|
+
provider_name = self.run_config.model_provider_name
|
|
181
|
+
if provider_name == ModelProviderName.ollama:
|
|
210
182
|
# Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
|
|
211
183
|
return self.json_schema_response_format()
|
|
212
184
|
else:
|
|
213
185
|
# Default to function calling -- it's older than the other modes. Higher compatibility.
|
|
214
186
|
# Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
|
|
215
|
-
strict =
|
|
187
|
+
strict = provider_name == ModelProviderName.openai
|
|
216
188
|
return self.tool_call_params(strict=strict)
|
|
189
|
+
case StructuredOutputMode.unknown:
|
|
190
|
+
# See above, but this case should never happen.
|
|
191
|
+
raise ValueError("Structured output mode is unknown.")
|
|
217
192
|
case _:
|
|
218
|
-
raise_exhaustive_enum_error(
|
|
193
|
+
raise_exhaustive_enum_error(structured_output_mode)
|
|
219
194
|
|
|
220
195
|
def json_schema_response_format(self) -> dict[str, Any]:
|
|
221
196
|
output_schema = self.task().output_schema()
|
|
@@ -387,6 +362,13 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
387
362
|
"messages": messages,
|
|
388
363
|
"api_base": self._api_base,
|
|
389
364
|
"headers": self._headers,
|
|
365
|
+
"temperature": self.run_config.temperature,
|
|
366
|
+
"top_p": self.run_config.top_p,
|
|
367
|
+
# This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
|
|
368
|
+
# Not all models and providers support all openai params (for example, o3 doesn't support top_p)
|
|
369
|
+
# Better to ignore them than to fail the model call.
|
|
370
|
+
# https://docs.litellm.ai/docs/completion/input
|
|
371
|
+
"drop_params": True,
|
|
390
372
|
**extra_body,
|
|
391
373
|
**self._additional_body_options,
|
|
392
374
|
}
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from dataclasses import dataclass, field
|
|
2
2
|
|
|
3
|
+
from kiln_ai.datamodel.task import RunConfigProperties
|
|
4
|
+
|
|
3
5
|
|
|
4
6
|
@dataclass
|
|
5
7
|
class LiteLlmConfig:
|
|
6
|
-
|
|
7
|
-
provider_name: str
|
|
8
|
+
run_config_properties: RunConfigProperties
|
|
8
9
|
# If set, over rides the provider-name based URL from litellm
|
|
9
10
|
base_url: str | None = None
|
|
10
11
|
# Headers to send with every request
|
|
@@ -6,7 +6,8 @@ from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMo
|
|
|
6
6
|
from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput
|
|
7
7
|
from kiln_ai.adapters.parsers.request_formatters import request_formatter_from_id
|
|
8
8
|
from kiln_ai.datamodel import Task
|
|
9
|
-
from kiln_ai.datamodel.
|
|
9
|
+
from kiln_ai.datamodel.datamodel_enums import ChatStrategy
|
|
10
|
+
from kiln_ai.datamodel.task import RunConfig, RunConfigProperties
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
class MockAdapter(BaseAdapter):
|
|
@@ -37,8 +38,9 @@ def adapter(base_task):
|
|
|
37
38
|
run_config=RunConfig(
|
|
38
39
|
task=base_task,
|
|
39
40
|
model_name="test_model",
|
|
40
|
-
model_provider_name="
|
|
41
|
+
model_provider_name="openai",
|
|
41
42
|
prompt_id="simple_prompt_builder",
|
|
43
|
+
structured_output_mode="json_schema",
|
|
42
44
|
),
|
|
43
45
|
)
|
|
44
46
|
|
|
@@ -88,7 +90,7 @@ async def test_model_provider_loads_and_caches(adapter, mock_provider):
|
|
|
88
90
|
# First call should load and cache
|
|
89
91
|
provider1 = adapter.model_provider()
|
|
90
92
|
assert provider1 == mock_provider
|
|
91
|
-
mock_loader.assert_called_once_with("test_model", "
|
|
93
|
+
mock_loader.assert_called_once_with("test_model", "openai")
|
|
92
94
|
|
|
93
95
|
# Second call should use cache
|
|
94
96
|
mock_loader.reset_mock()
|
|
@@ -97,29 +99,30 @@ async def test_model_provider_loads_and_caches(adapter, mock_provider):
|
|
|
97
99
|
mock_loader.assert_not_called()
|
|
98
100
|
|
|
99
101
|
|
|
100
|
-
async def
|
|
102
|
+
async def test_model_provider_invalid_provider_model_name(base_task):
|
|
103
|
+
"""Test error when model or provider name is missing"""
|
|
104
|
+
# Test with missing model name
|
|
105
|
+
with pytest.raises(ValueError, match="Input should be"):
|
|
106
|
+
adapter = MockAdapter(
|
|
107
|
+
run_config=RunConfig(
|
|
108
|
+
task=base_task,
|
|
109
|
+
model_name="test_model",
|
|
110
|
+
model_provider_name="invalid",
|
|
111
|
+
prompt_id="simple_prompt_builder",
|
|
112
|
+
),
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
async def test_model_provider_missing_model_names(base_task):
|
|
101
117
|
"""Test error when model or provider name is missing"""
|
|
102
118
|
# Test with missing model name
|
|
103
119
|
adapter = MockAdapter(
|
|
104
120
|
run_config=RunConfig(
|
|
105
121
|
task=base_task,
|
|
106
122
|
model_name="",
|
|
107
|
-
model_provider_name="",
|
|
108
|
-
prompt_id="simple_prompt_builder",
|
|
109
|
-
),
|
|
110
|
-
)
|
|
111
|
-
with pytest.raises(
|
|
112
|
-
ValueError, match="model_name and model_provider_name must be provided"
|
|
113
|
-
):
|
|
114
|
-
await adapter.model_provider()
|
|
115
|
-
|
|
116
|
-
# Test with missing provider name
|
|
117
|
-
adapter = MockAdapter(
|
|
118
|
-
run_config=RunConfig(
|
|
119
|
-
task=base_task,
|
|
120
|
-
model_name="test_model",
|
|
121
|
-
model_provider_name="",
|
|
123
|
+
model_provider_name="openai",
|
|
122
124
|
prompt_id="simple_prompt_builder",
|
|
125
|
+
structured_output_mode="json_schema",
|
|
123
126
|
),
|
|
124
127
|
)
|
|
125
128
|
with pytest.raises(
|
|
@@ -138,7 +141,7 @@ async def test_model_provider_not_found(adapter):
|
|
|
138
141
|
|
|
139
142
|
with pytest.raises(
|
|
140
143
|
ValueError,
|
|
141
|
-
match="
|
|
144
|
+
match="not found for model test_model",
|
|
142
145
|
):
|
|
143
146
|
await adapter.model_provider()
|
|
144
147
|
|
|
@@ -168,11 +171,7 @@ async def test_prompt_builder_json_instructions(
|
|
|
168
171
|
adapter.prompt_builder = mock_prompt_builder
|
|
169
172
|
adapter.model_provider_name = "openai"
|
|
170
173
|
adapter.has_structured_output = MagicMock(return_value=output_schema)
|
|
171
|
-
|
|
172
|
-
# provider mock
|
|
173
|
-
provider = MagicMock()
|
|
174
|
-
provider.structured_output_mode = structured_output_mode
|
|
175
|
-
adapter.model_provider = MagicMock(return_value=provider)
|
|
174
|
+
adapter.run_config.structured_output_mode = structured_output_mode
|
|
176
175
|
|
|
177
176
|
# Test
|
|
178
177
|
adapter.build_prompt()
|
|
@@ -181,41 +180,6 @@ async def test_prompt_builder_json_instructions(
|
|
|
181
180
|
)
|
|
182
181
|
|
|
183
182
|
|
|
184
|
-
@pytest.mark.parametrize(
|
|
185
|
-
"cot_prompt,has_structured_output,reasoning_capable,expected",
|
|
186
|
-
[
|
|
187
|
-
# COT and normal LLM
|
|
188
|
-
("think carefully", False, False, ("cot_two_call", "think carefully")),
|
|
189
|
-
# Structured output with thinking-capable LLM
|
|
190
|
-
("think carefully", True, True, ("cot_as_message", "think carefully")),
|
|
191
|
-
# Structured output with normal LLM
|
|
192
|
-
("think carefully", True, False, ("cot_two_call", "think carefully")),
|
|
193
|
-
# Basic cases - no COT
|
|
194
|
-
(None, True, True, ("basic", None)),
|
|
195
|
-
(None, False, False, ("basic", None)),
|
|
196
|
-
(None, True, False, ("basic", None)),
|
|
197
|
-
(None, False, True, ("basic", None)),
|
|
198
|
-
# Edge case - COT prompt exists but structured output is False and reasoning_capable is True
|
|
199
|
-
("think carefully", False, True, ("cot_as_message", "think carefully")),
|
|
200
|
-
],
|
|
201
|
-
)
|
|
202
|
-
async def test_run_strategy(
|
|
203
|
-
adapter, cot_prompt, has_structured_output, reasoning_capable, expected
|
|
204
|
-
):
|
|
205
|
-
"""Test that run_strategy returns correct strategy based on conditions"""
|
|
206
|
-
# Mock dependencies
|
|
207
|
-
adapter.prompt_builder.chain_of_thought_prompt = MagicMock(return_value=cot_prompt)
|
|
208
|
-
adapter.has_structured_output = MagicMock(return_value=has_structured_output)
|
|
209
|
-
|
|
210
|
-
provider = MagicMock()
|
|
211
|
-
provider.reasoning_capable = reasoning_capable
|
|
212
|
-
adapter.model_provider = MagicMock(return_value=provider)
|
|
213
|
-
|
|
214
|
-
# Test
|
|
215
|
-
result = adapter.run_strategy()
|
|
216
|
-
assert result == expected
|
|
217
|
-
|
|
218
|
-
|
|
219
183
|
@pytest.mark.asyncio
|
|
220
184
|
@pytest.mark.parametrize(
|
|
221
185
|
"formatter_id,expected_input,expected_calls",
|
|
@@ -269,3 +233,214 @@ async def test_input_formatting(
|
|
|
269
233
|
# Verify original input was preserved in the run
|
|
270
234
|
if formatter_id:
|
|
271
235
|
mock_formatter.format_input.assert_called_once_with(original_input)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
async def test_properties_for_task_output_includes_all_run_config_properties(adapter):
|
|
239
|
+
"""Test that all properties from RunConfigProperties are saved in task output properties"""
|
|
240
|
+
# Get all field names from RunConfigProperties
|
|
241
|
+
run_config_properties_fields = set(RunConfigProperties.model_fields.keys())
|
|
242
|
+
|
|
243
|
+
# Get the properties saved by the adapter
|
|
244
|
+
saved_properties = adapter._properties_for_task_output()
|
|
245
|
+
saved_property_keys = set(saved_properties.keys())
|
|
246
|
+
|
|
247
|
+
# Check which RunConfigProperties fields are missing from saved properties
|
|
248
|
+
# Note: model_provider_name becomes model_provider in saved properties
|
|
249
|
+
expected_mappings = {
|
|
250
|
+
"model_name": "model_name",
|
|
251
|
+
"model_provider_name": "model_provider",
|
|
252
|
+
"prompt_id": "prompt_id",
|
|
253
|
+
"temperature": "temperature",
|
|
254
|
+
"top_p": "top_p",
|
|
255
|
+
"structured_output_mode": "structured_output_mode",
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
missing_properties = []
|
|
259
|
+
for field_name in run_config_properties_fields:
|
|
260
|
+
expected_key = expected_mappings.get(field_name, field_name)
|
|
261
|
+
if expected_key not in saved_property_keys:
|
|
262
|
+
missing_properties.append(
|
|
263
|
+
f"RunConfigProperties.{field_name} -> {expected_key}"
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
assert not missing_properties, (
|
|
267
|
+
f"The following RunConfigProperties fields are not saved by _properties_for_task_output: {missing_properties}. Please update the method to include them."
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
async def test_properties_for_task_output_catches_missing_new_property(adapter):
|
|
272
|
+
"""Test that demonstrates our test will catch when new properties are added to RunConfigProperties but not to _properties_for_task_output"""
|
|
273
|
+
# Simulate what happens if a new property was added to RunConfigProperties
|
|
274
|
+
# We'll mock the model_fields to include a fake new property
|
|
275
|
+
original_fields = RunConfigProperties.model_fields.copy()
|
|
276
|
+
|
|
277
|
+
# Create a mock field to simulate a new property being added
|
|
278
|
+
from pydantic.fields import FieldInfo
|
|
279
|
+
|
|
280
|
+
mock_field = FieldInfo(annotation=str, default="default_value")
|
|
281
|
+
|
|
282
|
+
try:
|
|
283
|
+
# Add a fake new field to simulate someone adding a property
|
|
284
|
+
RunConfigProperties.model_fields["new_fake_property"] = mock_field
|
|
285
|
+
|
|
286
|
+
# Get all field names from RunConfigProperties (now includes our fake property)
|
|
287
|
+
run_config_properties_fields = set(RunConfigProperties.model_fields.keys())
|
|
288
|
+
|
|
289
|
+
# Get the properties saved by the adapter (won't include our fake property)
|
|
290
|
+
saved_properties = adapter._properties_for_task_output()
|
|
291
|
+
saved_property_keys = set(saved_properties.keys())
|
|
292
|
+
|
|
293
|
+
# The mappings don't include our fake property
|
|
294
|
+
expected_mappings = {
|
|
295
|
+
"model_name": "model_name",
|
|
296
|
+
"model_provider_name": "model_provider",
|
|
297
|
+
"prompt_id": "prompt_id",
|
|
298
|
+
"temperature": "temperature",
|
|
299
|
+
"top_p": "top_p",
|
|
300
|
+
"structured_output_mode": "structured_output_mode",
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
missing_properties = []
|
|
304
|
+
for field_name in run_config_properties_fields:
|
|
305
|
+
expected_key = expected_mappings.get(field_name, field_name)
|
|
306
|
+
if expected_key not in saved_property_keys:
|
|
307
|
+
missing_properties.append(
|
|
308
|
+
f"RunConfigProperties.{field_name} -> {expected_key}"
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
# This should find our missing fake property
|
|
312
|
+
assert missing_properties == [
|
|
313
|
+
"RunConfigProperties.new_fake_property -> new_fake_property"
|
|
314
|
+
], f"Expected to find missing fake property, but got: {missing_properties}"
|
|
315
|
+
|
|
316
|
+
finally:
|
|
317
|
+
# Restore the original fields
|
|
318
|
+
RunConfigProperties.model_fields.clear()
|
|
319
|
+
RunConfigProperties.model_fields.update(original_fields)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
@pytest.mark.parametrize(
|
|
323
|
+
"cot_prompt,tuned_strategy,reasoning_capable,expected_formatter_class",
|
|
324
|
+
[
|
|
325
|
+
# No COT prompt -> always single turn
|
|
326
|
+
(None, None, False, "SingleTurnFormatter"),
|
|
327
|
+
(None, ChatStrategy.two_message_cot, False, "SingleTurnFormatter"),
|
|
328
|
+
(None, ChatStrategy.single_turn_r1_thinking, True, "SingleTurnFormatter"),
|
|
329
|
+
# With COT prompt:
|
|
330
|
+
# - Tuned strategy takes precedence (except single turn)
|
|
331
|
+
(
|
|
332
|
+
"think step by step",
|
|
333
|
+
ChatStrategy.two_message_cot,
|
|
334
|
+
False,
|
|
335
|
+
"TwoMessageCotFormatter",
|
|
336
|
+
),
|
|
337
|
+
(
|
|
338
|
+
"think step by step",
|
|
339
|
+
ChatStrategy.single_turn_r1_thinking,
|
|
340
|
+
False,
|
|
341
|
+
"SingleTurnR1ThinkingFormatter",
|
|
342
|
+
),
|
|
343
|
+
# - Tuned single turn is ignored when COT exists
|
|
344
|
+
(
|
|
345
|
+
"think step by step",
|
|
346
|
+
ChatStrategy.single_turn,
|
|
347
|
+
True,
|
|
348
|
+
"SingleTurnR1ThinkingFormatter",
|
|
349
|
+
),
|
|
350
|
+
# - Reasoning capable -> single turn R1 thinking
|
|
351
|
+
("think step by step", None, True, "SingleTurnR1ThinkingFormatter"),
|
|
352
|
+
# - Not reasoning capable -> two message COT
|
|
353
|
+
("think step by step", None, False, "TwoMessageCotFormatter"),
|
|
354
|
+
],
|
|
355
|
+
)
|
|
356
|
+
def test_build_chat_formatter(
|
|
357
|
+
adapter,
|
|
358
|
+
cot_prompt,
|
|
359
|
+
tuned_strategy,
|
|
360
|
+
reasoning_capable,
|
|
361
|
+
expected_formatter_class,
|
|
362
|
+
):
|
|
363
|
+
"""Test chat formatter strategy selection based on COT prompt, tuned strategy, and model capabilities"""
|
|
364
|
+
# Mock the prompt builder
|
|
365
|
+
mock_prompt_builder = MagicMock()
|
|
366
|
+
mock_prompt_builder.chain_of_thought_prompt.return_value = cot_prompt
|
|
367
|
+
mock_prompt_builder.build_prompt.return_value = "system message"
|
|
368
|
+
adapter.prompt_builder = mock_prompt_builder
|
|
369
|
+
|
|
370
|
+
# Mock the model provider
|
|
371
|
+
mock_provider = MagicMock()
|
|
372
|
+
mock_provider.tuned_chat_strategy = tuned_strategy
|
|
373
|
+
mock_provider.reasoning_capable = reasoning_capable
|
|
374
|
+
adapter.model_provider = MagicMock(return_value=mock_provider)
|
|
375
|
+
|
|
376
|
+
# Get the formatter
|
|
377
|
+
formatter = adapter.build_chat_formatter("test input")
|
|
378
|
+
|
|
379
|
+
# Verify the formatter type
|
|
380
|
+
assert formatter.__class__.__name__ == expected_formatter_class
|
|
381
|
+
|
|
382
|
+
# Verify the formatter was created with correct parameters
|
|
383
|
+
assert formatter.system_message == "system message"
|
|
384
|
+
assert formatter.user_input == "test input"
|
|
385
|
+
# Only check thinking_instructions for formatters that use it
|
|
386
|
+
if expected_formatter_class == "TwoMessageCotFormatter":
|
|
387
|
+
if cot_prompt:
|
|
388
|
+
assert formatter.thinking_instructions == cot_prompt
|
|
389
|
+
else:
|
|
390
|
+
assert formatter.thinking_instructions is None
|
|
391
|
+
# For other formatters, don't assert thinking_instructions
|
|
392
|
+
|
|
393
|
+
# Verify prompt builder was called correctly
|
|
394
|
+
mock_prompt_builder.build_prompt.assert_called_once()
|
|
395
|
+
mock_prompt_builder.chain_of_thought_prompt.assert_called_once()
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
@pytest.mark.parametrize(
|
|
399
|
+
"initial_mode,expected_mode",
|
|
400
|
+
[
|
|
401
|
+
(
|
|
402
|
+
StructuredOutputMode.json_schema,
|
|
403
|
+
StructuredOutputMode.json_schema,
|
|
404
|
+
), # Should not change
|
|
405
|
+
(
|
|
406
|
+
StructuredOutputMode.unknown,
|
|
407
|
+
StructuredOutputMode.json_mode,
|
|
408
|
+
), # Should update to default
|
|
409
|
+
],
|
|
410
|
+
)
|
|
411
|
+
async def test_update_run_config_unknown_structured_output_mode(
|
|
412
|
+
base_task, initial_mode, expected_mode
|
|
413
|
+
):
|
|
414
|
+
"""Test that unknown structured output mode is updated to the default for the model provider"""
|
|
415
|
+
# Create a run config with the initial mode
|
|
416
|
+
run_config = RunConfig(
|
|
417
|
+
task=base_task,
|
|
418
|
+
model_name="test_model",
|
|
419
|
+
model_provider_name="openai",
|
|
420
|
+
prompt_id="simple_prompt_builder",
|
|
421
|
+
structured_output_mode=initial_mode,
|
|
422
|
+
temperature=0.7, # Add some other properties to verify they're preserved
|
|
423
|
+
top_p=0.9,
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
# Mock the default mode lookup
|
|
427
|
+
with patch(
|
|
428
|
+
"kiln_ai.adapters.model_adapters.base_adapter.default_structured_output_mode_for_model_provider"
|
|
429
|
+
) as mock_default:
|
|
430
|
+
mock_default.return_value = StructuredOutputMode.json_mode
|
|
431
|
+
|
|
432
|
+
# Create the adapter
|
|
433
|
+
adapter = MockAdapter(run_config=run_config)
|
|
434
|
+
|
|
435
|
+
# Verify the mode was updated correctly
|
|
436
|
+
assert adapter.run_config.structured_output_mode == expected_mode
|
|
437
|
+
|
|
438
|
+
# Verify other properties were preserved
|
|
439
|
+
assert adapter.run_config.temperature == 0.7
|
|
440
|
+
assert adapter.run_config.top_p == 0.9
|
|
441
|
+
|
|
442
|
+
# Verify the default mode lookup was only called when needed
|
|
443
|
+
if initial_mode == StructuredOutputMode.unknown:
|
|
444
|
+
mock_default.assert_called_once_with("test_model", "openai")
|
|
445
|
+
else:
|
|
446
|
+
mock_default.assert_not_called()
|