python-flexeval 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flexeval/__init__.py +11 -0
- flexeval/__main__.py +11 -0
- flexeval/classes/__init__.py +15 -0
- flexeval/classes/base.py +32 -0
- flexeval/classes/dataset.py +82 -0
- flexeval/classes/eval_runner.py +158 -0
- flexeval/classes/eval_set_run.py +32 -0
- flexeval/classes/message.py +183 -0
- flexeval/classes/metric.py +55 -0
- flexeval/classes/thread.py +79 -0
- flexeval/classes/tool_call.py +51 -0
- flexeval/classes/turn.py +206 -0
- flexeval/cli.py +104 -0
- flexeval/completions.py +147 -0
- flexeval/compute_metrics.py +788 -0
- flexeval/config.yaml +23 -0
- flexeval/configuration/__init__.py +1 -0
- flexeval/configuration/completion_functions.py +231 -0
- flexeval/configuration/evals.yaml +864 -0
- flexeval/configuration/function_metrics.py +650 -0
- flexeval/configuration/rubric_metrics.yaml +194 -0
- flexeval/data_loader.py +513 -0
- flexeval/db_utils.py +38 -0
- flexeval/dependency_graph.py +234 -0
- flexeval/eval_schema.json +256 -0
- flexeval/function_types.py +173 -0
- flexeval/helpers.py +52 -0
- flexeval/io/__init__.py +1 -0
- flexeval/io/parsers/yaml_parser.py +69 -0
- flexeval/log_utils.py +34 -0
- flexeval/metrics/__init__.py +8 -0
- flexeval/metrics/access.py +28 -0
- flexeval/metrics/save.py +39 -0
- flexeval/rubric.py +62 -0
- flexeval/run_utils.py +65 -0
- flexeval/runner.py +132 -0
- flexeval/schema/__init__.py +11 -0
- flexeval/schema/config_schema.py +46 -0
- flexeval/schema/eval_schema.py +163 -0
- flexeval/schema/evalrun_schema.py +97 -0
- flexeval/schema/rubric_schema.py +40 -0
- flexeval/schema/schema_utils.py +26 -0
- python_flexeval-0.1.5.dist-info/METADATA +118 -0
- python_flexeval-0.1.5.dist-info/RECORD +47 -0
- python_flexeval-0.1.5.dist-info/WHEEL +4 -0
- python_flexeval-0.1.5.dist-info/entry_points.txt +2 -0
- python_flexeval-0.1.5.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import peewee as pw
|
|
2
|
+
|
|
3
|
+
from flexeval.classes.base import BaseModel
|
|
4
|
+
from flexeval.classes.dataset import Dataset
|
|
5
|
+
from flexeval.classes.eval_set_run import EvalSetRun
|
|
6
|
+
from flexeval.classes.message import Message
|
|
7
|
+
from flexeval.classes.thread import Thread
|
|
8
|
+
from flexeval.classes.turn import Turn
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ToolCall(BaseModel):
|
|
12
|
+
"""Holds a single component of a single turn
|
|
13
|
+
Corresponds to one output of a node in LangGraph
|
|
14
|
+
or one Turn in jsonl
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
id = pw.IntegerField(primary_key=True)
|
|
18
|
+
|
|
19
|
+
evalsetrun = pw.ForeignKeyField(EvalSetRun, backref="toolcalls")
|
|
20
|
+
dataset = pw.ForeignKeyField(Dataset, backref="toolcalls")
|
|
21
|
+
thread = pw.ForeignKeyField(Thread, backref="toolcalls")
|
|
22
|
+
message = pw.ForeignKeyField(Message, backref="toolcalls")
|
|
23
|
+
turn = pw.ForeignKeyField(Turn, backref="toolcalls")
|
|
24
|
+
|
|
25
|
+
function_name = pw.TextField()
|
|
26
|
+
args = pw.TextField()
|
|
27
|
+
additional_kwargs = (
|
|
28
|
+
pw.TextField()
|
|
29
|
+
) # holds any additional info we want to save with a tool call
|
|
30
|
+
tool_call_id = pw.TextField()
|
|
31
|
+
response_content = pw.TextField()
|
|
32
|
+
|
|
33
|
+
def __init__(self, **kwargs):
|
|
34
|
+
super().__init__(**kwargs)
|
|
35
|
+
self.metrics_to_evaluate = []
|
|
36
|
+
|
|
37
|
+
def get_dict_representation(self) -> dict:
|
|
38
|
+
"""
|
|
39
|
+
Get a dictionary representation of the content of this toolcall,
|
|
40
|
+
suitable for passing to function metrics that need a standard Python
|
|
41
|
+
data structure representation of a tool call.
|
|
42
|
+
|
|
43
|
+
Keys in returned dictionary are role, content (for the response content),
|
|
44
|
+
function_name, and args.
|
|
45
|
+
"""
|
|
46
|
+
return {
|
|
47
|
+
"role": "toolcall",
|
|
48
|
+
"content": self.response_content,
|
|
49
|
+
"args": self.args,
|
|
50
|
+
"function_name": self.function_name,
|
|
51
|
+
}
|
flexeval/classes/turn.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import peewee as pw
|
|
6
|
+
from playhouse.shortcuts import model_to_dict
|
|
7
|
+
|
|
8
|
+
from flexeval.classes.base import BaseModel
|
|
9
|
+
from flexeval.classes.dataset import Dataset
|
|
10
|
+
from flexeval.classes.eval_set_run import EvalSetRun
|
|
11
|
+
from flexeval.classes.thread import Thread
|
|
12
|
+
from flexeval.configuration import completion_functions
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Turn(BaseModel):
|
|
18
|
+
"""Holds a single turn
|
|
19
|
+
In a conversational exchange, each 'Turn' holds information
|
|
20
|
+
from 1 or more outputs from the same source or role in sequence
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
id = pw.IntegerField(primary_key=True)
|
|
24
|
+
|
|
25
|
+
evalsetrun = pw.ForeignKeyField(EvalSetRun, backref="turns")
|
|
26
|
+
dataset = pw.ForeignKeyField(Dataset, backref="turns")
|
|
27
|
+
thread = pw.ForeignKeyField(Thread, backref="turns")
|
|
28
|
+
index_in_thread = pw.IntegerField()
|
|
29
|
+
role = pw.TextField()
|
|
30
|
+
|
|
31
|
+
def __init__(self, **kwargs):
|
|
32
|
+
super().__init__(**kwargs)
|
|
33
|
+
self.metrics_to_evaluate = []
|
|
34
|
+
|
|
35
|
+
def get_completion(self):
|
|
36
|
+
# only get a completion if this is the final turn - we probably don't want to branch from mid-conversation
|
|
37
|
+
if self.is_final_turn_in_input:
|
|
38
|
+
completion_config = json.loads(self.evalsetrun.completion_llm)
|
|
39
|
+
completion_fn_name = completion_config.get("function_name", None)
|
|
40
|
+
completion_function_kwargs = completion_config.get("kwargs", None)
|
|
41
|
+
|
|
42
|
+
# Check if the function name exists in the global namespace and call it
|
|
43
|
+
if hasattr(completion_functions, completion_fn_name) and hasattr(
|
|
44
|
+
completion_functions, completion_fn_name
|
|
45
|
+
):
|
|
46
|
+
completion_function = getattr(
|
|
47
|
+
completion_functions, completion_fn_name, None
|
|
48
|
+
)
|
|
49
|
+
completion = completion_function(
|
|
50
|
+
conversation_history=self.get_formatted_prompt(
|
|
51
|
+
include_system_prompt=False
|
|
52
|
+
),
|
|
53
|
+
**completion_function_kwargs,
|
|
54
|
+
)
|
|
55
|
+
else:
|
|
56
|
+
logger.warning(
|
|
57
|
+
"In completion_functions.py: No callable function named "
|
|
58
|
+
+ completion_fn_name
|
|
59
|
+
+ " found."
|
|
60
|
+
)
|
|
61
|
+
completion = None
|
|
62
|
+
|
|
63
|
+
# "completion" will be the output of an existing completion function
|
|
64
|
+
# We need to make the message object
|
|
65
|
+
# and probably also a turn object
|
|
66
|
+
|
|
67
|
+
# which means it'll have a structure like this
|
|
68
|
+
# TODO - make this a requirement of the completion functions?
|
|
69
|
+
# - make the completion function just return content?
|
|
70
|
+
# {"choices": [{"message": {"content": "hi", "role": "assistant"}}]}
|
|
71
|
+
result = model_to_dict(self, exclude=[self.id])
|
|
72
|
+
result["evalsetrun"] = self.evalsetrun
|
|
73
|
+
result["dataset"] = self.dataset
|
|
74
|
+
result["datasetrow"] = self.datasetrow
|
|
75
|
+
result["turn_number"] = self.turn_number + 1
|
|
76
|
+
result["role"] = "assistant"
|
|
77
|
+
result["context"] = self.get_formatted_prompt(include_system_prompt=False)
|
|
78
|
+
result["is_final_turn_in_input"] = False # b/c it's not in input
|
|
79
|
+
self.is_final_turn_in_input = False
|
|
80
|
+
result["is_completion"] = True
|
|
81
|
+
result["completion"] = completion
|
|
82
|
+
result["model"] = completion.get("model", None)
|
|
83
|
+
result["prompt_tokens"] = completion.get("usage", {}).get(
|
|
84
|
+
"prompt_tokens", None
|
|
85
|
+
) / len(completion.get("choices", [1]))
|
|
86
|
+
result["completion_tokens"] = completion.get("usage", {}).get(
|
|
87
|
+
"completion_tokens", None
|
|
88
|
+
) / len(
|
|
89
|
+
completion.get("choices", [1])
|
|
90
|
+
) # TODO - use tiktoken here instead?? this will just give the average
|
|
91
|
+
|
|
92
|
+
result_list = []
|
|
93
|
+
for ix, choice in enumerate(completion["choices"]):
|
|
94
|
+
temp = copy.deepcopy(result)
|
|
95
|
+
temp["tool_used"] = choice["message"].get("tool_calls", None)
|
|
96
|
+
temp["turn"] = [choice["message"]]
|
|
97
|
+
temp["content"] = choice["message"]["content"]
|
|
98
|
+
temp["completion_number"] = ix + 1
|
|
99
|
+
result_list.append(temp)
|
|
100
|
+
|
|
101
|
+
return result_list
|
|
102
|
+
else:
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
def get_context(self, include_system_prompt=False) -> list[dict[str, str]]:
|
|
106
|
+
"""
|
|
107
|
+
Context is the context of the first message in the turn
|
|
108
|
+
"""
|
|
109
|
+
context = ""
|
|
110
|
+
for message in self.messages:
|
|
111
|
+
context = message.context
|
|
112
|
+
break
|
|
113
|
+
context = json.loads(context)
|
|
114
|
+
if not include_system_prompt:
|
|
115
|
+
context = [
|
|
116
|
+
cur_dict for cur_dict in context if cur_dict.get("role") != "system"
|
|
117
|
+
]
|
|
118
|
+
return context
|
|
119
|
+
|
|
120
|
+
def get_formatted_prompt(self, include_system_prompt=False):
|
|
121
|
+
formatted_prompt = []
|
|
122
|
+
if include_system_prompt:
|
|
123
|
+
if hasattr(self, "system_prompt"):
|
|
124
|
+
# TODO this is a bit hacky; it allows for an override of the system prompt by setting it on the Turn object
|
|
125
|
+
system_prompt = self.system_prompt
|
|
126
|
+
else:
|
|
127
|
+
system_prompt = self.thread.system_prompt
|
|
128
|
+
# if system prompt not available in this thread, we have nothing to include
|
|
129
|
+
if system_prompt is not None:
|
|
130
|
+
formatted_prompt.append({"role": "system", "content": system_prompt})
|
|
131
|
+
|
|
132
|
+
# context = json.loads(self.context)
|
|
133
|
+
context = self.get_context()
|
|
134
|
+
|
|
135
|
+
if len(context) > 0:
|
|
136
|
+
formatted_prompt += context # TODO - we might just want a subset of this
|
|
137
|
+
|
|
138
|
+
formatted_prompt += self.get_content()
|
|
139
|
+
return formatted_prompt
|
|
140
|
+
|
|
141
|
+
def get_content(self, include_toolcalls=True, include_tool_messages=True):
|
|
142
|
+
"""
|
|
143
|
+
Content is a list of dictionaries where each dictionary
|
|
144
|
+
contains the role and content of messages and tool calls
|
|
145
|
+
in the turn. Each tool call appears after the message it's
|
|
146
|
+
associated with. If toolcalls are not desired, pass False
|
|
147
|
+
to include_toolcalls.
|
|
148
|
+
"""
|
|
149
|
+
content = []
|
|
150
|
+
for message in self.messages:
|
|
151
|
+
if include_tool_messages or message.langgraph_message_type != "ToolMessage":
|
|
152
|
+
content.append({"role": message.role, "content": message.content})
|
|
153
|
+
if include_toolcalls:
|
|
154
|
+
for toolcall in message.toolcalls:
|
|
155
|
+
content.append(toolcall.get_dict_representation())
|
|
156
|
+
|
|
157
|
+
return content
|
|
158
|
+
|
|
159
|
+
def format_input_for_rubric(
|
|
160
|
+
self, include_system_prompt: bool = False, include_tool_messages: bool = False
|
|
161
|
+
):
|
|
162
|
+
"""This is the 'public' method that returns the info for this Turn"""
|
|
163
|
+
output_minus_completion = ""
|
|
164
|
+
if include_system_prompt:
|
|
165
|
+
output_minus_completion.append(
|
|
166
|
+
{"role": "system", "content": self.system_prompt}
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
for msg in self.get_context(): # input[:-1]:
|
|
170
|
+
# this outputs user: XYZ, or assistant: 123
|
|
171
|
+
if len(msg["content"]) > 0 and (
|
|
172
|
+
include_tool_messages or msg.get("langgraph_role") != "tool"
|
|
173
|
+
):
|
|
174
|
+
output_minus_completion += f"{msg['role']}: {msg['content']}\n"
|
|
175
|
+
# Including role as prefix to account for both tool and assistant
|
|
176
|
+
completion = ""
|
|
177
|
+
for msg in self.get_content(include_tool_messages=include_tool_messages):
|
|
178
|
+
if len(msg["content"]) > 0:
|
|
179
|
+
completion += f"{msg['role']}: {msg['content']}\n"
|
|
180
|
+
# completion = f"{self.get_content()['content']}"
|
|
181
|
+
output = output_minus_completion + completion
|
|
182
|
+
|
|
183
|
+
tool_call_text = ""
|
|
184
|
+
for tc in self.toolcalls:
|
|
185
|
+
printme = True
|
|
186
|
+
# if there's a property called tc.additional_kwargs and it evalues to False...don't print
|
|
187
|
+
if hasattr(tc, "additional_kwargs"):
|
|
188
|
+
if not json.loads(tc.additional_kwargs).get("print", False):
|
|
189
|
+
printme = False
|
|
190
|
+
if printme:
|
|
191
|
+
tool_call_text += """
|
|
192
|
+
|
|
193
|
+
Function name: {function_name}
|
|
194
|
+
Input arguments: {args}
|
|
195
|
+
Function output: {response_content}
|
|
196
|
+
""".format(
|
|
197
|
+
function_name=tc.function_name,
|
|
198
|
+
args=tc.args,
|
|
199
|
+
response_content=tc.response_content,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# output - all turns
|
|
203
|
+
# output_minus_completion - all turns except the last
|
|
204
|
+
# completion - last turn
|
|
205
|
+
# tool_call_text - all tool calls
|
|
206
|
+
return output, output_minus_completion, completion, tool_call_text
|
flexeval/cli.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""CLI commands."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Annotated
|
|
6
|
+
|
|
7
|
+
import typer
|
|
8
|
+
|
|
9
|
+
from flexeval import db_utils, log_utils, runner
|
|
10
|
+
from flexeval.io.parsers import yaml_parser
|
|
11
|
+
from flexeval.metrics import access
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def global_callback(
|
|
17
|
+
ctx: typer.Context,
|
|
18
|
+
log_level: Annotated[
|
|
19
|
+
log_utils.LogLevel, typer.Option(help="Log level to use.", case_sensitive=False)
|
|
20
|
+
] = log_utils.LogLevel.INFO.value,
|
|
21
|
+
):
|
|
22
|
+
"""FlexEval offers a number of CLI commands for convenience."""
|
|
23
|
+
log_utils.set_up_logging(log_utils.LogLevel.get_logging_constant(log_level.value))
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
app = typer.Typer(callback=global_callback)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@app.command(no_args_is_help=True)
|
|
30
|
+
def run(
|
|
31
|
+
eval_run_yaml_path: Annotated[
|
|
32
|
+
Path, typer.Argument(help="YAML file specifying the Eval Run.")
|
|
33
|
+
],
|
|
34
|
+
):
|
|
35
|
+
"""Run FlexEval using the given YAML Eval Run configuration."""
|
|
36
|
+
eval_run = yaml_parser.load_eval_run_from_yaml(eval_run_yaml_path)
|
|
37
|
+
runner.run(eval_run)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@app.command(no_args_is_help=True)
|
|
41
|
+
def summarize_metrics(
|
|
42
|
+
eval_run_yaml_path: Annotated[
|
|
43
|
+
Path | None,
|
|
44
|
+
typer.Argument(
|
|
45
|
+
help="YAML file specifying the Eval Run.",
|
|
46
|
+
exists=True,
|
|
47
|
+
dir_okay=False,
|
|
48
|
+
),
|
|
49
|
+
] = None,
|
|
50
|
+
database_path: Annotated[
|
|
51
|
+
Path | None,
|
|
52
|
+
typer.Option(help="Database path.", exists=True, dir_okay=False),
|
|
53
|
+
] = None,
|
|
54
|
+
):
|
|
55
|
+
"""Print a summary of computed metrics."""
|
|
56
|
+
if eval_run_yaml_path is not None:
|
|
57
|
+
if database_path is not None:
|
|
58
|
+
logger.warning(
|
|
59
|
+
"Ignoring database_path since eval_run_yaml_path is provided."
|
|
60
|
+
)
|
|
61
|
+
eval_run = yaml_parser.load_eval_run_from_yaml(eval_run_yaml_path)
|
|
62
|
+
database_path = eval_run.database_path
|
|
63
|
+
|
|
64
|
+
if database_path is None:
|
|
65
|
+
raise ValueError("Must provide an Eval Run or a database path.")
|
|
66
|
+
db_utils.initialize_database(database_path)
|
|
67
|
+
counts = access.count_dict_values(access.get_all_metrics())
|
|
68
|
+
print("Summary of metric value counts:")
|
|
69
|
+
for key, counter in counts.items():
|
|
70
|
+
print(" " + key)
|
|
71
|
+
for value, count in counter.most_common(5):
|
|
72
|
+
value = str(value)
|
|
73
|
+
display_limit = 50
|
|
74
|
+
if len(value) > display_limit:
|
|
75
|
+
value = value[: display_limit - 3].replace("\n", "\\n") + "..."
|
|
76
|
+
print(f" {value}: {count}")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@app.command(no_args_is_help=True)
|
|
80
|
+
def run_eval_by_name(
|
|
81
|
+
input_data: Annotated[list[Path], typer.Option(help="Input data filepaths.")],
|
|
82
|
+
database_path: Annotated[Path, typer.Option(help="Output database path.")],
|
|
83
|
+
eval_name: str,
|
|
84
|
+
evals_path: Path,
|
|
85
|
+
config_path: Path,
|
|
86
|
+
clear_tables: bool = False,
|
|
87
|
+
):
|
|
88
|
+
"""Run an eval by name."""
|
|
89
|
+
runner.run_from_name_args(
|
|
90
|
+
input_data,
|
|
91
|
+
database_path,
|
|
92
|
+
eval_name,
|
|
93
|
+
config_path,
|
|
94
|
+
evals_path,
|
|
95
|
+
clear_tables=clear_tables,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def main():
|
|
100
|
+
app()
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
if __name__ == "__main__":
|
|
104
|
+
main()
|
flexeval/completions.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
"""Completing conversations using LLMs."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
|
7
|
+
|
|
8
|
+
from flexeval import classes
|
|
9
|
+
from flexeval.configuration import completion_functions
|
|
10
|
+
from flexeval.schema.eval_schema import CompletionLlm
|
|
11
|
+
from flexeval.schema.evalrun_schema import EvalRun
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_completion_function(completion_llm: CompletionLlm) -> Callable:
|
|
17
|
+
"""Identify a completion function given the completion LLM configuration.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
completion_llm (CompletionLlm): The description of the function to retrieve.
|
|
21
|
+
|
|
22
|
+
Raises:
|
|
23
|
+
ValueError: If not found in :mod:`~flexeval.configuration.completion_functions` or in globals by function_name.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Callable: The completion function.
|
|
27
|
+
"""
|
|
28
|
+
# Check if the function name exists in the global namespace and call it
|
|
29
|
+
if hasattr(completion_functions, completion_llm.function_name):
|
|
30
|
+
completion_function = getattr(
|
|
31
|
+
completion_functions, completion_llm.function_name
|
|
32
|
+
)
|
|
33
|
+
return completion_function
|
|
34
|
+
if completion_llm.function_name in globals() and callable(
|
|
35
|
+
globals()[completion_llm.function_name]
|
|
36
|
+
):
|
|
37
|
+
# TODO probably don't allow this by default, and also offer a way to specify other places where completion_functions can live.
|
|
38
|
+
logger.debug("Found function in globals(), which could be trouble.")
|
|
39
|
+
return globals()[completion_llm.function_name]
|
|
40
|
+
else:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
"No completion function named " + completion_llm.function_name + " found."
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_completion(turn: classes.turn.Turn, completion_llm: CompletionLlm):
|
|
47
|
+
completion_function = get_completion_function(completion_llm)
|
|
48
|
+
conversation_history = turn.get_formatted_prompt(
|
|
49
|
+
include_system_prompt=completion_llm.include_system_prompt
|
|
50
|
+
)
|
|
51
|
+
completion = completion_function(
|
|
52
|
+
conversation_history=conversation_history,
|
|
53
|
+
**completion_llm.kwargs,
|
|
54
|
+
)
|
|
55
|
+
return completion
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_completions(eval_run: EvalRun, evalsetrun: classes.eval_set_run.EvalSetRun):
|
|
59
|
+
n_workers = eval_run.config.max_workers
|
|
60
|
+
if n_workers == 1:
|
|
61
|
+
for thread in evalsetrun.threads:
|
|
62
|
+
# select last turn in thread
|
|
63
|
+
if len(thread.turns) == 0:
|
|
64
|
+
continue
|
|
65
|
+
turn = (
|
|
66
|
+
thread.turns.select()
|
|
67
|
+
.order_by(classes.turn.Turn.index_in_thread.desc())
|
|
68
|
+
.first()
|
|
69
|
+
)
|
|
70
|
+
# TODO handle exceptions appropriately (#58)
|
|
71
|
+
completion = get_completion(turn, eval_run.eval.completion_llm)
|
|
72
|
+
if completion is None:
|
|
73
|
+
continue
|
|
74
|
+
save_completion(completion, turn, evalsetrun, eval_run)
|
|
75
|
+
else:
|
|
76
|
+
with ThreadPoolExecutor(max_workers=n_workers) as executor:
|
|
77
|
+
futures: dict[Future, classes.turn.Turn] = {}
|
|
78
|
+
for thread in evalsetrun.threads:
|
|
79
|
+
if len(thread.turns) == 0:
|
|
80
|
+
continue
|
|
81
|
+
turn = (
|
|
82
|
+
thread.turns.select()
|
|
83
|
+
.order_by(classes.turn.Turn.index_in_thread.desc())
|
|
84
|
+
.first()
|
|
85
|
+
)
|
|
86
|
+
future = executor.submit(
|
|
87
|
+
get_completion, turn, eval_run.eval.completion_llm
|
|
88
|
+
)
|
|
89
|
+
futures[future] = turn
|
|
90
|
+
|
|
91
|
+
for future in as_completed(futures):
|
|
92
|
+
turn = futures[future]
|
|
93
|
+
completion = future.result()
|
|
94
|
+
if completion is None:
|
|
95
|
+
continue
|
|
96
|
+
save_completion(completion, turn, evalsetrun, eval_run)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def save_completion(
|
|
100
|
+
completion: dict,
|
|
101
|
+
turn: classes.turn.Turn,
|
|
102
|
+
evalsetrun: classes.eval_set_run.EvalSetRun,
|
|
103
|
+
eval_run: EvalRun,
|
|
104
|
+
):
|
|
105
|
+
new_message_completions = completion["choices"]
|
|
106
|
+
if len(new_message_completions) > 1:
|
|
107
|
+
logger.warning(
|
|
108
|
+
"We don't yet support multiple completions, using just the first one."
|
|
109
|
+
)
|
|
110
|
+
new_message_completion = new_message_completions[0]["message"]
|
|
111
|
+
if turn.role == "assistant":
|
|
112
|
+
# don't create a new Turn, because this completion is a continuation of an existing assistant turn
|
|
113
|
+
new_turn = turn
|
|
114
|
+
else:
|
|
115
|
+
new_turn = classes.turn.Turn.create(
|
|
116
|
+
evalsetrun=evalsetrun,
|
|
117
|
+
dataset=turn.dataset,
|
|
118
|
+
thread=turn.thread,
|
|
119
|
+
index_in_thread=turn.index_in_thread + 1,
|
|
120
|
+
role=new_message_completion["role"],
|
|
121
|
+
)
|
|
122
|
+
prev_message = (
|
|
123
|
+
turn.messages.select()
|
|
124
|
+
.order_by(classes.message.Message.index_in_thread.desc())
|
|
125
|
+
.first()
|
|
126
|
+
)
|
|
127
|
+
new_message_context = prev_message.get_context()
|
|
128
|
+
new_message_context.append(
|
|
129
|
+
{"role": prev_message.role, "content": prev_message.content}
|
|
130
|
+
)
|
|
131
|
+
classes.message.Message.create(
|
|
132
|
+
evalsetrun=evalsetrun,
|
|
133
|
+
dataset=turn.dataset,
|
|
134
|
+
thread=turn.thread,
|
|
135
|
+
turn=new_turn,
|
|
136
|
+
index_in_thread=prev_message.index_in_thread + 1,
|
|
137
|
+
role=new_message_completion["role"],
|
|
138
|
+
content=new_message_completion["content"],
|
|
139
|
+
context=json.dumps(new_message_context),
|
|
140
|
+
system_prompt=prev_message.system_prompt,
|
|
141
|
+
is_flexeval_completion=True,
|
|
142
|
+
# TODO I have no idea what model_name is supposed to be, and the completion function name doesn't seem that useful
|
|
143
|
+
model_name=eval_run.eval.completion_llm.function_name,
|
|
144
|
+
prompt_tokens=completion.get("usage", {}).get("prompt_tokens", None),
|
|
145
|
+
completion_tokens=completion.get("usage", {}).get("completion_tokens", None),
|
|
146
|
+
)
|
|
147
|
+
# TODO also save any toolcalls
|