python-flexeval 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. flexeval/__init__.py +11 -0
  2. flexeval/__main__.py +11 -0
  3. flexeval/classes/__init__.py +15 -0
  4. flexeval/classes/base.py +32 -0
  5. flexeval/classes/dataset.py +82 -0
  6. flexeval/classes/eval_runner.py +158 -0
  7. flexeval/classes/eval_set_run.py +32 -0
  8. flexeval/classes/message.py +183 -0
  9. flexeval/classes/metric.py +55 -0
  10. flexeval/classes/thread.py +79 -0
  11. flexeval/classes/tool_call.py +51 -0
  12. flexeval/classes/turn.py +206 -0
  13. flexeval/cli.py +104 -0
  14. flexeval/completions.py +147 -0
  15. flexeval/compute_metrics.py +788 -0
  16. flexeval/config.yaml +23 -0
  17. flexeval/configuration/__init__.py +1 -0
  18. flexeval/configuration/completion_functions.py +231 -0
  19. flexeval/configuration/evals.yaml +864 -0
  20. flexeval/configuration/function_metrics.py +650 -0
  21. flexeval/configuration/rubric_metrics.yaml +194 -0
  22. flexeval/data_loader.py +513 -0
  23. flexeval/db_utils.py +38 -0
  24. flexeval/dependency_graph.py +234 -0
  25. flexeval/eval_schema.json +256 -0
  26. flexeval/function_types.py +173 -0
  27. flexeval/helpers.py +52 -0
  28. flexeval/io/__init__.py +1 -0
  29. flexeval/io/parsers/yaml_parser.py +69 -0
  30. flexeval/log_utils.py +34 -0
  31. flexeval/metrics/__init__.py +8 -0
  32. flexeval/metrics/access.py +28 -0
  33. flexeval/metrics/save.py +39 -0
  34. flexeval/rubric.py +62 -0
  35. flexeval/run_utils.py +65 -0
  36. flexeval/runner.py +132 -0
  37. flexeval/schema/__init__.py +11 -0
  38. flexeval/schema/config_schema.py +46 -0
  39. flexeval/schema/eval_schema.py +163 -0
  40. flexeval/schema/evalrun_schema.py +97 -0
  41. flexeval/schema/rubric_schema.py +40 -0
  42. flexeval/schema/schema_utils.py +26 -0
  43. python_flexeval-0.1.5.dist-info/METADATA +118 -0
  44. python_flexeval-0.1.5.dist-info/RECORD +47 -0
  45. python_flexeval-0.1.5.dist-info/WHEEL +4 -0
  46. python_flexeval-0.1.5.dist-info/entry_points.txt +2 -0
  47. python_flexeval-0.1.5.dist-info/licenses/LICENSE +21 -0
flexeval/__init__.py ADDED
@@ -0,0 +1,11 @@
1
+ """FlexEval is a Python package for designing custom metrics, completion functions, and LLM-graded rubrics for evaluating the behavior of LLM-powered systems.
2
+
3
+ This top-level import exposes the :func:`~flexeval.runner.run` method."""
4
+
5
+ from flexeval import metrics
6
+ from flexeval.runner import run
7
+
8
+ __all__ = [
9
+ "metrics",
10
+ "run",
11
+ ]
flexeval/__main__.py ADDED
@@ -0,0 +1,11 @@
1
+ """Pass-through to the CLI."""
2
+
3
+ from flexeval import cli
4
+
5
+
6
+ def main():
7
+ cli.main()
8
+
9
+
10
+ if __name__ == "__main__":
11
+ main()
@@ -0,0 +1,15 @@
1
+ """Peewee classes used for saving the results of a FlexEval run.
2
+
3
+ See :mod:`peewee` for more information on the capabilities of these objects."""
4
+
5
+ from . import dataset, eval_set_run, message, metric, thread, tool_call, turn
6
+
7
+ __all__ = [
8
+ "dataset",
9
+ "eval_set_run",
10
+ "message",
11
+ "metric",
12
+ "thread",
13
+ "tool_call",
14
+ "turn",
15
+ ]
@@ -0,0 +1,32 @@
1
+ import peewee as pw
2
+ from playhouse.shortcuts import ThreadSafeDatabaseMetadata
3
+ from playhouse.sqliteq import SqliteQueueDatabase
4
+
5
+
6
+ def create_sqlite_database(
7
+ database_path: str | None = None, use_queue_db: bool = False
8
+ ) -> pw.SqliteDatabase:
9
+ if use_queue_db:
10
+ return SqliteQueueDatabase(
11
+ database_path,
12
+ use_gevent=False,
13
+ autostart=False,
14
+ results_timeout=5.0,
15
+ queue_max_size=64, # Max. # of pending writes that can accumulate
16
+ pragmas={"journal_mode": "wal"}, # use Write-ahead Logging
17
+ )
18
+ return pw.SqliteDatabase(
19
+ database_path,
20
+ pragmas={"journal_mode": "wal"}, # use Write-ahead Logging
21
+ )
22
+
23
+
24
+ database = create_sqlite_database()
25
+
26
+
27
+ class BaseModel(pw.Model):
28
+ """Peewee base class for all FlexEval database models."""
29
+
30
+ class Meta:
31
+ model_metadata_class = ThreadSafeDatabaseMetadata
32
+ database = database
@@ -0,0 +1,82 @@
1
+ import os.path
2
+
3
+ import peewee as pw
4
+
5
+ from flexeval.classes.base import BaseModel
6
+ from flexeval.classes.eval_set_run import EvalSetRun
7
+
8
+
9
+ class Dataset(BaseModel):
10
+ """Holds a dataset, e.g. a jsonl file"""
11
+
12
+ id = pw.IntegerField(primary_key=True)
13
+ evalsetrun = pw.ForeignKeyField(EvalSetRun, backref="datasets")
14
+ filename = pw.TextField()
15
+ datatype = pw.TextField(null=True)
16
+ contents = pw.TextField(null=True) # raw contents
17
+
18
+ max_n_conversation_threads = pw.IntegerField(null=True)
19
+ nb_evaluations_per_thread = pw.IntegerField(null=True, default=1)
20
+
21
+ # In line with LangGraph expectations, we assume n=1 for all outputs of LLMs
22
+ # However, each node can append list with length 2+ to the message queue
23
+
24
+ # Thread - conversation
25
+ # Turn - adjacent messages from the same agent
26
+ # Message -
27
+ # role - human or ai, user or assistant
28
+ # text - empty string or non-empty
29
+ # list of 0+ Tool Calls
30
+ # post-processing - add a turn_id
31
+ # additional_kwargs JSON
32
+ # ToolUse
33
+ # foreign keys to "invoker" message and "function output" message
34
+ # message that invoked it - foreign key
35
+ # parameters of the input
36
+ # result of tool call
37
+ # Metric
38
+ # granularity type
39
+ # foreign key to the object
40
+
41
+ # **each entry from LangGraph is a LIST of completions - usually with length 1
42
+
43
+ # Completion - has one bit of text content, and 0+ ToolCalls
44
+ # ToolCall - tool call (and response!) associated with the completion
45
+ # completion_id
46
+ # message_id
47
+ # turn_id
48
+
49
+ def load_data(self):
50
+ from flexeval import (
51
+ data_loader,
52
+ ) # Local import as this needs to happen after the module is fully loaded
53
+
54
+ if self.filename.endswith(".jsonl"):
55
+ self.datatype = "json"
56
+ data_loader.load_jsonl(
57
+ dataset=self,
58
+ filename=self.filename,
59
+ max_n_conversation_threads=self.max_n_conversation_threads,
60
+ nb_evaluations_per_thread=self.nb_evaluations_per_thread,
61
+ )
62
+
63
+ elif is_sqlite_file(self.filename):
64
+ self.datatype = "sqlite"
65
+ data_loader.load_langgraph_sqlite(
66
+ dataset=self,
67
+ filename=self.filename,
68
+ max_n_conversation_threads=self.max_n_conversation_threads,
69
+ nb_evaluations_per_thread=self.nb_evaluations_per_thread,
70
+ )
71
+ else:
72
+ raise ValueError(
73
+ f"Unsupported format '{os.path.splitext(self.filename)[-1]}'. Each Data File must be either a jsonl or sqlite file. You provided the file: '{self.filename}'"
74
+ )
75
+
76
+
77
+ def is_sqlite_file(filepath):
78
+ # Open the file in binary mode
79
+ with open(filepath, "rb") as file:
80
+ header = file.read(16)
81
+ # Check if the header matches the SQLite format header
82
+ return header == b"SQLite format 3\x00"
@@ -0,0 +1,158 @@
1
+ import logging
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+
5
+ import dotenv
6
+ from peewee import SqliteDatabase
7
+
8
+ from flexeval import db_utils, dependency_graph
9
+ from flexeval.schema import EvalRun
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class EvalRunner:
15
+ """Class for maintaining database connection, logs, and run state
16
+ Does not need to write anything to database itself.
17
+ """
18
+
19
+ database: SqliteDatabase
20
+
21
+ def __init__(
22
+ self,
23
+ evalrun: EvalRun,
24
+ ):
25
+ self.evalrun: EvalRun = evalrun
26
+
27
+ self.initialize_logger()
28
+ self.add_file_logger()
29
+ self.load_env()
30
+ self.initialize_database()
31
+ self.load_evaluation_settings()
32
+
33
+ def initialize_logger(self, add_stream_handler: bool = False):
34
+ """Configure the logger for this class.
35
+
36
+ Args:
37
+ add_stream_handler (bool, optional): If True, will add a stream handler at the INFO level. Defaults to False.
38
+ """
39
+ self.logger = logging.getLogger("FlexEval")
40
+ self.logger.setLevel(logging.DEBUG)
41
+
42
+ if add_stream_handler:
43
+ # TODO this stream handler logic should probably be removed
44
+ # Create a console handler for lower level messages to output to console
45
+ ch = logging.StreamHandler()
46
+ ch.setLevel(logging.INFO)
47
+
48
+ # Create a formatter and set it for the handlers
49
+ formatter = logging.Formatter(
50
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
51
+ )
52
+ ch.setFormatter(formatter)
53
+
54
+ # Add the handlers to the logger
55
+ self.logger.addHandler(ch)
56
+
57
+ def add_file_logger(self):
58
+ if self.evalrun.config.logs_path is None:
59
+ logger.info("No log path specified, so not doing any file logging.")
60
+ return
61
+ logs_path = self.evalrun.config.logs_path
62
+ if logs_path.is_file():
63
+ raise ValueError(
64
+ f"Config logs_path expects a directory, but was set to existing file '{logs_path}'."
65
+ )
66
+ elif not logs_path.exists():
67
+ if logs_path.suffix != "":
68
+ logger.warning(
69
+ f"Creating logs_path '{logs_path}' as a directory, despite apparent suffix '{logs_path.suffix}'."
70
+ )
71
+ logs_path.mkdir(parents=True, exist_ok=True)
72
+
73
+ # Get the current date to use in the filename
74
+ current_date = datetime.now().strftime("%Y-%m-%d")
75
+
76
+ # Create a file handler that logs debug and higher level messages to a date-based file
77
+ log_filepath = logs_path / f"{current_date}_{self.evalrun.eval.name}.log"
78
+ fh = logging.FileHandler(log_filepath)
79
+ fh.setLevel(logging.DEBUG)
80
+
81
+ # Create a formatter and set it for the handlers
82
+ formatter = logging.Formatter(
83
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
84
+ )
85
+ fh.setFormatter(formatter)
86
+ self.logger.addHandler(fh)
87
+ self.logger.info(f"Started logging to log file '{log_filepath}'.")
88
+
89
+ def load_env(self):
90
+ env_filepath = self.evalrun.config.env_filepath
91
+ if env_filepath is not None and str(env_filepath).strip() != "":
92
+ if not env_filepath.exists():
93
+ raise ValueError(
94
+ f"Environment file not present at configured path '{env_filepath}'."
95
+ )
96
+ dotenv.load_dotenv(env_filepath, verbose=True)
97
+ self.logger.debug(f"Finished loading .env file from '{env_filepath}'.")
98
+ else:
99
+ self.logger.debug(
100
+ f"Skipping .env file loading as config env_filepath is '{env_filepath}'."
101
+ )
102
+
103
+ def get_database_path(self) -> Path:
104
+ return self.evalrun.database_path
105
+
106
+ def initialize_database(self):
107
+ """Initializes database and tables. If config.clear_tables, then current contents of tables are dropped."""
108
+ db_utils.initialize_database(
109
+ self.evalrun.database_path, clear_tables=self.evalrun.config.clear_tables
110
+ )
111
+
112
+ def load_evaluation_settings(self):
113
+ """This function parses our eval suite and puts it in the data structure we'll need
114
+ for easy use at run-time
115
+ """
116
+ # if the current eval has a 'config' entry, overwrite configuration options with its entries
117
+ if (
118
+ self.evalrun.eval.model_extra is not None
119
+ and len(self.evalrun.eval.model_extra) > 0
120
+ ):
121
+ model_extra = self.evalrun.eval.model_extra
122
+ self.logger.debug(
123
+ f"Extra configuration keys provided in eval: {list(model_extra.keys())}"
124
+ )
125
+ for field_name in model_extra.keys():
126
+ if hasattr(self.evalrun.config, field_name):
127
+ old_value = getattr(self.evalrun.config, field_name)
128
+ new_value = model_extra[field_name]
129
+ self.logger.info(
130
+ f"Updating configuration setting: '{field_name}'='{new_value}' (old='{old_value}')"
131
+ )
132
+ setattr(self.evalrun.config, field_name, new_value)
133
+ else:
134
+ self.logger.warning(
135
+ f"Unknown configuration field '{field_name}' was ignored."
136
+ )
137
+
138
+ # TODO verify that applying defaults is done solely by pydantic and this step is no longer necessary
139
+ # apply defaults to the schema
140
+ # self.eval = apply_defaults(schema=target_schema, data=self.eval)
141
+
142
+ # convert into graph structure
143
+ self.metrics_graph_ordered_list = dependency_graph.create_metrics_graph(
144
+ self.evalrun.eval.metrics
145
+ )
146
+ # validate: completion function defined
147
+ if len(self.metrics_graph_ordered_list) > 0:
148
+ if self.evalrun.eval.grader_llm is None:
149
+ self.logger.warning(
150
+ f"'{len(self.metrics_graph_ordered_list)}' metrics defined, but no grader LLM defined."
151
+ )
152
+
153
+ def shutdown_logging(self):
154
+ # remove logging handler so we don't get repeat logs if we call run() twice
155
+ handlers = self.logger.handlers[:]
156
+ for handler in handlers:
157
+ handler.close()
158
+ self.logger.removeHandler(handler)
@@ -0,0 +1,32 @@
1
+ import json
2
+ from datetime import datetime
3
+
4
+ import peewee as pw
5
+
6
+ from flexeval.classes.base import BaseModel
7
+
8
+
9
+ class EvalSetRun(BaseModel):
10
+ """Class for running set of evaluations"""
11
+
12
+ id = pw.IntegerField(primary_key=True)
13
+ name = pw.CharField(null=True)
14
+ notes = pw.TextField(null=True)
15
+ dataset_files = pw.TextField() # JSON string
16
+ metrics = pw.TextField()
17
+ metrics_graph_ordered_list = pw.TextField()
18
+ do_completion = pw.BooleanField()
19
+ completion_llm = pw.TextField(null=True) # JSON string
20
+ grader_llm = pw.TextField(null=True) # JSON string
21
+ model_name = pw.TextField(null=True) # JSON string
22
+ success = pw.BooleanField(null=True)
23
+ rubrics = pw.TextField(null=True)
24
+ timestamp = pw.DateTimeField(
25
+ default=datetime.now
26
+ ) # Automatically set to current date and time
27
+
28
+ def get_datasets(self) -> list[str]:
29
+ # TODO Turn these into DataSource instances instead, returning list[DataSource]
30
+ temp = json.loads(self.dataset_files)
31
+ assert isinstance(temp, list), "The `data` entry in evals.yaml must be a list."
32
+ return temp
@@ -0,0 +1,183 @@
1
+ import copy
2
+ import json
3
+ import logging
4
+
5
+ import peewee as pw
6
+ from playhouse.shortcuts import model_to_dict
7
+
8
+ from flexeval.classes.base import BaseModel
9
+ from flexeval.classes.dataset import Dataset
10
+ from flexeval.classes.eval_set_run import EvalSetRun
11
+ from flexeval.classes.thread import Thread
12
+ from flexeval.classes.turn import Turn
13
+ from flexeval.configuration import completion_functions
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class Message(BaseModel):
19
+ """Holds a single component of a single turn
20
+ Corresponds to one output of a node in LangGraph
21
+ or one Turn in jsonl
22
+ """
23
+
24
+ id = pw.IntegerField(primary_key=True)
25
+
26
+ evalsetrun = pw.ForeignKeyField(EvalSetRun, backref="messages")
27
+ dataset = pw.ForeignKeyField(Dataset, backref="messages")
28
+ thread = pw.ForeignKeyField(Thread, backref="messages")
29
+ index_in_thread = pw.IntegerField()
30
+ # must be null=True because we're adding it after create()
31
+ turn = pw.ForeignKeyField(Turn, null=True, backref="messages")
32
+
33
+ role = pw.TextField() # user or assistant - 'tools' are counted as assistants
34
+ content = pw.TextField()
35
+ context = pw.TextField(null=True) # Previous messages
36
+
37
+ # helpers
38
+ system_prompt = pw.TextField(null=True)
39
+ is_flexeval_completion = pw.BooleanField(null=True)
40
+ is_final_turn_in_input = pw.BooleanField(null=True)
41
+ langgraph_print = pw.TextField(null=True)
42
+
43
+ # language model stats
44
+ tool_callslanggraph_print = pw.TextField(null=True)
45
+ tool_call_ids = pw.TextField(null=True)
46
+ n_tool_calls = pw.IntegerField(null=True)
47
+ prompt_tokens = pw.IntegerField(null=True)
48
+ completion_tokens = pw.IntegerField(null=True)
49
+ model_name = pw.TextField(null=True)
50
+
51
+ # langgraph metadata
52
+ langgraph_ts = pw.TextField(null=True)
53
+ langgraph_step = pw.IntegerField(null=True)
54
+ langgraph_thread_id = pw.TextField(null=True)
55
+ langgraph_checkpoint_id = pw.TextField(null=True)
56
+ langgraph_parent_checkpoint_id = pw.TextField(null=True)
57
+ langgraph_node = pw.TextField(null=True)
58
+ langgraph_message_type = pw.TextField(null=True)
59
+ langgraph_type = pw.TextField(null=True)
60
+ langgraph_invocation_id = pw.TextField(null=True)
61
+ # putting these at the end so the database is easier to browse
62
+ langgraph_checkpoint = pw.TextField(null=True)
63
+ langgraph_metadata = pw.TextField(null=True)
64
+
65
+ def __init__(self, **kwargs):
66
+ super().__init__(**kwargs)
67
+ self.metrics_to_evaluate = []
68
+
69
+ def get_completion(self, include_system_prompt=False):
70
+ # only get a completion if this is the final turn - we probably don't want to branch from mid-conversation
71
+ if self.is_final_turn_in_input:
72
+ completion_config = json.loads(self.evalsetrun.completion_llm)
73
+ completion_fn_name = completion_config.get("function_name", None)
74
+ completion_function_kwargs = completion_config.get("kwargs", None)
75
+
76
+ # Check if the function name exists in the global namespace and call it
77
+
78
+ if hasattr(completion_functions, completion_fn_name) and hasattr(
79
+ completion_functions, completion_fn_name
80
+ ):
81
+ completion_function = getattr(
82
+ completion_functions, completion_fn_name, None
83
+ )
84
+ completion = completion_function(
85
+ conversation_history=self.get_formatted_prompt(
86
+ include_system_prompt=False
87
+ ),
88
+ **completion_function_kwargs,
89
+ )
90
+ else:
91
+ logger.warning(
92
+ "In completion_functions.py: No callable function named "
93
+ + completion_fn_name
94
+ + " found."
95
+ )
96
+ completion = None
97
+
98
+ # "completion" will be the output of an existing completion function
99
+ # which generally means it'll have a structure like this
100
+ # {"choices": [{"message": {"content": "hi", "role": "assistant"}}]}
101
+ result = model_to_dict(self, exclude=[self.id])
102
+ result["evalsetrun"] = self.evalsetrun
103
+ result["dataset"] = self.dataset
104
+ result["datasetrow"] = self.datasetrow
105
+ result["turn_number"] = self.turn_number + 1
106
+ result["role"] = "assistant"
107
+ result["context"] = self.get_formatted_prompt(include_system_prompt=False)
108
+ result["is_final_turn_in_input"] = False # b/c it's not in input
109
+ self.is_final_turn_in_input = False
110
+ result["is_completion"] = True
111
+ result["completion"] = completion
112
+ result["model"] = completion.get("model", None)
113
+ result["prompt_tokens"] = completion.get("usage", {}).get(
114
+ "prompt_tokens", None
115
+ ) / len(completion.get("choices", [1]))
116
+ result["completion_tokens"] = completion.get("usage", {}).get(
117
+ "completion_tokens", None
118
+ ) / len(
119
+ completion.get("choices", [1])
120
+ ) # TODO - use tiktoken here instead?? this will just give the average
121
+
122
+ result_list = []
123
+ for ix, choice in enumerate(completion["choices"]):
124
+ temp = copy.deepcopy(result)
125
+ temp["tool_used"] = choice["message"].get("tool_calls", None)
126
+ temp["turn"] = [choice["message"]]
127
+ temp["content"] = choice["message"]["content"]
128
+ temp["completion_number"] = ix + 1
129
+ result_list.append(temp)
130
+
131
+ return result_list
132
+ else:
133
+ return None
134
+
135
+ def get_formatted_prompt(self, include_system_prompt=False) -> list[dict[str, str]]:
136
+ formatted_prompt = []
137
+ if include_system_prompt:
138
+ formatted_prompt.append({"role": "system", "content": self.system_prompt})
139
+ context = json.loads(self.context)
140
+ if len(context) > 0:
141
+ formatted_prompt += context # TODO - we might just want a subset of this
142
+ formatted_prompt.append({"role": self.role, "content": self.content})
143
+ # for t in json.loads(self.turn):
144
+ # formatted_prompt.append({"role": t["role"], "content": t["content"]})
145
+ return formatted_prompt
146
+
147
+ def format_input_for_rubric(self):
148
+ input = self.get_formatted_prompt()
149
+ output_minus_completion = ""
150
+ for i in input[:-1]:
151
+ output_minus_completion += f"{i['role']}: {i['content']}\n"
152
+ completion = f"{input[-1]['content']}"
153
+ output = output_minus_completion + completion
154
+
155
+ tool_call_text = ""
156
+ for tc in self.toolcalls:
157
+ tool_call_text += """
158
+
159
+ Function name: {function_name}
160
+ Input arguments: {args}
161
+ Function output: {response_content}
162
+ """.format(
163
+ function_name=tc.function_name,
164
+ args=tc.args,
165
+ response_content=tc.response_content,
166
+ )
167
+
168
+ # output - all turns
169
+ # output_minus_completion - all turns except the last
170
+ # completion - last turn
171
+ # tool_call_text - all tool calls
172
+ return output, output_minus_completion, completion, tool_call_text
173
+
174
+ def get_content(self) -> str:
175
+ return self.content
176
+
177
+ def get_context(self, include_system_prompt=False) -> list[dict[str, str]]:
178
+ context = json.loads(self.context)
179
+ if not include_system_prompt:
180
+ context = [
181
+ cur_dict for cur_dict in context if cur_dict.get("role") != "system"
182
+ ]
183
+ return context
@@ -0,0 +1,55 @@
1
+ import peewee as pw
2
+
3
+ from flexeval.classes.base import BaseModel
4
+ from flexeval.classes.dataset import Dataset
5
+ from flexeval.classes.eval_set_run import EvalSetRun
6
+ from flexeval.classes.message import Message
7
+ from flexeval.classes.thread import Thread
8
+ from flexeval.classes.tool_call import ToolCall
9
+ from flexeval.classes.turn import Turn
10
+
11
+
12
+ class Metric(BaseModel):
13
+ """Holds a single metric/property computed based one just ONE turn"""
14
+
15
+ id = pw.IntegerField(primary_key=True)
16
+
17
+ evalsetrun = pw.ForeignKeyField(EvalSetRun, backref="metrics_list")
18
+ dataset = pw.ForeignKeyField(Dataset, backref="metrics_list")
19
+ thread = pw.ForeignKeyField(Thread, backref="metrics_list")
20
+ turn = pw.ForeignKeyField(
21
+ Turn, null=True, backref="metrics_list"
22
+ ) # Only defined for Turn metrics
23
+ message = pw.ForeignKeyField(
24
+ Message, null=True, backref="metrics_list"
25
+ ) # Only defined for Message metrics
26
+ toolcall = pw.ForeignKeyField(
27
+ ToolCall, null=True, backref="metrics_list"
28
+ ) # Only defined for ToolCall metrics
29
+
30
+ evaluation_name = pw.TextField()
31
+ evaluation_type = pw.TextField()
32
+ metric_name = pw.TextField()
33
+ # metric_type = pw.TextField() # TODO: Some parts of the code use "metric_tye" and others use "evaluation_type" - choose one for consistency
34
+ metric_level = pw.TextField()
35
+ # TODO we may want to consider adding a secondary metric_nonnumeric_value field to support non-numeric functions and rubrics
36
+ metric_value = pw.FloatField(
37
+ null=True
38
+ ) # necessary if rubric result is INVALID or e.g. latency doesn't apply to the very first message
39
+ kwargs = pw.TextField()
40
+ # context_only allows us to create another kind of dependency
41
+ # where we can quantify something about the previous conversation
42
+ # and then use that quantity in a downstream analysis
43
+ # e.g. 'would a plot be pedagogically appropriate here' is really a question about the PAST of the conversation
44
+ # NOTE: but we have gotten rid of context_only for rubrics, where only {context} is used so technically here 'context_only' is False
45
+ # or 'was the conversation ever flagged by the moderation api' would be a question about the previous turns that might
46
+ # allow to have better context for the properties of this turn
47
+ # context_only = pw.BooleanField(default=False)
48
+ source = pw.TextField() # TODO - make another table for this? But maybe not, because this also contains filled-in rubrics
49
+ depends_on = pw.TextField()
50
+ rubric_prompt = pw.TextField(null=True)
51
+ rubric_completion = pw.TextField(null=True)
52
+ rubric_model = pw.TextField(null=True)
53
+ rubric_completion_tokens = pw.IntegerField(null=True)
54
+ rubric_prompt_tokens = pw.IntegerField(null=True)
55
+ rubric_score = pw.TextField(null=True)
@@ -0,0 +1,79 @@
1
+ import peewee as pw
2
+
3
+ from flexeval.classes.base import BaseModel
4
+ from flexeval.classes.dataset import Dataset
5
+ from flexeval.classes.eval_set_run import EvalSetRun
6
+
7
+
8
+ class Thread(BaseModel):
9
+ """Class for holding a single thread / conversation
10
+ This corresponds to a single row in a jsonl file
11
+ or a single 'thread_id' in a langgraph checkpoint database"""
12
+
13
+ id = pw.IntegerField(primary_key=True)
14
+ dataset = pw.ForeignKeyField(Dataset, backref="threads")
15
+ evalsetrun = pw.ForeignKeyField(EvalSetRun, backref="threads")
16
+
17
+ langgraph_thread_id = pw.TextField(null=True)
18
+ eval_run_thread_id = pw.TextField(null=True)
19
+ jsonl_thread_id = pw.TextField(null=True)
20
+
21
+ system_prompt = pw.TextField(null=True)
22
+
23
+ def __init__(self, **kwargs):
24
+ super().__init__(**kwargs)
25
+ self.metrics_to_evaluate = []
26
+
27
+ # TODO - test this!
28
+ def format_input_for_rubric(self):
29
+ input = self.get_formatted_prompt()
30
+ output_minus_completion = ""
31
+ for i in input[:-1]:
32
+ output_minus_completion += f"{i['role']}: {i['content']}\n"
33
+ completion = f"{input[-1]['role']}: {input[-1]['content']}\n"
34
+ output = output_minus_completion + completion
35
+
36
+ tool_call_text = ""
37
+ for tc in self.toolcalls:
38
+ tool_call_text += """
39
+
40
+ Function name: {function_name}
41
+ Input arguments: {args}
42
+ Function output: {response_content}
43
+ """.format(
44
+ function_name=tc.function_name,
45
+ args=tc.args,
46
+ response_content=tc.response_content,
47
+ )
48
+
49
+ # output - all turns
50
+ # output_minus_completion - all turns except the last
51
+ # completion - last turn
52
+ # tool_call_text - all tool calls
53
+ return output, output_minus_completion, completion, tool_call_text
54
+
55
+ def get_formatted_prompt(self, include_system_prompt=False):
56
+ formatted_prompt = []
57
+ if include_system_prompt:
58
+ formatted_prompt.append({"role": "system", "content": self.system_prompt})
59
+
60
+ formatted_prompt += self.get_content()
61
+
62
+ return formatted_prompt
63
+
64
+ def get_content(self, include_toolcalls=True):
65
+ """
66
+ Content is a list of dictionaries where each dictionary
67
+ contains the role and content of messages and tool calls
68
+ in the turn. Each tool call appears after the message it's
69
+ associated with. If toolcalls are not desired, pass False
70
+ to include_toolcalls.
71
+ """
72
+ content = []
73
+ for message in self.messages:
74
+ content.append({"role": message.role, "content": message.content})
75
+ if include_toolcalls:
76
+ for toolcall in message.toolcalls:
77
+ content.append(toolcall.get_dict_representation())
78
+
79
+ return content