python-flexeval 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. flexeval/__init__.py +11 -0
  2. flexeval/__main__.py +11 -0
  3. flexeval/classes/__init__.py +15 -0
  4. flexeval/classes/base.py +32 -0
  5. flexeval/classes/dataset.py +82 -0
  6. flexeval/classes/eval_runner.py +158 -0
  7. flexeval/classes/eval_set_run.py +32 -0
  8. flexeval/classes/message.py +183 -0
  9. flexeval/classes/metric.py +55 -0
  10. flexeval/classes/thread.py +79 -0
  11. flexeval/classes/tool_call.py +51 -0
  12. flexeval/classes/turn.py +206 -0
  13. flexeval/cli.py +104 -0
  14. flexeval/completions.py +147 -0
  15. flexeval/compute_metrics.py +788 -0
  16. flexeval/config.yaml +23 -0
  17. flexeval/configuration/__init__.py +1 -0
  18. flexeval/configuration/completion_functions.py +231 -0
  19. flexeval/configuration/evals.yaml +864 -0
  20. flexeval/configuration/function_metrics.py +650 -0
  21. flexeval/configuration/rubric_metrics.yaml +194 -0
  22. flexeval/data_loader.py +513 -0
  23. flexeval/db_utils.py +38 -0
  24. flexeval/dependency_graph.py +234 -0
  25. flexeval/eval_schema.json +256 -0
  26. flexeval/function_types.py +173 -0
  27. flexeval/helpers.py +52 -0
  28. flexeval/io/__init__.py +1 -0
  29. flexeval/io/parsers/yaml_parser.py +69 -0
  30. flexeval/log_utils.py +34 -0
  31. flexeval/metrics/__init__.py +8 -0
  32. flexeval/metrics/access.py +28 -0
  33. flexeval/metrics/save.py +39 -0
  34. flexeval/rubric.py +62 -0
  35. flexeval/run_utils.py +65 -0
  36. flexeval/runner.py +132 -0
  37. flexeval/schema/__init__.py +11 -0
  38. flexeval/schema/config_schema.py +46 -0
  39. flexeval/schema/eval_schema.py +163 -0
  40. flexeval/schema/evalrun_schema.py +97 -0
  41. flexeval/schema/rubric_schema.py +40 -0
  42. flexeval/schema/schema_utils.py +26 -0
  43. python_flexeval-0.1.5.dist-info/METADATA +118 -0
  44. python_flexeval-0.1.5.dist-info/RECORD +47 -0
  45. python_flexeval-0.1.5.dist-info/WHEEL +4 -0
  46. python_flexeval-0.1.5.dist-info/entry_points.txt +2 -0
  47. python_flexeval-0.1.5.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,513 @@
1
+ """Dataset loading functions. Maybe should move to :mod:`~flexeval.io`."""
2
+
3
+ import json
4
+ import logging
5
+ import pathlib
6
+ import random as rd
7
+ import sqlite3
8
+
9
+ from langchain.load.dump import dumps
10
+ from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
11
+
12
+ from flexeval.classes.dataset import Dataset
13
+ from flexeval.classes.message import Message
14
+ from flexeval.classes.thread import Thread
15
+ from flexeval.classes.tool_call import ToolCall
16
+ from flexeval.classes.turn import Turn
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def load_jsonl(
22
+ dataset: Dataset,
23
+ filename: str | pathlib.Path,
24
+ max_n_conversation_threads: int | None = None,
25
+ nb_evaluations_per_thread: int | None = 1,
26
+ ):
27
+ with open(filename, "r") as infile:
28
+ contents = infile.read() # will be a big string
29
+ all_lines = contents.splitlines()
30
+
31
+ # Each row is a single row of the jsonl file
32
+ # That means it has 'input' as a key, and a list of dictionaries as values
33
+ # per line
34
+
35
+ if max_n_conversation_threads is None:
36
+ max_n_conversation_threads = len(all_lines)
37
+
38
+ if max_n_conversation_threads <= len(all_lines):
39
+ selected_thread_ids = rd.sample(
40
+ list(range(len(all_lines))), max_n_conversation_threads
41
+ )
42
+ else:
43
+ logger.debug(
44
+ f"You requested up to '{max_n_conversation_threads}' conversations but only '{len(all_lines)}' are present in Jsonl dataset at '{filename}'."
45
+ )
46
+ selected_thread_ids = list(range(len(all_lines)))
47
+
48
+ ### should duplicate the select threads nb_evaluations_per_thread times
49
+ if nb_evaluations_per_thread is None:
50
+ nb_evaluations_per_thread = 1
51
+
52
+ for thread_id, thread in enumerate(all_lines):
53
+ for thread_eval_run_id in range(
54
+ max(1, nb_evaluations_per_thread)
55
+ ): # duplicate stored threads for averaged evaluation results
56
+ if thread_id in selected_thread_ids:
57
+ thread_object = Thread.create(
58
+ evalsetrun=dataset.evalsetrun,
59
+ dataset=dataset,
60
+ jsonl_thread_id=thread_id,
61
+ eval_run_thread_id=str(thread_id)
62
+ + "_"
63
+ + str(thread_eval_run_id),
64
+ )
65
+
66
+ # Context
67
+ context = []
68
+ thread_input = json.loads(thread)["input"]
69
+
70
+ # Get system prompt used in the thread - assuming only 1
71
+ for message in thread_input:
72
+ if message["role"] == "system":
73
+ system_prompt = message["content"]
74
+ break
75
+ else:
76
+ system_prompt = None
77
+ if system_prompt is not None:
78
+ # Add the system prompt as context
79
+ context.append({"role": "system", "content": system_prompt})
80
+
81
+ # Create messages
82
+ index_in_thread = 0
83
+ for message in thread_input:
84
+ role = message.get("role", None)
85
+ if role != "system":
86
+ # System message shouldn't be added as a separate message
87
+ system_prompt_for_this_message = ""
88
+ if role != "user":
89
+ system_prompt_for_this_message = system_prompt
90
+ Message.create(
91
+ evalsetrun=dataset.evalsetrun,
92
+ dataset=dataset,
93
+ thread=thread_object,
94
+ index_in_thread=index_in_thread,
95
+ role=role,
96
+ content=message.get("content", None),
97
+ context=json.dumps(context),
98
+ metadata=message.get("metadata", None),
99
+ is_flexeval_completion=False,
100
+ system_prompt=system_prompt_for_this_message,
101
+ )
102
+ # Update context
103
+ context.append(
104
+ {"role": role, "content": message.get("content", None)}
105
+ )
106
+ index_in_thread += 1
107
+
108
+ add_turns(thread_object)
109
+
110
+ # TODO - should we add ToolCall here? Is there a standard way to represent them in jsonl?
111
+
112
+
113
+ def load_langgraph_sqlite(
114
+ dataset: Dataset,
115
+ filename: str,
116
+ max_n_conversation_threads: int | None = None,
117
+ nb_evaluations_per_thread: int | None = 1,
118
+ ):
119
+ serializer = JsonPlusSerializer()
120
+
121
+ with sqlite3.connect(filename) as conn:
122
+ # Set the row factory to sqlite3.Row
123
+ # allowing us to reference columns by name instead of index
124
+ conn.row_factory = sqlite3.Row
125
+
126
+ # Create a cursor object
127
+ cursor = conn.cursor()
128
+ verify_checkpoints_table_exists(cursor)
129
+
130
+ # Sync database
131
+ query = "PRAGMA wal_checkpoint(FULL);"
132
+ cursor.execute(query)
133
+
134
+ # Make threads (aka conversations)
135
+ query = "select distinct thread_id from checkpoints"
136
+ cursor.execute(query)
137
+ thread_ids = cursor.fetchall()
138
+
139
+ nb_threads = len(thread_ids)
140
+ if max_n_conversation_threads is None:
141
+ max_n_conversation_threads = nb_threads
142
+
143
+ if max_n_conversation_threads <= nb_threads:
144
+ selected_thread_ids = rd.sample(thread_ids, max_n_conversation_threads)
145
+ else:
146
+ logger.debug(
147
+ f"You requested up to '{max_n_conversation_threads}' conversations but only '{nb_threads}' are present in Sqlite dataset at '{filename}'."
148
+ )
149
+ selected_thread_ids = thread_ids
150
+
151
+ logger.debug(" DEBUG DUPLICATE SELECT THREAD IDS\n", selected_thread_ids[0])
152
+
153
+ for thread_eval_run_id in range(
154
+ max(1, nb_evaluations_per_thread)
155
+ ): # duplicate stored threads for averaged evaluation results
156
+ for thread_id in selected_thread_ids:
157
+ thread = Thread.create(
158
+ evalsetrun=dataset.evalsetrun,
159
+ dataset=dataset,
160
+ langgraph_thread_id=thread_id[0],
161
+ eval_run_thread_id=str(thread_id[0])
162
+ + "_"
163
+ + str(thread_eval_run_id),
164
+ )
165
+
166
+ # Create messages
167
+ query = f"select * from checkpoints where thread_id = '{thread.langgraph_thread_id}'"
168
+ cursor.execute(query)
169
+ completion_list = cursor.fetchall()
170
+
171
+ # context has to be reset at the start of every thread
172
+ context = []
173
+ # tool call variables
174
+ tool_calls_dict = {}
175
+ tool_responses_dict = {}
176
+ tool_addional_kwargs_dict = {}
177
+ # system prompt reset for every thread
178
+ system_prompt = None
179
+
180
+ for completion_row in completion_list:
181
+ # checkpoint is full state history
182
+ checkpoint = serializer.loads_typed(
183
+ (completion_row["type"], completion_row["checkpoint"])
184
+ )
185
+ # metadata is the state update for that row
186
+ metadata = json.loads(completion_row["metadata"])
187
+ # IDs from langgraph
188
+
189
+ if metadata.get("writes") is None:
190
+ continue
191
+ else:
192
+ # Goal here is to create a data structure for EACH write/update
193
+ # that can be used to construct a Message object
194
+ # LangGraph stores info in 'writes' in the checkpoints.metadata column
195
+ # but the format is a bit different between human and machine input
196
+ # The resulting data structure should have
197
+ # key (str) -- graph 'node' that produced the message (or 'human')
198
+ # value (list) -- list of 'message' data structures with id, kwargs, etc
199
+ # {
200
+ # 'node_name':{
201
+ # "messages":[
202
+ # {
203
+ # 'id': "XYZ"
204
+ # 'kwargs':{
205
+ # "content": 'text of the message',
206
+ # "additional_kwargs": {}
207
+ # },
208
+ # }
209
+ # ]
210
+ #
211
+ # }
212
+ # }
213
+
214
+ # user input condition
215
+ if metadata.get("source") == "input":
216
+ # NOTE: I think with the updated logging of HumanMessage with langgraph, we don't need this case
217
+ update_dict = {}
218
+ # this will be a dictionary we can add to
219
+ # key is 'input', as in human input
220
+ update_dict["input"] = {"messages": []}
221
+ # print("metadata keys:", metadata["writes"].keys())
222
+ # the very first message in input in a thread seems to include
223
+ # the system prompt, not a message that was sent by the user.
224
+ # the system promptdoesn't seem to be set anywhere else, so
225
+ # using that as the system prompt for the thread.
226
+ messagecount = 0
227
+ for msg in metadata["writes"]["__start__"]["messages"]:
228
+ if messagecount == 0 and metadata["step"] == -1:
229
+ system_prompt = msg["kwargs"]["content"]
230
+ messagecount += 1
231
+ else:
232
+ message = {}
233
+ message["id"] = [
234
+ "HumanMessage"
235
+ ] # LangGraph has a list here
236
+ message["kwargs"] = {}
237
+ message["kwargs"]["content"] = msg
238
+ message["kwargs"]["type"] = "human"
239
+ update_dict["input"]["messages"].append(message)
240
+ # will be used below
241
+ role = "user"
242
+
243
+ # machine input condition
244
+ elif metadata.get("source") == "loop":
245
+ # This already has a list of messages with kwargs, etc
246
+ update_dict = metadata.get("writes")
247
+ # I think 'system_prompt' is empty by default and not stored here unless
248
+ # it's included in the LangGraph state
249
+ checkpoint_system_prompt = checkpoint.get(
250
+ "channel_values", {}
251
+ ).get("system_prompt")
252
+ if checkpoint_system_prompt is not None:
253
+ system_prompt = checkpoint_system_prompt
254
+ role = "assistant"
255
+ else:
256
+ raise Exception(
257
+ f"Unhandled input condition! Source not 'loop' or 'input'. Metadata: {metadata}"
258
+ )
259
+ # Add system prompt as first thing in context if not already present
260
+ if len(context) == 0:
261
+ context.append({"role": "system", "content": system_prompt})
262
+
263
+ # iterate through nodes - there is probably only 1
264
+ for node, value in update_dict.items():
265
+ # iterate through list of message updates
266
+ if "messages" in value:
267
+ if isinstance(value["messages"], dict):
268
+ # Make this a list to iterate through - 4 Feb 2025 - used to be a list previously
269
+ messagelist = [value["messages"]]
270
+ else:
271
+ messagelist = value["messages"]
272
+ index_in_thread = 0
273
+ for message in messagelist:
274
+ if role == "user":
275
+ content = (
276
+ message.get("kwargs", {})
277
+ .get("content", {})
278
+ .get("kwargs", {})
279
+ .get("content", None)
280
+ )
281
+ elif role == "assistant":
282
+ content = message.get("kwargs", {}).get(
283
+ "content", None
284
+ )
285
+ else:
286
+ raise Exception(
287
+ "`role` should be either user or assistant."
288
+ )
289
+ Message.create(
290
+ evalsetrun=dataset.evalsetrun,
291
+ dataset=dataset,
292
+ thread=thread,
293
+ index_in_thread=index_in_thread,
294
+ role=role,
295
+ content=content,
296
+ context=json.dumps(context),
297
+ is_flexeval_completion=False,
298
+ system_prompt=system_prompt,
299
+ # language model stats
300
+ tool_calls=json.dumps(
301
+ message.get("kwargs", {}).get(
302
+ "tool_calls", []
303
+ )
304
+ ),
305
+ tool_call_ids=[
306
+ tc["id"]
307
+ for tc in message.get("kwargs", {}).get(
308
+ "tool_calls", []
309
+ )
310
+ ],
311
+ n_tool_calls=len(
312
+ message.get("kwargs", {}).get(
313
+ "tool_calls", []
314
+ )
315
+ ),
316
+ prompt_tokens=message.get("kwargs", {})
317
+ .get("response_metadata", {})
318
+ .get("token_usage", {})
319
+ .get("prompt_tokens"),
320
+ completion_tokens=message.get("kwargs", {})
321
+ .get("response_metadata", {})
322
+ .get("token_usage", {})
323
+ .get("completion_tokens"),
324
+ model_name=message.get("kwargs", {})
325
+ .get("response_metadata", {})
326
+ .get("model_name"),
327
+ # langgraph metadata
328
+ langgraph_ts=checkpoint.get("ts"),
329
+ langgraph_step=metadata.get("step"),
330
+ langgraph_thread_id=completion_row["thread_id"],
331
+ langgraph_checkpoint_id=completion_row[
332
+ "checkpoint_id"
333
+ ],
334
+ langgraph_parent_checkpoint_id=completion_row[
335
+ "parent_checkpoint_id"
336
+ ],
337
+ langgraph_checkpoint=dumps(
338
+ checkpoint
339
+ ), # Have to re-dump this because of the de-serialization#completion_row["checkpoint"],
340
+ langgraph_metadata=completion_row["metadata"],
341
+ langgraph_node=node,
342
+ langgraph_message_type=message["id"][-1],
343
+ langgraph_type=message.get("kwargs", {}).get(
344
+ "type"
345
+ ),
346
+ # special property of state
347
+ langchain_print=message.get("kwargs", {})
348
+ .get("additional_kwargs", {})
349
+ .get("print", False),
350
+ )
351
+
352
+ # update the context for the next Message
353
+ context.append(
354
+ {
355
+ "role": role,
356
+ "content": content,
357
+ "langgraph_role": message["id"][-1],
358
+ }
359
+ )
360
+
361
+ # record tool call info so we can match them up later
362
+ if message.get("kwargs", {}).get("type") == "tool":
363
+ # this should have a mapping between tool_call_id and the RESPONSE to to the tool call
364
+ tool_responses_dict[
365
+ message.get("kwargs", {}).get(
366
+ "tool_call_id"
367
+ )
368
+ ] = message.get("kwargs", {}).get("content", "")
369
+ else:
370
+ for tool_call in message.get("kwargs", {}).get(
371
+ "tool_calls", []
372
+ ):
373
+ # this should have all the info about the tool calls, including additional_kwargs
374
+ # but NOT their responses
375
+ tool_calls_dict[tool_call["id"]] = tool_call
376
+ tool_addional_kwargs_dict[
377
+ tool_call["id"]
378
+ ] = message.get("kwargs", {}).get(
379
+ "additional_kwargs", {}
380
+ )
381
+ index_in_thread += 1
382
+
383
+ # Add turns to each message
384
+ # Need to do this before dealing with tool calls, since we
385
+ # associated turns with tool calls via messages during the .create() method
386
+ add_turns(thread)
387
+
388
+ ## Match up tool calls and make an object for each match
389
+ for tool_call_id, tool_call_vals in tool_calls_dict.items():
390
+ if tool_call_id not in tool_responses_dict:
391
+ raise ValueError(
392
+ f"Found a tool call without a tool response! id='{tool_call_id}'"
393
+ )
394
+ # get matching message - should now be accessible through thread now?
395
+ matching_message = [
396
+ m for m in thread.messages if tool_call_id in m.tool_call_ids
397
+ ][0]
398
+
399
+ ToolCall.create(
400
+ evalsetrun=dataset.evalsetrun,
401
+ dataset=dataset,
402
+ thread=thread,
403
+ turn=matching_message.turn,
404
+ message=matching_message,
405
+ function_name=tool_call_vals.get("name"),
406
+ args=json.dumps(tool_call_vals.get("args")),
407
+ additional_kwargs=json.dumps(
408
+ tool_addional_kwargs_dict.get(tool_call_id)
409
+ ),
410
+ tool_call_id=tool_call_id,
411
+ response_content=tool_responses_dict.get(tool_call_id),
412
+ )
413
+
414
+ ## Add system prompt if available?
415
+
416
+
417
+ def add_turns(thread: Thread):
418
+ # Add turn labels
419
+ # Step 1 - add placeholder_turn_id to each message
420
+ message_roles = []
421
+ for message in thread.messages:
422
+ message_roles.append({"id": message.id, "role": message.role})
423
+ message_placeholder_ids, turn_dict = get_turns(thread=thread)
424
+ # Step 2 - Create turns, plus a mapping between the placeholder ids and the created ids
425
+ turns = {}
426
+ index_in_thread = 0
427
+ for placeholder_turn_id, role in turn_dict.items(): # turns.items():
428
+ t = Turn.create(
429
+ evalsetrun=thread.evalsetrun,
430
+ dataset=thread.dataset,
431
+ thread=thread,
432
+ index_in_thread=index_in_thread,
433
+ role=role,
434
+ )
435
+ # map placeholder id to turn object
436
+ turns[placeholder_turn_id] = t
437
+ index_in_thread += 1
438
+ # Step 3 - add placeholder ids to messages
439
+ # Can use zip since entries in message_list correspond to thread.messages
440
+ # NOTE: ANR: I don't follow how the message_list was supposed to work below.
441
+ for ml, message in zip(message_placeholder_ids, thread.messages):
442
+ # Is this going to work? No idea
443
+ message.turn = turns[ml]
444
+ # message.is_final_turn_in_input = ml.get("is_final_turn_in_input", False)
445
+ message.save()
446
+
447
+
448
+ def verify_checkpoints_table_exists(cursor):
449
+ # double check that the 'checkpoints' table exists
450
+ cursor.execute(
451
+ """
452
+ SELECT name FROM sqlite_master
453
+ WHERE type='table' AND name='checkpoints'
454
+ """
455
+ )
456
+ result = cursor.fetchone()
457
+ # Assert that the result is not None, meaning the table exists
458
+ assert result is not None, "Table 'checkpoints' does not exist in the database."
459
+
460
+
461
+ def get_turns(thread: Thread):
462
+ """We're defining a turn as a list of 1 or more consequtive outputs
463
+ by the same role, where the role is either 'user', or 'assistant/tool'.
464
+ In other words, we would parse as follows:
465
+ TURN 1 - user
466
+ TURN 2 - assistant
467
+ TURN 3 - user
468
+ TURN 4 - assistant
469
+ TURN 4 - tool
470
+ TURN 4 - assistant
471
+ TURN 5 - user
472
+ """
473
+
474
+ # these are all treated as belonging to the same 'turn'
475
+ machine_labels = ["assistant", "ai", "tool"]
476
+
477
+ turn_id = 1
478
+ previous_role = ""
479
+ # TODO: Make a message list here, store the placeholder ids, and update to the real turn ids; save at end
480
+ message_placeholder_ids = []
481
+ for turnentry_id, entry in enumerate(thread.messages): # enumerate(input_list):
482
+ current_role = entry.role # entry.get("role", None)
483
+ # entry["role"] = current_role
484
+ # if your role matches a previous, don't increment turn_id
485
+ if (current_role in machine_labels and previous_role in machine_labels) or (
486
+ current_role not in machine_labels and previous_role not in machine_labels
487
+ ):
488
+ pass # TODO: clean up the condition to avoid the empty if
489
+ # previous_role = current_role
490
+ # entry["placholder_turn_id"] = turn_id
491
+ else:
492
+ turn_id += 1
493
+ # entry["placholder_turn_id"] = turn_id
494
+ # previous_role = current_role
495
+ # entry.turn_id = turn_id
496
+ message_placeholder_ids.append(turn_id)
497
+ previous_role = current_role
498
+ # entry.save()
499
+
500
+ # NOTE: ANR seems like this could be optimized - e.g., set all
501
+ # to false, then do a select query for just the ones where turn_id column is turn_id. That would also
502
+ # reduce the number of saves to the database.
503
+ # label final entry
504
+ # ANR: moved up the turn_id_roles bit here to avoid iterating twice
505
+ turn_id_roles = {}
506
+ for message_placehold_id, entry in zip(
507
+ message_placeholder_ids, thread.messages
508
+ ): # input_list:
509
+ turn_id_roles[message_placehold_id] = entry.role
510
+ entry.is_final_turn_in_input = message_placehold_id == turn_id
511
+ entry.save() # Could optimize this to avoid saving twice
512
+
513
+ return message_placeholder_ids, turn_id_roles
flexeval/db_utils.py ADDED
@@ -0,0 +1,38 @@
1
+ """Peewee database utilities."""
2
+
3
+ import peewee as pw
4
+
5
+ from flexeval.classes import base as classes_base
6
+ from flexeval.classes.dataset import Dataset
7
+ from flexeval.classes.eval_set_run import EvalSetRun
8
+ from flexeval.classes.message import Message
9
+ from flexeval.classes.metric import Metric
10
+ from flexeval.classes.thread import Thread
11
+ from flexeval.classes.tool_call import ToolCall
12
+ from flexeval.classes.turn import Turn
13
+
14
+ DATABASE_TABLES = [EvalSetRun, Dataset, Thread, Turn, Message, ToolCall, Metric]
15
+
16
+
17
+ def initialize_database(database_path: str, clear_tables: bool = False):
18
+ classes_base.database.init(database_path)
19
+ # classes_base.database.start()
20
+
21
+ if clear_tables:
22
+ classes_base.database.drop_tables(DATABASE_TABLES)
23
+ classes_base.database.create_tables(DATABASE_TABLES)
24
+
25
+
26
+ def bind_to_database(database_path: str) -> pw.Database:
27
+ """Utility function for binding to a FlexEval database so that ORM functionality can be used.
28
+
29
+ See: https://docs.peewee-orm.com/en/latest/peewee/database.html#setting-the-database-at-run-time
30
+
31
+ Returns:
32
+ pw.Database: The new database created for the models to bind to.
33
+ """
34
+ new_database = classes_base.create_sqlite_database(database_path)
35
+ new_database.bind(DATABASE_TABLES)
36
+ # Verify the binding worked by checking one of the models
37
+ assert classes_base.BaseModel._meta.database == new_database
38
+ return new_database