kiln-ai 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +8 -2
- kiln_ai/adapters/adapter_registry.py +43 -208
- kiln_ai/adapters/chat/chat_formatter.py +8 -12
- kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/docker_model_runner_tools.py +119 -0
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/base_eval.py +2 -2
- kiln_ai/adapters/eval/eval_runner.py +9 -3
- kiln_ai/adapters/eval/g_eval.py +2 -2
- kiln_ai/adapters/eval/test_base_eval.py +2 -4
- kiln_ai/adapters/eval/test_g_eval.py +4 -5
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
- kiln_ai/adapters/fine_tune/__init__.py +1 -1
- kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +192 -0
- kiln_ai/adapters/ml_model_list.py +761 -37
- kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
- kiln_ai/adapters/model_adapters/litellm_adapter.py +380 -138
- kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
- kiln_ai/adapters/model_adapters/test_structured_output.py +113 -5
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/parsers/__init__.py +1 -1
- kiln_ai/adapters/provider_tools.py +205 -47
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/repair/test_repair_task.py +12 -9
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +657 -85
- kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
- kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
- kiln_ai/adapters/test_ml_model_list.py +251 -1
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_adaptors.py +13 -6
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +254 -8
- kiln_ai/adapters/test_remote_config.py +651 -58
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +39 -34
- kiln_ai/datamodel/basemodel.py +170 -1
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +28 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/eval.py +1 -1
- kiln_ai/datamodel/external_tool_server.py +298 -0
- kiln_ai/datamodel/extraction.py +303 -0
- kiln_ai/datamodel/json_schema.py +25 -10
- kiln_ai/datamodel/project.py +40 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/registry.py +0 -15
- kiln_ai/datamodel/run_config.py +62 -0
- kiln_ai/datamodel/task.py +2 -77
- kiln_ai/datamodel/task_output.py +6 -1
- kiln_ai/datamodel/task_run.py +41 -0
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +4 -4
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_example_models.py +175 -0
- kiln_ai/datamodel/test_external_tool_server.py +691 -0
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +470 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_registry.py +8 -3
- kiln_ai/datamodel/test_task.py +15 -47
- kiln_ai/datamodel/test_tool_id.py +320 -0
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +105 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/__init__.py +8 -0
- kiln_ai/tools/base_tool.py +82 -0
- kiln_ai/tools/built_in_tools/__init__.py +13 -0
- kiln_ai/tools/built_in_tools/math_tools.py +124 -0
- kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
- kiln_ai/tools/mcp_server_tool.py +95 -0
- kiln_ai/tools/mcp_session_manager.py +246 -0
- kiln_ai/tools/rag_tools.py +157 -0
- kiln_ai/tools/test_base_tools.py +199 -0
- kiln_ai/tools/test_mcp_server_tool.py +457 -0
- kiln_ai/tools/test_mcp_session_manager.py +1585 -0
- kiln_ai/tools/test_rag_tools.py +848 -0
- kiln_ai/tools/test_tool_registry.py +562 -0
- kiln_ai/tools/tool_registry.py +85 -0
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +24 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +94 -0
- kiln_ai/utils/pdf_utils.py +38 -0
- kiln_ai/utils/project_utils.py +17 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_config.py +138 -1
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +131 -0
- kiln_ai/utils/test_pdf_utils.py +73 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +12 -5
- kiln_ai-0.21.0.dist-info/RECORD +211 -0
- kiln_ai-0.19.0.dist-info/RECORD +0 -115
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
3
|
+
from kiln_ai.datamodel.tool_id import KilnBuiltInToolId
|
|
4
|
+
from kiln_ai.tools.base_tool import KilnTool
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AddTool(KilnTool):
|
|
8
|
+
"""
|
|
9
|
+
A concrete tool that adds two numbers together.
|
|
10
|
+
Demonstrates how to use the KilnTool base class.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self):
|
|
14
|
+
parameters_schema = {
|
|
15
|
+
"type": "object",
|
|
16
|
+
"properties": {
|
|
17
|
+
"a": {"type": "number", "description": "The first number to add"},
|
|
18
|
+
"b": {"type": "number", "description": "The second number to add"},
|
|
19
|
+
},
|
|
20
|
+
"required": ["a", "b"],
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
super().__init__(
|
|
24
|
+
tool_id=KilnBuiltInToolId.ADD_NUMBERS,
|
|
25
|
+
name="add",
|
|
26
|
+
description="Add two numbers together and return the result",
|
|
27
|
+
parameters_schema=parameters_schema,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
async def run(self, a: Union[int, float], b: Union[int, float]) -> str:
|
|
31
|
+
"""Add two numbers and return the result."""
|
|
32
|
+
return str(a + b)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SubtractTool(KilnTool):
|
|
36
|
+
"""
|
|
37
|
+
A concrete tool that subtracts two numbers.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self):
|
|
41
|
+
parameters_schema = {
|
|
42
|
+
"type": "object",
|
|
43
|
+
"properties": {
|
|
44
|
+
"a": {"type": "number", "description": "The first number (minuend)"},
|
|
45
|
+
"b": {
|
|
46
|
+
"type": "number",
|
|
47
|
+
"description": "The second number to subtract (subtrahend)",
|
|
48
|
+
},
|
|
49
|
+
},
|
|
50
|
+
"required": ["a", "b"],
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
super().__init__(
|
|
54
|
+
tool_id=KilnBuiltInToolId.SUBTRACT_NUMBERS,
|
|
55
|
+
name="subtract",
|
|
56
|
+
description="Subtract the second number from the first number and return the result",
|
|
57
|
+
parameters_schema=parameters_schema,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
async def run(self, a: Union[int, float], b: Union[int, float]) -> str:
|
|
61
|
+
"""Subtract b from a and return the result."""
|
|
62
|
+
return str(a - b)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class MultiplyTool(KilnTool):
|
|
66
|
+
"""
|
|
67
|
+
A concrete tool that multiplies two numbers together.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(self):
|
|
71
|
+
parameters_schema = {
|
|
72
|
+
"type": "object",
|
|
73
|
+
"properties": {
|
|
74
|
+
"a": {"type": "number", "description": "The first number to multiply"},
|
|
75
|
+
"b": {"type": "number", "description": "The second number to multiply"},
|
|
76
|
+
},
|
|
77
|
+
"required": ["a", "b"],
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
super().__init__(
|
|
81
|
+
tool_id=KilnBuiltInToolId.MULTIPLY_NUMBERS,
|
|
82
|
+
name="multiply",
|
|
83
|
+
description="Multiply two numbers together and return the result",
|
|
84
|
+
parameters_schema=parameters_schema,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
async def run(self, a: Union[int, float], b: Union[int, float]) -> str:
|
|
88
|
+
"""Multiply two numbers and return the result."""
|
|
89
|
+
return str(a * b)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class DivideTool(KilnTool):
|
|
93
|
+
"""
|
|
94
|
+
A concrete tool that divides two numbers.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
def __init__(self):
|
|
98
|
+
parameters_schema = {
|
|
99
|
+
"type": "object",
|
|
100
|
+
"properties": {
|
|
101
|
+
"a": {
|
|
102
|
+
"type": "number",
|
|
103
|
+
"description": "The dividend (number to be divided)",
|
|
104
|
+
},
|
|
105
|
+
"b": {
|
|
106
|
+
"type": "number",
|
|
107
|
+
"description": "The divisor (number to divide by)",
|
|
108
|
+
},
|
|
109
|
+
},
|
|
110
|
+
"required": ["a", "b"],
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
super().__init__(
|
|
114
|
+
tool_id=KilnBuiltInToolId.DIVIDE_NUMBERS,
|
|
115
|
+
name="divide",
|
|
116
|
+
description="Divide the first number by the second number and return the result",
|
|
117
|
+
parameters_schema=parameters_schema,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
async def run(self, a: Union[int, float], b: Union[int, float]) -> str:
|
|
121
|
+
"""Divide a by b and return the result."""
|
|
122
|
+
if b == 0:
|
|
123
|
+
raise ZeroDivisionError("Cannot divide by zero")
|
|
124
|
+
return str(a / b)
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from kiln_ai.datamodel.tool_id import KilnBuiltInToolId
|
|
4
|
+
from kiln_ai.tools.built_in_tools.math_tools import (
|
|
5
|
+
AddTool,
|
|
6
|
+
DivideTool,
|
|
7
|
+
MultiplyTool,
|
|
8
|
+
SubtractTool,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TestAddTool:
|
|
13
|
+
"""Test the AddTool class."""
|
|
14
|
+
|
|
15
|
+
async def test_init(self):
|
|
16
|
+
"""Test AddTool initialization."""
|
|
17
|
+
tool = AddTool()
|
|
18
|
+
assert await tool.id() == KilnBuiltInToolId.ADD_NUMBERS
|
|
19
|
+
assert await tool.name() == "add"
|
|
20
|
+
assert (
|
|
21
|
+
await tool.description() == "Add two numbers together and return the result"
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
async def test_toolcall_definition(self):
|
|
25
|
+
"""Test AddTool toolcall definition structure."""
|
|
26
|
+
tool = AddTool()
|
|
27
|
+
definition = await tool.toolcall_definition()
|
|
28
|
+
|
|
29
|
+
assert definition["type"] == "function"
|
|
30
|
+
assert definition["function"]["name"] == "add"
|
|
31
|
+
assert (
|
|
32
|
+
definition["function"]["description"]
|
|
33
|
+
== "Add two numbers together and return the result"
|
|
34
|
+
)
|
|
35
|
+
assert "properties" in definition["function"]["parameters"]
|
|
36
|
+
assert "a" in definition["function"]["parameters"]["properties"]
|
|
37
|
+
assert "b" in definition["function"]["parameters"]["properties"]
|
|
38
|
+
|
|
39
|
+
@pytest.mark.parametrize(
|
|
40
|
+
"a, b, expected",
|
|
41
|
+
[
|
|
42
|
+
(1, 2, "3"),
|
|
43
|
+
(0, 0, "0"),
|
|
44
|
+
(-1, 1, "0"),
|
|
45
|
+
(2.5, 3.5, "6.0"),
|
|
46
|
+
(-2.5, -3.5, "-6.0"),
|
|
47
|
+
(100, 200, "300"),
|
|
48
|
+
],
|
|
49
|
+
)
|
|
50
|
+
async def test_run_various_inputs(self, a, b, expected):
|
|
51
|
+
"""Test AddTool run method with various inputs."""
|
|
52
|
+
tool = AddTool()
|
|
53
|
+
result = await tool.run(a=a, b=b)
|
|
54
|
+
assert result == expected
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class TestSubtractTool:
|
|
58
|
+
"""Test the SubtractTool class."""
|
|
59
|
+
|
|
60
|
+
async def test_init(self):
|
|
61
|
+
"""Test SubtractTool initialization."""
|
|
62
|
+
tool = SubtractTool()
|
|
63
|
+
assert await tool.id() == KilnBuiltInToolId.SUBTRACT_NUMBERS
|
|
64
|
+
assert await tool.name() == "subtract"
|
|
65
|
+
assert (
|
|
66
|
+
await tool.description()
|
|
67
|
+
== "Subtract the second number from the first number and return the result"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
async def test_toolcall_definition(self):
|
|
71
|
+
"""Test SubtractTool toolcall definition structure."""
|
|
72
|
+
tool = SubtractTool()
|
|
73
|
+
definition = await tool.toolcall_definition()
|
|
74
|
+
|
|
75
|
+
assert definition["type"] == "function"
|
|
76
|
+
assert definition["function"]["name"] == "subtract"
|
|
77
|
+
assert (
|
|
78
|
+
definition["function"]["description"]
|
|
79
|
+
== "Subtract the second number from the first number and return the result"
|
|
80
|
+
)
|
|
81
|
+
assert "properties" in definition["function"]["parameters"]
|
|
82
|
+
assert "a" in definition["function"]["parameters"]["properties"]
|
|
83
|
+
assert "b" in definition["function"]["parameters"]["properties"]
|
|
84
|
+
|
|
85
|
+
@pytest.mark.parametrize(
|
|
86
|
+
"a, b, expected",
|
|
87
|
+
[
|
|
88
|
+
(5, 3, "2"),
|
|
89
|
+
(0, 0, "0"),
|
|
90
|
+
(1, -1, "2"),
|
|
91
|
+
(5.5, 2.5, "3.0"),
|
|
92
|
+
(-2.5, -3.5, "1.0"),
|
|
93
|
+
(100, 200, "-100"),
|
|
94
|
+
],
|
|
95
|
+
)
|
|
96
|
+
async def test_run_various_inputs(self, a, b, expected):
|
|
97
|
+
"""Test SubtractTool run method with various inputs."""
|
|
98
|
+
tool = SubtractTool()
|
|
99
|
+
result = await tool.run(a=a, b=b)
|
|
100
|
+
assert result == expected
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class TestMultiplyTool:
|
|
104
|
+
"""Test the MultiplyTool class."""
|
|
105
|
+
|
|
106
|
+
async def test_init(self):
|
|
107
|
+
"""Test MultiplyTool initialization."""
|
|
108
|
+
tool = MultiplyTool()
|
|
109
|
+
assert await tool.id() == KilnBuiltInToolId.MULTIPLY_NUMBERS
|
|
110
|
+
assert await tool.name() == "multiply"
|
|
111
|
+
assert (
|
|
112
|
+
await tool.description()
|
|
113
|
+
== "Multiply two numbers together and return the result"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
async def test_toolcall_definition(self):
|
|
117
|
+
"""Test MultiplyTool toolcall definition structure."""
|
|
118
|
+
tool = MultiplyTool()
|
|
119
|
+
definition = await tool.toolcall_definition()
|
|
120
|
+
|
|
121
|
+
assert definition["type"] == "function"
|
|
122
|
+
assert definition["function"]["name"] == "multiply"
|
|
123
|
+
assert (
|
|
124
|
+
definition["function"]["description"]
|
|
125
|
+
== "Multiply two numbers together and return the result"
|
|
126
|
+
)
|
|
127
|
+
assert "properties" in definition["function"]["parameters"]
|
|
128
|
+
assert "a" in definition["function"]["parameters"]["properties"]
|
|
129
|
+
assert "b" in definition["function"]["parameters"]["properties"]
|
|
130
|
+
|
|
131
|
+
@pytest.mark.parametrize(
|
|
132
|
+
"a, b, expected",
|
|
133
|
+
[
|
|
134
|
+
(2, 3, "6"),
|
|
135
|
+
(0, 5, "0"),
|
|
136
|
+
(-2, 3, "-6"),
|
|
137
|
+
(2.5, 4, "10.0"),
|
|
138
|
+
(-2.5, -4, "10.0"),
|
|
139
|
+
(1, 1, "1"),
|
|
140
|
+
],
|
|
141
|
+
)
|
|
142
|
+
async def test_run_various_inputs(self, a, b, expected):
|
|
143
|
+
"""Test MultiplyTool run method with various inputs."""
|
|
144
|
+
tool = MultiplyTool()
|
|
145
|
+
result = await tool.run(a=a, b=b)
|
|
146
|
+
assert result == expected
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class TestDivideTool:
|
|
150
|
+
"""Test the DivideTool class."""
|
|
151
|
+
|
|
152
|
+
async def test_init(self):
|
|
153
|
+
"""Test DivideTool initialization."""
|
|
154
|
+
tool = DivideTool()
|
|
155
|
+
assert await tool.id() == KilnBuiltInToolId.DIVIDE_NUMBERS
|
|
156
|
+
assert await tool.name() == "divide"
|
|
157
|
+
assert (
|
|
158
|
+
await tool.description()
|
|
159
|
+
== "Divide the first number by the second number and return the result"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
async def test_toolcall_definition(self):
|
|
163
|
+
"""Test DivideTool toolcall definition structure."""
|
|
164
|
+
tool = DivideTool()
|
|
165
|
+
definition = await tool.toolcall_definition()
|
|
166
|
+
|
|
167
|
+
assert definition["type"] == "function"
|
|
168
|
+
assert definition["function"]["name"] == "divide"
|
|
169
|
+
assert (
|
|
170
|
+
definition["function"]["description"]
|
|
171
|
+
== "Divide the first number by the second number and return the result"
|
|
172
|
+
)
|
|
173
|
+
assert "properties" in definition["function"]["parameters"]
|
|
174
|
+
assert "a" in definition["function"]["parameters"]["properties"]
|
|
175
|
+
assert "b" in definition["function"]["parameters"]["properties"]
|
|
176
|
+
|
|
177
|
+
@pytest.mark.parametrize(
|
|
178
|
+
"a, b, expected",
|
|
179
|
+
[
|
|
180
|
+
(6, 2, "3.0"),
|
|
181
|
+
(1, 1, "1.0"),
|
|
182
|
+
(-6, 2, "-3.0"),
|
|
183
|
+
(7.5, 2.5, "3.0"),
|
|
184
|
+
(-10, -2, "5.0"),
|
|
185
|
+
(0, 5, "0.0"),
|
|
186
|
+
],
|
|
187
|
+
)
|
|
188
|
+
async def test_run_various_inputs(self, a, b, expected):
|
|
189
|
+
"""Test DivideTool run method with various inputs."""
|
|
190
|
+
tool = DivideTool()
|
|
191
|
+
result = await tool.run(a=a, b=b)
|
|
192
|
+
assert result == expected
|
|
193
|
+
|
|
194
|
+
async def test_divide_by_zero(self):
|
|
195
|
+
"""Test that division by zero raises ZeroDivisionError."""
|
|
196
|
+
tool = DivideTool()
|
|
197
|
+
with pytest.raises(ZeroDivisionError, match="Cannot divide by zero"):
|
|
198
|
+
await tool.run(a=5, b=0)
|
|
199
|
+
|
|
200
|
+
async def test_divide_zero_by_zero(self):
|
|
201
|
+
"""Test that zero divided by zero raises ZeroDivisionError."""
|
|
202
|
+
tool = DivideTool()
|
|
203
|
+
with pytest.raises(ZeroDivisionError, match="Cannot divide by zero"):
|
|
204
|
+
await tool.run(a=0, b=0)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from typing import Any, Dict
|
|
2
|
+
|
|
3
|
+
from mcp.types import CallToolResult, TextContent
|
|
4
|
+
from mcp.types import Tool as MCPTool
|
|
5
|
+
|
|
6
|
+
from kiln_ai.datamodel.external_tool_server import ExternalToolServer
|
|
7
|
+
from kiln_ai.datamodel.tool_id import MCP_REMOTE_TOOL_ID_PREFIX, ToolId
|
|
8
|
+
from kiln_ai.tools.base_tool import KilnToolInterface
|
|
9
|
+
from kiln_ai.tools.mcp_session_manager import MCPSessionManager
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MCPServerTool(KilnToolInterface):
|
|
13
|
+
def __init__(self, data_model: ExternalToolServer, name: str):
|
|
14
|
+
self._tool_id = f"{MCP_REMOTE_TOOL_ID_PREFIX}{data_model.id}::{name}"
|
|
15
|
+
self._tool_server_model = data_model
|
|
16
|
+
self._name = name
|
|
17
|
+
self._tool: MCPTool | None = None
|
|
18
|
+
|
|
19
|
+
async def id(self) -> ToolId:
|
|
20
|
+
return self._tool_id
|
|
21
|
+
|
|
22
|
+
async def name(self) -> str:
|
|
23
|
+
return self._name
|
|
24
|
+
|
|
25
|
+
async def description(self) -> str:
|
|
26
|
+
await self._load_tool_properties()
|
|
27
|
+
return self._description
|
|
28
|
+
|
|
29
|
+
async def toolcall_definition(self) -> Dict[str, Any]:
|
|
30
|
+
"""Generate OpenAI-compatible tool definition."""
|
|
31
|
+
await self._load_tool_properties()
|
|
32
|
+
return {
|
|
33
|
+
"type": "function",
|
|
34
|
+
"function": {
|
|
35
|
+
"name": await self.name(),
|
|
36
|
+
"description": await self.description(),
|
|
37
|
+
"parameters": self._parameters_schema,
|
|
38
|
+
},
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
async def run(self, **kwargs) -> Any:
|
|
42
|
+
result = await self._call_tool(**kwargs)
|
|
43
|
+
|
|
44
|
+
if result.isError:
|
|
45
|
+
raise ValueError(
|
|
46
|
+
f"Tool {await self.name()} returned an error: {result.content}"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if not result.content:
|
|
50
|
+
raise ValueError("Tool returned no content")
|
|
51
|
+
|
|
52
|
+
# raise error if the first block is not a text block
|
|
53
|
+
if not isinstance(result.content[0], TextContent):
|
|
54
|
+
raise ValueError("First block must be a text block")
|
|
55
|
+
|
|
56
|
+
# raise error if there is more than one content block
|
|
57
|
+
if len(result.content) > 1:
|
|
58
|
+
raise ValueError("Tool returned multiple content blocks, expected one")
|
|
59
|
+
|
|
60
|
+
return result.content[0].text
|
|
61
|
+
|
|
62
|
+
# Call the MCP Tool
|
|
63
|
+
async def _call_tool(self, **kwargs) -> CallToolResult:
|
|
64
|
+
async with MCPSessionManager.shared().mcp_client(
|
|
65
|
+
self._tool_server_model
|
|
66
|
+
) as session:
|
|
67
|
+
result = await session.call_tool(
|
|
68
|
+
name=await self.name(),
|
|
69
|
+
arguments=kwargs,
|
|
70
|
+
)
|
|
71
|
+
return result
|
|
72
|
+
|
|
73
|
+
async def _load_tool_properties(self):
|
|
74
|
+
if self._tool is not None:
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
tool = await self._get_tool(self._name)
|
|
78
|
+
self._tool = tool
|
|
79
|
+
self._description = tool.description or "N/A"
|
|
80
|
+
self._parameters_schema = tool.inputSchema or {
|
|
81
|
+
"type": "object",
|
|
82
|
+
"properties": {},
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
# Get the MCP Tool from the server
|
|
86
|
+
async def _get_tool(self, tool_name: str) -> MCPTool:
|
|
87
|
+
async with MCPSessionManager.shared().mcp_client(
|
|
88
|
+
self._tool_server_model
|
|
89
|
+
) as session:
|
|
90
|
+
tools = await session.list_tools()
|
|
91
|
+
|
|
92
|
+
tool = next((tool for tool in tools.tools if tool.name == tool_name), None)
|
|
93
|
+
if tool is None:
|
|
94
|
+
raise ValueError(f"Tool {tool_name} not found")
|
|
95
|
+
return tool
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import subprocess
|
|
4
|
+
import sys
|
|
5
|
+
from contextlib import asynccontextmanager
|
|
6
|
+
from datetime import timedelta
|
|
7
|
+
from typing import AsyncGenerator
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
from mcp import StdioServerParameters
|
|
11
|
+
from mcp.client.session import ClientSession
|
|
12
|
+
from mcp.client.stdio import stdio_client
|
|
13
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
14
|
+
from mcp.shared.exceptions import McpError
|
|
15
|
+
|
|
16
|
+
from kiln_ai.datamodel.external_tool_server import ExternalToolServer, ToolServerType
|
|
17
|
+
from kiln_ai.utils.config import Config
|
|
18
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class MCPSessionManager:
|
|
24
|
+
"""
|
|
25
|
+
This class is a singleton that manages MCP sessions for remote MCP servers.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
_shared_instance = None
|
|
29
|
+
|
|
30
|
+
def __init__(self):
|
|
31
|
+
self._shell_path = None
|
|
32
|
+
|
|
33
|
+
@classmethod
|
|
34
|
+
def shared(cls):
|
|
35
|
+
if cls._shared_instance is None:
|
|
36
|
+
cls._shared_instance = cls()
|
|
37
|
+
return cls._shared_instance
|
|
38
|
+
|
|
39
|
+
@asynccontextmanager
|
|
40
|
+
async def mcp_client(
|
|
41
|
+
self,
|
|
42
|
+
tool_server: ExternalToolServer,
|
|
43
|
+
) -> AsyncGenerator[
|
|
44
|
+
ClientSession,
|
|
45
|
+
None,
|
|
46
|
+
]:
|
|
47
|
+
match tool_server.type:
|
|
48
|
+
case ToolServerType.remote_mcp:
|
|
49
|
+
async with self._create_remote_mcp_session(tool_server) as session:
|
|
50
|
+
yield session
|
|
51
|
+
case ToolServerType.local_mcp:
|
|
52
|
+
async with self._create_local_mcp_session(tool_server) as session:
|
|
53
|
+
yield session
|
|
54
|
+
case _:
|
|
55
|
+
raise_exhaustive_enum_error(tool_server.type)
|
|
56
|
+
|
|
57
|
+
def _extract_first_exception(
|
|
58
|
+
self, exception: Exception, target_type: type | tuple[type, ...]
|
|
59
|
+
) -> Exception | None:
|
|
60
|
+
"""
|
|
61
|
+
Extract first relevant exception from ExceptionGroup or handle direct exceptions
|
|
62
|
+
"""
|
|
63
|
+
# Check if the exception itself is of the target type
|
|
64
|
+
if isinstance(exception, target_type):
|
|
65
|
+
return exception
|
|
66
|
+
|
|
67
|
+
# Handle ExceptionGroup
|
|
68
|
+
if hasattr(exception, "exceptions"):
|
|
69
|
+
exceptions_attr = getattr(exception, "exceptions", None)
|
|
70
|
+
if exceptions_attr:
|
|
71
|
+
for nested_exc in exceptions_attr:
|
|
72
|
+
result = self._extract_first_exception(nested_exc, target_type)
|
|
73
|
+
if result:
|
|
74
|
+
return result
|
|
75
|
+
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
@asynccontextmanager
|
|
79
|
+
async def _create_remote_mcp_session(
|
|
80
|
+
self,
|
|
81
|
+
tool_server: ExternalToolServer,
|
|
82
|
+
) -> AsyncGenerator[ClientSession, None]:
|
|
83
|
+
"""
|
|
84
|
+
Create a session for a remote MCP server.
|
|
85
|
+
"""
|
|
86
|
+
# Make sure the server_url is set
|
|
87
|
+
server_url = tool_server.properties.get("server_url")
|
|
88
|
+
if not server_url:
|
|
89
|
+
raise ValueError("server_url is required")
|
|
90
|
+
|
|
91
|
+
# Make a copy of the headers to avoid modifying the original object
|
|
92
|
+
headers = tool_server.properties.get("headers", {}).copy()
|
|
93
|
+
|
|
94
|
+
# Retrieve secret headers from configuration and merge with regular headers
|
|
95
|
+
secret_headers, _ = tool_server.retrieve_secrets()
|
|
96
|
+
headers.update(secret_headers)
|
|
97
|
+
|
|
98
|
+
try:
|
|
99
|
+
async with streamablehttp_client(server_url, headers=headers) as (
|
|
100
|
+
read_stream,
|
|
101
|
+
write_stream,
|
|
102
|
+
_,
|
|
103
|
+
):
|
|
104
|
+
# Create a session using the client streams
|
|
105
|
+
async with ClientSession(read_stream, write_stream) as session:
|
|
106
|
+
await session.initialize()
|
|
107
|
+
yield session
|
|
108
|
+
except Exception as e:
|
|
109
|
+
# Handle HTTP errors with user-friendly messages
|
|
110
|
+
|
|
111
|
+
# Check for HTTPStatusError
|
|
112
|
+
http_error = self._extract_first_exception(e, httpx.HTTPStatusError)
|
|
113
|
+
if http_error and isinstance(http_error, httpx.HTTPStatusError):
|
|
114
|
+
raise ValueError(
|
|
115
|
+
f"The MCP server rejected the request. "
|
|
116
|
+
f"Status {http_error.response.status_code}. "
|
|
117
|
+
f"Response from server:\n{http_error.response.reason_phrase}"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Check for connection errors
|
|
121
|
+
connection_error_types = (ConnectionError, OSError, httpx.RequestError)
|
|
122
|
+
connection_error = self._extract_first_exception(e, connection_error_types)
|
|
123
|
+
if connection_error and isinstance(
|
|
124
|
+
connection_error, connection_error_types
|
|
125
|
+
):
|
|
126
|
+
raise RuntimeError(
|
|
127
|
+
f"Unable to connect to MCP server. Please verify the configurations are correct, the server is running, and your network connection is working. Original error: {connection_error}"
|
|
128
|
+
) from e
|
|
129
|
+
|
|
130
|
+
# If no known error types found, re-raise the original exception
|
|
131
|
+
raise RuntimeError(
|
|
132
|
+
f"Failed to connect to the MCP Server. Check the server's docs for troubleshooting. Original error: {e}"
|
|
133
|
+
) from e
|
|
134
|
+
|
|
135
|
+
@asynccontextmanager
|
|
136
|
+
async def _create_local_mcp_session(
|
|
137
|
+
self,
|
|
138
|
+
tool_server: ExternalToolServer,
|
|
139
|
+
) -> AsyncGenerator[ClientSession, None]:
|
|
140
|
+
"""
|
|
141
|
+
Create a session for a local MCP server.
|
|
142
|
+
"""
|
|
143
|
+
command = tool_server.properties.get("command")
|
|
144
|
+
if not command:
|
|
145
|
+
raise ValueError(
|
|
146
|
+
"Attempted to start local MCP server, but no command was provided"
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
args = tool_server.properties.get("args", [])
|
|
150
|
+
if not isinstance(args, list):
|
|
151
|
+
raise ValueError(
|
|
152
|
+
"Attempted to start local MCP server, but args is not a list of strings"
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# Make a copy of the env_vars to avoid modifying the original object
|
|
156
|
+
env_vars = tool_server.properties.get("env_vars", {}).copy()
|
|
157
|
+
|
|
158
|
+
# Retrieve secret environment variables from configuration and merge with regular env_vars
|
|
159
|
+
secret_env_vars, _ = tool_server.retrieve_secrets()
|
|
160
|
+
env_vars.update(secret_env_vars)
|
|
161
|
+
|
|
162
|
+
# Set PATH, only if not explicitly set during MCP tool setup
|
|
163
|
+
if "PATH" not in env_vars:
|
|
164
|
+
env_vars["PATH"] = self._get_path()
|
|
165
|
+
|
|
166
|
+
# Set the server parameters
|
|
167
|
+
server_params = StdioServerParameters(
|
|
168
|
+
command=command,
|
|
169
|
+
args=args,
|
|
170
|
+
env=env_vars,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
try:
|
|
174
|
+
async with stdio_client(server_params) as (read, write):
|
|
175
|
+
async with ClientSession(
|
|
176
|
+
read, write, read_timeout_seconds=timedelta(seconds=8)
|
|
177
|
+
) as session:
|
|
178
|
+
await session.initialize()
|
|
179
|
+
yield session
|
|
180
|
+
except Exception as e:
|
|
181
|
+
# Check for MCP errors. Things like wrong arguments would fall here.
|
|
182
|
+
mcp_error = self._extract_first_exception(e, McpError)
|
|
183
|
+
if mcp_error and isinstance(mcp_error, McpError):
|
|
184
|
+
self._raise_local_mcp_error(mcp_error)
|
|
185
|
+
|
|
186
|
+
# Re-raise the original error but with a friendlier message
|
|
187
|
+
self._raise_local_mcp_error(e)
|
|
188
|
+
|
|
189
|
+
def _raise_local_mcp_error(self, e: Exception):
|
|
190
|
+
"""
|
|
191
|
+
Raise a ValueError with a friendlier message for local MCP errors.
|
|
192
|
+
"""
|
|
193
|
+
raise RuntimeError(
|
|
194
|
+
f"MCP server failed to start. Please verify your command, arguments, and environment variables, and consult the server's documentation for the correct setup. Original error: {e}"
|
|
195
|
+
) from e
|
|
196
|
+
|
|
197
|
+
def _get_path(self) -> str:
|
|
198
|
+
"""
|
|
199
|
+
Builds a PATH environment variable. From environment, Kiln Config, and loading rc files.
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
# If the user sets a custom MCP path, use only it. This also functions as a way to disable the shell path loading.
|
|
203
|
+
custom_mcp_path = Config.shared().get_value("custom_mcp_path")
|
|
204
|
+
if custom_mcp_path is not None:
|
|
205
|
+
return custom_mcp_path
|
|
206
|
+
else:
|
|
207
|
+
return self.get_shell_path()
|
|
208
|
+
|
|
209
|
+
def get_shell_path(self) -> str:
|
|
210
|
+
# Windows has a global PATH, so we don't need to source rc files
|
|
211
|
+
if sys.platform in ("win32", "Windows"):
|
|
212
|
+
return os.environ.get("PATH", "")
|
|
213
|
+
|
|
214
|
+
# Cache
|
|
215
|
+
if self._shell_path is not None:
|
|
216
|
+
return self._shell_path
|
|
217
|
+
|
|
218
|
+
# Attempt to get shell PATH from preferred shell, which will source rc files, run scripts like `brew shellenv`, etc.
|
|
219
|
+
shell_path = None
|
|
220
|
+
try:
|
|
221
|
+
shell = os.environ.get("SHELL", "/bin/bash")
|
|
222
|
+
# Use -l (login) flag to source ~/.profile, ~/.bash_profile, ~/.zprofile, etc.
|
|
223
|
+
result = subprocess.run(
|
|
224
|
+
[shell, "-l", "-c", "echo $PATH"],
|
|
225
|
+
capture_output=True,
|
|
226
|
+
text=True,
|
|
227
|
+
timeout=3,
|
|
228
|
+
)
|
|
229
|
+
if result.returncode == 0:
|
|
230
|
+
shell_path = result.stdout.strip()
|
|
231
|
+
except (subprocess.TimeoutExpired, subprocess.SubprocessError, Exception) as e:
|
|
232
|
+
logger.error(f"Shell path exception details: {e}")
|
|
233
|
+
|
|
234
|
+
# Fallback to environment PATH
|
|
235
|
+
if shell_path is None:
|
|
236
|
+
logger.error(
|
|
237
|
+
"Error getting shell PATH. You may not be able to find MCP server commands like 'npx'. You can set a custom MCP path in the Kiln config file. See docs for details."
|
|
238
|
+
)
|
|
239
|
+
shell_path = os.environ.get("PATH", "")
|
|
240
|
+
|
|
241
|
+
self._shell_path = shell_path
|
|
242
|
+
return shell_path
|
|
243
|
+
|
|
244
|
+
def clear_shell_path_cache(self):
|
|
245
|
+
"""Clear the cached shell path. Typically used when adding a new tool, which might have just been installed."""
|
|
246
|
+
self._shell_path = None
|