kiln-ai 0.0.4__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (33) hide show
  1. kiln_ai/adapters/base_adapter.py +168 -0
  2. kiln_ai/adapters/langchain_adapters.py +113 -0
  3. kiln_ai/adapters/ml_model_list.py +436 -0
  4. kiln_ai/adapters/prompt_builders.py +122 -0
  5. kiln_ai/adapters/repair/repair_task.py +71 -0
  6. kiln_ai/adapters/repair/test_repair_task.py +248 -0
  7. kiln_ai/adapters/test_langchain_adapter.py +50 -0
  8. kiln_ai/adapters/test_ml_model_list.py +99 -0
  9. kiln_ai/adapters/test_prompt_adaptors.py +167 -0
  10. kiln_ai/adapters/test_prompt_builders.py +315 -0
  11. kiln_ai/adapters/test_saving_adapter_results.py +168 -0
  12. kiln_ai/adapters/test_structured_output.py +218 -0
  13. kiln_ai/datamodel/__init__.py +362 -2
  14. kiln_ai/datamodel/basemodel.py +372 -0
  15. kiln_ai/datamodel/json_schema.py +45 -0
  16. kiln_ai/datamodel/test_basemodel.py +277 -0
  17. kiln_ai/datamodel/test_datasource.py +107 -0
  18. kiln_ai/datamodel/test_example_models.py +644 -0
  19. kiln_ai/datamodel/test_json_schema.py +124 -0
  20. kiln_ai/datamodel/test_models.py +190 -0
  21. kiln_ai/datamodel/test_nested_save.py +205 -0
  22. kiln_ai/datamodel/test_output_rating.py +88 -0
  23. kiln_ai/utils/config.py +170 -0
  24. kiln_ai/utils/formatting.py +5 -0
  25. kiln_ai/utils/test_config.py +245 -0
  26. {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.1.dist-info}/METADATA +22 -1
  27. kiln_ai-0.5.1.dist-info/RECORD +29 -0
  28. kiln_ai/__init.__.py +0 -3
  29. kiln_ai/coreadd.py +0 -3
  30. kiln_ai/datamodel/project.py +0 -15
  31. kiln_ai-0.0.4.dist-info/RECORD +0 -8
  32. {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.1.dist-info}/LICENSE.txt +0 -0
  33. {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,315 @@
1
+ import json
2
+
3
+ import pytest
4
+ from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter
5
+ from kiln_ai.adapters.prompt_builders import (
6
+ FewShotPromptBuilder,
7
+ MultiShotPromptBuilder,
8
+ SimplePromptBuilder,
9
+ prompt_builder_from_ui_name,
10
+ )
11
+ from kiln_ai.adapters.test_prompt_adaptors import build_test_task
12
+ from kiln_ai.adapters.test_structured_output import build_structured_output_test_task
13
+ from kiln_ai.datamodel import (
14
+ DataSource,
15
+ DataSourceType,
16
+ Project,
17
+ Task,
18
+ TaskOutput,
19
+ TaskOutputRating,
20
+ TaskRun,
21
+ )
22
+
23
+
24
+ def test_simple_prompt_builder(tmp_path):
25
+ task = build_test_task(tmp_path)
26
+ builder = SimplePromptBuilder(task=task)
27
+ input = "two plus two"
28
+ prompt = builder.build_prompt()
29
+ assert (
30
+ "You are an assistant which performs math tasks provided in plain text."
31
+ in prompt
32
+ )
33
+
34
+ assert "1) " + task.requirements[0].instruction in prompt
35
+ assert "2) " + task.requirements[1].instruction in prompt
36
+ assert "3) " + task.requirements[2].instruction in prompt
37
+
38
+ user_msg = builder.build_user_message(input)
39
+ assert input in user_msg
40
+ assert input not in prompt
41
+
42
+
43
+ class MockAdapter(BaseAdapter):
44
+ def adapter_specific_instructions(self) -> str | None:
45
+ return "You are a mock, send me the response!"
46
+
47
+ def _run(self, input: str) -> str:
48
+ return "mock response"
49
+
50
+ def adapter_info(self) -> AdapterInfo:
51
+ return AdapterInfo(
52
+ adapter_name="mock_adapter",
53
+ model_name="mock_model",
54
+ model_provider="mock_provider",
55
+ )
56
+
57
+
58
+ def test_simple_prompt_builder_structured_output(tmp_path):
59
+ task = build_structured_output_test_task(tmp_path)
60
+ builder = SimplePromptBuilder(task=task)
61
+ input = "Cows"
62
+ prompt = builder.build_prompt()
63
+ assert "You are an assistant which tells a joke, given a subject." in prompt
64
+
65
+ # check adapter instructions are included
66
+ run_adapter = MockAdapter(task, prompt_builder=builder)
67
+ assert "You are a mock, send me the response!" in run_adapter.build_prompt()
68
+
69
+ user_msg = builder.build_user_message(input)
70
+ assert input in user_msg
71
+ assert input not in prompt
72
+
73
+
74
+ def test_multi_shot_prompt_builder(tmp_path):
75
+ # Create a project and task hierarchy
76
+ project = Project(name="Test Project", path=(tmp_path / "test_project.kiln"))
77
+ project.save_to_file()
78
+ task = Task(
79
+ name="Test Task",
80
+ instruction="You are an assistant which tells a joke, given a subject.",
81
+ parent=project,
82
+ input_json_schema=json.dumps(
83
+ {
84
+ "type": "object",
85
+ "properties": {
86
+ "subject": {"type": "string"},
87
+ },
88
+ "required": ["subject"],
89
+ }
90
+ ),
91
+ output_json_schema=json.dumps(
92
+ {
93
+ "type": "object",
94
+ "properties": {"joke": {"type": "string"}},
95
+ "required": ["joke"],
96
+ }
97
+ ),
98
+ )
99
+ task.save_to_file()
100
+
101
+ check_example_outputs(task, 0)
102
+
103
+ # Create an task input, but with no output
104
+ e1 = TaskRun(
105
+ input='{"subject": "Cows"}',
106
+ input_source=DataSource(
107
+ type=DataSourceType.human,
108
+ properties={"created_by": "john_doe"},
109
+ ),
110
+ parent=task,
111
+ output=TaskOutput(
112
+ output='{"joke": "Moo I am a cow joke."}',
113
+ source=DataSource(
114
+ type=DataSourceType.human,
115
+ properties={"created_by": "john_doe"},
116
+ ),
117
+ ),
118
+ )
119
+ e1.save_to_file()
120
+
121
+ ## still zero since not fixed and not rated highly
122
+ check_example_outputs(task, 0)
123
+
124
+ e1.output.rating = TaskOutputRating(value=4)
125
+ e1.save_to_file()
126
+ # Now that it's highly rated, it should be included
127
+ check_example_outputs(task, 1)
128
+
129
+ # Test with repaired output (highest priority)
130
+ e1 = TaskRun(
131
+ input='{"subject": "Cows"}',
132
+ input_source=DataSource(
133
+ type=DataSourceType.human,
134
+ properties={"created_by": "john_doe"},
135
+ ),
136
+ parent=task,
137
+ output=TaskOutput(
138
+ output='{"joke": "Moo I am a cow joke."}',
139
+ source=DataSource(
140
+ type=DataSourceType.human,
141
+ properties={"created_by": "john_doe"},
142
+ ),
143
+ ),
144
+ repair_instructions="Fix the joke",
145
+ repaired_output=TaskOutput(
146
+ output='{"joke": "Why did the cow cross the road? To get to the udder side!"}',
147
+ source=DataSource(
148
+ type=DataSourceType.human,
149
+ properties={"created_by": "jane_doe"},
150
+ ),
151
+ ),
152
+ )
153
+ e1.save_to_file()
154
+ check_example_outputs(task, 1)
155
+
156
+ # Test with high-quality output (second priority)
157
+ e2 = TaskRun(
158
+ input='{"subject": "Dogs"}',
159
+ input_source=DataSource(
160
+ type=DataSourceType.human,
161
+ properties={"created_by": "john_doe"},
162
+ ),
163
+ parent=task,
164
+ output=TaskOutput(
165
+ output='{"joke": "Why did the dog get a job? He wanted to be a collar-ary!"}',
166
+ source=DataSource(
167
+ type=DataSourceType.human,
168
+ properties={"created_by": "john_doe"},
169
+ ),
170
+ rating=TaskOutputRating(value=4, reason="Good pun"),
171
+ ),
172
+ )
173
+ e2.save_to_file()
174
+ check_example_outputs(task, 2)
175
+
176
+ # Test sorting by rating value
177
+ e3 = TaskRun(
178
+ input='{"subject": "Cats"}',
179
+ input_source=DataSource(
180
+ type=DataSourceType.human,
181
+ properties={"created_by": "john_doe"},
182
+ ),
183
+ parent=task,
184
+ output=TaskOutput(
185
+ output='{"joke": "Why don\'t cats play poker in the jungle? Too many cheetahs!"}',
186
+ source=DataSource(
187
+ type=DataSourceType.human,
188
+ properties={"created_by": "john_doe"},
189
+ ),
190
+ rating=TaskOutputRating(value=5, reason="Excellent joke"),
191
+ ),
192
+ )
193
+ e3.save_to_file()
194
+ check_example_outputs(task, 3)
195
+
196
+ # Verify the order of examples
197
+ prompt_builder = MultiShotPromptBuilder(task=task)
198
+ prompt = prompt_builder.build_prompt()
199
+ assert "Why did the cow cross the road?" in prompt
200
+ assert prompt.index("Why did the cow cross the road?") < prompt.index(
201
+ "Why don't cats play poker in the jungle?"
202
+ )
203
+ assert prompt.index("Why don't cats play poker in the jungle?") < prompt.index(
204
+ "Why did the dog get a job?"
205
+ )
206
+
207
+
208
+ # Add a new test for the FewShotPromptBuilder
209
+ def test_few_shot_prompt_builder(tmp_path):
210
+ # Create a project and task hierarchy (similar to test_multi_shot_prompt_builder)
211
+ project = Project(name="Test Project", path=(tmp_path / "test_project.kiln"))
212
+ project.save_to_file()
213
+ task = Task(
214
+ name="Test Task",
215
+ instruction="You are an assistant which tells a joke, given a subject.",
216
+ parent=project,
217
+ input_json_schema=json.dumps(
218
+ {
219
+ "type": "object",
220
+ "properties": {
221
+ "subject": {"type": "string"},
222
+ },
223
+ "required": ["subject"],
224
+ }
225
+ ),
226
+ output_json_schema=json.dumps(
227
+ {
228
+ "type": "object",
229
+ "properties": {"joke": {"type": "string"}},
230
+ "required": ["joke"],
231
+ }
232
+ ),
233
+ )
234
+ task.save_to_file()
235
+
236
+ # Create 6 examples (2 repaired, 4 high-quality)
237
+ for i in range(6):
238
+ run = TaskRun(
239
+ input=f'{{"subject": "Subject {i+1}"}}',
240
+ input_source=DataSource(
241
+ type=DataSourceType.human,
242
+ properties={"created_by": "john_doe"},
243
+ ),
244
+ parent=task,
245
+ output=TaskOutput(
246
+ output=f'{{"joke": "Joke Initial Output {i+1}"}}',
247
+ source=DataSource(
248
+ type=DataSourceType.human,
249
+ properties={"created_by": "john_doe"},
250
+ ),
251
+ rating=TaskOutputRating(value=4 + (i % 2), reason="Good joke"),
252
+ ),
253
+ )
254
+ print("RATING", "Joke Initial Output ", i + 1, " - RATED:", 4 + (i % 2), "\n")
255
+ if i < 2:
256
+ run = run.model_copy(
257
+ update={
258
+ "repair_instructions": "Fix the joke",
259
+ "repaired_output": TaskOutput(
260
+ output=f'{{"joke": "Repaired Joke {i+1}"}}',
261
+ source=DataSource(
262
+ type=DataSourceType.human,
263
+ properties={"created_by": "jane_doe"},
264
+ ),
265
+ ),
266
+ }
267
+ )
268
+ run.save_to_file()
269
+
270
+ # Check that only 4 examples are included
271
+ prompt_builder = FewShotPromptBuilder(task=task)
272
+ prompt = prompt_builder.build_prompt()
273
+ assert prompt.count("## Example") == 4
274
+
275
+ print("PROMPT", prompt)
276
+ # Verify the order of examples (2 repaired, then 2 highest-rated)
277
+ assert "Repaired Joke 1" in prompt
278
+ assert "Repaired Joke 2" in prompt
279
+ assert "Joke Initial Output 6" in prompt # Rating 5
280
+ assert "Joke Initial Output 4" in prompt # Rating 5
281
+ assert "Joke Initial Output 5" not in prompt # Rating 4, not included
282
+ assert "Joke Initial Output 3" not in prompt # Rating 4, not included
283
+ assert "Joke Initial Output 1" not in prompt # Repaired, so using that
284
+ assert "Joke Initial Output 2" not in prompt # Repaired, so using that
285
+
286
+
287
+ def check_example_outputs(task: Task, count: int):
288
+ prompt_builder = MultiShotPromptBuilder(task=task)
289
+ prompt = prompt_builder.build_prompt()
290
+ assert "# Instruction" in prompt
291
+ assert task.instruction in prompt
292
+ if count == 0:
293
+ assert "# Example Outputs" not in prompt
294
+ else:
295
+ assert "# Example Outputs" in prompt
296
+ assert f"## Example {count}" in prompt
297
+
298
+
299
+ def test_prompt_builder_name():
300
+ assert SimplePromptBuilder.prompt_builder_name() == "simple_prompt_builder"
301
+ assert MultiShotPromptBuilder.prompt_builder_name() == "multi_shot_prompt_builder"
302
+
303
+
304
+ def test_prompt_builder_from_ui_name():
305
+ assert prompt_builder_from_ui_name("basic") == SimplePromptBuilder
306
+ assert prompt_builder_from_ui_name("few_shot") == FewShotPromptBuilder
307
+ assert prompt_builder_from_ui_name("many_shot") == MultiShotPromptBuilder
308
+
309
+ with pytest.raises(ValueError, match="Unknown prompt builder: invalid_name"):
310
+ prompt_builder_from_ui_name("invalid_name")
311
+
312
+
313
+ def test_example_count():
314
+ assert FewShotPromptBuilder.example_count() == 4
315
+ assert MultiShotPromptBuilder.example_count() == 25
@@ -0,0 +1,168 @@
1
+ from unittest.mock import patch
2
+
3
+ import pytest
4
+ from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter
5
+ from kiln_ai.datamodel import (
6
+ DataSource,
7
+ DataSourceType,
8
+ Project,
9
+ Task,
10
+ )
11
+ from kiln_ai.utils.config import Config
12
+
13
+
14
+ class MockAdapter(BaseAdapter):
15
+ async def _run(self, input: dict | str) -> dict | str:
16
+ return "Test output"
17
+
18
+ def adapter_info(self) -> AdapterInfo:
19
+ return AdapterInfo(
20
+ adapter_name="mock_adapter",
21
+ model_name="mock_model",
22
+ model_provider="mock_provider",
23
+ prompt_builder_name="mock_prompt_builder",
24
+ )
25
+
26
+
27
+ @pytest.fixture
28
+ def test_task(tmp_path):
29
+ project = Project(name="test_project", path=tmp_path / "test_project.kiln")
30
+ project.save_to_file()
31
+ task = Task(
32
+ parent=project,
33
+ name="test_task",
34
+ instruction="Task instruction",
35
+ )
36
+ task.save_to_file()
37
+ return task
38
+
39
+
40
+ def test_save_run_isolation(test_task):
41
+ adapter = MockAdapter(test_task)
42
+ input_data = "Test input"
43
+ output_data = "Test output"
44
+
45
+ task_run = adapter.generate_run(
46
+ input=input_data, input_source=None, output=output_data
47
+ )
48
+ task_run.save_to_file()
49
+
50
+ # Check that the task input was saved correctly
51
+ assert task_run.parent == test_task
52
+ assert task_run.input == input_data
53
+ assert task_run.input_source.type == DataSourceType.human
54
+ created_by = Config.shared().user_id
55
+ if created_by and created_by != "":
56
+ assert task_run.input_source.properties["created_by"] == created_by
57
+ else:
58
+ assert "created_by" not in task_run.input_source.properties
59
+
60
+ # Check that the task output was saved correctly
61
+ saved_output = task_run.output
62
+ assert saved_output.output == output_data
63
+ assert saved_output.source.type == DataSourceType.synthetic
64
+ assert saved_output.rating is None
65
+
66
+ # Verify that the data can be read back from disk
67
+ reloaded_task = Task.load_from_file(test_task.path)
68
+ reloaded_runs = reloaded_task.runs()
69
+ assert len(reloaded_runs) == 1
70
+ reloaded_run = reloaded_runs[0]
71
+ assert reloaded_run.input == input_data
72
+ assert reloaded_run.input_source.type == DataSourceType.human
73
+ reloaded_output = reloaded_run.output
74
+
75
+ reloaded_output = reloaded_run.output
76
+ assert reloaded_output.output == output_data
77
+ assert reloaded_output.source.type == DataSourceType.synthetic
78
+ assert reloaded_output.rating is None
79
+ assert reloaded_output.source.properties["adapter_name"] == "mock_adapter"
80
+ assert reloaded_output.source.properties["model_name"] == "mock_model"
81
+ assert reloaded_output.source.properties["model_provider"] == "mock_provider"
82
+ assert (
83
+ reloaded_output.source.properties["prompt_builder_name"]
84
+ == "mock_prompt_builder"
85
+ )
86
+
87
+ # Run again, with same input and different output. Should create a new TaskRun.
88
+ task_output = adapter.generate_run(input_data, None, "Different output")
89
+ task_output.save_to_file()
90
+ assert len(test_task.runs()) == 2
91
+ assert "Different output" in set(run.output.output for run in test_task.runs())
92
+
93
+ # run again with same input and same output. Should not create a new TaskRun.
94
+ task_output = adapter.generate_run(input_data, None, output_data)
95
+ task_output.save_to_file()
96
+ assert len(test_task.runs()) == 2
97
+ assert "Different output" in set(run.output.output for run in test_task.runs())
98
+ assert output_data in set(run.output.output for run in test_task.runs())
99
+
100
+ # run again with input of different type. Should create a new TaskRun and TaskOutput.
101
+ task_output = adapter.generate_run(
102
+ input_data,
103
+ DataSource(
104
+ type=DataSourceType.synthetic,
105
+ properties={
106
+ "model_name": "mock_model",
107
+ "model_provider": "mock_provider",
108
+ "prompt_builder_name": "mock_prompt_builder",
109
+ "adapter_name": "mock_adapter",
110
+ },
111
+ ),
112
+ output_data,
113
+ )
114
+ task_output.save_to_file()
115
+ assert len(test_task.runs()) == 3
116
+ assert task_output.input == input_data
117
+ assert task_output.input_source.type == DataSourceType.synthetic
118
+ assert "Different output" in set(run.output.output for run in test_task.runs())
119
+ assert output_data in set(run.output.output for run in test_task.runs())
120
+
121
+
122
+ @pytest.mark.asyncio
123
+ async def test_autosave_false(test_task):
124
+ with patch("kiln_ai.utils.config.Config.shared") as mock_shared:
125
+ mock_config = mock_shared.return_value
126
+ mock_config.autosave_runs = False
127
+ mock_config.user_id = "test_user"
128
+
129
+ adapter = MockAdapter(test_task)
130
+ input_data = "Test input"
131
+
132
+ run = await adapter.invoke(input_data)
133
+
134
+ # Check that no runs were saved
135
+ assert len(test_task.runs()) == 0
136
+
137
+ # Check that the run ID is not set
138
+ assert run.id is None
139
+
140
+
141
+ @pytest.mark.asyncio
142
+ async def test_autosave_true(test_task):
143
+ with patch("kiln_ai.utils.config.Config.shared") as mock_shared:
144
+ mock_config = mock_shared.return_value
145
+ mock_config.autosave_runs = True
146
+ mock_config.user_id = "test_user"
147
+
148
+ adapter = MockAdapter(test_task)
149
+ input_data = "Test input"
150
+
151
+ run = await adapter.invoke(input_data)
152
+
153
+ # Check that the run ID is set
154
+ assert run.id is not None
155
+
156
+ # Check that an task input was saved
157
+ task_runs = test_task.runs()
158
+ assert len(task_runs) == 1
159
+ assert task_runs[0].input == input_data
160
+ assert task_runs[0].input_source.type == DataSourceType.human
161
+
162
+ output = task_runs[0].output
163
+ assert output.output == "Test output"
164
+ assert output.source.type == DataSourceType.synthetic
165
+ assert output.source.properties["adapter_name"] == "mock_adapter"
166
+ assert output.source.properties["model_name"] == "mock_model"
167
+ assert output.source.properties["model_provider"] == "mock_provider"
168
+ assert output.source.properties["prompt_builder_name"] == "mock_prompt_builder"
@@ -0,0 +1,218 @@
1
+ from pathlib import Path
2
+ from typing import Dict
3
+
4
+ import jsonschema
5
+ import jsonschema.exceptions
6
+ import kiln_ai.datamodel as datamodel
7
+ import pytest
8
+ from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter
9
+ from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
10
+ from kiln_ai.adapters.ml_model_list import (
11
+ built_in_models,
12
+ ollama_online,
13
+ )
14
+ from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
15
+
16
+
17
+ @pytest.mark.paid
18
+ async def test_structured_output_groq(tmp_path):
19
+ await run_structured_output_test(tmp_path, "llama_3_1_8b", "groq")
20
+
21
+
22
+ @pytest.mark.paid
23
+ async def test_structured_output_openrouter(tmp_path):
24
+ await run_structured_output_test(tmp_path, "mistral_nemo", "openrouter")
25
+
26
+
27
+ @pytest.mark.paid
28
+ async def test_structured_output_bedrock(tmp_path):
29
+ await run_structured_output_test(tmp_path, "llama_3_1_70b", "amazon_bedrock")
30
+
31
+
32
+ @pytest.mark.ollama
33
+ async def test_structured_output_ollama_phi(tmp_path):
34
+ # https://python.langchain.com/v0.2/docs/how_to/structured_output/#advanced-specifying-the-method-for-structuring-outputs
35
+ pytest.skip(
36
+ "not working yet - phi3.5 does not support tools. Need json_mode + format in prompt"
37
+ )
38
+ await run_structured_output_test(tmp_path, "phi_3_5", "ollama")
39
+
40
+
41
+ @pytest.mark.ollama
42
+ async def test_structured_output_gpt_4o_mini(tmp_path):
43
+ await run_structured_output_test(tmp_path, "gpt_4o_mini", "openai")
44
+
45
+
46
+ @pytest.mark.ollama
47
+ async def test_structured_output_ollama_llama(tmp_path):
48
+ if not await ollama_online():
49
+ pytest.skip("Ollama API not running. Expect it running on localhost:11434")
50
+ await run_structured_output_test(tmp_path, "llama_3_1_8b", "ollama")
51
+
52
+
53
+ class MockAdapter(BaseAdapter):
54
+ def __init__(self, kiln_task: datamodel.Task, response: Dict | str | None):
55
+ super().__init__(kiln_task)
56
+ self.response = response
57
+
58
+ async def _run(self, input: str) -> Dict | str:
59
+ return self.response
60
+
61
+ def adapter_info(self) -> AdapterInfo:
62
+ return AdapterInfo(
63
+ adapter_name="mock_adapter",
64
+ model_name="mock_model",
65
+ model_provider="mock_provider",
66
+ prompt_builder_name="mock_prompt_builder",
67
+ )
68
+
69
+
70
+ async def test_mock_unstructred_response(tmp_path):
71
+ task = build_structured_output_test_task(tmp_path)
72
+
73
+ # don't error on valid response
74
+ adapter = MockAdapter(task, response={"setup": "asdf", "punchline": "asdf"})
75
+ answer = await adapter.invoke_returning_raw("You are a mock, send me the response!")
76
+ assert answer["setup"] == "asdf"
77
+ assert answer["punchline"] == "asdf"
78
+
79
+ # error on response that doesn't match schema
80
+ adapter = MockAdapter(task, response={"setup": "asdf"})
81
+ with pytest.raises(Exception):
82
+ answer = await adapter.invoke("You are a mock, send me the response!")
83
+
84
+ adapter = MockAdapter(task, response="string instead of dict")
85
+ with pytest.raises(RuntimeError):
86
+ # Not a structed response so should error
87
+ answer = await adapter.invoke("You are a mock, send me the response!")
88
+
89
+ # Should error, expecting a string, not a dict
90
+ project = datamodel.Project(name="test", path=tmp_path / "test.kiln")
91
+ task = datamodel.Task(
92
+ parent=project,
93
+ name="test task",
94
+ instruction="You are an assistant which performs math tasks provided in plain text.",
95
+ )
96
+ task.instruction = (
97
+ "You are an assistant which performs math tasks provided in plain text."
98
+ )
99
+ adapter = MockAdapter(task, response={"dict": "value"})
100
+ with pytest.raises(RuntimeError):
101
+ answer = await adapter.invoke("You are a mock, send me the response!")
102
+
103
+
104
+ @pytest.mark.paid
105
+ @pytest.mark.ollama
106
+ async def test_all_built_in_models_structured_output(tmp_path):
107
+ for model in built_in_models:
108
+ if not model.supports_structured_output:
109
+ print(
110
+ f"Skipping {model.name} because it does not support structured output"
111
+ )
112
+ continue
113
+ for provider in model.providers:
114
+ if not provider.supports_structured_output:
115
+ print(
116
+ f"Skipping {model.name} {provider.name} because it does not support structured output"
117
+ )
118
+ continue
119
+ try:
120
+ print(f"Running {model.name} {provider.name}")
121
+ await run_structured_output_test(tmp_path, model.name, provider.name)
122
+ except Exception as e:
123
+ raise RuntimeError(f"Error running {model.name} {provider}") from e
124
+
125
+
126
+ def build_structured_output_test_task(tmp_path: Path):
127
+ project = datamodel.Project(name="test", path=tmp_path / "test.kiln")
128
+ project.save_to_file()
129
+ task = datamodel.Task(
130
+ parent=project,
131
+ name="test task",
132
+ instruction="You are an assistant which tells a joke, given a subject.",
133
+ )
134
+ task.output_json_schema = json_joke_schema
135
+ schema = task.output_schema()
136
+ assert schema is not None
137
+ assert schema["properties"]["setup"]["type"] == "string"
138
+ assert schema["properties"]["punchline"]["type"] == "string"
139
+ task.save_to_file()
140
+ assert task.name == "test task"
141
+ assert len(task.requirements) == 0
142
+ return task
143
+
144
+
145
+ async def run_structured_output_test(tmp_path: Path, model_name: str, provider: str):
146
+ task = build_structured_output_test_task(tmp_path)
147
+ a = LangChainPromptAdapter(task, model_name=model_name, provider=provider)
148
+ parsed = await a.invoke_returning_raw("Cows") # a joke about cows
149
+ if parsed is None or not isinstance(parsed, Dict):
150
+ raise RuntimeError(f"structured response is not a dict: {parsed}")
151
+ assert parsed["setup"] is not None
152
+ assert parsed["punchline"] is not None
153
+ if "rating" in parsed and parsed["rating"] is not None:
154
+ rating = parsed["rating"]
155
+ # Note: really should be an int according to json schema, but mistral returns a string
156
+ if isinstance(rating, str):
157
+ rating = int(rating)
158
+ assert rating >= 0
159
+ assert rating <= 10
160
+
161
+
162
+ def build_structured_input_test_task(tmp_path: Path):
163
+ project = datamodel.Project(name="test", path=tmp_path / "test.kiln")
164
+ project.save_to_file()
165
+ task = datamodel.Task(
166
+ parent=project,
167
+ name="test task",
168
+ instruction="You are an assistant which classifies a triangle given the lengths of its sides. If all sides are of equal length, the triangle is equilateral. If two sides are equal, the triangle is isosceles. Otherwise, it is scalene.\n\nAt the end of your response return the result in double square brackets. It should be plain text. It should be exactly one of the three following strings: '[[equilateral]]', or '[[isosceles]]', or '[[scalene]]'.",
169
+ )
170
+ task.input_json_schema = json_triangle_schema
171
+ schema = task.input_schema()
172
+ assert schema is not None
173
+ assert schema["properties"]["a"]["type"] == "integer"
174
+ assert schema["properties"]["b"]["type"] == "integer"
175
+ assert schema["properties"]["c"]["type"] == "integer"
176
+ assert schema["required"] == ["a", "b", "c"]
177
+ task.save_to_file()
178
+ assert task.name == "test task"
179
+ assert len(task.requirements) == 0
180
+ return task
181
+
182
+
183
+ async def run_structured_input_test(tmp_path: Path, model_name: str, provider: str):
184
+ task = build_structured_input_test_task(tmp_path)
185
+ a = LangChainPromptAdapter(task, model_name=model_name, provider=provider)
186
+ with pytest.raises(ValueError):
187
+ # not structured input in dictionary
188
+ await a.invoke("a=1, b=2, c=3")
189
+ with pytest.raises(jsonschema.exceptions.ValidationError):
190
+ # invalid structured input
191
+ await a.invoke({"a": 1, "b": 2, "d": 3})
192
+
193
+ response = await a.invoke_returning_raw({"a": 2, "b": 2, "c": 2})
194
+ assert response is not None
195
+ assert isinstance(response, str)
196
+ assert "[[equilateral]]" in response
197
+ adapter_info = a.adapter_info()
198
+ assert adapter_info.prompt_builder_name == "SimplePromptBuilder"
199
+ assert adapter_info.model_name == model_name
200
+ assert adapter_info.model_provider == provider
201
+ assert adapter_info.adapter_name == "kiln_langchain_adapter"
202
+
203
+
204
+ @pytest.mark.paid
205
+ async def test_structured_input_gpt_4o_mini(tmp_path):
206
+ await run_structured_input_test(tmp_path, "llama_3_1_8b", "groq")
207
+
208
+
209
+ @pytest.mark.paid
210
+ @pytest.mark.ollama
211
+ async def test_all_built_in_models_structured_input(tmp_path):
212
+ for model in built_in_models:
213
+ for provider in model.providers:
214
+ try:
215
+ print(f"Running {model.name} {provider.name}")
216
+ await run_structured_input_test(tmp_path, model.name, provider.name)
217
+ except Exception as e:
218
+ raise RuntimeError(f"Error running {model.name} {provider}") from e