kiln-ai 0.18.0__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (89) hide show
  1. kiln_ai/adapters/__init__.py +2 -2
  2. kiln_ai/adapters/adapter_registry.py +46 -0
  3. kiln_ai/adapters/chat/chat_formatter.py +8 -12
  4. kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
  5. kiln_ai/adapters/data_gen/data_gen_task.py +2 -2
  6. kiln_ai/adapters/data_gen/test_data_gen_task.py +7 -3
  7. kiln_ai/adapters/docker_model_runner_tools.py +119 -0
  8. kiln_ai/adapters/eval/base_eval.py +2 -2
  9. kiln_ai/adapters/eval/eval_runner.py +3 -1
  10. kiln_ai/adapters/eval/g_eval.py +2 -2
  11. kiln_ai/adapters/eval/test_base_eval.py +1 -1
  12. kiln_ai/adapters/eval/test_eval_runner.py +6 -12
  13. kiln_ai/adapters/eval/test_g_eval.py +3 -4
  14. kiln_ai/adapters/eval/test_g_eval_data.py +1 -1
  15. kiln_ai/adapters/fine_tune/__init__.py +1 -1
  16. kiln_ai/adapters/fine_tune/base_finetune.py +1 -0
  17. kiln_ai/adapters/fine_tune/fireworks_finetune.py +32 -20
  18. kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
  19. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +30 -21
  20. kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
  21. kiln_ai/adapters/ml_model_list.py +1009 -111
  22. kiln_ai/adapters/model_adapters/base_adapter.py +62 -28
  23. kiln_ai/adapters/model_adapters/litellm_adapter.py +397 -80
  24. kiln_ai/adapters/model_adapters/test_base_adapter.py +194 -18
  25. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +428 -4
  26. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
  27. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
  28. kiln_ai/adapters/model_adapters/test_structured_output.py +120 -14
  29. kiln_ai/adapters/parsers/__init__.py +1 -1
  30. kiln_ai/adapters/parsers/test_r1_parser.py +1 -1
  31. kiln_ai/adapters/provider_tools.py +35 -20
  32. kiln_ai/adapters/remote_config.py +57 -10
  33. kiln_ai/adapters/repair/repair_task.py +1 -1
  34. kiln_ai/adapters/repair/test_repair_task.py +12 -9
  35. kiln_ai/adapters/run_output.py +3 -0
  36. kiln_ai/adapters/test_adapter_registry.py +109 -2
  37. kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
  38. kiln_ai/adapters/test_ml_model_list.py +51 -1
  39. kiln_ai/adapters/test_prompt_adaptors.py +13 -6
  40. kiln_ai/adapters/test_provider_tools.py +73 -12
  41. kiln_ai/adapters/test_remote_config.py +470 -16
  42. kiln_ai/datamodel/__init__.py +23 -21
  43. kiln_ai/datamodel/basemodel.py +54 -28
  44. kiln_ai/datamodel/datamodel_enums.py +3 -0
  45. kiln_ai/datamodel/dataset_split.py +5 -3
  46. kiln_ai/datamodel/eval.py +4 -4
  47. kiln_ai/datamodel/external_tool_server.py +298 -0
  48. kiln_ai/datamodel/finetune.py +2 -2
  49. kiln_ai/datamodel/json_schema.py +25 -10
  50. kiln_ai/datamodel/project.py +11 -4
  51. kiln_ai/datamodel/prompt.py +2 -2
  52. kiln_ai/datamodel/prompt_id.py +4 -4
  53. kiln_ai/datamodel/registry.py +0 -15
  54. kiln_ai/datamodel/run_config.py +62 -0
  55. kiln_ai/datamodel/task.py +8 -83
  56. kiln_ai/datamodel/task_output.py +7 -2
  57. kiln_ai/datamodel/task_run.py +41 -0
  58. kiln_ai/datamodel/test_basemodel.py +213 -21
  59. kiln_ai/datamodel/test_eval_model.py +6 -6
  60. kiln_ai/datamodel/test_example_models.py +175 -0
  61. kiln_ai/datamodel/test_external_tool_server.py +691 -0
  62. kiln_ai/datamodel/test_model_perf.py +1 -1
  63. kiln_ai/datamodel/test_prompt_id.py +5 -1
  64. kiln_ai/datamodel/test_registry.py +8 -3
  65. kiln_ai/datamodel/test_task.py +20 -47
  66. kiln_ai/datamodel/test_tool_id.py +239 -0
  67. kiln_ai/datamodel/tool_id.py +83 -0
  68. kiln_ai/tools/__init__.py +8 -0
  69. kiln_ai/tools/base_tool.py +82 -0
  70. kiln_ai/tools/built_in_tools/__init__.py +13 -0
  71. kiln_ai/tools/built_in_tools/math_tools.py +124 -0
  72. kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
  73. kiln_ai/tools/mcp_server_tool.py +95 -0
  74. kiln_ai/tools/mcp_session_manager.py +243 -0
  75. kiln_ai/tools/test_base_tools.py +199 -0
  76. kiln_ai/tools/test_mcp_server_tool.py +457 -0
  77. kiln_ai/tools/test_mcp_session_manager.py +1585 -0
  78. kiln_ai/tools/test_tool_registry.py +473 -0
  79. kiln_ai/tools/tool_registry.py +64 -0
  80. kiln_ai/utils/config.py +32 -0
  81. kiln_ai/utils/open_ai_types.py +94 -0
  82. kiln_ai/utils/project_utils.py +17 -0
  83. kiln_ai/utils/test_config.py +138 -1
  84. kiln_ai/utils/test_open_ai_types.py +131 -0
  85. {kiln_ai-0.18.0.dist-info → kiln_ai-0.20.1.dist-info}/METADATA +37 -6
  86. kiln_ai-0.20.1.dist-info/RECORD +138 -0
  87. kiln_ai-0.18.0.dist-info/RECORD +0 -115
  88. {kiln_ai-0.18.0.dist-info → kiln_ai-0.20.1.dist-info}/WHEEL +0 -0
  89. {kiln_ai-0.18.0.dist-info → kiln_ai-0.20.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,9 +1,26 @@
1
+ import copy
2
+ import json
1
3
  import logging
2
- from typing import Any, Dict
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, List, Tuple, TypeAlias, Union
3
6
 
4
7
  import litellm
5
- from litellm.types.utils import ChoiceLogprobs, Choices, ModelResponse
8
+ from litellm.types.utils import (
9
+ ChatCompletionMessageToolCall,
10
+ ChoiceLogprobs,
11
+ Choices,
12
+ ModelResponse,
13
+ )
14
+ from litellm.types.utils import (
15
+ Message as LiteLLMMessage,
16
+ )
6
17
  from litellm.types.utils import Usage as LiteLlmUsage
18
+ from openai.types.chat import (
19
+ ChatCompletionToolMessageParam,
20
+ )
21
+ from openai.types.chat.chat_completion_message_tool_call_param import (
22
+ ChatCompletionMessageToolCallParam,
23
+ )
7
24
 
8
25
  import kiln_ai.datamodel as datamodel
9
26
  from kiln_ai.adapters.ml_model_list import (
@@ -18,11 +35,32 @@ from kiln_ai.adapters.model_adapters.base_adapter import (
18
35
  Usage,
19
36
  )
20
37
  from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
21
- from kiln_ai.datamodel.task import run_config_from_run_config_properties
38
+ from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
39
+ from kiln_ai.tools.base_tool import KilnToolInterface
22
40
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
41
+ from kiln_ai.utils.open_ai_types import (
42
+ ChatCompletionAssistantMessageParamWrapper,
43
+ ChatCompletionMessageParam,
44
+ )
45
+
46
+ MAX_CALLS_PER_TURN = 10
47
+ MAX_TOOL_CALLS_PER_TURN = 30
23
48
 
24
49
  logger = logging.getLogger(__name__)
25
50
 
51
+ ChatCompletionMessageIncludingLiteLLM: TypeAlias = Union[
52
+ ChatCompletionMessageParam, LiteLLMMessage
53
+ ]
54
+
55
+
56
+ @dataclass
57
+ class ModelTurnResult:
58
+ assistant_message: str
59
+ all_messages: list[ChatCompletionMessageIncludingLiteLLM]
60
+ model_response: ModelResponse | None
61
+ model_choice: Choices | None
62
+ usage: Usage
63
+
26
64
 
27
65
  class LiteLlmAdapter(BaseAdapter):
28
66
  def __init__(
@@ -36,117 +74,226 @@ class LiteLlmAdapter(BaseAdapter):
36
74
  self._api_base = config.base_url
37
75
  self._headers = config.default_headers
38
76
  self._litellm_model_id: str | None = None
77
+ self._cached_available_tools: list[KilnToolInterface] | None = None
39
78
 
40
- # Create a RunConfig, adding the task to the RunConfigProperties
41
- run_config = run_config_from_run_config_properties(
79
+ super().__init__(
42
80
  task=kiln_task,
43
- run_config_properties=config.run_config_properties,
81
+ run_config=config.run_config_properties,
82
+ config=base_adapter_config,
44
83
  )
45
84
 
46
- super().__init__(
47
- run_config=run_config,
48
- config=base_adapter_config,
85
+ async def _run_model_turn(
86
+ self,
87
+ provider: KilnModelProvider,
88
+ prior_messages: list[ChatCompletionMessageIncludingLiteLLM],
89
+ top_logprobs: int | None,
90
+ skip_response_format: bool,
91
+ ) -> ModelTurnResult:
92
+ """
93
+ Call the model for a single top level turn: from user message to agent message.
94
+
95
+ It may make handle iterations of tool calls between the user/agent message if needed.
96
+ """
97
+
98
+ usage = Usage()
99
+ messages = list(prior_messages)
100
+ tool_calls_count = 0
101
+
102
+ while tool_calls_count < MAX_TOOL_CALLS_PER_TURN:
103
+ # Build completion kwargs for tool calls
104
+ completion_kwargs = await self.build_completion_kwargs(
105
+ provider,
106
+ # Pass a copy, as acompletion mutates objects and breaks types.
107
+ copy.deepcopy(messages),
108
+ top_logprobs,
109
+ skip_response_format,
110
+ )
111
+
112
+ # Make the completion call
113
+ model_response, response_choice = await self.acompletion_checking_response(
114
+ **completion_kwargs
115
+ )
116
+
117
+ # count the usage
118
+ usage += self.usage_from_response(model_response)
119
+
120
+ # Extract content and tool calls
121
+ if not hasattr(response_choice, "message"):
122
+ raise ValueError("Response choice has no message")
123
+ content = response_choice.message.content
124
+ tool_calls = response_choice.message.tool_calls
125
+ if not content and not tool_calls:
126
+ raise ValueError(
127
+ "Model returned an assistant message, but no content or tool calls. This is not supported."
128
+ )
129
+
130
+ # Add message to messages, so it can be used in the next turn
131
+ messages.append(response_choice.message)
132
+
133
+ # Process tool calls if any
134
+ if tool_calls and len(tool_calls) > 0:
135
+ (
136
+ assistant_message_from_toolcall,
137
+ tool_call_messages,
138
+ ) = await self.process_tool_calls(tool_calls)
139
+
140
+ # Add tool call results to messages
141
+ messages.extend(tool_call_messages)
142
+
143
+ # If task_response tool was called, we're done
144
+ if assistant_message_from_toolcall is not None:
145
+ return ModelTurnResult(
146
+ assistant_message=assistant_message_from_toolcall,
147
+ all_messages=messages,
148
+ model_response=model_response,
149
+ model_choice=response_choice,
150
+ usage=usage,
151
+ )
152
+
153
+ # If there were tool calls, increment counter and continue
154
+ if tool_call_messages:
155
+ tool_calls_count += 1
156
+ continue
157
+
158
+ # If no tool calls, return the content as final output
159
+ if content:
160
+ return ModelTurnResult(
161
+ assistant_message=content,
162
+ all_messages=messages,
163
+ model_response=model_response,
164
+ model_choice=response_choice,
165
+ usage=usage,
166
+ )
167
+
168
+ # If we get here with no content and no tool calls, break
169
+ raise RuntimeError(
170
+ "Model returned neither content nor tool calls. It must return at least one of these."
171
+ )
172
+
173
+ raise RuntimeError(
174
+ f"Too many tool calls ({tool_calls_count}). Stopping iteration to avoid using too many tokens."
49
175
  )
50
176
 
51
177
  async def _run(self, input: Dict | str) -> tuple[RunOutput, Usage | None]:
178
+ usage = Usage()
179
+
52
180
  provider = self.model_provider()
53
181
  if not provider.model_id:
54
182
  raise ValueError("Model ID is required for OpenAI compatible models")
55
183
 
56
184
  chat_formatter = self.build_chat_formatter(input)
185
+ messages: list[ChatCompletionMessageIncludingLiteLLM] = []
57
186
 
58
- prior_output = None
59
- prior_message = None
60
- response = None
187
+ prior_output: str | None = None
188
+ final_choice: Choices | None = None
61
189
  turns = 0
190
+
62
191
  while True:
63
192
  turns += 1
64
- if turns > 10:
193
+ if turns > MAX_CALLS_PER_TURN:
65
194
  raise RuntimeError(
66
- "Too many turns. Stopping iteration to avoid using too many tokens."
195
+ f"Too many turns ({turns}). Stopping iteration to avoid using too many tokens."
67
196
  )
68
197
 
69
198
  turn = chat_formatter.next_turn(prior_output)
70
199
  if turn is None:
200
+ # No next turn, we're done
71
201
  break
72
202
 
203
+ # Add messages from the turn to chat history
204
+ for message in turn.messages:
205
+ if message.content is None:
206
+ raise ValueError("Empty message content isn't allowed")
207
+ # pyright incorrectly warns about this, but it's valid so we can ignore. It can't handle the multi-value role.
208
+ messages.append({"role": message.role, "content": message.content}) # type: ignore
209
+
73
210
  skip_response_format = not turn.final_call
74
- all_messages = chat_formatter.message_dicts()
75
- completion_kwargs = await self.build_completion_kwargs(
211
+ turn_result = await self._run_model_turn(
76
212
  provider,
77
- all_messages,
213
+ messages,
78
214
  self.base_adapter_config.top_logprobs if turn.final_call else None,
79
215
  skip_response_format,
80
216
  )
81
- response = await litellm.acompletion(**completion_kwargs)
82
- if (
83
- not isinstance(response, ModelResponse)
84
- or not response.choices
85
- or len(response.choices) == 0
86
- or not isinstance(response.choices[0], Choices)
87
- ):
88
- raise RuntimeError(
89
- f"Expected ModelResponse with Choices, got {type(response)}."
90
- )
91
- prior_message = response.choices[0].message
92
- prior_output = prior_message.content
93
-
94
- # Fallback: Use args of first tool call to task_response if it exists
95
- if (
96
- not prior_output
97
- and hasattr(prior_message, "tool_calls")
98
- and prior_message.tool_calls
99
- ):
100
- tool_call = next(
101
- (
102
- tool_call
103
- for tool_call in prior_message.tool_calls
104
- if tool_call.function.name == "task_response"
105
- ),
106
- None,
107
- )
108
- if tool_call:
109
- prior_output = tool_call.function.arguments
217
+
218
+ usage += turn_result.usage
219
+
220
+ prior_output = turn_result.assistant_message
221
+ messages = turn_result.all_messages
222
+ final_choice = turn_result.model_choice
110
223
 
111
224
  if not prior_output:
112
- raise RuntimeError("No output returned from model")
225
+ raise RuntimeError("No assistant message/output returned from model")
113
226
 
114
- if response is None or prior_message is None:
115
- raise RuntimeError("No response returned from model")
227
+ logprobs = self._extract_and_validate_logprobs(final_choice)
116
228
 
229
+ # Save COT/reasoning if it exists. May be a message, or may be parsed by LiteLLM (or openrouter, or anyone upstream)
117
230
  intermediate_outputs = chat_formatter.intermediate_outputs()
231
+ self._extract_reasoning_to_intermediate_outputs(
232
+ final_choice, intermediate_outputs
233
+ )
234
+
235
+ if not isinstance(prior_output, str):
236
+ raise RuntimeError(f"assistant message is not a string: {prior_output}")
118
237
 
119
- logprobs = (
120
- response.choices[0].logprobs
121
- if hasattr(response.choices[0], "logprobs")
122
- and isinstance(response.choices[0].logprobs, ChoiceLogprobs)
123
- else None
238
+ trace = self.all_messages_to_trace(messages)
239
+ output = RunOutput(
240
+ output=prior_output,
241
+ intermediate_outputs=intermediate_outputs,
242
+ output_logprobs=logprobs,
243
+ trace=trace,
124
244
  )
125
245
 
126
- # Check logprobs worked, if requested
127
- if self.base_adapter_config.top_logprobs is not None and logprobs is None:
128
- raise RuntimeError("Logprobs were required, but no logprobs were returned.")
246
+ return output, usage
129
247
 
130
- # Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream)
248
+ def _extract_and_validate_logprobs(
249
+ self, final_choice: Choices | None
250
+ ) -> ChoiceLogprobs | None:
251
+ """
252
+ Extract logprobs from the final choice and validate they exist if required.
253
+ """
254
+ logprobs = None
131
255
  if (
132
- prior_message is not None
133
- and hasattr(prior_message, "reasoning_content")
134
- and prior_message.reasoning_content
135
- and len(prior_message.reasoning_content.strip()) > 0
256
+ final_choice is not None
257
+ and hasattr(final_choice, "logprobs")
258
+ and isinstance(final_choice.logprobs, ChoiceLogprobs)
136
259
  ):
137
- intermediate_outputs["reasoning"] = prior_message.reasoning_content.strip()
260
+ logprobs = final_choice.logprobs
138
261
 
139
- # the string content of the response
140
- response_content = prior_output
262
+ # Check logprobs worked, if required
263
+ if self.base_adapter_config.top_logprobs is not None and logprobs is None:
264
+ raise RuntimeError("Logprobs were required, but no logprobs were returned.")
141
265
 
142
- if not isinstance(response_content, str):
143
- raise RuntimeError(f"response is not a string: {response_content}")
266
+ return logprobs
144
267
 
145
- return RunOutput(
146
- output=response_content,
147
- intermediate_outputs=intermediate_outputs,
148
- output_logprobs=logprobs,
149
- ), self.usage_from_response(response)
268
+ def _extract_reasoning_to_intermediate_outputs(
269
+ self, final_choice: Choices | None, intermediate_outputs: Dict[str, Any]
270
+ ) -> None:
271
+ """Extract reasoning content from model choice and add to intermediate outputs if present."""
272
+ if (
273
+ final_choice is not None
274
+ and hasattr(final_choice, "message")
275
+ and hasattr(final_choice.message, "reasoning_content")
276
+ ):
277
+ reasoning_content = final_choice.message.reasoning_content
278
+ if reasoning_content is not None:
279
+ stripped_reasoning_content = reasoning_content.strip()
280
+ if len(stripped_reasoning_content) > 0:
281
+ intermediate_outputs["reasoning"] = stripped_reasoning_content
282
+
283
+ async def acompletion_checking_response(
284
+ self, **kwargs
285
+ ) -> Tuple[ModelResponse, Choices]:
286
+ response = await litellm.acompletion(**kwargs)
287
+ if (
288
+ not isinstance(response, ModelResponse)
289
+ or not response.choices
290
+ or len(response.choices) == 0
291
+ or not isinstance(response.choices[0], Choices)
292
+ ):
293
+ raise RuntimeError(
294
+ f"Expected ModelResponse with Choices, got {type(response)}."
295
+ )
296
+ return response, response.choices[0]
150
297
 
151
298
  def adapter_name(self) -> str:
152
299
  return "kiln_openai_compatible_adapter"
@@ -181,6 +328,9 @@ class LiteLlmAdapter(BaseAdapter):
181
328
  if provider_name == ModelProviderName.ollama:
182
329
  # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
183
330
  return self.json_schema_response_format()
331
+ elif provider_name == ModelProviderName.docker_model_runner:
332
+ # Docker Model Runner uses OpenAI-compatible API with JSON schema support
333
+ return self.json_schema_response_format()
184
334
  else:
185
335
  # Default to function calling -- it's older than the other modes. Higher compatibility.
186
336
  # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
@@ -193,7 +343,7 @@ class LiteLlmAdapter(BaseAdapter):
193
343
  raise_exhaustive_enum_error(structured_output_mode)
194
344
 
195
345
  def json_schema_response_format(self) -> dict[str, Any]:
196
- output_schema = self.task().output_schema()
346
+ output_schema = self.task.output_schema()
197
347
  return {
198
348
  "response_format": {
199
349
  "type": "json_schema",
@@ -206,7 +356,7 @@ class LiteLlmAdapter(BaseAdapter):
206
356
 
207
357
  def tool_call_params(self, strict: bool) -> dict[str, Any]:
208
358
  # Add additional_properties: false to the schema (OpenAI requires this for some models)
209
- output_schema = self.task().output_schema()
359
+ output_schema = self.task.output_schema()
210
360
  if not isinstance(output_schema, dict):
211
361
  raise ValueError(
212
362
  "Invalid output schema for this task. Can not use tool calls."
@@ -235,7 +385,7 @@ class LiteLlmAdapter(BaseAdapter):
235
385
  }
236
386
 
237
387
  def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
238
- # TODO P1: Don't love having this logic here. But it's a usability improvement
388
+ # Don't love having this logic here. But it's worth the usability improvement
239
389
  # so better to keep it than exclude it. Should figure out how I want to isolate
240
390
  # this sort of logic so it's config driven and can be overridden
241
391
 
@@ -251,6 +401,11 @@ class LiteLlmAdapter(BaseAdapter):
251
401
  "exclude": False,
252
402
  }
253
403
 
404
+ if provider.gemini_reasoning_enabled:
405
+ extra_body["reasoning"] = {
406
+ "enabled": True,
407
+ }
408
+
254
409
  if provider.name == ModelProviderName.openrouter:
255
410
  # Ask OpenRouter to include usage in the response (cost)
256
411
  extra_body["usage"] = {"include": True}
@@ -280,6 +435,10 @@ class LiteLlmAdapter(BaseAdapter):
280
435
  # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
281
436
  provider_options["require_parameters"] = False
282
437
 
438
+ # Siliconflow uses a bool flag for thinking, for some models
439
+ if provider.siliconflow_enable_thinking is not None:
440
+ extra_body["enable_thinking"] = provider.siliconflow_enable_thinking
441
+
283
442
  if len(provider_options) > 0:
284
443
  extra_body["provider"] = provider_options
285
444
 
@@ -311,6 +470,10 @@ class LiteLlmAdapter(BaseAdapter):
311
470
  # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
312
471
  # This is because we're setting detailed features like response_format=json_schema and want lower level control.
313
472
  is_custom = True
473
+ case ModelProviderName.docker_model_runner:
474
+ # Docker Model Runner uses OpenAI-compatible API, similar to Ollama
475
+ # We want direct control over the requests for features like response_format=json_schema
476
+ is_custom = True
314
477
  case ModelProviderName.gemini_api:
315
478
  litellm_provider_name = "gemini"
316
479
  case ModelProviderName.fireworks_ai:
@@ -325,6 +488,10 @@ class LiteLlmAdapter(BaseAdapter):
325
488
  litellm_provider_name = "vertex_ai"
326
489
  case ModelProviderName.together_ai:
327
490
  litellm_provider_name = "together_ai"
491
+ case ModelProviderName.cerebras:
492
+ litellm_provider_name = "cerebras"
493
+ case ModelProviderName.siliconflow_cn:
494
+ is_custom = True
328
495
  case ModelProviderName.openai_compatible:
329
496
  is_custom = True
330
497
  case ModelProviderName.kiln_custom_registry:
@@ -354,7 +521,7 @@ class LiteLlmAdapter(BaseAdapter):
354
521
  async def build_completion_kwargs(
355
522
  self,
356
523
  provider: KilnModelProvider,
357
- messages: list[dict[str, Any]],
524
+ messages: list[ChatCompletionMessageIncludingLiteLLM],
358
525
  top_logprobs: int | None,
359
526
  skip_response_format: bool = False,
360
527
  ) -> dict[str, Any]:
@@ -377,9 +544,23 @@ class LiteLlmAdapter(BaseAdapter):
377
544
  **self._additional_body_options,
378
545
  }
379
546
 
547
+ tool_calls = await self.litellm_tools()
548
+ has_tools = len(tool_calls) > 0
549
+ if has_tools:
550
+ completion_kwargs["tools"] = tool_calls
551
+ completion_kwargs["tool_choice"] = "auto"
552
+
380
553
  if not skip_response_format:
381
554
  # Response format: json_schema, json_instructions, json_mode, function_calling, etc
382
555
  response_format_options = await self.response_format_options()
556
+
557
+ # Check for a conflict between tools and response format using tools
558
+ # We could reconsider this. Model could be able to choose between a final answer or a tool call on any turn. However, good models for tools tend to also support json_schea, so do we need to support both? If we do, merge them, and consider auto vs forced when merging (only forced for final, auto for merged).
559
+ if has_tools and "tools" in response_format_options:
560
+ raise ValueError(
561
+ "Function calling/tools can't be used as the JSON response format if you're also using tools. Please select a different structured output mode."
562
+ )
563
+
383
564
  completion_kwargs.update(response_format_options)
384
565
 
385
566
  if top_logprobs is not None:
@@ -388,7 +569,7 @@ class LiteLlmAdapter(BaseAdapter):
388
569
 
389
570
  return completion_kwargs
390
571
 
391
- def usage_from_response(self, response: ModelResponse) -> Usage | None:
572
+ def usage_from_response(self, response: ModelResponse) -> Usage:
392
573
  litellm_usage = response.get("usage", None)
393
574
 
394
575
  # LiteLLM isn't consistent in how it returns the cost.
@@ -396,11 +577,11 @@ class LiteLlmAdapter(BaseAdapter):
396
577
  if cost is None and litellm_usage:
397
578
  cost = litellm_usage.get("cost", None)
398
579
 
399
- if not litellm_usage and not cost:
400
- return None
401
-
402
580
  usage = Usage()
403
581
 
582
+ if not litellm_usage and not cost:
583
+ return usage
584
+
404
585
  if litellm_usage and isinstance(litellm_usage, LiteLlmUsage):
405
586
  usage.input_tokens = litellm_usage.get("prompt_tokens", None)
406
587
  usage.output_tokens = litellm_usage.get("completion_tokens", None)
@@ -419,3 +600,139 @@ class LiteLlmAdapter(BaseAdapter):
419
600
  )
420
601
 
421
602
  return usage
603
+
604
+ async def cached_available_tools(self) -> list[KilnToolInterface]:
605
+ if self._cached_available_tools is None:
606
+ self._cached_available_tools = await self.available_tools()
607
+ return self._cached_available_tools
608
+
609
+ async def litellm_tools(self) -> list[Dict]:
610
+ available_tools = await self.cached_available_tools()
611
+
612
+ # LiteLLM takes the standard OpenAI-compatible tool call format
613
+ return [await tool.toolcall_definition() for tool in available_tools]
614
+
615
+ async def process_tool_calls(
616
+ self, tool_calls: list[ChatCompletionMessageToolCall] | None
617
+ ) -> tuple[str | None, list[ChatCompletionToolMessageParam]]:
618
+ if tool_calls is None:
619
+ return None, []
620
+
621
+ assistant_output_from_toolcall: str | None = None
622
+ tool_call_response_messages: list[ChatCompletionToolMessageParam] = []
623
+
624
+ for tool_call in tool_calls:
625
+ # Kiln "task_response" tool is used for returning structured output via tool calls.
626
+ # Load the output from the tool call. Also
627
+ if tool_call.function.name == "task_response":
628
+ assistant_output_from_toolcall = tool_call.function.arguments
629
+ continue
630
+
631
+ # Process normal tool calls (not the "task_response" tool)
632
+ tool_name = tool_call.function.name
633
+ tool = None
634
+ for tool_option in await self.cached_available_tools():
635
+ if await tool_option.name() == tool_name:
636
+ tool = tool_option
637
+ break
638
+ if not tool:
639
+ raise RuntimeError(
640
+ f"A tool named '{tool_name}' was invoked by a model, but was not available."
641
+ )
642
+
643
+ # Parse the arguments and validate them against the tool's schema
644
+ try:
645
+ parsed_args = json.loads(tool_call.function.arguments)
646
+ except json.JSONDecodeError:
647
+ raise RuntimeError(
648
+ f"Failed to parse arguments for tool '{tool_name}' (should be JSON): {tool_call.function.arguments}"
649
+ )
650
+ try:
651
+ tool_call_definition = await tool.toolcall_definition()
652
+ json_schema = json.dumps(tool_call_definition["function"]["parameters"])
653
+ validate_schema_with_value_error(parsed_args, json_schema)
654
+ except Exception as e:
655
+ raise RuntimeError(
656
+ f"Failed to validate arguments for tool '{tool_name}'. The arguments didn't match the tool's schema. The arguments were: {parsed_args}\n The error was: {e}"
657
+ ) from e
658
+
659
+ result = await tool.run(**parsed_args)
660
+
661
+ tool_call_response_messages.append(
662
+ ChatCompletionToolMessageParam(
663
+ role="tool",
664
+ tool_call_id=tool_call.id,
665
+ content=result,
666
+ )
667
+ )
668
+
669
+ if (
670
+ assistant_output_from_toolcall is not None
671
+ and len(tool_call_response_messages) > 0
672
+ ):
673
+ raise RuntimeError(
674
+ "Model asked for impossible combination: task_response tool call and other tool calls were both provided in the same turn. This is not supported as it means the model asked us to both return task_response results (ending the turn) and run new tools calls to send back to the model. If the model makes this mistake often, try a difference structured data model like JSON schema, where this is impossible."
675
+ )
676
+
677
+ return assistant_output_from_toolcall, tool_call_response_messages
678
+
679
+ def litellm_message_to_trace_message(
680
+ self, raw_message: LiteLLMMessage
681
+ ) -> ChatCompletionAssistantMessageParamWrapper:
682
+ """
683
+ Convert a LiteLLM Message object to an OpenAI compatible message, our ChatCompletionAssistantMessageParamWrapper
684
+ """
685
+ message: ChatCompletionAssistantMessageParamWrapper = {
686
+ "role": "assistant",
687
+ }
688
+ if raw_message.role != "assistant":
689
+ raise ValueError(
690
+ "Model returned a message with a role other than assistant. This is not supported."
691
+ )
692
+
693
+ if hasattr(raw_message, "content"):
694
+ message["content"] = raw_message.content
695
+ if hasattr(raw_message, "reasoning_content"):
696
+ message["reasoning_content"] = raw_message.reasoning_content
697
+ if hasattr(raw_message, "tool_calls"):
698
+ # Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallParam
699
+ open_ai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
700
+ for litellm_tool_call in raw_message.tool_calls or []:
701
+ # Optional in the SDK for streaming responses, but should never be None at this point.
702
+ if litellm_tool_call.function.name is None:
703
+ raise ValueError(
704
+ "The model requested a tool call, without providing a function name (required)."
705
+ )
706
+ open_ai_tool_calls.append(
707
+ ChatCompletionMessageToolCallParam(
708
+ id=litellm_tool_call.id,
709
+ type="function",
710
+ function={
711
+ "name": litellm_tool_call.function.name,
712
+ "arguments": litellm_tool_call.function.arguments,
713
+ },
714
+ )
715
+ )
716
+ if len(open_ai_tool_calls) > 0:
717
+ message["tool_calls"] = open_ai_tool_calls
718
+
719
+ if not message.get("content") and not message.get("tool_calls"):
720
+ raise ValueError(
721
+ "Model returned an assistant message, but no content or tool calls. This is not supported."
722
+ )
723
+
724
+ return message
725
+
726
+ def all_messages_to_trace(
727
+ self, messages: list[ChatCompletionMessageIncludingLiteLLM]
728
+ ) -> list[ChatCompletionMessageParam]:
729
+ """
730
+ Internally we allow LiteLLM Message objects, but for trace we need OpenAI compatible types. Replace LiteLLM Message objects with OpenAI compatible types.
731
+ """
732
+ trace: list[ChatCompletionMessageParam] = []
733
+ for message in messages:
734
+ if isinstance(message, LiteLLMMessage):
735
+ trace.append(self.litellm_message_to_trace_message(message))
736
+ else:
737
+ trace.append(message)
738
+ return trace