kiln-ai 0.15.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (72) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +22 -44
  3. kiln_ai/adapters/chat/__init__.py +8 -0
  4. kiln_ai/adapters/chat/chat_formatter.py +234 -0
  5. kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
  6. kiln_ai/adapters/data_gen/test_data_gen_task.py +19 -6
  7. kiln_ai/adapters/eval/base_eval.py +8 -6
  8. kiln_ai/adapters/eval/eval_runner.py +9 -65
  9. kiln_ai/adapters/eval/g_eval.py +26 -8
  10. kiln_ai/adapters/eval/test_base_eval.py +166 -15
  11. kiln_ai/adapters/eval/test_eval_runner.py +3 -0
  12. kiln_ai/adapters/eval/test_g_eval.py +1 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +2 -2
  14. kiln_ai/adapters/fine_tune/dataset_formatter.py +153 -197
  15. kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
  16. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +402 -211
  17. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
  18. kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
  19. kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
  20. kiln_ai/adapters/fine_tune/test_vertex_finetune.py +4 -4
  21. kiln_ai/adapters/fine_tune/together_finetune.py +12 -1
  22. kiln_ai/adapters/ml_model_list.py +556 -45
  23. kiln_ai/adapters/model_adapters/base_adapter.py +100 -35
  24. kiln_ai/adapters/model_adapters/litellm_adapter.py +116 -100
  25. kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
  26. kiln_ai/adapters/model_adapters/test_base_adapter.py +299 -52
  27. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +121 -22
  28. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +44 -2
  29. kiln_ai/adapters/model_adapters/test_structured_output.py +48 -18
  30. kiln_ai/adapters/parsers/base_parser.py +0 -3
  31. kiln_ai/adapters/parsers/parser_registry.py +5 -3
  32. kiln_ai/adapters/parsers/r1_parser.py +17 -2
  33. kiln_ai/adapters/parsers/request_formatters.py +40 -0
  34. kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
  35. kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
  36. kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
  37. kiln_ai/adapters/prompt_builders.py +14 -17
  38. kiln_ai/adapters/provider_tools.py +39 -4
  39. kiln_ai/adapters/repair/test_repair_task.py +27 -5
  40. kiln_ai/adapters/test_adapter_registry.py +88 -28
  41. kiln_ai/adapters/test_ml_model_list.py +158 -0
  42. kiln_ai/adapters/test_prompt_adaptors.py +17 -3
  43. kiln_ai/adapters/test_prompt_builders.py +27 -19
  44. kiln_ai/adapters/test_provider_tools.py +130 -12
  45. kiln_ai/datamodel/__init__.py +2 -2
  46. kiln_ai/datamodel/datamodel_enums.py +43 -4
  47. kiln_ai/datamodel/dataset_filters.py +69 -1
  48. kiln_ai/datamodel/dataset_split.py +4 -0
  49. kiln_ai/datamodel/eval.py +8 -0
  50. kiln_ai/datamodel/finetune.py +13 -7
  51. kiln_ai/datamodel/prompt_id.py +1 -0
  52. kiln_ai/datamodel/task.py +68 -7
  53. kiln_ai/datamodel/task_output.py +1 -1
  54. kiln_ai/datamodel/task_run.py +39 -7
  55. kiln_ai/datamodel/test_basemodel.py +5 -8
  56. kiln_ai/datamodel/test_dataset_filters.py +82 -0
  57. kiln_ai/datamodel/test_dataset_split.py +2 -8
  58. kiln_ai/datamodel/test_example_models.py +54 -0
  59. kiln_ai/datamodel/test_models.py +80 -9
  60. kiln_ai/datamodel/test_task.py +168 -2
  61. kiln_ai/utils/async_job_runner.py +106 -0
  62. kiln_ai/utils/config.py +3 -2
  63. kiln_ai/utils/dataset_import.py +81 -19
  64. kiln_ai/utils/logging.py +165 -0
  65. kiln_ai/utils/test_async_job_runner.py +199 -0
  66. kiln_ai/utils/test_config.py +23 -0
  67. kiln_ai/utils/test_dataset_import.py +272 -10
  68. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/METADATA +1 -1
  69. kiln_ai-0.17.0.dist-info/RECORD +113 -0
  70. kiln_ai-0.15.0.dist-info/RECORD +0 -104
  71. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/WHEEL +0 -0
  72. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,17 +1,32 @@
1
1
  import json
2
2
  from abc import ABCMeta, abstractmethod
3
3
  from dataclasses import dataclass
4
- from typing import Dict, Literal, Tuple
5
-
6
- import jsonschema
7
-
8
- from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
4
+ from typing import Dict, Tuple
5
+
6
+ from kiln_ai.adapters.chat.chat_formatter import (
7
+ ChatFormatter,
8
+ get_chat_formatter,
9
+ )
10
+ from kiln_ai.adapters.ml_model_list import (
11
+ KilnModelProvider,
12
+ StructuredOutputMode,
13
+ default_structured_output_mode_for_model_provider,
14
+ )
9
15
  from kiln_ai.adapters.parsers.json_parser import parse_json_string
10
16
  from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id
17
+ from kiln_ai.adapters.parsers.request_formatters import request_formatter_from_id
11
18
  from kiln_ai.adapters.prompt_builders import prompt_builder_from_id
12
19
  from kiln_ai.adapters.provider_tools import kiln_model_provider_from
13
20
  from kiln_ai.adapters.run_output import RunOutput
14
- from kiln_ai.datamodel import DataSource, DataSourceType, Task, TaskOutput, TaskRun
21
+ from kiln_ai.datamodel import (
22
+ DataSource,
23
+ DataSourceType,
24
+ Task,
25
+ TaskOutput,
26
+ TaskRun,
27
+ Usage,
28
+ )
29
+ from kiln_ai.datamodel.datamodel_enums import ChatStrategy
15
30
  from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
16
31
  from kiln_ai.datamodel.task import RunConfig
17
32
  from kiln_ai.utils.config import Config
@@ -30,9 +45,6 @@ class AdapterConfig:
30
45
  default_tags: list[str] | None = None
31
46
 
32
47
 
33
- COT_FINAL_ANSWER_PROMPT = "Considering the above, return a final result."
34
-
35
-
36
48
  class BaseAdapter(metaclass=ABCMeta):
37
49
  """Base class for AI model adapters that handle task execution.
38
50
 
@@ -53,6 +65,7 @@ class BaseAdapter(metaclass=ABCMeta):
53
65
  config: AdapterConfig | None = None,
54
66
  ):
55
67
  self.run_config = run_config
68
+ self.update_run_config_unknown_structured_output_mode()
56
69
  self.prompt_builder = prompt_builder_from_id(
57
70
  run_config.prompt_id, run_config.task
58
71
  )
@@ -106,14 +119,19 @@ class BaseAdapter(metaclass=ABCMeta):
106
119
  "This task requires a specific input schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information.",
107
120
  )
108
121
 
122
+ # Format model input for model call (we save the original input in the task without formatting)
123
+ formatted_input = input
124
+ formatter_id = self.model_provider().formatter
125
+ if formatter_id is not None:
126
+ formatter = request_formatter_from_id(formatter_id)
127
+ formatted_input = formatter.format_input(input)
128
+
109
129
  # Run
110
- run_output = await self._run(input)
130
+ run_output, usage = await self._run(formatted_input)
111
131
 
112
132
  # Parse
113
133
  provider = self.model_provider()
114
- parser = model_parser_from_id(provider.parser)(
115
- structured_output=self.has_structured_output()
116
- )
134
+ parser = model_parser_from_id(provider.parser)
117
135
  parsed_output = parser.parse_output(original_output=run_output)
118
136
 
119
137
  # validate output
@@ -147,7 +165,7 @@ class BaseAdapter(metaclass=ABCMeta):
147
165
  )
148
166
 
149
167
  # Generate the run and output
150
- run = self.generate_run(input, input_source, parsed_output)
168
+ run = self.generate_run(input, input_source, parsed_output, usage)
151
169
 
152
170
  # Save the run if configured to do so, and we have a path to save to
153
171
  if (
@@ -170,15 +188,15 @@ class BaseAdapter(metaclass=ABCMeta):
170
188
  pass
171
189
 
172
190
  @abstractmethod
173
- async def _run(self, input: Dict | str) -> RunOutput:
191
+ async def _run(self, input: Dict | str) -> Tuple[RunOutput, Usage | None]:
174
192
  pass
175
193
 
176
194
  def build_prompt(self) -> str:
177
195
  # The prompt builder needs to know if we want to inject formatting instructions
178
- provider = self.model_provider()
196
+ structured_output_mode = self.run_config.structured_output_mode
179
197
  add_json_instructions = self.has_structured_output() and (
180
- provider.structured_output_mode == StructuredOutputMode.json_instructions
181
- or provider.structured_output_mode
198
+ structured_output_mode == StructuredOutputMode.json_instructions
199
+ or structured_output_mode
182
200
  == StructuredOutputMode.json_instruction_and_object
183
201
  )
184
202
 
@@ -186,30 +204,59 @@ class BaseAdapter(metaclass=ABCMeta):
186
204
  include_json_instructions=add_json_instructions
187
205
  )
188
206
 
189
- def run_strategy(
190
- self,
191
- ) -> Tuple[Literal["cot_as_message", "cot_two_call", "basic"], str | None]:
192
- # Determine the run strategy for COT prompting. 3 options:
193
- # 1. "Thinking" LLM designed to output thinking in a structured format plus a COT prompt: we make 1 call to the LLM, which outputs thinking in a structured format. We include the thinking instuctions as a message.
194
- # 2. Normal LLM with COT prompt: we make 2 calls to the LLM - one for thinking and one for the final response. This helps us use the LLM's structured output modes (json_schema, tools, etc), which can't be used in a single call. It also separates the thinking from the final response.
195
- # 3. Non chain of thought: we make 1 call to the LLM, with no COT prompt.
207
+ def build_chat_formatter(self, input: Dict | str) -> ChatFormatter:
208
+ # Determine the chat strategy to use based on the prompt the user selected, the model's capabilities, and if the model was finetuned with a specific chat strategy.
209
+
196
210
  cot_prompt = self.prompt_builder.chain_of_thought_prompt()
197
- reasoning_capable = self.model_provider().reasoning_capable
211
+ system_message = self.build_prompt()
212
+
213
+ # If no COT prompt, use the single turn strategy. Even when a tuned strategy is set, as the tuned strategy is either already single turn, or won't work without a COT prompt.
214
+ if not cot_prompt:
215
+ return get_chat_formatter(
216
+ strategy=ChatStrategy.single_turn,
217
+ system_message=system_message,
218
+ user_input=input,
219
+ )
198
220
 
199
- if cot_prompt and reasoning_capable:
200
- # 1: "Thinking" LLM designed to output thinking in a structured format
221
+ # Some models like finetunes are trained with a specific chat strategy. Use that.
222
+ # However, don't use that if it is single turn. The user selected a COT prompt, and we give explicit prompt selection priority over the tuned strategy.
223
+ tuned_chat_strategy = self.model_provider().tuned_chat_strategy
224
+ if tuned_chat_strategy and tuned_chat_strategy != ChatStrategy.single_turn:
225
+ return get_chat_formatter(
226
+ strategy=tuned_chat_strategy,
227
+ system_message=system_message,
228
+ user_input=input,
229
+ thinking_instructions=cot_prompt,
230
+ )
231
+
232
+ # Pick the best chat strategy for the model given it has a cot prompt.
233
+ reasoning_capable = self.model_provider().reasoning_capable
234
+ if reasoning_capable:
235
+ # "Thinking" LLM designed to output thinking in a structured format. We'll use it's native format.
201
236
  # A simple message with the COT prompt appended to the message list is sufficient
202
- return "cot_as_message", cot_prompt
203
- elif cot_prompt:
204
- # 2: Unstructured output with COT
205
- # Two calls to separate the thinking from the final response
206
- return "cot_two_call", cot_prompt
237
+ return get_chat_formatter(
238
+ strategy=ChatStrategy.single_turn_r1_thinking,
239
+ system_message=system_message,
240
+ user_input=input,
241
+ thinking_instructions=cot_prompt,
242
+ )
207
243
  else:
208
- return "basic", None
244
+ # Unstructured output with COT
245
+ # Two calls to separate the thinking from the final response
246
+ return get_chat_formatter(
247
+ strategy=ChatStrategy.two_message_cot,
248
+ system_message=system_message,
249
+ user_input=input,
250
+ thinking_instructions=cot_prompt,
251
+ )
209
252
 
210
253
  # create a run and task output
211
254
  def generate_run(
212
- self, input: Dict | str, input_source: DataSource | None, run_output: RunOutput
255
+ self,
256
+ input: Dict | str,
257
+ input_source: DataSource | None,
258
+ run_output: RunOutput,
259
+ usage: Usage | None = None,
213
260
  ) -> TaskRun:
214
261
  # Convert input and output to JSON strings if they are dictionaries
215
262
  input_str = (
@@ -242,6 +289,7 @@ class BaseAdapter(metaclass=ABCMeta):
242
289
  ),
243
290
  intermediate_outputs=run_output.intermediate_outputs,
244
291
  tags=self.base_adapter_config.default_tags or [],
292
+ usage=usage,
245
293
  )
246
294
 
247
295
  return new_task_run
@@ -254,5 +302,22 @@ class BaseAdapter(metaclass=ABCMeta):
254
302
  props["model_name"] = self.run_config.model_name
255
303
  props["model_provider"] = self.run_config.model_provider_name
256
304
  props["prompt_id"] = self.run_config.prompt_id
305
+ props["structured_output_mode"] = self.run_config.structured_output_mode
306
+ props["temperature"] = self.run_config.temperature
307
+ props["top_p"] = self.run_config.top_p
257
308
 
258
309
  return props
310
+
311
+ def update_run_config_unknown_structured_output_mode(self) -> None:
312
+ structured_output_mode = self.run_config.structured_output_mode
313
+
314
+ # Old datamodels didn't save the structured output mode. Some clients (tests, end users) might not set it.
315
+ # Look up our recommended mode from ml_model_list if we have one
316
+ if structured_output_mode == StructuredOutputMode.unknown:
317
+ new_run_config = self.run_config.model_copy(deep=True)
318
+ structured_output_mode = default_structured_output_mode_for_model_provider(
319
+ self.run_config.model_name,
320
+ self.run_config.model_provider_name,
321
+ )
322
+ new_run_config.structured_output_mode = structured_output_mode
323
+ self.run_config = new_run_config
@@ -1,7 +1,9 @@
1
+ import logging
1
2
  from typing import Any, Dict
2
3
 
3
4
  import litellm
4
5
  from litellm.types.utils import ChoiceLogprobs, Choices, ModelResponse
6
+ from litellm.types.utils import Usage as LiteLlmUsage
5
7
 
6
8
  import kiln_ai.datamodel as datamodel
7
9
  from kiln_ai.adapters.ml_model_list import (
@@ -10,25 +12,23 @@ from kiln_ai.adapters.ml_model_list import (
10
12
  StructuredOutputMode,
11
13
  )
12
14
  from kiln_ai.adapters.model_adapters.base_adapter import (
13
- COT_FINAL_ANSWER_PROMPT,
14
15
  AdapterConfig,
15
16
  BaseAdapter,
16
17
  RunOutput,
18
+ Usage,
17
19
  )
18
- from kiln_ai.adapters.model_adapters.litellm_config import (
19
- LiteLlmConfig,
20
- )
21
- from kiln_ai.datamodel import PromptGenerators, PromptId
22
- from kiln_ai.datamodel.task import RunConfig
20
+ from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
21
+ from kiln_ai.datamodel.task import run_config_from_run_config_properties
23
22
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
24
23
 
24
+ logger = logging.getLogger(__name__)
25
+
25
26
 
26
27
  class LiteLlmAdapter(BaseAdapter):
27
28
  def __init__(
28
29
  self,
29
30
  config: LiteLlmConfig,
30
31
  kiln_task: datamodel.Task,
31
- prompt_id: PromptId | None = None,
32
32
  base_adapter_config: AdapterConfig | None = None,
33
33
  ):
34
34
  self.config = config
@@ -37,11 +37,10 @@ class LiteLlmAdapter(BaseAdapter):
37
37
  self._headers = config.default_headers
38
38
  self._litellm_model_id: str | None = None
39
39
 
40
- run_config = RunConfig(
40
+ # Create a RunConfig, adding the task to the RunConfigProperties
41
+ run_config = run_config_from_run_config_properties(
41
42
  task=kiln_task,
42
- model_name=config.model_name,
43
- model_provider_name=config.provider_name,
44
- prompt_id=prompt_id or PromptGenerators.SIMPLE,
43
+ run_config_properties=config.run_config_properties,
45
44
  )
46
45
 
47
46
  super().__init__(
@@ -49,84 +48,74 @@ class LiteLlmAdapter(BaseAdapter):
49
48
  config=base_adapter_config,
50
49
  )
51
50
 
52
- async def _run(self, input: Dict | str) -> RunOutput:
51
+ async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]:
53
52
  provider = self.model_provider()
54
53
  if not provider.model_id:
55
54
  raise ValueError("Model ID is required for OpenAI compatible models")
56
55
 
57
- intermediate_outputs: dict[str, str] = {}
58
- prompt = self.build_prompt()
59
- user_msg = self.prompt_builder.build_user_message(input)
60
- messages = [
61
- {"role": "system", "content": prompt},
62
- {"role": "user", "content": user_msg},
63
- ]
64
-
65
- run_strategy, cot_prompt = self.run_strategy()
66
-
67
- if run_strategy == "cot_as_message":
68
- # Used for reasoning-capable models that can output thinking and structured format
69
- if not cot_prompt:
70
- raise ValueError("cot_prompt is required for cot_as_message strategy")
71
- messages.append({"role": "system", "content": cot_prompt})
72
- elif run_strategy == "cot_two_call":
73
- if not cot_prompt:
74
- raise ValueError("cot_prompt is required for cot_two_call strategy")
75
- messages.append({"role": "system", "content": cot_prompt})
76
-
77
- # First call for chain of thought
78
- # No response format as this request is for "thinking" in plain text
79
- # No logprobs as only needed for final answer
56
+ chat_formatter = self.build_chat_formatter(input)
57
+
58
+ prior_output = None
59
+ prior_message = None
60
+ response = None
61
+ turns = 0
62
+ while True:
63
+ turns += 1
64
+ if turns > 10:
65
+ raise RuntimeError(
66
+ "Too many turns. Stopping iteration to avoid using too many tokens."
67
+ )
68
+
69
+ turn = chat_formatter.next_turn(prior_output)
70
+ if turn is None:
71
+ break
72
+
73
+ skip_response_format = not turn.final_call
74
+ all_messages = chat_formatter.message_dicts()
80
75
  completion_kwargs = await self.build_completion_kwargs(
81
- provider, messages, None, skip_response_format=True
76
+ provider,
77
+ all_messages,
78
+ self.base_adapter_config.top_logprobs if turn.final_call else None,
79
+ skip_response_format,
82
80
  )
83
- cot_response = await litellm.acompletion(**completion_kwargs)
81
+ response = await litellm.acompletion(**completion_kwargs)
84
82
  if (
85
- not isinstance(cot_response, ModelResponse)
86
- or not cot_response.choices
87
- or len(cot_response.choices) == 0
88
- or not isinstance(cot_response.choices[0], Choices)
83
+ not isinstance(response, ModelResponse)
84
+ or not response.choices
85
+ or len(response.choices) == 0
86
+ or not isinstance(response.choices[0], Choices)
89
87
  ):
90
88
  raise RuntimeError(
91
- f"Expected ModelResponse with Choices, got {type(cot_response)}."
89
+ f"Expected ModelResponse with Choices, got {type(response)}."
92
90
  )
93
- cot_content = cot_response.choices[0].message.content
94
- if cot_content is not None:
95
- intermediate_outputs["chain_of_thought"] = cot_content
96
-
97
- messages.extend(
98
- [
99
- {"role": "assistant", "content": cot_content or ""},
100
- {"role": "user", "content": COT_FINAL_ANSWER_PROMPT},
101
- ]
102
- )
91
+ prior_message = response.choices[0].message
92
+ prior_output = prior_message.content
103
93
 
104
- # Make the API call using litellm
105
- completion_kwargs = await self.build_completion_kwargs(
106
- provider, messages, self.base_adapter_config.top_logprobs
107
- )
108
- response = await litellm.acompletion(**completion_kwargs)
94
+ # Fallback: Use args of first tool call to task_response if it exists
95
+ if (
96
+ not prior_output
97
+ and hasattr(prior_message, "tool_calls")
98
+ and prior_message.tool_calls
99
+ ):
100
+ tool_call = next(
101
+ (
102
+ tool_call
103
+ for tool_call in prior_message.tool_calls
104
+ if tool_call.function.name == "task_response"
105
+ ),
106
+ None,
107
+ )
108
+ if tool_call:
109
+ prior_output = tool_call.function.arguments
109
110
 
110
- if not isinstance(response, ModelResponse):
111
- raise RuntimeError(f"Expected ModelResponse, got {type(response)}.")
111
+ if not prior_output:
112
+ raise RuntimeError("No output returned from model")
112
113
 
113
- # Maybe remove this? There is no error attribute on the response object.
114
- # # Keeping in typesafe way as we added it for a reason, but should investigate what that was and if it still applies.
115
- if hasattr(response, "error") and response.__getattribute__("error"):
116
- raise RuntimeError(
117
- f"LLM API returned an error: {response.__getattribute__('error')}"
118
- )
114
+ if response is None or prior_message is None:
115
+ raise RuntimeError("No response returned from model")
119
116
 
120
- if (
121
- not response.choices
122
- or len(response.choices) == 0
123
- or not isinstance(response.choices[0], Choices)
124
- ):
125
- raise RuntimeError(
126
- "No message content returned in the response from LLM API"
127
- )
117
+ intermediate_outputs = chat_formatter.intermediate_outputs()
128
118
 
129
- message = response.choices[0].message
130
119
  logprobs = (
131
120
  response.choices[0].logprobs
132
121
  if hasattr(response.choices[0], "logprobs")
@@ -139,28 +128,16 @@ class LiteLlmAdapter(BaseAdapter):
139
128
  raise RuntimeError("Logprobs were required, but no logprobs were returned.")
140
129
 
141
130
  # Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream)
142
- if hasattr(message, "reasoning_content") and message.reasoning_content:
143
- intermediate_outputs["reasoning"] = message.reasoning_content
144
-
145
- # the string content of the response
146
- response_content = message.content
147
-
148
- # Fallback: Use args of first tool call to task_response if it exists
149
131
  if (
150
- not response_content
151
- and hasattr(message, "tool_calls")
152
- and message.tool_calls
132
+ prior_message is not None
133
+ and hasattr(prior_message, "reasoning_content")
134
+ and prior_message.reasoning_content
135
+ and len(prior_message.reasoning_content.strip()) > 0
153
136
  ):
154
- tool_call = next(
155
- (
156
- tool_call
157
- for tool_call in message.tool_calls
158
- if tool_call.function.name == "task_response"
159
- ),
160
- None,
161
- )
162
- if tool_call:
163
- response_content = tool_call.function.arguments
137
+ intermediate_outputs["reasoning"] = prior_message.reasoning_content.strip()
138
+
139
+ # the string content of the response
140
+ response_content = prior_output
164
141
 
165
142
  if not isinstance(response_content, str):
166
143
  raise RuntimeError(f"response is not a string: {response_content}")
@@ -169,7 +146,7 @@ class LiteLlmAdapter(BaseAdapter):
169
146
  output=response_content,
170
147
  intermediate_outputs=intermediate_outputs,
171
148
  output_logprobs=logprobs,
172
- )
149
+ ), self.usage_from_response(response)
173
150
 
174
151
  def adapter_name(self) -> str:
175
152
  return "kiln_openai_compatible_adapter"
@@ -179,8 +156,9 @@ class LiteLlmAdapter(BaseAdapter):
179
156
  if not self.has_structured_output():
180
157
  return {}
181
158
 
182
- provider = self.model_provider()
183
- match provider.structured_output_mode:
159
+ structured_output_mode = self.run_config.structured_output_mode
160
+
161
+ match structured_output_mode:
184
162
  case StructuredOutputMode.json_mode:
185
163
  return {"response_format": {"type": "json_object"}}
186
164
  case StructuredOutputMode.json_schema:
@@ -199,16 +177,20 @@ class LiteLlmAdapter(BaseAdapter):
199
177
  # We set response_format to json_object and also set json instructions in the prompt
200
178
  return {"response_format": {"type": "json_object"}}
201
179
  case StructuredOutputMode.default:
202
- if provider.name == ModelProviderName.ollama:
180
+ provider_name = self.run_config.model_provider_name
181
+ if provider_name == ModelProviderName.ollama:
203
182
  # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
204
183
  return self.json_schema_response_format()
205
184
  else:
206
185
  # Default to function calling -- it's older than the other modes. Higher compatibility.
207
186
  # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
208
- strict = provider.name == ModelProviderName.openai
187
+ strict = provider_name == ModelProviderName.openai
209
188
  return self.tool_call_params(strict=strict)
189
+ case StructuredOutputMode.unknown:
190
+ # See above, but this case should never happen.
191
+ raise ValueError("Structured output mode is unknown.")
210
192
  case _:
211
- raise_exhaustive_enum_error(provider.structured_output_mode)
193
+ raise_exhaustive_enum_error(structured_output_mode)
212
194
 
213
195
  def json_schema_response_format(self) -> dict[str, Any]:
214
196
  output_schema = self.task().output_schema()
@@ -380,6 +362,13 @@ class LiteLlmAdapter(BaseAdapter):
380
362
  "messages": messages,
381
363
  "api_base": self._api_base,
382
364
  "headers": self._headers,
365
+ "temperature": self.run_config.temperature,
366
+ "top_p": self.run_config.top_p,
367
+ # This drops params that are not supported by the model. Only openai params like top_p, temperature -- not litellm params like model, etc.
368
+ # Not all models and providers support all openai params (for example, o3 doesn't support top_p)
369
+ # Better to ignore them than to fail the model call.
370
+ # https://docs.litellm.ai/docs/completion/input
371
+ "drop_params": True,
383
372
  **extra_body,
384
373
  **self._additional_body_options,
385
374
  }
@@ -394,3 +383,30 @@ class LiteLlmAdapter(BaseAdapter):
394
383
  completion_kwargs["top_logprobs"] = top_logprobs
395
384
 
396
385
  return completion_kwargs
386
+
387
+ def usage_from_response(self, response: ModelResponse) -> Usage | None:
388
+ litellm_usage = response.get("usage", None)
389
+ cost = response._hidden_params.get("response_cost", None)
390
+ if not litellm_usage and not cost:
391
+ return None
392
+
393
+ usage = Usage()
394
+
395
+ if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
396
+ usage.input_tokens = litellm_usage.get("prompt_tokens", None)
397
+ usage.output_tokens = litellm_usage.get("completion_tokens", None)
398
+ usage.total_tokens = litellm_usage.get("total_tokens", None)
399
+ else:
400
+ logger.warning(
401
+ f"Unexpected usage format from litellm: {litellm_usage}. Expected Usage object, got {type(litellm_usage)}"
402
+ )
403
+
404
+ if isinstance(cost, float):
405
+ usage.cost = cost
406
+ elif cost is not None:
407
+ # None is allowed, but no other types are expected
408
+ logger.warning(
409
+ f"Unexpected cost format from litellm: {cost}. Expected float, got {type(cost)}"
410
+ )
411
+
412
+ return usage
@@ -1,10 +1,11 @@
1
1
  from dataclasses import dataclass, field
2
2
 
3
+ from kiln_ai.datamodel.task import RunConfigProperties
4
+
3
5
 
4
6
  @dataclass
5
7
  class LiteLlmConfig:
6
- model_name: str
7
- provider_name: str
8
+ run_config_properties: RunConfigProperties
8
9
  # If set, over rides the provider-name based URL from litellm
9
10
  base_url: str | None = None
10
11
  # Headers to send with every request