kiln-ai 0.0.4__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/base_adapter.py +168 -0
- kiln_ai/adapters/langchain_adapters.py +113 -0
- kiln_ai/adapters/ml_model_list.py +436 -0
- kiln_ai/adapters/prompt_builders.py +122 -0
- kiln_ai/adapters/repair/repair_task.py +71 -0
- kiln_ai/adapters/repair/test_repair_task.py +248 -0
- kiln_ai/adapters/test_langchain_adapter.py +50 -0
- kiln_ai/adapters/test_ml_model_list.py +99 -0
- kiln_ai/adapters/test_prompt_adaptors.py +167 -0
- kiln_ai/adapters/test_prompt_builders.py +315 -0
- kiln_ai/adapters/test_saving_adapter_results.py +168 -0
- kiln_ai/adapters/test_structured_output.py +218 -0
- kiln_ai/datamodel/__init__.py +362 -2
- kiln_ai/datamodel/basemodel.py +372 -0
- kiln_ai/datamodel/json_schema.py +45 -0
- kiln_ai/datamodel/test_basemodel.py +277 -0
- kiln_ai/datamodel/test_datasource.py +107 -0
- kiln_ai/datamodel/test_example_models.py +644 -0
- kiln_ai/datamodel/test_json_schema.py +124 -0
- kiln_ai/datamodel/test_models.py +190 -0
- kiln_ai/datamodel/test_nested_save.py +205 -0
- kiln_ai/datamodel/test_output_rating.py +88 -0
- kiln_ai/utils/config.py +170 -0
- kiln_ai/utils/formatting.py +5 -0
- kiln_ai/utils/test_config.py +245 -0
- {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.0.dist-info}/METADATA +20 -1
- kiln_ai-0.5.0.dist-info/RECORD +29 -0
- kiln_ai/__init.__.py +0 -3
- kiln_ai/coreadd.py +0 -3
- kiln_ai/datamodel/project.py +0 -15
- kiln_ai-0.0.4.dist-info/RECORD +0 -8
- {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.0.dist-info}/LICENSE.txt +0 -0
- {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from unittest.mock import AsyncMock, patch
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
from kiln_ai.adapters.langchain_adapters import (
|
|
7
|
+
LangChainPromptAdapter,
|
|
8
|
+
)
|
|
9
|
+
from kiln_ai.adapters.repair.repair_task import (
|
|
10
|
+
RepairTaskInput,
|
|
11
|
+
RepairTaskRun,
|
|
12
|
+
)
|
|
13
|
+
from kiln_ai.datamodel import (
|
|
14
|
+
DataSource,
|
|
15
|
+
DataSourceType,
|
|
16
|
+
Priority,
|
|
17
|
+
Task,
|
|
18
|
+
TaskOutput,
|
|
19
|
+
TaskRequirement,
|
|
20
|
+
TaskRun,
|
|
21
|
+
)
|
|
22
|
+
from pydantic import ValidationError
|
|
23
|
+
|
|
24
|
+
json_joke_schema = """{
|
|
25
|
+
"type": "object",
|
|
26
|
+
"properties": {
|
|
27
|
+
"setup": {
|
|
28
|
+
"description": "The setup of the joke",
|
|
29
|
+
"title": "Setup",
|
|
30
|
+
"type": "string"
|
|
31
|
+
},
|
|
32
|
+
"punchline": {
|
|
33
|
+
"description": "The punchline to the joke",
|
|
34
|
+
"title": "Punchline",
|
|
35
|
+
"type": "string"
|
|
36
|
+
},
|
|
37
|
+
"rating": {
|
|
38
|
+
"anyOf": [
|
|
39
|
+
{
|
|
40
|
+
"type": "integer"
|
|
41
|
+
},
|
|
42
|
+
{
|
|
43
|
+
"type": "null"
|
|
44
|
+
}
|
|
45
|
+
],
|
|
46
|
+
"default": null,
|
|
47
|
+
"description": "How funny the joke is, from 1 to 10",
|
|
48
|
+
"title": "Rating"
|
|
49
|
+
}
|
|
50
|
+
},
|
|
51
|
+
"required": [
|
|
52
|
+
"setup",
|
|
53
|
+
"punchline"
|
|
54
|
+
]
|
|
55
|
+
}
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@pytest.fixture
|
|
60
|
+
def sample_task(tmp_path):
|
|
61
|
+
task_path = tmp_path / "task.json"
|
|
62
|
+
task = Task(
|
|
63
|
+
name="Joke Generator",
|
|
64
|
+
path=task_path,
|
|
65
|
+
description="Generate a funny joke",
|
|
66
|
+
instruction="Create a joke with a setup and punchline",
|
|
67
|
+
requirements=[
|
|
68
|
+
TaskRequirement(
|
|
69
|
+
id="req1",
|
|
70
|
+
name="Humor",
|
|
71
|
+
instruction="The joke should be funny and appropriate",
|
|
72
|
+
priority=Priority.p1,
|
|
73
|
+
)
|
|
74
|
+
],
|
|
75
|
+
output_json_schema=json_joke_schema,
|
|
76
|
+
)
|
|
77
|
+
task.save_to_file()
|
|
78
|
+
return task
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@pytest.fixture
|
|
82
|
+
def sample_task_run(sample_task):
|
|
83
|
+
task_run = TaskRun(
|
|
84
|
+
parent=sample_task,
|
|
85
|
+
input='{"topic": "chicken"}',
|
|
86
|
+
input_source=DataSource(
|
|
87
|
+
type=DataSourceType.human, properties={"created_by": "Jane Doe"}
|
|
88
|
+
),
|
|
89
|
+
output=TaskOutput(
|
|
90
|
+
output='{"setup": "Why did the chicken cross the road?", "punchline": "To get to the other side", "rating": null}',
|
|
91
|
+
source=DataSource(
|
|
92
|
+
type=DataSourceType.synthetic,
|
|
93
|
+
properties={
|
|
94
|
+
"model_name": "gpt_4o",
|
|
95
|
+
"model_provider": "openai",
|
|
96
|
+
"adapter_name": "langchain_adapter",
|
|
97
|
+
"prompt_builder_name": "simple_prompt_builder",
|
|
98
|
+
},
|
|
99
|
+
),
|
|
100
|
+
),
|
|
101
|
+
)
|
|
102
|
+
task_run.save_to_file()
|
|
103
|
+
return task_run
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@pytest.fixture
|
|
107
|
+
def sample_repair_data(sample_task, sample_task_run):
|
|
108
|
+
return {
|
|
109
|
+
"original_task": sample_task,
|
|
110
|
+
"task_run": sample_task_run,
|
|
111
|
+
"evaluator_feedback": "The joke is too cliché. Please come up with a more original chicken-related joke.",
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def test_build_repair_task_input(sample_repair_data):
|
|
116
|
+
result = RepairTaskRun.build_repair_task_input(**sample_repair_data)
|
|
117
|
+
|
|
118
|
+
assert isinstance(result, RepairTaskInput)
|
|
119
|
+
assert "Create a joke with a setup and punchline" in result.original_prompt
|
|
120
|
+
assert "1) The joke should be funny and appropriate" in result.original_prompt
|
|
121
|
+
assert result.original_input == '{"topic": "chicken"}'
|
|
122
|
+
assert (
|
|
123
|
+
result.original_output
|
|
124
|
+
== '{"setup": "Why did the chicken cross the road?", "punchline": "To get to the other side", "rating": null}'
|
|
125
|
+
)
|
|
126
|
+
assert (
|
|
127
|
+
result.evaluator_feedback
|
|
128
|
+
== "The joke is too cliché. Please come up with a more original chicken-related joke."
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def test_repair_input_schema():
|
|
133
|
+
schema = RepairTaskInput.model_json_schema()
|
|
134
|
+
assert schema["type"] == "object"
|
|
135
|
+
assert "original_prompt" in schema["properties"]
|
|
136
|
+
assert "original_input" in schema["properties"]
|
|
137
|
+
assert "original_output" in schema["properties"]
|
|
138
|
+
assert "evaluator_feedback" in schema["properties"]
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def test_repair_task_initialization(sample_task):
|
|
142
|
+
repair_task = RepairTaskRun(sample_task)
|
|
143
|
+
|
|
144
|
+
assert repair_task.name == "Repair"
|
|
145
|
+
assert "Repair a task run" in repair_task.description
|
|
146
|
+
assert "You are an assistant which helps improve output" in repair_task.instruction
|
|
147
|
+
assert len(repair_task.requirements) == 1
|
|
148
|
+
assert repair_task.requirements[0].name == "Follow Eval Feedback"
|
|
149
|
+
assert repair_task.input_json_schema == json.dumps(
|
|
150
|
+
RepairTaskInput.model_json_schema()
|
|
151
|
+
)
|
|
152
|
+
assert repair_task.output_json_schema == sample_task.output_json_schema
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def test_build_repair_task_input_with_empty_values(sample_task, sample_task_run):
|
|
156
|
+
# Arrange
|
|
157
|
+
sample_task_run.input = ""
|
|
158
|
+
sample_task_run.output.output = ""
|
|
159
|
+
|
|
160
|
+
# Act & Assert
|
|
161
|
+
with pytest.raises(ValidationError, match="evaluator_feedback"):
|
|
162
|
+
RepairTaskRun.build_repair_task_input(
|
|
163
|
+
original_task=sample_task, task_run=sample_task_run, evaluator_feedback=""
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Test that it works with non-empty feedback
|
|
167
|
+
result = RepairTaskRun.build_repair_task_input(
|
|
168
|
+
original_task=sample_task,
|
|
169
|
+
task_run=sample_task_run,
|
|
170
|
+
evaluator_feedback="Some feedback",
|
|
171
|
+
)
|
|
172
|
+
assert isinstance(result, RepairTaskInput)
|
|
173
|
+
assert result.evaluator_feedback == "Some feedback"
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@pytest.mark.parametrize("invalid_input", [{}])
|
|
177
|
+
def test_build_repair_task_input_with_invalid_input(invalid_input):
|
|
178
|
+
# Act & Assert
|
|
179
|
+
with pytest.raises(TypeError):
|
|
180
|
+
RepairTaskRun.build_repair_task_input(invalid_input)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
@pytest.mark.paid
|
|
184
|
+
async def test_live_run(sample_task, sample_task_run, sample_repair_data):
|
|
185
|
+
if os.getenv("GROQ_API_KEY") is None:
|
|
186
|
+
pytest.skip("GROQ_API_KEY not set")
|
|
187
|
+
repair_task = RepairTaskRun(sample_task)
|
|
188
|
+
repair_task_input = RepairTaskRun.build_repair_task_input(**sample_repair_data)
|
|
189
|
+
assert isinstance(repair_task_input, RepairTaskInput)
|
|
190
|
+
|
|
191
|
+
adapter = LangChainPromptAdapter(
|
|
192
|
+
repair_task, model_name="llama_3_1_8b", provider="groq"
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
run = await adapter.invoke(repair_task_input.model_dump())
|
|
196
|
+
assert run is not None
|
|
197
|
+
assert "Please come up with a more original chicken-related joke." in run.input
|
|
198
|
+
parsed_output = json.loads(run.output.output)
|
|
199
|
+
assert "setup" in parsed_output
|
|
200
|
+
assert "punchline" in parsed_output
|
|
201
|
+
assert run.output.source.properties == {
|
|
202
|
+
"adapter_name": "kiln_langchain_adapter",
|
|
203
|
+
"model_name": "llama_3_1_8b",
|
|
204
|
+
"model_provider": "groq",
|
|
205
|
+
"prompt_builder_name": "simple_prompt_builder",
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@pytest.mark.asyncio
|
|
210
|
+
async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repair_data):
|
|
211
|
+
repair_task = RepairTaskRun(sample_task)
|
|
212
|
+
repair_task_input = RepairTaskRun.build_repair_task_input(**sample_repair_data)
|
|
213
|
+
assert isinstance(repair_task_input, RepairTaskInput)
|
|
214
|
+
|
|
215
|
+
mocked_output = {
|
|
216
|
+
"setup": "Why did the chicken join a band?",
|
|
217
|
+
"punchline": "Because it had excellent drumsticks!",
|
|
218
|
+
"rating": 8,
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
with patch.object(
|
|
222
|
+
LangChainPromptAdapter, "_run", new_callable=AsyncMock
|
|
223
|
+
) as mock_run:
|
|
224
|
+
mock_run.return_value = mocked_output
|
|
225
|
+
|
|
226
|
+
adapter = LangChainPromptAdapter(
|
|
227
|
+
repair_task, model_name="llama_3_1_8b", provider="groq"
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
run = await adapter.invoke(repair_task_input.model_dump())
|
|
231
|
+
|
|
232
|
+
assert run is not None
|
|
233
|
+
assert run.id is None
|
|
234
|
+
assert "Please come up with a more original chicken-related joke." in run.input
|
|
235
|
+
|
|
236
|
+
parsed_output = json.loads(run.output.output)
|
|
237
|
+
assert parsed_output == mocked_output
|
|
238
|
+
assert run.output.source.properties == {
|
|
239
|
+
"adapter_name": "kiln_langchain_adapter",
|
|
240
|
+
"model_name": "llama_3_1_8b",
|
|
241
|
+
"model_provider": "groq",
|
|
242
|
+
"prompt_builder_name": "simple_prompt_builder",
|
|
243
|
+
}
|
|
244
|
+
assert run.input_source.type == DataSourceType.human
|
|
245
|
+
assert "created_by" in run.input_source.properties
|
|
246
|
+
|
|
247
|
+
# Verify that the mock was called
|
|
248
|
+
mock_run.assert_called_once()
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
|
|
2
|
+
from kiln_ai.adapters.test_prompt_adaptors import build_test_task
|
|
3
|
+
from langchain_groq import ChatGroq
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def test_langchain_adapter_munge_response(tmp_path):
|
|
7
|
+
task = build_test_task(tmp_path)
|
|
8
|
+
lca = LangChainPromptAdapter(
|
|
9
|
+
kiln_task=task, model_name="llama_3_1_8b", provider="ollama"
|
|
10
|
+
)
|
|
11
|
+
# Mistral Large tool calling format is a bit different
|
|
12
|
+
response = {
|
|
13
|
+
"name": "task_response",
|
|
14
|
+
"arguments": {
|
|
15
|
+
"setup": "Why did the cow join a band?",
|
|
16
|
+
"punchline": "Because she wanted to be a moo-sician!",
|
|
17
|
+
},
|
|
18
|
+
}
|
|
19
|
+
munged = lca._munge_response(response)
|
|
20
|
+
assert munged["setup"] == "Why did the cow join a band?"
|
|
21
|
+
assert munged["punchline"] == "Because she wanted to be a moo-sician!"
|
|
22
|
+
|
|
23
|
+
# non mistral format should continue to work
|
|
24
|
+
munged = lca._munge_response(response["arguments"])
|
|
25
|
+
assert munged["setup"] == "Why did the cow join a band?"
|
|
26
|
+
assert munged["punchline"] == "Because she wanted to be a moo-sician!"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def test_langchain_adapter_infer_model_name(tmp_path):
|
|
30
|
+
task = build_test_task(tmp_path)
|
|
31
|
+
custom = ChatGroq(model="llama-3.1-8b-instant", groq_api_key="test")
|
|
32
|
+
|
|
33
|
+
lca = LangChainPromptAdapter(kiln_task=task, custom_model=custom)
|
|
34
|
+
|
|
35
|
+
model_info = lca.adapter_info()
|
|
36
|
+
assert model_info.model_name == "custom.langchain:llama-3.1-8b-instant"
|
|
37
|
+
assert model_info.model_provider == "custom.langchain:ChatGroq"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def test_langchain_adapter_info(tmp_path):
|
|
41
|
+
task = build_test_task(tmp_path)
|
|
42
|
+
|
|
43
|
+
lca = LangChainPromptAdapter(
|
|
44
|
+
kiln_task=task, model_name="llama_3_1_8b", provider="ollama"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
model_info = lca.adapter_info()
|
|
48
|
+
assert model_info.adapter_name == "kiln_langchain_adapter"
|
|
49
|
+
assert model_info.model_name == "llama_3_1_8b"
|
|
50
|
+
assert model_info.model_provider == "ollama"
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from unittest.mock import patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from libs.core.kiln_ai.adapters.ml_model_list import (
|
|
6
|
+
ModelProviderName,
|
|
7
|
+
check_provider_warnings,
|
|
8
|
+
provider_name_from_id,
|
|
9
|
+
provider_warnings,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pytest.fixture
|
|
14
|
+
def mock_config():
|
|
15
|
+
with patch("libs.core.kiln_ai.adapters.ml_model_list.get_config_value") as mock:
|
|
16
|
+
yield mock
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def test_check_provider_warnings_no_warning(mock_config):
|
|
20
|
+
mock_config.return_value = "some_value"
|
|
21
|
+
|
|
22
|
+
# This should not raise an exception
|
|
23
|
+
check_provider_warnings(ModelProviderName.amazon_bedrock)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_check_provider_warnings_missing_key(mock_config):
|
|
27
|
+
mock_config.return_value = None
|
|
28
|
+
|
|
29
|
+
with pytest.raises(ValueError) as exc_info:
|
|
30
|
+
check_provider_warnings(ModelProviderName.amazon_bedrock)
|
|
31
|
+
|
|
32
|
+
assert provider_warnings[ModelProviderName.amazon_bedrock].message in str(
|
|
33
|
+
exc_info.value
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def test_check_provider_warnings_unknown_provider():
|
|
38
|
+
# This should not raise an exception, as no settings are required for unknown providers
|
|
39
|
+
check_provider_warnings("unknown_provider")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@pytest.mark.parametrize(
|
|
43
|
+
"provider_name",
|
|
44
|
+
[
|
|
45
|
+
ModelProviderName.amazon_bedrock,
|
|
46
|
+
ModelProviderName.openrouter,
|
|
47
|
+
ModelProviderName.groq,
|
|
48
|
+
ModelProviderName.openai,
|
|
49
|
+
],
|
|
50
|
+
)
|
|
51
|
+
def test_check_provider_warnings_all_providers(mock_config, provider_name):
|
|
52
|
+
mock_config.return_value = None
|
|
53
|
+
|
|
54
|
+
with pytest.raises(ValueError) as exc_info:
|
|
55
|
+
check_provider_warnings(provider_name)
|
|
56
|
+
|
|
57
|
+
assert provider_warnings[provider_name].message in str(exc_info.value)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def test_check_provider_warnings_partial_keys_set(mock_config):
|
|
61
|
+
def mock_get(key):
|
|
62
|
+
return "value" if key == "bedrock_access_key" else None
|
|
63
|
+
|
|
64
|
+
mock_config.side_effect = mock_get
|
|
65
|
+
|
|
66
|
+
with pytest.raises(ValueError) as exc_info:
|
|
67
|
+
check_provider_warnings(ModelProviderName.amazon_bedrock)
|
|
68
|
+
|
|
69
|
+
assert provider_warnings[ModelProviderName.amazon_bedrock].message in str(
|
|
70
|
+
exc_info.value
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def test_provider_name_from_id_unknown_provider():
|
|
75
|
+
assert (
|
|
76
|
+
provider_name_from_id("unknown_provider")
|
|
77
|
+
== "Unknown provider: unknown_provider"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def test_provider_name_from_id_case_sensitivity():
|
|
82
|
+
assert (
|
|
83
|
+
provider_name_from_id(ModelProviderName.amazon_bedrock.upper())
|
|
84
|
+
== "Unknown provider: AMAZON_BEDROCK"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@pytest.mark.parametrize(
|
|
89
|
+
"provider_id, expected_name",
|
|
90
|
+
[
|
|
91
|
+
(ModelProviderName.amazon_bedrock, "Amazon Bedrock"),
|
|
92
|
+
(ModelProviderName.openrouter, "OpenRouter"),
|
|
93
|
+
(ModelProviderName.groq, "Groq"),
|
|
94
|
+
(ModelProviderName.ollama, "Ollama"),
|
|
95
|
+
(ModelProviderName.openai, "OpenAI"),
|
|
96
|
+
],
|
|
97
|
+
)
|
|
98
|
+
def test_provider_name_from_id_parametrized(provider_id, expected_name):
|
|
99
|
+
assert provider_name_from_id(provider_id) == expected_name
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import kiln_ai.datamodel as datamodel
|
|
5
|
+
import pytest
|
|
6
|
+
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
|
|
7
|
+
from kiln_ai.adapters.ml_model_list import built_in_models, ollama_online
|
|
8
|
+
from langchain_core.language_models.fake_chat_models import FakeListChatModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@pytest.mark.paid
|
|
12
|
+
async def test_groq(tmp_path):
|
|
13
|
+
if os.getenv("GROQ_API_KEY") is None:
|
|
14
|
+
pytest.skip("GROQ_API_KEY not set")
|
|
15
|
+
await run_simple_test(tmp_path, "llama_3_1_8b", "groq")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.mark.paid
|
|
19
|
+
async def test_openrouter(tmp_path):
|
|
20
|
+
await run_simple_test(tmp_path, "llama_3_1_8b", "openrouter")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@pytest.mark.ollama
|
|
24
|
+
async def test_ollama_phi(tmp_path):
|
|
25
|
+
# Check if Ollama API is running
|
|
26
|
+
if not await ollama_online():
|
|
27
|
+
pytest.skip("Ollama API not running. Expect it running on localhost:11434")
|
|
28
|
+
|
|
29
|
+
await run_simple_test(tmp_path, "phi_3_5", "ollama")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@pytest.mark.ollama
|
|
33
|
+
async def test_ollama_gemma(tmp_path):
|
|
34
|
+
# Check if Ollama API is running
|
|
35
|
+
if not await ollama_online():
|
|
36
|
+
pytest.skip("Ollama API not running. Expect it running on localhost:11434")
|
|
37
|
+
|
|
38
|
+
await run_simple_test(tmp_path, "gemma_2_2b", "ollama")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@pytest.mark.ollama
|
|
42
|
+
async def test_autoselect_provider(tmp_path):
|
|
43
|
+
# Check if Ollama API is running
|
|
44
|
+
if not await ollama_online():
|
|
45
|
+
pytest.skip("Ollama API not running. Expect it running on localhost:11434")
|
|
46
|
+
|
|
47
|
+
await run_simple_test(tmp_path, "phi_3_5")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@pytest.mark.ollama
|
|
51
|
+
async def test_ollama_llama(tmp_path):
|
|
52
|
+
# Check if Ollama API is running
|
|
53
|
+
if not await ollama_online():
|
|
54
|
+
pytest.skip("Ollama API not running. Expect it running on localhost:11434")
|
|
55
|
+
|
|
56
|
+
await run_simple_test(tmp_path, "llama_3_1_8b", "ollama")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@pytest.mark.paid
|
|
60
|
+
async def test_openai(tmp_path):
|
|
61
|
+
if os.getenv("OPENAI_API_KEY") is None:
|
|
62
|
+
pytest.skip("OPENAI_API_KEY not set")
|
|
63
|
+
await run_simple_test(tmp_path, "gpt_4o_mini", "openai")
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@pytest.mark.paid
|
|
67
|
+
async def test_amazon_bedrock(tmp_path):
|
|
68
|
+
if (
|
|
69
|
+
os.getenv("AWS_SECRET_ACCESS_KEY") is None
|
|
70
|
+
or os.getenv("AWS_ACCESS_KEY_ID") is None
|
|
71
|
+
):
|
|
72
|
+
pytest.skip("AWS keys not set")
|
|
73
|
+
await run_simple_test(tmp_path, "llama_3_1_8b", "amazon_bedrock")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
async def test_mock(tmp_path):
|
|
77
|
+
task = build_test_task(tmp_path)
|
|
78
|
+
mockChatModel = FakeListChatModel(responses=["mock response"])
|
|
79
|
+
adapter = LangChainPromptAdapter(task, custom_model=mockChatModel)
|
|
80
|
+
run = await adapter.invoke("You are a mock, send me the response!")
|
|
81
|
+
assert "mock response" in run.output.output
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
async def test_mock_returning_run(tmp_path):
|
|
85
|
+
task = build_test_task(tmp_path)
|
|
86
|
+
mockChatModel = FakeListChatModel(responses=["mock response"])
|
|
87
|
+
adapter = LangChainPromptAdapter(task, custom_model=mockChatModel)
|
|
88
|
+
run = await adapter.invoke("You are a mock, send me the response!")
|
|
89
|
+
assert run.output.output == "mock response"
|
|
90
|
+
assert run is not None
|
|
91
|
+
assert run.id is not None
|
|
92
|
+
assert run.input == "You are a mock, send me the response!"
|
|
93
|
+
assert run.output.output == "mock response"
|
|
94
|
+
assert "created_by" in run.input_source.properties
|
|
95
|
+
assert run.output.source.properties == {
|
|
96
|
+
"adapter_name": "kiln_langchain_adapter",
|
|
97
|
+
"model_name": "custom.langchain:unknown_model",
|
|
98
|
+
"model_provider": "custom.langchain:FakeListChatModel",
|
|
99
|
+
"prompt_builder_name": "simple_prompt_builder",
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@pytest.mark.paid
|
|
104
|
+
@pytest.mark.ollama
|
|
105
|
+
async def test_all_built_in_models(tmp_path):
|
|
106
|
+
task = build_test_task(tmp_path)
|
|
107
|
+
for model in built_in_models:
|
|
108
|
+
for provider in model.providers:
|
|
109
|
+
try:
|
|
110
|
+
print(f"Running {model.name} {provider.name}")
|
|
111
|
+
await run_simple_task(task, model.name, provider.name)
|
|
112
|
+
except Exception as e:
|
|
113
|
+
raise RuntimeError(f"Error running {model.name} {provider}") from e
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def build_test_task(tmp_path: Path):
|
|
117
|
+
project = datamodel.Project(name="test", path=tmp_path / "test.kiln")
|
|
118
|
+
project.save_to_file()
|
|
119
|
+
assert project.name == "test"
|
|
120
|
+
|
|
121
|
+
r1 = datamodel.TaskRequirement(
|
|
122
|
+
name="BEDMAS",
|
|
123
|
+
instruction="You follow order of mathematical operation (BEDMAS)",
|
|
124
|
+
)
|
|
125
|
+
r2 = datamodel.TaskRequirement(
|
|
126
|
+
name="only basic math",
|
|
127
|
+
instruction="If the problem has anything other than addition, subtraction, multiplication, division, and brackets, you will not answer it. Reply instead with 'I'm just a basic calculator, I don't know how to do that'.",
|
|
128
|
+
)
|
|
129
|
+
r3 = datamodel.TaskRequirement(
|
|
130
|
+
name="Answer format",
|
|
131
|
+
instruction="The answer can contain any content about your reasoning, but at the end it should include the final answer in numerals in square brackets. For example if the answer is one hundred, the end of your response should be [100].",
|
|
132
|
+
)
|
|
133
|
+
task = datamodel.Task(
|
|
134
|
+
parent=project,
|
|
135
|
+
name="test task",
|
|
136
|
+
instruction="You are an assistant which performs math tasks provided in plain text.",
|
|
137
|
+
requirements=[r1, r2, r3],
|
|
138
|
+
)
|
|
139
|
+
task.save_to_file()
|
|
140
|
+
assert task.name == "test task"
|
|
141
|
+
assert len(task.requirements) == 3
|
|
142
|
+
return task
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
async def run_simple_test(tmp_path: Path, model_name: str, provider: str | None = None):
|
|
146
|
+
task = build_test_task(tmp_path)
|
|
147
|
+
return await run_simple_task(task, model_name, provider)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
async def run_simple_task(task: datamodel.Task, model_name: str, provider: str):
|
|
151
|
+
adapter = LangChainPromptAdapter(task, model_name=model_name, provider=provider)
|
|
152
|
+
|
|
153
|
+
run = await adapter.invoke(
|
|
154
|
+
"You should answer the following question: four plus six times 10"
|
|
155
|
+
)
|
|
156
|
+
assert "64" in run.output.output
|
|
157
|
+
assert run.id is not None
|
|
158
|
+
assert (
|
|
159
|
+
run.input == "You should answer the following question: four plus six times 10"
|
|
160
|
+
)
|
|
161
|
+
assert "64" in run.output.output
|
|
162
|
+
assert run.output.source.properties == {
|
|
163
|
+
"adapter_name": "kiln_langchain_adapter",
|
|
164
|
+
"model_name": model_name,
|
|
165
|
+
"model_provider": provider,
|
|
166
|
+
"prompt_builder_name": "simple_prompt_builder",
|
|
167
|
+
}
|