kiln-ai 0.0.4__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/base_adapter.py +168 -0
- kiln_ai/adapters/langchain_adapters.py +113 -0
- kiln_ai/adapters/ml_model_list.py +436 -0
- kiln_ai/adapters/prompt_builders.py +122 -0
- kiln_ai/adapters/repair/repair_task.py +71 -0
- kiln_ai/adapters/repair/test_repair_task.py +248 -0
- kiln_ai/adapters/test_langchain_adapter.py +50 -0
- kiln_ai/adapters/test_ml_model_list.py +99 -0
- kiln_ai/adapters/test_prompt_adaptors.py +167 -0
- kiln_ai/adapters/test_prompt_builders.py +315 -0
- kiln_ai/adapters/test_saving_adapter_results.py +168 -0
- kiln_ai/adapters/test_structured_output.py +218 -0
- kiln_ai/datamodel/__init__.py +362 -2
- kiln_ai/datamodel/basemodel.py +372 -0
- kiln_ai/datamodel/json_schema.py +45 -0
- kiln_ai/datamodel/test_basemodel.py +277 -0
- kiln_ai/datamodel/test_datasource.py +107 -0
- kiln_ai/datamodel/test_example_models.py +644 -0
- kiln_ai/datamodel/test_json_schema.py +124 -0
- kiln_ai/datamodel/test_models.py +190 -0
- kiln_ai/datamodel/test_nested_save.py +205 -0
- kiln_ai/datamodel/test_output_rating.py +88 -0
- kiln_ai/utils/config.py +170 -0
- kiln_ai/utils/formatting.py +5 -0
- kiln_ai/utils/test_config.py +245 -0
- {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.1.dist-info}/METADATA +22 -1
- kiln_ai-0.5.1.dist-info/RECORD +29 -0
- kiln_ai/__init.__.py +0 -3
- kiln_ai/coreadd.py +0 -3
- kiln_ai/datamodel/project.py +0 -15
- kiln_ai-0.0.4.dist-info/RECORD +0 -8
- {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.1.dist-info}/LICENSE.txt +0 -0
- {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter
|
|
5
|
+
from kiln_ai.adapters.prompt_builders import (
|
|
6
|
+
FewShotPromptBuilder,
|
|
7
|
+
MultiShotPromptBuilder,
|
|
8
|
+
SimplePromptBuilder,
|
|
9
|
+
prompt_builder_from_ui_name,
|
|
10
|
+
)
|
|
11
|
+
from kiln_ai.adapters.test_prompt_adaptors import build_test_task
|
|
12
|
+
from kiln_ai.adapters.test_structured_output import build_structured_output_test_task
|
|
13
|
+
from kiln_ai.datamodel import (
|
|
14
|
+
DataSource,
|
|
15
|
+
DataSourceType,
|
|
16
|
+
Project,
|
|
17
|
+
Task,
|
|
18
|
+
TaskOutput,
|
|
19
|
+
TaskOutputRating,
|
|
20
|
+
TaskRun,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def test_simple_prompt_builder(tmp_path):
|
|
25
|
+
task = build_test_task(tmp_path)
|
|
26
|
+
builder = SimplePromptBuilder(task=task)
|
|
27
|
+
input = "two plus two"
|
|
28
|
+
prompt = builder.build_prompt()
|
|
29
|
+
assert (
|
|
30
|
+
"You are an assistant which performs math tasks provided in plain text."
|
|
31
|
+
in prompt
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
assert "1) " + task.requirements[0].instruction in prompt
|
|
35
|
+
assert "2) " + task.requirements[1].instruction in prompt
|
|
36
|
+
assert "3) " + task.requirements[2].instruction in prompt
|
|
37
|
+
|
|
38
|
+
user_msg = builder.build_user_message(input)
|
|
39
|
+
assert input in user_msg
|
|
40
|
+
assert input not in prompt
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MockAdapter(BaseAdapter):
|
|
44
|
+
def adapter_specific_instructions(self) -> str | None:
|
|
45
|
+
return "You are a mock, send me the response!"
|
|
46
|
+
|
|
47
|
+
def _run(self, input: str) -> str:
|
|
48
|
+
return "mock response"
|
|
49
|
+
|
|
50
|
+
def adapter_info(self) -> AdapterInfo:
|
|
51
|
+
return AdapterInfo(
|
|
52
|
+
adapter_name="mock_adapter",
|
|
53
|
+
model_name="mock_model",
|
|
54
|
+
model_provider="mock_provider",
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def test_simple_prompt_builder_structured_output(tmp_path):
|
|
59
|
+
task = build_structured_output_test_task(tmp_path)
|
|
60
|
+
builder = SimplePromptBuilder(task=task)
|
|
61
|
+
input = "Cows"
|
|
62
|
+
prompt = builder.build_prompt()
|
|
63
|
+
assert "You are an assistant which tells a joke, given a subject." in prompt
|
|
64
|
+
|
|
65
|
+
# check adapter instructions are included
|
|
66
|
+
run_adapter = MockAdapter(task, prompt_builder=builder)
|
|
67
|
+
assert "You are a mock, send me the response!" in run_adapter.build_prompt()
|
|
68
|
+
|
|
69
|
+
user_msg = builder.build_user_message(input)
|
|
70
|
+
assert input in user_msg
|
|
71
|
+
assert input not in prompt
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def test_multi_shot_prompt_builder(tmp_path):
|
|
75
|
+
# Create a project and task hierarchy
|
|
76
|
+
project = Project(name="Test Project", path=(tmp_path / "test_project.kiln"))
|
|
77
|
+
project.save_to_file()
|
|
78
|
+
task = Task(
|
|
79
|
+
name="Test Task",
|
|
80
|
+
instruction="You are an assistant which tells a joke, given a subject.",
|
|
81
|
+
parent=project,
|
|
82
|
+
input_json_schema=json.dumps(
|
|
83
|
+
{
|
|
84
|
+
"type": "object",
|
|
85
|
+
"properties": {
|
|
86
|
+
"subject": {"type": "string"},
|
|
87
|
+
},
|
|
88
|
+
"required": ["subject"],
|
|
89
|
+
}
|
|
90
|
+
),
|
|
91
|
+
output_json_schema=json.dumps(
|
|
92
|
+
{
|
|
93
|
+
"type": "object",
|
|
94
|
+
"properties": {"joke": {"type": "string"}},
|
|
95
|
+
"required": ["joke"],
|
|
96
|
+
}
|
|
97
|
+
),
|
|
98
|
+
)
|
|
99
|
+
task.save_to_file()
|
|
100
|
+
|
|
101
|
+
check_example_outputs(task, 0)
|
|
102
|
+
|
|
103
|
+
# Create an task input, but with no output
|
|
104
|
+
e1 = TaskRun(
|
|
105
|
+
input='{"subject": "Cows"}',
|
|
106
|
+
input_source=DataSource(
|
|
107
|
+
type=DataSourceType.human,
|
|
108
|
+
properties={"created_by": "john_doe"},
|
|
109
|
+
),
|
|
110
|
+
parent=task,
|
|
111
|
+
output=TaskOutput(
|
|
112
|
+
output='{"joke": "Moo I am a cow joke."}',
|
|
113
|
+
source=DataSource(
|
|
114
|
+
type=DataSourceType.human,
|
|
115
|
+
properties={"created_by": "john_doe"},
|
|
116
|
+
),
|
|
117
|
+
),
|
|
118
|
+
)
|
|
119
|
+
e1.save_to_file()
|
|
120
|
+
|
|
121
|
+
## still zero since not fixed and not rated highly
|
|
122
|
+
check_example_outputs(task, 0)
|
|
123
|
+
|
|
124
|
+
e1.output.rating = TaskOutputRating(value=4)
|
|
125
|
+
e1.save_to_file()
|
|
126
|
+
# Now that it's highly rated, it should be included
|
|
127
|
+
check_example_outputs(task, 1)
|
|
128
|
+
|
|
129
|
+
# Test with repaired output (highest priority)
|
|
130
|
+
e1 = TaskRun(
|
|
131
|
+
input='{"subject": "Cows"}',
|
|
132
|
+
input_source=DataSource(
|
|
133
|
+
type=DataSourceType.human,
|
|
134
|
+
properties={"created_by": "john_doe"},
|
|
135
|
+
),
|
|
136
|
+
parent=task,
|
|
137
|
+
output=TaskOutput(
|
|
138
|
+
output='{"joke": "Moo I am a cow joke."}',
|
|
139
|
+
source=DataSource(
|
|
140
|
+
type=DataSourceType.human,
|
|
141
|
+
properties={"created_by": "john_doe"},
|
|
142
|
+
),
|
|
143
|
+
),
|
|
144
|
+
repair_instructions="Fix the joke",
|
|
145
|
+
repaired_output=TaskOutput(
|
|
146
|
+
output='{"joke": "Why did the cow cross the road? To get to the udder side!"}',
|
|
147
|
+
source=DataSource(
|
|
148
|
+
type=DataSourceType.human,
|
|
149
|
+
properties={"created_by": "jane_doe"},
|
|
150
|
+
),
|
|
151
|
+
),
|
|
152
|
+
)
|
|
153
|
+
e1.save_to_file()
|
|
154
|
+
check_example_outputs(task, 1)
|
|
155
|
+
|
|
156
|
+
# Test with high-quality output (second priority)
|
|
157
|
+
e2 = TaskRun(
|
|
158
|
+
input='{"subject": "Dogs"}',
|
|
159
|
+
input_source=DataSource(
|
|
160
|
+
type=DataSourceType.human,
|
|
161
|
+
properties={"created_by": "john_doe"},
|
|
162
|
+
),
|
|
163
|
+
parent=task,
|
|
164
|
+
output=TaskOutput(
|
|
165
|
+
output='{"joke": "Why did the dog get a job? He wanted to be a collar-ary!"}',
|
|
166
|
+
source=DataSource(
|
|
167
|
+
type=DataSourceType.human,
|
|
168
|
+
properties={"created_by": "john_doe"},
|
|
169
|
+
),
|
|
170
|
+
rating=TaskOutputRating(value=4, reason="Good pun"),
|
|
171
|
+
),
|
|
172
|
+
)
|
|
173
|
+
e2.save_to_file()
|
|
174
|
+
check_example_outputs(task, 2)
|
|
175
|
+
|
|
176
|
+
# Test sorting by rating value
|
|
177
|
+
e3 = TaskRun(
|
|
178
|
+
input='{"subject": "Cats"}',
|
|
179
|
+
input_source=DataSource(
|
|
180
|
+
type=DataSourceType.human,
|
|
181
|
+
properties={"created_by": "john_doe"},
|
|
182
|
+
),
|
|
183
|
+
parent=task,
|
|
184
|
+
output=TaskOutput(
|
|
185
|
+
output='{"joke": "Why don\'t cats play poker in the jungle? Too many cheetahs!"}',
|
|
186
|
+
source=DataSource(
|
|
187
|
+
type=DataSourceType.human,
|
|
188
|
+
properties={"created_by": "john_doe"},
|
|
189
|
+
),
|
|
190
|
+
rating=TaskOutputRating(value=5, reason="Excellent joke"),
|
|
191
|
+
),
|
|
192
|
+
)
|
|
193
|
+
e3.save_to_file()
|
|
194
|
+
check_example_outputs(task, 3)
|
|
195
|
+
|
|
196
|
+
# Verify the order of examples
|
|
197
|
+
prompt_builder = MultiShotPromptBuilder(task=task)
|
|
198
|
+
prompt = prompt_builder.build_prompt()
|
|
199
|
+
assert "Why did the cow cross the road?" in prompt
|
|
200
|
+
assert prompt.index("Why did the cow cross the road?") < prompt.index(
|
|
201
|
+
"Why don't cats play poker in the jungle?"
|
|
202
|
+
)
|
|
203
|
+
assert prompt.index("Why don't cats play poker in the jungle?") < prompt.index(
|
|
204
|
+
"Why did the dog get a job?"
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
# Add a new test for the FewShotPromptBuilder
|
|
209
|
+
def test_few_shot_prompt_builder(tmp_path):
|
|
210
|
+
# Create a project and task hierarchy (similar to test_multi_shot_prompt_builder)
|
|
211
|
+
project = Project(name="Test Project", path=(tmp_path / "test_project.kiln"))
|
|
212
|
+
project.save_to_file()
|
|
213
|
+
task = Task(
|
|
214
|
+
name="Test Task",
|
|
215
|
+
instruction="You are an assistant which tells a joke, given a subject.",
|
|
216
|
+
parent=project,
|
|
217
|
+
input_json_schema=json.dumps(
|
|
218
|
+
{
|
|
219
|
+
"type": "object",
|
|
220
|
+
"properties": {
|
|
221
|
+
"subject": {"type": "string"},
|
|
222
|
+
},
|
|
223
|
+
"required": ["subject"],
|
|
224
|
+
}
|
|
225
|
+
),
|
|
226
|
+
output_json_schema=json.dumps(
|
|
227
|
+
{
|
|
228
|
+
"type": "object",
|
|
229
|
+
"properties": {"joke": {"type": "string"}},
|
|
230
|
+
"required": ["joke"],
|
|
231
|
+
}
|
|
232
|
+
),
|
|
233
|
+
)
|
|
234
|
+
task.save_to_file()
|
|
235
|
+
|
|
236
|
+
# Create 6 examples (2 repaired, 4 high-quality)
|
|
237
|
+
for i in range(6):
|
|
238
|
+
run = TaskRun(
|
|
239
|
+
input=f'{{"subject": "Subject {i+1}"}}',
|
|
240
|
+
input_source=DataSource(
|
|
241
|
+
type=DataSourceType.human,
|
|
242
|
+
properties={"created_by": "john_doe"},
|
|
243
|
+
),
|
|
244
|
+
parent=task,
|
|
245
|
+
output=TaskOutput(
|
|
246
|
+
output=f'{{"joke": "Joke Initial Output {i+1}"}}',
|
|
247
|
+
source=DataSource(
|
|
248
|
+
type=DataSourceType.human,
|
|
249
|
+
properties={"created_by": "john_doe"},
|
|
250
|
+
),
|
|
251
|
+
rating=TaskOutputRating(value=4 + (i % 2), reason="Good joke"),
|
|
252
|
+
),
|
|
253
|
+
)
|
|
254
|
+
print("RATING", "Joke Initial Output ", i + 1, " - RATED:", 4 + (i % 2), "\n")
|
|
255
|
+
if i < 2:
|
|
256
|
+
run = run.model_copy(
|
|
257
|
+
update={
|
|
258
|
+
"repair_instructions": "Fix the joke",
|
|
259
|
+
"repaired_output": TaskOutput(
|
|
260
|
+
output=f'{{"joke": "Repaired Joke {i+1}"}}',
|
|
261
|
+
source=DataSource(
|
|
262
|
+
type=DataSourceType.human,
|
|
263
|
+
properties={"created_by": "jane_doe"},
|
|
264
|
+
),
|
|
265
|
+
),
|
|
266
|
+
}
|
|
267
|
+
)
|
|
268
|
+
run.save_to_file()
|
|
269
|
+
|
|
270
|
+
# Check that only 4 examples are included
|
|
271
|
+
prompt_builder = FewShotPromptBuilder(task=task)
|
|
272
|
+
prompt = prompt_builder.build_prompt()
|
|
273
|
+
assert prompt.count("## Example") == 4
|
|
274
|
+
|
|
275
|
+
print("PROMPT", prompt)
|
|
276
|
+
# Verify the order of examples (2 repaired, then 2 highest-rated)
|
|
277
|
+
assert "Repaired Joke 1" in prompt
|
|
278
|
+
assert "Repaired Joke 2" in prompt
|
|
279
|
+
assert "Joke Initial Output 6" in prompt # Rating 5
|
|
280
|
+
assert "Joke Initial Output 4" in prompt # Rating 5
|
|
281
|
+
assert "Joke Initial Output 5" not in prompt # Rating 4, not included
|
|
282
|
+
assert "Joke Initial Output 3" not in prompt # Rating 4, not included
|
|
283
|
+
assert "Joke Initial Output 1" not in prompt # Repaired, so using that
|
|
284
|
+
assert "Joke Initial Output 2" not in prompt # Repaired, so using that
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def check_example_outputs(task: Task, count: int):
|
|
288
|
+
prompt_builder = MultiShotPromptBuilder(task=task)
|
|
289
|
+
prompt = prompt_builder.build_prompt()
|
|
290
|
+
assert "# Instruction" in prompt
|
|
291
|
+
assert task.instruction in prompt
|
|
292
|
+
if count == 0:
|
|
293
|
+
assert "# Example Outputs" not in prompt
|
|
294
|
+
else:
|
|
295
|
+
assert "# Example Outputs" in prompt
|
|
296
|
+
assert f"## Example {count}" in prompt
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def test_prompt_builder_name():
|
|
300
|
+
assert SimplePromptBuilder.prompt_builder_name() == "simple_prompt_builder"
|
|
301
|
+
assert MultiShotPromptBuilder.prompt_builder_name() == "multi_shot_prompt_builder"
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def test_prompt_builder_from_ui_name():
|
|
305
|
+
assert prompt_builder_from_ui_name("basic") == SimplePromptBuilder
|
|
306
|
+
assert prompt_builder_from_ui_name("few_shot") == FewShotPromptBuilder
|
|
307
|
+
assert prompt_builder_from_ui_name("many_shot") == MultiShotPromptBuilder
|
|
308
|
+
|
|
309
|
+
with pytest.raises(ValueError, match="Unknown prompt builder: invalid_name"):
|
|
310
|
+
prompt_builder_from_ui_name("invalid_name")
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def test_example_count():
|
|
314
|
+
assert FewShotPromptBuilder.example_count() == 4
|
|
315
|
+
assert MultiShotPromptBuilder.example_count() == 25
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
from unittest.mock import patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter
|
|
5
|
+
from kiln_ai.datamodel import (
|
|
6
|
+
DataSource,
|
|
7
|
+
DataSourceType,
|
|
8
|
+
Project,
|
|
9
|
+
Task,
|
|
10
|
+
)
|
|
11
|
+
from kiln_ai.utils.config import Config
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MockAdapter(BaseAdapter):
|
|
15
|
+
async def _run(self, input: dict | str) -> dict | str:
|
|
16
|
+
return "Test output"
|
|
17
|
+
|
|
18
|
+
def adapter_info(self) -> AdapterInfo:
|
|
19
|
+
return AdapterInfo(
|
|
20
|
+
adapter_name="mock_adapter",
|
|
21
|
+
model_name="mock_model",
|
|
22
|
+
model_provider="mock_provider",
|
|
23
|
+
prompt_builder_name="mock_prompt_builder",
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@pytest.fixture
|
|
28
|
+
def test_task(tmp_path):
|
|
29
|
+
project = Project(name="test_project", path=tmp_path / "test_project.kiln")
|
|
30
|
+
project.save_to_file()
|
|
31
|
+
task = Task(
|
|
32
|
+
parent=project,
|
|
33
|
+
name="test_task",
|
|
34
|
+
instruction="Task instruction",
|
|
35
|
+
)
|
|
36
|
+
task.save_to_file()
|
|
37
|
+
return task
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def test_save_run_isolation(test_task):
|
|
41
|
+
adapter = MockAdapter(test_task)
|
|
42
|
+
input_data = "Test input"
|
|
43
|
+
output_data = "Test output"
|
|
44
|
+
|
|
45
|
+
task_run = adapter.generate_run(
|
|
46
|
+
input=input_data, input_source=None, output=output_data
|
|
47
|
+
)
|
|
48
|
+
task_run.save_to_file()
|
|
49
|
+
|
|
50
|
+
# Check that the task input was saved correctly
|
|
51
|
+
assert task_run.parent == test_task
|
|
52
|
+
assert task_run.input == input_data
|
|
53
|
+
assert task_run.input_source.type == DataSourceType.human
|
|
54
|
+
created_by = Config.shared().user_id
|
|
55
|
+
if created_by and created_by != "":
|
|
56
|
+
assert task_run.input_source.properties["created_by"] == created_by
|
|
57
|
+
else:
|
|
58
|
+
assert "created_by" not in task_run.input_source.properties
|
|
59
|
+
|
|
60
|
+
# Check that the task output was saved correctly
|
|
61
|
+
saved_output = task_run.output
|
|
62
|
+
assert saved_output.output == output_data
|
|
63
|
+
assert saved_output.source.type == DataSourceType.synthetic
|
|
64
|
+
assert saved_output.rating is None
|
|
65
|
+
|
|
66
|
+
# Verify that the data can be read back from disk
|
|
67
|
+
reloaded_task = Task.load_from_file(test_task.path)
|
|
68
|
+
reloaded_runs = reloaded_task.runs()
|
|
69
|
+
assert len(reloaded_runs) == 1
|
|
70
|
+
reloaded_run = reloaded_runs[0]
|
|
71
|
+
assert reloaded_run.input == input_data
|
|
72
|
+
assert reloaded_run.input_source.type == DataSourceType.human
|
|
73
|
+
reloaded_output = reloaded_run.output
|
|
74
|
+
|
|
75
|
+
reloaded_output = reloaded_run.output
|
|
76
|
+
assert reloaded_output.output == output_data
|
|
77
|
+
assert reloaded_output.source.type == DataSourceType.synthetic
|
|
78
|
+
assert reloaded_output.rating is None
|
|
79
|
+
assert reloaded_output.source.properties["adapter_name"] == "mock_adapter"
|
|
80
|
+
assert reloaded_output.source.properties["model_name"] == "mock_model"
|
|
81
|
+
assert reloaded_output.source.properties["model_provider"] == "mock_provider"
|
|
82
|
+
assert (
|
|
83
|
+
reloaded_output.source.properties["prompt_builder_name"]
|
|
84
|
+
== "mock_prompt_builder"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Run again, with same input and different output. Should create a new TaskRun.
|
|
88
|
+
task_output = adapter.generate_run(input_data, None, "Different output")
|
|
89
|
+
task_output.save_to_file()
|
|
90
|
+
assert len(test_task.runs()) == 2
|
|
91
|
+
assert "Different output" in set(run.output.output for run in test_task.runs())
|
|
92
|
+
|
|
93
|
+
# run again with same input and same output. Should not create a new TaskRun.
|
|
94
|
+
task_output = adapter.generate_run(input_data, None, output_data)
|
|
95
|
+
task_output.save_to_file()
|
|
96
|
+
assert len(test_task.runs()) == 2
|
|
97
|
+
assert "Different output" in set(run.output.output for run in test_task.runs())
|
|
98
|
+
assert output_data in set(run.output.output for run in test_task.runs())
|
|
99
|
+
|
|
100
|
+
# run again with input of different type. Should create a new TaskRun and TaskOutput.
|
|
101
|
+
task_output = adapter.generate_run(
|
|
102
|
+
input_data,
|
|
103
|
+
DataSource(
|
|
104
|
+
type=DataSourceType.synthetic,
|
|
105
|
+
properties={
|
|
106
|
+
"model_name": "mock_model",
|
|
107
|
+
"model_provider": "mock_provider",
|
|
108
|
+
"prompt_builder_name": "mock_prompt_builder",
|
|
109
|
+
"adapter_name": "mock_adapter",
|
|
110
|
+
},
|
|
111
|
+
),
|
|
112
|
+
output_data,
|
|
113
|
+
)
|
|
114
|
+
task_output.save_to_file()
|
|
115
|
+
assert len(test_task.runs()) == 3
|
|
116
|
+
assert task_output.input == input_data
|
|
117
|
+
assert task_output.input_source.type == DataSourceType.synthetic
|
|
118
|
+
assert "Different output" in set(run.output.output for run in test_task.runs())
|
|
119
|
+
assert output_data in set(run.output.output for run in test_task.runs())
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@pytest.mark.asyncio
|
|
123
|
+
async def test_autosave_false(test_task):
|
|
124
|
+
with patch("kiln_ai.utils.config.Config.shared") as mock_shared:
|
|
125
|
+
mock_config = mock_shared.return_value
|
|
126
|
+
mock_config.autosave_runs = False
|
|
127
|
+
mock_config.user_id = "test_user"
|
|
128
|
+
|
|
129
|
+
adapter = MockAdapter(test_task)
|
|
130
|
+
input_data = "Test input"
|
|
131
|
+
|
|
132
|
+
run = await adapter.invoke(input_data)
|
|
133
|
+
|
|
134
|
+
# Check that no runs were saved
|
|
135
|
+
assert len(test_task.runs()) == 0
|
|
136
|
+
|
|
137
|
+
# Check that the run ID is not set
|
|
138
|
+
assert run.id is None
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@pytest.mark.asyncio
|
|
142
|
+
async def test_autosave_true(test_task):
|
|
143
|
+
with patch("kiln_ai.utils.config.Config.shared") as mock_shared:
|
|
144
|
+
mock_config = mock_shared.return_value
|
|
145
|
+
mock_config.autosave_runs = True
|
|
146
|
+
mock_config.user_id = "test_user"
|
|
147
|
+
|
|
148
|
+
adapter = MockAdapter(test_task)
|
|
149
|
+
input_data = "Test input"
|
|
150
|
+
|
|
151
|
+
run = await adapter.invoke(input_data)
|
|
152
|
+
|
|
153
|
+
# Check that the run ID is set
|
|
154
|
+
assert run.id is not None
|
|
155
|
+
|
|
156
|
+
# Check that an task input was saved
|
|
157
|
+
task_runs = test_task.runs()
|
|
158
|
+
assert len(task_runs) == 1
|
|
159
|
+
assert task_runs[0].input == input_data
|
|
160
|
+
assert task_runs[0].input_source.type == DataSourceType.human
|
|
161
|
+
|
|
162
|
+
output = task_runs[0].output
|
|
163
|
+
assert output.output == "Test output"
|
|
164
|
+
assert output.source.type == DataSourceType.synthetic
|
|
165
|
+
assert output.source.properties["adapter_name"] == "mock_adapter"
|
|
166
|
+
assert output.source.properties["model_name"] == "mock_model"
|
|
167
|
+
assert output.source.properties["model_provider"] == "mock_provider"
|
|
168
|
+
assert output.source.properties["prompt_builder_name"] == "mock_prompt_builder"
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Dict
|
|
3
|
+
|
|
4
|
+
import jsonschema
|
|
5
|
+
import jsonschema.exceptions
|
|
6
|
+
import kiln_ai.datamodel as datamodel
|
|
7
|
+
import pytest
|
|
8
|
+
from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter
|
|
9
|
+
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
|
|
10
|
+
from kiln_ai.adapters.ml_model_list import (
|
|
11
|
+
built_in_models,
|
|
12
|
+
ollama_online,
|
|
13
|
+
)
|
|
14
|
+
from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@pytest.mark.paid
|
|
18
|
+
async def test_structured_output_groq(tmp_path):
|
|
19
|
+
await run_structured_output_test(tmp_path, "llama_3_1_8b", "groq")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@pytest.mark.paid
|
|
23
|
+
async def test_structured_output_openrouter(tmp_path):
|
|
24
|
+
await run_structured_output_test(tmp_path, "mistral_nemo", "openrouter")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@pytest.mark.paid
|
|
28
|
+
async def test_structured_output_bedrock(tmp_path):
|
|
29
|
+
await run_structured_output_test(tmp_path, "llama_3_1_70b", "amazon_bedrock")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@pytest.mark.ollama
|
|
33
|
+
async def test_structured_output_ollama_phi(tmp_path):
|
|
34
|
+
# https://python.langchain.com/v0.2/docs/how_to/structured_output/#advanced-specifying-the-method-for-structuring-outputs
|
|
35
|
+
pytest.skip(
|
|
36
|
+
"not working yet - phi3.5 does not support tools. Need json_mode + format in prompt"
|
|
37
|
+
)
|
|
38
|
+
await run_structured_output_test(tmp_path, "phi_3_5", "ollama")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@pytest.mark.ollama
|
|
42
|
+
async def test_structured_output_gpt_4o_mini(tmp_path):
|
|
43
|
+
await run_structured_output_test(tmp_path, "gpt_4o_mini", "openai")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@pytest.mark.ollama
|
|
47
|
+
async def test_structured_output_ollama_llama(tmp_path):
|
|
48
|
+
if not await ollama_online():
|
|
49
|
+
pytest.skip("Ollama API not running. Expect it running on localhost:11434")
|
|
50
|
+
await run_structured_output_test(tmp_path, "llama_3_1_8b", "ollama")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class MockAdapter(BaseAdapter):
|
|
54
|
+
def __init__(self, kiln_task: datamodel.Task, response: Dict | str | None):
|
|
55
|
+
super().__init__(kiln_task)
|
|
56
|
+
self.response = response
|
|
57
|
+
|
|
58
|
+
async def _run(self, input: str) -> Dict | str:
|
|
59
|
+
return self.response
|
|
60
|
+
|
|
61
|
+
def adapter_info(self) -> AdapterInfo:
|
|
62
|
+
return AdapterInfo(
|
|
63
|
+
adapter_name="mock_adapter",
|
|
64
|
+
model_name="mock_model",
|
|
65
|
+
model_provider="mock_provider",
|
|
66
|
+
prompt_builder_name="mock_prompt_builder",
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
async def test_mock_unstructred_response(tmp_path):
|
|
71
|
+
task = build_structured_output_test_task(tmp_path)
|
|
72
|
+
|
|
73
|
+
# don't error on valid response
|
|
74
|
+
adapter = MockAdapter(task, response={"setup": "asdf", "punchline": "asdf"})
|
|
75
|
+
answer = await adapter.invoke_returning_raw("You are a mock, send me the response!")
|
|
76
|
+
assert answer["setup"] == "asdf"
|
|
77
|
+
assert answer["punchline"] == "asdf"
|
|
78
|
+
|
|
79
|
+
# error on response that doesn't match schema
|
|
80
|
+
adapter = MockAdapter(task, response={"setup": "asdf"})
|
|
81
|
+
with pytest.raises(Exception):
|
|
82
|
+
answer = await adapter.invoke("You are a mock, send me the response!")
|
|
83
|
+
|
|
84
|
+
adapter = MockAdapter(task, response="string instead of dict")
|
|
85
|
+
with pytest.raises(RuntimeError):
|
|
86
|
+
# Not a structed response so should error
|
|
87
|
+
answer = await adapter.invoke("You are a mock, send me the response!")
|
|
88
|
+
|
|
89
|
+
# Should error, expecting a string, not a dict
|
|
90
|
+
project = datamodel.Project(name="test", path=tmp_path / "test.kiln")
|
|
91
|
+
task = datamodel.Task(
|
|
92
|
+
parent=project,
|
|
93
|
+
name="test task",
|
|
94
|
+
instruction="You are an assistant which performs math tasks provided in plain text.",
|
|
95
|
+
)
|
|
96
|
+
task.instruction = (
|
|
97
|
+
"You are an assistant which performs math tasks provided in plain text."
|
|
98
|
+
)
|
|
99
|
+
adapter = MockAdapter(task, response={"dict": "value"})
|
|
100
|
+
with pytest.raises(RuntimeError):
|
|
101
|
+
answer = await adapter.invoke("You are a mock, send me the response!")
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@pytest.mark.paid
|
|
105
|
+
@pytest.mark.ollama
|
|
106
|
+
async def test_all_built_in_models_structured_output(tmp_path):
|
|
107
|
+
for model in built_in_models:
|
|
108
|
+
if not model.supports_structured_output:
|
|
109
|
+
print(
|
|
110
|
+
f"Skipping {model.name} because it does not support structured output"
|
|
111
|
+
)
|
|
112
|
+
continue
|
|
113
|
+
for provider in model.providers:
|
|
114
|
+
if not provider.supports_structured_output:
|
|
115
|
+
print(
|
|
116
|
+
f"Skipping {model.name} {provider.name} because it does not support structured output"
|
|
117
|
+
)
|
|
118
|
+
continue
|
|
119
|
+
try:
|
|
120
|
+
print(f"Running {model.name} {provider.name}")
|
|
121
|
+
await run_structured_output_test(tmp_path, model.name, provider.name)
|
|
122
|
+
except Exception as e:
|
|
123
|
+
raise RuntimeError(f"Error running {model.name} {provider}") from e
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def build_structured_output_test_task(tmp_path: Path):
|
|
127
|
+
project = datamodel.Project(name="test", path=tmp_path / "test.kiln")
|
|
128
|
+
project.save_to_file()
|
|
129
|
+
task = datamodel.Task(
|
|
130
|
+
parent=project,
|
|
131
|
+
name="test task",
|
|
132
|
+
instruction="You are an assistant which tells a joke, given a subject.",
|
|
133
|
+
)
|
|
134
|
+
task.output_json_schema = json_joke_schema
|
|
135
|
+
schema = task.output_schema()
|
|
136
|
+
assert schema is not None
|
|
137
|
+
assert schema["properties"]["setup"]["type"] == "string"
|
|
138
|
+
assert schema["properties"]["punchline"]["type"] == "string"
|
|
139
|
+
task.save_to_file()
|
|
140
|
+
assert task.name == "test task"
|
|
141
|
+
assert len(task.requirements) == 0
|
|
142
|
+
return task
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
async def run_structured_output_test(tmp_path: Path, model_name: str, provider: str):
|
|
146
|
+
task = build_structured_output_test_task(tmp_path)
|
|
147
|
+
a = LangChainPromptAdapter(task, model_name=model_name, provider=provider)
|
|
148
|
+
parsed = await a.invoke_returning_raw("Cows") # a joke about cows
|
|
149
|
+
if parsed is None or not isinstance(parsed, Dict):
|
|
150
|
+
raise RuntimeError(f"structured response is not a dict: {parsed}")
|
|
151
|
+
assert parsed["setup"] is not None
|
|
152
|
+
assert parsed["punchline"] is not None
|
|
153
|
+
if "rating" in parsed and parsed["rating"] is not None:
|
|
154
|
+
rating = parsed["rating"]
|
|
155
|
+
# Note: really should be an int according to json schema, but mistral returns a string
|
|
156
|
+
if isinstance(rating, str):
|
|
157
|
+
rating = int(rating)
|
|
158
|
+
assert rating >= 0
|
|
159
|
+
assert rating <= 10
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def build_structured_input_test_task(tmp_path: Path):
|
|
163
|
+
project = datamodel.Project(name="test", path=tmp_path / "test.kiln")
|
|
164
|
+
project.save_to_file()
|
|
165
|
+
task = datamodel.Task(
|
|
166
|
+
parent=project,
|
|
167
|
+
name="test task",
|
|
168
|
+
instruction="You are an assistant which classifies a triangle given the lengths of its sides. If all sides are of equal length, the triangle is equilateral. If two sides are equal, the triangle is isosceles. Otherwise, it is scalene.\n\nAt the end of your response return the result in double square brackets. It should be plain text. It should be exactly one of the three following strings: '[[equilateral]]', or '[[isosceles]]', or '[[scalene]]'.",
|
|
169
|
+
)
|
|
170
|
+
task.input_json_schema = json_triangle_schema
|
|
171
|
+
schema = task.input_schema()
|
|
172
|
+
assert schema is not None
|
|
173
|
+
assert schema["properties"]["a"]["type"] == "integer"
|
|
174
|
+
assert schema["properties"]["b"]["type"] == "integer"
|
|
175
|
+
assert schema["properties"]["c"]["type"] == "integer"
|
|
176
|
+
assert schema["required"] == ["a", "b", "c"]
|
|
177
|
+
task.save_to_file()
|
|
178
|
+
assert task.name == "test task"
|
|
179
|
+
assert len(task.requirements) == 0
|
|
180
|
+
return task
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
async def run_structured_input_test(tmp_path: Path, model_name: str, provider: str):
|
|
184
|
+
task = build_structured_input_test_task(tmp_path)
|
|
185
|
+
a = LangChainPromptAdapter(task, model_name=model_name, provider=provider)
|
|
186
|
+
with pytest.raises(ValueError):
|
|
187
|
+
# not structured input in dictionary
|
|
188
|
+
await a.invoke("a=1, b=2, c=3")
|
|
189
|
+
with pytest.raises(jsonschema.exceptions.ValidationError):
|
|
190
|
+
# invalid structured input
|
|
191
|
+
await a.invoke({"a": 1, "b": 2, "d": 3})
|
|
192
|
+
|
|
193
|
+
response = await a.invoke_returning_raw({"a": 2, "b": 2, "c": 2})
|
|
194
|
+
assert response is not None
|
|
195
|
+
assert isinstance(response, str)
|
|
196
|
+
assert "[[equilateral]]" in response
|
|
197
|
+
adapter_info = a.adapter_info()
|
|
198
|
+
assert adapter_info.prompt_builder_name == "SimplePromptBuilder"
|
|
199
|
+
assert adapter_info.model_name == model_name
|
|
200
|
+
assert adapter_info.model_provider == provider
|
|
201
|
+
assert adapter_info.adapter_name == "kiln_langchain_adapter"
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@pytest.mark.paid
|
|
205
|
+
async def test_structured_input_gpt_4o_mini(tmp_path):
|
|
206
|
+
await run_structured_input_test(tmp_path, "llama_3_1_8b", "groq")
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@pytest.mark.paid
|
|
210
|
+
@pytest.mark.ollama
|
|
211
|
+
async def test_all_built_in_models_structured_input(tmp_path):
|
|
212
|
+
for model in built_in_models:
|
|
213
|
+
for provider in model.providers:
|
|
214
|
+
try:
|
|
215
|
+
print(f"Running {model.name} {provider.name}")
|
|
216
|
+
await run_structured_input_test(tmp_path, model.name, provider.name)
|
|
217
|
+
except Exception as e:
|
|
218
|
+
raise RuntimeError(f"Error running {model.name} {provider}") from e
|