kiln-ai 0.0.4__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/base_adapter.py +168 -0
- kiln_ai/adapters/langchain_adapters.py +113 -0
- kiln_ai/adapters/ml_model_list.py +436 -0
- kiln_ai/adapters/prompt_builders.py +122 -0
- kiln_ai/adapters/repair/repair_task.py +71 -0
- kiln_ai/adapters/repair/test_repair_task.py +248 -0
- kiln_ai/adapters/test_langchain_adapter.py +50 -0
- kiln_ai/adapters/test_ml_model_list.py +99 -0
- kiln_ai/adapters/test_prompt_adaptors.py +167 -0
- kiln_ai/adapters/test_prompt_builders.py +315 -0
- kiln_ai/adapters/test_saving_adapter_results.py +168 -0
- kiln_ai/adapters/test_structured_output.py +218 -0
- kiln_ai/datamodel/__init__.py +362 -2
- kiln_ai/datamodel/basemodel.py +372 -0
- kiln_ai/datamodel/json_schema.py +45 -0
- kiln_ai/datamodel/test_basemodel.py +277 -0
- kiln_ai/datamodel/test_datasource.py +107 -0
- kiln_ai/datamodel/test_example_models.py +644 -0
- kiln_ai/datamodel/test_json_schema.py +124 -0
- kiln_ai/datamodel/test_models.py +190 -0
- kiln_ai/datamodel/test_nested_save.py +205 -0
- kiln_ai/datamodel/test_output_rating.py +88 -0
- kiln_ai/utils/config.py +170 -0
- kiln_ai/utils/formatting.py +5 -0
- kiln_ai/utils/test_config.py +245 -0
- {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.0.dist-info}/METADATA +20 -1
- kiln_ai-0.5.0.dist-info/RECORD +29 -0
- kiln_ai/__init.__.py +0 -3
- kiln_ai/coreadd.py +0 -3
- kiln_ai/datamodel/project.py +0 -15
- kiln_ai-0.0.4.dist-info/RECORD +0 -8
- {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.0.dist-info}/LICENSE.txt +0 -0
- {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,644 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from kiln_ai.datamodel import (
|
|
5
|
+
DataSource,
|
|
6
|
+
DataSourceType,
|
|
7
|
+
Project,
|
|
8
|
+
Task,
|
|
9
|
+
TaskDeterminism,
|
|
10
|
+
TaskOutput,
|
|
11
|
+
TaskOutputRating,
|
|
12
|
+
TaskOutputRatingType,
|
|
13
|
+
TaskRequirement,
|
|
14
|
+
TaskRun,
|
|
15
|
+
)
|
|
16
|
+
from pydantic import ValidationError
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@pytest.fixture
|
|
20
|
+
def valid_task_run(tmp_path):
|
|
21
|
+
task = Task(
|
|
22
|
+
name="Test Task",
|
|
23
|
+
instruction="test instruction",
|
|
24
|
+
path=tmp_path / Task.base_filename(),
|
|
25
|
+
)
|
|
26
|
+
return TaskRun(
|
|
27
|
+
parent=task,
|
|
28
|
+
input="Test input",
|
|
29
|
+
input_source=DataSource(
|
|
30
|
+
type=DataSourceType.human,
|
|
31
|
+
properties={"created_by": "John Doe"},
|
|
32
|
+
),
|
|
33
|
+
output=TaskOutput(
|
|
34
|
+
output="Test output",
|
|
35
|
+
source=DataSource(
|
|
36
|
+
type=DataSourceType.human,
|
|
37
|
+
properties={"created_by": "John Doe"},
|
|
38
|
+
),
|
|
39
|
+
),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def test_task_model_validation(valid_task_run):
|
|
44
|
+
task_run = valid_task_run
|
|
45
|
+
task_run.model_validate(task_run, strict=True)
|
|
46
|
+
task_run.save_to_file()
|
|
47
|
+
assert task_run.input == "Test input"
|
|
48
|
+
assert task_run.input_source.type == DataSourceType.human
|
|
49
|
+
assert task_run.input_source.properties == {"created_by": "John Doe"}
|
|
50
|
+
assert task_run.output.output == "Test output"
|
|
51
|
+
assert task_run.output.source.type == DataSourceType.human
|
|
52
|
+
assert task_run.output.source.properties == {"created_by": "John Doe"}
|
|
53
|
+
|
|
54
|
+
# Invalid source
|
|
55
|
+
with pytest.raises(ValidationError, match="Input should be"):
|
|
56
|
+
DataSource(type="invalid")
|
|
57
|
+
|
|
58
|
+
with pytest.raises(ValidationError, match="Invalid data source type"):
|
|
59
|
+
task_run = valid_task_run.model_copy(deep=True)
|
|
60
|
+
task_run.input_source.type = "invalid"
|
|
61
|
+
DataSource.model_validate(task_run.input_source, strict=True)
|
|
62
|
+
|
|
63
|
+
# Missing required field
|
|
64
|
+
with pytest.raises(ValidationError, match="Input should be a valid string"):
|
|
65
|
+
task_run = valid_task_run.model_copy()
|
|
66
|
+
task_run.input = None
|
|
67
|
+
|
|
68
|
+
# Invalid source_properties type
|
|
69
|
+
with pytest.raises(ValidationError):
|
|
70
|
+
task_run = valid_task_run.model_copy()
|
|
71
|
+
task_run.input_source.properties = "invalid"
|
|
72
|
+
DataSource.model_validate(task_run.input_source, strict=True)
|
|
73
|
+
|
|
74
|
+
# Test we catch nested validation errors
|
|
75
|
+
with pytest.raises(
|
|
76
|
+
ValidationError, match="'created_by' is required for DataSourceType.human"
|
|
77
|
+
):
|
|
78
|
+
task_run = TaskRun(
|
|
79
|
+
input="Test input",
|
|
80
|
+
input_source=DataSource(
|
|
81
|
+
type=DataSourceType.human,
|
|
82
|
+
properties={"created_by": "John Doe"},
|
|
83
|
+
),
|
|
84
|
+
output=TaskOutput(
|
|
85
|
+
output="Test output",
|
|
86
|
+
source=DataSource(
|
|
87
|
+
type=DataSourceType.human,
|
|
88
|
+
properties={"wrong_key": "John Doe"},
|
|
89
|
+
),
|
|
90
|
+
),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def test_task_run_relationship(valid_task_run):
|
|
95
|
+
assert valid_task_run.__class__.relationship_name() == "runs"
|
|
96
|
+
assert valid_task_run.__class__.parent_type().__name__ == "Task"
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def test_structured_output_workflow(tmp_path):
|
|
100
|
+
tmp_project_file = (
|
|
101
|
+
tmp_path / "test_structured_output_runs" / Project.base_filename()
|
|
102
|
+
)
|
|
103
|
+
# Create project
|
|
104
|
+
project = Project(name="Test Project", path=str(tmp_project_file))
|
|
105
|
+
project.save_to_file()
|
|
106
|
+
|
|
107
|
+
# Create task with requirements
|
|
108
|
+
req1 = TaskRequirement(name="Req1", instruction="Name must be capitalized")
|
|
109
|
+
req2 = TaskRequirement(name="Req2", instruction="Age must be positive")
|
|
110
|
+
|
|
111
|
+
task = Task(
|
|
112
|
+
name="Structured Output Task",
|
|
113
|
+
parent=project,
|
|
114
|
+
instruction="Generate a JSON object with name and age",
|
|
115
|
+
determinism=TaskDeterminism.semantic_match,
|
|
116
|
+
output_json_schema=json.dumps(
|
|
117
|
+
{
|
|
118
|
+
"type": "object",
|
|
119
|
+
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
|
|
120
|
+
"required": ["name", "age"],
|
|
121
|
+
}
|
|
122
|
+
),
|
|
123
|
+
requirements=[
|
|
124
|
+
req1,
|
|
125
|
+
req2,
|
|
126
|
+
],
|
|
127
|
+
)
|
|
128
|
+
task.save_to_file()
|
|
129
|
+
|
|
130
|
+
# Create runs
|
|
131
|
+
runs = []
|
|
132
|
+
for source in DataSourceType:
|
|
133
|
+
for _ in range(2):
|
|
134
|
+
task_run = TaskRun(
|
|
135
|
+
input="Generate info for John Doe",
|
|
136
|
+
input_source=DataSource(
|
|
137
|
+
type=DataSourceType.human,
|
|
138
|
+
properties={"created_by": "john_doe"},
|
|
139
|
+
)
|
|
140
|
+
if source == DataSourceType.human
|
|
141
|
+
else DataSource(
|
|
142
|
+
type=DataSourceType.synthetic,
|
|
143
|
+
properties={
|
|
144
|
+
"adapter_name": "TestAdapter",
|
|
145
|
+
"model_name": "GPT-4",
|
|
146
|
+
"model_provider": "OpenAI",
|
|
147
|
+
"prompt_builder_name": "TestPromptBuilder",
|
|
148
|
+
},
|
|
149
|
+
),
|
|
150
|
+
parent=task,
|
|
151
|
+
output=TaskOutput(
|
|
152
|
+
output='{"name": "John Doe", "age": 30}',
|
|
153
|
+
source=DataSource(
|
|
154
|
+
type=DataSourceType.human,
|
|
155
|
+
properties={"created_by": "john_doe"},
|
|
156
|
+
),
|
|
157
|
+
),
|
|
158
|
+
)
|
|
159
|
+
task_run.save_to_file()
|
|
160
|
+
runs.append(task_run)
|
|
161
|
+
|
|
162
|
+
# make a run with a repaired output
|
|
163
|
+
repaired_run = TaskRun(
|
|
164
|
+
input="Generate info for John Doe",
|
|
165
|
+
input_source=DataSource(
|
|
166
|
+
type=DataSourceType.human,
|
|
167
|
+
properties={"created_by": "john_doe"},
|
|
168
|
+
),
|
|
169
|
+
parent=task,
|
|
170
|
+
output=TaskOutput(
|
|
171
|
+
output='{"name": "John Doe", "age": 31}',
|
|
172
|
+
source=DataSource(
|
|
173
|
+
type=DataSourceType.human,
|
|
174
|
+
properties={"created_by": "john_doe"},
|
|
175
|
+
),
|
|
176
|
+
),
|
|
177
|
+
repair_instructions="The age should be 31 instead of 30",
|
|
178
|
+
repaired_output=TaskOutput(
|
|
179
|
+
output='{"name": "John Doe", "age": 31}',
|
|
180
|
+
source=DataSource(
|
|
181
|
+
type=DataSourceType.human,
|
|
182
|
+
properties={"created_by": "john_doe"},
|
|
183
|
+
),
|
|
184
|
+
),
|
|
185
|
+
)
|
|
186
|
+
repaired_run.save_to_file()
|
|
187
|
+
runs.append(repaired_run)
|
|
188
|
+
|
|
189
|
+
# Update outputs with ratings
|
|
190
|
+
for task_run in runs:
|
|
191
|
+
task_run.output.rating = TaskOutputRating(
|
|
192
|
+
value=4,
|
|
193
|
+
requirement_ratings={
|
|
194
|
+
req1.id: 5,
|
|
195
|
+
req2.id: 5,
|
|
196
|
+
},
|
|
197
|
+
)
|
|
198
|
+
task_run.save_to_file()
|
|
199
|
+
|
|
200
|
+
# Load from disk and validate
|
|
201
|
+
loaded_project = Project.load_from_file(tmp_project_file)
|
|
202
|
+
loaded_task = loaded_project.tasks()[0]
|
|
203
|
+
|
|
204
|
+
assert loaded_task.name == "Structured Output Task"
|
|
205
|
+
assert len(loaded_task.requirements) == 2
|
|
206
|
+
assert len(loaded_task.runs()) == 5
|
|
207
|
+
|
|
208
|
+
loaded_runs = loaded_task.runs()
|
|
209
|
+
for task_run in loaded_runs:
|
|
210
|
+
output = task_run.output
|
|
211
|
+
assert output.rating is not None
|
|
212
|
+
assert output.rating.value == 4
|
|
213
|
+
assert len(output.rating.requirement_ratings) == 2
|
|
214
|
+
|
|
215
|
+
# Find the run with the fixed output
|
|
216
|
+
run_with_fixed_output = next(
|
|
217
|
+
(task_run for task_run in loaded_runs if task_run.repaired_output is not None),
|
|
218
|
+
None,
|
|
219
|
+
)
|
|
220
|
+
assert run_with_fixed_output is not None, "No run found with fixed output"
|
|
221
|
+
assert (
|
|
222
|
+
run_with_fixed_output.repaired_output.output
|
|
223
|
+
== '{"name": "John Doe", "age": 31}'
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def test_task_output_requirement_rating_keys(tmp_path):
|
|
228
|
+
# Create a project, task, and example hierarchy
|
|
229
|
+
project = Project(name="Test Project", path=(tmp_path / "test_project"))
|
|
230
|
+
project.save_to_file()
|
|
231
|
+
|
|
232
|
+
# Create task requirements
|
|
233
|
+
req1 = TaskRequirement(
|
|
234
|
+
name="Requirement 1", instruction="Requirement 1 instruction"
|
|
235
|
+
)
|
|
236
|
+
req2 = TaskRequirement(
|
|
237
|
+
name="Requirement 2", instruction="Requirement 2 instruction"
|
|
238
|
+
)
|
|
239
|
+
task = Task(
|
|
240
|
+
name="Test Task",
|
|
241
|
+
parent=project,
|
|
242
|
+
instruction="Task instruction",
|
|
243
|
+
requirements=[req1, req2],
|
|
244
|
+
)
|
|
245
|
+
task.save_to_file()
|
|
246
|
+
|
|
247
|
+
# Valid case: all requirement IDs are valid
|
|
248
|
+
task_run = TaskRun(
|
|
249
|
+
input="Test input",
|
|
250
|
+
input_source=DataSource(
|
|
251
|
+
type=DataSourceType.human,
|
|
252
|
+
properties={"created_by": "john_doe"},
|
|
253
|
+
),
|
|
254
|
+
parent=task,
|
|
255
|
+
output=TaskOutput(
|
|
256
|
+
output="Test output",
|
|
257
|
+
source=DataSource(
|
|
258
|
+
type=DataSourceType.human,
|
|
259
|
+
properties={"created_by": "john_doe"},
|
|
260
|
+
),
|
|
261
|
+
rating=TaskOutputRating(
|
|
262
|
+
value=4,
|
|
263
|
+
requirement_ratings={
|
|
264
|
+
req1.id: 5,
|
|
265
|
+
req2.id: 4,
|
|
266
|
+
},
|
|
267
|
+
),
|
|
268
|
+
),
|
|
269
|
+
)
|
|
270
|
+
task_run.save_to_file()
|
|
271
|
+
assert task_run.output.rating.requirement_ratings is not None
|
|
272
|
+
|
|
273
|
+
# Invalid case: unknown requirement ID
|
|
274
|
+
with pytest.raises(
|
|
275
|
+
ValueError,
|
|
276
|
+
match="Requirement ID .* is not a valid requirement ID for this task",
|
|
277
|
+
):
|
|
278
|
+
task_run = TaskRun(
|
|
279
|
+
input="Test input",
|
|
280
|
+
input_source=DataSource(
|
|
281
|
+
type=DataSourceType.human,
|
|
282
|
+
properties={"created_by": "john_doe"},
|
|
283
|
+
),
|
|
284
|
+
parent=task,
|
|
285
|
+
output=TaskOutput(
|
|
286
|
+
output="Test output",
|
|
287
|
+
source=DataSource(
|
|
288
|
+
type=DataSourceType.human,
|
|
289
|
+
properties={"created_by": "john_doe"},
|
|
290
|
+
),
|
|
291
|
+
rating=TaskOutputRating(
|
|
292
|
+
value=4,
|
|
293
|
+
requirement_ratings={
|
|
294
|
+
"unknown_id": 5,
|
|
295
|
+
},
|
|
296
|
+
),
|
|
297
|
+
),
|
|
298
|
+
)
|
|
299
|
+
task_run.save_to_file()
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def test_task_output_schema_validation(tmp_path):
|
|
303
|
+
# Create a project, task, and example hierarchy
|
|
304
|
+
project = Project(name="Test Project", path=(tmp_path / "test_project"))
|
|
305
|
+
project.save_to_file()
|
|
306
|
+
task = Task(
|
|
307
|
+
name="Test Task",
|
|
308
|
+
instruction="test instruction",
|
|
309
|
+
parent=project,
|
|
310
|
+
output_json_schema=json.dumps(
|
|
311
|
+
{
|
|
312
|
+
"type": "object",
|
|
313
|
+
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
|
|
314
|
+
"required": ["name", "age"],
|
|
315
|
+
}
|
|
316
|
+
),
|
|
317
|
+
)
|
|
318
|
+
task.save_to_file()
|
|
319
|
+
|
|
320
|
+
# Create an run output with a valid schema
|
|
321
|
+
task_output = TaskRun(
|
|
322
|
+
input="Test input",
|
|
323
|
+
input_source=DataSource(
|
|
324
|
+
type=DataSourceType.human,
|
|
325
|
+
properties={"created_by": "john_doe"},
|
|
326
|
+
),
|
|
327
|
+
parent=task,
|
|
328
|
+
output=TaskOutput(
|
|
329
|
+
output='{"name": "John Doe", "age": 30}',
|
|
330
|
+
source=DataSource(
|
|
331
|
+
type=DataSourceType.human,
|
|
332
|
+
properties={"created_by": "john_doe"},
|
|
333
|
+
),
|
|
334
|
+
),
|
|
335
|
+
)
|
|
336
|
+
task_output.save_to_file()
|
|
337
|
+
|
|
338
|
+
# changing to invalid output
|
|
339
|
+
with pytest.raises(ValueError, match="does not match task output schema"):
|
|
340
|
+
task_output.output.output = '{"name": "John Doe", "age": "thirty"}'
|
|
341
|
+
task_output.save_to_file()
|
|
342
|
+
|
|
343
|
+
# Invalid case: output does not match task output schema
|
|
344
|
+
with pytest.raises(ValueError, match="does not match task output schema"):
|
|
345
|
+
task_output = TaskRun(
|
|
346
|
+
input="Test input",
|
|
347
|
+
input_source=DataSource(
|
|
348
|
+
type=DataSourceType.human,
|
|
349
|
+
properties={"created_by": "john_doe"},
|
|
350
|
+
),
|
|
351
|
+
parent=task,
|
|
352
|
+
output=TaskOutput(
|
|
353
|
+
output='{"name": "John Doe", "age": "thirty"}',
|
|
354
|
+
source=DataSource(
|
|
355
|
+
type=DataSourceType.human,
|
|
356
|
+
properties={"created_by": "john_doe"},
|
|
357
|
+
),
|
|
358
|
+
),
|
|
359
|
+
)
|
|
360
|
+
task_output.save_to_file()
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def test_task_input_schema_validation(tmp_path):
|
|
364
|
+
# Create a project and task hierarchy
|
|
365
|
+
project = Project(name="Test Project", path=(tmp_path / "test_project"))
|
|
366
|
+
project.save_to_file()
|
|
367
|
+
task = Task(
|
|
368
|
+
name="Test Task",
|
|
369
|
+
parent=project,
|
|
370
|
+
instruction="test instruction",
|
|
371
|
+
input_json_schema=json.dumps(
|
|
372
|
+
{
|
|
373
|
+
"type": "object",
|
|
374
|
+
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
|
|
375
|
+
"required": ["name", "age"],
|
|
376
|
+
}
|
|
377
|
+
),
|
|
378
|
+
)
|
|
379
|
+
task.save_to_file()
|
|
380
|
+
|
|
381
|
+
# Create an example with a valid input schema
|
|
382
|
+
valid_task_output = TaskRun(
|
|
383
|
+
input='{"name": "John Doe", "age": 30}',
|
|
384
|
+
input_source=DataSource(
|
|
385
|
+
type=DataSourceType.human,
|
|
386
|
+
properties={"created_by": "john_doe"},
|
|
387
|
+
),
|
|
388
|
+
parent=task,
|
|
389
|
+
output=TaskOutput(
|
|
390
|
+
output="Test output",
|
|
391
|
+
source=DataSource(
|
|
392
|
+
type=DataSourceType.human,
|
|
393
|
+
properties={"created_by": "john_doe"},
|
|
394
|
+
),
|
|
395
|
+
),
|
|
396
|
+
)
|
|
397
|
+
valid_task_output.save_to_file()
|
|
398
|
+
|
|
399
|
+
# Changing to invalid input
|
|
400
|
+
with pytest.raises(ValueError, match="does not match task input schema"):
|
|
401
|
+
valid_task_output.input = '{"name": "John Doe", "age": "thirty"}'
|
|
402
|
+
valid_task_output.save_to_file()
|
|
403
|
+
|
|
404
|
+
# Invalid case: input does not match task input schema
|
|
405
|
+
with pytest.raises(ValueError, match="does not match task input schema"):
|
|
406
|
+
task_output = TaskRun(
|
|
407
|
+
input='{"name": "John Doe", "age": "thirty"}',
|
|
408
|
+
input_source=DataSource(
|
|
409
|
+
type=DataSourceType.human,
|
|
410
|
+
properties={"created_by": "john_doe"},
|
|
411
|
+
),
|
|
412
|
+
parent=task,
|
|
413
|
+
output=TaskOutput(
|
|
414
|
+
output="Test output",
|
|
415
|
+
source=DataSource(
|
|
416
|
+
type=DataSourceType.human,
|
|
417
|
+
properties={"created_by": "john_doe"},
|
|
418
|
+
),
|
|
419
|
+
),
|
|
420
|
+
)
|
|
421
|
+
task_output.save_to_file()
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def test_valid_human_task_output():
|
|
425
|
+
output = TaskOutput(
|
|
426
|
+
output="Test output",
|
|
427
|
+
source=DataSource(
|
|
428
|
+
type=DataSourceType.human,
|
|
429
|
+
properties={"created_by": "John Doe"},
|
|
430
|
+
),
|
|
431
|
+
)
|
|
432
|
+
assert output.source.type == DataSourceType.human
|
|
433
|
+
assert output.source.properties["created_by"] == "John Doe"
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def test_invalid_human_task_output_missing_created_by():
|
|
437
|
+
with pytest.raises(
|
|
438
|
+
ValidationError, match="'created_by' is required for DataSourceType.human"
|
|
439
|
+
):
|
|
440
|
+
TaskOutput(
|
|
441
|
+
output="Test output",
|
|
442
|
+
source=DataSource(
|
|
443
|
+
type=DataSourceType.human,
|
|
444
|
+
properties={},
|
|
445
|
+
),
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def test_invalid_human_task_output_empty_created_by():
|
|
450
|
+
with pytest.raises(
|
|
451
|
+
ValidationError, match="Property 'created_by' must be a non-empty string"
|
|
452
|
+
):
|
|
453
|
+
TaskOutput(
|
|
454
|
+
output="Test output",
|
|
455
|
+
source=DataSource(
|
|
456
|
+
type=DataSourceType.human,
|
|
457
|
+
properties={"created_by": ""},
|
|
458
|
+
),
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
def test_valid_synthetic_task_output():
|
|
463
|
+
output = TaskOutput(
|
|
464
|
+
output="Test output",
|
|
465
|
+
source=DataSource(
|
|
466
|
+
type=DataSourceType.synthetic,
|
|
467
|
+
properties={
|
|
468
|
+
"adapter_name": "TestAdapter",
|
|
469
|
+
"model_name": "GPT-4",
|
|
470
|
+
"model_provider": "OpenAI",
|
|
471
|
+
"prompt_builder_name": "TestPromptBuilder",
|
|
472
|
+
},
|
|
473
|
+
),
|
|
474
|
+
)
|
|
475
|
+
assert output.source.type == DataSourceType.synthetic
|
|
476
|
+
assert output.source.properties["adapter_name"] == "TestAdapter"
|
|
477
|
+
assert output.source.properties["model_name"] == "GPT-4"
|
|
478
|
+
assert output.source.properties["model_provider"] == "OpenAI"
|
|
479
|
+
assert output.source.properties["prompt_builder_name"] == "TestPromptBuilder"
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def test_invalid_synthetic_task_output_missing_keys():
|
|
483
|
+
with pytest.raises(
|
|
484
|
+
ValidationError,
|
|
485
|
+
match="'model_provider' is required for DataSourceType.synthetic",
|
|
486
|
+
):
|
|
487
|
+
TaskOutput(
|
|
488
|
+
output="Test output",
|
|
489
|
+
source=DataSource(
|
|
490
|
+
type=DataSourceType.synthetic,
|
|
491
|
+
properties={"adapter_name": "TestAdapter", "model_name": "GPT-4"},
|
|
492
|
+
),
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def test_invalid_synthetic_task_output_empty_values():
|
|
497
|
+
with pytest.raises(
|
|
498
|
+
ValidationError, match="'model_name' must be a non-empty string"
|
|
499
|
+
):
|
|
500
|
+
TaskOutput(
|
|
501
|
+
output="Test output",
|
|
502
|
+
source=DataSource(
|
|
503
|
+
type=DataSourceType.synthetic,
|
|
504
|
+
properties={
|
|
505
|
+
"adapter_name": "TestAdapter",
|
|
506
|
+
"model_name": "",
|
|
507
|
+
"model_provider": "OpenAI",
|
|
508
|
+
"prompt_builder_name": "TestPromptBuilder",
|
|
509
|
+
},
|
|
510
|
+
),
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def test_invalid_synthetic_task_output_non_string_values():
|
|
515
|
+
with pytest.raises(
|
|
516
|
+
ValidationError, match="'prompt_builder_name' must be of type str"
|
|
517
|
+
):
|
|
518
|
+
DataSource(
|
|
519
|
+
type=DataSourceType.synthetic,
|
|
520
|
+
properties={
|
|
521
|
+
"adapter_name": "TestAdapter",
|
|
522
|
+
"model_name": "GPT-4",
|
|
523
|
+
"model_provider": "OpenAI",
|
|
524
|
+
"prompt_builder_name": 123,
|
|
525
|
+
},
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def test_task_run_validate_repaired_output():
|
|
530
|
+
# Test case 1: Valid TaskRun with no repaired_output
|
|
531
|
+
valid_task_run = TaskRun(
|
|
532
|
+
input="test input",
|
|
533
|
+
input_source=DataSource(
|
|
534
|
+
type=DataSourceType.human,
|
|
535
|
+
properties={"created_by": "john_doe"},
|
|
536
|
+
),
|
|
537
|
+
output=TaskOutput(
|
|
538
|
+
output="test output",
|
|
539
|
+
source=DataSource(
|
|
540
|
+
type=DataSourceType.human,
|
|
541
|
+
properties={"created_by": "john_doe"},
|
|
542
|
+
),
|
|
543
|
+
),
|
|
544
|
+
)
|
|
545
|
+
assert valid_task_run.repaired_output is None
|
|
546
|
+
|
|
547
|
+
# Test case 2: Valid TaskRun with repaired_output and no rating
|
|
548
|
+
valid_task_run_with_repair = TaskRun(
|
|
549
|
+
input="test input",
|
|
550
|
+
input_source=DataSource(
|
|
551
|
+
type=DataSourceType.human,
|
|
552
|
+
properties={"created_by": "john_doe"},
|
|
553
|
+
),
|
|
554
|
+
output=TaskOutput(
|
|
555
|
+
output="test output",
|
|
556
|
+
source=DataSource(
|
|
557
|
+
type=DataSourceType.human,
|
|
558
|
+
properties={"created_by": "john_doe"},
|
|
559
|
+
),
|
|
560
|
+
),
|
|
561
|
+
repair_instructions="Fix the output",
|
|
562
|
+
repaired_output=TaskOutput(
|
|
563
|
+
output="repaired output",
|
|
564
|
+
source=DataSource(
|
|
565
|
+
type=DataSourceType.human,
|
|
566
|
+
properties={"created_by": "john_doe"},
|
|
567
|
+
),
|
|
568
|
+
),
|
|
569
|
+
)
|
|
570
|
+
assert valid_task_run_with_repair.repaired_output is not None
|
|
571
|
+
assert valid_task_run_with_repair.repaired_output.rating is None
|
|
572
|
+
|
|
573
|
+
# test missing repair_instructions
|
|
574
|
+
with pytest.raises(ValidationError) as exc_info:
|
|
575
|
+
TaskRun(
|
|
576
|
+
input="test input",
|
|
577
|
+
input_source=DataSource(
|
|
578
|
+
type=DataSourceType.human,
|
|
579
|
+
properties={"created_by": "john_doe"},
|
|
580
|
+
),
|
|
581
|
+
output=TaskOutput(
|
|
582
|
+
output="test output",
|
|
583
|
+
source=DataSource(
|
|
584
|
+
type=DataSourceType.human,
|
|
585
|
+
properties={"created_by": "john_doe"},
|
|
586
|
+
),
|
|
587
|
+
),
|
|
588
|
+
repaired_output=TaskOutput(
|
|
589
|
+
output="repaired output",
|
|
590
|
+
source=DataSource(
|
|
591
|
+
type=DataSourceType.human,
|
|
592
|
+
properties={"created_by": "john_doe"},
|
|
593
|
+
),
|
|
594
|
+
),
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
assert "Repair instructions are required" in str(exc_info.value)
|
|
598
|
+
|
|
599
|
+
# test missing repaired_output
|
|
600
|
+
with pytest.raises(ValidationError) as exc_info:
|
|
601
|
+
TaskRun(
|
|
602
|
+
input="test input",
|
|
603
|
+
input_source=DataSource(
|
|
604
|
+
type=DataSourceType.human,
|
|
605
|
+
properties={"created_by": "john_doe"},
|
|
606
|
+
),
|
|
607
|
+
output=TaskOutput(
|
|
608
|
+
output="test output",
|
|
609
|
+
source=DataSource(
|
|
610
|
+
type=DataSourceType.human,
|
|
611
|
+
properties={"created_by": "john_doe"},
|
|
612
|
+
),
|
|
613
|
+
),
|
|
614
|
+
repair_instructions="Fix the output",
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
assert "A repaired output is required" in str(exc_info.value)
|
|
618
|
+
|
|
619
|
+
# Test case 3: Invalid TaskRun with repaired_output containing a rating
|
|
620
|
+
with pytest.raises(ValidationError) as exc_info:
|
|
621
|
+
TaskRun(
|
|
622
|
+
input="test input",
|
|
623
|
+
input_source=DataSource(
|
|
624
|
+
type=DataSourceType.human,
|
|
625
|
+
properties={"created_by": "john_doe"},
|
|
626
|
+
),
|
|
627
|
+
output=TaskOutput(
|
|
628
|
+
output="test output",
|
|
629
|
+
source=DataSource(
|
|
630
|
+
type=DataSourceType.human,
|
|
631
|
+
properties={"created_by": "john_doe"},
|
|
632
|
+
),
|
|
633
|
+
),
|
|
634
|
+
repaired_output=TaskOutput(
|
|
635
|
+
output="repaired output",
|
|
636
|
+
source=DataSource(
|
|
637
|
+
type=DataSourceType.human,
|
|
638
|
+
properties={"created_by": "john_doe"},
|
|
639
|
+
),
|
|
640
|
+
rating=TaskOutputRating(type=TaskOutputRatingType.five_star, value=5.0),
|
|
641
|
+
),
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
assert "Repaired output rating must be None" in str(exc_info.value)
|