kiln-ai 0.11.1__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (80) hide show
  1. kiln_ai/adapters/__init__.py +4 -0
  2. kiln_ai/adapters/adapter_registry.py +163 -39
  3. kiln_ai/adapters/data_gen/data_gen_task.py +18 -0
  4. kiln_ai/adapters/eval/__init__.py +28 -0
  5. kiln_ai/adapters/eval/base_eval.py +164 -0
  6. kiln_ai/adapters/eval/eval_runner.py +270 -0
  7. kiln_ai/adapters/eval/g_eval.py +368 -0
  8. kiln_ai/adapters/eval/registry.py +16 -0
  9. kiln_ai/adapters/eval/test_base_eval.py +325 -0
  10. kiln_ai/adapters/eval/test_eval_runner.py +641 -0
  11. kiln_ai/adapters/eval/test_g_eval.py +498 -0
  12. kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
  14. kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
  15. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +4 -1
  16. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
  17. kiln_ai/adapters/fine_tune/test_openai_finetune.py +1 -1
  18. kiln_ai/adapters/fine_tune/test_together_finetune.py +531 -0
  19. kiln_ai/adapters/fine_tune/together_finetune.py +325 -0
  20. kiln_ai/adapters/ml_model_list.py +758 -163
  21. kiln_ai/adapters/model_adapters/__init__.py +2 -4
  22. kiln_ai/adapters/model_adapters/base_adapter.py +61 -43
  23. kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
  24. kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
  25. kiln_ai/adapters/model_adapters/test_base_adapter.py +22 -13
  26. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
  27. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -19
  28. kiln_ai/adapters/model_adapters/test_structured_output.py +59 -35
  29. kiln_ai/adapters/ollama_tools.py +3 -3
  30. kiln_ai/adapters/parsers/r1_parser.py +19 -14
  31. kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
  32. kiln_ai/adapters/prompt_builders.py +80 -42
  33. kiln_ai/adapters/provider_tools.py +50 -58
  34. kiln_ai/adapters/repair/repair_task.py +9 -21
  35. kiln_ai/adapters/repair/test_repair_task.py +6 -6
  36. kiln_ai/adapters/run_output.py +3 -0
  37. kiln_ai/adapters/test_adapter_registry.py +26 -29
  38. kiln_ai/adapters/test_generate_docs.py +4 -4
  39. kiln_ai/adapters/test_ollama_tools.py +0 -1
  40. kiln_ai/adapters/test_prompt_adaptors.py +47 -33
  41. kiln_ai/adapters/test_prompt_builders.py +91 -31
  42. kiln_ai/adapters/test_provider_tools.py +26 -81
  43. kiln_ai/datamodel/__init__.py +50 -952
  44. kiln_ai/datamodel/basemodel.py +2 -0
  45. kiln_ai/datamodel/datamodel_enums.py +60 -0
  46. kiln_ai/datamodel/dataset_filters.py +114 -0
  47. kiln_ai/datamodel/dataset_split.py +170 -0
  48. kiln_ai/datamodel/eval.py +298 -0
  49. kiln_ai/datamodel/finetune.py +105 -0
  50. kiln_ai/datamodel/json_schema.py +7 -1
  51. kiln_ai/datamodel/project.py +23 -0
  52. kiln_ai/datamodel/prompt.py +37 -0
  53. kiln_ai/datamodel/prompt_id.py +83 -0
  54. kiln_ai/datamodel/strict_mode.py +24 -0
  55. kiln_ai/datamodel/task.py +181 -0
  56. kiln_ai/datamodel/task_output.py +328 -0
  57. kiln_ai/datamodel/task_run.py +164 -0
  58. kiln_ai/datamodel/test_basemodel.py +19 -11
  59. kiln_ai/datamodel/test_dataset_filters.py +71 -0
  60. kiln_ai/datamodel/test_dataset_split.py +32 -8
  61. kiln_ai/datamodel/test_datasource.py +22 -2
  62. kiln_ai/datamodel/test_eval_model.py +635 -0
  63. kiln_ai/datamodel/test_example_models.py +9 -13
  64. kiln_ai/datamodel/test_json_schema.py +23 -0
  65. kiln_ai/datamodel/test_models.py +2 -2
  66. kiln_ai/datamodel/test_prompt_id.py +129 -0
  67. kiln_ai/datamodel/test_task.py +159 -0
  68. kiln_ai/utils/config.py +43 -1
  69. kiln_ai/utils/dataset_import.py +232 -0
  70. kiln_ai/utils/test_dataset_import.py +596 -0
  71. {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/METADATA +86 -6
  72. kiln_ai-0.13.0.dist-info/RECORD +103 -0
  73. kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -302
  74. kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -11
  75. kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -246
  76. kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -350
  77. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -225
  78. kiln_ai-0.11.1.dist-info/RECORD +0 -76
  79. {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/WHEEL +0 -0
  80. {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,246 +0,0 @@
1
- from typing import Any, Dict
2
-
3
- from openai import AsyncOpenAI
4
- from openai.types.chat import (
5
- ChatCompletion,
6
- ChatCompletionAssistantMessageParam,
7
- ChatCompletionSystemMessageParam,
8
- ChatCompletionUserMessageParam,
9
- )
10
-
11
- import kiln_ai.datamodel as datamodel
12
- from kiln_ai.adapters.ml_model_list import StructuredOutputMode
13
- from kiln_ai.adapters.model_adapters.base_adapter import (
14
- COT_FINAL_ANSWER_PROMPT,
15
- AdapterInfo,
16
- BaseAdapter,
17
- BasePromptBuilder,
18
- RunOutput,
19
- )
20
- from kiln_ai.adapters.model_adapters.openai_compatible_config import (
21
- OpenAICompatibleConfig,
22
- )
23
- from kiln_ai.adapters.parsers.json_parser import parse_json_string
24
- from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
25
-
26
-
27
- class OpenAICompatibleAdapter(BaseAdapter):
28
- def __init__(
29
- self,
30
- config: OpenAICompatibleConfig,
31
- kiln_task: datamodel.Task,
32
- prompt_builder: BasePromptBuilder | None = None,
33
- tags: list[str] | None = None,
34
- ):
35
- self.config = config
36
- self.client = AsyncOpenAI(
37
- api_key=config.api_key,
38
- base_url=config.base_url,
39
- default_headers=config.default_headers,
40
- )
41
-
42
- super().__init__(
43
- kiln_task,
44
- model_name=config.model_name,
45
- model_provider_name=config.provider_name,
46
- prompt_builder=prompt_builder,
47
- tags=tags,
48
- )
49
-
50
- async def _run(self, input: Dict | str) -> RunOutput:
51
- provider = self.model_provider()
52
- intermediate_outputs: dict[str, str] = {}
53
- prompt = self.build_prompt()
54
- user_msg = self.prompt_builder.build_user_message(input)
55
- messages = [
56
- ChatCompletionSystemMessageParam(role="system", content=prompt),
57
- ChatCompletionUserMessageParam(role="user", content=user_msg),
58
- ]
59
-
60
- run_strategy, cot_prompt = self.run_strategy()
61
-
62
- if run_strategy == "cot_as_message":
63
- if not cot_prompt:
64
- raise ValueError("cot_prompt is required for cot_as_message strategy")
65
- messages.append(
66
- ChatCompletionSystemMessageParam(role="system", content=cot_prompt)
67
- )
68
- elif run_strategy == "cot_two_call":
69
- if not cot_prompt:
70
- raise ValueError("cot_prompt is required for cot_two_call strategy")
71
- messages.append(
72
- ChatCompletionSystemMessageParam(role="system", content=cot_prompt)
73
- )
74
-
75
- # First call for chain of thought
76
- cot_response = await self.client.chat.completions.create(
77
- model=provider.provider_options["model"],
78
- messages=messages,
79
- )
80
- cot_content = cot_response.choices[0].message.content
81
- if cot_content is not None:
82
- intermediate_outputs["chain_of_thought"] = cot_content
83
-
84
- messages.extend(
85
- [
86
- ChatCompletionAssistantMessageParam(
87
- role="assistant", content=cot_content
88
- ),
89
- ChatCompletionUserMessageParam(
90
- role="user",
91
- content=COT_FINAL_ANSWER_PROMPT,
92
- ),
93
- ]
94
- )
95
-
96
- # OpenRouter specific options for reasoning models
97
- extra_body = {}
98
- require_or_reasoning = (
99
- self.config.openrouter_style_reasoning and provider.reasoning_capable
100
- )
101
- if require_or_reasoning:
102
- extra_body["include_reasoning"] = True
103
- # Filter to providers that support the reasoning parameter
104
- extra_body["provider"] = {
105
- "require_parameters": True,
106
- # Ugly to have these here, but big range of quality of R1 providers
107
- "order": ["Fireworks", "Together"],
108
- # fp8 quants are awful
109
- "ignore": ["DeepInfra"],
110
- }
111
-
112
- # Main completion call
113
- response_format_options = await self.response_format_options()
114
- response = await self.client.chat.completions.create(
115
- model=provider.provider_options["model"],
116
- messages=messages,
117
- extra_body=extra_body,
118
- **response_format_options,
119
- )
120
-
121
- if not isinstance(response, ChatCompletion):
122
- raise RuntimeError(
123
- f"Expected ChatCompletion response, got {type(response)}."
124
- )
125
-
126
- if hasattr(response, "error") and response.error: # pyright: ignore
127
- raise RuntimeError(
128
- f"OpenAI compatible API returned status code {response.error.get('code')}: {response.error.get('message') or 'Unknown error'}.\nError: {response.error}" # pyright: ignore
129
- )
130
- if not response.choices or len(response.choices) == 0:
131
- raise RuntimeError(
132
- "No message content returned in the response from OpenAI compatible API"
133
- )
134
-
135
- message = response.choices[0].message
136
-
137
- # Save reasoning if it exists (OpenRouter specific format)
138
- if require_or_reasoning:
139
- if (
140
- hasattr(message, "reasoning") and message.reasoning # pyright: ignore
141
- ):
142
- intermediate_outputs["reasoning"] = message.reasoning # pyright: ignore
143
- else:
144
- raise RuntimeError(
145
- "Reasoning is required for this model, but no reasoning was returned from OpenRouter."
146
- )
147
-
148
- # the string content of the response
149
- response_content = message.content
150
-
151
- # Fallback: Use args of first tool call to task_response if it exists
152
- if not response_content and message.tool_calls:
153
- tool_call = next(
154
- (
155
- tool_call
156
- for tool_call in message.tool_calls
157
- if tool_call.function.name == "task_response"
158
- ),
159
- None,
160
- )
161
- if tool_call:
162
- response_content = tool_call.function.arguments
163
-
164
- if not isinstance(response_content, str):
165
- raise RuntimeError(f"response is not a string: {response_content}")
166
-
167
- if self.has_structured_output():
168
- structured_response = parse_json_string(response_content)
169
- return RunOutput(
170
- output=structured_response,
171
- intermediate_outputs=intermediate_outputs,
172
- )
173
-
174
- return RunOutput(
175
- output=response_content,
176
- intermediate_outputs=intermediate_outputs,
177
- )
178
-
179
- def adapter_info(self) -> AdapterInfo:
180
- return AdapterInfo(
181
- model_name=self.model_name,
182
- model_provider=self.model_provider_name,
183
- adapter_name="kiln_openai_compatible_adapter",
184
- prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(),
185
- prompt_id=self.prompt_builder.prompt_id(),
186
- )
187
-
188
- async def response_format_options(self) -> dict[str, Any]:
189
- # Unstructured if task isn't structured
190
- if not self.has_structured_output():
191
- return {}
192
-
193
- provider = self.model_provider()
194
- match provider.structured_output_mode:
195
- case StructuredOutputMode.json_mode:
196
- return {"response_format": {"type": "json_object"}}
197
- case StructuredOutputMode.json_schema:
198
- output_schema = self.kiln_task.output_schema()
199
- return {
200
- "response_format": {
201
- "type": "json_schema",
202
- "json_schema": {
203
- "name": "task_response",
204
- "schema": output_schema,
205
- },
206
- }
207
- }
208
- case StructuredOutputMode.function_calling:
209
- return self.tool_call_params()
210
- case StructuredOutputMode.json_instructions:
211
- # JSON done via instructions in prompt, not the API response format. Do not ask for json_object (see option below).
212
- return {}
213
- case StructuredOutputMode.json_instruction_and_object:
214
- # We set response_format to json_object and also set json instructions in the prompt
215
- return {"response_format": {"type": "json_object"}}
216
- case StructuredOutputMode.default:
217
- # Default to function calling -- it's older than the other modes. Higher compatibility.
218
- return self.tool_call_params()
219
- case _:
220
- raise_exhaustive_enum_error(provider.structured_output_mode)
221
-
222
- def tool_call_params(self) -> dict[str, Any]:
223
- # Add additional_properties: false to the schema (OpenAI requires this for some models)
224
- output_schema = self.kiln_task.output_schema()
225
- if not isinstance(output_schema, dict):
226
- raise ValueError(
227
- "Invalid output schema for this task. Can not use tool calls."
228
- )
229
- output_schema["additionalProperties"] = False
230
-
231
- return {
232
- "tools": [
233
- {
234
- "type": "function",
235
- "function": {
236
- "name": "task_response",
237
- "parameters": output_schema,
238
- "strict": True,
239
- },
240
- }
241
- ],
242
- "tool_choice": {
243
- "type": "function",
244
- "function": {"name": "task_response"},
245
- },
246
- }
@@ -1,350 +0,0 @@
1
- import os
2
- from unittest.mock import AsyncMock, MagicMock, patch
3
-
4
- import pytest
5
- from langchain_aws import ChatBedrockConverse
6
- from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
7
- from langchain_fireworks import ChatFireworks
8
- from langchain_groq import ChatGroq
9
- from langchain_ollama import ChatOllama
10
-
11
- from kiln_ai.adapters.ml_model_list import (
12
- KilnModelProvider,
13
- ModelProviderName,
14
- StructuredOutputMode,
15
- )
16
- from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
17
- from kiln_ai.adapters.model_adapters.langchain_adapters import (
18
- LangchainAdapter,
19
- langchain_model_from_provider,
20
- )
21
- from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
22
- from kiln_ai.adapters.test_prompt_adaptors import build_test_task
23
-
24
-
25
- @pytest.fixture
26
- def mock_adapter(tmp_path):
27
- return LangchainAdapter(
28
- kiln_task=build_test_task(tmp_path),
29
- model_name="llama_3_1_8b",
30
- provider="ollama",
31
- )
32
-
33
-
34
- def test_langchain_adapter_munge_response(mock_adapter):
35
- # Mistral Large tool calling format is a bit different
36
- response = {
37
- "name": "task_response",
38
- "arguments": {
39
- "setup": "Why did the cow join a band?",
40
- "punchline": "Because she wanted to be a moo-sician!",
41
- },
42
- }
43
- munged = mock_adapter._munge_response(response)
44
- assert munged["setup"] == "Why did the cow join a band?"
45
- assert munged["punchline"] == "Because she wanted to be a moo-sician!"
46
-
47
- # non mistral format should continue to work
48
- munged = mock_adapter._munge_response(response["arguments"])
49
- assert munged["setup"] == "Why did the cow join a band?"
50
- assert munged["punchline"] == "Because she wanted to be a moo-sician!"
51
-
52
-
53
- def test_langchain_adapter_infer_model_name(tmp_path):
54
- task = build_test_task(tmp_path)
55
- custom = ChatGroq(model="llama-3.1-8b-instant", groq_api_key="test")
56
-
57
- lca = LangchainAdapter(kiln_task=task, custom_model=custom)
58
-
59
- model_info = lca.adapter_info()
60
- assert model_info.model_name == "custom.langchain:llama-3.1-8b-instant"
61
- assert model_info.model_provider == "custom.langchain:ChatGroq"
62
-
63
-
64
- def test_langchain_adapter_info(tmp_path):
65
- task = build_test_task(tmp_path)
66
-
67
- lca = LangchainAdapter(kiln_task=task, model_name="llama_3_1_8b", provider="ollama")
68
-
69
- model_info = lca.adapter_info()
70
- assert model_info.adapter_name == "kiln_langchain_adapter"
71
- assert model_info.model_name == "llama_3_1_8b"
72
- assert model_info.model_provider == "ollama"
73
-
74
-
75
- async def test_langchain_adapter_with_cot(tmp_path):
76
- task = build_test_task(tmp_path)
77
- task.output_json_schema = (
78
- '{"type": "object", "properties": {"count": {"type": "integer"}}}'
79
- )
80
- lca = LangchainAdapter(
81
- kiln_task=task,
82
- model_name="llama_3_1_8b",
83
- provider="ollama",
84
- prompt_builder=SimpleChainOfThoughtPromptBuilder(task),
85
- )
86
-
87
- # Mock the base model and its invoke method
88
- mock_base_model = MagicMock()
89
- mock_base_model.ainvoke = AsyncMock(
90
- return_value=AIMessage(content="Chain of thought reasoning...")
91
- )
92
-
93
- # Create a separate mock for self.model()
94
- mock_model_instance = MagicMock()
95
- mock_model_instance.ainvoke = AsyncMock(return_value={"parsed": {"count": 1}})
96
-
97
- # Mock the langchain_model_from function to return the base model
98
- mock_model_from = AsyncMock(return_value=mock_base_model)
99
-
100
- # Patch both the langchain_model_from function and self.model()
101
- with (
102
- patch.object(LangchainAdapter, "langchain_model_from", mock_model_from),
103
- patch.object(LangchainAdapter, "model", return_value=mock_model_instance),
104
- ):
105
- response = await lca._run("test input")
106
-
107
- # First 3 messages are the same for both calls
108
- for invoke_args in [
109
- mock_base_model.ainvoke.call_args[0][0],
110
- mock_model_instance.ainvoke.call_args[0][0],
111
- ]:
112
- assert isinstance(
113
- invoke_args[0], SystemMessage
114
- ) # First message should be system prompt
115
- assert (
116
- "You are an assistant which performs math tasks provided in plain text."
117
- in invoke_args[0].content
118
- )
119
- assert isinstance(invoke_args[1], HumanMessage)
120
- assert "test input" in invoke_args[1].content
121
- assert isinstance(invoke_args[2], SystemMessage)
122
- assert "step by step" in invoke_args[2].content
123
-
124
- # the COT should only have 3 messages
125
- assert len(mock_base_model.ainvoke.call_args[0][0]) == 3
126
- assert len(mock_model_instance.ainvoke.call_args[0][0]) == 5
127
-
128
- # the final response should have the COT content and the final instructions
129
- invoke_args = mock_model_instance.ainvoke.call_args[0][0]
130
- assert isinstance(invoke_args[3], AIMessage)
131
- assert "Chain of thought reasoning..." in invoke_args[3].content
132
- assert isinstance(invoke_args[4], HumanMessage)
133
- assert COT_FINAL_ANSWER_PROMPT in invoke_args[4].content
134
-
135
- assert (
136
- response.intermediate_outputs["chain_of_thought"]
137
- == "Chain of thought reasoning..."
138
- )
139
- assert response.output == {"count": 1}
140
-
141
-
142
- @pytest.mark.parametrize(
143
- "structured_output_mode,expected_method",
144
- [
145
- (StructuredOutputMode.function_calling, "function_calling"),
146
- (StructuredOutputMode.json_mode, "json_mode"),
147
- (StructuredOutputMode.json_schema, "json_schema"),
148
- (StructuredOutputMode.json_instruction_and_object, "json_mode"),
149
- (StructuredOutputMode.default, None),
150
- ],
151
- )
152
- async def test_get_structured_output_options(
153
- mock_adapter, structured_output_mode, expected_method
154
- ):
155
- # Mock the provider response
156
- mock_provider = MagicMock()
157
- mock_provider.structured_output_mode = structured_output_mode
158
-
159
- # Mock adapter.model_provider()
160
- mock_adapter.model_provider = MagicMock(return_value=mock_provider)
161
-
162
- options = mock_adapter.get_structured_output_options("model_name", "provider")
163
- assert options.get("method") == expected_method
164
-
165
-
166
- @pytest.mark.asyncio
167
- async def test_langchain_model_from_provider_groq():
168
- provider = KilnModelProvider(
169
- name=ModelProviderName.groq, provider_options={"model": "mixtral-8x7b"}
170
- )
171
-
172
- with patch(
173
- "kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
174
- ) as mock_config:
175
- mock_config.return_value.groq_api_key = "test_key"
176
- model = await langchain_model_from_provider(provider, "mixtral-8x7b")
177
- assert isinstance(model, ChatGroq)
178
- assert model.model_name == "mixtral-8x7b"
179
-
180
-
181
- @pytest.mark.asyncio
182
- async def test_langchain_model_from_provider_bedrock():
183
- provider = KilnModelProvider(
184
- name=ModelProviderName.amazon_bedrock,
185
- provider_options={"model": "anthropic.claude-v2", "region_name": "us-east-1"},
186
- )
187
-
188
- with patch(
189
- "kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
190
- ) as mock_config:
191
- mock_config.return_value.bedrock_access_key = "test_access"
192
- mock_config.return_value.bedrock_secret_key = "test_secret"
193
- model = await langchain_model_from_provider(provider, "anthropic.claude-v2")
194
- assert isinstance(model, ChatBedrockConverse)
195
- assert os.environ.get("AWS_ACCESS_KEY_ID") == "test_access"
196
- assert os.environ.get("AWS_SECRET_ACCESS_KEY") == "test_secret"
197
-
198
-
199
- @pytest.mark.asyncio
200
- async def test_langchain_model_from_provider_fireworks():
201
- provider = KilnModelProvider(
202
- name=ModelProviderName.fireworks_ai, provider_options={"model": "mixtral-8x7b"}
203
- )
204
-
205
- with patch(
206
- "kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
207
- ) as mock_config:
208
- mock_config.return_value.fireworks_api_key = "test_key"
209
- model = await langchain_model_from_provider(provider, "mixtral-8x7b")
210
- assert isinstance(model, ChatFireworks)
211
-
212
-
213
- @pytest.mark.asyncio
214
- async def test_langchain_model_from_provider_ollama():
215
- provider = KilnModelProvider(
216
- name=ModelProviderName.ollama,
217
- provider_options={"model": "llama2", "model_aliases": ["llama2-uncensored"]},
218
- )
219
-
220
- mock_connection = MagicMock()
221
- with (
222
- patch(
223
- "kiln_ai.adapters.model_adapters.langchain_adapters.get_ollama_connection",
224
- return_value=AsyncMock(return_value=mock_connection),
225
- ),
226
- patch(
227
- "kiln_ai.adapters.model_adapters.langchain_adapters.ollama_model_installed",
228
- return_value=True,
229
- ),
230
- patch(
231
- "kiln_ai.adapters.model_adapters.langchain_adapters.ollama_base_url",
232
- return_value="http://localhost:11434",
233
- ),
234
- ):
235
- model = await langchain_model_from_provider(provider, "llama2")
236
- assert isinstance(model, ChatOllama)
237
- assert model.model == "llama2"
238
-
239
-
240
- @pytest.mark.asyncio
241
- async def test_langchain_model_from_provider_invalid():
242
- provider = KilnModelProvider.model_construct(
243
- name="invalid_provider", provider_options={}
244
- )
245
-
246
- with pytest.raises(ValueError, match="Invalid model or provider"):
247
- await langchain_model_from_provider(provider, "test_model")
248
-
249
-
250
- @pytest.mark.asyncio
251
- async def test_langchain_adapter_model_caching(tmp_path):
252
- task = build_test_task(tmp_path)
253
- custom_model = ChatGroq(model="mixtral-8x7b", groq_api_key="test")
254
-
255
- adapter = LangchainAdapter(kiln_task=task, custom_model=custom_model)
256
-
257
- # First call should return the cached model
258
- model1 = await adapter.model()
259
- assert model1 is custom_model
260
-
261
- # Second call should return the same cached instance
262
- model2 = await adapter.model()
263
- assert model2 is model1
264
-
265
-
266
- @pytest.mark.asyncio
267
- async def test_langchain_adapter_model_structured_output(tmp_path):
268
- task = build_test_task(tmp_path)
269
- task.output_json_schema = """
270
- {
271
- "type": "object",
272
- "properties": {
273
- "count": {"type": "integer"}
274
- }
275
- }
276
- """
277
-
278
- mock_model = MagicMock()
279
- mock_model.with_structured_output = MagicMock(return_value="structured_model")
280
-
281
- adapter = LangchainAdapter(
282
- kiln_task=task, model_name="test_model", provider="ollama"
283
- )
284
- adapter.get_structured_output_options = MagicMock(
285
- return_value={"option1": "value1"}
286
- )
287
- adapter.langchain_model_from = AsyncMock(return_value=mock_model)
288
-
289
- model = await adapter.model()
290
-
291
- # Verify the model was configured with structured output
292
- mock_model.with_structured_output.assert_called_once_with(
293
- {
294
- "type": "object",
295
- "properties": {"count": {"type": "integer"}},
296
- "title": "task_response",
297
- "description": "A response from the task",
298
- },
299
- include_raw=True,
300
- option1="value1",
301
- )
302
- assert model == "structured_model"
303
-
304
-
305
- @pytest.mark.asyncio
306
- async def test_langchain_adapter_model_no_structured_output_support(tmp_path):
307
- task = build_test_task(tmp_path)
308
- task.output_json_schema = (
309
- '{"type": "object", "properties": {"count": {"type": "integer"}}}'
310
- )
311
-
312
- mock_model = MagicMock()
313
- # Remove with_structured_output method
314
- del mock_model.with_structured_output
315
-
316
- adapter = LangchainAdapter(
317
- kiln_task=task, model_name="test_model", provider="ollama"
318
- )
319
- adapter.langchain_model_from = AsyncMock(return_value=mock_model)
320
-
321
- with pytest.raises(ValueError, match="does not support structured output"):
322
- await adapter.model()
323
-
324
-
325
- import pytest
326
-
327
- from kiln_ai.adapters.ml_model_list import KilnModelProvider, ModelProviderName
328
- from kiln_ai.adapters.model_adapters.langchain_adapters import (
329
- langchain_model_from_provider,
330
- )
331
-
332
-
333
- @pytest.mark.parametrize(
334
- "provider_name",
335
- [
336
- (ModelProviderName.openai),
337
- (ModelProviderName.openai_compatible),
338
- (ModelProviderName.openrouter),
339
- ],
340
- )
341
- @pytest.mark.asyncio
342
- async def test_langchain_model_from_provider_unsupported_providers(provider_name):
343
- # Arrange
344
- provider = KilnModelProvider(
345
- name=provider_name, provider_options={}, structured_output_mode="default"
346
- )
347
-
348
- # Assert unsupported providers raise an error
349
- with pytest.raises(ValueError):
350
- await langchain_model_from_provider(provider, "test-model")