kiln-ai 0.21.0__py3-none-any.whl → 0.22.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/extractors/litellm_extractor.py +52 -32
- kiln_ai/adapters/extractors/test_litellm_extractor.py +169 -71
- kiln_ai/adapters/ml_embedding_model_list.py +330 -28
- kiln_ai/adapters/ml_model_list.py +503 -23
- kiln_ai/adapters/model_adapters/litellm_adapter.py +39 -8
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +78 -0
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +6 -9
- kiln_ai/adapters/test_ml_embedding_model_list.py +89 -279
- kiln_ai/adapters/test_ml_model_list.py +0 -10
- kiln_ai/adapters/vector_store/lancedb_adapter.py +24 -70
- kiln_ai/adapters/vector_store/lancedb_helpers.py +101 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +9 -16
- kiln_ai/adapters/vector_store/test_lancedb_helpers.py +142 -0
- kiln_ai/adapters/vector_store_loaders/__init__.py +0 -0
- kiln_ai/adapters/vector_store_loaders/test_lancedb_loader.py +282 -0
- kiln_ai/adapters/vector_store_loaders/test_vector_store_loader.py +544 -0
- kiln_ai/adapters/vector_store_loaders/vector_store_loader.py +91 -0
- kiln_ai/datamodel/basemodel.py +31 -3
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +14 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +3 -3
- kiln_ai/datamodel/test_basemodel.py +269 -13
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_model.py +31 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +106 -1
- kiln_ai/datamodel/tool_id.py +49 -0
- kiln_ai/tools/base_tool.py +30 -6
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +162 -0
- kiln_ai/tools/mcp_server_tool.py +7 -5
- kiln_ai/tools/mcp_session_manager.py +50 -24
- kiln_ai/tools/rag_tools.py +17 -6
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +186 -226
- kiln_ai/tools/test_rag_tools.py +86 -5
- kiln_ai/tools/test_tool_registry.py +199 -5
- kiln_ai/tools/tool_registry.py +49 -17
- kiln_ai/utils/filesystem.py +4 -4
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +21 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +14 -1
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/METADATA +79 -1
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/RECORD +53 -45
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.21.0.dist-info → kiln_ai-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -13,7 +13,6 @@ from litellm.types.utils import (
|
|
|
13
13
|
)
|
|
14
14
|
from litellm.types.utils import Message as LiteLLMMessage
|
|
15
15
|
from litellm.types.utils import Usage as LiteLlmUsage
|
|
16
|
-
from openai.types.chat import ChatCompletionToolMessageParam
|
|
17
16
|
from openai.types.chat.chat_completion_message_tool_call_param import (
|
|
18
17
|
ChatCompletionMessageToolCallParam,
|
|
19
18
|
)
|
|
@@ -32,12 +31,18 @@ from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
|
32
31
|
)
|
|
33
32
|
from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
|
|
34
33
|
from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
|
|
35
|
-
from kiln_ai.tools.base_tool import
|
|
34
|
+
from kiln_ai.tools.base_tool import (
|
|
35
|
+
KilnToolInterface,
|
|
36
|
+
ToolCallContext,
|
|
37
|
+
ToolCallDefinition,
|
|
38
|
+
)
|
|
39
|
+
from kiln_ai.tools.kiln_task_tool import KilnTaskToolResult
|
|
36
40
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
37
41
|
from kiln_ai.utils.litellm import get_litellm_provider_info
|
|
38
42
|
from kiln_ai.utils.open_ai_types import (
|
|
39
43
|
ChatCompletionAssistantMessageParamWrapper,
|
|
40
44
|
ChatCompletionMessageParam,
|
|
45
|
+
ChatCompletionToolMessageParamWrapper,
|
|
41
46
|
)
|
|
42
47
|
|
|
43
48
|
MAX_CALLS_PER_TURN = 10
|
|
@@ -488,6 +493,21 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
488
493
|
completion_kwargs["tools"] = tool_calls
|
|
489
494
|
completion_kwargs["tool_choice"] = "auto"
|
|
490
495
|
|
|
496
|
+
# Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both.
|
|
497
|
+
# Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom.
|
|
498
|
+
if provider.temp_top_p_exclusive:
|
|
499
|
+
if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0:
|
|
500
|
+
del completion_kwargs["top_p"]
|
|
501
|
+
if (
|
|
502
|
+
"temperature" in completion_kwargs
|
|
503
|
+
and completion_kwargs["temperature"] == 1.0
|
|
504
|
+
):
|
|
505
|
+
del completion_kwargs["temperature"]
|
|
506
|
+
if "top_p" in completion_kwargs and "temperature" in completion_kwargs:
|
|
507
|
+
raise ValueError(
|
|
508
|
+
"top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)."
|
|
509
|
+
)
|
|
510
|
+
|
|
491
511
|
if not skip_response_format:
|
|
492
512
|
# Response format: json_schema, json_instructions, json_mode, function_calling, etc
|
|
493
513
|
response_format_options = await self.response_format_options()
|
|
@@ -544,7 +564,7 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
544
564
|
self._cached_available_tools = await self.available_tools()
|
|
545
565
|
return self._cached_available_tools
|
|
546
566
|
|
|
547
|
-
async def litellm_tools(self) -> list[
|
|
567
|
+
async def litellm_tools(self) -> list[ToolCallDefinition]:
|
|
548
568
|
available_tools = await self.cached_available_tools()
|
|
549
569
|
|
|
550
570
|
# LiteLLM takes the standard OpenAI-compatible tool call format
|
|
@@ -552,12 +572,12 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
552
572
|
|
|
553
573
|
async def process_tool_calls(
|
|
554
574
|
self, tool_calls: list[ChatCompletionMessageToolCall] | None
|
|
555
|
-
) -> tuple[str | None, list[
|
|
575
|
+
) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]:
|
|
556
576
|
if tool_calls is None:
|
|
557
577
|
return None, []
|
|
558
578
|
|
|
559
579
|
assistant_output_from_toolcall: str | None = None
|
|
560
|
-
tool_call_response_messages: list[
|
|
580
|
+
tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = []
|
|
561
581
|
|
|
562
582
|
for tool_call in tool_calls:
|
|
563
583
|
# Kiln "task_response" tool is used for returning structured output via tool calls.
|
|
@@ -594,13 +614,24 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
594
614
|
f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
|
|
595
615
|
) from e
|
|
596
616
|
|
|
597
|
-
|
|
617
|
+
# Create context with the calling task's allow_saving setting
|
|
618
|
+
context = ToolCallContext(
|
|
619
|
+
allow_saving=self.base_adapter_config.allow_saving
|
|
620
|
+
)
|
|
621
|
+
result = await tool.run(context, **parsed_args)
|
|
622
|
+
if isinstance(result, KilnTaskToolResult):
|
|
623
|
+
content = result.output
|
|
624
|
+
kiln_task_tool_data = result.kiln_task_tool_data
|
|
625
|
+
else:
|
|
626
|
+
content = result
|
|
627
|
+
kiln_task_tool_data = None
|
|
598
628
|
|
|
599
629
|
tool_call_response_messages.append(
|
|
600
|
-
|
|
630
|
+
ChatCompletionToolMessageParamWrapper(
|
|
601
631
|
role="tool",
|
|
602
632
|
tool_call_id=tool_call.id,
|
|
603
|
-
content=
|
|
633
|
+
content=content,
|
|
634
|
+
kiln_task_tool_data=kiln_task_tool_data,
|
|
604
635
|
)
|
|
605
636
|
)
|
|
606
637
|
|
|
@@ -405,6 +405,7 @@ async def test_build_completion_kwargs_custom_temperature_top_p(config, mock_tas
|
|
|
405
405
|
|
|
406
406
|
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
407
407
|
mock_provider = Mock()
|
|
408
|
+
mock_provider.temp_top_p_exclusive = False
|
|
408
409
|
messages = [{"role": "user", "content": "Hello"}]
|
|
409
410
|
|
|
410
411
|
with (
|
|
@@ -446,6 +447,7 @@ async def test_build_completion_kwargs(
|
|
|
446
447
|
"""Test build_completion_kwargs with various configurations"""
|
|
447
448
|
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
448
449
|
mock_provider = Mock()
|
|
450
|
+
mock_provider.temp_top_p_exclusive = False
|
|
449
451
|
messages = [{"role": "user", "content": "Hello"}]
|
|
450
452
|
|
|
451
453
|
with (
|
|
@@ -613,6 +615,7 @@ async def test_build_completion_kwargs_includes_tools(
|
|
|
613
615
|
"""Test build_completion_kwargs includes tools when available_tools has tools"""
|
|
614
616
|
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
615
617
|
mock_provider = Mock()
|
|
618
|
+
mock_provider.temp_top_p_exclusive = False
|
|
616
619
|
messages = [{"role": "user", "content": "Hello"}]
|
|
617
620
|
|
|
618
621
|
with (
|
|
@@ -666,6 +669,7 @@ async def test_build_completion_kwargs_raises_error_with_tools_conflict(
|
|
|
666
669
|
config.run_config_properties.structured_output_mode = structured_output_mode
|
|
667
670
|
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
668
671
|
mock_provider = Mock()
|
|
672
|
+
mock_provider.temp_top_p_exclusive = False
|
|
669
673
|
messages = [{"role": "user", "content": "Hello"}]
|
|
670
674
|
|
|
671
675
|
with (
|
|
@@ -976,3 +980,77 @@ def test_build_extra_body_enable_thinking(config, mock_task, enable_thinking):
|
|
|
976
980
|
extra_body = adapter.build_extra_body(provider)
|
|
977
981
|
|
|
978
982
|
assert extra_body["enable_thinking"] == enable_thinking
|
|
983
|
+
|
|
984
|
+
|
|
985
|
+
@pytest.mark.asyncio
|
|
986
|
+
@pytest.mark.parametrize(
|
|
987
|
+
"temperature,top_p,should_raise,expected_temp,expected_top_p",
|
|
988
|
+
[
|
|
989
|
+
(1.0, 1.0, False, None, None),
|
|
990
|
+
(0.7, 1.0, False, 0.7, None),
|
|
991
|
+
(1.0, 0.9, False, None, 0.9),
|
|
992
|
+
(0.7, 0.9, True, None, None),
|
|
993
|
+
(0.5, 0.5, True, None, None),
|
|
994
|
+
],
|
|
995
|
+
)
|
|
996
|
+
async def test_build_completion_kwargs_temp_top_p_exclusive(
|
|
997
|
+
config, mock_task, temperature, top_p, should_raise, expected_temp, expected_top_p
|
|
998
|
+
):
|
|
999
|
+
"""Test build_completion_kwargs with temp_top_p_exclusive provider flag"""
|
|
1000
|
+
config.run_config_properties.temperature = temperature
|
|
1001
|
+
config.run_config_properties.top_p = top_p
|
|
1002
|
+
|
|
1003
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
1004
|
+
mock_provider = Mock()
|
|
1005
|
+
mock_provider.temp_top_p_exclusive = True
|
|
1006
|
+
messages = [{"role": "user", "content": "Hello"}]
|
|
1007
|
+
|
|
1008
|
+
with (
|
|
1009
|
+
patch.object(adapter, "model_provider", return_value=mock_provider),
|
|
1010
|
+
patch.object(adapter, "litellm_model_id", return_value="anthropic/test-model"),
|
|
1011
|
+
patch.object(adapter, "build_extra_body", return_value={}),
|
|
1012
|
+
patch.object(adapter, "response_format_options", return_value={}),
|
|
1013
|
+
):
|
|
1014
|
+
if should_raise:
|
|
1015
|
+
with pytest.raises(
|
|
1016
|
+
ValueError,
|
|
1017
|
+
match="top_p and temperature can not both have custom values",
|
|
1018
|
+
):
|
|
1019
|
+
await adapter.build_completion_kwargs(mock_provider, messages, None)
|
|
1020
|
+
else:
|
|
1021
|
+
kwargs = await adapter.build_completion_kwargs(
|
|
1022
|
+
mock_provider, messages, None
|
|
1023
|
+
)
|
|
1024
|
+
|
|
1025
|
+
if expected_temp is None:
|
|
1026
|
+
assert "temperature" not in kwargs
|
|
1027
|
+
else:
|
|
1028
|
+
assert kwargs["temperature"] == expected_temp
|
|
1029
|
+
|
|
1030
|
+
if expected_top_p is None:
|
|
1031
|
+
assert "top_p" not in kwargs
|
|
1032
|
+
else:
|
|
1033
|
+
assert kwargs["top_p"] == expected_top_p
|
|
1034
|
+
|
|
1035
|
+
|
|
1036
|
+
@pytest.mark.asyncio
|
|
1037
|
+
async def test_build_completion_kwargs_temp_top_p_not_exclusive(config, mock_task):
|
|
1038
|
+
"""Test build_completion_kwargs with temp_top_p_exclusive=False allows both params"""
|
|
1039
|
+
config.run_config_properties.temperature = 0.7
|
|
1040
|
+
config.run_config_properties.top_p = 0.9
|
|
1041
|
+
|
|
1042
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
1043
|
+
mock_provider = Mock()
|
|
1044
|
+
mock_provider.temp_top_p_exclusive = False
|
|
1045
|
+
messages = [{"role": "user", "content": "Hello"}]
|
|
1046
|
+
|
|
1047
|
+
with (
|
|
1048
|
+
patch.object(adapter, "model_provider", return_value=mock_provider),
|
|
1049
|
+
patch.object(adapter, "litellm_model_id", return_value="openai/test-model"),
|
|
1050
|
+
patch.object(adapter, "build_extra_body", return_value={}),
|
|
1051
|
+
patch.object(adapter, "response_format_options", return_value={}),
|
|
1052
|
+
):
|
|
1053
|
+
kwargs = await adapter.build_completion_kwargs(mock_provider, messages, None)
|
|
1054
|
+
|
|
1055
|
+
assert kwargs["temperature"] == 0.7
|
|
1056
|
+
assert kwargs["top_p"] == 0.9
|
|
@@ -18,12 +18,15 @@ from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
|
|
|
18
18
|
from kiln_ai.datamodel import PromptId
|
|
19
19
|
from kiln_ai.datamodel.datamodel_enums import ModelProviderName, StructuredOutputMode
|
|
20
20
|
from kiln_ai.datamodel.task import RunConfigProperties
|
|
21
|
+
from kiln_ai.datamodel.tool_id import ToolId
|
|
22
|
+
from kiln_ai.tools.base_tool import ToolCallContext
|
|
21
23
|
from kiln_ai.tools.built_in_tools.math_tools import (
|
|
22
24
|
AddTool,
|
|
23
25
|
DivideTool,
|
|
24
26
|
MultiplyTool,
|
|
25
27
|
SubtractTool,
|
|
26
28
|
)
|
|
29
|
+
from kiln_ai.tools.kiln_task_tool import KilnTaskToolResult
|
|
27
30
|
from kiln_ai.utils.open_ai_types import ChatCompletionMessageParam
|
|
28
31
|
|
|
29
32
|
|
|
@@ -91,6 +94,7 @@ async def run_simple_task_with_tools(
|
|
|
91
94
|
# Verify that AddTool.run was called with correct parameters
|
|
92
95
|
add_spy.run.assert_called()
|
|
93
96
|
add_call_args = add_spy.run.call_args
|
|
97
|
+
assert add_call_args.args[0].allow_saving # First arg is ToolCallContext
|
|
94
98
|
add_kwargs = add_call_args.kwargs
|
|
95
99
|
assert add_kwargs.get("a") == 2
|
|
96
100
|
assert add_kwargs.get("b") == 2
|
|
@@ -126,6 +130,9 @@ async def run_simple_task_with_tools(
|
|
|
126
130
|
# Verify that MultiplyTool.run was called with correct parameters
|
|
127
131
|
multiply_spy.run.assert_called()
|
|
128
132
|
multiply_call_args = multiply_spy.run.call_args
|
|
133
|
+
assert multiply_call_args.args[
|
|
134
|
+
0
|
|
135
|
+
].allow_saving # First arg is ToolCallContext
|
|
129
136
|
multiply_kwargs = multiply_call_args.kwargs
|
|
130
137
|
# Check that multiply was called with a=6, b=10 (or vice versa)
|
|
131
138
|
assert (
|
|
@@ -137,6 +144,7 @@ async def run_simple_task_with_tools(
|
|
|
137
144
|
# Verify that AddTool.run was called with correct parameters
|
|
138
145
|
add_spy.run.assert_called()
|
|
139
146
|
add_call_args = add_spy.run.call_args
|
|
147
|
+
assert add_call_args.args[0].allow_saving # First arg is ToolCallContext
|
|
140
148
|
add_kwargs = add_call_args.kwargs
|
|
141
149
|
# Check that add was called with a=60, b=4 (or vice versa)
|
|
142
150
|
assert (add_kwargs.get("a") == 60 and add_kwargs.get("b") == 4) or (
|
|
@@ -482,8 +490,16 @@ async def test_run_model_turn_parallel_tools(tmp_path):
|
|
|
482
490
|
)
|
|
483
491
|
|
|
484
492
|
# Verify both tools were called in parallel
|
|
485
|
-
|
|
486
|
-
|
|
493
|
+
# The context is passed as the first positional argument, not as a keyword argument
|
|
494
|
+
multiply_spy.run.assert_called_once()
|
|
495
|
+
multiply_call_args = multiply_spy.run.call_args
|
|
496
|
+
assert multiply_call_args.args[0].allow_saving # First arg is ToolCallContext
|
|
497
|
+
assert multiply_call_args.kwargs == {"a": 6, "b": 10}
|
|
498
|
+
|
|
499
|
+
add_spy.run.assert_called_once()
|
|
500
|
+
add_call_args = add_spy.run.call_args
|
|
501
|
+
assert add_call_args.args[0].allow_saving # First arg is ToolCallContext
|
|
502
|
+
assert add_call_args.kwargs == {"a": 2, "b": 3}
|
|
487
503
|
|
|
488
504
|
# Verify the result structure
|
|
489
505
|
assert isinstance(result, ModelTurnResult)
|
|
@@ -596,8 +612,16 @@ async def test_run_model_turn_sequential_tools(tmp_path):
|
|
|
596
612
|
)
|
|
597
613
|
|
|
598
614
|
# Verify tools were called sequentially
|
|
599
|
-
|
|
600
|
-
|
|
615
|
+
# The context is passed as the first positional argument, not as a keyword argument
|
|
616
|
+
multiply_spy.run.assert_called_once()
|
|
617
|
+
multiply_call_args = multiply_spy.run.call_args
|
|
618
|
+
assert multiply_call_args.args[0].allow_saving # First arg is ToolCallContext
|
|
619
|
+
assert multiply_call_args.kwargs == {"a": 6, "b": 10}
|
|
620
|
+
|
|
621
|
+
add_spy.run.assert_called_once()
|
|
622
|
+
add_call_args = add_spy.run.call_args
|
|
623
|
+
assert add_call_args.args[0].allow_saving # First arg is ToolCallContext
|
|
624
|
+
assert add_call_args.kwargs == {"a": 60, "b": 4}
|
|
601
625
|
|
|
602
626
|
# Verify the result structure
|
|
603
627
|
assert isinstance(result, ModelTurnResult)
|
|
@@ -756,11 +780,59 @@ class MockTool:
|
|
|
756
780
|
}
|
|
757
781
|
}
|
|
758
782
|
|
|
759
|
-
async def run(self, **kwargs) -> str:
|
|
783
|
+
async def run(self, context: ToolCallContext | None = None, **kwargs) -> str:
|
|
760
784
|
if self._raise_on_run:
|
|
761
785
|
raise self._raise_on_run
|
|
762
786
|
return self._return_value
|
|
763
787
|
|
|
788
|
+
async def id(self) -> ToolId:
|
|
789
|
+
"""Mock implementation of id for testing."""
|
|
790
|
+
return f"mock_tool_{self._name}"
|
|
791
|
+
|
|
792
|
+
|
|
793
|
+
class MockKilnTaskTool:
|
|
794
|
+
"""Mock tool class that returns KilnTaskToolResult for testing"""
|
|
795
|
+
|
|
796
|
+
def __init__(
|
|
797
|
+
self,
|
|
798
|
+
name: str,
|
|
799
|
+
raise_on_run: Exception | None = None,
|
|
800
|
+
output: str = "kiln_task_output",
|
|
801
|
+
kiln_task_tool_data: str = "project_id:::tool_id:::task_id:::run_id",
|
|
802
|
+
):
|
|
803
|
+
self._name = name
|
|
804
|
+
self._raise_on_run = raise_on_run
|
|
805
|
+
self._output = output
|
|
806
|
+
self._kiln_task_tool_data = kiln_task_tool_data
|
|
807
|
+
|
|
808
|
+
async def name(self) -> str:
|
|
809
|
+
return self._name
|
|
810
|
+
|
|
811
|
+
async def toolcall_definition(self) -> dict:
|
|
812
|
+
return {
|
|
813
|
+
"function": {
|
|
814
|
+
"parameters": {
|
|
815
|
+
"type": "object",
|
|
816
|
+
"properties": {"input": {"type": "string"}},
|
|
817
|
+
"required": ["input"],
|
|
818
|
+
}
|
|
819
|
+
}
|
|
820
|
+
}
|
|
821
|
+
|
|
822
|
+
async def run(
|
|
823
|
+
self, context: ToolCallContext | None = None, **kwargs
|
|
824
|
+
) -> KilnTaskToolResult:
|
|
825
|
+
if self._raise_on_run:
|
|
826
|
+
raise self._raise_on_run
|
|
827
|
+
return KilnTaskToolResult(
|
|
828
|
+
output=self._output,
|
|
829
|
+
kiln_task_tool_data=self._kiln_task_tool_data,
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
async def id(self) -> ToolId:
|
|
833
|
+
"""Mock implementation of id for testing."""
|
|
834
|
+
return f"mock_kiln_task_tool_{self._name}"
|
|
835
|
+
|
|
764
836
|
|
|
765
837
|
async def test_process_tool_calls_none_input(tmp_path):
|
|
766
838
|
"""Test process_tool_calls with None input"""
|
|
@@ -879,6 +951,7 @@ async def test_process_tool_calls_normal_tool_success(tmp_path):
|
|
|
879
951
|
"role": "tool",
|
|
880
952
|
"tool_call_id": "call_1",
|
|
881
953
|
"content": "5",
|
|
954
|
+
"kiln_task_tool_data": None,
|
|
882
955
|
}
|
|
883
956
|
|
|
884
957
|
|
|
@@ -915,8 +988,10 @@ async def test_process_tool_calls_multiple_normal_tools(tmp_path):
|
|
|
915
988
|
assert len(tool_messages) == 2
|
|
916
989
|
assert tool_messages[0]["tool_call_id"] == "call_1"
|
|
917
990
|
assert tool_messages[0]["content"] == "5"
|
|
991
|
+
assert tool_messages[0].get("kiln_task_tool_data") is None
|
|
918
992
|
assert tool_messages[1]["tool_call_id"] == "call_2"
|
|
919
993
|
assert tool_messages[1]["content"] == "6"
|
|
994
|
+
assert tool_messages[1].get("kiln_task_tool_data") is None
|
|
920
995
|
|
|
921
996
|
|
|
922
997
|
async def test_process_tool_calls_tool_not_found(tmp_path):
|
|
@@ -1072,6 +1147,7 @@ async def test_process_tool_calls_complex_result(tmp_path):
|
|
|
1072
1147
|
assert assistant_output is None
|
|
1073
1148
|
assert len(tool_messages) == 1
|
|
1074
1149
|
assert tool_messages[0]["content"] == complex_result
|
|
1150
|
+
assert tool_messages[0].get("kiln_task_tool_data") is None
|
|
1075
1151
|
|
|
1076
1152
|
|
|
1077
1153
|
async def test_process_tool_calls_task_response_with_normal_tools_error(tmp_path):
|
|
@@ -1101,3 +1177,41 @@ async def test_process_tool_calls_task_response_with_normal_tools_error(tmp_path
|
|
|
1101
1177
|
match="task_response tool call and other tool calls were both provided",
|
|
1102
1178
|
):
|
|
1103
1179
|
await litellm_adapter.process_tool_calls(tool_calls) # type: ignore
|
|
1180
|
+
|
|
1181
|
+
|
|
1182
|
+
async def test_process_tool_calls_kiln_task_tool_result(tmp_path):
|
|
1183
|
+
"""Test process_tool_calls with KilnTaskToolResult - tests the new if statement branch"""
|
|
1184
|
+
task = build_test_task(tmp_path)
|
|
1185
|
+
config = LiteLlmConfig(
|
|
1186
|
+
run_config_properties=RunConfigProperties(
|
|
1187
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
1188
|
+
model_name="gpt_4_1_mini",
|
|
1189
|
+
model_provider_name=ModelProviderName.openai,
|
|
1190
|
+
prompt_id="simple_prompt_builder",
|
|
1191
|
+
)
|
|
1192
|
+
)
|
|
1193
|
+
litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
|
|
1194
|
+
|
|
1195
|
+
mock_kiln_task_tool = MockKilnTaskTool(
|
|
1196
|
+
"kiln_task_tool",
|
|
1197
|
+
output="Task completed successfully",
|
|
1198
|
+
kiln_task_tool_data="proj123:::tool456:::task789:::run101",
|
|
1199
|
+
)
|
|
1200
|
+
tool_calls = [MockToolCall("call_1", "kiln_task_tool", '{"input": "test input"}')]
|
|
1201
|
+
|
|
1202
|
+
with patch.object(
|
|
1203
|
+
litellm_adapter, "cached_available_tools", return_value=[mock_kiln_task_tool]
|
|
1204
|
+
):
|
|
1205
|
+
assistant_output, tool_messages = await litellm_adapter.process_tool_calls(
|
|
1206
|
+
tool_calls # type: ignore
|
|
1207
|
+
)
|
|
1208
|
+
|
|
1209
|
+
assert assistant_output is None
|
|
1210
|
+
assert len(tool_messages) == 1
|
|
1211
|
+
assert tool_messages[0]["role"] == "tool"
|
|
1212
|
+
assert tool_messages[0]["tool_call_id"] == "call_1"
|
|
1213
|
+
assert tool_messages[0]["content"] == "Task completed successfully"
|
|
1214
|
+
assert (
|
|
1215
|
+
tool_messages[0].get("kiln_task_tool_data")
|
|
1216
|
+
== "proj123:::tool456:::task789:::run101"
|
|
1217
|
+
)
|
|
@@ -60,7 +60,9 @@ def test_save_run_isolation(test_task, adapter):
|
|
|
60
60
|
)
|
|
61
61
|
|
|
62
62
|
task_run = adapter.generate_run(
|
|
63
|
-
input=input_data,
|
|
63
|
+
input=input_data,
|
|
64
|
+
input_source=None,
|
|
65
|
+
run_output=run_output,
|
|
64
66
|
)
|
|
65
67
|
task_run.save_to_file()
|
|
66
68
|
|
|
@@ -146,7 +148,9 @@ def test_generate_run_non_ascii(test_task, adapter):
|
|
|
146
148
|
)
|
|
147
149
|
|
|
148
150
|
task_run = adapter.generate_run(
|
|
149
|
-
input=input_data,
|
|
151
|
+
input=input_data,
|
|
152
|
+
input_source=None,
|
|
153
|
+
run_output=run_output,
|
|
150
154
|
)
|
|
151
155
|
task_run.save_to_file()
|
|
152
156
|
|
|
@@ -256,7 +260,9 @@ def test_properties_for_task_output_custom_values(test_task):
|
|
|
256
260
|
run_output = RunOutput(output=output_data, intermediate_outputs=None)
|
|
257
261
|
|
|
258
262
|
task_run = adapter.generate_run(
|
|
259
|
-
input=input_data,
|
|
263
|
+
input=input_data,
|
|
264
|
+
input_source=None,
|
|
265
|
+
run_output=run_output,
|
|
260
266
|
)
|
|
261
267
|
task_run.save_to_file()
|
|
262
268
|
|
|
@@ -175,15 +175,12 @@ async def run_structured_output_test(tmp_path: Path, model_name: str, provider:
|
|
|
175
175
|
|
|
176
176
|
# Check reasoning models
|
|
177
177
|
assert a._model_provider is not None
|
|
178
|
-
if
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
else:
|
|
185
|
-
assert "reasoning" in run.intermediate_outputs
|
|
186
|
-
assert isinstance(run.intermediate_outputs["reasoning"], str)
|
|
178
|
+
if (
|
|
179
|
+
a._model_provider.reasoning_capable
|
|
180
|
+
and not a._model_provider.reasoning_optional_for_structured_output
|
|
181
|
+
):
|
|
182
|
+
assert "reasoning" in run.intermediate_outputs
|
|
183
|
+
assert isinstance(run.intermediate_outputs["reasoning"], str)
|
|
187
184
|
|
|
188
185
|
|
|
189
186
|
def build_structured_input_test_task(tmp_path: Path):
|