kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +6 -0
- kiln_ai/adapters/adapter_registry.py +43 -226
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/eval_runner.py +6 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -3
- kiln_ai/adapters/eval/test_g_eval.py +1 -1
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +494 -0
- kiln_ai/adapters/ml_model_list.py +876 -18
- kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/provider_tools.py +190 -46
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/test_adapter_registry.py +579 -86
- kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
- kiln_ai/adapters/test_ml_model_list.py +202 -0
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +199 -8
- kiln_ai/adapters/test_remote_config.py +551 -56
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +16 -13
- kiln_ai/datamodel/basemodel.py +201 -4
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +27 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +317 -0
- kiln_ai/datamodel/project.py +33 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +270 -14
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +501 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +187 -1
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +58 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/base_tool.py +12 -3
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +158 -0
- kiln_ai/tools/mcp_server_tool.py +2 -2
- kiln_ai/tools/mcp_session_manager.py +51 -22
- kiln_ai/tools/rag_tools.py +164 -0
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +187 -227
- kiln_ai/tools/test_rag_tools.py +929 -0
- kiln_ai/tools/test_tool_registry.py +290 -7
- kiln_ai/tools/tool_registry.py +69 -16
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +2 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +59 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +86 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
- kiln_ai-0.22.0.dist-info/RECORD +213 -0
- kiln_ai-0.20.1.dist-info/RECORD +0 -138
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,15 +1,22 @@
|
|
|
1
|
+
from pathlib import Path
|
|
1
2
|
from unittest.mock import Mock
|
|
2
3
|
|
|
3
4
|
import pytest
|
|
4
5
|
|
|
5
|
-
from kiln_ai.datamodel.external_tool_server import
|
|
6
|
+
from kiln_ai.datamodel.external_tool_server import (
|
|
7
|
+
ExternalToolServer,
|
|
8
|
+
ToolServerType,
|
|
9
|
+
)
|
|
6
10
|
from kiln_ai.datamodel.project import Project
|
|
7
11
|
from kiln_ai.datamodel.task import Task
|
|
8
12
|
from kiln_ai.datamodel.tool_id import (
|
|
13
|
+
KILN_TASK_TOOL_ID_PREFIX,
|
|
9
14
|
MCP_LOCAL_TOOL_ID_PREFIX,
|
|
10
15
|
MCP_REMOTE_TOOL_ID_PREFIX,
|
|
16
|
+
RAG_TOOL_ID_PREFIX,
|
|
11
17
|
KilnBuiltInToolId,
|
|
12
18
|
_check_tool_id,
|
|
19
|
+
kiln_task_server_id_from_tool_id,
|
|
13
20
|
mcp_server_and_tool_name_from_id,
|
|
14
21
|
)
|
|
15
22
|
from kiln_ai.tools.built_in_tools.math_tools import (
|
|
@@ -18,6 +25,7 @@ from kiln_ai.tools.built_in_tools.math_tools import (
|
|
|
18
25
|
MultiplyTool,
|
|
19
26
|
SubtractTool,
|
|
20
27
|
)
|
|
28
|
+
from kiln_ai.tools.kiln_task_tool import KilnTaskTool
|
|
21
29
|
from kiln_ai.tools.mcp_server_tool import MCPServerTool
|
|
22
30
|
from kiln_ai.tools.tool_registry import tool_from_id
|
|
23
31
|
|
|
@@ -85,7 +93,6 @@ class TestToolRegistry:
|
|
|
85
93
|
type=ToolServerType.remote_mcp,
|
|
86
94
|
properties={
|
|
87
95
|
"server_url": "https://example.com",
|
|
88
|
-
"headers": {},
|
|
89
96
|
},
|
|
90
97
|
)
|
|
91
98
|
|
|
@@ -143,7 +150,7 @@ class TestToolRegistry:
|
|
|
143
150
|
tool_id = f"{MCP_LOCAL_TOOL_ID_PREFIX}test_server::test_tool"
|
|
144
151
|
with pytest.raises(
|
|
145
152
|
ValueError,
|
|
146
|
-
match="Unable to resolve tool from id.*Requires a parent project/task",
|
|
153
|
+
match=r"Unable to resolve tool from id.*Requires a parent project/task",
|
|
147
154
|
):
|
|
148
155
|
tool_from_id(tool_id, task=None)
|
|
149
156
|
|
|
@@ -155,7 +162,6 @@ class TestToolRegistry:
|
|
|
155
162
|
type=ToolServerType.remote_mcp,
|
|
156
163
|
properties={
|
|
157
164
|
"server_url": "https://example.com",
|
|
158
|
-
"headers": {},
|
|
159
165
|
},
|
|
160
166
|
)
|
|
161
167
|
|
|
@@ -181,6 +187,93 @@ class TestToolRegistry:
|
|
|
181
187
|
):
|
|
182
188
|
tool_from_id(tool_id, task=mock_task)
|
|
183
189
|
|
|
190
|
+
def test_tool_from_id_rag_tool_success(self):
|
|
191
|
+
"""Test that tool_from_id works with RAG tool IDs."""
|
|
192
|
+
# Create mock RAG config
|
|
193
|
+
from unittest.mock import patch
|
|
194
|
+
|
|
195
|
+
with (
|
|
196
|
+
patch("kiln_ai.tools.tool_registry.RagConfig") as mock_rag_config_class,
|
|
197
|
+
patch("kiln_ai.tools.rag_tools.RagTool") as mock_rag_tool_class,
|
|
198
|
+
):
|
|
199
|
+
# Setup mock RAG config
|
|
200
|
+
mock_rag_config = Mock()
|
|
201
|
+
mock_rag_config.id = "test_rag_config"
|
|
202
|
+
mock_rag_config_class.from_id_and_parent_path.return_value = mock_rag_config
|
|
203
|
+
|
|
204
|
+
# Setup mock RAG tool
|
|
205
|
+
mock_rag_tool = Mock()
|
|
206
|
+
mock_rag_tool_class.return_value = mock_rag_tool
|
|
207
|
+
|
|
208
|
+
# Create mock project
|
|
209
|
+
mock_project = Mock(spec=Project)
|
|
210
|
+
mock_project.id = "test_project_id"
|
|
211
|
+
mock_project.path = Path("/test/path")
|
|
212
|
+
|
|
213
|
+
# Create mock task with parent project
|
|
214
|
+
mock_task = Mock(spec=Task)
|
|
215
|
+
mock_task.parent_project.return_value = mock_project
|
|
216
|
+
|
|
217
|
+
# Test with RAG tool ID
|
|
218
|
+
tool_id = f"{RAG_TOOL_ID_PREFIX}test_rag_config"
|
|
219
|
+
tool = tool_from_id(tool_id, task=mock_task)
|
|
220
|
+
|
|
221
|
+
# Verify the tool is RagTool
|
|
222
|
+
assert tool == mock_rag_tool
|
|
223
|
+
mock_rag_config_class.from_id_and_parent_path.assert_called_once_with(
|
|
224
|
+
"test_rag_config", Path("/test/path")
|
|
225
|
+
)
|
|
226
|
+
mock_rag_tool_class.assert_called_once_with(tool_id, mock_rag_config)
|
|
227
|
+
|
|
228
|
+
def test_tool_from_id_rag_tool_no_task(self):
|
|
229
|
+
"""Test that RAG tool ID without task raises ValueError."""
|
|
230
|
+
tool_id = f"{RAG_TOOL_ID_PREFIX}test_rag_config"
|
|
231
|
+
|
|
232
|
+
with pytest.raises(
|
|
233
|
+
ValueError,
|
|
234
|
+
match=r"Unable to resolve tool from id.*Requires a parent project/task",
|
|
235
|
+
):
|
|
236
|
+
tool_from_id(tool_id, task=None)
|
|
237
|
+
|
|
238
|
+
def test_tool_from_id_rag_tool_no_project(self):
|
|
239
|
+
"""Test that RAG tool ID with task but no project raises ValueError."""
|
|
240
|
+
# Create mock task without parent project
|
|
241
|
+
mock_task = Mock(spec=Task)
|
|
242
|
+
mock_task.parent_project.return_value = None
|
|
243
|
+
|
|
244
|
+
tool_id = f"{RAG_TOOL_ID_PREFIX}test_rag_config"
|
|
245
|
+
|
|
246
|
+
with pytest.raises(
|
|
247
|
+
ValueError,
|
|
248
|
+
match=r"Unable to resolve tool from id.*Requires a parent project/task",
|
|
249
|
+
):
|
|
250
|
+
tool_from_id(tool_id, task=mock_task)
|
|
251
|
+
|
|
252
|
+
def test_tool_from_id_rag_config_not_found(self):
|
|
253
|
+
"""Test that RAG tool ID with missing RAG config raises ValueError."""
|
|
254
|
+
from unittest.mock import patch
|
|
255
|
+
|
|
256
|
+
with patch("kiln_ai.tools.tool_registry.RagConfig") as mock_rag_config_class:
|
|
257
|
+
# Setup mock to return None (config not found)
|
|
258
|
+
mock_rag_config_class.from_id_and_parent_path.return_value = None
|
|
259
|
+
|
|
260
|
+
# Create mock project
|
|
261
|
+
mock_project = Mock(spec=Project)
|
|
262
|
+
mock_project.id = "test_project_id"
|
|
263
|
+
mock_project.path = Path("/test/path")
|
|
264
|
+
|
|
265
|
+
# Create mock task with parent project
|
|
266
|
+
mock_task = Mock(spec=Task)
|
|
267
|
+
mock_task.parent_project.return_value = mock_project
|
|
268
|
+
|
|
269
|
+
tool_id = f"{RAG_TOOL_ID_PREFIX}missing_rag_config"
|
|
270
|
+
|
|
271
|
+
with pytest.raises(
|
|
272
|
+
ValueError,
|
|
273
|
+
match="RAG config not found: missing_rag_config in project test_project_id for tool",
|
|
274
|
+
):
|
|
275
|
+
tool_from_id(tool_id, task=mock_task)
|
|
276
|
+
|
|
184
277
|
def test_all_built_in_tools_are_registered(self):
|
|
185
278
|
"""Test that all KilnBuiltInToolId enum members are handled by the registry."""
|
|
186
279
|
for tool_id in KilnBuiltInToolId:
|
|
@@ -272,6 +365,48 @@ class TestToolRegistry:
|
|
|
272
365
|
with pytest.raises(ValueError, match=f"Invalid tool ID: {invalid_id}"):
|
|
273
366
|
_check_tool_id(invalid_id)
|
|
274
367
|
|
|
368
|
+
def test_check_tool_id_valid_kiln_task_tool_id(self):
|
|
369
|
+
"""Test that _check_tool_id accepts valid Kiln task tool IDs."""
|
|
370
|
+
valid_kiln_task_ids = [
|
|
371
|
+
f"{KILN_TASK_TOOL_ID_PREFIX}server123",
|
|
372
|
+
f"{KILN_TASK_TOOL_ID_PREFIX}my_task_server",
|
|
373
|
+
f"{KILN_TASK_TOOL_ID_PREFIX}123456789",
|
|
374
|
+
f"{KILN_TASK_TOOL_ID_PREFIX}server_with_underscores",
|
|
375
|
+
f"{KILN_TASK_TOOL_ID_PREFIX}server-with-dashes",
|
|
376
|
+
]
|
|
377
|
+
|
|
378
|
+
for tool_id in valid_kiln_task_ids:
|
|
379
|
+
result = _check_tool_id(tool_id)
|
|
380
|
+
assert result == tool_id
|
|
381
|
+
|
|
382
|
+
def test_check_tool_id_invalid_kiln_task_tool_id(self):
|
|
383
|
+
"""Test that _check_tool_id rejects invalid Kiln task tool IDs."""
|
|
384
|
+
# These start with the prefix but have wrong format
|
|
385
|
+
invalid_kiln_task_format_ids = [
|
|
386
|
+
f"{KILN_TASK_TOOL_ID_PREFIX}", # Missing server ID
|
|
387
|
+
f"{KILN_TASK_TOOL_ID_PREFIX}::", # Empty server ID
|
|
388
|
+
f"{KILN_TASK_TOOL_ID_PREFIX}server::tool", # Too many parts (3 instead of 2)
|
|
389
|
+
f"{KILN_TASK_TOOL_ID_PREFIX}server::tool::extra", # Too many parts (4 instead of 2)
|
|
390
|
+
]
|
|
391
|
+
|
|
392
|
+
for invalid_id in invalid_kiln_task_format_ids:
|
|
393
|
+
with pytest.raises(
|
|
394
|
+
ValueError, match=f"Invalid Kiln task tool ID format: {invalid_id}"
|
|
395
|
+
):
|
|
396
|
+
_check_tool_id(invalid_id)
|
|
397
|
+
|
|
398
|
+
# These don't match the prefix - get generic error
|
|
399
|
+
invalid_generic_ids = [
|
|
400
|
+
"kiln_task:", # Missing last colon (doesn't match full prefix)
|
|
401
|
+
"kiln:task::server", # Wrong prefix format
|
|
402
|
+
"kiln_task_server", # Missing colons
|
|
403
|
+
"task::server", # Missing kiln prefix
|
|
404
|
+
]
|
|
405
|
+
|
|
406
|
+
for invalid_id in invalid_generic_ids:
|
|
407
|
+
with pytest.raises(ValueError, match=f"Invalid tool ID: {invalid_id}"):
|
|
408
|
+
_check_tool_id(invalid_id)
|
|
409
|
+
|
|
275
410
|
def test_mcp_server_and_tool_name_from_id_valid_inputs(self):
|
|
276
411
|
"""Test that mcp_server_and_tool_name_from_id correctly parses valid MCP tool IDs."""
|
|
277
412
|
test_cases = [
|
|
@@ -400,13 +535,71 @@ class TestToolRegistry:
|
|
|
400
535
|
assert server_id == expected_server
|
|
401
536
|
assert tool_name == expected_tool
|
|
402
537
|
|
|
538
|
+
def test_kiln_task_server_id_from_tool_id_valid_inputs(self):
|
|
539
|
+
"""Test that kiln_task_server_id_from_tool_id correctly parses valid Kiln task tool IDs."""
|
|
540
|
+
test_cases = [
|
|
541
|
+
("kiln_task::server123", "server123"),
|
|
542
|
+
("kiln_task::my_task_server", "my_task_server"),
|
|
543
|
+
("kiln_task::123456789", "123456789"),
|
|
544
|
+
("kiln_task::server_with_underscores", "server_with_underscores"),
|
|
545
|
+
("kiln_task::server-with-dashes", "server-with-dashes"),
|
|
546
|
+
("kiln_task::a", "a"), # Minimal valid case
|
|
547
|
+
(
|
|
548
|
+
"kiln_task::very_long_server_name_with_numbers_123",
|
|
549
|
+
"very_long_server_name_with_numbers_123",
|
|
550
|
+
),
|
|
551
|
+
]
|
|
552
|
+
|
|
553
|
+
for tool_id, expected_server_id in test_cases:
|
|
554
|
+
result = kiln_task_server_id_from_tool_id(tool_id)
|
|
555
|
+
assert result == expected_server_id, (
|
|
556
|
+
f"Failed for {tool_id}: expected {expected_server_id}, got {result}"
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
def test_kiln_task_server_id_from_tool_id_invalid_inputs(self):
|
|
560
|
+
"""Test that kiln_task_server_id_from_tool_id raises ValueError for invalid Kiln task tool IDs."""
|
|
561
|
+
invalid_inputs = [
|
|
562
|
+
"kiln_task::", # Empty server ID
|
|
563
|
+
"kiln_task::server::tool", # Too many parts (3 instead of 2)
|
|
564
|
+
"kiln_task::server::tool::extra", # Too many parts (4 instead of 2)
|
|
565
|
+
"invalid::format", # Wrong prefix
|
|
566
|
+
"", # Empty string
|
|
567
|
+
"single_part", # No separators
|
|
568
|
+
"two::parts", # Only 2 parts but wrong prefix
|
|
569
|
+
"kiln_task", # Missing colons
|
|
570
|
+
]
|
|
571
|
+
|
|
572
|
+
for invalid_id in invalid_inputs:
|
|
573
|
+
with pytest.raises(
|
|
574
|
+
ValueError,
|
|
575
|
+
match=r"Invalid Kiln task tool ID format:.*Expected format.*kiln_task::<server_id>",
|
|
576
|
+
):
|
|
577
|
+
kiln_task_server_id_from_tool_id(invalid_id)
|
|
578
|
+
|
|
579
|
+
@pytest.mark.parametrize(
|
|
580
|
+
"tool_id,expected_server_id",
|
|
581
|
+
[
|
|
582
|
+
("kiln_task::test_server", "test_server"),
|
|
583
|
+
("kiln_task::s", "s"),
|
|
584
|
+
("kiln_task::long_server_name_123", "long_server_name_123"),
|
|
585
|
+
("kiln_task::server-with-dashes", "server-with-dashes"),
|
|
586
|
+
("kiln_task::server_with_underscores", "server_with_underscores"),
|
|
587
|
+
],
|
|
588
|
+
)
|
|
589
|
+
def test_kiln_task_server_id_from_tool_id_parametrized(
|
|
590
|
+
self, tool_id, expected_server_id
|
|
591
|
+
):
|
|
592
|
+
"""Parametrized test for kiln_task_server_id_from_tool_id with various valid inputs."""
|
|
593
|
+
server_id = kiln_task_server_id_from_tool_id(tool_id)
|
|
594
|
+
assert server_id == expected_server_id
|
|
595
|
+
|
|
403
596
|
def test_tool_from_id_mcp_missing_task_raises_error(self):
|
|
404
597
|
"""Test that MCP tool ID with missing task raises ValueError."""
|
|
405
598
|
mcp_tool_id = f"{MCP_REMOTE_TOOL_ID_PREFIX}test_server::test_tool"
|
|
406
599
|
|
|
407
600
|
with pytest.raises(
|
|
408
601
|
ValueError,
|
|
409
|
-
match="Unable to resolve tool from id.*Requires a parent project/task",
|
|
602
|
+
match=r"Unable to resolve tool from id.*Requires a parent project/task",
|
|
410
603
|
):
|
|
411
604
|
tool_from_id(mcp_tool_id, task=None)
|
|
412
605
|
|
|
@@ -419,7 +612,6 @@ class TestToolRegistry:
|
|
|
419
612
|
description="Test MCP server",
|
|
420
613
|
properties={
|
|
421
614
|
"server_url": "https://example.com",
|
|
422
|
-
"headers": {},
|
|
423
615
|
},
|
|
424
616
|
)
|
|
425
617
|
|
|
@@ -450,7 +642,6 @@ class TestToolRegistry:
|
|
|
450
642
|
description="Different MCP server",
|
|
451
643
|
properties={
|
|
452
644
|
"server_url": "https://example.com",
|
|
453
|
-
"headers": {},
|
|
454
645
|
},
|
|
455
646
|
)
|
|
456
647
|
|
|
@@ -471,3 +662,95 @@ class TestToolRegistry:
|
|
|
471
662
|
match="External tool server not found: nonexistent_server in project ID test_project_id",
|
|
472
663
|
):
|
|
473
664
|
tool_from_id(mcp_tool_id, task=mock_task)
|
|
665
|
+
|
|
666
|
+
def test_tool_from_id_kiln_task_tool_success(self):
|
|
667
|
+
"""Test that tool_from_id works with Kiln task tool IDs."""
|
|
668
|
+
# Create mock external tool server for Kiln task
|
|
669
|
+
mock_server = ExternalToolServer(
|
|
670
|
+
name="test_kiln_task_server",
|
|
671
|
+
type=ToolServerType.kiln_task,
|
|
672
|
+
description="Test Kiln task server",
|
|
673
|
+
properties={
|
|
674
|
+
"name": "test_task_tool",
|
|
675
|
+
"description": "A test task tool",
|
|
676
|
+
"task_id": "test_task_123",
|
|
677
|
+
"run_config_id": "test_config_456",
|
|
678
|
+
"is_archived": False,
|
|
679
|
+
},
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
# Create mock project with the external tool server
|
|
683
|
+
mock_project = Mock(spec=Project)
|
|
684
|
+
mock_project.id = "test_project_id"
|
|
685
|
+
mock_project.external_tool_servers.return_value = [mock_server]
|
|
686
|
+
|
|
687
|
+
# Create mock task with parent project
|
|
688
|
+
mock_task = Mock(spec=Task)
|
|
689
|
+
mock_task.parent_project.return_value = mock_project
|
|
690
|
+
|
|
691
|
+
# Test with Kiln task tool ID
|
|
692
|
+
tool_id = f"{KILN_TASK_TOOL_ID_PREFIX}{mock_server.id}"
|
|
693
|
+
tool = tool_from_id(tool_id, task=mock_task)
|
|
694
|
+
|
|
695
|
+
# Verify the tool is KilnTaskTool
|
|
696
|
+
assert isinstance(tool, KilnTaskTool)
|
|
697
|
+
assert tool._project_id == "test_project_id"
|
|
698
|
+
assert tool._tool_id == tool_id
|
|
699
|
+
assert tool._tool_server_model == mock_server
|
|
700
|
+
|
|
701
|
+
def test_tool_from_id_kiln_task_tool_no_task(self):
|
|
702
|
+
"""Test that Kiln task tool ID without task raises ValueError."""
|
|
703
|
+
tool_id = f"{KILN_TASK_TOOL_ID_PREFIX}test_server"
|
|
704
|
+
with pytest.raises(
|
|
705
|
+
ValueError,
|
|
706
|
+
match=r"Unable to resolve tool from id.*Requires a parent project/task",
|
|
707
|
+
):
|
|
708
|
+
tool_from_id(tool_id, task=None)
|
|
709
|
+
|
|
710
|
+
def test_tool_from_id_kiln_task_tool_no_project(self):
|
|
711
|
+
"""Test that Kiln task tool ID with task but no project raises ValueError."""
|
|
712
|
+
# Create mock task without parent project
|
|
713
|
+
mock_task = Mock(spec=Task)
|
|
714
|
+
mock_task.parent_project.return_value = None
|
|
715
|
+
|
|
716
|
+
tool_id = f"{KILN_TASK_TOOL_ID_PREFIX}test_server"
|
|
717
|
+
|
|
718
|
+
with pytest.raises(
|
|
719
|
+
ValueError,
|
|
720
|
+
match=r"Unable to resolve tool from id.*Requires a parent project/task",
|
|
721
|
+
):
|
|
722
|
+
tool_from_id(tool_id, task=mock_task)
|
|
723
|
+
|
|
724
|
+
def test_tool_from_id_kiln_task_tool_server_not_found(self):
|
|
725
|
+
"""Test that Kiln task tool ID with server not found raises ValueError."""
|
|
726
|
+
# Create mock external tool server with different ID
|
|
727
|
+
mock_server = ExternalToolServer(
|
|
728
|
+
name="different_server",
|
|
729
|
+
type=ToolServerType.kiln_task,
|
|
730
|
+
description="Different Kiln task server",
|
|
731
|
+
properties={
|
|
732
|
+
"name": "different_tool",
|
|
733
|
+
"description": "A different task tool",
|
|
734
|
+
"task_id": "different_task_123",
|
|
735
|
+
"run_config_id": "different_config_456",
|
|
736
|
+
"is_archived": False,
|
|
737
|
+
},
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
# Create mock project with the external tool server
|
|
741
|
+
mock_project = Mock(spec=Project)
|
|
742
|
+
mock_project.id = "test_project_id"
|
|
743
|
+
mock_project.external_tool_servers.return_value = [mock_server]
|
|
744
|
+
|
|
745
|
+
# Create mock task with parent project
|
|
746
|
+
mock_task = Mock(spec=Task)
|
|
747
|
+
mock_task.parent_project.return_value = mock_project
|
|
748
|
+
|
|
749
|
+
# Use a tool ID with a server that doesn't exist in the project
|
|
750
|
+
tool_id = f"{KILN_TASK_TOOL_ID_PREFIX}nonexistent_server"
|
|
751
|
+
|
|
752
|
+
with pytest.raises(
|
|
753
|
+
ValueError,
|
|
754
|
+
match="Kiln Task External tool server not found: nonexistent_server in project ID test_project_id",
|
|
755
|
+
):
|
|
756
|
+
tool_from_id(tool_id, task=mock_task)
|
kiln_ai/tools/tool_registry.py
CHANGED
|
@@ -1,9 +1,14 @@
|
|
|
1
|
+
from kiln_ai.datamodel.rag import RagConfig
|
|
1
2
|
from kiln_ai.datamodel.task import Task
|
|
2
3
|
from kiln_ai.datamodel.tool_id import (
|
|
4
|
+
KILN_TASK_TOOL_ID_PREFIX,
|
|
3
5
|
MCP_LOCAL_TOOL_ID_PREFIX,
|
|
4
6
|
MCP_REMOTE_TOOL_ID_PREFIX,
|
|
7
|
+
RAG_TOOL_ID_PREFIX,
|
|
5
8
|
KilnBuiltInToolId,
|
|
9
|
+
kiln_task_server_id_from_tool_id,
|
|
6
10
|
mcp_server_and_tool_name_from_id,
|
|
11
|
+
rag_config_id_from_id,
|
|
7
12
|
)
|
|
8
13
|
from kiln_ai.tools.base_tool import KilnToolInterface
|
|
9
14
|
from kiln_ai.tools.built_in_tools.math_tools import (
|
|
@@ -12,6 +17,7 @@ from kiln_ai.tools.built_in_tools.math_tools import (
|
|
|
12
17
|
MultiplyTool,
|
|
13
18
|
SubtractTool,
|
|
14
19
|
)
|
|
20
|
+
from kiln_ai.tools.kiln_task_tool import KilnTaskTool
|
|
15
21
|
from kiln_ai.tools.mcp_server_tool import MCPServerTool
|
|
16
22
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
17
23
|
|
|
@@ -35,30 +41,77 @@ def tool_from_id(tool_id: str, task: Task | None = None) -> KilnToolInterface:
|
|
|
35
41
|
case _:
|
|
36
42
|
raise_exhaustive_enum_error(typed_tool_id)
|
|
37
43
|
|
|
38
|
-
# Check MCP
|
|
39
|
-
|
|
44
|
+
# Check if this looks like an MCP or Kiln Task tool ID that requires a project
|
|
45
|
+
is_mcp_tool = tool_id.startswith(
|
|
46
|
+
(MCP_REMOTE_TOOL_ID_PREFIX, MCP_LOCAL_TOOL_ID_PREFIX)
|
|
47
|
+
)
|
|
48
|
+
is_kiln_task_tool = tool_id.startswith(KILN_TASK_TOOL_ID_PREFIX)
|
|
49
|
+
|
|
50
|
+
if is_mcp_tool or is_kiln_task_tool:
|
|
40
51
|
project = task.parent_project() if task is not None else None
|
|
41
|
-
if project is None:
|
|
52
|
+
if project is None or project.id is None:
|
|
42
53
|
raise ValueError(
|
|
43
54
|
f"Unable to resolve tool from id: {tool_id}. Requires a parent project/task."
|
|
44
55
|
)
|
|
45
56
|
|
|
46
|
-
#
|
|
47
|
-
|
|
57
|
+
# Check MCP Server Tools
|
|
58
|
+
if is_mcp_tool:
|
|
59
|
+
# Get the tool server ID and tool name from the ID
|
|
60
|
+
tool_server_id, tool_name = mcp_server_and_tool_name_from_id(
|
|
61
|
+
tool_id
|
|
62
|
+
) # Fixed function name
|
|
63
|
+
|
|
64
|
+
server = next(
|
|
65
|
+
(
|
|
66
|
+
server
|
|
67
|
+
for server in project.external_tool_servers()
|
|
68
|
+
if server.id == tool_server_id
|
|
69
|
+
),
|
|
70
|
+
None,
|
|
71
|
+
)
|
|
72
|
+
if server is None:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"External tool server not found: {tool_server_id} in project ID {project.id}"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
return MCPServerTool(server, tool_name)
|
|
78
|
+
|
|
79
|
+
# Check Kiln Task Tools
|
|
80
|
+
if is_kiln_task_tool:
|
|
81
|
+
server_id = kiln_task_server_id_from_tool_id(tool_id)
|
|
82
|
+
|
|
83
|
+
server = next(
|
|
84
|
+
(
|
|
85
|
+
server
|
|
86
|
+
for server in project.external_tool_servers()
|
|
87
|
+
if server.id == server_id
|
|
88
|
+
),
|
|
89
|
+
None,
|
|
90
|
+
)
|
|
91
|
+
if server is None:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Kiln Task External tool server not found: {server_id} in project ID {project.id}"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return KilnTaskTool(project.id, tool_id, server)
|
|
97
|
+
|
|
98
|
+
elif tool_id.startswith(RAG_TOOL_ID_PREFIX):
|
|
99
|
+
project = task.parent_project() if task is not None else None
|
|
100
|
+
if project is None:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"Unable to resolve tool from id: {tool_id}. Requires a parent project/task."
|
|
103
|
+
)
|
|
48
104
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
for server in project.external_tool_servers()
|
|
53
|
-
if server.id == tool_server_id
|
|
54
|
-
),
|
|
55
|
-
None,
|
|
56
|
-
)
|
|
57
|
-
if server is None:
|
|
105
|
+
rag_config_id = rag_config_id_from_id(tool_id)
|
|
106
|
+
rag_config = RagConfig.from_id_and_parent_path(rag_config_id, project.path)
|
|
107
|
+
if rag_config is None:
|
|
58
108
|
raise ValueError(
|
|
59
|
-
f"
|
|
109
|
+
f"RAG config not found: {rag_config_id} in project {project.id} for tool {tool_id}"
|
|
60
110
|
)
|
|
61
111
|
|
|
62
|
-
|
|
112
|
+
# Lazy import to avoid circular dependency
|
|
113
|
+
from kiln_ai.tools.rag_tools import RagTool
|
|
114
|
+
|
|
115
|
+
return RagTool(tool_id, rag_config)
|
|
63
116
|
|
|
64
117
|
raise ValueError(f"Tool ID {tool_id} not found in tool registry")
|
kiln_ai/utils/__init__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import logging
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import AsyncGenerator, Awaitable, Callable, List, TypeVar
|
|
4
|
+
from typing import AsyncGenerator, Awaitable, Callable, Generic, List, TypeVar
|
|
5
5
|
|
|
6
6
|
logger = logging.getLogger(__name__)
|
|
7
7
|
|
|
@@ -15,29 +15,66 @@ class Progress:
|
|
|
15
15
|
errors: int
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
class
|
|
19
|
-
def
|
|
18
|
+
class AsyncJobRunnerObserver(Generic[T]):
|
|
19
|
+
async def on_error(self, job: T, error: Exception):
|
|
20
|
+
"""
|
|
21
|
+
Called when a job raises an unhandled exception.
|
|
22
|
+
"""
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
async def on_success(self, job: T):
|
|
26
|
+
"""
|
|
27
|
+
Called when a job completes successfully.
|
|
28
|
+
"""
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
async def on_job_start(self, job: T):
|
|
32
|
+
"""
|
|
33
|
+
Called when a job starts.
|
|
34
|
+
"""
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class AsyncJobRunner(Generic[T]):
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
jobs: List[T],
|
|
42
|
+
run_job_fn: Callable[[T], Awaitable[bool]],
|
|
43
|
+
concurrency: int = 1,
|
|
44
|
+
observers: List[AsyncJobRunnerObserver[T]] | None = None,
|
|
45
|
+
):
|
|
20
46
|
if concurrency < 1:
|
|
21
47
|
raise ValueError("concurrency must be ≥ 1")
|
|
22
48
|
self.concurrency = concurrency
|
|
49
|
+
self.jobs = jobs
|
|
50
|
+
self.run_job_fn = run_job_fn
|
|
51
|
+
self.observers = observers or []
|
|
23
52
|
|
|
24
|
-
async def
|
|
25
|
-
self
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
53
|
+
async def notify_error(self, job: T, error: Exception):
|
|
54
|
+
for observer in self.observers:
|
|
55
|
+
await observer.on_error(job, error)
|
|
56
|
+
|
|
57
|
+
async def notify_success(self, job: T):
|
|
58
|
+
for observer in self.observers:
|
|
59
|
+
await observer.on_success(job)
|
|
60
|
+
|
|
61
|
+
async def notify_job_start(self, job: T):
|
|
62
|
+
for observer in self.observers:
|
|
63
|
+
await observer.on_job_start(job)
|
|
64
|
+
|
|
65
|
+
async def run(self) -> AsyncGenerator[Progress, None]:
|
|
29
66
|
"""
|
|
30
67
|
Runs the jobs with parallel workers and yields progress updates.
|
|
31
68
|
"""
|
|
32
69
|
complete = 0
|
|
33
70
|
errors = 0
|
|
34
|
-
total = len(jobs)
|
|
71
|
+
total = len(self.jobs)
|
|
35
72
|
|
|
36
73
|
# Send initial status
|
|
37
74
|
yield Progress(complete=complete, total=total, errors=errors)
|
|
38
75
|
|
|
39
76
|
worker_queue: asyncio.Queue[T] = asyncio.Queue()
|
|
40
|
-
for job in jobs:
|
|
77
|
+
for job in self.jobs:
|
|
41
78
|
worker_queue.put_nowait(job)
|
|
42
79
|
|
|
43
80
|
# simple status queue to return progress. True=success, False=error
|
|
@@ -46,7 +83,7 @@ class AsyncJobRunner:
|
|
|
46
83
|
workers = []
|
|
47
84
|
for _ in range(self.concurrency):
|
|
48
85
|
task = asyncio.create_task(
|
|
49
|
-
self._run_worker(worker_queue, status_queue,
|
|
86
|
+
self._run_worker(worker_queue, status_queue, self.run_job_fn),
|
|
50
87
|
)
|
|
51
88
|
workers.append(task)
|
|
52
89
|
|
|
@@ -64,7 +101,11 @@ class AsyncJobRunner:
|
|
|
64
101
|
else:
|
|
65
102
|
errors += 1
|
|
66
103
|
|
|
67
|
-
yield Progress(
|
|
104
|
+
yield Progress(
|
|
105
|
+
complete=complete,
|
|
106
|
+
total=total,
|
|
107
|
+
errors=errors,
|
|
108
|
+
)
|
|
68
109
|
except asyncio.TimeoutError:
|
|
69
110
|
# Timeout is expected, just continue to recheck worker status
|
|
70
111
|
# Don't love this but beats sentinels for reliability
|
|
@@ -82,7 +123,7 @@ class AsyncJobRunner:
|
|
|
82
123
|
self,
|
|
83
124
|
worker_queue: asyncio.Queue[T],
|
|
84
125
|
status_queue: asyncio.Queue[bool],
|
|
85
|
-
|
|
126
|
+
run_job_fn: Callable[[T], Awaitable[bool]],
|
|
86
127
|
):
|
|
87
128
|
while True:
|
|
88
129
|
try:
|
|
@@ -92,13 +133,17 @@ class AsyncJobRunner:
|
|
|
92
133
|
break
|
|
93
134
|
|
|
94
135
|
try:
|
|
95
|
-
|
|
96
|
-
|
|
136
|
+
await self.notify_job_start(job)
|
|
137
|
+
result = await run_job_fn(job)
|
|
138
|
+
if result:
|
|
139
|
+
await self.notify_success(job)
|
|
140
|
+
except Exception as e:
|
|
97
141
|
logger.error("Job failed to complete", exc_info=True)
|
|
98
|
-
|
|
142
|
+
await self.notify_error(job, e)
|
|
143
|
+
result = False
|
|
99
144
|
|
|
100
145
|
try:
|
|
101
|
-
await status_queue.put(
|
|
146
|
+
await status_queue.put(result)
|
|
102
147
|
except Exception:
|
|
103
148
|
logger.error("Failed to enqueue status for job", exc_info=True)
|
|
104
149
|
finally:
|
kiln_ai/utils/config.py
CHANGED
|
@@ -221,14 +221,14 @@ class Config:
|
|
|
221
221
|
raise AttributeError(f"Config has no attribute '{name}'")
|
|
222
222
|
|
|
223
223
|
@classmethod
|
|
224
|
-
def settings_dir(cls, create=True):
|
|
224
|
+
def settings_dir(cls, create=True) -> str:
|
|
225
225
|
settings_dir = os.path.join(Path.home(), ".kiln_ai")
|
|
226
226
|
if create and not os.path.exists(settings_dir):
|
|
227
227
|
os.makedirs(settings_dir)
|
|
228
228
|
return settings_dir
|
|
229
229
|
|
|
230
230
|
@classmethod
|
|
231
|
-
def settings_path(cls, create=True):
|
|
231
|
+
def settings_path(cls, create=True) -> str:
|
|
232
232
|
settings_dir = cls.settings_dir(create)
|
|
233
233
|
return os.path.join(settings_dir, "settings.yaml")
|
|
234
234
|
|
kiln_ai/utils/env.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@contextmanager
|
|
6
|
+
def temporary_env(var_name: str, value: str):
|
|
7
|
+
old_value = os.environ.get(var_name)
|
|
8
|
+
os.environ[var_name] = value
|
|
9
|
+
try:
|
|
10
|
+
yield
|
|
11
|
+
finally:
|
|
12
|
+
if old_value is None:
|
|
13
|
+
os.environ.pop(var_name, None) # remove if it did not exist before
|
|
14
|
+
else:
|
|
15
|
+
os.environ[var_name] = old_value
|