python-flexeval 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flexeval/__about__.py +1 -1
- flexeval/classes/dataset.py +12 -72
- flexeval/classes/eval_set_run.py +18 -7
- flexeval/classes/jsonview.py +112 -0
- flexeval/classes/message.py +16 -5
- flexeval/classes/metric.py +0 -8
- flexeval/classes/thread.py +4 -2
- flexeval/classes/tool_call.py +0 -2
- flexeval/classes/turn.py +7 -5
- flexeval/completions.py +8 -5
- flexeval/compute_metrics.py +45 -32
- flexeval/configuration/evals.yaml +2 -25
- flexeval/data_loader.py +219 -302
- flexeval/db_utils.py +11 -2
- flexeval/dependency_graph.py +3 -3
- flexeval/eval_schema.json +0 -18
- flexeval/function_types.py +2 -13
- flexeval/metrics/save.py +12 -8
- flexeval/run_utils.py +163 -17
- flexeval/runner.py +6 -14
- flexeval/schema/config_schema.py +12 -0
- flexeval/schema/eval_schema.py +3 -0
- flexeval/schema/evalrun_schema.py +41 -10
- {python_flexeval-0.2.0.dist-info → python_flexeval-0.4.0.dist-info}/METADATA +3 -3
- python_flexeval-0.4.0.dist-info/RECORD +49 -0
- {python_flexeval-0.2.0.dist-info → python_flexeval-0.4.0.dist-info}/WHEEL +1 -1
- python_flexeval-0.2.0.dist-info/RECORD +0 -48
- {python_flexeval-0.2.0.dist-info → python_flexeval-0.4.0.dist-info}/entry_points.txt +0 -0
- {python_flexeval-0.2.0.dist-info → python_flexeval-0.4.0.dist-info}/licenses/LICENSE +0 -0
flexeval/data_loader.py
CHANGED
|
@@ -6,7 +6,6 @@ import pathlib
|
|
|
6
6
|
import random as rd
|
|
7
7
|
import sqlite3
|
|
8
8
|
|
|
9
|
-
from langchain.load.dump import dumps
|
|
10
9
|
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
|
11
10
|
|
|
12
11
|
from flexeval.classes.dataset import Dataset
|
|
@@ -14,10 +13,117 @@ from flexeval.classes.message import Message
|
|
|
14
13
|
from flexeval.classes.thread import Thread
|
|
15
14
|
from flexeval.classes.tool_call import ToolCall
|
|
16
15
|
from flexeval.classes.turn import Turn
|
|
16
|
+
from flexeval.schema.evalrun_schema import FileDataSource, FileFormatEnum
|
|
17
17
|
|
|
18
18
|
logger = logging.getLogger(__name__)
|
|
19
19
|
|
|
20
20
|
|
|
21
|
+
def load_thread_to_dataset(
|
|
22
|
+
thread_id: str | int,
|
|
23
|
+
thread: dict,
|
|
24
|
+
dataset: Dataset,
|
|
25
|
+
eval_run_thread_id: str | None = None,
|
|
26
|
+
) -> Thread:
|
|
27
|
+
if "input" not in thread:
|
|
28
|
+
raise ValueError(
|
|
29
|
+
f"Expected thread format is a dictionary containing at least an 'input' key. Instead, we found: {thread.keys()}"
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
# extract any metadata
|
|
33
|
+
thread_metadata = thread.copy()
|
|
34
|
+
del thread_metadata["input"]
|
|
35
|
+
|
|
36
|
+
context = []
|
|
37
|
+
thread_input = thread["input"]
|
|
38
|
+
|
|
39
|
+
# Get system prompt used in the thread - assuming only 1
|
|
40
|
+
for message in thread_input:
|
|
41
|
+
if message["role"] == "system":
|
|
42
|
+
system_prompt = message["content"]
|
|
43
|
+
break
|
|
44
|
+
else:
|
|
45
|
+
system_prompt = None
|
|
46
|
+
if system_prompt is not None:
|
|
47
|
+
# Add the system prompt as context
|
|
48
|
+
context.append({"role": "system", "content": system_prompt})
|
|
49
|
+
|
|
50
|
+
thread_object: Thread = Thread.create(
|
|
51
|
+
dataset=dataset,
|
|
52
|
+
jsonl_thread_id=thread_id,
|
|
53
|
+
eval_run_thread_id=eval_run_thread_id,
|
|
54
|
+
system_prompt=system_prompt,
|
|
55
|
+
metadata=json.dumps(thread_metadata),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Create messages
|
|
59
|
+
index_in_thread = 0
|
|
60
|
+
for message in thread_input:
|
|
61
|
+
if not isinstance(message, dict):
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"Can't load unknown object type; expected dict. Check JSONL format: {message}"
|
|
64
|
+
)
|
|
65
|
+
role = message.get("role", None)
|
|
66
|
+
if role != "system":
|
|
67
|
+
# System message shouldn't be added as a separate message
|
|
68
|
+
system_prompt_for_this_message = ""
|
|
69
|
+
if role != "user":
|
|
70
|
+
system_prompt_for_this_message = system_prompt
|
|
71
|
+
message_metadata = message.copy()
|
|
72
|
+
if "content" in message_metadata:
|
|
73
|
+
del message_metadata["content"]
|
|
74
|
+
if "role" in message_metadata:
|
|
75
|
+
del message_metadata["role"]
|
|
76
|
+
Message.create(
|
|
77
|
+
dataset=dataset,
|
|
78
|
+
thread=thread_object,
|
|
79
|
+
index_in_thread=index_in_thread,
|
|
80
|
+
role=role,
|
|
81
|
+
content=message.get("content", None),
|
|
82
|
+
context=json.dumps(context),
|
|
83
|
+
is_flexeval_completion=False,
|
|
84
|
+
system_prompt=system_prompt_for_this_message,
|
|
85
|
+
metadata=json.dumps(message_metadata),
|
|
86
|
+
)
|
|
87
|
+
# Update context
|
|
88
|
+
context.append({"role": role, "content": message.get("content", None)})
|
|
89
|
+
index_in_thread += 1
|
|
90
|
+
|
|
91
|
+
add_turns(thread_object)
|
|
92
|
+
return thread_object
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def load_file(
|
|
96
|
+
dataset: Dataset,
|
|
97
|
+
data_source: FileDataSource,
|
|
98
|
+
max_n_conversation_threads: int | None = None,
|
|
99
|
+
nb_evaluations_per_thread: int | None = 1,
|
|
100
|
+
):
|
|
101
|
+
if data_source.format == FileFormatEnum.jsonl:
|
|
102
|
+
load_jsonl(
|
|
103
|
+
dataset=dataset,
|
|
104
|
+
filename=data_source.path,
|
|
105
|
+
max_n_conversation_threads=max_n_conversation_threads,
|
|
106
|
+
nb_evaluations_per_thread=nb_evaluations_per_thread,
|
|
107
|
+
)
|
|
108
|
+
elif data_source.format == FileFormatEnum.langgraph_sqlite:
|
|
109
|
+
load_langgraph_sqlite(
|
|
110
|
+
dataset=dataset,
|
|
111
|
+
filename=data_source.path,
|
|
112
|
+
max_n_conversation_threads=max_n_conversation_threads,
|
|
113
|
+
nb_evaluations_per_thread=nb_evaluations_per_thread,
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
raise ValueError("Format not yet supported.")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def load_iterable(
|
|
120
|
+
dataset: Dataset,
|
|
121
|
+
iterable,
|
|
122
|
+
):
|
|
123
|
+
for thread_id, thread in enumerate(iterable):
|
|
124
|
+
load_thread_to_dataset(thread_id, thread, dataset)
|
|
125
|
+
|
|
126
|
+
|
|
21
127
|
def load_jsonl(
|
|
22
128
|
dataset: Dataset,
|
|
23
129
|
filename: str | pathlib.Path,
|
|
@@ -50,63 +156,16 @@ def load_jsonl(
|
|
|
50
156
|
nb_evaluations_per_thread = 1
|
|
51
157
|
|
|
52
158
|
for thread_id, thread in enumerate(all_lines):
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
eval_run_thread_id=str(thread_id)
|
|
62
|
-
+ "_"
|
|
63
|
-
+ str(thread_eval_run_id),
|
|
159
|
+
if thread_id in selected_thread_ids:
|
|
160
|
+
thread_json = json.loads(thread)
|
|
161
|
+
for thread_eval_run_id in range(
|
|
162
|
+
max(1, nb_evaluations_per_thread)
|
|
163
|
+
): # duplicate stored threads to enable averaged per-object evaluations
|
|
164
|
+
eval_run_thread_id = f"{thread_id}_{thread_eval_run_id}"
|
|
165
|
+
load_thread_to_dataset(
|
|
166
|
+
thread_id, thread_json, dataset, eval_run_thread_id
|
|
64
167
|
)
|
|
65
168
|
|
|
66
|
-
# Context
|
|
67
|
-
context = []
|
|
68
|
-
thread_input = json.loads(thread)["input"]
|
|
69
|
-
|
|
70
|
-
# Get system prompt used in the thread - assuming only 1
|
|
71
|
-
for message in thread_input:
|
|
72
|
-
if message["role"] == "system":
|
|
73
|
-
system_prompt = message["content"]
|
|
74
|
-
break
|
|
75
|
-
else:
|
|
76
|
-
system_prompt = None
|
|
77
|
-
if system_prompt is not None:
|
|
78
|
-
# Add the system prompt as context
|
|
79
|
-
context.append({"role": "system", "content": system_prompt})
|
|
80
|
-
|
|
81
|
-
# Create messages
|
|
82
|
-
index_in_thread = 0
|
|
83
|
-
for message in thread_input:
|
|
84
|
-
role = message.get("role", None)
|
|
85
|
-
if role != "system":
|
|
86
|
-
# System message shouldn't be added as a separate message
|
|
87
|
-
system_prompt_for_this_message = ""
|
|
88
|
-
if role != "user":
|
|
89
|
-
system_prompt_for_this_message = system_prompt
|
|
90
|
-
Message.create(
|
|
91
|
-
evalsetrun=dataset.evalsetrun,
|
|
92
|
-
dataset=dataset,
|
|
93
|
-
thread=thread_object,
|
|
94
|
-
index_in_thread=index_in_thread,
|
|
95
|
-
role=role,
|
|
96
|
-
content=message.get("content", None),
|
|
97
|
-
context=json.dumps(context),
|
|
98
|
-
metadata=message.get("metadata", None),
|
|
99
|
-
is_flexeval_completion=False,
|
|
100
|
-
system_prompt=system_prompt_for_this_message,
|
|
101
|
-
)
|
|
102
|
-
# Update context
|
|
103
|
-
context.append(
|
|
104
|
-
{"role": role, "content": message.get("content", None)}
|
|
105
|
-
)
|
|
106
|
-
index_in_thread += 1
|
|
107
|
-
|
|
108
|
-
add_turns(thread_object)
|
|
109
|
-
|
|
110
169
|
# TODO - should we add ToolCall here? Is there a standard way to represent them in jsonl?
|
|
111
170
|
|
|
112
171
|
|
|
@@ -116,24 +175,22 @@ def load_langgraph_sqlite(
|
|
|
116
175
|
max_n_conversation_threads: int | None = None,
|
|
117
176
|
nb_evaluations_per_thread: int | None = 1,
|
|
118
177
|
):
|
|
178
|
+
"""Load conversations from a LangGraph SQLite checkpoint database.
|
|
179
|
+
|
|
180
|
+
Reads the final checkpoint for each thread and extracts the cumulative
|
|
181
|
+
message list from channel_values.messages. Compatible with langgraph >= 1.0.
|
|
182
|
+
"""
|
|
119
183
|
serializer = JsonPlusSerializer()
|
|
120
184
|
|
|
121
185
|
with sqlite3.connect(filename) as conn:
|
|
122
|
-
# Set the row factory to sqlite3.Row
|
|
123
|
-
# allowing us to reference columns by name instead of index
|
|
124
186
|
conn.row_factory = sqlite3.Row
|
|
125
|
-
|
|
126
|
-
# Create a cursor object
|
|
127
187
|
cursor = conn.cursor()
|
|
128
188
|
verify_checkpoints_table_exists(cursor)
|
|
129
189
|
|
|
130
|
-
|
|
131
|
-
query = "PRAGMA wal_checkpoint(FULL);"
|
|
132
|
-
cursor.execute(query)
|
|
190
|
+
cursor.execute("PRAGMA wal_checkpoint(FULL);")
|
|
133
191
|
|
|
134
|
-
#
|
|
135
|
-
|
|
136
|
-
cursor.execute(query)
|
|
192
|
+
# Get distinct thread IDs
|
|
193
|
+
cursor.execute("SELECT DISTINCT thread_id FROM checkpoints")
|
|
137
194
|
thread_ids = cursor.fetchall()
|
|
138
195
|
|
|
139
196
|
nb_threads = len(thread_ids)
|
|
@@ -144,260 +201,125 @@ def load_langgraph_sqlite(
|
|
|
144
201
|
selected_thread_ids = rd.sample(thread_ids, max_n_conversation_threads)
|
|
145
202
|
else:
|
|
146
203
|
logger.debug(
|
|
147
|
-
f"You requested up to '{max_n_conversation_threads}' conversations
|
|
204
|
+
f"You requested up to '{max_n_conversation_threads}' conversations "
|
|
205
|
+
f"but only '{nb_threads}' are present in Sqlite dataset at '{filename}'."
|
|
148
206
|
)
|
|
149
207
|
selected_thread_ids = thread_ids
|
|
150
208
|
|
|
151
|
-
|
|
209
|
+
for thread_eval_run_id in range(max(1, nb_evaluations_per_thread)):
|
|
210
|
+
for thread_id_row in selected_thread_ids:
|
|
211
|
+
lg_thread_id = thread_id_row[0]
|
|
212
|
+
|
|
213
|
+
# Get the final checkpoint (highest step) for this thread
|
|
214
|
+
cursor.execute(
|
|
215
|
+
"""
|
|
216
|
+
SELECT *, json_extract(metadata, '$.step') as step
|
|
217
|
+
FROM checkpoints
|
|
218
|
+
WHERE thread_id = ?
|
|
219
|
+
ORDER BY json_extract(metadata, '$.step') DESC
|
|
220
|
+
LIMIT 1
|
|
221
|
+
""",
|
|
222
|
+
(lg_thread_id,),
|
|
223
|
+
)
|
|
224
|
+
final_row = cursor.fetchone()
|
|
225
|
+
if final_row is None:
|
|
226
|
+
logger.warning(f"No checkpoints found for thread '{lg_thread_id}'")
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
checkpoint = serializer.loads_typed(
|
|
230
|
+
(final_row["type"], final_row["checkpoint"])
|
|
231
|
+
)
|
|
232
|
+
lg_messages = checkpoint.get("channel_values", {}).get("messages", [])
|
|
233
|
+
|
|
234
|
+
if not lg_messages:
|
|
235
|
+
logger.warning(
|
|
236
|
+
f"No messages in final checkpoint for thread '{lg_thread_id}'"
|
|
237
|
+
)
|
|
238
|
+
continue
|
|
152
239
|
|
|
153
|
-
for thread_eval_run_id in range(
|
|
154
|
-
max(1, nb_evaluations_per_thread)
|
|
155
|
-
): # duplicate stored threads for averaged evaluation results
|
|
156
|
-
for thread_id in selected_thread_ids:
|
|
157
240
|
thread = Thread.create(
|
|
158
|
-
evalsetrun=dataset.evalsetrun,
|
|
159
241
|
dataset=dataset,
|
|
160
|
-
langgraph_thread_id=
|
|
161
|
-
eval_run_thread_id=
|
|
162
|
-
+ "_"
|
|
163
|
-
+ str(thread_eval_run_id),
|
|
242
|
+
langgraph_thread_id=lg_thread_id,
|
|
243
|
+
eval_run_thread_id=f"{lg_thread_id}_{thread_eval_run_id}",
|
|
164
244
|
)
|
|
165
245
|
|
|
166
|
-
#
|
|
167
|
-
|
|
168
|
-
cursor.execute(query)
|
|
169
|
-
completion_list = cursor.fetchall()
|
|
170
|
-
|
|
171
|
-
# context has to be reset at the start of every thread
|
|
246
|
+
# Map message types to FlexEval roles
|
|
247
|
+
# Tools are counted as assistant per existing convention
|
|
172
248
|
context = []
|
|
173
|
-
|
|
249
|
+
system_prompt = None
|
|
174
250
|
tool_calls_dict = {}
|
|
175
251
|
tool_responses_dict = {}
|
|
176
|
-
|
|
177
|
-
# system prompt reset for every thread
|
|
178
|
-
system_prompt = None
|
|
252
|
+
tool_additional_kwargs_dict = {}
|
|
179
253
|
|
|
180
|
-
for
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
254
|
+
for index_in_thread, msg in enumerate(lg_messages):
|
|
255
|
+
msg_type = msg.type # 'human', 'ai', 'tool'
|
|
256
|
+
role = "user" if msg_type == "human" else "assistant"
|
|
257
|
+
content = msg.content
|
|
258
|
+
|
|
259
|
+
# Extract tool call info
|
|
260
|
+
tool_calls = getattr(msg, "tool_calls", []) or []
|
|
261
|
+
tool_call_ids = [tc["id"] for tc in tool_calls]
|
|
262
|
+
response_meta = getattr(msg, "response_metadata", {}) or {}
|
|
263
|
+
token_usage = response_meta.get("token_usage", {})
|
|
264
|
+
additional_kwargs = getattr(msg, "additional_kwargs", {}) or {}
|
|
265
|
+
|
|
266
|
+
Message.create(
|
|
267
|
+
dataset=dataset,
|
|
268
|
+
thread=thread,
|
|
269
|
+
index_in_thread=index_in_thread,
|
|
270
|
+
role=role,
|
|
271
|
+
content=content,
|
|
272
|
+
context=json.dumps(context),
|
|
273
|
+
is_flexeval_completion=False,
|
|
274
|
+
system_prompt=system_prompt,
|
|
275
|
+
# language model stats
|
|
276
|
+
tool_calls=json.dumps(tool_calls),
|
|
277
|
+
tool_call_ids=tool_call_ids,
|
|
278
|
+
n_tool_calls=len(tool_calls),
|
|
279
|
+
prompt_tokens=token_usage.get("prompt_tokens"),
|
|
280
|
+
completion_tokens=token_usage.get("completion_tokens"),
|
|
281
|
+
model_name=response_meta.get("model_name"),
|
|
282
|
+
# langgraph metadata
|
|
283
|
+
langgraph_ts=checkpoint.get("ts"),
|
|
284
|
+
langgraph_thread_id=lg_thread_id,
|
|
285
|
+
langgraph_checkpoint_id=final_row["checkpoint_id"],
|
|
286
|
+
langgraph_parent_checkpoint_id=final_row[
|
|
287
|
+
"parent_checkpoint_id"
|
|
288
|
+
],
|
|
289
|
+
langgraph_metadata=final_row["metadata"],
|
|
290
|
+
langgraph_message_type=msg_type,
|
|
291
|
+
langgraph_type=msg_type,
|
|
184
292
|
)
|
|
185
|
-
# metadata is the state update for that row
|
|
186
|
-
metadata = json.loads(completion_row["metadata"])
|
|
187
|
-
# IDs from langgraph
|
|
188
293
|
|
|
189
|
-
|
|
190
|
-
|
|
294
|
+
# Build context for next message
|
|
295
|
+
context.append({"role": role, "content": content})
|
|
296
|
+
|
|
297
|
+
# Track tool calls and responses for ToolCall creation
|
|
298
|
+
if msg_type == "tool":
|
|
299
|
+
tool_call_id = getattr(msg, "tool_call_id", None)
|
|
300
|
+
if tool_call_id:
|
|
301
|
+
tool_responses_dict[tool_call_id] = content
|
|
191
302
|
else:
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
# key (str) -- graph 'node' that produced the message (or 'human')
|
|
198
|
-
# value (list) -- list of 'message' data structures with id, kwargs, etc
|
|
199
|
-
# {
|
|
200
|
-
# 'node_name':{
|
|
201
|
-
# "messages":[
|
|
202
|
-
# {
|
|
203
|
-
# 'id': "XYZ"
|
|
204
|
-
# 'kwargs':{
|
|
205
|
-
# "content": 'text of the message',
|
|
206
|
-
# "additional_kwargs": {}
|
|
207
|
-
# },
|
|
208
|
-
# }
|
|
209
|
-
# ]
|
|
210
|
-
#
|
|
211
|
-
# }
|
|
212
|
-
# }
|
|
213
|
-
|
|
214
|
-
# user input condition
|
|
215
|
-
if metadata.get("source") == "input":
|
|
216
|
-
# NOTE: I think with the updated logging of HumanMessage with langgraph, we don't need this case
|
|
217
|
-
update_dict = {}
|
|
218
|
-
# this will be a dictionary we can add to
|
|
219
|
-
# key is 'input', as in human input
|
|
220
|
-
update_dict["input"] = {"messages": []}
|
|
221
|
-
# print("metadata keys:", metadata["writes"].keys())
|
|
222
|
-
# the very first message in input in a thread seems to include
|
|
223
|
-
# the system prompt, not a message that was sent by the user.
|
|
224
|
-
# the system promptdoesn't seem to be set anywhere else, so
|
|
225
|
-
# using that as the system prompt for the thread.
|
|
226
|
-
messagecount = 0
|
|
227
|
-
for msg in metadata["writes"]["__start__"]["messages"]:
|
|
228
|
-
if messagecount == 0 and metadata["step"] == -1:
|
|
229
|
-
system_prompt = msg["kwargs"]["content"]
|
|
230
|
-
messagecount += 1
|
|
231
|
-
else:
|
|
232
|
-
message = {}
|
|
233
|
-
message["id"] = [
|
|
234
|
-
"HumanMessage"
|
|
235
|
-
] # LangGraph has a list here
|
|
236
|
-
message["kwargs"] = {}
|
|
237
|
-
message["kwargs"]["content"] = msg
|
|
238
|
-
message["kwargs"]["type"] = "human"
|
|
239
|
-
update_dict["input"]["messages"].append(message)
|
|
240
|
-
# will be used below
|
|
241
|
-
role = "user"
|
|
242
|
-
|
|
243
|
-
# machine input condition
|
|
244
|
-
elif metadata.get("source") == "loop":
|
|
245
|
-
# This already has a list of messages with kwargs, etc
|
|
246
|
-
update_dict = metadata.get("writes")
|
|
247
|
-
# I think 'system_prompt' is empty by default and not stored here unless
|
|
248
|
-
# it's included in the LangGraph state
|
|
249
|
-
checkpoint_system_prompt = checkpoint.get(
|
|
250
|
-
"channel_values", {}
|
|
251
|
-
).get("system_prompt")
|
|
252
|
-
if checkpoint_system_prompt is not None:
|
|
253
|
-
system_prompt = checkpoint_system_prompt
|
|
254
|
-
role = "assistant"
|
|
255
|
-
else:
|
|
256
|
-
raise Exception(
|
|
257
|
-
f"Unhandled input condition! Source not 'loop' or 'input'. Metadata: {metadata}"
|
|
258
|
-
)
|
|
259
|
-
# Add system prompt as first thing in context if not already present
|
|
260
|
-
if len(context) == 0:
|
|
261
|
-
context.append({"role": "system", "content": system_prompt})
|
|
262
|
-
|
|
263
|
-
# iterate through nodes - there is probably only 1
|
|
264
|
-
for node, value in update_dict.items():
|
|
265
|
-
# iterate through list of message updates
|
|
266
|
-
if "messages" in value:
|
|
267
|
-
if isinstance(value["messages"], dict):
|
|
268
|
-
# Make this a list to iterate through - 4 Feb 2025 - used to be a list previously
|
|
269
|
-
messagelist = [value["messages"]]
|
|
270
|
-
else:
|
|
271
|
-
messagelist = value["messages"]
|
|
272
|
-
index_in_thread = 0
|
|
273
|
-
for message in messagelist:
|
|
274
|
-
if role == "user":
|
|
275
|
-
content = (
|
|
276
|
-
message.get("kwargs", {})
|
|
277
|
-
.get("content", {})
|
|
278
|
-
.get("kwargs", {})
|
|
279
|
-
.get("content", None)
|
|
280
|
-
)
|
|
281
|
-
elif role == "assistant":
|
|
282
|
-
content = message.get("kwargs", {}).get(
|
|
283
|
-
"content", None
|
|
284
|
-
)
|
|
285
|
-
else:
|
|
286
|
-
raise Exception(
|
|
287
|
-
"`role` should be either user or assistant."
|
|
288
|
-
)
|
|
289
|
-
Message.create(
|
|
290
|
-
evalsetrun=dataset.evalsetrun,
|
|
291
|
-
dataset=dataset,
|
|
292
|
-
thread=thread,
|
|
293
|
-
index_in_thread=index_in_thread,
|
|
294
|
-
role=role,
|
|
295
|
-
content=content,
|
|
296
|
-
context=json.dumps(context),
|
|
297
|
-
is_flexeval_completion=False,
|
|
298
|
-
system_prompt=system_prompt,
|
|
299
|
-
# language model stats
|
|
300
|
-
tool_calls=json.dumps(
|
|
301
|
-
message.get("kwargs", {}).get(
|
|
302
|
-
"tool_calls", []
|
|
303
|
-
)
|
|
304
|
-
),
|
|
305
|
-
tool_call_ids=[
|
|
306
|
-
tc["id"]
|
|
307
|
-
for tc in message.get("kwargs", {}).get(
|
|
308
|
-
"tool_calls", []
|
|
309
|
-
)
|
|
310
|
-
],
|
|
311
|
-
n_tool_calls=len(
|
|
312
|
-
message.get("kwargs", {}).get(
|
|
313
|
-
"tool_calls", []
|
|
314
|
-
)
|
|
315
|
-
),
|
|
316
|
-
prompt_tokens=message.get("kwargs", {})
|
|
317
|
-
.get("response_metadata", {})
|
|
318
|
-
.get("token_usage", {})
|
|
319
|
-
.get("prompt_tokens"),
|
|
320
|
-
completion_tokens=message.get("kwargs", {})
|
|
321
|
-
.get("response_metadata", {})
|
|
322
|
-
.get("token_usage", {})
|
|
323
|
-
.get("completion_tokens"),
|
|
324
|
-
model_name=message.get("kwargs", {})
|
|
325
|
-
.get("response_metadata", {})
|
|
326
|
-
.get("model_name"),
|
|
327
|
-
# langgraph metadata
|
|
328
|
-
langgraph_ts=checkpoint.get("ts"),
|
|
329
|
-
langgraph_step=metadata.get("step"),
|
|
330
|
-
langgraph_thread_id=completion_row["thread_id"],
|
|
331
|
-
langgraph_checkpoint_id=completion_row[
|
|
332
|
-
"checkpoint_id"
|
|
333
|
-
],
|
|
334
|
-
langgraph_parent_checkpoint_id=completion_row[
|
|
335
|
-
"parent_checkpoint_id"
|
|
336
|
-
],
|
|
337
|
-
langgraph_checkpoint=dumps(
|
|
338
|
-
checkpoint
|
|
339
|
-
), # Have to re-dump this because of the de-serialization#completion_row["checkpoint"],
|
|
340
|
-
langgraph_metadata=completion_row["metadata"],
|
|
341
|
-
langgraph_node=node,
|
|
342
|
-
langgraph_message_type=message["id"][-1],
|
|
343
|
-
langgraph_type=message.get("kwargs", {}).get(
|
|
344
|
-
"type"
|
|
345
|
-
),
|
|
346
|
-
# special property of state
|
|
347
|
-
langchain_print=message.get("kwargs", {})
|
|
348
|
-
.get("additional_kwargs", {})
|
|
349
|
-
.get("print", False),
|
|
350
|
-
)
|
|
351
|
-
|
|
352
|
-
# update the context for the next Message
|
|
353
|
-
context.append(
|
|
354
|
-
{
|
|
355
|
-
"role": role,
|
|
356
|
-
"content": content,
|
|
357
|
-
"langgraph_role": message["id"][-1],
|
|
358
|
-
}
|
|
359
|
-
)
|
|
360
|
-
|
|
361
|
-
# record tool call info so we can match them up later
|
|
362
|
-
if message.get("kwargs", {}).get("type") == "tool":
|
|
363
|
-
# this should have a mapping between tool_call_id and the RESPONSE to to the tool call
|
|
364
|
-
tool_responses_dict[
|
|
365
|
-
message.get("kwargs", {}).get(
|
|
366
|
-
"tool_call_id"
|
|
367
|
-
)
|
|
368
|
-
] = message.get("kwargs", {}).get("content", "")
|
|
369
|
-
else:
|
|
370
|
-
for tool_call in message.get("kwargs", {}).get(
|
|
371
|
-
"tool_calls", []
|
|
372
|
-
):
|
|
373
|
-
# this should have all the info about the tool calls, including additional_kwargs
|
|
374
|
-
# but NOT their responses
|
|
375
|
-
tool_calls_dict[tool_call["id"]] = tool_call
|
|
376
|
-
tool_addional_kwargs_dict[
|
|
377
|
-
tool_call["id"]
|
|
378
|
-
] = message.get("kwargs", {}).get(
|
|
379
|
-
"additional_kwargs", {}
|
|
380
|
-
)
|
|
381
|
-
index_in_thread += 1
|
|
382
|
-
|
|
383
|
-
# Add turns to each message
|
|
384
|
-
# Need to do this before dealing with tool calls, since we
|
|
385
|
-
# associated turns with tool calls via messages during the .create() method
|
|
303
|
+
for tc in tool_calls:
|
|
304
|
+
tool_calls_dict[tc["id"]] = tc
|
|
305
|
+
tool_additional_kwargs_dict[tc["id"]] = additional_kwargs
|
|
306
|
+
|
|
307
|
+
# Create turns from messages
|
|
386
308
|
add_turns(thread)
|
|
387
309
|
|
|
388
|
-
|
|
310
|
+
# Create ToolCall objects by matching calls to responses
|
|
389
311
|
for tool_call_id, tool_call_vals in tool_calls_dict.items():
|
|
390
312
|
if tool_call_id not in tool_responses_dict:
|
|
391
313
|
raise ValueError(
|
|
392
314
|
f"Found a tool call without a tool response! id='{tool_call_id}'"
|
|
393
315
|
)
|
|
394
|
-
# get matching message - should now be accessible through thread now?
|
|
395
316
|
matching_message = [
|
|
396
|
-
m
|
|
317
|
+
m
|
|
318
|
+
for m in thread.messages
|
|
319
|
+
if tool_call_id in (m.tool_call_ids or [])
|
|
397
320
|
][0]
|
|
398
321
|
|
|
399
322
|
ToolCall.create(
|
|
400
|
-
evalsetrun=dataset.evalsetrun,
|
|
401
323
|
dataset=dataset,
|
|
402
324
|
thread=thread,
|
|
403
325
|
turn=matching_message.turn,
|
|
@@ -405,14 +327,12 @@ def load_langgraph_sqlite(
|
|
|
405
327
|
function_name=tool_call_vals.get("name"),
|
|
406
328
|
args=json.dumps(tool_call_vals.get("args")),
|
|
407
329
|
additional_kwargs=json.dumps(
|
|
408
|
-
|
|
330
|
+
tool_additional_kwargs_dict.get(tool_call_id)
|
|
409
331
|
),
|
|
410
332
|
tool_call_id=tool_call_id,
|
|
411
333
|
response_content=tool_responses_dict.get(tool_call_id),
|
|
412
334
|
)
|
|
413
335
|
|
|
414
|
-
## Add system prompt if available?
|
|
415
|
-
|
|
416
336
|
|
|
417
337
|
def add_turns(thread: Thread):
|
|
418
338
|
# Add turn labels
|
|
@@ -426,7 +346,6 @@ def add_turns(thread: Thread):
|
|
|
426
346
|
index_in_thread = 0
|
|
427
347
|
for placeholder_turn_id, role in turn_dict.items(): # turns.items():
|
|
428
348
|
t = Turn.create(
|
|
429
|
-
evalsetrun=thread.evalsetrun,
|
|
430
349
|
dataset=thread.dataset,
|
|
431
350
|
thread=thread,
|
|
432
351
|
index_in_thread=index_in_thread,
|
|
@@ -447,12 +366,10 @@ def add_turns(thread: Thread):
|
|
|
447
366
|
|
|
448
367
|
def verify_checkpoints_table_exists(cursor):
|
|
449
368
|
# double check that the 'checkpoints' table exists
|
|
450
|
-
cursor.execute(
|
|
451
|
-
"""
|
|
369
|
+
cursor.execute("""
|
|
452
370
|
SELECT name FROM sqlite_master
|
|
453
371
|
WHERE type='table' AND name='checkpoints'
|
|
454
|
-
"""
|
|
455
|
-
)
|
|
372
|
+
""")
|
|
456
373
|
result = cursor.fetchone()
|
|
457
374
|
# Assert that the result is not None, meaning the table exists
|
|
458
375
|
assert result is not None, "Table 'checkpoints' does not exist in the database."
|
flexeval/db_utils.py
CHANGED
|
@@ -4,14 +4,23 @@ import peewee as pw
|
|
|
4
4
|
|
|
5
5
|
from flexeval.classes import base as classes_base
|
|
6
6
|
from flexeval.classes.dataset import Dataset
|
|
7
|
-
from flexeval.classes.eval_set_run import EvalSetRun
|
|
7
|
+
from flexeval.classes.eval_set_run import EvalSetRun, EvalSetRunDatasets
|
|
8
8
|
from flexeval.classes.message import Message
|
|
9
9
|
from flexeval.classes.metric import Metric
|
|
10
10
|
from flexeval.classes.thread import Thread
|
|
11
11
|
from flexeval.classes.tool_call import ToolCall
|
|
12
12
|
from flexeval.classes.turn import Turn
|
|
13
13
|
|
|
14
|
-
DATABASE_TABLES = [
|
|
14
|
+
DATABASE_TABLES = [
|
|
15
|
+
EvalSetRun,
|
|
16
|
+
Dataset,
|
|
17
|
+
EvalSetRunDatasets,
|
|
18
|
+
Thread,
|
|
19
|
+
Turn,
|
|
20
|
+
Message,
|
|
21
|
+
ToolCall,
|
|
22
|
+
Metric,
|
|
23
|
+
]
|
|
15
24
|
|
|
16
25
|
|
|
17
26
|
def ensure_database(database_path: str):
|
flexeval/dependency_graph.py
CHANGED
|
@@ -115,9 +115,9 @@ def get_parent_metrics(all_metrics: dict, child: dict) -> tuple[list, list]:
|
|
|
115
115
|
"""metrics_graph_ordered_list will be a list of metrics in order in which they should be run
|
|
116
116
|
|
|
117
117
|
This function takes the eval represented by "child" and finds ALL evals in "all_metrics"
|
|
118
|
-
that
|
|
118
|
+
that qualify as the child's immediate parent
|
|
119
119
|
|
|
120
|
-
An eval can qualify as a parent by having a matching name, type,
|
|
120
|
+
An eval can qualify as a parent by having a matching name, type, etc.
|
|
121
121
|
At this point, we won't have enough information to decide whether the child should be run
|
|
122
122
|
(since the child might have additional requirements on the output of the parent)
|
|
123
123
|
but this is enough to tell us that the child should be run AFTER the parent.
|
|
@@ -145,7 +145,7 @@ def get_parent_metrics(all_metrics: dict, child: dict) -> tuple[list, list]:
|
|
|
145
145
|
|
|
146
146
|
# if the conditionals are listed in the depends_on entry but don't match...
|
|
147
147
|
# Only check conditionals that are explicitly specified (not None) in the requirement
|
|
148
|
-
conditionals = ["metric_level", "
|
|
148
|
+
conditionals = ["metric_level", "name", "kwargs"]
|
|
149
149
|
for conditional in conditionals:
|
|
150
150
|
if (
|
|
151
151
|
conditional in requirement
|