kiln-ai 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +19 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
- kiln_ai/adapters/fine_tune/__init__.py +14 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
- kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
- kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
- kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
- kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
- kiln_ai/adapters/langchain_adapters.py +103 -13
- kiln_ai/adapters/ml_model_list.py +239 -303
- kiln_ai/adapters/ollama_tools.py +115 -0
- kiln_ai/adapters/provider_tools.py +308 -0
- kiln_ai/adapters/repair/repair_task.py +4 -2
- kiln_ai/adapters/repair/test_repair_task.py +6 -11
- kiln_ai/adapters/test_langchain_adapter.py +229 -18
- kiln_ai/adapters/test_ollama_tools.py +42 -0
- kiln_ai/adapters/test_prompt_adaptors.py +7 -5
- kiln_ai/adapters/test_provider_tools.py +531 -0
- kiln_ai/adapters/test_structured_output.py +22 -43
- kiln_ai/datamodel/__init__.py +287 -24
- kiln_ai/datamodel/basemodel.py +122 -38
- kiln_ai/datamodel/model_cache.py +116 -0
- kiln_ai/datamodel/registry.py +31 -0
- kiln_ai/datamodel/test_basemodel.py +167 -4
- kiln_ai/datamodel/test_dataset_split.py +234 -0
- kiln_ai/datamodel/test_example_models.py +12 -0
- kiln_ai/datamodel/test_model_cache.py +244 -0
- kiln_ai/datamodel/test_models.py +215 -1
- kiln_ai/datamodel/test_registry.py +96 -0
- kiln_ai/utils/config.py +14 -1
- kiln_ai/utils/name_generator.py +125 -0
- kiln_ai/utils/test_name_geneator.py +47 -0
- kiln_ai-0.7.1.dist-info/METADATA +237 -0
- kiln_ai-0.7.1.dist-info/RECORD +58 -0
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/WHEEL +1 -1
- kiln_ai/adapters/test_ml_model_list.py +0 -181
- kiln_ai-0.6.1.dist-info/METADATA +0 -88
- kiln_ai-0.6.1.dist-info/RECORD +0 -37
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""
|
|
2
|
+
A simple cache for our datamodel.
|
|
3
|
+
|
|
4
|
+
Works at the file level, caching the pydantic model based on the file path.
|
|
5
|
+
|
|
6
|
+
Keeping this really simple. Our goal is to really be "disk-backed" data model, so using disk primitives.
|
|
7
|
+
|
|
8
|
+
- Use disk mtime to determine if the cached model is stale.
|
|
9
|
+
- Still using glob for iterating over projects, just caching at the file level
|
|
10
|
+
- Use path as the cache key
|
|
11
|
+
- Cache always populated from a disk read, so we know it refects what's on disk. Even if we had a memory-constructed version, we don't cache that.
|
|
12
|
+
- Cache the parsed model, not the raw file contents. Parsing and validating is what's expensive. >99% speedup when measured.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import sys
|
|
17
|
+
import warnings
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Dict, Optional, Tuple, Type, TypeVar
|
|
20
|
+
|
|
21
|
+
from pydantic import BaseModel
|
|
22
|
+
|
|
23
|
+
T = TypeVar("T", bound=BaseModel)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ModelCache:
|
|
27
|
+
_shared_instance = None
|
|
28
|
+
|
|
29
|
+
def __init__(self):
|
|
30
|
+
# Store both the model and the modified time of the cached file contents
|
|
31
|
+
self.model_cache: Dict[Path, Tuple[BaseModel, int]] = {}
|
|
32
|
+
self._enabled = self._check_timestamp_granularity()
|
|
33
|
+
if not self._enabled:
|
|
34
|
+
warnings.warn(
|
|
35
|
+
"File system does not support fine-grained timestamps. "
|
|
36
|
+
"Model caching has been disabled to ensure consistency."
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def shared(cls):
|
|
41
|
+
if cls._shared_instance is None:
|
|
42
|
+
cls._shared_instance = cls()
|
|
43
|
+
return cls._shared_instance
|
|
44
|
+
|
|
45
|
+
def _is_cache_valid(self, path: Path, cached_mtime_ns: int) -> bool:
|
|
46
|
+
try:
|
|
47
|
+
current_mtime_ns = path.stat().st_mtime_ns
|
|
48
|
+
except Exception:
|
|
49
|
+
return False
|
|
50
|
+
return cached_mtime_ns == current_mtime_ns
|
|
51
|
+
|
|
52
|
+
def _get_model(self, path: Path, model_type: Type[T]) -> Optional[T]:
|
|
53
|
+
if path not in self.model_cache:
|
|
54
|
+
return None
|
|
55
|
+
model, cached_mtime_ns = self.model_cache[path]
|
|
56
|
+
if not self._is_cache_valid(path, cached_mtime_ns):
|
|
57
|
+
self.invalidate(path)
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
if not isinstance(model, model_type):
|
|
61
|
+
self.invalidate(path)
|
|
62
|
+
raise ValueError(f"Model at {path} is not of type {model_type.__name__}")
|
|
63
|
+
return model
|
|
64
|
+
|
|
65
|
+
def get_model(self, path: Path, model_type: Type[T]) -> Optional[T]:
|
|
66
|
+
# We return a copy so in-memory edits don't impact the cache until they are saved
|
|
67
|
+
# Benchmark shows about 2x slower, but much more foolproof
|
|
68
|
+
model = self._get_model(path, model_type)
|
|
69
|
+
if model:
|
|
70
|
+
return model.model_copy(deep=True)
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
def get_model_id(self, path: Path, model_type: Type[T]) -> Optional[str]:
|
|
74
|
+
model = self._get_model(path, model_type)
|
|
75
|
+
if model and hasattr(model, "id"):
|
|
76
|
+
id = model.id # type: ignore
|
|
77
|
+
if isinstance(id, str):
|
|
78
|
+
return id
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
def set_model(self, path: Path, model: BaseModel, mtime_ns: int):
|
|
82
|
+
# disable caching if the filesystem doesn't support fine-grained timestamps
|
|
83
|
+
if not self._enabled:
|
|
84
|
+
return
|
|
85
|
+
self.model_cache[path] = (model, mtime_ns)
|
|
86
|
+
|
|
87
|
+
def invalidate(self, path: Path):
|
|
88
|
+
if path in self.model_cache:
|
|
89
|
+
del self.model_cache[path]
|
|
90
|
+
|
|
91
|
+
def clear(self):
|
|
92
|
+
self.model_cache.clear()
|
|
93
|
+
|
|
94
|
+
def _check_timestamp_granularity(self) -> bool:
|
|
95
|
+
"""Check if filesystem supports fine-grained timestamps (microseconds or better)."""
|
|
96
|
+
|
|
97
|
+
# MacOS and Windows support fine-grained timestamps
|
|
98
|
+
if sys.platform in ["darwin", "win32"]:
|
|
99
|
+
return True
|
|
100
|
+
|
|
101
|
+
# Linux supports fine-grained timestamps SOMETIMES. ext4 should work.
|
|
102
|
+
try:
|
|
103
|
+
# Get filesystem stats for the current directory
|
|
104
|
+
stats = os.statvfs(Path(__file__).parent)
|
|
105
|
+
|
|
106
|
+
# f_timespec was added in Linux 5.6 (2020)
|
|
107
|
+
# Returns nanoseconds precision as a power of 10
|
|
108
|
+
# e.g., 1 = decisecond, 2 = centisecond, 3 = millisecond, etc.
|
|
109
|
+
timespec = getattr(stats, "f_timespec", 0)
|
|
110
|
+
|
|
111
|
+
# Consider microsecond precision (6) or better as "fine-grained"
|
|
112
|
+
return timespec >= 6
|
|
113
|
+
except (AttributeError, OSError):
|
|
114
|
+
# If f_timespec isn't available or other errors occur,
|
|
115
|
+
# assume poor granularity to be safe
|
|
116
|
+
return False
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from kiln_ai.datamodel import Project
|
|
2
|
+
from kiln_ai.utils.config import Config
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def all_projects() -> list[Project]:
|
|
6
|
+
project_paths = Config.shared().projects
|
|
7
|
+
if project_paths is None:
|
|
8
|
+
return []
|
|
9
|
+
projects = []
|
|
10
|
+
for project_path in project_paths:
|
|
11
|
+
try:
|
|
12
|
+
projects.append(Project.load_from_file(project_path))
|
|
13
|
+
except Exception:
|
|
14
|
+
# deleted files are possible continue with the rest
|
|
15
|
+
continue
|
|
16
|
+
return projects
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def project_from_id(project_id: str) -> Project | None:
|
|
20
|
+
project_paths = Config.shared().projects
|
|
21
|
+
if project_paths is not None:
|
|
22
|
+
for project_path in project_paths:
|
|
23
|
+
try:
|
|
24
|
+
project = Project.load_from_file(project_path)
|
|
25
|
+
if project.id == project_id:
|
|
26
|
+
return project
|
|
27
|
+
except Exception:
|
|
28
|
+
# deleted files are possible continue with the rest
|
|
29
|
+
continue
|
|
30
|
+
|
|
31
|
+
return None
|
|
@@ -2,10 +2,16 @@ import datetime
|
|
|
2
2
|
import json
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import Optional
|
|
5
|
+
from unittest.mock import MagicMock, patch
|
|
5
6
|
|
|
6
7
|
import pytest
|
|
7
8
|
|
|
8
|
-
from kiln_ai.datamodel.basemodel import
|
|
9
|
+
from kiln_ai.datamodel.basemodel import (
|
|
10
|
+
KilnBaseModel,
|
|
11
|
+
KilnParentedModel,
|
|
12
|
+
string_to_valid_name,
|
|
13
|
+
)
|
|
14
|
+
from kiln_ai.datamodel.model_cache import ModelCache
|
|
9
15
|
|
|
10
16
|
|
|
11
17
|
@pytest.fixture
|
|
@@ -41,6 +47,17 @@ def test_newer_file(tmp_path) -> Path:
|
|
|
41
47
|
return test_file_path
|
|
42
48
|
|
|
43
49
|
|
|
50
|
+
@pytest.fixture
|
|
51
|
+
def tmp_model_cache():
|
|
52
|
+
temp_cache = ModelCache()
|
|
53
|
+
# We're testing integration, not cache functions, in this file
|
|
54
|
+
temp_cache._enabled = True
|
|
55
|
+
with (
|
|
56
|
+
patch("kiln_ai.datamodel.basemodel.ModelCache.shared", return_value=temp_cache),
|
|
57
|
+
):
|
|
58
|
+
yield temp_cache
|
|
59
|
+
|
|
60
|
+
|
|
44
61
|
def test_load_from_file(test_base_file):
|
|
45
62
|
model = KilnBaseModel.load_from_file(test_base_file)
|
|
46
63
|
assert model.v == 1
|
|
@@ -273,9 +290,8 @@ def test_lazy_load_parent(tmp_path):
|
|
|
273
290
|
assert loaded_parent.name == "Parent"
|
|
274
291
|
assert loaded_parent.path == parent.path
|
|
275
292
|
|
|
276
|
-
# Verify that the
|
|
277
|
-
assert
|
|
278
|
-
assert loaded_child._parent is loaded_parent
|
|
293
|
+
# Verify that the parent is cached
|
|
294
|
+
assert loaded_child.cached_parent() is loaded_parent
|
|
279
295
|
|
|
280
296
|
|
|
281
297
|
def test_delete(tmp_path):
|
|
@@ -306,3 +322,150 @@ def test_delete_no_path():
|
|
|
306
322
|
model = KilnBaseModel()
|
|
307
323
|
with pytest.raises(ValueError, match="Cannot delete model because path is not set"):
|
|
308
324
|
model.delete()
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def test_string_to_valid_name():
|
|
328
|
+
# Test basic valid strings remain unchanged
|
|
329
|
+
assert string_to_valid_name("Hello World") == "Hello World"
|
|
330
|
+
assert string_to_valid_name("Test-123") == "Test-123"
|
|
331
|
+
assert string_to_valid_name("my_file_name") == "my_file_name"
|
|
332
|
+
|
|
333
|
+
# Test invalid characters are replaced
|
|
334
|
+
assert string_to_valid_name("Hello@World!") == "Hello_World"
|
|
335
|
+
assert string_to_valid_name("File.name.txt") == "File_name_txt"
|
|
336
|
+
assert string_to_valid_name("Special#$%Chars") == "Special_Chars"
|
|
337
|
+
|
|
338
|
+
# Test consecutive invalid characters
|
|
339
|
+
assert string_to_valid_name("multiple!!!symbols") == "multiple_symbols"
|
|
340
|
+
assert string_to_valid_name("path/to/file") == "path_to_file"
|
|
341
|
+
|
|
342
|
+
# Test leading/trailing special characters
|
|
343
|
+
assert string_to_valid_name("__test__") == "test"
|
|
344
|
+
assert string_to_valid_name("...test...") == "test"
|
|
345
|
+
|
|
346
|
+
# Test empty string and whitespace
|
|
347
|
+
assert string_to_valid_name("") == ""
|
|
348
|
+
assert string_to_valid_name(" ") == ""
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def test_load_from_file_with_cache(test_base_file, tmp_model_cache):
|
|
352
|
+
tmp_model_cache.get_model = MagicMock(return_value=None)
|
|
353
|
+
tmp_model_cache.set_model = MagicMock()
|
|
354
|
+
|
|
355
|
+
# Load the model
|
|
356
|
+
model = KilnBaseModel.load_from_file(test_base_file)
|
|
357
|
+
|
|
358
|
+
# Check that the cache was checked and set
|
|
359
|
+
tmp_model_cache.get_model.assert_called_once_with(test_base_file, KilnBaseModel)
|
|
360
|
+
tmp_model_cache.set_model.assert_called_once()
|
|
361
|
+
|
|
362
|
+
# Ensure the model is correctly loaded
|
|
363
|
+
assert model.v == 1
|
|
364
|
+
assert model.path == test_base_file
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def test_save_to_file_invalidates_cache(test_base_file, tmp_model_cache):
|
|
368
|
+
# Create and save the model
|
|
369
|
+
model = KilnBaseModel(path=test_base_file)
|
|
370
|
+
|
|
371
|
+
# Set mock after to ignore any previous calls, we want to see save calls it
|
|
372
|
+
tmp_model_cache.invalidate = MagicMock()
|
|
373
|
+
model.save_to_file()
|
|
374
|
+
|
|
375
|
+
# Check that the cache was invalidated. Might be called multiple times for setting props like path. but must be called at least once.
|
|
376
|
+
tmp_model_cache.invalidate.assert_called_with(test_base_file)
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def test_delete_invalidates_cache(tmp_path, tmp_model_cache):
|
|
380
|
+
# Create and save the model
|
|
381
|
+
file_path = tmp_path / "test.kiln"
|
|
382
|
+
model = KilnBaseModel(path=file_path)
|
|
383
|
+
model.save_to_file()
|
|
384
|
+
|
|
385
|
+
# populate and check cache
|
|
386
|
+
model = KilnBaseModel.load_from_file(file_path)
|
|
387
|
+
cached_model = tmp_model_cache.get_model(file_path, KilnBaseModel)
|
|
388
|
+
assert cached_model.id == model.id
|
|
389
|
+
|
|
390
|
+
tmp_model_cache.invalidate = MagicMock()
|
|
391
|
+
|
|
392
|
+
# Delete the model
|
|
393
|
+
model.delete()
|
|
394
|
+
|
|
395
|
+
# Check that the cache was invalidated
|
|
396
|
+
tmp_model_cache.invalidate.assert_called_with(file_path)
|
|
397
|
+
assert tmp_model_cache.get_model(file_path, KilnBaseModel) is None
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def test_load_from_file_with_cached_model(test_base_file, tmp_model_cache):
|
|
401
|
+
# Set up the mock to return a cached model
|
|
402
|
+
cached_model = KilnBaseModel(v=1, path=test_base_file)
|
|
403
|
+
tmp_model_cache.get_model = MagicMock(return_value=cached_model)
|
|
404
|
+
|
|
405
|
+
with patch("builtins.open", create=True) as mock_open:
|
|
406
|
+
# Load the model
|
|
407
|
+
model = KilnBaseModel.load_from_file(test_base_file)
|
|
408
|
+
|
|
409
|
+
# Check that the cache was checked and the cached model was returned
|
|
410
|
+
tmp_model_cache.get_model.assert_called_once_with(test_base_file, KilnBaseModel)
|
|
411
|
+
assert model is cached_model
|
|
412
|
+
|
|
413
|
+
# Assert that open was not called (we used the cached model, not file)
|
|
414
|
+
mock_open.assert_not_called()
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def test_from_id_and_parent_path(test_base_parented_file, tmp_model_cache):
|
|
418
|
+
# Set up parent and children models
|
|
419
|
+
parent = BaseParentExample.load_from_file(test_base_parented_file)
|
|
420
|
+
|
|
421
|
+
child1 = DefaultParentedModel(parent=parent, name="Child1")
|
|
422
|
+
child2 = DefaultParentedModel(parent=parent, name="Child2")
|
|
423
|
+
child3 = DefaultParentedModel(parent=parent, name="Child3")
|
|
424
|
+
|
|
425
|
+
# Save all children
|
|
426
|
+
child1.save_to_file()
|
|
427
|
+
child2.save_to_file()
|
|
428
|
+
child3.save_to_file()
|
|
429
|
+
|
|
430
|
+
# Test finding existing child by ID
|
|
431
|
+
found_child = DefaultParentedModel.from_id_and_parent_path(
|
|
432
|
+
child2.id, test_base_parented_file
|
|
433
|
+
)
|
|
434
|
+
assert found_child is not None
|
|
435
|
+
assert found_child.id == child2.id
|
|
436
|
+
assert found_child.name == "Child2"
|
|
437
|
+
assert found_child is not child2 # not same instance (deep copy)
|
|
438
|
+
|
|
439
|
+
# Test non-existent ID returns None
|
|
440
|
+
not_found = DefaultParentedModel.from_id_and_parent_path(
|
|
441
|
+
"nonexistent", test_base_parented_file
|
|
442
|
+
)
|
|
443
|
+
assert not_found is None
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def test_from_id_and_parent_path_with_cache(test_base_parented_file, tmp_model_cache):
|
|
447
|
+
# Set up parent and child
|
|
448
|
+
parent = BaseParentExample.load_from_file(test_base_parented_file)
|
|
449
|
+
child = DefaultParentedModel(parent=parent, name="Child")
|
|
450
|
+
child.save_to_file()
|
|
451
|
+
|
|
452
|
+
# First load to populate cache
|
|
453
|
+
_ = DefaultParentedModel.from_id_and_parent_path(child.id, test_base_parented_file)
|
|
454
|
+
|
|
455
|
+
# Mock cache to verify it's used
|
|
456
|
+
tmp_model_cache.get_model_id = MagicMock(return_value=child.id)
|
|
457
|
+
|
|
458
|
+
# Load again - should use cache
|
|
459
|
+
found_child = DefaultParentedModel.from_id_and_parent_path(
|
|
460
|
+
child.id, test_base_parented_file
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
assert found_child is not None
|
|
464
|
+
assert found_child.id == child.id
|
|
465
|
+
tmp_model_cache.get_model_id.assert_called()
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def test_from_id_and_parent_path_without_parent():
|
|
469
|
+
# Test with None parent_path
|
|
470
|
+
not_found = DefaultParentedModel.from_id_and_parent_path("any-id", None)
|
|
471
|
+
assert not_found is None
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from pydantic import ValidationError
|
|
3
|
+
|
|
4
|
+
# import datamodel first or we get circular import errors
|
|
5
|
+
from kiln_ai.datamodel import (
|
|
6
|
+
AllDatasetFilter,
|
|
7
|
+
AllSplitDefinition,
|
|
8
|
+
DatasetSplit,
|
|
9
|
+
DatasetSplitDefinition,
|
|
10
|
+
DataSource,
|
|
11
|
+
DataSourceType,
|
|
12
|
+
HighRatingDatasetFilter,
|
|
13
|
+
Task,
|
|
14
|
+
TaskOutput,
|
|
15
|
+
TaskOutputRating,
|
|
16
|
+
TaskOutputRatingType,
|
|
17
|
+
TaskRun,
|
|
18
|
+
Train60Test20Val20SplitDefinition,
|
|
19
|
+
Train80Test20SplitDefinition,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@pytest.fixture
|
|
24
|
+
def sample_task(tmp_path):
|
|
25
|
+
task_path = tmp_path / "task.kiln"
|
|
26
|
+
task = Task(
|
|
27
|
+
name="Test Task",
|
|
28
|
+
path=task_path,
|
|
29
|
+
description="Test task for dataset splitting",
|
|
30
|
+
instruction="Test instruction",
|
|
31
|
+
)
|
|
32
|
+
task.save_to_file()
|
|
33
|
+
return task
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@pytest.fixture
|
|
37
|
+
def sample_task_runs(sample_task):
|
|
38
|
+
# Create 10 task runs with different ratings
|
|
39
|
+
task_runs = []
|
|
40
|
+
for i in range(10):
|
|
41
|
+
rating = 5 if i < 6 else 1 # 6 high, 4 low ratings
|
|
42
|
+
task_run = TaskRun(
|
|
43
|
+
parent=sample_task,
|
|
44
|
+
input=f"input_{i}",
|
|
45
|
+
input_source=DataSource(
|
|
46
|
+
type=DataSourceType.human,
|
|
47
|
+
properties={"created_by": "test-user"},
|
|
48
|
+
),
|
|
49
|
+
output=TaskOutput(
|
|
50
|
+
output=f"output_{i}",
|
|
51
|
+
source=DataSource(
|
|
52
|
+
type=DataSourceType.human,
|
|
53
|
+
properties={"created_by": "test-user"},
|
|
54
|
+
),
|
|
55
|
+
rating=TaskOutputRating(
|
|
56
|
+
value=rating, type=TaskOutputRatingType.five_star
|
|
57
|
+
),
|
|
58
|
+
),
|
|
59
|
+
)
|
|
60
|
+
task_run.save_to_file()
|
|
61
|
+
task_runs.append(task_run)
|
|
62
|
+
return task_runs
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@pytest.fixture
|
|
66
|
+
def standard_splitstandard_splitss():
|
|
67
|
+
return [
|
|
68
|
+
DatasetSplitDefinition(name="train", percentage=0.8),
|
|
69
|
+
DatasetSplitDefinition(name="test", percentage=0.2),
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@pytest.fixture
|
|
74
|
+
def task_run():
|
|
75
|
+
return TaskRun(
|
|
76
|
+
input="test input",
|
|
77
|
+
input_source=DataSource(
|
|
78
|
+
type=DataSourceType.human,
|
|
79
|
+
properties={"created_by": "test-user"},
|
|
80
|
+
),
|
|
81
|
+
output=TaskOutput(
|
|
82
|
+
output="test output",
|
|
83
|
+
source=DataSource(
|
|
84
|
+
type=DataSourceType.human,
|
|
85
|
+
properties={"created_by": "test-user"},
|
|
86
|
+
),
|
|
87
|
+
rating=TaskOutputRating(rating=5, type=TaskOutputRatingType.five_star),
|
|
88
|
+
),
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def test_dataset_split_definition():
|
|
93
|
+
split = DatasetSplitDefinition(name="train", percentage=0.8)
|
|
94
|
+
assert split.name == "train"
|
|
95
|
+
assert split.percentage == 0.8
|
|
96
|
+
assert split.description is None
|
|
97
|
+
|
|
98
|
+
# Test validation
|
|
99
|
+
with pytest.raises(ValidationError):
|
|
100
|
+
DatasetSplitDefinition(name="train", percentage=1.5)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def test_dataset_split_validation():
|
|
104
|
+
# Test valid percentages
|
|
105
|
+
splits = [
|
|
106
|
+
DatasetSplitDefinition(name="train", percentage=0.8),
|
|
107
|
+
DatasetSplitDefinition(name="test", percentage=0.2),
|
|
108
|
+
]
|
|
109
|
+
dataset = DatasetSplit(
|
|
110
|
+
name="test_split",
|
|
111
|
+
splits=splits,
|
|
112
|
+
split_contents={"train": [], "test": []},
|
|
113
|
+
)
|
|
114
|
+
assert dataset.splits == splits
|
|
115
|
+
|
|
116
|
+
# Test invalid percentages
|
|
117
|
+
invalid_splits = [
|
|
118
|
+
DatasetSplitDefinition(name="train", percentage=0.8),
|
|
119
|
+
DatasetSplitDefinition(name="test", percentage=0.3),
|
|
120
|
+
]
|
|
121
|
+
with pytest.raises(ValueError, match="sum of split percentages must be 1.0"):
|
|
122
|
+
DatasetSplit(
|
|
123
|
+
name="test_split",
|
|
124
|
+
splits=invalid_splits,
|
|
125
|
+
split_contents={"train": [], "test": []},
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def test_all_dataset_filter(task_run):
|
|
130
|
+
assert AllDatasetFilter(task_run) is True
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def test_high_rating_dataset_filter(sample_task_runs):
|
|
134
|
+
for task_run in sample_task_runs:
|
|
135
|
+
assert HighRatingDatasetFilter(task_run) is (
|
|
136
|
+
task_run.output.rating.is_high_quality()
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
@pytest.mark.parametrize(
|
|
141
|
+
"splits,expected_sizes",
|
|
142
|
+
[
|
|
143
|
+
(Train80Test20SplitDefinition, {"train": 8, "test": 2}),
|
|
144
|
+
(AllSplitDefinition, {"all": 10}),
|
|
145
|
+
(Train60Test20Val20SplitDefinition, {"train": 6, "test": 2, "val": 2}),
|
|
146
|
+
(
|
|
147
|
+
[
|
|
148
|
+
DatasetSplitDefinition(name="train", percentage=0.7),
|
|
149
|
+
DatasetSplitDefinition(name="validation", percentage=0.2),
|
|
150
|
+
DatasetSplitDefinition(name="test", percentage=0.1),
|
|
151
|
+
],
|
|
152
|
+
{"train": 7, "validation": 2, "test": 1},
|
|
153
|
+
),
|
|
154
|
+
],
|
|
155
|
+
)
|
|
156
|
+
def test_dataset_split_from_task(sample_task, sample_task_runs, splits, expected_sizes):
|
|
157
|
+
assert sample_task_runs is not None
|
|
158
|
+
dataset = DatasetSplit.from_task("Split Name", sample_task, splits)
|
|
159
|
+
assert dataset.name == "Split Name"
|
|
160
|
+
|
|
161
|
+
# Check split sizes match expected
|
|
162
|
+
for split_name, expected_size in expected_sizes.items():
|
|
163
|
+
assert len(dataset.split_contents[split_name]) == expected_size
|
|
164
|
+
|
|
165
|
+
# Verify total size matches input size
|
|
166
|
+
total_size = sum(len(ids) for ids in dataset.split_contents.values())
|
|
167
|
+
assert total_size == len(sample_task_runs)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def test_dataset_split_with_high_rating_filter(sample_task, sample_task_runs):
|
|
171
|
+
assert len(sample_task_runs) == 10
|
|
172
|
+
dataset = DatasetSplit.from_task(
|
|
173
|
+
"Split Name",
|
|
174
|
+
sample_task,
|
|
175
|
+
Train80Test20SplitDefinition,
|
|
176
|
+
filter=HighRatingDatasetFilter,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Check that only high-rated task runs are included
|
|
180
|
+
all_ids = []
|
|
181
|
+
for ids in dataset.split_contents.values():
|
|
182
|
+
all_ids.extend(ids)
|
|
183
|
+
assert len(all_ids) == 6 # We created 6 high-rated task runs
|
|
184
|
+
|
|
185
|
+
# Check split proportions
|
|
186
|
+
train_size = len(dataset.split_contents["train"])
|
|
187
|
+
test_size = len(dataset.split_contents["test"])
|
|
188
|
+
assert train_size == 5 # ~80% of 6
|
|
189
|
+
assert test_size == 1 # ~20% of 6
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def test_dataset_split_with_single_split(sample_task, sample_task_runs):
|
|
193
|
+
splits = [DatasetSplitDefinition(name="all", percentage=1.0)]
|
|
194
|
+
dataset = DatasetSplit.from_task("Split Name", sample_task, splits)
|
|
195
|
+
|
|
196
|
+
assert len(dataset.split_contents["all"]) == len(sample_task_runs)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def test_missing_count(sample_task, sample_task_runs):
|
|
200
|
+
assert sample_task_runs is not None
|
|
201
|
+
# Create a dataset split with all task runs
|
|
202
|
+
dataset = DatasetSplit.from_task(
|
|
203
|
+
"Split Name", sample_task, Train80Test20SplitDefinition
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# Initially there should be no missing runs
|
|
207
|
+
assert dataset.missing_count() == 0
|
|
208
|
+
|
|
209
|
+
# Add some IDs to the split, that aren't on disk
|
|
210
|
+
dataset.split_contents["test"].append("1")
|
|
211
|
+
dataset.split_contents["test"].append("2")
|
|
212
|
+
dataset.split_contents["test"].append("3")
|
|
213
|
+
# shouldn't happen, but should not double count if it does
|
|
214
|
+
dataset.split_contents["train"].append("3")
|
|
215
|
+
|
|
216
|
+
# Now we should have 3 missing runs
|
|
217
|
+
assert dataset.missing_count() == 3
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def test_smaller_sample(sample_task, sample_task_runs):
|
|
221
|
+
assert sample_task_runs is not None
|
|
222
|
+
# Create a dataset split with all task runs
|
|
223
|
+
dataset = DatasetSplit.from_task(
|
|
224
|
+
"Split Name", sample_task, Train80Test20SplitDefinition
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Initially there should be no missing runs
|
|
228
|
+
assert dataset.missing_count() == 0
|
|
229
|
+
|
|
230
|
+
dataset.split_contents["test"].pop()
|
|
231
|
+
dataset.split_contents["train"].pop()
|
|
232
|
+
|
|
233
|
+
# Now we should have 0 missing runs. It's okay that dataset has newer data.
|
|
234
|
+
assert dataset.missing_count() == 0
|
|
@@ -5,8 +5,10 @@ import pytest
|
|
|
5
5
|
from pydantic import ValidationError
|
|
6
6
|
|
|
7
7
|
from kiln_ai.datamodel import (
|
|
8
|
+
DatasetSplit,
|
|
8
9
|
DataSource,
|
|
9
10
|
DataSourceType,
|
|
11
|
+
Finetune,
|
|
10
12
|
Project,
|
|
11
13
|
Task,
|
|
12
14
|
TaskDeterminism,
|
|
@@ -97,6 +99,16 @@ def test_task_run_relationship(valid_task_run):
|
|
|
97
99
|
assert valid_task_run.__class__.parent_type().__name__ == "Task"
|
|
98
100
|
|
|
99
101
|
|
|
102
|
+
def test_dataset_split_relationship():
|
|
103
|
+
assert DatasetSplit.relationship_name() == "dataset_splits"
|
|
104
|
+
assert DatasetSplit.parent_type().__name__ == "Task"
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def test_base_finetune_relationship():
|
|
108
|
+
assert Finetune.relationship_name() == "finetunes"
|
|
109
|
+
assert Finetune.parent_type().__name__ == "Task"
|
|
110
|
+
|
|
111
|
+
|
|
100
112
|
def test_structured_output_workflow(tmp_path):
|
|
101
113
|
tmp_project_file = (
|
|
102
114
|
tmp_path / "test_structured_output_runs" / Project.base_filename()
|