kiln-ai 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +19 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
- kiln_ai/adapters/fine_tune/__init__.py +14 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
- kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
- kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
- kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
- kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
- kiln_ai/adapters/langchain_adapters.py +103 -13
- kiln_ai/adapters/ml_model_list.py +239 -303
- kiln_ai/adapters/ollama_tools.py +115 -0
- kiln_ai/adapters/provider_tools.py +308 -0
- kiln_ai/adapters/repair/repair_task.py +4 -2
- kiln_ai/adapters/repair/test_repair_task.py +6 -11
- kiln_ai/adapters/test_langchain_adapter.py +229 -18
- kiln_ai/adapters/test_ollama_tools.py +42 -0
- kiln_ai/adapters/test_prompt_adaptors.py +7 -5
- kiln_ai/adapters/test_provider_tools.py +531 -0
- kiln_ai/adapters/test_structured_output.py +22 -43
- kiln_ai/datamodel/__init__.py +287 -24
- kiln_ai/datamodel/basemodel.py +122 -38
- kiln_ai/datamodel/model_cache.py +116 -0
- kiln_ai/datamodel/registry.py +31 -0
- kiln_ai/datamodel/test_basemodel.py +167 -4
- kiln_ai/datamodel/test_dataset_split.py +234 -0
- kiln_ai/datamodel/test_example_models.py +12 -0
- kiln_ai/datamodel/test_model_cache.py +244 -0
- kiln_ai/datamodel/test_models.py +215 -1
- kiln_ai/datamodel/test_registry.py +96 -0
- kiln_ai/utils/config.py +14 -1
- kiln_ai/utils/name_generator.py +125 -0
- kiln_ai/utils/test_name_geneator.py +47 -0
- kiln_ai-0.7.1.dist-info/METADATA +237 -0
- kiln_ai-0.7.1.dist-info/RECORD +58 -0
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/WHEEL +1 -1
- kiln_ai/adapters/test_ml_model_list.py +0 -181
- kiln_ai-0.6.1.dist-info/METADATA +0 -88
- kiln_ai-0.6.1.dist-info/RECORD +0 -37
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from unittest import mock
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
from libs.core.kiln_ai.datamodel.model_cache import ModelCache
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# Define a simple Pydantic model for testing
|
|
11
|
+
class ModelTest(BaseModel):
|
|
12
|
+
name: str
|
|
13
|
+
value: int
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@pytest.fixture
|
|
17
|
+
def model_cache():
|
|
18
|
+
return ModelCache()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def should_skip_test(model_cache):
|
|
22
|
+
return not model_cache._enabled
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pytest.fixture
|
|
26
|
+
def test_path(tmp_path):
|
|
27
|
+
# Create a temporary file path for testing
|
|
28
|
+
test_file = tmp_path / "test_model.kiln"
|
|
29
|
+
test_file.touch() # Create the file
|
|
30
|
+
return test_file
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def test_set_and_get_model(model_cache, test_path):
|
|
34
|
+
if not model_cache._enabled:
|
|
35
|
+
pytest.skip("Cache is disabled on this fs")
|
|
36
|
+
|
|
37
|
+
model = ModelTest(name="test", value=123)
|
|
38
|
+
mtime_ns = test_path.stat().st_mtime_ns
|
|
39
|
+
|
|
40
|
+
model_cache.set_model(test_path, model, mtime_ns)
|
|
41
|
+
cached_model = model_cache.get_model(test_path, ModelTest)
|
|
42
|
+
|
|
43
|
+
assert cached_model is not None
|
|
44
|
+
assert cached_model.name == "test"
|
|
45
|
+
assert cached_model.value == 123
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def test_invalidate_model(model_cache, test_path):
|
|
49
|
+
model = ModelTest(name="test", value=123)
|
|
50
|
+
mtime = test_path.stat().st_mtime
|
|
51
|
+
|
|
52
|
+
model_cache.set_model(test_path, model, mtime)
|
|
53
|
+
model_cache.invalidate(test_path)
|
|
54
|
+
cached_model = model_cache.get_model(test_path, ModelTest)
|
|
55
|
+
|
|
56
|
+
assert cached_model is None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def test_clear_cache(model_cache, test_path):
|
|
60
|
+
model = ModelTest(name="test", value=123)
|
|
61
|
+
mtime = test_path.stat().st_mtime
|
|
62
|
+
|
|
63
|
+
model_cache.set_model(test_path, model, mtime)
|
|
64
|
+
model_cache.clear()
|
|
65
|
+
cached_model = model_cache.get_model(test_path, ModelTest)
|
|
66
|
+
|
|
67
|
+
assert cached_model is None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def test_cache_invalid_due_to_mtime_change(model_cache, test_path):
|
|
71
|
+
model = ModelTest(name="test", value=123)
|
|
72
|
+
mtime = test_path.stat().st_mtime
|
|
73
|
+
|
|
74
|
+
model_cache.set_model(test_path, model, mtime)
|
|
75
|
+
|
|
76
|
+
# Simulate a file modification by updating the mtime
|
|
77
|
+
test_path.touch()
|
|
78
|
+
cached_model = model_cache.get_model(test_path, ModelTest)
|
|
79
|
+
|
|
80
|
+
assert cached_model is None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def test_get_model_wrong_type(model_cache, test_path):
|
|
84
|
+
if not model_cache._enabled:
|
|
85
|
+
pytest.skip("Cache is disabled on this fs")
|
|
86
|
+
|
|
87
|
+
class AnotherModel(BaseModel):
|
|
88
|
+
other_field: str
|
|
89
|
+
|
|
90
|
+
model = ModelTest(name="test", value=123)
|
|
91
|
+
mtime_ns = test_path.stat().st_mtime_ns
|
|
92
|
+
|
|
93
|
+
model_cache.set_model(test_path, model, mtime_ns)
|
|
94
|
+
|
|
95
|
+
with pytest.raises(ValueError):
|
|
96
|
+
model_cache.get_model(test_path, AnotherModel)
|
|
97
|
+
|
|
98
|
+
# Test that the cache invalidates
|
|
99
|
+
cached_model = model_cache.get_model(test_path, ModelTest)
|
|
100
|
+
assert cached_model is None
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def test_is_cache_valid_true(model_cache, test_path):
|
|
104
|
+
mtime_ns = test_path.stat().st_mtime_ns
|
|
105
|
+
assert model_cache._is_cache_valid(test_path, mtime_ns) is True
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def test_is_cache_valid_false_due_to_mtime_change(model_cache, test_path):
|
|
109
|
+
if not model_cache._enabled:
|
|
110
|
+
pytest.skip("Cache is disabled on this fs")
|
|
111
|
+
|
|
112
|
+
mtime_ns = test_path.stat().st_mtime_ns
|
|
113
|
+
# Simulate a file modification by updating the mtime
|
|
114
|
+
test_path.touch()
|
|
115
|
+
assert model_cache._is_cache_valid(test_path, mtime_ns) is False
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def test_is_cache_valid_false_due_to_missing_file(model_cache):
|
|
119
|
+
non_existent_path = Path("/non/existent/path")
|
|
120
|
+
assert model_cache._is_cache_valid(non_existent_path, 0) is False
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def test_benchmark_get_model(benchmark, model_cache, test_path):
|
|
124
|
+
model = ModelTest(name="test", value=123)
|
|
125
|
+
mtime = test_path.stat().st_mtime
|
|
126
|
+
|
|
127
|
+
# Set the model in the cache
|
|
128
|
+
model_cache.set_model(test_path, model, mtime)
|
|
129
|
+
|
|
130
|
+
# Benchmark the get_model method
|
|
131
|
+
def get_model():
|
|
132
|
+
return model_cache.get_model(test_path, ModelTest)
|
|
133
|
+
|
|
134
|
+
benchmark(get_model)
|
|
135
|
+
stats = benchmark.stats.stats
|
|
136
|
+
|
|
137
|
+
# 25k ops per second is the target. Getting 250k on Macbook, but CI will be slower
|
|
138
|
+
target = 1 / 25000
|
|
139
|
+
if stats.mean > target:
|
|
140
|
+
pytest.fail(
|
|
141
|
+
f"Average time per iteration: {stats.mean}, expected less than {target}"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def test_get_model_returns_copy(model_cache, test_path):
|
|
146
|
+
if not model_cache._enabled:
|
|
147
|
+
pytest.skip("Cache is disabled on this fs")
|
|
148
|
+
|
|
149
|
+
model = ModelTest(name="test", value=123)
|
|
150
|
+
mtime_ns = test_path.stat().st_mtime_ns
|
|
151
|
+
|
|
152
|
+
# Set the model in the cache
|
|
153
|
+
model_cache.set_model(test_path, model, mtime_ns)
|
|
154
|
+
|
|
155
|
+
# Get a copy of the model from the cache
|
|
156
|
+
cached_model = model_cache.get_model(test_path, ModelTest)
|
|
157
|
+
|
|
158
|
+
# Different instance (is), same data (==)
|
|
159
|
+
assert cached_model is not model
|
|
160
|
+
assert cached_model == model
|
|
161
|
+
|
|
162
|
+
# Mutate the cached model
|
|
163
|
+
cached_model.name = "mutated"
|
|
164
|
+
|
|
165
|
+
# Get the model again from the cache
|
|
166
|
+
new_cached_model = model_cache.get_model(test_path, ModelTest)
|
|
167
|
+
|
|
168
|
+
# Assert that the new cached model has the original values
|
|
169
|
+
assert new_cached_model == model
|
|
170
|
+
assert new_cached_model.name == "test"
|
|
171
|
+
|
|
172
|
+
# Save the mutated model back to the cache
|
|
173
|
+
model_cache.set_model(test_path, cached_model, mtime_ns)
|
|
174
|
+
|
|
175
|
+
# Get the model again from the cache
|
|
176
|
+
updated_cached_model = model_cache.get_model(test_path, ModelTest)
|
|
177
|
+
|
|
178
|
+
# Assert that the updated cached model has the mutated values
|
|
179
|
+
assert updated_cached_model.name == "mutated"
|
|
180
|
+
assert updated_cached_model.value == 123
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def test_no_cache_when_no_fine_granularity(model_cache, test_path):
|
|
184
|
+
model = ModelTest(name="test", value=123)
|
|
185
|
+
mtime_ns = test_path.stat().st_mtime_ns
|
|
186
|
+
|
|
187
|
+
model_cache._enabled = False
|
|
188
|
+
model_cache.set_model(test_path, model, mtime_ns)
|
|
189
|
+
cached_model = model_cache.get_model(test_path, ModelTest)
|
|
190
|
+
|
|
191
|
+
# Assert that the model is not cached
|
|
192
|
+
assert cached_model is None
|
|
193
|
+
assert model_cache.model_cache == {}
|
|
194
|
+
assert model_cache._enabled is False
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def test_check_timestamp_granularity_macos():
|
|
198
|
+
with mock.patch("sys.platform", "darwin"):
|
|
199
|
+
cache = ModelCache()
|
|
200
|
+
assert cache._check_timestamp_granularity() is True
|
|
201
|
+
assert cache._enabled is True
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def test_check_timestamp_granularity_windows():
|
|
205
|
+
with mock.patch("sys.platform", "win32"):
|
|
206
|
+
cache = ModelCache()
|
|
207
|
+
assert cache._check_timestamp_granularity() is True
|
|
208
|
+
assert cache._enabled is True
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def test_check_timestamp_granularity_linux_good():
|
|
212
|
+
mock_stats = mock.Mock()
|
|
213
|
+
mock_stats.f_timespec = 9 # nanosecond precision
|
|
214
|
+
|
|
215
|
+
with (
|
|
216
|
+
mock.patch("sys.platform", "linux"),
|
|
217
|
+
mock.patch("os.statvfs", return_value=mock_stats),
|
|
218
|
+
):
|
|
219
|
+
cache = ModelCache()
|
|
220
|
+
assert cache._check_timestamp_granularity() is True
|
|
221
|
+
assert cache._enabled is True
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def test_check_timestamp_granularity_linux_poor():
|
|
225
|
+
mock_stats = mock.Mock()
|
|
226
|
+
mock_stats.f_timespec = 3 # millisecond precision
|
|
227
|
+
|
|
228
|
+
with (
|
|
229
|
+
mock.patch("sys.platform", "linux"),
|
|
230
|
+
mock.patch("os.statvfs", return_value=mock_stats),
|
|
231
|
+
):
|
|
232
|
+
cache = ModelCache()
|
|
233
|
+
assert cache._check_timestamp_granularity() is False
|
|
234
|
+
assert cache._enabled is False
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def test_check_timestamp_granularity_linux_error():
|
|
238
|
+
with (
|
|
239
|
+
mock.patch("sys.platform", "linux"),
|
|
240
|
+
mock.patch("os.statvfs", side_effect=OSError("Mock filesystem error")),
|
|
241
|
+
):
|
|
242
|
+
cache = ModelCache()
|
|
243
|
+
assert cache._check_timestamp_granularity() is False
|
|
244
|
+
assert cache._enabled is False
|
kiln_ai/datamodel/test_models.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import os
|
|
3
|
+
from unittest.mock import patch
|
|
2
4
|
|
|
3
5
|
import pytest
|
|
4
6
|
from pydantic import ValidationError
|
|
@@ -6,6 +8,7 @@ from pydantic import ValidationError
|
|
|
6
8
|
from kiln_ai.datamodel import (
|
|
7
9
|
DataSource,
|
|
8
10
|
DataSourceType,
|
|
11
|
+
Finetune,
|
|
9
12
|
Project,
|
|
10
13
|
Task,
|
|
11
14
|
TaskOutput,
|
|
@@ -27,7 +30,7 @@ def test_project_file(tmp_path):
|
|
|
27
30
|
|
|
28
31
|
@pytest.fixture
|
|
29
32
|
def test_task_file(tmp_path):
|
|
30
|
-
test_file_path = tmp_path / "task.
|
|
33
|
+
test_file_path = tmp_path / "task.kiln"
|
|
31
34
|
data = {
|
|
32
35
|
"v": 1,
|
|
33
36
|
"name": "Test Task",
|
|
@@ -81,6 +84,7 @@ def test_task_serialization(test_project_file):
|
|
|
81
84
|
instruction="Test Base Task Instruction",
|
|
82
85
|
thinking_instruction="Test Thinking Instruction",
|
|
83
86
|
)
|
|
87
|
+
assert task._loaded_from_file is False
|
|
84
88
|
|
|
85
89
|
task.save_to_file()
|
|
86
90
|
|
|
@@ -89,6 +93,11 @@ def test_task_serialization(test_project_file):
|
|
|
89
93
|
assert parsed_task.description == "Test Description"
|
|
90
94
|
assert parsed_task.instruction == "Test Base Task Instruction"
|
|
91
95
|
assert parsed_task.thinking_instruction == "Test Thinking Instruction"
|
|
96
|
+
assert parsed_task._loaded_from_file is True
|
|
97
|
+
|
|
98
|
+
# Confirm the local property is not persisted to disk
|
|
99
|
+
json_data = json.loads(parsed_task.path.read_text())
|
|
100
|
+
assert "_loaded_from_file" not in json_data
|
|
92
101
|
|
|
93
102
|
|
|
94
103
|
def test_save_to_file_without_path():
|
|
@@ -225,3 +234,208 @@ def test_task_run_intermediate_outputs():
|
|
|
225
234
|
"cot": "chain of thought output",
|
|
226
235
|
"draft": "draft output",
|
|
227
236
|
}
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def test_finetune_basic():
|
|
240
|
+
# Test basic initialization
|
|
241
|
+
finetune = Finetune(
|
|
242
|
+
name="test-finetune",
|
|
243
|
+
provider="openai",
|
|
244
|
+
base_model_id="gpt-3.5-turbo",
|
|
245
|
+
dataset_split_id="dataset-123",
|
|
246
|
+
train_split_name="train",
|
|
247
|
+
system_message="Test system message",
|
|
248
|
+
)
|
|
249
|
+
assert finetune.name == "test-finetune"
|
|
250
|
+
assert finetune.provider == "openai"
|
|
251
|
+
assert finetune.base_model_id == "gpt-3.5-turbo"
|
|
252
|
+
assert finetune.dataset_split_id == "dataset-123"
|
|
253
|
+
assert finetune.train_split_name == "train"
|
|
254
|
+
assert finetune.provider_id is None
|
|
255
|
+
assert finetune.parameters == {}
|
|
256
|
+
assert finetune.description is None
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def test_finetune_full():
|
|
260
|
+
# Test with all fields populated
|
|
261
|
+
finetune = Finetune(
|
|
262
|
+
name="test-finetune",
|
|
263
|
+
description="Test description",
|
|
264
|
+
provider="openai",
|
|
265
|
+
base_model_id="gpt-3.5-turbo",
|
|
266
|
+
provider_id="ft-abc123",
|
|
267
|
+
dataset_split_id="dataset-123",
|
|
268
|
+
train_split_name="train",
|
|
269
|
+
system_message="Test system message",
|
|
270
|
+
parameters={
|
|
271
|
+
"epochs": 3,
|
|
272
|
+
"learning_rate": 0.1,
|
|
273
|
+
"batch_size": 4,
|
|
274
|
+
"use_fp16": True,
|
|
275
|
+
"model_suffix": "-v1",
|
|
276
|
+
},
|
|
277
|
+
)
|
|
278
|
+
assert finetune.description == "Test description"
|
|
279
|
+
assert finetune.provider_id == "ft-abc123"
|
|
280
|
+
assert finetune.parameters == {
|
|
281
|
+
"epochs": 3,
|
|
282
|
+
"learning_rate": 0.1,
|
|
283
|
+
"batch_size": 4,
|
|
284
|
+
"use_fp16": True,
|
|
285
|
+
"model_suffix": "-v1",
|
|
286
|
+
}
|
|
287
|
+
assert finetune.system_message == "Test system message"
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def test_finetune_parent_task():
|
|
291
|
+
# Test parent_task() method
|
|
292
|
+
task = Task(name="Test Task", instruction="Test instruction")
|
|
293
|
+
finetune = Finetune(
|
|
294
|
+
name="test-finetune",
|
|
295
|
+
provider="openai",
|
|
296
|
+
base_model_id="gpt-3.5-turbo",
|
|
297
|
+
parent=task,
|
|
298
|
+
dataset_split_id="dataset-123",
|
|
299
|
+
train_split_name="train",
|
|
300
|
+
system_message="Test system message",
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
assert finetune.parent_task() == task
|
|
304
|
+
|
|
305
|
+
# Test with no parent
|
|
306
|
+
finetune_no_parent = Finetune(
|
|
307
|
+
name="test-finetune",
|
|
308
|
+
provider="openai",
|
|
309
|
+
base_model_id="gpt-3.5-turbo",
|
|
310
|
+
dataset_split_id="dataset-123",
|
|
311
|
+
train_split_name="train",
|
|
312
|
+
system_message="Test system message",
|
|
313
|
+
)
|
|
314
|
+
assert finetune_no_parent.parent_task() is None
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def test_finetune_parameters_validation():
|
|
318
|
+
# Test that parameters only accept valid types
|
|
319
|
+
with pytest.raises(ValidationError):
|
|
320
|
+
Finetune(
|
|
321
|
+
name="test-finetune",
|
|
322
|
+
provider="openai",
|
|
323
|
+
base_model_id="gpt-3.5-turbo",
|
|
324
|
+
parameters={"invalid": [1, 2, 3]}, # Lists are not allowed
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def test_task_run_input_source_validation(tmp_path):
|
|
329
|
+
# Setup basic output for TaskRun creation
|
|
330
|
+
output = TaskOutput(
|
|
331
|
+
output="test output",
|
|
332
|
+
source=DataSource(
|
|
333
|
+
type=DataSourceType.synthetic,
|
|
334
|
+
properties={
|
|
335
|
+
"model_name": "test-model",
|
|
336
|
+
"model_provider": "test-provider",
|
|
337
|
+
"adapter_name": "test-adapter",
|
|
338
|
+
},
|
|
339
|
+
),
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
project_path = tmp_path / "project.kiln"
|
|
343
|
+
project = Project(name="Test Project", path=project_path)
|
|
344
|
+
project.save_to_file()
|
|
345
|
+
task = Task(name="Test Task", instruction="Test Instruction", parent=project)
|
|
346
|
+
task.save_to_file()
|
|
347
|
+
|
|
348
|
+
# Test 1: Creating without input_source should work when strict mode is off
|
|
349
|
+
task_run = TaskRun(
|
|
350
|
+
input="test input",
|
|
351
|
+
output=output,
|
|
352
|
+
)
|
|
353
|
+
task_run.parent = task
|
|
354
|
+
assert task_run.input_source is None
|
|
355
|
+
|
|
356
|
+
# Save for later usage
|
|
357
|
+
task_run.save_to_file()
|
|
358
|
+
task_missing_input_source = task_run.path
|
|
359
|
+
|
|
360
|
+
# Test 2: Creating with input_source should work when strict mode is off
|
|
361
|
+
task_run = TaskRun(
|
|
362
|
+
input="test input 2",
|
|
363
|
+
input_source=DataSource(
|
|
364
|
+
type=DataSourceType.human,
|
|
365
|
+
properties={"created_by": "test-user"},
|
|
366
|
+
),
|
|
367
|
+
output=output,
|
|
368
|
+
)
|
|
369
|
+
assert task_run.input_source is not None
|
|
370
|
+
|
|
371
|
+
# Test 3: Creating without input_source should fail when strict mode is on
|
|
372
|
+
with patch("kiln_ai.datamodel.strict_mode", return_value=True):
|
|
373
|
+
with pytest.raises(ValueError) as exc_info:
|
|
374
|
+
task_run = TaskRun(
|
|
375
|
+
input="test input 3",
|
|
376
|
+
output=output,
|
|
377
|
+
)
|
|
378
|
+
assert "input_source is required when strict mode is enabled" in str(
|
|
379
|
+
exc_info.value
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
# Test 4: Loading from disk should work without input_source, even with strict mode on
|
|
383
|
+
assert os.path.exists(task_missing_input_source)
|
|
384
|
+
task_run = TaskRun.load_from_file(task_missing_input_source)
|
|
385
|
+
assert task_run.input_source is None
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def test_task_output_source_validation(tmp_path):
|
|
389
|
+
# Setup basic output source for validation
|
|
390
|
+
output_source = DataSource(
|
|
391
|
+
type=DataSourceType.synthetic,
|
|
392
|
+
properties={
|
|
393
|
+
"model_name": "test-model",
|
|
394
|
+
"model_provider": "test-provider",
|
|
395
|
+
"adapter_name": "test-adapter",
|
|
396
|
+
},
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
project_path = tmp_path / "project.kiln"
|
|
400
|
+
project = Project(name="Test Project", path=project_path)
|
|
401
|
+
project.save_to_file()
|
|
402
|
+
task = Task(name="Test Task", instruction="Test Instruction", parent=project)
|
|
403
|
+
task.save_to_file()
|
|
404
|
+
|
|
405
|
+
# Test 1: Creating without source should work when strict mode is off
|
|
406
|
+
task_output = TaskOutput(
|
|
407
|
+
output="test output",
|
|
408
|
+
)
|
|
409
|
+
assert task_output.source is None
|
|
410
|
+
|
|
411
|
+
# Save for later usage
|
|
412
|
+
task_run = TaskRun(
|
|
413
|
+
input="test input",
|
|
414
|
+
input_source=output_source,
|
|
415
|
+
output=task_output,
|
|
416
|
+
)
|
|
417
|
+
task_run.parent = task
|
|
418
|
+
task_run.save_to_file()
|
|
419
|
+
task_missing_output_source = task_run.path
|
|
420
|
+
|
|
421
|
+
# Test 2: Creating with source should work when strict mode is off
|
|
422
|
+
task_output = TaskOutput(
|
|
423
|
+
output="test output 2",
|
|
424
|
+
source=output_source,
|
|
425
|
+
)
|
|
426
|
+
assert task_output.source is not None
|
|
427
|
+
|
|
428
|
+
# Test 3: Creating without source should fail when strict mode is on
|
|
429
|
+
with patch("kiln_ai.datamodel.strict_mode", return_value=True):
|
|
430
|
+
with pytest.raises(ValueError) as exc_info:
|
|
431
|
+
task_output = TaskOutput(
|
|
432
|
+
output="test output 3",
|
|
433
|
+
)
|
|
434
|
+
assert "Output source is required when strict mode is enabled" in str(
|
|
435
|
+
exc_info.value
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
# Test 4: Loading from disk should work without source, even with strict mode on
|
|
439
|
+
assert os.path.exists(task_missing_output_source)
|
|
440
|
+
task_run = TaskRun.load_from_file(task_missing_output_source)
|
|
441
|
+
assert task_run.output.source is None
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
from unittest.mock import Mock, patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from kiln_ai.datamodel import Project
|
|
6
|
+
from kiln_ai.datamodel.registry import all_projects, project_from_id
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@pytest.fixture
|
|
10
|
+
def mock_config():
|
|
11
|
+
with patch("kiln_ai.datamodel.registry.Config") as mock:
|
|
12
|
+
config_instance = Mock()
|
|
13
|
+
mock.shared.return_value = config_instance
|
|
14
|
+
yield config_instance
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@pytest.fixture
|
|
18
|
+
def mock_project():
|
|
19
|
+
def create_mock_project(project_id: str = "test-id"):
|
|
20
|
+
project = Mock(spec=Project)
|
|
21
|
+
project.id = project_id
|
|
22
|
+
return project
|
|
23
|
+
|
|
24
|
+
return create_mock_project
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def test_all_projects_empty(mock_config):
|
|
28
|
+
mock_config.projects = None
|
|
29
|
+
assert all_projects() == []
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def test_all_projects_success(mock_config, mock_project):
|
|
33
|
+
mock_config.projects = ["path1", "path2"]
|
|
34
|
+
|
|
35
|
+
project1 = mock_project("project1")
|
|
36
|
+
project2 = mock_project("project2")
|
|
37
|
+
|
|
38
|
+
with patch("kiln_ai.datamodel.Project.load_from_file") as mock_load:
|
|
39
|
+
mock_load.side_effect = [project1, project2]
|
|
40
|
+
|
|
41
|
+
result = all_projects()
|
|
42
|
+
|
|
43
|
+
assert len(result) == 2
|
|
44
|
+
assert result[0] == project1
|
|
45
|
+
assert result[1] == project2
|
|
46
|
+
mock_load.assert_any_call("path1")
|
|
47
|
+
mock_load.assert_any_call("path2")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def test_all_projects_with_errors(mock_config, mock_project):
|
|
51
|
+
mock_config.projects = ["path1", "path2", "path3"]
|
|
52
|
+
|
|
53
|
+
project1 = mock_project("project1")
|
|
54
|
+
project3 = mock_project("project3")
|
|
55
|
+
|
|
56
|
+
with patch("kiln_ai.datamodel.Project.load_from_file") as mock_load:
|
|
57
|
+
mock_load.side_effect = [project1, Exception("File not found"), project3]
|
|
58
|
+
|
|
59
|
+
result = all_projects()
|
|
60
|
+
|
|
61
|
+
assert len(result) == 2
|
|
62
|
+
assert result[0] == project1
|
|
63
|
+
assert result[1] == project3
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def test_project_from_id_not_found(mock_config):
|
|
67
|
+
mock_config.projects = None
|
|
68
|
+
assert project_from_id("any-id") is None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def test_project_from_id_success(mock_config, mock_project):
|
|
72
|
+
mock_config.projects = ["path1", "path2"]
|
|
73
|
+
|
|
74
|
+
project1 = mock_project("project1")
|
|
75
|
+
project2 = mock_project("project2")
|
|
76
|
+
|
|
77
|
+
with patch("kiln_ai.datamodel.Project.load_from_file") as mock_load:
|
|
78
|
+
mock_load.side_effect = [project1, project2]
|
|
79
|
+
|
|
80
|
+
result = project_from_id("project2")
|
|
81
|
+
|
|
82
|
+
assert result == project2
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def test_project_from_id_with_errors(mock_config, mock_project):
|
|
86
|
+
mock_config.projects = ["path1", "path2", "path3"]
|
|
87
|
+
|
|
88
|
+
project1 = mock_project("project1")
|
|
89
|
+
project3 = mock_project("target-id")
|
|
90
|
+
|
|
91
|
+
with patch("kiln_ai.datamodel.Project.load_from_file") as mock_load:
|
|
92
|
+
mock_load.side_effect = [project1, Exception("File not found"), project3]
|
|
93
|
+
|
|
94
|
+
result = project_from_id("target-id")
|
|
95
|
+
|
|
96
|
+
assert result == project3
|
kiln_ai/utils/config.py
CHANGED
|
@@ -67,10 +67,23 @@ class Config:
|
|
|
67
67
|
env_var="OPENROUTER_API_KEY",
|
|
68
68
|
sensitive=True,
|
|
69
69
|
),
|
|
70
|
+
"fireworks_api_key": ConfigProperty(
|
|
71
|
+
str,
|
|
72
|
+
env_var="FIREWORKS_API_KEY",
|
|
73
|
+
sensitive=True,
|
|
74
|
+
),
|
|
75
|
+
"fireworks_account_id": ConfigProperty(
|
|
76
|
+
str,
|
|
77
|
+
env_var="FIREWORKS_ACCOUNT_ID",
|
|
78
|
+
),
|
|
70
79
|
"projects": ConfigProperty(
|
|
71
80
|
list,
|
|
72
81
|
default_lambda=lambda: [],
|
|
73
82
|
),
|
|
83
|
+
"custom_models": ConfigProperty(
|
|
84
|
+
list,
|
|
85
|
+
default_lambda=lambda: [],
|
|
86
|
+
),
|
|
74
87
|
}
|
|
75
88
|
self._settings = self.load_settings()
|
|
76
89
|
|
|
@@ -136,7 +149,7 @@ class Config:
|
|
|
136
149
|
settings = yaml.safe_load(f.read()) or {}
|
|
137
150
|
return settings
|
|
138
151
|
|
|
139
|
-
def settings(self, hide_sensitive=False):
|
|
152
|
+
def settings(self, hide_sensitive=False) -> Dict[str, Any]:
|
|
140
153
|
if hide_sensitive:
|
|
141
154
|
return {
|
|
142
155
|
k: "[hidden]"
|