kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +6 -0
- kiln_ai/adapters/adapter_registry.py +43 -226
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/eval_runner.py +6 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -3
- kiln_ai/adapters/eval/test_g_eval.py +1 -1
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +494 -0
- kiln_ai/adapters/ml_model_list.py +876 -18
- kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/provider_tools.py +190 -46
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/test_adapter_registry.py +579 -86
- kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
- kiln_ai/adapters/test_ml_model_list.py +202 -0
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +199 -8
- kiln_ai/adapters/test_remote_config.py +551 -56
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +16 -13
- kiln_ai/datamodel/basemodel.py +201 -4
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +27 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +317 -0
- kiln_ai/datamodel/project.py +33 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +270 -14
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +501 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +187 -1
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +58 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/base_tool.py +12 -3
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +158 -0
- kiln_ai/tools/mcp_server_tool.py +2 -2
- kiln_ai/tools/mcp_session_manager.py +51 -22
- kiln_ai/tools/rag_tools.py +164 -0
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +187 -227
- kiln_ai/tools/test_rag_tools.py +929 -0
- kiln_ai/tools/test_tool_registry.py +290 -7
- kiln_ai/tools/tool_registry.py +69 -16
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +2 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +59 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +86 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
- kiln_ai-0.22.0.dist-info/RECORD +213 -0
- kiln_ai-0.20.1.dist-info/RECORD +0 -138
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import TYPE_CHECKING, Union
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field, model_validator
|
|
5
|
+
|
|
6
|
+
from kiln_ai.datamodel.basemodel import FilenameString, KilnParentedModel
|
|
7
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
8
|
+
from kiln_ai.utils.validation import (
|
|
9
|
+
validate_return_dict_prop,
|
|
10
|
+
validate_return_dict_prop_optional,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from kiln_ai.datamodel.project import Project
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class VectorStoreType(str, Enum):
|
|
18
|
+
LANCE_DB_FTS = "lancedb_fts"
|
|
19
|
+
LANCE_DB_HYBRID = "lancedb_hybrid"
|
|
20
|
+
LANCE_DB_VECTOR = "lancedb_vector"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LanceDBConfigBaseProperties(BaseModel):
|
|
24
|
+
similarity_top_k: int = Field(
|
|
25
|
+
description="The number of results to return from the vector store.",
|
|
26
|
+
)
|
|
27
|
+
overfetch_factor: int = Field(
|
|
28
|
+
description="The overfetch factor to use for the vector search.",
|
|
29
|
+
)
|
|
30
|
+
vector_column_name: str = Field(
|
|
31
|
+
description="The name of the vector column in the vector store.",
|
|
32
|
+
)
|
|
33
|
+
text_key: str = Field(
|
|
34
|
+
description="The name of the text column in the vector store.",
|
|
35
|
+
)
|
|
36
|
+
doc_id_key: str = Field(
|
|
37
|
+
description="The name of the document id column in the vector store.",
|
|
38
|
+
)
|
|
39
|
+
nprobes: int | None = Field(
|
|
40
|
+
description="The number of probes to use for the vector search.",
|
|
41
|
+
default=None,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class VectorStoreConfig(KilnParentedModel):
|
|
46
|
+
name: FilenameString = Field(
|
|
47
|
+
description="A name for your own reference to identify the vector store config.",
|
|
48
|
+
)
|
|
49
|
+
description: str | None = Field(
|
|
50
|
+
description="A description for your own reference.",
|
|
51
|
+
default=None,
|
|
52
|
+
)
|
|
53
|
+
store_type: VectorStoreType = Field(
|
|
54
|
+
description="The type of vector store to use.",
|
|
55
|
+
)
|
|
56
|
+
properties: dict[str, str | int | float | None] = Field(
|
|
57
|
+
description="The properties of the vector store config, specific to the selected store_type.",
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
@model_validator(mode="after")
|
|
61
|
+
def validate_properties(self):
|
|
62
|
+
match self.store_type:
|
|
63
|
+
case (
|
|
64
|
+
VectorStoreType.LANCE_DB_FTS
|
|
65
|
+
| VectorStoreType.LANCE_DB_HYBRID
|
|
66
|
+
| VectorStoreType.LANCE_DB_VECTOR
|
|
67
|
+
):
|
|
68
|
+
return self.validate_lancedb_properties(self.store_type)
|
|
69
|
+
case _:
|
|
70
|
+
raise_exhaustive_enum_error(self.store_type)
|
|
71
|
+
|
|
72
|
+
def validate_lancedb_properties(self, store_type: VectorStoreType):
|
|
73
|
+
err_msg_prefix = f"LanceDB vector store configs properties for {store_type}:"
|
|
74
|
+
validate_return_dict_prop(
|
|
75
|
+
self.properties, "similarity_top_k", int, err_msg_prefix
|
|
76
|
+
)
|
|
77
|
+
validate_return_dict_prop(
|
|
78
|
+
self.properties, "overfetch_factor", int, err_msg_prefix
|
|
79
|
+
)
|
|
80
|
+
validate_return_dict_prop(
|
|
81
|
+
self.properties, "vector_column_name", str, err_msg_prefix
|
|
82
|
+
)
|
|
83
|
+
validate_return_dict_prop(self.properties, "text_key", str, err_msg_prefix)
|
|
84
|
+
validate_return_dict_prop(self.properties, "doc_id_key", str, err_msg_prefix)
|
|
85
|
+
|
|
86
|
+
# nprobes is only used for vector and hybrid queries
|
|
87
|
+
if (
|
|
88
|
+
store_type == VectorStoreType.LANCE_DB_VECTOR
|
|
89
|
+
or store_type == VectorStoreType.LANCE_DB_HYBRID
|
|
90
|
+
):
|
|
91
|
+
validate_return_dict_prop(self.properties, "nprobes", int, err_msg_prefix)
|
|
92
|
+
|
|
93
|
+
return self
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def lancedb_properties(self) -> LanceDBConfigBaseProperties:
|
|
97
|
+
err_msg_prefix = "LanceDB vector store configs properties:"
|
|
98
|
+
return LanceDBConfigBaseProperties(
|
|
99
|
+
similarity_top_k=validate_return_dict_prop(
|
|
100
|
+
self.properties,
|
|
101
|
+
"similarity_top_k",
|
|
102
|
+
int,
|
|
103
|
+
err_msg_prefix,
|
|
104
|
+
),
|
|
105
|
+
overfetch_factor=validate_return_dict_prop(
|
|
106
|
+
self.properties,
|
|
107
|
+
"overfetch_factor",
|
|
108
|
+
int,
|
|
109
|
+
err_msg_prefix,
|
|
110
|
+
),
|
|
111
|
+
vector_column_name=validate_return_dict_prop(
|
|
112
|
+
self.properties,
|
|
113
|
+
"vector_column_name",
|
|
114
|
+
str,
|
|
115
|
+
err_msg_prefix,
|
|
116
|
+
),
|
|
117
|
+
text_key=validate_return_dict_prop(
|
|
118
|
+
self.properties,
|
|
119
|
+
"text_key",
|
|
120
|
+
str,
|
|
121
|
+
err_msg_prefix,
|
|
122
|
+
),
|
|
123
|
+
doc_id_key=validate_return_dict_prop(
|
|
124
|
+
self.properties,
|
|
125
|
+
"doc_id_key",
|
|
126
|
+
str,
|
|
127
|
+
err_msg_prefix,
|
|
128
|
+
),
|
|
129
|
+
nprobes=validate_return_dict_prop_optional(
|
|
130
|
+
self.properties,
|
|
131
|
+
"nprobes",
|
|
132
|
+
int,
|
|
133
|
+
err_msg_prefix,
|
|
134
|
+
),
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# Workaround to return typed parent without importing Project
|
|
138
|
+
def parent_project(self) -> Union["Project", None]:
|
|
139
|
+
if self.parent is None or self.parent.__class__.__name__ != "Project":
|
|
140
|
+
return None
|
|
141
|
+
return self.parent # type: ignore
|
kiln_ai/tools/base_tool.py
CHANGED
|
@@ -1,10 +1,19 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
+
from dataclasses import dataclass
|
|
2
3
|
from typing import Any, Dict
|
|
3
4
|
|
|
4
5
|
from kiln_ai.datamodel.json_schema import validate_schema_dict
|
|
5
6
|
from kiln_ai.datamodel.tool_id import KilnBuiltInToolId, ToolId
|
|
6
7
|
|
|
7
8
|
|
|
9
|
+
@dataclass
|
|
10
|
+
class ToolCallContext:
|
|
11
|
+
"""Context passed to tools when they are called, containing information from the calling task."""
|
|
12
|
+
|
|
13
|
+
"""Used for Kiln Tasks as Tools, to know if the tool call should save the task run it invoked to that task's Dataset."""
|
|
14
|
+
allow_saving: bool = True
|
|
15
|
+
|
|
16
|
+
|
|
8
17
|
class KilnToolInterface(ABC):
|
|
9
18
|
"""
|
|
10
19
|
Abstract interface defining the core API that all Kiln tools must implement.
|
|
@@ -12,8 +21,8 @@ class KilnToolInterface(ABC):
|
|
|
12
21
|
"""
|
|
13
22
|
|
|
14
23
|
@abstractmethod
|
|
15
|
-
async def run(self, **kwargs) -> Any:
|
|
16
|
-
"""Execute the tool with the given parameters."""
|
|
24
|
+
async def run(self, context: ToolCallContext | None = None, **kwargs) -> Any:
|
|
25
|
+
"""Execute the tool with the given parameters and calling context if provided."""
|
|
17
26
|
pass
|
|
18
27
|
|
|
19
28
|
@abstractmethod
|
|
@@ -77,6 +86,6 @@ class KilnTool(KilnToolInterface):
|
|
|
77
86
|
}
|
|
78
87
|
|
|
79
88
|
@abstractmethod
|
|
80
|
-
async def run(self, **kwargs) ->
|
|
89
|
+
async def run(self, context: ToolCallContext | None = None, **kwargs) -> Any:
|
|
81
90
|
"""Subclasses must implement the actual tool logic."""
|
|
82
91
|
pass
|
|
@@ -27,7 +27,9 @@ class AddTool(KilnTool):
|
|
|
27
27
|
parameters_schema=parameters_schema,
|
|
28
28
|
)
|
|
29
29
|
|
|
30
|
-
async def run(
|
|
30
|
+
async def run(
|
|
31
|
+
self, context=None, *, a: Union[int, float], b: Union[int, float]
|
|
32
|
+
) -> str:
|
|
31
33
|
"""Add two numbers and return the result."""
|
|
32
34
|
return str(a + b)
|
|
33
35
|
|
|
@@ -57,7 +59,9 @@ class SubtractTool(KilnTool):
|
|
|
57
59
|
parameters_schema=parameters_schema,
|
|
58
60
|
)
|
|
59
61
|
|
|
60
|
-
async def run(
|
|
62
|
+
async def run(
|
|
63
|
+
self, context=None, *, a: Union[int, float], b: Union[int, float]
|
|
64
|
+
) -> str:
|
|
61
65
|
"""Subtract b from a and return the result."""
|
|
62
66
|
return str(a - b)
|
|
63
67
|
|
|
@@ -84,7 +88,9 @@ class MultiplyTool(KilnTool):
|
|
|
84
88
|
parameters_schema=parameters_schema,
|
|
85
89
|
)
|
|
86
90
|
|
|
87
|
-
async def run(
|
|
91
|
+
async def run(
|
|
92
|
+
self, context=None, *, a: Union[int, float], b: Union[int, float]
|
|
93
|
+
) -> str:
|
|
88
94
|
"""Multiply two numbers and return the result."""
|
|
89
95
|
return str(a * b)
|
|
90
96
|
|
|
@@ -117,7 +123,9 @@ class DivideTool(KilnTool):
|
|
|
117
123
|
parameters_schema=parameters_schema,
|
|
118
124
|
)
|
|
119
125
|
|
|
120
|
-
async def run(
|
|
126
|
+
async def run(
|
|
127
|
+
self, context=None, *, a: Union[int, float], b: Union[int, float]
|
|
128
|
+
) -> str:
|
|
121
129
|
"""Divide a by b and return the result."""
|
|
122
130
|
if b == 0:
|
|
123
131
|
raise ZeroDivisionError("Cannot divide by zero")
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from functools import cached_property
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
|
|
5
|
+
from kiln_ai.datamodel import Task
|
|
6
|
+
from kiln_ai.datamodel.external_tool_server import ExternalToolServer
|
|
7
|
+
from kiln_ai.datamodel.task import TaskRunConfig
|
|
8
|
+
from kiln_ai.datamodel.task_output import DataSource, DataSourceType
|
|
9
|
+
from kiln_ai.datamodel.tool_id import ToolId
|
|
10
|
+
from kiln_ai.tools.base_tool import KilnToolInterface, ToolCallContext
|
|
11
|
+
from kiln_ai.utils.project_utils import project_from_id
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class KilnTaskToolResult:
|
|
16
|
+
output: str
|
|
17
|
+
kiln_task_tool_data: str
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class KilnTaskTool(KilnToolInterface):
|
|
21
|
+
"""
|
|
22
|
+
A tool that wraps a Kiln task, allowing it to be called as a function.
|
|
23
|
+
|
|
24
|
+
This tool loads a task by ID and executes it using the specified run configuration.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
project_id: str,
|
|
30
|
+
tool_id: str,
|
|
31
|
+
data_model: ExternalToolServer,
|
|
32
|
+
):
|
|
33
|
+
self._project_id = project_id
|
|
34
|
+
self._tool_server_model = data_model
|
|
35
|
+
self._tool_id = tool_id
|
|
36
|
+
|
|
37
|
+
self._name = data_model.properties.get("name", "")
|
|
38
|
+
self._description = data_model.properties.get("description", "")
|
|
39
|
+
self._task_id = data_model.properties.get("task_id", "")
|
|
40
|
+
self._run_config_id = data_model.properties.get("run_config_id", "")
|
|
41
|
+
|
|
42
|
+
async def id(self) -> ToolId:
|
|
43
|
+
return self._tool_id
|
|
44
|
+
|
|
45
|
+
async def name(self) -> str:
|
|
46
|
+
return self._name
|
|
47
|
+
|
|
48
|
+
async def description(self) -> str:
|
|
49
|
+
return self._description
|
|
50
|
+
|
|
51
|
+
async def toolcall_definition(self) -> Dict[str, Any]:
|
|
52
|
+
"""Generate OpenAI-compatible tool definition."""
|
|
53
|
+
return {
|
|
54
|
+
"type": "function",
|
|
55
|
+
"function": {
|
|
56
|
+
"name": await self.name(),
|
|
57
|
+
"description": await self.description(),
|
|
58
|
+
"parameters": self.parameters_schema,
|
|
59
|
+
},
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
async def run(
|
|
63
|
+
self, context: ToolCallContext | None = None, **kwargs
|
|
64
|
+
) -> KilnTaskToolResult:
|
|
65
|
+
"""Execute the wrapped Kiln task with the given parameters and calling context."""
|
|
66
|
+
if context is None:
|
|
67
|
+
raise ValueError("Context is required for running a KilnTaskTool.")
|
|
68
|
+
|
|
69
|
+
# Determine the input format
|
|
70
|
+
if self._task.input_json_schema:
|
|
71
|
+
# Structured input - pass kwargs directly
|
|
72
|
+
input = kwargs
|
|
73
|
+
else:
|
|
74
|
+
# Plaintext input - extract from 'input' parameter
|
|
75
|
+
if "input" in kwargs:
|
|
76
|
+
input = kwargs["input"]
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError(f"Input not found in kwargs: {kwargs}")
|
|
79
|
+
|
|
80
|
+
# These imports are here to avoid circular chains
|
|
81
|
+
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
82
|
+
from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
|
|
83
|
+
|
|
84
|
+
# Create adapter and run the task using the calling task's allow_saving setting
|
|
85
|
+
adapter = adapter_for_task(
|
|
86
|
+
self._task,
|
|
87
|
+
run_config_properties=self._run_config.run_config_properties,
|
|
88
|
+
base_adapter_config=AdapterConfig(
|
|
89
|
+
allow_saving=context.allow_saving,
|
|
90
|
+
default_tags=["tool_call"],
|
|
91
|
+
),
|
|
92
|
+
)
|
|
93
|
+
task_run = await adapter.invoke(
|
|
94
|
+
input,
|
|
95
|
+
input_source=DataSource(
|
|
96
|
+
type=DataSourceType.tool_call,
|
|
97
|
+
run_config=self._run_config.run_config_properties,
|
|
98
|
+
),
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return KilnTaskToolResult(
|
|
102
|
+
output=task_run.output.output,
|
|
103
|
+
kiln_task_tool_data=f"{self._project_id}:::{self._tool_id}:::{self._task.id}:::{task_run.id}",
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
@cached_property
|
|
107
|
+
def _task(self) -> Task:
|
|
108
|
+
# Load the project first
|
|
109
|
+
project = project_from_id(self._project_id)
|
|
110
|
+
if project is None:
|
|
111
|
+
raise ValueError(f"Project not found: {self._project_id}")
|
|
112
|
+
|
|
113
|
+
# Load the task from the project
|
|
114
|
+
task = Task.from_id_and_parent_path(self._task_id, project.path)
|
|
115
|
+
if task is None:
|
|
116
|
+
raise ValueError(
|
|
117
|
+
f"Task not found: {self._task_id} in project {self._project_id}"
|
|
118
|
+
)
|
|
119
|
+
return task
|
|
120
|
+
|
|
121
|
+
@cached_property
|
|
122
|
+
def _run_config(self) -> TaskRunConfig:
|
|
123
|
+
run_config = next(
|
|
124
|
+
(
|
|
125
|
+
run_config
|
|
126
|
+
for run_config in self._task.run_configs(readonly=True)
|
|
127
|
+
if run_config.id == self._run_config_id
|
|
128
|
+
),
|
|
129
|
+
None,
|
|
130
|
+
)
|
|
131
|
+
if run_config is None:
|
|
132
|
+
raise ValueError(
|
|
133
|
+
f"Task run config not found: {self._run_config_id} for task {self._task_id} in project {self._project_id}"
|
|
134
|
+
)
|
|
135
|
+
return run_config
|
|
136
|
+
|
|
137
|
+
@cached_property
|
|
138
|
+
def parameters_schema(self) -> Dict[str, Any]:
|
|
139
|
+
if self._task.input_json_schema:
|
|
140
|
+
# Use the task's input schema directly if it exists
|
|
141
|
+
parameters_schema = self._task.input_schema()
|
|
142
|
+
else:
|
|
143
|
+
# For plaintext tasks, create a simple string input parameter
|
|
144
|
+
parameters_schema = {
|
|
145
|
+
"type": "object",
|
|
146
|
+
"properties": {
|
|
147
|
+
"input": {
|
|
148
|
+
"type": "string",
|
|
149
|
+
"description": "Plaintext input for the tool.",
|
|
150
|
+
}
|
|
151
|
+
},
|
|
152
|
+
"required": ["input"],
|
|
153
|
+
}
|
|
154
|
+
if parameters_schema is None:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
f"Failed to create parameters schema for tool_id {self._tool_id}"
|
|
157
|
+
)
|
|
158
|
+
return parameters_schema
|
kiln_ai/tools/mcp_server_tool.py
CHANGED
|
@@ -5,7 +5,7 @@ from mcp.types import Tool as MCPTool
|
|
|
5
5
|
|
|
6
6
|
from kiln_ai.datamodel.external_tool_server import ExternalToolServer
|
|
7
7
|
from kiln_ai.datamodel.tool_id import MCP_REMOTE_TOOL_ID_PREFIX, ToolId
|
|
8
|
-
from kiln_ai.tools.base_tool import KilnToolInterface
|
|
8
|
+
from kiln_ai.tools.base_tool import KilnToolInterface, ToolCallContext
|
|
9
9
|
from kiln_ai.tools.mcp_session_manager import MCPSessionManager
|
|
10
10
|
|
|
11
11
|
|
|
@@ -38,7 +38,7 @@ class MCPServerTool(KilnToolInterface):
|
|
|
38
38
|
},
|
|
39
39
|
}
|
|
40
40
|
|
|
41
|
-
async def run(self, **kwargs) ->
|
|
41
|
+
async def run(self, context: ToolCallContext | None = None, **kwargs) -> str:
|
|
42
42
|
result = await self._call_tool(**kwargs)
|
|
43
43
|
|
|
44
44
|
if result.isError:
|
|
@@ -2,7 +2,9 @@ import logging
|
|
|
2
2
|
import os
|
|
3
3
|
import subprocess
|
|
4
4
|
import sys
|
|
5
|
+
import tempfile
|
|
5
6
|
from contextlib import asynccontextmanager
|
|
7
|
+
from datetime import timedelta
|
|
6
8
|
from typing import AsyncGenerator
|
|
7
9
|
|
|
8
10
|
import httpx
|
|
@@ -18,6 +20,8 @@ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
|
18
20
|
|
|
19
21
|
logger = logging.getLogger(__name__)
|
|
20
22
|
|
|
23
|
+
LOCAL_MCP_ERROR_INSTRUCTION = "Please verify your command, arguments, and environment variables, and consult the server's documentation for the correct setup."
|
|
24
|
+
|
|
21
25
|
|
|
22
26
|
class MCPSessionManager:
|
|
23
27
|
"""
|
|
@@ -50,6 +54,8 @@ class MCPSessionManager:
|
|
|
50
54
|
case ToolServerType.local_mcp:
|
|
51
55
|
async with self._create_local_mcp_session(tool_server) as session:
|
|
52
56
|
yield session
|
|
57
|
+
case ToolServerType.kiln_task:
|
|
58
|
+
raise ValueError("Kiln task tools are not available from an MCP server")
|
|
53
59
|
case _:
|
|
54
60
|
raise_exhaustive_enum_error(tool_server.type)
|
|
55
61
|
|
|
@@ -163,33 +169,56 @@ class MCPSessionManager:
|
|
|
163
169
|
env_vars["PATH"] = self._get_path()
|
|
164
170
|
|
|
165
171
|
# Set the server parameters
|
|
172
|
+
cwd = os.path.join(Config.settings_dir(), "cache", "mcp_cache")
|
|
173
|
+
os.makedirs(cwd, exist_ok=True)
|
|
166
174
|
server_params = StdioServerParameters(
|
|
167
|
-
command=command,
|
|
168
|
-
args=args,
|
|
169
|
-
env=env_vars,
|
|
175
|
+
command=command, args=args, env=env_vars, cwd=cwd
|
|
170
176
|
)
|
|
171
177
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
178
|
+
# Create temporary file to capture MCP server stderr
|
|
179
|
+
# Use errors="replace" to handle non-UTF-8 bytes gracefully
|
|
180
|
+
with tempfile.TemporaryFile(
|
|
181
|
+
mode="w+", encoding="utf-8", errors="replace"
|
|
182
|
+
) as err_log:
|
|
183
|
+
try:
|
|
184
|
+
async with stdio_client(server_params, errlog=err_log) as (
|
|
185
|
+
read,
|
|
186
|
+
write,
|
|
187
|
+
):
|
|
188
|
+
async with ClientSession(
|
|
189
|
+
read, write, read_timeout_seconds=timedelta(seconds=30)
|
|
190
|
+
) as session:
|
|
191
|
+
await session.initialize()
|
|
192
|
+
yield session
|
|
193
|
+
except Exception as e:
|
|
194
|
+
# Read stderr content from temporary file for debugging
|
|
195
|
+
err_log.seek(0) # Read from the start of the file
|
|
196
|
+
stderr_content = err_log.read()
|
|
197
|
+
if stderr_content:
|
|
198
|
+
logger.error(
|
|
199
|
+
f"MCP server '{tool_server.name}' stderr output: {stderr_content}"
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Check for MCP errors. Things like wrong arguments would fall here.
|
|
203
|
+
mcp_error = self._extract_first_exception(e, McpError)
|
|
204
|
+
if mcp_error and isinstance(mcp_error, McpError):
|
|
205
|
+
self._raise_local_mcp_error(mcp_error, stderr_content)
|
|
206
|
+
|
|
207
|
+
# Re-raise the original error but with a friendlier message
|
|
208
|
+
self._raise_local_mcp_error(e, stderr_content)
|
|
209
|
+
|
|
210
|
+
def _raise_local_mcp_error(self, e: Exception, stderr: str):
|
|
187
211
|
"""
|
|
188
|
-
Raise a
|
|
212
|
+
Raise a RuntimeError with a friendlier message for local MCP errors.
|
|
189
213
|
"""
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
214
|
+
error_msg = f"'{e}'"
|
|
215
|
+
|
|
216
|
+
if stderr:
|
|
217
|
+
error_msg += f"\nMCP server error: {stderr}"
|
|
218
|
+
|
|
219
|
+
error_msg += f"\n{LOCAL_MCP_ERROR_INSTRUCTION}"
|
|
220
|
+
|
|
221
|
+
raise RuntimeError(error_msg) from e
|
|
193
222
|
|
|
194
223
|
def _get_path(self) -> str:
|
|
195
224
|
"""
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
from functools import cached_property
|
|
2
|
+
from typing import Any, Dict, List, TypedDict
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from kiln_ai.adapters.embedding.base_embedding_adapter import BaseEmbeddingAdapter
|
|
7
|
+
from kiln_ai.adapters.embedding.embedding_registry import embedding_adapter_from_type
|
|
8
|
+
from kiln_ai.adapters.vector_store.base_vector_store_adapter import (
|
|
9
|
+
BaseVectorStoreAdapter,
|
|
10
|
+
SearchResult,
|
|
11
|
+
VectorStoreQuery,
|
|
12
|
+
)
|
|
13
|
+
from kiln_ai.adapters.vector_store.vector_store_registry import (
|
|
14
|
+
vector_store_adapter_for_config,
|
|
15
|
+
)
|
|
16
|
+
from kiln_ai.datamodel.embedding import EmbeddingConfig
|
|
17
|
+
from kiln_ai.datamodel.project import Project
|
|
18
|
+
from kiln_ai.datamodel.rag import RagConfig
|
|
19
|
+
from kiln_ai.datamodel.tool_id import ToolId
|
|
20
|
+
from kiln_ai.datamodel.vector_store import VectorStoreConfig, VectorStoreType
|
|
21
|
+
from kiln_ai.tools.base_tool import KilnToolInterface, ToolCallContext
|
|
22
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ChunkContext(BaseModel):
|
|
26
|
+
metadata: dict
|
|
27
|
+
text: str
|
|
28
|
+
|
|
29
|
+
def serialize(self) -> str:
|
|
30
|
+
metadata_str = ", ".join([f"{k}: {v}" for k, v in self.metadata.items()])
|
|
31
|
+
return f"[{metadata_str}]\n{self.text}\n\n"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def format_search_results(search_results: List[SearchResult]) -> str:
|
|
35
|
+
results: List[ChunkContext] = []
|
|
36
|
+
for search_result in search_results:
|
|
37
|
+
results.append(
|
|
38
|
+
ChunkContext(
|
|
39
|
+
metadata={
|
|
40
|
+
"document_id": search_result.document_id,
|
|
41
|
+
"chunk_idx": search_result.chunk_idx,
|
|
42
|
+
},
|
|
43
|
+
text=search_result.chunk_text,
|
|
44
|
+
)
|
|
45
|
+
)
|
|
46
|
+
return "\n=========\n".join([result.serialize() for result in results])
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class RagParams(TypedDict):
|
|
50
|
+
query: str
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class RagTool(KilnToolInterface):
|
|
54
|
+
"""
|
|
55
|
+
A tool that searches the vector store and returns the most relevant chunks.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(self, tool_id: str, rag_config: RagConfig):
|
|
59
|
+
self._id = tool_id
|
|
60
|
+
self._name = rag_config.tool_name
|
|
61
|
+
self._description = rag_config.tool_description
|
|
62
|
+
self._parameters_schema = {
|
|
63
|
+
"type": "object",
|
|
64
|
+
"properties": {
|
|
65
|
+
"query": {
|
|
66
|
+
"type": "string",
|
|
67
|
+
"description": "The search query",
|
|
68
|
+
},
|
|
69
|
+
},
|
|
70
|
+
"required": ["query"],
|
|
71
|
+
}
|
|
72
|
+
self._rag_config = rag_config
|
|
73
|
+
vector_store_config = VectorStoreConfig.from_id_and_parent_path(
|
|
74
|
+
str(self._rag_config.vector_store_config_id), self.project.path
|
|
75
|
+
)
|
|
76
|
+
if vector_store_config is None:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
f"Vector store config not found: {self._rag_config.vector_store_config_id}"
|
|
79
|
+
)
|
|
80
|
+
self._vector_store_config = vector_store_config
|
|
81
|
+
self._vector_store_adapter: BaseVectorStoreAdapter | None = None
|
|
82
|
+
|
|
83
|
+
@cached_property
|
|
84
|
+
def project(self) -> Project:
|
|
85
|
+
project = self._rag_config.parent_project()
|
|
86
|
+
if project is None:
|
|
87
|
+
raise ValueError(f"RAG config {self._rag_config.id} has no project")
|
|
88
|
+
return project
|
|
89
|
+
|
|
90
|
+
@cached_property
|
|
91
|
+
def embedding(
|
|
92
|
+
self,
|
|
93
|
+
) -> tuple[EmbeddingConfig, BaseEmbeddingAdapter]:
|
|
94
|
+
embedding_config = EmbeddingConfig.from_id_and_parent_path(
|
|
95
|
+
str(self._rag_config.embedding_config_id), self.project.path
|
|
96
|
+
)
|
|
97
|
+
if embedding_config is None:
|
|
98
|
+
raise ValueError(
|
|
99
|
+
f"Embedding config not found: {self._rag_config.embedding_config_id}"
|
|
100
|
+
)
|
|
101
|
+
return embedding_config, embedding_adapter_from_type(embedding_config)
|
|
102
|
+
|
|
103
|
+
async def vector_store(
|
|
104
|
+
self,
|
|
105
|
+
) -> BaseVectorStoreAdapter:
|
|
106
|
+
if self._vector_store_adapter is None:
|
|
107
|
+
self._vector_store_adapter = await vector_store_adapter_for_config(
|
|
108
|
+
vector_store_config=self._vector_store_config,
|
|
109
|
+
rag_config=self._rag_config,
|
|
110
|
+
)
|
|
111
|
+
return self._vector_store_adapter
|
|
112
|
+
|
|
113
|
+
async def id(self) -> ToolId:
|
|
114
|
+
return self._id
|
|
115
|
+
|
|
116
|
+
async def name(self) -> str:
|
|
117
|
+
return self._name
|
|
118
|
+
|
|
119
|
+
async def description(self) -> str:
|
|
120
|
+
return self._description
|
|
121
|
+
|
|
122
|
+
async def toolcall_definition(self) -> Dict[str, Any]:
|
|
123
|
+
"""Return the OpenAI-compatible tool definition for this tool."""
|
|
124
|
+
return {
|
|
125
|
+
"type": "function",
|
|
126
|
+
"function": {
|
|
127
|
+
"name": await self.name(),
|
|
128
|
+
"description": await self.description(),
|
|
129
|
+
"parameters": self._parameters_schema,
|
|
130
|
+
},
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
async def run(self, context: ToolCallContext | None = None, **kwargs) -> str:
|
|
134
|
+
kwargs = RagParams(**kwargs)
|
|
135
|
+
query = kwargs["query"]
|
|
136
|
+
|
|
137
|
+
_, embedding_adapter = self.embedding
|
|
138
|
+
|
|
139
|
+
vector_store_adapter = await self.vector_store()
|
|
140
|
+
store_query = VectorStoreQuery(
|
|
141
|
+
query_embedding=None,
|
|
142
|
+
query_string=query,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
match self._vector_store_config.store_type:
|
|
146
|
+
case VectorStoreType.LANCE_DB_HYBRID | VectorStoreType.LANCE_DB_VECTOR:
|
|
147
|
+
is_vector_query = True
|
|
148
|
+
case VectorStoreType.LANCE_DB_FTS:
|
|
149
|
+
is_vector_query = False
|
|
150
|
+
case _:
|
|
151
|
+
raise_exhaustive_enum_error(self._vector_store_config.store_type)
|
|
152
|
+
|
|
153
|
+
if is_vector_query:
|
|
154
|
+
query_embedding_result = await embedding_adapter.generate_embeddings(
|
|
155
|
+
[query]
|
|
156
|
+
)
|
|
157
|
+
if len(query_embedding_result.embeddings) == 0:
|
|
158
|
+
raise ValueError("No embeddings generated")
|
|
159
|
+
store_query.query_embedding = query_embedding_result.embeddings[0].vector
|
|
160
|
+
|
|
161
|
+
search_results = await vector_store_adapter.search(store_query)
|
|
162
|
+
search_results_as_text = format_search_results(search_results)
|
|
163
|
+
|
|
164
|
+
return search_results_as_text
|