kiln-ai 0.0.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (33) hide show
  1. kiln_ai/adapters/base_adapter.py +168 -0
  2. kiln_ai/adapters/langchain_adapters.py +113 -0
  3. kiln_ai/adapters/ml_model_list.py +436 -0
  4. kiln_ai/adapters/prompt_builders.py +122 -0
  5. kiln_ai/adapters/repair/repair_task.py +71 -0
  6. kiln_ai/adapters/repair/test_repair_task.py +248 -0
  7. kiln_ai/adapters/test_langchain_adapter.py +50 -0
  8. kiln_ai/adapters/test_ml_model_list.py +99 -0
  9. kiln_ai/adapters/test_prompt_adaptors.py +167 -0
  10. kiln_ai/adapters/test_prompt_builders.py +315 -0
  11. kiln_ai/adapters/test_saving_adapter_results.py +168 -0
  12. kiln_ai/adapters/test_structured_output.py +218 -0
  13. kiln_ai/datamodel/__init__.py +362 -2
  14. kiln_ai/datamodel/basemodel.py +372 -0
  15. kiln_ai/datamodel/json_schema.py +45 -0
  16. kiln_ai/datamodel/test_basemodel.py +277 -0
  17. kiln_ai/datamodel/test_datasource.py +107 -0
  18. kiln_ai/datamodel/test_example_models.py +644 -0
  19. kiln_ai/datamodel/test_json_schema.py +124 -0
  20. kiln_ai/datamodel/test_models.py +190 -0
  21. kiln_ai/datamodel/test_nested_save.py +205 -0
  22. kiln_ai/datamodel/test_output_rating.py +88 -0
  23. kiln_ai/utils/config.py +170 -0
  24. kiln_ai/utils/formatting.py +5 -0
  25. kiln_ai/utils/test_config.py +245 -0
  26. {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.0.dist-info}/METADATA +20 -1
  27. kiln_ai-0.5.0.dist-info/RECORD +29 -0
  28. kiln_ai/__init.__.py +0 -3
  29. kiln_ai/coreadd.py +0 -3
  30. kiln_ai/datamodel/project.py +0 -15
  31. kiln_ai-0.0.4.dist-info/RECORD +0 -8
  32. {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.0.dist-info}/LICENSE.txt +0 -0
  33. {kiln_ai-0.0.4.dist-info → kiln_ai-0.5.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,644 @@
1
+ import json
2
+
3
+ import pytest
4
+ from kiln_ai.datamodel import (
5
+ DataSource,
6
+ DataSourceType,
7
+ Project,
8
+ Task,
9
+ TaskDeterminism,
10
+ TaskOutput,
11
+ TaskOutputRating,
12
+ TaskOutputRatingType,
13
+ TaskRequirement,
14
+ TaskRun,
15
+ )
16
+ from pydantic import ValidationError
17
+
18
+
19
+ @pytest.fixture
20
+ def valid_task_run(tmp_path):
21
+ task = Task(
22
+ name="Test Task",
23
+ instruction="test instruction",
24
+ path=tmp_path / Task.base_filename(),
25
+ )
26
+ return TaskRun(
27
+ parent=task,
28
+ input="Test input",
29
+ input_source=DataSource(
30
+ type=DataSourceType.human,
31
+ properties={"created_by": "John Doe"},
32
+ ),
33
+ output=TaskOutput(
34
+ output="Test output",
35
+ source=DataSource(
36
+ type=DataSourceType.human,
37
+ properties={"created_by": "John Doe"},
38
+ ),
39
+ ),
40
+ )
41
+
42
+
43
+ def test_task_model_validation(valid_task_run):
44
+ task_run = valid_task_run
45
+ task_run.model_validate(task_run, strict=True)
46
+ task_run.save_to_file()
47
+ assert task_run.input == "Test input"
48
+ assert task_run.input_source.type == DataSourceType.human
49
+ assert task_run.input_source.properties == {"created_by": "John Doe"}
50
+ assert task_run.output.output == "Test output"
51
+ assert task_run.output.source.type == DataSourceType.human
52
+ assert task_run.output.source.properties == {"created_by": "John Doe"}
53
+
54
+ # Invalid source
55
+ with pytest.raises(ValidationError, match="Input should be"):
56
+ DataSource(type="invalid")
57
+
58
+ with pytest.raises(ValidationError, match="Invalid data source type"):
59
+ task_run = valid_task_run.model_copy(deep=True)
60
+ task_run.input_source.type = "invalid"
61
+ DataSource.model_validate(task_run.input_source, strict=True)
62
+
63
+ # Missing required field
64
+ with pytest.raises(ValidationError, match="Input should be a valid string"):
65
+ task_run = valid_task_run.model_copy()
66
+ task_run.input = None
67
+
68
+ # Invalid source_properties type
69
+ with pytest.raises(ValidationError):
70
+ task_run = valid_task_run.model_copy()
71
+ task_run.input_source.properties = "invalid"
72
+ DataSource.model_validate(task_run.input_source, strict=True)
73
+
74
+ # Test we catch nested validation errors
75
+ with pytest.raises(
76
+ ValidationError, match="'created_by' is required for DataSourceType.human"
77
+ ):
78
+ task_run = TaskRun(
79
+ input="Test input",
80
+ input_source=DataSource(
81
+ type=DataSourceType.human,
82
+ properties={"created_by": "John Doe"},
83
+ ),
84
+ output=TaskOutput(
85
+ output="Test output",
86
+ source=DataSource(
87
+ type=DataSourceType.human,
88
+ properties={"wrong_key": "John Doe"},
89
+ ),
90
+ ),
91
+ )
92
+
93
+
94
+ def test_task_run_relationship(valid_task_run):
95
+ assert valid_task_run.__class__.relationship_name() == "runs"
96
+ assert valid_task_run.__class__.parent_type().__name__ == "Task"
97
+
98
+
99
+ def test_structured_output_workflow(tmp_path):
100
+ tmp_project_file = (
101
+ tmp_path / "test_structured_output_runs" / Project.base_filename()
102
+ )
103
+ # Create project
104
+ project = Project(name="Test Project", path=str(tmp_project_file))
105
+ project.save_to_file()
106
+
107
+ # Create task with requirements
108
+ req1 = TaskRequirement(name="Req1", instruction="Name must be capitalized")
109
+ req2 = TaskRequirement(name="Req2", instruction="Age must be positive")
110
+
111
+ task = Task(
112
+ name="Structured Output Task",
113
+ parent=project,
114
+ instruction="Generate a JSON object with name and age",
115
+ determinism=TaskDeterminism.semantic_match,
116
+ output_json_schema=json.dumps(
117
+ {
118
+ "type": "object",
119
+ "properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
120
+ "required": ["name", "age"],
121
+ }
122
+ ),
123
+ requirements=[
124
+ req1,
125
+ req2,
126
+ ],
127
+ )
128
+ task.save_to_file()
129
+
130
+ # Create runs
131
+ runs = []
132
+ for source in DataSourceType:
133
+ for _ in range(2):
134
+ task_run = TaskRun(
135
+ input="Generate info for John Doe",
136
+ input_source=DataSource(
137
+ type=DataSourceType.human,
138
+ properties={"created_by": "john_doe"},
139
+ )
140
+ if source == DataSourceType.human
141
+ else DataSource(
142
+ type=DataSourceType.synthetic,
143
+ properties={
144
+ "adapter_name": "TestAdapter",
145
+ "model_name": "GPT-4",
146
+ "model_provider": "OpenAI",
147
+ "prompt_builder_name": "TestPromptBuilder",
148
+ },
149
+ ),
150
+ parent=task,
151
+ output=TaskOutput(
152
+ output='{"name": "John Doe", "age": 30}',
153
+ source=DataSource(
154
+ type=DataSourceType.human,
155
+ properties={"created_by": "john_doe"},
156
+ ),
157
+ ),
158
+ )
159
+ task_run.save_to_file()
160
+ runs.append(task_run)
161
+
162
+ # make a run with a repaired output
163
+ repaired_run = TaskRun(
164
+ input="Generate info for John Doe",
165
+ input_source=DataSource(
166
+ type=DataSourceType.human,
167
+ properties={"created_by": "john_doe"},
168
+ ),
169
+ parent=task,
170
+ output=TaskOutput(
171
+ output='{"name": "John Doe", "age": 31}',
172
+ source=DataSource(
173
+ type=DataSourceType.human,
174
+ properties={"created_by": "john_doe"},
175
+ ),
176
+ ),
177
+ repair_instructions="The age should be 31 instead of 30",
178
+ repaired_output=TaskOutput(
179
+ output='{"name": "John Doe", "age": 31}',
180
+ source=DataSource(
181
+ type=DataSourceType.human,
182
+ properties={"created_by": "john_doe"},
183
+ ),
184
+ ),
185
+ )
186
+ repaired_run.save_to_file()
187
+ runs.append(repaired_run)
188
+
189
+ # Update outputs with ratings
190
+ for task_run in runs:
191
+ task_run.output.rating = TaskOutputRating(
192
+ value=4,
193
+ requirement_ratings={
194
+ req1.id: 5,
195
+ req2.id: 5,
196
+ },
197
+ )
198
+ task_run.save_to_file()
199
+
200
+ # Load from disk and validate
201
+ loaded_project = Project.load_from_file(tmp_project_file)
202
+ loaded_task = loaded_project.tasks()[0]
203
+
204
+ assert loaded_task.name == "Structured Output Task"
205
+ assert len(loaded_task.requirements) == 2
206
+ assert len(loaded_task.runs()) == 5
207
+
208
+ loaded_runs = loaded_task.runs()
209
+ for task_run in loaded_runs:
210
+ output = task_run.output
211
+ assert output.rating is not None
212
+ assert output.rating.value == 4
213
+ assert len(output.rating.requirement_ratings) == 2
214
+
215
+ # Find the run with the fixed output
216
+ run_with_fixed_output = next(
217
+ (task_run for task_run in loaded_runs if task_run.repaired_output is not None),
218
+ None,
219
+ )
220
+ assert run_with_fixed_output is not None, "No run found with fixed output"
221
+ assert (
222
+ run_with_fixed_output.repaired_output.output
223
+ == '{"name": "John Doe", "age": 31}'
224
+ )
225
+
226
+
227
+ def test_task_output_requirement_rating_keys(tmp_path):
228
+ # Create a project, task, and example hierarchy
229
+ project = Project(name="Test Project", path=(tmp_path / "test_project"))
230
+ project.save_to_file()
231
+
232
+ # Create task requirements
233
+ req1 = TaskRequirement(
234
+ name="Requirement 1", instruction="Requirement 1 instruction"
235
+ )
236
+ req2 = TaskRequirement(
237
+ name="Requirement 2", instruction="Requirement 2 instruction"
238
+ )
239
+ task = Task(
240
+ name="Test Task",
241
+ parent=project,
242
+ instruction="Task instruction",
243
+ requirements=[req1, req2],
244
+ )
245
+ task.save_to_file()
246
+
247
+ # Valid case: all requirement IDs are valid
248
+ task_run = TaskRun(
249
+ input="Test input",
250
+ input_source=DataSource(
251
+ type=DataSourceType.human,
252
+ properties={"created_by": "john_doe"},
253
+ ),
254
+ parent=task,
255
+ output=TaskOutput(
256
+ output="Test output",
257
+ source=DataSource(
258
+ type=DataSourceType.human,
259
+ properties={"created_by": "john_doe"},
260
+ ),
261
+ rating=TaskOutputRating(
262
+ value=4,
263
+ requirement_ratings={
264
+ req1.id: 5,
265
+ req2.id: 4,
266
+ },
267
+ ),
268
+ ),
269
+ )
270
+ task_run.save_to_file()
271
+ assert task_run.output.rating.requirement_ratings is not None
272
+
273
+ # Invalid case: unknown requirement ID
274
+ with pytest.raises(
275
+ ValueError,
276
+ match="Requirement ID .* is not a valid requirement ID for this task",
277
+ ):
278
+ task_run = TaskRun(
279
+ input="Test input",
280
+ input_source=DataSource(
281
+ type=DataSourceType.human,
282
+ properties={"created_by": "john_doe"},
283
+ ),
284
+ parent=task,
285
+ output=TaskOutput(
286
+ output="Test output",
287
+ source=DataSource(
288
+ type=DataSourceType.human,
289
+ properties={"created_by": "john_doe"},
290
+ ),
291
+ rating=TaskOutputRating(
292
+ value=4,
293
+ requirement_ratings={
294
+ "unknown_id": 5,
295
+ },
296
+ ),
297
+ ),
298
+ )
299
+ task_run.save_to_file()
300
+
301
+
302
+ def test_task_output_schema_validation(tmp_path):
303
+ # Create a project, task, and example hierarchy
304
+ project = Project(name="Test Project", path=(tmp_path / "test_project"))
305
+ project.save_to_file()
306
+ task = Task(
307
+ name="Test Task",
308
+ instruction="test instruction",
309
+ parent=project,
310
+ output_json_schema=json.dumps(
311
+ {
312
+ "type": "object",
313
+ "properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
314
+ "required": ["name", "age"],
315
+ }
316
+ ),
317
+ )
318
+ task.save_to_file()
319
+
320
+ # Create an run output with a valid schema
321
+ task_output = TaskRun(
322
+ input="Test input",
323
+ input_source=DataSource(
324
+ type=DataSourceType.human,
325
+ properties={"created_by": "john_doe"},
326
+ ),
327
+ parent=task,
328
+ output=TaskOutput(
329
+ output='{"name": "John Doe", "age": 30}',
330
+ source=DataSource(
331
+ type=DataSourceType.human,
332
+ properties={"created_by": "john_doe"},
333
+ ),
334
+ ),
335
+ )
336
+ task_output.save_to_file()
337
+
338
+ # changing to invalid output
339
+ with pytest.raises(ValueError, match="does not match task output schema"):
340
+ task_output.output.output = '{"name": "John Doe", "age": "thirty"}'
341
+ task_output.save_to_file()
342
+
343
+ # Invalid case: output does not match task output schema
344
+ with pytest.raises(ValueError, match="does not match task output schema"):
345
+ task_output = TaskRun(
346
+ input="Test input",
347
+ input_source=DataSource(
348
+ type=DataSourceType.human,
349
+ properties={"created_by": "john_doe"},
350
+ ),
351
+ parent=task,
352
+ output=TaskOutput(
353
+ output='{"name": "John Doe", "age": "thirty"}',
354
+ source=DataSource(
355
+ type=DataSourceType.human,
356
+ properties={"created_by": "john_doe"},
357
+ ),
358
+ ),
359
+ )
360
+ task_output.save_to_file()
361
+
362
+
363
+ def test_task_input_schema_validation(tmp_path):
364
+ # Create a project and task hierarchy
365
+ project = Project(name="Test Project", path=(tmp_path / "test_project"))
366
+ project.save_to_file()
367
+ task = Task(
368
+ name="Test Task",
369
+ parent=project,
370
+ instruction="test instruction",
371
+ input_json_schema=json.dumps(
372
+ {
373
+ "type": "object",
374
+ "properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
375
+ "required": ["name", "age"],
376
+ }
377
+ ),
378
+ )
379
+ task.save_to_file()
380
+
381
+ # Create an example with a valid input schema
382
+ valid_task_output = TaskRun(
383
+ input='{"name": "John Doe", "age": 30}',
384
+ input_source=DataSource(
385
+ type=DataSourceType.human,
386
+ properties={"created_by": "john_doe"},
387
+ ),
388
+ parent=task,
389
+ output=TaskOutput(
390
+ output="Test output",
391
+ source=DataSource(
392
+ type=DataSourceType.human,
393
+ properties={"created_by": "john_doe"},
394
+ ),
395
+ ),
396
+ )
397
+ valid_task_output.save_to_file()
398
+
399
+ # Changing to invalid input
400
+ with pytest.raises(ValueError, match="does not match task input schema"):
401
+ valid_task_output.input = '{"name": "John Doe", "age": "thirty"}'
402
+ valid_task_output.save_to_file()
403
+
404
+ # Invalid case: input does not match task input schema
405
+ with pytest.raises(ValueError, match="does not match task input schema"):
406
+ task_output = TaskRun(
407
+ input='{"name": "John Doe", "age": "thirty"}',
408
+ input_source=DataSource(
409
+ type=DataSourceType.human,
410
+ properties={"created_by": "john_doe"},
411
+ ),
412
+ parent=task,
413
+ output=TaskOutput(
414
+ output="Test output",
415
+ source=DataSource(
416
+ type=DataSourceType.human,
417
+ properties={"created_by": "john_doe"},
418
+ ),
419
+ ),
420
+ )
421
+ task_output.save_to_file()
422
+
423
+
424
+ def test_valid_human_task_output():
425
+ output = TaskOutput(
426
+ output="Test output",
427
+ source=DataSource(
428
+ type=DataSourceType.human,
429
+ properties={"created_by": "John Doe"},
430
+ ),
431
+ )
432
+ assert output.source.type == DataSourceType.human
433
+ assert output.source.properties["created_by"] == "John Doe"
434
+
435
+
436
+ def test_invalid_human_task_output_missing_created_by():
437
+ with pytest.raises(
438
+ ValidationError, match="'created_by' is required for DataSourceType.human"
439
+ ):
440
+ TaskOutput(
441
+ output="Test output",
442
+ source=DataSource(
443
+ type=DataSourceType.human,
444
+ properties={},
445
+ ),
446
+ )
447
+
448
+
449
+ def test_invalid_human_task_output_empty_created_by():
450
+ with pytest.raises(
451
+ ValidationError, match="Property 'created_by' must be a non-empty string"
452
+ ):
453
+ TaskOutput(
454
+ output="Test output",
455
+ source=DataSource(
456
+ type=DataSourceType.human,
457
+ properties={"created_by": ""},
458
+ ),
459
+ )
460
+
461
+
462
+ def test_valid_synthetic_task_output():
463
+ output = TaskOutput(
464
+ output="Test output",
465
+ source=DataSource(
466
+ type=DataSourceType.synthetic,
467
+ properties={
468
+ "adapter_name": "TestAdapter",
469
+ "model_name": "GPT-4",
470
+ "model_provider": "OpenAI",
471
+ "prompt_builder_name": "TestPromptBuilder",
472
+ },
473
+ ),
474
+ )
475
+ assert output.source.type == DataSourceType.synthetic
476
+ assert output.source.properties["adapter_name"] == "TestAdapter"
477
+ assert output.source.properties["model_name"] == "GPT-4"
478
+ assert output.source.properties["model_provider"] == "OpenAI"
479
+ assert output.source.properties["prompt_builder_name"] == "TestPromptBuilder"
480
+
481
+
482
+ def test_invalid_synthetic_task_output_missing_keys():
483
+ with pytest.raises(
484
+ ValidationError,
485
+ match="'model_provider' is required for DataSourceType.synthetic",
486
+ ):
487
+ TaskOutput(
488
+ output="Test output",
489
+ source=DataSource(
490
+ type=DataSourceType.synthetic,
491
+ properties={"adapter_name": "TestAdapter", "model_name": "GPT-4"},
492
+ ),
493
+ )
494
+
495
+
496
+ def test_invalid_synthetic_task_output_empty_values():
497
+ with pytest.raises(
498
+ ValidationError, match="'model_name' must be a non-empty string"
499
+ ):
500
+ TaskOutput(
501
+ output="Test output",
502
+ source=DataSource(
503
+ type=DataSourceType.synthetic,
504
+ properties={
505
+ "adapter_name": "TestAdapter",
506
+ "model_name": "",
507
+ "model_provider": "OpenAI",
508
+ "prompt_builder_name": "TestPromptBuilder",
509
+ },
510
+ ),
511
+ )
512
+
513
+
514
+ def test_invalid_synthetic_task_output_non_string_values():
515
+ with pytest.raises(
516
+ ValidationError, match="'prompt_builder_name' must be of type str"
517
+ ):
518
+ DataSource(
519
+ type=DataSourceType.synthetic,
520
+ properties={
521
+ "adapter_name": "TestAdapter",
522
+ "model_name": "GPT-4",
523
+ "model_provider": "OpenAI",
524
+ "prompt_builder_name": 123,
525
+ },
526
+ )
527
+
528
+
529
+ def test_task_run_validate_repaired_output():
530
+ # Test case 1: Valid TaskRun with no repaired_output
531
+ valid_task_run = TaskRun(
532
+ input="test input",
533
+ input_source=DataSource(
534
+ type=DataSourceType.human,
535
+ properties={"created_by": "john_doe"},
536
+ ),
537
+ output=TaskOutput(
538
+ output="test output",
539
+ source=DataSource(
540
+ type=DataSourceType.human,
541
+ properties={"created_by": "john_doe"},
542
+ ),
543
+ ),
544
+ )
545
+ assert valid_task_run.repaired_output is None
546
+
547
+ # Test case 2: Valid TaskRun with repaired_output and no rating
548
+ valid_task_run_with_repair = TaskRun(
549
+ input="test input",
550
+ input_source=DataSource(
551
+ type=DataSourceType.human,
552
+ properties={"created_by": "john_doe"},
553
+ ),
554
+ output=TaskOutput(
555
+ output="test output",
556
+ source=DataSource(
557
+ type=DataSourceType.human,
558
+ properties={"created_by": "john_doe"},
559
+ ),
560
+ ),
561
+ repair_instructions="Fix the output",
562
+ repaired_output=TaskOutput(
563
+ output="repaired output",
564
+ source=DataSource(
565
+ type=DataSourceType.human,
566
+ properties={"created_by": "john_doe"},
567
+ ),
568
+ ),
569
+ )
570
+ assert valid_task_run_with_repair.repaired_output is not None
571
+ assert valid_task_run_with_repair.repaired_output.rating is None
572
+
573
+ # test missing repair_instructions
574
+ with pytest.raises(ValidationError) as exc_info:
575
+ TaskRun(
576
+ input="test input",
577
+ input_source=DataSource(
578
+ type=DataSourceType.human,
579
+ properties={"created_by": "john_doe"},
580
+ ),
581
+ output=TaskOutput(
582
+ output="test output",
583
+ source=DataSource(
584
+ type=DataSourceType.human,
585
+ properties={"created_by": "john_doe"},
586
+ ),
587
+ ),
588
+ repaired_output=TaskOutput(
589
+ output="repaired output",
590
+ source=DataSource(
591
+ type=DataSourceType.human,
592
+ properties={"created_by": "john_doe"},
593
+ ),
594
+ ),
595
+ )
596
+
597
+ assert "Repair instructions are required" in str(exc_info.value)
598
+
599
+ # test missing repaired_output
600
+ with pytest.raises(ValidationError) as exc_info:
601
+ TaskRun(
602
+ input="test input",
603
+ input_source=DataSource(
604
+ type=DataSourceType.human,
605
+ properties={"created_by": "john_doe"},
606
+ ),
607
+ output=TaskOutput(
608
+ output="test output",
609
+ source=DataSource(
610
+ type=DataSourceType.human,
611
+ properties={"created_by": "john_doe"},
612
+ ),
613
+ ),
614
+ repair_instructions="Fix the output",
615
+ )
616
+
617
+ assert "A repaired output is required" in str(exc_info.value)
618
+
619
+ # Test case 3: Invalid TaskRun with repaired_output containing a rating
620
+ with pytest.raises(ValidationError) as exc_info:
621
+ TaskRun(
622
+ input="test input",
623
+ input_source=DataSource(
624
+ type=DataSourceType.human,
625
+ properties={"created_by": "john_doe"},
626
+ ),
627
+ output=TaskOutput(
628
+ output="test output",
629
+ source=DataSource(
630
+ type=DataSourceType.human,
631
+ properties={"created_by": "john_doe"},
632
+ ),
633
+ ),
634
+ repaired_output=TaskOutput(
635
+ output="repaired output",
636
+ source=DataSource(
637
+ type=DataSourceType.human,
638
+ properties={"created_by": "john_doe"},
639
+ ),
640
+ rating=TaskOutputRating(type=TaskOutputRatingType.five_star, value=5.0),
641
+ ),
642
+ )
643
+
644
+ assert "Repaired output rating must be None" in str(exc_info.value)