kiln-ai 0.12.0__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +4 -0
- kiln_ai/adapters/adapter_registry.py +153 -28
- kiln_ai/adapters/eval/__init__.py +28 -0
- kiln_ai/adapters/eval/eval_runner.py +4 -1
- kiln_ai/adapters/eval/g_eval.py +2 -1
- kiln_ai/adapters/eval/test_base_eval.py +1 -0
- kiln_ai/adapters/eval/test_eval_runner.py +1 -0
- kiln_ai/adapters/eval/test_g_eval.py +1 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/test_together_finetune.py +531 -0
- kiln_ai/adapters/fine_tune/together_finetune.py +325 -0
- kiln_ai/adapters/ml_model_list.py +638 -155
- kiln_ai/adapters/model_adapters/__init__.py +2 -4
- kiln_ai/adapters/model_adapters/base_adapter.py +14 -11
- kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
- kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
- kiln_ai/adapters/model_adapters/test_structured_output.py +23 -5
- kiln_ai/adapters/ollama_tools.py +3 -2
- kiln_ai/adapters/parsers/r1_parser.py +19 -14
- kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
- kiln_ai/adapters/provider_tools.py +50 -58
- kiln_ai/adapters/repair/test_repair_task.py +3 -3
- kiln_ai/adapters/run_output.py +1 -1
- kiln_ai/adapters/test_adapter_registry.py +17 -20
- kiln_ai/adapters/test_generate_docs.py +2 -2
- kiln_ai/adapters/test_prompt_adaptors.py +30 -19
- kiln_ai/adapters/test_provider_tools.py +26 -81
- kiln_ai/datamodel/basemodel.py +2 -0
- kiln_ai/datamodel/datamodel_enums.py +2 -0
- kiln_ai/datamodel/json_schema.py +1 -1
- kiln_ai/datamodel/task_output.py +13 -6
- kiln_ai/datamodel/test_basemodel.py +9 -0
- kiln_ai/datamodel/test_datasource.py +19 -0
- kiln_ai/utils/config.py +37 -0
- kiln_ai/utils/dataset_import.py +232 -0
- kiln_ai/utils/test_dataset_import.py +596 -0
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/METADATA +51 -7
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/RECORD +42 -39
- kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -309
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -10
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -289
- kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -343
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -216
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,289 +0,0 @@
|
|
|
1
|
-
from typing import Any, Dict
|
|
2
|
-
|
|
3
|
-
from openai import AsyncOpenAI
|
|
4
|
-
from openai.types.chat import (
|
|
5
|
-
ChatCompletion,
|
|
6
|
-
ChatCompletionAssistantMessageParam,
|
|
7
|
-
ChatCompletionSystemMessageParam,
|
|
8
|
-
ChatCompletionUserMessageParam,
|
|
9
|
-
)
|
|
10
|
-
|
|
11
|
-
import kiln_ai.datamodel as datamodel
|
|
12
|
-
from kiln_ai.adapters.ml_model_list import (
|
|
13
|
-
KilnModelProvider,
|
|
14
|
-
ModelProviderName,
|
|
15
|
-
StructuredOutputMode,
|
|
16
|
-
)
|
|
17
|
-
from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
18
|
-
COT_FINAL_ANSWER_PROMPT,
|
|
19
|
-
AdapterConfig,
|
|
20
|
-
BaseAdapter,
|
|
21
|
-
RunOutput,
|
|
22
|
-
)
|
|
23
|
-
from kiln_ai.adapters.model_adapters.openai_compatible_config import (
|
|
24
|
-
OpenAICompatibleConfig,
|
|
25
|
-
)
|
|
26
|
-
from kiln_ai.adapters.parsers.json_parser import parse_json_string
|
|
27
|
-
from kiln_ai.datamodel import PromptGenerators, PromptId
|
|
28
|
-
from kiln_ai.datamodel.task import RunConfig
|
|
29
|
-
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class OpenAICompatibleAdapter(BaseAdapter):
|
|
33
|
-
def __init__(
|
|
34
|
-
self,
|
|
35
|
-
config: OpenAICompatibleConfig,
|
|
36
|
-
kiln_task: datamodel.Task,
|
|
37
|
-
prompt_id: PromptId | None = None,
|
|
38
|
-
base_adapter_config: AdapterConfig | None = None,
|
|
39
|
-
):
|
|
40
|
-
self.config = config
|
|
41
|
-
self.client = AsyncOpenAI(
|
|
42
|
-
api_key=config.api_key,
|
|
43
|
-
base_url=config.base_url,
|
|
44
|
-
default_headers=config.default_headers,
|
|
45
|
-
)
|
|
46
|
-
|
|
47
|
-
run_config = RunConfig(
|
|
48
|
-
task=kiln_task,
|
|
49
|
-
model_name=config.model_name,
|
|
50
|
-
model_provider_name=config.provider_name,
|
|
51
|
-
prompt_id=prompt_id or PromptGenerators.SIMPLE,
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
super().__init__(
|
|
55
|
-
run_config=run_config,
|
|
56
|
-
config=base_adapter_config,
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
async def _run(self, input: Dict | str) -> RunOutput:
|
|
60
|
-
provider = self.model_provider()
|
|
61
|
-
intermediate_outputs: dict[str, str] = {}
|
|
62
|
-
prompt = self.build_prompt()
|
|
63
|
-
user_msg = self.prompt_builder.build_user_message(input)
|
|
64
|
-
messages = [
|
|
65
|
-
ChatCompletionSystemMessageParam(role="system", content=prompt),
|
|
66
|
-
ChatCompletionUserMessageParam(role="user", content=user_msg),
|
|
67
|
-
]
|
|
68
|
-
|
|
69
|
-
run_strategy, cot_prompt = self.run_strategy()
|
|
70
|
-
|
|
71
|
-
if run_strategy == "cot_as_message":
|
|
72
|
-
if not cot_prompt:
|
|
73
|
-
raise ValueError("cot_prompt is required for cot_as_message strategy")
|
|
74
|
-
messages.append(
|
|
75
|
-
ChatCompletionSystemMessageParam(role="system", content=cot_prompt)
|
|
76
|
-
)
|
|
77
|
-
elif run_strategy == "cot_two_call":
|
|
78
|
-
if not cot_prompt:
|
|
79
|
-
raise ValueError("cot_prompt is required for cot_two_call strategy")
|
|
80
|
-
messages.append(
|
|
81
|
-
ChatCompletionSystemMessageParam(role="system", content=cot_prompt)
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
# First call for chain of thought
|
|
85
|
-
cot_response = await self.client.chat.completions.create(
|
|
86
|
-
model=provider.provider_options["model"],
|
|
87
|
-
messages=messages,
|
|
88
|
-
)
|
|
89
|
-
cot_content = cot_response.choices[0].message.content
|
|
90
|
-
if cot_content is not None:
|
|
91
|
-
intermediate_outputs["chain_of_thought"] = cot_content
|
|
92
|
-
|
|
93
|
-
messages.extend(
|
|
94
|
-
[
|
|
95
|
-
ChatCompletionAssistantMessageParam(
|
|
96
|
-
role="assistant", content=cot_content
|
|
97
|
-
),
|
|
98
|
-
ChatCompletionUserMessageParam(
|
|
99
|
-
role="user",
|
|
100
|
-
content=COT_FINAL_ANSWER_PROMPT,
|
|
101
|
-
),
|
|
102
|
-
]
|
|
103
|
-
)
|
|
104
|
-
|
|
105
|
-
# Build custom request params based on model provider
|
|
106
|
-
extra_body = self.build_extra_body(provider)
|
|
107
|
-
|
|
108
|
-
# Main completion call
|
|
109
|
-
response_format_options = await self.response_format_options()
|
|
110
|
-
response = await self.client.chat.completions.create(
|
|
111
|
-
model=provider.provider_options["model"],
|
|
112
|
-
messages=messages,
|
|
113
|
-
extra_body=extra_body,
|
|
114
|
-
logprobs=self.base_adapter_config.top_logprobs is not None,
|
|
115
|
-
top_logprobs=self.base_adapter_config.top_logprobs,
|
|
116
|
-
**response_format_options,
|
|
117
|
-
)
|
|
118
|
-
|
|
119
|
-
if not isinstance(response, ChatCompletion):
|
|
120
|
-
raise RuntimeError(
|
|
121
|
-
f"Expected ChatCompletion response, got {type(response)}."
|
|
122
|
-
)
|
|
123
|
-
|
|
124
|
-
if hasattr(response, "error") and response.error: # pyright: ignore
|
|
125
|
-
raise RuntimeError(
|
|
126
|
-
f"OpenAI compatible API returned status code {response.error.get('code')}: {response.error.get('message') or 'Unknown error'}.\nError: {response.error}" # pyright: ignore
|
|
127
|
-
)
|
|
128
|
-
if not response.choices or len(response.choices) == 0:
|
|
129
|
-
raise RuntimeError(
|
|
130
|
-
"No message content returned in the response from OpenAI compatible API"
|
|
131
|
-
)
|
|
132
|
-
|
|
133
|
-
message = response.choices[0].message
|
|
134
|
-
logprobs = response.choices[0].logprobs
|
|
135
|
-
|
|
136
|
-
# Check logprobs worked, if requested
|
|
137
|
-
if self.base_adapter_config.top_logprobs is not None and logprobs is None:
|
|
138
|
-
raise RuntimeError("Logprobs were required, but no logprobs were returned.")
|
|
139
|
-
|
|
140
|
-
# Save reasoning if it exists (OpenRouter specific api response field)
|
|
141
|
-
if provider.require_openrouter_reasoning:
|
|
142
|
-
if (
|
|
143
|
-
hasattr(message, "reasoning") and message.reasoning # pyright: ignore
|
|
144
|
-
):
|
|
145
|
-
intermediate_outputs["reasoning"] = message.reasoning # pyright: ignore
|
|
146
|
-
else:
|
|
147
|
-
raise RuntimeError(
|
|
148
|
-
"Reasoning is required for this model, but no reasoning was returned from OpenRouter."
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
# the string content of the response
|
|
152
|
-
response_content = message.content
|
|
153
|
-
|
|
154
|
-
# Fallback: Use args of first tool call to task_response if it exists
|
|
155
|
-
if not response_content and message.tool_calls:
|
|
156
|
-
tool_call = next(
|
|
157
|
-
(
|
|
158
|
-
tool_call
|
|
159
|
-
for tool_call in message.tool_calls
|
|
160
|
-
if tool_call.function.name == "task_response"
|
|
161
|
-
),
|
|
162
|
-
None,
|
|
163
|
-
)
|
|
164
|
-
if tool_call:
|
|
165
|
-
response_content = tool_call.function.arguments
|
|
166
|
-
|
|
167
|
-
if not isinstance(response_content, str):
|
|
168
|
-
raise RuntimeError(f"response is not a string: {response_content}")
|
|
169
|
-
|
|
170
|
-
# Parse to dict if we have structured output
|
|
171
|
-
output: Dict | str = response_content
|
|
172
|
-
if self.has_structured_output():
|
|
173
|
-
output = parse_json_string(response_content)
|
|
174
|
-
|
|
175
|
-
return RunOutput(
|
|
176
|
-
output=output,
|
|
177
|
-
intermediate_outputs=intermediate_outputs,
|
|
178
|
-
output_logprobs=logprobs,
|
|
179
|
-
)
|
|
180
|
-
|
|
181
|
-
def adapter_name(self) -> str:
|
|
182
|
-
return "kiln_openai_compatible_adapter"
|
|
183
|
-
|
|
184
|
-
async def response_format_options(self) -> dict[str, Any]:
|
|
185
|
-
# Unstructured if task isn't structured
|
|
186
|
-
if not self.has_structured_output():
|
|
187
|
-
return {}
|
|
188
|
-
|
|
189
|
-
provider = self.model_provider()
|
|
190
|
-
match provider.structured_output_mode:
|
|
191
|
-
case StructuredOutputMode.json_mode:
|
|
192
|
-
return {"response_format": {"type": "json_object"}}
|
|
193
|
-
case StructuredOutputMode.json_schema:
|
|
194
|
-
output_schema = self.task().output_schema()
|
|
195
|
-
return {
|
|
196
|
-
"response_format": {
|
|
197
|
-
"type": "json_schema",
|
|
198
|
-
"json_schema": {
|
|
199
|
-
"name": "task_response",
|
|
200
|
-
"schema": output_schema,
|
|
201
|
-
},
|
|
202
|
-
}
|
|
203
|
-
}
|
|
204
|
-
case StructuredOutputMode.function_calling_weak:
|
|
205
|
-
return self.tool_call_params(strict=False)
|
|
206
|
-
case StructuredOutputMode.function_calling:
|
|
207
|
-
return self.tool_call_params(strict=True)
|
|
208
|
-
case StructuredOutputMode.json_instructions:
|
|
209
|
-
# JSON done via instructions in prompt, not the API response format. Do not ask for json_object (see option below).
|
|
210
|
-
return {}
|
|
211
|
-
case StructuredOutputMode.json_instruction_and_object:
|
|
212
|
-
# We set response_format to json_object and also set json instructions in the prompt
|
|
213
|
-
return {"response_format": {"type": "json_object"}}
|
|
214
|
-
case StructuredOutputMode.default:
|
|
215
|
-
# Default to function calling -- it's older than the other modes. Higher compatibility.
|
|
216
|
-
return self.tool_call_params(strict=True)
|
|
217
|
-
case _:
|
|
218
|
-
raise_exhaustive_enum_error(provider.structured_output_mode)
|
|
219
|
-
|
|
220
|
-
def tool_call_params(self, strict: bool) -> dict[str, Any]:
|
|
221
|
-
# Add additional_properties: false to the schema (OpenAI requires this for some models)
|
|
222
|
-
output_schema = self.task().output_schema()
|
|
223
|
-
if not isinstance(output_schema, dict):
|
|
224
|
-
raise ValueError(
|
|
225
|
-
"Invalid output schema for this task. Can not use tool calls."
|
|
226
|
-
)
|
|
227
|
-
output_schema["additionalProperties"] = False
|
|
228
|
-
|
|
229
|
-
function_params = {
|
|
230
|
-
"name": "task_response",
|
|
231
|
-
"parameters": output_schema,
|
|
232
|
-
}
|
|
233
|
-
# This should be on, but we allow setting function_calling_weak for APIs that don't support it.
|
|
234
|
-
if strict:
|
|
235
|
-
function_params["strict"] = True
|
|
236
|
-
|
|
237
|
-
return {
|
|
238
|
-
"tools": [
|
|
239
|
-
{
|
|
240
|
-
"type": "function",
|
|
241
|
-
"function": function_params,
|
|
242
|
-
}
|
|
243
|
-
],
|
|
244
|
-
"tool_choice": {
|
|
245
|
-
"type": "function",
|
|
246
|
-
"function": {"name": "task_response"},
|
|
247
|
-
},
|
|
248
|
-
}
|
|
249
|
-
|
|
250
|
-
def build_extra_body(self, provider: KilnModelProvider) -> dict[str, Any]:
|
|
251
|
-
# TODO P1: Don't love having this logic here. But it's a usability improvement
|
|
252
|
-
# so better to keep it than exclude it. Should figure out how I want to isolate
|
|
253
|
-
# this sort of logic so it's config driven and can be overridden
|
|
254
|
-
|
|
255
|
-
extra_body = {}
|
|
256
|
-
provider_options = {}
|
|
257
|
-
|
|
258
|
-
if provider.require_openrouter_reasoning:
|
|
259
|
-
# https://openrouter.ai/docs/use-cases/reasoning-tokens
|
|
260
|
-
extra_body["reasoning"] = {
|
|
261
|
-
"exclude": False,
|
|
262
|
-
}
|
|
263
|
-
|
|
264
|
-
if provider.r1_openrouter_options:
|
|
265
|
-
# Require providers that support the reasoning parameter
|
|
266
|
-
provider_options["require_parameters"] = True
|
|
267
|
-
# Prefer R1 providers with reasonable perf/quants
|
|
268
|
-
provider_options["order"] = ["Fireworks", "Together"]
|
|
269
|
-
# R1 providers with unreasonable quants
|
|
270
|
-
provider_options["ignore"] = ["DeepInfra"]
|
|
271
|
-
|
|
272
|
-
# Only set of this request is to get logprobs.
|
|
273
|
-
if (
|
|
274
|
-
provider.logprobs_openrouter_options
|
|
275
|
-
and self.base_adapter_config.top_logprobs is not None
|
|
276
|
-
):
|
|
277
|
-
# Don't let OpenRouter choose a provider that doesn't support logprobs.
|
|
278
|
-
provider_options["require_parameters"] = True
|
|
279
|
-
# DeepInfra silently fails to return logprobs consistently.
|
|
280
|
-
provider_options["ignore"] = ["DeepInfra"]
|
|
281
|
-
|
|
282
|
-
if provider.openrouter_skip_required_parameters:
|
|
283
|
-
# Oddball case, R1 14/8/1.5B fail with this param, even though they support thinking params.
|
|
284
|
-
provider_options["require_parameters"] = False
|
|
285
|
-
|
|
286
|
-
if len(provider_options) > 0:
|
|
287
|
-
extra_body["provider"] = provider_options
|
|
288
|
-
|
|
289
|
-
return extra_body
|
|
@@ -1,343 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from unittest.mock import AsyncMock, MagicMock, patch
|
|
3
|
-
|
|
4
|
-
import pytest
|
|
5
|
-
from langchain_aws import ChatBedrockConverse
|
|
6
|
-
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
7
|
-
from langchain_fireworks import ChatFireworks
|
|
8
|
-
from langchain_groq import ChatGroq
|
|
9
|
-
from langchain_ollama import ChatOllama
|
|
10
|
-
|
|
11
|
-
from kiln_ai.adapters.ml_model_list import (
|
|
12
|
-
KilnModelProvider,
|
|
13
|
-
ModelProviderName,
|
|
14
|
-
StructuredOutputMode,
|
|
15
|
-
)
|
|
16
|
-
from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
|
|
17
|
-
from kiln_ai.adapters.model_adapters.langchain_adapters import (
|
|
18
|
-
LangchainAdapter,
|
|
19
|
-
langchain_model_from_provider,
|
|
20
|
-
)
|
|
21
|
-
from kiln_ai.adapters.test_prompt_adaptors import build_test_task
|
|
22
|
-
from kiln_ai.datamodel.task import RunConfig
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
@pytest.fixture
|
|
26
|
-
def mock_adapter(tmp_path):
|
|
27
|
-
return LangchainAdapter(
|
|
28
|
-
kiln_task=build_test_task(tmp_path),
|
|
29
|
-
model_name="llama_3_1_8b",
|
|
30
|
-
provider="ollama",
|
|
31
|
-
)
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def test_langchain_adapter_munge_response(mock_adapter):
|
|
35
|
-
# Mistral Large tool calling format is a bit different
|
|
36
|
-
response = {
|
|
37
|
-
"name": "task_response",
|
|
38
|
-
"arguments": {
|
|
39
|
-
"setup": "Why did the cow join a band?",
|
|
40
|
-
"punchline": "Because she wanted to be a moo-sician!",
|
|
41
|
-
},
|
|
42
|
-
}
|
|
43
|
-
munged = mock_adapter._munge_response(response)
|
|
44
|
-
assert munged["setup"] == "Why did the cow join a band?"
|
|
45
|
-
assert munged["punchline"] == "Because she wanted to be a moo-sician!"
|
|
46
|
-
|
|
47
|
-
# non mistral format should continue to work
|
|
48
|
-
munged = mock_adapter._munge_response(response["arguments"])
|
|
49
|
-
assert munged["setup"] == "Why did the cow join a band?"
|
|
50
|
-
assert munged["punchline"] == "Because she wanted to be a moo-sician!"
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def test_langchain_adapter_infer_model_name(tmp_path):
|
|
54
|
-
task = build_test_task(tmp_path)
|
|
55
|
-
custom = ChatGroq(model="llama-3.1-8b-instant", groq_api_key="test")
|
|
56
|
-
|
|
57
|
-
lca = LangchainAdapter(kiln_task=task, custom_model=custom)
|
|
58
|
-
|
|
59
|
-
assert lca.run_config.model_name == "custom.langchain:llama-3.1-8b-instant"
|
|
60
|
-
assert lca.run_config.model_provider_name == "custom.langchain:ChatGroq"
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def test_langchain_adapter_info(tmp_path):
|
|
64
|
-
task = build_test_task(tmp_path)
|
|
65
|
-
|
|
66
|
-
lca = LangchainAdapter(kiln_task=task, model_name="llama_3_1_8b", provider="ollama")
|
|
67
|
-
|
|
68
|
-
assert lca.adapter_name() == "kiln_langchain_adapter"
|
|
69
|
-
assert lca.run_config.model_name == "llama_3_1_8b"
|
|
70
|
-
assert lca.run_config.model_provider_name == "ollama"
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
async def test_langchain_adapter_with_cot(tmp_path):
|
|
74
|
-
task = build_test_task(tmp_path)
|
|
75
|
-
task.output_json_schema = (
|
|
76
|
-
'{"type": "object", "properties": {"count": {"type": "integer"}}}'
|
|
77
|
-
)
|
|
78
|
-
lca = LangchainAdapter(
|
|
79
|
-
kiln_task=task,
|
|
80
|
-
model_name="llama_3_1_8b",
|
|
81
|
-
provider="ollama",
|
|
82
|
-
prompt_id="simple_chain_of_thought_prompt_builder",
|
|
83
|
-
)
|
|
84
|
-
|
|
85
|
-
# Mock the base model and its invoke method
|
|
86
|
-
mock_base_model = MagicMock()
|
|
87
|
-
mock_base_model.ainvoke = AsyncMock(
|
|
88
|
-
return_value=AIMessage(content="Chain of thought reasoning...")
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
# Create a separate mock for self.model()
|
|
92
|
-
mock_model_instance = MagicMock()
|
|
93
|
-
mock_model_instance.ainvoke = AsyncMock(return_value={"parsed": {"count": 1}})
|
|
94
|
-
|
|
95
|
-
# Mock the langchain_model_from function to return the base model
|
|
96
|
-
mock_model_from = AsyncMock(return_value=mock_base_model)
|
|
97
|
-
|
|
98
|
-
# Patch both the langchain_model_from function and self.model()
|
|
99
|
-
with (
|
|
100
|
-
patch.object(LangchainAdapter, "langchain_model_from", mock_model_from),
|
|
101
|
-
patch.object(LangchainAdapter, "model", return_value=mock_model_instance),
|
|
102
|
-
):
|
|
103
|
-
response = await lca._run("test input")
|
|
104
|
-
|
|
105
|
-
# First 3 messages are the same for both calls
|
|
106
|
-
for invoke_args in [
|
|
107
|
-
mock_base_model.ainvoke.call_args[0][0],
|
|
108
|
-
mock_model_instance.ainvoke.call_args[0][0],
|
|
109
|
-
]:
|
|
110
|
-
assert isinstance(
|
|
111
|
-
invoke_args[0], SystemMessage
|
|
112
|
-
) # First message should be system prompt
|
|
113
|
-
assert (
|
|
114
|
-
"You are an assistant which performs math tasks provided in plain text."
|
|
115
|
-
in invoke_args[0].content
|
|
116
|
-
)
|
|
117
|
-
assert isinstance(invoke_args[1], HumanMessage)
|
|
118
|
-
assert "test input" in invoke_args[1].content
|
|
119
|
-
assert isinstance(invoke_args[2], SystemMessage)
|
|
120
|
-
assert "step by step" in invoke_args[2].content
|
|
121
|
-
|
|
122
|
-
# the COT should only have 3 messages
|
|
123
|
-
assert len(mock_base_model.ainvoke.call_args[0][0]) == 3
|
|
124
|
-
assert len(mock_model_instance.ainvoke.call_args[0][0]) == 5
|
|
125
|
-
|
|
126
|
-
# the final response should have the COT content and the final instructions
|
|
127
|
-
invoke_args = mock_model_instance.ainvoke.call_args[0][0]
|
|
128
|
-
assert isinstance(invoke_args[3], AIMessage)
|
|
129
|
-
assert "Chain of thought reasoning..." in invoke_args[3].content
|
|
130
|
-
assert isinstance(invoke_args[4], HumanMessage)
|
|
131
|
-
assert COT_FINAL_ANSWER_PROMPT in invoke_args[4].content
|
|
132
|
-
|
|
133
|
-
assert (
|
|
134
|
-
response.intermediate_outputs["chain_of_thought"]
|
|
135
|
-
== "Chain of thought reasoning..."
|
|
136
|
-
)
|
|
137
|
-
assert response.output == {"count": 1}
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
@pytest.mark.parametrize(
|
|
141
|
-
"structured_output_mode,expected_method",
|
|
142
|
-
[
|
|
143
|
-
(StructuredOutputMode.function_calling, "function_calling"),
|
|
144
|
-
(StructuredOutputMode.json_mode, "json_mode"),
|
|
145
|
-
(StructuredOutputMode.json_schema, "json_schema"),
|
|
146
|
-
(StructuredOutputMode.json_instruction_and_object, "json_mode"),
|
|
147
|
-
(StructuredOutputMode.default, None),
|
|
148
|
-
],
|
|
149
|
-
)
|
|
150
|
-
async def test_get_structured_output_options(
|
|
151
|
-
mock_adapter, structured_output_mode, expected_method
|
|
152
|
-
):
|
|
153
|
-
# Mock the provider response
|
|
154
|
-
mock_provider = MagicMock()
|
|
155
|
-
mock_provider.structured_output_mode = structured_output_mode
|
|
156
|
-
|
|
157
|
-
# Mock adapter.model_provider()
|
|
158
|
-
mock_adapter.model_provider = MagicMock(return_value=mock_provider)
|
|
159
|
-
|
|
160
|
-
options = mock_adapter.get_structured_output_options("model_name", "provider")
|
|
161
|
-
assert options.get("method") == expected_method
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
@pytest.mark.asyncio
|
|
165
|
-
async def test_langchain_model_from_provider_groq():
|
|
166
|
-
provider = KilnModelProvider(
|
|
167
|
-
name=ModelProviderName.groq, provider_options={"model": "mixtral-8x7b"}
|
|
168
|
-
)
|
|
169
|
-
|
|
170
|
-
with patch(
|
|
171
|
-
"kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
|
|
172
|
-
) as mock_config:
|
|
173
|
-
mock_config.return_value.groq_api_key = "test_key"
|
|
174
|
-
model = await langchain_model_from_provider(provider, "mixtral-8x7b")
|
|
175
|
-
assert isinstance(model, ChatGroq)
|
|
176
|
-
assert model.model_name == "mixtral-8x7b"
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
@pytest.mark.asyncio
|
|
180
|
-
async def test_langchain_model_from_provider_bedrock():
|
|
181
|
-
provider = KilnModelProvider(
|
|
182
|
-
name=ModelProviderName.amazon_bedrock,
|
|
183
|
-
provider_options={"model": "anthropic.claude-v2", "region_name": "us-east-1"},
|
|
184
|
-
)
|
|
185
|
-
|
|
186
|
-
with patch(
|
|
187
|
-
"kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
|
|
188
|
-
) as mock_config:
|
|
189
|
-
mock_config.return_value.bedrock_access_key = "test_access"
|
|
190
|
-
mock_config.return_value.bedrock_secret_key = "test_secret"
|
|
191
|
-
model = await langchain_model_from_provider(provider, "anthropic.claude-v2")
|
|
192
|
-
assert isinstance(model, ChatBedrockConverse)
|
|
193
|
-
assert os.environ.get("AWS_ACCESS_KEY_ID") == "test_access"
|
|
194
|
-
assert os.environ.get("AWS_SECRET_ACCESS_KEY") == "test_secret"
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
@pytest.mark.asyncio
|
|
198
|
-
async def test_langchain_model_from_provider_fireworks():
|
|
199
|
-
provider = KilnModelProvider(
|
|
200
|
-
name=ModelProviderName.fireworks_ai, provider_options={"model": "mixtral-8x7b"}
|
|
201
|
-
)
|
|
202
|
-
|
|
203
|
-
with patch(
|
|
204
|
-
"kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
|
|
205
|
-
) as mock_config:
|
|
206
|
-
mock_config.return_value.fireworks_api_key = "test_key"
|
|
207
|
-
model = await langchain_model_from_provider(provider, "mixtral-8x7b")
|
|
208
|
-
assert isinstance(model, ChatFireworks)
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
@pytest.mark.asyncio
|
|
212
|
-
async def test_langchain_model_from_provider_ollama():
|
|
213
|
-
provider = KilnModelProvider(
|
|
214
|
-
name=ModelProviderName.ollama,
|
|
215
|
-
provider_options={"model": "llama2", "model_aliases": ["llama2-uncensored"]},
|
|
216
|
-
)
|
|
217
|
-
|
|
218
|
-
mock_connection = MagicMock()
|
|
219
|
-
with (
|
|
220
|
-
patch(
|
|
221
|
-
"kiln_ai.adapters.model_adapters.langchain_adapters.get_ollama_connection",
|
|
222
|
-
return_value=AsyncMock(return_value=mock_connection),
|
|
223
|
-
),
|
|
224
|
-
patch(
|
|
225
|
-
"kiln_ai.adapters.model_adapters.langchain_adapters.ollama_model_installed",
|
|
226
|
-
return_value=True,
|
|
227
|
-
),
|
|
228
|
-
patch(
|
|
229
|
-
"kiln_ai.adapters.model_adapters.langchain_adapters.ollama_base_url",
|
|
230
|
-
return_value="http://localhost:11434",
|
|
231
|
-
),
|
|
232
|
-
):
|
|
233
|
-
model = await langchain_model_from_provider(provider, "llama2")
|
|
234
|
-
assert isinstance(model, ChatOllama)
|
|
235
|
-
assert model.model == "llama2"
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
@pytest.mark.asyncio
|
|
239
|
-
async def test_langchain_model_from_provider_invalid():
|
|
240
|
-
provider = KilnModelProvider.model_construct(
|
|
241
|
-
name="invalid_provider", provider_options={}
|
|
242
|
-
)
|
|
243
|
-
|
|
244
|
-
with pytest.raises(ValueError, match="Invalid model or provider"):
|
|
245
|
-
await langchain_model_from_provider(provider, "test_model")
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
@pytest.mark.asyncio
|
|
249
|
-
async def test_langchain_adapter_model_caching(tmp_path):
|
|
250
|
-
task = build_test_task(tmp_path)
|
|
251
|
-
custom_model = ChatGroq(model="mixtral-8x7b", groq_api_key="test")
|
|
252
|
-
|
|
253
|
-
adapter = LangchainAdapter(kiln_task=task, custom_model=custom_model)
|
|
254
|
-
|
|
255
|
-
# First call should return the cached model
|
|
256
|
-
model1 = await adapter.model()
|
|
257
|
-
assert model1 is custom_model
|
|
258
|
-
|
|
259
|
-
# Second call should return the same cached instance
|
|
260
|
-
model2 = await adapter.model()
|
|
261
|
-
assert model2 is model1
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
@pytest.mark.asyncio
|
|
265
|
-
async def test_langchain_adapter_model_structured_output(tmp_path):
|
|
266
|
-
task = build_test_task(tmp_path)
|
|
267
|
-
task.output_json_schema = """
|
|
268
|
-
{
|
|
269
|
-
"type": "object",
|
|
270
|
-
"properties": {
|
|
271
|
-
"count": {"type": "integer"}
|
|
272
|
-
}
|
|
273
|
-
}
|
|
274
|
-
"""
|
|
275
|
-
|
|
276
|
-
mock_model = MagicMock()
|
|
277
|
-
mock_model.with_structured_output = MagicMock(return_value="structured_model")
|
|
278
|
-
|
|
279
|
-
adapter = LangchainAdapter(
|
|
280
|
-
kiln_task=task, model_name="test_model", provider="ollama"
|
|
281
|
-
)
|
|
282
|
-
adapter.get_structured_output_options = MagicMock(
|
|
283
|
-
return_value={"option1": "value1"}
|
|
284
|
-
)
|
|
285
|
-
adapter.langchain_model_from = AsyncMock(return_value=mock_model)
|
|
286
|
-
|
|
287
|
-
model = await adapter.model()
|
|
288
|
-
|
|
289
|
-
# Verify the model was configured with structured output
|
|
290
|
-
mock_model.with_structured_output.assert_called_once_with(
|
|
291
|
-
{
|
|
292
|
-
"type": "object",
|
|
293
|
-
"properties": {"count": {"type": "integer"}},
|
|
294
|
-
"title": "task_response",
|
|
295
|
-
"description": "A response from the task",
|
|
296
|
-
},
|
|
297
|
-
include_raw=True,
|
|
298
|
-
option1="value1",
|
|
299
|
-
)
|
|
300
|
-
assert model == "structured_model"
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
@pytest.mark.asyncio
|
|
304
|
-
async def test_langchain_adapter_model_no_structured_output_support(tmp_path):
|
|
305
|
-
task = build_test_task(tmp_path)
|
|
306
|
-
task.output_json_schema = (
|
|
307
|
-
'{"type": "object", "properties": {"count": {"type": "integer"}}}'
|
|
308
|
-
)
|
|
309
|
-
|
|
310
|
-
mock_model = MagicMock()
|
|
311
|
-
# Remove with_structured_output method
|
|
312
|
-
del mock_model.with_structured_output
|
|
313
|
-
|
|
314
|
-
adapter = LangchainAdapter(
|
|
315
|
-
kiln_task=task, model_name="test_model", provider="ollama"
|
|
316
|
-
)
|
|
317
|
-
adapter.langchain_model_from = AsyncMock(return_value=mock_model)
|
|
318
|
-
|
|
319
|
-
with pytest.raises(ValueError, match="does not support structured output"):
|
|
320
|
-
await adapter.model()
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
import pytest
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
@pytest.mark.parametrize(
|
|
327
|
-
"provider_name",
|
|
328
|
-
[
|
|
329
|
-
(ModelProviderName.openai),
|
|
330
|
-
(ModelProviderName.openai_compatible),
|
|
331
|
-
(ModelProviderName.openrouter),
|
|
332
|
-
],
|
|
333
|
-
)
|
|
334
|
-
@pytest.mark.asyncio
|
|
335
|
-
async def test_langchain_model_from_provider_unsupported_providers(provider_name):
|
|
336
|
-
# Arrange
|
|
337
|
-
provider = KilnModelProvider(
|
|
338
|
-
name=provider_name, provider_options={}, structured_output_mode="default"
|
|
339
|
-
)
|
|
340
|
-
|
|
341
|
-
# Assert unsupported providers raise an error
|
|
342
|
-
with pytest.raises(ValueError):
|
|
343
|
-
await langchain_model_from_provider(provider, "test-model")
|