kiln-ai 0.21.0__py3-none-any.whl → 0.22.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (53) hide show
  1. kiln_ai/adapters/extractors/litellm_extractor.py +52 -32
  2. kiln_ai/adapters/extractors/test_litellm_extractor.py +169 -71
  3. kiln_ai/adapters/ml_embedding_model_list.py +330 -28
  4. kiln_ai/adapters/ml_model_list.py +503 -23
  5. kiln_ai/adapters/model_adapters/litellm_adapter.py +39 -8
  6. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +78 -0
  7. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
  8. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
  9. kiln_ai/adapters/model_adapters/test_structured_output.py +6 -9
  10. kiln_ai/adapters/test_ml_embedding_model_list.py +89 -279
  11. kiln_ai/adapters/test_ml_model_list.py +0 -10
  12. kiln_ai/adapters/vector_store/lancedb_adapter.py +24 -70
  13. kiln_ai/adapters/vector_store/lancedb_helpers.py +101 -0
  14. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +9 -16
  15. kiln_ai/adapters/vector_store/test_lancedb_helpers.py +142 -0
  16. kiln_ai/adapters/vector_store_loaders/__init__.py +0 -0
  17. kiln_ai/adapters/vector_store_loaders/test_lancedb_loader.py +282 -0
  18. kiln_ai/adapters/vector_store_loaders/test_vector_store_loader.py +544 -0
  19. kiln_ai/adapters/vector_store_loaders/vector_store_loader.py +91 -0
  20. kiln_ai/datamodel/basemodel.py +31 -3
  21. kiln_ai/datamodel/external_tool_server.py +206 -54
  22. kiln_ai/datamodel/extraction.py +14 -0
  23. kiln_ai/datamodel/task.py +5 -0
  24. kiln_ai/datamodel/task_output.py +41 -11
  25. kiln_ai/datamodel/test_attachment.py +3 -3
  26. kiln_ai/datamodel/test_basemodel.py +269 -13
  27. kiln_ai/datamodel/test_datasource.py +50 -0
  28. kiln_ai/datamodel/test_external_tool_server.py +534 -152
  29. kiln_ai/datamodel/test_extraction_model.py +31 -0
  30. kiln_ai/datamodel/test_task.py +35 -1
  31. kiln_ai/datamodel/test_tool_id.py +106 -1
  32. kiln_ai/datamodel/tool_id.py +49 -0
  33. kiln_ai/tools/base_tool.py +30 -6
  34. kiln_ai/tools/built_in_tools/math_tools.py +12 -4
  35. kiln_ai/tools/kiln_task_tool.py +162 -0
  36. kiln_ai/tools/mcp_server_tool.py +7 -5
  37. kiln_ai/tools/mcp_session_manager.py +50 -24
  38. kiln_ai/tools/rag_tools.py +17 -6
  39. kiln_ai/tools/test_kiln_task_tool.py +527 -0
  40. kiln_ai/tools/test_mcp_server_tool.py +4 -15
  41. kiln_ai/tools/test_mcp_session_manager.py +186 -226
  42. kiln_ai/tools/test_rag_tools.py +86 -5
  43. kiln_ai/tools/test_tool_registry.py +199 -5
  44. kiln_ai/tools/tool_registry.py +49 -17
  45. kiln_ai/utils/filesystem.py +4 -4
  46. kiln_ai/utils/open_ai_types.py +19 -2
  47. kiln_ai/utils/pdf_utils.py +21 -0
  48. kiln_ai/utils/test_open_ai_types.py +88 -12
  49. kiln_ai/utils/test_pdf_utils.py +14 -1
  50. {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/METADATA +79 -1
  51. {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/RECORD +53 -45
  52. {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/WHEEL +0 -0
  53. {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -468,3 +468,34 @@ def test_document_invalid_mime_type(
468
468
  )
469
469
  def test_get_kind_from_mime_type(mime_type, expected_kind):
470
470
  assert get_kind_from_mime_type(mime_type) == expected_kind
471
+
472
+
473
+ def test_document_friendly_name(mock_project, mock_attachment_factory):
474
+ name = f"Test Document {uuid.uuid4()!s}"
475
+ document = Document(
476
+ name=name,
477
+ description=f"Test description {uuid.uuid4()!s}",
478
+ kind=Kind.DOCUMENT,
479
+ original_file=FileInfo(
480
+ filename=f"test_{name}.txt",
481
+ size=100,
482
+ mime_type="text/plain",
483
+ attachment=mock_attachment_factory(),
484
+ ),
485
+ parent=mock_project,
486
+ )
487
+ document.save_to_file()
488
+
489
+ # backward compatibility: old documents did not have name_override
490
+ assert document.name_override is None
491
+ assert document.friendly_name == name
492
+
493
+ # new documents have name_override
494
+ document.name_override = "Test Document Override"
495
+ assert document.friendly_name == "Test Document Override"
496
+
497
+ document.save_to_file()
498
+
499
+ document = Document.from_id_and_parent_path(str(document.id), mock_project.path)
500
+ assert document is not None
501
+ assert document.friendly_name == "Test Document Override"
@@ -254,7 +254,7 @@ def test_run_config_upgrade_old_entries():
254
254
  },
255
255
  "prompt": {
256
256
  "name": "Dazzling Unicorn",
257
- "description": "Frozen copy of prompt 'simple_prompt_builder', created for evaluations.",
257
+ "description": "Frozen copy of prompt 'simple_prompt_builder'.",
258
258
  "generator_id": "simple_prompt_builder",
259
259
  "prompt": "Generate a joke, given a theme. The theme will be provided as a word or phrase as the input to the model. The assistant should output a joke that is funny and relevant to the theme. If a style is provided, the joke should be in that style. The output should include a setup and punchline.\n\nYour response should respect the following requirements:\n1) Keep the joke on topic. If the user specifies a theme, the joke must be related to that theme.\n2) Avoid any jokes that are offensive or inappropriate. Keep the joke clean and appropriate for all audiences.\n3) Make the joke funny and engaging. It should be something that someone would want to tell to their friends. Something clever, not just a simple pun.\n",
260
260
  "chain_of_thought_instructions": None,
@@ -296,3 +296,37 @@ def test_run_config_upgrade_old_entries():
296
296
  def test_task_name_unicode_name():
297
297
  task = Task(name="你好", instruction="Do something")
298
298
  assert task.name == "你好"
299
+
300
+
301
+ def test_task_default_run_config_id_property(tmp_path):
302
+ """Test that default_run_config_id can be set and retrieved."""
303
+
304
+ # Create a task
305
+ task = Task(
306
+ name="Test Task", instruction="Test instruction", path=tmp_path / "task.kiln"
307
+ )
308
+ task.save_to_file()
309
+
310
+ # Create a run config for the task
311
+ run_config = TaskRunConfig(
312
+ name="Test Config",
313
+ run_config_properties=RunConfigProperties(
314
+ model_name="gpt-4",
315
+ model_provider_name="openai",
316
+ prompt_id=PromptGenerators.SIMPLE,
317
+ structured_output_mode=StructuredOutputMode.json_schema,
318
+ ),
319
+ parent=task,
320
+ )
321
+ run_config.save_to_file()
322
+
323
+ # Test None default (should be valid)
324
+ assert task.default_run_config_id is None
325
+
326
+ # Test setting a valid ID
327
+ task.default_run_config_id = "123456789012"
328
+ assert task.default_run_config_id == "123456789012"
329
+
330
+ # Test setting back to None
331
+ task.default_run_config_id = None
332
+ assert task.default_run_config_id is None
@@ -8,6 +8,7 @@ from kiln_ai.datamodel.tool_id import (
8
8
  KilnBuiltInToolId,
9
9
  ToolId,
10
10
  _check_tool_id,
11
+ kiln_task_server_id_from_tool_id,
11
12
  mcp_server_and_tool_name_from_id,
12
13
  rag_config_id_from_id,
13
14
  )
@@ -145,6 +146,39 @@ class TestCheckToolId:
145
146
  with pytest.raises(ValueError, match="Invalid RAG tool ID"):
146
147
  _check_tool_id("kiln_tool::rag::")
147
148
 
149
+ def test_valid_kiln_task_tools(self):
150
+ """Test validation of valid Kiln task tools."""
151
+ valid_ids = [
152
+ "kiln_task::server1",
153
+ "kiln_task::my_server",
154
+ "kiln_task::test_server_123",
155
+ "kiln_task::server_with_underscores",
156
+ "kiln_task::server-with-dashes",
157
+ "kiln_task::server.with.dots",
158
+ ]
159
+ for tool_id in valid_ids:
160
+ result = _check_tool_id(tool_id)
161
+ assert result == tool_id
162
+
163
+ def test_invalid_kiln_task_format(self):
164
+ """Test validation fails for invalid Kiln task tool formats."""
165
+ # These IDs start with the Kiln task prefix but have invalid formats
166
+ kiln_task_invalid_ids = [
167
+ "kiln_task::", # Missing server ID
168
+ "kiln_task::server::extra", # Too many parts
169
+ "kiln_task::server::tool::extra", # Too many parts
170
+ ]
171
+
172
+ for invalid_id in kiln_task_invalid_ids:
173
+ with pytest.raises(ValueError, match="Invalid Kiln task tool ID"):
174
+ _check_tool_id(invalid_id)
175
+
176
+ def test_kiln_task_tool_empty_server_id(self):
177
+ """Test that Kiln task tool with empty server ID is handled properly."""
178
+ # This tests the case where kiln_task_server_id_from_tool_id returns empty string which should raise an error
179
+ with pytest.raises(ValueError, match="Invalid Kiln task tool ID"):
180
+ _check_tool_id("kiln_task::")
181
+
148
182
 
149
183
  class TestMcpServerAndToolNameFromId:
150
184
  """Test the mcp_server_and_tool_name_from_id function."""
@@ -220,7 +254,7 @@ class TestToolIdPydanticType:
220
254
  model = self._ModelWithToolId(tool_id=tool_id.value)
221
255
  assert model.tool_id == tool_id.value
222
256
 
223
- def test_valid_mcp_tools(self):
257
+ def test_valid_tool_ids(self):
224
258
  """Test ToolId validates MCP remote and local tools."""
225
259
  valid_ids = [
226
260
  # Remote MCP tools
@@ -232,6 +266,9 @@ class TestToolIdPydanticType:
232
266
  # RAG tools
233
267
  "kiln_tool::rag::config1",
234
268
  "kiln_tool::rag::my_rag_config",
269
+ # Kiln task tools
270
+ "kiln_task::server1",
271
+ "kiln_task::my_server",
235
272
  ]
236
273
 
237
274
  for tool_id in valid_ids:
@@ -249,6 +286,8 @@ class TestToolIdPydanticType:
249
286
  "mcp::local::server",
250
287
  "kiln_tool::rag::",
251
288
  "kiln_tool::rag::config::extra",
289
+ "kiln_task::",
290
+ "kiln_task::server::extra",
252
291
  ]
253
292
 
254
293
  for invalid_id in invalid_ids:
@@ -318,3 +357,69 @@ class TestRagConfigIdFromId:
318
357
  # The validation for empty config ID happens in _check_tool_id
319
358
  result = rag_config_id_from_id("kiln_tool::rag::")
320
359
  assert result == ""
360
+
361
+
362
+ class TestKilnTaskServerIdFromToolId:
363
+ """Test the kiln_task_server_id_from_tool_id function."""
364
+
365
+ def test_valid_kiln_task_ids(self):
366
+ """Test parsing valid Kiln task tool IDs."""
367
+ test_cases = [
368
+ ("kiln_task::server1", "server1"),
369
+ ("kiln_task::my_server", "my_server"),
370
+ ("kiln_task::test_server_123", "test_server_123"),
371
+ ("kiln_task::a", "a"), # Minimal valid case
372
+ ("kiln_task::server_with_underscores", "server_with_underscores"),
373
+ ("kiln_task::server-with-dashes", "server-with-dashes"),
374
+ ("kiln_task::server.with.dots", "server.with.dots"),
375
+ ]
376
+
377
+ for tool_id, expected in test_cases:
378
+ result = kiln_task_server_id_from_tool_id(tool_id)
379
+ assert result == expected
380
+
381
+ def test_invalid_kiln_task_ids(self):
382
+ """Test parsing fails for invalid Kiln task tool IDs."""
383
+ # Test various invalid formats
384
+ invalid_ids = [
385
+ "kiln_task::", # Empty server ID
386
+ "kiln_task::server::extra", # Too many parts (3 parts)
387
+ "kiln_task::server::tool::extra", # Too many parts (4 parts)
388
+ "wrong::server", # Wrong prefix
389
+ "kiln_wrong::server", # Wrong prefix
390
+ "task::server", # Too few parts (2 parts)
391
+ "", # Empty string
392
+ "single_part", # Only 1 part
393
+ "kiln_task", # Missing server ID
394
+ ]
395
+
396
+ for invalid_id in invalid_ids:
397
+ with pytest.raises(ValueError, match="Invalid Kiln task tool ID format"):
398
+ kiln_task_server_id_from_tool_id(invalid_id)
399
+
400
+ def test_kiln_task_id_with_empty_server_id(self):
401
+ """Test that Kiln task tool ID with empty server ID raises error."""
402
+ with pytest.raises(ValueError, match="Invalid Kiln task tool ID format"):
403
+ kiln_task_server_id_from_tool_id("kiln_task::")
404
+
405
+ def test_kiln_task_id_with_whitespace_server_id(self):
406
+ """Test that Kiln task tool ID with whitespace-only server ID raises error."""
407
+ with pytest.raises(ValueError, match="Invalid Kiln task tool ID format"):
408
+ kiln_task_server_id_from_tool_id("kiln_task::")
409
+
410
+ def test_kiln_task_id_with_multiple_colons(self):
411
+ """Test that Kiln task tool ID with multiple colons raises error."""
412
+ with pytest.raises(ValueError, match="Invalid Kiln task tool ID format"):
413
+ kiln_task_server_id_from_tool_id("kiln_task::server::extra")
414
+
415
+ def test_kiln_task_id_case_sensitivity(self):
416
+ """Test that Kiln task tool IDs are case sensitive."""
417
+ # These should work
418
+ result1 = kiln_task_server_id_from_tool_id("kiln_task::Server")
419
+ assert result1 == "Server"
420
+
421
+ result2 = kiln_task_server_id_from_tool_id("kiln_task::SERVER")
422
+ assert result2 == "SERVER"
423
+
424
+ result3 = kiln_task_server_id_from_tool_id("kiln_task::server")
425
+ assert result3 == "server"
@@ -3,6 +3,8 @@ from typing import Annotated
3
3
 
4
4
  from pydantic import AfterValidator
5
5
 
6
+ from kiln_ai.datamodel.basemodel import ID_TYPE
7
+
6
8
  ToolId = Annotated[
7
9
  str,
8
10
  AfterValidator(lambda v: _check_tool_id(v)),
@@ -14,6 +16,7 @@ Tool IDs can be one of:
14
16
  - A kiln built-in tool name: kiln_tool::add_numbers
15
17
  - A remote MCP tool: mcp::remote::<server_id>::<tool_name>
16
18
  - A local MCP tool: mcp::local::<server_id>::<tool_name>
19
+ - A Kiln task tool: kiln_task::<server_id>
17
20
  - More coming soon like kiln_project_tool::rag::RAG_CONFIG_ID
18
21
  """
19
22
 
@@ -28,6 +31,7 @@ class KilnBuiltInToolId(str, Enum):
28
31
  MCP_REMOTE_TOOL_ID_PREFIX = "mcp::remote::"
29
32
  RAG_TOOL_ID_PREFIX = "kiln_tool::rag::"
30
33
  MCP_LOCAL_TOOL_ID_PREFIX = "mcp::local::"
34
+ KILN_TASK_TOOL_ID_PREFIX = "kiln_task::"
31
35
 
32
36
 
33
37
  def _check_tool_id(id: str) -> str:
@@ -68,6 +72,15 @@ def _check_tool_id(id: str) -> str:
68
72
  )
69
73
  return id
70
74
 
75
+ # Kiln task tools must have format: kiln_task::<server_id>
76
+ if id.startswith(KILN_TASK_TOOL_ID_PREFIX):
77
+ server_id = kiln_task_server_id_from_tool_id(id)
78
+ if not server_id:
79
+ raise ValueError(
80
+ f"Invalid Kiln task tool ID: {id}. Expected format: 'kiln_task::<server_id>'."
81
+ )
82
+ return id
83
+
71
84
  raise ValueError(f"Invalid tool ID: {id}")
72
85
 
73
86
 
@@ -103,3 +116,39 @@ def rag_config_id_from_id(id: str) -> str:
103
116
  f"Invalid RAG tool ID: {id}. Expected format: 'kiln_tool::rag::<rag_config_id>'."
104
117
  )
105
118
  return parts[2]
119
+
120
+
121
+ def build_rag_tool_id(rag_config_id: ID_TYPE) -> str:
122
+ """Construct the tool ID for a RAG configuration."""
123
+
124
+ return f"{RAG_TOOL_ID_PREFIX}{rag_config_id}"
125
+
126
+
127
+ def build_kiln_task_tool_id(server_id: ID_TYPE) -> str:
128
+ """Construct the tool ID for a Kiln task server."""
129
+ return f"{KILN_TASK_TOOL_ID_PREFIX}{server_id}"
130
+
131
+
132
+ def kiln_task_server_id_from_tool_id(tool_id: str) -> str:
133
+ """
134
+ Get the server ID from the tool ID.
135
+ """
136
+ if not tool_id.startswith(KILN_TASK_TOOL_ID_PREFIX):
137
+ raise ValueError(
138
+ f"Invalid Kiln task tool ID format: {tool_id}. Expected format: 'kiln_task::<server_id>'."
139
+ )
140
+
141
+ # Remove prefix and split on ::
142
+ remaining = tool_id[len(KILN_TASK_TOOL_ID_PREFIX) :]
143
+ if not remaining:
144
+ raise ValueError(
145
+ f"Invalid Kiln task tool ID format: {tool_id}. Expected format: 'kiln_task::<server_id>'."
146
+ )
147
+ parts = remaining.split("::")
148
+
149
+ if len(parts) != 1 or not parts[0].strip():
150
+ raise ValueError(
151
+ f"Invalid Kiln task tool ID format: {tool_id}. Expected format: 'kiln_task::<server_id>'."
152
+ )
153
+
154
+ return parts[0] # server_id
@@ -1,10 +1,34 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Any, Dict
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, TypedDict
3
4
 
4
5
  from kiln_ai.datamodel.json_schema import validate_schema_dict
5
6
  from kiln_ai.datamodel.tool_id import KilnBuiltInToolId, ToolId
6
7
 
7
8
 
9
+ class ToolFunction(TypedDict):
10
+ """Typed dict for the function definition within a tool call definition."""
11
+
12
+ name: str
13
+ description: str
14
+ parameters: Dict[str, Any]
15
+
16
+
17
+ class ToolCallDefinition(TypedDict):
18
+ """Typed dict for OpenAI-compatible tool call definitions."""
19
+
20
+ type: str # Must be "function"
21
+ function: ToolFunction
22
+
23
+
24
+ @dataclass
25
+ class ToolCallContext:
26
+ """Context passed to tools when they are called, containing information from the calling task."""
27
+
28
+ """Used for Kiln Tasks as Tools, to know if the tool call should save the task run it invoked to that task's Dataset."""
29
+ allow_saving: bool = True
30
+
31
+
8
32
  class KilnToolInterface(ABC):
9
33
  """
10
34
  Abstract interface defining the core API that all Kiln tools must implement.
@@ -12,12 +36,12 @@ class KilnToolInterface(ABC):
12
36
  """
13
37
 
14
38
  @abstractmethod
15
- async def run(self, **kwargs) -> Any:
16
- """Execute the tool with the given parameters."""
39
+ async def run(self, context: ToolCallContext | None = None, **kwargs) -> Any:
40
+ """Execute the tool with the given parameters and calling context if provided."""
17
41
  pass
18
42
 
19
43
  @abstractmethod
20
- async def toolcall_definition(self) -> Dict[str, Any]:
44
+ async def toolcall_definition(self) -> ToolCallDefinition:
21
45
  """Return the OpenAI-compatible tool definition for this tool."""
22
46
  pass
23
47
 
@@ -65,7 +89,7 @@ class KilnTool(KilnToolInterface):
65
89
  async def description(self) -> str:
66
90
  return self._description
67
91
 
68
- async def toolcall_definition(self) -> Dict[str, Any]:
92
+ async def toolcall_definition(self) -> ToolCallDefinition:
69
93
  """Generate OpenAI-compatible tool definition."""
70
94
  return {
71
95
  "type": "function",
@@ -77,6 +101,6 @@ class KilnTool(KilnToolInterface):
77
101
  }
78
102
 
79
103
  @abstractmethod
80
- async def run(self, **kwargs) -> str:
104
+ async def run(self, context: ToolCallContext | None = None, **kwargs) -> Any:
81
105
  """Subclasses must implement the actual tool logic."""
82
106
  pass
@@ -27,7 +27,9 @@ class AddTool(KilnTool):
27
27
  parameters_schema=parameters_schema,
28
28
  )
29
29
 
30
- async def run(self, a: Union[int, float], b: Union[int, float]) -> str:
30
+ async def run(
31
+ self, context=None, *, a: Union[int, float], b: Union[int, float]
32
+ ) -> str:
31
33
  """Add two numbers and return the result."""
32
34
  return str(a + b)
33
35
 
@@ -57,7 +59,9 @@ class SubtractTool(KilnTool):
57
59
  parameters_schema=parameters_schema,
58
60
  )
59
61
 
60
- async def run(self, a: Union[int, float], b: Union[int, float]) -> str:
62
+ async def run(
63
+ self, context=None, *, a: Union[int, float], b: Union[int, float]
64
+ ) -> str:
61
65
  """Subtract b from a and return the result."""
62
66
  return str(a - b)
63
67
 
@@ -84,7 +88,9 @@ class MultiplyTool(KilnTool):
84
88
  parameters_schema=parameters_schema,
85
89
  )
86
90
 
87
- async def run(self, a: Union[int, float], b: Union[int, float]) -> str:
91
+ async def run(
92
+ self, context=None, *, a: Union[int, float], b: Union[int, float]
93
+ ) -> str:
88
94
  """Multiply two numbers and return the result."""
89
95
  return str(a * b)
90
96
 
@@ -117,7 +123,9 @@ class DivideTool(KilnTool):
117
123
  parameters_schema=parameters_schema,
118
124
  )
119
125
 
120
- async def run(self, a: Union[int, float], b: Union[int, float]) -> str:
126
+ async def run(
127
+ self, context=None, *, a: Union[int, float], b: Union[int, float]
128
+ ) -> str:
121
129
  """Divide a by b and return the result."""
122
130
  if b == 0:
123
131
  raise ZeroDivisionError("Cannot divide by zero")
@@ -0,0 +1,162 @@
1
+ from dataclasses import dataclass
2
+ from functools import cached_property
3
+ from typing import Any, Dict
4
+
5
+ from kiln_ai.datamodel import Task
6
+ from kiln_ai.datamodel.external_tool_server import ExternalToolServer
7
+ from kiln_ai.datamodel.task import TaskRunConfig
8
+ from kiln_ai.datamodel.task_output import DataSource, DataSourceType
9
+ from kiln_ai.datamodel.tool_id import ToolId
10
+ from kiln_ai.tools.base_tool import (
11
+ KilnToolInterface,
12
+ ToolCallContext,
13
+ ToolCallDefinition,
14
+ )
15
+ from kiln_ai.utils.project_utils import project_from_id
16
+
17
+
18
+ @dataclass
19
+ class KilnTaskToolResult:
20
+ output: str
21
+ kiln_task_tool_data: str
22
+
23
+
24
+ class KilnTaskTool(KilnToolInterface):
25
+ """
26
+ A tool that wraps a Kiln task, allowing it to be called as a function.
27
+
28
+ This tool loads a task by ID and executes it using the specified run configuration.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ project_id: str,
34
+ tool_id: str,
35
+ data_model: ExternalToolServer,
36
+ ):
37
+ self._project_id = project_id
38
+ self._tool_server_model = data_model
39
+ self._tool_id = tool_id
40
+
41
+ self._name = data_model.properties.get("name", "")
42
+ self._description = data_model.properties.get("description", "")
43
+ self._task_id = data_model.properties.get("task_id", "")
44
+ self._run_config_id = data_model.properties.get("run_config_id", "")
45
+
46
+ async def id(self) -> ToolId:
47
+ return self._tool_id
48
+
49
+ async def name(self) -> str:
50
+ return self._name
51
+
52
+ async def description(self) -> str:
53
+ return self._description
54
+
55
+ async def toolcall_definition(self) -> ToolCallDefinition:
56
+ """Generate OpenAI-compatible tool definition."""
57
+ return {
58
+ "type": "function",
59
+ "function": {
60
+ "name": await self.name(),
61
+ "description": await self.description(),
62
+ "parameters": self.parameters_schema,
63
+ },
64
+ }
65
+
66
+ async def run(
67
+ self, context: ToolCallContext | None = None, **kwargs
68
+ ) -> KilnTaskToolResult:
69
+ """Execute the wrapped Kiln task with the given parameters and calling context."""
70
+ if context is None:
71
+ raise ValueError("Context is required for running a KilnTaskTool.")
72
+
73
+ # Determine the input format
74
+ if self._task.input_json_schema:
75
+ # Structured input - pass kwargs directly
76
+ input = kwargs
77
+ else:
78
+ # Plaintext input - extract from 'input' parameter
79
+ if "input" in kwargs:
80
+ input = kwargs["input"]
81
+ else:
82
+ raise ValueError(f"Input not found in kwargs: {kwargs}")
83
+
84
+ # These imports are here to avoid circular chains
85
+ from kiln_ai.adapters.adapter_registry import adapter_for_task
86
+ from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
87
+
88
+ # Create adapter and run the task using the calling task's allow_saving setting
89
+ adapter = adapter_for_task(
90
+ self._task,
91
+ run_config_properties=self._run_config.run_config_properties,
92
+ base_adapter_config=AdapterConfig(
93
+ allow_saving=context.allow_saving,
94
+ default_tags=["tool_call"],
95
+ ),
96
+ )
97
+ task_run = await adapter.invoke(
98
+ input,
99
+ input_source=DataSource(
100
+ type=DataSourceType.tool_call,
101
+ run_config=self._run_config.run_config_properties,
102
+ ),
103
+ )
104
+
105
+ return KilnTaskToolResult(
106
+ output=task_run.output.output,
107
+ kiln_task_tool_data=f"{self._project_id}:::{self._tool_id}:::{self._task.id}:::{task_run.id}",
108
+ )
109
+
110
+ @cached_property
111
+ def _task(self) -> Task:
112
+ # Load the project first
113
+ project = project_from_id(self._project_id)
114
+ if project is None:
115
+ raise ValueError(f"Project not found: {self._project_id}")
116
+
117
+ # Load the task from the project
118
+ task = Task.from_id_and_parent_path(self._task_id, project.path)
119
+ if task is None:
120
+ raise ValueError(
121
+ f"Task not found: {self._task_id} in project {self._project_id}"
122
+ )
123
+ return task
124
+
125
+ @cached_property
126
+ def _run_config(self) -> TaskRunConfig:
127
+ run_config = next(
128
+ (
129
+ run_config
130
+ for run_config in self._task.run_configs(readonly=True)
131
+ if run_config.id == self._run_config_id
132
+ ),
133
+ None,
134
+ )
135
+ if run_config is None:
136
+ raise ValueError(
137
+ f"Task run config not found: {self._run_config_id} for task {self._task_id} in project {self._project_id}"
138
+ )
139
+ return run_config
140
+
141
+ @cached_property
142
+ def parameters_schema(self) -> Dict[str, Any]:
143
+ if self._task.input_json_schema:
144
+ # Use the task's input schema directly if it exists
145
+ parameters_schema = self._task.input_schema()
146
+ else:
147
+ # For plaintext tasks, create a simple string input parameter
148
+ parameters_schema = {
149
+ "type": "object",
150
+ "properties": {
151
+ "input": {
152
+ "type": "string",
153
+ "description": "Plaintext input for the tool.",
154
+ }
155
+ },
156
+ "required": ["input"],
157
+ }
158
+ if parameters_schema is None:
159
+ raise ValueError(
160
+ f"Failed to create parameters schema for tool_id {self._tool_id}"
161
+ )
162
+ return parameters_schema
@@ -1,11 +1,13 @@
1
- from typing import Any, Dict
2
-
3
1
  from mcp.types import CallToolResult, TextContent
4
2
  from mcp.types import Tool as MCPTool
5
3
 
6
4
  from kiln_ai.datamodel.external_tool_server import ExternalToolServer
7
5
  from kiln_ai.datamodel.tool_id import MCP_REMOTE_TOOL_ID_PREFIX, ToolId
8
- from kiln_ai.tools.base_tool import KilnToolInterface
6
+ from kiln_ai.tools.base_tool import (
7
+ KilnToolInterface,
8
+ ToolCallContext,
9
+ ToolCallDefinition,
10
+ )
9
11
  from kiln_ai.tools.mcp_session_manager import MCPSessionManager
10
12
 
11
13
 
@@ -26,7 +28,7 @@ class MCPServerTool(KilnToolInterface):
26
28
  await self._load_tool_properties()
27
29
  return self._description
28
30
 
29
- async def toolcall_definition(self) -> Dict[str, Any]:
31
+ async def toolcall_definition(self) -> ToolCallDefinition:
30
32
  """Generate OpenAI-compatible tool definition."""
31
33
  await self._load_tool_properties()
32
34
  return {
@@ -38,7 +40,7 @@ class MCPServerTool(KilnToolInterface):
38
40
  },
39
41
  }
40
42
 
41
- async def run(self, **kwargs) -> Any:
43
+ async def run(self, context: ToolCallContext | None = None, **kwargs) -> str:
42
44
  result = await self._call_tool(**kwargs)
43
45
 
44
46
  if result.isError: