kiln-ai 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +11 -1
- kiln_ai/adapters/adapter_registry.py +19 -0
- kiln_ai/adapters/data_gen/__init__.py +11 -0
- kiln_ai/adapters/data_gen/data_gen_task.py +69 -1
- kiln_ai/adapters/data_gen/test_data_gen_task.py +30 -21
- kiln_ai/adapters/fine_tune/__init__.py +14 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
- kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
- kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
- kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
- kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
- kiln_ai/adapters/langchain_adapters.py +103 -13
- kiln_ai/adapters/ml_model_list.py +218 -304
- kiln_ai/adapters/ollama_tools.py +114 -0
- kiln_ai/adapters/provider_tools.py +295 -0
- kiln_ai/adapters/repair/test_repair_task.py +6 -11
- kiln_ai/adapters/test_langchain_adapter.py +46 -18
- kiln_ai/adapters/test_ollama_tools.py +42 -0
- kiln_ai/adapters/test_prompt_adaptors.py +7 -5
- kiln_ai/adapters/test_provider_tools.py +312 -0
- kiln_ai/adapters/test_structured_output.py +22 -43
- kiln_ai/datamodel/__init__.py +235 -22
- kiln_ai/datamodel/basemodel.py +30 -0
- kiln_ai/datamodel/registry.py +31 -0
- kiln_ai/datamodel/test_basemodel.py +29 -1
- kiln_ai/datamodel/test_dataset_split.py +234 -0
- kiln_ai/datamodel/test_example_models.py +12 -0
- kiln_ai/datamodel/test_models.py +91 -1
- kiln_ai/datamodel/test_registry.py +96 -0
- kiln_ai/utils/config.py +9 -0
- kiln_ai/utils/name_generator.py +125 -0
- kiln_ai/utils/test_name_geneator.py +47 -0
- {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/METADATA +4 -2
- kiln_ai-0.7.0.dist-info/RECORD +56 -0
- kiln_ai/adapters/test_ml_model_list.py +0 -181
- kiln_ai-0.6.0.dist-info/RECORD +0 -36
- {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
kiln_ai/datamodel/__init__.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import math
|
|
5
|
+
import random
|
|
4
6
|
from enum import Enum, IntEnum
|
|
5
|
-
from typing import TYPE_CHECKING, Dict, List, Type, Union
|
|
7
|
+
from typing import TYPE_CHECKING, Callable, Dict, List, Type, Union
|
|
6
8
|
|
|
7
9
|
import jsonschema
|
|
8
10
|
import jsonschema.exceptions
|
|
@@ -14,6 +16,8 @@ from kiln_ai.datamodel.json_schema import JsonObjectSchema, schema_from_json_str
|
|
|
14
16
|
from .basemodel import (
|
|
15
17
|
ID_FIELD,
|
|
16
18
|
ID_TYPE,
|
|
19
|
+
NAME_FIELD,
|
|
20
|
+
SHORT_NAME_FIELD,
|
|
17
21
|
KilnBaseModel,
|
|
18
22
|
KilnParentedModel,
|
|
19
23
|
KilnParentModel,
|
|
@@ -42,26 +46,6 @@ __all__ = [
|
|
|
42
46
|
]
|
|
43
47
|
|
|
44
48
|
|
|
45
|
-
# Conventions:
|
|
46
|
-
# 1) Names are filename safe as they may be used as file names. They are informational and not to be used in prompts/training/validation.
|
|
47
|
-
# 2) Descrptions are for Kiln users to describe/understanding the purpose of this object. They must never be used in prompts/training/validation. Use "instruction/requirements" instead.
|
|
48
|
-
|
|
49
|
-
# Filename compatible names
|
|
50
|
-
NAME_REGEX = r"^[A-Za-z0-9 _-]+$"
|
|
51
|
-
NAME_FIELD = Field(
|
|
52
|
-
min_length=1,
|
|
53
|
-
max_length=120,
|
|
54
|
-
pattern=NAME_REGEX,
|
|
55
|
-
description="A name for this entity.",
|
|
56
|
-
)
|
|
57
|
-
SHORT_NAME_FIELD = Field(
|
|
58
|
-
min_length=1,
|
|
59
|
-
max_length=32,
|
|
60
|
-
pattern=NAME_REGEX,
|
|
61
|
-
description="A name for this entity",
|
|
62
|
-
)
|
|
63
|
-
|
|
64
|
-
|
|
65
49
|
class Priority(IntEnum):
|
|
66
50
|
"""Defines priority levels for tasks and requirements, where P0 is highest priority."""
|
|
67
51
|
|
|
@@ -156,6 +140,71 @@ class TaskOutput(KilnBaseModel):
|
|
|
156
140
|
return self
|
|
157
141
|
|
|
158
142
|
|
|
143
|
+
class FineTuneStatusType(str, Enum):
|
|
144
|
+
"""
|
|
145
|
+
The status type of a fine-tune (running, completed, failed, etc).
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
unknown = "unknown" # server error
|
|
149
|
+
pending = "pending"
|
|
150
|
+
running = "running"
|
|
151
|
+
completed = "completed"
|
|
152
|
+
failed = "failed"
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class Finetune(KilnParentedModel):
|
|
156
|
+
name: str = NAME_FIELD
|
|
157
|
+
description: str | None = Field(
|
|
158
|
+
default=None,
|
|
159
|
+
description="A description of the fine-tune for you and your team. Not used in training.",
|
|
160
|
+
)
|
|
161
|
+
provider: str = Field(
|
|
162
|
+
description="The provider to use for the fine-tune (e.g. 'openai')."
|
|
163
|
+
)
|
|
164
|
+
base_model_id: str = Field(
|
|
165
|
+
description="The id of the base model to use for the fine-tune. This string relates to the provider's IDs for their own models, not Kiln IDs."
|
|
166
|
+
)
|
|
167
|
+
provider_id: str | None = Field(
|
|
168
|
+
default=None,
|
|
169
|
+
description="The ID of the fine-tune job on the provider's side. May not be the same as the fine_tune_model_id.",
|
|
170
|
+
)
|
|
171
|
+
fine_tune_model_id: str | None = Field(
|
|
172
|
+
default=None,
|
|
173
|
+
description="The ID of the fine-tuned model on the provider's side. May not be the same as the provider_id.",
|
|
174
|
+
)
|
|
175
|
+
dataset_split_id: str = Field(
|
|
176
|
+
description="The ID of the dataset split to use for this fine-tune.",
|
|
177
|
+
)
|
|
178
|
+
train_split_name: str = Field(
|
|
179
|
+
default="train",
|
|
180
|
+
description="The name of the training split to use for this fine-tune.",
|
|
181
|
+
)
|
|
182
|
+
validation_split_name: str | None = Field(
|
|
183
|
+
default=None,
|
|
184
|
+
description="The name of the validation split to use for this fine-tune. Optional.",
|
|
185
|
+
)
|
|
186
|
+
parameters: dict[str, str | int | float | bool] = Field(
|
|
187
|
+
default={},
|
|
188
|
+
description="The parameters to use for this fine-tune. These are provider-specific.",
|
|
189
|
+
)
|
|
190
|
+
system_message: str = Field(
|
|
191
|
+
description="The system message to use for this fine-tune.",
|
|
192
|
+
)
|
|
193
|
+
latest_status: FineTuneStatusType = Field(
|
|
194
|
+
default=FineTuneStatusType.unknown,
|
|
195
|
+
description="The latest known status of this fine-tune. Not updated in real time.",
|
|
196
|
+
)
|
|
197
|
+
properties: Dict[str, str | int | float] = Field(
|
|
198
|
+
default={},
|
|
199
|
+
description="Properties of the fine-tune. Different providers may use different properties.",
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
def parent_task(self) -> Task | None:
|
|
203
|
+
if not isinstance(self.parent, Task):
|
|
204
|
+
return None
|
|
205
|
+
return self.parent
|
|
206
|
+
|
|
207
|
+
|
|
159
208
|
class DataSourceType(str, Enum):
|
|
160
209
|
"""
|
|
161
210
|
The source type of a piece of data.
|
|
@@ -344,6 +393,160 @@ class TaskRun(KilnParentedModel):
|
|
|
344
393
|
return self
|
|
345
394
|
|
|
346
395
|
|
|
396
|
+
# Define the type alias for clarity
|
|
397
|
+
DatasetFilter = Callable[[TaskRun], bool]
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def AllDatasetFilter(_: TaskRun) -> bool:
|
|
401
|
+
return True
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def HighRatingDatasetFilter(task_run: TaskRun) -> bool:
|
|
405
|
+
if task_run.output is None or task_run.output.rating is None:
|
|
406
|
+
return False
|
|
407
|
+
return task_run.output.rating.is_high_quality()
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
class DatasetSplitDefinition(BaseModel):
|
|
411
|
+
"""
|
|
412
|
+
A definition of a split in a dataset.
|
|
413
|
+
|
|
414
|
+
Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
|
|
415
|
+
"""
|
|
416
|
+
|
|
417
|
+
name: str = NAME_FIELD
|
|
418
|
+
description: str | None = Field(
|
|
419
|
+
default=None,
|
|
420
|
+
description="A description of the dataset for you and your team. Not used in training.",
|
|
421
|
+
)
|
|
422
|
+
percentage: float = Field(
|
|
423
|
+
ge=0.0,
|
|
424
|
+
le=1.0,
|
|
425
|
+
description="The percentage of the dataset that this split represents (between 0 and 1).",
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
AllSplitDefinition: list[DatasetSplitDefinition] = [
|
|
430
|
+
DatasetSplitDefinition(name="all", percentage=1.0)
|
|
431
|
+
]
|
|
432
|
+
Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [
|
|
433
|
+
DatasetSplitDefinition(name="train", percentage=0.8),
|
|
434
|
+
DatasetSplitDefinition(name="test", percentage=0.2),
|
|
435
|
+
]
|
|
436
|
+
Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [
|
|
437
|
+
DatasetSplitDefinition(name="train", percentage=0.6),
|
|
438
|
+
DatasetSplitDefinition(name="test", percentage=0.2),
|
|
439
|
+
DatasetSplitDefinition(name="val", percentage=0.2),
|
|
440
|
+
]
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
class DatasetSplit(KilnParentedModel):
|
|
444
|
+
"""
|
|
445
|
+
A collection of task runs, with optional splits (train, test, validation).
|
|
446
|
+
|
|
447
|
+
Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
|
|
448
|
+
|
|
449
|
+
Maintains a list of IDs for each split, to avoid data duplication.
|
|
450
|
+
"""
|
|
451
|
+
|
|
452
|
+
name: str = NAME_FIELD
|
|
453
|
+
description: str | None = Field(
|
|
454
|
+
default=None,
|
|
455
|
+
description="A description of the dataset for you and your team. Not used in training.",
|
|
456
|
+
)
|
|
457
|
+
splits: list[DatasetSplitDefinition] = Field(
|
|
458
|
+
default_factory=list,
|
|
459
|
+
description="The splits in the dataset.",
|
|
460
|
+
)
|
|
461
|
+
split_contents: dict[str, list[str]] = Field(
|
|
462
|
+
description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
@model_validator(mode="after")
|
|
466
|
+
def validate_split_percentages(self) -> "DatasetSplit":
|
|
467
|
+
total = sum(split.percentage for split in self.splits)
|
|
468
|
+
if not math.isclose(total, 1.0, rel_tol=1e-9):
|
|
469
|
+
raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
|
|
470
|
+
return self
|
|
471
|
+
|
|
472
|
+
@classmethod
|
|
473
|
+
def from_task(
|
|
474
|
+
cls,
|
|
475
|
+
name: str,
|
|
476
|
+
task: "Task",
|
|
477
|
+
splits: list[DatasetSplitDefinition],
|
|
478
|
+
filter: DatasetFilter = AllDatasetFilter,
|
|
479
|
+
description: str | None = None,
|
|
480
|
+
):
|
|
481
|
+
"""
|
|
482
|
+
Build a dataset split from a task.
|
|
483
|
+
"""
|
|
484
|
+
split_contents = cls.build_split_contents(task, splits, filter)
|
|
485
|
+
return cls(
|
|
486
|
+
parent=task,
|
|
487
|
+
name=name,
|
|
488
|
+
description=description,
|
|
489
|
+
splits=splits,
|
|
490
|
+
split_contents=split_contents,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
@classmethod
|
|
494
|
+
def build_split_contents(
|
|
495
|
+
cls,
|
|
496
|
+
task: "Task",
|
|
497
|
+
splits: list[DatasetSplitDefinition],
|
|
498
|
+
filter: DatasetFilter,
|
|
499
|
+
) -> dict[str, list[str]]:
|
|
500
|
+
valid_ids = []
|
|
501
|
+
for task_run in task.runs():
|
|
502
|
+
if filter(task_run):
|
|
503
|
+
valid_ids.append(task_run.id)
|
|
504
|
+
|
|
505
|
+
# Shuffle and split by split percentage
|
|
506
|
+
random.shuffle(valid_ids)
|
|
507
|
+
split_contents = {}
|
|
508
|
+
start_idx = 0
|
|
509
|
+
remaining_items = len(valid_ids)
|
|
510
|
+
|
|
511
|
+
# Handle all splits except the last one
|
|
512
|
+
for split in splits[:-1]:
|
|
513
|
+
split_size = round(len(valid_ids) * split.percentage)
|
|
514
|
+
split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
|
|
515
|
+
start_idx += split_size
|
|
516
|
+
remaining_items -= split_size
|
|
517
|
+
|
|
518
|
+
# Last split gets all remaining items (for rounding)
|
|
519
|
+
if splits:
|
|
520
|
+
split_contents[splits[-1].name] = valid_ids[start_idx:]
|
|
521
|
+
|
|
522
|
+
return split_contents
|
|
523
|
+
|
|
524
|
+
def parent_task(self) -> "Task | None":
|
|
525
|
+
# inline import to avoid circular import
|
|
526
|
+
from kiln_ai.datamodel import Task
|
|
527
|
+
|
|
528
|
+
if not isinstance(self.parent, Task):
|
|
529
|
+
return None
|
|
530
|
+
return self.parent
|
|
531
|
+
|
|
532
|
+
def missing_count(self) -> int:
|
|
533
|
+
"""
|
|
534
|
+
Returns:
|
|
535
|
+
int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
|
|
536
|
+
"""
|
|
537
|
+
parent = self.parent_task()
|
|
538
|
+
if parent is None:
|
|
539
|
+
raise ValueError("DatasetSplit has no parent task")
|
|
540
|
+
|
|
541
|
+
runs = parent.runs()
|
|
542
|
+
all_ids = set(run.id for run in runs)
|
|
543
|
+
all_ids_in_splits = set()
|
|
544
|
+
for ids in self.split_contents.values():
|
|
545
|
+
all_ids_in_splits.update(ids)
|
|
546
|
+
missing = all_ids_in_splits - all_ids
|
|
547
|
+
return len(missing)
|
|
548
|
+
|
|
549
|
+
|
|
347
550
|
class TaskRequirement(BaseModel):
|
|
348
551
|
"""
|
|
349
552
|
Defines a specific requirement that should be met by task outputs.
|
|
@@ -376,7 +579,11 @@ class TaskDeterminism(str, Enum):
|
|
|
376
579
|
class Task(
|
|
377
580
|
KilnParentedModel,
|
|
378
581
|
KilnParentModel,
|
|
379
|
-
parent_of={
|
|
582
|
+
parent_of={
|
|
583
|
+
"runs": TaskRun,
|
|
584
|
+
"dataset_splits": DatasetSplit,
|
|
585
|
+
"finetunes": Finetune,
|
|
586
|
+
},
|
|
380
587
|
):
|
|
381
588
|
"""
|
|
382
589
|
Represents a specific task to be performed, with associated requirements and validation rules.
|
|
@@ -416,6 +623,12 @@ class Task(
|
|
|
416
623
|
def runs(self) -> list[TaskRun]:
|
|
417
624
|
return super().runs() # type: ignore
|
|
418
625
|
|
|
626
|
+
def dataset_splits(self) -> list[DatasetSplit]:
|
|
627
|
+
return super().dataset_splits() # type: ignore
|
|
628
|
+
|
|
629
|
+
def finetunes(self) -> list[Finetune]:
|
|
630
|
+
return super().finetunes() # type: ignore
|
|
631
|
+
|
|
419
632
|
|
|
420
633
|
class Project(KilnParentModel, parent_of={"tasks": Task}):
|
|
421
634
|
"""
|
kiln_ai/datamodel/basemodel.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import re
|
|
2
3
|
import shutil
|
|
3
4
|
import uuid
|
|
4
5
|
from abc import ABCMeta
|
|
@@ -38,6 +39,34 @@ ID_TYPE = Optional[str]
|
|
|
38
39
|
T = TypeVar("T", bound="KilnBaseModel")
|
|
39
40
|
PT = TypeVar("PT", bound="KilnParentedModel")
|
|
40
41
|
|
|
42
|
+
# Naming conventions:
|
|
43
|
+
# 1) Names are filename safe as they may be used as file names. They are informational and not to be used in prompts/training/validation.
|
|
44
|
+
# 2) Descrptions are for Kiln users to describe/understanding the purpose of this object. They must never be used in prompts/training/validation. Use "instruction/requirements" instead.
|
|
45
|
+
|
|
46
|
+
# Filename compatible names
|
|
47
|
+
NAME_REGEX = r"^[A-Za-z0-9 _-]+$"
|
|
48
|
+
NAME_FIELD = Field(
|
|
49
|
+
min_length=1,
|
|
50
|
+
max_length=120,
|
|
51
|
+
pattern=NAME_REGEX,
|
|
52
|
+
description="A name for this entity.",
|
|
53
|
+
)
|
|
54
|
+
SHORT_NAME_FIELD = Field(
|
|
55
|
+
min_length=1,
|
|
56
|
+
max_length=32,
|
|
57
|
+
pattern=NAME_REGEX,
|
|
58
|
+
description="A name for this entity",
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def string_to_valid_name(name: str) -> str:
|
|
63
|
+
# Replace any character not allowed by NAME_REGEX with an underscore
|
|
64
|
+
valid_name = re.sub(r"[^A-Za-z0-9 _-]", "_", name)
|
|
65
|
+
# Replace consecutive underscores with a single underscore
|
|
66
|
+
valid_name = re.sub(r"_+", "_", valid_name)
|
|
67
|
+
# Remove leading and trailing underscores or whitespace
|
|
68
|
+
return valid_name.strip("_").strip()
|
|
69
|
+
|
|
41
70
|
|
|
42
71
|
class KilnBaseModel(BaseModel):
|
|
43
72
|
"""Base model for all Kiln data models with common functionality for persistence and versioning.
|
|
@@ -97,6 +126,7 @@ class KilnBaseModel(BaseModel):
|
|
|
97
126
|
|
|
98
127
|
Raises:
|
|
99
128
|
ValueError: If the loaded model is not of the expected type or version
|
|
129
|
+
FileNotFoundError: If the file does not exist
|
|
100
130
|
"""
|
|
101
131
|
with open(path, "r") as file:
|
|
102
132
|
file_data = file.read()
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from kiln_ai.datamodel import Project
|
|
2
|
+
from kiln_ai.utils.config import Config
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def all_projects() -> list[Project]:
|
|
6
|
+
project_paths = Config.shared().projects
|
|
7
|
+
if project_paths is None:
|
|
8
|
+
return []
|
|
9
|
+
projects = []
|
|
10
|
+
for project_path in project_paths:
|
|
11
|
+
try:
|
|
12
|
+
projects.append(Project.load_from_file(project_path))
|
|
13
|
+
except Exception:
|
|
14
|
+
# deleted files are possible continue with the rest
|
|
15
|
+
continue
|
|
16
|
+
return projects
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def project_from_id(project_id: str) -> Project | None:
|
|
20
|
+
project_paths = Config.shared().projects
|
|
21
|
+
if project_paths is not None:
|
|
22
|
+
for project_path in project_paths:
|
|
23
|
+
try:
|
|
24
|
+
project = Project.load_from_file(project_path)
|
|
25
|
+
if project.id == project_id:
|
|
26
|
+
return project
|
|
27
|
+
except Exception:
|
|
28
|
+
# deleted files are possible continue with the rest
|
|
29
|
+
continue
|
|
30
|
+
|
|
31
|
+
return None
|
|
@@ -5,7 +5,11 @@ from typing import Optional
|
|
|
5
5
|
|
|
6
6
|
import pytest
|
|
7
7
|
|
|
8
|
-
from kiln_ai.datamodel.basemodel import
|
|
8
|
+
from kiln_ai.datamodel.basemodel import (
|
|
9
|
+
KilnBaseModel,
|
|
10
|
+
KilnParentedModel,
|
|
11
|
+
string_to_valid_name,
|
|
12
|
+
)
|
|
9
13
|
|
|
10
14
|
|
|
11
15
|
@pytest.fixture
|
|
@@ -306,3 +310,27 @@ def test_delete_no_path():
|
|
|
306
310
|
model = KilnBaseModel()
|
|
307
311
|
with pytest.raises(ValueError, match="Cannot delete model because path is not set"):
|
|
308
312
|
model.delete()
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def test_string_to_valid_name():
|
|
316
|
+
# Test basic valid strings remain unchanged
|
|
317
|
+
assert string_to_valid_name("Hello World") == "Hello World"
|
|
318
|
+
assert string_to_valid_name("Test-123") == "Test-123"
|
|
319
|
+
assert string_to_valid_name("my_file_name") == "my_file_name"
|
|
320
|
+
|
|
321
|
+
# Test invalid characters are replaced
|
|
322
|
+
assert string_to_valid_name("Hello@World!") == "Hello_World"
|
|
323
|
+
assert string_to_valid_name("File.name.txt") == "File_name_txt"
|
|
324
|
+
assert string_to_valid_name("Special#$%Chars") == "Special_Chars"
|
|
325
|
+
|
|
326
|
+
# Test consecutive invalid characters
|
|
327
|
+
assert string_to_valid_name("multiple!!!symbols") == "multiple_symbols"
|
|
328
|
+
assert string_to_valid_name("path/to/file") == "path_to_file"
|
|
329
|
+
|
|
330
|
+
# Test leading/trailing special characters
|
|
331
|
+
assert string_to_valid_name("__test__") == "test"
|
|
332
|
+
assert string_to_valid_name("...test...") == "test"
|
|
333
|
+
|
|
334
|
+
# Test empty string and whitespace
|
|
335
|
+
assert string_to_valid_name("") == ""
|
|
336
|
+
assert string_to_valid_name(" ") == ""
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from pydantic import ValidationError
|
|
3
|
+
|
|
4
|
+
# import datamodel first or we get circular import errors
|
|
5
|
+
from kiln_ai.datamodel import (
|
|
6
|
+
AllDatasetFilter,
|
|
7
|
+
AllSplitDefinition,
|
|
8
|
+
DatasetSplit,
|
|
9
|
+
DatasetSplitDefinition,
|
|
10
|
+
DataSource,
|
|
11
|
+
DataSourceType,
|
|
12
|
+
HighRatingDatasetFilter,
|
|
13
|
+
Task,
|
|
14
|
+
TaskOutput,
|
|
15
|
+
TaskOutputRating,
|
|
16
|
+
TaskOutputRatingType,
|
|
17
|
+
TaskRun,
|
|
18
|
+
Train60Test20Val20SplitDefinition,
|
|
19
|
+
Train80Test20SplitDefinition,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@pytest.fixture
|
|
24
|
+
def sample_task(tmp_path):
|
|
25
|
+
task_path = tmp_path / "task.kiln"
|
|
26
|
+
task = Task(
|
|
27
|
+
name="Test Task",
|
|
28
|
+
path=task_path,
|
|
29
|
+
description="Test task for dataset splitting",
|
|
30
|
+
instruction="Test instruction",
|
|
31
|
+
)
|
|
32
|
+
task.save_to_file()
|
|
33
|
+
return task
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@pytest.fixture
|
|
37
|
+
def sample_task_runs(sample_task):
|
|
38
|
+
# Create 10 task runs with different ratings
|
|
39
|
+
task_runs = []
|
|
40
|
+
for i in range(10):
|
|
41
|
+
rating = 5 if i < 6 else 1 # 6 high, 4 low ratings
|
|
42
|
+
task_run = TaskRun(
|
|
43
|
+
parent=sample_task,
|
|
44
|
+
input=f"input_{i}",
|
|
45
|
+
input_source=DataSource(
|
|
46
|
+
type=DataSourceType.human,
|
|
47
|
+
properties={"created_by": "test-user"},
|
|
48
|
+
),
|
|
49
|
+
output=TaskOutput(
|
|
50
|
+
output=f"output_{i}",
|
|
51
|
+
source=DataSource(
|
|
52
|
+
type=DataSourceType.human,
|
|
53
|
+
properties={"created_by": "test-user"},
|
|
54
|
+
),
|
|
55
|
+
rating=TaskOutputRating(
|
|
56
|
+
value=rating, type=TaskOutputRatingType.five_star
|
|
57
|
+
),
|
|
58
|
+
),
|
|
59
|
+
)
|
|
60
|
+
task_run.save_to_file()
|
|
61
|
+
task_runs.append(task_run)
|
|
62
|
+
return task_runs
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@pytest.fixture
|
|
66
|
+
def standard_splitstandard_splitss():
|
|
67
|
+
return [
|
|
68
|
+
DatasetSplitDefinition(name="train", percentage=0.8),
|
|
69
|
+
DatasetSplitDefinition(name="test", percentage=0.2),
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@pytest.fixture
|
|
74
|
+
def task_run():
|
|
75
|
+
return TaskRun(
|
|
76
|
+
input="test input",
|
|
77
|
+
input_source=DataSource(
|
|
78
|
+
type=DataSourceType.human,
|
|
79
|
+
properties={"created_by": "test-user"},
|
|
80
|
+
),
|
|
81
|
+
output=TaskOutput(
|
|
82
|
+
output="test output",
|
|
83
|
+
source=DataSource(
|
|
84
|
+
type=DataSourceType.human,
|
|
85
|
+
properties={"created_by": "test-user"},
|
|
86
|
+
),
|
|
87
|
+
rating=TaskOutputRating(rating=5, type=TaskOutputRatingType.five_star),
|
|
88
|
+
),
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def test_dataset_split_definition():
|
|
93
|
+
split = DatasetSplitDefinition(name="train", percentage=0.8)
|
|
94
|
+
assert split.name == "train"
|
|
95
|
+
assert split.percentage == 0.8
|
|
96
|
+
assert split.description is None
|
|
97
|
+
|
|
98
|
+
# Test validation
|
|
99
|
+
with pytest.raises(ValidationError):
|
|
100
|
+
DatasetSplitDefinition(name="train", percentage=1.5)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def test_dataset_split_validation():
|
|
104
|
+
# Test valid percentages
|
|
105
|
+
splits = [
|
|
106
|
+
DatasetSplitDefinition(name="train", percentage=0.8),
|
|
107
|
+
DatasetSplitDefinition(name="test", percentage=0.2),
|
|
108
|
+
]
|
|
109
|
+
dataset = DatasetSplit(
|
|
110
|
+
name="test_split",
|
|
111
|
+
splits=splits,
|
|
112
|
+
split_contents={"train": [], "test": []},
|
|
113
|
+
)
|
|
114
|
+
assert dataset.splits == splits
|
|
115
|
+
|
|
116
|
+
# Test invalid percentages
|
|
117
|
+
invalid_splits = [
|
|
118
|
+
DatasetSplitDefinition(name="train", percentage=0.8),
|
|
119
|
+
DatasetSplitDefinition(name="test", percentage=0.3),
|
|
120
|
+
]
|
|
121
|
+
with pytest.raises(ValueError, match="sum of split percentages must be 1.0"):
|
|
122
|
+
DatasetSplit(
|
|
123
|
+
name="test_split",
|
|
124
|
+
splits=invalid_splits,
|
|
125
|
+
split_contents={"train": [], "test": []},
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def test_all_dataset_filter(task_run):
|
|
130
|
+
assert AllDatasetFilter(task_run) is True
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def test_high_rating_dataset_filter(sample_task_runs):
|
|
134
|
+
for task_run in sample_task_runs:
|
|
135
|
+
assert HighRatingDatasetFilter(task_run) is (
|
|
136
|
+
task_run.output.rating.is_high_quality()
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
@pytest.mark.parametrize(
|
|
141
|
+
"splits,expected_sizes",
|
|
142
|
+
[
|
|
143
|
+
(Train80Test20SplitDefinition, {"train": 8, "test": 2}),
|
|
144
|
+
(AllSplitDefinition, {"all": 10}),
|
|
145
|
+
(Train60Test20Val20SplitDefinition, {"train": 6, "test": 2, "val": 2}),
|
|
146
|
+
(
|
|
147
|
+
[
|
|
148
|
+
DatasetSplitDefinition(name="train", percentage=0.7),
|
|
149
|
+
DatasetSplitDefinition(name="validation", percentage=0.2),
|
|
150
|
+
DatasetSplitDefinition(name="test", percentage=0.1),
|
|
151
|
+
],
|
|
152
|
+
{"train": 7, "validation": 2, "test": 1},
|
|
153
|
+
),
|
|
154
|
+
],
|
|
155
|
+
)
|
|
156
|
+
def test_dataset_split_from_task(sample_task, sample_task_runs, splits, expected_sizes):
|
|
157
|
+
assert sample_task_runs is not None
|
|
158
|
+
dataset = DatasetSplit.from_task("Split Name", sample_task, splits)
|
|
159
|
+
assert dataset.name == "Split Name"
|
|
160
|
+
|
|
161
|
+
# Check split sizes match expected
|
|
162
|
+
for split_name, expected_size in expected_sizes.items():
|
|
163
|
+
assert len(dataset.split_contents[split_name]) == expected_size
|
|
164
|
+
|
|
165
|
+
# Verify total size matches input size
|
|
166
|
+
total_size = sum(len(ids) for ids in dataset.split_contents.values())
|
|
167
|
+
assert total_size == len(sample_task_runs)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def test_dataset_split_with_high_rating_filter(sample_task, sample_task_runs):
|
|
171
|
+
assert len(sample_task_runs) == 10
|
|
172
|
+
dataset = DatasetSplit.from_task(
|
|
173
|
+
"Split Name",
|
|
174
|
+
sample_task,
|
|
175
|
+
Train80Test20SplitDefinition,
|
|
176
|
+
filter=HighRatingDatasetFilter,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Check that only high-rated task runs are included
|
|
180
|
+
all_ids = []
|
|
181
|
+
for ids in dataset.split_contents.values():
|
|
182
|
+
all_ids.extend(ids)
|
|
183
|
+
assert len(all_ids) == 6 # We created 6 high-rated task runs
|
|
184
|
+
|
|
185
|
+
# Check split proportions
|
|
186
|
+
train_size = len(dataset.split_contents["train"])
|
|
187
|
+
test_size = len(dataset.split_contents["test"])
|
|
188
|
+
assert train_size == 5 # ~80% of 6
|
|
189
|
+
assert test_size == 1 # ~20% of 6
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def test_dataset_split_with_single_split(sample_task, sample_task_runs):
|
|
193
|
+
splits = [DatasetSplitDefinition(name="all", percentage=1.0)]
|
|
194
|
+
dataset = DatasetSplit.from_task("Split Name", sample_task, splits)
|
|
195
|
+
|
|
196
|
+
assert len(dataset.split_contents["all"]) == len(sample_task_runs)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def test_missing_count(sample_task, sample_task_runs):
|
|
200
|
+
assert sample_task_runs is not None
|
|
201
|
+
# Create a dataset split with all task runs
|
|
202
|
+
dataset = DatasetSplit.from_task(
|
|
203
|
+
"Split Name", sample_task, Train80Test20SplitDefinition
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# Initially there should be no missing runs
|
|
207
|
+
assert dataset.missing_count() == 0
|
|
208
|
+
|
|
209
|
+
# Add some IDs to the split, that aren't on disk
|
|
210
|
+
dataset.split_contents["test"].append("1")
|
|
211
|
+
dataset.split_contents["test"].append("2")
|
|
212
|
+
dataset.split_contents["test"].append("3")
|
|
213
|
+
# shouldn't happen, but should not double count if it does
|
|
214
|
+
dataset.split_contents["train"].append("3")
|
|
215
|
+
|
|
216
|
+
# Now we should have 3 missing runs
|
|
217
|
+
assert dataset.missing_count() == 3
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def test_smaller_sample(sample_task, sample_task_runs):
|
|
221
|
+
assert sample_task_runs is not None
|
|
222
|
+
# Create a dataset split with all task runs
|
|
223
|
+
dataset = DatasetSplit.from_task(
|
|
224
|
+
"Split Name", sample_task, Train80Test20SplitDefinition
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Initially there should be no missing runs
|
|
228
|
+
assert dataset.missing_count() == 0
|
|
229
|
+
|
|
230
|
+
dataset.split_contents["test"].pop()
|
|
231
|
+
dataset.split_contents["train"].pop()
|
|
232
|
+
|
|
233
|
+
# Now we should have 0 missing runs. It's okay that dataset has newer data.
|
|
234
|
+
assert dataset.missing_count() == 0
|
|
@@ -5,8 +5,10 @@ import pytest
|
|
|
5
5
|
from pydantic import ValidationError
|
|
6
6
|
|
|
7
7
|
from kiln_ai.datamodel import (
|
|
8
|
+
DatasetSplit,
|
|
8
9
|
DataSource,
|
|
9
10
|
DataSourceType,
|
|
11
|
+
Finetune,
|
|
10
12
|
Project,
|
|
11
13
|
Task,
|
|
12
14
|
TaskDeterminism,
|
|
@@ -97,6 +99,16 @@ def test_task_run_relationship(valid_task_run):
|
|
|
97
99
|
assert valid_task_run.__class__.parent_type().__name__ == "Task"
|
|
98
100
|
|
|
99
101
|
|
|
102
|
+
def test_dataset_split_relationship():
|
|
103
|
+
assert DatasetSplit.relationship_name() == "dataset_splits"
|
|
104
|
+
assert DatasetSplit.parent_type().__name__ == "Task"
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def test_base_finetune_relationship():
|
|
108
|
+
assert Finetune.relationship_name() == "finetunes"
|
|
109
|
+
assert Finetune.parent_type().__name__ == "Task"
|
|
110
|
+
|
|
111
|
+
|
|
100
112
|
def test_structured_output_workflow(tmp_path):
|
|
101
113
|
tmp_project_file = (
|
|
102
114
|
tmp_path / "test_structured_output_runs" / Project.base_filename()
|