kiln-ai 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. kiln_ai/adapters/eval/base_eval.py +7 -2
  2. kiln_ai/adapters/eval/eval_runner.py +5 -64
  3. kiln_ai/adapters/eval/g_eval.py +3 -3
  4. kiln_ai/adapters/fine_tune/base_finetune.py +6 -3
  5. kiln_ai/adapters/fine_tune/dataset_formatter.py +128 -38
  6. kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
  7. kiln_ai/adapters/fine_tune/fireworks_finetune.py +2 -1
  8. kiln_ai/adapters/fine_tune/test_base_finetune.py +7 -0
  9. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +267 -10
  10. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
  11. kiln_ai/adapters/fine_tune/test_vertex_finetune.py +586 -0
  12. kiln_ai/adapters/fine_tune/vertex_finetune.py +217 -0
  13. kiln_ai/adapters/ml_model_list.py +817 -62
  14. kiln_ai/adapters/model_adapters/base_adapter.py +33 -10
  15. kiln_ai/adapters/model_adapters/litellm_adapter.py +51 -12
  16. kiln_ai/adapters/model_adapters/test_base_adapter.py +74 -2
  17. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +65 -1
  18. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +3 -2
  19. kiln_ai/adapters/model_adapters/test_structured_output.py +4 -6
  20. kiln_ai/adapters/parsers/base_parser.py +0 -3
  21. kiln_ai/adapters/parsers/parser_registry.py +5 -3
  22. kiln_ai/adapters/parsers/r1_parser.py +17 -2
  23. kiln_ai/adapters/parsers/request_formatters.py +40 -0
  24. kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
  25. kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
  26. kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
  27. kiln_ai/adapters/prompt_builders.py +14 -1
  28. kiln_ai/adapters/provider_tools.py +25 -1
  29. kiln_ai/adapters/repair/test_repair_task.py +3 -2
  30. kiln_ai/adapters/test_prompt_builders.py +24 -3
  31. kiln_ai/adapters/test_provider_tools.py +86 -1
  32. kiln_ai/datamodel/__init__.py +2 -0
  33. kiln_ai/datamodel/datamodel_enums.py +14 -0
  34. kiln_ai/datamodel/dataset_filters.py +69 -1
  35. kiln_ai/datamodel/dataset_split.py +4 -0
  36. kiln_ai/datamodel/eval.py +8 -0
  37. kiln_ai/datamodel/finetune.py +1 -0
  38. kiln_ai/datamodel/json_schema.py +24 -7
  39. kiln_ai/datamodel/prompt_id.py +1 -0
  40. kiln_ai/datamodel/task_output.py +10 -6
  41. kiln_ai/datamodel/task_run.py +68 -12
  42. kiln_ai/datamodel/test_basemodel.py +3 -7
  43. kiln_ai/datamodel/test_dataset_filters.py +82 -0
  44. kiln_ai/datamodel/test_dataset_split.py +2 -0
  45. kiln_ai/datamodel/test_example_models.py +158 -3
  46. kiln_ai/datamodel/test_json_schema.py +22 -3
  47. kiln_ai/datamodel/test_model_perf.py +3 -2
  48. kiln_ai/datamodel/test_models.py +50 -2
  49. kiln_ai/utils/async_job_runner.py +106 -0
  50. kiln_ai/utils/dataset_import.py +80 -18
  51. kiln_ai/utils/test_async_job_runner.py +199 -0
  52. kiln_ai/utils/test_dataset_import.py +242 -10
  53. {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/METADATA +3 -2
  54. kiln_ai-0.16.0.dist-info/RECORD +108 -0
  55. kiln_ai/adapters/test_generate_docs.py +0 -69
  56. kiln_ai-0.14.0.dist-info/RECORD +0 -103
  57. {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/WHEEL +0 -0
  58. {kiln_ai-0.14.0.dist-info → kiln_ai-0.16.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -3,11 +3,11 @@ from typing import TYPE_CHECKING, Dict, List, Union
3
3
 
4
4
  import jsonschema
5
5
  import jsonschema.exceptions
6
- from pydantic import Field, ValidationInfo, model_validator
6
+ from pydantic import BaseModel, Field, ValidationInfo, model_validator
7
7
  from typing_extensions import Self
8
8
 
9
9
  from kiln_ai.datamodel.basemodel import KilnParentedModel
10
- from kiln_ai.datamodel.json_schema import validate_schema
10
+ from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
11
11
  from kiln_ai.datamodel.strict_mode import strict_mode
12
12
  from kiln_ai.datamodel.task_output import DataSource, TaskOutput
13
13
 
@@ -15,6 +15,29 @@ if TYPE_CHECKING:
15
15
  from kiln_ai.datamodel.task import Task
16
16
 
17
17
 
18
+ class Usage(BaseModel):
19
+ input_tokens: int | None = Field(
20
+ default=None,
21
+ description="The number of input tokens used in the task run.",
22
+ ge=0,
23
+ )
24
+ output_tokens: int | None = Field(
25
+ default=None,
26
+ description="The number of output tokens used in the task run.",
27
+ ge=0,
28
+ )
29
+ total_tokens: int | None = Field(
30
+ default=None,
31
+ description="The total number of tokens used in the task run.",
32
+ ge=0,
33
+ )
34
+ cost: float | None = Field(
35
+ default=None,
36
+ description="The cost of the task run in US dollars, saved at runtime (prices can change over time).",
37
+ ge=0,
38
+ )
39
+
40
+
18
41
  class TaskRun(KilnParentedModel):
19
42
  """
20
43
  Represents a single execution of a Task.
@@ -47,17 +70,26 @@ class TaskRun(KilnParentedModel):
47
70
  default=[],
48
71
  description="Tags for the task run. Tags are used to categorize task runs for filtering and reporting.",
49
72
  )
73
+ usage: Usage | None = Field(
74
+ default=None,
75
+ description="Usage information for the task run. This includes the number of input tokens, output tokens, and total tokens used.",
76
+ )
77
+
78
+ def thinking_training_data(self) -> str | None:
79
+ """
80
+ Get the thinking training data from the task run.
81
+ """
82
+ if self.intermediate_outputs is None:
83
+ return None
84
+ return self.intermediate_outputs.get(
85
+ "reasoning"
86
+ ) or self.intermediate_outputs.get("chain_of_thought")
50
87
 
51
88
  def has_thinking_training_data(self) -> bool:
52
89
  """
53
90
  Does this run have thinking data that we can use to train a thinking model?
54
91
  """
55
- if self.intermediate_outputs is None:
56
- return False
57
- return (
58
- "chain_of_thought" in self.intermediate_outputs
59
- or "reasoning" in self.intermediate_outputs
60
- )
92
+ return self.thinking_training_data() is not None
61
93
 
62
94
  # Workaround to return typed parent without importing Task
63
95
  def parent_task(self) -> Union["Task", None]:
@@ -87,14 +119,19 @@ class TaskRun(KilnParentedModel):
87
119
  # don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving)
88
120
  return self
89
121
 
90
- # validate output
122
+ # validate input
91
123
  if task.input_json_schema is not None:
92
124
  try:
93
- validate_schema(json.loads(self.input), task.input_json_schema)
125
+ input_parsed = json.loads(self.input)
94
126
  except json.JSONDecodeError:
95
127
  raise ValueError("Input is not a valid JSON object")
96
- except jsonschema.exceptions.ValidationError as e:
97
- raise ValueError(f"Input does not match task input schema: {e}")
128
+
129
+ validate_schema_with_value_error(
130
+ input_parsed,
131
+ task.input_json_schema,
132
+ "Input does not match task input schema.",
133
+ )
134
+
98
135
  self._last_validated_input = self.input
99
136
  return self
100
137
 
@@ -131,6 +168,24 @@ class TaskRun(KilnParentedModel):
131
168
  raise ValueError(
132
169
  "Repaired output rating must be None. Repaired outputs are assumed to have a perfect rating, as they have been fixed."
133
170
  )
171
+
172
+ task = self.parent_task()
173
+ if (
174
+ task is not None
175
+ and self.repaired_output.output is not None
176
+ and task.output_json_schema is not None
177
+ ):
178
+ try:
179
+ output_parsed = json.loads(self.repaired_output.output)
180
+ except json.JSONDecodeError:
181
+ raise ValueError("Repaired output is not a valid JSON object")
182
+
183
+ validate_schema_with_value_error(
184
+ output_parsed,
185
+ task.output_json_schema,
186
+ "Repaired output does not match task output schema.",
187
+ )
188
+
134
189
  if self.repair_instructions is None and self.repaired_output is not None:
135
190
  raise ValueError(
136
191
  "Repair instructions are required if providing a repaired output."
@@ -139,6 +194,7 @@ class TaskRun(KilnParentedModel):
139
194
  raise ValueError(
140
195
  "A repaired output is required if providing repair instructions."
141
196
  )
197
+
142
198
  return self
143
199
 
144
200
  @model_validator(mode="after")
@@ -483,7 +483,7 @@ class MockAdapter(BaseAdapter):
483
483
  """Implementation of BaseAdapter for testing"""
484
484
 
485
485
  async def _run(self, input):
486
- return RunOutput(output="test output", intermediate_outputs=None)
486
+ return RunOutput(output="test output", intermediate_outputs=None), None
487
487
 
488
488
  def adapter_name(self) -> str:
489
489
  return "test"
@@ -510,6 +510,7 @@ async def test_invoke_parsing_flow(adapter):
510
510
  # Mock dependencies
511
511
  mock_provider = MagicMock()
512
512
  mock_provider.parser = "test_parser"
513
+ mock_provider.formatter = None
513
514
  mock_provider.reasoning_capable = False
514
515
 
515
516
  mock_parser = MagicMock()
@@ -517,13 +518,11 @@ async def test_invoke_parsing_flow(adapter):
517
518
  output="parsed test output", intermediate_outputs={"key": "value"}
518
519
  )
519
520
 
520
- mock_parser_class = MagicMock(return_value=mock_parser)
521
-
522
521
  with (
523
522
  patch.object(adapter, "model_provider", return_value=mock_provider),
524
523
  patch(
525
524
  "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id",
526
- return_value=mock_parser_class,
525
+ return_value=mock_parser,
527
526
  ),
528
527
  patch("kiln_ai.adapters.model_adapters.base_adapter.Config") as mock_config,
529
528
  ):
@@ -534,9 +533,6 @@ async def test_invoke_parsing_flow(adapter):
534
533
  # Execute
535
534
  result = await adapter.invoke("test input")
536
535
 
537
- # Verify parser was created correctly
538
- mock_parser_class.assert_called_once_with(structured_output=False)
539
-
540
536
  # Verify parsing occurred
541
537
  mock_parser.parse_output.assert_called_once()
542
538
  parsed_args = mock_parser.parse_output.call_args[1]
@@ -1,3 +1,5 @@
1
+ from unittest.mock import Mock
2
+
1
3
  import pytest
2
4
  from pydantic import BaseModel
3
5
 
@@ -5,12 +7,14 @@ from kiln_ai.datamodel.dataset_filters import (
5
7
  AllDatasetFilter,
6
8
  DatasetFilterId,
7
9
  HighRatingDatasetFilter,
10
+ MultiDatasetFilter,
8
11
  StaticDatasetFilters,
9
12
  TagFilter,
10
13
  ThinkingModelDatasetFilter,
11
14
  ThinkingModelHighRatedFilter,
12
15
  dataset_filter_from_id,
13
16
  )
17
+ from kiln_ai.datamodel.task_run import TaskRun
14
18
 
15
19
  # Note: Many more filter tests in test_dataset_split.py
16
20
 
@@ -69,3 +73,81 @@ def test_tag_filter(tag, expected_error, expected_tag):
69
73
  filter = dataset_filter_from_id(tag)
70
74
  assert isinstance(filter, TagFilter)
71
75
  assert filter.tag == expected_tag
76
+
77
+
78
+ class TestMultiDatasetFilter:
79
+ @pytest.mark.parametrize(
80
+ "filter_string,expected_filters",
81
+ [
82
+ ("multi_filter::high_rating", ["high_rating"]),
83
+ (
84
+ "multi_filter::high_rating&thinking_model",
85
+ ["high_rating", "thinking_model"],
86
+ ),
87
+ ("multi_filter::tag::test&high_rating", ["tag::test", "high_rating"]),
88
+ (
89
+ "multi_filter::high_rating&tag::tag\\&name",
90
+ ["high_rating", "tag::tag&name"],
91
+ ),
92
+ ],
93
+ )
94
+ def test_valid_filter_string_parsing(self, filter_string, expected_filters):
95
+ """Test that valid filter strings are parsed correctly."""
96
+ assert MultiDatasetFilter.parse_filter_string(filter_string) == expected_filters
97
+ assert MultiDatasetFilter.is_valid_filter_string(filter_string)
98
+
99
+ @pytest.mark.parametrize(
100
+ "filter_string,expected_error",
101
+ [
102
+ (
103
+ "not_multi_filter::high_rating",
104
+ "Filter string must start with multi_filter::",
105
+ ),
106
+ ("multi_filter::", "No filters specified after prefix"),
107
+ ("multi_filter::high_rating&", "Invalid dataset filter ID:"),
108
+ ("multi_filter::invalid_filter", "Invalid dataset filter ID:"),
109
+ ],
110
+ )
111
+ def test_invalid_filter_string_handling(self, filter_string, expected_error):
112
+ """Test that invalid filter strings raise appropriate errors."""
113
+ with pytest.raises(ValueError, match=expected_error):
114
+ MultiDatasetFilter.parse_filter_string(filter_string)
115
+ assert not MultiDatasetFilter.is_valid_filter_string(filter_string)
116
+
117
+ def test_filter_combination_logic(self):
118
+ """Test that multiple filters are combined with AND logic."""
119
+ # Create a mock task run
120
+ task_run = Mock(spec=TaskRun)
121
+ task_run.output = Mock()
122
+ task_run.output.rating = Mock()
123
+ task_run.output.rating.is_high_quality.return_value = True
124
+ task_run.tags = ["test_tag"]
125
+ task_run.has_thinking_training_data.return_value = True
126
+ task_run.repaired_output = None
127
+
128
+ # Test combining high_rating and tag filters
129
+ filter_id = "multi_filter::high_rating&tag::test_tag"
130
+ multi_filter = dataset_filter_from_id(filter_id)
131
+ assert multi_filter(task_run)
132
+
133
+ # Test that it fails if one filter fails
134
+ task_run.tags = ["wrong_tag"]
135
+ assert not multi_filter(task_run)
136
+ task_run.tags = ["test_tag"]
137
+ assert multi_filter(task_run)
138
+ task_run.output.rating.is_high_quality.return_value = False
139
+ assert not multi_filter(task_run)
140
+
141
+ # Verify the mock was called as expected
142
+ task_run.output.rating.is_high_quality.assert_called()
143
+
144
+ def test_filter_creation_from_id(self):
145
+ """Test that multi filters can be created via dataset_filter_from_id."""
146
+ filter_id = "multi_filter::high_rating&thinking_model"
147
+ filter = dataset_filter_from_id(filter_id)
148
+ assert isinstance(filter, MultiDatasetFilter)
149
+ assert len(filter.filters) == 2
150
+ assert any(isinstance(f, type(HighRatingDatasetFilter)) for f in filter.filters)
151
+ assert any(
152
+ isinstance(f, type(ThinkingModelDatasetFilter)) for f in filter.filters
153
+ )
@@ -17,6 +17,7 @@ from kiln_ai.datamodel.dataset_split import (
17
17
  AllSplitDefinition,
18
18
  Train60Test20Val20SplitDefinition,
19
19
  Train80Test20SplitDefinition,
20
+ Train80Val20SplitDefinition,
20
21
  )
21
22
  from kiln_ai.datamodel.test_dataset_filters import (
22
23
  AllDatasetFilter,
@@ -174,6 +175,7 @@ def test_high_rating_dataset_filter(sample_task_runs):
174
175
  [
175
176
  (Train80Test20SplitDefinition, {"train": 8, "test": 2}),
176
177
  (AllSplitDefinition, {"all": 10}),
178
+ (Train80Val20SplitDefinition, {"train": 8, "val": 2}),
177
179
  (Train60Test20Val20SplitDefinition, {"train": 6, "test": 2, "val": 2}),
178
180
  (
179
181
  [
@@ -16,6 +16,7 @@ from kiln_ai.datamodel import (
16
16
  TaskOutputRatingType,
17
17
  TaskRequirement,
18
18
  TaskRun,
19
+ Usage,
19
20
  )
20
21
 
21
22
 
@@ -358,6 +359,9 @@ def test_task_output_schema_validation(tmp_path):
358
359
  task_output.save_to_file()
359
360
 
360
361
 
362
+ _input_schema_match = "Input does not match task input schema"
363
+
364
+
361
365
  def test_task_input_schema_validation(tmp_path):
362
366
  # Create a project and task hierarchy
363
367
  project = Project(name="Test Project", path=(tmp_path / "test_project"))
@@ -395,18 +399,18 @@ def test_task_input_schema_validation(tmp_path):
395
399
  valid_task_output.save_to_file()
396
400
 
397
401
  # Changing to invalid input
398
- with pytest.raises(ValueError, match=_schema_match):
402
+ with pytest.raises(ValueError, match=_input_schema_match):
399
403
  valid_task_output.input = '{"name": "John Doe", "age": "thirty"}'
400
404
  valid_task_output.save_to_file()
401
405
 
402
406
  # loading from file, then changing to invalid input
403
407
  loaded_task_output = TaskRun.load_from_file(valid_task_output.path)
404
- with pytest.raises(ValueError, match=_schema_match):
408
+ with pytest.raises(ValueError, match=_input_schema_match):
405
409
  loaded_task_output.input = '{"name": "John Doe", "age": "thirty"}'
406
410
  loaded_task_output.save_to_file()
407
411
 
408
412
  # Invalid case: input does not match task input schema
409
- with pytest.raises(ValueError, match=_schema_match):
413
+ with pytest.raises(ValueError, match=_input_schema_match):
410
414
  task_output = TaskRun(
411
415
  input='{"name": "John Doe", "age": "thirty"}',
412
416
  input_source=DataSource(
@@ -642,3 +646,154 @@ def test_task_run_validate_repaired_output():
642
646
  )
643
647
 
644
648
  assert "Repaired output rating must be None" in str(exc_info.value)
649
+
650
+
651
+ def test_task_run_validate_repaired_output_structured(tmp_path):
652
+ # Create a project, task, and example hierarchy
653
+ project = Project(name="Test Project", path=(tmp_path / "test_project"))
654
+ project.save_to_file()
655
+ task = Task(
656
+ name="Test Task",
657
+ instruction="test instruction",
658
+ parent=project,
659
+ output_json_schema=json.dumps(
660
+ {
661
+ "type": "object",
662
+ "properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
663
+ "required": ["name", "age"],
664
+ }
665
+ ),
666
+ )
667
+ task.save_to_file()
668
+
669
+ # test valid repaired output schema
670
+ task_run = TaskRun(
671
+ parent=task,
672
+ input="test input",
673
+ input_source=DataSource(
674
+ type=DataSourceType.human,
675
+ properties={"created_by": "john_doe"},
676
+ ),
677
+ output=TaskOutput(
678
+ output='{"name": "John Doe", "age": 30}',
679
+ source=DataSource(
680
+ type=DataSourceType.human,
681
+ properties={"created_by": "john_doe"},
682
+ ),
683
+ ),
684
+ repair_instructions="Fix the output",
685
+ repaired_output=TaskOutput(
686
+ output='{"name": "John Doe", "age": 30}',
687
+ source=DataSource(
688
+ type=DataSourceType.human, properties={"created_by": "john_doe"}
689
+ ),
690
+ ),
691
+ )
692
+
693
+ assert task_run.repaired_output is not None
694
+ assert task_run.repaired_output.rating is None
695
+
696
+ # test invalid JSON
697
+ with pytest.raises(ValueError):
698
+ TaskRun(
699
+ parent=task,
700
+ input="test input",
701
+ input_source=DataSource(
702
+ type=DataSourceType.human,
703
+ properties={"created_by": "john_doe"},
704
+ ),
705
+ output=TaskOutput(
706
+ output='{"name": "John Doe", "age": 30}',
707
+ source=DataSource(
708
+ type=DataSourceType.human,
709
+ properties={"created_by": "john_doe"},
710
+ ),
711
+ ),
712
+ repair_instructions="Fix the output",
713
+ repaired_output=TaskOutput(
714
+ output='{"name": "John Doe", "age": 30', # missing closing brace
715
+ source=DataSource(
716
+ type=DataSourceType.human,
717
+ properties={"created_by": "john_doe"},
718
+ ),
719
+ ),
720
+ )
721
+
722
+ # test invalid repaired output schema
723
+ with pytest.raises(ValueError):
724
+ TaskRun(
725
+ parent=task,
726
+ input="test input",
727
+ input_source=DataSource(
728
+ type=DataSourceType.human,
729
+ properties={"created_by": "john_doe"},
730
+ ),
731
+ output=TaskOutput(
732
+ output='{"name": "John Doe", "age": 30}',
733
+ source=DataSource(
734
+ type=DataSourceType.human,
735
+ properties={"created_by": "john_doe"},
736
+ ),
737
+ ),
738
+ repair_instructions="Fix the output",
739
+ repaired_output=TaskOutput(
740
+ output='{"name": "John Doe", "age": "thirty"}', # invalid schema
741
+ source=DataSource(
742
+ type=DataSourceType.human,
743
+ properties={"created_by": "john_doe"},
744
+ ),
745
+ ),
746
+ )
747
+
748
+
749
+ @pytest.mark.parametrize(
750
+ "input_tokens,output_tokens,total_tokens,cost,should_raise",
751
+ [
752
+ # Valid cases
753
+ (100, 50, 150, 0.002, False), # All fields
754
+ (None, None, None, None, False), # All None (defaults)
755
+ # Invalid cases
756
+ (-100, 50, 150, 0.002, True), # Negative input_tokens
757
+ (100, -50, 150, 0.002, True), # Negative output_tokens
758
+ (100, 50, -150, 0.002, True), # Negative total_tokens
759
+ (100, 50, 150, -0.002, True), # Negative cost
760
+ ],
761
+ )
762
+ def test_usage_model(input_tokens, output_tokens, total_tokens, cost, should_raise):
763
+ """Test the Usage model with various input combinations."""
764
+ if should_raise:
765
+ with pytest.raises(ValidationError):
766
+ Usage(
767
+ input_tokens=input_tokens,
768
+ output_tokens=output_tokens,
769
+ total_tokens=total_tokens,
770
+ cost=cost,
771
+ )
772
+ else:
773
+ usage = Usage(
774
+ input_tokens=input_tokens,
775
+ output_tokens=output_tokens,
776
+ total_tokens=total_tokens,
777
+ cost=cost,
778
+ )
779
+ assert usage.input_tokens == input_tokens
780
+ assert usage.output_tokens == output_tokens
781
+ assert usage.total_tokens == total_tokens
782
+ assert usage.cost == cost
783
+
784
+
785
+ def test_usage_model_in_task_run(valid_task_run):
786
+ """Test that Usage can be properly set in a TaskRun."""
787
+ usage = Usage(
788
+ input_tokens=100,
789
+ output_tokens=50,
790
+ total_tokens=150,
791
+ cost=0.002,
792
+ )
793
+ task_run = valid_task_run.model_copy(deep=True)
794
+ task_run.usage = usage
795
+ assert task_run.usage == usage
796
+ assert task_run.usage.input_tokens == 100
797
+ assert task_run.usage.output_tokens == 50
798
+ assert task_run.usage.total_tokens == 150
799
+ assert task_run.usage.cost == 0.002
@@ -1,3 +1,4 @@
1
+ import jsonschema
1
2
  import pytest
2
3
  from pydantic import BaseModel
3
4
 
@@ -6,6 +7,7 @@ from kiln_ai.datamodel.json_schema import (
6
7
  schema_from_json_str,
7
8
  string_to_json_key,
8
9
  validate_schema,
10
+ validate_schema_with_value_error,
9
11
  )
10
12
 
11
13
 
@@ -71,15 +73,32 @@ def test_validate_schema_content():
71
73
  o = {"setup": "asdf", "punchline": "asdf", "rating": 1}
72
74
  validate_schema(o, json_joke_schema)
73
75
  o = {"setup": "asdf"}
74
- with pytest.raises(Exception):
76
+ with pytest.raises(jsonschema.exceptions.ValidationError):
75
77
  validate_schema(0, json_joke_schema)
76
78
  o = {"setup": "asdf", "punchline": "asdf"}
77
79
  validate_schema(o, json_joke_schema)
78
80
  o = {"setup": "asdf", "punchline": "asdf", "rating": "1"}
79
- with pytest.raises(Exception):
81
+ with pytest.raises(jsonschema.exceptions.ValidationError):
80
82
  validate_schema(o, json_joke_schema)
81
83
 
82
84
 
85
+ def test_validate_schema_content_with_value_error():
86
+ o = {"setup": "asdf", "punchline": "asdf", "rating": 1}
87
+ validate_schema_with_value_error(o, json_joke_schema, "PREFIX")
88
+ o = {"setup": "asdf"}
89
+ with pytest.raises(
90
+ ValueError, match="PREFIX The error from the schema check was: "
91
+ ):
92
+ validate_schema_with_value_error(0, json_joke_schema, "PREFIX")
93
+ o = {"setup": "asdf", "punchline": "asdf"}
94
+ validate_schema_with_value_error(o, json_joke_schema, "PREFIX")
95
+ o = {"setup": "asdf", "punchline": "asdf", "rating": "1"}
96
+ with pytest.raises(
97
+ ValueError, match="PREFIX The error from the schema check was: "
98
+ ):
99
+ validate_schema_with_value_error(o, json_joke_schema, "PREFIX")
100
+
101
+
83
102
  json_triangle_schema = """{
84
103
  "type": "object",
85
104
  "properties": {
@@ -122,7 +141,7 @@ def test_triangle_schema():
122
141
  assert schema["properties"]["c"]["type"] == "integer"
123
142
  assert schema["required"] == ["a", "b", "c"]
124
143
  validate_schema({"a": 1, "b": 2, "c": 3}, json_triangle_schema)
125
- with pytest.raises(Exception):
144
+ with pytest.raises(jsonschema.exceptions.ValidationError):
126
145
  validate_schema({"a": 1, "b": 2, "c": "3"}, json_triangle_schema)
127
146
 
128
147
 
@@ -119,7 +119,8 @@ def test_benchmark_load_from_file(benchmark, task_run):
119
119
  avg_time_per_iteration = total_time / iterations
120
120
  ops_per_second = 1.0 / avg_time_per_iteration
121
121
 
122
- # I get 8k ops per second on my MBP. Lower value here for CI.
122
+ # I get 8k ops per second on my MBP. Lower value here for CI and parallel testing.
123
123
  # Prior to optimization was 290 ops per second.
124
- if ops_per_second < 1000:
124
+ print(f"Ops per second: {ops_per_second:.6f}")
125
+ if ops_per_second < 500:
125
126
  pytest.fail(f"Ops per second: {ops_per_second:.6f}, expected more than 1k ops")
@@ -547,20 +547,34 @@ def test_prompt_parent_task():
547
547
  False,
548
548
  None,
549
549
  ),
550
- # Test 3: Invalid case - thinking instructions with final_only
550
+ # Test 3: Valid case - no thinking instructions with final_and_intermediate_r1_compatible
551
+ (
552
+ None,
553
+ FinetuneDataStrategy.final_and_intermediate_r1_compatible,
554
+ False,
555
+ None,
556
+ ),
557
+ # Test 4: Invalid case - thinking instructions with final_only
551
558
  (
552
559
  "Think step by step",
553
560
  FinetuneDataStrategy.final_only,
554
561
  True,
555
562
  "Thinking instructions can only be used when data_strategy is final_and_intermediate",
556
563
  ),
557
- # Test 4: Invalid case - no thinking instructions with final_and_intermediate
564
+ # Test 5: Invalid case - no thinking instructions with final_and_intermediate
558
565
  (
559
566
  None,
560
567
  FinetuneDataStrategy.final_and_intermediate,
561
568
  True,
562
569
  "Thinking instructions are required when data_strategy is final_and_intermediate",
563
570
  ),
571
+ # Test 6: Invalid case - thinking instructions with final_and_intermediate_r1_compatible
572
+ (
573
+ "Think step by step",
574
+ FinetuneDataStrategy.final_and_intermediate_r1_compatible,
575
+ True,
576
+ "Thinking instructions can only be used when data_strategy is final_and_intermediate",
577
+ ),
564
578
  ],
565
579
  )
566
580
  def test_finetune_thinking_instructions_validation(
@@ -617,3 +631,37 @@ def test_task_run_has_thinking_training_data(intermediate_outputs, expected):
617
631
  intermediate_outputs=intermediate_outputs,
618
632
  )
619
633
  assert task_run.has_thinking_training_data() == expected
634
+
635
+
636
+ @pytest.mark.parametrize(
637
+ "intermediate_outputs,expected",
638
+ [
639
+ # No intermediate outputs
640
+ (None, None),
641
+ # Empty intermediate outputs
642
+ ({}, None),
643
+ # Only chain_of_thought
644
+ ({"chain_of_thought": "thinking process"}, "thinking process"),
645
+ # Only reasoning
646
+ ({"reasoning": "reasoning process"}, "reasoning process"),
647
+ # Both chain_of_thought and reasoning (should return reasoning as it's checked first)
648
+ (
649
+ {"chain_of_thought": "thinking process", "reasoning": "reasoning process"},
650
+ "reasoning process",
651
+ ),
652
+ # Other intermediate outputs but no thinking data
653
+ ({"other_output": "some data"}, None),
654
+ # Mixed other outputs with thinking data
655
+ (
656
+ {"chain_of_thought": "thinking process", "other_output": "some data"},
657
+ "thinking process",
658
+ ),
659
+ ],
660
+ )
661
+ def test_task_run_thinking_training_data(intermediate_outputs, expected):
662
+ task_run = TaskRun(
663
+ input="test input",
664
+ output=TaskOutput(output="test output"),
665
+ intermediate_outputs=intermediate_outputs,
666
+ )
667
+ assert task_run.thinking_training_data() == expected