kiln-ai 0.11.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/adapter_registry.py +12 -13
- kiln_ai/adapters/data_gen/data_gen_task.py +18 -0
- kiln_ai/adapters/eval/base_eval.py +164 -0
- kiln_ai/adapters/eval/eval_runner.py +267 -0
- kiln_ai/adapters/eval/g_eval.py +367 -0
- kiln_ai/adapters/eval/registry.py +16 -0
- kiln_ai/adapters/eval/test_base_eval.py +324 -0
- kiln_ai/adapters/eval/test_eval_runner.py +640 -0
- kiln_ai/adapters/eval/test_g_eval.py +497 -0
- kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +4 -1
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +1 -1
- kiln_ai/adapters/ml_model_list.py +141 -29
- kiln_ai/adapters/model_adapters/base_adapter.py +50 -35
- kiln_ai/adapters/model_adapters/langchain_adapters.py +27 -20
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -1
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +93 -50
- kiln_ai/adapters/model_adapters/test_base_adapter.py +22 -13
- kiln_ai/adapters/model_adapters/test_langchain_adapter.py +7 -14
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +55 -64
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -19
- kiln_ai/adapters/model_adapters/test_structured_output.py +36 -30
- kiln_ai/adapters/ollama_tools.py +0 -1
- kiln_ai/adapters/prompt_builders.py +80 -42
- kiln_ai/adapters/repair/repair_task.py +9 -21
- kiln_ai/adapters/repair/test_repair_task.py +3 -3
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +10 -10
- kiln_ai/adapters/test_generate_docs.py +6 -6
- kiln_ai/adapters/test_ollama_tools.py +0 -1
- kiln_ai/adapters/test_prompt_adaptors.py +17 -14
- kiln_ai/adapters/test_prompt_builders.py +91 -31
- kiln_ai/datamodel/__init__.py +50 -952
- kiln_ai/datamodel/datamodel_enums.py +58 -0
- kiln_ai/datamodel/dataset_filters.py +114 -0
- kiln_ai/datamodel/dataset_split.py +170 -0
- kiln_ai/datamodel/eval.py +298 -0
- kiln_ai/datamodel/finetune.py +105 -0
- kiln_ai/datamodel/json_schema.py +6 -0
- kiln_ai/datamodel/project.py +23 -0
- kiln_ai/datamodel/prompt.py +37 -0
- kiln_ai/datamodel/prompt_id.py +83 -0
- kiln_ai/datamodel/strict_mode.py +24 -0
- kiln_ai/datamodel/task.py +181 -0
- kiln_ai/datamodel/task_output.py +321 -0
- kiln_ai/datamodel/task_run.py +164 -0
- kiln_ai/datamodel/test_basemodel.py +10 -11
- kiln_ai/datamodel/test_dataset_filters.py +71 -0
- kiln_ai/datamodel/test_dataset_split.py +32 -8
- kiln_ai/datamodel/test_datasource.py +3 -2
- kiln_ai/datamodel/test_eval_model.py +635 -0
- kiln_ai/datamodel/test_example_models.py +9 -13
- kiln_ai/datamodel/test_json_schema.py +23 -0
- kiln_ai/datamodel/test_models.py +2 -2
- kiln_ai/datamodel/test_prompt_id.py +129 -0
- kiln_ai/datamodel/test_task.py +159 -0
- kiln_ai/utils/config.py +6 -1
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.12.0.dist-info}/METADATA +37 -1
- kiln_ai-0.12.0.dist-info/RECORD +100 -0
- kiln_ai-0.11.1.dist-info/RECORD +0 -76
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.12.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import TYPE_CHECKING, Dict, List, Type, Union
|
|
4
|
+
|
|
5
|
+
import jsonschema
|
|
6
|
+
import jsonschema.exceptions
|
|
7
|
+
from pydantic import BaseModel, Field, ValidationInfo, model_validator
|
|
8
|
+
from typing_extensions import Self
|
|
9
|
+
|
|
10
|
+
from kiln_ai.datamodel.basemodel import ID_TYPE, KilnBaseModel
|
|
11
|
+
from kiln_ai.datamodel.datamodel_enums import TaskOutputRatingType
|
|
12
|
+
from kiln_ai.datamodel.json_schema import validate_schema
|
|
13
|
+
from kiln_ai.datamodel.strict_mode import strict_mode
|
|
14
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from kiln_ai.datamodel.task import Task
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RequirementRating(BaseModel):
|
|
21
|
+
"""Rating for a specific requirement within a task output."""
|
|
22
|
+
|
|
23
|
+
value: float = Field(
|
|
24
|
+
description="The rating value. Interpretation depends on rating type"
|
|
25
|
+
)
|
|
26
|
+
type: TaskOutputRatingType = Field(description="The type of rating")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def normalize_rating(rating: float, rating_type: TaskOutputRatingType) -> float:
|
|
30
|
+
"""Normalize a rating to a 0-1 scale. Simple normalization, not z-score."""
|
|
31
|
+
match rating_type:
|
|
32
|
+
case TaskOutputRatingType.five_star:
|
|
33
|
+
if rating < 1 or rating > 5:
|
|
34
|
+
raise ValueError("Five star rating must be between 1 and 5")
|
|
35
|
+
return (rating - 1) / 4
|
|
36
|
+
case TaskOutputRatingType.pass_fail:
|
|
37
|
+
if rating < 0 or rating > 1:
|
|
38
|
+
raise ValueError("Pass fail rating must 0 to 1")
|
|
39
|
+
return rating
|
|
40
|
+
case TaskOutputRatingType.pass_fail_critical:
|
|
41
|
+
if rating < -1 or rating > 1:
|
|
42
|
+
raise ValueError("Pass fail critical rating must -1 to 1")
|
|
43
|
+
return (rating + 1) / 2 # -1 to 1
|
|
44
|
+
case TaskOutputRatingType.custom:
|
|
45
|
+
raise ValueError("Custom rating type can not be normalized")
|
|
46
|
+
case _:
|
|
47
|
+
raise_exhaustive_enum_error(rating_type)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class TaskOutputRating(KilnBaseModel):
|
|
51
|
+
"""
|
|
52
|
+
A rating for a task output, including an overall rating and ratings for each requirement.
|
|
53
|
+
|
|
54
|
+
Supports:
|
|
55
|
+
- five_star: 1-5 star ratings
|
|
56
|
+
- pass_fail: boolean pass/fail (1.0 = pass, 0.0 = fail)
|
|
57
|
+
- pass_fail_critical: tri-state (1.0 = pass, 0.0 = fail, -1.0 = critical fail)
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
type: TaskOutputRatingType = Field(default=TaskOutputRatingType.five_star)
|
|
61
|
+
value: float | None = Field(
|
|
62
|
+
description="The rating value. Interpretation depends on rating type:\n- five_star: 1-5 stars\n- pass_fail: 1.0 (pass) or 0.0 (fail)\n- pass_fail_critical: 1.0 (pass), 0.0 (fail), or -1.0 (critical fail)",
|
|
63
|
+
default=None,
|
|
64
|
+
)
|
|
65
|
+
requirement_ratings: Dict[ID_TYPE, RequirementRating] = Field(
|
|
66
|
+
default={},
|
|
67
|
+
description="The ratings of the requirements of the task.",
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Previously we stored rating values as a dict of floats, but now we store them as RequirementRating objects.
|
|
71
|
+
@model_validator(mode="before")
|
|
72
|
+
def upgrade_old_format(cls, data: dict) -> dict:
|
|
73
|
+
if not isinstance(data, dict):
|
|
74
|
+
return data
|
|
75
|
+
|
|
76
|
+
# Check if we have the old format (dict of floats)
|
|
77
|
+
req_ratings = data.get("requirement_ratings", {})
|
|
78
|
+
if req_ratings and all(
|
|
79
|
+
isinstance(v, (int, float)) for v in req_ratings.values()
|
|
80
|
+
):
|
|
81
|
+
# Convert each float to a RequirementRating object
|
|
82
|
+
# all ratings are five star at the point we used this format
|
|
83
|
+
data["requirement_ratings"] = {
|
|
84
|
+
k: {"value": v, "type": TaskOutputRatingType.five_star}
|
|
85
|
+
for k, v in req_ratings.items()
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
return data
|
|
89
|
+
|
|
90
|
+
# Used to select high quality outputs for example selection (MultiShotPromptBuilder, etc)
|
|
91
|
+
def is_high_quality(self) -> bool:
|
|
92
|
+
if self.value is None:
|
|
93
|
+
return False
|
|
94
|
+
|
|
95
|
+
if self.type == TaskOutputRatingType.five_star:
|
|
96
|
+
return self.value >= 4
|
|
97
|
+
elif self.type == TaskOutputRatingType.pass_fail:
|
|
98
|
+
return self.value == 1.0
|
|
99
|
+
elif self.type == TaskOutputRatingType.pass_fail_critical:
|
|
100
|
+
return self.value == 1.0
|
|
101
|
+
return False
|
|
102
|
+
|
|
103
|
+
@model_validator(mode="after")
|
|
104
|
+
def validate_rating(self) -> Self:
|
|
105
|
+
if self.type not in TaskOutputRatingType:
|
|
106
|
+
raise ValueError(f"Invalid rating type: {self.type}")
|
|
107
|
+
|
|
108
|
+
# Overall rating is optional
|
|
109
|
+
if self.value is not None:
|
|
110
|
+
self._validate_rating(self.type, self.value, "overall rating")
|
|
111
|
+
|
|
112
|
+
for req_id, req_rating in self.requirement_ratings.items():
|
|
113
|
+
self._validate_rating(
|
|
114
|
+
req_rating.type,
|
|
115
|
+
req_rating.value,
|
|
116
|
+
f"requirement rating for req ID: {req_id}",
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
return self
|
|
120
|
+
|
|
121
|
+
def _validate_rating(
|
|
122
|
+
self, type: TaskOutputRatingType, rating: float | None, rating_name: str
|
|
123
|
+
) -> None:
|
|
124
|
+
if type == TaskOutputRatingType.five_star:
|
|
125
|
+
self._validate_five_star(rating, rating_name)
|
|
126
|
+
elif type == TaskOutputRatingType.pass_fail:
|
|
127
|
+
self._validate_pass_fail(rating, rating_name)
|
|
128
|
+
elif type == TaskOutputRatingType.pass_fail_critical:
|
|
129
|
+
self._validate_pass_fail_critical(rating, rating_name)
|
|
130
|
+
|
|
131
|
+
def _validate_five_star(self, rating: float | None, rating_name: str) -> None:
|
|
132
|
+
if rating is None or not isinstance(rating, float) or not rating.is_integer():
|
|
133
|
+
raise ValueError(
|
|
134
|
+
f"{rating_name.capitalize()} of type five_star must be an integer value (1-5)"
|
|
135
|
+
)
|
|
136
|
+
if rating < 1 or rating > 5:
|
|
137
|
+
raise ValueError(
|
|
138
|
+
f"{rating_name.capitalize()} of type five_star must be between 1 and 5 stars"
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def _validate_pass_fail(self, rating: float | None, rating_name: str) -> None:
|
|
142
|
+
if rating is None or not isinstance(rating, float) or not rating.is_integer():
|
|
143
|
+
raise ValueError(
|
|
144
|
+
f"{rating_name.capitalize()} of type pass_fail must be an integer value (0 or 1)"
|
|
145
|
+
)
|
|
146
|
+
if rating not in [0, 1]:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
f"{rating_name.capitalize()} of type pass_fail must be 0 (fail) or 1 (pass)"
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
def _validate_pass_fail_critical(
|
|
152
|
+
self, rating: float | None, rating_name: str
|
|
153
|
+
) -> None:
|
|
154
|
+
if rating is None or not isinstance(rating, float) or not rating.is_integer():
|
|
155
|
+
raise ValueError(
|
|
156
|
+
f"{rating_name.capitalize()} of type pass_fail_critical must be an integer value (-1, 0, or 1)"
|
|
157
|
+
)
|
|
158
|
+
if rating not in [-1, 0, 1]:
|
|
159
|
+
raise ValueError(
|
|
160
|
+
f"{rating_name.capitalize()} of type pass_fail_critical must be -1 (critical fail), 0 (fail), or 1 (pass)"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class DataSourceType(str, Enum):
|
|
165
|
+
"""
|
|
166
|
+
The source type of a piece of data.
|
|
167
|
+
|
|
168
|
+
Human: a human created the data
|
|
169
|
+
Synthetic: a model created the data
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
human = "human"
|
|
173
|
+
synthetic = "synthetic"
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class DataSourceProperty(BaseModel):
|
|
177
|
+
"""
|
|
178
|
+
Defines a property that can be associated with a data source.
|
|
179
|
+
|
|
180
|
+
Includes validation rules for when properties are required or not allowed
|
|
181
|
+
based on the data source type.
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
name: str
|
|
185
|
+
type: Type[Union[str, int, float]]
|
|
186
|
+
required_for: List[DataSourceType] = []
|
|
187
|
+
not_allowed_for: List[DataSourceType] = []
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class DataSource(BaseModel):
|
|
191
|
+
"""
|
|
192
|
+
Represents the origin of data, either human or synthetic, with associated properties.
|
|
193
|
+
|
|
194
|
+
Properties vary based on the source type - for synthetic sources this includes
|
|
195
|
+
model information, for human sources this includes creator information.
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
type: DataSourceType
|
|
199
|
+
properties: Dict[str, str | int | float] = Field(
|
|
200
|
+
default={},
|
|
201
|
+
description="Properties describing the data source. For synthetic things like model. For human, the human's name.",
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
_data_source_properties = [
|
|
205
|
+
DataSourceProperty(
|
|
206
|
+
name="created_by",
|
|
207
|
+
type=str,
|
|
208
|
+
required_for=[DataSourceType.human],
|
|
209
|
+
not_allowed_for=[DataSourceType.synthetic],
|
|
210
|
+
),
|
|
211
|
+
DataSourceProperty(
|
|
212
|
+
name="model_name",
|
|
213
|
+
type=str,
|
|
214
|
+
required_for=[DataSourceType.synthetic],
|
|
215
|
+
not_allowed_for=[DataSourceType.human],
|
|
216
|
+
),
|
|
217
|
+
DataSourceProperty(
|
|
218
|
+
name="model_provider",
|
|
219
|
+
type=str,
|
|
220
|
+
required_for=[DataSourceType.synthetic],
|
|
221
|
+
not_allowed_for=[DataSourceType.human],
|
|
222
|
+
),
|
|
223
|
+
DataSourceProperty(
|
|
224
|
+
name="adapter_name",
|
|
225
|
+
type=str,
|
|
226
|
+
required_for=[DataSourceType.synthetic],
|
|
227
|
+
not_allowed_for=[DataSourceType.human],
|
|
228
|
+
),
|
|
229
|
+
DataSourceProperty(
|
|
230
|
+
# Legacy field -- allow loading from old runs, but we shouldn't be setting it.
|
|
231
|
+
name="prompt_builder_name",
|
|
232
|
+
type=str,
|
|
233
|
+
not_allowed_for=[DataSourceType.human],
|
|
234
|
+
),
|
|
235
|
+
DataSourceProperty(
|
|
236
|
+
# The PromptId of the prompt. Can be a saved prompt, fine-tune, generator name, etc. See PromptId type for more details.
|
|
237
|
+
name="prompt_id",
|
|
238
|
+
type=str,
|
|
239
|
+
not_allowed_for=[DataSourceType.human],
|
|
240
|
+
),
|
|
241
|
+
]
|
|
242
|
+
|
|
243
|
+
@model_validator(mode="after")
|
|
244
|
+
def validate_type(self) -> "DataSource":
|
|
245
|
+
if self.type not in DataSourceType:
|
|
246
|
+
raise ValueError(f"Invalid data source type: {self.type}")
|
|
247
|
+
return self
|
|
248
|
+
|
|
249
|
+
@model_validator(mode="after")
|
|
250
|
+
def validate_properties(self) -> "DataSource":
|
|
251
|
+
for prop in self._data_source_properties:
|
|
252
|
+
# Check the property type is correct
|
|
253
|
+
if prop.name in self.properties:
|
|
254
|
+
if not isinstance(self.properties[prop.name], prop.type):
|
|
255
|
+
raise ValueError(
|
|
256
|
+
f"'{prop.name}' must be of type {prop.type.__name__} for {self.type} data source"
|
|
257
|
+
)
|
|
258
|
+
# Check the property is required for the data source type
|
|
259
|
+
if self.type in prop.required_for:
|
|
260
|
+
if prop.name not in self.properties:
|
|
261
|
+
raise ValueError(
|
|
262
|
+
f"'{prop.name}' is required for {self.type} data source"
|
|
263
|
+
)
|
|
264
|
+
# Check the property is not allowed for the data source type
|
|
265
|
+
elif self.type in prop.not_allowed_for and prop.name in self.properties:
|
|
266
|
+
raise ValueError(
|
|
267
|
+
f"'{prop.name}' is not allowed for {self.type} data source"
|
|
268
|
+
)
|
|
269
|
+
return self
|
|
270
|
+
|
|
271
|
+
@model_validator(mode="after")
|
|
272
|
+
def validate_no_empty_properties(self) -> Self:
|
|
273
|
+
for prop, value in self.properties.items():
|
|
274
|
+
if isinstance(value, str) and value == "":
|
|
275
|
+
raise ValueError(
|
|
276
|
+
f"Property '{prop}' must be a non-empty string for {self.type} data source"
|
|
277
|
+
)
|
|
278
|
+
return self
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
class TaskOutput(KilnBaseModel):
|
|
282
|
+
"""
|
|
283
|
+
An output for a specific task run.
|
|
284
|
+
|
|
285
|
+
Contains the actual output content, its source (human or synthetic),
|
|
286
|
+
and optional rating information.
|
|
287
|
+
"""
|
|
288
|
+
|
|
289
|
+
output: str = Field(
|
|
290
|
+
description="The output of the task. JSON formatted for structured output, plaintext for unstructured output."
|
|
291
|
+
)
|
|
292
|
+
source: DataSource | None = Field(
|
|
293
|
+
description="The source of the output: human or synthetic.",
|
|
294
|
+
default=None,
|
|
295
|
+
)
|
|
296
|
+
rating: TaskOutputRating | None = Field(
|
|
297
|
+
default=None, description="The rating of the output"
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
def validate_output_format(self, task: "Task") -> Self:
|
|
301
|
+
# validate output
|
|
302
|
+
if task.output_json_schema is not None:
|
|
303
|
+
try:
|
|
304
|
+
validate_schema(json.loads(self.output), task.output_json_schema)
|
|
305
|
+
except json.JSONDecodeError:
|
|
306
|
+
raise ValueError("Output is not a valid JSON object")
|
|
307
|
+
except jsonschema.exceptions.ValidationError as e:
|
|
308
|
+
raise ValueError(f"Output does not match task output schema: {e}")
|
|
309
|
+
return self
|
|
310
|
+
|
|
311
|
+
@model_validator(mode="after")
|
|
312
|
+
def validate_output_source(self, info: ValidationInfo) -> Self:
|
|
313
|
+
# On strict mode and not loaded from file, we validate output_source is not None.
|
|
314
|
+
# We want to be able to load any data, even if it's not perfect. But we want to create perfect data when adding new data.
|
|
315
|
+
if not strict_mode():
|
|
316
|
+
return self
|
|
317
|
+
if self.loaded_from_file(info):
|
|
318
|
+
return self
|
|
319
|
+
if self.source is None:
|
|
320
|
+
raise ValueError("Output source is required when strict mode is enabled")
|
|
321
|
+
return self
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import TYPE_CHECKING, Dict, List, Union
|
|
3
|
+
|
|
4
|
+
import jsonschema
|
|
5
|
+
import jsonschema.exceptions
|
|
6
|
+
from pydantic import Field, ValidationInfo, model_validator
|
|
7
|
+
from typing_extensions import Self
|
|
8
|
+
|
|
9
|
+
from kiln_ai.datamodel.basemodel import KilnParentedModel
|
|
10
|
+
from kiln_ai.datamodel.json_schema import validate_schema
|
|
11
|
+
from kiln_ai.datamodel.strict_mode import strict_mode
|
|
12
|
+
from kiln_ai.datamodel.task_output import DataSource, TaskOutput
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from kiln_ai.datamodel.task import Task
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TaskRun(KilnParentedModel):
|
|
19
|
+
"""
|
|
20
|
+
Represents a single execution of a Task.
|
|
21
|
+
|
|
22
|
+
Contains the input used, its source, the output produced, and optional
|
|
23
|
+
repair information if the output needed correction.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
input: str = Field(
|
|
27
|
+
description="The inputs to the task. JSON formatted for structured input, plaintext for unstructured input."
|
|
28
|
+
)
|
|
29
|
+
input_source: DataSource | None = Field(
|
|
30
|
+
default=None, description="The source of the input: human or synthetic."
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
output: TaskOutput = Field(description="The output of the task run.")
|
|
34
|
+
repair_instructions: str | None = Field(
|
|
35
|
+
default=None,
|
|
36
|
+
description="Instructions for fixing the output. Should define what is wrong, and how to fix it. Will be used by models for both generating a fixed output, and evaluating future models.",
|
|
37
|
+
)
|
|
38
|
+
repaired_output: TaskOutput | None = Field(
|
|
39
|
+
default=None,
|
|
40
|
+
description="An version of the output with issues fixed. This must be a 'fixed' version of the existing output, and not an entirely new output. If you wish to generate an ideal curatorial output for this task unrelated to this output, generate a new TaskOutput with type 'human' instead of using this field.",
|
|
41
|
+
)
|
|
42
|
+
intermediate_outputs: Dict[str, str] | None = Field(
|
|
43
|
+
default=None,
|
|
44
|
+
description="Intermediate outputs from the task run. Keys are the names of the intermediate output steps (cot=chain of thought, etc), values are the output data.",
|
|
45
|
+
)
|
|
46
|
+
tags: List[str] = Field(
|
|
47
|
+
default=[],
|
|
48
|
+
description="Tags for the task run. Tags are used to categorize task runs for filtering and reporting.",
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
def has_thinking_training_data(self) -> bool:
|
|
52
|
+
"""
|
|
53
|
+
Does this run have thinking data that we can use to train a thinking model?
|
|
54
|
+
"""
|
|
55
|
+
if self.intermediate_outputs is None:
|
|
56
|
+
return False
|
|
57
|
+
return (
|
|
58
|
+
"chain_of_thought" in self.intermediate_outputs
|
|
59
|
+
or "reasoning" in self.intermediate_outputs
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Workaround to return typed parent without importing Task
|
|
63
|
+
def parent_task(self) -> Union["Task", None]:
|
|
64
|
+
if self.parent is None or self.parent.__class__.__name__ != "Task":
|
|
65
|
+
return None
|
|
66
|
+
return self.parent # type: ignore
|
|
67
|
+
|
|
68
|
+
@model_validator(mode="after")
|
|
69
|
+
def validate_input_format(self, info: ValidationInfo) -> Self:
|
|
70
|
+
# Don't validate if loading from file (not new). Too slow.
|
|
71
|
+
# We don't allow changing task schema, so this is redundant validation.
|
|
72
|
+
# Note: we still validate if editing a loaded model
|
|
73
|
+
if self.loading_from_file(info):
|
|
74
|
+
# Consider loading an existing model as validated.
|
|
75
|
+
self._last_validated_input = self.input
|
|
76
|
+
return self
|
|
77
|
+
|
|
78
|
+
# Don't validate if input has not changed. Too slow to run this every time.
|
|
79
|
+
if (
|
|
80
|
+
hasattr(self, "_last_validated_input")
|
|
81
|
+
and self.input == self._last_validated_input
|
|
82
|
+
):
|
|
83
|
+
return self
|
|
84
|
+
|
|
85
|
+
task = self.parent_task()
|
|
86
|
+
if task is None:
|
|
87
|
+
# don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving)
|
|
88
|
+
return self
|
|
89
|
+
|
|
90
|
+
# validate output
|
|
91
|
+
if task.input_json_schema is not None:
|
|
92
|
+
try:
|
|
93
|
+
validate_schema(json.loads(self.input), task.input_json_schema)
|
|
94
|
+
except json.JSONDecodeError:
|
|
95
|
+
raise ValueError("Input is not a valid JSON object")
|
|
96
|
+
except jsonschema.exceptions.ValidationError as e:
|
|
97
|
+
raise ValueError(f"Input does not match task input schema: {e}")
|
|
98
|
+
self._last_validated_input = self.input
|
|
99
|
+
return self
|
|
100
|
+
|
|
101
|
+
@model_validator(mode="after")
|
|
102
|
+
def validate_output_format(self, info: ValidationInfo) -> Self:
|
|
103
|
+
# Don't validate if loading from file (not new). Too slow.
|
|
104
|
+
# Note: we still validate if editing a loaded model's output.
|
|
105
|
+
if self.loading_from_file(info):
|
|
106
|
+
# Consider loading an existing model as validated.
|
|
107
|
+
self._last_validated_output = self.output.output if self.output else None
|
|
108
|
+
return self
|
|
109
|
+
|
|
110
|
+
# Don't validate unless output has changed since last validation.
|
|
111
|
+
# The validator is slow and costly, don't want it running when setting other fields.
|
|
112
|
+
if (
|
|
113
|
+
hasattr(self, "_last_validated_output")
|
|
114
|
+
and self.output is not None
|
|
115
|
+
and self.output.output == self._last_validated_output
|
|
116
|
+
):
|
|
117
|
+
return self
|
|
118
|
+
|
|
119
|
+
task = self.parent_task()
|
|
120
|
+
if task is None:
|
|
121
|
+
return self
|
|
122
|
+
|
|
123
|
+
self.output.validate_output_format(task)
|
|
124
|
+
self._last_validated_output = self.output.output if self.output else None
|
|
125
|
+
return self
|
|
126
|
+
|
|
127
|
+
@model_validator(mode="after")
|
|
128
|
+
def validate_repaired_output(self) -> Self:
|
|
129
|
+
if self.repaired_output is not None:
|
|
130
|
+
if self.repaired_output.rating is not None:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
"Repaired output rating must be None. Repaired outputs are assumed to have a perfect rating, as they have been fixed."
|
|
133
|
+
)
|
|
134
|
+
if self.repair_instructions is None and self.repaired_output is not None:
|
|
135
|
+
raise ValueError(
|
|
136
|
+
"Repair instructions are required if providing a repaired output."
|
|
137
|
+
)
|
|
138
|
+
if self.repair_instructions is not None and self.repaired_output is None:
|
|
139
|
+
raise ValueError(
|
|
140
|
+
"A repaired output is required if providing repair instructions."
|
|
141
|
+
)
|
|
142
|
+
return self
|
|
143
|
+
|
|
144
|
+
@model_validator(mode="after")
|
|
145
|
+
def validate_input_source(self, info: ValidationInfo) -> Self:
|
|
146
|
+
# On strict mode and not loaded from file, we validate input_source is not None.
|
|
147
|
+
# We want to be able to load any data, even if it's not perfect. But we want to create perfect data when adding new data.
|
|
148
|
+
if not strict_mode():
|
|
149
|
+
return self
|
|
150
|
+
if self.loaded_from_file(info):
|
|
151
|
+
return self
|
|
152
|
+
if self.input_source is None:
|
|
153
|
+
raise ValueError("input_source is required when strict mode is enabled")
|
|
154
|
+
return self
|
|
155
|
+
|
|
156
|
+
@model_validator(mode="after")
|
|
157
|
+
def validate_tags(self) -> Self:
|
|
158
|
+
for tag in self.tags:
|
|
159
|
+
if not tag:
|
|
160
|
+
raise ValueError("Tags cannot be empty strings")
|
|
161
|
+
if " " in tag:
|
|
162
|
+
raise ValueError("Tags cannot contain spaces. Try underscores.")
|
|
163
|
+
|
|
164
|
+
return self
|
|
@@ -6,7 +6,7 @@ from unittest.mock import MagicMock, patch
|
|
|
6
6
|
|
|
7
7
|
import pytest
|
|
8
8
|
|
|
9
|
-
from kiln_ai.adapters.model_adapters.base_adapter import
|
|
9
|
+
from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter
|
|
10
10
|
from kiln_ai.adapters.run_output import RunOutput
|
|
11
11
|
from kiln_ai.datamodel import Task, TaskRun
|
|
12
12
|
from kiln_ai.datamodel.basemodel import (
|
|
@@ -15,6 +15,7 @@ from kiln_ai.datamodel.basemodel import (
|
|
|
15
15
|
string_to_valid_name,
|
|
16
16
|
)
|
|
17
17
|
from kiln_ai.datamodel.model_cache import ModelCache
|
|
18
|
+
from kiln_ai.datamodel.task import RunConfig
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
@pytest.fixture
|
|
@@ -484,13 +485,8 @@ class MockAdapter(BaseAdapter):
|
|
|
484
485
|
async def _run(self, input):
|
|
485
486
|
return RunOutput(output="test output", intermediate_outputs=None)
|
|
486
487
|
|
|
487
|
-
def
|
|
488
|
-
return
|
|
489
|
-
adapter_name="test",
|
|
490
|
-
model_name=self.model_name,
|
|
491
|
-
model_provider=self.model_provider_name,
|
|
492
|
-
prompt_builder_name="test",
|
|
493
|
-
)
|
|
488
|
+
def adapter_name(self) -> str:
|
|
489
|
+
return "test"
|
|
494
490
|
|
|
495
491
|
|
|
496
492
|
@pytest.fixture
|
|
@@ -501,9 +497,12 @@ def base_task():
|
|
|
501
497
|
@pytest.fixture
|
|
502
498
|
def adapter(base_task):
|
|
503
499
|
return MockAdapter(
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
500
|
+
run_config=RunConfig(
|
|
501
|
+
task=base_task,
|
|
502
|
+
model_name="test_model",
|
|
503
|
+
model_provider_name="test_provider",
|
|
504
|
+
prompt_id="simple_prompt_builder",
|
|
505
|
+
),
|
|
507
506
|
)
|
|
508
507
|
|
|
509
508
|
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from pydantic import BaseModel
|
|
3
|
+
|
|
4
|
+
from kiln_ai.datamodel.dataset_filters import (
|
|
5
|
+
AllDatasetFilter,
|
|
6
|
+
DatasetFilterId,
|
|
7
|
+
HighRatingDatasetFilter,
|
|
8
|
+
StaticDatasetFilters,
|
|
9
|
+
TagFilter,
|
|
10
|
+
ThinkingModelDatasetFilter,
|
|
11
|
+
ThinkingModelHighRatedFilter,
|
|
12
|
+
dataset_filter_from_id,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
# Note: Many more filter tests in test_dataset_split.py
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def test_all_dataset_filter_from_id():
|
|
19
|
+
assert dataset_filter_from_id("all") == AllDatasetFilter
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def test_high_rating_dataset_filter_from_id():
|
|
23
|
+
assert dataset_filter_from_id("high_rating") == HighRatingDatasetFilter
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_thinking_model_dataset_filter_from_id():
|
|
27
|
+
assert dataset_filter_from_id("thinking_model") == ThinkingModelDatasetFilter
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_thinking_model_high_rated_dataset_filter_from_id():
|
|
31
|
+
assert (
|
|
32
|
+
dataset_filter_from_id("thinking_model_high_rated")
|
|
33
|
+
== ThinkingModelHighRatedFilter
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def test_all_static_dataset_filters():
|
|
38
|
+
for filter_id in StaticDatasetFilters:
|
|
39
|
+
assert dataset_filter_from_id(filter_id) is not None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ModelTester(BaseModel):
|
|
43
|
+
dsid: DatasetFilterId
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@pytest.mark.parametrize(
|
|
47
|
+
"tag,expected_error,expected_tag",
|
|
48
|
+
[
|
|
49
|
+
("tag::test", False, "test"),
|
|
50
|
+
("tag::other", False, "other"),
|
|
51
|
+
("tag::", True, None),
|
|
52
|
+
("tag", True, None),
|
|
53
|
+
("", True, None),
|
|
54
|
+
],
|
|
55
|
+
)
|
|
56
|
+
def test_tag_filter(tag, expected_error, expected_tag):
|
|
57
|
+
# Check our model validators
|
|
58
|
+
if expected_error:
|
|
59
|
+
with pytest.raises(ValueError):
|
|
60
|
+
ModelTester(dsid=tag)
|
|
61
|
+
else:
|
|
62
|
+
ModelTester(dsid=tag)
|
|
63
|
+
|
|
64
|
+
# Check the constructor
|
|
65
|
+
if expected_tag is None:
|
|
66
|
+
with pytest.raises(ValueError, match="Invalid dataset filter ID:"):
|
|
67
|
+
dataset_filter_from_id(tag)
|
|
68
|
+
else:
|
|
69
|
+
filter = dataset_filter_from_id(tag)
|
|
70
|
+
assert isinstance(filter, TagFilter)
|
|
71
|
+
assert filter.tag == expected_tag
|
|
@@ -3,24 +3,28 @@ from pydantic import ValidationError
|
|
|
3
3
|
|
|
4
4
|
# import datamodel first or we get circular import errors
|
|
5
5
|
from kiln_ai.datamodel import (
|
|
6
|
-
AllDatasetFilter,
|
|
7
|
-
AllSplitDefinition,
|
|
8
|
-
DatasetFilterType,
|
|
9
6
|
DatasetSplit,
|
|
10
7
|
DatasetSplitDefinition,
|
|
11
8
|
DataSource,
|
|
12
9
|
DataSourceType,
|
|
13
|
-
HighRatingDatasetFilter,
|
|
14
10
|
Task,
|
|
15
11
|
TaskOutput,
|
|
16
12
|
TaskOutputRating,
|
|
17
13
|
TaskOutputRatingType,
|
|
18
14
|
TaskRun,
|
|
19
|
-
|
|
20
|
-
|
|
15
|
+
)
|
|
16
|
+
from kiln_ai.datamodel.dataset_split import (
|
|
17
|
+
AllSplitDefinition,
|
|
21
18
|
Train60Test20Val20SplitDefinition,
|
|
22
19
|
Train80Test20SplitDefinition,
|
|
23
20
|
)
|
|
21
|
+
from kiln_ai.datamodel.test_dataset_filters import (
|
|
22
|
+
AllDatasetFilter,
|
|
23
|
+
HighRatingDatasetFilter,
|
|
24
|
+
TagFilter,
|
|
25
|
+
ThinkingModelDatasetFilter,
|
|
26
|
+
ThinkingModelHighRatedFilter,
|
|
27
|
+
)
|
|
24
28
|
|
|
25
29
|
|
|
26
30
|
@pytest.fixture
|
|
@@ -42,6 +46,7 @@ def sample_task_runs(sample_task):
|
|
|
42
46
|
task_runs = []
|
|
43
47
|
for i in range(10):
|
|
44
48
|
rating = 5 if i < 6 else 1 # 6 high, 4 low ratings
|
|
49
|
+
tags = ["tag1"] if i < 6 else []
|
|
45
50
|
task_run = TaskRun(
|
|
46
51
|
parent=sample_task,
|
|
47
52
|
input=f"input_{i}",
|
|
@@ -59,6 +64,7 @@ def sample_task_runs(sample_task):
|
|
|
59
64
|
value=rating, type=TaskOutputRatingType.five_star
|
|
60
65
|
),
|
|
61
66
|
),
|
|
67
|
+
tags=tags,
|
|
62
68
|
)
|
|
63
69
|
task_run.save_to_file()
|
|
64
70
|
task_runs.append(task_run)
|
|
@@ -199,10 +205,10 @@ def test_dataset_split_with_high_rating_filter(sample_task, sample_task_runs):
|
|
|
199
205
|
"Split Name",
|
|
200
206
|
sample_task,
|
|
201
207
|
Train80Test20SplitDefinition,
|
|
202
|
-
|
|
208
|
+
filter_id="high_rating",
|
|
203
209
|
)
|
|
204
210
|
|
|
205
|
-
assert dataset.filter ==
|
|
211
|
+
assert dataset.filter == "high_rating"
|
|
206
212
|
|
|
207
213
|
# Check that only high-rated task runs are included
|
|
208
214
|
all_ids = []
|
|
@@ -329,3 +335,21 @@ def test_thinking_model_dataset_filter_high_rated(
|
|
|
329
335
|
)
|
|
330
336
|
|
|
331
337
|
assert ThinkingModelHighRatedFilter(task_run) is expected_result
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def test_tag_dataset_filter(sample_task_runs):
|
|
341
|
+
num_tagged = 0
|
|
342
|
+
num_untagged = 0
|
|
343
|
+
filter = TagFilter("tag1")
|
|
344
|
+
for task_run in sample_task_runs:
|
|
345
|
+
if "tag1" in task_run.tags:
|
|
346
|
+
num_tagged += 1
|
|
347
|
+
assert "tag1" in task_run.tags
|
|
348
|
+
assert filter(task_run) is True
|
|
349
|
+
else:
|
|
350
|
+
num_untagged += 1
|
|
351
|
+
assert "tag1" not in task_run.tags
|
|
352
|
+
assert filter(task_run) is False
|
|
353
|
+
|
|
354
|
+
assert num_tagged == 6
|
|
355
|
+
assert num_untagged == 4
|