kiln-ai 0.5.5__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +9 -1
- kiln_ai/adapters/base_adapter.py +24 -35
- kiln_ai/adapters/data_gen/__init__.py +11 -0
- kiln_ai/adapters/data_gen/data_gen_prompts.py +73 -0
- kiln_ai/adapters/data_gen/data_gen_task.py +185 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +293 -0
- kiln_ai/adapters/langchain_adapters.py +39 -7
- kiln_ai/adapters/ml_model_list.py +55 -1
- kiln_ai/adapters/prompt_builders.py +66 -0
- kiln_ai/adapters/repair/test_repair_task.py +4 -1
- kiln_ai/adapters/test_langchain_adapter.py +73 -0
- kiln_ai/adapters/test_ml_model_list.py +56 -0
- kiln_ai/adapters/test_prompt_adaptors.py +52 -18
- kiln_ai/adapters/test_prompt_builders.py +97 -7
- kiln_ai/adapters/test_saving_adapter_results.py +16 -6
- kiln_ai/adapters/test_structured_output.py +33 -5
- kiln_ai/datamodel/__init__.py +28 -7
- kiln_ai/datamodel/json_schema.py +1 -0
- kiln_ai/datamodel/test_models.py +44 -8
- kiln_ai/utils/config.py +3 -2
- kiln_ai/utils/test_config.py +7 -0
- {kiln_ai-0.5.5.dist-info → kiln_ai-0.6.1.dist-info}/METADATA +1 -2
- kiln_ai-0.6.1.dist-info/RECORD +37 -0
- {kiln_ai-0.5.5.dist-info → kiln_ai-0.6.1.dist-info}/WHEEL +1 -1
- kiln_ai-0.5.5.dist-info/RECORD +0 -33
- {kiln_ai-0.5.5.dist-info → kiln_ai-0.6.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from kiln_ai.adapters.data_gen.data_gen_task import (
|
|
6
|
+
DataGenCategoriesTask,
|
|
7
|
+
DataGenCategoriesTaskInput,
|
|
8
|
+
DataGenCategoriesTaskOutput,
|
|
9
|
+
DataGenSampleTask,
|
|
10
|
+
DataGenSampleTaskInput,
|
|
11
|
+
list_json_schema_for_task,
|
|
12
|
+
)
|
|
13
|
+
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
|
|
14
|
+
from kiln_ai.adapters.ml_model_list import get_model_and_provider
|
|
15
|
+
from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
|
|
16
|
+
from kiln_ai.datamodel import Project, Task
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@pytest.fixture
|
|
20
|
+
def base_task():
|
|
21
|
+
project = Project(name="TestProject")
|
|
22
|
+
return Task(
|
|
23
|
+
name="Cowboy Speaker",
|
|
24
|
+
parent=project,
|
|
25
|
+
description="Reply like a cowboy",
|
|
26
|
+
instruction="Reply like a cowboy",
|
|
27
|
+
requirements=[],
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def test_data_gen_categories_task_input_initialization(base_task):
|
|
32
|
+
# Arrange
|
|
33
|
+
node_path = ["root", "branch", "leaf"]
|
|
34
|
+
num_subtopics = 4
|
|
35
|
+
human_guidance = "Test guidance"
|
|
36
|
+
|
|
37
|
+
# Act
|
|
38
|
+
input_model = DataGenCategoriesTaskInput.from_task(
|
|
39
|
+
task=base_task,
|
|
40
|
+
node_path=node_path,
|
|
41
|
+
num_subtopics=num_subtopics,
|
|
42
|
+
human_guidance=human_guidance,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Assert
|
|
46
|
+
assert input_model.node_path == node_path
|
|
47
|
+
assert input_model.num_subtopics == num_subtopics
|
|
48
|
+
assert input_model.human_guidance == human_guidance
|
|
49
|
+
assert isinstance(input_model.system_prompt, str)
|
|
50
|
+
assert "Reply like a cowboy" in input_model.system_prompt
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def test_data_gen_categories_task_input_default_values(base_task):
|
|
54
|
+
# Act
|
|
55
|
+
input_model = DataGenCategoriesTaskInput.from_task(task=base_task)
|
|
56
|
+
|
|
57
|
+
# Assert
|
|
58
|
+
assert input_model.num_subtopics == 6
|
|
59
|
+
assert input_model.human_guidance is None
|
|
60
|
+
assert input_model.node_path == []
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_data_gen_categories_task_initialization():
|
|
64
|
+
# Act
|
|
65
|
+
task = DataGenCategoriesTask()
|
|
66
|
+
|
|
67
|
+
# Assert
|
|
68
|
+
assert task.name == "DataGen"
|
|
69
|
+
assert isinstance(task.parent, Project)
|
|
70
|
+
assert task.description is not None
|
|
71
|
+
assert task.instruction is not None
|
|
72
|
+
assert isinstance(task.input_json_schema, str)
|
|
73
|
+
assert isinstance(task.output_json_schema, str)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def test_data_gen_categories_task_schemas():
|
|
77
|
+
# Act
|
|
78
|
+
task = DataGenCategoriesTask()
|
|
79
|
+
|
|
80
|
+
# Assert
|
|
81
|
+
input_schema = json.loads(task.input_json_schema)
|
|
82
|
+
output_schema = json.loads(task.output_json_schema)
|
|
83
|
+
|
|
84
|
+
assert isinstance(input_schema, dict)
|
|
85
|
+
assert isinstance(output_schema, dict)
|
|
86
|
+
assert output_schema["type"] == "object"
|
|
87
|
+
assert output_schema["properties"]["subtopics"]["type"] == "array"
|
|
88
|
+
assert input_schema["properties"]["node_path"]["type"] == "array"
|
|
89
|
+
assert input_schema["properties"]["num_subtopics"]["type"] == "integer"
|
|
90
|
+
assert set(input_schema["required"]) == {
|
|
91
|
+
"node_path",
|
|
92
|
+
"num_subtopics",
|
|
93
|
+
"system_prompt",
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@pytest.mark.paid
|
|
98
|
+
@pytest.mark.ollama
|
|
99
|
+
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
100
|
+
async def test_data_gen_all_models_providers(
|
|
101
|
+
tmp_path, model_name, provider_name, base_task
|
|
102
|
+
):
|
|
103
|
+
_, provider = get_model_and_provider(model_name, provider_name)
|
|
104
|
+
if not provider.supports_data_gen:
|
|
105
|
+
# pass if the model doesn't support data gen (testing the support flag is part of this)
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
data_gen_task = DataGenCategoriesTask()
|
|
109
|
+
data_gen_input = DataGenCategoriesTaskInput.from_task(base_task, num_subtopics=6)
|
|
110
|
+
|
|
111
|
+
adapter = LangChainPromptAdapter(
|
|
112
|
+
data_gen_task,
|
|
113
|
+
model_name=model_name,
|
|
114
|
+
provider=provider_name,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
input_dict = data_gen_input.model_dump()
|
|
118
|
+
run = await adapter.invoke(input_dict)
|
|
119
|
+
parsed_output = DataGenCategoriesTaskOutput.model_validate_json(run.output.output)
|
|
120
|
+
assert len(parsed_output.subtopics) == 6
|
|
121
|
+
for subtopic in parsed_output.subtopics:
|
|
122
|
+
assert isinstance(subtopic, str)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def test_data_gen_sample_task_input_initialization(base_task):
|
|
126
|
+
# Arrange
|
|
127
|
+
topic = ["cowboys", "hats"]
|
|
128
|
+
num_samples = 4
|
|
129
|
+
human_guidance = "Test guidance"
|
|
130
|
+
|
|
131
|
+
# Act
|
|
132
|
+
input_model = DataGenSampleTaskInput.from_task(
|
|
133
|
+
task=base_task,
|
|
134
|
+
topic=topic,
|
|
135
|
+
num_samples=num_samples,
|
|
136
|
+
human_guidance=human_guidance,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Assert
|
|
140
|
+
assert input_model.topic == topic
|
|
141
|
+
assert input_model.num_samples == num_samples
|
|
142
|
+
assert input_model.human_guidance == human_guidance
|
|
143
|
+
assert isinstance(input_model.system_prompt, str)
|
|
144
|
+
assert "Reply like a cowboy" in input_model.system_prompt
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def test_data_gen_sample_task_input_default_values(base_task):
|
|
148
|
+
# Act
|
|
149
|
+
input_model = DataGenSampleTaskInput.from_task(task=base_task)
|
|
150
|
+
|
|
151
|
+
# Assert
|
|
152
|
+
assert input_model.num_samples == 8
|
|
153
|
+
assert input_model.human_guidance is None
|
|
154
|
+
assert input_model.topic == []
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def test_data_gen_sample_task_initialization(base_task):
|
|
158
|
+
# Act
|
|
159
|
+
task = DataGenSampleTask(target_task=base_task)
|
|
160
|
+
|
|
161
|
+
# Assert
|
|
162
|
+
assert task.name == "DataGenSample"
|
|
163
|
+
assert isinstance(task.parent, Project)
|
|
164
|
+
assert task.description is not None
|
|
165
|
+
assert task.instruction is not None
|
|
166
|
+
|
|
167
|
+
input_schema = json.loads(task.input_json_schema)
|
|
168
|
+
output_schema = json.loads(task.output_json_schema)
|
|
169
|
+
|
|
170
|
+
assert isinstance(input_schema, dict)
|
|
171
|
+
assert isinstance(output_schema, dict)
|
|
172
|
+
assert output_schema["type"] == "object"
|
|
173
|
+
assert output_schema["properties"]["generated_samples"]["type"] == "array"
|
|
174
|
+
assert input_schema["properties"]["topic"]["type"] == "array"
|
|
175
|
+
assert input_schema["properties"]["num_samples"]["type"] == "integer"
|
|
176
|
+
assert set(input_schema["required"]) == {
|
|
177
|
+
"topic",
|
|
178
|
+
"num_samples",
|
|
179
|
+
"system_prompt",
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def test_list_json_schema_for_task_with_output_schema(base_task):
|
|
184
|
+
# Arrange
|
|
185
|
+
base_task.input_json_schema = json.dumps(
|
|
186
|
+
{
|
|
187
|
+
"type": "object",
|
|
188
|
+
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
|
|
189
|
+
}
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# Act
|
|
193
|
+
schema = list_json_schema_for_task(base_task)
|
|
194
|
+
parsed_schema = json.loads(schema)
|
|
195
|
+
|
|
196
|
+
# Assert
|
|
197
|
+
assert parsed_schema["type"] == "object"
|
|
198
|
+
generated_samples_schema = parsed_schema["properties"]["generated_samples"]
|
|
199
|
+
assert generated_samples_schema["type"] == "array"
|
|
200
|
+
assert generated_samples_schema["items"]["type"] == "object"
|
|
201
|
+
assert generated_samples_schema["items"]["properties"]["name"]["type"] == "string"
|
|
202
|
+
assert generated_samples_schema["items"]["properties"]["age"]["type"] == "integer"
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def test_list_json_schema_for_task_without_output_schema(base_task):
|
|
206
|
+
# Arrange
|
|
207
|
+
base_task.output_json_schema = None
|
|
208
|
+
|
|
209
|
+
# Act
|
|
210
|
+
schema = list_json_schema_for_task(base_task)
|
|
211
|
+
parsed_schema = json.loads(schema)
|
|
212
|
+
|
|
213
|
+
# Assert
|
|
214
|
+
assert parsed_schema["type"] == "object"
|
|
215
|
+
assert parsed_schema["properties"]["generated_samples"]["type"] == "array"
|
|
216
|
+
assert parsed_schema["properties"]["generated_samples"]["items"]["type"] == "string"
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@pytest.mark.paid
|
|
220
|
+
@pytest.mark.ollama
|
|
221
|
+
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
222
|
+
async def test_data_gen_sample_all_models_providers(
|
|
223
|
+
tmp_path, model_name, provider_name, base_task
|
|
224
|
+
):
|
|
225
|
+
_, provider = get_model_and_provider(model_name, provider_name)
|
|
226
|
+
if not provider.supports_data_gen:
|
|
227
|
+
# pass if the model doesn't support data gen (testing the support flag is part of this)
|
|
228
|
+
return
|
|
229
|
+
|
|
230
|
+
data_gen_task = DataGenSampleTask(target_task=base_task)
|
|
231
|
+
data_gen_input = DataGenSampleTaskInput.from_task(
|
|
232
|
+
base_task, topic=["riding horses"], num_samples=4
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
adapter = LangChainPromptAdapter(
|
|
236
|
+
data_gen_task,
|
|
237
|
+
model_name=model_name,
|
|
238
|
+
provider=provider_name,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
input_dict = data_gen_input.model_dump()
|
|
242
|
+
run = await adapter.invoke(input_dict)
|
|
243
|
+
parsed_output = json.loads(run.output.output)
|
|
244
|
+
samples = parsed_output["generated_samples"]
|
|
245
|
+
assert len(samples) == 4
|
|
246
|
+
for sample in samples:
|
|
247
|
+
assert isinstance(sample, str)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@pytest.mark.paid
|
|
251
|
+
@pytest.mark.ollama
|
|
252
|
+
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
253
|
+
async def test_data_gen_sample_all_models_providers_with_structured_output(
|
|
254
|
+
tmp_path, model_name, provider_name, base_task
|
|
255
|
+
):
|
|
256
|
+
base_task.output_json_schema = json.dumps(
|
|
257
|
+
{
|
|
258
|
+
"type": "object",
|
|
259
|
+
"properties": {
|
|
260
|
+
"opening": {"type": "string"},
|
|
261
|
+
"closing": {"type": "string"},
|
|
262
|
+
},
|
|
263
|
+
"required": ["opening", "closing"],
|
|
264
|
+
}
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
_, provider = get_model_and_provider(model_name, provider_name)
|
|
268
|
+
if not provider.supports_data_gen:
|
|
269
|
+
# pass if the model doesn't support data gen (testing the support flag is part of this)
|
|
270
|
+
return
|
|
271
|
+
|
|
272
|
+
data_gen_task = DataGenSampleTask(target_task=base_task)
|
|
273
|
+
data_gen_input = DataGenSampleTaskInput.from_task(
|
|
274
|
+
base_task, topic=["riding horses"], num_samples=4
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
adapter = LangChainPromptAdapter(
|
|
278
|
+
data_gen_task,
|
|
279
|
+
model_name=model_name,
|
|
280
|
+
provider=provider_name,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
input_dict = data_gen_input.model_dump()
|
|
284
|
+
run = await adapter.invoke(input_dict)
|
|
285
|
+
parsed_output = json.loads(run.output.output)
|
|
286
|
+
samples = parsed_output["generated_samples"]
|
|
287
|
+
assert len(samples) == 4
|
|
288
|
+
for sample in samples:
|
|
289
|
+
assert isinstance(sample, dict)
|
|
290
|
+
assert "opening" in sample
|
|
291
|
+
assert "closing" in sample
|
|
292
|
+
assert isinstance(sample["opening"], str)
|
|
293
|
+
assert isinstance(sample["closing"], str)
|
|
@@ -2,14 +2,14 @@ from typing import Dict
|
|
|
2
2
|
|
|
3
3
|
from langchain_core.language_models import LanguageModelInput
|
|
4
4
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
5
|
-
from langchain_core.messages import HumanMessage, SystemMessage
|
|
5
|
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
6
6
|
from langchain_core.messages.base import BaseMessage
|
|
7
7
|
from langchain_core.runnables import Runnable
|
|
8
8
|
from pydantic import BaseModel
|
|
9
9
|
|
|
10
10
|
import kiln_ai.datamodel as datamodel
|
|
11
11
|
|
|
12
|
-
from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder
|
|
12
|
+
from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput
|
|
13
13
|
from .ml_model_list import langchain_model_from
|
|
14
14
|
|
|
15
15
|
LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel]
|
|
@@ -84,15 +84,41 @@ class LangChainPromptAdapter(BaseAdapter):
|
|
|
84
84
|
)
|
|
85
85
|
return self._model
|
|
86
86
|
|
|
87
|
-
async def _run(self, input: Dict | str) ->
|
|
87
|
+
async def _run(self, input: Dict | str) -> RunOutput:
|
|
88
|
+
model = await self.model()
|
|
89
|
+
chain = model
|
|
90
|
+
intermediate_outputs = {}
|
|
91
|
+
|
|
88
92
|
prompt = self.build_prompt()
|
|
89
93
|
user_msg = self.prompt_builder.build_user_message(input)
|
|
90
94
|
messages = [
|
|
91
95
|
SystemMessage(content=prompt),
|
|
92
96
|
HumanMessage(content=user_msg),
|
|
93
97
|
]
|
|
94
|
-
|
|
95
|
-
|
|
98
|
+
|
|
99
|
+
# COT with structured output
|
|
100
|
+
cot_prompt = self.prompt_builder.chain_of_thought_prompt()
|
|
101
|
+
if cot_prompt and self.has_structured_output():
|
|
102
|
+
# Base model (without structured output) used for COT message
|
|
103
|
+
base_model = await langchain_model_from(
|
|
104
|
+
self.model_name, self.model_provider
|
|
105
|
+
)
|
|
106
|
+
messages.append(
|
|
107
|
+
SystemMessage(content=cot_prompt),
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
cot_messages = [*messages]
|
|
111
|
+
cot_response = base_model.invoke(cot_messages)
|
|
112
|
+
intermediate_outputs["chain_of_thought"] = cot_response.content
|
|
113
|
+
messages.append(AIMessage(content=cot_response.content))
|
|
114
|
+
messages.append(
|
|
115
|
+
SystemMessage(content="Considering the above, return a final result.")
|
|
116
|
+
)
|
|
117
|
+
elif cot_prompt:
|
|
118
|
+
# for plaintext output, we just add COT instructions. We still only make one call.
|
|
119
|
+
messages.append(SystemMessage(content=cot_prompt))
|
|
120
|
+
|
|
121
|
+
response = chain.invoke(messages)
|
|
96
122
|
|
|
97
123
|
if self.has_structured_output():
|
|
98
124
|
if (
|
|
@@ -102,14 +128,20 @@ class LangChainPromptAdapter(BaseAdapter):
|
|
|
102
128
|
):
|
|
103
129
|
raise RuntimeError(f"structured response not returned: {response}")
|
|
104
130
|
structured_response = response["parsed"]
|
|
105
|
-
return
|
|
131
|
+
return RunOutput(
|
|
132
|
+
output=self._munge_response(structured_response),
|
|
133
|
+
intermediate_outputs=intermediate_outputs,
|
|
134
|
+
)
|
|
106
135
|
else:
|
|
107
136
|
if not isinstance(response, BaseMessage):
|
|
108
137
|
raise RuntimeError(f"response is not a BaseMessage: {response}")
|
|
109
138
|
text_content = response.content
|
|
110
139
|
if not isinstance(text_content, str):
|
|
111
140
|
raise RuntimeError(f"response is not a string: {text_content}")
|
|
112
|
-
return
|
|
141
|
+
return RunOutput(
|
|
142
|
+
output=text_content,
|
|
143
|
+
intermediate_outputs=intermediate_outputs,
|
|
144
|
+
)
|
|
113
145
|
|
|
114
146
|
def adapter_info(self) -> AdapterInfo:
|
|
115
147
|
return AdapterInfo(
|
|
@@ -89,6 +89,7 @@ class KilnModelProvider(BaseModel):
|
|
|
89
89
|
|
|
90
90
|
name: ModelProviderName
|
|
91
91
|
supports_structured_output: bool = True
|
|
92
|
+
supports_data_gen: bool = True
|
|
92
93
|
provider_options: Dict = {}
|
|
93
94
|
|
|
94
95
|
|
|
@@ -176,6 +177,8 @@ built_in_models: List[KilnModel] = [
|
|
|
176
177
|
providers=[
|
|
177
178
|
KilnModelProvider(
|
|
178
179
|
name=ModelProviderName.openrouter,
|
|
180
|
+
supports_structured_output=False, # it should, but doesn't work on openrouter
|
|
181
|
+
supports_data_gen=False, # doesn't work on openrouter
|
|
179
182
|
provider_options={"model": "google/gemini-pro-1.5"},
|
|
180
183
|
),
|
|
181
184
|
],
|
|
@@ -188,6 +191,7 @@ built_in_models: List[KilnModel] = [
|
|
|
188
191
|
providers=[
|
|
189
192
|
KilnModelProvider(
|
|
190
193
|
name=ModelProviderName.openrouter,
|
|
194
|
+
supports_data_gen=False,
|
|
191
195
|
provider_options={"model": "google/gemini-flash-1.5"},
|
|
192
196
|
),
|
|
193
197
|
],
|
|
@@ -200,6 +204,8 @@ built_in_models: List[KilnModel] = [
|
|
|
200
204
|
providers=[
|
|
201
205
|
KilnModelProvider(
|
|
202
206
|
name=ModelProviderName.openrouter,
|
|
207
|
+
supports_structured_output=False,
|
|
208
|
+
supports_data_gen=False,
|
|
203
209
|
provider_options={"model": "google/gemini-flash-1.5-8b"},
|
|
204
210
|
),
|
|
205
211
|
],
|
|
@@ -213,6 +219,7 @@ built_in_models: List[KilnModel] = [
|
|
|
213
219
|
KilnModelProvider(
|
|
214
220
|
name=ModelProviderName.openrouter,
|
|
215
221
|
supports_structured_output=False,
|
|
222
|
+
supports_data_gen=False,
|
|
216
223
|
provider_options={"model": "nvidia/llama-3.1-nemotron-70b-instruct"},
|
|
217
224
|
),
|
|
218
225
|
],
|
|
@@ -230,6 +237,7 @@ built_in_models: List[KilnModel] = [
|
|
|
230
237
|
KilnModelProvider(
|
|
231
238
|
name=ModelProviderName.amazon_bedrock,
|
|
232
239
|
supports_structured_output=False,
|
|
240
|
+
supports_data_gen=False,
|
|
233
241
|
provider_options={
|
|
234
242
|
"model": "meta.llama3-1-8b-instruct-v1:0",
|
|
235
243
|
"region_name": "us-west-2", # Llama 3.1 only in west-2
|
|
@@ -237,6 +245,7 @@ built_in_models: List[KilnModel] = [
|
|
|
237
245
|
),
|
|
238
246
|
KilnModelProvider(
|
|
239
247
|
name=ModelProviderName.ollama,
|
|
248
|
+
supports_data_gen=False,
|
|
240
249
|
provider_options={
|
|
241
250
|
"model": "llama3.1:8b",
|
|
242
251
|
"model_aliases": ["llama3.1"], # 8b is default
|
|
@@ -245,6 +254,7 @@ built_in_models: List[KilnModel] = [
|
|
|
245
254
|
KilnModelProvider(
|
|
246
255
|
name=ModelProviderName.openrouter,
|
|
247
256
|
supports_structured_output=False,
|
|
257
|
+
supports_data_gen=False,
|
|
248
258
|
provider_options={"model": "meta-llama/llama-3.1-8b-instruct"},
|
|
249
259
|
),
|
|
250
260
|
],
|
|
@@ -261,7 +271,9 @@ built_in_models: List[KilnModel] = [
|
|
|
261
271
|
),
|
|
262
272
|
KilnModelProvider(
|
|
263
273
|
name=ModelProviderName.amazon_bedrock,
|
|
274
|
+
# not sure how AWS manages to break this, but it's not working
|
|
264
275
|
supports_structured_output=False,
|
|
276
|
+
supports_data_gen=False,
|
|
265
277
|
provider_options={
|
|
266
278
|
"model": "meta.llama3-1-70b-instruct-v1:0",
|
|
267
279
|
"region_name": "us-west-2", # Llama 3.1 only in west-2
|
|
@@ -285,6 +297,7 @@ built_in_models: List[KilnModel] = [
|
|
|
285
297
|
providers=[
|
|
286
298
|
KilnModelProvider(
|
|
287
299
|
name=ModelProviderName.amazon_bedrock,
|
|
300
|
+
supports_data_gen=False,
|
|
288
301
|
provider_options={
|
|
289
302
|
"model": "meta.llama3-1-405b-instruct-v1:0",
|
|
290
303
|
"region_name": "us-west-2", # Llama 3.1 only in west-2
|
|
@@ -344,8 +357,15 @@ built_in_models: List[KilnModel] = [
|
|
|
344
357
|
KilnModelProvider(
|
|
345
358
|
name=ModelProviderName.openrouter,
|
|
346
359
|
supports_structured_output=False,
|
|
360
|
+
supports_data_gen=False,
|
|
347
361
|
provider_options={"model": "meta-llama/llama-3.2-3b-instruct"},
|
|
348
362
|
),
|
|
363
|
+
KilnModelProvider(
|
|
364
|
+
name=ModelProviderName.ollama,
|
|
365
|
+
supports_structured_output=False,
|
|
366
|
+
supports_data_gen=False,
|
|
367
|
+
provider_options={"model": "llama3.2"},
|
|
368
|
+
),
|
|
349
369
|
],
|
|
350
370
|
),
|
|
351
371
|
# Llama 3.2 11B
|
|
@@ -357,8 +377,15 @@ built_in_models: List[KilnModel] = [
|
|
|
357
377
|
KilnModelProvider(
|
|
358
378
|
name=ModelProviderName.openrouter,
|
|
359
379
|
supports_structured_output=False,
|
|
380
|
+
supports_data_gen=False,
|
|
360
381
|
provider_options={"model": "meta-llama/llama-3.2-11b-vision-instruct"},
|
|
361
382
|
),
|
|
383
|
+
KilnModelProvider(
|
|
384
|
+
name=ModelProviderName.ollama,
|
|
385
|
+
supports_structured_output=False,
|
|
386
|
+
supports_data_gen=False,
|
|
387
|
+
provider_options={"model": "llama3.2-vision"},
|
|
388
|
+
),
|
|
362
389
|
],
|
|
363
390
|
),
|
|
364
391
|
# Llama 3.2 90B
|
|
@@ -370,8 +397,15 @@ built_in_models: List[KilnModel] = [
|
|
|
370
397
|
KilnModelProvider(
|
|
371
398
|
name=ModelProviderName.openrouter,
|
|
372
399
|
supports_structured_output=False,
|
|
400
|
+
supports_data_gen=False,
|
|
373
401
|
provider_options={"model": "meta-llama/llama-3.2-90b-vision-instruct"},
|
|
374
402
|
),
|
|
403
|
+
KilnModelProvider(
|
|
404
|
+
name=ModelProviderName.ollama,
|
|
405
|
+
supports_structured_output=False,
|
|
406
|
+
supports_data_gen=False,
|
|
407
|
+
provider_options={"model": "llama3.2-vision:90b"},
|
|
408
|
+
),
|
|
375
409
|
],
|
|
376
410
|
),
|
|
377
411
|
# Phi 3.5
|
|
@@ -384,10 +418,13 @@ built_in_models: List[KilnModel] = [
|
|
|
384
418
|
KilnModelProvider(
|
|
385
419
|
name=ModelProviderName.ollama,
|
|
386
420
|
supports_structured_output=False,
|
|
421
|
+
supports_data_gen=False,
|
|
387
422
|
provider_options={"model": "phi3.5"},
|
|
388
423
|
),
|
|
389
424
|
KilnModelProvider(
|
|
390
425
|
name=ModelProviderName.openrouter,
|
|
426
|
+
supports_structured_output=False,
|
|
427
|
+
supports_data_gen=False,
|
|
391
428
|
provider_options={"model": "microsoft/phi-3.5-mini-128k-instruct"},
|
|
392
429
|
),
|
|
393
430
|
],
|
|
@@ -402,6 +439,7 @@ built_in_models: List[KilnModel] = [
|
|
|
402
439
|
KilnModelProvider(
|
|
403
440
|
name=ModelProviderName.ollama,
|
|
404
441
|
supports_structured_output=False,
|
|
442
|
+
supports_data_gen=False,
|
|
405
443
|
provider_options={
|
|
406
444
|
"model": "gemma2:2b",
|
|
407
445
|
},
|
|
@@ -417,12 +455,14 @@ built_in_models: List[KilnModel] = [
|
|
|
417
455
|
providers=[
|
|
418
456
|
KilnModelProvider(
|
|
419
457
|
name=ModelProviderName.ollama,
|
|
458
|
+
supports_data_gen=False,
|
|
420
459
|
provider_options={
|
|
421
460
|
"model": "gemma2:9b",
|
|
422
461
|
},
|
|
423
462
|
),
|
|
424
463
|
KilnModelProvider(
|
|
425
464
|
name=ModelProviderName.openrouter,
|
|
465
|
+
supports_data_gen=False,
|
|
426
466
|
provider_options={"model": "google/gemma-2-9b-it"},
|
|
427
467
|
),
|
|
428
468
|
],
|
|
@@ -436,12 +476,14 @@ built_in_models: List[KilnModel] = [
|
|
|
436
476
|
providers=[
|
|
437
477
|
KilnModelProvider(
|
|
438
478
|
name=ModelProviderName.ollama,
|
|
479
|
+
supports_data_gen=False,
|
|
439
480
|
provider_options={
|
|
440
481
|
"model": "gemma2:27b",
|
|
441
482
|
},
|
|
442
483
|
),
|
|
443
484
|
KilnModelProvider(
|
|
444
485
|
name=ModelProviderName.openrouter,
|
|
486
|
+
supports_data_gen=False,
|
|
445
487
|
provider_options={"model": "google/gemma-2-27b-it"},
|
|
446
488
|
),
|
|
447
489
|
],
|
|
@@ -449,6 +491,19 @@ built_in_models: List[KilnModel] = [
|
|
|
449
491
|
]
|
|
450
492
|
|
|
451
493
|
|
|
494
|
+
def get_model_and_provider(
|
|
495
|
+
model_name: str, provider_name: str
|
|
496
|
+
) -> tuple[KilnModel | None, KilnModelProvider | None]:
|
|
497
|
+
model = next(filter(lambda m: m.name == model_name, built_in_models), None)
|
|
498
|
+
if model is None:
|
|
499
|
+
return None, None
|
|
500
|
+
provider = next(filter(lambda p: p.name == provider_name, model.providers), None)
|
|
501
|
+
# all or nothing
|
|
502
|
+
if provider is None or model is None:
|
|
503
|
+
return None, None
|
|
504
|
+
return model, provider
|
|
505
|
+
|
|
506
|
+
|
|
452
507
|
def provider_name_from_id(id: str) -> str:
|
|
453
508
|
"""
|
|
454
509
|
Converts a provider ID to its human-readable name.
|
|
@@ -687,7 +742,6 @@ def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
|
|
|
687
742
|
models = tags["models"]
|
|
688
743
|
if isinstance(models, list):
|
|
689
744
|
model_names = [model["model"] for model in models]
|
|
690
|
-
print(f"model_names: {model_names}")
|
|
691
745
|
available_supported_models = [
|
|
692
746
|
model
|
|
693
747
|
for model in model_names
|
|
@@ -54,6 +54,28 @@ class BasePromptBuilder(metaclass=ABCMeta):
|
|
|
54
54
|
|
|
55
55
|
return f"The input is:\n{input}"
|
|
56
56
|
|
|
57
|
+
def chain_of_thought_prompt(self) -> str | None:
|
|
58
|
+
"""Build and return the chain of thought prompt string.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
str: The constructed chain of thought prompt.
|
|
62
|
+
"""
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
def build_prompt_for_ui(self) -> str:
|
|
66
|
+
"""Build a prompt for the UI. It includes additional instructions (like chain of thought), even if they are passed to the model in stages.
|
|
67
|
+
|
|
68
|
+
Designed for end-user consumption, not for model consumption.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
str: The constructed prompt string.
|
|
72
|
+
"""
|
|
73
|
+
base_prompt = self.build_prompt()
|
|
74
|
+
cot_prompt = self.chain_of_thought_prompt()
|
|
75
|
+
if cot_prompt:
|
|
76
|
+
base_prompt += "\n# Thinking Instructions\n\n" + cot_prompt
|
|
77
|
+
return base_prompt
|
|
78
|
+
|
|
57
79
|
|
|
58
80
|
class SimplePromptBuilder(BasePromptBuilder):
|
|
59
81
|
"""A basic prompt builder that combines task instruction with requirements."""
|
|
@@ -187,11 +209,49 @@ class RepairsPromptBuilder(MultiShotPromptBuilder):
|
|
|
187
209
|
return prompt_section
|
|
188
210
|
|
|
189
211
|
|
|
212
|
+
def chain_of_thought_prompt(task: Task) -> str | None:
|
|
213
|
+
"""Standard implementation to build and return the chain of thought prompt string.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
str: The constructed chain of thought prompt.
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
cot_instruction = task.thinking_instruction
|
|
220
|
+
if not cot_instruction:
|
|
221
|
+
cot_instruction = "Think step by step, explaining your reasoning."
|
|
222
|
+
|
|
223
|
+
return cot_instruction
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class SimpleChainOfThoughtPromptBuilder(SimplePromptBuilder):
|
|
227
|
+
"""A prompt builder that includes a chain of thought prompt on top of the simple prompt."""
|
|
228
|
+
|
|
229
|
+
def chain_of_thought_prompt(self) -> str | None:
|
|
230
|
+
return chain_of_thought_prompt(self.task)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class FewShotChainOfThoughtPromptBuilder(FewShotPromptBuilder):
|
|
234
|
+
"""A prompt builder that includes a chain of thought prompt on top of the few shot prompt."""
|
|
235
|
+
|
|
236
|
+
def chain_of_thought_prompt(self) -> str | None:
|
|
237
|
+
return chain_of_thought_prompt(self.task)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class MultiShotChainOfThoughtPromptBuilder(MultiShotPromptBuilder):
|
|
241
|
+
"""A prompt builder that includes a chain of thought prompt on top of the multi shot prompt."""
|
|
242
|
+
|
|
243
|
+
def chain_of_thought_prompt(self) -> str | None:
|
|
244
|
+
return chain_of_thought_prompt(self.task)
|
|
245
|
+
|
|
246
|
+
|
|
190
247
|
prompt_builder_registry = {
|
|
191
248
|
"simple_prompt_builder": SimplePromptBuilder,
|
|
192
249
|
"multi_shot_prompt_builder": MultiShotPromptBuilder,
|
|
193
250
|
"few_shot_prompt_builder": FewShotPromptBuilder,
|
|
194
251
|
"repairs_prompt_builder": RepairsPromptBuilder,
|
|
252
|
+
"simple_chain_of_thought_prompt_builder": SimpleChainOfThoughtPromptBuilder,
|
|
253
|
+
"few_shot_chain_of_thought_prompt_builder": FewShotChainOfThoughtPromptBuilder,
|
|
254
|
+
"multi_shot_chain_of_thought_prompt_builder": MultiShotChainOfThoughtPromptBuilder,
|
|
195
255
|
}
|
|
196
256
|
|
|
197
257
|
|
|
@@ -217,5 +277,11 @@ def prompt_builder_from_ui_name(ui_name: str) -> type[BasePromptBuilder]:
|
|
|
217
277
|
return MultiShotPromptBuilder
|
|
218
278
|
case "repairs":
|
|
219
279
|
return RepairsPromptBuilder
|
|
280
|
+
case "simple_chain_of_thought":
|
|
281
|
+
return SimpleChainOfThoughtPromptBuilder
|
|
282
|
+
case "few_shot_chain_of_thought":
|
|
283
|
+
return FewShotChainOfThoughtPromptBuilder
|
|
284
|
+
case "multi_shot_chain_of_thought":
|
|
285
|
+
return MultiShotChainOfThoughtPromptBuilder
|
|
220
286
|
case _:
|
|
221
287
|
raise ValueError(f"Unknown prompt builder: {ui_name}")
|
|
@@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, patch
|
|
|
5
5
|
import pytest
|
|
6
6
|
from pydantic import ValidationError
|
|
7
7
|
|
|
8
|
+
from kiln_ai.adapters.base_adapter import RunOutput
|
|
8
9
|
from kiln_ai.adapters.langchain_adapters import (
|
|
9
10
|
LangChainPromptAdapter,
|
|
10
11
|
)
|
|
@@ -222,7 +223,9 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
|
|
|
222
223
|
with patch.object(
|
|
223
224
|
LangChainPromptAdapter, "_run", new_callable=AsyncMock
|
|
224
225
|
) as mock_run:
|
|
225
|
-
mock_run.return_value =
|
|
226
|
+
mock_run.return_value = RunOutput(
|
|
227
|
+
output=mocked_output, intermediate_outputs=None
|
|
228
|
+
)
|
|
226
229
|
|
|
227
230
|
adapter = LangChainPromptAdapter(
|
|
228
231
|
repair_task, model_name="llama_3_1_8b", provider="groq"
|