edsl 0.1.46__py3-none-any.whl → 0.1.47__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +86 -19
- edsl/__version__.py +1 -1
- edsl/coop/coop.py +134 -53
- edsl/data/Cache.py +2 -0
- edsl/data/CacheEntry.py +10 -2
- edsl/inference_services/PerplexityService.py +9 -5
- edsl/jobs/Jobs.py +20 -0
- edsl/jobs/JobsComponentConstructor.py +2 -1
- edsl/language_models/LanguageModel.py +6 -6
- edsl/questions/QuestionBase.py +5 -0
- edsl/questions/question_registry.py +6 -7
- edsl/results/DatasetExportMixin.py +99 -2
- edsl/results/Results.py +59 -0
- edsl/scenarios/FileStore.py +112 -7
- edsl/scenarios/ScenarioList.py +130 -0
- edsl/study/Study.py +2 -2
- edsl/surveys/Survey.py +15 -20
- {edsl-0.1.46.dist-info → edsl-0.1.47.dist-info}/METADATA +3 -2
- {edsl-0.1.46.dist-info → edsl-0.1.47.dist-info}/RECORD +21 -33
- edsl/auto/AutoStudy.py +0 -130
- edsl/auto/StageBase.py +0 -243
- edsl/auto/StageGenerateSurvey.py +0 -178
- edsl/auto/StageLabelQuestions.py +0 -125
- edsl/auto/StagePersona.py +0 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +0 -88
- edsl/auto/StagePersonaDimensionValues.py +0 -74
- edsl/auto/StagePersonaDimensions.py +0 -69
- edsl/auto/StageQuestions.py +0 -74
- edsl/auto/SurveyCreatorPipeline.py +0 -21
- edsl/auto/utilities.py +0 -218
- edsl/base/Base.py +0 -279
- {edsl-0.1.46.dist-info → edsl-0.1.47.dist-info}/LICENSE +0 -0
- {edsl-0.1.46.dist-info → edsl-0.1.47.dist-info}/WHEEL +0 -0
@@ -379,10 +379,10 @@ class LanguageModel(
|
|
379
379
|
cached_response, cache_key = cache.fetch(**cache_call_params)
|
380
380
|
|
381
381
|
if cache_used := cached_response is not None:
|
382
|
-
|
382
|
+
# print("cache used")
|
383
383
|
response = json.loads(cached_response)
|
384
384
|
else:
|
385
|
-
# print("cache not used")
|
385
|
+
# print("cache not used")
|
386
386
|
f = (
|
387
387
|
self.remote_async_execute_model_call
|
388
388
|
if hasattr(self, "remote") and self.remote
|
@@ -398,14 +398,14 @@ class LanguageModel(
|
|
398
398
|
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
399
399
|
response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
|
400
400
|
new_cache_key = cache.store(
|
401
|
-
**cache_call_params, response=response
|
401
|
+
**cache_call_params, response=response, service=self._inference_service_
|
402
402
|
) # store the response in the cache
|
403
403
|
assert new_cache_key == cache_key # should be the same
|
404
404
|
|
405
|
-
#breakpoint()
|
405
|
+
# breakpoint()
|
406
406
|
|
407
407
|
cost = self.cost(response)
|
408
|
-
#breakpoint()
|
408
|
+
# breakpoint()
|
409
409
|
return ModelResponse(
|
410
410
|
response=response,
|
411
411
|
cache_used=cache_used,
|
@@ -470,7 +470,7 @@ class LanguageModel(
|
|
470
470
|
model_outputs=model_outputs,
|
471
471
|
edsl_dict=edsl_dict,
|
472
472
|
)
|
473
|
-
#breakpoint()
|
473
|
+
# breakpoint()
|
474
474
|
return agent_response_dict
|
475
475
|
|
476
476
|
get_response = sync_wrapper(async_get_response)
|
edsl/questions/QuestionBase.py
CHANGED
@@ -18,6 +18,7 @@ from edsl.questions.SimpleAskMixin import SimpleAskMixin
|
|
18
18
|
from edsl.questions.QuestionBasePromptsMixin import QuestionBasePromptsMixin
|
19
19
|
from edsl.questions.question_base_gen_mixin import QuestionBaseGenMixin
|
20
20
|
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
21
|
+
from edsl.utilities.utilities import is_valid_variable_name
|
21
22
|
|
22
23
|
if TYPE_CHECKING:
|
23
24
|
from edsl.questions.response_validator_abc import ResponseValidatorABC
|
@@ -56,6 +57,10 @@ class QuestionBase(
|
|
56
57
|
_answering_instructions = None
|
57
58
|
_question_presentation = None
|
58
59
|
|
60
|
+
def is_valid_question_name(self) -> bool:
|
61
|
+
"""Check if the question name is valid."""
|
62
|
+
return is_valid_variable_name(self.question_name)
|
63
|
+
|
59
64
|
@property
|
60
65
|
def response_validator(self) -> "ResponseValidatorABC":
|
61
66
|
"""Return the response validator."""
|
@@ -60,26 +60,25 @@ class Question(metaclass=Meta):
|
|
60
60
|
return q.example()
|
61
61
|
|
62
62
|
@classmethod
|
63
|
-
def pull(cls,
|
63
|
+
def pull(cls, url_or_uuid: Union[str, UUID]):
|
64
64
|
"""Pull the object from coop."""
|
65
65
|
from edsl.coop import Coop
|
66
66
|
|
67
67
|
coop = Coop()
|
68
|
-
return coop.get(
|
68
|
+
return coop.get(url_or_uuid, "question")
|
69
69
|
|
70
70
|
@classmethod
|
71
|
-
def delete(cls,
|
71
|
+
def delete(cls, url_or_uuid: Union[str, UUID]):
|
72
72
|
"""Delete the object from coop."""
|
73
73
|
from edsl.coop import Coop
|
74
74
|
|
75
75
|
coop = Coop()
|
76
|
-
return coop.delete(
|
76
|
+
return coop.delete(url_or_uuid)
|
77
77
|
|
78
78
|
@classmethod
|
79
79
|
def patch(
|
80
80
|
cls,
|
81
|
-
|
82
|
-
url: Optional[str] = None,
|
81
|
+
url_or_uuid: Union[str, UUID],
|
83
82
|
description: Optional[str] = None,
|
84
83
|
value: Optional[Any] = None,
|
85
84
|
visibility: Optional[str] = None,
|
@@ -88,7 +87,7 @@ class Question(metaclass=Meta):
|
|
88
87
|
from edsl.coop import Coop
|
89
88
|
|
90
89
|
coop = Coop()
|
91
|
-
return coop.patch(
|
90
|
+
return coop.patch(url_or_uuid, description, value, visibility)
|
92
91
|
|
93
92
|
@classmethod
|
94
93
|
def list_question_types(cls):
|
@@ -505,8 +505,9 @@ class DatasetExportMixin:
|
|
505
505
|
|
506
506
|
from edsl.utilities.PrettyList import PrettyList
|
507
507
|
|
508
|
-
return PrettyList(list_to_return)
|
509
|
-
|
508
|
+
#return PrettyList(list_to_return)
|
509
|
+
return list_to_return
|
510
|
+
|
510
511
|
def html(
|
511
512
|
self,
|
512
513
|
filename: Optional[str] = None,
|
@@ -903,6 +904,102 @@ class DatasetExportMixin:
|
|
903
904
|
result.data.pop(field_index)
|
904
905
|
|
905
906
|
return result
|
907
|
+
|
908
|
+
def drop(self, field_name):
|
909
|
+
"""
|
910
|
+
Returns a new Dataset with the specified field removed.
|
911
|
+
|
912
|
+
Args:
|
913
|
+
field_name (str): The name of the field to remove.
|
914
|
+
|
915
|
+
Returns:
|
916
|
+
Dataset: A new Dataset instance without the specified field.
|
917
|
+
|
918
|
+
Raises:
|
919
|
+
KeyError: If the field_name doesn't exist in the dataset.
|
920
|
+
|
921
|
+
Examples:
|
922
|
+
>>> from edsl.results.Dataset import Dataset
|
923
|
+
>>> d = Dataset([{'a': [1, 2, 3]}, {'b': [4, 5, 6]}])
|
924
|
+
>>> d.drop('a')
|
925
|
+
Dataset([{'b': [4, 5, 6]}])
|
926
|
+
|
927
|
+
>>> d.drop('c')
|
928
|
+
Traceback (most recent call last):
|
929
|
+
...
|
930
|
+
KeyError: "Field 'c' not found in dataset"
|
931
|
+
"""
|
932
|
+
from edsl.results.Dataset import Dataset
|
933
|
+
|
934
|
+
# Check if field exists in the dataset
|
935
|
+
if field_name not in self.relevant_columns():
|
936
|
+
raise KeyError(f"Field '{field_name}' not found in dataset")
|
937
|
+
|
938
|
+
# Create a new dataset without the specified field
|
939
|
+
new_data = [entry for entry in self.data if field_name not in entry]
|
940
|
+
return Dataset(new_data)
|
941
|
+
|
942
|
+
def remove_prefix(self):
|
943
|
+
"""Returns a new Dataset with the prefix removed from all column names.
|
944
|
+
|
945
|
+
The prefix is defined as everything before the first dot (.) in the column name.
|
946
|
+
If removing prefixes would result in duplicate column names, an exception is raised.
|
947
|
+
|
948
|
+
Returns:
|
949
|
+
Dataset: A new Dataset with prefixes removed from column names
|
950
|
+
|
951
|
+
Raises:
|
952
|
+
ValueError: If removing prefixes would result in duplicate column names
|
953
|
+
|
954
|
+
Examples:
|
955
|
+
>>> from edsl.results import Results
|
956
|
+
>>> r = Results.example()
|
957
|
+
>>> r.select('how_feeling', 'how_feeling_yesterday').relevant_columns()
|
958
|
+
['answer.how_feeling', 'answer.how_feeling_yesterday']
|
959
|
+
>>> r.select('how_feeling', 'how_feeling_yesterday').remove_prefix().relevant_columns()
|
960
|
+
['how_feeling', 'how_feeling_yesterday']
|
961
|
+
|
962
|
+
>>> from edsl.results.Dataset import Dataset
|
963
|
+
>>> d = Dataset([{'a.x': [1, 2, 3]}, {'b.x': [4, 5, 6]}])
|
964
|
+
>>> d.remove_prefix()
|
965
|
+
Traceback (most recent call last):
|
966
|
+
...
|
967
|
+
ValueError: Removing prefixes would result in duplicate column names: ['x']
|
968
|
+
"""
|
969
|
+
from edsl.results.Dataset import Dataset
|
970
|
+
|
971
|
+
# Get all column names
|
972
|
+
columns = self.relevant_columns()
|
973
|
+
|
974
|
+
# Extract the unprefixed names
|
975
|
+
unprefixed = {}
|
976
|
+
duplicates = set()
|
977
|
+
|
978
|
+
for col in columns:
|
979
|
+
if '.' in col:
|
980
|
+
unprefixed_name = col.split('.', 1)[1]
|
981
|
+
if unprefixed_name in unprefixed:
|
982
|
+
duplicates.add(unprefixed_name)
|
983
|
+
unprefixed[unprefixed_name] = col
|
984
|
+
else:
|
985
|
+
# For columns without a prefix, keep them as is
|
986
|
+
unprefixed[col] = col
|
987
|
+
|
988
|
+
# Check for duplicates
|
989
|
+
if duplicates:
|
990
|
+
raise ValueError(f"Removing prefixes would result in duplicate column names: {sorted(list(duplicates))}")
|
991
|
+
|
992
|
+
# Create a new dataset with unprefixed column names
|
993
|
+
new_data = []
|
994
|
+
for entry in self.data:
|
995
|
+
key, values = list(entry.items())[0]
|
996
|
+
if '.' in key:
|
997
|
+
new_key = key.split('.', 1)[1]
|
998
|
+
else:
|
999
|
+
new_key = key
|
1000
|
+
new_data.append({new_key: values})
|
1001
|
+
|
1002
|
+
return Dataset(new_data)
|
906
1003
|
|
907
1004
|
|
908
1005
|
if __name__ == "__main__":
|
edsl/results/Results.py
CHANGED
@@ -1379,6 +1379,65 @@ class Results(UserList, Mixins, Base):
|
|
1379
1379
|
raise ResultsError(f"Failed to fetch remote results: {str(e)}")
|
1380
1380
|
|
1381
1381
|
|
1382
|
+
def spot_issues(self, models: Optional[ModelList] = None) -> Results:
|
1383
|
+
"""Run a survey to spot issues and suggest improvements for prompts that had no model response, returning a new Results object.
|
1384
|
+
Future version: Allow user to optionally pass a list of questions to review, regardless of whether they had a null model response.
|
1385
|
+
"""
|
1386
|
+
from edsl.questions import QuestionFreeText, QuestionDict
|
1387
|
+
from edsl.surveys import Survey
|
1388
|
+
from edsl.scenarios import Scenario, ScenarioList
|
1389
|
+
from edsl.language_models import Model, ModelList
|
1390
|
+
import pandas as pd
|
1391
|
+
|
1392
|
+
df = self.select("agent.*", "scenario.*", "answer.*", "raw_model_response.*", "prompt.*").to_pandas()
|
1393
|
+
scenario_list = []
|
1394
|
+
|
1395
|
+
for _, row in df.iterrows():
|
1396
|
+
for col in df.columns:
|
1397
|
+
if col.endswith("_raw_model_response") and pd.isna(row[col]):
|
1398
|
+
q = col.split("_raw_model_response")[0].replace("raw_model_response.", "")
|
1399
|
+
|
1400
|
+
s = Scenario({
|
1401
|
+
"original_question": q,
|
1402
|
+
"original_agent_index": row["agent.agent_index"],
|
1403
|
+
"original_scenario_index": row["scenario.scenario_index"],
|
1404
|
+
"original_prompts": f"User prompt: {row[f'prompt.{q}_user_prompt']}\nSystem prompt: {row[f'prompt.{q}_system_prompt']}"
|
1405
|
+
})
|
1406
|
+
|
1407
|
+
scenario_list.append(s)
|
1408
|
+
|
1409
|
+
sl = ScenarioList(set(scenario_list))
|
1410
|
+
|
1411
|
+
q1 = QuestionFreeText(
|
1412
|
+
question_name = "issues",
|
1413
|
+
question_text = """
|
1414
|
+
The following prompts generated a bad or null response: '{{ original_prompts }}'
|
1415
|
+
What do you think was the likely issue(s)?
|
1416
|
+
"""
|
1417
|
+
)
|
1418
|
+
|
1419
|
+
q2 = QuestionDict(
|
1420
|
+
question_name = "revised",
|
1421
|
+
question_text = """
|
1422
|
+
The following prompts generated a bad or null response: '{{ original_prompts }}'
|
1423
|
+
You identified the issue(s) as '{{ issues.answer }}'.
|
1424
|
+
Please revise the prompts to address the issue(s).
|
1425
|
+
""",
|
1426
|
+
answer_keys = ["revised_user_prompt", "revised_system_prompt"]
|
1427
|
+
)
|
1428
|
+
|
1429
|
+
survey = Survey(questions = [q1, q2])
|
1430
|
+
|
1431
|
+
if models is not None:
|
1432
|
+
if not isinstance(models, ModelList):
|
1433
|
+
raise ResultsError("models must be a ModelList")
|
1434
|
+
results = survey.by(sl).by(models).run()
|
1435
|
+
else:
|
1436
|
+
results = survey.by(sl).run() # use the default model
|
1437
|
+
|
1438
|
+
return results
|
1439
|
+
|
1440
|
+
|
1382
1441
|
def main(): # pragma: no cover
|
1383
1442
|
"""Call the OpenAI API credits."""
|
1384
1443
|
from edsl.results.Results import Results
|
edsl/scenarios/FileStore.py
CHANGED
@@ -11,6 +11,10 @@ from edsl.utilities.remove_edsl_version import remove_edsl_version
|
|
11
11
|
from edsl.scenarios.file_methods import FileMethods
|
12
12
|
from typing import Union
|
13
13
|
from uuid import UUID
|
14
|
+
import time
|
15
|
+
from typing import Dict, Any, IO, Optional, List, Union, Literal
|
16
|
+
|
17
|
+
|
14
18
|
|
15
19
|
class FileStore(Scenario):
|
16
20
|
__documentation__ = "https://docs.expectedparrot.com/en/latest/filestore.html"
|
@@ -30,7 +34,7 @@ class FileStore(Scenario):
|
|
30
34
|
path = kwargs["filename"]
|
31
35
|
|
32
36
|
# Check if path is a URL and handle download
|
33
|
-
if path and (path.startswith(
|
37
|
+
if path and (path.startswith("http://") or path.startswith("https://")):
|
34
38
|
temp_filestore = self.from_url(path, mime_type=mime_type)
|
35
39
|
path = temp_filestore._path
|
36
40
|
mime_type = temp_filestore.mime_type
|
@@ -91,6 +95,102 @@ class FileStore(Scenario):
|
|
91
95
|
else:
|
92
96
|
print(f"Example for {example_type} is not supported.")
|
93
97
|
|
98
|
+
@classmethod
|
99
|
+
async def _async_screenshot(
|
100
|
+
cls,
|
101
|
+
url: str,
|
102
|
+
full_page: bool = True,
|
103
|
+
wait_until: Literal[
|
104
|
+
"load", "domcontentloaded", "networkidle", "commit"
|
105
|
+
] = "networkidle",
|
106
|
+
download_path: Optional[str] = None,
|
107
|
+
) -> "FileStore":
|
108
|
+
"""Async version of screenshot functionality"""
|
109
|
+
try:
|
110
|
+
from playwright.async_api import async_playwright
|
111
|
+
except ImportError:
|
112
|
+
raise ImportError(
|
113
|
+
"Screenshot functionality requires additional dependencies.\n"
|
114
|
+
"Install them with: pip install 'edsl[screenshot]'"
|
115
|
+
)
|
116
|
+
|
117
|
+
if download_path is None:
|
118
|
+
download_path = os.path.join(
|
119
|
+
os.getcwd(), f"screenshot_{int(time.time())}.png"
|
120
|
+
)
|
121
|
+
|
122
|
+
async with async_playwright() as p:
|
123
|
+
browser = await p.chromium.launch()
|
124
|
+
page = await browser.new_page()
|
125
|
+
await page.goto(url, wait_until=wait_until)
|
126
|
+
await page.screenshot(path=download_path, full_page=full_page)
|
127
|
+
await browser.close()
|
128
|
+
|
129
|
+
return cls(download_path, mime_type="image/png")
|
130
|
+
|
131
|
+
@classmethod
|
132
|
+
def from_url_screenshot(cls, url: str, **kwargs) -> "FileStore":
|
133
|
+
"""Synchronous wrapper for screenshot functionality"""
|
134
|
+
import asyncio
|
135
|
+
|
136
|
+
try:
|
137
|
+
# Try using get_event_loop first (works in regular Python)
|
138
|
+
loop = asyncio.get_event_loop()
|
139
|
+
except RuntimeError:
|
140
|
+
# If we're in IPython/Jupyter, create a new loop
|
141
|
+
loop = asyncio.new_event_loop()
|
142
|
+
asyncio.set_event_loop(loop)
|
143
|
+
|
144
|
+
try:
|
145
|
+
return loop.run_until_complete(cls._async_screenshot(url, **kwargs))
|
146
|
+
finally:
|
147
|
+
if not loop.is_running():
|
148
|
+
loop.close()
|
149
|
+
|
150
|
+
@classmethod
|
151
|
+
def batch_screenshots(cls, urls: List[str], **kwargs) -> "ScenarioList":
|
152
|
+
"""
|
153
|
+
Take screenshots of multiple URLs concurrently.
|
154
|
+
Args:
|
155
|
+
urls: List of URLs to screenshot
|
156
|
+
**kwargs: Additional arguments passed to screenshot function (full_page, wait_until, etc.)
|
157
|
+
Returns:
|
158
|
+
ScenarioList containing FileStore objects with their corresponding URLs
|
159
|
+
"""
|
160
|
+
from edsl import ScenarioList
|
161
|
+
|
162
|
+
try:
|
163
|
+
# Try using get_event_loop first (works in regular Python)
|
164
|
+
loop = asyncio.get_event_loop()
|
165
|
+
except RuntimeError:
|
166
|
+
# If we're in IPython/Jupyter, create a new loop
|
167
|
+
loop = asyncio.new_event_loop()
|
168
|
+
asyncio.set_event_loop(loop)
|
169
|
+
|
170
|
+
# Create tasks for all screenshots
|
171
|
+
tasks = [cls._async_screenshot(url, **kwargs) for url in urls]
|
172
|
+
|
173
|
+
try:
|
174
|
+
# Run all screenshots concurrently
|
175
|
+
results = loop.run_until_complete(
|
176
|
+
asyncio.gather(*tasks, return_exceptions=True)
|
177
|
+
)
|
178
|
+
|
179
|
+
# Filter out any errors and log them
|
180
|
+
successful_results = []
|
181
|
+
for url, result in zip(urls, results):
|
182
|
+
if isinstance(result, Exception):
|
183
|
+
print(f"Failed to screenshot {url}: {result}")
|
184
|
+
else:
|
185
|
+
successful_results.append(
|
186
|
+
Scenario({"url": url, "screenshot": result})
|
187
|
+
)
|
188
|
+
|
189
|
+
return ScenarioList(successful_results)
|
190
|
+
finally:
|
191
|
+
if not loop.is_running():
|
192
|
+
loop.close()
|
193
|
+
|
94
194
|
@property
|
95
195
|
def size(self) -> int:
|
96
196
|
if self.base64_string != None:
|
@@ -273,12 +373,11 @@ class FileStore(Scenario):
|
|
273
373
|
# raise TypeError("No text method found for this file type.")
|
274
374
|
|
275
375
|
def push(
|
276
|
-
self,
|
277
|
-
description: Optional[str] = None,
|
376
|
+
self,
|
377
|
+
description: Optional[str] = None,
|
278
378
|
alias: Optional[str] = None,
|
279
379
|
visibility: Optional[str] = "unlisted",
|
280
380
|
expected_parrot_url: Optional[str] = None,
|
281
|
-
|
282
381
|
) -> dict:
|
283
382
|
"""
|
284
383
|
Push the object to Coop.
|
@@ -286,20 +385,26 @@ class FileStore(Scenario):
|
|
286
385
|
:param visibility: The visibility of the object to push.
|
287
386
|
"""
|
288
387
|
scenario_version = Scenario.from_dict(self.to_dict())
|
388
|
+
|
289
389
|
if description is None:
|
290
390
|
description = "File: " + self.path
|
291
|
-
info = scenario_version.push(
|
391
|
+
info = scenario_version.push(
|
392
|
+
description=description,
|
393
|
+
visibility=visibility,
|
394
|
+
expected_parrot_url=expected_parrot_url,
|
395
|
+
alias=alias,
|
396
|
+
)
|
292
397
|
return info
|
293
398
|
|
294
399
|
@classmethod
|
295
400
|
def pull(cls, url_or_uuid: Union[str, UUID]) -> "FileStore":
|
296
401
|
"""
|
297
402
|
Pull a FileStore object from Coop.
|
298
|
-
|
403
|
+
|
299
404
|
Args:
|
300
405
|
url_or_uuid: Either a UUID string or a URL pointing to the object
|
301
406
|
expected_parrot_url: Optional URL for the Parrot server
|
302
|
-
|
407
|
+
|
303
408
|
Returns:
|
304
409
|
FileStore: The pulled FileStore object
|
305
410
|
"""
|
edsl/scenarios/ScenarioList.py
CHANGED
@@ -1156,6 +1156,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1156
1156
|
|
1157
1157
|
return scenario_list
|
1158
1158
|
|
1159
|
+
@classmethod
|
1159
1160
|
def from_wikipedia(cls, url: str, table_index: int = 0):
|
1160
1161
|
"""
|
1161
1162
|
Extracts a table from a Wikipedia page.
|
@@ -1672,6 +1673,135 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1672
1673
|
|
1673
1674
|
return ScenarioList(result)
|
1674
1675
|
|
1676
|
+
def create_comparisons(
|
1677
|
+
self,
|
1678
|
+
bidirectional: bool = False,
|
1679
|
+
num_options: int = 2,
|
1680
|
+
option_prefix: str = "option_",
|
1681
|
+
use_alphabet: bool = False
|
1682
|
+
) -> ScenarioList:
|
1683
|
+
"""Create a new ScenarioList with comparisons between scenarios.
|
1684
|
+
|
1685
|
+
Each scenario in the result contains multiple original scenarios as dictionaries,
|
1686
|
+
allowing for side-by-side comparison.
|
1687
|
+
|
1688
|
+
Args:
|
1689
|
+
bidirectional (bool): If True, include both (A,B) and (B,A) comparisons.
|
1690
|
+
If False, only include (A,B) where A comes before B in the original list.
|
1691
|
+
num_options (int): Number of scenarios to include in each comparison.
|
1692
|
+
Default is 2 for pairwise comparisons.
|
1693
|
+
option_prefix (str): Prefix for the keys in the resulting scenarios.
|
1694
|
+
Default is "option_", resulting in keys like "option_1", "option_2", etc.
|
1695
|
+
Ignored if use_alphabet is True.
|
1696
|
+
use_alphabet (bool): If True, use letters as keys (A, B, C, etc.) instead of
|
1697
|
+
the option_prefix with numbers.
|
1698
|
+
|
1699
|
+
Returns:
|
1700
|
+
ScenarioList: A new ScenarioList where each scenario contains multiple original
|
1701
|
+
scenarios as dictionaries.
|
1702
|
+
|
1703
|
+
Example:
|
1704
|
+
>>> s = ScenarioList([
|
1705
|
+
... Scenario({'id': 1, 'text': 'Option A'}),
|
1706
|
+
... Scenario({'id': 2, 'text': 'Option B'}),
|
1707
|
+
... Scenario({'id': 3, 'text': 'Option C'})
|
1708
|
+
... ])
|
1709
|
+
>>> s.create_comparisons(use_alphabet=True)
|
1710
|
+
ScenarioList([Scenario({'A': {'id': 1, 'text': 'Option A'}, 'B': {'id': 2, 'text': 'Option B'}}), Scenario({'A': {'id': 1, 'text': 'Option A'}, 'B': {'id': 3, 'text': 'Option C'}}), Scenario({'A': {'id': 2, 'text': 'Option B'}, 'B': {'id': 3, 'text': 'Option C'}})])
|
1711
|
+
>>> s.create_comparisons(num_options=3, use_alphabet=True)
|
1712
|
+
ScenarioList([Scenario({'A': {'id': 1, 'text': 'Option A'}, 'B': {'id': 2, 'text': 'Option B'}, 'C': {'id': 3, 'text': 'Option C'}})])
|
1713
|
+
"""
|
1714
|
+
from itertools import combinations, permutations
|
1715
|
+
import string
|
1716
|
+
|
1717
|
+
if num_options < 2:
|
1718
|
+
raise ValueError("num_options must be at least 2")
|
1719
|
+
|
1720
|
+
if num_options > len(self):
|
1721
|
+
raise ValueError(f"num_options ({num_options}) cannot exceed the number of scenarios ({len(self)})")
|
1722
|
+
|
1723
|
+
if use_alphabet and num_options > 26:
|
1724
|
+
raise ValueError("When using alphabet labels, num_options cannot exceed 26 (the number of letters in the English alphabet)")
|
1725
|
+
|
1726
|
+
# Convert each scenario to a dictionary
|
1727
|
+
scenario_dicts = [scenario.to_dict(add_edsl_version=False) for scenario in self]
|
1728
|
+
|
1729
|
+
# Generate combinations or permutations based on bidirectional flag
|
1730
|
+
if bidirectional:
|
1731
|
+
# For bidirectional, use permutations to get all ordered arrangements
|
1732
|
+
if num_options == 2:
|
1733
|
+
# For pairwise, we can use permutations with r=2
|
1734
|
+
scenario_groups = permutations(scenario_dicts, 2)
|
1735
|
+
else:
|
1736
|
+
# For more than 2 options with bidirectional=True,
|
1737
|
+
# we need all permutations of the specified size
|
1738
|
+
scenario_groups = permutations(scenario_dicts, num_options)
|
1739
|
+
else:
|
1740
|
+
# For unidirectional, use combinations to get unordered groups
|
1741
|
+
scenario_groups = combinations(scenario_dicts, num_options)
|
1742
|
+
|
1743
|
+
# Create new scenarios with the combinations
|
1744
|
+
result = []
|
1745
|
+
for group in scenario_groups:
|
1746
|
+
new_scenario = {}
|
1747
|
+
for i, scenario_dict in enumerate(group):
|
1748
|
+
if use_alphabet:
|
1749
|
+
# Use uppercase letters (A, B, C, etc.)
|
1750
|
+
key = string.ascii_uppercase[i]
|
1751
|
+
else:
|
1752
|
+
# Use the option prefix with numbers (option_1, option_2, etc.)
|
1753
|
+
key = f"{option_prefix}{i+1}"
|
1754
|
+
new_scenario[key] = scenario_dict
|
1755
|
+
result.append(Scenario(new_scenario))
|
1756
|
+
|
1757
|
+
return ScenarioList(result)
|
1758
|
+
|
1759
|
+
@classmethod
|
1760
|
+
def from_parquet(cls, filepath: str) -> ScenarioList:
|
1761
|
+
"""Create a ScenarioList from a Parquet file.
|
1762
|
+
|
1763
|
+
Args:
|
1764
|
+
filepath (str): Path to the Parquet file
|
1765
|
+
|
1766
|
+
Returns:
|
1767
|
+
ScenarioList: A ScenarioList containing the data from the Parquet file
|
1768
|
+
|
1769
|
+
Example:
|
1770
|
+
>>> import pandas as pd
|
1771
|
+
>>> import tempfile
|
1772
|
+
>>> df = pd.DataFrame({'name': ['Alice', 'Bob'], 'age': [30, 25]})
|
1773
|
+
>>> # The following would create and read a parquet file if dependencies are installed:
|
1774
|
+
>>> # with tempfile.NamedTemporaryFile(suffix='.parquet', delete=False) as f:
|
1775
|
+
>>> # df.to_parquet(f.name)
|
1776
|
+
>>> # scenario_list = ScenarioList.from_parquet(f.name)
|
1777
|
+
>>> # Instead, we'll demonstrate the equivalent result:
|
1778
|
+
>>> scenario_list = ScenarioList.from_pandas(df)
|
1779
|
+
>>> len(scenario_list)
|
1780
|
+
2
|
1781
|
+
>>> scenario_list[0]['name']
|
1782
|
+
'Alice'
|
1783
|
+
"""
|
1784
|
+
import pandas as pd
|
1785
|
+
|
1786
|
+
try:
|
1787
|
+
# Try to read the Parquet file with pandas
|
1788
|
+
df = pd.read_parquet(filepath)
|
1789
|
+
except ImportError as e:
|
1790
|
+
# Handle missing dependencies with a helpful error message
|
1791
|
+
if "pyarrow" in str(e) or "fastparquet" in str(e):
|
1792
|
+
raise ImportError(
|
1793
|
+
"Missing dependencies for Parquet support. Please install either pyarrow or fastparquet:\n"
|
1794
|
+
" pip install pyarrow\n"
|
1795
|
+
" or\n"
|
1796
|
+
" pip install fastparquet"
|
1797
|
+
) from e
|
1798
|
+
else:
|
1799
|
+
raise
|
1800
|
+
|
1801
|
+
# Convert the DataFrame to a ScenarioList
|
1802
|
+
return cls.from_pandas(df)
|
1803
|
+
|
1804
|
+
|
1675
1805
|
|
1676
1806
|
if __name__ == "__main__":
|
1677
1807
|
import doctest
|
edsl/study/Study.py
CHANGED
@@ -504,12 +504,12 @@ class Study:
|
|
504
504
|
)
|
505
505
|
|
506
506
|
@classmethod
|
507
|
-
def pull(cls,
|
507
|
+
def pull(cls, url_or_uuid: Union[str, UUID]):
|
508
508
|
"""Pull the object from coop."""
|
509
509
|
from edsl.coop import Coop
|
510
510
|
|
511
511
|
coop = Coop()
|
512
|
-
return coop.get(
|
512
|
+
return coop.get(url_or_uuid, "study")
|
513
513
|
|
514
514
|
def __repr__(self):
|
515
515
|
return f"""Study(name = "{self.name}", description = "{self.description}", objects = {self.objects}, cache = {self.cache}, filename = "{self.filename}", coop = {self.coop}, use_study_cache = {self.use_study_cache}, overwrite_on_change = {self.overwrite_on_change})"""
|
edsl/surveys/Survey.py
CHANGED
@@ -172,6 +172,13 @@ class Survey(SurveyExportMixin, Base):
|
|
172
172
|
|
173
173
|
self._seed = None
|
174
174
|
|
175
|
+
# Cache the InstructionCollection
|
176
|
+
self._cached_instruction_collection = None
|
177
|
+
|
178
|
+
def question_names_valid(self) -> bool:
|
179
|
+
"""Check if the question names are valid."""
|
180
|
+
return all(q.is_valid_question_name() for q in self.questions)
|
181
|
+
|
175
182
|
def draw(self) -> "Survey":
|
176
183
|
"""Return a new survey with a randomly selected permutation of the options."""
|
177
184
|
if self._seed is None: # only set once
|
@@ -205,28 +212,16 @@ class Survey(SurveyExportMixin, Base):
|
|
205
212
|
# region: Survey instruction handling
|
206
213
|
@property
|
207
214
|
def _relevant_instructions_dict(self) -> InstructionCollection:
|
208
|
-
"""Return a dictionary with keys as question names and values as instructions that are relevant to the question.
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
"""
|
215
|
-
return InstructionCollection(
|
216
|
-
self._instruction_names_to_instructions, self.questions
|
217
|
-
)
|
215
|
+
"""Return a dictionary with keys as question names and values as instructions that are relevant to the question."""
|
216
|
+
if self._cached_instruction_collection is None:
|
217
|
+
self._cached_instruction_collection = InstructionCollection(
|
218
|
+
self._instruction_names_to_instructions, self.questions
|
219
|
+
)
|
220
|
+
return self._cached_instruction_collection
|
218
221
|
|
219
222
|
def _relevant_instructions(self, question: QuestionBase) -> dict:
|
220
|
-
"""
|
221
|
-
|
222
|
-
:param question: The question to get the relevant instructions for.
|
223
|
-
|
224
|
-
# Did the instruction come before the question and was it not modified by a change instruction?
|
225
|
-
|
226
|
-
"""
|
227
|
-
return InstructionCollection(
|
228
|
-
self._instruction_names_to_instructions, self.questions
|
229
|
-
)[question]
|
223
|
+
"""Return instructions that are relevant to the question."""
|
224
|
+
return self._relevant_instructions_dict[question]
|
230
225
|
|
231
226
|
def show_flow(self, filename: Optional[str] = None) -> None:
|
232
227
|
"""Show the flow of the survey."""
|