kiln-ai 0.5.5__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +9 -1
- kiln_ai/adapters/base_adapter.py +24 -35
- kiln_ai/adapters/data_gen/__init__.py +11 -0
- kiln_ai/adapters/data_gen/data_gen_prompts.py +73 -0
- kiln_ai/adapters/data_gen/data_gen_task.py +185 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +293 -0
- kiln_ai/adapters/langchain_adapters.py +39 -7
- kiln_ai/adapters/ml_model_list.py +55 -1
- kiln_ai/adapters/prompt_builders.py +66 -0
- kiln_ai/adapters/repair/test_repair_task.py +4 -1
- kiln_ai/adapters/test_langchain_adapter.py +73 -0
- kiln_ai/adapters/test_ml_model_list.py +56 -0
- kiln_ai/adapters/test_prompt_adaptors.py +52 -18
- kiln_ai/adapters/test_prompt_builders.py +97 -7
- kiln_ai/adapters/test_saving_adapter_results.py +16 -6
- kiln_ai/adapters/test_structured_output.py +33 -5
- kiln_ai/datamodel/__init__.py +28 -7
- kiln_ai/datamodel/json_schema.py +1 -0
- kiln_ai/datamodel/test_models.py +44 -8
- kiln_ai/utils/config.py +3 -2
- kiln_ai/utils/test_config.py +7 -0
- {kiln_ai-0.5.5.dist-info → kiln_ai-0.6.1.dist-info}/METADATA +1 -2
- kiln_ai-0.6.1.dist-info/RECORD +37 -0
- {kiln_ai-0.5.5.dist-info → kiln_ai-0.6.1.dist-info}/WHEEL +1 -1
- kiln_ai-0.5.5.dist-info/RECORD +0 -33
- {kiln_ai-0.5.5.dist-info → kiln_ai-0.6.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,6 +1,10 @@
|
|
|
1
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
2
|
+
|
|
3
|
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
1
4
|
from langchain_groq import ChatGroq
|
|
2
5
|
|
|
3
6
|
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
|
|
7
|
+
from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
|
|
4
8
|
from kiln_ai.adapters.test_prompt_adaptors import build_test_task
|
|
5
9
|
|
|
6
10
|
|
|
@@ -49,3 +53,72 @@ def test_langchain_adapter_info(tmp_path):
|
|
|
49
53
|
assert model_info.adapter_name == "kiln_langchain_adapter"
|
|
50
54
|
assert model_info.model_name == "llama_3_1_8b"
|
|
51
55
|
assert model_info.model_provider == "ollama"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
async def test_langchain_adapter_with_cot(tmp_path):
|
|
59
|
+
task = build_test_task(tmp_path)
|
|
60
|
+
task.output_json_schema = (
|
|
61
|
+
'{"type": "object", "properties": {"count": {"type": "integer"}}}'
|
|
62
|
+
)
|
|
63
|
+
lca = LangChainPromptAdapter(
|
|
64
|
+
kiln_task=task,
|
|
65
|
+
model_name="llama_3_1_8b",
|
|
66
|
+
provider="ollama",
|
|
67
|
+
prompt_builder=SimpleChainOfThoughtPromptBuilder(task),
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Mock the base model and its invoke method
|
|
71
|
+
mock_base_model = MagicMock()
|
|
72
|
+
mock_base_model.invoke.return_value = AIMessage(
|
|
73
|
+
content="Chain of thought reasoning..."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# Create a separate mock for self.model()
|
|
77
|
+
mock_model_instance = MagicMock()
|
|
78
|
+
mock_model_instance.invoke.return_value = {"parsed": {"count": 1}}
|
|
79
|
+
|
|
80
|
+
# Mock the langchain_model_from function to return the base model
|
|
81
|
+
mock_model_from = AsyncMock(return_value=mock_base_model)
|
|
82
|
+
|
|
83
|
+
# Patch both the langchain_model_from function and self.model()
|
|
84
|
+
with (
|
|
85
|
+
patch(
|
|
86
|
+
"kiln_ai.adapters.langchain_adapters.langchain_model_from", mock_model_from
|
|
87
|
+
),
|
|
88
|
+
patch.object(LangChainPromptAdapter, "model", return_value=mock_model_instance),
|
|
89
|
+
):
|
|
90
|
+
response = await lca._run("test input")
|
|
91
|
+
|
|
92
|
+
# First 3 messages are the same for both calls
|
|
93
|
+
for invoke_args in [
|
|
94
|
+
mock_base_model.invoke.call_args[0][0],
|
|
95
|
+
mock_model_instance.invoke.call_args[0][0],
|
|
96
|
+
]:
|
|
97
|
+
assert isinstance(
|
|
98
|
+
invoke_args[0], SystemMessage
|
|
99
|
+
) # First message should be system prompt
|
|
100
|
+
assert (
|
|
101
|
+
"You are an assistant which performs math tasks provided in plain text."
|
|
102
|
+
in invoke_args[0].content
|
|
103
|
+
)
|
|
104
|
+
assert isinstance(invoke_args[1], HumanMessage)
|
|
105
|
+
assert "test input" in invoke_args[1].content
|
|
106
|
+
assert isinstance(invoke_args[2], SystemMessage)
|
|
107
|
+
assert "step by step" in invoke_args[2].content
|
|
108
|
+
|
|
109
|
+
# the COT should only have 3 messages
|
|
110
|
+
assert len(mock_base_model.invoke.call_args[0][0]) == 3
|
|
111
|
+
assert len(mock_model_instance.invoke.call_args[0][0]) == 5
|
|
112
|
+
|
|
113
|
+
# the final response should have the COT content and the final instructions
|
|
114
|
+
invoke_args = mock_model_instance.invoke.call_args[0][0]
|
|
115
|
+
assert isinstance(invoke_args[3], AIMessage)
|
|
116
|
+
assert "Chain of thought reasoning..." in invoke_args[3].content
|
|
117
|
+
assert isinstance(invoke_args[4], SystemMessage)
|
|
118
|
+
assert "Considering the above, return a final result." in invoke_args[4].content
|
|
119
|
+
|
|
120
|
+
assert (
|
|
121
|
+
response.intermediate_outputs["chain_of_thought"]
|
|
122
|
+
== "Chain of thought reasoning..."
|
|
123
|
+
)
|
|
124
|
+
assert response.output == {"count": 1}
|
|
@@ -4,9 +4,11 @@ from unittest.mock import patch
|
|
|
4
4
|
import pytest
|
|
5
5
|
|
|
6
6
|
from kiln_ai.adapters.ml_model_list import (
|
|
7
|
+
ModelName,
|
|
7
8
|
ModelProviderName,
|
|
8
9
|
OllamaConnection,
|
|
9
10
|
check_provider_warnings,
|
|
11
|
+
get_model_and_provider,
|
|
10
12
|
ollama_model_supported,
|
|
11
13
|
parse_ollama_tags,
|
|
12
14
|
provider_name_from_id,
|
|
@@ -123,3 +125,57 @@ def test_ollama_model_supported():
|
|
|
123
125
|
assert ollama_model_supported(conn, "llama3.1:latest")
|
|
124
126
|
assert ollama_model_supported(conn, "llama3.1")
|
|
125
127
|
assert not ollama_model_supported(conn, "unknown_model")
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def test_get_model_and_provider_valid():
|
|
131
|
+
# Test with a known valid model and provider combination
|
|
132
|
+
model, provider = get_model_and_provider(
|
|
133
|
+
ModelName.phi_3_5, ModelProviderName.ollama
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
assert model is not None
|
|
137
|
+
assert provider is not None
|
|
138
|
+
assert model.name == ModelName.phi_3_5
|
|
139
|
+
assert provider.name == ModelProviderName.ollama
|
|
140
|
+
assert provider.provider_options["model"] == "phi3.5"
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def test_get_model_and_provider_invalid_model():
|
|
144
|
+
# Test with an invalid model name
|
|
145
|
+
model, provider = get_model_and_provider(
|
|
146
|
+
"nonexistent_model", ModelProviderName.ollama
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
assert model is None
|
|
150
|
+
assert provider is None
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def test_get_model_and_provider_invalid_provider():
|
|
154
|
+
# Test with a valid model but invalid provider
|
|
155
|
+
model, provider = get_model_and_provider(ModelName.phi_3_5, "nonexistent_provider")
|
|
156
|
+
|
|
157
|
+
assert model is None
|
|
158
|
+
assert provider is None
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def test_get_model_and_provider_valid_model_wrong_provider():
|
|
162
|
+
# Test with a valid model but a provider that doesn't support it
|
|
163
|
+
model, provider = get_model_and_provider(
|
|
164
|
+
ModelName.phi_3_5, ModelProviderName.amazon_bedrock
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
assert model is None
|
|
168
|
+
assert provider is None
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def test_get_model_and_provider_multiple_providers():
|
|
172
|
+
# Test with a model that has multiple providers
|
|
173
|
+
model, provider = get_model_and_provider(
|
|
174
|
+
ModelName.llama_3_1_70b, ModelProviderName.groq
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
assert model is not None
|
|
178
|
+
assert provider is not None
|
|
179
|
+
assert model.name == ModelName.llama_3_1_70b
|
|
180
|
+
assert provider.name == ModelProviderName.groq
|
|
181
|
+
assert provider.provider_options["model"] == "llama-3.1-70b-versatile"
|
|
@@ -7,6 +7,18 @@ from langchain_core.language_models.fake_chat_models import FakeListChatModel
|
|
|
7
7
|
import kiln_ai.datamodel as datamodel
|
|
8
8
|
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
|
|
9
9
|
from kiln_ai.adapters.ml_model_list import built_in_models, ollama_online
|
|
10
|
+
from kiln_ai.adapters.prompt_builders import (
|
|
11
|
+
BasePromptBuilder,
|
|
12
|
+
SimpleChainOfThoughtPromptBuilder,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_all_models_and_providers():
|
|
17
|
+
model_provider_pairs = []
|
|
18
|
+
for model in built_in_models:
|
|
19
|
+
for provider in model.providers:
|
|
20
|
+
model_provider_pairs.append((model.name, provider.name))
|
|
21
|
+
return model_provider_pairs
|
|
10
22
|
|
|
11
23
|
|
|
12
24
|
@pytest.mark.paid
|
|
@@ -30,6 +42,7 @@ async def test_groq(tmp_path):
|
|
|
30
42
|
"llama_3_2_90b",
|
|
31
43
|
"claude_3_5_haiku",
|
|
32
44
|
"claude_3_5_sonnet",
|
|
45
|
+
"phi_3_5",
|
|
33
46
|
],
|
|
34
47
|
)
|
|
35
48
|
@pytest.mark.paid
|
|
@@ -119,15 +132,19 @@ async def test_mock_returning_run(tmp_path):
|
|
|
119
132
|
|
|
120
133
|
@pytest.mark.paid
|
|
121
134
|
@pytest.mark.ollama
|
|
122
|
-
|
|
135
|
+
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
136
|
+
async def test_all_models_providers_plaintext(tmp_path, model_name, provider_name):
|
|
123
137
|
task = build_test_task(tmp_path)
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
138
|
+
await run_simple_task(task, model_name, provider_name)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@pytest.mark.paid
|
|
142
|
+
@pytest.mark.ollama
|
|
143
|
+
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
144
|
+
async def test_cot_prompt_builder(tmp_path, model_name, provider_name):
|
|
145
|
+
task = build_test_task(tmp_path)
|
|
146
|
+
pb = SimpleChainOfThoughtPromptBuilder(task)
|
|
147
|
+
await run_simple_task(task, model_name, provider_name, pb)
|
|
131
148
|
|
|
132
149
|
|
|
133
150
|
def build_test_task(tmp_path: Path):
|
|
@@ -159,13 +176,25 @@ def build_test_task(tmp_path: Path):
|
|
|
159
176
|
return task
|
|
160
177
|
|
|
161
178
|
|
|
162
|
-
async def run_simple_test(
|
|
179
|
+
async def run_simple_test(
|
|
180
|
+
tmp_path: Path,
|
|
181
|
+
model_name: str,
|
|
182
|
+
provider: str | None = None,
|
|
183
|
+
prompt_builder: BasePromptBuilder | None = None,
|
|
184
|
+
):
|
|
163
185
|
task = build_test_task(tmp_path)
|
|
164
|
-
return await run_simple_task(task, model_name, provider)
|
|
186
|
+
return await run_simple_task(task, model_name, provider, prompt_builder)
|
|
165
187
|
|
|
166
188
|
|
|
167
|
-
async def run_simple_task(
|
|
168
|
-
|
|
189
|
+
async def run_simple_task(
|
|
190
|
+
task: datamodel.Task,
|
|
191
|
+
model_name: str,
|
|
192
|
+
provider: str,
|
|
193
|
+
prompt_builder: BasePromptBuilder | None = None,
|
|
194
|
+
) -> datamodel.TaskRun:
|
|
195
|
+
adapter = LangChainPromptAdapter(
|
|
196
|
+
task, model_name=model_name, provider=provider, prompt_builder=prompt_builder
|
|
197
|
+
)
|
|
169
198
|
|
|
170
199
|
run = await adapter.invoke(
|
|
171
200
|
"You should answer the following question: four plus six times 10"
|
|
@@ -176,9 +205,14 @@ async def run_simple_task(task: datamodel.Task, model_name: str, provider: str):
|
|
|
176
205
|
run.input == "You should answer the following question: four plus six times 10"
|
|
177
206
|
)
|
|
178
207
|
assert "64" in run.output.output
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
208
|
+
source_props = run.output.source.properties
|
|
209
|
+
assert source_props["adapter_name"] == "kiln_langchain_adapter"
|
|
210
|
+
assert source_props["model_name"] == model_name
|
|
211
|
+
assert source_props["model_provider"] == provider
|
|
212
|
+
expected_prompt_builder_name = (
|
|
213
|
+
prompt_builder.__class__.prompt_builder_name()
|
|
214
|
+
if prompt_builder
|
|
215
|
+
else "simple_prompt_builder"
|
|
216
|
+
)
|
|
217
|
+
assert source_props["prompt_builder_name"] == expected_prompt_builder_name
|
|
218
|
+
return run
|
|
@@ -4,10 +4,14 @@ import pytest
|
|
|
4
4
|
|
|
5
5
|
from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter
|
|
6
6
|
from kiln_ai.adapters.prompt_builders import (
|
|
7
|
+
FewShotChainOfThoughtPromptBuilder,
|
|
7
8
|
FewShotPromptBuilder,
|
|
9
|
+
MultiShotChainOfThoughtPromptBuilder,
|
|
8
10
|
MultiShotPromptBuilder,
|
|
9
11
|
RepairsPromptBuilder,
|
|
12
|
+
SimpleChainOfThoughtPromptBuilder,
|
|
10
13
|
SimplePromptBuilder,
|
|
14
|
+
chain_of_thought_prompt,
|
|
11
15
|
prompt_builder_from_ui_name,
|
|
12
16
|
)
|
|
13
17
|
from kiln_ai.adapters.test_prompt_adaptors import build_test_task
|
|
@@ -43,9 +47,6 @@ def test_simple_prompt_builder(tmp_path):
|
|
|
43
47
|
|
|
44
48
|
|
|
45
49
|
class MockAdapter(BaseAdapter):
|
|
46
|
-
def adapter_specific_instructions(self) -> str | None:
|
|
47
|
-
return "You are a mock, send me the response!"
|
|
48
|
-
|
|
49
50
|
def _run(self, input: str) -> str:
|
|
50
51
|
return "mock response"
|
|
51
52
|
|
|
@@ -64,10 +65,6 @@ def test_simple_prompt_builder_structured_output(tmp_path):
|
|
|
64
65
|
prompt = builder.build_prompt()
|
|
65
66
|
assert "You are an assistant which tells a joke, given a subject." in prompt
|
|
66
67
|
|
|
67
|
-
# check adapter instructions are included
|
|
68
|
-
run_adapter = MockAdapter(task, prompt_builder=builder)
|
|
69
|
-
assert "You are a mock, send me the response!" in run_adapter.build_prompt()
|
|
70
|
-
|
|
71
68
|
user_msg = builder.build_user_message(input)
|
|
72
69
|
assert input in user_msg
|
|
73
70
|
assert input not in prompt
|
|
@@ -313,6 +310,18 @@ def test_prompt_builder_from_ui_name():
|
|
|
313
310
|
assert prompt_builder_from_ui_name("few_shot") == FewShotPromptBuilder
|
|
314
311
|
assert prompt_builder_from_ui_name("many_shot") == MultiShotPromptBuilder
|
|
315
312
|
assert prompt_builder_from_ui_name("repairs") == RepairsPromptBuilder
|
|
313
|
+
assert (
|
|
314
|
+
prompt_builder_from_ui_name("simple_chain_of_thought")
|
|
315
|
+
== SimpleChainOfThoughtPromptBuilder
|
|
316
|
+
)
|
|
317
|
+
assert (
|
|
318
|
+
prompt_builder_from_ui_name("few_shot_chain_of_thought")
|
|
319
|
+
== FewShotChainOfThoughtPromptBuilder
|
|
320
|
+
)
|
|
321
|
+
assert (
|
|
322
|
+
prompt_builder_from_ui_name("multi_shot_chain_of_thought")
|
|
323
|
+
== MultiShotChainOfThoughtPromptBuilder
|
|
324
|
+
)
|
|
316
325
|
|
|
317
326
|
with pytest.raises(ValueError, match="Unknown prompt builder: invalid_name"):
|
|
318
327
|
prompt_builder_from_ui_name("invalid_name")
|
|
@@ -336,3 +345,84 @@ def test_repair_multi_shot_prompt_builder(task_with_examples):
|
|
|
336
345
|
'Initial Output Which Was Insufficient: {"joke": "Moo I am a cow joke."}'
|
|
337
346
|
in prompt
|
|
338
347
|
)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def test_chain_of_thought_prompt(tmp_path):
|
|
351
|
+
# Test with default thinking instruction
|
|
352
|
+
task = Task(
|
|
353
|
+
name="Test Task",
|
|
354
|
+
instruction="Test instruction",
|
|
355
|
+
parent=None,
|
|
356
|
+
thinking_instruction=None,
|
|
357
|
+
)
|
|
358
|
+
assert (
|
|
359
|
+
chain_of_thought_prompt(task)
|
|
360
|
+
== "Think step by step, explaining your reasoning."
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# Test with custom thinking instruction
|
|
364
|
+
custom_instruction = "First analyze the problem, then break it down into steps."
|
|
365
|
+
task = Task(
|
|
366
|
+
name="Test Task",
|
|
367
|
+
instruction="Test instruction",
|
|
368
|
+
parent=None,
|
|
369
|
+
thinking_instruction=custom_instruction,
|
|
370
|
+
)
|
|
371
|
+
assert chain_of_thought_prompt(task) == custom_instruction
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
@pytest.mark.parametrize(
|
|
375
|
+
"builder_class",
|
|
376
|
+
[
|
|
377
|
+
SimpleChainOfThoughtPromptBuilder,
|
|
378
|
+
FewShotChainOfThoughtPromptBuilder,
|
|
379
|
+
MultiShotChainOfThoughtPromptBuilder,
|
|
380
|
+
],
|
|
381
|
+
)
|
|
382
|
+
def test_chain_of_thought_prompt_builders(builder_class, task_with_examples):
|
|
383
|
+
# Test with default thinking instruction
|
|
384
|
+
builder = builder_class(task=task_with_examples)
|
|
385
|
+
assert (
|
|
386
|
+
builder.chain_of_thought_prompt()
|
|
387
|
+
== "Think step by step, explaining your reasoning."
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
# Test with custom thinking instruction
|
|
391
|
+
custom_instruction = "First analyze the problem, then break it down into steps."
|
|
392
|
+
task_with_custom = task_with_examples.model_copy(
|
|
393
|
+
update={"thinking_instruction": custom_instruction}
|
|
394
|
+
)
|
|
395
|
+
builder = builder_class(task=task_with_custom)
|
|
396
|
+
assert builder.chain_of_thought_prompt() == custom_instruction
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def test_build_prompt_for_ui(tmp_path):
|
|
400
|
+
# Test regular prompt builder
|
|
401
|
+
task = build_test_task(tmp_path)
|
|
402
|
+
simple_builder = SimplePromptBuilder(task=task)
|
|
403
|
+
ui_prompt = simple_builder.build_prompt_for_ui()
|
|
404
|
+
|
|
405
|
+
# Should match regular prompt since no chain of thought
|
|
406
|
+
assert ui_prompt == simple_builder.build_prompt()
|
|
407
|
+
assert "# Thinking Instructions" not in ui_prompt
|
|
408
|
+
|
|
409
|
+
# Test chain of thought prompt builder
|
|
410
|
+
cot_builder = SimpleChainOfThoughtPromptBuilder(task=task)
|
|
411
|
+
ui_prompt_cot = cot_builder.build_prompt_for_ui()
|
|
412
|
+
|
|
413
|
+
# Should include both base prompt and thinking instructions
|
|
414
|
+
assert cot_builder.build_prompt() in ui_prompt_cot
|
|
415
|
+
assert "# Thinking Instructions" in ui_prompt_cot
|
|
416
|
+
assert "Think step by step" in ui_prompt_cot
|
|
417
|
+
|
|
418
|
+
# Test with custom thinking instruction
|
|
419
|
+
custom_instruction = "First analyze the problem, then solve it."
|
|
420
|
+
task_with_custom = task.model_copy(
|
|
421
|
+
update={"thinking_instruction": custom_instruction}
|
|
422
|
+
)
|
|
423
|
+
custom_cot_builder = SimpleChainOfThoughtPromptBuilder(task=task_with_custom)
|
|
424
|
+
ui_prompt_custom = custom_cot_builder.build_prompt_for_ui()
|
|
425
|
+
|
|
426
|
+
assert custom_cot_builder.build_prompt() in ui_prompt_custom
|
|
427
|
+
assert "# Thinking Instructions" in ui_prompt_custom
|
|
428
|
+
assert custom_instruction in ui_prompt_custom
|
|
@@ -2,7 +2,7 @@ from unittest.mock import patch
|
|
|
2
2
|
|
|
3
3
|
import pytest
|
|
4
4
|
|
|
5
|
-
from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter
|
|
5
|
+
from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter, RunOutput
|
|
6
6
|
from kiln_ai.datamodel import (
|
|
7
7
|
DataSource,
|
|
8
8
|
DataSourceType,
|
|
@@ -14,7 +14,7 @@ from kiln_ai.utils.config import Config
|
|
|
14
14
|
|
|
15
15
|
class MockAdapter(BaseAdapter):
|
|
16
16
|
async def _run(self, input: dict | str) -> dict | str:
|
|
17
|
-
return "Test output"
|
|
17
|
+
return RunOutput(output="Test output", intermediate_outputs=None)
|
|
18
18
|
|
|
19
19
|
def adapter_info(self) -> AdapterInfo:
|
|
20
20
|
return AdapterInfo(
|
|
@@ -42,9 +42,13 @@ def test_save_run_isolation(test_task):
|
|
|
42
42
|
adapter = MockAdapter(test_task)
|
|
43
43
|
input_data = "Test input"
|
|
44
44
|
output_data = "Test output"
|
|
45
|
+
run_output = RunOutput(
|
|
46
|
+
output=output_data,
|
|
47
|
+
intermediate_outputs={"chain_of_thought": "Test chain of thought"},
|
|
48
|
+
)
|
|
45
49
|
|
|
46
50
|
task_run = adapter.generate_run(
|
|
47
|
-
input=input_data, input_source=None,
|
|
51
|
+
input=input_data, input_source=None, run_output=run_output
|
|
48
52
|
)
|
|
49
53
|
task_run.save_to_file()
|
|
50
54
|
|
|
@@ -52,6 +56,9 @@ def test_save_run_isolation(test_task):
|
|
|
52
56
|
assert task_run.parent == test_task
|
|
53
57
|
assert task_run.input == input_data
|
|
54
58
|
assert task_run.input_source.type == DataSourceType.human
|
|
59
|
+
assert task_run.intermediate_outputs == {
|
|
60
|
+
"chain_of_thought": "Test chain of thought"
|
|
61
|
+
}
|
|
55
62
|
created_by = Config.shared().user_id
|
|
56
63
|
if created_by and created_by != "":
|
|
57
64
|
assert task_run.input_source.properties["created_by"] == created_by
|
|
@@ -86,13 +93,16 @@ def test_save_run_isolation(test_task):
|
|
|
86
93
|
)
|
|
87
94
|
|
|
88
95
|
# Run again, with same input and different output. Should create a new TaskRun.
|
|
89
|
-
|
|
96
|
+
different_run_output = RunOutput(
|
|
97
|
+
output="Different output", intermediate_outputs=None
|
|
98
|
+
)
|
|
99
|
+
task_output = adapter.generate_run(input_data, None, different_run_output)
|
|
90
100
|
task_output.save_to_file()
|
|
91
101
|
assert len(test_task.runs()) == 2
|
|
92
102
|
assert "Different output" in set(run.output.output for run in test_task.runs())
|
|
93
103
|
|
|
94
104
|
# run again with same input and same output. Should not create a new TaskRun.
|
|
95
|
-
task_output = adapter.generate_run(input_data, None,
|
|
105
|
+
task_output = adapter.generate_run(input_data, None, run_output)
|
|
96
106
|
task_output.save_to_file()
|
|
97
107
|
assert len(test_task.runs()) == 2
|
|
98
108
|
assert "Different output" in set(run.output.output for run in test_task.runs())
|
|
@@ -110,7 +120,7 @@ def test_save_run_isolation(test_task):
|
|
|
110
120
|
"adapter_name": "mock_adapter",
|
|
111
121
|
},
|
|
112
122
|
),
|
|
113
|
-
|
|
123
|
+
run_output,
|
|
114
124
|
)
|
|
115
125
|
task_output.save_to_file()
|
|
116
126
|
assert len(test_task.runs()) == 3
|
|
@@ -6,12 +6,17 @@ import jsonschema.exceptions
|
|
|
6
6
|
import pytest
|
|
7
7
|
|
|
8
8
|
import kiln_ai.datamodel as datamodel
|
|
9
|
-
from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter
|
|
9
|
+
from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter, RunOutput
|
|
10
10
|
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
|
|
11
11
|
from kiln_ai.adapters.ml_model_list import (
|
|
12
12
|
built_in_models,
|
|
13
13
|
ollama_online,
|
|
14
14
|
)
|
|
15
|
+
from kiln_ai.adapters.prompt_builders import (
|
|
16
|
+
BasePromptBuilder,
|
|
17
|
+
SimpleChainOfThoughtPromptBuilder,
|
|
18
|
+
)
|
|
19
|
+
from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
|
|
15
20
|
from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
|
|
16
21
|
|
|
17
22
|
|
|
@@ -59,8 +64,8 @@ class MockAdapter(BaseAdapter):
|
|
|
59
64
|
super().__init__(kiln_task)
|
|
60
65
|
self.response = response
|
|
61
66
|
|
|
62
|
-
async def _run(self, input: str) ->
|
|
63
|
-
return self.response
|
|
67
|
+
async def _run(self, input: str) -> RunOutput:
|
|
68
|
+
return RunOutput(output=self.response, intermediate_outputs=None)
|
|
64
69
|
|
|
65
70
|
def adapter_info(self) -> AdapterInfo:
|
|
66
71
|
return AdapterInfo(
|
|
@@ -190,7 +195,18 @@ def build_structured_input_test_task(tmp_path: Path):
|
|
|
190
195
|
|
|
191
196
|
async def run_structured_input_test(tmp_path: Path, model_name: str, provider: str):
|
|
192
197
|
task = build_structured_input_test_task(tmp_path)
|
|
193
|
-
|
|
198
|
+
await run_structured_input_task(task, model_name, provider)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
async def run_structured_input_task(
|
|
202
|
+
task: datamodel.Task,
|
|
203
|
+
model_name: str,
|
|
204
|
+
provider: str,
|
|
205
|
+
pb: BasePromptBuilder | None = None,
|
|
206
|
+
):
|
|
207
|
+
a = LangChainPromptAdapter(
|
|
208
|
+
task, model_name=model_name, provider=provider, prompt_builder=pb
|
|
209
|
+
)
|
|
194
210
|
with pytest.raises(ValueError):
|
|
195
211
|
# not structured input in dictionary
|
|
196
212
|
await a.invoke("a=1, b=2, c=3")
|
|
@@ -203,7 +219,10 @@ async def run_structured_input_test(tmp_path: Path, model_name: str, provider: s
|
|
|
203
219
|
assert isinstance(response, str)
|
|
204
220
|
assert "[[equilateral]]" in response
|
|
205
221
|
adapter_info = a.adapter_info()
|
|
206
|
-
|
|
222
|
+
expected_pb_name = "simple_prompt_builder"
|
|
223
|
+
if pb is not None:
|
|
224
|
+
expected_pb_name = pb.__class__.prompt_builder_name()
|
|
225
|
+
assert adapter_info.prompt_builder_name == expected_pb_name
|
|
207
226
|
assert adapter_info.model_name == model_name
|
|
208
227
|
assert adapter_info.model_provider == provider
|
|
209
228
|
assert adapter_info.adapter_name == "kiln_langchain_adapter"
|
|
@@ -224,3 +243,12 @@ async def test_all_built_in_models_structured_input(tmp_path):
|
|
|
224
243
|
await run_structured_input_test(tmp_path, model.name, provider.name)
|
|
225
244
|
except Exception as e:
|
|
226
245
|
raise RuntimeError(f"Error running {model.name} {provider}") from e
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@pytest.mark.paid
|
|
249
|
+
@pytest.mark.ollama
|
|
250
|
+
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
251
|
+
async def test_structured_cot_prompt_builder(tmp_path, model_name, provider_name):
|
|
252
|
+
task = build_structured_input_test_task(tmp_path)
|
|
253
|
+
pb = SimpleChainOfThoughtPromptBuilder(task)
|
|
254
|
+
await run_structured_input_task(task, model_name, provider_name, pb)
|
kiln_ai/datamodel/__init__.py
CHANGED
|
@@ -48,8 +48,18 @@ __all__ = [
|
|
|
48
48
|
|
|
49
49
|
# Filename compatible names
|
|
50
50
|
NAME_REGEX = r"^[A-Za-z0-9 _-]+$"
|
|
51
|
-
NAME_FIELD = Field(
|
|
52
|
-
|
|
51
|
+
NAME_FIELD = Field(
|
|
52
|
+
min_length=1,
|
|
53
|
+
max_length=120,
|
|
54
|
+
pattern=NAME_REGEX,
|
|
55
|
+
description="A name for this entity.",
|
|
56
|
+
)
|
|
57
|
+
SHORT_NAME_FIELD = Field(
|
|
58
|
+
min_length=1,
|
|
59
|
+
max_length=32,
|
|
60
|
+
pattern=NAME_REGEX,
|
|
61
|
+
description="A name for this entity",
|
|
62
|
+
)
|
|
53
63
|
|
|
54
64
|
|
|
55
65
|
class Priority(IntEnum):
|
|
@@ -280,6 +290,10 @@ class TaskRun(KilnParentedModel):
|
|
|
280
290
|
default=None,
|
|
281
291
|
description="An version of the output with issues fixed. This must be a 'fixed' version of the existing output, and not an entirely new output. If you wish to generate an ideal curatorial output for this task unrelated to this output, generate a new TaskOutput with type 'human' instead of using this field.",
|
|
282
292
|
)
|
|
293
|
+
intermediate_outputs: Dict[str, str] | None = Field(
|
|
294
|
+
default=None,
|
|
295
|
+
description="Intermediate outputs from the task run. Keys are the names of the intermediate output steps (cot=chain of thought, etc), values are the output data.",
|
|
296
|
+
)
|
|
283
297
|
|
|
284
298
|
def parent_task(self) -> Task | None:
|
|
285
299
|
if not isinstance(self.parent, Task):
|
|
@@ -372,14 +386,21 @@ class Task(
|
|
|
372
386
|
"""
|
|
373
387
|
|
|
374
388
|
name: str = NAME_FIELD
|
|
375
|
-
description: str = Field(
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
389
|
+
description: str | None = Field(
|
|
390
|
+
default=None,
|
|
391
|
+
description="A description of the task for you and your team. Will not be used in prompts/training/validation.",
|
|
392
|
+
)
|
|
393
|
+
instruction: str = Field(
|
|
394
|
+
min_length=1,
|
|
395
|
+
description="The instructions for the task. Will be used in prompts/training/validation.",
|
|
396
|
+
)
|
|
379
397
|
requirements: List[TaskRequirement] = Field(default=[])
|
|
380
|
-
# TODO: make this required, or formalize the default message output schema
|
|
381
398
|
output_json_schema: JsonObjectSchema | None = None
|
|
382
399
|
input_json_schema: JsonObjectSchema | None = None
|
|
400
|
+
thinking_instruction: str | None = Field(
|
|
401
|
+
default=None,
|
|
402
|
+
description="Instructions for the model 'thinking' about the requirement prior to answering. Used for chain of thought style prompting.",
|
|
403
|
+
)
|
|
383
404
|
|
|
384
405
|
def output_schema(self) -> Dict | None:
|
|
385
406
|
if self.output_json_schema is None:
|
kiln_ai/datamodel/json_schema.py
CHANGED
|
@@ -64,6 +64,7 @@ def schema_from_json_str(v: str) -> Dict:
|
|
|
64
64
|
jsonschema.Draft202012Validator.check_schema(parsed)
|
|
65
65
|
if not isinstance(parsed, dict):
|
|
66
66
|
raise ValueError(f"JSON schema must be a dict, not {type(parsed)}")
|
|
67
|
+
# Top level arrays are valid JSON schemas, but we don't want to allow them here as they often cause issues
|
|
67
68
|
if (
|
|
68
69
|
"type" not in parsed
|
|
69
70
|
or parsed["type"] != "object"
|