python-flexeval 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flexeval/__about__.py +1 -1
- flexeval/classes/dataset.py +12 -72
- flexeval/classes/eval_set_run.py +18 -7
- flexeval/classes/jsonview.py +10 -5
- flexeval/classes/message.py +11 -5
- flexeval/classes/metric.py +0 -8
- flexeval/classes/thread.py +0 -2
- flexeval/classes/tool_call.py +0 -2
- flexeval/classes/turn.py +7 -5
- flexeval/completions.py +8 -5
- flexeval/compute_metrics.py +45 -32
- flexeval/configuration/evals.yaml +2 -25
- flexeval/data_loader.py +219 -317
- flexeval/db_utils.py +11 -2
- flexeval/dependency_graph.py +3 -3
- flexeval/eval_schema.json +0 -18
- flexeval/function_types.py +2 -13
- flexeval/metrics/save.py +12 -8
- flexeval/run_utils.py +163 -17
- flexeval/runner.py +6 -14
- flexeval/schema/config_schema.py +12 -0
- flexeval/schema/eval_schema.py +3 -0
- flexeval/schema/evalrun_schema.py +41 -10
- {python_flexeval-0.3.0.dist-info → python_flexeval-0.4.1.dist-info}/METADATA +3 -3
- python_flexeval-0.4.1.dist-info/RECORD +49 -0
- {python_flexeval-0.3.0.dist-info → python_flexeval-0.4.1.dist-info}/WHEEL +1 -1
- python_flexeval-0.3.0.dist-info/RECORD +0 -49
- {python_flexeval-0.3.0.dist-info → python_flexeval-0.4.1.dist-info}/entry_points.txt +0 -0
- {python_flexeval-0.3.0.dist-info → python_flexeval-0.4.1.dist-info}/licenses/LICENSE +0 -0
flexeval/data_loader.py
CHANGED
|
@@ -6,7 +6,6 @@ import pathlib
|
|
|
6
6
|
import random as rd
|
|
7
7
|
import sqlite3
|
|
8
8
|
|
|
9
|
-
from langchain.load.dump import dumps
|
|
10
9
|
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
|
11
10
|
|
|
12
11
|
from flexeval.classes.dataset import Dataset
|
|
@@ -14,10 +13,117 @@ from flexeval.classes.message import Message
|
|
|
14
13
|
from flexeval.classes.thread import Thread
|
|
15
14
|
from flexeval.classes.tool_call import ToolCall
|
|
16
15
|
from flexeval.classes.turn import Turn
|
|
16
|
+
from flexeval.schema.evalrun_schema import FileDataSource, FileFormatEnum
|
|
17
17
|
|
|
18
18
|
logger = logging.getLogger(__name__)
|
|
19
19
|
|
|
20
20
|
|
|
21
|
+
def load_thread_to_dataset(
|
|
22
|
+
thread_id: str | int,
|
|
23
|
+
thread: dict,
|
|
24
|
+
dataset: Dataset,
|
|
25
|
+
eval_run_thread_id: str | None = None,
|
|
26
|
+
) -> Thread:
|
|
27
|
+
if "input" not in thread:
|
|
28
|
+
raise ValueError(
|
|
29
|
+
f"Expected thread format is a dictionary containing at least an 'input' key. Instead, we found: {thread.keys()}"
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
# extract any metadata
|
|
33
|
+
thread_metadata = thread.copy()
|
|
34
|
+
del thread_metadata["input"]
|
|
35
|
+
|
|
36
|
+
context = []
|
|
37
|
+
thread_input = thread["input"]
|
|
38
|
+
|
|
39
|
+
# Get system prompt used in the thread - assuming only 1
|
|
40
|
+
for message in thread_input:
|
|
41
|
+
if message["role"] == "system":
|
|
42
|
+
system_prompt = message["content"]
|
|
43
|
+
break
|
|
44
|
+
else:
|
|
45
|
+
system_prompt = None
|
|
46
|
+
if system_prompt is not None:
|
|
47
|
+
# Add the system prompt as context
|
|
48
|
+
context.append({"role": "system", "content": system_prompt})
|
|
49
|
+
|
|
50
|
+
thread_object: Thread = Thread.create(
|
|
51
|
+
dataset=dataset,
|
|
52
|
+
jsonl_thread_id=thread_id,
|
|
53
|
+
eval_run_thread_id=eval_run_thread_id,
|
|
54
|
+
system_prompt=system_prompt,
|
|
55
|
+
metadata=json.dumps(thread_metadata),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Create messages
|
|
59
|
+
index_in_thread = 0
|
|
60
|
+
for message in thread_input:
|
|
61
|
+
if not isinstance(message, dict):
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"Can't load unknown object type; expected dict. Check JSONL format: {message}"
|
|
64
|
+
)
|
|
65
|
+
role = message.get("role", None)
|
|
66
|
+
if role != "system":
|
|
67
|
+
# System message shouldn't be added as a separate message
|
|
68
|
+
system_prompt_for_this_message = ""
|
|
69
|
+
if role != "user":
|
|
70
|
+
system_prompt_for_this_message = system_prompt
|
|
71
|
+
message_metadata = message.copy()
|
|
72
|
+
if "content" in message_metadata:
|
|
73
|
+
del message_metadata["content"]
|
|
74
|
+
if "role" in message_metadata:
|
|
75
|
+
del message_metadata["role"]
|
|
76
|
+
Message.create(
|
|
77
|
+
dataset=dataset,
|
|
78
|
+
thread=thread_object,
|
|
79
|
+
index_in_thread=index_in_thread,
|
|
80
|
+
role=role,
|
|
81
|
+
content=message.get("content", None),
|
|
82
|
+
context=json.dumps(context),
|
|
83
|
+
is_flexeval_completion=False,
|
|
84
|
+
system_prompt=system_prompt_for_this_message,
|
|
85
|
+
metadata=json.dumps(message_metadata),
|
|
86
|
+
)
|
|
87
|
+
# Update context
|
|
88
|
+
context.append({"role": role, "content": message.get("content", None)})
|
|
89
|
+
index_in_thread += 1
|
|
90
|
+
|
|
91
|
+
add_turns(thread_object)
|
|
92
|
+
return thread_object
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def load_file(
|
|
96
|
+
dataset: Dataset,
|
|
97
|
+
data_source: FileDataSource,
|
|
98
|
+
max_n_conversation_threads: int | None = None,
|
|
99
|
+
nb_evaluations_per_thread: int | None = 1,
|
|
100
|
+
):
|
|
101
|
+
if data_source.format == FileFormatEnum.jsonl:
|
|
102
|
+
load_jsonl(
|
|
103
|
+
dataset=dataset,
|
|
104
|
+
filename=data_source.path,
|
|
105
|
+
max_n_conversation_threads=max_n_conversation_threads,
|
|
106
|
+
nb_evaluations_per_thread=nb_evaluations_per_thread,
|
|
107
|
+
)
|
|
108
|
+
elif data_source.format == FileFormatEnum.langgraph_sqlite:
|
|
109
|
+
load_langgraph_sqlite(
|
|
110
|
+
dataset=dataset,
|
|
111
|
+
filename=data_source.path,
|
|
112
|
+
max_n_conversation_threads=max_n_conversation_threads,
|
|
113
|
+
nb_evaluations_per_thread=nb_evaluations_per_thread,
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
raise ValueError("Format not yet supported.")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def load_iterable(
|
|
120
|
+
dataset: Dataset,
|
|
121
|
+
iterable,
|
|
122
|
+
):
|
|
123
|
+
for thread_id, thread in enumerate(iterable):
|
|
124
|
+
load_thread_to_dataset(thread_id, thread, dataset)
|
|
125
|
+
|
|
126
|
+
|
|
21
127
|
def load_jsonl(
|
|
22
128
|
dataset: Dataset,
|
|
23
129
|
filename: str | pathlib.Path,
|
|
@@ -50,78 +156,16 @@ def load_jsonl(
|
|
|
50
156
|
nb_evaluations_per_thread = 1
|
|
51
157
|
|
|
52
158
|
for thread_id, thread in enumerate(all_lines):
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
context = []
|
|
63
|
-
thread_input = thread_json["input"]
|
|
64
|
-
|
|
65
|
-
# Get system prompt used in the thread - assuming only 1
|
|
66
|
-
for message in thread_input:
|
|
67
|
-
if message["role"] == "system":
|
|
68
|
-
system_prompt = message["content"]
|
|
69
|
-
break
|
|
70
|
-
else:
|
|
71
|
-
system_prompt = None
|
|
72
|
-
if system_prompt is not None:
|
|
73
|
-
# Add the system prompt as context
|
|
74
|
-
context.append({"role": "system", "content": system_prompt})
|
|
75
|
-
|
|
76
|
-
thread_object: Thread = Thread.create(
|
|
77
|
-
evalsetrun=dataset.evalsetrun,
|
|
78
|
-
dataset=dataset,
|
|
79
|
-
jsonl_thread_id=thread_id,
|
|
80
|
-
eval_run_thread_id=str(thread_id)
|
|
81
|
-
+ "_"
|
|
82
|
-
+ str(thread_eval_run_id),
|
|
83
|
-
system_prompt=system_prompt,
|
|
84
|
-
metadata=json.dumps(thread_metadata),
|
|
159
|
+
if thread_id in selected_thread_ids:
|
|
160
|
+
thread_json = json.loads(thread)
|
|
161
|
+
for thread_eval_run_id in range(
|
|
162
|
+
max(1, nb_evaluations_per_thread)
|
|
163
|
+
): # duplicate stored threads to enable averaged per-object evaluations
|
|
164
|
+
eval_run_thread_id = f"{thread_id}_{thread_eval_run_id}"
|
|
165
|
+
load_thread_to_dataset(
|
|
166
|
+
thread_id, thread_json, dataset, eval_run_thread_id
|
|
85
167
|
)
|
|
86
168
|
|
|
87
|
-
# Create messages
|
|
88
|
-
index_in_thread = 0
|
|
89
|
-
for message in thread_input:
|
|
90
|
-
if not isinstance(message, dict):
|
|
91
|
-
raise ValueError(
|
|
92
|
-
f"Can't load unknown object type; expected dict. Check JSONL format: {message}"
|
|
93
|
-
)
|
|
94
|
-
role = message.get("role", None)
|
|
95
|
-
if role != "system":
|
|
96
|
-
# System message shouldn't be added as a separate message
|
|
97
|
-
system_prompt_for_this_message = ""
|
|
98
|
-
if role != "user":
|
|
99
|
-
system_prompt_for_this_message = system_prompt
|
|
100
|
-
message_metadata = message.copy()
|
|
101
|
-
if "content" in message_metadata:
|
|
102
|
-
del message_metadata["content"]
|
|
103
|
-
if "role" in message_metadata:
|
|
104
|
-
del message_metadata["role"]
|
|
105
|
-
Message.create(
|
|
106
|
-
evalsetrun=dataset.evalsetrun,
|
|
107
|
-
dataset=dataset,
|
|
108
|
-
thread=thread_object,
|
|
109
|
-
index_in_thread=index_in_thread,
|
|
110
|
-
role=role,
|
|
111
|
-
content=message.get("content", None),
|
|
112
|
-
context=json.dumps(context),
|
|
113
|
-
is_flexeval_completion=False,
|
|
114
|
-
system_prompt=system_prompt_for_this_message,
|
|
115
|
-
metadata=json.dumps(message_metadata),
|
|
116
|
-
)
|
|
117
|
-
# Update context
|
|
118
|
-
context.append(
|
|
119
|
-
{"role": role, "content": message.get("content", None)}
|
|
120
|
-
)
|
|
121
|
-
index_in_thread += 1
|
|
122
|
-
|
|
123
|
-
add_turns(thread_object)
|
|
124
|
-
|
|
125
169
|
# TODO - should we add ToolCall here? Is there a standard way to represent them in jsonl?
|
|
126
170
|
|
|
127
171
|
|
|
@@ -131,24 +175,22 @@ def load_langgraph_sqlite(
|
|
|
131
175
|
max_n_conversation_threads: int | None = None,
|
|
132
176
|
nb_evaluations_per_thread: int | None = 1,
|
|
133
177
|
):
|
|
178
|
+
"""Load conversations from a LangGraph SQLite checkpoint database.
|
|
179
|
+
|
|
180
|
+
Reads the final checkpoint for each thread and extracts the cumulative
|
|
181
|
+
message list from channel_values.messages. Compatible with langgraph >= 1.0.
|
|
182
|
+
"""
|
|
134
183
|
serializer = JsonPlusSerializer()
|
|
135
184
|
|
|
136
185
|
with sqlite3.connect(filename) as conn:
|
|
137
|
-
# Set the row factory to sqlite3.Row
|
|
138
|
-
# allowing us to reference columns by name instead of index
|
|
139
186
|
conn.row_factory = sqlite3.Row
|
|
140
|
-
|
|
141
|
-
# Create a cursor object
|
|
142
187
|
cursor = conn.cursor()
|
|
143
188
|
verify_checkpoints_table_exists(cursor)
|
|
144
189
|
|
|
145
|
-
|
|
146
|
-
query = "PRAGMA wal_checkpoint(FULL);"
|
|
147
|
-
cursor.execute(query)
|
|
190
|
+
cursor.execute("PRAGMA wal_checkpoint(FULL);")
|
|
148
191
|
|
|
149
|
-
#
|
|
150
|
-
|
|
151
|
-
cursor.execute(query)
|
|
192
|
+
# Get distinct thread IDs
|
|
193
|
+
cursor.execute("SELECT DISTINCT thread_id FROM checkpoints")
|
|
152
194
|
thread_ids = cursor.fetchall()
|
|
153
195
|
|
|
154
196
|
nb_threads = len(thread_ids)
|
|
@@ -159,260 +201,125 @@ def load_langgraph_sqlite(
|
|
|
159
201
|
selected_thread_ids = rd.sample(thread_ids, max_n_conversation_threads)
|
|
160
202
|
else:
|
|
161
203
|
logger.debug(
|
|
162
|
-
f"You requested up to '{max_n_conversation_threads}' conversations
|
|
204
|
+
f"You requested up to '{max_n_conversation_threads}' conversations "
|
|
205
|
+
f"but only '{nb_threads}' are present in Sqlite dataset at '{filename}'."
|
|
163
206
|
)
|
|
164
207
|
selected_thread_ids = thread_ids
|
|
165
208
|
|
|
166
|
-
|
|
209
|
+
for thread_eval_run_id in range(max(1, nb_evaluations_per_thread)):
|
|
210
|
+
for thread_id_row in selected_thread_ids:
|
|
211
|
+
lg_thread_id = thread_id_row[0]
|
|
212
|
+
|
|
213
|
+
# Get the final checkpoint (highest step) for this thread
|
|
214
|
+
cursor.execute(
|
|
215
|
+
"""
|
|
216
|
+
SELECT *, json_extract(metadata, '$.step') as step
|
|
217
|
+
FROM checkpoints
|
|
218
|
+
WHERE thread_id = ?
|
|
219
|
+
ORDER BY json_extract(metadata, '$.step') DESC
|
|
220
|
+
LIMIT 1
|
|
221
|
+
""",
|
|
222
|
+
(lg_thread_id,),
|
|
223
|
+
)
|
|
224
|
+
final_row = cursor.fetchone()
|
|
225
|
+
if final_row is None:
|
|
226
|
+
logger.warning(f"No checkpoints found for thread '{lg_thread_id}'")
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
checkpoint = serializer.loads_typed(
|
|
230
|
+
(final_row["type"], final_row["checkpoint"])
|
|
231
|
+
)
|
|
232
|
+
lg_messages = checkpoint.get("channel_values", {}).get("messages", [])
|
|
233
|
+
|
|
234
|
+
if not lg_messages:
|
|
235
|
+
logger.warning(
|
|
236
|
+
f"No messages in final checkpoint for thread '{lg_thread_id}'"
|
|
237
|
+
)
|
|
238
|
+
continue
|
|
167
239
|
|
|
168
|
-
for thread_eval_run_id in range(
|
|
169
|
-
max(1, nb_evaluations_per_thread)
|
|
170
|
-
): # duplicate stored threads for averaged evaluation results
|
|
171
|
-
for thread_id in selected_thread_ids:
|
|
172
240
|
thread = Thread.create(
|
|
173
|
-
evalsetrun=dataset.evalsetrun,
|
|
174
241
|
dataset=dataset,
|
|
175
|
-
langgraph_thread_id=
|
|
176
|
-
eval_run_thread_id=
|
|
177
|
-
+ "_"
|
|
178
|
-
+ str(thread_eval_run_id),
|
|
242
|
+
langgraph_thread_id=lg_thread_id,
|
|
243
|
+
eval_run_thread_id=f"{lg_thread_id}_{thread_eval_run_id}",
|
|
179
244
|
)
|
|
180
245
|
|
|
181
|
-
#
|
|
182
|
-
|
|
183
|
-
cursor.execute(query)
|
|
184
|
-
completion_list = cursor.fetchall()
|
|
185
|
-
|
|
186
|
-
# context has to be reset at the start of every thread
|
|
246
|
+
# Map message types to FlexEval roles
|
|
247
|
+
# Tools are counted as assistant per existing convention
|
|
187
248
|
context = []
|
|
188
|
-
|
|
249
|
+
system_prompt = None
|
|
189
250
|
tool_calls_dict = {}
|
|
190
251
|
tool_responses_dict = {}
|
|
191
|
-
|
|
192
|
-
# system prompt reset for every thread
|
|
193
|
-
system_prompt = None
|
|
252
|
+
tool_additional_kwargs_dict = {}
|
|
194
253
|
|
|
195
|
-
for
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
254
|
+
for index_in_thread, msg in enumerate(lg_messages):
|
|
255
|
+
msg_type = msg.type # 'human', 'ai', 'tool'
|
|
256
|
+
role = "user" if msg_type == "human" else "assistant"
|
|
257
|
+
content = msg.content
|
|
258
|
+
|
|
259
|
+
# Extract tool call info
|
|
260
|
+
tool_calls = getattr(msg, "tool_calls", []) or []
|
|
261
|
+
tool_call_ids = [tc["id"] for tc in tool_calls]
|
|
262
|
+
response_meta = getattr(msg, "response_metadata", {}) or {}
|
|
263
|
+
token_usage = response_meta.get("token_usage", {})
|
|
264
|
+
additional_kwargs = getattr(msg, "additional_kwargs", {}) or {}
|
|
265
|
+
|
|
266
|
+
Message.create(
|
|
267
|
+
dataset=dataset,
|
|
268
|
+
thread=thread,
|
|
269
|
+
index_in_thread=index_in_thread,
|
|
270
|
+
role=role,
|
|
271
|
+
content=content,
|
|
272
|
+
context=json.dumps(context),
|
|
273
|
+
is_flexeval_completion=False,
|
|
274
|
+
system_prompt=system_prompt,
|
|
275
|
+
# language model stats
|
|
276
|
+
tool_calls=json.dumps(tool_calls),
|
|
277
|
+
tool_call_ids=tool_call_ids,
|
|
278
|
+
n_tool_calls=len(tool_calls),
|
|
279
|
+
prompt_tokens=token_usage.get("prompt_tokens"),
|
|
280
|
+
completion_tokens=token_usage.get("completion_tokens"),
|
|
281
|
+
model_name=response_meta.get("model_name"),
|
|
282
|
+
# langgraph metadata
|
|
283
|
+
langgraph_ts=checkpoint.get("ts"),
|
|
284
|
+
langgraph_thread_id=lg_thread_id,
|
|
285
|
+
langgraph_checkpoint_id=final_row["checkpoint_id"],
|
|
286
|
+
langgraph_parent_checkpoint_id=final_row[
|
|
287
|
+
"parent_checkpoint_id"
|
|
288
|
+
],
|
|
289
|
+
langgraph_metadata=final_row["metadata"],
|
|
290
|
+
langgraph_message_type=msg_type,
|
|
291
|
+
langgraph_type=msg_type,
|
|
199
292
|
)
|
|
200
|
-
# metadata is the state update for that row
|
|
201
|
-
metadata = json.loads(completion_row["metadata"])
|
|
202
|
-
# IDs from langgraph
|
|
203
293
|
|
|
204
|
-
|
|
205
|
-
|
|
294
|
+
# Build context for next message
|
|
295
|
+
context.append({"role": role, "content": content})
|
|
296
|
+
|
|
297
|
+
# Track tool calls and responses for ToolCall creation
|
|
298
|
+
if msg_type == "tool":
|
|
299
|
+
tool_call_id = getattr(msg, "tool_call_id", None)
|
|
300
|
+
if tool_call_id:
|
|
301
|
+
tool_responses_dict[tool_call_id] = content
|
|
206
302
|
else:
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
# key (str) -- graph 'node' that produced the message (or 'human')
|
|
213
|
-
# value (list) -- list of 'message' data structures with id, kwargs, etc
|
|
214
|
-
# {
|
|
215
|
-
# 'node_name':{
|
|
216
|
-
# "messages":[
|
|
217
|
-
# {
|
|
218
|
-
# 'id': "XYZ"
|
|
219
|
-
# 'kwargs':{
|
|
220
|
-
# "content": 'text of the message',
|
|
221
|
-
# "additional_kwargs": {}
|
|
222
|
-
# },
|
|
223
|
-
# }
|
|
224
|
-
# ]
|
|
225
|
-
#
|
|
226
|
-
# }
|
|
227
|
-
# }
|
|
228
|
-
|
|
229
|
-
# user input condition
|
|
230
|
-
if metadata.get("source") == "input":
|
|
231
|
-
# NOTE: I think with the updated logging of HumanMessage with langgraph, we don't need this case
|
|
232
|
-
update_dict = {}
|
|
233
|
-
# this will be a dictionary we can add to
|
|
234
|
-
# key is 'input', as in human input
|
|
235
|
-
update_dict["input"] = {"messages": []}
|
|
236
|
-
# print("metadata keys:", metadata["writes"].keys())
|
|
237
|
-
# the very first message in input in a thread seems to include
|
|
238
|
-
# the system prompt, not a message that was sent by the user.
|
|
239
|
-
# the system promptdoesn't seem to be set anywhere else, so
|
|
240
|
-
# using that as the system prompt for the thread.
|
|
241
|
-
messagecount = 0
|
|
242
|
-
for msg in metadata["writes"]["__start__"]["messages"]:
|
|
243
|
-
if messagecount == 0 and metadata["step"] == -1:
|
|
244
|
-
system_prompt = msg["kwargs"]["content"]
|
|
245
|
-
messagecount += 1
|
|
246
|
-
else:
|
|
247
|
-
message = {}
|
|
248
|
-
message["id"] = [
|
|
249
|
-
"HumanMessage"
|
|
250
|
-
] # LangGraph has a list here
|
|
251
|
-
message["kwargs"] = {}
|
|
252
|
-
message["kwargs"]["content"] = msg
|
|
253
|
-
message["kwargs"]["type"] = "human"
|
|
254
|
-
update_dict["input"]["messages"].append(message)
|
|
255
|
-
# will be used below
|
|
256
|
-
role = "user"
|
|
257
|
-
|
|
258
|
-
# machine input condition
|
|
259
|
-
elif metadata.get("source") == "loop":
|
|
260
|
-
# This already has a list of messages with kwargs, etc
|
|
261
|
-
update_dict = metadata.get("writes")
|
|
262
|
-
# I think 'system_prompt' is empty by default and not stored here unless
|
|
263
|
-
# it's included in the LangGraph state
|
|
264
|
-
checkpoint_system_prompt = checkpoint.get(
|
|
265
|
-
"channel_values", {}
|
|
266
|
-
).get("system_prompt")
|
|
267
|
-
if checkpoint_system_prompt is not None:
|
|
268
|
-
system_prompt = checkpoint_system_prompt
|
|
269
|
-
role = "assistant"
|
|
270
|
-
else:
|
|
271
|
-
raise Exception(
|
|
272
|
-
f"Unhandled input condition! Source not 'loop' or 'input'. Metadata: {metadata}"
|
|
273
|
-
)
|
|
274
|
-
# Add system prompt as first thing in context if not already present
|
|
275
|
-
if len(context) == 0:
|
|
276
|
-
context.append({"role": "system", "content": system_prompt})
|
|
277
|
-
|
|
278
|
-
# iterate through nodes - there is probably only 1
|
|
279
|
-
for node, value in update_dict.items():
|
|
280
|
-
# iterate through list of message updates
|
|
281
|
-
if "messages" in value:
|
|
282
|
-
if isinstance(value["messages"], dict):
|
|
283
|
-
# Make this a list to iterate through - 4 Feb 2025 - used to be a list previously
|
|
284
|
-
messagelist = [value["messages"]]
|
|
285
|
-
else:
|
|
286
|
-
messagelist = value["messages"]
|
|
287
|
-
index_in_thread = 0
|
|
288
|
-
for message in messagelist:
|
|
289
|
-
if role == "user":
|
|
290
|
-
content = (
|
|
291
|
-
message.get("kwargs", {})
|
|
292
|
-
.get("content", {})
|
|
293
|
-
.get("kwargs", {})
|
|
294
|
-
.get("content", None)
|
|
295
|
-
)
|
|
296
|
-
elif role == "assistant":
|
|
297
|
-
content = message.get("kwargs", {}).get(
|
|
298
|
-
"content", None
|
|
299
|
-
)
|
|
300
|
-
else:
|
|
301
|
-
raise Exception(
|
|
302
|
-
"`role` should be either user or assistant."
|
|
303
|
-
)
|
|
304
|
-
Message.create(
|
|
305
|
-
evalsetrun=dataset.evalsetrun,
|
|
306
|
-
dataset=dataset,
|
|
307
|
-
thread=thread,
|
|
308
|
-
index_in_thread=index_in_thread,
|
|
309
|
-
role=role,
|
|
310
|
-
content=content,
|
|
311
|
-
context=json.dumps(context),
|
|
312
|
-
is_flexeval_completion=False,
|
|
313
|
-
system_prompt=system_prompt,
|
|
314
|
-
# language model stats
|
|
315
|
-
tool_calls=json.dumps(
|
|
316
|
-
message.get("kwargs", {}).get(
|
|
317
|
-
"tool_calls", []
|
|
318
|
-
)
|
|
319
|
-
),
|
|
320
|
-
tool_call_ids=[
|
|
321
|
-
tc["id"]
|
|
322
|
-
for tc in message.get("kwargs", {}).get(
|
|
323
|
-
"tool_calls", []
|
|
324
|
-
)
|
|
325
|
-
],
|
|
326
|
-
n_tool_calls=len(
|
|
327
|
-
message.get("kwargs", {}).get(
|
|
328
|
-
"tool_calls", []
|
|
329
|
-
)
|
|
330
|
-
),
|
|
331
|
-
prompt_tokens=message.get("kwargs", {})
|
|
332
|
-
.get("response_metadata", {})
|
|
333
|
-
.get("token_usage", {})
|
|
334
|
-
.get("prompt_tokens"),
|
|
335
|
-
completion_tokens=message.get("kwargs", {})
|
|
336
|
-
.get("response_metadata", {})
|
|
337
|
-
.get("token_usage", {})
|
|
338
|
-
.get("completion_tokens"),
|
|
339
|
-
model_name=message.get("kwargs", {})
|
|
340
|
-
.get("response_metadata", {})
|
|
341
|
-
.get("model_name"),
|
|
342
|
-
# langgraph metadata
|
|
343
|
-
langgraph_ts=checkpoint.get("ts"),
|
|
344
|
-
langgraph_step=metadata.get("step"),
|
|
345
|
-
langgraph_thread_id=completion_row["thread_id"],
|
|
346
|
-
langgraph_checkpoint_id=completion_row[
|
|
347
|
-
"checkpoint_id"
|
|
348
|
-
],
|
|
349
|
-
langgraph_parent_checkpoint_id=completion_row[
|
|
350
|
-
"parent_checkpoint_id"
|
|
351
|
-
],
|
|
352
|
-
langgraph_checkpoint=dumps(
|
|
353
|
-
checkpoint
|
|
354
|
-
), # Have to re-dump this because of the de-serialization#completion_row["checkpoint"],
|
|
355
|
-
langgraph_metadata=completion_row["metadata"],
|
|
356
|
-
langgraph_node=node,
|
|
357
|
-
langgraph_message_type=message["id"][-1],
|
|
358
|
-
langgraph_type=message.get("kwargs", {}).get(
|
|
359
|
-
"type"
|
|
360
|
-
),
|
|
361
|
-
# special property of state
|
|
362
|
-
langchain_print=message.get("kwargs", {})
|
|
363
|
-
.get("additional_kwargs", {})
|
|
364
|
-
.get("print", False),
|
|
365
|
-
)
|
|
366
|
-
|
|
367
|
-
# update the context for the next Message
|
|
368
|
-
context.append(
|
|
369
|
-
{
|
|
370
|
-
"role": role,
|
|
371
|
-
"content": content,
|
|
372
|
-
"langgraph_role": message["id"][-1],
|
|
373
|
-
}
|
|
374
|
-
)
|
|
375
|
-
|
|
376
|
-
# record tool call info so we can match them up later
|
|
377
|
-
if message.get("kwargs", {}).get("type") == "tool":
|
|
378
|
-
# this should have a mapping between tool_call_id and the RESPONSE to to the tool call
|
|
379
|
-
tool_responses_dict[
|
|
380
|
-
message.get("kwargs", {}).get(
|
|
381
|
-
"tool_call_id"
|
|
382
|
-
)
|
|
383
|
-
] = message.get("kwargs", {}).get("content", "")
|
|
384
|
-
else:
|
|
385
|
-
for tool_call in message.get("kwargs", {}).get(
|
|
386
|
-
"tool_calls", []
|
|
387
|
-
):
|
|
388
|
-
# this should have all the info about the tool calls, including additional_kwargs
|
|
389
|
-
# but NOT their responses
|
|
390
|
-
tool_calls_dict[tool_call["id"]] = tool_call
|
|
391
|
-
tool_addional_kwargs_dict[
|
|
392
|
-
tool_call["id"]
|
|
393
|
-
] = message.get("kwargs", {}).get(
|
|
394
|
-
"additional_kwargs", {}
|
|
395
|
-
)
|
|
396
|
-
index_in_thread += 1
|
|
397
|
-
|
|
398
|
-
# Add turns to each message
|
|
399
|
-
# Need to do this before dealing with tool calls, since we
|
|
400
|
-
# associated turns with tool calls via messages during the .create() method
|
|
303
|
+
for tc in tool_calls:
|
|
304
|
+
tool_calls_dict[tc["id"]] = tc
|
|
305
|
+
tool_additional_kwargs_dict[tc["id"]] = additional_kwargs
|
|
306
|
+
|
|
307
|
+
# Create turns from messages
|
|
401
308
|
add_turns(thread)
|
|
402
309
|
|
|
403
|
-
|
|
310
|
+
# Create ToolCall objects by matching calls to responses
|
|
404
311
|
for tool_call_id, tool_call_vals in tool_calls_dict.items():
|
|
405
312
|
if tool_call_id not in tool_responses_dict:
|
|
406
313
|
raise ValueError(
|
|
407
314
|
f"Found a tool call without a tool response! id='{tool_call_id}'"
|
|
408
315
|
)
|
|
409
|
-
# get matching message - should now be accessible through thread now?
|
|
410
316
|
matching_message = [
|
|
411
|
-
m
|
|
317
|
+
m
|
|
318
|
+
for m in thread.messages
|
|
319
|
+
if tool_call_id in (m.tool_call_ids or [])
|
|
412
320
|
][0]
|
|
413
321
|
|
|
414
322
|
ToolCall.create(
|
|
415
|
-
evalsetrun=dataset.evalsetrun,
|
|
416
323
|
dataset=dataset,
|
|
417
324
|
thread=thread,
|
|
418
325
|
turn=matching_message.turn,
|
|
@@ -420,14 +327,12 @@ def load_langgraph_sqlite(
|
|
|
420
327
|
function_name=tool_call_vals.get("name"),
|
|
421
328
|
args=json.dumps(tool_call_vals.get("args")),
|
|
422
329
|
additional_kwargs=json.dumps(
|
|
423
|
-
|
|
330
|
+
tool_additional_kwargs_dict.get(tool_call_id)
|
|
424
331
|
),
|
|
425
332
|
tool_call_id=tool_call_id,
|
|
426
333
|
response_content=tool_responses_dict.get(tool_call_id),
|
|
427
334
|
)
|
|
428
335
|
|
|
429
|
-
## Add system prompt if available?
|
|
430
|
-
|
|
431
336
|
|
|
432
337
|
def add_turns(thread: Thread):
|
|
433
338
|
# Add turn labels
|
|
@@ -441,7 +346,6 @@ def add_turns(thread: Thread):
|
|
|
441
346
|
index_in_thread = 0
|
|
442
347
|
for placeholder_turn_id, role in turn_dict.items(): # turns.items():
|
|
443
348
|
t = Turn.create(
|
|
444
|
-
evalsetrun=thread.evalsetrun,
|
|
445
349
|
dataset=thread.dataset,
|
|
446
350
|
thread=thread,
|
|
447
351
|
index_in_thread=index_in_thread,
|
|
@@ -462,12 +366,10 @@ def add_turns(thread: Thread):
|
|
|
462
366
|
|
|
463
367
|
def verify_checkpoints_table_exists(cursor):
|
|
464
368
|
# double check that the 'checkpoints' table exists
|
|
465
|
-
cursor.execute(
|
|
466
|
-
"""
|
|
369
|
+
cursor.execute("""
|
|
467
370
|
SELECT name FROM sqlite_master
|
|
468
371
|
WHERE type='table' AND name='checkpoints'
|
|
469
|
-
"""
|
|
470
|
-
)
|
|
372
|
+
""")
|
|
471
373
|
result = cursor.fetchone()
|
|
472
374
|
# Assert that the result is not None, meaning the table exists
|
|
473
375
|
assert result is not None, "Table 'checkpoints' does not exist in the database."
|
flexeval/db_utils.py
CHANGED
|
@@ -4,14 +4,23 @@ import peewee as pw
|
|
|
4
4
|
|
|
5
5
|
from flexeval.classes import base as classes_base
|
|
6
6
|
from flexeval.classes.dataset import Dataset
|
|
7
|
-
from flexeval.classes.eval_set_run import EvalSetRun
|
|
7
|
+
from flexeval.classes.eval_set_run import EvalSetRun, EvalSetRunDatasets
|
|
8
8
|
from flexeval.classes.message import Message
|
|
9
9
|
from flexeval.classes.metric import Metric
|
|
10
10
|
from flexeval.classes.thread import Thread
|
|
11
11
|
from flexeval.classes.tool_call import ToolCall
|
|
12
12
|
from flexeval.classes.turn import Turn
|
|
13
13
|
|
|
14
|
-
DATABASE_TABLES = [
|
|
14
|
+
DATABASE_TABLES = [
|
|
15
|
+
EvalSetRun,
|
|
16
|
+
Dataset,
|
|
17
|
+
EvalSetRunDatasets,
|
|
18
|
+
Thread,
|
|
19
|
+
Turn,
|
|
20
|
+
Message,
|
|
21
|
+
ToolCall,
|
|
22
|
+
Metric,
|
|
23
|
+
]
|
|
15
24
|
|
|
16
25
|
|
|
17
26
|
def ensure_database(database_path: str):
|