kiln-ai 0.12.0__py3-none-any.whl → 0.13.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (49) hide show
  1. kiln_ai/adapters/__init__.py +4 -0
  2. kiln_ai/adapters/adapter_registry.py +157 -28
  3. kiln_ai/adapters/eval/__init__.py +28 -0
  4. kiln_ai/adapters/eval/eval_runner.py +4 -1
  5. kiln_ai/adapters/eval/g_eval.py +19 -3
  6. kiln_ai/adapters/eval/test_base_eval.py +1 -0
  7. kiln_ai/adapters/eval/test_eval_runner.py +1 -0
  8. kiln_ai/adapters/eval/test_g_eval.py +13 -7
  9. kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
  10. kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
  11. kiln_ai/adapters/fine_tune/fireworks_finetune.py +8 -1
  12. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +19 -0
  13. kiln_ai/adapters/fine_tune/test_together_finetune.py +533 -0
  14. kiln_ai/adapters/fine_tune/together_finetune.py +327 -0
  15. kiln_ai/adapters/ml_model_list.py +638 -155
  16. kiln_ai/adapters/model_adapters/__init__.py +2 -4
  17. kiln_ai/adapters/model_adapters/base_adapter.py +14 -11
  18. kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
  19. kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
  20. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
  21. kiln_ai/adapters/model_adapters/test_structured_output.py +23 -5
  22. kiln_ai/adapters/ollama_tools.py +3 -2
  23. kiln_ai/adapters/parsers/r1_parser.py +19 -14
  24. kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
  25. kiln_ai/adapters/provider_tools.py +52 -60
  26. kiln_ai/adapters/repair/test_repair_task.py +3 -3
  27. kiln_ai/adapters/run_output.py +1 -1
  28. kiln_ai/adapters/test_adapter_registry.py +17 -20
  29. kiln_ai/adapters/test_generate_docs.py +2 -2
  30. kiln_ai/adapters/test_prompt_adaptors.py +30 -19
  31. kiln_ai/adapters/test_provider_tools.py +27 -82
  32. kiln_ai/datamodel/basemodel.py +2 -0
  33. kiln_ai/datamodel/datamodel_enums.py +2 -0
  34. kiln_ai/datamodel/json_schema.py +1 -1
  35. kiln_ai/datamodel/task_output.py +13 -6
  36. kiln_ai/datamodel/test_basemodel.py +9 -0
  37. kiln_ai/datamodel/test_datasource.py +19 -0
  38. kiln_ai/utils/config.py +46 -0
  39. kiln_ai/utils/dataset_import.py +232 -0
  40. kiln_ai/utils/test_dataset_import.py +596 -0
  41. {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/METADATA +51 -7
  42. {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/RECORD +44 -41
  43. kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -309
  44. kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -10
  45. kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -289
  46. kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -343
  47. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -216
  48. {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/WHEEL +0 -0
  49. {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/licenses/LICENSE.txt +0 -0
@@ -7,12 +7,10 @@ Model adapters are used to call AI models, like Ollama, OpenAI, etc.
7
7
 
8
8
  from . import (
9
9
  base_adapter,
10
- langchain_adapters,
11
- openai_model_adapter,
10
+ litellm_adapter,
12
11
  )
13
12
 
14
13
  __all__ = [
15
14
  "base_adapter",
16
- "langchain_adapters",
17
- "openai_model_adapter",
15
+ "litellm_adapter",
18
16
  ]
@@ -4,6 +4,7 @@ from dataclasses import dataclass
4
4
  from typing import Dict, Literal, Tuple
5
5
 
6
6
  from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
7
+ from kiln_ai.adapters.parsers.json_parser import parse_json_string
7
8
  from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id
8
9
  from kiln_ai.adapters.prompt_builders import prompt_builder_from_id
9
10
  from kiln_ai.adapters.provider_tools import kiln_model_provider_from
@@ -85,17 +86,6 @@ class BaseAdapter(metaclass=ABCMeta):
85
86
  )
86
87
  return self._model_provider
87
88
 
88
- async def invoke_returning_raw(
89
- self,
90
- input: Dict | str,
91
- input_source: DataSource | None = None,
92
- ) -> Dict | str:
93
- result = await self.invoke(input, input_source)
94
- if self.task().output_json_schema is None:
95
- return result.output.output
96
- else:
97
- return json.loads(result.output.output)
98
-
99
89
  async def invoke(
100
90
  self,
101
91
  input: Dict | str,
@@ -127,6 +117,10 @@ class BaseAdapter(metaclass=ABCMeta):
127
117
 
128
118
  # validate output
129
119
  if self.output_schema is not None:
120
+ # Parse json to dict if we have structured output
121
+ if isinstance(parsed_output.output, str):
122
+ parsed_output.output = parse_json_string(parsed_output.output)
123
+
130
124
  if not isinstance(parsed_output.output, dict):
131
125
  raise RuntimeError(
132
126
  f"structured response is not a dict: {parsed_output.output}"
@@ -138,6 +132,15 @@ class BaseAdapter(metaclass=ABCMeta):
138
132
  f"response is not a string for non-structured task: {parsed_output.output}"
139
133
  )
140
134
 
135
+ # Validate reasoning content is present (if reasoning)
136
+ if provider.reasoning_capable and (
137
+ not parsed_output.intermediate_outputs
138
+ or "reasoning" not in parsed_output.intermediate_outputs
139
+ ):
140
+ raise RuntimeError(
141
+ "Reasoning is required for this model, but no reasoning was returned."
142
+ )
143
+
141
144
  # Generate the run and output
142
145
  run = self.generate_run(input, input_source, parsed_output)
143
146
 
@@ -0,0 +1,391 @@
1
+ from typing import Any, Dict
2
+
3
+ import litellm
4
+ from litellm.types.utils import ChoiceLogprobs, Choices, ModelResponse
5
+
6
+ import kiln_ai.datamodel as datamodel
7
+ from kiln_ai.adapters.ml_model_list import (
8
+ KilnModelProvider,
9
+ ModelProviderName,
10
+ StructuredOutputMode,
11
+ )
12
+ from kiln_ai.adapters.model_adapters.base_adapter import (
13
+ COT_FINAL_ANSWER_PROMPT,
14
+ AdapterConfig,
15
+ BaseAdapter,
16
+ RunOutput,
17
+ )
18
+ from kiln_ai.adapters.model_adapters.litellm_config import (
19
+ LiteLlmConfig,
20
+ )
21
+ from kiln_ai.datamodel import PromptGenerators, PromptId
22
+ from kiln_ai.datamodel.task import RunConfig
23
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
24
+
25
+
26
+ class LiteLlmAdapter(BaseAdapter):
27
+ def __init__(
28
+ self,
29
+ config: LiteLlmConfig,
30
+ kiln_task: datamodel.Task,
31
+ prompt_id: PromptId | None = None,
32
+ base_adapter_config: AdapterConfig | None = None,
33
+ ):
34
+ self.config = config
35
+ self._additional_body_options = config.additional_body_options
36
+ self._api_base = config.base_url
37
+ self._headers = config.default_headers
38
+ self._litellm_model_id: str | None = None
39
+
40
+ run_config = RunConfig(
41
+ task=kiln_task,
42
+ model_name=config.model_name,
43
+ model_provider_name=config.provider_name,
44
+ prompt_id=prompt_id or PromptGenerators.SIMPLE,
45
+ )
46
+
47
+ super().__init__(
48
+ run_config=run_config,
49
+ config=base_adapter_config,
50
+ )
51
+
52
+ async def _run(self, input: Dict | str) -> RunOutput:
53
+ provider = self.model_provider()
54
+ if not provider.model_id:
55
+ raise ValueError("Model ID is required for OpenAI compatible models")
56
+
57
+ intermediate_outputs: dict[str, str] = {}
58
+ prompt = self.build_prompt()
59
+ user_msg = self.prompt_builder.build_user_message(input)
60
+ messages = [
61
+ {"role": "system", "content": prompt},
62
+ {"role": "user", "content": user_msg},
63
+ ]
64
+
65
+ run_strategy, cot_prompt = self.run_strategy()
66
+
67
+ if run_strategy == "cot_as_message":
68
+ if not cot_prompt:
69
+ raise ValueError("cot_prompt is required for cot_as_message strategy")
70
+ messages.append({"role": "system", "content": cot_prompt})
71
+ elif run_strategy == "cot_two_call":
72
+ if not cot_prompt:
73
+ raise ValueError("cot_prompt is required for cot_two_call strategy")
74
+ messages.append({"role": "system", "content": cot_prompt})
75
+
76
+ # First call for chain of thought - No logprobs as only needed for final answer
77
+ completion_kwargs = await self.build_completion_kwargs(
78
+ provider, messages, None
79
+ )
80
+ cot_response = await litellm.acompletion(**completion_kwargs)
81
+ if (
82
+ not isinstance(cot_response, ModelResponse)
83
+ or not cot_response.choices
84
+ or len(cot_response.choices) == 0
85
+ or not isinstance(cot_response.choices[0], Choices)
86
+ ):
87
+ raise RuntimeError(
88
+ f"Expected ModelResponse with Choices, got {type(cot_response)}."
89
+ )
90
+ cot_content = cot_response.choices[0].message.content
91
+ if cot_content is not None:
92
+ intermediate_outputs["chain_of_thought"] = cot_content
93
+
94
+ messages.extend(
95
+ [
96
+ {"role": "assistant", "content": cot_content or ""},
97
+ {"role": "user", "content": COT_FINAL_ANSWER_PROMPT},
98
+ ]
99
+ )
100
+
101
+ # Make the API call using litellm
102
+ completion_kwargs = await self.build_completion_kwargs(
103
+ provider, messages, self.base_adapter_config.top_logprobs
104
+ )
105
+ response = await litellm.acompletion(**completion_kwargs)
106
+
107
+ if not isinstance(response, ModelResponse):
108
+ raise RuntimeError(f"Expected ModelResponse, got {type(response)}.")
109
+
110
+ # Maybe remove this? There is no error attribute on the response object.
111
+ # # Keeping in typesafe way as we added it for a reason, but should investigate what that was and if it still applies.
112
+ if hasattr(response, "error") and response.__getattribute__("error"):
113
+ raise RuntimeError(
114
+ f"LLM API returned an error: {response.__getattribute__('error')}"
115
+ )
116
+
117
+ if (
118
+ not response.choices
119
+ or len(response.choices) == 0
120
+ or not isinstance(response.choices[0], Choices)
121
+ ):
122
+ raise RuntimeError(
123
+ "No message content returned in the response from LLM API"
124
+ )
125
+
126
+ message = response.choices[0].message
127
+ logprobs = (
128
+ response.choices[0].logprobs
129
+ if hasattr(response.choices[0], "logprobs")
130
+ and isinstance(response.choices[0].logprobs, ChoiceLogprobs)
131
+ else None
132
+ )
133
+
134
+ # Check logprobs worked, if requested
135
+ if self.base_adapter_config.top_logprobs is not None and logprobs is None:
136
+ raise RuntimeError("Logprobs were required, but no logprobs were returned.")
137
+
138
+ # Save reasoning if it exists and was parsed by LiteLLM (or openrouter, or anyone upstream)
139
+ if hasattr(message, "reasoning_content") and message.reasoning_content:
140
+ intermediate_outputs["reasoning"] = message.reasoning_content
141
+
142
+ # the string content of the response
143
+ response_content = message.content
144
+
145
+ # Fallback: Use args of first tool call to task_response if it exists
146
+ if (
147
+ not response_content
148
+ and hasattr(message, "tool_calls")
149
+ and message.tool_calls
150
+ ):
151
+ tool_call = next(
152
+ (
153
+ tool_call
154
+ for tool_call in message.tool_calls
155
+ if tool_call.function.name == "task_response"
156
+ ),
157
+ None,
158
+ )
159
+ if tool_call:
160
+ response_content = tool_call.function.arguments
161
+
162
+ if not isinstance(response_content, str):
163
+ raise RuntimeError(f"response is not a string: {response_content}")
164
+
165
+ return RunOutput(
166
+ output=response_content,
167
+ intermediate_outputs=intermediate_outputs,
168
+ output_logprobs=logprobs,
169
+ )
170
+
171
+ def adapter_name(self) -> str:
172
+ return "kiln_openai_compatible_adapter"
173
+
174
+ async def response_format_options(self) -> dict[str, Any]:
175
+ # Unstructured if task isn't structured
176
+ if not self.has_structured_output():
177
+ return {}
178
+
179
+ provider = self.model_provider()
180
+ match provider.structured_output_mode:
181
+ case StructuredOutputMode.json_mode:
182
+ return {"response_format": {"type": "json_object"}}
183
+ case StructuredOutputMode.json_schema:
184
+ return self.json_schema_response_format()
185
+ case StructuredOutputMode.function_calling_weak:
186
+ return self.tool_call_params(strict=False)
187
+ case StructuredOutputMode.function_calling:
188
+ return self.tool_call_params(strict=True)
189
+ case StructuredOutputMode.json_instructions:
190
+ # JSON instructions dynamically injected in prompt, not the API response format. Do not ask for json_object (see option below).
191
+ return {}
192
+ case StructuredOutputMode.json_custom_instructions:
193
+ # JSON instructions statically injected in system prompt, not the API response format. Do not ask for json_object (see option above).
194
+ return {}
195
+ case StructuredOutputMode.json_instruction_and_object:
196
+ # We set response_format to json_object and also set json instructions in the prompt
197
+ return {"response_format": {"type": "json_object"}}
198
+ case StructuredOutputMode.default:
199
+ if provider.name == ModelProviderName.ollama:
200
+ # Ollama added json_schema to all models: https://ollama.com/blog/structured-outputs
201
+ return self.json_schema_response_format()
202
+ else:
203
+ # Default to function calling -- it's older than the other modes. Higher compatibility.
204
+ # Strict isn't widely supported yet, so we don't use it by default unless it's OpenAI.
205
+ strict = provider.name == ModelProviderName.openai
206
+ return self.tool_call_params(strict=strict)
207
+ case _:
208
+ raise_exhaustive_enum_error(provider.structured_output_mode)
209
+
210
+ def json_schema_response_format(self) -> dict[str, Any]:
211
+ output_schema = self.task().output_schema()
212
+ return {
213
+ "response_format": {
214
+ "type": "json_schema",
215
+ "json_schema": {
216
+ "name": "task_response",
217
+ "schema": output_schema,
218
+ },
219
+ }
220
+ }
221
+
222
+ def tool_call_params(self, strict: bool) -> dict[str, Any]:
223
+ # Add additional_properties: false to the schema (OpenAI requires this for some models)
224
+ output_schema = self.task().output_schema()
225
+ if not isinstance(output_schema, dict):
226
+ raise ValueError(
227
+ "Invalid output schema for this task. Can not use tool calls."
228
+ )
229
+ output_schema["additionalProperties"] = False
230
+
231
+ function_params = {
232
+ "name": "task_response",
233
+ "parameters": output_schema,
234
+ }
235
+ # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
236
+ if strict:
237
+ function_params["strict"] = True
238
+
239
+ return {
240
+ "tools": [
241
+ {
242
+ "type": "function",
243
+ "function": function_params,
244
+ }
245
+ ],
246
+ "tool_choice": {
247
+ "type": "function",
248
+ "function": {"name": "task_response"},
249
+ },
250
+ }
251
+
252
+ def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
253
+ # TODO P1: Don't love having this logic here. But it's a usability improvement
254
+ # so better to keep it than exclude it. Should figure out how I want to isolate
255
+ # this sort of logic so it's config driven and can be overridden
256
+
257
+ extra_body = {}
258
+ provider_options = {}
259
+
260
+ if provider.thinking_level is not None:
261
+ extra_body["reasoning_effort"] = provider.thinking_level
262
+
263
+ if provider.require_openrouter_reasoning:
264
+ # https://openrouter.ai/docs/use-cases/reasoning-tokens
265
+ extra_body["reasoning"] = {
266
+ "exclude": False,
267
+ }
268
+
269
+ if provider.anthropic_extended_thinking:
270
+ extra_body["thinking"] = {"type": "enabled", "budget_tokens": 4000}
271
+
272
+ if provider.r1_openrouter_options:
273
+ # Require providers that support the reasoning parameter
274
+ provider_options["require_parameters"] = True
275
+ # Prefer R1 providers with reasonable perf/quants
276
+ provider_options["order"] = ["Fireworks", "Together"]
277
+ # R1 providers with unreasonable quants
278
+ provider_options["ignore"] = ["DeepInfra"]
279
+
280
+ # Only set of this request is to get logprobs.
281
+ if (
282
+ provider.logprobs_openrouter_options
283
+ and self.base_adapter_config.top_logprobs is not None
284
+ ):
285
+ # Don't let OpenRouter choose a provider that doesn't support logprobs.
286
+ provider_options["require_parameters"] = True
287
+ # DeepInfra silently fails to return logprobs consistently.
288
+ provider_options["ignore"] = ["DeepInfra"]
289
+
290
+ if provider.openrouter_skip_required_parameters:
291
+ # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
292
+ provider_options["require_parameters"] = False
293
+
294
+ if len(provider_options) > 0:
295
+ extra_body["provider"] = provider_options
296
+
297
+ return extra_body
298
+
299
+ def litellm_model_id(self) -> str:
300
+ # The model ID is an interesting combination of format and url endpoint.
301
+ # It specifics the provider URL/host, but this is overridden if you manually set an api url
302
+
303
+ if self._litellm_model_id:
304
+ return self._litellm_model_id
305
+
306
+ provider = self.model_provider()
307
+ if not provider.model_id:
308
+ raise ValueError("Model ID is required for OpenAI compatible models")
309
+
310
+ litellm_provider_name: str | None = None
311
+ is_custom = False
312
+ match provider.name:
313
+ case ModelProviderName.openrouter:
314
+ litellm_provider_name = "openrouter"
315
+ case ModelProviderName.openai:
316
+ litellm_provider_name = "openai"
317
+ case ModelProviderName.groq:
318
+ litellm_provider_name = "groq"
319
+ case ModelProviderName.anthropic:
320
+ litellm_provider_name = "anthropic"
321
+ case ModelProviderName.ollama:
322
+ # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
323
+ # This is because we're setting detailed features like response_format=json_schema and want lower level control.
324
+ is_custom = True
325
+ case ModelProviderName.gemini_api:
326
+ litellm_provider_name = "gemini"
327
+ case ModelProviderName.fireworks_ai:
328
+ litellm_provider_name = "fireworks_ai"
329
+ case ModelProviderName.amazon_bedrock:
330
+ litellm_provider_name = "bedrock"
331
+ case ModelProviderName.azure_openai:
332
+ litellm_provider_name = "azure"
333
+ case ModelProviderName.huggingface:
334
+ litellm_provider_name = "huggingface"
335
+ case ModelProviderName.vertex:
336
+ litellm_provider_name = "vertex_ai"
337
+ case ModelProviderName.together_ai:
338
+ litellm_provider_name = "together_ai"
339
+ case ModelProviderName.openai_compatible:
340
+ is_custom = True
341
+ case ModelProviderName.kiln_custom_registry:
342
+ is_custom = True
343
+ case ModelProviderName.kiln_fine_tune:
344
+ is_custom = True
345
+ case _:
346
+ raise_exhaustive_enum_error(provider.name)
347
+
348
+ if is_custom:
349
+ if self._api_base is None:
350
+ raise ValueError(
351
+ "Explicit Base URL is required for OpenAI compatible APIs (custom models, ollama, fine tunes, and custom registry models)"
352
+ )
353
+ # Use openai as it's only used for format, not url
354
+ litellm_provider_name = "openai"
355
+
356
+ # Sholdn't be possible but keep type checker happy
357
+ if litellm_provider_name is None:
358
+ raise ValueError(
359
+ f"Provider name could not lookup valid litellm provider ID {provider.model_id}"
360
+ )
361
+
362
+ self._litellm_model_id = litellm_provider_name + "/" + provider.model_id
363
+ return self._litellm_model_id
364
+
365
+ async def build_completion_kwargs(
366
+ self,
367
+ provider: KilnModelProvider,
368
+ messages: list[dict[str, Any]],
369
+ top_logprobs: int | None,
370
+ ) -> dict[str, Any]:
371
+ extra_body = self.build_extra_body(provider)
372
+
373
+ # Merge all parameters into a single kwargs dict for litellm
374
+ completion_kwargs = {
375
+ "model": self.litellm_model_id(),
376
+ "messages": messages,
377
+ "api_base": self._api_base,
378
+ "headers": self._headers,
379
+ **extra_body,
380
+ **self._additional_body_options,
381
+ }
382
+
383
+ # Response format: json_schema, json_instructions, json_mode, function_calling, etc
384
+ response_format_options = await self.response_format_options()
385
+ completion_kwargs.update(response_format_options)
386
+
387
+ if top_logprobs is not None:
388
+ completion_kwargs["logprobs"] = True
389
+ completion_kwargs["top_logprobs"] = top_logprobs
390
+
391
+ return completion_kwargs
@@ -0,0 +1,13 @@
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class LiteLlmConfig:
6
+ model_name: str
7
+ provider_name: str
8
+ # If set, over rides the provider-name based URL from litellm
9
+ base_url: str | None = None
10
+ # Headers to send with every request
11
+ default_headers: dict[str, str] | None = None
12
+ # Extra body to send with every request
13
+ additional_body_options: dict[str, str] = field(default_factory=dict)