kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +6 -0
- kiln_ai/adapters/adapter_registry.py +43 -226
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/eval_runner.py +6 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -3
- kiln_ai/adapters/eval/test_g_eval.py +1 -1
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +494 -0
- kiln_ai/adapters/ml_model_list.py +876 -18
- kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/provider_tools.py +190 -46
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/test_adapter_registry.py +579 -86
- kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
- kiln_ai/adapters/test_ml_model_list.py +202 -0
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +199 -8
- kiln_ai/adapters/test_remote_config.py +551 -56
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +16 -13
- kiln_ai/datamodel/basemodel.py +201 -4
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +27 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +317 -0
- kiln_ai/datamodel/project.py +33 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +270 -14
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +501 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +187 -1
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +58 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/base_tool.py +12 -3
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +158 -0
- kiln_ai/tools/mcp_server_tool.py +2 -2
- kiln_ai/tools/mcp_session_manager.py +51 -22
- kiln_ai/tools/rag_tools.py +164 -0
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +187 -227
- kiln_ai/tools/test_rag_tools.py +929 -0
- kiln_ai/tools/test_tool_registry.py +290 -7
- kiln_ai/tools/tool_registry.py +69 -16
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +2 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +59 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +86 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
- kiln_ai-0.22.0.dist-info/RECORD +213 -0
- kiln_ai-0.20.1.dist-info/RECORD +0 -138
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -11,13 +11,8 @@ from litellm.types.utils import (
|
|
|
11
11
|
Choices,
|
|
12
12
|
ModelResponse,
|
|
13
13
|
)
|
|
14
|
-
from litellm.types.utils import
|
|
15
|
-
Message as LiteLLMMessage,
|
|
16
|
-
)
|
|
14
|
+
from litellm.types.utils import Message as LiteLLMMessage
|
|
17
15
|
from litellm.types.utils import Usage as LiteLlmUsage
|
|
18
|
-
from openai.types.chat import (
|
|
19
|
-
ChatCompletionToolMessageParam,
|
|
20
|
-
)
|
|
21
16
|
from openai.types.chat.chat_completion_message_tool_call_param import (
|
|
22
17
|
ChatCompletionMessageToolCallParam,
|
|
23
18
|
)
|
|
@@ -36,11 +31,14 @@ from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
|
36
31
|
)
|
|
37
32
|
from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
|
|
38
33
|
from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
|
|
39
|
-
from kiln_ai.tools.base_tool import KilnToolInterface
|
|
34
|
+
from kiln_ai.tools.base_tool import KilnToolInterface, ToolCallContext
|
|
35
|
+
from kiln_ai.tools.kiln_task_tool import KilnTaskToolResult
|
|
40
36
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
37
|
+
from kiln_ai.utils.litellm import get_litellm_provider_info
|
|
41
38
|
from kiln_ai.utils.open_ai_types import (
|
|
42
39
|
ChatCompletionAssistantMessageParamWrapper,
|
|
43
40
|
ChatCompletionMessageParam,
|
|
41
|
+
ChatCompletionToolMessageParamWrapper,
|
|
44
42
|
)
|
|
45
43
|
|
|
46
44
|
MAX_CALLS_PER_TURN = 10
|
|
@@ -447,75 +445,16 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
447
445
|
def litellm_model_id(self) -> str:
|
|
448
446
|
# The model ID is an interesting combination of format and url endpoint.
|
|
449
447
|
# It specifics the provider URL/host, but this is overridden if you manually set an api url
|
|
450
|
-
|
|
451
448
|
if self._litellm_model_id:
|
|
452
449
|
return self._litellm_model_id
|
|
453
450
|
|
|
454
|
-
|
|
455
|
-
if
|
|
456
|
-
raise ValueError("Model ID is required for OpenAI compatible models")
|
|
457
|
-
|
|
458
|
-
litellm_provider_name: str | None = None
|
|
459
|
-
is_custom = False
|
|
460
|
-
match provider.name:
|
|
461
|
-
case ModelProviderName.openrouter:
|
|
462
|
-
litellm_provider_name = "openrouter"
|
|
463
|
-
case ModelProviderName.openai:
|
|
464
|
-
litellm_provider_name = "openai"
|
|
465
|
-
case ModelProviderName.groq:
|
|
466
|
-
litellm_provider_name = "groq"
|
|
467
|
-
case ModelProviderName.anthropic:
|
|
468
|
-
litellm_provider_name = "anthropic"
|
|
469
|
-
case ModelProviderName.ollama:
|
|
470
|
-
# We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
|
|
471
|
-
# This is because we're setting detailed features like response_format=json_schema and want lower level control.
|
|
472
|
-
is_custom = True
|
|
473
|
-
case ModelProviderName.docker_model_runner:
|
|
474
|
-
# Docker Model Runner uses OpenAI-compatible API, similar to Ollama
|
|
475
|
-
# We want direct control over the requests for features like response_format=json_schema
|
|
476
|
-
is_custom = True
|
|
477
|
-
case ModelProviderName.gemini_api:
|
|
478
|
-
litellm_provider_name = "gemini"
|
|
479
|
-
case ModelProviderName.fireworks_ai:
|
|
480
|
-
litellm_provider_name = "fireworks_ai"
|
|
481
|
-
case ModelProviderName.amazon_bedrock:
|
|
482
|
-
litellm_provider_name = "bedrock"
|
|
483
|
-
case ModelProviderName.azure_openai:
|
|
484
|
-
litellm_provider_name = "azure"
|
|
485
|
-
case ModelProviderName.huggingface:
|
|
486
|
-
litellm_provider_name = "huggingface"
|
|
487
|
-
case ModelProviderName.vertex:
|
|
488
|
-
litellm_provider_name = "vertex_ai"
|
|
489
|
-
case ModelProviderName.together_ai:
|
|
490
|
-
litellm_provider_name = "together_ai"
|
|
491
|
-
case ModelProviderName.cerebras:
|
|
492
|
-
litellm_provider_name = "cerebras"
|
|
493
|
-
case ModelProviderName.siliconflow_cn:
|
|
494
|
-
is_custom = True
|
|
495
|
-
case ModelProviderName.openai_compatible:
|
|
496
|
-
is_custom = True
|
|
497
|
-
case ModelProviderName.kiln_custom_registry:
|
|
498
|
-
is_custom = True
|
|
499
|
-
case ModelProviderName.kiln_fine_tune:
|
|
500
|
-
is_custom = True
|
|
501
|
-
case _:
|
|
502
|
-
raise_exhaustive_enum_error(provider.name)
|
|
503
|
-
|
|
504
|
-
if is_custom:
|
|
505
|
-
if self._api_base is None:
|
|
506
|
-
raise ValueError(
|
|
507
|
-
"Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
|
|
508
|
-
)
|
|
509
|
-
# Use openai as it's only used for format, not url
|
|
510
|
-
litellm_provider_name = "openai"
|
|
511
|
-
|
|
512
|
-
# Sholdn't be possible but keep type checker happy
|
|
513
|
-
if litellm_provider_name is None:
|
|
451
|
+
litellm_provider_info = get_litellm_provider_info(self.model_provider())
|
|
452
|
+
if litellm_provider_info.is_custom and self._api_base is None:
|
|
514
453
|
raise ValueError(
|
|
515
|
-
|
|
454
|
+
"Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
|
|
516
455
|
)
|
|
517
456
|
|
|
518
|
-
self._litellm_model_id =
|
|
457
|
+
self._litellm_model_id = litellm_provider_info.litellm_model_id
|
|
519
458
|
return self._litellm_model_id
|
|
520
459
|
|
|
521
460
|
async def build_completion_kwargs(
|
|
@@ -550,6 +489,21 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
550
489
|
completion_kwargs["tools"] = tool_calls
|
|
551
490
|
completion_kwargs["tool_choice"] = "auto"
|
|
552
491
|
|
|
492
|
+
# Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both.
|
|
493
|
+
# Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom.
|
|
494
|
+
if provider.temp_top_p_exclusive:
|
|
495
|
+
if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0:
|
|
496
|
+
del completion_kwargs["top_p"]
|
|
497
|
+
if (
|
|
498
|
+
"temperature" in completion_kwargs
|
|
499
|
+
and completion_kwargs["temperature"] == 1.0
|
|
500
|
+
):
|
|
501
|
+
del completion_kwargs["temperature"]
|
|
502
|
+
if "top_p" in completion_kwargs and "temperature" in completion_kwargs:
|
|
503
|
+
raise ValueError(
|
|
504
|
+
"top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)."
|
|
505
|
+
)
|
|
506
|
+
|
|
553
507
|
if not skip_response_format:
|
|
554
508
|
# Response format: json_schema, json_instructions, json_mode, function_calling, etc
|
|
555
509
|
response_format_options = await self.response_format_options()
|
|
@@ -614,12 +568,12 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
614
568
|
|
|
615
569
|
async def process_tool_calls(
|
|
616
570
|
self, tool_calls: list[ChatCompletionMessageToolCall] | None
|
|
617
|
-
) -> tuple[str | None, list[
|
|
571
|
+
) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]:
|
|
618
572
|
if tool_calls is None:
|
|
619
573
|
return None, []
|
|
620
574
|
|
|
621
575
|
assistant_output_from_toolcall: str | None = None
|
|
622
|
-
tool_call_response_messages: list[
|
|
576
|
+
tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = []
|
|
623
577
|
|
|
624
578
|
for tool_call in tool_calls:
|
|
625
579
|
# Kiln "task_response" tool is used for returning structured output via tool calls.
|
|
@@ -656,13 +610,24 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
656
610
|
f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
|
|
657
611
|
) from e
|
|
658
612
|
|
|
659
|
-
|
|
613
|
+
# Create context with the calling task's allow_saving setting
|
|
614
|
+
context = ToolCallContext(
|
|
615
|
+
allow_saving=self.base_adapter_config.allow_saving
|
|
616
|
+
)
|
|
617
|
+
result = await tool.run(context, **parsed_args)
|
|
618
|
+
if isinstance(result, KilnTaskToolResult):
|
|
619
|
+
content = result.output
|
|
620
|
+
kiln_task_tool_data = result.kiln_task_tool_data
|
|
621
|
+
else:
|
|
622
|
+
content = result
|
|
623
|
+
kiln_task_tool_data = None
|
|
660
624
|
|
|
661
625
|
tool_call_response_messages.append(
|
|
662
|
-
|
|
626
|
+
ChatCompletionToolMessageParamWrapper(
|
|
663
627
|
role="tool",
|
|
664
628
|
tool_call_id=tool_call.id,
|
|
665
|
-
content=
|
|
629
|
+
content=content,
|
|
630
|
+
kiln_task_tool_data=kiln_task_tool_data,
|
|
666
631
|
)
|
|
667
632
|
)
|
|
668
633
|
|
|
@@ -351,7 +351,7 @@ def test_litellm_model_id_unknown_provider(config, mock_task):
|
|
|
351
351
|
|
|
352
352
|
with patch.object(adapter, "model_provider", return_value=mock_provider):
|
|
353
353
|
with patch(
|
|
354
|
-
"kiln_ai.
|
|
354
|
+
"kiln_ai.utils.litellm.raise_exhaustive_enum_error"
|
|
355
355
|
) as mock_raise_error:
|
|
356
356
|
mock_raise_error.side_effect = Exception("Test error")
|
|
357
357
|
|
|
@@ -405,6 +405,7 @@ async def test_build_completion_kwargs_custom_temperature_top_p(config, mock_tas
|
|
|
405
405
|
|
|
406
406
|
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
407
407
|
mock_provider = Mock()
|
|
408
|
+
mock_provider.temp_top_p_exclusive = False
|
|
408
409
|
messages = [{"role": "user", "content": "Hello"}]
|
|
409
410
|
|
|
410
411
|
with (
|
|
@@ -446,6 +447,7 @@ async def test_build_completion_kwargs(
|
|
|
446
447
|
"""Test build_completion_kwargs with various configurations"""
|
|
447
448
|
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
448
449
|
mock_provider = Mock()
|
|
450
|
+
mock_provider.temp_top_p_exclusive = False
|
|
449
451
|
messages = [{"role": "user", "content": "Hello"}]
|
|
450
452
|
|
|
451
453
|
with (
|
|
@@ -613,6 +615,7 @@ async def test_build_completion_kwargs_includes_tools(
|
|
|
613
615
|
"""Test build_completion_kwargs includes tools when available_tools has tools"""
|
|
614
616
|
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
615
617
|
mock_provider = Mock()
|
|
618
|
+
mock_provider.temp_top_p_exclusive = False
|
|
616
619
|
messages = [{"role": "user", "content": "Hello"}]
|
|
617
620
|
|
|
618
621
|
with (
|
|
@@ -666,6 +669,7 @@ async def test_build_completion_kwargs_raises_error_with_tools_conflict(
|
|
|
666
669
|
config.run_config_properties.structured_output_mode = structured_output_mode
|
|
667
670
|
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
668
671
|
mock_provider = Mock()
|
|
672
|
+
mock_provider.temp_top_p_exclusive = False
|
|
669
673
|
messages = [{"role": "user", "content": "Hello"}]
|
|
670
674
|
|
|
671
675
|
with (
|
|
@@ -976,3 +980,77 @@ def test_build_extra_body_enable_thinking(config, mock_task, enable_thinking):
|
|
|
976
980
|
extra_body = adapter.build_extra_body(provider)
|
|
977
981
|
|
|
978
982
|
assert extra_body["enable_thinking"] == enable_thinking
|
|
983
|
+
|
|
984
|
+
|
|
985
|
+
@pytest.mark.asyncio
|
|
986
|
+
@pytest.mark.parametrize(
|
|
987
|
+
"temperature,top_p,should_raise,expected_temp,expected_top_p",
|
|
988
|
+
[
|
|
989
|
+
(1.0, 1.0, False, None, None),
|
|
990
|
+
(0.7, 1.0, False, 0.7, None),
|
|
991
|
+
(1.0, 0.9, False, None, 0.9),
|
|
992
|
+
(0.7, 0.9, True, None, None),
|
|
993
|
+
(0.5, 0.5, True, None, None),
|
|
994
|
+
],
|
|
995
|
+
)
|
|
996
|
+
async def test_build_completion_kwargs_temp_top_p_exclusive(
|
|
997
|
+
config, mock_task, temperature, top_p, should_raise, expected_temp, expected_top_p
|
|
998
|
+
):
|
|
999
|
+
"""Test build_completion_kwargs with temp_top_p_exclusive provider flag"""
|
|
1000
|
+
config.run_config_properties.temperature = temperature
|
|
1001
|
+
config.run_config_properties.top_p = top_p
|
|
1002
|
+
|
|
1003
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
1004
|
+
mock_provider = Mock()
|
|
1005
|
+
mock_provider.temp_top_p_exclusive = True
|
|
1006
|
+
messages = [{"role": "user", "content": "Hello"}]
|
|
1007
|
+
|
|
1008
|
+
with (
|
|
1009
|
+
patch.object(adapter, "model_provider", return_value=mock_provider),
|
|
1010
|
+
patch.object(adapter, "litellm_model_id", return_value="anthropic/test-model"),
|
|
1011
|
+
patch.object(adapter, "build_extra_body", return_value={}),
|
|
1012
|
+
patch.object(adapter, "response_format_options", return_value={}),
|
|
1013
|
+
):
|
|
1014
|
+
if should_raise:
|
|
1015
|
+
with pytest.raises(
|
|
1016
|
+
ValueError,
|
|
1017
|
+
match="top_p and temperature can not both have custom values",
|
|
1018
|
+
):
|
|
1019
|
+
await adapter.build_completion_kwargs(mock_provider, messages, None)
|
|
1020
|
+
else:
|
|
1021
|
+
kwargs = await adapter.build_completion_kwargs(
|
|
1022
|
+
mock_provider, messages, None
|
|
1023
|
+
)
|
|
1024
|
+
|
|
1025
|
+
if expected_temp is None:
|
|
1026
|
+
assert "temperature" not in kwargs
|
|
1027
|
+
else:
|
|
1028
|
+
assert kwargs["temperature"] == expected_temp
|
|
1029
|
+
|
|
1030
|
+
if expected_top_p is None:
|
|
1031
|
+
assert "top_p" not in kwargs
|
|
1032
|
+
else:
|
|
1033
|
+
assert kwargs["top_p"] == expected_top_p
|
|
1034
|
+
|
|
1035
|
+
|
|
1036
|
+
@pytest.mark.asyncio
|
|
1037
|
+
async def test_build_completion_kwargs_temp_top_p_not_exclusive(config, mock_task):
|
|
1038
|
+
"""Test build_completion_kwargs with temp_top_p_exclusive=False allows both params"""
|
|
1039
|
+
config.run_config_properties.temperature = 0.7
|
|
1040
|
+
config.run_config_properties.top_p = 0.9
|
|
1041
|
+
|
|
1042
|
+
adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
|
|
1043
|
+
mock_provider = Mock()
|
|
1044
|
+
mock_provider.temp_top_p_exclusive = False
|
|
1045
|
+
messages = [{"role": "user", "content": "Hello"}]
|
|
1046
|
+
|
|
1047
|
+
with (
|
|
1048
|
+
patch.object(adapter, "model_provider", return_value=mock_provider),
|
|
1049
|
+
patch.object(adapter, "litellm_model_id", return_value="openai/test-model"),
|
|
1050
|
+
patch.object(adapter, "build_extra_body", return_value={}),
|
|
1051
|
+
patch.object(adapter, "response_format_options", return_value={}),
|
|
1052
|
+
):
|
|
1053
|
+
kwargs = await adapter.build_completion_kwargs(mock_provider, messages, None)
|
|
1054
|
+
|
|
1055
|
+
assert kwargs["temperature"] == 0.7
|
|
1056
|
+
assert kwargs["top_p"] == 0.9
|
|
@@ -18,12 +18,15 @@ from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
|
|
|
18
18
|
from kiln_ai.datamodel import PromptId
|
|
19
19
|
from kiln_ai.datamodel.datamodel_enums import ModelProviderName, StructuredOutputMode
|
|
20
20
|
from kiln_ai.datamodel.task import RunConfigProperties
|
|
21
|
+
from kiln_ai.datamodel.tool_id import ToolId
|
|
22
|
+
from kiln_ai.tools.base_tool import ToolCallContext
|
|
21
23
|
from kiln_ai.tools.built_in_tools.math_tools import (
|
|
22
24
|
AddTool,
|
|
23
25
|
DivideTool,
|
|
24
26
|
MultiplyTool,
|
|
25
27
|
SubtractTool,
|
|
26
28
|
)
|
|
29
|
+
from kiln_ai.tools.kiln_task_tool import KilnTaskToolResult
|
|
27
30
|
from kiln_ai.utils.open_ai_types import ChatCompletionMessageParam
|
|
28
31
|
|
|
29
32
|
|
|
@@ -91,6 +94,7 @@ async def run_simple_task_with_tools(
|
|
|
91
94
|
# Verify that AddTool.run was called with correct parameters
|
|
92
95
|
add_spy.run.assert_called()
|
|
93
96
|
add_call_args = add_spy.run.call_args
|
|
97
|
+
assert add_call_args.args[0].allow_saving # First arg is ToolCallContext
|
|
94
98
|
add_kwargs = add_call_args.kwargs
|
|
95
99
|
assert add_kwargs.get("a") == 2
|
|
96
100
|
assert add_kwargs.get("b") == 2
|
|
@@ -126,6 +130,9 @@ async def run_simple_task_with_tools(
|
|
|
126
130
|
# Verify that MultiplyTool.run was called with correct parameters
|
|
127
131
|
multiply_spy.run.assert_called()
|
|
128
132
|
multiply_call_args = multiply_spy.run.call_args
|
|
133
|
+
assert multiply_call_args.args[
|
|
134
|
+
0
|
|
135
|
+
].allow_saving # First arg is ToolCallContext
|
|
129
136
|
multiply_kwargs = multiply_call_args.kwargs
|
|
130
137
|
# Check that multiply was called with a=6, b=10 (or vice versa)
|
|
131
138
|
assert (
|
|
@@ -137,6 +144,7 @@ async def run_simple_task_with_tools(
|
|
|
137
144
|
# Verify that AddTool.run was called with correct parameters
|
|
138
145
|
add_spy.run.assert_called()
|
|
139
146
|
add_call_args = add_spy.run.call_args
|
|
147
|
+
assert add_call_args.args[0].allow_saving # First arg is ToolCallContext
|
|
140
148
|
add_kwargs = add_call_args.kwargs
|
|
141
149
|
# Check that add was called with a=60, b=4 (or vice versa)
|
|
142
150
|
assert (add_kwargs.get("a") == 60 and add_kwargs.get("b") == 4) or (
|
|
@@ -482,8 +490,16 @@ async def test_run_model_turn_parallel_tools(tmp_path):
|
|
|
482
490
|
)
|
|
483
491
|
|
|
484
492
|
# Verify both tools were called in parallel
|
|
485
|
-
|
|
486
|
-
|
|
493
|
+
# The context is passed as the first positional argument, not as a keyword argument
|
|
494
|
+
multiply_spy.run.assert_called_once()
|
|
495
|
+
multiply_call_args = multiply_spy.run.call_args
|
|
496
|
+
assert multiply_call_args.args[0].allow_saving # First arg is ToolCallContext
|
|
497
|
+
assert multiply_call_args.kwargs == {"a": 6, "b": 10}
|
|
498
|
+
|
|
499
|
+
add_spy.run.assert_called_once()
|
|
500
|
+
add_call_args = add_spy.run.call_args
|
|
501
|
+
assert add_call_args.args[0].allow_saving # First arg is ToolCallContext
|
|
502
|
+
assert add_call_args.kwargs == {"a": 2, "b": 3}
|
|
487
503
|
|
|
488
504
|
# Verify the result structure
|
|
489
505
|
assert isinstance(result, ModelTurnResult)
|
|
@@ -596,8 +612,16 @@ async def test_run_model_turn_sequential_tools(tmp_path):
|
|
|
596
612
|
)
|
|
597
613
|
|
|
598
614
|
# Verify tools were called sequentially
|
|
599
|
-
|
|
600
|
-
|
|
615
|
+
# The context is passed as the first positional argument, not as a keyword argument
|
|
616
|
+
multiply_spy.run.assert_called_once()
|
|
617
|
+
multiply_call_args = multiply_spy.run.call_args
|
|
618
|
+
assert multiply_call_args.args[0].allow_saving # First arg is ToolCallContext
|
|
619
|
+
assert multiply_call_args.kwargs == {"a": 6, "b": 10}
|
|
620
|
+
|
|
621
|
+
add_spy.run.assert_called_once()
|
|
622
|
+
add_call_args = add_spy.run.call_args
|
|
623
|
+
assert add_call_args.args[0].allow_saving # First arg is ToolCallContext
|
|
624
|
+
assert add_call_args.kwargs == {"a": 60, "b": 4}
|
|
601
625
|
|
|
602
626
|
# Verify the result structure
|
|
603
627
|
assert isinstance(result, ModelTurnResult)
|
|
@@ -756,11 +780,59 @@ class MockTool:
|
|
|
756
780
|
}
|
|
757
781
|
}
|
|
758
782
|
|
|
759
|
-
async def run(self, **kwargs) -> str:
|
|
783
|
+
async def run(self, context: ToolCallContext | None = None, **kwargs) -> str:
|
|
760
784
|
if self._raise_on_run:
|
|
761
785
|
raise self._raise_on_run
|
|
762
786
|
return self._return_value
|
|
763
787
|
|
|
788
|
+
async def id(self) -> ToolId:
|
|
789
|
+
"""Mock implementation of id for testing."""
|
|
790
|
+
return f"mock_tool_{self._name}"
|
|
791
|
+
|
|
792
|
+
|
|
793
|
+
class MockKilnTaskTool:
|
|
794
|
+
"""Mock tool class that returns KilnTaskToolResult for testing"""
|
|
795
|
+
|
|
796
|
+
def __init__(
|
|
797
|
+
self,
|
|
798
|
+
name: str,
|
|
799
|
+
raise_on_run: Exception | None = None,
|
|
800
|
+
output: str = "kiln_task_output",
|
|
801
|
+
kiln_task_tool_data: str = "project_id:::tool_id:::task_id:::run_id",
|
|
802
|
+
):
|
|
803
|
+
self._name = name
|
|
804
|
+
self._raise_on_run = raise_on_run
|
|
805
|
+
self._output = output
|
|
806
|
+
self._kiln_task_tool_data = kiln_task_tool_data
|
|
807
|
+
|
|
808
|
+
async def name(self) -> str:
|
|
809
|
+
return self._name
|
|
810
|
+
|
|
811
|
+
async def toolcall_definition(self) -> dict:
|
|
812
|
+
return {
|
|
813
|
+
"function": {
|
|
814
|
+
"parameters": {
|
|
815
|
+
"type": "object",
|
|
816
|
+
"properties": {"input": {"type": "string"}},
|
|
817
|
+
"required": ["input"],
|
|
818
|
+
}
|
|
819
|
+
}
|
|
820
|
+
}
|
|
821
|
+
|
|
822
|
+
async def run(
|
|
823
|
+
self, context: ToolCallContext | None = None, **kwargs
|
|
824
|
+
) -> KilnTaskToolResult:
|
|
825
|
+
if self._raise_on_run:
|
|
826
|
+
raise self._raise_on_run
|
|
827
|
+
return KilnTaskToolResult(
|
|
828
|
+
output=self._output,
|
|
829
|
+
kiln_task_tool_data=self._kiln_task_tool_data,
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
async def id(self) -> ToolId:
|
|
833
|
+
"""Mock implementation of id for testing."""
|
|
834
|
+
return f"mock_kiln_task_tool_{self._name}"
|
|
835
|
+
|
|
764
836
|
|
|
765
837
|
async def test_process_tool_calls_none_input(tmp_path):
|
|
766
838
|
"""Test process_tool_calls with None input"""
|
|
@@ -879,6 +951,7 @@ async def test_process_tool_calls_normal_tool_success(tmp_path):
|
|
|
879
951
|
"role": "tool",
|
|
880
952
|
"tool_call_id": "call_1",
|
|
881
953
|
"content": "5",
|
|
954
|
+
"kiln_task_tool_data": None,
|
|
882
955
|
}
|
|
883
956
|
|
|
884
957
|
|
|
@@ -915,8 +988,10 @@ async def test_process_tool_calls_multiple_normal_tools(tmp_path):
|
|
|
915
988
|
assert len(tool_messages) == 2
|
|
916
989
|
assert tool_messages[0]["tool_call_id"] == "call_1"
|
|
917
990
|
assert tool_messages[0]["content"] == "5"
|
|
991
|
+
assert tool_messages[0].get("kiln_task_tool_data") is None
|
|
918
992
|
assert tool_messages[1]["tool_call_id"] == "call_2"
|
|
919
993
|
assert tool_messages[1]["content"] == "6"
|
|
994
|
+
assert tool_messages[1].get("kiln_task_tool_data") is None
|
|
920
995
|
|
|
921
996
|
|
|
922
997
|
async def test_process_tool_calls_tool_not_found(tmp_path):
|
|
@@ -1072,6 +1147,7 @@ async def test_process_tool_calls_complex_result(tmp_path):
|
|
|
1072
1147
|
assert assistant_output is None
|
|
1073
1148
|
assert len(tool_messages) == 1
|
|
1074
1149
|
assert tool_messages[0]["content"] == complex_result
|
|
1150
|
+
assert tool_messages[0].get("kiln_task_tool_data") is None
|
|
1075
1151
|
|
|
1076
1152
|
|
|
1077
1153
|
async def test_process_tool_calls_task_response_with_normal_tools_error(tmp_path):
|
|
@@ -1101,3 +1177,41 @@ async def test_process_tool_calls_task_response_with_normal_tools_error(tmp_path
|
|
|
1101
1177
|
match="task_response tool call and other tool calls were both provided",
|
|
1102
1178
|
):
|
|
1103
1179
|
await litellm_adapter.process_tool_calls(tool_calls) # type: ignore
|
|
1180
|
+
|
|
1181
|
+
|
|
1182
|
+
async def test_process_tool_calls_kiln_task_tool_result(tmp_path):
|
|
1183
|
+
"""Test process_tool_calls with KilnTaskToolResult - tests the new if statement branch"""
|
|
1184
|
+
task = build_test_task(tmp_path)
|
|
1185
|
+
config = LiteLlmConfig(
|
|
1186
|
+
run_config_properties=RunConfigProperties(
|
|
1187
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
1188
|
+
model_name="gpt_4_1_mini",
|
|
1189
|
+
model_provider_name=ModelProviderName.openai,
|
|
1190
|
+
prompt_id="simple_prompt_builder",
|
|
1191
|
+
)
|
|
1192
|
+
)
|
|
1193
|
+
litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
|
|
1194
|
+
|
|
1195
|
+
mock_kiln_task_tool = MockKilnTaskTool(
|
|
1196
|
+
"kiln_task_tool",
|
|
1197
|
+
output="Task completed successfully",
|
|
1198
|
+
kiln_task_tool_data="proj123:::tool456:::task789:::run101",
|
|
1199
|
+
)
|
|
1200
|
+
tool_calls = [MockToolCall("call_1", "kiln_task_tool", '{"input": "test input"}')]
|
|
1201
|
+
|
|
1202
|
+
with patch.object(
|
|
1203
|
+
litellm_adapter, "cached_available_tools", return_value=[mock_kiln_task_tool]
|
|
1204
|
+
):
|
|
1205
|
+
assistant_output, tool_messages = await litellm_adapter.process_tool_calls(
|
|
1206
|
+
tool_calls # type: ignore
|
|
1207
|
+
)
|
|
1208
|
+
|
|
1209
|
+
assert assistant_output is None
|
|
1210
|
+
assert len(tool_messages) == 1
|
|
1211
|
+
assert tool_messages[0]["role"] == "tool"
|
|
1212
|
+
assert tool_messages[0]["tool_call_id"] == "call_1"
|
|
1213
|
+
assert tool_messages[0]["content"] == "Task completed successfully"
|
|
1214
|
+
assert (
|
|
1215
|
+
tool_messages[0].get("kiln_task_tool_data")
|
|
1216
|
+
== "proj123:::tool456:::task789:::run101"
|
|
1217
|
+
)
|
|
@@ -60,7 +60,9 @@ def test_save_run_isolation(test_task, adapter):
|
|
|
60
60
|
)
|
|
61
61
|
|
|
62
62
|
task_run = adapter.generate_run(
|
|
63
|
-
input=input_data,
|
|
63
|
+
input=input_data,
|
|
64
|
+
input_source=None,
|
|
65
|
+
run_output=run_output,
|
|
64
66
|
)
|
|
65
67
|
task_run.save_to_file()
|
|
66
68
|
|
|
@@ -146,7 +148,9 @@ def test_generate_run_non_ascii(test_task, adapter):
|
|
|
146
148
|
)
|
|
147
149
|
|
|
148
150
|
task_run = adapter.generate_run(
|
|
149
|
-
input=input_data,
|
|
151
|
+
input=input_data,
|
|
152
|
+
input_source=None,
|
|
153
|
+
run_output=run_output,
|
|
150
154
|
)
|
|
151
155
|
task_run.save_to_file()
|
|
152
156
|
|
|
@@ -256,7 +260,9 @@ def test_properties_for_task_output_custom_values(test_task):
|
|
|
256
260
|
run_output = RunOutput(output=output_data, intermediate_outputs=None)
|
|
257
261
|
|
|
258
262
|
task_run = adapter.generate_run(
|
|
259
|
-
input=input_data,
|
|
263
|
+
input=input_data,
|
|
264
|
+
input_source=None,
|
|
265
|
+
run_output=run_output,
|
|
260
266
|
)
|
|
261
267
|
task_run.save_to_file()
|
|
262
268
|
|
|
@@ -175,15 +175,12 @@ async def run_structured_output_test(tmp_path: Path, model_name: str, provider:
|
|
|
175
175
|
|
|
176
176
|
# Check reasoning models
|
|
177
177
|
assert a._model_provider is not None
|
|
178
|
-
if
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
else:
|
|
185
|
-
assert "reasoning" in run.intermediate_outputs
|
|
186
|
-
assert isinstance(run.intermediate_outputs["reasoning"], str)
|
|
178
|
+
if (
|
|
179
|
+
a._model_provider.reasoning_capable
|
|
180
|
+
and not a._model_provider.reasoning_optional_for_structured_output
|
|
181
|
+
):
|
|
182
|
+
assert "reasoning" in run.intermediate_outputs
|
|
183
|
+
assert isinstance(run.intermediate_outputs["reasoning"], str)
|
|
187
184
|
|
|
188
185
|
|
|
189
186
|
def build_structured_input_test_task(tmp_path: Path):
|
|
@@ -344,6 +341,7 @@ async def test_all_built_in_models_structured_input_mocked(tmp_path):
|
|
|
344
341
|
mock_config = Mock()
|
|
345
342
|
mock_config.open_ai_api_key = "mock_api_key"
|
|
346
343
|
mock_config.user_id = "test_user"
|
|
344
|
+
mock_config.groq_api_key = "mock_api_key"
|
|
347
345
|
|
|
348
346
|
with (
|
|
349
347
|
patch(
|
|
@@ -398,6 +396,7 @@ async def test_structured_input_cot_prompt_builder_mocked(tmp_path):
|
|
|
398
396
|
mock_config = Mock()
|
|
399
397
|
mock_config.open_ai_api_key = "mock_api_key"
|
|
400
398
|
mock_config.user_id = "test_user"
|
|
399
|
+
mock_config.groq_api_key = "mock_api_key"
|
|
401
400
|
|
|
402
401
|
with (
|
|
403
402
|
patch(
|
|
@@ -456,7 +455,7 @@ When asked for a final result, this is the format (for an equilateral example):
|
|
|
456
455
|
"""
|
|
457
456
|
task.output_json_schema = json.dumps(triangle_schema)
|
|
458
457
|
task.save_to_file()
|
|
459
|
-
response,
|
|
458
|
+
response, _, _ = await run_structured_input_task_no_validation(
|
|
460
459
|
task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
|
|
461
460
|
)
|
|
462
461
|
|