kiln-ai 0.15.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +22 -44
- kiln_ai/adapters/chat/__init__.py +8 -0
- kiln_ai/adapters/chat/chat_formatter.py +234 -0
- kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +19 -6
- kiln_ai/adapters/eval/base_eval.py +8 -6
- kiln_ai/adapters/eval/eval_runner.py +9 -65
- kiln_ai/adapters/eval/g_eval.py +26 -8
- kiln_ai/adapters/eval/test_base_eval.py +166 -15
- kiln_ai/adapters/eval/test_eval_runner.py +3 -0
- kiln_ai/adapters/eval/test_g_eval.py +1 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +2 -2
- kiln_ai/adapters/fine_tune/dataset_formatter.py +153 -197
- kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +402 -211
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +4 -4
- kiln_ai/adapters/fine_tune/together_finetune.py +12 -1
- kiln_ai/adapters/ml_model_list.py +556 -45
- kiln_ai/adapters/model_adapters/base_adapter.py +100 -35
- kiln_ai/adapters/model_adapters/litellm_adapter.py +116 -100
- kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
- kiln_ai/adapters/model_adapters/test_base_adapter.py +299 -52
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +121 -22
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +44 -2
- kiln_ai/adapters/model_adapters/test_structured_output.py +48 -18
- kiln_ai/adapters/parsers/base_parser.py +0 -3
- kiln_ai/adapters/parsers/parser_registry.py +5 -3
- kiln_ai/adapters/parsers/r1_parser.py +17 -2
- kiln_ai/adapters/parsers/request_formatters.py +40 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
- kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
- kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
- kiln_ai/adapters/prompt_builders.py +14 -17
- kiln_ai/adapters/provider_tools.py +39 -4
- kiln_ai/adapters/repair/test_repair_task.py +27 -5
- kiln_ai/adapters/test_adapter_registry.py +88 -28
- kiln_ai/adapters/test_ml_model_list.py +158 -0
- kiln_ai/adapters/test_prompt_adaptors.py +17 -3
- kiln_ai/adapters/test_prompt_builders.py +27 -19
- kiln_ai/adapters/test_provider_tools.py +130 -12
- kiln_ai/datamodel/__init__.py +2 -2
- kiln_ai/datamodel/datamodel_enums.py +43 -4
- kiln_ai/datamodel/dataset_filters.py +69 -1
- kiln_ai/datamodel/dataset_split.py +4 -0
- kiln_ai/datamodel/eval.py +8 -0
- kiln_ai/datamodel/finetune.py +13 -7
- kiln_ai/datamodel/prompt_id.py +1 -0
- kiln_ai/datamodel/task.py +68 -7
- kiln_ai/datamodel/task_output.py +1 -1
- kiln_ai/datamodel/task_run.py +39 -7
- kiln_ai/datamodel/test_basemodel.py +5 -8
- kiln_ai/datamodel/test_dataset_filters.py +82 -0
- kiln_ai/datamodel/test_dataset_split.py +2 -8
- kiln_ai/datamodel/test_example_models.py +54 -0
- kiln_ai/datamodel/test_models.py +80 -9
- kiln_ai/datamodel/test_task.py +168 -2
- kiln_ai/utils/async_job_runner.py +106 -0
- kiln_ai/utils/config.py +3 -2
- kiln_ai/utils/dataset_import.py +81 -19
- kiln_ai/utils/logging.py +165 -0
- kiln_ai/utils/test_async_job_runner.py +199 -0
- kiln_ai/utils/test_config.py +23 -0
- kiln_ai/utils/test_dataset_import.py +272 -10
- {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/METADATA +1 -1
- kiln_ai-0.17.0.dist-info/RECORD +113 -0
- kiln_ai-0.15.0.dist-info/RECORD +0 -104
- {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,17 +1,32 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from abc import ABCMeta, abstractmethod
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Dict,
|
|
5
|
-
|
|
6
|
-
import
|
|
7
|
-
|
|
8
|
-
|
|
4
|
+
from typing import Dict, Tuple
|
|
5
|
+
|
|
6
|
+
from kiln_ai.adapters.chat.chat_formatter import (
|
|
7
|
+
ChatFormatter,
|
|
8
|
+
get_chat_formatter,
|
|
9
|
+
)
|
|
10
|
+
from kiln_ai.adapters.ml_model_list import (
|
|
11
|
+
KilnModelProvider,
|
|
12
|
+
StructuredOutputMode,
|
|
13
|
+
default_structured_output_mode_for_model_provider,
|
|
14
|
+
)
|
|
9
15
|
from kiln_ai.adapters.parsers.json_parser import parse_json_string
|
|
10
16
|
from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id
|
|
17
|
+
from kiln_ai.adapters.parsers.request_formatters import request_formatter_from_id
|
|
11
18
|
from kiln_ai.adapters.prompt_builders import prompt_builder_from_id
|
|
12
19
|
from kiln_ai.adapters.provider_tools import kiln_model_provider_from
|
|
13
20
|
from kiln_ai.adapters.run_output import RunOutput
|
|
14
|
-
from kiln_ai.datamodel import
|
|
21
|
+
from kiln_ai.datamodel import (
|
|
22
|
+
DataSource,
|
|
23
|
+
DataSourceType,
|
|
24
|
+
Task,
|
|
25
|
+
TaskOutput,
|
|
26
|
+
TaskRun,
|
|
27
|
+
Usage,
|
|
28
|
+
)
|
|
29
|
+
from kiln_ai.datamodel.datamodel_enums import ChatStrategy
|
|
15
30
|
from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
|
|
16
31
|
from kiln_ai.datamodel.task import RunConfig
|
|
17
32
|
from kiln_ai.utils.config import Config
|
|
@@ -30,9 +45,6 @@ class AdapterConfig:
|
|
|
30
45
|
default_tags: list[str] | None = None
|
|
31
46
|
|
|
32
47
|
|
|
33
|
-
COT_FINAL_ANSWER_PROMPT = "Considering the above, return a final result."
|
|
34
|
-
|
|
35
|
-
|
|
36
48
|
class BaseAdapter(metaclass=ABCMeta):
|
|
37
49
|
"""Base class for AI model adapters that handle task execution.
|
|
38
50
|
|
|
@@ -53,6 +65,7 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
53
65
|
config: AdapterConfig | None = None,
|
|
54
66
|
):
|
|
55
67
|
self.run_config = run_config
|
|
68
|
+
self.update_run_config_unknown_structured_output_mode()
|
|
56
69
|
self.prompt_builder = prompt_builder_from_id(
|
|
57
70
|
run_config.prompt_id, run_config.task
|
|
58
71
|
)
|
|
@@ -106,14 +119,19 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
106
119
|
"This task requires a specific input schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information.",
|
|
107
120
|
)
|
|
108
121
|
|
|
122
|
+
# Format model input for model call (we save the original input in the task without formatting)
|
|
123
|
+
formatted_input = input
|
|
124
|
+
formatter_id = self.model_provider().formatter
|
|
125
|
+
if formatter_id is not None:
|
|
126
|
+
formatter = request_formatter_from_id(formatter_id)
|
|
127
|
+
formatted_input = formatter.format_input(input)
|
|
128
|
+
|
|
109
129
|
# Run
|
|
110
|
-
run_output = await self._run(
|
|
130
|
+
run_output, usage = await self._run(formatted_input)
|
|
111
131
|
|
|
112
132
|
# Parse
|
|
113
133
|
provider = self.model_provider()
|
|
114
|
-
parser = model_parser_from_id(provider.parser)
|
|
115
|
-
structured_output=self.has_structured_output()
|
|
116
|
-
)
|
|
134
|
+
parser = model_parser_from_id(provider.parser)
|
|
117
135
|
parsed_output = parser.parse_output(original_output=run_output)
|
|
118
136
|
|
|
119
137
|
# validate output
|
|
@@ -147,7 +165,7 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
147
165
|
)
|
|
148
166
|
|
|
149
167
|
# Generate the run and output
|
|
150
|
-
run = self.generate_run(input, input_source, parsed_output)
|
|
168
|
+
run = self.generate_run(input, input_source, parsed_output, usage)
|
|
151
169
|
|
|
152
170
|
# Save the run if configured to do so, and we have a path to save to
|
|
153
171
|
if (
|
|
@@ -170,15 +188,15 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
170
188
|
pass
|
|
171
189
|
|
|
172
190
|
@abstractmethod
|
|
173
|
-
async def _run(self, input: Dict | str) -> RunOutput:
|
|
191
|
+
async def _run(self, input: Dict | str) -> Tuple[RunOutput, Usage | None]:
|
|
174
192
|
pass
|
|
175
193
|
|
|
176
194
|
def build_prompt(self) -> str:
|
|
177
195
|
# The prompt builder needs to know if we want to inject formatting instructions
|
|
178
|
-
|
|
196
|
+
structured_output_mode = self.run_config.structured_output_mode
|
|
179
197
|
add_json_instructions = self.has_structured_output() and (
|
|
180
|
-
|
|
181
|
-
or
|
|
198
|
+
structured_output_mode == StructuredOutputMode.json_instructions
|
|
199
|
+
or structured_output_mode
|
|
182
200
|
== StructuredOutputMode.json_instruction_and_object
|
|
183
201
|
)
|
|
184
202
|
|
|
@@ -186,30 +204,59 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
186
204
|
include_json_instructions=add_json_instructions
|
|
187
205
|
)
|
|
188
206
|
|
|
189
|
-
def
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
# Determine the run strategy for COT prompting. 3 options:
|
|
193
|
-
# 1. "Thinking" LLM designed to output thinking in a structured format plus a COT prompt: we make 1 call to the LLM, which outputs thinking in a structured format. We include the thinking instuctions as a message.
|
|
194
|
-
# 2. Normal LLM with COT prompt: we make 2 calls to the LLM - one for thinking and one for the final response. This helps us use the LLM's structured output modes (json_schema, tools, etc), which can't be used in a single call. It also separates the thinking from the final response.
|
|
195
|
-
# 3. Non chain of thought: we make 1 call to the LLM, with no COT prompt.
|
|
207
|
+
def build_chat_formatter(self, input: Dict | str) -> ChatFormatter:
|
|
208
|
+
# Determine the chat strategy to use based on the prompt the user selected, the model's capabilities, and if the model was finetuned with a specific chat strategy.
|
|
209
|
+
|
|
196
210
|
cot_prompt = self.prompt_builder.chain_of_thought_prompt()
|
|
197
|
-
|
|
211
|
+
system_message = self.build_prompt()
|
|
212
|
+
|
|
213
|
+
# If no COT prompt, use the single turn strategy. Even when a tuned strategy is set, as the tuned strategy is either already single turn, or won't work without a COT prompt.
|
|
214
|
+
if not cot_prompt:
|
|
215
|
+
return get_chat_formatter(
|
|
216
|
+
strategy=ChatStrategy.single_turn,
|
|
217
|
+
system_message=system_message,
|
|
218
|
+
user_input=input,
|
|
219
|
+
)
|
|
198
220
|
|
|
199
|
-
|
|
200
|
-
|
|
221
|
+
# Some models like finetunes are trained with a specific chat strategy. Use that.
|
|
222
|
+
# However, don't use that if it is single turn. The user selected a COT prompt, and we give explicit prompt selection priority over the tuned strategy.
|
|
223
|
+
tuned_chat_strategy = self.model_provider().tuned_chat_strategy
|
|
224
|
+
if tuned_chat_strategy and tuned_chat_strategy != ChatStrategy.single_turn:
|
|
225
|
+
return get_chat_formatter(
|
|
226
|
+
strategy=tuned_chat_strategy,
|
|
227
|
+
system_message=system_message,
|
|
228
|
+
user_input=input,
|
|
229
|
+
thinking_instructions=cot_prompt,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Pick the best chat strategy for the model given it has a cot prompt.
|
|
233
|
+
reasoning_capable = self.model_provider().reasoning_capable
|
|
234
|
+
if reasoning_capable:
|
|
235
|
+
# "Thinking" LLM designed to output thinking in a structured format. We'll use it's native format.
|
|
201
236
|
# A simple message with the COT prompt appended to the message list is sufficient
|
|
202
|
-
return
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
237
|
+
return get_chat_formatter(
|
|
238
|
+
strategy=ChatStrategy.single_turn_r1_thinking,
|
|
239
|
+
system_message=system_message,
|
|
240
|
+
user_input=input,
|
|
241
|
+
thinking_instructions=cot_prompt,
|
|
242
|
+
)
|
|
207
243
|
else:
|
|
208
|
-
|
|
244
|
+
# Unstructured output with COT
|
|
245
|
+
# Two calls to separate the thinking from the final response
|
|
246
|
+
return get_chat_formatter(
|
|
247
|
+
strategy=ChatStrategy.two_message_cot,
|
|
248
|
+
system_message=system_message,
|
|
249
|
+
user_input=input,
|
|
250
|
+
thinking_instructions=cot_prompt,
|
|
251
|
+
)
|
|
209
252
|
|
|
210
253
|
# create a run and task output
|
|
211
254
|
def generate_run(
|
|
212
|
-
self,
|
|
255
|
+
self,
|
|
256
|
+
input: Dict | str,
|
|
257
|
+
input_source: DataSource | None,
|
|
258
|
+
run_output: RunOutput,
|
|
259
|
+
usage: Usage | None = None,
|
|
213
260
|
) -> TaskRun:
|
|
214
261
|
# Convert input and output to JSON strings if they are dictionaries
|
|
215
262
|
input_str = (
|
|
@@ -242,6 +289,7 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
242
289
|
),
|
|
243
290
|
intermediate_outputs=run_output.intermediate_outputs,
|
|
244
291
|
tags=self.base_adapter_config.default_tags or [],
|
|
292
|
+
usage=usage,
|
|
245
293
|
)
|
|
246
294
|
|
|
247
295
|
return new_task_run
|
|
@@ -254,5 +302,22 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
254
302
|
props["model_name"] = self.run_config.model_name
|
|
255
303
|
props["model_provider"] = self.run_config.model_provider_name
|
|
256
304
|
props["prompt_id"] = self.run_config.prompt_id
|
|
305
|
+
props["structured_output_mode"] = self.run_config.structured_output_mode
|
|
306
|
+
props["temperature"] = self.run_config.temperature
|
|
307
|
+
props["top_p"] = self.run_config.top_p
|
|
257
308
|
|
|
258
309
|
return props
|
|
310
|
+
|
|
311
|
+
def update_run_config_unknown_structured_output_mode(self) -> None:
|
|
312
|
+
structured_output_mode = self.run_config.structured_output_mode
|
|
313
|
+
|
|
314
|
+
# Old datamodels didn't save the structured output mode. Some clients (tests, end users) might not set it.
|
|
315
|
+
# Look up our recommended mode from ml_model_list if we have one
|
|
316
|
+
if structured_output_mode == StructuredOutputMode.unknown:
|
|
317
|
+
new_run_config = self.run_config.model_copy(deep=True)
|
|
318
|
+
structured_output_mode = default_structured_output_mode_for_model_provider(
|
|
319
|
+
self.run_config.model_name,
|
|
320
|
+
self.run_config.model_provider_name,
|
|
321
|
+
)
|
|
322
|
+
new_run_config.structured_output_mode = structured_output_mode
|
|
323
|
+
self.run_config = new_run_config
|
|
@@ -1,7 +1,9 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from typing import Any, Dict
|
|
2
3
|
|
|
3
4
|
import litellm
|
|
4
5
|
from litellm.types.utils import ChoiceLogprobs, Choices, ModelResponse
|
|
6
|
+
from litellm.types.utils import Usage as LiteLlmUsage
|
|
5
7
|
|
|
6
8
|
import kiln_ai.datamodel as datamodel
|
|
7
9
|
from kiln_ai.adapters.ml_model_list import (
|
|
@@ -10,25 +12,23 @@ from kiln_ai.adapters.ml_model_list import (
|
|
|
10
12
|
StructuredOutputMode,
|
|
11
13
|
)
|
|
12
14
|
from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
13
|
-
COT_FINAL_ANSWER_PROMPT,
|
|
14
15
|
AdapterConfig,
|
|
15
16
|
BaseAdapter,
|
|
16
17
|
RunOutput,
|
|
18
|
+
Usage,
|
|
17
19
|
)
|
|
18
|
-
from kiln_ai.adapters.model_adapters.litellm_config import
|
|
19
|
-
|
|
20
|
-
)
|
|
21
|
-
from kiln_ai.datamodel import PromptGenerators, PromptId
|
|
22
|
-
from kiln_ai.datamodel.task import RunConfig
|
|
20
|
+
from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
|
|
21
|
+
from kiln_ai.datamodel.task import run_config_from_run_config_properties
|
|
23
22
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
24
23
|
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
25
26
|
|
|
26
27
|
class LiteLlmAdapter(BaseAdapter):
|
|
27
28
|
def __init__(
|
|
28
29
|
self,
|
|
29
30
|
config: LiteLlmConfig,
|
|
30
31
|
kiln_task: datamodel.Task,
|
|
31
|
-
prompt_id: PromptId | None = None,
|
|
32
32
|
base_adapter_config: AdapterConfig | None = None,
|
|
33
33
|
):
|
|
34
34
|
self.config = config
|
|
@@ -37,11 +37,10 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
37
37
|
self._headers = config.default_headers
|
|
38
38
|
self._litellm_model_id: str | None = None
|
|
39
39
|
|
|
40
|
-
|
|
40
|
+
# Create a RunConfig, adding the task to the RunConfigProperties
|
|
41
|
+
run_config = run_config_from_run_config_properties(
|
|
41
42
|
task=kiln_task,
|
|
42
|
-
|
|
43
|
-
model_provider_name=config.provider_name,
|
|
44
|
-
prompt_id=prompt_id or PromptGenerators.SIMPLE,
|
|
43
|
+
run_config_properties=config.run_config_properties,
|
|
45
44
|
)
|
|
46
45
|
|
|
47
46
|
super().__init__(
|
|
@@ -49,84 +48,74 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
49
48
|
config=base_adapter_config,
|
|
50
49
|
)
|
|
51
50
|
|
|
52
|
-
async def _run(self, input: Dict | str) -> RunOutput:
|
|
51
|
+
async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]:
|
|
53
52
|
provider = self.model_provider()
|
|
54
53
|
if not provider.model_id:
|
|
55
54
|
raise ValueError("Model ID is required for OpenAI compatible models")
|
|
56
55
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
# First call for chain of thought
|
|
78
|
-
# No response format as this request is for "thinking" in plain text
|
|
79
|
-
# No logprobs as only needed for final answer
|
|
56
|
+
chat_formatter = self.build_chat_formatter(input)
|
|
57
|
+
|
|
58
|
+
prior_output = None
|
|
59
|
+
prior_message = None
|
|
60
|
+
response = None
|
|
61
|
+
turns = 0
|
|
62
|
+
while True:
|
|
63
|
+
turns += 1
|
|
64
|
+
if turns > 10:
|
|
65
|
+
raise RuntimeError(
|
|
66
|
+
"Too many turns. Stopping iteration to avoid using too many tokens."
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
turn = chat_formatter.next_turn(prior_output)
|
|
70
|
+
if turn is None:
|
|
71
|
+
break
|
|
72
|
+
|
|
73
|
+
skip_response_format = not turn.final_call
|
|
74
|
+
all_messages = chat_formatter.message_dicts()
|
|
80
75
|
completion_kwargs = await self.build_completion_kwargs(
|
|
81
|
-
provider,
|
|
76
|
+
provider,
|
|
77
|
+
all_messages,
|
|
78
|
+
self.base_adapter_config.top_logprobs if turn.final_call else None,
|
|
79
|
+
skip_response_format,
|
|
82
80
|
)
|
|
83
|
-
|
|
81
|
+
response = await litellm.acompletion(**completion_kwargs)
|
|
84
82
|
if (
|
|
85
|
-
not isinstance(
|
|
86
|
-
or not
|
|
87
|
-
or len(
|
|
88
|
-
or not isinstance(
|
|
83
|
+
not isinstance(response, ModelResponse)
|
|
84
|
+
or not response.choices
|
|
85
|
+
or len(response.choices) == 0
|
|
86
|
+
or not isinstance(response.choices[0], Choices)
|
|
89
87
|
):
|
|
90
88
|
raise RuntimeError(
|
|
91
|
-
f"Expected ModelResponse with Choices, got {type(
|
|
89
|
+
f"Expected ModelResponse with Choices, got {type(response)}."
|
|
92
90
|
)
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
intermediate_outputs["chain_of_thought"] = cot_content
|
|
96
|
-
|
|
97
|
-
messages.extend(
|
|
98
|
-
[
|
|
99
|
-
{"role": "assistant", "content": cot_content or ""},
|
|
100
|
-
{"role": "user", "content": COT_FINAL_ANSWER_PROMPT},
|
|
101
|
-
]
|
|
102
|
-
)
|
|
91
|
+
prior_message = response.choices[0].message
|
|
92
|
+
prior_output = prior_message.content
|
|
103
93
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
94
|
+
# Fallback: Use args of first tool call to task_response if it exists
|
|
95
|
+
if (
|
|
96
|
+
not prior_output
|
|
97
|
+
and hasattr(prior_message, "tool_calls")
|
|
98
|
+
and prior_message.tool_calls
|
|
99
|
+
):
|
|
100
|
+
tool_call = next(
|
|
101
|
+
(
|
|
102
|
+
tool_call
|
|
103
|
+
for tool_call in prior_message.tool_calls
|
|
104
|
+
if tool_call.function.name == "task_response"
|
|
105
|
+
),
|
|
106
|
+
None,
|
|
107
|
+
)
|
|
108
|
+
if tool_call:
|
|
109
|
+
prior_output = tool_call.function.arguments
|
|
109
110
|
|
|
110
|
-
|
|
111
|
-
|
|
111
|
+
if not prior_output:
|
|
112
|
+
raise RuntimeError("No output returned from model")
|
|
112
113
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
if hasattr(response, "error") and response.__getattribute__("error"):
|
|
116
|
-
raise RuntimeError(
|
|
117
|
-
f"LLM API returned an error: {response.__getattribute__('error')}"
|
|
118
|
-
)
|
|
114
|
+
if response is None or prior_message is None:
|
|
115
|
+
raise RuntimeError("No response returned from model")
|
|
119
116
|
|
|
120
|
-
|
|
121
|
-
not response.choices
|
|
122
|
-
or len(response.choices) == 0
|
|
123
|
-
or not isinstance(response.choices[0], Choices)
|
|
124
|
-
):
|
|
125
|
-
raise RuntimeError(
|
|
126
|
-
"No message content returned in the response from LLM API"
|
|
127
|
-
)
|
|
117
|
+
intermediate_outputs = chat_formatter.intermediate_outputs()
|
|
128
118
|
|
|
129
|
-
message = response.choices[0].message
|
|
130
119
|
logprobs = (
|
|
131
120
|
response.choices[0].logprobs
|
|
132
121
|
if hasattr(response.choices[0], "logprobs")
|
|
@@ -139,28 +128,16 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
139
128
|
raise RuntimeError("Logprobs were required, but no logprobs were returned.")
|
|
140
129
|
|
|
141
130
|
# Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream)
|
|
142
|
-
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
|
143
|
-
intermediate_outputs["reasoning"] = message.reasoning_content
|
|
144
|
-
|
|
145
|
-
# the string content of the response
|
|
146
|
-
response_content = message.content
|
|
147
|
-
|
|
148
|
-
# Fallback: Use args of first tool call to task_response if it exists
|
|
149
131
|
if (
|
|
150
|
-
not
|
|
151
|
-
and hasattr(
|
|
152
|
-
and
|
|
132
|
+
prior_message is not None
|
|
133
|
+
and hasattr(prior_message, "reasoning_content")
|
|
134
|
+
and prior_message.reasoning_content
|
|
135
|
+
and len(prior_message.reasoning_content.strip()) > 0
|
|
153
136
|
):
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
if tool_call.function.name == "task_response"
|
|
159
|
-
),
|
|
160
|
-
None,
|
|
161
|
-
)
|
|
162
|
-
if tool_call:
|
|
163
|
-
response_content = tool_call.function.arguments
|
|
137
|
+
intermediate_outputs["reasoning"] = prior_message.reasoning_content.strip()
|
|
138
|
+
|
|
139
|
+
# the string content of the response
|
|
140
|
+
response_content = prior_output
|
|
164
141
|
|
|
165
142
|
if not isinstance(response_content, str):
|
|
166
143
|
raise RuntimeError(f"response is not a string: {response_content}")
|
|
@@ -169,7 +146,7 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
169
146
|
output=response_content,
|
|
170
147
|
intermediate_outputs=intermediate_outputs,
|
|
171
148
|
output_logprobs=logprobs,
|
|
172
|
-
)
|
|
149
|
+
), self.usage_from_response(response)
|
|
173
150
|
|
|
174
151
|
def adapter_name(self) -> str:
|
|
175
152
|
return "kiln_openai_compatible_adapter"
|
|
@@ -179,8 +156,9 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
179
156
|
if not self.has_structured_output():
|
|
180
157
|
return {}
|
|
181
158
|
|
|
182
|
-
|
|
183
|
-
|
|
159
|
+
structured_output_mode = self.run_config.structured_output_mode
|
|
160
|
+
|
|
161
|
+
match structured_output_mode:
|
|
184
162
|
case StructuredOutputMode.json_mode:
|
|
185
163
|
return {"response_format": {"type": "json_object"}}
|
|
186
164
|
case StructuredOutputMode.json_schema:
|
|
@@ -199,16 +177,20 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
199
177
|
# We set response_format to json_object and also set json instructions in the prompt
|
|
200
178
|
return {"response_format": {"type": "json_object"}}
|
|
201
179
|
case StructuredOutputMode.default:
|
|
202
|
-
|
|
180
|
+
provider_name = self.run_config.model_provider_name
|
|
181
|
+
if provider_name == ModelProviderName.ollama:
|
|
203
182
|
# Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
|
|
204
183
|
return self.json_schema_response_format()
|
|
205
184
|
else:
|
|
206
185
|
# Default to function calling -- it's older than the other modes. Higher compatibility.
|
|
207
186
|
# Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
|
|
208
|
-
strict =
|
|
187
|
+
strict = provider_name == ModelProviderName.openai
|
|
209
188
|
return self.tool_call_params(strict=strict)
|
|
189
|
+
case StructuredOutputMode.unknown:
|
|
190
|
+
# See above, but this case should never happen.
|
|
191
|
+
raise ValueError("Structured output mode is unknown.")
|
|
210
192
|
case _:
|
|
211
|
-
raise_exhaustive_enum_error(
|
|
193
|
+
raise_exhaustive_enum_error(structured_output_mode)
|
|
212
194
|
|
|
213
195
|
def json_schema_response_format(self) -> dict[str, Any]:
|
|
214
196
|
output_schema = self.task().output_schema()
|
|
@@ -380,6 +362,13 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
380
362
|
"messages": messages,
|
|
381
363
|
"api_base": self._api_base,
|
|
382
364
|
"headers": self._headers,
|
|
365
|
+
"temperature": self.run_config.temperature,
|
|
366
|
+
"top_p": self.run_config.top_p,
|
|
367
|
+
# This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
|
|
368
|
+
# Not all models and providers support all openai params (for example, o3 doesn't support top_p)
|
|
369
|
+
# Better to ignore them than to fail the model call.
|
|
370
|
+
# https://docs.litellm.ai/docs/completion/input
|
|
371
|
+
"drop_params": True,
|
|
383
372
|
**extra_body,
|
|
384
373
|
**self._additional_body_options,
|
|
385
374
|
}
|
|
@@ -394,3 +383,30 @@ class LiteLlmAdapter(BaseAdapter):
|
|
|
394
383
|
completion_kwargs["top_logprobs"] = top_logprobs
|
|
395
384
|
|
|
396
385
|
return completion_kwargs
|
|
386
|
+
|
|
387
|
+
def usage_from_response(self, response: ModelResponse) -> Usage | None:
|
|
388
|
+
litellm_usage = response.get("usage", None)
|
|
389
|
+
cost = response._hidden_params.get("response_cost", None)
|
|
390
|
+
if not litellm_usage and not cost:
|
|
391
|
+
return None
|
|
392
|
+
|
|
393
|
+
usage = Usage()
|
|
394
|
+
|
|
395
|
+
if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
|
|
396
|
+
usage.input_tokens = litellm_usage.get("prompt_tokens", None)
|
|
397
|
+
usage.output_tokens = litellm_usage.get("completion_tokens", None)
|
|
398
|
+
usage.total_tokens = litellm_usage.get("total_tokens", None)
|
|
399
|
+
else:
|
|
400
|
+
logger.warning(
|
|
401
|
+
f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
if isinstance(cost, float):
|
|
405
|
+
usage.cost = cost
|
|
406
|
+
elif cost is not None:
|
|
407
|
+
# None is allowed, but no other types are expected
|
|
408
|
+
logger.warning(
|
|
409
|
+
f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
return usage
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from dataclasses import dataclass, field
|
|
2
2
|
|
|
3
|
+
from kiln_ai.datamodel.task import RunConfigProperties
|
|
4
|
+
|
|
3
5
|
|
|
4
6
|
@dataclass
|
|
5
7
|
class LiteLlmConfig:
|
|
6
|
-
|
|
7
|
-
provider_name: str
|
|
8
|
+
run_config_properties: RunConfigProperties
|
|
8
9
|
# If set, over rides the provider-name based URL from litellm
|
|
9
10
|
base_url: str | None = None
|
|
10
11
|
# Headers to send with every request
|