kiln-ai 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/adapter_registry.py +2 -0
- kiln_ai/adapters/base_adapter.py +6 -1
- kiln_ai/adapters/langchain_adapters.py +5 -1
- kiln_ai/adapters/ml_model_list.py +43 -12
- kiln_ai/adapters/ollama_tools.py +4 -3
- kiln_ai/adapters/provider_tools.py +63 -2
- kiln_ai/adapters/repair/repair_task.py +4 -2
- kiln_ai/adapters/test_langchain_adapter.py +183 -0
- kiln_ai/adapters/test_provider_tools.py +315 -1
- kiln_ai/datamodel/__init__.py +162 -19
- kiln_ai/datamodel/basemodel.py +90 -42
- kiln_ai/datamodel/model_cache.py +116 -0
- kiln_ai/datamodel/test_basemodel.py +138 -3
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_model_cache.py +244 -0
- kiln_ai/datamodel/test_models.py +173 -0
- kiln_ai/datamodel/test_output_rating.py +377 -10
- kiln_ai/utils/config.py +33 -10
- kiln_ai/utils/test_config.py +48 -0
- kiln_ai-0.8.0.dist-info/METADATA +237 -0
- {kiln_ai-0.7.0.dist-info → kiln_ai-0.8.0.dist-info}/RECORD +23 -21
- {kiln_ai-0.7.0.dist-info → kiln_ai-0.8.0.dist-info}/WHEEL +1 -1
- kiln_ai-0.7.0.dist-info/METADATA +0 -90
- {kiln_ai-0.7.0.dist-info → kiln_ai-0.8.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from unittest import mock
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
from libs.core.kiln_ai.datamodel.model_cache import ModelCache
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# Define a simple Pydantic model for testing
|
|
11
|
+
class ModelTest(BaseModel):
|
|
12
|
+
name: str
|
|
13
|
+
value: int
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@pytest.fixture
|
|
17
|
+
def model_cache():
|
|
18
|
+
return ModelCache()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def should_skip_test(model_cache):
|
|
22
|
+
return not model_cache._enabled
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pytest.fixture
|
|
26
|
+
def test_path(tmp_path):
|
|
27
|
+
# Create a temporary file path for testing
|
|
28
|
+
test_file = tmp_path / "test_model.kiln"
|
|
29
|
+
test_file.touch() # Create the file
|
|
30
|
+
return test_file
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def test_set_and_get_model(model_cache, test_path):
|
|
34
|
+
if not model_cache._enabled:
|
|
35
|
+
pytest.skip("Cache is disabled on this fs")
|
|
36
|
+
|
|
37
|
+
model = ModelTest(name="test", value=123)
|
|
38
|
+
mtime_ns = test_path.stat().st_mtime_ns
|
|
39
|
+
|
|
40
|
+
model_cache.set_model(test_path, model, mtime_ns)
|
|
41
|
+
cached_model = model_cache.get_model(test_path, ModelTest)
|
|
42
|
+
|
|
43
|
+
assert cached_model is not None
|
|
44
|
+
assert cached_model.name == "test"
|
|
45
|
+
assert cached_model.value == 123
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def test_invalidate_model(model_cache, test_path):
|
|
49
|
+
model = ModelTest(name="test", value=123)
|
|
50
|
+
mtime = test_path.stat().st_mtime
|
|
51
|
+
|
|
52
|
+
model_cache.set_model(test_path, model, mtime)
|
|
53
|
+
model_cache.invalidate(test_path)
|
|
54
|
+
cached_model = model_cache.get_model(test_path, ModelTest)
|
|
55
|
+
|
|
56
|
+
assert cached_model is None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def test_clear_cache(model_cache, test_path):
|
|
60
|
+
model = ModelTest(name="test", value=123)
|
|
61
|
+
mtime = test_path.stat().st_mtime
|
|
62
|
+
|
|
63
|
+
model_cache.set_model(test_path, model, mtime)
|
|
64
|
+
model_cache.clear()
|
|
65
|
+
cached_model = model_cache.get_model(test_path, ModelTest)
|
|
66
|
+
|
|
67
|
+
assert cached_model is None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def test_cache_invalid_due_to_mtime_change(model_cache, test_path):
|
|
71
|
+
model = ModelTest(name="test", value=123)
|
|
72
|
+
mtime = test_path.stat().st_mtime
|
|
73
|
+
|
|
74
|
+
model_cache.set_model(test_path, model, mtime)
|
|
75
|
+
|
|
76
|
+
# Simulate a file modification by updating the mtime
|
|
77
|
+
test_path.touch()
|
|
78
|
+
cached_model = model_cache.get_model(test_path, ModelTest)
|
|
79
|
+
|
|
80
|
+
assert cached_model is None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def test_get_model_wrong_type(model_cache, test_path):
|
|
84
|
+
if not model_cache._enabled:
|
|
85
|
+
pytest.skip("Cache is disabled on this fs")
|
|
86
|
+
|
|
87
|
+
class AnotherModel(BaseModel):
|
|
88
|
+
other_field: str
|
|
89
|
+
|
|
90
|
+
model = ModelTest(name="test", value=123)
|
|
91
|
+
mtime_ns = test_path.stat().st_mtime_ns
|
|
92
|
+
|
|
93
|
+
model_cache.set_model(test_path, model, mtime_ns)
|
|
94
|
+
|
|
95
|
+
with pytest.raises(ValueError):
|
|
96
|
+
model_cache.get_model(test_path, AnotherModel)
|
|
97
|
+
|
|
98
|
+
# Test that the cache invalidates
|
|
99
|
+
cached_model = model_cache.get_model(test_path, ModelTest)
|
|
100
|
+
assert cached_model is None
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def test_is_cache_valid_true(model_cache, test_path):
|
|
104
|
+
mtime_ns = test_path.stat().st_mtime_ns
|
|
105
|
+
assert model_cache._is_cache_valid(test_path, mtime_ns) is True
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def test_is_cache_valid_false_due_to_mtime_change(model_cache, test_path):
|
|
109
|
+
if not model_cache._enabled:
|
|
110
|
+
pytest.skip("Cache is disabled on this fs")
|
|
111
|
+
|
|
112
|
+
mtime_ns = test_path.stat().st_mtime_ns
|
|
113
|
+
# Simulate a file modification by updating the mtime
|
|
114
|
+
test_path.touch()
|
|
115
|
+
assert model_cache._is_cache_valid(test_path, mtime_ns) is False
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def test_is_cache_valid_false_due_to_missing_file(model_cache):
|
|
119
|
+
non_existent_path = Path("/non/existent/path")
|
|
120
|
+
assert model_cache._is_cache_valid(non_existent_path, 0) is False
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def test_benchmark_get_model(benchmark, model_cache, test_path):
|
|
124
|
+
model = ModelTest(name="test", value=123)
|
|
125
|
+
mtime = test_path.stat().st_mtime
|
|
126
|
+
|
|
127
|
+
# Set the model in the cache
|
|
128
|
+
model_cache.set_model(test_path, model, mtime)
|
|
129
|
+
|
|
130
|
+
# Benchmark the get_model method
|
|
131
|
+
def get_model():
|
|
132
|
+
return model_cache.get_model(test_path, ModelTest)
|
|
133
|
+
|
|
134
|
+
benchmark(get_model)
|
|
135
|
+
stats = benchmark.stats.stats
|
|
136
|
+
|
|
137
|
+
# 25k ops per second is the target. Getting 250k on Macbook, but CI will be slower
|
|
138
|
+
target = 1 / 25000
|
|
139
|
+
if stats.mean > target:
|
|
140
|
+
pytest.fail(
|
|
141
|
+
f"Average time per iteration: {stats.mean}, expected less than {target}"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def test_get_model_returns_copy(model_cache, test_path):
|
|
146
|
+
if not model_cache._enabled:
|
|
147
|
+
pytest.skip("Cache is disabled on this fs")
|
|
148
|
+
|
|
149
|
+
model = ModelTest(name="test", value=123)
|
|
150
|
+
mtime_ns = test_path.stat().st_mtime_ns
|
|
151
|
+
|
|
152
|
+
# Set the model in the cache
|
|
153
|
+
model_cache.set_model(test_path, model, mtime_ns)
|
|
154
|
+
|
|
155
|
+
# Get a copy of the model from the cache
|
|
156
|
+
cached_model = model_cache.get_model(test_path, ModelTest)
|
|
157
|
+
|
|
158
|
+
# Different instance (is), same data (==)
|
|
159
|
+
assert cached_model is not model
|
|
160
|
+
assert cached_model == model
|
|
161
|
+
|
|
162
|
+
# Mutate the cached model
|
|
163
|
+
cached_model.name = "mutated"
|
|
164
|
+
|
|
165
|
+
# Get the model again from the cache
|
|
166
|
+
new_cached_model = model_cache.get_model(test_path, ModelTest)
|
|
167
|
+
|
|
168
|
+
# Assert that the new cached model has the original values
|
|
169
|
+
assert new_cached_model == model
|
|
170
|
+
assert new_cached_model.name == "test"
|
|
171
|
+
|
|
172
|
+
# Save the mutated model back to the cache
|
|
173
|
+
model_cache.set_model(test_path, cached_model, mtime_ns)
|
|
174
|
+
|
|
175
|
+
# Get the model again from the cache
|
|
176
|
+
updated_cached_model = model_cache.get_model(test_path, ModelTest)
|
|
177
|
+
|
|
178
|
+
# Assert that the updated cached model has the mutated values
|
|
179
|
+
assert updated_cached_model.name == "mutated"
|
|
180
|
+
assert updated_cached_model.value == 123
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def test_no_cache_when_no_fine_granularity(model_cache, test_path):
|
|
184
|
+
model = ModelTest(name="test", value=123)
|
|
185
|
+
mtime_ns = test_path.stat().st_mtime_ns
|
|
186
|
+
|
|
187
|
+
model_cache._enabled = False
|
|
188
|
+
model_cache.set_model(test_path, model, mtime_ns)
|
|
189
|
+
cached_model = model_cache.get_model(test_path, ModelTest)
|
|
190
|
+
|
|
191
|
+
# Assert that the model is not cached
|
|
192
|
+
assert cached_model is None
|
|
193
|
+
assert model_cache.model_cache == {}
|
|
194
|
+
assert model_cache._enabled is False
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def test_check_timestamp_granularity_macos():
|
|
198
|
+
with mock.patch("sys.platform", "darwin"):
|
|
199
|
+
cache = ModelCache()
|
|
200
|
+
assert cache._check_timestamp_granularity() is True
|
|
201
|
+
assert cache._enabled is True
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def test_check_timestamp_granularity_windows():
|
|
205
|
+
with mock.patch("sys.platform", "win32"):
|
|
206
|
+
cache = ModelCache()
|
|
207
|
+
assert cache._check_timestamp_granularity() is True
|
|
208
|
+
assert cache._enabled is True
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def test_check_timestamp_granularity_linux_good():
|
|
212
|
+
mock_stats = mock.Mock()
|
|
213
|
+
mock_stats.f_timespec = 9 # nanosecond precision
|
|
214
|
+
|
|
215
|
+
with (
|
|
216
|
+
mock.patch("sys.platform", "linux"),
|
|
217
|
+
mock.patch("os.statvfs", return_value=mock_stats),
|
|
218
|
+
):
|
|
219
|
+
cache = ModelCache()
|
|
220
|
+
assert cache._check_timestamp_granularity() is True
|
|
221
|
+
assert cache._enabled is True
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def test_check_timestamp_granularity_linux_poor():
|
|
225
|
+
mock_stats = mock.Mock()
|
|
226
|
+
mock_stats.f_timespec = 3 # millisecond precision
|
|
227
|
+
|
|
228
|
+
with (
|
|
229
|
+
mock.patch("sys.platform", "linux"),
|
|
230
|
+
mock.patch("os.statvfs", return_value=mock_stats),
|
|
231
|
+
):
|
|
232
|
+
cache = ModelCache()
|
|
233
|
+
assert cache._check_timestamp_granularity() is False
|
|
234
|
+
assert cache._enabled is False
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def test_check_timestamp_granularity_linux_error():
|
|
238
|
+
with (
|
|
239
|
+
mock.patch("sys.platform", "linux"),
|
|
240
|
+
mock.patch("os.statvfs", side_effect=OSError("Mock filesystem error")),
|
|
241
|
+
):
|
|
242
|
+
cache = ModelCache()
|
|
243
|
+
assert cache._check_timestamp_granularity() is False
|
|
244
|
+
assert cache._enabled is False
|
kiln_ai/datamodel/test_models.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import os
|
|
3
|
+
from unittest.mock import patch
|
|
2
4
|
|
|
3
5
|
import pytest
|
|
4
6
|
from pydantic import ValidationError
|
|
@@ -82,6 +84,7 @@ def test_task_serialization(test_project_file):
|
|
|
82
84
|
instruction="Test Base Task Instruction",
|
|
83
85
|
thinking_instruction="Test Thinking Instruction",
|
|
84
86
|
)
|
|
87
|
+
assert task._loaded_from_file is False
|
|
85
88
|
|
|
86
89
|
task.save_to_file()
|
|
87
90
|
|
|
@@ -90,6 +93,11 @@ def test_task_serialization(test_project_file):
|
|
|
90
93
|
assert parsed_task.description == "Test Description"
|
|
91
94
|
assert parsed_task.instruction == "Test Base Task Instruction"
|
|
92
95
|
assert parsed_task.thinking_instruction == "Test Thinking Instruction"
|
|
96
|
+
assert parsed_task._loaded_from_file is True
|
|
97
|
+
|
|
98
|
+
# Confirm the local property is not persisted to disk
|
|
99
|
+
json_data = json.loads(parsed_task.path.read_text())
|
|
100
|
+
assert "_loaded_from_file" not in json_data
|
|
93
101
|
|
|
94
102
|
|
|
95
103
|
def test_save_to_file_without_path():
|
|
@@ -315,3 +323,168 @@ def test_finetune_parameters_validation():
|
|
|
315
323
|
base_model_id="gpt-3.5-turbo",
|
|
316
324
|
parameters={"invalid": [1, 2, 3]}, # Lists are not allowed
|
|
317
325
|
)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def test_task_run_input_source_validation(tmp_path):
|
|
329
|
+
# Setup basic output for TaskRun creation
|
|
330
|
+
output = TaskOutput(
|
|
331
|
+
output="test output",
|
|
332
|
+
source=DataSource(
|
|
333
|
+
type=DataSourceType.synthetic,
|
|
334
|
+
properties={
|
|
335
|
+
"model_name": "test-model",
|
|
336
|
+
"model_provider": "test-provider",
|
|
337
|
+
"adapter_name": "test-adapter",
|
|
338
|
+
},
|
|
339
|
+
),
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
project_path = tmp_path / "project.kiln"
|
|
343
|
+
project = Project(name="Test Project", path=project_path)
|
|
344
|
+
project.save_to_file()
|
|
345
|
+
task = Task(name="Test Task", instruction="Test Instruction", parent=project)
|
|
346
|
+
task.save_to_file()
|
|
347
|
+
|
|
348
|
+
# Test 1: Creating without input_source should work when strict mode is off
|
|
349
|
+
task_run = TaskRun(
|
|
350
|
+
input="test input",
|
|
351
|
+
output=output,
|
|
352
|
+
)
|
|
353
|
+
task_run.parent = task
|
|
354
|
+
assert task_run.input_source is None
|
|
355
|
+
|
|
356
|
+
# Save for later usage
|
|
357
|
+
task_run.save_to_file()
|
|
358
|
+
task_missing_input_source = task_run.path
|
|
359
|
+
|
|
360
|
+
# Test 2: Creating with input_source should work when strict mode is off
|
|
361
|
+
task_run = TaskRun(
|
|
362
|
+
input="test input 2",
|
|
363
|
+
input_source=DataSource(
|
|
364
|
+
type=DataSourceType.human,
|
|
365
|
+
properties={"created_by": "test-user"},
|
|
366
|
+
),
|
|
367
|
+
output=output,
|
|
368
|
+
)
|
|
369
|
+
assert task_run.input_source is not None
|
|
370
|
+
|
|
371
|
+
# Test 3: Creating without input_source should fail when strict mode is on
|
|
372
|
+
with patch("kiln_ai.datamodel.strict_mode", return_value=True):
|
|
373
|
+
with pytest.raises(ValueError) as exc_info:
|
|
374
|
+
task_run = TaskRun(
|
|
375
|
+
input="test input 3",
|
|
376
|
+
output=output,
|
|
377
|
+
)
|
|
378
|
+
assert "input_source is required when strict mode is enabled" in str(
|
|
379
|
+
exc_info.value
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
# Test 4: Loading from disk should work without input_source, even with strict mode on
|
|
383
|
+
assert os.path.exists(task_missing_input_source)
|
|
384
|
+
task_run = TaskRun.load_from_file(task_missing_input_source)
|
|
385
|
+
assert task_run.input_source is None
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def test_task_output_source_validation(tmp_path):
|
|
389
|
+
# Setup basic output source for validation
|
|
390
|
+
output_source = DataSource(
|
|
391
|
+
type=DataSourceType.synthetic,
|
|
392
|
+
properties={
|
|
393
|
+
"model_name": "test-model",
|
|
394
|
+
"model_provider": "test-provider",
|
|
395
|
+
"adapter_name": "test-adapter",
|
|
396
|
+
},
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
project_path = tmp_path / "project.kiln"
|
|
400
|
+
project = Project(name="Test Project", path=project_path)
|
|
401
|
+
project.save_to_file()
|
|
402
|
+
task = Task(name="Test Task", instruction="Test Instruction", parent=project)
|
|
403
|
+
task.save_to_file()
|
|
404
|
+
|
|
405
|
+
# Test 1: Creating without source should work when strict mode is off
|
|
406
|
+
task_output = TaskOutput(
|
|
407
|
+
output="test output",
|
|
408
|
+
)
|
|
409
|
+
assert task_output.source is None
|
|
410
|
+
|
|
411
|
+
# Save for later usage
|
|
412
|
+
task_run = TaskRun(
|
|
413
|
+
input="test input",
|
|
414
|
+
input_source=output_source,
|
|
415
|
+
output=task_output,
|
|
416
|
+
)
|
|
417
|
+
task_run.parent = task
|
|
418
|
+
task_run.save_to_file()
|
|
419
|
+
task_missing_output_source = task_run.path
|
|
420
|
+
|
|
421
|
+
# Test 2: Creating with source should work when strict mode is off
|
|
422
|
+
task_output = TaskOutput(
|
|
423
|
+
output="test output 2",
|
|
424
|
+
source=output_source,
|
|
425
|
+
)
|
|
426
|
+
assert task_output.source is not None
|
|
427
|
+
|
|
428
|
+
# Test 3: Creating without source should fail when strict mode is on
|
|
429
|
+
with patch("kiln_ai.datamodel.strict_mode", return_value=True):
|
|
430
|
+
with pytest.raises(ValueError) as exc_info:
|
|
431
|
+
task_output = TaskOutput(
|
|
432
|
+
output="test output 3",
|
|
433
|
+
)
|
|
434
|
+
assert "Output source is required when strict mode is enabled" in str(
|
|
435
|
+
exc_info.value
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
# Test 4: Loading from disk should work without source, even with strict mode on
|
|
439
|
+
assert os.path.exists(task_missing_output_source)
|
|
440
|
+
task_run = TaskRun.load_from_file(task_missing_output_source)
|
|
441
|
+
assert task_run.output.source is None
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def test_task_run_tags_validation():
|
|
445
|
+
# Setup basic output for TaskRun creation
|
|
446
|
+
output = TaskOutput(
|
|
447
|
+
output="test output",
|
|
448
|
+
source=DataSource(
|
|
449
|
+
type=DataSourceType.synthetic,
|
|
450
|
+
properties={
|
|
451
|
+
"model_name": "test-model",
|
|
452
|
+
"model_provider": "test-provider",
|
|
453
|
+
"adapter_name": "test-adapter",
|
|
454
|
+
},
|
|
455
|
+
),
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
# Test 1: Valid tags should work
|
|
459
|
+
task_run = TaskRun(
|
|
460
|
+
input="test input",
|
|
461
|
+
output=output,
|
|
462
|
+
tags=["test_tag", "another_tag", "tag123"],
|
|
463
|
+
)
|
|
464
|
+
assert task_run.tags == ["test_tag", "another_tag", "tag123"]
|
|
465
|
+
|
|
466
|
+
# Test 2: Empty list of tags should work
|
|
467
|
+
task_run = TaskRun(
|
|
468
|
+
input="test input",
|
|
469
|
+
output=output,
|
|
470
|
+
tags=[],
|
|
471
|
+
)
|
|
472
|
+
assert task_run.tags == []
|
|
473
|
+
|
|
474
|
+
# Test 3: Empty string tag should fail
|
|
475
|
+
with pytest.raises(ValueError) as exc_info:
|
|
476
|
+
TaskRun(
|
|
477
|
+
input="test input",
|
|
478
|
+
output=output,
|
|
479
|
+
tags=["valid_tag", ""],
|
|
480
|
+
)
|
|
481
|
+
assert "Tags cannot be empty strings" in str(exc_info.value)
|
|
482
|
+
|
|
483
|
+
# Test 4: Tag with spaces should fail
|
|
484
|
+
with pytest.raises(ValueError) as exc_info:
|
|
485
|
+
TaskRun(
|
|
486
|
+
input="test input",
|
|
487
|
+
output=output,
|
|
488
|
+
tags=["valid_tag", "invalid tag"],
|
|
489
|
+
)
|
|
490
|
+
assert "Tags cannot contain spaces. Try underscores." in str(exc_info.value)
|