kiln-ai 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kiln_ai/adapters/eval/base_eval.py +7 -2
- kiln_ai/adapters/eval/eval_runner.py +5 -64
- kiln_ai/adapters/eval/g_eval.py +3 -3
- kiln_ai/adapters/fine_tune/base_finetune.py +6 -3
- kiln_ai/adapters/fine_tune/dataset_formatter.py +128 -38
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +2 -1
- kiln_ai/adapters/fine_tune/test_base_finetune.py +7 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +267 -10
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +586 -0
- kiln_ai/adapters/fine_tune/vertex_finetune.py +217 -0
- kiln_ai/adapters/ml_model_list.py +817 -62
- kiln_ai/adapters/model_adapters/base_adapter.py +33 -10
- kiln_ai/adapters/model_adapters/litellm_adapter.py +51 -12
- kiln_ai/adapters/model_adapters/test_base_adapter.py +74 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +65 -1
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +3 -2
- kiln_ai/adapters/model_adapters/test_structured_output.py +4 -6
- kiln_ai/adapters/parsers/base_parser.py +0 -3
- kiln_ai/adapters/parsers/parser_registry.py +5 -3
- kiln_ai/adapters/parsers/r1_parser.py +17 -2
- kiln_ai/adapters/parsers/request_formatters.py +40 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
- kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
- kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
- kiln_ai/adapters/prompt_builders.py +14 -1
- kiln_ai/adapters/provider_tools.py +25 -1
- kiln_ai/adapters/repair/test_repair_task.py +3 -2
- kiln_ai/adapters/test_prompt_builders.py +24 -3
- kiln_ai/adapters/test_provider_tools.py +86 -1
- kiln_ai/datamodel/__init__.py +2 -0
- kiln_ai/datamodel/datamodel_enums.py +14 -0
- kiln_ai/datamodel/dataset_filters.py +69 -1
- kiln_ai/datamodel/dataset_split.py +4 -0
- kiln_ai/datamodel/eval.py +8 -0
- kiln_ai/datamodel/finetune.py +1 -0
- kiln_ai/datamodel/json_schema.py +24 -7
- kiln_ai/datamodel/prompt_id.py +1 -0
- kiln_ai/datamodel/task_output.py +10 -6
- kiln_ai/datamodel/task_run.py +68 -12
- kiln_ai/datamodel/test_basemodel.py +3 -7
- kiln_ai/datamodel/test_dataset_filters.py +82 -0
- kiln_ai/datamodel/test_dataset_split.py +2 -0
- kiln_ai/datamodel/test_example_models.py +158 -3
- kiln_ai/datamodel/test_json_schema.py +22 -3
- kiln_ai/datamodel/test_model_perf.py +3 -2
- kiln_ai/datamodel/test_models.py +50 -2
- kiln_ai/utils/async_job_runner.py +106 -0
- kiln_ai/utils/dataset_import.py +80 -18
- kiln_ai/utils/test_async_job_runner.py +199 -0
- kiln_ai/utils/test_dataset_import.py +242 -10
- {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/METADATA +3 -2
- kiln_ai-0.16.0.dist-info/RECORD +108 -0
- kiln_ai/adapters/test_generate_docs.py +0 -69
- kiln_ai-0.14.0.dist-info/RECORD +0 -103
- {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from kiln_ai.adapters.ml_model_list import ModelFormatterID
|
|
4
|
+
from kiln_ai.adapters.parsers.request_formatters import (
|
|
5
|
+
Qwen3StyleNoThinkFormatter,
|
|
6
|
+
request_formatter_from_id,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@pytest.fixture
|
|
11
|
+
def qwen_formatter():
|
|
12
|
+
return Qwen3StyleNoThinkFormatter()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def test_qwen_formatter_string_input(qwen_formatter):
|
|
16
|
+
input_text = "Hello world"
|
|
17
|
+
formatted = qwen_formatter.format_input(input_text)
|
|
18
|
+
assert formatted == "Hello world\n\n/no_think"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def test_qwen_formatter_dict_input(qwen_formatter):
|
|
22
|
+
input_dict = {"key": "value", "nested": {"inner": "data"}}
|
|
23
|
+
formatted = qwen_formatter.format_input(input_dict)
|
|
24
|
+
expected = """{
|
|
25
|
+
"key": "value",
|
|
26
|
+
"nested": {
|
|
27
|
+
"inner": "data"
|
|
28
|
+
}
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
/no_think"""
|
|
32
|
+
assert formatted == expected
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def test_qwen_formatter_empty_input(qwen_formatter):
|
|
36
|
+
# Test empty string
|
|
37
|
+
assert qwen_formatter.format_input("") == "\n\n/no_think"
|
|
38
|
+
|
|
39
|
+
# Test empty dict
|
|
40
|
+
assert qwen_formatter.format_input({}) == "{}\n\n/no_think"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def test_qwen_formatter_special_characters(qwen_formatter):
|
|
44
|
+
input_text = "Special chars: !@#$%^&*()_+思"
|
|
45
|
+
formatted = qwen_formatter.format_input(input_text)
|
|
46
|
+
assert formatted == "Special chars: !@#$%^&*()_+思\n\n/no_think"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def test_qwen_formatter_multiline_string(qwen_formatter):
|
|
50
|
+
input_text = """Line 1
|
|
51
|
+
Line 2
|
|
52
|
+
Line 3"""
|
|
53
|
+
formatted = qwen_formatter.format_input(input_text)
|
|
54
|
+
assert (
|
|
55
|
+
formatted
|
|
56
|
+
== """Line 1
|
|
57
|
+
Line 2
|
|
58
|
+
Line 3
|
|
59
|
+
|
|
60
|
+
/no_think"""
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def test_request_formatter_factory():
|
|
65
|
+
# Test valid formatter ID
|
|
66
|
+
formatter = request_formatter_from_id(ModelFormatterID.qwen3_style_no_think)
|
|
67
|
+
assert isinstance(formatter, Qwen3StyleNoThinkFormatter)
|
|
68
|
+
|
|
69
|
+
# Test that the formatter works
|
|
70
|
+
assert formatter.format_input("test") == "test\n\n/no_think"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def test_request_formatter_factory_invalid_id():
|
|
74
|
+
# Test with an invalid enum value by using a string that doesn't exist in the enum
|
|
75
|
+
with pytest.raises(ValueError, match="Unhandled enum value"):
|
|
76
|
+
request_formatter_from_id("invalid_formatter_id") # type: ignore
|
|
@@ -101,7 +101,6 @@ class SimplePromptBuilder(BasePromptBuilder):
|
|
|
101
101
|
"""
|
|
102
102
|
base_prompt = self.task.instruction
|
|
103
103
|
|
|
104
|
-
# TODO: this is just a quick version. Formatting and best practices TBD
|
|
105
104
|
if len(self.task.requirements) > 0:
|
|
106
105
|
base_prompt += (
|
|
107
106
|
"\n\nYour response should respect the following requirements:\n"
|
|
@@ -113,6 +112,18 @@ class SimplePromptBuilder(BasePromptBuilder):
|
|
|
113
112
|
return base_prompt
|
|
114
113
|
|
|
115
114
|
|
|
115
|
+
class ShortPromptBuilder(BasePromptBuilder):
|
|
116
|
+
"""A prompt builder that includes a the base prompt but excludes the requirements."""
|
|
117
|
+
|
|
118
|
+
def build_base_prompt(self) -> str:
|
|
119
|
+
"""Build a short prompt with just the base prompt, no requirements.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
str: The constructed prompt string.
|
|
123
|
+
"""
|
|
124
|
+
return self.task.instruction
|
|
125
|
+
|
|
126
|
+
|
|
116
127
|
class MultiShotPromptBuilder(BasePromptBuilder):
|
|
117
128
|
"""A prompt builder that includes multiple examples in the prompt."""
|
|
118
129
|
|
|
@@ -414,6 +425,8 @@ def prompt_builder_from_id(prompt_id: PromptId, task: Task) -> BasePromptBuilder
|
|
|
414
425
|
match typed_prompt_generator:
|
|
415
426
|
case PromptGenerators.SIMPLE:
|
|
416
427
|
return SimplePromptBuilder(task)
|
|
428
|
+
case PromptGenerators.SHORT:
|
|
429
|
+
return ShortPromptBuilder(task)
|
|
417
430
|
case PromptGenerators.FEW_SHOT:
|
|
418
431
|
return FewShotPromptBuilder(task)
|
|
419
432
|
case PromptGenerators.MULTI_SHOT:
|
|
@@ -5,6 +5,7 @@ from kiln_ai.adapters.ml_model_list import (
|
|
|
5
5
|
KilnModel,
|
|
6
6
|
KilnModelProvider,
|
|
7
7
|
ModelName,
|
|
8
|
+
ModelParserID,
|
|
8
9
|
ModelProviderName,
|
|
9
10
|
StructuredOutputMode,
|
|
10
11
|
built_in_models,
|
|
@@ -15,7 +16,7 @@ from kiln_ai.adapters.model_adapters.litellm_config import (
|
|
|
15
16
|
from kiln_ai.adapters.ollama_tools import (
|
|
16
17
|
get_ollama_connection,
|
|
17
18
|
)
|
|
18
|
-
from kiln_ai.datamodel import Finetune, Task
|
|
19
|
+
from kiln_ai.datamodel import Finetune, FinetuneDataStrategy, Task
|
|
19
20
|
from kiln_ai.datamodel.registry import project_from_id
|
|
20
21
|
from kiln_ai.utils.config import Config
|
|
21
22
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
@@ -257,6 +258,14 @@ def finetune_from_id(model_id: str) -> Finetune:
|
|
|
257
258
|
return fine_tune
|
|
258
259
|
|
|
259
260
|
|
|
261
|
+
def parser_from_data_strategy(
|
|
262
|
+
data_strategy: FinetuneDataStrategy,
|
|
263
|
+
) -> ModelParserID | None:
|
|
264
|
+
if data_strategy == FinetuneDataStrategy.final_and_intermediate_r1_compatible:
|
|
265
|
+
return ModelParserID.r1_thinking
|
|
266
|
+
return None
|
|
267
|
+
|
|
268
|
+
|
|
260
269
|
def finetune_provider_model(
|
|
261
270
|
model_id: str,
|
|
262
271
|
) -> KilnModelProvider:
|
|
@@ -266,8 +275,23 @@ def finetune_provider_model(
|
|
|
266
275
|
model_provider = KilnModelProvider(
|
|
267
276
|
name=provider,
|
|
268
277
|
model_id=fine_tune.fine_tune_model_id,
|
|
278
|
+
parser=parser_from_data_strategy(fine_tune.data_strategy),
|
|
279
|
+
reasoning_capable=(
|
|
280
|
+
fine_tune.data_strategy
|
|
281
|
+
in [
|
|
282
|
+
FinetuneDataStrategy.final_and_intermediate,
|
|
283
|
+
FinetuneDataStrategy.final_and_intermediate_r1_compatible,
|
|
284
|
+
]
|
|
285
|
+
),
|
|
269
286
|
)
|
|
270
287
|
|
|
288
|
+
if provider == ModelProviderName.vertex and fine_tune.fine_tune_model_id:
|
|
289
|
+
# Vertex AI trick: use the model_id "openai/endpoint_id". OpenAI calls the openai compatible API, which supports endpoint.
|
|
290
|
+
# Context: vertex has at least 3 APIS: vertex, openai compatible, and gemini. LiteLLM tries to infer which to use. This works
|
|
291
|
+
# on current LiteLLM version. Could also set base_model to gemini to tell it which to use, but same result.
|
|
292
|
+
endpoint_id = fine_tune.fine_tune_model_id.split("/")[-1]
|
|
293
|
+
model_provider.model_id = f"openai/{endpoint_id}"
|
|
294
|
+
|
|
271
295
|
if fine_tune.structured_output_mode is not None:
|
|
272
296
|
# If we know the model was trained with specific output mode, set it
|
|
273
297
|
model_provider.structured_output_mode = fine_tune.structured_output_mode
|
|
@@ -218,8 +218,9 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
|
|
|
218
218
|
}
|
|
219
219
|
|
|
220
220
|
with patch.object(LiteLlmAdapter, "_run", new_callable=AsyncMock) as mock_run:
|
|
221
|
-
mock_run.return_value =
|
|
222
|
-
output=mocked_output, intermediate_outputs=None
|
|
221
|
+
mock_run.return_value = (
|
|
222
|
+
RunOutput(output=mocked_output, intermediate_outputs=None),
|
|
223
|
+
None,
|
|
223
224
|
)
|
|
224
225
|
|
|
225
226
|
adapter = adapter_for_task(
|
|
@@ -3,7 +3,7 @@ import logging
|
|
|
3
3
|
|
|
4
4
|
import pytest
|
|
5
5
|
|
|
6
|
-
from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter
|
|
6
|
+
from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput
|
|
7
7
|
from kiln_ai.adapters.model_adapters.test_structured_output import (
|
|
8
8
|
build_structured_output_test_task,
|
|
9
9
|
)
|
|
@@ -15,6 +15,7 @@ from kiln_ai.adapters.prompt_builders import (
|
|
|
15
15
|
MultiShotPromptBuilder,
|
|
16
16
|
RepairsPromptBuilder,
|
|
17
17
|
SavedPromptBuilder,
|
|
18
|
+
ShortPromptBuilder,
|
|
18
19
|
SimpleChainOfThoughtPromptBuilder,
|
|
19
20
|
SimplePromptBuilder,
|
|
20
21
|
TaskRunConfigPromptBuilder,
|
|
@@ -33,6 +34,7 @@ from kiln_ai.datamodel import (
|
|
|
33
34
|
TaskOutput,
|
|
34
35
|
TaskOutputRating,
|
|
35
36
|
TaskRun,
|
|
37
|
+
Usage,
|
|
36
38
|
)
|
|
37
39
|
from kiln_ai.datamodel.task import RunConfigProperties, TaskRunConfig
|
|
38
40
|
|
|
@@ -58,9 +60,28 @@ def test_simple_prompt_builder(tmp_path):
|
|
|
58
60
|
assert input not in prompt
|
|
59
61
|
|
|
60
62
|
|
|
63
|
+
def test_short_prompt_builder(tmp_path):
|
|
64
|
+
task = build_test_task(tmp_path)
|
|
65
|
+
builder = ShortPromptBuilder(task=task)
|
|
66
|
+
prompt = builder.build_prompt(include_json_instructions=False)
|
|
67
|
+
|
|
68
|
+
# Should only include the instruction, not requirements
|
|
69
|
+
assert task.instruction == prompt
|
|
70
|
+
assert task.requirements[0].instruction not in prompt
|
|
71
|
+
assert task.requirements[1].instruction not in prompt
|
|
72
|
+
assert task.requirements[2].instruction not in prompt
|
|
73
|
+
|
|
74
|
+
# Should handle JSON instructions correctly
|
|
75
|
+
prompt_with_json = builder.build_prompt(include_json_instructions=True)
|
|
76
|
+
assert task.instruction in prompt_with_json
|
|
77
|
+
if task.output_schema():
|
|
78
|
+
assert "# Format Instructions" in prompt_with_json
|
|
79
|
+
assert task.output_schema() in prompt_with_json
|
|
80
|
+
|
|
81
|
+
|
|
61
82
|
class MockAdapter(BaseAdapter):
|
|
62
|
-
def _run(self, input: str) ->
|
|
63
|
-
return "mock response"
|
|
83
|
+
async def _run(self, input: str) -> tuple[RunOutput, Usage | None]:
|
|
84
|
+
return RunOutput(output="mock response", intermediate_outputs=None), None
|
|
64
85
|
|
|
65
86
|
def adapter_name(self) -> str:
|
|
66
87
|
return "mock_adapter"
|
|
@@ -5,6 +5,7 @@ import pytest
|
|
|
5
5
|
from kiln_ai.adapters.ml_model_list import (
|
|
6
6
|
KilnModel,
|
|
7
7
|
ModelName,
|
|
8
|
+
ModelParserID,
|
|
8
9
|
ModelProviderName,
|
|
9
10
|
)
|
|
10
11
|
from kiln_ai.adapters.ollama_tools import OllamaConnection
|
|
@@ -24,7 +25,12 @@ from kiln_ai.adapters.provider_tools import (
|
|
|
24
25
|
provider_name_from_id,
|
|
25
26
|
provider_warnings,
|
|
26
27
|
)
|
|
27
|
-
from kiln_ai.datamodel import
|
|
28
|
+
from kiln_ai.datamodel import (
|
|
29
|
+
Finetune,
|
|
30
|
+
FinetuneDataStrategy,
|
|
31
|
+
StructuredOutputMode,
|
|
32
|
+
Task,
|
|
33
|
+
)
|
|
28
34
|
|
|
29
35
|
|
|
30
36
|
@pytest.fixture(autouse=True)
|
|
@@ -65,6 +71,33 @@ def mock_finetune():
|
|
|
65
71
|
finetune.provider = ModelProviderName.openai
|
|
66
72
|
finetune.fine_tune_model_id = "ft:gpt-3.5-turbo:custom:model-123"
|
|
67
73
|
finetune.structured_output_mode = StructuredOutputMode.json_schema
|
|
74
|
+
finetune.data_strategy = FinetuneDataStrategy.final_only
|
|
75
|
+
mock.return_value = finetune
|
|
76
|
+
yield mock
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@pytest.fixture
|
|
80
|
+
def mock_finetune_final_and_intermediate():
|
|
81
|
+
with patch("kiln_ai.datamodel.Finetune.from_id_and_parent_path") as mock:
|
|
82
|
+
finetune = Mock(spec=Finetune)
|
|
83
|
+
finetune.provider = ModelProviderName.openai
|
|
84
|
+
finetune.fine_tune_model_id = "ft:gpt-3.5-turbo:custom:model-123"
|
|
85
|
+
finetune.structured_output_mode = StructuredOutputMode.json_schema
|
|
86
|
+
finetune.data_strategy = FinetuneDataStrategy.final_and_intermediate
|
|
87
|
+
mock.return_value = finetune
|
|
88
|
+
yield mock
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@pytest.fixture
|
|
92
|
+
def mock_finetune_r1_compatible():
|
|
93
|
+
with patch("kiln_ai.datamodel.Finetune.from_id_and_parent_path") as mock:
|
|
94
|
+
finetune = Mock(spec=Finetune)
|
|
95
|
+
finetune.provider = ModelProviderName.ollama
|
|
96
|
+
finetune.fine_tune_model_id = "ft:deepseek-r1:671b:custom:model-123"
|
|
97
|
+
finetune.structured_output_mode = StructuredOutputMode.json_schema
|
|
98
|
+
finetune.data_strategy = (
|
|
99
|
+
FinetuneDataStrategy.final_and_intermediate_r1_compatible
|
|
100
|
+
)
|
|
68
101
|
mock.return_value = finetune
|
|
69
102
|
yield mock
|
|
70
103
|
|
|
@@ -426,6 +459,38 @@ def test_finetune_provider_model_success(mock_project, mock_task, mock_finetune)
|
|
|
426
459
|
assert provider.name == ModelProviderName.openai
|
|
427
460
|
assert provider.model_id == "ft:gpt-3.5-turbo:custom:model-123"
|
|
428
461
|
assert provider.structured_output_mode == StructuredOutputMode.json_schema
|
|
462
|
+
assert provider.reasoning_capable is False
|
|
463
|
+
assert provider.parser == None
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
def test_finetune_provider_model_success_final_and_intermediate(
|
|
467
|
+
mock_project, mock_task, mock_finetune_final_and_intermediate
|
|
468
|
+
):
|
|
469
|
+
"""Test successful creation of a fine-tuned model provider"""
|
|
470
|
+
model_id = "project-123::task-456::finetune-789"
|
|
471
|
+
|
|
472
|
+
provider = finetune_provider_model(model_id)
|
|
473
|
+
|
|
474
|
+
assert provider.name == ModelProviderName.openai
|
|
475
|
+
assert provider.model_id == "ft:gpt-3.5-turbo:custom:model-123"
|
|
476
|
+
assert provider.structured_output_mode == StructuredOutputMode.json_schema
|
|
477
|
+
assert provider.reasoning_capable is True
|
|
478
|
+
assert provider.parser == None
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def test_finetune_provider_model_success_r1_compatible(
|
|
482
|
+
mock_project, mock_task, mock_finetune_r1_compatible
|
|
483
|
+
):
|
|
484
|
+
"""Test successful creation of a fine-tuned model provider"""
|
|
485
|
+
model_id = "project-123::task-456::finetune-789"
|
|
486
|
+
|
|
487
|
+
provider = finetune_provider_model(model_id)
|
|
488
|
+
|
|
489
|
+
assert provider.name == ModelProviderName.ollama
|
|
490
|
+
assert provider.model_id == "ft:deepseek-r1:671b:custom:model-123"
|
|
491
|
+
assert provider.structured_output_mode == StructuredOutputMode.json_schema
|
|
492
|
+
assert provider.reasoning_capable is True
|
|
493
|
+
assert provider.parser == ModelParserID.r1_thinking
|
|
429
494
|
|
|
430
495
|
|
|
431
496
|
def test_finetune_provider_model_invalid_id():
|
|
@@ -515,6 +580,7 @@ def test_finetune_provider_model_structured_mode(
|
|
|
515
580
|
finetune.provider = provider_name
|
|
516
581
|
finetune.fine_tune_model_id = "fireworks-model-123"
|
|
517
582
|
finetune.structured_output_mode = structured_output_mode
|
|
583
|
+
finetune.data_strategy = FinetuneDataStrategy.final_only
|
|
518
584
|
mock_finetune.return_value = finetune
|
|
519
585
|
|
|
520
586
|
provider = finetune_provider_model("project-123::task-456::finetune-789")
|
|
@@ -522,6 +588,8 @@ def test_finetune_provider_model_structured_mode(
|
|
|
522
588
|
assert provider.name == provider_name
|
|
523
589
|
assert provider.model_id == "fireworks-model-123"
|
|
524
590
|
assert provider.structured_output_mode == expected_mode
|
|
591
|
+
assert provider.reasoning_capable is False
|
|
592
|
+
assert provider.parser == None
|
|
525
593
|
|
|
526
594
|
|
|
527
595
|
def test_openai_compatible_provider_config(mock_shared_config):
|
|
@@ -791,3 +859,20 @@ def test_finetune_from_id_cache_hit(mock_project, mock_task, mock_finetune):
|
|
|
791
859
|
mock_project.assert_not_called()
|
|
792
860
|
mock_task.assert_not_called()
|
|
793
861
|
mock_finetune.assert_not_called()
|
|
862
|
+
|
|
863
|
+
|
|
864
|
+
def test_finetune_provider_model_vertex_ai(mock_project, mock_task, mock_finetune):
|
|
865
|
+
"""Test creation of provider for Vertex AI with endpoint ID transformation"""
|
|
866
|
+
finetune = Mock(spec=Finetune)
|
|
867
|
+
finetune.provider = ModelProviderName.vertex
|
|
868
|
+
finetune.fine_tune_model_id = "projects/123/locations/us-central1/endpoints/456"
|
|
869
|
+
finetune.structured_output_mode = StructuredOutputMode.json_mode
|
|
870
|
+
finetune.data_strategy = FinetuneDataStrategy.final_only
|
|
871
|
+
mock_finetune.return_value = finetune
|
|
872
|
+
|
|
873
|
+
provider = finetune_provider_model("project-123::task-456::finetune-789")
|
|
874
|
+
|
|
875
|
+
assert provider.name == ModelProviderName.vertex
|
|
876
|
+
# Verify the model_id is transformed into openai/endpoint_id format
|
|
877
|
+
assert provider.model_id == "openai/456"
|
|
878
|
+
assert provider.structured_output_mode == StructuredOutputMode.json_mode
|
kiln_ai/datamodel/__init__.py
CHANGED
|
@@ -44,6 +44,7 @@ from kiln_ai.datamodel.task_output import (
|
|
|
44
44
|
)
|
|
45
45
|
from kiln_ai.datamodel.task_run import (
|
|
46
46
|
TaskRun,
|
|
47
|
+
Usage,
|
|
47
48
|
)
|
|
48
49
|
|
|
49
50
|
__all__ = [
|
|
@@ -74,4 +75,5 @@ __all__ = [
|
|
|
74
75
|
"PromptId",
|
|
75
76
|
"PromptGenerators",
|
|
76
77
|
"prompt_generator_values",
|
|
78
|
+
"Usage",
|
|
77
79
|
]
|
|
@@ -56,5 +56,19 @@ class FineTuneStatusType(str, Enum):
|
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
class FinetuneDataStrategy(str, Enum):
|
|
59
|
+
"""Strategy for what data to include when fine-tuning a model."""
|
|
60
|
+
|
|
61
|
+
# Only train on the final response, ignoring any intermediate steps or chain of thought
|
|
59
62
|
final_only = "final_only"
|
|
63
|
+
|
|
64
|
+
# Train on both the final response and any intermediate steps/chain of thought
|
|
60
65
|
final_and_intermediate = "final_and_intermediate"
|
|
66
|
+
|
|
67
|
+
# Train using R1-style thinking format, which includes the reasoning in <think> tags in the message
|
|
68
|
+
final_and_intermediate_r1_compatible = "final_and_intermediate_r1_compatible"
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
THINKING_DATA_STRATEGIES: list[FinetuneDataStrategy] = [
|
|
72
|
+
FinetuneDataStrategy.final_and_intermediate,
|
|
73
|
+
FinetuneDataStrategy.final_and_intermediate_r1_compatible,
|
|
74
|
+
]
|
|
@@ -1,5 +1,6 @@
|
|
|
1
|
+
import re
|
|
1
2
|
from enum import Enum
|
|
2
|
-
from typing import Annotated, Protocol
|
|
3
|
+
from typing import Annotated, ClassVar, List, Protocol
|
|
3
4
|
|
|
4
5
|
from pydantic import AfterValidator
|
|
5
6
|
|
|
@@ -59,6 +60,65 @@ class TagFilter:
|
|
|
59
60
|
return self.tag in task_run.tags
|
|
60
61
|
|
|
61
62
|
|
|
63
|
+
class MultiDatasetFilter:
|
|
64
|
+
"""
|
|
65
|
+
A filter that combines multiple filters using AND logic.
|
|
66
|
+
The filters are specified in a query string format after 'multi_filter::'
|
|
67
|
+
Example: multi_filter::high_rating&thinking_model&tag::tag_name
|
|
68
|
+
|
|
69
|
+
Ampersands in filter IDs can be escaped with a backslash.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
PREFIX: ClassVar[str] = "multi_filter::"
|
|
73
|
+
ESCAPED_AMPERSAND: ClassVar[str] = r"\&"
|
|
74
|
+
UNESCAPED_AMPERSAND: ClassVar[str] = "&"
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def parse_filter_string(cls, filter_string: str) -> List[str]:
|
|
78
|
+
"""
|
|
79
|
+
Parse a filter string into individual filter IDs, handling escaped ampersands.
|
|
80
|
+
"""
|
|
81
|
+
if not filter_string.startswith(cls.PREFIX):
|
|
82
|
+
raise ValueError(f"Filter string must start with {cls.PREFIX}")
|
|
83
|
+
|
|
84
|
+
# Remove the prefix
|
|
85
|
+
content = filter_string[len(cls.PREFIX) :]
|
|
86
|
+
if not content:
|
|
87
|
+
raise ValueError("No filters specified after prefix")
|
|
88
|
+
|
|
89
|
+
# Split on unescaped ampersands
|
|
90
|
+
# This regex matches & that are not preceded by a backslash
|
|
91
|
+
parts = re.split(r"(?<!\\)&", content)
|
|
92
|
+
|
|
93
|
+
# Unescape ampersands in each part
|
|
94
|
+
filter_ids = [
|
|
95
|
+
part.replace(cls.ESCAPED_AMPERSAND, cls.UNESCAPED_AMPERSAND)
|
|
96
|
+
for part in parts
|
|
97
|
+
]
|
|
98
|
+
|
|
99
|
+
# Validate each filter ID using the existing validation
|
|
100
|
+
for fid in filter_ids:
|
|
101
|
+
_check_dataset_filter_id(fid)
|
|
102
|
+
|
|
103
|
+
return filter_ids
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
def is_valid_filter_string(cls, filter_string: str) -> bool:
|
|
107
|
+
"""Check if a filter string is valid."""
|
|
108
|
+
try:
|
|
109
|
+
cls.parse_filter_string(filter_string)
|
|
110
|
+
return True
|
|
111
|
+
except ValueError:
|
|
112
|
+
return False
|
|
113
|
+
|
|
114
|
+
def __init__(self, filter_id: str):
|
|
115
|
+
filter_ids = MultiDatasetFilter.parse_filter_string(filter_id)
|
|
116
|
+
self.filters = [dataset_filter_from_id(fid) for fid in filter_ids]
|
|
117
|
+
|
|
118
|
+
def __call__(self, task_run: TaskRun) -> bool:
|
|
119
|
+
return all(f(task_run) for f in self.filters)
|
|
120
|
+
|
|
121
|
+
|
|
62
122
|
class StaticDatasetFilters(str, Enum):
|
|
63
123
|
"""Dataset filter names."""
|
|
64
124
|
|
|
@@ -98,6 +158,11 @@ def _check_dataset_filter_id(id: str) -> str:
|
|
|
98
158
|
if id.startswith("tag::") and len(id) > 5:
|
|
99
159
|
return id
|
|
100
160
|
|
|
161
|
+
if id.startswith(MultiDatasetFilter.PREFIX):
|
|
162
|
+
if not MultiDatasetFilter.is_valid_filter_string(id):
|
|
163
|
+
raise ValueError(f"Invalid multi-filter string: {id}")
|
|
164
|
+
return id
|
|
165
|
+
|
|
101
166
|
raise ValueError(f"Invalid dataset filter ID: {id}")
|
|
102
167
|
|
|
103
168
|
|
|
@@ -108,6 +173,9 @@ def dataset_filter_from_id(id: DatasetFilterId) -> DatasetFilter:
|
|
|
108
173
|
if id.startswith("tag::") and len(id) > 5:
|
|
109
174
|
return TagFilter(id[5:])
|
|
110
175
|
|
|
176
|
+
if id.startswith(MultiDatasetFilter.PREFIX):
|
|
177
|
+
return MultiDatasetFilter(id)
|
|
178
|
+
|
|
111
179
|
if id in static_dataset_filters:
|
|
112
180
|
return static_dataset_filters[id]
|
|
113
181
|
|
|
@@ -45,6 +45,10 @@ Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [
|
|
|
45
45
|
DatasetSplitDefinition(name="train", percentage=0.8),
|
|
46
46
|
DatasetSplitDefinition(name="test", percentage=0.2),
|
|
47
47
|
]
|
|
48
|
+
Train80Val20SplitDefinition: list[DatasetSplitDefinition] = [
|
|
49
|
+
DatasetSplitDefinition(name="train", percentage=0.8),
|
|
50
|
+
DatasetSplitDefinition(name="val", percentage=0.2),
|
|
51
|
+
]
|
|
48
52
|
Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [
|
|
49
53
|
DatasetSplitDefinition(name="train", percentage=0.6),
|
|
50
54
|
DatasetSplitDefinition(name="test", percentage=0.2),
|
kiln_ai/datamodel/eval.py
CHANGED
|
@@ -263,6 +263,10 @@ class Eval(KilnParentedModel, KilnParentModel, parent_of={"configs": EvalConfig}
|
|
|
263
263
|
default=None,
|
|
264
264
|
description="The id of the current config to use for this eval. This can be changed over time to run the same eval with different configs.",
|
|
265
265
|
)
|
|
266
|
+
current_run_config_id: ID_TYPE = Field(
|
|
267
|
+
default=None,
|
|
268
|
+
description="The id of the a run config which was selected as the best run config for this eval. The run config must belong to the parent Task.",
|
|
269
|
+
)
|
|
266
270
|
eval_set_filter_id: DatasetFilterId = Field(
|
|
267
271
|
description="The id of the dataset filter which defines which dataset items are included when running this eval. Should be mutually exclusive with eval_configs_filter_id."
|
|
268
272
|
)
|
|
@@ -272,6 +276,10 @@ class Eval(KilnParentedModel, KilnParentModel, parent_of={"configs": EvalConfig}
|
|
|
272
276
|
output_scores: List[EvalOutputScore] = Field(
|
|
273
277
|
description="The scores this evaluator should produce."
|
|
274
278
|
)
|
|
279
|
+
favourite: bool = Field(
|
|
280
|
+
default=False,
|
|
281
|
+
description="Whether this eval is a favourite of the user. Rendered as a star icon in the UI.",
|
|
282
|
+
)
|
|
275
283
|
|
|
276
284
|
# Workaround to return typed parent without importing Task
|
|
277
285
|
def parent_task(self) -> Union["Task", None]:
|
kiln_ai/datamodel/finetune.py
CHANGED
kiln_ai/datamodel/json_schema.py
CHANGED
|
@@ -41,16 +41,33 @@ def validate_schema(instance: Dict, schema_str: str) -> None:
|
|
|
41
41
|
|
|
42
42
|
Raises:
|
|
43
43
|
jsonschema.exceptions.ValidationError: If validation fails
|
|
44
|
-
|
|
44
|
+
"""
|
|
45
|
+
schema = schema_from_json_str(schema_str)
|
|
46
|
+
v = jsonschema.Draft202012Validator(schema)
|
|
47
|
+
v.validate(instance)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def validate_schema_with_value_error(
|
|
51
|
+
instance: Dict, schema_str: str, error_prefix: str | None = None
|
|
52
|
+
) -> None:
|
|
53
|
+
"""Validate a dictionary against a JSON schema and raise a ValueError if the schema is invalid.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
instance: Dictionary to validate
|
|
57
|
+
schema_str: JSON schema string to validate against
|
|
58
|
+
error_prefix: Error message prefix to include in the ValueError
|
|
59
|
+
|
|
60
|
+
Raises:
|
|
61
|
+
ValueError: If the instance does not match the schema
|
|
45
62
|
"""
|
|
46
63
|
try:
|
|
47
|
-
|
|
48
|
-
v = jsonschema.Draft202012Validator(schema)
|
|
49
|
-
v.validate(instance)
|
|
64
|
+
validate_schema(instance, schema_str)
|
|
50
65
|
except jsonschema.exceptions.ValidationError as e:
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
66
|
+
msg = f"The error from the schema check was: {e.message}. The JSON was: \n```json\n{instance}\n```"
|
|
67
|
+
if error_prefix:
|
|
68
|
+
msg = f"{error_prefix} {msg}"
|
|
69
|
+
|
|
70
|
+
raise ValueError(msg) from e
|
|
54
71
|
|
|
55
72
|
|
|
56
73
|
def schema_from_json_str(v: str) -> Dict:
|
kiln_ai/datamodel/prompt_id.py
CHANGED
|
@@ -13,6 +13,7 @@ class PromptGenerators(str, Enum):
|
|
|
13
13
|
SIMPLE_CHAIN_OF_THOUGHT = "simple_chain_of_thought_prompt_builder"
|
|
14
14
|
FEW_SHOT_CHAIN_OF_THOUGHT = "few_shot_chain_of_thought_prompt_builder"
|
|
15
15
|
MULTI_SHOT_CHAIN_OF_THOUGHT = "multi_shot_chain_of_thought_prompt_builder"
|
|
16
|
+
SHORT = "short_prompt_builder"
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
prompt_generator_values = [pg.value for pg in PromptGenerators]
|
kiln_ai/datamodel/task_output.py
CHANGED
|
@@ -9,7 +9,7 @@ from typing_extensions import Self
|
|
|
9
9
|
|
|
10
10
|
from kiln_ai.datamodel.basemodel import ID_TYPE, KilnBaseModel
|
|
11
11
|
from kiln_ai.datamodel.datamodel_enums import TaskOutputRatingType
|
|
12
|
-
from kiln_ai.datamodel.json_schema import
|
|
12
|
+
from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
|
|
13
13
|
from kiln_ai.datamodel.strict_mode import strict_mode
|
|
14
14
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
15
15
|
|
|
@@ -64,7 +64,7 @@ class TaskOutputRating(KilnBaseModel):
|
|
|
64
64
|
)
|
|
65
65
|
requirement_ratings: Dict[ID_TYPE, RequirementRating] = Field(
|
|
66
66
|
default={},
|
|
67
|
-
description="The ratings of the requirements of the task.",
|
|
67
|
+
description="The ratings of the requirements of the task. The ID can be either a task_requirement_id or a named rating for an eval_output_score name (in format 'named::<name>').",
|
|
68
68
|
)
|
|
69
69
|
|
|
70
70
|
# Previously we stored rating values as a dict of floats, but now we store them as RequirementRating objects.
|
|
@@ -308,11 +308,15 @@ class TaskOutput(KilnBaseModel):
|
|
|
308
308
|
# validate output
|
|
309
309
|
if task.output_json_schema is not None:
|
|
310
310
|
try:
|
|
311
|
-
|
|
312
|
-
except json.JSONDecodeError:
|
|
311
|
+
output_parsed = json.loads(self.output)
|
|
312
|
+
except json.JSONDecodeError as e:
|
|
313
313
|
raise ValueError("Output is not a valid JSON object")
|
|
314
|
-
|
|
315
|
-
|
|
314
|
+
|
|
315
|
+
validate_schema_with_value_error(
|
|
316
|
+
output_parsed,
|
|
317
|
+
task.output_json_schema,
|
|
318
|
+
"This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information.",
|
|
319
|
+
)
|
|
316
320
|
return self
|
|
317
321
|
|
|
318
322
|
@model_validator(mode="after")
|