python-flexeval 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flexeval/data_loader.py CHANGED
@@ -6,7 +6,6 @@ import pathlib
6
6
  import random as rd
7
7
  import sqlite3
8
8
 
9
- from langchain.load.dump import dumps
10
9
  from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
11
10
 
12
11
  from flexeval.classes.dataset import Dataset
@@ -14,10 +13,117 @@ from flexeval.classes.message import Message
14
13
  from flexeval.classes.thread import Thread
15
14
  from flexeval.classes.tool_call import ToolCall
16
15
  from flexeval.classes.turn import Turn
16
+ from flexeval.schema.evalrun_schema import FileDataSource, FileFormatEnum
17
17
 
18
18
  logger = logging.getLogger(__name__)
19
19
 
20
20
 
21
+ def load_thread_to_dataset(
22
+ thread_id: str | int,
23
+ thread: dict,
24
+ dataset: Dataset,
25
+ eval_run_thread_id: str | None = None,
26
+ ) -> Thread:
27
+ if "input" not in thread:
28
+ raise ValueError(
29
+ f"Expected thread format is a dictionary containing at least an 'input' key. Instead, we found: {thread.keys()}"
30
+ )
31
+
32
+ # extract any metadata
33
+ thread_metadata = thread.copy()
34
+ del thread_metadata["input"]
35
+
36
+ context = []
37
+ thread_input = thread["input"]
38
+
39
+ # Get system prompt used in the thread - assuming only 1
40
+ for message in thread_input:
41
+ if message["role"] == "system":
42
+ system_prompt = message["content"]
43
+ break
44
+ else:
45
+ system_prompt = None
46
+ if system_prompt is not None:
47
+ # Add the system prompt as context
48
+ context.append({"role": "system", "content": system_prompt})
49
+
50
+ thread_object: Thread = Thread.create(
51
+ dataset=dataset,
52
+ jsonl_thread_id=thread_id,
53
+ eval_run_thread_id=eval_run_thread_id,
54
+ system_prompt=system_prompt,
55
+ metadata=json.dumps(thread_metadata),
56
+ )
57
+
58
+ # Create messages
59
+ index_in_thread = 0
60
+ for message in thread_input:
61
+ if not isinstance(message, dict):
62
+ raise ValueError(
63
+ f"Can't load unknown object type; expected dict. Check JSONL format: {message}"
64
+ )
65
+ role = message.get("role", None)
66
+ if role != "system":
67
+ # System message shouldn't be added as a separate message
68
+ system_prompt_for_this_message = ""
69
+ if role != "user":
70
+ system_prompt_for_this_message = system_prompt
71
+ message_metadata = message.copy()
72
+ if "content" in message_metadata:
73
+ del message_metadata["content"]
74
+ if "role" in message_metadata:
75
+ del message_metadata["role"]
76
+ Message.create(
77
+ dataset=dataset,
78
+ thread=thread_object,
79
+ index_in_thread=index_in_thread,
80
+ role=role,
81
+ content=message.get("content", None),
82
+ context=json.dumps(context),
83
+ is_flexeval_completion=False,
84
+ system_prompt=system_prompt_for_this_message,
85
+ metadata=json.dumps(message_metadata),
86
+ )
87
+ # Update context
88
+ context.append({"role": role, "content": message.get("content", None)})
89
+ index_in_thread += 1
90
+
91
+ add_turns(thread_object)
92
+ return thread_object
93
+
94
+
95
+ def load_file(
96
+ dataset: Dataset,
97
+ data_source: FileDataSource,
98
+ max_n_conversation_threads: int | None = None,
99
+ nb_evaluations_per_thread: int | None = 1,
100
+ ):
101
+ if data_source.format == FileFormatEnum.jsonl:
102
+ load_jsonl(
103
+ dataset=dataset,
104
+ filename=data_source.path,
105
+ max_n_conversation_threads=max_n_conversation_threads,
106
+ nb_evaluations_per_thread=nb_evaluations_per_thread,
107
+ )
108
+ elif data_source.format == FileFormatEnum.langgraph_sqlite:
109
+ load_langgraph_sqlite(
110
+ dataset=dataset,
111
+ filename=data_source.path,
112
+ max_n_conversation_threads=max_n_conversation_threads,
113
+ nb_evaluations_per_thread=nb_evaluations_per_thread,
114
+ )
115
+ else:
116
+ raise ValueError("Format not yet supported.")
117
+
118
+
119
+ def load_iterable(
120
+ dataset: Dataset,
121
+ iterable,
122
+ ):
123
+ for thread_id, thread in enumerate(iterable):
124
+ load_thread_to_dataset(thread_id, thread, dataset)
125
+
126
+
21
127
  def load_jsonl(
22
128
  dataset: Dataset,
23
129
  filename: str | pathlib.Path,
@@ -50,63 +156,16 @@ def load_jsonl(
50
156
  nb_evaluations_per_thread = 1
51
157
 
52
158
  for thread_id, thread in enumerate(all_lines):
53
- for thread_eval_run_id in range(
54
- max(1, nb_evaluations_per_thread)
55
- ): # duplicate stored threads for averaged evaluation results
56
- if thread_id in selected_thread_ids:
57
- thread_object = Thread.create(
58
- evalsetrun=dataset.evalsetrun,
59
- dataset=dataset,
60
- jsonl_thread_id=thread_id,
61
- eval_run_thread_id=str(thread_id)
62
- + "_"
63
- + str(thread_eval_run_id),
159
+ if thread_id in selected_thread_ids:
160
+ thread_json = json.loads(thread)
161
+ for thread_eval_run_id in range(
162
+ max(1, nb_evaluations_per_thread)
163
+ ): # duplicate stored threads to enable averaged per-object evaluations
164
+ eval_run_thread_id = f"{thread_id}_{thread_eval_run_id}"
165
+ load_thread_to_dataset(
166
+ thread_id, thread_json, dataset, eval_run_thread_id
64
167
  )
65
168
 
66
- # Context
67
- context = []
68
- thread_input = json.loads(thread)["input"]
69
-
70
- # Get system prompt used in the thread - assuming only 1
71
- for message in thread_input:
72
- if message["role"] == "system":
73
- system_prompt = message["content"]
74
- break
75
- else:
76
- system_prompt = None
77
- if system_prompt is not None:
78
- # Add the system prompt as context
79
- context.append({"role": "system", "content": system_prompt})
80
-
81
- # Create messages
82
- index_in_thread = 0
83
- for message in thread_input:
84
- role = message.get("role", None)
85
- if role != "system":
86
- # System message shouldn't be added as a separate message
87
- system_prompt_for_this_message = ""
88
- if role != "user":
89
- system_prompt_for_this_message = system_prompt
90
- Message.create(
91
- evalsetrun=dataset.evalsetrun,
92
- dataset=dataset,
93
- thread=thread_object,
94
- index_in_thread=index_in_thread,
95
- role=role,
96
- content=message.get("content", None),
97
- context=json.dumps(context),
98
- metadata=message.get("metadata", None),
99
- is_flexeval_completion=False,
100
- system_prompt=system_prompt_for_this_message,
101
- )
102
- # Update context
103
- context.append(
104
- {"role": role, "content": message.get("content", None)}
105
- )
106
- index_in_thread += 1
107
-
108
- add_turns(thread_object)
109
-
110
169
  # TODO - should we add ToolCall here? Is there a standard way to represent them in jsonl?
111
170
 
112
171
 
@@ -116,24 +175,22 @@ def load_langgraph_sqlite(
116
175
  max_n_conversation_threads: int | None = None,
117
176
  nb_evaluations_per_thread: int | None = 1,
118
177
  ):
178
+ """Load conversations from a LangGraph SQLite checkpoint database.
179
+
180
+ Reads the final checkpoint for each thread and extracts the cumulative
181
+ message list from channel_values.messages. Compatible with langgraph >= 1.0.
182
+ """
119
183
  serializer = JsonPlusSerializer()
120
184
 
121
185
  with sqlite3.connect(filename) as conn:
122
- # Set the row factory to sqlite3.Row
123
- # allowing us to reference columns by name instead of index
124
186
  conn.row_factory = sqlite3.Row
125
-
126
- # Create a cursor object
127
187
  cursor = conn.cursor()
128
188
  verify_checkpoints_table_exists(cursor)
129
189
 
130
- # Sync database
131
- query = "PRAGMA wal_checkpoint(FULL);"
132
- cursor.execute(query)
190
+ cursor.execute("PRAGMA wal_checkpoint(FULL);")
133
191
 
134
- # Make threads (aka conversations)
135
- query = "select distinct thread_id from checkpoints"
136
- cursor.execute(query)
192
+ # Get distinct thread IDs
193
+ cursor.execute("SELECT DISTINCT thread_id FROM checkpoints")
137
194
  thread_ids = cursor.fetchall()
138
195
 
139
196
  nb_threads = len(thread_ids)
@@ -144,260 +201,125 @@ def load_langgraph_sqlite(
144
201
  selected_thread_ids = rd.sample(thread_ids, max_n_conversation_threads)
145
202
  else:
146
203
  logger.debug(
147
- f"You requested up to '{max_n_conversation_threads}' conversations but only '{nb_threads}' are present in Sqlite dataset at '{filename}'."
204
+ f"You requested up to '{max_n_conversation_threads}' conversations "
205
+ f"but only '{nb_threads}' are present in Sqlite dataset at '{filename}'."
148
206
  )
149
207
  selected_thread_ids = thread_ids
150
208
 
151
- logger.debug(" DEBUG DUPLICATE SELECT THREAD IDS\n", selected_thread_ids[0])
209
+ for thread_eval_run_id in range(max(1, nb_evaluations_per_thread)):
210
+ for thread_id_row in selected_thread_ids:
211
+ lg_thread_id = thread_id_row[0]
212
+
213
+ # Get the final checkpoint (highest step) for this thread
214
+ cursor.execute(
215
+ """
216
+ SELECT *, json_extract(metadata, '$.step') as step
217
+ FROM checkpoints
218
+ WHERE thread_id = ?
219
+ ORDER BY json_extract(metadata, '$.step') DESC
220
+ LIMIT 1
221
+ """,
222
+ (lg_thread_id,),
223
+ )
224
+ final_row = cursor.fetchone()
225
+ if final_row is None:
226
+ logger.warning(f"No checkpoints found for thread '{lg_thread_id}'")
227
+ continue
228
+
229
+ checkpoint = serializer.loads_typed(
230
+ (final_row["type"], final_row["checkpoint"])
231
+ )
232
+ lg_messages = checkpoint.get("channel_values", {}).get("messages", [])
233
+
234
+ if not lg_messages:
235
+ logger.warning(
236
+ f"No messages in final checkpoint for thread '{lg_thread_id}'"
237
+ )
238
+ continue
152
239
 
153
- for thread_eval_run_id in range(
154
- max(1, nb_evaluations_per_thread)
155
- ): # duplicate stored threads for averaged evaluation results
156
- for thread_id in selected_thread_ids:
157
240
  thread = Thread.create(
158
- evalsetrun=dataset.evalsetrun,
159
241
  dataset=dataset,
160
- langgraph_thread_id=thread_id[0],
161
- eval_run_thread_id=str(thread_id[0])
162
- + "_"
163
- + str(thread_eval_run_id),
242
+ langgraph_thread_id=lg_thread_id,
243
+ eval_run_thread_id=f"{lg_thread_id}_{thread_eval_run_id}",
164
244
  )
165
245
 
166
- # Create messages
167
- query = f"select * from checkpoints where thread_id = '{thread.langgraph_thread_id}'"
168
- cursor.execute(query)
169
- completion_list = cursor.fetchall()
170
-
171
- # context has to be reset at the start of every thread
246
+ # Map message types to FlexEval roles
247
+ # Tools are counted as assistant per existing convention
172
248
  context = []
173
- # tool call variables
249
+ system_prompt = None
174
250
  tool_calls_dict = {}
175
251
  tool_responses_dict = {}
176
- tool_addional_kwargs_dict = {}
177
- # system prompt reset for every thread
178
- system_prompt = None
252
+ tool_additional_kwargs_dict = {}
179
253
 
180
- for completion_row in completion_list:
181
- # checkpoint is full state history
182
- checkpoint = serializer.loads_typed(
183
- (completion_row["type"], completion_row["checkpoint"])
254
+ for index_in_thread, msg in enumerate(lg_messages):
255
+ msg_type = msg.type # 'human', 'ai', 'tool'
256
+ role = "user" if msg_type == "human" else "assistant"
257
+ content = msg.content
258
+
259
+ # Extract tool call info
260
+ tool_calls = getattr(msg, "tool_calls", []) or []
261
+ tool_call_ids = [tc["id"] for tc in tool_calls]
262
+ response_meta = getattr(msg, "response_metadata", {}) or {}
263
+ token_usage = response_meta.get("token_usage", {})
264
+ additional_kwargs = getattr(msg, "additional_kwargs", {}) or {}
265
+
266
+ Message.create(
267
+ dataset=dataset,
268
+ thread=thread,
269
+ index_in_thread=index_in_thread,
270
+ role=role,
271
+ content=content,
272
+ context=json.dumps(context),
273
+ is_flexeval_completion=False,
274
+ system_prompt=system_prompt,
275
+ # language model stats
276
+ tool_calls=json.dumps(tool_calls),
277
+ tool_call_ids=tool_call_ids,
278
+ n_tool_calls=len(tool_calls),
279
+ prompt_tokens=token_usage.get("prompt_tokens"),
280
+ completion_tokens=token_usage.get("completion_tokens"),
281
+ model_name=response_meta.get("model_name"),
282
+ # langgraph metadata
283
+ langgraph_ts=checkpoint.get("ts"),
284
+ langgraph_thread_id=lg_thread_id,
285
+ langgraph_checkpoint_id=final_row["checkpoint_id"],
286
+ langgraph_parent_checkpoint_id=final_row[
287
+ "parent_checkpoint_id"
288
+ ],
289
+ langgraph_metadata=final_row["metadata"],
290
+ langgraph_message_type=msg_type,
291
+ langgraph_type=msg_type,
184
292
  )
185
- # metadata is the state update for that row
186
- metadata = json.loads(completion_row["metadata"])
187
- # IDs from langgraph
188
293
 
189
- if metadata.get("writes") is None:
190
- continue
294
+ # Build context for next message
295
+ context.append({"role": role, "content": content})
296
+
297
+ # Track tool calls and responses for ToolCall creation
298
+ if msg_type == "tool":
299
+ tool_call_id = getattr(msg, "tool_call_id", None)
300
+ if tool_call_id:
301
+ tool_responses_dict[tool_call_id] = content
191
302
  else:
192
- # Goal here is to create a data structure for EACH write/update
193
- # that can be used to construct a Message object
194
- # LangGraph stores info in 'writes' in the checkpoints.metadata column
195
- # but the format is a bit different between human and machine input
196
- # The resulting data structure should have
197
- # key (str) -- graph 'node' that produced the message (or 'human')
198
- # value (list) -- list of 'message' data structures with id, kwargs, etc
199
- # {
200
- # 'node_name':{
201
- # "messages":[
202
- # {
203
- # 'id': "XYZ"
204
- # 'kwargs':{
205
- # "content": 'text of the message',
206
- # "additional_kwargs": {}
207
- # },
208
- # }
209
- # ]
210
- #
211
- # }
212
- # }
213
-
214
- # user input condition
215
- if metadata.get("source") == "input":
216
- # NOTE: I think with the updated logging of HumanMessage with langgraph, we don't need this case
217
- update_dict = {}
218
- # this will be a dictionary we can add to
219
- # key is 'input', as in human input
220
- update_dict["input"] = {"messages": []}
221
- # print("metadata keys:", metadata["writes"].keys())
222
- # the very first message in input in a thread seems to include
223
- # the system prompt, not a message that was sent by the user.
224
- # the system promptdoesn't seem to be set anywhere else, so
225
- # using that as the system prompt for the thread.
226
- messagecount = 0
227
- for msg in metadata["writes"]["__start__"]["messages"]:
228
- if messagecount == 0 and metadata["step"] == -1:
229
- system_prompt = msg["kwargs"]["content"]
230
- messagecount += 1
231
- else:
232
- message = {}
233
- message["id"] = [
234
- "HumanMessage"
235
- ] # LangGraph has a list here
236
- message["kwargs"] = {}
237
- message["kwargs"]["content"] = msg
238
- message["kwargs"]["type"] = "human"
239
- update_dict["input"]["messages"].append(message)
240
- # will be used below
241
- role = "user"
242
-
243
- # machine input condition
244
- elif metadata.get("source") == "loop":
245
- # This already has a list of messages with kwargs, etc
246
- update_dict = metadata.get("writes")
247
- # I think 'system_prompt' is empty by default and not stored here unless
248
- # it's included in the LangGraph state
249
- checkpoint_system_prompt = checkpoint.get(
250
- "channel_values", {}
251
- ).get("system_prompt")
252
- if checkpoint_system_prompt is not None:
253
- system_prompt = checkpoint_system_prompt
254
- role = "assistant"
255
- else:
256
- raise Exception(
257
- f"Unhandled input condition! Source not 'loop' or 'input'. Metadata: {metadata}"
258
- )
259
- # Add system prompt as first thing in context if not already present
260
- if len(context) == 0:
261
- context.append({"role": "system", "content": system_prompt})
262
-
263
- # iterate through nodes - there is probably only 1
264
- for node, value in update_dict.items():
265
- # iterate through list of message updates
266
- if "messages" in value:
267
- if isinstance(value["messages"], dict):
268
- # Make this a list to iterate through - 4 Feb 2025 - used to be a list previously
269
- messagelist = [value["messages"]]
270
- else:
271
- messagelist = value["messages"]
272
- index_in_thread = 0
273
- for message in messagelist:
274
- if role == "user":
275
- content = (
276
- message.get("kwargs", {})
277
- .get("content", {})
278
- .get("kwargs", {})
279
- .get("content", None)
280
- )
281
- elif role == "assistant":
282
- content = message.get("kwargs", {}).get(
283
- "content", None
284
- )
285
- else:
286
- raise Exception(
287
- "`role` should be either user or assistant."
288
- )
289
- Message.create(
290
- evalsetrun=dataset.evalsetrun,
291
- dataset=dataset,
292
- thread=thread,
293
- index_in_thread=index_in_thread,
294
- role=role,
295
- content=content,
296
- context=json.dumps(context),
297
- is_flexeval_completion=False,
298
- system_prompt=system_prompt,
299
- # language model stats
300
- tool_calls=json.dumps(
301
- message.get("kwargs", {}).get(
302
- "tool_calls", []
303
- )
304
- ),
305
- tool_call_ids=[
306
- tc["id"]
307
- for tc in message.get("kwargs", {}).get(
308
- "tool_calls", []
309
- )
310
- ],
311
- n_tool_calls=len(
312
- message.get("kwargs", {}).get(
313
- "tool_calls", []
314
- )
315
- ),
316
- prompt_tokens=message.get("kwargs", {})
317
- .get("response_metadata", {})
318
- .get("token_usage", {})
319
- .get("prompt_tokens"),
320
- completion_tokens=message.get("kwargs", {})
321
- .get("response_metadata", {})
322
- .get("token_usage", {})
323
- .get("completion_tokens"),
324
- model_name=message.get("kwargs", {})
325
- .get("response_metadata", {})
326
- .get("model_name"),
327
- # langgraph metadata
328
- langgraph_ts=checkpoint.get("ts"),
329
- langgraph_step=metadata.get("step"),
330
- langgraph_thread_id=completion_row["thread_id"],
331
- langgraph_checkpoint_id=completion_row[
332
- "checkpoint_id"
333
- ],
334
- langgraph_parent_checkpoint_id=completion_row[
335
- "parent_checkpoint_id"
336
- ],
337
- langgraph_checkpoint=dumps(
338
- checkpoint
339
- ), # Have to re-dump this because of the de-serialization#completion_row["checkpoint"],
340
- langgraph_metadata=completion_row["metadata"],
341
- langgraph_node=node,
342
- langgraph_message_type=message["id"][-1],
343
- langgraph_type=message.get("kwargs", {}).get(
344
- "type"
345
- ),
346
- # special property of state
347
- langchain_print=message.get("kwargs", {})
348
- .get("additional_kwargs", {})
349
- .get("print", False),
350
- )
351
-
352
- # update the context for the next Message
353
- context.append(
354
- {
355
- "role": role,
356
- "content": content,
357
- "langgraph_role": message["id"][-1],
358
- }
359
- )
360
-
361
- # record tool call info so we can match them up later
362
- if message.get("kwargs", {}).get("type") == "tool":
363
- # this should have a mapping between tool_call_id and the RESPONSE to to the tool call
364
- tool_responses_dict[
365
- message.get("kwargs", {}).get(
366
- "tool_call_id"
367
- )
368
- ] = message.get("kwargs", {}).get("content", "")
369
- else:
370
- for tool_call in message.get("kwargs", {}).get(
371
- "tool_calls", []
372
- ):
373
- # this should have all the info about the tool calls, including additional_kwargs
374
- # but NOT their responses
375
- tool_calls_dict[tool_call["id"]] = tool_call
376
- tool_addional_kwargs_dict[
377
- tool_call["id"]
378
- ] = message.get("kwargs", {}).get(
379
- "additional_kwargs", {}
380
- )
381
- index_in_thread += 1
382
-
383
- # Add turns to each message
384
- # Need to do this before dealing with tool calls, since we
385
- # associated turns with tool calls via messages during the .create() method
303
+ for tc in tool_calls:
304
+ tool_calls_dict[tc["id"]] = tc
305
+ tool_additional_kwargs_dict[tc["id"]] = additional_kwargs
306
+
307
+ # Create turns from messages
386
308
  add_turns(thread)
387
309
 
388
- ## Match up tool calls and make an object for each match
310
+ # Create ToolCall objects by matching calls to responses
389
311
  for tool_call_id, tool_call_vals in tool_calls_dict.items():
390
312
  if tool_call_id not in tool_responses_dict:
391
313
  raise ValueError(
392
314
  f"Found a tool call without a tool response! id='{tool_call_id}'"
393
315
  )
394
- # get matching message - should now be accessible through thread now?
395
316
  matching_message = [
396
- m for m in thread.messages if tool_call_id in m.tool_call_ids
317
+ m
318
+ for m in thread.messages
319
+ if tool_call_id in (m.tool_call_ids or [])
397
320
  ][0]
398
321
 
399
322
  ToolCall.create(
400
- evalsetrun=dataset.evalsetrun,
401
323
  dataset=dataset,
402
324
  thread=thread,
403
325
  turn=matching_message.turn,
@@ -405,14 +327,12 @@ def load_langgraph_sqlite(
405
327
  function_name=tool_call_vals.get("name"),
406
328
  args=json.dumps(tool_call_vals.get("args")),
407
329
  additional_kwargs=json.dumps(
408
- tool_addional_kwargs_dict.get(tool_call_id)
330
+ tool_additional_kwargs_dict.get(tool_call_id)
409
331
  ),
410
332
  tool_call_id=tool_call_id,
411
333
  response_content=tool_responses_dict.get(tool_call_id),
412
334
  )
413
335
 
414
- ## Add system prompt if available?
415
-
416
336
 
417
337
  def add_turns(thread: Thread):
418
338
  # Add turn labels
@@ -426,7 +346,6 @@ def add_turns(thread: Thread):
426
346
  index_in_thread = 0
427
347
  for placeholder_turn_id, role in turn_dict.items(): # turns.items():
428
348
  t = Turn.create(
429
- evalsetrun=thread.evalsetrun,
430
349
  dataset=thread.dataset,
431
350
  thread=thread,
432
351
  index_in_thread=index_in_thread,
@@ -447,12 +366,10 @@ def add_turns(thread: Thread):
447
366
 
448
367
  def verify_checkpoints_table_exists(cursor):
449
368
  # double check that the 'checkpoints' table exists
450
- cursor.execute(
451
- """
369
+ cursor.execute("""
452
370
  SELECT name FROM sqlite_master
453
371
  WHERE type='table' AND name='checkpoints'
454
- """
455
- )
372
+ """)
456
373
  result = cursor.fetchone()
457
374
  # Assert that the result is not None, meaning the table exists
458
375
  assert result is not None, "Table 'checkpoints' does not exist in the database."
flexeval/db_utils.py CHANGED
@@ -4,14 +4,23 @@ import peewee as pw
4
4
 
5
5
  from flexeval.classes import base as classes_base
6
6
  from flexeval.classes.dataset import Dataset
7
- from flexeval.classes.eval_set_run import EvalSetRun
7
+ from flexeval.classes.eval_set_run import EvalSetRun, EvalSetRunDatasets
8
8
  from flexeval.classes.message import Message
9
9
  from flexeval.classes.metric import Metric
10
10
  from flexeval.classes.thread import Thread
11
11
  from flexeval.classes.tool_call import ToolCall
12
12
  from flexeval.classes.turn import Turn
13
13
 
14
- DATABASE_TABLES = [EvalSetRun, Dataset, Thread, Turn, Message, ToolCall, Metric]
14
+ DATABASE_TABLES = [
15
+ EvalSetRun,
16
+ Dataset,
17
+ EvalSetRunDatasets,
18
+ Thread,
19
+ Turn,
20
+ Message,
21
+ ToolCall,
22
+ Metric,
23
+ ]
15
24
 
16
25
 
17
26
  def ensure_database(database_path: str):
@@ -115,9 +115,9 @@ def get_parent_metrics(all_metrics: dict, child: dict) -> tuple[list, list]:
115
115
  """metrics_graph_ordered_list will be a list of metrics in order in which they should be run
116
116
 
117
117
  This function takes the eval represented by "child" and finds ALL evals in "all_metrics"
118
- that quality as the child's immediate parent
118
+ that qualify as the child's immediate parent
119
119
 
120
- An eval can qualify as a parent by having a matching name, type, context_only
120
+ An eval can qualify as a parent by having a matching name, type, etc.
121
121
  At this point, we won't have enough information to decide whether the child should be run
122
122
  (since the child might have additional requirements on the output of the parent)
123
123
  but this is enough to tell us that the child should be run AFTER the parent.
@@ -145,7 +145,7 @@ def get_parent_metrics(all_metrics: dict, child: dict) -> tuple[list, list]:
145
145
 
146
146
  # if the conditionals are listed in the depends_on entry but don't match...
147
147
  # Only check conditionals that are explicitly specified (not None) in the requirement
148
- conditionals = ["metric_level", "context_only", "name", "kwargs"]
148
+ conditionals = ["metric_level", "name", "kwargs"]
149
149
  for conditional in conditionals:
150
150
  if (
151
151
  conditional in requirement