kiln-ai 0.15.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (72) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +22 -44
  3. kiln_ai/adapters/chat/__init__.py +8 -0
  4. kiln_ai/adapters/chat/chat_formatter.py +234 -0
  5. kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
  6. kiln_ai/adapters/data_gen/test_data_gen_task.py +19 -6
  7. kiln_ai/adapters/eval/base_eval.py +8 -6
  8. kiln_ai/adapters/eval/eval_runner.py +9 -65
  9. kiln_ai/adapters/eval/g_eval.py +26 -8
  10. kiln_ai/adapters/eval/test_base_eval.py +166 -15
  11. kiln_ai/adapters/eval/test_eval_runner.py +3 -0
  12. kiln_ai/adapters/eval/test_g_eval.py +1 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +2 -2
  14. kiln_ai/adapters/fine_tune/dataset_formatter.py +153 -197
  15. kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
  16. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +402 -211
  17. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
  18. kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
  19. kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
  20. kiln_ai/adapters/fine_tune/test_vertex_finetune.py +4 -4
  21. kiln_ai/adapters/fine_tune/together_finetune.py +12 -1
  22. kiln_ai/adapters/ml_model_list.py +556 -45
  23. kiln_ai/adapters/model_adapters/base_adapter.py +100 -35
  24. kiln_ai/adapters/model_adapters/litellm_adapter.py +116 -100
  25. kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
  26. kiln_ai/adapters/model_adapters/test_base_adapter.py +299 -52
  27. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +121 -22
  28. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +44 -2
  29. kiln_ai/adapters/model_adapters/test_structured_output.py +48 -18
  30. kiln_ai/adapters/parsers/base_parser.py +0 -3
  31. kiln_ai/adapters/parsers/parser_registry.py +5 -3
  32. kiln_ai/adapters/parsers/r1_parser.py +17 -2
  33. kiln_ai/adapters/parsers/request_formatters.py +40 -0
  34. kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
  35. kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
  36. kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
  37. kiln_ai/adapters/prompt_builders.py +14 -17
  38. kiln_ai/adapters/provider_tools.py +39 -4
  39. kiln_ai/adapters/repair/test_repair_task.py +27 -5
  40. kiln_ai/adapters/test_adapter_registry.py +88 -28
  41. kiln_ai/adapters/test_ml_model_list.py +158 -0
  42. kiln_ai/adapters/test_prompt_adaptors.py +17 -3
  43. kiln_ai/adapters/test_prompt_builders.py +27 -19
  44. kiln_ai/adapters/test_provider_tools.py +130 -12
  45. kiln_ai/datamodel/__init__.py +2 -2
  46. kiln_ai/datamodel/datamodel_enums.py +43 -4
  47. kiln_ai/datamodel/dataset_filters.py +69 -1
  48. kiln_ai/datamodel/dataset_split.py +4 -0
  49. kiln_ai/datamodel/eval.py +8 -0
  50. kiln_ai/datamodel/finetune.py +13 -7
  51. kiln_ai/datamodel/prompt_id.py +1 -0
  52. kiln_ai/datamodel/task.py +68 -7
  53. kiln_ai/datamodel/task_output.py +1 -1
  54. kiln_ai/datamodel/task_run.py +39 -7
  55. kiln_ai/datamodel/test_basemodel.py +5 -8
  56. kiln_ai/datamodel/test_dataset_filters.py +82 -0
  57. kiln_ai/datamodel/test_dataset_split.py +2 -8
  58. kiln_ai/datamodel/test_example_models.py +54 -0
  59. kiln_ai/datamodel/test_models.py +80 -9
  60. kiln_ai/datamodel/test_task.py +168 -2
  61. kiln_ai/utils/async_job_runner.py +106 -0
  62. kiln_ai/utils/config.py +3 -2
  63. kiln_ai/utils/dataset_import.py +81 -19
  64. kiln_ai/utils/logging.py +165 -0
  65. kiln_ai/utils/test_async_job_runner.py +199 -0
  66. kiln_ai/utils/test_config.py +23 -0
  67. kiln_ai/utils/test_dataset_import.py +272 -10
  68. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/METADATA +1 -1
  69. kiln_ai-0.17.0.dist-info/RECORD +113 -0
  70. kiln_ai-0.15.0.dist-info/RECORD +0 -104
  71. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/WHEEL +0 -0
  72. {kiln_ai-0.15.0.dist-info → kiln_ai-0.17.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict, List, Union
3
3
 
4
4
  import jsonschema
5
5
  import jsonschema.exceptions
6
- from pydantic import Field, ValidationInfo, model_validator
6
+ from pydantic import BaseModel, Field, ValidationInfo, model_validator
7
7
  from typing_extensions import Self
8
8
 
9
9
  from kiln_ai.datamodel.basemodel import KilnParentedModel
@@ -15,6 +15,29 @@ if TYPE_CHECKING:
15
15
  from kiln_ai.datamodel.task import Task
16
16
 
17
17
 
18
+ class Usage(BaseModel):
19
+ input_tokens: int | None = Field(
20
+ default=None,
21
+ description="The number of input tokens used in the task run.",
22
+ ge=0,
23
+ )
24
+ output_tokens: int | None = Field(
25
+ default=None,
26
+ description="The number of output tokens used in the task run.",
27
+ ge=0,
28
+ )
29
+ total_tokens: int | None = Field(
30
+ default=None,
31
+ description="The total number of tokens used in the task run.",
32
+ ge=0,
33
+ )
34
+ cost: float | None = Field(
35
+ default=None,
36
+ description="The cost of the task run in US dollars, saved at runtime (prices can change over time).",
37
+ ge=0,
38
+ )
39
+
40
+
18
41
  class TaskRun(KilnParentedModel):
19
42
  """
20
43
  Represents a single execution of a Task.
@@ -47,17 +70,26 @@ class TaskRun(KilnParentedModel):
47
70
  default=[],
48
71
  description="Tags for the task run. Tags are used to categorize task runs for filtering and reporting.",
49
72
  )
73
+ usage: Usage | None = Field(
74
+ default=None,
75
+ description="Usage information for the task run. This includes the number of input tokens, output tokens, and total tokens used.",
76
+ )
77
+
78
+ def thinking_training_data(self) -> str | None:
79
+ """
80
+ Get the thinking training data from the task run.
81
+ """
82
+ if self.intermediate_outputs is None:
83
+ return None
84
+ return self.intermediate_outputs.get(
85
+ "reasoning"
86
+ ) or self.intermediate_outputs.get("chain_of_thought")
50
87
 
51
88
  def has_thinking_training_data(self) -> bool:
52
89
  """
53
90
  Does this run have thinking data that we can use to train a thinking model?
54
91
  """
55
- if self.intermediate_outputs is None:
56
- return False
57
- return (
58
- "chain_of_thought" in self.intermediate_outputs
59
- or "reasoning" in self.intermediate_outputs
60
- )
92
+ return self.thinking_training_data() is not None
61
93
 
62
94
  # Workaround to return typed parent without importing Task
63
95
  def parent_task(self) -> Union["Task", None]:
@@ -483,7 +483,7 @@ class MockAdapter(BaseAdapter):
483
483
  """Implementation of BaseAdapter for testing"""
484
484
 
485
485
  async def _run(self, input):
486
- return RunOutput(output="test output", intermediate_outputs=None)
486
+ return RunOutput(output="test output", intermediate_outputs=None), None
487
487
 
488
488
  def adapter_name(self) -> str:
489
489
  return "test"
@@ -500,8 +500,9 @@ def adapter(base_task):
500
500
  run_config=RunConfig(
501
501
  task=base_task,
502
502
  model_name="test_model",
503
- model_provider_name="test_provider",
503
+ model_provider_name="openai",
504
504
  prompt_id="simple_prompt_builder",
505
+ structured_output_mode="json_schema",
505
506
  ),
506
507
  )
507
508
 
@@ -510,6 +511,7 @@ async def test_invoke_parsing_flow(adapter):
510
511
  # Mock dependencies
511
512
  mock_provider = MagicMock()
512
513
  mock_provider.parser = "test_parser"
514
+ mock_provider.formatter = None
513
515
  mock_provider.reasoning_capable = False
514
516
 
515
517
  mock_parser = MagicMock()
@@ -517,13 +519,11 @@ async def test_invoke_parsing_flow(adapter):
517
519
  output="parsed test output", intermediate_outputs={"key": "value"}
518
520
  )
519
521
 
520
- mock_parser_class = MagicMock(return_value=mock_parser)
521
-
522
522
  with (
523
523
  patch.object(adapter, "model_provider", return_value=mock_provider),
524
524
  patch(
525
525
  "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id",
526
- return_value=mock_parser_class,
526
+ return_value=mock_parser,
527
527
  ),
528
528
  patch("kiln_ai.adapters.model_adapters.base_adapter.Config") as mock_config,
529
529
  ):
@@ -534,9 +534,6 @@ async def test_invoke_parsing_flow(adapter):
534
534
  # Execute
535
535
  result = await adapter.invoke("test input")
536
536
 
537
- # Verify parser was created correctly
538
- mock_parser_class.assert_called_once_with(structured_output=False)
539
-
540
537
  # Verify parsing occurred
541
538
  mock_parser.parse_output.assert_called_once()
542
539
  parsed_args = mock_parser.parse_output.call_args[1]
@@ -1,3 +1,5 @@
1
+ from unittest.mock import Mock
2
+
1
3
  import pytest
2
4
  from pydantic import BaseModel
3
5
 
@@ -5,12 +7,14 @@ from kiln_ai.datamodel.dataset_filters import (
5
7
  AllDatasetFilter,
6
8
  DatasetFilterId,
7
9
  HighRatingDatasetFilter,
10
+ MultiDatasetFilter,
8
11
  StaticDatasetFilters,
9
12
  TagFilter,
10
13
  ThinkingModelDatasetFilter,
11
14
  ThinkingModelHighRatedFilter,
12
15
  dataset_filter_from_id,
13
16
  )
17
+ from kiln_ai.datamodel.task_run import TaskRun
14
18
 
15
19
  # Note: Many more filter tests in test_dataset_split.py
16
20
 
@@ -69,3 +73,81 @@ def test_tag_filter(tag, expected_error, expected_tag):
69
73
  filter = dataset_filter_from_id(tag)
70
74
  assert isinstance(filter, TagFilter)
71
75
  assert filter.tag == expected_tag
76
+
77
+
78
+ class TestMultiDatasetFilter:
79
+ @pytest.mark.parametrize(
80
+ "filter_string,expected_filters",
81
+ [
82
+ ("multi_filter::high_rating", ["high_rating"]),
83
+ (
84
+ "multi_filter::high_rating&thinking_model",
85
+ ["high_rating", "thinking_model"],
86
+ ),
87
+ ("multi_filter::tag::test&high_rating", ["tag::test", "high_rating"]),
88
+ (
89
+ "multi_filter::high_rating&tag::tag\\&name",
90
+ ["high_rating", "tag::tag&name"],
91
+ ),
92
+ ],
93
+ )
94
+ def test_valid_filter_string_parsing(self, filter_string, expected_filters):
95
+ """Test that valid filter strings are parsed correctly."""
96
+ assert MultiDatasetFilter.parse_filter_string(filter_string) == expected_filters
97
+ assert MultiDatasetFilter.is_valid_filter_string(filter_string)
98
+
99
+ @pytest.mark.parametrize(
100
+ "filter_string,expected_error",
101
+ [
102
+ (
103
+ "not_multi_filter::high_rating",
104
+ "Filter string must start with multi_filter::",
105
+ ),
106
+ ("multi_filter::", "No filters specified after prefix"),
107
+ ("multi_filter::high_rating&", "Invalid dataset filter ID:"),
108
+ ("multi_filter::invalid_filter", "Invalid dataset filter ID:"),
109
+ ],
110
+ )
111
+ def test_invalid_filter_string_handling(self, filter_string, expected_error):
112
+ """Test that invalid filter strings raise appropriate errors."""
113
+ with pytest.raises(ValueError, match=expected_error):
114
+ MultiDatasetFilter.parse_filter_string(filter_string)
115
+ assert not MultiDatasetFilter.is_valid_filter_string(filter_string)
116
+
117
+ def test_filter_combination_logic(self):
118
+ """Test that multiple filters are combined with AND logic."""
119
+ # Create a mock task run
120
+ task_run = Mock(spec=TaskRun)
121
+ task_run.output = Mock()
122
+ task_run.output.rating = Mock()
123
+ task_run.output.rating.is_high_quality.return_value = True
124
+ task_run.tags = ["test_tag"]
125
+ task_run.has_thinking_training_data.return_value = True
126
+ task_run.repaired_output = None
127
+
128
+ # Test combining high_rating and tag filters
129
+ filter_id = "multi_filter::high_rating&tag::test_tag"
130
+ multi_filter = dataset_filter_from_id(filter_id)
131
+ assert multi_filter(task_run)
132
+
133
+ # Test that it fails if one filter fails
134
+ task_run.tags = ["wrong_tag"]
135
+ assert not multi_filter(task_run)
136
+ task_run.tags = ["test_tag"]
137
+ assert multi_filter(task_run)
138
+ task_run.output.rating.is_high_quality.return_value = False
139
+ assert not multi_filter(task_run)
140
+
141
+ # Verify the mock was called as expected
142
+ task_run.output.rating.is_high_quality.assert_called()
143
+
144
+ def test_filter_creation_from_id(self):
145
+ """Test that multi filters can be created via dataset_filter_from_id."""
146
+ filter_id = "multi_filter::high_rating&thinking_model"
147
+ filter = dataset_filter_from_id(filter_id)
148
+ assert isinstance(filter, MultiDatasetFilter)
149
+ assert len(filter.filters) == 2
150
+ assert any(isinstance(f, type(HighRatingDatasetFilter)) for f in filter.filters)
151
+ assert any(
152
+ isinstance(f, type(ThinkingModelDatasetFilter)) for f in filter.filters
153
+ )
@@ -17,6 +17,7 @@ from kiln_ai.datamodel.dataset_split import (
17
17
  AllSplitDefinition,
18
18
  Train60Test20Val20SplitDefinition,
19
19
  Train80Test20SplitDefinition,
20
+ Train80Val20SplitDefinition,
20
21
  )
21
22
  from kiln_ai.datamodel.test_dataset_filters import (
22
23
  AllDatasetFilter,
@@ -71,14 +72,6 @@ def sample_task_runs(sample_task):
71
72
  return task_runs
72
73
 
73
74
 
74
- @pytest.fixture
75
- def standard_splitstandard_splitss():
76
- return [
77
- DatasetSplitDefinition(name="train", percentage=0.8),
78
- DatasetSplitDefinition(name="test", percentage=0.2),
79
- ]
80
-
81
-
82
75
  @pytest.fixture
83
76
  def task_run():
84
77
  return TaskRun(
@@ -174,6 +167,7 @@ def test_high_rating_dataset_filter(sample_task_runs):
174
167
  [
175
168
  (Train80Test20SplitDefinition, {"train": 8, "test": 2}),
176
169
  (AllSplitDefinition, {"all": 10}),
170
+ (Train80Val20SplitDefinition, {"train": 8, "val": 2}),
177
171
  (Train60Test20Val20SplitDefinition, {"train": 6, "test": 2, "val": 2}),
178
172
  (
179
173
  [
@@ -16,6 +16,7 @@ from kiln_ai.datamodel import (
16
16
  TaskOutputRatingType,
17
17
  TaskRequirement,
18
18
  TaskRun,
19
+ Usage,
19
20
  )
20
21
 
21
22
 
@@ -743,3 +744,56 @@ def test_task_run_validate_repaired_output_structured(tmp_path):
743
744
  ),
744
745
  ),
745
746
  )
747
+
748
+
749
+ @pytest.mark.parametrize(
750
+ "input_tokens,output_tokens,total_tokens,cost,should_raise",
751
+ [
752
+ # Valid cases
753
+ (100, 50, 150, 0.002, False), # All fields
754
+ (None, None, None, None, False), # All None (defaults)
755
+ # Invalid cases
756
+ (-100, 50, 150, 0.002, True), # Negative input_tokens
757
+ (100, -50, 150, 0.002, True), # Negative output_tokens
758
+ (100, 50, -150, 0.002, True), # Negative total_tokens
759
+ (100, 50, 150, -0.002, True), # Negative cost
760
+ ],
761
+ )
762
+ def test_usage_model(input_tokens, output_tokens, total_tokens, cost, should_raise):
763
+ """Test the Usage model with various input combinations."""
764
+ if should_raise:
765
+ with pytest.raises(ValidationError):
766
+ Usage(
767
+ input_tokens=input_tokens,
768
+ output_tokens=output_tokens,
769
+ total_tokens=total_tokens,
770
+ cost=cost,
771
+ )
772
+ else:
773
+ usage = Usage(
774
+ input_tokens=input_tokens,
775
+ output_tokens=output_tokens,
776
+ total_tokens=total_tokens,
777
+ cost=cost,
778
+ )
779
+ assert usage.input_tokens == input_tokens
780
+ assert usage.output_tokens == output_tokens
781
+ assert usage.total_tokens == total_tokens
782
+ assert usage.cost == cost
783
+
784
+
785
+ def test_usage_model_in_task_run(valid_task_run):
786
+ """Test that Usage can be properly set in a TaskRun."""
787
+ usage = Usage(
788
+ input_tokens=100,
789
+ output_tokens=50,
790
+ total_tokens=150,
791
+ cost=0.002,
792
+ )
793
+ task_run = valid_task_run.model_copy(deep=True)
794
+ task_run.usage = usage
795
+ assert task_run.usage == usage
796
+ assert task_run.usage.input_tokens == 100
797
+ assert task_run.usage.output_tokens == 50
798
+ assert task_run.usage.total_tokens == 150
799
+ assert task_run.usage.cost == 0.002
@@ -9,13 +9,13 @@ from kiln_ai.datamodel import (
9
9
  DataSource,
10
10
  DataSourceType,
11
11
  Finetune,
12
- FinetuneDataStrategy,
13
12
  Project,
14
13
  Prompt,
15
14
  Task,
16
15
  TaskOutput,
17
16
  TaskRun,
18
17
  )
18
+ from kiln_ai.datamodel.datamodel_enums import ChatStrategy
19
19
  from kiln_ai.datamodel.test_json_schema import json_joke_schema
20
20
 
21
21
 
@@ -536,30 +536,58 @@ def test_prompt_parent_task():
536
536
  # Test 1: Valid case - no thinking instructions with final_only
537
537
  (
538
538
  None,
539
- FinetuneDataStrategy.final_only,
539
+ ChatStrategy.single_turn,
540
540
  False,
541
541
  None,
542
542
  ),
543
543
  # Test 2: Valid case - thinking instructions with final_and_intermediate
544
544
  (
545
545
  "Think step by step",
546
- FinetuneDataStrategy.final_and_intermediate,
546
+ ChatStrategy.two_message_cot_legacy,
547
547
  False,
548
548
  None,
549
549
  ),
550
- # Test 3: Invalid case - thinking instructions with final_only
550
+ # Test 3: Valid case - no thinking instructions with final_and_intermediate_r1_compatible
551
+ (
552
+ None,
553
+ ChatStrategy.single_turn_r1_thinking,
554
+ False,
555
+ None,
556
+ ),
557
+ # Test 4: Invalid case - thinking instructions with final_only
558
+ (
559
+ "Think step by step",
560
+ ChatStrategy.single_turn,
561
+ True,
562
+ "Thinking instructions can only be used when data_strategy is",
563
+ ),
564
+ # Test 5: Invalid case - no thinking instructions with final_and_intermediate
565
+ (
566
+ None,
567
+ ChatStrategy.two_message_cot_legacy,
568
+ True,
569
+ "Thinking instructions are required when data_strategy is",
570
+ ),
571
+ # Test 6: Invalid case - thinking instructions with final_and_intermediate_r1_compatible
551
572
  (
552
573
  "Think step by step",
553
- FinetuneDataStrategy.final_only,
574
+ ChatStrategy.single_turn_r1_thinking,
554
575
  True,
555
- "Thinking instructions can only be used when data_strategy is final_and_intermediate",
576
+ "Thinking instructions can only be used when data_strategy is",
556
577
  ),
557
- # Test 4: Invalid case - no thinking instructions with final_and_intermediate
578
+ # Test 7: new COT format
558
579
  (
580
+ "Think step by step",
581
+ ChatStrategy.two_message_cot,
582
+ False,
559
583
  None,
560
- FinetuneDataStrategy.final_and_intermediate,
584
+ ),
585
+ # Test 8: new COT format
586
+ (
587
+ None,
588
+ ChatStrategy.two_message_cot,
561
589
  True,
562
- "Thinking instructions are required when data_strategy is final_and_intermediate",
590
+ "Thinking instructions are required when data_strategy is",
563
591
  ),
564
592
  ],
565
593
  )
@@ -617,3 +645,46 @@ def test_task_run_has_thinking_training_data(intermediate_outputs, expected):
617
645
  intermediate_outputs=intermediate_outputs,
618
646
  )
619
647
  assert task_run.has_thinking_training_data() == expected
648
+
649
+
650
+ @pytest.mark.parametrize(
651
+ "intermediate_outputs,expected",
652
+ [
653
+ # No intermediate outputs
654
+ (None, None),
655
+ # Empty intermediate outputs
656
+ ({}, None),
657
+ # Only chain_of_thought
658
+ ({"chain_of_thought": "thinking process"}, "thinking process"),
659
+ # Only reasoning
660
+ ({"reasoning": "reasoning process"}, "reasoning process"),
661
+ # Both chain_of_thought and reasoning (should return reasoning as it's checked first)
662
+ (
663
+ {"chain_of_thought": "thinking process", "reasoning": "reasoning process"},
664
+ "reasoning process",
665
+ ),
666
+ # Other intermediate outputs but no thinking data
667
+ ({"other_output": "some data"}, None),
668
+ # Mixed other outputs with thinking data
669
+ (
670
+ {"chain_of_thought": "thinking process", "other_output": "some data"},
671
+ "thinking process",
672
+ ),
673
+ ],
674
+ )
675
+ def test_task_run_thinking_training_data(intermediate_outputs, expected):
676
+ task_run = TaskRun(
677
+ input="test input",
678
+ output=TaskOutput(output="test output"),
679
+ intermediate_outputs=intermediate_outputs,
680
+ )
681
+ assert task_run.thinking_training_data() == expected
682
+
683
+
684
+ def test_chat_strategy_enum():
685
+ # This has to align to the old FinetuneDataStrategy enum
686
+ assert ChatStrategy.single_turn == "final_only"
687
+ assert ChatStrategy.two_message_cot_legacy == "final_and_intermediate"
688
+ assert (
689
+ ChatStrategy.single_turn_r1_thinking == "final_and_intermediate_r1_compatible"
690
+ )
@@ -1,7 +1,7 @@
1
1
  import pytest
2
2
  from pydantic import ValidationError
3
3
 
4
- from kiln_ai.datamodel.datamodel_enums import TaskOutputRatingType
4
+ from kiln_ai.datamodel.datamodel_enums import StructuredOutputMode, TaskOutputRatingType
5
5
  from kiln_ai.datamodel.prompt_id import PromptGenerators
6
6
  from kiln_ai.datamodel.task import RunConfig, RunConfigProperties, Task, TaskRunConfig
7
7
  from kiln_ai.datamodel.task_output import normalize_rating
@@ -15,6 +15,7 @@ def test_runconfig_valid_creation():
15
15
  model_name="gpt-4",
16
16
  model_provider_name="openai",
17
17
  prompt_id=PromptGenerators.SIMPLE,
18
+ structured_output_mode="json_schema",
18
19
  )
19
20
 
20
21
  assert config.task == task
@@ -29,12 +30,13 @@ def test_runconfig_missing_required_fields():
29
30
 
30
31
  errors = exc_info.value.errors()
31
32
  assert (
32
- len(errors) == 4
33
+ len(errors) == 5
33
34
  ) # task, model_name, model_provider_name, and prompt_id are required
34
35
  assert any(error["loc"][0] == "task" for error in errors)
35
36
  assert any(error["loc"][0] == "model_name" for error in errors)
36
37
  assert any(error["loc"][0] == "model_provider_name" for error in errors)
37
38
  assert any(error["loc"][0] == "prompt_id" for error in errors)
39
+ assert any(error["loc"][0] == "structured_output_mode" for error in errors)
38
40
 
39
41
 
40
42
  def test_runconfig_custom_prompt_id():
@@ -45,6 +47,7 @@ def test_runconfig_custom_prompt_id():
45
47
  model_name="gpt-4",
46
48
  model_provider_name="openai",
47
49
  prompt_id=PromptGenerators.SIMPLE_CHAIN_OF_THOUGHT,
50
+ structured_output_mode="json_schema",
48
51
  )
49
52
 
50
53
  assert config.prompt_id == PromptGenerators.SIMPLE_CHAIN_OF_THOUGHT
@@ -61,6 +64,7 @@ def sample_run_config_props(sample_task):
61
64
  model_name="gpt-4",
62
65
  model_provider_name="openai",
63
66
  prompt_id=PromptGenerators.SIMPLE,
67
+ structured_output_mode="json_schema",
64
68
  )
65
69
 
66
70
 
@@ -157,3 +161,165 @@ def test_normalize_rating(rating_type, rating, expected):
157
161
  def test_normalize_rating_errors(rating_type, rating):
158
162
  with pytest.raises(ValueError):
159
163
  normalize_rating(rating, rating_type)
164
+
165
+
166
+ def test_run_config_defaults():
167
+ """RunConfig should require top_p, temperature, and structured_output_mode to be set."""
168
+ task = Task(id="task1", name="Test Task", instruction="Do something")
169
+
170
+ config = RunConfig(
171
+ task=task,
172
+ model_name="gpt-4",
173
+ model_provider_name="openai",
174
+ prompt_id=PromptGenerators.SIMPLE,
175
+ structured_output_mode="json_schema",
176
+ )
177
+ assert config.top_p == 1.0
178
+ assert config.temperature == 1.0
179
+
180
+
181
+ def test_run_config_valid_ranges():
182
+ """RunConfig should accept valid ranges for top_p and temperature."""
183
+ task = Task(id="task1", name="Test Task", instruction="Do something")
184
+
185
+ # Test valid values
186
+ config = RunConfig(
187
+ task=task,
188
+ model_name="gpt-4",
189
+ model_provider_name="openai",
190
+ prompt_id=PromptGenerators.SIMPLE,
191
+ top_p=0.9,
192
+ temperature=0.7,
193
+ structured_output_mode=StructuredOutputMode.json_schema,
194
+ )
195
+
196
+ assert config.top_p == 0.9
197
+ assert config.temperature == 0.7
198
+ assert config.structured_output_mode == StructuredOutputMode.json_schema
199
+
200
+
201
+ @pytest.mark.parametrize("top_p", [0.0, 0.5, 1.0])
202
+ def test_run_config_valid_top_p(top_p):
203
+ """Test that RunConfig accepts valid top_p values (0-1)."""
204
+ task = Task(id="task1", name="Test Task", instruction="Do something")
205
+
206
+ config = RunConfig(
207
+ task=task,
208
+ model_name="gpt-4",
209
+ model_provider_name="openai",
210
+ prompt_id=PromptGenerators.SIMPLE,
211
+ top_p=top_p,
212
+ temperature=1.0,
213
+ structured_output_mode=StructuredOutputMode.json_schema,
214
+ )
215
+
216
+ assert config.top_p == top_p
217
+
218
+
219
+ @pytest.mark.parametrize("top_p", [-0.1, 1.1, 2.0])
220
+ def test_run_config_invalid_top_p(top_p):
221
+ """Test that RunConfig rejects invalid top_p values."""
222
+ task = Task(id="task1", name="Test Task", instruction="Do something")
223
+
224
+ with pytest.raises(ValueError, match="top_p must be between 0 and 1"):
225
+ RunConfig(
226
+ task=task,
227
+ model_name="gpt-4",
228
+ model_provider_name="openai",
229
+ prompt_id=PromptGenerators.SIMPLE,
230
+ top_p=top_p,
231
+ temperature=1.0,
232
+ structured_output_mode=StructuredOutputMode.json_schema,
233
+ )
234
+
235
+
236
+ @pytest.mark.parametrize("temperature", [0.0, 1.0, 2.0])
237
+ def test_run_config_valid_temperature(temperature):
238
+ """Test that RunConfig accepts valid temperature values (0-2)."""
239
+ task = Task(id="task1", name="Test Task", instruction="Do something")
240
+
241
+ config = RunConfig(
242
+ task=task,
243
+ model_name="gpt-4",
244
+ model_provider_name="openai",
245
+ prompt_id=PromptGenerators.SIMPLE,
246
+ top_p=0.9,
247
+ temperature=temperature,
248
+ structured_output_mode=StructuredOutputMode.json_schema,
249
+ )
250
+
251
+ assert config.temperature == temperature
252
+
253
+
254
+ @pytest.mark.parametrize("temperature", [-0.1, 2.1, 3.0])
255
+ def test_run_config_invalid_temperature(temperature):
256
+ """Test that RunConfig rejects invalid temperature values."""
257
+ task = Task(id="task1", name="Test Task", instruction="Do something")
258
+
259
+ with pytest.raises(ValueError, match="temperature must be between 0 and 2"):
260
+ RunConfig(
261
+ task=task,
262
+ model_name="gpt-4",
263
+ model_provider_name="openai",
264
+ prompt_id=PromptGenerators.SIMPLE,
265
+ top_p=0.9,
266
+ temperature=temperature,
267
+ structured_output_mode=StructuredOutputMode.json_schema,
268
+ )
269
+
270
+
271
+ def test_run_config_upgrade_old_entries():
272
+ """Test that TaskRunConfig parses old entries correctly with nested objects, filling in defaults where needed."""
273
+
274
+ data = {
275
+ "v": 1,
276
+ "name": "test name",
277
+ "created_at": "2025-06-09T13:33:35.276927",
278
+ "created_by": "scosman",
279
+ "run_config_properties": {
280
+ "model_name": "gpt_4_1_nano",
281
+ "model_provider_name": "openai",
282
+ "prompt_id": "task_run_config::189194447826::228174773209::244130257039",
283
+ "top_p": 0.77,
284
+ "temperature": 0.77,
285
+ "structured_output_mode": "json_instruction_and_object",
286
+ },
287
+ "prompt": {
288
+ "name": "Dazzling Unicorn",
289
+ "description": "Frozen copy of prompt 'simple_prompt_builder', created for evaluations.",
290
+ "generator_id": "simple_prompt_builder",
291
+ "prompt": "Generate a joke, given a theme. The theme will be provided as a word or phrase as the input to the model. The assistant should output a joke that is funny and relevant to the theme. If a style is provided, the joke should be in that style. The output should include a setup and punchline.\n\nYour response should respect the following requirements:\n1) Keep the joke on topic. If the user specifies a theme, the joke must be related to that theme.\n2) Avoid any jokes that are offensive or inappropriate. Keep the joke clean and appropriate for all audiences.\n3) Make the joke funny and engaging. It should be something that someone would want to tell to their friends. Something clever, not just a simple pun.\n",
292
+ "chain_of_thought_instructions": None,
293
+ },
294
+ "model_type": "task_run_config",
295
+ }
296
+
297
+ # Parse the data - this should be TaskRunConfig, not RunConfig
298
+ parsed = TaskRunConfig.model_validate(data)
299
+ assert parsed.name == "test name"
300
+ assert parsed.created_by == "scosman"
301
+ assert (
302
+ parsed.run_config_properties.structured_output_mode
303
+ == "json_instruction_and_object"
304
+ )
305
+
306
+ # should still work if loading from file
307
+ parsed = TaskRunConfig.model_validate(data, context={"loading_from_file": True})
308
+ assert parsed.name == "test name"
309
+ assert parsed.created_by == "scosman"
310
+ assert (
311
+ parsed.run_config_properties.structured_output_mode
312
+ == "json_instruction_and_object"
313
+ )
314
+
315
+ # Remove structured_output_mode from run_config_properties and parse again
316
+ del data["run_config_properties"]["structured_output_mode"]
317
+
318
+ with pytest.raises(ValidationError):
319
+ # should error if not loading from file
320
+ parsed = TaskRunConfig.model_validate(data)
321
+
322
+ parsed = TaskRunConfig.model_validate(data, context={"loading_from_file": True})
323
+ assert parsed.name == "test name"
324
+ assert parsed.created_by == "scosman"
325
+ assert parsed.run_config_properties.structured_output_mode == "unknown"