kiln-ai 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kiln_ai/adapters/eval/base_eval.py +7 -2
- kiln_ai/adapters/eval/eval_runner.py +5 -64
- kiln_ai/adapters/eval/g_eval.py +3 -3
- kiln_ai/adapters/fine_tune/base_finetune.py +6 -3
- kiln_ai/adapters/fine_tune/dataset_formatter.py +128 -38
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +2 -1
- kiln_ai/adapters/fine_tune/test_base_finetune.py +7 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +267 -10
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +586 -0
- kiln_ai/adapters/fine_tune/vertex_finetune.py +217 -0
- kiln_ai/adapters/ml_model_list.py +817 -62
- kiln_ai/adapters/model_adapters/base_adapter.py +33 -10
- kiln_ai/adapters/model_adapters/litellm_adapter.py +51 -12
- kiln_ai/adapters/model_adapters/test_base_adapter.py +74 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +65 -1
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +3 -2
- kiln_ai/adapters/model_adapters/test_structured_output.py +4 -6
- kiln_ai/adapters/parsers/base_parser.py +0 -3
- kiln_ai/adapters/parsers/parser_registry.py +5 -3
- kiln_ai/adapters/parsers/r1_parser.py +17 -2
- kiln_ai/adapters/parsers/request_formatters.py +40 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
- kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
- kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
- kiln_ai/adapters/prompt_builders.py +14 -1
- kiln_ai/adapters/provider_tools.py +25 -1
- kiln_ai/adapters/repair/test_repair_task.py +3 -2
- kiln_ai/adapters/test_prompt_builders.py +24 -3
- kiln_ai/adapters/test_provider_tools.py +86 -1
- kiln_ai/datamodel/__init__.py +2 -0
- kiln_ai/datamodel/datamodel_enums.py +14 -0
- kiln_ai/datamodel/dataset_filters.py +69 -1
- kiln_ai/datamodel/dataset_split.py +4 -0
- kiln_ai/datamodel/eval.py +8 -0
- kiln_ai/datamodel/finetune.py +1 -0
- kiln_ai/datamodel/json_schema.py +24 -7
- kiln_ai/datamodel/prompt_id.py +1 -0
- kiln_ai/datamodel/task_output.py +10 -6
- kiln_ai/datamodel/task_run.py +68 -12
- kiln_ai/datamodel/test_basemodel.py +3 -7
- kiln_ai/datamodel/test_dataset_filters.py +82 -0
- kiln_ai/datamodel/test_dataset_split.py +2 -0
- kiln_ai/datamodel/test_example_models.py +158 -3
- kiln_ai/datamodel/test_json_schema.py +22 -3
- kiln_ai/datamodel/test_model_perf.py +3 -2
- kiln_ai/datamodel/test_models.py +50 -2
- kiln_ai/utils/async_job_runner.py +106 -0
- kiln_ai/utils/dataset_import.py +80 -18
- kiln_ai/utils/test_async_job_runner.py +199 -0
- kiln_ai/utils/test_dataset_import.py +242 -10
- {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/METADATA +3 -2
- kiln_ai-0.16.0.dist-info/RECORD +108 -0
- kiln_ai/adapters/test_generate_docs.py +0 -69
- kiln_ai-0.14.0.dist-info/RECORD +0 -103
- {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/licenses/LICENSE.txt +0 -0
kiln_ai/datamodel/task_run.py
CHANGED
|
@@ -3,11 +3,11 @@ from typing import TYPE_CHECKING, Dict, List, Union
|
|
|
3
3
|
|
|
4
4
|
import jsonschema
|
|
5
5
|
import jsonschema.exceptions
|
|
6
|
-
from pydantic import Field, ValidationInfo, model_validator
|
|
6
|
+
from pydantic import BaseModel, Field, ValidationInfo, model_validator
|
|
7
7
|
from typing_extensions import Self
|
|
8
8
|
|
|
9
9
|
from kiln_ai.datamodel.basemodel import KilnParentedModel
|
|
10
|
-
from kiln_ai.datamodel.json_schema import
|
|
10
|
+
from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
|
|
11
11
|
from kiln_ai.datamodel.strict_mode import strict_mode
|
|
12
12
|
from kiln_ai.datamodel.task_output import DataSource, TaskOutput
|
|
13
13
|
|
|
@@ -15,6 +15,29 @@ if TYPE_CHECKING:
|
|
|
15
15
|
from kiln_ai.datamodel.task import Task
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
class Usage(BaseModel):
|
|
19
|
+
input_tokens: int | None = Field(
|
|
20
|
+
default=None,
|
|
21
|
+
description="The number of input tokens used in the task run.",
|
|
22
|
+
ge=0,
|
|
23
|
+
)
|
|
24
|
+
output_tokens: int | None = Field(
|
|
25
|
+
default=None,
|
|
26
|
+
description="The number of output tokens used in the task run.",
|
|
27
|
+
ge=0,
|
|
28
|
+
)
|
|
29
|
+
total_tokens: int | None = Field(
|
|
30
|
+
default=None,
|
|
31
|
+
description="The total number of tokens used in the task run.",
|
|
32
|
+
ge=0,
|
|
33
|
+
)
|
|
34
|
+
cost: float | None = Field(
|
|
35
|
+
default=None,
|
|
36
|
+
description="The cost of the task run in US dollars, saved at runtime (prices can change over time).",
|
|
37
|
+
ge=0,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
18
41
|
class TaskRun(KilnParentedModel):
|
|
19
42
|
"""
|
|
20
43
|
Represents a single execution of a Task.
|
|
@@ -47,17 +70,26 @@ class TaskRun(KilnParentedModel):
|
|
|
47
70
|
default=[],
|
|
48
71
|
description="Tags for the task run. Tags are used to categorize task runs for filtering and reporting.",
|
|
49
72
|
)
|
|
73
|
+
usage: Usage | None = Field(
|
|
74
|
+
default=None,
|
|
75
|
+
description="Usage information for the task run. This includes the number of input tokens, output tokens, and total tokens used.",
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def thinking_training_data(self) -> str | None:
|
|
79
|
+
"""
|
|
80
|
+
Get the thinking training data from the task run.
|
|
81
|
+
"""
|
|
82
|
+
if self.intermediate_outputs is None:
|
|
83
|
+
return None
|
|
84
|
+
return self.intermediate_outputs.get(
|
|
85
|
+
"reasoning"
|
|
86
|
+
) or self.intermediate_outputs.get("chain_of_thought")
|
|
50
87
|
|
|
51
88
|
def has_thinking_training_data(self) -> bool:
|
|
52
89
|
"""
|
|
53
90
|
Does this run have thinking data that we can use to train a thinking model?
|
|
54
91
|
"""
|
|
55
|
-
|
|
56
|
-
return False
|
|
57
|
-
return (
|
|
58
|
-
"chain_of_thought" in self.intermediate_outputs
|
|
59
|
-
or "reasoning" in self.intermediate_outputs
|
|
60
|
-
)
|
|
92
|
+
return self.thinking_training_data() is not None
|
|
61
93
|
|
|
62
94
|
# Workaround to return typed parent without importing Task
|
|
63
95
|
def parent_task(self) -> Union["Task", None]:
|
|
@@ -87,14 +119,19 @@ class TaskRun(KilnParentedModel):
|
|
|
87
119
|
# don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving)
|
|
88
120
|
return self
|
|
89
121
|
|
|
90
|
-
# validate
|
|
122
|
+
# validate input
|
|
91
123
|
if task.input_json_schema is not None:
|
|
92
124
|
try:
|
|
93
|
-
|
|
125
|
+
input_parsed = json.loads(self.input)
|
|
94
126
|
except json.JSONDecodeError:
|
|
95
127
|
raise ValueError("Input is not a valid JSON object")
|
|
96
|
-
|
|
97
|
-
|
|
128
|
+
|
|
129
|
+
validate_schema_with_value_error(
|
|
130
|
+
input_parsed,
|
|
131
|
+
task.input_json_schema,
|
|
132
|
+
"Input does not match task input schema.",
|
|
133
|
+
)
|
|
134
|
+
|
|
98
135
|
self._last_validated_input = self.input
|
|
99
136
|
return self
|
|
100
137
|
|
|
@@ -131,6 +168,24 @@ class TaskRun(KilnParentedModel):
|
|
|
131
168
|
raise ValueError(
|
|
132
169
|
"Repaired output rating must be None. Repaired outputs are assumed to have a perfect rating, as they have been fixed."
|
|
133
170
|
)
|
|
171
|
+
|
|
172
|
+
task = self.parent_task()
|
|
173
|
+
if (
|
|
174
|
+
task is not None
|
|
175
|
+
and self.repaired_output.output is not None
|
|
176
|
+
and task.output_json_schema is not None
|
|
177
|
+
):
|
|
178
|
+
try:
|
|
179
|
+
output_parsed = json.loads(self.repaired_output.output)
|
|
180
|
+
except json.JSONDecodeError:
|
|
181
|
+
raise ValueError("Repaired output is not a valid JSON object")
|
|
182
|
+
|
|
183
|
+
validate_schema_with_value_error(
|
|
184
|
+
output_parsed,
|
|
185
|
+
task.output_json_schema,
|
|
186
|
+
"Repaired output does not match task output schema.",
|
|
187
|
+
)
|
|
188
|
+
|
|
134
189
|
if self.repair_instructions is None and self.repaired_output is not None:
|
|
135
190
|
raise ValueError(
|
|
136
191
|
"Repair instructions are required if providing a repaired output."
|
|
@@ -139,6 +194,7 @@ class TaskRun(KilnParentedModel):
|
|
|
139
194
|
raise ValueError(
|
|
140
195
|
"A repaired output is required if providing repair instructions."
|
|
141
196
|
)
|
|
197
|
+
|
|
142
198
|
return self
|
|
143
199
|
|
|
144
200
|
@model_validator(mode="after")
|
|
@@ -483,7 +483,7 @@ class MockAdapter(BaseAdapter):
|
|
|
483
483
|
"""Implementation of BaseAdapter for testing"""
|
|
484
484
|
|
|
485
485
|
async def _run(self, input):
|
|
486
|
-
return RunOutput(output="test output", intermediate_outputs=None)
|
|
486
|
+
return RunOutput(output="test output", intermediate_outputs=None), None
|
|
487
487
|
|
|
488
488
|
def adapter_name(self) -> str:
|
|
489
489
|
return "test"
|
|
@@ -510,6 +510,7 @@ async def test_invoke_parsing_flow(adapter):
|
|
|
510
510
|
# Mock dependencies
|
|
511
511
|
mock_provider = MagicMock()
|
|
512
512
|
mock_provider.parser = "test_parser"
|
|
513
|
+
mock_provider.formatter = None
|
|
513
514
|
mock_provider.reasoning_capable = False
|
|
514
515
|
|
|
515
516
|
mock_parser = MagicMock()
|
|
@@ -517,13 +518,11 @@ async def test_invoke_parsing_flow(adapter):
|
|
|
517
518
|
output="parsed test output", intermediate_outputs={"key": "value"}
|
|
518
519
|
)
|
|
519
520
|
|
|
520
|
-
mock_parser_class = MagicMock(return_value=mock_parser)
|
|
521
|
-
|
|
522
521
|
with (
|
|
523
522
|
patch.object(adapter, "model_provider", return_value=mock_provider),
|
|
524
523
|
patch(
|
|
525
524
|
"kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id",
|
|
526
|
-
return_value=
|
|
525
|
+
return_value=mock_parser,
|
|
527
526
|
),
|
|
528
527
|
patch("kiln_ai.adapters.model_adapters.base_adapter.Config") as mock_config,
|
|
529
528
|
):
|
|
@@ -534,9 +533,6 @@ async def test_invoke_parsing_flow(adapter):
|
|
|
534
533
|
# Execute
|
|
535
534
|
result = await adapter.invoke("test input")
|
|
536
535
|
|
|
537
|
-
# Verify parser was created correctly
|
|
538
|
-
mock_parser_class.assert_called_once_with(structured_output=False)
|
|
539
|
-
|
|
540
536
|
# Verify parsing occurred
|
|
541
537
|
mock_parser.parse_output.assert_called_once()
|
|
542
538
|
parsed_args = mock_parser.parse_output.call_args[1]
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from unittest.mock import Mock
|
|
2
|
+
|
|
1
3
|
import pytest
|
|
2
4
|
from pydantic import BaseModel
|
|
3
5
|
|
|
@@ -5,12 +7,14 @@ from kiln_ai.datamodel.dataset_filters import (
|
|
|
5
7
|
AllDatasetFilter,
|
|
6
8
|
DatasetFilterId,
|
|
7
9
|
HighRatingDatasetFilter,
|
|
10
|
+
MultiDatasetFilter,
|
|
8
11
|
StaticDatasetFilters,
|
|
9
12
|
TagFilter,
|
|
10
13
|
ThinkingModelDatasetFilter,
|
|
11
14
|
ThinkingModelHighRatedFilter,
|
|
12
15
|
dataset_filter_from_id,
|
|
13
16
|
)
|
|
17
|
+
from kiln_ai.datamodel.task_run import TaskRun
|
|
14
18
|
|
|
15
19
|
# Note: Many more filter tests in test_dataset_split.py
|
|
16
20
|
|
|
@@ -69,3 +73,81 @@ def test_tag_filter(tag, expected_error, expected_tag):
|
|
|
69
73
|
filter = dataset_filter_from_id(tag)
|
|
70
74
|
assert isinstance(filter, TagFilter)
|
|
71
75
|
assert filter.tag == expected_tag
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class TestMultiDatasetFilter:
|
|
79
|
+
@pytest.mark.parametrize(
|
|
80
|
+
"filter_string,expected_filters",
|
|
81
|
+
[
|
|
82
|
+
("multi_filter::high_rating", ["high_rating"]),
|
|
83
|
+
(
|
|
84
|
+
"multi_filter::high_rating&thinking_model",
|
|
85
|
+
["high_rating", "thinking_model"],
|
|
86
|
+
),
|
|
87
|
+
("multi_filter::tag::test&high_rating", ["tag::test", "high_rating"]),
|
|
88
|
+
(
|
|
89
|
+
"multi_filter::high_rating&tag::tag\\&name",
|
|
90
|
+
["high_rating", "tag::tag&name"],
|
|
91
|
+
),
|
|
92
|
+
],
|
|
93
|
+
)
|
|
94
|
+
def test_valid_filter_string_parsing(self, filter_string, expected_filters):
|
|
95
|
+
"""Test that valid filter strings are parsed correctly."""
|
|
96
|
+
assert MultiDatasetFilter.parse_filter_string(filter_string) == expected_filters
|
|
97
|
+
assert MultiDatasetFilter.is_valid_filter_string(filter_string)
|
|
98
|
+
|
|
99
|
+
@pytest.mark.parametrize(
|
|
100
|
+
"filter_string,expected_error",
|
|
101
|
+
[
|
|
102
|
+
(
|
|
103
|
+
"not_multi_filter::high_rating",
|
|
104
|
+
"Filter string must start with multi_filter::",
|
|
105
|
+
),
|
|
106
|
+
("multi_filter::", "No filters specified after prefix"),
|
|
107
|
+
("multi_filter::high_rating&", "Invalid dataset filter ID:"),
|
|
108
|
+
("multi_filter::invalid_filter", "Invalid dataset filter ID:"),
|
|
109
|
+
],
|
|
110
|
+
)
|
|
111
|
+
def test_invalid_filter_string_handling(self, filter_string, expected_error):
|
|
112
|
+
"""Test that invalid filter strings raise appropriate errors."""
|
|
113
|
+
with pytest.raises(ValueError, match=expected_error):
|
|
114
|
+
MultiDatasetFilter.parse_filter_string(filter_string)
|
|
115
|
+
assert not MultiDatasetFilter.is_valid_filter_string(filter_string)
|
|
116
|
+
|
|
117
|
+
def test_filter_combination_logic(self):
|
|
118
|
+
"""Test that multiple filters are combined with AND logic."""
|
|
119
|
+
# Create a mock task run
|
|
120
|
+
task_run = Mock(spec=TaskRun)
|
|
121
|
+
task_run.output = Mock()
|
|
122
|
+
task_run.output.rating = Mock()
|
|
123
|
+
task_run.output.rating.is_high_quality.return_value = True
|
|
124
|
+
task_run.tags = ["test_tag"]
|
|
125
|
+
task_run.has_thinking_training_data.return_value = True
|
|
126
|
+
task_run.repaired_output = None
|
|
127
|
+
|
|
128
|
+
# Test combining high_rating and tag filters
|
|
129
|
+
filter_id = "multi_filter::high_rating&tag::test_tag"
|
|
130
|
+
multi_filter = dataset_filter_from_id(filter_id)
|
|
131
|
+
assert multi_filter(task_run)
|
|
132
|
+
|
|
133
|
+
# Test that it fails if one filter fails
|
|
134
|
+
task_run.tags = ["wrong_tag"]
|
|
135
|
+
assert not multi_filter(task_run)
|
|
136
|
+
task_run.tags = ["test_tag"]
|
|
137
|
+
assert multi_filter(task_run)
|
|
138
|
+
task_run.output.rating.is_high_quality.return_value = False
|
|
139
|
+
assert not multi_filter(task_run)
|
|
140
|
+
|
|
141
|
+
# Verify the mock was called as expected
|
|
142
|
+
task_run.output.rating.is_high_quality.assert_called()
|
|
143
|
+
|
|
144
|
+
def test_filter_creation_from_id(self):
|
|
145
|
+
"""Test that multi filters can be created via dataset_filter_from_id."""
|
|
146
|
+
filter_id = "multi_filter::high_rating&thinking_model"
|
|
147
|
+
filter = dataset_filter_from_id(filter_id)
|
|
148
|
+
assert isinstance(filter, MultiDatasetFilter)
|
|
149
|
+
assert len(filter.filters) == 2
|
|
150
|
+
assert any(isinstance(f, type(HighRatingDatasetFilter)) for f in filter.filters)
|
|
151
|
+
assert any(
|
|
152
|
+
isinstance(f, type(ThinkingModelDatasetFilter)) for f in filter.filters
|
|
153
|
+
)
|
|
@@ -17,6 +17,7 @@ from kiln_ai.datamodel.dataset_split import (
|
|
|
17
17
|
AllSplitDefinition,
|
|
18
18
|
Train60Test20Val20SplitDefinition,
|
|
19
19
|
Train80Test20SplitDefinition,
|
|
20
|
+
Train80Val20SplitDefinition,
|
|
20
21
|
)
|
|
21
22
|
from kiln_ai.datamodel.test_dataset_filters import (
|
|
22
23
|
AllDatasetFilter,
|
|
@@ -174,6 +175,7 @@ def test_high_rating_dataset_filter(sample_task_runs):
|
|
|
174
175
|
[
|
|
175
176
|
(Train80Test20SplitDefinition, {"train": 8, "test": 2}),
|
|
176
177
|
(AllSplitDefinition, {"all": 10}),
|
|
178
|
+
(Train80Val20SplitDefinition, {"train": 8, "val": 2}),
|
|
177
179
|
(Train60Test20Val20SplitDefinition, {"train": 6, "test": 2, "val": 2}),
|
|
178
180
|
(
|
|
179
181
|
[
|
|
@@ -16,6 +16,7 @@ from kiln_ai.datamodel import (
|
|
|
16
16
|
TaskOutputRatingType,
|
|
17
17
|
TaskRequirement,
|
|
18
18
|
TaskRun,
|
|
19
|
+
Usage,
|
|
19
20
|
)
|
|
20
21
|
|
|
21
22
|
|
|
@@ -358,6 +359,9 @@ def test_task_output_schema_validation(tmp_path):
|
|
|
358
359
|
task_output.save_to_file()
|
|
359
360
|
|
|
360
361
|
|
|
362
|
+
_input_schema_match = "Input does not match task input schema"
|
|
363
|
+
|
|
364
|
+
|
|
361
365
|
def test_task_input_schema_validation(tmp_path):
|
|
362
366
|
# Create a project and task hierarchy
|
|
363
367
|
project = Project(name="Test Project", path=(tmp_path / "test_project"))
|
|
@@ -395,18 +399,18 @@ def test_task_input_schema_validation(tmp_path):
|
|
|
395
399
|
valid_task_output.save_to_file()
|
|
396
400
|
|
|
397
401
|
# Changing to invalid input
|
|
398
|
-
with pytest.raises(ValueError, match=
|
|
402
|
+
with pytest.raises(ValueError, match=_input_schema_match):
|
|
399
403
|
valid_task_output.input = '{"name": "John Doe", "age": "thirty"}'
|
|
400
404
|
valid_task_output.save_to_file()
|
|
401
405
|
|
|
402
406
|
# loading from file, then changing to invalid input
|
|
403
407
|
loaded_task_output = TaskRun.load_from_file(valid_task_output.path)
|
|
404
|
-
with pytest.raises(ValueError, match=
|
|
408
|
+
with pytest.raises(ValueError, match=_input_schema_match):
|
|
405
409
|
loaded_task_output.input = '{"name": "John Doe", "age": "thirty"}'
|
|
406
410
|
loaded_task_output.save_to_file()
|
|
407
411
|
|
|
408
412
|
# Invalid case: input does not match task input schema
|
|
409
|
-
with pytest.raises(ValueError, match=
|
|
413
|
+
with pytest.raises(ValueError, match=_input_schema_match):
|
|
410
414
|
task_output = TaskRun(
|
|
411
415
|
input='{"name": "John Doe", "age": "thirty"}',
|
|
412
416
|
input_source=DataSource(
|
|
@@ -642,3 +646,154 @@ def test_task_run_validate_repaired_output():
|
|
|
642
646
|
)
|
|
643
647
|
|
|
644
648
|
assert "Repaired output rating must be None" in str(exc_info.value)
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
def test_task_run_validate_repaired_output_structured(tmp_path):
|
|
652
|
+
# Create a project, task, and example hierarchy
|
|
653
|
+
project = Project(name="Test Project", path=(tmp_path / "test_project"))
|
|
654
|
+
project.save_to_file()
|
|
655
|
+
task = Task(
|
|
656
|
+
name="Test Task",
|
|
657
|
+
instruction="test instruction",
|
|
658
|
+
parent=project,
|
|
659
|
+
output_json_schema=json.dumps(
|
|
660
|
+
{
|
|
661
|
+
"type": "object",
|
|
662
|
+
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
|
|
663
|
+
"required": ["name", "age"],
|
|
664
|
+
}
|
|
665
|
+
),
|
|
666
|
+
)
|
|
667
|
+
task.save_to_file()
|
|
668
|
+
|
|
669
|
+
# test valid repaired output schema
|
|
670
|
+
task_run = TaskRun(
|
|
671
|
+
parent=task,
|
|
672
|
+
input="test input",
|
|
673
|
+
input_source=DataSource(
|
|
674
|
+
type=DataSourceType.human,
|
|
675
|
+
properties={"created_by": "john_doe"},
|
|
676
|
+
),
|
|
677
|
+
output=TaskOutput(
|
|
678
|
+
output='{"name": "John Doe", "age": 30}',
|
|
679
|
+
source=DataSource(
|
|
680
|
+
type=DataSourceType.human,
|
|
681
|
+
properties={"created_by": "john_doe"},
|
|
682
|
+
),
|
|
683
|
+
),
|
|
684
|
+
repair_instructions="Fix the output",
|
|
685
|
+
repaired_output=TaskOutput(
|
|
686
|
+
output='{"name": "John Doe", "age": 30}',
|
|
687
|
+
source=DataSource(
|
|
688
|
+
type=DataSourceType.human, properties={"created_by": "john_doe"}
|
|
689
|
+
),
|
|
690
|
+
),
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
assert task_run.repaired_output is not None
|
|
694
|
+
assert task_run.repaired_output.rating is None
|
|
695
|
+
|
|
696
|
+
# test invalid JSON
|
|
697
|
+
with pytest.raises(ValueError):
|
|
698
|
+
TaskRun(
|
|
699
|
+
parent=task,
|
|
700
|
+
input="test input",
|
|
701
|
+
input_source=DataSource(
|
|
702
|
+
type=DataSourceType.human,
|
|
703
|
+
properties={"created_by": "john_doe"},
|
|
704
|
+
),
|
|
705
|
+
output=TaskOutput(
|
|
706
|
+
output='{"name": "John Doe", "age": 30}',
|
|
707
|
+
source=DataSource(
|
|
708
|
+
type=DataSourceType.human,
|
|
709
|
+
properties={"created_by": "john_doe"},
|
|
710
|
+
),
|
|
711
|
+
),
|
|
712
|
+
repair_instructions="Fix the output",
|
|
713
|
+
repaired_output=TaskOutput(
|
|
714
|
+
output='{"name": "John Doe", "age": 30', # missing closing brace
|
|
715
|
+
source=DataSource(
|
|
716
|
+
type=DataSourceType.human,
|
|
717
|
+
properties={"created_by": "john_doe"},
|
|
718
|
+
),
|
|
719
|
+
),
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
# test invalid repaired output schema
|
|
723
|
+
with pytest.raises(ValueError):
|
|
724
|
+
TaskRun(
|
|
725
|
+
parent=task,
|
|
726
|
+
input="test input",
|
|
727
|
+
input_source=DataSource(
|
|
728
|
+
type=DataSourceType.human,
|
|
729
|
+
properties={"created_by": "john_doe"},
|
|
730
|
+
),
|
|
731
|
+
output=TaskOutput(
|
|
732
|
+
output='{"name": "John Doe", "age": 30}',
|
|
733
|
+
source=DataSource(
|
|
734
|
+
type=DataSourceType.human,
|
|
735
|
+
properties={"created_by": "john_doe"},
|
|
736
|
+
),
|
|
737
|
+
),
|
|
738
|
+
repair_instructions="Fix the output",
|
|
739
|
+
repaired_output=TaskOutput(
|
|
740
|
+
output='{"name": "John Doe", "age": "thirty"}', # invalid schema
|
|
741
|
+
source=DataSource(
|
|
742
|
+
type=DataSourceType.human,
|
|
743
|
+
properties={"created_by": "john_doe"},
|
|
744
|
+
),
|
|
745
|
+
),
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
|
|
749
|
+
@pytest.mark.parametrize(
|
|
750
|
+
"input_tokens,output_tokens,total_tokens,cost,should_raise",
|
|
751
|
+
[
|
|
752
|
+
# Valid cases
|
|
753
|
+
(100, 50, 150, 0.002, False), # All fields
|
|
754
|
+
(None, None, None, None, False), # All None (defaults)
|
|
755
|
+
# Invalid cases
|
|
756
|
+
(-100, 50, 150, 0.002, True), # Negative input_tokens
|
|
757
|
+
(100, -50, 150, 0.002, True), # Negative output_tokens
|
|
758
|
+
(100, 50, -150, 0.002, True), # Negative total_tokens
|
|
759
|
+
(100, 50, 150, -0.002, True), # Negative cost
|
|
760
|
+
],
|
|
761
|
+
)
|
|
762
|
+
def test_usage_model(input_tokens, output_tokens, total_tokens, cost, should_raise):
|
|
763
|
+
"""Test the Usage model with various input combinations."""
|
|
764
|
+
if should_raise:
|
|
765
|
+
with pytest.raises(ValidationError):
|
|
766
|
+
Usage(
|
|
767
|
+
input_tokens=input_tokens,
|
|
768
|
+
output_tokens=output_tokens,
|
|
769
|
+
total_tokens=total_tokens,
|
|
770
|
+
cost=cost,
|
|
771
|
+
)
|
|
772
|
+
else:
|
|
773
|
+
usage = Usage(
|
|
774
|
+
input_tokens=input_tokens,
|
|
775
|
+
output_tokens=output_tokens,
|
|
776
|
+
total_tokens=total_tokens,
|
|
777
|
+
cost=cost,
|
|
778
|
+
)
|
|
779
|
+
assert usage.input_tokens == input_tokens
|
|
780
|
+
assert usage.output_tokens == output_tokens
|
|
781
|
+
assert usage.total_tokens == total_tokens
|
|
782
|
+
assert usage.cost == cost
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
def test_usage_model_in_task_run(valid_task_run):
|
|
786
|
+
"""Test that Usage can be properly set in a TaskRun."""
|
|
787
|
+
usage = Usage(
|
|
788
|
+
input_tokens=100,
|
|
789
|
+
output_tokens=50,
|
|
790
|
+
total_tokens=150,
|
|
791
|
+
cost=0.002,
|
|
792
|
+
)
|
|
793
|
+
task_run = valid_task_run.model_copy(deep=True)
|
|
794
|
+
task_run.usage = usage
|
|
795
|
+
assert task_run.usage == usage
|
|
796
|
+
assert task_run.usage.input_tokens == 100
|
|
797
|
+
assert task_run.usage.output_tokens == 50
|
|
798
|
+
assert task_run.usage.total_tokens == 150
|
|
799
|
+
assert task_run.usage.cost == 0.002
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import jsonschema
|
|
1
2
|
import pytest
|
|
2
3
|
from pydantic import BaseModel
|
|
3
4
|
|
|
@@ -6,6 +7,7 @@ from kiln_ai.datamodel.json_schema import (
|
|
|
6
7
|
schema_from_json_str,
|
|
7
8
|
string_to_json_key,
|
|
8
9
|
validate_schema,
|
|
10
|
+
validate_schema_with_value_error,
|
|
9
11
|
)
|
|
10
12
|
|
|
11
13
|
|
|
@@ -71,15 +73,32 @@ def test_validate_schema_content():
|
|
|
71
73
|
o = {"setup": "asdf", "punchline": "asdf", "rating": 1}
|
|
72
74
|
validate_schema(o, json_joke_schema)
|
|
73
75
|
o = {"setup": "asdf"}
|
|
74
|
-
with pytest.raises(
|
|
76
|
+
with pytest.raises(jsonschema.exceptions.ValidationError):
|
|
75
77
|
validate_schema(0, json_joke_schema)
|
|
76
78
|
o = {"setup": "asdf", "punchline": "asdf"}
|
|
77
79
|
validate_schema(o, json_joke_schema)
|
|
78
80
|
o = {"setup": "asdf", "punchline": "asdf", "rating": "1"}
|
|
79
|
-
with pytest.raises(
|
|
81
|
+
with pytest.raises(jsonschema.exceptions.ValidationError):
|
|
80
82
|
validate_schema(o, json_joke_schema)
|
|
81
83
|
|
|
82
84
|
|
|
85
|
+
def test_validate_schema_content_with_value_error():
|
|
86
|
+
o = {"setup": "asdf", "punchline": "asdf", "rating": 1}
|
|
87
|
+
validate_schema_with_value_error(o, json_joke_schema, "PREFIX")
|
|
88
|
+
o = {"setup": "asdf"}
|
|
89
|
+
with pytest.raises(
|
|
90
|
+
ValueError, match="PREFIX The error from the schema check was: "
|
|
91
|
+
):
|
|
92
|
+
validate_schema_with_value_error(0, json_joke_schema, "PREFIX")
|
|
93
|
+
o = {"setup": "asdf", "punchline": "asdf"}
|
|
94
|
+
validate_schema_with_value_error(o, json_joke_schema, "PREFIX")
|
|
95
|
+
o = {"setup": "asdf", "punchline": "asdf", "rating": "1"}
|
|
96
|
+
with pytest.raises(
|
|
97
|
+
ValueError, match="PREFIX The error from the schema check was: "
|
|
98
|
+
):
|
|
99
|
+
validate_schema_with_value_error(o, json_joke_schema, "PREFIX")
|
|
100
|
+
|
|
101
|
+
|
|
83
102
|
json_triangle_schema = """{
|
|
84
103
|
"type": "object",
|
|
85
104
|
"properties": {
|
|
@@ -122,7 +141,7 @@ def test_triangle_schema():
|
|
|
122
141
|
assert schema["properties"]["c"]["type"] == "integer"
|
|
123
142
|
assert schema["required"] == ["a", "b", "c"]
|
|
124
143
|
validate_schema({"a": 1, "b": 2, "c": 3}, json_triangle_schema)
|
|
125
|
-
with pytest.raises(
|
|
144
|
+
with pytest.raises(jsonschema.exceptions.ValidationError):
|
|
126
145
|
validate_schema({"a": 1, "b": 2, "c": "3"}, json_triangle_schema)
|
|
127
146
|
|
|
128
147
|
|
|
@@ -119,7 +119,8 @@ def test_benchmark_load_from_file(benchmark, task_run):
|
|
|
119
119
|
avg_time_per_iteration = total_time / iterations
|
|
120
120
|
ops_per_second = 1.0 / avg_time_per_iteration
|
|
121
121
|
|
|
122
|
-
# I get 8k ops per second on my MBP. Lower value here for CI.
|
|
122
|
+
# I get 8k ops per second on my MBP. Lower value here for CI and parallel testing.
|
|
123
123
|
# Prior to optimization was 290 ops per second.
|
|
124
|
-
|
|
124
|
+
print(f"Ops per second: {ops_per_second:.6f}")
|
|
125
|
+
if ops_per_second < 500:
|
|
125
126
|
pytest.fail(f"Ops per second: {ops_per_second:.6f}, expected more than 1k ops")
|
kiln_ai/datamodel/test_models.py
CHANGED
|
@@ -547,20 +547,34 @@ def test_prompt_parent_task():
|
|
|
547
547
|
False,
|
|
548
548
|
None,
|
|
549
549
|
),
|
|
550
|
-
# Test 3:
|
|
550
|
+
# Test 3: Valid case - no thinking instructions with final_and_intermediate_r1_compatible
|
|
551
|
+
(
|
|
552
|
+
None,
|
|
553
|
+
FinetuneDataStrategy.final_and_intermediate_r1_compatible,
|
|
554
|
+
False,
|
|
555
|
+
None,
|
|
556
|
+
),
|
|
557
|
+
# Test 4: Invalid case - thinking instructions with final_only
|
|
551
558
|
(
|
|
552
559
|
"Think step by step",
|
|
553
560
|
FinetuneDataStrategy.final_only,
|
|
554
561
|
True,
|
|
555
562
|
"Thinking instructions can only be used when data_strategy is final_and_intermediate",
|
|
556
563
|
),
|
|
557
|
-
# Test
|
|
564
|
+
# Test 5: Invalid case - no thinking instructions with final_and_intermediate
|
|
558
565
|
(
|
|
559
566
|
None,
|
|
560
567
|
FinetuneDataStrategy.final_and_intermediate,
|
|
561
568
|
True,
|
|
562
569
|
"Thinking instructions are required when data_strategy is final_and_intermediate",
|
|
563
570
|
),
|
|
571
|
+
# Test 6: Invalid case - thinking instructions with final_and_intermediate_r1_compatible
|
|
572
|
+
(
|
|
573
|
+
"Think step by step",
|
|
574
|
+
FinetuneDataStrategy.final_and_intermediate_r1_compatible,
|
|
575
|
+
True,
|
|
576
|
+
"Thinking instructions can only be used when data_strategy is final_and_intermediate",
|
|
577
|
+
),
|
|
564
578
|
],
|
|
565
579
|
)
|
|
566
580
|
def test_finetune_thinking_instructions_validation(
|
|
@@ -617,3 +631,37 @@ def test_task_run_has_thinking_training_data(intermediate_outputs, expected):
|
|
|
617
631
|
intermediate_outputs=intermediate_outputs,
|
|
618
632
|
)
|
|
619
633
|
assert task_run.has_thinking_training_data() == expected
|
|
634
|
+
|
|
635
|
+
|
|
636
|
+
@pytest.mark.parametrize(
|
|
637
|
+
"intermediate_outputs,expected",
|
|
638
|
+
[
|
|
639
|
+
# No intermediate outputs
|
|
640
|
+
(None, None),
|
|
641
|
+
# Empty intermediate outputs
|
|
642
|
+
({}, None),
|
|
643
|
+
# Only chain_of_thought
|
|
644
|
+
({"chain_of_thought": "thinking process"}, "thinking process"),
|
|
645
|
+
# Only reasoning
|
|
646
|
+
({"reasoning": "reasoning process"}, "reasoning process"),
|
|
647
|
+
# Both chain_of_thought and reasoning (should return reasoning as it's checked first)
|
|
648
|
+
(
|
|
649
|
+
{"chain_of_thought": "thinking process", "reasoning": "reasoning process"},
|
|
650
|
+
"reasoning process",
|
|
651
|
+
),
|
|
652
|
+
# Other intermediate outputs but no thinking data
|
|
653
|
+
({"other_output": "some data"}, None),
|
|
654
|
+
# Mixed other outputs with thinking data
|
|
655
|
+
(
|
|
656
|
+
{"chain_of_thought": "thinking process", "other_output": "some data"},
|
|
657
|
+
"thinking process",
|
|
658
|
+
),
|
|
659
|
+
],
|
|
660
|
+
)
|
|
661
|
+
def test_task_run_thinking_training_data(intermediate_outputs, expected):
|
|
662
|
+
task_run = TaskRun(
|
|
663
|
+
input="test input",
|
|
664
|
+
output=TaskOutput(output="test output"),
|
|
665
|
+
intermediate_outputs=intermediate_outputs,
|
|
666
|
+
)
|
|
667
|
+
assert task_run.thinking_training_data() == expected
|