kiln-ai 0.11.1__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +4 -0
- kiln_ai/adapters/adapter_registry.py +163 -39
- kiln_ai/adapters/data_gen/data_gen_task.py +18 -0
- kiln_ai/adapters/eval/__init__.py +28 -0
- kiln_ai/adapters/eval/base_eval.py +164 -0
- kiln_ai/adapters/eval/eval_runner.py +270 -0
- kiln_ai/adapters/eval/g_eval.py +368 -0
- kiln_ai/adapters/eval/registry.py +16 -0
- kiln_ai/adapters/eval/test_base_eval.py +325 -0
- kiln_ai/adapters/eval/test_eval_runner.py +641 -0
- kiln_ai/adapters/eval/test_g_eval.py +498 -0
- kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +4 -1
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +1 -1
- kiln_ai/adapters/fine_tune/test_together_finetune.py +531 -0
- kiln_ai/adapters/fine_tune/together_finetune.py +325 -0
- kiln_ai/adapters/ml_model_list.py +758 -163
- kiln_ai/adapters/model_adapters/__init__.py +2 -4
- kiln_ai/adapters/model_adapters/base_adapter.py +61 -43
- kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
- kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +22 -13
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -19
- kiln_ai/adapters/model_adapters/test_structured_output.py +59 -35
- kiln_ai/adapters/ollama_tools.py +3 -3
- kiln_ai/adapters/parsers/r1_parser.py +19 -14
- kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
- kiln_ai/adapters/prompt_builders.py +80 -42
- kiln_ai/adapters/provider_tools.py +50 -58
- kiln_ai/adapters/repair/repair_task.py +9 -21
- kiln_ai/adapters/repair/test_repair_task.py +6 -6
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +26 -29
- kiln_ai/adapters/test_generate_docs.py +4 -4
- kiln_ai/adapters/test_ollama_tools.py +0 -1
- kiln_ai/adapters/test_prompt_adaptors.py +47 -33
- kiln_ai/adapters/test_prompt_builders.py +91 -31
- kiln_ai/adapters/test_provider_tools.py +26 -81
- kiln_ai/datamodel/__init__.py +50 -952
- kiln_ai/datamodel/basemodel.py +2 -0
- kiln_ai/datamodel/datamodel_enums.py +60 -0
- kiln_ai/datamodel/dataset_filters.py +114 -0
- kiln_ai/datamodel/dataset_split.py +170 -0
- kiln_ai/datamodel/eval.py +298 -0
- kiln_ai/datamodel/finetune.py +105 -0
- kiln_ai/datamodel/json_schema.py +7 -1
- kiln_ai/datamodel/project.py +23 -0
- kiln_ai/datamodel/prompt.py +37 -0
- kiln_ai/datamodel/prompt_id.py +83 -0
- kiln_ai/datamodel/strict_mode.py +24 -0
- kiln_ai/datamodel/task.py +181 -0
- kiln_ai/datamodel/task_output.py +328 -0
- kiln_ai/datamodel/task_run.py +164 -0
- kiln_ai/datamodel/test_basemodel.py +19 -11
- kiln_ai/datamodel/test_dataset_filters.py +71 -0
- kiln_ai/datamodel/test_dataset_split.py +32 -8
- kiln_ai/datamodel/test_datasource.py +22 -2
- kiln_ai/datamodel/test_eval_model.py +635 -0
- kiln_ai/datamodel/test_example_models.py +9 -13
- kiln_ai/datamodel/test_json_schema.py +23 -0
- kiln_ai/datamodel/test_models.py +2 -2
- kiln_ai/datamodel/test_prompt_id.py +129 -0
- kiln_ai/datamodel/test_task.py +159 -0
- kiln_ai/utils/config.py +43 -1
- kiln_ai/utils/dataset_import.py +232 -0
- kiln_ai/utils/test_dataset_import.py +596 -0
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/METADATA +86 -6
- kiln_ai-0.13.0.dist-info/RECORD +103 -0
- kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -302
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -11
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -246
- kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -350
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -225
- kiln_ai-0.11.1.dist-info/RECORD +0 -76
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,246 +0,0 @@
|
|
|
1
|
-
from typing import Any, Dict
|
|
2
|
-
|
|
3
|
-
from openai import AsyncOpenAI
|
|
4
|
-
from openai.types.chat import (
|
|
5
|
-
ChatCompletion,
|
|
6
|
-
ChatCompletionAssistantMessageParam,
|
|
7
|
-
ChatCompletionSystemMessageParam,
|
|
8
|
-
ChatCompletionUserMessageParam,
|
|
9
|
-
)
|
|
10
|
-
|
|
11
|
-
import kiln_ai.datamodel as datamodel
|
|
12
|
-
from kiln_ai.adapters.ml_model_list import StructuredOutputMode
|
|
13
|
-
from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
14
|
-
COT_FINAL_ANSWER_PROMPT,
|
|
15
|
-
AdapterInfo,
|
|
16
|
-
BaseAdapter,
|
|
17
|
-
BasePromptBuilder,
|
|
18
|
-
RunOutput,
|
|
19
|
-
)
|
|
20
|
-
from kiln_ai.adapters.model_adapters.openai_compatible_config import (
|
|
21
|
-
OpenAICompatibleConfig,
|
|
22
|
-
)
|
|
23
|
-
from kiln_ai.adapters.parsers.json_parser import parse_json_string
|
|
24
|
-
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class OpenAICompatibleAdapter(BaseAdapter):
|
|
28
|
-
def __init__(
|
|
29
|
-
self,
|
|
30
|
-
config: OpenAICompatibleConfig,
|
|
31
|
-
kiln_task: datamodel.Task,
|
|
32
|
-
prompt_builder: BasePromptBuilder | None = None,
|
|
33
|
-
tags: list[str] | None = None,
|
|
34
|
-
):
|
|
35
|
-
self.config = config
|
|
36
|
-
self.client = AsyncOpenAI(
|
|
37
|
-
api_key=config.api_key,
|
|
38
|
-
base_url=config.base_url,
|
|
39
|
-
default_headers=config.default_headers,
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
super().__init__(
|
|
43
|
-
kiln_task,
|
|
44
|
-
model_name=config.model_name,
|
|
45
|
-
model_provider_name=config.provider_name,
|
|
46
|
-
prompt_builder=prompt_builder,
|
|
47
|
-
tags=tags,
|
|
48
|
-
)
|
|
49
|
-
|
|
50
|
-
async def _run(self, input: Dict | str) -> RunOutput:
|
|
51
|
-
provider = self.model_provider()
|
|
52
|
-
intermediate_outputs: dict[str, str] = {}
|
|
53
|
-
prompt = self.build_prompt()
|
|
54
|
-
user_msg = self.prompt_builder.build_user_message(input)
|
|
55
|
-
messages = [
|
|
56
|
-
ChatCompletionSystemMessageParam(role="system", content=prompt),
|
|
57
|
-
ChatCompletionUserMessageParam(role="user", content=user_msg),
|
|
58
|
-
]
|
|
59
|
-
|
|
60
|
-
run_strategy, cot_prompt = self.run_strategy()
|
|
61
|
-
|
|
62
|
-
if run_strategy == "cot_as_message":
|
|
63
|
-
if not cot_prompt:
|
|
64
|
-
raise ValueError("cot_prompt is required for cot_as_message strategy")
|
|
65
|
-
messages.append(
|
|
66
|
-
ChatCompletionSystemMessageParam(role="system", content=cot_prompt)
|
|
67
|
-
)
|
|
68
|
-
elif run_strategy == "cot_two_call":
|
|
69
|
-
if not cot_prompt:
|
|
70
|
-
raise ValueError("cot_prompt is required for cot_two_call strategy")
|
|
71
|
-
messages.append(
|
|
72
|
-
ChatCompletionSystemMessageParam(role="system", content=cot_prompt)
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
# First call for chain of thought
|
|
76
|
-
cot_response = await self.client.chat.completions.create(
|
|
77
|
-
model=provider.provider_options["model"],
|
|
78
|
-
messages=messages,
|
|
79
|
-
)
|
|
80
|
-
cot_content = cot_response.choices[0].message.content
|
|
81
|
-
if cot_content is not None:
|
|
82
|
-
intermediate_outputs["chain_of_thought"] = cot_content
|
|
83
|
-
|
|
84
|
-
messages.extend(
|
|
85
|
-
[
|
|
86
|
-
ChatCompletionAssistantMessageParam(
|
|
87
|
-
role="assistant", content=cot_content
|
|
88
|
-
),
|
|
89
|
-
ChatCompletionUserMessageParam(
|
|
90
|
-
role="user",
|
|
91
|
-
content=COT_FINAL_ANSWER_PROMPT,
|
|
92
|
-
),
|
|
93
|
-
]
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
# OpenRouter specific options for reasoning models
|
|
97
|
-
extra_body = {}
|
|
98
|
-
require_or_reasoning = (
|
|
99
|
-
self.config.openrouter_style_reasoning and provider.reasoning_capable
|
|
100
|
-
)
|
|
101
|
-
if require_or_reasoning:
|
|
102
|
-
extra_body["include_reasoning"] = True
|
|
103
|
-
# Filter to providers that support the reasoning parameter
|
|
104
|
-
extra_body["provider"] = {
|
|
105
|
-
"require_parameters": True,
|
|
106
|
-
# Ugly to have these here, but big range of quality of R1 providers
|
|
107
|
-
"order": ["Fireworks", "Together"],
|
|
108
|
-
# fp8 quants are awful
|
|
109
|
-
"ignore": ["DeepInfra"],
|
|
110
|
-
}
|
|
111
|
-
|
|
112
|
-
# Main completion call
|
|
113
|
-
response_format_options = await self.response_format_options()
|
|
114
|
-
response = await self.client.chat.completions.create(
|
|
115
|
-
model=provider.provider_options["model"],
|
|
116
|
-
messages=messages,
|
|
117
|
-
extra_body=extra_body,
|
|
118
|
-
**response_format_options,
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
if not isinstance(response, ChatCompletion):
|
|
122
|
-
raise RuntimeError(
|
|
123
|
-
f"Expected ChatCompletion response, got {type(response)}."
|
|
124
|
-
)
|
|
125
|
-
|
|
126
|
-
if hasattr(response, "error") and response.error: # pyright: ignore
|
|
127
|
-
raise RuntimeError(
|
|
128
|
-
f"OpenAI compatible API returned status code {response.error.get('code')}: {response.error.get('message') or 'Unknown error'}.\nError: {response.error}" # pyright: ignore
|
|
129
|
-
)
|
|
130
|
-
if not response.choices or len(response.choices) == 0:
|
|
131
|
-
raise RuntimeError(
|
|
132
|
-
"No message content returned in the response from OpenAI compatible API"
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
message = response.choices[0].message
|
|
136
|
-
|
|
137
|
-
# Save reasoning if it exists (OpenRouter specific format)
|
|
138
|
-
if require_or_reasoning:
|
|
139
|
-
if (
|
|
140
|
-
hasattr(message, "reasoning") and message.reasoning # pyright: ignore
|
|
141
|
-
):
|
|
142
|
-
intermediate_outputs["reasoning"] = message.reasoning # pyright: ignore
|
|
143
|
-
else:
|
|
144
|
-
raise RuntimeError(
|
|
145
|
-
"Reasoning is required for this model, but no reasoning was returned from OpenRouter."
|
|
146
|
-
)
|
|
147
|
-
|
|
148
|
-
# the string content of the response
|
|
149
|
-
response_content = message.content
|
|
150
|
-
|
|
151
|
-
# Fallback: Use args of first tool call to task_response if it exists
|
|
152
|
-
if not response_content and message.tool_calls:
|
|
153
|
-
tool_call = next(
|
|
154
|
-
(
|
|
155
|
-
tool_call
|
|
156
|
-
for tool_call in message.tool_calls
|
|
157
|
-
if tool_call.function.name == "task_response"
|
|
158
|
-
),
|
|
159
|
-
None,
|
|
160
|
-
)
|
|
161
|
-
if tool_call:
|
|
162
|
-
response_content = tool_call.function.arguments
|
|
163
|
-
|
|
164
|
-
if not isinstance(response_content, str):
|
|
165
|
-
raise RuntimeError(f"response is not a string: {response_content}")
|
|
166
|
-
|
|
167
|
-
if self.has_structured_output():
|
|
168
|
-
structured_response = parse_json_string(response_content)
|
|
169
|
-
return RunOutput(
|
|
170
|
-
output=structured_response,
|
|
171
|
-
intermediate_outputs=intermediate_outputs,
|
|
172
|
-
)
|
|
173
|
-
|
|
174
|
-
return RunOutput(
|
|
175
|
-
output=response_content,
|
|
176
|
-
intermediate_outputs=intermediate_outputs,
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
def adapter_info(self) -> AdapterInfo:
|
|
180
|
-
return AdapterInfo(
|
|
181
|
-
model_name=self.model_name,
|
|
182
|
-
model_provider=self.model_provider_name,
|
|
183
|
-
adapter_name="kiln_openai_compatible_adapter",
|
|
184
|
-
prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(),
|
|
185
|
-
prompt_id=self.prompt_builder.prompt_id(),
|
|
186
|
-
)
|
|
187
|
-
|
|
188
|
-
async def response_format_options(self) -> dict[str, Any]:
|
|
189
|
-
# Unstructured if task isn't structured
|
|
190
|
-
if not self.has_structured_output():
|
|
191
|
-
return {}
|
|
192
|
-
|
|
193
|
-
provider = self.model_provider()
|
|
194
|
-
match provider.structured_output_mode:
|
|
195
|
-
case StructuredOutputMode.json_mode:
|
|
196
|
-
return {"response_format": {"type": "json_object"}}
|
|
197
|
-
case StructuredOutputMode.json_schema:
|
|
198
|
-
output_schema = self.kiln_task.output_schema()
|
|
199
|
-
return {
|
|
200
|
-
"response_format": {
|
|
201
|
-
"type": "json_schema",
|
|
202
|
-
"json_schema": {
|
|
203
|
-
"name": "task_response",
|
|
204
|
-
"schema": output_schema,
|
|
205
|
-
},
|
|
206
|
-
}
|
|
207
|
-
}
|
|
208
|
-
case StructuredOutputMode.function_calling:
|
|
209
|
-
return self.tool_call_params()
|
|
210
|
-
case StructuredOutputMode.json_instructions:
|
|
211
|
-
# JSON done via instructions in prompt, not the API response format. Do not ask for json_object (see option below).
|
|
212
|
-
return {}
|
|
213
|
-
case StructuredOutputMode.json_instruction_and_object:
|
|
214
|
-
# We set response_format to json_object and also set json instructions in the prompt
|
|
215
|
-
return {"response_format": {"type": "json_object"}}
|
|
216
|
-
case StructuredOutputMode.default:
|
|
217
|
-
# Default to function calling -- it's older than the other modes. Higher compatibility.
|
|
218
|
-
return self.tool_call_params()
|
|
219
|
-
case _:
|
|
220
|
-
raise_exhaustive_enum_error(provider.structured_output_mode)
|
|
221
|
-
|
|
222
|
-
def tool_call_params(self) -> dict[str, Any]:
|
|
223
|
-
# Add additional_properties: false to the schema (OpenAI requires this for some models)
|
|
224
|
-
output_schema = self.kiln_task.output_schema()
|
|
225
|
-
if not isinstance(output_schema, dict):
|
|
226
|
-
raise ValueError(
|
|
227
|
-
"Invalid output schema for this task. Can not use tool calls."
|
|
228
|
-
)
|
|
229
|
-
output_schema["additionalProperties"] = False
|
|
230
|
-
|
|
231
|
-
return {
|
|
232
|
-
"tools": [
|
|
233
|
-
{
|
|
234
|
-
"type": "function",
|
|
235
|
-
"function": {
|
|
236
|
-
"name": "task_response",
|
|
237
|
-
"parameters": output_schema,
|
|
238
|
-
"strict": True,
|
|
239
|
-
},
|
|
240
|
-
}
|
|
241
|
-
],
|
|
242
|
-
"tool_choice": {
|
|
243
|
-
"type": "function",
|
|
244
|
-
"function": {"name": "task_response"},
|
|
245
|
-
},
|
|
246
|
-
}
|
|
@@ -1,350 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from unittest.mock import AsyncMock, MagicMock, patch
|
|
3
|
-
|
|
4
|
-
import pytest
|
|
5
|
-
from langchain_aws import ChatBedrockConverse
|
|
6
|
-
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
7
|
-
from langchain_fireworks import ChatFireworks
|
|
8
|
-
from langchain_groq import ChatGroq
|
|
9
|
-
from langchain_ollama import ChatOllama
|
|
10
|
-
|
|
11
|
-
from kiln_ai.adapters.ml_model_list import (
|
|
12
|
-
KilnModelProvider,
|
|
13
|
-
ModelProviderName,
|
|
14
|
-
StructuredOutputMode,
|
|
15
|
-
)
|
|
16
|
-
from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
|
|
17
|
-
from kiln_ai.adapters.model_adapters.langchain_adapters import (
|
|
18
|
-
LangchainAdapter,
|
|
19
|
-
langchain_model_from_provider,
|
|
20
|
-
)
|
|
21
|
-
from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
|
|
22
|
-
from kiln_ai.adapters.test_prompt_adaptors import build_test_task
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
@pytest.fixture
|
|
26
|
-
def mock_adapter(tmp_path):
|
|
27
|
-
return LangchainAdapter(
|
|
28
|
-
kiln_task=build_test_task(tmp_path),
|
|
29
|
-
model_name="llama_3_1_8b",
|
|
30
|
-
provider="ollama",
|
|
31
|
-
)
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def test_langchain_adapter_munge_response(mock_adapter):
|
|
35
|
-
# Mistral Large tool calling format is a bit different
|
|
36
|
-
response = {
|
|
37
|
-
"name": "task_response",
|
|
38
|
-
"arguments": {
|
|
39
|
-
"setup": "Why did the cow join a band?",
|
|
40
|
-
"punchline": "Because she wanted to be a moo-sician!",
|
|
41
|
-
},
|
|
42
|
-
}
|
|
43
|
-
munged = mock_adapter._munge_response(response)
|
|
44
|
-
assert munged["setup"] == "Why did the cow join a band?"
|
|
45
|
-
assert munged["punchline"] == "Because she wanted to be a moo-sician!"
|
|
46
|
-
|
|
47
|
-
# non mistral format should continue to work
|
|
48
|
-
munged = mock_adapter._munge_response(response["arguments"])
|
|
49
|
-
assert munged["setup"] == "Why did the cow join a band?"
|
|
50
|
-
assert munged["punchline"] == "Because she wanted to be a moo-sician!"
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def test_langchain_adapter_infer_model_name(tmp_path):
|
|
54
|
-
task = build_test_task(tmp_path)
|
|
55
|
-
custom = ChatGroq(model="llama-3.1-8b-instant", groq_api_key="test")
|
|
56
|
-
|
|
57
|
-
lca = LangchainAdapter(kiln_task=task, custom_model=custom)
|
|
58
|
-
|
|
59
|
-
model_info = lca.adapter_info()
|
|
60
|
-
assert model_info.model_name == "custom.langchain:llama-3.1-8b-instant"
|
|
61
|
-
assert model_info.model_provider == "custom.langchain:ChatGroq"
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def test_langchain_adapter_info(tmp_path):
|
|
65
|
-
task = build_test_task(tmp_path)
|
|
66
|
-
|
|
67
|
-
lca = LangchainAdapter(kiln_task=task, model_name="llama_3_1_8b", provider="ollama")
|
|
68
|
-
|
|
69
|
-
model_info = lca.adapter_info()
|
|
70
|
-
assert model_info.adapter_name == "kiln_langchain_adapter"
|
|
71
|
-
assert model_info.model_name == "llama_3_1_8b"
|
|
72
|
-
assert model_info.model_provider == "ollama"
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
async def test_langchain_adapter_with_cot(tmp_path):
|
|
76
|
-
task = build_test_task(tmp_path)
|
|
77
|
-
task.output_json_schema = (
|
|
78
|
-
'{"type": "object", "properties": {"count": {"type": "integer"}}}'
|
|
79
|
-
)
|
|
80
|
-
lca = LangchainAdapter(
|
|
81
|
-
kiln_task=task,
|
|
82
|
-
model_name="llama_3_1_8b",
|
|
83
|
-
provider="ollama",
|
|
84
|
-
prompt_builder=SimpleChainOfThoughtPromptBuilder(task),
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
# Mock the base model and its invoke method
|
|
88
|
-
mock_base_model = MagicMock()
|
|
89
|
-
mock_base_model.ainvoke = AsyncMock(
|
|
90
|
-
return_value=AIMessage(content="Chain of thought reasoning...")
|
|
91
|
-
)
|
|
92
|
-
|
|
93
|
-
# Create a separate mock for self.model()
|
|
94
|
-
mock_model_instance = MagicMock()
|
|
95
|
-
mock_model_instance.ainvoke = AsyncMock(return_value={"parsed": {"count": 1}})
|
|
96
|
-
|
|
97
|
-
# Mock the langchain_model_from function to return the base model
|
|
98
|
-
mock_model_from = AsyncMock(return_value=mock_base_model)
|
|
99
|
-
|
|
100
|
-
# Patch both the langchain_model_from function and self.model()
|
|
101
|
-
with (
|
|
102
|
-
patch.object(LangchainAdapter, "langchain_model_from", mock_model_from),
|
|
103
|
-
patch.object(LangchainAdapter, "model", return_value=mock_model_instance),
|
|
104
|
-
):
|
|
105
|
-
response = await lca._run("test input")
|
|
106
|
-
|
|
107
|
-
# First 3 messages are the same for both calls
|
|
108
|
-
for invoke_args in [
|
|
109
|
-
mock_base_model.ainvoke.call_args[0][0],
|
|
110
|
-
mock_model_instance.ainvoke.call_args[0][0],
|
|
111
|
-
]:
|
|
112
|
-
assert isinstance(
|
|
113
|
-
invoke_args[0], SystemMessage
|
|
114
|
-
) # First message should be system prompt
|
|
115
|
-
assert (
|
|
116
|
-
"You are an assistant which performs math tasks provided in plain text."
|
|
117
|
-
in invoke_args[0].content
|
|
118
|
-
)
|
|
119
|
-
assert isinstance(invoke_args[1], HumanMessage)
|
|
120
|
-
assert "test input" in invoke_args[1].content
|
|
121
|
-
assert isinstance(invoke_args[2], SystemMessage)
|
|
122
|
-
assert "step by step" in invoke_args[2].content
|
|
123
|
-
|
|
124
|
-
# the COT should only have 3 messages
|
|
125
|
-
assert len(mock_base_model.ainvoke.call_args[0][0]) == 3
|
|
126
|
-
assert len(mock_model_instance.ainvoke.call_args[0][0]) == 5
|
|
127
|
-
|
|
128
|
-
# the final response should have the COT content and the final instructions
|
|
129
|
-
invoke_args = mock_model_instance.ainvoke.call_args[0][0]
|
|
130
|
-
assert isinstance(invoke_args[3], AIMessage)
|
|
131
|
-
assert "Chain of thought reasoning..." in invoke_args[3].content
|
|
132
|
-
assert isinstance(invoke_args[4], HumanMessage)
|
|
133
|
-
assert COT_FINAL_ANSWER_PROMPT in invoke_args[4].content
|
|
134
|
-
|
|
135
|
-
assert (
|
|
136
|
-
response.intermediate_outputs["chain_of_thought"]
|
|
137
|
-
== "Chain of thought reasoning..."
|
|
138
|
-
)
|
|
139
|
-
assert response.output == {"count": 1}
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
@pytest.mark.parametrize(
|
|
143
|
-
"structured_output_mode,expected_method",
|
|
144
|
-
[
|
|
145
|
-
(StructuredOutputMode.function_calling, "function_calling"),
|
|
146
|
-
(StructuredOutputMode.json_mode, "json_mode"),
|
|
147
|
-
(StructuredOutputMode.json_schema, "json_schema"),
|
|
148
|
-
(StructuredOutputMode.json_instruction_and_object, "json_mode"),
|
|
149
|
-
(StructuredOutputMode.default, None),
|
|
150
|
-
],
|
|
151
|
-
)
|
|
152
|
-
async def test_get_structured_output_options(
|
|
153
|
-
mock_adapter, structured_output_mode, expected_method
|
|
154
|
-
):
|
|
155
|
-
# Mock the provider response
|
|
156
|
-
mock_provider = MagicMock()
|
|
157
|
-
mock_provider.structured_output_mode = structured_output_mode
|
|
158
|
-
|
|
159
|
-
# Mock adapter.model_provider()
|
|
160
|
-
mock_adapter.model_provider = MagicMock(return_value=mock_provider)
|
|
161
|
-
|
|
162
|
-
options = mock_adapter.get_structured_output_options("model_name", "provider")
|
|
163
|
-
assert options.get("method") == expected_method
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
@pytest.mark.asyncio
|
|
167
|
-
async def test_langchain_model_from_provider_groq():
|
|
168
|
-
provider = KilnModelProvider(
|
|
169
|
-
name=ModelProviderName.groq, provider_options={"model": "mixtral-8x7b"}
|
|
170
|
-
)
|
|
171
|
-
|
|
172
|
-
with patch(
|
|
173
|
-
"kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
|
|
174
|
-
) as mock_config:
|
|
175
|
-
mock_config.return_value.groq_api_key = "test_key"
|
|
176
|
-
model = await langchain_model_from_provider(provider, "mixtral-8x7b")
|
|
177
|
-
assert isinstance(model, ChatGroq)
|
|
178
|
-
assert model.model_name == "mixtral-8x7b"
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
@pytest.mark.asyncio
|
|
182
|
-
async def test_langchain_model_from_provider_bedrock():
|
|
183
|
-
provider = KilnModelProvider(
|
|
184
|
-
name=ModelProviderName.amazon_bedrock,
|
|
185
|
-
provider_options={"model": "anthropic.claude-v2", "region_name": "us-east-1"},
|
|
186
|
-
)
|
|
187
|
-
|
|
188
|
-
with patch(
|
|
189
|
-
"kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
|
|
190
|
-
) as mock_config:
|
|
191
|
-
mock_config.return_value.bedrock_access_key = "test_access"
|
|
192
|
-
mock_config.return_value.bedrock_secret_key = "test_secret"
|
|
193
|
-
model = await langchain_model_from_provider(provider, "anthropic.claude-v2")
|
|
194
|
-
assert isinstance(model, ChatBedrockConverse)
|
|
195
|
-
assert os.environ.get("AWS_ACCESS_KEY_ID") == "test_access"
|
|
196
|
-
assert os.environ.get("AWS_SECRET_ACCESS_KEY") == "test_secret"
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
@pytest.mark.asyncio
|
|
200
|
-
async def test_langchain_model_from_provider_fireworks():
|
|
201
|
-
provider = KilnModelProvider(
|
|
202
|
-
name=ModelProviderName.fireworks_ai, provider_options={"model": "mixtral-8x7b"}
|
|
203
|
-
)
|
|
204
|
-
|
|
205
|
-
with patch(
|
|
206
|
-
"kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
|
|
207
|
-
) as mock_config:
|
|
208
|
-
mock_config.return_value.fireworks_api_key = "test_key"
|
|
209
|
-
model = await langchain_model_from_provider(provider, "mixtral-8x7b")
|
|
210
|
-
assert isinstance(model, ChatFireworks)
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
@pytest.mark.asyncio
|
|
214
|
-
async def test_langchain_model_from_provider_ollama():
|
|
215
|
-
provider = KilnModelProvider(
|
|
216
|
-
name=ModelProviderName.ollama,
|
|
217
|
-
provider_options={"model": "llama2", "model_aliases": ["llama2-uncensored"]},
|
|
218
|
-
)
|
|
219
|
-
|
|
220
|
-
mock_connection = MagicMock()
|
|
221
|
-
with (
|
|
222
|
-
patch(
|
|
223
|
-
"kiln_ai.adapters.model_adapters.langchain_adapters.get_ollama_connection",
|
|
224
|
-
return_value=AsyncMock(return_value=mock_connection),
|
|
225
|
-
),
|
|
226
|
-
patch(
|
|
227
|
-
"kiln_ai.adapters.model_adapters.langchain_adapters.ollama_model_installed",
|
|
228
|
-
return_value=True,
|
|
229
|
-
),
|
|
230
|
-
patch(
|
|
231
|
-
"kiln_ai.adapters.model_adapters.langchain_adapters.ollama_base_url",
|
|
232
|
-
return_value="http://localhost:11434",
|
|
233
|
-
),
|
|
234
|
-
):
|
|
235
|
-
model = await langchain_model_from_provider(provider, "llama2")
|
|
236
|
-
assert isinstance(model, ChatOllama)
|
|
237
|
-
assert model.model == "llama2"
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
@pytest.mark.asyncio
|
|
241
|
-
async def test_langchain_model_from_provider_invalid():
|
|
242
|
-
provider = KilnModelProvider.model_construct(
|
|
243
|
-
name="invalid_provider", provider_options={}
|
|
244
|
-
)
|
|
245
|
-
|
|
246
|
-
with pytest.raises(ValueError, match="Invalid model or provider"):
|
|
247
|
-
await langchain_model_from_provider(provider, "test_model")
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
@pytest.mark.asyncio
|
|
251
|
-
async def test_langchain_adapter_model_caching(tmp_path):
|
|
252
|
-
task = build_test_task(tmp_path)
|
|
253
|
-
custom_model = ChatGroq(model="mixtral-8x7b", groq_api_key="test")
|
|
254
|
-
|
|
255
|
-
adapter = LangchainAdapter(kiln_task=task, custom_model=custom_model)
|
|
256
|
-
|
|
257
|
-
# First call should return the cached model
|
|
258
|
-
model1 = await adapter.model()
|
|
259
|
-
assert model1 is custom_model
|
|
260
|
-
|
|
261
|
-
# Second call should return the same cached instance
|
|
262
|
-
model2 = await adapter.model()
|
|
263
|
-
assert model2 is model1
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
@pytest.mark.asyncio
|
|
267
|
-
async def test_langchain_adapter_model_structured_output(tmp_path):
|
|
268
|
-
task = build_test_task(tmp_path)
|
|
269
|
-
task.output_json_schema = """
|
|
270
|
-
{
|
|
271
|
-
"type": "object",
|
|
272
|
-
"properties": {
|
|
273
|
-
"count": {"type": "integer"}
|
|
274
|
-
}
|
|
275
|
-
}
|
|
276
|
-
"""
|
|
277
|
-
|
|
278
|
-
mock_model = MagicMock()
|
|
279
|
-
mock_model.with_structured_output = MagicMock(return_value="structured_model")
|
|
280
|
-
|
|
281
|
-
adapter = LangchainAdapter(
|
|
282
|
-
kiln_task=task, model_name="test_model", provider="ollama"
|
|
283
|
-
)
|
|
284
|
-
adapter.get_structured_output_options = MagicMock(
|
|
285
|
-
return_value={"option1": "value1"}
|
|
286
|
-
)
|
|
287
|
-
adapter.langchain_model_from = AsyncMock(return_value=mock_model)
|
|
288
|
-
|
|
289
|
-
model = await adapter.model()
|
|
290
|
-
|
|
291
|
-
# Verify the model was configured with structured output
|
|
292
|
-
mock_model.with_structured_output.assert_called_once_with(
|
|
293
|
-
{
|
|
294
|
-
"type": "object",
|
|
295
|
-
"properties": {"count": {"type": "integer"}},
|
|
296
|
-
"title": "task_response",
|
|
297
|
-
"description": "A response from the task",
|
|
298
|
-
},
|
|
299
|
-
include_raw=True,
|
|
300
|
-
option1="value1",
|
|
301
|
-
)
|
|
302
|
-
assert model == "structured_model"
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
@pytest.mark.asyncio
|
|
306
|
-
async def test_langchain_adapter_model_no_structured_output_support(tmp_path):
|
|
307
|
-
task = build_test_task(tmp_path)
|
|
308
|
-
task.output_json_schema = (
|
|
309
|
-
'{"type": "object", "properties": {"count": {"type": "integer"}}}'
|
|
310
|
-
)
|
|
311
|
-
|
|
312
|
-
mock_model = MagicMock()
|
|
313
|
-
# Remove with_structured_output method
|
|
314
|
-
del mock_model.with_structured_output
|
|
315
|
-
|
|
316
|
-
adapter = LangchainAdapter(
|
|
317
|
-
kiln_task=task, model_name="test_model", provider="ollama"
|
|
318
|
-
)
|
|
319
|
-
adapter.langchain_model_from = AsyncMock(return_value=mock_model)
|
|
320
|
-
|
|
321
|
-
with pytest.raises(ValueError, match="does not support structured output"):
|
|
322
|
-
await adapter.model()
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
import pytest
|
|
326
|
-
|
|
327
|
-
from kiln_ai.adapters.ml_model_list import KilnModelProvider, ModelProviderName
|
|
328
|
-
from kiln_ai.adapters.model_adapters.langchain_adapters import (
|
|
329
|
-
langchain_model_from_provider,
|
|
330
|
-
)
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
@pytest.mark.parametrize(
|
|
334
|
-
"provider_name",
|
|
335
|
-
[
|
|
336
|
-
(ModelProviderName.openai),
|
|
337
|
-
(ModelProviderName.openai_compatible),
|
|
338
|
-
(ModelProviderName.openrouter),
|
|
339
|
-
],
|
|
340
|
-
)
|
|
341
|
-
@pytest.mark.asyncio
|
|
342
|
-
async def test_langchain_model_from_provider_unsupported_providers(provider_name):
|
|
343
|
-
# Arrange
|
|
344
|
-
provider = KilnModelProvider(
|
|
345
|
-
name=provider_name, provider_options={}, structured_output_mode="default"
|
|
346
|
-
)
|
|
347
|
-
|
|
348
|
-
# Assert unsupported providers raise an error
|
|
349
|
-
with pytest.raises(ValueError):
|
|
350
|
-
await langchain_model_from_provider(provider, "test-model")
|