kiln-ai 0.11.1__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +4 -0
- kiln_ai/adapters/adapter_registry.py +163 -39
- kiln_ai/adapters/data_gen/data_gen_task.py +18 -0
- kiln_ai/adapters/eval/__init__.py +28 -0
- kiln_ai/adapters/eval/base_eval.py +164 -0
- kiln_ai/adapters/eval/eval_runner.py +270 -0
- kiln_ai/adapters/eval/g_eval.py +368 -0
- kiln_ai/adapters/eval/registry.py +16 -0
- kiln_ai/adapters/eval/test_base_eval.py +325 -0
- kiln_ai/adapters/eval/test_eval_runner.py +641 -0
- kiln_ai/adapters/eval/test_g_eval.py +498 -0
- kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +4 -1
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +1 -1
- kiln_ai/adapters/fine_tune/test_together_finetune.py +531 -0
- kiln_ai/adapters/fine_tune/together_finetune.py +325 -0
- kiln_ai/adapters/ml_model_list.py +758 -163
- kiln_ai/adapters/model_adapters/__init__.py +2 -4
- kiln_ai/adapters/model_adapters/base_adapter.py +61 -43
- kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
- kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +22 -13
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -19
- kiln_ai/adapters/model_adapters/test_structured_output.py +59 -35
- kiln_ai/adapters/ollama_tools.py +3 -3
- kiln_ai/adapters/parsers/r1_parser.py +19 -14
- kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
- kiln_ai/adapters/prompt_builders.py +80 -42
- kiln_ai/adapters/provider_tools.py +50 -58
- kiln_ai/adapters/repair/repair_task.py +9 -21
- kiln_ai/adapters/repair/test_repair_task.py +6 -6
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +26 -29
- kiln_ai/adapters/test_generate_docs.py +4 -4
- kiln_ai/adapters/test_ollama_tools.py +0 -1
- kiln_ai/adapters/test_prompt_adaptors.py +47 -33
- kiln_ai/adapters/test_prompt_builders.py +91 -31
- kiln_ai/adapters/test_provider_tools.py +26 -81
- kiln_ai/datamodel/__init__.py +50 -952
- kiln_ai/datamodel/basemodel.py +2 -0
- kiln_ai/datamodel/datamodel_enums.py +60 -0
- kiln_ai/datamodel/dataset_filters.py +114 -0
- kiln_ai/datamodel/dataset_split.py +170 -0
- kiln_ai/datamodel/eval.py +298 -0
- kiln_ai/datamodel/finetune.py +105 -0
- kiln_ai/datamodel/json_schema.py +7 -1
- kiln_ai/datamodel/project.py +23 -0
- kiln_ai/datamodel/prompt.py +37 -0
- kiln_ai/datamodel/prompt_id.py +83 -0
- kiln_ai/datamodel/strict_mode.py +24 -0
- kiln_ai/datamodel/task.py +181 -0
- kiln_ai/datamodel/task_output.py +328 -0
- kiln_ai/datamodel/task_run.py +164 -0
- kiln_ai/datamodel/test_basemodel.py +19 -11
- kiln_ai/datamodel/test_dataset_filters.py +71 -0
- kiln_ai/datamodel/test_dataset_split.py +32 -8
- kiln_ai/datamodel/test_datasource.py +22 -2
- kiln_ai/datamodel/test_eval_model.py +635 -0
- kiln_ai/datamodel/test_example_models.py +9 -13
- kiln_ai/datamodel/test_json_schema.py +23 -0
- kiln_ai/datamodel/test_models.py +2 -2
- kiln_ai/datamodel/test_prompt_id.py +129 -0
- kiln_ai/datamodel/test_task.py +159 -0
- kiln_ai/utils/config.py +43 -1
- kiln_ai/utils/dataset_import.py +232 -0
- kiln_ai/utils/test_dataset_import.py +596 -0
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/METADATA +86 -6
- kiln_ai-0.13.0.dist-info/RECORD +103 -0
- kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -302
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -11
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -246
- kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -350
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -225
- kiln_ai-0.11.1.dist-info/RECORD +0 -76
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.13.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,596 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from io import StringIO
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
from pydantic import BaseModel, ValidationError
|
|
9
|
+
|
|
10
|
+
from kiln_ai.datamodel import Project, Task
|
|
11
|
+
from kiln_ai.utils.dataset_import import (
|
|
12
|
+
DatasetFileImporter,
|
|
13
|
+
DatasetImportFormat,
|
|
14
|
+
ImportConfig,
|
|
15
|
+
KilnInvalidImportFormat,
|
|
16
|
+
deserialize_tags,
|
|
17
|
+
format_validation_error,
|
|
18
|
+
generate_import_tags,
|
|
19
|
+
without_none_values,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pytest.fixture
|
|
26
|
+
def base_task(tmp_path) -> Task:
|
|
27
|
+
project_path = tmp_path / "project.kiln"
|
|
28
|
+
|
|
29
|
+
project = Project(name="TestProject", path=str(project_path))
|
|
30
|
+
project.save_to_file()
|
|
31
|
+
|
|
32
|
+
task = Task(
|
|
33
|
+
name="Sentiment Classifier",
|
|
34
|
+
parent=project,
|
|
35
|
+
description="Classify the sentiment of a sentence",
|
|
36
|
+
instruction="Classify the sentiment of a sentence",
|
|
37
|
+
requirements=[],
|
|
38
|
+
)
|
|
39
|
+
task.save_to_file()
|
|
40
|
+
return task
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@pytest.fixture
|
|
44
|
+
def task_with_structured_output(base_task: Task):
|
|
45
|
+
base_task.output_json_schema = json.dumps(
|
|
46
|
+
{
|
|
47
|
+
"type": "object",
|
|
48
|
+
"properties": {
|
|
49
|
+
"sentiment": {"type": "string"},
|
|
50
|
+
"confidence": {"type": "number"},
|
|
51
|
+
},
|
|
52
|
+
"required": ["sentiment", "confidence"],
|
|
53
|
+
}
|
|
54
|
+
)
|
|
55
|
+
base_task.save_to_file()
|
|
56
|
+
return base_task
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@pytest.fixture
|
|
60
|
+
def task_with_structured_input(base_task: Task):
|
|
61
|
+
base_task.input_json_schema = json.dumps(
|
|
62
|
+
{
|
|
63
|
+
"type": "object",
|
|
64
|
+
"properties": {
|
|
65
|
+
"example_id": {"type": "integer"},
|
|
66
|
+
"text": {"type": "string"},
|
|
67
|
+
},
|
|
68
|
+
"required": ["example_id", "text"],
|
|
69
|
+
}
|
|
70
|
+
)
|
|
71
|
+
base_task.save_to_file()
|
|
72
|
+
return base_task
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@pytest.fixture
|
|
76
|
+
def task_with_intermediate_outputs(base_task: Task):
|
|
77
|
+
for run in base_task.runs():
|
|
78
|
+
run.intermediate_outputs = {"reasoning": "thinking output"}
|
|
79
|
+
base_task.thinking_instruction = "thinking instructions"
|
|
80
|
+
return base_task
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def dict_to_csv_row(row: dict) -> str:
|
|
84
|
+
"""Convert a dictionary to a CSV row with proper escaping."""
|
|
85
|
+
output = StringIO()
|
|
86
|
+
writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL)
|
|
87
|
+
writer.writerow(row.values())
|
|
88
|
+
return output.getvalue().rstrip("\n")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def dicts_to_file_as_csv(items: list[dict], file_name: str, tmp_path: Path) -> str:
|
|
92
|
+
"""Write a list of dictionaries to a CSV file with escaping and a header.
|
|
93
|
+
|
|
94
|
+
Returns the path to the file.
|
|
95
|
+
"""
|
|
96
|
+
rows = [dict_to_csv_row(item) for item in items]
|
|
97
|
+
header = ",".join(f'"{key}"' for key in items[0].keys())
|
|
98
|
+
csv_data = header + "\n" + "\n".join(rows)
|
|
99
|
+
|
|
100
|
+
file_path = tmp_path / file_name
|
|
101
|
+
with open(file_path, "w", encoding="utf-8") as f:
|
|
102
|
+
f.write(csv_data)
|
|
103
|
+
|
|
104
|
+
return file_path
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def compare_tags(actual_tags: list[str], expected_tags: list[str]):
|
|
108
|
+
"""Compare the tags of a run to a list of tags.
|
|
109
|
+
|
|
110
|
+
Returns True if the run.tags contains all the tags in the list.
|
|
111
|
+
"""
|
|
112
|
+
# the run.tags contain some extra default tags
|
|
113
|
+
if expected_tags:
|
|
114
|
+
tags_expected = expected_tags.split(",")
|
|
115
|
+
else:
|
|
116
|
+
tags_expected = []
|
|
117
|
+
|
|
118
|
+
assert all(tag in actual_tags for tag in tags_expected)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def test_import_csv_plain_text(base_task: Task, tmp_path):
|
|
122
|
+
row_data = [
|
|
123
|
+
{
|
|
124
|
+
"input": "This is my input",
|
|
125
|
+
"output": "This is my output 啊",
|
|
126
|
+
"tags": "t1,t2",
|
|
127
|
+
},
|
|
128
|
+
{
|
|
129
|
+
"input": "This is my input 2",
|
|
130
|
+
"output": "This is my output 2 啊",
|
|
131
|
+
"tags": "t3,t4",
|
|
132
|
+
},
|
|
133
|
+
{
|
|
134
|
+
"input": "This is my input 3",
|
|
135
|
+
"output": "This is my output 3 啊",
|
|
136
|
+
"tags": "t5",
|
|
137
|
+
},
|
|
138
|
+
{
|
|
139
|
+
"input": "This is my input 4",
|
|
140
|
+
"output": "This is my output 4 啊",
|
|
141
|
+
"tags": "",
|
|
142
|
+
},
|
|
143
|
+
]
|
|
144
|
+
|
|
145
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
146
|
+
|
|
147
|
+
importer = DatasetFileImporter(
|
|
148
|
+
base_task,
|
|
149
|
+
ImportConfig(
|
|
150
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
151
|
+
dataset_path=file_path,
|
|
152
|
+
dataset_name="test.csv",
|
|
153
|
+
),
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
importer.create_runs_from_file()
|
|
157
|
+
|
|
158
|
+
assert len(base_task.runs()) == 4
|
|
159
|
+
|
|
160
|
+
for run in base_task.runs():
|
|
161
|
+
# identify the row data with same input as the run
|
|
162
|
+
match = next(
|
|
163
|
+
(row for row in row_data if row["input"] == run.input),
|
|
164
|
+
None,
|
|
165
|
+
)
|
|
166
|
+
assert match is not None
|
|
167
|
+
assert run.input == match["input"]
|
|
168
|
+
assert run.output.output == match["output"]
|
|
169
|
+
|
|
170
|
+
compare_tags(run.tags, match["tags"])
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def test_import_csv_default_tags(base_task: Task, tmp_path):
|
|
174
|
+
row_data = [
|
|
175
|
+
{
|
|
176
|
+
"input": "This is my input",
|
|
177
|
+
"output": "This is my output 啊",
|
|
178
|
+
"tags": "t1,t2",
|
|
179
|
+
},
|
|
180
|
+
{
|
|
181
|
+
"input": "This is my input 4",
|
|
182
|
+
"output": "This is my output 4 啊",
|
|
183
|
+
"tags": "",
|
|
184
|
+
},
|
|
185
|
+
]
|
|
186
|
+
|
|
187
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
188
|
+
|
|
189
|
+
importer = DatasetFileImporter(
|
|
190
|
+
base_task,
|
|
191
|
+
ImportConfig(
|
|
192
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
193
|
+
dataset_path=file_path,
|
|
194
|
+
dataset_name="test.csv",
|
|
195
|
+
),
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
importer.create_runs_from_file()
|
|
199
|
+
|
|
200
|
+
assert len(base_task.runs()) == 2
|
|
201
|
+
|
|
202
|
+
default_tags = 2
|
|
203
|
+
|
|
204
|
+
for run in base_task.runs():
|
|
205
|
+
# identify the row data with same input as the run
|
|
206
|
+
match = next(
|
|
207
|
+
(row for row in row_data if row["input"] == run.input),
|
|
208
|
+
None,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
assert match is not None
|
|
212
|
+
|
|
213
|
+
if match["tags"]:
|
|
214
|
+
expected_tags = match["tags"].split(",")
|
|
215
|
+
assert len(run.tags) == len(expected_tags) + default_tags
|
|
216
|
+
else:
|
|
217
|
+
assert len(run.tags) == default_tags
|
|
218
|
+
|
|
219
|
+
# these are the default tags
|
|
220
|
+
assert "imported" in run.tags
|
|
221
|
+
assert any(tag.startswith("imported_") for tag in run.tags)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def test_import_csv_plain_text_missing_output(base_task: Task, tmp_path):
|
|
225
|
+
row_data = [
|
|
226
|
+
{"input": "This is my input", "tags": "t1,t2"},
|
|
227
|
+
{"input": "This is my input 2", "tags": "t3,t4"},
|
|
228
|
+
{"input": "This is my input 3", "tags": "t5,t6"},
|
|
229
|
+
]
|
|
230
|
+
|
|
231
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
232
|
+
|
|
233
|
+
importer = DatasetFileImporter(
|
|
234
|
+
base_task,
|
|
235
|
+
ImportConfig(
|
|
236
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
237
|
+
dataset_path=file_path,
|
|
238
|
+
dataset_name="test.csv",
|
|
239
|
+
),
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# check that the import raises an exception
|
|
243
|
+
with pytest.raises(KilnInvalidImportFormat) as e:
|
|
244
|
+
importer.create_runs_from_file()
|
|
245
|
+
|
|
246
|
+
# no row number because the whole structure is invalid
|
|
247
|
+
assert e.value.row_number is None
|
|
248
|
+
assert "Missing required headers" in str(e.value)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def test_import_csv_structured_output(task_with_structured_output: Task, tmp_path):
|
|
252
|
+
row_data = [
|
|
253
|
+
{
|
|
254
|
+
"input": "This is my input",
|
|
255
|
+
"output": json.dumps({"sentiment": "高兴", "confidence": 0.95}),
|
|
256
|
+
"tags": "t1,t2",
|
|
257
|
+
},
|
|
258
|
+
{
|
|
259
|
+
"input": "This is my input 2",
|
|
260
|
+
"output": json.dumps({"sentiment": "negative", "confidence": 0.05}),
|
|
261
|
+
"tags": "t3,t4",
|
|
262
|
+
},
|
|
263
|
+
{
|
|
264
|
+
"input": "This is my input 3",
|
|
265
|
+
"output": json.dumps({"sentiment": "neutral", "confidence": 0.5}),
|
|
266
|
+
"tags": "",
|
|
267
|
+
},
|
|
268
|
+
]
|
|
269
|
+
|
|
270
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
271
|
+
|
|
272
|
+
importer = DatasetFileImporter(
|
|
273
|
+
task_with_structured_output,
|
|
274
|
+
ImportConfig(
|
|
275
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
276
|
+
dataset_path=file_path,
|
|
277
|
+
dataset_name="test.csv",
|
|
278
|
+
),
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
importer.create_runs_from_file()
|
|
282
|
+
|
|
283
|
+
assert len(task_with_structured_output.runs()) == 3
|
|
284
|
+
|
|
285
|
+
for run in task_with_structured_output.runs():
|
|
286
|
+
# identify the row data with same input as the run
|
|
287
|
+
match = next(
|
|
288
|
+
(row for row in row_data if row["input"] == run.input),
|
|
289
|
+
None,
|
|
290
|
+
)
|
|
291
|
+
assert match is not None
|
|
292
|
+
assert run.input == match["input"]
|
|
293
|
+
assert json.loads(run.output.output) == json.loads(match["output"])
|
|
294
|
+
|
|
295
|
+
compare_tags(run.tags, match["tags"])
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def test_import_csv_structured_output_wrong_schema(
|
|
299
|
+
task_with_structured_output: Task, tmp_path
|
|
300
|
+
):
|
|
301
|
+
row_data = [
|
|
302
|
+
{
|
|
303
|
+
"input": "This is my input",
|
|
304
|
+
"output": json.dumps({"sentiment": "positive", "confidence": 0.95}),
|
|
305
|
+
"tags": "t1,t2",
|
|
306
|
+
},
|
|
307
|
+
{
|
|
308
|
+
"input": "This is my input 2",
|
|
309
|
+
# the output is wrong because sentiment is not a string
|
|
310
|
+
"output": json.dumps({"sentiment": 100, "confidence": 0.05}),
|
|
311
|
+
"tags": "t3,t4",
|
|
312
|
+
},
|
|
313
|
+
{
|
|
314
|
+
"input": "This is my input 3",
|
|
315
|
+
"output": json.dumps({"sentiment": "positive", "confidence": 0.5}),
|
|
316
|
+
"tags": "",
|
|
317
|
+
},
|
|
318
|
+
]
|
|
319
|
+
|
|
320
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
321
|
+
|
|
322
|
+
importer = DatasetFileImporter(
|
|
323
|
+
task_with_structured_output,
|
|
324
|
+
ImportConfig(
|
|
325
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
326
|
+
dataset_path=file_path,
|
|
327
|
+
dataset_name="test.csv",
|
|
328
|
+
),
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
# check that the import raises an exception
|
|
332
|
+
with pytest.raises(KilnInvalidImportFormat) as e:
|
|
333
|
+
importer.create_runs_from_file()
|
|
334
|
+
|
|
335
|
+
# the row number is +1 because of the header
|
|
336
|
+
assert e.value.row_number == 3
|
|
337
|
+
assert "Error in row 3: Validation failed" in str(e.value)
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def test_import_csv_structured_input_wrong_schema(
|
|
341
|
+
task_with_structured_input: Task, tmp_path
|
|
342
|
+
):
|
|
343
|
+
row_data = [
|
|
344
|
+
{
|
|
345
|
+
# this one is missing example_id
|
|
346
|
+
"input": json.dumps({"example_id": 1, "text": "This is my input"}),
|
|
347
|
+
"output": "This is my output",
|
|
348
|
+
"tags": "t1,t2",
|
|
349
|
+
},
|
|
350
|
+
{
|
|
351
|
+
"input": json.dumps({"text": "This is my input 2"}),
|
|
352
|
+
"output": "This is my output 2",
|
|
353
|
+
"tags": "t3,t4",
|
|
354
|
+
},
|
|
355
|
+
{
|
|
356
|
+
"input": json.dumps({"example_id": 3, "text": "This is my input 3"}),
|
|
357
|
+
"output": "This is my output 3",
|
|
358
|
+
"tags": "",
|
|
359
|
+
},
|
|
360
|
+
]
|
|
361
|
+
|
|
362
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
363
|
+
|
|
364
|
+
importer = DatasetFileImporter(
|
|
365
|
+
task_with_structured_input,
|
|
366
|
+
ImportConfig(
|
|
367
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
368
|
+
dataset_path=file_path,
|
|
369
|
+
dataset_name="test.csv",
|
|
370
|
+
),
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
# check that the import raises an exception
|
|
374
|
+
with pytest.raises(KilnInvalidImportFormat) as e:
|
|
375
|
+
importer.create_runs_from_file()
|
|
376
|
+
|
|
377
|
+
# the row number is +1 because of the header
|
|
378
|
+
assert e.value.row_number == 3
|
|
379
|
+
assert "Error in row 3: Validation failed" in str(e.value)
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def test_import_csv_intermediate_outputs_reasoning(
|
|
383
|
+
task_with_intermediate_outputs: Task,
|
|
384
|
+
tmp_path,
|
|
385
|
+
):
|
|
386
|
+
row_data = [
|
|
387
|
+
{
|
|
388
|
+
"input": "This is my input",
|
|
389
|
+
"output": "This is my output",
|
|
390
|
+
"reasoning": "我觉得这个输出是正确的",
|
|
391
|
+
"tags": "t1,t2",
|
|
392
|
+
},
|
|
393
|
+
{
|
|
394
|
+
"input": "This is my input 2",
|
|
395
|
+
"output": "This is my output 2",
|
|
396
|
+
"reasoning": "thinking output 2",
|
|
397
|
+
"tags": "t3,t4",
|
|
398
|
+
},
|
|
399
|
+
{
|
|
400
|
+
"input": "This is my input 3",
|
|
401
|
+
"output": "This is my output 3",
|
|
402
|
+
"reasoning": "thinking output 3",
|
|
403
|
+
"tags": "",
|
|
404
|
+
},
|
|
405
|
+
]
|
|
406
|
+
|
|
407
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
408
|
+
|
|
409
|
+
importer = DatasetFileImporter(
|
|
410
|
+
task_with_intermediate_outputs,
|
|
411
|
+
ImportConfig(
|
|
412
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
413
|
+
dataset_path=file_path,
|
|
414
|
+
dataset_name="test.csv",
|
|
415
|
+
),
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
importer.create_runs_from_file()
|
|
419
|
+
|
|
420
|
+
assert len(task_with_intermediate_outputs.runs()) == 3
|
|
421
|
+
|
|
422
|
+
for run in task_with_intermediate_outputs.runs():
|
|
423
|
+
# identify the row data with same input as the run
|
|
424
|
+
match = next(
|
|
425
|
+
(row for row in row_data if row["input"] == run.input),
|
|
426
|
+
None,
|
|
427
|
+
)
|
|
428
|
+
assert match is not None
|
|
429
|
+
assert run.input == match["input"]
|
|
430
|
+
assert run.output.output == match["output"]
|
|
431
|
+
assert run.intermediate_outputs["reasoning"] == match["reasoning"]
|
|
432
|
+
compare_tags(run.tags, match["tags"])
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def test_import_csv_intermediate_outputs_cot(
|
|
436
|
+
task_with_intermediate_outputs: Task, tmp_path
|
|
437
|
+
):
|
|
438
|
+
row_data = [
|
|
439
|
+
{
|
|
440
|
+
"input": "This is my input",
|
|
441
|
+
"output": "This is my output",
|
|
442
|
+
"chain_of_thought": "我觉得这个输出是正确的",
|
|
443
|
+
"tags": "t1,t2",
|
|
444
|
+
},
|
|
445
|
+
{
|
|
446
|
+
"input": "This is my input 2",
|
|
447
|
+
"output": "This is my output 2",
|
|
448
|
+
"chain_of_thought": "thinking output 2",
|
|
449
|
+
"tags": "t3,t4",
|
|
450
|
+
},
|
|
451
|
+
{
|
|
452
|
+
"input": "This is my input 3",
|
|
453
|
+
"output": "This is my output 3",
|
|
454
|
+
"chain_of_thought": "thinking output 3",
|
|
455
|
+
"tags": "",
|
|
456
|
+
},
|
|
457
|
+
]
|
|
458
|
+
|
|
459
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
460
|
+
|
|
461
|
+
importer = DatasetFileImporter(
|
|
462
|
+
task_with_intermediate_outputs,
|
|
463
|
+
ImportConfig(
|
|
464
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
465
|
+
dataset_path=file_path,
|
|
466
|
+
dataset_name="test.csv",
|
|
467
|
+
),
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
importer.create_runs_from_file()
|
|
471
|
+
|
|
472
|
+
assert len(task_with_intermediate_outputs.runs()) == 3
|
|
473
|
+
|
|
474
|
+
for run in task_with_intermediate_outputs.runs():
|
|
475
|
+
# identify the row data with same input as the run
|
|
476
|
+
match = next(
|
|
477
|
+
(row for row in row_data if row["input"] == run.input),
|
|
478
|
+
None,
|
|
479
|
+
)
|
|
480
|
+
assert match is not None
|
|
481
|
+
assert run.input == match["input"]
|
|
482
|
+
assert run.output.output == match["output"]
|
|
483
|
+
assert run.intermediate_outputs["chain_of_thought"] == match["chain_of_thought"]
|
|
484
|
+
compare_tags(run.tags, match["tags"])
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
def test_import_csv_intermediate_outputs_reasoning_and_cot(
|
|
488
|
+
task_with_intermediate_outputs: Task,
|
|
489
|
+
tmp_path,
|
|
490
|
+
):
|
|
491
|
+
row_data = [
|
|
492
|
+
{
|
|
493
|
+
"input": "This is my input",
|
|
494
|
+
"output": "This is my output",
|
|
495
|
+
"reasoning": "thinking output 1",
|
|
496
|
+
"chain_of_thought": "thinking output 1",
|
|
497
|
+
"tags": "t1,t2",
|
|
498
|
+
},
|
|
499
|
+
]
|
|
500
|
+
|
|
501
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
502
|
+
|
|
503
|
+
importer = DatasetFileImporter(
|
|
504
|
+
task_with_intermediate_outputs,
|
|
505
|
+
ImportConfig(
|
|
506
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
507
|
+
dataset_path=file_path,
|
|
508
|
+
dataset_name="test.csv",
|
|
509
|
+
),
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
importer.create_runs_from_file()
|
|
513
|
+
|
|
514
|
+
assert len(task_with_intermediate_outputs.runs()) == 1
|
|
515
|
+
|
|
516
|
+
for run in task_with_intermediate_outputs.runs():
|
|
517
|
+
# identify the row data with same input as the run
|
|
518
|
+
match = next(
|
|
519
|
+
(row for row in row_data if row["input"] == run.input),
|
|
520
|
+
None,
|
|
521
|
+
)
|
|
522
|
+
assert match is not None
|
|
523
|
+
assert run.input == match["input"]
|
|
524
|
+
assert run.output.output == match["output"]
|
|
525
|
+
assert run.intermediate_outputs["chain_of_thought"] == match["chain_of_thought"]
|
|
526
|
+
assert run.intermediate_outputs["reasoning"] == match["reasoning"]
|
|
527
|
+
compare_tags(run.tags, match["tags"])
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def test_import_csv_invalid_tags(base_task: Task, tmp_path):
|
|
531
|
+
row_data = [
|
|
532
|
+
{
|
|
533
|
+
"input": "This is my input",
|
|
534
|
+
"output": "This is my output",
|
|
535
|
+
"tags": "tag with space,valid-tag",
|
|
536
|
+
},
|
|
537
|
+
{
|
|
538
|
+
"input": "This is my input 2",
|
|
539
|
+
"output": "This is my output 2",
|
|
540
|
+
"tags": "another invalid tag",
|
|
541
|
+
},
|
|
542
|
+
]
|
|
543
|
+
|
|
544
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
545
|
+
|
|
546
|
+
importer = DatasetFileImporter(
|
|
547
|
+
base_task,
|
|
548
|
+
ImportConfig(
|
|
549
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
550
|
+
dataset_path=file_path,
|
|
551
|
+
dataset_name="test.csv",
|
|
552
|
+
),
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
# check that the import raises an exception
|
|
556
|
+
with pytest.raises(KilnInvalidImportFormat) as e:
|
|
557
|
+
importer.create_runs_from_file()
|
|
558
|
+
|
|
559
|
+
# the row number is +1 because of the header
|
|
560
|
+
assert e.value.row_number == 2
|
|
561
|
+
assert "Tags cannot contain spaces" in str(e.value)
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
def test_without_none_values():
|
|
565
|
+
assert without_none_values({"a": 1, "b": None}) == {"a": 1}
|
|
566
|
+
assert without_none_values({"a": None, "b": 2}) == {"b": 2}
|
|
567
|
+
assert without_none_values({"a": None, "b": None}) == {}
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
def test_deserialize_tags():
|
|
571
|
+
assert deserialize_tags("t1,t2") == ["t1", "t2"]
|
|
572
|
+
assert deserialize_tags(None) == []
|
|
573
|
+
assert deserialize_tags("") == []
|
|
574
|
+
assert deserialize_tags(" ") == []
|
|
575
|
+
assert deserialize_tags("t1, t2") == ["t1", "t2"]
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
def test_format_validation_error():
|
|
579
|
+
class TestModel(BaseModel):
|
|
580
|
+
a: int
|
|
581
|
+
b: int
|
|
582
|
+
|
|
583
|
+
try:
|
|
584
|
+
TestModel.model_validate({"a": "not an int"})
|
|
585
|
+
except ValidationError as e:
|
|
586
|
+
human_readable = format_validation_error(e)
|
|
587
|
+
assert human_readable.startswith("Validation failed:")
|
|
588
|
+
assert (
|
|
589
|
+
"a: Input should be a valid integer, unable to parse string as an integer"
|
|
590
|
+
in human_readable
|
|
591
|
+
)
|
|
592
|
+
assert "b: Field required" in human_readable
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
def test_generate_import_tags():
|
|
596
|
+
assert generate_import_tags("123") == ["imported", "imported_123"]
|