kiln-ai 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (44) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +19 -0
  3. kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
  4. kiln_ai/adapters/fine_tune/__init__.py +14 -0
  5. kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
  6. kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
  7. kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
  8. kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
  9. kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
  10. kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
  11. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
  12. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
  13. kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
  14. kiln_ai/adapters/langchain_adapters.py +103 -13
  15. kiln_ai/adapters/ml_model_list.py +239 -303
  16. kiln_ai/adapters/ollama_tools.py +115 -0
  17. kiln_ai/adapters/provider_tools.py +308 -0
  18. kiln_ai/adapters/repair/repair_task.py +4 -2
  19. kiln_ai/adapters/repair/test_repair_task.py +6 -11
  20. kiln_ai/adapters/test_langchain_adapter.py +229 -18
  21. kiln_ai/adapters/test_ollama_tools.py +42 -0
  22. kiln_ai/adapters/test_prompt_adaptors.py +7 -5
  23. kiln_ai/adapters/test_provider_tools.py +531 -0
  24. kiln_ai/adapters/test_structured_output.py +22 -43
  25. kiln_ai/datamodel/__init__.py +287 -24
  26. kiln_ai/datamodel/basemodel.py +122 -38
  27. kiln_ai/datamodel/model_cache.py +116 -0
  28. kiln_ai/datamodel/registry.py +31 -0
  29. kiln_ai/datamodel/test_basemodel.py +167 -4
  30. kiln_ai/datamodel/test_dataset_split.py +234 -0
  31. kiln_ai/datamodel/test_example_models.py +12 -0
  32. kiln_ai/datamodel/test_model_cache.py +244 -0
  33. kiln_ai/datamodel/test_models.py +215 -1
  34. kiln_ai/datamodel/test_registry.py +96 -0
  35. kiln_ai/utils/config.py +14 -1
  36. kiln_ai/utils/name_generator.py +125 -0
  37. kiln_ai/utils/test_name_geneator.py +47 -0
  38. kiln_ai-0.7.1.dist-info/METADATA +237 -0
  39. kiln_ai-0.7.1.dist-info/RECORD +58 -0
  40. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/WHEEL +1 -1
  41. kiln_ai/adapters/test_ml_model_list.py +0 -181
  42. kiln_ai-0.6.1.dist-info/METADATA +0 -88
  43. kiln_ai-0.6.1.dist-info/RECORD +0 -37
  44. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,116 @@
1
+ """
2
+ A simple cache for our datamodel.
3
+
4
+ Works at the file level, caching the pydantic model based on the file path.
5
+
6
+ Keeping this really simple. Our goal is to really be "disk-backed" data model, so using disk primitives.
7
+
8
+ - Use disk mtime to determine if the cached model is stale.
9
+ - Still using glob for iterating over projects, just caching at the file level
10
+ - Use path as the cache key
11
+ - Cache always populated from a disk read, so we know it refects what's on disk. Even if we had a memory-constructed version, we don't cache that.
12
+ - Cache the parsed model, not the raw file contents. Parsing and validating is what's expensive. >99% speedup when measured.
13
+ """
14
+
15
+ import os
16
+ import sys
17
+ import warnings
18
+ from pathlib import Path
19
+ from typing import Dict, Optional, Tuple, Type, TypeVar
20
+
21
+ from pydantic import BaseModel
22
+
23
+ T = TypeVar("T", bound=BaseModel)
24
+
25
+
26
+ class ModelCache:
27
+ _shared_instance = None
28
+
29
+ def __init__(self):
30
+ # Store both the model and the modified time of the cached file contents
31
+ self.model_cache: Dict[Path, Tuple[BaseModel, int]] = {}
32
+ self._enabled = self._check_timestamp_granularity()
33
+ if not self._enabled:
34
+ warnings.warn(
35
+ "File system does not support fine-grained timestamps. "
36
+ "Model caching has been disabled to ensure consistency."
37
+ )
38
+
39
+ @classmethod
40
+ def shared(cls):
41
+ if cls._shared_instance is None:
42
+ cls._shared_instance = cls()
43
+ return cls._shared_instance
44
+
45
+ def _is_cache_valid(self, path: Path, cached_mtime_ns: int) -> bool:
46
+ try:
47
+ current_mtime_ns = path.stat().st_mtime_ns
48
+ except Exception:
49
+ return False
50
+ return cached_mtime_ns == current_mtime_ns
51
+
52
+ def _get_model(self, path: Path, model_type: Type[T]) -> Optional[T]:
53
+ if path not in self.model_cache:
54
+ return None
55
+ model, cached_mtime_ns = self.model_cache[path]
56
+ if not self._is_cache_valid(path, cached_mtime_ns):
57
+ self.invalidate(path)
58
+ return None
59
+
60
+ if not isinstance(model, model_type):
61
+ self.invalidate(path)
62
+ raise ValueError(f"Model at {path} is not of type {model_type.__name__}")
63
+ return model
64
+
65
+ def get_model(self, path: Path, model_type: Type[T]) -> Optional[T]:
66
+ # We return a copy so in-memory edits don't impact the cache until they are saved
67
+ # Benchmark shows about 2x slower, but much more foolproof
68
+ model = self._get_model(path, model_type)
69
+ if model:
70
+ return model.model_copy(deep=True)
71
+ return None
72
+
73
+ def get_model_id(self, path: Path, model_type: Type[T]) -> Optional[str]:
74
+ model = self._get_model(path, model_type)
75
+ if model and hasattr(model, "id"):
76
+ id = model.id # type: ignore
77
+ if isinstance(id, str):
78
+ return id
79
+ return None
80
+
81
+ def set_model(self, path: Path, model: BaseModel, mtime_ns: int):
82
+ # disable caching if the filesystem doesn't support fine-grained timestamps
83
+ if not self._enabled:
84
+ return
85
+ self.model_cache[path] = (model, mtime_ns)
86
+
87
+ def invalidate(self, path: Path):
88
+ if path in self.model_cache:
89
+ del self.model_cache[path]
90
+
91
+ def clear(self):
92
+ self.model_cache.clear()
93
+
94
+ def _check_timestamp_granularity(self) -> bool:
95
+ """Check if filesystem supports fine-grained timestamps (microseconds or better)."""
96
+
97
+ # MacOS and Windows support fine-grained timestamps
98
+ if sys.platform in ["darwin", "win32"]:
99
+ return True
100
+
101
+ # Linux supports fine-grained timestamps SOMETIMES. ext4 should work.
102
+ try:
103
+ # Get filesystem stats for the current directory
104
+ stats = os.statvfs(Path(__file__).parent)
105
+
106
+ # f_timespec was added in Linux 5.6 (2020)
107
+ # Returns nanoseconds precision as a power of 10
108
+ # e.g., 1 = decisecond, 2 = centisecond, 3 = millisecond, etc.
109
+ timespec = getattr(stats, "f_timespec", 0)
110
+
111
+ # Consider microsecond precision (6) or better as "fine-grained"
112
+ return timespec >= 6
113
+ except (AttributeError, OSError):
114
+ # If f_timespec isn't available or other errors occur,
115
+ # assume poor granularity to be safe
116
+ return False
@@ -0,0 +1,31 @@
1
+ from kiln_ai.datamodel import Project
2
+ from kiln_ai.utils.config import Config
3
+
4
+
5
+ def all_projects() -> list[Project]:
6
+ project_paths = Config.shared().projects
7
+ if project_paths is None:
8
+ return []
9
+ projects = []
10
+ for project_path in project_paths:
11
+ try:
12
+ projects.append(Project.load_from_file(project_path))
13
+ except Exception:
14
+ # deleted files are possible continue with the rest
15
+ continue
16
+ return projects
17
+
18
+
19
+ def project_from_id(project_id: str) -> Project | None:
20
+ project_paths = Config.shared().projects
21
+ if project_paths is not None:
22
+ for project_path in project_paths:
23
+ try:
24
+ project = Project.load_from_file(project_path)
25
+ if project.id == project_id:
26
+ return project
27
+ except Exception:
28
+ # deleted files are possible continue with the rest
29
+ continue
30
+
31
+ return None
@@ -2,10 +2,16 @@ import datetime
2
2
  import json
3
3
  from pathlib import Path
4
4
  from typing import Optional
5
+ from unittest.mock import MagicMock, patch
5
6
 
6
7
  import pytest
7
8
 
8
- from kiln_ai.datamodel.basemodel import KilnBaseModel, KilnParentedModel
9
+ from kiln_ai.datamodel.basemodel import (
10
+ KilnBaseModel,
11
+ KilnParentedModel,
12
+ string_to_valid_name,
13
+ )
14
+ from kiln_ai.datamodel.model_cache import ModelCache
9
15
 
10
16
 
11
17
  @pytest.fixture
@@ -41,6 +47,17 @@ def test_newer_file(tmp_path) -> Path:
41
47
  return test_file_path
42
48
 
43
49
 
50
+ @pytest.fixture
51
+ def tmp_model_cache():
52
+ temp_cache = ModelCache()
53
+ # We're testing integration, not cache functions, in this file
54
+ temp_cache._enabled = True
55
+ with (
56
+ patch("kiln_ai.datamodel.basemodel.ModelCache.shared", return_value=temp_cache),
57
+ ):
58
+ yield temp_cache
59
+
60
+
44
61
  def test_load_from_file(test_base_file):
45
62
  model = KilnBaseModel.load_from_file(test_base_file)
46
63
  assert model.v == 1
@@ -273,9 +290,8 @@ def test_lazy_load_parent(tmp_path):
273
290
  assert loaded_parent.name == "Parent"
274
291
  assert loaded_parent.path == parent.path
275
292
 
276
- # Verify that the _parent attribute is now set
277
- assert hasattr(loaded_child, "_parent")
278
- assert loaded_child._parent is loaded_parent
293
+ # Verify that the parent is cached
294
+ assert loaded_child.cached_parent() is loaded_parent
279
295
 
280
296
 
281
297
  def test_delete(tmp_path):
@@ -306,3 +322,150 @@ def test_delete_no_path():
306
322
  model = KilnBaseModel()
307
323
  with pytest.raises(ValueError, match="Cannot delete model because path is not set"):
308
324
  model.delete()
325
+
326
+
327
+ def test_string_to_valid_name():
328
+ # Test basic valid strings remain unchanged
329
+ assert string_to_valid_name("Hello World") == "Hello World"
330
+ assert string_to_valid_name("Test-123") == "Test-123"
331
+ assert string_to_valid_name("my_file_name") == "my_file_name"
332
+
333
+ # Test invalid characters are replaced
334
+ assert string_to_valid_name("Hello@World!") == "Hello_World"
335
+ assert string_to_valid_name("File.name.txt") == "File_name_txt"
336
+ assert string_to_valid_name("Special#$%Chars") == "Special_Chars"
337
+
338
+ # Test consecutive invalid characters
339
+ assert string_to_valid_name("multiple!!!symbols") == "multiple_symbols"
340
+ assert string_to_valid_name("path/to/file") == "path_to_file"
341
+
342
+ # Test leading/trailing special characters
343
+ assert string_to_valid_name("__test__") == "test"
344
+ assert string_to_valid_name("...test...") == "test"
345
+
346
+ # Test empty string and whitespace
347
+ assert string_to_valid_name("") == ""
348
+ assert string_to_valid_name(" ") == ""
349
+
350
+
351
+ def test_load_from_file_with_cache(test_base_file, tmp_model_cache):
352
+ tmp_model_cache.get_model = MagicMock(return_value=None)
353
+ tmp_model_cache.set_model = MagicMock()
354
+
355
+ # Load the model
356
+ model = KilnBaseModel.load_from_file(test_base_file)
357
+
358
+ # Check that the cache was checked and set
359
+ tmp_model_cache.get_model.assert_called_once_with(test_base_file, KilnBaseModel)
360
+ tmp_model_cache.set_model.assert_called_once()
361
+
362
+ # Ensure the model is correctly loaded
363
+ assert model.v == 1
364
+ assert model.path == test_base_file
365
+
366
+
367
+ def test_save_to_file_invalidates_cache(test_base_file, tmp_model_cache):
368
+ # Create and save the model
369
+ model = KilnBaseModel(path=test_base_file)
370
+
371
+ # Set mock after to ignore any previous calls, we want to see save calls it
372
+ tmp_model_cache.invalidate = MagicMock()
373
+ model.save_to_file()
374
+
375
+ # Check that the cache was invalidated. Might be called multiple times for setting props like path. but must be called at least once.
376
+ tmp_model_cache.invalidate.assert_called_with(test_base_file)
377
+
378
+
379
+ def test_delete_invalidates_cache(tmp_path, tmp_model_cache):
380
+ # Create and save the model
381
+ file_path = tmp_path / "test.kiln"
382
+ model = KilnBaseModel(path=file_path)
383
+ model.save_to_file()
384
+
385
+ # populate and check cache
386
+ model = KilnBaseModel.load_from_file(file_path)
387
+ cached_model = tmp_model_cache.get_model(file_path, KilnBaseModel)
388
+ assert cached_model.id == model.id
389
+
390
+ tmp_model_cache.invalidate = MagicMock()
391
+
392
+ # Delete the model
393
+ model.delete()
394
+
395
+ # Check that the cache was invalidated
396
+ tmp_model_cache.invalidate.assert_called_with(file_path)
397
+ assert tmp_model_cache.get_model(file_path, KilnBaseModel) is None
398
+
399
+
400
+ def test_load_from_file_with_cached_model(test_base_file, tmp_model_cache):
401
+ # Set up the mock to return a cached model
402
+ cached_model = KilnBaseModel(v=1, path=test_base_file)
403
+ tmp_model_cache.get_model = MagicMock(return_value=cached_model)
404
+
405
+ with patch("builtins.open", create=True) as mock_open:
406
+ # Load the model
407
+ model = KilnBaseModel.load_from_file(test_base_file)
408
+
409
+ # Check that the cache was checked and the cached model was returned
410
+ tmp_model_cache.get_model.assert_called_once_with(test_base_file, KilnBaseModel)
411
+ assert model is cached_model
412
+
413
+ # Assert that open was not called (we used the cached model, not file)
414
+ mock_open.assert_not_called()
415
+
416
+
417
+ def test_from_id_and_parent_path(test_base_parented_file, tmp_model_cache):
418
+ # Set up parent and children models
419
+ parent = BaseParentExample.load_from_file(test_base_parented_file)
420
+
421
+ child1 = DefaultParentedModel(parent=parent, name="Child1")
422
+ child2 = DefaultParentedModel(parent=parent, name="Child2")
423
+ child3 = DefaultParentedModel(parent=parent, name="Child3")
424
+
425
+ # Save all children
426
+ child1.save_to_file()
427
+ child2.save_to_file()
428
+ child3.save_to_file()
429
+
430
+ # Test finding existing child by ID
431
+ found_child = DefaultParentedModel.from_id_and_parent_path(
432
+ child2.id, test_base_parented_file
433
+ )
434
+ assert found_child is not None
435
+ assert found_child.id == child2.id
436
+ assert found_child.name == "Child2"
437
+ assert found_child is not child2 # not same instance (deep copy)
438
+
439
+ # Test non-existent ID returns None
440
+ not_found = DefaultParentedModel.from_id_and_parent_path(
441
+ "nonexistent", test_base_parented_file
442
+ )
443
+ assert not_found is None
444
+
445
+
446
+ def test_from_id_and_parent_path_with_cache(test_base_parented_file, tmp_model_cache):
447
+ # Set up parent and child
448
+ parent = BaseParentExample.load_from_file(test_base_parented_file)
449
+ child = DefaultParentedModel(parent=parent, name="Child")
450
+ child.save_to_file()
451
+
452
+ # First load to populate cache
453
+ _ = DefaultParentedModel.from_id_and_parent_path(child.id, test_base_parented_file)
454
+
455
+ # Mock cache to verify it's used
456
+ tmp_model_cache.get_model_id = MagicMock(return_value=child.id)
457
+
458
+ # Load again - should use cache
459
+ found_child = DefaultParentedModel.from_id_and_parent_path(
460
+ child.id, test_base_parented_file
461
+ )
462
+
463
+ assert found_child is not None
464
+ assert found_child.id == child.id
465
+ tmp_model_cache.get_model_id.assert_called()
466
+
467
+
468
+ def test_from_id_and_parent_path_without_parent():
469
+ # Test with None parent_path
470
+ not_found = DefaultParentedModel.from_id_and_parent_path("any-id", None)
471
+ assert not_found is None
@@ -0,0 +1,234 @@
1
+ import pytest
2
+ from pydantic import ValidationError
3
+
4
+ # import datamodel first or we get circular import errors
5
+ from kiln_ai.datamodel import (
6
+ AllDatasetFilter,
7
+ AllSplitDefinition,
8
+ DatasetSplit,
9
+ DatasetSplitDefinition,
10
+ DataSource,
11
+ DataSourceType,
12
+ HighRatingDatasetFilter,
13
+ Task,
14
+ TaskOutput,
15
+ TaskOutputRating,
16
+ TaskOutputRatingType,
17
+ TaskRun,
18
+ Train60Test20Val20SplitDefinition,
19
+ Train80Test20SplitDefinition,
20
+ )
21
+
22
+
23
+ @pytest.fixture
24
+ def sample_task(tmp_path):
25
+ task_path = tmp_path / "task.kiln"
26
+ task = Task(
27
+ name="Test Task",
28
+ path=task_path,
29
+ description="Test task for dataset splitting",
30
+ instruction="Test instruction",
31
+ )
32
+ task.save_to_file()
33
+ return task
34
+
35
+
36
+ @pytest.fixture
37
+ def sample_task_runs(sample_task):
38
+ # Create 10 task runs with different ratings
39
+ task_runs = []
40
+ for i in range(10):
41
+ rating = 5 if i < 6 else 1 # 6 high, 4 low ratings
42
+ task_run = TaskRun(
43
+ parent=sample_task,
44
+ input=f"input_{i}",
45
+ input_source=DataSource(
46
+ type=DataSourceType.human,
47
+ properties={"created_by": "test-user"},
48
+ ),
49
+ output=TaskOutput(
50
+ output=f"output_{i}",
51
+ source=DataSource(
52
+ type=DataSourceType.human,
53
+ properties={"created_by": "test-user"},
54
+ ),
55
+ rating=TaskOutputRating(
56
+ value=rating, type=TaskOutputRatingType.five_star
57
+ ),
58
+ ),
59
+ )
60
+ task_run.save_to_file()
61
+ task_runs.append(task_run)
62
+ return task_runs
63
+
64
+
65
+ @pytest.fixture
66
+ def standard_splitstandard_splitss():
67
+ return [
68
+ DatasetSplitDefinition(name="train", percentage=0.8),
69
+ DatasetSplitDefinition(name="test", percentage=0.2),
70
+ ]
71
+
72
+
73
+ @pytest.fixture
74
+ def task_run():
75
+ return TaskRun(
76
+ input="test input",
77
+ input_source=DataSource(
78
+ type=DataSourceType.human,
79
+ properties={"created_by": "test-user"},
80
+ ),
81
+ output=TaskOutput(
82
+ output="test output",
83
+ source=DataSource(
84
+ type=DataSourceType.human,
85
+ properties={"created_by": "test-user"},
86
+ ),
87
+ rating=TaskOutputRating(rating=5, type=TaskOutputRatingType.five_star),
88
+ ),
89
+ )
90
+
91
+
92
+ def test_dataset_split_definition():
93
+ split = DatasetSplitDefinition(name="train", percentage=0.8)
94
+ assert split.name == "train"
95
+ assert split.percentage == 0.8
96
+ assert split.description is None
97
+
98
+ # Test validation
99
+ with pytest.raises(ValidationError):
100
+ DatasetSplitDefinition(name="train", percentage=1.5)
101
+
102
+
103
+ def test_dataset_split_validation():
104
+ # Test valid percentages
105
+ splits = [
106
+ DatasetSplitDefinition(name="train", percentage=0.8),
107
+ DatasetSplitDefinition(name="test", percentage=0.2),
108
+ ]
109
+ dataset = DatasetSplit(
110
+ name="test_split",
111
+ splits=splits,
112
+ split_contents={"train": [], "test": []},
113
+ )
114
+ assert dataset.splits == splits
115
+
116
+ # Test invalid percentages
117
+ invalid_splits = [
118
+ DatasetSplitDefinition(name="train", percentage=0.8),
119
+ DatasetSplitDefinition(name="test", percentage=0.3),
120
+ ]
121
+ with pytest.raises(ValueError, match="sum of split percentages must be 1.0"):
122
+ DatasetSplit(
123
+ name="test_split",
124
+ splits=invalid_splits,
125
+ split_contents={"train": [], "test": []},
126
+ )
127
+
128
+
129
+ def test_all_dataset_filter(task_run):
130
+ assert AllDatasetFilter(task_run) is True
131
+
132
+
133
+ def test_high_rating_dataset_filter(sample_task_runs):
134
+ for task_run in sample_task_runs:
135
+ assert HighRatingDatasetFilter(task_run) is (
136
+ task_run.output.rating.is_high_quality()
137
+ )
138
+
139
+
140
+ @pytest.mark.parametrize(
141
+ "splits,expected_sizes",
142
+ [
143
+ (Train80Test20SplitDefinition, {"train": 8, "test": 2}),
144
+ (AllSplitDefinition, {"all": 10}),
145
+ (Train60Test20Val20SplitDefinition, {"train": 6, "test": 2, "val": 2}),
146
+ (
147
+ [
148
+ DatasetSplitDefinition(name="train", percentage=0.7),
149
+ DatasetSplitDefinition(name="validation", percentage=0.2),
150
+ DatasetSplitDefinition(name="test", percentage=0.1),
151
+ ],
152
+ {"train": 7, "validation": 2, "test": 1},
153
+ ),
154
+ ],
155
+ )
156
+ def test_dataset_split_from_task(sample_task, sample_task_runs, splits, expected_sizes):
157
+ assert sample_task_runs is not None
158
+ dataset = DatasetSplit.from_task("Split Name", sample_task, splits)
159
+ assert dataset.name == "Split Name"
160
+
161
+ # Check split sizes match expected
162
+ for split_name, expected_size in expected_sizes.items():
163
+ assert len(dataset.split_contents[split_name]) == expected_size
164
+
165
+ # Verify total size matches input size
166
+ total_size = sum(len(ids) for ids in dataset.split_contents.values())
167
+ assert total_size == len(sample_task_runs)
168
+
169
+
170
+ def test_dataset_split_with_high_rating_filter(sample_task, sample_task_runs):
171
+ assert len(sample_task_runs) == 10
172
+ dataset = DatasetSplit.from_task(
173
+ "Split Name",
174
+ sample_task,
175
+ Train80Test20SplitDefinition,
176
+ filter=HighRatingDatasetFilter,
177
+ )
178
+
179
+ # Check that only high-rated task runs are included
180
+ all_ids = []
181
+ for ids in dataset.split_contents.values():
182
+ all_ids.extend(ids)
183
+ assert len(all_ids) == 6 # We created 6 high-rated task runs
184
+
185
+ # Check split proportions
186
+ train_size = len(dataset.split_contents["train"])
187
+ test_size = len(dataset.split_contents["test"])
188
+ assert train_size == 5 # ~80% of 6
189
+ assert test_size == 1 # ~20% of 6
190
+
191
+
192
+ def test_dataset_split_with_single_split(sample_task, sample_task_runs):
193
+ splits = [DatasetSplitDefinition(name="all", percentage=1.0)]
194
+ dataset = DatasetSplit.from_task("Split Name", sample_task, splits)
195
+
196
+ assert len(dataset.split_contents["all"]) == len(sample_task_runs)
197
+
198
+
199
+ def test_missing_count(sample_task, sample_task_runs):
200
+ assert sample_task_runs is not None
201
+ # Create a dataset split with all task runs
202
+ dataset = DatasetSplit.from_task(
203
+ "Split Name", sample_task, Train80Test20SplitDefinition
204
+ )
205
+
206
+ # Initially there should be no missing runs
207
+ assert dataset.missing_count() == 0
208
+
209
+ # Add some IDs to the split, that aren't on disk
210
+ dataset.split_contents["test"].append("1")
211
+ dataset.split_contents["test"].append("2")
212
+ dataset.split_contents["test"].append("3")
213
+ # shouldn't happen, but should not double count if it does
214
+ dataset.split_contents["train"].append("3")
215
+
216
+ # Now we should have 3 missing runs
217
+ assert dataset.missing_count() == 3
218
+
219
+
220
+ def test_smaller_sample(sample_task, sample_task_runs):
221
+ assert sample_task_runs is not None
222
+ # Create a dataset split with all task runs
223
+ dataset = DatasetSplit.from_task(
224
+ "Split Name", sample_task, Train80Test20SplitDefinition
225
+ )
226
+
227
+ # Initially there should be no missing runs
228
+ assert dataset.missing_count() == 0
229
+
230
+ dataset.split_contents["test"].pop()
231
+ dataset.split_contents["train"].pop()
232
+
233
+ # Now we should have 0 missing runs. It's okay that dataset has newer data.
234
+ assert dataset.missing_count() == 0
@@ -5,8 +5,10 @@ import pytest
5
5
  from pydantic import ValidationError
6
6
 
7
7
  from kiln_ai.datamodel import (
8
+ DatasetSplit,
8
9
  DataSource,
9
10
  DataSourceType,
11
+ Finetune,
10
12
  Project,
11
13
  Task,
12
14
  TaskDeterminism,
@@ -97,6 +99,16 @@ def test_task_run_relationship(valid_task_run):
97
99
  assert valid_task_run.__class__.parent_type().__name__ == "Task"
98
100
 
99
101
 
102
+ def test_dataset_split_relationship():
103
+ assert DatasetSplit.relationship_name() == "dataset_splits"
104
+ assert DatasetSplit.parent_type().__name__ == "Task"
105
+
106
+
107
+ def test_base_finetune_relationship():
108
+ assert Finetune.relationship_name() == "finetunes"
109
+ assert Finetune.parent_type().__name__ == "Task"
110
+
111
+
100
112
  def test_structured_output_workflow(tmp_path):
101
113
  tmp_project_file = (
102
114
  tmp_path / "test_structured_output_runs" / Project.base_filename()