kiln-ai 0.8.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +7 -7
- kiln_ai/adapters/adapter_registry.py +81 -10
- kiln_ai/adapters/data_gen/data_gen_task.py +21 -3
- kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
- kiln_ai/adapters/eval/base_eval.py +164 -0
- kiln_ai/adapters/eval/eval_runner.py +267 -0
- kiln_ai/adapters/eval/g_eval.py +367 -0
- kiln_ai/adapters/eval/registry.py +16 -0
- kiln_ai/adapters/eval/test_base_eval.py +324 -0
- kiln_ai/adapters/eval/test_eval_runner.py +640 -0
- kiln_ai/adapters/eval/test_g_eval.py +497 -0
- kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
- kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
- kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
- kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +472 -129
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +114 -22
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
- kiln_ai/adapters/ml_model_list.py +434 -93
- kiln_ai/adapters/model_adapters/__init__.py +18 -0
- kiln_ai/adapters/model_adapters/base_adapter.py +250 -0
- kiln_ai/adapters/model_adapters/langchain_adapters.py +309 -0
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +10 -0
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +289 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +199 -0
- kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +105 -97
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +216 -0
- kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +80 -30
- kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +125 -46
- kiln_ai/adapters/ollama_tools.py +0 -1
- kiln_ai/adapters/parsers/__init__.py +10 -0
- kiln_ai/adapters/parsers/base_parser.py +12 -0
- kiln_ai/adapters/parsers/json_parser.py +37 -0
- kiln_ai/adapters/parsers/parser_registry.py +19 -0
- kiln_ai/adapters/parsers/r1_parser.py +69 -0
- kiln_ai/adapters/parsers/test_json_parser.py +81 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
- kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
- kiln_ai/adapters/prompt_builders.py +193 -49
- kiln_ai/adapters/provider_tools.py +91 -36
- kiln_ai/adapters/repair/repair_task.py +18 -19
- kiln_ai/adapters/repair/test_repair_task.py +7 -7
- kiln_ai/adapters/run_output.py +11 -0
- kiln_ai/adapters/test_adapter_registry.py +177 -0
- kiln_ai/adapters/test_generate_docs.py +69 -0
- kiln_ai/adapters/test_ollama_tools.py +0 -1
- kiln_ai/adapters/test_prompt_adaptors.py +25 -18
- kiln_ai/adapters/test_prompt_builders.py +265 -44
- kiln_ai/adapters/test_provider_tools.py +268 -46
- kiln_ai/datamodel/__init__.py +51 -772
- kiln_ai/datamodel/basemodel.py +31 -11
- kiln_ai/datamodel/datamodel_enums.py +58 -0
- kiln_ai/datamodel/dataset_filters.py +114 -0
- kiln_ai/datamodel/dataset_split.py +170 -0
- kiln_ai/datamodel/eval.py +298 -0
- kiln_ai/datamodel/finetune.py +105 -0
- kiln_ai/datamodel/json_schema.py +14 -3
- kiln_ai/datamodel/model_cache.py +8 -3
- kiln_ai/datamodel/project.py +23 -0
- kiln_ai/datamodel/prompt.py +37 -0
- kiln_ai/datamodel/prompt_id.py +83 -0
- kiln_ai/datamodel/strict_mode.py +24 -0
- kiln_ai/datamodel/task.py +181 -0
- kiln_ai/datamodel/task_output.py +321 -0
- kiln_ai/datamodel/task_run.py +164 -0
- kiln_ai/datamodel/test_basemodel.py +80 -2
- kiln_ai/datamodel/test_dataset_filters.py +71 -0
- kiln_ai/datamodel/test_dataset_split.py +127 -6
- kiln_ai/datamodel/test_datasource.py +3 -2
- kiln_ai/datamodel/test_eval_model.py +635 -0
- kiln_ai/datamodel/test_example_models.py +34 -17
- kiln_ai/datamodel/test_json_schema.py +23 -0
- kiln_ai/datamodel/test_model_cache.py +24 -0
- kiln_ai/datamodel/test_model_perf.py +125 -0
- kiln_ai/datamodel/test_models.py +131 -2
- kiln_ai/datamodel/test_prompt_id.py +129 -0
- kiln_ai/datamodel/test_task.py +159 -0
- kiln_ai/utils/config.py +6 -1
- kiln_ai/utils/exhaustive_error.py +6 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/METADATA +45 -7
- kiln_ai-0.12.0.dist-info/RECORD +100 -0
- kiln_ai/adapters/base_adapter.py +0 -191
- kiln_ai/adapters/langchain_adapters.py +0 -256
- kiln_ai-0.8.1.dist-info/RECORD +0 -58
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from kiln_ai.adapters.parsers.json_parser import parse_json_string
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def test_parse_plain_json():
|
|
7
|
+
json_str = '{"key": "value", "number": 42}'
|
|
8
|
+
result = parse_json_string(json_str)
|
|
9
|
+
assert result == {"key": "value", "number": 42}
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def test_parse_json_with_code_block():
|
|
13
|
+
json_str = """```
|
|
14
|
+
{"key": "value", "number": 42}
|
|
15
|
+
```"""
|
|
16
|
+
result = parse_json_string(json_str)
|
|
17
|
+
assert result == {"key": "value", "number": 42}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def test_parse_json_with_language_block():
|
|
21
|
+
json_str = """```json
|
|
22
|
+
{"key": "value", "number": 42}
|
|
23
|
+
```"""
|
|
24
|
+
result = parse_json_string(json_str)
|
|
25
|
+
assert result == {"key": "value", "number": 42}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_parse_json_with_whitespace():
|
|
29
|
+
json_str = """
|
|
30
|
+
{
|
|
31
|
+
"key": "value",
|
|
32
|
+
"number": 42
|
|
33
|
+
}
|
|
34
|
+
"""
|
|
35
|
+
result = parse_json_string(json_str)
|
|
36
|
+
assert result == {"key": "value", "number": 42}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def test_parse_invalid_json():
|
|
40
|
+
json_str = '{"key": "value", invalid}'
|
|
41
|
+
with pytest.raises(ValueError) as exc_info:
|
|
42
|
+
parse_json_string(json_str)
|
|
43
|
+
assert (
|
|
44
|
+
"This task requires JSON output but the model didn't return valid JSON."
|
|
45
|
+
in str(exc_info.value)
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def test_parse_empty_code_block():
|
|
50
|
+
json_str = """```json
|
|
51
|
+
```"""
|
|
52
|
+
with pytest.raises(ValueError) as exc_info:
|
|
53
|
+
parse_json_string(json_str)
|
|
54
|
+
assert (
|
|
55
|
+
"This task requires JSON output but the model didn't return valid JSON."
|
|
56
|
+
in str(exc_info.value)
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def test_parse_complex_json():
|
|
61
|
+
json_str = """```json
|
|
62
|
+
{
|
|
63
|
+
"string": "hello",
|
|
64
|
+
"number": 42,
|
|
65
|
+
"bool": true,
|
|
66
|
+
"null": null,
|
|
67
|
+
"array": [1, 2, 3],
|
|
68
|
+
"nested": {
|
|
69
|
+
"inner": "value"
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
```"""
|
|
73
|
+
result = parse_json_string(json_str)
|
|
74
|
+
assert result == {
|
|
75
|
+
"string": "hello",
|
|
76
|
+
"number": 42,
|
|
77
|
+
"bool": True,
|
|
78
|
+
"null": None,
|
|
79
|
+
"array": [1, 2, 3],
|
|
80
|
+
"nested": {"inner": "value"},
|
|
81
|
+
}
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from kiln_ai.adapters.ml_model_list import ModelParserID
|
|
4
|
+
from kiln_ai.adapters.parsers.base_parser import BaseParser
|
|
5
|
+
from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id
|
|
6
|
+
from kiln_ai.adapters.parsers.r1_parser import R1ThinkingParser
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def test_model_parser_from_id_invalid():
|
|
10
|
+
"""Test that invalid parser ID raises ValueError."""
|
|
11
|
+
|
|
12
|
+
# Create a mock enum value that isn't handled
|
|
13
|
+
class MockModelParserID:
|
|
14
|
+
mock_value = "mock_value"
|
|
15
|
+
|
|
16
|
+
with pytest.raises(ValueError) as exc_info:
|
|
17
|
+
model_parser_from_id(MockModelParserID.mock_value) # type: ignore
|
|
18
|
+
|
|
19
|
+
assert "Unhandled enum value" in str(exc_info.value)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@pytest.mark.parametrize(
|
|
23
|
+
"parser_id,expected_class",
|
|
24
|
+
[
|
|
25
|
+
(None, BaseParser),
|
|
26
|
+
(ModelParserID.r1_thinking, R1ThinkingParser),
|
|
27
|
+
],
|
|
28
|
+
)
|
|
29
|
+
def test_model_parser_from_id_parametrized(parser_id, expected_class):
|
|
30
|
+
"""Test all valid parser IDs using parametrize."""
|
|
31
|
+
parser_class = model_parser_from_id(parser_id)
|
|
32
|
+
assert parser_class == expected_class
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from kiln_ai.adapters.parsers.r1_parser import R1ThinkingParser
|
|
4
|
+
from kiln_ai.adapters.run_output import RunOutput
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@pytest.fixture
|
|
8
|
+
def parser():
|
|
9
|
+
return R1ThinkingParser()
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def test_valid_response(parser):
|
|
13
|
+
response = RunOutput(
|
|
14
|
+
output="<think>This is thinking content</think>This is the result",
|
|
15
|
+
intermediate_outputs=None,
|
|
16
|
+
)
|
|
17
|
+
parsed = parser.parse_output(response)
|
|
18
|
+
assert parsed.intermediate_outputs["reasoning"] == "This is thinking content"
|
|
19
|
+
assert parsed.output == "This is the result"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def test_response_with_whitespace(parser):
|
|
23
|
+
response = RunOutput(
|
|
24
|
+
output="""
|
|
25
|
+
<think>
|
|
26
|
+
This is thinking content
|
|
27
|
+
</think>
|
|
28
|
+
This is the result
|
|
29
|
+
""",
|
|
30
|
+
intermediate_outputs=None,
|
|
31
|
+
)
|
|
32
|
+
parsed = parser.parse_output(response)
|
|
33
|
+
assert (
|
|
34
|
+
parsed.intermediate_outputs["reasoning"].strip() == "This is thinking content"
|
|
35
|
+
)
|
|
36
|
+
assert parsed.output.strip() == "This is the result"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def test_missing_start_tag(parser):
|
|
40
|
+
with pytest.raises(ValueError, match="Response must start with <think> tag"):
|
|
41
|
+
parser.parse_output(
|
|
42
|
+
RunOutput(output="Some content</think>result", intermediate_outputs=None)
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def test_missing_end_tag(parser):
|
|
47
|
+
with pytest.raises(ValueError, match="Missing thinking tags"):
|
|
48
|
+
parser.parse_output(
|
|
49
|
+
RunOutput(output="<think>Some content", intermediate_outputs=None)
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def test_multiple_start_tags(parser):
|
|
54
|
+
with pytest.raises(ValueError, match="Multiple thinking tags found"):
|
|
55
|
+
parser.parse_output(
|
|
56
|
+
RunOutput(
|
|
57
|
+
output="<think>content1<think>content2</think>result",
|
|
58
|
+
intermediate_outputs=None,
|
|
59
|
+
)
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_multiple_end_tags(parser):
|
|
64
|
+
with pytest.raises(ValueError, match="Multiple thinking tags found"):
|
|
65
|
+
parser.parse_output(
|
|
66
|
+
RunOutput(
|
|
67
|
+
output="<think>content</think></think>result", intermediate_outputs=None
|
|
68
|
+
)
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def test_empty_thinking_content(parser):
|
|
73
|
+
response = RunOutput(
|
|
74
|
+
output="<think></think>This is the result", intermediate_outputs=None
|
|
75
|
+
)
|
|
76
|
+
parsed = parser.parse_output(response)
|
|
77
|
+
assert parsed.intermediate_outputs == {"reasoning": ""}
|
|
78
|
+
assert parsed.output == "This is the result"
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def test_missing_result(parser):
|
|
82
|
+
with pytest.raises(ValueError, match="No content found after </think> tag"):
|
|
83
|
+
parser.parse_output(
|
|
84
|
+
RunOutput(output="<think>Some content</think>", intermediate_outputs=None)
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def test_multiline_content(parser):
|
|
89
|
+
response = RunOutput(
|
|
90
|
+
output="""<think>Line 1
|
|
91
|
+
Line 2
|
|
92
|
+
Line 3</think>Final result""",
|
|
93
|
+
intermediate_outputs=None,
|
|
94
|
+
)
|
|
95
|
+
parsed = parser.parse_output(response)
|
|
96
|
+
assert "Line 1" in parsed.intermediate_outputs["reasoning"]
|
|
97
|
+
assert "Line 2" in parsed.intermediate_outputs["reasoning"]
|
|
98
|
+
assert "Line 3" in parsed.intermediate_outputs["reasoning"]
|
|
99
|
+
assert parsed.output == "Final result"
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def test_special_characters(parser):
|
|
103
|
+
response = RunOutput(
|
|
104
|
+
output="<think>Content with: !@#$%^&*思()</think>Result with: !@#$%^&*思()",
|
|
105
|
+
intermediate_outputs=None,
|
|
106
|
+
)
|
|
107
|
+
parsed = parser.parse_output(response)
|
|
108
|
+
assert parsed.intermediate_outputs["reasoning"] == "Content with: !@#$%^&*思()"
|
|
109
|
+
assert parsed.output == "Result with: !@#$%^&*思()"
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def test_non_string_input(parser):
|
|
113
|
+
with pytest.raises(ValueError, match="Response must be a string for R1 parser"):
|
|
114
|
+
parser.parse_output(RunOutput(output={}, intermediate_outputs=None))
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def test_intermediate_outputs(parser):
|
|
118
|
+
# append to existing intermediate outputs
|
|
119
|
+
out = parser.parse_output(
|
|
120
|
+
RunOutput(
|
|
121
|
+
output="<think>Some content</think>result",
|
|
122
|
+
intermediate_outputs={"existing": "data"},
|
|
123
|
+
)
|
|
124
|
+
)
|
|
125
|
+
assert out.intermediate_outputs["reasoning"] == "Some content"
|
|
126
|
+
assert out.intermediate_outputs["existing"] == "data"
|
|
127
|
+
|
|
128
|
+
# empty dict is allowed
|
|
129
|
+
out = parser.parse_output(
|
|
130
|
+
RunOutput(
|
|
131
|
+
output="<think>Some content</think>result",
|
|
132
|
+
intermediate_outputs={},
|
|
133
|
+
)
|
|
134
|
+
)
|
|
135
|
+
assert out.intermediate_outputs["reasoning"] == "Some content"
|
|
136
|
+
|
|
137
|
+
# None is allowed
|
|
138
|
+
out = parser.parse_output(
|
|
139
|
+
RunOutput(
|
|
140
|
+
output="<think>Some content</think>result",
|
|
141
|
+
intermediate_outputs=None,
|
|
142
|
+
)
|
|
143
|
+
)
|
|
144
|
+
assert out.intermediate_outputs["reasoning"] == "Some content"
|
|
@@ -2,8 +2,8 @@ import json
|
|
|
2
2
|
from abc import ABCMeta, abstractmethod
|
|
3
3
|
from typing import Dict
|
|
4
4
|
|
|
5
|
-
from kiln_ai.datamodel import Task, TaskRun
|
|
6
|
-
from kiln_ai.utils.
|
|
5
|
+
from kiln_ai.datamodel import PromptGenerators, PromptId, Task, TaskRun
|
|
6
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class BasePromptBuilder(metaclass=ABCMeta):
|
|
@@ -20,25 +20,38 @@ class BasePromptBuilder(metaclass=ABCMeta):
|
|
|
20
20
|
"""
|
|
21
21
|
self.task = task
|
|
22
22
|
|
|
23
|
-
|
|
24
|
-
|
|
23
|
+
def prompt_id(self) -> str | None:
|
|
24
|
+
"""Returns the ID of the prompt, scoped to this builder.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
str | None: The ID of the prompt, or None if not set.
|
|
28
|
+
"""
|
|
29
|
+
return None
|
|
30
|
+
|
|
31
|
+
def build_prompt(self, include_json_instructions) -> str:
|
|
25
32
|
"""Build and return the complete prompt string.
|
|
26
33
|
|
|
27
34
|
Returns:
|
|
28
35
|
str: The constructed prompt.
|
|
29
36
|
"""
|
|
30
|
-
|
|
37
|
+
prompt = self.build_base_prompt()
|
|
31
38
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
39
|
+
if include_json_instructions and self.task.output_schema():
|
|
40
|
+
prompt = (
|
|
41
|
+
prompt
|
|
42
|
+
+ f"\n\n# Format Instructions\n\nReturn a JSON object conforming to the following schema:\n```\n{self.task.output_schema()}\n```"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
return prompt
|
|
35
46
|
|
|
36
|
-
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def build_base_prompt(self) -> str:
|
|
49
|
+
"""Build and return the complete prompt string.
|
|
37
50
|
|
|
38
51
|
Returns:
|
|
39
|
-
str: The prompt
|
|
52
|
+
str: The constructed prompt.
|
|
40
53
|
"""
|
|
41
|
-
|
|
54
|
+
pass
|
|
42
55
|
|
|
43
56
|
def build_user_message(self, input: Dict | str) -> str:
|
|
44
57
|
"""Build a user message from the input.
|
|
@@ -50,7 +63,7 @@ class BasePromptBuilder(metaclass=ABCMeta):
|
|
|
50
63
|
str: The formatted user message.
|
|
51
64
|
"""
|
|
52
65
|
if isinstance(input, Dict):
|
|
53
|
-
return f"The input is:\n{json.dumps(input, indent=2)}"
|
|
66
|
+
return f"The input is:\n{json.dumps(input, indent=2, ensure_ascii=False)}"
|
|
54
67
|
|
|
55
68
|
return f"The input is:\n{input}"
|
|
56
69
|
|
|
@@ -70,7 +83,7 @@ class BasePromptBuilder(metaclass=ABCMeta):
|
|
|
70
83
|
Returns:
|
|
71
84
|
str: The constructed prompt string.
|
|
72
85
|
"""
|
|
73
|
-
base_prompt = self.build_prompt()
|
|
86
|
+
base_prompt = self.build_prompt(include_json_instructions=False)
|
|
74
87
|
cot_prompt = self.chain_of_thought_prompt()
|
|
75
88
|
if cot_prompt:
|
|
76
89
|
base_prompt += "\n# Thinking Instructions\n\n" + cot_prompt
|
|
@@ -80,7 +93,7 @@ class BasePromptBuilder(metaclass=ABCMeta):
|
|
|
80
93
|
class SimplePromptBuilder(BasePromptBuilder):
|
|
81
94
|
"""A basic prompt builder that combines task instruction with requirements."""
|
|
82
95
|
|
|
83
|
-
def
|
|
96
|
+
def build_base_prompt(self) -> str:
|
|
84
97
|
"""Build a simple prompt with instruction and requirements.
|
|
85
98
|
|
|
86
99
|
Returns:
|
|
@@ -95,7 +108,7 @@ class SimplePromptBuilder(BasePromptBuilder):
|
|
|
95
108
|
)
|
|
96
109
|
# iterate requirements, formatting them in numbereed list like 1) task.instruction\n2)...
|
|
97
110
|
for i, requirement in enumerate(self.task.requirements):
|
|
98
|
-
base_prompt += f"{i+1}) {requirement.instruction}\n"
|
|
111
|
+
base_prompt += f"{i + 1}) {requirement.instruction}\n"
|
|
99
112
|
|
|
100
113
|
return base_prompt
|
|
101
114
|
|
|
@@ -112,18 +125,18 @@ class MultiShotPromptBuilder(BasePromptBuilder):
|
|
|
112
125
|
"""
|
|
113
126
|
return 25
|
|
114
127
|
|
|
115
|
-
def
|
|
128
|
+
def build_base_prompt(self) -> str:
|
|
116
129
|
"""Build a prompt with instruction, requirements, and multiple examples.
|
|
117
130
|
|
|
118
131
|
Returns:
|
|
119
132
|
str: The constructed prompt string with examples.
|
|
120
133
|
"""
|
|
121
|
-
base_prompt = f"# Instruction\n\n{
|
|
134
|
+
base_prompt = f"# Instruction\n\n{self.task.instruction}\n\n"
|
|
122
135
|
|
|
123
136
|
if len(self.task.requirements) > 0:
|
|
124
137
|
base_prompt += "# Requirements\n\nYour response should respect the following requirements:\n"
|
|
125
138
|
for i, requirement in enumerate(self.task.requirements):
|
|
126
|
-
base_prompt += f"{i+1}) {requirement.instruction}\n"
|
|
139
|
+
base_prompt += f"{i + 1}) {requirement.instruction}\n"
|
|
127
140
|
base_prompt += "\n"
|
|
128
141
|
|
|
129
142
|
valid_examples = self.collect_examples()
|
|
@@ -140,11 +153,11 @@ class MultiShotPromptBuilder(BasePromptBuilder):
|
|
|
140
153
|
def prompt_section_for_example(self, index: int, example: TaskRun) -> str:
|
|
141
154
|
# Prefer repaired output if it exists, otherwise use the regular output
|
|
142
155
|
output = example.repaired_output or example.output
|
|
143
|
-
return f"## Example {index+1}\n\nInput: {example.input}\nOutput: {output.output}\n\n"
|
|
156
|
+
return f"## Example {index + 1}\n\nInput: {example.input}\nOutput: {output.output}\n\n"
|
|
144
157
|
|
|
145
158
|
def collect_examples(self) -> list[TaskRun]:
|
|
146
159
|
valid_examples: list[TaskRun] = []
|
|
147
|
-
runs = self.task.runs()
|
|
160
|
+
runs = self.task.runs(readonly=True)
|
|
148
161
|
|
|
149
162
|
# first pass, we look for repaired outputs. These are the best examples.
|
|
150
163
|
for run in runs:
|
|
@@ -198,7 +211,7 @@ class RepairsPromptBuilder(MultiShotPromptBuilder):
|
|
|
198
211
|
):
|
|
199
212
|
return super().prompt_section_for_example(index, example)
|
|
200
213
|
|
|
201
|
-
prompt_section = f"## Example {index+1}\n\nInput: {example.input}\n\n"
|
|
214
|
+
prompt_section = f"## Example {index + 1}\n\nInput: {example.input}\n\n"
|
|
202
215
|
prompt_section += (
|
|
203
216
|
f"Initial Output Which Was Insufficient: {example.output.output}\n\n"
|
|
204
217
|
)
|
|
@@ -209,7 +222,7 @@ class RepairsPromptBuilder(MultiShotPromptBuilder):
|
|
|
209
222
|
return prompt_section
|
|
210
223
|
|
|
211
224
|
|
|
212
|
-
def chain_of_thought_prompt(task: Task) -> str
|
|
225
|
+
def chain_of_thought_prompt(task: Task) -> str:
|
|
213
226
|
"""Standard implementation to build and return the chain of thought prompt string.
|
|
214
227
|
|
|
215
228
|
Returns:
|
|
@@ -244,23 +257,132 @@ class MultiShotChainOfThoughtPromptBuilder(MultiShotPromptBuilder):
|
|
|
244
257
|
return chain_of_thought_prompt(self.task)
|
|
245
258
|
|
|
246
259
|
|
|
247
|
-
|
|
248
|
-
"
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
260
|
+
class SavedPromptBuilder(BasePromptBuilder):
|
|
261
|
+
"""A prompt builder that looks up a static prompt."""
|
|
262
|
+
|
|
263
|
+
def __init__(self, task: Task, prompt_id: str):
|
|
264
|
+
super().__init__(task)
|
|
265
|
+
prompt_model = next(
|
|
266
|
+
(
|
|
267
|
+
prompt
|
|
268
|
+
for prompt in task.prompts(readonly=True)
|
|
269
|
+
if prompt.id == prompt_id
|
|
270
|
+
),
|
|
271
|
+
None,
|
|
272
|
+
)
|
|
273
|
+
if not prompt_model:
|
|
274
|
+
raise ValueError(f"Prompt ID not found: {prompt_id}")
|
|
275
|
+
self.prompt_model = prompt_model
|
|
276
|
+
|
|
277
|
+
def prompt_id(self) -> str | None:
|
|
278
|
+
return self.prompt_model.id
|
|
279
|
+
|
|
280
|
+
def build_base_prompt(self) -> str:
|
|
281
|
+
"""Returns a saved prompt.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
str: The prompt string.
|
|
285
|
+
"""
|
|
286
|
+
return self.prompt_model.prompt
|
|
287
|
+
|
|
288
|
+
def chain_of_thought_prompt(self) -> str | None:
|
|
289
|
+
return self.prompt_model.chain_of_thought_instructions
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class TaskRunConfigPromptBuilder(BasePromptBuilder):
|
|
293
|
+
"""A prompt builder that looks up a static prompt in a task run config."""
|
|
294
|
+
|
|
295
|
+
def __init__(self, task: Task, run_config_prompt_id: str):
|
|
296
|
+
parts = run_config_prompt_id.split("::")
|
|
297
|
+
if len(parts) != 4:
|
|
298
|
+
raise ValueError(
|
|
299
|
+
f"Invalid task run config prompt ID: {run_config_prompt_id}. Expected format: 'task_run_config::[project_id]::[task_id]::[run_config_id]'."
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
task_id = parts[2]
|
|
303
|
+
if task_id != task.id:
|
|
304
|
+
raise ValueError(
|
|
305
|
+
f"Task run config prompt ID: {run_config_prompt_id}. Task ID mismatch. Expected: {task.id}, got: {task_id}."
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
run_config_id = parts[3]
|
|
309
|
+
run_config = next(
|
|
310
|
+
(
|
|
311
|
+
run_config
|
|
312
|
+
for run_config in task.run_configs(readonly=True)
|
|
313
|
+
if run_config.id == run_config_id
|
|
314
|
+
),
|
|
315
|
+
None,
|
|
316
|
+
)
|
|
317
|
+
if not run_config:
|
|
318
|
+
raise ValueError(
|
|
319
|
+
f"Task run config ID not found: {run_config_id} for prompt id {run_config_prompt_id}"
|
|
320
|
+
)
|
|
321
|
+
if run_config.prompt is None:
|
|
322
|
+
raise ValueError(
|
|
323
|
+
f"Task run config ID {run_config_id} does not have a stored prompt. Used as prompt id {run_config_prompt_id}"
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
# Load the prompt from the model
|
|
327
|
+
self.prompt = run_config.prompt.prompt
|
|
328
|
+
self.cot_prompt = run_config.prompt.chain_of_thought_instructions
|
|
329
|
+
self.id = run_config_prompt_id
|
|
330
|
+
|
|
331
|
+
super().__init__(task)
|
|
332
|
+
|
|
333
|
+
def prompt_id(self) -> str | None:
|
|
334
|
+
return self.id
|
|
335
|
+
|
|
336
|
+
def build_base_prompt(self) -> str:
|
|
337
|
+
return self.prompt
|
|
338
|
+
|
|
339
|
+
def chain_of_thought_prompt(self) -> str | None:
|
|
340
|
+
return self.cot_prompt
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
class FineTunePromptBuilder(BasePromptBuilder):
|
|
344
|
+
"""A prompt builder that looks up a fine-tune prompt."""
|
|
345
|
+
|
|
346
|
+
def __init__(self, task: Task, nested_fine_tune_id: str):
|
|
347
|
+
super().__init__(task)
|
|
348
|
+
|
|
349
|
+
# IDs are in project_id::task_id::fine_tune_id format
|
|
350
|
+
self.full_fine_tune_id = nested_fine_tune_id
|
|
351
|
+
parts = nested_fine_tune_id.split("::")
|
|
352
|
+
if len(parts) != 3:
|
|
353
|
+
raise ValueError(
|
|
354
|
+
f"Invalid fine-tune ID format. Expected 'project_id::task_id::fine_tune_id', got: {nested_fine_tune_id}"
|
|
355
|
+
)
|
|
356
|
+
fine_tune_id = parts[2]
|
|
357
|
+
|
|
358
|
+
fine_tune_model = next(
|
|
359
|
+
(
|
|
360
|
+
fine_tune
|
|
361
|
+
for fine_tune in task.finetunes(readonly=True)
|
|
362
|
+
if fine_tune.id == fine_tune_id
|
|
363
|
+
),
|
|
364
|
+
None,
|
|
365
|
+
)
|
|
366
|
+
if not fine_tune_model:
|
|
367
|
+
raise ValueError(f"Fine-tune ID not found: {fine_tune_id}")
|
|
368
|
+
self.fine_tune_model = fine_tune_model
|
|
369
|
+
|
|
370
|
+
def prompt_id(self) -> str | None:
|
|
371
|
+
return self.full_fine_tune_id
|
|
372
|
+
|
|
373
|
+
def build_base_prompt(self) -> str:
|
|
374
|
+
return self.fine_tune_model.system_message
|
|
375
|
+
|
|
376
|
+
def chain_of_thought_prompt(self) -> str | None:
|
|
377
|
+
return self.fine_tune_model.thinking_instructions
|
|
256
378
|
|
|
257
379
|
|
|
258
380
|
# Our UI has some names that are not the same as the class names, which also hint parameters.
|
|
259
|
-
def
|
|
381
|
+
def prompt_builder_from_id(prompt_id: PromptId, task: Task) -> BasePromptBuilder:
|
|
260
382
|
"""Convert a name used in the UI to the corresponding prompt builder class.
|
|
261
383
|
|
|
262
384
|
Args:
|
|
263
|
-
|
|
385
|
+
prompt_id (PromptId): The prompt ID.
|
|
264
386
|
|
|
265
387
|
Returns:
|
|
266
388
|
type[BasePromptBuilder]: The corresponding prompt builder class.
|
|
@@ -268,20 +390,42 @@ def prompt_builder_from_ui_name(ui_name: str) -> type[BasePromptBuilder]:
|
|
|
268
390
|
Raises:
|
|
269
391
|
ValueError: If the UI name is not recognized.
|
|
270
392
|
"""
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
393
|
+
|
|
394
|
+
# Saved prompts are prefixed with "id::"
|
|
395
|
+
if prompt_id.startswith("id::"):
|
|
396
|
+
prompt_id = prompt_id[4:]
|
|
397
|
+
return SavedPromptBuilder(task, prompt_id)
|
|
398
|
+
|
|
399
|
+
# Task run config prompts are prefixed with "task_run_config::"
|
|
400
|
+
# task_run_config::[project_id]::[task_id]::[run_config_id]
|
|
401
|
+
if prompt_id.startswith("task_run_config::"):
|
|
402
|
+
return TaskRunConfigPromptBuilder(task, prompt_id)
|
|
403
|
+
|
|
404
|
+
# Fine-tune prompts are prefixed with "fine_tune_prompt::"
|
|
405
|
+
if prompt_id.startswith("fine_tune_prompt::"):
|
|
406
|
+
prompt_id = prompt_id[18:]
|
|
407
|
+
return FineTunePromptBuilder(task, prompt_id)
|
|
408
|
+
|
|
409
|
+
# Check if the prompt_id matches any enum value
|
|
410
|
+
if prompt_id not in [member.value for member in PromptGenerators]:
|
|
411
|
+
raise ValueError(f"Unknown prompt generator: {prompt_id}")
|
|
412
|
+
typed_prompt_generator = PromptGenerators(prompt_id)
|
|
413
|
+
|
|
414
|
+
match typed_prompt_generator:
|
|
415
|
+
case PromptGenerators.SIMPLE:
|
|
416
|
+
return SimplePromptBuilder(task)
|
|
417
|
+
case PromptGenerators.FEW_SHOT:
|
|
418
|
+
return FewShotPromptBuilder(task)
|
|
419
|
+
case PromptGenerators.MULTI_SHOT:
|
|
420
|
+
return MultiShotPromptBuilder(task)
|
|
421
|
+
case PromptGenerators.REPAIRS:
|
|
422
|
+
return RepairsPromptBuilder(task)
|
|
423
|
+
case PromptGenerators.SIMPLE_CHAIN_OF_THOUGHT:
|
|
424
|
+
return SimpleChainOfThoughtPromptBuilder(task)
|
|
425
|
+
case PromptGenerators.FEW_SHOT_CHAIN_OF_THOUGHT:
|
|
426
|
+
return FewShotChainOfThoughtPromptBuilder(task)
|
|
427
|
+
case PromptGenerators.MULTI_SHOT_CHAIN_OF_THOUGHT:
|
|
428
|
+
return MultiShotChainOfThoughtPromptBuilder(task)
|
|
286
429
|
case _:
|
|
287
|
-
|
|
430
|
+
# Type checking will find missing cases
|
|
431
|
+
raise_exhaustive_enum_error(typed_prompt_generator)
|