kiln-ai 0.21.0__py3-none-any.whl → 0.22.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/extractors/litellm_extractor.py +52 -32
- kiln_ai/adapters/extractors/test_litellm_extractor.py +169 -71
- kiln_ai/adapters/ml_embedding_model_list.py +330 -28
- kiln_ai/adapters/ml_model_list.py +503 -23
- kiln_ai/adapters/model_adapters/litellm_adapter.py +39 -8
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +78 -0
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +6 -9
- kiln_ai/adapters/test_ml_embedding_model_list.py +89 -279
- kiln_ai/adapters/test_ml_model_list.py +0 -10
- kiln_ai/adapters/vector_store/lancedb_adapter.py +24 -70
- kiln_ai/adapters/vector_store/lancedb_helpers.py +101 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +9 -16
- kiln_ai/adapters/vector_store/test_lancedb_helpers.py +142 -0
- kiln_ai/adapters/vector_store_loaders/__init__.py +0 -0
- kiln_ai/adapters/vector_store_loaders/test_lancedb_loader.py +282 -0
- kiln_ai/adapters/vector_store_loaders/test_vector_store_loader.py +544 -0
- kiln_ai/adapters/vector_store_loaders/vector_store_loader.py +91 -0
- kiln_ai/datamodel/basemodel.py +31 -3
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +14 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +3 -3
- kiln_ai/datamodel/test_basemodel.py +269 -13
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_model.py +31 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +106 -1
- kiln_ai/datamodel/tool_id.py +49 -0
- kiln_ai/tools/base_tool.py +30 -6
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +162 -0
- kiln_ai/tools/mcp_server_tool.py +7 -5
- kiln_ai/tools/mcp_session_manager.py +50 -24
- kiln_ai/tools/rag_tools.py +17 -6
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +186 -226
- kiln_ai/tools/test_rag_tools.py +86 -5
- kiln_ai/tools/test_tool_registry.py +199 -5
- kiln_ai/tools/tool_registry.py +49 -17
- kiln_ai/utils/filesystem.py +4 -4
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +21 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +14 -1
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/METADATA +79 -1
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/RECORD +53 -45
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -468,3 +468,34 @@ def test_document_invalid_mime_type(
|
|
|
468
468
|
)
|
|
469
469
|
def test_get_kind_from_mime_type(mime_type, expected_kind):
|
|
470
470
|
assert get_kind_from_mime_type(mime_type) == expected_kind
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def test_document_friendly_name(mock_project, mock_attachment_factory):
|
|
474
|
+
name = f"Test Document {uuid.uuid4()!s}"
|
|
475
|
+
document = Document(
|
|
476
|
+
name=name,
|
|
477
|
+
description=f"Test description {uuid.uuid4()!s}",
|
|
478
|
+
kind=Kind.DOCUMENT,
|
|
479
|
+
original_file=FileInfo(
|
|
480
|
+
filename=f"test_{name}.txt",
|
|
481
|
+
size=100,
|
|
482
|
+
mime_type="text/plain",
|
|
483
|
+
attachment=mock_attachment_factory(),
|
|
484
|
+
),
|
|
485
|
+
parent=mock_project,
|
|
486
|
+
)
|
|
487
|
+
document.save_to_file()
|
|
488
|
+
|
|
489
|
+
# backward compatibility: old documents did not have name_override
|
|
490
|
+
assert document.name_override is None
|
|
491
|
+
assert document.friendly_name == name
|
|
492
|
+
|
|
493
|
+
# new documents have name_override
|
|
494
|
+
document.name_override = "Test Document Override"
|
|
495
|
+
assert document.friendly_name == "Test Document Override"
|
|
496
|
+
|
|
497
|
+
document.save_to_file()
|
|
498
|
+
|
|
499
|
+
document = Document.from_id_and_parent_path(str(document.id), mock_project.path)
|
|
500
|
+
assert document is not None
|
|
501
|
+
assert document.friendly_name == "Test Document Override"
|
kiln_ai/datamodel/test_task.py
CHANGED
|
@@ -254,7 +254,7 @@ def test_run_config_upgrade_old_entries():
|
|
|
254
254
|
},
|
|
255
255
|
"prompt": {
|
|
256
256
|
"name": "Dazzling Unicorn",
|
|
257
|
-
"description": "Frozen copy of prompt 'simple_prompt_builder'
|
|
257
|
+
"description": "Frozen copy of prompt 'simple_prompt_builder'.",
|
|
258
258
|
"generator_id": "simple_prompt_builder",
|
|
259
259
|
"prompt": "Generate a joke, given a theme. The theme will be provided as a word or phrase as the input to the model. The assistant should output a joke that is funny and relevant to the theme. If a style is provided, the joke should be in that style. The output should include a setup and punchline.\n\nYour response should respect the following requirements:\n1) Keep the joke on topic. If the user specifies a theme, the joke must be related to that theme.\n2) Avoid any jokes that are offensive or inappropriate. Keep the joke clean and appropriate for all audiences.\n3) Make the joke funny and engaging. It should be something that someone would want to tell to their friends. Something clever, not just a simple pun.\n",
|
|
260
260
|
"chain_of_thought_instructions": None,
|
|
@@ -296,3 +296,37 @@ def test_run_config_upgrade_old_entries():
|
|
|
296
296
|
def test_task_name_unicode_name():
|
|
297
297
|
task = Task(name="你好", instruction="Do something")
|
|
298
298
|
assert task.name == "你好"
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def test_task_default_run_config_id_property(tmp_path):
|
|
302
|
+
"""Test that default_run_config_id can be set and retrieved."""
|
|
303
|
+
|
|
304
|
+
# Create a task
|
|
305
|
+
task = Task(
|
|
306
|
+
name="Test Task", instruction="Test instruction", path=tmp_path / "task.kiln"
|
|
307
|
+
)
|
|
308
|
+
task.save_to_file()
|
|
309
|
+
|
|
310
|
+
# Create a run config for the task
|
|
311
|
+
run_config = TaskRunConfig(
|
|
312
|
+
name="Test Config",
|
|
313
|
+
run_config_properties=RunConfigProperties(
|
|
314
|
+
model_name="gpt-4",
|
|
315
|
+
model_provider_name="openai",
|
|
316
|
+
prompt_id=PromptGenerators.SIMPLE,
|
|
317
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
318
|
+
),
|
|
319
|
+
parent=task,
|
|
320
|
+
)
|
|
321
|
+
run_config.save_to_file()
|
|
322
|
+
|
|
323
|
+
# Test None default (should be valid)
|
|
324
|
+
assert task.default_run_config_id is None
|
|
325
|
+
|
|
326
|
+
# Test setting a valid ID
|
|
327
|
+
task.default_run_config_id = "123456789012"
|
|
328
|
+
assert task.default_run_config_id == "123456789012"
|
|
329
|
+
|
|
330
|
+
# Test setting back to None
|
|
331
|
+
task.default_run_config_id = None
|
|
332
|
+
assert task.default_run_config_id is None
|
|
@@ -8,6 +8,7 @@ from kiln_ai.datamodel.tool_id import (
|
|
|
8
8
|
KilnBuiltInToolId,
|
|
9
9
|
ToolId,
|
|
10
10
|
_check_tool_id,
|
|
11
|
+
kiln_task_server_id_from_tool_id,
|
|
11
12
|
mcp_server_and_tool_name_from_id,
|
|
12
13
|
rag_config_id_from_id,
|
|
13
14
|
)
|
|
@@ -145,6 +146,39 @@ class TestCheckToolId:
|
|
|
145
146
|
with pytest.raises(ValueError, match="Invalid RAG tool ID"):
|
|
146
147
|
_check_tool_id("kiln_tool::rag::")
|
|
147
148
|
|
|
149
|
+
def test_valid_kiln_task_tools(self):
|
|
150
|
+
"""Test validation of valid Kiln task tools."""
|
|
151
|
+
valid_ids = [
|
|
152
|
+
"kiln_task::server1",
|
|
153
|
+
"kiln_task::my_server",
|
|
154
|
+
"kiln_task::test_server_123",
|
|
155
|
+
"kiln_task::server_with_underscores",
|
|
156
|
+
"kiln_task::server-with-dashes",
|
|
157
|
+
"kiln_task::server.with.dots",
|
|
158
|
+
]
|
|
159
|
+
for tool_id in valid_ids:
|
|
160
|
+
result = _check_tool_id(tool_id)
|
|
161
|
+
assert result == tool_id
|
|
162
|
+
|
|
163
|
+
def test_invalid_kiln_task_format(self):
|
|
164
|
+
"""Test validation fails for invalid Kiln task tool formats."""
|
|
165
|
+
# These IDs start with the Kiln task prefix but have invalid formats
|
|
166
|
+
kiln_task_invalid_ids = [
|
|
167
|
+
"kiln_task::", # Missing server ID
|
|
168
|
+
"kiln_task::server::extra", # Too many parts
|
|
169
|
+
"kiln_task::server::tool::extra", # Too many parts
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
for invalid_id in kiln_task_invalid_ids:
|
|
173
|
+
with pytest.raises(ValueError, match="Invalid Kiln task tool ID"):
|
|
174
|
+
_check_tool_id(invalid_id)
|
|
175
|
+
|
|
176
|
+
def test_kiln_task_tool_empty_server_id(self):
|
|
177
|
+
"""Test that Kiln task tool with empty server ID is handled properly."""
|
|
178
|
+
# This tests the case where kiln_task_server_id_from_tool_id returns empty string which should raise an error
|
|
179
|
+
with pytest.raises(ValueError, match="Invalid Kiln task tool ID"):
|
|
180
|
+
_check_tool_id("kiln_task::")
|
|
181
|
+
|
|
148
182
|
|
|
149
183
|
class TestMcpServerAndToolNameFromId:
|
|
150
184
|
"""Test the mcp_server_and_tool_name_from_id function."""
|
|
@@ -220,7 +254,7 @@ class TestToolIdPydanticType:
|
|
|
220
254
|
model = self._ModelWithToolId(tool_id=tool_id.value)
|
|
221
255
|
assert model.tool_id == tool_id.value
|
|
222
256
|
|
|
223
|
-
def
|
|
257
|
+
def test_valid_tool_ids(self):
|
|
224
258
|
"""Test ToolId validates MCP remote and local tools."""
|
|
225
259
|
valid_ids = [
|
|
226
260
|
# Remote MCP tools
|
|
@@ -232,6 +266,9 @@ class TestToolIdPydanticType:
|
|
|
232
266
|
# RAG tools
|
|
233
267
|
"kiln_tool::rag::config1",
|
|
234
268
|
"kiln_tool::rag::my_rag_config",
|
|
269
|
+
# Kiln task tools
|
|
270
|
+
"kiln_task::server1",
|
|
271
|
+
"kiln_task::my_server",
|
|
235
272
|
]
|
|
236
273
|
|
|
237
274
|
for tool_id in valid_ids:
|
|
@@ -249,6 +286,8 @@ class TestToolIdPydanticType:
|
|
|
249
286
|
"mcp::local::server",
|
|
250
287
|
"kiln_tool::rag::",
|
|
251
288
|
"kiln_tool::rag::config::extra",
|
|
289
|
+
"kiln_task::",
|
|
290
|
+
"kiln_task::server::extra",
|
|
252
291
|
]
|
|
253
292
|
|
|
254
293
|
for invalid_id in invalid_ids:
|
|
@@ -318,3 +357,69 @@ class TestRagConfigIdFromId:
|
|
|
318
357
|
# The validation for empty config ID happens in _check_tool_id
|
|
319
358
|
result = rag_config_id_from_id("kiln_tool::rag::")
|
|
320
359
|
assert result == ""
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
class TestKilnTaskServerIdFromToolId:
|
|
363
|
+
"""Test the kiln_task_server_id_from_tool_id function."""
|
|
364
|
+
|
|
365
|
+
def test_valid_kiln_task_ids(self):
|
|
366
|
+
"""Test parsing valid Kiln task tool IDs."""
|
|
367
|
+
test_cases = [
|
|
368
|
+
("kiln_task::server1", "server1"),
|
|
369
|
+
("kiln_task::my_server", "my_server"),
|
|
370
|
+
("kiln_task::test_server_123", "test_server_123"),
|
|
371
|
+
("kiln_task::a", "a"), # Minimal valid case
|
|
372
|
+
("kiln_task::server_with_underscores", "server_with_underscores"),
|
|
373
|
+
("kiln_task::server-with-dashes", "server-with-dashes"),
|
|
374
|
+
("kiln_task::server.with.dots", "server.with.dots"),
|
|
375
|
+
]
|
|
376
|
+
|
|
377
|
+
for tool_id, expected in test_cases:
|
|
378
|
+
result = kiln_task_server_id_from_tool_id(tool_id)
|
|
379
|
+
assert result == expected
|
|
380
|
+
|
|
381
|
+
def test_invalid_kiln_task_ids(self):
|
|
382
|
+
"""Test parsing fails for invalid Kiln task tool IDs."""
|
|
383
|
+
# Test various invalid formats
|
|
384
|
+
invalid_ids = [
|
|
385
|
+
"kiln_task::", # Empty server ID
|
|
386
|
+
"kiln_task::server::extra", # Too many parts (3 parts)
|
|
387
|
+
"kiln_task::server::tool::extra", # Too many parts (4 parts)
|
|
388
|
+
"wrong::server", # Wrong prefix
|
|
389
|
+
"kiln_wrong::server", # Wrong prefix
|
|
390
|
+
"task::server", # Too few parts (2 parts)
|
|
391
|
+
"", # Empty string
|
|
392
|
+
"single_part", # Only 1 part
|
|
393
|
+
"kiln_task", # Missing server ID
|
|
394
|
+
]
|
|
395
|
+
|
|
396
|
+
for invalid_id in invalid_ids:
|
|
397
|
+
with pytest.raises(ValueError, match="Invalid Kiln task tool ID format"):
|
|
398
|
+
kiln_task_server_id_from_tool_id(invalid_id)
|
|
399
|
+
|
|
400
|
+
def test_kiln_task_id_with_empty_server_id(self):
|
|
401
|
+
"""Test that Kiln task tool ID with empty server ID raises error."""
|
|
402
|
+
with pytest.raises(ValueError, match="Invalid Kiln task tool ID format"):
|
|
403
|
+
kiln_task_server_id_from_tool_id("kiln_task::")
|
|
404
|
+
|
|
405
|
+
def test_kiln_task_id_with_whitespace_server_id(self):
|
|
406
|
+
"""Test that Kiln task tool ID with whitespace-only server ID raises error."""
|
|
407
|
+
with pytest.raises(ValueError, match="Invalid Kiln task tool ID format"):
|
|
408
|
+
kiln_task_server_id_from_tool_id("kiln_task::")
|
|
409
|
+
|
|
410
|
+
def test_kiln_task_id_with_multiple_colons(self):
|
|
411
|
+
"""Test that Kiln task tool ID with multiple colons raises error."""
|
|
412
|
+
with pytest.raises(ValueError, match="Invalid Kiln task tool ID format"):
|
|
413
|
+
kiln_task_server_id_from_tool_id("kiln_task::server::extra")
|
|
414
|
+
|
|
415
|
+
def test_kiln_task_id_case_sensitivity(self):
|
|
416
|
+
"""Test that Kiln task tool IDs are case sensitive."""
|
|
417
|
+
# These should work
|
|
418
|
+
result1 = kiln_task_server_id_from_tool_id("kiln_task::Server")
|
|
419
|
+
assert result1 == "Server"
|
|
420
|
+
|
|
421
|
+
result2 = kiln_task_server_id_from_tool_id("kiln_task::SERVER")
|
|
422
|
+
assert result2 == "SERVER"
|
|
423
|
+
|
|
424
|
+
result3 = kiln_task_server_id_from_tool_id("kiln_task::server")
|
|
425
|
+
assert result3 == "server"
|
kiln_ai/datamodel/tool_id.py
CHANGED
|
@@ -3,6 +3,8 @@ from typing import Annotated
|
|
|
3
3
|
|
|
4
4
|
from pydantic import AfterValidator
|
|
5
5
|
|
|
6
|
+
from kiln_ai.datamodel.basemodel import ID_TYPE
|
|
7
|
+
|
|
6
8
|
ToolId = Annotated[
|
|
7
9
|
str,
|
|
8
10
|
AfterValidator(lambda v: _check_tool_id(v)),
|
|
@@ -14,6 +16,7 @@ Tool IDs can be one of:
|
|
|
14
16
|
- A kiln built-in tool name: kiln_tool::add_numbers
|
|
15
17
|
- A remote MCP tool: mcp::remote::<server_id>::<tool_name>
|
|
16
18
|
- A local MCP tool: mcp::local::<server_id>::<tool_name>
|
|
19
|
+
- A Kiln task tool: kiln_task::<server_id>
|
|
17
20
|
- More coming soon like kiln_project_tool::rag::RAG_CONFIG_ID
|
|
18
21
|
"""
|
|
19
22
|
|
|
@@ -28,6 +31,7 @@ class KilnBuiltInToolId(str, Enum):
|
|
|
28
31
|
MCP_REMOTE_TOOL_ID_PREFIX = "mcp::remote::"
|
|
29
32
|
RAG_TOOL_ID_PREFIX = "kiln_tool::rag::"
|
|
30
33
|
MCP_LOCAL_TOOL_ID_PREFIX = "mcp::local::"
|
|
34
|
+
KILN_TASK_TOOL_ID_PREFIX = "kiln_task::"
|
|
31
35
|
|
|
32
36
|
|
|
33
37
|
def _check_tool_id(id: str) -> str:
|
|
@@ -68,6 +72,15 @@ def _check_tool_id(id: str) -> str:
|
|
|
68
72
|
)
|
|
69
73
|
return id
|
|
70
74
|
|
|
75
|
+
# Kiln task tools must have format: kiln_task::<server_id>
|
|
76
|
+
if id.startswith(KILN_TASK_TOOL_ID_PREFIX):
|
|
77
|
+
server_id = kiln_task_server_id_from_tool_id(id)
|
|
78
|
+
if not server_id:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"Invalid Kiln task tool ID: {id}. Expected format: 'kiln_task::<server_id>'."
|
|
81
|
+
)
|
|
82
|
+
return id
|
|
83
|
+
|
|
71
84
|
raise ValueError(f"Invalid tool ID: {id}")
|
|
72
85
|
|
|
73
86
|
|
|
@@ -103,3 +116,39 @@ def rag_config_id_from_id(id: str) -> str:
|
|
|
103
116
|
f"Invalid RAG tool ID: {id}. Expected format: 'kiln_tool::rag::<rag_config_id>'."
|
|
104
117
|
)
|
|
105
118
|
return parts[2]
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def build_rag_tool_id(rag_config_id: ID_TYPE) -> str:
|
|
122
|
+
"""Construct the tool ID for a RAG configuration."""
|
|
123
|
+
|
|
124
|
+
return f"{RAG_TOOL_ID_PREFIX}{rag_config_id}"
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def build_kiln_task_tool_id(server_id: ID_TYPE) -> str:
|
|
128
|
+
"""Construct the tool ID for a Kiln task server."""
|
|
129
|
+
return f"{KILN_TASK_TOOL_ID_PREFIX}{server_id}"
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def kiln_task_server_id_from_tool_id(tool_id: str) -> str:
|
|
133
|
+
"""
|
|
134
|
+
Get the server ID from the tool ID.
|
|
135
|
+
"""
|
|
136
|
+
if not tool_id.startswith(KILN_TASK_TOOL_ID_PREFIX):
|
|
137
|
+
raise ValueError(
|
|
138
|
+
f"Invalid Kiln task tool ID format: {tool_id}. Expected format: 'kiln_task::<server_id>'."
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Remove prefix and split on ::
|
|
142
|
+
remaining = tool_id[len(KILN_TASK_TOOL_ID_PREFIX) :]
|
|
143
|
+
if not remaining:
|
|
144
|
+
raise ValueError(
|
|
145
|
+
f"Invalid Kiln task tool ID format: {tool_id}. Expected format: 'kiln_task::<server_id>'."
|
|
146
|
+
)
|
|
147
|
+
parts = remaining.split("::")
|
|
148
|
+
|
|
149
|
+
if len(parts) != 1 or not parts[0].strip():
|
|
150
|
+
raise ValueError(
|
|
151
|
+
f"Invalid Kiln task tool ID format: {tool_id}. Expected format: 'kiln_task::<server_id>'."
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
return parts[0] # server_id
|
kiln_ai/tools/base_tool.py
CHANGED
|
@@ -1,10 +1,34 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Dict, TypedDict
|
|
3
4
|
|
|
4
5
|
from kiln_ai.datamodel.json_schema import validate_schema_dict
|
|
5
6
|
from kiln_ai.datamodel.tool_id import KilnBuiltInToolId, ToolId
|
|
6
7
|
|
|
7
8
|
|
|
9
|
+
class ToolFunction(TypedDict):
|
|
10
|
+
"""Typed dict for the function definition within a tool call definition."""
|
|
11
|
+
|
|
12
|
+
name: str
|
|
13
|
+
description: str
|
|
14
|
+
parameters: Dict[str, Any]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ToolCallDefinition(TypedDict):
|
|
18
|
+
"""Typed dict for OpenAI-compatible tool call definitions."""
|
|
19
|
+
|
|
20
|
+
type: str # Must be "function"
|
|
21
|
+
function: ToolFunction
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class ToolCallContext:
|
|
26
|
+
"""Context passed to tools when they are called, containing information from the calling task."""
|
|
27
|
+
|
|
28
|
+
"""Used for Kiln Tasks as Tools, to know if the tool call should save the task run it invoked to that task's Dataset."""
|
|
29
|
+
allow_saving: bool = True
|
|
30
|
+
|
|
31
|
+
|
|
8
32
|
class KilnToolInterface(ABC):
|
|
9
33
|
"""
|
|
10
34
|
Abstract interface defining the core API that all Kiln tools must implement.
|
|
@@ -12,12 +36,12 @@ class KilnToolInterface(ABC):
|
|
|
12
36
|
"""
|
|
13
37
|
|
|
14
38
|
@abstractmethod
|
|
15
|
-
async def run(self, **kwargs) -> Any:
|
|
16
|
-
"""Execute the tool with the given parameters."""
|
|
39
|
+
async def run(self, context: ToolCallContext | None = None, **kwargs) -> Any:
|
|
40
|
+
"""Execute the tool with the given parameters and calling context if provided."""
|
|
17
41
|
pass
|
|
18
42
|
|
|
19
43
|
@abstractmethod
|
|
20
|
-
async def toolcall_definition(self) ->
|
|
44
|
+
async def toolcall_definition(self) -> ToolCallDefinition:
|
|
21
45
|
"""Return the OpenAI-compatible tool definition for this tool."""
|
|
22
46
|
pass
|
|
23
47
|
|
|
@@ -65,7 +89,7 @@ class KilnTool(KilnToolInterface):
|
|
|
65
89
|
async def description(self) -> str:
|
|
66
90
|
return self._description
|
|
67
91
|
|
|
68
|
-
async def toolcall_definition(self) ->
|
|
92
|
+
async def toolcall_definition(self) -> ToolCallDefinition:
|
|
69
93
|
"""Generate OpenAI-compatible tool definition."""
|
|
70
94
|
return {
|
|
71
95
|
"type": "function",
|
|
@@ -77,6 +101,6 @@ class KilnTool(KilnToolInterface):
|
|
|
77
101
|
}
|
|
78
102
|
|
|
79
103
|
@abstractmethod
|
|
80
|
-
async def run(self, **kwargs) ->
|
|
104
|
+
async def run(self, context: ToolCallContext | None = None, **kwargs) -> Any:
|
|
81
105
|
"""Subclasses must implement the actual tool logic."""
|
|
82
106
|
pass
|
|
@@ -27,7 +27,9 @@ class AddTool(KilnTool):
|
|
|
27
27
|
parameters_schema=parameters_schema,
|
|
28
28
|
)
|
|
29
29
|
|
|
30
|
-
async def run(
|
|
30
|
+
async def run(
|
|
31
|
+
self, context=None, *, a: Union[int, float], b: Union[int, float]
|
|
32
|
+
) -> str:
|
|
31
33
|
"""Add two numbers and return the result."""
|
|
32
34
|
return str(a + b)
|
|
33
35
|
|
|
@@ -57,7 +59,9 @@ class SubtractTool(KilnTool):
|
|
|
57
59
|
parameters_schema=parameters_schema,
|
|
58
60
|
)
|
|
59
61
|
|
|
60
|
-
async def run(
|
|
62
|
+
async def run(
|
|
63
|
+
self, context=None, *, a: Union[int, float], b: Union[int, float]
|
|
64
|
+
) -> str:
|
|
61
65
|
"""Subtract b from a and return the result."""
|
|
62
66
|
return str(a - b)
|
|
63
67
|
|
|
@@ -84,7 +88,9 @@ class MultiplyTool(KilnTool):
|
|
|
84
88
|
parameters_schema=parameters_schema,
|
|
85
89
|
)
|
|
86
90
|
|
|
87
|
-
async def run(
|
|
91
|
+
async def run(
|
|
92
|
+
self, context=None, *, a: Union[int, float], b: Union[int, float]
|
|
93
|
+
) -> str:
|
|
88
94
|
"""Multiply two numbers and return the result."""
|
|
89
95
|
return str(a * b)
|
|
90
96
|
|
|
@@ -117,7 +123,9 @@ class DivideTool(KilnTool):
|
|
|
117
123
|
parameters_schema=parameters_schema,
|
|
118
124
|
)
|
|
119
125
|
|
|
120
|
-
async def run(
|
|
126
|
+
async def run(
|
|
127
|
+
self, context=None, *, a: Union[int, float], b: Union[int, float]
|
|
128
|
+
) -> str:
|
|
121
129
|
"""Divide a by b and return the result."""
|
|
122
130
|
if b == 0:
|
|
123
131
|
raise ZeroDivisionError("Cannot divide by zero")
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from functools import cached_property
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
|
|
5
|
+
from kiln_ai.datamodel import Task
|
|
6
|
+
from kiln_ai.datamodel.external_tool_server import ExternalToolServer
|
|
7
|
+
from kiln_ai.datamodel.task import TaskRunConfig
|
|
8
|
+
from kiln_ai.datamodel.task_output import DataSource, DataSourceType
|
|
9
|
+
from kiln_ai.datamodel.tool_id import ToolId
|
|
10
|
+
from kiln_ai.tools.base_tool import (
|
|
11
|
+
KilnToolInterface,
|
|
12
|
+
ToolCallContext,
|
|
13
|
+
ToolCallDefinition,
|
|
14
|
+
)
|
|
15
|
+
from kiln_ai.utils.project_utils import project_from_id
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class KilnTaskToolResult:
|
|
20
|
+
output: str
|
|
21
|
+
kiln_task_tool_data: str
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class KilnTaskTool(KilnToolInterface):
|
|
25
|
+
"""
|
|
26
|
+
A tool that wraps a Kiln task, allowing it to be called as a function.
|
|
27
|
+
|
|
28
|
+
This tool loads a task by ID and executes it using the specified run configuration.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
project_id: str,
|
|
34
|
+
tool_id: str,
|
|
35
|
+
data_model: ExternalToolServer,
|
|
36
|
+
):
|
|
37
|
+
self._project_id = project_id
|
|
38
|
+
self._tool_server_model = data_model
|
|
39
|
+
self._tool_id = tool_id
|
|
40
|
+
|
|
41
|
+
self._name = data_model.properties.get("name", "")
|
|
42
|
+
self._description = data_model.properties.get("description", "")
|
|
43
|
+
self._task_id = data_model.properties.get("task_id", "")
|
|
44
|
+
self._run_config_id = data_model.properties.get("run_config_id", "")
|
|
45
|
+
|
|
46
|
+
async def id(self) -> ToolId:
|
|
47
|
+
return self._tool_id
|
|
48
|
+
|
|
49
|
+
async def name(self) -> str:
|
|
50
|
+
return self._name
|
|
51
|
+
|
|
52
|
+
async def description(self) -> str:
|
|
53
|
+
return self._description
|
|
54
|
+
|
|
55
|
+
async def toolcall_definition(self) -> ToolCallDefinition:
|
|
56
|
+
"""Generate OpenAI-compatible tool definition."""
|
|
57
|
+
return {
|
|
58
|
+
"type": "function",
|
|
59
|
+
"function": {
|
|
60
|
+
"name": await self.name(),
|
|
61
|
+
"description": await self.description(),
|
|
62
|
+
"parameters": self.parameters_schema,
|
|
63
|
+
},
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
async def run(
|
|
67
|
+
self, context: ToolCallContext | None = None, **kwargs
|
|
68
|
+
) -> KilnTaskToolResult:
|
|
69
|
+
"""Execute the wrapped Kiln task with the given parameters and calling context."""
|
|
70
|
+
if context is None:
|
|
71
|
+
raise ValueError("Context is required for running a KilnTaskTool.")
|
|
72
|
+
|
|
73
|
+
# Determine the input format
|
|
74
|
+
if self._task.input_json_schema:
|
|
75
|
+
# Structured input - pass kwargs directly
|
|
76
|
+
input = kwargs
|
|
77
|
+
else:
|
|
78
|
+
# Plaintext input - extract from 'input' parameter
|
|
79
|
+
if "input" in kwargs:
|
|
80
|
+
input = kwargs["input"]
|
|
81
|
+
else:
|
|
82
|
+
raise ValueError(f"Input not found in kwargs: {kwargs}")
|
|
83
|
+
|
|
84
|
+
# These imports are here to avoid circular chains
|
|
85
|
+
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
86
|
+
from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
|
|
87
|
+
|
|
88
|
+
# Create adapter and run the task using the calling task's allow_saving setting
|
|
89
|
+
adapter = adapter_for_task(
|
|
90
|
+
self._task,
|
|
91
|
+
run_config_properties=self._run_config.run_config_properties,
|
|
92
|
+
base_adapter_config=AdapterConfig(
|
|
93
|
+
allow_saving=context.allow_saving,
|
|
94
|
+
default_tags=["tool_call"],
|
|
95
|
+
),
|
|
96
|
+
)
|
|
97
|
+
task_run = await adapter.invoke(
|
|
98
|
+
input,
|
|
99
|
+
input_source=DataSource(
|
|
100
|
+
type=DataSourceType.tool_call,
|
|
101
|
+
run_config=self._run_config.run_config_properties,
|
|
102
|
+
),
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
return KilnTaskToolResult(
|
|
106
|
+
output=task_run.output.output,
|
|
107
|
+
kiln_task_tool_data=f"{self._project_id}:::{self._tool_id}:::{self._task.id}:::{task_run.id}",
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
@cached_property
|
|
111
|
+
def _task(self) -> Task:
|
|
112
|
+
# Load the project first
|
|
113
|
+
project = project_from_id(self._project_id)
|
|
114
|
+
if project is None:
|
|
115
|
+
raise ValueError(f"Project not found: {self._project_id}")
|
|
116
|
+
|
|
117
|
+
# Load the task from the project
|
|
118
|
+
task = Task.from_id_and_parent_path(self._task_id, project.path)
|
|
119
|
+
if task is None:
|
|
120
|
+
raise ValueError(
|
|
121
|
+
f"Task not found: {self._task_id} in project {self._project_id}"
|
|
122
|
+
)
|
|
123
|
+
return task
|
|
124
|
+
|
|
125
|
+
@cached_property
|
|
126
|
+
def _run_config(self) -> TaskRunConfig:
|
|
127
|
+
run_config = next(
|
|
128
|
+
(
|
|
129
|
+
run_config
|
|
130
|
+
for run_config in self._task.run_configs(readonly=True)
|
|
131
|
+
if run_config.id == self._run_config_id
|
|
132
|
+
),
|
|
133
|
+
None,
|
|
134
|
+
)
|
|
135
|
+
if run_config is None:
|
|
136
|
+
raise ValueError(
|
|
137
|
+
f"Task run config not found: {self._run_config_id} for task {self._task_id} in project {self._project_id}"
|
|
138
|
+
)
|
|
139
|
+
return run_config
|
|
140
|
+
|
|
141
|
+
@cached_property
|
|
142
|
+
def parameters_schema(self) -> Dict[str, Any]:
|
|
143
|
+
if self._task.input_json_schema:
|
|
144
|
+
# Use the task's input schema directly if it exists
|
|
145
|
+
parameters_schema = self._task.input_schema()
|
|
146
|
+
else:
|
|
147
|
+
# For plaintext tasks, create a simple string input parameter
|
|
148
|
+
parameters_schema = {
|
|
149
|
+
"type": "object",
|
|
150
|
+
"properties": {
|
|
151
|
+
"input": {
|
|
152
|
+
"type": "string",
|
|
153
|
+
"description": "Plaintext input for the tool.",
|
|
154
|
+
}
|
|
155
|
+
},
|
|
156
|
+
"required": ["input"],
|
|
157
|
+
}
|
|
158
|
+
if parameters_schema is None:
|
|
159
|
+
raise ValueError(
|
|
160
|
+
f"Failed to create parameters schema for tool_id {self._tool_id}"
|
|
161
|
+
)
|
|
162
|
+
return parameters_schema
|
kiln_ai/tools/mcp_server_tool.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
|
-
from typing import Any, Dict
|
|
2
|
-
|
|
3
1
|
from mcp.types import CallToolResult, TextContent
|
|
4
2
|
from mcp.types import Tool as MCPTool
|
|
5
3
|
|
|
6
4
|
from kiln_ai.datamodel.external_tool_server import ExternalToolServer
|
|
7
5
|
from kiln_ai.datamodel.tool_id import MCP_REMOTE_TOOL_ID_PREFIX, ToolId
|
|
8
|
-
from kiln_ai.tools.base_tool import
|
|
6
|
+
from kiln_ai.tools.base_tool import (
|
|
7
|
+
KilnToolInterface,
|
|
8
|
+
ToolCallContext,
|
|
9
|
+
ToolCallDefinition,
|
|
10
|
+
)
|
|
9
11
|
from kiln_ai.tools.mcp_session_manager import MCPSessionManager
|
|
10
12
|
|
|
11
13
|
|
|
@@ -26,7 +28,7 @@ class MCPServerTool(KilnToolInterface):
|
|
|
26
28
|
await self._load_tool_properties()
|
|
27
29
|
return self._description
|
|
28
30
|
|
|
29
|
-
async def toolcall_definition(self) ->
|
|
31
|
+
async def toolcall_definition(self) -> ToolCallDefinition:
|
|
30
32
|
"""Generate OpenAI-compatible tool definition."""
|
|
31
33
|
await self._load_tool_properties()
|
|
32
34
|
return {
|
|
@@ -38,7 +40,7 @@ class MCPServerTool(KilnToolInterface):
|
|
|
38
40
|
},
|
|
39
41
|
}
|
|
40
42
|
|
|
41
|
-
async def run(self, **kwargs) ->
|
|
43
|
+
async def run(self, context: ToolCallContext | None = None, **kwargs) -> str:
|
|
42
44
|
result = await self._call_tool(**kwargs)
|
|
43
45
|
|
|
44
46
|
if result.isError:
|