kiln-ai 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (158) hide show
  1. kiln_ai/adapters/__init__.py +8 -2
  2. kiln_ai/adapters/adapter_registry.py +43 -208
  3. kiln_ai/adapters/chat/chat_formatter.py +8 -12
  4. kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
  5. kiln_ai/adapters/chunkers/__init__.py +13 -0
  6. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  7. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  8. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  9. kiln_ai/adapters/chunkers/helpers.py +23 -0
  10. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  11. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  12. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  13. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  14. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  15. kiln_ai/adapters/docker_model_runner_tools.py +119 -0
  16. kiln_ai/adapters/embedding/__init__.py +0 -0
  17. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  18. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  19. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  20. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  21. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  22. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  23. kiln_ai/adapters/eval/base_eval.py +2 -2
  24. kiln_ai/adapters/eval/eval_runner.py +9 -3
  25. kiln_ai/adapters/eval/g_eval.py +2 -2
  26. kiln_ai/adapters/eval/test_base_eval.py +2 -4
  27. kiln_ai/adapters/eval/test_g_eval.py +4 -5
  28. kiln_ai/adapters/extractors/__init__.py +18 -0
  29. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  30. kiln_ai/adapters/extractors/encoding.py +20 -0
  31. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  32. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  33. kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
  34. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  35. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  36. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  37. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  38. kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
  39. kiln_ai/adapters/fine_tune/__init__.py +1 -1
  40. kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
  41. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  42. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  43. kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
  44. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  45. kiln_ai/adapters/ml_embedding_model_list.py +192 -0
  46. kiln_ai/adapters/ml_model_list.py +761 -37
  47. kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
  48. kiln_ai/adapters/model_adapters/litellm_adapter.py +380 -138
  49. kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
  50. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -2
  51. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
  52. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
  53. kiln_ai/adapters/model_adapters/test_structured_output.py +113 -5
  54. kiln_ai/adapters/ollama_tools.py +69 -12
  55. kiln_ai/adapters/parsers/__init__.py +1 -1
  56. kiln_ai/adapters/provider_tools.py +205 -47
  57. kiln_ai/adapters/rag/deduplication.py +49 -0
  58. kiln_ai/adapters/rag/progress.py +252 -0
  59. kiln_ai/adapters/rag/rag_runners.py +844 -0
  60. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  61. kiln_ai/adapters/rag/test_progress.py +785 -0
  62. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  63. kiln_ai/adapters/remote_config.py +80 -8
  64. kiln_ai/adapters/repair/test_repair_task.py +12 -9
  65. kiln_ai/adapters/run_output.py +3 -0
  66. kiln_ai/adapters/test_adapter_registry.py +657 -85
  67. kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
  68. kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
  69. kiln_ai/adapters/test_ml_model_list.py +251 -1
  70. kiln_ai/adapters/test_ollama_tools.py +340 -1
  71. kiln_ai/adapters/test_prompt_adaptors.py +13 -6
  72. kiln_ai/adapters/test_prompt_builders.py +1 -1
  73. kiln_ai/adapters/test_provider_tools.py +254 -8
  74. kiln_ai/adapters/test_remote_config.py +651 -58
  75. kiln_ai/adapters/vector_store/__init__.py +1 -0
  76. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  77. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  78. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  79. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  80. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  81. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  82. kiln_ai/datamodel/__init__.py +39 -34
  83. kiln_ai/datamodel/basemodel.py +170 -1
  84. kiln_ai/datamodel/chunk.py +158 -0
  85. kiln_ai/datamodel/datamodel_enums.py +28 -0
  86. kiln_ai/datamodel/embedding.py +64 -0
  87. kiln_ai/datamodel/eval.py +1 -1
  88. kiln_ai/datamodel/external_tool_server.py +298 -0
  89. kiln_ai/datamodel/extraction.py +303 -0
  90. kiln_ai/datamodel/json_schema.py +25 -10
  91. kiln_ai/datamodel/project.py +40 -1
  92. kiln_ai/datamodel/rag.py +79 -0
  93. kiln_ai/datamodel/registry.py +0 -15
  94. kiln_ai/datamodel/run_config.py +62 -0
  95. kiln_ai/datamodel/task.py +2 -77
  96. kiln_ai/datamodel/task_output.py +6 -1
  97. kiln_ai/datamodel/task_run.py +41 -0
  98. kiln_ai/datamodel/test_attachment.py +649 -0
  99. kiln_ai/datamodel/test_basemodel.py +4 -4
  100. kiln_ai/datamodel/test_chunk_models.py +317 -0
  101. kiln_ai/datamodel/test_dataset_split.py +1 -1
  102. kiln_ai/datamodel/test_embedding_models.py +448 -0
  103. kiln_ai/datamodel/test_eval_model.py +6 -6
  104. kiln_ai/datamodel/test_example_models.py +175 -0
  105. kiln_ai/datamodel/test_external_tool_server.py +691 -0
  106. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  107. kiln_ai/datamodel/test_extraction_model.py +470 -0
  108. kiln_ai/datamodel/test_rag.py +641 -0
  109. kiln_ai/datamodel/test_registry.py +8 -3
  110. kiln_ai/datamodel/test_task.py +15 -47
  111. kiln_ai/datamodel/test_tool_id.py +320 -0
  112. kiln_ai/datamodel/test_vector_store.py +320 -0
  113. kiln_ai/datamodel/tool_id.py +105 -0
  114. kiln_ai/datamodel/vector_store.py +141 -0
  115. kiln_ai/tools/__init__.py +8 -0
  116. kiln_ai/tools/base_tool.py +82 -0
  117. kiln_ai/tools/built_in_tools/__init__.py +13 -0
  118. kiln_ai/tools/built_in_tools/math_tools.py +124 -0
  119. kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
  120. kiln_ai/tools/mcp_server_tool.py +95 -0
  121. kiln_ai/tools/mcp_session_manager.py +246 -0
  122. kiln_ai/tools/rag_tools.py +157 -0
  123. kiln_ai/tools/test_base_tools.py +199 -0
  124. kiln_ai/tools/test_mcp_server_tool.py +457 -0
  125. kiln_ai/tools/test_mcp_session_manager.py +1585 -0
  126. kiln_ai/tools/test_rag_tools.py +848 -0
  127. kiln_ai/tools/test_tool_registry.py +562 -0
  128. kiln_ai/tools/tool_registry.py +85 -0
  129. kiln_ai/utils/__init__.py +3 -0
  130. kiln_ai/utils/async_job_runner.py +62 -17
  131. kiln_ai/utils/config.py +24 -2
  132. kiln_ai/utils/env.py +15 -0
  133. kiln_ai/utils/filesystem.py +14 -0
  134. kiln_ai/utils/filesystem_cache.py +60 -0
  135. kiln_ai/utils/litellm.py +94 -0
  136. kiln_ai/utils/lock.py +100 -0
  137. kiln_ai/utils/mime_type.py +38 -0
  138. kiln_ai/utils/open_ai_types.py +94 -0
  139. kiln_ai/utils/pdf_utils.py +38 -0
  140. kiln_ai/utils/project_utils.py +17 -0
  141. kiln_ai/utils/test_async_job_runner.py +151 -35
  142. kiln_ai/utils/test_config.py +138 -1
  143. kiln_ai/utils/test_env.py +142 -0
  144. kiln_ai/utils/test_filesystem_cache.py +316 -0
  145. kiln_ai/utils/test_litellm.py +206 -0
  146. kiln_ai/utils/test_lock.py +185 -0
  147. kiln_ai/utils/test_mime_type.py +66 -0
  148. kiln_ai/utils/test_open_ai_types.py +131 -0
  149. kiln_ai/utils/test_pdf_utils.py +73 -0
  150. kiln_ai/utils/test_uuid.py +111 -0
  151. kiln_ai/utils/test_validation.py +524 -0
  152. kiln_ai/utils/uuid.py +9 -0
  153. kiln_ai/utils/validation.py +90 -0
  154. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +12 -5
  155. kiln_ai-0.21.0.dist-info/RECORD +211 -0
  156. kiln_ai-0.19.0.dist-info/RECORD +0 -115
  157. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
  158. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,9 +1,22 @@
1
+ import copy
2
+ import json
1
3
  import logging
2
- from typing import Any, Dict
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, List, Tuple, TypeAlias, Union
3
6
 
4
7
  import litellm
5
- from litellm.types.utils import ChoiceLogprobs, Choices, ModelResponse
8
+ from litellm.types.utils import (
9
+ ChatCompletionMessageToolCall,
10
+ ChoiceLogprobs,
11
+ Choices,
12
+ ModelResponse,
13
+ )
14
+ from litellm.types.utils import Message as LiteLLMMessage
6
15
  from litellm.types.utils import Usage as LiteLlmUsage
16
+ from openai.types.chat import ChatCompletionToolMessageParam
17
+ from openai.types.chat.chat_completion_message_tool_call_param import (
18
+ ChatCompletionMessageToolCallParam,
19
+ )
7
20
 
8
21
  import kiln_ai.datamodel as datamodel
9
22
  from kiln_ai.adapters.ml_model_list import (
@@ -18,11 +31,33 @@ from kiln_ai.adapters.model_adapters.base_adapter import (
18
31
  Usage,
19
32
  )
20
33
  from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
21
- from kiln_ai.datamodel.task import run_config_from_run_config_properties
34
+ from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
35
+ from kiln_ai.tools.base_tool import KilnToolInterface
22
36
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
37
+ from kiln_ai.utils.litellm import get_litellm_provider_info
38
+ from kiln_ai.utils.open_ai_types import (
39
+ ChatCompletionAssistantMessageParamWrapper,
40
+ ChatCompletionMessageParam,
41
+ )
42
+
43
+ MAX_CALLS_PER_TURN = 10
44
+ MAX_TOOL_CALLS_PER_TURN = 30
23
45
 
24
46
  logger = logging.getLogger(__name__)
25
47
 
48
+ ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[
49
+ ChatCompletionMessageParam, LiteLLMMessage
50
+ ]
51
+
52
+
53
+ @dataclass
54
+ class ModelTurnResult:
55
+ assistant_message: str
56
+ all_messages: list[ChatCompletionMessageIncludingLiteLLM]
57
+ model_response: ModelResponse | None
58
+ model_choice: Choices | None
59
+ usage: Usage
60
+
26
61
 
27
62
  class LiteLlmAdapter(BaseAdapter):
28
63
  def __init__(
@@ -36,117 +71,226 @@ class LiteLlmAdapter(BaseAdapter):
36
71
  self._api_base = config.base_url
37
72
  self._headers = config.default_headers
38
73
  self._litellm_model_id: str | None = None
74
+ self._cached_available_tools: list[KilnToolInterface] | None = None
39
75
 
40
- # Create a RunConfig, adding the task to the RunConfigProperties
41
- run_config = run_config_from_run_config_properties(
76
+ super().__init__(
42
77
  task=kiln_task,
43
- run_config_properties=config.run_config_properties,
78
+ run_config=config.run_config_properties,
79
+ config=base_adapter_config,
44
80
  )
45
81
 
46
- super().__init__(
47
- run_config=run_config,
48
- config=base_adapter_config,
82
+ async def _run_model_turn(
83
+ self,
84
+ provider: KilnModelProvider,
85
+ prior_messages: list[ChatCompletionMessageIncludingLiteLLM],
86
+ top_logprobs: int | None,
87
+ skip_response_format: bool,
88
+ ) -> ModelTurnResult:
89
+ """
90
+ Call the model for a single top level turn: from user message to agent message.
91
+
92
+ It may make handle iterations of tool calls between the user/agent message if needed.
93
+ """
94
+
95
+ usage = Usage()
96
+ messages = list(prior_messages)
97
+ tool_calls_count = 0
98
+
99
+ while tool_calls_count < MAX_TOOL_CALLS_PER_TURN:
100
+ # Build completion kwargs for tool calls
101
+ completion_kwargs = await self.build_completion_kwargs(
102
+ provider,
103
+ # Pass a copy, as acompletion mutates objects and breaks types.
104
+ copy.deepcopy(messages),
105
+ top_logprobs,
106
+ skip_response_format,
107
+ )
108
+
109
+ # Make the completion call
110
+ model_response, response_choice = await self.acompletion_checking_response(
111
+ **completion_kwargs
112
+ )
113
+
114
+ # count the usage
115
+ usage += self.usage_from_response(model_response)
116
+
117
+ # Extract content and tool calls
118
+ if not hasattr(response_choice, "message"):
119
+ raise ValueError("Response choice has no message")
120
+ content = response_choice.message.content
121
+ tool_calls = response_choice.message.tool_calls
122
+ if not content and not tool_calls:
123
+ raise ValueError(
124
+ "Model returned an assistant message, but no content or tool calls. This is not supported."
125
+ )
126
+
127
+ # Add message to messages, so it can be used in the next turn
128
+ messages.append(response_choice.message)
129
+
130
+ # Process tool calls if any
131
+ if tool_calls and len(tool_calls) > 0:
132
+ (
133
+ assistant_message_from_toolcall,
134
+ tool_call_messages,
135
+ ) = await self.process_tool_calls(tool_calls)
136
+
137
+ # Add tool call results to messages
138
+ messages.extend(tool_call_messages)
139
+
140
+ # If task_response tool was called, we're done
141
+ if assistant_message_from_toolcall is not None:
142
+ return ModelTurnResult(
143
+ assistant_message=assistant_message_from_toolcall,
144
+ all_messages=messages,
145
+ model_response=model_response,
146
+ model_choice=response_choice,
147
+ usage=usage,
148
+ )
149
+
150
+ # If there were tool calls, increment counter and continue
151
+ if tool_call_messages:
152
+ tool_calls_count += 1
153
+ continue
154
+
155
+ # If no tool calls, return the content as final output
156
+ if content:
157
+ return ModelTurnResult(
158
+ assistant_message=content,
159
+ all_messages=messages,
160
+ model_response=model_response,
161
+ model_choice=response_choice,
162
+ usage=usage,
163
+ )
164
+
165
+ # If we get here with no content and no tool calls, break
166
+ raise RuntimeError(
167
+ "Model returned neither content nor tool calls. It must return at least one of these."
168
+ )
169
+
170
+ raise RuntimeError(
171
+ f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens."
49
172
  )
50
173
 
51
174
  async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]:
175
+ usage = Usage()
176
+
52
177
  provider = self.model_provider()
53
178
  if not provider.model_id:
54
179
  raise ValueError("Model ID is required for OpenAI compatible models")
55
180
 
56
181
  chat_formatter = self.build_chat_formatter(input)
182
+ messages: list[ChatCompletionMessageIncludingLiteLLM] = []
57
183
 
58
- prior_output = None
59
- prior_message = None
60
- response = None
184
+ prior_output: str | None = None
185
+ final_choice: Choices | None = None
61
186
  turns = 0
187
+
62
188
  while True:
63
189
  turns += 1
64
- if turns > 10:
190
+ if turns > MAX_CALLS_PER_TURN:
65
191
  raise RuntimeError(
66
- "Too many turns. Stopping iteration to avoid using too many tokens."
192
+ f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens."
67
193
  )
68
194
 
69
195
  turn = chat_formatter.next_turn(prior_output)
70
196
  if turn is None:
197
+ # No next turn, we're done
71
198
  break
72
199
 
200
+ # Add messages from the turn to chat history
201
+ for message in turn.messages:
202
+ if message.content is None:
203
+ raise ValueError("Empty message content isn't allowed")
204
+ # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role.
205
+ messages.append({"role": message.role, "content": message.content}) # type: ignore
206
+
73
207
  skip_response_format = not turn.final_call
74
- all_messages = chat_formatter.message_dicts()
75
- completion_kwargs = await self.build_completion_kwargs(
208
+ turn_result = await self._run_model_turn(
76
209
  provider,
77
- all_messages,
210
+ messages,
78
211
  self.base_adapter_config.top_logprobs if turn.final_call else None,
79
212
  skip_response_format,
80
213
  )
81
- response = await litellm.acompletion(**completion_kwargs)
82
- if (
83
- not isinstance(response, ModelResponse)
84
- or not response.choices
85
- or len(response.choices) == 0
86
- or not isinstance(response.choices[0], Choices)
87
- ):
88
- raise RuntimeError(
89
- f"Expected ModelResponse with Choices, got {type(response)}."
90
- )
91
- prior_message = response.choices[0].message
92
- prior_output = prior_message.content
93
-
94
- # Fallback: Use args of first tool call to task_response if it exists
95
- if (
96
- not prior_output
97
- and hasattr(prior_message, "tool_calls")
98
- and prior_message.tool_calls
99
- ):
100
- tool_call = next(
101
- (
102
- tool_call
103
- for tool_call in prior_message.tool_calls
104
- if tool_call.function.name == "task_response"
105
- ),
106
- None,
107
- )
108
- if tool_call:
109
- prior_output = tool_call.function.arguments
214
+
215
+ usage += turn_result.usage
216
+
217
+ prior_output = turn_result.assistant_message
218
+ messages = turn_result.all_messages
219
+ final_choice = turn_result.model_choice
110
220
 
111
221
  if not prior_output:
112
- raise RuntimeError("No output returned from model")
222
+ raise RuntimeError("No assistant message/output returned from model")
113
223
 
114
- if response is None or prior_message is None:
115
- raise RuntimeError("No response returned from model")
224
+ logprobs = self._extract_and_validate_logprobs(final_choice)
116
225
 
226
+ # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream)
117
227
  intermediate_outputs = chat_formatter.intermediate_outputs()
228
+ self._extract_reasoning_to_intermediate_outputs(
229
+ final_choice, intermediate_outputs
230
+ )
231
+
232
+ if not isinstance(prior_output, str):
233
+ raise RuntimeError(f"assistant message is not a string: {prior_output}")
118
234
 
119
- logprobs = (
120
- response.choices[0].logprobs
121
- if hasattr(response.choices[0], "logprobs")
122
- and isinstance(response.choices[0].logprobs, ChoiceLogprobs)
123
- else None
235
+ trace = self.all_messages_to_trace(messages)
236
+ output = RunOutput(
237
+ output=prior_output,
238
+ intermediate_outputs=intermediate_outputs,
239
+ output_logprobs=logprobs,
240
+ trace=trace,
124
241
  )
125
242
 
126
- # Check logprobs worked, if requested
127
- if self.base_adapter_config.top_logprobs is not None and logprobs is None:
128
- raise RuntimeError("Logprobs were required, but no logprobs were returned.")
243
+ return output, usage
129
244
 
130
- # Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream)
245
+ def _extract_and_validate_logprobs(
246
+ self, final_choice: Choices | None
247
+ ) -> ChoiceLogprobs | None:
248
+ """
249
+ Extract logprobs from the final choice and validate they exist if required.
250
+ """
251
+ logprobs = None
131
252
  if (
132
- prior_message is not None
133
- and hasattr(prior_message, "reasoning_content")
134
- and prior_message.reasoning_content
135
- and len(prior_message.reasoning_content.strip()) > 0
253
+ final_choice is not None
254
+ and hasattr(final_choice, "logprobs")
255
+ and isinstance(final_choice.logprobs, ChoiceLogprobs)
136
256
  ):
137
- intermediate_outputs["reasoning"] = prior_message.reasoning_content.strip()
257
+ logprobs = final_choice.logprobs
138
258
 
139
- # the string content of the response
140
- response_content = prior_output
259
+ # Check logprobs worked, if required
260
+ if self.base_adapter_config.top_logprobs is not None and logprobs is None:
261
+ raise RuntimeError("Logprobs were required, but no logprobs were returned.")
141
262
 
142
- if not isinstance(response_content, str):
143
- raise RuntimeError(f"response is not a string: {response_content}")
263
+ return logprobs
144
264
 
145
- return RunOutput(
146
- output=response_content,
147
- intermediate_outputs=intermediate_outputs,
148
- output_logprobs=logprobs,
149
- ), self.usage_from_response(response)
265
+ def _extract_reasoning_to_intermediate_outputs(
266
+ self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any]
267
+ ) -> None:
268
+ """Extract reasoning content from model choice and add to intermediate outputs if present."""
269
+ if (
270
+ final_choice is not None
271
+ and hasattr(final_choice, "message")
272
+ and hasattr(final_choice.message, "reasoning_content")
273
+ ):
274
+ reasoning_content = final_choice.message.reasoning_content
275
+ if reasoning_content is not None:
276
+ stripped_reasoning_content = reasoning_content.strip()
277
+ if len(stripped_reasoning_content) > 0:
278
+ intermediate_outputs["reasoning"] = stripped_reasoning_content
279
+
280
+ async def acompletion_checking_response(
281
+ self, **kwargs
282
+ ) -> Tuple[ModelResponse, Choices]:
283
+ response = await litellm.acompletion(**kwargs)
284
+ if (
285
+ not isinstance(response, ModelResponse)
286
+ or not response.choices
287
+ or len(response.choices) == 0
288
+ or not isinstance(response.choices[0], Choices)
289
+ ):
290
+ raise RuntimeError(
291
+ f"Expected ModelResponse with Choices, got {type(response)}."
292
+ )
293
+ return response, response.choices[0]
150
294
 
151
295
  def adapter_name(self) -> str:
152
296
  return "kiln_openai_compatible_adapter"
@@ -181,6 +325,9 @@ class LiteLlmAdapter(BaseAdapter):
181
325
  if provider_name == ModelProviderName.ollama:
182
326
  # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
183
327
  return self.json_schema_response_format()
328
+ elif provider_name == ModelProviderName.docker_model_runner:
329
+ # Docker Model Runner uses OpenAI-compatible API with JSON schema support
330
+ return self.json_schema_response_format()
184
331
  else:
185
332
  # Default to function calling -- it's older than the other modes. Higher compatibility.
186
333
  # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
@@ -193,7 +340,7 @@ class LiteLlmAdapter(BaseAdapter):
193
340
  raise_exhaustive_enum_error(structured_output_mode)
194
341
 
195
342
  def json_schema_response_format(self) -> dict[str, Any]:
196
- output_schema = self.task().output_schema()
343
+ output_schema = self.task.output_schema()
197
344
  return {
198
345
  "response_format": {
199
346
  "type": "json_schema",
@@ -206,7 +353,7 @@ class LiteLlmAdapter(BaseAdapter):
206
353
 
207
354
  def tool_call_params(self, strict: bool) -> dict[str, Any]:
208
355
  # Add additional_properties: false to the schema (OpenAI requires this for some models)
209
- output_schema = self.task().output_schema()
356
+ output_schema = self.task.output_schema()
210
357
  if not isinstance(output_schema, dict):
211
358
  raise ValueError(
212
359
  "Invalid output schema for this task. Can not use tool calls."
@@ -297,77 +444,22 @@ class LiteLlmAdapter(BaseAdapter):
297
444
  def litellm_model_id(self) -> str:
298
445
  # The model ID is an interesting combination of format and url endpoint.
299
446
  # It specifics the provider URL/host, but this is overridden if you manually set an api url
300
-
301
447
  if self._litellm_model_id:
302
448
  return self._litellm_model_id
303
449
 
304
- provider = self.model_provider()
305
- if not provider.model_id:
306
- raise ValueError("Model ID is required for OpenAI compatible models")
307
-
308
- litellm_provider_name: str | None = None
309
- is_custom = False
310
- match provider.name:
311
- case ModelProviderName.openrouter:
312
- litellm_provider_name = "openrouter"
313
- case ModelProviderName.openai:
314
- litellm_provider_name = "openai"
315
- case ModelProviderName.groq:
316
- litellm_provider_name = "groq"
317
- case ModelProviderName.anthropic:
318
- litellm_provider_name = "anthropic"
319
- case ModelProviderName.ollama:
320
- # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
321
- # This is because we're setting detailed features like response_format=json_schema and want lower level control.
322
- is_custom = True
323
- case ModelProviderName.gemini_api:
324
- litellm_provider_name = "gemini"
325
- case ModelProviderName.fireworks_ai:
326
- litellm_provider_name = "fireworks_ai"
327
- case ModelProviderName.amazon_bedrock:
328
- litellm_provider_name = "bedrock"
329
- case ModelProviderName.azure_openai:
330
- litellm_provider_name = "azure"
331
- case ModelProviderName.huggingface:
332
- litellm_provider_name = "huggingface"
333
- case ModelProviderName.vertex:
334
- litellm_provider_name = "vertex_ai"
335
- case ModelProviderName.together_ai:
336
- litellm_provider_name = "together_ai"
337
- case ModelProviderName.cerebras:
338
- litellm_provider_name = "cerebras"
339
- case ModelProviderName.siliconflow_cn:
340
- is_custom = True
341
- case ModelProviderName.openai_compatible:
342
- is_custom = True
343
- case ModelProviderName.kiln_custom_registry:
344
- is_custom = True
345
- case ModelProviderName.kiln_fine_tune:
346
- is_custom = True
347
- case _:
348
- raise_exhaustive_enum_error(provider.name)
349
-
350
- if is_custom:
351
- if self._api_base is None:
352
- raise ValueError(
353
- "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
354
- )
355
- # Use openai as it's only used for format, not url
356
- litellm_provider_name = "openai"
357
-
358
- # Sholdn't be possible but keep type checker happy
359
- if litellm_provider_name is None:
450
+ litellm_provider_info = get_litellm_provider_info(self.model_provider())
451
+ if litellm_provider_info.is_custom and self._api_base is None:
360
452
  raise ValueError(
361
- f"Provider name could not lookup valid litellm provider ID {provider.model_id}"
453
+ "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
362
454
  )
363
455
 
364
- self._litellm_model_id = litellm_provider_name + "/" + provider.model_id
456
+ self._litellm_model_id = litellm_provider_info.litellm_model_id
365
457
  return self._litellm_model_id
366
458
 
367
459
  async def build_completion_kwargs(
368
460
  self,
369
461
  provider: KilnModelProvider,
370
- messages: list[dict[str, Any]],
462
+ messages: list[ChatCompletionMessageIncludingLiteLLM],
371
463
  top_logprobs: int | None,
372
464
  skip_response_format: bool = False,
373
465
  ) -> dict[str, Any]:
@@ -390,9 +482,23 @@ class LiteLlmAdapter(BaseAdapter):
390
482
  **self._additional_body_options,
391
483
  }
392
484
 
485
+ tool_calls = await self.litellm_tools()
486
+ has_tools = len(tool_calls) > 0
487
+ if has_tools:
488
+ completion_kwargs["tools"] = tool_calls
489
+ completion_kwargs["tool_choice"] = "auto"
490
+
393
491
  if not skip_response_format:
394
492
  # Response format: json_schema, json_instructions, json_mode, function_calling, etc
395
493
  response_format_options = await self.response_format_options()
494
+
495
+ # Check for a conflict between tools and response format using tools
496
+ # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged).
497
+ if has_tools and "tools" in response_format_options:
498
+ raise ValueError(
499
+ "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode."
500
+ )
501
+
396
502
  completion_kwargs.update(response_format_options)
397
503
 
398
504
  if top_logprobs is not None:
@@ -401,7 +507,7 @@ class LiteLlmAdapter(BaseAdapter):
401
507
 
402
508
  return completion_kwargs
403
509
 
404
- def usage_from_response(self, response: ModelResponse) -> Usage | None:
510
+ def usage_from_response(self, response: ModelResponse) -> Usage:
405
511
  litellm_usage = response.get("usage", None)
406
512
 
407
513
  # LiteLLM isn't consistent in how it returns the cost.
@@ -409,11 +515,11 @@ class LiteLlmAdapter(BaseAdapter):
409
515
  if cost is None and litellm_usage:
410
516
  cost = litellm_usage.get("cost", None)
411
517
 
412
- if not litellm_usage and not cost:
413
- return None
414
-
415
518
  usage = Usage()
416
519
 
520
+ if not litellm_usage and not cost:
521
+ return usage
522
+
417
523
  if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
418
524
  usage.input_tokens = litellm_usage.get("prompt_tokens", None)
419
525
  usage.output_tokens = litellm_usage.get("completion_tokens", None)
@@ -432,3 +538,139 @@ class LiteLlmAdapter(BaseAdapter):
432
538
  )
433
539
 
434
540
  return usage
541
+
542
+ async def cached_available_tools(self) -> list[KilnToolInterface]:
543
+ if self._cached_available_tools is None:
544
+ self._cached_available_tools = await self.available_tools()
545
+ return self._cached_available_tools
546
+
547
+ async def litellm_tools(self) -> list[Dict]:
548
+ available_tools = await self.cached_available_tools()
549
+
550
+ # LiteLLM takes the standard OpenAI-compatible tool call format
551
+ return [await tool.toolcall_definition() for tool in available_tools]
552
+
553
+ async def process_tool_calls(
554
+ self, tool_calls: list[ChatCompletionMessageToolCall] | None
555
+ ) -> tuple[str | None, list[ChatCompletionToolMessageParam]]:
556
+ if tool_calls is None:
557
+ return None, []
558
+
559
+ assistant_output_from_toolcall: str | None = None
560
+ tool_call_response_messages: list[ChatCompletionToolMessageParam] = []
561
+
562
+ for tool_call in tool_calls:
563
+ # Kiln "task_response" tool is used for returning structured output via tool calls.
564
+ # Load the output from the tool call. Also
565
+ if tool_call.function.name == "task_response":
566
+ assistant_output_from_toolcall = tool_call.function.arguments
567
+ continue
568
+
569
+ # Process normal tool calls (not the "task_response" tool)
570
+ tool_name = tool_call.function.name
571
+ tool = None
572
+ for tool_option in await self.cached_available_tools():
573
+ if await tool_option.name() == tool_name:
574
+ tool = tool_option
575
+ break
576
+ if not tool:
577
+ raise RuntimeError(
578
+ f"A tool named '{tool_name}' was invoked by a model, but was not available."
579
+ )
580
+
581
+ # Parse the arguments and validate them against the tool's schema
582
+ try:
583
+ parsed_args = json.loads(tool_call.function.arguments)
584
+ except json.JSONDecodeError:
585
+ raise RuntimeError(
586
+ f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}"
587
+ )
588
+ try:
589
+ tool_call_definition = await tool.toolcall_definition()
590
+ json_schema = json.dumps(tool_call_definition["function"]["parameters"])
591
+ validate_schema_with_value_error(parsed_args, json_schema)
592
+ except Exception as e:
593
+ raise RuntimeError(
594
+ f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
595
+ ) from e
596
+
597
+ result = await tool.run(**parsed_args)
598
+
599
+ tool_call_response_messages.append(
600
+ ChatCompletionToolMessageParam(
601
+ role="tool",
602
+ tool_call_id=tool_call.id,
603
+ content=result,
604
+ )
605
+ )
606
+
607
+ if (
608
+ assistant_output_from_toolcall is not None
609
+ and len(tool_call_response_messages) > 0
610
+ ):
611
+ raise RuntimeError(
612
+ "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible."
613
+ )
614
+
615
+ return assistant_output_from_toolcall, tool_call_response_messages
616
+
617
+ def litellm_message_to_trace_message(
618
+ self, raw_message: LiteLLMMessage
619
+ ) -> ChatCompletionAssistantMessageParamWrapper:
620
+ """
621
+ Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
622
+ """
623
+ message: ChatCompletionAssistantMessageParamWrapper = {
624
+ "role": "assistant",
625
+ }
626
+ if raw_message.role != "assistant":
627
+ raise ValueError(
628
+ "Model returned a message with a role other than assistant. This is not supported."
629
+ )
630
+
631
+ if hasattr(raw_message, "content"):
632
+ message["content"] = raw_message.content
633
+ if hasattr(raw_message, "reasoning_content"):
634
+ message["reasoning_content"] = raw_message.reasoning_content
635
+ if hasattr(raw_message, "tool_calls"):
636
+ # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam
637
+ open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
638
+ for litellm_tool_call in raw_message.tool_calls or []:
639
+ # Optional in the SDK for streaming responses, but should never be None at this point.
640
+ if litellm_tool_call.function.name is None:
641
+ raise ValueError(
642
+ "The model requested a tool call, without providing a function name (required)."
643
+ )
644
+ open_ai_tool_calls.append(
645
+ ChatCompletionMessageToolCallParam(
646
+ id=litellm_tool_call.id,
647
+ type="function",
648
+ function={
649
+ "name": litellm_tool_call.function.name,
650
+ "arguments": litellm_tool_call.function.arguments,
651
+ },
652
+ )
653
+ )
654
+ if len(open_ai_tool_calls) > 0:
655
+ message["tool_calls"] = open_ai_tool_calls
656
+
657
+ if not message.get("content") and not message.get("tool_calls"):
658
+ raise ValueError(
659
+ "Model returned an assistant message, but no content or tool calls. This is not supported."
660
+ )
661
+
662
+ return message
663
+
664
+ def all_messages_to_trace(
665
+ self, messages: list[ChatCompletionMessageIncludingLiteLLM]
666
+ ) -> list[ChatCompletionMessageParam]:
667
+ """
668
+ Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
669
+ """
670
+ trace: list[ChatCompletionMessageParam] = []
671
+ for message in messages:
672
+ if isinstance(message, LiteLLMMessage):
673
+ trace.append(self.litellm_message_to_trace_message(message))
674
+ else:
675
+ trace.append(message)
676
+ return trace