kiln-ai 0.12.0__py3-none-any.whl → 0.13.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (49) hide show
  1. kiln_ai/adapters/__init__.py +4 -0
  2. kiln_ai/adapters/adapter_registry.py +157 -28
  3. kiln_ai/adapters/eval/__init__.py +28 -0
  4. kiln_ai/adapters/eval/eval_runner.py +4 -1
  5. kiln_ai/adapters/eval/g_eval.py +19 -3
  6. kiln_ai/adapters/eval/test_base_eval.py +1 -0
  7. kiln_ai/adapters/eval/test_eval_runner.py +1 -0
  8. kiln_ai/adapters/eval/test_g_eval.py +13 -7
  9. kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
  10. kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
  11. kiln_ai/adapters/fine_tune/fireworks_finetune.py +8 -1
  12. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +19 -0
  13. kiln_ai/adapters/fine_tune/test_together_finetune.py +533 -0
  14. kiln_ai/adapters/fine_tune/together_finetune.py +327 -0
  15. kiln_ai/adapters/ml_model_list.py +638 -155
  16. kiln_ai/adapters/model_adapters/__init__.py +2 -4
  17. kiln_ai/adapters/model_adapters/base_adapter.py +14 -11
  18. kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
  19. kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
  20. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
  21. kiln_ai/adapters/model_adapters/test_structured_output.py +23 -5
  22. kiln_ai/adapters/ollama_tools.py +3 -2
  23. kiln_ai/adapters/parsers/r1_parser.py +19 -14
  24. kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
  25. kiln_ai/adapters/provider_tools.py +52 -60
  26. kiln_ai/adapters/repair/test_repair_task.py +3 -3
  27. kiln_ai/adapters/run_output.py +1 -1
  28. kiln_ai/adapters/test_adapter_registry.py +17 -20
  29. kiln_ai/adapters/test_generate_docs.py +2 -2
  30. kiln_ai/adapters/test_prompt_adaptors.py +30 -19
  31. kiln_ai/adapters/test_provider_tools.py +27 -82
  32. kiln_ai/datamodel/basemodel.py +2 -0
  33. kiln_ai/datamodel/datamodel_enums.py +2 -0
  34. kiln_ai/datamodel/json_schema.py +1 -1
  35. kiln_ai/datamodel/task_output.py +13 -6
  36. kiln_ai/datamodel/test_basemodel.py +9 -0
  37. kiln_ai/datamodel/test_datasource.py +19 -0
  38. kiln_ai/utils/config.py +46 -0
  39. kiln_ai/utils/dataset_import.py +232 -0
  40. kiln_ai/utils/test_dataset_import.py +596 -0
  41. {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/METADATA +51 -7
  42. {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/RECORD +44 -41
  43. kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -309
  44. kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -10
  45. kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -289
  46. kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -343
  47. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -216
  48. {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/WHEEL +0 -0
  49. {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,289 +0,0 @@
1
- from typing import Any, Dict
2
-
3
- from openai import AsyncOpenAI
4
- from openai.types.chat import (
5
- ChatCompletion,
6
- ChatCompletionAssistantMessageParam,
7
- ChatCompletionSystemMessageParam,
8
- ChatCompletionUserMessageParam,
9
- )
10
-
11
- import kiln_ai.datamodel as datamodel
12
- from kiln_ai.adapters.ml_model_list import (
13
- KilnModelProvider,
14
- ModelProviderName,
15
- StructuredOutputMode,
16
- )
17
- from kiln_ai.adapters.model_adapters.base_adapter import (
18
- COT_FINAL_ANSWER_PROMPT,
19
- AdapterConfig,
20
- BaseAdapter,
21
- RunOutput,
22
- )
23
- from kiln_ai.adapters.model_adapters.openai_compatible_config import (
24
- OpenAICompatibleConfig,
25
- )
26
- from kiln_ai.adapters.parsers.json_parser import parse_json_string
27
- from kiln_ai.datamodel import PromptGenerators, PromptId
28
- from kiln_ai.datamodel.task import RunConfig
29
- from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
30
-
31
-
32
- class OpenAICompatibleAdapter(BaseAdapter):
33
- def __init__(
34
- self,
35
- config: OpenAICompatibleConfig,
36
- kiln_task: datamodel.Task,
37
- prompt_id: PromptId | None = None,
38
- base_adapter_config: AdapterConfig | None = None,
39
- ):
40
- self.config = config
41
- self.client = AsyncOpenAI(
42
- api_key=config.api_key,
43
- base_url=config.base_url,
44
- default_headers=config.default_headers,
45
- )
46
-
47
- run_config = RunConfig(
48
- task=kiln_task,
49
- model_name=config.model_name,
50
- model_provider_name=config.provider_name,
51
- prompt_id=prompt_id or PromptGenerators.SIMPLE,
52
- )
53
-
54
- super().__init__(
55
- run_config=run_config,
56
- config=base_adapter_config,
57
- )
58
-
59
- async def _run(self, input: Dict | str) -> RunOutput:
60
- provider = self.model_provider()
61
- intermediate_outputs: dict[str, str] = {}
62
- prompt = self.build_prompt()
63
- user_msg = self.prompt_builder.build_user_message(input)
64
- messages = [
65
- ChatCompletionSystemMessageParam(role="system", content=prompt),
66
- ChatCompletionUserMessageParam(role="user", content=user_msg),
67
- ]
68
-
69
- run_strategy, cot_prompt = self.run_strategy()
70
-
71
- if run_strategy == "cot_as_message":
72
- if not cot_prompt:
73
- raise ValueError("cot_prompt is required for cot_as_message strategy")
74
- messages.append(
75
- ChatCompletionSystemMessageParam(role="system", content=cot_prompt)
76
- )
77
- elif run_strategy == "cot_two_call":
78
- if not cot_prompt:
79
- raise ValueError("cot_prompt is required for cot_two_call strategy")
80
- messages.append(
81
- ChatCompletionSystemMessageParam(role="system", content=cot_prompt)
82
- )
83
-
84
- # First call for chain of thought
85
- cot_response = await self.client.chat.completions.create(
86
- model=provider.provider_options["model"],
87
- messages=messages,
88
- )
89
- cot_content = cot_response.choices[0].message.content
90
- if cot_content is not None:
91
- intermediate_outputs["chain_of_thought"] = cot_content
92
-
93
- messages.extend(
94
- [
95
- ChatCompletionAssistantMessageParam(
96
- role="assistant", content=cot_content
97
- ),
98
- ChatCompletionUserMessageParam(
99
- role="user",
100
- content=COT_FINAL_ANSWER_PROMPT,
101
- ),
102
- ]
103
- )
104
-
105
- # Build custom request params based on model provider
106
- extra_body = self.build_extra_body(provider)
107
-
108
- # Main completion call
109
- response_format_options = await self.response_format_options()
110
- response = await self.client.chat.completions.create(
111
- model=provider.provider_options["model"],
112
- messages=messages,
113
- extra_body=extra_body,
114
- logprobs=self.base_adapter_config.top_logprobs is not None,
115
- top_logprobs=self.base_adapter_config.top_logprobs,
116
- **response_format_options,
117
- )
118
-
119
- if not isinstance(response, ChatCompletion):
120
- raise RuntimeError(
121
- f"Expected ChatCompletion response, got {type(response)}."
122
- )
123
-
124
- if hasattr(response, "error") and response.error: # pyright: ignore
125
- raise RuntimeError(
126
- f"OpenAI compatible API returned status code {response.error.get('code')}: {response.error.get('message') or 'Unknown error'}.\nError: {response.error}" # pyright: ignore
127
- )
128
- if not response.choices or len(response.choices) == 0:
129
- raise RuntimeError(
130
- "No message content returned in the response from OpenAI compatible API"
131
- )
132
-
133
- message = response.choices[0].message
134
- logprobs = response.choices[0].logprobs
135
-
136
- # Check logprobs worked, if requested
137
- if self.base_adapter_config.top_logprobs is not None and logprobs is None:
138
- raise RuntimeError("Logprobs were required, but no logprobs were returned.")
139
-
140
- # Save reasoning if it exists (OpenRouter specific api response field)
141
- if provider.require_openrouter_reasoning:
142
- if (
143
- hasattr(message, "reasoning") and message.reasoning # pyright: ignore
144
- ):
145
- intermediate_outputs["reasoning"] = message.reasoning # pyright: ignore
146
- else:
147
- raise RuntimeError(
148
- "Reasoning is required for this model, but no reasoning was returned from OpenRouter."
149
- )
150
-
151
- # the string content of the response
152
- response_content = message.content
153
-
154
- # Fallback: Use args of first tool call to task_response if it exists
155
- if not response_content and message.tool_calls:
156
- tool_call = next(
157
- (
158
- tool_call
159
- for tool_call in message.tool_calls
160
- if tool_call.function.name == "task_response"
161
- ),
162
- None,
163
- )
164
- if tool_call:
165
- response_content = tool_call.function.arguments
166
-
167
- if not isinstance(response_content, str):
168
- raise RuntimeError(f"response is not a string: {response_content}")
169
-
170
- # Parse to dict if we have structured output
171
- output: Dict | str = response_content
172
- if self.has_structured_output():
173
- output = parse_json_string(response_content)
174
-
175
- return RunOutput(
176
- output=output,
177
- intermediate_outputs=intermediate_outputs,
178
- output_logprobs=logprobs,
179
- )
180
-
181
- def adapter_name(self) -> str:
182
- return "kiln_openai_compatible_adapter"
183
-
184
- async def response_format_options(self) -> dict[str, Any]:
185
- # Unstructured if task isn't structured
186
- if not self.has_structured_output():
187
- return {}
188
-
189
- provider = self.model_provider()
190
- match provider.structured_output_mode:
191
- case StructuredOutputMode.json_mode:
192
- return {"response_format": {"type": "json_object"}}
193
- case StructuredOutputMode.json_schema:
194
- output_schema = self.task().output_schema()
195
- return {
196
- "response_format": {
197
- "type": "json_schema",
198
- "json_schema": {
199
- "name": "task_response",
200
- "schema": output_schema,
201
- },
202
- }
203
- }
204
- case StructuredOutputMode.function_calling_weak:
205
- return self.tool_call_params(strict=False)
206
- case StructuredOutputMode.function_calling:
207
- return self.tool_call_params(strict=True)
208
- case StructuredOutputMode.json_instructions:
209
- # JSON done via instructions in prompt, not the API response format. Do not ask for json_object (see option below).
210
- return {}
211
- case StructuredOutputMode.json_instruction_and_object:
212
- # We set response_format to json_object and also set json instructions in the prompt
213
- return {"response_format": {"type": "json_object"}}
214
- case StructuredOutputMode.default:
215
- # Default to function calling -- it's older than the other modes. Higher compatibility.
216
- return self.tool_call_params(strict=True)
217
- case _:
218
- raise_exhaustive_enum_error(provider.structured_output_mode)
219
-
220
- def tool_call_params(self, strict: bool) -> dict[str, Any]:
221
- # Add additional_properties: false to the schema (OpenAI requires this for some models)
222
- output_schema = self.task().output_schema()
223
- if not isinstance(output_schema, dict):
224
- raise ValueError(
225
- "Invalid output schema for this task. Can not use tool calls."
226
- )
227
- output_schema["additionalProperties"] = False
228
-
229
- function_params = {
230
- "name": "task_response",
231
- "parameters": output_schema,
232
- }
233
- # This should be on, but we allow setting function_calling_weak for APIs that don't support it.
234
- if strict:
235
- function_params["strict"] = True
236
-
237
- return {
238
- "tools": [
239
- {
240
- "type": "function",
241
- "function": function_params,
242
- }
243
- ],
244
- "tool_choice": {
245
- "type": "function",
246
- "function": {"name": "task_response"},
247
- },
248
- }
249
-
250
- def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
251
- # TODO P1: Don't love having this logic here. But it's a usability improvement
252
- # so better to keep it than exclude it. Should figure out how I want to isolate
253
- # this sort of logic so it's config driven and can be overridden
254
-
255
- extra_body = {}
256
- provider_options = {}
257
-
258
- if provider.require_openrouter_reasoning:
259
- # https://openrouter.ai/docs/use-cases/reasoning-tokens
260
- extra_body["reasoning"] = {
261
- "exclude": False,
262
- }
263
-
264
- if provider.r1_openrouter_options:
265
- # Require providers that support the reasoning parameter
266
- provider_options["require_parameters"] = True
267
- # Prefer R1 providers with reasonable perf/quants
268
- provider_options["order"] = ["Fireworks", "Together"]
269
- # R1 providers with unreasonable quants
270
- provider_options["ignore"] = ["DeepInfra"]
271
-
272
- # Only set of this request is to get logprobs.
273
- if (
274
- provider.logprobs_openrouter_options
275
- and self.base_adapter_config.top_logprobs is not None
276
- ):
277
- # Don't let OpenRouter choose a provider that doesn't support logprobs.
278
- provider_options["require_parameters"] = True
279
- # DeepInfra silently fails to return logprobs consistently.
280
- provider_options["ignore"] = ["DeepInfra"]
281
-
282
- if provider.openrouter_skip_required_parameters:
283
- # Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
284
- provider_options["require_parameters"] = False
285
-
286
- if len(provider_options) > 0:
287
- extra_body["provider"] = provider_options
288
-
289
- return extra_body
@@ -1,343 +0,0 @@
1
- import os
2
- from unittest.mock import AsyncMock, MagicMock, patch
3
-
4
- import pytest
5
- from langchain_aws import ChatBedrockConverse
6
- from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
7
- from langchain_fireworks import ChatFireworks
8
- from langchain_groq import ChatGroq
9
- from langchain_ollama import ChatOllama
10
-
11
- from kiln_ai.adapters.ml_model_list import (
12
- KilnModelProvider,
13
- ModelProviderName,
14
- StructuredOutputMode,
15
- )
16
- from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
17
- from kiln_ai.adapters.model_adapters.langchain_adapters import (
18
- LangchainAdapter,
19
- langchain_model_from_provider,
20
- )
21
- from kiln_ai.adapters.test_prompt_adaptors import build_test_task
22
- from kiln_ai.datamodel.task import RunConfig
23
-
24
-
25
- @pytest.fixture
26
- def mock_adapter(tmp_path):
27
- return LangchainAdapter(
28
- kiln_task=build_test_task(tmp_path),
29
- model_name="llama_3_1_8b",
30
- provider="ollama",
31
- )
32
-
33
-
34
- def test_langchain_adapter_munge_response(mock_adapter):
35
- # Mistral Large tool calling format is a bit different
36
- response = {
37
- "name": "task_response",
38
- "arguments": {
39
- "setup": "Why did the cow join a band?",
40
- "punchline": "Because she wanted to be a moo-sician!",
41
- },
42
- }
43
- munged = mock_adapter._munge_response(response)
44
- assert munged["setup"] == "Why did the cow join a band?"
45
- assert munged["punchline"] == "Because she wanted to be a moo-sician!"
46
-
47
- # non mistral format should continue to work
48
- munged = mock_adapter._munge_response(response["arguments"])
49
- assert munged["setup"] == "Why did the cow join a band?"
50
- assert munged["punchline"] == "Because she wanted to be a moo-sician!"
51
-
52
-
53
- def test_langchain_adapter_infer_model_name(tmp_path):
54
- task = build_test_task(tmp_path)
55
- custom = ChatGroq(model="llama-3.1-8b-instant", groq_api_key="test")
56
-
57
- lca = LangchainAdapter(kiln_task=task, custom_model=custom)
58
-
59
- assert lca.run_config.model_name == "custom.langchain:llama-3.1-8b-instant"
60
- assert lca.run_config.model_provider_name == "custom.langchain:ChatGroq"
61
-
62
-
63
- def test_langchain_adapter_info(tmp_path):
64
- task = build_test_task(tmp_path)
65
-
66
- lca = LangchainAdapter(kiln_task=task, model_name="llama_3_1_8b", provider="ollama")
67
-
68
- assert lca.adapter_name() == "kiln_langchain_adapter"
69
- assert lca.run_config.model_name == "llama_3_1_8b"
70
- assert lca.run_config.model_provider_name == "ollama"
71
-
72
-
73
- async def test_langchain_adapter_with_cot(tmp_path):
74
- task = build_test_task(tmp_path)
75
- task.output_json_schema = (
76
- '{"type": "object", "properties": {"count": {"type": "integer"}}}'
77
- )
78
- lca = LangchainAdapter(
79
- kiln_task=task,
80
- model_name="llama_3_1_8b",
81
- provider="ollama",
82
- prompt_id="simple_chain_of_thought_prompt_builder",
83
- )
84
-
85
- # Mock the base model and its invoke method
86
- mock_base_model = MagicMock()
87
- mock_base_model.ainvoke = AsyncMock(
88
- return_value=AIMessage(content="Chain of thought reasoning...")
89
- )
90
-
91
- # Create a separate mock for self.model()
92
- mock_model_instance = MagicMock()
93
- mock_model_instance.ainvoke = AsyncMock(return_value={"parsed": {"count": 1}})
94
-
95
- # Mock the langchain_model_from function to return the base model
96
- mock_model_from = AsyncMock(return_value=mock_base_model)
97
-
98
- # Patch both the langchain_model_from function and self.model()
99
- with (
100
- patch.object(LangchainAdapter, "langchain_model_from", mock_model_from),
101
- patch.object(LangchainAdapter, "model", return_value=mock_model_instance),
102
- ):
103
- response = await lca._run("test input")
104
-
105
- # First 3 messages are the same for both calls
106
- for invoke_args in [
107
- mock_base_model.ainvoke.call_args[0][0],
108
- mock_model_instance.ainvoke.call_args[0][0],
109
- ]:
110
- assert isinstance(
111
- invoke_args[0], SystemMessage
112
- ) # First message should be system prompt
113
- assert (
114
- "You are an assistant which performs math tasks provided in plain text."
115
- in invoke_args[0].content
116
- )
117
- assert isinstance(invoke_args[1], HumanMessage)
118
- assert "test input" in invoke_args[1].content
119
- assert isinstance(invoke_args[2], SystemMessage)
120
- assert "step by step" in invoke_args[2].content
121
-
122
- # the COT should only have 3 messages
123
- assert len(mock_base_model.ainvoke.call_args[0][0]) == 3
124
- assert len(mock_model_instance.ainvoke.call_args[0][0]) == 5
125
-
126
- # the final response should have the COT content and the final instructions
127
- invoke_args = mock_model_instance.ainvoke.call_args[0][0]
128
- assert isinstance(invoke_args[3], AIMessage)
129
- assert "Chain of thought reasoning..." in invoke_args[3].content
130
- assert isinstance(invoke_args[4], HumanMessage)
131
- assert COT_FINAL_ANSWER_PROMPT in invoke_args[4].content
132
-
133
- assert (
134
- response.intermediate_outputs["chain_of_thought"]
135
- == "Chain of thought reasoning..."
136
- )
137
- assert response.output == {"count": 1}
138
-
139
-
140
- @pytest.mark.parametrize(
141
- "structured_output_mode,expected_method",
142
- [
143
- (StructuredOutputMode.function_calling, "function_calling"),
144
- (StructuredOutputMode.json_mode, "json_mode"),
145
- (StructuredOutputMode.json_schema, "json_schema"),
146
- (StructuredOutputMode.json_instruction_and_object, "json_mode"),
147
- (StructuredOutputMode.default, None),
148
- ],
149
- )
150
- async def test_get_structured_output_options(
151
- mock_adapter, structured_output_mode, expected_method
152
- ):
153
- # Mock the provider response
154
- mock_provider = MagicMock()
155
- mock_provider.structured_output_mode = structured_output_mode
156
-
157
- # Mock adapter.model_provider()
158
- mock_adapter.model_provider = MagicMock(return_value=mock_provider)
159
-
160
- options = mock_adapter.get_structured_output_options("model_name", "provider")
161
- assert options.get("method") == expected_method
162
-
163
-
164
- @pytest.mark.asyncio
165
- async def test_langchain_model_from_provider_groq():
166
- provider = KilnModelProvider(
167
- name=ModelProviderName.groq, provider_options={"model": "mixtral-8x7b"}
168
- )
169
-
170
- with patch(
171
- "kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
172
- ) as mock_config:
173
- mock_config.return_value.groq_api_key = "test_key"
174
- model = await langchain_model_from_provider(provider, "mixtral-8x7b")
175
- assert isinstance(model, ChatGroq)
176
- assert model.model_name == "mixtral-8x7b"
177
-
178
-
179
- @pytest.mark.asyncio
180
- async def test_langchain_model_from_provider_bedrock():
181
- provider = KilnModelProvider(
182
- name=ModelProviderName.amazon_bedrock,
183
- provider_options={"model": "anthropic.claude-v2", "region_name": "us-east-1"},
184
- )
185
-
186
- with patch(
187
- "kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
188
- ) as mock_config:
189
- mock_config.return_value.bedrock_access_key = "test_access"
190
- mock_config.return_value.bedrock_secret_key = "test_secret"
191
- model = await langchain_model_from_provider(provider, "anthropic.claude-v2")
192
- assert isinstance(model, ChatBedrockConverse)
193
- assert os.environ.get("AWS_ACCESS_KEY_ID") == "test_access"
194
- assert os.environ.get("AWS_SECRET_ACCESS_KEY") == "test_secret"
195
-
196
-
197
- @pytest.mark.asyncio
198
- async def test_langchain_model_from_provider_fireworks():
199
- provider = KilnModelProvider(
200
- name=ModelProviderName.fireworks_ai, provider_options={"model": "mixtral-8x7b"}
201
- )
202
-
203
- with patch(
204
- "kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
205
- ) as mock_config:
206
- mock_config.return_value.fireworks_api_key = "test_key"
207
- model = await langchain_model_from_provider(provider, "mixtral-8x7b")
208
- assert isinstance(model, ChatFireworks)
209
-
210
-
211
- @pytest.mark.asyncio
212
- async def test_langchain_model_from_provider_ollama():
213
- provider = KilnModelProvider(
214
- name=ModelProviderName.ollama,
215
- provider_options={"model": "llama2", "model_aliases": ["llama2-uncensored"]},
216
- )
217
-
218
- mock_connection = MagicMock()
219
- with (
220
- patch(
221
- "kiln_ai.adapters.model_adapters.langchain_adapters.get_ollama_connection",
222
- return_value=AsyncMock(return_value=mock_connection),
223
- ),
224
- patch(
225
- "kiln_ai.adapters.model_adapters.langchain_adapters.ollama_model_installed",
226
- return_value=True,
227
- ),
228
- patch(
229
- "kiln_ai.adapters.model_adapters.langchain_adapters.ollama_base_url",
230
- return_value="http://localhost:11434",
231
- ),
232
- ):
233
- model = await langchain_model_from_provider(provider, "llama2")
234
- assert isinstance(model, ChatOllama)
235
- assert model.model == "llama2"
236
-
237
-
238
- @pytest.mark.asyncio
239
- async def test_langchain_model_from_provider_invalid():
240
- provider = KilnModelProvider.model_construct(
241
- name="invalid_provider", provider_options={}
242
- )
243
-
244
- with pytest.raises(ValueError, match="Invalid model or provider"):
245
- await langchain_model_from_provider(provider, "test_model")
246
-
247
-
248
- @pytest.mark.asyncio
249
- async def test_langchain_adapter_model_caching(tmp_path):
250
- task = build_test_task(tmp_path)
251
- custom_model = ChatGroq(model="mixtral-8x7b", groq_api_key="test")
252
-
253
- adapter = LangchainAdapter(kiln_task=task, custom_model=custom_model)
254
-
255
- # First call should return the cached model
256
- model1 = await adapter.model()
257
- assert model1 is custom_model
258
-
259
- # Second call should return the same cached instance
260
- model2 = await adapter.model()
261
- assert model2 is model1
262
-
263
-
264
- @pytest.mark.asyncio
265
- async def test_langchain_adapter_model_structured_output(tmp_path):
266
- task = build_test_task(tmp_path)
267
- task.output_json_schema = """
268
- {
269
- "type": "object",
270
- "properties": {
271
- "count": {"type": "integer"}
272
- }
273
- }
274
- """
275
-
276
- mock_model = MagicMock()
277
- mock_model.with_structured_output = MagicMock(return_value="structured_model")
278
-
279
- adapter = LangchainAdapter(
280
- kiln_task=task, model_name="test_model", provider="ollama"
281
- )
282
- adapter.get_structured_output_options = MagicMock(
283
- return_value={"option1": "value1"}
284
- )
285
- adapter.langchain_model_from = AsyncMock(return_value=mock_model)
286
-
287
- model = await adapter.model()
288
-
289
- # Verify the model was configured with structured output
290
- mock_model.with_structured_output.assert_called_once_with(
291
- {
292
- "type": "object",
293
- "properties": {"count": {"type": "integer"}},
294
- "title": "task_response",
295
- "description": "A response from the task",
296
- },
297
- include_raw=True,
298
- option1="value1",
299
- )
300
- assert model == "structured_model"
301
-
302
-
303
- @pytest.mark.asyncio
304
- async def test_langchain_adapter_model_no_structured_output_support(tmp_path):
305
- task = build_test_task(tmp_path)
306
- task.output_json_schema = (
307
- '{"type": "object", "properties": {"count": {"type": "integer"}}}'
308
- )
309
-
310
- mock_model = MagicMock()
311
- # Remove with_structured_output method
312
- del mock_model.with_structured_output
313
-
314
- adapter = LangchainAdapter(
315
- kiln_task=task, model_name="test_model", provider="ollama"
316
- )
317
- adapter.langchain_model_from = AsyncMock(return_value=mock_model)
318
-
319
- with pytest.raises(ValueError, match="does not support structured output"):
320
- await adapter.model()
321
-
322
-
323
- import pytest
324
-
325
-
326
- @pytest.mark.parametrize(
327
- "provider_name",
328
- [
329
- (ModelProviderName.openai),
330
- (ModelProviderName.openai_compatible),
331
- (ModelProviderName.openrouter),
332
- ],
333
- )
334
- @pytest.mark.asyncio
335
- async def test_langchain_model_from_provider_unsupported_providers(provider_name):
336
- # Arrange
337
- provider = KilnModelProvider(
338
- name=provider_name, provider_options={}, structured_output_mode="default"
339
- )
340
-
341
- # Assert unsupported providers raise an error
342
- with pytest.raises(ValueError):
343
- await langchain_model_from_provider(provider, "test-model")