kiln-ai 0.18.0__py3-none-any.whl → 0.20.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -2
- kiln_ai/adapters/adapter_registry.py +46 -0
- kiln_ai/adapters/chat/chat_formatter.py +8 -12
- kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
- kiln_ai/adapters/data_gen/data_gen_task.py +2 -2
- kiln_ai/adapters/data_gen/test_data_gen_task.py +7 -3
- kiln_ai/adapters/docker_model_runner_tools.py +119 -0
- kiln_ai/adapters/eval/base_eval.py +2 -2
- kiln_ai/adapters/eval/eval_runner.py +3 -1
- kiln_ai/adapters/eval/g_eval.py +2 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -1
- kiln_ai/adapters/eval/test_eval_runner.py +6 -12
- kiln_ai/adapters/eval/test_g_eval.py +3 -4
- kiln_ai/adapters/eval/test_g_eval_data.py +1 -1
- kiln_ai/adapters/fine_tune/__init__.py +1 -1
- kiln_ai/adapters/fine_tune/base_finetune.py +1 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +32 -20
- kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +30 -21
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
- kiln_ai/adapters/ml_model_list.py +1009 -111
- kiln_ai/adapters/model_adapters/base_adapter.py +62 -28
- kiln_ai/adapters/model_adapters/litellm_adapter.py +397 -80
- kiln_ai/adapters/model_adapters/test_base_adapter.py +194 -18
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +428 -4
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
- kiln_ai/adapters/model_adapters/test_structured_output.py +120 -14
- kiln_ai/adapters/parsers/__init__.py +1 -1
- kiln_ai/adapters/parsers/test_r1_parser.py +1 -1
- kiln_ai/adapters/provider_tools.py +35 -20
- kiln_ai/adapters/remote_config.py +57 -10
- kiln_ai/adapters/repair/repair_task.py +1 -1
- kiln_ai/adapters/repair/test_repair_task.py +12 -9
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +109 -2
- kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
- kiln_ai/adapters/test_ml_model_list.py +51 -1
- kiln_ai/adapters/test_prompt_adaptors.py +13 -6
- kiln_ai/adapters/test_provider_tools.py +73 -12
- kiln_ai/adapters/test_remote_config.py +470 -16
- kiln_ai/datamodel/__init__.py +23 -21
- kiln_ai/datamodel/basemodel.py +54 -28
- kiln_ai/datamodel/datamodel_enums.py +3 -0
- kiln_ai/datamodel/dataset_split.py +5 -3
- kiln_ai/datamodel/eval.py +4 -4
- kiln_ai/datamodel/external_tool_server.py +298 -0
- kiln_ai/datamodel/finetune.py +2 -2
- kiln_ai/datamodel/json_schema.py +25 -10
- kiln_ai/datamodel/project.py +11 -4
- kiln_ai/datamodel/prompt.py +2 -2
- kiln_ai/datamodel/prompt_id.py +4 -4
- kiln_ai/datamodel/registry.py +0 -15
- kiln_ai/datamodel/run_config.py +62 -0
- kiln_ai/datamodel/task.py +8 -83
- kiln_ai/datamodel/task_output.py +7 -2
- kiln_ai/datamodel/task_run.py +41 -0
- kiln_ai/datamodel/test_basemodel.py +213 -21
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_example_models.py +175 -0
- kiln_ai/datamodel/test_external_tool_server.py +691 -0
- kiln_ai/datamodel/test_model_perf.py +1 -1
- kiln_ai/datamodel/test_prompt_id.py +5 -1
- kiln_ai/datamodel/test_registry.py +8 -3
- kiln_ai/datamodel/test_task.py +20 -47
- kiln_ai/datamodel/test_tool_id.py +239 -0
- kiln_ai/datamodel/tool_id.py +83 -0
- kiln_ai/tools/__init__.py +8 -0
- kiln_ai/tools/base_tool.py +82 -0
- kiln_ai/tools/built_in_tools/__init__.py +13 -0
- kiln_ai/tools/built_in_tools/math_tools.py +124 -0
- kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
- kiln_ai/tools/mcp_server_tool.py +95 -0
- kiln_ai/tools/mcp_session_manager.py +243 -0
- kiln_ai/tools/test_base_tools.py +199 -0
- kiln_ai/tools/test_mcp_server_tool.py +457 -0
- kiln_ai/tools/test_mcp_session_manager.py +1585 -0
- kiln_ai/tools/test_tool_registry.py +473 -0
- kiln_ai/tools/tool_registry.py +64 -0
- kiln_ai/utils/config.py +32 -0
- kiln_ai/utils/open_ai_types.py +94 -0
- kiln_ai/utils/project_utils.py +17 -0
- kiln_ai/utils/test_config.py +138 -1
- kiln_ai/utils/test_open_ai_types.py +131 -0
- {kiln_ai-0.18.0.dist-info → kiln_ai-0.20.1.dist-info}/METADATA +37 -6
- kiln_ai-0.20.1.dist-info/RECORD +138 -0
- kiln_ai-0.18.0.dist-info/RECORD +0 -115
- {kiln_ai-0.18.0.dist-info → kiln_ai-0.20.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.18.0.dist-info → kiln_ai-0.20.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,473 @@
|
|
|
1
|
+
from unittest.mock import Mock
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from kiln_ai.datamodel.external_tool_server import ExternalToolServer, ToolServerType
|
|
6
|
+
from kiln_ai.datamodel.project import Project
|
|
7
|
+
from kiln_ai.datamodel.task import Task
|
|
8
|
+
from kiln_ai.datamodel.tool_id import (
|
|
9
|
+
MCP_LOCAL_TOOL_ID_PREFIX,
|
|
10
|
+
MCP_REMOTE_TOOL_ID_PREFIX,
|
|
11
|
+
KilnBuiltInToolId,
|
|
12
|
+
_check_tool_id,
|
|
13
|
+
mcp_server_and_tool_name_from_id,
|
|
14
|
+
)
|
|
15
|
+
from kiln_ai.tools.built_in_tools.math_tools import (
|
|
16
|
+
AddTool,
|
|
17
|
+
DivideTool,
|
|
18
|
+
MultiplyTool,
|
|
19
|
+
SubtractTool,
|
|
20
|
+
)
|
|
21
|
+
from kiln_ai.tools.mcp_server_tool import MCPServerTool
|
|
22
|
+
from kiln_ai.tools.tool_registry import tool_from_id
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TestToolRegistry:
|
|
26
|
+
"""Test the tool registry functionality."""
|
|
27
|
+
|
|
28
|
+
async def test_tool_from_id_add_numbers(self):
|
|
29
|
+
"""Test that ADD_NUMBERS tool ID returns AddTool instance."""
|
|
30
|
+
tool = tool_from_id(KilnBuiltInToolId.ADD_NUMBERS)
|
|
31
|
+
|
|
32
|
+
assert isinstance(tool, AddTool)
|
|
33
|
+
assert await tool.id() == KilnBuiltInToolId.ADD_NUMBERS
|
|
34
|
+
assert await tool.name() == "add"
|
|
35
|
+
assert "Add two numbers" in await tool.description()
|
|
36
|
+
|
|
37
|
+
async def test_tool_from_id_subtract_numbers(self):
|
|
38
|
+
"""Test that SUBTRACT_NUMBERS tool ID returns SubtractTool instance."""
|
|
39
|
+
tool = tool_from_id(KilnBuiltInToolId.SUBTRACT_NUMBERS)
|
|
40
|
+
|
|
41
|
+
assert isinstance(tool, SubtractTool)
|
|
42
|
+
assert await tool.id() == KilnBuiltInToolId.SUBTRACT_NUMBERS
|
|
43
|
+
assert await tool.name() == "subtract"
|
|
44
|
+
|
|
45
|
+
async def test_tool_from_id_multiply_numbers(self):
|
|
46
|
+
"""Test that MULTIPLY_NUMBERS tool ID returns MultiplyTool instance."""
|
|
47
|
+
tool = tool_from_id(KilnBuiltInToolId.MULTIPLY_NUMBERS)
|
|
48
|
+
|
|
49
|
+
assert isinstance(tool, MultiplyTool)
|
|
50
|
+
assert await tool.id() == KilnBuiltInToolId.MULTIPLY_NUMBERS
|
|
51
|
+
assert await tool.name() == "multiply"
|
|
52
|
+
|
|
53
|
+
async def test_tool_from_id_divide_numbers(self):
|
|
54
|
+
"""Test that DIVIDE_NUMBERS tool ID returns DivideTool instance."""
|
|
55
|
+
tool = tool_from_id(KilnBuiltInToolId.DIVIDE_NUMBERS)
|
|
56
|
+
|
|
57
|
+
assert isinstance(tool, DivideTool)
|
|
58
|
+
assert await tool.id() == KilnBuiltInToolId.DIVIDE_NUMBERS
|
|
59
|
+
assert await tool.name() == "divide"
|
|
60
|
+
|
|
61
|
+
async def test_tool_from_id_with_string_values(self):
|
|
62
|
+
"""Test that tool_from_id works with string values of enum members."""
|
|
63
|
+
tool = tool_from_id("kiln_tool::add_numbers")
|
|
64
|
+
|
|
65
|
+
assert isinstance(tool, AddTool)
|
|
66
|
+
assert await tool.id() == KilnBuiltInToolId.ADD_NUMBERS
|
|
67
|
+
|
|
68
|
+
async def test_tool_from_id_invalid_tool_id(self):
|
|
69
|
+
"""Test that invalid tool ID raises ValueError."""
|
|
70
|
+
with pytest.raises(
|
|
71
|
+
ValueError, match="Tool ID invalid_tool_id not found in tool registry"
|
|
72
|
+
):
|
|
73
|
+
tool_from_id("invalid_tool_id")
|
|
74
|
+
|
|
75
|
+
def test_tool_from_id_empty_string(self):
|
|
76
|
+
"""Test that empty string tool ID raises ValueError."""
|
|
77
|
+
with pytest.raises(ValueError, match="Tool ID not found in tool registry"):
|
|
78
|
+
tool_from_id("")
|
|
79
|
+
|
|
80
|
+
def test_tool_from_id_mcp_remote_tool_success(self):
|
|
81
|
+
"""Test that tool_from_id works with MCP remote tool IDs."""
|
|
82
|
+
# Create mock external tool server
|
|
83
|
+
mock_server = ExternalToolServer(
|
|
84
|
+
name="test_server",
|
|
85
|
+
type=ToolServerType.remote_mcp,
|
|
86
|
+
properties={
|
|
87
|
+
"server_url": "https://example.com",
|
|
88
|
+
"headers": {},
|
|
89
|
+
},
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Create mock project with the external tool server
|
|
93
|
+
mock_project = Mock(spec=Project)
|
|
94
|
+
mock_project.id = "test_project_id"
|
|
95
|
+
mock_project.external_tool_servers.return_value = [mock_server]
|
|
96
|
+
|
|
97
|
+
# Create mock task with parent project
|
|
98
|
+
mock_task = Mock(spec=Task)
|
|
99
|
+
mock_task.parent_project.return_value = mock_project
|
|
100
|
+
|
|
101
|
+
# Test with remote MCP tool ID
|
|
102
|
+
tool_id = f"{MCP_REMOTE_TOOL_ID_PREFIX}{mock_server.id}::echo"
|
|
103
|
+
tool = tool_from_id(tool_id, task=mock_task)
|
|
104
|
+
|
|
105
|
+
# Verify the tool is MCPServerTool
|
|
106
|
+
assert isinstance(tool, MCPServerTool)
|
|
107
|
+
assert tool._tool_server_model == mock_server
|
|
108
|
+
assert tool._name == "echo"
|
|
109
|
+
|
|
110
|
+
def test_tool_from_id_mcp_local_tool_success(self):
|
|
111
|
+
"""Test that tool_from_id works with MCP local tool IDs."""
|
|
112
|
+
# Create mock external tool server
|
|
113
|
+
mock_server = ExternalToolServer(
|
|
114
|
+
name="local_server",
|
|
115
|
+
type=ToolServerType.local_mcp,
|
|
116
|
+
properties={
|
|
117
|
+
"command": "python",
|
|
118
|
+
"args": ["server.py", "--port", "8080"],
|
|
119
|
+
"env_vars": {},
|
|
120
|
+
},
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Create mock project with the external tool server
|
|
124
|
+
mock_project = Mock(spec=Project)
|
|
125
|
+
mock_project.id = "test_project_id"
|
|
126
|
+
mock_project.external_tool_servers.return_value = [mock_server]
|
|
127
|
+
|
|
128
|
+
# Create mock task with parent project
|
|
129
|
+
mock_task = Mock(spec=Task)
|
|
130
|
+
mock_task.parent_project.return_value = mock_project
|
|
131
|
+
|
|
132
|
+
# Test with local MCP tool ID
|
|
133
|
+
tool_id = f"{MCP_LOCAL_TOOL_ID_PREFIX}{mock_server.id}::calculate"
|
|
134
|
+
tool = tool_from_id(tool_id, task=mock_task)
|
|
135
|
+
|
|
136
|
+
# Verify the tool is MCPServerTool
|
|
137
|
+
assert isinstance(tool, MCPServerTool)
|
|
138
|
+
assert tool._tool_server_model == mock_server
|
|
139
|
+
assert tool._name == "calculate"
|
|
140
|
+
|
|
141
|
+
def test_tool_from_id_mcp_tool_project_not_found(self):
|
|
142
|
+
"""Test that tool_from_id raises ValueError when task is not provided."""
|
|
143
|
+
tool_id = f"{MCP_LOCAL_TOOL_ID_PREFIX}test_server::test_tool"
|
|
144
|
+
with pytest.raises(
|
|
145
|
+
ValueError,
|
|
146
|
+
match="Unable to resolve tool from id.*Requires a parent project/task",
|
|
147
|
+
):
|
|
148
|
+
tool_from_id(tool_id, task=None)
|
|
149
|
+
|
|
150
|
+
def test_tool_from_id_mcp_tool_server_not_found(self):
|
|
151
|
+
"""Test that tool_from_id raises ValueError when tool server is not found."""
|
|
152
|
+
# Create mock external tool server with different ID
|
|
153
|
+
mock_server = ExternalToolServer(
|
|
154
|
+
name="different_server",
|
|
155
|
+
type=ToolServerType.remote_mcp,
|
|
156
|
+
properties={
|
|
157
|
+
"server_url": "https://example.com",
|
|
158
|
+
"headers": {},
|
|
159
|
+
},
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Create mock project with the external tool server
|
|
163
|
+
mock_project = Mock(spec=Project)
|
|
164
|
+
mock_project.id = "test_project_id"
|
|
165
|
+
mock_project.external_tool_servers.return_value = [mock_server]
|
|
166
|
+
|
|
167
|
+
# Create mock task with parent project
|
|
168
|
+
mock_task = Mock(spec=Task)
|
|
169
|
+
mock_task.parent_project.return_value = mock_project
|
|
170
|
+
|
|
171
|
+
# Test with both remote and local tool IDs that reference nonexistent servers
|
|
172
|
+
test_cases = [
|
|
173
|
+
f"{MCP_REMOTE_TOOL_ID_PREFIX}missing_server::test_tool",
|
|
174
|
+
f"{MCP_LOCAL_TOOL_ID_PREFIX}missing_server::test_tool",
|
|
175
|
+
]
|
|
176
|
+
|
|
177
|
+
for tool_id in test_cases:
|
|
178
|
+
with pytest.raises(
|
|
179
|
+
ValueError,
|
|
180
|
+
match="External tool server not found: missing_server in project ID test_project_id",
|
|
181
|
+
):
|
|
182
|
+
tool_from_id(tool_id, task=mock_task)
|
|
183
|
+
|
|
184
|
+
def test_all_built_in_tools_are_registered(self):
|
|
185
|
+
"""Test that all KilnBuiltInToolId enum members are handled by the registry."""
|
|
186
|
+
for tool_id in KilnBuiltInToolId:
|
|
187
|
+
# This should not raise an exception
|
|
188
|
+
tool = tool_from_id(tool_id.value)
|
|
189
|
+
assert tool is not None
|
|
190
|
+
|
|
191
|
+
async def test_registry_returns_new_instances(self):
|
|
192
|
+
"""Test that registry returns new instances each time (not singletons)."""
|
|
193
|
+
tool1 = tool_from_id(KilnBuiltInToolId.ADD_NUMBERS)
|
|
194
|
+
tool2 = tool_from_id(KilnBuiltInToolId.ADD_NUMBERS)
|
|
195
|
+
|
|
196
|
+
assert tool1 is not tool2 # Different instances
|
|
197
|
+
assert type(tool1) is type(tool2) # Same type
|
|
198
|
+
assert await tool1.id() == await tool2.id() # Same id
|
|
199
|
+
|
|
200
|
+
async def test_check_tool_id_valid_built_in_tools(self):
|
|
201
|
+
"""Test that _check_tool_id accepts valid built-in tool IDs."""
|
|
202
|
+
for tool_id in KilnBuiltInToolId:
|
|
203
|
+
result = _check_tool_id(tool_id.value)
|
|
204
|
+
assert result == tool_id.value
|
|
205
|
+
|
|
206
|
+
def test_check_tool_id_invalid_tool_id(self):
|
|
207
|
+
"""Test that _check_tool_id raises ValueError for invalid tool ID."""
|
|
208
|
+
with pytest.raises(ValueError, match="Invalid tool ID: invalid_tool_id"):
|
|
209
|
+
_check_tool_id("invalid_tool_id")
|
|
210
|
+
|
|
211
|
+
def test_check_tool_id_empty_string(self):
|
|
212
|
+
"""Test that _check_tool_id raises ValueError for empty string."""
|
|
213
|
+
with pytest.raises(ValueError, match="Invalid tool ID: "):
|
|
214
|
+
_check_tool_id("")
|
|
215
|
+
|
|
216
|
+
def test_check_tool_id_none_value(self):
|
|
217
|
+
"""Test that _check_tool_id raises ValueError for None."""
|
|
218
|
+
with pytest.raises(ValueError, match="Invalid tool ID: None"):
|
|
219
|
+
_check_tool_id(None) # type: ignore
|
|
220
|
+
|
|
221
|
+
def test_check_tool_id_valid_mcp_remote_tool_id(self):
|
|
222
|
+
"""Test that _check_tool_id accepts valid MCP remote tool IDs."""
|
|
223
|
+
valid_mcp_ids = [
|
|
224
|
+
f"{MCP_REMOTE_TOOL_ID_PREFIX}server123::tool_name",
|
|
225
|
+
f"{MCP_REMOTE_TOOL_ID_PREFIX}my_server::echo",
|
|
226
|
+
f"{MCP_REMOTE_TOOL_ID_PREFIX}123456789::test_tool",
|
|
227
|
+
f"{MCP_REMOTE_TOOL_ID_PREFIX}server_with_underscores::complex_tool_name",
|
|
228
|
+
]
|
|
229
|
+
|
|
230
|
+
for tool_id in valid_mcp_ids:
|
|
231
|
+
result = _check_tool_id(tool_id)
|
|
232
|
+
assert result == tool_id
|
|
233
|
+
|
|
234
|
+
def test_check_tool_id_valid_mcp_local_tool_id(self):
|
|
235
|
+
"""Test that _check_tool_id accepts valid MCP local tool IDs."""
|
|
236
|
+
valid_mcp_local_ids = [
|
|
237
|
+
f"{MCP_LOCAL_TOOL_ID_PREFIX}server123::tool_name",
|
|
238
|
+
f"{MCP_LOCAL_TOOL_ID_PREFIX}my_local_server::calculate",
|
|
239
|
+
f"{MCP_LOCAL_TOOL_ID_PREFIX}local_tool_server::process_data",
|
|
240
|
+
f"{MCP_LOCAL_TOOL_ID_PREFIX}server_with_underscores::complex_tool_name",
|
|
241
|
+
]
|
|
242
|
+
|
|
243
|
+
for tool_id in valid_mcp_local_ids:
|
|
244
|
+
result = _check_tool_id(tool_id)
|
|
245
|
+
assert result == tool_id
|
|
246
|
+
|
|
247
|
+
def test_check_tool_id_invalid_mcp_remote_tool_id(self):
|
|
248
|
+
"""Test that _check_tool_id rejects invalid MCP-like tool IDs."""
|
|
249
|
+
# These start with the prefix but have wrong format - get specific MCP error
|
|
250
|
+
invalid_mcp_format_ids = [
|
|
251
|
+
"mcp::remote::server", # Missing tool name (only 3 parts instead of 4)
|
|
252
|
+
"mcp::remote::", # Missing server and tool name (only 3 parts)
|
|
253
|
+
"mcp::remote::::tool", # Empty server name (5 parts instead of 4)
|
|
254
|
+
"mcp::remote::server::tool::extra", # Too many parts (5 instead of 4)
|
|
255
|
+
]
|
|
256
|
+
|
|
257
|
+
for invalid_id in invalid_mcp_format_ids:
|
|
258
|
+
with pytest.raises(
|
|
259
|
+
ValueError, match=f"Invalid remote MCP tool ID: {invalid_id}"
|
|
260
|
+
):
|
|
261
|
+
_check_tool_id(invalid_id)
|
|
262
|
+
|
|
263
|
+
# These don't match the prefix - get generic error
|
|
264
|
+
invalid_generic_ids = [
|
|
265
|
+
"mcp::remote:", # Missing last colon (doesn't match full prefix)
|
|
266
|
+
"mcp:remote::server::tool", # Wrong prefix format
|
|
267
|
+
"mcp::remote_server::tool", # Wrong prefix format
|
|
268
|
+
"remote::server::tool", # Missing mcp prefix
|
|
269
|
+
]
|
|
270
|
+
|
|
271
|
+
for invalid_id in invalid_generic_ids:
|
|
272
|
+
with pytest.raises(ValueError, match=f"Invalid tool ID: {invalid_id}"):
|
|
273
|
+
_check_tool_id(invalid_id)
|
|
274
|
+
|
|
275
|
+
def test_mcp_server_and_tool_name_from_id_valid_inputs(self):
|
|
276
|
+
"""Test that mcp_server_and_tool_name_from_id correctly parses valid MCP tool IDs."""
|
|
277
|
+
test_cases = [
|
|
278
|
+
# Remote MCP tool IDs
|
|
279
|
+
("mcp::remote::server123::tool_name", ("server123", "tool_name")),
|
|
280
|
+
("mcp::remote::my_server::echo", ("my_server", "echo")),
|
|
281
|
+
("mcp::remote::123456789::test_tool", ("123456789", "test_tool")),
|
|
282
|
+
(
|
|
283
|
+
"mcp::remote::server_with_underscores::complex_tool_name",
|
|
284
|
+
("server_with_underscores", "complex_tool_name"),
|
|
285
|
+
),
|
|
286
|
+
("mcp::remote::a::b", ("a", "b")), # Minimal valid case
|
|
287
|
+
(
|
|
288
|
+
"mcp::remote::server-with-dashes::tool-with-dashes",
|
|
289
|
+
("server-with-dashes", "tool-with-dashes"),
|
|
290
|
+
),
|
|
291
|
+
# Local MCP tool IDs
|
|
292
|
+
("mcp::local::local_server::calculate", ("local_server", "calculate")),
|
|
293
|
+
("mcp::local::my_local_tool::process", ("my_local_tool", "process")),
|
|
294
|
+
(
|
|
295
|
+
"mcp::local::123456789::local_test_tool",
|
|
296
|
+
("123456789", "local_test_tool"),
|
|
297
|
+
),
|
|
298
|
+
(
|
|
299
|
+
"mcp::local::local_server_with_underscores::complex_local_tool",
|
|
300
|
+
("local_server_with_underscores", "complex_local_tool"),
|
|
301
|
+
),
|
|
302
|
+
("mcp::local::x::y", ("x", "y")), # Minimal valid case for local
|
|
303
|
+
]
|
|
304
|
+
|
|
305
|
+
for tool_id, expected in test_cases:
|
|
306
|
+
result = mcp_server_and_tool_name_from_id(tool_id)
|
|
307
|
+
assert result == expected, (
|
|
308
|
+
f"Failed for {tool_id}: expected {expected}, got {result}"
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
def test_mcp_server_and_tool_name_from_id_invalid_inputs(self):
|
|
312
|
+
"""Test that mcp_server_and_tool_name_from_id raises ValueError for invalid MCP tool IDs."""
|
|
313
|
+
# Test remote MCP format errors
|
|
314
|
+
remote_invalid_inputs = [
|
|
315
|
+
"mcp::remote::server", # Only 3 parts instead of 4
|
|
316
|
+
"mcp::remote::", # Only 3 parts, missing server and tool
|
|
317
|
+
"mcp::remote::server::tool::extra", # 5 parts instead of 4
|
|
318
|
+
]
|
|
319
|
+
|
|
320
|
+
for invalid_id in remote_invalid_inputs:
|
|
321
|
+
with pytest.raises(
|
|
322
|
+
ValueError,
|
|
323
|
+
match=r"Invalid remote MCP tool ID:.*Expected format.*mcp::remote::<server_id>::<tool_name>",
|
|
324
|
+
):
|
|
325
|
+
mcp_server_and_tool_name_from_id(invalid_id)
|
|
326
|
+
|
|
327
|
+
# Test local MCP format errors
|
|
328
|
+
local_invalid_inputs = [
|
|
329
|
+
"mcp::local::server", # Only 3 parts instead of 4
|
|
330
|
+
"mcp::local::", # Only 3 parts, missing server and tool
|
|
331
|
+
"mcp::local::server::tool::extra", # 5 parts instead of 4
|
|
332
|
+
]
|
|
333
|
+
|
|
334
|
+
for invalid_id in local_invalid_inputs:
|
|
335
|
+
with pytest.raises(
|
|
336
|
+
ValueError,
|
|
337
|
+
match=r"Invalid local MCP tool ID:.*Expected format.*mcp::local::<server_id>::<tool_name>",
|
|
338
|
+
):
|
|
339
|
+
mcp_server_and_tool_name_from_id(invalid_id)
|
|
340
|
+
|
|
341
|
+
# Test generic MCP format errors (no valid prefix)
|
|
342
|
+
generic_invalid_inputs = [
|
|
343
|
+
"invalid::format::here", # 3 parts, wrong prefix
|
|
344
|
+
"", # Empty string
|
|
345
|
+
"single_part", # No separators
|
|
346
|
+
"two::parts", # Only 2 parts
|
|
347
|
+
]
|
|
348
|
+
|
|
349
|
+
for invalid_id in generic_invalid_inputs:
|
|
350
|
+
with pytest.raises(
|
|
351
|
+
ValueError,
|
|
352
|
+
match=r"Invalid MCP tool ID:.*Expected format.*mcp::\(remote\|local\)::<server_id>::<tool_name>",
|
|
353
|
+
):
|
|
354
|
+
mcp_server_and_tool_name_from_id(invalid_id)
|
|
355
|
+
|
|
356
|
+
def test_mcp_server_and_tool_name_from_id_edge_cases(self):
|
|
357
|
+
"""Test that mcp_server_and_tool_name_from_id handles edge cases (empty parts allowed by parser)."""
|
|
358
|
+
# These are valid according to the parser (exactly 4 parts),
|
|
359
|
+
# but empty server_id/tool_name validation is handled by _check_tool_id
|
|
360
|
+
edge_cases = [
|
|
361
|
+
("mcp::remote::::tool", ("", "tool")), # Empty server name
|
|
362
|
+
("mcp::remote::server::", ("server", "")), # Empty tool name
|
|
363
|
+
("mcp::remote::::", ("", "")), # Both empty
|
|
364
|
+
]
|
|
365
|
+
|
|
366
|
+
for tool_id, expected in edge_cases:
|
|
367
|
+
result = mcp_server_and_tool_name_from_id(tool_id)
|
|
368
|
+
assert result == expected, (
|
|
369
|
+
f"Failed for {tool_id}: expected {expected}, got {result}"
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
@pytest.mark.parametrize(
|
|
373
|
+
"tool_id,expected_server,expected_tool",
|
|
374
|
+
[
|
|
375
|
+
("mcp::remote::test_server::test_tool", "test_server", "test_tool"),
|
|
376
|
+
("mcp::remote::s::t", "s", "t"),
|
|
377
|
+
(
|
|
378
|
+
"mcp::remote::long_server_name_123::complex_tool_name_456",
|
|
379
|
+
"long_server_name_123",
|
|
380
|
+
"complex_tool_name_456",
|
|
381
|
+
),
|
|
382
|
+
(
|
|
383
|
+
"mcp::local::local_test_server::local_test_tool",
|
|
384
|
+
"local_test_server",
|
|
385
|
+
"local_test_tool",
|
|
386
|
+
),
|
|
387
|
+
("mcp::local::l::l", "l", "l"),
|
|
388
|
+
(
|
|
389
|
+
"mcp::local::long_local_server_123::complex_local_tool_456",
|
|
390
|
+
"long_local_server_123",
|
|
391
|
+
"complex_local_tool_456",
|
|
392
|
+
),
|
|
393
|
+
],
|
|
394
|
+
)
|
|
395
|
+
def test_mcp_server_and_tool_name_from_id_parametrized(
|
|
396
|
+
self, tool_id, expected_server, expected_tool
|
|
397
|
+
):
|
|
398
|
+
"""Parametrized test for mcp_server_and_tool_name_from_id with various valid inputs."""
|
|
399
|
+
server_id, tool_name = mcp_server_and_tool_name_from_id(tool_id)
|
|
400
|
+
assert server_id == expected_server
|
|
401
|
+
assert tool_name == expected_tool
|
|
402
|
+
|
|
403
|
+
def test_tool_from_id_mcp_missing_task_raises_error(self):
|
|
404
|
+
"""Test that MCP tool ID with missing task raises ValueError."""
|
|
405
|
+
mcp_tool_id = f"{MCP_REMOTE_TOOL_ID_PREFIX}test_server::test_tool"
|
|
406
|
+
|
|
407
|
+
with pytest.raises(
|
|
408
|
+
ValueError,
|
|
409
|
+
match="Unable to resolve tool from id.*Requires a parent project/task",
|
|
410
|
+
):
|
|
411
|
+
tool_from_id(mcp_tool_id, task=None)
|
|
412
|
+
|
|
413
|
+
def test_tool_from_id_mcp_functional_case(self):
|
|
414
|
+
"""Test that MCP tool ID with valid task and project returns MCPServerTool."""
|
|
415
|
+
# Create mock external tool server
|
|
416
|
+
mock_server = ExternalToolServer(
|
|
417
|
+
name="test_server",
|
|
418
|
+
type=ToolServerType.remote_mcp,
|
|
419
|
+
description="Test MCP server",
|
|
420
|
+
properties={
|
|
421
|
+
"server_url": "https://example.com",
|
|
422
|
+
"headers": {},
|
|
423
|
+
},
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
# Create mock project with the external tool server
|
|
427
|
+
mock_project = Mock(spec=Project)
|
|
428
|
+
mock_project.id = "test_project_id"
|
|
429
|
+
mock_project.external_tool_servers.return_value = [mock_server]
|
|
430
|
+
|
|
431
|
+
# Create mock task with parent project
|
|
432
|
+
mock_task = Mock(spec=Task)
|
|
433
|
+
mock_task.parent_project.return_value = mock_project
|
|
434
|
+
|
|
435
|
+
mcp_tool_id = f"{MCP_REMOTE_TOOL_ID_PREFIX}{mock_server.id}::test_tool"
|
|
436
|
+
|
|
437
|
+
tool = tool_from_id(mcp_tool_id, task=mock_task)
|
|
438
|
+
|
|
439
|
+
assert isinstance(tool, MCPServerTool)
|
|
440
|
+
# Verify the tool was created with the correct server and tool name
|
|
441
|
+
assert tool._tool_server_model == mock_server
|
|
442
|
+
assert tool._name == "test_tool"
|
|
443
|
+
|
|
444
|
+
def test_tool_from_id_mcp_no_server_found_raises_error(self):
|
|
445
|
+
"""Test that MCP tool ID with server not found raises ValueError."""
|
|
446
|
+
# Create mock external tool server with different ID
|
|
447
|
+
mock_server = ExternalToolServer(
|
|
448
|
+
name="different_server",
|
|
449
|
+
type=ToolServerType.remote_mcp,
|
|
450
|
+
description="Different MCP server",
|
|
451
|
+
properties={
|
|
452
|
+
"server_url": "https://example.com",
|
|
453
|
+
"headers": {},
|
|
454
|
+
},
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
# Create mock project with the external tool server
|
|
458
|
+
mock_project = Mock(spec=Project)
|
|
459
|
+
mock_project.id = "test_project_id"
|
|
460
|
+
mock_project.external_tool_servers.return_value = [mock_server]
|
|
461
|
+
|
|
462
|
+
# Create mock task with parent project
|
|
463
|
+
mock_task = Mock(spec=Task)
|
|
464
|
+
mock_task.parent_project.return_value = mock_project
|
|
465
|
+
|
|
466
|
+
# Use a tool ID with a server that doesn't exist in the project
|
|
467
|
+
mcp_tool_id = f"{MCP_REMOTE_TOOL_ID_PREFIX}nonexistent_server::test_tool"
|
|
468
|
+
|
|
469
|
+
with pytest.raises(
|
|
470
|
+
ValueError,
|
|
471
|
+
match="External tool server not found: nonexistent_server in project ID test_project_id",
|
|
472
|
+
):
|
|
473
|
+
tool_from_id(mcp_tool_id, task=mock_task)
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from kiln_ai.datamodel.task import Task
|
|
2
|
+
from kiln_ai.datamodel.tool_id import (
|
|
3
|
+
MCP_LOCAL_TOOL_ID_PREFIX,
|
|
4
|
+
MCP_REMOTE_TOOL_ID_PREFIX,
|
|
5
|
+
KilnBuiltInToolId,
|
|
6
|
+
mcp_server_and_tool_name_from_id,
|
|
7
|
+
)
|
|
8
|
+
from kiln_ai.tools.base_tool import KilnToolInterface
|
|
9
|
+
from kiln_ai.tools.built_in_tools.math_tools import (
|
|
10
|
+
AddTool,
|
|
11
|
+
DivideTool,
|
|
12
|
+
MultiplyTool,
|
|
13
|
+
SubtractTool,
|
|
14
|
+
)
|
|
15
|
+
from kiln_ai.tools.mcp_server_tool import MCPServerTool
|
|
16
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def tool_from_id(tool_id: str, task: Task | None = None) -> KilnToolInterface:
|
|
20
|
+
"""
|
|
21
|
+
Get a tool from its ID.
|
|
22
|
+
"""
|
|
23
|
+
# Check built-in tools
|
|
24
|
+
if tool_id in [member.value for member in KilnBuiltInToolId]:
|
|
25
|
+
typed_tool_id = KilnBuiltInToolId(tool_id)
|
|
26
|
+
match typed_tool_id:
|
|
27
|
+
case KilnBuiltInToolId.ADD_NUMBERS:
|
|
28
|
+
return AddTool()
|
|
29
|
+
case KilnBuiltInToolId.SUBTRACT_NUMBERS:
|
|
30
|
+
return SubtractTool()
|
|
31
|
+
case KilnBuiltInToolId.MULTIPLY_NUMBERS:
|
|
32
|
+
return MultiplyTool()
|
|
33
|
+
case KilnBuiltInToolId.DIVIDE_NUMBERS:
|
|
34
|
+
return DivideTool()
|
|
35
|
+
case _:
|
|
36
|
+
raise_exhaustive_enum_error(typed_tool_id)
|
|
37
|
+
|
|
38
|
+
# Check MCP Server Tools
|
|
39
|
+
if tool_id.startswith((MCP_REMOTE_TOOL_ID_PREFIX, MCP_LOCAL_TOOL_ID_PREFIX)):
|
|
40
|
+
project = task.parent_project() if task is not None else None
|
|
41
|
+
if project is None:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
f"Unable to resolve tool from id: {tool_id}. Requires a parent project/task."
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Get the tool server ID and tool name from the ID
|
|
47
|
+
tool_server_id, tool_name = mcp_server_and_tool_name_from_id(tool_id)
|
|
48
|
+
|
|
49
|
+
server = next(
|
|
50
|
+
(
|
|
51
|
+
server
|
|
52
|
+
for server in project.external_tool_servers()
|
|
53
|
+
if server.id == tool_server_id
|
|
54
|
+
),
|
|
55
|
+
None,
|
|
56
|
+
)
|
|
57
|
+
if server is None:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
f"External tool server not found: {tool_server_id} in project ID {project.id}"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
return MCPServerTool(server, tool_name)
|
|
63
|
+
|
|
64
|
+
raise ValueError(f"Tool ID {tool_id} not found in tool registry")
|
kiln_ai/utils/config.py
CHANGED
|
@@ -6,6 +6,9 @@ from typing import Any, Callable, Dict, List, Optional
|
|
|
6
6
|
|
|
7
7
|
import yaml
|
|
8
8
|
|
|
9
|
+
# Configuration keys
|
|
10
|
+
MCP_SECRETS_KEY = "mcp_secrets"
|
|
11
|
+
|
|
9
12
|
|
|
10
13
|
class ConfigProperty:
|
|
11
14
|
def __init__(
|
|
@@ -54,6 +57,10 @@ class Config:
|
|
|
54
57
|
str,
|
|
55
58
|
env_var="OLLAMA_BASE_URL",
|
|
56
59
|
),
|
|
60
|
+
"docker_model_runner_base_url": ConfigProperty(
|
|
61
|
+
str,
|
|
62
|
+
env_var="DOCKER_MODEL_RUNNER_BASE_URL",
|
|
63
|
+
),
|
|
57
64
|
"bedrock_access_key": ConfigProperty(
|
|
58
65
|
str,
|
|
59
66
|
env_var="AWS_ACCESS_KEY_ID",
|
|
@@ -124,6 +131,11 @@ class Config:
|
|
|
124
131
|
env_var="WANDB_API_KEY",
|
|
125
132
|
sensitive=True,
|
|
126
133
|
),
|
|
134
|
+
"siliconflow_cn_api_key": ConfigProperty(
|
|
135
|
+
str,
|
|
136
|
+
env_var="SILICONFLOW_CN_API_KEY",
|
|
137
|
+
sensitive=True,
|
|
138
|
+
),
|
|
127
139
|
"wandb_base_url": ConfigProperty(
|
|
128
140
|
str,
|
|
129
141
|
env_var="WANDB_BASE_URL",
|
|
@@ -137,6 +149,26 @@ class Config:
|
|
|
137
149
|
default_lambda=lambda: [],
|
|
138
150
|
sensitive_keys=["api_key"],
|
|
139
151
|
),
|
|
152
|
+
"cerebras_api_key": ConfigProperty(
|
|
153
|
+
str,
|
|
154
|
+
env_var="CEREBRAS_API_KEY",
|
|
155
|
+
sensitive=True,
|
|
156
|
+
),
|
|
157
|
+
"enable_demo_tools": ConfigProperty(
|
|
158
|
+
bool,
|
|
159
|
+
env_var="ENABLE_DEMO_TOOLS",
|
|
160
|
+
default=False,
|
|
161
|
+
),
|
|
162
|
+
# Allow the user to set the path to lookup MCP server commands, like npx.
|
|
163
|
+
"custom_mcp_path": ConfigProperty(
|
|
164
|
+
str,
|
|
165
|
+
env_var="CUSTOM_MCP_PATH",
|
|
166
|
+
),
|
|
167
|
+
# Allow the user to set secrets for MCP servers, the key is mcp_server_id::key_name
|
|
168
|
+
MCP_SECRETS_KEY: ConfigProperty(
|
|
169
|
+
dict[str, str],
|
|
170
|
+
sensitive=True,
|
|
171
|
+
),
|
|
140
172
|
}
|
|
141
173
|
self._lock = threading.Lock()
|
|
142
174
|
self._settings = self.load_settings()
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Wrapper for OpenAI types to make them compatible with Pydantic.
|
|
3
|
+
|
|
4
|
+
Pydantic doesn't support Iterable[T] well, so we use List[T] instead for tool_calls,
|
|
5
|
+
https://github.com/pydantic/pydantic/issues/9541
|
|
6
|
+
|
|
7
|
+
Otherwise we are using OpenAI SDK types directly.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from typing import (
|
|
11
|
+
Iterable,
|
|
12
|
+
List,
|
|
13
|
+
Literal,
|
|
14
|
+
Optional,
|
|
15
|
+
TypeAlias,
|
|
16
|
+
Union,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from openai.types.chat import (
|
|
20
|
+
ChatCompletionDeveloperMessageParam,
|
|
21
|
+
ChatCompletionFunctionMessageParam,
|
|
22
|
+
ChatCompletionMessageToolCallParam,
|
|
23
|
+
ChatCompletionSystemMessageParam,
|
|
24
|
+
ChatCompletionToolMessageParam,
|
|
25
|
+
ChatCompletionUserMessageParam,
|
|
26
|
+
)
|
|
27
|
+
from openai.types.chat.chat_completion_assistant_message_param import (
|
|
28
|
+
Audio,
|
|
29
|
+
ContentArrayOfContentPart,
|
|
30
|
+
FunctionCall,
|
|
31
|
+
)
|
|
32
|
+
from typing_extensions import Required, TypedDict
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ChatCompletionAssistantMessageParamWrapper(TypedDict, total=False):
|
|
36
|
+
"""
|
|
37
|
+
Almost exact copy of ChatCompletionAssistantMessageParam, but two changes.
|
|
38
|
+
|
|
39
|
+
First change: List[T] instead of Iterable[T] for tool_calls. Addresses pydantic issue.
|
|
40
|
+
https://github.com/pydantic/pydantic/issues/9541
|
|
41
|
+
|
|
42
|
+
Second change: Add reasoning_content to the message. A LiteLLM property for reasoning data.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
role: Required[Literal["assistant"]]
|
|
46
|
+
"""The role of the messages author, in this case `assistant`."""
|
|
47
|
+
|
|
48
|
+
audio: Optional[Audio]
|
|
49
|
+
"""Data about a previous audio response from the model.
|
|
50
|
+
|
|
51
|
+
[Learn more](https://platform.openai.com/docs/guides/audio).
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
content: Union[str, Iterable[ContentArrayOfContentPart], None]
|
|
55
|
+
"""The contents of the assistant message.
|
|
56
|
+
|
|
57
|
+
Required unless `tool_calls` or `function_call` is specified.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
reasoning_content: Optional[str]
|
|
61
|
+
"""The reasoning content of the assistant message.
|
|
62
|
+
|
|
63
|
+
A LiteLLM property for reasoning data: https://docs.litellm.ai/docs/reasoning_content
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
function_call: Optional[FunctionCall]
|
|
67
|
+
"""Deprecated and replaced by `tool_calls`.
|
|
68
|
+
|
|
69
|
+
The name and arguments of a function that should be called, as generated by the
|
|
70
|
+
model.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
name: str
|
|
74
|
+
"""An optional name for the participant.
|
|
75
|
+
|
|
76
|
+
Provides the model information to differentiate between participants of the same
|
|
77
|
+
role.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
refusal: Optional[str]
|
|
81
|
+
"""The refusal message by the assistant."""
|
|
82
|
+
|
|
83
|
+
tool_calls: List[ChatCompletionMessageToolCallParam]
|
|
84
|
+
"""The tool calls generated by the model, such as function calls."""
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
ChatCompletionMessageParam: TypeAlias = Union[
|
|
88
|
+
ChatCompletionDeveloperMessageParam,
|
|
89
|
+
ChatCompletionSystemMessageParam,
|
|
90
|
+
ChatCompletionUserMessageParam,
|
|
91
|
+
ChatCompletionAssistantMessageParamWrapper,
|
|
92
|
+
ChatCompletionToolMessageParam,
|
|
93
|
+
ChatCompletionFunctionMessageParam,
|
|
94
|
+
]
|