kiln-ai 0.8.0__py3-none-any.whl → 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +7 -7
- kiln_ai/adapters/adapter_registry.py +77 -5
- kiln_ai/adapters/data_gen/data_gen_task.py +3 -3
- kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
- kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
- kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
- kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
- kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +469 -129
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +113 -21
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
- kiln_ai/adapters/ml_model_list.py +323 -94
- kiln_ai/adapters/model_adapters/__init__.py +18 -0
- kiln_ai/adapters/{base_adapter.py → model_adapters/base_adapter.py} +81 -37
- kiln_ai/adapters/{langchain_adapters.py → model_adapters/langchain_adapters.py} +130 -84
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +11 -0
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +246 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +190 -0
- kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +103 -88
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +225 -0
- kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +43 -15
- kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +93 -20
- kiln_ai/adapters/parsers/__init__.py +10 -0
- kiln_ai/adapters/parsers/base_parser.py +12 -0
- kiln_ai/adapters/parsers/json_parser.py +37 -0
- kiln_ai/adapters/parsers/parser_registry.py +19 -0
- kiln_ai/adapters/parsers/r1_parser.py +69 -0
- kiln_ai/adapters/parsers/test_json_parser.py +81 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
- kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
- kiln_ai/adapters/prompt_builders.py +126 -20
- kiln_ai/adapters/provider_tools.py +91 -36
- kiln_ai/adapters/repair/repair_task.py +17 -6
- kiln_ai/adapters/repair/test_repair_task.py +4 -4
- kiln_ai/adapters/run_output.py +8 -0
- kiln_ai/adapters/test_adapter_registry.py +177 -0
- kiln_ai/adapters/test_generate_docs.py +69 -0
- kiln_ai/adapters/test_prompt_adaptors.py +8 -4
- kiln_ai/adapters/test_prompt_builders.py +190 -29
- kiln_ai/adapters/test_provider_tools.py +268 -46
- kiln_ai/datamodel/__init__.py +199 -12
- kiln_ai/datamodel/basemodel.py +31 -11
- kiln_ai/datamodel/json_schema.py +8 -3
- kiln_ai/datamodel/model_cache.py +8 -3
- kiln_ai/datamodel/test_basemodel.py +81 -2
- kiln_ai/datamodel/test_dataset_split.py +100 -3
- kiln_ai/datamodel/test_example_models.py +25 -4
- kiln_ai/datamodel/test_model_cache.py +24 -0
- kiln_ai/datamodel/test_model_perf.py +125 -0
- kiln_ai/datamodel/test_models.py +129 -0
- kiln_ai/utils/exhaustive_error.py +6 -0
- {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/METADATA +9 -7
- kiln_ai-0.11.1.dist-info/RECORD +76 -0
- kiln_ai-0.8.0.dist-info/RECORD +0 -58
- {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""
|
|
2
|
+
# Model Adapters
|
|
3
|
+
|
|
4
|
+
Model adapters are used to call AI models, like Ollama, OpenAI, etc.
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from . import (
|
|
9
|
+
base_adapter,
|
|
10
|
+
langchain_adapters,
|
|
11
|
+
openai_model_adapter,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"base_adapter",
|
|
16
|
+
"langchain_adapters",
|
|
17
|
+
"openai_model_adapter",
|
|
18
|
+
]
|
|
@@ -1,8 +1,13 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from abc import ABCMeta, abstractmethod
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Dict
|
|
4
|
+
from typing import Dict, Literal, Tuple
|
|
5
5
|
|
|
6
|
+
from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
|
|
7
|
+
from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id
|
|
8
|
+
from kiln_ai.adapters.prompt_builders import BasePromptBuilder, SimplePromptBuilder
|
|
9
|
+
from kiln_ai.adapters.provider_tools import kiln_model_provider_from
|
|
10
|
+
from kiln_ai.adapters.run_output import RunOutput
|
|
6
11
|
from kiln_ai.datamodel import (
|
|
7
12
|
DataSource,
|
|
8
13
|
DataSourceType,
|
|
@@ -13,8 +18,6 @@ from kiln_ai.datamodel import (
|
|
|
13
18
|
from kiln_ai.datamodel.json_schema import validate_schema
|
|
14
19
|
from kiln_ai.utils.config import Config
|
|
15
20
|
|
|
16
|
-
from .prompt_builders import BasePromptBuilder, SimplePromptBuilder
|
|
17
|
-
|
|
18
21
|
|
|
19
22
|
@dataclass
|
|
20
23
|
class AdapterInfo:
|
|
@@ -22,12 +25,10 @@ class AdapterInfo:
|
|
|
22
25
|
model_name: str
|
|
23
26
|
model_provider: str
|
|
24
27
|
prompt_builder_name: str
|
|
28
|
+
prompt_id: str | None = None
|
|
25
29
|
|
|
26
30
|
|
|
27
|
-
|
|
28
|
-
class RunOutput:
|
|
29
|
-
output: Dict | str
|
|
30
|
-
intermediate_outputs: Dict[str, str] | None
|
|
31
|
+
COT_FINAL_ANSWER_PROMPT = "Considering the above, return a final result."
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
class BaseAdapter(metaclass=ABCMeta):
|
|
@@ -47,6 +48,8 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
47
48
|
def __init__(
|
|
48
49
|
self,
|
|
49
50
|
kiln_task: Task,
|
|
51
|
+
model_name: str,
|
|
52
|
+
model_provider_name: str,
|
|
50
53
|
prompt_builder: BasePromptBuilder | None = None,
|
|
51
54
|
tags: list[str] | None = None,
|
|
52
55
|
):
|
|
@@ -55,6 +58,26 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
55
58
|
self.output_schema = self.kiln_task.output_json_schema
|
|
56
59
|
self.input_schema = self.kiln_task.input_json_schema
|
|
57
60
|
self.default_tags = tags
|
|
61
|
+
self.model_name = model_name
|
|
62
|
+
self.model_provider_name = model_provider_name
|
|
63
|
+
self._model_provider: KilnModelProvider | None = None
|
|
64
|
+
|
|
65
|
+
def model_provider(self) -> KilnModelProvider:
|
|
66
|
+
"""
|
|
67
|
+
Lazy load the model provider for this adapter.
|
|
68
|
+
"""
|
|
69
|
+
if self._model_provider is not None:
|
|
70
|
+
return self._model_provider
|
|
71
|
+
if not self.model_name or not self.model_provider_name:
|
|
72
|
+
raise ValueError("model_name and model_provider_name must be provided")
|
|
73
|
+
self._model_provider = kiln_model_provider_from(
|
|
74
|
+
self.model_name, self.model_provider_name
|
|
75
|
+
)
|
|
76
|
+
if not self._model_provider:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
f"model_provider_name {self.model_provider_name} not found for model {self.model_name}"
|
|
79
|
+
)
|
|
80
|
+
return self._model_provider
|
|
58
81
|
|
|
59
82
|
async def invoke_returning_raw(
|
|
60
83
|
self,
|
|
@@ -81,21 +104,28 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
81
104
|
# Run
|
|
82
105
|
run_output = await self._run(input)
|
|
83
106
|
|
|
107
|
+
# Parse
|
|
108
|
+
provider = self.model_provider()
|
|
109
|
+
parser = model_parser_from_id(provider.parser)(
|
|
110
|
+
structured_output=self.has_structured_output()
|
|
111
|
+
)
|
|
112
|
+
parsed_output = parser.parse_output(original_output=run_output)
|
|
113
|
+
|
|
84
114
|
# validate output
|
|
85
115
|
if self.output_schema is not None:
|
|
86
|
-
if not isinstance(
|
|
116
|
+
if not isinstance(parsed_output.output, dict):
|
|
87
117
|
raise RuntimeError(
|
|
88
|
-
f"structured response is not a dict: {
|
|
118
|
+
f"structured response is not a dict: {parsed_output.output}"
|
|
89
119
|
)
|
|
90
|
-
validate_schema(
|
|
120
|
+
validate_schema(parsed_output.output, self.output_schema)
|
|
91
121
|
else:
|
|
92
|
-
if not isinstance(
|
|
122
|
+
if not isinstance(parsed_output.output, str):
|
|
93
123
|
raise RuntimeError(
|
|
94
|
-
f"response is not a string for non-structured task: {
|
|
124
|
+
f"response is not a string for non-structured task: {parsed_output.output}"
|
|
95
125
|
)
|
|
96
126
|
|
|
97
127
|
# Generate the run and output
|
|
98
|
-
run = self.generate_run(input, input_source,
|
|
128
|
+
run = self.generate_run(input, input_source, parsed_output)
|
|
99
129
|
|
|
100
130
|
# Save the run if configured to do so, and we have a path to save to
|
|
101
131
|
if Config.shared().autosave_runs and self.kiln_task.path is not None:
|
|
@@ -118,16 +148,49 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
118
148
|
pass
|
|
119
149
|
|
|
120
150
|
def build_prompt(self) -> str:
|
|
121
|
-
|
|
151
|
+
# The prompt builder needs to know if we want to inject formatting instructions
|
|
152
|
+
provider = self.model_provider()
|
|
153
|
+
add_json_instructions = self.has_structured_output() and (
|
|
154
|
+
provider.structured_output_mode == StructuredOutputMode.json_instructions
|
|
155
|
+
or provider.structured_output_mode
|
|
156
|
+
== StructuredOutputMode.json_instruction_and_object
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
return self.prompt_builder.build_prompt(
|
|
160
|
+
include_json_instructions=add_json_instructions
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
def run_strategy(
|
|
164
|
+
self,
|
|
165
|
+
) -> Tuple[Literal["cot_as_message", "cot_two_call", "basic"], str | None]:
|
|
166
|
+
# Determine the run strategy for COT prompting. 3 options:
|
|
167
|
+
# 1. "Thinking" LLM designed to output thinking in a structured format plus a COT prompt: we make 1 call to the LLM, which outputs thinking in a structured format. We include the thinking instuctions as a message.
|
|
168
|
+
# 2. Normal LLM with COT prompt: we make 2 calls to the LLM - one for thinking and one for the final response. This helps us use the LLM's structured output modes (json_schema, tools, etc), which can't be used in a single call. It also separates the thinking from the final response.
|
|
169
|
+
# 3. Non chain of thought: we make 1 call to the LLM, with no COT prompt.
|
|
170
|
+
cot_prompt = self.prompt_builder.chain_of_thought_prompt()
|
|
171
|
+
reasoning_capable = self.model_provider().reasoning_capable
|
|
172
|
+
|
|
173
|
+
if cot_prompt and reasoning_capable:
|
|
174
|
+
# 1: "Thinking" LLM designed to output thinking in a structured format
|
|
175
|
+
# A simple message with the COT prompt appended to the message list is sufficient
|
|
176
|
+
return "cot_as_message", cot_prompt
|
|
177
|
+
elif cot_prompt:
|
|
178
|
+
# 2: Unstructured output with COT
|
|
179
|
+
# Two calls to separate the thinking from the final response
|
|
180
|
+
return "cot_two_call", cot_prompt
|
|
181
|
+
else:
|
|
182
|
+
return "basic", None
|
|
122
183
|
|
|
123
184
|
# create a run and task output
|
|
124
185
|
def generate_run(
|
|
125
186
|
self, input: Dict | str, input_source: DataSource | None, run_output: RunOutput
|
|
126
187
|
) -> TaskRun:
|
|
127
188
|
# Convert input and output to JSON strings if they are dictionaries
|
|
128
|
-
input_str =
|
|
189
|
+
input_str = (
|
|
190
|
+
json.dumps(input, ensure_ascii=False) if isinstance(input, dict) else input
|
|
191
|
+
)
|
|
129
192
|
output_str = (
|
|
130
|
-
json.dumps(run_output.output)
|
|
193
|
+
json.dumps(run_output.output, ensure_ascii=False)
|
|
131
194
|
if isinstance(run_output.output, dict)
|
|
132
195
|
else run_output.output
|
|
133
196
|
)
|
|
@@ -155,27 +218,6 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
155
218
|
tags=self.default_tags or [],
|
|
156
219
|
)
|
|
157
220
|
|
|
158
|
-
exclude_fields = {
|
|
159
|
-
"id": True,
|
|
160
|
-
"created_at": True,
|
|
161
|
-
"updated_at": True,
|
|
162
|
-
"path": True,
|
|
163
|
-
"output": {"id": True, "created_at": True, "updated_at": True},
|
|
164
|
-
}
|
|
165
|
-
new_run_dump = new_task_run.model_dump(exclude=exclude_fields)
|
|
166
|
-
|
|
167
|
-
# Check if the same run already exists
|
|
168
|
-
existing_task_run = next(
|
|
169
|
-
(
|
|
170
|
-
task_run
|
|
171
|
-
for task_run in self.kiln_task.runs()
|
|
172
|
-
if task_run.model_dump(exclude=exclude_fields) == new_run_dump
|
|
173
|
-
),
|
|
174
|
-
None,
|
|
175
|
-
)
|
|
176
|
-
if existing_task_run:
|
|
177
|
-
return existing_task_run
|
|
178
|
-
|
|
179
221
|
return new_task_run
|
|
180
222
|
|
|
181
223
|
def _properties_for_task_output(self) -> Dict[str, str | int | float]:
|
|
@@ -187,5 +229,7 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
187
229
|
props["model_name"] = adapter_info.model_name
|
|
188
230
|
props["model_provider"] = adapter_info.model_provider
|
|
189
231
|
props["prompt_builder_name"] = adapter_info.prompt_builder_name
|
|
232
|
+
if adapter_info.prompt_id is not None:
|
|
233
|
+
props["prompt_id"] = adapter_info.prompt_id
|
|
190
234
|
|
|
191
235
|
return props
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from os import getenv
|
|
3
2
|
from typing import Any, Dict
|
|
4
3
|
|
|
5
4
|
from langchain_aws import ChatBedrockConverse
|
|
@@ -11,20 +10,28 @@ from langchain_core.runnables import Runnable
|
|
|
11
10
|
from langchain_fireworks import ChatFireworks
|
|
12
11
|
from langchain_groq import ChatGroq
|
|
13
12
|
from langchain_ollama import ChatOllama
|
|
14
|
-
from langchain_openai import ChatOpenAI
|
|
15
13
|
from pydantic import BaseModel
|
|
16
14
|
|
|
17
15
|
import kiln_ai.datamodel as datamodel
|
|
16
|
+
from kiln_ai.adapters.ml_model_list import (
|
|
17
|
+
KilnModelProvider,
|
|
18
|
+
ModelProviderName,
|
|
19
|
+
StructuredOutputMode,
|
|
20
|
+
)
|
|
21
|
+
from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
22
|
+
COT_FINAL_ANSWER_PROMPT,
|
|
23
|
+
AdapterInfo,
|
|
24
|
+
BaseAdapter,
|
|
25
|
+
BasePromptBuilder,
|
|
26
|
+
RunOutput,
|
|
27
|
+
)
|
|
18
28
|
from kiln_ai.adapters.ollama_tools import (
|
|
19
29
|
get_ollama_connection,
|
|
20
30
|
ollama_base_url,
|
|
21
31
|
ollama_model_installed,
|
|
22
32
|
)
|
|
23
33
|
from kiln_ai.utils.config import Config
|
|
24
|
-
|
|
25
|
-
from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput
|
|
26
|
-
from .ml_model_list import KilnModelProvider, ModelProviderName
|
|
27
|
-
from .provider_tools import kiln_model_provider_from
|
|
34
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
28
35
|
|
|
29
36
|
LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel]
|
|
30
37
|
|
|
@@ -41,39 +48,62 @@ class LangchainAdapter(BaseAdapter):
|
|
|
41
48
|
prompt_builder: BasePromptBuilder | None = None,
|
|
42
49
|
tags: list[str] | None = None,
|
|
43
50
|
):
|
|
44
|
-
super().__init__(kiln_task, prompt_builder=prompt_builder, tags=tags)
|
|
45
51
|
if custom_model is not None:
|
|
46
52
|
self._model = custom_model
|
|
47
53
|
|
|
48
54
|
# Attempt to infer model provider and name from custom model
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
custom_model, "model_name"
|
|
56
|
-
)
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
55
|
+
if provider is None:
|
|
56
|
+
provider = "custom.langchain:" + custom_model.__class__.__name__
|
|
57
|
+
|
|
58
|
+
if model_name is None:
|
|
59
|
+
model_name = "custom.langchain:unknown_model"
|
|
60
|
+
if hasattr(custom_model, "model_name") and isinstance(
|
|
61
|
+
getattr(custom_model, "model_name"), str
|
|
62
|
+
):
|
|
63
|
+
model_name = "custom.langchain:" + getattr(
|
|
64
|
+
custom_model, "model_name"
|
|
65
|
+
)
|
|
66
|
+
if hasattr(custom_model, "model") and isinstance(
|
|
67
|
+
getattr(custom_model, "model"), str
|
|
68
|
+
):
|
|
69
|
+
model_name = "custom.langchain:" + getattr(custom_model, "model")
|
|
61
70
|
elif model_name is not None:
|
|
62
|
-
|
|
63
|
-
|
|
71
|
+
# default provider name if not provided
|
|
72
|
+
provider = provider or "custom.langchain.default_provider"
|
|
64
73
|
else:
|
|
65
74
|
raise ValueError(
|
|
66
75
|
"model_name and provider must be provided if custom_model is not provided"
|
|
67
76
|
)
|
|
68
77
|
|
|
78
|
+
if model_name is None:
|
|
79
|
+
raise ValueError("model_name must be provided")
|
|
80
|
+
|
|
81
|
+
super().__init__(
|
|
82
|
+
kiln_task,
|
|
83
|
+
model_name=model_name,
|
|
84
|
+
model_provider_name=provider,
|
|
85
|
+
prompt_builder=prompt_builder,
|
|
86
|
+
tags=tags,
|
|
87
|
+
)
|
|
88
|
+
|
|
69
89
|
async def model(self) -> LangChainModelType:
|
|
70
90
|
# cached model
|
|
71
91
|
if self._model:
|
|
72
92
|
return self._model
|
|
73
93
|
|
|
74
|
-
self._model = await langchain_model_from(
|
|
94
|
+
self._model = await self.langchain_model_from()
|
|
95
|
+
|
|
96
|
+
# Decide if we want to use Langchain's structured output:
|
|
97
|
+
# 1. Only for structured tasks
|
|
98
|
+
# 2. Only if the provider's mode isn't json_instructions (only mode that doesn't use an API option for structured output capabilities)
|
|
99
|
+
provider = self.model_provider()
|
|
100
|
+
use_lc_structured_output = (
|
|
101
|
+
self.has_structured_output()
|
|
102
|
+
and provider.structured_output_mode
|
|
103
|
+
!= StructuredOutputMode.json_instructions
|
|
104
|
+
)
|
|
75
105
|
|
|
76
|
-
if
|
|
106
|
+
if use_lc_structured_output:
|
|
77
107
|
if not hasattr(self._model, "with_structured_output") or not callable(
|
|
78
108
|
getattr(self._model, "with_structured_output")
|
|
79
109
|
):
|
|
@@ -88,8 +118,8 @@ class LangchainAdapter(BaseAdapter):
|
|
|
88
118
|
)
|
|
89
119
|
output_schema["title"] = "task_response"
|
|
90
120
|
output_schema["description"] = "A response from the task"
|
|
91
|
-
with_structured_output_options =
|
|
92
|
-
self.model_name, self.
|
|
121
|
+
with_structured_output_options = self.get_structured_output_options(
|
|
122
|
+
self.model_name, self.model_provider_name
|
|
93
123
|
)
|
|
94
124
|
self._model = self._model.with_structured_output(
|
|
95
125
|
output_schema,
|
|
@@ -99,6 +129,7 @@ class LangchainAdapter(BaseAdapter):
|
|
|
99
129
|
return self._model
|
|
100
130
|
|
|
101
131
|
async def _run(self, input: Dict | str) -> RunOutput:
|
|
132
|
+
provider = self.model_provider()
|
|
102
133
|
model = await self.model()
|
|
103
134
|
chain = model
|
|
104
135
|
intermediate_outputs = {}
|
|
@@ -110,58 +141,63 @@ class LangchainAdapter(BaseAdapter):
|
|
|
110
141
|
HumanMessage(content=user_msg),
|
|
111
142
|
]
|
|
112
143
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
if
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
144
|
+
run_strategy, cot_prompt = self.run_strategy()
|
|
145
|
+
|
|
146
|
+
if run_strategy == "cot_as_message":
|
|
147
|
+
if not cot_prompt:
|
|
148
|
+
raise ValueError("cot_prompt is required for cot_as_message strategy")
|
|
149
|
+
messages.append(SystemMessage(content=cot_prompt))
|
|
150
|
+
elif run_strategy == "cot_two_call":
|
|
151
|
+
if not cot_prompt:
|
|
152
|
+
raise ValueError("cot_prompt is required for cot_two_call strategy")
|
|
120
153
|
messages.append(
|
|
121
154
|
SystemMessage(content=cot_prompt),
|
|
122
155
|
)
|
|
123
156
|
|
|
157
|
+
# Base model (without structured output) used for COT message
|
|
158
|
+
base_model = await self.langchain_model_from()
|
|
159
|
+
|
|
124
160
|
cot_messages = [*messages]
|
|
125
161
|
cot_response = await base_model.ainvoke(cot_messages)
|
|
126
162
|
intermediate_outputs["chain_of_thought"] = cot_response.content
|
|
127
163
|
messages.append(AIMessage(content=cot_response.content))
|
|
128
|
-
messages.append(
|
|
129
|
-
SystemMessage(content="Considering the above, return a final result.")
|
|
130
|
-
)
|
|
131
|
-
elif cot_prompt:
|
|
132
|
-
messages.append(SystemMessage(content=cot_prompt))
|
|
164
|
+
messages.append(HumanMessage(content=COT_FINAL_ANSWER_PROMPT))
|
|
133
165
|
|
|
134
166
|
response = await chain.ainvoke(messages)
|
|
135
167
|
|
|
136
|
-
if
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
168
|
+
# Langchain may have already parsed the response into structured output, so use that if available.
|
|
169
|
+
# However, a plain string may still be fixed at the parsing layer, so not being structured isn't a critical failure (yet)
|
|
170
|
+
if (
|
|
171
|
+
self.has_structured_output()
|
|
172
|
+
and isinstance(response, dict)
|
|
173
|
+
and "parsed" in response
|
|
174
|
+
and isinstance(response["parsed"], dict)
|
|
175
|
+
):
|
|
143
176
|
structured_response = response["parsed"]
|
|
144
177
|
return RunOutput(
|
|
145
178
|
output=self._munge_response(structured_response),
|
|
146
179
|
intermediate_outputs=intermediate_outputs,
|
|
147
180
|
)
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
181
|
+
|
|
182
|
+
if not isinstance(response, BaseMessage):
|
|
183
|
+
raise RuntimeError(f"response is not a BaseMessage: {response}")
|
|
184
|
+
|
|
185
|
+
text_content = response.content
|
|
186
|
+
if not isinstance(text_content, str):
|
|
187
|
+
raise RuntimeError(f"response is not a string: {text_content}")
|
|
188
|
+
|
|
189
|
+
return RunOutput(
|
|
190
|
+
output=text_content,
|
|
191
|
+
intermediate_outputs=intermediate_outputs,
|
|
192
|
+
)
|
|
158
193
|
|
|
159
194
|
def adapter_info(self) -> AdapterInfo:
|
|
160
195
|
return AdapterInfo(
|
|
161
196
|
model_name=self.model_name,
|
|
162
|
-
model_provider=self.
|
|
197
|
+
model_provider=self.model_provider_name,
|
|
163
198
|
adapter_name="kiln_langchain_adapter",
|
|
164
199
|
prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(),
|
|
200
|
+
prompt_id=self.prompt_builder.prompt_id(),
|
|
165
201
|
)
|
|
166
202
|
|
|
167
203
|
def _munge_response(self, response: Dict) -> Dict:
|
|
@@ -174,34 +210,54 @@ class LangchainAdapter(BaseAdapter):
|
|
|
174
210
|
return response["arguments"]
|
|
175
211
|
return response
|
|
176
212
|
|
|
213
|
+
def get_structured_output_options(
|
|
214
|
+
self, model_name: str, model_provider_name: str
|
|
215
|
+
) -> Dict[str, Any]:
|
|
216
|
+
provider = self.model_provider()
|
|
217
|
+
if not provider:
|
|
218
|
+
return {}
|
|
177
219
|
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
220
|
+
options = {}
|
|
221
|
+
# We may need to add some provider specific logic here if providers use different names for the same mode, but everyone is copying openai for now
|
|
222
|
+
match provider.structured_output_mode:
|
|
223
|
+
case StructuredOutputMode.function_calling:
|
|
224
|
+
options["method"] = "function_calling"
|
|
225
|
+
case StructuredOutputMode.json_mode:
|
|
226
|
+
options["method"] = "json_mode"
|
|
227
|
+
case StructuredOutputMode.json_instruction_and_object:
|
|
228
|
+
# We also pass instructions
|
|
229
|
+
options["method"] = "json_mode"
|
|
230
|
+
case StructuredOutputMode.json_schema:
|
|
231
|
+
options["method"] = "json_schema"
|
|
232
|
+
case StructuredOutputMode.json_instructions:
|
|
233
|
+
# JSON done via instructions in prompt, not via API
|
|
234
|
+
pass
|
|
235
|
+
case StructuredOutputMode.default:
|
|
236
|
+
if provider.name == ModelProviderName.ollama:
|
|
237
|
+
# Ollama has great json_schema support, so use that: https://ollama.com/blog/structured-outputs
|
|
238
|
+
options["method"] = "json_schema"
|
|
239
|
+
else:
|
|
240
|
+
# Let langchain decide the default
|
|
241
|
+
pass
|
|
242
|
+
case _:
|
|
243
|
+
raise_exhaustive_enum_error(provider.structured_output_mode)
|
|
187
244
|
|
|
245
|
+
return options
|
|
188
246
|
|
|
189
|
-
async def langchain_model_from(
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
provider = await kiln_model_provider_from(name, provider_name)
|
|
193
|
-
return await langchain_model_from_provider(provider, name)
|
|
247
|
+
async def langchain_model_from(self) -> BaseChatModel:
|
|
248
|
+
provider = self.model_provider()
|
|
249
|
+
return await langchain_model_from_provider(provider, self.model_name)
|
|
194
250
|
|
|
195
251
|
|
|
196
252
|
async def langchain_model_from_provider(
|
|
197
253
|
provider: KilnModelProvider, model_name: str
|
|
198
254
|
) -> BaseChatModel:
|
|
199
255
|
if provider.name == ModelProviderName.openai:
|
|
200
|
-
|
|
201
|
-
|
|
256
|
+
# We use the OpenAICompatibleAdapter for OpenAI
|
|
257
|
+
raise ValueError("OpenAI is not supported in Langchain adapter")
|
|
202
258
|
elif provider.name == ModelProviderName.openai_compatible:
|
|
203
|
-
#
|
|
204
|
-
|
|
259
|
+
# We use the OpenAICompatibleAdapter for OpenAI compatible
|
|
260
|
+
raise ValueError("OpenAI compatible is not supported in Langchain adapter")
|
|
205
261
|
elif provider.name == ModelProviderName.groq:
|
|
206
262
|
api_key = Config.shared().groq_api_key
|
|
207
263
|
if api_key is None:
|
|
@@ -241,16 +297,6 @@ async def langchain_model_from_provider(
|
|
|
241
297
|
|
|
242
298
|
raise ValueError(f"Model {model_name} not installed on Ollama")
|
|
243
299
|
elif provider.name == ModelProviderName.openrouter:
|
|
244
|
-
|
|
245
|
-
base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
|
|
246
|
-
return ChatOpenAI(
|
|
247
|
-
**provider.provider_options,
|
|
248
|
-
openai_api_key=api_key, # type: ignore[arg-type]
|
|
249
|
-
openai_api_base=base_url, # type: ignore[arg-type]
|
|
250
|
-
default_headers={
|
|
251
|
-
"HTTP-Referer": "https://getkiln.ai/openrouter",
|
|
252
|
-
"X-Title": "KilnAI",
|
|
253
|
-
},
|
|
254
|
-
)
|
|
300
|
+
raise ValueError("OpenRouter is not supported in Langchain adapter")
|
|
255
301
|
else:
|
|
256
302
|
raise ValueError(f"Invalid model or provider: {model_name} - {provider.name}")
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@dataclass
|
|
5
|
+
class OpenAICompatibleConfig:
|
|
6
|
+
api_key: str
|
|
7
|
+
model_name: str
|
|
8
|
+
provider_name: str
|
|
9
|
+
base_url: str | None = None # Defaults to OpenAI
|
|
10
|
+
default_headers: dict[str, str] | None = None
|
|
11
|
+
openrouter_style_reasoning: bool = False
|