kiln-ai 0.8.1__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (88) hide show
  1. kiln_ai/adapters/__init__.py +7 -7
  2. kiln_ai/adapters/adapter_registry.py +81 -10
  3. kiln_ai/adapters/data_gen/data_gen_task.py +21 -3
  4. kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
  5. kiln_ai/adapters/eval/base_eval.py +164 -0
  6. kiln_ai/adapters/eval/eval_runner.py +267 -0
  7. kiln_ai/adapters/eval/g_eval.py +367 -0
  8. kiln_ai/adapters/eval/registry.py +16 -0
  9. kiln_ai/adapters/eval/test_base_eval.py +324 -0
  10. kiln_ai/adapters/eval/test_eval_runner.py +640 -0
  11. kiln_ai/adapters/eval/test_g_eval.py +497 -0
  12. kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
  14. kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
  15. kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
  16. kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
  17. kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
  18. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +472 -129
  19. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +114 -22
  20. kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
  21. kiln_ai/adapters/ml_model_list.py +434 -93
  22. kiln_ai/adapters/model_adapters/__init__.py +18 -0
  23. kiln_ai/adapters/model_adapters/base_adapter.py +250 -0
  24. kiln_ai/adapters/model_adapters/langchain_adapters.py +309 -0
  25. kiln_ai/adapters/model_adapters/openai_compatible_config.py +10 -0
  26. kiln_ai/adapters/model_adapters/openai_model_adapter.py +289 -0
  27. kiln_ai/adapters/model_adapters/test_base_adapter.py +199 -0
  28. kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +105 -97
  29. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +216 -0
  30. kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +80 -30
  31. kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +125 -46
  32. kiln_ai/adapters/ollama_tools.py +0 -1
  33. kiln_ai/adapters/parsers/__init__.py +10 -0
  34. kiln_ai/adapters/parsers/base_parser.py +12 -0
  35. kiln_ai/adapters/parsers/json_parser.py +37 -0
  36. kiln_ai/adapters/parsers/parser_registry.py +19 -0
  37. kiln_ai/adapters/parsers/r1_parser.py +69 -0
  38. kiln_ai/adapters/parsers/test_json_parser.py +81 -0
  39. kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
  40. kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
  41. kiln_ai/adapters/prompt_builders.py +193 -49
  42. kiln_ai/adapters/provider_tools.py +91 -36
  43. kiln_ai/adapters/repair/repair_task.py +18 -19
  44. kiln_ai/adapters/repair/test_repair_task.py +7 -7
  45. kiln_ai/adapters/run_output.py +11 -0
  46. kiln_ai/adapters/test_adapter_registry.py +177 -0
  47. kiln_ai/adapters/test_generate_docs.py +69 -0
  48. kiln_ai/adapters/test_ollama_tools.py +0 -1
  49. kiln_ai/adapters/test_prompt_adaptors.py +25 -18
  50. kiln_ai/adapters/test_prompt_builders.py +265 -44
  51. kiln_ai/adapters/test_provider_tools.py +268 -46
  52. kiln_ai/datamodel/__init__.py +51 -772
  53. kiln_ai/datamodel/basemodel.py +31 -11
  54. kiln_ai/datamodel/datamodel_enums.py +58 -0
  55. kiln_ai/datamodel/dataset_filters.py +114 -0
  56. kiln_ai/datamodel/dataset_split.py +170 -0
  57. kiln_ai/datamodel/eval.py +298 -0
  58. kiln_ai/datamodel/finetune.py +105 -0
  59. kiln_ai/datamodel/json_schema.py +14 -3
  60. kiln_ai/datamodel/model_cache.py +8 -3
  61. kiln_ai/datamodel/project.py +23 -0
  62. kiln_ai/datamodel/prompt.py +37 -0
  63. kiln_ai/datamodel/prompt_id.py +83 -0
  64. kiln_ai/datamodel/strict_mode.py +24 -0
  65. kiln_ai/datamodel/task.py +181 -0
  66. kiln_ai/datamodel/task_output.py +321 -0
  67. kiln_ai/datamodel/task_run.py +164 -0
  68. kiln_ai/datamodel/test_basemodel.py +80 -2
  69. kiln_ai/datamodel/test_dataset_filters.py +71 -0
  70. kiln_ai/datamodel/test_dataset_split.py +127 -6
  71. kiln_ai/datamodel/test_datasource.py +3 -2
  72. kiln_ai/datamodel/test_eval_model.py +635 -0
  73. kiln_ai/datamodel/test_example_models.py +34 -17
  74. kiln_ai/datamodel/test_json_schema.py +23 -0
  75. kiln_ai/datamodel/test_model_cache.py +24 -0
  76. kiln_ai/datamodel/test_model_perf.py +125 -0
  77. kiln_ai/datamodel/test_models.py +131 -2
  78. kiln_ai/datamodel/test_prompt_id.py +129 -0
  79. kiln_ai/datamodel/test_task.py +159 -0
  80. kiln_ai/utils/config.py +6 -1
  81. kiln_ai/utils/exhaustive_error.py +6 -0
  82. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/METADATA +45 -7
  83. kiln_ai-0.12.0.dist-info/RECORD +100 -0
  84. kiln_ai/adapters/base_adapter.py +0 -191
  85. kiln_ai/adapters/langchain_adapters.py +0 -256
  86. kiln_ai-0.8.1.dist-info/RECORD +0 -58
  87. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/WHEEL +0 -0
  88. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,81 @@
1
+ import pytest
2
+
3
+ from kiln_ai.adapters.parsers.json_parser import parse_json_string
4
+
5
+
6
+ def test_parse_plain_json():
7
+ json_str = '{"key": "value", "number": 42}'
8
+ result = parse_json_string(json_str)
9
+ assert result == {"key": "value", "number": 42}
10
+
11
+
12
+ def test_parse_json_with_code_block():
13
+ json_str = """```
14
+ {"key": "value", "number": 42}
15
+ ```"""
16
+ result = parse_json_string(json_str)
17
+ assert result == {"key": "value", "number": 42}
18
+
19
+
20
+ def test_parse_json_with_language_block():
21
+ json_str = """```json
22
+ {"key": "value", "number": 42}
23
+ ```"""
24
+ result = parse_json_string(json_str)
25
+ assert result == {"key": "value", "number": 42}
26
+
27
+
28
+ def test_parse_json_with_whitespace():
29
+ json_str = """
30
+ {
31
+ "key": "value",
32
+ "number": 42
33
+ }
34
+ """
35
+ result = parse_json_string(json_str)
36
+ assert result == {"key": "value", "number": 42}
37
+
38
+
39
+ def test_parse_invalid_json():
40
+ json_str = '{"key": "value", invalid}'
41
+ with pytest.raises(ValueError) as exc_info:
42
+ parse_json_string(json_str)
43
+ assert (
44
+ "This task requires JSON output but the model didn't return valid JSON."
45
+ in str(exc_info.value)
46
+ )
47
+
48
+
49
+ def test_parse_empty_code_block():
50
+ json_str = """```json
51
+ ```"""
52
+ with pytest.raises(ValueError) as exc_info:
53
+ parse_json_string(json_str)
54
+ assert (
55
+ "This task requires JSON output but the model didn't return valid JSON."
56
+ in str(exc_info.value)
57
+ )
58
+
59
+
60
+ def test_parse_complex_json():
61
+ json_str = """```json
62
+ {
63
+ "string": "hello",
64
+ "number": 42,
65
+ "bool": true,
66
+ "null": null,
67
+ "array": [1, 2, 3],
68
+ "nested": {
69
+ "inner": "value"
70
+ }
71
+ }
72
+ ```"""
73
+ result = parse_json_string(json_str)
74
+ assert result == {
75
+ "string": "hello",
76
+ "number": 42,
77
+ "bool": True,
78
+ "null": None,
79
+ "array": [1, 2, 3],
80
+ "nested": {"inner": "value"},
81
+ }
@@ -0,0 +1,32 @@
1
+ import pytest
2
+
3
+ from kiln_ai.adapters.ml_model_list import ModelParserID
4
+ from kiln_ai.adapters.parsers.base_parser import BaseParser
5
+ from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id
6
+ from kiln_ai.adapters.parsers.r1_parser import R1ThinkingParser
7
+
8
+
9
+ def test_model_parser_from_id_invalid():
10
+ """Test that invalid parser ID raises ValueError."""
11
+
12
+ # Create a mock enum value that isn't handled
13
+ class MockModelParserID:
14
+ mock_value = "mock_value"
15
+
16
+ with pytest.raises(ValueError) as exc_info:
17
+ model_parser_from_id(MockModelParserID.mock_value) # type: ignore
18
+
19
+ assert "Unhandled enum value" in str(exc_info.value)
20
+
21
+
22
+ @pytest.mark.parametrize(
23
+ "parser_id,expected_class",
24
+ [
25
+ (None, BaseParser),
26
+ (ModelParserID.r1_thinking, R1ThinkingParser),
27
+ ],
28
+ )
29
+ def test_model_parser_from_id_parametrized(parser_id, expected_class):
30
+ """Test all valid parser IDs using parametrize."""
31
+ parser_class = model_parser_from_id(parser_id)
32
+ assert parser_class == expected_class
@@ -0,0 +1,144 @@
1
+ import pytest
2
+
3
+ from kiln_ai.adapters.parsers.r1_parser import R1ThinkingParser
4
+ from kiln_ai.adapters.run_output import RunOutput
5
+
6
+
7
+ @pytest.fixture
8
+ def parser():
9
+ return R1ThinkingParser()
10
+
11
+
12
+ def test_valid_response(parser):
13
+ response = RunOutput(
14
+ output="<think>This is thinking content</think>This is the result",
15
+ intermediate_outputs=None,
16
+ )
17
+ parsed = parser.parse_output(response)
18
+ assert parsed.intermediate_outputs["reasoning"] == "This is thinking content"
19
+ assert parsed.output == "This is the result"
20
+
21
+
22
+ def test_response_with_whitespace(parser):
23
+ response = RunOutput(
24
+ output="""
25
+ <think>
26
+ This is thinking content
27
+ </think>
28
+ This is the result
29
+ """,
30
+ intermediate_outputs=None,
31
+ )
32
+ parsed = parser.parse_output(response)
33
+ assert (
34
+ parsed.intermediate_outputs["reasoning"].strip() == "This is thinking content"
35
+ )
36
+ assert parsed.output.strip() == "This is the result"
37
+
38
+
39
+ def test_missing_start_tag(parser):
40
+ with pytest.raises(ValueError, match="Response must start with <think> tag"):
41
+ parser.parse_output(
42
+ RunOutput(output="Some content</think>result", intermediate_outputs=None)
43
+ )
44
+
45
+
46
+ def test_missing_end_tag(parser):
47
+ with pytest.raises(ValueError, match="Missing thinking tags"):
48
+ parser.parse_output(
49
+ RunOutput(output="<think>Some content", intermediate_outputs=None)
50
+ )
51
+
52
+
53
+ def test_multiple_start_tags(parser):
54
+ with pytest.raises(ValueError, match="Multiple thinking tags found"):
55
+ parser.parse_output(
56
+ RunOutput(
57
+ output="<think>content1<think>content2</think>result",
58
+ intermediate_outputs=None,
59
+ )
60
+ )
61
+
62
+
63
+ def test_multiple_end_tags(parser):
64
+ with pytest.raises(ValueError, match="Multiple thinking tags found"):
65
+ parser.parse_output(
66
+ RunOutput(
67
+ output="<think>content</think></think>result", intermediate_outputs=None
68
+ )
69
+ )
70
+
71
+
72
+ def test_empty_thinking_content(parser):
73
+ response = RunOutput(
74
+ output="<think></think>This is the result", intermediate_outputs=None
75
+ )
76
+ parsed = parser.parse_output(response)
77
+ assert parsed.intermediate_outputs == {"reasoning": ""}
78
+ assert parsed.output == "This is the result"
79
+
80
+
81
+ def test_missing_result(parser):
82
+ with pytest.raises(ValueError, match="No content found after </think> tag"):
83
+ parser.parse_output(
84
+ RunOutput(output="<think>Some content</think>", intermediate_outputs=None)
85
+ )
86
+
87
+
88
+ def test_multiline_content(parser):
89
+ response = RunOutput(
90
+ output="""<think>Line 1
91
+ Line 2
92
+ Line 3</think>Final result""",
93
+ intermediate_outputs=None,
94
+ )
95
+ parsed = parser.parse_output(response)
96
+ assert "Line 1" in parsed.intermediate_outputs["reasoning"]
97
+ assert "Line 2" in parsed.intermediate_outputs["reasoning"]
98
+ assert "Line 3" in parsed.intermediate_outputs["reasoning"]
99
+ assert parsed.output == "Final result"
100
+
101
+
102
+ def test_special_characters(parser):
103
+ response = RunOutput(
104
+ output="<think>Content with: !@#$%^&*思()</think>Result with: !@#$%^&*思()",
105
+ intermediate_outputs=None,
106
+ )
107
+ parsed = parser.parse_output(response)
108
+ assert parsed.intermediate_outputs["reasoning"] == "Content with: !@#$%^&*思()"
109
+ assert parsed.output == "Result with: !@#$%^&*思()"
110
+
111
+
112
+ def test_non_string_input(parser):
113
+ with pytest.raises(ValueError, match="Response must be a string for R1 parser"):
114
+ parser.parse_output(RunOutput(output={}, intermediate_outputs=None))
115
+
116
+
117
+ def test_intermediate_outputs(parser):
118
+ # append to existing intermediate outputs
119
+ out = parser.parse_output(
120
+ RunOutput(
121
+ output="<think>Some content</think>result",
122
+ intermediate_outputs={"existing": "data"},
123
+ )
124
+ )
125
+ assert out.intermediate_outputs["reasoning"] == "Some content"
126
+ assert out.intermediate_outputs["existing"] == "data"
127
+
128
+ # empty dict is allowed
129
+ out = parser.parse_output(
130
+ RunOutput(
131
+ output="<think>Some content</think>result",
132
+ intermediate_outputs={},
133
+ )
134
+ )
135
+ assert out.intermediate_outputs["reasoning"] == "Some content"
136
+
137
+ # None is allowed
138
+ out = parser.parse_output(
139
+ RunOutput(
140
+ output="<think>Some content</think>result",
141
+ intermediate_outputs=None,
142
+ )
143
+ )
144
+ assert out.intermediate_outputs["reasoning"] == "Some content"
@@ -2,8 +2,8 @@ import json
2
2
  from abc import ABCMeta, abstractmethod
3
3
  from typing import Dict
4
4
 
5
- from kiln_ai.datamodel import Task, TaskRun
6
- from kiln_ai.utils.formatting import snake_case
5
+ from kiln_ai.datamodel import PromptGenerators, PromptId, Task, TaskRun
6
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
7
7
 
8
8
 
9
9
  class BasePromptBuilder(metaclass=ABCMeta):
@@ -20,25 +20,38 @@ class BasePromptBuilder(metaclass=ABCMeta):
20
20
  """
21
21
  self.task = task
22
22
 
23
- @abstractmethod
24
- def build_prompt(self) -> str:
23
+ def prompt_id(self) -> str | None:
24
+ """Returns the ID of the prompt, scoped to this builder.
25
+
26
+ Returns:
27
+ str | None: The ID of the prompt, or None if not set.
28
+ """
29
+ return None
30
+
31
+ def build_prompt(self, include_json_instructions) -> str:
25
32
  """Build and return the complete prompt string.
26
33
 
27
34
  Returns:
28
35
  str: The constructed prompt.
29
36
  """
30
- pass
37
+ prompt = self.build_base_prompt()
31
38
 
32
- @classmethod
33
- def prompt_builder_name(cls) -> str:
34
- """Returns the name of the prompt builder, to be used for persisting into the datastore.
39
+ if include_json_instructions and self.task.output_schema():
40
+ prompt = (
41
+ prompt
42
+ + f"\n\n# Format Instructions\n\nReturn a JSON object conforming to the following schema:\n```\n{self.task.output_schema()}\n```"
43
+ )
44
+
45
+ return prompt
35
46
 
36
- Default implementation gets the name of the prompt builder in snake case. If you change the class name, you should override this so prior saved data is compatible.
47
+ @abstractmethod
48
+ def build_base_prompt(self) -> str:
49
+ """Build and return the complete prompt string.
37
50
 
38
51
  Returns:
39
- str: The prompt builder name in snake_case format.
52
+ str: The constructed prompt.
40
53
  """
41
- return snake_case(cls.__name__)
54
+ pass
42
55
 
43
56
  def build_user_message(self, input: Dict | str) -> str:
44
57
  """Build a user message from the input.
@@ -50,7 +63,7 @@ class BasePromptBuilder(metaclass=ABCMeta):
50
63
  str: The formatted user message.
51
64
  """
52
65
  if isinstance(input, Dict):
53
- return f"The input is:\n{json.dumps(input, indent=2)}"
66
+ return f"The input is:\n{json.dumps(input, indent=2, ensure_ascii=False)}"
54
67
 
55
68
  return f"The input is:\n{input}"
56
69
 
@@ -70,7 +83,7 @@ class BasePromptBuilder(metaclass=ABCMeta):
70
83
  Returns:
71
84
  str: The constructed prompt string.
72
85
  """
73
- base_prompt = self.build_prompt()
86
+ base_prompt = self.build_prompt(include_json_instructions=False)
74
87
  cot_prompt = self.chain_of_thought_prompt()
75
88
  if cot_prompt:
76
89
  base_prompt += "\n# Thinking Instructions\n\n" + cot_prompt
@@ -80,7 +93,7 @@ class BasePromptBuilder(metaclass=ABCMeta):
80
93
  class SimplePromptBuilder(BasePromptBuilder):
81
94
  """A basic prompt builder that combines task instruction with requirements."""
82
95
 
83
- def build_prompt(self) -> str:
96
+ def build_base_prompt(self) -> str:
84
97
  """Build a simple prompt with instruction and requirements.
85
98
 
86
99
  Returns:
@@ -95,7 +108,7 @@ class SimplePromptBuilder(BasePromptBuilder):
95
108
  )
96
109
  # iterate requirements, formatting them in numbereed list like 1) task.instruction\n2)...
97
110
  for i, requirement in enumerate(self.task.requirements):
98
- base_prompt += f"{i+1}) {requirement.instruction}\n"
111
+ base_prompt += f"{i + 1}) {requirement.instruction}\n"
99
112
 
100
113
  return base_prompt
101
114
 
@@ -112,18 +125,18 @@ class MultiShotPromptBuilder(BasePromptBuilder):
112
125
  """
113
126
  return 25
114
127
 
115
- def build_prompt(self) -> str:
128
+ def build_base_prompt(self) -> str:
116
129
  """Build a prompt with instruction, requirements, and multiple examples.
117
130
 
118
131
  Returns:
119
132
  str: The constructed prompt string with examples.
120
133
  """
121
- base_prompt = f"# Instruction\n\n{ self.task.instruction }\n\n"
134
+ base_prompt = f"# Instruction\n\n{self.task.instruction}\n\n"
122
135
 
123
136
  if len(self.task.requirements) > 0:
124
137
  base_prompt += "# Requirements\n\nYour response should respect the following requirements:\n"
125
138
  for i, requirement in enumerate(self.task.requirements):
126
- base_prompt += f"{i+1}) {requirement.instruction}\n"
139
+ base_prompt += f"{i + 1}) {requirement.instruction}\n"
127
140
  base_prompt += "\n"
128
141
 
129
142
  valid_examples = self.collect_examples()
@@ -140,11 +153,11 @@ class MultiShotPromptBuilder(BasePromptBuilder):
140
153
  def prompt_section_for_example(self, index: int, example: TaskRun) -> str:
141
154
  # Prefer repaired output if it exists, otherwise use the regular output
142
155
  output = example.repaired_output or example.output
143
- return f"## Example {index+1}\n\nInput: {example.input}\nOutput: {output.output}\n\n"
156
+ return f"## Example {index + 1}\n\nInput: {example.input}\nOutput: {output.output}\n\n"
144
157
 
145
158
  def collect_examples(self) -> list[TaskRun]:
146
159
  valid_examples: list[TaskRun] = []
147
- runs = self.task.runs()
160
+ runs = self.task.runs(readonly=True)
148
161
 
149
162
  # first pass, we look for repaired outputs. These are the best examples.
150
163
  for run in runs:
@@ -198,7 +211,7 @@ class RepairsPromptBuilder(MultiShotPromptBuilder):
198
211
  ):
199
212
  return super().prompt_section_for_example(index, example)
200
213
 
201
- prompt_section = f"## Example {index+1}\n\nInput: {example.input}\n\n"
214
+ prompt_section = f"## Example {index + 1}\n\nInput: {example.input}\n\n"
202
215
  prompt_section += (
203
216
  f"Initial Output Which Was Insufficient: {example.output.output}\n\n"
204
217
  )
@@ -209,7 +222,7 @@ class RepairsPromptBuilder(MultiShotPromptBuilder):
209
222
  return prompt_section
210
223
 
211
224
 
212
- def chain_of_thought_prompt(task: Task) -> str | None:
225
+ def chain_of_thought_prompt(task: Task) -> str:
213
226
  """Standard implementation to build and return the chain of thought prompt string.
214
227
 
215
228
  Returns:
@@ -244,23 +257,132 @@ class MultiShotChainOfThoughtPromptBuilder(MultiShotPromptBuilder):
244
257
  return chain_of_thought_prompt(self.task)
245
258
 
246
259
 
247
- prompt_builder_registry = {
248
- "simple_prompt_builder": SimplePromptBuilder,
249
- "multi_shot_prompt_builder": MultiShotPromptBuilder,
250
- "few_shot_prompt_builder": FewShotPromptBuilder,
251
- "repairs_prompt_builder": RepairsPromptBuilder,
252
- "simple_chain_of_thought_prompt_builder": SimpleChainOfThoughtPromptBuilder,
253
- "few_shot_chain_of_thought_prompt_builder": FewShotChainOfThoughtPromptBuilder,
254
- "multi_shot_chain_of_thought_prompt_builder": MultiShotChainOfThoughtPromptBuilder,
255
- }
260
+ class SavedPromptBuilder(BasePromptBuilder):
261
+ """A prompt builder that looks up a static prompt."""
262
+
263
+ def __init__(self, task: Task, prompt_id: str):
264
+ super().__init__(task)
265
+ prompt_model = next(
266
+ (
267
+ prompt
268
+ for prompt in task.prompts(readonly=True)
269
+ if prompt.id == prompt_id
270
+ ),
271
+ None,
272
+ )
273
+ if not prompt_model:
274
+ raise ValueError(f"Prompt ID not found: {prompt_id}")
275
+ self.prompt_model = prompt_model
276
+
277
+ def prompt_id(self) -> str | None:
278
+ return self.prompt_model.id
279
+
280
+ def build_base_prompt(self) -> str:
281
+ """Returns a saved prompt.
282
+
283
+ Returns:
284
+ str: The prompt string.
285
+ """
286
+ return self.prompt_model.prompt
287
+
288
+ def chain_of_thought_prompt(self) -> str | None:
289
+ return self.prompt_model.chain_of_thought_instructions
290
+
291
+
292
+ class TaskRunConfigPromptBuilder(BasePromptBuilder):
293
+ """A prompt builder that looks up a static prompt in a task run config."""
294
+
295
+ def __init__(self, task: Task, run_config_prompt_id: str):
296
+ parts = run_config_prompt_id.split("::")
297
+ if len(parts) != 4:
298
+ raise ValueError(
299
+ f"Invalid task run config prompt ID: {run_config_prompt_id}. Expected format: 'task_run_config::[project_id]::[task_id]::[run_config_id]'."
300
+ )
301
+
302
+ task_id = parts[2]
303
+ if task_id != task.id:
304
+ raise ValueError(
305
+ f"Task run config prompt ID: {run_config_prompt_id}. Task ID mismatch. Expected: {task.id}, got: {task_id}."
306
+ )
307
+
308
+ run_config_id = parts[3]
309
+ run_config = next(
310
+ (
311
+ run_config
312
+ for run_config in task.run_configs(readonly=True)
313
+ if run_config.id == run_config_id
314
+ ),
315
+ None,
316
+ )
317
+ if not run_config:
318
+ raise ValueError(
319
+ f"Task run config ID not found: {run_config_id} for prompt id {run_config_prompt_id}"
320
+ )
321
+ if run_config.prompt is None:
322
+ raise ValueError(
323
+ f"Task run config ID {run_config_id} does not have a stored prompt. Used as prompt id {run_config_prompt_id}"
324
+ )
325
+
326
+ # Load the prompt from the model
327
+ self.prompt = run_config.prompt.prompt
328
+ self.cot_prompt = run_config.prompt.chain_of_thought_instructions
329
+ self.id = run_config_prompt_id
330
+
331
+ super().__init__(task)
332
+
333
+ def prompt_id(self) -> str | None:
334
+ return self.id
335
+
336
+ def build_base_prompt(self) -> str:
337
+ return self.prompt
338
+
339
+ def chain_of_thought_prompt(self) -> str | None:
340
+ return self.cot_prompt
341
+
342
+
343
+ class FineTunePromptBuilder(BasePromptBuilder):
344
+ """A prompt builder that looks up a fine-tune prompt."""
345
+
346
+ def __init__(self, task: Task, nested_fine_tune_id: str):
347
+ super().__init__(task)
348
+
349
+ # IDs are in project_id::task_id::fine_tune_id format
350
+ self.full_fine_tune_id = nested_fine_tune_id
351
+ parts = nested_fine_tune_id.split("::")
352
+ if len(parts) != 3:
353
+ raise ValueError(
354
+ f"Invalid fine-tune ID format. Expected 'project_id::task_id::fine_tune_id', got: {nested_fine_tune_id}"
355
+ )
356
+ fine_tune_id = parts[2]
357
+
358
+ fine_tune_model = next(
359
+ (
360
+ fine_tune
361
+ for fine_tune in task.finetunes(readonly=True)
362
+ if fine_tune.id == fine_tune_id
363
+ ),
364
+ None,
365
+ )
366
+ if not fine_tune_model:
367
+ raise ValueError(f"Fine-tune ID not found: {fine_tune_id}")
368
+ self.fine_tune_model = fine_tune_model
369
+
370
+ def prompt_id(self) -> str | None:
371
+ return self.full_fine_tune_id
372
+
373
+ def build_base_prompt(self) -> str:
374
+ return self.fine_tune_model.system_message
375
+
376
+ def chain_of_thought_prompt(self) -> str | None:
377
+ return self.fine_tune_model.thinking_instructions
256
378
 
257
379
 
258
380
  # Our UI has some names that are not the same as the class names, which also hint parameters.
259
- def prompt_builder_from_ui_name(ui_name: str) -> type[BasePromptBuilder]:
381
+ def prompt_builder_from_id(prompt_id: PromptId, task: Task) -> BasePromptBuilder:
260
382
  """Convert a name used in the UI to the corresponding prompt builder class.
261
383
 
262
384
  Args:
263
- ui_name (str): The UI name for the prompt builder type.
385
+ prompt_id (PromptId): The prompt ID.
264
386
 
265
387
  Returns:
266
388
  type[BasePromptBuilder]: The corresponding prompt builder class.
@@ -268,20 +390,42 @@ def prompt_builder_from_ui_name(ui_name: str) -> type[BasePromptBuilder]:
268
390
  Raises:
269
391
  ValueError: If the UI name is not recognized.
270
392
  """
271
- match ui_name:
272
- case "basic":
273
- return SimplePromptBuilder
274
- case "few_shot":
275
- return FewShotPromptBuilder
276
- case "many_shot":
277
- return MultiShotPromptBuilder
278
- case "repairs":
279
- return RepairsPromptBuilder
280
- case "simple_chain_of_thought":
281
- return SimpleChainOfThoughtPromptBuilder
282
- case "few_shot_chain_of_thought":
283
- return FewShotChainOfThoughtPromptBuilder
284
- case "multi_shot_chain_of_thought":
285
- return MultiShotChainOfThoughtPromptBuilder
393
+
394
+ # Saved prompts are prefixed with "id::"
395
+ if prompt_id.startswith("id::"):
396
+ prompt_id = prompt_id[4:]
397
+ return SavedPromptBuilder(task, prompt_id)
398
+
399
+ # Task run config prompts are prefixed with "task_run_config::"
400
+ # task_run_config::[project_id]::[task_id]::[run_config_id]
401
+ if prompt_id.startswith("task_run_config::"):
402
+ return TaskRunConfigPromptBuilder(task, prompt_id)
403
+
404
+ # Fine-tune prompts are prefixed with "fine_tune_prompt::"
405
+ if prompt_id.startswith("fine_tune_prompt::"):
406
+ prompt_id = prompt_id[18:]
407
+ return FineTunePromptBuilder(task, prompt_id)
408
+
409
+ # Check if the prompt_id matches any enum value
410
+ if prompt_id not in [member.value for member in PromptGenerators]:
411
+ raise ValueError(f"Unknown prompt generator: {prompt_id}")
412
+ typed_prompt_generator = PromptGenerators(prompt_id)
413
+
414
+ match typed_prompt_generator:
415
+ case PromptGenerators.SIMPLE:
416
+ return SimplePromptBuilder(task)
417
+ case PromptGenerators.FEW_SHOT:
418
+ return FewShotPromptBuilder(task)
419
+ case PromptGenerators.MULTI_SHOT:
420
+ return MultiShotPromptBuilder(task)
421
+ case PromptGenerators.REPAIRS:
422
+ return RepairsPromptBuilder(task)
423
+ case PromptGenerators.SIMPLE_CHAIN_OF_THOUGHT:
424
+ return SimpleChainOfThoughtPromptBuilder(task)
425
+ case PromptGenerators.FEW_SHOT_CHAIN_OF_THOUGHT:
426
+ return FewShotChainOfThoughtPromptBuilder(task)
427
+ case PromptGenerators.MULTI_SHOT_CHAIN_OF_THOUGHT:
428
+ return MultiShotChainOfThoughtPromptBuilder(task)
286
429
  case _:
287
- raise ValueError(f"Unknown prompt builder: {ui_name}")
430
+ # Type checking will find missing cases
431
+ raise_exhaustive_enum_error(typed_prompt_generator)