kiln-ai 0.8.0__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (57) hide show
  1. kiln_ai/adapters/__init__.py +7 -7
  2. kiln_ai/adapters/adapter_registry.py +77 -5
  3. kiln_ai/adapters/data_gen/data_gen_task.py +3 -3
  4. kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
  5. kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
  6. kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
  7. kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
  8. kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
  9. kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
  10. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +469 -129
  11. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +113 -21
  12. kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
  13. kiln_ai/adapters/ml_model_list.py +323 -94
  14. kiln_ai/adapters/model_adapters/__init__.py +18 -0
  15. kiln_ai/adapters/{base_adapter.py → model_adapters/base_adapter.py} +81 -37
  16. kiln_ai/adapters/{langchain_adapters.py → model_adapters/langchain_adapters.py} +130 -84
  17. kiln_ai/adapters/model_adapters/openai_compatible_config.py +11 -0
  18. kiln_ai/adapters/model_adapters/openai_model_adapter.py +246 -0
  19. kiln_ai/adapters/model_adapters/test_base_adapter.py +190 -0
  20. kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +103 -88
  21. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +225 -0
  22. kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +43 -15
  23. kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +93 -20
  24. kiln_ai/adapters/parsers/__init__.py +10 -0
  25. kiln_ai/adapters/parsers/base_parser.py +12 -0
  26. kiln_ai/adapters/parsers/json_parser.py +37 -0
  27. kiln_ai/adapters/parsers/parser_registry.py +19 -0
  28. kiln_ai/adapters/parsers/r1_parser.py +69 -0
  29. kiln_ai/adapters/parsers/test_json_parser.py +81 -0
  30. kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
  31. kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
  32. kiln_ai/adapters/prompt_builders.py +126 -20
  33. kiln_ai/adapters/provider_tools.py +91 -36
  34. kiln_ai/adapters/repair/repair_task.py +17 -6
  35. kiln_ai/adapters/repair/test_repair_task.py +4 -4
  36. kiln_ai/adapters/run_output.py +8 -0
  37. kiln_ai/adapters/test_adapter_registry.py +177 -0
  38. kiln_ai/adapters/test_generate_docs.py +69 -0
  39. kiln_ai/adapters/test_prompt_adaptors.py +8 -4
  40. kiln_ai/adapters/test_prompt_builders.py +190 -29
  41. kiln_ai/adapters/test_provider_tools.py +268 -46
  42. kiln_ai/datamodel/__init__.py +199 -12
  43. kiln_ai/datamodel/basemodel.py +31 -11
  44. kiln_ai/datamodel/json_schema.py +8 -3
  45. kiln_ai/datamodel/model_cache.py +8 -3
  46. kiln_ai/datamodel/test_basemodel.py +81 -2
  47. kiln_ai/datamodel/test_dataset_split.py +100 -3
  48. kiln_ai/datamodel/test_example_models.py +25 -4
  49. kiln_ai/datamodel/test_model_cache.py +24 -0
  50. kiln_ai/datamodel/test_model_perf.py +125 -0
  51. kiln_ai/datamodel/test_models.py +129 -0
  52. kiln_ai/utils/exhaustive_error.py +6 -0
  53. {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/METADATA +9 -7
  54. kiln_ai-0.11.1.dist-info/RECORD +76 -0
  55. kiln_ai-0.8.0.dist-info/RECORD +0 -58
  56. {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/WHEEL +0 -0
  57. {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,177 @@
1
+ from unittest.mock import patch
2
+
3
+ import pytest
4
+
5
+ from kiln_ai import datamodel
6
+ from kiln_ai.adapters.adapter_registry import adapter_for_task
7
+ from kiln_ai.adapters.ml_model_list import ModelProviderName
8
+ from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
9
+ from kiln_ai.adapters.model_adapters.openai_model_adapter import OpenAICompatibleAdapter
10
+ from kiln_ai.adapters.prompt_builders import BasePromptBuilder
11
+ from kiln_ai.adapters.provider_tools import kiln_model_provider_from
12
+
13
+
14
+ @pytest.fixture
15
+ def mock_config():
16
+ with patch("kiln_ai.adapters.adapter_registry.Config") as mock:
17
+ mock.shared.return_value.open_ai_api_key = "test-openai-key"
18
+ mock.shared.return_value.open_router_api_key = "test-openrouter-key"
19
+ yield mock
20
+
21
+
22
+ @pytest.fixture
23
+ def basic_task():
24
+ return datamodel.Task(
25
+ task_id="test-task",
26
+ task_type="test",
27
+ input_text="test input",
28
+ name="test-task",
29
+ instruction="test-task",
30
+ )
31
+
32
+
33
+ @pytest.fixture
34
+ def mock_finetune_from_id():
35
+ with patch("kiln_ai.adapters.provider_tools.finetune_from_id") as mock:
36
+ mock.return_value.provider = ModelProviderName.openai
37
+ mock.return_value.fine_tune_model_id = "test-model"
38
+ yield mock
39
+
40
+
41
+ def test_openai_adapter_creation(mock_config, basic_task):
42
+ adapter = adapter_for_task(
43
+ kiln_task=basic_task, model_name="gpt-4", provider=ModelProviderName.openai
44
+ )
45
+
46
+ assert isinstance(adapter, OpenAICompatibleAdapter)
47
+ assert adapter.config.model_name == "gpt-4"
48
+ assert adapter.config.api_key == "test-openai-key"
49
+ assert adapter.config.provider_name == ModelProviderName.openai
50
+ assert adapter.config.base_url is None # OpenAI url is default
51
+ assert adapter.config.default_headers is None
52
+
53
+
54
+ def test_openrouter_adapter_creation(mock_config, basic_task):
55
+ adapter = adapter_for_task(
56
+ kiln_task=basic_task,
57
+ model_name="anthropic/claude-3-opus",
58
+ provider=ModelProviderName.openrouter,
59
+ )
60
+
61
+ assert isinstance(adapter, OpenAICompatibleAdapter)
62
+ assert adapter.config.model_name == "anthropic/claude-3-opus"
63
+ assert adapter.config.api_key == "test-openrouter-key"
64
+ assert adapter.config.provider_name == ModelProviderName.openrouter
65
+ assert adapter.config.base_url == "https://openrouter.ai/api/v1"
66
+ assert adapter.config.default_headers == {
67
+ "HTTP-Referer": "https://getkiln.ai/openrouter",
68
+ "X-Title": "KilnAI",
69
+ }
70
+
71
+
72
+ @pytest.mark.parametrize(
73
+ "provider",
74
+ [
75
+ ModelProviderName.groq,
76
+ ModelProviderName.amazon_bedrock,
77
+ ModelProviderName.ollama,
78
+ ModelProviderName.fireworks_ai,
79
+ ],
80
+ )
81
+ def test_langchain_adapter_creation(mock_config, basic_task, provider):
82
+ adapter = adapter_for_task(
83
+ kiln_task=basic_task, model_name="test-model", provider=provider
84
+ )
85
+
86
+ assert isinstance(adapter, LangchainAdapter)
87
+ assert adapter.model_name == "test-model"
88
+
89
+
90
+ # TODO should run for all cases
91
+ def test_custom_prompt_builder(mock_config, basic_task):
92
+ class TestPromptBuilder(BasePromptBuilder):
93
+ def build_base_prompt(self, kiln_task) -> str:
94
+ return "test-prompt"
95
+
96
+ prompt_builder = TestPromptBuilder(basic_task)
97
+ adapter = adapter_for_task(
98
+ kiln_task=basic_task,
99
+ model_name="gpt-4",
100
+ provider=ModelProviderName.openai,
101
+ prompt_builder=prompt_builder,
102
+ )
103
+
104
+ assert adapter.prompt_builder == prompt_builder
105
+
106
+
107
+ # TODO should run for all cases
108
+ def test_tags_passed_through(mock_config, basic_task):
109
+ tags = ["test-tag-1", "test-tag-2"]
110
+ adapter = adapter_for_task(
111
+ kiln_task=basic_task,
112
+ model_name="gpt-4",
113
+ provider=ModelProviderName.openai,
114
+ tags=tags,
115
+ )
116
+
117
+ assert adapter.default_tags == tags
118
+
119
+
120
+ def test_invalid_provider(mock_config, basic_task):
121
+ with pytest.raises(ValueError, match="Unhandled enum value"):
122
+ adapter_for_task(
123
+ kiln_task=basic_task, model_name="test-model", provider="invalid"
124
+ )
125
+
126
+
127
+ @patch("kiln_ai.adapters.adapter_registry.openai_compatible_config")
128
+ def test_openai_compatible_adapter(mock_compatible_config, mock_config, basic_task):
129
+ mock_compatible_config.return_value.model_name = "test-model"
130
+ mock_compatible_config.return_value.api_key = "test-key"
131
+ mock_compatible_config.return_value.base_url = "https://test.com/v1"
132
+
133
+ adapter = adapter_for_task(
134
+ kiln_task=basic_task,
135
+ model_name="provider::test-model",
136
+ provider=ModelProviderName.openai_compatible,
137
+ )
138
+
139
+ assert isinstance(adapter, OpenAICompatibleAdapter)
140
+ mock_compatible_config.assert_called_once_with("provider::test-model")
141
+ assert adapter.config.model_name == "test-model"
142
+ assert adapter.config.api_key == "test-key"
143
+ assert adapter.config.base_url == "https://test.com/v1"
144
+
145
+
146
+ def test_custom_openai_compatible_provider(mock_config, basic_task):
147
+ adapter = adapter_for_task(
148
+ kiln_task=basic_task,
149
+ model_name="openai::test-model",
150
+ provider=ModelProviderName.kiln_custom_registry,
151
+ )
152
+
153
+ assert isinstance(adapter, OpenAICompatibleAdapter)
154
+ assert adapter.config.model_name == "openai::test-model"
155
+ assert adapter.config.api_key == "test-openai-key"
156
+ assert adapter.config.base_url is None # openai is none
157
+ assert adapter.config.provider_name == ModelProviderName.kiln_custom_registry
158
+
159
+
160
+ async def test_fine_tune_provider(mock_config, basic_task, mock_finetune_from_id):
161
+ adapter = adapter_for_task(
162
+ kiln_task=basic_task,
163
+ model_name="proj::task::tune",
164
+ provider=ModelProviderName.kiln_fine_tune,
165
+ )
166
+
167
+ mock_finetune_from_id.assert_called_once_with("proj::task::tune")
168
+ assert isinstance(adapter, OpenAICompatibleAdapter)
169
+ assert adapter.config.provider_name == ModelProviderName.kiln_fine_tune
170
+ # Kiln model name here, but the underlying openai model id below
171
+ assert adapter.config.model_name == "proj::task::tune"
172
+
173
+ provider = kiln_model_provider_from(
174
+ "proj::task::tune", provider_name=ModelProviderName.kiln_fine_tune
175
+ )
176
+ # The actual model name from the fine tune object
177
+ assert provider.provider_options["model"] == "test-model"
@@ -0,0 +1,69 @@
1
+ from typing import List
2
+
3
+ import pytest
4
+
5
+ from libs.core.kiln_ai.adapters.ml_model_list import (
6
+ KilnModelProvider,
7
+ built_in_models,
8
+ )
9
+ from libs.core.kiln_ai.adapters.provider_tools import provider_name_from_id
10
+
11
+
12
+ def _all_providers_support(providers: List[KilnModelProvider], attribute: str) -> bool:
13
+ """Check if all providers support a given feature"""
14
+ return all(getattr(provider, attribute) for provider in providers)
15
+
16
+
17
+ def _any_providers_support(providers: List[KilnModelProvider], attribute: str) -> bool:
18
+ """Check if any providers support a given feature"""
19
+ return any(getattr(provider, attribute) for provider in providers)
20
+
21
+
22
+ def _get_support_status(providers: List[KilnModelProvider], attribute: str) -> str:
23
+ """Get the support status for a feature"""
24
+ if _all_providers_support(providers, attribute):
25
+ return "✅︎"
26
+ elif _any_providers_support(providers, attribute):
27
+ return "✅︎ (some providers)"
28
+ return ""
29
+
30
+
31
+ def _has_finetune_support(providers: List[KilnModelProvider]) -> str:
32
+ """Check if any provider supports fine-tuning"""
33
+ return "✅︎" if any(p.provider_finetune_id for p in providers) else ""
34
+
35
+
36
+ @pytest.mark.paid(reason="Marking as paid so it isn't run by default")
37
+ def test_generate_model_table():
38
+ """Generate a markdown table of all models and their capabilities"""
39
+
40
+ # Table header
41
+ table = [
42
+ "| Model Name | Providers | Structured Output | Reasoning | Synthetic Data | API Fine-Tuneable |",
43
+ "|------------|-----------|-------------------|-----------|----------------|-------------------|",
44
+ ]
45
+
46
+ for model in built_in_models:
47
+ provider_names = ", ".join(
48
+ sorted(provider_name_from_id(p.name.value) for p in model.providers)
49
+ )
50
+ structured_output = _get_support_status(
51
+ model.providers, "supports_structured_output"
52
+ )
53
+ reasoning = _get_support_status(model.providers, "reasoning_capable")
54
+ data_gen = _get_support_status(model.providers, "supports_data_gen")
55
+ finetune = _has_finetune_support(model.providers)
56
+
57
+ row = f"| {model.friendly_name} | {provider_names} | {structured_output} | {reasoning} | {data_gen} | {finetune} |"
58
+ table.append(row)
59
+
60
+ # Print the table (useful for documentation)
61
+ print("\nModel Capability Matrix:\n")
62
+ print("\n".join(table))
63
+
64
+ # Basic assertions to ensure the table is well-formed
65
+ assert len(table) > 2, "Table should have header and at least one row"
66
+ assert all("|" in row for row in table), "All rows should be properly formatted"
67
+ assert len(table[0].split("|")) == len(table[1].split("|")), (
68
+ "Header and separator should have same number of columns"
69
+ )
@@ -6,8 +6,8 @@ from langchain_core.language_models.fake_chat_models import FakeListChatModel
6
6
 
7
7
  import kiln_ai.datamodel as datamodel
8
8
  from kiln_ai.adapters.adapter_registry import adapter_for_task
9
- from kiln_ai.adapters.langchain_adapters import LangchainAdapter
10
9
  from kiln_ai.adapters.ml_model_list import built_in_models
10
+ from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
11
11
  from kiln_ai.adapters.ollama_tools import ollama_online
12
12
  from kiln_ai.adapters.prompt_builders import (
13
13
  BasePromptBuilder,
@@ -108,7 +108,11 @@ async def test_amazon_bedrock(tmp_path):
108
108
  async def test_mock(tmp_path):
109
109
  task = build_test_task(tmp_path)
110
110
  mockChatModel = FakeListChatModel(responses=["mock response"])
111
- adapter = LangchainAdapter(task, custom_model=mockChatModel)
111
+ adapter = LangchainAdapter(
112
+ task,
113
+ custom_model=mockChatModel,
114
+ provider="ollama",
115
+ )
112
116
  run = await adapter.invoke("You are a mock, send me the response!")
113
117
  assert "mock response" in run.output.output
114
118
 
@@ -116,7 +120,7 @@ async def test_mock(tmp_path):
116
120
  async def test_mock_returning_run(tmp_path):
117
121
  task = build_test_task(tmp_path)
118
122
  mockChatModel = FakeListChatModel(responses=["mock response"])
119
- adapter = LangchainAdapter(task, custom_model=mockChatModel)
123
+ adapter = LangchainAdapter(task, custom_model=mockChatModel, provider="ollama")
120
124
  run = await adapter.invoke("You are a mock, send me the response!")
121
125
  assert run.output.output == "mock response"
122
126
  assert run is not None
@@ -127,7 +131,7 @@ async def test_mock_returning_run(tmp_path):
127
131
  assert run.output.source.properties == {
128
132
  "adapter_name": "kiln_langchain_adapter",
129
133
  "model_name": "custom.langchain:unknown_model",
130
- "model_provider": "custom.langchain:FakeListChatModel",
134
+ "model_provider": "ollama",
131
135
  "prompt_builder_name": "simple_prompt_builder",
132
136
  }
133
137
 
@@ -2,24 +2,31 @@ import json
2
2
 
3
3
  import pytest
4
4
 
5
- from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter
5
+ from kiln_ai.adapters.model_adapters.base_adapter import AdapterInfo, BaseAdapter
6
+ from kiln_ai.adapters.model_adapters.test_structured_output import (
7
+ build_structured_output_test_task,
8
+ )
6
9
  from kiln_ai.adapters.prompt_builders import (
7
10
  FewShotChainOfThoughtPromptBuilder,
8
11
  FewShotPromptBuilder,
12
+ FineTunePromptBuilder,
9
13
  MultiShotChainOfThoughtPromptBuilder,
10
14
  MultiShotPromptBuilder,
11
15
  RepairsPromptBuilder,
16
+ SavedPromptBuilder,
12
17
  SimpleChainOfThoughtPromptBuilder,
13
18
  SimplePromptBuilder,
14
19
  chain_of_thought_prompt,
15
20
  prompt_builder_from_ui_name,
16
21
  )
17
22
  from kiln_ai.adapters.test_prompt_adaptors import build_test_task
18
- from kiln_ai.adapters.test_structured_output import build_structured_output_test_task
19
23
  from kiln_ai.datamodel import (
20
24
  DataSource,
21
25
  DataSourceType,
26
+ Finetune,
27
+ FinetuneDataStrategy,
22
28
  Project,
29
+ Prompt,
23
30
  Task,
24
31
  TaskOutput,
25
32
  TaskOutputRating,
@@ -31,7 +38,7 @@ def test_simple_prompt_builder(tmp_path):
31
38
  task = build_test_task(tmp_path)
32
39
  builder = SimplePromptBuilder(task=task)
33
40
  input = "two plus two"
34
- prompt = builder.build_prompt()
41
+ prompt = builder.build_prompt(include_json_instructions=False)
35
42
  assert (
36
43
  "You are an assistant which performs math tasks provided in plain text."
37
44
  in prompt
@@ -62,7 +69,7 @@ def test_simple_prompt_builder_structured_output(tmp_path):
62
69
  task = build_structured_output_test_task(tmp_path)
63
70
  builder = SimplePromptBuilder(task=task)
64
71
  input = "Cows"
65
- prompt = builder.build_prompt()
72
+ prompt = builder.build_prompt(include_json_instructions=False)
66
73
  assert "You are an assistant which tells a joke, given a subject." in prompt
67
74
 
68
75
  user_msg = builder.build_user_message(input)
@@ -70,6 +77,14 @@ def test_simple_prompt_builder_structured_output(tmp_path):
70
77
  assert input not in prompt
71
78
 
72
79
 
80
+ def test_simple_prompt_builder_structured_input_non_ascii(tmp_path):
81
+ task = build_structured_output_test_task(tmp_path)
82
+ builder = SimplePromptBuilder(task=task)
83
+ input = {"key": "你好👋"}
84
+ user_msg = builder.build_user_message(input)
85
+ assert "你好👋" in user_msg
86
+
87
+
73
88
  @pytest.fixture
74
89
  def task_with_examples(tmp_path):
75
90
  # Create a project and task hierarchy
@@ -198,7 +213,7 @@ def task_with_examples(tmp_path):
198
213
  def test_multi_shot_prompt_builder(task_with_examples):
199
214
  # Verify the order of examples
200
215
  prompt_builder = MultiShotPromptBuilder(task=task_with_examples)
201
- prompt = prompt_builder.build_prompt()
216
+ prompt = prompt_builder.build_prompt(include_json_instructions=False)
202
217
  assert "Why did the cow cross the road?" in prompt
203
218
  assert prompt.index("Why did the cow cross the road?") < prompt.index(
204
219
  "Why don't cats play poker in the jungle?"
@@ -239,14 +254,14 @@ def test_few_shot_prompt_builder(tmp_path):
239
254
  # Create 6 examples (2 repaired, 4 high-quality)
240
255
  for i in range(6):
241
256
  run = TaskRun(
242
- input=f'{{"subject": "Subject {i+1}"}}',
257
+ input=f'{{"subject": "Subject {i + 1}"}}',
243
258
  input_source=DataSource(
244
259
  type=DataSourceType.human,
245
260
  properties={"created_by": "john_doe"},
246
261
  ),
247
262
  parent=task,
248
263
  output=TaskOutput(
249
- output=f'{{"joke": "Joke Initial Output {i+1}"}}',
264
+ output=f'{{"joke": "Joke Initial Output {i + 1}"}}',
250
265
  source=DataSource(
251
266
  type=DataSourceType.human,
252
267
  properties={"created_by": "john_doe"},
@@ -260,7 +275,7 @@ def test_few_shot_prompt_builder(tmp_path):
260
275
  update={
261
276
  "repair_instructions": "Fix the joke",
262
277
  "repaired_output": TaskOutput(
263
- output=f'{{"joke": "Repaired Joke {i+1}"}}',
278
+ output=f'{{"joke": "Repaired Joke {i + 1}"}}',
264
279
  source=DataSource(
265
280
  type=DataSourceType.human,
266
281
  properties={"created_by": "jane_doe"},
@@ -272,7 +287,7 @@ def test_few_shot_prompt_builder(tmp_path):
272
287
 
273
288
  # Check that only 4 examples are included
274
289
  prompt_builder = FewShotPromptBuilder(task=task)
275
- prompt = prompt_builder.build_prompt()
290
+ prompt = prompt_builder.build_prompt(include_json_instructions=False)
276
291
  assert prompt.count("## Example") == 4
277
292
 
278
293
  print("PROMPT", prompt)
@@ -289,7 +304,7 @@ def test_few_shot_prompt_builder(tmp_path):
289
304
 
290
305
  def check_example_outputs(task: Task, count: int):
291
306
  prompt_builder = MultiShotPromptBuilder(task=task)
292
- prompt = prompt_builder.build_prompt()
307
+ prompt = prompt_builder.build_prompt(include_json_instructions=False)
293
308
  assert "# Instruction" in prompt
294
309
  assert task.instruction in prompt
295
310
  if count == 0:
@@ -305,26 +320,84 @@ def test_prompt_builder_name():
305
320
  assert RepairsPromptBuilder.prompt_builder_name() == "repairs_prompt_builder"
306
321
 
307
322
 
308
- def test_prompt_builder_from_ui_name():
309
- assert prompt_builder_from_ui_name("basic") == SimplePromptBuilder
310
- assert prompt_builder_from_ui_name("few_shot") == FewShotPromptBuilder
311
- assert prompt_builder_from_ui_name("many_shot") == MultiShotPromptBuilder
312
- assert prompt_builder_from_ui_name("repairs") == RepairsPromptBuilder
313
- assert (
314
- prompt_builder_from_ui_name("simple_chain_of_thought")
315
- == SimpleChainOfThoughtPromptBuilder
323
+ def test_prompt_builder_from_ui_name(task_with_examples):
324
+ task = task_with_examples
325
+ assert isinstance(prompt_builder_from_ui_name("basic", task), SimplePromptBuilder)
326
+ assert isinstance(
327
+ prompt_builder_from_ui_name("few_shot", task), FewShotPromptBuilder
316
328
  )
317
- assert (
318
- prompt_builder_from_ui_name("few_shot_chain_of_thought")
319
- == FewShotChainOfThoughtPromptBuilder
329
+ assert isinstance(
330
+ prompt_builder_from_ui_name("many_shot", task), MultiShotPromptBuilder
320
331
  )
321
- assert (
322
- prompt_builder_from_ui_name("multi_shot_chain_of_thought")
323
- == MultiShotChainOfThoughtPromptBuilder
332
+ assert isinstance(
333
+ prompt_builder_from_ui_name("repairs", task), RepairsPromptBuilder
334
+ )
335
+ assert isinstance(
336
+ prompt_builder_from_ui_name("simple_chain_of_thought", task),
337
+ SimpleChainOfThoughtPromptBuilder,
338
+ )
339
+ assert isinstance(
340
+ prompt_builder_from_ui_name("few_shot_chain_of_thought", task),
341
+ FewShotChainOfThoughtPromptBuilder,
342
+ )
343
+ assert isinstance(
344
+ prompt_builder_from_ui_name("multi_shot_chain_of_thought", task),
345
+ MultiShotChainOfThoughtPromptBuilder,
324
346
  )
325
347
 
326
348
  with pytest.raises(ValueError, match="Unknown prompt builder: invalid_name"):
327
- prompt_builder_from_ui_name("invalid_name")
349
+ prompt_builder_from_ui_name("invalid_name", task)
350
+
351
+ with pytest.raises(ValueError, match="Prompt ID not found: 123"):
352
+ prompt_builder_from_ui_name("id::123", task)
353
+
354
+ with pytest.raises(
355
+ ValueError,
356
+ match="Invalid fine-tune ID format. Expected 'project_id::task_id::fine_tune_id'",
357
+ ):
358
+ prompt_builder_from_ui_name("fine_tune_prompt::123", task)
359
+
360
+ with pytest.raises(
361
+ ValueError,
362
+ match="Fine-tune ID not found",
363
+ ):
364
+ prompt_builder_from_ui_name("fine_tune_prompt::123::456::789", task)
365
+
366
+ prompt = Prompt(
367
+ name="test_prompt_name",
368
+ prompt="test_prompt",
369
+ chain_of_thought_instructions="coti",
370
+ parent=task,
371
+ )
372
+ prompt.save_to_file()
373
+ pb = prompt_builder_from_ui_name("id::" + prompt.id, task)
374
+ assert isinstance(pb, SavedPromptBuilder)
375
+ assert pb.prompt_id() == prompt.id
376
+ assert pb.build_prompt(include_json_instructions=False) == "test_prompt"
377
+ assert pb.chain_of_thought_prompt() == "coti"
378
+
379
+ finetune = Finetune(
380
+ name="test_finetune_name",
381
+ system_message="test_system_message",
382
+ thinking_instructions="test_thinking_instructions",
383
+ parent=task,
384
+ base_model_id="test_base_model_id",
385
+ dataset_split_id="asdf",
386
+ provider="test_provider",
387
+ data_strategy=FinetuneDataStrategy.final_and_intermediate,
388
+ )
389
+ finetune.save_to_file()
390
+ nested_fine_tune_id = (
391
+ task_with_examples.parent.id + "::" + task_with_examples.id + "::" + finetune.id
392
+ )
393
+ pb = prompt_builder_from_ui_name(
394
+ "fine_tune_prompt::" + nested_fine_tune_id,
395
+ task_with_examples,
396
+ )
397
+ assert isinstance(pb, FineTunePromptBuilder)
398
+ assert pb.prompt_id() == nested_fine_tune_id
399
+ assert pb.build_base_prompt() == "test_system_message"
400
+ assert pb.chain_of_thought_prompt() == "test_thinking_instructions"
328
401
 
329
402
 
330
403
  def test_example_count():
@@ -335,7 +408,7 @@ def test_example_count():
335
408
  def test_repair_multi_shot_prompt_builder(task_with_examples):
336
409
  # Verify the order of examples
337
410
  prompt_builder = RepairsPromptBuilder(task=task_with_examples)
338
- prompt = prompt_builder.build_prompt()
411
+ prompt = prompt_builder.build_prompt(include_json_instructions=False)
339
412
  assert (
340
413
  'Repaired Output Which is Sufficient: {"joke": "Why did the cow cross the road? To get to the udder side!"}'
341
414
  in prompt
@@ -403,7 +476,7 @@ def test_build_prompt_for_ui(tmp_path):
403
476
  ui_prompt = simple_builder.build_prompt_for_ui()
404
477
 
405
478
  # Should match regular prompt since no chain of thought
406
- assert ui_prompt == simple_builder.build_prompt()
479
+ assert ui_prompt == simple_builder.build_prompt(include_json_instructions=False)
407
480
  assert "# Thinking Instructions" not in ui_prompt
408
481
 
409
482
  # Test chain of thought prompt builder
@@ -411,7 +484,7 @@ def test_build_prompt_for_ui(tmp_path):
411
484
  ui_prompt_cot = cot_builder.build_prompt_for_ui()
412
485
 
413
486
  # Should include both base prompt and thinking instructions
414
- assert cot_builder.build_prompt() in ui_prompt_cot
487
+ assert cot_builder.build_prompt(include_json_instructions=False) in ui_prompt_cot
415
488
  assert "# Thinking Instructions" in ui_prompt_cot
416
489
  assert "Think step by step" in ui_prompt_cot
417
490
 
@@ -423,6 +496,94 @@ def test_build_prompt_for_ui(tmp_path):
423
496
  custom_cot_builder = SimpleChainOfThoughtPromptBuilder(task=task_with_custom)
424
497
  ui_prompt_custom = custom_cot_builder.build_prompt_for_ui()
425
498
 
426
- assert custom_cot_builder.build_prompt() in ui_prompt_custom
499
+ assert (
500
+ custom_cot_builder.build_prompt(include_json_instructions=False)
501
+ in ui_prompt_custom
502
+ )
427
503
  assert "# Thinking Instructions" in ui_prompt_custom
428
504
  assert custom_instruction in ui_prompt_custom
505
+
506
+
507
+ def test_saved_prompt_builder(tmp_path):
508
+ task = build_test_task(tmp_path)
509
+
510
+ prompt = Prompt(
511
+ name="test_prompt_name",
512
+ prompt="test_prompt",
513
+ parent=task,
514
+ )
515
+ prompt.save_to_file()
516
+
517
+ builder = SavedPromptBuilder(task=task, prompt_id=prompt.id)
518
+ assert builder.build_prompt(include_json_instructions=False) == "test_prompt"
519
+ assert builder.chain_of_thought_prompt() is None
520
+ assert builder.build_prompt_for_ui() == "test_prompt"
521
+ assert builder.prompt_id() == prompt.id
522
+
523
+
524
+ def test_saved_prompt_builder_with_chain_of_thought(tmp_path):
525
+ task = build_test_task(tmp_path)
526
+
527
+ prompt = Prompt(
528
+ name="test_prompt_name",
529
+ prompt="test_prompt",
530
+ chain_of_thought_instructions="Think step by step",
531
+ parent=task,
532
+ )
533
+ prompt.save_to_file()
534
+
535
+ builder = SavedPromptBuilder(task=task, prompt_id=prompt.id)
536
+ assert builder.build_prompt(include_json_instructions=False) == "test_prompt"
537
+ assert builder.chain_of_thought_prompt() == "Think step by step"
538
+ assert "Think step by step" in builder.build_prompt_for_ui()
539
+ assert builder.prompt_id() == prompt.id
540
+
541
+
542
+ def test_saved_prompt_builder_not_found(tmp_path):
543
+ task = build_test_task(tmp_path)
544
+
545
+ with pytest.raises(ValueError, match="Prompt ID not found: 123"):
546
+ SavedPromptBuilder(task=task, prompt_id="123")
547
+
548
+
549
+ def test_build_prompt_with_json_instructions(tmp_path):
550
+ task = build_test_task(tmp_path)
551
+ task = task.model_copy(
552
+ update={
553
+ "output_json_schema": json.dumps(
554
+ {
555
+ "type": "object",
556
+ "properties": {"result": {"type": "string"}},
557
+ "required": ["result"],
558
+ }
559
+ )
560
+ }
561
+ )
562
+
563
+ builder = SimplePromptBuilder(task=task)
564
+
565
+ # Test without JSON instructions
566
+ prompt_without_json = builder.build_prompt(include_json_instructions=False)
567
+ assert "Format Instructions" not in prompt_without_json
568
+ assert (
569
+ "Return a JSON object conforming to the following schema:"
570
+ not in prompt_without_json
571
+ )
572
+ assert task.output_json_schema not in prompt_without_json
573
+
574
+ # Test with JSON instructions
575
+ prompt_with_json = builder.build_prompt(include_json_instructions=True)
576
+ assert "# Format Instructions" in prompt_with_json
577
+ assert (
578
+ "Return a JSON object conforming to the following schema:" in prompt_with_json
579
+ )
580
+ assert "```" in prompt_with_json
581
+ assert (
582
+ "{'type': 'object', 'properties': {'result': {'type': 'string'}}, 'required': ['result']}"
583
+ in prompt_with_json
584
+ )
585
+
586
+ # Verify base prompt is still included
587
+ assert task.instruction in prompt_with_json
588
+ for requirement in task.requirements:
589
+ assert requirement.instruction in prompt_with_json