kiln-ai 0.8.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +7 -7
- kiln_ai/adapters/adapter_registry.py +81 -10
- kiln_ai/adapters/data_gen/data_gen_task.py +21 -3
- kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
- kiln_ai/adapters/eval/base_eval.py +164 -0
- kiln_ai/adapters/eval/eval_runner.py +267 -0
- kiln_ai/adapters/eval/g_eval.py +367 -0
- kiln_ai/adapters/eval/registry.py +16 -0
- kiln_ai/adapters/eval/test_base_eval.py +324 -0
- kiln_ai/adapters/eval/test_eval_runner.py +640 -0
- kiln_ai/adapters/eval/test_g_eval.py +497 -0
- kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
- kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
- kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
- kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +472 -129
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +114 -22
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
- kiln_ai/adapters/ml_model_list.py +434 -93
- kiln_ai/adapters/model_adapters/__init__.py +18 -0
- kiln_ai/adapters/model_adapters/base_adapter.py +250 -0
- kiln_ai/adapters/model_adapters/langchain_adapters.py +309 -0
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +10 -0
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +289 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +199 -0
- kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +105 -97
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +216 -0
- kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +80 -30
- kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +125 -46
- kiln_ai/adapters/ollama_tools.py +0 -1
- kiln_ai/adapters/parsers/__init__.py +10 -0
- kiln_ai/adapters/parsers/base_parser.py +12 -0
- kiln_ai/adapters/parsers/json_parser.py +37 -0
- kiln_ai/adapters/parsers/parser_registry.py +19 -0
- kiln_ai/adapters/parsers/r1_parser.py +69 -0
- kiln_ai/adapters/parsers/test_json_parser.py +81 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
- kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
- kiln_ai/adapters/prompt_builders.py +193 -49
- kiln_ai/adapters/provider_tools.py +91 -36
- kiln_ai/adapters/repair/repair_task.py +18 -19
- kiln_ai/adapters/repair/test_repair_task.py +7 -7
- kiln_ai/adapters/run_output.py +11 -0
- kiln_ai/adapters/test_adapter_registry.py +177 -0
- kiln_ai/adapters/test_generate_docs.py +69 -0
- kiln_ai/adapters/test_ollama_tools.py +0 -1
- kiln_ai/adapters/test_prompt_adaptors.py +25 -18
- kiln_ai/adapters/test_prompt_builders.py +265 -44
- kiln_ai/adapters/test_provider_tools.py +268 -46
- kiln_ai/datamodel/__init__.py +51 -772
- kiln_ai/datamodel/basemodel.py +31 -11
- kiln_ai/datamodel/datamodel_enums.py +58 -0
- kiln_ai/datamodel/dataset_filters.py +114 -0
- kiln_ai/datamodel/dataset_split.py +170 -0
- kiln_ai/datamodel/eval.py +298 -0
- kiln_ai/datamodel/finetune.py +105 -0
- kiln_ai/datamodel/json_schema.py +14 -3
- kiln_ai/datamodel/model_cache.py +8 -3
- kiln_ai/datamodel/project.py +23 -0
- kiln_ai/datamodel/prompt.py +37 -0
- kiln_ai/datamodel/prompt_id.py +83 -0
- kiln_ai/datamodel/strict_mode.py +24 -0
- kiln_ai/datamodel/task.py +181 -0
- kiln_ai/datamodel/task_output.py +321 -0
- kiln_ai/datamodel/task_run.py +164 -0
- kiln_ai/datamodel/test_basemodel.py +80 -2
- kiln_ai/datamodel/test_dataset_filters.py +71 -0
- kiln_ai/datamodel/test_dataset_split.py +127 -6
- kiln_ai/datamodel/test_datasource.py +3 -2
- kiln_ai/datamodel/test_eval_model.py +635 -0
- kiln_ai/datamodel/test_example_models.py +34 -17
- kiln_ai/datamodel/test_json_schema.py +23 -0
- kiln_ai/datamodel/test_model_cache.py +24 -0
- kiln_ai/datamodel/test_model_perf.py +125 -0
- kiln_ai/datamodel/test_models.py +131 -2
- kiln_ai/datamodel/test_prompt_id.py +129 -0
- kiln_ai/datamodel/test_task.py +159 -0
- kiln_ai/utils/config.py +6 -1
- kiln_ai/utils/exhaustive_error.py +6 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/METADATA +45 -7
- kiln_ai-0.12.0.dist-info/RECORD +100 -0
- kiln_ai/adapters/base_adapter.py +0 -191
- kiln_ai/adapters/langchain_adapters.py +0 -256
- kiln_ai-0.8.1.dist-info/RECORD +0 -58
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""
|
|
2
|
+
# Model Adapters
|
|
3
|
+
|
|
4
|
+
Model adapters are used to call AI models, like Ollama, OpenAI, etc.
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from . import (
|
|
9
|
+
base_adapter,
|
|
10
|
+
langchain_adapters,
|
|
11
|
+
openai_model_adapter,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"base_adapter",
|
|
16
|
+
"langchain_adapters",
|
|
17
|
+
"openai_model_adapter",
|
|
18
|
+
]
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from abc import ABCMeta, abstractmethod
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict, Literal, Tuple
|
|
5
|
+
|
|
6
|
+
from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
|
|
7
|
+
from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id
|
|
8
|
+
from kiln_ai.adapters.prompt_builders import prompt_builder_from_id
|
|
9
|
+
from kiln_ai.adapters.provider_tools import kiln_model_provider_from
|
|
10
|
+
from kiln_ai.adapters.run_output import RunOutput
|
|
11
|
+
from kiln_ai.datamodel import (
|
|
12
|
+
DataSource,
|
|
13
|
+
DataSourceType,
|
|
14
|
+
Task,
|
|
15
|
+
TaskOutput,
|
|
16
|
+
TaskRun,
|
|
17
|
+
)
|
|
18
|
+
from kiln_ai.datamodel.json_schema import validate_schema
|
|
19
|
+
from kiln_ai.datamodel.task import RunConfig
|
|
20
|
+
from kiln_ai.utils.config import Config
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class AdapterConfig:
|
|
25
|
+
"""
|
|
26
|
+
An adapter config is config options that do NOT impact the output of the model.
|
|
27
|
+
|
|
28
|
+
For example: if it's saved, of if we request additional data like logprobs.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
allow_saving: bool = True
|
|
32
|
+
top_logprobs: int | None = None
|
|
33
|
+
default_tags: list[str] | None = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
COT_FINAL_ANSWER_PROMPT = "Considering the above, return a final result."
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class BaseAdapter(metaclass=ABCMeta):
|
|
40
|
+
"""Base class for AI model adapters that handle task execution.
|
|
41
|
+
|
|
42
|
+
This abstract class provides the foundation for implementing model-specific adapters
|
|
43
|
+
that can process tasks with structured or unstructured inputs/outputs. It handles
|
|
44
|
+
input/output validation, prompt building, and run tracking.
|
|
45
|
+
|
|
46
|
+
Attributes:
|
|
47
|
+
prompt_builder (BasePromptBuilder): Builder for constructing prompts for the model
|
|
48
|
+
kiln_task (Task): The task configuration and metadata
|
|
49
|
+
output_schema (dict | None): JSON schema for validating structured outputs
|
|
50
|
+
input_schema (dict | None): JSON schema for validating structured inputs
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
run_config: RunConfig,
|
|
56
|
+
config: AdapterConfig | None = None,
|
|
57
|
+
):
|
|
58
|
+
self.run_config = run_config
|
|
59
|
+
self.prompt_builder = prompt_builder_from_id(
|
|
60
|
+
run_config.prompt_id, run_config.task
|
|
61
|
+
)
|
|
62
|
+
self._model_provider: KilnModelProvider | None = None
|
|
63
|
+
|
|
64
|
+
self.output_schema = self.task().output_json_schema
|
|
65
|
+
self.input_schema = self.task().input_json_schema
|
|
66
|
+
self.base_adapter_config = config or AdapterConfig()
|
|
67
|
+
|
|
68
|
+
def task(self) -> Task:
|
|
69
|
+
return self.run_config.task
|
|
70
|
+
|
|
71
|
+
def model_provider(self) -> KilnModelProvider:
|
|
72
|
+
"""
|
|
73
|
+
Lazy load the model provider for this adapter.
|
|
74
|
+
"""
|
|
75
|
+
if self._model_provider is not None:
|
|
76
|
+
return self._model_provider
|
|
77
|
+
if not self.run_config.model_name or not self.run_config.model_provider_name:
|
|
78
|
+
raise ValueError("model_name and model_provider_name must be provided")
|
|
79
|
+
self._model_provider = kiln_model_provider_from(
|
|
80
|
+
self.run_config.model_name, self.run_config.model_provider_name
|
|
81
|
+
)
|
|
82
|
+
if not self._model_provider:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"model_provider_name {self.run_config.model_provider_name} not found for model {self.run_config.model_name}"
|
|
85
|
+
)
|
|
86
|
+
return self._model_provider
|
|
87
|
+
|
|
88
|
+
async def invoke_returning_raw(
|
|
89
|
+
self,
|
|
90
|
+
input: Dict | str,
|
|
91
|
+
input_source: DataSource | None = None,
|
|
92
|
+
) -> Dict | str:
|
|
93
|
+
result = await self.invoke(input, input_source)
|
|
94
|
+
if self.task().output_json_schema is None:
|
|
95
|
+
return result.output.output
|
|
96
|
+
else:
|
|
97
|
+
return json.loads(result.output.output)
|
|
98
|
+
|
|
99
|
+
async def invoke(
|
|
100
|
+
self,
|
|
101
|
+
input: Dict | str,
|
|
102
|
+
input_source: DataSource | None = None,
|
|
103
|
+
) -> TaskRun:
|
|
104
|
+
run_output, _ = await self.invoke_returning_run_output(input, input_source)
|
|
105
|
+
return run_output
|
|
106
|
+
|
|
107
|
+
async def invoke_returning_run_output(
|
|
108
|
+
self,
|
|
109
|
+
input: Dict | str,
|
|
110
|
+
input_source: DataSource | None = None,
|
|
111
|
+
) -> Tuple[TaskRun, RunOutput]:
|
|
112
|
+
# validate input
|
|
113
|
+
if self.input_schema is not None:
|
|
114
|
+
if not isinstance(input, dict):
|
|
115
|
+
raise ValueError(f"structured input is not a dict: {input}")
|
|
116
|
+
validate_schema(input, self.input_schema)
|
|
117
|
+
|
|
118
|
+
# Run
|
|
119
|
+
run_output = await self._run(input)
|
|
120
|
+
|
|
121
|
+
# Parse
|
|
122
|
+
provider = self.model_provider()
|
|
123
|
+
parser = model_parser_from_id(provider.parser)(
|
|
124
|
+
structured_output=self.has_structured_output()
|
|
125
|
+
)
|
|
126
|
+
parsed_output = parser.parse_output(original_output=run_output)
|
|
127
|
+
|
|
128
|
+
# validate output
|
|
129
|
+
if self.output_schema is not None:
|
|
130
|
+
if not isinstance(parsed_output.output, dict):
|
|
131
|
+
raise RuntimeError(
|
|
132
|
+
f"structured response is not a dict: {parsed_output.output}"
|
|
133
|
+
)
|
|
134
|
+
validate_schema(parsed_output.output, self.output_schema)
|
|
135
|
+
else:
|
|
136
|
+
if not isinstance(parsed_output.output, str):
|
|
137
|
+
raise RuntimeError(
|
|
138
|
+
f"response is not a string for non-structured task: {parsed_output.output}"
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Generate the run and output
|
|
142
|
+
run = self.generate_run(input, input_source, parsed_output)
|
|
143
|
+
|
|
144
|
+
# Save the run if configured to do so, and we have a path to save to
|
|
145
|
+
if (
|
|
146
|
+
self.base_adapter_config.allow_saving
|
|
147
|
+
and Config.shared().autosave_runs
|
|
148
|
+
and self.task().path is not None
|
|
149
|
+
):
|
|
150
|
+
run.save_to_file()
|
|
151
|
+
else:
|
|
152
|
+
# Clear the ID to indicate it's not persisted
|
|
153
|
+
run.id = None
|
|
154
|
+
|
|
155
|
+
return run, run_output
|
|
156
|
+
|
|
157
|
+
def has_structured_output(self) -> bool:
|
|
158
|
+
return self.output_schema is not None
|
|
159
|
+
|
|
160
|
+
@abstractmethod
|
|
161
|
+
def adapter_name(self) -> str:
|
|
162
|
+
pass
|
|
163
|
+
|
|
164
|
+
@abstractmethod
|
|
165
|
+
async def _run(self, input: Dict | str) -> RunOutput:
|
|
166
|
+
pass
|
|
167
|
+
|
|
168
|
+
def build_prompt(self) -> str:
|
|
169
|
+
# The prompt builder needs to know if we want to inject formatting instructions
|
|
170
|
+
provider = self.model_provider()
|
|
171
|
+
add_json_instructions = self.has_structured_output() and (
|
|
172
|
+
provider.structured_output_mode == StructuredOutputMode.json_instructions
|
|
173
|
+
or provider.structured_output_mode
|
|
174
|
+
== StructuredOutputMode.json_instruction_and_object
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
return self.prompt_builder.build_prompt(
|
|
178
|
+
include_json_instructions=add_json_instructions
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
def run_strategy(
|
|
182
|
+
self,
|
|
183
|
+
) -> Tuple[Literal["cot_as_message", "cot_two_call", "basic"], str | None]:
|
|
184
|
+
# Determine the run strategy for COT prompting. 3 options:
|
|
185
|
+
# 1. "Thinking" LLM designed to output thinking in a structured format plus a COT prompt: we make 1 call to the LLM, which outputs thinking in a structured format. We include the thinking instuctions as a message.
|
|
186
|
+
# 2. Normal LLM with COT prompt: we make 2 calls to the LLM - one for thinking and one for the final response. This helps us use the LLM's structured output modes (json_schema, tools, etc), which can't be used in a single call. It also separates the thinking from the final response.
|
|
187
|
+
# 3. Non chain of thought: we make 1 call to the LLM, with no COT prompt.
|
|
188
|
+
cot_prompt = self.prompt_builder.chain_of_thought_prompt()
|
|
189
|
+
reasoning_capable = self.model_provider().reasoning_capable
|
|
190
|
+
|
|
191
|
+
if cot_prompt and reasoning_capable:
|
|
192
|
+
# 1: "Thinking" LLM designed to output thinking in a structured format
|
|
193
|
+
# A simple message with the COT prompt appended to the message list is sufficient
|
|
194
|
+
return "cot_as_message", cot_prompt
|
|
195
|
+
elif cot_prompt:
|
|
196
|
+
# 2: Unstructured output with COT
|
|
197
|
+
# Two calls to separate the thinking from the final response
|
|
198
|
+
return "cot_two_call", cot_prompt
|
|
199
|
+
else:
|
|
200
|
+
return "basic", None
|
|
201
|
+
|
|
202
|
+
# create a run and task output
|
|
203
|
+
def generate_run(
|
|
204
|
+
self, input: Dict | str, input_source: DataSource | None, run_output: RunOutput
|
|
205
|
+
) -> TaskRun:
|
|
206
|
+
# Convert input and output to JSON strings if they are dictionaries
|
|
207
|
+
input_str = (
|
|
208
|
+
json.dumps(input, ensure_ascii=False) if isinstance(input, dict) else input
|
|
209
|
+
)
|
|
210
|
+
output_str = (
|
|
211
|
+
json.dumps(run_output.output, ensure_ascii=False)
|
|
212
|
+
if isinstance(run_output.output, dict)
|
|
213
|
+
else run_output.output
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# If no input source is provided, use the human data source
|
|
217
|
+
if input_source is None:
|
|
218
|
+
input_source = DataSource(
|
|
219
|
+
type=DataSourceType.human,
|
|
220
|
+
properties={"created_by": Config.shared().user_id},
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
new_task_run = TaskRun(
|
|
224
|
+
parent=self.task(),
|
|
225
|
+
input=input_str,
|
|
226
|
+
input_source=input_source,
|
|
227
|
+
output=TaskOutput(
|
|
228
|
+
output=output_str,
|
|
229
|
+
# Synthetic since an adapter, not a human, is creating this
|
|
230
|
+
source=DataSource(
|
|
231
|
+
type=DataSourceType.synthetic,
|
|
232
|
+
properties=self._properties_for_task_output(),
|
|
233
|
+
),
|
|
234
|
+
),
|
|
235
|
+
intermediate_outputs=run_output.intermediate_outputs,
|
|
236
|
+
tags=self.base_adapter_config.default_tags or [],
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
return new_task_run
|
|
240
|
+
|
|
241
|
+
def _properties_for_task_output(self) -> Dict[str, str | int | float]:
|
|
242
|
+
props = {}
|
|
243
|
+
|
|
244
|
+
# adapter info
|
|
245
|
+
props["adapter_name"] = self.adapter_name()
|
|
246
|
+
props["model_name"] = self.run_config.model_name
|
|
247
|
+
props["model_provider"] = self.run_config.model_provider_name
|
|
248
|
+
props["prompt_id"] = self.run_config.prompt_id
|
|
249
|
+
|
|
250
|
+
return props
|
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
|
|
4
|
+
from langchain_aws import ChatBedrockConverse
|
|
5
|
+
from langchain_core.language_models import LanguageModelInput
|
|
6
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
|
7
|
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
8
|
+
from langchain_core.messages.base import BaseMessage
|
|
9
|
+
from langchain_core.runnables import Runnable
|
|
10
|
+
from langchain_fireworks import ChatFireworks
|
|
11
|
+
from langchain_groq import ChatGroq
|
|
12
|
+
from langchain_ollama import ChatOllama
|
|
13
|
+
from pydantic import BaseModel
|
|
14
|
+
|
|
15
|
+
import kiln_ai.datamodel as datamodel
|
|
16
|
+
from kiln_ai.adapters.ml_model_list import (
|
|
17
|
+
KilnModelProvider,
|
|
18
|
+
ModelProviderName,
|
|
19
|
+
StructuredOutputMode,
|
|
20
|
+
)
|
|
21
|
+
from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
22
|
+
COT_FINAL_ANSWER_PROMPT,
|
|
23
|
+
AdapterConfig,
|
|
24
|
+
BaseAdapter,
|
|
25
|
+
RunOutput,
|
|
26
|
+
)
|
|
27
|
+
from kiln_ai.adapters.ollama_tools import (
|
|
28
|
+
get_ollama_connection,
|
|
29
|
+
ollama_base_url,
|
|
30
|
+
ollama_model_installed,
|
|
31
|
+
)
|
|
32
|
+
from kiln_ai.datamodel import PromptId
|
|
33
|
+
from kiln_ai.datamodel.task import RunConfig
|
|
34
|
+
from kiln_ai.utils.config import Config
|
|
35
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
36
|
+
|
|
37
|
+
LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class LangchainAdapter(BaseAdapter):
|
|
41
|
+
_model: LangChainModelType | None = None
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
kiln_task: datamodel.Task,
|
|
46
|
+
custom_model: BaseChatModel | None = None,
|
|
47
|
+
model_name: str | None = None,
|
|
48
|
+
provider: str | None = None,
|
|
49
|
+
prompt_id: PromptId | None = None,
|
|
50
|
+
base_adapter_config: AdapterConfig | None = None,
|
|
51
|
+
):
|
|
52
|
+
if custom_model is not None:
|
|
53
|
+
self._model = custom_model
|
|
54
|
+
|
|
55
|
+
# Attempt to infer model provider and name from custom model
|
|
56
|
+
if provider is None:
|
|
57
|
+
provider = "custom.langchain:" + custom_model.__class__.__name__
|
|
58
|
+
|
|
59
|
+
if model_name is None:
|
|
60
|
+
model_name = "custom.langchain:unknown_model"
|
|
61
|
+
if hasattr(custom_model, "model_name") and isinstance(
|
|
62
|
+
getattr(custom_model, "model_name"), str
|
|
63
|
+
):
|
|
64
|
+
model_name = "custom.langchain:" + getattr(
|
|
65
|
+
custom_model, "model_name"
|
|
66
|
+
)
|
|
67
|
+
if hasattr(custom_model, "model") and isinstance(
|
|
68
|
+
getattr(custom_model, "model"), str
|
|
69
|
+
):
|
|
70
|
+
model_name = "custom.langchain:" + getattr(custom_model, "model")
|
|
71
|
+
elif model_name is not None:
|
|
72
|
+
# default provider name if not provided
|
|
73
|
+
provider = provider or "custom.langchain.default_provider"
|
|
74
|
+
else:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
"model_name and provider must be provided if custom_model is not provided"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if model_name is None:
|
|
80
|
+
raise ValueError("model_name must be provided")
|
|
81
|
+
|
|
82
|
+
run_config = RunConfig(
|
|
83
|
+
task=kiln_task,
|
|
84
|
+
model_name=model_name,
|
|
85
|
+
model_provider_name=provider,
|
|
86
|
+
prompt_id=prompt_id or datamodel.PromptGenerators.SIMPLE,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
super().__init__(
|
|
90
|
+
run_config=run_config,
|
|
91
|
+
config=base_adapter_config,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
async def model(self) -> LangChainModelType:
|
|
95
|
+
# cached model
|
|
96
|
+
if self._model:
|
|
97
|
+
return self._model
|
|
98
|
+
|
|
99
|
+
self._model = await self.langchain_model_from()
|
|
100
|
+
|
|
101
|
+
# Decide if we want to use Langchain's structured output:
|
|
102
|
+
# 1. Only for structured tasks
|
|
103
|
+
# 2. Only if the provider's mode isn't json_instructions (only mode that doesn't use an API option for structured output capabilities)
|
|
104
|
+
provider = self.model_provider()
|
|
105
|
+
use_lc_structured_output = (
|
|
106
|
+
self.has_structured_output()
|
|
107
|
+
and provider.structured_output_mode
|
|
108
|
+
!= StructuredOutputMode.json_instructions
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
if use_lc_structured_output:
|
|
112
|
+
if not hasattr(self._model, "with_structured_output") or not callable(
|
|
113
|
+
getattr(self._model, "with_structured_output")
|
|
114
|
+
):
|
|
115
|
+
raise ValueError(
|
|
116
|
+
f"model {self._model} does not support structured output, cannot use output_json_schema"
|
|
117
|
+
)
|
|
118
|
+
# Langchain expects title/description to be at top level, on top of json schema
|
|
119
|
+
output_schema = self.task().output_schema()
|
|
120
|
+
if output_schema is None:
|
|
121
|
+
raise ValueError(
|
|
122
|
+
f"output_json_schema is not valid json: {self.task().output_json_schema}"
|
|
123
|
+
)
|
|
124
|
+
output_schema["title"] = "task_response"
|
|
125
|
+
output_schema["description"] = "A response from the task"
|
|
126
|
+
with_structured_output_options = self.get_structured_output_options(
|
|
127
|
+
self.run_config.model_name, self.run_config.model_provider_name
|
|
128
|
+
)
|
|
129
|
+
self._model = self._model.with_structured_output(
|
|
130
|
+
output_schema,
|
|
131
|
+
include_raw=True,
|
|
132
|
+
**with_structured_output_options,
|
|
133
|
+
)
|
|
134
|
+
return self._model
|
|
135
|
+
|
|
136
|
+
async def _run(self, input: Dict | str) -> RunOutput:
|
|
137
|
+
if self.base_adapter_config.top_logprobs is not None:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
"Kiln's Langchain adapter does not support logprobs/top_logprobs. Select a model from an OpenAI compatible provider (openai, openrouter, etc) instead."
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
provider = self.model_provider()
|
|
143
|
+
model = await self.model()
|
|
144
|
+
chain = model
|
|
145
|
+
intermediate_outputs = {}
|
|
146
|
+
|
|
147
|
+
prompt = self.build_prompt()
|
|
148
|
+
user_msg = self.prompt_builder.build_user_message(input)
|
|
149
|
+
messages = [
|
|
150
|
+
SystemMessage(content=prompt),
|
|
151
|
+
HumanMessage(content=user_msg),
|
|
152
|
+
]
|
|
153
|
+
|
|
154
|
+
run_strategy, cot_prompt = self.run_strategy()
|
|
155
|
+
|
|
156
|
+
if run_strategy == "cot_as_message":
|
|
157
|
+
if not cot_prompt:
|
|
158
|
+
raise ValueError("cot_prompt is required for cot_as_message strategy")
|
|
159
|
+
messages.append(SystemMessage(content=cot_prompt))
|
|
160
|
+
elif run_strategy == "cot_two_call":
|
|
161
|
+
if not cot_prompt:
|
|
162
|
+
raise ValueError("cot_prompt is required for cot_two_call strategy")
|
|
163
|
+
messages.append(
|
|
164
|
+
SystemMessage(content=cot_prompt),
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Base model (without structured output) used for COT message
|
|
168
|
+
base_model = await self.langchain_model_from()
|
|
169
|
+
|
|
170
|
+
cot_messages = [*messages]
|
|
171
|
+
cot_response = await base_model.ainvoke(cot_messages)
|
|
172
|
+
intermediate_outputs["chain_of_thought"] = cot_response.content
|
|
173
|
+
messages.append(AIMessage(content=cot_response.content))
|
|
174
|
+
messages.append(HumanMessage(content=COT_FINAL_ANSWER_PROMPT))
|
|
175
|
+
|
|
176
|
+
response = await chain.ainvoke(messages)
|
|
177
|
+
|
|
178
|
+
# Langchain may have already parsed the response into structured output, so use that if available.
|
|
179
|
+
# However, a plain string may still be fixed at the parsing layer, so not being structured isn't a critical failure (yet)
|
|
180
|
+
if (
|
|
181
|
+
self.has_structured_output()
|
|
182
|
+
and isinstance(response, dict)
|
|
183
|
+
and "parsed" in response
|
|
184
|
+
and isinstance(response["parsed"], dict)
|
|
185
|
+
):
|
|
186
|
+
structured_response = response["parsed"]
|
|
187
|
+
return RunOutput(
|
|
188
|
+
output=self._munge_response(structured_response),
|
|
189
|
+
intermediate_outputs=intermediate_outputs,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
if not isinstance(response, BaseMessage):
|
|
193
|
+
raise RuntimeError(f"response is not a BaseMessage: {response}")
|
|
194
|
+
|
|
195
|
+
text_content = response.content
|
|
196
|
+
if not isinstance(text_content, str):
|
|
197
|
+
raise RuntimeError(f"response is not a string: {text_content}")
|
|
198
|
+
|
|
199
|
+
return RunOutput(
|
|
200
|
+
output=text_content,
|
|
201
|
+
intermediate_outputs=intermediate_outputs,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
def adapter_name(self) -> str:
|
|
205
|
+
return "kiln_langchain_adapter"
|
|
206
|
+
|
|
207
|
+
def _munge_response(self, response: Dict) -> Dict:
|
|
208
|
+
# Mistral Large tool calling format is a bit different. Convert to standard format.
|
|
209
|
+
if (
|
|
210
|
+
"name" in response
|
|
211
|
+
and response["name"] == "task_response"
|
|
212
|
+
and "arguments" in response
|
|
213
|
+
):
|
|
214
|
+
return response["arguments"]
|
|
215
|
+
return response
|
|
216
|
+
|
|
217
|
+
def get_structured_output_options(
|
|
218
|
+
self, model_name: str, model_provider_name: str
|
|
219
|
+
) -> Dict[str, Any]:
|
|
220
|
+
provider = self.model_provider()
|
|
221
|
+
if not provider:
|
|
222
|
+
return {}
|
|
223
|
+
|
|
224
|
+
options = {}
|
|
225
|
+
# We may need to add some provider specific logic here if providers use different names for the same mode, but everyone is copying openai for now
|
|
226
|
+
match provider.structured_output_mode:
|
|
227
|
+
case StructuredOutputMode.function_calling_weak:
|
|
228
|
+
# Langchaing doesn't handle weak/strict separately
|
|
229
|
+
options["method"] = "function_calling"
|
|
230
|
+
case StructuredOutputMode.function_calling:
|
|
231
|
+
options["method"] = "function_calling"
|
|
232
|
+
case StructuredOutputMode.json_mode:
|
|
233
|
+
options["method"] = "json_mode"
|
|
234
|
+
case StructuredOutputMode.json_instruction_and_object:
|
|
235
|
+
# We also pass instructions
|
|
236
|
+
options["method"] = "json_mode"
|
|
237
|
+
case StructuredOutputMode.json_schema:
|
|
238
|
+
options["method"] = "json_schema"
|
|
239
|
+
case StructuredOutputMode.json_instructions:
|
|
240
|
+
# JSON done via instructions in prompt, not via API
|
|
241
|
+
pass
|
|
242
|
+
case StructuredOutputMode.default:
|
|
243
|
+
if provider.name == ModelProviderName.ollama:
|
|
244
|
+
# Ollama has great json_schema support, so use that: https://ollama.com/blog/structured-outputs
|
|
245
|
+
options["method"] = "json_schema"
|
|
246
|
+
else:
|
|
247
|
+
# Let langchain decide the default
|
|
248
|
+
pass
|
|
249
|
+
case _:
|
|
250
|
+
raise_exhaustive_enum_error(provider.structured_output_mode)
|
|
251
|
+
|
|
252
|
+
return options
|
|
253
|
+
|
|
254
|
+
async def langchain_model_from(self) -> BaseChatModel:
|
|
255
|
+
provider = self.model_provider()
|
|
256
|
+
return await langchain_model_from_provider(provider, self.run_config.model_name)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
async def langchain_model_from_provider(
|
|
260
|
+
provider: KilnModelProvider, model_name: str
|
|
261
|
+
) -> BaseChatModel:
|
|
262
|
+
if provider.name == ModelProviderName.openai:
|
|
263
|
+
# We use the OpenAICompatibleAdapter for OpenAI
|
|
264
|
+
raise ValueError("OpenAI is not supported in Langchain adapter")
|
|
265
|
+
elif provider.name == ModelProviderName.openai_compatible:
|
|
266
|
+
# We use the OpenAICompatibleAdapter for OpenAI compatible
|
|
267
|
+
raise ValueError("OpenAI compatible is not supported in Langchain adapter")
|
|
268
|
+
elif provider.name == ModelProviderName.groq:
|
|
269
|
+
api_key = Config.shared().groq_api_key
|
|
270
|
+
if api_key is None:
|
|
271
|
+
raise ValueError(
|
|
272
|
+
"Attempted to use Groq without an API key set. "
|
|
273
|
+
"Get your API key from https://console.groq.com/keys"
|
|
274
|
+
)
|
|
275
|
+
return ChatGroq(**provider.provider_options, groq_api_key=api_key) # type: ignore[arg-type]
|
|
276
|
+
elif provider.name == ModelProviderName.amazon_bedrock:
|
|
277
|
+
api_key = Config.shared().bedrock_access_key
|
|
278
|
+
secret_key = Config.shared().bedrock_secret_key
|
|
279
|
+
# langchain doesn't allow passing these, so ugly hack to set env vars
|
|
280
|
+
os.environ["AWS_ACCESS_KEY_ID"] = api_key
|
|
281
|
+
os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
|
|
282
|
+
return ChatBedrockConverse(
|
|
283
|
+
**provider.provider_options,
|
|
284
|
+
)
|
|
285
|
+
elif provider.name == ModelProviderName.fireworks_ai:
|
|
286
|
+
api_key = Config.shared().fireworks_api_key
|
|
287
|
+
return ChatFireworks(**provider.provider_options, api_key=api_key)
|
|
288
|
+
elif provider.name == ModelProviderName.ollama:
|
|
289
|
+
# Ollama model naming is pretty flexible. We try a few versions of the model name
|
|
290
|
+
potential_model_names = []
|
|
291
|
+
if "model" in provider.provider_options:
|
|
292
|
+
potential_model_names.append(provider.provider_options["model"])
|
|
293
|
+
if "model_aliases" in provider.provider_options:
|
|
294
|
+
potential_model_names.extend(provider.provider_options["model_aliases"])
|
|
295
|
+
|
|
296
|
+
# Get the list of models Ollama supports
|
|
297
|
+
ollama_connection = await get_ollama_connection()
|
|
298
|
+
if ollama_connection is None:
|
|
299
|
+
raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.")
|
|
300
|
+
|
|
301
|
+
for model_name in potential_model_names:
|
|
302
|
+
if ollama_model_installed(ollama_connection, model_name):
|
|
303
|
+
return ChatOllama(model=model_name, base_url=ollama_base_url())
|
|
304
|
+
|
|
305
|
+
raise ValueError(f"Model {model_name} not installed on Ollama")
|
|
306
|
+
elif provider.name == ModelProviderName.openrouter:
|
|
307
|
+
raise ValueError("OpenRouter is not supported in Langchain adapter")
|
|
308
|
+
else:
|
|
309
|
+
raise ValueError(f"Invalid model or provider: {model_name} - {provider.name}")
|