kiln-ai 0.8.1__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (57) hide show
  1. kiln_ai/adapters/__init__.py +7 -7
  2. kiln_ai/adapters/adapter_registry.py +77 -5
  3. kiln_ai/adapters/data_gen/data_gen_task.py +3 -3
  4. kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
  5. kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
  6. kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
  7. kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
  8. kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
  9. kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
  10. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +469 -129
  11. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +113 -21
  12. kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
  13. kiln_ai/adapters/ml_model_list.py +323 -94
  14. kiln_ai/adapters/model_adapters/__init__.py +18 -0
  15. kiln_ai/adapters/{base_adapter.py → model_adapters/base_adapter.py} +81 -37
  16. kiln_ai/adapters/{langchain_adapters.py → model_adapters/langchain_adapters.py} +130 -84
  17. kiln_ai/adapters/model_adapters/openai_compatible_config.py +11 -0
  18. kiln_ai/adapters/model_adapters/openai_model_adapter.py +246 -0
  19. kiln_ai/adapters/model_adapters/test_base_adapter.py +190 -0
  20. kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +103 -88
  21. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +225 -0
  22. kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +43 -15
  23. kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +93 -20
  24. kiln_ai/adapters/parsers/__init__.py +10 -0
  25. kiln_ai/adapters/parsers/base_parser.py +12 -0
  26. kiln_ai/adapters/parsers/json_parser.py +37 -0
  27. kiln_ai/adapters/parsers/parser_registry.py +19 -0
  28. kiln_ai/adapters/parsers/r1_parser.py +69 -0
  29. kiln_ai/adapters/parsers/test_json_parser.py +81 -0
  30. kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
  31. kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
  32. kiln_ai/adapters/prompt_builders.py +126 -20
  33. kiln_ai/adapters/provider_tools.py +91 -36
  34. kiln_ai/adapters/repair/repair_task.py +17 -6
  35. kiln_ai/adapters/repair/test_repair_task.py +4 -4
  36. kiln_ai/adapters/run_output.py +8 -0
  37. kiln_ai/adapters/test_adapter_registry.py +177 -0
  38. kiln_ai/adapters/test_generate_docs.py +69 -0
  39. kiln_ai/adapters/test_prompt_adaptors.py +8 -4
  40. kiln_ai/adapters/test_prompt_builders.py +190 -29
  41. kiln_ai/adapters/test_provider_tools.py +268 -46
  42. kiln_ai/datamodel/__init__.py +193 -12
  43. kiln_ai/datamodel/basemodel.py +31 -11
  44. kiln_ai/datamodel/json_schema.py +8 -3
  45. kiln_ai/datamodel/model_cache.py +8 -3
  46. kiln_ai/datamodel/test_basemodel.py +81 -2
  47. kiln_ai/datamodel/test_dataset_split.py +100 -3
  48. kiln_ai/datamodel/test_example_models.py +25 -4
  49. kiln_ai/datamodel/test_model_cache.py +24 -0
  50. kiln_ai/datamodel/test_model_perf.py +125 -0
  51. kiln_ai/datamodel/test_models.py +129 -0
  52. kiln_ai/utils/exhaustive_error.py +6 -0
  53. {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/METADATA +9 -7
  54. kiln_ai-0.11.1.dist-info/RECORD +76 -0
  55. kiln_ai-0.8.1.dist-info/RECORD +0 -58
  56. {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/WHEEL +0 -0
  57. {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,246 @@
1
+ from typing import Any, Dict
2
+
3
+ from openai import AsyncOpenAI
4
+ from openai.types.chat import (
5
+ ChatCompletion,
6
+ ChatCompletionAssistantMessageParam,
7
+ ChatCompletionSystemMessageParam,
8
+ ChatCompletionUserMessageParam,
9
+ )
10
+
11
+ import kiln_ai.datamodel as datamodel
12
+ from kiln_ai.adapters.ml_model_list import StructuredOutputMode
13
+ from kiln_ai.adapters.model_adapters.base_adapter import (
14
+ COT_FINAL_ANSWER_PROMPT,
15
+ AdapterInfo,
16
+ BaseAdapter,
17
+ BasePromptBuilder,
18
+ RunOutput,
19
+ )
20
+ from kiln_ai.adapters.model_adapters.openai_compatible_config import (
21
+ OpenAICompatibleConfig,
22
+ )
23
+ from kiln_ai.adapters.parsers.json_parser import parse_json_string
24
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
25
+
26
+
27
+ class OpenAICompatibleAdapter(BaseAdapter):
28
+ def __init__(
29
+ self,
30
+ config: OpenAICompatibleConfig,
31
+ kiln_task: datamodel.Task,
32
+ prompt_builder: BasePromptBuilder | None = None,
33
+ tags: list[str] | None = None,
34
+ ):
35
+ self.config = config
36
+ self.client = AsyncOpenAI(
37
+ api_key=config.api_key,
38
+ base_url=config.base_url,
39
+ default_headers=config.default_headers,
40
+ )
41
+
42
+ super().__init__(
43
+ kiln_task,
44
+ model_name=config.model_name,
45
+ model_provider_name=config.provider_name,
46
+ prompt_builder=prompt_builder,
47
+ tags=tags,
48
+ )
49
+
50
+ async def _run(self, input: Dict | str) -> RunOutput:
51
+ provider = self.model_provider()
52
+ intermediate_outputs: dict[str, str] = {}
53
+ prompt = self.build_prompt()
54
+ user_msg = self.prompt_builder.build_user_message(input)
55
+ messages = [
56
+ ChatCompletionSystemMessageParam(role="system", content=prompt),
57
+ ChatCompletionUserMessageParam(role="user", content=user_msg),
58
+ ]
59
+
60
+ run_strategy, cot_prompt = self.run_strategy()
61
+
62
+ if run_strategy == "cot_as_message":
63
+ if not cot_prompt:
64
+ raise ValueError("cot_prompt is required for cot_as_message strategy")
65
+ messages.append(
66
+ ChatCompletionSystemMessageParam(role="system", content=cot_prompt)
67
+ )
68
+ elif run_strategy == "cot_two_call":
69
+ if not cot_prompt:
70
+ raise ValueError("cot_prompt is required for cot_two_call strategy")
71
+ messages.append(
72
+ ChatCompletionSystemMessageParam(role="system", content=cot_prompt)
73
+ )
74
+
75
+ # First call for chain of thought
76
+ cot_response = await self.client.chat.completions.create(
77
+ model=provider.provider_options["model"],
78
+ messages=messages,
79
+ )
80
+ cot_content = cot_response.choices[0].message.content
81
+ if cot_content is not None:
82
+ intermediate_outputs["chain_of_thought"] = cot_content
83
+
84
+ messages.extend(
85
+ [
86
+ ChatCompletionAssistantMessageParam(
87
+ role="assistant", content=cot_content
88
+ ),
89
+ ChatCompletionUserMessageParam(
90
+ role="user",
91
+ content=COT_FINAL_ANSWER_PROMPT,
92
+ ),
93
+ ]
94
+ )
95
+
96
+ # OpenRouter specific options for reasoning models
97
+ extra_body = {}
98
+ require_or_reasoning = (
99
+ self.config.openrouter_style_reasoning and provider.reasoning_capable
100
+ )
101
+ if require_or_reasoning:
102
+ extra_body["include_reasoning"] = True
103
+ # Filter to providers that support the reasoning parameter
104
+ extra_body["provider"] = {
105
+ "require_parameters": True,
106
+ # Ugly to have these here, but big range of quality of R1 providers
107
+ "order": ["Fireworks", "Together"],
108
+ # fp8 quants are awful
109
+ "ignore": ["DeepInfra"],
110
+ }
111
+
112
+ # Main completion call
113
+ response_format_options = await self.response_format_options()
114
+ response = await self.client.chat.completions.create(
115
+ model=provider.provider_options["model"],
116
+ messages=messages,
117
+ extra_body=extra_body,
118
+ **response_format_options,
119
+ )
120
+
121
+ if not isinstance(response, ChatCompletion):
122
+ raise RuntimeError(
123
+ f"Expected ChatCompletion response, got {type(response)}."
124
+ )
125
+
126
+ if hasattr(response, "error") and response.error: # pyright: ignore
127
+ raise RuntimeError(
128
+ f"OpenAI compatible API returned status code {response.error.get('code')}: {response.error.get('message') or 'Unknown error'}.\nError: {response.error}" # pyright: ignore
129
+ )
130
+ if not response.choices or len(response.choices) == 0:
131
+ raise RuntimeError(
132
+ "No message content returned in the response from OpenAI compatible API"
133
+ )
134
+
135
+ message = response.choices[0].message
136
+
137
+ # Save reasoning if it exists (OpenRouter specific format)
138
+ if require_or_reasoning:
139
+ if (
140
+ hasattr(message, "reasoning") and message.reasoning # pyright: ignore
141
+ ):
142
+ intermediate_outputs["reasoning"] = message.reasoning # pyright: ignore
143
+ else:
144
+ raise RuntimeError(
145
+ "Reasoning is required for this model, but no reasoning was returned from OpenRouter."
146
+ )
147
+
148
+ # the string content of the response
149
+ response_content = message.content
150
+
151
+ # Fallback: Use args of first tool call to task_response if it exists
152
+ if not response_content and message.tool_calls:
153
+ tool_call = next(
154
+ (
155
+ tool_call
156
+ for tool_call in message.tool_calls
157
+ if tool_call.function.name == "task_response"
158
+ ),
159
+ None,
160
+ )
161
+ if tool_call:
162
+ response_content = tool_call.function.arguments
163
+
164
+ if not isinstance(response_content, str):
165
+ raise RuntimeError(f"response is not a string: {response_content}")
166
+
167
+ if self.has_structured_output():
168
+ structured_response = parse_json_string(response_content)
169
+ return RunOutput(
170
+ output=structured_response,
171
+ intermediate_outputs=intermediate_outputs,
172
+ )
173
+
174
+ return RunOutput(
175
+ output=response_content,
176
+ intermediate_outputs=intermediate_outputs,
177
+ )
178
+
179
+ def adapter_info(self) -> AdapterInfo:
180
+ return AdapterInfo(
181
+ model_name=self.model_name,
182
+ model_provider=self.model_provider_name,
183
+ adapter_name="kiln_openai_compatible_adapter",
184
+ prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(),
185
+ prompt_id=self.prompt_builder.prompt_id(),
186
+ )
187
+
188
+ async def response_format_options(self) -> dict[str, Any]:
189
+ # Unstructured if task isn't structured
190
+ if not self.has_structured_output():
191
+ return {}
192
+
193
+ provider = self.model_provider()
194
+ match provider.structured_output_mode:
195
+ case StructuredOutputMode.json_mode:
196
+ return {"response_format": {"type": "json_object"}}
197
+ case StructuredOutputMode.json_schema:
198
+ output_schema = self.kiln_task.output_schema()
199
+ return {
200
+ "response_format": {
201
+ "type": "json_schema",
202
+ "json_schema": {
203
+ "name": "task_response",
204
+ "schema": output_schema,
205
+ },
206
+ }
207
+ }
208
+ case StructuredOutputMode.function_calling:
209
+ return self.tool_call_params()
210
+ case StructuredOutputMode.json_instructions:
211
+ # JSON done via instructions in prompt, not the API response format. Do not ask for json_object (see option below).
212
+ return {}
213
+ case StructuredOutputMode.json_instruction_and_object:
214
+ # We set response_format to json_object and also set json instructions in the prompt
215
+ return {"response_format": {"type": "json_object"}}
216
+ case StructuredOutputMode.default:
217
+ # Default to function calling -- it's older than the other modes. Higher compatibility.
218
+ return self.tool_call_params()
219
+ case _:
220
+ raise_exhaustive_enum_error(provider.structured_output_mode)
221
+
222
+ def tool_call_params(self) -> dict[str, Any]:
223
+ # Add additional_properties: false to the schema (OpenAI requires this for some models)
224
+ output_schema = self.kiln_task.output_schema()
225
+ if not isinstance(output_schema, dict):
226
+ raise ValueError(
227
+ "Invalid output schema for this task. Can not use tool calls."
228
+ )
229
+ output_schema["additionalProperties"] = False
230
+
231
+ return {
232
+ "tools": [
233
+ {
234
+ "type": "function",
235
+ "function": {
236
+ "name": "task_response",
237
+ "parameters": output_schema,
238
+ "strict": True,
239
+ },
240
+ }
241
+ ],
242
+ "tool_choice": {
243
+ "type": "function",
244
+ "function": {"name": "task_response"},
245
+ },
246
+ }
@@ -0,0 +1,190 @@
1
+ from unittest.mock import MagicMock, patch
2
+
3
+ import pytest
4
+
5
+ from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
6
+ from kiln_ai.adapters.model_adapters.base_adapter import AdapterInfo, BaseAdapter
7
+ from kiln_ai.datamodel import Task
8
+
9
+
10
+ class MockAdapter(BaseAdapter):
11
+ """Concrete implementation of BaseAdapter for testing"""
12
+
13
+ async def _run(self, input):
14
+ return None
15
+
16
+ def adapter_info(self) -> AdapterInfo:
17
+ return AdapterInfo(
18
+ adapter_name="test",
19
+ model_name=self.model_name,
20
+ model_provider=self.model_provider_name,
21
+ prompt_builder_name="test",
22
+ )
23
+
24
+
25
+ @pytest.fixture
26
+ def mock_provider():
27
+ return KilnModelProvider(
28
+ name="openai",
29
+ )
30
+
31
+
32
+ @pytest.fixture
33
+ def base_task():
34
+ return Task(name="test_task", instruction="test_instruction")
35
+
36
+
37
+ @pytest.fixture
38
+ def adapter(base_task):
39
+ return MockAdapter(
40
+ kiln_task=base_task,
41
+ model_name="test_model",
42
+ model_provider_name="test_provider",
43
+ )
44
+
45
+
46
+ async def test_model_provider_uses_cache(adapter, mock_provider):
47
+ """Test that cached provider is returned if it exists"""
48
+ # Set up cached provider
49
+ adapter._model_provider = mock_provider
50
+
51
+ # Mock the provider loader to ensure it's not called
52
+ with patch(
53
+ "kiln_ai.adapters.model_adapters.base_adapter.kiln_model_provider_from"
54
+ ) as mock_loader:
55
+ provider = adapter.model_provider()
56
+
57
+ assert provider == mock_provider
58
+ mock_loader.assert_not_called()
59
+
60
+
61
+ async def test_model_provider_loads_and_caches(adapter, mock_provider):
62
+ """Test that provider is loaded and cached if not present"""
63
+ # Ensure no cached provider
64
+ adapter._model_provider = None
65
+
66
+ # Mock the provider loader
67
+ with patch(
68
+ "kiln_ai.adapters.model_adapters.base_adapter.kiln_model_provider_from"
69
+ ) as mock_loader:
70
+ mock_loader.return_value = mock_provider
71
+
72
+ # First call should load and cache
73
+ provider1 = adapter.model_provider()
74
+ assert provider1 == mock_provider
75
+ mock_loader.assert_called_once_with("test_model", "test_provider")
76
+
77
+ # Second call should use cache
78
+ mock_loader.reset_mock()
79
+ provider2 = adapter.model_provider()
80
+ assert provider2 == mock_provider
81
+ mock_loader.assert_not_called()
82
+
83
+
84
+ async def test_model_provider_missing_names(base_task):
85
+ """Test error when model or provider name is missing"""
86
+ # Test with missing model name
87
+ adapter = MockAdapter(
88
+ kiln_task=base_task, model_name="", model_provider_name="test_provider"
89
+ )
90
+ with pytest.raises(
91
+ ValueError, match="model_name and model_provider_name must be provided"
92
+ ):
93
+ await adapter.model_provider()
94
+
95
+ # Test with missing provider name
96
+ adapter = MockAdapter(
97
+ kiln_task=base_task, model_name="test_model", model_provider_name=""
98
+ )
99
+ with pytest.raises(
100
+ ValueError, match="model_name and model_provider_name must be provided"
101
+ ):
102
+ await adapter.model_provider()
103
+
104
+
105
+ async def test_model_provider_not_found(adapter):
106
+ """Test error when provider loader returns None"""
107
+ # Mock the provider loader to return None
108
+ with patch(
109
+ "kiln_ai.adapters.model_adapters.base_adapter.kiln_model_provider_from"
110
+ ) as mock_loader:
111
+ mock_loader.return_value = None
112
+
113
+ with pytest.raises(
114
+ ValueError,
115
+ match="model_provider_name test_provider not found for model test_model",
116
+ ):
117
+ await adapter.model_provider()
118
+
119
+
120
+ @pytest.mark.asyncio
121
+ @pytest.mark.parametrize(
122
+ "output_schema,structured_output_mode,expected_json_instructions",
123
+ [
124
+ (False, StructuredOutputMode.json_instructions, False),
125
+ (True, StructuredOutputMode.json_instructions, True),
126
+ (False, StructuredOutputMode.json_instruction_and_object, False),
127
+ (True, StructuredOutputMode.json_instruction_and_object, True),
128
+ (True, StructuredOutputMode.json_mode, False),
129
+ (False, StructuredOutputMode.json_mode, False),
130
+ ],
131
+ )
132
+ async def test_prompt_builder_json_instructions(
133
+ base_task,
134
+ adapter,
135
+ output_schema,
136
+ structured_output_mode,
137
+ expected_json_instructions,
138
+ ):
139
+ """Test that prompt builder is called with correct include_json_instructions value"""
140
+ # Mock the prompt builder and has_structured_output method
141
+ mock_prompt_builder = MagicMock()
142
+ adapter.prompt_builder = mock_prompt_builder
143
+ adapter.model_provider_name = "openai"
144
+ adapter.has_structured_output = MagicMock(return_value=output_schema)
145
+
146
+ # provider mock
147
+ provider = MagicMock()
148
+ provider.structured_output_mode = structured_output_mode
149
+ adapter.model_provider = MagicMock(return_value=provider)
150
+
151
+ # Test
152
+ adapter.build_prompt()
153
+ mock_prompt_builder.build_prompt.assert_called_with(
154
+ include_json_instructions=expected_json_instructions
155
+ )
156
+
157
+
158
+ @pytest.mark.parametrize(
159
+ "cot_prompt,has_structured_output,reasoning_capable,expected",
160
+ [
161
+ # COT and normal LLM
162
+ ("think carefully", False, False, ("cot_two_call", "think carefully")),
163
+ # Structured output with thinking-capable LLM
164
+ ("think carefully", True, True, ("cot_as_message", "think carefully")),
165
+ # Structured output with normal LLM
166
+ ("think carefully", True, False, ("cot_two_call", "think carefully")),
167
+ # Basic cases - no COT
168
+ (None, True, True, ("basic", None)),
169
+ (None, False, False, ("basic", None)),
170
+ (None, True, False, ("basic", None)),
171
+ (None, False, True, ("basic", None)),
172
+ # Edge case - COT prompt exists but structured output is False and reasoning_capable is True
173
+ ("think carefully", False, True, ("cot_as_message", "think carefully")),
174
+ ],
175
+ )
176
+ async def test_run_strategy(
177
+ adapter, cot_prompt, has_structured_output, reasoning_capable, expected
178
+ ):
179
+ """Test that run_strategy returns correct strategy based on conditions"""
180
+ # Mock dependencies
181
+ adapter.prompt_builder.chain_of_thought_prompt = MagicMock(return_value=cot_prompt)
182
+ adapter.has_structured_output = MagicMock(return_value=has_structured_output)
183
+
184
+ provider = MagicMock()
185
+ provider.reasoning_capable = reasoning_capable
186
+ adapter.model_provider = MagicMock(return_value=provider)
187
+
188
+ # Test
189
+ result = adapter.run_strategy()
190
+ assert result == expected