kiln-ai 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (158) hide show
  1. kiln_ai/adapters/__init__.py +8 -2
  2. kiln_ai/adapters/adapter_registry.py +43 -208
  3. kiln_ai/adapters/chat/chat_formatter.py +8 -12
  4. kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
  5. kiln_ai/adapters/chunkers/__init__.py +13 -0
  6. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  7. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  8. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  9. kiln_ai/adapters/chunkers/helpers.py +23 -0
  10. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  11. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  12. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  13. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  14. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  15. kiln_ai/adapters/docker_model_runner_tools.py +119 -0
  16. kiln_ai/adapters/embedding/__init__.py +0 -0
  17. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  18. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  19. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  20. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  21. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  22. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  23. kiln_ai/adapters/eval/base_eval.py +2 -2
  24. kiln_ai/adapters/eval/eval_runner.py +9 -3
  25. kiln_ai/adapters/eval/g_eval.py +2 -2
  26. kiln_ai/adapters/eval/test_base_eval.py +2 -4
  27. kiln_ai/adapters/eval/test_g_eval.py +4 -5
  28. kiln_ai/adapters/extractors/__init__.py +18 -0
  29. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  30. kiln_ai/adapters/extractors/encoding.py +20 -0
  31. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  32. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  33. kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
  34. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  35. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  36. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  37. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  38. kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
  39. kiln_ai/adapters/fine_tune/__init__.py +1 -1
  40. kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
  41. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  42. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  43. kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
  44. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  45. kiln_ai/adapters/ml_embedding_model_list.py +192 -0
  46. kiln_ai/adapters/ml_model_list.py +761 -37
  47. kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
  48. kiln_ai/adapters/model_adapters/litellm_adapter.py +380 -138
  49. kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
  50. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -2
  51. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
  52. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
  53. kiln_ai/adapters/model_adapters/test_structured_output.py +113 -5
  54. kiln_ai/adapters/ollama_tools.py +69 -12
  55. kiln_ai/adapters/parsers/__init__.py +1 -1
  56. kiln_ai/adapters/provider_tools.py +205 -47
  57. kiln_ai/adapters/rag/deduplication.py +49 -0
  58. kiln_ai/adapters/rag/progress.py +252 -0
  59. kiln_ai/adapters/rag/rag_runners.py +844 -0
  60. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  61. kiln_ai/adapters/rag/test_progress.py +785 -0
  62. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  63. kiln_ai/adapters/remote_config.py +80 -8
  64. kiln_ai/adapters/repair/test_repair_task.py +12 -9
  65. kiln_ai/adapters/run_output.py +3 -0
  66. kiln_ai/adapters/test_adapter_registry.py +657 -85
  67. kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
  68. kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
  69. kiln_ai/adapters/test_ml_model_list.py +251 -1
  70. kiln_ai/adapters/test_ollama_tools.py +340 -1
  71. kiln_ai/adapters/test_prompt_adaptors.py +13 -6
  72. kiln_ai/adapters/test_prompt_builders.py +1 -1
  73. kiln_ai/adapters/test_provider_tools.py +254 -8
  74. kiln_ai/adapters/test_remote_config.py +651 -58
  75. kiln_ai/adapters/vector_store/__init__.py +1 -0
  76. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  77. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  78. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  79. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  80. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  81. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  82. kiln_ai/datamodel/__init__.py +39 -34
  83. kiln_ai/datamodel/basemodel.py +170 -1
  84. kiln_ai/datamodel/chunk.py +158 -0
  85. kiln_ai/datamodel/datamodel_enums.py +28 -0
  86. kiln_ai/datamodel/embedding.py +64 -0
  87. kiln_ai/datamodel/eval.py +1 -1
  88. kiln_ai/datamodel/external_tool_server.py +298 -0
  89. kiln_ai/datamodel/extraction.py +303 -0
  90. kiln_ai/datamodel/json_schema.py +25 -10
  91. kiln_ai/datamodel/project.py +40 -1
  92. kiln_ai/datamodel/rag.py +79 -0
  93. kiln_ai/datamodel/registry.py +0 -15
  94. kiln_ai/datamodel/run_config.py +62 -0
  95. kiln_ai/datamodel/task.py +2 -77
  96. kiln_ai/datamodel/task_output.py +6 -1
  97. kiln_ai/datamodel/task_run.py +41 -0
  98. kiln_ai/datamodel/test_attachment.py +649 -0
  99. kiln_ai/datamodel/test_basemodel.py +4 -4
  100. kiln_ai/datamodel/test_chunk_models.py +317 -0
  101. kiln_ai/datamodel/test_dataset_split.py +1 -1
  102. kiln_ai/datamodel/test_embedding_models.py +448 -0
  103. kiln_ai/datamodel/test_eval_model.py +6 -6
  104. kiln_ai/datamodel/test_example_models.py +175 -0
  105. kiln_ai/datamodel/test_external_tool_server.py +691 -0
  106. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  107. kiln_ai/datamodel/test_extraction_model.py +470 -0
  108. kiln_ai/datamodel/test_rag.py +641 -0
  109. kiln_ai/datamodel/test_registry.py +8 -3
  110. kiln_ai/datamodel/test_task.py +15 -47
  111. kiln_ai/datamodel/test_tool_id.py +320 -0
  112. kiln_ai/datamodel/test_vector_store.py +320 -0
  113. kiln_ai/datamodel/tool_id.py +105 -0
  114. kiln_ai/datamodel/vector_store.py +141 -0
  115. kiln_ai/tools/__init__.py +8 -0
  116. kiln_ai/tools/base_tool.py +82 -0
  117. kiln_ai/tools/built_in_tools/__init__.py +13 -0
  118. kiln_ai/tools/built_in_tools/math_tools.py +124 -0
  119. kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
  120. kiln_ai/tools/mcp_server_tool.py +95 -0
  121. kiln_ai/tools/mcp_session_manager.py +246 -0
  122. kiln_ai/tools/rag_tools.py +157 -0
  123. kiln_ai/tools/test_base_tools.py +199 -0
  124. kiln_ai/tools/test_mcp_server_tool.py +457 -0
  125. kiln_ai/tools/test_mcp_session_manager.py +1585 -0
  126. kiln_ai/tools/test_rag_tools.py +848 -0
  127. kiln_ai/tools/test_tool_registry.py +562 -0
  128. kiln_ai/tools/tool_registry.py +85 -0
  129. kiln_ai/utils/__init__.py +3 -0
  130. kiln_ai/utils/async_job_runner.py +62 -17
  131. kiln_ai/utils/config.py +24 -2
  132. kiln_ai/utils/env.py +15 -0
  133. kiln_ai/utils/filesystem.py +14 -0
  134. kiln_ai/utils/filesystem_cache.py +60 -0
  135. kiln_ai/utils/litellm.py +94 -0
  136. kiln_ai/utils/lock.py +100 -0
  137. kiln_ai/utils/mime_type.py +38 -0
  138. kiln_ai/utils/open_ai_types.py +94 -0
  139. kiln_ai/utils/pdf_utils.py +38 -0
  140. kiln_ai/utils/project_utils.py +17 -0
  141. kiln_ai/utils/test_async_job_runner.py +151 -35
  142. kiln_ai/utils/test_config.py +138 -1
  143. kiln_ai/utils/test_env.py +142 -0
  144. kiln_ai/utils/test_filesystem_cache.py +316 -0
  145. kiln_ai/utils/test_litellm.py +206 -0
  146. kiln_ai/utils/test_lock.py +185 -0
  147. kiln_ai/utils/test_mime_type.py +66 -0
  148. kiln_ai/utils/test_open_ai_types.py +131 -0
  149. kiln_ai/utils/test_pdf_utils.py +73 -0
  150. kiln_ai/utils/test_uuid.py +111 -0
  151. kiln_ai/utils/test_validation.py +524 -0
  152. kiln_ai/utils/uuid.py +9 -0
  153. kiln_ai/utils/validation.py +90 -0
  154. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +12 -5
  155. kiln_ai-0.21.0.dist-info/RECORD +211 -0
  156. kiln_ai-0.19.0.dist-info/RECORD +0 -115
  157. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
  158. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,562 @@
1
+ from pathlib import Path
2
+ from unittest.mock import Mock
3
+
4
+ import pytest
5
+
6
+ from kiln_ai.datamodel.external_tool_server import ExternalToolServer, ToolServerType
7
+ from kiln_ai.datamodel.project import Project
8
+ from kiln_ai.datamodel.task import Task
9
+ from kiln_ai.datamodel.tool_id import (
10
+ MCP_LOCAL_TOOL_ID_PREFIX,
11
+ MCP_REMOTE_TOOL_ID_PREFIX,
12
+ RAG_TOOL_ID_PREFIX,
13
+ KilnBuiltInToolId,
14
+ _check_tool_id,
15
+ mcp_server_and_tool_name_from_id,
16
+ )
17
+ from kiln_ai.tools.built_in_tools.math_tools import (
18
+ AddTool,
19
+ DivideTool,
20
+ MultiplyTool,
21
+ SubtractTool,
22
+ )
23
+ from kiln_ai.tools.mcp_server_tool import MCPServerTool
24
+ from kiln_ai.tools.tool_registry import tool_from_id
25
+
26
+
27
+ class TestToolRegistry:
28
+ """Test the tool registry functionality."""
29
+
30
+ async def test_tool_from_id_add_numbers(self):
31
+ """Test that ADD_NUMBERS tool ID returns AddTool instance."""
32
+ tool = tool_from_id(KilnBuiltInToolId.ADD_NUMBERS)
33
+
34
+ assert isinstance(tool, AddTool)
35
+ assert await tool.id() == KilnBuiltInToolId.ADD_NUMBERS
36
+ assert await tool.name() == "add"
37
+ assert "Add two numbers" in await tool.description()
38
+
39
+ async def test_tool_from_id_subtract_numbers(self):
40
+ """Test that SUBTRACT_NUMBERS tool ID returns SubtractTool instance."""
41
+ tool = tool_from_id(KilnBuiltInToolId.SUBTRACT_NUMBERS)
42
+
43
+ assert isinstance(tool, SubtractTool)
44
+ assert await tool.id() == KilnBuiltInToolId.SUBTRACT_NUMBERS
45
+ assert await tool.name() == "subtract"
46
+
47
+ async def test_tool_from_id_multiply_numbers(self):
48
+ """Test that MULTIPLY_NUMBERS tool ID returns MultiplyTool instance."""
49
+ tool = tool_from_id(KilnBuiltInToolId.MULTIPLY_NUMBERS)
50
+
51
+ assert isinstance(tool, MultiplyTool)
52
+ assert await tool.id() == KilnBuiltInToolId.MULTIPLY_NUMBERS
53
+ assert await tool.name() == "multiply"
54
+
55
+ async def test_tool_from_id_divide_numbers(self):
56
+ """Test that DIVIDE_NUMBERS tool ID returns DivideTool instance."""
57
+ tool = tool_from_id(KilnBuiltInToolId.DIVIDE_NUMBERS)
58
+
59
+ assert isinstance(tool, DivideTool)
60
+ assert await tool.id() == KilnBuiltInToolId.DIVIDE_NUMBERS
61
+ assert await tool.name() == "divide"
62
+
63
+ async def test_tool_from_id_with_string_values(self):
64
+ """Test that tool_from_id works with string values of enum members."""
65
+ tool = tool_from_id("kiln_tool::add_numbers")
66
+
67
+ assert isinstance(tool, AddTool)
68
+ assert await tool.id() == KilnBuiltInToolId.ADD_NUMBERS
69
+
70
+ async def test_tool_from_id_invalid_tool_id(self):
71
+ """Test that invalid tool ID raises ValueError."""
72
+ with pytest.raises(
73
+ ValueError, match="Tool ID invalid_tool_id not found in tool registry"
74
+ ):
75
+ tool_from_id("invalid_tool_id")
76
+
77
+ def test_tool_from_id_empty_string(self):
78
+ """Test that empty string tool ID raises ValueError."""
79
+ with pytest.raises(ValueError, match="Tool ID not found in tool registry"):
80
+ tool_from_id("")
81
+
82
+ def test_tool_from_id_mcp_remote_tool_success(self):
83
+ """Test that tool_from_id works with MCP remote tool IDs."""
84
+ # Create mock external tool server
85
+ mock_server = ExternalToolServer(
86
+ name="test_server",
87
+ type=ToolServerType.remote_mcp,
88
+ properties={
89
+ "server_url": "https://example.com",
90
+ "headers": {},
91
+ },
92
+ )
93
+
94
+ # Create mock project with the external tool server
95
+ mock_project = Mock(spec=Project)
96
+ mock_project.id = "test_project_id"
97
+ mock_project.external_tool_servers.return_value = [mock_server]
98
+
99
+ # Create mock task with parent project
100
+ mock_task = Mock(spec=Task)
101
+ mock_task.parent_project.return_value = mock_project
102
+
103
+ # Test with remote MCP tool ID
104
+ tool_id = f"{MCP_REMOTE_TOOL_ID_PREFIX}{mock_server.id}::echo"
105
+ tool = tool_from_id(tool_id, task=mock_task)
106
+
107
+ # Verify the tool is MCPServerTool
108
+ assert isinstance(tool, MCPServerTool)
109
+ assert tool._tool_server_model == mock_server
110
+ assert tool._name == "echo"
111
+
112
+ def test_tool_from_id_mcp_local_tool_success(self):
113
+ """Test that tool_from_id works with MCP local tool IDs."""
114
+ # Create mock external tool server
115
+ mock_server = ExternalToolServer(
116
+ name="local_server",
117
+ type=ToolServerType.local_mcp,
118
+ properties={
119
+ "command": "python",
120
+ "args": ["server.py", "--port", "8080"],
121
+ "env_vars": {},
122
+ },
123
+ )
124
+
125
+ # Create mock project with the external tool server
126
+ mock_project = Mock(spec=Project)
127
+ mock_project.id = "test_project_id"
128
+ mock_project.external_tool_servers.return_value = [mock_server]
129
+
130
+ # Create mock task with parent project
131
+ mock_task = Mock(spec=Task)
132
+ mock_task.parent_project.return_value = mock_project
133
+
134
+ # Test with local MCP tool ID
135
+ tool_id = f"{MCP_LOCAL_TOOL_ID_PREFIX}{mock_server.id}::calculate"
136
+ tool = tool_from_id(tool_id, task=mock_task)
137
+
138
+ # Verify the tool is MCPServerTool
139
+ assert isinstance(tool, MCPServerTool)
140
+ assert tool._tool_server_model == mock_server
141
+ assert tool._name == "calculate"
142
+
143
+ def test_tool_from_id_mcp_tool_project_not_found(self):
144
+ """Test that tool_from_id raises ValueError when task is not provided."""
145
+ tool_id = f"{MCP_LOCAL_TOOL_ID_PREFIX}test_server::test_tool"
146
+ with pytest.raises(
147
+ ValueError,
148
+ match=r"Unable to resolve tool from id.*Requires a parent project/task",
149
+ ):
150
+ tool_from_id(tool_id, task=None)
151
+
152
+ def test_tool_from_id_mcp_tool_server_not_found(self):
153
+ """Test that tool_from_id raises ValueError when tool server is not found."""
154
+ # Create mock external tool server with different ID
155
+ mock_server = ExternalToolServer(
156
+ name="different_server",
157
+ type=ToolServerType.remote_mcp,
158
+ properties={
159
+ "server_url": "https://example.com",
160
+ "headers": {},
161
+ },
162
+ )
163
+
164
+ # Create mock project with the external tool server
165
+ mock_project = Mock(spec=Project)
166
+ mock_project.id = "test_project_id"
167
+ mock_project.external_tool_servers.return_value = [mock_server]
168
+
169
+ # Create mock task with parent project
170
+ mock_task = Mock(spec=Task)
171
+ mock_task.parent_project.return_value = mock_project
172
+
173
+ # Test with both remote and local tool IDs that reference nonexistent servers
174
+ test_cases = [
175
+ f"{MCP_REMOTE_TOOL_ID_PREFIX}missing_server::test_tool",
176
+ f"{MCP_LOCAL_TOOL_ID_PREFIX}missing_server::test_tool",
177
+ ]
178
+
179
+ for tool_id in test_cases:
180
+ with pytest.raises(
181
+ ValueError,
182
+ match="External tool server not found: missing_server in project ID test_project_id",
183
+ ):
184
+ tool_from_id(tool_id, task=mock_task)
185
+
186
+ def test_tool_from_id_rag_tool_success(self):
187
+ """Test that tool_from_id works with RAG tool IDs."""
188
+ # Create mock RAG config
189
+ from unittest.mock import patch
190
+
191
+ with (
192
+ patch("kiln_ai.tools.tool_registry.RagConfig") as mock_rag_config_class,
193
+ patch("kiln_ai.tools.rag_tools.RagTool") as mock_rag_tool_class,
194
+ ):
195
+ # Setup mock RAG config
196
+ mock_rag_config = Mock()
197
+ mock_rag_config.id = "test_rag_config"
198
+ mock_rag_config_class.from_id_and_parent_path.return_value = mock_rag_config
199
+
200
+ # Setup mock RAG tool
201
+ mock_rag_tool = Mock()
202
+ mock_rag_tool_class.return_value = mock_rag_tool
203
+
204
+ # Create mock project
205
+ mock_project = Mock(spec=Project)
206
+ mock_project.id = "test_project_id"
207
+ mock_project.path = Path("/test/path")
208
+
209
+ # Create mock task with parent project
210
+ mock_task = Mock(spec=Task)
211
+ mock_task.parent_project.return_value = mock_project
212
+
213
+ # Test with RAG tool ID
214
+ tool_id = f"{RAG_TOOL_ID_PREFIX}test_rag_config"
215
+ tool = tool_from_id(tool_id, task=mock_task)
216
+
217
+ # Verify the tool is RagTool
218
+ assert tool == mock_rag_tool
219
+ mock_rag_config_class.from_id_and_parent_path.assert_called_once_with(
220
+ "test_rag_config", Path("/test/path")
221
+ )
222
+ mock_rag_tool_class.assert_called_once_with(tool_id, mock_rag_config)
223
+
224
+ def test_tool_from_id_rag_tool_no_task(self):
225
+ """Test that RAG tool ID without task raises ValueError."""
226
+ tool_id = f"{RAG_TOOL_ID_PREFIX}test_rag_config"
227
+
228
+ with pytest.raises(
229
+ ValueError,
230
+ match=r"Unable to resolve tool from id.*Requires a parent project/task",
231
+ ):
232
+ tool_from_id(tool_id, task=None)
233
+
234
+ def test_tool_from_id_rag_tool_no_project(self):
235
+ """Test that RAG tool ID with task but no project raises ValueError."""
236
+ # Create mock task without parent project
237
+ mock_task = Mock(spec=Task)
238
+ mock_task.parent_project.return_value = None
239
+
240
+ tool_id = f"{RAG_TOOL_ID_PREFIX}test_rag_config"
241
+
242
+ with pytest.raises(
243
+ ValueError,
244
+ match=r"Unable to resolve tool from id.*Requires a parent project/task",
245
+ ):
246
+ tool_from_id(tool_id, task=mock_task)
247
+
248
+ def test_tool_from_id_rag_config_not_found(self):
249
+ """Test that RAG tool ID with missing RAG config raises ValueError."""
250
+ from unittest.mock import patch
251
+
252
+ with patch("kiln_ai.tools.tool_registry.RagConfig") as mock_rag_config_class:
253
+ # Setup mock to return None (config not found)
254
+ mock_rag_config_class.from_id_and_parent_path.return_value = None
255
+
256
+ # Create mock project
257
+ mock_project = Mock(spec=Project)
258
+ mock_project.id = "test_project_id"
259
+ mock_project.path = Path("/test/path")
260
+
261
+ # Create mock task with parent project
262
+ mock_task = Mock(spec=Task)
263
+ mock_task.parent_project.return_value = mock_project
264
+
265
+ tool_id = f"{RAG_TOOL_ID_PREFIX}missing_rag_config"
266
+
267
+ with pytest.raises(
268
+ ValueError,
269
+ match="RAG config not found: missing_rag_config in project test_project_id for tool",
270
+ ):
271
+ tool_from_id(tool_id, task=mock_task)
272
+
273
+ def test_all_built_in_tools_are_registered(self):
274
+ """Test that all KilnBuiltInToolId enum members are handled by the registry."""
275
+ for tool_id in KilnBuiltInToolId:
276
+ # This should not raise an exception
277
+ tool = tool_from_id(tool_id.value)
278
+ assert tool is not None
279
+
280
+ async def test_registry_returns_new_instances(self):
281
+ """Test that registry returns new instances each time (not singletons)."""
282
+ tool1 = tool_from_id(KilnBuiltInToolId.ADD_NUMBERS)
283
+ tool2 = tool_from_id(KilnBuiltInToolId.ADD_NUMBERS)
284
+
285
+ assert tool1 is not tool2 # Different instances
286
+ assert type(tool1) is type(tool2) # Same type
287
+ assert await tool1.id() == await tool2.id() # Same id
288
+
289
+ async def test_check_tool_id_valid_built_in_tools(self):
290
+ """Test that _check_tool_id accepts valid built-in tool IDs."""
291
+ for tool_id in KilnBuiltInToolId:
292
+ result = _check_tool_id(tool_id.value)
293
+ assert result == tool_id.value
294
+
295
+ def test_check_tool_id_invalid_tool_id(self):
296
+ """Test that _check_tool_id raises ValueError for invalid tool ID."""
297
+ with pytest.raises(ValueError, match="Invalid tool ID: invalid_tool_id"):
298
+ _check_tool_id("invalid_tool_id")
299
+
300
+ def test_check_tool_id_empty_string(self):
301
+ """Test that _check_tool_id raises ValueError for empty string."""
302
+ with pytest.raises(ValueError, match="Invalid tool ID: "):
303
+ _check_tool_id("")
304
+
305
+ def test_check_tool_id_none_value(self):
306
+ """Test that _check_tool_id raises ValueError for None."""
307
+ with pytest.raises(ValueError, match="Invalid tool ID: None"):
308
+ _check_tool_id(None) # type: ignore
309
+
310
+ def test_check_tool_id_valid_mcp_remote_tool_id(self):
311
+ """Test that _check_tool_id accepts valid MCP remote tool IDs."""
312
+ valid_mcp_ids = [
313
+ f"{MCP_REMOTE_TOOL_ID_PREFIX}server123::tool_name",
314
+ f"{MCP_REMOTE_TOOL_ID_PREFIX}my_server::echo",
315
+ f"{MCP_REMOTE_TOOL_ID_PREFIX}123456789::test_tool",
316
+ f"{MCP_REMOTE_TOOL_ID_PREFIX}server_with_underscores::complex_tool_name",
317
+ ]
318
+
319
+ for tool_id in valid_mcp_ids:
320
+ result = _check_tool_id(tool_id)
321
+ assert result == tool_id
322
+
323
+ def test_check_tool_id_valid_mcp_local_tool_id(self):
324
+ """Test that _check_tool_id accepts valid MCP local tool IDs."""
325
+ valid_mcp_local_ids = [
326
+ f"{MCP_LOCAL_TOOL_ID_PREFIX}server123::tool_name",
327
+ f"{MCP_LOCAL_TOOL_ID_PREFIX}my_local_server::calculate",
328
+ f"{MCP_LOCAL_TOOL_ID_PREFIX}local_tool_server::process_data",
329
+ f"{MCP_LOCAL_TOOL_ID_PREFIX}server_with_underscores::complex_tool_name",
330
+ ]
331
+
332
+ for tool_id in valid_mcp_local_ids:
333
+ result = _check_tool_id(tool_id)
334
+ assert result == tool_id
335
+
336
+ def test_check_tool_id_invalid_mcp_remote_tool_id(self):
337
+ """Test that _check_tool_id rejects invalid MCP-like tool IDs."""
338
+ # These start with the prefix but have wrong format - get specific MCP error
339
+ invalid_mcp_format_ids = [
340
+ "mcp::remote::server", # Missing tool name (only 3 parts instead of 4)
341
+ "mcp::remote::", # Missing server and tool name (only 3 parts)
342
+ "mcp::remote::::tool", # Empty server name (5 parts instead of 4)
343
+ "mcp::remote::server::tool::extra", # Too many parts (5 instead of 4)
344
+ ]
345
+
346
+ for invalid_id in invalid_mcp_format_ids:
347
+ with pytest.raises(
348
+ ValueError, match=f"Invalid remote MCP tool ID: {invalid_id}"
349
+ ):
350
+ _check_tool_id(invalid_id)
351
+
352
+ # These don't match the prefix - get generic error
353
+ invalid_generic_ids = [
354
+ "mcp::remote:", # Missing last colon (doesn't match full prefix)
355
+ "mcp:remote::server::tool", # Wrong prefix format
356
+ "mcp::remote_server::tool", # Wrong prefix format
357
+ "remote::server::tool", # Missing mcp prefix
358
+ ]
359
+
360
+ for invalid_id in invalid_generic_ids:
361
+ with pytest.raises(ValueError, match=f"Invalid tool ID: {invalid_id}"):
362
+ _check_tool_id(invalid_id)
363
+
364
+ def test_mcp_server_and_tool_name_from_id_valid_inputs(self):
365
+ """Test that mcp_server_and_tool_name_from_id correctly parses valid MCP tool IDs."""
366
+ test_cases = [
367
+ # Remote MCP tool IDs
368
+ ("mcp::remote::server123::tool_name", ("server123", "tool_name")),
369
+ ("mcp::remote::my_server::echo", ("my_server", "echo")),
370
+ ("mcp::remote::123456789::test_tool", ("123456789", "test_tool")),
371
+ (
372
+ "mcp::remote::server_with_underscores::complex_tool_name",
373
+ ("server_with_underscores", "complex_tool_name"),
374
+ ),
375
+ ("mcp::remote::a::b", ("a", "b")), # Minimal valid case
376
+ (
377
+ "mcp::remote::server-with-dashes::tool-with-dashes",
378
+ ("server-with-dashes", "tool-with-dashes"),
379
+ ),
380
+ # Local MCP tool IDs
381
+ ("mcp::local::local_server::calculate", ("local_server", "calculate")),
382
+ ("mcp::local::my_local_tool::process", ("my_local_tool", "process")),
383
+ (
384
+ "mcp::local::123456789::local_test_tool",
385
+ ("123456789", "local_test_tool"),
386
+ ),
387
+ (
388
+ "mcp::local::local_server_with_underscores::complex_local_tool",
389
+ ("local_server_with_underscores", "complex_local_tool"),
390
+ ),
391
+ ("mcp::local::x::y", ("x", "y")), # Minimal valid case for local
392
+ ]
393
+
394
+ for tool_id, expected in test_cases:
395
+ result = mcp_server_and_tool_name_from_id(tool_id)
396
+ assert result == expected, (
397
+ f"Failed for {tool_id}: expected {expected}, got {result}"
398
+ )
399
+
400
+ def test_mcp_server_and_tool_name_from_id_invalid_inputs(self):
401
+ """Test that mcp_server_and_tool_name_from_id raises ValueError for invalid MCP tool IDs."""
402
+ # Test remote MCP format errors
403
+ remote_invalid_inputs = [
404
+ "mcp::remote::server", # Only 3 parts instead of 4
405
+ "mcp::remote::", # Only 3 parts, missing server and tool
406
+ "mcp::remote::server::tool::extra", # 5 parts instead of 4
407
+ ]
408
+
409
+ for invalid_id in remote_invalid_inputs:
410
+ with pytest.raises(
411
+ ValueError,
412
+ match=r"Invalid remote MCP tool ID:.*Expected format.*mcp::remote::<server_id>::<tool_name>",
413
+ ):
414
+ mcp_server_and_tool_name_from_id(invalid_id)
415
+
416
+ # Test local MCP format errors
417
+ local_invalid_inputs = [
418
+ "mcp::local::server", # Only 3 parts instead of 4
419
+ "mcp::local::", # Only 3 parts, missing server and tool
420
+ "mcp::local::server::tool::extra", # 5 parts instead of 4
421
+ ]
422
+
423
+ for invalid_id in local_invalid_inputs:
424
+ with pytest.raises(
425
+ ValueError,
426
+ match=r"Invalid local MCP tool ID:.*Expected format.*mcp::local::<server_id>::<tool_name>",
427
+ ):
428
+ mcp_server_and_tool_name_from_id(invalid_id)
429
+
430
+ # Test generic MCP format errors (no valid prefix)
431
+ generic_invalid_inputs = [
432
+ "invalid::format::here", # 3 parts, wrong prefix
433
+ "", # Empty string
434
+ "single_part", # No separators
435
+ "two::parts", # Only 2 parts
436
+ ]
437
+
438
+ for invalid_id in generic_invalid_inputs:
439
+ with pytest.raises(
440
+ ValueError,
441
+ match=r"Invalid MCP tool ID:.*Expected format.*mcp::\(remote\|local\)::<server_id>::<tool_name>",
442
+ ):
443
+ mcp_server_and_tool_name_from_id(invalid_id)
444
+
445
+ def test_mcp_server_and_tool_name_from_id_edge_cases(self):
446
+ """Test that mcp_server_and_tool_name_from_id handles edge cases (empty parts allowed by parser)."""
447
+ # These are valid according to the parser (exactly 4 parts),
448
+ # but empty server_id/tool_name validation is handled by _check_tool_id
449
+ edge_cases = [
450
+ ("mcp::remote::::tool", ("", "tool")), # Empty server name
451
+ ("mcp::remote::server::", ("server", "")), # Empty tool name
452
+ ("mcp::remote::::", ("", "")), # Both empty
453
+ ]
454
+
455
+ for tool_id, expected in edge_cases:
456
+ result = mcp_server_and_tool_name_from_id(tool_id)
457
+ assert result == expected, (
458
+ f"Failed for {tool_id}: expected {expected}, got {result}"
459
+ )
460
+
461
+ @pytest.mark.parametrize(
462
+ "tool_id,expected_server,expected_tool",
463
+ [
464
+ ("mcp::remote::test_server::test_tool", "test_server", "test_tool"),
465
+ ("mcp::remote::s::t", "s", "t"),
466
+ (
467
+ "mcp::remote::long_server_name_123::complex_tool_name_456",
468
+ "long_server_name_123",
469
+ "complex_tool_name_456",
470
+ ),
471
+ (
472
+ "mcp::local::local_test_server::local_test_tool",
473
+ "local_test_server",
474
+ "local_test_tool",
475
+ ),
476
+ ("mcp::local::l::l", "l", "l"),
477
+ (
478
+ "mcp::local::long_local_server_123::complex_local_tool_456",
479
+ "long_local_server_123",
480
+ "complex_local_tool_456",
481
+ ),
482
+ ],
483
+ )
484
+ def test_mcp_server_and_tool_name_from_id_parametrized(
485
+ self, tool_id, expected_server, expected_tool
486
+ ):
487
+ """Parametrized test for mcp_server_and_tool_name_from_id with various valid inputs."""
488
+ server_id, tool_name = mcp_server_and_tool_name_from_id(tool_id)
489
+ assert server_id == expected_server
490
+ assert tool_name == expected_tool
491
+
492
+ def test_tool_from_id_mcp_missing_task_raises_error(self):
493
+ """Test that MCP tool ID with missing task raises ValueError."""
494
+ mcp_tool_id = f"{MCP_REMOTE_TOOL_ID_PREFIX}test_server::test_tool"
495
+
496
+ with pytest.raises(
497
+ ValueError,
498
+ match=r"Unable to resolve tool from id.*Requires a parent project/task",
499
+ ):
500
+ tool_from_id(mcp_tool_id, task=None)
501
+
502
+ def test_tool_from_id_mcp_functional_case(self):
503
+ """Test that MCP tool ID with valid task and project returns MCPServerTool."""
504
+ # Create mock external tool server
505
+ mock_server = ExternalToolServer(
506
+ name="test_server",
507
+ type=ToolServerType.remote_mcp,
508
+ description="Test MCP server",
509
+ properties={
510
+ "server_url": "https://example.com",
511
+ "headers": {},
512
+ },
513
+ )
514
+
515
+ # Create mock project with the external tool server
516
+ mock_project = Mock(spec=Project)
517
+ mock_project.id = "test_project_id"
518
+ mock_project.external_tool_servers.return_value = [mock_server]
519
+
520
+ # Create mock task with parent project
521
+ mock_task = Mock(spec=Task)
522
+ mock_task.parent_project.return_value = mock_project
523
+
524
+ mcp_tool_id = f"{MCP_REMOTE_TOOL_ID_PREFIX}{mock_server.id}::test_tool"
525
+
526
+ tool = tool_from_id(mcp_tool_id, task=mock_task)
527
+
528
+ assert isinstance(tool, MCPServerTool)
529
+ # Verify the tool was created with the correct server and tool name
530
+ assert tool._tool_server_model == mock_server
531
+ assert tool._name == "test_tool"
532
+
533
+ def test_tool_from_id_mcp_no_server_found_raises_error(self):
534
+ """Test that MCP tool ID with server not found raises ValueError."""
535
+ # Create mock external tool server with different ID
536
+ mock_server = ExternalToolServer(
537
+ name="different_server",
538
+ type=ToolServerType.remote_mcp,
539
+ description="Different MCP server",
540
+ properties={
541
+ "server_url": "https://example.com",
542
+ "headers": {},
543
+ },
544
+ )
545
+
546
+ # Create mock project with the external tool server
547
+ mock_project = Mock(spec=Project)
548
+ mock_project.id = "test_project_id"
549
+ mock_project.external_tool_servers.return_value = [mock_server]
550
+
551
+ # Create mock task with parent project
552
+ mock_task = Mock(spec=Task)
553
+ mock_task.parent_project.return_value = mock_project
554
+
555
+ # Use a tool ID with a server that doesn't exist in the project
556
+ mcp_tool_id = f"{MCP_REMOTE_TOOL_ID_PREFIX}nonexistent_server::test_tool"
557
+
558
+ with pytest.raises(
559
+ ValueError,
560
+ match="External tool server not found: nonexistent_server in project ID test_project_id",
561
+ ):
562
+ tool_from_id(mcp_tool_id, task=mock_task)
@@ -0,0 +1,85 @@
1
+ from kiln_ai.datamodel.rag import RagConfig
2
+ from kiln_ai.datamodel.task import Task
3
+ from kiln_ai.datamodel.tool_id import (
4
+ MCP_LOCAL_TOOL_ID_PREFIX,
5
+ MCP_REMOTE_TOOL_ID_PREFIX,
6
+ RAG_TOOL_ID_PREFIX,
7
+ KilnBuiltInToolId,
8
+ mcp_server_and_tool_name_from_id,
9
+ rag_config_id_from_id,
10
+ )
11
+ from kiln_ai.tools.base_tool import KilnToolInterface
12
+ from kiln_ai.tools.built_in_tools.math_tools import (
13
+ AddTool,
14
+ DivideTool,
15
+ MultiplyTool,
16
+ SubtractTool,
17
+ )
18
+ from kiln_ai.tools.mcp_server_tool import MCPServerTool
19
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
20
+
21
+
22
+ def tool_from_id(tool_id: str, task: Task | None = None) -> KilnToolInterface:
23
+ """
24
+ Get a tool from its ID.
25
+ """
26
+ # Check built-in tools
27
+ if tool_id in [member.value for member in KilnBuiltInToolId]:
28
+ typed_tool_id = KilnBuiltInToolId(tool_id)
29
+ match typed_tool_id:
30
+ case KilnBuiltInToolId.ADD_NUMBERS:
31
+ return AddTool()
32
+ case KilnBuiltInToolId.SUBTRACT_NUMBERS:
33
+ return SubtractTool()
34
+ case KilnBuiltInToolId.MULTIPLY_NUMBERS:
35
+ return MultiplyTool()
36
+ case KilnBuiltInToolId.DIVIDE_NUMBERS:
37
+ return DivideTool()
38
+ case _:
39
+ raise_exhaustive_enum_error(typed_tool_id)
40
+
41
+ # Check MCP Server Tools
42
+ if tool_id.startswith((MCP_REMOTE_TOOL_ID_PREFIX, MCP_LOCAL_TOOL_ID_PREFIX)):
43
+ project = task.parent_project() if task is not None else None
44
+ if project is None:
45
+ raise ValueError(
46
+ f"Unable to resolve tool from id: {tool_id}. Requires a parent project/task."
47
+ )
48
+
49
+ # Get the tool server ID and tool name from the ID
50
+ tool_server_id, tool_name = mcp_server_and_tool_name_from_id(tool_id)
51
+
52
+ server = next(
53
+ (
54
+ server
55
+ for server in project.external_tool_servers()
56
+ if server.id == tool_server_id
57
+ ),
58
+ None,
59
+ )
60
+ if server is None:
61
+ raise ValueError(
62
+ f"External tool server not found: {tool_server_id} in project ID {project.id}"
63
+ )
64
+
65
+ return MCPServerTool(server, tool_name)
66
+ elif tool_id.startswith(RAG_TOOL_ID_PREFIX):
67
+ project = task.parent_project() if task is not None else None
68
+ if project is None:
69
+ raise ValueError(
70
+ f"Unable to resolve tool from id: {tool_id}. Requires a parent project/task."
71
+ )
72
+
73
+ rag_config_id = rag_config_id_from_id(tool_id)
74
+ rag_config = RagConfig.from_id_and_parent_path(rag_config_id, project.path)
75
+ if rag_config is None:
76
+ raise ValueError(
77
+ f"RAG config not found: {rag_config_id} in project {project.id} for tool {tool_id}"
78
+ )
79
+
80
+ # Lazy import to avoid circular dependency
81
+ from kiln_ai.tools.rag_tools import RagTool
82
+
83
+ return RagTool(tool_id, rag_config)
84
+
85
+ raise ValueError(f"Tool ID {tool_id} not found in tool registry")
kiln_ai/utils/__init__.py CHANGED
@@ -5,8 +5,11 @@ Misc utilities used in the kiln_ai library.
5
5
  """
6
6
 
7
7
  from . import config, formatting
8
+ from .lock import AsyncLockManager, shared_async_lock_manager
8
9
 
9
10
  __all__ = [
11
+ "AsyncLockManager",
10
12
  "config",
11
13
  "formatting",
14
+ "shared_async_lock_manager",
12
15
  ]