kiln-ai 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +8 -2
- kiln_ai/adapters/adapter_registry.py +43 -208
- kiln_ai/adapters/chat/chat_formatter.py +8 -12
- kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/docker_model_runner_tools.py +119 -0
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/base_eval.py +2 -2
- kiln_ai/adapters/eval/eval_runner.py +9 -3
- kiln_ai/adapters/eval/g_eval.py +2 -2
- kiln_ai/adapters/eval/test_base_eval.py +2 -4
- kiln_ai/adapters/eval/test_g_eval.py +4 -5
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
- kiln_ai/adapters/fine_tune/__init__.py +1 -1
- kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +192 -0
- kiln_ai/adapters/ml_model_list.py +761 -37
- kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
- kiln_ai/adapters/model_adapters/litellm_adapter.py +380 -138
- kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
- kiln_ai/adapters/model_adapters/test_structured_output.py +113 -5
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/parsers/__init__.py +1 -1
- kiln_ai/adapters/provider_tools.py +205 -47
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/repair/test_repair_task.py +12 -9
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +657 -85
- kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
- kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
- kiln_ai/adapters/test_ml_model_list.py +251 -1
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_adaptors.py +13 -6
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +254 -8
- kiln_ai/adapters/test_remote_config.py +651 -58
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +39 -34
- kiln_ai/datamodel/basemodel.py +170 -1
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +28 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/eval.py +1 -1
- kiln_ai/datamodel/external_tool_server.py +298 -0
- kiln_ai/datamodel/extraction.py +303 -0
- kiln_ai/datamodel/json_schema.py +25 -10
- kiln_ai/datamodel/project.py +40 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/registry.py +0 -15
- kiln_ai/datamodel/run_config.py +62 -0
- kiln_ai/datamodel/task.py +2 -77
- kiln_ai/datamodel/task_output.py +6 -1
- kiln_ai/datamodel/task_run.py +41 -0
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +4 -4
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_example_models.py +175 -0
- kiln_ai/datamodel/test_external_tool_server.py +691 -0
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +470 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_registry.py +8 -3
- kiln_ai/datamodel/test_task.py +15 -47
- kiln_ai/datamodel/test_tool_id.py +320 -0
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +105 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/__init__.py +8 -0
- kiln_ai/tools/base_tool.py +82 -0
- kiln_ai/tools/built_in_tools/__init__.py +13 -0
- kiln_ai/tools/built_in_tools/math_tools.py +124 -0
- kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
- kiln_ai/tools/mcp_server_tool.py +95 -0
- kiln_ai/tools/mcp_session_manager.py +246 -0
- kiln_ai/tools/rag_tools.py +157 -0
- kiln_ai/tools/test_base_tools.py +199 -0
- kiln_ai/tools/test_mcp_server_tool.py +457 -0
- kiln_ai/tools/test_mcp_session_manager.py +1585 -0
- kiln_ai/tools/test_rag_tools.py +848 -0
- kiln_ai/tools/test_tool_registry.py +562 -0
- kiln_ai/tools/tool_registry.py +85 -0
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +24 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +94 -0
- kiln_ai/utils/pdf_utils.py +38 -0
- kiln_ai/utils/project_utils.py +17 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_config.py +138 -1
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +131 -0
- kiln_ai/utils/test_pdf_utils.py +73 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +12 -5
- kiln_ai-0.21.0.dist-info/RECORD +211 -0
- kiln_ai-0.19.0.dist-info/RECORD +0 -115
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,9 +1,22 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import json
|
|
1
3
|
import logging
|
|
2
|
-
from
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any, Dict, List, Tuple, TypeAlias, Union
|
|
3
6
|
|
|
4
7
|
import litellm
|
|
5
|
-
from litellm.types.utils import
|
|
8
|
+
from litellm.types.utils import (
|
|
9
|
+
ChatCompletionMessageToolCall,
|
|
10
|
+
ChoiceLogprobs,
|
|
11
|
+
Choices,
|
|
12
|
+
ModelResponse,
|
|
13
|
+
)
|
|
14
|
+
from litellm.types.utils import Message as LiteLLMMessage
|
|
6
15
|
from litellm.types.utils import Usage as LiteLlmUsage
|
|
16
|
+
from openai.types.chat import ChatCompletionToolMessageParam
|
|
17
|
+
from openai.types.chat.chat_completion_message_tool_call_param import (
|
|
18
|
+
ChatCompletionMessageToolCallParam,
|
|
19
|
+
)
|
|
7
20
|
|
|
8
21
|
import kiln_ai.datamodel as datamodel
|
|
9
22
|
from kiln_ai.adapters.ml_model_list import (
|
|
@@ -18,11 +31,33 @@ from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
|
18
31
|
Usage,
|
|
19
32
|
)
|
|
20
33
|
from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
|
|
21
|
-
from kiln_ai.datamodel.
|
|
34
|
+
from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
|
|
35
|
+
from kiln_ai.tools.base_tool import KilnToolInterface
|
|
22
36
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
37
|
+
from kiln_ai.utils.litellm import get_litellm_provider_info
|
|
38
|
+
from kiln_ai.utils.open_ai_types import (
|
|
39
|
+
ChatCompletionAssistantMessageParamWrapper,
|
|
40
|
+
ChatCompletionMessageParam,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
MAX_CALLS_PER_TURN = 10
|
|
44
|
+
MAX_TOOL_CALLS_PER_TURN = 30
|
|
23
45
|
|
|
24
46
|
logger = logging.getLogger(__name__)
|
|
25
47
|
|
|
48
|
+
ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[
|
|
49
|
+
ChatCompletionMessageParam, LiteLLMMessage
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class ModelTurnResult:
|
|
55
|
+
assistant_message: str
|
|
56
|
+
all_messages: list[ChatCompletionMessageIncludingLiteLLM]
|
|
57
|
+
model_response: ModelResponse | None
|
|
58
|
+
model_choice: Choices | None
|
|
59
|
+
usage: Usage
|
|
60
|
+
|
|
26
61
|
|
|
27
62
|
class LiteLlmAdapter(BaseAdapter):
|
|
28
63
|
def __init__(
|
|
@@ -36,117 +71,226 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
36
71
|
self._api_base = config.base_url
|
|
37
72
|
self._headers = config.default_headers
|
|
38
73
|
self._litellm_model_id: str | None = None
|
|
74
|
+
self._cached_available_tools: list[KilnToolInterface] | None = None
|
|
39
75
|
|
|
40
|
-
|
|
41
|
-
run_config = run_config_from_run_config_properties(
|
|
76
|
+
super().__init__(
|
|
42
77
|
task=kiln_task,
|
|
43
|
-
|
|
78
|
+
run_config=config.run_config_properties,
|
|
79
|
+
config=base_adapter_config,
|
|
44
80
|
)
|
|
45
81
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
82
|
+
async def _run_model_turn(
|
|
83
|
+
self,
|
|
84
|
+
provider: KilnModelProvider,
|
|
85
|
+
prior_messages: list[ChatCompletionMessageIncludingLiteLLM],
|
|
86
|
+
top_logprobs: int | None,
|
|
87
|
+
skip_response_format: bool,
|
|
88
|
+
) -> ModelTurnResult:
|
|
89
|
+
"""
|
|
90
|
+
Call the model for a single top level turn: from user message to agent message.
|
|
91
|
+
|
|
92
|
+
It may make handle iterations of tool calls between the user/agent message if needed.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
usage = Usage()
|
|
96
|
+
messages = list(prior_messages)
|
|
97
|
+
tool_calls_count = 0
|
|
98
|
+
|
|
99
|
+
while tool_calls_count < MAX_TOOL_CALLS_PER_TURN:
|
|
100
|
+
# Build completion kwargs for tool calls
|
|
101
|
+
completion_kwargs = await self.build_completion_kwargs(
|
|
102
|
+
provider,
|
|
103
|
+
# Pass a copy, as acompletion mutates objects and breaks types.
|
|
104
|
+
copy.deepcopy(messages),
|
|
105
|
+
top_logprobs,
|
|
106
|
+
skip_response_format,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Make the completion call
|
|
110
|
+
model_response, response_choice = await self.acompletion_checking_response(
|
|
111
|
+
**completion_kwargs
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# count the usage
|
|
115
|
+
usage += self.usage_from_response(model_response)
|
|
116
|
+
|
|
117
|
+
# Extract content and tool calls
|
|
118
|
+
if not hasattr(response_choice, "message"):
|
|
119
|
+
raise ValueError("Response choice has no message")
|
|
120
|
+
content = response_choice.message.content
|
|
121
|
+
tool_calls = response_choice.message.tool_calls
|
|
122
|
+
if not content and not tool_calls:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
"Model returned an assistant message, but no content or tool calls. This is not supported."
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Add message to messages, so it can be used in the next turn
|
|
128
|
+
messages.append(response_choice.message)
|
|
129
|
+
|
|
130
|
+
# Process tool calls if any
|
|
131
|
+
if tool_calls and len(tool_calls) > 0:
|
|
132
|
+
(
|
|
133
|
+
assistant_message_from_toolcall,
|
|
134
|
+
tool_call_messages,
|
|
135
|
+
) = await self.process_tool_calls(tool_calls)
|
|
136
|
+
|
|
137
|
+
# Add tool call results to messages
|
|
138
|
+
messages.extend(tool_call_messages)
|
|
139
|
+
|
|
140
|
+
# If task_response tool was called, we're done
|
|
141
|
+
if assistant_message_from_toolcall is not None:
|
|
142
|
+
return ModelTurnResult(
|
|
143
|
+
assistant_message=assistant_message_from_toolcall,
|
|
144
|
+
all_messages=messages,
|
|
145
|
+
model_response=model_response,
|
|
146
|
+
model_choice=response_choice,
|
|
147
|
+
usage=usage,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# If there were tool calls, increment counter and continue
|
|
151
|
+
if tool_call_messages:
|
|
152
|
+
tool_calls_count += 1
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
# If no tool calls, return the content as final output
|
|
156
|
+
if content:
|
|
157
|
+
return ModelTurnResult(
|
|
158
|
+
assistant_message=content,
|
|
159
|
+
all_messages=messages,
|
|
160
|
+
model_response=model_response,
|
|
161
|
+
model_choice=response_choice,
|
|
162
|
+
usage=usage,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# If we get here with no content and no tool calls, break
|
|
166
|
+
raise RuntimeError(
|
|
167
|
+
"Model returned neither content nor tool calls. It must return at least one of these."
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
raise RuntimeError(
|
|
171
|
+
f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens."
|
|
49
172
|
)
|
|
50
173
|
|
|
51
174
|
async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]:
|
|
175
|
+
usage = Usage()
|
|
176
|
+
|
|
52
177
|
provider = self.model_provider()
|
|
53
178
|
if not provider.model_id:
|
|
54
179
|
raise ValueError("Model ID is required for OpenAI compatible models")
|
|
55
180
|
|
|
56
181
|
chat_formatter = self.build_chat_formatter(input)
|
|
182
|
+
messages: list[ChatCompletionMessageIncludingLiteLLM] = []
|
|
57
183
|
|
|
58
|
-
prior_output = None
|
|
59
|
-
|
|
60
|
-
response = None
|
|
184
|
+
prior_output: str | None = None
|
|
185
|
+
final_choice: Choices | None = None
|
|
61
186
|
turns = 0
|
|
187
|
+
|
|
62
188
|
while True:
|
|
63
189
|
turns += 1
|
|
64
|
-
if turns >
|
|
190
|
+
if turns > MAX_CALLS_PER_TURN:
|
|
65
191
|
raise RuntimeError(
|
|
66
|
-
"Too many turns. Stopping iteration to avoid using too many tokens."
|
|
192
|
+
f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens."
|
|
67
193
|
)
|
|
68
194
|
|
|
69
195
|
turn = chat_formatter.next_turn(prior_output)
|
|
70
196
|
if turn is None:
|
|
197
|
+
# No next turn, we're done
|
|
71
198
|
break
|
|
72
199
|
|
|
200
|
+
# Add messages from the turn to chat history
|
|
201
|
+
for message in turn.messages:
|
|
202
|
+
if message.content is None:
|
|
203
|
+
raise ValueError("Empty message content isn't allowed")
|
|
204
|
+
# pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role.
|
|
205
|
+
messages.append({"role": message.role, "content": message.content}) # type: ignore
|
|
206
|
+
|
|
73
207
|
skip_response_format = not turn.final_call
|
|
74
|
-
|
|
75
|
-
completion_kwargs = await self.build_completion_kwargs(
|
|
208
|
+
turn_result = await self._run_model_turn(
|
|
76
209
|
provider,
|
|
77
|
-
|
|
210
|
+
messages,
|
|
78
211
|
self.base_adapter_config.top_logprobs if turn.final_call else None,
|
|
79
212
|
skip_response_format,
|
|
80
213
|
)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
):
|
|
88
|
-
raise RuntimeError(
|
|
89
|
-
f"Expected ModelResponse with Choices, got {type(response)}."
|
|
90
|
-
)
|
|
91
|
-
prior_message = response.choices[0].message
|
|
92
|
-
prior_output = prior_message.content
|
|
93
|
-
|
|
94
|
-
# Fallback: Use args of first tool call to task_response if it exists
|
|
95
|
-
if (
|
|
96
|
-
not prior_output
|
|
97
|
-
and hasattr(prior_message, "tool_calls")
|
|
98
|
-
and prior_message.tool_calls
|
|
99
|
-
):
|
|
100
|
-
tool_call = next(
|
|
101
|
-
(
|
|
102
|
-
tool_call
|
|
103
|
-
for tool_call in prior_message.tool_calls
|
|
104
|
-
if tool_call.function.name == "task_response"
|
|
105
|
-
),
|
|
106
|
-
None,
|
|
107
|
-
)
|
|
108
|
-
if tool_call:
|
|
109
|
-
prior_output = tool_call.function.arguments
|
|
214
|
+
|
|
215
|
+
usage += turn_result.usage
|
|
216
|
+
|
|
217
|
+
prior_output = turn_result.assistant_message
|
|
218
|
+
messages = turn_result.all_messages
|
|
219
|
+
final_choice = turn_result.model_choice
|
|
110
220
|
|
|
111
221
|
if not prior_output:
|
|
112
|
-
raise RuntimeError("No output returned from model")
|
|
222
|
+
raise RuntimeError("No assistant message/output returned from model")
|
|
113
223
|
|
|
114
|
-
|
|
115
|
-
raise RuntimeError("No response returned from model")
|
|
224
|
+
logprobs = self._extract_and_validate_logprobs(final_choice)
|
|
116
225
|
|
|
226
|
+
# Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream)
|
|
117
227
|
intermediate_outputs = chat_formatter.intermediate_outputs()
|
|
228
|
+
self._extract_reasoning_to_intermediate_outputs(
|
|
229
|
+
final_choice, intermediate_outputs
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
if not isinstance(prior_output, str):
|
|
233
|
+
raise RuntimeError(f"assistant message is not a string: {prior_output}")
|
|
118
234
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
235
|
+
trace = self.all_messages_to_trace(messages)
|
|
236
|
+
output = RunOutput(
|
|
237
|
+
output=prior_output,
|
|
238
|
+
intermediate_outputs=intermediate_outputs,
|
|
239
|
+
output_logprobs=logprobs,
|
|
240
|
+
trace=trace,
|
|
124
241
|
)
|
|
125
242
|
|
|
126
|
-
|
|
127
|
-
if self.base_adapter_config.top_logprobs is not None and logprobs is None:
|
|
128
|
-
raise RuntimeError("Logprobs were required, but no logprobs were returned.")
|
|
243
|
+
return output, usage
|
|
129
244
|
|
|
130
|
-
|
|
245
|
+
def _extract_and_validate_logprobs(
|
|
246
|
+
self, final_choice: Choices | None
|
|
247
|
+
) -> ChoiceLogprobs | None:
|
|
248
|
+
"""
|
|
249
|
+
Extract logprobs from the final choice and validate they exist if required.
|
|
250
|
+
"""
|
|
251
|
+
logprobs = None
|
|
131
252
|
if (
|
|
132
|
-
|
|
133
|
-
and hasattr(
|
|
134
|
-
and
|
|
135
|
-
and len(prior_message.reasoning_content.strip()) > 0
|
|
253
|
+
final_choice is not None
|
|
254
|
+
and hasattr(final_choice, "logprobs")
|
|
255
|
+
and isinstance(final_choice.logprobs, ChoiceLogprobs)
|
|
136
256
|
):
|
|
137
|
-
|
|
257
|
+
logprobs = final_choice.logprobs
|
|
138
258
|
|
|
139
|
-
#
|
|
140
|
-
|
|
259
|
+
# Check logprobs worked, if required
|
|
260
|
+
if self.base_adapter_config.top_logprobs is not None and logprobs is None:
|
|
261
|
+
raise RuntimeError("Logprobs were required, but no logprobs were returned.")
|
|
141
262
|
|
|
142
|
-
|
|
143
|
-
raise RuntimeError(f"response is not a string: {response_content}")
|
|
263
|
+
return logprobs
|
|
144
264
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
265
|
+
def _extract_reasoning_to_intermediate_outputs(
|
|
266
|
+
self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any]
|
|
267
|
+
) -> None:
|
|
268
|
+
"""Extract reasoning content from model choice and add to intermediate outputs if present."""
|
|
269
|
+
if (
|
|
270
|
+
final_choice is not None
|
|
271
|
+
and hasattr(final_choice, "message")
|
|
272
|
+
and hasattr(final_choice.message, "reasoning_content")
|
|
273
|
+
):
|
|
274
|
+
reasoning_content = final_choice.message.reasoning_content
|
|
275
|
+
if reasoning_content is not None:
|
|
276
|
+
stripped_reasoning_content = reasoning_content.strip()
|
|
277
|
+
if len(stripped_reasoning_content) > 0:
|
|
278
|
+
intermediate_outputs["reasoning"] = stripped_reasoning_content
|
|
279
|
+
|
|
280
|
+
async def acompletion_checking_response(
|
|
281
|
+
self, **kwargs
|
|
282
|
+
) -> Tuple[ModelResponse, Choices]:
|
|
283
|
+
response = await litellm.acompletion(**kwargs)
|
|
284
|
+
if (
|
|
285
|
+
not isinstance(response, ModelResponse)
|
|
286
|
+
or not response.choices
|
|
287
|
+
or len(response.choices) == 0
|
|
288
|
+
or not isinstance(response.choices[0], Choices)
|
|
289
|
+
):
|
|
290
|
+
raise RuntimeError(
|
|
291
|
+
f"Expected ModelResponse with Choices, got {type(response)}."
|
|
292
|
+
)
|
|
293
|
+
return response, response.choices[0]
|
|
150
294
|
|
|
151
295
|
def adapter_name(self) -> str:
|
|
152
296
|
return "kiln_openai_compatible_adapter"
|
|
@@ -181,6 +325,9 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
181
325
|
if provider_name == ModelProviderName.ollama:
|
|
182
326
|
# Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
|
|
183
327
|
return self.json_schema_response_format()
|
|
328
|
+
elif provider_name == ModelProviderName.docker_model_runner:
|
|
329
|
+
# Docker Model Runner uses OpenAI-compatible API with JSON schema support
|
|
330
|
+
return self.json_schema_response_format()
|
|
184
331
|
else:
|
|
185
332
|
# Default to function calling -- it's older than the other modes. Higher compatibility.
|
|
186
333
|
# Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
|
|
@@ -193,7 +340,7 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
193
340
|
raise_exhaustive_enum_error(structured_output_mode)
|
|
194
341
|
|
|
195
342
|
def json_schema_response_format(self) -> dict[str, Any]:
|
|
196
|
-
output_schema = self.task
|
|
343
|
+
output_schema = self.task.output_schema()
|
|
197
344
|
return {
|
|
198
345
|
"response_format": {
|
|
199
346
|
"type": "json_schema",
|
|
@@ -206,7 +353,7 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
206
353
|
|
|
207
354
|
def tool_call_params(self, strict: bool) -> dict[str, Any]:
|
|
208
355
|
# Add additional_properties: false to the schema (OpenAI requires this for some models)
|
|
209
|
-
output_schema = self.task
|
|
356
|
+
output_schema = self.task.output_schema()
|
|
210
357
|
if not isinstance(output_schema, dict):
|
|
211
358
|
raise ValueError(
|
|
212
359
|
"Invalid output schema for this task. Can not use tool calls."
|
|
@@ -297,77 +444,22 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
297
444
|
def litellm_model_id(self) -> str:
|
|
298
445
|
# The model ID is an interesting combination of format and url endpoint.
|
|
299
446
|
# It specifics the provider URL/host, but this is overridden if you manually set an api url
|
|
300
|
-
|
|
301
447
|
if self._litellm_model_id:
|
|
302
448
|
return self._litellm_model_id
|
|
303
449
|
|
|
304
|
-
|
|
305
|
-
if
|
|
306
|
-
raise ValueError("Model ID is required for OpenAI compatible models")
|
|
307
|
-
|
|
308
|
-
litellm_provider_name: str | None = None
|
|
309
|
-
is_custom = False
|
|
310
|
-
match provider.name:
|
|
311
|
-
case ModelProviderName.openrouter:
|
|
312
|
-
litellm_provider_name = "openrouter"
|
|
313
|
-
case ModelProviderName.openai:
|
|
314
|
-
litellm_provider_name = "openai"
|
|
315
|
-
case ModelProviderName.groq:
|
|
316
|
-
litellm_provider_name = "groq"
|
|
317
|
-
case ModelProviderName.anthropic:
|
|
318
|
-
litellm_provider_name = "anthropic"
|
|
319
|
-
case ModelProviderName.ollama:
|
|
320
|
-
# We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
|
|
321
|
-
# This is because we're setting detailed features like response_format=json_schema and want lower level control.
|
|
322
|
-
is_custom = True
|
|
323
|
-
case ModelProviderName.gemini_api:
|
|
324
|
-
litellm_provider_name = "gemini"
|
|
325
|
-
case ModelProviderName.fireworks_ai:
|
|
326
|
-
litellm_provider_name = "fireworks_ai"
|
|
327
|
-
case ModelProviderName.amazon_bedrock:
|
|
328
|
-
litellm_provider_name = "bedrock"
|
|
329
|
-
case ModelProviderName.azure_openai:
|
|
330
|
-
litellm_provider_name = "azure"
|
|
331
|
-
case ModelProviderName.huggingface:
|
|
332
|
-
litellm_provider_name = "huggingface"
|
|
333
|
-
case ModelProviderName.vertex:
|
|
334
|
-
litellm_provider_name = "vertex_ai"
|
|
335
|
-
case ModelProviderName.together_ai:
|
|
336
|
-
litellm_provider_name = "together_ai"
|
|
337
|
-
case ModelProviderName.cerebras:
|
|
338
|
-
litellm_provider_name = "cerebras"
|
|
339
|
-
case ModelProviderName.siliconflow_cn:
|
|
340
|
-
is_custom = True
|
|
341
|
-
case ModelProviderName.openai_compatible:
|
|
342
|
-
is_custom = True
|
|
343
|
-
case ModelProviderName.kiln_custom_registry:
|
|
344
|
-
is_custom = True
|
|
345
|
-
case ModelProviderName.kiln_fine_tune:
|
|
346
|
-
is_custom = True
|
|
347
|
-
case _:
|
|
348
|
-
raise_exhaustive_enum_error(provider.name)
|
|
349
|
-
|
|
350
|
-
if is_custom:
|
|
351
|
-
if self._api_base is None:
|
|
352
|
-
raise ValueError(
|
|
353
|
-
"Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
|
|
354
|
-
)
|
|
355
|
-
# Use openai as it's only used for format, not url
|
|
356
|
-
litellm_provider_name = "openai"
|
|
357
|
-
|
|
358
|
-
# Sholdn't be possible but keep type checker happy
|
|
359
|
-
if litellm_provider_name is None:
|
|
450
|
+
litellm_provider_info = get_litellm_provider_info(self.model_provider())
|
|
451
|
+
if litellm_provider_info.is_custom and self._api_base is None:
|
|
360
452
|
raise ValueError(
|
|
361
|
-
|
|
453
|
+
"Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
|
|
362
454
|
)
|
|
363
455
|
|
|
364
|
-
self._litellm_model_id =
|
|
456
|
+
self._litellm_model_id = litellm_provider_info.litellm_model_id
|
|
365
457
|
return self._litellm_model_id
|
|
366
458
|
|
|
367
459
|
async def build_completion_kwargs(
|
|
368
460
|
self,
|
|
369
461
|
provider: KilnModelProvider,
|
|
370
|
-
messages: list[
|
|
462
|
+
messages: list[ChatCompletionMessageIncludingLiteLLM],
|
|
371
463
|
top_logprobs: int | None,
|
|
372
464
|
skip_response_format: bool = False,
|
|
373
465
|
) -> dict[str, Any]:
|
|
@@ -390,9 +482,23 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
390
482
|
**self._additional_body_options,
|
|
391
483
|
}
|
|
392
484
|
|
|
485
|
+
tool_calls = await self.litellm_tools()
|
|
486
|
+
has_tools = len(tool_calls) > 0
|
|
487
|
+
if has_tools:
|
|
488
|
+
completion_kwargs["tools"] = tool_calls
|
|
489
|
+
completion_kwargs["tool_choice"] = "auto"
|
|
490
|
+
|
|
393
491
|
if not skip_response_format:
|
|
394
492
|
# Response format: json_schema, json_instructions, json_mode, function_calling, etc
|
|
395
493
|
response_format_options = await self.response_format_options()
|
|
494
|
+
|
|
495
|
+
# Check for a conflict between tools and response format using tools
|
|
496
|
+
# We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged).
|
|
497
|
+
if has_tools and "tools" in response_format_options:
|
|
498
|
+
raise ValueError(
|
|
499
|
+
"Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode."
|
|
500
|
+
)
|
|
501
|
+
|
|
396
502
|
completion_kwargs.update(response_format_options)
|
|
397
503
|
|
|
398
504
|
if top_logprobs is not None:
|
|
@@ -401,7 +507,7 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
401
507
|
|
|
402
508
|
return completion_kwargs
|
|
403
509
|
|
|
404
|
-
def usage_from_response(self, response: ModelResponse) -> Usage
|
|
510
|
+
def usage_from_response(self, response: ModelResponse) -> Usage:
|
|
405
511
|
litellm_usage = response.get("usage", None)
|
|
406
512
|
|
|
407
513
|
# LiteLLM isn't consistent in how it returns the cost.
|
|
@@ -409,11 +515,11 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
409
515
|
if cost is None and litellm_usage:
|
|
410
516
|
cost = litellm_usage.get("cost", None)
|
|
411
517
|
|
|
412
|
-
if not litellm_usage and not cost:
|
|
413
|
-
return None
|
|
414
|
-
|
|
415
518
|
usage = Usage()
|
|
416
519
|
|
|
520
|
+
if not litellm_usage and not cost:
|
|
521
|
+
return usage
|
|
522
|
+
|
|
417
523
|
if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
|
|
418
524
|
usage.input_tokens = litellm_usage.get("prompt_tokens", None)
|
|
419
525
|
usage.output_tokens = litellm_usage.get("completion_tokens", None)
|
|
@@ -432,3 +538,139 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
432
538
|
)
|
|
433
539
|
|
|
434
540
|
return usage
|
|
541
|
+
|
|
542
|
+
async def cached_available_tools(self) -> list[KilnToolInterface]:
|
|
543
|
+
if self._cached_available_tools is None:
|
|
544
|
+
self._cached_available_tools = await self.available_tools()
|
|
545
|
+
return self._cached_available_tools
|
|
546
|
+
|
|
547
|
+
async def litellm_tools(self) -> list[Dict]:
|
|
548
|
+
available_tools = await self.cached_available_tools()
|
|
549
|
+
|
|
550
|
+
# LiteLLM takes the standard OpenAI-compatible tool call format
|
|
551
|
+
return [await tool.toolcall_definition() for tool in available_tools]
|
|
552
|
+
|
|
553
|
+
async def process_tool_calls(
|
|
554
|
+
self, tool_calls: list[ChatCompletionMessageToolCall] | None
|
|
555
|
+
) -> tuple[str | None, list[ChatCompletionToolMessageParam]]:
|
|
556
|
+
if tool_calls is None:
|
|
557
|
+
return None, []
|
|
558
|
+
|
|
559
|
+
assistant_output_from_toolcall: str | None = None
|
|
560
|
+
tool_call_response_messages: list[ChatCompletionToolMessageParam] = []
|
|
561
|
+
|
|
562
|
+
for tool_call in tool_calls:
|
|
563
|
+
# Kiln "task_response" tool is used for returning structured output via tool calls.
|
|
564
|
+
# Load the output from the tool call. Also
|
|
565
|
+
if tool_call.function.name == "task_response":
|
|
566
|
+
assistant_output_from_toolcall = tool_call.function.arguments
|
|
567
|
+
continue
|
|
568
|
+
|
|
569
|
+
# Process normal tool calls (not the "task_response" tool)
|
|
570
|
+
tool_name = tool_call.function.name
|
|
571
|
+
tool = None
|
|
572
|
+
for tool_option in await self.cached_available_tools():
|
|
573
|
+
if await tool_option.name() == tool_name:
|
|
574
|
+
tool = tool_option
|
|
575
|
+
break
|
|
576
|
+
if not tool:
|
|
577
|
+
raise RuntimeError(
|
|
578
|
+
f"A tool named '{tool_name}' was invoked by a model, but was not available."
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
# Parse the arguments and validate them against the tool's schema
|
|
582
|
+
try:
|
|
583
|
+
parsed_args = json.loads(tool_call.function.arguments)
|
|
584
|
+
except json.JSONDecodeError:
|
|
585
|
+
raise RuntimeError(
|
|
586
|
+
f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}"
|
|
587
|
+
)
|
|
588
|
+
try:
|
|
589
|
+
tool_call_definition = await tool.toolcall_definition()
|
|
590
|
+
json_schema = json.dumps(tool_call_definition["function"]["parameters"])
|
|
591
|
+
validate_schema_with_value_error(parsed_args, json_schema)
|
|
592
|
+
except Exception as e:
|
|
593
|
+
raise RuntimeError(
|
|
594
|
+
f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
|
|
595
|
+
) from e
|
|
596
|
+
|
|
597
|
+
result = await tool.run(**parsed_args)
|
|
598
|
+
|
|
599
|
+
tool_call_response_messages.append(
|
|
600
|
+
ChatCompletionToolMessageParam(
|
|
601
|
+
role="tool",
|
|
602
|
+
tool_call_id=tool_call.id,
|
|
603
|
+
content=result,
|
|
604
|
+
)
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
if (
|
|
608
|
+
assistant_output_from_toolcall is not None
|
|
609
|
+
and len(tool_call_response_messages) > 0
|
|
610
|
+
):
|
|
611
|
+
raise RuntimeError(
|
|
612
|
+
"Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible."
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
return assistant_output_from_toolcall, tool_call_response_messages
|
|
616
|
+
|
|
617
|
+
def litellm_message_to_trace_message(
|
|
618
|
+
self, raw_message: LiteLLMMessage
|
|
619
|
+
) -> ChatCompletionAssistantMessageParamWrapper:
|
|
620
|
+
"""
|
|
621
|
+
Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
|
|
622
|
+
"""
|
|
623
|
+
message: ChatCompletionAssistantMessageParamWrapper = {
|
|
624
|
+
"role": "assistant",
|
|
625
|
+
}
|
|
626
|
+
if raw_message.role != "assistant":
|
|
627
|
+
raise ValueError(
|
|
628
|
+
"Model returned a message with a role other than assistant. This is not supported."
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
if hasattr(raw_message, "content"):
|
|
632
|
+
message["content"] = raw_message.content
|
|
633
|
+
if hasattr(raw_message, "reasoning_content"):
|
|
634
|
+
message["reasoning_content"] = raw_message.reasoning_content
|
|
635
|
+
if hasattr(raw_message, "tool_calls"):
|
|
636
|
+
# Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam
|
|
637
|
+
open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
|
|
638
|
+
for litellm_tool_call in raw_message.tool_calls or []:
|
|
639
|
+
# Optional in the SDK for streaming responses, but should never be None at this point.
|
|
640
|
+
if litellm_tool_call.function.name is None:
|
|
641
|
+
raise ValueError(
|
|
642
|
+
"The model requested a tool call, without providing a function name (required)."
|
|
643
|
+
)
|
|
644
|
+
open_ai_tool_calls.append(
|
|
645
|
+
ChatCompletionMessageToolCallParam(
|
|
646
|
+
id=litellm_tool_call.id,
|
|
647
|
+
type="function",
|
|
648
|
+
function={
|
|
649
|
+
"name": litellm_tool_call.function.name,
|
|
650
|
+
"arguments": litellm_tool_call.function.arguments,
|
|
651
|
+
},
|
|
652
|
+
)
|
|
653
|
+
)
|
|
654
|
+
if len(open_ai_tool_calls) > 0:
|
|
655
|
+
message["tool_calls"] = open_ai_tool_calls
|
|
656
|
+
|
|
657
|
+
if not message.get("content") and not message.get("tool_calls"):
|
|
658
|
+
raise ValueError(
|
|
659
|
+
"Model returned an assistant message, but no content or tool calls. This is not supported."
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
return message
|
|
663
|
+
|
|
664
|
+
def all_messages_to_trace(
|
|
665
|
+
self, messages: list[ChatCompletionMessageIncludingLiteLLM]
|
|
666
|
+
) -> list[ChatCompletionMessageParam]:
|
|
667
|
+
"""
|
|
668
|
+
Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
|
|
669
|
+
"""
|
|
670
|
+
trace: list[ChatCompletionMessageParam] = []
|
|
671
|
+
for message in messages:
|
|
672
|
+
if isinstance(message, LiteLLMMessage):
|
|
673
|
+
trace.append(self.litellm_message_to_trace_message(message))
|
|
674
|
+
else:
|
|
675
|
+
trace.append(message)
|
|
676
|
+
return trace
|