kiln-ai 0.0.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (33) hide show
  1. kiln_ai/adapters/base_adapter.py +168 -0
  2. kiln_ai/adapters/langchain_adapters.py +113 -0
  3. kiln_ai/adapters/ml_model_list.py +436 -0
  4. kiln_ai/adapters/prompt_builders.py +122 -0
  5. kiln_ai/adapters/repair/repair_task.py +71 -0
  6. kiln_ai/adapters/repair/test_repair_task.py +248 -0
  7. kiln_ai/adapters/test_langchain_adapter.py +50 -0
  8. kiln_ai/adapters/test_ml_model_list.py +99 -0
  9. kiln_ai/adapters/test_prompt_adaptors.py +167 -0
  10. kiln_ai/adapters/test_prompt_builders.py +315 -0
  11. kiln_ai/adapters/test_saving_adapter_results.py +168 -0
  12. kiln_ai/adapters/test_structured_output.py +218 -0
  13. kiln_ai/datamodel/__init__.py +362 -2
  14. kiln_ai/datamodel/basemodel.py +372 -0
  15. kiln_ai/datamodel/json_schema.py +45 -0
  16. kiln_ai/datamodel/test_basemodel.py +277 -0
  17. kiln_ai/datamodel/test_datasource.py +107 -0
  18. kiln_ai/datamodel/test_example_models.py +644 -0
  19. kiln_ai/datamodel/test_json_schema.py +124 -0
  20. kiln_ai/datamodel/test_models.py +190 -0
  21. kiln_ai/datamodel/test_nested_save.py +205 -0
  22. kiln_ai/datamodel/test_output_rating.py +88 -0
  23. kiln_ai/utils/config.py +170 -0
  24. kiln_ai/utils/formatting.py +5 -0
  25. kiln_ai/utils/test_config.py +245 -0
  26. {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.0.dist-info}/METADATA +20 -1
  27. kiln_ai-0.5.0.dist-info/RECORD +29 -0
  28. kiln_ai/__init.__.py +0 -3
  29. kiln_ai/coreadd.py +0 -3
  30. kiln_ai/datamodel/project.py +0 -15
  31. kiln_ai-0.0.4.dist-info/RECORD +0 -8
  32. {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.0.dist-info}/LICENSE.txt +0 -0
  33. {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,436 @@
1
+ import os
2
+ from dataclasses import dataclass
3
+ from enum import Enum
4
+ from os import getenv
5
+ from typing import Dict, List, NoReturn
6
+
7
+ import httpx
8
+ from langchain_aws import ChatBedrockConverse
9
+ from langchain_core.language_models.chat_models import BaseChatModel
10
+ from langchain_groq import ChatGroq
11
+ from langchain_ollama import ChatOllama
12
+ from langchain_openai import ChatOpenAI
13
+ from pydantic import BaseModel
14
+
15
+ from ..utils.config import Config
16
+
17
+
18
+ class ModelProviderName(str, Enum):
19
+ openai = "openai"
20
+ groq = "groq"
21
+ amazon_bedrock = "amazon_bedrock"
22
+ ollama = "ollama"
23
+ openrouter = "openrouter"
24
+
25
+
26
+ class ModelFamily(str, Enum):
27
+ gpt = "gpt"
28
+ llama = "llama"
29
+ phi = "phi"
30
+ mistral = "mistral"
31
+ gemma = "gemma"
32
+
33
+
34
+ # Where models have instruct and raw versions, instruct is default and raw is specified
35
+ class ModelName(str, Enum):
36
+ llama_3_1_8b = "llama_3_1_8b"
37
+ llama_3_1_70b = "llama_3_1_70b"
38
+ llama_3_1_405b = "llama_3_1_405b"
39
+ gpt_4o_mini = "gpt_4o_mini"
40
+ gpt_4o = "gpt_4o"
41
+ phi_3_5 = "phi_3_5"
42
+ mistral_large = "mistral_large"
43
+ mistral_nemo = "mistral_nemo"
44
+ gemma_2_2b = "gemma_2_2b"
45
+ gemma_2_9b = "gemma_2_9b"
46
+ gemma_2_27b = "gemma_2_27b"
47
+
48
+
49
+ class KilnModelProvider(BaseModel):
50
+ name: ModelProviderName
51
+ # Allow overriding the model level setting
52
+ supports_structured_output: bool = True
53
+ provider_options: Dict = {}
54
+
55
+
56
+ class KilnModel(BaseModel):
57
+ family: str
58
+ name: str
59
+ friendly_name: str
60
+ providers: List[KilnModelProvider]
61
+ supports_structured_output: bool = True
62
+
63
+
64
+ built_in_models: List[KilnModel] = [
65
+ # GPT 4o Mini
66
+ KilnModel(
67
+ family=ModelFamily.gpt,
68
+ name=ModelName.gpt_4o_mini,
69
+ friendly_name="GPT 4o Mini",
70
+ providers=[
71
+ KilnModelProvider(
72
+ name=ModelProviderName.openai,
73
+ provider_options={"model": "gpt-4o-mini"},
74
+ ),
75
+ KilnModelProvider(
76
+ name=ModelProviderName.openrouter,
77
+ provider_options={"model": "openai/gpt-4o-mini"},
78
+ ),
79
+ ],
80
+ ),
81
+ # GPT 4o
82
+ KilnModel(
83
+ family=ModelFamily.gpt,
84
+ name=ModelName.gpt_4o,
85
+ friendly_name="GPT 4o",
86
+ providers=[
87
+ KilnModelProvider(
88
+ name=ModelProviderName.openai,
89
+ provider_options={"model": "gpt-4o"},
90
+ ),
91
+ KilnModelProvider(
92
+ name=ModelProviderName.openrouter,
93
+ provider_options={"model": "openai/gpt-4o-2024-08-06"},
94
+ ),
95
+ ],
96
+ ),
97
+ # Llama 3.1-8b
98
+ KilnModel(
99
+ family=ModelFamily.llama,
100
+ name=ModelName.llama_3_1_8b,
101
+ friendly_name="Llama 3.1 8B",
102
+ providers=[
103
+ KilnModelProvider(
104
+ name=ModelProviderName.groq,
105
+ provider_options={"model": "llama-3.1-8b-instant"},
106
+ ),
107
+ KilnModelProvider(
108
+ name=ModelProviderName.amazon_bedrock,
109
+ provider_options={
110
+ "model": "meta.llama3-1-8b-instruct-v1:0",
111
+ "region_name": "us-west-2", # Llama 3.1 only in west-2
112
+ },
113
+ ),
114
+ KilnModelProvider(
115
+ name=ModelProviderName.ollama,
116
+ provider_options={"model": "llama3.1"}, # 8b is default
117
+ ),
118
+ KilnModelProvider(
119
+ name=ModelProviderName.openrouter,
120
+ provider_options={"model": "meta-llama/llama-3.1-8b-instruct"},
121
+ ),
122
+ ],
123
+ ),
124
+ # Llama 3.1 70b
125
+ KilnModel(
126
+ family=ModelFamily.llama,
127
+ name=ModelName.llama_3_1_70b,
128
+ friendly_name="Llama 3.1 70B",
129
+ providers=[
130
+ KilnModelProvider(
131
+ name=ModelProviderName.groq,
132
+ provider_options={"model": "llama-3.1-70b-versatile"},
133
+ ),
134
+ KilnModelProvider(
135
+ name=ModelProviderName.amazon_bedrock,
136
+ # TODO: this should work but a bug in the bedrock response schema
137
+ supports_structured_output=False,
138
+ provider_options={
139
+ "model": "meta.llama3-1-70b-instruct-v1:0",
140
+ "region_name": "us-west-2", # Llama 3.1 only in west-2
141
+ },
142
+ ),
143
+ KilnModelProvider(
144
+ name=ModelProviderName.openrouter,
145
+ provider_options={"model": "meta-llama/llama-3.1-70b-instruct"},
146
+ ),
147
+ # TODO: enable once tests update to check if model is available
148
+ # KilnModelProvider(
149
+ # provider=ModelProviders.ollama,
150
+ # provider_options={"model": "llama3.1:70b"},
151
+ # ),
152
+ ],
153
+ ),
154
+ # Llama 3.1 405b
155
+ KilnModel(
156
+ family=ModelFamily.llama,
157
+ name=ModelName.llama_3_1_405b,
158
+ friendly_name="Llama 3.1 405B",
159
+ providers=[
160
+ # TODO: bring back when groq does: https://console.groq.com/docs/models
161
+ # KilnModelProvider(
162
+ # name=ModelProviderName.groq,
163
+ # provider_options={"model": "llama-3.1-405b-instruct-v1:0"},
164
+ # ),
165
+ KilnModelProvider(
166
+ name=ModelProviderName.amazon_bedrock,
167
+ provider_options={
168
+ "model": "meta.llama3-1-405b-instruct-v1:0",
169
+ "region_name": "us-west-2", # Llama 3.1 only in west-2
170
+ },
171
+ ),
172
+ # TODO: enable once tests update to check if model is available
173
+ # KilnModelProvider(
174
+ # name=ModelProviderName.ollama,
175
+ # provider_options={"model": "llama3.1:405b"},
176
+ # ),
177
+ KilnModelProvider(
178
+ name=ModelProviderName.openrouter,
179
+ provider_options={"model": "meta-llama/llama-3.1-405b-instruct"},
180
+ ),
181
+ ],
182
+ ),
183
+ # Mistral Nemo
184
+ KilnModel(
185
+ family=ModelFamily.mistral,
186
+ name=ModelName.mistral_nemo,
187
+ friendly_name="Mistral Nemo",
188
+ providers=[
189
+ KilnModelProvider(
190
+ name=ModelProviderName.openrouter,
191
+ provider_options={"model": "mistralai/mistral-nemo"},
192
+ ),
193
+ ],
194
+ ),
195
+ # Mistral Large
196
+ KilnModel(
197
+ family=ModelFamily.mistral,
198
+ name=ModelName.mistral_large,
199
+ friendly_name="Mistral Large",
200
+ providers=[
201
+ KilnModelProvider(
202
+ name=ModelProviderName.amazon_bedrock,
203
+ provider_options={
204
+ "model": "mistral.mistral-large-2407-v1:0",
205
+ "region_name": "us-west-2", # only in west-2
206
+ },
207
+ ),
208
+ KilnModelProvider(
209
+ name=ModelProviderName.openrouter,
210
+ provider_options={"model": "mistralai/mistral-large"},
211
+ ),
212
+ # TODO: enable once tests update to check if model is available
213
+ # KilnModelProvider(
214
+ # provider=ModelProviders.ollama,
215
+ # provider_options={"model": "mistral-large"},
216
+ # ),
217
+ ],
218
+ ),
219
+ # Phi 3.5
220
+ KilnModel(
221
+ family=ModelFamily.phi,
222
+ name=ModelName.phi_3_5,
223
+ friendly_name="Phi 3.5",
224
+ supports_structured_output=False,
225
+ providers=[
226
+ KilnModelProvider(
227
+ name=ModelProviderName.ollama,
228
+ provider_options={"model": "phi3.5"},
229
+ ),
230
+ KilnModelProvider(
231
+ name=ModelProviderName.openrouter,
232
+ provider_options={"model": "microsoft/phi-3.5-mini-128k-instruct"},
233
+ ),
234
+ ],
235
+ ),
236
+ # Gemma 2 2.6b
237
+ KilnModel(
238
+ family=ModelFamily.gemma,
239
+ name=ModelName.gemma_2_2b,
240
+ friendly_name="Gemma 2 2B",
241
+ supports_structured_output=False,
242
+ providers=[
243
+ KilnModelProvider(
244
+ name=ModelProviderName.ollama,
245
+ provider_options={
246
+ "model": "gemma2:2b",
247
+ },
248
+ ),
249
+ ],
250
+ ),
251
+ # Gemma 2 9b
252
+ KilnModel(
253
+ family=ModelFamily.gemma,
254
+ name=ModelName.gemma_2_9b,
255
+ friendly_name="Gemma 2 9B",
256
+ supports_structured_output=False,
257
+ providers=[
258
+ # TODO: enable once tests update to check if model is available
259
+ # KilnModelProvider(
260
+ # name=ModelProviderName.ollama,
261
+ # provider_options={
262
+ # "model": "gemma2:9b",
263
+ # },
264
+ # ),
265
+ KilnModelProvider(
266
+ name=ModelProviderName.openrouter,
267
+ provider_options={"model": "google/gemma-2-9b-it"},
268
+ ),
269
+ ],
270
+ ),
271
+ # Gemma 2 27b
272
+ KilnModel(
273
+ family=ModelFamily.gemma,
274
+ name=ModelName.gemma_2_27b,
275
+ friendly_name="Gemma 2 27B",
276
+ supports_structured_output=False,
277
+ providers=[
278
+ # TODO: enable once tests update to check if model is available
279
+ # KilnModelProvider(
280
+ # name=ModelProviderName.ollama,
281
+ # provider_options={
282
+ # "model": "gemma2:27b",
283
+ # },
284
+ # ),
285
+ KilnModelProvider(
286
+ name=ModelProviderName.openrouter,
287
+ provider_options={"model": "google/gemma-2-27b-it"},
288
+ ),
289
+ ],
290
+ ),
291
+ ]
292
+
293
+
294
+ def provider_name_from_id(id: str) -> str:
295
+ if id in ModelProviderName.__members__:
296
+ enum_id = ModelProviderName(id)
297
+ match enum_id:
298
+ case ModelProviderName.amazon_bedrock:
299
+ return "Amazon Bedrock"
300
+ case ModelProviderName.openrouter:
301
+ return "OpenRouter"
302
+ case ModelProviderName.groq:
303
+ return "Groq"
304
+ case ModelProviderName.ollama:
305
+ return "Ollama"
306
+ case ModelProviderName.openai:
307
+ return "OpenAI"
308
+ case _:
309
+ # triggers pyright warning if I miss a case
310
+ raise_exhaustive_error(enum_id)
311
+
312
+ return "Unknown provider: " + id
313
+
314
+
315
+ def raise_exhaustive_error(value: NoReturn) -> NoReturn:
316
+ raise ValueError(f"Unhandled enum value: {value}")
317
+
318
+
319
+ @dataclass
320
+ class ModelProviderWarning:
321
+ required_config_keys: List[str]
322
+ message: str
323
+
324
+
325
+ provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
326
+ ModelProviderName.amazon_bedrock: ModelProviderWarning(
327
+ required_config_keys=["bedrock_access_key", "bedrock_secret_key"],
328
+ message="Attempted to use Amazon Bedrock without an access key and secret set. \nGet your keys from https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/overview",
329
+ ),
330
+ ModelProviderName.openrouter: ModelProviderWarning(
331
+ required_config_keys=["open_router_api_key"],
332
+ message="Attempted to use OpenRouter without an API key set. \nGet your API key from https://openrouter.ai/settings/keys",
333
+ ),
334
+ ModelProviderName.groq: ModelProviderWarning(
335
+ required_config_keys=["groq_api_key"],
336
+ message="Attempted to use Groq without an API key set. \nGet your API key from https://console.groq.com/keys",
337
+ ),
338
+ ModelProviderName.openai: ModelProviderWarning(
339
+ required_config_keys=["open_ai_api_key"],
340
+ message="Attempted to use OpenAI without an API key set. \nGet your API key from https://platform.openai.com/account/api-keys",
341
+ ),
342
+ }
343
+
344
+
345
+ def get_config_value(key: str):
346
+ try:
347
+ return Config.shared().__getattr__(key)
348
+ except AttributeError:
349
+ return None
350
+
351
+
352
+ def check_provider_warnings(provider_name: ModelProviderName):
353
+ warning_check = provider_warnings.get(provider_name)
354
+ if warning_check is None:
355
+ return
356
+ for key in warning_check.required_config_keys:
357
+ if get_config_value(key) is None:
358
+ raise ValueError(warning_check.message)
359
+
360
+
361
+ def langchain_model_from(name: str, provider_name: str | None = None) -> BaseChatModel:
362
+ if name not in ModelName.__members__:
363
+ raise ValueError(f"Invalid name: {name}")
364
+
365
+ # Select the model from built_in_models using the name
366
+ model = next(filter(lambda m: m.name == name, built_in_models))
367
+ if model is None:
368
+ raise ValueError(f"Model {name} not found")
369
+
370
+ # If a provider is provided, select the provider from the model's provider_config
371
+ provider: KilnModelProvider | None = None
372
+ if model.providers is None or len(model.providers) == 0:
373
+ raise ValueError(f"Model {name} has no providers")
374
+ elif provider_name is None:
375
+ # TODO: priority order
376
+ provider = model.providers[0]
377
+ else:
378
+ provider = next(
379
+ filter(lambda p: p.name == provider_name, model.providers), None
380
+ )
381
+ if provider is None:
382
+ raise ValueError(f"Provider {provider_name} not found for model {name}")
383
+
384
+ check_provider_warnings(provider.name)
385
+
386
+ if provider.name == ModelProviderName.openai:
387
+ api_key = Config.shared().open_ai_api_key
388
+ return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type]
389
+ elif provider.name == ModelProviderName.groq:
390
+ api_key = Config.shared().groq_api_key
391
+ if api_key is None:
392
+ raise ValueError(
393
+ "Attempted to use Groq without an API key set. "
394
+ "Get your API key from https://console.groq.com/keys"
395
+ )
396
+ return ChatGroq(**provider.provider_options, groq_api_key=api_key) # type: ignore[arg-type]
397
+ elif provider.name == ModelProviderName.amazon_bedrock:
398
+ api_key = Config.shared().bedrock_access_key
399
+ secret_key = Config.shared().bedrock_secret_key
400
+ # langchain doesn't allow passing these, so ugly hack to set env vars
401
+ os.environ["AWS_ACCESS_KEY_ID"] = api_key
402
+ os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
403
+ return ChatBedrockConverse(
404
+ **provider.provider_options,
405
+ )
406
+ elif provider.name == ModelProviderName.ollama:
407
+ return ChatOllama(**provider.provider_options, base_url=ollama_base_url())
408
+ elif provider.name == ModelProviderName.openrouter:
409
+ api_key = Config.shared().open_router_api_key
410
+ base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
411
+ return ChatOpenAI(
412
+ **provider.provider_options,
413
+ openai_api_key=api_key, # type: ignore[arg-type]
414
+ openai_api_base=base_url, # type: ignore[arg-type]
415
+ default_headers={
416
+ "HTTP-Referer": "https://kiln-ai.com/openrouter",
417
+ "X-Title": "KilnAI",
418
+ },
419
+ )
420
+ else:
421
+ raise ValueError(f"Invalid model or provider: {name} - {provider_name}")
422
+
423
+
424
+ def ollama_base_url():
425
+ env_base_url = os.getenv("OLLAMA_BASE_URL")
426
+ if env_base_url is not None:
427
+ return env_base_url
428
+ return "http://localhost:11434"
429
+
430
+
431
+ async def ollama_online():
432
+ try:
433
+ httpx.get(ollama_base_url() + "/api/tags")
434
+ except httpx.RequestError:
435
+ return False
436
+ return True
@@ -0,0 +1,122 @@
1
+ import json
2
+ from abc import ABCMeta, abstractmethod
3
+ from typing import Dict
4
+
5
+ from kiln_ai.datamodel import Task
6
+ from kiln_ai.utils.formatting import snake_case
7
+
8
+
9
+ class BasePromptBuilder(metaclass=ABCMeta):
10
+ def __init__(self, task: Task):
11
+ self.task = task
12
+
13
+ @abstractmethod
14
+ def build_prompt(self) -> str:
15
+ pass
16
+
17
+ # override to change the name of the prompt builder (if changing class names)
18
+ @classmethod
19
+ def prompt_builder_name(cls) -> str:
20
+ return snake_case(cls.__name__)
21
+
22
+ # Can be overridden to add more information to the user message
23
+ def build_user_message(self, input: Dict | str) -> str:
24
+ if isinstance(input, Dict):
25
+ return f"The input is:\n{json.dumps(input, indent=2)}"
26
+
27
+ return f"The input is:\n{input}"
28
+
29
+
30
+ class SimplePromptBuilder(BasePromptBuilder):
31
+ def build_prompt(self) -> str:
32
+ base_prompt = self.task.instruction
33
+
34
+ # TODO: this is just a quick version. Formatting and best practices TBD
35
+ if len(self.task.requirements) > 0:
36
+ base_prompt += (
37
+ "\n\nYour response should respect the following requirements:\n"
38
+ )
39
+ # iterate requirements, formatting them in numbereed list like 1) task.instruction\n2)...
40
+ for i, requirement in enumerate(self.task.requirements):
41
+ base_prompt += f"{i+1}) {requirement.instruction}\n"
42
+
43
+ return base_prompt
44
+
45
+
46
+ class MultiShotPromptBuilder(BasePromptBuilder):
47
+ @classmethod
48
+ def example_count(cls) -> int:
49
+ return 25
50
+
51
+ def build_prompt(self) -> str:
52
+ base_prompt = f"# Instruction\n\n{ self.task.instruction }\n\n"
53
+
54
+ if len(self.task.requirements) > 0:
55
+ base_prompt += "# Requirements\n\nYour response should respect the following requirements:\n"
56
+ for i, requirement in enumerate(self.task.requirements):
57
+ base_prompt += f"{i+1}) {requirement.instruction}\n"
58
+ base_prompt += "\n"
59
+
60
+ valid_examples: list[tuple[str, str]] = []
61
+ runs = self.task.runs()
62
+
63
+ # first pass, we look for repaired outputs. These are the best examples.
64
+ for run in runs:
65
+ if len(valid_examples) >= self.__class__.example_count():
66
+ break
67
+ if run.repaired_output is not None:
68
+ valid_examples.append((run.input, run.repaired_output.output))
69
+
70
+ # second pass, we look for high quality outputs (rating based)
71
+ # Minimum is "high_quality" (4 star in star rating scale), then sort by rating
72
+ # exclude repaired outputs as they were used above
73
+ runs_with_rating = [
74
+ run
75
+ for run in runs
76
+ if run.output.rating is not None
77
+ and run.output.rating.value is not None
78
+ and run.output.rating.is_high_quality()
79
+ and run.repaired_output is None
80
+ ]
81
+ runs_with_rating.sort(
82
+ key=lambda x: (x.output.rating and x.output.rating.value) or 0, reverse=True
83
+ )
84
+ for run in runs_with_rating:
85
+ if len(valid_examples) >= self.__class__.example_count():
86
+ break
87
+ valid_examples.append((run.input, run.output.output))
88
+
89
+ if len(valid_examples) > 0:
90
+ base_prompt += "# Example Outputs\n\n"
91
+ for i, example in enumerate(valid_examples):
92
+ base_prompt += (
93
+ f"## Example {i+1}\n\nInput: {example[0]}\nOutput: {example[1]}\n\n"
94
+ )
95
+
96
+ return base_prompt
97
+
98
+
99
+ class FewShotPromptBuilder(MultiShotPromptBuilder):
100
+ @classmethod
101
+ def example_count(cls) -> int:
102
+ return 4
103
+
104
+
105
+ prompt_builder_registry = {
106
+ "simple_prompt_builder": SimplePromptBuilder,
107
+ "multi_shot_prompt_builder": MultiShotPromptBuilder,
108
+ "few_shot_prompt_builder": FewShotPromptBuilder,
109
+ }
110
+
111
+
112
+ # Our UI has some names that are not the same as the class names, which also hint parameters.
113
+ def prompt_builder_from_ui_name(ui_name: str) -> type[BasePromptBuilder]:
114
+ match ui_name:
115
+ case "basic":
116
+ return SimplePromptBuilder
117
+ case "few_shot":
118
+ return FewShotPromptBuilder
119
+ case "many_shot":
120
+ return MultiShotPromptBuilder
121
+ case _:
122
+ raise ValueError(f"Unknown prompt builder: {ui_name}")
@@ -0,0 +1,71 @@
1
+ import json
2
+ from typing import Type
3
+
4
+ from kiln_ai.adapters.prompt_builders import BasePromptBuilder, prompt_builder_registry
5
+ from kiln_ai.datamodel import Priority, Project, Task, TaskRequirement, TaskRun
6
+ from pydantic import BaseModel, Field
7
+
8
+
9
+ # TODO add evaluator rating
10
+ class RepairTaskInput(BaseModel):
11
+ original_prompt: str
12
+ original_input: str
13
+ original_output: str
14
+ evaluator_feedback: str = Field(
15
+ min_length=1,
16
+ description="Feedback from an evaluator on how to repair the task run.",
17
+ )
18
+
19
+
20
+ class RepairTaskRun(Task, parent_of={}):
21
+ def __init__(self, original_task: Task):
22
+ # Keep the typechecker happy
23
+ tmp_project = Project(name="Repair")
24
+ super().__init__(
25
+ name="Repair",
26
+ parent=tmp_project,
27
+ description="Repair a task run, given feedback from an evaluator about how the response can be improved.",
28
+ instruction="You are an assistant which helps improve output from another assistant (original assistant). You'll be provided a task that the original assistant executed (prompt), \
29
+ the input it was given, and the output it generated. An evaluator has determined that the output it generated did not satisfy the task and should be improved. The evaluator will provide \
30
+ feedback describing what should be improved. Your job is to understand the evaluator's feedback and improve the response.",
31
+ requirements=[
32
+ TaskRequirement(
33
+ name="Follow Eval Feedback",
34
+ instruction="The evaluator's feedback is the most important thing to consider. If it conflicts with the original task instruction or prompt, prioritize the evaluator's feedback.",
35
+ priority=Priority.p0,
36
+ )
37
+ ],
38
+ input_json_schema=json.dumps(RepairTaskInput.model_json_schema()),
39
+ output_json_schema=original_task.output_json_schema,
40
+ )
41
+
42
+ @classmethod
43
+ def _original_prompt(cls, run: TaskRun, task: Task) -> str:
44
+ prompt_builder_class: Type[BasePromptBuilder] | None = None
45
+ prompt_builder_name = run.output.source.properties.get(
46
+ "prompt_builder_name", None
47
+ )
48
+ if prompt_builder_name is not None and isinstance(prompt_builder_name, str):
49
+ prompt_builder_class = prompt_builder_registry.get(
50
+ prompt_builder_name, None
51
+ )
52
+ if prompt_builder_class is None:
53
+ raise ValueError(f"No prompt builder found for name: {prompt_builder_name}")
54
+ prompt_builder = prompt_builder_class(task=task)
55
+ if not isinstance(prompt_builder, BasePromptBuilder):
56
+ raise ValueError(
57
+ f"Prompt builder {prompt_builder_name} is not a valid prompt builder"
58
+ )
59
+ return prompt_builder.build_prompt()
60
+
61
+ @classmethod
62
+ def build_repair_task_input(
63
+ cls, original_task: Task, task_run: TaskRun, evaluator_feedback: str
64
+ ) -> RepairTaskInput:
65
+ original_prompt = cls._original_prompt(task_run, original_task)
66
+ return RepairTaskInput(
67
+ original_prompt=original_prompt,
68
+ original_input=task_run.input,
69
+ original_output=task_run.output.output,
70
+ evaluator_feedback=evaluator_feedback,
71
+ )