python-flexeval 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. flexeval/__init__.py +11 -0
  2. flexeval/__main__.py +11 -0
  3. flexeval/classes/__init__.py +15 -0
  4. flexeval/classes/base.py +32 -0
  5. flexeval/classes/dataset.py +82 -0
  6. flexeval/classes/eval_runner.py +158 -0
  7. flexeval/classes/eval_set_run.py +32 -0
  8. flexeval/classes/message.py +183 -0
  9. flexeval/classes/metric.py +55 -0
  10. flexeval/classes/thread.py +79 -0
  11. flexeval/classes/tool_call.py +51 -0
  12. flexeval/classes/turn.py +206 -0
  13. flexeval/cli.py +104 -0
  14. flexeval/completions.py +147 -0
  15. flexeval/compute_metrics.py +788 -0
  16. flexeval/config.yaml +23 -0
  17. flexeval/configuration/__init__.py +1 -0
  18. flexeval/configuration/completion_functions.py +231 -0
  19. flexeval/configuration/evals.yaml +864 -0
  20. flexeval/configuration/function_metrics.py +650 -0
  21. flexeval/configuration/rubric_metrics.yaml +194 -0
  22. flexeval/data_loader.py +513 -0
  23. flexeval/db_utils.py +38 -0
  24. flexeval/dependency_graph.py +234 -0
  25. flexeval/eval_schema.json +256 -0
  26. flexeval/function_types.py +173 -0
  27. flexeval/helpers.py +52 -0
  28. flexeval/io/__init__.py +1 -0
  29. flexeval/io/parsers/yaml_parser.py +69 -0
  30. flexeval/log_utils.py +34 -0
  31. flexeval/metrics/__init__.py +8 -0
  32. flexeval/metrics/access.py +28 -0
  33. flexeval/metrics/save.py +39 -0
  34. flexeval/rubric.py +62 -0
  35. flexeval/run_utils.py +65 -0
  36. flexeval/runner.py +132 -0
  37. flexeval/schema/__init__.py +11 -0
  38. flexeval/schema/config_schema.py +46 -0
  39. flexeval/schema/eval_schema.py +163 -0
  40. flexeval/schema/evalrun_schema.py +97 -0
  41. flexeval/schema/rubric_schema.py +40 -0
  42. flexeval/schema/schema_utils.py +26 -0
  43. python_flexeval-0.1.5.dist-info/METADATA +118 -0
  44. python_flexeval-0.1.5.dist-info/RECORD +47 -0
  45. python_flexeval-0.1.5.dist-info/WHEEL +4 -0
  46. python_flexeval-0.1.5.dist-info/entry_points.txt +2 -0
  47. python_flexeval-0.1.5.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,788 @@
1
+ """Utilities for computing needed metric computations and actually invoking those computations."""
2
+
3
+ import copy
4
+ import importlib
5
+ import importlib.util
6
+ import inspect
7
+ import json
8
+ import logging
9
+ import string
10
+ import types
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from typing import Iterable, Union
13
+
14
+ import networkx as nx
15
+
16
+ from flexeval import function_types
17
+ from flexeval.classes.eval_set_run import EvalSetRun
18
+ from flexeval.classes.message import Message
19
+ from flexeval.classes.thread import Thread
20
+ from flexeval.classes.tool_call import ToolCall
21
+ from flexeval.classes.turn import Turn
22
+ from flexeval.configuration import completion_functions, function_metrics
23
+ from flexeval.schema import EvalRun, FunctionsCollection, eval_schema
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class ObjectMetric:
29
+ def __init__(self, object: Message | Turn | ToolCall | Thread, metric: dict):
30
+ """Tracks a unique (object, metric) combination and any results computed for that metric.
31
+
32
+ Args:
33
+ object (Message | Turn | ToolCall | Thread): The object to track.
34
+ metric (dict): The metric to track.
35
+ """
36
+ self.object: Message | Turn | ToolCall | Thread = object
37
+ self.metric: dict = metric
38
+ self.metric_results: list[dict] | None = None
39
+
40
+ def __repr__(self) -> str:
41
+ return f"ObjectMetric(object={self.object.__class__.__name__} {self.object.id}, metric={self.metric}, metric_results={self.metric_results})"
42
+
43
+
44
+ class MetricGraphBuilder:
45
+ """Builds :class:`networkx.DiGraph`\s of :class:`~flexeval.compute_metrics.ObjectMetric` instances that reflect any computational dependencies between them."""
46
+
47
+ def __init__(self):
48
+ # key: tuple(metric_level, metric_id, object_id)
49
+ # value: ObjectMetric
50
+ self.id_to_object_metric_map = {}
51
+
52
+ def build_metric_structures(self, evalsetrun: EvalSetRun):
53
+ metric_id_map = {}
54
+ metrics_by_level = {}
55
+ for metric_instance in json.loads(evalsetrun.metrics_graph_ordered_list):
56
+ metric_level = metric_instance["metric_level"]
57
+ if metric_level not in metrics_by_level:
58
+ metrics_by_level[metric_level] = []
59
+ metrics_by_level[metric_level].append(metric_instance)
60
+ metric_id_map[metric_instance["id"]] = metric_instance
61
+ self.metric_id_map = metric_id_map
62
+ self.metrics_by_level = metrics_by_level
63
+
64
+ def get_or_create_object_metric(
65
+ self,
66
+ metric_level: eval_schema.MetricLevel,
67
+ object: Message | Turn | ToolCall | Thread,
68
+ metric: dict,
69
+ ) -> ObjectMetric:
70
+ key = (metric_level, metric["id"], object.id)
71
+ if key not in self.id_to_object_metric_map:
72
+ object_metric = ObjectMetric(object, metric)
73
+ self.id_to_object_metric_map[key] = object_metric
74
+ return self.id_to_object_metric_map[key]
75
+
76
+ def get_index(
77
+ self, target_id: int, objects: list[Message | Turn | ToolCall | Thread]
78
+ ):
79
+ for i, object in enumerate(objects):
80
+ if target_id == object.id:
81
+ break
82
+ else:
83
+ raise ValueError(
84
+ f"Failed to find object with id '{target_id}' in '{len(objects)}' objects."
85
+ )
86
+ return i
87
+
88
+ def find_object_metric_from_depends_on(
89
+ self,
90
+ current_object: Message | Turn | ToolCall | Thread,
91
+ current_metric_level: eval_schema.MetricLevel,
92
+ current_index: int,
93
+ depends_on: dict,
94
+ ) -> ObjectMetric | None:
95
+ """
96
+ If you're a Turn metric that depends on a Message metric,
97
+ then we create a dependency on ALL or ANY Message meeting the criteria.
98
+ We don't know how to handle that...
99
+
100
+ In contrast, if you're a Message metric that depends on a Turn metric,
101
+ then we have a dependency on only a single object: that Message's Turn.
102
+ """
103
+ metric_id = depends_on["parent_id"]
104
+ dependency_metric_level = depends_on.get("metric_level")
105
+ if dependency_metric_level is None:
106
+ # if not specified in the dependency already, look up the metric level
107
+ depends_on_metric = self.metric_id_map[metric_id]
108
+ dependency_metric_level = depends_on_metric["metric_level"]
109
+ if dependency_metric_level is None:
110
+ raise ValueError(
111
+ f"Metric lacks a metric level: {depends_on_metric} (matched via dependency_info: {depends_on})"
112
+ )
113
+
114
+ if dependency_metric_level == current_metric_level:
115
+ pass # just use current_index, no lookup needed
116
+ elif current_metric_level == "ToolCall":
117
+ if dependency_metric_level == "Message":
118
+ current_index = self.get_index(
119
+ current_object.message_id, self.objects_by_level["Message"]
120
+ )
121
+ elif dependency_metric_level == "Turn":
122
+ current_index = self.get_index(
123
+ current_object.turn_id, self.objects_by_level["Turn"]
124
+ )
125
+ elif dependency_metric_level == "Thread":
126
+ current_index = 0 # only a single thread, by definition
127
+ elif current_metric_level == "Message":
128
+ if dependency_metric_level == "Turn":
129
+ current_index = self.get_index(
130
+ current_object.turn_id, self.objects_by_level["Turn"]
131
+ )
132
+ elif dependency_metric_level == "Thread":
133
+ current_index = 0 # only a single thread, by definition
134
+ elif dependency_metric_level == "ToolCall":
135
+ raise ValueError(
136
+ f"Can't depend on a '{dependency_metric_level}' metric from a '{current_metric_level}' metric."
137
+ )
138
+ elif current_metric_level == "Turn":
139
+ if dependency_metric_level == "Thread":
140
+ current_index = 0 # only a single thread, by definition
141
+ else:
142
+ raise ValueError(
143
+ f"Can't depend on a '{dependency_metric_level}' metric from a '{current_metric_level}' metric."
144
+ )
145
+ elif current_metric_level == "Thread":
146
+ raise ValueError(
147
+ f"Can't depend on a '{dependency_metric_level}' metric from a '{current_metric_level}' metric."
148
+ )
149
+ else:
150
+ raise ValueError(f"Unsupported level: {current_metric_level=}")
151
+ relative_object_position = depends_on["relative_object_position"]
152
+ target_object_index = current_index + relative_object_position
153
+ if target_object_index < 0:
154
+ logger.debug(
155
+ f"Object at position '{current_index}' object cannot in principle satisfy this dependency, so skipping it."
156
+ )
157
+ return None
158
+ object = self.objects_by_level[dependency_metric_level][target_object_index]
159
+ metric = self.metric_id_map[metric_id]
160
+ return self.get_or_create_object_metric(dependency_metric_level, object, metric)
161
+
162
+ def build_thread_task_graphs(self, evalsetrun: EvalSetRun) -> Iterable[nx.DiGraph]:
163
+ threads = evalsetrun.threads
164
+ for thread in threads:
165
+ yield self.build_thread_task_graph(thread)
166
+
167
+ def build_thread_task_graph(self, thread: Thread) -> nx.DiGraph:
168
+ self.objects_by_level = {
169
+ "Thread": [thread],
170
+ "Turn": list(thread.turns),
171
+ "Message": list(thread.messages),
172
+ "ToolCall": list(thread.toolcalls),
173
+ }
174
+
175
+ g = nx.DiGraph()
176
+ for level, metrics_at_level in self.metrics_by_level.items():
177
+ if len(metrics_at_level) == 0:
178
+ continue
179
+ objects = self.objects_by_level[level]
180
+ for i, object in enumerate(objects):
181
+ for metric in metrics_at_level:
182
+ # register metric on object
183
+ object_metric = self.get_or_create_object_metric(
184
+ level, object, metric
185
+ )
186
+ g.add_node(object_metric)
187
+ if "depends_on" in metric:
188
+ for dependency in metric["depends_on"]:
189
+ # register dependency metric on the relevant object
190
+ dependency_object_metric = (
191
+ self.find_object_metric_from_depends_on(
192
+ object, level, i, dependency
193
+ )
194
+ )
195
+ if dependency_object_metric is None:
196
+ logger.debug(
197
+ "This object cannot in principle satisfy this dependency, so skipping it."
198
+ )
199
+ # TODO verify that this is the expected behavior in chained dependencies X -> Y -> Z
200
+ g.remove_node(object_metric)
201
+ continue
202
+ g.add_node(dependency_object_metric)
203
+ g.add_edge(
204
+ dependency_object_metric,
205
+ object_metric,
206
+ depends_on=dependency,
207
+ )
208
+ return g
209
+
210
+
211
+ def compute_metrics(evalrun: EvalRun, evalsetrun: EvalSetRun) -> list[dict]:
212
+ n_workers = evalrun.config.max_workers
213
+ raise_on_error = evalrun.config.raise_on_metric_error
214
+ mgb = MetricGraphBuilder()
215
+ mgb.build_metric_structures(evalsetrun)
216
+ graphs = mgb.build_thread_task_graphs(evalsetrun)
217
+ mc = MetricComputer.from_evalrun(evalrun, evalsetrun)
218
+ metrics = []
219
+ if n_workers == 1:
220
+ for graph in graphs:
221
+ graph_metrics = mc.process_thread_dependency_graph(graph, raise_on_error)
222
+ metrics.extend(graph_metrics)
223
+ else:
224
+ with ThreadPoolExecutor(max_workers=n_workers) as executor:
225
+ futures = []
226
+ for graph in graphs:
227
+ future = executor.submit(mc.process_thread_dependency_graph, graph)
228
+ futures.append(future)
229
+ for i, future in enumerate(futures):
230
+ metrics.extend(future.result())
231
+ if i % 100 == 0:
232
+ logger.info(f"Metrics futures resulted: {i + 1} / {len(futures)}")
233
+ return metrics
234
+
235
+
236
+ class MetricComputer:
237
+ @classmethod
238
+ def from_evalrun(
239
+ cls, evalrun: EvalRun, evalsetrun: EvalSetRun | None = None
240
+ ) -> "MetricComputer":
241
+ function_modules = evalrun.function_modules
242
+ # convert from string module names or filepaths to Python modules
243
+ actual_modules = []
244
+ for i, function_module in enumerate(function_modules):
245
+ if isinstance(function_module, types.ModuleType):
246
+ # already a module
247
+ actual_modules.append(function_module)
248
+ elif isinstance(function_module, FunctionsCollection):
249
+ raise ValueError("FunctionsCollection not yet implemented!")
250
+ else: # it's a filepath
251
+ try:
252
+ # TODO I think this is not necessary given the pydantic schema; this should always fail for filepaths
253
+ # alternately, we might call import_module() on the ModuleType modules, but I think that's unnecessary
254
+ module = importlib.import_module(str(function_module))
255
+ except ModuleNotFoundError as module_not_found:
256
+ try:
257
+ spec = importlib.util.spec_from_file_location(
258
+ f"function_module_{i}", function_module
259
+ )
260
+ module = importlib.util.module_from_spec(spec)
261
+ spec.loader.exec_module(module)
262
+ except Exception as module_not_loaded:
263
+ raise ValueError(
264
+ f"Failed to load function module specified by '{function_module}.' (module not found: {module_not_found}, and failed to load from file location: {module_not_loaded})"
265
+ )
266
+ actual_modules.append(module)
267
+ if evalrun.add_default_functions and function_metrics not in actual_modules:
268
+ actual_modules.append(function_metrics)
269
+ mc = cls(actual_modules, evalsetrun)
270
+ # validation step: verify that all functions are present
271
+ missing_functions = set()
272
+ if evalrun.eval.metrics.function is not None:
273
+ for function_item in evalrun.eval.metrics.function:
274
+ try:
275
+ mc.find_function(function_item.name)
276
+ except ValueError:
277
+ missing_functions.add(function_item.name)
278
+ if len(missing_functions) > 0:
279
+ raise ValueError(
280
+ f"Failed to find '{len(missing_functions)}' functions in the provided function module. Missing function names: {', '.join(sorted(missing_functions))}"
281
+ )
282
+ # validation step: verify that all rubrics are present
283
+ missing_rubrics = set()
284
+ if mc.rubrics is not None and evalrun.eval.metrics.rubric is not None:
285
+ for rubric_item in evalrun.eval.metrics.rubric:
286
+ if rubric_item.name not in mc.rubrics:
287
+ missing_rubrics.add(rubric_item.name)
288
+ if len(missing_rubrics) > 0:
289
+ raise ValueError(
290
+ f"Failed to find '{len(missing_rubrics)}' rubrics in the provided rubric set. Missing rubric names: {', '.join(sorted(missing_rubrics))}"
291
+ )
292
+ return mc
293
+
294
+ def __init__(self, function_modules: list, evalsetrun: EvalSetRun | None = None):
295
+ self.function_modules: list = function_modules
296
+ self.rubrics: dict | None = (
297
+ self.load_rubrics(evalsetrun) if evalsetrun is not None else None
298
+ )
299
+
300
+ def load_rubrics(self, evalsetrun: EvalSetRun):
301
+ """Set the rubrics to be used by this MetricComputer from the given EvalSetRun."""
302
+ self.rubrics = json.loads(evalsetrun.rubrics)
303
+
304
+ def process_thread_dependency_graphs(
305
+ self, graph_list: Iterable[nx.DiGraph]
306
+ ) -> list[dict]:
307
+ evaluated_metrics = []
308
+ for g in graph_list:
309
+ evaluated_metrics.extend(self.process_thread_dependency_graph(g))
310
+ return evaluated_metrics
311
+
312
+ def process_thread_dependency_graph(
313
+ self, g: nx.DiGraph, raise_on_error: bool = True
314
+ ) -> list[dict]:
315
+ evaluated_metrics = []
316
+ try:
317
+ for object_metric in nx.topological_sort(g):
318
+ all_dependencies_met = True
319
+ for dependency in g.predecessors(object_metric):
320
+ if dependency.metric_results is None:
321
+ raise ValueError(
322
+ f"FlexEval error: expected metric_result for dependency '{dependency.metric['evaluation_name']}' to be computed before processing metric '{object_metric.metric['evaluation_name']}'."
323
+ )
324
+ dependency_info = g.get_edge_data(dependency, object_metric)[
325
+ "depends_on"
326
+ ]
327
+ dependency_met = False
328
+ if (
329
+ "metric_name" in dependency_info
330
+ and dependency_info["metric_name"] is not None
331
+ and dependency_info["metric_name"]
332
+ != dependency.metric["evaluation_name"]
333
+ ):
334
+ for metric_result in dependency.metric_results:
335
+ # expected key must be present and in the expected range
336
+ if (
337
+ dependency_info["metric_name"]
338
+ == metric_result["metric_name"]
339
+ ):
340
+ dependency_met = (
341
+ metric_result["metric_value"]
342
+ >= dependency_info["metric_min_value"]
343
+ ) and (
344
+ metric_result["metric_value"]
345
+ <= dependency_info["metric_max_value"]
346
+ )
347
+ break
348
+ else:
349
+ logger.debug(
350
+ f"Key '{dependency_info['metric_name']}' not found in results for dependency '{dependency.metric['evaluation_name']}'."
351
+ )
352
+ elif len(dependency.metric_results) == 1:
353
+ metric_result = dependency.metric_results[0]
354
+ dependency_met = (
355
+ metric_result["metric_value"]
356
+ >= dependency_info["metric_min_value"]
357
+ ) and (
358
+ metric_result["metric_value"]
359
+ <= dependency_info["metric_max_value"]
360
+ )
361
+ elif len(dependency.metric_results) == 0:
362
+ logger.debug(
363
+ f"Skipping metric because dependency '{dependency.metric['evaluation_name']}' has no results."
364
+ )
365
+ else:
366
+ raise ValueError(
367
+ f"Not sure how to evaluate dependency '{dependency.metric['evaluation_name']}' for metric '{object_metric.metric['evaluation_name']}', as it has {len(dependency.metric_results)} results but no specified key."
368
+ )
369
+ if not dependency_met:
370
+ all_dependencies_met = False
371
+ logger.debug(
372
+ f"Value for metric '{dependency.metric['evaluation_name']}' not in range for dependency {dependency_info}."
373
+ )
374
+ break
375
+ if all_dependencies_met:
376
+ # TODO in the future, we could pass some metric_results as kwargs to the metric function
377
+ # or as a special formatting key to the rubric
378
+ metric_results = self.compute_metric(
379
+ object_metric.object, **object_metric.metric
380
+ )
381
+ object_metric.metric_results = metric_results
382
+ evaluated_metrics.extend(metric_results)
383
+ else:
384
+ # no results for this metric, as dependencies were unmet
385
+ object_metric.metric_results = []
386
+ self._validate_metrics(evaluated_metrics)
387
+ except Exception as ex:
388
+ logger.exception(f"An error occurred during metric processing: {ex}")
389
+ if raise_on_error:
390
+ raise
391
+ return evaluated_metrics
392
+
393
+ def compute_metrics(self, object: Union[Thread, Turn, Message, ToolCall]):
394
+ """we've defined a variable called metrics_to_evaluate
395
+ it's a list we need to loop through
396
+ each entry looks like this
397
+ {
398
+ 'name': 'string_length',
399
+ 'type': 'function',
400
+ 'kwargs': {},
401
+ 'depends_on': []
402
+ }
403
+ """
404
+ # we'll keep the results in a list
405
+ # for each new metric, if it has dependencies, we'll need to make sure they're met - otherwise we won't run it
406
+ evaluated_metrics = []
407
+ # METRICS IN ORDER
408
+ for metric_to_evaluate in object.metrics_to_evaluate:
409
+ # see if there's a dependency
410
+ dependencies_are_all_met = True
411
+ # If there are no dependencies, this loop won't execute
412
+ # and the metric will be evaluated
413
+ if (
414
+ "depends_on" in metric_to_evaluate
415
+ and len(metric_to_evaluate["depends_on"]) > 0
416
+ ):
417
+ # here, we have a metric with 1+ dependencies
418
+ # ALL of these dependencies must be satisfied
419
+
420
+ # we determine whether a given metric is a match if it matches
421
+ # 1 - the id
422
+ # 2 - the metric_name
423
+ # 3 - the metric_min_value
424
+ # 4 - the metric_max_value
425
+ # not meeting ANY of them will short-circuit the loop and cause the eval to not evaluate
426
+ # check all dependencies
427
+ for dependency in metric_to_evaluate["depends_on"]:
428
+ # for each dependency, assume it's not met
429
+ # if it's in the list AND its values meet the criteria, it's met
430
+ dependency_is_met = False
431
+ # if a specific metric_name was specified, you need to match exactly:
432
+ for em in evaluated_metrics:
433
+ # 'depends_on' will have all fields populated at this point
434
+ if em["id"] == dependency["parent_id"]:
435
+ if (
436
+ em["metric_value"] >= dependency["metric_min_value"]
437
+ and em["metric_value"] <= dependency["metric_max_value"]
438
+ ):
439
+ # this specific dependency was met - can quit looking
440
+ dependency_is_met = True
441
+ break
442
+ else:
443
+ logger.debug(
444
+ f"Metric value '{em['metric_value']}' not in range for dependency id='{dependency['parent_id']}'."
445
+ )
446
+ if not dependency_is_met:
447
+ dependencies_are_all_met = False
448
+ # if even one dependency is not met - don't do the evaluation
449
+ break
450
+ if dependencies_are_all_met:
451
+ # pass through arguments, but add 'self' as the turn
452
+ # ONLY call if dependencies are ALL met
453
+ # TODO - maybe in the future we'll want to add the computed value from
454
+ # the dependency through as an argument here
455
+ metric_results = self.compute_metric(object, **metric_to_evaluate)
456
+ evaluated_metrics.extend(metric_results)
457
+ else:
458
+ logger.debug(
459
+ f"Skipping metric '{em['metric_name']}' (id='{em['id']}') due to unmet dependencies."
460
+ )
461
+ return evaluated_metrics
462
+
463
+ def compute_metric(
464
+ self,
465
+ object: Union[Thread, Turn, Message, ToolCall],
466
+ evaluation_name: str,
467
+ evaluation_type: str,
468
+ metric_level: str,
469
+ kwargs: dict,
470
+ context_only: bool = None,
471
+ depends_on: list = None,
472
+ id: int = None,
473
+ notes: str = None, # just a placeholder
474
+ ) -> list[dict]:
475
+ if evaluation_type == "function":
476
+ metrics = self.compute_function_metric(
477
+ function_name=evaluation_name,
478
+ metric_kwargs=kwargs,
479
+ metric_level=metric_level,
480
+ context_only=context_only,
481
+ input_object=object,
482
+ depends_on=depends_on,
483
+ id=id,
484
+ )
485
+ elif evaluation_type == "rubric":
486
+ metrics = self.compute_rubric_metric(
487
+ rubric_name=evaluation_name,
488
+ metric_kwargs=kwargs,
489
+ metric_level=metric_level,
490
+ object=object,
491
+ depends_on=depends_on,
492
+ id=id,
493
+ )
494
+ else:
495
+ raise ValueError(
496
+ f"The argument evaluation_type provided to compute_metric is invalid. Must be one of 'function' or 'rubric'. You passed '{type}'."
497
+ )
498
+ self._validate_metrics(metrics)
499
+ return metrics
500
+
501
+ def _validate_metrics(self, metrics: list[dict]):
502
+ for m in metrics:
503
+ if m.get("evaluation_type", None) is None:
504
+ raise ValueError(
505
+ f"Metric '{m}' does not have a value for the key `type`."
506
+ )
507
+ if m.get("metric_value", None) is None:
508
+ raise ValueError(
509
+ f"Metric '{m}' does not have a value for the key `metric_value`."
510
+ )
511
+
512
+ def invoke_function(
513
+ self,
514
+ metric_function: callable,
515
+ metric_level: eval_schema.MetricLevel,
516
+ input_object: function_types.AnyFunctionObjectInput,
517
+ metric_kwargs: dict,
518
+ context_only: bool,
519
+ ):
520
+ function_input = function_types.get_function_input(
521
+ metric_function, metric_level, input_object, context_only
522
+ )
523
+ metrics_result = metric_function(function_input, **metric_kwargs)
524
+ return metrics_result
525
+
526
+ def find_function(self, function_name: str):
527
+ for function_module in self.function_modules:
528
+ if hasattr(function_module, function_name) and callable(
529
+ getattr(function_module, function_name)
530
+ ):
531
+ metric_function = getattr(function_module, function_name)
532
+ metric_source = inspect.getsource(metric_function)
533
+ return metric_function, metric_source
534
+ raise ValueError(
535
+ f"Metric function with name '{function_name}' was not found in any of the '{len(self.function_modules)}' provided function modules."
536
+ )
537
+
538
+ def compute_function_metric(
539
+ self,
540
+ function_name: str,
541
+ metric_kwargs: dict,
542
+ input_object: Union[Thread, Turn, Message, ToolCall],
543
+ metric_level: eval_schema.MetricLevel,
544
+ context_only: bool,
545
+ depends_on: list,
546
+ id: int,
547
+ ):
548
+ # this is NOT a method - it's a function b/c we want it to be able to return multiple metrics, if more than one is returned
549
+ # they share most of the same information though so it's convenient to have them constructed similarly
550
+ # will return a list of dictionaries
551
+
552
+ # Check if the function exists in any of the function namespaces
553
+ metric_function, metric_source = self.find_function(function_name)
554
+ metrics_result = self.invoke_function(
555
+ metric_function, metric_level, input_object, metric_kwargs, context_only
556
+ )
557
+
558
+ base_result = {
559
+ metric_level.lower(): input_object,
560
+ "evaluation_name": function_name,
561
+ "evaluation_type": "function",
562
+ "metric_level": metric_level,
563
+ "kwargs": metric_kwargs,
564
+ "source": metric_source, # TODO - put this back?
565
+ "context_only": context_only,
566
+ "depends_on": depends_on,
567
+ "id": id,
568
+ }
569
+ # now deal with output
570
+ if isinstance(metrics_result, float) or isinstance(metrics_result, int):
571
+ result = copy.deepcopy(base_result)
572
+ result["metric_name"] = function_name
573
+ result["metric_value"] = metrics_result
574
+ return [result]
575
+ elif isinstance(metrics_result, dict):
576
+ result_list = []
577
+ # TODO rethink this behavior
578
+ for k, v in metrics_result.items():
579
+ result = copy.deepcopy(base_result)
580
+ if "metric_name" in result and result["metric_name"] != k:
581
+ logger.warning(
582
+ f"Overriding metric_name in metric result with '{k}' (was '{result['metric_name']}')."
583
+ )
584
+ result["metric_name"] = k
585
+ result["metric_value"] = float(v)
586
+ result_list.append(result)
587
+ return result_list
588
+ elif isinstance(metrics_result, list):
589
+ result_list = []
590
+
591
+ for entry in metrics_result:
592
+ result = copy.deepcopy(base_result)
593
+ result["metric_name"] = entry.get("name", None)
594
+ result["metric_value"] = float(entry.get("value", None))
595
+ result_list.append(result)
596
+ return result_list
597
+ else:
598
+ raise ValueError(
599
+ f"The metric type returned from '{metric_function}' is not a supported type. It must be one of `list`, `int`, `float`, or `dict`. You supplied '{type(metrics_result)}'."
600
+ )
601
+
602
+ def compute_rubric_metric(
603
+ self,
604
+ rubric_name: str,
605
+ metric_kwargs: dict,
606
+ object: Union[Thread, Turn, Message],
607
+ metric_level: str,
608
+ depends_on: list,
609
+ id: int,
610
+ ):
611
+ if self.rubrics is not None:
612
+ rubrics = self.rubrics
613
+ else:
614
+ rubrics = json.loads(object.evalsetrun.rubrics)
615
+ if rubric_name not in rubrics:
616
+ raise ValueError(
617
+ f"You requested a rubric called '{rubric_name}', but only these were found: {rubrics.keys()}."
618
+ )
619
+
620
+ prompt = rubrics.get(rubric_name).get("prompt", "")
621
+
622
+ # format input for rubric
623
+ conversation, context, content, tool_calls = object.format_input_for_rubric()
624
+ # conversation : all turns; context: all turns without the last entry; completion: only the last entry
625
+ # use three keywords:
626
+ # #{conversation} -- The whole conversation
627
+ # #{context} -- The previous turns without the current entry
628
+ # #{content} -- Only the current turn / message / toolcall depending on the metric_level
629
+ # for the future: add {compeltion} under the condition of do_completion == True
630
+
631
+ # Add verfication steps before populating the rubric
632
+ # case 1: {conversation} and {context} should not go together
633
+ # case 2: {completion} and {content} should not go together
634
+ # case 3: if there is a {completion}, do_completion should be true
635
+
636
+ if "{conversation}" in prompt and "{context}" in prompt:
637
+ raise Exception(
638
+ "Your rubric should not have both {conversation} and {context}. Please check the README file for more information about how to write FlexEval rubrics."
639
+ )
640
+
641
+ if "{completion}" in prompt and "{content}" in prompt:
642
+ raise Exception(
643
+ "Your rubric should not have both {content} and {completion}. Please check the README file for more information about how to write FlexEval rubrics."
644
+ )
645
+
646
+ if "{completion}" in prompt and not object.evalsetrun.do_completion:
647
+ raise Exception(
648
+ "Your rubric has {completion}, but in your test specification for this rubric evaluation, do_completion is not True. Please check the README file for more information about how to write FlexEval rubrics."
649
+ )
650
+
651
+ populated_prompt = prompt.format(
652
+ conversation=conversation,
653
+ context=context,
654
+ content=content,
655
+ tool_calls=tool_calls,
656
+ )
657
+
658
+ # with do_completion == True, only the completion is evaluated with or without the context.
659
+ if object.evalsetrun.do_completion and "{completion}" in prompt:
660
+ # TODO revisit this logic
661
+ # also included object.is_completion, which only works for Message rubrics
662
+ # but we can in principle check for a message in either a turn or a thread with is_flexeval_completion true
663
+ populated_prompt = prompt.format(completion=content)
664
+
665
+ choice_scores = rubrics.get(rubric_name).get("choice_scores")
666
+
667
+ # get rubric grader
668
+ if object.evalsetrun.grader_llm is None or object.evalsetrun.grader_llm == "":
669
+ raise ValueError(
670
+ "Attempting to evaluate a rubric metric, but no grader LLM defined."
671
+ )
672
+ grader_completion_function = json.loads(object.evalsetrun.grader_llm)
673
+ if grader_completion_function is None or len(grader_completion_function) == 0:
674
+ raise ValueError(
675
+ "Attempting to evaluate a rubric metric, but no grader LLM defined."
676
+ )
677
+ grader_completion_fn_name = grader_completion_function.get(
678
+ "function_name", None
679
+ )
680
+ grader_completion_fn_kwargs = grader_completion_function.get("kwargs", {})
681
+ if hasattr(completion_functions, grader_completion_fn_name) and hasattr(
682
+ completion_functions, grader_completion_fn_name
683
+ ):
684
+ completion_function = getattr(
685
+ completion_functions, grader_completion_fn_name, None
686
+ )
687
+ DEFAULT_COT_TEMPLATE = "\nBefore answering, I will reason in a step-by-step manner as to get the right answer, then conclude with the answer in the format requested."
688
+ ANSWER_PROMPT = f""""
689
+ First, write out in a step by step manner your reasoning to be sure that your conclusion is correct. Avoid simply stating the correct answer at the outset. Then print only a single choice from {list(choice_scores.keys())} (without quotes or punctuation) on its own line corresponding to the correct answer. At the end, repeat just the answer by itself on a new line.
690
+
691
+ Reasoning:""".strip()
692
+ # This is the call to the grader completion function
693
+ completion = completion_function(
694
+ conversation_history=[
695
+ {"role": "system", "content": "You are a helpful assistant."},
696
+ {"role": "assistant", "content": DEFAULT_COT_TEMPLATE},
697
+ {"role": "user", "content": populated_prompt + ANSWER_PROMPT},
698
+ {"role": "assistant", "content": DEFAULT_COT_TEMPLATE},
699
+ ],
700
+ **grader_completion_fn_kwargs,
701
+ )
702
+ completion_text = completion["choices"][0]["message"]["content"]
703
+
704
+ # use logic from Evals
705
+ # https://github.com/openai/evals/blob/d3dc89042ddee879a68a326fdb37716ee518640c/evals/elsuite/modelgraded/classify_utils.py#L29
706
+ choice_scores["__invalid__"] = None
707
+
708
+ def get_match(completion_text: str, choice_scores: dict):
709
+ MATCH_FNS = {
710
+ "include": lambda x, y: float(x in y),
711
+ "exact": lambda x, y: float(x == y),
712
+ "endswith": lambda x, y: x.endswith(y),
713
+ "starts_or_endswith": lambda x, y: x.startswith(y) or x.endswith(y),
714
+ }
715
+ lines = completion_text.strip().split("\n")
716
+ lines = lines[::-1] # reverse lines
717
+ for line in lines:
718
+ line = line.strip()
719
+ line = "".join(c for c in line if c not in string.punctuation)
720
+ if not line:
721
+ continue
722
+ for choice in choice_scores.keys():
723
+ if MATCH_FNS["starts_or_endswith"](line, choice):
724
+ return choice
725
+ return "__invalid__"
726
+
727
+ score = get_match(
728
+ completion_text=completion_text, choice_scores=choice_scores
729
+ )
730
+ result = {
731
+ metric_level.lower(): object,
732
+ "metric_name": rubric_name,
733
+ "evaluation_name": rubric_name,
734
+ "evaluation_type": "rubric",
735
+ "id": id,
736
+ "kwargs": metric_kwargs,
737
+ "depends_on": depends_on,
738
+ "source": populated_prompt,
739
+ "metric_level": metric_level,
740
+ "metric_value": choice_scores[score],
741
+ "rubric_prompt": populated_prompt,
742
+ "rubric_completion": completion_text,
743
+ "rubric_model": completion.get("model", None),
744
+ "rubric_completion_tokens": completion.get("usage", {}).get(
745
+ "completion_tokens", 0
746
+ ),
747
+ "rubric_prompt_tokens": completion.get("usage", {}).get(
748
+ "prompt_tokens", 0
749
+ ),
750
+ "rubric_score": score,
751
+ }
752
+ return [result]
753
+
754
+
755
+ def add_all_metrics_to_objects(iterable_of_objects, metrics):
756
+ """
757
+ Adds all metric instances in metrics_for_level to each instance of
758
+ an evaluable object (e.g., Turn, Thread, Message, or ToolCall) in
759
+ iterable_of_objects. This addition is done by appending to the
760
+ `metrics_to_evaluate` field, which all instances in iterable_of_objects
761
+ should have.
762
+
763
+ :param iterable_of_objects: list of objects that have a metrics_to_evaluate field
764
+ :param metrics: list of metric instances to add to each object
765
+ """
766
+ for object in iterable_of_objects:
767
+ # Field metrics_to_evaluate initialized in constructor
768
+ # metric dependencies happen WITHIN turns, rather than across
769
+ # this means I can associate a sequence of metrics within each turn
770
+ # but then have the turns execute them in parallel
771
+ # each turn will keep track of its own set of metrics
772
+ # Keeping this as a loop to do the rubric_count appropriately
773
+ object.metrics_to_evaluate = object.metrics_to_evaluate + metrics
774
+
775
+
776
+ def count_rubric_metrics(iterable_of_objects):
777
+ """
778
+ Returns the total number of rubric type metrics in
779
+ the metrics_to_evaluate field in each object.
780
+
781
+ :param iterable_of_objects: list of objects that have a metrics_to_evaluate field
782
+ """
783
+ rubric_count = 0
784
+ for object in iterable_of_objects:
785
+ for metric_instance in object.metrics_to_evaluate:
786
+ if metric_instance.get("evaluation_type") == "rubric":
787
+ rubric_count += 1
788
+ return rubric_count