kiln-ai 0.21.0__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (45) hide show
  1. kiln_ai/adapters/extractors/litellm_extractor.py +52 -32
  2. kiln_ai/adapters/extractors/test_litellm_extractor.py +169 -71
  3. kiln_ai/adapters/ml_embedding_model_list.py +330 -28
  4. kiln_ai/adapters/ml_model_list.py +503 -23
  5. kiln_ai/adapters/model_adapters/litellm_adapter.py +34 -7
  6. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +78 -0
  7. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
  8. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
  9. kiln_ai/adapters/model_adapters/test_structured_output.py +6 -9
  10. kiln_ai/adapters/test_ml_embedding_model_list.py +89 -279
  11. kiln_ai/adapters/test_ml_model_list.py +0 -10
  12. kiln_ai/datamodel/basemodel.py +31 -3
  13. kiln_ai/datamodel/external_tool_server.py +206 -54
  14. kiln_ai/datamodel/extraction.py +14 -0
  15. kiln_ai/datamodel/task.py +5 -0
  16. kiln_ai/datamodel/task_output.py +41 -11
  17. kiln_ai/datamodel/test_attachment.py +3 -3
  18. kiln_ai/datamodel/test_basemodel.py +269 -13
  19. kiln_ai/datamodel/test_datasource.py +50 -0
  20. kiln_ai/datamodel/test_external_tool_server.py +534 -152
  21. kiln_ai/datamodel/test_extraction_model.py +31 -0
  22. kiln_ai/datamodel/test_task.py +35 -1
  23. kiln_ai/datamodel/test_tool_id.py +106 -1
  24. kiln_ai/datamodel/tool_id.py +36 -0
  25. kiln_ai/tools/base_tool.py +12 -3
  26. kiln_ai/tools/built_in_tools/math_tools.py +12 -4
  27. kiln_ai/tools/kiln_task_tool.py +158 -0
  28. kiln_ai/tools/mcp_server_tool.py +2 -2
  29. kiln_ai/tools/mcp_session_manager.py +50 -24
  30. kiln_ai/tools/rag_tools.py +12 -5
  31. kiln_ai/tools/test_kiln_task_tool.py +527 -0
  32. kiln_ai/tools/test_mcp_server_tool.py +4 -15
  33. kiln_ai/tools/test_mcp_session_manager.py +186 -226
  34. kiln_ai/tools/test_rag_tools.py +86 -5
  35. kiln_ai/tools/test_tool_registry.py +199 -5
  36. kiln_ai/tools/tool_registry.py +49 -17
  37. kiln_ai/utils/filesystem.py +4 -4
  38. kiln_ai/utils/open_ai_types.py +19 -2
  39. kiln_ai/utils/pdf_utils.py +21 -0
  40. kiln_ai/utils/test_open_ai_types.py +88 -12
  41. kiln_ai/utils/test_pdf_utils.py +14 -1
  42. {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +3 -1
  43. {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.0.dist-info}/RECORD +45 -43
  44. {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
  45. {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  from functools import cached_property
2
- from typing import Any, Dict, List
2
+ from typing import Any, Dict, List, TypedDict
3
3
 
4
4
  from pydantic import BaseModel
5
5
 
@@ -18,7 +18,7 @@ from kiln_ai.datamodel.project import Project
18
18
  from kiln_ai.datamodel.rag import RagConfig
19
19
  from kiln_ai.datamodel.tool_id import ToolId
20
20
  from kiln_ai.datamodel.vector_store import VectorStoreConfig, VectorStoreType
21
- from kiln_ai.tools.base_tool import KilnToolInterface
21
+ from kiln_ai.tools.base_tool import KilnToolInterface, ToolCallContext
22
22
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
23
23
 
24
24
 
@@ -46,6 +46,10 @@ def format_search_results(search_results: List[SearchResult]) -> str:
46
46
  return "\n=========\n".join([result.serialize() for result in results])
47
47
 
48
48
 
49
+ class RagParams(TypedDict):
50
+ query: str
51
+
52
+
49
53
  class RagTool(KilnToolInterface):
50
54
  """
51
55
  A tool that searches the vector store and returns the most relevant chunks.
@@ -126,7 +130,10 @@ class RagTool(KilnToolInterface):
126
130
  },
127
131
  }
128
132
 
129
- async def run(self, query: str) -> str:
133
+ async def run(self, context: ToolCallContext | None = None, **kwargs) -> str:
134
+ kwargs = RagParams(**kwargs)
135
+ query = kwargs["query"]
136
+
130
137
  _, embedding_adapter = self.embedding
131
138
 
132
139
  vector_store_adapter = await self.vector_store()
@@ -152,6 +159,6 @@ class RagTool(KilnToolInterface):
152
159
  store_query.query_embedding = query_embedding_result.embeddings[0].vector
153
160
 
154
161
  search_results = await vector_store_adapter.search(store_query)
155
- context = format_search_results(search_results)
162
+ search_results_as_text = format_search_results(search_results)
156
163
 
157
- return context
164
+ return search_results_as_text
@@ -0,0 +1,527 @@
1
+ from unittest.mock import AsyncMock, MagicMock, patch
2
+
3
+ import pytest
4
+
5
+ from kiln_ai.datamodel import Task
6
+ from kiln_ai.datamodel.datamodel_enums import ModelProviderName, StructuredOutputMode
7
+ from kiln_ai.datamodel.external_tool_server import ExternalToolServer, ToolServerType
8
+ from kiln_ai.datamodel.run_config import RunConfigProperties
9
+ from kiln_ai.datamodel.task import TaskRunConfig
10
+ from kiln_ai.datamodel.task_output import DataSource, DataSourceType
11
+ from kiln_ai.tools.base_tool import ToolCallContext
12
+ from kiln_ai.tools.kiln_task_tool import KilnTaskTool, KilnTaskToolResult
13
+
14
+
15
+ class TestKilnTaskToolResult:
16
+ """Test the KilnTaskToolResult class."""
17
+
18
+ def test_init(self):
19
+ """Test KilnTaskToolResult initialization."""
20
+ output = "test output"
21
+ kiln_task_tool_data = "project_id:::tool_id:::task_id:::run_id"
22
+
23
+ result = KilnTaskToolResult(output, kiln_task_tool_data)
24
+
25
+ assert result.output == output
26
+ assert result.kiln_task_tool_data == kiln_task_tool_data
27
+
28
+ def test_init_with_empty_strings(self):
29
+ """Test KilnTaskToolResult initialization with empty strings."""
30
+ result = KilnTaskToolResult("", "")
31
+
32
+ assert result.output == ""
33
+ assert result.kiln_task_tool_data == ""
34
+
35
+
36
+ class TestKilnTaskTool:
37
+ """Test the KilnTaskTool class."""
38
+
39
+ @pytest.fixture
40
+ def mock_external_tool_server(self):
41
+ """Create a mock ExternalToolServer for testing."""
42
+ return ExternalToolServer(
43
+ name="test_tool",
44
+ type=ToolServerType.kiln_task,
45
+ description="Test Kiln task tool",
46
+ properties={
47
+ "name": "test_task_tool",
48
+ "description": "A test task tool",
49
+ "task_id": "test_task_123",
50
+ "run_config_id": "test_config_456",
51
+ "is_archived": False,
52
+ },
53
+ )
54
+
55
+ @pytest.fixture
56
+ def mock_task(self):
57
+ """Create a mock Task for testing."""
58
+ task = MagicMock(spec=Task)
59
+ task.id = "test_task_123"
60
+ task.input_json_schema = None
61
+ task.input_schema.return_value = None
62
+ task.run_configs.return_value = []
63
+ return task
64
+
65
+ @pytest.fixture
66
+ def mock_run_config(self):
67
+ """Create a mock TaskRunConfig for testing."""
68
+ run_config = MagicMock(spec=TaskRunConfig)
69
+ run_config.id = "test_config_456"
70
+ run_config.run_config_properties = {
71
+ "model_name": "gpt-4",
72
+ "model_provider_name": "openai",
73
+ "prompt_id": "simple_prompt_builder",
74
+ "structured_output_mode": "default",
75
+ }
76
+ return run_config
77
+
78
+ @pytest.fixture
79
+ def mock_context(self):
80
+ """Create a mock ToolCallContext for testing."""
81
+ context = MagicMock(spec=ToolCallContext)
82
+ context.allow_saving = True
83
+ return context
84
+
85
+ @pytest.fixture
86
+ def kiln_task_tool(self, mock_external_tool_server):
87
+ """Create a KilnTaskTool instance for testing."""
88
+ return KilnTaskTool(
89
+ project_id="test_project",
90
+ tool_id="test_tool_id",
91
+ data_model=mock_external_tool_server,
92
+ )
93
+
94
+ @pytest.mark.asyncio
95
+ async def test_init(self, mock_external_tool_server):
96
+ """Test KilnTaskTool initialization."""
97
+ tool = KilnTaskTool(
98
+ project_id="test_project",
99
+ tool_id="test_tool_id",
100
+ data_model=mock_external_tool_server,
101
+ )
102
+
103
+ assert tool._project_id == "test_project"
104
+ assert tool._tool_id == "test_tool_id"
105
+ assert tool._tool_server_model == mock_external_tool_server
106
+ assert tool._name == "test_task_tool"
107
+ assert tool._description == "A test task tool"
108
+ assert tool._task_id == "test_task_123"
109
+ assert tool._run_config_id == "test_config_456"
110
+
111
+ @pytest.mark.asyncio
112
+ async def test_init_with_missing_properties(self):
113
+ """Test KilnTaskTool initialization with missing properties."""
114
+ # Create a server with minimal required properties
115
+ server = ExternalToolServer(
116
+ name="test_tool",
117
+ type=ToolServerType.kiln_task,
118
+ description="Test tool",
119
+ properties={
120
+ "name": "minimal_tool",
121
+ "description": "",
122
+ "task_id": "",
123
+ "run_config_id": "",
124
+ "is_archived": False,
125
+ },
126
+ )
127
+
128
+ tool = KilnTaskTool(
129
+ project_id="test_project",
130
+ tool_id="test_tool_id",
131
+ data_model=server,
132
+ )
133
+
134
+ assert tool._name == "minimal_tool"
135
+ assert tool._description == ""
136
+ assert tool._task_id == ""
137
+ assert tool._run_config_id == ""
138
+
139
+ @pytest.mark.asyncio
140
+ async def test_id(self, kiln_task_tool):
141
+ """Test the id method."""
142
+ result = await kiln_task_tool.id()
143
+ assert result == "test_tool_id"
144
+
145
+ @pytest.mark.asyncio
146
+ async def test_name(self, kiln_task_tool):
147
+ """Test the name method."""
148
+ result = await kiln_task_tool.name()
149
+ assert result == "test_task_tool"
150
+
151
+ @pytest.mark.asyncio
152
+ async def test_description(self, kiln_task_tool):
153
+ """Test the description method."""
154
+ result = await kiln_task_tool.description()
155
+ assert result == "A test task tool"
156
+
157
+ @pytest.mark.asyncio
158
+ async def test_toolcall_definition(self, kiln_task_tool):
159
+ """Test the toolcall_definition method."""
160
+ # Mock the parameters_schema property directly
161
+ kiln_task_tool.parameters_schema = {"type": "object"}
162
+
163
+ definition = await kiln_task_tool.toolcall_definition()
164
+
165
+ assert definition["type"] == "function"
166
+ assert definition["function"]["name"] == "test_task_tool"
167
+ assert definition["function"]["description"] == "A test task tool"
168
+ assert definition["function"]["parameters"] == {"type": "object"}
169
+
170
+ @pytest.mark.asyncio
171
+ async def test_run_with_plaintext_input(
172
+ self, kiln_task_tool, mock_context, mock_task, mock_run_config
173
+ ):
174
+ """Test the run method with plaintext input."""
175
+ # Setup mocks
176
+ kiln_task_tool._task = mock_task
177
+ kiln_task_tool._run_config = mock_run_config
178
+
179
+ with (
180
+ patch(
181
+ "kiln_ai.adapters.adapter_registry.adapter_for_task"
182
+ ) as mock_adapter_for_task,
183
+ patch(
184
+ "kiln_ai.adapters.model_adapters.base_adapter.AdapterConfig"
185
+ ) as mock_adapter_config,
186
+ ):
187
+ # Mock adapter and task run
188
+ mock_adapter = AsyncMock()
189
+ mock_adapter_for_task.return_value = mock_adapter
190
+
191
+ mock_task_run = MagicMock()
192
+ mock_task_run.id = "run_789"
193
+ mock_task_run.output.output = "Task completed successfully"
194
+ mock_adapter.invoke.return_value = mock_task_run
195
+
196
+ # Test with plaintext input
197
+ result = await kiln_task_tool.run(context=mock_context, input="test input")
198
+
199
+ # Verify adapter was created correctly
200
+ mock_adapter_for_task.assert_called_once_with(
201
+ mock_task,
202
+ run_config_properties={
203
+ "model_name": "gpt-4",
204
+ "model_provider_name": "openai",
205
+ "prompt_id": "simple_prompt_builder",
206
+ "structured_output_mode": "default",
207
+ },
208
+ base_adapter_config=mock_adapter_config.return_value,
209
+ )
210
+
211
+ # Verify adapter config
212
+ mock_adapter_config.assert_called_once_with(
213
+ allow_saving=True,
214
+ default_tags=["tool_call"],
215
+ )
216
+
217
+ # Verify adapter invoke was called
218
+ mock_adapter.invoke.assert_called_once_with(
219
+ "test input",
220
+ input_source=DataSource(
221
+ type=DataSourceType.tool_call,
222
+ run_config=RunConfigProperties(
223
+ model_name="gpt-4",
224
+ model_provider_name=ModelProviderName.openai,
225
+ prompt_id="simple_prompt_builder",
226
+ structured_output_mode=StructuredOutputMode.default,
227
+ ),
228
+ ),
229
+ )
230
+
231
+ # Verify result
232
+ assert isinstance(result, KilnTaskToolResult)
233
+ assert result.output == "Task completed successfully"
234
+ assert (
235
+ result.kiln_task_tool_data
236
+ == "test_project:::test_tool_id:::test_task_123:::run_789"
237
+ )
238
+
239
+ @pytest.mark.asyncio
240
+ async def test_run_with_structured_input(
241
+ self, kiln_task_tool, mock_context, mock_task, mock_run_config
242
+ ):
243
+ """Test the run method with structured input."""
244
+ # Setup task with JSON schema
245
+ mock_task.input_json_schema = {
246
+ "type": "object",
247
+ "properties": {"param1": {"type": "string"}},
248
+ }
249
+
250
+ # Setup mocks
251
+ kiln_task_tool._task = mock_task
252
+ kiln_task_tool._run_config = mock_run_config
253
+
254
+ with patch(
255
+ "kiln_ai.adapters.adapter_registry.adapter_for_task"
256
+ ) as mock_adapter_for_task:
257
+ # Mock adapter and task run
258
+ mock_adapter = AsyncMock()
259
+ mock_adapter_for_task.return_value = mock_adapter
260
+
261
+ mock_task_run = MagicMock()
262
+ mock_task_run.id = "run_789"
263
+ mock_task_run.output.output = "Structured task completed"
264
+ mock_adapter.invoke.return_value = mock_task_run
265
+
266
+ # Test with structured input
267
+ result = await kiln_task_tool.run(
268
+ context=mock_context, param1="value1", param2="value2"
269
+ )
270
+
271
+ # Verify adapter invoke was called with kwargs
272
+ mock_adapter.invoke.assert_called_once_with(
273
+ {"param1": "value1", "param2": "value2"},
274
+ input_source=DataSource(
275
+ type=DataSourceType.tool_call,
276
+ run_config=RunConfigProperties(
277
+ model_name="gpt-4",
278
+ model_provider_name=ModelProviderName.openai,
279
+ prompt_id="simple_prompt_builder",
280
+ structured_output_mode=StructuredOutputMode.default,
281
+ ),
282
+ ),
283
+ )
284
+
285
+ # Verify result
286
+ assert result.output == "Structured task completed"
287
+
288
+ @pytest.mark.asyncio
289
+ async def test_run_without_context(self, kiln_task_tool):
290
+ """Test the run method without context raises ValueError."""
291
+ with pytest.raises(
292
+ ValueError, match="Context is required for running a KilnTaskTool"
293
+ ):
294
+ await kiln_task_tool.run(input="test input")
295
+
296
+ @pytest.mark.asyncio
297
+ async def test_run_plaintext_missing_input(
298
+ self, kiln_task_tool, mock_context, mock_task
299
+ ):
300
+ """Test the run method with plaintext task but missing input parameter."""
301
+ # Setup mocks
302
+ kiln_task_tool._task = mock_task
303
+
304
+ with pytest.raises(ValueError, match="Input not found in kwargs"):
305
+ await kiln_task_tool.run(context=mock_context, wrong_param="value")
306
+
307
+ @pytest.mark.asyncio
308
+ async def test_task_property_project_not_found(self, kiln_task_tool):
309
+ """Test _task property when project is not found."""
310
+ with patch("kiln_ai.tools.kiln_task_tool.project_from_id", return_value=None):
311
+ with pytest.raises(ValueError, match="Project not found: test_project"):
312
+ _ = kiln_task_tool._task
313
+
314
+ @pytest.mark.asyncio
315
+ async def test_task_property_task_not_found(self, kiln_task_tool):
316
+ """Test _task property when task is not found."""
317
+ mock_project = MagicMock()
318
+ mock_project.path = "/test/path"
319
+
320
+ with (
321
+ patch(
322
+ "kiln_ai.tools.kiln_task_tool.project_from_id",
323
+ return_value=mock_project,
324
+ ),
325
+ patch(
326
+ "kiln_ai.tools.kiln_task_tool.Task.from_id_and_parent_path",
327
+ return_value=None,
328
+ ),
329
+ ):
330
+ with pytest.raises(
331
+ ValueError,
332
+ match="Task not found: test_task_123 in project test_project",
333
+ ):
334
+ _ = kiln_task_tool._task
335
+
336
+ @pytest.mark.asyncio
337
+ async def test_task_property_success(self, kiln_task_tool, mock_task):
338
+ """Test _task property when task is found successfully."""
339
+ mock_project = MagicMock()
340
+ mock_project.path = "/test/path"
341
+
342
+ with (
343
+ patch(
344
+ "kiln_ai.tools.kiln_task_tool.project_from_id",
345
+ return_value=mock_project,
346
+ ),
347
+ patch(
348
+ "kiln_ai.tools.kiln_task_tool.Task.from_id_and_parent_path",
349
+ return_value=mock_task,
350
+ ),
351
+ ):
352
+ result = kiln_task_tool._task
353
+ assert result == mock_task
354
+
355
+ @pytest.mark.asyncio
356
+ async def test_run_config_property_not_found(self, kiln_task_tool, mock_task):
357
+ """Test _run_config property when run config is not found."""
358
+ mock_task.run_configs.return_value = []
359
+
360
+ # Setup mocks
361
+ kiln_task_tool._task = mock_task
362
+
363
+ with pytest.raises(
364
+ ValueError,
365
+ match="Task run config not found: test_config_456 for task test_task_123 in project test_project",
366
+ ):
367
+ _ = kiln_task_tool._run_config
368
+
369
+ @pytest.mark.asyncio
370
+ async def test_run_config_property_success(
371
+ self, kiln_task_tool, mock_task, mock_run_config
372
+ ):
373
+ """Test _run_config property when run config is found successfully."""
374
+ mock_task.run_configs.return_value = [mock_run_config]
375
+
376
+ # Setup mocks
377
+ kiln_task_tool._task = mock_task
378
+
379
+ result = kiln_task_tool._run_config
380
+ assert result == mock_run_config
381
+
382
+ @pytest.mark.asyncio
383
+ async def test_parameters_schema_with_json_schema(self, kiln_task_tool, mock_task):
384
+ """Test parameters_schema property with JSON schema."""
385
+ expected_schema = {
386
+ "type": "object",
387
+ "properties": {"param": {"type": "string"}},
388
+ }
389
+ mock_task.input_json_schema = expected_schema
390
+ mock_task.input_schema.return_value = expected_schema
391
+
392
+ # Setup mocks
393
+ kiln_task_tool._task = mock_task
394
+
395
+ result = kiln_task_tool.parameters_schema
396
+ assert result == expected_schema
397
+
398
+ @pytest.mark.asyncio
399
+ async def test_parameters_schema_plaintext(self, kiln_task_tool, mock_task):
400
+ """Test parameters_schema property for plaintext task."""
401
+ mock_task.input_json_schema = None
402
+
403
+ # Setup mocks
404
+ kiln_task_tool._task = mock_task
405
+
406
+ result = kiln_task_tool.parameters_schema
407
+
408
+ expected = {
409
+ "type": "object",
410
+ "properties": {
411
+ "input": {
412
+ "type": "string",
413
+ "description": "Plaintext input for the tool.",
414
+ }
415
+ },
416
+ "required": ["input"],
417
+ }
418
+ assert result == expected
419
+
420
+ @pytest.mark.asyncio
421
+ async def test_parameters_schema_none_raises_error(self, kiln_task_tool, mock_task):
422
+ """Test parameters_schema property when schema is None raises ValueError."""
423
+ # Set up a task with JSON schema but input_schema returns None
424
+ mock_task.input_json_schema = {
425
+ "type": "object",
426
+ "properties": {"param": {"type": "string"}},
427
+ }
428
+ mock_task.input_schema.return_value = None
429
+
430
+ # Setup mocks - directly assign the task to bypass cached property
431
+ kiln_task_tool._task = mock_task
432
+
433
+ with pytest.raises(
434
+ ValueError,
435
+ match="Failed to create parameters schema for tool_id test_tool_id",
436
+ ):
437
+ _ = kiln_task_tool.parameters_schema
438
+
439
+ @pytest.mark.asyncio
440
+ async def test_cached_properties(self, kiln_task_tool, mock_task, mock_run_config):
441
+ """Test that cached properties work correctly."""
442
+ mock_project = MagicMock()
443
+ mock_project.path = "/test/path"
444
+
445
+ with (
446
+ patch(
447
+ "kiln_ai.tools.kiln_task_tool.project_from_id",
448
+ return_value=mock_project,
449
+ ),
450
+ patch(
451
+ "kiln_ai.tools.kiln_task_tool.Task.from_id_and_parent_path",
452
+ return_value=mock_task,
453
+ ),
454
+ ):
455
+ # First access should call the methods
456
+ task1 = kiln_task_tool._task
457
+ task2 = kiln_task_tool._task
458
+
459
+ # Should be the same object (cached)
460
+ assert task1 is task2
461
+
462
+ # Verify the methods were called only once
463
+ assert mock_project is not None # project_from_id was called
464
+ # Task.from_id_and_parent_path should have been called once
465
+ with patch(
466
+ "kiln_ai.tools.kiln_task_tool.Task.from_id_and_parent_path"
467
+ ) as mock_from_id:
468
+ mock_from_id.return_value = mock_task
469
+ _ = kiln_task_tool._task
470
+ # Should not be called again due to caching
471
+ mock_from_id.assert_not_called()
472
+
473
+ @pytest.mark.asyncio
474
+ async def test_run_with_adapter_exception(
475
+ self, kiln_task_tool, mock_context, mock_task, mock_run_config
476
+ ):
477
+ """Test the run method when adapter raises an exception."""
478
+ # Setup mocks
479
+ kiln_task_tool._task = mock_task
480
+ kiln_task_tool._run_config = mock_run_config
481
+
482
+ with patch(
483
+ "kiln_ai.adapters.adapter_registry.adapter_for_task"
484
+ ) as mock_adapter_for_task:
485
+ # Mock adapter to raise an exception
486
+ mock_adapter = AsyncMock()
487
+ mock_adapter.invoke.side_effect = Exception("Adapter failed")
488
+ mock_adapter_for_task.return_value = mock_adapter
489
+
490
+ with pytest.raises(Exception, match="Adapter failed"):
491
+ await kiln_task_tool.run(context=mock_context, input="test input")
492
+
493
+ @pytest.mark.asyncio
494
+ async def test_run_with_different_allow_saving(
495
+ self, kiln_task_tool, mock_task, mock_run_config
496
+ ):
497
+ """Test the run method with different allow_saving values."""
498
+ mock_context_false = MagicMock(spec=ToolCallContext)
499
+ mock_context_false.allow_saving = False
500
+
501
+ # Setup mocks
502
+ kiln_task_tool._task = mock_task
503
+ kiln_task_tool._run_config = mock_run_config
504
+
505
+ with (
506
+ patch(
507
+ "kiln_ai.adapters.adapter_registry.adapter_for_task"
508
+ ) as mock_adapter_for_task,
509
+ patch(
510
+ "kiln_ai.adapters.model_adapters.base_adapter.AdapterConfig"
511
+ ) as mock_adapter_config,
512
+ ):
513
+ mock_adapter = AsyncMock()
514
+ mock_adapter_for_task.return_value = mock_adapter
515
+
516
+ mock_task_run = MagicMock()
517
+ mock_task_run.id = "run_789"
518
+ mock_task_run.output.output = "Task completed"
519
+ mock_adapter.invoke.return_value = mock_task_run
520
+
521
+ await kiln_task_tool.run(context=mock_context_false, input="test input")
522
+
523
+ # Verify adapter config was called with allow_saving=False
524
+ mock_adapter_config.assert_called_once_with(
525
+ allow_saving=False,
526
+ default_tags=["tool_call"],
527
+ )
@@ -10,7 +10,10 @@ from mcp.types import (
10
10
  Tool,
11
11
  )
12
12
 
13
- from kiln_ai.datamodel.external_tool_server import ExternalToolServer, ToolServerType
13
+ from kiln_ai.datamodel.external_tool_server import (
14
+ ExternalToolServer,
15
+ ToolServerType,
16
+ )
14
17
  from kiln_ai.datamodel.tool_id import MCP_REMOTE_TOOL_ID_PREFIX
15
18
  from kiln_ai.tools.mcp_server_tool import MCPServerTool
16
19
 
@@ -27,7 +30,6 @@ class TestMCPServerTool:
27
30
  description="Test server",
28
31
  properties={
29
32
  "server_url": "https://example.com",
30
- "headers": {},
31
33
  },
32
34
  )
33
35
 
@@ -60,7 +62,6 @@ class TestMCPServerTool:
60
62
  type=ToolServerType.remote_mcp,
61
63
  properties={
62
64
  "server_url": "https://example.com",
63
- "headers": {},
64
65
  },
65
66
  )
66
67
  tool = MCPServerTool(server, "test_tool")
@@ -90,7 +91,6 @@ class TestMCPServerTool:
90
91
  type=ToolServerType.remote_mcp,
91
92
  properties={
92
93
  "server_url": "https://example.com",
93
- "headers": {},
94
94
  },
95
95
  )
96
96
  tool = MCPServerTool(server, "test_tool")
@@ -116,7 +116,6 @@ class TestMCPServerTool:
116
116
  type=ToolServerType.remote_mcp,
117
117
  properties={
118
118
  "server_url": "https://example.com",
119
- "headers": {},
120
119
  },
121
120
  )
122
121
  tool = MCPServerTool(server, "test_tool")
@@ -143,7 +142,6 @@ class TestMCPServerTool:
143
142
  type=ToolServerType.remote_mcp,
144
143
  properties={
145
144
  "server_url": "https://example.com",
146
- "headers": {},
147
145
  },
148
146
  )
149
147
  tool = MCPServerTool(server, "test_tool")
@@ -170,7 +168,6 @@ class TestMCPServerTool:
170
168
  type=ToolServerType.remote_mcp,
171
169
  properties={
172
170
  "server_url": "https://example.com",
173
- "headers": {},
174
171
  },
175
172
  )
176
173
  tool = MCPServerTool(server, "test_tool")
@@ -196,7 +193,6 @@ class TestMCPServerTool:
196
193
  type=ToolServerType.remote_mcp,
197
194
  properties={
198
195
  "server_url": "https://example.com",
199
- "headers": {},
200
196
  },
201
197
  )
202
198
  tool = MCPServerTool(server, "test_tool")
@@ -231,7 +227,6 @@ class TestMCPServerTool:
231
227
  type=ToolServerType.remote_mcp,
232
228
  properties={
233
229
  "server_url": "https://example.com",
234
- "headers": {},
235
230
  },
236
231
  )
237
232
  tool = MCPServerTool(server, "target_tool")
@@ -258,7 +253,6 @@ class TestMCPServerTool:
258
253
  type=ToolServerType.remote_mcp,
259
254
  properties={
260
255
  "server_url": "https://example.com",
261
- "headers": {},
262
256
  },
263
257
  )
264
258
  tool = MCPServerTool(server, "missing_tool")
@@ -287,7 +281,6 @@ class TestMCPServerTool:
287
281
  type=ToolServerType.remote_mcp,
288
282
  properties={
289
283
  "server_url": "https://example.com",
290
- "headers": {},
291
284
  },
292
285
  )
293
286
  tool = MCPServerTool(server, "test_tool")
@@ -321,7 +314,6 @@ class TestMCPServerTool:
321
314
  type=ToolServerType.remote_mcp,
322
315
  properties={
323
316
  "server_url": "https://example.com",
324
- "headers": {},
325
317
  },
326
318
  )
327
319
  tool = MCPServerTool(server, "test_tool")
@@ -347,7 +339,6 @@ class TestMCPServerTool:
347
339
  type=ToolServerType.remote_mcp,
348
340
  properties={
349
341
  "server_url": "https://example.com",
350
- "headers": {},
351
342
  },
352
343
  )
353
344
  tool = MCPServerTool(server, "test_tool")
@@ -365,7 +356,6 @@ class TestMCPServerTool:
365
356
  type=ToolServerType.remote_mcp,
366
357
  properties={
367
358
  "server_url": "https://example.com",
368
- "headers": {},
369
359
  },
370
360
  )
371
361
  tool = MCPServerTool(server, "test_tool")
@@ -415,7 +405,6 @@ class TestMCPServerToolIntegration:
415
405
  description="Postman Echo MCP Server for testing",
416
406
  properties={
417
407
  "server_url": "https://postman-echo-mcp.fly.dev/",
418
- "headers": {},
419
408
  },
420
409
  )
421
410