kiln-ai 0.8.1__py3-none-any.whl → 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +7 -7
- kiln_ai/adapters/adapter_registry.py +77 -5
- kiln_ai/adapters/data_gen/data_gen_task.py +3 -3
- kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
- kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
- kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
- kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
- kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +469 -129
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +113 -21
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
- kiln_ai/adapters/ml_model_list.py +323 -94
- kiln_ai/adapters/model_adapters/__init__.py +18 -0
- kiln_ai/adapters/{base_adapter.py → model_adapters/base_adapter.py} +81 -37
- kiln_ai/adapters/{langchain_adapters.py → model_adapters/langchain_adapters.py} +130 -84
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +11 -0
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +246 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +190 -0
- kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +103 -88
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +225 -0
- kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +43 -15
- kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +93 -20
- kiln_ai/adapters/parsers/__init__.py +10 -0
- kiln_ai/adapters/parsers/base_parser.py +12 -0
- kiln_ai/adapters/parsers/json_parser.py +37 -0
- kiln_ai/adapters/parsers/parser_registry.py +19 -0
- kiln_ai/adapters/parsers/r1_parser.py +69 -0
- kiln_ai/adapters/parsers/test_json_parser.py +81 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
- kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
- kiln_ai/adapters/prompt_builders.py +126 -20
- kiln_ai/adapters/provider_tools.py +91 -36
- kiln_ai/adapters/repair/repair_task.py +17 -6
- kiln_ai/adapters/repair/test_repair_task.py +4 -4
- kiln_ai/adapters/run_output.py +8 -0
- kiln_ai/adapters/test_adapter_registry.py +177 -0
- kiln_ai/adapters/test_generate_docs.py +69 -0
- kiln_ai/adapters/test_prompt_adaptors.py +8 -4
- kiln_ai/adapters/test_prompt_builders.py +190 -29
- kiln_ai/adapters/test_provider_tools.py +268 -46
- kiln_ai/datamodel/__init__.py +193 -12
- kiln_ai/datamodel/basemodel.py +31 -11
- kiln_ai/datamodel/json_schema.py +8 -3
- kiln_ai/datamodel/model_cache.py +8 -3
- kiln_ai/datamodel/test_basemodel.py +81 -2
- kiln_ai/datamodel/test_dataset_split.py +100 -3
- kiln_ai/datamodel/test_example_models.py +25 -4
- kiln_ai/datamodel/test_model_cache.py +24 -0
- kiln_ai/datamodel/test_model_perf.py +125 -0
- kiln_ai/datamodel/test_models.py +129 -0
- kiln_ai/utils/exhaustive_error.py +6 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/METADATA +9 -7
- kiln_ai-0.11.1.dist-info/RECORD +76 -0
- kiln_ai-0.8.1.dist-info/RECORD +0 -58
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
from unittest.mock import patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from kiln_ai import datamodel
|
|
6
|
+
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
7
|
+
from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
8
|
+
from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
|
|
9
|
+
from kiln_ai.adapters.model_adapters.openai_model_adapter import OpenAICompatibleAdapter
|
|
10
|
+
from kiln_ai.adapters.prompt_builders import BasePromptBuilder
|
|
11
|
+
from kiln_ai.adapters.provider_tools import kiln_model_provider_from
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@pytest.fixture
|
|
15
|
+
def mock_config():
|
|
16
|
+
with patch("kiln_ai.adapters.adapter_registry.Config") as mock:
|
|
17
|
+
mock.shared.return_value.open_ai_api_key = "test-openai-key"
|
|
18
|
+
mock.shared.return_value.open_router_api_key = "test-openrouter-key"
|
|
19
|
+
yield mock
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@pytest.fixture
|
|
23
|
+
def basic_task():
|
|
24
|
+
return datamodel.Task(
|
|
25
|
+
task_id="test-task",
|
|
26
|
+
task_type="test",
|
|
27
|
+
input_text="test input",
|
|
28
|
+
name="test-task",
|
|
29
|
+
instruction="test-task",
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@pytest.fixture
|
|
34
|
+
def mock_finetune_from_id():
|
|
35
|
+
with patch("kiln_ai.adapters.provider_tools.finetune_from_id") as mock:
|
|
36
|
+
mock.return_value.provider = ModelProviderName.openai
|
|
37
|
+
mock.return_value.fine_tune_model_id = "test-model"
|
|
38
|
+
yield mock
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def test_openai_adapter_creation(mock_config, basic_task):
|
|
42
|
+
adapter = adapter_for_task(
|
|
43
|
+
kiln_task=basic_task, model_name="gpt-4", provider=ModelProviderName.openai
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
assert isinstance(adapter, OpenAICompatibleAdapter)
|
|
47
|
+
assert adapter.config.model_name == "gpt-4"
|
|
48
|
+
assert adapter.config.api_key == "test-openai-key"
|
|
49
|
+
assert adapter.config.provider_name == ModelProviderName.openai
|
|
50
|
+
assert adapter.config.base_url is None # OpenAI url is default
|
|
51
|
+
assert adapter.config.default_headers is None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def test_openrouter_adapter_creation(mock_config, basic_task):
|
|
55
|
+
adapter = adapter_for_task(
|
|
56
|
+
kiln_task=basic_task,
|
|
57
|
+
model_name="anthropic/claude-3-opus",
|
|
58
|
+
provider=ModelProviderName.openrouter,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
assert isinstance(adapter, OpenAICompatibleAdapter)
|
|
62
|
+
assert adapter.config.model_name == "anthropic/claude-3-opus"
|
|
63
|
+
assert adapter.config.api_key == "test-openrouter-key"
|
|
64
|
+
assert adapter.config.provider_name == ModelProviderName.openrouter
|
|
65
|
+
assert adapter.config.base_url == "https://openrouter.ai/api/v1"
|
|
66
|
+
assert adapter.config.default_headers == {
|
|
67
|
+
"HTTP-Referer": "https://getkiln.ai/openrouter",
|
|
68
|
+
"X-Title": "KilnAI",
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@pytest.mark.parametrize(
|
|
73
|
+
"provider",
|
|
74
|
+
[
|
|
75
|
+
ModelProviderName.groq,
|
|
76
|
+
ModelProviderName.amazon_bedrock,
|
|
77
|
+
ModelProviderName.ollama,
|
|
78
|
+
ModelProviderName.fireworks_ai,
|
|
79
|
+
],
|
|
80
|
+
)
|
|
81
|
+
def test_langchain_adapter_creation(mock_config, basic_task, provider):
|
|
82
|
+
adapter = adapter_for_task(
|
|
83
|
+
kiln_task=basic_task, model_name="test-model", provider=provider
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
assert isinstance(adapter, LangchainAdapter)
|
|
87
|
+
assert adapter.model_name == "test-model"
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# TODO should run for all cases
|
|
91
|
+
def test_custom_prompt_builder(mock_config, basic_task):
|
|
92
|
+
class TestPromptBuilder(BasePromptBuilder):
|
|
93
|
+
def build_base_prompt(self, kiln_task) -> str:
|
|
94
|
+
return "test-prompt"
|
|
95
|
+
|
|
96
|
+
prompt_builder = TestPromptBuilder(basic_task)
|
|
97
|
+
adapter = adapter_for_task(
|
|
98
|
+
kiln_task=basic_task,
|
|
99
|
+
model_name="gpt-4",
|
|
100
|
+
provider=ModelProviderName.openai,
|
|
101
|
+
prompt_builder=prompt_builder,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
assert adapter.prompt_builder == prompt_builder
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
# TODO should run for all cases
|
|
108
|
+
def test_tags_passed_through(mock_config, basic_task):
|
|
109
|
+
tags = ["test-tag-1", "test-tag-2"]
|
|
110
|
+
adapter = adapter_for_task(
|
|
111
|
+
kiln_task=basic_task,
|
|
112
|
+
model_name="gpt-4",
|
|
113
|
+
provider=ModelProviderName.openai,
|
|
114
|
+
tags=tags,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
assert adapter.default_tags == tags
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def test_invalid_provider(mock_config, basic_task):
|
|
121
|
+
with pytest.raises(ValueError, match="Unhandled enum value"):
|
|
122
|
+
adapter_for_task(
|
|
123
|
+
kiln_task=basic_task, model_name="test-model", provider="invalid"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@patch("kiln_ai.adapters.adapter_registry.openai_compatible_config")
|
|
128
|
+
def test_openai_compatible_adapter(mock_compatible_config, mock_config, basic_task):
|
|
129
|
+
mock_compatible_config.return_value.model_name = "test-model"
|
|
130
|
+
mock_compatible_config.return_value.api_key = "test-key"
|
|
131
|
+
mock_compatible_config.return_value.base_url = "https://test.com/v1"
|
|
132
|
+
|
|
133
|
+
adapter = adapter_for_task(
|
|
134
|
+
kiln_task=basic_task,
|
|
135
|
+
model_name="provider::test-model",
|
|
136
|
+
provider=ModelProviderName.openai_compatible,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
assert isinstance(adapter, OpenAICompatibleAdapter)
|
|
140
|
+
mock_compatible_config.assert_called_once_with("provider::test-model")
|
|
141
|
+
assert adapter.config.model_name == "test-model"
|
|
142
|
+
assert adapter.config.api_key == "test-key"
|
|
143
|
+
assert adapter.config.base_url == "https://test.com/v1"
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def test_custom_openai_compatible_provider(mock_config, basic_task):
|
|
147
|
+
adapter = adapter_for_task(
|
|
148
|
+
kiln_task=basic_task,
|
|
149
|
+
model_name="openai::test-model",
|
|
150
|
+
provider=ModelProviderName.kiln_custom_registry,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
assert isinstance(adapter, OpenAICompatibleAdapter)
|
|
154
|
+
assert adapter.config.model_name == "openai::test-model"
|
|
155
|
+
assert adapter.config.api_key == "test-openai-key"
|
|
156
|
+
assert adapter.config.base_url is None # openai is none
|
|
157
|
+
assert adapter.config.provider_name == ModelProviderName.kiln_custom_registry
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
async def test_fine_tune_provider(mock_config, basic_task, mock_finetune_from_id):
|
|
161
|
+
adapter = adapter_for_task(
|
|
162
|
+
kiln_task=basic_task,
|
|
163
|
+
model_name="proj::task::tune",
|
|
164
|
+
provider=ModelProviderName.kiln_fine_tune,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
mock_finetune_from_id.assert_called_once_with("proj::task::tune")
|
|
168
|
+
assert isinstance(adapter, OpenAICompatibleAdapter)
|
|
169
|
+
assert adapter.config.provider_name == ModelProviderName.kiln_fine_tune
|
|
170
|
+
# Kiln model name here, but the underlying openai model id below
|
|
171
|
+
assert adapter.config.model_name == "proj::task::tune"
|
|
172
|
+
|
|
173
|
+
provider = kiln_model_provider_from(
|
|
174
|
+
"proj::task::tune", provider_name=ModelProviderName.kiln_fine_tune
|
|
175
|
+
)
|
|
176
|
+
# The actual model name from the fine tune object
|
|
177
|
+
assert provider.provider_options["model"] == "test-model"
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from libs.core.kiln_ai.adapters.ml_model_list import (
|
|
6
|
+
KilnModelProvider,
|
|
7
|
+
built_in_models,
|
|
8
|
+
)
|
|
9
|
+
from libs.core.kiln_ai.adapters.provider_tools import provider_name_from_id
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _all_providers_support(providers: List[KilnModelProvider], attribute: str) -> bool:
|
|
13
|
+
"""Check if all providers support a given feature"""
|
|
14
|
+
return all(getattr(provider, attribute) for provider in providers)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _any_providers_support(providers: List[KilnModelProvider], attribute: str) -> bool:
|
|
18
|
+
"""Check if any providers support a given feature"""
|
|
19
|
+
return any(getattr(provider, attribute) for provider in providers)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _get_support_status(providers: List[KilnModelProvider], attribute: str) -> str:
|
|
23
|
+
"""Get the support status for a feature"""
|
|
24
|
+
if _all_providers_support(providers, attribute):
|
|
25
|
+
return "✅︎"
|
|
26
|
+
elif _any_providers_support(providers, attribute):
|
|
27
|
+
return "✅︎ (some providers)"
|
|
28
|
+
return ""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _has_finetune_support(providers: List[KilnModelProvider]) -> str:
|
|
32
|
+
"""Check if any provider supports fine-tuning"""
|
|
33
|
+
return "✅︎" if any(p.provider_finetune_id for p in providers) else ""
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@pytest.mark.paid(reason="Marking as paid so it isn't run by default")
|
|
37
|
+
def test_generate_model_table():
|
|
38
|
+
"""Generate a markdown table of all models and their capabilities"""
|
|
39
|
+
|
|
40
|
+
# Table header
|
|
41
|
+
table = [
|
|
42
|
+
"| Model Name | Providers | Structured Output | Reasoning | Synthetic Data | API Fine-Tuneable |",
|
|
43
|
+
"|------------|-----------|-------------------|-----------|----------------|-------------------|",
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
for model in built_in_models:
|
|
47
|
+
provider_names = ", ".join(
|
|
48
|
+
sorted(provider_name_from_id(p.name.value) for p in model.providers)
|
|
49
|
+
)
|
|
50
|
+
structured_output = _get_support_status(
|
|
51
|
+
model.providers, "supports_structured_output"
|
|
52
|
+
)
|
|
53
|
+
reasoning = _get_support_status(model.providers, "reasoning_capable")
|
|
54
|
+
data_gen = _get_support_status(model.providers, "supports_data_gen")
|
|
55
|
+
finetune = _has_finetune_support(model.providers)
|
|
56
|
+
|
|
57
|
+
row = f"| {model.friendly_name} | {provider_names} | {structured_output} | {reasoning} | {data_gen} | {finetune} |"
|
|
58
|
+
table.append(row)
|
|
59
|
+
|
|
60
|
+
# Print the table (useful for documentation)
|
|
61
|
+
print("\nModel Capability Matrix:\n")
|
|
62
|
+
print("\n".join(table))
|
|
63
|
+
|
|
64
|
+
# Basic assertions to ensure the table is well-formed
|
|
65
|
+
assert len(table) > 2, "Table should have header and at least one row"
|
|
66
|
+
assert all("|" in row for row in table), "All rows should be properly formatted"
|
|
67
|
+
assert len(table[0].split("|")) == len(table[1].split("|")), (
|
|
68
|
+
"Header and separator should have same number of columns"
|
|
69
|
+
)
|
|
@@ -6,8 +6,8 @@ from langchain_core.language_models.fake_chat_models import FakeListChatModel
|
|
|
6
6
|
|
|
7
7
|
import kiln_ai.datamodel as datamodel
|
|
8
8
|
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
9
|
-
from kiln_ai.adapters.langchain_adapters import LangchainAdapter
|
|
10
9
|
from kiln_ai.adapters.ml_model_list import built_in_models
|
|
10
|
+
from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
|
|
11
11
|
from kiln_ai.adapters.ollama_tools import ollama_online
|
|
12
12
|
from kiln_ai.adapters.prompt_builders import (
|
|
13
13
|
BasePromptBuilder,
|
|
@@ -108,7 +108,11 @@ async def test_amazon_bedrock(tmp_path):
|
|
|
108
108
|
async def test_mock(tmp_path):
|
|
109
109
|
task = build_test_task(tmp_path)
|
|
110
110
|
mockChatModel = FakeListChatModel(responses=["mock response"])
|
|
111
|
-
adapter = LangchainAdapter(
|
|
111
|
+
adapter = LangchainAdapter(
|
|
112
|
+
task,
|
|
113
|
+
custom_model=mockChatModel,
|
|
114
|
+
provider="ollama",
|
|
115
|
+
)
|
|
112
116
|
run = await adapter.invoke("You are a mock, send me the response!")
|
|
113
117
|
assert "mock response" in run.output.output
|
|
114
118
|
|
|
@@ -116,7 +120,7 @@ async def test_mock(tmp_path):
|
|
|
116
120
|
async def test_mock_returning_run(tmp_path):
|
|
117
121
|
task = build_test_task(tmp_path)
|
|
118
122
|
mockChatModel = FakeListChatModel(responses=["mock response"])
|
|
119
|
-
adapter = LangchainAdapter(task, custom_model=mockChatModel)
|
|
123
|
+
adapter = LangchainAdapter(task, custom_model=mockChatModel, provider="ollama")
|
|
120
124
|
run = await adapter.invoke("You are a mock, send me the response!")
|
|
121
125
|
assert run.output.output == "mock response"
|
|
122
126
|
assert run is not None
|
|
@@ -127,7 +131,7 @@ async def test_mock_returning_run(tmp_path):
|
|
|
127
131
|
assert run.output.source.properties == {
|
|
128
132
|
"adapter_name": "kiln_langchain_adapter",
|
|
129
133
|
"model_name": "custom.langchain:unknown_model",
|
|
130
|
-
"model_provider": "
|
|
134
|
+
"model_provider": "ollama",
|
|
131
135
|
"prompt_builder_name": "simple_prompt_builder",
|
|
132
136
|
}
|
|
133
137
|
|
|
@@ -2,24 +2,31 @@ import json
|
|
|
2
2
|
|
|
3
3
|
import pytest
|
|
4
4
|
|
|
5
|
-
from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter
|
|
5
|
+
from kiln_ai.adapters.model_adapters.base_adapter import AdapterInfo, BaseAdapter
|
|
6
|
+
from kiln_ai.adapters.model_adapters.test_structured_output import (
|
|
7
|
+
build_structured_output_test_task,
|
|
8
|
+
)
|
|
6
9
|
from kiln_ai.adapters.prompt_builders import (
|
|
7
10
|
FewShotChainOfThoughtPromptBuilder,
|
|
8
11
|
FewShotPromptBuilder,
|
|
12
|
+
FineTunePromptBuilder,
|
|
9
13
|
MultiShotChainOfThoughtPromptBuilder,
|
|
10
14
|
MultiShotPromptBuilder,
|
|
11
15
|
RepairsPromptBuilder,
|
|
16
|
+
SavedPromptBuilder,
|
|
12
17
|
SimpleChainOfThoughtPromptBuilder,
|
|
13
18
|
SimplePromptBuilder,
|
|
14
19
|
chain_of_thought_prompt,
|
|
15
20
|
prompt_builder_from_ui_name,
|
|
16
21
|
)
|
|
17
22
|
from kiln_ai.adapters.test_prompt_adaptors import build_test_task
|
|
18
|
-
from kiln_ai.adapters.test_structured_output import build_structured_output_test_task
|
|
19
23
|
from kiln_ai.datamodel import (
|
|
20
24
|
DataSource,
|
|
21
25
|
DataSourceType,
|
|
26
|
+
Finetune,
|
|
27
|
+
FinetuneDataStrategy,
|
|
22
28
|
Project,
|
|
29
|
+
Prompt,
|
|
23
30
|
Task,
|
|
24
31
|
TaskOutput,
|
|
25
32
|
TaskOutputRating,
|
|
@@ -31,7 +38,7 @@ def test_simple_prompt_builder(tmp_path):
|
|
|
31
38
|
task = build_test_task(tmp_path)
|
|
32
39
|
builder = SimplePromptBuilder(task=task)
|
|
33
40
|
input = "two plus two"
|
|
34
|
-
prompt = builder.build_prompt()
|
|
41
|
+
prompt = builder.build_prompt(include_json_instructions=False)
|
|
35
42
|
assert (
|
|
36
43
|
"You are an assistant which performs math tasks provided in plain text."
|
|
37
44
|
in prompt
|
|
@@ -62,7 +69,7 @@ def test_simple_prompt_builder_structured_output(tmp_path):
|
|
|
62
69
|
task = build_structured_output_test_task(tmp_path)
|
|
63
70
|
builder = SimplePromptBuilder(task=task)
|
|
64
71
|
input = "Cows"
|
|
65
|
-
prompt = builder.build_prompt()
|
|
72
|
+
prompt = builder.build_prompt(include_json_instructions=False)
|
|
66
73
|
assert "You are an assistant which tells a joke, given a subject." in prompt
|
|
67
74
|
|
|
68
75
|
user_msg = builder.build_user_message(input)
|
|
@@ -70,6 +77,14 @@ def test_simple_prompt_builder_structured_output(tmp_path):
|
|
|
70
77
|
assert input not in prompt
|
|
71
78
|
|
|
72
79
|
|
|
80
|
+
def test_simple_prompt_builder_structured_input_non_ascii(tmp_path):
|
|
81
|
+
task = build_structured_output_test_task(tmp_path)
|
|
82
|
+
builder = SimplePromptBuilder(task=task)
|
|
83
|
+
input = {"key": "你好👋"}
|
|
84
|
+
user_msg = builder.build_user_message(input)
|
|
85
|
+
assert "你好👋" in user_msg
|
|
86
|
+
|
|
87
|
+
|
|
73
88
|
@pytest.fixture
|
|
74
89
|
def task_with_examples(tmp_path):
|
|
75
90
|
# Create a project and task hierarchy
|
|
@@ -198,7 +213,7 @@ def task_with_examples(tmp_path):
|
|
|
198
213
|
def test_multi_shot_prompt_builder(task_with_examples):
|
|
199
214
|
# Verify the order of examples
|
|
200
215
|
prompt_builder = MultiShotPromptBuilder(task=task_with_examples)
|
|
201
|
-
prompt = prompt_builder.build_prompt()
|
|
216
|
+
prompt = prompt_builder.build_prompt(include_json_instructions=False)
|
|
202
217
|
assert "Why did the cow cross the road?" in prompt
|
|
203
218
|
assert prompt.index("Why did the cow cross the road?") < prompt.index(
|
|
204
219
|
"Why don't cats play poker in the jungle?"
|
|
@@ -239,14 +254,14 @@ def test_few_shot_prompt_builder(tmp_path):
|
|
|
239
254
|
# Create 6 examples (2 repaired, 4 high-quality)
|
|
240
255
|
for i in range(6):
|
|
241
256
|
run = TaskRun(
|
|
242
|
-
input=f'{{"subject": "Subject {i+1}"}}',
|
|
257
|
+
input=f'{{"subject": "Subject {i + 1}"}}',
|
|
243
258
|
input_source=DataSource(
|
|
244
259
|
type=DataSourceType.human,
|
|
245
260
|
properties={"created_by": "john_doe"},
|
|
246
261
|
),
|
|
247
262
|
parent=task,
|
|
248
263
|
output=TaskOutput(
|
|
249
|
-
output=f'{{"joke": "Joke Initial Output {i+1}"}}',
|
|
264
|
+
output=f'{{"joke": "Joke Initial Output {i + 1}"}}',
|
|
250
265
|
source=DataSource(
|
|
251
266
|
type=DataSourceType.human,
|
|
252
267
|
properties={"created_by": "john_doe"},
|
|
@@ -260,7 +275,7 @@ def test_few_shot_prompt_builder(tmp_path):
|
|
|
260
275
|
update={
|
|
261
276
|
"repair_instructions": "Fix the joke",
|
|
262
277
|
"repaired_output": TaskOutput(
|
|
263
|
-
output=f'{{"joke": "Repaired Joke {i+1}"}}',
|
|
278
|
+
output=f'{{"joke": "Repaired Joke {i + 1}"}}',
|
|
264
279
|
source=DataSource(
|
|
265
280
|
type=DataSourceType.human,
|
|
266
281
|
properties={"created_by": "jane_doe"},
|
|
@@ -272,7 +287,7 @@ def test_few_shot_prompt_builder(tmp_path):
|
|
|
272
287
|
|
|
273
288
|
# Check that only 4 examples are included
|
|
274
289
|
prompt_builder = FewShotPromptBuilder(task=task)
|
|
275
|
-
prompt = prompt_builder.build_prompt()
|
|
290
|
+
prompt = prompt_builder.build_prompt(include_json_instructions=False)
|
|
276
291
|
assert prompt.count("## Example") == 4
|
|
277
292
|
|
|
278
293
|
print("PROMPT", prompt)
|
|
@@ -289,7 +304,7 @@ def test_few_shot_prompt_builder(tmp_path):
|
|
|
289
304
|
|
|
290
305
|
def check_example_outputs(task: Task, count: int):
|
|
291
306
|
prompt_builder = MultiShotPromptBuilder(task=task)
|
|
292
|
-
prompt = prompt_builder.build_prompt()
|
|
307
|
+
prompt = prompt_builder.build_prompt(include_json_instructions=False)
|
|
293
308
|
assert "# Instruction" in prompt
|
|
294
309
|
assert task.instruction in prompt
|
|
295
310
|
if count == 0:
|
|
@@ -305,26 +320,84 @@ def test_prompt_builder_name():
|
|
|
305
320
|
assert RepairsPromptBuilder.prompt_builder_name() == "repairs_prompt_builder"
|
|
306
321
|
|
|
307
322
|
|
|
308
|
-
def test_prompt_builder_from_ui_name():
|
|
309
|
-
|
|
310
|
-
assert prompt_builder_from_ui_name("
|
|
311
|
-
assert
|
|
312
|
-
|
|
313
|
-
assert (
|
|
314
|
-
prompt_builder_from_ui_name("simple_chain_of_thought")
|
|
315
|
-
== SimpleChainOfThoughtPromptBuilder
|
|
323
|
+
def test_prompt_builder_from_ui_name(task_with_examples):
|
|
324
|
+
task = task_with_examples
|
|
325
|
+
assert isinstance(prompt_builder_from_ui_name("basic", task), SimplePromptBuilder)
|
|
326
|
+
assert isinstance(
|
|
327
|
+
prompt_builder_from_ui_name("few_shot", task), FewShotPromptBuilder
|
|
316
328
|
)
|
|
317
|
-
assert (
|
|
318
|
-
prompt_builder_from_ui_name("
|
|
319
|
-
== FewShotChainOfThoughtPromptBuilder
|
|
329
|
+
assert isinstance(
|
|
330
|
+
prompt_builder_from_ui_name("many_shot", task), MultiShotPromptBuilder
|
|
320
331
|
)
|
|
321
|
-
assert (
|
|
322
|
-
prompt_builder_from_ui_name("
|
|
323
|
-
|
|
332
|
+
assert isinstance(
|
|
333
|
+
prompt_builder_from_ui_name("repairs", task), RepairsPromptBuilder
|
|
334
|
+
)
|
|
335
|
+
assert isinstance(
|
|
336
|
+
prompt_builder_from_ui_name("simple_chain_of_thought", task),
|
|
337
|
+
SimpleChainOfThoughtPromptBuilder,
|
|
338
|
+
)
|
|
339
|
+
assert isinstance(
|
|
340
|
+
prompt_builder_from_ui_name("few_shot_chain_of_thought", task),
|
|
341
|
+
FewShotChainOfThoughtPromptBuilder,
|
|
342
|
+
)
|
|
343
|
+
assert isinstance(
|
|
344
|
+
prompt_builder_from_ui_name("multi_shot_chain_of_thought", task),
|
|
345
|
+
MultiShotChainOfThoughtPromptBuilder,
|
|
324
346
|
)
|
|
325
347
|
|
|
326
348
|
with pytest.raises(ValueError, match="Unknown prompt builder: invalid_name"):
|
|
327
|
-
prompt_builder_from_ui_name("invalid_name")
|
|
349
|
+
prompt_builder_from_ui_name("invalid_name", task)
|
|
350
|
+
|
|
351
|
+
with pytest.raises(ValueError, match="Prompt ID not found: 123"):
|
|
352
|
+
prompt_builder_from_ui_name("id::123", task)
|
|
353
|
+
|
|
354
|
+
with pytest.raises(
|
|
355
|
+
ValueError,
|
|
356
|
+
match="Invalid fine-tune ID format. Expected 'project_id::task_id::fine_tune_id'",
|
|
357
|
+
):
|
|
358
|
+
prompt_builder_from_ui_name("fine_tune_prompt::123", task)
|
|
359
|
+
|
|
360
|
+
with pytest.raises(
|
|
361
|
+
ValueError,
|
|
362
|
+
match="Fine-tune ID not found",
|
|
363
|
+
):
|
|
364
|
+
prompt_builder_from_ui_name("fine_tune_prompt::123::456::789", task)
|
|
365
|
+
|
|
366
|
+
prompt = Prompt(
|
|
367
|
+
name="test_prompt_name",
|
|
368
|
+
prompt="test_prompt",
|
|
369
|
+
chain_of_thought_instructions="coti",
|
|
370
|
+
parent=task,
|
|
371
|
+
)
|
|
372
|
+
prompt.save_to_file()
|
|
373
|
+
pb = prompt_builder_from_ui_name("id::" + prompt.id, task)
|
|
374
|
+
assert isinstance(pb, SavedPromptBuilder)
|
|
375
|
+
assert pb.prompt_id() == prompt.id
|
|
376
|
+
assert pb.build_prompt(include_json_instructions=False) == "test_prompt"
|
|
377
|
+
assert pb.chain_of_thought_prompt() == "coti"
|
|
378
|
+
|
|
379
|
+
finetune = Finetune(
|
|
380
|
+
name="test_finetune_name",
|
|
381
|
+
system_message="test_system_message",
|
|
382
|
+
thinking_instructions="test_thinking_instructions",
|
|
383
|
+
parent=task,
|
|
384
|
+
base_model_id="test_base_model_id",
|
|
385
|
+
dataset_split_id="asdf",
|
|
386
|
+
provider="test_provider",
|
|
387
|
+
data_strategy=FinetuneDataStrategy.final_and_intermediate,
|
|
388
|
+
)
|
|
389
|
+
finetune.save_to_file()
|
|
390
|
+
nested_fine_tune_id = (
|
|
391
|
+
task_with_examples.parent.id + "::" + task_with_examples.id + "::" + finetune.id
|
|
392
|
+
)
|
|
393
|
+
pb = prompt_builder_from_ui_name(
|
|
394
|
+
"fine_tune_prompt::" + nested_fine_tune_id,
|
|
395
|
+
task_with_examples,
|
|
396
|
+
)
|
|
397
|
+
assert isinstance(pb, FineTunePromptBuilder)
|
|
398
|
+
assert pb.prompt_id() == nested_fine_tune_id
|
|
399
|
+
assert pb.build_base_prompt() == "test_system_message"
|
|
400
|
+
assert pb.chain_of_thought_prompt() == "test_thinking_instructions"
|
|
328
401
|
|
|
329
402
|
|
|
330
403
|
def test_example_count():
|
|
@@ -335,7 +408,7 @@ def test_example_count():
|
|
|
335
408
|
def test_repair_multi_shot_prompt_builder(task_with_examples):
|
|
336
409
|
# Verify the order of examples
|
|
337
410
|
prompt_builder = RepairsPromptBuilder(task=task_with_examples)
|
|
338
|
-
prompt = prompt_builder.build_prompt()
|
|
411
|
+
prompt = prompt_builder.build_prompt(include_json_instructions=False)
|
|
339
412
|
assert (
|
|
340
413
|
'Repaired Output Which is Sufficient: {"joke": "Why did the cow cross the road? To get to the udder side!"}'
|
|
341
414
|
in prompt
|
|
@@ -403,7 +476,7 @@ def test_build_prompt_for_ui(tmp_path):
|
|
|
403
476
|
ui_prompt = simple_builder.build_prompt_for_ui()
|
|
404
477
|
|
|
405
478
|
# Should match regular prompt since no chain of thought
|
|
406
|
-
assert ui_prompt == simple_builder.build_prompt()
|
|
479
|
+
assert ui_prompt == simple_builder.build_prompt(include_json_instructions=False)
|
|
407
480
|
assert "# Thinking Instructions" not in ui_prompt
|
|
408
481
|
|
|
409
482
|
# Test chain of thought prompt builder
|
|
@@ -411,7 +484,7 @@ def test_build_prompt_for_ui(tmp_path):
|
|
|
411
484
|
ui_prompt_cot = cot_builder.build_prompt_for_ui()
|
|
412
485
|
|
|
413
486
|
# Should include both base prompt and thinking instructions
|
|
414
|
-
assert cot_builder.build_prompt() in ui_prompt_cot
|
|
487
|
+
assert cot_builder.build_prompt(include_json_instructions=False) in ui_prompt_cot
|
|
415
488
|
assert "# Thinking Instructions" in ui_prompt_cot
|
|
416
489
|
assert "Think step by step" in ui_prompt_cot
|
|
417
490
|
|
|
@@ -423,6 +496,94 @@ def test_build_prompt_for_ui(tmp_path):
|
|
|
423
496
|
custom_cot_builder = SimpleChainOfThoughtPromptBuilder(task=task_with_custom)
|
|
424
497
|
ui_prompt_custom = custom_cot_builder.build_prompt_for_ui()
|
|
425
498
|
|
|
426
|
-
assert
|
|
499
|
+
assert (
|
|
500
|
+
custom_cot_builder.build_prompt(include_json_instructions=False)
|
|
501
|
+
in ui_prompt_custom
|
|
502
|
+
)
|
|
427
503
|
assert "# Thinking Instructions" in ui_prompt_custom
|
|
428
504
|
assert custom_instruction in ui_prompt_custom
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
def test_saved_prompt_builder(tmp_path):
|
|
508
|
+
task = build_test_task(tmp_path)
|
|
509
|
+
|
|
510
|
+
prompt = Prompt(
|
|
511
|
+
name="test_prompt_name",
|
|
512
|
+
prompt="test_prompt",
|
|
513
|
+
parent=task,
|
|
514
|
+
)
|
|
515
|
+
prompt.save_to_file()
|
|
516
|
+
|
|
517
|
+
builder = SavedPromptBuilder(task=task, prompt_id=prompt.id)
|
|
518
|
+
assert builder.build_prompt(include_json_instructions=False) == "test_prompt"
|
|
519
|
+
assert builder.chain_of_thought_prompt() is None
|
|
520
|
+
assert builder.build_prompt_for_ui() == "test_prompt"
|
|
521
|
+
assert builder.prompt_id() == prompt.id
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
def test_saved_prompt_builder_with_chain_of_thought(tmp_path):
|
|
525
|
+
task = build_test_task(tmp_path)
|
|
526
|
+
|
|
527
|
+
prompt = Prompt(
|
|
528
|
+
name="test_prompt_name",
|
|
529
|
+
prompt="test_prompt",
|
|
530
|
+
chain_of_thought_instructions="Think step by step",
|
|
531
|
+
parent=task,
|
|
532
|
+
)
|
|
533
|
+
prompt.save_to_file()
|
|
534
|
+
|
|
535
|
+
builder = SavedPromptBuilder(task=task, prompt_id=prompt.id)
|
|
536
|
+
assert builder.build_prompt(include_json_instructions=False) == "test_prompt"
|
|
537
|
+
assert builder.chain_of_thought_prompt() == "Think step by step"
|
|
538
|
+
assert "Think step by step" in builder.build_prompt_for_ui()
|
|
539
|
+
assert builder.prompt_id() == prompt.id
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def test_saved_prompt_builder_not_found(tmp_path):
|
|
543
|
+
task = build_test_task(tmp_path)
|
|
544
|
+
|
|
545
|
+
with pytest.raises(ValueError, match="Prompt ID not found: 123"):
|
|
546
|
+
SavedPromptBuilder(task=task, prompt_id="123")
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def test_build_prompt_with_json_instructions(tmp_path):
|
|
550
|
+
task = build_test_task(tmp_path)
|
|
551
|
+
task = task.model_copy(
|
|
552
|
+
update={
|
|
553
|
+
"output_json_schema": json.dumps(
|
|
554
|
+
{
|
|
555
|
+
"type": "object",
|
|
556
|
+
"properties": {"result": {"type": "string"}},
|
|
557
|
+
"required": ["result"],
|
|
558
|
+
}
|
|
559
|
+
)
|
|
560
|
+
}
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
builder = SimplePromptBuilder(task=task)
|
|
564
|
+
|
|
565
|
+
# Test without JSON instructions
|
|
566
|
+
prompt_without_json = builder.build_prompt(include_json_instructions=False)
|
|
567
|
+
assert "Format Instructions" not in prompt_without_json
|
|
568
|
+
assert (
|
|
569
|
+
"Return a JSON object conforming to the following schema:"
|
|
570
|
+
not in prompt_without_json
|
|
571
|
+
)
|
|
572
|
+
assert task.output_json_schema not in prompt_without_json
|
|
573
|
+
|
|
574
|
+
# Test with JSON instructions
|
|
575
|
+
prompt_with_json = builder.build_prompt(include_json_instructions=True)
|
|
576
|
+
assert "# Format Instructions" in prompt_with_json
|
|
577
|
+
assert (
|
|
578
|
+
"Return a JSON object conforming to the following schema:" in prompt_with_json
|
|
579
|
+
)
|
|
580
|
+
assert "```" in prompt_with_json
|
|
581
|
+
assert (
|
|
582
|
+
"{'type': 'object', 'properties': {'result': {'type': 'string'}}, 'required': ['result']}"
|
|
583
|
+
in prompt_with_json
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
# Verify base prompt is still included
|
|
587
|
+
assert task.instruction in prompt_with_json
|
|
588
|
+
for requirement in task.requirements:
|
|
589
|
+
assert requirement.instruction in prompt_with_json
|