kiln-ai 0.11.1__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (80) hide show
  1. kiln_ai/adapters/__init__.py +4 -0
  2. kiln_ai/adapters/adapter_registry.py +163 -39
  3. kiln_ai/adapters/data_gen/data_gen_task.py +18 -0
  4. kiln_ai/adapters/eval/__init__.py +28 -0
  5. kiln_ai/adapters/eval/base_eval.py +164 -0
  6. kiln_ai/adapters/eval/eval_runner.py +270 -0
  7. kiln_ai/adapters/eval/g_eval.py +368 -0
  8. kiln_ai/adapters/eval/registry.py +16 -0
  9. kiln_ai/adapters/eval/test_base_eval.py +325 -0
  10. kiln_ai/adapters/eval/test_eval_runner.py +641 -0
  11. kiln_ai/adapters/eval/test_g_eval.py +498 -0
  12. kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
  14. kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
  15. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +4 -1
  16. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
  17. kiln_ai/adapters/fine_tune/test_openai_finetune.py +1 -1
  18. kiln_ai/adapters/fine_tune/test_together_finetune.py +531 -0
  19. kiln_ai/adapters/fine_tune/together_finetune.py +325 -0
  20. kiln_ai/adapters/ml_model_list.py +758 -163
  21. kiln_ai/adapters/model_adapters/__init__.py +2 -4
  22. kiln_ai/adapters/model_adapters/base_adapter.py +61 -43
  23. kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
  24. kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
  25. kiln_ai/adapters/model_adapters/test_base_adapter.py +22 -13
  26. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
  27. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -19
  28. kiln_ai/adapters/model_adapters/test_structured_output.py +59 -35
  29. kiln_ai/adapters/ollama_tools.py +3 -3
  30. kiln_ai/adapters/parsers/r1_parser.py +19 -14
  31. kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
  32. kiln_ai/adapters/prompt_builders.py +80 -42
  33. kiln_ai/adapters/provider_tools.py +50 -58
  34. kiln_ai/adapters/repair/repair_task.py +9 -21
  35. kiln_ai/adapters/repair/test_repair_task.py +6 -6
  36. kiln_ai/adapters/run_output.py +3 -0
  37. kiln_ai/adapters/test_adapter_registry.py +26 -29
  38. kiln_ai/adapters/test_generate_docs.py +4 -4
  39. kiln_ai/adapters/test_ollama_tools.py +0 -1
  40. kiln_ai/adapters/test_prompt_adaptors.py +47 -33
  41. kiln_ai/adapters/test_prompt_builders.py +91 -31
  42. kiln_ai/adapters/test_provider_tools.py +26 -81
  43. kiln_ai/datamodel/__init__.py +50 -952
  44. kiln_ai/datamodel/basemodel.py +2 -0
  45. kiln_ai/datamodel/datamodel_enums.py +60 -0
  46. kiln_ai/datamodel/dataset_filters.py +114 -0
  47. kiln_ai/datamodel/dataset_split.py +170 -0
  48. kiln_ai/datamodel/eval.py +298 -0
  49. kiln_ai/datamodel/finetune.py +105 -0
  50. kiln_ai/datamodel/json_schema.py +7 -1
  51. kiln_ai/datamodel/project.py +23 -0
  52. kiln_ai/datamodel/prompt.py +37 -0
  53. kiln_ai/datamodel/prompt_id.py +83 -0
  54. kiln_ai/datamodel/strict_mode.py +24 -0
  55. kiln_ai/datamodel/task.py +181 -0
  56. kiln_ai/datamodel/task_output.py +328 -0
  57. kiln_ai/datamodel/task_run.py +164 -0
  58. kiln_ai/datamodel/test_basemodel.py +19 -11
  59. kiln_ai/datamodel/test_dataset_filters.py +71 -0
  60. kiln_ai/datamodel/test_dataset_split.py +32 -8
  61. kiln_ai/datamodel/test_datasource.py +22 -2
  62. kiln_ai/datamodel/test_eval_model.py +635 -0
  63. kiln_ai/datamodel/test_example_models.py +9 -13
  64. kiln_ai/datamodel/test_json_schema.py +23 -0
  65. kiln_ai/datamodel/test_models.py +2 -2
  66. kiln_ai/datamodel/test_prompt_id.py +129 -0
  67. kiln_ai/datamodel/test_task.py +159 -0
  68. kiln_ai/utils/config.py +43 -1
  69. kiln_ai/utils/dataset_import.py +232 -0
  70. kiln_ai/utils/test_dataset_import.py +596 -0
  71. {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/METADATA +86 -6
  72. kiln_ai-0.13.0.dist-info/RECORD +103 -0
  73. kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -302
  74. kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -11
  75. kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -246
  76. kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -350
  77. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -225
  78. kiln_ai-0.11.1.dist-info/RECORD +0 -76
  79. {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/WHEEL +0 -0
  80. {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -7,12 +7,10 @@ Model adapters are used to call AI models, like Ollama, OpenAI, etc.
7
7
 
8
8
  from . import (
9
9
  base_adapter,
10
- langchain_adapters,
11
- openai_model_adapter,
10
+ litellm_adapter,
12
11
  )
13
12
 
14
13
  __all__ = [
15
14
  "base_adapter",
16
- "langchain_adapters",
17
- "openai_model_adapter",
15
+ "litellm_adapter",
18
16
  ]
@@ -4,8 +4,9 @@ from dataclasses import dataclass
4
4
  from typing import Dict, Literal, Tuple
5
5
 
6
6
  from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
7
+ from kiln_ai.adapters.parsers.json_parser import parse_json_string
7
8
  from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id
8
- from kiln_ai.adapters.prompt_builders import BasePromptBuilder, SimplePromptBuilder
9
+ from kiln_ai.adapters.prompt_builders import prompt_builder_from_id
9
10
  from kiln_ai.adapters.provider_tools import kiln_model_provider_from
10
11
  from kiln_ai.adapters.run_output import RunOutput
11
12
  from kiln_ai.datamodel import (
@@ -16,16 +17,21 @@ from kiln_ai.datamodel import (
16
17
  TaskRun,
17
18
  )
18
19
  from kiln_ai.datamodel.json_schema import validate_schema
20
+ from kiln_ai.datamodel.task import RunConfig
19
21
  from kiln_ai.utils.config import Config
20
22
 
21
23
 
22
24
  @dataclass
23
- class AdapterInfo:
24
- adapter_name: str
25
- model_name: str
26
- model_provider: str
27
- prompt_builder_name: str
28
- prompt_id: str | None = None
25
+ class AdapterConfig:
26
+ """
27
+ An adapter config is config options that do NOT impact the output of the model.
28
+
29
+ For example: if it's saved, of if we request additional data like logprobs.
30
+ """
31
+
32
+ allow_saving: bool = True
33
+ top_logprobs: int | None = None
34
+ default_tags: list[str] | None = None
29
35
 
30
36
 
31
37
  COT_FINAL_ANSWER_PROMPT = "Considering the above, return a final result."
@@ -47,54 +53,52 @@ class BaseAdapter(metaclass=ABCMeta):
47
53
 
48
54
  def __init__(
49
55
  self,
50
- kiln_task: Task,
51
- model_name: str,
52
- model_provider_name: str,
53
- prompt_builder: BasePromptBuilder | None = None,
54
- tags: list[str] | None = None,
56
+ run_config: RunConfig,
57
+ config: AdapterConfig | None = None,
55
58
  ):
56
- self.prompt_builder = prompt_builder or SimplePromptBuilder(kiln_task)
57
- self.kiln_task = kiln_task
58
- self.output_schema = self.kiln_task.output_json_schema
59
- self.input_schema = self.kiln_task.input_json_schema
60
- self.default_tags = tags
61
- self.model_name = model_name
62
- self.model_provider_name = model_provider_name
59
+ self.run_config = run_config
60
+ self.prompt_builder = prompt_builder_from_id(
61
+ run_config.prompt_id, run_config.task
62
+ )
63
63
  self._model_provider: KilnModelProvider | None = None
64
64
 
65
+ self.output_schema = self.task().output_json_schema
66
+ self.input_schema = self.task().input_json_schema
67
+ self.base_adapter_config = config or AdapterConfig()
68
+
69
+ def task(self) -> Task:
70
+ return self.run_config.task
71
+
65
72
  def model_provider(self) -> KilnModelProvider:
66
73
  """
67
74
  Lazy load the model provider for this adapter.
68
75
  """
69
76
  if self._model_provider is not None:
70
77
  return self._model_provider
71
- if not self.model_name or not self.model_provider_name:
78
+ if not self.run_config.model_name or not self.run_config.model_provider_name:
72
79
  raise ValueError("model_name and model_provider_name must be provided")
73
80
  self._model_provider = kiln_model_provider_from(
74
- self.model_name, self.model_provider_name
81
+ self.run_config.model_name, self.run_config.model_provider_name
75
82
  )
76
83
  if not self._model_provider:
77
84
  raise ValueError(
78
- f"model_provider_name {self.model_provider_name} not found for model {self.model_name}"
85
+ f"model_provider_name {self.run_config.model_provider_name} not found for model {self.run_config.model_name}"
79
86
  )
80
87
  return self._model_provider
81
88
 
82
- async def invoke_returning_raw(
89
+ async def invoke(
83
90
  self,
84
91
  input: Dict | str,
85
92
  input_source: DataSource | None = None,
86
- ) -> Dict | str:
87
- result = await self.invoke(input, input_source)
88
- if self.kiln_task.output_json_schema is None:
89
- return result.output.output
90
- else:
91
- return json.loads(result.output.output)
93
+ ) -> TaskRun:
94
+ run_output, _ = await self.invoke_returning_run_output(input, input_source)
95
+ return run_output
92
96
 
93
- async def invoke(
97
+ async def invoke_returning_run_output(
94
98
  self,
95
99
  input: Dict | str,
96
100
  input_source: DataSource | None = None,
97
- ) -> TaskRun:
101
+ ) -> Tuple[TaskRun, RunOutput]:
98
102
  # validate input
99
103
  if self.input_schema is not None:
100
104
  if not isinstance(input, dict):
@@ -113,6 +117,10 @@ class BaseAdapter(metaclass=ABCMeta):
113
117
 
114
118
  # validate output
115
119
  if self.output_schema is not None:
120
+ # Parse json to dict if we have structured output
121
+ if isinstance(parsed_output.output, str):
122
+ parsed_output.output = parse_json_string(parsed_output.output)
123
+
116
124
  if not isinstance(parsed_output.output, dict):
117
125
  raise RuntimeError(
118
126
  f"structured response is not a dict: {parsed_output.output}"
@@ -124,23 +132,36 @@ class BaseAdapter(metaclass=ABCMeta):
124
132
  f"response is not a string for non-structured task: {parsed_output.output}"
125
133
  )
126
134
 
135
+ # Validate reasoning content is present (if reasoning)
136
+ if provider.reasoning_capable and (
137
+ not parsed_output.intermediate_outputs
138
+ or "reasoning" not in parsed_output.intermediate_outputs
139
+ ):
140
+ raise RuntimeError(
141
+ "Reasoning is required for this model, but no reasoning was returned."
142
+ )
143
+
127
144
  # Generate the run and output
128
145
  run = self.generate_run(input, input_source, parsed_output)
129
146
 
130
147
  # Save the run if configured to do so, and we have a path to save to
131
- if Config.shared().autosave_runs and self.kiln_task.path is not None:
148
+ if (
149
+ self.base_adapter_config.allow_saving
150
+ and Config.shared().autosave_runs
151
+ and self.task().path is not None
152
+ ):
132
153
  run.save_to_file()
133
154
  else:
134
155
  # Clear the ID to indicate it's not persisted
135
156
  run.id = None
136
157
 
137
- return run
158
+ return run, run_output
138
159
 
139
160
  def has_structured_output(self) -> bool:
140
161
  return self.output_schema is not None
141
162
 
142
163
  @abstractmethod
143
- def adapter_info(self) -> AdapterInfo:
164
+ def adapter_name(self) -> str:
144
165
  pass
145
166
 
146
167
  @abstractmethod
@@ -203,7 +224,7 @@ class BaseAdapter(metaclass=ABCMeta):
203
224
  )
204
225
 
205
226
  new_task_run = TaskRun(
206
- parent=self.kiln_task,
227
+ parent=self.task(),
207
228
  input=input_str,
208
229
  input_source=input_source,
209
230
  output=TaskOutput(
@@ -215,7 +236,7 @@ class BaseAdapter(metaclass=ABCMeta):
215
236
  ),
216
237
  ),
217
238
  intermediate_outputs=run_output.intermediate_outputs,
218
- tags=self.default_tags or [],
239
+ tags=self.base_adapter_config.default_tags or [],
219
240
  )
220
241
 
221
242
  return new_task_run
@@ -224,12 +245,9 @@ class BaseAdapter(metaclass=ABCMeta):
224
245
  props = {}
225
246
 
226
247
  # adapter info
227
- adapter_info = self.adapter_info()
228
- props["adapter_name"] = adapter_info.adapter_name
229
- props["model_name"] = adapter_info.model_name
230
- props["model_provider"] = adapter_info.model_provider
231
- props["prompt_builder_name"] = adapter_info.prompt_builder_name
232
- if adapter_info.prompt_id is not None:
233
- props["prompt_id"] = adapter_info.prompt_id
248
+ props["adapter_name"] = self.adapter_name()
249
+ props["model_name"] = self.run_config.model_name
250
+ props["model_provider"] = self.run_config.model_provider_name
251
+ props["prompt_id"] = self.run_config.prompt_id
234
252
 
235
253
  return props
@@ -0,0 +1,391 @@
1
+ from typing import Any, Dict
2
+
3
+ import litellm
4
+ from litellm.types.utils import ChoiceLogprobs, Choices, ModelResponse
5
+
6
+ import kiln_ai.datamodel as datamodel
7
+ from kiln_ai.adapters.ml_model_list import (
8
+ KilnModelProvider,
9
+ ModelProviderName,
10
+ StructuredOutputMode,
11
+ )
12
+ from kiln_ai.adapters.model_adapters.base_adapter import (
13
+ COT_FINAL_ANSWER_PROMPT,
14
+ AdapterConfig,
15
+ BaseAdapter,
16
+ RunOutput,
17
+ )
18
+ from kiln_ai.adapters.model_adapters.litellm_config import (
19
+ LiteLlmConfig,
20
+ )
21
+ from kiln_ai.datamodel import PromptGenerators, PromptId
22
+ from kiln_ai.datamodel.task import RunConfig
23
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
24
+
25
+
26
+ class LiteLlmAdapter(BaseAdapter):
27
+ def __init__(
28
+ self,
29
+ config: LiteLlmConfig,
30
+ kiln_task: datamodel.Task,
31
+ prompt_id: PromptId | None = None,
32
+ base_adapter_config: AdapterConfig | None = None,
33
+ ):
34
+ self.config = config
35
+ self._additional_body_options = config.additional_body_options
36
+ self._api_base = config.base_url
37
+ self._headers = config.default_headers
38
+ self._litellm_model_id: str | None = None
39
+
40
+ run_config = RunConfig(
41
+ task=kiln_task,
42
+ model_name=config.model_name,
43
+ model_provider_name=config.provider_name,
44
+ prompt_id=prompt_id or PromptGenerators.SIMPLE,
45
+ )
46
+
47
+ super().__init__(
48
+ run_config=run_config,
49
+ config=base_adapter_config,
50
+ )
51
+
52
+ async def _run(self, input: Dict | str) -> RunOutput:
53
+ provider = self.model_provider()
54
+ if not provider.model_id:
55
+ raise ValueError("Model ID is required for OpenAI compatible models")
56
+
57
+ intermediate_outputs: dict[str, str] = {}
58
+ prompt = self.build_prompt()
59
+ user_msg = self.prompt_builder.build_user_message(input)
60
+ messages = [
61
+ {"role": "system", "content": prompt},
62
+ {"role": "user", "content": user_msg},
63
+ ]
64
+
65
+ run_strategy, cot_prompt = self.run_strategy()
66
+
67
+ if run_strategy == "cot_as_message":
68
+ if not cot_prompt:
69
+ raise ValueError("cot_prompt is required for cot_as_message strategy")
70
+ messages.append({"role": "system", "content": cot_prompt})
71
+ elif run_strategy == "cot_two_call":
72
+ if not cot_prompt:
73
+ raise ValueError("cot_prompt is required for cot_two_call strategy")
74
+ messages.append({"role": "system", "content": cot_prompt})
75
+
76
+ # First call for chain of thought - No logprobs as only needed for final answer
77
+ completion_kwargs = await self.build_completion_kwargs(
78
+ provider, messages, None
79
+ )
80
+ cot_response = await litellm.acompletion(**completion_kwargs)
81
+ if (
82
+ not isinstance(cot_response, ModelResponse)
83
+ or not cot_response.choices
84
+ or len(cot_response.choices) == 0
85
+ or not isinstance(cot_response.choices[0], Choices)
86
+ ):
87
+ raise RuntimeError(
88
+ f"Expected ModelResponse with Choices, got {type(cot_response)}."
89
+ )
90
+ cot_content = cot_response.choices[0].message.content
91
+ if cot_content is not None:
92
+ intermediate_outputs["chain_of_thought"] = cot_content
93
+
94
+ messages.extend(
95
+ [
96
+ {"role": "assistant", "content": cot_content or ""},
97
+ {"role": "user", "content": COT_FINAL_ANSWER_PROMPT},
98
+ ]
99
+ )
100
+
101
+ # Make the API call using litellm
102
+ completion_kwargs = await self.build_completion_kwargs(
103
+ provider, messages, self.base_adapter_config.top_logprobs
104
+ )
105
+ response = await litellm.acompletion(**completion_kwargs)
106
+
107
+ if not isinstance(response, ModelResponse):
108
+ raise RuntimeError(f"Expected ModelResponse, got {type(response)}.")
109
+
110
+ # Maybe remove this? There is no error attribute on the response object.
111
+ # # Keeping in typesafe way as we added it for a reason, but should investigate what that was and if it still applies.
112
+ if hasattr(response, "error") and response.__getattribute__("error"):
113
+ raise RuntimeError(
114
+ f"LLM API returned an error: {response.__getattribute__('error')}"
115
+ )
116
+
117
+ if (
118
+ not response.choices
119
+ or len(response.choices) == 0
120
+ or not isinstance(response.choices[0], Choices)
121
+ ):
122
+ raise RuntimeError(
123
+ "No message content returned in the response from LLM API"
124
+ )
125
+
126
+ message = response.choices[0].message
127
+ logprobs = (
128
+ response.choices[0].logprobs
129
+ if hasattr(response.choices[0], "logprobs")
130
+ and isinstance(response.choices[0].logprobs, ChoiceLogprobs)
131
+ else None
132
+ )
133
+
134
+ # Check logprobs worked, if requested
135
+ if self.base_adapter_config.top_logprobs is not None and logprobs is None:
136
+ raise RuntimeError("Logprobs were required, but no logprobs were returned.")
137
+
138
+ # Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream)
139
+ if hasattr(message, "reasoning_content") and message.reasoning_content:
140
+ intermediate_outputs["reasoning"] = message.reasoning_content
141
+
142
+ # the string content of the response
143
+ response_content = message.content
144
+
145
+ # Fallback: Use args of first tool call to task_response if it exists
146
+ if (
147
+ not response_content
148
+ and hasattr(message, "tool_calls")
149
+ and message.tool_calls
150
+ ):
151
+ tool_call = next(
152
+ (
153
+ tool_call
154
+ for tool_call in message.tool_calls
155
+ if tool_call.function.name == "task_response"
156
+ ),
157
+ None,
158
+ )
159
+ if tool_call:
160
+ response_content = tool_call.function.arguments
161
+
162
+ if not isinstance(response_content, str):
163
+ raise RuntimeError(f"response is not a string: {response_content}")
164
+
165
+ return RunOutput(
166
+ output=response_content,
167
+ intermediate_outputs=intermediate_outputs,
168
+ output_logprobs=logprobs,
169
+ )
170
+
171
+ def adapter_name(self) -> str:
172
+ return "kiln_openai_compatible_adapter"
173
+
174
+ async def response_format_options(self) -> dict[str, Any]:
175
+ # Unstructured if task isn't structured
176
+ if not self.has_structured_output():
177
+ return {}
178
+
179
+ provider = self.model_provider()
180
+ match provider.structured_output_mode:
181
+ case StructuredOutputMode.json_mode:
182
+ return {"response_format": {"type": "json_object"}}
183
+ case StructuredOutputMode.json_schema:
184
+ return self.json_schema_response_format()
185
+ case StructuredOutputMode.function_calling_weak:
186
+ return self.tool_call_params(strict=False)
187
+ case StructuredOutputMode.function_calling:
188
+ return self.tool_call_params(strict=True)
189
+ case StructuredOutputMode.json_instructions:
190
+ # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
191
+ return {}
192
+ case StructuredOutputMode.json_custom_instructions:
193
+ # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
194
+ return {}
195
+ case StructuredOutputMode.json_instruction_and_object:
196
+ # We set response_format to json_object and also set json instructions in the prompt
197
+ return {"response_format": {"type": "json_object"}}
198
+ case StructuredOutputMode.default:
199
+ if provider.name == ModelProviderName.ollama:
200
+ # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
201
+ return self.json_schema_response_format()
202
+ else:
203
+ # Default to function calling -- it's older than the other modes. Higher compatibility.
204
+ # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
205
+ strict = provider.name == ModelProviderName.openai
206
+ return self.tool_call_params(strict=strict)
207
+ case _:
208
+ raise_exhaustive_enum_error(provider.structured_output_mode)
209
+
210
+ def json_schema_response_format(self) -> dict[str, Any]:
211
+ output_schema = self.task().output_schema()
212
+ return {
213
+ "response_format": {
214
+ "type": "json_schema",
215
+ "json_schema": {
216
+ "name": "task_response",
217
+ "schema": output_schema,
218
+ },
219
+ }
220
+ }
221
+
222
+ def tool_call_params(self, strict: bool) -> dict[str, Any]:
223
+ # Add additional_properties: false to the schema (OpenAI requires this for some models)
224
+ output_schema = self.task().output_schema()
225
+ if not isinstance(output_schema, dict):
226
+ raise ValueError(
227
+ "Invalid output schema for this task. Can not use tool calls."
228
+ )
229
+ output_schema["additionalProperties"] = False
230
+
231
+ function_params = {
232
+ "name": "task_response",
233
+ "parameters": output_schema,
234
+ }
235
+ # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
236
+ if strict:
237
+ function_params["strict"] = True
238
+
239
+ return {
240
+ "tools": [
241
+ {
242
+ "type": "function",
243
+ "function": function_params,
244
+ }
245
+ ],
246
+ "tool_choice": {
247
+ "type": "function",
248
+ "function": {"name": "task_response"},
249
+ },
250
+ }
251
+
252
+ def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
253
+ # TODO P1: Don't love having this logic here. But it's a usability improvement
254
+ # so better to keep it than exclude it. Should figure out how I want to isolate
255
+ # this sort of logic so it's config driven and can be overridden
256
+
257
+ extra_body = {}
258
+ provider_options = {}
259
+
260
+ if provider.thinking_level is not None:
261
+ extra_body["reasoning_effort"] = provider.thinking_level
262
+
263
+ if provider.require_openrouter_reasoning:
264
+ # https://openrouter.ai/docs/use-cases/reasoning-tokens
265
+ extra_body["reasoning"] = {
266
+ "exclude": False,
267
+ }
268
+
269
+ if provider.anthropic_extended_thinking:
270
+ extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
271
+
272
+ if provider.r1_openrouter_options:
273
+ # Require providers that support the reasoning parameter
274
+ provider_options["require_parameters"] = True
275
+ # Prefer R1 providers with reasonable perf/quants
276
+ provider_options["order"] = ["Fireworks", "Together"]
277
+ # R1 providers with unreasonable quants
278
+ provider_options["ignore"] = ["DeepInfra"]
279
+
280
+ # Only set of this request is to get logprobs.
281
+ if (
282
+ provider.logprobs_openrouter_options
283
+ and self.base_adapter_config.top_logprobs is not None
284
+ ):
285
+ # Don't let OpenRouter choose a provider that doesn't support logprobs.
286
+ provider_options["require_parameters"] = True
287
+ # DeepInfra silently fails to return logprobs consistently.
288
+ provider_options["ignore"] = ["DeepInfra"]
289
+
290
+ if provider.openrouter_skip_required_parameters:
291
+ # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
292
+ provider_options["require_parameters"] = False
293
+
294
+ if len(provider_options) > 0:
295
+ extra_body["provider"] = provider_options
296
+
297
+ return extra_body
298
+
299
+ def litellm_model_id(self) -> str:
300
+ # The model ID is an interesting combination of format and url endpoint.
301
+ # It specifics the provider URL/host, but this is overridden if you manually set an api url
302
+
303
+ if self._litellm_model_id:
304
+ return self._litellm_model_id
305
+
306
+ provider = self.model_provider()
307
+ if not provider.model_id:
308
+ raise ValueError("Model ID is required for OpenAI compatible models")
309
+
310
+ litellm_provider_name: str | None = None
311
+ is_custom = False
312
+ match provider.name:
313
+ case ModelProviderName.openrouter:
314
+ litellm_provider_name = "openrouter"
315
+ case ModelProviderName.openai:
316
+ litellm_provider_name = "openai"
317
+ case ModelProviderName.groq:
318
+ litellm_provider_name = "groq"
319
+ case ModelProviderName.anthropic:
320
+ litellm_provider_name = "anthropic"
321
+ case ModelProviderName.ollama:
322
+ # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
323
+ # This is because we're setting detailed features like response_format=json_schema and want lower level control.
324
+ is_custom = True
325
+ case ModelProviderName.gemini_api:
326
+ litellm_provider_name = "gemini"
327
+ case ModelProviderName.fireworks_ai:
328
+ litellm_provider_name = "fireworks_ai"
329
+ case ModelProviderName.amazon_bedrock:
330
+ litellm_provider_name = "bedrock"
331
+ case ModelProviderName.azure_openai:
332
+ litellm_provider_name = "azure"
333
+ case ModelProviderName.huggingface:
334
+ litellm_provider_name = "huggingface"
335
+ case ModelProviderName.vertex:
336
+ litellm_provider_name = "vertex_ai"
337
+ case ModelProviderName.together_ai:
338
+ litellm_provider_name = "together_ai"
339
+ case ModelProviderName.openai_compatible:
340
+ is_custom = True
341
+ case ModelProviderName.kiln_custom_registry:
342
+ is_custom = True
343
+ case ModelProviderName.kiln_fine_tune:
344
+ is_custom = True
345
+ case _:
346
+ raise_exhaustive_enum_error(provider.name)
347
+
348
+ if is_custom:
349
+ if self._api_base is None:
350
+ raise ValueError(
351
+ "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
352
+ )
353
+ # Use openai as it's only used for format, not url
354
+ litellm_provider_name = "openai"
355
+
356
+ # Sholdn't be possible but keep type checker happy
357
+ if litellm_provider_name is None:
358
+ raise ValueError(
359
+ f"Provider name could not lookup valid litellm provider ID {provider.model_id}"
360
+ )
361
+
362
+ self._litellm_model_id = litellm_provider_name + "/" + provider.model_id
363
+ return self._litellm_model_id
364
+
365
+ async def build_completion_kwargs(
366
+ self,
367
+ provider: KilnModelProvider,
368
+ messages: list[dict[str, Any]],
369
+ top_logprobs: int | None,
370
+ ) -> dict[str, Any]:
371
+ extra_body = self.build_extra_body(provider)
372
+
373
+ # Merge all parameters into a single kwargs dict for litellm
374
+ completion_kwargs = {
375
+ "model": self.litellm_model_id(),
376
+ "messages": messages,
377
+ "api_base": self._api_base,
378
+ "headers": self._headers,
379
+ **extra_body,
380
+ **self._additional_body_options,
381
+ }
382
+
383
+ # Response format: json_schema, json_instructions, json_mode, function_calling, etc
384
+ response_format_options = await self.response_format_options()
385
+ completion_kwargs.update(response_format_options)
386
+
387
+ if top_logprobs is not None:
388
+ completion_kwargs["logprobs"] = True
389
+ completion_kwargs["top_logprobs"] = top_logprobs
390
+
391
+ return completion_kwargs
@@ -0,0 +1,13 @@
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class LiteLlmConfig:
6
+ model_name: str
7
+ provider_name: str
8
+ # If set, over rides the provider-name based URL from litellm
9
+ base_url: str | None = None
10
+ # Headers to send with every request
11
+ default_headers: dict[str, str] | None = None
12
+ # Extra body to send with every request
13
+ additional_body_options: dict[str, str] = field(default_factory=dict)
@@ -3,8 +3,9 @@ from unittest.mock import MagicMock, patch
3
3
  import pytest
4
4
 
5
5
  from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
6
- from kiln_ai.adapters.model_adapters.base_adapter import AdapterInfo, BaseAdapter
6
+ from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter
7
7
  from kiln_ai.datamodel import Task
8
+ from kiln_ai.datamodel.task import RunConfig
8
9
 
9
10
 
10
11
  class MockAdapter(BaseAdapter):
@@ -13,13 +14,8 @@ class MockAdapter(BaseAdapter):
13
14
  async def _run(self, input):
14
15
  return None
15
16
 
16
- def adapter_info(self) -> AdapterInfo:
17
- return AdapterInfo(
18
- adapter_name="test",
19
- model_name=self.model_name,
20
- model_provider=self.model_provider_name,
21
- prompt_builder_name="test",
22
- )
17
+ def adapter_name(self) -> str:
18
+ return "test"
23
19
 
24
20
 
25
21
  @pytest.fixture
@@ -37,9 +33,12 @@ def base_task():
37
33
  @pytest.fixture
38
34
  def adapter(base_task):
39
35
  return MockAdapter(
40
- kiln_task=base_task,
41
- model_name="test_model",
42
- model_provider_name="test_provider",
36
+ run_config=RunConfig(
37
+ task=base_task,
38
+ model_name="test_model",
39
+ model_provider_name="test_provider",
40
+ prompt_id="simple_prompt_builder",
41
+ ),
43
42
  )
44
43
 
45
44
 
@@ -85,7 +84,12 @@ async def test_model_provider_missing_names(base_task):
85
84
  """Test error when model or provider name is missing"""
86
85
  # Test with missing model name
87
86
  adapter = MockAdapter(
88
- kiln_task=base_task, model_name="", model_provider_name="test_provider"
87
+ run_config=RunConfig(
88
+ task=base_task,
89
+ model_name="",
90
+ model_provider_name="",
91
+ prompt_id="simple_prompt_builder",
92
+ ),
89
93
  )
90
94
  with pytest.raises(
91
95
  ValueError, match="model_name and model_provider_name must be provided"
@@ -94,7 +98,12 @@ async def test_model_provider_missing_names(base_task):
94
98
 
95
99
  # Test with missing provider name
96
100
  adapter = MockAdapter(
97
- kiln_task=base_task, model_name="test_model", model_provider_name=""
101
+ run_config=RunConfig(
102
+ task=base_task,
103
+ model_name="test_model",
104
+ model_provider_name="",
105
+ prompt_id="simple_prompt_builder",
106
+ ),
98
107
  )
99
108
  with pytest.raises(
100
109
  ValueError, match="model_name and model_provider_name must be provided"