kiln-ai 0.8.1__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (88) hide show
  1. kiln_ai/adapters/__init__.py +7 -7
  2. kiln_ai/adapters/adapter_registry.py +81 -10
  3. kiln_ai/adapters/data_gen/data_gen_task.py +21 -3
  4. kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
  5. kiln_ai/adapters/eval/base_eval.py +164 -0
  6. kiln_ai/adapters/eval/eval_runner.py +267 -0
  7. kiln_ai/adapters/eval/g_eval.py +367 -0
  8. kiln_ai/adapters/eval/registry.py +16 -0
  9. kiln_ai/adapters/eval/test_base_eval.py +324 -0
  10. kiln_ai/adapters/eval/test_eval_runner.py +640 -0
  11. kiln_ai/adapters/eval/test_g_eval.py +497 -0
  12. kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
  14. kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
  15. kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
  16. kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
  17. kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
  18. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +472 -129
  19. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +114 -22
  20. kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
  21. kiln_ai/adapters/ml_model_list.py +434 -93
  22. kiln_ai/adapters/model_adapters/__init__.py +18 -0
  23. kiln_ai/adapters/model_adapters/base_adapter.py +250 -0
  24. kiln_ai/adapters/model_adapters/langchain_adapters.py +309 -0
  25. kiln_ai/adapters/model_adapters/openai_compatible_config.py +10 -0
  26. kiln_ai/adapters/model_adapters/openai_model_adapter.py +289 -0
  27. kiln_ai/adapters/model_adapters/test_base_adapter.py +199 -0
  28. kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +105 -97
  29. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +216 -0
  30. kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +80 -30
  31. kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +125 -46
  32. kiln_ai/adapters/ollama_tools.py +0 -1
  33. kiln_ai/adapters/parsers/__init__.py +10 -0
  34. kiln_ai/adapters/parsers/base_parser.py +12 -0
  35. kiln_ai/adapters/parsers/json_parser.py +37 -0
  36. kiln_ai/adapters/parsers/parser_registry.py +19 -0
  37. kiln_ai/adapters/parsers/r1_parser.py +69 -0
  38. kiln_ai/adapters/parsers/test_json_parser.py +81 -0
  39. kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
  40. kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
  41. kiln_ai/adapters/prompt_builders.py +193 -49
  42. kiln_ai/adapters/provider_tools.py +91 -36
  43. kiln_ai/adapters/repair/repair_task.py +18 -19
  44. kiln_ai/adapters/repair/test_repair_task.py +7 -7
  45. kiln_ai/adapters/run_output.py +11 -0
  46. kiln_ai/adapters/test_adapter_registry.py +177 -0
  47. kiln_ai/adapters/test_generate_docs.py +69 -0
  48. kiln_ai/adapters/test_ollama_tools.py +0 -1
  49. kiln_ai/adapters/test_prompt_adaptors.py +25 -18
  50. kiln_ai/adapters/test_prompt_builders.py +265 -44
  51. kiln_ai/adapters/test_provider_tools.py +268 -46
  52. kiln_ai/datamodel/__init__.py +51 -772
  53. kiln_ai/datamodel/basemodel.py +31 -11
  54. kiln_ai/datamodel/datamodel_enums.py +58 -0
  55. kiln_ai/datamodel/dataset_filters.py +114 -0
  56. kiln_ai/datamodel/dataset_split.py +170 -0
  57. kiln_ai/datamodel/eval.py +298 -0
  58. kiln_ai/datamodel/finetune.py +105 -0
  59. kiln_ai/datamodel/json_schema.py +14 -3
  60. kiln_ai/datamodel/model_cache.py +8 -3
  61. kiln_ai/datamodel/project.py +23 -0
  62. kiln_ai/datamodel/prompt.py +37 -0
  63. kiln_ai/datamodel/prompt_id.py +83 -0
  64. kiln_ai/datamodel/strict_mode.py +24 -0
  65. kiln_ai/datamodel/task.py +181 -0
  66. kiln_ai/datamodel/task_output.py +321 -0
  67. kiln_ai/datamodel/task_run.py +164 -0
  68. kiln_ai/datamodel/test_basemodel.py +80 -2
  69. kiln_ai/datamodel/test_dataset_filters.py +71 -0
  70. kiln_ai/datamodel/test_dataset_split.py +127 -6
  71. kiln_ai/datamodel/test_datasource.py +3 -2
  72. kiln_ai/datamodel/test_eval_model.py +635 -0
  73. kiln_ai/datamodel/test_example_models.py +34 -17
  74. kiln_ai/datamodel/test_json_schema.py +23 -0
  75. kiln_ai/datamodel/test_model_cache.py +24 -0
  76. kiln_ai/datamodel/test_model_perf.py +125 -0
  77. kiln_ai/datamodel/test_models.py +131 -2
  78. kiln_ai/datamodel/test_prompt_id.py +129 -0
  79. kiln_ai/datamodel/test_task.py +159 -0
  80. kiln_ai/utils/config.py +6 -1
  81. kiln_ai/utils/exhaustive_error.py +6 -0
  82. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/METADATA +45 -7
  83. kiln_ai-0.12.0.dist-info/RECORD +100 -0
  84. kiln_ai/adapters/base_adapter.py +0 -191
  85. kiln_ai/adapters/langchain_adapters.py +0 -256
  86. kiln_ai-0.8.1.dist-info/RECORD +0 -58
  87. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/WHEEL +0 -0
  88. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -3,31 +3,31 @@
3
3
 
4
4
  Adapters are used to connect Kiln to external systems, or to add new functionality to Kiln.
5
5
 
6
- BaseAdapter is extensible, and used for adding adapters that provide AI functionality. There's currently a LangChain adapter which provides a bridge to LangChain.
6
+ Model adapters are used to call AI models, like Ollama, OpenAI, etc.
7
7
 
8
8
  The ml_model_list submodule contains a list of models that can be used for machine learning tasks. More can easily be added, but we keep a list here of models that are known to work well with Kiln's structured data and tool calling systems.
9
9
 
10
10
  The prompt_builders submodule contains classes that build prompts for use with the AI agents.
11
11
 
12
12
  The repair submodule contains an adapter for the repair task.
13
+
14
+ The parser submodule contains parsers for the output of the AI models.
13
15
  """
14
16
 
15
17
  from . import (
16
- base_adapter,
17
18
  data_gen,
18
19
  fine_tune,
19
- langchain_adapters,
20
20
  ml_model_list,
21
+ model_adapters,
21
22
  prompt_builders,
22
23
  repair,
23
24
  )
24
25
 
25
26
  __all__ = [
26
- "base_adapter",
27
- "langchain_adapters",
27
+ "model_adapters",
28
+ "data_gen",
29
+ "fine_tune",
28
30
  "ml_model_list",
29
31
  "prompt_builders",
30
32
  "repair",
31
- "data_gen",
32
- "fine_tune",
33
33
  ]
@@ -1,21 +1,92 @@
1
+ from os import getenv
2
+
1
3
  from kiln_ai import datamodel
2
- from kiln_ai.adapters.base_adapter import BaseAdapter
3
- from kiln_ai.adapters.langchain_adapters import LangchainAdapter
4
- from kiln_ai.adapters.prompt_builders import BasePromptBuilder
4
+ from kiln_ai.adapters.ml_model_list import ModelProviderName
5
+ from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig, BaseAdapter
6
+ from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
7
+ from kiln_ai.adapters.model_adapters.openai_model_adapter import (
8
+ OpenAICompatibleAdapter,
9
+ OpenAICompatibleConfig,
10
+ )
11
+ from kiln_ai.adapters.provider_tools import core_provider, openai_compatible_config
12
+ from kiln_ai.datamodel import PromptId
13
+ from kiln_ai.utils.config import Config
14
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
5
15
 
6
16
 
7
17
  def adapter_for_task(
8
18
  kiln_task: datamodel.Task,
9
- model_name: str | None = None,
10
- provider: str | None = None,
11
- prompt_builder: BasePromptBuilder | None = None,
12
- tags: list[str] | None = None,
19
+ model_name: str,
20
+ provider: ModelProviderName,
21
+ prompt_id: PromptId | None = None,
22
+ base_adapter_config: AdapterConfig | None = None,
13
23
  ) -> BaseAdapter:
14
- # We use langchain for everything right now, but can add any others here
24
+ # Get the provider to run. For things like the fine-tune provider, we want to run the underlying provider
25
+ core_provider_name = core_provider(model_name, provider)
26
+
27
+ match core_provider_name:
28
+ case ModelProviderName.openrouter:
29
+ return OpenAICompatibleAdapter(
30
+ kiln_task=kiln_task,
31
+ config=OpenAICompatibleConfig(
32
+ base_url=getenv("OPENROUTER_BASE_URL")
33
+ or "https://openrouter.ai/api/v1",
34
+ api_key=Config.shared().open_router_api_key,
35
+ model_name=model_name,
36
+ provider_name=provider,
37
+ default_headers={
38
+ "HTTP-Referer": "https://getkiln.ai/openrouter",
39
+ "X-Title": "KilnAI",
40
+ },
41
+ ),
42
+ prompt_id=prompt_id,
43
+ base_adapter_config=base_adapter_config,
44
+ )
45
+ case ModelProviderName.openai:
46
+ return OpenAICompatibleAdapter(
47
+ kiln_task=kiln_task,
48
+ config=OpenAICompatibleConfig(
49
+ api_key=Config.shared().open_ai_api_key,
50
+ model_name=model_name,
51
+ provider_name=provider,
52
+ ),
53
+ prompt_id=prompt_id,
54
+ base_adapter_config=base_adapter_config,
55
+ )
56
+ case ModelProviderName.openai_compatible:
57
+ config = openai_compatible_config(model_name)
58
+ return OpenAICompatibleAdapter(
59
+ kiln_task=kiln_task,
60
+ config=config,
61
+ prompt_id=prompt_id,
62
+ base_adapter_config=base_adapter_config,
63
+ )
64
+ # Use LangchainAdapter for the rest
65
+ case ModelProviderName.groq:
66
+ pass
67
+ case ModelProviderName.amazon_bedrock:
68
+ pass
69
+ case ModelProviderName.ollama:
70
+ pass
71
+ case ModelProviderName.fireworks_ai:
72
+ pass
73
+ # These are virtual providers that should have mapped to an actual provider in core_provider
74
+ case ModelProviderName.kiln_fine_tune:
75
+ raise ValueError(
76
+ "Fine tune is not a supported core provider. It should map to an actual provider."
77
+ )
78
+ case ModelProviderName.kiln_custom_registry:
79
+ raise ValueError(
80
+ "Custom openai compatible provider is not a supported core provider. It should map to an actual provider."
81
+ )
82
+ case _:
83
+ raise_exhaustive_enum_error(core_provider_name)
84
+
85
+ # We use langchain for all others right now, but moving off it as we touch anything.
15
86
  return LangchainAdapter(
16
87
  kiln_task,
17
88
  model_name=model_name,
18
89
  provider=provider,
19
- prompt_builder=prompt_builder,
20
- tags=tags,
90
+ prompt_id=prompt_id,
91
+ base_adapter_config=base_adapter_config,
21
92
  )
@@ -55,7 +55,7 @@ class DataGenCategoriesTaskInput(BaseModel):
55
55
  num_subtopics=num_subtopics,
56
56
  human_guidance=human_guidance,
57
57
  existing_topics=existing_topics,
58
- system_prompt=prompt_builder.build_prompt(),
58
+ system_prompt=prompt_builder.build_prompt(include_json_instructions=False),
59
59
  )
60
60
 
61
61
 
@@ -132,7 +132,7 @@ class DataGenSampleTaskInput(BaseModel):
132
132
  topic=topic,
133
133
  num_samples=num_samples,
134
134
  human_guidance=human_guidance,
135
- system_prompt=prompt_builder.build_prompt(),
135
+ system_prompt=prompt_builder.build_prompt(include_json_instructions=False),
136
136
  )
137
137
 
138
138
 
@@ -163,7 +163,7 @@ def list_json_schema_for_task(task: Task) -> str:
163
163
  "required": ["generated_samples"],
164
164
  }
165
165
 
166
- return json.dumps(top_level_schema)
166
+ return json.dumps(top_level_schema, ensure_ascii=False)
167
167
 
168
168
 
169
169
  class DataGenSampleTask(Task, parent_of={}):
@@ -183,3 +183,21 @@ class DataGenSampleTask(Task, parent_of={}):
183
183
  input_json_schema=json.dumps(DataGenSampleTaskInput.model_json_schema()),
184
184
  output_json_schema=list_json_schema_for_task(target_task),
185
185
  )
186
+
187
+
188
+ def wrap_task_with_guidance(original_instruction: str, guidance: str) -> str:
189
+ """Wrap the original instruction with human guidance.
190
+
191
+ Args:
192
+ original_instruction: The original instruction to wrap
193
+ guidance: The human guidance to wrap the instruction with
194
+ """
195
+ return f"""{original_instruction}
196
+
197
+ # Special Instructions
198
+
199
+ The above instructions are the original instructions for this task. For this execution, we've been given additional instructions. Follow both, but prioritize the additional instructions when they conflict. The additional instructions are:
200
+ <additional_instructions>
201
+ {guidance}
202
+ </additional_instructions>
203
+ """
@@ -180,7 +180,7 @@ def test_data_gen_sample_task_initialization(base_task):
180
180
  }
181
181
 
182
182
 
183
- def test_list_json_schema_for_task_with_output_schema(base_task):
183
+ def test_list_json_schema_for_task_with_input_schema(base_task):
184
184
  # Arrange
185
185
  base_task.input_json_schema = json.dumps(
186
186
  {
@@ -202,9 +202,29 @@ def test_list_json_schema_for_task_with_output_schema(base_task):
202
202
  assert generated_samples_schema["items"]["properties"]["age"]["type"] == "integer"
203
203
 
204
204
 
205
- def test_list_json_schema_for_task_without_output_schema(base_task):
205
+ def test_list_json_schema_for_task_with_input_schema_non_ascii(base_task):
206
206
  # Arrange
207
- base_task.output_json_schema = None
207
+ base_task.input_json_schema = json.dumps(
208
+ {
209
+ "type": "object",
210
+ "properties": {
211
+ "名字": {"type": "string"},
212
+ "年齢": {"type": "integer"},
213
+ },
214
+ }
215
+ )
216
+
217
+ # Act
218
+ schema = list_json_schema_for_task(base_task)
219
+
220
+ # Assert
221
+ assert "名字" in schema
222
+ assert "年齢" in schema
223
+
224
+
225
+ def test_list_json_schema_for_task_without_input_schema(base_task):
226
+ # Arrange
227
+ base_task.input_json_schema = None
208
228
 
209
229
  # Act
210
230
  schema = list_json_schema_for_task(base_task)
@@ -0,0 +1,164 @@
1
+ import json
2
+ from abc import abstractmethod
3
+ from typing import Dict
4
+
5
+ from kiln_ai.adapters.adapter_registry import adapter_for_task
6
+ from kiln_ai.adapters.ml_model_list import ModelProviderName
7
+ from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
8
+ from kiln_ai.datamodel.eval import Eval, EvalConfig, EvalScores
9
+ from kiln_ai.datamodel.json_schema import validate_schema
10
+ from kiln_ai.datamodel.task import RunConfig, TaskOutputRatingType, TaskRun
11
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
12
+
13
+
14
+ class BaseEval:
15
+ """
16
+ Base class for all evals/evaluators.
17
+
18
+ Should be subclassed, and the run_eval method implemented.
19
+ """
20
+
21
+ def __init__(self, eval_config: EvalConfig, run_config: RunConfig | None):
22
+ self.eval_config = eval_config
23
+ eval = eval_config.parent_eval()
24
+ if not eval:
25
+ raise ValueError("Eval config must have a parent eval")
26
+ self.eval = eval
27
+ task = self.eval.parent_task()
28
+ if not task:
29
+ raise ValueError("Eval must have a parent task")
30
+ self.target_task = task
31
+ self.score_schema = BaseEval.build_score_schema(eval, allow_float_scores=True)
32
+ self.run_config = run_config
33
+
34
+ def model_and_provider(self) -> tuple[str, ModelProviderName]:
35
+ model_name = self.eval_config.model_name
36
+ provider = self.eval_config.model_provider
37
+ if (
38
+ not model_name
39
+ or not provider
40
+ or not isinstance(model_name, str)
41
+ or not isinstance(provider, str)
42
+ or provider not in ModelProviderName.__members__
43
+ ):
44
+ raise ValueError(
45
+ "Model name and provider must be set in the eval config model properties"
46
+ )
47
+
48
+ return model_name, ModelProviderName(provider)
49
+
50
+ async def run_task_and_eval(
51
+ self, input: str
52
+ ) -> tuple[TaskRun, EvalScores, Dict[str, str] | None]:
53
+ """
54
+ Runs the task on the provided run_config to generate fresh output, then runs the eval on that output.
55
+ """
56
+ if self.run_config is None:
57
+ raise ValueError("Run config is required for run_task_and_eval")
58
+
59
+ run_adapter = adapter_for_task(
60
+ self.target_task,
61
+ self.run_config.model_name,
62
+ ModelProviderName(self.run_config.model_provider_name),
63
+ base_adapter_config=AdapterConfig(allow_saving=False),
64
+ )
65
+
66
+ # Parse structured input if needed
67
+ parsed_input = input
68
+ if self.target_task.output_json_schema is not None:
69
+ parsed_input = json.loads(input)
70
+
71
+ # we don't save by default here. We'll save manually after validating the output
72
+ run_output = await run_adapter.invoke(parsed_input)
73
+
74
+ eval_output, intermediate_outputs = await self.run_eval(run_output)
75
+ validate_schema(eval_output, self.score_schema)
76
+
77
+ return run_output, eval_output, intermediate_outputs
78
+
79
+ @abstractmethod
80
+ async def run_eval(
81
+ self, task_run: TaskRun
82
+ ) -> tuple[EvalScores, Dict[str, str] | None]:
83
+ """
84
+ Runs the eval on the given task run.
85
+
86
+ Returns a dictionary of scores which should conform to the score schema, and a dictionary of intermediate outputs (eval thinking).
87
+ """
88
+ pass
89
+
90
+ @classmethod
91
+ def build_score_schema(cls, eval: Eval, allow_float_scores: bool = False) -> str:
92
+ """
93
+ Build a JSON schema for the scoring output of the task requirements
94
+
95
+ We allow 2 modes: allow_float_scores=True and allow_float_scores=False.
96
+
97
+ allow_float_scores=False is used for the call to the model, and forces the model into selecting into discrete rating options (int 1-5, pass-fail, etc).
98
+ allow_float_scores=True is used for final score output (for example, after we take a g-eval weighting of the model's logprobs). A pass/fail rating might return 0.75 for likely pass (as opposed to 0.99 for near certain pass), or a 1-5 score might return 3.75.
99
+ """
100
+
101
+ # Note: python maintains order, which is good as we want the user defined order, and overall last
102
+ properties = {}
103
+ for output_score in eval.output_scores:
104
+ output_score_json_key = output_score.json_key()
105
+
106
+ if len(output_score_json_key) == 0:
107
+ raise ValueError(
108
+ f"Invalid output score name: {output_score.name}. Can not be used as JSON schema key."
109
+ )
110
+ property: dict[str, str | int | float | list[str] | list[int]] = {
111
+ "title": output_score.name,
112
+ }
113
+ match output_score.type:
114
+ case TaskOutputRatingType.five_star:
115
+ if allow_float_scores:
116
+ property["type"] = "number"
117
+ property["minimum"] = 1
118
+ property["maximum"] = 5
119
+ else:
120
+ property["enum"] = [1, 2, 3, 4, 5]
121
+
122
+ property["description"] = (
123
+ f"{output_score.instruction}\n\nThe rating should be between 1 and 5, with 1 being the worst and 5 being the best."
124
+ )
125
+ case TaskOutputRatingType.pass_fail:
126
+ if allow_float_scores:
127
+ property["type"] = "number"
128
+ property["minimum"] = 0
129
+ property["maximum"] = 1
130
+ property["description"] = (
131
+ f"{output_score.instruction}\n\nThe rating should be between 0 and 1, with 0 being a failure and 1 being a pass."
132
+ )
133
+ else:
134
+ property["enum"] = ["pass", "fail"]
135
+ property["description"] = (
136
+ f"{output_score.instruction}\n\nThe rating should be either 'pass' or 'fail'."
137
+ )
138
+ case TaskOutputRatingType.pass_fail_critical:
139
+ if allow_float_scores:
140
+ property["type"] = "number"
141
+ property["minimum"] = -1
142
+ property["maximum"] = 1
143
+ property["description"] = (
144
+ f"{output_score.instruction}\n\nThe rating should be between -1 and 1, with 1 being a pass, 0 being a failure, and -1 being a critical failure (very severe failure)."
145
+ )
146
+ else:
147
+ property["enum"] = ["pass", "fail", "critical"]
148
+ property["description"] = (
149
+ f"{output_score.instruction}\n\nThe rating should be either 'pass', 'fail', or 'critical' where critical a very severe failure."
150
+ )
151
+ case TaskOutputRatingType.custom:
152
+ # Skip custom rating types in evals
153
+ continue
154
+ case _:
155
+ raise_exhaustive_enum_error(output_score.type)
156
+
157
+ properties[output_score_json_key] = property
158
+
159
+ schema = {
160
+ "type": "object",
161
+ "properties": properties,
162
+ "required": list(properties.keys()),
163
+ }
164
+ return json.dumps(schema, ensure_ascii=False)
@@ -0,0 +1,267 @@
1
+ import asyncio
2
+ import logging
3
+ from dataclasses import dataclass
4
+ from typing import AsyncGenerator, Dict, List, Literal, Set
5
+
6
+ from kiln_ai.adapters.eval.base_eval import BaseEval
7
+ from kiln_ai.adapters.eval.registry import eval_adapter_from_type
8
+ from kiln_ai.datamodel.basemodel import ID_TYPE
9
+ from kiln_ai.datamodel.dataset_filters import dataset_filter_from_id
10
+ from kiln_ai.datamodel.eval import EvalConfig, EvalRun, EvalScores
11
+ from kiln_ai.datamodel.task import TaskRunConfig
12
+ from kiln_ai.datamodel.task_run import TaskRun
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @dataclass
18
+ class EvalJob:
19
+ item: TaskRun
20
+ type: Literal["task_run_eval", "eval_config_eval"]
21
+ # If type == "task_run_eval", both of these should be set. If type == "eval_config_eval", only eval_config should be set.
22
+ eval_config: EvalConfig
23
+ task_run_config: TaskRunConfig | None = None
24
+
25
+
26
+ @dataclass
27
+ class EvalProgress:
28
+ complete: int | None = None
29
+ total: int | None = None
30
+ errors: int | None = None
31
+
32
+
33
+ class EvalRunner:
34
+ """
35
+ Runs an eval. Async execution is supported to make it faster when using remote/fast model providers.
36
+
37
+ Can run an eval in 2 modes:
38
+ 1) eval_config_eval: evaluate an eval config using existing dataset items.
39
+ 2) task_run_eval: evaluate a range of task run configs, generating new run output using existing dataset item input.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ eval_configs: List[EvalConfig],
45
+ run_configs: List[TaskRunConfig] | None,
46
+ eval_run_type: Literal["eval_config_eval", "task_run_eval"],
47
+ ):
48
+ if len(eval_configs) == 0:
49
+ raise ValueError("Eval runner requires at least one eval config")
50
+ target_eval = eval_configs[0].parent_eval()
51
+ if target_eval is None:
52
+ raise ValueError("Eval config requires a parent eval")
53
+ for eval_config in eval_configs:
54
+ parent_eval = eval_config.parent_eval()
55
+ if parent_eval is None:
56
+ raise ValueError("Eval config requires a parent eval")
57
+ if parent_eval.id != target_eval.id:
58
+ raise ValueError("All eval configs must have the same parent eval")
59
+
60
+ target_task = target_eval.parent_task()
61
+ if target_task is None:
62
+ raise ValueError("Eval config requires a (grand)parent task")
63
+
64
+ # Check that run_configs is compatible
65
+ if eval_run_type == "task_run_eval":
66
+ if run_configs is None or len(run_configs) == 0:
67
+ raise ValueError("Task run eval requires run configs")
68
+ for run_config in run_configs:
69
+ parent_task = run_config.parent_task()
70
+ if parent_task is None:
71
+ raise ValueError("All run configs must have a parent task")
72
+ if parent_task.id != target_task.id:
73
+ raise ValueError(
74
+ "Run config is not for the same task as the eval configs"
75
+ )
76
+ else:
77
+ if run_configs is not None:
78
+ raise ValueError("Mode 'eval_config_eval' does not support run configs")
79
+
80
+ self.eval_run_type = eval_run_type
81
+ self.eval_configs = eval_configs
82
+ self.run_configs = run_configs
83
+ self.task = target_task
84
+ self.eval = target_eval
85
+
86
+ def collect_tasks(self) -> List[EvalJob]:
87
+ if self.eval_run_type == "eval_config_eval":
88
+ return self.collect_tasks_for_eval_config_eval()
89
+ else:
90
+ return self.collect_tasks_for_task_run_eval()
91
+
92
+ def collect_tasks_for_eval_config_eval(self) -> List[EvalJob]:
93
+ """
94
+ Collect all jobs for this run, excluding any that have already been run.
95
+
96
+ This variant is used for mode "eval_config_eval", using existing dataset run data (input/output).
97
+
98
+ The tasks:
99
+ - should be in the eval config set filter
100
+ - should not have already been run for this eval config + dataset item pair
101
+ """
102
+ filter = dataset_filter_from_id(self.eval.eval_configs_filter_id)
103
+
104
+ # already_run[eval_config_id][dataset_id]
105
+ already_run: Dict[ID_TYPE, Set[ID_TYPE]] = {}
106
+ for eval_config in self.eval_configs:
107
+ already_run[eval_config.id] = set()
108
+ for run in eval_config.runs(readonly=True):
109
+ already_run[eval_config.id].add(run.dataset_id)
110
+
111
+ return [
112
+ EvalJob(
113
+ item=task_run,
114
+ eval_config=eval_config,
115
+ type="eval_config_eval",
116
+ )
117
+ for task_run in self.task.runs(readonly=True)
118
+ if filter(task_run)
119
+ for eval_config in self.eval_configs
120
+ if task_run.id not in already_run[eval_config.id]
121
+ ]
122
+
123
+ def collect_tasks_for_task_run_eval(self) -> List[EvalJob]:
124
+ """
125
+ Collect all jobs for this run, excluding any that have already been run.
126
+
127
+ This variant is used for mode "task_run_eval", generating new run output using existing dataset item input.
128
+
129
+ The tasks:
130
+ - should be in the eval set filter
131
+ - should not have already been run for this eval config + run config + dataset item
132
+ """
133
+ filter = dataset_filter_from_id(self.eval.eval_set_filter_id)
134
+
135
+ # already_run[eval_config_id][run_config_id][dataset_id]
136
+ already_run: Dict[ID_TYPE, Dict[ID_TYPE, Set[ID_TYPE]]] = {}
137
+ for eval_config in self.eval_configs:
138
+ already_run[eval_config.id] = {}
139
+ for run_config in self.run_configs or []:
140
+ already_run[eval_config.id][run_config.id] = set()
141
+ for run in eval_config.runs(readonly=True):
142
+ if run.task_run_config_id is not None:
143
+ already_run[eval_config.id][run.task_run_config_id].add(
144
+ run.dataset_id
145
+ )
146
+
147
+ return [
148
+ EvalJob(
149
+ item=task_run,
150
+ task_run_config=run_config,
151
+ type="task_run_eval",
152
+ eval_config=eval_config,
153
+ )
154
+ for task_run in self.task.runs(readonly=True)
155
+ if filter(task_run)
156
+ for eval_config in self.eval_configs
157
+ for run_config in self.run_configs or []
158
+ if task_run.id not in already_run[eval_config.id][run_config.id]
159
+ ]
160
+
161
+ async def run(self, concurrency: int = 25) -> AsyncGenerator[EvalProgress, None]:
162
+ """
163
+ Runs the configured eval run with parallel workers and yields progress updates.
164
+ """
165
+ jobs = self.collect_tasks()
166
+
167
+ complete = 0
168
+ errors = 0
169
+ total = len(jobs)
170
+
171
+ # Send initial status
172
+ yield EvalProgress(complete=complete, total=total, errors=errors)
173
+
174
+ worker_queue: asyncio.Queue[EvalJob] = asyncio.Queue()
175
+ for job in jobs:
176
+ worker_queue.put_nowait(job)
177
+
178
+ # simple status queue to return progress. True=success, False=error
179
+ status_queue: asyncio.Queue[bool] = asyncio.Queue()
180
+
181
+ workers = []
182
+ for i in range(concurrency):
183
+ task = asyncio.create_task(self.run_worker(worker_queue, status_queue))
184
+ workers.append(task)
185
+
186
+ # Send status updates until workers are done, and they are all sent
187
+ while not status_queue.empty() or not all(worker.done() for worker in workers):
188
+ try:
189
+ # Use timeout to prevent hanging if all workers complete
190
+ # between our while condition check and get()
191
+ success = await asyncio.wait_for(status_queue.get(), timeout=0.1)
192
+ if success:
193
+ complete += 1
194
+ else:
195
+ errors += 1
196
+
197
+ yield EvalProgress(complete=complete, total=total, errors=errors)
198
+ except asyncio.TimeoutError:
199
+ # Timeout is expected, just continue to recheck worker status
200
+ # Don't love this but beats sentinels for reliability
201
+ continue
202
+
203
+ # These are redundant, but keeping them will catch async errors
204
+ await asyncio.gather(*workers)
205
+ await worker_queue.join()
206
+
207
+ async def run_worker(
208
+ self, worker_queue: asyncio.Queue[EvalJob], status_queue: asyncio.Queue[bool]
209
+ ):
210
+ while True:
211
+ try:
212
+ job = worker_queue.get_nowait()
213
+ except asyncio.QueueEmpty:
214
+ # worker can end when the queue is empty
215
+ break
216
+ try:
217
+ success = await self.run_job(job)
218
+ await status_queue.put(success)
219
+ finally:
220
+ # Always mark the dequeued task as done, even on exceptions
221
+ worker_queue.task_done()
222
+
223
+ async def run_job(self, job: EvalJob) -> bool:
224
+ try:
225
+ # Create the evaluator for this eval config/run config pair
226
+ evaluator = eval_adapter_from_type(job.eval_config.config_type)(
227
+ job.eval_config,
228
+ job.task_run_config.run_config() if job.task_run_config else None,
229
+ )
230
+ if not isinstance(evaluator, BaseEval):
231
+ raise ValueError("Not able to create evaluator from eval config")
232
+
233
+ task_output: str | None = None
234
+ scores: EvalScores | None = None
235
+ intermediate_outputs: Dict[str, str] | None = None
236
+ if job.type == "eval_config_eval":
237
+ # Eval config eval, we use the saved input from the task run, not invoking the task again
238
+ scores, intermediate_outputs = await evaluator.run_eval(job.item)
239
+ task_output = job.item.output.output
240
+ else:
241
+ # Task run eval, we invoke the task again to get a fresh output
242
+ (
243
+ result_task_run,
244
+ scores,
245
+ intermediate_outputs,
246
+ ) = await evaluator.run_task_and_eval(job.item.input)
247
+ task_output = result_task_run.output.output
248
+
249
+ # Save the job result
250
+ eval_run = EvalRun(
251
+ parent=job.eval_config,
252
+ task_run_config_id=job.task_run_config.id
253
+ if job.task_run_config
254
+ else None,
255
+ dataset_id=job.item.id,
256
+ eval_config_eval=job.type == "eval_config_eval",
257
+ scores=scores,
258
+ input=job.item.input,
259
+ output=task_output,
260
+ intermediate_outputs=intermediate_outputs,
261
+ )
262
+ eval_run.save_to_file()
263
+
264
+ return True
265
+ except Exception as e:
266
+ logger.error(f"Error running eval job for dataset item {job.item.id}: {e}")
267
+ return False