kiln-ai 0.8.1__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (88) hide show
  1. kiln_ai/adapters/__init__.py +7 -7
  2. kiln_ai/adapters/adapter_registry.py +81 -10
  3. kiln_ai/adapters/data_gen/data_gen_task.py +21 -3
  4. kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
  5. kiln_ai/adapters/eval/base_eval.py +164 -0
  6. kiln_ai/adapters/eval/eval_runner.py +267 -0
  7. kiln_ai/adapters/eval/g_eval.py +367 -0
  8. kiln_ai/adapters/eval/registry.py +16 -0
  9. kiln_ai/adapters/eval/test_base_eval.py +324 -0
  10. kiln_ai/adapters/eval/test_eval_runner.py +640 -0
  11. kiln_ai/adapters/eval/test_g_eval.py +497 -0
  12. kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
  14. kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
  15. kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
  16. kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
  17. kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
  18. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +472 -129
  19. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +114 -22
  20. kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
  21. kiln_ai/adapters/ml_model_list.py +434 -93
  22. kiln_ai/adapters/model_adapters/__init__.py +18 -0
  23. kiln_ai/adapters/model_adapters/base_adapter.py +250 -0
  24. kiln_ai/adapters/model_adapters/langchain_adapters.py +309 -0
  25. kiln_ai/adapters/model_adapters/openai_compatible_config.py +10 -0
  26. kiln_ai/adapters/model_adapters/openai_model_adapter.py +289 -0
  27. kiln_ai/adapters/model_adapters/test_base_adapter.py +199 -0
  28. kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +105 -97
  29. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +216 -0
  30. kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +80 -30
  31. kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +125 -46
  32. kiln_ai/adapters/ollama_tools.py +0 -1
  33. kiln_ai/adapters/parsers/__init__.py +10 -0
  34. kiln_ai/adapters/parsers/base_parser.py +12 -0
  35. kiln_ai/adapters/parsers/json_parser.py +37 -0
  36. kiln_ai/adapters/parsers/parser_registry.py +19 -0
  37. kiln_ai/adapters/parsers/r1_parser.py +69 -0
  38. kiln_ai/adapters/parsers/test_json_parser.py +81 -0
  39. kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
  40. kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
  41. kiln_ai/adapters/prompt_builders.py +193 -49
  42. kiln_ai/adapters/provider_tools.py +91 -36
  43. kiln_ai/adapters/repair/repair_task.py +18 -19
  44. kiln_ai/adapters/repair/test_repair_task.py +7 -7
  45. kiln_ai/adapters/run_output.py +11 -0
  46. kiln_ai/adapters/test_adapter_registry.py +177 -0
  47. kiln_ai/adapters/test_generate_docs.py +69 -0
  48. kiln_ai/adapters/test_ollama_tools.py +0 -1
  49. kiln_ai/adapters/test_prompt_adaptors.py +25 -18
  50. kiln_ai/adapters/test_prompt_builders.py +265 -44
  51. kiln_ai/adapters/test_provider_tools.py +268 -46
  52. kiln_ai/datamodel/__init__.py +51 -772
  53. kiln_ai/datamodel/basemodel.py +31 -11
  54. kiln_ai/datamodel/datamodel_enums.py +58 -0
  55. kiln_ai/datamodel/dataset_filters.py +114 -0
  56. kiln_ai/datamodel/dataset_split.py +170 -0
  57. kiln_ai/datamodel/eval.py +298 -0
  58. kiln_ai/datamodel/finetune.py +105 -0
  59. kiln_ai/datamodel/json_schema.py +14 -3
  60. kiln_ai/datamodel/model_cache.py +8 -3
  61. kiln_ai/datamodel/project.py +23 -0
  62. kiln_ai/datamodel/prompt.py +37 -0
  63. kiln_ai/datamodel/prompt_id.py +83 -0
  64. kiln_ai/datamodel/strict_mode.py +24 -0
  65. kiln_ai/datamodel/task.py +181 -0
  66. kiln_ai/datamodel/task_output.py +321 -0
  67. kiln_ai/datamodel/task_run.py +164 -0
  68. kiln_ai/datamodel/test_basemodel.py +80 -2
  69. kiln_ai/datamodel/test_dataset_filters.py +71 -0
  70. kiln_ai/datamodel/test_dataset_split.py +127 -6
  71. kiln_ai/datamodel/test_datasource.py +3 -2
  72. kiln_ai/datamodel/test_eval_model.py +635 -0
  73. kiln_ai/datamodel/test_example_models.py +34 -17
  74. kiln_ai/datamodel/test_json_schema.py +23 -0
  75. kiln_ai/datamodel/test_model_cache.py +24 -0
  76. kiln_ai/datamodel/test_model_perf.py +125 -0
  77. kiln_ai/datamodel/test_models.py +131 -2
  78. kiln_ai/datamodel/test_prompt_id.py +129 -0
  79. kiln_ai/datamodel/test_task.py +159 -0
  80. kiln_ai/utils/config.py +6 -1
  81. kiln_ai/utils/exhaustive_error.py +6 -0
  82. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/METADATA +45 -7
  83. kiln_ai-0.12.0.dist-info/RECORD +100 -0
  84. kiln_ai/adapters/base_adapter.py +0 -191
  85. kiln_ai/adapters/langchain_adapters.py +0 -256
  86. kiln_ai-0.8.1.dist-info/RECORD +0 -58
  87. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/WHEEL +0 -0
  88. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,18 @@
1
+ """
2
+ # Model Adapters
3
+
4
+ Model adapters are used to call AI models, like Ollama, OpenAI, etc.
5
+
6
+ """
7
+
8
+ from . import (
9
+ base_adapter,
10
+ langchain_adapters,
11
+ openai_model_adapter,
12
+ )
13
+
14
+ __all__ = [
15
+ "base_adapter",
16
+ "langchain_adapters",
17
+ "openai_model_adapter",
18
+ ]
@@ -0,0 +1,250 @@
1
+ import json
2
+ from abc import ABCMeta, abstractmethod
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Literal, Tuple
5
+
6
+ from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
7
+ from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id
8
+ from kiln_ai.adapters.prompt_builders import prompt_builder_from_id
9
+ from kiln_ai.adapters.provider_tools import kiln_model_provider_from
10
+ from kiln_ai.adapters.run_output import RunOutput
11
+ from kiln_ai.datamodel import (
12
+ DataSource,
13
+ DataSourceType,
14
+ Task,
15
+ TaskOutput,
16
+ TaskRun,
17
+ )
18
+ from kiln_ai.datamodel.json_schema import validate_schema
19
+ from kiln_ai.datamodel.task import RunConfig
20
+ from kiln_ai.utils.config import Config
21
+
22
+
23
+ @dataclass
24
+ class AdapterConfig:
25
+ """
26
+ An adapter config is config options that do NOT impact the output of the model.
27
+
28
+ For example: if it's saved, of if we request additional data like logprobs.
29
+ """
30
+
31
+ allow_saving: bool = True
32
+ top_logprobs: int | None = None
33
+ default_tags: list[str] | None = None
34
+
35
+
36
+ COT_FINAL_ANSWER_PROMPT = "Considering the above, return a final result."
37
+
38
+
39
+ class BaseAdapter(metaclass=ABCMeta):
40
+ """Base class for AI model adapters that handle task execution.
41
+
42
+ This abstract class provides the foundation for implementing model-specific adapters
43
+ that can process tasks with structured or unstructured inputs/outputs. It handles
44
+ input/output validation, prompt building, and run tracking.
45
+
46
+ Attributes:
47
+ prompt_builder (BasePromptBuilder): Builder for constructing prompts for the model
48
+ kiln_task (Task): The task configuration and metadata
49
+ output_schema (dict | None): JSON schema for validating structured outputs
50
+ input_schema (dict | None): JSON schema for validating structured inputs
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ run_config: RunConfig,
56
+ config: AdapterConfig | None = None,
57
+ ):
58
+ self.run_config = run_config
59
+ self.prompt_builder = prompt_builder_from_id(
60
+ run_config.prompt_id, run_config.task
61
+ )
62
+ self._model_provider: KilnModelProvider | None = None
63
+
64
+ self.output_schema = self.task().output_json_schema
65
+ self.input_schema = self.task().input_json_schema
66
+ self.base_adapter_config = config or AdapterConfig()
67
+
68
+ def task(self) -> Task:
69
+ return self.run_config.task
70
+
71
+ def model_provider(self) -> KilnModelProvider:
72
+ """
73
+ Lazy load the model provider for this adapter.
74
+ """
75
+ if self._model_provider is not None:
76
+ return self._model_provider
77
+ if not self.run_config.model_name or not self.run_config.model_provider_name:
78
+ raise ValueError("model_name and model_provider_name must be provided")
79
+ self._model_provider = kiln_model_provider_from(
80
+ self.run_config.model_name, self.run_config.model_provider_name
81
+ )
82
+ if not self._model_provider:
83
+ raise ValueError(
84
+ f"model_provider_name {self.run_config.model_provider_name} not found for model {self.run_config.model_name}"
85
+ )
86
+ return self._model_provider
87
+
88
+ async def invoke_returning_raw(
89
+ self,
90
+ input: Dict | str,
91
+ input_source: DataSource | None = None,
92
+ ) -> Dict | str:
93
+ result = await self.invoke(input, input_source)
94
+ if self.task().output_json_schema is None:
95
+ return result.output.output
96
+ else:
97
+ return json.loads(result.output.output)
98
+
99
+ async def invoke(
100
+ self,
101
+ input: Dict | str,
102
+ input_source: DataSource | None = None,
103
+ ) -> TaskRun:
104
+ run_output, _ = await self.invoke_returning_run_output(input, input_source)
105
+ return run_output
106
+
107
+ async def invoke_returning_run_output(
108
+ self,
109
+ input: Dict | str,
110
+ input_source: DataSource | None = None,
111
+ ) -> Tuple[TaskRun, RunOutput]:
112
+ # validate input
113
+ if self.input_schema is not None:
114
+ if not isinstance(input, dict):
115
+ raise ValueError(f"structured input is not a dict: {input}")
116
+ validate_schema(input, self.input_schema)
117
+
118
+ # Run
119
+ run_output = await self._run(input)
120
+
121
+ # Parse
122
+ provider = self.model_provider()
123
+ parser = model_parser_from_id(provider.parser)(
124
+ structured_output=self.has_structured_output()
125
+ )
126
+ parsed_output = parser.parse_output(original_output=run_output)
127
+
128
+ # validate output
129
+ if self.output_schema is not None:
130
+ if not isinstance(parsed_output.output, dict):
131
+ raise RuntimeError(
132
+ f"structured response is not a dict: {parsed_output.output}"
133
+ )
134
+ validate_schema(parsed_output.output, self.output_schema)
135
+ else:
136
+ if not isinstance(parsed_output.output, str):
137
+ raise RuntimeError(
138
+ f"response is not a string for non-structured task: {parsed_output.output}"
139
+ )
140
+
141
+ # Generate the run and output
142
+ run = self.generate_run(input, input_source, parsed_output)
143
+
144
+ # Save the run if configured to do so, and we have a path to save to
145
+ if (
146
+ self.base_adapter_config.allow_saving
147
+ and Config.shared().autosave_runs
148
+ and self.task().path is not None
149
+ ):
150
+ run.save_to_file()
151
+ else:
152
+ # Clear the ID to indicate it's not persisted
153
+ run.id = None
154
+
155
+ return run, run_output
156
+
157
+ def has_structured_output(self) -> bool:
158
+ return self.output_schema is not None
159
+
160
+ @abstractmethod
161
+ def adapter_name(self) -> str:
162
+ pass
163
+
164
+ @abstractmethod
165
+ async def _run(self, input: Dict | str) -> RunOutput:
166
+ pass
167
+
168
+ def build_prompt(self) -> str:
169
+ # The prompt builder needs to know if we want to inject formatting instructions
170
+ provider = self.model_provider()
171
+ add_json_instructions = self.has_structured_output() and (
172
+ provider.structured_output_mode == StructuredOutputMode.json_instructions
173
+ or provider.structured_output_mode
174
+ == StructuredOutputMode.json_instruction_and_object
175
+ )
176
+
177
+ return self.prompt_builder.build_prompt(
178
+ include_json_instructions=add_json_instructions
179
+ )
180
+
181
+ def run_strategy(
182
+ self,
183
+ ) -> Tuple[Literal["cot_as_message", "cot_two_call", "basic"], str | None]:
184
+ # Determine the run strategy for COT prompting. 3 options:
185
+ # 1. "Thinking" LLM designed to output thinking in a structured format plus a COT prompt: we make 1 call to the LLM, which outputs thinking in a structured format. We include the thinking instuctions as a message.
186
+ # 2. Normal LLM with COT prompt: we make 2 calls to the LLM - one for thinking and one for the final response. This helps us use the LLM's structured output modes (json_schema, tools, etc), which can't be used in a single call. It also separates the thinking from the final response.
187
+ # 3. Non chain of thought: we make 1 call to the LLM, with no COT prompt.
188
+ cot_prompt = self.prompt_builder.chain_of_thought_prompt()
189
+ reasoning_capable = self.model_provider().reasoning_capable
190
+
191
+ if cot_prompt and reasoning_capable:
192
+ # 1: "Thinking" LLM designed to output thinking in a structured format
193
+ # A simple message with the COT prompt appended to the message list is sufficient
194
+ return "cot_as_message", cot_prompt
195
+ elif cot_prompt:
196
+ # 2: Unstructured output with COT
197
+ # Two calls to separate the thinking from the final response
198
+ return "cot_two_call", cot_prompt
199
+ else:
200
+ return "basic", None
201
+
202
+ # create a run and task output
203
+ def generate_run(
204
+ self, input: Dict | str, input_source: DataSource | None, run_output: RunOutput
205
+ ) -> TaskRun:
206
+ # Convert input and output to JSON strings if they are dictionaries
207
+ input_str = (
208
+ json.dumps(input, ensure_ascii=False) if isinstance(input, dict) else input
209
+ )
210
+ output_str = (
211
+ json.dumps(run_output.output, ensure_ascii=False)
212
+ if isinstance(run_output.output, dict)
213
+ else run_output.output
214
+ )
215
+
216
+ # If no input source is provided, use the human data source
217
+ if input_source is None:
218
+ input_source = DataSource(
219
+ type=DataSourceType.human,
220
+ properties={"created_by": Config.shared().user_id},
221
+ )
222
+
223
+ new_task_run = TaskRun(
224
+ parent=self.task(),
225
+ input=input_str,
226
+ input_source=input_source,
227
+ output=TaskOutput(
228
+ output=output_str,
229
+ # Synthetic since an adapter, not a human, is creating this
230
+ source=DataSource(
231
+ type=DataSourceType.synthetic,
232
+ properties=self._properties_for_task_output(),
233
+ ),
234
+ ),
235
+ intermediate_outputs=run_output.intermediate_outputs,
236
+ tags=self.base_adapter_config.default_tags or [],
237
+ )
238
+
239
+ return new_task_run
240
+
241
+ def _properties_for_task_output(self) -> Dict[str, str | int | float]:
242
+ props = {}
243
+
244
+ # adapter info
245
+ props["adapter_name"] = self.adapter_name()
246
+ props["model_name"] = self.run_config.model_name
247
+ props["model_provider"] = self.run_config.model_provider_name
248
+ props["prompt_id"] = self.run_config.prompt_id
249
+
250
+ return props
@@ -0,0 +1,309 @@
1
+ import os
2
+ from typing import Any, Dict
3
+
4
+ from langchain_aws import ChatBedrockConverse
5
+ from langchain_core.language_models import LanguageModelInput
6
+ from langchain_core.language_models.chat_models import BaseChatModel
7
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
8
+ from langchain_core.messages.base import BaseMessage
9
+ from langchain_core.runnables import Runnable
10
+ from langchain_fireworks import ChatFireworks
11
+ from langchain_groq import ChatGroq
12
+ from langchain_ollama import ChatOllama
13
+ from pydantic import BaseModel
14
+
15
+ import kiln_ai.datamodel as datamodel
16
+ from kiln_ai.adapters.ml_model_list import (
17
+ KilnModelProvider,
18
+ ModelProviderName,
19
+ StructuredOutputMode,
20
+ )
21
+ from kiln_ai.adapters.model_adapters.base_adapter import (
22
+ COT_FINAL_ANSWER_PROMPT,
23
+ AdapterConfig,
24
+ BaseAdapter,
25
+ RunOutput,
26
+ )
27
+ from kiln_ai.adapters.ollama_tools import (
28
+ get_ollama_connection,
29
+ ollama_base_url,
30
+ ollama_model_installed,
31
+ )
32
+ from kiln_ai.datamodel import PromptId
33
+ from kiln_ai.datamodel.task import RunConfig
34
+ from kiln_ai.utils.config import Config
35
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
36
+
37
+ LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel]
38
+
39
+
40
+ class LangchainAdapter(BaseAdapter):
41
+ _model: LangChainModelType | None = None
42
+
43
+ def __init__(
44
+ self,
45
+ kiln_task: datamodel.Task,
46
+ custom_model: BaseChatModel | None = None,
47
+ model_name: str | None = None,
48
+ provider: str | None = None,
49
+ prompt_id: PromptId | None = None,
50
+ base_adapter_config: AdapterConfig | None = None,
51
+ ):
52
+ if custom_model is not None:
53
+ self._model = custom_model
54
+
55
+ # Attempt to infer model provider and name from custom model
56
+ if provider is None:
57
+ provider = "custom.langchain:" + custom_model.__class__.__name__
58
+
59
+ if model_name is None:
60
+ model_name = "custom.langchain:unknown_model"
61
+ if hasattr(custom_model, "model_name") and isinstance(
62
+ getattr(custom_model, "model_name"), str
63
+ ):
64
+ model_name = "custom.langchain:" + getattr(
65
+ custom_model, "model_name"
66
+ )
67
+ if hasattr(custom_model, "model") and isinstance(
68
+ getattr(custom_model, "model"), str
69
+ ):
70
+ model_name = "custom.langchain:" + getattr(custom_model, "model")
71
+ elif model_name is not None:
72
+ # default provider name if not provided
73
+ provider = provider or "custom.langchain.default_provider"
74
+ else:
75
+ raise ValueError(
76
+ "model_name and provider must be provided if custom_model is not provided"
77
+ )
78
+
79
+ if model_name is None:
80
+ raise ValueError("model_name must be provided")
81
+
82
+ run_config = RunConfig(
83
+ task=kiln_task,
84
+ model_name=model_name,
85
+ model_provider_name=provider,
86
+ prompt_id=prompt_id or datamodel.PromptGenerators.SIMPLE,
87
+ )
88
+
89
+ super().__init__(
90
+ run_config=run_config,
91
+ config=base_adapter_config,
92
+ )
93
+
94
+ async def model(self) -> LangChainModelType:
95
+ # cached model
96
+ if self._model:
97
+ return self._model
98
+
99
+ self._model = await self.langchain_model_from()
100
+
101
+ # Decide if we want to use Langchain's structured output:
102
+ # 1. Only for structured tasks
103
+ # 2. Only if the provider's mode isn't json_instructions (only mode that doesn't use an API option for structured output capabilities)
104
+ provider = self.model_provider()
105
+ use_lc_structured_output = (
106
+ self.has_structured_output()
107
+ and provider.structured_output_mode
108
+ != StructuredOutputMode.json_instructions
109
+ )
110
+
111
+ if use_lc_structured_output:
112
+ if not hasattr(self._model, "with_structured_output") or not callable(
113
+ getattr(self._model, "with_structured_output")
114
+ ):
115
+ raise ValueError(
116
+ f"model {self._model} does not support structured output, cannot use output_json_schema"
117
+ )
118
+ # Langchain expects title/description to be at top level, on top of json schema
119
+ output_schema = self.task().output_schema()
120
+ if output_schema is None:
121
+ raise ValueError(
122
+ f"output_json_schema is not valid json: {self.task().output_json_schema}"
123
+ )
124
+ output_schema["title"] = "task_response"
125
+ output_schema["description"] = "A response from the task"
126
+ with_structured_output_options = self.get_structured_output_options(
127
+ self.run_config.model_name, self.run_config.model_provider_name
128
+ )
129
+ self._model = self._model.with_structured_output(
130
+ output_schema,
131
+ include_raw=True,
132
+ **with_structured_output_options,
133
+ )
134
+ return self._model
135
+
136
+ async def _run(self, input: Dict | str) -> RunOutput:
137
+ if self.base_adapter_config.top_logprobs is not None:
138
+ raise ValueError(
139
+ "Kiln's Langchain adapter does not support logprobs/top_logprobs. Select a model from an OpenAI compatible provider (openai, openrouter, etc) instead."
140
+ )
141
+
142
+ provider = self.model_provider()
143
+ model = await self.model()
144
+ chain = model
145
+ intermediate_outputs = {}
146
+
147
+ prompt = self.build_prompt()
148
+ user_msg = self.prompt_builder.build_user_message(input)
149
+ messages = [
150
+ SystemMessage(content=prompt),
151
+ HumanMessage(content=user_msg),
152
+ ]
153
+
154
+ run_strategy, cot_prompt = self.run_strategy()
155
+
156
+ if run_strategy == "cot_as_message":
157
+ if not cot_prompt:
158
+ raise ValueError("cot_prompt is required for cot_as_message strategy")
159
+ messages.append(SystemMessage(content=cot_prompt))
160
+ elif run_strategy == "cot_two_call":
161
+ if not cot_prompt:
162
+ raise ValueError("cot_prompt is required for cot_two_call strategy")
163
+ messages.append(
164
+ SystemMessage(content=cot_prompt),
165
+ )
166
+
167
+ # Base model (without structured output) used for COT message
168
+ base_model = await self.langchain_model_from()
169
+
170
+ cot_messages = [*messages]
171
+ cot_response = await base_model.ainvoke(cot_messages)
172
+ intermediate_outputs["chain_of_thought"] = cot_response.content
173
+ messages.append(AIMessage(content=cot_response.content))
174
+ messages.append(HumanMessage(content=COT_FINAL_ANSWER_PROMPT))
175
+
176
+ response = await chain.ainvoke(messages)
177
+
178
+ # Langchain may have already parsed the response into structured output, so use that if available.
179
+ # However, a plain string may still be fixed at the parsing layer, so not being structured isn't a critical failure (yet)
180
+ if (
181
+ self.has_structured_output()
182
+ and isinstance(response, dict)
183
+ and "parsed" in response
184
+ and isinstance(response["parsed"], dict)
185
+ ):
186
+ structured_response = response["parsed"]
187
+ return RunOutput(
188
+ output=self._munge_response(structured_response),
189
+ intermediate_outputs=intermediate_outputs,
190
+ )
191
+
192
+ if not isinstance(response, BaseMessage):
193
+ raise RuntimeError(f"response is not a BaseMessage: {response}")
194
+
195
+ text_content = response.content
196
+ if not isinstance(text_content, str):
197
+ raise RuntimeError(f"response is not a string: {text_content}")
198
+
199
+ return RunOutput(
200
+ output=text_content,
201
+ intermediate_outputs=intermediate_outputs,
202
+ )
203
+
204
+ def adapter_name(self) -> str:
205
+ return "kiln_langchain_adapter"
206
+
207
+ def _munge_response(self, response: Dict) -> Dict:
208
+ # Mistral Large tool calling format is a bit different. Convert to standard format.
209
+ if (
210
+ "name" in response
211
+ and response["name"] == "task_response"
212
+ and "arguments" in response
213
+ ):
214
+ return response["arguments"]
215
+ return response
216
+
217
+ def get_structured_output_options(
218
+ self, model_name: str, model_provider_name: str
219
+ ) -> Dict[str, Any]:
220
+ provider = self.model_provider()
221
+ if not provider:
222
+ return {}
223
+
224
+ options = {}
225
+ # We may need to add some provider specific logic here if providers use different names for the same mode, but everyone is copying openai for now
226
+ match provider.structured_output_mode:
227
+ case StructuredOutputMode.function_calling_weak:
228
+ # Langchaing doesn't handle weak/strict separately
229
+ options["method"] = "function_calling"
230
+ case StructuredOutputMode.function_calling:
231
+ options["method"] = "function_calling"
232
+ case StructuredOutputMode.json_mode:
233
+ options["method"] = "json_mode"
234
+ case StructuredOutputMode.json_instruction_and_object:
235
+ # We also pass instructions
236
+ options["method"] = "json_mode"
237
+ case StructuredOutputMode.json_schema:
238
+ options["method"] = "json_schema"
239
+ case StructuredOutputMode.json_instructions:
240
+ # JSON done via instructions in prompt, not via API
241
+ pass
242
+ case StructuredOutputMode.default:
243
+ if provider.name == ModelProviderName.ollama:
244
+ # Ollama has great json_schema support, so use that: https://ollama.com/blog/structured-outputs
245
+ options["method"] = "json_schema"
246
+ else:
247
+ # Let langchain decide the default
248
+ pass
249
+ case _:
250
+ raise_exhaustive_enum_error(provider.structured_output_mode)
251
+
252
+ return options
253
+
254
+ async def langchain_model_from(self) -> BaseChatModel:
255
+ provider = self.model_provider()
256
+ return await langchain_model_from_provider(provider, self.run_config.model_name)
257
+
258
+
259
+ async def langchain_model_from_provider(
260
+ provider: KilnModelProvider, model_name: str
261
+ ) -> BaseChatModel:
262
+ if provider.name == ModelProviderName.openai:
263
+ # We use the OpenAICompatibleAdapter for OpenAI
264
+ raise ValueError("OpenAI is not supported in Langchain adapter")
265
+ elif provider.name == ModelProviderName.openai_compatible:
266
+ # We use the OpenAICompatibleAdapter for OpenAI compatible
267
+ raise ValueError("OpenAI compatible is not supported in Langchain adapter")
268
+ elif provider.name == ModelProviderName.groq:
269
+ api_key = Config.shared().groq_api_key
270
+ if api_key is None:
271
+ raise ValueError(
272
+ "Attempted to use Groq without an API key set. "
273
+ "Get your API key from https://console.groq.com/keys"
274
+ )
275
+ return ChatGroq(**provider.provider_options, groq_api_key=api_key) # type: ignore[arg-type]
276
+ elif provider.name == ModelProviderName.amazon_bedrock:
277
+ api_key = Config.shared().bedrock_access_key
278
+ secret_key = Config.shared().bedrock_secret_key
279
+ # langchain doesn't allow passing these, so ugly hack to set env vars
280
+ os.environ["AWS_ACCESS_KEY_ID"] = api_key
281
+ os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
282
+ return ChatBedrockConverse(
283
+ **provider.provider_options,
284
+ )
285
+ elif provider.name == ModelProviderName.fireworks_ai:
286
+ api_key = Config.shared().fireworks_api_key
287
+ return ChatFireworks(**provider.provider_options, api_key=api_key)
288
+ elif provider.name == ModelProviderName.ollama:
289
+ # Ollama model naming is pretty flexible. We try a few versions of the model name
290
+ potential_model_names = []
291
+ if "model" in provider.provider_options:
292
+ potential_model_names.append(provider.provider_options["model"])
293
+ if "model_aliases" in provider.provider_options:
294
+ potential_model_names.extend(provider.provider_options["model_aliases"])
295
+
296
+ # Get the list of models Ollama supports
297
+ ollama_connection = await get_ollama_connection()
298
+ if ollama_connection is None:
299
+ raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.")
300
+
301
+ for model_name in potential_model_names:
302
+ if ollama_model_installed(ollama_connection, model_name):
303
+ return ChatOllama(model=model_name, base_url=ollama_base_url())
304
+
305
+ raise ValueError(f"Model {model_name} not installed on Ollama")
306
+ elif provider.name == ModelProviderName.openrouter:
307
+ raise ValueError("OpenRouter is not supported in Langchain adapter")
308
+ else:
309
+ raise ValueError(f"Invalid model or provider: {model_name} - {provider.name}")
@@ -0,0 +1,10 @@
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class OpenAICompatibleConfig:
6
+ api_key: str
7
+ model_name: str
8
+ provider_name: str
9
+ base_url: str | None = None # Defaults to OpenAI
10
+ default_headers: dict[str, str] | None = None