kiln-ai 0.12.0__py3-none-any.whl → 0.13.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +4 -0
- kiln_ai/adapters/adapter_registry.py +157 -28
- kiln_ai/adapters/eval/__init__.py +28 -0
- kiln_ai/adapters/eval/eval_runner.py +4 -1
- kiln_ai/adapters/eval/g_eval.py +19 -3
- kiln_ai/adapters/eval/test_base_eval.py +1 -0
- kiln_ai/adapters/eval/test_eval_runner.py +1 -0
- kiln_ai/adapters/eval/test_g_eval.py +13 -7
- kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +8 -1
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_together_finetune.py +533 -0
- kiln_ai/adapters/fine_tune/together_finetune.py +327 -0
- kiln_ai/adapters/ml_model_list.py +638 -155
- kiln_ai/adapters/model_adapters/__init__.py +2 -4
- kiln_ai/adapters/model_adapters/base_adapter.py +14 -11
- kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
- kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
- kiln_ai/adapters/model_adapters/test_structured_output.py +23 -5
- kiln_ai/adapters/ollama_tools.py +3 -2
- kiln_ai/adapters/parsers/r1_parser.py +19 -14
- kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
- kiln_ai/adapters/provider_tools.py +52 -60
- kiln_ai/adapters/repair/test_repair_task.py +3 -3
- kiln_ai/adapters/run_output.py +1 -1
- kiln_ai/adapters/test_adapter_registry.py +17 -20
- kiln_ai/adapters/test_generate_docs.py +2 -2
- kiln_ai/adapters/test_prompt_adaptors.py +30 -19
- kiln_ai/adapters/test_provider_tools.py +27 -82
- kiln_ai/datamodel/basemodel.py +2 -0
- kiln_ai/datamodel/datamodel_enums.py +2 -0
- kiln_ai/datamodel/json_schema.py +1 -1
- kiln_ai/datamodel/task_output.py +13 -6
- kiln_ai/datamodel/test_basemodel.py +9 -0
- kiln_ai/datamodel/test_datasource.py +19 -0
- kiln_ai/utils/config.py +46 -0
- kiln_ai/utils/dataset_import.py +232 -0
- kiln_ai/utils/test_dataset_import.py +596 -0
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/METADATA +51 -7
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/RECORD +44 -41
- kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -309
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -10
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -289
- kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -343
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -216
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/WHEEL +0 -0
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,596 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from io import StringIO
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
from pydantic import BaseModel, ValidationError
|
|
9
|
+
|
|
10
|
+
from kiln_ai.datamodel import Project, Task
|
|
11
|
+
from kiln_ai.utils.dataset_import import (
|
|
12
|
+
DatasetFileImporter,
|
|
13
|
+
DatasetImportFormat,
|
|
14
|
+
ImportConfig,
|
|
15
|
+
KilnInvalidImportFormat,
|
|
16
|
+
deserialize_tags,
|
|
17
|
+
format_validation_error,
|
|
18
|
+
generate_import_tags,
|
|
19
|
+
without_none_values,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pytest.fixture
|
|
26
|
+
def base_task(tmp_path) -> Task:
|
|
27
|
+
project_path = tmp_path / "project.kiln"
|
|
28
|
+
|
|
29
|
+
project = Project(name="TestProject", path=str(project_path))
|
|
30
|
+
project.save_to_file()
|
|
31
|
+
|
|
32
|
+
task = Task(
|
|
33
|
+
name="Sentiment Classifier",
|
|
34
|
+
parent=project,
|
|
35
|
+
description="Classify the sentiment of a sentence",
|
|
36
|
+
instruction="Classify the sentiment of a sentence",
|
|
37
|
+
requirements=[],
|
|
38
|
+
)
|
|
39
|
+
task.save_to_file()
|
|
40
|
+
return task
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@pytest.fixture
|
|
44
|
+
def task_with_structured_output(base_task: Task):
|
|
45
|
+
base_task.output_json_schema = json.dumps(
|
|
46
|
+
{
|
|
47
|
+
"type": "object",
|
|
48
|
+
"properties": {
|
|
49
|
+
"sentiment": {"type": "string"},
|
|
50
|
+
"confidence": {"type": "number"},
|
|
51
|
+
},
|
|
52
|
+
"required": ["sentiment", "confidence"],
|
|
53
|
+
}
|
|
54
|
+
)
|
|
55
|
+
base_task.save_to_file()
|
|
56
|
+
return base_task
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@pytest.fixture
|
|
60
|
+
def task_with_structured_input(base_task: Task):
|
|
61
|
+
base_task.input_json_schema = json.dumps(
|
|
62
|
+
{
|
|
63
|
+
"type": "object",
|
|
64
|
+
"properties": {
|
|
65
|
+
"example_id": {"type": "integer"},
|
|
66
|
+
"text": {"type": "string"},
|
|
67
|
+
},
|
|
68
|
+
"required": ["example_id", "text"],
|
|
69
|
+
}
|
|
70
|
+
)
|
|
71
|
+
base_task.save_to_file()
|
|
72
|
+
return base_task
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@pytest.fixture
|
|
76
|
+
def task_with_intermediate_outputs(base_task: Task):
|
|
77
|
+
for run in base_task.runs():
|
|
78
|
+
run.intermediate_outputs = {"reasoning": "thinking output"}
|
|
79
|
+
base_task.thinking_instruction = "thinking instructions"
|
|
80
|
+
return base_task
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def dict_to_csv_row(row: dict) -> str:
|
|
84
|
+
"""Convert a dictionary to a CSV row with proper escaping."""
|
|
85
|
+
output = StringIO()
|
|
86
|
+
writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL)
|
|
87
|
+
writer.writerow(row.values())
|
|
88
|
+
return output.getvalue().rstrip("\n")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def dicts_to_file_as_csv(items: list[dict], file_name: str, tmp_path: Path) -> str:
|
|
92
|
+
"""Write a list of dictionaries to a CSV file with escaping and a header.
|
|
93
|
+
|
|
94
|
+
Returns the path to the file.
|
|
95
|
+
"""
|
|
96
|
+
rows = [dict_to_csv_row(item) for item in items]
|
|
97
|
+
header = ",".join(f'"{key}"' for key in items[0].keys())
|
|
98
|
+
csv_data = header + "\n" + "\n".join(rows)
|
|
99
|
+
|
|
100
|
+
file_path = tmp_path / file_name
|
|
101
|
+
with open(file_path, "w", encoding="utf-8") as f:
|
|
102
|
+
f.write(csv_data)
|
|
103
|
+
|
|
104
|
+
return file_path
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def compare_tags(actual_tags: list[str], expected_tags: list[str]):
|
|
108
|
+
"""Compare the tags of a run to a list of tags.
|
|
109
|
+
|
|
110
|
+
Returns True if the run.tags contains all the tags in the list.
|
|
111
|
+
"""
|
|
112
|
+
# the run.tags contain some extra default tags
|
|
113
|
+
if expected_tags:
|
|
114
|
+
tags_expected = expected_tags.split(",")
|
|
115
|
+
else:
|
|
116
|
+
tags_expected = []
|
|
117
|
+
|
|
118
|
+
assert all(tag in actual_tags for tag in tags_expected)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def test_import_csv_plain_text(base_task: Task, tmp_path):
|
|
122
|
+
row_data = [
|
|
123
|
+
{
|
|
124
|
+
"input": "This is my input",
|
|
125
|
+
"output": "This is my output 啊",
|
|
126
|
+
"tags": "t1,t2",
|
|
127
|
+
},
|
|
128
|
+
{
|
|
129
|
+
"input": "This is my input 2",
|
|
130
|
+
"output": "This is my output 2 啊",
|
|
131
|
+
"tags": "t3,t4",
|
|
132
|
+
},
|
|
133
|
+
{
|
|
134
|
+
"input": "This is my input 3",
|
|
135
|
+
"output": "This is my output 3 啊",
|
|
136
|
+
"tags": "t5",
|
|
137
|
+
},
|
|
138
|
+
{
|
|
139
|
+
"input": "This is my input 4",
|
|
140
|
+
"output": "This is my output 4 啊",
|
|
141
|
+
"tags": "",
|
|
142
|
+
},
|
|
143
|
+
]
|
|
144
|
+
|
|
145
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
146
|
+
|
|
147
|
+
importer = DatasetFileImporter(
|
|
148
|
+
base_task,
|
|
149
|
+
ImportConfig(
|
|
150
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
151
|
+
dataset_path=file_path,
|
|
152
|
+
dataset_name="test.csv",
|
|
153
|
+
),
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
importer.create_runs_from_file()
|
|
157
|
+
|
|
158
|
+
assert len(base_task.runs()) == 4
|
|
159
|
+
|
|
160
|
+
for run in base_task.runs():
|
|
161
|
+
# identify the row data with same input as the run
|
|
162
|
+
match = next(
|
|
163
|
+
(row for row in row_data if row["input"] == run.input),
|
|
164
|
+
None,
|
|
165
|
+
)
|
|
166
|
+
assert match is not None
|
|
167
|
+
assert run.input == match["input"]
|
|
168
|
+
assert run.output.output == match["output"]
|
|
169
|
+
|
|
170
|
+
compare_tags(run.tags, match["tags"])
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def test_import_csv_default_tags(base_task: Task, tmp_path):
|
|
174
|
+
row_data = [
|
|
175
|
+
{
|
|
176
|
+
"input": "This is my input",
|
|
177
|
+
"output": "This is my output 啊",
|
|
178
|
+
"tags": "t1,t2",
|
|
179
|
+
},
|
|
180
|
+
{
|
|
181
|
+
"input": "This is my input 4",
|
|
182
|
+
"output": "This is my output 4 啊",
|
|
183
|
+
"tags": "",
|
|
184
|
+
},
|
|
185
|
+
]
|
|
186
|
+
|
|
187
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
188
|
+
|
|
189
|
+
importer = DatasetFileImporter(
|
|
190
|
+
base_task,
|
|
191
|
+
ImportConfig(
|
|
192
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
193
|
+
dataset_path=file_path,
|
|
194
|
+
dataset_name="test.csv",
|
|
195
|
+
),
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
importer.create_runs_from_file()
|
|
199
|
+
|
|
200
|
+
assert len(base_task.runs()) == 2
|
|
201
|
+
|
|
202
|
+
default_tags = 2
|
|
203
|
+
|
|
204
|
+
for run in base_task.runs():
|
|
205
|
+
# identify the row data with same input as the run
|
|
206
|
+
match = next(
|
|
207
|
+
(row for row in row_data if row["input"] == run.input),
|
|
208
|
+
None,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
assert match is not None
|
|
212
|
+
|
|
213
|
+
if match["tags"]:
|
|
214
|
+
expected_tags = match["tags"].split(",")
|
|
215
|
+
assert len(run.tags) == len(expected_tags) + default_tags
|
|
216
|
+
else:
|
|
217
|
+
assert len(run.tags) == default_tags
|
|
218
|
+
|
|
219
|
+
# these are the default tags
|
|
220
|
+
assert "imported" in run.tags
|
|
221
|
+
assert any(tag.startswith("imported_") for tag in run.tags)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def test_import_csv_plain_text_missing_output(base_task: Task, tmp_path):
|
|
225
|
+
row_data = [
|
|
226
|
+
{"input": "This is my input", "tags": "t1,t2"},
|
|
227
|
+
{"input": "This is my input 2", "tags": "t3,t4"},
|
|
228
|
+
{"input": "This is my input 3", "tags": "t5,t6"},
|
|
229
|
+
]
|
|
230
|
+
|
|
231
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
232
|
+
|
|
233
|
+
importer = DatasetFileImporter(
|
|
234
|
+
base_task,
|
|
235
|
+
ImportConfig(
|
|
236
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
237
|
+
dataset_path=file_path,
|
|
238
|
+
dataset_name="test.csv",
|
|
239
|
+
),
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# check that the import raises an exception
|
|
243
|
+
with pytest.raises(KilnInvalidImportFormat) as e:
|
|
244
|
+
importer.create_runs_from_file()
|
|
245
|
+
|
|
246
|
+
# no row number because the whole structure is invalid
|
|
247
|
+
assert e.value.row_number is None
|
|
248
|
+
assert "Missing required headers" in str(e.value)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def test_import_csv_structured_output(task_with_structured_output: Task, tmp_path):
|
|
252
|
+
row_data = [
|
|
253
|
+
{
|
|
254
|
+
"input": "This is my input",
|
|
255
|
+
"output": json.dumps({"sentiment": "高兴", "confidence": 0.95}),
|
|
256
|
+
"tags": "t1,t2",
|
|
257
|
+
},
|
|
258
|
+
{
|
|
259
|
+
"input": "This is my input 2",
|
|
260
|
+
"output": json.dumps({"sentiment": "negative", "confidence": 0.05}),
|
|
261
|
+
"tags": "t3,t4",
|
|
262
|
+
},
|
|
263
|
+
{
|
|
264
|
+
"input": "This is my input 3",
|
|
265
|
+
"output": json.dumps({"sentiment": "neutral", "confidence": 0.5}),
|
|
266
|
+
"tags": "",
|
|
267
|
+
},
|
|
268
|
+
]
|
|
269
|
+
|
|
270
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
271
|
+
|
|
272
|
+
importer = DatasetFileImporter(
|
|
273
|
+
task_with_structured_output,
|
|
274
|
+
ImportConfig(
|
|
275
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
276
|
+
dataset_path=file_path,
|
|
277
|
+
dataset_name="test.csv",
|
|
278
|
+
),
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
importer.create_runs_from_file()
|
|
282
|
+
|
|
283
|
+
assert len(task_with_structured_output.runs()) == 3
|
|
284
|
+
|
|
285
|
+
for run in task_with_structured_output.runs():
|
|
286
|
+
# identify the row data with same input as the run
|
|
287
|
+
match = next(
|
|
288
|
+
(row for row in row_data if row["input"] == run.input),
|
|
289
|
+
None,
|
|
290
|
+
)
|
|
291
|
+
assert match is not None
|
|
292
|
+
assert run.input == match["input"]
|
|
293
|
+
assert json.loads(run.output.output) == json.loads(match["output"])
|
|
294
|
+
|
|
295
|
+
compare_tags(run.tags, match["tags"])
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def test_import_csv_structured_output_wrong_schema(
|
|
299
|
+
task_with_structured_output: Task, tmp_path
|
|
300
|
+
):
|
|
301
|
+
row_data = [
|
|
302
|
+
{
|
|
303
|
+
"input": "This is my input",
|
|
304
|
+
"output": json.dumps({"sentiment": "positive", "confidence": 0.95}),
|
|
305
|
+
"tags": "t1,t2",
|
|
306
|
+
},
|
|
307
|
+
{
|
|
308
|
+
"input": "This is my input 2",
|
|
309
|
+
# the output is wrong because sentiment is not a string
|
|
310
|
+
"output": json.dumps({"sentiment": 100, "confidence": 0.05}),
|
|
311
|
+
"tags": "t3,t4",
|
|
312
|
+
},
|
|
313
|
+
{
|
|
314
|
+
"input": "This is my input 3",
|
|
315
|
+
"output": json.dumps({"sentiment": "positive", "confidence": 0.5}),
|
|
316
|
+
"tags": "",
|
|
317
|
+
},
|
|
318
|
+
]
|
|
319
|
+
|
|
320
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
321
|
+
|
|
322
|
+
importer = DatasetFileImporter(
|
|
323
|
+
task_with_structured_output,
|
|
324
|
+
ImportConfig(
|
|
325
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
326
|
+
dataset_path=file_path,
|
|
327
|
+
dataset_name="test.csv",
|
|
328
|
+
),
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
# check that the import raises an exception
|
|
332
|
+
with pytest.raises(KilnInvalidImportFormat) as e:
|
|
333
|
+
importer.create_runs_from_file()
|
|
334
|
+
|
|
335
|
+
# the row number is +1 because of the header
|
|
336
|
+
assert e.value.row_number == 3
|
|
337
|
+
assert "Error in row 3: Validation failed" in str(e.value)
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def test_import_csv_structured_input_wrong_schema(
|
|
341
|
+
task_with_structured_input: Task, tmp_path
|
|
342
|
+
):
|
|
343
|
+
row_data = [
|
|
344
|
+
{
|
|
345
|
+
# this one is missing example_id
|
|
346
|
+
"input": json.dumps({"example_id": 1, "text": "This is my input"}),
|
|
347
|
+
"output": "This is my output",
|
|
348
|
+
"tags": "t1,t2",
|
|
349
|
+
},
|
|
350
|
+
{
|
|
351
|
+
"input": json.dumps({"text": "This is my input 2"}),
|
|
352
|
+
"output": "This is my output 2",
|
|
353
|
+
"tags": "t3,t4",
|
|
354
|
+
},
|
|
355
|
+
{
|
|
356
|
+
"input": json.dumps({"example_id": 3, "text": "This is my input 3"}),
|
|
357
|
+
"output": "This is my output 3",
|
|
358
|
+
"tags": "",
|
|
359
|
+
},
|
|
360
|
+
]
|
|
361
|
+
|
|
362
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
363
|
+
|
|
364
|
+
importer = DatasetFileImporter(
|
|
365
|
+
task_with_structured_input,
|
|
366
|
+
ImportConfig(
|
|
367
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
368
|
+
dataset_path=file_path,
|
|
369
|
+
dataset_name="test.csv",
|
|
370
|
+
),
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
# check that the import raises an exception
|
|
374
|
+
with pytest.raises(KilnInvalidImportFormat) as e:
|
|
375
|
+
importer.create_runs_from_file()
|
|
376
|
+
|
|
377
|
+
# the row number is +1 because of the header
|
|
378
|
+
assert e.value.row_number == 3
|
|
379
|
+
assert "Error in row 3: Validation failed" in str(e.value)
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def test_import_csv_intermediate_outputs_reasoning(
|
|
383
|
+
task_with_intermediate_outputs: Task,
|
|
384
|
+
tmp_path,
|
|
385
|
+
):
|
|
386
|
+
row_data = [
|
|
387
|
+
{
|
|
388
|
+
"input": "This is my input",
|
|
389
|
+
"output": "This is my output",
|
|
390
|
+
"reasoning": "我觉得这个输出是正确的",
|
|
391
|
+
"tags": "t1,t2",
|
|
392
|
+
},
|
|
393
|
+
{
|
|
394
|
+
"input": "This is my input 2",
|
|
395
|
+
"output": "This is my output 2",
|
|
396
|
+
"reasoning": "thinking output 2",
|
|
397
|
+
"tags": "t3,t4",
|
|
398
|
+
},
|
|
399
|
+
{
|
|
400
|
+
"input": "This is my input 3",
|
|
401
|
+
"output": "This is my output 3",
|
|
402
|
+
"reasoning": "thinking output 3",
|
|
403
|
+
"tags": "",
|
|
404
|
+
},
|
|
405
|
+
]
|
|
406
|
+
|
|
407
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
408
|
+
|
|
409
|
+
importer = DatasetFileImporter(
|
|
410
|
+
task_with_intermediate_outputs,
|
|
411
|
+
ImportConfig(
|
|
412
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
413
|
+
dataset_path=file_path,
|
|
414
|
+
dataset_name="test.csv",
|
|
415
|
+
),
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
importer.create_runs_from_file()
|
|
419
|
+
|
|
420
|
+
assert len(task_with_intermediate_outputs.runs()) == 3
|
|
421
|
+
|
|
422
|
+
for run in task_with_intermediate_outputs.runs():
|
|
423
|
+
# identify the row data with same input as the run
|
|
424
|
+
match = next(
|
|
425
|
+
(row for row in row_data if row["input"] == run.input),
|
|
426
|
+
None,
|
|
427
|
+
)
|
|
428
|
+
assert match is not None
|
|
429
|
+
assert run.input == match["input"]
|
|
430
|
+
assert run.output.output == match["output"]
|
|
431
|
+
assert run.intermediate_outputs["reasoning"] == match["reasoning"]
|
|
432
|
+
compare_tags(run.tags, match["tags"])
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def test_import_csv_intermediate_outputs_cot(
|
|
436
|
+
task_with_intermediate_outputs: Task, tmp_path
|
|
437
|
+
):
|
|
438
|
+
row_data = [
|
|
439
|
+
{
|
|
440
|
+
"input": "This is my input",
|
|
441
|
+
"output": "This is my output",
|
|
442
|
+
"chain_of_thought": "我觉得这个输出是正确的",
|
|
443
|
+
"tags": "t1,t2",
|
|
444
|
+
},
|
|
445
|
+
{
|
|
446
|
+
"input": "This is my input 2",
|
|
447
|
+
"output": "This is my output 2",
|
|
448
|
+
"chain_of_thought": "thinking output 2",
|
|
449
|
+
"tags": "t3,t4",
|
|
450
|
+
},
|
|
451
|
+
{
|
|
452
|
+
"input": "This is my input 3",
|
|
453
|
+
"output": "This is my output 3",
|
|
454
|
+
"chain_of_thought": "thinking output 3",
|
|
455
|
+
"tags": "",
|
|
456
|
+
},
|
|
457
|
+
]
|
|
458
|
+
|
|
459
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
460
|
+
|
|
461
|
+
importer = DatasetFileImporter(
|
|
462
|
+
task_with_intermediate_outputs,
|
|
463
|
+
ImportConfig(
|
|
464
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
465
|
+
dataset_path=file_path,
|
|
466
|
+
dataset_name="test.csv",
|
|
467
|
+
),
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
importer.create_runs_from_file()
|
|
471
|
+
|
|
472
|
+
assert len(task_with_intermediate_outputs.runs()) == 3
|
|
473
|
+
|
|
474
|
+
for run in task_with_intermediate_outputs.runs():
|
|
475
|
+
# identify the row data with same input as the run
|
|
476
|
+
match = next(
|
|
477
|
+
(row for row in row_data if row["input"] == run.input),
|
|
478
|
+
None,
|
|
479
|
+
)
|
|
480
|
+
assert match is not None
|
|
481
|
+
assert run.input == match["input"]
|
|
482
|
+
assert run.output.output == match["output"]
|
|
483
|
+
assert run.intermediate_outputs["chain_of_thought"] == match["chain_of_thought"]
|
|
484
|
+
compare_tags(run.tags, match["tags"])
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
def test_import_csv_intermediate_outputs_reasoning_and_cot(
|
|
488
|
+
task_with_intermediate_outputs: Task,
|
|
489
|
+
tmp_path,
|
|
490
|
+
):
|
|
491
|
+
row_data = [
|
|
492
|
+
{
|
|
493
|
+
"input": "This is my input",
|
|
494
|
+
"output": "This is my output",
|
|
495
|
+
"reasoning": "thinking output 1",
|
|
496
|
+
"chain_of_thought": "thinking output 1",
|
|
497
|
+
"tags": "t1,t2",
|
|
498
|
+
},
|
|
499
|
+
]
|
|
500
|
+
|
|
501
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
502
|
+
|
|
503
|
+
importer = DatasetFileImporter(
|
|
504
|
+
task_with_intermediate_outputs,
|
|
505
|
+
ImportConfig(
|
|
506
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
507
|
+
dataset_path=file_path,
|
|
508
|
+
dataset_name="test.csv",
|
|
509
|
+
),
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
importer.create_runs_from_file()
|
|
513
|
+
|
|
514
|
+
assert len(task_with_intermediate_outputs.runs()) == 1
|
|
515
|
+
|
|
516
|
+
for run in task_with_intermediate_outputs.runs():
|
|
517
|
+
# identify the row data with same input as the run
|
|
518
|
+
match = next(
|
|
519
|
+
(row for row in row_data if row["input"] == run.input),
|
|
520
|
+
None,
|
|
521
|
+
)
|
|
522
|
+
assert match is not None
|
|
523
|
+
assert run.input == match["input"]
|
|
524
|
+
assert run.output.output == match["output"]
|
|
525
|
+
assert run.intermediate_outputs["chain_of_thought"] == match["chain_of_thought"]
|
|
526
|
+
assert run.intermediate_outputs["reasoning"] == match["reasoning"]
|
|
527
|
+
compare_tags(run.tags, match["tags"])
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def test_import_csv_invalid_tags(base_task: Task, tmp_path):
|
|
531
|
+
row_data = [
|
|
532
|
+
{
|
|
533
|
+
"input": "This is my input",
|
|
534
|
+
"output": "This is my output",
|
|
535
|
+
"tags": "tag with space,valid-tag",
|
|
536
|
+
},
|
|
537
|
+
{
|
|
538
|
+
"input": "This is my input 2",
|
|
539
|
+
"output": "This is my output 2",
|
|
540
|
+
"tags": "another invalid tag",
|
|
541
|
+
},
|
|
542
|
+
]
|
|
543
|
+
|
|
544
|
+
file_path = dicts_to_file_as_csv(row_data, "test.csv", tmp_path)
|
|
545
|
+
|
|
546
|
+
importer = DatasetFileImporter(
|
|
547
|
+
base_task,
|
|
548
|
+
ImportConfig(
|
|
549
|
+
dataset_type=DatasetImportFormat.CSV,
|
|
550
|
+
dataset_path=file_path,
|
|
551
|
+
dataset_name="test.csv",
|
|
552
|
+
),
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
# check that the import raises an exception
|
|
556
|
+
with pytest.raises(KilnInvalidImportFormat) as e:
|
|
557
|
+
importer.create_runs_from_file()
|
|
558
|
+
|
|
559
|
+
# the row number is +1 because of the header
|
|
560
|
+
assert e.value.row_number == 2
|
|
561
|
+
assert "Tags cannot contain spaces" in str(e.value)
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
def test_without_none_values():
|
|
565
|
+
assert without_none_values({"a": 1, "b": None}) == {"a": 1}
|
|
566
|
+
assert without_none_values({"a": None, "b": 2}) == {"b": 2}
|
|
567
|
+
assert without_none_values({"a": None, "b": None}) == {}
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
def test_deserialize_tags():
|
|
571
|
+
assert deserialize_tags("t1,t2") == ["t1", "t2"]
|
|
572
|
+
assert deserialize_tags(None) == []
|
|
573
|
+
assert deserialize_tags("") == []
|
|
574
|
+
assert deserialize_tags(" ") == []
|
|
575
|
+
assert deserialize_tags("t1, t2") == ["t1", "t2"]
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
def test_format_validation_error():
|
|
579
|
+
class TestModel(BaseModel):
|
|
580
|
+
a: int
|
|
581
|
+
b: int
|
|
582
|
+
|
|
583
|
+
try:
|
|
584
|
+
TestModel.model_validate({"a": "not an int"})
|
|
585
|
+
except ValidationError as e:
|
|
586
|
+
human_readable = format_validation_error(e)
|
|
587
|
+
assert human_readable.startswith("Validation failed:")
|
|
588
|
+
assert (
|
|
589
|
+
"a: Input should be a valid integer, unable to parse string as an integer"
|
|
590
|
+
in human_readable
|
|
591
|
+
)
|
|
592
|
+
assert "b: Field required" in human_readable
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
def test_generate_import_tags():
|
|
596
|
+
assert generate_import_tags("123") == ["imported", "imported_123"]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: kiln-ai
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.13.2
|
|
4
4
|
Summary: Kiln AI
|
|
5
5
|
Project-URL: Homepage, https://getkiln.ai
|
|
6
6
|
Project-URL: Repository, https://github.com/Kiln-AI/kiln
|
|
@@ -13,20 +13,20 @@ Classifier: License :: OSI Approved :: MIT License
|
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.10
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.11
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
16
17
|
Requires-Python: >=3.10
|
|
18
|
+
Requires-Dist: boto3>=1.37.10
|
|
17
19
|
Requires-Dist: coverage>=7.6.4
|
|
20
|
+
Requires-Dist: google-cloud-aiplatform>=1.84.0
|
|
18
21
|
Requires-Dist: jsonschema>=4.23.0
|
|
19
|
-
Requires-Dist:
|
|
20
|
-
Requires-Dist: langchain-fireworks>=0.2.5
|
|
21
|
-
Requires-Dist: langchain-groq>=0.2.0
|
|
22
|
-
Requires-Dist: langchain-ollama>=0.2.2
|
|
23
|
-
Requires-Dist: langchain>=0.3.5
|
|
22
|
+
Requires-Dist: litellm>=1.63.5
|
|
24
23
|
Requires-Dist: openai>=1.53.0
|
|
25
24
|
Requires-Dist: pdoc>=15.0.0
|
|
26
25
|
Requires-Dist: pydantic>=2.9.2
|
|
27
26
|
Requires-Dist: pytest-benchmark>=5.1.0
|
|
28
27
|
Requires-Dist: pytest-cov>=6.0.0
|
|
29
28
|
Requires-Dist: pyyaml>=6.0.2
|
|
29
|
+
Requires-Dist: together
|
|
30
30
|
Requires-Dist: typing-extensions>=4.12.2
|
|
31
31
|
Description-Content-Type: text/markdown
|
|
32
32
|
|
|
@@ -72,6 +72,7 @@ The library has a [comprehensive set of docs](https://kiln-ai.github.io/Kiln/kil
|
|
|
72
72
|
- [Using your Kiln Dataset in a Notebook or Project](#using-your-kiln-dataset-in-a-notebook-or-project)
|
|
73
73
|
- [Using Kiln Dataset in Pandas](#using-kiln-dataset-in-pandas)
|
|
74
74
|
- [Building and Running a Kiln Task from Code](#building-and-running-a-kiln-task-from-code)
|
|
75
|
+
- [Adding Custom Model or AI Provider from Code](#adding-custom-model-or-ai-provider-from-code)
|
|
75
76
|
- [Full API Reference](#full-api-reference)
|
|
76
77
|
|
|
77
78
|
## Installation
|
|
@@ -259,7 +260,7 @@ response = await adapter.invoke(task_input)
|
|
|
259
260
|
print(f"Output: {response.output.output}")
|
|
260
261
|
|
|
261
262
|
# Step 4 (optional): Load the task from disk and print the results.
|
|
262
|
-
# This will only work if the task was loaded from disk, or you called task.save_to_file() before invoking the adapter (
|
|
263
|
+
# This will only work if the task was loaded from disk, or you called task.save_to_file() before invoking the adapter (ephemeral tasks don't save their result to disk)
|
|
263
264
|
task = datamodel.Task.load_from_file(tmp_path / "test_task.kiln")
|
|
264
265
|
for run in task.runs():
|
|
265
266
|
print(f"Run: {run.id}")
|
|
@@ -268,6 +269,49 @@ for run in task.runs():
|
|
|
268
269
|
|
|
269
270
|
```
|
|
270
271
|
|
|
272
|
+
### Adding Custom Model or AI Provider from Code
|
|
273
|
+
|
|
274
|
+
You can add additional AI models and providers to Kiln.
|
|
275
|
+
|
|
276
|
+
See our docs for more information, including how to add these from the UI:
|
|
277
|
+
|
|
278
|
+
- [Custom Models From Existing Providers](https://docs.getkiln.ai/docs/models-and-ai-providers#custom-models-from-existing-providers)
|
|
279
|
+
- [Custom OpenAI Compatible Servers](https://docs.getkiln.ai/docs/models-and-ai-providers#custom-openai-compatible-servers)
|
|
280
|
+
|
|
281
|
+
You can also add these from code. The kiln_ai.utils.Config class helps you manage the Kiln config file (stored at `~/.kiln_settings/config.yaml`):
|
|
282
|
+
|
|
283
|
+
```python
|
|
284
|
+
# Addding an OpenAI compatible provider
|
|
285
|
+
name = "CustomOllama"
|
|
286
|
+
base_url = "http://localhost:1234/api/v1"
|
|
287
|
+
api_key = "12345"
|
|
288
|
+
providers = Config.shared().openai_compatible_providers or []
|
|
289
|
+
existing_provider = next((p for p in providers if p["name"] == name), None)
|
|
290
|
+
if existing_provider:
|
|
291
|
+
# skip since this already exists
|
|
292
|
+
return
|
|
293
|
+
providers.append(
|
|
294
|
+
{
|
|
295
|
+
"name": name,
|
|
296
|
+
"base_url": base_url,
|
|
297
|
+
"api_key": api_key,
|
|
298
|
+
}
|
|
299
|
+
)
|
|
300
|
+
Config.shared().openai_compatible_providers = providers
|
|
301
|
+
```
|
|
302
|
+
|
|
303
|
+
```python
|
|
304
|
+
# Add a custom model ID to an existing provider.
|
|
305
|
+
new_model = "openai::gpt-3.5-turbo"
|
|
306
|
+
custom_model_ids = Config.shared().custom_models
|
|
307
|
+
existing_model = next((m for m in custom_model_ids if m == new_model), None)
|
|
308
|
+
if existing_model:
|
|
309
|
+
# skip since this already exists
|
|
310
|
+
return
|
|
311
|
+
custom_model_ids.append(new_model)
|
|
312
|
+
Config.shared().custom_models = custom_model_ids
|
|
313
|
+
```
|
|
314
|
+
|
|
271
315
|
## Full API Reference
|
|
272
316
|
|
|
273
317
|
The library can do a lot more than the examples we've shown here.
|