kiln-ai 0.8.1__py3-none-any.whl → 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +7 -7
- kiln_ai/adapters/adapter_registry.py +77 -5
- kiln_ai/adapters/data_gen/data_gen_task.py +3 -3
- kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
- kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
- kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
- kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
- kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +469 -129
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +113 -21
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
- kiln_ai/adapters/ml_model_list.py +323 -94
- kiln_ai/adapters/model_adapters/__init__.py +18 -0
- kiln_ai/adapters/{base_adapter.py → model_adapters/base_adapter.py} +81 -37
- kiln_ai/adapters/{langchain_adapters.py → model_adapters/langchain_adapters.py} +130 -84
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +11 -0
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +246 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +190 -0
- kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +103 -88
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +225 -0
- kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +43 -15
- kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +93 -20
- kiln_ai/adapters/parsers/__init__.py +10 -0
- kiln_ai/adapters/parsers/base_parser.py +12 -0
- kiln_ai/adapters/parsers/json_parser.py +37 -0
- kiln_ai/adapters/parsers/parser_registry.py +19 -0
- kiln_ai/adapters/parsers/r1_parser.py +69 -0
- kiln_ai/adapters/parsers/test_json_parser.py +81 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
- kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
- kiln_ai/adapters/prompt_builders.py +126 -20
- kiln_ai/adapters/provider_tools.py +91 -36
- kiln_ai/adapters/repair/repair_task.py +17 -6
- kiln_ai/adapters/repair/test_repair_task.py +4 -4
- kiln_ai/adapters/run_output.py +8 -0
- kiln_ai/adapters/test_adapter_registry.py +177 -0
- kiln_ai/adapters/test_generate_docs.py +69 -0
- kiln_ai/adapters/test_prompt_adaptors.py +8 -4
- kiln_ai/adapters/test_prompt_builders.py +190 -29
- kiln_ai/adapters/test_provider_tools.py +268 -46
- kiln_ai/datamodel/__init__.py +193 -12
- kiln_ai/datamodel/basemodel.py +31 -11
- kiln_ai/datamodel/json_schema.py +8 -3
- kiln_ai/datamodel/model_cache.py +8 -3
- kiln_ai/datamodel/test_basemodel.py +81 -2
- kiln_ai/datamodel/test_dataset_split.py +100 -3
- kiln_ai/datamodel/test_example_models.py +25 -4
- kiln_ai/datamodel/test_model_cache.py +24 -0
- kiln_ai/datamodel/test_model_perf.py +125 -0
- kiln_ai/datamodel/test_models.py +129 -0
- kiln_ai/utils/exhaustive_error.py +6 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/METADATA +9 -7
- kiln_ai-0.11.1.dist-info/RECORD +76 -0
- kiln_ai-0.8.1.dist-info/RECORD +0 -58
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
from typing import Any, Dict
|
|
2
|
+
|
|
3
|
+
from openai import AsyncOpenAI
|
|
4
|
+
from openai.types.chat import (
|
|
5
|
+
ChatCompletion,
|
|
6
|
+
ChatCompletionAssistantMessageParam,
|
|
7
|
+
ChatCompletionSystemMessageParam,
|
|
8
|
+
ChatCompletionUserMessageParam,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
import kiln_ai.datamodel as datamodel
|
|
12
|
+
from kiln_ai.adapters.ml_model_list import StructuredOutputMode
|
|
13
|
+
from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
14
|
+
COT_FINAL_ANSWER_PROMPT,
|
|
15
|
+
AdapterInfo,
|
|
16
|
+
BaseAdapter,
|
|
17
|
+
BasePromptBuilder,
|
|
18
|
+
RunOutput,
|
|
19
|
+
)
|
|
20
|
+
from kiln_ai.adapters.model_adapters.openai_compatible_config import (
|
|
21
|
+
OpenAICompatibleConfig,
|
|
22
|
+
)
|
|
23
|
+
from kiln_ai.adapters.parsers.json_parser import parse_json_string
|
|
24
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class OpenAICompatibleAdapter(BaseAdapter):
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
config: OpenAICompatibleConfig,
|
|
31
|
+
kiln_task: datamodel.Task,
|
|
32
|
+
prompt_builder: BasePromptBuilder | None = None,
|
|
33
|
+
tags: list[str] | None = None,
|
|
34
|
+
):
|
|
35
|
+
self.config = config
|
|
36
|
+
self.client = AsyncOpenAI(
|
|
37
|
+
api_key=config.api_key,
|
|
38
|
+
base_url=config.base_url,
|
|
39
|
+
default_headers=config.default_headers,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
super().__init__(
|
|
43
|
+
kiln_task,
|
|
44
|
+
model_name=config.model_name,
|
|
45
|
+
model_provider_name=config.provider_name,
|
|
46
|
+
prompt_builder=prompt_builder,
|
|
47
|
+
tags=tags,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
async def _run(self, input: Dict | str) -> RunOutput:
|
|
51
|
+
provider = self.model_provider()
|
|
52
|
+
intermediate_outputs: dict[str, str] = {}
|
|
53
|
+
prompt = self.build_prompt()
|
|
54
|
+
user_msg = self.prompt_builder.build_user_message(input)
|
|
55
|
+
messages = [
|
|
56
|
+
ChatCompletionSystemMessageParam(role="system", content=prompt),
|
|
57
|
+
ChatCompletionUserMessageParam(role="user", content=user_msg),
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
run_strategy, cot_prompt = self.run_strategy()
|
|
61
|
+
|
|
62
|
+
if run_strategy == "cot_as_message":
|
|
63
|
+
if not cot_prompt:
|
|
64
|
+
raise ValueError("cot_prompt is required for cot_as_message strategy")
|
|
65
|
+
messages.append(
|
|
66
|
+
ChatCompletionSystemMessageParam(role="system", content=cot_prompt)
|
|
67
|
+
)
|
|
68
|
+
elif run_strategy == "cot_two_call":
|
|
69
|
+
if not cot_prompt:
|
|
70
|
+
raise ValueError("cot_prompt is required for cot_two_call strategy")
|
|
71
|
+
messages.append(
|
|
72
|
+
ChatCompletionSystemMessageParam(role="system", content=cot_prompt)
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# First call for chain of thought
|
|
76
|
+
cot_response = await self.client.chat.completions.create(
|
|
77
|
+
model=provider.provider_options["model"],
|
|
78
|
+
messages=messages,
|
|
79
|
+
)
|
|
80
|
+
cot_content = cot_response.choices[0].message.content
|
|
81
|
+
if cot_content is not None:
|
|
82
|
+
intermediate_outputs["chain_of_thought"] = cot_content
|
|
83
|
+
|
|
84
|
+
messages.extend(
|
|
85
|
+
[
|
|
86
|
+
ChatCompletionAssistantMessageParam(
|
|
87
|
+
role="assistant", content=cot_content
|
|
88
|
+
),
|
|
89
|
+
ChatCompletionUserMessageParam(
|
|
90
|
+
role="user",
|
|
91
|
+
content=COT_FINAL_ANSWER_PROMPT,
|
|
92
|
+
),
|
|
93
|
+
]
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# OpenRouter specific options for reasoning models
|
|
97
|
+
extra_body = {}
|
|
98
|
+
require_or_reasoning = (
|
|
99
|
+
self.config.openrouter_style_reasoning and provider.reasoning_capable
|
|
100
|
+
)
|
|
101
|
+
if require_or_reasoning:
|
|
102
|
+
extra_body["include_reasoning"] = True
|
|
103
|
+
# Filter to providers that support the reasoning parameter
|
|
104
|
+
extra_body["provider"] = {
|
|
105
|
+
"require_parameters": True,
|
|
106
|
+
# Ugly to have these here, but big range of quality of R1 providers
|
|
107
|
+
"order": ["Fireworks", "Together"],
|
|
108
|
+
# fp8 quants are awful
|
|
109
|
+
"ignore": ["DeepInfra"],
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
# Main completion call
|
|
113
|
+
response_format_options = await self.response_format_options()
|
|
114
|
+
response = await self.client.chat.completions.create(
|
|
115
|
+
model=provider.provider_options["model"],
|
|
116
|
+
messages=messages,
|
|
117
|
+
extra_body=extra_body,
|
|
118
|
+
**response_format_options,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
if not isinstance(response, ChatCompletion):
|
|
122
|
+
raise RuntimeError(
|
|
123
|
+
f"Expected ChatCompletion response, got {type(response)}."
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if hasattr(response, "error") and response.error: # pyright: ignore
|
|
127
|
+
raise RuntimeError(
|
|
128
|
+
f"OpenAI compatible API returned status code {response.error.get('code')}: {response.error.get('message') or 'Unknown error'}.\nError: {response.error}" # pyright: ignore
|
|
129
|
+
)
|
|
130
|
+
if not response.choices or len(response.choices) == 0:
|
|
131
|
+
raise RuntimeError(
|
|
132
|
+
"No message content returned in the response from OpenAI compatible API"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
message = response.choices[0].message
|
|
136
|
+
|
|
137
|
+
# Save reasoning if it exists (OpenRouter specific format)
|
|
138
|
+
if require_or_reasoning:
|
|
139
|
+
if (
|
|
140
|
+
hasattr(message, "reasoning") and message.reasoning # pyright: ignore
|
|
141
|
+
):
|
|
142
|
+
intermediate_outputs["reasoning"] = message.reasoning # pyright: ignore
|
|
143
|
+
else:
|
|
144
|
+
raise RuntimeError(
|
|
145
|
+
"Reasoning is required for this model, but no reasoning was returned from OpenRouter."
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# the string content of the response
|
|
149
|
+
response_content = message.content
|
|
150
|
+
|
|
151
|
+
# Fallback: Use args of first tool call to task_response if it exists
|
|
152
|
+
if not response_content and message.tool_calls:
|
|
153
|
+
tool_call = next(
|
|
154
|
+
(
|
|
155
|
+
tool_call
|
|
156
|
+
for tool_call in message.tool_calls
|
|
157
|
+
if tool_call.function.name == "task_response"
|
|
158
|
+
),
|
|
159
|
+
None,
|
|
160
|
+
)
|
|
161
|
+
if tool_call:
|
|
162
|
+
response_content = tool_call.function.arguments
|
|
163
|
+
|
|
164
|
+
if not isinstance(response_content, str):
|
|
165
|
+
raise RuntimeError(f"response is not a string: {response_content}")
|
|
166
|
+
|
|
167
|
+
if self.has_structured_output():
|
|
168
|
+
structured_response = parse_json_string(response_content)
|
|
169
|
+
return RunOutput(
|
|
170
|
+
output=structured_response,
|
|
171
|
+
intermediate_outputs=intermediate_outputs,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
return RunOutput(
|
|
175
|
+
output=response_content,
|
|
176
|
+
intermediate_outputs=intermediate_outputs,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
def adapter_info(self) -> AdapterInfo:
|
|
180
|
+
return AdapterInfo(
|
|
181
|
+
model_name=self.model_name,
|
|
182
|
+
model_provider=self.model_provider_name,
|
|
183
|
+
adapter_name="kiln_openai_compatible_adapter",
|
|
184
|
+
prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(),
|
|
185
|
+
prompt_id=self.prompt_builder.prompt_id(),
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
async def response_format_options(self) -> dict[str, Any]:
|
|
189
|
+
# Unstructured if task isn't structured
|
|
190
|
+
if not self.has_structured_output():
|
|
191
|
+
return {}
|
|
192
|
+
|
|
193
|
+
provider = self.model_provider()
|
|
194
|
+
match provider.structured_output_mode:
|
|
195
|
+
case StructuredOutputMode.json_mode:
|
|
196
|
+
return {"response_format": {"type": "json_object"}}
|
|
197
|
+
case StructuredOutputMode.json_schema:
|
|
198
|
+
output_schema = self.kiln_task.output_schema()
|
|
199
|
+
return {
|
|
200
|
+
"response_format": {
|
|
201
|
+
"type": "json_schema",
|
|
202
|
+
"json_schema": {
|
|
203
|
+
"name": "task_response",
|
|
204
|
+
"schema": output_schema,
|
|
205
|
+
},
|
|
206
|
+
}
|
|
207
|
+
}
|
|
208
|
+
case StructuredOutputMode.function_calling:
|
|
209
|
+
return self.tool_call_params()
|
|
210
|
+
case StructuredOutputMode.json_instructions:
|
|
211
|
+
# JSON done via instructions in prompt, not the API response format. Do not ask for json_object (see option below).
|
|
212
|
+
return {}
|
|
213
|
+
case StructuredOutputMode.json_instruction_and_object:
|
|
214
|
+
# We set response_format to json_object and also set json instructions in the prompt
|
|
215
|
+
return {"response_format": {"type": "json_object"}}
|
|
216
|
+
case StructuredOutputMode.default:
|
|
217
|
+
# Default to function calling -- it's older than the other modes. Higher compatibility.
|
|
218
|
+
return self.tool_call_params()
|
|
219
|
+
case _:
|
|
220
|
+
raise_exhaustive_enum_error(provider.structured_output_mode)
|
|
221
|
+
|
|
222
|
+
def tool_call_params(self) -> dict[str, Any]:
|
|
223
|
+
# Add additional_properties: false to the schema (OpenAI requires this for some models)
|
|
224
|
+
output_schema = self.kiln_task.output_schema()
|
|
225
|
+
if not isinstance(output_schema, dict):
|
|
226
|
+
raise ValueError(
|
|
227
|
+
"Invalid output schema for this task. Can not use tool calls."
|
|
228
|
+
)
|
|
229
|
+
output_schema["additionalProperties"] = False
|
|
230
|
+
|
|
231
|
+
return {
|
|
232
|
+
"tools": [
|
|
233
|
+
{
|
|
234
|
+
"type": "function",
|
|
235
|
+
"function": {
|
|
236
|
+
"name": "task_response",
|
|
237
|
+
"parameters": output_schema,
|
|
238
|
+
"strict": True,
|
|
239
|
+
},
|
|
240
|
+
}
|
|
241
|
+
],
|
|
242
|
+
"tool_choice": {
|
|
243
|
+
"type": "function",
|
|
244
|
+
"function": {"name": "task_response"},
|
|
245
|
+
},
|
|
246
|
+
}
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
from unittest.mock import MagicMock, patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
|
|
6
|
+
from kiln_ai.adapters.model_adapters.base_adapter import AdapterInfo, BaseAdapter
|
|
7
|
+
from kiln_ai.datamodel import Task
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MockAdapter(BaseAdapter):
|
|
11
|
+
"""Concrete implementation of BaseAdapter for testing"""
|
|
12
|
+
|
|
13
|
+
async def _run(self, input):
|
|
14
|
+
return None
|
|
15
|
+
|
|
16
|
+
def adapter_info(self) -> AdapterInfo:
|
|
17
|
+
return AdapterInfo(
|
|
18
|
+
adapter_name="test",
|
|
19
|
+
model_name=self.model_name,
|
|
20
|
+
model_provider=self.model_provider_name,
|
|
21
|
+
prompt_builder_name="test",
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pytest.fixture
|
|
26
|
+
def mock_provider():
|
|
27
|
+
return KilnModelProvider(
|
|
28
|
+
name="openai",
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@pytest.fixture
|
|
33
|
+
def base_task():
|
|
34
|
+
return Task(name="test_task", instruction="test_instruction")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@pytest.fixture
|
|
38
|
+
def adapter(base_task):
|
|
39
|
+
return MockAdapter(
|
|
40
|
+
kiln_task=base_task,
|
|
41
|
+
model_name="test_model",
|
|
42
|
+
model_provider_name="test_provider",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
async def test_model_provider_uses_cache(adapter, mock_provider):
|
|
47
|
+
"""Test that cached provider is returned if it exists"""
|
|
48
|
+
# Set up cached provider
|
|
49
|
+
adapter._model_provider = mock_provider
|
|
50
|
+
|
|
51
|
+
# Mock the provider loader to ensure it's not called
|
|
52
|
+
with patch(
|
|
53
|
+
"kiln_ai.adapters.model_adapters.base_adapter.kiln_model_provider_from"
|
|
54
|
+
) as mock_loader:
|
|
55
|
+
provider = adapter.model_provider()
|
|
56
|
+
|
|
57
|
+
assert provider == mock_provider
|
|
58
|
+
mock_loader.assert_not_called()
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
async def test_model_provider_loads_and_caches(adapter, mock_provider):
|
|
62
|
+
"""Test that provider is loaded and cached if not present"""
|
|
63
|
+
# Ensure no cached provider
|
|
64
|
+
adapter._model_provider = None
|
|
65
|
+
|
|
66
|
+
# Mock the provider loader
|
|
67
|
+
with patch(
|
|
68
|
+
"kiln_ai.adapters.model_adapters.base_adapter.kiln_model_provider_from"
|
|
69
|
+
) as mock_loader:
|
|
70
|
+
mock_loader.return_value = mock_provider
|
|
71
|
+
|
|
72
|
+
# First call should load and cache
|
|
73
|
+
provider1 = adapter.model_provider()
|
|
74
|
+
assert provider1 == mock_provider
|
|
75
|
+
mock_loader.assert_called_once_with("test_model", "test_provider")
|
|
76
|
+
|
|
77
|
+
# Second call should use cache
|
|
78
|
+
mock_loader.reset_mock()
|
|
79
|
+
provider2 = adapter.model_provider()
|
|
80
|
+
assert provider2 == mock_provider
|
|
81
|
+
mock_loader.assert_not_called()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
async def test_model_provider_missing_names(base_task):
|
|
85
|
+
"""Test error when model or provider name is missing"""
|
|
86
|
+
# Test with missing model name
|
|
87
|
+
adapter = MockAdapter(
|
|
88
|
+
kiln_task=base_task, model_name="", model_provider_name="test_provider"
|
|
89
|
+
)
|
|
90
|
+
with pytest.raises(
|
|
91
|
+
ValueError, match="model_name and model_provider_name must be provided"
|
|
92
|
+
):
|
|
93
|
+
await adapter.model_provider()
|
|
94
|
+
|
|
95
|
+
# Test with missing provider name
|
|
96
|
+
adapter = MockAdapter(
|
|
97
|
+
kiln_task=base_task, model_name="test_model", model_provider_name=""
|
|
98
|
+
)
|
|
99
|
+
with pytest.raises(
|
|
100
|
+
ValueError, match="model_name and model_provider_name must be provided"
|
|
101
|
+
):
|
|
102
|
+
await adapter.model_provider()
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
async def test_model_provider_not_found(adapter):
|
|
106
|
+
"""Test error when provider loader returns None"""
|
|
107
|
+
# Mock the provider loader to return None
|
|
108
|
+
with patch(
|
|
109
|
+
"kiln_ai.adapters.model_adapters.base_adapter.kiln_model_provider_from"
|
|
110
|
+
) as mock_loader:
|
|
111
|
+
mock_loader.return_value = None
|
|
112
|
+
|
|
113
|
+
with pytest.raises(
|
|
114
|
+
ValueError,
|
|
115
|
+
match="model_provider_name test_provider not found for model test_model",
|
|
116
|
+
):
|
|
117
|
+
await adapter.model_provider()
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@pytest.mark.asyncio
|
|
121
|
+
@pytest.mark.parametrize(
|
|
122
|
+
"output_schema,structured_output_mode,expected_json_instructions",
|
|
123
|
+
[
|
|
124
|
+
(False, StructuredOutputMode.json_instructions, False),
|
|
125
|
+
(True, StructuredOutputMode.json_instructions, True),
|
|
126
|
+
(False, StructuredOutputMode.json_instruction_and_object, False),
|
|
127
|
+
(True, StructuredOutputMode.json_instruction_and_object, True),
|
|
128
|
+
(True, StructuredOutputMode.json_mode, False),
|
|
129
|
+
(False, StructuredOutputMode.json_mode, False),
|
|
130
|
+
],
|
|
131
|
+
)
|
|
132
|
+
async def test_prompt_builder_json_instructions(
|
|
133
|
+
base_task,
|
|
134
|
+
adapter,
|
|
135
|
+
output_schema,
|
|
136
|
+
structured_output_mode,
|
|
137
|
+
expected_json_instructions,
|
|
138
|
+
):
|
|
139
|
+
"""Test that prompt builder is called with correct include_json_instructions value"""
|
|
140
|
+
# Mock the prompt builder and has_structured_output method
|
|
141
|
+
mock_prompt_builder = MagicMock()
|
|
142
|
+
adapter.prompt_builder = mock_prompt_builder
|
|
143
|
+
adapter.model_provider_name = "openai"
|
|
144
|
+
adapter.has_structured_output = MagicMock(return_value=output_schema)
|
|
145
|
+
|
|
146
|
+
# provider mock
|
|
147
|
+
provider = MagicMock()
|
|
148
|
+
provider.structured_output_mode = structured_output_mode
|
|
149
|
+
adapter.model_provider = MagicMock(return_value=provider)
|
|
150
|
+
|
|
151
|
+
# Test
|
|
152
|
+
adapter.build_prompt()
|
|
153
|
+
mock_prompt_builder.build_prompt.assert_called_with(
|
|
154
|
+
include_json_instructions=expected_json_instructions
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@pytest.mark.parametrize(
|
|
159
|
+
"cot_prompt,has_structured_output,reasoning_capable,expected",
|
|
160
|
+
[
|
|
161
|
+
# COT and normal LLM
|
|
162
|
+
("think carefully", False, False, ("cot_two_call", "think carefully")),
|
|
163
|
+
# Structured output with thinking-capable LLM
|
|
164
|
+
("think carefully", True, True, ("cot_as_message", "think carefully")),
|
|
165
|
+
# Structured output with normal LLM
|
|
166
|
+
("think carefully", True, False, ("cot_two_call", "think carefully")),
|
|
167
|
+
# Basic cases - no COT
|
|
168
|
+
(None, True, True, ("basic", None)),
|
|
169
|
+
(None, False, False, ("basic", None)),
|
|
170
|
+
(None, True, False, ("basic", None)),
|
|
171
|
+
(None, False, True, ("basic", None)),
|
|
172
|
+
# Edge case - COT prompt exists but structured output is False and reasoning_capable is True
|
|
173
|
+
("think carefully", False, True, ("cot_as_message", "think carefully")),
|
|
174
|
+
],
|
|
175
|
+
)
|
|
176
|
+
async def test_run_strategy(
|
|
177
|
+
adapter, cot_prompt, has_structured_output, reasoning_capable, expected
|
|
178
|
+
):
|
|
179
|
+
"""Test that run_strategy returns correct strategy based on conditions"""
|
|
180
|
+
# Mock dependencies
|
|
181
|
+
adapter.prompt_builder.chain_of_thought_prompt = MagicMock(return_value=cot_prompt)
|
|
182
|
+
adapter.has_structured_output = MagicMock(return_value=has_structured_output)
|
|
183
|
+
|
|
184
|
+
provider = MagicMock()
|
|
185
|
+
provider.reasoning_capable = reasoning_capable
|
|
186
|
+
adapter.model_provider = MagicMock(return_value=provider)
|
|
187
|
+
|
|
188
|
+
# Test
|
|
189
|
+
result = adapter.run_strategy()
|
|
190
|
+
assert result == expected
|