kiln-ai 0.11.1__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +4 -0
- kiln_ai/adapters/adapter_registry.py +163 -39
- kiln_ai/adapters/data_gen/data_gen_task.py +18 -0
- kiln_ai/adapters/eval/__init__.py +28 -0
- kiln_ai/adapters/eval/base_eval.py +164 -0
- kiln_ai/adapters/eval/eval_runner.py +270 -0
- kiln_ai/adapters/eval/g_eval.py +368 -0
- kiln_ai/adapters/eval/registry.py +16 -0
- kiln_ai/adapters/eval/test_base_eval.py +325 -0
- kiln_ai/adapters/eval/test_eval_runner.py +641 -0
- kiln_ai/adapters/eval/test_g_eval.py +498 -0
- kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +4 -1
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +1 -1
- kiln_ai/adapters/fine_tune/test_together_finetune.py +531 -0
- kiln_ai/adapters/fine_tune/together_finetune.py +325 -0
- kiln_ai/adapters/ml_model_list.py +758 -163
- kiln_ai/adapters/model_adapters/__init__.py +2 -4
- kiln_ai/adapters/model_adapters/base_adapter.py +61 -43
- kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
- kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +22 -13
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -19
- kiln_ai/adapters/model_adapters/test_structured_output.py +59 -35
- kiln_ai/adapters/ollama_tools.py +3 -3
- kiln_ai/adapters/parsers/r1_parser.py +19 -14
- kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
- kiln_ai/adapters/prompt_builders.py +80 -42
- kiln_ai/adapters/provider_tools.py +50 -58
- kiln_ai/adapters/repair/repair_task.py +9 -21
- kiln_ai/adapters/repair/test_repair_task.py +6 -6
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +26 -29
- kiln_ai/adapters/test_generate_docs.py +4 -4
- kiln_ai/adapters/test_ollama_tools.py +0 -1
- kiln_ai/adapters/test_prompt_adaptors.py +47 -33
- kiln_ai/adapters/test_prompt_builders.py +91 -31
- kiln_ai/adapters/test_provider_tools.py +26 -81
- kiln_ai/datamodel/__init__.py +50 -952
- kiln_ai/datamodel/basemodel.py +2 -0
- kiln_ai/datamodel/datamodel_enums.py +60 -0
- kiln_ai/datamodel/dataset_filters.py +114 -0
- kiln_ai/datamodel/dataset_split.py +170 -0
- kiln_ai/datamodel/eval.py +298 -0
- kiln_ai/datamodel/finetune.py +105 -0
- kiln_ai/datamodel/json_schema.py +7 -1
- kiln_ai/datamodel/project.py +23 -0
- kiln_ai/datamodel/prompt.py +37 -0
- kiln_ai/datamodel/prompt_id.py +83 -0
- kiln_ai/datamodel/strict_mode.py +24 -0
- kiln_ai/datamodel/task.py +181 -0
- kiln_ai/datamodel/task_output.py +328 -0
- kiln_ai/datamodel/task_run.py +164 -0
- kiln_ai/datamodel/test_basemodel.py +19 -11
- kiln_ai/datamodel/test_dataset_filters.py +71 -0
- kiln_ai/datamodel/test_dataset_split.py +32 -8
- kiln_ai/datamodel/test_datasource.py +22 -2
- kiln_ai/datamodel/test_eval_model.py +635 -0
- kiln_ai/datamodel/test_example_models.py +9 -13
- kiln_ai/datamodel/test_json_schema.py +23 -0
- kiln_ai/datamodel/test_models.py +2 -2
- kiln_ai/datamodel/test_prompt_id.py +129 -0
- kiln_ai/datamodel/test_task.py +159 -0
- kiln_ai/utils/config.py +43 -1
- kiln_ai/utils/dataset_import.py +232 -0
- kiln_ai/utils/test_dataset_import.py +596 -0
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/METADATA +86 -6
- kiln_ai-0.13.0.dist-info/RECORD +103 -0
- kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -302
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -11
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -246
- kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -350
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -225
- kiln_ai-0.11.1.dist-info/RECORD +0 -76
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -7,12 +7,10 @@ Model adapters are used to call AI models, like Ollama, OpenAI, etc.
|
|
|
7
7
|
|
|
8
8
|
from . import (
|
|
9
9
|
base_adapter,
|
|
10
|
-
|
|
11
|
-
openai_model_adapter,
|
|
10
|
+
litellm_adapter,
|
|
12
11
|
)
|
|
13
12
|
|
|
14
13
|
__all__ = [
|
|
15
14
|
"base_adapter",
|
|
16
|
-
"
|
|
17
|
-
"openai_model_adapter",
|
|
15
|
+
"litellm_adapter",
|
|
18
16
|
]
|
|
@@ -4,8 +4,9 @@ from dataclasses import dataclass
|
|
|
4
4
|
from typing import Dict, Literal, Tuple
|
|
5
5
|
|
|
6
6
|
from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
|
|
7
|
+
from kiln_ai.adapters.parsers.json_parser import parse_json_string
|
|
7
8
|
from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id
|
|
8
|
-
from kiln_ai.adapters.prompt_builders import
|
|
9
|
+
from kiln_ai.adapters.prompt_builders import prompt_builder_from_id
|
|
9
10
|
from kiln_ai.adapters.provider_tools import kiln_model_provider_from
|
|
10
11
|
from kiln_ai.adapters.run_output import RunOutput
|
|
11
12
|
from kiln_ai.datamodel import (
|
|
@@ -16,16 +17,21 @@ from kiln_ai.datamodel import (
|
|
|
16
17
|
TaskRun,
|
|
17
18
|
)
|
|
18
19
|
from kiln_ai.datamodel.json_schema import validate_schema
|
|
20
|
+
from kiln_ai.datamodel.task import RunConfig
|
|
19
21
|
from kiln_ai.utils.config import Config
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
@dataclass
|
|
23
|
-
class
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
25
|
+
class AdapterConfig:
|
|
26
|
+
"""
|
|
27
|
+
An adapter config is config options that do NOT impact the output of the model.
|
|
28
|
+
|
|
29
|
+
For example: if it's saved, of if we request additional data like logprobs.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
allow_saving: bool = True
|
|
33
|
+
top_logprobs: int | None = None
|
|
34
|
+
default_tags: list[str] | None = None
|
|
29
35
|
|
|
30
36
|
|
|
31
37
|
COT_FINAL_ANSWER_PROMPT = "Considering the above, return a final result."
|
|
@@ -47,54 +53,52 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
47
53
|
|
|
48
54
|
def __init__(
|
|
49
55
|
self,
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
model_provider_name: str,
|
|
53
|
-
prompt_builder: BasePromptBuilder | None = None,
|
|
54
|
-
tags: list[str] | None = None,
|
|
56
|
+
run_config: RunConfig,
|
|
57
|
+
config: AdapterConfig | None = None,
|
|
55
58
|
):
|
|
56
|
-
self.
|
|
57
|
-
self.
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
self.default_tags = tags
|
|
61
|
-
self.model_name = model_name
|
|
62
|
-
self.model_provider_name = model_provider_name
|
|
59
|
+
self.run_config = run_config
|
|
60
|
+
self.prompt_builder = prompt_builder_from_id(
|
|
61
|
+
run_config.prompt_id, run_config.task
|
|
62
|
+
)
|
|
63
63
|
self._model_provider: KilnModelProvider | None = None
|
|
64
64
|
|
|
65
|
+
self.output_schema = self.task().output_json_schema
|
|
66
|
+
self.input_schema = self.task().input_json_schema
|
|
67
|
+
self.base_adapter_config = config or AdapterConfig()
|
|
68
|
+
|
|
69
|
+
def task(self) -> Task:
|
|
70
|
+
return self.run_config.task
|
|
71
|
+
|
|
65
72
|
def model_provider(self) -> KilnModelProvider:
|
|
66
73
|
"""
|
|
67
74
|
Lazy load the model provider for this adapter.
|
|
68
75
|
"""
|
|
69
76
|
if self._model_provider is not None:
|
|
70
77
|
return self._model_provider
|
|
71
|
-
if not self.model_name or not self.model_provider_name:
|
|
78
|
+
if not self.run_config.model_name or not self.run_config.model_provider_name:
|
|
72
79
|
raise ValueError("model_name and model_provider_name must be provided")
|
|
73
80
|
self._model_provider = kiln_model_provider_from(
|
|
74
|
-
self.model_name, self.model_provider_name
|
|
81
|
+
self.run_config.model_name, self.run_config.model_provider_name
|
|
75
82
|
)
|
|
76
83
|
if not self._model_provider:
|
|
77
84
|
raise ValueError(
|
|
78
|
-
f"model_provider_name {self.model_provider_name} not found for model {self.model_name}"
|
|
85
|
+
f"model_provider_name {self.run_config.model_provider_name} not found for model {self.run_config.model_name}"
|
|
79
86
|
)
|
|
80
87
|
return self._model_provider
|
|
81
88
|
|
|
82
|
-
async def
|
|
89
|
+
async def invoke(
|
|
83
90
|
self,
|
|
84
91
|
input: Dict | str,
|
|
85
92
|
input_source: DataSource | None = None,
|
|
86
|
-
) ->
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
return result.output.output
|
|
90
|
-
else:
|
|
91
|
-
return json.loads(result.output.output)
|
|
93
|
+
) -> TaskRun:
|
|
94
|
+
run_output, _ = await self.invoke_returning_run_output(input, input_source)
|
|
95
|
+
return run_output
|
|
92
96
|
|
|
93
|
-
async def
|
|
97
|
+
async def invoke_returning_run_output(
|
|
94
98
|
self,
|
|
95
99
|
input: Dict | str,
|
|
96
100
|
input_source: DataSource | None = None,
|
|
97
|
-
) -> TaskRun:
|
|
101
|
+
) -> Tuple[TaskRun, RunOutput]:
|
|
98
102
|
# validate input
|
|
99
103
|
if self.input_schema is not None:
|
|
100
104
|
if not isinstance(input, dict):
|
|
@@ -113,6 +117,10 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
113
117
|
|
|
114
118
|
# validate output
|
|
115
119
|
if self.output_schema is not None:
|
|
120
|
+
# Parse json to dict if we have structured output
|
|
121
|
+
if isinstance(parsed_output.output, str):
|
|
122
|
+
parsed_output.output = parse_json_string(parsed_output.output)
|
|
123
|
+
|
|
116
124
|
if not isinstance(parsed_output.output, dict):
|
|
117
125
|
raise RuntimeError(
|
|
118
126
|
f"structured response is not a dict: {parsed_output.output}"
|
|
@@ -124,23 +132,36 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
124
132
|
f"response is not a string for non-structured task: {parsed_output.output}"
|
|
125
133
|
)
|
|
126
134
|
|
|
135
|
+
# Validate reasoning content is present (if reasoning)
|
|
136
|
+
if provider.reasoning_capable and (
|
|
137
|
+
not parsed_output.intermediate_outputs
|
|
138
|
+
or "reasoning" not in parsed_output.intermediate_outputs
|
|
139
|
+
):
|
|
140
|
+
raise RuntimeError(
|
|
141
|
+
"Reasoning is required for this model, but no reasoning was returned."
|
|
142
|
+
)
|
|
143
|
+
|
|
127
144
|
# Generate the run and output
|
|
128
145
|
run = self.generate_run(input, input_source, parsed_output)
|
|
129
146
|
|
|
130
147
|
# Save the run if configured to do so, and we have a path to save to
|
|
131
|
-
if
|
|
148
|
+
if (
|
|
149
|
+
self.base_adapter_config.allow_saving
|
|
150
|
+
and Config.shared().autosave_runs
|
|
151
|
+
and self.task().path is not None
|
|
152
|
+
):
|
|
132
153
|
run.save_to_file()
|
|
133
154
|
else:
|
|
134
155
|
# Clear the ID to indicate it's not persisted
|
|
135
156
|
run.id = None
|
|
136
157
|
|
|
137
|
-
return run
|
|
158
|
+
return run, run_output
|
|
138
159
|
|
|
139
160
|
def has_structured_output(self) -> bool:
|
|
140
161
|
return self.output_schema is not None
|
|
141
162
|
|
|
142
163
|
@abstractmethod
|
|
143
|
-
def
|
|
164
|
+
def adapter_name(self) -> str:
|
|
144
165
|
pass
|
|
145
166
|
|
|
146
167
|
@abstractmethod
|
|
@@ -203,7 +224,7 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
203
224
|
)
|
|
204
225
|
|
|
205
226
|
new_task_run = TaskRun(
|
|
206
|
-
parent=self.
|
|
227
|
+
parent=self.task(),
|
|
207
228
|
input=input_str,
|
|
208
229
|
input_source=input_source,
|
|
209
230
|
output=TaskOutput(
|
|
@@ -215,7 +236,7 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
215
236
|
),
|
|
216
237
|
),
|
|
217
238
|
intermediate_outputs=run_output.intermediate_outputs,
|
|
218
|
-
tags=self.default_tags or [],
|
|
239
|
+
tags=self.base_adapter_config.default_tags or [],
|
|
219
240
|
)
|
|
220
241
|
|
|
221
242
|
return new_task_run
|
|
@@ -224,12 +245,9 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
224
245
|
props = {}
|
|
225
246
|
|
|
226
247
|
# adapter info
|
|
227
|
-
|
|
228
|
-
props["
|
|
229
|
-
props["
|
|
230
|
-
props["
|
|
231
|
-
props["prompt_builder_name"] = adapter_info.prompt_builder_name
|
|
232
|
-
if adapter_info.prompt_id is not None:
|
|
233
|
-
props["prompt_id"] = adapter_info.prompt_id
|
|
248
|
+
props["adapter_name"] = self.adapter_name()
|
|
249
|
+
props["model_name"] = self.run_config.model_name
|
|
250
|
+
props["model_provider"] = self.run_config.model_provider_name
|
|
251
|
+
props["prompt_id"] = self.run_config.prompt_id
|
|
234
252
|
|
|
235
253
|
return props
|
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
from typing import Any, Dict
|
|
2
|
+
|
|
3
|
+
import litellm
|
|
4
|
+
from litellm.types.utils import ChoiceLogprobs, Choices, ModelResponse
|
|
5
|
+
|
|
6
|
+
import kiln_ai.datamodel as datamodel
|
|
7
|
+
from kiln_ai.adapters.ml_model_list import (
|
|
8
|
+
KilnModelProvider,
|
|
9
|
+
ModelProviderName,
|
|
10
|
+
StructuredOutputMode,
|
|
11
|
+
)
|
|
12
|
+
from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
13
|
+
COT_FINAL_ANSWER_PROMPT,
|
|
14
|
+
AdapterConfig,
|
|
15
|
+
BaseAdapter,
|
|
16
|
+
RunOutput,
|
|
17
|
+
)
|
|
18
|
+
from kiln_ai.adapters.model_adapters.litellm_config import (
|
|
19
|
+
LiteLlmConfig,
|
|
20
|
+
)
|
|
21
|
+
from kiln_ai.datamodel import PromptGenerators, PromptId
|
|
22
|
+
from kiln_ai.datamodel.task import RunConfig
|
|
23
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class LiteLlmAdapter(BaseAdapter):
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
config: LiteLlmConfig,
|
|
30
|
+
kiln_task: datamodel.Task,
|
|
31
|
+
prompt_id: PromptId | None = None,
|
|
32
|
+
base_adapter_config: AdapterConfig | None = None,
|
|
33
|
+
):
|
|
34
|
+
self.config = config
|
|
35
|
+
self._additional_body_options = config.additional_body_options
|
|
36
|
+
self._api_base = config.base_url
|
|
37
|
+
self._headers = config.default_headers
|
|
38
|
+
self._litellm_model_id: str | None = None
|
|
39
|
+
|
|
40
|
+
run_config = RunConfig(
|
|
41
|
+
task=kiln_task,
|
|
42
|
+
model_name=config.model_name,
|
|
43
|
+
model_provider_name=config.provider_name,
|
|
44
|
+
prompt_id=prompt_id or PromptGenerators.SIMPLE,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
super().__init__(
|
|
48
|
+
run_config=run_config,
|
|
49
|
+
config=base_adapter_config,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
async def _run(self, input: Dict | str) -> RunOutput:
|
|
53
|
+
provider = self.model_provider()
|
|
54
|
+
if not provider.model_id:
|
|
55
|
+
raise ValueError("Model ID is required for OpenAI compatible models")
|
|
56
|
+
|
|
57
|
+
intermediate_outputs: dict[str, str] = {}
|
|
58
|
+
prompt = self.build_prompt()
|
|
59
|
+
user_msg = self.prompt_builder.build_user_message(input)
|
|
60
|
+
messages = [
|
|
61
|
+
{"role": "system", "content": prompt},
|
|
62
|
+
{"role": "user", "content": user_msg},
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
run_strategy, cot_prompt = self.run_strategy()
|
|
66
|
+
|
|
67
|
+
if run_strategy == "cot_as_message":
|
|
68
|
+
if not cot_prompt:
|
|
69
|
+
raise ValueError("cot_prompt is required for cot_as_message strategy")
|
|
70
|
+
messages.append({"role": "system", "content": cot_prompt})
|
|
71
|
+
elif run_strategy == "cot_two_call":
|
|
72
|
+
if not cot_prompt:
|
|
73
|
+
raise ValueError("cot_prompt is required for cot_two_call strategy")
|
|
74
|
+
messages.append({"role": "system", "content": cot_prompt})
|
|
75
|
+
|
|
76
|
+
# First call for chain of thought - No logprobs as only needed for final answer
|
|
77
|
+
completion_kwargs = await self.build_completion_kwargs(
|
|
78
|
+
provider, messages, None
|
|
79
|
+
)
|
|
80
|
+
cot_response = await litellm.acompletion(**completion_kwargs)
|
|
81
|
+
if (
|
|
82
|
+
not isinstance(cot_response, ModelResponse)
|
|
83
|
+
or not cot_response.choices
|
|
84
|
+
or len(cot_response.choices) == 0
|
|
85
|
+
or not isinstance(cot_response.choices[0], Choices)
|
|
86
|
+
):
|
|
87
|
+
raise RuntimeError(
|
|
88
|
+
f"Expected ModelResponse with Choices, got {type(cot_response)}."
|
|
89
|
+
)
|
|
90
|
+
cot_content = cot_response.choices[0].message.content
|
|
91
|
+
if cot_content is not None:
|
|
92
|
+
intermediate_outputs["chain_of_thought"] = cot_content
|
|
93
|
+
|
|
94
|
+
messages.extend(
|
|
95
|
+
[
|
|
96
|
+
{"role": "assistant", "content": cot_content or ""},
|
|
97
|
+
{"role": "user", "content": COT_FINAL_ANSWER_PROMPT},
|
|
98
|
+
]
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Make the API call using litellm
|
|
102
|
+
completion_kwargs = await self.build_completion_kwargs(
|
|
103
|
+
provider, messages, self.base_adapter_config.top_logprobs
|
|
104
|
+
)
|
|
105
|
+
response = await litellm.acompletion(**completion_kwargs)
|
|
106
|
+
|
|
107
|
+
if not isinstance(response, ModelResponse):
|
|
108
|
+
raise RuntimeError(f"Expected ModelResponse, got {type(response)}.")
|
|
109
|
+
|
|
110
|
+
# Maybe remove this? There is no error attribute on the response object.
|
|
111
|
+
# # Keeping in typesafe way as we added it for a reason, but should investigate what that was and if it still applies.
|
|
112
|
+
if hasattr(response, "error") and response.__getattribute__("error"):
|
|
113
|
+
raise RuntimeError(
|
|
114
|
+
f"LLM API returned an error: {response.__getattribute__('error')}"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
if (
|
|
118
|
+
not response.choices
|
|
119
|
+
or len(response.choices) == 0
|
|
120
|
+
or not isinstance(response.choices[0], Choices)
|
|
121
|
+
):
|
|
122
|
+
raise RuntimeError(
|
|
123
|
+
"No message content returned in the response from LLM API"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
message = response.choices[0].message
|
|
127
|
+
logprobs = (
|
|
128
|
+
response.choices[0].logprobs
|
|
129
|
+
if hasattr(response.choices[0], "logprobs")
|
|
130
|
+
and isinstance(response.choices[0].logprobs, ChoiceLogprobs)
|
|
131
|
+
else None
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Check logprobs worked, if requested
|
|
135
|
+
if self.base_adapter_config.top_logprobs is not None and logprobs is None:
|
|
136
|
+
raise RuntimeError("Logprobs were required, but no logprobs were returned.")
|
|
137
|
+
|
|
138
|
+
# Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream)
|
|
139
|
+
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
|
140
|
+
intermediate_outputs["reasoning"] = message.reasoning_content
|
|
141
|
+
|
|
142
|
+
# the string content of the response
|
|
143
|
+
response_content = message.content
|
|
144
|
+
|
|
145
|
+
# Fallback: Use args of first tool call to task_response if it exists
|
|
146
|
+
if (
|
|
147
|
+
not response_content
|
|
148
|
+
and hasattr(message, "tool_calls")
|
|
149
|
+
and message.tool_calls
|
|
150
|
+
):
|
|
151
|
+
tool_call = next(
|
|
152
|
+
(
|
|
153
|
+
tool_call
|
|
154
|
+
for tool_call in message.tool_calls
|
|
155
|
+
if tool_call.function.name == "task_response"
|
|
156
|
+
),
|
|
157
|
+
None,
|
|
158
|
+
)
|
|
159
|
+
if tool_call:
|
|
160
|
+
response_content = tool_call.function.arguments
|
|
161
|
+
|
|
162
|
+
if not isinstance(response_content, str):
|
|
163
|
+
raise RuntimeError(f"response is not a string: {response_content}")
|
|
164
|
+
|
|
165
|
+
return RunOutput(
|
|
166
|
+
output=response_content,
|
|
167
|
+
intermediate_outputs=intermediate_outputs,
|
|
168
|
+
output_logprobs=logprobs,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def adapter_name(self) -> str:
|
|
172
|
+
return "kiln_openai_compatible_adapter"
|
|
173
|
+
|
|
174
|
+
async def response_format_options(self) -> dict[str, Any]:
|
|
175
|
+
# Unstructured if task isn't structured
|
|
176
|
+
if not self.has_structured_output():
|
|
177
|
+
return {}
|
|
178
|
+
|
|
179
|
+
provider = self.model_provider()
|
|
180
|
+
match provider.structured_output_mode:
|
|
181
|
+
case StructuredOutputMode.json_mode:
|
|
182
|
+
return {"response_format": {"type": "json_object"}}
|
|
183
|
+
case StructuredOutputMode.json_schema:
|
|
184
|
+
return self.json_schema_response_format()
|
|
185
|
+
case StructuredOutputMode.function_calling_weak:
|
|
186
|
+
return self.tool_call_params(strict=False)
|
|
187
|
+
case StructuredOutputMode.function_calling:
|
|
188
|
+
return self.tool_call_params(strict=True)
|
|
189
|
+
case StructuredOutputMode.json_instructions:
|
|
190
|
+
# JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
|
|
191
|
+
return {}
|
|
192
|
+
case StructuredOutputMode.json_custom_instructions:
|
|
193
|
+
# JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
|
|
194
|
+
return {}
|
|
195
|
+
case StructuredOutputMode.json_instruction_and_object:
|
|
196
|
+
# We set response_format to json_object and also set json instructions in the prompt
|
|
197
|
+
return {"response_format": {"type": "json_object"}}
|
|
198
|
+
case StructuredOutputMode.default:
|
|
199
|
+
if provider.name == ModelProviderName.ollama:
|
|
200
|
+
# Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
|
|
201
|
+
return self.json_schema_response_format()
|
|
202
|
+
else:
|
|
203
|
+
# Default to function calling -- it's older than the other modes. Higher compatibility.
|
|
204
|
+
# Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
|
|
205
|
+
strict = provider.name == ModelProviderName.openai
|
|
206
|
+
return self.tool_call_params(strict=strict)
|
|
207
|
+
case _:
|
|
208
|
+
raise_exhaustive_enum_error(provider.structured_output_mode)
|
|
209
|
+
|
|
210
|
+
def json_schema_response_format(self) -> dict[str, Any]:
|
|
211
|
+
output_schema = self.task().output_schema()
|
|
212
|
+
return {
|
|
213
|
+
"response_format": {
|
|
214
|
+
"type": "json_schema",
|
|
215
|
+
"json_schema": {
|
|
216
|
+
"name": "task_response",
|
|
217
|
+
"schema": output_schema,
|
|
218
|
+
},
|
|
219
|
+
}
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
def tool_call_params(self, strict: bool) -> dict[str, Any]:
|
|
223
|
+
# Add additional_properties: false to the schema (OpenAI requires this for some models)
|
|
224
|
+
output_schema = self.task().output_schema()
|
|
225
|
+
if not isinstance(output_schema, dict):
|
|
226
|
+
raise ValueError(
|
|
227
|
+
"Invalid output schema for this task. Can not use tool calls."
|
|
228
|
+
)
|
|
229
|
+
output_schema["additionalProperties"] = False
|
|
230
|
+
|
|
231
|
+
function_params = {
|
|
232
|
+
"name": "task_response",
|
|
233
|
+
"parameters": output_schema,
|
|
234
|
+
}
|
|
235
|
+
# This should be on, but we allow setting function_calling_weak for APIs that don't support it.
|
|
236
|
+
if strict:
|
|
237
|
+
function_params["strict"] = True
|
|
238
|
+
|
|
239
|
+
return {
|
|
240
|
+
"tools": [
|
|
241
|
+
{
|
|
242
|
+
"type": "function",
|
|
243
|
+
"function": function_params,
|
|
244
|
+
}
|
|
245
|
+
],
|
|
246
|
+
"tool_choice": {
|
|
247
|
+
"type": "function",
|
|
248
|
+
"function": {"name": "task_response"},
|
|
249
|
+
},
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
|
|
253
|
+
# TODO P1: Don't love having this logic here. But it's a usability improvement
|
|
254
|
+
# so better to keep it than exclude it. Should figure out how I want to isolate
|
|
255
|
+
# this sort of logic so it's config driven and can be overridden
|
|
256
|
+
|
|
257
|
+
extra_body = {}
|
|
258
|
+
provider_options = {}
|
|
259
|
+
|
|
260
|
+
if provider.thinking_level is not None:
|
|
261
|
+
extra_body["reasoning_effort"] = provider.thinking_level
|
|
262
|
+
|
|
263
|
+
if provider.require_openrouter_reasoning:
|
|
264
|
+
# https://openrouter.ai/docs/use-cases/reasoning-tokens
|
|
265
|
+
extra_body["reasoning"] = {
|
|
266
|
+
"exclude": False,
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
if provider.anthropic_extended_thinking:
|
|
270
|
+
extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
|
|
271
|
+
|
|
272
|
+
if provider.r1_openrouter_options:
|
|
273
|
+
# Require providers that support the reasoning parameter
|
|
274
|
+
provider_options["require_parameters"] = True
|
|
275
|
+
# Prefer R1 providers with reasonable perf/quants
|
|
276
|
+
provider_options["order"] = ["Fireworks", "Together"]
|
|
277
|
+
# R1 providers with unreasonable quants
|
|
278
|
+
provider_options["ignore"] = ["DeepInfra"]
|
|
279
|
+
|
|
280
|
+
# Only set of this request is to get logprobs.
|
|
281
|
+
if (
|
|
282
|
+
provider.logprobs_openrouter_options
|
|
283
|
+
and self.base_adapter_config.top_logprobs is not None
|
|
284
|
+
):
|
|
285
|
+
# Don't let OpenRouter choose a provider that doesn't support logprobs.
|
|
286
|
+
provider_options["require_parameters"] = True
|
|
287
|
+
# DeepInfra silently fails to return logprobs consistently.
|
|
288
|
+
provider_options["ignore"] = ["DeepInfra"]
|
|
289
|
+
|
|
290
|
+
if provider.openrouter_skip_required_parameters:
|
|
291
|
+
# Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
|
|
292
|
+
provider_options["require_parameters"] = False
|
|
293
|
+
|
|
294
|
+
if len(provider_options) > 0:
|
|
295
|
+
extra_body["provider"] = provider_options
|
|
296
|
+
|
|
297
|
+
return extra_body
|
|
298
|
+
|
|
299
|
+
def litellm_model_id(self) -> str:
|
|
300
|
+
# The model ID is an interesting combination of format and url endpoint.
|
|
301
|
+
# It specifics the provider URL/host, but this is overridden if you manually set an api url
|
|
302
|
+
|
|
303
|
+
if self._litellm_model_id:
|
|
304
|
+
return self._litellm_model_id
|
|
305
|
+
|
|
306
|
+
provider = self.model_provider()
|
|
307
|
+
if not provider.model_id:
|
|
308
|
+
raise ValueError("Model ID is required for OpenAI compatible models")
|
|
309
|
+
|
|
310
|
+
litellm_provider_name: str | None = None
|
|
311
|
+
is_custom = False
|
|
312
|
+
match provider.name:
|
|
313
|
+
case ModelProviderName.openrouter:
|
|
314
|
+
litellm_provider_name = "openrouter"
|
|
315
|
+
case ModelProviderName.openai:
|
|
316
|
+
litellm_provider_name = "openai"
|
|
317
|
+
case ModelProviderName.groq:
|
|
318
|
+
litellm_provider_name = "groq"
|
|
319
|
+
case ModelProviderName.anthropic:
|
|
320
|
+
litellm_provider_name = "anthropic"
|
|
321
|
+
case ModelProviderName.ollama:
|
|
322
|
+
# We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
|
|
323
|
+
# This is because we're setting detailed features like response_format=json_schema and want lower level control.
|
|
324
|
+
is_custom = True
|
|
325
|
+
case ModelProviderName.gemini_api:
|
|
326
|
+
litellm_provider_name = "gemini"
|
|
327
|
+
case ModelProviderName.fireworks_ai:
|
|
328
|
+
litellm_provider_name = "fireworks_ai"
|
|
329
|
+
case ModelProviderName.amazon_bedrock:
|
|
330
|
+
litellm_provider_name = "bedrock"
|
|
331
|
+
case ModelProviderName.azure_openai:
|
|
332
|
+
litellm_provider_name = "azure"
|
|
333
|
+
case ModelProviderName.huggingface:
|
|
334
|
+
litellm_provider_name = "huggingface"
|
|
335
|
+
case ModelProviderName.vertex:
|
|
336
|
+
litellm_provider_name = "vertex_ai"
|
|
337
|
+
case ModelProviderName.together_ai:
|
|
338
|
+
litellm_provider_name = "together_ai"
|
|
339
|
+
case ModelProviderName.openai_compatible:
|
|
340
|
+
is_custom = True
|
|
341
|
+
case ModelProviderName.kiln_custom_registry:
|
|
342
|
+
is_custom = True
|
|
343
|
+
case ModelProviderName.kiln_fine_tune:
|
|
344
|
+
is_custom = True
|
|
345
|
+
case _:
|
|
346
|
+
raise_exhaustive_enum_error(provider.name)
|
|
347
|
+
|
|
348
|
+
if is_custom:
|
|
349
|
+
if self._api_base is None:
|
|
350
|
+
raise ValueError(
|
|
351
|
+
"Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
|
|
352
|
+
)
|
|
353
|
+
# Use openai as it's only used for format, not url
|
|
354
|
+
litellm_provider_name = "openai"
|
|
355
|
+
|
|
356
|
+
# Sholdn't be possible but keep type checker happy
|
|
357
|
+
if litellm_provider_name is None:
|
|
358
|
+
raise ValueError(
|
|
359
|
+
f"Provider name could not lookup valid litellm provider ID {provider.model_id}"
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
self._litellm_model_id = litellm_provider_name + "/" + provider.model_id
|
|
363
|
+
return self._litellm_model_id
|
|
364
|
+
|
|
365
|
+
async def build_completion_kwargs(
|
|
366
|
+
self,
|
|
367
|
+
provider: KilnModelProvider,
|
|
368
|
+
messages: list[dict[str, Any]],
|
|
369
|
+
top_logprobs: int | None,
|
|
370
|
+
) -> dict[str, Any]:
|
|
371
|
+
extra_body = self.build_extra_body(provider)
|
|
372
|
+
|
|
373
|
+
# Merge all parameters into a single kwargs dict for litellm
|
|
374
|
+
completion_kwargs = {
|
|
375
|
+
"model": self.litellm_model_id(),
|
|
376
|
+
"messages": messages,
|
|
377
|
+
"api_base": self._api_base,
|
|
378
|
+
"headers": self._headers,
|
|
379
|
+
**extra_body,
|
|
380
|
+
**self._additional_body_options,
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
# Response format: json_schema, json_instructions, json_mode, function_calling, etc
|
|
384
|
+
response_format_options = await self.response_format_options()
|
|
385
|
+
completion_kwargs.update(response_format_options)
|
|
386
|
+
|
|
387
|
+
if top_logprobs is not None:
|
|
388
|
+
completion_kwargs["logprobs"] = True
|
|
389
|
+
completion_kwargs["top_logprobs"] = top_logprobs
|
|
390
|
+
|
|
391
|
+
return completion_kwargs
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@dataclass
|
|
5
|
+
class LiteLlmConfig:
|
|
6
|
+
model_name: str
|
|
7
|
+
provider_name: str
|
|
8
|
+
# If set, over rides the provider-name based URL from litellm
|
|
9
|
+
base_url: str | None = None
|
|
10
|
+
# Headers to send with every request
|
|
11
|
+
default_headers: dict[str, str] | None = None
|
|
12
|
+
# Extra body to send with every request
|
|
13
|
+
additional_body_options: dict[str, str] = field(default_factory=dict)
|
|
@@ -3,8 +3,9 @@ from unittest.mock import MagicMock, patch
|
|
|
3
3
|
import pytest
|
|
4
4
|
|
|
5
5
|
from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
|
|
6
|
-
from kiln_ai.adapters.model_adapters.base_adapter import
|
|
6
|
+
from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter
|
|
7
7
|
from kiln_ai.datamodel import Task
|
|
8
|
+
from kiln_ai.datamodel.task import RunConfig
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
class MockAdapter(BaseAdapter):
|
|
@@ -13,13 +14,8 @@ class MockAdapter(BaseAdapter):
|
|
|
13
14
|
async def _run(self, input):
|
|
14
15
|
return None
|
|
15
16
|
|
|
16
|
-
def
|
|
17
|
-
return
|
|
18
|
-
adapter_name="test",
|
|
19
|
-
model_name=self.model_name,
|
|
20
|
-
model_provider=self.model_provider_name,
|
|
21
|
-
prompt_builder_name="test",
|
|
22
|
-
)
|
|
17
|
+
def adapter_name(self) -> str:
|
|
18
|
+
return "test"
|
|
23
19
|
|
|
24
20
|
|
|
25
21
|
@pytest.fixture
|
|
@@ -37,9 +33,12 @@ def base_task():
|
|
|
37
33
|
@pytest.fixture
|
|
38
34
|
def adapter(base_task):
|
|
39
35
|
return MockAdapter(
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
36
|
+
run_config=RunConfig(
|
|
37
|
+
task=base_task,
|
|
38
|
+
model_name="test_model",
|
|
39
|
+
model_provider_name="test_provider",
|
|
40
|
+
prompt_id="simple_prompt_builder",
|
|
41
|
+
),
|
|
43
42
|
)
|
|
44
43
|
|
|
45
44
|
|
|
@@ -85,7 +84,12 @@ async def test_model_provider_missing_names(base_task):
|
|
|
85
84
|
"""Test error when model or provider name is missing"""
|
|
86
85
|
# Test with missing model name
|
|
87
86
|
adapter = MockAdapter(
|
|
88
|
-
|
|
87
|
+
run_config=RunConfig(
|
|
88
|
+
task=base_task,
|
|
89
|
+
model_name="",
|
|
90
|
+
model_provider_name="",
|
|
91
|
+
prompt_id="simple_prompt_builder",
|
|
92
|
+
),
|
|
89
93
|
)
|
|
90
94
|
with pytest.raises(
|
|
91
95
|
ValueError, match="model_name and model_provider_name must be provided"
|
|
@@ -94,7 +98,12 @@ async def test_model_provider_missing_names(base_task):
|
|
|
94
98
|
|
|
95
99
|
# Test with missing provider name
|
|
96
100
|
adapter = MockAdapter(
|
|
97
|
-
|
|
101
|
+
run_config=RunConfig(
|
|
102
|
+
task=base_task,
|
|
103
|
+
model_name="test_model",
|
|
104
|
+
model_provider_name="",
|
|
105
|
+
prompt_id="simple_prompt_builder",
|
|
106
|
+
),
|
|
98
107
|
)
|
|
99
108
|
with pytest.raises(
|
|
100
109
|
ValueError, match="model_name and model_provider_name must be provided"
|