kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (133) hide show
  1. kiln_ai/adapters/__init__.py +6 -0
  2. kiln_ai/adapters/adapter_registry.py +43 -226
  3. kiln_ai/adapters/chunkers/__init__.py +13 -0
  4. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  5. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  6. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  7. kiln_ai/adapters/chunkers/helpers.py +23 -0
  8. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  9. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  10. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  11. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  12. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  13. kiln_ai/adapters/embedding/__init__.py +0 -0
  14. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  15. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  16. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  17. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  18. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  19. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  20. kiln_ai/adapters/eval/eval_runner.py +6 -2
  21. kiln_ai/adapters/eval/test_base_eval.py +1 -3
  22. kiln_ai/adapters/eval/test_g_eval.py +1 -1
  23. kiln_ai/adapters/extractors/__init__.py +18 -0
  24. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  25. kiln_ai/adapters/extractors/encoding.py +20 -0
  26. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  27. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  28. kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
  29. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  30. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  31. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  32. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  33. kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
  34. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  35. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  36. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  37. kiln_ai/adapters/ml_embedding_model_list.py +494 -0
  38. kiln_ai/adapters/ml_model_list.py +876 -18
  39. kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
  40. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
  41. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
  42. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
  43. kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
  44. kiln_ai/adapters/ollama_tools.py +69 -12
  45. kiln_ai/adapters/provider_tools.py +190 -46
  46. kiln_ai/adapters/rag/deduplication.py +49 -0
  47. kiln_ai/adapters/rag/progress.py +252 -0
  48. kiln_ai/adapters/rag/rag_runners.py +844 -0
  49. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  50. kiln_ai/adapters/rag/test_progress.py +785 -0
  51. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  52. kiln_ai/adapters/remote_config.py +80 -8
  53. kiln_ai/adapters/test_adapter_registry.py +579 -86
  54. kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
  55. kiln_ai/adapters/test_ml_model_list.py +202 -0
  56. kiln_ai/adapters/test_ollama_tools.py +340 -1
  57. kiln_ai/adapters/test_prompt_builders.py +1 -1
  58. kiln_ai/adapters/test_provider_tools.py +199 -8
  59. kiln_ai/adapters/test_remote_config.py +551 -56
  60. kiln_ai/adapters/vector_store/__init__.py +1 -0
  61. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  62. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  63. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  64. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  65. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  66. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  67. kiln_ai/datamodel/__init__.py +16 -13
  68. kiln_ai/datamodel/basemodel.py +201 -4
  69. kiln_ai/datamodel/chunk.py +158 -0
  70. kiln_ai/datamodel/datamodel_enums.py +27 -0
  71. kiln_ai/datamodel/embedding.py +64 -0
  72. kiln_ai/datamodel/external_tool_server.py +206 -54
  73. kiln_ai/datamodel/extraction.py +317 -0
  74. kiln_ai/datamodel/project.py +33 -1
  75. kiln_ai/datamodel/rag.py +79 -0
  76. kiln_ai/datamodel/task.py +5 -0
  77. kiln_ai/datamodel/task_output.py +41 -11
  78. kiln_ai/datamodel/test_attachment.py +649 -0
  79. kiln_ai/datamodel/test_basemodel.py +270 -14
  80. kiln_ai/datamodel/test_chunk_models.py +317 -0
  81. kiln_ai/datamodel/test_dataset_split.py +1 -1
  82. kiln_ai/datamodel/test_datasource.py +50 -0
  83. kiln_ai/datamodel/test_embedding_models.py +448 -0
  84. kiln_ai/datamodel/test_eval_model.py +6 -6
  85. kiln_ai/datamodel/test_external_tool_server.py +534 -152
  86. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  87. kiln_ai/datamodel/test_extraction_model.py +501 -0
  88. kiln_ai/datamodel/test_rag.py +641 -0
  89. kiln_ai/datamodel/test_task.py +35 -1
  90. kiln_ai/datamodel/test_tool_id.py +187 -1
  91. kiln_ai/datamodel/test_vector_store.py +320 -0
  92. kiln_ai/datamodel/tool_id.py +58 -0
  93. kiln_ai/datamodel/vector_store.py +141 -0
  94. kiln_ai/tools/base_tool.py +12 -3
  95. kiln_ai/tools/built_in_tools/math_tools.py +12 -4
  96. kiln_ai/tools/kiln_task_tool.py +158 -0
  97. kiln_ai/tools/mcp_server_tool.py +2 -2
  98. kiln_ai/tools/mcp_session_manager.py +51 -22
  99. kiln_ai/tools/rag_tools.py +164 -0
  100. kiln_ai/tools/test_kiln_task_tool.py +527 -0
  101. kiln_ai/tools/test_mcp_server_tool.py +4 -15
  102. kiln_ai/tools/test_mcp_session_manager.py +187 -227
  103. kiln_ai/tools/test_rag_tools.py +929 -0
  104. kiln_ai/tools/test_tool_registry.py +290 -7
  105. kiln_ai/tools/tool_registry.py +69 -16
  106. kiln_ai/utils/__init__.py +3 -0
  107. kiln_ai/utils/async_job_runner.py +62 -17
  108. kiln_ai/utils/config.py +2 -2
  109. kiln_ai/utils/env.py +15 -0
  110. kiln_ai/utils/filesystem.py +14 -0
  111. kiln_ai/utils/filesystem_cache.py +60 -0
  112. kiln_ai/utils/litellm.py +94 -0
  113. kiln_ai/utils/lock.py +100 -0
  114. kiln_ai/utils/mime_type.py +38 -0
  115. kiln_ai/utils/open_ai_types.py +19 -2
  116. kiln_ai/utils/pdf_utils.py +59 -0
  117. kiln_ai/utils/test_async_job_runner.py +151 -35
  118. kiln_ai/utils/test_env.py +142 -0
  119. kiln_ai/utils/test_filesystem_cache.py +316 -0
  120. kiln_ai/utils/test_litellm.py +206 -0
  121. kiln_ai/utils/test_lock.py +185 -0
  122. kiln_ai/utils/test_mime_type.py +66 -0
  123. kiln_ai/utils/test_open_ai_types.py +88 -12
  124. kiln_ai/utils/test_pdf_utils.py +86 -0
  125. kiln_ai/utils/test_uuid.py +111 -0
  126. kiln_ai/utils/test_validation.py +524 -0
  127. kiln_ai/utils/uuid.py +9 -0
  128. kiln_ai/utils/validation.py +90 -0
  129. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
  130. kiln_ai-0.22.0.dist-info/RECORD +213 -0
  131. kiln_ai-0.20.1.dist-info/RECORD +0 -138
  132. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
  133. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -11,13 +11,8 @@ from litellm.types.utils import (
11
11
  Choices,
12
12
  ModelResponse,
13
13
  )
14
- from litellm.types.utils import (
15
- Message as LiteLLMMessage,
16
- )
14
+ from litellm.types.utils import Message as LiteLLMMessage
17
15
  from litellm.types.utils import Usage as LiteLlmUsage
18
- from openai.types.chat import (
19
- ChatCompletionToolMessageParam,
20
- )
21
16
  from openai.types.chat.chat_completion_message_tool_call_param import (
22
17
  ChatCompletionMessageToolCallParam,
23
18
  )
@@ -36,11 +31,14 @@ from kiln_ai.adapters.model_adapters.base_adapter import (
36
31
  )
37
32
  from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
38
33
  from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
39
- from kiln_ai.tools.base_tool import KilnToolInterface
34
+ from kiln_ai.tools.base_tool import KilnToolInterface, ToolCallContext
35
+ from kiln_ai.tools.kiln_task_tool import KilnTaskToolResult
40
36
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
37
+ from kiln_ai.utils.litellm import get_litellm_provider_info
41
38
  from kiln_ai.utils.open_ai_types import (
42
39
  ChatCompletionAssistantMessageParamWrapper,
43
40
  ChatCompletionMessageParam,
41
+ ChatCompletionToolMessageParamWrapper,
44
42
  )
45
43
 
46
44
  MAX_CALLS_PER_TURN = 10
@@ -447,75 +445,16 @@ class LiteLlmAdapter(BaseAdapter):
447
445
  def litellm_model_id(self) -> str:
448
446
  # The model ID is an interesting combination of format and url endpoint.
449
447
  # It specifics the provider URL/host, but this is overridden if you manually set an api url
450
-
451
448
  if self._litellm_model_id:
452
449
  return self._litellm_model_id
453
450
 
454
- provider = self.model_provider()
455
- if not provider.model_id:
456
- raise ValueError("Model ID is required for OpenAI compatible models")
457
-
458
- litellm_provider_name: str | None = None
459
- is_custom = False
460
- match provider.name:
461
- case ModelProviderName.openrouter:
462
- litellm_provider_name = "openrouter"
463
- case ModelProviderName.openai:
464
- litellm_provider_name = "openai"
465
- case ModelProviderName.groq:
466
- litellm_provider_name = "groq"
467
- case ModelProviderName.anthropic:
468
- litellm_provider_name = "anthropic"
469
- case ModelProviderName.ollama:
470
- # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
471
- # This is because we're setting detailed features like response_format=json_schema and want lower level control.
472
- is_custom = True
473
- case ModelProviderName.docker_model_runner:
474
- # Docker Model Runner uses OpenAI-compatible API, similar to Ollama
475
- # We want direct control over the requests for features like response_format=json_schema
476
- is_custom = True
477
- case ModelProviderName.gemini_api:
478
- litellm_provider_name = "gemini"
479
- case ModelProviderName.fireworks_ai:
480
- litellm_provider_name = "fireworks_ai"
481
- case ModelProviderName.amazon_bedrock:
482
- litellm_provider_name = "bedrock"
483
- case ModelProviderName.azure_openai:
484
- litellm_provider_name = "azure"
485
- case ModelProviderName.huggingface:
486
- litellm_provider_name = "huggingface"
487
- case ModelProviderName.vertex:
488
- litellm_provider_name = "vertex_ai"
489
- case ModelProviderName.together_ai:
490
- litellm_provider_name = "together_ai"
491
- case ModelProviderName.cerebras:
492
- litellm_provider_name = "cerebras"
493
- case ModelProviderName.siliconflow_cn:
494
- is_custom = True
495
- case ModelProviderName.openai_compatible:
496
- is_custom = True
497
- case ModelProviderName.kiln_custom_registry:
498
- is_custom = True
499
- case ModelProviderName.kiln_fine_tune:
500
- is_custom = True
501
- case _:
502
- raise_exhaustive_enum_error(provider.name)
503
-
504
- if is_custom:
505
- if self._api_base is None:
506
- raise ValueError(
507
- "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
508
- )
509
- # Use openai as it's only used for format, not url
510
- litellm_provider_name = "openai"
511
-
512
- # Sholdn't be possible but keep type checker happy
513
- if litellm_provider_name is None:
451
+ litellm_provider_info = get_litellm_provider_info(self.model_provider())
452
+ if litellm_provider_info.is_custom and self._api_base is None:
514
453
  raise ValueError(
515
- f"Provider name could not lookup valid litellm provider ID {provider.model_id}"
454
+ "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
516
455
  )
517
456
 
518
- self._litellm_model_id = litellm_provider_name + "/" + provider.model_id
457
+ self._litellm_model_id = litellm_provider_info.litellm_model_id
519
458
  return self._litellm_model_id
520
459
 
521
460
  async def build_completion_kwargs(
@@ -550,6 +489,21 @@ class LiteLlmAdapter(BaseAdapter):
550
489
  completion_kwargs["tools"] = tool_calls
551
490
  completion_kwargs["tool_choice"] = "auto"
552
491
 
492
+ # Special condition for Claude Opus 4.1 and Sonnet 4.5, where we can only specify top_p or temp, not both.
493
+ # Remove default values (1.0) prioritizing anything the user customized, then error with helpful message if they are both custom.
494
+ if provider.temp_top_p_exclusive:
495
+ if "top_p" in completion_kwargs and completion_kwargs["top_p"] == 1.0:
496
+ del completion_kwargs["top_p"]
497
+ if (
498
+ "temperature" in completion_kwargs
499
+ and completion_kwargs["temperature"] == 1.0
500
+ ):
501
+ del completion_kwargs["temperature"]
502
+ if "top_p" in completion_kwargs and "temperature" in completion_kwargs:
503
+ raise ValueError(
504
+ "top_p and temperature can not both have custom values for this model. This is a restriction from the model provider. Please set only one of them to a custom value (not 1.0)."
505
+ )
506
+
553
507
  if not skip_response_format:
554
508
  # Response format: json_schema, json_instructions, json_mode, function_calling, etc
555
509
  response_format_options = await self.response_format_options()
@@ -614,12 +568,12 @@ class LiteLlmAdapter(BaseAdapter):
614
568
 
615
569
  async def process_tool_calls(
616
570
  self, tool_calls: list[ChatCompletionMessageToolCall] | None
617
- ) -> tuple[str | None, list[ChatCompletionToolMessageParam]]:
571
+ ) -> tuple[str | None, list[ChatCompletionToolMessageParamWrapper]]:
618
572
  if tool_calls is None:
619
573
  return None, []
620
574
 
621
575
  assistant_output_from_toolcall: str | None = None
622
- tool_call_response_messages: list[ChatCompletionToolMessageParam] = []
576
+ tool_call_response_messages: list[ChatCompletionToolMessageParamWrapper] = []
623
577
 
624
578
  for tool_call in tool_calls:
625
579
  # Kiln "task_response" tool is used for returning structured output via tool calls.
@@ -656,13 +610,24 @@ class LiteLlmAdapter(BaseAdapter):
656
610
  f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
657
611
  ) from e
658
612
 
659
- result = await tool.run(**parsed_args)
613
+ # Create context with the calling task's allow_saving setting
614
+ context = ToolCallContext(
615
+ allow_saving=self.base_adapter_config.allow_saving
616
+ )
617
+ result = await tool.run(context, **parsed_args)
618
+ if isinstance(result, KilnTaskToolResult):
619
+ content = result.output
620
+ kiln_task_tool_data = result.kiln_task_tool_data
621
+ else:
622
+ content = result
623
+ kiln_task_tool_data = None
660
624
 
661
625
  tool_call_response_messages.append(
662
- ChatCompletionToolMessageParam(
626
+ ChatCompletionToolMessageParamWrapper(
663
627
  role="tool",
664
628
  tool_call_id=tool_call.id,
665
- content=result,
629
+ content=content,
630
+ kiln_task_tool_data=kiln_task_tool_data,
666
631
  )
667
632
  )
668
633
 
@@ -351,7 +351,7 @@ def test_litellm_model_id_unknown_provider(config, mock_task):
351
351
 
352
352
  with patch.object(adapter, "model_provider", return_value=mock_provider):
353
353
  with patch(
354
- "kiln_ai.adapters.model_adapters.litellm_adapter.raise_exhaustive_enum_error"
354
+ "kiln_ai.utils.litellm.raise_exhaustive_enum_error"
355
355
  ) as mock_raise_error:
356
356
  mock_raise_error.side_effect = Exception("Test error")
357
357
 
@@ -405,6 +405,7 @@ async def test_build_completion_kwargs_custom_temperature_top_p(config, mock_tas
405
405
 
406
406
  adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
407
407
  mock_provider = Mock()
408
+ mock_provider.temp_top_p_exclusive = False
408
409
  messages = [{"role": "user", "content": "Hello"}]
409
410
 
410
411
  with (
@@ -446,6 +447,7 @@ async def test_build_completion_kwargs(
446
447
  """Test build_completion_kwargs with various configurations"""
447
448
  adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
448
449
  mock_provider = Mock()
450
+ mock_provider.temp_top_p_exclusive = False
449
451
  messages = [{"role": "user", "content": "Hello"}]
450
452
 
451
453
  with (
@@ -613,6 +615,7 @@ async def test_build_completion_kwargs_includes_tools(
613
615
  """Test build_completion_kwargs includes tools when available_tools has tools"""
614
616
  adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
615
617
  mock_provider = Mock()
618
+ mock_provider.temp_top_p_exclusive = False
616
619
  messages = [{"role": "user", "content": "Hello"}]
617
620
 
618
621
  with (
@@ -666,6 +669,7 @@ async def test_build_completion_kwargs_raises_error_with_tools_conflict(
666
669
  config.run_config_properties.structured_output_mode = structured_output_mode
667
670
  adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
668
671
  mock_provider = Mock()
672
+ mock_provider.temp_top_p_exclusive = False
669
673
  messages = [{"role": "user", "content": "Hello"}]
670
674
 
671
675
  with (
@@ -976,3 +980,77 @@ def test_build_extra_body_enable_thinking(config, mock_task, enable_thinking):
976
980
  extra_body = adapter.build_extra_body(provider)
977
981
 
978
982
  assert extra_body["enable_thinking"] == enable_thinking
983
+
984
+
985
+ @pytest.mark.asyncio
986
+ @pytest.mark.parametrize(
987
+ "temperature,top_p,should_raise,expected_temp,expected_top_p",
988
+ [
989
+ (1.0, 1.0, False, None, None),
990
+ (0.7, 1.0, False, 0.7, None),
991
+ (1.0, 0.9, False, None, 0.9),
992
+ (0.7, 0.9, True, None, None),
993
+ (0.5, 0.5, True, None, None),
994
+ ],
995
+ )
996
+ async def test_build_completion_kwargs_temp_top_p_exclusive(
997
+ config, mock_task, temperature, top_p, should_raise, expected_temp, expected_top_p
998
+ ):
999
+ """Test build_completion_kwargs with temp_top_p_exclusive provider flag"""
1000
+ config.run_config_properties.temperature = temperature
1001
+ config.run_config_properties.top_p = top_p
1002
+
1003
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
1004
+ mock_provider = Mock()
1005
+ mock_provider.temp_top_p_exclusive = True
1006
+ messages = [{"role": "user", "content": "Hello"}]
1007
+
1008
+ with (
1009
+ patch.object(adapter, "model_provider", return_value=mock_provider),
1010
+ patch.object(adapter, "litellm_model_id", return_value="anthropic/test-model"),
1011
+ patch.object(adapter, "build_extra_body", return_value={}),
1012
+ patch.object(adapter, "response_format_options", return_value={}),
1013
+ ):
1014
+ if should_raise:
1015
+ with pytest.raises(
1016
+ ValueError,
1017
+ match="top_p and temperature can not both have custom values",
1018
+ ):
1019
+ await adapter.build_completion_kwargs(mock_provider, messages, None)
1020
+ else:
1021
+ kwargs = await adapter.build_completion_kwargs(
1022
+ mock_provider, messages, None
1023
+ )
1024
+
1025
+ if expected_temp is None:
1026
+ assert "temperature" not in kwargs
1027
+ else:
1028
+ assert kwargs["temperature"] == expected_temp
1029
+
1030
+ if expected_top_p is None:
1031
+ assert "top_p" not in kwargs
1032
+ else:
1033
+ assert kwargs["top_p"] == expected_top_p
1034
+
1035
+
1036
+ @pytest.mark.asyncio
1037
+ async def test_build_completion_kwargs_temp_top_p_not_exclusive(config, mock_task):
1038
+ """Test build_completion_kwargs with temp_top_p_exclusive=False allows both params"""
1039
+ config.run_config_properties.temperature = 0.7
1040
+ config.run_config_properties.top_p = 0.9
1041
+
1042
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
1043
+ mock_provider = Mock()
1044
+ mock_provider.temp_top_p_exclusive = False
1045
+ messages = [{"role": "user", "content": "Hello"}]
1046
+
1047
+ with (
1048
+ patch.object(adapter, "model_provider", return_value=mock_provider),
1049
+ patch.object(adapter, "litellm_model_id", return_value="openai/test-model"),
1050
+ patch.object(adapter, "build_extra_body", return_value={}),
1051
+ patch.object(adapter, "response_format_options", return_value={}),
1052
+ ):
1053
+ kwargs = await adapter.build_completion_kwargs(mock_provider, messages, None)
1054
+
1055
+ assert kwargs["temperature"] == 0.7
1056
+ assert kwargs["top_p"] == 0.9
@@ -18,12 +18,15 @@ from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
18
18
  from kiln_ai.datamodel import PromptId
19
19
  from kiln_ai.datamodel.datamodel_enums import ModelProviderName, StructuredOutputMode
20
20
  from kiln_ai.datamodel.task import RunConfigProperties
21
+ from kiln_ai.datamodel.tool_id import ToolId
22
+ from kiln_ai.tools.base_tool import ToolCallContext
21
23
  from kiln_ai.tools.built_in_tools.math_tools import (
22
24
  AddTool,
23
25
  DivideTool,
24
26
  MultiplyTool,
25
27
  SubtractTool,
26
28
  )
29
+ from kiln_ai.tools.kiln_task_tool import KilnTaskToolResult
27
30
  from kiln_ai.utils.open_ai_types import ChatCompletionMessageParam
28
31
 
29
32
 
@@ -91,6 +94,7 @@ async def run_simple_task_with_tools(
91
94
  # Verify that AddTool.run was called with correct parameters
92
95
  add_spy.run.assert_called()
93
96
  add_call_args = add_spy.run.call_args
97
+ assert add_call_args.args[0].allow_saving # First arg is ToolCallContext
94
98
  add_kwargs = add_call_args.kwargs
95
99
  assert add_kwargs.get("a") == 2
96
100
  assert add_kwargs.get("b") == 2
@@ -126,6 +130,9 @@ async def run_simple_task_with_tools(
126
130
  # Verify that MultiplyTool.run was called with correct parameters
127
131
  multiply_spy.run.assert_called()
128
132
  multiply_call_args = multiply_spy.run.call_args
133
+ assert multiply_call_args.args[
134
+ 0
135
+ ].allow_saving # First arg is ToolCallContext
129
136
  multiply_kwargs = multiply_call_args.kwargs
130
137
  # Check that multiply was called with a=6, b=10 (or vice versa)
131
138
  assert (
@@ -137,6 +144,7 @@ async def run_simple_task_with_tools(
137
144
  # Verify that AddTool.run was called with correct parameters
138
145
  add_spy.run.assert_called()
139
146
  add_call_args = add_spy.run.call_args
147
+ assert add_call_args.args[0].allow_saving # First arg is ToolCallContext
140
148
  add_kwargs = add_call_args.kwargs
141
149
  # Check that add was called with a=60, b=4 (or vice versa)
142
150
  assert (add_kwargs.get("a") == 60 and add_kwargs.get("b") == 4) or (
@@ -482,8 +490,16 @@ async def test_run_model_turn_parallel_tools(tmp_path):
482
490
  )
483
491
 
484
492
  # Verify both tools were called in parallel
485
- multiply_spy.run.assert_called_once_with(a=6, b=10)
486
- add_spy.run.assert_called_once_with(a=2, b=3)
493
+ # The context is passed as the first positional argument, not as a keyword argument
494
+ multiply_spy.run.assert_called_once()
495
+ multiply_call_args = multiply_spy.run.call_args
496
+ assert multiply_call_args.args[0].allow_saving # First arg is ToolCallContext
497
+ assert multiply_call_args.kwargs == {"a": 6, "b": 10}
498
+
499
+ add_spy.run.assert_called_once()
500
+ add_call_args = add_spy.run.call_args
501
+ assert add_call_args.args[0].allow_saving # First arg is ToolCallContext
502
+ assert add_call_args.kwargs == {"a": 2, "b": 3}
487
503
 
488
504
  # Verify the result structure
489
505
  assert isinstance(result, ModelTurnResult)
@@ -596,8 +612,16 @@ async def test_run_model_turn_sequential_tools(tmp_path):
596
612
  )
597
613
 
598
614
  # Verify tools were called sequentially
599
- multiply_spy.run.assert_called_once_with(a=6, b=10)
600
- add_spy.run.assert_called_once_with(a=60, b=4)
615
+ # The context is passed as the first positional argument, not as a keyword argument
616
+ multiply_spy.run.assert_called_once()
617
+ multiply_call_args = multiply_spy.run.call_args
618
+ assert multiply_call_args.args[0].allow_saving # First arg is ToolCallContext
619
+ assert multiply_call_args.kwargs == {"a": 6, "b": 10}
620
+
621
+ add_spy.run.assert_called_once()
622
+ add_call_args = add_spy.run.call_args
623
+ assert add_call_args.args[0].allow_saving # First arg is ToolCallContext
624
+ assert add_call_args.kwargs == {"a": 60, "b": 4}
601
625
 
602
626
  # Verify the result structure
603
627
  assert isinstance(result, ModelTurnResult)
@@ -756,11 +780,59 @@ class MockTool:
756
780
  }
757
781
  }
758
782
 
759
- async def run(self, **kwargs) -> str:
783
+ async def run(self, context: ToolCallContext | None = None, **kwargs) -> str:
760
784
  if self._raise_on_run:
761
785
  raise self._raise_on_run
762
786
  return self._return_value
763
787
 
788
+ async def id(self) -> ToolId:
789
+ """Mock implementation of id for testing."""
790
+ return f"mock_tool_{self._name}"
791
+
792
+
793
+ class MockKilnTaskTool:
794
+ """Mock tool class that returns KilnTaskToolResult for testing"""
795
+
796
+ def __init__(
797
+ self,
798
+ name: str,
799
+ raise_on_run: Exception | None = None,
800
+ output: str = "kiln_task_output",
801
+ kiln_task_tool_data: str = "project_id:::tool_id:::task_id:::run_id",
802
+ ):
803
+ self._name = name
804
+ self._raise_on_run = raise_on_run
805
+ self._output = output
806
+ self._kiln_task_tool_data = kiln_task_tool_data
807
+
808
+ async def name(self) -> str:
809
+ return self._name
810
+
811
+ async def toolcall_definition(self) -> dict:
812
+ return {
813
+ "function": {
814
+ "parameters": {
815
+ "type": "object",
816
+ "properties": {"input": {"type": "string"}},
817
+ "required": ["input"],
818
+ }
819
+ }
820
+ }
821
+
822
+ async def run(
823
+ self, context: ToolCallContext | None = None, **kwargs
824
+ ) -> KilnTaskToolResult:
825
+ if self._raise_on_run:
826
+ raise self._raise_on_run
827
+ return KilnTaskToolResult(
828
+ output=self._output,
829
+ kiln_task_tool_data=self._kiln_task_tool_data,
830
+ )
831
+
832
+ async def id(self) -> ToolId:
833
+ """Mock implementation of id for testing."""
834
+ return f"mock_kiln_task_tool_{self._name}"
835
+
764
836
 
765
837
  async def test_process_tool_calls_none_input(tmp_path):
766
838
  """Test process_tool_calls with None input"""
@@ -879,6 +951,7 @@ async def test_process_tool_calls_normal_tool_success(tmp_path):
879
951
  "role": "tool",
880
952
  "tool_call_id": "call_1",
881
953
  "content": "5",
954
+ "kiln_task_tool_data": None,
882
955
  }
883
956
 
884
957
 
@@ -915,8 +988,10 @@ async def test_process_tool_calls_multiple_normal_tools(tmp_path):
915
988
  assert len(tool_messages) == 2
916
989
  assert tool_messages[0]["tool_call_id"] == "call_1"
917
990
  assert tool_messages[0]["content"] == "5"
991
+ assert tool_messages[0].get("kiln_task_tool_data") is None
918
992
  assert tool_messages[1]["tool_call_id"] == "call_2"
919
993
  assert tool_messages[1]["content"] == "6"
994
+ assert tool_messages[1].get("kiln_task_tool_data") is None
920
995
 
921
996
 
922
997
  async def test_process_tool_calls_tool_not_found(tmp_path):
@@ -1072,6 +1147,7 @@ async def test_process_tool_calls_complex_result(tmp_path):
1072
1147
  assert assistant_output is None
1073
1148
  assert len(tool_messages) == 1
1074
1149
  assert tool_messages[0]["content"] == complex_result
1150
+ assert tool_messages[0].get("kiln_task_tool_data") is None
1075
1151
 
1076
1152
 
1077
1153
  async def test_process_tool_calls_task_response_with_normal_tools_error(tmp_path):
@@ -1101,3 +1177,41 @@ async def test_process_tool_calls_task_response_with_normal_tools_error(tmp_path
1101
1177
  match="task_response tool call and other tool calls were both provided",
1102
1178
  ):
1103
1179
  await litellm_adapter.process_tool_calls(tool_calls) # type: ignore
1180
+
1181
+
1182
+ async def test_process_tool_calls_kiln_task_tool_result(tmp_path):
1183
+ """Test process_tool_calls with KilnTaskToolResult - tests the new if statement branch"""
1184
+ task = build_test_task(tmp_path)
1185
+ config = LiteLlmConfig(
1186
+ run_config_properties=RunConfigProperties(
1187
+ structured_output_mode=StructuredOutputMode.json_schema,
1188
+ model_name="gpt_4_1_mini",
1189
+ model_provider_name=ModelProviderName.openai,
1190
+ prompt_id="simple_prompt_builder",
1191
+ )
1192
+ )
1193
+ litellm_adapter = LiteLlmAdapter(config=config, kiln_task=task)
1194
+
1195
+ mock_kiln_task_tool = MockKilnTaskTool(
1196
+ "kiln_task_tool",
1197
+ output="Task completed successfully",
1198
+ kiln_task_tool_data="proj123:::tool456:::task789:::run101",
1199
+ )
1200
+ tool_calls = [MockToolCall("call_1", "kiln_task_tool", '{"input": "test input"}')]
1201
+
1202
+ with patch.object(
1203
+ litellm_adapter, "cached_available_tools", return_value=[mock_kiln_task_tool]
1204
+ ):
1205
+ assistant_output, tool_messages = await litellm_adapter.process_tool_calls(
1206
+ tool_calls # type: ignore
1207
+ )
1208
+
1209
+ assert assistant_output is None
1210
+ assert len(tool_messages) == 1
1211
+ assert tool_messages[0]["role"] == "tool"
1212
+ assert tool_messages[0]["tool_call_id"] == "call_1"
1213
+ assert tool_messages[0]["content"] == "Task completed successfully"
1214
+ assert (
1215
+ tool_messages[0].get("kiln_task_tool_data")
1216
+ == "proj123:::tool456:::task789:::run101"
1217
+ )
@@ -60,7 +60,9 @@ def test_save_run_isolation(test_task, adapter):
60
60
  )
61
61
 
62
62
  task_run = adapter.generate_run(
63
- input=input_data, input_source=None, run_output=run_output
63
+ input=input_data,
64
+ input_source=None,
65
+ run_output=run_output,
64
66
  )
65
67
  task_run.save_to_file()
66
68
 
@@ -146,7 +148,9 @@ def test_generate_run_non_ascii(test_task, adapter):
146
148
  )
147
149
 
148
150
  task_run = adapter.generate_run(
149
- input=input_data, input_source=None, run_output=run_output
151
+ input=input_data,
152
+ input_source=None,
153
+ run_output=run_output,
150
154
  )
151
155
  task_run.save_to_file()
152
156
 
@@ -256,7 +260,9 @@ def test_properties_for_task_output_custom_values(test_task):
256
260
  run_output = RunOutput(output=output_data, intermediate_outputs=None)
257
261
 
258
262
  task_run = adapter.generate_run(
259
- input=input_data, input_source=None, run_output=run_output
263
+ input=input_data,
264
+ input_source=None,
265
+ run_output=run_output,
260
266
  )
261
267
  task_run.save_to_file()
262
268
 
@@ -175,15 +175,12 @@ async def run_structured_output_test(tmp_path: Path, model_name: str, provider:
175
175
 
176
176
  # Check reasoning models
177
177
  assert a._model_provider is not None
178
- if a._model_provider.reasoning_capable:
179
- # some providers have reasoning_capable models that do not return the reasoning
180
- # for structured output responses (they provide it only for non-structured output)
181
- if a._model_provider.reasoning_optional_for_structured_output:
182
- # models may be updated to include the reasoning in the future
183
- assert "reasoning" not in run.intermediate_outputs
184
- else:
185
- assert "reasoning" in run.intermediate_outputs
186
- assert isinstance(run.intermediate_outputs["reasoning"], str)
178
+ if (
179
+ a._model_provider.reasoning_capable
180
+ and not a._model_provider.reasoning_optional_for_structured_output
181
+ ):
182
+ assert "reasoning" in run.intermediate_outputs
183
+ assert isinstance(run.intermediate_outputs["reasoning"], str)
187
184
 
188
185
 
189
186
  def build_structured_input_test_task(tmp_path: Path):
@@ -344,6 +341,7 @@ async def test_all_built_in_models_structured_input_mocked(tmp_path):
344
341
  mock_config = Mock()
345
342
  mock_config.open_ai_api_key = "mock_api_key"
346
343
  mock_config.user_id = "test_user"
344
+ mock_config.groq_api_key = "mock_api_key"
347
345
 
348
346
  with (
349
347
  patch(
@@ -398,6 +396,7 @@ async def test_structured_input_cot_prompt_builder_mocked(tmp_path):
398
396
  mock_config = Mock()
399
397
  mock_config.open_ai_api_key = "mock_api_key"
400
398
  mock_config.user_id = "test_user"
399
+ mock_config.groq_api_key = "mock_api_key"
401
400
 
402
401
  with (
403
402
  patch(
@@ -456,7 +455,7 @@ When asked for a final result, this is the format (for an equilateral example):
456
455
  """
457
456
  task.output_json_schema = json.dumps(triangle_schema)
458
457
  task.save_to_file()
459
- response, adapter, _ = await run_structured_input_task_no_validation(
458
+ response, _, _ = await run_structured_input_task_no_validation(
460
459
  task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
461
460
  )
462
461