kiln-ai 0.21.0__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (45) hide show
  1. kiln_ai/adapters/extractors/litellm_extractor.py +52 -32
  2. kiln_ai/adapters/extractors/test_litellm_extractor.py +169 -71
  3. kiln_ai/adapters/ml_embedding_model_list.py +330 -28
  4. kiln_ai/adapters/ml_model_list.py +503 -23
  5. kiln_ai/adapters/model_adapters/litellm_adapter.py +34 -7
  6. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +78 -0
  7. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
  8. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
  9. kiln_ai/adapters/model_adapters/test_structured_output.py +6 -9
  10. kiln_ai/adapters/test_ml_embedding_model_list.py +89 -279
  11. kiln_ai/adapters/test_ml_model_list.py +0 -10
  12. kiln_ai/datamodel/basemodel.py +31 -3
  13. kiln_ai/datamodel/external_tool_server.py +206 -54
  14. kiln_ai/datamodel/extraction.py +14 -0
  15. kiln_ai/datamodel/task.py +5 -0
  16. kiln_ai/datamodel/task_output.py +41 -11
  17. kiln_ai/datamodel/test_attachment.py +3 -3
  18. kiln_ai/datamodel/test_basemodel.py +269 -13
  19. kiln_ai/datamodel/test_datasource.py +50 -0
  20. kiln_ai/datamodel/test_external_tool_server.py +534 -152
  21. kiln_ai/datamodel/test_extraction_model.py +31 -0
  22. kiln_ai/datamodel/test_task.py +35 -1
  23. kiln_ai/datamodel/test_tool_id.py +106 -1
  24. kiln_ai/datamodel/tool_id.py +36 -0
  25. kiln_ai/tools/base_tool.py +12 -3
  26. kiln_ai/tools/built_in_tools/math_tools.py +12 -4
  27. kiln_ai/tools/kiln_task_tool.py +158 -0
  28. kiln_ai/tools/mcp_server_tool.py +2 -2
  29. kiln_ai/tools/mcp_session_manager.py +50 -24
  30. kiln_ai/tools/rag_tools.py +12 -5
  31. kiln_ai/tools/test_kiln_task_tool.py +527 -0
  32. kiln_ai/tools/test_mcp_server_tool.py +4 -15
  33. kiln_ai/tools/test_mcp_session_manager.py +186 -226
  34. kiln_ai/tools/test_rag_tools.py +86 -5
  35. kiln_ai/tools/test_tool_registry.py +199 -5
  36. kiln_ai/tools/tool_registry.py +49 -17
  37. kiln_ai/utils/filesystem.py +4 -4
  38. kiln_ai/utils/open_ai_types.py +19 -2
  39. kiln_ai/utils/pdf_utils.py +21 -0
  40. kiln_ai/utils/test_open_ai_types.py +88 -12
  41. kiln_ai/utils/test_pdf_utils.py +14 -1
  42. {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +3 -1
  43. {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.0.dist-info}/RECORD +45 -43
  44. {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
  45. {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -468,3 +468,34 @@ def test_document_invalid_mime_type(
468
468
  )
469
469
  def test_get_kind_from_mime_type(mime_type, expected_kind):
470
470
  assert get_kind_from_mime_type(mime_type) == expected_kind
471
+
472
+
473
+ def test_document_friendly_name(mock_project, mock_attachment_factory):
474
+ name = f"Test Document {uuid.uuid4()!s}"
475
+ document = Document(
476
+ name=name,
477
+ description=f"Test description {uuid.uuid4()!s}",
478
+ kind=Kind.DOCUMENT,
479
+ original_file=FileInfo(
480
+ filename=f"test_{name}.txt",
481
+ size=100,
482
+ mime_type="text/plain",
483
+ attachment=mock_attachment_factory(),
484
+ ),
485
+ parent=mock_project,
486
+ )
487
+ document.save_to_file()
488
+
489
+ # backward compatibility: old documents did not have name_override
490
+ assert document.name_override is None
491
+ assert document.friendly_name == name
492
+
493
+ # new documents have name_override
494
+ document.name_override = "Test Document Override"
495
+ assert document.friendly_name == "Test Document Override"
496
+
497
+ document.save_to_file()
498
+
499
+ document = Document.from_id_and_parent_path(str(document.id), mock_project.path)
500
+ assert document is not None
501
+ assert document.friendly_name == "Test Document Override"
@@ -254,7 +254,7 @@ def test_run_config_upgrade_old_entries():
254
254
  },
255
255
  "prompt": {
256
256
  "name": "Dazzling Unicorn",
257
- "description": "Frozen copy of prompt 'simple_prompt_builder', created for evaluations.",
257
+ "description": "Frozen copy of prompt 'simple_prompt_builder'.",
258
258
  "generator_id": "simple_prompt_builder",
259
259
  "prompt": "Generate a joke, given a theme. The theme will be provided as a word or phrase as the input to the model. The assistant should output a joke that is funny and relevant to the theme. If a style is provided, the joke should be in that style. The output should include a setup and punchline.\n\nYour response should respect the following requirements:\n1) Keep the joke on topic. If the user specifies a theme, the joke must be related to that theme.\n2) Avoid any jokes that are offensive or inappropriate. Keep the joke clean and appropriate for all audiences.\n3) Make the joke funny and engaging. It should be something that someone would want to tell to their friends. Something clever, not just a simple pun.\n",
260
260
  "chain_of_thought_instructions": None,
@@ -296,3 +296,37 @@ def test_run_config_upgrade_old_entries():
296
296
  def test_task_name_unicode_name():
297
297
  task = Task(name="你好", instruction="Do something")
298
298
  assert task.name == "你好"
299
+
300
+
301
+ def test_task_default_run_config_id_property(tmp_path):
302
+ """Test that default_run_config_id can be set and retrieved."""
303
+
304
+ # Create a task
305
+ task = Task(
306
+ name="Test Task", instruction="Test instruction", path=tmp_path / "task.kiln"
307
+ )
308
+ task.save_to_file()
309
+
310
+ # Create a run config for the task
311
+ run_config = TaskRunConfig(
312
+ name="Test Config",
313
+ run_config_properties=RunConfigProperties(
314
+ model_name="gpt-4",
315
+ model_provider_name="openai",
316
+ prompt_id=PromptGenerators.SIMPLE,
317
+ structured_output_mode=StructuredOutputMode.json_schema,
318
+ ),
319
+ parent=task,
320
+ )
321
+ run_config.save_to_file()
322
+
323
+ # Test None default (should be valid)
324
+ assert task.default_run_config_id is None
325
+
326
+ # Test setting a valid ID
327
+ task.default_run_config_id = "123456789012"
328
+ assert task.default_run_config_id == "123456789012"
329
+
330
+ # Test setting back to None
331
+ task.default_run_config_id = None
332
+ assert task.default_run_config_id is None
@@ -8,6 +8,7 @@ from kiln_ai.datamodel.tool_id import (
8
8
  KilnBuiltInToolId,
9
9
  ToolId,
10
10
  _check_tool_id,
11
+ kiln_task_server_id_from_tool_id,
11
12
  mcp_server_and_tool_name_from_id,
12
13
  rag_config_id_from_id,
13
14
  )
@@ -145,6 +146,39 @@ class TestCheckToolId:
145
146
  with pytest.raises(ValueError, match="Invalid RAG tool ID"):
146
147
  _check_tool_id("kiln_tool::rag::")
147
148
 
149
+ def test_valid_kiln_task_tools(self):
150
+ """Test validation of valid Kiln task tools."""
151
+ valid_ids = [
152
+ "kiln_task::server1",
153
+ "kiln_task::my_server",
154
+ "kiln_task::test_server_123",
155
+ "kiln_task::server_with_underscores",
156
+ "kiln_task::server-with-dashes",
157
+ "kiln_task::server.with.dots",
158
+ ]
159
+ for tool_id in valid_ids:
160
+ result = _check_tool_id(tool_id)
161
+ assert result == tool_id
162
+
163
+ def test_invalid_kiln_task_format(self):
164
+ """Test validation fails for invalid Kiln task tool formats."""
165
+ # These IDs start with the Kiln task prefix but have invalid formats
166
+ kiln_task_invalid_ids = [
167
+ "kiln_task::", # Missing server ID
168
+ "kiln_task::server::extra", # Too many parts
169
+ "kiln_task::server::tool::extra", # Too many parts
170
+ ]
171
+
172
+ for invalid_id in kiln_task_invalid_ids:
173
+ with pytest.raises(ValueError, match="Invalid Kiln task tool ID"):
174
+ _check_tool_id(invalid_id)
175
+
176
+ def test_kiln_task_tool_empty_server_id(self):
177
+ """Test that Kiln task tool with empty server ID is handled properly."""
178
+ # This tests the case where kiln_task_server_id_from_tool_id returns empty string which should raise an error
179
+ with pytest.raises(ValueError, match="Invalid Kiln task tool ID"):
180
+ _check_tool_id("kiln_task::")
181
+
148
182
 
149
183
  class TestMcpServerAndToolNameFromId:
150
184
  """Test the mcp_server_and_tool_name_from_id function."""
@@ -220,7 +254,7 @@ class TestToolIdPydanticType:
220
254
  model = self._ModelWithToolId(tool_id=tool_id.value)
221
255
  assert model.tool_id == tool_id.value
222
256
 
223
- def test_valid_mcp_tools(self):
257
+ def test_valid_tool_ids(self):
224
258
  """Test ToolId validates MCP remote and local tools."""
225
259
  valid_ids = [
226
260
  # Remote MCP tools
@@ -232,6 +266,9 @@ class TestToolIdPydanticType:
232
266
  # RAG tools
233
267
  "kiln_tool::rag::config1",
234
268
  "kiln_tool::rag::my_rag_config",
269
+ # Kiln task tools
270
+ "kiln_task::server1",
271
+ "kiln_task::my_server",
235
272
  ]
236
273
 
237
274
  for tool_id in valid_ids:
@@ -249,6 +286,8 @@ class TestToolIdPydanticType:
249
286
  "mcp::local::server",
250
287
  "kiln_tool::rag::",
251
288
  "kiln_tool::rag::config::extra",
289
+ "kiln_task::",
290
+ "kiln_task::server::extra",
252
291
  ]
253
292
 
254
293
  for invalid_id in invalid_ids:
@@ -318,3 +357,69 @@ class TestRagConfigIdFromId:
318
357
  # The validation for empty config ID happens in _check_tool_id
319
358
  result = rag_config_id_from_id("kiln_tool::rag::")
320
359
  assert result == ""
360
+
361
+
362
+ class TestKilnTaskServerIdFromToolId:
363
+ """Test the kiln_task_server_id_from_tool_id function."""
364
+
365
+ def test_valid_kiln_task_ids(self):
366
+ """Test parsing valid Kiln task tool IDs."""
367
+ test_cases = [
368
+ ("kiln_task::server1", "server1"),
369
+ ("kiln_task::my_server", "my_server"),
370
+ ("kiln_task::test_server_123", "test_server_123"),
371
+ ("kiln_task::a", "a"), # Minimal valid case
372
+ ("kiln_task::server_with_underscores", "server_with_underscores"),
373
+ ("kiln_task::server-with-dashes", "server-with-dashes"),
374
+ ("kiln_task::server.with.dots", "server.with.dots"),
375
+ ]
376
+
377
+ for tool_id, expected in test_cases:
378
+ result = kiln_task_server_id_from_tool_id(tool_id)
379
+ assert result == expected
380
+
381
+ def test_invalid_kiln_task_ids(self):
382
+ """Test parsing fails for invalid Kiln task tool IDs."""
383
+ # Test various invalid formats
384
+ invalid_ids = [
385
+ "kiln_task::", # Empty server ID
386
+ "kiln_task::server::extra", # Too many parts (3 parts)
387
+ "kiln_task::server::tool::extra", # Too many parts (4 parts)
388
+ "wrong::server", # Wrong prefix
389
+ "kiln_wrong::server", # Wrong prefix
390
+ "task::server", # Too few parts (2 parts)
391
+ "", # Empty string
392
+ "single_part", # Only 1 part
393
+ "kiln_task", # Missing server ID
394
+ ]
395
+
396
+ for invalid_id in invalid_ids:
397
+ with pytest.raises(ValueError, match="Invalid Kiln task tool ID format"):
398
+ kiln_task_server_id_from_tool_id(invalid_id)
399
+
400
+ def test_kiln_task_id_with_empty_server_id(self):
401
+ """Test that Kiln task tool ID with empty server ID raises error."""
402
+ with pytest.raises(ValueError, match="Invalid Kiln task tool ID format"):
403
+ kiln_task_server_id_from_tool_id("kiln_task::")
404
+
405
+ def test_kiln_task_id_with_whitespace_server_id(self):
406
+ """Test that Kiln task tool ID with whitespace-only server ID raises error."""
407
+ with pytest.raises(ValueError, match="Invalid Kiln task tool ID format"):
408
+ kiln_task_server_id_from_tool_id("kiln_task::")
409
+
410
+ def test_kiln_task_id_with_multiple_colons(self):
411
+ """Test that Kiln task tool ID with multiple colons raises error."""
412
+ with pytest.raises(ValueError, match="Invalid Kiln task tool ID format"):
413
+ kiln_task_server_id_from_tool_id("kiln_task::server::extra")
414
+
415
+ def test_kiln_task_id_case_sensitivity(self):
416
+ """Test that Kiln task tool IDs are case sensitive."""
417
+ # These should work
418
+ result1 = kiln_task_server_id_from_tool_id("kiln_task::Server")
419
+ assert result1 == "Server"
420
+
421
+ result2 = kiln_task_server_id_from_tool_id("kiln_task::SERVER")
422
+ assert result2 == "SERVER"
423
+
424
+ result3 = kiln_task_server_id_from_tool_id("kiln_task::server")
425
+ assert result3 == "server"
@@ -14,6 +14,7 @@ Tool IDs can be one of:
14
14
  - A kiln built-in tool name: kiln_tool::add_numbers
15
15
  - A remote MCP tool: mcp::remote::<server_id>::<tool_name>
16
16
  - A local MCP tool: mcp::local::<server_id>::<tool_name>
17
+ - A Kiln task tool: kiln_task::<server_id>
17
18
  - More coming soon like kiln_project_tool::rag::RAG_CONFIG_ID
18
19
  """
19
20
 
@@ -28,6 +29,7 @@ class KilnBuiltInToolId(str, Enum):
28
29
  MCP_REMOTE_TOOL_ID_PREFIX = "mcp::remote::"
29
30
  RAG_TOOL_ID_PREFIX = "kiln_tool::rag::"
30
31
  MCP_LOCAL_TOOL_ID_PREFIX = "mcp::local::"
32
+ KILN_TASK_TOOL_ID_PREFIX = "kiln_task::"
31
33
 
32
34
 
33
35
  def _check_tool_id(id: str) -> str:
@@ -68,6 +70,15 @@ def _check_tool_id(id: str) -> str:
68
70
  )
69
71
  return id
70
72
 
73
+ # Kiln task tools must have format: kiln_task::<server_id>
74
+ if id.startswith(KILN_TASK_TOOL_ID_PREFIX):
75
+ server_id = kiln_task_server_id_from_tool_id(id)
76
+ if not server_id:
77
+ raise ValueError(
78
+ f"Invalid Kiln task tool ID: {id}. Expected format: 'kiln_task::<server_id>'."
79
+ )
80
+ return id
81
+
71
82
  raise ValueError(f"Invalid tool ID: {id}")
72
83
 
73
84
 
@@ -103,3 +114,28 @@ def rag_config_id_from_id(id: str) -> str:
103
114
  f"Invalid RAG tool ID: {id}. Expected format: 'kiln_tool::rag::<rag_config_id>'."
104
115
  )
105
116
  return parts[2]
117
+
118
+
119
+ def kiln_task_server_id_from_tool_id(tool_id: str) -> str:
120
+ """
121
+ Get the server ID from the tool ID.
122
+ """
123
+ if not tool_id.startswith(KILN_TASK_TOOL_ID_PREFIX):
124
+ raise ValueError(
125
+ f"Invalid Kiln task tool ID format: {tool_id}. Expected format: 'kiln_task::<server_id>'."
126
+ )
127
+
128
+ # Remove prefix and split on ::
129
+ remaining = tool_id[len(KILN_TASK_TOOL_ID_PREFIX) :]
130
+ if not remaining:
131
+ raise ValueError(
132
+ f"Invalid Kiln task tool ID format: {tool_id}. Expected format: 'kiln_task::<server_id>'."
133
+ )
134
+ parts = remaining.split("::")
135
+
136
+ if len(parts) != 1 or not parts[0].strip():
137
+ raise ValueError(
138
+ f"Invalid Kiln task tool ID format: {tool_id}. Expected format: 'kiln_task::<server_id>'."
139
+ )
140
+
141
+ return parts[0] # server_id
@@ -1,10 +1,19 @@
1
1
  from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
2
3
  from typing import Any, Dict
3
4
 
4
5
  from kiln_ai.datamodel.json_schema import validate_schema_dict
5
6
  from kiln_ai.datamodel.tool_id import KilnBuiltInToolId, ToolId
6
7
 
7
8
 
9
+ @dataclass
10
+ class ToolCallContext:
11
+ """Context passed to tools when they are called, containing information from the calling task."""
12
+
13
+ """Used for Kiln Tasks as Tools, to know if the tool call should save the task run it invoked to that task's Dataset."""
14
+ allow_saving: bool = True
15
+
16
+
8
17
  class KilnToolInterface(ABC):
9
18
  """
10
19
  Abstract interface defining the core API that all Kiln tools must implement.
@@ -12,8 +21,8 @@ class KilnToolInterface(ABC):
12
21
  """
13
22
 
14
23
  @abstractmethod
15
- async def run(self, **kwargs) -> Any:
16
- """Execute the tool with the given parameters."""
24
+ async def run(self, context: ToolCallContext | None = None, **kwargs) -> Any:
25
+ """Execute the tool with the given parameters and calling context if provided."""
17
26
  pass
18
27
 
19
28
  @abstractmethod
@@ -77,6 +86,6 @@ class KilnTool(KilnToolInterface):
77
86
  }
78
87
 
79
88
  @abstractmethod
80
- async def run(self, **kwargs) -> str:
89
+ async def run(self, context: ToolCallContext | None = None, **kwargs) -> Any:
81
90
  """Subclasses must implement the actual tool logic."""
82
91
  pass
@@ -27,7 +27,9 @@ class AddTool(KilnTool):
27
27
  parameters_schema=parameters_schema,
28
28
  )
29
29
 
30
- async def run(self, a: Union[int, float], b: Union[int, float]) -> str:
30
+ async def run(
31
+ self, context=None, *, a: Union[int, float], b: Union[int, float]
32
+ ) -> str:
31
33
  """Add two numbers and return the result."""
32
34
  return str(a + b)
33
35
 
@@ -57,7 +59,9 @@ class SubtractTool(KilnTool):
57
59
  parameters_schema=parameters_schema,
58
60
  )
59
61
 
60
- async def run(self, a: Union[int, float], b: Union[int, float]) -> str:
62
+ async def run(
63
+ self, context=None, *, a: Union[int, float], b: Union[int, float]
64
+ ) -> str:
61
65
  """Subtract b from a and return the result."""
62
66
  return str(a - b)
63
67
 
@@ -84,7 +88,9 @@ class MultiplyTool(KilnTool):
84
88
  parameters_schema=parameters_schema,
85
89
  )
86
90
 
87
- async def run(self, a: Union[int, float], b: Union[int, float]) -> str:
91
+ async def run(
92
+ self, context=None, *, a: Union[int, float], b: Union[int, float]
93
+ ) -> str:
88
94
  """Multiply two numbers and return the result."""
89
95
  return str(a * b)
90
96
 
@@ -117,7 +123,9 @@ class DivideTool(KilnTool):
117
123
  parameters_schema=parameters_schema,
118
124
  )
119
125
 
120
- async def run(self, a: Union[int, float], b: Union[int, float]) -> str:
126
+ async def run(
127
+ self, context=None, *, a: Union[int, float], b: Union[int, float]
128
+ ) -> str:
121
129
  """Divide a by b and return the result."""
122
130
  if b == 0:
123
131
  raise ZeroDivisionError("Cannot divide by zero")
@@ -0,0 +1,158 @@
1
+ from dataclasses import dataclass
2
+ from functools import cached_property
3
+ from typing import Any, Dict
4
+
5
+ from kiln_ai.datamodel import Task
6
+ from kiln_ai.datamodel.external_tool_server import ExternalToolServer
7
+ from kiln_ai.datamodel.task import TaskRunConfig
8
+ from kiln_ai.datamodel.task_output import DataSource, DataSourceType
9
+ from kiln_ai.datamodel.tool_id import ToolId
10
+ from kiln_ai.tools.base_tool import KilnToolInterface, ToolCallContext
11
+ from kiln_ai.utils.project_utils import project_from_id
12
+
13
+
14
+ @dataclass
15
+ class KilnTaskToolResult:
16
+ output: str
17
+ kiln_task_tool_data: str
18
+
19
+
20
+ class KilnTaskTool(KilnToolInterface):
21
+ """
22
+ A tool that wraps a Kiln task, allowing it to be called as a function.
23
+
24
+ This tool loads a task by ID and executes it using the specified run configuration.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ project_id: str,
30
+ tool_id: str,
31
+ data_model: ExternalToolServer,
32
+ ):
33
+ self._project_id = project_id
34
+ self._tool_server_model = data_model
35
+ self._tool_id = tool_id
36
+
37
+ self._name = data_model.properties.get("name", "")
38
+ self._description = data_model.properties.get("description", "")
39
+ self._task_id = data_model.properties.get("task_id", "")
40
+ self._run_config_id = data_model.properties.get("run_config_id", "")
41
+
42
+ async def id(self) -> ToolId:
43
+ return self._tool_id
44
+
45
+ async def name(self) -> str:
46
+ return self._name
47
+
48
+ async def description(self) -> str:
49
+ return self._description
50
+
51
+ async def toolcall_definition(self) -> Dict[str, Any]:
52
+ """Generate OpenAI-compatible tool definition."""
53
+ return {
54
+ "type": "function",
55
+ "function": {
56
+ "name": await self.name(),
57
+ "description": await self.description(),
58
+ "parameters": self.parameters_schema,
59
+ },
60
+ }
61
+
62
+ async def run(
63
+ self, context: ToolCallContext | None = None, **kwargs
64
+ ) -> KilnTaskToolResult:
65
+ """Execute the wrapped Kiln task with the given parameters and calling context."""
66
+ if context is None:
67
+ raise ValueError("Context is required for running a KilnTaskTool.")
68
+
69
+ # Determine the input format
70
+ if self._task.input_json_schema:
71
+ # Structured input - pass kwargs directly
72
+ input = kwargs
73
+ else:
74
+ # Plaintext input - extract from 'input' parameter
75
+ if "input" in kwargs:
76
+ input = kwargs["input"]
77
+ else:
78
+ raise ValueError(f"Input not found in kwargs: {kwargs}")
79
+
80
+ # These imports are here to avoid circular chains
81
+ from kiln_ai.adapters.adapter_registry import adapter_for_task
82
+ from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
83
+
84
+ # Create adapter and run the task using the calling task's allow_saving setting
85
+ adapter = adapter_for_task(
86
+ self._task,
87
+ run_config_properties=self._run_config.run_config_properties,
88
+ base_adapter_config=AdapterConfig(
89
+ allow_saving=context.allow_saving,
90
+ default_tags=["tool_call"],
91
+ ),
92
+ )
93
+ task_run = await adapter.invoke(
94
+ input,
95
+ input_source=DataSource(
96
+ type=DataSourceType.tool_call,
97
+ run_config=self._run_config.run_config_properties,
98
+ ),
99
+ )
100
+
101
+ return KilnTaskToolResult(
102
+ output=task_run.output.output,
103
+ kiln_task_tool_data=f"{self._project_id}:::{self._tool_id}:::{self._task.id}:::{task_run.id}",
104
+ )
105
+
106
+ @cached_property
107
+ def _task(self) -> Task:
108
+ # Load the project first
109
+ project = project_from_id(self._project_id)
110
+ if project is None:
111
+ raise ValueError(f"Project not found: {self._project_id}")
112
+
113
+ # Load the task from the project
114
+ task = Task.from_id_and_parent_path(self._task_id, project.path)
115
+ if task is None:
116
+ raise ValueError(
117
+ f"Task not found: {self._task_id} in project {self._project_id}"
118
+ )
119
+ return task
120
+
121
+ @cached_property
122
+ def _run_config(self) -> TaskRunConfig:
123
+ run_config = next(
124
+ (
125
+ run_config
126
+ for run_config in self._task.run_configs(readonly=True)
127
+ if run_config.id == self._run_config_id
128
+ ),
129
+ None,
130
+ )
131
+ if run_config is None:
132
+ raise ValueError(
133
+ f"Task run config not found: {self._run_config_id} for task {self._task_id} in project {self._project_id}"
134
+ )
135
+ return run_config
136
+
137
+ @cached_property
138
+ def parameters_schema(self) -> Dict[str, Any]:
139
+ if self._task.input_json_schema:
140
+ # Use the task's input schema directly if it exists
141
+ parameters_schema = self._task.input_schema()
142
+ else:
143
+ # For plaintext tasks, create a simple string input parameter
144
+ parameters_schema = {
145
+ "type": "object",
146
+ "properties": {
147
+ "input": {
148
+ "type": "string",
149
+ "description": "Plaintext input for the tool.",
150
+ }
151
+ },
152
+ "required": ["input"],
153
+ }
154
+ if parameters_schema is None:
155
+ raise ValueError(
156
+ f"Failed to create parameters schema for tool_id {self._tool_id}"
157
+ )
158
+ return parameters_schema
@@ -5,7 +5,7 @@ from mcp.types import Tool as MCPTool
5
5
 
6
6
  from kiln_ai.datamodel.external_tool_server import ExternalToolServer
7
7
  from kiln_ai.datamodel.tool_id import MCP_REMOTE_TOOL_ID_PREFIX, ToolId
8
- from kiln_ai.tools.base_tool import KilnToolInterface
8
+ from kiln_ai.tools.base_tool import KilnToolInterface, ToolCallContext
9
9
  from kiln_ai.tools.mcp_session_manager import MCPSessionManager
10
10
 
11
11
 
@@ -38,7 +38,7 @@ class MCPServerTool(KilnToolInterface):
38
38
  },
39
39
  }
40
40
 
41
- async def run(self, **kwargs) -> Any:
41
+ async def run(self, context: ToolCallContext | None = None, **kwargs) -> str:
42
42
  result = await self._call_tool(**kwargs)
43
43
 
44
44
  if result.isError:
@@ -2,6 +2,7 @@ import logging
2
2
  import os
3
3
  import subprocess
4
4
  import sys
5
+ import tempfile
5
6
  from contextlib import asynccontextmanager
6
7
  from datetime import timedelta
7
8
  from typing import AsyncGenerator
@@ -19,6 +20,8 @@ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
19
20
 
20
21
  logger = logging.getLogger(__name__)
21
22
 
23
+ LOCAL_MCP_ERROR_INSTRUCTION = "Please verify your command, arguments, and environment variables, and consult the server's documentation for the correct setup."
24
+
22
25
 
23
26
  class MCPSessionManager:
24
27
  """
@@ -51,6 +54,8 @@ class MCPSessionManager:
51
54
  case ToolServerType.local_mcp:
52
55
  async with self._create_local_mcp_session(tool_server) as session:
53
56
  yield session
57
+ case ToolServerType.kiln_task:
58
+ raise ValueError("Kiln task tools are not available from an MCP server")
54
59
  case _:
55
60
  raise_exhaustive_enum_error(tool_server.type)
56
61
 
@@ -164,35 +169,56 @@ class MCPSessionManager:
164
169
  env_vars["PATH"] = self._get_path()
165
170
 
166
171
  # Set the server parameters
172
+ cwd = os.path.join(Config.settings_dir(), "cache", "mcp_cache")
173
+ os.makedirs(cwd, exist_ok=True)
167
174
  server_params = StdioServerParameters(
168
- command=command,
169
- args=args,
170
- env=env_vars,
175
+ command=command, args=args, env=env_vars, cwd=cwd
171
176
  )
172
177
 
173
- try:
174
- async with stdio_client(server_params) as (read, write):
175
- async with ClientSession(
176
- read, write, read_timeout_seconds=timedelta(seconds=8)
177
- ) as session:
178
- await session.initialize()
179
- yield session
180
- except Exception as e:
181
- # Check for MCP errors. Things like wrong arguments would fall here.
182
- mcp_error = self._extract_first_exception(e, McpError)
183
- if mcp_error and isinstance(mcp_error, McpError):
184
- self._raise_local_mcp_error(mcp_error)
185
-
186
- # Re-raise the original error but with a friendlier message
187
- self._raise_local_mcp_error(e)
188
-
189
- def _raise_local_mcp_error(self, e: Exception):
178
+ # Create temporary file to capture MCP server stderr
179
+ # Use errors="replace" to handle non-UTF-8 bytes gracefully
180
+ with tempfile.TemporaryFile(
181
+ mode="w+", encoding="utf-8", errors="replace"
182
+ ) as err_log:
183
+ try:
184
+ async with stdio_client(server_params, errlog=err_log) as (
185
+ read,
186
+ write,
187
+ ):
188
+ async with ClientSession(
189
+ read, write, read_timeout_seconds=timedelta(seconds=30)
190
+ ) as session:
191
+ await session.initialize()
192
+ yield session
193
+ except Exception as e:
194
+ # Read stderr content from temporary file for debugging
195
+ err_log.seek(0) # Read from the start of the file
196
+ stderr_content = err_log.read()
197
+ if stderr_content:
198
+ logger.error(
199
+ f"MCP server '{tool_server.name}' stderr output: {stderr_content}"
200
+ )
201
+
202
+ # Check for MCP errors. Things like wrong arguments would fall here.
203
+ mcp_error = self._extract_first_exception(e, McpError)
204
+ if mcp_error and isinstance(mcp_error, McpError):
205
+ self._raise_local_mcp_error(mcp_error, stderr_content)
206
+
207
+ # Re-raise the original error but with a friendlier message
208
+ self._raise_local_mcp_error(e, stderr_content)
209
+
210
+ def _raise_local_mcp_error(self, e: Exception, stderr: str):
190
211
  """
191
- Raise a ValueError with a friendlier message for local MCP errors.
212
+ Raise a RuntimeError with a friendlier message for local MCP errors.
192
213
  """
193
- raise RuntimeError(
194
- f"MCP server failed to start. Please verify your command, arguments, and environment variables, and consult the server's documentation for the correct setup. Original error: {e}"
195
- ) from e
214
+ error_msg = f"'{e}'"
215
+
216
+ if stderr:
217
+ error_msg += f"\nMCP server error: {stderr}"
218
+
219
+ error_msg += f"\n{LOCAL_MCP_ERROR_INSTRUCTION}"
220
+
221
+ raise RuntimeError(error_msg) from e
196
222
 
197
223
  def _get_path(self) -> str:
198
224
  """