kiln-ai 0.8.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +7 -7
- kiln_ai/adapters/adapter_registry.py +81 -10
- kiln_ai/adapters/data_gen/data_gen_task.py +21 -3
- kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
- kiln_ai/adapters/eval/base_eval.py +164 -0
- kiln_ai/adapters/eval/eval_runner.py +267 -0
- kiln_ai/adapters/eval/g_eval.py +367 -0
- kiln_ai/adapters/eval/registry.py +16 -0
- kiln_ai/adapters/eval/test_base_eval.py +324 -0
- kiln_ai/adapters/eval/test_eval_runner.py +640 -0
- kiln_ai/adapters/eval/test_g_eval.py +497 -0
- kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
- kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
- kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
- kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +472 -129
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +114 -22
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
- kiln_ai/adapters/ml_model_list.py +434 -93
- kiln_ai/adapters/model_adapters/__init__.py +18 -0
- kiln_ai/adapters/model_adapters/base_adapter.py +250 -0
- kiln_ai/adapters/model_adapters/langchain_adapters.py +309 -0
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +10 -0
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +289 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +199 -0
- kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +105 -97
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +216 -0
- kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +80 -30
- kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +125 -46
- kiln_ai/adapters/ollama_tools.py +0 -1
- kiln_ai/adapters/parsers/__init__.py +10 -0
- kiln_ai/adapters/parsers/base_parser.py +12 -0
- kiln_ai/adapters/parsers/json_parser.py +37 -0
- kiln_ai/adapters/parsers/parser_registry.py +19 -0
- kiln_ai/adapters/parsers/r1_parser.py +69 -0
- kiln_ai/adapters/parsers/test_json_parser.py +81 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
- kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
- kiln_ai/adapters/prompt_builders.py +193 -49
- kiln_ai/adapters/provider_tools.py +91 -36
- kiln_ai/adapters/repair/repair_task.py +18 -19
- kiln_ai/adapters/repair/test_repair_task.py +7 -7
- kiln_ai/adapters/run_output.py +11 -0
- kiln_ai/adapters/test_adapter_registry.py +177 -0
- kiln_ai/adapters/test_generate_docs.py +69 -0
- kiln_ai/adapters/test_ollama_tools.py +0 -1
- kiln_ai/adapters/test_prompt_adaptors.py +25 -18
- kiln_ai/adapters/test_prompt_builders.py +265 -44
- kiln_ai/adapters/test_provider_tools.py +268 -46
- kiln_ai/datamodel/__init__.py +51 -772
- kiln_ai/datamodel/basemodel.py +31 -11
- kiln_ai/datamodel/datamodel_enums.py +58 -0
- kiln_ai/datamodel/dataset_filters.py +114 -0
- kiln_ai/datamodel/dataset_split.py +170 -0
- kiln_ai/datamodel/eval.py +298 -0
- kiln_ai/datamodel/finetune.py +105 -0
- kiln_ai/datamodel/json_schema.py +14 -3
- kiln_ai/datamodel/model_cache.py +8 -3
- kiln_ai/datamodel/project.py +23 -0
- kiln_ai/datamodel/prompt.py +37 -0
- kiln_ai/datamodel/prompt_id.py +83 -0
- kiln_ai/datamodel/strict_mode.py +24 -0
- kiln_ai/datamodel/task.py +181 -0
- kiln_ai/datamodel/task_output.py +321 -0
- kiln_ai/datamodel/task_run.py +164 -0
- kiln_ai/datamodel/test_basemodel.py +80 -2
- kiln_ai/datamodel/test_dataset_filters.py +71 -0
- kiln_ai/datamodel/test_dataset_split.py +127 -6
- kiln_ai/datamodel/test_datasource.py +3 -2
- kiln_ai/datamodel/test_eval_model.py +635 -0
- kiln_ai/datamodel/test_example_models.py +34 -17
- kiln_ai/datamodel/test_json_schema.py +23 -0
- kiln_ai/datamodel/test_model_cache.py +24 -0
- kiln_ai/datamodel/test_model_perf.py +125 -0
- kiln_ai/datamodel/test_models.py +131 -2
- kiln_ai/datamodel/test_prompt_id.py +129 -0
- kiln_ai/datamodel/test_task.py +159 -0
- kiln_ai/utils/config.py +6 -1
- kiln_ai/utils/exhaustive_error.py +6 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/METADATA +45 -7
- kiln_ai-0.12.0.dist-info/RECORD +100 -0
- kiln_ai/adapters/base_adapter.py +0 -191
- kiln_ai/adapters/langchain_adapters.py +0 -256
- kiln_ai-0.8.1.dist-info/RECORD +0 -58
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
kiln_ai/datamodel/basemodel.py
CHANGED
|
@@ -120,11 +120,12 @@ class KilnBaseModel(BaseModel):
|
|
|
120
120
|
return cls.load_from_file(path)
|
|
121
121
|
|
|
122
122
|
@classmethod
|
|
123
|
-
def load_from_file(cls: Type[T], path: Path | str) -> T:
|
|
123
|
+
def load_from_file(cls: Type[T], path: Path | str, readonly: bool = False) -> T:
|
|
124
124
|
"""Load a model instance from a specific file path.
|
|
125
125
|
|
|
126
126
|
Args:
|
|
127
127
|
path (Path): Path to the model file
|
|
128
|
+
readonly (bool): If True, the model will be returned in readonly mode (cached instance, not a copy, not safe to mutate)
|
|
128
129
|
|
|
129
130
|
Returns:
|
|
130
131
|
T: Instance of the model
|
|
@@ -135,10 +136,10 @@ class KilnBaseModel(BaseModel):
|
|
|
135
136
|
"""
|
|
136
137
|
if isinstance(path, str):
|
|
137
138
|
path = Path(path)
|
|
138
|
-
cached_model = ModelCache.shared().get_model(path, cls)
|
|
139
|
+
cached_model = ModelCache.shared().get_model(path, cls, readonly=readonly)
|
|
139
140
|
if cached_model is not None:
|
|
140
141
|
return cached_model
|
|
141
|
-
with open(path, "r") as file:
|
|
142
|
+
with open(path, "r", encoding="utf-8") as file:
|
|
142
143
|
# modified time of file for cache invalidation. From file descriptor so it's atomic w read.
|
|
143
144
|
mtime_ns = os.fstat(file.fileno()).st_mtime_ns
|
|
144
145
|
file_data = file.read()
|
|
@@ -168,13 +169,20 @@ class KilnBaseModel(BaseModel):
|
|
|
168
169
|
# Two methods of indicated it's loaded from file:
|
|
169
170
|
# 1) info.context.get("loading_from_file") -> During actual loading, before we can set _loaded_from_file
|
|
170
171
|
# 2) self._loaded_from_file -> After loading, set by the loader
|
|
172
|
+
if self.loading_from_file(info):
|
|
173
|
+
return True
|
|
174
|
+
return self._loaded_from_file
|
|
175
|
+
|
|
176
|
+
# indicates the model is currently being loaded from file (not mutating it after)
|
|
177
|
+
def loading_from_file(self, info: ValidationInfo | None = None) -> bool:
|
|
178
|
+
# info.context.get("loading_from_file") -> During actual loading, before we can set _loaded_from_file
|
|
171
179
|
if (
|
|
172
180
|
info is not None
|
|
173
181
|
and info.context is not None
|
|
174
182
|
and info.context.get("loading_from_file", False)
|
|
175
183
|
):
|
|
176
184
|
return True
|
|
177
|
-
return
|
|
185
|
+
return False
|
|
178
186
|
|
|
179
187
|
def save_to_file(self) -> None:
|
|
180
188
|
"""Save the model instance to a file.
|
|
@@ -190,7 +198,7 @@ class KilnBaseModel(BaseModel):
|
|
|
190
198
|
)
|
|
191
199
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
192
200
|
json_data = self.model_dump_json(indent=2, exclude={"path"})
|
|
193
|
-
with open(path, "w") as file:
|
|
201
|
+
with open(path, "w", encoding="utf-8") as file:
|
|
194
202
|
file.write(json_data)
|
|
195
203
|
# save the path so even if something like name changes, the file doesn't move
|
|
196
204
|
self.path = path
|
|
@@ -342,16 +350,28 @@ class KilnParentedModel(KilnBaseModel, metaclass=ABCMeta):
|
|
|
342
350
|
return []
|
|
343
351
|
|
|
344
352
|
# Collect all /relationship/{id}/{base_filename.kiln} files in the relationship folder
|
|
345
|
-
for
|
|
346
|
-
|
|
353
|
+
# manual code instead of glob for performance (5x speedup over glob)
|
|
354
|
+
|
|
355
|
+
base_filename = cls.base_filename()
|
|
356
|
+
# Iterate through immediate subdirectories using scandir for better performance
|
|
357
|
+
# Benchmark: scandir is 10x faster than glob, so worth the extra code
|
|
358
|
+
with os.scandir(relationship_folder) as entries:
|
|
359
|
+
for entry in entries:
|
|
360
|
+
if not entry.is_dir():
|
|
361
|
+
continue
|
|
362
|
+
|
|
363
|
+
child_file = Path(entry.path) / base_filename
|
|
364
|
+
if child_file.is_file():
|
|
365
|
+
yield child_file
|
|
347
366
|
|
|
348
367
|
@classmethod
|
|
349
368
|
def all_children_of_parent_path(
|
|
350
|
-
cls: Type[PT], parent_path: Path | None
|
|
369
|
+
cls: Type[PT], parent_path: Path | None, readonly: bool = False
|
|
351
370
|
) -> list[PT]:
|
|
352
371
|
children = []
|
|
353
372
|
for child_path in cls.iterate_children_paths_of_parent_path(parent_path):
|
|
354
|
-
|
|
373
|
+
item = cls.load_from_file(child_path, readonly=readonly)
|
|
374
|
+
children.append(item)
|
|
355
375
|
return children
|
|
356
376
|
|
|
357
377
|
@classmethod
|
|
@@ -394,8 +414,8 @@ class KilnParentModel(KilnBaseModel, metaclass=ABCMeta):
|
|
|
394
414
|
def _create_child_method(
|
|
395
415
|
cls, relationship_name: str, child_class: Type[KilnParentedModel]
|
|
396
416
|
):
|
|
397
|
-
def child_method(self) -> list[child_class]:
|
|
398
|
-
return child_class.all_children_of_parent_path(self.path)
|
|
417
|
+
def child_method(self, readonly: bool = False) -> list[child_class]:
|
|
418
|
+
return child_class.all_children_of_parent_path(self.path, readonly=readonly)
|
|
399
419
|
|
|
400
420
|
child_method.__name__ = relationship_name
|
|
401
421
|
child_method.__annotations__ = {"return": List[child_class]}
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from enum import Enum, IntEnum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Priority(IntEnum):
|
|
5
|
+
"""Defines priority levels for tasks and requirements, where P0 is highest priority."""
|
|
6
|
+
|
|
7
|
+
p0 = 0
|
|
8
|
+
p1 = 1
|
|
9
|
+
p2 = 2
|
|
10
|
+
p3 = 3
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# Only one rating type for now, but this allows for extensibility if we want to add more in the future
|
|
14
|
+
class TaskOutputRatingType(str, Enum):
|
|
15
|
+
"""Defines the types of rating systems available for task outputs."""
|
|
16
|
+
|
|
17
|
+
five_star = "five_star"
|
|
18
|
+
pass_fail = "pass_fail"
|
|
19
|
+
pass_fail_critical = "pass_fail_critical"
|
|
20
|
+
custom = "custom"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class StructuredOutputMode(str, Enum):
|
|
24
|
+
"""
|
|
25
|
+
Enumeration of supported structured output modes.
|
|
26
|
+
|
|
27
|
+
- default: let the adapter decide
|
|
28
|
+
- json_schema: request json using API capabilities for json_schema
|
|
29
|
+
- function_calling: request json using API capabilities for function calling
|
|
30
|
+
- json_mode: request json using API's JSON mode, which should return valid JSON, but isn't checking/passing the schema
|
|
31
|
+
- json_instructions: append instructions to the prompt to request json matching the schema. No API capabilities are used. You should have a custom parser on these models as they will be returning strings.
|
|
32
|
+
- json_instruction_and_object: append instructions to the prompt to request json matching the schema. Also request the response as json_mode via API capabilities (returning dictionaries).
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
default = "default"
|
|
36
|
+
json_schema = "json_schema"
|
|
37
|
+
function_calling_weak = "function_calling_weak"
|
|
38
|
+
function_calling = "function_calling"
|
|
39
|
+
json_mode = "json_mode"
|
|
40
|
+
json_instructions = "json_instructions"
|
|
41
|
+
json_instruction_and_object = "json_instruction_and_object"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class FineTuneStatusType(str, Enum):
|
|
45
|
+
"""
|
|
46
|
+
The status type of a fine-tune (running, completed, failed, etc).
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
unknown = "unknown" # server error
|
|
50
|
+
pending = "pending"
|
|
51
|
+
running = "running"
|
|
52
|
+
completed = "completed"
|
|
53
|
+
failed = "failed"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class FinetuneDataStrategy(str, Enum):
|
|
57
|
+
final_only = "final_only"
|
|
58
|
+
final_and_intermediate = "final_and_intermediate"
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Annotated, Protocol
|
|
3
|
+
|
|
4
|
+
from pydantic import AfterValidator
|
|
5
|
+
|
|
6
|
+
from kiln_ai.datamodel.task_run import TaskRun
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DatasetFilter(Protocol):
|
|
10
|
+
"""A protocol defining the interface for dataset filters.
|
|
11
|
+
|
|
12
|
+
This allows both stateless function-based filters and stateful class-based filters
|
|
13
|
+
to be used interchangeably, as long as they implement the __call__ method.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __call__(self, task_run: TaskRun) -> bool:
|
|
17
|
+
"""Return True if the task run should be included in the dataset."""
|
|
18
|
+
...
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def AllDatasetFilter(_: TaskRun) -> bool:
|
|
22
|
+
return True
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def HighRatingDatasetFilter(task_run: TaskRun) -> bool:
|
|
26
|
+
if task_run.output is None:
|
|
27
|
+
return False
|
|
28
|
+
if task_run.repaired_output is not None:
|
|
29
|
+
# Repairs always considered high quality
|
|
30
|
+
return True
|
|
31
|
+
if task_run.output.rating is None:
|
|
32
|
+
return False
|
|
33
|
+
return task_run.output.rating.is_high_quality()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def ThinkingModelDatasetFilter(task_run: TaskRun) -> bool:
|
|
37
|
+
"""
|
|
38
|
+
A filter that returns True if the task has intermediate outputs we can training a 'thinking' model on (reasoning or chain of thought)
|
|
39
|
+
"""
|
|
40
|
+
return task_run.has_thinking_training_data()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def ThinkingModelHighRatedFilter(task_run: TaskRun) -> bool:
|
|
44
|
+
"""
|
|
45
|
+
A filter that returns True if the task has thinking data and the output is high quality
|
|
46
|
+
"""
|
|
47
|
+
return ThinkingModelDatasetFilter(task_run) and HighRatingDatasetFilter(task_run)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class TagFilter:
|
|
51
|
+
"""
|
|
52
|
+
A filter that returns True if the task has a tag matching the given tag.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, tag: str):
|
|
56
|
+
self.tag = tag
|
|
57
|
+
|
|
58
|
+
def __call__(self, task_run: TaskRun) -> bool:
|
|
59
|
+
return self.tag in task_run.tags
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class StaticDatasetFilters(str, Enum):
|
|
63
|
+
"""Dataset filter names."""
|
|
64
|
+
|
|
65
|
+
ALL = "all"
|
|
66
|
+
HIGH_RATING = "high_rating"
|
|
67
|
+
THINKING_MODEL = "thinking_model"
|
|
68
|
+
THINKING_MODEL_HIGH_RATED = "thinking_model_high_rated"
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
static_dataset_filters = {
|
|
72
|
+
StaticDatasetFilters.ALL: AllDatasetFilter,
|
|
73
|
+
StaticDatasetFilters.HIGH_RATING: HighRatingDatasetFilter,
|
|
74
|
+
StaticDatasetFilters.THINKING_MODEL: ThinkingModelDatasetFilter,
|
|
75
|
+
StaticDatasetFilters.THINKING_MODEL_HIGH_RATED: ThinkingModelHighRatedFilter,
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
DatasetFilterId = Annotated[
|
|
79
|
+
str,
|
|
80
|
+
AfterValidator(lambda v: _check_dataset_filter_id(v)),
|
|
81
|
+
]
|
|
82
|
+
"""
|
|
83
|
+
A pydantic type that validates strings containing a valid dataset filter ID.
|
|
84
|
+
|
|
85
|
+
Dataset filter IDs can be one of:
|
|
86
|
+
- A built-in dataset filter name
|
|
87
|
+
- A tag::<tag> filter, where <tag> is a string
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _check_dataset_filter_id(id: str) -> str:
|
|
92
|
+
"""
|
|
93
|
+
Check that the dataset filter ID is valid.
|
|
94
|
+
"""
|
|
95
|
+
if id in static_dataset_filters:
|
|
96
|
+
return id
|
|
97
|
+
|
|
98
|
+
if id.startswith("tag::") and len(id) > 5:
|
|
99
|
+
return id
|
|
100
|
+
|
|
101
|
+
raise ValueError(f"Invalid dataset filter ID: {id}")
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def dataset_filter_from_id(id: DatasetFilterId) -> DatasetFilter:
|
|
105
|
+
"""
|
|
106
|
+
Get a dataset filter from an ID.
|
|
107
|
+
"""
|
|
108
|
+
if id.startswith("tag::") and len(id) > 5:
|
|
109
|
+
return TagFilter(id[5:])
|
|
110
|
+
|
|
111
|
+
if id in static_dataset_filters:
|
|
112
|
+
return static_dataset_filters[id]
|
|
113
|
+
|
|
114
|
+
raise ValueError(f"Invalid dataset filter ID: {id}")
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tools for splitting datasets into train/test/validation splits. Includes filters for selecting which task runs to include in each split.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
import random
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Field, model_validator
|
|
10
|
+
|
|
11
|
+
from kiln_ai.datamodel.basemodel import NAME_FIELD, KilnParentedModel
|
|
12
|
+
from kiln_ai.datamodel.dataset_filters import (
|
|
13
|
+
DatasetFilter,
|
|
14
|
+
DatasetFilterId,
|
|
15
|
+
dataset_filter_from_id,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from kiln_ai.datamodel.task import Task
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DatasetSplitDefinition(BaseModel):
|
|
23
|
+
"""
|
|
24
|
+
A definition of a split in a dataset.
|
|
25
|
+
|
|
26
|
+
Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
name: str = NAME_FIELD
|
|
30
|
+
description: str | None = Field(
|
|
31
|
+
default=None,
|
|
32
|
+
description="A description of the dataset for you and your team. Not used in training.",
|
|
33
|
+
)
|
|
34
|
+
percentage: float = Field(
|
|
35
|
+
ge=0.0,
|
|
36
|
+
le=1.0,
|
|
37
|
+
description="The percentage of the dataset that this split represents (between 0 and 1).",
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
AllSplitDefinition: list[DatasetSplitDefinition] = [
|
|
42
|
+
DatasetSplitDefinition(name="all", percentage=1.0)
|
|
43
|
+
]
|
|
44
|
+
Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [
|
|
45
|
+
DatasetSplitDefinition(name="train", percentage=0.8),
|
|
46
|
+
DatasetSplitDefinition(name="test", percentage=0.2),
|
|
47
|
+
]
|
|
48
|
+
Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [
|
|
49
|
+
DatasetSplitDefinition(name="train", percentage=0.6),
|
|
50
|
+
DatasetSplitDefinition(name="test", percentage=0.2),
|
|
51
|
+
DatasetSplitDefinition(name="val", percentage=0.2),
|
|
52
|
+
]
|
|
53
|
+
Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [
|
|
54
|
+
DatasetSplitDefinition(name="train", percentage=0.8),
|
|
55
|
+
DatasetSplitDefinition(name="test", percentage=0.1),
|
|
56
|
+
DatasetSplitDefinition(name="val", percentage=0.1),
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class DatasetSplit(KilnParentedModel):
|
|
61
|
+
"""
|
|
62
|
+
A collection of task runs, with optional splits (train, test, validation).
|
|
63
|
+
|
|
64
|
+
Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
|
|
65
|
+
|
|
66
|
+
Maintains a list of IDs for each split, to avoid data duplication.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
name: str = NAME_FIELD
|
|
70
|
+
description: str | None = Field(
|
|
71
|
+
default=None,
|
|
72
|
+
description="A description of the dataset for you and your team. Not used in training.",
|
|
73
|
+
)
|
|
74
|
+
splits: list[DatasetSplitDefinition] = Field(
|
|
75
|
+
default_factory=list,
|
|
76
|
+
description="The splits in the dataset.",
|
|
77
|
+
)
|
|
78
|
+
split_contents: dict[str, list[str]] = Field(
|
|
79
|
+
description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
|
|
80
|
+
)
|
|
81
|
+
filter: DatasetFilterId | None = Field(
|
|
82
|
+
default=None,
|
|
83
|
+
description="The filter used to build the dataset.",
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
@model_validator(mode="after")
|
|
87
|
+
def validate_split_percentages(self) -> "DatasetSplit":
|
|
88
|
+
total = sum(split.percentage for split in self.splits)
|
|
89
|
+
if not math.isclose(total, 1.0, rel_tol=1e-9):
|
|
90
|
+
raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
|
|
91
|
+
return self
|
|
92
|
+
|
|
93
|
+
@classmethod
|
|
94
|
+
def from_task(
|
|
95
|
+
cls,
|
|
96
|
+
name: str,
|
|
97
|
+
task: "Task",
|
|
98
|
+
splits: list[DatasetSplitDefinition],
|
|
99
|
+
filter_id: DatasetFilterId = "all",
|
|
100
|
+
description: str | None = None,
|
|
101
|
+
):
|
|
102
|
+
"""
|
|
103
|
+
Build a dataset split from a task.
|
|
104
|
+
"""
|
|
105
|
+
filter = dataset_filter_from_id(filter_id)
|
|
106
|
+
split_contents = cls.build_split_contents(task, splits, filter)
|
|
107
|
+
return cls(
|
|
108
|
+
parent=task,
|
|
109
|
+
name=name,
|
|
110
|
+
description=description,
|
|
111
|
+
splits=splits,
|
|
112
|
+
split_contents=split_contents,
|
|
113
|
+
filter=filter_id,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
@classmethod
|
|
117
|
+
def build_split_contents(
|
|
118
|
+
cls,
|
|
119
|
+
task: "Task",
|
|
120
|
+
splits: list[DatasetSplitDefinition],
|
|
121
|
+
filter: DatasetFilter,
|
|
122
|
+
) -> dict[str, list[str]]:
|
|
123
|
+
valid_ids = []
|
|
124
|
+
for task_run in task.runs():
|
|
125
|
+
if filter(task_run):
|
|
126
|
+
valid_ids.append(task_run.id)
|
|
127
|
+
|
|
128
|
+
# Shuffle and split by split percentage
|
|
129
|
+
random.shuffle(valid_ids)
|
|
130
|
+
split_contents = {}
|
|
131
|
+
start_idx = 0
|
|
132
|
+
remaining_items = len(valid_ids)
|
|
133
|
+
|
|
134
|
+
# Handle all splits except the last one
|
|
135
|
+
for split in splits[:-1]:
|
|
136
|
+
split_size = round(len(valid_ids) * split.percentage)
|
|
137
|
+
split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
|
|
138
|
+
start_idx += split_size
|
|
139
|
+
remaining_items -= split_size
|
|
140
|
+
|
|
141
|
+
# Last split gets all remaining items (for rounding)
|
|
142
|
+
if splits:
|
|
143
|
+
split_contents[splits[-1].name] = valid_ids[start_idx:]
|
|
144
|
+
|
|
145
|
+
return split_contents
|
|
146
|
+
|
|
147
|
+
def parent_task(self) -> "Task | None":
|
|
148
|
+
# inline import to avoid circular import
|
|
149
|
+
from kiln_ai.datamodel import Task
|
|
150
|
+
|
|
151
|
+
if not isinstance(self.parent, Task):
|
|
152
|
+
return None
|
|
153
|
+
return self.parent
|
|
154
|
+
|
|
155
|
+
def missing_count(self) -> int:
|
|
156
|
+
"""
|
|
157
|
+
Returns:
|
|
158
|
+
int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
|
|
159
|
+
"""
|
|
160
|
+
parent = self.parent_task()
|
|
161
|
+
if parent is None:
|
|
162
|
+
raise ValueError("DatasetSplit has no parent task")
|
|
163
|
+
|
|
164
|
+
runs = parent.runs(readonly=True)
|
|
165
|
+
all_ids = set(run.id for run in runs)
|
|
166
|
+
all_ids_in_splits = set()
|
|
167
|
+
for ids in self.split_contents.values():
|
|
168
|
+
all_ids_in_splits.update(ids)
|
|
169
|
+
missing = all_ids_in_splits - all_ids
|
|
170
|
+
return len(missing)
|