python-flexeval 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. flexeval/__init__.py +11 -0
  2. flexeval/__main__.py +11 -0
  3. flexeval/classes/__init__.py +15 -0
  4. flexeval/classes/base.py +32 -0
  5. flexeval/classes/dataset.py +82 -0
  6. flexeval/classes/eval_runner.py +158 -0
  7. flexeval/classes/eval_set_run.py +32 -0
  8. flexeval/classes/message.py +183 -0
  9. flexeval/classes/metric.py +55 -0
  10. flexeval/classes/thread.py +79 -0
  11. flexeval/classes/tool_call.py +51 -0
  12. flexeval/classes/turn.py +206 -0
  13. flexeval/cli.py +104 -0
  14. flexeval/completions.py +147 -0
  15. flexeval/compute_metrics.py +788 -0
  16. flexeval/config.yaml +23 -0
  17. flexeval/configuration/__init__.py +1 -0
  18. flexeval/configuration/completion_functions.py +231 -0
  19. flexeval/configuration/evals.yaml +864 -0
  20. flexeval/configuration/function_metrics.py +650 -0
  21. flexeval/configuration/rubric_metrics.yaml +194 -0
  22. flexeval/data_loader.py +513 -0
  23. flexeval/db_utils.py +38 -0
  24. flexeval/dependency_graph.py +234 -0
  25. flexeval/eval_schema.json +256 -0
  26. flexeval/function_types.py +173 -0
  27. flexeval/helpers.py +52 -0
  28. flexeval/io/__init__.py +1 -0
  29. flexeval/io/parsers/yaml_parser.py +69 -0
  30. flexeval/log_utils.py +34 -0
  31. flexeval/metrics/__init__.py +8 -0
  32. flexeval/metrics/access.py +28 -0
  33. flexeval/metrics/save.py +39 -0
  34. flexeval/rubric.py +62 -0
  35. flexeval/run_utils.py +65 -0
  36. flexeval/runner.py +132 -0
  37. flexeval/schema/__init__.py +11 -0
  38. flexeval/schema/config_schema.py +46 -0
  39. flexeval/schema/eval_schema.py +163 -0
  40. flexeval/schema/evalrun_schema.py +97 -0
  41. flexeval/schema/rubric_schema.py +40 -0
  42. flexeval/schema/schema_utils.py +26 -0
  43. python_flexeval-0.1.5.dist-info/METADATA +118 -0
  44. python_flexeval-0.1.5.dist-info/RECORD +47 -0
  45. python_flexeval-0.1.5.dist-info/WHEEL +4 -0
  46. python_flexeval-0.1.5.dist-info/entry_points.txt +2 -0
  47. python_flexeval-0.1.5.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,51 @@
1
+ import peewee as pw
2
+
3
+ from flexeval.classes.base import BaseModel
4
+ from flexeval.classes.dataset import Dataset
5
+ from flexeval.classes.eval_set_run import EvalSetRun
6
+ from flexeval.classes.message import Message
7
+ from flexeval.classes.thread import Thread
8
+ from flexeval.classes.turn import Turn
9
+
10
+
11
+ class ToolCall(BaseModel):
12
+ """Holds a single component of a single turn
13
+ Corresponds to one output of a node in LangGraph
14
+ or one Turn in jsonl
15
+ """
16
+
17
+ id = pw.IntegerField(primary_key=True)
18
+
19
+ evalsetrun = pw.ForeignKeyField(EvalSetRun, backref="toolcalls")
20
+ dataset = pw.ForeignKeyField(Dataset, backref="toolcalls")
21
+ thread = pw.ForeignKeyField(Thread, backref="toolcalls")
22
+ message = pw.ForeignKeyField(Message, backref="toolcalls")
23
+ turn = pw.ForeignKeyField(Turn, backref="toolcalls")
24
+
25
+ function_name = pw.TextField()
26
+ args = pw.TextField()
27
+ additional_kwargs = (
28
+ pw.TextField()
29
+ ) # holds any additional info we want to save with a tool call
30
+ tool_call_id = pw.TextField()
31
+ response_content = pw.TextField()
32
+
33
+ def __init__(self, **kwargs):
34
+ super().__init__(**kwargs)
35
+ self.metrics_to_evaluate = []
36
+
37
+ def get_dict_representation(self) -> dict:
38
+ """
39
+ Get a dictionary representation of the content of this toolcall,
40
+ suitable for passing to function metrics that need a standard Python
41
+ data structure representation of a tool call.
42
+
43
+ Keys in returned dictionary are role, content (for the response content),
44
+ function_name, and args.
45
+ """
46
+ return {
47
+ "role": "toolcall",
48
+ "content": self.response_content,
49
+ "args": self.args,
50
+ "function_name": self.function_name,
51
+ }
@@ -0,0 +1,206 @@
1
+ import copy
2
+ import json
3
+ import logging
4
+
5
+ import peewee as pw
6
+ from playhouse.shortcuts import model_to_dict
7
+
8
+ from flexeval.classes.base import BaseModel
9
+ from flexeval.classes.dataset import Dataset
10
+ from flexeval.classes.eval_set_run import EvalSetRun
11
+ from flexeval.classes.thread import Thread
12
+ from flexeval.configuration import completion_functions
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class Turn(BaseModel):
18
+ """Holds a single turn
19
+ In a conversational exchange, each 'Turn' holds information
20
+ from 1 or more outputs from the same source or role in sequence
21
+ """
22
+
23
+ id = pw.IntegerField(primary_key=True)
24
+
25
+ evalsetrun = pw.ForeignKeyField(EvalSetRun, backref="turns")
26
+ dataset = pw.ForeignKeyField(Dataset, backref="turns")
27
+ thread = pw.ForeignKeyField(Thread, backref="turns")
28
+ index_in_thread = pw.IntegerField()
29
+ role = pw.TextField()
30
+
31
+ def __init__(self, **kwargs):
32
+ super().__init__(**kwargs)
33
+ self.metrics_to_evaluate = []
34
+
35
+ def get_completion(self):
36
+ # only get a completion if this is the final turn - we probably don't want to branch from mid-conversation
37
+ if self.is_final_turn_in_input:
38
+ completion_config = json.loads(self.evalsetrun.completion_llm)
39
+ completion_fn_name = completion_config.get("function_name", None)
40
+ completion_function_kwargs = completion_config.get("kwargs", None)
41
+
42
+ # Check if the function name exists in the global namespace and call it
43
+ if hasattr(completion_functions, completion_fn_name) and hasattr(
44
+ completion_functions, completion_fn_name
45
+ ):
46
+ completion_function = getattr(
47
+ completion_functions, completion_fn_name, None
48
+ )
49
+ completion = completion_function(
50
+ conversation_history=self.get_formatted_prompt(
51
+ include_system_prompt=False
52
+ ),
53
+ **completion_function_kwargs,
54
+ )
55
+ else:
56
+ logger.warning(
57
+ "In completion_functions.py: No callable function named "
58
+ + completion_fn_name
59
+ + " found."
60
+ )
61
+ completion = None
62
+
63
+ # "completion" will be the output of an existing completion function
64
+ # We need to make the message object
65
+ # and probably also a turn object
66
+
67
+ # which means it'll have a structure like this
68
+ # TODO - make this a requirement of the completion functions?
69
+ # - make the completion function just return content?
70
+ # {"choices": [{"message": {"content": "hi", "role": "assistant"}}]}
71
+ result = model_to_dict(self, exclude=[self.id])
72
+ result["evalsetrun"] = self.evalsetrun
73
+ result["dataset"] = self.dataset
74
+ result["datasetrow"] = self.datasetrow
75
+ result["turn_number"] = self.turn_number + 1
76
+ result["role"] = "assistant"
77
+ result["context"] = self.get_formatted_prompt(include_system_prompt=False)
78
+ result["is_final_turn_in_input"] = False # b/c it's not in input
79
+ self.is_final_turn_in_input = False
80
+ result["is_completion"] = True
81
+ result["completion"] = completion
82
+ result["model"] = completion.get("model", None)
83
+ result["prompt_tokens"] = completion.get("usage", {}).get(
84
+ "prompt_tokens", None
85
+ ) / len(completion.get("choices", [1]))
86
+ result["completion_tokens"] = completion.get("usage", {}).get(
87
+ "completion_tokens", None
88
+ ) / len(
89
+ completion.get("choices", [1])
90
+ ) # TODO - use tiktoken here instead?? this will just give the average
91
+
92
+ result_list = []
93
+ for ix, choice in enumerate(completion["choices"]):
94
+ temp = copy.deepcopy(result)
95
+ temp["tool_used"] = choice["message"].get("tool_calls", None)
96
+ temp["turn"] = [choice["message"]]
97
+ temp["content"] = choice["message"]["content"]
98
+ temp["completion_number"] = ix + 1
99
+ result_list.append(temp)
100
+
101
+ return result_list
102
+ else:
103
+ return None
104
+
105
+ def get_context(self, include_system_prompt=False) -> list[dict[str, str]]:
106
+ """
107
+ Context is the context of the first message in the turn
108
+ """
109
+ context = ""
110
+ for message in self.messages:
111
+ context = message.context
112
+ break
113
+ context = json.loads(context)
114
+ if not include_system_prompt:
115
+ context = [
116
+ cur_dict for cur_dict in context if cur_dict.get("role") != "system"
117
+ ]
118
+ return context
119
+
120
+ def get_formatted_prompt(self, include_system_prompt=False):
121
+ formatted_prompt = []
122
+ if include_system_prompt:
123
+ if hasattr(self, "system_prompt"):
124
+ # TODO this is a bit hacky; it allows for an override of the system prompt by setting it on the Turn object
125
+ system_prompt = self.system_prompt
126
+ else:
127
+ system_prompt = self.thread.system_prompt
128
+ # if system prompt not available in this thread, we have nothing to include
129
+ if system_prompt is not None:
130
+ formatted_prompt.append({"role": "system", "content": system_prompt})
131
+
132
+ # context = json.loads(self.context)
133
+ context = self.get_context()
134
+
135
+ if len(context) > 0:
136
+ formatted_prompt += context # TODO - we might just want a subset of this
137
+
138
+ formatted_prompt += self.get_content()
139
+ return formatted_prompt
140
+
141
+ def get_content(self, include_toolcalls=True, include_tool_messages=True):
142
+ """
143
+ Content is a list of dictionaries where each dictionary
144
+ contains the role and content of messages and tool calls
145
+ in the turn. Each tool call appears after the message it's
146
+ associated with. If toolcalls are not desired, pass False
147
+ to include_toolcalls.
148
+ """
149
+ content = []
150
+ for message in self.messages:
151
+ if include_tool_messages or message.langgraph_message_type != "ToolMessage":
152
+ content.append({"role": message.role, "content": message.content})
153
+ if include_toolcalls:
154
+ for toolcall in message.toolcalls:
155
+ content.append(toolcall.get_dict_representation())
156
+
157
+ return content
158
+
159
+ def format_input_for_rubric(
160
+ self, include_system_prompt: bool = False, include_tool_messages: bool = False
161
+ ):
162
+ """This is the 'public' method that returns the info for this Turn"""
163
+ output_minus_completion = ""
164
+ if include_system_prompt:
165
+ output_minus_completion.append(
166
+ {"role": "system", "content": self.system_prompt}
167
+ )
168
+
169
+ for msg in self.get_context(): # input[:-1]:
170
+ # this outputs user: XYZ, or assistant: 123
171
+ if len(msg["content"]) > 0 and (
172
+ include_tool_messages or msg.get("langgraph_role") != "tool"
173
+ ):
174
+ output_minus_completion += f"{msg['role']}: {msg['content']}\n"
175
+ # Including role as prefix to account for both tool and assistant
176
+ completion = ""
177
+ for msg in self.get_content(include_tool_messages=include_tool_messages):
178
+ if len(msg["content"]) > 0:
179
+ completion += f"{msg['role']}: {msg['content']}\n"
180
+ # completion = f"{self.get_content()['content']}"
181
+ output = output_minus_completion + completion
182
+
183
+ tool_call_text = ""
184
+ for tc in self.toolcalls:
185
+ printme = True
186
+ # if there's a property called tc.additional_kwargs and it evalues to False...don't print
187
+ if hasattr(tc, "additional_kwargs"):
188
+ if not json.loads(tc.additional_kwargs).get("print", False):
189
+ printme = False
190
+ if printme:
191
+ tool_call_text += """
192
+
193
+ Function name: {function_name}
194
+ Input arguments: {args}
195
+ Function output: {response_content}
196
+ """.format(
197
+ function_name=tc.function_name,
198
+ args=tc.args,
199
+ response_content=tc.response_content,
200
+ )
201
+
202
+ # output - all turns
203
+ # output_minus_completion - all turns except the last
204
+ # completion - last turn
205
+ # tool_call_text - all tool calls
206
+ return output, output_minus_completion, completion, tool_call_text
flexeval/cli.py ADDED
@@ -0,0 +1,104 @@
1
+ """CLI commands."""
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Annotated
6
+
7
+ import typer
8
+
9
+ from flexeval import db_utils, log_utils, runner
10
+ from flexeval.io.parsers import yaml_parser
11
+ from flexeval.metrics import access
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def global_callback(
17
+ ctx: typer.Context,
18
+ log_level: Annotated[
19
+ log_utils.LogLevel, typer.Option(help="Log level to use.", case_sensitive=False)
20
+ ] = log_utils.LogLevel.INFO.value,
21
+ ):
22
+ """FlexEval offers a number of CLI commands for convenience."""
23
+ log_utils.set_up_logging(log_utils.LogLevel.get_logging_constant(log_level.value))
24
+
25
+
26
+ app = typer.Typer(callback=global_callback)
27
+
28
+
29
+ @app.command(no_args_is_help=True)
30
+ def run(
31
+ eval_run_yaml_path: Annotated[
32
+ Path, typer.Argument(help="YAML file specifying the Eval Run.")
33
+ ],
34
+ ):
35
+ """Run FlexEval using the given YAML Eval Run configuration."""
36
+ eval_run = yaml_parser.load_eval_run_from_yaml(eval_run_yaml_path)
37
+ runner.run(eval_run)
38
+
39
+
40
+ @app.command(no_args_is_help=True)
41
+ def summarize_metrics(
42
+ eval_run_yaml_path: Annotated[
43
+ Path | None,
44
+ typer.Argument(
45
+ help="YAML file specifying the Eval Run.",
46
+ exists=True,
47
+ dir_okay=False,
48
+ ),
49
+ ] = None,
50
+ database_path: Annotated[
51
+ Path | None,
52
+ typer.Option(help="Database path.", exists=True, dir_okay=False),
53
+ ] = None,
54
+ ):
55
+ """Print a summary of computed metrics."""
56
+ if eval_run_yaml_path is not None:
57
+ if database_path is not None:
58
+ logger.warning(
59
+ "Ignoring database_path since eval_run_yaml_path is provided."
60
+ )
61
+ eval_run = yaml_parser.load_eval_run_from_yaml(eval_run_yaml_path)
62
+ database_path = eval_run.database_path
63
+
64
+ if database_path is None:
65
+ raise ValueError("Must provide an Eval Run or a database path.")
66
+ db_utils.initialize_database(database_path)
67
+ counts = access.count_dict_values(access.get_all_metrics())
68
+ print("Summary of metric value counts:")
69
+ for key, counter in counts.items():
70
+ print(" " + key)
71
+ for value, count in counter.most_common(5):
72
+ value = str(value)
73
+ display_limit = 50
74
+ if len(value) > display_limit:
75
+ value = value[: display_limit - 3].replace("\n", "\\n") + "..."
76
+ print(f" {value}: {count}")
77
+
78
+
79
+ @app.command(no_args_is_help=True)
80
+ def run_eval_by_name(
81
+ input_data: Annotated[list[Path], typer.Option(help="Input data filepaths.")],
82
+ database_path: Annotated[Path, typer.Option(help="Output database path.")],
83
+ eval_name: str,
84
+ evals_path: Path,
85
+ config_path: Path,
86
+ clear_tables: bool = False,
87
+ ):
88
+ """Run an eval by name."""
89
+ runner.run_from_name_args(
90
+ input_data,
91
+ database_path,
92
+ eval_name,
93
+ config_path,
94
+ evals_path,
95
+ clear_tables=clear_tables,
96
+ )
97
+
98
+
99
+ def main():
100
+ app()
101
+
102
+
103
+ if __name__ == "__main__":
104
+ main()
@@ -0,0 +1,147 @@
1
+ """Completing conversations using LLMs."""
2
+
3
+ import json
4
+ import logging
5
+ from collections.abc import Callable
6
+ from concurrent.futures import Future, ThreadPoolExecutor, as_completed
7
+
8
+ from flexeval import classes
9
+ from flexeval.configuration import completion_functions
10
+ from flexeval.schema.eval_schema import CompletionLlm
11
+ from flexeval.schema.evalrun_schema import EvalRun
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def get_completion_function(completion_llm: CompletionLlm) -> Callable:
17
+ """Identify a completion function given the completion LLM configuration.
18
+
19
+ Args:
20
+ completion_llm (CompletionLlm): The description of the function to retrieve.
21
+
22
+ Raises:
23
+ ValueError: If not found in :mod:`~flexeval.configuration.completion_functions` or in globals by function_name.
24
+
25
+ Returns:
26
+ Callable: The completion function.
27
+ """
28
+ # Check if the function name exists in the global namespace and call it
29
+ if hasattr(completion_functions, completion_llm.function_name):
30
+ completion_function = getattr(
31
+ completion_functions, completion_llm.function_name
32
+ )
33
+ return completion_function
34
+ if completion_llm.function_name in globals() and callable(
35
+ globals()[completion_llm.function_name]
36
+ ):
37
+ # TODO probably don't allow this by default, and also offer a way to specify other places where completion_functions can live.
38
+ logger.debug("Found function in globals(), which could be trouble.")
39
+ return globals()[completion_llm.function_name]
40
+ else:
41
+ raise ValueError(
42
+ "No completion function named " + completion_llm.function_name + " found."
43
+ )
44
+
45
+
46
+ def get_completion(turn: classes.turn.Turn, completion_llm: CompletionLlm):
47
+ completion_function = get_completion_function(completion_llm)
48
+ conversation_history = turn.get_formatted_prompt(
49
+ include_system_prompt=completion_llm.include_system_prompt
50
+ )
51
+ completion = completion_function(
52
+ conversation_history=conversation_history,
53
+ **completion_llm.kwargs,
54
+ )
55
+ return completion
56
+
57
+
58
+ def get_completions(eval_run: EvalRun, evalsetrun: classes.eval_set_run.EvalSetRun):
59
+ n_workers = eval_run.config.max_workers
60
+ if n_workers == 1:
61
+ for thread in evalsetrun.threads:
62
+ # select last turn in thread
63
+ if len(thread.turns) == 0:
64
+ continue
65
+ turn = (
66
+ thread.turns.select()
67
+ .order_by(classes.turn.Turn.index_in_thread.desc())
68
+ .first()
69
+ )
70
+ # TODO handle exceptions appropriately (#58)
71
+ completion = get_completion(turn, eval_run.eval.completion_llm)
72
+ if completion is None:
73
+ continue
74
+ save_completion(completion, turn, evalsetrun, eval_run)
75
+ else:
76
+ with ThreadPoolExecutor(max_workers=n_workers) as executor:
77
+ futures: dict[Future, classes.turn.Turn] = {}
78
+ for thread in evalsetrun.threads:
79
+ if len(thread.turns) == 0:
80
+ continue
81
+ turn = (
82
+ thread.turns.select()
83
+ .order_by(classes.turn.Turn.index_in_thread.desc())
84
+ .first()
85
+ )
86
+ future = executor.submit(
87
+ get_completion, turn, eval_run.eval.completion_llm
88
+ )
89
+ futures[future] = turn
90
+
91
+ for future in as_completed(futures):
92
+ turn = futures[future]
93
+ completion = future.result()
94
+ if completion is None:
95
+ continue
96
+ save_completion(completion, turn, evalsetrun, eval_run)
97
+
98
+
99
+ def save_completion(
100
+ completion: dict,
101
+ turn: classes.turn.Turn,
102
+ evalsetrun: classes.eval_set_run.EvalSetRun,
103
+ eval_run: EvalRun,
104
+ ):
105
+ new_message_completions = completion["choices"]
106
+ if len(new_message_completions) > 1:
107
+ logger.warning(
108
+ "We don't yet support multiple completions, using just the first one."
109
+ )
110
+ new_message_completion = new_message_completions[0]["message"]
111
+ if turn.role == "assistant":
112
+ # don't create a new Turn, because this completion is a continuation of an existing assistant turn
113
+ new_turn = turn
114
+ else:
115
+ new_turn = classes.turn.Turn.create(
116
+ evalsetrun=evalsetrun,
117
+ dataset=turn.dataset,
118
+ thread=turn.thread,
119
+ index_in_thread=turn.index_in_thread + 1,
120
+ role=new_message_completion["role"],
121
+ )
122
+ prev_message = (
123
+ turn.messages.select()
124
+ .order_by(classes.message.Message.index_in_thread.desc())
125
+ .first()
126
+ )
127
+ new_message_context = prev_message.get_context()
128
+ new_message_context.append(
129
+ {"role": prev_message.role, "content": prev_message.content}
130
+ )
131
+ classes.message.Message.create(
132
+ evalsetrun=evalsetrun,
133
+ dataset=turn.dataset,
134
+ thread=turn.thread,
135
+ turn=new_turn,
136
+ index_in_thread=prev_message.index_in_thread + 1,
137
+ role=new_message_completion["role"],
138
+ content=new_message_completion["content"],
139
+ context=json.dumps(new_message_context),
140
+ system_prompt=prev_message.system_prompt,
141
+ is_flexeval_completion=True,
142
+ # TODO I have no idea what model_name is supposed to be, and the completion function name doesn't seem that useful
143
+ model_name=eval_run.eval.completion_llm.function_name,
144
+ prompt_tokens=completion.get("usage", {}).get("prompt_tokens", None),
145
+ completion_tokens=completion.get("usage", {}).get("completion_tokens", None),
146
+ )
147
+ # TODO also save any toolcalls