kiln-ai 0.11.1__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +4 -0
- kiln_ai/adapters/adapter_registry.py +163 -39
- kiln_ai/adapters/data_gen/data_gen_task.py +18 -0
- kiln_ai/adapters/eval/__init__.py +28 -0
- kiln_ai/adapters/eval/base_eval.py +164 -0
- kiln_ai/adapters/eval/eval_runner.py +270 -0
- kiln_ai/adapters/eval/g_eval.py +368 -0
- kiln_ai/adapters/eval/registry.py +16 -0
- kiln_ai/adapters/eval/test_base_eval.py +325 -0
- kiln_ai/adapters/eval/test_eval_runner.py +641 -0
- kiln_ai/adapters/eval/test_g_eval.py +498 -0
- kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +4 -1
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +1 -1
- kiln_ai/adapters/fine_tune/test_together_finetune.py +531 -0
- kiln_ai/adapters/fine_tune/together_finetune.py +325 -0
- kiln_ai/adapters/ml_model_list.py +758 -163
- kiln_ai/adapters/model_adapters/__init__.py +2 -4
- kiln_ai/adapters/model_adapters/base_adapter.py +61 -43
- kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
- kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +22 -13
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -19
- kiln_ai/adapters/model_adapters/test_structured_output.py +59 -35
- kiln_ai/adapters/ollama_tools.py +3 -3
- kiln_ai/adapters/parsers/r1_parser.py +19 -14
- kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
- kiln_ai/adapters/prompt_builders.py +80 -42
- kiln_ai/adapters/provider_tools.py +50 -58
- kiln_ai/adapters/repair/repair_task.py +9 -21
- kiln_ai/adapters/repair/test_repair_task.py +6 -6
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +26 -29
- kiln_ai/adapters/test_generate_docs.py +4 -4
- kiln_ai/adapters/test_ollama_tools.py +0 -1
- kiln_ai/adapters/test_prompt_adaptors.py +47 -33
- kiln_ai/adapters/test_prompt_builders.py +91 -31
- kiln_ai/adapters/test_provider_tools.py +26 -81
- kiln_ai/datamodel/__init__.py +50 -952
- kiln_ai/datamodel/basemodel.py +2 -0
- kiln_ai/datamodel/datamodel_enums.py +60 -0
- kiln_ai/datamodel/dataset_filters.py +114 -0
- kiln_ai/datamodel/dataset_split.py +170 -0
- kiln_ai/datamodel/eval.py +298 -0
- kiln_ai/datamodel/finetune.py +105 -0
- kiln_ai/datamodel/json_schema.py +7 -1
- kiln_ai/datamodel/project.py +23 -0
- kiln_ai/datamodel/prompt.py +37 -0
- kiln_ai/datamodel/prompt_id.py +83 -0
- kiln_ai/datamodel/strict_mode.py +24 -0
- kiln_ai/datamodel/task.py +181 -0
- kiln_ai/datamodel/task_output.py +328 -0
- kiln_ai/datamodel/task_run.py +164 -0
- kiln_ai/datamodel/test_basemodel.py +19 -11
- kiln_ai/datamodel/test_dataset_filters.py +71 -0
- kiln_ai/datamodel/test_dataset_split.py +32 -8
- kiln_ai/datamodel/test_datasource.py +22 -2
- kiln_ai/datamodel/test_eval_model.py +635 -0
- kiln_ai/datamodel/test_example_models.py +9 -13
- kiln_ai/datamodel/test_json_schema.py +23 -0
- kiln_ai/datamodel/test_models.py +2 -2
- kiln_ai/datamodel/test_prompt_id.py +129 -0
- kiln_ai/datamodel/test_task.py +159 -0
- kiln_ai/utils/config.py +43 -1
- kiln_ai/utils/dataset_import.py +232 -0
- kiln_ai/utils/test_dataset_import.py +596 -0
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/METADATA +86 -6
- kiln_ai-0.13.0.dist-info/RECORD +103 -0
- kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -302
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -11
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -246
- kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -350
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -225
- kiln_ai-0.11.1.dist-info/RECORD +0 -76
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -2,8 +2,6 @@ import json
|
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from typing import Dict
|
|
4
4
|
|
|
5
|
-
import jsonschema
|
|
6
|
-
import jsonschema.exceptions
|
|
7
5
|
import pytest
|
|
8
6
|
|
|
9
7
|
import kiln_ai.datamodel as datamodel
|
|
@@ -12,16 +10,13 @@ from kiln_ai.adapters.ml_model_list import (
|
|
|
12
10
|
built_in_models,
|
|
13
11
|
)
|
|
14
12
|
from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
15
|
-
AdapterInfo,
|
|
16
13
|
BaseAdapter,
|
|
17
14
|
RunOutput,
|
|
18
15
|
)
|
|
19
16
|
from kiln_ai.adapters.ollama_tools import ollama_online
|
|
20
|
-
from kiln_ai.adapters.prompt_builders import (
|
|
21
|
-
BasePromptBuilder,
|
|
22
|
-
SimpleChainOfThoughtPromptBuilder,
|
|
23
|
-
)
|
|
24
17
|
from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
|
|
18
|
+
from kiln_ai.datamodel import PromptId
|
|
19
|
+
from kiln_ai.datamodel.task import RunConfig
|
|
25
20
|
from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
|
|
26
21
|
|
|
27
22
|
|
|
@@ -39,9 +34,9 @@ async def test_structured_output_gpt_4o_mini(tmp_path):
|
|
|
39
34
|
await run_structured_output_test(tmp_path, "gpt_4o_mini", "openai")
|
|
40
35
|
|
|
41
36
|
|
|
42
|
-
@pytest.mark.parametrize("model_name", ["llama_3_1_8b"])
|
|
37
|
+
@pytest.mark.parametrize("model_name", ["llama_3_1_8b", "gemma_2_2b"])
|
|
43
38
|
@pytest.mark.ollama
|
|
44
|
-
async def
|
|
39
|
+
async def test_structured_output_ollama(tmp_path, model_name):
|
|
45
40
|
if not await ollama_online():
|
|
46
41
|
pytest.skip("Ollama API not running. Expect it running on localhost:11434")
|
|
47
42
|
await run_structured_output_test(tmp_path, model_name, "ollama")
|
|
@@ -49,19 +44,21 @@ async def test_structured_output_ollama_llama(tmp_path, model_name):
|
|
|
49
44
|
|
|
50
45
|
class MockAdapter(BaseAdapter):
|
|
51
46
|
def __init__(self, kiln_task: datamodel.Task, response: Dict | str | None):
|
|
52
|
-
super().__init__(
|
|
47
|
+
super().__init__(
|
|
48
|
+
run_config=RunConfig(
|
|
49
|
+
task=kiln_task,
|
|
50
|
+
model_name="phi_3_5",
|
|
51
|
+
model_provider_name="ollama",
|
|
52
|
+
prompt_id="simple_chain_of_thought_prompt_builder",
|
|
53
|
+
),
|
|
54
|
+
)
|
|
53
55
|
self.response = response
|
|
54
56
|
|
|
55
57
|
async def _run(self, input: str) -> RunOutput:
|
|
56
58
|
return RunOutput(output=self.response, intermediate_outputs=None)
|
|
57
59
|
|
|
58
|
-
def
|
|
59
|
-
return
|
|
60
|
-
adapter_name="mock_adapter",
|
|
61
|
-
model_name="mock_model",
|
|
62
|
-
model_provider="mock_provider",
|
|
63
|
-
prompt_builder_name="mock_prompt_builder",
|
|
64
|
-
)
|
|
60
|
+
def adapter_name(self) -> str:
|
|
61
|
+
return "mock_adapter"
|
|
65
62
|
|
|
66
63
|
|
|
67
64
|
async def test_mock_unstructred_response(tmp_path):
|
|
@@ -69,7 +66,8 @@ async def test_mock_unstructred_response(tmp_path):
|
|
|
69
66
|
|
|
70
67
|
# don't error on valid response
|
|
71
68
|
adapter = MockAdapter(task, response={"setup": "asdf", "punchline": "asdf"})
|
|
72
|
-
|
|
69
|
+
run = await adapter.invoke("You are a mock, send me the response!")
|
|
70
|
+
answer = json.loads(run.output.output)
|
|
73
71
|
assert answer["setup"] == "asdf"
|
|
74
72
|
assert answer["punchline"] == "asdf"
|
|
75
73
|
|
|
@@ -79,9 +77,12 @@ async def test_mock_unstructred_response(tmp_path):
|
|
|
79
77
|
answer = await adapter.invoke("You are a mock, send me the response!")
|
|
80
78
|
|
|
81
79
|
adapter = MockAdapter(task, response="string instead of dict")
|
|
82
|
-
with pytest.raises(
|
|
80
|
+
with pytest.raises(
|
|
81
|
+
ValueError,
|
|
82
|
+
match="This task requires JSON output but the model didn't return valid JSON",
|
|
83
|
+
):
|
|
83
84
|
# Not a structed response so should error
|
|
84
|
-
|
|
85
|
+
run = await adapter.invoke("You are a mock, send me the response!")
|
|
85
86
|
|
|
86
87
|
# Should error, expecting a string, not a dict
|
|
87
88
|
project = datamodel.Project(name="test", path=tmp_path / "test.kiln")
|
|
@@ -146,7 +147,8 @@ async def run_structured_output_test(tmp_path: Path, model_name: str, provider:
|
|
|
146
147
|
task = build_structured_output_test_task(tmp_path)
|
|
147
148
|
a = adapter_for_task(task, model_name=model_name, provider=provider)
|
|
148
149
|
try:
|
|
149
|
-
|
|
150
|
+
run = await a.invoke("Cows") # a joke about cows
|
|
151
|
+
parsed = json.loads(run.output.output)
|
|
150
152
|
except ValueError as e:
|
|
151
153
|
if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
|
|
152
154
|
pytest.skip(
|
|
@@ -165,6 +167,12 @@ async def run_structured_output_test(tmp_path: Path, model_name: str, provider:
|
|
|
165
167
|
assert rating >= 0
|
|
166
168
|
assert rating <= 10
|
|
167
169
|
|
|
170
|
+
# Check reasoning models
|
|
171
|
+
assert a._model_provider is not None
|
|
172
|
+
if a._model_provider.reasoning_capable:
|
|
173
|
+
assert "reasoning" in run.intermediate_outputs
|
|
174
|
+
assert isinstance(run.intermediate_outputs["reasoning"], str)
|
|
175
|
+
|
|
168
176
|
|
|
169
177
|
def build_structured_input_test_task(tmp_path: Path):
|
|
170
178
|
project = datamodel.Project(name="test", path=tmp_path / "test.kiln")
|
|
@@ -204,20 +212,27 @@ async def run_structured_input_task(
|
|
|
204
212
|
task: datamodel.Task,
|
|
205
213
|
model_name: str,
|
|
206
214
|
provider: str,
|
|
207
|
-
|
|
215
|
+
prompt_id: PromptId | None = None,
|
|
208
216
|
):
|
|
209
217
|
a = adapter_for_task(
|
|
210
|
-
task,
|
|
218
|
+
task,
|
|
219
|
+
model_name=model_name,
|
|
220
|
+
provider=provider,
|
|
221
|
+
prompt_id=prompt_id,
|
|
211
222
|
)
|
|
212
223
|
with pytest.raises(ValueError):
|
|
213
224
|
# not structured input in dictionary
|
|
214
225
|
await a.invoke("a=1, b=2, c=3")
|
|
215
|
-
with pytest.raises(
|
|
226
|
+
with pytest.raises(
|
|
227
|
+
ValueError,
|
|
228
|
+
match="This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema.",
|
|
229
|
+
):
|
|
216
230
|
# invalid structured input
|
|
217
231
|
await a.invoke({"a": 1, "b": 2, "d": 3})
|
|
218
232
|
|
|
219
233
|
try:
|
|
220
|
-
|
|
234
|
+
run = await a.invoke({"a": 2, "b": 2, "c": 2})
|
|
235
|
+
response = run.output.output
|
|
221
236
|
except ValueError as e:
|
|
222
237
|
if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
|
|
223
238
|
pytest.skip(
|
|
@@ -229,13 +244,20 @@ async def run_structured_input_task(
|
|
|
229
244
|
assert "[[equilateral]]" in response
|
|
230
245
|
else:
|
|
231
246
|
assert response["is_equilateral"] is True
|
|
232
|
-
|
|
247
|
+
|
|
233
248
|
expected_pb_name = "simple_prompt_builder"
|
|
234
|
-
if
|
|
235
|
-
expected_pb_name =
|
|
236
|
-
assert
|
|
237
|
-
|
|
238
|
-
assert
|
|
249
|
+
if prompt_id is not None:
|
|
250
|
+
expected_pb_name = prompt_id
|
|
251
|
+
assert a.run_config.prompt_id == expected_pb_name
|
|
252
|
+
|
|
253
|
+
assert a.run_config.model_name == model_name
|
|
254
|
+
assert a.run_config.model_provider_name == provider
|
|
255
|
+
|
|
256
|
+
# Check reasoning models
|
|
257
|
+
assert a._model_provider is not None
|
|
258
|
+
if a._model_provider.reasoning_capable:
|
|
259
|
+
assert "reasoning" in run.intermediate_outputs
|
|
260
|
+
assert isinstance(run.intermediate_outputs["reasoning"], str)
|
|
239
261
|
|
|
240
262
|
|
|
241
263
|
@pytest.mark.paid
|
|
@@ -257,8 +279,9 @@ async def test_all_built_in_models_structured_input(
|
|
|
257
279
|
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
258
280
|
async def test_structured_input_cot_prompt_builder(tmp_path, model_name, provider_name):
|
|
259
281
|
task = build_structured_input_test_task(tmp_path)
|
|
260
|
-
|
|
261
|
-
|
|
282
|
+
await run_structured_input_task(
|
|
283
|
+
task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
|
|
284
|
+
)
|
|
262
285
|
|
|
263
286
|
|
|
264
287
|
@pytest.mark.paid
|
|
@@ -302,5 +325,6 @@ When asked for a final result, this is the format (for an equilateral example):
|
|
|
302
325
|
"""
|
|
303
326
|
task.output_json_schema = json.dumps(triangle_schema)
|
|
304
327
|
task.save_to_file()
|
|
305
|
-
|
|
306
|
-
|
|
328
|
+
await run_structured_input_task(
|
|
329
|
+
task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
|
|
330
|
+
)
|
kiln_ai/adapters/ollama_tools.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import os
|
|
2
1
|
from typing import Any, List
|
|
3
2
|
|
|
4
3
|
import httpx
|
|
@@ -39,6 +38,7 @@ async def ollama_online() -> bool:
|
|
|
39
38
|
|
|
40
39
|
class OllamaConnection(BaseModel):
|
|
41
40
|
message: str
|
|
41
|
+
version: str | None = None
|
|
42
42
|
supported_models: List[str]
|
|
43
43
|
untested_models: List[str] = Field(default_factory=list)
|
|
44
44
|
|
|
@@ -50,7 +50,7 @@ class OllamaConnection(BaseModel):
|
|
|
50
50
|
def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
|
|
51
51
|
# Build a list of models we support for Ollama from the built-in model list
|
|
52
52
|
supported_ollama_models = [
|
|
53
|
-
provider.
|
|
53
|
+
provider.model_id
|
|
54
54
|
for model in built_in_models
|
|
55
55
|
for provider in model.providers
|
|
56
56
|
if provider.name == ModelProviderName.ollama
|
|
@@ -61,7 +61,7 @@ def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
|
|
|
61
61
|
alias
|
|
62
62
|
for model in built_in_models
|
|
63
63
|
for provider in model.providers
|
|
64
|
-
for alias in provider.
|
|
64
|
+
for alias in provider.ollama_model_aliases or []
|
|
65
65
|
]
|
|
66
66
|
)
|
|
67
67
|
|
|
@@ -20,21 +20,33 @@ class R1ThinkingParser(BaseParser):
|
|
|
20
20
|
Raises:
|
|
21
21
|
ValueError: If response format is invalid (missing tags, multiple tags, or no content after closing tag)
|
|
22
22
|
"""
|
|
23
|
+
|
|
24
|
+
# The upstream providers (litellm, openrouter, fireworks) all keep changing their response formats, sometimes adding reasoning parsing where it didn't previously exist.
|
|
25
|
+
# If they do it already, great just return. If not we parse it ourselves. Not ideal, but better than upstream changes breaking the app.
|
|
26
|
+
if (
|
|
27
|
+
original_output.intermediate_outputs is not None
|
|
28
|
+
and "reasoning" in original_output.intermediate_outputs
|
|
29
|
+
):
|
|
30
|
+
return original_output
|
|
31
|
+
|
|
23
32
|
# This parser only works for strings
|
|
24
33
|
if not isinstance(original_output.output, str):
|
|
25
34
|
raise ValueError("Response must be a string for R1 parser")
|
|
26
35
|
|
|
27
36
|
# Strip whitespace and validate basic structure
|
|
28
37
|
cleaned_response = original_output.output.strip()
|
|
29
|
-
if not cleaned_response.startswith(self.START_TAG):
|
|
30
|
-
raise ValueError("Response must start with <think> tag")
|
|
31
38
|
|
|
32
39
|
# Find the thinking tags
|
|
33
|
-
think_start = cleaned_response.find(self.START_TAG)
|
|
34
40
|
think_end = cleaned_response.find(self.END_TAG)
|
|
41
|
+
if think_end == -1:
|
|
42
|
+
raise ValueError("Missing </think> tag")
|
|
35
43
|
|
|
36
|
-
|
|
37
|
-
|
|
44
|
+
think_tag_start = cleaned_response.find(self.START_TAG)
|
|
45
|
+
if think_tag_start == -1:
|
|
46
|
+
# We allow no start <think>, thinking starts on first char. QwQ does this.
|
|
47
|
+
think_start = 0
|
|
48
|
+
else:
|
|
49
|
+
think_start = think_tag_start + len(self.START_TAG)
|
|
38
50
|
|
|
39
51
|
# Check for multiple tags
|
|
40
52
|
if (
|
|
@@ -44,9 +56,7 @@ class R1ThinkingParser(BaseParser):
|
|
|
44
56
|
raise ValueError("Multiple thinking tags found")
|
|
45
57
|
|
|
46
58
|
# Extract thinking content
|
|
47
|
-
thinking_content = cleaned_response[
|
|
48
|
-
think_start + len(self.START_TAG) : think_end
|
|
49
|
-
].strip()
|
|
59
|
+
thinking_content = cleaned_response[think_start:think_end].strip()
|
|
50
60
|
|
|
51
61
|
# Extract result (everything after </think>)
|
|
52
62
|
result = cleaned_response[think_end + len(self.END_TAG) :].strip()
|
|
@@ -54,16 +64,11 @@ class R1ThinkingParser(BaseParser):
|
|
|
54
64
|
if not result or len(result) == 0:
|
|
55
65
|
raise ValueError("No content found after </think> tag")
|
|
56
66
|
|
|
57
|
-
# Parse JSON if needed
|
|
58
|
-
output = result
|
|
59
|
-
if self.structured_output:
|
|
60
|
-
output = parse_json_string(result)
|
|
61
|
-
|
|
62
67
|
# Add thinking content to intermediate outputs if it exists
|
|
63
68
|
intermediate_outputs = original_output.intermediate_outputs or {}
|
|
64
69
|
intermediate_outputs["reasoning"] = thinking_content
|
|
65
70
|
|
|
66
71
|
return RunOutput(
|
|
67
|
-
output=
|
|
72
|
+
output=result,
|
|
68
73
|
intermediate_outputs=intermediate_outputs,
|
|
69
74
|
)
|
|
@@ -19,6 +19,16 @@ def test_valid_response(parser):
|
|
|
19
19
|
assert parsed.output == "This is the result"
|
|
20
20
|
|
|
21
21
|
|
|
22
|
+
def test_already_parsed_response(parser):
|
|
23
|
+
response = RunOutput(
|
|
24
|
+
output="This is the result",
|
|
25
|
+
intermediate_outputs={"reasoning": "This is thinking content"},
|
|
26
|
+
)
|
|
27
|
+
parsed = parser.parse_output(response)
|
|
28
|
+
assert parsed.intermediate_outputs["reasoning"] == "This is thinking content"
|
|
29
|
+
assert parsed.output == "This is the result"
|
|
30
|
+
|
|
31
|
+
|
|
22
32
|
def test_response_with_whitespace(parser):
|
|
23
33
|
response = RunOutput(
|
|
24
34
|
output="""
|
|
@@ -37,14 +47,16 @@ def test_response_with_whitespace(parser):
|
|
|
37
47
|
|
|
38
48
|
|
|
39
49
|
def test_missing_start_tag(parser):
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
50
|
+
parsed = parser.parse_output(
|
|
51
|
+
RunOutput(output="Some content</think>result", intermediate_outputs=None)
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
assert parsed.intermediate_outputs["reasoning"] == "Some content"
|
|
55
|
+
assert parsed.output == "result"
|
|
44
56
|
|
|
45
57
|
|
|
46
58
|
def test_missing_end_tag(parser):
|
|
47
|
-
with pytest.raises(ValueError, match="Missing
|
|
59
|
+
with pytest.raises(ValueError, match="Missing </think> tag"):
|
|
48
60
|
parser.parse_output(
|
|
49
61
|
RunOutput(output="<think>Some content", intermediate_outputs=None)
|
|
50
62
|
)
|
|
@@ -2,8 +2,8 @@ import json
|
|
|
2
2
|
from abc import ABCMeta, abstractmethod
|
|
3
3
|
from typing import Dict
|
|
4
4
|
|
|
5
|
-
from kiln_ai.datamodel import Task, TaskRun
|
|
6
|
-
from kiln_ai.utils.
|
|
5
|
+
from kiln_ai.datamodel import PromptGenerators, PromptId, Task, TaskRun
|
|
6
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class BasePromptBuilder(metaclass=ABCMeta):
|
|
@@ -53,17 +53,6 @@ class BasePromptBuilder(metaclass=ABCMeta):
|
|
|
53
53
|
"""
|
|
54
54
|
pass
|
|
55
55
|
|
|
56
|
-
@classmethod
|
|
57
|
-
def prompt_builder_name(cls) -> str:
|
|
58
|
-
"""Returns the name of the prompt builder, to be used for persisting into the datastore.
|
|
59
|
-
|
|
60
|
-
Default implementation gets the name of the prompt builder in snake case. If you change the class name, you should override this so prior saved data is compatible.
|
|
61
|
-
|
|
62
|
-
Returns:
|
|
63
|
-
str: The prompt builder name in snake_case format.
|
|
64
|
-
"""
|
|
65
|
-
return snake_case(cls.__name__)
|
|
66
|
-
|
|
67
56
|
def build_user_message(self, input: Dict | str) -> str:
|
|
68
57
|
"""Build a user message from the input.
|
|
69
58
|
|
|
@@ -300,6 +289,57 @@ class SavedPromptBuilder(BasePromptBuilder):
|
|
|
300
289
|
return self.prompt_model.chain_of_thought_instructions
|
|
301
290
|
|
|
302
291
|
|
|
292
|
+
class TaskRunConfigPromptBuilder(BasePromptBuilder):
|
|
293
|
+
"""A prompt builder that looks up a static prompt in a task run config."""
|
|
294
|
+
|
|
295
|
+
def __init__(self, task: Task, run_config_prompt_id: str):
|
|
296
|
+
parts = run_config_prompt_id.split("::")
|
|
297
|
+
if len(parts) != 4:
|
|
298
|
+
raise ValueError(
|
|
299
|
+
f"Invalid task run config prompt ID: {run_config_prompt_id}. Expected format: 'task_run_config::[project_id]::[task_id]::[run_config_id]'."
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
task_id = parts[2]
|
|
303
|
+
if task_id != task.id:
|
|
304
|
+
raise ValueError(
|
|
305
|
+
f"Task run config prompt ID: {run_config_prompt_id}. Task ID mismatch. Expected: {task.id}, got: {task_id}."
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
run_config_id = parts[3]
|
|
309
|
+
run_config = next(
|
|
310
|
+
(
|
|
311
|
+
run_config
|
|
312
|
+
for run_config in task.run_configs(readonly=True)
|
|
313
|
+
if run_config.id == run_config_id
|
|
314
|
+
),
|
|
315
|
+
None,
|
|
316
|
+
)
|
|
317
|
+
if not run_config:
|
|
318
|
+
raise ValueError(
|
|
319
|
+
f"Task run config ID not found: {run_config_id} for prompt id {run_config_prompt_id}"
|
|
320
|
+
)
|
|
321
|
+
if run_config.prompt is None:
|
|
322
|
+
raise ValueError(
|
|
323
|
+
f"Task run config ID {run_config_id} does not have a stored prompt. Used as prompt id {run_config_prompt_id}"
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
# Load the prompt from the model
|
|
327
|
+
self.prompt = run_config.prompt.prompt
|
|
328
|
+
self.cot_prompt = run_config.prompt.chain_of_thought_instructions
|
|
329
|
+
self.id = run_config_prompt_id
|
|
330
|
+
|
|
331
|
+
super().__init__(task)
|
|
332
|
+
|
|
333
|
+
def prompt_id(self) -> str | None:
|
|
334
|
+
return self.id
|
|
335
|
+
|
|
336
|
+
def build_base_prompt(self) -> str:
|
|
337
|
+
return self.prompt
|
|
338
|
+
|
|
339
|
+
def chain_of_thought_prompt(self) -> str | None:
|
|
340
|
+
return self.cot_prompt
|
|
341
|
+
|
|
342
|
+
|
|
303
343
|
class FineTunePromptBuilder(BasePromptBuilder):
|
|
304
344
|
"""A prompt builder that looks up a fine-tune prompt."""
|
|
305
345
|
|
|
@@ -337,25 +377,12 @@ class FineTunePromptBuilder(BasePromptBuilder):
|
|
|
337
377
|
return self.fine_tune_model.thinking_instructions
|
|
338
378
|
|
|
339
379
|
|
|
340
|
-
# TODO P2: we end up with 2 IDs for these: the keys here (ui_name) and the prompt_builder_name from the class
|
|
341
|
-
# We end up maintaining this in _prompt_generators as well.
|
|
342
|
-
prompt_builder_registry = {
|
|
343
|
-
"simple_prompt_builder": SimplePromptBuilder,
|
|
344
|
-
"multi_shot_prompt_builder": MultiShotPromptBuilder,
|
|
345
|
-
"few_shot_prompt_builder": FewShotPromptBuilder,
|
|
346
|
-
"repairs_prompt_builder": RepairsPromptBuilder,
|
|
347
|
-
"simple_chain_of_thought_prompt_builder": SimpleChainOfThoughtPromptBuilder,
|
|
348
|
-
"few_shot_chain_of_thought_prompt_builder": FewShotChainOfThoughtPromptBuilder,
|
|
349
|
-
"multi_shot_chain_of_thought_prompt_builder": MultiShotChainOfThoughtPromptBuilder,
|
|
350
|
-
}
|
|
351
|
-
|
|
352
|
-
|
|
353
380
|
# Our UI has some names that are not the same as the class names, which also hint parameters.
|
|
354
|
-
def
|
|
381
|
+
def prompt_builder_from_id(prompt_id: PromptId, task: Task) -> BasePromptBuilder:
|
|
355
382
|
"""Convert a name used in the UI to the corresponding prompt builder class.
|
|
356
383
|
|
|
357
384
|
Args:
|
|
358
|
-
|
|
385
|
+
prompt_id (PromptId): The prompt ID.
|
|
359
386
|
|
|
360
387
|
Returns:
|
|
361
388
|
type[BasePromptBuilder]: The corresponding prompt builder class.
|
|
@@ -365,29 +392,40 @@ def prompt_builder_from_ui_name(ui_name: str, task: Task) -> BasePromptBuilder:
|
|
|
365
392
|
"""
|
|
366
393
|
|
|
367
394
|
# Saved prompts are prefixed with "id::"
|
|
368
|
-
if
|
|
369
|
-
prompt_id =
|
|
395
|
+
if prompt_id.startswith("id::"):
|
|
396
|
+
prompt_id = prompt_id[4:]
|
|
370
397
|
return SavedPromptBuilder(task, prompt_id)
|
|
371
398
|
|
|
399
|
+
# Task run config prompts are prefixed with "task_run_config::"
|
|
400
|
+
# task_run_config::[project_id]::[task_id]::[run_config_id]
|
|
401
|
+
if prompt_id.startswith("task_run_config::"):
|
|
402
|
+
return TaskRunConfigPromptBuilder(task, prompt_id)
|
|
403
|
+
|
|
372
404
|
# Fine-tune prompts are prefixed with "fine_tune_prompt::"
|
|
373
|
-
if
|
|
374
|
-
|
|
375
|
-
return FineTunePromptBuilder(task,
|
|
405
|
+
if prompt_id.startswith("fine_tune_prompt::"):
|
|
406
|
+
prompt_id = prompt_id[18:]
|
|
407
|
+
return FineTunePromptBuilder(task, prompt_id)
|
|
408
|
+
|
|
409
|
+
# Check if the prompt_id matches any enum value
|
|
410
|
+
if prompt_id not in [member.value for member in PromptGenerators]:
|
|
411
|
+
raise ValueError(f"Unknown prompt generator: {prompt_id}")
|
|
412
|
+
typed_prompt_generator = PromptGenerators(prompt_id)
|
|
376
413
|
|
|
377
|
-
match
|
|
378
|
-
case
|
|
414
|
+
match typed_prompt_generator:
|
|
415
|
+
case PromptGenerators.SIMPLE:
|
|
379
416
|
return SimplePromptBuilder(task)
|
|
380
|
-
case
|
|
417
|
+
case PromptGenerators.FEW_SHOT:
|
|
381
418
|
return FewShotPromptBuilder(task)
|
|
382
|
-
case
|
|
419
|
+
case PromptGenerators.MULTI_SHOT:
|
|
383
420
|
return MultiShotPromptBuilder(task)
|
|
384
|
-
case
|
|
421
|
+
case PromptGenerators.REPAIRS:
|
|
385
422
|
return RepairsPromptBuilder(task)
|
|
386
|
-
case
|
|
423
|
+
case PromptGenerators.SIMPLE_CHAIN_OF_THOUGHT:
|
|
387
424
|
return SimpleChainOfThoughtPromptBuilder(task)
|
|
388
|
-
case
|
|
425
|
+
case PromptGenerators.FEW_SHOT_CHAIN_OF_THOUGHT:
|
|
389
426
|
return FewShotChainOfThoughtPromptBuilder(task)
|
|
390
|
-
case
|
|
427
|
+
case PromptGenerators.MULTI_SHOT_CHAIN_OF_THOUGHT:
|
|
391
428
|
return MultiShotChainOfThoughtPromptBuilder(task)
|
|
392
429
|
case _:
|
|
393
|
-
|
|
430
|
+
# Type checking will find missing cases
|
|
431
|
+
raise_exhaustive_enum_error(typed_prompt_generator)
|
|
@@ -9,8 +9,8 @@ from kiln_ai.adapters.ml_model_list import (
|
|
|
9
9
|
StructuredOutputMode,
|
|
10
10
|
built_in_models,
|
|
11
11
|
)
|
|
12
|
-
from kiln_ai.adapters.model_adapters.
|
|
13
|
-
|
|
12
|
+
from kiln_ai.adapters.model_adapters.litellm_config import (
|
|
13
|
+
LiteLlmConfig,
|
|
14
14
|
)
|
|
15
15
|
from kiln_ai.adapters.ollama_tools import (
|
|
16
16
|
get_ollama_connection,
|
|
@@ -153,7 +153,7 @@ def kiln_model_provider_from(
|
|
|
153
153
|
return finetune_provider_model(name)
|
|
154
154
|
|
|
155
155
|
if provider_name == ModelProviderName.openai_compatible:
|
|
156
|
-
return
|
|
156
|
+
return lite_llm_provider_model(name)
|
|
157
157
|
|
|
158
158
|
built_in_model = builtin_model_from(name, provider_name)
|
|
159
159
|
if built_in_model:
|
|
@@ -175,13 +175,13 @@ def kiln_model_provider_from(
|
|
|
175
175
|
supports_structured_output=False,
|
|
176
176
|
supports_data_gen=False,
|
|
177
177
|
untested_model=True,
|
|
178
|
-
|
|
178
|
+
model_id=name,
|
|
179
179
|
)
|
|
180
180
|
|
|
181
181
|
|
|
182
|
-
def
|
|
182
|
+
def lite_llm_config(
|
|
183
183
|
model_id: str,
|
|
184
|
-
) ->
|
|
184
|
+
) -> LiteLlmConfig:
|
|
185
185
|
try:
|
|
186
186
|
openai_provider_name, model_id = model_id.split("::")
|
|
187
187
|
except Exception:
|
|
@@ -205,22 +205,23 @@ def openai_compatible_config(
|
|
|
205
205
|
f"OpenAI compatible provider {openai_provider_name} has no base URL"
|
|
206
206
|
)
|
|
207
207
|
|
|
208
|
-
return
|
|
209
|
-
|
|
208
|
+
return LiteLlmConfig(
|
|
209
|
+
# OpenAI compatible, with a custom base URL
|
|
210
210
|
model_name=model_id,
|
|
211
211
|
provider_name=ModelProviderName.openai_compatible,
|
|
212
212
|
base_url=base_url,
|
|
213
|
+
additional_body_options={
|
|
214
|
+
"api_key": api_key,
|
|
215
|
+
},
|
|
213
216
|
)
|
|
214
217
|
|
|
215
218
|
|
|
216
|
-
def
|
|
219
|
+
def lite_llm_provider_model(
|
|
217
220
|
model_id: str,
|
|
218
221
|
) -> KilnModelProvider:
|
|
219
222
|
return KilnModelProvider(
|
|
220
223
|
name=ModelProviderName.openai_compatible,
|
|
221
|
-
|
|
222
|
-
"model": model_id,
|
|
223
|
-
},
|
|
224
|
+
model_id=model_id,
|
|
224
225
|
supports_structured_output=False,
|
|
225
226
|
supports_data_gen=False,
|
|
226
227
|
untested_model=True,
|
|
@@ -264,9 +265,7 @@ def finetune_provider_model(
|
|
|
264
265
|
provider = ModelProviderName[fine_tune.provider]
|
|
265
266
|
model_provider = KilnModelProvider(
|
|
266
267
|
name=provider,
|
|
267
|
-
|
|
268
|
-
"model": fine_tune.fine_tune_model_id,
|
|
269
|
-
},
|
|
268
|
+
model_id=fine_tune.fine_tune_model_id,
|
|
270
269
|
)
|
|
271
270
|
|
|
272
271
|
if fine_tune.structured_output_mode is not None:
|
|
@@ -331,6 +330,18 @@ def provider_name_from_id(id: str) -> str:
|
|
|
331
330
|
return "Custom Models"
|
|
332
331
|
case ModelProviderName.openai_compatible:
|
|
333
332
|
return "OpenAI Compatible"
|
|
333
|
+
case ModelProviderName.azure_openai:
|
|
334
|
+
return "Azure OpenAI"
|
|
335
|
+
case ModelProviderName.gemini_api:
|
|
336
|
+
return "Gemini API"
|
|
337
|
+
case ModelProviderName.anthropic:
|
|
338
|
+
return "Anthropic"
|
|
339
|
+
case ModelProviderName.huggingface:
|
|
340
|
+
return "Hugging Face"
|
|
341
|
+
case ModelProviderName.vertex:
|
|
342
|
+
return "Google Vertex AI"
|
|
343
|
+
case ModelProviderName.together_ai:
|
|
344
|
+
return "Together AI"
|
|
334
345
|
case _:
|
|
335
346
|
# triggers pyright warning if I miss a case
|
|
336
347
|
raise_exhaustive_enum_error(enum_id)
|
|
@@ -338,49 +349,6 @@ def provider_name_from_id(id: str) -> str:
|
|
|
338
349
|
return "Unknown provider: " + id
|
|
339
350
|
|
|
340
351
|
|
|
341
|
-
def provider_options_for_custom_model(
|
|
342
|
-
model_name: str, provider_name: str
|
|
343
|
-
) -> Dict[str, str]:
|
|
344
|
-
"""
|
|
345
|
-
Generated model provider options for a custom model. Each has their own format/options.
|
|
346
|
-
"""
|
|
347
|
-
|
|
348
|
-
if provider_name not in ModelProviderName.__members__:
|
|
349
|
-
raise ValueError(f"Invalid provider name: {provider_name}")
|
|
350
|
-
|
|
351
|
-
enum_id = ModelProviderName(provider_name)
|
|
352
|
-
match enum_id:
|
|
353
|
-
case ModelProviderName.amazon_bedrock:
|
|
354
|
-
# us-west-2 is the only region consistently supported by Bedrock
|
|
355
|
-
return {"model": model_name, "region_name": "us-west-2"}
|
|
356
|
-
case (
|
|
357
|
-
ModelProviderName.openai
|
|
358
|
-
| ModelProviderName.ollama
|
|
359
|
-
| ModelProviderName.fireworks_ai
|
|
360
|
-
| ModelProviderName.openrouter
|
|
361
|
-
| ModelProviderName.groq
|
|
362
|
-
):
|
|
363
|
-
return {"model": model_name}
|
|
364
|
-
case ModelProviderName.kiln_custom_registry:
|
|
365
|
-
raise ValueError(
|
|
366
|
-
"Custom models from registry should be parsed into provider/model before calling this."
|
|
367
|
-
)
|
|
368
|
-
case ModelProviderName.kiln_fine_tune:
|
|
369
|
-
raise ValueError(
|
|
370
|
-
"Fine tuned models should populate provider options via another path"
|
|
371
|
-
)
|
|
372
|
-
case ModelProviderName.openai_compatible:
|
|
373
|
-
raise ValueError(
|
|
374
|
-
"OpenAI compatible models should populate provider options via another path"
|
|
375
|
-
)
|
|
376
|
-
case _:
|
|
377
|
-
# triggers pyright warning if I miss a case
|
|
378
|
-
raise_exhaustive_enum_error(enum_id)
|
|
379
|
-
|
|
380
|
-
# Won't reach this, type checking will catch missed values
|
|
381
|
-
return {"model": model_name}
|
|
382
|
-
|
|
383
|
-
|
|
384
352
|
@dataclass
|
|
385
353
|
class ModelProviderWarning:
|
|
386
354
|
required_config_keys: List[str]
|
|
@@ -408,4 +376,28 @@ provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
|
|
|
408
376
|
required_config_keys=["fireworks_api_key", "fireworks_account_id"],
|
|
409
377
|
message="Attempted to use Fireworks without an API key and account ID set. \nGet your API key from https://fireworks.ai/account/api-keys and your account ID from https://fireworks.ai/account/profile",
|
|
410
378
|
),
|
|
379
|
+
ModelProviderName.anthropic: ModelProviderWarning(
|
|
380
|
+
required_config_keys=["anthropic_api_key"],
|
|
381
|
+
message="Attempted to use Anthropic without an API key set. \nGet your API key from https://console.anthropic.com/settings/keys",
|
|
382
|
+
),
|
|
383
|
+
ModelProviderName.gemini_api: ModelProviderWarning(
|
|
384
|
+
required_config_keys=["gemini_api_key"],
|
|
385
|
+
message="Attempted to use Gemini without an API key set. \nGet your API key from https://aistudio.google.com/app/apikey",
|
|
386
|
+
),
|
|
387
|
+
ModelProviderName.azure_openai: ModelProviderWarning(
|
|
388
|
+
required_config_keys=["azure_openai_api_key", "azure_openai_endpoint"],
|
|
389
|
+
message="Attempted to use Azure OpenAI without an API key and endpoint set. Configure these in settings.",
|
|
390
|
+
),
|
|
391
|
+
ModelProviderName.huggingface: ModelProviderWarning(
|
|
392
|
+
required_config_keys=["huggingface_api_key"],
|
|
393
|
+
message="Attempted to use Hugging Face without an API key set. \nGet your API key from https://huggingface.co/settings/tokens",
|
|
394
|
+
),
|
|
395
|
+
ModelProviderName.vertex: ModelProviderWarning(
|
|
396
|
+
required_config_keys=["vertex_project_id"],
|
|
397
|
+
message="Attempted to use Vertex without a project ID set. \nGet your project ID from the Vertex AI console.",
|
|
398
|
+
),
|
|
399
|
+
ModelProviderName.together_ai: ModelProviderWarning(
|
|
400
|
+
required_config_keys=["together_api_key"],
|
|
401
|
+
message="Attempted to use Together without an API key set. \nGet your API key from https://together.ai/settings/keys",
|
|
402
|
+
),
|
|
411
403
|
}
|