kiln-ai 0.19.0__py3-none-any.whl → 0.20.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -2
- kiln_ai/adapters/adapter_registry.py +19 -1
- kiln_ai/adapters/chat/chat_formatter.py +8 -12
- kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
- kiln_ai/adapters/docker_model_runner_tools.py +119 -0
- kiln_ai/adapters/eval/base_eval.py +2 -2
- kiln_ai/adapters/eval/eval_runner.py +3 -1
- kiln_ai/adapters/eval/g_eval.py +2 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -1
- kiln_ai/adapters/eval/test_g_eval.py +3 -4
- kiln_ai/adapters/fine_tune/__init__.py +1 -1
- kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
- kiln_ai/adapters/ml_model_list.py +380 -34
- kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
- kiln_ai/adapters/model_adapters/litellm_adapter.py +383 -79
- kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +406 -1
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
- kiln_ai/adapters/model_adapters/test_structured_output.py +110 -4
- kiln_ai/adapters/parsers/__init__.py +1 -1
- kiln_ai/adapters/provider_tools.py +15 -1
- kiln_ai/adapters/repair/test_repair_task.py +12 -9
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +80 -1
- kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
- kiln_ai/adapters/test_ml_model_list.py +39 -1
- kiln_ai/adapters/test_prompt_adaptors.py +13 -6
- kiln_ai/adapters/test_provider_tools.py +55 -0
- kiln_ai/adapters/test_remote_config.py +98 -0
- kiln_ai/datamodel/__init__.py +23 -21
- kiln_ai/datamodel/datamodel_enums.py +1 -0
- kiln_ai/datamodel/eval.py +1 -1
- kiln_ai/datamodel/external_tool_server.py +298 -0
- kiln_ai/datamodel/json_schema.py +25 -10
- kiln_ai/datamodel/project.py +8 -1
- kiln_ai/datamodel/registry.py +0 -15
- kiln_ai/datamodel/run_config.py +62 -0
- kiln_ai/datamodel/task.py +2 -77
- kiln_ai/datamodel/task_output.py +6 -1
- kiln_ai/datamodel/task_run.py +41 -0
- kiln_ai/datamodel/test_basemodel.py +3 -3
- kiln_ai/datamodel/test_example_models.py +175 -0
- kiln_ai/datamodel/test_external_tool_server.py +691 -0
- kiln_ai/datamodel/test_registry.py +8 -3
- kiln_ai/datamodel/test_task.py +15 -47
- kiln_ai/datamodel/test_tool_id.py +239 -0
- kiln_ai/datamodel/tool_id.py +83 -0
- kiln_ai/tools/__init__.py +8 -0
- kiln_ai/tools/base_tool.py +82 -0
- kiln_ai/tools/built_in_tools/__init__.py +13 -0
- kiln_ai/tools/built_in_tools/math_tools.py +124 -0
- kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
- kiln_ai/tools/mcp_server_tool.py +95 -0
- kiln_ai/tools/mcp_session_manager.py +243 -0
- kiln_ai/tools/test_base_tools.py +199 -0
- kiln_ai/tools/test_mcp_server_tool.py +457 -0
- kiln_ai/tools/test_mcp_session_manager.py +1585 -0
- kiln_ai/tools/test_tool_registry.py +473 -0
- kiln_ai/tools/tool_registry.py +64 -0
- kiln_ai/utils/config.py +22 -0
- kiln_ai/utils/open_ai_types.py +94 -0
- kiln_ai/utils/project_utils.py +17 -0
- kiln_ai/utils/test_config.py +138 -1
- kiln_ai/utils/test_open_ai_types.py +131 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.20.1.dist-info}/METADATA +6 -5
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.20.1.dist-info}/RECORD +70 -47
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.20.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.20.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,9 +1,26 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import json
|
|
1
3
|
import logging
|
|
2
|
-
from
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any, Dict, List, Tuple, TypeAlias, Union
|
|
3
6
|
|
|
4
7
|
import litellm
|
|
5
|
-
from litellm.types.utils import
|
|
8
|
+
from litellm.types.utils import (
|
|
9
|
+
ChatCompletionMessageToolCall,
|
|
10
|
+
ChoiceLogprobs,
|
|
11
|
+
Choices,
|
|
12
|
+
ModelResponse,
|
|
13
|
+
)
|
|
14
|
+
from litellm.types.utils import (
|
|
15
|
+
Message as LiteLLMMessage,
|
|
16
|
+
)
|
|
6
17
|
from litellm.types.utils import Usage as LiteLlmUsage
|
|
18
|
+
from openai.types.chat import (
|
|
19
|
+
ChatCompletionToolMessageParam,
|
|
20
|
+
)
|
|
21
|
+
from openai.types.chat.chat_completion_message_tool_call_param import (
|
|
22
|
+
ChatCompletionMessageToolCallParam,
|
|
23
|
+
)
|
|
7
24
|
|
|
8
25
|
import kiln_ai.datamodel as datamodel
|
|
9
26
|
from kiln_ai.adapters.ml_model_list import (
|
|
@@ -18,11 +35,32 @@ from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
|
18
35
|
Usage,
|
|
19
36
|
)
|
|
20
37
|
from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
|
|
21
|
-
from kiln_ai.datamodel.
|
|
38
|
+
from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
|
|
39
|
+
from kiln_ai.tools.base_tool import KilnToolInterface
|
|
22
40
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
41
|
+
from kiln_ai.utils.open_ai_types import (
|
|
42
|
+
ChatCompletionAssistantMessageParamWrapper,
|
|
43
|
+
ChatCompletionMessageParam,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
MAX_CALLS_PER_TURN = 10
|
|
47
|
+
MAX_TOOL_CALLS_PER_TURN = 30
|
|
23
48
|
|
|
24
49
|
logger = logging.getLogger(__name__)
|
|
25
50
|
|
|
51
|
+
ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[
|
|
52
|
+
ChatCompletionMessageParam, LiteLLMMessage
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class ModelTurnResult:
|
|
58
|
+
assistant_message: str
|
|
59
|
+
all_messages: list[ChatCompletionMessageIncludingLiteLLM]
|
|
60
|
+
model_response: ModelResponse | None
|
|
61
|
+
model_choice: Choices | None
|
|
62
|
+
usage: Usage
|
|
63
|
+
|
|
26
64
|
|
|
27
65
|
class LiteLlmAdapter(BaseAdapter):
|
|
28
66
|
def __init__(
|
|
@@ -36,117 +74,226 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
36
74
|
self._api_base = config.base_url
|
|
37
75
|
self._headers = config.default_headers
|
|
38
76
|
self._litellm_model_id: str | None = None
|
|
77
|
+
self._cached_available_tools: list[KilnToolInterface] | None = None
|
|
39
78
|
|
|
40
|
-
|
|
41
|
-
run_config = run_config_from_run_config_properties(
|
|
79
|
+
super().__init__(
|
|
42
80
|
task=kiln_task,
|
|
43
|
-
|
|
81
|
+
run_config=config.run_config_properties,
|
|
82
|
+
config=base_adapter_config,
|
|
44
83
|
)
|
|
45
84
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
85
|
+
async def _run_model_turn(
|
|
86
|
+
self,
|
|
87
|
+
provider: KilnModelProvider,
|
|
88
|
+
prior_messages: list[ChatCompletionMessageIncludingLiteLLM],
|
|
89
|
+
top_logprobs: int | None,
|
|
90
|
+
skip_response_format: bool,
|
|
91
|
+
) -> ModelTurnResult:
|
|
92
|
+
"""
|
|
93
|
+
Call the model for a single top level turn: from user message to agent message.
|
|
94
|
+
|
|
95
|
+
It may make handle iterations of tool calls between the user/agent message if needed.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
usage = Usage()
|
|
99
|
+
messages = list(prior_messages)
|
|
100
|
+
tool_calls_count = 0
|
|
101
|
+
|
|
102
|
+
while tool_calls_count < MAX_TOOL_CALLS_PER_TURN:
|
|
103
|
+
# Build completion kwargs for tool calls
|
|
104
|
+
completion_kwargs = await self.build_completion_kwargs(
|
|
105
|
+
provider,
|
|
106
|
+
# Pass a copy, as acompletion mutates objects and breaks types.
|
|
107
|
+
copy.deepcopy(messages),
|
|
108
|
+
top_logprobs,
|
|
109
|
+
skip_response_format,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Make the completion call
|
|
113
|
+
model_response, response_choice = await self.acompletion_checking_response(
|
|
114
|
+
**completion_kwargs
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# count the usage
|
|
118
|
+
usage += self.usage_from_response(model_response)
|
|
119
|
+
|
|
120
|
+
# Extract content and tool calls
|
|
121
|
+
if not hasattr(response_choice, "message"):
|
|
122
|
+
raise ValueError("Response choice has no message")
|
|
123
|
+
content = response_choice.message.content
|
|
124
|
+
tool_calls = response_choice.message.tool_calls
|
|
125
|
+
if not content and not tool_calls:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
"Model returned an assistant message, but no content or tool calls. This is not supported."
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Add message to messages, so it can be used in the next turn
|
|
131
|
+
messages.append(response_choice.message)
|
|
132
|
+
|
|
133
|
+
# Process tool calls if any
|
|
134
|
+
if tool_calls and len(tool_calls) > 0:
|
|
135
|
+
(
|
|
136
|
+
assistant_message_from_toolcall,
|
|
137
|
+
tool_call_messages,
|
|
138
|
+
) = await self.process_tool_calls(tool_calls)
|
|
139
|
+
|
|
140
|
+
# Add tool call results to messages
|
|
141
|
+
messages.extend(tool_call_messages)
|
|
142
|
+
|
|
143
|
+
# If task_response tool was called, we're done
|
|
144
|
+
if assistant_message_from_toolcall is not None:
|
|
145
|
+
return ModelTurnResult(
|
|
146
|
+
assistant_message=assistant_message_from_toolcall,
|
|
147
|
+
all_messages=messages,
|
|
148
|
+
model_response=model_response,
|
|
149
|
+
model_choice=response_choice,
|
|
150
|
+
usage=usage,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# If there were tool calls, increment counter and continue
|
|
154
|
+
if tool_call_messages:
|
|
155
|
+
tool_calls_count += 1
|
|
156
|
+
continue
|
|
157
|
+
|
|
158
|
+
# If no tool calls, return the content as final output
|
|
159
|
+
if content:
|
|
160
|
+
return ModelTurnResult(
|
|
161
|
+
assistant_message=content,
|
|
162
|
+
all_messages=messages,
|
|
163
|
+
model_response=model_response,
|
|
164
|
+
model_choice=response_choice,
|
|
165
|
+
usage=usage,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# If we get here with no content and no tool calls, break
|
|
169
|
+
raise RuntimeError(
|
|
170
|
+
"Model returned neither content nor tool calls. It must return at least one of these."
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
raise RuntimeError(
|
|
174
|
+
f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens."
|
|
49
175
|
)
|
|
50
176
|
|
|
51
177
|
async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]:
|
|
178
|
+
usage = Usage()
|
|
179
|
+
|
|
52
180
|
provider = self.model_provider()
|
|
53
181
|
if not provider.model_id:
|
|
54
182
|
raise ValueError("Model ID is required for OpenAI compatible models")
|
|
55
183
|
|
|
56
184
|
chat_formatter = self.build_chat_formatter(input)
|
|
185
|
+
messages: list[ChatCompletionMessageIncludingLiteLLM] = []
|
|
57
186
|
|
|
58
|
-
prior_output = None
|
|
59
|
-
|
|
60
|
-
response = None
|
|
187
|
+
prior_output: str | None = None
|
|
188
|
+
final_choice: Choices | None = None
|
|
61
189
|
turns = 0
|
|
190
|
+
|
|
62
191
|
while True:
|
|
63
192
|
turns += 1
|
|
64
|
-
if turns >
|
|
193
|
+
if turns > MAX_CALLS_PER_TURN:
|
|
65
194
|
raise RuntimeError(
|
|
66
|
-
"Too many turns. Stopping iteration to avoid using too many tokens."
|
|
195
|
+
f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens."
|
|
67
196
|
)
|
|
68
197
|
|
|
69
198
|
turn = chat_formatter.next_turn(prior_output)
|
|
70
199
|
if turn is None:
|
|
200
|
+
# No next turn, we're done
|
|
71
201
|
break
|
|
72
202
|
|
|
203
|
+
# Add messages from the turn to chat history
|
|
204
|
+
for message in turn.messages:
|
|
205
|
+
if message.content is None:
|
|
206
|
+
raise ValueError("Empty message content isn't allowed")
|
|
207
|
+
# pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role.
|
|
208
|
+
messages.append({"role": message.role, "content": message.content}) # type: ignore
|
|
209
|
+
|
|
73
210
|
skip_response_format = not turn.final_call
|
|
74
|
-
|
|
75
|
-
completion_kwargs = await self.build_completion_kwargs(
|
|
211
|
+
turn_result = await self._run_model_turn(
|
|
76
212
|
provider,
|
|
77
|
-
|
|
213
|
+
messages,
|
|
78
214
|
self.base_adapter_config.top_logprobs if turn.final_call else None,
|
|
79
215
|
skip_response_format,
|
|
80
216
|
)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
):
|
|
88
|
-
raise RuntimeError(
|
|
89
|
-
f"Expected ModelResponse with Choices, got {type(response)}."
|
|
90
|
-
)
|
|
91
|
-
prior_message = response.choices[0].message
|
|
92
|
-
prior_output = prior_message.content
|
|
93
|
-
|
|
94
|
-
# Fallback: Use args of first tool call to task_response if it exists
|
|
95
|
-
if (
|
|
96
|
-
not prior_output
|
|
97
|
-
and hasattr(prior_message, "tool_calls")
|
|
98
|
-
and prior_message.tool_calls
|
|
99
|
-
):
|
|
100
|
-
tool_call = next(
|
|
101
|
-
(
|
|
102
|
-
tool_call
|
|
103
|
-
for tool_call in prior_message.tool_calls
|
|
104
|
-
if tool_call.function.name == "task_response"
|
|
105
|
-
),
|
|
106
|
-
None,
|
|
107
|
-
)
|
|
108
|
-
if tool_call:
|
|
109
|
-
prior_output = tool_call.function.arguments
|
|
217
|
+
|
|
218
|
+
usage += turn_result.usage
|
|
219
|
+
|
|
220
|
+
prior_output = turn_result.assistant_message
|
|
221
|
+
messages = turn_result.all_messages
|
|
222
|
+
final_choice = turn_result.model_choice
|
|
110
223
|
|
|
111
224
|
if not prior_output:
|
|
112
|
-
raise RuntimeError("No output returned from model")
|
|
225
|
+
raise RuntimeError("No assistant message/output returned from model")
|
|
113
226
|
|
|
114
|
-
|
|
115
|
-
raise RuntimeError("No response returned from model")
|
|
227
|
+
logprobs = self._extract_and_validate_logprobs(final_choice)
|
|
116
228
|
|
|
229
|
+
# Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream)
|
|
117
230
|
intermediate_outputs = chat_formatter.intermediate_outputs()
|
|
231
|
+
self._extract_reasoning_to_intermediate_outputs(
|
|
232
|
+
final_choice, intermediate_outputs
|
|
233
|
+
)
|
|
118
234
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
235
|
+
if not isinstance(prior_output, str):
|
|
236
|
+
raise RuntimeError(f"assistant message is not a string: {prior_output}")
|
|
237
|
+
|
|
238
|
+
trace = self.all_messages_to_trace(messages)
|
|
239
|
+
output = RunOutput(
|
|
240
|
+
output=prior_output,
|
|
241
|
+
intermediate_outputs=intermediate_outputs,
|
|
242
|
+
output_logprobs=logprobs,
|
|
243
|
+
trace=trace,
|
|
124
244
|
)
|
|
125
245
|
|
|
126
|
-
|
|
127
|
-
if self.base_adapter_config.top_logprobs is not None and logprobs is None:
|
|
128
|
-
raise RuntimeError("Logprobs were required, but no logprobs were returned.")
|
|
246
|
+
return output, usage
|
|
129
247
|
|
|
130
|
-
|
|
248
|
+
def _extract_and_validate_logprobs(
|
|
249
|
+
self, final_choice: Choices | None
|
|
250
|
+
) -> ChoiceLogprobs | None:
|
|
251
|
+
"""
|
|
252
|
+
Extract logprobs from the final choice and validate they exist if required.
|
|
253
|
+
"""
|
|
254
|
+
logprobs = None
|
|
131
255
|
if (
|
|
132
|
-
|
|
133
|
-
and hasattr(
|
|
134
|
-
and
|
|
135
|
-
and len(prior_message.reasoning_content.strip()) > 0
|
|
256
|
+
final_choice is not None
|
|
257
|
+
and hasattr(final_choice, "logprobs")
|
|
258
|
+
and isinstance(final_choice.logprobs, ChoiceLogprobs)
|
|
136
259
|
):
|
|
137
|
-
|
|
260
|
+
logprobs = final_choice.logprobs
|
|
138
261
|
|
|
139
|
-
#
|
|
140
|
-
|
|
262
|
+
# Check logprobs worked, if required
|
|
263
|
+
if self.base_adapter_config.top_logprobs is not None and logprobs is None:
|
|
264
|
+
raise RuntimeError("Logprobs were required, but no logprobs were returned.")
|
|
141
265
|
|
|
142
|
-
|
|
143
|
-
raise RuntimeError(f"response is not a string: {response_content}")
|
|
266
|
+
return logprobs
|
|
144
267
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
268
|
+
def _extract_reasoning_to_intermediate_outputs(
|
|
269
|
+
self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any]
|
|
270
|
+
) -> None:
|
|
271
|
+
"""Extract reasoning content from model choice and add to intermediate outputs if present."""
|
|
272
|
+
if (
|
|
273
|
+
final_choice is not None
|
|
274
|
+
and hasattr(final_choice, "message")
|
|
275
|
+
and hasattr(final_choice.message, "reasoning_content")
|
|
276
|
+
):
|
|
277
|
+
reasoning_content = final_choice.message.reasoning_content
|
|
278
|
+
if reasoning_content is not None:
|
|
279
|
+
stripped_reasoning_content = reasoning_content.strip()
|
|
280
|
+
if len(stripped_reasoning_content) > 0:
|
|
281
|
+
intermediate_outputs["reasoning"] = stripped_reasoning_content
|
|
282
|
+
|
|
283
|
+
async def acompletion_checking_response(
|
|
284
|
+
self, **kwargs
|
|
285
|
+
) -> Tuple[ModelResponse, Choices]:
|
|
286
|
+
response = await litellm.acompletion(**kwargs)
|
|
287
|
+
if (
|
|
288
|
+
not isinstance(response, ModelResponse)
|
|
289
|
+
or not response.choices
|
|
290
|
+
or len(response.choices) == 0
|
|
291
|
+
or not isinstance(response.choices[0], Choices)
|
|
292
|
+
):
|
|
293
|
+
raise RuntimeError(
|
|
294
|
+
f"Expected ModelResponse with Choices, got {type(response)}."
|
|
295
|
+
)
|
|
296
|
+
return response, response.choices[0]
|
|
150
297
|
|
|
151
298
|
def adapter_name(self) -> str:
|
|
152
299
|
return "kiln_openai_compatible_adapter"
|
|
@@ -181,6 +328,9 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
181
328
|
if provider_name == ModelProviderName.ollama:
|
|
182
329
|
# Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
|
|
183
330
|
return self.json_schema_response_format()
|
|
331
|
+
elif provider_name == ModelProviderName.docker_model_runner:
|
|
332
|
+
# Docker Model Runner uses OpenAI-compatible API with JSON schema support
|
|
333
|
+
return self.json_schema_response_format()
|
|
184
334
|
else:
|
|
185
335
|
# Default to function calling -- it's older than the other modes. Higher compatibility.
|
|
186
336
|
# Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
|
|
@@ -193,7 +343,7 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
193
343
|
raise_exhaustive_enum_error(structured_output_mode)
|
|
194
344
|
|
|
195
345
|
def json_schema_response_format(self) -> dict[str, Any]:
|
|
196
|
-
output_schema = self.task
|
|
346
|
+
output_schema = self.task.output_schema()
|
|
197
347
|
return {
|
|
198
348
|
"response_format": {
|
|
199
349
|
"type": "json_schema",
|
|
@@ -206,7 +356,7 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
206
356
|
|
|
207
357
|
def tool_call_params(self, strict: bool) -> dict[str, Any]:
|
|
208
358
|
# Add additional_properties: false to the schema (OpenAI requires this for some models)
|
|
209
|
-
output_schema = self.task
|
|
359
|
+
output_schema = self.task.output_schema()
|
|
210
360
|
if not isinstance(output_schema, dict):
|
|
211
361
|
raise ValueError(
|
|
212
362
|
"Invalid output schema for this task. Can not use tool calls."
|
|
@@ -320,6 +470,10 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
320
470
|
# We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
|
|
321
471
|
# This is because we're setting detailed features like response_format=json_schema and want lower level control.
|
|
322
472
|
is_custom = True
|
|
473
|
+
case ModelProviderName.docker_model_runner:
|
|
474
|
+
# Docker Model Runner uses OpenAI-compatible API, similar to Ollama
|
|
475
|
+
# We want direct control over the requests for features like response_format=json_schema
|
|
476
|
+
is_custom = True
|
|
323
477
|
case ModelProviderName.gemini_api:
|
|
324
478
|
litellm_provider_name = "gemini"
|
|
325
479
|
case ModelProviderName.fireworks_ai:
|
|
@@ -367,7 +521,7 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
367
521
|
async def build_completion_kwargs(
|
|
368
522
|
self,
|
|
369
523
|
provider: KilnModelProvider,
|
|
370
|
-
messages: list[
|
|
524
|
+
messages: list[ChatCompletionMessageIncludingLiteLLM],
|
|
371
525
|
top_logprobs: int | None,
|
|
372
526
|
skip_response_format: bool = False,
|
|
373
527
|
) -> dict[str, Any]:
|
|
@@ -390,9 +544,23 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
390
544
|
**self._additional_body_options,
|
|
391
545
|
}
|
|
392
546
|
|
|
547
|
+
tool_calls = await self.litellm_tools()
|
|
548
|
+
has_tools = len(tool_calls) > 0
|
|
549
|
+
if has_tools:
|
|
550
|
+
completion_kwargs["tools"] = tool_calls
|
|
551
|
+
completion_kwargs["tool_choice"] = "auto"
|
|
552
|
+
|
|
393
553
|
if not skip_response_format:
|
|
394
554
|
# Response format: json_schema, json_instructions, json_mode, function_calling, etc
|
|
395
555
|
response_format_options = await self.response_format_options()
|
|
556
|
+
|
|
557
|
+
# Check for a conflict between tools and response format using tools
|
|
558
|
+
# We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged).
|
|
559
|
+
if has_tools and "tools" in response_format_options:
|
|
560
|
+
raise ValueError(
|
|
561
|
+
"Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode."
|
|
562
|
+
)
|
|
563
|
+
|
|
396
564
|
completion_kwargs.update(response_format_options)
|
|
397
565
|
|
|
398
566
|
if top_logprobs is not None:
|
|
@@ -401,7 +569,7 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
401
569
|
|
|
402
570
|
return completion_kwargs
|
|
403
571
|
|
|
404
|
-
def usage_from_response(self, response: ModelResponse) -> Usage
|
|
572
|
+
def usage_from_response(self, response: ModelResponse) -> Usage:
|
|
405
573
|
litellm_usage = response.get("usage", None)
|
|
406
574
|
|
|
407
575
|
# LiteLLM isn't consistent in how it returns the cost.
|
|
@@ -409,11 +577,11 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
409
577
|
if cost is None and litellm_usage:
|
|
410
578
|
cost = litellm_usage.get("cost", None)
|
|
411
579
|
|
|
412
|
-
if not litellm_usage and not cost:
|
|
413
|
-
return None
|
|
414
|
-
|
|
415
580
|
usage = Usage()
|
|
416
581
|
|
|
582
|
+
if not litellm_usage and not cost:
|
|
583
|
+
return usage
|
|
584
|
+
|
|
417
585
|
if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
|
|
418
586
|
usage.input_tokens = litellm_usage.get("prompt_tokens", None)
|
|
419
587
|
usage.output_tokens = litellm_usage.get("completion_tokens", None)
|
|
@@ -432,3 +600,139 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
432
600
|
)
|
|
433
601
|
|
|
434
602
|
return usage
|
|
603
|
+
|
|
604
|
+
async def cached_available_tools(self) -> list[KilnToolInterface]:
|
|
605
|
+
if self._cached_available_tools is None:
|
|
606
|
+
self._cached_available_tools = await self.available_tools()
|
|
607
|
+
return self._cached_available_tools
|
|
608
|
+
|
|
609
|
+
async def litellm_tools(self) -> list[Dict]:
|
|
610
|
+
available_tools = await self.cached_available_tools()
|
|
611
|
+
|
|
612
|
+
# LiteLLM takes the standard OpenAI-compatible tool call format
|
|
613
|
+
return [await tool.toolcall_definition() for tool in available_tools]
|
|
614
|
+
|
|
615
|
+
async def process_tool_calls(
|
|
616
|
+
self, tool_calls: list[ChatCompletionMessageToolCall] | None
|
|
617
|
+
) -> tuple[str | None, list[ChatCompletionToolMessageParam]]:
|
|
618
|
+
if tool_calls is None:
|
|
619
|
+
return None, []
|
|
620
|
+
|
|
621
|
+
assistant_output_from_toolcall: str | None = None
|
|
622
|
+
tool_call_response_messages: list[ChatCompletionToolMessageParam] = []
|
|
623
|
+
|
|
624
|
+
for tool_call in tool_calls:
|
|
625
|
+
# Kiln "task_response" tool is used for returning structured output via tool calls.
|
|
626
|
+
# Load the output from the tool call. Also
|
|
627
|
+
if tool_call.function.name == "task_response":
|
|
628
|
+
assistant_output_from_toolcall = tool_call.function.arguments
|
|
629
|
+
continue
|
|
630
|
+
|
|
631
|
+
# Process normal tool calls (not the "task_response" tool)
|
|
632
|
+
tool_name = tool_call.function.name
|
|
633
|
+
tool = None
|
|
634
|
+
for tool_option in await self.cached_available_tools():
|
|
635
|
+
if await tool_option.name() == tool_name:
|
|
636
|
+
tool = tool_option
|
|
637
|
+
break
|
|
638
|
+
if not tool:
|
|
639
|
+
raise RuntimeError(
|
|
640
|
+
f"A tool named '{tool_name}' was invoked by a model, but was not available."
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
# Parse the arguments and validate them against the tool's schema
|
|
644
|
+
try:
|
|
645
|
+
parsed_args = json.loads(tool_call.function.arguments)
|
|
646
|
+
except json.JSONDecodeError:
|
|
647
|
+
raise RuntimeError(
|
|
648
|
+
f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}"
|
|
649
|
+
)
|
|
650
|
+
try:
|
|
651
|
+
tool_call_definition = await tool.toolcall_definition()
|
|
652
|
+
json_schema = json.dumps(tool_call_definition["function"]["parameters"])
|
|
653
|
+
validate_schema_with_value_error(parsed_args, json_schema)
|
|
654
|
+
except Exception as e:
|
|
655
|
+
raise RuntimeError(
|
|
656
|
+
f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
|
|
657
|
+
) from e
|
|
658
|
+
|
|
659
|
+
result = await tool.run(**parsed_args)
|
|
660
|
+
|
|
661
|
+
tool_call_response_messages.append(
|
|
662
|
+
ChatCompletionToolMessageParam(
|
|
663
|
+
role="tool",
|
|
664
|
+
tool_call_id=tool_call.id,
|
|
665
|
+
content=result,
|
|
666
|
+
)
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
if (
|
|
670
|
+
assistant_output_from_toolcall is not None
|
|
671
|
+
and len(tool_call_response_messages) > 0
|
|
672
|
+
):
|
|
673
|
+
raise RuntimeError(
|
|
674
|
+
"Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible."
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
return assistant_output_from_toolcall, tool_call_response_messages
|
|
678
|
+
|
|
679
|
+
def litellm_message_to_trace_message(
|
|
680
|
+
self, raw_message: LiteLLMMessage
|
|
681
|
+
) -> ChatCompletionAssistantMessageParamWrapper:
|
|
682
|
+
"""
|
|
683
|
+
Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
|
|
684
|
+
"""
|
|
685
|
+
message: ChatCompletionAssistantMessageParamWrapper = {
|
|
686
|
+
"role": "assistant",
|
|
687
|
+
}
|
|
688
|
+
if raw_message.role != "assistant":
|
|
689
|
+
raise ValueError(
|
|
690
|
+
"Model returned a message with a role other than assistant. This is not supported."
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
if hasattr(raw_message, "content"):
|
|
694
|
+
message["content"] = raw_message.content
|
|
695
|
+
if hasattr(raw_message, "reasoning_content"):
|
|
696
|
+
message["reasoning_content"] = raw_message.reasoning_content
|
|
697
|
+
if hasattr(raw_message, "tool_calls"):
|
|
698
|
+
# Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam
|
|
699
|
+
open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
|
|
700
|
+
for litellm_tool_call in raw_message.tool_calls or []:
|
|
701
|
+
# Optional in the SDK for streaming responses, but should never be None at this point.
|
|
702
|
+
if litellm_tool_call.function.name is None:
|
|
703
|
+
raise ValueError(
|
|
704
|
+
"The model requested a tool call, without providing a function name (required)."
|
|
705
|
+
)
|
|
706
|
+
open_ai_tool_calls.append(
|
|
707
|
+
ChatCompletionMessageToolCallParam(
|
|
708
|
+
id=litellm_tool_call.id,
|
|
709
|
+
type="function",
|
|
710
|
+
function={
|
|
711
|
+
"name": litellm_tool_call.function.name,
|
|
712
|
+
"arguments": litellm_tool_call.function.arguments,
|
|
713
|
+
},
|
|
714
|
+
)
|
|
715
|
+
)
|
|
716
|
+
if len(open_ai_tool_calls) > 0:
|
|
717
|
+
message["tool_calls"] = open_ai_tool_calls
|
|
718
|
+
|
|
719
|
+
if not message.get("content") and not message.get("tool_calls"):
|
|
720
|
+
raise ValueError(
|
|
721
|
+
"Model returned an assistant message, but no content or tool calls. This is not supported."
|
|
722
|
+
)
|
|
723
|
+
|
|
724
|
+
return message
|
|
725
|
+
|
|
726
|
+
def all_messages_to_trace(
|
|
727
|
+
self, messages: list[ChatCompletionMessageIncludingLiteLLM]
|
|
728
|
+
) -> list[ChatCompletionMessageParam]:
|
|
729
|
+
"""
|
|
730
|
+
Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
|
|
731
|
+
"""
|
|
732
|
+
trace: list[ChatCompletionMessageParam] = []
|
|
733
|
+
for message in messages:
|
|
734
|
+
if isinstance(message, LiteLLMMessage):
|
|
735
|
+
trace.append(self.litellm_message_to_trace_message(message))
|
|
736
|
+
else:
|
|
737
|
+
trace.append(message)
|
|
738
|
+
return trace
|