kiln-ai 0.8.0__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (57) hide show
  1. kiln_ai/adapters/__init__.py +7 -7
  2. kiln_ai/adapters/adapter_registry.py +77 -5
  3. kiln_ai/adapters/data_gen/data_gen_task.py +3 -3
  4. kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
  5. kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
  6. kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
  7. kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
  8. kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
  9. kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
  10. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +469 -129
  11. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +113 -21
  12. kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
  13. kiln_ai/adapters/ml_model_list.py +323 -94
  14. kiln_ai/adapters/model_adapters/__init__.py +18 -0
  15. kiln_ai/adapters/{base_adapter.py → model_adapters/base_adapter.py} +81 -37
  16. kiln_ai/adapters/{langchain_adapters.py → model_adapters/langchain_adapters.py} +130 -84
  17. kiln_ai/adapters/model_adapters/openai_compatible_config.py +11 -0
  18. kiln_ai/adapters/model_adapters/openai_model_adapter.py +246 -0
  19. kiln_ai/adapters/model_adapters/test_base_adapter.py +190 -0
  20. kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +103 -88
  21. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +225 -0
  22. kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +43 -15
  23. kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +93 -20
  24. kiln_ai/adapters/parsers/__init__.py +10 -0
  25. kiln_ai/adapters/parsers/base_parser.py +12 -0
  26. kiln_ai/adapters/parsers/json_parser.py +37 -0
  27. kiln_ai/adapters/parsers/parser_registry.py +19 -0
  28. kiln_ai/adapters/parsers/r1_parser.py +69 -0
  29. kiln_ai/adapters/parsers/test_json_parser.py +81 -0
  30. kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
  31. kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
  32. kiln_ai/adapters/prompt_builders.py +126 -20
  33. kiln_ai/adapters/provider_tools.py +91 -36
  34. kiln_ai/adapters/repair/repair_task.py +17 -6
  35. kiln_ai/adapters/repair/test_repair_task.py +4 -4
  36. kiln_ai/adapters/run_output.py +8 -0
  37. kiln_ai/adapters/test_adapter_registry.py +177 -0
  38. kiln_ai/adapters/test_generate_docs.py +69 -0
  39. kiln_ai/adapters/test_prompt_adaptors.py +8 -4
  40. kiln_ai/adapters/test_prompt_builders.py +190 -29
  41. kiln_ai/adapters/test_provider_tools.py +268 -46
  42. kiln_ai/datamodel/__init__.py +199 -12
  43. kiln_ai/datamodel/basemodel.py +31 -11
  44. kiln_ai/datamodel/json_schema.py +8 -3
  45. kiln_ai/datamodel/model_cache.py +8 -3
  46. kiln_ai/datamodel/test_basemodel.py +81 -2
  47. kiln_ai/datamodel/test_dataset_split.py +100 -3
  48. kiln_ai/datamodel/test_example_models.py +25 -4
  49. kiln_ai/datamodel/test_model_cache.py +24 -0
  50. kiln_ai/datamodel/test_model_perf.py +125 -0
  51. kiln_ai/datamodel/test_models.py +129 -0
  52. kiln_ai/utils/exhaustive_error.py +6 -0
  53. {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/METADATA +9 -7
  54. kiln_ai-0.11.1.dist-info/RECORD +76 -0
  55. kiln_ai-0.8.0.dist-info/RECORD +0 -58
  56. {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/WHEEL +0 -0
  57. {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -20,8 +20,32 @@ class BasePromptBuilder(metaclass=ABCMeta):
20
20
  """
21
21
  self.task = task
22
22
 
23
+ def prompt_id(self) -> str | None:
24
+ """Returns the ID of the prompt, scoped to this builder.
25
+
26
+ Returns:
27
+ str | None: The ID of the prompt, or None if not set.
28
+ """
29
+ return None
30
+
31
+ def build_prompt(self, include_json_instructions) -> str:
32
+ """Build and return the complete prompt string.
33
+
34
+ Returns:
35
+ str: The constructed prompt.
36
+ """
37
+ prompt = self.build_base_prompt()
38
+
39
+ if include_json_instructions and self.task.output_schema():
40
+ prompt = (
41
+ prompt
42
+ + f"\n\n# Format Instructions\n\nReturn a JSON object conforming to the following schema:\n```\n{self.task.output_schema()}\n```"
43
+ )
44
+
45
+ return prompt
46
+
23
47
  @abstractmethod
24
- def build_prompt(self) -> str:
48
+ def build_base_prompt(self) -> str:
25
49
  """Build and return the complete prompt string.
26
50
 
27
51
  Returns:
@@ -50,7 +74,7 @@ class BasePromptBuilder(metaclass=ABCMeta):
50
74
  str: The formatted user message.
51
75
  """
52
76
  if isinstance(input, Dict):
53
- return f"The input is:\n{json.dumps(input, indent=2)}"
77
+ return f"The input is:\n{json.dumps(input, indent=2, ensure_ascii=False)}"
54
78
 
55
79
  return f"The input is:\n{input}"
56
80
 
@@ -70,7 +94,7 @@ class BasePromptBuilder(metaclass=ABCMeta):
70
94
  Returns:
71
95
  str: The constructed prompt string.
72
96
  """
73
- base_prompt = self.build_prompt()
97
+ base_prompt = self.build_prompt(include_json_instructions=False)
74
98
  cot_prompt = self.chain_of_thought_prompt()
75
99
  if cot_prompt:
76
100
  base_prompt += "\n# Thinking Instructions\n\n" + cot_prompt
@@ -80,7 +104,7 @@ class BasePromptBuilder(metaclass=ABCMeta):
80
104
  class SimplePromptBuilder(BasePromptBuilder):
81
105
  """A basic prompt builder that combines task instruction with requirements."""
82
106
 
83
- def build_prompt(self) -> str:
107
+ def build_base_prompt(self) -> str:
84
108
  """Build a simple prompt with instruction and requirements.
85
109
 
86
110
  Returns:
@@ -95,7 +119,7 @@ class SimplePromptBuilder(BasePromptBuilder):
95
119
  )
96
120
  # iterate requirements, formatting them in numbereed list like 1) task.instruction\n2)...
97
121
  for i, requirement in enumerate(self.task.requirements):
98
- base_prompt += f"{i+1}) {requirement.instruction}\n"
122
+ base_prompt += f"{i + 1}) {requirement.instruction}\n"
99
123
 
100
124
  return base_prompt
101
125
 
@@ -112,18 +136,18 @@ class MultiShotPromptBuilder(BasePromptBuilder):
112
136
  """
113
137
  return 25
114
138
 
115
- def build_prompt(self) -> str:
139
+ def build_base_prompt(self) -> str:
116
140
  """Build a prompt with instruction, requirements, and multiple examples.
117
141
 
118
142
  Returns:
119
143
  str: The constructed prompt string with examples.
120
144
  """
121
- base_prompt = f"# Instruction\n\n{ self.task.instruction }\n\n"
145
+ base_prompt = f"# Instruction\n\n{self.task.instruction}\n\n"
122
146
 
123
147
  if len(self.task.requirements) > 0:
124
148
  base_prompt += "# Requirements\n\nYour response should respect the following requirements:\n"
125
149
  for i, requirement in enumerate(self.task.requirements):
126
- base_prompt += f"{i+1}) {requirement.instruction}\n"
150
+ base_prompt += f"{i + 1}) {requirement.instruction}\n"
127
151
  base_prompt += "\n"
128
152
 
129
153
  valid_examples = self.collect_examples()
@@ -140,11 +164,11 @@ class MultiShotPromptBuilder(BasePromptBuilder):
140
164
  def prompt_section_for_example(self, index: int, example: TaskRun) -> str:
141
165
  # Prefer repaired output if it exists, otherwise use the regular output
142
166
  output = example.repaired_output or example.output
143
- return f"## Example {index+1}\n\nInput: {example.input}\nOutput: {output.output}\n\n"
167
+ return f"## Example {index + 1}\n\nInput: {example.input}\nOutput: {output.output}\n\n"
144
168
 
145
169
  def collect_examples(self) -> list[TaskRun]:
146
170
  valid_examples: list[TaskRun] = []
147
- runs = self.task.runs()
171
+ runs = self.task.runs(readonly=True)
148
172
 
149
173
  # first pass, we look for repaired outputs. These are the best examples.
150
174
  for run in runs:
@@ -198,7 +222,7 @@ class RepairsPromptBuilder(MultiShotPromptBuilder):
198
222
  ):
199
223
  return super().prompt_section_for_example(index, example)
200
224
 
201
- prompt_section = f"## Example {index+1}\n\nInput: {example.input}\n\n"
225
+ prompt_section = f"## Example {index + 1}\n\nInput: {example.input}\n\n"
202
226
  prompt_section += (
203
227
  f"Initial Output Which Was Insufficient: {example.output.output}\n\n"
204
228
  )
@@ -209,7 +233,7 @@ class RepairsPromptBuilder(MultiShotPromptBuilder):
209
233
  return prompt_section
210
234
 
211
235
 
212
- def chain_of_thought_prompt(task: Task) -> str | None:
236
+ def chain_of_thought_prompt(task: Task) -> str:
213
237
  """Standard implementation to build and return the chain of thought prompt string.
214
238
 
215
239
  Returns:
@@ -244,6 +268,77 @@ class MultiShotChainOfThoughtPromptBuilder(MultiShotPromptBuilder):
244
268
  return chain_of_thought_prompt(self.task)
245
269
 
246
270
 
271
+ class SavedPromptBuilder(BasePromptBuilder):
272
+ """A prompt builder that looks up a static prompt."""
273
+
274
+ def __init__(self, task: Task, prompt_id: str):
275
+ super().__init__(task)
276
+ prompt_model = next(
277
+ (
278
+ prompt
279
+ for prompt in task.prompts(readonly=True)
280
+ if prompt.id == prompt_id
281
+ ),
282
+ None,
283
+ )
284
+ if not prompt_model:
285
+ raise ValueError(f"Prompt ID not found: {prompt_id}")
286
+ self.prompt_model = prompt_model
287
+
288
+ def prompt_id(self) -> str | None:
289
+ return self.prompt_model.id
290
+
291
+ def build_base_prompt(self) -> str:
292
+ """Returns a saved prompt.
293
+
294
+ Returns:
295
+ str: The prompt string.
296
+ """
297
+ return self.prompt_model.prompt
298
+
299
+ def chain_of_thought_prompt(self) -> str | None:
300
+ return self.prompt_model.chain_of_thought_instructions
301
+
302
+
303
+ class FineTunePromptBuilder(BasePromptBuilder):
304
+ """A prompt builder that looks up a fine-tune prompt."""
305
+
306
+ def __init__(self, task: Task, nested_fine_tune_id: str):
307
+ super().__init__(task)
308
+
309
+ # IDs are in project_id::task_id::fine_tune_id format
310
+ self.full_fine_tune_id = nested_fine_tune_id
311
+ parts = nested_fine_tune_id.split("::")
312
+ if len(parts) != 3:
313
+ raise ValueError(
314
+ f"Invalid fine-tune ID format. Expected 'project_id::task_id::fine_tune_id', got: {nested_fine_tune_id}"
315
+ )
316
+ fine_tune_id = parts[2]
317
+
318
+ fine_tune_model = next(
319
+ (
320
+ fine_tune
321
+ for fine_tune in task.finetunes(readonly=True)
322
+ if fine_tune.id == fine_tune_id
323
+ ),
324
+ None,
325
+ )
326
+ if not fine_tune_model:
327
+ raise ValueError(f"Fine-tune ID not found: {fine_tune_id}")
328
+ self.fine_tune_model = fine_tune_model
329
+
330
+ def prompt_id(self) -> str | None:
331
+ return self.full_fine_tune_id
332
+
333
+ def build_base_prompt(self) -> str:
334
+ return self.fine_tune_model.system_message
335
+
336
+ def chain_of_thought_prompt(self) -> str | None:
337
+ return self.fine_tune_model.thinking_instructions
338
+
339
+
340
+ # TODO P2: we end up with 2 IDs for these: the keys here (ui_name) and the prompt_builder_name from the class
341
+ # We end up maintaining this in _prompt_generators as well.
247
342
  prompt_builder_registry = {
248
343
  "simple_prompt_builder": SimplePromptBuilder,
249
344
  "multi_shot_prompt_builder": MultiShotPromptBuilder,
@@ -256,7 +351,7 @@ prompt_builder_registry = {
256
351
 
257
352
 
258
353
  # Our UI has some names that are not the same as the class names, which also hint parameters.
259
- def prompt_builder_from_ui_name(ui_name: str) -> type[BasePromptBuilder]:
354
+ def prompt_builder_from_ui_name(ui_name: str, task: Task) -> BasePromptBuilder:
260
355
  """Convert a name used in the UI to the corresponding prompt builder class.
261
356
 
262
357
  Args:
@@ -268,20 +363,31 @@ def prompt_builder_from_ui_name(ui_name: str) -> type[BasePromptBuilder]:
268
363
  Raises:
269
364
  ValueError: If the UI name is not recognized.
270
365
  """
366
+
367
+ # Saved prompts are prefixed with "id::"
368
+ if ui_name.startswith("id::"):
369
+ prompt_id = ui_name[4:]
370
+ return SavedPromptBuilder(task, prompt_id)
371
+
372
+ # Fine-tune prompts are prefixed with "fine_tune_prompt::"
373
+ if ui_name.startswith("fine_tune_prompt::"):
374
+ fine_tune_id = ui_name[18:]
375
+ return FineTunePromptBuilder(task, fine_tune_id)
376
+
271
377
  match ui_name:
272
378
  case "basic":
273
- return SimplePromptBuilder
379
+ return SimplePromptBuilder(task)
274
380
  case "few_shot":
275
- return FewShotPromptBuilder
381
+ return FewShotPromptBuilder(task)
276
382
  case "many_shot":
277
- return MultiShotPromptBuilder
383
+ return MultiShotPromptBuilder(task)
278
384
  case "repairs":
279
- return RepairsPromptBuilder
385
+ return RepairsPromptBuilder(task)
280
386
  case "simple_chain_of_thought":
281
- return SimpleChainOfThoughtPromptBuilder
387
+ return SimpleChainOfThoughtPromptBuilder(task)
282
388
  case "few_shot_chain_of_thought":
283
- return FewShotChainOfThoughtPromptBuilder
389
+ return FewShotChainOfThoughtPromptBuilder(task)
284
390
  case "multi_shot_chain_of_thought":
285
- return MultiShotChainOfThoughtPromptBuilder
391
+ return MultiShotChainOfThoughtPromptBuilder(task)
286
392
  case _:
287
393
  raise ValueError(f"Unknown prompt builder: {ui_name}")
@@ -1,20 +1,24 @@
1
1
  from dataclasses import dataclass
2
- from typing import Dict, List, NoReturn
2
+ from typing import Dict, List
3
3
 
4
4
  from kiln_ai.adapters.ml_model_list import (
5
5
  KilnModel,
6
6
  KilnModelProvider,
7
7
  ModelName,
8
8
  ModelProviderName,
9
+ StructuredOutputMode,
9
10
  built_in_models,
10
11
  )
12
+ from kiln_ai.adapters.model_adapters.openai_compatible_config import (
13
+ OpenAICompatibleConfig,
14
+ )
11
15
  from kiln_ai.adapters.ollama_tools import (
12
16
  get_ollama_connection,
13
17
  )
14
18
  from kiln_ai.datamodel import Finetune, Task
15
19
  from kiln_ai.datamodel.registry import project_from_id
16
-
17
- from ..utils.config import Config
20
+ from kiln_ai.utils.config import Config
21
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
18
22
 
19
23
 
20
24
  async def provider_enabled(provider_name: ModelProviderName) -> bool:
@@ -61,7 +65,7 @@ def check_provider_warnings(provider_name: ModelProviderName):
61
65
  raise ValueError(warning_check.message)
62
66
 
63
67
 
64
- async def builtin_model_from(
68
+ def builtin_model_from(
65
69
  name: str, provider_name: str | None = None
66
70
  ) -> KilnModelProvider | None:
67
71
  """
@@ -102,7 +106,47 @@ async def builtin_model_from(
102
106
  return provider
103
107
 
104
108
 
105
- async def kiln_model_provider_from(
109
+ def core_provider(model_id: str, provider_name: ModelProviderName) -> ModelProviderName:
110
+ """
111
+ Get the provider that should be run.
112
+
113
+ Some provider IDs are wrappers (fine-tunes, custom models). This maps these to runnable providers (openai, ollama, etc)
114
+ """
115
+
116
+ # Custom models map to the underlying provider
117
+ if provider_name is ModelProviderName.kiln_custom_registry:
118
+ provider_name, _ = parse_custom_model_id(model_id)
119
+ return provider_name
120
+
121
+ # Fine-tune provider maps to an underlying provider
122
+ if provider_name is ModelProviderName.kiln_fine_tune:
123
+ finetune = finetune_from_id(model_id)
124
+ if finetune.provider not in ModelProviderName.__members__:
125
+ raise ValueError(
126
+ f"Finetune {model_id} has no underlying provider {finetune.provider}"
127
+ )
128
+ return ModelProviderName(finetune.provider)
129
+
130
+ return provider_name
131
+
132
+
133
+ def parse_custom_model_id(
134
+ model_id: str,
135
+ ) -> tuple[ModelProviderName, str]:
136
+ if "::" not in model_id:
137
+ raise ValueError(f"Invalid custom model ID: {model_id}")
138
+
139
+ # For custom registry, get the provider name and model name from the model id
140
+ provider_name = model_id.split("::", 1)[0]
141
+ model_name = model_id.split("::", 1)[1]
142
+
143
+ if provider_name not in ModelProviderName.__members__:
144
+ raise ValueError(f"Invalid provider name: {provider_name}")
145
+
146
+ return ModelProviderName(provider_name), model_name
147
+
148
+
149
+ def kiln_model_provider_from(
106
150
  name: str, provider_name: str | None = None
107
151
  ) -> KilnModelProvider:
108
152
  if provider_name == ModelProviderName.kiln_fine_tune:
@@ -111,14 +155,13 @@ async def kiln_model_provider_from(
111
155
  if provider_name == ModelProviderName.openai_compatible:
112
156
  return openai_compatible_provider_model(name)
113
157
 
114
- built_in_model = await builtin_model_from(name, provider_name)
158
+ built_in_model = builtin_model_from(name, provider_name)
115
159
  if built_in_model:
116
160
  return built_in_model
117
161
 
118
162
  # For custom registry, get the provider name and model name from the model id
119
163
  if provider_name == ModelProviderName.kiln_custom_registry:
120
- provider_name = name.split("::", 1)[0]
121
- name = name.split("::", 1)[1]
164
+ provider_name, name = parse_custom_model_id(name)
122
165
 
123
166
  # Custom/untested model. Set untested, and build a ModelProvider at runtime
124
167
  if provider_name is None:
@@ -136,12 +179,9 @@ async def kiln_model_provider_from(
136
179
  )
137
180
 
138
181
 
139
- finetune_cache: dict[str, KilnModelProvider] = {}
140
-
141
-
142
- def openai_compatible_provider_model(
182
+ def openai_compatible_config(
143
183
  model_id: str,
144
- ) -> KilnModelProvider:
184
+ ) -> OpenAICompatibleConfig:
145
185
  try:
146
186
  openai_provider_name, model_id = model_id.split("::")
147
187
  except Exception:
@@ -165,12 +205,21 @@ def openai_compatible_provider_model(
165
205
  f"OpenAI compatible provider {openai_provider_name} has no base URL"
166
206
  )
167
207
 
208
+ return OpenAICompatibleConfig(
209
+ api_key=api_key,
210
+ model_name=model_id,
211
+ provider_name=ModelProviderName.openai_compatible,
212
+ base_url=base_url,
213
+ )
214
+
215
+
216
+ def openai_compatible_provider_model(
217
+ model_id: str,
218
+ ) -> KilnModelProvider:
168
219
  return KilnModelProvider(
169
220
  name=ModelProviderName.openai_compatible,
170
221
  provider_options={
171
222
  "model": model_id,
172
- "api_key": api_key,
173
- "openai_api_base": base_url,
174
223
  },
175
224
  supports_structured_output=False,
176
225
  supports_data_gen=False,
@@ -178,9 +227,10 @@ def openai_compatible_provider_model(
178
227
  )
179
228
 
180
229
 
181
- def finetune_provider_model(
182
- model_id: str,
183
- ) -> KilnModelProvider:
230
+ finetune_cache: dict[str, Finetune] = {}
231
+
232
+
233
+ def finetune_from_id(model_id: str) -> Finetune:
184
234
  if model_id in finetune_cache:
185
235
  return finetune_cache[model_id]
186
236
 
@@ -202,6 +252,15 @@ def finetune_provider_model(
202
252
  f"Fine tune {fine_tune_id} not completed. Refresh it's status in the fine-tune tab."
203
253
  )
204
254
 
255
+ finetune_cache[model_id] = fine_tune
256
+ return fine_tune
257
+
258
+
259
+ def finetune_provider_model(
260
+ model_id: str,
261
+ ) -> KilnModelProvider:
262
+ fine_tune = finetune_from_id(model_id)
263
+
205
264
  provider = ModelProviderName[fine_tune.provider]
206
265
  model_provider = KilnModelProvider(
207
266
  name=provider,
@@ -210,18 +269,18 @@ def finetune_provider_model(
210
269
  },
211
270
  )
212
271
 
213
- # TODO: Don't love this abstraction/logic.
214
- if fine_tune.provider == ModelProviderName.fireworks_ai:
215
- # Fireworks finetunes are trained with json, not tool calling (which is LC default format)
216
- model_provider.adapter_options = {
217
- "langchain": {
218
- "with_structured_output_options": {
219
- "method": "json_mode",
220
- }
221
- }
222
- }
223
-
224
- finetune_cache[model_id] = model_provider
272
+ if fine_tune.structured_output_mode is not None:
273
+ # If we know the model was trained with specific output mode, set it
274
+ model_provider.structured_output_mode = fine_tune.structured_output_mode
275
+ else:
276
+ # Some early adopters won't have structured_output_mode set on their fine-tunes.
277
+ # We know that OpenAI uses json_schema, and Fireworks (only other provider) use json_mode.
278
+ # This can be removed in the future
279
+ if provider == ModelProviderName.openai:
280
+ model_provider.structured_output_mode = StructuredOutputMode.json_schema
281
+ else:
282
+ model_provider.structured_output_mode = StructuredOutputMode.json_mode
283
+
225
284
  return model_provider
226
285
 
227
286
 
@@ -274,7 +333,7 @@ def provider_name_from_id(id: str) -> str:
274
333
  return "OpenAI Compatible"
275
334
  case _:
276
335
  # triggers pyright warning if I miss a case
277
- raise_exhaustive_error(enum_id)
336
+ raise_exhaustive_enum_error(enum_id)
278
337
 
279
338
  return "Unknown provider: " + id
280
339
 
@@ -316,16 +375,12 @@ def provider_options_for_custom_model(
316
375
  )
317
376
  case _:
318
377
  # triggers pyright warning if I miss a case
319
- raise_exhaustive_error(enum_id)
378
+ raise_exhaustive_enum_error(enum_id)
320
379
 
321
380
  # Won't reach this, type checking will catch missed values
322
381
  return {"model": model_name}
323
382
 
324
383
 
325
- def raise_exhaustive_error(value: NoReturn) -> NoReturn:
326
- raise ValueError(f"Unhandled enum value: {value}")
327
-
328
-
329
384
  @dataclass
330
385
  class ModelProviderWarning:
331
386
  required_config_keys: List[str]
@@ -3,7 +3,11 @@ from typing import Type
3
3
 
4
4
  from pydantic import BaseModel, Field
5
5
 
6
- from kiln_ai.adapters.prompt_builders import BasePromptBuilder, prompt_builder_registry
6
+ from kiln_ai.adapters.prompt_builders import (
7
+ BasePromptBuilder,
8
+ SavedPromptBuilder,
9
+ prompt_builder_registry,
10
+ )
7
11
  from kiln_ai.datamodel import Priority, Project, Task, TaskRequirement, TaskRun
8
12
 
9
13
 
@@ -42,11 +46,18 @@ feedback describing what should be improved. Your job is to understand the evalu
42
46
 
43
47
  @classmethod
44
48
  def _original_prompt(cls, run: TaskRun, task: Task) -> str:
49
+ if run.output.source is None or run.output.source.properties is None:
50
+ raise ValueError("No source properties found")
51
+
52
+ # Try ID first, then builder name
53
+ prompt_id = run.output.source.properties.get("prompt_id", None)
54
+ if prompt_id is not None and isinstance(prompt_id, str):
55
+ static_prompt_builder = SavedPromptBuilder(task, prompt_id)
56
+ return static_prompt_builder.build_prompt(include_json_instructions=False)
57
+
45
58
  prompt_builder_class: Type[BasePromptBuilder] | None = None
46
- prompt_builder_name = (
47
- run.output.source.properties.get("prompt_builder_name", None)
48
- if run.output.source
49
- else None
59
+ prompt_builder_name = run.output.source.properties.get(
60
+ "prompt_builder_name", None
50
61
  )
51
62
  if prompt_builder_name is not None and isinstance(prompt_builder_name, str):
52
63
  prompt_builder_class = prompt_builder_registry.get(
@@ -59,7 +70,7 @@ feedback describing what should be improved. Your job is to understand the evalu
59
70
  raise ValueError(
60
71
  f"Prompt builder {prompt_builder_name} is not a valid prompt builder"
61
72
  )
62
- return prompt_builder.build_prompt()
73
+ return prompt_builder.build_prompt(include_json_instructions=False)
63
74
 
64
75
  @classmethod
65
76
  def build_repair_task_input(
@@ -6,8 +6,8 @@ import pytest
6
6
  from pydantic import ValidationError
7
7
 
8
8
  from kiln_ai.adapters.adapter_registry import adapter_for_task
9
- from kiln_ai.adapters.base_adapter import RunOutput
10
- from kiln_ai.adapters.langchain_adapters import LangchainAdapter
9
+ from kiln_ai.adapters.model_adapters.base_adapter import RunOutput
10
+ from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
11
11
  from kiln_ai.adapters.repair.repair_task import (
12
12
  RepairTaskInput,
13
13
  RepairTaskRun,
@@ -223,7 +223,7 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
223
223
  )
224
224
 
225
225
  adapter = adapter_for_task(
226
- repair_task, model_name="llama_3_1_8b", provider="groq"
226
+ repair_task, model_name="llama_3_1_8b", provider="ollama"
227
227
  )
228
228
 
229
229
  run = await adapter.invoke(repair_task_input.model_dump())
@@ -237,7 +237,7 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
237
237
  assert run.output.source.properties == {
238
238
  "adapter_name": "kiln_langchain_adapter",
239
239
  "model_name": "llama_3_1_8b",
240
- "model_provider": "groq",
240
+ "model_provider": "ollama",
241
241
  "prompt_builder_name": "simple_prompt_builder",
242
242
  }
243
243
  assert run.input_source.type == DataSourceType.human
@@ -0,0 +1,8 @@
1
+ from dataclasses import dataclass
2
+ from typing import Dict
3
+
4
+
5
+ @dataclass
6
+ class RunOutput:
7
+ output: Dict | str
8
+ intermediate_outputs: Dict[str, str] | None