kiln-ai 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/adapter_registry.py +2 -0
- kiln_ai/adapters/base_adapter.py +6 -1
- kiln_ai/adapters/langchain_adapters.py +5 -1
- kiln_ai/adapters/ml_model_list.py +43 -12
- kiln_ai/adapters/ollama_tools.py +4 -3
- kiln_ai/adapters/provider_tools.py +63 -2
- kiln_ai/adapters/repair/repair_task.py +4 -2
- kiln_ai/adapters/test_langchain_adapter.py +183 -0
- kiln_ai/adapters/test_provider_tools.py +315 -1
- kiln_ai/datamodel/__init__.py +162 -19
- kiln_ai/datamodel/basemodel.py +90 -42
- kiln_ai/datamodel/model_cache.py +116 -0
- kiln_ai/datamodel/test_basemodel.py +138 -3
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_model_cache.py +244 -0
- kiln_ai/datamodel/test_models.py +173 -0
- kiln_ai/datamodel/test_output_rating.py +377 -10
- kiln_ai/utils/config.py +33 -10
- kiln_ai/utils/test_config.py +48 -0
- kiln_ai-0.8.0.dist-info/METADATA +237 -0
- {kiln_ai-0.7.0.dist-info → kiln_ai-0.8.0.dist-info}/RECORD +23 -21
- {kiln_ai-0.7.0.dist-info → kiln_ai-0.8.0.dist-info}/WHEEL +1 -1
- kiln_ai-0.7.0.dist-info/METADATA +0 -90
- {kiln_ai-0.7.0.dist-info → kiln_ai-0.8.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,14 +1,34 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
1
3
|
import pytest
|
|
2
4
|
from pydantic import ValidationError
|
|
3
5
|
|
|
4
|
-
from kiln_ai.datamodel import TaskOutputRating, TaskOutputRatingType
|
|
6
|
+
from kiln_ai.datamodel import RequirementRating, TaskOutputRating, TaskOutputRatingType
|
|
5
7
|
|
|
6
8
|
|
|
7
9
|
def test_valid_task_output_rating():
|
|
8
10
|
rating = TaskOutputRating(value=4.0, requirement_ratings={"req1": 5.0, "req2": 3.0})
|
|
9
11
|
assert rating.type == TaskOutputRatingType.five_star
|
|
10
12
|
assert rating.value == 4.0
|
|
11
|
-
|
|
13
|
+
dumped = json.loads(rating.model_dump_json())
|
|
14
|
+
assert dumped["requirement_ratings"] == {
|
|
15
|
+
"req1": {"type": TaskOutputRatingType.five_star, "value": 5.0},
|
|
16
|
+
"req2": {"type": TaskOutputRatingType.five_star, "value": 3.0},
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
# new format
|
|
20
|
+
rating = TaskOutputRating(
|
|
21
|
+
value=4.0,
|
|
22
|
+
requirement_ratings={
|
|
23
|
+
"req1": {"type": TaskOutputRatingType.five_star, "value": 5.0},
|
|
24
|
+
"req2": {"type": TaskOutputRatingType.five_star, "value": 3.0},
|
|
25
|
+
},
|
|
26
|
+
)
|
|
27
|
+
dumped = json.loads(rating.model_dump_json())
|
|
28
|
+
assert dumped["requirement_ratings"] == {
|
|
29
|
+
"req1": {"type": TaskOutputRatingType.five_star, "value": 5.0},
|
|
30
|
+
"req2": {"type": TaskOutputRatingType.five_star, "value": 3.0},
|
|
31
|
+
}
|
|
12
32
|
|
|
13
33
|
|
|
14
34
|
def test_invalid_rating_type():
|
|
@@ -40,34 +60,92 @@ def test_rating_below_range():
|
|
|
40
60
|
TaskOutputRating(value=0.0)
|
|
41
61
|
|
|
42
62
|
|
|
43
|
-
def
|
|
44
|
-
rating = TaskOutputRating(
|
|
45
|
-
value
|
|
63
|
+
def test_valid_requirement_ratings_old_format():
|
|
64
|
+
rating = TaskOutputRating.model_validate(
|
|
65
|
+
{"value": 4.0, "requirement_ratings": {"req1": 5.0, "req2": 3.0, "req3": 1.0}}
|
|
66
|
+
)
|
|
67
|
+
dumped = json.loads(rating.model_dump_json())
|
|
68
|
+
assert dumped["requirement_ratings"] == {
|
|
69
|
+
"req1": {"type": TaskOutputRatingType.five_star, "value": 5.0},
|
|
70
|
+
"req2": {"type": TaskOutputRatingType.five_star, "value": 3.0},
|
|
71
|
+
"req3": {"type": TaskOutputRatingType.five_star, "value": 1.0},
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def test_valid_requirement_ratings_new_format():
|
|
76
|
+
rating = TaskOutputRating.model_validate(
|
|
77
|
+
{
|
|
78
|
+
"value": 4.0,
|
|
79
|
+
"requirement_ratings": {
|
|
80
|
+
"req1": {"type": TaskOutputRatingType.five_star, "value": 5.0},
|
|
81
|
+
"req2": {"type": TaskOutputRatingType.five_star, "value": 3.0},
|
|
82
|
+
"req3": {"type": TaskOutputRatingType.five_star, "value": 1.0},
|
|
83
|
+
},
|
|
84
|
+
}
|
|
46
85
|
)
|
|
47
|
-
|
|
86
|
+
dumped = json.loads(rating.model_dump_json())
|
|
87
|
+
assert dumped["requirement_ratings"] == {
|
|
88
|
+
"req1": {"type": TaskOutputRatingType.five_star, "value": 5.0},
|
|
89
|
+
"req2": {"type": TaskOutputRatingType.five_star, "value": 3.0},
|
|
90
|
+
"req3": {"type": TaskOutputRatingType.five_star, "value": 1.0},
|
|
91
|
+
}
|
|
48
92
|
|
|
49
93
|
|
|
50
94
|
def test_invalid_requirement_rating_value():
|
|
51
95
|
with pytest.raises(
|
|
52
96
|
ValidationError,
|
|
53
|
-
match="Requirement rating for req1 of type five_star must be an integer value",
|
|
97
|
+
match="Requirement rating for req id: req1 of type five_star must be an integer value",
|
|
54
98
|
):
|
|
55
99
|
TaskOutputRating(value=4.0, requirement_ratings={"req1": 3.5})
|
|
56
100
|
|
|
101
|
+
# new format
|
|
102
|
+
with pytest.raises(
|
|
103
|
+
ValidationError,
|
|
104
|
+
match="Requirement rating for req id: req1 of type five_star must be an integer value",
|
|
105
|
+
):
|
|
106
|
+
TaskOutputRating(
|
|
107
|
+
value=4.0,
|
|
108
|
+
requirement_ratings={
|
|
109
|
+
"req1": {"type": TaskOutputRatingType.five_star, "value": 3.5}
|
|
110
|
+
},
|
|
111
|
+
)
|
|
112
|
+
|
|
57
113
|
|
|
58
114
|
def test_requirement_rating_out_of_range():
|
|
59
115
|
with pytest.raises(
|
|
60
116
|
ValidationError,
|
|
61
|
-
match="Requirement rating for req1 of type five_star must be between 1 and 5 stars",
|
|
117
|
+
match="Requirement rating for req id: req1 of type five_star must be between 1 and 5 stars",
|
|
62
118
|
):
|
|
63
119
|
TaskOutputRating(value=4.0, requirement_ratings={"req1": 6.0})
|
|
64
120
|
|
|
121
|
+
# new format
|
|
122
|
+
with pytest.raises(
|
|
123
|
+
ValidationError,
|
|
124
|
+
match="Requirement rating for req id: req1 of type five_star must be between 1 and 5 stars",
|
|
125
|
+
):
|
|
126
|
+
TaskOutputRating(
|
|
127
|
+
value=4.0,
|
|
128
|
+
requirement_ratings={
|
|
129
|
+
"req1": {"type": TaskOutputRatingType.five_star, "value": 6.0}
|
|
130
|
+
},
|
|
131
|
+
)
|
|
132
|
+
|
|
65
133
|
|
|
66
134
|
def test_empty_requirement_ratings():
|
|
67
135
|
rating = TaskOutputRating(value=4.0)
|
|
68
136
|
assert rating.requirement_ratings == {}
|
|
69
137
|
|
|
70
138
|
|
|
139
|
+
def test_empty_requirement_ratings_integer():
|
|
140
|
+
rating = TaskOutputRating(
|
|
141
|
+
value=4,
|
|
142
|
+
requirement_ratings={
|
|
143
|
+
"req1": RequirementRating(type=TaskOutputRatingType.five_star, value=5),
|
|
144
|
+
},
|
|
145
|
+
)
|
|
146
|
+
assert rating.requirement_ratings["req1"].value == 5.0
|
|
147
|
+
|
|
148
|
+
|
|
71
149
|
def test_invalid_id_type():
|
|
72
150
|
with pytest.raises(ValidationError):
|
|
73
151
|
TaskOutputRating(
|
|
@@ -77,13 +155,302 @@ def test_invalid_id_type():
|
|
|
77
155
|
},
|
|
78
156
|
)
|
|
79
157
|
|
|
158
|
+
# new format
|
|
159
|
+
with pytest.raises(ValidationError):
|
|
160
|
+
TaskOutputRating(
|
|
161
|
+
value=4.0,
|
|
162
|
+
requirement_ratings={
|
|
163
|
+
123: {"type": TaskOutputRatingType.five_star, "value": 4.0}
|
|
164
|
+
},
|
|
165
|
+
)
|
|
166
|
+
|
|
80
167
|
|
|
81
168
|
def test_valid_custom_rating():
|
|
82
169
|
rating = TaskOutputRating(
|
|
83
170
|
type=TaskOutputRatingType.custom,
|
|
84
171
|
value=31.459,
|
|
85
|
-
requirement_ratings={
|
|
172
|
+
requirement_ratings={
|
|
173
|
+
"req1": {"type": TaskOutputRatingType.custom, "value": 42.0},
|
|
174
|
+
"req2": {"type": TaskOutputRatingType.custom, "value": 3.14},
|
|
175
|
+
},
|
|
86
176
|
)
|
|
87
177
|
assert rating.type == TaskOutputRatingType.custom
|
|
88
178
|
assert rating.value == 31.459
|
|
89
|
-
|
|
179
|
+
dumped = json.loads(rating.model_dump_json())
|
|
180
|
+
assert dumped["requirement_ratings"] == {
|
|
181
|
+
"req1": {"type": TaskOutputRatingType.custom, "value": 42.0},
|
|
182
|
+
"req2": {"type": TaskOutputRatingType.custom, "value": 3.14},
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
# We upgraded the format of requirement_ratings to be a dict of RequirementRating objects from a dict of floats
|
|
187
|
+
def test_task_output_rating_format_upgrade():
|
|
188
|
+
# Test old format (dict of floats)
|
|
189
|
+
old_format = {
|
|
190
|
+
"type": "five_star",
|
|
191
|
+
"value": 4.0,
|
|
192
|
+
"requirement_ratings": {"req1": 5.0, "req2": 3.0},
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
rating = TaskOutputRating.model_validate(old_format)
|
|
196
|
+
|
|
197
|
+
# Verify the upgrade worked
|
|
198
|
+
assert isinstance(rating.requirement_ratings["req1"], RequirementRating)
|
|
199
|
+
assert rating.requirement_ratings["req1"].value == 5.0
|
|
200
|
+
assert rating.requirement_ratings["req1"].type == TaskOutputRatingType.five_star
|
|
201
|
+
assert rating.requirement_ratings["req2"].value == 3.0
|
|
202
|
+
assert rating.requirement_ratings["req2"].type == TaskOutputRatingType.five_star
|
|
203
|
+
|
|
204
|
+
# Verify the json dump is new format
|
|
205
|
+
json_dump = json.loads(rating.model_dump_json())
|
|
206
|
+
assert json_dump["requirement_ratings"]["req1"]["type"] == "five_star"
|
|
207
|
+
assert json_dump["requirement_ratings"]["req1"]["value"] == 5.0
|
|
208
|
+
assert json_dump["requirement_ratings"]["req2"]["type"] == "five_star"
|
|
209
|
+
assert json_dump["requirement_ratings"]["req2"]["value"] == 3.0
|
|
210
|
+
|
|
211
|
+
# Test new format (dict of RequirementRating)
|
|
212
|
+
new_format = {
|
|
213
|
+
"type": "five_star",
|
|
214
|
+
"value": 4.0,
|
|
215
|
+
"requirement_ratings": {
|
|
216
|
+
"req1": {"value": 5.0, "type": "five_star"},
|
|
217
|
+
"req2": {"value": 3.0, "type": "five_star"},
|
|
218
|
+
},
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
rating = TaskOutputRating.model_validate(new_format)
|
|
222
|
+
|
|
223
|
+
# Verify new format works as expected
|
|
224
|
+
assert isinstance(rating.requirement_ratings["req1"], RequirementRating)
|
|
225
|
+
assert rating.requirement_ratings["req1"].value == 5.0
|
|
226
|
+
assert rating.requirement_ratings["req1"].type == TaskOutputRatingType.five_star
|
|
227
|
+
|
|
228
|
+
# Verify the json dump is new format
|
|
229
|
+
json_dump = json.loads(rating.model_dump_json())
|
|
230
|
+
assert json_dump["requirement_ratings"]["req1"]["type"] == "five_star"
|
|
231
|
+
assert json_dump["requirement_ratings"]["req1"]["value"] == 5.0
|
|
232
|
+
assert json_dump["requirement_ratings"]["req2"]["type"] == "five_star"
|
|
233
|
+
assert json_dump["requirement_ratings"]["req2"]["value"] == 3.0
|
|
234
|
+
|
|
235
|
+
# Test mixed format (should fail)
|
|
236
|
+
mixed_format = {
|
|
237
|
+
"type": "five_star",
|
|
238
|
+
"value": 4.0,
|
|
239
|
+
"requirement_ratings": {
|
|
240
|
+
"req1": 5.0,
|
|
241
|
+
"req2": {"value": 3.0, "type": "five_star"},
|
|
242
|
+
},
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
with pytest.raises(ValidationError):
|
|
246
|
+
TaskOutputRating.model_validate(mixed_format)
|
|
247
|
+
|
|
248
|
+
# Test empty requirement_ratings
|
|
249
|
+
empty_format = {"type": "five_star", "value": 4.0, "requirement_ratings": {}}
|
|
250
|
+
|
|
251
|
+
rating = TaskOutputRating.model_validate(empty_format)
|
|
252
|
+
assert rating.requirement_ratings == {}
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def test_valid_pass_fail_rating():
|
|
256
|
+
rating = TaskOutputRating(
|
|
257
|
+
type=TaskOutputRatingType.pass_fail,
|
|
258
|
+
value=1.0,
|
|
259
|
+
requirement_ratings={
|
|
260
|
+
"req1": {"type": TaskOutputRatingType.pass_fail, "value": 1.0},
|
|
261
|
+
"req2": {"type": TaskOutputRatingType.pass_fail, "value": 0.0},
|
|
262
|
+
},
|
|
263
|
+
)
|
|
264
|
+
assert rating.type == TaskOutputRatingType.pass_fail
|
|
265
|
+
assert rating.value == 1.0
|
|
266
|
+
dumped = json.loads(rating.model_dump_json())
|
|
267
|
+
assert dumped["requirement_ratings"] == {
|
|
268
|
+
"req1": {"type": TaskOutputRatingType.pass_fail, "value": 1.0},
|
|
269
|
+
"req2": {"type": TaskOutputRatingType.pass_fail, "value": 0.0},
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def test_invalid_pass_fail_rating_value():
|
|
274
|
+
with pytest.raises(
|
|
275
|
+
ValidationError,
|
|
276
|
+
match="Overall rating of type pass_fail must be an integer value",
|
|
277
|
+
):
|
|
278
|
+
TaskOutputRating(type=TaskOutputRatingType.pass_fail, value=0.5)
|
|
279
|
+
|
|
280
|
+
with pytest.raises(
|
|
281
|
+
ValidationError,
|
|
282
|
+
match="Requirement rating for req id: req1 of type pass_fail must be an integer value",
|
|
283
|
+
):
|
|
284
|
+
TaskOutputRating(
|
|
285
|
+
type=TaskOutputRatingType.pass_fail,
|
|
286
|
+
value=1.0,
|
|
287
|
+
requirement_ratings={
|
|
288
|
+
"req1": {"type": TaskOutputRatingType.pass_fail, "value": 0.5}
|
|
289
|
+
},
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def test_pass_fail_rating_out_of_range():
|
|
294
|
+
with pytest.raises(
|
|
295
|
+
ValidationError,
|
|
296
|
+
match="Overall rating of type pass_fail must be 0 \\(fail\\) or 1 \\(pass\\)",
|
|
297
|
+
):
|
|
298
|
+
TaskOutputRating(type=TaskOutputRatingType.pass_fail, value=2.0)
|
|
299
|
+
|
|
300
|
+
with pytest.raises(
|
|
301
|
+
ValidationError,
|
|
302
|
+
match="Requirement rating for req id: req1 of type pass_fail must be 0 \\(fail\\) or 1 \\(pass\\)",
|
|
303
|
+
):
|
|
304
|
+
TaskOutputRating(
|
|
305
|
+
type=TaskOutputRatingType.pass_fail,
|
|
306
|
+
value=1.0,
|
|
307
|
+
requirement_ratings={
|
|
308
|
+
"req1": {"type": TaskOutputRatingType.pass_fail, "value": 2.0}
|
|
309
|
+
},
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def test_valid_pass_fail_critical_rating():
|
|
314
|
+
rating = TaskOutputRating(
|
|
315
|
+
type=TaskOutputRatingType.pass_fail_critical,
|
|
316
|
+
value=1.0,
|
|
317
|
+
requirement_ratings={
|
|
318
|
+
"req1": {"type": TaskOutputRatingType.pass_fail_critical, "value": 1.0},
|
|
319
|
+
"req2": {"type": TaskOutputRatingType.pass_fail_critical, "value": 0.0},
|
|
320
|
+
"req3": {"type": TaskOutputRatingType.pass_fail_critical, "value": -1.0},
|
|
321
|
+
},
|
|
322
|
+
)
|
|
323
|
+
assert rating.type == TaskOutputRatingType.pass_fail_critical
|
|
324
|
+
assert rating.value == 1.0
|
|
325
|
+
dumped = json.loads(rating.model_dump_json())
|
|
326
|
+
assert dumped["requirement_ratings"] == {
|
|
327
|
+
"req1": {"type": TaskOutputRatingType.pass_fail_critical, "value": 1.0},
|
|
328
|
+
"req2": {"type": TaskOutputRatingType.pass_fail_critical, "value": 0.0},
|
|
329
|
+
"req3": {"type": TaskOutputRatingType.pass_fail_critical, "value": -1.0},
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def test_invalid_pass_fail_critical_rating_value():
|
|
334
|
+
with pytest.raises(
|
|
335
|
+
ValidationError,
|
|
336
|
+
match="Overall rating of type pass_fail_critical must be an integer value",
|
|
337
|
+
):
|
|
338
|
+
TaskOutputRating(type=TaskOutputRatingType.pass_fail_critical, value=0.5)
|
|
339
|
+
|
|
340
|
+
with pytest.raises(
|
|
341
|
+
ValidationError,
|
|
342
|
+
match="Requirement rating for req id: req1 of type pass_fail_critical must be an integer value",
|
|
343
|
+
):
|
|
344
|
+
TaskOutputRating(
|
|
345
|
+
type=TaskOutputRatingType.pass_fail_critical,
|
|
346
|
+
value=1.0,
|
|
347
|
+
requirement_ratings={
|
|
348
|
+
"req1": {"type": TaskOutputRatingType.pass_fail_critical, "value": 0.5}
|
|
349
|
+
},
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def test_pass_fail_critical_rating_out_of_range():
|
|
354
|
+
with pytest.raises(
|
|
355
|
+
ValidationError,
|
|
356
|
+
match="Overall rating of type pass_fail_critical must be -1 \\(critical fail\\), 0 \\(fail\\), or 1 \\(pass\\)",
|
|
357
|
+
):
|
|
358
|
+
TaskOutputRating(type=TaskOutputRatingType.pass_fail_critical, value=2.0)
|
|
359
|
+
|
|
360
|
+
with pytest.raises(
|
|
361
|
+
ValidationError,
|
|
362
|
+
match="Requirement rating for req id: req1 of type pass_fail_critical must be -1 \\(critical fail\\), 0 \\(fail\\), or 1 \\(pass\\)",
|
|
363
|
+
):
|
|
364
|
+
TaskOutputRating(
|
|
365
|
+
type=TaskOutputRatingType.pass_fail_critical,
|
|
366
|
+
value=1.0,
|
|
367
|
+
requirement_ratings={
|
|
368
|
+
"req1": {"type": TaskOutputRatingType.pass_fail_critical, "value": 2.0}
|
|
369
|
+
},
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def test_is_high_quality():
|
|
374
|
+
# Test five_star ratings
|
|
375
|
+
assert (
|
|
376
|
+
TaskOutputRating(
|
|
377
|
+
type=TaskOutputRatingType.five_star, value=5.0
|
|
378
|
+
).is_high_quality()
|
|
379
|
+
is True
|
|
380
|
+
)
|
|
381
|
+
assert (
|
|
382
|
+
TaskOutputRating(
|
|
383
|
+
type=TaskOutputRatingType.five_star, value=4.0
|
|
384
|
+
).is_high_quality()
|
|
385
|
+
is True
|
|
386
|
+
)
|
|
387
|
+
assert (
|
|
388
|
+
TaskOutputRating(
|
|
389
|
+
type=TaskOutputRatingType.five_star, value=3.0
|
|
390
|
+
).is_high_quality()
|
|
391
|
+
is False
|
|
392
|
+
)
|
|
393
|
+
assert (
|
|
394
|
+
TaskOutputRating(
|
|
395
|
+
type=TaskOutputRatingType.five_star, value=2.0
|
|
396
|
+
).is_high_quality()
|
|
397
|
+
is False
|
|
398
|
+
)
|
|
399
|
+
assert (
|
|
400
|
+
TaskOutputRating(
|
|
401
|
+
type=TaskOutputRatingType.five_star, value=1.0
|
|
402
|
+
).is_high_quality()
|
|
403
|
+
is False
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
# Test pass_fail ratings
|
|
407
|
+
assert (
|
|
408
|
+
TaskOutputRating(
|
|
409
|
+
type=TaskOutputRatingType.pass_fail, value=1.0
|
|
410
|
+
).is_high_quality()
|
|
411
|
+
is True
|
|
412
|
+
)
|
|
413
|
+
assert (
|
|
414
|
+
TaskOutputRating(
|
|
415
|
+
type=TaskOutputRatingType.pass_fail, value=0.0
|
|
416
|
+
).is_high_quality()
|
|
417
|
+
is False
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
# Test pass_fail_critical ratings
|
|
421
|
+
assert (
|
|
422
|
+
TaskOutputRating(
|
|
423
|
+
type=TaskOutputRatingType.pass_fail_critical, value=1.0
|
|
424
|
+
).is_high_quality()
|
|
425
|
+
is True
|
|
426
|
+
)
|
|
427
|
+
assert (
|
|
428
|
+
TaskOutputRating(
|
|
429
|
+
type=TaskOutputRatingType.pass_fail_critical, value=0.0
|
|
430
|
+
).is_high_quality()
|
|
431
|
+
is False
|
|
432
|
+
)
|
|
433
|
+
assert (
|
|
434
|
+
TaskOutputRating(
|
|
435
|
+
type=TaskOutputRatingType.pass_fail_critical, value=-1.0
|
|
436
|
+
).is_high_quality()
|
|
437
|
+
is False
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
# Test custom ratings (should always return False)
|
|
441
|
+
assert (
|
|
442
|
+
TaskOutputRating(
|
|
443
|
+
type=TaskOutputRatingType.custom, value=100.0
|
|
444
|
+
).is_high_quality()
|
|
445
|
+
is False
|
|
446
|
+
)
|
|
447
|
+
assert (
|
|
448
|
+
TaskOutputRating(type=TaskOutputRatingType.custom, value=0.0).is_high_quality()
|
|
449
|
+
is False
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
# Test None value
|
|
453
|
+
assert (
|
|
454
|
+
TaskOutputRating(type=TaskOutputRatingType.custom, value=None).is_high_quality()
|
|
455
|
+
is False
|
|
456
|
+
)
|
kiln_ai/utils/config.py
CHANGED
|
@@ -2,7 +2,7 @@ import getpass
|
|
|
2
2
|
import os
|
|
3
3
|
import threading
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Any, Callable, Dict, Optional
|
|
5
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
6
6
|
|
|
7
7
|
import yaml
|
|
8
8
|
|
|
@@ -15,12 +15,14 @@ class ConfigProperty:
|
|
|
15
15
|
env_var: Optional[str] = None,
|
|
16
16
|
default_lambda: Optional[Callable[[], Any]] = None,
|
|
17
17
|
sensitive: bool = False,
|
|
18
|
+
sensitive_keys: Optional[List[str]] = None,
|
|
18
19
|
):
|
|
19
20
|
self.type = type_
|
|
20
21
|
self.default = default
|
|
21
22
|
self.env_var = env_var
|
|
22
23
|
self.default_lambda = default_lambda
|
|
23
24
|
self.sensitive = sensitive
|
|
25
|
+
self.sensitive_keys = sensitive_keys
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
class Config:
|
|
@@ -80,6 +82,15 @@ class Config:
|
|
|
80
82
|
list,
|
|
81
83
|
default_lambda=lambda: [],
|
|
82
84
|
),
|
|
85
|
+
"custom_models": ConfigProperty(
|
|
86
|
+
list,
|
|
87
|
+
default_lambda=lambda: [],
|
|
88
|
+
),
|
|
89
|
+
"openai_compatible_providers": ConfigProperty(
|
|
90
|
+
list,
|
|
91
|
+
default_lambda=lambda: [],
|
|
92
|
+
sensitive_keys=["api_key"],
|
|
93
|
+
),
|
|
83
94
|
}
|
|
84
95
|
self._settings = self.load_settings()
|
|
85
96
|
|
|
@@ -145,15 +156,27 @@ class Config:
|
|
|
145
156
|
settings = yaml.safe_load(f.read()) or {}
|
|
146
157
|
return settings
|
|
147
158
|
|
|
148
|
-
def settings(self, hide_sensitive=False):
|
|
149
|
-
if hide_sensitive:
|
|
150
|
-
return
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
159
|
+
def settings(self, hide_sensitive=False) -> Dict[str, Any]:
|
|
160
|
+
if not hide_sensitive:
|
|
161
|
+
return self._settings
|
|
162
|
+
|
|
163
|
+
settings = {
|
|
164
|
+
k: "[hidden]"
|
|
165
|
+
if k in self._properties and self._properties[k].sensitive
|
|
166
|
+
else v
|
|
167
|
+
for k, v in self._settings.items()
|
|
168
|
+
}
|
|
169
|
+
# Hide sensitive keys in lists. Could generalize this if we every have more types, but right not it's only needed for root elements of lists
|
|
170
|
+
for key, value in settings.items():
|
|
171
|
+
if key in self._properties and self._properties[key].sensitive_keys:
|
|
172
|
+
sensitive_keys = self._properties[key].sensitive_keys or []
|
|
173
|
+
for sensitive_key in sensitive_keys:
|
|
174
|
+
if isinstance(value, list):
|
|
175
|
+
for item in value:
|
|
176
|
+
if sensitive_key in item:
|
|
177
|
+
item[sensitive_key] = "[hidden]"
|
|
178
|
+
|
|
179
|
+
return settings
|
|
157
180
|
|
|
158
181
|
def save_setting(self, name: str, value: Any):
|
|
159
182
|
self.update_settings({name: value})
|
kiln_ai/utils/test_config.py
CHANGED
|
@@ -27,6 +27,7 @@ def config_with_yaml(mock_yaml_file):
|
|
|
27
27
|
),
|
|
28
28
|
"int_property": ConfigProperty(int, default=0),
|
|
29
29
|
"empty_property": ConfigProperty(str),
|
|
30
|
+
"list_of_objects": ConfigProperty(list, default=[]),
|
|
30
31
|
}
|
|
31
32
|
)
|
|
32
33
|
|
|
@@ -251,3 +252,50 @@ def test_stale_values_bug(config_with_yaml):
|
|
|
251
252
|
# Simulate updating the settings file with set_settings
|
|
252
253
|
config_with_yaml.update_settings({"example_property": "third_value"})
|
|
253
254
|
assert config_with_yaml.example_property == "third_value"
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
async def test_openai_compatible_providers():
|
|
258
|
+
config = Config.shared()
|
|
259
|
+
assert config.openai_compatible_providers == []
|
|
260
|
+
|
|
261
|
+
new_settings = [
|
|
262
|
+
{
|
|
263
|
+
"name": "provider1",
|
|
264
|
+
"url": "https://provider1.com",
|
|
265
|
+
"api_key": "password1",
|
|
266
|
+
},
|
|
267
|
+
{
|
|
268
|
+
"name": "provider2",
|
|
269
|
+
"url": "https://provider2.com",
|
|
270
|
+
},
|
|
271
|
+
]
|
|
272
|
+
config.save_setting("openai_compatible_providers", new_settings)
|
|
273
|
+
assert config.openai_compatible_providers == new_settings
|
|
274
|
+
|
|
275
|
+
# Test that sensitive keys are hidden
|
|
276
|
+
settings = config.settings(hide_sensitive=True)
|
|
277
|
+
assert settings["openai_compatible_providers"] == [
|
|
278
|
+
{"name": "provider1", "url": "https://provider1.com", "api_key": "[hidden]"},
|
|
279
|
+
{"name": "provider2", "url": "https://provider2.com"},
|
|
280
|
+
]
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def test_yaml_persistence_structured_data(config_with_yaml, mock_yaml_file):
|
|
284
|
+
# Set a value
|
|
285
|
+
new_settings = [
|
|
286
|
+
{
|
|
287
|
+
"name": "provider1",
|
|
288
|
+
"url": "https://provider1.com",
|
|
289
|
+
"api_key": "password1",
|
|
290
|
+
},
|
|
291
|
+
{
|
|
292
|
+
"name": "provider2",
|
|
293
|
+
"url": "https://provider2.com",
|
|
294
|
+
},
|
|
295
|
+
]
|
|
296
|
+
config_with_yaml.list_of_objects = new_settings
|
|
297
|
+
|
|
298
|
+
# Check that the value was saved to the YAML file
|
|
299
|
+
with open(mock_yaml_file, "r") as f:
|
|
300
|
+
saved_settings = yaml.safe_load(f)
|
|
301
|
+
assert saved_settings["list_of_objects"] == new_settings
|