edsl 0.1.39__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +0 -28
- edsl/__init__.py +1 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +17 -9
- edsl/agents/Invigilator.py +14 -13
- edsl/agents/InvigilatorBase.py +1 -4
- edsl/agents/PromptConstructor.py +22 -42
- edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
- edsl/auto/AutoStudy.py +5 -18
- edsl/auto/StageBase.py +40 -53
- edsl/auto/StageQuestions.py +1 -2
- edsl/auto/utilities.py +6 -0
- edsl/coop/coop.py +5 -21
- edsl/data/Cache.py +18 -29
- edsl/data/CacheHandler.py +2 -0
- edsl/data/RemoteCacheSync.py +46 -154
- edsl/enums.py +0 -7
- edsl/inference_services/AnthropicService.py +16 -38
- edsl/inference_services/AvailableModelFetcher.py +1 -7
- edsl/inference_services/GoogleService.py +1 -5
- edsl/inference_services/InferenceServicesCollection.py +2 -18
- edsl/inference_services/OpenAIService.py +31 -46
- edsl/inference_services/TestService.py +3 -1
- edsl/inference_services/TogetherAIService.py +3 -5
- edsl/inference_services/data_structures.py +2 -74
- edsl/jobs/AnswerQuestionFunctionConstructor.py +113 -148
- edsl/jobs/FetchInvigilator.py +3 -10
- edsl/jobs/InterviewsConstructor.py +4 -6
- edsl/jobs/Jobs.py +233 -299
- edsl/jobs/JobsChecks.py +2 -2
- edsl/jobs/JobsPrompts.py +1 -1
- edsl/jobs/JobsRemoteInferenceHandler.py +136 -160
- edsl/jobs/interviews/Interview.py +42 -80
- edsl/jobs/runners/JobsRunnerAsyncio.py +358 -88
- edsl/jobs/runners/JobsRunnerStatus.py +165 -133
- edsl/jobs/tasks/TaskHistory.py +3 -24
- edsl/language_models/LanguageModel.py +4 -59
- edsl/language_models/ModelList.py +8 -19
- edsl/language_models/__init__.py +1 -1
- edsl/language_models/registry.py +180 -0
- edsl/language_models/repair.py +1 -1
- edsl/questions/QuestionBase.py +26 -35
- edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +49 -52
- edsl/questions/QuestionBasePromptsMixin.py +1 -1
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +2 -2
- edsl/questions/QuestionExtract.py +7 -5
- edsl/questions/QuestionFreeText.py +1 -1
- edsl/questions/QuestionList.py +15 -9
- edsl/questions/QuestionMatrix.py +1 -1
- edsl/questions/QuestionMultipleChoice.py +1 -1
- edsl/questions/QuestionNumerical.py +1 -1
- edsl/questions/QuestionRank.py +1 -1
- edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +18 -6
- edsl/questions/{response_validator_factory.py → ResponseValidatorFactory.py} +1 -7
- edsl/questions/SimpleAskMixin.py +1 -1
- edsl/questions/__init__.py +1 -1
- edsl/results/DatasetExportMixin.py +119 -60
- edsl/results/Result.py +3 -109
- edsl/results/Results.py +39 -50
- edsl/scenarios/FileStore.py +0 -32
- edsl/scenarios/ScenarioList.py +7 -35
- edsl/scenarios/handlers/csv.py +0 -11
- edsl/surveys/Survey.py +20 -71
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +1 -1
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/RECORD +78 -84
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +1 -1
- edsl/jobs/async_interview_runner.py +0 -138
- edsl/jobs/check_survey_scenario_compatibility.py +0 -85
- edsl/jobs/data_structures.py +0 -120
- edsl/jobs/results_exceptions_handler.py +0 -98
- edsl/language_models/model.py +0 -256
- edsl/questions/data_structures.py +0 -20
- edsl/results/file_exports.py +0 -252
- /edsl/agents/{question_option_processor.py → QuestionOptionProcessor.py} +0 -0
- /edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +0 -0
- /edsl/questions/{loop_processor.py → LoopProcessor.py} +0 -0
- /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
- /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
- /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
- /edsl/results/{results_selector.py → Selector.py} +0 -0
- /edsl/scenarios/{directory_scanner.py → DirectoryScanner.py} +0 -0
- /edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +0 -0
- /edsl/scenarios/{scenario_selector.py → ScenarioSelector.py} +0 -0
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
edsl/questions/QuestionRank.py
CHANGED
@@ -8,7 +8,7 @@ from edsl.questions.descriptors import (
|
|
8
8
|
QuestionOptionsDescriptor,
|
9
9
|
NumSelectionsDescriptor,
|
10
10
|
)
|
11
|
-
from edsl.questions.
|
11
|
+
from edsl.questions.ResponseValidatorABC import ResponseValidatorABC
|
12
12
|
|
13
13
|
|
14
14
|
def create_response_model(
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import logging
|
1
2
|
from abc import ABC, abstractmethod
|
2
3
|
from typing import Optional, Any, List, TypedDict
|
3
4
|
|
@@ -6,17 +7,17 @@ from pydantic import BaseModel, Field, field_validator, ValidationError
|
|
6
7
|
from edsl.exceptions.questions import QuestionAnswerValidationError
|
7
8
|
from edsl.questions.ExceptionExplainer import ExceptionExplainer
|
8
9
|
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
10
|
+
|
11
|
+
class BaseResponse(BaseModel):
|
12
|
+
answer: Any
|
13
|
+
comment: Optional[str] = None
|
14
|
+
generated_tokens: Optional[str] = None
|
13
15
|
|
14
16
|
|
15
17
|
class ResponseValidatorABC(ABC):
|
16
18
|
required_params: List[str] = []
|
17
19
|
|
18
20
|
def __init_subclass__(cls, **kwargs):
|
19
|
-
"""This is a metaclass that ensures that all subclasses of ResponseValidatorABC have the required class variables."""
|
20
21
|
super().__init_subclass__(**kwargs)
|
21
22
|
required_class_vars = ["required_params", "valid_examples", "invalid_examples"]
|
22
23
|
for var in required_class_vars:
|
@@ -51,7 +52,12 @@ class ResponseValidatorABC(ABC):
|
|
51
52
|
if not hasattr(self, "permissive"):
|
52
53
|
self.permissive = False
|
53
54
|
|
54
|
-
self.fixes_tried = 0
|
55
|
+
self.fixes_tried = 0
|
56
|
+
|
57
|
+
class RawEdslAnswerDict(TypedDict):
|
58
|
+
answer: Any
|
59
|
+
comment: Optional[str]
|
60
|
+
generated_tokens: Optional[str]
|
55
61
|
|
56
62
|
def _preprocess(self, data: RawEdslAnswerDict) -> RawEdslAnswerDict:
|
57
63
|
"""This is for testing purposes. A question can be given an exception to throw or an answer to always return.
|
@@ -83,6 +89,11 @@ class ResponseValidatorABC(ABC):
|
|
83
89
|
def post_validation_answer_convert(self, data):
|
84
90
|
return data
|
85
91
|
|
92
|
+
class EdslAnswerDict(TypedDict):
|
93
|
+
answer: Any
|
94
|
+
comment: Optional[str]
|
95
|
+
generated_tokens: Optional[str]
|
96
|
+
|
86
97
|
def validate(
|
87
98
|
self,
|
88
99
|
raw_edsl_answer_dict: RawEdslAnswerDict,
|
@@ -125,6 +136,7 @@ class ResponseValidatorABC(ABC):
|
|
125
136
|
def human_explanation(self, e: QuestionAnswerValidationError):
|
126
137
|
explanation = ExceptionExplainer(e, model_response=e.data).explain()
|
127
138
|
return explanation
|
139
|
+
# return e
|
128
140
|
|
129
141
|
def _handle_exception(self, e: Exception, raw_edsl_answer_dict) -> EdslAnswerDict:
|
130
142
|
if self.fixes_tried == 0:
|
@@ -1,10 +1,4 @@
|
|
1
|
-
from edsl.questions.data_structures import BaseModel
|
2
|
-
from edsl.questions.response_validator_abc import ResponseValidatorABC
|
3
|
-
|
4
|
-
|
5
1
|
class ResponseValidatorFactory:
|
6
|
-
"""Factory class to create a response validator for a question."""
|
7
|
-
|
8
2
|
def __init__(self, question):
|
9
3
|
self.question = question
|
10
4
|
|
@@ -16,7 +10,7 @@ class ResponseValidatorFactory:
|
|
16
10
|
return self.question.create_response_model()
|
17
11
|
|
18
12
|
@property
|
19
|
-
def response_validator(self) -> "
|
13
|
+
def response_validator(self) -> "ResponseValidatorBase":
|
20
14
|
"""Return the response validator."""
|
21
15
|
params = (
|
22
16
|
{
|
edsl/questions/SimpleAskMixin.py
CHANGED
@@ -66,7 +66,7 @@ class SimpleAskMixin:
|
|
66
66
|
system_prompt="You are a helpful agent pretending to be a human. Do not break character",
|
67
67
|
top_logprobs=4,
|
68
68
|
):
|
69
|
-
from edsl.language_models.
|
69
|
+
from edsl.language_models.registry import Model
|
70
70
|
|
71
71
|
if model is None:
|
72
72
|
model = Model()
|
edsl/questions/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
# Schemas
|
2
2
|
from edsl.questions.settings import Settings
|
3
|
-
from edsl.questions.
|
3
|
+
from edsl.questions.RegisterQuestionsMeta import RegisterQuestionsMeta
|
4
4
|
|
5
5
|
# Base Class
|
6
6
|
from edsl.questions.QuestionBase import QuestionBase
|
@@ -1,12 +1,12 @@
|
|
1
1
|
"""Mixin class for exporting results."""
|
2
2
|
|
3
|
+
import base64
|
4
|
+
import csv
|
3
5
|
import io
|
4
6
|
import warnings
|
5
7
|
import textwrap
|
6
8
|
from typing import Optional, Tuple, Union, List
|
7
9
|
|
8
|
-
from edsl.results.file_exports import CSVExport, ExcelExport, JSONLExport, SQLiteExport
|
9
|
-
|
10
10
|
|
11
11
|
class DatasetExportMixin:
|
12
12
|
"""Mixin class for exporting Dataset objects."""
|
@@ -164,44 +164,79 @@ class DatasetExportMixin:
|
|
164
164
|
remove_prefix=remove_prefix, pretty_labels=pretty_labels
|
165
165
|
)
|
166
166
|
|
167
|
-
def to_jsonl(self, filename: Optional[str] = None) ->
|
168
|
-
"""Export the results to a FileStore instance containing JSONL data.
|
169
|
-
exporter = JSONLExport(data=self, filename=filename)
|
170
|
-
return exporter.export()
|
167
|
+
def to_jsonl(self, filename: Optional[str] = None) -> "FileStore":
|
168
|
+
"""Export the results to a FileStore instance containing JSONL data.
|
171
169
|
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
170
|
+
Args:
|
171
|
+
filename: Optional filename for the JSONL file (defaults to "results.jsonl")
|
172
|
+
|
173
|
+
Returns:
|
174
|
+
FileStore: Instance containing the JSONL data
|
175
|
+
"""
|
176
|
+
if filename is None:
|
177
|
+
filename = "results.jsonl"
|
178
|
+
|
179
|
+
# Write to string buffer
|
180
|
+
output = io.StringIO()
|
181
|
+
for entry in self:
|
182
|
+
key, values = list(entry.items())[0]
|
183
|
+
output.write(f'{{"{key}": {values}}}\n')
|
184
|
+
|
185
|
+
# Get the CSV string and encode to base64
|
186
|
+
jsonl_string = output.getvalue()
|
187
|
+
base64_string = base64.b64encode(jsonl_string.encode()).decode()
|
188
|
+
from edsl.scenarios.FileStore import FileStore
|
189
|
+
|
190
|
+
return FileStore(
|
191
|
+
path=filename,
|
192
|
+
mime_type="application/jsonl",
|
193
|
+
binary=False,
|
194
|
+
suffix="jsonl",
|
195
|
+
base64_string=base64_string,
|
188
196
|
)
|
189
|
-
return exporter.export()
|
190
197
|
|
191
198
|
def to_csv(
|
192
199
|
self,
|
193
200
|
filename: Optional[str] = None,
|
194
201
|
remove_prefix: bool = False,
|
195
202
|
pretty_labels: Optional[dict] = None,
|
196
|
-
) ->
|
197
|
-
"""Export the results to a FileStore instance containing CSV data.
|
198
|
-
|
199
|
-
|
200
|
-
filename
|
201
|
-
remove_prefix
|
202
|
-
pretty_labels
|
203
|
+
) -> "FileStore":
|
204
|
+
"""Export the results to a FileStore instance containing CSV data.
|
205
|
+
|
206
|
+
Args:
|
207
|
+
filename: Optional filename for the CSV (defaults to "results.csv")
|
208
|
+
remove_prefix: Whether to remove the prefix from column names
|
209
|
+
pretty_labels: Dictionary mapping original column names to pretty labels
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
FileStore: Instance containing the CSV data
|
213
|
+
"""
|
214
|
+
if filename is None:
|
215
|
+
filename = "results.csv"
|
216
|
+
|
217
|
+
# Get the tabular data
|
218
|
+
header, rows = self._get_tabular_data(
|
219
|
+
remove_prefix=remove_prefix, pretty_labels=pretty_labels
|
220
|
+
)
|
221
|
+
|
222
|
+
# Write to string buffer
|
223
|
+
output = io.StringIO()
|
224
|
+
writer = csv.writer(output)
|
225
|
+
writer.writerow(header)
|
226
|
+
writer.writerows(rows)
|
227
|
+
|
228
|
+
# Get the CSV string and encode to base64
|
229
|
+
csv_string = output.getvalue()
|
230
|
+
base64_string = base64.b64encode(csv_string.encode()).decode()
|
231
|
+
from edsl.scenarios.FileStore import FileStore
|
232
|
+
|
233
|
+
return FileStore(
|
234
|
+
path=filename,
|
235
|
+
mime_type="text/csv",
|
236
|
+
binary=False,
|
237
|
+
suffix="csv",
|
238
|
+
base64_string=base64_string,
|
203
239
|
)
|
204
|
-
return exporter.export()
|
205
240
|
|
206
241
|
def to_excel(
|
207
242
|
self,
|
@@ -209,16 +244,60 @@ class DatasetExportMixin:
|
|
209
244
|
remove_prefix: bool = False,
|
210
245
|
pretty_labels: Optional[dict] = None,
|
211
246
|
sheet_name: Optional[str] = None,
|
212
|
-
) ->
|
213
|
-
"""Export the results to a FileStore instance containing Excel data.
|
214
|
-
|
215
|
-
|
216
|
-
filename
|
217
|
-
remove_prefix
|
218
|
-
pretty_labels
|
219
|
-
sheet_name
|
247
|
+
) -> "FileStore":
|
248
|
+
"""Export the results to a FileStore instance containing Excel data.
|
249
|
+
|
250
|
+
Args:
|
251
|
+
filename: Optional filename for the Excel file (defaults to "results.xlsx")
|
252
|
+
remove_prefix: Whether to remove the prefix from column names
|
253
|
+
pretty_labels: Dictionary mapping original column names to pretty labels
|
254
|
+
sheet_name: Name of the worksheet (defaults to "Results")
|
255
|
+
|
256
|
+
Returns:
|
257
|
+
FileStore: Instance containing the Excel data
|
258
|
+
"""
|
259
|
+
from openpyxl import Workbook
|
260
|
+
|
261
|
+
if filename is None:
|
262
|
+
filename = "results.xlsx"
|
263
|
+
if sheet_name is None:
|
264
|
+
sheet_name = "Results"
|
265
|
+
|
266
|
+
# Get the tabular data
|
267
|
+
header, rows = self._get_tabular_data(
|
268
|
+
remove_prefix=remove_prefix, pretty_labels=pretty_labels
|
269
|
+
)
|
270
|
+
|
271
|
+
# Create Excel workbook in memory
|
272
|
+
wb = Workbook()
|
273
|
+
ws = wb.active
|
274
|
+
ws.title = sheet_name
|
275
|
+
|
276
|
+
# Write header
|
277
|
+
for col, value in enumerate(header, 1):
|
278
|
+
ws.cell(row=1, column=col, value=value)
|
279
|
+
|
280
|
+
# Write data rows
|
281
|
+
for row_idx, row_data in enumerate(rows, 2):
|
282
|
+
for col, value in enumerate(row_data, 1):
|
283
|
+
ws.cell(row=row_idx, column=col, value=value)
|
284
|
+
|
285
|
+
# Save to bytes buffer
|
286
|
+
buffer = io.BytesIO()
|
287
|
+
wb.save(buffer)
|
288
|
+
buffer.seek(0)
|
289
|
+
|
290
|
+
# Convert to base64
|
291
|
+
base64_string = base64.b64encode(buffer.getvalue()).decode()
|
292
|
+
from edsl.scenarios.FileStore import FileStore
|
293
|
+
|
294
|
+
return FileStore(
|
295
|
+
path=filename,
|
296
|
+
mime_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
297
|
+
binary=True,
|
298
|
+
suffix="xlsx",
|
299
|
+
base64_string=base64_string,
|
220
300
|
)
|
221
|
-
return exporter.export()
|
222
301
|
|
223
302
|
def _db(self, remove_prefix: bool = True):
|
224
303
|
"""Create a SQLite database in memory and return the connection.
|
@@ -319,26 +398,6 @@ class DatasetExportMixin:
|
|
319
398
|
# df_sorted = df.sort_index(axis=1) # Sort columns alphabetically
|
320
399
|
return df
|
321
400
|
|
322
|
-
def to_polars(
|
323
|
-
self, remove_prefix: bool = False, lists_as_strings=False
|
324
|
-
) -> "pl.DataFrame":
|
325
|
-
"""Convert the results to a Polars DataFrame.
|
326
|
-
|
327
|
-
:param remove_prefix: Whether to remove the prefix from the column names.
|
328
|
-
"""
|
329
|
-
return self._to_polars_strings(remove_prefix)
|
330
|
-
|
331
|
-
def _to_polars_strings(self, remove_prefix: bool = False) -> "pl.DataFrame":
|
332
|
-
"""Convert the results to a Polars DataFrame.
|
333
|
-
|
334
|
-
:param remove_prefix: Whether to remove the prefix from the column names.
|
335
|
-
"""
|
336
|
-
import polars as pl
|
337
|
-
|
338
|
-
csv_string = self.to_csv(remove_prefix=remove_prefix).text
|
339
|
-
df = pl.read_csv(io.StringIO(csv_string))
|
340
|
-
return df
|
341
|
-
|
342
401
|
def to_scenario_list(self, remove_prefix: bool = True) -> list[dict]:
|
343
402
|
"""Convert the results to a list of dictionaries, one per scenario.
|
344
403
|
|
edsl/results/Result.py
CHANGED
@@ -173,9 +173,9 @@ class Result(Base, UserDict):
|
|
173
173
|
if question_name in self.question_to_attributes:
|
174
174
|
for dictionary_name in sub_dicts_needing_new_keys:
|
175
175
|
new_key = question_name + "_" + dictionary_name
|
176
|
-
sub_dicts_needing_new_keys[dictionary_name][
|
177
|
-
|
178
|
-
|
176
|
+
sub_dicts_needing_new_keys[dictionary_name][
|
177
|
+
new_key
|
178
|
+
] = self.question_to_attributes[question_name][dictionary_name]
|
179
179
|
|
180
180
|
new_cache_dict = {
|
181
181
|
f"{k}_cache_used": v for k, v in self.data["cache_used_dict"].items()
|
@@ -444,112 +444,6 @@ class Result(Base, UserDict):
|
|
444
444
|
raise ValueError(f"Parameter {k} not found in Result object")
|
445
445
|
return scoring_function(**params)
|
446
446
|
|
447
|
-
@classmethod
|
448
|
-
def from_interview(
|
449
|
-
cls, interview, extracted_answers, model_response_objects
|
450
|
-
) -> Result:
|
451
|
-
"""Return a Result object from an interview dictionary."""
|
452
|
-
|
453
|
-
def get_question_results(
|
454
|
-
model_response_objects,
|
455
|
-
) -> dict[str, "EDSLResultObjectInput"]:
|
456
|
-
"""Maps the question name to the EDSLResultObjectInput."""
|
457
|
-
question_results = {}
|
458
|
-
for result in model_response_objects:
|
459
|
-
question_results[result.question_name] = result
|
460
|
-
return question_results
|
461
|
-
|
462
|
-
def get_generated_tokens_dict(answer_key_names) -> dict[str, str]:
|
463
|
-
generated_tokens_dict = {
|
464
|
-
k + "_generated_tokens": question_results[k].generated_tokens
|
465
|
-
for k in answer_key_names
|
466
|
-
}
|
467
|
-
return generated_tokens_dict
|
468
|
-
|
469
|
-
def get_comments_dict(answer_key_names) -> dict[str, str]:
|
470
|
-
comments_dict = {
|
471
|
-
k + "_comment": question_results[k].comment for k in answer_key_names
|
472
|
-
}
|
473
|
-
return comments_dict
|
474
|
-
|
475
|
-
def get_question_name_to_prompts(
|
476
|
-
model_response_objects,
|
477
|
-
) -> dict[str, dict[str, str]]:
|
478
|
-
question_name_to_prompts = dict({})
|
479
|
-
for result in model_response_objects:
|
480
|
-
question_name = result.question_name
|
481
|
-
question_name_to_prompts[question_name] = {
|
482
|
-
"user_prompt": result.prompts["user_prompt"],
|
483
|
-
"system_prompt": result.prompts["system_prompt"],
|
484
|
-
}
|
485
|
-
return question_name_to_prompts
|
486
|
-
|
487
|
-
def get_prompt_dictionary(answer_key_names, question_name_to_prompts):
|
488
|
-
prompt_dictionary = {}
|
489
|
-
for answer_key_name in answer_key_names:
|
490
|
-
prompt_dictionary[answer_key_name + "_user_prompt"] = (
|
491
|
-
question_name_to_prompts[answer_key_name]["user_prompt"]
|
492
|
-
)
|
493
|
-
prompt_dictionary[answer_key_name + "_system_prompt"] = (
|
494
|
-
question_name_to_prompts[answer_key_name]["system_prompt"]
|
495
|
-
)
|
496
|
-
return prompt_dictionary
|
497
|
-
|
498
|
-
def get_raw_model_results_and_cache_used_dictionary(model_response_objects):
|
499
|
-
raw_model_results_dictionary = {}
|
500
|
-
cache_used_dictionary = {}
|
501
|
-
for result in model_response_objects:
|
502
|
-
question_name = result.question_name
|
503
|
-
raw_model_results_dictionary[question_name + "_raw_model_response"] = (
|
504
|
-
result.raw_model_response
|
505
|
-
)
|
506
|
-
raw_model_results_dictionary[question_name + "_cost"] = result.cost
|
507
|
-
one_use_buys = (
|
508
|
-
"NA"
|
509
|
-
if isinstance(result.cost, str)
|
510
|
-
or result.cost == 0
|
511
|
-
or result.cost is None
|
512
|
-
else 1.0 / result.cost
|
513
|
-
)
|
514
|
-
raw_model_results_dictionary[question_name + "_one_usd_buys"] = (
|
515
|
-
one_use_buys
|
516
|
-
)
|
517
|
-
cache_used_dictionary[question_name] = result.cache_used
|
518
|
-
|
519
|
-
return raw_model_results_dictionary, cache_used_dictionary
|
520
|
-
|
521
|
-
question_results = get_question_results(model_response_objects)
|
522
|
-
answer_key_names = list(question_results.keys())
|
523
|
-
generated_tokens_dict = get_generated_tokens_dict(answer_key_names)
|
524
|
-
comments_dict = get_comments_dict(answer_key_names)
|
525
|
-
answer_dict = {k: extracted_answers[k] for k in answer_key_names}
|
526
|
-
|
527
|
-
question_name_to_prompts = get_question_name_to_prompts(model_response_objects)
|
528
|
-
prompt_dictionary = get_prompt_dictionary(
|
529
|
-
answer_key_names, question_name_to_prompts
|
530
|
-
)
|
531
|
-
raw_model_results_dictionary, cache_used_dictionary = (
|
532
|
-
get_raw_model_results_and_cache_used_dictionary(model_response_objects)
|
533
|
-
)
|
534
|
-
|
535
|
-
result = cls(
|
536
|
-
agent=interview.agent,
|
537
|
-
scenario=interview.scenario,
|
538
|
-
model=interview.model,
|
539
|
-
iteration=interview.iteration,
|
540
|
-
# Computed objects
|
541
|
-
answer=answer_dict,
|
542
|
-
prompt=prompt_dictionary,
|
543
|
-
raw_model_response=raw_model_results_dictionary,
|
544
|
-
survey=interview.survey,
|
545
|
-
generated_tokens=generated_tokens_dict,
|
546
|
-
comments_dict=comments_dict,
|
547
|
-
cache_used_dict=cache_used_dictionary,
|
548
|
-
indices=interview.indices,
|
549
|
-
)
|
550
|
-
result.interview_hash = interview.initial_hash
|
551
|
-
return result
|
552
|
-
|
553
447
|
|
554
448
|
if __name__ == "__main__":
|
555
449
|
import doctest
|
edsl/results/Results.py
CHANGED
@@ -9,8 +9,6 @@ import random
|
|
9
9
|
from collections import UserList, defaultdict
|
10
10
|
from typing import Optional, Callable, Any, Type, Union, List, TYPE_CHECKING
|
11
11
|
|
12
|
-
from bisect import bisect_left
|
13
|
-
|
14
12
|
from edsl.Base import Base
|
15
13
|
from edsl.exceptions.results import (
|
16
14
|
ResultsError,
|
@@ -26,7 +24,7 @@ if TYPE_CHECKING:
|
|
26
24
|
from edsl.surveys.Survey import Survey
|
27
25
|
from edsl.data.Cache import Cache
|
28
26
|
from edsl.agents.AgentList import AgentList
|
29
|
-
from edsl.language_models.
|
27
|
+
from edsl.language_models.registry import Model
|
30
28
|
from edsl.scenarios.ScenarioList import ScenarioList
|
31
29
|
from edsl.results.Result import Result
|
32
30
|
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
@@ -35,7 +33,7 @@ if TYPE_CHECKING:
|
|
35
33
|
|
36
34
|
from edsl.results.ResultsExportMixin import ResultsExportMixin
|
37
35
|
from edsl.results.ResultsGGMixin import ResultsGGMixin
|
38
|
-
from edsl.results.
|
36
|
+
from edsl.results.ResultsFetchMixin import ResultsFetchMixin
|
39
37
|
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
40
38
|
|
41
39
|
|
@@ -138,33 +136,7 @@ class Results(UserList, Mixins, Base):
|
|
138
136
|
}
|
139
137
|
return d
|
140
138
|
|
141
|
-
def
|
142
|
-
item_order = getattr(item, "order", None)
|
143
|
-
if item_order is not None:
|
144
|
-
# Get list of orders, putting None at the end
|
145
|
-
orders = [getattr(x, "order", None) for x in self]
|
146
|
-
# Filter to just the non-None orders for bisect
|
147
|
-
sorted_orders = [x for x in orders if x is not None]
|
148
|
-
if sorted_orders:
|
149
|
-
index = bisect_left(sorted_orders, item_order)
|
150
|
-
# Account for any None values before this position
|
151
|
-
index += orders[:index].count(None)
|
152
|
-
else:
|
153
|
-
# If no sorted items yet, insert before any unordered items
|
154
|
-
index = 0
|
155
|
-
self.data.insert(index, item)
|
156
|
-
else:
|
157
|
-
# No order - append to end
|
158
|
-
self.data.append(item)
|
159
|
-
|
160
|
-
def append(self, item):
|
161
|
-
self.insert(item)
|
162
|
-
|
163
|
-
def extend(self, other):
|
164
|
-
for item in other:
|
165
|
-
self.insert(item)
|
166
|
-
|
167
|
-
def compute_job_cost(self, include_cached_responses_in_cost: bool = False) -> float:
|
139
|
+
def compute_job_cost(self, include_cached_responses_in_cost=False) -> float:
|
168
140
|
"""
|
169
141
|
Computes the cost of a completed job in USD.
|
170
142
|
"""
|
@@ -278,6 +250,24 @@ class Results(UserList, Mixins, Base):
|
|
278
250
|
|
279
251
|
raise TypeError("Invalid argument type")
|
280
252
|
|
253
|
+
# def _update_results(self) -> None:
|
254
|
+
# from edsl import Agent, Scenario
|
255
|
+
# from edsl.language_models import LanguageModel
|
256
|
+
# from edsl.results import Result
|
257
|
+
|
258
|
+
# if self._job_uuid and len(self.data) < self._total_results:
|
259
|
+
# results = [
|
260
|
+
# Result(
|
261
|
+
# agent=Agent.from_dict(json.loads(r.agent)),
|
262
|
+
# scenario=Scenario.from_dict(json.loads(r.scenario)),
|
263
|
+
# model=LanguageModel.from_dict(json.loads(r.model)),
|
264
|
+
# iteration=1,
|
265
|
+
# answer=json.loads(r.answer),
|
266
|
+
# )
|
267
|
+
# for r in CRUD.read_results(self._job_uuid)
|
268
|
+
# ]
|
269
|
+
# self.data = results
|
270
|
+
|
281
271
|
def __add__(self, other: Results) -> Results:
|
282
272
|
"""Add two Results objects together.
|
283
273
|
They must have the same survey and created columns.
|
@@ -305,10 +295,13 @@ class Results(UserList, Mixins, Base):
|
|
305
295
|
)
|
306
296
|
|
307
297
|
def __repr__(self) -> str:
|
298
|
+
# import reprlib
|
299
|
+
|
308
300
|
return f"Results(data = {self.data}, survey = {repr(self.survey)}, created_columns = {self.created_columns})"
|
309
301
|
|
310
302
|
def table(
|
311
303
|
self,
|
304
|
+
# selector_string: Optional[str] = "*.*",
|
312
305
|
*fields,
|
313
306
|
tablefmt: Optional[str] = None,
|
314
307
|
pretty_labels: Optional[dict] = None,
|
@@ -347,11 +340,11 @@ class Results(UserList, Mixins, Base):
|
|
347
340
|
|
348
341
|
def to_dict(
|
349
342
|
self,
|
350
|
-
sort
|
351
|
-
add_edsl_version
|
352
|
-
include_cache
|
353
|
-
include_task_history
|
354
|
-
include_cache_info
|
343
|
+
sort=False,
|
344
|
+
add_edsl_version=False,
|
345
|
+
include_cache=False,
|
346
|
+
include_task_history=False,
|
347
|
+
include_cache_info=True,
|
355
348
|
) -> dict[str, Any]:
|
356
349
|
from edsl.data.Cache import Cache
|
357
350
|
|
@@ -393,7 +386,7 @@ class Results(UserList, Mixins, Base):
|
|
393
386
|
|
394
387
|
return d
|
395
388
|
|
396
|
-
def compare(self, other_results
|
389
|
+
def compare(self, other_results):
|
397
390
|
"""
|
398
391
|
Compare two Results objects and return the differences.
|
399
392
|
"""
|
@@ -411,7 +404,7 @@ class Results(UserList, Mixins, Base):
|
|
411
404
|
}
|
412
405
|
|
413
406
|
@property
|
414
|
-
def has_unfixed_exceptions(self)
|
407
|
+
def has_unfixed_exceptions(self):
|
415
408
|
return self.task_history.has_unfixed_exceptions
|
416
409
|
|
417
410
|
def __hash__(self) -> int:
|
@@ -494,6 +487,10 @@ class Results(UserList, Mixins, Base):
|
|
494
487
|
raise ResultsDeserializationError(f"Error in Results.from_dict: {e}")
|
495
488
|
return results
|
496
489
|
|
490
|
+
######################
|
491
|
+
## Convenience methods
|
492
|
+
## & Report methods
|
493
|
+
######################
|
497
494
|
@property
|
498
495
|
def _key_to_data_type(self) -> dict[str, str]:
|
499
496
|
"""
|
@@ -692,19 +689,13 @@ class Results(UserList, Mixins, Base):
|
|
692
689
|
"""
|
693
690
|
return self.data[0]
|
694
691
|
|
695
|
-
def answer_truncate(
|
696
|
-
self, column: str, top_n: int = 5, new_var_name: str = None
|
697
|
-
) -> Results:
|
692
|
+
def answer_truncate(self, column: str, top_n=5, new_var_name=None) -> Results:
|
698
693
|
"""Create a new variable that truncates the answers to the top_n.
|
699
694
|
|
700
695
|
:param column: The column to truncate.
|
701
696
|
:param top_n: The number of top answers to keep.
|
702
697
|
:param new_var_name: The name of the new variable. If None, it is the original name + '_truncated'.
|
703
698
|
|
704
|
-
Example:
|
705
|
-
>>> r = Results.example()
|
706
|
-
>>> r.answer_truncate('how_feeling', top_n = 2).select('how_feeling', 'how_feeling_truncated')
|
707
|
-
Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'answer.how_feeling_truncated': ['Other', 'Other', 'Other', 'Other']}])
|
708
699
|
|
709
700
|
|
710
701
|
"""
|
@@ -925,7 +916,7 @@ class Results(UserList, Mixins, Base):
|
|
925
916
|
n: Optional[int] = None,
|
926
917
|
frac: Optional[float] = None,
|
927
918
|
with_replacement: bool = True,
|
928
|
-
seed: Optional[str] =
|
919
|
+
seed: Optional[str] = "edsl",
|
929
920
|
) -> Results:
|
930
921
|
"""Sample the results.
|
931
922
|
|
@@ -940,7 +931,7 @@ class Results(UserList, Mixins, Base):
|
|
940
931
|
>>> len(r.sample(2))
|
941
932
|
2
|
942
933
|
"""
|
943
|
-
if seed:
|
934
|
+
if seed != "edsl":
|
944
935
|
random.seed(seed)
|
945
936
|
|
946
937
|
if n is None and frac is None:
|
@@ -978,7 +969,7 @@ class Results(UserList, Mixins, Base):
|
|
978
969
|
Dataset([{'answer.how_feeling_yesterday': ['Great', 'Good', 'OK', 'Terrible']}])
|
979
970
|
"""
|
980
971
|
|
981
|
-
from edsl.results.
|
972
|
+
from edsl.results.Selector import Selector
|
982
973
|
|
983
974
|
if len(self) == 0:
|
984
975
|
raise Exception("No data to select from---the Results object is empty.")
|
@@ -993,7 +984,6 @@ class Results(UserList, Mixins, Base):
|
|
993
984
|
return selector.select(*columns)
|
994
985
|
|
995
986
|
def sort_by(self, *columns: str, reverse: bool = False) -> Results:
|
996
|
-
"""Sort the results by one or more columns."""
|
997
987
|
import warnings
|
998
988
|
|
999
989
|
warnings.warn(
|
@@ -1002,7 +992,6 @@ class Results(UserList, Mixins, Base):
|
|
1002
992
|
return self.order_by(*columns, reverse=reverse)
|
1003
993
|
|
1004
994
|
def _parse_column(self, column: str) -> tuple[str, str]:
|
1005
|
-
"""Parse a column name into a data type and key."""
|
1006
995
|
if "." in column:
|
1007
996
|
return column.split(".")
|
1008
997
|
return self._key_to_data_type[column], column
|