kiln-ai 0.6.1__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (40) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +19 -0
  3. kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
  4. kiln_ai/adapters/fine_tune/__init__.py +14 -0
  5. kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
  6. kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
  7. kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
  8. kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
  9. kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
  10. kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
  11. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
  12. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
  13. kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
  14. kiln_ai/adapters/langchain_adapters.py +103 -13
  15. kiln_ai/adapters/ml_model_list.py +218 -304
  16. kiln_ai/adapters/ollama_tools.py +114 -0
  17. kiln_ai/adapters/provider_tools.py +295 -0
  18. kiln_ai/adapters/repair/test_repair_task.py +6 -11
  19. kiln_ai/adapters/test_langchain_adapter.py +46 -18
  20. kiln_ai/adapters/test_ollama_tools.py +42 -0
  21. kiln_ai/adapters/test_prompt_adaptors.py +7 -5
  22. kiln_ai/adapters/test_provider_tools.py +312 -0
  23. kiln_ai/adapters/test_structured_output.py +22 -43
  24. kiln_ai/datamodel/__init__.py +235 -22
  25. kiln_ai/datamodel/basemodel.py +30 -0
  26. kiln_ai/datamodel/registry.py +31 -0
  27. kiln_ai/datamodel/test_basemodel.py +29 -1
  28. kiln_ai/datamodel/test_dataset_split.py +234 -0
  29. kiln_ai/datamodel/test_example_models.py +12 -0
  30. kiln_ai/datamodel/test_models.py +91 -1
  31. kiln_ai/datamodel/test_registry.py +96 -0
  32. kiln_ai/utils/config.py +9 -0
  33. kiln_ai/utils/name_generator.py +125 -0
  34. kiln_ai/utils/test_name_geneator.py +47 -0
  35. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.0.dist-info}/METADATA +4 -2
  36. kiln_ai-0.7.0.dist-info/RECORD +56 -0
  37. kiln_ai/adapters/test_ml_model_list.py +0 -181
  38. kiln_ai-0.6.1.dist-info/RECORD +0 -37
  39. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.0.dist-info}/WHEEL +0 -0
  40. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,8 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import json
4
+ import math
5
+ import random
4
6
  from enum import Enum, IntEnum
5
- from typing import TYPE_CHECKING, Dict, List, Type, Union
7
+ from typing import TYPE_CHECKING, Callable, Dict, List, Type, Union
6
8
 
7
9
  import jsonschema
8
10
  import jsonschema.exceptions
@@ -14,6 +16,8 @@ from kiln_ai.datamodel.json_schema import JsonObjectSchema, schema_from_json_str
14
16
  from .basemodel import (
15
17
  ID_FIELD,
16
18
  ID_TYPE,
19
+ NAME_FIELD,
20
+ SHORT_NAME_FIELD,
17
21
  KilnBaseModel,
18
22
  KilnParentedModel,
19
23
  KilnParentModel,
@@ -42,26 +46,6 @@ __all__ = [
42
46
  ]
43
47
 
44
48
 
45
- # Conventions:
46
- # 1) Names are filename safe as they may be used as file names. They are informational and not to be used in prompts/training/validation.
47
- # 2) Descrptions are for Kiln users to describe/understanding the purpose of this object. They must never be used in prompts/training/validation. Use "instruction/requirements" instead.
48
-
49
- # Filename compatible names
50
- NAME_REGEX = r"^[A-Za-z0-9 _-]+$"
51
- NAME_FIELD = Field(
52
- min_length=1,
53
- max_length=120,
54
- pattern=NAME_REGEX,
55
- description="A name for this entity.",
56
- )
57
- SHORT_NAME_FIELD = Field(
58
- min_length=1,
59
- max_length=32,
60
- pattern=NAME_REGEX,
61
- description="A name for this entity",
62
- )
63
-
64
-
65
49
  class Priority(IntEnum):
66
50
  """Defines priority levels for tasks and requirements, where P0 is highest priority."""
67
51
 
@@ -156,6 +140,71 @@ class TaskOutput(KilnBaseModel):
156
140
  return self
157
141
 
158
142
 
143
+ class FineTuneStatusType(str, Enum):
144
+ """
145
+ The status type of a fine-tune (running, completed, failed, etc).
146
+ """
147
+
148
+ unknown = "unknown" # server error
149
+ pending = "pending"
150
+ running = "running"
151
+ completed = "completed"
152
+ failed = "failed"
153
+
154
+
155
+ class Finetune(KilnParentedModel):
156
+ name: str = NAME_FIELD
157
+ description: str | None = Field(
158
+ default=None,
159
+ description="A description of the fine-tune for you and your team. Not used in training.",
160
+ )
161
+ provider: str = Field(
162
+ description="The provider to use for the fine-tune (e.g. 'openai')."
163
+ )
164
+ base_model_id: str = Field(
165
+ description="The id of the base model to use for the fine-tune. This string relates to the provider's IDs for their own models, not Kiln IDs."
166
+ )
167
+ provider_id: str | None = Field(
168
+ default=None,
169
+ description="The ID of the fine-tune job on the provider's side. May not be the same as the fine_tune_model_id.",
170
+ )
171
+ fine_tune_model_id: str | None = Field(
172
+ default=None,
173
+ description="The ID of the fine-tuned model on the provider's side. May not be the same as the provider_id.",
174
+ )
175
+ dataset_split_id: str = Field(
176
+ description="The ID of the dataset split to use for this fine-tune.",
177
+ )
178
+ train_split_name: str = Field(
179
+ default="train",
180
+ description="The name of the training split to use for this fine-tune.",
181
+ )
182
+ validation_split_name: str | None = Field(
183
+ default=None,
184
+ description="The name of the validation split to use for this fine-tune. Optional.",
185
+ )
186
+ parameters: dict[str, str | int | float | bool] = Field(
187
+ default={},
188
+ description="The parameters to use for this fine-tune. These are provider-specific.",
189
+ )
190
+ system_message: str = Field(
191
+ description="The system message to use for this fine-tune.",
192
+ )
193
+ latest_status: FineTuneStatusType = Field(
194
+ default=FineTuneStatusType.unknown,
195
+ description="The latest known status of this fine-tune. Not updated in real time.",
196
+ )
197
+ properties: Dict[str, str | int | float] = Field(
198
+ default={},
199
+ description="Properties of the fine-tune. Different providers may use different properties.",
200
+ )
201
+
202
+ def parent_task(self) -> Task | None:
203
+ if not isinstance(self.parent, Task):
204
+ return None
205
+ return self.parent
206
+
207
+
159
208
  class DataSourceType(str, Enum):
160
209
  """
161
210
  The source type of a piece of data.
@@ -344,6 +393,160 @@ class TaskRun(KilnParentedModel):
344
393
  return self
345
394
 
346
395
 
396
+ # Define the type alias for clarity
397
+ DatasetFilter = Callable[[TaskRun], bool]
398
+
399
+
400
+ def AllDatasetFilter(_: TaskRun) -> bool:
401
+ return True
402
+
403
+
404
+ def HighRatingDatasetFilter(task_run: TaskRun) -> bool:
405
+ if task_run.output is None or task_run.output.rating is None:
406
+ return False
407
+ return task_run.output.rating.is_high_quality()
408
+
409
+
410
+ class DatasetSplitDefinition(BaseModel):
411
+ """
412
+ A definition of a split in a dataset.
413
+
414
+ Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
415
+ """
416
+
417
+ name: str = NAME_FIELD
418
+ description: str | None = Field(
419
+ default=None,
420
+ description="A description of the dataset for you and your team. Not used in training.",
421
+ )
422
+ percentage: float = Field(
423
+ ge=0.0,
424
+ le=1.0,
425
+ description="The percentage of the dataset that this split represents (between 0 and 1).",
426
+ )
427
+
428
+
429
+ AllSplitDefinition: list[DatasetSplitDefinition] = [
430
+ DatasetSplitDefinition(name="all", percentage=1.0)
431
+ ]
432
+ Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [
433
+ DatasetSplitDefinition(name="train", percentage=0.8),
434
+ DatasetSplitDefinition(name="test", percentage=0.2),
435
+ ]
436
+ Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [
437
+ DatasetSplitDefinition(name="train", percentage=0.6),
438
+ DatasetSplitDefinition(name="test", percentage=0.2),
439
+ DatasetSplitDefinition(name="val", percentage=0.2),
440
+ ]
441
+
442
+
443
+ class DatasetSplit(KilnParentedModel):
444
+ """
445
+ A collection of task runs, with optional splits (train, test, validation).
446
+
447
+ Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
448
+
449
+ Maintains a list of IDs for each split, to avoid data duplication.
450
+ """
451
+
452
+ name: str = NAME_FIELD
453
+ description: str | None = Field(
454
+ default=None,
455
+ description="A description of the dataset for you and your team. Not used in training.",
456
+ )
457
+ splits: list[DatasetSplitDefinition] = Field(
458
+ default_factory=list,
459
+ description="The splits in the dataset.",
460
+ )
461
+ split_contents: dict[str, list[str]] = Field(
462
+ description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
463
+ )
464
+
465
+ @model_validator(mode="after")
466
+ def validate_split_percentages(self) -> "DatasetSplit":
467
+ total = sum(split.percentage for split in self.splits)
468
+ if not math.isclose(total, 1.0, rel_tol=1e-9):
469
+ raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
470
+ return self
471
+
472
+ @classmethod
473
+ def from_task(
474
+ cls,
475
+ name: str,
476
+ task: "Task",
477
+ splits: list[DatasetSplitDefinition],
478
+ filter: DatasetFilter = AllDatasetFilter,
479
+ description: str | None = None,
480
+ ):
481
+ """
482
+ Build a dataset split from a task.
483
+ """
484
+ split_contents = cls.build_split_contents(task, splits, filter)
485
+ return cls(
486
+ parent=task,
487
+ name=name,
488
+ description=description,
489
+ splits=splits,
490
+ split_contents=split_contents,
491
+ )
492
+
493
+ @classmethod
494
+ def build_split_contents(
495
+ cls,
496
+ task: "Task",
497
+ splits: list[DatasetSplitDefinition],
498
+ filter: DatasetFilter,
499
+ ) -> dict[str, list[str]]:
500
+ valid_ids = []
501
+ for task_run in task.runs():
502
+ if filter(task_run):
503
+ valid_ids.append(task_run.id)
504
+
505
+ # Shuffle and split by split percentage
506
+ random.shuffle(valid_ids)
507
+ split_contents = {}
508
+ start_idx = 0
509
+ remaining_items = len(valid_ids)
510
+
511
+ # Handle all splits except the last one
512
+ for split in splits[:-1]:
513
+ split_size = round(len(valid_ids) * split.percentage)
514
+ split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
515
+ start_idx += split_size
516
+ remaining_items -= split_size
517
+
518
+ # Last split gets all remaining items (for rounding)
519
+ if splits:
520
+ split_contents[splits[-1].name] = valid_ids[start_idx:]
521
+
522
+ return split_contents
523
+
524
+ def parent_task(self) -> "Task | None":
525
+ # inline import to avoid circular import
526
+ from kiln_ai.datamodel import Task
527
+
528
+ if not isinstance(self.parent, Task):
529
+ return None
530
+ return self.parent
531
+
532
+ def missing_count(self) -> int:
533
+ """
534
+ Returns:
535
+ int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
536
+ """
537
+ parent = self.parent_task()
538
+ if parent is None:
539
+ raise ValueError("DatasetSplit has no parent task")
540
+
541
+ runs = parent.runs()
542
+ all_ids = set(run.id for run in runs)
543
+ all_ids_in_splits = set()
544
+ for ids in self.split_contents.values():
545
+ all_ids_in_splits.update(ids)
546
+ missing = all_ids_in_splits - all_ids
547
+ return len(missing)
548
+
549
+
347
550
  class TaskRequirement(BaseModel):
348
551
  """
349
552
  Defines a specific requirement that should be met by task outputs.
@@ -376,7 +579,11 @@ class TaskDeterminism(str, Enum):
376
579
  class Task(
377
580
  KilnParentedModel,
378
581
  KilnParentModel,
379
- parent_of={"runs": TaskRun},
582
+ parent_of={
583
+ "runs": TaskRun,
584
+ "dataset_splits": DatasetSplit,
585
+ "finetunes": Finetune,
586
+ },
380
587
  ):
381
588
  """
382
589
  Represents a specific task to be performed, with associated requirements and validation rules.
@@ -416,6 +623,12 @@ class Task(
416
623
  def runs(self) -> list[TaskRun]:
417
624
  return super().runs() # type: ignore
418
625
 
626
+ def dataset_splits(self) -> list[DatasetSplit]:
627
+ return super().dataset_splits() # type: ignore
628
+
629
+ def finetunes(self) -> list[Finetune]:
630
+ return super().finetunes() # type: ignore
631
+
419
632
 
420
633
  class Project(KilnParentModel, parent_of={"tasks": Task}):
421
634
  """
@@ -1,4 +1,5 @@
1
1
  import json
2
+ import re
2
3
  import shutil
3
4
  import uuid
4
5
  from abc import ABCMeta
@@ -38,6 +39,34 @@ ID_TYPE = Optional[str]
38
39
  T = TypeVar("T", bound="KilnBaseModel")
39
40
  PT = TypeVar("PT", bound="KilnParentedModel")
40
41
 
42
+ # Naming conventions:
43
+ # 1) Names are filename safe as they may be used as file names. They are informational and not to be used in prompts/training/validation.
44
+ # 2) Descrptions are for Kiln users to describe/understanding the purpose of this object. They must never be used in prompts/training/validation. Use "instruction/requirements" instead.
45
+
46
+ # Filename compatible names
47
+ NAME_REGEX = r"^[A-Za-z0-9 _-]+$"
48
+ NAME_FIELD = Field(
49
+ min_length=1,
50
+ max_length=120,
51
+ pattern=NAME_REGEX,
52
+ description="A name for this entity.",
53
+ )
54
+ SHORT_NAME_FIELD = Field(
55
+ min_length=1,
56
+ max_length=32,
57
+ pattern=NAME_REGEX,
58
+ description="A name for this entity",
59
+ )
60
+
61
+
62
+ def string_to_valid_name(name: str) -> str:
63
+ # Replace any character not allowed by NAME_REGEX with an underscore
64
+ valid_name = re.sub(r"[^A-Za-z0-9 _-]", "_", name)
65
+ # Replace consecutive underscores with a single underscore
66
+ valid_name = re.sub(r"_+", "_", valid_name)
67
+ # Remove leading and trailing underscores or whitespace
68
+ return valid_name.strip("_").strip()
69
+
41
70
 
42
71
  class KilnBaseModel(BaseModel):
43
72
  """Base model for all Kiln data models with common functionality for persistence and versioning.
@@ -97,6 +126,7 @@ class KilnBaseModel(BaseModel):
97
126
 
98
127
  Raises:
99
128
  ValueError: If the loaded model is not of the expected type or version
129
+ FileNotFoundError: If the file does not exist
100
130
  """
101
131
  with open(path, "r") as file:
102
132
  file_data = file.read()
@@ -0,0 +1,31 @@
1
+ from kiln_ai.datamodel import Project
2
+ from kiln_ai.utils.config import Config
3
+
4
+
5
+ def all_projects() -> list[Project]:
6
+ project_paths = Config.shared().projects
7
+ if project_paths is None:
8
+ return []
9
+ projects = []
10
+ for project_path in project_paths:
11
+ try:
12
+ projects.append(Project.load_from_file(project_path))
13
+ except Exception:
14
+ # deleted files are possible continue with the rest
15
+ continue
16
+ return projects
17
+
18
+
19
+ def project_from_id(project_id: str) -> Project | None:
20
+ project_paths = Config.shared().projects
21
+ if project_paths is not None:
22
+ for project_path in project_paths:
23
+ try:
24
+ project = Project.load_from_file(project_path)
25
+ if project.id == project_id:
26
+ return project
27
+ except Exception:
28
+ # deleted files are possible continue with the rest
29
+ continue
30
+
31
+ return None
@@ -5,7 +5,11 @@ from typing import Optional
5
5
 
6
6
  import pytest
7
7
 
8
- from kiln_ai.datamodel.basemodel import KilnBaseModel, KilnParentedModel
8
+ from kiln_ai.datamodel.basemodel import (
9
+ KilnBaseModel,
10
+ KilnParentedModel,
11
+ string_to_valid_name,
12
+ )
9
13
 
10
14
 
11
15
  @pytest.fixture
@@ -306,3 +310,27 @@ def test_delete_no_path():
306
310
  model = KilnBaseModel()
307
311
  with pytest.raises(ValueError, match="Cannot delete model because path is not set"):
308
312
  model.delete()
313
+
314
+
315
+ def test_string_to_valid_name():
316
+ # Test basic valid strings remain unchanged
317
+ assert string_to_valid_name("Hello World") == "Hello World"
318
+ assert string_to_valid_name("Test-123") == "Test-123"
319
+ assert string_to_valid_name("my_file_name") == "my_file_name"
320
+
321
+ # Test invalid characters are replaced
322
+ assert string_to_valid_name("Hello@World!") == "Hello_World"
323
+ assert string_to_valid_name("File.name.txt") == "File_name_txt"
324
+ assert string_to_valid_name("Special#$%Chars") == "Special_Chars"
325
+
326
+ # Test consecutive invalid characters
327
+ assert string_to_valid_name("multiple!!!symbols") == "multiple_symbols"
328
+ assert string_to_valid_name("path/to/file") == "path_to_file"
329
+
330
+ # Test leading/trailing special characters
331
+ assert string_to_valid_name("__test__") == "test"
332
+ assert string_to_valid_name("...test...") == "test"
333
+
334
+ # Test empty string and whitespace
335
+ assert string_to_valid_name("") == ""
336
+ assert string_to_valid_name(" ") == ""
@@ -0,0 +1,234 @@
1
+ import pytest
2
+ from pydantic import ValidationError
3
+
4
+ # import datamodel first or we get circular import errors
5
+ from kiln_ai.datamodel import (
6
+ AllDatasetFilter,
7
+ AllSplitDefinition,
8
+ DatasetSplit,
9
+ DatasetSplitDefinition,
10
+ DataSource,
11
+ DataSourceType,
12
+ HighRatingDatasetFilter,
13
+ Task,
14
+ TaskOutput,
15
+ TaskOutputRating,
16
+ TaskOutputRatingType,
17
+ TaskRun,
18
+ Train60Test20Val20SplitDefinition,
19
+ Train80Test20SplitDefinition,
20
+ )
21
+
22
+
23
+ @pytest.fixture
24
+ def sample_task(tmp_path):
25
+ task_path = tmp_path / "task.kiln"
26
+ task = Task(
27
+ name="Test Task",
28
+ path=task_path,
29
+ description="Test task for dataset splitting",
30
+ instruction="Test instruction",
31
+ )
32
+ task.save_to_file()
33
+ return task
34
+
35
+
36
+ @pytest.fixture
37
+ def sample_task_runs(sample_task):
38
+ # Create 10 task runs with different ratings
39
+ task_runs = []
40
+ for i in range(10):
41
+ rating = 5 if i < 6 else 1 # 6 high, 4 low ratings
42
+ task_run = TaskRun(
43
+ parent=sample_task,
44
+ input=f"input_{i}",
45
+ input_source=DataSource(
46
+ type=DataSourceType.human,
47
+ properties={"created_by": "test-user"},
48
+ ),
49
+ output=TaskOutput(
50
+ output=f"output_{i}",
51
+ source=DataSource(
52
+ type=DataSourceType.human,
53
+ properties={"created_by": "test-user"},
54
+ ),
55
+ rating=TaskOutputRating(
56
+ value=rating, type=TaskOutputRatingType.five_star
57
+ ),
58
+ ),
59
+ )
60
+ task_run.save_to_file()
61
+ task_runs.append(task_run)
62
+ return task_runs
63
+
64
+
65
+ @pytest.fixture
66
+ def standard_splitstandard_splitss():
67
+ return [
68
+ DatasetSplitDefinition(name="train", percentage=0.8),
69
+ DatasetSplitDefinition(name="test", percentage=0.2),
70
+ ]
71
+
72
+
73
+ @pytest.fixture
74
+ def task_run():
75
+ return TaskRun(
76
+ input="test input",
77
+ input_source=DataSource(
78
+ type=DataSourceType.human,
79
+ properties={"created_by": "test-user"},
80
+ ),
81
+ output=TaskOutput(
82
+ output="test output",
83
+ source=DataSource(
84
+ type=DataSourceType.human,
85
+ properties={"created_by": "test-user"},
86
+ ),
87
+ rating=TaskOutputRating(rating=5, type=TaskOutputRatingType.five_star),
88
+ ),
89
+ )
90
+
91
+
92
+ def test_dataset_split_definition():
93
+ split = DatasetSplitDefinition(name="train", percentage=0.8)
94
+ assert split.name == "train"
95
+ assert split.percentage == 0.8
96
+ assert split.description is None
97
+
98
+ # Test validation
99
+ with pytest.raises(ValidationError):
100
+ DatasetSplitDefinition(name="train", percentage=1.5)
101
+
102
+
103
+ def test_dataset_split_validation():
104
+ # Test valid percentages
105
+ splits = [
106
+ DatasetSplitDefinition(name="train", percentage=0.8),
107
+ DatasetSplitDefinition(name="test", percentage=0.2),
108
+ ]
109
+ dataset = DatasetSplit(
110
+ name="test_split",
111
+ splits=splits,
112
+ split_contents={"train": [], "test": []},
113
+ )
114
+ assert dataset.splits == splits
115
+
116
+ # Test invalid percentages
117
+ invalid_splits = [
118
+ DatasetSplitDefinition(name="train", percentage=0.8),
119
+ DatasetSplitDefinition(name="test", percentage=0.3),
120
+ ]
121
+ with pytest.raises(ValueError, match="sum of split percentages must be 1.0"):
122
+ DatasetSplit(
123
+ name="test_split",
124
+ splits=invalid_splits,
125
+ split_contents={"train": [], "test": []},
126
+ )
127
+
128
+
129
+ def test_all_dataset_filter(task_run):
130
+ assert AllDatasetFilter(task_run) is True
131
+
132
+
133
+ def test_high_rating_dataset_filter(sample_task_runs):
134
+ for task_run in sample_task_runs:
135
+ assert HighRatingDatasetFilter(task_run) is (
136
+ task_run.output.rating.is_high_quality()
137
+ )
138
+
139
+
140
+ @pytest.mark.parametrize(
141
+ "splits,expected_sizes",
142
+ [
143
+ (Train80Test20SplitDefinition, {"train": 8, "test": 2}),
144
+ (AllSplitDefinition, {"all": 10}),
145
+ (Train60Test20Val20SplitDefinition, {"train": 6, "test": 2, "val": 2}),
146
+ (
147
+ [
148
+ DatasetSplitDefinition(name="train", percentage=0.7),
149
+ DatasetSplitDefinition(name="validation", percentage=0.2),
150
+ DatasetSplitDefinition(name="test", percentage=0.1),
151
+ ],
152
+ {"train": 7, "validation": 2, "test": 1},
153
+ ),
154
+ ],
155
+ )
156
+ def test_dataset_split_from_task(sample_task, sample_task_runs, splits, expected_sizes):
157
+ assert sample_task_runs is not None
158
+ dataset = DatasetSplit.from_task("Split Name", sample_task, splits)
159
+ assert dataset.name == "Split Name"
160
+
161
+ # Check split sizes match expected
162
+ for split_name, expected_size in expected_sizes.items():
163
+ assert len(dataset.split_contents[split_name]) == expected_size
164
+
165
+ # Verify total size matches input size
166
+ total_size = sum(len(ids) for ids in dataset.split_contents.values())
167
+ assert total_size == len(sample_task_runs)
168
+
169
+
170
+ def test_dataset_split_with_high_rating_filter(sample_task, sample_task_runs):
171
+ assert len(sample_task_runs) == 10
172
+ dataset = DatasetSplit.from_task(
173
+ "Split Name",
174
+ sample_task,
175
+ Train80Test20SplitDefinition,
176
+ filter=HighRatingDatasetFilter,
177
+ )
178
+
179
+ # Check that only high-rated task runs are included
180
+ all_ids = []
181
+ for ids in dataset.split_contents.values():
182
+ all_ids.extend(ids)
183
+ assert len(all_ids) == 6 # We created 6 high-rated task runs
184
+
185
+ # Check split proportions
186
+ train_size = len(dataset.split_contents["train"])
187
+ test_size = len(dataset.split_contents["test"])
188
+ assert train_size == 5 # ~80% of 6
189
+ assert test_size == 1 # ~20% of 6
190
+
191
+
192
+ def test_dataset_split_with_single_split(sample_task, sample_task_runs):
193
+ splits = [DatasetSplitDefinition(name="all", percentage=1.0)]
194
+ dataset = DatasetSplit.from_task("Split Name", sample_task, splits)
195
+
196
+ assert len(dataset.split_contents["all"]) == len(sample_task_runs)
197
+
198
+
199
+ def test_missing_count(sample_task, sample_task_runs):
200
+ assert sample_task_runs is not None
201
+ # Create a dataset split with all task runs
202
+ dataset = DatasetSplit.from_task(
203
+ "Split Name", sample_task, Train80Test20SplitDefinition
204
+ )
205
+
206
+ # Initially there should be no missing runs
207
+ assert dataset.missing_count() == 0
208
+
209
+ # Add some IDs to the split, that aren't on disk
210
+ dataset.split_contents["test"].append("1")
211
+ dataset.split_contents["test"].append("2")
212
+ dataset.split_contents["test"].append("3")
213
+ # shouldn't happen, but should not double count if it does
214
+ dataset.split_contents["train"].append("3")
215
+
216
+ # Now we should have 3 missing runs
217
+ assert dataset.missing_count() == 3
218
+
219
+
220
+ def test_smaller_sample(sample_task, sample_task_runs):
221
+ assert sample_task_runs is not None
222
+ # Create a dataset split with all task runs
223
+ dataset = DatasetSplit.from_task(
224
+ "Split Name", sample_task, Train80Test20SplitDefinition
225
+ )
226
+
227
+ # Initially there should be no missing runs
228
+ assert dataset.missing_count() == 0
229
+
230
+ dataset.split_contents["test"].pop()
231
+ dataset.split_contents["train"].pop()
232
+
233
+ # Now we should have 0 missing runs. It's okay that dataset has newer data.
234
+ assert dataset.missing_count() == 0
@@ -5,8 +5,10 @@ import pytest
5
5
  from pydantic import ValidationError
6
6
 
7
7
  from kiln_ai.datamodel import (
8
+ DatasetSplit,
8
9
  DataSource,
9
10
  DataSourceType,
11
+ Finetune,
10
12
  Project,
11
13
  Task,
12
14
  TaskDeterminism,
@@ -97,6 +99,16 @@ def test_task_run_relationship(valid_task_run):
97
99
  assert valid_task_run.__class__.parent_type().__name__ == "Task"
98
100
 
99
101
 
102
+ def test_dataset_split_relationship():
103
+ assert DatasetSplit.relationship_name() == "dataset_splits"
104
+ assert DatasetSplit.parent_type().__name__ == "Task"
105
+
106
+
107
+ def test_base_finetune_relationship():
108
+ assert Finetune.relationship_name() == "finetunes"
109
+ assert Finetune.parent_type().__name__ == "Task"
110
+
111
+
100
112
  def test_structured_output_workflow(tmp_path):
101
113
  tmp_project_file = (
102
114
  tmp_path / "test_structured_output_runs" / Project.base_filename()