kiln-ai 0.8.0__py3-none-any.whl → 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +7 -7
- kiln_ai/adapters/adapter_registry.py +77 -5
- kiln_ai/adapters/data_gen/data_gen_task.py +3 -3
- kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
- kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
- kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
- kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
- kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +469 -129
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +113 -21
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
- kiln_ai/adapters/ml_model_list.py +323 -94
- kiln_ai/adapters/model_adapters/__init__.py +18 -0
- kiln_ai/adapters/{base_adapter.py → model_adapters/base_adapter.py} +81 -37
- kiln_ai/adapters/{langchain_adapters.py → model_adapters/langchain_adapters.py} +130 -84
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +11 -0
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +246 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +190 -0
- kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +103 -88
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +225 -0
- kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +43 -15
- kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +93 -20
- kiln_ai/adapters/parsers/__init__.py +10 -0
- kiln_ai/adapters/parsers/base_parser.py +12 -0
- kiln_ai/adapters/parsers/json_parser.py +37 -0
- kiln_ai/adapters/parsers/parser_registry.py +19 -0
- kiln_ai/adapters/parsers/r1_parser.py +69 -0
- kiln_ai/adapters/parsers/test_json_parser.py +81 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
- kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
- kiln_ai/adapters/prompt_builders.py +126 -20
- kiln_ai/adapters/provider_tools.py +91 -36
- kiln_ai/adapters/repair/repair_task.py +17 -6
- kiln_ai/adapters/repair/test_repair_task.py +4 -4
- kiln_ai/adapters/run_output.py +8 -0
- kiln_ai/adapters/test_adapter_registry.py +177 -0
- kiln_ai/adapters/test_generate_docs.py +69 -0
- kiln_ai/adapters/test_prompt_adaptors.py +8 -4
- kiln_ai/adapters/test_prompt_builders.py +190 -29
- kiln_ai/adapters/test_provider_tools.py +268 -46
- kiln_ai/datamodel/__init__.py +199 -12
- kiln_ai/datamodel/basemodel.py +31 -11
- kiln_ai/datamodel/json_schema.py +8 -3
- kiln_ai/datamodel/model_cache.py +8 -3
- kiln_ai/datamodel/test_basemodel.py +81 -2
- kiln_ai/datamodel/test_dataset_split.py +100 -3
- kiln_ai/datamodel/test_example_models.py +25 -4
- kiln_ai/datamodel/test_model_cache.py +24 -0
- kiln_ai/datamodel/test_model_perf.py +125 -0
- kiln_ai/datamodel/test_models.py +129 -0
- kiln_ai/utils/exhaustive_error.py +6 -0
- {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/METADATA +9 -7
- kiln_ai-0.11.1.dist-info/RECORD +76 -0
- kiln_ai-0.8.0.dist-info/RECORD +0 -58
- {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -20,8 +20,32 @@ class BasePromptBuilder(metaclass=ABCMeta):
|
|
|
20
20
|
"""
|
|
21
21
|
self.task = task
|
|
22
22
|
|
|
23
|
+
def prompt_id(self) -> str | None:
|
|
24
|
+
"""Returns the ID of the prompt, scoped to this builder.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
str | None: The ID of the prompt, or None if not set.
|
|
28
|
+
"""
|
|
29
|
+
return None
|
|
30
|
+
|
|
31
|
+
def build_prompt(self, include_json_instructions) -> str:
|
|
32
|
+
"""Build and return the complete prompt string.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
str: The constructed prompt.
|
|
36
|
+
"""
|
|
37
|
+
prompt = self.build_base_prompt()
|
|
38
|
+
|
|
39
|
+
if include_json_instructions and self.task.output_schema():
|
|
40
|
+
prompt = (
|
|
41
|
+
prompt
|
|
42
|
+
+ f"\n\n# Format Instructions\n\nReturn a JSON object conforming to the following schema:\n```\n{self.task.output_schema()}\n```"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
return prompt
|
|
46
|
+
|
|
23
47
|
@abstractmethod
|
|
24
|
-
def
|
|
48
|
+
def build_base_prompt(self) -> str:
|
|
25
49
|
"""Build and return the complete prompt string.
|
|
26
50
|
|
|
27
51
|
Returns:
|
|
@@ -50,7 +74,7 @@ class BasePromptBuilder(metaclass=ABCMeta):
|
|
|
50
74
|
str: The formatted user message.
|
|
51
75
|
"""
|
|
52
76
|
if isinstance(input, Dict):
|
|
53
|
-
return f"The input is:\n{json.dumps(input, indent=2)}"
|
|
77
|
+
return f"The input is:\n{json.dumps(input, indent=2, ensure_ascii=False)}"
|
|
54
78
|
|
|
55
79
|
return f"The input is:\n{input}"
|
|
56
80
|
|
|
@@ -70,7 +94,7 @@ class BasePromptBuilder(metaclass=ABCMeta):
|
|
|
70
94
|
Returns:
|
|
71
95
|
str: The constructed prompt string.
|
|
72
96
|
"""
|
|
73
|
-
base_prompt = self.build_prompt()
|
|
97
|
+
base_prompt = self.build_prompt(include_json_instructions=False)
|
|
74
98
|
cot_prompt = self.chain_of_thought_prompt()
|
|
75
99
|
if cot_prompt:
|
|
76
100
|
base_prompt += "\n# Thinking Instructions\n\n" + cot_prompt
|
|
@@ -80,7 +104,7 @@ class BasePromptBuilder(metaclass=ABCMeta):
|
|
|
80
104
|
class SimplePromptBuilder(BasePromptBuilder):
|
|
81
105
|
"""A basic prompt builder that combines task instruction with requirements."""
|
|
82
106
|
|
|
83
|
-
def
|
|
107
|
+
def build_base_prompt(self) -> str:
|
|
84
108
|
"""Build a simple prompt with instruction and requirements.
|
|
85
109
|
|
|
86
110
|
Returns:
|
|
@@ -95,7 +119,7 @@ class SimplePromptBuilder(BasePromptBuilder):
|
|
|
95
119
|
)
|
|
96
120
|
# iterate requirements, formatting them in numbereed list like 1) task.instruction\n2)...
|
|
97
121
|
for i, requirement in enumerate(self.task.requirements):
|
|
98
|
-
base_prompt += f"{i+1}) {requirement.instruction}\n"
|
|
122
|
+
base_prompt += f"{i + 1}) {requirement.instruction}\n"
|
|
99
123
|
|
|
100
124
|
return base_prompt
|
|
101
125
|
|
|
@@ -112,18 +136,18 @@ class MultiShotPromptBuilder(BasePromptBuilder):
|
|
|
112
136
|
"""
|
|
113
137
|
return 25
|
|
114
138
|
|
|
115
|
-
def
|
|
139
|
+
def build_base_prompt(self) -> str:
|
|
116
140
|
"""Build a prompt with instruction, requirements, and multiple examples.
|
|
117
141
|
|
|
118
142
|
Returns:
|
|
119
143
|
str: The constructed prompt string with examples.
|
|
120
144
|
"""
|
|
121
|
-
base_prompt = f"# Instruction\n\n{
|
|
145
|
+
base_prompt = f"# Instruction\n\n{self.task.instruction}\n\n"
|
|
122
146
|
|
|
123
147
|
if len(self.task.requirements) > 0:
|
|
124
148
|
base_prompt += "# Requirements\n\nYour response should respect the following requirements:\n"
|
|
125
149
|
for i, requirement in enumerate(self.task.requirements):
|
|
126
|
-
base_prompt += f"{i+1}) {requirement.instruction}\n"
|
|
150
|
+
base_prompt += f"{i + 1}) {requirement.instruction}\n"
|
|
127
151
|
base_prompt += "\n"
|
|
128
152
|
|
|
129
153
|
valid_examples = self.collect_examples()
|
|
@@ -140,11 +164,11 @@ class MultiShotPromptBuilder(BasePromptBuilder):
|
|
|
140
164
|
def prompt_section_for_example(self, index: int, example: TaskRun) -> str:
|
|
141
165
|
# Prefer repaired output if it exists, otherwise use the regular output
|
|
142
166
|
output = example.repaired_output or example.output
|
|
143
|
-
return f"## Example {index+1}\n\nInput: {example.input}\nOutput: {output.output}\n\n"
|
|
167
|
+
return f"## Example {index + 1}\n\nInput: {example.input}\nOutput: {output.output}\n\n"
|
|
144
168
|
|
|
145
169
|
def collect_examples(self) -> list[TaskRun]:
|
|
146
170
|
valid_examples: list[TaskRun] = []
|
|
147
|
-
runs = self.task.runs()
|
|
171
|
+
runs = self.task.runs(readonly=True)
|
|
148
172
|
|
|
149
173
|
# first pass, we look for repaired outputs. These are the best examples.
|
|
150
174
|
for run in runs:
|
|
@@ -198,7 +222,7 @@ class RepairsPromptBuilder(MultiShotPromptBuilder):
|
|
|
198
222
|
):
|
|
199
223
|
return super().prompt_section_for_example(index, example)
|
|
200
224
|
|
|
201
|
-
prompt_section = f"## Example {index+1}\n\nInput: {example.input}\n\n"
|
|
225
|
+
prompt_section = f"## Example {index + 1}\n\nInput: {example.input}\n\n"
|
|
202
226
|
prompt_section += (
|
|
203
227
|
f"Initial Output Which Was Insufficient: {example.output.output}\n\n"
|
|
204
228
|
)
|
|
@@ -209,7 +233,7 @@ class RepairsPromptBuilder(MultiShotPromptBuilder):
|
|
|
209
233
|
return prompt_section
|
|
210
234
|
|
|
211
235
|
|
|
212
|
-
def chain_of_thought_prompt(task: Task) -> str
|
|
236
|
+
def chain_of_thought_prompt(task: Task) -> str:
|
|
213
237
|
"""Standard implementation to build and return the chain of thought prompt string.
|
|
214
238
|
|
|
215
239
|
Returns:
|
|
@@ -244,6 +268,77 @@ class MultiShotChainOfThoughtPromptBuilder(MultiShotPromptBuilder):
|
|
|
244
268
|
return chain_of_thought_prompt(self.task)
|
|
245
269
|
|
|
246
270
|
|
|
271
|
+
class SavedPromptBuilder(BasePromptBuilder):
|
|
272
|
+
"""A prompt builder that looks up a static prompt."""
|
|
273
|
+
|
|
274
|
+
def __init__(self, task: Task, prompt_id: str):
|
|
275
|
+
super().__init__(task)
|
|
276
|
+
prompt_model = next(
|
|
277
|
+
(
|
|
278
|
+
prompt
|
|
279
|
+
for prompt in task.prompts(readonly=True)
|
|
280
|
+
if prompt.id == prompt_id
|
|
281
|
+
),
|
|
282
|
+
None,
|
|
283
|
+
)
|
|
284
|
+
if not prompt_model:
|
|
285
|
+
raise ValueError(f"Prompt ID not found: {prompt_id}")
|
|
286
|
+
self.prompt_model = prompt_model
|
|
287
|
+
|
|
288
|
+
def prompt_id(self) -> str | None:
|
|
289
|
+
return self.prompt_model.id
|
|
290
|
+
|
|
291
|
+
def build_base_prompt(self) -> str:
|
|
292
|
+
"""Returns a saved prompt.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
str: The prompt string.
|
|
296
|
+
"""
|
|
297
|
+
return self.prompt_model.prompt
|
|
298
|
+
|
|
299
|
+
def chain_of_thought_prompt(self) -> str | None:
|
|
300
|
+
return self.prompt_model.chain_of_thought_instructions
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class FineTunePromptBuilder(BasePromptBuilder):
|
|
304
|
+
"""A prompt builder that looks up a fine-tune prompt."""
|
|
305
|
+
|
|
306
|
+
def __init__(self, task: Task, nested_fine_tune_id: str):
|
|
307
|
+
super().__init__(task)
|
|
308
|
+
|
|
309
|
+
# IDs are in project_id::task_id::fine_tune_id format
|
|
310
|
+
self.full_fine_tune_id = nested_fine_tune_id
|
|
311
|
+
parts = nested_fine_tune_id.split("::")
|
|
312
|
+
if len(parts) != 3:
|
|
313
|
+
raise ValueError(
|
|
314
|
+
f"Invalid fine-tune ID format. Expected 'project_id::task_id::fine_tune_id', got: {nested_fine_tune_id}"
|
|
315
|
+
)
|
|
316
|
+
fine_tune_id = parts[2]
|
|
317
|
+
|
|
318
|
+
fine_tune_model = next(
|
|
319
|
+
(
|
|
320
|
+
fine_tune
|
|
321
|
+
for fine_tune in task.finetunes(readonly=True)
|
|
322
|
+
if fine_tune.id == fine_tune_id
|
|
323
|
+
),
|
|
324
|
+
None,
|
|
325
|
+
)
|
|
326
|
+
if not fine_tune_model:
|
|
327
|
+
raise ValueError(f"Fine-tune ID not found: {fine_tune_id}")
|
|
328
|
+
self.fine_tune_model = fine_tune_model
|
|
329
|
+
|
|
330
|
+
def prompt_id(self) -> str | None:
|
|
331
|
+
return self.full_fine_tune_id
|
|
332
|
+
|
|
333
|
+
def build_base_prompt(self) -> str:
|
|
334
|
+
return self.fine_tune_model.system_message
|
|
335
|
+
|
|
336
|
+
def chain_of_thought_prompt(self) -> str | None:
|
|
337
|
+
return self.fine_tune_model.thinking_instructions
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
# TODO P2: we end up with 2 IDs for these: the keys here (ui_name) and the prompt_builder_name from the class
|
|
341
|
+
# We end up maintaining this in _prompt_generators as well.
|
|
247
342
|
prompt_builder_registry = {
|
|
248
343
|
"simple_prompt_builder": SimplePromptBuilder,
|
|
249
344
|
"multi_shot_prompt_builder": MultiShotPromptBuilder,
|
|
@@ -256,7 +351,7 @@ prompt_builder_registry = {
|
|
|
256
351
|
|
|
257
352
|
|
|
258
353
|
# Our UI has some names that are not the same as the class names, which also hint parameters.
|
|
259
|
-
def prompt_builder_from_ui_name(ui_name: str) ->
|
|
354
|
+
def prompt_builder_from_ui_name(ui_name: str, task: Task) -> BasePromptBuilder:
|
|
260
355
|
"""Convert a name used in the UI to the corresponding prompt builder class.
|
|
261
356
|
|
|
262
357
|
Args:
|
|
@@ -268,20 +363,31 @@ def prompt_builder_from_ui_name(ui_name: str) -> type[BasePromptBuilder]:
|
|
|
268
363
|
Raises:
|
|
269
364
|
ValueError: If the UI name is not recognized.
|
|
270
365
|
"""
|
|
366
|
+
|
|
367
|
+
# Saved prompts are prefixed with "id::"
|
|
368
|
+
if ui_name.startswith("id::"):
|
|
369
|
+
prompt_id = ui_name[4:]
|
|
370
|
+
return SavedPromptBuilder(task, prompt_id)
|
|
371
|
+
|
|
372
|
+
# Fine-tune prompts are prefixed with "fine_tune_prompt::"
|
|
373
|
+
if ui_name.startswith("fine_tune_prompt::"):
|
|
374
|
+
fine_tune_id = ui_name[18:]
|
|
375
|
+
return FineTunePromptBuilder(task, fine_tune_id)
|
|
376
|
+
|
|
271
377
|
match ui_name:
|
|
272
378
|
case "basic":
|
|
273
|
-
return SimplePromptBuilder
|
|
379
|
+
return SimplePromptBuilder(task)
|
|
274
380
|
case "few_shot":
|
|
275
|
-
return FewShotPromptBuilder
|
|
381
|
+
return FewShotPromptBuilder(task)
|
|
276
382
|
case "many_shot":
|
|
277
|
-
return MultiShotPromptBuilder
|
|
383
|
+
return MultiShotPromptBuilder(task)
|
|
278
384
|
case "repairs":
|
|
279
|
-
return RepairsPromptBuilder
|
|
385
|
+
return RepairsPromptBuilder(task)
|
|
280
386
|
case "simple_chain_of_thought":
|
|
281
|
-
return SimpleChainOfThoughtPromptBuilder
|
|
387
|
+
return SimpleChainOfThoughtPromptBuilder(task)
|
|
282
388
|
case "few_shot_chain_of_thought":
|
|
283
|
-
return FewShotChainOfThoughtPromptBuilder
|
|
389
|
+
return FewShotChainOfThoughtPromptBuilder(task)
|
|
284
390
|
case "multi_shot_chain_of_thought":
|
|
285
|
-
return MultiShotChainOfThoughtPromptBuilder
|
|
391
|
+
return MultiShotChainOfThoughtPromptBuilder(task)
|
|
286
392
|
case _:
|
|
287
393
|
raise ValueError(f"Unknown prompt builder: {ui_name}")
|
|
@@ -1,20 +1,24 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import Dict, List
|
|
2
|
+
from typing import Dict, List
|
|
3
3
|
|
|
4
4
|
from kiln_ai.adapters.ml_model_list import (
|
|
5
5
|
KilnModel,
|
|
6
6
|
KilnModelProvider,
|
|
7
7
|
ModelName,
|
|
8
8
|
ModelProviderName,
|
|
9
|
+
StructuredOutputMode,
|
|
9
10
|
built_in_models,
|
|
10
11
|
)
|
|
12
|
+
from kiln_ai.adapters.model_adapters.openai_compatible_config import (
|
|
13
|
+
OpenAICompatibleConfig,
|
|
14
|
+
)
|
|
11
15
|
from kiln_ai.adapters.ollama_tools import (
|
|
12
16
|
get_ollama_connection,
|
|
13
17
|
)
|
|
14
18
|
from kiln_ai.datamodel import Finetune, Task
|
|
15
19
|
from kiln_ai.datamodel.registry import project_from_id
|
|
16
|
-
|
|
17
|
-
from
|
|
20
|
+
from kiln_ai.utils.config import Config
|
|
21
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
18
22
|
|
|
19
23
|
|
|
20
24
|
async def provider_enabled(provider_name: ModelProviderName) -> bool:
|
|
@@ -61,7 +65,7 @@ def check_provider_warnings(provider_name: ModelProviderName):
|
|
|
61
65
|
raise ValueError(warning_check.message)
|
|
62
66
|
|
|
63
67
|
|
|
64
|
-
|
|
68
|
+
def builtin_model_from(
|
|
65
69
|
name: str, provider_name: str | None = None
|
|
66
70
|
) -> KilnModelProvider | None:
|
|
67
71
|
"""
|
|
@@ -102,7 +106,47 @@ async def builtin_model_from(
|
|
|
102
106
|
return provider
|
|
103
107
|
|
|
104
108
|
|
|
105
|
-
|
|
109
|
+
def core_provider(model_id: str, provider_name: ModelProviderName) -> ModelProviderName:
|
|
110
|
+
"""
|
|
111
|
+
Get the provider that should be run.
|
|
112
|
+
|
|
113
|
+
Some provider IDs are wrappers (fine-tunes, custom models). This maps these to runnable providers (openai, ollama, etc)
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
# Custom models map to the underlying provider
|
|
117
|
+
if provider_name is ModelProviderName.kiln_custom_registry:
|
|
118
|
+
provider_name, _ = parse_custom_model_id(model_id)
|
|
119
|
+
return provider_name
|
|
120
|
+
|
|
121
|
+
# Fine-tune provider maps to an underlying provider
|
|
122
|
+
if provider_name is ModelProviderName.kiln_fine_tune:
|
|
123
|
+
finetune = finetune_from_id(model_id)
|
|
124
|
+
if finetune.provider not in ModelProviderName.__members__:
|
|
125
|
+
raise ValueError(
|
|
126
|
+
f"Finetune {model_id} has no underlying provider {finetune.provider}"
|
|
127
|
+
)
|
|
128
|
+
return ModelProviderName(finetune.provider)
|
|
129
|
+
|
|
130
|
+
return provider_name
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def parse_custom_model_id(
|
|
134
|
+
model_id: str,
|
|
135
|
+
) -> tuple[ModelProviderName, str]:
|
|
136
|
+
if "::" not in model_id:
|
|
137
|
+
raise ValueError(f"Invalid custom model ID: {model_id}")
|
|
138
|
+
|
|
139
|
+
# For custom registry, get the provider name and model name from the model id
|
|
140
|
+
provider_name = model_id.split("::", 1)[0]
|
|
141
|
+
model_name = model_id.split("::", 1)[1]
|
|
142
|
+
|
|
143
|
+
if provider_name not in ModelProviderName.__members__:
|
|
144
|
+
raise ValueError(f"Invalid provider name: {provider_name}")
|
|
145
|
+
|
|
146
|
+
return ModelProviderName(provider_name), model_name
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def kiln_model_provider_from(
|
|
106
150
|
name: str, provider_name: str | None = None
|
|
107
151
|
) -> KilnModelProvider:
|
|
108
152
|
if provider_name == ModelProviderName.kiln_fine_tune:
|
|
@@ -111,14 +155,13 @@ async def kiln_model_provider_from(
|
|
|
111
155
|
if provider_name == ModelProviderName.openai_compatible:
|
|
112
156
|
return openai_compatible_provider_model(name)
|
|
113
157
|
|
|
114
|
-
built_in_model =
|
|
158
|
+
built_in_model = builtin_model_from(name, provider_name)
|
|
115
159
|
if built_in_model:
|
|
116
160
|
return built_in_model
|
|
117
161
|
|
|
118
162
|
# For custom registry, get the provider name and model name from the model id
|
|
119
163
|
if provider_name == ModelProviderName.kiln_custom_registry:
|
|
120
|
-
provider_name = name
|
|
121
|
-
name = name.split("::", 1)[1]
|
|
164
|
+
provider_name, name = parse_custom_model_id(name)
|
|
122
165
|
|
|
123
166
|
# Custom/untested model. Set untested, and build a ModelProvider at runtime
|
|
124
167
|
if provider_name is None:
|
|
@@ -136,12 +179,9 @@ async def kiln_model_provider_from(
|
|
|
136
179
|
)
|
|
137
180
|
|
|
138
181
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
def openai_compatible_provider_model(
|
|
182
|
+
def openai_compatible_config(
|
|
143
183
|
model_id: str,
|
|
144
|
-
) ->
|
|
184
|
+
) -> OpenAICompatibleConfig:
|
|
145
185
|
try:
|
|
146
186
|
openai_provider_name, model_id = model_id.split("::")
|
|
147
187
|
except Exception:
|
|
@@ -165,12 +205,21 @@ def openai_compatible_provider_model(
|
|
|
165
205
|
f"OpenAI compatible provider {openai_provider_name} has no base URL"
|
|
166
206
|
)
|
|
167
207
|
|
|
208
|
+
return OpenAICompatibleConfig(
|
|
209
|
+
api_key=api_key,
|
|
210
|
+
model_name=model_id,
|
|
211
|
+
provider_name=ModelProviderName.openai_compatible,
|
|
212
|
+
base_url=base_url,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def openai_compatible_provider_model(
|
|
217
|
+
model_id: str,
|
|
218
|
+
) -> KilnModelProvider:
|
|
168
219
|
return KilnModelProvider(
|
|
169
220
|
name=ModelProviderName.openai_compatible,
|
|
170
221
|
provider_options={
|
|
171
222
|
"model": model_id,
|
|
172
|
-
"api_key": api_key,
|
|
173
|
-
"openai_api_base": base_url,
|
|
174
223
|
},
|
|
175
224
|
supports_structured_output=False,
|
|
176
225
|
supports_data_gen=False,
|
|
@@ -178,9 +227,10 @@ def openai_compatible_provider_model(
|
|
|
178
227
|
)
|
|
179
228
|
|
|
180
229
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
230
|
+
finetune_cache: dict[str, Finetune] = {}
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def finetune_from_id(model_id: str) -> Finetune:
|
|
184
234
|
if model_id in finetune_cache:
|
|
185
235
|
return finetune_cache[model_id]
|
|
186
236
|
|
|
@@ -202,6 +252,15 @@ def finetune_provider_model(
|
|
|
202
252
|
f"Fine tune {fine_tune_id} not completed. Refresh it's status in the fine-tune tab."
|
|
203
253
|
)
|
|
204
254
|
|
|
255
|
+
finetune_cache[model_id] = fine_tune
|
|
256
|
+
return fine_tune
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def finetune_provider_model(
|
|
260
|
+
model_id: str,
|
|
261
|
+
) -> KilnModelProvider:
|
|
262
|
+
fine_tune = finetune_from_id(model_id)
|
|
263
|
+
|
|
205
264
|
provider = ModelProviderName[fine_tune.provider]
|
|
206
265
|
model_provider = KilnModelProvider(
|
|
207
266
|
name=provider,
|
|
@@ -210,18 +269,18 @@ def finetune_provider_model(
|
|
|
210
269
|
},
|
|
211
270
|
)
|
|
212
271
|
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
272
|
+
if fine_tune.structured_output_mode is not None:
|
|
273
|
+
# If we know the model was trained with specific output mode, set it
|
|
274
|
+
model_provider.structured_output_mode = fine_tune.structured_output_mode
|
|
275
|
+
else:
|
|
276
|
+
# Some early adopters won't have structured_output_mode set on their fine-tunes.
|
|
277
|
+
# We know that OpenAI uses json_schema, and Fireworks (only other provider) use json_mode.
|
|
278
|
+
# This can be removed in the future
|
|
279
|
+
if provider == ModelProviderName.openai:
|
|
280
|
+
model_provider.structured_output_mode = StructuredOutputMode.json_schema
|
|
281
|
+
else:
|
|
282
|
+
model_provider.structured_output_mode = StructuredOutputMode.json_mode
|
|
283
|
+
|
|
225
284
|
return model_provider
|
|
226
285
|
|
|
227
286
|
|
|
@@ -274,7 +333,7 @@ def provider_name_from_id(id: str) -> str:
|
|
|
274
333
|
return "OpenAI Compatible"
|
|
275
334
|
case _:
|
|
276
335
|
# triggers pyright warning if I miss a case
|
|
277
|
-
|
|
336
|
+
raise_exhaustive_enum_error(enum_id)
|
|
278
337
|
|
|
279
338
|
return "Unknown provider: " + id
|
|
280
339
|
|
|
@@ -316,16 +375,12 @@ def provider_options_for_custom_model(
|
|
|
316
375
|
)
|
|
317
376
|
case _:
|
|
318
377
|
# triggers pyright warning if I miss a case
|
|
319
|
-
|
|
378
|
+
raise_exhaustive_enum_error(enum_id)
|
|
320
379
|
|
|
321
380
|
# Won't reach this, type checking will catch missed values
|
|
322
381
|
return {"model": model_name}
|
|
323
382
|
|
|
324
383
|
|
|
325
|
-
def raise_exhaustive_error(value: NoReturn) -> NoReturn:
|
|
326
|
-
raise ValueError(f"Unhandled enum value: {value}")
|
|
327
|
-
|
|
328
|
-
|
|
329
384
|
@dataclass
|
|
330
385
|
class ModelProviderWarning:
|
|
331
386
|
required_config_keys: List[str]
|
|
@@ -3,7 +3,11 @@ from typing import Type
|
|
|
3
3
|
|
|
4
4
|
from pydantic import BaseModel, Field
|
|
5
5
|
|
|
6
|
-
from kiln_ai.adapters.prompt_builders import
|
|
6
|
+
from kiln_ai.adapters.prompt_builders import (
|
|
7
|
+
BasePromptBuilder,
|
|
8
|
+
SavedPromptBuilder,
|
|
9
|
+
prompt_builder_registry,
|
|
10
|
+
)
|
|
7
11
|
from kiln_ai.datamodel import Priority, Project, Task, TaskRequirement, TaskRun
|
|
8
12
|
|
|
9
13
|
|
|
@@ -42,11 +46,18 @@ feedback describing what should be improved. Your job is to understand the evalu
|
|
|
42
46
|
|
|
43
47
|
@classmethod
|
|
44
48
|
def _original_prompt(cls, run: TaskRun, task: Task) -> str:
|
|
49
|
+
if run.output.source is None or run.output.source.properties is None:
|
|
50
|
+
raise ValueError("No source properties found")
|
|
51
|
+
|
|
52
|
+
# Try ID first, then builder name
|
|
53
|
+
prompt_id = run.output.source.properties.get("prompt_id", None)
|
|
54
|
+
if prompt_id is not None and isinstance(prompt_id, str):
|
|
55
|
+
static_prompt_builder = SavedPromptBuilder(task, prompt_id)
|
|
56
|
+
return static_prompt_builder.build_prompt(include_json_instructions=False)
|
|
57
|
+
|
|
45
58
|
prompt_builder_class: Type[BasePromptBuilder] | None = None
|
|
46
|
-
prompt_builder_name = (
|
|
47
|
-
|
|
48
|
-
if run.output.source
|
|
49
|
-
else None
|
|
59
|
+
prompt_builder_name = run.output.source.properties.get(
|
|
60
|
+
"prompt_builder_name", None
|
|
50
61
|
)
|
|
51
62
|
if prompt_builder_name is not None and isinstance(prompt_builder_name, str):
|
|
52
63
|
prompt_builder_class = prompt_builder_registry.get(
|
|
@@ -59,7 +70,7 @@ feedback describing what should be improved. Your job is to understand the evalu
|
|
|
59
70
|
raise ValueError(
|
|
60
71
|
f"Prompt builder {prompt_builder_name} is not a valid prompt builder"
|
|
61
72
|
)
|
|
62
|
-
return prompt_builder.build_prompt()
|
|
73
|
+
return prompt_builder.build_prompt(include_json_instructions=False)
|
|
63
74
|
|
|
64
75
|
@classmethod
|
|
65
76
|
def build_repair_task_input(
|
|
@@ -6,8 +6,8 @@ import pytest
|
|
|
6
6
|
from pydantic import ValidationError
|
|
7
7
|
|
|
8
8
|
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
9
|
-
from kiln_ai.adapters.base_adapter import RunOutput
|
|
10
|
-
from kiln_ai.adapters.langchain_adapters import LangchainAdapter
|
|
9
|
+
from kiln_ai.adapters.model_adapters.base_adapter import RunOutput
|
|
10
|
+
from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
|
|
11
11
|
from kiln_ai.adapters.repair.repair_task import (
|
|
12
12
|
RepairTaskInput,
|
|
13
13
|
RepairTaskRun,
|
|
@@ -223,7 +223,7 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
|
|
|
223
223
|
)
|
|
224
224
|
|
|
225
225
|
adapter = adapter_for_task(
|
|
226
|
-
repair_task, model_name="llama_3_1_8b", provider="
|
|
226
|
+
repair_task, model_name="llama_3_1_8b", provider="ollama"
|
|
227
227
|
)
|
|
228
228
|
|
|
229
229
|
run = await adapter.invoke(repair_task_input.model_dump())
|
|
@@ -237,7 +237,7 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
|
|
|
237
237
|
assert run.output.source.properties == {
|
|
238
238
|
"adapter_name": "kiln_langchain_adapter",
|
|
239
239
|
"model_name": "llama_3_1_8b",
|
|
240
|
-
"model_provider": "
|
|
240
|
+
"model_provider": "ollama",
|
|
241
241
|
"prompt_builder_name": "simple_prompt_builder",
|
|
242
242
|
}
|
|
243
243
|
assert run.input_source.type == DataSourceType.human
|