kiln-ai 0.12.0__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +4 -0
- kiln_ai/adapters/adapter_registry.py +153 -28
- kiln_ai/adapters/eval/__init__.py +28 -0
- kiln_ai/adapters/eval/eval_runner.py +4 -1
- kiln_ai/adapters/eval/g_eval.py +2 -1
- kiln_ai/adapters/eval/test_base_eval.py +1 -0
- kiln_ai/adapters/eval/test_eval_runner.py +1 -0
- kiln_ai/adapters/eval/test_g_eval.py +1 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/test_together_finetune.py +531 -0
- kiln_ai/adapters/fine_tune/together_finetune.py +325 -0
- kiln_ai/adapters/ml_model_list.py +638 -155
- kiln_ai/adapters/model_adapters/__init__.py +2 -4
- kiln_ai/adapters/model_adapters/base_adapter.py +14 -11
- kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
- kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
- kiln_ai/adapters/model_adapters/test_structured_output.py +23 -5
- kiln_ai/adapters/ollama_tools.py +3 -2
- kiln_ai/adapters/parsers/r1_parser.py +19 -14
- kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
- kiln_ai/adapters/provider_tools.py +50 -58
- kiln_ai/adapters/repair/test_repair_task.py +3 -3
- kiln_ai/adapters/run_output.py +1 -1
- kiln_ai/adapters/test_adapter_registry.py +17 -20
- kiln_ai/adapters/test_generate_docs.py +2 -2
- kiln_ai/adapters/test_prompt_adaptors.py +30 -19
- kiln_ai/adapters/test_provider_tools.py +26 -81
- kiln_ai/datamodel/basemodel.py +2 -0
- kiln_ai/datamodel/datamodel_enums.py +2 -0
- kiln_ai/datamodel/json_schema.py +1 -1
- kiln_ai/datamodel/task_output.py +13 -6
- kiln_ai/datamodel/test_basemodel.py +9 -0
- kiln_ai/datamodel/test_datasource.py +19 -0
- kiln_ai/utils/config.py +37 -0
- kiln_ai/utils/dataset_import.py +232 -0
- kiln_ai/utils/test_dataset_import.py +596 -0
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/METADATA +51 -7
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/RECORD +42 -39
- kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -309
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -10
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -289
- kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -343
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -216
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -7,12 +7,10 @@ Model adapters are used to call AI models, like Ollama, OpenAI, etc.
|
|
|
7
7
|
|
|
8
8
|
from . import (
|
|
9
9
|
base_adapter,
|
|
10
|
-
|
|
11
|
-
openai_model_adapter,
|
|
10
|
+
litellm_adapter,
|
|
12
11
|
)
|
|
13
12
|
|
|
14
13
|
__all__ = [
|
|
15
14
|
"base_adapter",
|
|
16
|
-
"
|
|
17
|
-
"openai_model_adapter",
|
|
15
|
+
"litellm_adapter",
|
|
18
16
|
]
|
|
@@ -4,6 +4,7 @@ from dataclasses import dataclass
|
|
|
4
4
|
from typing import Dict, Literal, Tuple
|
|
5
5
|
|
|
6
6
|
from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
|
|
7
|
+
from kiln_ai.adapters.parsers.json_parser import parse_json_string
|
|
7
8
|
from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id
|
|
8
9
|
from kiln_ai.adapters.prompt_builders import prompt_builder_from_id
|
|
9
10
|
from kiln_ai.adapters.provider_tools import kiln_model_provider_from
|
|
@@ -85,17 +86,6 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
85
86
|
)
|
|
86
87
|
return self._model_provider
|
|
87
88
|
|
|
88
|
-
async def invoke_returning_raw(
|
|
89
|
-
self,
|
|
90
|
-
input: Dict | str,
|
|
91
|
-
input_source: DataSource | None = None,
|
|
92
|
-
) -> Dict | str:
|
|
93
|
-
result = await self.invoke(input, input_source)
|
|
94
|
-
if self.task().output_json_schema is None:
|
|
95
|
-
return result.output.output
|
|
96
|
-
else:
|
|
97
|
-
return json.loads(result.output.output)
|
|
98
|
-
|
|
99
89
|
async def invoke(
|
|
100
90
|
self,
|
|
101
91
|
input: Dict | str,
|
|
@@ -127,6 +117,10 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
127
117
|
|
|
128
118
|
# validate output
|
|
129
119
|
if self.output_schema is not None:
|
|
120
|
+
# Parse json to dict if we have structured output
|
|
121
|
+
if isinstance(parsed_output.output, str):
|
|
122
|
+
parsed_output.output = parse_json_string(parsed_output.output)
|
|
123
|
+
|
|
130
124
|
if not isinstance(parsed_output.output, dict):
|
|
131
125
|
raise RuntimeError(
|
|
132
126
|
f"structured response is not a dict: {parsed_output.output}"
|
|
@@ -138,6 +132,15 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
138
132
|
f"response is not a string for non-structured task: {parsed_output.output}"
|
|
139
133
|
)
|
|
140
134
|
|
|
135
|
+
# Validate reasoning content is present (if reasoning)
|
|
136
|
+
if provider.reasoning_capable and (
|
|
137
|
+
not parsed_output.intermediate_outputs
|
|
138
|
+
or "reasoning" not in parsed_output.intermediate_outputs
|
|
139
|
+
):
|
|
140
|
+
raise RuntimeError(
|
|
141
|
+
"Reasoning is required for this model, but no reasoning was returned."
|
|
142
|
+
)
|
|
143
|
+
|
|
141
144
|
# Generate the run and output
|
|
142
145
|
run = self.generate_run(input, input_source, parsed_output)
|
|
143
146
|
|
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
from typing import Any, Dict
|
|
2
|
+
|
|
3
|
+
import litellm
|
|
4
|
+
from litellm.types.utils import ChoiceLogprobs, Choices, ModelResponse
|
|
5
|
+
|
|
6
|
+
import kiln_ai.datamodel as datamodel
|
|
7
|
+
from kiln_ai.adapters.ml_model_list import (
|
|
8
|
+
KilnModelProvider,
|
|
9
|
+
ModelProviderName,
|
|
10
|
+
StructuredOutputMode,
|
|
11
|
+
)
|
|
12
|
+
from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
13
|
+
COT_FINAL_ANSWER_PROMPT,
|
|
14
|
+
AdapterConfig,
|
|
15
|
+
BaseAdapter,
|
|
16
|
+
RunOutput,
|
|
17
|
+
)
|
|
18
|
+
from kiln_ai.adapters.model_adapters.litellm_config import (
|
|
19
|
+
LiteLlmConfig,
|
|
20
|
+
)
|
|
21
|
+
from kiln_ai.datamodel import PromptGenerators, PromptId
|
|
22
|
+
from kiln_ai.datamodel.task import RunConfig
|
|
23
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class LiteLlmAdapter(BaseAdapter):
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
config: LiteLlmConfig,
|
|
30
|
+
kiln_task: datamodel.Task,
|
|
31
|
+
prompt_id: PromptId | None = None,
|
|
32
|
+
base_adapter_config: AdapterConfig | None = None,
|
|
33
|
+
):
|
|
34
|
+
self.config = config
|
|
35
|
+
self._additional_body_options = config.additional_body_options
|
|
36
|
+
self._api_base = config.base_url
|
|
37
|
+
self._headers = config.default_headers
|
|
38
|
+
self._litellm_model_id: str | None = None
|
|
39
|
+
|
|
40
|
+
run_config = RunConfig(
|
|
41
|
+
task=kiln_task,
|
|
42
|
+
model_name=config.model_name,
|
|
43
|
+
model_provider_name=config.provider_name,
|
|
44
|
+
prompt_id=prompt_id or PromptGenerators.SIMPLE,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
super().__init__(
|
|
48
|
+
run_config=run_config,
|
|
49
|
+
config=base_adapter_config,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
async def _run(self, input: Dict | str) -> RunOutput:
|
|
53
|
+
provider = self.model_provider()
|
|
54
|
+
if not provider.model_id:
|
|
55
|
+
raise ValueError("Model ID is required for OpenAI compatible models")
|
|
56
|
+
|
|
57
|
+
intermediate_outputs: dict[str, str] = {}
|
|
58
|
+
prompt = self.build_prompt()
|
|
59
|
+
user_msg = self.prompt_builder.build_user_message(input)
|
|
60
|
+
messages = [
|
|
61
|
+
{"role": "system", "content": prompt},
|
|
62
|
+
{"role": "user", "content": user_msg},
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
run_strategy, cot_prompt = self.run_strategy()
|
|
66
|
+
|
|
67
|
+
if run_strategy == "cot_as_message":
|
|
68
|
+
if not cot_prompt:
|
|
69
|
+
raise ValueError("cot_prompt is required for cot_as_message strategy")
|
|
70
|
+
messages.append({"role": "system", "content": cot_prompt})
|
|
71
|
+
elif run_strategy == "cot_two_call":
|
|
72
|
+
if not cot_prompt:
|
|
73
|
+
raise ValueError("cot_prompt is required for cot_two_call strategy")
|
|
74
|
+
messages.append({"role": "system", "content": cot_prompt})
|
|
75
|
+
|
|
76
|
+
# First call for chain of thought - No logprobs as only needed for final answer
|
|
77
|
+
completion_kwargs = await self.build_completion_kwargs(
|
|
78
|
+
provider, messages, None
|
|
79
|
+
)
|
|
80
|
+
cot_response = await litellm.acompletion(**completion_kwargs)
|
|
81
|
+
if (
|
|
82
|
+
not isinstance(cot_response, ModelResponse)
|
|
83
|
+
or not cot_response.choices
|
|
84
|
+
or len(cot_response.choices) == 0
|
|
85
|
+
or not isinstance(cot_response.choices[0], Choices)
|
|
86
|
+
):
|
|
87
|
+
raise RuntimeError(
|
|
88
|
+
f"Expected ModelResponse with Choices, got {type(cot_response)}."
|
|
89
|
+
)
|
|
90
|
+
cot_content = cot_response.choices[0].message.content
|
|
91
|
+
if cot_content is not None:
|
|
92
|
+
intermediate_outputs["chain_of_thought"] = cot_content
|
|
93
|
+
|
|
94
|
+
messages.extend(
|
|
95
|
+
[
|
|
96
|
+
{"role": "assistant", "content": cot_content or ""},
|
|
97
|
+
{"role": "user", "content": COT_FINAL_ANSWER_PROMPT},
|
|
98
|
+
]
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Make the API call using litellm
|
|
102
|
+
completion_kwargs = await self.build_completion_kwargs(
|
|
103
|
+
provider, messages, self.base_adapter_config.top_logprobs
|
|
104
|
+
)
|
|
105
|
+
response = await litellm.acompletion(**completion_kwargs)
|
|
106
|
+
|
|
107
|
+
if not isinstance(response, ModelResponse):
|
|
108
|
+
raise RuntimeError(f"Expected ModelResponse, got {type(response)}.")
|
|
109
|
+
|
|
110
|
+
# Maybe remove this? There is no error attribute on the response object.
|
|
111
|
+
# # Keeping in typesafe way as we added it for a reason, but should investigate what that was and if it still applies.
|
|
112
|
+
if hasattr(response, "error") and response.__getattribute__("error"):
|
|
113
|
+
raise RuntimeError(
|
|
114
|
+
f"LLM API returned an error: {response.__getattribute__('error')}"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
if (
|
|
118
|
+
not response.choices
|
|
119
|
+
or len(response.choices) == 0
|
|
120
|
+
or not isinstance(response.choices[0], Choices)
|
|
121
|
+
):
|
|
122
|
+
raise RuntimeError(
|
|
123
|
+
"No message content returned in the response from LLM API"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
message = response.choices[0].message
|
|
127
|
+
logprobs = (
|
|
128
|
+
response.choices[0].logprobs
|
|
129
|
+
if hasattr(response.choices[0], "logprobs")
|
|
130
|
+
and isinstance(response.choices[0].logprobs, ChoiceLogprobs)
|
|
131
|
+
else None
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Check logprobs worked, if requested
|
|
135
|
+
if self.base_adapter_config.top_logprobs is not None and logprobs is None:
|
|
136
|
+
raise RuntimeError("Logprobs were required, but no logprobs were returned.")
|
|
137
|
+
|
|
138
|
+
# Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream)
|
|
139
|
+
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
|
140
|
+
intermediate_outputs["reasoning"] = message.reasoning_content
|
|
141
|
+
|
|
142
|
+
# the string content of the response
|
|
143
|
+
response_content = message.content
|
|
144
|
+
|
|
145
|
+
# Fallback: Use args of first tool call to task_response if it exists
|
|
146
|
+
if (
|
|
147
|
+
not response_content
|
|
148
|
+
and hasattr(message, "tool_calls")
|
|
149
|
+
and message.tool_calls
|
|
150
|
+
):
|
|
151
|
+
tool_call = next(
|
|
152
|
+
(
|
|
153
|
+
tool_call
|
|
154
|
+
for tool_call in message.tool_calls
|
|
155
|
+
if tool_call.function.name == "task_response"
|
|
156
|
+
),
|
|
157
|
+
None,
|
|
158
|
+
)
|
|
159
|
+
if tool_call:
|
|
160
|
+
response_content = tool_call.function.arguments
|
|
161
|
+
|
|
162
|
+
if not isinstance(response_content, str):
|
|
163
|
+
raise RuntimeError(f"response is not a string: {response_content}")
|
|
164
|
+
|
|
165
|
+
return RunOutput(
|
|
166
|
+
output=response_content,
|
|
167
|
+
intermediate_outputs=intermediate_outputs,
|
|
168
|
+
output_logprobs=logprobs,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def adapter_name(self) -> str:
|
|
172
|
+
return "kiln_openai_compatible_adapter"
|
|
173
|
+
|
|
174
|
+
async def response_format_options(self) -> dict[str, Any]:
|
|
175
|
+
# Unstructured if task isn't structured
|
|
176
|
+
if not self.has_structured_output():
|
|
177
|
+
return {}
|
|
178
|
+
|
|
179
|
+
provider = self.model_provider()
|
|
180
|
+
match provider.structured_output_mode:
|
|
181
|
+
case StructuredOutputMode.json_mode:
|
|
182
|
+
return {"response_format": {"type": "json_object"}}
|
|
183
|
+
case StructuredOutputMode.json_schema:
|
|
184
|
+
return self.json_schema_response_format()
|
|
185
|
+
case StructuredOutputMode.function_calling_weak:
|
|
186
|
+
return self.tool_call_params(strict=False)
|
|
187
|
+
case StructuredOutputMode.function_calling:
|
|
188
|
+
return self.tool_call_params(strict=True)
|
|
189
|
+
case StructuredOutputMode.json_instructions:
|
|
190
|
+
# JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
|
|
191
|
+
return {}
|
|
192
|
+
case StructuredOutputMode.json_custom_instructions:
|
|
193
|
+
# JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
|
|
194
|
+
return {}
|
|
195
|
+
case StructuredOutputMode.json_instruction_and_object:
|
|
196
|
+
# We set response_format to json_object and also set json instructions in the prompt
|
|
197
|
+
return {"response_format": {"type": "json_object"}}
|
|
198
|
+
case StructuredOutputMode.default:
|
|
199
|
+
if provider.name == ModelProviderName.ollama:
|
|
200
|
+
# Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
|
|
201
|
+
return self.json_schema_response_format()
|
|
202
|
+
else:
|
|
203
|
+
# Default to function calling -- it's older than the other modes. Higher compatibility.
|
|
204
|
+
# Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
|
|
205
|
+
strict = provider.name == ModelProviderName.openai
|
|
206
|
+
return self.tool_call_params(strict=strict)
|
|
207
|
+
case _:
|
|
208
|
+
raise_exhaustive_enum_error(provider.structured_output_mode)
|
|
209
|
+
|
|
210
|
+
def json_schema_response_format(self) -> dict[str, Any]:
|
|
211
|
+
output_schema = self.task().output_schema()
|
|
212
|
+
return {
|
|
213
|
+
"response_format": {
|
|
214
|
+
"type": "json_schema",
|
|
215
|
+
"json_schema": {
|
|
216
|
+
"name": "task_response",
|
|
217
|
+
"schema": output_schema,
|
|
218
|
+
},
|
|
219
|
+
}
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
def tool_call_params(self, strict: bool) -> dict[str, Any]:
|
|
223
|
+
# Add additional_properties: false to the schema (OpenAI requires this for some models)
|
|
224
|
+
output_schema = self.task().output_schema()
|
|
225
|
+
if not isinstance(output_schema, dict):
|
|
226
|
+
raise ValueError(
|
|
227
|
+
"Invalid output schema for this task. Can not use tool calls."
|
|
228
|
+
)
|
|
229
|
+
output_schema["additionalProperties"] = False
|
|
230
|
+
|
|
231
|
+
function_params = {
|
|
232
|
+
"name": "task_response",
|
|
233
|
+
"parameters": output_schema,
|
|
234
|
+
}
|
|
235
|
+
# This should be on, but we allow setting function_calling_weak for APIs that don't support it.
|
|
236
|
+
if strict:
|
|
237
|
+
function_params["strict"] = True
|
|
238
|
+
|
|
239
|
+
return {
|
|
240
|
+
"tools": [
|
|
241
|
+
{
|
|
242
|
+
"type": "function",
|
|
243
|
+
"function": function_params,
|
|
244
|
+
}
|
|
245
|
+
],
|
|
246
|
+
"tool_choice": {
|
|
247
|
+
"type": "function",
|
|
248
|
+
"function": {"name": "task_response"},
|
|
249
|
+
},
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
|
|
253
|
+
# TODO P1: Don't love having this logic here. But it's a usability improvement
|
|
254
|
+
# so better to keep it than exclude it. Should figure out how I want to isolate
|
|
255
|
+
# this sort of logic so it's config driven and can be overridden
|
|
256
|
+
|
|
257
|
+
extra_body = {}
|
|
258
|
+
provider_options = {}
|
|
259
|
+
|
|
260
|
+
if provider.thinking_level is not None:
|
|
261
|
+
extra_body["reasoning_effort"] = provider.thinking_level
|
|
262
|
+
|
|
263
|
+
if provider.require_openrouter_reasoning:
|
|
264
|
+
# https://openrouter.ai/docs/use-cases/reasoning-tokens
|
|
265
|
+
extra_body["reasoning"] = {
|
|
266
|
+
"exclude": False,
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
if provider.anthropic_extended_thinking:
|
|
270
|
+
extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
|
|
271
|
+
|
|
272
|
+
if provider.r1_openrouter_options:
|
|
273
|
+
# Require providers that support the reasoning parameter
|
|
274
|
+
provider_options["require_parameters"] = True
|
|
275
|
+
# Prefer R1 providers with reasonable perf/quants
|
|
276
|
+
provider_options["order"] = ["Fireworks", "Together"]
|
|
277
|
+
# R1 providers with unreasonable quants
|
|
278
|
+
provider_options["ignore"] = ["DeepInfra"]
|
|
279
|
+
|
|
280
|
+
# Only set of this request is to get logprobs.
|
|
281
|
+
if (
|
|
282
|
+
provider.logprobs_openrouter_options
|
|
283
|
+
and self.base_adapter_config.top_logprobs is not None
|
|
284
|
+
):
|
|
285
|
+
# Don't let OpenRouter choose a provider that doesn't support logprobs.
|
|
286
|
+
provider_options["require_parameters"] = True
|
|
287
|
+
# DeepInfra silently fails to return logprobs consistently.
|
|
288
|
+
provider_options["ignore"] = ["DeepInfra"]
|
|
289
|
+
|
|
290
|
+
if provider.openrouter_skip_required_parameters:
|
|
291
|
+
# Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
|
|
292
|
+
provider_options["require_parameters"] = False
|
|
293
|
+
|
|
294
|
+
if len(provider_options) > 0:
|
|
295
|
+
extra_body["provider"] = provider_options
|
|
296
|
+
|
|
297
|
+
return extra_body
|
|
298
|
+
|
|
299
|
+
def litellm_model_id(self) -> str:
|
|
300
|
+
# The model ID is an interesting combination of format and url endpoint.
|
|
301
|
+
# It specifics the provider URL/host, but this is overridden if you manually set an api url
|
|
302
|
+
|
|
303
|
+
if self._litellm_model_id:
|
|
304
|
+
return self._litellm_model_id
|
|
305
|
+
|
|
306
|
+
provider = self.model_provider()
|
|
307
|
+
if not provider.model_id:
|
|
308
|
+
raise ValueError("Model ID is required for OpenAI compatible models")
|
|
309
|
+
|
|
310
|
+
litellm_provider_name: str | None = None
|
|
311
|
+
is_custom = False
|
|
312
|
+
match provider.name:
|
|
313
|
+
case ModelProviderName.openrouter:
|
|
314
|
+
litellm_provider_name = "openrouter"
|
|
315
|
+
case ModelProviderName.openai:
|
|
316
|
+
litellm_provider_name = "openai"
|
|
317
|
+
case ModelProviderName.groq:
|
|
318
|
+
litellm_provider_name = "groq"
|
|
319
|
+
case ModelProviderName.anthropic:
|
|
320
|
+
litellm_provider_name = "anthropic"
|
|
321
|
+
case ModelProviderName.ollama:
|
|
322
|
+
# We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
|
|
323
|
+
# This is because we're setting detailed features like response_format=json_schema and want lower level control.
|
|
324
|
+
is_custom = True
|
|
325
|
+
case ModelProviderName.gemini_api:
|
|
326
|
+
litellm_provider_name = "gemini"
|
|
327
|
+
case ModelProviderName.fireworks_ai:
|
|
328
|
+
litellm_provider_name = "fireworks_ai"
|
|
329
|
+
case ModelProviderName.amazon_bedrock:
|
|
330
|
+
litellm_provider_name = "bedrock"
|
|
331
|
+
case ModelProviderName.azure_openai:
|
|
332
|
+
litellm_provider_name = "azure"
|
|
333
|
+
case ModelProviderName.huggingface:
|
|
334
|
+
litellm_provider_name = "huggingface"
|
|
335
|
+
case ModelProviderName.vertex:
|
|
336
|
+
litellm_provider_name = "vertex_ai"
|
|
337
|
+
case ModelProviderName.together_ai:
|
|
338
|
+
litellm_provider_name = "together_ai"
|
|
339
|
+
case ModelProviderName.openai_compatible:
|
|
340
|
+
is_custom = True
|
|
341
|
+
case ModelProviderName.kiln_custom_registry:
|
|
342
|
+
is_custom = True
|
|
343
|
+
case ModelProviderName.kiln_fine_tune:
|
|
344
|
+
is_custom = True
|
|
345
|
+
case _:
|
|
346
|
+
raise_exhaustive_enum_error(provider.name)
|
|
347
|
+
|
|
348
|
+
if is_custom:
|
|
349
|
+
if self._api_base is None:
|
|
350
|
+
raise ValueError(
|
|
351
|
+
"Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
|
|
352
|
+
)
|
|
353
|
+
# Use openai as it's only used for format, not url
|
|
354
|
+
litellm_provider_name = "openai"
|
|
355
|
+
|
|
356
|
+
# Sholdn't be possible but keep type checker happy
|
|
357
|
+
if litellm_provider_name is None:
|
|
358
|
+
raise ValueError(
|
|
359
|
+
f"Provider name could not lookup valid litellm provider ID {provider.model_id}"
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
self._litellm_model_id = litellm_provider_name + "/" + provider.model_id
|
|
363
|
+
return self._litellm_model_id
|
|
364
|
+
|
|
365
|
+
async def build_completion_kwargs(
|
|
366
|
+
self,
|
|
367
|
+
provider: KilnModelProvider,
|
|
368
|
+
messages: list[dict[str, Any]],
|
|
369
|
+
top_logprobs: int | None,
|
|
370
|
+
) -> dict[str, Any]:
|
|
371
|
+
extra_body = self.build_extra_body(provider)
|
|
372
|
+
|
|
373
|
+
# Merge all parameters into a single kwargs dict for litellm
|
|
374
|
+
completion_kwargs = {
|
|
375
|
+
"model": self.litellm_model_id(),
|
|
376
|
+
"messages": messages,
|
|
377
|
+
"api_base": self._api_base,
|
|
378
|
+
"headers": self._headers,
|
|
379
|
+
**extra_body,
|
|
380
|
+
**self._additional_body_options,
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
# Response format: json_schema, json_instructions, json_mode, function_calling, etc
|
|
384
|
+
response_format_options = await self.response_format_options()
|
|
385
|
+
completion_kwargs.update(response_format_options)
|
|
386
|
+
|
|
387
|
+
if top_logprobs is not None:
|
|
388
|
+
completion_kwargs["logprobs"] = True
|
|
389
|
+
completion_kwargs["top_logprobs"] = top_logprobs
|
|
390
|
+
|
|
391
|
+
return completion_kwargs
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@dataclass
|
|
5
|
+
class LiteLlmConfig:
|
|
6
|
+
model_name: str
|
|
7
|
+
provider_name: str
|
|
8
|
+
# If set, over rides the provider-name based URL from litellm
|
|
9
|
+
base_url: str | None = None
|
|
10
|
+
# Headers to send with every request
|
|
11
|
+
default_headers: dict[str, str] | None = None
|
|
12
|
+
# Extra body to send with every request
|
|
13
|
+
additional_body_options: dict[str, str] = field(default_factory=dict)
|