kiln-ai 0.8.0__py3-none-any.whl → 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +7 -7
- kiln_ai/adapters/adapter_registry.py +77 -5
- kiln_ai/adapters/data_gen/data_gen_task.py +3 -3
- kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
- kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
- kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
- kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
- kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +469 -129
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +113 -21
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
- kiln_ai/adapters/ml_model_list.py +323 -94
- kiln_ai/adapters/model_adapters/__init__.py +18 -0
- kiln_ai/adapters/{base_adapter.py → model_adapters/base_adapter.py} +81 -37
- kiln_ai/adapters/{langchain_adapters.py → model_adapters/langchain_adapters.py} +130 -84
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +11 -0
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +246 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +190 -0
- kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +103 -88
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +225 -0
- kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +43 -15
- kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +93 -20
- kiln_ai/adapters/parsers/__init__.py +10 -0
- kiln_ai/adapters/parsers/base_parser.py +12 -0
- kiln_ai/adapters/parsers/json_parser.py +37 -0
- kiln_ai/adapters/parsers/parser_registry.py +19 -0
- kiln_ai/adapters/parsers/r1_parser.py +69 -0
- kiln_ai/adapters/parsers/test_json_parser.py +81 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
- kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
- kiln_ai/adapters/prompt_builders.py +126 -20
- kiln_ai/adapters/provider_tools.py +91 -36
- kiln_ai/adapters/repair/repair_task.py +17 -6
- kiln_ai/adapters/repair/test_repair_task.py +4 -4
- kiln_ai/adapters/run_output.py +8 -0
- kiln_ai/adapters/test_adapter_registry.py +177 -0
- kiln_ai/adapters/test_generate_docs.py +69 -0
- kiln_ai/adapters/test_prompt_adaptors.py +8 -4
- kiln_ai/adapters/test_prompt_builders.py +190 -29
- kiln_ai/adapters/test_provider_tools.py +268 -46
- kiln_ai/datamodel/__init__.py +199 -12
- kiln_ai/datamodel/basemodel.py +31 -11
- kiln_ai/datamodel/json_schema.py +8 -3
- kiln_ai/datamodel/model_cache.py +8 -3
- kiln_ai/datamodel/test_basemodel.py +81 -2
- kiln_ai/datamodel/test_dataset_split.py +100 -3
- kiln_ai/datamodel/test_example_models.py +25 -4
- kiln_ai/datamodel/test_model_cache.py +24 -0
- kiln_ai/datamodel/test_model_perf.py +125 -0
- kiln_ai/datamodel/test_models.py +129 -0
- kiln_ai/utils/exhaustive_error.py +6 -0
- {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/METADATA +9 -7
- kiln_ai-0.11.1.dist-info/RECORD +76 -0
- kiln_ai-0.8.0.dist-info/RECORD +0 -58
- {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import json
|
|
1
2
|
from pathlib import Path
|
|
2
3
|
from typing import Dict
|
|
3
4
|
|
|
@@ -7,10 +8,14 @@ import pytest
|
|
|
7
8
|
|
|
8
9
|
import kiln_ai.datamodel as datamodel
|
|
9
10
|
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
10
|
-
from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter, RunOutput
|
|
11
11
|
from kiln_ai.adapters.ml_model_list import (
|
|
12
12
|
built_in_models,
|
|
13
13
|
)
|
|
14
|
+
from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
15
|
+
AdapterInfo,
|
|
16
|
+
BaseAdapter,
|
|
17
|
+
RunOutput,
|
|
18
|
+
)
|
|
14
19
|
from kiln_ai.adapters.ollama_tools import ollama_online
|
|
15
20
|
from kiln_ai.adapters.prompt_builders import (
|
|
16
21
|
BasePromptBuilder,
|
|
@@ -44,7 +49,7 @@ async def test_structured_output_ollama_llama(tmp_path, model_name):
|
|
|
44
49
|
|
|
45
50
|
class MockAdapter(BaseAdapter):
|
|
46
51
|
def __init__(self, kiln_task: datamodel.Task, response: Dict | str | None):
|
|
47
|
-
super().__init__(kiln_task)
|
|
52
|
+
super().__init__(kiln_task, model_name="phi_3_5", model_provider_name="ollama")
|
|
48
53
|
self.response = response
|
|
49
54
|
|
|
50
55
|
async def _run(self, input: str) -> RunOutput:
|
|
@@ -93,19 +98,10 @@ async def test_mock_unstructred_response(tmp_path):
|
|
|
93
98
|
answer = await adapter.invoke("You are a mock, send me the response!")
|
|
94
99
|
|
|
95
100
|
|
|
96
|
-
|
|
97
|
-
@pytest.mark.ollama
|
|
98
|
-
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
99
|
-
async def test_all_built_in_models_structured_output(
|
|
100
|
-
tmp_path, model_name, provider_name
|
|
101
|
-
):
|
|
101
|
+
def check_supports_structured_output(model_name: str, provider_name: str):
|
|
102
102
|
for model in built_in_models:
|
|
103
103
|
if model.name != model_name:
|
|
104
104
|
continue
|
|
105
|
-
if not model.supports_structured_output:
|
|
106
|
-
pytest.skip(
|
|
107
|
-
f"Skipping {model.name} because it does not support structured output"
|
|
108
|
-
)
|
|
109
105
|
for provider in model.providers:
|
|
110
106
|
if provider.name != provider_name:
|
|
111
107
|
continue
|
|
@@ -113,11 +109,20 @@ async def test_all_built_in_models_structured_output(
|
|
|
113
109
|
pytest.skip(
|
|
114
110
|
f"Skipping {model.name} {provider.name} because it does not support structured output"
|
|
115
111
|
)
|
|
116
|
-
await run_structured_output_test(tmp_path, model.name, provider.name)
|
|
117
112
|
return
|
|
118
113
|
raise RuntimeError(f"No model {model_name} {provider_name} found")
|
|
119
114
|
|
|
120
115
|
|
|
116
|
+
@pytest.mark.paid
|
|
117
|
+
@pytest.mark.ollama
|
|
118
|
+
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
119
|
+
async def test_all_built_in_models_structured_output(
|
|
120
|
+
tmp_path, model_name, provider_name
|
|
121
|
+
):
|
|
122
|
+
check_supports_structured_output(model_name, provider_name)
|
|
123
|
+
await run_structured_output_test(tmp_path, model_name, provider_name)
|
|
124
|
+
|
|
125
|
+
|
|
121
126
|
def build_structured_output_test_task(tmp_path: Path):
|
|
122
127
|
project = datamodel.Project(name="test", path=tmp_path / "test.kiln")
|
|
123
128
|
project.save_to_file()
|
|
@@ -140,7 +145,14 @@ def build_structured_output_test_task(tmp_path: Path):
|
|
|
140
145
|
async def run_structured_output_test(tmp_path: Path, model_name: str, provider: str):
|
|
141
146
|
task = build_structured_output_test_task(tmp_path)
|
|
142
147
|
a = adapter_for_task(task, model_name=model_name, provider=provider)
|
|
143
|
-
|
|
148
|
+
try:
|
|
149
|
+
parsed = await a.invoke_returning_raw("Cows") # a joke about cows
|
|
150
|
+
except ValueError as e:
|
|
151
|
+
if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
|
|
152
|
+
pytest.skip(
|
|
153
|
+
f"Skipping {model_name} {provider} because Ollama is not running"
|
|
154
|
+
)
|
|
155
|
+
raise e
|
|
144
156
|
if parsed is None or not isinstance(parsed, Dict):
|
|
145
157
|
raise RuntimeError(f"structured response is not a dict: {parsed}")
|
|
146
158
|
assert parsed["setup"] is not None
|
|
@@ -161,6 +173,7 @@ def build_structured_input_test_task(tmp_path: Path):
|
|
|
161
173
|
parent=project,
|
|
162
174
|
name="test task",
|
|
163
175
|
instruction="You are an assistant which classifies a triangle given the lengths of its sides. If all sides are of equal length, the triangle is equilateral. If two sides are equal, the triangle is isosceles. Otherwise, it is scalene.\n\nAt the end of your response return the result in double square brackets. It should be plain text. It should be exactly one of the three following strings: '[[equilateral]]', or '[[isosceles]]', or '[[scalene]]'.",
|
|
176
|
+
thinking_prompt="Think step by step.",
|
|
164
177
|
)
|
|
165
178
|
task.input_json_schema = json_triangle_schema
|
|
166
179
|
schema = task.input_schema()
|
|
@@ -177,7 +190,14 @@ def build_structured_input_test_task(tmp_path: Path):
|
|
|
177
190
|
|
|
178
191
|
async def run_structured_input_test(tmp_path: Path, model_name: str, provider: str):
|
|
179
192
|
task = build_structured_input_test_task(tmp_path)
|
|
180
|
-
|
|
193
|
+
try:
|
|
194
|
+
await run_structured_input_task(task, model_name, provider)
|
|
195
|
+
except ValueError as e:
|
|
196
|
+
if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
|
|
197
|
+
pytest.skip(
|
|
198
|
+
f"Skipping {model_name} {provider} because Ollama is not running"
|
|
199
|
+
)
|
|
200
|
+
raise e
|
|
181
201
|
|
|
182
202
|
|
|
183
203
|
async def run_structured_input_task(
|
|
@@ -196,10 +216,19 @@ async def run_structured_input_task(
|
|
|
196
216
|
# invalid structured input
|
|
197
217
|
await a.invoke({"a": 1, "b": 2, "d": 3})
|
|
198
218
|
|
|
199
|
-
|
|
219
|
+
try:
|
|
220
|
+
response = await a.invoke_returning_raw({"a": 2, "b": 2, "c": 2})
|
|
221
|
+
except ValueError as e:
|
|
222
|
+
if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
|
|
223
|
+
pytest.skip(
|
|
224
|
+
f"Skipping {model_name} {provider} because Ollama is not running"
|
|
225
|
+
)
|
|
226
|
+
raise e
|
|
200
227
|
assert response is not None
|
|
201
|
-
|
|
202
|
-
|
|
228
|
+
if isinstance(response, str):
|
|
229
|
+
assert "[[equilateral]]" in response
|
|
230
|
+
else:
|
|
231
|
+
assert response["is_equilateral"] is True
|
|
203
232
|
adapter_info = a.adapter_info()
|
|
204
233
|
expected_pb_name = "simple_prompt_builder"
|
|
205
234
|
if pb is not None:
|
|
@@ -207,7 +236,6 @@ async def run_structured_input_task(
|
|
|
207
236
|
assert adapter_info.prompt_builder_name == expected_pb_name
|
|
208
237
|
assert adapter_info.model_name == model_name
|
|
209
238
|
assert adapter_info.model_provider == provider
|
|
210
|
-
assert adapter_info.adapter_name == "kiln_langchain_adapter"
|
|
211
239
|
|
|
212
240
|
|
|
213
241
|
@pytest.mark.paid
|
|
@@ -227,7 +255,52 @@ async def test_all_built_in_models_structured_input(
|
|
|
227
255
|
@pytest.mark.paid
|
|
228
256
|
@pytest.mark.ollama
|
|
229
257
|
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
230
|
-
async def
|
|
258
|
+
async def test_structured_input_cot_prompt_builder(tmp_path, model_name, provider_name):
|
|
231
259
|
task = build_structured_input_test_task(tmp_path)
|
|
232
260
|
pb = SimpleChainOfThoughtPromptBuilder(task)
|
|
233
261
|
await run_structured_input_task(task, model_name, provider_name, pb)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@pytest.mark.paid
|
|
265
|
+
@pytest.mark.ollama
|
|
266
|
+
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
267
|
+
async def test_structured_output_cot_prompt_builder(
|
|
268
|
+
tmp_path, model_name, provider_name
|
|
269
|
+
):
|
|
270
|
+
check_supports_structured_output(model_name, provider_name)
|
|
271
|
+
triangle_schema = {
|
|
272
|
+
"type": "object",
|
|
273
|
+
"properties": {
|
|
274
|
+
"is_equilateral": {
|
|
275
|
+
"type": "boolean",
|
|
276
|
+
"description": "True if all sides of the triangle are equal in length",
|
|
277
|
+
},
|
|
278
|
+
"is_scalene": {
|
|
279
|
+
"type": "boolean",
|
|
280
|
+
"description": "True if all sides of the triangle have different lengths",
|
|
281
|
+
},
|
|
282
|
+
"is_obtuse": {
|
|
283
|
+
"type": "boolean",
|
|
284
|
+
"description": "True if one of the angles is greater than 90 degrees",
|
|
285
|
+
},
|
|
286
|
+
},
|
|
287
|
+
"required": ["is_equilateral", "is_scalene", "is_obtuse"],
|
|
288
|
+
"additionalProperties": False,
|
|
289
|
+
}
|
|
290
|
+
task = build_structured_input_test_task(tmp_path)
|
|
291
|
+
task.instruction = """
|
|
292
|
+
You are an assistant which classifies a triangle given the lengths of its sides. If all sides are of equal length, the triangle is equilateral. If two sides are equal, the triangle is isosceles. Otherwise, it is scalene.\n\n"
|
|
293
|
+
|
|
294
|
+
When asked for a final result, this is the format (for an equilateral example):
|
|
295
|
+
```json
|
|
296
|
+
{
|
|
297
|
+
"is_equilateral": true,
|
|
298
|
+
"is_scalene": false,
|
|
299
|
+
"is_obtuse": false
|
|
300
|
+
}
|
|
301
|
+
```
|
|
302
|
+
"""
|
|
303
|
+
task.output_json_schema = json.dumps(triangle_schema)
|
|
304
|
+
task.save_to_file()
|
|
305
|
+
pb = SimpleChainOfThoughtPromptBuilder(task)
|
|
306
|
+
await run_structured_input_task(task, model_name, provider_name, pb)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from kiln_ai.adapters.run_output import RunOutput
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class BaseParser:
|
|
5
|
+
def __init__(self, structured_output: bool = False):
|
|
6
|
+
self.structured_output = structured_output
|
|
7
|
+
|
|
8
|
+
def parse_output(self, original_output: RunOutput) -> RunOutput:
|
|
9
|
+
"""
|
|
10
|
+
Method for parsing the output of a model. Typically overridden by subclasses.
|
|
11
|
+
"""
|
|
12
|
+
return original_output
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def parse_json_string(json_string: str) -> Dict[str, Any]:
|
|
6
|
+
"""
|
|
7
|
+
Parse a JSON string into a dictionary. Handles multiple formats:
|
|
8
|
+
- Plain JSON
|
|
9
|
+
- JSON wrapped in ```json code blocks
|
|
10
|
+
- JSON wrapped in ``` code blocks
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
json_string: String containing JSON data, possibly wrapped in code blocks
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
Dict containing parsed JSON data
|
|
17
|
+
|
|
18
|
+
Raises:
|
|
19
|
+
ValueError: If JSON parsing fails
|
|
20
|
+
"""
|
|
21
|
+
# Remove code block markers if present
|
|
22
|
+
cleaned_string = json_string.strip()
|
|
23
|
+
if cleaned_string.startswith("```"):
|
|
24
|
+
# Split by newlines and remove first/last lines if they contain ```
|
|
25
|
+
lines = cleaned_string.split("\n")
|
|
26
|
+
if lines[0].startswith("```"):
|
|
27
|
+
lines = lines[1:]
|
|
28
|
+
if lines and lines[-1].strip() == "```":
|
|
29
|
+
lines = lines[:-1]
|
|
30
|
+
cleaned_string = "\n".join(lines)
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
return json.loads(cleaned_string)
|
|
34
|
+
except json.JSONDecodeError as e:
|
|
35
|
+
raise ValueError(
|
|
36
|
+
f"This task requires JSON output but the model didn't return valid JSON. Search 'Troubleshooting Structured Data Issues' in our docs for more information. The model produced the following: {cleaned_string}"
|
|
37
|
+
) from e
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from typing import Type
|
|
2
|
+
|
|
3
|
+
from kiln_ai.adapters.ml_model_list import ModelParserID
|
|
4
|
+
from kiln_ai.adapters.parsers.base_parser import BaseParser
|
|
5
|
+
from kiln_ai.adapters.parsers.r1_parser import R1ThinkingParser
|
|
6
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def model_parser_from_id(parser_id: ModelParserID | None) -> Type[BaseParser]:
|
|
10
|
+
"""
|
|
11
|
+
Get a model parser from its ID.
|
|
12
|
+
"""
|
|
13
|
+
match parser_id:
|
|
14
|
+
case None:
|
|
15
|
+
return BaseParser
|
|
16
|
+
case ModelParserID.r1_thinking:
|
|
17
|
+
return R1ThinkingParser
|
|
18
|
+
case _:
|
|
19
|
+
raise_exhaustive_enum_error(parser_id)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from kiln_ai.adapters.parsers.base_parser import BaseParser
|
|
2
|
+
from kiln_ai.adapters.parsers.json_parser import parse_json_string
|
|
3
|
+
from kiln_ai.adapters.run_output import RunOutput
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class R1ThinkingParser(BaseParser):
|
|
7
|
+
START_TAG = "<think>"
|
|
8
|
+
END_TAG = "</think>"
|
|
9
|
+
|
|
10
|
+
def parse_output(self, original_output: RunOutput) -> RunOutput:
|
|
11
|
+
"""
|
|
12
|
+
Parse the <think> </think> tags from the response into the intermediate and final outputs.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
original_output: RunOutput containing the raw response string
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
ParsedOutput containing the intermediate content (thinking content) and final result
|
|
19
|
+
|
|
20
|
+
Raises:
|
|
21
|
+
ValueError: If response format is invalid (missing tags, multiple tags, or no content after closing tag)
|
|
22
|
+
"""
|
|
23
|
+
# This parser only works for strings
|
|
24
|
+
if not isinstance(original_output.output, str):
|
|
25
|
+
raise ValueError("Response must be a string for R1 parser")
|
|
26
|
+
|
|
27
|
+
# Strip whitespace and validate basic structure
|
|
28
|
+
cleaned_response = original_output.output.strip()
|
|
29
|
+
if not cleaned_response.startswith(self.START_TAG):
|
|
30
|
+
raise ValueError("Response must start with <think> tag")
|
|
31
|
+
|
|
32
|
+
# Find the thinking tags
|
|
33
|
+
think_start = cleaned_response.find(self.START_TAG)
|
|
34
|
+
think_end = cleaned_response.find(self.END_TAG)
|
|
35
|
+
|
|
36
|
+
if think_start == -1 or think_end == -1:
|
|
37
|
+
raise ValueError("Missing thinking tags")
|
|
38
|
+
|
|
39
|
+
# Check for multiple tags
|
|
40
|
+
if (
|
|
41
|
+
cleaned_response.count(self.START_TAG) > 1
|
|
42
|
+
or cleaned_response.count(self.END_TAG) > 1
|
|
43
|
+
):
|
|
44
|
+
raise ValueError("Multiple thinking tags found")
|
|
45
|
+
|
|
46
|
+
# Extract thinking content
|
|
47
|
+
thinking_content = cleaned_response[
|
|
48
|
+
think_start + len(self.START_TAG) : think_end
|
|
49
|
+
].strip()
|
|
50
|
+
|
|
51
|
+
# Extract result (everything after </think>)
|
|
52
|
+
result = cleaned_response[think_end + len(self.END_TAG) :].strip()
|
|
53
|
+
|
|
54
|
+
if not result or len(result) == 0:
|
|
55
|
+
raise ValueError("No content found after </think> tag")
|
|
56
|
+
|
|
57
|
+
# Parse JSON if needed
|
|
58
|
+
output = result
|
|
59
|
+
if self.structured_output:
|
|
60
|
+
output = parse_json_string(result)
|
|
61
|
+
|
|
62
|
+
# Add thinking content to intermediate outputs if it exists
|
|
63
|
+
intermediate_outputs = original_output.intermediate_outputs or {}
|
|
64
|
+
intermediate_outputs["reasoning"] = thinking_content
|
|
65
|
+
|
|
66
|
+
return RunOutput(
|
|
67
|
+
output=output,
|
|
68
|
+
intermediate_outputs=intermediate_outputs,
|
|
69
|
+
)
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from kiln_ai.adapters.parsers.json_parser import parse_json_string
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def test_parse_plain_json():
|
|
7
|
+
json_str = '{"key": "value", "number": 42}'
|
|
8
|
+
result = parse_json_string(json_str)
|
|
9
|
+
assert result == {"key": "value", "number": 42}
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def test_parse_json_with_code_block():
|
|
13
|
+
json_str = """```
|
|
14
|
+
{"key": "value", "number": 42}
|
|
15
|
+
```"""
|
|
16
|
+
result = parse_json_string(json_str)
|
|
17
|
+
assert result == {"key": "value", "number": 42}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def test_parse_json_with_language_block():
|
|
21
|
+
json_str = """```json
|
|
22
|
+
{"key": "value", "number": 42}
|
|
23
|
+
```"""
|
|
24
|
+
result = parse_json_string(json_str)
|
|
25
|
+
assert result == {"key": "value", "number": 42}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_parse_json_with_whitespace():
|
|
29
|
+
json_str = """
|
|
30
|
+
{
|
|
31
|
+
"key": "value",
|
|
32
|
+
"number": 42
|
|
33
|
+
}
|
|
34
|
+
"""
|
|
35
|
+
result = parse_json_string(json_str)
|
|
36
|
+
assert result == {"key": "value", "number": 42}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def test_parse_invalid_json():
|
|
40
|
+
json_str = '{"key": "value", invalid}'
|
|
41
|
+
with pytest.raises(ValueError) as exc_info:
|
|
42
|
+
parse_json_string(json_str)
|
|
43
|
+
assert (
|
|
44
|
+
"This task requires JSON output but the model didn't return valid JSON."
|
|
45
|
+
in str(exc_info.value)
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def test_parse_empty_code_block():
|
|
50
|
+
json_str = """```json
|
|
51
|
+
```"""
|
|
52
|
+
with pytest.raises(ValueError) as exc_info:
|
|
53
|
+
parse_json_string(json_str)
|
|
54
|
+
assert (
|
|
55
|
+
"This task requires JSON output but the model didn't return valid JSON."
|
|
56
|
+
in str(exc_info.value)
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def test_parse_complex_json():
|
|
61
|
+
json_str = """```json
|
|
62
|
+
{
|
|
63
|
+
"string": "hello",
|
|
64
|
+
"number": 42,
|
|
65
|
+
"bool": true,
|
|
66
|
+
"null": null,
|
|
67
|
+
"array": [1, 2, 3],
|
|
68
|
+
"nested": {
|
|
69
|
+
"inner": "value"
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
```"""
|
|
73
|
+
result = parse_json_string(json_str)
|
|
74
|
+
assert result == {
|
|
75
|
+
"string": "hello",
|
|
76
|
+
"number": 42,
|
|
77
|
+
"bool": True,
|
|
78
|
+
"null": None,
|
|
79
|
+
"array": [1, 2, 3],
|
|
80
|
+
"nested": {"inner": "value"},
|
|
81
|
+
}
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from kiln_ai.adapters.ml_model_list import ModelParserID
|
|
4
|
+
from kiln_ai.adapters.parsers.base_parser import BaseParser
|
|
5
|
+
from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id
|
|
6
|
+
from kiln_ai.adapters.parsers.r1_parser import R1ThinkingParser
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def test_model_parser_from_id_invalid():
|
|
10
|
+
"""Test that invalid parser ID raises ValueError."""
|
|
11
|
+
|
|
12
|
+
# Create a mock enum value that isn't handled
|
|
13
|
+
class MockModelParserID:
|
|
14
|
+
mock_value = "mock_value"
|
|
15
|
+
|
|
16
|
+
with pytest.raises(ValueError) as exc_info:
|
|
17
|
+
model_parser_from_id(MockModelParserID.mock_value) # type: ignore
|
|
18
|
+
|
|
19
|
+
assert "Unhandled enum value" in str(exc_info.value)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@pytest.mark.parametrize(
|
|
23
|
+
"parser_id,expected_class",
|
|
24
|
+
[
|
|
25
|
+
(None, BaseParser),
|
|
26
|
+
(ModelParserID.r1_thinking, R1ThinkingParser),
|
|
27
|
+
],
|
|
28
|
+
)
|
|
29
|
+
def test_model_parser_from_id_parametrized(parser_id, expected_class):
|
|
30
|
+
"""Test all valid parser IDs using parametrize."""
|
|
31
|
+
parser_class = model_parser_from_id(parser_id)
|
|
32
|
+
assert parser_class == expected_class
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from kiln_ai.adapters.parsers.r1_parser import R1ThinkingParser
|
|
4
|
+
from kiln_ai.adapters.run_output import RunOutput
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@pytest.fixture
|
|
8
|
+
def parser():
|
|
9
|
+
return R1ThinkingParser()
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def test_valid_response(parser):
|
|
13
|
+
response = RunOutput(
|
|
14
|
+
output="<think>This is thinking content</think>This is the result",
|
|
15
|
+
intermediate_outputs=None,
|
|
16
|
+
)
|
|
17
|
+
parsed = parser.parse_output(response)
|
|
18
|
+
assert parsed.intermediate_outputs["reasoning"] == "This is thinking content"
|
|
19
|
+
assert parsed.output == "This is the result"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def test_response_with_whitespace(parser):
|
|
23
|
+
response = RunOutput(
|
|
24
|
+
output="""
|
|
25
|
+
<think>
|
|
26
|
+
This is thinking content
|
|
27
|
+
</think>
|
|
28
|
+
This is the result
|
|
29
|
+
""",
|
|
30
|
+
intermediate_outputs=None,
|
|
31
|
+
)
|
|
32
|
+
parsed = parser.parse_output(response)
|
|
33
|
+
assert (
|
|
34
|
+
parsed.intermediate_outputs["reasoning"].strip() == "This is thinking content"
|
|
35
|
+
)
|
|
36
|
+
assert parsed.output.strip() == "This is the result"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def test_missing_start_tag(parser):
|
|
40
|
+
with pytest.raises(ValueError, match="Response must start with <think> tag"):
|
|
41
|
+
parser.parse_output(
|
|
42
|
+
RunOutput(output="Some content</think>result", intermediate_outputs=None)
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def test_missing_end_tag(parser):
|
|
47
|
+
with pytest.raises(ValueError, match="Missing thinking tags"):
|
|
48
|
+
parser.parse_output(
|
|
49
|
+
RunOutput(output="<think>Some content", intermediate_outputs=None)
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def test_multiple_start_tags(parser):
|
|
54
|
+
with pytest.raises(ValueError, match="Multiple thinking tags found"):
|
|
55
|
+
parser.parse_output(
|
|
56
|
+
RunOutput(
|
|
57
|
+
output="<think>content1<think>content2</think>result",
|
|
58
|
+
intermediate_outputs=None,
|
|
59
|
+
)
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_multiple_end_tags(parser):
|
|
64
|
+
with pytest.raises(ValueError, match="Multiple thinking tags found"):
|
|
65
|
+
parser.parse_output(
|
|
66
|
+
RunOutput(
|
|
67
|
+
output="<think>content</think></think>result", intermediate_outputs=None
|
|
68
|
+
)
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def test_empty_thinking_content(parser):
|
|
73
|
+
response = RunOutput(
|
|
74
|
+
output="<think></think>This is the result", intermediate_outputs=None
|
|
75
|
+
)
|
|
76
|
+
parsed = parser.parse_output(response)
|
|
77
|
+
assert parsed.intermediate_outputs == {"reasoning": ""}
|
|
78
|
+
assert parsed.output == "This is the result"
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def test_missing_result(parser):
|
|
82
|
+
with pytest.raises(ValueError, match="No content found after </think> tag"):
|
|
83
|
+
parser.parse_output(
|
|
84
|
+
RunOutput(output="<think>Some content</think>", intermediate_outputs=None)
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def test_multiline_content(parser):
|
|
89
|
+
response = RunOutput(
|
|
90
|
+
output="""<think>Line 1
|
|
91
|
+
Line 2
|
|
92
|
+
Line 3</think>Final result""",
|
|
93
|
+
intermediate_outputs=None,
|
|
94
|
+
)
|
|
95
|
+
parsed = parser.parse_output(response)
|
|
96
|
+
assert "Line 1" in parsed.intermediate_outputs["reasoning"]
|
|
97
|
+
assert "Line 2" in parsed.intermediate_outputs["reasoning"]
|
|
98
|
+
assert "Line 3" in parsed.intermediate_outputs["reasoning"]
|
|
99
|
+
assert parsed.output == "Final result"
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def test_special_characters(parser):
|
|
103
|
+
response = RunOutput(
|
|
104
|
+
output="<think>Content with: !@#$%^&*思()</think>Result with: !@#$%^&*思()",
|
|
105
|
+
intermediate_outputs=None,
|
|
106
|
+
)
|
|
107
|
+
parsed = parser.parse_output(response)
|
|
108
|
+
assert parsed.intermediate_outputs["reasoning"] == "Content with: !@#$%^&*思()"
|
|
109
|
+
assert parsed.output == "Result with: !@#$%^&*思()"
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def test_non_string_input(parser):
|
|
113
|
+
with pytest.raises(ValueError, match="Response must be a string for R1 parser"):
|
|
114
|
+
parser.parse_output(RunOutput(output={}, intermediate_outputs=None))
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def test_intermediate_outputs(parser):
|
|
118
|
+
# append to existing intermediate outputs
|
|
119
|
+
out = parser.parse_output(
|
|
120
|
+
RunOutput(
|
|
121
|
+
output="<think>Some content</think>result",
|
|
122
|
+
intermediate_outputs={"existing": "data"},
|
|
123
|
+
)
|
|
124
|
+
)
|
|
125
|
+
assert out.intermediate_outputs["reasoning"] == "Some content"
|
|
126
|
+
assert out.intermediate_outputs["existing"] == "data"
|
|
127
|
+
|
|
128
|
+
# empty dict is allowed
|
|
129
|
+
out = parser.parse_output(
|
|
130
|
+
RunOutput(
|
|
131
|
+
output="<think>Some content</think>result",
|
|
132
|
+
intermediate_outputs={},
|
|
133
|
+
)
|
|
134
|
+
)
|
|
135
|
+
assert out.intermediate_outputs["reasoning"] == "Some content"
|
|
136
|
+
|
|
137
|
+
# None is allowed
|
|
138
|
+
out = parser.parse_output(
|
|
139
|
+
RunOutput(
|
|
140
|
+
output="<think>Some content</think>result",
|
|
141
|
+
intermediate_outputs=None,
|
|
142
|
+
)
|
|
143
|
+
)
|
|
144
|
+
assert out.intermediate_outputs["reasoning"] == "Some content"
|