kiln-ai 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +8 -2
- kiln_ai/adapters/adapter_registry.py +43 -208
- kiln_ai/adapters/chat/chat_formatter.py +8 -12
- kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/docker_model_runner_tools.py +119 -0
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/base_eval.py +2 -2
- kiln_ai/adapters/eval/eval_runner.py +9 -3
- kiln_ai/adapters/eval/g_eval.py +2 -2
- kiln_ai/adapters/eval/test_base_eval.py +2 -4
- kiln_ai/adapters/eval/test_g_eval.py +4 -5
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
- kiln_ai/adapters/fine_tune/__init__.py +1 -1
- kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +192 -0
- kiln_ai/adapters/ml_model_list.py +761 -37
- kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
- kiln_ai/adapters/model_adapters/litellm_adapter.py +380 -138
- kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
- kiln_ai/adapters/model_adapters/test_structured_output.py +113 -5
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/parsers/__init__.py +1 -1
- kiln_ai/adapters/provider_tools.py +205 -47
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/repair/test_repair_task.py +12 -9
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +657 -85
- kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
- kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
- kiln_ai/adapters/test_ml_model_list.py +251 -1
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_adaptors.py +13 -6
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +254 -8
- kiln_ai/adapters/test_remote_config.py +651 -58
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +39 -34
- kiln_ai/datamodel/basemodel.py +170 -1
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +28 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/eval.py +1 -1
- kiln_ai/datamodel/external_tool_server.py +298 -0
- kiln_ai/datamodel/extraction.py +303 -0
- kiln_ai/datamodel/json_schema.py +25 -10
- kiln_ai/datamodel/project.py +40 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/registry.py +0 -15
- kiln_ai/datamodel/run_config.py +62 -0
- kiln_ai/datamodel/task.py +2 -77
- kiln_ai/datamodel/task_output.py +6 -1
- kiln_ai/datamodel/task_run.py +41 -0
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +4 -4
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_example_models.py +175 -0
- kiln_ai/datamodel/test_external_tool_server.py +691 -0
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +470 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_registry.py +8 -3
- kiln_ai/datamodel/test_task.py +15 -47
- kiln_ai/datamodel/test_tool_id.py +320 -0
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +105 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/__init__.py +8 -0
- kiln_ai/tools/base_tool.py +82 -0
- kiln_ai/tools/built_in_tools/__init__.py +13 -0
- kiln_ai/tools/built_in_tools/math_tools.py +124 -0
- kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
- kiln_ai/tools/mcp_server_tool.py +95 -0
- kiln_ai/tools/mcp_session_manager.py +246 -0
- kiln_ai/tools/rag_tools.py +157 -0
- kiln_ai/tools/test_base_tools.py +199 -0
- kiln_ai/tools/test_mcp_server_tool.py +457 -0
- kiln_ai/tools/test_mcp_session_manager.py +1585 -0
- kiln_ai/tools/test_rag_tools.py +848 -0
- kiln_ai/tools/test_tool_registry.py +562 -0
- kiln_ai/tools/tool_registry.py +85 -0
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +24 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +94 -0
- kiln_ai/utils/pdf_utils.py +38 -0
- kiln_ai/utils/project_utils.py +17 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_config.py +138 -1
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +131 -0
- kiln_ai/utils/test_pdf_utils.py +73 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +12 -5
- kiln_ai-0.21.0.dist-info/RECORD +211 -0
- kiln_ai-0.19.0.dist-info/RECORD +0 -115
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,1103 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from unittest.mock import Mock, patch
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
from litellm.types.utils import ModelResponse
|
|
7
|
+
from litellm.types.utils import Usage as LiteLlmUsage
|
|
8
|
+
|
|
9
|
+
from kiln_ai import datamodel
|
|
10
|
+
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
11
|
+
from kiln_ai.adapters.ml_model_list import KilnModelProvider, built_in_models
|
|
12
|
+
from kiln_ai.adapters.model_adapters.litellm_adapter import (
|
|
13
|
+
LiteLlmAdapter,
|
|
14
|
+
ModelTurnResult,
|
|
15
|
+
)
|
|
16
|
+
from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
|
|
17
|
+
from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
|
|
18
|
+
from kiln_ai.datamodel import PromptId
|
|
19
|
+
from kiln_ai.datamodel.datamodel_enums import ModelProviderName, StructuredOutputMode
|
|
20
|
+
from kiln_ai.datamodel.task import RunConfigProperties
|
|
21
|
+
from kiln_ai.tools.built_in_tools.math_tools import (
|
|
22
|
+
AddTool,
|
|
23
|
+
DivideTool,
|
|
24
|
+
MultiplyTool,
|
|
25
|
+
SubtractTool,
|
|
26
|
+
)
|
|
27
|
+
from kiln_ai.utils.open_ai_types import ChatCompletionMessageParam
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def build_test_task(tmp_path: Path):
|
|
31
|
+
project = datamodel.Project(name="test", path=tmp_path / "test.kiln")
|
|
32
|
+
project.save_to_file()
|
|
33
|
+
assert project.name == "test"
|
|
34
|
+
|
|
35
|
+
r1 = datamodel.TaskRequirement(
|
|
36
|
+
name="BEDMAS",
|
|
37
|
+
instruction="You follow order of mathematical operation (BEDMAS)",
|
|
38
|
+
)
|
|
39
|
+
r2 = datamodel.TaskRequirement(
|
|
40
|
+
name="only basic math",
|
|
41
|
+
instruction="If the problem has anything other than addition, subtraction, multiplication, division, and brackets, you will not answer it. Reply instead with 'I'm just a basic calculator, I don't know how to do that'.",
|
|
42
|
+
)
|
|
43
|
+
r3 = datamodel.TaskRequirement(
|
|
44
|
+
name="use tools for math",
|
|
45
|
+
instruction="Always use the tools provided for math tasks",
|
|
46
|
+
)
|
|
47
|
+
r4 = datamodel.TaskRequirement(
|
|
48
|
+
name="Answer format",
|
|
49
|
+
instruction="The answer can contain any content about your reasoning, but at the end it should include the final answer in numerals in square brackets. For example if the answer is one hundred, the end of your response should be [100].",
|
|
50
|
+
)
|
|
51
|
+
task = datamodel.Task(
|
|
52
|
+
parent=project,
|
|
53
|
+
name="test task",
|
|
54
|
+
instruction="You are an assistant which performs math tasks provided in plain text using functions/tools.\n\nYou must use function calling (tools) for math tasks or you will be penalized. For example if requested to answer 2+2, you must call the 'add' function with a=2 and b=2 or the answer will be rejected.",
|
|
55
|
+
requirements=[r1, r2, r3, r4],
|
|
56
|
+
)
|
|
57
|
+
task.save_to_file()
|
|
58
|
+
assert task.name == "test task"
|
|
59
|
+
assert len(task.requirements) == 4
|
|
60
|
+
return task
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
async def run_simple_task_with_tools(
|
|
64
|
+
task: datamodel.Task,
|
|
65
|
+
model_name: str,
|
|
66
|
+
provider: str,
|
|
67
|
+
simplified: bool = False,
|
|
68
|
+
prompt_id: PromptId | None = None,
|
|
69
|
+
) -> datamodel.TaskRun:
|
|
70
|
+
adapter = adapter_for_task(
|
|
71
|
+
task,
|
|
72
|
+
RunConfigProperties(
|
|
73
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
74
|
+
model_name=model_name,
|
|
75
|
+
model_provider_name=ModelProviderName(provider),
|
|
76
|
+
prompt_id=prompt_id or "simple_prompt_builder",
|
|
77
|
+
),
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Create tools with MultiplyTool wrapped in a spy
|
|
81
|
+
multiply_tool = MultiplyTool()
|
|
82
|
+
multiply_spy = Mock(wraps=multiply_tool)
|
|
83
|
+
add_tool = AddTool()
|
|
84
|
+
add_spy = Mock(wraps=add_tool)
|
|
85
|
+
mock_math_tools = [add_spy, SubtractTool(), multiply_spy, DivideTool()]
|
|
86
|
+
|
|
87
|
+
with patch.object(adapter, "available_tools", return_value=mock_math_tools):
|
|
88
|
+
if simplified:
|
|
89
|
+
run = await adapter.invoke("what is 2+2")
|
|
90
|
+
|
|
91
|
+
# Verify that AddTool.run was called with correct parameters
|
|
92
|
+
add_spy.run.assert_called()
|
|
93
|
+
add_call_args = add_spy.run.call_args
|
|
94
|
+
add_kwargs = add_call_args.kwargs
|
|
95
|
+
assert add_kwargs.get("a") == 2
|
|
96
|
+
assert add_kwargs.get("b") == 2
|
|
97
|
+
|
|
98
|
+
assert "4" in run.output.output
|
|
99
|
+
|
|
100
|
+
trace = run.trace
|
|
101
|
+
assert trace is not None
|
|
102
|
+
assert len(trace) == 5
|
|
103
|
+
assert trace[0]["role"] == "system"
|
|
104
|
+
assert trace[1]["role"] == "user"
|
|
105
|
+
assert trace[2]["role"] == "assistant"
|
|
106
|
+
assert trace[3]["role"] == "tool"
|
|
107
|
+
assert trace[3]["content"] == "4"
|
|
108
|
+
assert trace[3]["tool_call_id"] is not None
|
|
109
|
+
assert trace[4]["role"] == "assistant"
|
|
110
|
+
assert "[4]" in trace[4]["content"] # type: ignore
|
|
111
|
+
|
|
112
|
+
# Deep dive on tool_calls, which we build ourselves
|
|
113
|
+
tool_calls = trace[2].get("tool_calls", None)
|
|
114
|
+
assert tool_calls is not None
|
|
115
|
+
assert len(tool_calls) == 1
|
|
116
|
+
assert tool_calls[0]["id"] # not None or empty
|
|
117
|
+
assert tool_calls[0]["function"]["name"] == "add"
|
|
118
|
+
json_args = json.loads(tool_calls[0]["function"]["arguments"])
|
|
119
|
+
assert json_args["a"] == 2
|
|
120
|
+
assert json_args["b"] == 2
|
|
121
|
+
else:
|
|
122
|
+
run = await adapter.invoke(
|
|
123
|
+
"You should answer the following question: four plus six times 10"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Verify that MultiplyTool.run was called with correct parameters
|
|
127
|
+
multiply_spy.run.assert_called()
|
|
128
|
+
multiply_call_args = multiply_spy.run.call_args
|
|
129
|
+
multiply_kwargs = multiply_call_args.kwargs
|
|
130
|
+
# Check that multiply was called with a=6, b=10 (or vice versa)
|
|
131
|
+
assert (
|
|
132
|
+
multiply_kwargs.get("a") == 6 and multiply_kwargs.get("b") == 10
|
|
133
|
+
) or (multiply_kwargs.get("a") == 10 and multiply_kwargs.get("b") == 6), (
|
|
134
|
+
f"Expected multiply to be called with a=6, b=10 or a=10, b=6, but got {multiply_kwargs}"
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# Verify that AddTool.run was called with correct parameters
|
|
138
|
+
add_spy.run.assert_called()
|
|
139
|
+
add_call_args = add_spy.run.call_args
|
|
140
|
+
add_kwargs = add_call_args.kwargs
|
|
141
|
+
# Check that add was called with a=60, b=4 (or vice versa)
|
|
142
|
+
assert (add_kwargs.get("a") == 60 and add_kwargs.get("b") == 4) or (
|
|
143
|
+
add_kwargs.get("a") == 4 and add_kwargs.get("b") == 60
|
|
144
|
+
), (
|
|
145
|
+
f"Expected add to be called with a=60, b=4 or a=4, b=60, but got {add_kwargs}"
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
assert "64" in run.output.output
|
|
149
|
+
assert (
|
|
150
|
+
run.input
|
|
151
|
+
== "You should answer the following question: four plus six times 10"
|
|
152
|
+
)
|
|
153
|
+
assert "64" in run.output.output
|
|
154
|
+
|
|
155
|
+
trace = run.trace
|
|
156
|
+
assert trace is not None
|
|
157
|
+
assert len(trace) == 7
|
|
158
|
+
assert trace[0]["role"] == "system"
|
|
159
|
+
assert trace[1]["role"] == "user"
|
|
160
|
+
assert trace[2]["role"] == "assistant"
|
|
161
|
+
assert trace[3]["role"] == "tool"
|
|
162
|
+
assert trace[3]["content"] == "60"
|
|
163
|
+
assert trace[4]["role"] == "assistant"
|
|
164
|
+
assert trace[5]["role"] == "tool"
|
|
165
|
+
assert trace[5]["content"] == "64"
|
|
166
|
+
assert trace[6]["role"] == "assistant"
|
|
167
|
+
assert "[64]" in trace[6]["content"] # type: ignore
|
|
168
|
+
|
|
169
|
+
assert run.id is not None
|
|
170
|
+
source_props = run.output.source.properties if run.output.source else {}
|
|
171
|
+
assert source_props["adapter_name"] in [
|
|
172
|
+
"kiln_langchain_adapter",
|
|
173
|
+
"kiln_openai_compatible_adapter",
|
|
174
|
+
]
|
|
175
|
+
assert source_props["model_name"] == model_name
|
|
176
|
+
assert source_props["model_provider"] == provider
|
|
177
|
+
if prompt_id is None:
|
|
178
|
+
assert source_props["prompt_id"] == "simple_prompt_builder"
|
|
179
|
+
else:
|
|
180
|
+
assert source_props["prompt_id"] == prompt_id
|
|
181
|
+
return run
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
@pytest.mark.paid
|
|
185
|
+
async def test_tools_gpt_4_1_mini(tmp_path):
|
|
186
|
+
task = build_test_task(tmp_path)
|
|
187
|
+
await run_simple_task_with_tools(task, "gpt_4_1_mini", ModelProviderName.openai)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@pytest.mark.paid
|
|
191
|
+
async def test_tools_gpt_4_1_mini_simplified(tmp_path):
|
|
192
|
+
task = build_test_task(tmp_path)
|
|
193
|
+
await run_simple_task_with_tools(
|
|
194
|
+
task, "gpt_4_1_mini", ModelProviderName.openai, simplified=True
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def check_supports_structured_output(model_name: str, provider_name: str):
|
|
199
|
+
for model in built_in_models:
|
|
200
|
+
if model.name != model_name:
|
|
201
|
+
continue
|
|
202
|
+
for provider in model.providers:
|
|
203
|
+
if provider.name != provider_name:
|
|
204
|
+
continue
|
|
205
|
+
if not provider.supports_function_calling:
|
|
206
|
+
pytest.skip(
|
|
207
|
+
f"Skipping {model.name} {provider.name} because it does not support function calling"
|
|
208
|
+
)
|
|
209
|
+
return
|
|
210
|
+
raise RuntimeError(f"No model {model_name} {provider_name} found")
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@pytest.mark.paid
|
|
214
|
+
@pytest.mark.ollama
|
|
215
|
+
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
216
|
+
async def test_tools_all_built_in_models(tmp_path, model_name, provider_name):
|
|
217
|
+
check_supports_structured_output(model_name, provider_name)
|
|
218
|
+
task = build_test_task(tmp_path)
|
|
219
|
+
# For the test of all models run the simplified test, we're checking if it can handle any tool calls, not getting fancy with it
|
|
220
|
+
await run_simple_task_with_tools(task, model_name, provider_name, simplified=True)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
async def test_tools_simplied_mocked(tmp_path):
|
|
224
|
+
task = build_test_task(tmp_path)
|
|
225
|
+
|
|
226
|
+
# Usage should add up, not just return the last one.
|
|
227
|
+
usage = LiteLlmUsage(
|
|
228
|
+
prompt_tokens=10,
|
|
229
|
+
completion_tokens=20,
|
|
230
|
+
total_tokens=30,
|
|
231
|
+
cost=0.5,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Mock 2 responses using tool calls adding 2+2
|
|
235
|
+
# First response: requests add tool call for 2+2
|
|
236
|
+
# Second response: final answer: 4
|
|
237
|
+
# this should trigger proper asserts in the run_simple_task_with_tools function
|
|
238
|
+
|
|
239
|
+
# First response: requests add tool call
|
|
240
|
+
mock_response_1 = ModelResponse(
|
|
241
|
+
model="gpt-4o-mini",
|
|
242
|
+
choices=[
|
|
243
|
+
{
|
|
244
|
+
"message": {
|
|
245
|
+
"content": None,
|
|
246
|
+
"tool_calls": [
|
|
247
|
+
{
|
|
248
|
+
"id": "tool_call_add",
|
|
249
|
+
"type": "function",
|
|
250
|
+
"function": {
|
|
251
|
+
"name": "add",
|
|
252
|
+
"arguments": '{"a": 2, "b": 2}',
|
|
253
|
+
},
|
|
254
|
+
}
|
|
255
|
+
],
|
|
256
|
+
}
|
|
257
|
+
}
|
|
258
|
+
],
|
|
259
|
+
usage=usage,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# Second response: final answer
|
|
263
|
+
mock_response_2 = ModelResponse(
|
|
264
|
+
model="gpt-4o-mini",
|
|
265
|
+
choices=[
|
|
266
|
+
{
|
|
267
|
+
"message": {
|
|
268
|
+
"content": "The answer is [4]",
|
|
269
|
+
"tool_calls": None,
|
|
270
|
+
"reasoning_content": "I used a tool",
|
|
271
|
+
}
|
|
272
|
+
}
|
|
273
|
+
],
|
|
274
|
+
usage=usage,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Mock the Config.shared() method to return a mock config with required attributes
|
|
278
|
+
mock_config = Mock()
|
|
279
|
+
mock_config.open_ai_api_key = "mock_api_key"
|
|
280
|
+
mock_config.user_id = "test_user"
|
|
281
|
+
|
|
282
|
+
with (
|
|
283
|
+
patch(
|
|
284
|
+
"litellm.acompletion",
|
|
285
|
+
side_effect=[mock_response_1, mock_response_2],
|
|
286
|
+
),
|
|
287
|
+
patch("kiln_ai.utils.config.Config.shared", return_value=mock_config),
|
|
288
|
+
):
|
|
289
|
+
task_run = await run_simple_task_with_tools(
|
|
290
|
+
task, "gpt_4_1_mini", ModelProviderName.openai, simplified=True
|
|
291
|
+
)
|
|
292
|
+
assert task_run.usage is not None
|
|
293
|
+
assert task_run.usage.input_tokens == 20
|
|
294
|
+
assert task_run.usage.output_tokens == 40
|
|
295
|
+
assert task_run.usage.total_tokens == 60
|
|
296
|
+
assert task_run.usage.cost == 1.0
|
|
297
|
+
|
|
298
|
+
# Check reasoning content in the trace
|
|
299
|
+
trace = task_run.trace
|
|
300
|
+
assert trace is not None
|
|
301
|
+
assert len(trace) == 5
|
|
302
|
+
assert trace[4].get("reasoning_content") == "I used a tool"
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
async def test_tools_mocked(tmp_path):
|
|
306
|
+
task = build_test_task(tmp_path)
|
|
307
|
+
|
|
308
|
+
# Usage should add up, not just return the last one.
|
|
309
|
+
usage = LiteLlmUsage(
|
|
310
|
+
prompt_tokens=10,
|
|
311
|
+
completion_tokens=20,
|
|
312
|
+
total_tokens=30,
|
|
313
|
+
cost=0.5,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Mock 3 responses using tool calls for BEDMAS operations matching the test math problem: (6*10)+4
|
|
317
|
+
# First response: requests multiply tool call for 6*10
|
|
318
|
+
# Second response: requests add tool call for 60+4
|
|
319
|
+
# Third response: final answer: 64
|
|
320
|
+
# this should trigger proper asserts in the run_simple_task_with_tools function
|
|
321
|
+
|
|
322
|
+
# First response: requests multiply tool call
|
|
323
|
+
mock_response_1 = ModelResponse(
|
|
324
|
+
model="gpt-4o-mini",
|
|
325
|
+
choices=[
|
|
326
|
+
{
|
|
327
|
+
"message": {
|
|
328
|
+
"content": None,
|
|
329
|
+
"tool_calls": [
|
|
330
|
+
{
|
|
331
|
+
"id": "tool_call_multiply",
|
|
332
|
+
"type": "function",
|
|
333
|
+
"function": {
|
|
334
|
+
"name": "multiply",
|
|
335
|
+
"arguments": '{"a": 6, "b": 10}',
|
|
336
|
+
},
|
|
337
|
+
}
|
|
338
|
+
],
|
|
339
|
+
}
|
|
340
|
+
}
|
|
341
|
+
],
|
|
342
|
+
usage=usage,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# Second response: requests add tool call
|
|
346
|
+
mock_response_2 = ModelResponse(
|
|
347
|
+
model="gpt-4o-mini",
|
|
348
|
+
choices=[
|
|
349
|
+
{
|
|
350
|
+
"message": {
|
|
351
|
+
"content": None,
|
|
352
|
+
"tool_calls": [
|
|
353
|
+
{
|
|
354
|
+
"id": "tool_call_add",
|
|
355
|
+
"type": "function",
|
|
356
|
+
"function": {
|
|
357
|
+
"name": "add",
|
|
358
|
+
"arguments": '{"a": 60, "b": 4}',
|
|
359
|
+
},
|
|
360
|
+
}
|
|
361
|
+
],
|
|
362
|
+
}
|
|
363
|
+
}
|
|
364
|
+
],
|
|
365
|
+
usage=usage,
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
# Third response: final answer
|
|
369
|
+
mock_response_3 = ModelResponse(
|
|
370
|
+
model="gpt-4o-mini",
|
|
371
|
+
choices=[{"message": {"content": "The answer is [64]", "tool_calls": None}}],
|
|
372
|
+
usage=usage,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
# Mock the Config.shared() method to return a mock config with required attributes
|
|
376
|
+
mock_config = Mock()
|
|
377
|
+
mock_config.open_ai_api_key = "mock_api_key"
|
|
378
|
+
mock_config.user_id = "test_user"
|
|
379
|
+
|
|
380
|
+
with (
|
|
381
|
+
patch(
|
|
382
|
+
"litellm.acompletion",
|
|
383
|
+
side_effect=[mock_response_1, mock_response_2, mock_response_3],
|
|
384
|
+
),
|
|
385
|
+
patch("kiln_ai.utils.config.Config.shared", return_value=mock_config),
|
|
386
|
+
):
|
|
387
|
+
task_run = await run_simple_task_with_tools(
|
|
388
|
+
task, "gpt_4_1_mini", ModelProviderName.openai
|
|
389
|
+
)
|
|
390
|
+
assert task_run.usage is not None
|
|
391
|
+
assert task_run.usage.input_tokens == 30
|
|
392
|
+
assert task_run.usage.output_tokens == 60
|
|
393
|
+
assert task_run.usage.total_tokens == 90
|
|
394
|
+
assert task_run.usage.cost == 1.5
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
async def test_run_model_turn_parallel_tools(tmp_path):
|
|
398
|
+
"""Test _run_model_turn with multiple parallel tool calls in a single response."""
|
|
399
|
+
task = build_test_task(tmp_path)
|
|
400
|
+
# Cast to LiteLlmAdapter to access _run_model_turn
|
|
401
|
+
config = LiteLlmConfig(
|
|
402
|
+
run_config_properties=RunConfigProperties(
|
|
403
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
404
|
+
model_name="gpt_4_1_mini",
|
|
405
|
+
model_provider_name=ModelProviderName.openai,
|
|
406
|
+
prompt_id="simple_prompt_builder",
|
|
407
|
+
)
|
|
408
|
+
)
|
|
409
|
+
litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
|
|
410
|
+
|
|
411
|
+
# Mock multiple parallel tool calls
|
|
412
|
+
mock_response = ModelResponse(
|
|
413
|
+
model="gpt-4o-mini",
|
|
414
|
+
choices=[
|
|
415
|
+
{
|
|
416
|
+
"message": {
|
|
417
|
+
"content": "I'll solve this step by step using the tools.",
|
|
418
|
+
"tool_calls": [
|
|
419
|
+
{
|
|
420
|
+
"id": "tool_call_multiply",
|
|
421
|
+
"type": "function",
|
|
422
|
+
"function": {
|
|
423
|
+
"name": "multiply",
|
|
424
|
+
"arguments": '{"a": 6, "b": 10}',
|
|
425
|
+
},
|
|
426
|
+
},
|
|
427
|
+
{
|
|
428
|
+
"id": "tool_call_add",
|
|
429
|
+
"type": "function",
|
|
430
|
+
"function": {
|
|
431
|
+
"name": "add",
|
|
432
|
+
"arguments": '{"a": 2, "b": 3}',
|
|
433
|
+
},
|
|
434
|
+
},
|
|
435
|
+
],
|
|
436
|
+
}
|
|
437
|
+
}
|
|
438
|
+
],
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
# Mock final response after tool execution
|
|
442
|
+
final_response = ModelResponse(
|
|
443
|
+
model="gpt-4o-mini",
|
|
444
|
+
choices=[
|
|
445
|
+
{"message": {"content": "The results are 60 and 5", "tool_calls": None}}
|
|
446
|
+
],
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
provider = KilnModelProvider(name=ModelProviderName.openai, model_id="gpt_4_1_mini")
|
|
450
|
+
|
|
451
|
+
prior_messages: list[ChatCompletionMessageParam] = [
|
|
452
|
+
{"role": "user", "content": "Calculate 6*10 and 2+3"}
|
|
453
|
+
]
|
|
454
|
+
|
|
455
|
+
# Create tools with spies
|
|
456
|
+
multiply_tool = MultiplyTool()
|
|
457
|
+
multiply_spy = Mock(wraps=multiply_tool)
|
|
458
|
+
|
|
459
|
+
add_tool = AddTool()
|
|
460
|
+
add_spy = Mock(wraps=add_tool)
|
|
461
|
+
|
|
462
|
+
with patch.object(
|
|
463
|
+
litellm_adapter, "cached_available_tools", return_value=[multiply_spy, add_spy]
|
|
464
|
+
):
|
|
465
|
+
with patch(
|
|
466
|
+
"litellm.acompletion",
|
|
467
|
+
side_effect=[mock_response, final_response],
|
|
468
|
+
):
|
|
469
|
+
with patch.object(
|
|
470
|
+
litellm_adapter, "build_completion_kwargs", return_value={}
|
|
471
|
+
):
|
|
472
|
+
with patch.object(
|
|
473
|
+
litellm_adapter,
|
|
474
|
+
"acompletion_checking_response",
|
|
475
|
+
side_effect=[
|
|
476
|
+
(mock_response, mock_response.choices[0]),
|
|
477
|
+
(final_response, final_response.choices[0]),
|
|
478
|
+
],
|
|
479
|
+
):
|
|
480
|
+
result = await litellm_adapter._run_model_turn(
|
|
481
|
+
provider, prior_messages, None, False
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
# Verify both tools were called in parallel
|
|
485
|
+
multiply_spy.run.assert_called_once_with(a=6, b=10)
|
|
486
|
+
add_spy.run.assert_called_once_with(a=2, b=3)
|
|
487
|
+
|
|
488
|
+
# Verify the result structure
|
|
489
|
+
assert isinstance(result, ModelTurnResult)
|
|
490
|
+
assert result.assistant_message == "The results are 60 and 5"
|
|
491
|
+
assert (
|
|
492
|
+
len(result.all_messages) == 5
|
|
493
|
+
) # user + assistant + 2 tool results + final assistant
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
async def test_run_model_turn_sequential_tools(tmp_path):
|
|
497
|
+
"""Test _run_model_turn with sequential tool calls across multiple turns."""
|
|
498
|
+
task = build_test_task(tmp_path)
|
|
499
|
+
# Cast to LiteLlmAdapter to access _run_model_turn
|
|
500
|
+
config = LiteLlmConfig(
|
|
501
|
+
run_config_properties=RunConfigProperties(
|
|
502
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
503
|
+
model_name="gpt_4_1_mini",
|
|
504
|
+
model_provider_name=ModelProviderName.openai,
|
|
505
|
+
prompt_id="simple_prompt_builder",
|
|
506
|
+
)
|
|
507
|
+
)
|
|
508
|
+
litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
|
|
509
|
+
|
|
510
|
+
# First response: requests multiply tool call
|
|
511
|
+
mock_response_1 = ModelResponse(
|
|
512
|
+
model="gpt-4o-mini",
|
|
513
|
+
choices=[
|
|
514
|
+
{
|
|
515
|
+
"message": {
|
|
516
|
+
"content": None,
|
|
517
|
+
"tool_calls": [
|
|
518
|
+
{
|
|
519
|
+
"id": "tool_call_multiply",
|
|
520
|
+
"type": "function",
|
|
521
|
+
"function": {
|
|
522
|
+
"name": "multiply",
|
|
523
|
+
"arguments": '{"a": 6, "b": 10}',
|
|
524
|
+
},
|
|
525
|
+
}
|
|
526
|
+
],
|
|
527
|
+
}
|
|
528
|
+
}
|
|
529
|
+
],
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
# Second response: requests add tool call using result from first
|
|
533
|
+
mock_response_2 = ModelResponse(
|
|
534
|
+
model="gpt-4o-mini",
|
|
535
|
+
choices=[
|
|
536
|
+
{
|
|
537
|
+
"message": {
|
|
538
|
+
"content": None,
|
|
539
|
+
"tool_calls": [
|
|
540
|
+
{
|
|
541
|
+
"id": "tool_call_add",
|
|
542
|
+
"type": "function",
|
|
543
|
+
"function": {
|
|
544
|
+
"name": "add",
|
|
545
|
+
"arguments": '{"a": 60, "b": 4}',
|
|
546
|
+
},
|
|
547
|
+
}
|
|
548
|
+
],
|
|
549
|
+
}
|
|
550
|
+
}
|
|
551
|
+
],
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
# Final response with answer
|
|
555
|
+
mock_response_3 = ModelResponse(
|
|
556
|
+
model="gpt-4o-mini",
|
|
557
|
+
choices=[
|
|
558
|
+
{"message": {"content": "The final answer is 64", "tool_calls": None}}
|
|
559
|
+
],
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
provider = KilnModelProvider(name=ModelProviderName.openai, model_id="gpt_4_1_mini")
|
|
563
|
+
|
|
564
|
+
prior_messages: list[ChatCompletionMessageParam] = [
|
|
565
|
+
{"role": "user", "content": "Calculate (6*10)+4"}
|
|
566
|
+
]
|
|
567
|
+
|
|
568
|
+
# Create tools with spies
|
|
569
|
+
multiply_tool = MultiplyTool()
|
|
570
|
+
multiply_spy = Mock(wraps=multiply_tool)
|
|
571
|
+
|
|
572
|
+
add_tool = AddTool()
|
|
573
|
+
add_spy = Mock(wraps=add_tool)
|
|
574
|
+
|
|
575
|
+
with patch.object(
|
|
576
|
+
litellm_adapter, "cached_available_tools", return_value=[multiply_spy, add_spy]
|
|
577
|
+
):
|
|
578
|
+
with patch(
|
|
579
|
+
"litellm.acompletion",
|
|
580
|
+
side_effect=[mock_response_1, mock_response_2, mock_response_3],
|
|
581
|
+
):
|
|
582
|
+
with patch.object(
|
|
583
|
+
litellm_adapter, "build_completion_kwargs", return_value={}
|
|
584
|
+
):
|
|
585
|
+
with patch.object(
|
|
586
|
+
litellm_adapter,
|
|
587
|
+
"acompletion_checking_response",
|
|
588
|
+
side_effect=[
|
|
589
|
+
(mock_response_1, mock_response_1.choices[0]),
|
|
590
|
+
(mock_response_2, mock_response_2.choices[0]),
|
|
591
|
+
(mock_response_3, mock_response_3.choices[0]),
|
|
592
|
+
],
|
|
593
|
+
):
|
|
594
|
+
result = await litellm_adapter._run_model_turn(
|
|
595
|
+
provider, prior_messages, None, False
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
# Verify tools were called sequentially
|
|
599
|
+
multiply_spy.run.assert_called_once_with(a=6, b=10)
|
|
600
|
+
add_spy.run.assert_called_once_with(a=60, b=4)
|
|
601
|
+
|
|
602
|
+
# Verify the result structure
|
|
603
|
+
assert isinstance(result, ModelTurnResult)
|
|
604
|
+
assert result.assistant_message == "The final answer is 64"
|
|
605
|
+
# Messages: user + assistant1 + tool1 + assistant2 + tool2 + final assistant
|
|
606
|
+
assert len(result.all_messages) == 6
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
async def test_run_model_turn_max_tool_calls_exceeded(tmp_path):
|
|
610
|
+
"""Test _run_model_turn raises error when MAX_TOOL_CALLS_PER_TURN is exceeded."""
|
|
611
|
+
task = build_test_task(tmp_path)
|
|
612
|
+
# Cast to LiteLlmAdapter to access _run_model_turn
|
|
613
|
+
config = LiteLlmConfig(
|
|
614
|
+
run_config_properties=RunConfigProperties(
|
|
615
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
616
|
+
model_name="gpt_4_1_mini",
|
|
617
|
+
model_provider_name=ModelProviderName.openai,
|
|
618
|
+
prompt_id="simple_prompt_builder",
|
|
619
|
+
)
|
|
620
|
+
)
|
|
621
|
+
litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
|
|
622
|
+
|
|
623
|
+
# Mock response that always returns a tool call (creates infinite loop)
|
|
624
|
+
mock_response = ModelResponse(
|
|
625
|
+
model="gpt-4o-mini",
|
|
626
|
+
choices=[
|
|
627
|
+
{
|
|
628
|
+
"message": {
|
|
629
|
+
"content": None,
|
|
630
|
+
"tool_calls": [
|
|
631
|
+
{
|
|
632
|
+
"id": "tool_call_add",
|
|
633
|
+
"type": "function",
|
|
634
|
+
"function": {
|
|
635
|
+
"name": "add",
|
|
636
|
+
"arguments": '{"a": 1, "b": 1}',
|
|
637
|
+
},
|
|
638
|
+
}
|
|
639
|
+
],
|
|
640
|
+
}
|
|
641
|
+
}
|
|
642
|
+
],
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
provider = KilnModelProvider(name=ModelProviderName.openai, model_id="gpt_4_1_mini")
|
|
646
|
+
|
|
647
|
+
prior_messages: list[ChatCompletionMessageParam] = [
|
|
648
|
+
{"role": "user", "content": "Keep adding 1+1"}
|
|
649
|
+
]
|
|
650
|
+
|
|
651
|
+
# Create tool with spy
|
|
652
|
+
add_tool = AddTool()
|
|
653
|
+
add_spy = Mock(wraps=add_tool)
|
|
654
|
+
|
|
655
|
+
with patch.object(
|
|
656
|
+
litellm_adapter, "cached_available_tools", return_value=[add_spy]
|
|
657
|
+
):
|
|
658
|
+
with patch(
|
|
659
|
+
"litellm.acompletion",
|
|
660
|
+
return_value=mock_response,
|
|
661
|
+
):
|
|
662
|
+
with patch.object(
|
|
663
|
+
litellm_adapter, "build_completion_kwargs", return_value={}
|
|
664
|
+
):
|
|
665
|
+
with patch.object(
|
|
666
|
+
litellm_adapter,
|
|
667
|
+
"acompletion_checking_response",
|
|
668
|
+
return_value=(mock_response, mock_response.choices[0]),
|
|
669
|
+
):
|
|
670
|
+
with pytest.raises(RuntimeError, match="Too many tool calls"):
|
|
671
|
+
await litellm_adapter._run_model_turn(
|
|
672
|
+
provider, prior_messages, None, False
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
async def test_run_model_turn_no_tool_calls(tmp_path):
|
|
677
|
+
"""Test _run_model_turn with a simple response that doesn't use tools."""
|
|
678
|
+
task = build_test_task(tmp_path)
|
|
679
|
+
# Cast to LiteLlmAdapter to access _run_model_turn
|
|
680
|
+
config = LiteLlmConfig(
|
|
681
|
+
run_config_properties=RunConfigProperties(
|
|
682
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
683
|
+
model_name="gpt_4_1_mini",
|
|
684
|
+
model_provider_name=ModelProviderName.openai,
|
|
685
|
+
prompt_id="simple_prompt_builder",
|
|
686
|
+
)
|
|
687
|
+
)
|
|
688
|
+
litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
|
|
689
|
+
|
|
690
|
+
# Mock response without tool calls
|
|
691
|
+
mock_response = ModelResponse(
|
|
692
|
+
model="gpt-4o-mini",
|
|
693
|
+
choices=[
|
|
694
|
+
{"message": {"content": "This is a simple response", "tool_calls": None}}
|
|
695
|
+
],
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
provider = KilnModelProvider(name=ModelProviderName.openai, model_id="gpt_4_1_mini")
|
|
699
|
+
|
|
700
|
+
prior_messages: list[ChatCompletionMessageParam] = [
|
|
701
|
+
{"role": "user", "content": "Hello, how are you?"}
|
|
702
|
+
]
|
|
703
|
+
|
|
704
|
+
with patch.object(litellm_adapter, "build_completion_kwargs", return_value={}):
|
|
705
|
+
with patch.object(
|
|
706
|
+
litellm_adapter,
|
|
707
|
+
"acompletion_checking_response",
|
|
708
|
+
return_value=(mock_response, mock_response.choices[0]),
|
|
709
|
+
):
|
|
710
|
+
result = await litellm_adapter._run_model_turn(
|
|
711
|
+
provider, prior_messages, None, False
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
# Verify the result structure
|
|
715
|
+
assert isinstance(result, ModelTurnResult)
|
|
716
|
+
assert result.assistant_message == "This is a simple response"
|
|
717
|
+
assert len(result.all_messages) == 2 # user + assistant
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
# Unit tests for process_tool_calls method
|
|
721
|
+
class MockToolCall:
|
|
722
|
+
"""Mock class for ChatCompletionMessageToolCall"""
|
|
723
|
+
|
|
724
|
+
def __init__(self, id: str, function_name: str, arguments: str):
|
|
725
|
+
self.id = id
|
|
726
|
+
self.function = Mock()
|
|
727
|
+
self.function.name = function_name
|
|
728
|
+
self.function.arguments = arguments
|
|
729
|
+
self.type = "function"
|
|
730
|
+
|
|
731
|
+
|
|
732
|
+
class MockTool:
|
|
733
|
+
"""Mock tool class for testing"""
|
|
734
|
+
|
|
735
|
+
def __init__(
|
|
736
|
+
self,
|
|
737
|
+
name: str,
|
|
738
|
+
raise_on_run: Exception | None = None,
|
|
739
|
+
return_value: str = "test_result",
|
|
740
|
+
):
|
|
741
|
+
self._name = name
|
|
742
|
+
self._raise_on_run = raise_on_run
|
|
743
|
+
self._return_value = return_value
|
|
744
|
+
|
|
745
|
+
async def name(self) -> str:
|
|
746
|
+
return self._name
|
|
747
|
+
|
|
748
|
+
async def toolcall_definition(self) -> dict:
|
|
749
|
+
return {
|
|
750
|
+
"function": {
|
|
751
|
+
"parameters": {
|
|
752
|
+
"type": "object",
|
|
753
|
+
"properties": {"a": {"type": "number"}, "b": {"type": "number"}},
|
|
754
|
+
"required": ["a", "b"],
|
|
755
|
+
}
|
|
756
|
+
}
|
|
757
|
+
}
|
|
758
|
+
|
|
759
|
+
async def run(self, **kwargs) -> str:
|
|
760
|
+
if self._raise_on_run:
|
|
761
|
+
raise self._raise_on_run
|
|
762
|
+
return self._return_value
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
async def test_process_tool_calls_none_input(tmp_path):
|
|
766
|
+
"""Test process_tool_calls with None input"""
|
|
767
|
+
task = build_test_task(tmp_path)
|
|
768
|
+
config = LiteLlmConfig(
|
|
769
|
+
run_config_properties=RunConfigProperties(
|
|
770
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
771
|
+
model_name="gpt_4_1_mini",
|
|
772
|
+
model_provider_name=ModelProviderName.openai,
|
|
773
|
+
prompt_id="simple_prompt_builder",
|
|
774
|
+
)
|
|
775
|
+
)
|
|
776
|
+
litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
|
|
777
|
+
|
|
778
|
+
assistant_output, tool_messages = await litellm_adapter.process_tool_calls(None)
|
|
779
|
+
|
|
780
|
+
assert assistant_output is None
|
|
781
|
+
assert tool_messages == []
|
|
782
|
+
|
|
783
|
+
|
|
784
|
+
async def test_process_tool_calls_empty_list(tmp_path):
|
|
785
|
+
"""Test process_tool_calls with empty tool calls list"""
|
|
786
|
+
task = build_test_task(tmp_path)
|
|
787
|
+
config = LiteLlmConfig(
|
|
788
|
+
run_config_properties=RunConfigProperties(
|
|
789
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
790
|
+
model_name="gpt_4_1_mini",
|
|
791
|
+
model_provider_name=ModelProviderName.openai,
|
|
792
|
+
prompt_id="simple_prompt_builder",
|
|
793
|
+
)
|
|
794
|
+
)
|
|
795
|
+
litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
|
|
796
|
+
|
|
797
|
+
assistant_output, tool_messages = await litellm_adapter.process_tool_calls([])
|
|
798
|
+
|
|
799
|
+
assert assistant_output is None
|
|
800
|
+
assert tool_messages == []
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
async def test_process_tool_calls_task_response_only(tmp_path):
|
|
804
|
+
"""Test process_tool_calls with only task_response tool call"""
|
|
805
|
+
task = build_test_task(tmp_path)
|
|
806
|
+
config = LiteLlmConfig(
|
|
807
|
+
run_config_properties=RunConfigProperties(
|
|
808
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
809
|
+
model_name="gpt_4_1_mini",
|
|
810
|
+
model_provider_name=ModelProviderName.openai,
|
|
811
|
+
prompt_id="simple_prompt_builder",
|
|
812
|
+
)
|
|
813
|
+
)
|
|
814
|
+
litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
|
|
815
|
+
|
|
816
|
+
tool_calls = [MockToolCall("call_1", "task_response", '{"answer": "42"}')]
|
|
817
|
+
|
|
818
|
+
assistant_output, tool_messages = await litellm_adapter.process_tool_calls(
|
|
819
|
+
tool_calls # type: ignore
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
assert assistant_output == '{"answer": "42"}'
|
|
823
|
+
assert tool_messages == []
|
|
824
|
+
|
|
825
|
+
|
|
826
|
+
async def test_process_tool_calls_multiple_task_response(tmp_path):
|
|
827
|
+
"""Test process_tool_calls with multiple task_response calls - should keep the last one"""
|
|
828
|
+
task = build_test_task(tmp_path)
|
|
829
|
+
config = LiteLlmConfig(
|
|
830
|
+
run_config_properties=RunConfigProperties(
|
|
831
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
832
|
+
model_name="gpt_4_1_mini",
|
|
833
|
+
model_provider_name=ModelProviderName.openai,
|
|
834
|
+
prompt_id="simple_prompt_builder",
|
|
835
|
+
)
|
|
836
|
+
)
|
|
837
|
+
litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
|
|
838
|
+
|
|
839
|
+
tool_calls = [
|
|
840
|
+
MockToolCall("call_1", "task_response", '{"answer": "first"}'),
|
|
841
|
+
MockToolCall("call_2", "task_response", '{"answer": "second"}'),
|
|
842
|
+
]
|
|
843
|
+
|
|
844
|
+
assistant_output, tool_messages = await litellm_adapter.process_tool_calls(
|
|
845
|
+
tool_calls # type: ignore
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
# Should keep the last task_response
|
|
849
|
+
assert assistant_output == '{"answer": "second"}'
|
|
850
|
+
assert tool_messages == []
|
|
851
|
+
|
|
852
|
+
|
|
853
|
+
async def test_process_tool_calls_normal_tool_success(tmp_path):
|
|
854
|
+
"""Test process_tool_calls with successful normal tool call"""
|
|
855
|
+
task = build_test_task(tmp_path)
|
|
856
|
+
config = LiteLlmConfig(
|
|
857
|
+
run_config_properties=RunConfigProperties(
|
|
858
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
859
|
+
model_name="gpt_4_1_mini",
|
|
860
|
+
model_provider_name=ModelProviderName.openai,
|
|
861
|
+
prompt_id="simple_prompt_builder",
|
|
862
|
+
)
|
|
863
|
+
)
|
|
864
|
+
litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
|
|
865
|
+
|
|
866
|
+
mock_tool = MockTool("add", return_value="5")
|
|
867
|
+
tool_calls = [MockToolCall("call_1", "add", '{"a": 2, "b": 3}')]
|
|
868
|
+
|
|
869
|
+
with patch.object(
|
|
870
|
+
litellm_adapter, "cached_available_tools", return_value=[mock_tool]
|
|
871
|
+
):
|
|
872
|
+
assistant_output, tool_messages = await litellm_adapter.process_tool_calls(
|
|
873
|
+
tool_calls # type: ignore
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
assert assistant_output is None
|
|
877
|
+
assert len(tool_messages) == 1
|
|
878
|
+
assert tool_messages[0] == {
|
|
879
|
+
"role": "tool",
|
|
880
|
+
"tool_call_id": "call_1",
|
|
881
|
+
"content": "5",
|
|
882
|
+
}
|
|
883
|
+
|
|
884
|
+
|
|
885
|
+
async def test_process_tool_calls_multiple_normal_tools(tmp_path):
|
|
886
|
+
"""Test process_tool_calls with multiple normal tool calls"""
|
|
887
|
+
task = build_test_task(tmp_path)
|
|
888
|
+
config = LiteLlmConfig(
|
|
889
|
+
run_config_properties=RunConfigProperties(
|
|
890
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
891
|
+
model_name="gpt_4_1_mini",
|
|
892
|
+
model_provider_name=ModelProviderName.openai,
|
|
893
|
+
prompt_id="simple_prompt_builder",
|
|
894
|
+
)
|
|
895
|
+
)
|
|
896
|
+
litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
|
|
897
|
+
|
|
898
|
+
mock_tool_add = MockTool("add", return_value="5")
|
|
899
|
+
mock_tool_multiply = MockTool("multiply", return_value="6")
|
|
900
|
+
tool_calls = [
|
|
901
|
+
MockToolCall("call_1", "add", '{"a": 2, "b": 3}'),
|
|
902
|
+
MockToolCall("call_2", "multiply", '{"a": 2, "b": 3}'),
|
|
903
|
+
]
|
|
904
|
+
|
|
905
|
+
with patch.object(
|
|
906
|
+
litellm_adapter,
|
|
907
|
+
"cached_available_tools",
|
|
908
|
+
return_value=[mock_tool_add, mock_tool_multiply],
|
|
909
|
+
):
|
|
910
|
+
assistant_output, tool_messages = await litellm_adapter.process_tool_calls(
|
|
911
|
+
tool_calls # type: ignore
|
|
912
|
+
)
|
|
913
|
+
|
|
914
|
+
assert assistant_output is None
|
|
915
|
+
assert len(tool_messages) == 2
|
|
916
|
+
assert tool_messages[0]["tool_call_id"] == "call_1"
|
|
917
|
+
assert tool_messages[0]["content"] == "5"
|
|
918
|
+
assert tool_messages[1]["tool_call_id"] == "call_2"
|
|
919
|
+
assert tool_messages[1]["content"] == "6"
|
|
920
|
+
|
|
921
|
+
|
|
922
|
+
async def test_process_tool_calls_tool_not_found(tmp_path):
|
|
923
|
+
"""Test process_tool_calls when tool is not found"""
|
|
924
|
+
task = build_test_task(tmp_path)
|
|
925
|
+
config = LiteLlmConfig(
|
|
926
|
+
run_config_properties=RunConfigProperties(
|
|
927
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
928
|
+
model_name="gpt_4_1_mini",
|
|
929
|
+
model_provider_name=ModelProviderName.openai,
|
|
930
|
+
prompt_id="simple_prompt_builder",
|
|
931
|
+
)
|
|
932
|
+
)
|
|
933
|
+
litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
|
|
934
|
+
|
|
935
|
+
tool_calls = [MockToolCall("call_1", "nonexistent_tool", '{"a": 2, "b": 3}')]
|
|
936
|
+
|
|
937
|
+
with patch.object(litellm_adapter, "cached_available_tools", return_value=[]):
|
|
938
|
+
with pytest.raises(
|
|
939
|
+
RuntimeError,
|
|
940
|
+
match="A tool named 'nonexistent_tool' was invoked by a model, but was not available",
|
|
941
|
+
):
|
|
942
|
+
await litellm_adapter.process_tool_calls(tool_calls) # type: ignore
|
|
943
|
+
|
|
944
|
+
|
|
945
|
+
async def test_process_tool_calls_invalid_json_arguments(tmp_path):
|
|
946
|
+
"""Test process_tool_calls with invalid JSON arguments"""
|
|
947
|
+
task = build_test_task(tmp_path)
|
|
948
|
+
config = LiteLlmConfig(
|
|
949
|
+
run_config_properties=RunConfigProperties(
|
|
950
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
951
|
+
model_name="gpt_4_1_mini",
|
|
952
|
+
model_provider_name=ModelProviderName.openai,
|
|
953
|
+
prompt_id="simple_prompt_builder",
|
|
954
|
+
)
|
|
955
|
+
)
|
|
956
|
+
litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
|
|
957
|
+
|
|
958
|
+
mock_tool = MockTool("add")
|
|
959
|
+
tool_calls = [MockToolCall("call_1", "add", "invalid json")]
|
|
960
|
+
|
|
961
|
+
with patch.object(
|
|
962
|
+
litellm_adapter, "cached_available_tools", return_value=[mock_tool]
|
|
963
|
+
):
|
|
964
|
+
with pytest.raises(
|
|
965
|
+
RuntimeError, match="Failed to parse arguments for tool 'add'"
|
|
966
|
+
):
|
|
967
|
+
await litellm_adapter.process_tool_calls(tool_calls) # type: ignore
|
|
968
|
+
|
|
969
|
+
|
|
970
|
+
async def test_process_tool_calls_empty_arguments(tmp_path):
|
|
971
|
+
"""Test process_tool_calls with empty arguments string"""
|
|
972
|
+
task = build_test_task(tmp_path)
|
|
973
|
+
config = LiteLlmConfig(
|
|
974
|
+
run_config_properties=RunConfigProperties(
|
|
975
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
976
|
+
model_name="gpt_4_1_mini",
|
|
977
|
+
model_provider_name=ModelProviderName.openai,
|
|
978
|
+
prompt_id="simple_prompt_builder",
|
|
979
|
+
)
|
|
980
|
+
)
|
|
981
|
+
litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
|
|
982
|
+
|
|
983
|
+
mock_tool = MockTool("add")
|
|
984
|
+
tool_calls = [MockToolCall("call_1", "add", "")]
|
|
985
|
+
|
|
986
|
+
with patch.object(
|
|
987
|
+
litellm_adapter, "cached_available_tools", return_value=[mock_tool]
|
|
988
|
+
):
|
|
989
|
+
with pytest.raises(
|
|
990
|
+
RuntimeError, match="Failed to parse arguments for tool 'add'"
|
|
991
|
+
):
|
|
992
|
+
await litellm_adapter.process_tool_calls(tool_calls) # type: ignore
|
|
993
|
+
|
|
994
|
+
|
|
995
|
+
async def test_process_tool_calls_schema_validation_error(tmp_path):
|
|
996
|
+
"""Test process_tool_calls with schema validation error"""
|
|
997
|
+
task = build_test_task(tmp_path)
|
|
998
|
+
config = LiteLlmConfig(
|
|
999
|
+
run_config_properties=RunConfigProperties(
|
|
1000
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
1001
|
+
model_name="gpt_4_1_mini",
|
|
1002
|
+
model_provider_name=ModelProviderName.openai,
|
|
1003
|
+
prompt_id="simple_prompt_builder",
|
|
1004
|
+
)
|
|
1005
|
+
)
|
|
1006
|
+
litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
|
|
1007
|
+
|
|
1008
|
+
mock_tool = MockTool("add")
|
|
1009
|
+
# Missing required field 'b'
|
|
1010
|
+
tool_calls = [MockToolCall("call_1", "add", '{"a": 2}')]
|
|
1011
|
+
|
|
1012
|
+
with patch.object(
|
|
1013
|
+
litellm_adapter, "cached_available_tools", return_value=[mock_tool]
|
|
1014
|
+
):
|
|
1015
|
+
with pytest.raises(
|
|
1016
|
+
RuntimeError, match="Failed to validate arguments for tool 'add'"
|
|
1017
|
+
):
|
|
1018
|
+
await litellm_adapter.process_tool_calls(tool_calls) # type: ignore
|
|
1019
|
+
|
|
1020
|
+
|
|
1021
|
+
async def test_process_tool_calls_tool_execution_error(tmp_path):
|
|
1022
|
+
"""Test process_tool_calls when tool execution raises exception"""
|
|
1023
|
+
task = build_test_task(tmp_path)
|
|
1024
|
+
config = LiteLlmConfig(
|
|
1025
|
+
run_config_properties=RunConfigProperties(
|
|
1026
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
1027
|
+
model_name="gpt_4_1_mini",
|
|
1028
|
+
model_provider_name=ModelProviderName.openai,
|
|
1029
|
+
prompt_id="simple_prompt_builder",
|
|
1030
|
+
)
|
|
1031
|
+
)
|
|
1032
|
+
litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
|
|
1033
|
+
|
|
1034
|
+
# Mock tool that raises exception when run
|
|
1035
|
+
mock_tool = MockTool("add", raise_on_run=ValueError("Tool execution failed"))
|
|
1036
|
+
tool_calls = [MockToolCall("call_1", "add", '{"a": 2, "b": 3}')]
|
|
1037
|
+
|
|
1038
|
+
with patch.object(
|
|
1039
|
+
litellm_adapter, "cached_available_tools", return_value=[mock_tool]
|
|
1040
|
+
):
|
|
1041
|
+
# This should raise the ValueError from the tool
|
|
1042
|
+
with pytest.raises(ValueError, match="Tool execution failed"):
|
|
1043
|
+
await litellm_adapter.process_tool_calls(tool_calls) # type: ignore
|
|
1044
|
+
|
|
1045
|
+
|
|
1046
|
+
async def test_process_tool_calls_complex_result(tmp_path):
|
|
1047
|
+
"""Test process_tool_calls when tool returns complex object"""
|
|
1048
|
+
task = build_test_task(tmp_path)
|
|
1049
|
+
config = LiteLlmConfig(
|
|
1050
|
+
run_config_properties=RunConfigProperties(
|
|
1051
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
1052
|
+
model_name="gpt_4_1_mini",
|
|
1053
|
+
model_provider_name=ModelProviderName.openai,
|
|
1054
|
+
prompt_id="simple_prompt_builder",
|
|
1055
|
+
)
|
|
1056
|
+
)
|
|
1057
|
+
litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
|
|
1058
|
+
|
|
1059
|
+
complex_result = json.dumps(
|
|
1060
|
+
{"status": "success", "result": 42, "metadata": [1, 2, 3]}
|
|
1061
|
+
)
|
|
1062
|
+
mock_tool = MockTool("add", return_value=complex_result)
|
|
1063
|
+
tool_calls = [MockToolCall("call_1", "add", '{"a": 2, "b": 3}')]
|
|
1064
|
+
|
|
1065
|
+
with patch.object(
|
|
1066
|
+
litellm_adapter, "cached_available_tools", return_value=[mock_tool]
|
|
1067
|
+
):
|
|
1068
|
+
assistant_output, tool_messages = await litellm_adapter.process_tool_calls(
|
|
1069
|
+
tool_calls # type: ignore
|
|
1070
|
+
)
|
|
1071
|
+
|
|
1072
|
+
assert assistant_output is None
|
|
1073
|
+
assert len(tool_messages) == 1
|
|
1074
|
+
assert tool_messages[0]["content"] == complex_result
|
|
1075
|
+
|
|
1076
|
+
|
|
1077
|
+
async def test_process_tool_calls_task_response_with_normal_tools_error(tmp_path):
|
|
1078
|
+
"""Test process_tool_calls raises error when mixing task_response with normal tools"""
|
|
1079
|
+
task = build_test_task(tmp_path)
|
|
1080
|
+
config = LiteLlmConfig(
|
|
1081
|
+
run_config_properties=RunConfigProperties(
|
|
1082
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
1083
|
+
model_name="gpt_4_1_mini",
|
|
1084
|
+
model_provider_name=ModelProviderName.openai,
|
|
1085
|
+
prompt_id="simple_prompt_builder",
|
|
1086
|
+
)
|
|
1087
|
+
)
|
|
1088
|
+
litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
|
|
1089
|
+
|
|
1090
|
+
mock_tool = MockTool("add", return_value="5")
|
|
1091
|
+
tool_calls = [
|
|
1092
|
+
MockToolCall("call_1", "task_response", '{"answer": "42"}'),
|
|
1093
|
+
MockToolCall("call_2", "add", '{"a": 2, "b": 3}'),
|
|
1094
|
+
]
|
|
1095
|
+
|
|
1096
|
+
with patch.object(
|
|
1097
|
+
litellm_adapter, "cached_available_tools", return_value=[mock_tool]
|
|
1098
|
+
):
|
|
1099
|
+
with pytest.raises(
|
|
1100
|
+
RuntimeError,
|
|
1101
|
+
match="task_response tool call and other tool calls were both provided",
|
|
1102
|
+
):
|
|
1103
|
+
await litellm_adapter.process_tool_calls(tool_calls) # type: ignore
|