kiln-ai 0.8.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +7 -7
- kiln_ai/adapters/adapter_registry.py +81 -10
- kiln_ai/adapters/data_gen/data_gen_task.py +21 -3
- kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
- kiln_ai/adapters/eval/base_eval.py +164 -0
- kiln_ai/adapters/eval/eval_runner.py +267 -0
- kiln_ai/adapters/eval/g_eval.py +367 -0
- kiln_ai/adapters/eval/registry.py +16 -0
- kiln_ai/adapters/eval/test_base_eval.py +324 -0
- kiln_ai/adapters/eval/test_eval_runner.py +640 -0
- kiln_ai/adapters/eval/test_g_eval.py +497 -0
- kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
- kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
- kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
- kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +472 -129
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +114 -22
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
- kiln_ai/adapters/ml_model_list.py +434 -93
- kiln_ai/adapters/model_adapters/__init__.py +18 -0
- kiln_ai/adapters/model_adapters/base_adapter.py +250 -0
- kiln_ai/adapters/model_adapters/langchain_adapters.py +309 -0
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +10 -0
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +289 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +199 -0
- kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +105 -97
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +216 -0
- kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +80 -30
- kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +125 -46
- kiln_ai/adapters/ollama_tools.py +0 -1
- kiln_ai/adapters/parsers/__init__.py +10 -0
- kiln_ai/adapters/parsers/base_parser.py +12 -0
- kiln_ai/adapters/parsers/json_parser.py +37 -0
- kiln_ai/adapters/parsers/parser_registry.py +19 -0
- kiln_ai/adapters/parsers/r1_parser.py +69 -0
- kiln_ai/adapters/parsers/test_json_parser.py +81 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
- kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
- kiln_ai/adapters/prompt_builders.py +193 -49
- kiln_ai/adapters/provider_tools.py +91 -36
- kiln_ai/adapters/repair/repair_task.py +18 -19
- kiln_ai/adapters/repair/test_repair_task.py +7 -7
- kiln_ai/adapters/run_output.py +11 -0
- kiln_ai/adapters/test_adapter_registry.py +177 -0
- kiln_ai/adapters/test_generate_docs.py +69 -0
- kiln_ai/adapters/test_ollama_tools.py +0 -1
- kiln_ai/adapters/test_prompt_adaptors.py +25 -18
- kiln_ai/adapters/test_prompt_builders.py +265 -44
- kiln_ai/adapters/test_provider_tools.py +268 -46
- kiln_ai/datamodel/__init__.py +51 -772
- kiln_ai/datamodel/basemodel.py +31 -11
- kiln_ai/datamodel/datamodel_enums.py +58 -0
- kiln_ai/datamodel/dataset_filters.py +114 -0
- kiln_ai/datamodel/dataset_split.py +170 -0
- kiln_ai/datamodel/eval.py +298 -0
- kiln_ai/datamodel/finetune.py +105 -0
- kiln_ai/datamodel/json_schema.py +14 -3
- kiln_ai/datamodel/model_cache.py +8 -3
- kiln_ai/datamodel/project.py +23 -0
- kiln_ai/datamodel/prompt.py +37 -0
- kiln_ai/datamodel/prompt_id.py +83 -0
- kiln_ai/datamodel/strict_mode.py +24 -0
- kiln_ai/datamodel/task.py +181 -0
- kiln_ai/datamodel/task_output.py +321 -0
- kiln_ai/datamodel/task_run.py +164 -0
- kiln_ai/datamodel/test_basemodel.py +80 -2
- kiln_ai/datamodel/test_dataset_filters.py +71 -0
- kiln_ai/datamodel/test_dataset_split.py +127 -6
- kiln_ai/datamodel/test_datasource.py +3 -2
- kiln_ai/datamodel/test_eval_model.py +635 -0
- kiln_ai/datamodel/test_example_models.py +34 -17
- kiln_ai/datamodel/test_json_schema.py +23 -0
- kiln_ai/datamodel/test_model_cache.py +24 -0
- kiln_ai/datamodel/test_model_perf.py +125 -0
- kiln_ai/datamodel/test_models.py +131 -2
- kiln_ai/datamodel/test_prompt_id.py +129 -0
- kiln_ai/datamodel/test_task.py +159 -0
- kiln_ai/utils/config.py +6 -1
- kiln_ai/utils/exhaustive_error.py +6 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/METADATA +45 -7
- kiln_ai-0.12.0.dist-info/RECORD +100 -0
- kiln_ai/adapters/base_adapter.py +0 -191
- kiln_ai/adapters/langchain_adapters.py +0 -256
- kiln_ai-0.8.1.dist-info/RECORD +0 -58
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -11,7 +11,6 @@ from kiln_ai.datamodel import (
|
|
|
11
11
|
Finetune,
|
|
12
12
|
Project,
|
|
13
13
|
Task,
|
|
14
|
-
TaskDeterminism,
|
|
15
14
|
TaskOutput,
|
|
16
15
|
TaskOutputRating,
|
|
17
16
|
TaskOutputRatingType,
|
|
@@ -125,7 +124,6 @@ def test_structured_output_workflow(tmp_path):
|
|
|
125
124
|
name="Structured Output Task",
|
|
126
125
|
parent=project,
|
|
127
126
|
instruction="Generate a JSON object with name and age",
|
|
128
|
-
determinism=TaskDeterminism.semantic_match,
|
|
129
127
|
output_json_schema=json.dumps(
|
|
130
128
|
{
|
|
131
129
|
"type": "object",
|
|
@@ -142,7 +140,7 @@ def test_structured_output_workflow(tmp_path):
|
|
|
142
140
|
|
|
143
141
|
# Create runs
|
|
144
142
|
runs = []
|
|
145
|
-
for source in DataSourceType:
|
|
143
|
+
for source in [DataSourceType.human, DataSourceType.synthetic]:
|
|
146
144
|
for _ in range(2):
|
|
147
145
|
task_run = TaskRun(
|
|
148
146
|
input="Generate info for John Doe",
|
|
@@ -157,7 +155,7 @@ def test_structured_output_workflow(tmp_path):
|
|
|
157
155
|
"adapter_name": "TestAdapter",
|
|
158
156
|
"model_name": "GPT-4",
|
|
159
157
|
"model_provider": "OpenAI",
|
|
160
|
-
"
|
|
158
|
+
"prompt_id": "simple_prompt_builder",
|
|
161
159
|
},
|
|
162
160
|
),
|
|
163
161
|
parent=task,
|
|
@@ -216,9 +214,9 @@ def test_structured_output_workflow(tmp_path):
|
|
|
216
214
|
|
|
217
215
|
assert loaded_task.name == "Structured Output Task"
|
|
218
216
|
assert len(loaded_task.requirements) == 2
|
|
219
|
-
assert len(loaded_task.runs()) == 5
|
|
220
|
-
|
|
221
217
|
loaded_runs = loaded_task.runs()
|
|
218
|
+
assert len(loaded_runs) == 5
|
|
219
|
+
|
|
222
220
|
for task_run in loaded_runs:
|
|
223
221
|
output = task_run.output
|
|
224
222
|
assert output.rating is not None
|
|
@@ -284,6 +282,9 @@ def test_task_output_requirement_rating_keys(tmp_path):
|
|
|
284
282
|
assert task_run.output.rating.requirement_ratings is not None
|
|
285
283
|
|
|
286
284
|
|
|
285
|
+
_schema_match = "This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema."
|
|
286
|
+
|
|
287
|
+
|
|
287
288
|
def test_task_output_schema_validation(tmp_path):
|
|
288
289
|
# Create a project, task, and example hierarchy
|
|
289
290
|
project = Project(name="Test Project", path=(tmp_path / "test_project"))
|
|
@@ -321,12 +322,24 @@ def test_task_output_schema_validation(tmp_path):
|
|
|
321
322
|
task_output.save_to_file()
|
|
322
323
|
|
|
323
324
|
# changing to invalid output
|
|
324
|
-
with pytest.raises(
|
|
325
|
+
with pytest.raises(
|
|
326
|
+
ValueError,
|
|
327
|
+
match=_schema_match,
|
|
328
|
+
):
|
|
325
329
|
task_output.output.output = '{"name": "John Doe", "age": "thirty"}'
|
|
326
330
|
task_output.save_to_file()
|
|
327
331
|
|
|
332
|
+
# changing to invalid output from loaded model
|
|
333
|
+
loaded_task_output = TaskRun.load_from_file(task_output.path)
|
|
334
|
+
with pytest.raises(
|
|
335
|
+
ValueError,
|
|
336
|
+
match=_schema_match,
|
|
337
|
+
):
|
|
338
|
+
loaded_task_output.output.output = '{"name": "John Doe", "age": "forty"}'
|
|
339
|
+
loaded_task_output.save_to_file()
|
|
340
|
+
|
|
328
341
|
# Invalid case: output does not match task output schema
|
|
329
|
-
with pytest.raises(ValueError, match=
|
|
342
|
+
with pytest.raises(ValueError, match=_schema_match):
|
|
330
343
|
task_output = TaskRun(
|
|
331
344
|
input="Test input",
|
|
332
345
|
input_source=DataSource(
|
|
@@ -382,12 +395,18 @@ def test_task_input_schema_validation(tmp_path):
|
|
|
382
395
|
valid_task_output.save_to_file()
|
|
383
396
|
|
|
384
397
|
# Changing to invalid input
|
|
385
|
-
with pytest.raises(ValueError, match=
|
|
398
|
+
with pytest.raises(ValueError, match=_schema_match):
|
|
386
399
|
valid_task_output.input = '{"name": "John Doe", "age": "thirty"}'
|
|
387
400
|
valid_task_output.save_to_file()
|
|
388
401
|
|
|
402
|
+
# loading from file, then changing to invalid input
|
|
403
|
+
loaded_task_output = TaskRun.load_from_file(valid_task_output.path)
|
|
404
|
+
with pytest.raises(ValueError, match=_schema_match):
|
|
405
|
+
loaded_task_output.input = '{"name": "John Doe", "age": "thirty"}'
|
|
406
|
+
loaded_task_output.save_to_file()
|
|
407
|
+
|
|
389
408
|
# Invalid case: input does not match task input schema
|
|
390
|
-
with pytest.raises(ValueError, match=
|
|
409
|
+
with pytest.raises(ValueError, match=_schema_match):
|
|
391
410
|
task_output = TaskRun(
|
|
392
411
|
input='{"name": "John Doe", "age": "thirty"}',
|
|
393
412
|
input_source=DataSource(
|
|
@@ -451,7 +470,7 @@ def test_valid_synthetic_task_output():
|
|
|
451
470
|
"adapter_name": "TestAdapter",
|
|
452
471
|
"model_name": "GPT-4",
|
|
453
472
|
"model_provider": "OpenAI",
|
|
454
|
-
"
|
|
473
|
+
"prompt_id": "simple_prompt_builder",
|
|
455
474
|
},
|
|
456
475
|
),
|
|
457
476
|
)
|
|
@@ -459,7 +478,7 @@ def test_valid_synthetic_task_output():
|
|
|
459
478
|
assert output.source.properties["adapter_name"] == "TestAdapter"
|
|
460
479
|
assert output.source.properties["model_name"] == "GPT-4"
|
|
461
480
|
assert output.source.properties["model_provider"] == "OpenAI"
|
|
462
|
-
assert output.source.properties["
|
|
481
|
+
assert output.source.properties["prompt_id"] == "simple_prompt_builder"
|
|
463
482
|
|
|
464
483
|
|
|
465
484
|
def test_invalid_synthetic_task_output_missing_keys():
|
|
@@ -488,23 +507,21 @@ def test_invalid_synthetic_task_output_empty_values():
|
|
|
488
507
|
"adapter_name": "TestAdapter",
|
|
489
508
|
"model_name": "",
|
|
490
509
|
"model_provider": "OpenAI",
|
|
491
|
-
"
|
|
510
|
+
"prompt_id": "simple_prompt_builder",
|
|
492
511
|
},
|
|
493
512
|
),
|
|
494
513
|
)
|
|
495
514
|
|
|
496
515
|
|
|
497
516
|
def test_invalid_synthetic_task_output_non_string_values():
|
|
498
|
-
with pytest.raises(
|
|
499
|
-
ValidationError, match="'prompt_builder_name' must be of type str"
|
|
500
|
-
):
|
|
517
|
+
with pytest.raises(ValidationError, match="'prompt_id' must be of type str"):
|
|
501
518
|
DataSource(
|
|
502
519
|
type=DataSourceType.synthetic,
|
|
503
520
|
properties={
|
|
504
521
|
"adapter_name": "TestAdapter",
|
|
505
522
|
"model_name": "GPT-4",
|
|
506
523
|
"model_provider": "OpenAI",
|
|
507
|
-
"
|
|
524
|
+
"prompt_id": 123,
|
|
508
525
|
},
|
|
509
526
|
)
|
|
510
527
|
|
|
@@ -4,6 +4,7 @@ from pydantic import BaseModel
|
|
|
4
4
|
from kiln_ai.datamodel.json_schema import (
|
|
5
5
|
JsonObjectSchema,
|
|
6
6
|
schema_from_json_str,
|
|
7
|
+
string_to_json_key,
|
|
7
8
|
validate_schema,
|
|
8
9
|
)
|
|
9
10
|
|
|
@@ -123,3 +124,25 @@ def test_triangle_schema():
|
|
|
123
124
|
validate_schema({"a": 1, "b": 2, "c": 3}, json_triangle_schema)
|
|
124
125
|
with pytest.raises(Exception):
|
|
125
126
|
validate_schema({"a": 1, "b": 2, "c": "3"}, json_triangle_schema)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@pytest.mark.parametrize(
|
|
130
|
+
"input_str,expected",
|
|
131
|
+
[
|
|
132
|
+
("hello world", "hello_world"),
|
|
133
|
+
("Hello World", "hello_world"),
|
|
134
|
+
("hello_world", "hello_world"),
|
|
135
|
+
("HELLO WORLD", "hello_world"),
|
|
136
|
+
("hello123", "hello123"),
|
|
137
|
+
("hello-world", "helloworld"),
|
|
138
|
+
("hello!@#$%^&*()world", "helloworld"),
|
|
139
|
+
(" hello world ", "hello__world"),
|
|
140
|
+
("hello__world", "hello__world"),
|
|
141
|
+
("", ""),
|
|
142
|
+
("!@#$%", ""),
|
|
143
|
+
("snake_case_string", "snake_case_string"),
|
|
144
|
+
("camelCaseString", "camelcasestring"),
|
|
145
|
+
],
|
|
146
|
+
)
|
|
147
|
+
def test_string_to_json_key(input_str: str, expected: str):
|
|
148
|
+
assert string_to_json_key(input_str) == expected
|
|
@@ -242,3 +242,27 @@ def test_check_timestamp_granularity_linux_error():
|
|
|
242
242
|
cache = ModelCache()
|
|
243
243
|
assert cache._check_timestamp_granularity() is False
|
|
244
244
|
assert cache._enabled is False
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def test_get_model_readonly(model_cache, test_path):
|
|
248
|
+
if not model_cache._enabled:
|
|
249
|
+
pytest.skip("Cache is disabled on this fs")
|
|
250
|
+
|
|
251
|
+
model = ModelTest(name="test", value=123)
|
|
252
|
+
mtime_ns = test_path.stat().st_mtime_ns
|
|
253
|
+
|
|
254
|
+
# Set the model in the cache
|
|
255
|
+
model_cache.set_model(test_path, model, mtime_ns)
|
|
256
|
+
|
|
257
|
+
# Get the model in readonly mode
|
|
258
|
+
readonly_model = model_cache.get_model(test_path, ModelTest, readonly=True)
|
|
259
|
+
# Get a regular (copied) model
|
|
260
|
+
copied_model = model_cache.get_model(test_path, ModelTest)
|
|
261
|
+
|
|
262
|
+
# The readonly model should be the exact same instance as the cached model
|
|
263
|
+
assert readonly_model is model_cache.model_cache[test_path][0]
|
|
264
|
+
# While the regular get should be a different instance
|
|
265
|
+
assert copied_model is not model_cache.model_cache[test_path][0]
|
|
266
|
+
|
|
267
|
+
# Both should have the same data
|
|
268
|
+
assert readonly_model == copied_model == model
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import shutil
|
|
2
|
+
import uuid
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from kiln_ai.datamodel import (
|
|
7
|
+
DataSource,
|
|
8
|
+
DataSourceType,
|
|
9
|
+
Project,
|
|
10
|
+
Task,
|
|
11
|
+
TaskOutput,
|
|
12
|
+
TaskRun,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
test_json_schema = """{
|
|
16
|
+
"type": "object",
|
|
17
|
+
"properties": {
|
|
18
|
+
"setup": {
|
|
19
|
+
"description": "The setup of the joke",
|
|
20
|
+
"title": "Setup",
|
|
21
|
+
"type": "string"
|
|
22
|
+
},
|
|
23
|
+
"punchline": {
|
|
24
|
+
"description": "The punchline to the joke",
|
|
25
|
+
"title": "Punchline",
|
|
26
|
+
"type": "string"
|
|
27
|
+
},
|
|
28
|
+
"rating": {
|
|
29
|
+
"anyOf": [
|
|
30
|
+
{
|
|
31
|
+
"type": "integer"
|
|
32
|
+
},
|
|
33
|
+
{
|
|
34
|
+
"type": "null"
|
|
35
|
+
}
|
|
36
|
+
],
|
|
37
|
+
"default": null,
|
|
38
|
+
"description": "How funny the joke is, from 1 to 10",
|
|
39
|
+
"title": "Rating"
|
|
40
|
+
}
|
|
41
|
+
},
|
|
42
|
+
"required": [
|
|
43
|
+
"setup",
|
|
44
|
+
"punchline"
|
|
45
|
+
]
|
|
46
|
+
}
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@pytest.fixture
|
|
51
|
+
def task_run(tmp_path):
|
|
52
|
+
# setup a valid project/task/task_run for testing
|
|
53
|
+
output_source = DataSource(
|
|
54
|
+
type=DataSourceType.synthetic,
|
|
55
|
+
properties={
|
|
56
|
+
"model_name": "test-model",
|
|
57
|
+
"model_provider": "test-provider",
|
|
58
|
+
"adapter_name": "test-adapter",
|
|
59
|
+
},
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
project_path = tmp_path / "project.kiln"
|
|
63
|
+
project = Project(name="Test Project", path=project_path)
|
|
64
|
+
project.save_to_file()
|
|
65
|
+
task = Task(
|
|
66
|
+
name="Test Task",
|
|
67
|
+
instruction="Test Instruction",
|
|
68
|
+
parent=project,
|
|
69
|
+
output_json_schema=test_json_schema,
|
|
70
|
+
input_json_schema=test_json_schema,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
task.save_to_file()
|
|
74
|
+
|
|
75
|
+
task_output = TaskOutput(
|
|
76
|
+
output='{"setup": "Why did the chicken cross the road?", "punchline": "To get to the other side"}',
|
|
77
|
+
source=DataSource(
|
|
78
|
+
type=DataSourceType.synthetic,
|
|
79
|
+
properties={
|
|
80
|
+
"model_name": "test-model",
|
|
81
|
+
"model_provider": "test-provider",
|
|
82
|
+
"adapter_name": "test-adapter",
|
|
83
|
+
},
|
|
84
|
+
),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Save for later usage
|
|
88
|
+
task_run = TaskRun(
|
|
89
|
+
input='{"setup": "Why did the chicken cross the road?", "punchline": "To get to the other side"}',
|
|
90
|
+
input_source=output_source,
|
|
91
|
+
output=task_output,
|
|
92
|
+
)
|
|
93
|
+
task_run.parent = task
|
|
94
|
+
task_run.save_to_file()
|
|
95
|
+
|
|
96
|
+
return task_run
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@pytest.mark.benchmark
|
|
100
|
+
def test_benchmark_load_from_file(benchmark, task_run):
|
|
101
|
+
task_run_path = task_run.path
|
|
102
|
+
|
|
103
|
+
iterations = 500
|
|
104
|
+
total_time = 0
|
|
105
|
+
|
|
106
|
+
for _ in range(iterations):
|
|
107
|
+
# Copy the task to a new temp path, so we don't get warm loads/cached loads
|
|
108
|
+
temp_path = task_run.path.parent / f"temp_task_run_{uuid.uuid4()}.json"
|
|
109
|
+
shutil.copy(str(task_run_path), str(temp_path))
|
|
110
|
+
|
|
111
|
+
# only time loading the model (and one accessor for delayed validation)
|
|
112
|
+
start_time = benchmark._timer()
|
|
113
|
+
loaded = TaskRun.load_from_file(temp_path)
|
|
114
|
+
assert loaded.id == task_run.id
|
|
115
|
+
end_time = benchmark._timer()
|
|
116
|
+
|
|
117
|
+
total_time += end_time - start_time
|
|
118
|
+
|
|
119
|
+
avg_time_per_iteration = total_time / iterations
|
|
120
|
+
ops_per_second = 1.0 / avg_time_per_iteration
|
|
121
|
+
|
|
122
|
+
# I get 8k ops per second on my MBP. Lower value here for CI.
|
|
123
|
+
# Prior to optimization was 290 ops per second.
|
|
124
|
+
if ops_per_second < 1000:
|
|
125
|
+
pytest.fail(f"Ops per second: {ops_per_second:.6f}, expected more than 1k ops")
|
kiln_ai/datamodel/test_models.py
CHANGED
|
@@ -9,7 +9,9 @@ from kiln_ai.datamodel import (
|
|
|
9
9
|
DataSource,
|
|
10
10
|
DataSourceType,
|
|
11
11
|
Finetune,
|
|
12
|
+
FinetuneDataStrategy,
|
|
12
13
|
Project,
|
|
14
|
+
Prompt,
|
|
13
15
|
Task,
|
|
14
16
|
TaskOutput,
|
|
15
17
|
TaskRun,
|
|
@@ -70,6 +72,20 @@ def test_save_to_file(test_project_file):
|
|
|
70
72
|
assert data["description"] == "Test Description"
|
|
71
73
|
|
|
72
74
|
|
|
75
|
+
def test_save_to_file_non_ascii(test_project_file):
|
|
76
|
+
project = Project(
|
|
77
|
+
name="Test Project", description="Chúc mừng!", path=test_project_file
|
|
78
|
+
)
|
|
79
|
+
project.save_to_file()
|
|
80
|
+
|
|
81
|
+
with open(test_project_file, "r", encoding="utf-8") as file:
|
|
82
|
+
data = json.load(file)
|
|
83
|
+
|
|
84
|
+
assert data["v"] == 1
|
|
85
|
+
assert data["name"] == "Test Project"
|
|
86
|
+
assert data["description"] == "Chúc mừng!"
|
|
87
|
+
|
|
88
|
+
|
|
73
89
|
def test_task_defaults():
|
|
74
90
|
task = Task(name="Test Task", instruction="Test Instruction")
|
|
75
91
|
assert task.description is None
|
|
@@ -369,7 +385,7 @@ def test_task_run_input_source_validation(tmp_path):
|
|
|
369
385
|
assert task_run.input_source is not None
|
|
370
386
|
|
|
371
387
|
# Test 3: Creating without input_source should fail when strict mode is on
|
|
372
|
-
with patch("kiln_ai.datamodel.strict_mode", return_value=True):
|
|
388
|
+
with patch("kiln_ai.datamodel.task_run.strict_mode", return_value=True):
|
|
373
389
|
with pytest.raises(ValueError) as exc_info:
|
|
374
390
|
task_run = TaskRun(
|
|
375
391
|
input="test input 3",
|
|
@@ -426,7 +442,7 @@ def test_task_output_source_validation(tmp_path):
|
|
|
426
442
|
assert task_output.source is not None
|
|
427
443
|
|
|
428
444
|
# Test 3: Creating without source should fail when strict mode is on
|
|
429
|
-
with patch("kiln_ai.datamodel.strict_mode", return_value=True):
|
|
445
|
+
with patch("kiln_ai.datamodel.task_output.strict_mode", return_value=True):
|
|
430
446
|
with pytest.raises(ValueError) as exc_info:
|
|
431
447
|
task_output = TaskOutput(
|
|
432
448
|
output="test output 3",
|
|
@@ -488,3 +504,116 @@ def test_task_run_tags_validation():
|
|
|
488
504
|
tags=["valid_tag", "invalid tag"],
|
|
489
505
|
)
|
|
490
506
|
assert "Tags cannot contain spaces. Try underscores." in str(exc_info.value)
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def test_prompt_validation():
|
|
510
|
+
prompt = Prompt(name="Test Prompt Name", prompt="Test Prompt")
|
|
511
|
+
assert prompt.name == "Test Prompt Name"
|
|
512
|
+
assert prompt.prompt == "Test Prompt"
|
|
513
|
+
|
|
514
|
+
with pytest.raises(ValidationError):
|
|
515
|
+
Prompt(name="Test Prompt")
|
|
516
|
+
|
|
517
|
+
with pytest.raises(ValidationError):
|
|
518
|
+
Prompt(name="Test Prompt", prompt=None)
|
|
519
|
+
|
|
520
|
+
with pytest.raises(ValidationError):
|
|
521
|
+
Prompt(name="Test Prompt", prompt="")
|
|
522
|
+
|
|
523
|
+
with pytest.raises(ValidationError):
|
|
524
|
+
Prompt(prompt="Test Prompt")
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def test_prompt_parent_task():
|
|
528
|
+
task = Task(name="Test Task", instruction="Test Instruction")
|
|
529
|
+
prompt = Prompt(name="Test Prompt", prompt="Test Prompt", parent=task)
|
|
530
|
+
assert prompt.parent == task
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
@pytest.mark.parametrize(
|
|
534
|
+
"thinking_instructions,data_strategy,should_raise,expected_message",
|
|
535
|
+
[
|
|
536
|
+
# Test 1: Valid case - no thinking instructions with final_only
|
|
537
|
+
(
|
|
538
|
+
None,
|
|
539
|
+
FinetuneDataStrategy.final_only,
|
|
540
|
+
False,
|
|
541
|
+
None,
|
|
542
|
+
),
|
|
543
|
+
# Test 2: Valid case - thinking instructions with final_and_intermediate
|
|
544
|
+
(
|
|
545
|
+
"Think step by step",
|
|
546
|
+
FinetuneDataStrategy.final_and_intermediate,
|
|
547
|
+
False,
|
|
548
|
+
None,
|
|
549
|
+
),
|
|
550
|
+
# Test 3: Invalid case - thinking instructions with final_only
|
|
551
|
+
(
|
|
552
|
+
"Think step by step",
|
|
553
|
+
FinetuneDataStrategy.final_only,
|
|
554
|
+
True,
|
|
555
|
+
"Thinking instructions can only be used when data_strategy is final_and_intermediate",
|
|
556
|
+
),
|
|
557
|
+
# Test 4: Invalid case - no thinking instructions with final_and_intermediate
|
|
558
|
+
(
|
|
559
|
+
None,
|
|
560
|
+
FinetuneDataStrategy.final_and_intermediate,
|
|
561
|
+
True,
|
|
562
|
+
"Thinking instructions are required when data_strategy is final_and_intermediate",
|
|
563
|
+
),
|
|
564
|
+
],
|
|
565
|
+
)
|
|
566
|
+
def test_finetune_thinking_instructions_validation(
|
|
567
|
+
thinking_instructions, data_strategy, should_raise, expected_message
|
|
568
|
+
):
|
|
569
|
+
base_params = {
|
|
570
|
+
"name": "test-finetune",
|
|
571
|
+
"provider": "openai",
|
|
572
|
+
"base_model_id": "gpt-3.5-turbo",
|
|
573
|
+
"dataset_split_id": "split1",
|
|
574
|
+
"system_message": "test message",
|
|
575
|
+
"data_strategy": data_strategy,
|
|
576
|
+
}
|
|
577
|
+
|
|
578
|
+
if thinking_instructions is not None:
|
|
579
|
+
base_params["thinking_instructions"] = thinking_instructions
|
|
580
|
+
|
|
581
|
+
if should_raise:
|
|
582
|
+
with pytest.raises(ValueError) as exc_info:
|
|
583
|
+
Finetune(**base_params)
|
|
584
|
+
assert expected_message in str(exc_info.value)
|
|
585
|
+
else:
|
|
586
|
+
finetune = Finetune(**base_params)
|
|
587
|
+
assert finetune.thinking_instructions == thinking_instructions
|
|
588
|
+
assert finetune.data_strategy == data_strategy
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
@pytest.mark.parametrize(
|
|
592
|
+
"intermediate_outputs,expected",
|
|
593
|
+
[
|
|
594
|
+
# No intermediate outputs
|
|
595
|
+
(None, False),
|
|
596
|
+
# Empty intermediate outputs
|
|
597
|
+
({}, False),
|
|
598
|
+
# Only chain_of_thought
|
|
599
|
+
({"chain_of_thought": "thinking process"}, True),
|
|
600
|
+
# Only reasoning
|
|
601
|
+
({"reasoning": "reasoning process"}, True),
|
|
602
|
+
# Both chain_of_thought and reasoning
|
|
603
|
+
(
|
|
604
|
+
{"chain_of_thought": "thinking process", "reasoning": "reasoning process"},
|
|
605
|
+
True,
|
|
606
|
+
),
|
|
607
|
+
# Other intermediate outputs but no thinking data
|
|
608
|
+
({"other_output": "some data"}, False),
|
|
609
|
+
# Mixed other outputs with thinking data
|
|
610
|
+
({"chain_of_thought": "thinking process", "other_output": "some data"}, True),
|
|
611
|
+
],
|
|
612
|
+
)
|
|
613
|
+
def test_task_run_has_thinking_training_data(intermediate_outputs, expected):
|
|
614
|
+
task_run = TaskRun(
|
|
615
|
+
input="test input",
|
|
616
|
+
output=TaskOutput(output="test output"),
|
|
617
|
+
intermediate_outputs=intermediate_outputs,
|
|
618
|
+
)
|
|
619
|
+
assert task_run.has_thinking_training_data() == expected
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from pydantic import BaseModel, ValidationError
|
|
3
|
+
|
|
4
|
+
from kiln_ai.datamodel import (
|
|
5
|
+
PromptGenerators,
|
|
6
|
+
PromptId,
|
|
7
|
+
)
|
|
8
|
+
from kiln_ai.datamodel.prompt_id import is_frozen_prompt
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# Test model to validate the PromptId type
|
|
12
|
+
class ModelTester(BaseModel):
|
|
13
|
+
prompt_id: PromptId
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def test_valid_prompt_generator_names():
|
|
17
|
+
"""Test that valid prompt generator names are accepted"""
|
|
18
|
+
for generator in PromptGenerators:
|
|
19
|
+
model = ModelTester(prompt_id=generator.value)
|
|
20
|
+
assert model.prompt_id == generator.value
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_valid_saved_prompt_id():
|
|
24
|
+
"""Test that valid saved prompt IDs are accepted"""
|
|
25
|
+
valid_id = "id::prompt_789"
|
|
26
|
+
model = ModelTester(prompt_id=valid_id)
|
|
27
|
+
assert model.prompt_id == valid_id
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_valid_fine_tune_prompt_id():
|
|
31
|
+
"""Test that valid fine-tune prompt IDs are accepted"""
|
|
32
|
+
valid_id = "fine_tune_prompt::ft_123456"
|
|
33
|
+
model = ModelTester(prompt_id=valid_id)
|
|
34
|
+
assert model.prompt_id == valid_id
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@pytest.mark.parametrize(
|
|
38
|
+
"invalid_id",
|
|
39
|
+
[
|
|
40
|
+
pytest.param("id::project_123::task_456", id="missing_prompt_id"),
|
|
41
|
+
pytest.param("id::task_456::prompt_789", id="too_many_parts"),
|
|
42
|
+
pytest.param("id::", id="empty_parts"),
|
|
43
|
+
],
|
|
44
|
+
)
|
|
45
|
+
def test_invalid_saved_prompt_id_format(invalid_id):
|
|
46
|
+
"""Test that invalid saved prompt ID formats are rejected"""
|
|
47
|
+
with pytest.raises(ValidationError, match="Invalid saved prompt ID"):
|
|
48
|
+
ModelTester(prompt_id=invalid_id)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@pytest.mark.parametrize(
|
|
52
|
+
"invalid_id,expected_error",
|
|
53
|
+
[
|
|
54
|
+
("fine_tune_prompt::", "Invalid fine-tune prompt ID: fine_tune_prompt::"),
|
|
55
|
+
("fine_tune_prompt", "Invalid prompt ID: fine_tune_prompt"),
|
|
56
|
+
],
|
|
57
|
+
)
|
|
58
|
+
def test_invalid_fine_tune_prompt_id_format(invalid_id, expected_error):
|
|
59
|
+
"""Test that invalid fine-tune prompt ID formats are rejected"""
|
|
60
|
+
with pytest.raises(ValidationError, match=expected_error):
|
|
61
|
+
ModelTester(prompt_id=invalid_id)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def test_completely_invalid_formats():
|
|
65
|
+
"""Test that completely invalid formats are rejected"""
|
|
66
|
+
invalid_ids = [
|
|
67
|
+
"", # Empty string
|
|
68
|
+
"invalid_format", # Random string
|
|
69
|
+
"id:wrong_format", # Almost correct but wrong separator
|
|
70
|
+
"fine_tune:wrong_format", # Almost correct but wrong prefix
|
|
71
|
+
":::", # Just separators
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
for invalid_id in invalid_ids:
|
|
75
|
+
with pytest.raises(ValidationError, match="Invalid prompt ID"):
|
|
76
|
+
ModelTester(prompt_id=invalid_id)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def test_prompt_generator_case_sensitivity():
|
|
80
|
+
"""Test that prompt generator names are case sensitive"""
|
|
81
|
+
# Take first generator and modify its case
|
|
82
|
+
first_generator = next(iter(PromptGenerators)).value
|
|
83
|
+
wrong_case = first_generator.upper()
|
|
84
|
+
if wrong_case == first_generator:
|
|
85
|
+
wrong_case = first_generator.lower()
|
|
86
|
+
|
|
87
|
+
with pytest.raises(ValidationError):
|
|
88
|
+
ModelTester(prompt_id=wrong_case)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@pytest.mark.parametrize(
|
|
92
|
+
"valid_id",
|
|
93
|
+
[
|
|
94
|
+
"task_run_config::project_123::task_456::config_123", # Valid task run config prompt ID
|
|
95
|
+
],
|
|
96
|
+
)
|
|
97
|
+
def test_valid_task_run_config_prompt_id(valid_id):
|
|
98
|
+
"""Test that valid eval prompt IDs are accepted"""
|
|
99
|
+
model = ModelTester(prompt_id=valid_id)
|
|
100
|
+
assert model.prompt_id == valid_id
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@pytest.mark.parametrize(
|
|
104
|
+
"invalid_id,expected_error",
|
|
105
|
+
[
|
|
106
|
+
("task_run_config::", "Invalid task run config prompt ID"),
|
|
107
|
+
("task_run_config::p1", "Invalid task run config prompt ID"),
|
|
108
|
+
("task_run_config::p1::t1", "Invalid task run config prompt ID"),
|
|
109
|
+
("task_run_config::p1::t1::c1::extra", "Invalid task run config prompt ID"),
|
|
110
|
+
],
|
|
111
|
+
)
|
|
112
|
+
def test_invalid_eval_prompt_id_format(invalid_id, expected_error):
|
|
113
|
+
"""Test that invalid eval prompt ID formats are rejected"""
|
|
114
|
+
with pytest.raises(ValidationError, match=expected_error):
|
|
115
|
+
ModelTester(prompt_id=invalid_id)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@pytest.mark.parametrize(
|
|
119
|
+
"id,should_be_frozen",
|
|
120
|
+
[
|
|
121
|
+
("simple_prompt_builder", False),
|
|
122
|
+
("id::prompt_123", True),
|
|
123
|
+
("task_run_config::p1::t1", True),
|
|
124
|
+
("fine_tune_prompt::ft_123", True),
|
|
125
|
+
],
|
|
126
|
+
)
|
|
127
|
+
def test_is_frozen_prompt(id, should_be_frozen):
|
|
128
|
+
"""Test that the is_frozen_prompt function works"""
|
|
129
|
+
assert is_frozen_prompt(id) == should_be_frozen
|