kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (133) hide show
  1. kiln_ai/adapters/__init__.py +6 -0
  2. kiln_ai/adapters/adapter_registry.py +43 -226
  3. kiln_ai/adapters/chunkers/__init__.py +13 -0
  4. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  5. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  6. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  7. kiln_ai/adapters/chunkers/helpers.py +23 -0
  8. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  9. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  10. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  11. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  12. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  13. kiln_ai/adapters/embedding/__init__.py +0 -0
  14. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  15. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  16. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  17. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  18. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  19. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  20. kiln_ai/adapters/eval/eval_runner.py +6 -2
  21. kiln_ai/adapters/eval/test_base_eval.py +1 -3
  22. kiln_ai/adapters/eval/test_g_eval.py +1 -1
  23. kiln_ai/adapters/extractors/__init__.py +18 -0
  24. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  25. kiln_ai/adapters/extractors/encoding.py +20 -0
  26. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  27. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  28. kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
  29. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  30. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  31. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  32. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  33. kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
  34. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  35. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  36. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  37. kiln_ai/adapters/ml_embedding_model_list.py +494 -0
  38. kiln_ai/adapters/ml_model_list.py +876 -18
  39. kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
  40. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
  41. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
  42. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
  43. kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
  44. kiln_ai/adapters/ollama_tools.py +69 -12
  45. kiln_ai/adapters/provider_tools.py +190 -46
  46. kiln_ai/adapters/rag/deduplication.py +49 -0
  47. kiln_ai/adapters/rag/progress.py +252 -0
  48. kiln_ai/adapters/rag/rag_runners.py +844 -0
  49. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  50. kiln_ai/adapters/rag/test_progress.py +785 -0
  51. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  52. kiln_ai/adapters/remote_config.py +80 -8
  53. kiln_ai/adapters/test_adapter_registry.py +579 -86
  54. kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
  55. kiln_ai/adapters/test_ml_model_list.py +202 -0
  56. kiln_ai/adapters/test_ollama_tools.py +340 -1
  57. kiln_ai/adapters/test_prompt_builders.py +1 -1
  58. kiln_ai/adapters/test_provider_tools.py +199 -8
  59. kiln_ai/adapters/test_remote_config.py +551 -56
  60. kiln_ai/adapters/vector_store/__init__.py +1 -0
  61. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  62. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  63. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  64. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  65. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  66. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  67. kiln_ai/datamodel/__init__.py +16 -13
  68. kiln_ai/datamodel/basemodel.py +201 -4
  69. kiln_ai/datamodel/chunk.py +158 -0
  70. kiln_ai/datamodel/datamodel_enums.py +27 -0
  71. kiln_ai/datamodel/embedding.py +64 -0
  72. kiln_ai/datamodel/external_tool_server.py +206 -54
  73. kiln_ai/datamodel/extraction.py +317 -0
  74. kiln_ai/datamodel/project.py +33 -1
  75. kiln_ai/datamodel/rag.py +79 -0
  76. kiln_ai/datamodel/task.py +5 -0
  77. kiln_ai/datamodel/task_output.py +41 -11
  78. kiln_ai/datamodel/test_attachment.py +649 -0
  79. kiln_ai/datamodel/test_basemodel.py +270 -14
  80. kiln_ai/datamodel/test_chunk_models.py +317 -0
  81. kiln_ai/datamodel/test_dataset_split.py +1 -1
  82. kiln_ai/datamodel/test_datasource.py +50 -0
  83. kiln_ai/datamodel/test_embedding_models.py +448 -0
  84. kiln_ai/datamodel/test_eval_model.py +6 -6
  85. kiln_ai/datamodel/test_external_tool_server.py +534 -152
  86. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  87. kiln_ai/datamodel/test_extraction_model.py +501 -0
  88. kiln_ai/datamodel/test_rag.py +641 -0
  89. kiln_ai/datamodel/test_task.py +35 -1
  90. kiln_ai/datamodel/test_tool_id.py +187 -1
  91. kiln_ai/datamodel/test_vector_store.py +320 -0
  92. kiln_ai/datamodel/tool_id.py +58 -0
  93. kiln_ai/datamodel/vector_store.py +141 -0
  94. kiln_ai/tools/base_tool.py +12 -3
  95. kiln_ai/tools/built_in_tools/math_tools.py +12 -4
  96. kiln_ai/tools/kiln_task_tool.py +158 -0
  97. kiln_ai/tools/mcp_server_tool.py +2 -2
  98. kiln_ai/tools/mcp_session_manager.py +51 -22
  99. kiln_ai/tools/rag_tools.py +164 -0
  100. kiln_ai/tools/test_kiln_task_tool.py +527 -0
  101. kiln_ai/tools/test_mcp_server_tool.py +4 -15
  102. kiln_ai/tools/test_mcp_session_manager.py +187 -227
  103. kiln_ai/tools/test_rag_tools.py +929 -0
  104. kiln_ai/tools/test_tool_registry.py +290 -7
  105. kiln_ai/tools/tool_registry.py +69 -16
  106. kiln_ai/utils/__init__.py +3 -0
  107. kiln_ai/utils/async_job_runner.py +62 -17
  108. kiln_ai/utils/config.py +2 -2
  109. kiln_ai/utils/env.py +15 -0
  110. kiln_ai/utils/filesystem.py +14 -0
  111. kiln_ai/utils/filesystem_cache.py +60 -0
  112. kiln_ai/utils/litellm.py +94 -0
  113. kiln_ai/utils/lock.py +100 -0
  114. kiln_ai/utils/mime_type.py +38 -0
  115. kiln_ai/utils/open_ai_types.py +19 -2
  116. kiln_ai/utils/pdf_utils.py +59 -0
  117. kiln_ai/utils/test_async_job_runner.py +151 -35
  118. kiln_ai/utils/test_env.py +142 -0
  119. kiln_ai/utils/test_filesystem_cache.py +316 -0
  120. kiln_ai/utils/test_litellm.py +206 -0
  121. kiln_ai/utils/test_lock.py +185 -0
  122. kiln_ai/utils/test_mime_type.py +66 -0
  123. kiln_ai/utils/test_open_ai_types.py +88 -12
  124. kiln_ai/utils/test_pdf_utils.py +86 -0
  125. kiln_ai/utils/test_uuid.py +111 -0
  126. kiln_ai/utils/test_validation.py +524 -0
  127. kiln_ai/utils/uuid.py +9 -0
  128. kiln_ai/utils/validation.py +90 -0
  129. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
  130. kiln_ai-0.22.0.dist-info/RECORD +213 -0
  131. kiln_ai-0.20.1.dist-info/RECORD +0 -138
  132. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
  133. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,141 @@
1
+ from enum import Enum
2
+ from typing import TYPE_CHECKING, Union
3
+
4
+ from pydantic import BaseModel, Field, model_validator
5
+
6
+ from kiln_ai.datamodel.basemodel import FilenameString, KilnParentedModel
7
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
8
+ from kiln_ai.utils.validation import (
9
+ validate_return_dict_prop,
10
+ validate_return_dict_prop_optional,
11
+ )
12
+
13
+ if TYPE_CHECKING:
14
+ from kiln_ai.datamodel.project import Project
15
+
16
+
17
+ class VectorStoreType(str, Enum):
18
+ LANCE_DB_FTS = "lancedb_fts"
19
+ LANCE_DB_HYBRID = "lancedb_hybrid"
20
+ LANCE_DB_VECTOR = "lancedb_vector"
21
+
22
+
23
+ class LanceDBConfigBaseProperties(BaseModel):
24
+ similarity_top_k: int = Field(
25
+ description="The number of results to return from the vector store.",
26
+ )
27
+ overfetch_factor: int = Field(
28
+ description="The overfetch factor to use for the vector search.",
29
+ )
30
+ vector_column_name: str = Field(
31
+ description="The name of the vector column in the vector store.",
32
+ )
33
+ text_key: str = Field(
34
+ description="The name of the text column in the vector store.",
35
+ )
36
+ doc_id_key: str = Field(
37
+ description="The name of the document id column in the vector store.",
38
+ )
39
+ nprobes: int | None = Field(
40
+ description="The number of probes to use for the vector search.",
41
+ default=None,
42
+ )
43
+
44
+
45
+ class VectorStoreConfig(KilnParentedModel):
46
+ name: FilenameString = Field(
47
+ description="A name for your own reference to identify the vector store config.",
48
+ )
49
+ description: str | None = Field(
50
+ description="A description for your own reference.",
51
+ default=None,
52
+ )
53
+ store_type: VectorStoreType = Field(
54
+ description="The type of vector store to use.",
55
+ )
56
+ properties: dict[str, str | int | float | None] = Field(
57
+ description="The properties of the vector store config, specific to the selected store_type.",
58
+ )
59
+
60
+ @model_validator(mode="after")
61
+ def validate_properties(self):
62
+ match self.store_type:
63
+ case (
64
+ VectorStoreType.LANCE_DB_FTS
65
+ | VectorStoreType.LANCE_DB_HYBRID
66
+ | VectorStoreType.LANCE_DB_VECTOR
67
+ ):
68
+ return self.validate_lancedb_properties(self.store_type)
69
+ case _:
70
+ raise_exhaustive_enum_error(self.store_type)
71
+
72
+ def validate_lancedb_properties(self, store_type: VectorStoreType):
73
+ err_msg_prefix = f"LanceDB vector store configs properties for {store_type}:"
74
+ validate_return_dict_prop(
75
+ self.properties, "similarity_top_k", int, err_msg_prefix
76
+ )
77
+ validate_return_dict_prop(
78
+ self.properties, "overfetch_factor", int, err_msg_prefix
79
+ )
80
+ validate_return_dict_prop(
81
+ self.properties, "vector_column_name", str, err_msg_prefix
82
+ )
83
+ validate_return_dict_prop(self.properties, "text_key", str, err_msg_prefix)
84
+ validate_return_dict_prop(self.properties, "doc_id_key", str, err_msg_prefix)
85
+
86
+ # nprobes is only used for vector and hybrid queries
87
+ if (
88
+ store_type == VectorStoreType.LANCE_DB_VECTOR
89
+ or store_type == VectorStoreType.LANCE_DB_HYBRID
90
+ ):
91
+ validate_return_dict_prop(self.properties, "nprobes", int, err_msg_prefix)
92
+
93
+ return self
94
+
95
+ @property
96
+ def lancedb_properties(self) -> LanceDBConfigBaseProperties:
97
+ err_msg_prefix = "LanceDB vector store configs properties:"
98
+ return LanceDBConfigBaseProperties(
99
+ similarity_top_k=validate_return_dict_prop(
100
+ self.properties,
101
+ "similarity_top_k",
102
+ int,
103
+ err_msg_prefix,
104
+ ),
105
+ overfetch_factor=validate_return_dict_prop(
106
+ self.properties,
107
+ "overfetch_factor",
108
+ int,
109
+ err_msg_prefix,
110
+ ),
111
+ vector_column_name=validate_return_dict_prop(
112
+ self.properties,
113
+ "vector_column_name",
114
+ str,
115
+ err_msg_prefix,
116
+ ),
117
+ text_key=validate_return_dict_prop(
118
+ self.properties,
119
+ "text_key",
120
+ str,
121
+ err_msg_prefix,
122
+ ),
123
+ doc_id_key=validate_return_dict_prop(
124
+ self.properties,
125
+ "doc_id_key",
126
+ str,
127
+ err_msg_prefix,
128
+ ),
129
+ nprobes=validate_return_dict_prop_optional(
130
+ self.properties,
131
+ "nprobes",
132
+ int,
133
+ err_msg_prefix,
134
+ ),
135
+ )
136
+
137
+ # Workaround to return typed parent without importing Project
138
+ def parent_project(self) -> Union["Project", None]:
139
+ if self.parent is None or self.parent.__class__.__name__ != "Project":
140
+ return None
141
+ return self.parent # type: ignore
@@ -1,10 +1,19 @@
1
1
  from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
2
3
  from typing import Any, Dict
3
4
 
4
5
  from kiln_ai.datamodel.json_schema import validate_schema_dict
5
6
  from kiln_ai.datamodel.tool_id import KilnBuiltInToolId, ToolId
6
7
 
7
8
 
9
+ @dataclass
10
+ class ToolCallContext:
11
+ """Context passed to tools when they are called, containing information from the calling task."""
12
+
13
+ """Used for Kiln Tasks as Tools, to know if the tool call should save the task run it invoked to that task's Dataset."""
14
+ allow_saving: bool = True
15
+
16
+
8
17
  class KilnToolInterface(ABC):
9
18
  """
10
19
  Abstract interface defining the core API that all Kiln tools must implement.
@@ -12,8 +21,8 @@ class KilnToolInterface(ABC):
12
21
  """
13
22
 
14
23
  @abstractmethod
15
- async def run(self, **kwargs) -> Any:
16
- """Execute the tool with the given parameters."""
24
+ async def run(self, context: ToolCallContext | None = None, **kwargs) -> Any:
25
+ """Execute the tool with the given parameters and calling context if provided."""
17
26
  pass
18
27
 
19
28
  @abstractmethod
@@ -77,6 +86,6 @@ class KilnTool(KilnToolInterface):
77
86
  }
78
87
 
79
88
  @abstractmethod
80
- async def run(self, **kwargs) -> str:
89
+ async def run(self, context: ToolCallContext | None = None, **kwargs) -> Any:
81
90
  """Subclasses must implement the actual tool logic."""
82
91
  pass
@@ -27,7 +27,9 @@ class AddTool(KilnTool):
27
27
  parameters_schema=parameters_schema,
28
28
  )
29
29
 
30
- async def run(self, a: Union[int, float], b: Union[int, float]) -> str:
30
+ async def run(
31
+ self, context=None, *, a: Union[int, float], b: Union[int, float]
32
+ ) -> str:
31
33
  """Add two numbers and return the result."""
32
34
  return str(a + b)
33
35
 
@@ -57,7 +59,9 @@ class SubtractTool(KilnTool):
57
59
  parameters_schema=parameters_schema,
58
60
  )
59
61
 
60
- async def run(self, a: Union[int, float], b: Union[int, float]) -> str:
62
+ async def run(
63
+ self, context=None, *, a: Union[int, float], b: Union[int, float]
64
+ ) -> str:
61
65
  """Subtract b from a and return the result."""
62
66
  return str(a - b)
63
67
 
@@ -84,7 +88,9 @@ class MultiplyTool(KilnTool):
84
88
  parameters_schema=parameters_schema,
85
89
  )
86
90
 
87
- async def run(self, a: Union[int, float], b: Union[int, float]) -> str:
91
+ async def run(
92
+ self, context=None, *, a: Union[int, float], b: Union[int, float]
93
+ ) -> str:
88
94
  """Multiply two numbers and return the result."""
89
95
  return str(a * b)
90
96
 
@@ -117,7 +123,9 @@ class DivideTool(KilnTool):
117
123
  parameters_schema=parameters_schema,
118
124
  )
119
125
 
120
- async def run(self, a: Union[int, float], b: Union[int, float]) -> str:
126
+ async def run(
127
+ self, context=None, *, a: Union[int, float], b: Union[int, float]
128
+ ) -> str:
121
129
  """Divide a by b and return the result."""
122
130
  if b == 0:
123
131
  raise ZeroDivisionError("Cannot divide by zero")
@@ -0,0 +1,158 @@
1
+ from dataclasses import dataclass
2
+ from functools import cached_property
3
+ from typing import Any, Dict
4
+
5
+ from kiln_ai.datamodel import Task
6
+ from kiln_ai.datamodel.external_tool_server import ExternalToolServer
7
+ from kiln_ai.datamodel.task import TaskRunConfig
8
+ from kiln_ai.datamodel.task_output import DataSource, DataSourceType
9
+ from kiln_ai.datamodel.tool_id import ToolId
10
+ from kiln_ai.tools.base_tool import KilnToolInterface, ToolCallContext
11
+ from kiln_ai.utils.project_utils import project_from_id
12
+
13
+
14
+ @dataclass
15
+ class KilnTaskToolResult:
16
+ output: str
17
+ kiln_task_tool_data: str
18
+
19
+
20
+ class KilnTaskTool(KilnToolInterface):
21
+ """
22
+ A tool that wraps a Kiln task, allowing it to be called as a function.
23
+
24
+ This tool loads a task by ID and executes it using the specified run configuration.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ project_id: str,
30
+ tool_id: str,
31
+ data_model: ExternalToolServer,
32
+ ):
33
+ self._project_id = project_id
34
+ self._tool_server_model = data_model
35
+ self._tool_id = tool_id
36
+
37
+ self._name = data_model.properties.get("name", "")
38
+ self._description = data_model.properties.get("description", "")
39
+ self._task_id = data_model.properties.get("task_id", "")
40
+ self._run_config_id = data_model.properties.get("run_config_id", "")
41
+
42
+ async def id(self) -> ToolId:
43
+ return self._tool_id
44
+
45
+ async def name(self) -> str:
46
+ return self._name
47
+
48
+ async def description(self) -> str:
49
+ return self._description
50
+
51
+ async def toolcall_definition(self) -> Dict[str, Any]:
52
+ """Generate OpenAI-compatible tool definition."""
53
+ return {
54
+ "type": "function",
55
+ "function": {
56
+ "name": await self.name(),
57
+ "description": await self.description(),
58
+ "parameters": self.parameters_schema,
59
+ },
60
+ }
61
+
62
+ async def run(
63
+ self, context: ToolCallContext | None = None, **kwargs
64
+ ) -> KilnTaskToolResult:
65
+ """Execute the wrapped Kiln task with the given parameters and calling context."""
66
+ if context is None:
67
+ raise ValueError("Context is required for running a KilnTaskTool.")
68
+
69
+ # Determine the input format
70
+ if self._task.input_json_schema:
71
+ # Structured input - pass kwargs directly
72
+ input = kwargs
73
+ else:
74
+ # Plaintext input - extract from 'input' parameter
75
+ if "input" in kwargs:
76
+ input = kwargs["input"]
77
+ else:
78
+ raise ValueError(f"Input not found in kwargs: {kwargs}")
79
+
80
+ # These imports are here to avoid circular chains
81
+ from kiln_ai.adapters.adapter_registry import adapter_for_task
82
+ from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
83
+
84
+ # Create adapter and run the task using the calling task's allow_saving setting
85
+ adapter = adapter_for_task(
86
+ self._task,
87
+ run_config_properties=self._run_config.run_config_properties,
88
+ base_adapter_config=AdapterConfig(
89
+ allow_saving=context.allow_saving,
90
+ default_tags=["tool_call"],
91
+ ),
92
+ )
93
+ task_run = await adapter.invoke(
94
+ input,
95
+ input_source=DataSource(
96
+ type=DataSourceType.tool_call,
97
+ run_config=self._run_config.run_config_properties,
98
+ ),
99
+ )
100
+
101
+ return KilnTaskToolResult(
102
+ output=task_run.output.output,
103
+ kiln_task_tool_data=f"{self._project_id}:::{self._tool_id}:::{self._task.id}:::{task_run.id}",
104
+ )
105
+
106
+ @cached_property
107
+ def _task(self) -> Task:
108
+ # Load the project first
109
+ project = project_from_id(self._project_id)
110
+ if project is None:
111
+ raise ValueError(f"Project not found: {self._project_id}")
112
+
113
+ # Load the task from the project
114
+ task = Task.from_id_and_parent_path(self._task_id, project.path)
115
+ if task is None:
116
+ raise ValueError(
117
+ f"Task not found: {self._task_id} in project {self._project_id}"
118
+ )
119
+ return task
120
+
121
+ @cached_property
122
+ def _run_config(self) -> TaskRunConfig:
123
+ run_config = next(
124
+ (
125
+ run_config
126
+ for run_config in self._task.run_configs(readonly=True)
127
+ if run_config.id == self._run_config_id
128
+ ),
129
+ None,
130
+ )
131
+ if run_config is None:
132
+ raise ValueError(
133
+ f"Task run config not found: {self._run_config_id} for task {self._task_id} in project {self._project_id}"
134
+ )
135
+ return run_config
136
+
137
+ @cached_property
138
+ def parameters_schema(self) -> Dict[str, Any]:
139
+ if self._task.input_json_schema:
140
+ # Use the task's input schema directly if it exists
141
+ parameters_schema = self._task.input_schema()
142
+ else:
143
+ # For plaintext tasks, create a simple string input parameter
144
+ parameters_schema = {
145
+ "type": "object",
146
+ "properties": {
147
+ "input": {
148
+ "type": "string",
149
+ "description": "Plaintext input for the tool.",
150
+ }
151
+ },
152
+ "required": ["input"],
153
+ }
154
+ if parameters_schema is None:
155
+ raise ValueError(
156
+ f"Failed to create parameters schema for tool_id {self._tool_id}"
157
+ )
158
+ return parameters_schema
@@ -5,7 +5,7 @@ from mcp.types import Tool as MCPTool
5
5
 
6
6
  from kiln_ai.datamodel.external_tool_server import ExternalToolServer
7
7
  from kiln_ai.datamodel.tool_id import MCP_REMOTE_TOOL_ID_PREFIX, ToolId
8
- from kiln_ai.tools.base_tool import KilnToolInterface
8
+ from kiln_ai.tools.base_tool import KilnToolInterface, ToolCallContext
9
9
  from kiln_ai.tools.mcp_session_manager import MCPSessionManager
10
10
 
11
11
 
@@ -38,7 +38,7 @@ class MCPServerTool(KilnToolInterface):
38
38
  },
39
39
  }
40
40
 
41
- async def run(self, **kwargs) -> Any:
41
+ async def run(self, context: ToolCallContext | None = None, **kwargs) -> str:
42
42
  result = await self._call_tool(**kwargs)
43
43
 
44
44
  if result.isError:
@@ -2,7 +2,9 @@ import logging
2
2
  import os
3
3
  import subprocess
4
4
  import sys
5
+ import tempfile
5
6
  from contextlib import asynccontextmanager
7
+ from datetime import timedelta
6
8
  from typing import AsyncGenerator
7
9
 
8
10
  import httpx
@@ -18,6 +20,8 @@ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
18
20
 
19
21
  logger = logging.getLogger(__name__)
20
22
 
23
+ LOCAL_MCP_ERROR_INSTRUCTION = "Please verify your command, arguments, and environment variables, and consult the server's documentation for the correct setup."
24
+
21
25
 
22
26
  class MCPSessionManager:
23
27
  """
@@ -50,6 +54,8 @@ class MCPSessionManager:
50
54
  case ToolServerType.local_mcp:
51
55
  async with self._create_local_mcp_session(tool_server) as session:
52
56
  yield session
57
+ case ToolServerType.kiln_task:
58
+ raise ValueError("Kiln task tools are not available from an MCP server")
53
59
  case _:
54
60
  raise_exhaustive_enum_error(tool_server.type)
55
61
 
@@ -163,33 +169,56 @@ class MCPSessionManager:
163
169
  env_vars["PATH"] = self._get_path()
164
170
 
165
171
  # Set the server parameters
172
+ cwd = os.path.join(Config.settings_dir(), "cache", "mcp_cache")
173
+ os.makedirs(cwd, exist_ok=True)
166
174
  server_params = StdioServerParameters(
167
- command=command,
168
- args=args,
169
- env=env_vars,
175
+ command=command, args=args, env=env_vars, cwd=cwd
170
176
  )
171
177
 
172
- try:
173
- async with stdio_client(server_params) as (read, write):
174
- async with ClientSession(read, write) as session:
175
- await session.initialize()
176
- yield session
177
- except Exception as e:
178
- # Check for MCP errors. Things like wrong arguments would fall here.
179
- mcp_error = self._extract_first_exception(e, McpError)
180
- if mcp_error and isinstance(mcp_error, McpError):
181
- self._raise_local_mcp_error(mcp_error)
182
-
183
- # Re-raise the original error but with a friendlier message
184
- self._raise_local_mcp_error(e)
185
-
186
- def _raise_local_mcp_error(self, e: Exception):
178
+ # Create temporary file to capture MCP server stderr
179
+ # Use errors="replace" to handle non-UTF-8 bytes gracefully
180
+ with tempfile.TemporaryFile(
181
+ mode="w+", encoding="utf-8", errors="replace"
182
+ ) as err_log:
183
+ try:
184
+ async with stdio_client(server_params, errlog=err_log) as (
185
+ read,
186
+ write,
187
+ ):
188
+ async with ClientSession(
189
+ read, write, read_timeout_seconds=timedelta(seconds=30)
190
+ ) as session:
191
+ await session.initialize()
192
+ yield session
193
+ except Exception as e:
194
+ # Read stderr content from temporary file for debugging
195
+ err_log.seek(0) # Read from the start of the file
196
+ stderr_content = err_log.read()
197
+ if stderr_content:
198
+ logger.error(
199
+ f"MCP server '{tool_server.name}' stderr output: {stderr_content}"
200
+ )
201
+
202
+ # Check for MCP errors. Things like wrong arguments would fall here.
203
+ mcp_error = self._extract_first_exception(e, McpError)
204
+ if mcp_error and isinstance(mcp_error, McpError):
205
+ self._raise_local_mcp_error(mcp_error, stderr_content)
206
+
207
+ # Re-raise the original error but with a friendlier message
208
+ self._raise_local_mcp_error(e, stderr_content)
209
+
210
+ def _raise_local_mcp_error(self, e: Exception, stderr: str):
187
211
  """
188
- Raise a ValueError with a friendlier message for local MCP errors.
212
+ Raise a RuntimeError with a friendlier message for local MCP errors.
189
213
  """
190
- raise RuntimeError(
191
- f"MCP server failed to start. Please verify your command, arguments, and environment variables, and consult the server's documentation for the correct setup. Original error: {e}"
192
- ) from e
214
+ error_msg = f"'{e}'"
215
+
216
+ if stderr:
217
+ error_msg += f"\nMCP server error: {stderr}"
218
+
219
+ error_msg += f"\n{LOCAL_MCP_ERROR_INSTRUCTION}"
220
+
221
+ raise RuntimeError(error_msg) from e
193
222
 
194
223
  def _get_path(self) -> str:
195
224
  """
@@ -0,0 +1,164 @@
1
+ from functools import cached_property
2
+ from typing import Any, Dict, List, TypedDict
3
+
4
+ from pydantic import BaseModel
5
+
6
+ from kiln_ai.adapters.embedding.base_embedding_adapter import BaseEmbeddingAdapter
7
+ from kiln_ai.adapters.embedding.embedding_registry import embedding_adapter_from_type
8
+ from kiln_ai.adapters.vector_store.base_vector_store_adapter import (
9
+ BaseVectorStoreAdapter,
10
+ SearchResult,
11
+ VectorStoreQuery,
12
+ )
13
+ from kiln_ai.adapters.vector_store.vector_store_registry import (
14
+ vector_store_adapter_for_config,
15
+ )
16
+ from kiln_ai.datamodel.embedding import EmbeddingConfig
17
+ from kiln_ai.datamodel.project import Project
18
+ from kiln_ai.datamodel.rag import RagConfig
19
+ from kiln_ai.datamodel.tool_id import ToolId
20
+ from kiln_ai.datamodel.vector_store import VectorStoreConfig, VectorStoreType
21
+ from kiln_ai.tools.base_tool import KilnToolInterface, ToolCallContext
22
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
23
+
24
+
25
+ class ChunkContext(BaseModel):
26
+ metadata: dict
27
+ text: str
28
+
29
+ def serialize(self) -> str:
30
+ metadata_str = ", ".join([f"{k}: {v}" for k, v in self.metadata.items()])
31
+ return f"[{metadata_str}]\n{self.text}\n\n"
32
+
33
+
34
+ def format_search_results(search_results: List[SearchResult]) -> str:
35
+ results: List[ChunkContext] = []
36
+ for search_result in search_results:
37
+ results.append(
38
+ ChunkContext(
39
+ metadata={
40
+ "document_id": search_result.document_id,
41
+ "chunk_idx": search_result.chunk_idx,
42
+ },
43
+ text=search_result.chunk_text,
44
+ )
45
+ )
46
+ return "\n=========\n".join([result.serialize() for result in results])
47
+
48
+
49
+ class RagParams(TypedDict):
50
+ query: str
51
+
52
+
53
+ class RagTool(KilnToolInterface):
54
+ """
55
+ A tool that searches the vector store and returns the most relevant chunks.
56
+ """
57
+
58
+ def __init__(self, tool_id: str, rag_config: RagConfig):
59
+ self._id = tool_id
60
+ self._name = rag_config.tool_name
61
+ self._description = rag_config.tool_description
62
+ self._parameters_schema = {
63
+ "type": "object",
64
+ "properties": {
65
+ "query": {
66
+ "type": "string",
67
+ "description": "The search query",
68
+ },
69
+ },
70
+ "required": ["query"],
71
+ }
72
+ self._rag_config = rag_config
73
+ vector_store_config = VectorStoreConfig.from_id_and_parent_path(
74
+ str(self._rag_config.vector_store_config_id), self.project.path
75
+ )
76
+ if vector_store_config is None:
77
+ raise ValueError(
78
+ f"Vector store config not found: {self._rag_config.vector_store_config_id}"
79
+ )
80
+ self._vector_store_config = vector_store_config
81
+ self._vector_store_adapter: BaseVectorStoreAdapter | None = None
82
+
83
+ @cached_property
84
+ def project(self) -> Project:
85
+ project = self._rag_config.parent_project()
86
+ if project is None:
87
+ raise ValueError(f"RAG config {self._rag_config.id} has no project")
88
+ return project
89
+
90
+ @cached_property
91
+ def embedding(
92
+ self,
93
+ ) -> tuple[EmbeddingConfig, BaseEmbeddingAdapter]:
94
+ embedding_config = EmbeddingConfig.from_id_and_parent_path(
95
+ str(self._rag_config.embedding_config_id), self.project.path
96
+ )
97
+ if embedding_config is None:
98
+ raise ValueError(
99
+ f"Embedding config not found: {self._rag_config.embedding_config_id}"
100
+ )
101
+ return embedding_config, embedding_adapter_from_type(embedding_config)
102
+
103
+ async def vector_store(
104
+ self,
105
+ ) -> BaseVectorStoreAdapter:
106
+ if self._vector_store_adapter is None:
107
+ self._vector_store_adapter = await vector_store_adapter_for_config(
108
+ vector_store_config=self._vector_store_config,
109
+ rag_config=self._rag_config,
110
+ )
111
+ return self._vector_store_adapter
112
+
113
+ async def id(self) -> ToolId:
114
+ return self._id
115
+
116
+ async def name(self) -> str:
117
+ return self._name
118
+
119
+ async def description(self) -> str:
120
+ return self._description
121
+
122
+ async def toolcall_definition(self) -> Dict[str, Any]:
123
+ """Return the OpenAI-compatible tool definition for this tool."""
124
+ return {
125
+ "type": "function",
126
+ "function": {
127
+ "name": await self.name(),
128
+ "description": await self.description(),
129
+ "parameters": self._parameters_schema,
130
+ },
131
+ }
132
+
133
+ async def run(self, context: ToolCallContext | None = None, **kwargs) -> str:
134
+ kwargs = RagParams(**kwargs)
135
+ query = kwargs["query"]
136
+
137
+ _, embedding_adapter = self.embedding
138
+
139
+ vector_store_adapter = await self.vector_store()
140
+ store_query = VectorStoreQuery(
141
+ query_embedding=None,
142
+ query_string=query,
143
+ )
144
+
145
+ match self._vector_store_config.store_type:
146
+ case VectorStoreType.LANCE_DB_HYBRID | VectorStoreType.LANCE_DB_VECTOR:
147
+ is_vector_query = True
148
+ case VectorStoreType.LANCE_DB_FTS:
149
+ is_vector_query = False
150
+ case _:
151
+ raise_exhaustive_enum_error(self._vector_store_config.store_type)
152
+
153
+ if is_vector_query:
154
+ query_embedding_result = await embedding_adapter.generate_embeddings(
155
+ [query]
156
+ )
157
+ if len(query_embedding_result.embeddings) == 0:
158
+ raise ValueError("No embeddings generated")
159
+ store_query.query_embedding = query_embedding_result.embeddings[0].vector
160
+
161
+ search_results = await vector_store_adapter.search(store_query)
162
+ search_results_as_text = format_search_results(search_results)
163
+
164
+ return search_results_as_text