kiln-ai 0.0.4__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/base_adapter.py +168 -0
- kiln_ai/adapters/langchain_adapters.py +113 -0
- kiln_ai/adapters/ml_model_list.py +436 -0
- kiln_ai/adapters/prompt_builders.py +122 -0
- kiln_ai/adapters/repair/repair_task.py +71 -0
- kiln_ai/adapters/repair/test_repair_task.py +248 -0
- kiln_ai/adapters/test_langchain_adapter.py +50 -0
- kiln_ai/adapters/test_ml_model_list.py +99 -0
- kiln_ai/adapters/test_prompt_adaptors.py +167 -0
- kiln_ai/adapters/test_prompt_builders.py +315 -0
- kiln_ai/adapters/test_saving_adapter_results.py +168 -0
- kiln_ai/adapters/test_structured_output.py +218 -0
- kiln_ai/datamodel/__init__.py +362 -2
- kiln_ai/datamodel/basemodel.py +372 -0
- kiln_ai/datamodel/json_schema.py +45 -0
- kiln_ai/datamodel/test_basemodel.py +277 -0
- kiln_ai/datamodel/test_datasource.py +107 -0
- kiln_ai/datamodel/test_example_models.py +644 -0
- kiln_ai/datamodel/test_json_schema.py +124 -0
- kiln_ai/datamodel/test_models.py +190 -0
- kiln_ai/datamodel/test_nested_save.py +205 -0
- kiln_ai/datamodel/test_output_rating.py +88 -0
- kiln_ai/utils/config.py +170 -0
- kiln_ai/utils/formatting.py +5 -0
- kiln_ai/utils/test_config.py +245 -0
- {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.1.dist-info}/METADATA +22 -1
- kiln_ai-0.5.1.dist-info/RECORD +29 -0
- kiln_ai/__init.__.py +0 -3
- kiln_ai/coreadd.py +0 -3
- kiln_ai/datamodel/project.py +0 -15
- kiln_ai-0.0.4.dist-info/RECORD +0 -8
- {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.1.dist-info}/LICENSE.txt +0 -0
- {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,436 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from os import getenv
|
|
5
|
+
from typing import Dict, List, NoReturn
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
from langchain_aws import ChatBedrockConverse
|
|
9
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
|
10
|
+
from langchain_groq import ChatGroq
|
|
11
|
+
from langchain_ollama import ChatOllama
|
|
12
|
+
from langchain_openai import ChatOpenAI
|
|
13
|
+
from pydantic import BaseModel
|
|
14
|
+
|
|
15
|
+
from ..utils.config import Config
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ModelProviderName(str, Enum):
|
|
19
|
+
openai = "openai"
|
|
20
|
+
groq = "groq"
|
|
21
|
+
amazon_bedrock = "amazon_bedrock"
|
|
22
|
+
ollama = "ollama"
|
|
23
|
+
openrouter = "openrouter"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ModelFamily(str, Enum):
|
|
27
|
+
gpt = "gpt"
|
|
28
|
+
llama = "llama"
|
|
29
|
+
phi = "phi"
|
|
30
|
+
mistral = "mistral"
|
|
31
|
+
gemma = "gemma"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# Where models have instruct and raw versions, instruct is default and raw is specified
|
|
35
|
+
class ModelName(str, Enum):
|
|
36
|
+
llama_3_1_8b = "llama_3_1_8b"
|
|
37
|
+
llama_3_1_70b = "llama_3_1_70b"
|
|
38
|
+
llama_3_1_405b = "llama_3_1_405b"
|
|
39
|
+
gpt_4o_mini = "gpt_4o_mini"
|
|
40
|
+
gpt_4o = "gpt_4o"
|
|
41
|
+
phi_3_5 = "phi_3_5"
|
|
42
|
+
mistral_large = "mistral_large"
|
|
43
|
+
mistral_nemo = "mistral_nemo"
|
|
44
|
+
gemma_2_2b = "gemma_2_2b"
|
|
45
|
+
gemma_2_9b = "gemma_2_9b"
|
|
46
|
+
gemma_2_27b = "gemma_2_27b"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class KilnModelProvider(BaseModel):
|
|
50
|
+
name: ModelProviderName
|
|
51
|
+
# Allow overriding the model level setting
|
|
52
|
+
supports_structured_output: bool = True
|
|
53
|
+
provider_options: Dict = {}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class KilnModel(BaseModel):
|
|
57
|
+
family: str
|
|
58
|
+
name: str
|
|
59
|
+
friendly_name: str
|
|
60
|
+
providers: List[KilnModelProvider]
|
|
61
|
+
supports_structured_output: bool = True
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
built_in_models: List[KilnModel] = [
|
|
65
|
+
# GPT 4o Mini
|
|
66
|
+
KilnModel(
|
|
67
|
+
family=ModelFamily.gpt,
|
|
68
|
+
name=ModelName.gpt_4o_mini,
|
|
69
|
+
friendly_name="GPT 4o Mini",
|
|
70
|
+
providers=[
|
|
71
|
+
KilnModelProvider(
|
|
72
|
+
name=ModelProviderName.openai,
|
|
73
|
+
provider_options={"model": "gpt-4o-mini"},
|
|
74
|
+
),
|
|
75
|
+
KilnModelProvider(
|
|
76
|
+
name=ModelProviderName.openrouter,
|
|
77
|
+
provider_options={"model": "openai/gpt-4o-mini"},
|
|
78
|
+
),
|
|
79
|
+
],
|
|
80
|
+
),
|
|
81
|
+
# GPT 4o
|
|
82
|
+
KilnModel(
|
|
83
|
+
family=ModelFamily.gpt,
|
|
84
|
+
name=ModelName.gpt_4o,
|
|
85
|
+
friendly_name="GPT 4o",
|
|
86
|
+
providers=[
|
|
87
|
+
KilnModelProvider(
|
|
88
|
+
name=ModelProviderName.openai,
|
|
89
|
+
provider_options={"model": "gpt-4o"},
|
|
90
|
+
),
|
|
91
|
+
KilnModelProvider(
|
|
92
|
+
name=ModelProviderName.openrouter,
|
|
93
|
+
provider_options={"model": "openai/gpt-4o-2024-08-06"},
|
|
94
|
+
),
|
|
95
|
+
],
|
|
96
|
+
),
|
|
97
|
+
# Llama 3.1-8b
|
|
98
|
+
KilnModel(
|
|
99
|
+
family=ModelFamily.llama,
|
|
100
|
+
name=ModelName.llama_3_1_8b,
|
|
101
|
+
friendly_name="Llama 3.1 8B",
|
|
102
|
+
providers=[
|
|
103
|
+
KilnModelProvider(
|
|
104
|
+
name=ModelProviderName.groq,
|
|
105
|
+
provider_options={"model": "llama-3.1-8b-instant"},
|
|
106
|
+
),
|
|
107
|
+
KilnModelProvider(
|
|
108
|
+
name=ModelProviderName.amazon_bedrock,
|
|
109
|
+
provider_options={
|
|
110
|
+
"model": "meta.llama3-1-8b-instruct-v1:0",
|
|
111
|
+
"region_name": "us-west-2", # Llama 3.1 only in west-2
|
|
112
|
+
},
|
|
113
|
+
),
|
|
114
|
+
KilnModelProvider(
|
|
115
|
+
name=ModelProviderName.ollama,
|
|
116
|
+
provider_options={"model": "llama3.1"}, # 8b is default
|
|
117
|
+
),
|
|
118
|
+
KilnModelProvider(
|
|
119
|
+
name=ModelProviderName.openrouter,
|
|
120
|
+
provider_options={"model": "meta-llama/llama-3.1-8b-instruct"},
|
|
121
|
+
),
|
|
122
|
+
],
|
|
123
|
+
),
|
|
124
|
+
# Llama 3.1 70b
|
|
125
|
+
KilnModel(
|
|
126
|
+
family=ModelFamily.llama,
|
|
127
|
+
name=ModelName.llama_3_1_70b,
|
|
128
|
+
friendly_name="Llama 3.1 70B",
|
|
129
|
+
providers=[
|
|
130
|
+
KilnModelProvider(
|
|
131
|
+
name=ModelProviderName.groq,
|
|
132
|
+
provider_options={"model": "llama-3.1-70b-versatile"},
|
|
133
|
+
),
|
|
134
|
+
KilnModelProvider(
|
|
135
|
+
name=ModelProviderName.amazon_bedrock,
|
|
136
|
+
# TODO: this should work but a bug in the bedrock response schema
|
|
137
|
+
supports_structured_output=False,
|
|
138
|
+
provider_options={
|
|
139
|
+
"model": "meta.llama3-1-70b-instruct-v1:0",
|
|
140
|
+
"region_name": "us-west-2", # Llama 3.1 only in west-2
|
|
141
|
+
},
|
|
142
|
+
),
|
|
143
|
+
KilnModelProvider(
|
|
144
|
+
name=ModelProviderName.openrouter,
|
|
145
|
+
provider_options={"model": "meta-llama/llama-3.1-70b-instruct"},
|
|
146
|
+
),
|
|
147
|
+
# TODO: enable once tests update to check if model is available
|
|
148
|
+
# KilnModelProvider(
|
|
149
|
+
# provider=ModelProviders.ollama,
|
|
150
|
+
# provider_options={"model": "llama3.1:70b"},
|
|
151
|
+
# ),
|
|
152
|
+
],
|
|
153
|
+
),
|
|
154
|
+
# Llama 3.1 405b
|
|
155
|
+
KilnModel(
|
|
156
|
+
family=ModelFamily.llama,
|
|
157
|
+
name=ModelName.llama_3_1_405b,
|
|
158
|
+
friendly_name="Llama 3.1 405B",
|
|
159
|
+
providers=[
|
|
160
|
+
# TODO: bring back when groq does: https://console.groq.com/docs/models
|
|
161
|
+
# KilnModelProvider(
|
|
162
|
+
# name=ModelProviderName.groq,
|
|
163
|
+
# provider_options={"model": "llama-3.1-405b-instruct-v1:0"},
|
|
164
|
+
# ),
|
|
165
|
+
KilnModelProvider(
|
|
166
|
+
name=ModelProviderName.amazon_bedrock,
|
|
167
|
+
provider_options={
|
|
168
|
+
"model": "meta.llama3-1-405b-instruct-v1:0",
|
|
169
|
+
"region_name": "us-west-2", # Llama 3.1 only in west-2
|
|
170
|
+
},
|
|
171
|
+
),
|
|
172
|
+
# TODO: enable once tests update to check if model is available
|
|
173
|
+
# KilnModelProvider(
|
|
174
|
+
# name=ModelProviderName.ollama,
|
|
175
|
+
# provider_options={"model": "llama3.1:405b"},
|
|
176
|
+
# ),
|
|
177
|
+
KilnModelProvider(
|
|
178
|
+
name=ModelProviderName.openrouter,
|
|
179
|
+
provider_options={"model": "meta-llama/llama-3.1-405b-instruct"},
|
|
180
|
+
),
|
|
181
|
+
],
|
|
182
|
+
),
|
|
183
|
+
# Mistral Nemo
|
|
184
|
+
KilnModel(
|
|
185
|
+
family=ModelFamily.mistral,
|
|
186
|
+
name=ModelName.mistral_nemo,
|
|
187
|
+
friendly_name="Mistral Nemo",
|
|
188
|
+
providers=[
|
|
189
|
+
KilnModelProvider(
|
|
190
|
+
name=ModelProviderName.openrouter,
|
|
191
|
+
provider_options={"model": "mistralai/mistral-nemo"},
|
|
192
|
+
),
|
|
193
|
+
],
|
|
194
|
+
),
|
|
195
|
+
# Mistral Large
|
|
196
|
+
KilnModel(
|
|
197
|
+
family=ModelFamily.mistral,
|
|
198
|
+
name=ModelName.mistral_large,
|
|
199
|
+
friendly_name="Mistral Large",
|
|
200
|
+
providers=[
|
|
201
|
+
KilnModelProvider(
|
|
202
|
+
name=ModelProviderName.amazon_bedrock,
|
|
203
|
+
provider_options={
|
|
204
|
+
"model": "mistral.mistral-large-2407-v1:0",
|
|
205
|
+
"region_name": "us-west-2", # only in west-2
|
|
206
|
+
},
|
|
207
|
+
),
|
|
208
|
+
KilnModelProvider(
|
|
209
|
+
name=ModelProviderName.openrouter,
|
|
210
|
+
provider_options={"model": "mistralai/mistral-large"},
|
|
211
|
+
),
|
|
212
|
+
# TODO: enable once tests update to check if model is available
|
|
213
|
+
# KilnModelProvider(
|
|
214
|
+
# provider=ModelProviders.ollama,
|
|
215
|
+
# provider_options={"model": "mistral-large"},
|
|
216
|
+
# ),
|
|
217
|
+
],
|
|
218
|
+
),
|
|
219
|
+
# Phi 3.5
|
|
220
|
+
KilnModel(
|
|
221
|
+
family=ModelFamily.phi,
|
|
222
|
+
name=ModelName.phi_3_5,
|
|
223
|
+
friendly_name="Phi 3.5",
|
|
224
|
+
supports_structured_output=False,
|
|
225
|
+
providers=[
|
|
226
|
+
KilnModelProvider(
|
|
227
|
+
name=ModelProviderName.ollama,
|
|
228
|
+
provider_options={"model": "phi3.5"},
|
|
229
|
+
),
|
|
230
|
+
KilnModelProvider(
|
|
231
|
+
name=ModelProviderName.openrouter,
|
|
232
|
+
provider_options={"model": "microsoft/phi-3.5-mini-128k-instruct"},
|
|
233
|
+
),
|
|
234
|
+
],
|
|
235
|
+
),
|
|
236
|
+
# Gemma 2 2.6b
|
|
237
|
+
KilnModel(
|
|
238
|
+
family=ModelFamily.gemma,
|
|
239
|
+
name=ModelName.gemma_2_2b,
|
|
240
|
+
friendly_name="Gemma 2 2B",
|
|
241
|
+
supports_structured_output=False,
|
|
242
|
+
providers=[
|
|
243
|
+
KilnModelProvider(
|
|
244
|
+
name=ModelProviderName.ollama,
|
|
245
|
+
provider_options={
|
|
246
|
+
"model": "gemma2:2b",
|
|
247
|
+
},
|
|
248
|
+
),
|
|
249
|
+
],
|
|
250
|
+
),
|
|
251
|
+
# Gemma 2 9b
|
|
252
|
+
KilnModel(
|
|
253
|
+
family=ModelFamily.gemma,
|
|
254
|
+
name=ModelName.gemma_2_9b,
|
|
255
|
+
friendly_name="Gemma 2 9B",
|
|
256
|
+
supports_structured_output=False,
|
|
257
|
+
providers=[
|
|
258
|
+
# TODO: enable once tests update to check if model is available
|
|
259
|
+
# KilnModelProvider(
|
|
260
|
+
# name=ModelProviderName.ollama,
|
|
261
|
+
# provider_options={
|
|
262
|
+
# "model": "gemma2:9b",
|
|
263
|
+
# },
|
|
264
|
+
# ),
|
|
265
|
+
KilnModelProvider(
|
|
266
|
+
name=ModelProviderName.openrouter,
|
|
267
|
+
provider_options={"model": "google/gemma-2-9b-it"},
|
|
268
|
+
),
|
|
269
|
+
],
|
|
270
|
+
),
|
|
271
|
+
# Gemma 2 27b
|
|
272
|
+
KilnModel(
|
|
273
|
+
family=ModelFamily.gemma,
|
|
274
|
+
name=ModelName.gemma_2_27b,
|
|
275
|
+
friendly_name="Gemma 2 27B",
|
|
276
|
+
supports_structured_output=False,
|
|
277
|
+
providers=[
|
|
278
|
+
# TODO: enable once tests update to check if model is available
|
|
279
|
+
# KilnModelProvider(
|
|
280
|
+
# name=ModelProviderName.ollama,
|
|
281
|
+
# provider_options={
|
|
282
|
+
# "model": "gemma2:27b",
|
|
283
|
+
# },
|
|
284
|
+
# ),
|
|
285
|
+
KilnModelProvider(
|
|
286
|
+
name=ModelProviderName.openrouter,
|
|
287
|
+
provider_options={"model": "google/gemma-2-27b-it"},
|
|
288
|
+
),
|
|
289
|
+
],
|
|
290
|
+
),
|
|
291
|
+
]
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def provider_name_from_id(id: str) -> str:
|
|
295
|
+
if id in ModelProviderName.__members__:
|
|
296
|
+
enum_id = ModelProviderName(id)
|
|
297
|
+
match enum_id:
|
|
298
|
+
case ModelProviderName.amazon_bedrock:
|
|
299
|
+
return "Amazon Bedrock"
|
|
300
|
+
case ModelProviderName.openrouter:
|
|
301
|
+
return "OpenRouter"
|
|
302
|
+
case ModelProviderName.groq:
|
|
303
|
+
return "Groq"
|
|
304
|
+
case ModelProviderName.ollama:
|
|
305
|
+
return "Ollama"
|
|
306
|
+
case ModelProviderName.openai:
|
|
307
|
+
return "OpenAI"
|
|
308
|
+
case _:
|
|
309
|
+
# triggers pyright warning if I miss a case
|
|
310
|
+
raise_exhaustive_error(enum_id)
|
|
311
|
+
|
|
312
|
+
return "Unknown provider: " + id
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def raise_exhaustive_error(value: NoReturn) -> NoReturn:
|
|
316
|
+
raise ValueError(f"Unhandled enum value: {value}")
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
@dataclass
|
|
320
|
+
class ModelProviderWarning:
|
|
321
|
+
required_config_keys: List[str]
|
|
322
|
+
message: str
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
|
|
326
|
+
ModelProviderName.amazon_bedrock: ModelProviderWarning(
|
|
327
|
+
required_config_keys=["bedrock_access_key", "bedrock_secret_key"],
|
|
328
|
+
message="Attempted to use Amazon Bedrock without an access key and secret set. \nGet your keys from https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/overview",
|
|
329
|
+
),
|
|
330
|
+
ModelProviderName.openrouter: ModelProviderWarning(
|
|
331
|
+
required_config_keys=["open_router_api_key"],
|
|
332
|
+
message="Attempted to use OpenRouter without an API key set. \nGet your API key from https://openrouter.ai/settings/keys",
|
|
333
|
+
),
|
|
334
|
+
ModelProviderName.groq: ModelProviderWarning(
|
|
335
|
+
required_config_keys=["groq_api_key"],
|
|
336
|
+
message="Attempted to use Groq without an API key set. \nGet your API key from https://console.groq.com/keys",
|
|
337
|
+
),
|
|
338
|
+
ModelProviderName.openai: ModelProviderWarning(
|
|
339
|
+
required_config_keys=["open_ai_api_key"],
|
|
340
|
+
message="Attempted to use OpenAI without an API key set. \nGet your API key from https://platform.openai.com/account/api-keys",
|
|
341
|
+
),
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def get_config_value(key: str):
|
|
346
|
+
try:
|
|
347
|
+
return Config.shared().__getattr__(key)
|
|
348
|
+
except AttributeError:
|
|
349
|
+
return None
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def check_provider_warnings(provider_name: ModelProviderName):
|
|
353
|
+
warning_check = provider_warnings.get(provider_name)
|
|
354
|
+
if warning_check is None:
|
|
355
|
+
return
|
|
356
|
+
for key in warning_check.required_config_keys:
|
|
357
|
+
if get_config_value(key) is None:
|
|
358
|
+
raise ValueError(warning_check.message)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def langchain_model_from(name: str, provider_name: str | None = None) -> BaseChatModel:
|
|
362
|
+
if name not in ModelName.__members__:
|
|
363
|
+
raise ValueError(f"Invalid name: {name}")
|
|
364
|
+
|
|
365
|
+
# Select the model from built_in_models using the name
|
|
366
|
+
model = next(filter(lambda m: m.name == name, built_in_models))
|
|
367
|
+
if model is None:
|
|
368
|
+
raise ValueError(f"Model {name} not found")
|
|
369
|
+
|
|
370
|
+
# If a provider is provided, select the provider from the model's provider_config
|
|
371
|
+
provider: KilnModelProvider | None = None
|
|
372
|
+
if model.providers is None or len(model.providers) == 0:
|
|
373
|
+
raise ValueError(f"Model {name} has no providers")
|
|
374
|
+
elif provider_name is None:
|
|
375
|
+
# TODO: priority order
|
|
376
|
+
provider = model.providers[0]
|
|
377
|
+
else:
|
|
378
|
+
provider = next(
|
|
379
|
+
filter(lambda p: p.name == provider_name, model.providers), None
|
|
380
|
+
)
|
|
381
|
+
if provider is None:
|
|
382
|
+
raise ValueError(f"Provider {provider_name} not found for model {name}")
|
|
383
|
+
|
|
384
|
+
check_provider_warnings(provider.name)
|
|
385
|
+
|
|
386
|
+
if provider.name == ModelProviderName.openai:
|
|
387
|
+
api_key = Config.shared().open_ai_api_key
|
|
388
|
+
return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type]
|
|
389
|
+
elif provider.name == ModelProviderName.groq:
|
|
390
|
+
api_key = Config.shared().groq_api_key
|
|
391
|
+
if api_key is None:
|
|
392
|
+
raise ValueError(
|
|
393
|
+
"Attempted to use Groq without an API key set. "
|
|
394
|
+
"Get your API key from https://console.groq.com/keys"
|
|
395
|
+
)
|
|
396
|
+
return ChatGroq(**provider.provider_options, groq_api_key=api_key) # type: ignore[arg-type]
|
|
397
|
+
elif provider.name == ModelProviderName.amazon_bedrock:
|
|
398
|
+
api_key = Config.shared().bedrock_access_key
|
|
399
|
+
secret_key = Config.shared().bedrock_secret_key
|
|
400
|
+
# langchain doesn't allow passing these, so ugly hack to set env vars
|
|
401
|
+
os.environ["AWS_ACCESS_KEY_ID"] = api_key
|
|
402
|
+
os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
|
|
403
|
+
return ChatBedrockConverse(
|
|
404
|
+
**provider.provider_options,
|
|
405
|
+
)
|
|
406
|
+
elif provider.name == ModelProviderName.ollama:
|
|
407
|
+
return ChatOllama(**provider.provider_options, base_url=ollama_base_url())
|
|
408
|
+
elif provider.name == ModelProviderName.openrouter:
|
|
409
|
+
api_key = Config.shared().open_router_api_key
|
|
410
|
+
base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
|
|
411
|
+
return ChatOpenAI(
|
|
412
|
+
**provider.provider_options,
|
|
413
|
+
openai_api_key=api_key, # type: ignore[arg-type]
|
|
414
|
+
openai_api_base=base_url, # type: ignore[arg-type]
|
|
415
|
+
default_headers={
|
|
416
|
+
"HTTP-Referer": "https://kiln-ai.com/openrouter",
|
|
417
|
+
"X-Title": "KilnAI",
|
|
418
|
+
},
|
|
419
|
+
)
|
|
420
|
+
else:
|
|
421
|
+
raise ValueError(f"Invalid model or provider: {name} - {provider_name}")
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def ollama_base_url():
|
|
425
|
+
env_base_url = os.getenv("OLLAMA_BASE_URL")
|
|
426
|
+
if env_base_url is not None:
|
|
427
|
+
return env_base_url
|
|
428
|
+
return "http://localhost:11434"
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
async def ollama_online():
|
|
432
|
+
try:
|
|
433
|
+
httpx.get(ollama_base_url() + "/api/tags")
|
|
434
|
+
except httpx.RequestError:
|
|
435
|
+
return False
|
|
436
|
+
return True
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from abc import ABCMeta, abstractmethod
|
|
3
|
+
from typing import Dict
|
|
4
|
+
|
|
5
|
+
from kiln_ai.datamodel import Task
|
|
6
|
+
from kiln_ai.utils.formatting import snake_case
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BasePromptBuilder(metaclass=ABCMeta):
|
|
10
|
+
def __init__(self, task: Task):
|
|
11
|
+
self.task = task
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def build_prompt(self) -> str:
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
# override to change the name of the prompt builder (if changing class names)
|
|
18
|
+
@classmethod
|
|
19
|
+
def prompt_builder_name(cls) -> str:
|
|
20
|
+
return snake_case(cls.__name__)
|
|
21
|
+
|
|
22
|
+
# Can be overridden to add more information to the user message
|
|
23
|
+
def build_user_message(self, input: Dict | str) -> str:
|
|
24
|
+
if isinstance(input, Dict):
|
|
25
|
+
return f"The input is:\n{json.dumps(input, indent=2)}"
|
|
26
|
+
|
|
27
|
+
return f"The input is:\n{input}"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SimplePromptBuilder(BasePromptBuilder):
|
|
31
|
+
def build_prompt(self) -> str:
|
|
32
|
+
base_prompt = self.task.instruction
|
|
33
|
+
|
|
34
|
+
# TODO: this is just a quick version. Formatting and best practices TBD
|
|
35
|
+
if len(self.task.requirements) > 0:
|
|
36
|
+
base_prompt += (
|
|
37
|
+
"\n\nYour response should respect the following requirements:\n"
|
|
38
|
+
)
|
|
39
|
+
# iterate requirements, formatting them in numbereed list like 1) task.instruction\n2)...
|
|
40
|
+
for i, requirement in enumerate(self.task.requirements):
|
|
41
|
+
base_prompt += f"{i+1}) {requirement.instruction}\n"
|
|
42
|
+
|
|
43
|
+
return base_prompt
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class MultiShotPromptBuilder(BasePromptBuilder):
|
|
47
|
+
@classmethod
|
|
48
|
+
def example_count(cls) -> int:
|
|
49
|
+
return 25
|
|
50
|
+
|
|
51
|
+
def build_prompt(self) -> str:
|
|
52
|
+
base_prompt = f"# Instruction\n\n{ self.task.instruction }\n\n"
|
|
53
|
+
|
|
54
|
+
if len(self.task.requirements) > 0:
|
|
55
|
+
base_prompt += "# Requirements\n\nYour response should respect the following requirements:\n"
|
|
56
|
+
for i, requirement in enumerate(self.task.requirements):
|
|
57
|
+
base_prompt += f"{i+1}) {requirement.instruction}\n"
|
|
58
|
+
base_prompt += "\n"
|
|
59
|
+
|
|
60
|
+
valid_examples: list[tuple[str, str]] = []
|
|
61
|
+
runs = self.task.runs()
|
|
62
|
+
|
|
63
|
+
# first pass, we look for repaired outputs. These are the best examples.
|
|
64
|
+
for run in runs:
|
|
65
|
+
if len(valid_examples) >= self.__class__.example_count():
|
|
66
|
+
break
|
|
67
|
+
if run.repaired_output is not None:
|
|
68
|
+
valid_examples.append((run.input, run.repaired_output.output))
|
|
69
|
+
|
|
70
|
+
# second pass, we look for high quality outputs (rating based)
|
|
71
|
+
# Minimum is "high_quality" (4 star in star rating scale), then sort by rating
|
|
72
|
+
# exclude repaired outputs as they were used above
|
|
73
|
+
runs_with_rating = [
|
|
74
|
+
run
|
|
75
|
+
for run in runs
|
|
76
|
+
if run.output.rating is not None
|
|
77
|
+
and run.output.rating.value is not None
|
|
78
|
+
and run.output.rating.is_high_quality()
|
|
79
|
+
and run.repaired_output is None
|
|
80
|
+
]
|
|
81
|
+
runs_with_rating.sort(
|
|
82
|
+
key=lambda x: (x.output.rating and x.output.rating.value) or 0, reverse=True
|
|
83
|
+
)
|
|
84
|
+
for run in runs_with_rating:
|
|
85
|
+
if len(valid_examples) >= self.__class__.example_count():
|
|
86
|
+
break
|
|
87
|
+
valid_examples.append((run.input, run.output.output))
|
|
88
|
+
|
|
89
|
+
if len(valid_examples) > 0:
|
|
90
|
+
base_prompt += "# Example Outputs\n\n"
|
|
91
|
+
for i, example in enumerate(valid_examples):
|
|
92
|
+
base_prompt += (
|
|
93
|
+
f"## Example {i+1}\n\nInput: {example[0]}\nOutput: {example[1]}\n\n"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return base_prompt
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class FewShotPromptBuilder(MultiShotPromptBuilder):
|
|
100
|
+
@classmethod
|
|
101
|
+
def example_count(cls) -> int:
|
|
102
|
+
return 4
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
prompt_builder_registry = {
|
|
106
|
+
"simple_prompt_builder": SimplePromptBuilder,
|
|
107
|
+
"multi_shot_prompt_builder": MultiShotPromptBuilder,
|
|
108
|
+
"few_shot_prompt_builder": FewShotPromptBuilder,
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
# Our UI has some names that are not the same as the class names, which also hint parameters.
|
|
113
|
+
def prompt_builder_from_ui_name(ui_name: str) -> type[BasePromptBuilder]:
|
|
114
|
+
match ui_name:
|
|
115
|
+
case "basic":
|
|
116
|
+
return SimplePromptBuilder
|
|
117
|
+
case "few_shot":
|
|
118
|
+
return FewShotPromptBuilder
|
|
119
|
+
case "many_shot":
|
|
120
|
+
return MultiShotPromptBuilder
|
|
121
|
+
case _:
|
|
122
|
+
raise ValueError(f"Unknown prompt builder: {ui_name}")
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Type
|
|
3
|
+
|
|
4
|
+
from kiln_ai.adapters.prompt_builders import BasePromptBuilder, prompt_builder_registry
|
|
5
|
+
from kiln_ai.datamodel import Priority, Project, Task, TaskRequirement, TaskRun
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# TODO add evaluator rating
|
|
10
|
+
class RepairTaskInput(BaseModel):
|
|
11
|
+
original_prompt: str
|
|
12
|
+
original_input: str
|
|
13
|
+
original_output: str
|
|
14
|
+
evaluator_feedback: str = Field(
|
|
15
|
+
min_length=1,
|
|
16
|
+
description="Feedback from an evaluator on how to repair the task run.",
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RepairTaskRun(Task, parent_of={}):
|
|
21
|
+
def __init__(self, original_task: Task):
|
|
22
|
+
# Keep the typechecker happy
|
|
23
|
+
tmp_project = Project(name="Repair")
|
|
24
|
+
super().__init__(
|
|
25
|
+
name="Repair",
|
|
26
|
+
parent=tmp_project,
|
|
27
|
+
description="Repair a task run, given feedback from an evaluator about how the response can be improved.",
|
|
28
|
+
instruction="You are an assistant which helps improve output from another assistant (original assistant). You'll be provided a task that the original assistant executed (prompt), \
|
|
29
|
+
the input it was given, and the output it generated. An evaluator has determined that the output it generated did not satisfy the task and should be improved. The evaluator will provide \
|
|
30
|
+
feedback describing what should be improved. Your job is to understand the evaluator's feedback and improve the response.",
|
|
31
|
+
requirements=[
|
|
32
|
+
TaskRequirement(
|
|
33
|
+
name="Follow Eval Feedback",
|
|
34
|
+
instruction="The evaluator's feedback is the most important thing to consider. If it conflicts with the original task instruction or prompt, prioritize the evaluator's feedback.",
|
|
35
|
+
priority=Priority.p0,
|
|
36
|
+
)
|
|
37
|
+
],
|
|
38
|
+
input_json_schema=json.dumps(RepairTaskInput.model_json_schema()),
|
|
39
|
+
output_json_schema=original_task.output_json_schema,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def _original_prompt(cls, run: TaskRun, task: Task) -> str:
|
|
44
|
+
prompt_builder_class: Type[BasePromptBuilder] | None = None
|
|
45
|
+
prompt_builder_name = run.output.source.properties.get(
|
|
46
|
+
"prompt_builder_name", None
|
|
47
|
+
)
|
|
48
|
+
if prompt_builder_name is not None and isinstance(prompt_builder_name, str):
|
|
49
|
+
prompt_builder_class = prompt_builder_registry.get(
|
|
50
|
+
prompt_builder_name, None
|
|
51
|
+
)
|
|
52
|
+
if prompt_builder_class is None:
|
|
53
|
+
raise ValueError(f"No prompt builder found for name: {prompt_builder_name}")
|
|
54
|
+
prompt_builder = prompt_builder_class(task=task)
|
|
55
|
+
if not isinstance(prompt_builder, BasePromptBuilder):
|
|
56
|
+
raise ValueError(
|
|
57
|
+
f"Prompt builder {prompt_builder_name} is not a valid prompt builder"
|
|
58
|
+
)
|
|
59
|
+
return prompt_builder.build_prompt()
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def build_repair_task_input(
|
|
63
|
+
cls, original_task: Task, task_run: TaskRun, evaluator_feedback: str
|
|
64
|
+
) -> RepairTaskInput:
|
|
65
|
+
original_prompt = cls._original_prompt(task_run, original_task)
|
|
66
|
+
return RepairTaskInput(
|
|
67
|
+
original_prompt=original_prompt,
|
|
68
|
+
original_input=task_run.input,
|
|
69
|
+
original_output=task_run.output.output,
|
|
70
|
+
evaluator_feedback=evaluator_feedback,
|
|
71
|
+
)
|