edsl 0.1.39.dev2__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +28 -0
- edsl/__init__.py +1 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +8 -16
- edsl/agents/Invigilator.py +13 -14
- edsl/agents/InvigilatorBase.py +4 -1
- edsl/agents/PromptConstructor.py +42 -22
- edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
- edsl/auto/AutoStudy.py +18 -5
- edsl/auto/StageBase.py +53 -40
- edsl/auto/StageQuestions.py +2 -1
- edsl/auto/utilities.py +0 -6
- edsl/coop/coop.py +21 -5
- edsl/data/Cache.py +29 -18
- edsl/data/CacheHandler.py +0 -2
- edsl/data/RemoteCacheSync.py +154 -46
- edsl/data/hack.py +10 -0
- edsl/enums.py +7 -0
- edsl/inference_services/AnthropicService.py +38 -16
- edsl/inference_services/AvailableModelFetcher.py +7 -1
- edsl/inference_services/GoogleService.py +5 -1
- edsl/inference_services/InferenceServicesCollection.py +18 -2
- edsl/inference_services/OpenAIService.py +46 -31
- edsl/inference_services/TestService.py +1 -3
- edsl/inference_services/TogetherAIService.py +5 -3
- edsl/inference_services/data_structures.py +74 -2
- edsl/jobs/AnswerQuestionFunctionConstructor.py +148 -113
- edsl/jobs/FetchInvigilator.py +10 -3
- edsl/jobs/InterviewsConstructor.py +6 -4
- edsl/jobs/Jobs.py +299 -233
- edsl/jobs/JobsChecks.py +2 -2
- edsl/jobs/JobsPrompts.py +1 -1
- edsl/jobs/JobsRemoteInferenceHandler.py +160 -136
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/interviews/Interview.py +80 -42
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +87 -357
- edsl/jobs/runners/JobsRunnerStatus.py +131 -164
- edsl/jobs/tasks/TaskHistory.py +24 -3
- edsl/language_models/LanguageModel.py +59 -4
- edsl/language_models/ModelList.py +19 -8
- edsl/language_models/__init__.py +1 -1
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +1 -1
- edsl/questions/QuestionBase.py +35 -26
- edsl/questions/QuestionBasePromptsMixin.py +1 -1
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +2 -2
- edsl/questions/QuestionExtract.py +5 -7
- edsl/questions/QuestionFreeText.py +1 -1
- edsl/questions/QuestionList.py +9 -15
- edsl/questions/QuestionMatrix.py +1 -1
- edsl/questions/QuestionMultipleChoice.py +1 -1
- edsl/questions/QuestionNumerical.py +1 -1
- edsl/questions/QuestionRank.py +1 -1
- edsl/questions/SimpleAskMixin.py +1 -1
- edsl/questions/__init__.py +1 -1
- edsl/questions/data_structures.py +20 -0
- edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +52 -49
- edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +6 -18
- edsl/questions/{ResponseValidatorFactory.py → response_validator_factory.py} +7 -1
- edsl/results/DatasetExportMixin.py +60 -119
- edsl/results/Result.py +109 -3
- edsl/results/Results.py +50 -39
- edsl/results/file_exports.py +252 -0
- edsl/scenarios/ScenarioList.py +35 -7
- edsl/surveys/Survey.py +71 -20
- edsl/test_h +1 -0
- edsl/utilities/gcp_bucket/example.py +50 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +2 -2
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/RECORD +85 -76
- edsl/language_models/registry.py +0 -180
- /edsl/agents/{QuestionOptionProcessor.py → question_option_processor.py} +0 -0
- /edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +0 -0
- /edsl/questions/{LoopProcessor.py → loop_processor.py} +0 -0
- /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
- /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
- /edsl/results/{Selector.py → results_selector.py} +0 -0
- /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
- /edsl/scenarios/{DirectoryScanner.py → directory_scanner.py} +0 -0
- /edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +0 -0
- /edsl/scenarios/{ScenarioSelector.py → scenario_selector.py} +0 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +0 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
@@ -1,11 +1,16 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import copy
|
3
3
|
import itertools
|
4
|
-
from typing import Optional, List, Callable, Type
|
5
|
-
|
4
|
+
from typing import Optional, List, Callable, Type, TYPE_CHECKING
|
5
|
+
|
6
|
+
if TYPE_CHECKING:
|
7
|
+
from edsl.questions.QuestionBase import QuestionBase
|
8
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
6
9
|
|
7
10
|
|
8
11
|
class QuestionBaseGenMixin:
|
12
|
+
"""Mixin for QuestionBase."""
|
13
|
+
|
9
14
|
def copy(self) -> QuestionBase:
|
10
15
|
"""Return a deep copy of the question.
|
11
16
|
|
@@ -39,6 +44,31 @@ class QuestionBaseGenMixin:
|
|
39
44
|
questions.append(question)
|
40
45
|
return questions
|
41
46
|
|
47
|
+
def draw(self) -> "QuestionBase":
|
48
|
+
"""Return a new question with a randomly selected permutation of the options.
|
49
|
+
|
50
|
+
If the question has no options, returns a copy of the original question.
|
51
|
+
|
52
|
+
>>> from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice as Q
|
53
|
+
>>> q = Q.example()
|
54
|
+
>>> drawn = q.draw()
|
55
|
+
>>> len(drawn.question_options) == len(q.question_options)
|
56
|
+
True
|
57
|
+
>>> q is drawn
|
58
|
+
False
|
59
|
+
"""
|
60
|
+
|
61
|
+
if not hasattr(self, "question_options"):
|
62
|
+
return copy.deepcopy(self)
|
63
|
+
|
64
|
+
import random
|
65
|
+
|
66
|
+
question = copy.deepcopy(self)
|
67
|
+
question.question_options = list(
|
68
|
+
random.sample(self.question_options, len(self.question_options))
|
69
|
+
)
|
70
|
+
return question
|
71
|
+
|
42
72
|
def loop(self, scenario_list: ScenarioList) -> List[QuestionBase]:
|
43
73
|
"""Return a list of questions with the question name modified for each scenario.
|
44
74
|
|
@@ -50,57 +80,22 @@ class QuestionBaseGenMixin:
|
|
50
80
|
>>> len(q.loop(ScenarioList.from_list("subject", ["Math", "Economics", "Chemistry"])))
|
51
81
|
3
|
52
82
|
"""
|
53
|
-
from edsl.questions.
|
83
|
+
from edsl.questions.loop_processor import LoopProcessor
|
54
84
|
|
55
85
|
lp = LoopProcessor(self)
|
56
86
|
return lp.process_templates(scenario_list)
|
57
87
|
|
58
|
-
# from jinja2 import Environment
|
59
|
-
# from edsl.questions.QuestionBase import QuestionBase
|
60
|
-
|
61
|
-
# starting_name = self.question_name
|
62
|
-
# questions = []
|
63
|
-
# for index, scenario in enumerate(scenario_list):
|
64
|
-
# env = Environment()
|
65
|
-
# new_data = self.to_dict().copy()
|
66
|
-
# for key, value in [(k, v) for k, v in new_data.items() if v is not None]:
|
67
|
-
# if (
|
68
|
-
# isinstance(value, str) or isinstance(value, int)
|
69
|
-
# ) and key != "question_options":
|
70
|
-
# new_data[key] = env.from_string(value).render(scenario)
|
71
|
-
# elif isinstance(value, list):
|
72
|
-
# new_data[key] = [
|
73
|
-
# env.from_string(v).render(scenario) if isinstance(v, str) else v
|
74
|
-
# for v in value
|
75
|
-
# ]
|
76
|
-
# elif isinstance(value, dict):
|
77
|
-
# new_data[key] = {
|
78
|
-
# (
|
79
|
-
# env.from_string(k).render(scenario)
|
80
|
-
# if isinstance(k, str)
|
81
|
-
# else k
|
82
|
-
# ): (
|
83
|
-
# env.from_string(v).render(scenario)
|
84
|
-
# if isinstance(v, str)
|
85
|
-
# else v
|
86
|
-
# )
|
87
|
-
# for k, v in value.items()
|
88
|
-
# }
|
89
|
-
# elif key == "question_options" and isinstance(value, str):
|
90
|
-
# new_data[key] = value
|
91
|
-
# else:
|
92
|
-
# raise ValueError(
|
93
|
-
# f"Unexpected value type: {type(value)} for key '{key}'"
|
94
|
-
# )
|
95
|
-
|
96
|
-
# if new_data["question_name"] == starting_name:
|
97
|
-
# new_data["question_name"] = new_data["question_name"] + f"_{index}"
|
98
|
-
|
99
|
-
# questions.append(QuestionBase.from_dict(new_data))
|
100
|
-
# return questions
|
101
|
-
|
102
88
|
def render(self, replacement_dict: dict) -> "QuestionBase":
|
103
|
-
"""Render the question components as jinja2 templates with the replacement dictionary.
|
89
|
+
"""Render the question components as jinja2 templates with the replacement dictionary.
|
90
|
+
|
91
|
+
:param replacement_dict: The dictionary of values to replace in the question components.
|
92
|
+
|
93
|
+
>>> from edsl.questions.QuestionFreeText import QuestionFreeText
|
94
|
+
>>> q = QuestionFreeText(question_name = "color", question_text = "What is your favorite {{ thing }}?")
|
95
|
+
>>> q.render({"thing": "color"})
|
96
|
+
Question('free_text', question_name = \"""color\""", question_text = \"""What is your favorite color?\""")
|
97
|
+
|
98
|
+
"""
|
104
99
|
from jinja2 import Environment
|
105
100
|
from edsl.scenarios.Scenario import Scenario
|
106
101
|
|
@@ -127,15 +122,23 @@ class QuestionBaseGenMixin:
|
|
127
122
|
|
128
123
|
return self.apply_function(render_string)
|
129
124
|
|
130
|
-
def apply_function(
|
125
|
+
def apply_function(
|
126
|
+
self, func: Callable, exclude_components: List[str] = None
|
127
|
+
) -> QuestionBase:
|
131
128
|
"""Apply a function to the question parts
|
132
129
|
|
130
|
+
:param func: The function to apply to the question parts.
|
131
|
+
:param exclude_components: The components to exclude from the function application.
|
132
|
+
|
133
133
|
>>> from edsl.questions import QuestionFreeText
|
134
134
|
>>> q = QuestionFreeText(question_name = "color", question_text = "What is your favorite color?")
|
135
135
|
>>> shouting = lambda x: x.upper()
|
136
136
|
>>> q.apply_function(shouting)
|
137
137
|
Question('free_text', question_name = \"""color\""", question_text = \"""WHAT IS YOUR FAVORITE COLOR?\""")
|
138
138
|
|
139
|
+
>>> q.apply_function(shouting, exclude_components = ["question_type"])
|
140
|
+
Question('free_text', question_name = \"""COLOR\""", question_text = \"""WHAT IS YOUR FAVORITE COLOR?\""")
|
141
|
+
|
139
142
|
"""
|
140
143
|
from edsl.questions.QuestionBase import QuestionBase
|
141
144
|
|
@@ -1,4 +1,3 @@
|
|
1
|
-
import logging
|
2
1
|
from abc import ABC, abstractmethod
|
3
2
|
from typing import Optional, Any, List, TypedDict
|
4
3
|
|
@@ -7,17 +6,17 @@ from pydantic import BaseModel, Field, field_validator, ValidationError
|
|
7
6
|
from edsl.exceptions.questions import QuestionAnswerValidationError
|
8
7
|
from edsl.questions.ExceptionExplainer import ExceptionExplainer
|
9
8
|
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
generated_tokens: Optional[str] = None
|
9
|
+
from edsl.questions.data_structures import (
|
10
|
+
RawEdslAnswerDict,
|
11
|
+
EdslAnswerDict,
|
12
|
+
)
|
15
13
|
|
16
14
|
|
17
15
|
class ResponseValidatorABC(ABC):
|
18
16
|
required_params: List[str] = []
|
19
17
|
|
20
18
|
def __init_subclass__(cls, **kwargs):
|
19
|
+
"""This is a metaclass that ensures that all subclasses of ResponseValidatorABC have the required class variables."""
|
21
20
|
super().__init_subclass__(**kwargs)
|
22
21
|
required_class_vars = ["required_params", "valid_examples", "invalid_examples"]
|
23
22
|
for var in required_class_vars:
|
@@ -52,12 +51,7 @@ class ResponseValidatorABC(ABC):
|
|
52
51
|
if not hasattr(self, "permissive"):
|
53
52
|
self.permissive = False
|
54
53
|
|
55
|
-
self.fixes_tried = 0
|
56
|
-
|
57
|
-
class RawEdslAnswerDict(TypedDict):
|
58
|
-
answer: Any
|
59
|
-
comment: Optional[str]
|
60
|
-
generated_tokens: Optional[str]
|
54
|
+
self.fixes_tried = 0 # how many times we've tried to fix the answer
|
61
55
|
|
62
56
|
def _preprocess(self, data: RawEdslAnswerDict) -> RawEdslAnswerDict:
|
63
57
|
"""This is for testing purposes. A question can be given an exception to throw or an answer to always return.
|
@@ -89,11 +83,6 @@ class ResponseValidatorABC(ABC):
|
|
89
83
|
def post_validation_answer_convert(self, data):
|
90
84
|
return data
|
91
85
|
|
92
|
-
class EdslAnswerDict(TypedDict):
|
93
|
-
answer: Any
|
94
|
-
comment: Optional[str]
|
95
|
-
generated_tokens: Optional[str]
|
96
|
-
|
97
86
|
def validate(
|
98
87
|
self,
|
99
88
|
raw_edsl_answer_dict: RawEdslAnswerDict,
|
@@ -136,7 +125,6 @@ class ResponseValidatorABC(ABC):
|
|
136
125
|
def human_explanation(self, e: QuestionAnswerValidationError):
|
137
126
|
explanation = ExceptionExplainer(e, model_response=e.data).explain()
|
138
127
|
return explanation
|
139
|
-
# return e
|
140
128
|
|
141
129
|
def _handle_exception(self, e: Exception, raw_edsl_answer_dict) -> EdslAnswerDict:
|
142
130
|
if self.fixes_tried == 0:
|
@@ -1,4 +1,10 @@
|
|
1
|
+
from edsl.questions.data_structures import BaseModel
|
2
|
+
from edsl.questions.response_validator_abc import ResponseValidatorABC
|
3
|
+
|
4
|
+
|
1
5
|
class ResponseValidatorFactory:
|
6
|
+
"""Factory class to create a response validator for a question."""
|
7
|
+
|
2
8
|
def __init__(self, question):
|
3
9
|
self.question = question
|
4
10
|
|
@@ -10,7 +16,7 @@ class ResponseValidatorFactory:
|
|
10
16
|
return self.question.create_response_model()
|
11
17
|
|
12
18
|
@property
|
13
|
-
def response_validator(self) -> "
|
19
|
+
def response_validator(self) -> "ResponseValidatorABC":
|
14
20
|
"""Return the response validator."""
|
15
21
|
params = (
|
16
22
|
{
|
@@ -1,12 +1,12 @@
|
|
1
1
|
"""Mixin class for exporting results."""
|
2
2
|
|
3
|
-
import base64
|
4
|
-
import csv
|
5
3
|
import io
|
6
4
|
import warnings
|
7
5
|
import textwrap
|
8
6
|
from typing import Optional, Tuple, Union, List
|
9
7
|
|
8
|
+
from edsl.results.file_exports import CSVExport, ExcelExport, JSONLExport, SQLiteExport
|
9
|
+
|
10
10
|
|
11
11
|
class DatasetExportMixin:
|
12
12
|
"""Mixin class for exporting Dataset objects."""
|
@@ -164,79 +164,44 @@ class DatasetExportMixin:
|
|
164
164
|
remove_prefix=remove_prefix, pretty_labels=pretty_labels
|
165
165
|
)
|
166
166
|
|
167
|
-
def to_jsonl(self, filename: Optional[str] = None) -> "FileStore":
|
168
|
-
"""Export the results to a FileStore instance containing JSONL data.
|
169
|
-
|
170
|
-
|
171
|
-
filename: Optional filename for the JSONL file (defaults to "results.jsonl")
|
172
|
-
|
173
|
-
Returns:
|
174
|
-
FileStore: Instance containing the JSONL data
|
175
|
-
"""
|
176
|
-
if filename is None:
|
177
|
-
filename = "results.jsonl"
|
167
|
+
def to_jsonl(self, filename: Optional[str] = None) -> Optional["FileStore"]:
|
168
|
+
"""Export the results to a FileStore instance containing JSONL data."""
|
169
|
+
exporter = JSONLExport(data=self, filename=filename)
|
170
|
+
return exporter.export()
|
178
171
|
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
base64_string=base64_string,
|
172
|
+
def to_sqlite(
|
173
|
+
self,
|
174
|
+
filename: Optional[str] = None,
|
175
|
+
remove_prefix: bool = False,
|
176
|
+
pretty_labels: Optional[dict] = None,
|
177
|
+
table_name: str = "results",
|
178
|
+
if_exists: str = "replace",
|
179
|
+
) -> Optional["FileStore"]:
|
180
|
+
"""Export the results to a SQLite database file."""
|
181
|
+
exporter = SQLiteExport(
|
182
|
+
data=self,
|
183
|
+
filename=filename,
|
184
|
+
remove_prefix=remove_prefix,
|
185
|
+
pretty_labels=pretty_labels,
|
186
|
+
table_name=table_name,
|
187
|
+
if_exists=if_exists,
|
196
188
|
)
|
189
|
+
return exporter.export()
|
197
190
|
|
198
191
|
def to_csv(
|
199
192
|
self,
|
200
193
|
filename: Optional[str] = None,
|
201
194
|
remove_prefix: bool = False,
|
202
195
|
pretty_labels: Optional[dict] = None,
|
203
|
-
) -> "FileStore":
|
204
|
-
"""Export the results to a FileStore instance containing CSV data.
|
205
|
-
|
206
|
-
|
207
|
-
filename
|
208
|
-
remove_prefix
|
209
|
-
pretty_labels
|
210
|
-
|
211
|
-
Returns:
|
212
|
-
FileStore: Instance containing the CSV data
|
213
|
-
"""
|
214
|
-
if filename is None:
|
215
|
-
filename = "results.csv"
|
216
|
-
|
217
|
-
# Get the tabular data
|
218
|
-
header, rows = self._get_tabular_data(
|
219
|
-
remove_prefix=remove_prefix, pretty_labels=pretty_labels
|
220
|
-
)
|
221
|
-
|
222
|
-
# Write to string buffer
|
223
|
-
output = io.StringIO()
|
224
|
-
writer = csv.writer(output)
|
225
|
-
writer.writerow(header)
|
226
|
-
writer.writerows(rows)
|
227
|
-
|
228
|
-
# Get the CSV string and encode to base64
|
229
|
-
csv_string = output.getvalue()
|
230
|
-
base64_string = base64.b64encode(csv_string.encode()).decode()
|
231
|
-
from edsl.scenarios.FileStore import FileStore
|
232
|
-
|
233
|
-
return FileStore(
|
234
|
-
path=filename,
|
235
|
-
mime_type="text/csv",
|
236
|
-
binary=False,
|
237
|
-
suffix="csv",
|
238
|
-
base64_string=base64_string,
|
196
|
+
) -> Optional["FileStore"]:
|
197
|
+
"""Export the results to a FileStore instance containing CSV data."""
|
198
|
+
exporter = CSVExport(
|
199
|
+
data=self,
|
200
|
+
filename=filename,
|
201
|
+
remove_prefix=remove_prefix,
|
202
|
+
pretty_labels=pretty_labels,
|
239
203
|
)
|
204
|
+
return exporter.export()
|
240
205
|
|
241
206
|
def to_excel(
|
242
207
|
self,
|
@@ -244,60 +209,16 @@ class DatasetExportMixin:
|
|
244
209
|
remove_prefix: bool = False,
|
245
210
|
pretty_labels: Optional[dict] = None,
|
246
211
|
sheet_name: Optional[str] = None,
|
247
|
-
) -> "FileStore":
|
248
|
-
"""Export the results to a FileStore instance containing Excel data.
|
249
|
-
|
250
|
-
|
251
|
-
filename
|
252
|
-
remove_prefix
|
253
|
-
pretty_labels
|
254
|
-
sheet_name
|
255
|
-
|
256
|
-
Returns:
|
257
|
-
FileStore: Instance containing the Excel data
|
258
|
-
"""
|
259
|
-
from openpyxl import Workbook
|
260
|
-
|
261
|
-
if filename is None:
|
262
|
-
filename = "results.xlsx"
|
263
|
-
if sheet_name is None:
|
264
|
-
sheet_name = "Results"
|
265
|
-
|
266
|
-
# Get the tabular data
|
267
|
-
header, rows = self._get_tabular_data(
|
268
|
-
remove_prefix=remove_prefix, pretty_labels=pretty_labels
|
269
|
-
)
|
270
|
-
|
271
|
-
# Create Excel workbook in memory
|
272
|
-
wb = Workbook()
|
273
|
-
ws = wb.active
|
274
|
-
ws.title = sheet_name
|
275
|
-
|
276
|
-
# Write header
|
277
|
-
for col, value in enumerate(header, 1):
|
278
|
-
ws.cell(row=1, column=col, value=value)
|
279
|
-
|
280
|
-
# Write data rows
|
281
|
-
for row_idx, row_data in enumerate(rows, 2):
|
282
|
-
for col, value in enumerate(row_data, 1):
|
283
|
-
ws.cell(row=row_idx, column=col, value=value)
|
284
|
-
|
285
|
-
# Save to bytes buffer
|
286
|
-
buffer = io.BytesIO()
|
287
|
-
wb.save(buffer)
|
288
|
-
buffer.seek(0)
|
289
|
-
|
290
|
-
# Convert to base64
|
291
|
-
base64_string = base64.b64encode(buffer.getvalue()).decode()
|
292
|
-
from edsl.scenarios.FileStore import FileStore
|
293
|
-
|
294
|
-
return FileStore(
|
295
|
-
path=filename,
|
296
|
-
mime_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
297
|
-
binary=True,
|
298
|
-
suffix="xlsx",
|
299
|
-
base64_string=base64_string,
|
212
|
+
) -> Optional["FileStore"]:
|
213
|
+
"""Export the results to a FileStore instance containing Excel data."""
|
214
|
+
exporter = ExcelExport(
|
215
|
+
data=self,
|
216
|
+
filename=filename,
|
217
|
+
remove_prefix=remove_prefix,
|
218
|
+
pretty_labels=pretty_labels,
|
219
|
+
sheet_name=sheet_name,
|
300
220
|
)
|
221
|
+
return exporter.export()
|
301
222
|
|
302
223
|
def _db(self, remove_prefix: bool = True):
|
303
224
|
"""Create a SQLite database in memory and return the connection.
|
@@ -398,6 +319,26 @@ class DatasetExportMixin:
|
|
398
319
|
# df_sorted = df.sort_index(axis=1) # Sort columns alphabetically
|
399
320
|
return df
|
400
321
|
|
322
|
+
def to_polars(
|
323
|
+
self, remove_prefix: bool = False, lists_as_strings=False
|
324
|
+
) -> "pl.DataFrame":
|
325
|
+
"""Convert the results to a Polars DataFrame.
|
326
|
+
|
327
|
+
:param remove_prefix: Whether to remove the prefix from the column names.
|
328
|
+
"""
|
329
|
+
return self._to_polars_strings(remove_prefix)
|
330
|
+
|
331
|
+
def _to_polars_strings(self, remove_prefix: bool = False) -> "pl.DataFrame":
|
332
|
+
"""Convert the results to a Polars DataFrame.
|
333
|
+
|
334
|
+
:param remove_prefix: Whether to remove the prefix from the column names.
|
335
|
+
"""
|
336
|
+
import polars as pl
|
337
|
+
|
338
|
+
csv_string = self.to_csv(remove_prefix=remove_prefix).text
|
339
|
+
df = pl.read_csv(io.StringIO(csv_string))
|
340
|
+
return df
|
341
|
+
|
401
342
|
def to_scenario_list(self, remove_prefix: bool = True) -> list[dict]:
|
402
343
|
"""Convert the results to a list of dictionaries, one per scenario.
|
403
344
|
|
edsl/results/Result.py
CHANGED
@@ -173,9 +173,9 @@ class Result(Base, UserDict):
|
|
173
173
|
if question_name in self.question_to_attributes:
|
174
174
|
for dictionary_name in sub_dicts_needing_new_keys:
|
175
175
|
new_key = question_name + "_" + dictionary_name
|
176
|
-
sub_dicts_needing_new_keys[dictionary_name][
|
177
|
-
|
178
|
-
|
176
|
+
sub_dicts_needing_new_keys[dictionary_name][new_key] = (
|
177
|
+
self.question_to_attributes[question_name][dictionary_name]
|
178
|
+
)
|
179
179
|
|
180
180
|
new_cache_dict = {
|
181
181
|
f"{k}_cache_used": v for k, v in self.data["cache_used_dict"].items()
|
@@ -444,6 +444,112 @@ class Result(Base, UserDict):
|
|
444
444
|
raise ValueError(f"Parameter {k} not found in Result object")
|
445
445
|
return scoring_function(**params)
|
446
446
|
|
447
|
+
@classmethod
|
448
|
+
def from_interview(
|
449
|
+
cls, interview, extracted_answers, model_response_objects
|
450
|
+
) -> Result:
|
451
|
+
"""Return a Result object from an interview dictionary."""
|
452
|
+
|
453
|
+
def get_question_results(
|
454
|
+
model_response_objects,
|
455
|
+
) -> dict[str, "EDSLResultObjectInput"]:
|
456
|
+
"""Maps the question name to the EDSLResultObjectInput."""
|
457
|
+
question_results = {}
|
458
|
+
for result in model_response_objects:
|
459
|
+
question_results[result.question_name] = result
|
460
|
+
return question_results
|
461
|
+
|
462
|
+
def get_generated_tokens_dict(answer_key_names) -> dict[str, str]:
|
463
|
+
generated_tokens_dict = {
|
464
|
+
k + "_generated_tokens": question_results[k].generated_tokens
|
465
|
+
for k in answer_key_names
|
466
|
+
}
|
467
|
+
return generated_tokens_dict
|
468
|
+
|
469
|
+
def get_comments_dict(answer_key_names) -> dict[str, str]:
|
470
|
+
comments_dict = {
|
471
|
+
k + "_comment": question_results[k].comment for k in answer_key_names
|
472
|
+
}
|
473
|
+
return comments_dict
|
474
|
+
|
475
|
+
def get_question_name_to_prompts(
|
476
|
+
model_response_objects,
|
477
|
+
) -> dict[str, dict[str, str]]:
|
478
|
+
question_name_to_prompts = dict({})
|
479
|
+
for result in model_response_objects:
|
480
|
+
question_name = result.question_name
|
481
|
+
question_name_to_prompts[question_name] = {
|
482
|
+
"user_prompt": result.prompts["user_prompt"],
|
483
|
+
"system_prompt": result.prompts["system_prompt"],
|
484
|
+
}
|
485
|
+
return question_name_to_prompts
|
486
|
+
|
487
|
+
def get_prompt_dictionary(answer_key_names, question_name_to_prompts):
|
488
|
+
prompt_dictionary = {}
|
489
|
+
for answer_key_name in answer_key_names:
|
490
|
+
prompt_dictionary[answer_key_name + "_user_prompt"] = (
|
491
|
+
question_name_to_prompts[answer_key_name]["user_prompt"]
|
492
|
+
)
|
493
|
+
prompt_dictionary[answer_key_name + "_system_prompt"] = (
|
494
|
+
question_name_to_prompts[answer_key_name]["system_prompt"]
|
495
|
+
)
|
496
|
+
return prompt_dictionary
|
497
|
+
|
498
|
+
def get_raw_model_results_and_cache_used_dictionary(model_response_objects):
|
499
|
+
raw_model_results_dictionary = {}
|
500
|
+
cache_used_dictionary = {}
|
501
|
+
for result in model_response_objects:
|
502
|
+
question_name = result.question_name
|
503
|
+
raw_model_results_dictionary[question_name + "_raw_model_response"] = (
|
504
|
+
result.raw_model_response
|
505
|
+
)
|
506
|
+
raw_model_results_dictionary[question_name + "_cost"] = result.cost
|
507
|
+
one_use_buys = (
|
508
|
+
"NA"
|
509
|
+
if isinstance(result.cost, str)
|
510
|
+
or result.cost == 0
|
511
|
+
or result.cost is None
|
512
|
+
else 1.0 / result.cost
|
513
|
+
)
|
514
|
+
raw_model_results_dictionary[question_name + "_one_usd_buys"] = (
|
515
|
+
one_use_buys
|
516
|
+
)
|
517
|
+
cache_used_dictionary[question_name] = result.cache_used
|
518
|
+
|
519
|
+
return raw_model_results_dictionary, cache_used_dictionary
|
520
|
+
|
521
|
+
question_results = get_question_results(model_response_objects)
|
522
|
+
answer_key_names = list(question_results.keys())
|
523
|
+
generated_tokens_dict = get_generated_tokens_dict(answer_key_names)
|
524
|
+
comments_dict = get_comments_dict(answer_key_names)
|
525
|
+
answer_dict = {k: extracted_answers[k] for k in answer_key_names}
|
526
|
+
|
527
|
+
question_name_to_prompts = get_question_name_to_prompts(model_response_objects)
|
528
|
+
prompt_dictionary = get_prompt_dictionary(
|
529
|
+
answer_key_names, question_name_to_prompts
|
530
|
+
)
|
531
|
+
raw_model_results_dictionary, cache_used_dictionary = (
|
532
|
+
get_raw_model_results_and_cache_used_dictionary(model_response_objects)
|
533
|
+
)
|
534
|
+
|
535
|
+
result = cls(
|
536
|
+
agent=interview.agent,
|
537
|
+
scenario=interview.scenario,
|
538
|
+
model=interview.model,
|
539
|
+
iteration=interview.iteration,
|
540
|
+
# Computed objects
|
541
|
+
answer=answer_dict,
|
542
|
+
prompt=prompt_dictionary,
|
543
|
+
raw_model_response=raw_model_results_dictionary,
|
544
|
+
survey=interview.survey,
|
545
|
+
generated_tokens=generated_tokens_dict,
|
546
|
+
comments_dict=comments_dict,
|
547
|
+
cache_used_dict=cache_used_dictionary,
|
548
|
+
indices=interview.indices,
|
549
|
+
)
|
550
|
+
result.interview_hash = interview.initial_hash
|
551
|
+
return result
|
552
|
+
|
447
553
|
|
448
554
|
if __name__ == "__main__":
|
449
555
|
import doctest
|