python-flexeval 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flexeval/__init__.py +11 -0
- flexeval/__main__.py +11 -0
- flexeval/classes/__init__.py +15 -0
- flexeval/classes/base.py +32 -0
- flexeval/classes/dataset.py +82 -0
- flexeval/classes/eval_runner.py +158 -0
- flexeval/classes/eval_set_run.py +32 -0
- flexeval/classes/message.py +183 -0
- flexeval/classes/metric.py +55 -0
- flexeval/classes/thread.py +79 -0
- flexeval/classes/tool_call.py +51 -0
- flexeval/classes/turn.py +206 -0
- flexeval/cli.py +104 -0
- flexeval/completions.py +147 -0
- flexeval/compute_metrics.py +788 -0
- flexeval/config.yaml +23 -0
- flexeval/configuration/__init__.py +1 -0
- flexeval/configuration/completion_functions.py +231 -0
- flexeval/configuration/evals.yaml +864 -0
- flexeval/configuration/function_metrics.py +650 -0
- flexeval/configuration/rubric_metrics.yaml +194 -0
- flexeval/data_loader.py +513 -0
- flexeval/db_utils.py +38 -0
- flexeval/dependency_graph.py +234 -0
- flexeval/eval_schema.json +256 -0
- flexeval/function_types.py +173 -0
- flexeval/helpers.py +52 -0
- flexeval/io/__init__.py +1 -0
- flexeval/io/parsers/yaml_parser.py +69 -0
- flexeval/log_utils.py +34 -0
- flexeval/metrics/__init__.py +8 -0
- flexeval/metrics/access.py +28 -0
- flexeval/metrics/save.py +39 -0
- flexeval/rubric.py +62 -0
- flexeval/run_utils.py +65 -0
- flexeval/runner.py +132 -0
- flexeval/schema/__init__.py +11 -0
- flexeval/schema/config_schema.py +46 -0
- flexeval/schema/eval_schema.py +163 -0
- flexeval/schema/evalrun_schema.py +97 -0
- flexeval/schema/rubric_schema.py +40 -0
- flexeval/schema/schema_utils.py +26 -0
- python_flexeval-0.1.5.dist-info/METADATA +118 -0
- python_flexeval-0.1.5.dist-info/RECORD +47 -0
- python_flexeval-0.1.5.dist-info/WHEEL +4 -0
- python_flexeval-0.1.5.dist-info/entry_points.txt +2 -0
- python_flexeval-0.1.5.dist-info/licenses/LICENSE +21 -0
flexeval/data_loader.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
1
|
+
"""Dataset loading functions. Maybe should move to :mod:`~flexeval.io`."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import pathlib
|
|
6
|
+
import random as rd
|
|
7
|
+
import sqlite3
|
|
8
|
+
|
|
9
|
+
from langchain.load.dump import dumps
|
|
10
|
+
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
|
11
|
+
|
|
12
|
+
from flexeval.classes.dataset import Dataset
|
|
13
|
+
from flexeval.classes.message import Message
|
|
14
|
+
from flexeval.classes.thread import Thread
|
|
15
|
+
from flexeval.classes.tool_call import ToolCall
|
|
16
|
+
from flexeval.classes.turn import Turn
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def load_jsonl(
|
|
22
|
+
dataset: Dataset,
|
|
23
|
+
filename: str | pathlib.Path,
|
|
24
|
+
max_n_conversation_threads: int | None = None,
|
|
25
|
+
nb_evaluations_per_thread: int | None = 1,
|
|
26
|
+
):
|
|
27
|
+
with open(filename, "r") as infile:
|
|
28
|
+
contents = infile.read() # will be a big string
|
|
29
|
+
all_lines = contents.splitlines()
|
|
30
|
+
|
|
31
|
+
# Each row is a single row of the jsonl file
|
|
32
|
+
# That means it has 'input' as a key, and a list of dictionaries as values
|
|
33
|
+
# per line
|
|
34
|
+
|
|
35
|
+
if max_n_conversation_threads is None:
|
|
36
|
+
max_n_conversation_threads = len(all_lines)
|
|
37
|
+
|
|
38
|
+
if max_n_conversation_threads <= len(all_lines):
|
|
39
|
+
selected_thread_ids = rd.sample(
|
|
40
|
+
list(range(len(all_lines))), max_n_conversation_threads
|
|
41
|
+
)
|
|
42
|
+
else:
|
|
43
|
+
logger.debug(
|
|
44
|
+
f"You requested up to '{max_n_conversation_threads}' conversations but only '{len(all_lines)}' are present in Jsonl dataset at '{filename}'."
|
|
45
|
+
)
|
|
46
|
+
selected_thread_ids = list(range(len(all_lines)))
|
|
47
|
+
|
|
48
|
+
### should duplicate the select threads nb_evaluations_per_thread times
|
|
49
|
+
if nb_evaluations_per_thread is None:
|
|
50
|
+
nb_evaluations_per_thread = 1
|
|
51
|
+
|
|
52
|
+
for thread_id, thread in enumerate(all_lines):
|
|
53
|
+
for thread_eval_run_id in range(
|
|
54
|
+
max(1, nb_evaluations_per_thread)
|
|
55
|
+
): # duplicate stored threads for averaged evaluation results
|
|
56
|
+
if thread_id in selected_thread_ids:
|
|
57
|
+
thread_object = Thread.create(
|
|
58
|
+
evalsetrun=dataset.evalsetrun,
|
|
59
|
+
dataset=dataset,
|
|
60
|
+
jsonl_thread_id=thread_id,
|
|
61
|
+
eval_run_thread_id=str(thread_id)
|
|
62
|
+
+ "_"
|
|
63
|
+
+ str(thread_eval_run_id),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Context
|
|
67
|
+
context = []
|
|
68
|
+
thread_input = json.loads(thread)["input"]
|
|
69
|
+
|
|
70
|
+
# Get system prompt used in the thread - assuming only 1
|
|
71
|
+
for message in thread_input:
|
|
72
|
+
if message["role"] == "system":
|
|
73
|
+
system_prompt = message["content"]
|
|
74
|
+
break
|
|
75
|
+
else:
|
|
76
|
+
system_prompt = None
|
|
77
|
+
if system_prompt is not None:
|
|
78
|
+
# Add the system prompt as context
|
|
79
|
+
context.append({"role": "system", "content": system_prompt})
|
|
80
|
+
|
|
81
|
+
# Create messages
|
|
82
|
+
index_in_thread = 0
|
|
83
|
+
for message in thread_input:
|
|
84
|
+
role = message.get("role", None)
|
|
85
|
+
if role != "system":
|
|
86
|
+
# System message shouldn't be added as a separate message
|
|
87
|
+
system_prompt_for_this_message = ""
|
|
88
|
+
if role != "user":
|
|
89
|
+
system_prompt_for_this_message = system_prompt
|
|
90
|
+
Message.create(
|
|
91
|
+
evalsetrun=dataset.evalsetrun,
|
|
92
|
+
dataset=dataset,
|
|
93
|
+
thread=thread_object,
|
|
94
|
+
index_in_thread=index_in_thread,
|
|
95
|
+
role=role,
|
|
96
|
+
content=message.get("content", None),
|
|
97
|
+
context=json.dumps(context),
|
|
98
|
+
metadata=message.get("metadata", None),
|
|
99
|
+
is_flexeval_completion=False,
|
|
100
|
+
system_prompt=system_prompt_for_this_message,
|
|
101
|
+
)
|
|
102
|
+
# Update context
|
|
103
|
+
context.append(
|
|
104
|
+
{"role": role, "content": message.get("content", None)}
|
|
105
|
+
)
|
|
106
|
+
index_in_thread += 1
|
|
107
|
+
|
|
108
|
+
add_turns(thread_object)
|
|
109
|
+
|
|
110
|
+
# TODO - should we add ToolCall here? Is there a standard way to represent them in jsonl?
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def load_langgraph_sqlite(
|
|
114
|
+
dataset: Dataset,
|
|
115
|
+
filename: str,
|
|
116
|
+
max_n_conversation_threads: int | None = None,
|
|
117
|
+
nb_evaluations_per_thread: int | None = 1,
|
|
118
|
+
):
|
|
119
|
+
serializer = JsonPlusSerializer()
|
|
120
|
+
|
|
121
|
+
with sqlite3.connect(filename) as conn:
|
|
122
|
+
# Set the row factory to sqlite3.Row
|
|
123
|
+
# allowing us to reference columns by name instead of index
|
|
124
|
+
conn.row_factory = sqlite3.Row
|
|
125
|
+
|
|
126
|
+
# Create a cursor object
|
|
127
|
+
cursor = conn.cursor()
|
|
128
|
+
verify_checkpoints_table_exists(cursor)
|
|
129
|
+
|
|
130
|
+
# Sync database
|
|
131
|
+
query = "PRAGMA wal_checkpoint(FULL);"
|
|
132
|
+
cursor.execute(query)
|
|
133
|
+
|
|
134
|
+
# Make threads (aka conversations)
|
|
135
|
+
query = "select distinct thread_id from checkpoints"
|
|
136
|
+
cursor.execute(query)
|
|
137
|
+
thread_ids = cursor.fetchall()
|
|
138
|
+
|
|
139
|
+
nb_threads = len(thread_ids)
|
|
140
|
+
if max_n_conversation_threads is None:
|
|
141
|
+
max_n_conversation_threads = nb_threads
|
|
142
|
+
|
|
143
|
+
if max_n_conversation_threads <= nb_threads:
|
|
144
|
+
selected_thread_ids = rd.sample(thread_ids, max_n_conversation_threads)
|
|
145
|
+
else:
|
|
146
|
+
logger.debug(
|
|
147
|
+
f"You requested up to '{max_n_conversation_threads}' conversations but only '{nb_threads}' are present in Sqlite dataset at '{filename}'."
|
|
148
|
+
)
|
|
149
|
+
selected_thread_ids = thread_ids
|
|
150
|
+
|
|
151
|
+
logger.debug(" DEBUG DUPLICATE SELECT THREAD IDS\n", selected_thread_ids[0])
|
|
152
|
+
|
|
153
|
+
for thread_eval_run_id in range(
|
|
154
|
+
max(1, nb_evaluations_per_thread)
|
|
155
|
+
): # duplicate stored threads for averaged evaluation results
|
|
156
|
+
for thread_id in selected_thread_ids:
|
|
157
|
+
thread = Thread.create(
|
|
158
|
+
evalsetrun=dataset.evalsetrun,
|
|
159
|
+
dataset=dataset,
|
|
160
|
+
langgraph_thread_id=thread_id[0],
|
|
161
|
+
eval_run_thread_id=str(thread_id[0])
|
|
162
|
+
+ "_"
|
|
163
|
+
+ str(thread_eval_run_id),
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Create messages
|
|
167
|
+
query = f"select * from checkpoints where thread_id = '{thread.langgraph_thread_id}'"
|
|
168
|
+
cursor.execute(query)
|
|
169
|
+
completion_list = cursor.fetchall()
|
|
170
|
+
|
|
171
|
+
# context has to be reset at the start of every thread
|
|
172
|
+
context = []
|
|
173
|
+
# tool call variables
|
|
174
|
+
tool_calls_dict = {}
|
|
175
|
+
tool_responses_dict = {}
|
|
176
|
+
tool_addional_kwargs_dict = {}
|
|
177
|
+
# system prompt reset for every thread
|
|
178
|
+
system_prompt = None
|
|
179
|
+
|
|
180
|
+
for completion_row in completion_list:
|
|
181
|
+
# checkpoint is full state history
|
|
182
|
+
checkpoint = serializer.loads_typed(
|
|
183
|
+
(completion_row["type"], completion_row["checkpoint"])
|
|
184
|
+
)
|
|
185
|
+
# metadata is the state update for that row
|
|
186
|
+
metadata = json.loads(completion_row["metadata"])
|
|
187
|
+
# IDs from langgraph
|
|
188
|
+
|
|
189
|
+
if metadata.get("writes") is None:
|
|
190
|
+
continue
|
|
191
|
+
else:
|
|
192
|
+
# Goal here is to create a data structure for EACH write/update
|
|
193
|
+
# that can be used to construct a Message object
|
|
194
|
+
# LangGraph stores info in 'writes' in the checkpoints.metadata column
|
|
195
|
+
# but the format is a bit different between human and machine input
|
|
196
|
+
# The resulting data structure should have
|
|
197
|
+
# key (str) -- graph 'node' that produced the message (or 'human')
|
|
198
|
+
# value (list) -- list of 'message' data structures with id, kwargs, etc
|
|
199
|
+
# {
|
|
200
|
+
# 'node_name':{
|
|
201
|
+
# "messages":[
|
|
202
|
+
# {
|
|
203
|
+
# 'id': "XYZ"
|
|
204
|
+
# 'kwargs':{
|
|
205
|
+
# "content": 'text of the message',
|
|
206
|
+
# "additional_kwargs": {}
|
|
207
|
+
# },
|
|
208
|
+
# }
|
|
209
|
+
# ]
|
|
210
|
+
#
|
|
211
|
+
# }
|
|
212
|
+
# }
|
|
213
|
+
|
|
214
|
+
# user input condition
|
|
215
|
+
if metadata.get("source") == "input":
|
|
216
|
+
# NOTE: I think with the updated logging of HumanMessage with langgraph, we don't need this case
|
|
217
|
+
update_dict = {}
|
|
218
|
+
# this will be a dictionary we can add to
|
|
219
|
+
# key is 'input', as in human input
|
|
220
|
+
update_dict["input"] = {"messages": []}
|
|
221
|
+
# print("metadata keys:", metadata["writes"].keys())
|
|
222
|
+
# the very first message in input in a thread seems to include
|
|
223
|
+
# the system prompt, not a message that was sent by the user.
|
|
224
|
+
# the system promptdoesn't seem to be set anywhere else, so
|
|
225
|
+
# using that as the system prompt for the thread.
|
|
226
|
+
messagecount = 0
|
|
227
|
+
for msg in metadata["writes"]["__start__"]["messages"]:
|
|
228
|
+
if messagecount == 0 and metadata["step"] == -1:
|
|
229
|
+
system_prompt = msg["kwargs"]["content"]
|
|
230
|
+
messagecount += 1
|
|
231
|
+
else:
|
|
232
|
+
message = {}
|
|
233
|
+
message["id"] = [
|
|
234
|
+
"HumanMessage"
|
|
235
|
+
] # LangGraph has a list here
|
|
236
|
+
message["kwargs"] = {}
|
|
237
|
+
message["kwargs"]["content"] = msg
|
|
238
|
+
message["kwargs"]["type"] = "human"
|
|
239
|
+
update_dict["input"]["messages"].append(message)
|
|
240
|
+
# will be used below
|
|
241
|
+
role = "user"
|
|
242
|
+
|
|
243
|
+
# machine input condition
|
|
244
|
+
elif metadata.get("source") == "loop":
|
|
245
|
+
# This already has a list of messages with kwargs, etc
|
|
246
|
+
update_dict = metadata.get("writes")
|
|
247
|
+
# I think 'system_prompt' is empty by default and not stored here unless
|
|
248
|
+
# it's included in the LangGraph state
|
|
249
|
+
checkpoint_system_prompt = checkpoint.get(
|
|
250
|
+
"channel_values", {}
|
|
251
|
+
).get("system_prompt")
|
|
252
|
+
if checkpoint_system_prompt is not None:
|
|
253
|
+
system_prompt = checkpoint_system_prompt
|
|
254
|
+
role = "assistant"
|
|
255
|
+
else:
|
|
256
|
+
raise Exception(
|
|
257
|
+
f"Unhandled input condition! Source not 'loop' or 'input'. Metadata: {metadata}"
|
|
258
|
+
)
|
|
259
|
+
# Add system prompt as first thing in context if not already present
|
|
260
|
+
if len(context) == 0:
|
|
261
|
+
context.append({"role": "system", "content": system_prompt})
|
|
262
|
+
|
|
263
|
+
# iterate through nodes - there is probably only 1
|
|
264
|
+
for node, value in update_dict.items():
|
|
265
|
+
# iterate through list of message updates
|
|
266
|
+
if "messages" in value:
|
|
267
|
+
if isinstance(value["messages"], dict):
|
|
268
|
+
# Make this a list to iterate through - 4 Feb 2025 - used to be a list previously
|
|
269
|
+
messagelist = [value["messages"]]
|
|
270
|
+
else:
|
|
271
|
+
messagelist = value["messages"]
|
|
272
|
+
index_in_thread = 0
|
|
273
|
+
for message in messagelist:
|
|
274
|
+
if role == "user":
|
|
275
|
+
content = (
|
|
276
|
+
message.get("kwargs", {})
|
|
277
|
+
.get("content", {})
|
|
278
|
+
.get("kwargs", {})
|
|
279
|
+
.get("content", None)
|
|
280
|
+
)
|
|
281
|
+
elif role == "assistant":
|
|
282
|
+
content = message.get("kwargs", {}).get(
|
|
283
|
+
"content", None
|
|
284
|
+
)
|
|
285
|
+
else:
|
|
286
|
+
raise Exception(
|
|
287
|
+
"`role` should be either user or assistant."
|
|
288
|
+
)
|
|
289
|
+
Message.create(
|
|
290
|
+
evalsetrun=dataset.evalsetrun,
|
|
291
|
+
dataset=dataset,
|
|
292
|
+
thread=thread,
|
|
293
|
+
index_in_thread=index_in_thread,
|
|
294
|
+
role=role,
|
|
295
|
+
content=content,
|
|
296
|
+
context=json.dumps(context),
|
|
297
|
+
is_flexeval_completion=False,
|
|
298
|
+
system_prompt=system_prompt,
|
|
299
|
+
# language model stats
|
|
300
|
+
tool_calls=json.dumps(
|
|
301
|
+
message.get("kwargs", {}).get(
|
|
302
|
+
"tool_calls", []
|
|
303
|
+
)
|
|
304
|
+
),
|
|
305
|
+
tool_call_ids=[
|
|
306
|
+
tc["id"]
|
|
307
|
+
for tc in message.get("kwargs", {}).get(
|
|
308
|
+
"tool_calls", []
|
|
309
|
+
)
|
|
310
|
+
],
|
|
311
|
+
n_tool_calls=len(
|
|
312
|
+
message.get("kwargs", {}).get(
|
|
313
|
+
"tool_calls", []
|
|
314
|
+
)
|
|
315
|
+
),
|
|
316
|
+
prompt_tokens=message.get("kwargs", {})
|
|
317
|
+
.get("response_metadata", {})
|
|
318
|
+
.get("token_usage", {})
|
|
319
|
+
.get("prompt_tokens"),
|
|
320
|
+
completion_tokens=message.get("kwargs", {})
|
|
321
|
+
.get("response_metadata", {})
|
|
322
|
+
.get("token_usage", {})
|
|
323
|
+
.get("completion_tokens"),
|
|
324
|
+
model_name=message.get("kwargs", {})
|
|
325
|
+
.get("response_metadata", {})
|
|
326
|
+
.get("model_name"),
|
|
327
|
+
# langgraph metadata
|
|
328
|
+
langgraph_ts=checkpoint.get("ts"),
|
|
329
|
+
langgraph_step=metadata.get("step"),
|
|
330
|
+
langgraph_thread_id=completion_row["thread_id"],
|
|
331
|
+
langgraph_checkpoint_id=completion_row[
|
|
332
|
+
"checkpoint_id"
|
|
333
|
+
],
|
|
334
|
+
langgraph_parent_checkpoint_id=completion_row[
|
|
335
|
+
"parent_checkpoint_id"
|
|
336
|
+
],
|
|
337
|
+
langgraph_checkpoint=dumps(
|
|
338
|
+
checkpoint
|
|
339
|
+
), # Have to re-dump this because of the de-serialization#completion_row["checkpoint"],
|
|
340
|
+
langgraph_metadata=completion_row["metadata"],
|
|
341
|
+
langgraph_node=node,
|
|
342
|
+
langgraph_message_type=message["id"][-1],
|
|
343
|
+
langgraph_type=message.get("kwargs", {}).get(
|
|
344
|
+
"type"
|
|
345
|
+
),
|
|
346
|
+
# special property of state
|
|
347
|
+
langchain_print=message.get("kwargs", {})
|
|
348
|
+
.get("additional_kwargs", {})
|
|
349
|
+
.get("print", False),
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# update the context for the next Message
|
|
353
|
+
context.append(
|
|
354
|
+
{
|
|
355
|
+
"role": role,
|
|
356
|
+
"content": content,
|
|
357
|
+
"langgraph_role": message["id"][-1],
|
|
358
|
+
}
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# record tool call info so we can match them up later
|
|
362
|
+
if message.get("kwargs", {}).get("type") == "tool":
|
|
363
|
+
# this should have a mapping between tool_call_id and the RESPONSE to to the tool call
|
|
364
|
+
tool_responses_dict[
|
|
365
|
+
message.get("kwargs", {}).get(
|
|
366
|
+
"tool_call_id"
|
|
367
|
+
)
|
|
368
|
+
] = message.get("kwargs", {}).get("content", "")
|
|
369
|
+
else:
|
|
370
|
+
for tool_call in message.get("kwargs", {}).get(
|
|
371
|
+
"tool_calls", []
|
|
372
|
+
):
|
|
373
|
+
# this should have all the info about the tool calls, including additional_kwargs
|
|
374
|
+
# but NOT their responses
|
|
375
|
+
tool_calls_dict[tool_call["id"]] = tool_call
|
|
376
|
+
tool_addional_kwargs_dict[
|
|
377
|
+
tool_call["id"]
|
|
378
|
+
] = message.get("kwargs", {}).get(
|
|
379
|
+
"additional_kwargs", {}
|
|
380
|
+
)
|
|
381
|
+
index_in_thread += 1
|
|
382
|
+
|
|
383
|
+
# Add turns to each message
|
|
384
|
+
# Need to do this before dealing with tool calls, since we
|
|
385
|
+
# associated turns with tool calls via messages during the .create() method
|
|
386
|
+
add_turns(thread)
|
|
387
|
+
|
|
388
|
+
## Match up tool calls and make an object for each match
|
|
389
|
+
for tool_call_id, tool_call_vals in tool_calls_dict.items():
|
|
390
|
+
if tool_call_id not in tool_responses_dict:
|
|
391
|
+
raise ValueError(
|
|
392
|
+
f"Found a tool call without a tool response! id='{tool_call_id}'"
|
|
393
|
+
)
|
|
394
|
+
# get matching message - should now be accessible through thread now?
|
|
395
|
+
matching_message = [
|
|
396
|
+
m for m in thread.messages if tool_call_id in m.tool_call_ids
|
|
397
|
+
][0]
|
|
398
|
+
|
|
399
|
+
ToolCall.create(
|
|
400
|
+
evalsetrun=dataset.evalsetrun,
|
|
401
|
+
dataset=dataset,
|
|
402
|
+
thread=thread,
|
|
403
|
+
turn=matching_message.turn,
|
|
404
|
+
message=matching_message,
|
|
405
|
+
function_name=tool_call_vals.get("name"),
|
|
406
|
+
args=json.dumps(tool_call_vals.get("args")),
|
|
407
|
+
additional_kwargs=json.dumps(
|
|
408
|
+
tool_addional_kwargs_dict.get(tool_call_id)
|
|
409
|
+
),
|
|
410
|
+
tool_call_id=tool_call_id,
|
|
411
|
+
response_content=tool_responses_dict.get(tool_call_id),
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
## Add system prompt if available?
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def add_turns(thread: Thread):
|
|
418
|
+
# Add turn labels
|
|
419
|
+
# Step 1 - add placeholder_turn_id to each message
|
|
420
|
+
message_roles = []
|
|
421
|
+
for message in thread.messages:
|
|
422
|
+
message_roles.append({"id": message.id, "role": message.role})
|
|
423
|
+
message_placeholder_ids, turn_dict = get_turns(thread=thread)
|
|
424
|
+
# Step 2 - Create turns, plus a mapping between the placeholder ids and the created ids
|
|
425
|
+
turns = {}
|
|
426
|
+
index_in_thread = 0
|
|
427
|
+
for placeholder_turn_id, role in turn_dict.items(): # turns.items():
|
|
428
|
+
t = Turn.create(
|
|
429
|
+
evalsetrun=thread.evalsetrun,
|
|
430
|
+
dataset=thread.dataset,
|
|
431
|
+
thread=thread,
|
|
432
|
+
index_in_thread=index_in_thread,
|
|
433
|
+
role=role,
|
|
434
|
+
)
|
|
435
|
+
# map placeholder id to turn object
|
|
436
|
+
turns[placeholder_turn_id] = t
|
|
437
|
+
index_in_thread += 1
|
|
438
|
+
# Step 3 - add placeholder ids to messages
|
|
439
|
+
# Can use zip since entries in message_list correspond to thread.messages
|
|
440
|
+
# NOTE: ANR: I don't follow how the message_list was supposed to work below.
|
|
441
|
+
for ml, message in zip(message_placeholder_ids, thread.messages):
|
|
442
|
+
# Is this going to work? No idea
|
|
443
|
+
message.turn = turns[ml]
|
|
444
|
+
# message.is_final_turn_in_input = ml.get("is_final_turn_in_input", False)
|
|
445
|
+
message.save()
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
def verify_checkpoints_table_exists(cursor):
|
|
449
|
+
# double check that the 'checkpoints' table exists
|
|
450
|
+
cursor.execute(
|
|
451
|
+
"""
|
|
452
|
+
SELECT name FROM sqlite_master
|
|
453
|
+
WHERE type='table' AND name='checkpoints'
|
|
454
|
+
"""
|
|
455
|
+
)
|
|
456
|
+
result = cursor.fetchone()
|
|
457
|
+
# Assert that the result is not None, meaning the table exists
|
|
458
|
+
assert result is not None, "Table 'checkpoints' does not exist in the database."
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
def get_turns(thread: Thread):
|
|
462
|
+
"""We're defining a turn as a list of 1 or more consequtive outputs
|
|
463
|
+
by the same role, where the role is either 'user', or 'assistant/tool'.
|
|
464
|
+
In other words, we would parse as follows:
|
|
465
|
+
TURN 1 - user
|
|
466
|
+
TURN 2 - assistant
|
|
467
|
+
TURN 3 - user
|
|
468
|
+
TURN 4 - assistant
|
|
469
|
+
TURN 4 - tool
|
|
470
|
+
TURN 4 - assistant
|
|
471
|
+
TURN 5 - user
|
|
472
|
+
"""
|
|
473
|
+
|
|
474
|
+
# these are all treated as belonging to the same 'turn'
|
|
475
|
+
machine_labels = ["assistant", "ai", "tool"]
|
|
476
|
+
|
|
477
|
+
turn_id = 1
|
|
478
|
+
previous_role = ""
|
|
479
|
+
# TODO: Make a message list here, store the placeholder ids, and update to the real turn ids; save at end
|
|
480
|
+
message_placeholder_ids = []
|
|
481
|
+
for turnentry_id, entry in enumerate(thread.messages): # enumerate(input_list):
|
|
482
|
+
current_role = entry.role # entry.get("role", None)
|
|
483
|
+
# entry["role"] = current_role
|
|
484
|
+
# if your role matches a previous, don't increment turn_id
|
|
485
|
+
if (current_role in machine_labels and previous_role in machine_labels) or (
|
|
486
|
+
current_role not in machine_labels and previous_role not in machine_labels
|
|
487
|
+
):
|
|
488
|
+
pass # TODO: clean up the condition to avoid the empty if
|
|
489
|
+
# previous_role = current_role
|
|
490
|
+
# entry["placholder_turn_id"] = turn_id
|
|
491
|
+
else:
|
|
492
|
+
turn_id += 1
|
|
493
|
+
# entry["placholder_turn_id"] = turn_id
|
|
494
|
+
# previous_role = current_role
|
|
495
|
+
# entry.turn_id = turn_id
|
|
496
|
+
message_placeholder_ids.append(turn_id)
|
|
497
|
+
previous_role = current_role
|
|
498
|
+
# entry.save()
|
|
499
|
+
|
|
500
|
+
# NOTE: ANR seems like this could be optimized - e.g., set all
|
|
501
|
+
# to false, then do a select query for just the ones where turn_id column is turn_id. That would also
|
|
502
|
+
# reduce the number of saves to the database.
|
|
503
|
+
# label final entry
|
|
504
|
+
# ANR: moved up the turn_id_roles bit here to avoid iterating twice
|
|
505
|
+
turn_id_roles = {}
|
|
506
|
+
for message_placehold_id, entry in zip(
|
|
507
|
+
message_placeholder_ids, thread.messages
|
|
508
|
+
): # input_list:
|
|
509
|
+
turn_id_roles[message_placehold_id] = entry.role
|
|
510
|
+
entry.is_final_turn_in_input = message_placehold_id == turn_id
|
|
511
|
+
entry.save() # Could optimize this to avoid saving twice
|
|
512
|
+
|
|
513
|
+
return message_placeholder_ids, turn_id_roles
|
flexeval/db_utils.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Peewee database utilities."""
|
|
2
|
+
|
|
3
|
+
import peewee as pw
|
|
4
|
+
|
|
5
|
+
from flexeval.classes import base as classes_base
|
|
6
|
+
from flexeval.classes.dataset import Dataset
|
|
7
|
+
from flexeval.classes.eval_set_run import EvalSetRun
|
|
8
|
+
from flexeval.classes.message import Message
|
|
9
|
+
from flexeval.classes.metric import Metric
|
|
10
|
+
from flexeval.classes.thread import Thread
|
|
11
|
+
from flexeval.classes.tool_call import ToolCall
|
|
12
|
+
from flexeval.classes.turn import Turn
|
|
13
|
+
|
|
14
|
+
DATABASE_TABLES = [EvalSetRun, Dataset, Thread, Turn, Message, ToolCall, Metric]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def initialize_database(database_path: str, clear_tables: bool = False):
|
|
18
|
+
classes_base.database.init(database_path)
|
|
19
|
+
# classes_base.database.start()
|
|
20
|
+
|
|
21
|
+
if clear_tables:
|
|
22
|
+
classes_base.database.drop_tables(DATABASE_TABLES)
|
|
23
|
+
classes_base.database.create_tables(DATABASE_TABLES)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def bind_to_database(database_path: str) -> pw.Database:
|
|
27
|
+
"""Utility function for binding to a FlexEval database so that ORM functionality can be used.
|
|
28
|
+
|
|
29
|
+
See: https://docs.peewee-orm.com/en/latest/peewee/database.html#setting-the-database-at-run-time
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
pw.Database: The new database created for the models to bind to.
|
|
33
|
+
"""
|
|
34
|
+
new_database = classes_base.create_sqlite_database(database_path)
|
|
35
|
+
new_database.bind(DATABASE_TABLES)
|
|
36
|
+
# Verify the binding worked by checking one of the models
|
|
37
|
+
assert classes_base.BaseModel._meta.database == new_database
|
|
38
|
+
return new_database
|