kiln-ai 0.19.0__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (70) hide show
  1. kiln_ai/adapters/__init__.py +2 -2
  2. kiln_ai/adapters/adapter_registry.py +19 -1
  3. kiln_ai/adapters/chat/chat_formatter.py +8 -12
  4. kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
  5. kiln_ai/adapters/docker_model_runner_tools.py +119 -0
  6. kiln_ai/adapters/eval/base_eval.py +2 -2
  7. kiln_ai/adapters/eval/eval_runner.py +3 -1
  8. kiln_ai/adapters/eval/g_eval.py +2 -2
  9. kiln_ai/adapters/eval/test_base_eval.py +1 -1
  10. kiln_ai/adapters/eval/test_g_eval.py +3 -4
  11. kiln_ai/adapters/fine_tune/__init__.py +1 -1
  12. kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
  13. kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
  14. kiln_ai/adapters/ml_model_list.py +380 -34
  15. kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
  16. kiln_ai/adapters/model_adapters/litellm_adapter.py +383 -79
  17. kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
  18. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +406 -1
  19. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
  20. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
  21. kiln_ai/adapters/model_adapters/test_structured_output.py +110 -4
  22. kiln_ai/adapters/parsers/__init__.py +1 -1
  23. kiln_ai/adapters/provider_tools.py +15 -1
  24. kiln_ai/adapters/repair/test_repair_task.py +12 -9
  25. kiln_ai/adapters/run_output.py +3 -0
  26. kiln_ai/adapters/test_adapter_registry.py +80 -1
  27. kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
  28. kiln_ai/adapters/test_ml_model_list.py +39 -1
  29. kiln_ai/adapters/test_prompt_adaptors.py +13 -6
  30. kiln_ai/adapters/test_provider_tools.py +55 -0
  31. kiln_ai/adapters/test_remote_config.py +98 -0
  32. kiln_ai/datamodel/__init__.py +23 -21
  33. kiln_ai/datamodel/datamodel_enums.py +1 -0
  34. kiln_ai/datamodel/eval.py +1 -1
  35. kiln_ai/datamodel/external_tool_server.py +298 -0
  36. kiln_ai/datamodel/json_schema.py +25 -10
  37. kiln_ai/datamodel/project.py +8 -1
  38. kiln_ai/datamodel/registry.py +0 -15
  39. kiln_ai/datamodel/run_config.py +62 -0
  40. kiln_ai/datamodel/task.py +2 -77
  41. kiln_ai/datamodel/task_output.py +6 -1
  42. kiln_ai/datamodel/task_run.py +41 -0
  43. kiln_ai/datamodel/test_basemodel.py +3 -3
  44. kiln_ai/datamodel/test_example_models.py +175 -0
  45. kiln_ai/datamodel/test_external_tool_server.py +691 -0
  46. kiln_ai/datamodel/test_registry.py +8 -3
  47. kiln_ai/datamodel/test_task.py +15 -47
  48. kiln_ai/datamodel/test_tool_id.py +239 -0
  49. kiln_ai/datamodel/tool_id.py +83 -0
  50. kiln_ai/tools/__init__.py +8 -0
  51. kiln_ai/tools/base_tool.py +82 -0
  52. kiln_ai/tools/built_in_tools/__init__.py +13 -0
  53. kiln_ai/tools/built_in_tools/math_tools.py +124 -0
  54. kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
  55. kiln_ai/tools/mcp_server_tool.py +95 -0
  56. kiln_ai/tools/mcp_session_manager.py +243 -0
  57. kiln_ai/tools/test_base_tools.py +199 -0
  58. kiln_ai/tools/test_mcp_server_tool.py +457 -0
  59. kiln_ai/tools/test_mcp_session_manager.py +1585 -0
  60. kiln_ai/tools/test_tool_registry.py +473 -0
  61. kiln_ai/tools/tool_registry.py +64 -0
  62. kiln_ai/utils/config.py +22 -0
  63. kiln_ai/utils/open_ai_types.py +94 -0
  64. kiln_ai/utils/project_utils.py +17 -0
  65. kiln_ai/utils/test_config.py +138 -1
  66. kiln_ai/utils/test_open_ai_types.py +131 -0
  67. {kiln_ai-0.19.0.dist-info → kiln_ai-0.20.1.dist-info}/METADATA +6 -5
  68. {kiln_ai-0.19.0.dist-info → kiln_ai-0.20.1.dist-info}/RECORD +70 -47
  69. {kiln_ai-0.19.0.dist-info → kiln_ai-0.20.1.dist-info}/WHEEL +0 -0
  70. {kiln_ai-0.19.0.dist-info → kiln_ai-0.20.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,473 @@
1
+ from unittest.mock import Mock
2
+
3
+ import pytest
4
+
5
+ from kiln_ai.datamodel.external_tool_server import ExternalToolServer, ToolServerType
6
+ from kiln_ai.datamodel.project import Project
7
+ from kiln_ai.datamodel.task import Task
8
+ from kiln_ai.datamodel.tool_id import (
9
+ MCP_LOCAL_TOOL_ID_PREFIX,
10
+ MCP_REMOTE_TOOL_ID_PREFIX,
11
+ KilnBuiltInToolId,
12
+ _check_tool_id,
13
+ mcp_server_and_tool_name_from_id,
14
+ )
15
+ from kiln_ai.tools.built_in_tools.math_tools import (
16
+ AddTool,
17
+ DivideTool,
18
+ MultiplyTool,
19
+ SubtractTool,
20
+ )
21
+ from kiln_ai.tools.mcp_server_tool import MCPServerTool
22
+ from kiln_ai.tools.tool_registry import tool_from_id
23
+
24
+
25
+ class TestToolRegistry:
26
+ """Test the tool registry functionality."""
27
+
28
+ async def test_tool_from_id_add_numbers(self):
29
+ """Test that ADD_NUMBERS tool ID returns AddTool instance."""
30
+ tool = tool_from_id(KilnBuiltInToolId.ADD_NUMBERS)
31
+
32
+ assert isinstance(tool, AddTool)
33
+ assert await tool.id() == KilnBuiltInToolId.ADD_NUMBERS
34
+ assert await tool.name() == "add"
35
+ assert "Add two numbers" in await tool.description()
36
+
37
+ async def test_tool_from_id_subtract_numbers(self):
38
+ """Test that SUBTRACT_NUMBERS tool ID returns SubtractTool instance."""
39
+ tool = tool_from_id(KilnBuiltInToolId.SUBTRACT_NUMBERS)
40
+
41
+ assert isinstance(tool, SubtractTool)
42
+ assert await tool.id() == KilnBuiltInToolId.SUBTRACT_NUMBERS
43
+ assert await tool.name() == "subtract"
44
+
45
+ async def test_tool_from_id_multiply_numbers(self):
46
+ """Test that MULTIPLY_NUMBERS tool ID returns MultiplyTool instance."""
47
+ tool = tool_from_id(KilnBuiltInToolId.MULTIPLY_NUMBERS)
48
+
49
+ assert isinstance(tool, MultiplyTool)
50
+ assert await tool.id() == KilnBuiltInToolId.MULTIPLY_NUMBERS
51
+ assert await tool.name() == "multiply"
52
+
53
+ async def test_tool_from_id_divide_numbers(self):
54
+ """Test that DIVIDE_NUMBERS tool ID returns DivideTool instance."""
55
+ tool = tool_from_id(KilnBuiltInToolId.DIVIDE_NUMBERS)
56
+
57
+ assert isinstance(tool, DivideTool)
58
+ assert await tool.id() == KilnBuiltInToolId.DIVIDE_NUMBERS
59
+ assert await tool.name() == "divide"
60
+
61
+ async def test_tool_from_id_with_string_values(self):
62
+ """Test that tool_from_id works with string values of enum members."""
63
+ tool = tool_from_id("kiln_tool::add_numbers")
64
+
65
+ assert isinstance(tool, AddTool)
66
+ assert await tool.id() == KilnBuiltInToolId.ADD_NUMBERS
67
+
68
+ async def test_tool_from_id_invalid_tool_id(self):
69
+ """Test that invalid tool ID raises ValueError."""
70
+ with pytest.raises(
71
+ ValueError, match="Tool ID invalid_tool_id not found in tool registry"
72
+ ):
73
+ tool_from_id("invalid_tool_id")
74
+
75
+ def test_tool_from_id_empty_string(self):
76
+ """Test that empty string tool ID raises ValueError."""
77
+ with pytest.raises(ValueError, match="Tool ID not found in tool registry"):
78
+ tool_from_id("")
79
+
80
+ def test_tool_from_id_mcp_remote_tool_success(self):
81
+ """Test that tool_from_id works with MCP remote tool IDs."""
82
+ # Create mock external tool server
83
+ mock_server = ExternalToolServer(
84
+ name="test_server",
85
+ type=ToolServerType.remote_mcp,
86
+ properties={
87
+ "server_url": "https://example.com",
88
+ "headers": {},
89
+ },
90
+ )
91
+
92
+ # Create mock project with the external tool server
93
+ mock_project = Mock(spec=Project)
94
+ mock_project.id = "test_project_id"
95
+ mock_project.external_tool_servers.return_value = [mock_server]
96
+
97
+ # Create mock task with parent project
98
+ mock_task = Mock(spec=Task)
99
+ mock_task.parent_project.return_value = mock_project
100
+
101
+ # Test with remote MCP tool ID
102
+ tool_id = f"{MCP_REMOTE_TOOL_ID_PREFIX}{mock_server.id}::echo"
103
+ tool = tool_from_id(tool_id, task=mock_task)
104
+
105
+ # Verify the tool is MCPServerTool
106
+ assert isinstance(tool, MCPServerTool)
107
+ assert tool._tool_server_model == mock_server
108
+ assert tool._name == "echo"
109
+
110
+ def test_tool_from_id_mcp_local_tool_success(self):
111
+ """Test that tool_from_id works with MCP local tool IDs."""
112
+ # Create mock external tool server
113
+ mock_server = ExternalToolServer(
114
+ name="local_server",
115
+ type=ToolServerType.local_mcp,
116
+ properties={
117
+ "command": "python",
118
+ "args": ["server.py", "--port", "8080"],
119
+ "env_vars": {},
120
+ },
121
+ )
122
+
123
+ # Create mock project with the external tool server
124
+ mock_project = Mock(spec=Project)
125
+ mock_project.id = "test_project_id"
126
+ mock_project.external_tool_servers.return_value = [mock_server]
127
+
128
+ # Create mock task with parent project
129
+ mock_task = Mock(spec=Task)
130
+ mock_task.parent_project.return_value = mock_project
131
+
132
+ # Test with local MCP tool ID
133
+ tool_id = f"{MCP_LOCAL_TOOL_ID_PREFIX}{mock_server.id}::calculate"
134
+ tool = tool_from_id(tool_id, task=mock_task)
135
+
136
+ # Verify the tool is MCPServerTool
137
+ assert isinstance(tool, MCPServerTool)
138
+ assert tool._tool_server_model == mock_server
139
+ assert tool._name == "calculate"
140
+
141
+ def test_tool_from_id_mcp_tool_project_not_found(self):
142
+ """Test that tool_from_id raises ValueError when task is not provided."""
143
+ tool_id = f"{MCP_LOCAL_TOOL_ID_PREFIX}test_server::test_tool"
144
+ with pytest.raises(
145
+ ValueError,
146
+ match="Unable to resolve tool from id.*Requires a parent project/task",
147
+ ):
148
+ tool_from_id(tool_id, task=None)
149
+
150
+ def test_tool_from_id_mcp_tool_server_not_found(self):
151
+ """Test that tool_from_id raises ValueError when tool server is not found."""
152
+ # Create mock external tool server with different ID
153
+ mock_server = ExternalToolServer(
154
+ name="different_server",
155
+ type=ToolServerType.remote_mcp,
156
+ properties={
157
+ "server_url": "https://example.com",
158
+ "headers": {},
159
+ },
160
+ )
161
+
162
+ # Create mock project with the external tool server
163
+ mock_project = Mock(spec=Project)
164
+ mock_project.id = "test_project_id"
165
+ mock_project.external_tool_servers.return_value = [mock_server]
166
+
167
+ # Create mock task with parent project
168
+ mock_task = Mock(spec=Task)
169
+ mock_task.parent_project.return_value = mock_project
170
+
171
+ # Test with both remote and local tool IDs that reference nonexistent servers
172
+ test_cases = [
173
+ f"{MCP_REMOTE_TOOL_ID_PREFIX}missing_server::test_tool",
174
+ f"{MCP_LOCAL_TOOL_ID_PREFIX}missing_server::test_tool",
175
+ ]
176
+
177
+ for tool_id in test_cases:
178
+ with pytest.raises(
179
+ ValueError,
180
+ match="External tool server not found: missing_server in project ID test_project_id",
181
+ ):
182
+ tool_from_id(tool_id, task=mock_task)
183
+
184
+ def test_all_built_in_tools_are_registered(self):
185
+ """Test that all KilnBuiltInToolId enum members are handled by the registry."""
186
+ for tool_id in KilnBuiltInToolId:
187
+ # This should not raise an exception
188
+ tool = tool_from_id(tool_id.value)
189
+ assert tool is not None
190
+
191
+ async def test_registry_returns_new_instances(self):
192
+ """Test that registry returns new instances each time (not singletons)."""
193
+ tool1 = tool_from_id(KilnBuiltInToolId.ADD_NUMBERS)
194
+ tool2 = tool_from_id(KilnBuiltInToolId.ADD_NUMBERS)
195
+
196
+ assert tool1 is not tool2 # Different instances
197
+ assert type(tool1) is type(tool2) # Same type
198
+ assert await tool1.id() == await tool2.id() # Same id
199
+
200
+ async def test_check_tool_id_valid_built_in_tools(self):
201
+ """Test that _check_tool_id accepts valid built-in tool IDs."""
202
+ for tool_id in KilnBuiltInToolId:
203
+ result = _check_tool_id(tool_id.value)
204
+ assert result == tool_id.value
205
+
206
+ def test_check_tool_id_invalid_tool_id(self):
207
+ """Test that _check_tool_id raises ValueError for invalid tool ID."""
208
+ with pytest.raises(ValueError, match="Invalid tool ID: invalid_tool_id"):
209
+ _check_tool_id("invalid_tool_id")
210
+
211
+ def test_check_tool_id_empty_string(self):
212
+ """Test that _check_tool_id raises ValueError for empty string."""
213
+ with pytest.raises(ValueError, match="Invalid tool ID: "):
214
+ _check_tool_id("")
215
+
216
+ def test_check_tool_id_none_value(self):
217
+ """Test that _check_tool_id raises ValueError for None."""
218
+ with pytest.raises(ValueError, match="Invalid tool ID: None"):
219
+ _check_tool_id(None) # type: ignore
220
+
221
+ def test_check_tool_id_valid_mcp_remote_tool_id(self):
222
+ """Test that _check_tool_id accepts valid MCP remote tool IDs."""
223
+ valid_mcp_ids = [
224
+ f"{MCP_REMOTE_TOOL_ID_PREFIX}server123::tool_name",
225
+ f"{MCP_REMOTE_TOOL_ID_PREFIX}my_server::echo",
226
+ f"{MCP_REMOTE_TOOL_ID_PREFIX}123456789::test_tool",
227
+ f"{MCP_REMOTE_TOOL_ID_PREFIX}server_with_underscores::complex_tool_name",
228
+ ]
229
+
230
+ for tool_id in valid_mcp_ids:
231
+ result = _check_tool_id(tool_id)
232
+ assert result == tool_id
233
+
234
+ def test_check_tool_id_valid_mcp_local_tool_id(self):
235
+ """Test that _check_tool_id accepts valid MCP local tool IDs."""
236
+ valid_mcp_local_ids = [
237
+ f"{MCP_LOCAL_TOOL_ID_PREFIX}server123::tool_name",
238
+ f"{MCP_LOCAL_TOOL_ID_PREFIX}my_local_server::calculate",
239
+ f"{MCP_LOCAL_TOOL_ID_PREFIX}local_tool_server::process_data",
240
+ f"{MCP_LOCAL_TOOL_ID_PREFIX}server_with_underscores::complex_tool_name",
241
+ ]
242
+
243
+ for tool_id in valid_mcp_local_ids:
244
+ result = _check_tool_id(tool_id)
245
+ assert result == tool_id
246
+
247
+ def test_check_tool_id_invalid_mcp_remote_tool_id(self):
248
+ """Test that _check_tool_id rejects invalid MCP-like tool IDs."""
249
+ # These start with the prefix but have wrong format - get specific MCP error
250
+ invalid_mcp_format_ids = [
251
+ "mcp::remote::server", # Missing tool name (only 3 parts instead of 4)
252
+ "mcp::remote::", # Missing server and tool name (only 3 parts)
253
+ "mcp::remote::::tool", # Empty server name (5 parts instead of 4)
254
+ "mcp::remote::server::tool::extra", # Too many parts (5 instead of 4)
255
+ ]
256
+
257
+ for invalid_id in invalid_mcp_format_ids:
258
+ with pytest.raises(
259
+ ValueError, match=f"Invalid remote MCP tool ID: {invalid_id}"
260
+ ):
261
+ _check_tool_id(invalid_id)
262
+
263
+ # These don't match the prefix - get generic error
264
+ invalid_generic_ids = [
265
+ "mcp::remote:", # Missing last colon (doesn't match full prefix)
266
+ "mcp:remote::server::tool", # Wrong prefix format
267
+ "mcp::remote_server::tool", # Wrong prefix format
268
+ "remote::server::tool", # Missing mcp prefix
269
+ ]
270
+
271
+ for invalid_id in invalid_generic_ids:
272
+ with pytest.raises(ValueError, match=f"Invalid tool ID: {invalid_id}"):
273
+ _check_tool_id(invalid_id)
274
+
275
+ def test_mcp_server_and_tool_name_from_id_valid_inputs(self):
276
+ """Test that mcp_server_and_tool_name_from_id correctly parses valid MCP tool IDs."""
277
+ test_cases = [
278
+ # Remote MCP tool IDs
279
+ ("mcp::remote::server123::tool_name", ("server123", "tool_name")),
280
+ ("mcp::remote::my_server::echo", ("my_server", "echo")),
281
+ ("mcp::remote::123456789::test_tool", ("123456789", "test_tool")),
282
+ (
283
+ "mcp::remote::server_with_underscores::complex_tool_name",
284
+ ("server_with_underscores", "complex_tool_name"),
285
+ ),
286
+ ("mcp::remote::a::b", ("a", "b")), # Minimal valid case
287
+ (
288
+ "mcp::remote::server-with-dashes::tool-with-dashes",
289
+ ("server-with-dashes", "tool-with-dashes"),
290
+ ),
291
+ # Local MCP tool IDs
292
+ ("mcp::local::local_server::calculate", ("local_server", "calculate")),
293
+ ("mcp::local::my_local_tool::process", ("my_local_tool", "process")),
294
+ (
295
+ "mcp::local::123456789::local_test_tool",
296
+ ("123456789", "local_test_tool"),
297
+ ),
298
+ (
299
+ "mcp::local::local_server_with_underscores::complex_local_tool",
300
+ ("local_server_with_underscores", "complex_local_tool"),
301
+ ),
302
+ ("mcp::local::x::y", ("x", "y")), # Minimal valid case for local
303
+ ]
304
+
305
+ for tool_id, expected in test_cases:
306
+ result = mcp_server_and_tool_name_from_id(tool_id)
307
+ assert result == expected, (
308
+ f"Failed for {tool_id}: expected {expected}, got {result}"
309
+ )
310
+
311
+ def test_mcp_server_and_tool_name_from_id_invalid_inputs(self):
312
+ """Test that mcp_server_and_tool_name_from_id raises ValueError for invalid MCP tool IDs."""
313
+ # Test remote MCP format errors
314
+ remote_invalid_inputs = [
315
+ "mcp::remote::server", # Only 3 parts instead of 4
316
+ "mcp::remote::", # Only 3 parts, missing server and tool
317
+ "mcp::remote::server::tool::extra", # 5 parts instead of 4
318
+ ]
319
+
320
+ for invalid_id in remote_invalid_inputs:
321
+ with pytest.raises(
322
+ ValueError,
323
+ match=r"Invalid remote MCP tool ID:.*Expected format.*mcp::remote::<server_id>::<tool_name>",
324
+ ):
325
+ mcp_server_and_tool_name_from_id(invalid_id)
326
+
327
+ # Test local MCP format errors
328
+ local_invalid_inputs = [
329
+ "mcp::local::server", # Only 3 parts instead of 4
330
+ "mcp::local::", # Only 3 parts, missing server and tool
331
+ "mcp::local::server::tool::extra", # 5 parts instead of 4
332
+ ]
333
+
334
+ for invalid_id in local_invalid_inputs:
335
+ with pytest.raises(
336
+ ValueError,
337
+ match=r"Invalid local MCP tool ID:.*Expected format.*mcp::local::<server_id>::<tool_name>",
338
+ ):
339
+ mcp_server_and_tool_name_from_id(invalid_id)
340
+
341
+ # Test generic MCP format errors (no valid prefix)
342
+ generic_invalid_inputs = [
343
+ "invalid::format::here", # 3 parts, wrong prefix
344
+ "", # Empty string
345
+ "single_part", # No separators
346
+ "two::parts", # Only 2 parts
347
+ ]
348
+
349
+ for invalid_id in generic_invalid_inputs:
350
+ with pytest.raises(
351
+ ValueError,
352
+ match=r"Invalid MCP tool ID:.*Expected format.*mcp::\(remote\|local\)::<server_id>::<tool_name>",
353
+ ):
354
+ mcp_server_and_tool_name_from_id(invalid_id)
355
+
356
+ def test_mcp_server_and_tool_name_from_id_edge_cases(self):
357
+ """Test that mcp_server_and_tool_name_from_id handles edge cases (empty parts allowed by parser)."""
358
+ # These are valid according to the parser (exactly 4 parts),
359
+ # but empty server_id/tool_name validation is handled by _check_tool_id
360
+ edge_cases = [
361
+ ("mcp::remote::::tool", ("", "tool")), # Empty server name
362
+ ("mcp::remote::server::", ("server", "")), # Empty tool name
363
+ ("mcp::remote::::", ("", "")), # Both empty
364
+ ]
365
+
366
+ for tool_id, expected in edge_cases:
367
+ result = mcp_server_and_tool_name_from_id(tool_id)
368
+ assert result == expected, (
369
+ f"Failed for {tool_id}: expected {expected}, got {result}"
370
+ )
371
+
372
+ @pytest.mark.parametrize(
373
+ "tool_id,expected_server,expected_tool",
374
+ [
375
+ ("mcp::remote::test_server::test_tool", "test_server", "test_tool"),
376
+ ("mcp::remote::s::t", "s", "t"),
377
+ (
378
+ "mcp::remote::long_server_name_123::complex_tool_name_456",
379
+ "long_server_name_123",
380
+ "complex_tool_name_456",
381
+ ),
382
+ (
383
+ "mcp::local::local_test_server::local_test_tool",
384
+ "local_test_server",
385
+ "local_test_tool",
386
+ ),
387
+ ("mcp::local::l::l", "l", "l"),
388
+ (
389
+ "mcp::local::long_local_server_123::complex_local_tool_456",
390
+ "long_local_server_123",
391
+ "complex_local_tool_456",
392
+ ),
393
+ ],
394
+ )
395
+ def test_mcp_server_and_tool_name_from_id_parametrized(
396
+ self, tool_id, expected_server, expected_tool
397
+ ):
398
+ """Parametrized test for mcp_server_and_tool_name_from_id with various valid inputs."""
399
+ server_id, tool_name = mcp_server_and_tool_name_from_id(tool_id)
400
+ assert server_id == expected_server
401
+ assert tool_name == expected_tool
402
+
403
+ def test_tool_from_id_mcp_missing_task_raises_error(self):
404
+ """Test that MCP tool ID with missing task raises ValueError."""
405
+ mcp_tool_id = f"{MCP_REMOTE_TOOL_ID_PREFIX}test_server::test_tool"
406
+
407
+ with pytest.raises(
408
+ ValueError,
409
+ match="Unable to resolve tool from id.*Requires a parent project/task",
410
+ ):
411
+ tool_from_id(mcp_tool_id, task=None)
412
+
413
+ def test_tool_from_id_mcp_functional_case(self):
414
+ """Test that MCP tool ID with valid task and project returns MCPServerTool."""
415
+ # Create mock external tool server
416
+ mock_server = ExternalToolServer(
417
+ name="test_server",
418
+ type=ToolServerType.remote_mcp,
419
+ description="Test MCP server",
420
+ properties={
421
+ "server_url": "https://example.com",
422
+ "headers": {},
423
+ },
424
+ )
425
+
426
+ # Create mock project with the external tool server
427
+ mock_project = Mock(spec=Project)
428
+ mock_project.id = "test_project_id"
429
+ mock_project.external_tool_servers.return_value = [mock_server]
430
+
431
+ # Create mock task with parent project
432
+ mock_task = Mock(spec=Task)
433
+ mock_task.parent_project.return_value = mock_project
434
+
435
+ mcp_tool_id = f"{MCP_REMOTE_TOOL_ID_PREFIX}{mock_server.id}::test_tool"
436
+
437
+ tool = tool_from_id(mcp_tool_id, task=mock_task)
438
+
439
+ assert isinstance(tool, MCPServerTool)
440
+ # Verify the tool was created with the correct server and tool name
441
+ assert tool._tool_server_model == mock_server
442
+ assert tool._name == "test_tool"
443
+
444
+ def test_tool_from_id_mcp_no_server_found_raises_error(self):
445
+ """Test that MCP tool ID with server not found raises ValueError."""
446
+ # Create mock external tool server with different ID
447
+ mock_server = ExternalToolServer(
448
+ name="different_server",
449
+ type=ToolServerType.remote_mcp,
450
+ description="Different MCP server",
451
+ properties={
452
+ "server_url": "https://example.com",
453
+ "headers": {},
454
+ },
455
+ )
456
+
457
+ # Create mock project with the external tool server
458
+ mock_project = Mock(spec=Project)
459
+ mock_project.id = "test_project_id"
460
+ mock_project.external_tool_servers.return_value = [mock_server]
461
+
462
+ # Create mock task with parent project
463
+ mock_task = Mock(spec=Task)
464
+ mock_task.parent_project.return_value = mock_project
465
+
466
+ # Use a tool ID with a server that doesn't exist in the project
467
+ mcp_tool_id = f"{MCP_REMOTE_TOOL_ID_PREFIX}nonexistent_server::test_tool"
468
+
469
+ with pytest.raises(
470
+ ValueError,
471
+ match="External tool server not found: nonexistent_server in project ID test_project_id",
472
+ ):
473
+ tool_from_id(mcp_tool_id, task=mock_task)
@@ -0,0 +1,64 @@
1
+ from kiln_ai.datamodel.task import Task
2
+ from kiln_ai.datamodel.tool_id import (
3
+ MCP_LOCAL_TOOL_ID_PREFIX,
4
+ MCP_REMOTE_TOOL_ID_PREFIX,
5
+ KilnBuiltInToolId,
6
+ mcp_server_and_tool_name_from_id,
7
+ )
8
+ from kiln_ai.tools.base_tool import KilnToolInterface
9
+ from kiln_ai.tools.built_in_tools.math_tools import (
10
+ AddTool,
11
+ DivideTool,
12
+ MultiplyTool,
13
+ SubtractTool,
14
+ )
15
+ from kiln_ai.tools.mcp_server_tool import MCPServerTool
16
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
17
+
18
+
19
+ def tool_from_id(tool_id: str, task: Task | None = None) -> KilnToolInterface:
20
+ """
21
+ Get a tool from its ID.
22
+ """
23
+ # Check built-in tools
24
+ if tool_id in [member.value for member in KilnBuiltInToolId]:
25
+ typed_tool_id = KilnBuiltInToolId(tool_id)
26
+ match typed_tool_id:
27
+ case KilnBuiltInToolId.ADD_NUMBERS:
28
+ return AddTool()
29
+ case KilnBuiltInToolId.SUBTRACT_NUMBERS:
30
+ return SubtractTool()
31
+ case KilnBuiltInToolId.MULTIPLY_NUMBERS:
32
+ return MultiplyTool()
33
+ case KilnBuiltInToolId.DIVIDE_NUMBERS:
34
+ return DivideTool()
35
+ case _:
36
+ raise_exhaustive_enum_error(typed_tool_id)
37
+
38
+ # Check MCP Server Tools
39
+ if tool_id.startswith((MCP_REMOTE_TOOL_ID_PREFIX, MCP_LOCAL_TOOL_ID_PREFIX)):
40
+ project = task.parent_project() if task is not None else None
41
+ if project is None:
42
+ raise ValueError(
43
+ f"Unable to resolve tool from id: {tool_id}. Requires a parent project/task."
44
+ )
45
+
46
+ # Get the tool server ID and tool name from the ID
47
+ tool_server_id, tool_name = mcp_server_and_tool_name_from_id(tool_id)
48
+
49
+ server = next(
50
+ (
51
+ server
52
+ for server in project.external_tool_servers()
53
+ if server.id == tool_server_id
54
+ ),
55
+ None,
56
+ )
57
+ if server is None:
58
+ raise ValueError(
59
+ f"External tool server not found: {tool_server_id} in project ID {project.id}"
60
+ )
61
+
62
+ return MCPServerTool(server, tool_name)
63
+
64
+ raise ValueError(f"Tool ID {tool_id} not found in tool registry")
kiln_ai/utils/config.py CHANGED
@@ -6,6 +6,9 @@ from typing import Any, Callable, Dict, List, Optional
6
6
 
7
7
  import yaml
8
8
 
9
+ # Configuration keys
10
+ MCP_SECRETS_KEY = "mcp_secrets"
11
+
9
12
 
10
13
  class ConfigProperty:
11
14
  def __init__(
@@ -54,6 +57,10 @@ class Config:
54
57
  str,
55
58
  env_var="OLLAMA_BASE_URL",
56
59
  ),
60
+ "docker_model_runner_base_url": ConfigProperty(
61
+ str,
62
+ env_var="DOCKER_MODEL_RUNNER_BASE_URL",
63
+ ),
57
64
  "bedrock_access_key": ConfigProperty(
58
65
  str,
59
66
  env_var="AWS_ACCESS_KEY_ID",
@@ -147,6 +154,21 @@ class Config:
147
154
  env_var="CEREBRAS_API_KEY",
148
155
  sensitive=True,
149
156
  ),
157
+ "enable_demo_tools": ConfigProperty(
158
+ bool,
159
+ env_var="ENABLE_DEMO_TOOLS",
160
+ default=False,
161
+ ),
162
+ # Allow the user to set the path to lookup MCP server commands, like npx.
163
+ "custom_mcp_path": ConfigProperty(
164
+ str,
165
+ env_var="CUSTOM_MCP_PATH",
166
+ ),
167
+ # Allow the user to set secrets for MCP servers, the key is mcp_server_id::key_name
168
+ MCP_SECRETS_KEY: ConfigProperty(
169
+ dict[str, str],
170
+ sensitive=True,
171
+ ),
150
172
  }
151
173
  self._lock = threading.Lock()
152
174
  self._settings = self.load_settings()
@@ -0,0 +1,94 @@
1
+ """
2
+ Wrapper for OpenAI types to make them compatible with Pydantic.
3
+
4
+ Pydantic doesn't support Iterable[T] well, so we use List[T] instead for tool_calls,
5
+ https://github.com/pydantic/pydantic/issues/9541
6
+
7
+ Otherwise we are using OpenAI SDK types directly.
8
+ """
9
+
10
+ from typing import (
11
+ Iterable,
12
+ List,
13
+ Literal,
14
+ Optional,
15
+ TypeAlias,
16
+ Union,
17
+ )
18
+
19
+ from openai.types.chat import (
20
+ ChatCompletionDeveloperMessageParam,
21
+ ChatCompletionFunctionMessageParam,
22
+ ChatCompletionMessageToolCallParam,
23
+ ChatCompletionSystemMessageParam,
24
+ ChatCompletionToolMessageParam,
25
+ ChatCompletionUserMessageParam,
26
+ )
27
+ from openai.types.chat.chat_completion_assistant_message_param import (
28
+ Audio,
29
+ ContentArrayOfContentPart,
30
+ FunctionCall,
31
+ )
32
+ from typing_extensions import Required, TypedDict
33
+
34
+
35
+ class ChatCompletionAssistantMessageParamWrapper(TypedDict, total=False):
36
+ """
37
+ Almost exact copy of ChatCompletionAssistantMessageParam, but two changes.
38
+
39
+ First change: List[T] instead of Iterable[T] for tool_calls. Addresses pydantic issue.
40
+ https://github.com/pydantic/pydantic/issues/9541
41
+
42
+ Second change: Add reasoning_content to the message. A LiteLLM property for reasoning data.
43
+ """
44
+
45
+ role: Required[Literal["assistant"]]
46
+ """The role of the messages author, in this case `assistant`."""
47
+
48
+ audio: Optional[Audio]
49
+ """Data about a previous audio response from the model.
50
+
51
+ [Learn more](https://platform.openai.com/docs/guides/audio).
52
+ """
53
+
54
+ content: Union[str, Iterable[ContentArrayOfContentPart], None]
55
+ """The contents of the assistant message.
56
+
57
+ Required unless `tool_calls` or `function_call` is specified.
58
+ """
59
+
60
+ reasoning_content: Optional[str]
61
+ """The reasoning content of the assistant message.
62
+
63
+ A LiteLLM property for reasoning data: https://docs.litellm.ai/docs/reasoning_content
64
+ """
65
+
66
+ function_call: Optional[FunctionCall]
67
+ """Deprecated and replaced by `tool_calls`.
68
+
69
+ The name and arguments of a function that should be called, as generated by the
70
+ model.
71
+ """
72
+
73
+ name: str
74
+ """An optional name for the participant.
75
+
76
+ Provides the model information to differentiate between participants of the same
77
+ role.
78
+ """
79
+
80
+ refusal: Optional[str]
81
+ """The refusal message by the assistant."""
82
+
83
+ tool_calls: List[ChatCompletionMessageToolCallParam]
84
+ """The tool calls generated by the model, such as function calls."""
85
+
86
+
87
+ ChatCompletionMessageParam: TypeAlias = Union[
88
+ ChatCompletionDeveloperMessageParam,
89
+ ChatCompletionSystemMessageParam,
90
+ ChatCompletionUserMessageParam,
91
+ ChatCompletionAssistantMessageParamWrapper,
92
+ ChatCompletionToolMessageParam,
93
+ ChatCompletionFunctionMessageParam,
94
+ ]
@@ -0,0 +1,17 @@
1
+ from kiln_ai.datamodel.project import Project
2
+ from kiln_ai.utils.config import Config
3
+
4
+
5
+ def project_from_id(project_id: str) -> Project | None:
6
+ project_paths = Config.shared().projects
7
+ if project_paths is not None:
8
+ for project_path in project_paths:
9
+ try:
10
+ project = Project.load_from_file(project_path)
11
+ if project.id == project_id:
12
+ return project
13
+ except Exception:
14
+ # deleted files are possible continue with the rest
15
+ continue
16
+
17
+ return None