kiln-ai 0.21.0__py3-none-any.whl → 0.22.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/extractors/litellm_extractor.py +52 -32
- kiln_ai/adapters/extractors/test_litellm_extractor.py +169 -71
- kiln_ai/adapters/ml_embedding_model_list.py +330 -28
- kiln_ai/adapters/ml_model_list.py +503 -23
- kiln_ai/adapters/model_adapters/litellm_adapter.py +39 -8
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +78 -0
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +6 -9
- kiln_ai/adapters/test_ml_embedding_model_list.py +89 -279
- kiln_ai/adapters/test_ml_model_list.py +0 -10
- kiln_ai/adapters/vector_store/lancedb_adapter.py +24 -70
- kiln_ai/adapters/vector_store/lancedb_helpers.py +101 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +9 -16
- kiln_ai/adapters/vector_store/test_lancedb_helpers.py +142 -0
- kiln_ai/adapters/vector_store_loaders/__init__.py +0 -0
- kiln_ai/adapters/vector_store_loaders/test_lancedb_loader.py +282 -0
- kiln_ai/adapters/vector_store_loaders/test_vector_store_loader.py +544 -0
- kiln_ai/adapters/vector_store_loaders/vector_store_loader.py +91 -0
- kiln_ai/datamodel/basemodel.py +31 -3
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +14 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +3 -3
- kiln_ai/datamodel/test_basemodel.py +269 -13
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_model.py +31 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +106 -1
- kiln_ai/datamodel/tool_id.py +49 -0
- kiln_ai/tools/base_tool.py +30 -6
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +162 -0
- kiln_ai/tools/mcp_server_tool.py +7 -5
- kiln_ai/tools/mcp_session_manager.py +50 -24
- kiln_ai/tools/rag_tools.py +17 -6
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +186 -226
- kiln_ai/tools/test_rag_tools.py +86 -5
- kiln_ai/tools/test_tool_registry.py +199 -5
- kiln_ai/tools/tool_registry.py +49 -17
- kiln_ai/utils/filesystem.py +4 -4
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +21 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +14 -1
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/METADATA +79 -1
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/RECORD +53 -45
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -2,6 +2,7 @@ import logging
|
|
|
2
2
|
import os
|
|
3
3
|
import subprocess
|
|
4
4
|
import sys
|
|
5
|
+
import tempfile
|
|
5
6
|
from contextlib import asynccontextmanager
|
|
6
7
|
from datetime import timedelta
|
|
7
8
|
from typing import AsyncGenerator
|
|
@@ -19,6 +20,8 @@ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
|
19
20
|
|
|
20
21
|
logger = logging.getLogger(__name__)
|
|
21
22
|
|
|
23
|
+
LOCAL_MCP_ERROR_INSTRUCTION = "Please verify your command, arguments, and environment variables, and consult the server's documentation for the correct setup."
|
|
24
|
+
|
|
22
25
|
|
|
23
26
|
class MCPSessionManager:
|
|
24
27
|
"""
|
|
@@ -51,6 +54,8 @@ class MCPSessionManager:
|
|
|
51
54
|
case ToolServerType.local_mcp:
|
|
52
55
|
async with self._create_local_mcp_session(tool_server) as session:
|
|
53
56
|
yield session
|
|
57
|
+
case ToolServerType.kiln_task:
|
|
58
|
+
raise ValueError("Kiln task tools are not available from an MCP server")
|
|
54
59
|
case _:
|
|
55
60
|
raise_exhaustive_enum_error(tool_server.type)
|
|
56
61
|
|
|
@@ -164,35 +169,56 @@ class MCPSessionManager:
|
|
|
164
169
|
env_vars["PATH"] = self._get_path()
|
|
165
170
|
|
|
166
171
|
# Set the server parameters
|
|
172
|
+
cwd = os.path.join(Config.settings_dir(), "cache", "mcp_cache")
|
|
173
|
+
os.makedirs(cwd, exist_ok=True)
|
|
167
174
|
server_params = StdioServerParameters(
|
|
168
|
-
command=command,
|
|
169
|
-
args=args,
|
|
170
|
-
env=env_vars,
|
|
175
|
+
command=command, args=args, env=env_vars, cwd=cwd
|
|
171
176
|
)
|
|
172
177
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
178
|
+
# Create temporary file to capture MCP server stderr
|
|
179
|
+
# Use errors="replace" to handle non-UTF-8 bytes gracefully
|
|
180
|
+
with tempfile.TemporaryFile(
|
|
181
|
+
mode="w+", encoding="utf-8", errors="replace"
|
|
182
|
+
) as err_log:
|
|
183
|
+
try:
|
|
184
|
+
async with stdio_client(server_params, errlog=err_log) as (
|
|
185
|
+
read,
|
|
186
|
+
write,
|
|
187
|
+
):
|
|
188
|
+
async with ClientSession(
|
|
189
|
+
read, write, read_timeout_seconds=timedelta(seconds=30)
|
|
190
|
+
) as session:
|
|
191
|
+
await session.initialize()
|
|
192
|
+
yield session
|
|
193
|
+
except Exception as e:
|
|
194
|
+
# Read stderr content from temporary file for debugging
|
|
195
|
+
err_log.seek(0) # Read from the start of the file
|
|
196
|
+
stderr_content = err_log.read()
|
|
197
|
+
if stderr_content:
|
|
198
|
+
logger.error(
|
|
199
|
+
f"MCP server '{tool_server.name}' stderr output: {stderr_content}"
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Check for MCP errors. Things like wrong arguments would fall here.
|
|
203
|
+
mcp_error = self._extract_first_exception(e, McpError)
|
|
204
|
+
if mcp_error and isinstance(mcp_error, McpError):
|
|
205
|
+
self._raise_local_mcp_error(mcp_error, stderr_content)
|
|
206
|
+
|
|
207
|
+
# Re-raise the original error but with a friendlier message
|
|
208
|
+
self._raise_local_mcp_error(e, stderr_content)
|
|
209
|
+
|
|
210
|
+
def _raise_local_mcp_error(self, e: Exception, stderr: str):
|
|
190
211
|
"""
|
|
191
|
-
Raise a
|
|
212
|
+
Raise a RuntimeError with a friendlier message for local MCP errors.
|
|
192
213
|
"""
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
214
|
+
error_msg = f"'{e}'"
|
|
215
|
+
|
|
216
|
+
if stderr:
|
|
217
|
+
error_msg += f"\nMCP server error: {stderr}"
|
|
218
|
+
|
|
219
|
+
error_msg += f"\n{LOCAL_MCP_ERROR_INSTRUCTION}"
|
|
220
|
+
|
|
221
|
+
raise RuntimeError(error_msg) from e
|
|
196
222
|
|
|
197
223
|
def _get_path(self) -> str:
|
|
198
224
|
"""
|
kiln_ai/tools/rag_tools.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from functools import cached_property
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import List, TypedDict
|
|
3
3
|
|
|
4
4
|
from pydantic import BaseModel
|
|
5
5
|
|
|
@@ -18,7 +18,11 @@ from kiln_ai.datamodel.project import Project
|
|
|
18
18
|
from kiln_ai.datamodel.rag import RagConfig
|
|
19
19
|
from kiln_ai.datamodel.tool_id import ToolId
|
|
20
20
|
from kiln_ai.datamodel.vector_store import VectorStoreConfig, VectorStoreType
|
|
21
|
-
from kiln_ai.tools.base_tool import
|
|
21
|
+
from kiln_ai.tools.base_tool import (
|
|
22
|
+
KilnToolInterface,
|
|
23
|
+
ToolCallContext,
|
|
24
|
+
ToolCallDefinition,
|
|
25
|
+
)
|
|
22
26
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
23
27
|
|
|
24
28
|
|
|
@@ -46,6 +50,10 @@ def format_search_results(search_results: List[SearchResult]) -> str:
|
|
|
46
50
|
return "\n=========\n".join([result.serialize() for result in results])
|
|
47
51
|
|
|
48
52
|
|
|
53
|
+
class RagParams(TypedDict):
|
|
54
|
+
query: str
|
|
55
|
+
|
|
56
|
+
|
|
49
57
|
class RagTool(KilnToolInterface):
|
|
50
58
|
"""
|
|
51
59
|
A tool that searches the vector store and returns the most relevant chunks.
|
|
@@ -115,7 +123,7 @@ class RagTool(KilnToolInterface):
|
|
|
115
123
|
async def description(self) -> str:
|
|
116
124
|
return self._description
|
|
117
125
|
|
|
118
|
-
async def toolcall_definition(self) ->
|
|
126
|
+
async def toolcall_definition(self) -> ToolCallDefinition:
|
|
119
127
|
"""Return the OpenAI-compatible tool definition for this tool."""
|
|
120
128
|
return {
|
|
121
129
|
"type": "function",
|
|
@@ -126,7 +134,10 @@ class RagTool(KilnToolInterface):
|
|
|
126
134
|
},
|
|
127
135
|
}
|
|
128
136
|
|
|
129
|
-
async def run(self,
|
|
137
|
+
async def run(self, context: ToolCallContext | None = None, **kwargs) -> str:
|
|
138
|
+
kwargs = RagParams(**kwargs)
|
|
139
|
+
query = kwargs["query"]
|
|
140
|
+
|
|
130
141
|
_, embedding_adapter = self.embedding
|
|
131
142
|
|
|
132
143
|
vector_store_adapter = await self.vector_store()
|
|
@@ -152,6 +163,6 @@ class RagTool(KilnToolInterface):
|
|
|
152
163
|
store_query.query_embedding = query_embedding_result.embeddings[0].vector
|
|
153
164
|
|
|
154
165
|
search_results = await vector_store_adapter.search(store_query)
|
|
155
|
-
|
|
166
|
+
search_results_as_text = format_search_results(search_results)
|
|
156
167
|
|
|
157
|
-
return
|
|
168
|
+
return search_results_as_text
|
|
@@ -0,0 +1,527 @@
|
|
|
1
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from kiln_ai.datamodel import Task
|
|
6
|
+
from kiln_ai.datamodel.datamodel_enums import ModelProviderName, StructuredOutputMode
|
|
7
|
+
from kiln_ai.datamodel.external_tool_server import ExternalToolServer, ToolServerType
|
|
8
|
+
from kiln_ai.datamodel.run_config import RunConfigProperties
|
|
9
|
+
from kiln_ai.datamodel.task import TaskRunConfig
|
|
10
|
+
from kiln_ai.datamodel.task_output import DataSource, DataSourceType
|
|
11
|
+
from kiln_ai.tools.base_tool import ToolCallContext
|
|
12
|
+
from kiln_ai.tools.kiln_task_tool import KilnTaskTool, KilnTaskToolResult
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TestKilnTaskToolResult:
|
|
16
|
+
"""Test the KilnTaskToolResult class."""
|
|
17
|
+
|
|
18
|
+
def test_init(self):
|
|
19
|
+
"""Test KilnTaskToolResult initialization."""
|
|
20
|
+
output = "test output"
|
|
21
|
+
kiln_task_tool_data = "project_id:::tool_id:::task_id:::run_id"
|
|
22
|
+
|
|
23
|
+
result = KilnTaskToolResult(output, kiln_task_tool_data)
|
|
24
|
+
|
|
25
|
+
assert result.output == output
|
|
26
|
+
assert result.kiln_task_tool_data == kiln_task_tool_data
|
|
27
|
+
|
|
28
|
+
def test_init_with_empty_strings(self):
|
|
29
|
+
"""Test KilnTaskToolResult initialization with empty strings."""
|
|
30
|
+
result = KilnTaskToolResult("", "")
|
|
31
|
+
|
|
32
|
+
assert result.output == ""
|
|
33
|
+
assert result.kiln_task_tool_data == ""
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class TestKilnTaskTool:
|
|
37
|
+
"""Test the KilnTaskTool class."""
|
|
38
|
+
|
|
39
|
+
@pytest.fixture
|
|
40
|
+
def mock_external_tool_server(self):
|
|
41
|
+
"""Create a mock ExternalToolServer for testing."""
|
|
42
|
+
return ExternalToolServer(
|
|
43
|
+
name="test_tool",
|
|
44
|
+
type=ToolServerType.kiln_task,
|
|
45
|
+
description="Test Kiln task tool",
|
|
46
|
+
properties={
|
|
47
|
+
"name": "test_task_tool",
|
|
48
|
+
"description": "A test task tool",
|
|
49
|
+
"task_id": "test_task_123",
|
|
50
|
+
"run_config_id": "test_config_456",
|
|
51
|
+
"is_archived": False,
|
|
52
|
+
},
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
@pytest.fixture
|
|
56
|
+
def mock_task(self):
|
|
57
|
+
"""Create a mock Task for testing."""
|
|
58
|
+
task = MagicMock(spec=Task)
|
|
59
|
+
task.id = "test_task_123"
|
|
60
|
+
task.input_json_schema = None
|
|
61
|
+
task.input_schema.return_value = None
|
|
62
|
+
task.run_configs.return_value = []
|
|
63
|
+
return task
|
|
64
|
+
|
|
65
|
+
@pytest.fixture
|
|
66
|
+
def mock_run_config(self):
|
|
67
|
+
"""Create a mock TaskRunConfig for testing."""
|
|
68
|
+
run_config = MagicMock(spec=TaskRunConfig)
|
|
69
|
+
run_config.id = "test_config_456"
|
|
70
|
+
run_config.run_config_properties = {
|
|
71
|
+
"model_name": "gpt-4",
|
|
72
|
+
"model_provider_name": "openai",
|
|
73
|
+
"prompt_id": "simple_prompt_builder",
|
|
74
|
+
"structured_output_mode": "default",
|
|
75
|
+
}
|
|
76
|
+
return run_config
|
|
77
|
+
|
|
78
|
+
@pytest.fixture
|
|
79
|
+
def mock_context(self):
|
|
80
|
+
"""Create a mock ToolCallContext for testing."""
|
|
81
|
+
context = MagicMock(spec=ToolCallContext)
|
|
82
|
+
context.allow_saving = True
|
|
83
|
+
return context
|
|
84
|
+
|
|
85
|
+
@pytest.fixture
|
|
86
|
+
def kiln_task_tool(self, mock_external_tool_server):
|
|
87
|
+
"""Create a KilnTaskTool instance for testing."""
|
|
88
|
+
return KilnTaskTool(
|
|
89
|
+
project_id="test_project",
|
|
90
|
+
tool_id="test_tool_id",
|
|
91
|
+
data_model=mock_external_tool_server,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
@pytest.mark.asyncio
|
|
95
|
+
async def test_init(self, mock_external_tool_server):
|
|
96
|
+
"""Test KilnTaskTool initialization."""
|
|
97
|
+
tool = KilnTaskTool(
|
|
98
|
+
project_id="test_project",
|
|
99
|
+
tool_id="test_tool_id",
|
|
100
|
+
data_model=mock_external_tool_server,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
assert tool._project_id == "test_project"
|
|
104
|
+
assert tool._tool_id == "test_tool_id"
|
|
105
|
+
assert tool._tool_server_model == mock_external_tool_server
|
|
106
|
+
assert tool._name == "test_task_tool"
|
|
107
|
+
assert tool._description == "A test task tool"
|
|
108
|
+
assert tool._task_id == "test_task_123"
|
|
109
|
+
assert tool._run_config_id == "test_config_456"
|
|
110
|
+
|
|
111
|
+
@pytest.mark.asyncio
|
|
112
|
+
async def test_init_with_missing_properties(self):
|
|
113
|
+
"""Test KilnTaskTool initialization with missing properties."""
|
|
114
|
+
# Create a server with minimal required properties
|
|
115
|
+
server = ExternalToolServer(
|
|
116
|
+
name="test_tool",
|
|
117
|
+
type=ToolServerType.kiln_task,
|
|
118
|
+
description="Test tool",
|
|
119
|
+
properties={
|
|
120
|
+
"name": "minimal_tool",
|
|
121
|
+
"description": "",
|
|
122
|
+
"task_id": "",
|
|
123
|
+
"run_config_id": "",
|
|
124
|
+
"is_archived": False,
|
|
125
|
+
},
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
tool = KilnTaskTool(
|
|
129
|
+
project_id="test_project",
|
|
130
|
+
tool_id="test_tool_id",
|
|
131
|
+
data_model=server,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
assert tool._name == "minimal_tool"
|
|
135
|
+
assert tool._description == ""
|
|
136
|
+
assert tool._task_id == ""
|
|
137
|
+
assert tool._run_config_id == ""
|
|
138
|
+
|
|
139
|
+
@pytest.mark.asyncio
|
|
140
|
+
async def test_id(self, kiln_task_tool):
|
|
141
|
+
"""Test the id method."""
|
|
142
|
+
result = await kiln_task_tool.id()
|
|
143
|
+
assert result == "test_tool_id"
|
|
144
|
+
|
|
145
|
+
@pytest.mark.asyncio
|
|
146
|
+
async def test_name(self, kiln_task_tool):
|
|
147
|
+
"""Test the name method."""
|
|
148
|
+
result = await kiln_task_tool.name()
|
|
149
|
+
assert result == "test_task_tool"
|
|
150
|
+
|
|
151
|
+
@pytest.mark.asyncio
|
|
152
|
+
async def test_description(self, kiln_task_tool):
|
|
153
|
+
"""Test the description method."""
|
|
154
|
+
result = await kiln_task_tool.description()
|
|
155
|
+
assert result == "A test task tool"
|
|
156
|
+
|
|
157
|
+
@pytest.mark.asyncio
|
|
158
|
+
async def test_toolcall_definition(self, kiln_task_tool):
|
|
159
|
+
"""Test the toolcall_definition method."""
|
|
160
|
+
# Mock the parameters_schema property directly
|
|
161
|
+
kiln_task_tool.parameters_schema = {"type": "object"}
|
|
162
|
+
|
|
163
|
+
definition = await kiln_task_tool.toolcall_definition()
|
|
164
|
+
|
|
165
|
+
assert definition["type"] == "function"
|
|
166
|
+
assert definition["function"]["name"] == "test_task_tool"
|
|
167
|
+
assert definition["function"]["description"] == "A test task tool"
|
|
168
|
+
assert definition["function"]["parameters"] == {"type": "object"}
|
|
169
|
+
|
|
170
|
+
@pytest.mark.asyncio
|
|
171
|
+
async def test_run_with_plaintext_input(
|
|
172
|
+
self, kiln_task_tool, mock_context, mock_task, mock_run_config
|
|
173
|
+
):
|
|
174
|
+
"""Test the run method with plaintext input."""
|
|
175
|
+
# Setup mocks
|
|
176
|
+
kiln_task_tool._task = mock_task
|
|
177
|
+
kiln_task_tool._run_config = mock_run_config
|
|
178
|
+
|
|
179
|
+
with (
|
|
180
|
+
patch(
|
|
181
|
+
"kiln_ai.adapters.adapter_registry.adapter_for_task"
|
|
182
|
+
) as mock_adapter_for_task,
|
|
183
|
+
patch(
|
|
184
|
+
"kiln_ai.adapters.model_adapters.base_adapter.AdapterConfig"
|
|
185
|
+
) as mock_adapter_config,
|
|
186
|
+
):
|
|
187
|
+
# Mock adapter and task run
|
|
188
|
+
mock_adapter = AsyncMock()
|
|
189
|
+
mock_adapter_for_task.return_value = mock_adapter
|
|
190
|
+
|
|
191
|
+
mock_task_run = MagicMock()
|
|
192
|
+
mock_task_run.id = "run_789"
|
|
193
|
+
mock_task_run.output.output = "Task completed successfully"
|
|
194
|
+
mock_adapter.invoke.return_value = mock_task_run
|
|
195
|
+
|
|
196
|
+
# Test with plaintext input
|
|
197
|
+
result = await kiln_task_tool.run(context=mock_context, input="test input")
|
|
198
|
+
|
|
199
|
+
# Verify adapter was created correctly
|
|
200
|
+
mock_adapter_for_task.assert_called_once_with(
|
|
201
|
+
mock_task,
|
|
202
|
+
run_config_properties={
|
|
203
|
+
"model_name": "gpt-4",
|
|
204
|
+
"model_provider_name": "openai",
|
|
205
|
+
"prompt_id": "simple_prompt_builder",
|
|
206
|
+
"structured_output_mode": "default",
|
|
207
|
+
},
|
|
208
|
+
base_adapter_config=mock_adapter_config.return_value,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Verify adapter config
|
|
212
|
+
mock_adapter_config.assert_called_once_with(
|
|
213
|
+
allow_saving=True,
|
|
214
|
+
default_tags=["tool_call"],
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Verify adapter invoke was called
|
|
218
|
+
mock_adapter.invoke.assert_called_once_with(
|
|
219
|
+
"test input",
|
|
220
|
+
input_source=DataSource(
|
|
221
|
+
type=DataSourceType.tool_call,
|
|
222
|
+
run_config=RunConfigProperties(
|
|
223
|
+
model_name="gpt-4",
|
|
224
|
+
model_provider_name=ModelProviderName.openai,
|
|
225
|
+
prompt_id="simple_prompt_builder",
|
|
226
|
+
structured_output_mode=StructuredOutputMode.default,
|
|
227
|
+
),
|
|
228
|
+
),
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# Verify result
|
|
232
|
+
assert isinstance(result, KilnTaskToolResult)
|
|
233
|
+
assert result.output == "Task completed successfully"
|
|
234
|
+
assert (
|
|
235
|
+
result.kiln_task_tool_data
|
|
236
|
+
== "test_project:::test_tool_id:::test_task_123:::run_789"
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
@pytest.mark.asyncio
|
|
240
|
+
async def test_run_with_structured_input(
|
|
241
|
+
self, kiln_task_tool, mock_context, mock_task, mock_run_config
|
|
242
|
+
):
|
|
243
|
+
"""Test the run method with structured input."""
|
|
244
|
+
# Setup task with JSON schema
|
|
245
|
+
mock_task.input_json_schema = {
|
|
246
|
+
"type": "object",
|
|
247
|
+
"properties": {"param1": {"type": "string"}},
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
# Setup mocks
|
|
251
|
+
kiln_task_tool._task = mock_task
|
|
252
|
+
kiln_task_tool._run_config = mock_run_config
|
|
253
|
+
|
|
254
|
+
with patch(
|
|
255
|
+
"kiln_ai.adapters.adapter_registry.adapter_for_task"
|
|
256
|
+
) as mock_adapter_for_task:
|
|
257
|
+
# Mock adapter and task run
|
|
258
|
+
mock_adapter = AsyncMock()
|
|
259
|
+
mock_adapter_for_task.return_value = mock_adapter
|
|
260
|
+
|
|
261
|
+
mock_task_run = MagicMock()
|
|
262
|
+
mock_task_run.id = "run_789"
|
|
263
|
+
mock_task_run.output.output = "Structured task completed"
|
|
264
|
+
mock_adapter.invoke.return_value = mock_task_run
|
|
265
|
+
|
|
266
|
+
# Test with structured input
|
|
267
|
+
result = await kiln_task_tool.run(
|
|
268
|
+
context=mock_context, param1="value1", param2="value2"
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Verify adapter invoke was called with kwargs
|
|
272
|
+
mock_adapter.invoke.assert_called_once_with(
|
|
273
|
+
{"param1": "value1", "param2": "value2"},
|
|
274
|
+
input_source=DataSource(
|
|
275
|
+
type=DataSourceType.tool_call,
|
|
276
|
+
run_config=RunConfigProperties(
|
|
277
|
+
model_name="gpt-4",
|
|
278
|
+
model_provider_name=ModelProviderName.openai,
|
|
279
|
+
prompt_id="simple_prompt_builder",
|
|
280
|
+
structured_output_mode=StructuredOutputMode.default,
|
|
281
|
+
),
|
|
282
|
+
),
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# Verify result
|
|
286
|
+
assert result.output == "Structured task completed"
|
|
287
|
+
|
|
288
|
+
@pytest.mark.asyncio
|
|
289
|
+
async def test_run_without_context(self, kiln_task_tool):
|
|
290
|
+
"""Test the run method without context raises ValueError."""
|
|
291
|
+
with pytest.raises(
|
|
292
|
+
ValueError, match="Context is required for running a KilnTaskTool"
|
|
293
|
+
):
|
|
294
|
+
await kiln_task_tool.run(input="test input")
|
|
295
|
+
|
|
296
|
+
@pytest.mark.asyncio
|
|
297
|
+
async def test_run_plaintext_missing_input(
|
|
298
|
+
self, kiln_task_tool, mock_context, mock_task
|
|
299
|
+
):
|
|
300
|
+
"""Test the run method with plaintext task but missing input parameter."""
|
|
301
|
+
# Setup mocks
|
|
302
|
+
kiln_task_tool._task = mock_task
|
|
303
|
+
|
|
304
|
+
with pytest.raises(ValueError, match="Input not found in kwargs"):
|
|
305
|
+
await kiln_task_tool.run(context=mock_context, wrong_param="value")
|
|
306
|
+
|
|
307
|
+
@pytest.mark.asyncio
|
|
308
|
+
async def test_task_property_project_not_found(self, kiln_task_tool):
|
|
309
|
+
"""Test _task property when project is not found."""
|
|
310
|
+
with patch("kiln_ai.tools.kiln_task_tool.project_from_id", return_value=None):
|
|
311
|
+
with pytest.raises(ValueError, match="Project not found: test_project"):
|
|
312
|
+
_ = kiln_task_tool._task
|
|
313
|
+
|
|
314
|
+
@pytest.mark.asyncio
|
|
315
|
+
async def test_task_property_task_not_found(self, kiln_task_tool):
|
|
316
|
+
"""Test _task property when task is not found."""
|
|
317
|
+
mock_project = MagicMock()
|
|
318
|
+
mock_project.path = "/test/path"
|
|
319
|
+
|
|
320
|
+
with (
|
|
321
|
+
patch(
|
|
322
|
+
"kiln_ai.tools.kiln_task_tool.project_from_id",
|
|
323
|
+
return_value=mock_project,
|
|
324
|
+
),
|
|
325
|
+
patch(
|
|
326
|
+
"kiln_ai.tools.kiln_task_tool.Task.from_id_and_parent_path",
|
|
327
|
+
return_value=None,
|
|
328
|
+
),
|
|
329
|
+
):
|
|
330
|
+
with pytest.raises(
|
|
331
|
+
ValueError,
|
|
332
|
+
match="Task not found: test_task_123 in project test_project",
|
|
333
|
+
):
|
|
334
|
+
_ = kiln_task_tool._task
|
|
335
|
+
|
|
336
|
+
@pytest.mark.asyncio
|
|
337
|
+
async def test_task_property_success(self, kiln_task_tool, mock_task):
|
|
338
|
+
"""Test _task property when task is found successfully."""
|
|
339
|
+
mock_project = MagicMock()
|
|
340
|
+
mock_project.path = "/test/path"
|
|
341
|
+
|
|
342
|
+
with (
|
|
343
|
+
patch(
|
|
344
|
+
"kiln_ai.tools.kiln_task_tool.project_from_id",
|
|
345
|
+
return_value=mock_project,
|
|
346
|
+
),
|
|
347
|
+
patch(
|
|
348
|
+
"kiln_ai.tools.kiln_task_tool.Task.from_id_and_parent_path",
|
|
349
|
+
return_value=mock_task,
|
|
350
|
+
),
|
|
351
|
+
):
|
|
352
|
+
result = kiln_task_tool._task
|
|
353
|
+
assert result == mock_task
|
|
354
|
+
|
|
355
|
+
@pytest.mark.asyncio
|
|
356
|
+
async def test_run_config_property_not_found(self, kiln_task_tool, mock_task):
|
|
357
|
+
"""Test _run_config property when run config is not found."""
|
|
358
|
+
mock_task.run_configs.return_value = []
|
|
359
|
+
|
|
360
|
+
# Setup mocks
|
|
361
|
+
kiln_task_tool._task = mock_task
|
|
362
|
+
|
|
363
|
+
with pytest.raises(
|
|
364
|
+
ValueError,
|
|
365
|
+
match="Task run config not found: test_config_456 for task test_task_123 in project test_project",
|
|
366
|
+
):
|
|
367
|
+
_ = kiln_task_tool._run_config
|
|
368
|
+
|
|
369
|
+
@pytest.mark.asyncio
|
|
370
|
+
async def test_run_config_property_success(
|
|
371
|
+
self, kiln_task_tool, mock_task, mock_run_config
|
|
372
|
+
):
|
|
373
|
+
"""Test _run_config property when run config is found successfully."""
|
|
374
|
+
mock_task.run_configs.return_value = [mock_run_config]
|
|
375
|
+
|
|
376
|
+
# Setup mocks
|
|
377
|
+
kiln_task_tool._task = mock_task
|
|
378
|
+
|
|
379
|
+
result = kiln_task_tool._run_config
|
|
380
|
+
assert result == mock_run_config
|
|
381
|
+
|
|
382
|
+
@pytest.mark.asyncio
|
|
383
|
+
async def test_parameters_schema_with_json_schema(self, kiln_task_tool, mock_task):
|
|
384
|
+
"""Test parameters_schema property with JSON schema."""
|
|
385
|
+
expected_schema = {
|
|
386
|
+
"type": "object",
|
|
387
|
+
"properties": {"param": {"type": "string"}},
|
|
388
|
+
}
|
|
389
|
+
mock_task.input_json_schema = expected_schema
|
|
390
|
+
mock_task.input_schema.return_value = expected_schema
|
|
391
|
+
|
|
392
|
+
# Setup mocks
|
|
393
|
+
kiln_task_tool._task = mock_task
|
|
394
|
+
|
|
395
|
+
result = kiln_task_tool.parameters_schema
|
|
396
|
+
assert result == expected_schema
|
|
397
|
+
|
|
398
|
+
@pytest.mark.asyncio
|
|
399
|
+
async def test_parameters_schema_plaintext(self, kiln_task_tool, mock_task):
|
|
400
|
+
"""Test parameters_schema property for plaintext task."""
|
|
401
|
+
mock_task.input_json_schema = None
|
|
402
|
+
|
|
403
|
+
# Setup mocks
|
|
404
|
+
kiln_task_tool._task = mock_task
|
|
405
|
+
|
|
406
|
+
result = kiln_task_tool.parameters_schema
|
|
407
|
+
|
|
408
|
+
expected = {
|
|
409
|
+
"type": "object",
|
|
410
|
+
"properties": {
|
|
411
|
+
"input": {
|
|
412
|
+
"type": "string",
|
|
413
|
+
"description": "Plaintext input for the tool.",
|
|
414
|
+
}
|
|
415
|
+
},
|
|
416
|
+
"required": ["input"],
|
|
417
|
+
}
|
|
418
|
+
assert result == expected
|
|
419
|
+
|
|
420
|
+
@pytest.mark.asyncio
|
|
421
|
+
async def test_parameters_schema_none_raises_error(self, kiln_task_tool, mock_task):
|
|
422
|
+
"""Test parameters_schema property when schema is None raises ValueError."""
|
|
423
|
+
# Set up a task with JSON schema but input_schema returns None
|
|
424
|
+
mock_task.input_json_schema = {
|
|
425
|
+
"type": "object",
|
|
426
|
+
"properties": {"param": {"type": "string"}},
|
|
427
|
+
}
|
|
428
|
+
mock_task.input_schema.return_value = None
|
|
429
|
+
|
|
430
|
+
# Setup mocks - directly assign the task to bypass cached property
|
|
431
|
+
kiln_task_tool._task = mock_task
|
|
432
|
+
|
|
433
|
+
with pytest.raises(
|
|
434
|
+
ValueError,
|
|
435
|
+
match="Failed to create parameters schema for tool_id test_tool_id",
|
|
436
|
+
):
|
|
437
|
+
_ = kiln_task_tool.parameters_schema
|
|
438
|
+
|
|
439
|
+
@pytest.mark.asyncio
|
|
440
|
+
async def test_cached_properties(self, kiln_task_tool, mock_task, mock_run_config):
|
|
441
|
+
"""Test that cached properties work correctly."""
|
|
442
|
+
mock_project = MagicMock()
|
|
443
|
+
mock_project.path = "/test/path"
|
|
444
|
+
|
|
445
|
+
with (
|
|
446
|
+
patch(
|
|
447
|
+
"kiln_ai.tools.kiln_task_tool.project_from_id",
|
|
448
|
+
return_value=mock_project,
|
|
449
|
+
),
|
|
450
|
+
patch(
|
|
451
|
+
"kiln_ai.tools.kiln_task_tool.Task.from_id_and_parent_path",
|
|
452
|
+
return_value=mock_task,
|
|
453
|
+
),
|
|
454
|
+
):
|
|
455
|
+
# First access should call the methods
|
|
456
|
+
task1 = kiln_task_tool._task
|
|
457
|
+
task2 = kiln_task_tool._task
|
|
458
|
+
|
|
459
|
+
# Should be the same object (cached)
|
|
460
|
+
assert task1 is task2
|
|
461
|
+
|
|
462
|
+
# Verify the methods were called only once
|
|
463
|
+
assert mock_project is not None # project_from_id was called
|
|
464
|
+
# Task.from_id_and_parent_path should have been called once
|
|
465
|
+
with patch(
|
|
466
|
+
"kiln_ai.tools.kiln_task_tool.Task.from_id_and_parent_path"
|
|
467
|
+
) as mock_from_id:
|
|
468
|
+
mock_from_id.return_value = mock_task
|
|
469
|
+
_ = kiln_task_tool._task
|
|
470
|
+
# Should not be called again due to caching
|
|
471
|
+
mock_from_id.assert_not_called()
|
|
472
|
+
|
|
473
|
+
@pytest.mark.asyncio
|
|
474
|
+
async def test_run_with_adapter_exception(
|
|
475
|
+
self, kiln_task_tool, mock_context, mock_task, mock_run_config
|
|
476
|
+
):
|
|
477
|
+
"""Test the run method when adapter raises an exception."""
|
|
478
|
+
# Setup mocks
|
|
479
|
+
kiln_task_tool._task = mock_task
|
|
480
|
+
kiln_task_tool._run_config = mock_run_config
|
|
481
|
+
|
|
482
|
+
with patch(
|
|
483
|
+
"kiln_ai.adapters.adapter_registry.adapter_for_task"
|
|
484
|
+
) as mock_adapter_for_task:
|
|
485
|
+
# Mock adapter to raise an exception
|
|
486
|
+
mock_adapter = AsyncMock()
|
|
487
|
+
mock_adapter.invoke.side_effect = Exception("Adapter failed")
|
|
488
|
+
mock_adapter_for_task.return_value = mock_adapter
|
|
489
|
+
|
|
490
|
+
with pytest.raises(Exception, match="Adapter failed"):
|
|
491
|
+
await kiln_task_tool.run(context=mock_context, input="test input")
|
|
492
|
+
|
|
493
|
+
@pytest.mark.asyncio
|
|
494
|
+
async def test_run_with_different_allow_saving(
|
|
495
|
+
self, kiln_task_tool, mock_task, mock_run_config
|
|
496
|
+
):
|
|
497
|
+
"""Test the run method with different allow_saving values."""
|
|
498
|
+
mock_context_false = MagicMock(spec=ToolCallContext)
|
|
499
|
+
mock_context_false.allow_saving = False
|
|
500
|
+
|
|
501
|
+
# Setup mocks
|
|
502
|
+
kiln_task_tool._task = mock_task
|
|
503
|
+
kiln_task_tool._run_config = mock_run_config
|
|
504
|
+
|
|
505
|
+
with (
|
|
506
|
+
patch(
|
|
507
|
+
"kiln_ai.adapters.adapter_registry.adapter_for_task"
|
|
508
|
+
) as mock_adapter_for_task,
|
|
509
|
+
patch(
|
|
510
|
+
"kiln_ai.adapters.model_adapters.base_adapter.AdapterConfig"
|
|
511
|
+
) as mock_adapter_config,
|
|
512
|
+
):
|
|
513
|
+
mock_adapter = AsyncMock()
|
|
514
|
+
mock_adapter_for_task.return_value = mock_adapter
|
|
515
|
+
|
|
516
|
+
mock_task_run = MagicMock()
|
|
517
|
+
mock_task_run.id = "run_789"
|
|
518
|
+
mock_task_run.output.output = "Task completed"
|
|
519
|
+
mock_adapter.invoke.return_value = mock_task_run
|
|
520
|
+
|
|
521
|
+
await kiln_task_tool.run(context=mock_context_false, input="test input")
|
|
522
|
+
|
|
523
|
+
# Verify adapter config was called with allow_saving=False
|
|
524
|
+
mock_adapter_config.assert_called_once_with(
|
|
525
|
+
allow_saving=False,
|
|
526
|
+
default_tags=["tool_call"],
|
|
527
|
+
)
|