kiln-ai 0.21.0__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/extractors/litellm_extractor.py +52 -32
- kiln_ai/adapters/extractors/test_litellm_extractor.py +169 -71
- kiln_ai/adapters/ml_embedding_model_list.py +330 -28
- kiln_ai/adapters/ml_model_list.py +503 -23
- kiln_ai/adapters/model_adapters/litellm_adapter.py +34 -7
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +78 -0
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +6 -9
- kiln_ai/adapters/test_ml_embedding_model_list.py +89 -279
- kiln_ai/adapters/test_ml_model_list.py +0 -10
- kiln_ai/datamodel/basemodel.py +31 -3
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +14 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +3 -3
- kiln_ai/datamodel/test_basemodel.py +269 -13
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_model.py +31 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +106 -1
- kiln_ai/datamodel/tool_id.py +36 -0
- kiln_ai/tools/base_tool.py +12 -3
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +158 -0
- kiln_ai/tools/mcp_server_tool.py +2 -2
- kiln_ai/tools/mcp_session_manager.py +50 -24
- kiln_ai/tools/rag_tools.py +12 -5
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +186 -226
- kiln_ai/tools/test_rag_tools.py +86 -5
- kiln_ai/tools/test_tool_registry.py +199 -5
- kiln_ai/tools/tool_registry.py +49 -17
- kiln_ai/utils/filesystem.py +4 -4
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +21 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +14 -1
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +3 -1
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.0.dist-info}/RECORD +45 -43
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
kiln_ai/tools/test_rag_tools.py
CHANGED
|
@@ -8,6 +8,7 @@ from kiln_ai.datamodel.embedding import EmbeddingConfig
|
|
|
8
8
|
from kiln_ai.datamodel.project import Project
|
|
9
9
|
from kiln_ai.datamodel.rag import RagConfig
|
|
10
10
|
from kiln_ai.datamodel.vector_store import VectorStoreConfig, VectorStoreType
|
|
11
|
+
from kiln_ai.tools.base_tool import ToolCallContext
|
|
11
12
|
from kiln_ai.tools.rag_tools import ChunkContext, RagTool, format_search_results
|
|
12
13
|
|
|
13
14
|
|
|
@@ -420,7 +421,7 @@ class TestRagTool:
|
|
|
420
421
|
tool = RagTool("tool_123", mock_rag_config)
|
|
421
422
|
|
|
422
423
|
# Run the tool
|
|
423
|
-
result = await tool.run("test query")
|
|
424
|
+
result = await tool.run(context=None, query="test query")
|
|
424
425
|
|
|
425
426
|
# Verify the result format
|
|
426
427
|
expected_result = (
|
|
@@ -500,7 +501,7 @@ class TestRagTool:
|
|
|
500
501
|
tool = RagTool("tool_123", mock_rag_config)
|
|
501
502
|
|
|
502
503
|
# Run the tool
|
|
503
|
-
result = await tool.run("hybrid query")
|
|
504
|
+
result = await tool.run(context=None, query="hybrid query")
|
|
504
505
|
|
|
505
506
|
# Verify embedding generation was called
|
|
506
507
|
mock_embedding_adapter.generate_embeddings.assert_called_once_with(
|
|
@@ -566,7 +567,7 @@ class TestRagTool:
|
|
|
566
567
|
tool = RagTool("tool_123", mock_rag_config)
|
|
567
568
|
|
|
568
569
|
# Run the tool
|
|
569
|
-
result = await tool.run("fts query")
|
|
570
|
+
result = await tool.run(context=None, query="fts query")
|
|
570
571
|
|
|
571
572
|
# Verify the result format
|
|
572
573
|
expected_result = (
|
|
@@ -629,7 +630,7 @@ class TestRagTool:
|
|
|
629
630
|
|
|
630
631
|
# Run the tool and expect an error
|
|
631
632
|
with pytest.raises(ValueError, match="No embeddings generated"):
|
|
632
|
-
await tool.run("query with no embeddings")
|
|
633
|
+
await tool.run(context=None, query="query with no embeddings")
|
|
633
634
|
|
|
634
635
|
async def test_rag_tool_run_empty_search_results(
|
|
635
636
|
self, mock_rag_config, mock_project
|
|
@@ -675,11 +676,91 @@ class TestRagTool:
|
|
|
675
676
|
tool = RagTool("tool_123", mock_rag_config)
|
|
676
677
|
|
|
677
678
|
# Run the tool
|
|
678
|
-
result = await tool.run("query with no results")
|
|
679
|
+
result = await tool.run(context=None, query="query with no results")
|
|
679
680
|
|
|
680
681
|
# Should return empty string for no results
|
|
681
682
|
assert result == ""
|
|
682
683
|
|
|
684
|
+
async def test_rag_tool_run_with_context_is_accepted(
|
|
685
|
+
self, mock_rag_config, mock_project
|
|
686
|
+
):
|
|
687
|
+
"""Ensure RagTool.run accepts and works when a ToolCallContext is provided."""
|
|
688
|
+
mock_rag_config.parent_project.return_value = mock_project
|
|
689
|
+
|
|
690
|
+
# Mock search results
|
|
691
|
+
search_results = [
|
|
692
|
+
SearchResult(
|
|
693
|
+
document_id="doc_ctx",
|
|
694
|
+
chunk_idx=3,
|
|
695
|
+
chunk_text="Context ok",
|
|
696
|
+
similarity=0.77,
|
|
697
|
+
)
|
|
698
|
+
]
|
|
699
|
+
|
|
700
|
+
with (
|
|
701
|
+
patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class,
|
|
702
|
+
patch("kiln_ai.tools.rag_tools.EmbeddingConfig") as mock_embed_config_class,
|
|
703
|
+
patch(
|
|
704
|
+
"kiln_ai.tools.rag_tools.embedding_adapter_from_type"
|
|
705
|
+
) as mock_adapter_factory,
|
|
706
|
+
patch(
|
|
707
|
+
"kiln_ai.tools.rag_tools.vector_store_adapter_for_config",
|
|
708
|
+
new_callable=AsyncMock,
|
|
709
|
+
) as mock_vs_adapter_factory,
|
|
710
|
+
):
|
|
711
|
+
# VECTOR type → embedding path taken
|
|
712
|
+
mock_vector_store_config = Mock()
|
|
713
|
+
mock_vector_store_config.store_type = VectorStoreType.LANCE_DB_VECTOR
|
|
714
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = (
|
|
715
|
+
mock_vector_store_config
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
mock_embedding_config = Mock()
|
|
719
|
+
mock_embed_config_class.from_id_and_parent_path.return_value = (
|
|
720
|
+
mock_embedding_config
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
mock_embedding_adapter = AsyncMock()
|
|
724
|
+
mock_embedding_result = Mock()
|
|
725
|
+
mock_embedding_result.embeddings = [Mock(vector=[1.0])]
|
|
726
|
+
mock_embedding_adapter.generate_embeddings.return_value = (
|
|
727
|
+
mock_embedding_result
|
|
728
|
+
)
|
|
729
|
+
mock_adapter_factory.return_value = mock_embedding_adapter
|
|
730
|
+
|
|
731
|
+
mock_vector_store_adapter = AsyncMock()
|
|
732
|
+
mock_vector_store_adapter.search.return_value = search_results
|
|
733
|
+
mock_vs_adapter_factory.return_value = mock_vector_store_adapter
|
|
734
|
+
|
|
735
|
+
tool = RagTool("tool_ctx", mock_rag_config)
|
|
736
|
+
|
|
737
|
+
ctx = ToolCallContext(allow_saving=False)
|
|
738
|
+
result = await tool.run(context=ctx, query="with context")
|
|
739
|
+
|
|
740
|
+
# Works and returns formatted text
|
|
741
|
+
assert result == "[document_id: doc_ctx, chunk_idx: 3]\nContext ok\n\n"
|
|
742
|
+
|
|
743
|
+
# Normal behavior still occurs
|
|
744
|
+
mock_embedding_adapter.generate_embeddings.assert_called_once_with(
|
|
745
|
+
["with context"]
|
|
746
|
+
)
|
|
747
|
+
mock_vector_store_adapter.search.assert_called_once()
|
|
748
|
+
|
|
749
|
+
async def test_rag_tool_run_missing_query_raises(
|
|
750
|
+
self, mock_rag_config, mock_project
|
|
751
|
+
):
|
|
752
|
+
"""Ensure RagTool.run enforces the 'if not query' guard."""
|
|
753
|
+
mock_rag_config.parent_project.return_value = mock_project
|
|
754
|
+
|
|
755
|
+
with (
|
|
756
|
+
patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class,
|
|
757
|
+
):
|
|
758
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = Mock()
|
|
759
|
+
tool = RagTool("tool_err", mock_rag_config)
|
|
760
|
+
|
|
761
|
+
with pytest.raises(KeyError, match="query"):
|
|
762
|
+
await tool.run(context=None)
|
|
763
|
+
|
|
683
764
|
|
|
684
765
|
class TestRagToolNameAndDescription:
|
|
685
766
|
"""Test RagTool name and description functionality with tool_name and tool_description fields."""
|
|
@@ -3,15 +3,20 @@ from unittest.mock import Mock
|
|
|
3
3
|
|
|
4
4
|
import pytest
|
|
5
5
|
|
|
6
|
-
from kiln_ai.datamodel.external_tool_server import
|
|
6
|
+
from kiln_ai.datamodel.external_tool_server import (
|
|
7
|
+
ExternalToolServer,
|
|
8
|
+
ToolServerType,
|
|
9
|
+
)
|
|
7
10
|
from kiln_ai.datamodel.project import Project
|
|
8
11
|
from kiln_ai.datamodel.task import Task
|
|
9
12
|
from kiln_ai.datamodel.tool_id import (
|
|
13
|
+
KILN_TASK_TOOL_ID_PREFIX,
|
|
10
14
|
MCP_LOCAL_TOOL_ID_PREFIX,
|
|
11
15
|
MCP_REMOTE_TOOL_ID_PREFIX,
|
|
12
16
|
RAG_TOOL_ID_PREFIX,
|
|
13
17
|
KilnBuiltInToolId,
|
|
14
18
|
_check_tool_id,
|
|
19
|
+
kiln_task_server_id_from_tool_id,
|
|
15
20
|
mcp_server_and_tool_name_from_id,
|
|
16
21
|
)
|
|
17
22
|
from kiln_ai.tools.built_in_tools.math_tools import (
|
|
@@ -20,6 +25,7 @@ from kiln_ai.tools.built_in_tools.math_tools import (
|
|
|
20
25
|
MultiplyTool,
|
|
21
26
|
SubtractTool,
|
|
22
27
|
)
|
|
28
|
+
from kiln_ai.tools.kiln_task_tool import KilnTaskTool
|
|
23
29
|
from kiln_ai.tools.mcp_server_tool import MCPServerTool
|
|
24
30
|
from kiln_ai.tools.tool_registry import tool_from_id
|
|
25
31
|
|
|
@@ -87,7 +93,6 @@ class TestToolRegistry:
|
|
|
87
93
|
type=ToolServerType.remote_mcp,
|
|
88
94
|
properties={
|
|
89
95
|
"server_url": "https://example.com",
|
|
90
|
-
"headers": {},
|
|
91
96
|
},
|
|
92
97
|
)
|
|
93
98
|
|
|
@@ -157,7 +162,6 @@ class TestToolRegistry:
|
|
|
157
162
|
type=ToolServerType.remote_mcp,
|
|
158
163
|
properties={
|
|
159
164
|
"server_url": "https://example.com",
|
|
160
|
-
"headers": {},
|
|
161
165
|
},
|
|
162
166
|
)
|
|
163
167
|
|
|
@@ -361,6 +365,48 @@ class TestToolRegistry:
|
|
|
361
365
|
with pytest.raises(ValueError, match=f"Invalid tool ID: {invalid_id}"):
|
|
362
366
|
_check_tool_id(invalid_id)
|
|
363
367
|
|
|
368
|
+
def test_check_tool_id_valid_kiln_task_tool_id(self):
|
|
369
|
+
"""Test that _check_tool_id accepts valid Kiln task tool IDs."""
|
|
370
|
+
valid_kiln_task_ids = [
|
|
371
|
+
f"{KILN_TASK_TOOL_ID_PREFIX}server123",
|
|
372
|
+
f"{KILN_TASK_TOOL_ID_PREFIX}my_task_server",
|
|
373
|
+
f"{KILN_TASK_TOOL_ID_PREFIX}123456789",
|
|
374
|
+
f"{KILN_TASK_TOOL_ID_PREFIX}server_with_underscores",
|
|
375
|
+
f"{KILN_TASK_TOOL_ID_PREFIX}server-with-dashes",
|
|
376
|
+
]
|
|
377
|
+
|
|
378
|
+
for tool_id in valid_kiln_task_ids:
|
|
379
|
+
result = _check_tool_id(tool_id)
|
|
380
|
+
assert result == tool_id
|
|
381
|
+
|
|
382
|
+
def test_check_tool_id_invalid_kiln_task_tool_id(self):
|
|
383
|
+
"""Test that _check_tool_id rejects invalid Kiln task tool IDs."""
|
|
384
|
+
# These start with the prefix but have wrong format
|
|
385
|
+
invalid_kiln_task_format_ids = [
|
|
386
|
+
f"{KILN_TASK_TOOL_ID_PREFIX}", # Missing server ID
|
|
387
|
+
f"{KILN_TASK_TOOL_ID_PREFIX}::", # Empty server ID
|
|
388
|
+
f"{KILN_TASK_TOOL_ID_PREFIX}server::tool", # Too many parts (3 instead of 2)
|
|
389
|
+
f"{KILN_TASK_TOOL_ID_PREFIX}server::tool::extra", # Too many parts (4 instead of 2)
|
|
390
|
+
]
|
|
391
|
+
|
|
392
|
+
for invalid_id in invalid_kiln_task_format_ids:
|
|
393
|
+
with pytest.raises(
|
|
394
|
+
ValueError, match=f"Invalid Kiln task tool ID format: {invalid_id}"
|
|
395
|
+
):
|
|
396
|
+
_check_tool_id(invalid_id)
|
|
397
|
+
|
|
398
|
+
# These don't match the prefix - get generic error
|
|
399
|
+
invalid_generic_ids = [
|
|
400
|
+
"kiln_task:", # Missing last colon (doesn't match full prefix)
|
|
401
|
+
"kiln:task::server", # Wrong prefix format
|
|
402
|
+
"kiln_task_server", # Missing colons
|
|
403
|
+
"task::server", # Missing kiln prefix
|
|
404
|
+
]
|
|
405
|
+
|
|
406
|
+
for invalid_id in invalid_generic_ids:
|
|
407
|
+
with pytest.raises(ValueError, match=f"Invalid tool ID: {invalid_id}"):
|
|
408
|
+
_check_tool_id(invalid_id)
|
|
409
|
+
|
|
364
410
|
def test_mcp_server_and_tool_name_from_id_valid_inputs(self):
|
|
365
411
|
"""Test that mcp_server_and_tool_name_from_id correctly parses valid MCP tool IDs."""
|
|
366
412
|
test_cases = [
|
|
@@ -489,6 +535,64 @@ class TestToolRegistry:
|
|
|
489
535
|
assert server_id == expected_server
|
|
490
536
|
assert tool_name == expected_tool
|
|
491
537
|
|
|
538
|
+
def test_kiln_task_server_id_from_tool_id_valid_inputs(self):
|
|
539
|
+
"""Test that kiln_task_server_id_from_tool_id correctly parses valid Kiln task tool IDs."""
|
|
540
|
+
test_cases = [
|
|
541
|
+
("kiln_task::server123", "server123"),
|
|
542
|
+
("kiln_task::my_task_server", "my_task_server"),
|
|
543
|
+
("kiln_task::123456789", "123456789"),
|
|
544
|
+
("kiln_task::server_with_underscores", "server_with_underscores"),
|
|
545
|
+
("kiln_task::server-with-dashes", "server-with-dashes"),
|
|
546
|
+
("kiln_task::a", "a"), # Minimal valid case
|
|
547
|
+
(
|
|
548
|
+
"kiln_task::very_long_server_name_with_numbers_123",
|
|
549
|
+
"very_long_server_name_with_numbers_123",
|
|
550
|
+
),
|
|
551
|
+
]
|
|
552
|
+
|
|
553
|
+
for tool_id, expected_server_id in test_cases:
|
|
554
|
+
result = kiln_task_server_id_from_tool_id(tool_id)
|
|
555
|
+
assert result == expected_server_id, (
|
|
556
|
+
f"Failed for {tool_id}: expected {expected_server_id}, got {result}"
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
def test_kiln_task_server_id_from_tool_id_invalid_inputs(self):
|
|
560
|
+
"""Test that kiln_task_server_id_from_tool_id raises ValueError for invalid Kiln task tool IDs."""
|
|
561
|
+
invalid_inputs = [
|
|
562
|
+
"kiln_task::", # Empty server ID
|
|
563
|
+
"kiln_task::server::tool", # Too many parts (3 instead of 2)
|
|
564
|
+
"kiln_task::server::tool::extra", # Too many parts (4 instead of 2)
|
|
565
|
+
"invalid::format", # Wrong prefix
|
|
566
|
+
"", # Empty string
|
|
567
|
+
"single_part", # No separators
|
|
568
|
+
"two::parts", # Only 2 parts but wrong prefix
|
|
569
|
+
"kiln_task", # Missing colons
|
|
570
|
+
]
|
|
571
|
+
|
|
572
|
+
for invalid_id in invalid_inputs:
|
|
573
|
+
with pytest.raises(
|
|
574
|
+
ValueError,
|
|
575
|
+
match=r"Invalid Kiln task tool ID format:.*Expected format.*kiln_task::<server_id>",
|
|
576
|
+
):
|
|
577
|
+
kiln_task_server_id_from_tool_id(invalid_id)
|
|
578
|
+
|
|
579
|
+
@pytest.mark.parametrize(
|
|
580
|
+
"tool_id,expected_server_id",
|
|
581
|
+
[
|
|
582
|
+
("kiln_task::test_server", "test_server"),
|
|
583
|
+
("kiln_task::s", "s"),
|
|
584
|
+
("kiln_task::long_server_name_123", "long_server_name_123"),
|
|
585
|
+
("kiln_task::server-with-dashes", "server-with-dashes"),
|
|
586
|
+
("kiln_task::server_with_underscores", "server_with_underscores"),
|
|
587
|
+
],
|
|
588
|
+
)
|
|
589
|
+
def test_kiln_task_server_id_from_tool_id_parametrized(
|
|
590
|
+
self, tool_id, expected_server_id
|
|
591
|
+
):
|
|
592
|
+
"""Parametrized test for kiln_task_server_id_from_tool_id with various valid inputs."""
|
|
593
|
+
server_id = kiln_task_server_id_from_tool_id(tool_id)
|
|
594
|
+
assert server_id == expected_server_id
|
|
595
|
+
|
|
492
596
|
def test_tool_from_id_mcp_missing_task_raises_error(self):
|
|
493
597
|
"""Test that MCP tool ID with missing task raises ValueError."""
|
|
494
598
|
mcp_tool_id = f"{MCP_REMOTE_TOOL_ID_PREFIX}test_server::test_tool"
|
|
@@ -508,7 +612,6 @@ class TestToolRegistry:
|
|
|
508
612
|
description="Test MCP server",
|
|
509
613
|
properties={
|
|
510
614
|
"server_url": "https://example.com",
|
|
511
|
-
"headers": {},
|
|
512
615
|
},
|
|
513
616
|
)
|
|
514
617
|
|
|
@@ -539,7 +642,6 @@ class TestToolRegistry:
|
|
|
539
642
|
description="Different MCP server",
|
|
540
643
|
properties={
|
|
541
644
|
"server_url": "https://example.com",
|
|
542
|
-
"headers": {},
|
|
543
645
|
},
|
|
544
646
|
)
|
|
545
647
|
|
|
@@ -560,3 +662,95 @@ class TestToolRegistry:
|
|
|
560
662
|
match="External tool server not found: nonexistent_server in project ID test_project_id",
|
|
561
663
|
):
|
|
562
664
|
tool_from_id(mcp_tool_id, task=mock_task)
|
|
665
|
+
|
|
666
|
+
def test_tool_from_id_kiln_task_tool_success(self):
|
|
667
|
+
"""Test that tool_from_id works with Kiln task tool IDs."""
|
|
668
|
+
# Create mock external tool server for Kiln task
|
|
669
|
+
mock_server = ExternalToolServer(
|
|
670
|
+
name="test_kiln_task_server",
|
|
671
|
+
type=ToolServerType.kiln_task,
|
|
672
|
+
description="Test Kiln task server",
|
|
673
|
+
properties={
|
|
674
|
+
"name": "test_task_tool",
|
|
675
|
+
"description": "A test task tool",
|
|
676
|
+
"task_id": "test_task_123",
|
|
677
|
+
"run_config_id": "test_config_456",
|
|
678
|
+
"is_archived": False,
|
|
679
|
+
},
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
# Create mock project with the external tool server
|
|
683
|
+
mock_project = Mock(spec=Project)
|
|
684
|
+
mock_project.id = "test_project_id"
|
|
685
|
+
mock_project.external_tool_servers.return_value = [mock_server]
|
|
686
|
+
|
|
687
|
+
# Create mock task with parent project
|
|
688
|
+
mock_task = Mock(spec=Task)
|
|
689
|
+
mock_task.parent_project.return_value = mock_project
|
|
690
|
+
|
|
691
|
+
# Test with Kiln task tool ID
|
|
692
|
+
tool_id = f"{KILN_TASK_TOOL_ID_PREFIX}{mock_server.id}"
|
|
693
|
+
tool = tool_from_id(tool_id, task=mock_task)
|
|
694
|
+
|
|
695
|
+
# Verify the tool is KilnTaskTool
|
|
696
|
+
assert isinstance(tool, KilnTaskTool)
|
|
697
|
+
assert tool._project_id == "test_project_id"
|
|
698
|
+
assert tool._tool_id == tool_id
|
|
699
|
+
assert tool._tool_server_model == mock_server
|
|
700
|
+
|
|
701
|
+
def test_tool_from_id_kiln_task_tool_no_task(self):
|
|
702
|
+
"""Test that Kiln task tool ID without task raises ValueError."""
|
|
703
|
+
tool_id = f"{KILN_TASK_TOOL_ID_PREFIX}test_server"
|
|
704
|
+
with pytest.raises(
|
|
705
|
+
ValueError,
|
|
706
|
+
match=r"Unable to resolve tool from id.*Requires a parent project/task",
|
|
707
|
+
):
|
|
708
|
+
tool_from_id(tool_id, task=None)
|
|
709
|
+
|
|
710
|
+
def test_tool_from_id_kiln_task_tool_no_project(self):
|
|
711
|
+
"""Test that Kiln task tool ID with task but no project raises ValueError."""
|
|
712
|
+
# Create mock task without parent project
|
|
713
|
+
mock_task = Mock(spec=Task)
|
|
714
|
+
mock_task.parent_project.return_value = None
|
|
715
|
+
|
|
716
|
+
tool_id = f"{KILN_TASK_TOOL_ID_PREFIX}test_server"
|
|
717
|
+
|
|
718
|
+
with pytest.raises(
|
|
719
|
+
ValueError,
|
|
720
|
+
match=r"Unable to resolve tool from id.*Requires a parent project/task",
|
|
721
|
+
):
|
|
722
|
+
tool_from_id(tool_id, task=mock_task)
|
|
723
|
+
|
|
724
|
+
def test_tool_from_id_kiln_task_tool_server_not_found(self):
|
|
725
|
+
"""Test that Kiln task tool ID with server not found raises ValueError."""
|
|
726
|
+
# Create mock external tool server with different ID
|
|
727
|
+
mock_server = ExternalToolServer(
|
|
728
|
+
name="different_server",
|
|
729
|
+
type=ToolServerType.kiln_task,
|
|
730
|
+
description="Different Kiln task server",
|
|
731
|
+
properties={
|
|
732
|
+
"name": "different_tool",
|
|
733
|
+
"description": "A different task tool",
|
|
734
|
+
"task_id": "different_task_123",
|
|
735
|
+
"run_config_id": "different_config_456",
|
|
736
|
+
"is_archived": False,
|
|
737
|
+
},
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
# Create mock project with the external tool server
|
|
741
|
+
mock_project = Mock(spec=Project)
|
|
742
|
+
mock_project.id = "test_project_id"
|
|
743
|
+
mock_project.external_tool_servers.return_value = [mock_server]
|
|
744
|
+
|
|
745
|
+
# Create mock task with parent project
|
|
746
|
+
mock_task = Mock(spec=Task)
|
|
747
|
+
mock_task.parent_project.return_value = mock_project
|
|
748
|
+
|
|
749
|
+
# Use a tool ID with a server that doesn't exist in the project
|
|
750
|
+
tool_id = f"{KILN_TASK_TOOL_ID_PREFIX}nonexistent_server"
|
|
751
|
+
|
|
752
|
+
with pytest.raises(
|
|
753
|
+
ValueError,
|
|
754
|
+
match="Kiln Task External tool server not found: nonexistent_server in project ID test_project_id",
|
|
755
|
+
):
|
|
756
|
+
tool_from_id(tool_id, task=mock_task)
|
kiln_ai/tools/tool_registry.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
from kiln_ai.datamodel.rag import RagConfig
|
|
2
2
|
from kiln_ai.datamodel.task import Task
|
|
3
3
|
from kiln_ai.datamodel.tool_id import (
|
|
4
|
+
KILN_TASK_TOOL_ID_PREFIX,
|
|
4
5
|
MCP_LOCAL_TOOL_ID_PREFIX,
|
|
5
6
|
MCP_REMOTE_TOOL_ID_PREFIX,
|
|
6
7
|
RAG_TOOL_ID_PREFIX,
|
|
7
8
|
KilnBuiltInToolId,
|
|
9
|
+
kiln_task_server_id_from_tool_id,
|
|
8
10
|
mcp_server_and_tool_name_from_id,
|
|
9
11
|
rag_config_id_from_id,
|
|
10
12
|
)
|
|
@@ -15,6 +17,7 @@ from kiln_ai.tools.built_in_tools.math_tools import (
|
|
|
15
17
|
MultiplyTool,
|
|
16
18
|
SubtractTool,
|
|
17
19
|
)
|
|
20
|
+
from kiln_ai.tools.kiln_task_tool import KilnTaskTool
|
|
18
21
|
from kiln_ai.tools.mcp_server_tool import MCPServerTool
|
|
19
22
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
20
23
|
|
|
@@ -38,31 +41,60 @@ def tool_from_id(tool_id: str, task: Task | None = None) -> KilnToolInterface:
|
|
|
38
41
|
case _:
|
|
39
42
|
raise_exhaustive_enum_error(typed_tool_id)
|
|
40
43
|
|
|
41
|
-
# Check MCP
|
|
42
|
-
|
|
44
|
+
# Check if this looks like an MCP or Kiln Task tool ID that requires a project
|
|
45
|
+
is_mcp_tool = tool_id.startswith(
|
|
46
|
+
(MCP_REMOTE_TOOL_ID_PREFIX, MCP_LOCAL_TOOL_ID_PREFIX)
|
|
47
|
+
)
|
|
48
|
+
is_kiln_task_tool = tool_id.startswith(KILN_TASK_TOOL_ID_PREFIX)
|
|
49
|
+
|
|
50
|
+
if is_mcp_tool or is_kiln_task_tool:
|
|
43
51
|
project = task.parent_project() if task is not None else None
|
|
44
|
-
if project is None:
|
|
52
|
+
if project is None or project.id is None:
|
|
45
53
|
raise ValueError(
|
|
46
54
|
f"Unable to resolve tool from id: {tool_id}. Requires a parent project/task."
|
|
47
55
|
)
|
|
48
56
|
|
|
49
|
-
#
|
|
50
|
-
|
|
57
|
+
# Check MCP Server Tools
|
|
58
|
+
if is_mcp_tool:
|
|
59
|
+
# Get the tool server ID and tool name from the ID
|
|
60
|
+
tool_server_id, tool_name = mcp_server_and_tool_name_from_id(
|
|
61
|
+
tool_id
|
|
62
|
+
) # Fixed function name
|
|
51
63
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
64
|
+
server = next(
|
|
65
|
+
(
|
|
66
|
+
server
|
|
67
|
+
for server in project.external_tool_servers()
|
|
68
|
+
if server.id == tool_server_id
|
|
69
|
+
),
|
|
70
|
+
None,
|
|
71
|
+
)
|
|
72
|
+
if server is None:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"External tool server not found: {tool_server_id} in project ID {project.id}"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
return MCPServerTool(server, tool_name)
|
|
78
|
+
|
|
79
|
+
# Check Kiln Task Tools
|
|
80
|
+
if is_kiln_task_tool:
|
|
81
|
+
server_id = kiln_task_server_id_from_tool_id(tool_id)
|
|
82
|
+
|
|
83
|
+
server = next(
|
|
84
|
+
(
|
|
85
|
+
server
|
|
86
|
+
for server in project.external_tool_servers()
|
|
87
|
+
if server.id == server_id
|
|
88
|
+
),
|
|
89
|
+
None,
|
|
63
90
|
)
|
|
91
|
+
if server is None:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Kiln Task External tool server not found: {server_id} in project ID {project.id}"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return KilnTaskTool(project.id, tool_id, server)
|
|
64
97
|
|
|
65
|
-
return MCPServerTool(server, tool_name)
|
|
66
98
|
elif tool_id.startswith(RAG_TOOL_ID_PREFIX):
|
|
67
99
|
project = task.parent_project() if task is not None else None
|
|
68
100
|
if project is None:
|
kiln_ai/utils/filesystem.py
CHANGED
|
@@ -5,10 +5,10 @@ from pathlib import Path
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
def open_folder(path: str | Path) -> None:
|
|
8
|
-
|
|
8
|
+
dir = os.path.dirname(path)
|
|
9
9
|
if sys.platform.startswith("darwin"):
|
|
10
|
-
subprocess.run(["open",
|
|
10
|
+
subprocess.run(["open", dir], check=True)
|
|
11
11
|
elif sys.platform.startswith("win"):
|
|
12
|
-
os.startfile(
|
|
12
|
+
os.startfile(dir) # type: ignore[attr-defined]
|
|
13
13
|
else:
|
|
14
|
-
subprocess.run(["xdg-open",
|
|
14
|
+
subprocess.run(["xdg-open", dir], check=True)
|
kiln_ai/utils/open_ai_types.py
CHANGED
|
@@ -17,11 +17,11 @@ from typing import (
|
|
|
17
17
|
)
|
|
18
18
|
|
|
19
19
|
from openai.types.chat import (
|
|
20
|
+
ChatCompletionContentPartTextParam,
|
|
20
21
|
ChatCompletionDeveloperMessageParam,
|
|
21
22
|
ChatCompletionFunctionMessageParam,
|
|
22
23
|
ChatCompletionMessageToolCallParam,
|
|
23
24
|
ChatCompletionSystemMessageParam,
|
|
24
|
-
ChatCompletionToolMessageParam,
|
|
25
25
|
ChatCompletionUserMessageParam,
|
|
26
26
|
)
|
|
27
27
|
from openai.types.chat.chat_completion_assistant_message_param import (
|
|
@@ -84,11 +84,28 @@ class ChatCompletionAssistantMessageParamWrapper(TypedDict, total=False):
|
|
|
84
84
|
"""The tool calls generated by the model, such as function calls."""
|
|
85
85
|
|
|
86
86
|
|
|
87
|
+
class ChatCompletionToolMessageParamWrapper(TypedDict, total=False):
|
|
88
|
+
content: Required[Union[str, Iterable[ChatCompletionContentPartTextParam]]]
|
|
89
|
+
"""The contents of the tool message."""
|
|
90
|
+
|
|
91
|
+
role: Required[Literal["tool"]]
|
|
92
|
+
"""The role of the messages author, in this case `tool`."""
|
|
93
|
+
|
|
94
|
+
tool_call_id: Required[str]
|
|
95
|
+
"""Tool call that this message is responding to."""
|
|
96
|
+
|
|
97
|
+
kiln_task_tool_data: Optional[str]
|
|
98
|
+
"""The data for the Kiln task tool that this message is responding to.
|
|
99
|
+
|
|
100
|
+
Formatted as `<project_id>:::<tool_id>:::<task_id>:::<run_id>`
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
|
|
87
104
|
ChatCompletionMessageParam: TypeAlias = Union[
|
|
88
105
|
ChatCompletionDeveloperMessageParam,
|
|
89
106
|
ChatCompletionSystemMessageParam,
|
|
90
107
|
ChatCompletionUserMessageParam,
|
|
91
108
|
ChatCompletionAssistantMessageParamWrapper,
|
|
92
|
-
|
|
109
|
+
ChatCompletionToolMessageParamWrapper,
|
|
93
110
|
ChatCompletionFunctionMessageParam,
|
|
94
111
|
]
|
kiln_ai/utils/pdf_utils.py
CHANGED
|
@@ -8,6 +8,7 @@ from contextlib import asynccontextmanager
|
|
|
8
8
|
from pathlib import Path
|
|
9
9
|
from typing import AsyncGenerator
|
|
10
10
|
|
|
11
|
+
import pypdfium2
|
|
11
12
|
from pypdf import PdfReader, PdfWriter
|
|
12
13
|
|
|
13
14
|
|
|
@@ -36,3 +37,23 @@ async def split_pdf_into_pages(pdf_path: Path) -> AsyncGenerator[list[Path], Non
|
|
|
36
37
|
page_paths.append(page_path)
|
|
37
38
|
|
|
38
39
|
yield page_paths
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
async def convert_pdf_to_images(pdf_path: Path, output_dir: Path) -> list[Path]:
|
|
43
|
+
image_paths = []
|
|
44
|
+
|
|
45
|
+
# note: doing this in a thread causes a segfault - but this is slow and blocking
|
|
46
|
+
# so we should try to find a better way
|
|
47
|
+
pdf = pypdfium2.PdfDocument(pdf_path)
|
|
48
|
+
try:
|
|
49
|
+
for idx, page in enumerate(pdf):
|
|
50
|
+
await asyncio.sleep(0)
|
|
51
|
+
# scale=2 is legible for ~A4 pages (research papers, etc.) - lower than this is blurry
|
|
52
|
+
bitmap = page.render(scale=2).to_pil()
|
|
53
|
+
target_path = output_dir / f"img-{pdf_path.name}-{idx}.png"
|
|
54
|
+
bitmap.save(target_path)
|
|
55
|
+
image_paths.append(target_path)
|
|
56
|
+
|
|
57
|
+
return image_paths
|
|
58
|
+
finally:
|
|
59
|
+
pdf.close()
|