python-flexeval 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flexeval/data_loader.py CHANGED
@@ -6,7 +6,6 @@ import pathlib
6
6
  import random as rd
7
7
  import sqlite3
8
8
 
9
- from langchain.load.dump import dumps
10
9
  from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
11
10
 
12
11
  from flexeval.classes.dataset import Dataset
@@ -14,10 +13,117 @@ from flexeval.classes.message import Message
14
13
  from flexeval.classes.thread import Thread
15
14
  from flexeval.classes.tool_call import ToolCall
16
15
  from flexeval.classes.turn import Turn
16
+ from flexeval.schema.evalrun_schema import FileDataSource, FileFormatEnum
17
17
 
18
18
  logger = logging.getLogger(__name__)
19
19
 
20
20
 
21
+ def load_thread_to_dataset(
22
+ thread_id: str | int,
23
+ thread: dict,
24
+ dataset: Dataset,
25
+ eval_run_thread_id: str | None = None,
26
+ ) -> Thread:
27
+ if "input" not in thread:
28
+ raise ValueError(
29
+ f"Expected thread format is a dictionary containing at least an 'input' key. Instead, we found: {thread.keys()}"
30
+ )
31
+
32
+ # extract any metadata
33
+ thread_metadata = thread.copy()
34
+ del thread_metadata["input"]
35
+
36
+ context = []
37
+ thread_input = thread["input"]
38
+
39
+ # Get system prompt used in the thread - assuming only 1
40
+ for message in thread_input:
41
+ if message["role"] == "system":
42
+ system_prompt = message["content"]
43
+ break
44
+ else:
45
+ system_prompt = None
46
+ if system_prompt is not None:
47
+ # Add the system prompt as context
48
+ context.append({"role": "system", "content": system_prompt})
49
+
50
+ thread_object: Thread = Thread.create(
51
+ dataset=dataset,
52
+ jsonl_thread_id=thread_id,
53
+ eval_run_thread_id=eval_run_thread_id,
54
+ system_prompt=system_prompt,
55
+ metadata=json.dumps(thread_metadata),
56
+ )
57
+
58
+ # Create messages
59
+ index_in_thread = 0
60
+ for message in thread_input:
61
+ if not isinstance(message, dict):
62
+ raise ValueError(
63
+ f"Can't load unknown object type; expected dict. Check JSONL format: {message}"
64
+ )
65
+ role = message.get("role", None)
66
+ if role != "system":
67
+ # System message shouldn't be added as a separate message
68
+ system_prompt_for_this_message = ""
69
+ if role != "user":
70
+ system_prompt_for_this_message = system_prompt
71
+ message_metadata = message.copy()
72
+ if "content" in message_metadata:
73
+ del message_metadata["content"]
74
+ if "role" in message_metadata:
75
+ del message_metadata["role"]
76
+ Message.create(
77
+ dataset=dataset,
78
+ thread=thread_object,
79
+ index_in_thread=index_in_thread,
80
+ role=role,
81
+ content=message.get("content", None),
82
+ context=json.dumps(context),
83
+ is_flexeval_completion=False,
84
+ system_prompt=system_prompt_for_this_message,
85
+ metadata=json.dumps(message_metadata),
86
+ )
87
+ # Update context
88
+ context.append({"role": role, "content": message.get("content", None)})
89
+ index_in_thread += 1
90
+
91
+ add_turns(thread_object)
92
+ return thread_object
93
+
94
+
95
+ def load_file(
96
+ dataset: Dataset,
97
+ data_source: FileDataSource,
98
+ max_n_conversation_threads: int | None = None,
99
+ nb_evaluations_per_thread: int | None = 1,
100
+ ):
101
+ if data_source.format == FileFormatEnum.jsonl:
102
+ load_jsonl(
103
+ dataset=dataset,
104
+ filename=data_source.path,
105
+ max_n_conversation_threads=max_n_conversation_threads,
106
+ nb_evaluations_per_thread=nb_evaluations_per_thread,
107
+ )
108
+ elif data_source.format == FileFormatEnum.langgraph_sqlite:
109
+ load_langgraph_sqlite(
110
+ dataset=dataset,
111
+ filename=data_source.path,
112
+ max_n_conversation_threads=max_n_conversation_threads,
113
+ nb_evaluations_per_thread=nb_evaluations_per_thread,
114
+ )
115
+ else:
116
+ raise ValueError("Format not yet supported.")
117
+
118
+
119
+ def load_iterable(
120
+ dataset: Dataset,
121
+ iterable,
122
+ ):
123
+ for thread_id, thread in enumerate(iterable):
124
+ load_thread_to_dataset(thread_id, thread, dataset)
125
+
126
+
21
127
  def load_jsonl(
22
128
  dataset: Dataset,
23
129
  filename: str | pathlib.Path,
@@ -50,78 +156,16 @@ def load_jsonl(
50
156
  nb_evaluations_per_thread = 1
51
157
 
52
158
  for thread_id, thread in enumerate(all_lines):
53
- for thread_eval_run_id in range(
54
- max(1, nb_evaluations_per_thread)
55
- ): # duplicate stored threads for averaged evaluation results
56
- if thread_id in selected_thread_ids:
57
- thread_json = json.loads(thread)
58
- # extract any metadata
59
- thread_metadata = thread_json.copy()
60
- del thread_metadata["input"]
61
-
62
- context = []
63
- thread_input = thread_json["input"]
64
-
65
- # Get system prompt used in the thread - assuming only 1
66
- for message in thread_input:
67
- if message["role"] == "system":
68
- system_prompt = message["content"]
69
- break
70
- else:
71
- system_prompt = None
72
- if system_prompt is not None:
73
- # Add the system prompt as context
74
- context.append({"role": "system", "content": system_prompt})
75
-
76
- thread_object: Thread = Thread.create(
77
- evalsetrun=dataset.evalsetrun,
78
- dataset=dataset,
79
- jsonl_thread_id=thread_id,
80
- eval_run_thread_id=str(thread_id)
81
- + "_"
82
- + str(thread_eval_run_id),
83
- system_prompt=system_prompt,
84
- metadata=json.dumps(thread_metadata),
159
+ if thread_id in selected_thread_ids:
160
+ thread_json = json.loads(thread)
161
+ for thread_eval_run_id in range(
162
+ max(1, nb_evaluations_per_thread)
163
+ ): # duplicate stored threads to enable averaged per-object evaluations
164
+ eval_run_thread_id = f"{thread_id}_{thread_eval_run_id}"
165
+ load_thread_to_dataset(
166
+ thread_id, thread_json, dataset, eval_run_thread_id
85
167
  )
86
168
 
87
- # Create messages
88
- index_in_thread = 0
89
- for message in thread_input:
90
- if not isinstance(message, dict):
91
- raise ValueError(
92
- f"Can't load unknown object type; expected dict. Check JSONL format: {message}"
93
- )
94
- role = message.get("role", None)
95
- if role != "system":
96
- # System message shouldn't be added as a separate message
97
- system_prompt_for_this_message = ""
98
- if role != "user":
99
- system_prompt_for_this_message = system_prompt
100
- message_metadata = message.copy()
101
- if "content" in message_metadata:
102
- del message_metadata["content"]
103
- if "role" in message_metadata:
104
- del message_metadata["role"]
105
- Message.create(
106
- evalsetrun=dataset.evalsetrun,
107
- dataset=dataset,
108
- thread=thread_object,
109
- index_in_thread=index_in_thread,
110
- role=role,
111
- content=message.get("content", None),
112
- context=json.dumps(context),
113
- is_flexeval_completion=False,
114
- system_prompt=system_prompt_for_this_message,
115
- metadata=json.dumps(message_metadata),
116
- )
117
- # Update context
118
- context.append(
119
- {"role": role, "content": message.get("content", None)}
120
- )
121
- index_in_thread += 1
122
-
123
- add_turns(thread_object)
124
-
125
169
  # TODO - should we add ToolCall here? Is there a standard way to represent them in jsonl?
126
170
 
127
171
 
@@ -131,24 +175,22 @@ def load_langgraph_sqlite(
131
175
  max_n_conversation_threads: int | None = None,
132
176
  nb_evaluations_per_thread: int | None = 1,
133
177
  ):
178
+ """Load conversations from a LangGraph SQLite checkpoint database.
179
+
180
+ Reads the final checkpoint for each thread and extracts the cumulative
181
+ message list from channel_values.messages. Compatible with langgraph >= 1.0.
182
+ """
134
183
  serializer = JsonPlusSerializer()
135
184
 
136
185
  with sqlite3.connect(filename) as conn:
137
- # Set the row factory to sqlite3.Row
138
- # allowing us to reference columns by name instead of index
139
186
  conn.row_factory = sqlite3.Row
140
-
141
- # Create a cursor object
142
187
  cursor = conn.cursor()
143
188
  verify_checkpoints_table_exists(cursor)
144
189
 
145
- # Sync database
146
- query = "PRAGMA wal_checkpoint(FULL);"
147
- cursor.execute(query)
190
+ cursor.execute("PRAGMA wal_checkpoint(FULL);")
148
191
 
149
- # Make threads (aka conversations)
150
- query = "select distinct thread_id from checkpoints"
151
- cursor.execute(query)
192
+ # Get distinct thread IDs
193
+ cursor.execute("SELECT DISTINCT thread_id FROM checkpoints")
152
194
  thread_ids = cursor.fetchall()
153
195
 
154
196
  nb_threads = len(thread_ids)
@@ -159,260 +201,125 @@ def load_langgraph_sqlite(
159
201
  selected_thread_ids = rd.sample(thread_ids, max_n_conversation_threads)
160
202
  else:
161
203
  logger.debug(
162
- f"You requested up to '{max_n_conversation_threads}' conversations but only '{nb_threads}' are present in Sqlite dataset at '{filename}'."
204
+ f"You requested up to '{max_n_conversation_threads}' conversations "
205
+ f"but only '{nb_threads}' are present in Sqlite dataset at '{filename}'."
163
206
  )
164
207
  selected_thread_ids = thread_ids
165
208
 
166
- logger.debug(" DEBUG DUPLICATE SELECT THREAD IDS\n", selected_thread_ids[0])
209
+ for thread_eval_run_id in range(max(1, nb_evaluations_per_thread)):
210
+ for thread_id_row in selected_thread_ids:
211
+ lg_thread_id = thread_id_row[0]
212
+
213
+ # Get the final checkpoint (highest step) for this thread
214
+ cursor.execute(
215
+ """
216
+ SELECT *, json_extract(metadata, '$.step') as step
217
+ FROM checkpoints
218
+ WHERE thread_id = ?
219
+ ORDER BY json_extract(metadata, '$.step') DESC
220
+ LIMIT 1
221
+ """,
222
+ (lg_thread_id,),
223
+ )
224
+ final_row = cursor.fetchone()
225
+ if final_row is None:
226
+ logger.warning(f"No checkpoints found for thread '{lg_thread_id}'")
227
+ continue
228
+
229
+ checkpoint = serializer.loads_typed(
230
+ (final_row["type"], final_row["checkpoint"])
231
+ )
232
+ lg_messages = checkpoint.get("channel_values", {}).get("messages", [])
233
+
234
+ if not lg_messages:
235
+ logger.warning(
236
+ f"No messages in final checkpoint for thread '{lg_thread_id}'"
237
+ )
238
+ continue
167
239
 
168
- for thread_eval_run_id in range(
169
- max(1, nb_evaluations_per_thread)
170
- ): # duplicate stored threads for averaged evaluation results
171
- for thread_id in selected_thread_ids:
172
240
  thread = Thread.create(
173
- evalsetrun=dataset.evalsetrun,
174
241
  dataset=dataset,
175
- langgraph_thread_id=thread_id[0],
176
- eval_run_thread_id=str(thread_id[0])
177
- + "_"
178
- + str(thread_eval_run_id),
242
+ langgraph_thread_id=lg_thread_id,
243
+ eval_run_thread_id=f"{lg_thread_id}_{thread_eval_run_id}",
179
244
  )
180
245
 
181
- # Create messages
182
- query = f"select * from checkpoints where thread_id = '{thread.langgraph_thread_id}'"
183
- cursor.execute(query)
184
- completion_list = cursor.fetchall()
185
-
186
- # context has to be reset at the start of every thread
246
+ # Map message types to FlexEval roles
247
+ # Tools are counted as assistant per existing convention
187
248
  context = []
188
- # tool call variables
249
+ system_prompt = None
189
250
  tool_calls_dict = {}
190
251
  tool_responses_dict = {}
191
- tool_addional_kwargs_dict = {}
192
- # system prompt reset for every thread
193
- system_prompt = None
252
+ tool_additional_kwargs_dict = {}
194
253
 
195
- for completion_row in completion_list:
196
- # checkpoint is full state history
197
- checkpoint = serializer.loads_typed(
198
- (completion_row["type"], completion_row["checkpoint"])
254
+ for index_in_thread, msg in enumerate(lg_messages):
255
+ msg_type = msg.type # 'human', 'ai', 'tool'
256
+ role = "user" if msg_type == "human" else "assistant"
257
+ content = msg.content
258
+
259
+ # Extract tool call info
260
+ tool_calls = getattr(msg, "tool_calls", []) or []
261
+ tool_call_ids = [tc["id"] for tc in tool_calls]
262
+ response_meta = getattr(msg, "response_metadata", {}) or {}
263
+ token_usage = response_meta.get("token_usage", {})
264
+ additional_kwargs = getattr(msg, "additional_kwargs", {}) or {}
265
+
266
+ Message.create(
267
+ dataset=dataset,
268
+ thread=thread,
269
+ index_in_thread=index_in_thread,
270
+ role=role,
271
+ content=content,
272
+ context=json.dumps(context),
273
+ is_flexeval_completion=False,
274
+ system_prompt=system_prompt,
275
+ # language model stats
276
+ tool_calls=json.dumps(tool_calls),
277
+ tool_call_ids=tool_call_ids,
278
+ n_tool_calls=len(tool_calls),
279
+ prompt_tokens=token_usage.get("prompt_tokens"),
280
+ completion_tokens=token_usage.get("completion_tokens"),
281
+ model_name=response_meta.get("model_name"),
282
+ # langgraph metadata
283
+ langgraph_ts=checkpoint.get("ts"),
284
+ langgraph_thread_id=lg_thread_id,
285
+ langgraph_checkpoint_id=final_row["checkpoint_id"],
286
+ langgraph_parent_checkpoint_id=final_row[
287
+ "parent_checkpoint_id"
288
+ ],
289
+ langgraph_metadata=final_row["metadata"],
290
+ langgraph_message_type=msg_type,
291
+ langgraph_type=msg_type,
199
292
  )
200
- # metadata is the state update for that row
201
- metadata = json.loads(completion_row["metadata"])
202
- # IDs from langgraph
203
293
 
204
- if metadata.get("writes") is None:
205
- continue
294
+ # Build context for next message
295
+ context.append({"role": role, "content": content})
296
+
297
+ # Track tool calls and responses for ToolCall creation
298
+ if msg_type == "tool":
299
+ tool_call_id = getattr(msg, "tool_call_id", None)
300
+ if tool_call_id:
301
+ tool_responses_dict[tool_call_id] = content
206
302
  else:
207
- # Goal here is to create a data structure for EACH write/update
208
- # that can be used to construct a Message object
209
- # LangGraph stores info in 'writes' in the checkpoints.metadata column
210
- # but the format is a bit different between human and machine input
211
- # The resulting data structure should have
212
- # key (str) -- graph 'node' that produced the message (or 'human')
213
- # value (list) -- list of 'message' data structures with id, kwargs, etc
214
- # {
215
- # 'node_name':{
216
- # "messages":[
217
- # {
218
- # 'id': "XYZ"
219
- # 'kwargs':{
220
- # "content": 'text of the message',
221
- # "additional_kwargs": {}
222
- # },
223
- # }
224
- # ]
225
- #
226
- # }
227
- # }
228
-
229
- # user input condition
230
- if metadata.get("source") == "input":
231
- # NOTE: I think with the updated logging of HumanMessage with langgraph, we don't need this case
232
- update_dict = {}
233
- # this will be a dictionary we can add to
234
- # key is 'input', as in human input
235
- update_dict["input"] = {"messages": []}
236
- # print("metadata keys:", metadata["writes"].keys())
237
- # the very first message in input in a thread seems to include
238
- # the system prompt, not a message that was sent by the user.
239
- # the system promptdoesn't seem to be set anywhere else, so
240
- # using that as the system prompt for the thread.
241
- messagecount = 0
242
- for msg in metadata["writes"]["__start__"]["messages"]:
243
- if messagecount == 0 and metadata["step"] == -1:
244
- system_prompt = msg["kwargs"]["content"]
245
- messagecount += 1
246
- else:
247
- message = {}
248
- message["id"] = [
249
- "HumanMessage"
250
- ] # LangGraph has a list here
251
- message["kwargs"] = {}
252
- message["kwargs"]["content"] = msg
253
- message["kwargs"]["type"] = "human"
254
- update_dict["input"]["messages"].append(message)
255
- # will be used below
256
- role = "user"
257
-
258
- # machine input condition
259
- elif metadata.get("source") == "loop":
260
- # This already has a list of messages with kwargs, etc
261
- update_dict = metadata.get("writes")
262
- # I think 'system_prompt' is empty by default and not stored here unless
263
- # it's included in the LangGraph state
264
- checkpoint_system_prompt = checkpoint.get(
265
- "channel_values", {}
266
- ).get("system_prompt")
267
- if checkpoint_system_prompt is not None:
268
- system_prompt = checkpoint_system_prompt
269
- role = "assistant"
270
- else:
271
- raise Exception(
272
- f"Unhandled input condition! Source not 'loop' or 'input'. Metadata: {metadata}"
273
- )
274
- # Add system prompt as first thing in context if not already present
275
- if len(context) == 0:
276
- context.append({"role": "system", "content": system_prompt})
277
-
278
- # iterate through nodes - there is probably only 1
279
- for node, value in update_dict.items():
280
- # iterate through list of message updates
281
- if "messages" in value:
282
- if isinstance(value["messages"], dict):
283
- # Make this a list to iterate through - 4 Feb 2025 - used to be a list previously
284
- messagelist = [value["messages"]]
285
- else:
286
- messagelist = value["messages"]
287
- index_in_thread = 0
288
- for message in messagelist:
289
- if role == "user":
290
- content = (
291
- message.get("kwargs", {})
292
- .get("content", {})
293
- .get("kwargs", {})
294
- .get("content", None)
295
- )
296
- elif role == "assistant":
297
- content = message.get("kwargs", {}).get(
298
- "content", None
299
- )
300
- else:
301
- raise Exception(
302
- "`role` should be either user or assistant."
303
- )
304
- Message.create(
305
- evalsetrun=dataset.evalsetrun,
306
- dataset=dataset,
307
- thread=thread,
308
- index_in_thread=index_in_thread,
309
- role=role,
310
- content=content,
311
- context=json.dumps(context),
312
- is_flexeval_completion=False,
313
- system_prompt=system_prompt,
314
- # language model stats
315
- tool_calls=json.dumps(
316
- message.get("kwargs", {}).get(
317
- "tool_calls", []
318
- )
319
- ),
320
- tool_call_ids=[
321
- tc["id"]
322
- for tc in message.get("kwargs", {}).get(
323
- "tool_calls", []
324
- )
325
- ],
326
- n_tool_calls=len(
327
- message.get("kwargs", {}).get(
328
- "tool_calls", []
329
- )
330
- ),
331
- prompt_tokens=message.get("kwargs", {})
332
- .get("response_metadata", {})
333
- .get("token_usage", {})
334
- .get("prompt_tokens"),
335
- completion_tokens=message.get("kwargs", {})
336
- .get("response_metadata", {})
337
- .get("token_usage", {})
338
- .get("completion_tokens"),
339
- model_name=message.get("kwargs", {})
340
- .get("response_metadata", {})
341
- .get("model_name"),
342
- # langgraph metadata
343
- langgraph_ts=checkpoint.get("ts"),
344
- langgraph_step=metadata.get("step"),
345
- langgraph_thread_id=completion_row["thread_id"],
346
- langgraph_checkpoint_id=completion_row[
347
- "checkpoint_id"
348
- ],
349
- langgraph_parent_checkpoint_id=completion_row[
350
- "parent_checkpoint_id"
351
- ],
352
- langgraph_checkpoint=dumps(
353
- checkpoint
354
- ), # Have to re-dump this because of the de-serialization#completion_row["checkpoint"],
355
- langgraph_metadata=completion_row["metadata"],
356
- langgraph_node=node,
357
- langgraph_message_type=message["id"][-1],
358
- langgraph_type=message.get("kwargs", {}).get(
359
- "type"
360
- ),
361
- # special property of state
362
- langchain_print=message.get("kwargs", {})
363
- .get("additional_kwargs", {})
364
- .get("print", False),
365
- )
366
-
367
- # update the context for the next Message
368
- context.append(
369
- {
370
- "role": role,
371
- "content": content,
372
- "langgraph_role": message["id"][-1],
373
- }
374
- )
375
-
376
- # record tool call info so we can match them up later
377
- if message.get("kwargs", {}).get("type") == "tool":
378
- # this should have a mapping between tool_call_id and the RESPONSE to to the tool call
379
- tool_responses_dict[
380
- message.get("kwargs", {}).get(
381
- "tool_call_id"
382
- )
383
- ] = message.get("kwargs", {}).get("content", "")
384
- else:
385
- for tool_call in message.get("kwargs", {}).get(
386
- "tool_calls", []
387
- ):
388
- # this should have all the info about the tool calls, including additional_kwargs
389
- # but NOT their responses
390
- tool_calls_dict[tool_call["id"]] = tool_call
391
- tool_addional_kwargs_dict[
392
- tool_call["id"]
393
- ] = message.get("kwargs", {}).get(
394
- "additional_kwargs", {}
395
- )
396
- index_in_thread += 1
397
-
398
- # Add turns to each message
399
- # Need to do this before dealing with tool calls, since we
400
- # associated turns with tool calls via messages during the .create() method
303
+ for tc in tool_calls:
304
+ tool_calls_dict[tc["id"]] = tc
305
+ tool_additional_kwargs_dict[tc["id"]] = additional_kwargs
306
+
307
+ # Create turns from messages
401
308
  add_turns(thread)
402
309
 
403
- ## Match up tool calls and make an object for each match
310
+ # Create ToolCall objects by matching calls to responses
404
311
  for tool_call_id, tool_call_vals in tool_calls_dict.items():
405
312
  if tool_call_id not in tool_responses_dict:
406
313
  raise ValueError(
407
314
  f"Found a tool call without a tool response! id='{tool_call_id}'"
408
315
  )
409
- # get matching message - should now be accessible through thread now?
410
316
  matching_message = [
411
- m for m in thread.messages if tool_call_id in m.tool_call_ids
317
+ m
318
+ for m in thread.messages
319
+ if tool_call_id in (m.tool_call_ids or [])
412
320
  ][0]
413
321
 
414
322
  ToolCall.create(
415
- evalsetrun=dataset.evalsetrun,
416
323
  dataset=dataset,
417
324
  thread=thread,
418
325
  turn=matching_message.turn,
@@ -420,14 +327,12 @@ def load_langgraph_sqlite(
420
327
  function_name=tool_call_vals.get("name"),
421
328
  args=json.dumps(tool_call_vals.get("args")),
422
329
  additional_kwargs=json.dumps(
423
- tool_addional_kwargs_dict.get(tool_call_id)
330
+ tool_additional_kwargs_dict.get(tool_call_id)
424
331
  ),
425
332
  tool_call_id=tool_call_id,
426
333
  response_content=tool_responses_dict.get(tool_call_id),
427
334
  )
428
335
 
429
- ## Add system prompt if available?
430
-
431
336
 
432
337
  def add_turns(thread: Thread):
433
338
  # Add turn labels
@@ -441,7 +346,6 @@ def add_turns(thread: Thread):
441
346
  index_in_thread = 0
442
347
  for placeholder_turn_id, role in turn_dict.items(): # turns.items():
443
348
  t = Turn.create(
444
- evalsetrun=thread.evalsetrun,
445
349
  dataset=thread.dataset,
446
350
  thread=thread,
447
351
  index_in_thread=index_in_thread,
@@ -462,12 +366,10 @@ def add_turns(thread: Thread):
462
366
 
463
367
  def verify_checkpoints_table_exists(cursor):
464
368
  # double check that the 'checkpoints' table exists
465
- cursor.execute(
466
- """
369
+ cursor.execute("""
467
370
  SELECT name FROM sqlite_master
468
371
  WHERE type='table' AND name='checkpoints'
469
- """
470
- )
372
+ """)
471
373
  result = cursor.fetchone()
472
374
  # Assert that the result is not None, meaning the table exists
473
375
  assert result is not None, "Table 'checkpoints' does not exist in the database."
flexeval/db_utils.py CHANGED
@@ -4,14 +4,23 @@ import peewee as pw
4
4
 
5
5
  from flexeval.classes import base as classes_base
6
6
  from flexeval.classes.dataset import Dataset
7
- from flexeval.classes.eval_set_run import EvalSetRun
7
+ from flexeval.classes.eval_set_run import EvalSetRun, EvalSetRunDatasets
8
8
  from flexeval.classes.message import Message
9
9
  from flexeval.classes.metric import Metric
10
10
  from flexeval.classes.thread import Thread
11
11
  from flexeval.classes.tool_call import ToolCall
12
12
  from flexeval.classes.turn import Turn
13
13
 
14
- DATABASE_TABLES = [EvalSetRun, Dataset, Thread, Turn, Message, ToolCall, Metric]
14
+ DATABASE_TABLES = [
15
+ EvalSetRun,
16
+ Dataset,
17
+ EvalSetRunDatasets,
18
+ Thread,
19
+ Turn,
20
+ Message,
21
+ ToolCall,
22
+ Metric,
23
+ ]
15
24
 
16
25
 
17
26
  def ensure_database(database_path: str):