kiln-ai 0.15.0__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. kiln_ai/adapters/eval/eval_runner.py +5 -64
  2. kiln_ai/adapters/eval/g_eval.py +3 -3
  3. kiln_ai/adapters/fine_tune/dataset_formatter.py +124 -34
  4. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +264 -7
  5. kiln_ai/adapters/ml_model_list.py +478 -4
  6. kiln_ai/adapters/model_adapters/base_adapter.py +26 -8
  7. kiln_ai/adapters/model_adapters/litellm_adapter.py +41 -7
  8. kiln_ai/adapters/model_adapters/test_base_adapter.py +74 -2
  9. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +65 -1
  10. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +3 -2
  11. kiln_ai/adapters/model_adapters/test_structured_output.py +4 -6
  12. kiln_ai/adapters/parsers/base_parser.py +0 -3
  13. kiln_ai/adapters/parsers/parser_registry.py +5 -3
  14. kiln_ai/adapters/parsers/r1_parser.py +17 -2
  15. kiln_ai/adapters/parsers/request_formatters.py +40 -0
  16. kiln_ai/adapters/parsers/test_parser_registry.py +2 -2
  17. kiln_ai/adapters/parsers/test_r1_parser.py +44 -1
  18. kiln_ai/adapters/parsers/test_request_formatters.py +76 -0
  19. kiln_ai/adapters/prompt_builders.py +14 -1
  20. kiln_ai/adapters/provider_tools.py +18 -1
  21. kiln_ai/adapters/repair/test_repair_task.py +3 -2
  22. kiln_ai/adapters/test_prompt_builders.py +24 -3
  23. kiln_ai/adapters/test_provider_tools.py +70 -1
  24. kiln_ai/datamodel/__init__.py +2 -0
  25. kiln_ai/datamodel/datamodel_enums.py +14 -0
  26. kiln_ai/datamodel/dataset_filters.py +69 -1
  27. kiln_ai/datamodel/dataset_split.py +4 -0
  28. kiln_ai/datamodel/eval.py +8 -0
  29. kiln_ai/datamodel/finetune.py +1 -0
  30. kiln_ai/datamodel/prompt_id.py +1 -0
  31. kiln_ai/datamodel/task_output.py +1 -1
  32. kiln_ai/datamodel/task_run.py +39 -7
  33. kiln_ai/datamodel/test_basemodel.py +3 -7
  34. kiln_ai/datamodel/test_dataset_filters.py +82 -0
  35. kiln_ai/datamodel/test_dataset_split.py +2 -0
  36. kiln_ai/datamodel/test_example_models.py +54 -0
  37. kiln_ai/datamodel/test_models.py +50 -2
  38. kiln_ai/utils/async_job_runner.py +106 -0
  39. kiln_ai/utils/dataset_import.py +80 -18
  40. kiln_ai/utils/test_async_job_runner.py +199 -0
  41. kiln_ai/utils/test_dataset_import.py +242 -10
  42. {kiln_ai-0.15.0.dist-info → kiln_ai-0.16.0.dist-info}/METADATA +1 -1
  43. {kiln_ai-0.15.0.dist-info → kiln_ai-0.16.0.dist-info}/RECORD +45 -41
  44. {kiln_ai-0.15.0.dist-info → kiln_ai-0.16.0.dist-info}/WHEEL +0 -0
  45. {kiln_ai-0.15.0.dist-info → kiln_ai-0.16.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -218,8 +218,9 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
218
218
  }
219
219
 
220
220
  with patch.object(LiteLlmAdapter, "_run", new_callable=AsyncMock) as mock_run:
221
- mock_run.return_value = RunOutput(
222
- output=mocked_output, intermediate_outputs=None
221
+ mock_run.return_value = (
222
+ RunOutput(output=mocked_output, intermediate_outputs=None),
223
+ None,
223
224
  )
224
225
 
225
226
  adapter = adapter_for_task(
@@ -3,7 +3,7 @@ import logging
3
3
 
4
4
  import pytest
5
5
 
6
- from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter
6
+ from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput
7
7
  from kiln_ai.adapters.model_adapters.test_structured_output import (
8
8
  build_structured_output_test_task,
9
9
  )
@@ -15,6 +15,7 @@ from kiln_ai.adapters.prompt_builders import (
15
15
  MultiShotPromptBuilder,
16
16
  RepairsPromptBuilder,
17
17
  SavedPromptBuilder,
18
+ ShortPromptBuilder,
18
19
  SimpleChainOfThoughtPromptBuilder,
19
20
  SimplePromptBuilder,
20
21
  TaskRunConfigPromptBuilder,
@@ -33,6 +34,7 @@ from kiln_ai.datamodel import (
33
34
  TaskOutput,
34
35
  TaskOutputRating,
35
36
  TaskRun,
37
+ Usage,
36
38
  )
37
39
  from kiln_ai.datamodel.task import RunConfigProperties, TaskRunConfig
38
40
 
@@ -58,9 +60,28 @@ def test_simple_prompt_builder(tmp_path):
58
60
  assert input not in prompt
59
61
 
60
62
 
63
+ def test_short_prompt_builder(tmp_path):
64
+ task = build_test_task(tmp_path)
65
+ builder = ShortPromptBuilder(task=task)
66
+ prompt = builder.build_prompt(include_json_instructions=False)
67
+
68
+ # Should only include the instruction, not requirements
69
+ assert task.instruction == prompt
70
+ assert task.requirements[0].instruction not in prompt
71
+ assert task.requirements[1].instruction not in prompt
72
+ assert task.requirements[2].instruction not in prompt
73
+
74
+ # Should handle JSON instructions correctly
75
+ prompt_with_json = builder.build_prompt(include_json_instructions=True)
76
+ assert task.instruction in prompt_with_json
77
+ if task.output_schema():
78
+ assert "# Format Instructions" in prompt_with_json
79
+ assert task.output_schema() in prompt_with_json
80
+
81
+
61
82
  class MockAdapter(BaseAdapter):
62
- def _run(self, input: str) -> str:
63
- return "mock response"
83
+ async def _run(self, input: str) -> tuple[RunOutput, Usage | None]:
84
+ return RunOutput(output="mock response", intermediate_outputs=None), None
64
85
 
65
86
  def adapter_name(self) -> str:
66
87
  return "mock_adapter"
@@ -5,6 +5,7 @@ import pytest
5
5
  from kiln_ai.adapters.ml_model_list import (
6
6
  KilnModel,
7
7
  ModelName,
8
+ ModelParserID,
8
9
  ModelProviderName,
9
10
  )
10
11
  from kiln_ai.adapters.ollama_tools import OllamaConnection
@@ -24,7 +25,12 @@ from kiln_ai.adapters.provider_tools import (
24
25
  provider_name_from_id,
25
26
  provider_warnings,
26
27
  )
27
- from kiln_ai.datamodel import Finetune, StructuredOutputMode, Task
28
+ from kiln_ai.datamodel import (
29
+ Finetune,
30
+ FinetuneDataStrategy,
31
+ StructuredOutputMode,
32
+ Task,
33
+ )
28
34
 
29
35
 
30
36
  @pytest.fixture(autouse=True)
@@ -65,6 +71,33 @@ def mock_finetune():
65
71
  finetune.provider = ModelProviderName.openai
66
72
  finetune.fine_tune_model_id = "ft:gpt-3.5-turbo:custom:model-123"
67
73
  finetune.structured_output_mode = StructuredOutputMode.json_schema
74
+ finetune.data_strategy = FinetuneDataStrategy.final_only
75
+ mock.return_value = finetune
76
+ yield mock
77
+
78
+
79
+ @pytest.fixture
80
+ def mock_finetune_final_and_intermediate():
81
+ with patch("kiln_ai.datamodel.Finetune.from_id_and_parent_path") as mock:
82
+ finetune = Mock(spec=Finetune)
83
+ finetune.provider = ModelProviderName.openai
84
+ finetune.fine_tune_model_id = "ft:gpt-3.5-turbo:custom:model-123"
85
+ finetune.structured_output_mode = StructuredOutputMode.json_schema
86
+ finetune.data_strategy = FinetuneDataStrategy.final_and_intermediate
87
+ mock.return_value = finetune
88
+ yield mock
89
+
90
+
91
+ @pytest.fixture
92
+ def mock_finetune_r1_compatible():
93
+ with patch("kiln_ai.datamodel.Finetune.from_id_and_parent_path") as mock:
94
+ finetune = Mock(spec=Finetune)
95
+ finetune.provider = ModelProviderName.ollama
96
+ finetune.fine_tune_model_id = "ft:deepseek-r1:671b:custom:model-123"
97
+ finetune.structured_output_mode = StructuredOutputMode.json_schema
98
+ finetune.data_strategy = (
99
+ FinetuneDataStrategy.final_and_intermediate_r1_compatible
100
+ )
68
101
  mock.return_value = finetune
69
102
  yield mock
70
103
 
@@ -426,6 +459,38 @@ def test_finetune_provider_model_success(mock_project, mock_task, mock_finetune)
426
459
  assert provider.name == ModelProviderName.openai
427
460
  assert provider.model_id == "ft:gpt-3.5-turbo:custom:model-123"
428
461
  assert provider.structured_output_mode == StructuredOutputMode.json_schema
462
+ assert provider.reasoning_capable is False
463
+ assert provider.parser == None
464
+
465
+
466
+ def test_finetune_provider_model_success_final_and_intermediate(
467
+ mock_project, mock_task, mock_finetune_final_and_intermediate
468
+ ):
469
+ """Test successful creation of a fine-tuned model provider"""
470
+ model_id = "project-123::task-456::finetune-789"
471
+
472
+ provider = finetune_provider_model(model_id)
473
+
474
+ assert provider.name == ModelProviderName.openai
475
+ assert provider.model_id == "ft:gpt-3.5-turbo:custom:model-123"
476
+ assert provider.structured_output_mode == StructuredOutputMode.json_schema
477
+ assert provider.reasoning_capable is True
478
+ assert provider.parser == None
479
+
480
+
481
+ def test_finetune_provider_model_success_r1_compatible(
482
+ mock_project, mock_task, mock_finetune_r1_compatible
483
+ ):
484
+ """Test successful creation of a fine-tuned model provider"""
485
+ model_id = "project-123::task-456::finetune-789"
486
+
487
+ provider = finetune_provider_model(model_id)
488
+
489
+ assert provider.name == ModelProviderName.ollama
490
+ assert provider.model_id == "ft:deepseek-r1:671b:custom:model-123"
491
+ assert provider.structured_output_mode == StructuredOutputMode.json_schema
492
+ assert provider.reasoning_capable is True
493
+ assert provider.parser == ModelParserID.r1_thinking
429
494
 
430
495
 
431
496
  def test_finetune_provider_model_invalid_id():
@@ -515,6 +580,7 @@ def test_finetune_provider_model_structured_mode(
515
580
  finetune.provider = provider_name
516
581
  finetune.fine_tune_model_id = "fireworks-model-123"
517
582
  finetune.structured_output_mode = structured_output_mode
583
+ finetune.data_strategy = FinetuneDataStrategy.final_only
518
584
  mock_finetune.return_value = finetune
519
585
 
520
586
  provider = finetune_provider_model("project-123::task-456::finetune-789")
@@ -522,6 +588,8 @@ def test_finetune_provider_model_structured_mode(
522
588
  assert provider.name == provider_name
523
589
  assert provider.model_id == "fireworks-model-123"
524
590
  assert provider.structured_output_mode == expected_mode
591
+ assert provider.reasoning_capable is False
592
+ assert provider.parser == None
525
593
 
526
594
 
527
595
  def test_openai_compatible_provider_config(mock_shared_config):
@@ -799,6 +867,7 @@ def test_finetune_provider_model_vertex_ai(mock_project, mock_task, mock_finetun
799
867
  finetune.provider = ModelProviderName.vertex
800
868
  finetune.fine_tune_model_id = "projects/123/locations/us-central1/endpoints/456"
801
869
  finetune.structured_output_mode = StructuredOutputMode.json_mode
870
+ finetune.data_strategy = FinetuneDataStrategy.final_only
802
871
  mock_finetune.return_value = finetune
803
872
 
804
873
  provider = finetune_provider_model("project-123::task-456::finetune-789")
@@ -44,6 +44,7 @@ from kiln_ai.datamodel.task_output import (
44
44
  )
45
45
  from kiln_ai.datamodel.task_run import (
46
46
  TaskRun,
47
+ Usage,
47
48
  )
48
49
 
49
50
  __all__ = [
@@ -74,4 +75,5 @@ __all__ = [
74
75
  "PromptId",
75
76
  "PromptGenerators",
76
77
  "prompt_generator_values",
78
+ "Usage",
77
79
  ]
@@ -56,5 +56,19 @@ class FineTuneStatusType(str, Enum):
56
56
 
57
57
 
58
58
  class FinetuneDataStrategy(str, Enum):
59
+ """Strategy for what data to include when fine-tuning a model."""
60
+
61
+ # Only train on the final response, ignoring any intermediate steps or chain of thought
59
62
  final_only = "final_only"
63
+
64
+ # Train on both the final response and any intermediate steps/chain of thought
60
65
  final_and_intermediate = "final_and_intermediate"
66
+
67
+ # Train using R1-style thinking format, which includes the reasoning in <think> tags in the message
68
+ final_and_intermediate_r1_compatible = "final_and_intermediate_r1_compatible"
69
+
70
+
71
+ THINKING_DATA_STRATEGIES: list[FinetuneDataStrategy] = [
72
+ FinetuneDataStrategy.final_and_intermediate,
73
+ FinetuneDataStrategy.final_and_intermediate_r1_compatible,
74
+ ]
@@ -1,5 +1,6 @@
1
+ import re
1
2
  from enum import Enum
2
- from typing import Annotated, Protocol
3
+ from typing import Annotated, ClassVar, List, Protocol
3
4
 
4
5
  from pydantic import AfterValidator
5
6
 
@@ -59,6 +60,65 @@ class TagFilter:
59
60
  return self.tag in task_run.tags
60
61
 
61
62
 
63
+ class MultiDatasetFilter:
64
+ """
65
+ A filter that combines multiple filters using AND logic.
66
+ The filters are specified in a query string format after 'multi_filter::'
67
+ Example: multi_filter::high_rating&thinking_model&tag::tag_name
68
+
69
+ Ampersands in filter IDs can be escaped with a backslash.
70
+ """
71
+
72
+ PREFIX: ClassVar[str] = "multi_filter::"
73
+ ESCAPED_AMPERSAND: ClassVar[str] = r"\&"
74
+ UNESCAPED_AMPERSAND: ClassVar[str] = "&"
75
+
76
+ @classmethod
77
+ def parse_filter_string(cls, filter_string: str) -> List[str]:
78
+ """
79
+ Parse a filter string into individual filter IDs, handling escaped ampersands.
80
+ """
81
+ if not filter_string.startswith(cls.PREFIX):
82
+ raise ValueError(f"Filter string must start with {cls.PREFIX}")
83
+
84
+ # Remove the prefix
85
+ content = filter_string[len(cls.PREFIX) :]
86
+ if not content:
87
+ raise ValueError("No filters specified after prefix")
88
+
89
+ # Split on unescaped ampersands
90
+ # This regex matches & that are not preceded by a backslash
91
+ parts = re.split(r"(?<!\\)&", content)
92
+
93
+ # Unescape ampersands in each part
94
+ filter_ids = [
95
+ part.replace(cls.ESCAPED_AMPERSAND, cls.UNESCAPED_AMPERSAND)
96
+ for part in parts
97
+ ]
98
+
99
+ # Validate each filter ID using the existing validation
100
+ for fid in filter_ids:
101
+ _check_dataset_filter_id(fid)
102
+
103
+ return filter_ids
104
+
105
+ @classmethod
106
+ def is_valid_filter_string(cls, filter_string: str) -> bool:
107
+ """Check if a filter string is valid."""
108
+ try:
109
+ cls.parse_filter_string(filter_string)
110
+ return True
111
+ except ValueError:
112
+ return False
113
+
114
+ def __init__(self, filter_id: str):
115
+ filter_ids = MultiDatasetFilter.parse_filter_string(filter_id)
116
+ self.filters = [dataset_filter_from_id(fid) for fid in filter_ids]
117
+
118
+ def __call__(self, task_run: TaskRun) -> bool:
119
+ return all(f(task_run) for f in self.filters)
120
+
121
+
62
122
  class StaticDatasetFilters(str, Enum):
63
123
  """Dataset filter names."""
64
124
 
@@ -98,6 +158,11 @@ def _check_dataset_filter_id(id: str) -> str:
98
158
  if id.startswith("tag::") and len(id) > 5:
99
159
  return id
100
160
 
161
+ if id.startswith(MultiDatasetFilter.PREFIX):
162
+ if not MultiDatasetFilter.is_valid_filter_string(id):
163
+ raise ValueError(f"Invalid multi-filter string: {id}")
164
+ return id
165
+
101
166
  raise ValueError(f"Invalid dataset filter ID: {id}")
102
167
 
103
168
 
@@ -108,6 +173,9 @@ def dataset_filter_from_id(id: DatasetFilterId) -> DatasetFilter:
108
173
  if id.startswith("tag::") and len(id) > 5:
109
174
  return TagFilter(id[5:])
110
175
 
176
+ if id.startswith(MultiDatasetFilter.PREFIX):
177
+ return MultiDatasetFilter(id)
178
+
111
179
  if id in static_dataset_filters:
112
180
  return static_dataset_filters[id]
113
181
 
@@ -45,6 +45,10 @@ Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [
45
45
  DatasetSplitDefinition(name="train", percentage=0.8),
46
46
  DatasetSplitDefinition(name="test", percentage=0.2),
47
47
  ]
48
+ Train80Val20SplitDefinition: list[DatasetSplitDefinition] = [
49
+ DatasetSplitDefinition(name="train", percentage=0.8),
50
+ DatasetSplitDefinition(name="val", percentage=0.2),
51
+ ]
48
52
  Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [
49
53
  DatasetSplitDefinition(name="train", percentage=0.6),
50
54
  DatasetSplitDefinition(name="test", percentage=0.2),
kiln_ai/datamodel/eval.py CHANGED
@@ -263,6 +263,10 @@ class Eval(KilnParentedModel, KilnParentModel, parent_of={"configs": EvalConfig}
263
263
  default=None,
264
264
  description="The id of the current config to use for this eval. This can be changed over time to run the same eval with different configs.",
265
265
  )
266
+ current_run_config_id: ID_TYPE = Field(
267
+ default=None,
268
+ description="The id of the a run config which was selected as the best run config for this eval. The run config must belong to the parent Task.",
269
+ )
266
270
  eval_set_filter_id: DatasetFilterId = Field(
267
271
  description="The id of the dataset filter which defines which dataset items are included when running this eval. Should be mutually exclusive with eval_configs_filter_id."
268
272
  )
@@ -272,6 +276,10 @@ class Eval(KilnParentedModel, KilnParentModel, parent_of={"configs": EvalConfig}
272
276
  output_scores: List[EvalOutputScore] = Field(
273
277
  description="The scores this evaluator should produce."
274
278
  )
279
+ favourite: bool = Field(
280
+ default=False,
281
+ description="Whether this eval is a favourite of the user. Rendered as a star icon in the UI.",
282
+ )
275
283
 
276
284
  # Workaround to return typed parent without importing Task
277
285
  def parent_task(self) -> Union["Task", None]:
@@ -5,6 +5,7 @@ from typing_extensions import Self
5
5
 
6
6
  from kiln_ai.datamodel.basemodel import NAME_FIELD, KilnParentedModel
7
7
  from kiln_ai.datamodel.datamodel_enums import (
8
+ THINKING_DATA_STRATEGIES,
8
9
  FinetuneDataStrategy,
9
10
  FineTuneStatusType,
10
11
  StructuredOutputMode,
@@ -13,6 +13,7 @@ class PromptGenerators(str, Enum):
13
13
  SIMPLE_CHAIN_OF_THOUGHT = "simple_chain_of_thought_prompt_builder"
14
14
  FEW_SHOT_CHAIN_OF_THOUGHT = "few_shot_chain_of_thought_prompt_builder"
15
15
  MULTI_SHOT_CHAIN_OF_THOUGHT = "multi_shot_chain_of_thought_prompt_builder"
16
+ SHORT = "short_prompt_builder"
16
17
 
17
18
 
18
19
  prompt_generator_values = [pg.value for pg in PromptGenerators]
@@ -64,7 +64,7 @@ class TaskOutputRating(KilnBaseModel):
64
64
  )
65
65
  requirement_ratings: Dict[ID_TYPE, RequirementRating] = Field(
66
66
  default={},
67
- description="The ratings of the requirements of the task.",
67
+ description="The ratings of the requirements of the task. The ID can be either a task_requirement_id or a named rating for an eval_output_score name (in format 'named::<name>').",
68
68
  )
69
69
 
70
70
  # Previously we stored rating values as a dict of floats, but now we store them as RequirementRating objects.
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict, List, Union
3
3
 
4
4
  import jsonschema
5
5
  import jsonschema.exceptions
6
- from pydantic import Field, ValidationInfo, model_validator
6
+ from pydantic import BaseModel, Field, ValidationInfo, model_validator
7
7
  from typing_extensions import Self
8
8
 
9
9
  from kiln_ai.datamodel.basemodel import KilnParentedModel
@@ -15,6 +15,29 @@ if TYPE_CHECKING:
15
15
  from kiln_ai.datamodel.task import Task
16
16
 
17
17
 
18
+ class Usage(BaseModel):
19
+ input_tokens: int | None = Field(
20
+ default=None,
21
+ description="The number of input tokens used in the task run.",
22
+ ge=0,
23
+ )
24
+ output_tokens: int | None = Field(
25
+ default=None,
26
+ description="The number of output tokens used in the task run.",
27
+ ge=0,
28
+ )
29
+ total_tokens: int | None = Field(
30
+ default=None,
31
+ description="The total number of tokens used in the task run.",
32
+ ge=0,
33
+ )
34
+ cost: float | None = Field(
35
+ default=None,
36
+ description="The cost of the task run in US dollars, saved at runtime (prices can change over time).",
37
+ ge=0,
38
+ )
39
+
40
+
18
41
  class TaskRun(KilnParentedModel):
19
42
  """
20
43
  Represents a single execution of a Task.
@@ -47,17 +70,26 @@ class TaskRun(KilnParentedModel):
47
70
  default=[],
48
71
  description="Tags for the task run. Tags are used to categorize task runs for filtering and reporting.",
49
72
  )
73
+ usage: Usage | None = Field(
74
+ default=None,
75
+ description="Usage information for the task run. This includes the number of input tokens, output tokens, and total tokens used.",
76
+ )
77
+
78
+ def thinking_training_data(self) -> str | None:
79
+ """
80
+ Get the thinking training data from the task run.
81
+ """
82
+ if self.intermediate_outputs is None:
83
+ return None
84
+ return self.intermediate_outputs.get(
85
+ "reasoning"
86
+ ) or self.intermediate_outputs.get("chain_of_thought")
50
87
 
51
88
  def has_thinking_training_data(self) -> bool:
52
89
  """
53
90
  Does this run have thinking data that we can use to train a thinking model?
54
91
  """
55
- if self.intermediate_outputs is None:
56
- return False
57
- return (
58
- "chain_of_thought" in self.intermediate_outputs
59
- or "reasoning" in self.intermediate_outputs
60
- )
92
+ return self.thinking_training_data() is not None
61
93
 
62
94
  # Workaround to return typed parent without importing Task
63
95
  def parent_task(self) -> Union["Task", None]:
@@ -483,7 +483,7 @@ class MockAdapter(BaseAdapter):
483
483
  """Implementation of BaseAdapter for testing"""
484
484
 
485
485
  async def _run(self, input):
486
- return RunOutput(output="test output", intermediate_outputs=None)
486
+ return RunOutput(output="test output", intermediate_outputs=None), None
487
487
 
488
488
  def adapter_name(self) -> str:
489
489
  return "test"
@@ -510,6 +510,7 @@ async def test_invoke_parsing_flow(adapter):
510
510
  # Mock dependencies
511
511
  mock_provider = MagicMock()
512
512
  mock_provider.parser = "test_parser"
513
+ mock_provider.formatter = None
513
514
  mock_provider.reasoning_capable = False
514
515
 
515
516
  mock_parser = MagicMock()
@@ -517,13 +518,11 @@ async def test_invoke_parsing_flow(adapter):
517
518
  output="parsed test output", intermediate_outputs={"key": "value"}
518
519
  )
519
520
 
520
- mock_parser_class = MagicMock(return_value=mock_parser)
521
-
522
521
  with (
523
522
  patch.object(adapter, "model_provider", return_value=mock_provider),
524
523
  patch(
525
524
  "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id",
526
- return_value=mock_parser_class,
525
+ return_value=mock_parser,
527
526
  ),
528
527
  patch("kiln_ai.adapters.model_adapters.base_adapter.Config") as mock_config,
529
528
  ):
@@ -534,9 +533,6 @@ async def test_invoke_parsing_flow(adapter):
534
533
  # Execute
535
534
  result = await adapter.invoke("test input")
536
535
 
537
- # Verify parser was created correctly
538
- mock_parser_class.assert_called_once_with(structured_output=False)
539
-
540
536
  # Verify parsing occurred
541
537
  mock_parser.parse_output.assert_called_once()
542
538
  parsed_args = mock_parser.parse_output.call_args[1]
@@ -1,3 +1,5 @@
1
+ from unittest.mock import Mock
2
+
1
3
  import pytest
2
4
  from pydantic import BaseModel
3
5
 
@@ -5,12 +7,14 @@ from kiln_ai.datamodel.dataset_filters import (
5
7
  AllDatasetFilter,
6
8
  DatasetFilterId,
7
9
  HighRatingDatasetFilter,
10
+ MultiDatasetFilter,
8
11
  StaticDatasetFilters,
9
12
  TagFilter,
10
13
  ThinkingModelDatasetFilter,
11
14
  ThinkingModelHighRatedFilter,
12
15
  dataset_filter_from_id,
13
16
  )
17
+ from kiln_ai.datamodel.task_run import TaskRun
14
18
 
15
19
  # Note: Many more filter tests in test_dataset_split.py
16
20
 
@@ -69,3 +73,81 @@ def test_tag_filter(tag, expected_error, expected_tag):
69
73
  filter = dataset_filter_from_id(tag)
70
74
  assert isinstance(filter, TagFilter)
71
75
  assert filter.tag == expected_tag
76
+
77
+
78
+ class TestMultiDatasetFilter:
79
+ @pytest.mark.parametrize(
80
+ "filter_string,expected_filters",
81
+ [
82
+ ("multi_filter::high_rating", ["high_rating"]),
83
+ (
84
+ "multi_filter::high_rating&thinking_model",
85
+ ["high_rating", "thinking_model"],
86
+ ),
87
+ ("multi_filter::tag::test&high_rating", ["tag::test", "high_rating"]),
88
+ (
89
+ "multi_filter::high_rating&tag::tag\\&name",
90
+ ["high_rating", "tag::tag&name"],
91
+ ),
92
+ ],
93
+ )
94
+ def test_valid_filter_string_parsing(self, filter_string, expected_filters):
95
+ """Test that valid filter strings are parsed correctly."""
96
+ assert MultiDatasetFilter.parse_filter_string(filter_string) == expected_filters
97
+ assert MultiDatasetFilter.is_valid_filter_string(filter_string)
98
+
99
+ @pytest.mark.parametrize(
100
+ "filter_string,expected_error",
101
+ [
102
+ (
103
+ "not_multi_filter::high_rating",
104
+ "Filter string must start with multi_filter::",
105
+ ),
106
+ ("multi_filter::", "No filters specified after prefix"),
107
+ ("multi_filter::high_rating&", "Invalid dataset filter ID:"),
108
+ ("multi_filter::invalid_filter", "Invalid dataset filter ID:"),
109
+ ],
110
+ )
111
+ def test_invalid_filter_string_handling(self, filter_string, expected_error):
112
+ """Test that invalid filter strings raise appropriate errors."""
113
+ with pytest.raises(ValueError, match=expected_error):
114
+ MultiDatasetFilter.parse_filter_string(filter_string)
115
+ assert not MultiDatasetFilter.is_valid_filter_string(filter_string)
116
+
117
+ def test_filter_combination_logic(self):
118
+ """Test that multiple filters are combined with AND logic."""
119
+ # Create a mock task run
120
+ task_run = Mock(spec=TaskRun)
121
+ task_run.output = Mock()
122
+ task_run.output.rating = Mock()
123
+ task_run.output.rating.is_high_quality.return_value = True
124
+ task_run.tags = ["test_tag"]
125
+ task_run.has_thinking_training_data.return_value = True
126
+ task_run.repaired_output = None
127
+
128
+ # Test combining high_rating and tag filters
129
+ filter_id = "multi_filter::high_rating&tag::test_tag"
130
+ multi_filter = dataset_filter_from_id(filter_id)
131
+ assert multi_filter(task_run)
132
+
133
+ # Test that it fails if one filter fails
134
+ task_run.tags = ["wrong_tag"]
135
+ assert not multi_filter(task_run)
136
+ task_run.tags = ["test_tag"]
137
+ assert multi_filter(task_run)
138
+ task_run.output.rating.is_high_quality.return_value = False
139
+ assert not multi_filter(task_run)
140
+
141
+ # Verify the mock was called as expected
142
+ task_run.output.rating.is_high_quality.assert_called()
143
+
144
+ def test_filter_creation_from_id(self):
145
+ """Test that multi filters can be created via dataset_filter_from_id."""
146
+ filter_id = "multi_filter::high_rating&thinking_model"
147
+ filter = dataset_filter_from_id(filter_id)
148
+ assert isinstance(filter, MultiDatasetFilter)
149
+ assert len(filter.filters) == 2
150
+ assert any(isinstance(f, type(HighRatingDatasetFilter)) for f in filter.filters)
151
+ assert any(
152
+ isinstance(f, type(ThinkingModelDatasetFilter)) for f in filter.filters
153
+ )
@@ -17,6 +17,7 @@ from kiln_ai.datamodel.dataset_split import (
17
17
  AllSplitDefinition,
18
18
  Train60Test20Val20SplitDefinition,
19
19
  Train80Test20SplitDefinition,
20
+ Train80Val20SplitDefinition,
20
21
  )
21
22
  from kiln_ai.datamodel.test_dataset_filters import (
22
23
  AllDatasetFilter,
@@ -174,6 +175,7 @@ def test_high_rating_dataset_filter(sample_task_runs):
174
175
  [
175
176
  (Train80Test20SplitDefinition, {"train": 8, "test": 2}),
176
177
  (AllSplitDefinition, {"all": 10}),
178
+ (Train80Val20SplitDefinition, {"train": 8, "val": 2}),
177
179
  (Train60Test20Val20SplitDefinition, {"train": 6, "test": 2, "val": 2}),
178
180
  (
179
181
  [
@@ -16,6 +16,7 @@ from kiln_ai.datamodel import (
16
16
  TaskOutputRatingType,
17
17
  TaskRequirement,
18
18
  TaskRun,
19
+ Usage,
19
20
  )
20
21
 
21
22
 
@@ -743,3 +744,56 @@ def test_task_run_validate_repaired_output_structured(tmp_path):
743
744
  ),
744
745
  ),
745
746
  )
747
+
748
+
749
+ @pytest.mark.parametrize(
750
+ "input_tokens,output_tokens,total_tokens,cost,should_raise",
751
+ [
752
+ # Valid cases
753
+ (100, 50, 150, 0.002, False), # All fields
754
+ (None, None, None, None, False), # All None (defaults)
755
+ # Invalid cases
756
+ (-100, 50, 150, 0.002, True), # Negative input_tokens
757
+ (100, -50, 150, 0.002, True), # Negative output_tokens
758
+ (100, 50, -150, 0.002, True), # Negative total_tokens
759
+ (100, 50, 150, -0.002, True), # Negative cost
760
+ ],
761
+ )
762
+ def test_usage_model(input_tokens, output_tokens, total_tokens, cost, should_raise):
763
+ """Test the Usage model with various input combinations."""
764
+ if should_raise:
765
+ with pytest.raises(ValidationError):
766
+ Usage(
767
+ input_tokens=input_tokens,
768
+ output_tokens=output_tokens,
769
+ total_tokens=total_tokens,
770
+ cost=cost,
771
+ )
772
+ else:
773
+ usage = Usage(
774
+ input_tokens=input_tokens,
775
+ output_tokens=output_tokens,
776
+ total_tokens=total_tokens,
777
+ cost=cost,
778
+ )
779
+ assert usage.input_tokens == input_tokens
780
+ assert usage.output_tokens == output_tokens
781
+ assert usage.total_tokens == total_tokens
782
+ assert usage.cost == cost
783
+
784
+
785
+ def test_usage_model_in_task_run(valid_task_run):
786
+ """Test that Usage can be properly set in a TaskRun."""
787
+ usage = Usage(
788
+ input_tokens=100,
789
+ output_tokens=50,
790
+ total_tokens=150,
791
+ cost=0.002,
792
+ )
793
+ task_run = valid_task_run.model_copy(deep=True)
794
+ task_run.usage = usage
795
+ assert task_run.usage == usage
796
+ assert task_run.usage.input_tokens == 100
797
+ assert task_run.usage.output_tokens == 50
798
+ assert task_run.usage.total_tokens == 150
799
+ assert task_run.usage.cost == 0.002