python-flexeval 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flexeval/__init__.py +11 -0
- flexeval/__main__.py +11 -0
- flexeval/classes/__init__.py +15 -0
- flexeval/classes/base.py +32 -0
- flexeval/classes/dataset.py +82 -0
- flexeval/classes/eval_runner.py +158 -0
- flexeval/classes/eval_set_run.py +32 -0
- flexeval/classes/message.py +183 -0
- flexeval/classes/metric.py +55 -0
- flexeval/classes/thread.py +79 -0
- flexeval/classes/tool_call.py +51 -0
- flexeval/classes/turn.py +206 -0
- flexeval/cli.py +104 -0
- flexeval/completions.py +147 -0
- flexeval/compute_metrics.py +788 -0
- flexeval/config.yaml +23 -0
- flexeval/configuration/__init__.py +1 -0
- flexeval/configuration/completion_functions.py +231 -0
- flexeval/configuration/evals.yaml +864 -0
- flexeval/configuration/function_metrics.py +650 -0
- flexeval/configuration/rubric_metrics.yaml +194 -0
- flexeval/data_loader.py +513 -0
- flexeval/db_utils.py +38 -0
- flexeval/dependency_graph.py +234 -0
- flexeval/eval_schema.json +256 -0
- flexeval/function_types.py +173 -0
- flexeval/helpers.py +52 -0
- flexeval/io/__init__.py +1 -0
- flexeval/io/parsers/yaml_parser.py +69 -0
- flexeval/log_utils.py +34 -0
- flexeval/metrics/__init__.py +8 -0
- flexeval/metrics/access.py +28 -0
- flexeval/metrics/save.py +39 -0
- flexeval/rubric.py +62 -0
- flexeval/run_utils.py +65 -0
- flexeval/runner.py +132 -0
- flexeval/schema/__init__.py +11 -0
- flexeval/schema/config_schema.py +46 -0
- flexeval/schema/eval_schema.py +163 -0
- flexeval/schema/evalrun_schema.py +97 -0
- flexeval/schema/rubric_schema.py +40 -0
- flexeval/schema/schema_utils.py +26 -0
- python_flexeval-0.1.5.dist-info/METADATA +118 -0
- python_flexeval-0.1.5.dist-info/RECORD +47 -0
- python_flexeval-0.1.5.dist-info/WHEEL +4 -0
- python_flexeval-0.1.5.dist-info/entry_points.txt +2 -0
- python_flexeval-0.1.5.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,788 @@
|
|
|
1
|
+
"""Utilities for computing needed metric computations and actually invoking those computations."""
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import importlib
|
|
5
|
+
import importlib.util
|
|
6
|
+
import inspect
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import string
|
|
10
|
+
import types
|
|
11
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
12
|
+
from typing import Iterable, Union
|
|
13
|
+
|
|
14
|
+
import networkx as nx
|
|
15
|
+
|
|
16
|
+
from flexeval import function_types
|
|
17
|
+
from flexeval.classes.eval_set_run import EvalSetRun
|
|
18
|
+
from flexeval.classes.message import Message
|
|
19
|
+
from flexeval.classes.thread import Thread
|
|
20
|
+
from flexeval.classes.tool_call import ToolCall
|
|
21
|
+
from flexeval.classes.turn import Turn
|
|
22
|
+
from flexeval.configuration import completion_functions, function_metrics
|
|
23
|
+
from flexeval.schema import EvalRun, FunctionsCollection, eval_schema
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ObjectMetric:
|
|
29
|
+
def __init__(self, object: Message | Turn | ToolCall | Thread, metric: dict):
|
|
30
|
+
"""Tracks a unique (object, metric) combination and any results computed for that metric.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
object (Message | Turn | ToolCall | Thread): The object to track.
|
|
34
|
+
metric (dict): The metric to track.
|
|
35
|
+
"""
|
|
36
|
+
self.object: Message | Turn | ToolCall | Thread = object
|
|
37
|
+
self.metric: dict = metric
|
|
38
|
+
self.metric_results: list[dict] | None = None
|
|
39
|
+
|
|
40
|
+
def __repr__(self) -> str:
|
|
41
|
+
return f"ObjectMetric(object={self.object.__class__.__name__} {self.object.id}, metric={self.metric}, metric_results={self.metric_results})"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class MetricGraphBuilder:
|
|
45
|
+
"""Builds :class:`networkx.DiGraph`\s of :class:`~flexeval.compute_metrics.ObjectMetric` instances that reflect any computational dependencies between them."""
|
|
46
|
+
|
|
47
|
+
def __init__(self):
|
|
48
|
+
# key: tuple(metric_level, metric_id, object_id)
|
|
49
|
+
# value: ObjectMetric
|
|
50
|
+
self.id_to_object_metric_map = {}
|
|
51
|
+
|
|
52
|
+
def build_metric_structures(self, evalsetrun: EvalSetRun):
|
|
53
|
+
metric_id_map = {}
|
|
54
|
+
metrics_by_level = {}
|
|
55
|
+
for metric_instance in json.loads(evalsetrun.metrics_graph_ordered_list):
|
|
56
|
+
metric_level = metric_instance["metric_level"]
|
|
57
|
+
if metric_level not in metrics_by_level:
|
|
58
|
+
metrics_by_level[metric_level] = []
|
|
59
|
+
metrics_by_level[metric_level].append(metric_instance)
|
|
60
|
+
metric_id_map[metric_instance["id"]] = metric_instance
|
|
61
|
+
self.metric_id_map = metric_id_map
|
|
62
|
+
self.metrics_by_level = metrics_by_level
|
|
63
|
+
|
|
64
|
+
def get_or_create_object_metric(
|
|
65
|
+
self,
|
|
66
|
+
metric_level: eval_schema.MetricLevel,
|
|
67
|
+
object: Message | Turn | ToolCall | Thread,
|
|
68
|
+
metric: dict,
|
|
69
|
+
) -> ObjectMetric:
|
|
70
|
+
key = (metric_level, metric["id"], object.id)
|
|
71
|
+
if key not in self.id_to_object_metric_map:
|
|
72
|
+
object_metric = ObjectMetric(object, metric)
|
|
73
|
+
self.id_to_object_metric_map[key] = object_metric
|
|
74
|
+
return self.id_to_object_metric_map[key]
|
|
75
|
+
|
|
76
|
+
def get_index(
|
|
77
|
+
self, target_id: int, objects: list[Message | Turn | ToolCall | Thread]
|
|
78
|
+
):
|
|
79
|
+
for i, object in enumerate(objects):
|
|
80
|
+
if target_id == object.id:
|
|
81
|
+
break
|
|
82
|
+
else:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"Failed to find object with id '{target_id}' in '{len(objects)}' objects."
|
|
85
|
+
)
|
|
86
|
+
return i
|
|
87
|
+
|
|
88
|
+
def find_object_metric_from_depends_on(
|
|
89
|
+
self,
|
|
90
|
+
current_object: Message | Turn | ToolCall | Thread,
|
|
91
|
+
current_metric_level: eval_schema.MetricLevel,
|
|
92
|
+
current_index: int,
|
|
93
|
+
depends_on: dict,
|
|
94
|
+
) -> ObjectMetric | None:
|
|
95
|
+
"""
|
|
96
|
+
If you're a Turn metric that depends on a Message metric,
|
|
97
|
+
then we create a dependency on ALL or ANY Message meeting the criteria.
|
|
98
|
+
We don't know how to handle that...
|
|
99
|
+
|
|
100
|
+
In contrast, if you're a Message metric that depends on a Turn metric,
|
|
101
|
+
then we have a dependency on only a single object: that Message's Turn.
|
|
102
|
+
"""
|
|
103
|
+
metric_id = depends_on["parent_id"]
|
|
104
|
+
dependency_metric_level = depends_on.get("metric_level")
|
|
105
|
+
if dependency_metric_level is None:
|
|
106
|
+
# if not specified in the dependency already, look up the metric level
|
|
107
|
+
depends_on_metric = self.metric_id_map[metric_id]
|
|
108
|
+
dependency_metric_level = depends_on_metric["metric_level"]
|
|
109
|
+
if dependency_metric_level is None:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Metric lacks a metric level: {depends_on_metric} (matched via dependency_info: {depends_on})"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if dependency_metric_level == current_metric_level:
|
|
115
|
+
pass # just use current_index, no lookup needed
|
|
116
|
+
elif current_metric_level == "ToolCall":
|
|
117
|
+
if dependency_metric_level == "Message":
|
|
118
|
+
current_index = self.get_index(
|
|
119
|
+
current_object.message_id, self.objects_by_level["Message"]
|
|
120
|
+
)
|
|
121
|
+
elif dependency_metric_level == "Turn":
|
|
122
|
+
current_index = self.get_index(
|
|
123
|
+
current_object.turn_id, self.objects_by_level["Turn"]
|
|
124
|
+
)
|
|
125
|
+
elif dependency_metric_level == "Thread":
|
|
126
|
+
current_index = 0 # only a single thread, by definition
|
|
127
|
+
elif current_metric_level == "Message":
|
|
128
|
+
if dependency_metric_level == "Turn":
|
|
129
|
+
current_index = self.get_index(
|
|
130
|
+
current_object.turn_id, self.objects_by_level["Turn"]
|
|
131
|
+
)
|
|
132
|
+
elif dependency_metric_level == "Thread":
|
|
133
|
+
current_index = 0 # only a single thread, by definition
|
|
134
|
+
elif dependency_metric_level == "ToolCall":
|
|
135
|
+
raise ValueError(
|
|
136
|
+
f"Can't depend on a '{dependency_metric_level}' metric from a '{current_metric_level}' metric."
|
|
137
|
+
)
|
|
138
|
+
elif current_metric_level == "Turn":
|
|
139
|
+
if dependency_metric_level == "Thread":
|
|
140
|
+
current_index = 0 # only a single thread, by definition
|
|
141
|
+
else:
|
|
142
|
+
raise ValueError(
|
|
143
|
+
f"Can't depend on a '{dependency_metric_level}' metric from a '{current_metric_level}' metric."
|
|
144
|
+
)
|
|
145
|
+
elif current_metric_level == "Thread":
|
|
146
|
+
raise ValueError(
|
|
147
|
+
f"Can't depend on a '{dependency_metric_level}' metric from a '{current_metric_level}' metric."
|
|
148
|
+
)
|
|
149
|
+
else:
|
|
150
|
+
raise ValueError(f"Unsupported level: {current_metric_level=}")
|
|
151
|
+
relative_object_position = depends_on["relative_object_position"]
|
|
152
|
+
target_object_index = current_index + relative_object_position
|
|
153
|
+
if target_object_index < 0:
|
|
154
|
+
logger.debug(
|
|
155
|
+
f"Object at position '{current_index}' object cannot in principle satisfy this dependency, so skipping it."
|
|
156
|
+
)
|
|
157
|
+
return None
|
|
158
|
+
object = self.objects_by_level[dependency_metric_level][target_object_index]
|
|
159
|
+
metric = self.metric_id_map[metric_id]
|
|
160
|
+
return self.get_or_create_object_metric(dependency_metric_level, object, metric)
|
|
161
|
+
|
|
162
|
+
def build_thread_task_graphs(self, evalsetrun: EvalSetRun) -> Iterable[nx.DiGraph]:
|
|
163
|
+
threads = evalsetrun.threads
|
|
164
|
+
for thread in threads:
|
|
165
|
+
yield self.build_thread_task_graph(thread)
|
|
166
|
+
|
|
167
|
+
def build_thread_task_graph(self, thread: Thread) -> nx.DiGraph:
|
|
168
|
+
self.objects_by_level = {
|
|
169
|
+
"Thread": [thread],
|
|
170
|
+
"Turn": list(thread.turns),
|
|
171
|
+
"Message": list(thread.messages),
|
|
172
|
+
"ToolCall": list(thread.toolcalls),
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
g = nx.DiGraph()
|
|
176
|
+
for level, metrics_at_level in self.metrics_by_level.items():
|
|
177
|
+
if len(metrics_at_level) == 0:
|
|
178
|
+
continue
|
|
179
|
+
objects = self.objects_by_level[level]
|
|
180
|
+
for i, object in enumerate(objects):
|
|
181
|
+
for metric in metrics_at_level:
|
|
182
|
+
# register metric on object
|
|
183
|
+
object_metric = self.get_or_create_object_metric(
|
|
184
|
+
level, object, metric
|
|
185
|
+
)
|
|
186
|
+
g.add_node(object_metric)
|
|
187
|
+
if "depends_on" in metric:
|
|
188
|
+
for dependency in metric["depends_on"]:
|
|
189
|
+
# register dependency metric on the relevant object
|
|
190
|
+
dependency_object_metric = (
|
|
191
|
+
self.find_object_metric_from_depends_on(
|
|
192
|
+
object, level, i, dependency
|
|
193
|
+
)
|
|
194
|
+
)
|
|
195
|
+
if dependency_object_metric is None:
|
|
196
|
+
logger.debug(
|
|
197
|
+
"This object cannot in principle satisfy this dependency, so skipping it."
|
|
198
|
+
)
|
|
199
|
+
# TODO verify that this is the expected behavior in chained dependencies X -> Y -> Z
|
|
200
|
+
g.remove_node(object_metric)
|
|
201
|
+
continue
|
|
202
|
+
g.add_node(dependency_object_metric)
|
|
203
|
+
g.add_edge(
|
|
204
|
+
dependency_object_metric,
|
|
205
|
+
object_metric,
|
|
206
|
+
depends_on=dependency,
|
|
207
|
+
)
|
|
208
|
+
return g
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def compute_metrics(evalrun: EvalRun, evalsetrun: EvalSetRun) -> list[dict]:
|
|
212
|
+
n_workers = evalrun.config.max_workers
|
|
213
|
+
raise_on_error = evalrun.config.raise_on_metric_error
|
|
214
|
+
mgb = MetricGraphBuilder()
|
|
215
|
+
mgb.build_metric_structures(evalsetrun)
|
|
216
|
+
graphs = mgb.build_thread_task_graphs(evalsetrun)
|
|
217
|
+
mc = MetricComputer.from_evalrun(evalrun, evalsetrun)
|
|
218
|
+
metrics = []
|
|
219
|
+
if n_workers == 1:
|
|
220
|
+
for graph in graphs:
|
|
221
|
+
graph_metrics = mc.process_thread_dependency_graph(graph, raise_on_error)
|
|
222
|
+
metrics.extend(graph_metrics)
|
|
223
|
+
else:
|
|
224
|
+
with ThreadPoolExecutor(max_workers=n_workers) as executor:
|
|
225
|
+
futures = []
|
|
226
|
+
for graph in graphs:
|
|
227
|
+
future = executor.submit(mc.process_thread_dependency_graph, graph)
|
|
228
|
+
futures.append(future)
|
|
229
|
+
for i, future in enumerate(futures):
|
|
230
|
+
metrics.extend(future.result())
|
|
231
|
+
if i % 100 == 0:
|
|
232
|
+
logger.info(f"Metrics futures resulted: {i + 1} / {len(futures)}")
|
|
233
|
+
return metrics
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class MetricComputer:
|
|
237
|
+
@classmethod
|
|
238
|
+
def from_evalrun(
|
|
239
|
+
cls, evalrun: EvalRun, evalsetrun: EvalSetRun | None = None
|
|
240
|
+
) -> "MetricComputer":
|
|
241
|
+
function_modules = evalrun.function_modules
|
|
242
|
+
# convert from string module names or filepaths to Python modules
|
|
243
|
+
actual_modules = []
|
|
244
|
+
for i, function_module in enumerate(function_modules):
|
|
245
|
+
if isinstance(function_module, types.ModuleType):
|
|
246
|
+
# already a module
|
|
247
|
+
actual_modules.append(function_module)
|
|
248
|
+
elif isinstance(function_module, FunctionsCollection):
|
|
249
|
+
raise ValueError("FunctionsCollection not yet implemented!")
|
|
250
|
+
else: # it's a filepath
|
|
251
|
+
try:
|
|
252
|
+
# TODO I think this is not necessary given the pydantic schema; this should always fail for filepaths
|
|
253
|
+
# alternately, we might call import_module() on the ModuleType modules, but I think that's unnecessary
|
|
254
|
+
module = importlib.import_module(str(function_module))
|
|
255
|
+
except ModuleNotFoundError as module_not_found:
|
|
256
|
+
try:
|
|
257
|
+
spec = importlib.util.spec_from_file_location(
|
|
258
|
+
f"function_module_{i}", function_module
|
|
259
|
+
)
|
|
260
|
+
module = importlib.util.module_from_spec(spec)
|
|
261
|
+
spec.loader.exec_module(module)
|
|
262
|
+
except Exception as module_not_loaded:
|
|
263
|
+
raise ValueError(
|
|
264
|
+
f"Failed to load function module specified by '{function_module}.' (module not found: {module_not_found}, and failed to load from file location: {module_not_loaded})"
|
|
265
|
+
)
|
|
266
|
+
actual_modules.append(module)
|
|
267
|
+
if evalrun.add_default_functions and function_metrics not in actual_modules:
|
|
268
|
+
actual_modules.append(function_metrics)
|
|
269
|
+
mc = cls(actual_modules, evalsetrun)
|
|
270
|
+
# validation step: verify that all functions are present
|
|
271
|
+
missing_functions = set()
|
|
272
|
+
if evalrun.eval.metrics.function is not None:
|
|
273
|
+
for function_item in evalrun.eval.metrics.function:
|
|
274
|
+
try:
|
|
275
|
+
mc.find_function(function_item.name)
|
|
276
|
+
except ValueError:
|
|
277
|
+
missing_functions.add(function_item.name)
|
|
278
|
+
if len(missing_functions) > 0:
|
|
279
|
+
raise ValueError(
|
|
280
|
+
f"Failed to find '{len(missing_functions)}' functions in the provided function module. Missing function names: {', '.join(sorted(missing_functions))}"
|
|
281
|
+
)
|
|
282
|
+
# validation step: verify that all rubrics are present
|
|
283
|
+
missing_rubrics = set()
|
|
284
|
+
if mc.rubrics is not None and evalrun.eval.metrics.rubric is not None:
|
|
285
|
+
for rubric_item in evalrun.eval.metrics.rubric:
|
|
286
|
+
if rubric_item.name not in mc.rubrics:
|
|
287
|
+
missing_rubrics.add(rubric_item.name)
|
|
288
|
+
if len(missing_rubrics) > 0:
|
|
289
|
+
raise ValueError(
|
|
290
|
+
f"Failed to find '{len(missing_rubrics)}' rubrics in the provided rubric set. Missing rubric names: {', '.join(sorted(missing_rubrics))}"
|
|
291
|
+
)
|
|
292
|
+
return mc
|
|
293
|
+
|
|
294
|
+
def __init__(self, function_modules: list, evalsetrun: EvalSetRun | None = None):
|
|
295
|
+
self.function_modules: list = function_modules
|
|
296
|
+
self.rubrics: dict | None = (
|
|
297
|
+
self.load_rubrics(evalsetrun) if evalsetrun is not None else None
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
def load_rubrics(self, evalsetrun: EvalSetRun):
|
|
301
|
+
"""Set the rubrics to be used by this MetricComputer from the given EvalSetRun."""
|
|
302
|
+
self.rubrics = json.loads(evalsetrun.rubrics)
|
|
303
|
+
|
|
304
|
+
def process_thread_dependency_graphs(
|
|
305
|
+
self, graph_list: Iterable[nx.DiGraph]
|
|
306
|
+
) -> list[dict]:
|
|
307
|
+
evaluated_metrics = []
|
|
308
|
+
for g in graph_list:
|
|
309
|
+
evaluated_metrics.extend(self.process_thread_dependency_graph(g))
|
|
310
|
+
return evaluated_metrics
|
|
311
|
+
|
|
312
|
+
def process_thread_dependency_graph(
|
|
313
|
+
self, g: nx.DiGraph, raise_on_error: bool = True
|
|
314
|
+
) -> list[dict]:
|
|
315
|
+
evaluated_metrics = []
|
|
316
|
+
try:
|
|
317
|
+
for object_metric in nx.topological_sort(g):
|
|
318
|
+
all_dependencies_met = True
|
|
319
|
+
for dependency in g.predecessors(object_metric):
|
|
320
|
+
if dependency.metric_results is None:
|
|
321
|
+
raise ValueError(
|
|
322
|
+
f"FlexEval error: expected metric_result for dependency '{dependency.metric['evaluation_name']}' to be computed before processing metric '{object_metric.metric['evaluation_name']}'."
|
|
323
|
+
)
|
|
324
|
+
dependency_info = g.get_edge_data(dependency, object_metric)[
|
|
325
|
+
"depends_on"
|
|
326
|
+
]
|
|
327
|
+
dependency_met = False
|
|
328
|
+
if (
|
|
329
|
+
"metric_name" in dependency_info
|
|
330
|
+
and dependency_info["metric_name"] is not None
|
|
331
|
+
and dependency_info["metric_name"]
|
|
332
|
+
!= dependency.metric["evaluation_name"]
|
|
333
|
+
):
|
|
334
|
+
for metric_result in dependency.metric_results:
|
|
335
|
+
# expected key must be present and in the expected range
|
|
336
|
+
if (
|
|
337
|
+
dependency_info["metric_name"]
|
|
338
|
+
== metric_result["metric_name"]
|
|
339
|
+
):
|
|
340
|
+
dependency_met = (
|
|
341
|
+
metric_result["metric_value"]
|
|
342
|
+
>= dependency_info["metric_min_value"]
|
|
343
|
+
) and (
|
|
344
|
+
metric_result["metric_value"]
|
|
345
|
+
<= dependency_info["metric_max_value"]
|
|
346
|
+
)
|
|
347
|
+
break
|
|
348
|
+
else:
|
|
349
|
+
logger.debug(
|
|
350
|
+
f"Key '{dependency_info['metric_name']}' not found in results for dependency '{dependency.metric['evaluation_name']}'."
|
|
351
|
+
)
|
|
352
|
+
elif len(dependency.metric_results) == 1:
|
|
353
|
+
metric_result = dependency.metric_results[0]
|
|
354
|
+
dependency_met = (
|
|
355
|
+
metric_result["metric_value"]
|
|
356
|
+
>= dependency_info["metric_min_value"]
|
|
357
|
+
) and (
|
|
358
|
+
metric_result["metric_value"]
|
|
359
|
+
<= dependency_info["metric_max_value"]
|
|
360
|
+
)
|
|
361
|
+
elif len(dependency.metric_results) == 0:
|
|
362
|
+
logger.debug(
|
|
363
|
+
f"Skipping metric because dependency '{dependency.metric['evaluation_name']}' has no results."
|
|
364
|
+
)
|
|
365
|
+
else:
|
|
366
|
+
raise ValueError(
|
|
367
|
+
f"Not sure how to evaluate dependency '{dependency.metric['evaluation_name']}' for metric '{object_metric.metric['evaluation_name']}', as it has {len(dependency.metric_results)} results but no specified key."
|
|
368
|
+
)
|
|
369
|
+
if not dependency_met:
|
|
370
|
+
all_dependencies_met = False
|
|
371
|
+
logger.debug(
|
|
372
|
+
f"Value for metric '{dependency.metric['evaluation_name']}' not in range for dependency {dependency_info}."
|
|
373
|
+
)
|
|
374
|
+
break
|
|
375
|
+
if all_dependencies_met:
|
|
376
|
+
# TODO in the future, we could pass some metric_results as kwargs to the metric function
|
|
377
|
+
# or as a special formatting key to the rubric
|
|
378
|
+
metric_results = self.compute_metric(
|
|
379
|
+
object_metric.object, **object_metric.metric
|
|
380
|
+
)
|
|
381
|
+
object_metric.metric_results = metric_results
|
|
382
|
+
evaluated_metrics.extend(metric_results)
|
|
383
|
+
else:
|
|
384
|
+
# no results for this metric, as dependencies were unmet
|
|
385
|
+
object_metric.metric_results = []
|
|
386
|
+
self._validate_metrics(evaluated_metrics)
|
|
387
|
+
except Exception as ex:
|
|
388
|
+
logger.exception(f"An error occurred during metric processing: {ex}")
|
|
389
|
+
if raise_on_error:
|
|
390
|
+
raise
|
|
391
|
+
return evaluated_metrics
|
|
392
|
+
|
|
393
|
+
def compute_metrics(self, object: Union[Thread, Turn, Message, ToolCall]):
|
|
394
|
+
"""we've defined a variable called metrics_to_evaluate
|
|
395
|
+
it's a list we need to loop through
|
|
396
|
+
each entry looks like this
|
|
397
|
+
{
|
|
398
|
+
'name': 'string_length',
|
|
399
|
+
'type': 'function',
|
|
400
|
+
'kwargs': {},
|
|
401
|
+
'depends_on': []
|
|
402
|
+
}
|
|
403
|
+
"""
|
|
404
|
+
# we'll keep the results in a list
|
|
405
|
+
# for each new metric, if it has dependencies, we'll need to make sure they're met - otherwise we won't run it
|
|
406
|
+
evaluated_metrics = []
|
|
407
|
+
# METRICS IN ORDER
|
|
408
|
+
for metric_to_evaluate in object.metrics_to_evaluate:
|
|
409
|
+
# see if there's a dependency
|
|
410
|
+
dependencies_are_all_met = True
|
|
411
|
+
# If there are no dependencies, this loop won't execute
|
|
412
|
+
# and the metric will be evaluated
|
|
413
|
+
if (
|
|
414
|
+
"depends_on" in metric_to_evaluate
|
|
415
|
+
and len(metric_to_evaluate["depends_on"]) > 0
|
|
416
|
+
):
|
|
417
|
+
# here, we have a metric with 1+ dependencies
|
|
418
|
+
# ALL of these dependencies must be satisfied
|
|
419
|
+
|
|
420
|
+
# we determine whether a given metric is a match if it matches
|
|
421
|
+
# 1 - the id
|
|
422
|
+
# 2 - the metric_name
|
|
423
|
+
# 3 - the metric_min_value
|
|
424
|
+
# 4 - the metric_max_value
|
|
425
|
+
# not meeting ANY of them will short-circuit the loop and cause the eval to not evaluate
|
|
426
|
+
# check all dependencies
|
|
427
|
+
for dependency in metric_to_evaluate["depends_on"]:
|
|
428
|
+
# for each dependency, assume it's not met
|
|
429
|
+
# if it's in the list AND its values meet the criteria, it's met
|
|
430
|
+
dependency_is_met = False
|
|
431
|
+
# if a specific metric_name was specified, you need to match exactly:
|
|
432
|
+
for em in evaluated_metrics:
|
|
433
|
+
# 'depends_on' will have all fields populated at this point
|
|
434
|
+
if em["id"] == dependency["parent_id"]:
|
|
435
|
+
if (
|
|
436
|
+
em["metric_value"] >= dependency["metric_min_value"]
|
|
437
|
+
and em["metric_value"] <= dependency["metric_max_value"]
|
|
438
|
+
):
|
|
439
|
+
# this specific dependency was met - can quit looking
|
|
440
|
+
dependency_is_met = True
|
|
441
|
+
break
|
|
442
|
+
else:
|
|
443
|
+
logger.debug(
|
|
444
|
+
f"Metric value '{em['metric_value']}' not in range for dependency id='{dependency['parent_id']}'."
|
|
445
|
+
)
|
|
446
|
+
if not dependency_is_met:
|
|
447
|
+
dependencies_are_all_met = False
|
|
448
|
+
# if even one dependency is not met - don't do the evaluation
|
|
449
|
+
break
|
|
450
|
+
if dependencies_are_all_met:
|
|
451
|
+
# pass through arguments, but add 'self' as the turn
|
|
452
|
+
# ONLY call if dependencies are ALL met
|
|
453
|
+
# TODO - maybe in the future we'll want to add the computed value from
|
|
454
|
+
# the dependency through as an argument here
|
|
455
|
+
metric_results = self.compute_metric(object, **metric_to_evaluate)
|
|
456
|
+
evaluated_metrics.extend(metric_results)
|
|
457
|
+
else:
|
|
458
|
+
logger.debug(
|
|
459
|
+
f"Skipping metric '{em['metric_name']}' (id='{em['id']}') due to unmet dependencies."
|
|
460
|
+
)
|
|
461
|
+
return evaluated_metrics
|
|
462
|
+
|
|
463
|
+
def compute_metric(
|
|
464
|
+
self,
|
|
465
|
+
object: Union[Thread, Turn, Message, ToolCall],
|
|
466
|
+
evaluation_name: str,
|
|
467
|
+
evaluation_type: str,
|
|
468
|
+
metric_level: str,
|
|
469
|
+
kwargs: dict,
|
|
470
|
+
context_only: bool = None,
|
|
471
|
+
depends_on: list = None,
|
|
472
|
+
id: int = None,
|
|
473
|
+
notes: str = None, # just a placeholder
|
|
474
|
+
) -> list[dict]:
|
|
475
|
+
if evaluation_type == "function":
|
|
476
|
+
metrics = self.compute_function_metric(
|
|
477
|
+
function_name=evaluation_name,
|
|
478
|
+
metric_kwargs=kwargs,
|
|
479
|
+
metric_level=metric_level,
|
|
480
|
+
context_only=context_only,
|
|
481
|
+
input_object=object,
|
|
482
|
+
depends_on=depends_on,
|
|
483
|
+
id=id,
|
|
484
|
+
)
|
|
485
|
+
elif evaluation_type == "rubric":
|
|
486
|
+
metrics = self.compute_rubric_metric(
|
|
487
|
+
rubric_name=evaluation_name,
|
|
488
|
+
metric_kwargs=kwargs,
|
|
489
|
+
metric_level=metric_level,
|
|
490
|
+
object=object,
|
|
491
|
+
depends_on=depends_on,
|
|
492
|
+
id=id,
|
|
493
|
+
)
|
|
494
|
+
else:
|
|
495
|
+
raise ValueError(
|
|
496
|
+
f"The argument evaluation_type provided to compute_metric is invalid. Must be one of 'function' or 'rubric'. You passed '{type}'."
|
|
497
|
+
)
|
|
498
|
+
self._validate_metrics(metrics)
|
|
499
|
+
return metrics
|
|
500
|
+
|
|
501
|
+
def _validate_metrics(self, metrics: list[dict]):
|
|
502
|
+
for m in metrics:
|
|
503
|
+
if m.get("evaluation_type", None) is None:
|
|
504
|
+
raise ValueError(
|
|
505
|
+
f"Metric '{m}' does not have a value for the key `type`."
|
|
506
|
+
)
|
|
507
|
+
if m.get("metric_value", None) is None:
|
|
508
|
+
raise ValueError(
|
|
509
|
+
f"Metric '{m}' does not have a value for the key `metric_value`."
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
def invoke_function(
|
|
513
|
+
self,
|
|
514
|
+
metric_function: callable,
|
|
515
|
+
metric_level: eval_schema.MetricLevel,
|
|
516
|
+
input_object: function_types.AnyFunctionObjectInput,
|
|
517
|
+
metric_kwargs: dict,
|
|
518
|
+
context_only: bool,
|
|
519
|
+
):
|
|
520
|
+
function_input = function_types.get_function_input(
|
|
521
|
+
metric_function, metric_level, input_object, context_only
|
|
522
|
+
)
|
|
523
|
+
metrics_result = metric_function(function_input, **metric_kwargs)
|
|
524
|
+
return metrics_result
|
|
525
|
+
|
|
526
|
+
def find_function(self, function_name: str):
|
|
527
|
+
for function_module in self.function_modules:
|
|
528
|
+
if hasattr(function_module, function_name) and callable(
|
|
529
|
+
getattr(function_module, function_name)
|
|
530
|
+
):
|
|
531
|
+
metric_function = getattr(function_module, function_name)
|
|
532
|
+
metric_source = inspect.getsource(metric_function)
|
|
533
|
+
return metric_function, metric_source
|
|
534
|
+
raise ValueError(
|
|
535
|
+
f"Metric function with name '{function_name}' was not found in any of the '{len(self.function_modules)}' provided function modules."
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
def compute_function_metric(
|
|
539
|
+
self,
|
|
540
|
+
function_name: str,
|
|
541
|
+
metric_kwargs: dict,
|
|
542
|
+
input_object: Union[Thread, Turn, Message, ToolCall],
|
|
543
|
+
metric_level: eval_schema.MetricLevel,
|
|
544
|
+
context_only: bool,
|
|
545
|
+
depends_on: list,
|
|
546
|
+
id: int,
|
|
547
|
+
):
|
|
548
|
+
# this is NOT a method - it's a function b/c we want it to be able to return multiple metrics, if more than one is returned
|
|
549
|
+
# they share most of the same information though so it's convenient to have them constructed similarly
|
|
550
|
+
# will return a list of dictionaries
|
|
551
|
+
|
|
552
|
+
# Check if the function exists in any of the function namespaces
|
|
553
|
+
metric_function, metric_source = self.find_function(function_name)
|
|
554
|
+
metrics_result = self.invoke_function(
|
|
555
|
+
metric_function, metric_level, input_object, metric_kwargs, context_only
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
base_result = {
|
|
559
|
+
metric_level.lower(): input_object,
|
|
560
|
+
"evaluation_name": function_name,
|
|
561
|
+
"evaluation_type": "function",
|
|
562
|
+
"metric_level": metric_level,
|
|
563
|
+
"kwargs": metric_kwargs,
|
|
564
|
+
"source": metric_source, # TODO - put this back?
|
|
565
|
+
"context_only": context_only,
|
|
566
|
+
"depends_on": depends_on,
|
|
567
|
+
"id": id,
|
|
568
|
+
}
|
|
569
|
+
# now deal with output
|
|
570
|
+
if isinstance(metrics_result, float) or isinstance(metrics_result, int):
|
|
571
|
+
result = copy.deepcopy(base_result)
|
|
572
|
+
result["metric_name"] = function_name
|
|
573
|
+
result["metric_value"] = metrics_result
|
|
574
|
+
return [result]
|
|
575
|
+
elif isinstance(metrics_result, dict):
|
|
576
|
+
result_list = []
|
|
577
|
+
# TODO rethink this behavior
|
|
578
|
+
for k, v in metrics_result.items():
|
|
579
|
+
result = copy.deepcopy(base_result)
|
|
580
|
+
if "metric_name" in result and result["metric_name"] != k:
|
|
581
|
+
logger.warning(
|
|
582
|
+
f"Overriding metric_name in metric result with '{k}' (was '{result['metric_name']}')."
|
|
583
|
+
)
|
|
584
|
+
result["metric_name"] = k
|
|
585
|
+
result["metric_value"] = float(v)
|
|
586
|
+
result_list.append(result)
|
|
587
|
+
return result_list
|
|
588
|
+
elif isinstance(metrics_result, list):
|
|
589
|
+
result_list = []
|
|
590
|
+
|
|
591
|
+
for entry in metrics_result:
|
|
592
|
+
result = copy.deepcopy(base_result)
|
|
593
|
+
result["metric_name"] = entry.get("name", None)
|
|
594
|
+
result["metric_value"] = float(entry.get("value", None))
|
|
595
|
+
result_list.append(result)
|
|
596
|
+
return result_list
|
|
597
|
+
else:
|
|
598
|
+
raise ValueError(
|
|
599
|
+
f"The metric type returned from '{metric_function}' is not a supported type. It must be one of `list`, `int`, `float`, or `dict`. You supplied '{type(metrics_result)}'."
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
def compute_rubric_metric(
|
|
603
|
+
self,
|
|
604
|
+
rubric_name: str,
|
|
605
|
+
metric_kwargs: dict,
|
|
606
|
+
object: Union[Thread, Turn, Message],
|
|
607
|
+
metric_level: str,
|
|
608
|
+
depends_on: list,
|
|
609
|
+
id: int,
|
|
610
|
+
):
|
|
611
|
+
if self.rubrics is not None:
|
|
612
|
+
rubrics = self.rubrics
|
|
613
|
+
else:
|
|
614
|
+
rubrics = json.loads(object.evalsetrun.rubrics)
|
|
615
|
+
if rubric_name not in rubrics:
|
|
616
|
+
raise ValueError(
|
|
617
|
+
f"You requested a rubric called '{rubric_name}', but only these were found: {rubrics.keys()}."
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
prompt = rubrics.get(rubric_name).get("prompt", "")
|
|
621
|
+
|
|
622
|
+
# format input for rubric
|
|
623
|
+
conversation, context, content, tool_calls = object.format_input_for_rubric()
|
|
624
|
+
# conversation : all turns; context: all turns without the last entry; completion: only the last entry
|
|
625
|
+
# use three keywords:
|
|
626
|
+
# #{conversation} -- The whole conversation
|
|
627
|
+
# #{context} -- The previous turns without the current entry
|
|
628
|
+
# #{content} -- Only the current turn / message / toolcall depending on the metric_level
|
|
629
|
+
# for the future: add {compeltion} under the condition of do_completion == True
|
|
630
|
+
|
|
631
|
+
# Add verfication steps before populating the rubric
|
|
632
|
+
# case 1: {conversation} and {context} should not go together
|
|
633
|
+
# case 2: {completion} and {content} should not go together
|
|
634
|
+
# case 3: if there is a {completion}, do_completion should be true
|
|
635
|
+
|
|
636
|
+
if "{conversation}" in prompt and "{context}" in prompt:
|
|
637
|
+
raise Exception(
|
|
638
|
+
"Your rubric should not have both {conversation} and {context}. Please check the README file for more information about how to write FlexEval rubrics."
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
if "{completion}" in prompt and "{content}" in prompt:
|
|
642
|
+
raise Exception(
|
|
643
|
+
"Your rubric should not have both {content} and {completion}. Please check the README file for more information about how to write FlexEval rubrics."
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
if "{completion}" in prompt and not object.evalsetrun.do_completion:
|
|
647
|
+
raise Exception(
|
|
648
|
+
"Your rubric has {completion}, but in your test specification for this rubric evaluation, do_completion is not True. Please check the README file for more information about how to write FlexEval rubrics."
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
populated_prompt = prompt.format(
|
|
652
|
+
conversation=conversation,
|
|
653
|
+
context=context,
|
|
654
|
+
content=content,
|
|
655
|
+
tool_calls=tool_calls,
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
# with do_completion == True, only the completion is evaluated with or without the context.
|
|
659
|
+
if object.evalsetrun.do_completion and "{completion}" in prompt:
|
|
660
|
+
# TODO revisit this logic
|
|
661
|
+
# also included object.is_completion, which only works for Message rubrics
|
|
662
|
+
# but we can in principle check for a message in either a turn or a thread with is_flexeval_completion true
|
|
663
|
+
populated_prompt = prompt.format(completion=content)
|
|
664
|
+
|
|
665
|
+
choice_scores = rubrics.get(rubric_name).get("choice_scores")
|
|
666
|
+
|
|
667
|
+
# get rubric grader
|
|
668
|
+
if object.evalsetrun.grader_llm is None or object.evalsetrun.grader_llm == "":
|
|
669
|
+
raise ValueError(
|
|
670
|
+
"Attempting to evaluate a rubric metric, but no grader LLM defined."
|
|
671
|
+
)
|
|
672
|
+
grader_completion_function = json.loads(object.evalsetrun.grader_llm)
|
|
673
|
+
if grader_completion_function is None or len(grader_completion_function) == 0:
|
|
674
|
+
raise ValueError(
|
|
675
|
+
"Attempting to evaluate a rubric metric, but no grader LLM defined."
|
|
676
|
+
)
|
|
677
|
+
grader_completion_fn_name = grader_completion_function.get(
|
|
678
|
+
"function_name", None
|
|
679
|
+
)
|
|
680
|
+
grader_completion_fn_kwargs = grader_completion_function.get("kwargs", {})
|
|
681
|
+
if hasattr(completion_functions, grader_completion_fn_name) and hasattr(
|
|
682
|
+
completion_functions, grader_completion_fn_name
|
|
683
|
+
):
|
|
684
|
+
completion_function = getattr(
|
|
685
|
+
completion_functions, grader_completion_fn_name, None
|
|
686
|
+
)
|
|
687
|
+
DEFAULT_COT_TEMPLATE = "\nBefore answering, I will reason in a step-by-step manner as to get the right answer, then conclude with the answer in the format requested."
|
|
688
|
+
ANSWER_PROMPT = f""""
|
|
689
|
+
First, write out in a step by step manner your reasoning to be sure that your conclusion is correct. Avoid simply stating the correct answer at the outset. Then print only a single choice from {list(choice_scores.keys())} (without quotes or punctuation) on its own line corresponding to the correct answer. At the end, repeat just the answer by itself on a new line.
|
|
690
|
+
|
|
691
|
+
Reasoning:""".strip()
|
|
692
|
+
# This is the call to the grader completion function
|
|
693
|
+
completion = completion_function(
|
|
694
|
+
conversation_history=[
|
|
695
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
|
696
|
+
{"role": "assistant", "content": DEFAULT_COT_TEMPLATE},
|
|
697
|
+
{"role": "user", "content": populated_prompt + ANSWER_PROMPT},
|
|
698
|
+
{"role": "assistant", "content": DEFAULT_COT_TEMPLATE},
|
|
699
|
+
],
|
|
700
|
+
**grader_completion_fn_kwargs,
|
|
701
|
+
)
|
|
702
|
+
completion_text = completion["choices"][0]["message"]["content"]
|
|
703
|
+
|
|
704
|
+
# use logic from Evals
|
|
705
|
+
# https://github.com/openai/evals/blob/d3dc89042ddee879a68a326fdb37716ee518640c/evals/elsuite/modelgraded/classify_utils.py#L29
|
|
706
|
+
choice_scores["__invalid__"] = None
|
|
707
|
+
|
|
708
|
+
def get_match(completion_text: str, choice_scores: dict):
|
|
709
|
+
MATCH_FNS = {
|
|
710
|
+
"include": lambda x, y: float(x in y),
|
|
711
|
+
"exact": lambda x, y: float(x == y),
|
|
712
|
+
"endswith": lambda x, y: x.endswith(y),
|
|
713
|
+
"starts_or_endswith": lambda x, y: x.startswith(y) or x.endswith(y),
|
|
714
|
+
}
|
|
715
|
+
lines = completion_text.strip().split("\n")
|
|
716
|
+
lines = lines[::-1] # reverse lines
|
|
717
|
+
for line in lines:
|
|
718
|
+
line = line.strip()
|
|
719
|
+
line = "".join(c for c in line if c not in string.punctuation)
|
|
720
|
+
if not line:
|
|
721
|
+
continue
|
|
722
|
+
for choice in choice_scores.keys():
|
|
723
|
+
if MATCH_FNS["starts_or_endswith"](line, choice):
|
|
724
|
+
return choice
|
|
725
|
+
return "__invalid__"
|
|
726
|
+
|
|
727
|
+
score = get_match(
|
|
728
|
+
completion_text=completion_text, choice_scores=choice_scores
|
|
729
|
+
)
|
|
730
|
+
result = {
|
|
731
|
+
metric_level.lower(): object,
|
|
732
|
+
"metric_name": rubric_name,
|
|
733
|
+
"evaluation_name": rubric_name,
|
|
734
|
+
"evaluation_type": "rubric",
|
|
735
|
+
"id": id,
|
|
736
|
+
"kwargs": metric_kwargs,
|
|
737
|
+
"depends_on": depends_on,
|
|
738
|
+
"source": populated_prompt,
|
|
739
|
+
"metric_level": metric_level,
|
|
740
|
+
"metric_value": choice_scores[score],
|
|
741
|
+
"rubric_prompt": populated_prompt,
|
|
742
|
+
"rubric_completion": completion_text,
|
|
743
|
+
"rubric_model": completion.get("model", None),
|
|
744
|
+
"rubric_completion_tokens": completion.get("usage", {}).get(
|
|
745
|
+
"completion_tokens", 0
|
|
746
|
+
),
|
|
747
|
+
"rubric_prompt_tokens": completion.get("usage", {}).get(
|
|
748
|
+
"prompt_tokens", 0
|
|
749
|
+
),
|
|
750
|
+
"rubric_score": score,
|
|
751
|
+
}
|
|
752
|
+
return [result]
|
|
753
|
+
|
|
754
|
+
|
|
755
|
+
def add_all_metrics_to_objects(iterable_of_objects, metrics):
|
|
756
|
+
"""
|
|
757
|
+
Adds all metric instances in metrics_for_level to each instance of
|
|
758
|
+
an evaluable object (e.g., Turn, Thread, Message, or ToolCall) in
|
|
759
|
+
iterable_of_objects. This addition is done by appending to the
|
|
760
|
+
`metrics_to_evaluate` field, which all instances in iterable_of_objects
|
|
761
|
+
should have.
|
|
762
|
+
|
|
763
|
+
:param iterable_of_objects: list of objects that have a metrics_to_evaluate field
|
|
764
|
+
:param metrics: list of metric instances to add to each object
|
|
765
|
+
"""
|
|
766
|
+
for object in iterable_of_objects:
|
|
767
|
+
# Field metrics_to_evaluate initialized in constructor
|
|
768
|
+
# metric dependencies happen WITHIN turns, rather than across
|
|
769
|
+
# this means I can associate a sequence of metrics within each turn
|
|
770
|
+
# but then have the turns execute them in parallel
|
|
771
|
+
# each turn will keep track of its own set of metrics
|
|
772
|
+
# Keeping this as a loop to do the rubric_count appropriately
|
|
773
|
+
object.metrics_to_evaluate = object.metrics_to_evaluate + metrics
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
def count_rubric_metrics(iterable_of_objects):
|
|
777
|
+
"""
|
|
778
|
+
Returns the total number of rubric type metrics in
|
|
779
|
+
the metrics_to_evaluate field in each object.
|
|
780
|
+
|
|
781
|
+
:param iterable_of_objects: list of objects that have a metrics_to_evaluate field
|
|
782
|
+
"""
|
|
783
|
+
rubric_count = 0
|
|
784
|
+
for object in iterable_of_objects:
|
|
785
|
+
for metric_instance in object.metrics_to_evaluate:
|
|
786
|
+
if metric_instance.get("evaluation_type") == "rubric":
|
|
787
|
+
rubric_count += 1
|
|
788
|
+
return rubric_count
|