edsl 0.1.39.dev1__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +169 -116
- edsl/__init__.py +14 -6
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +358 -146
- edsl/agents/AgentList.py +211 -73
- edsl/agents/Invigilator.py +88 -36
- edsl/agents/InvigilatorBase.py +59 -70
- edsl/agents/PromptConstructor.py +117 -219
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionOptionProcessor.py +172 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/__init__.py +0 -1
- edsl/agents/prompt_helpers.py +3 -3
- edsl/config.py +22 -2
- edsl/conversation/car_buying.py +2 -1
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +1 -1
- edsl/coop/coop.py +104 -42
- edsl/coop/utils.py +14 -14
- edsl/data/Cache.py +21 -14
- edsl/data/CacheEntry.py +12 -15
- edsl/data/CacheHandler.py +33 -12
- edsl/data/__init__.py +4 -3
- edsl/data_transfer_models.py +2 -1
- edsl/enums.py +20 -0
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +12 -0
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/questions.py +24 -6
- edsl/exceptions/scenarios.py +7 -0
- edsl/inference_services/AnthropicService.py +0 -3
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +209 -0
- edsl/inference_services/AwsBedrock.py +0 -2
- edsl/inference_services/AzureAI.py +0 -2
- edsl/inference_services/GoogleService.py +2 -11
- edsl/inference_services/InferenceServiceABC.py +18 -85
- edsl/inference_services/InferenceServicesCollection.py +105 -80
- edsl/inference_services/MistralAIService.py +0 -3
- edsl/inference_services/OpenAIService.py +1 -4
- edsl/inference_services/PerplexityService.py +0 -3
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +11 -8
- edsl/inference_services/data_structures.py +62 -0
- edsl/jobs/AnswerQuestionFunctionConstructor.py +188 -0
- edsl/jobs/Answers.py +1 -14
- edsl/jobs/FetchInvigilator.py +40 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +48 -0
- edsl/jobs/Jobs.py +102 -243
- edsl/jobs/JobsChecks.py +35 -10
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +5 -3
- edsl/jobs/JobsRemoteInferenceHandler.py +128 -80
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/buckets/BucketCollection.py +44 -3
- edsl/jobs/buckets/TokenBucket.py +53 -21
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +77 -380
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +4 -49
- edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
- edsl/jobs/tasks/TaskHistory.py +14 -15
- edsl/jobs/tasks/task_status_enum.py +0 -2
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +137 -234
- edsl/language_models/ModelList.py +11 -13
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/__init__.py +0 -1
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/registry.py +49 -59
- edsl/language_models/repair.py +2 -2
- edsl/language_models/utilities.py +5 -4
- edsl/notebooks/Notebook.py +19 -14
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/prompts/Prompt.py +29 -39
- edsl/questions/AnswerValidatorMixin.py +47 -2
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/LoopProcessor.py +149 -0
- edsl/questions/QuestionBase.py +37 -192
- edsl/questions/QuestionBaseGenMixin.py +52 -48
- edsl/questions/QuestionBasePromptsMixin.py +7 -3
- edsl/questions/QuestionCheckBox.py +1 -1
- edsl/questions/QuestionExtract.py +1 -1
- edsl/questions/QuestionFreeText.py +1 -2
- edsl/questions/QuestionList.py +3 -5
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +66 -22
- edsl/questions/QuestionNumerical.py +1 -3
- edsl/questions/QuestionRank.py +6 -16
- edsl/questions/ResponseValidatorABC.py +37 -11
- edsl/questions/ResponseValidatorFactory.py +28 -0
- edsl/questions/SimpleAskMixin.py +4 -3
- edsl/questions/__init__.py +1 -0
- edsl/questions/derived/QuestionLinearScale.py +6 -3
- edsl/questions/derived/QuestionTopK.py +1 -1
- edsl/questions/descriptors.py +17 -3
- edsl/questions/question_registry.py +1 -1
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/results/CSSParameterizer.py +1 -1
- edsl/results/Dataset.py +170 -7
- edsl/results/DatasetExportMixin.py +224 -302
- edsl/results/DatasetTree.py +28 -8
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +192 -206
- edsl/results/Results.py +120 -113
- edsl/results/ResultsExportMixin.py +2 -0
- edsl/results/Selector.py +23 -13
- edsl/results/TableDisplay.py +98 -171
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +1 -1
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_renderers.py +118 -0
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DirectoryScanner.py +96 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +118 -239
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +90 -193
- edsl/scenarios/ScenarioHtmlMixin.py +4 -3
- edsl/scenarios/ScenarioJoin.py +10 -6
- edsl/scenarios/ScenarioList.py +383 -240
- edsl/scenarios/ScenarioListExportMixin.py +0 -7
- edsl/scenarios/ScenarioListPdfMixin.py +15 -37
- edsl/scenarios/ScenarioSelector.py +156 -0
- edsl/scenarios/__init__.py +1 -2
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +38 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/study/ObjectEntry.py +1 -1
- edsl/study/SnapShot.py +1 -1
- edsl/study/Study.py +5 -12
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/Rule.py +5 -4
- edsl/surveys/RuleCollection.py +25 -27
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +199 -771
- edsl/surveys/SurveyCSS.py +20 -8
- edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +4 -2
- edsl/surveys/descriptors.py +6 -2
- edsl/surveys/instructions/ChangeInstruction.py +1 -2
- edsl/surveys/instructions/Instruction.py +4 -13
- edsl/surveys/instructions/InstructionCollection.py +11 -6
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/report.html +1 -1
- edsl/tools/plotting.py +1 -1
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/utilities.py +35 -23
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +12 -10
- edsl-0.1.39.dev2.dist-info/RECORD +352 -0
- edsl/language_models/KeyLookup.py +0 -30
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/results/ResultsDBMixin.py +0 -238
- edsl-0.1.39.dev1.dist-info/RECORD +0 -277
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +0 -0
edsl/scenarios/ScenarioList.py
CHANGED
@@ -1,33 +1,78 @@
|
|
1
1
|
"""A list of Scenarios to be used in a survey."""
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
|
-
from typing import
|
4
|
+
from typing import (
|
5
|
+
Any,
|
6
|
+
Optional,
|
7
|
+
Union,
|
8
|
+
List,
|
9
|
+
Callable,
|
10
|
+
Literal,
|
11
|
+
TYPE_CHECKING,
|
12
|
+
)
|
13
|
+
|
14
|
+
try:
|
15
|
+
from typing import TypeAlias
|
16
|
+
except ImportError:
|
17
|
+
from typing_extensions import TypeAlias
|
18
|
+
|
5
19
|
import csv
|
6
20
|
import random
|
7
|
-
from collections import UserList, Counter
|
8
|
-
from collections.abc import Iterable
|
9
|
-
import urllib.parse
|
10
|
-
import urllib.request
|
11
21
|
from io import StringIO
|
12
|
-
from collections import defaultdict
|
13
22
|
import inspect
|
23
|
+
from collections import UserList, defaultdict
|
24
|
+
from collections.abc import Iterable
|
25
|
+
|
26
|
+
if TYPE_CHECKING:
|
27
|
+
from urllib.parse import ParseResult
|
28
|
+
from edsl.results.Dataset import Dataset
|
29
|
+
from edsl.jobs.Jobs import Jobs
|
30
|
+
from edsl.surveys.Survey import Survey
|
31
|
+
from edsl.questions.QuestionBase import QuestionBase
|
14
32
|
|
15
|
-
|
33
|
+
|
34
|
+
from simpleeval import EvalWithCompoundTypes, NameNotDefined # type: ignore
|
35
|
+
|
36
|
+
from tabulate import tabulate_formats
|
16
37
|
|
17
38
|
from edsl.Base import Base
|
18
|
-
from edsl.utilities.
|
39
|
+
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
40
|
+
|
19
41
|
from edsl.scenarios.Scenario import Scenario
|
20
42
|
from edsl.scenarios.ScenarioListPdfMixin import ScenarioListPdfMixin
|
21
43
|
from edsl.scenarios.ScenarioListExportMixin import ScenarioListExportMixin
|
22
|
-
|
23
44
|
from edsl.utilities.naming_utilities import sanitize_string
|
24
|
-
from edsl.utilities.
|
45
|
+
from edsl.utilities.is_valid_variable_name import is_valid_variable_name
|
46
|
+
from edsl.exceptions.scenarios import ScenarioError
|
47
|
+
|
48
|
+
from edsl.scenarios.DirectoryScanner import DirectoryScanner
|
25
49
|
|
26
50
|
|
27
51
|
class ScenarioListMixin(ScenarioListPdfMixin, ScenarioListExportMixin):
|
28
52
|
pass
|
29
53
|
|
30
54
|
|
55
|
+
if TYPE_CHECKING:
|
56
|
+
from edsl.results.Dataset import Dataset
|
57
|
+
|
58
|
+
TableFormat: TypeAlias = Literal[
|
59
|
+
"plain",
|
60
|
+
"simple",
|
61
|
+
"github",
|
62
|
+
"grid",
|
63
|
+
"fancy_grid",
|
64
|
+
"pipe",
|
65
|
+
"orgtbl",
|
66
|
+
"rst",
|
67
|
+
"mediawiki",
|
68
|
+
"html",
|
69
|
+
"latex",
|
70
|
+
"latex_raw",
|
71
|
+
"latex_booktabs",
|
72
|
+
"tsv",
|
73
|
+
]
|
74
|
+
|
75
|
+
|
31
76
|
class ScenarioList(Base, UserList, ScenarioListMixin):
|
32
77
|
"""Class for creating a list of scenarios to be used in a survey."""
|
33
78
|
|
@@ -35,7 +80,9 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
35
80
|
"https://docs.expectedparrot.com/en/latest/scenarios.html#scenariolist"
|
36
81
|
)
|
37
82
|
|
38
|
-
def __init__(
|
83
|
+
def __init__(
|
84
|
+
self, data: Optional[list] = None, codebook: Optional[dict[str, str]] = None
|
85
|
+
):
|
39
86
|
"""Initialize the ScenarioList class."""
|
40
87
|
if data is not None:
|
41
88
|
super().__init__(data)
|
@@ -57,12 +104,19 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
57
104
|
"""Check if the ScenarioList has Jinja braces."""
|
58
105
|
return any([scenario.has_jinja_braces for scenario in self])
|
59
106
|
|
60
|
-
def
|
107
|
+
def _convert_jinja_braces(self) -> ScenarioList:
|
61
108
|
"""Convert Jinja braces to Python braces."""
|
62
|
-
return ScenarioList([scenario.
|
109
|
+
return ScenarioList([scenario._convert_jinja_braces() for scenario in self])
|
63
110
|
|
64
|
-
def give_valid_names(self) -> ScenarioList:
|
65
|
-
"""Give valid names to the scenario keys.
|
111
|
+
def give_valid_names(self, existing_codebook: dict = None) -> ScenarioList:
|
112
|
+
"""Give valid names to the scenario keys, using an existing codebook if provided.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
existing_codebook (dict, optional): Existing mapping of original keys to valid names.
|
116
|
+
Defaults to None.
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
ScenarioList: A new ScenarioList with valid variable names and updated codebook.
|
66
120
|
|
67
121
|
>>> s = ScenarioList([Scenario({'a': 1, 'b': 2}), Scenario({'a': 1, 'b': 1})])
|
68
122
|
>>> s.give_valid_names()
|
@@ -70,27 +124,38 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
70
124
|
>>> s = ScenarioList([Scenario({'are you there John?': 1, 'b': 2}), Scenario({'a': 1, 'b': 1})])
|
71
125
|
>>> s.give_valid_names()
|
72
126
|
ScenarioList([Scenario({'john': 1, 'b': 2}), Scenario({'a': 1, 'b': 1})])
|
127
|
+
>>> s.give_valid_names({'are you there John?': 'custom_name'})
|
128
|
+
ScenarioList([Scenario({'custom_name': 1, 'b': 2}), Scenario({'a': 1, 'b': 1})])
|
73
129
|
"""
|
74
|
-
codebook = {}
|
75
|
-
|
130
|
+
codebook = existing_codebook.copy() if existing_codebook else {}
|
131
|
+
new_scenarios = []
|
132
|
+
|
76
133
|
for scenario in self:
|
77
134
|
new_scenario = {}
|
78
135
|
for key in scenario:
|
79
|
-
if
|
80
|
-
if key in codebook:
|
81
|
-
new_key = codebook[key]
|
82
|
-
else:
|
83
|
-
new_key = sanitize_string(key)
|
84
|
-
if not is_valid_variable_name(new_key):
|
85
|
-
new_key = f"var_{len(codebook)}"
|
86
|
-
codebook[key] = new_key
|
87
|
-
new_scenario[new_key] = scenario[key]
|
88
|
-
else:
|
136
|
+
if is_valid_variable_name(key):
|
89
137
|
new_scenario[key] = scenario[key]
|
90
|
-
|
91
|
-
|
138
|
+
continue
|
139
|
+
|
140
|
+
if key in codebook:
|
141
|
+
new_key = codebook[key]
|
142
|
+
else:
|
143
|
+
new_key = sanitize_string(key)
|
144
|
+
if not is_valid_variable_name(new_key):
|
145
|
+
new_key = f"var_{len(codebook)}"
|
146
|
+
codebook[key] = new_key
|
147
|
+
|
148
|
+
new_scenario[new_key] = scenario[key]
|
149
|
+
|
150
|
+
new_scenarios.append(Scenario(new_scenario))
|
151
|
+
|
152
|
+
return ScenarioList(new_scenarios, codebook)
|
92
153
|
|
93
|
-
def unpivot(
|
154
|
+
def unpivot(
|
155
|
+
self,
|
156
|
+
id_vars: Optional[List[str]] = None,
|
157
|
+
value_vars: Optional[List[str]] = None,
|
158
|
+
) -> ScenarioList:
|
94
159
|
"""
|
95
160
|
Unpivot the ScenarioList, allowing for id variables to be specified.
|
96
161
|
|
@@ -121,7 +186,40 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
121
186
|
|
122
187
|
return ScenarioList(new_scenarios)
|
123
188
|
|
124
|
-
def
|
189
|
+
def sem_filter(self, language_predicate: str) -> ScenarioList:
|
190
|
+
"""Filter the ScenarioList based on a language predicate.
|
191
|
+
|
192
|
+
:param language_predicate: The language predicate to use.
|
193
|
+
|
194
|
+
Inspired by:
|
195
|
+
@misc{patel2024semanticoperators,
|
196
|
+
title={Semantic Operators: A Declarative Model for Rich, AI-based Analytics Over Text Data},
|
197
|
+
author={Liana Patel and Siddharth Jha and Parth Asawa and Melissa Pan and Carlos Guestrin and Matei Zaharia},
|
198
|
+
year={2024},
|
199
|
+
eprint={2407.11418},
|
200
|
+
archivePrefix={arXiv},
|
201
|
+
primaryClass={cs.DB},
|
202
|
+
url={https://arxiv.org/abs/2407.11418},
|
203
|
+
}
|
204
|
+
"""
|
205
|
+
from edsl import QuestionYesNo
|
206
|
+
|
207
|
+
new_scenario_list = self.duplicate()
|
208
|
+
q = QuestionYesNo(
|
209
|
+
question_text=language_predicate, question_name="binary_outcome"
|
210
|
+
)
|
211
|
+
results = q.by(new_scenario_list).run(verbose=False)
|
212
|
+
new_scenario_list = new_scenario_list.add_list(
|
213
|
+
"criteria", results.select("binary_outcome").to_list()
|
214
|
+
)
|
215
|
+
return new_scenario_list.filter("criteria == 'Yes'").drop("criteria")
|
216
|
+
|
217
|
+
def pivot(
|
218
|
+
self,
|
219
|
+
id_vars: List[str] = None,
|
220
|
+
var_name="variable",
|
221
|
+
value_name="value",
|
222
|
+
) -> ScenarioList:
|
125
223
|
"""
|
126
224
|
Pivot the ScenarioList from long to wide format.
|
127
225
|
|
@@ -163,15 +261,15 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
163
261
|
|
164
262
|
return ScenarioList(pivoted_scenarios)
|
165
263
|
|
166
|
-
def group_by(
|
264
|
+
def group_by(
|
265
|
+
self, id_vars: List[str], variables: List[str], func: Callable
|
266
|
+
) -> ScenarioList:
|
167
267
|
"""
|
168
268
|
Group the ScenarioList by id_vars and apply a function to the specified variables.
|
169
269
|
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
func (callable): Function to apply to the grouped variables.
|
174
|
-
Should accept lists of values for each variable.
|
270
|
+
:param id_vars: Fields to use as identifier variables
|
271
|
+
:param variables: Fields to group and aggregate
|
272
|
+
:param func: Function to apply to the grouped variables
|
175
273
|
|
176
274
|
Returns:
|
177
275
|
ScenarioList: A new ScenarioList with the grouped and aggregated results
|
@@ -191,12 +289,12 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
191
289
|
# Check if the function is compatible with the specified variables
|
192
290
|
func_params = inspect.signature(func).parameters
|
193
291
|
if len(func_params) != len(variables):
|
194
|
-
raise
|
292
|
+
raise ScenarioError(
|
195
293
|
f"Function {func.__name__} expects {len(func_params)} arguments, but {len(variables)} variables were provided"
|
196
294
|
)
|
197
295
|
|
198
296
|
# Group the scenarios
|
199
|
-
grouped = defaultdict(lambda: defaultdict(list))
|
297
|
+
grouped: dict[str, list] = defaultdict(lambda: defaultdict(list))
|
200
298
|
for scenario in self:
|
201
299
|
key = tuple(scenario[id_var] for id_var in id_vars)
|
202
300
|
for var in variables:
|
@@ -208,10 +306,12 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
208
306
|
try:
|
209
307
|
aggregated = func(*[group[var] for var in variables])
|
210
308
|
except Exception as e:
|
211
|
-
raise
|
309
|
+
raise ScenarioError(f"Error applying function to group {key}: {str(e)}")
|
212
310
|
|
213
311
|
if not isinstance(aggregated, dict):
|
214
|
-
raise
|
312
|
+
raise ScenarioError(
|
313
|
+
f"Function {func.__name__} must return a dictionary"
|
314
|
+
)
|
215
315
|
|
216
316
|
new_scenario = dict(zip(id_vars, key))
|
217
317
|
new_scenario.update(aggregated)
|
@@ -278,50 +378,18 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
278
378
|
"""
|
279
379
|
return self.__mul__(other)
|
280
380
|
|
281
|
-
def shuffle(self, seed: Optional[str] =
|
381
|
+
def shuffle(self, seed: Optional[str] = None) -> ScenarioList:
|
282
382
|
"""Shuffle the ScenarioList.
|
283
383
|
|
284
384
|
>>> s = ScenarioList.from_list("a", [1,2,3,4])
|
285
|
-
>>> s.shuffle()
|
286
|
-
ScenarioList([Scenario({'a':
|
385
|
+
>>> s.shuffle(seed = "1234")
|
386
|
+
ScenarioList([Scenario({'a': 1}), Scenario({'a': 4}), Scenario({'a': 3}), Scenario({'a': 2})])
|
287
387
|
"""
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
"""Return an HTML representation of the AgentList."""
|
294
|
-
# return (
|
295
|
-
# str(self.summary(format="html")) + "<br>" + str(self.table(tablefmt="html"))
|
296
|
-
# )
|
297
|
-
footer = f"<a href={self.__documentation__}>(docs)</a>"
|
298
|
-
return str(self.summary(format="html")) + footer
|
299
|
-
|
300
|
-
# def _repr_html_(self) -> str:
|
301
|
-
# from edsl.utilities.utilities import data_to_html
|
302
|
-
|
303
|
-
# data = self.to_dict()
|
304
|
-
# _ = data.pop("edsl_version")
|
305
|
-
# _ = data.pop("edsl_class_name")
|
306
|
-
# for s in data["scenarios"]:
|
307
|
-
# _ = s.pop("edsl_version")
|
308
|
-
# _ = s.pop("edsl_class_name")
|
309
|
-
# for scenario in data["scenarios"]:
|
310
|
-
# for key, value in scenario.items():
|
311
|
-
# if hasattr(value, "to_dict"):
|
312
|
-
# data[key] = value.to_dict()
|
313
|
-
# return data_to_html(data)
|
314
|
-
|
315
|
-
# def tally(self, field) -> dict:
|
316
|
-
# """Return a tally of the values in the field.
|
317
|
-
|
318
|
-
# Example:
|
319
|
-
|
320
|
-
# >>> s = ScenarioList([Scenario({'a': 1, 'b': 1}), Scenario({'a': 1, 'b': 2})])
|
321
|
-
# >>> s.tally('b')
|
322
|
-
# {1: 1, 2: 1}
|
323
|
-
# """
|
324
|
-
# return dict(Counter([scenario[field] for scenario in self]))
|
388
|
+
sl = self.duplicate()
|
389
|
+
if seed:
|
390
|
+
random.seed(seed)
|
391
|
+
random.shuffle(sl.data)
|
392
|
+
return sl
|
325
393
|
|
326
394
|
def sample(self, n: int, seed: Optional[str] = None) -> ScenarioList:
|
327
395
|
"""Return a random sample from the ScenarioList
|
@@ -333,16 +401,22 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
333
401
|
if seed:
|
334
402
|
random.seed(seed)
|
335
403
|
|
336
|
-
|
404
|
+
sl = self.duplicate()
|
405
|
+
return ScenarioList(random.sample(sl.data, n))
|
337
406
|
|
338
|
-
def expand(self, expand_field: str, number_field=False) -> ScenarioList:
|
407
|
+
def expand(self, expand_field: str, number_field: bool = False) -> ScenarioList:
|
339
408
|
"""Expand the ScenarioList by a field.
|
340
409
|
|
410
|
+
:param expand_field: The field to expand.
|
411
|
+
:param number_field: Whether to add a field with the index of the value
|
412
|
+
|
341
413
|
Example:
|
342
414
|
|
343
415
|
>>> s = ScenarioList( [ Scenario({'a':1, 'b':[1,2]}) ] )
|
344
416
|
>>> s.expand('b')
|
345
417
|
ScenarioList([Scenario({'a': 1, 'b': 1}), Scenario({'a': 1, 'b': 2})])
|
418
|
+
>>> s.expand('b', number_field=True)
|
419
|
+
ScenarioList([Scenario({'a': 1, 'b': 1, 'b_number': 1}), Scenario({'a': 1, 'b': 2, 'b_number': 2})])
|
346
420
|
"""
|
347
421
|
new_scenarios = []
|
348
422
|
for scenario in self:
|
@@ -357,12 +431,11 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
357
431
|
new_scenarios.append(new_scenario)
|
358
432
|
return ScenarioList(new_scenarios)
|
359
433
|
|
360
|
-
def concatenate(self, fields: List[str], separator: str = ";") ->
|
434
|
+
def concatenate(self, fields: List[str], separator: str = ";") -> ScenarioList:
|
361
435
|
"""Concatenate specified fields into a single field.
|
362
436
|
|
363
|
-
|
364
|
-
|
365
|
-
separator (str, optional): Separator to use between field values. Defaults to ";".
|
437
|
+
:param fields: The fields to concatenate.
|
438
|
+
:param separator: The separator to use.
|
366
439
|
|
367
440
|
Returns:
|
368
441
|
ScenarioList: A new ScenarioList with concatenated fields.
|
@@ -392,11 +465,17 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
392
465
|
) -> ScenarioList:
|
393
466
|
"""Unpack a dictionary field into separate fields.
|
394
467
|
|
468
|
+
:param field: The field to unpack.
|
469
|
+
:param prefix: An optional prefix to add to the new fields.
|
470
|
+
:param drop_field: Whether to drop the original field.
|
471
|
+
|
395
472
|
Example:
|
396
473
|
|
397
474
|
>>> s = ScenarioList([Scenario({'a': 1, 'b': {'c': 2, 'd': 3}})])
|
398
475
|
>>> s.unpack_dict('b')
|
399
476
|
ScenarioList([Scenario({'a': 1, 'b': {'c': 2, 'd': 3}, 'c': 2, 'd': 3})])
|
477
|
+
>>> s.unpack_dict('b', prefix='new_')
|
478
|
+
ScenarioList([Scenario({'a': 1, 'b': {'c': 2, 'd': 3}, 'new_c': 2, 'new_d': 3})])
|
400
479
|
"""
|
401
480
|
new_scenarios = []
|
402
481
|
for scenario in self:
|
@@ -414,7 +493,17 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
414
493
|
def transform(
|
415
494
|
self, field: str, func: Callable, new_name: Optional[str] = None
|
416
495
|
) -> ScenarioList:
|
417
|
-
"""Transform a field using a function.
|
496
|
+
"""Transform a field using a function.
|
497
|
+
|
498
|
+
:param field: The field to transform.
|
499
|
+
:param func: The function to apply to the field.
|
500
|
+
:param new_name: An optional new name for the transformed field.
|
501
|
+
|
502
|
+
>>> s = ScenarioList([Scenario({'a': 1, 'b': 2}), Scenario({'a': 1, 'b': 1})])
|
503
|
+
>>> s.transform('b', lambda x: x + 1)
|
504
|
+
ScenarioList([Scenario({'a': 1, 'b': 3}), Scenario({'a': 1, 'b': 2})])
|
505
|
+
|
506
|
+
"""
|
418
507
|
new_scenarios = []
|
419
508
|
for scenario in self:
|
420
509
|
new_scenario = scenario.copy()
|
@@ -428,6 +517,9 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
428
517
|
"""
|
429
518
|
Return a new ScenarioList with a new variable added.
|
430
519
|
|
520
|
+
:param new_var_string: A string with the new variable assignment.
|
521
|
+
:param functions_dict: A dictionary of functions to use in the assignment.
|
522
|
+
|
431
523
|
Example:
|
432
524
|
|
433
525
|
>>> s = ScenarioList([Scenario({'a': 1, 'b': 2}), Scenario({'a': 1, 'b': 1})])
|
@@ -436,7 +528,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
436
528
|
|
437
529
|
"""
|
438
530
|
if "=" not in new_var_string:
|
439
|
-
raise
|
531
|
+
raise ScenarioError(
|
440
532
|
f"Mutate requires an '=' in the string, but '{new_var_string}' doesn't have one."
|
441
533
|
)
|
442
534
|
raw_var_name, expression = new_var_string.split("=", 1)
|
@@ -444,7 +536,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
444
536
|
from edsl.utilities.utilities import is_valid_variable_name
|
445
537
|
|
446
538
|
if not is_valid_variable_name(var_name):
|
447
|
-
raise
|
539
|
+
raise ScenarioError(f"{var_name} is not a valid variable name.")
|
448
540
|
|
449
541
|
# create the evaluator
|
450
542
|
functions_dict = functions_dict or {}
|
@@ -462,13 +554,15 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
462
554
|
try:
|
463
555
|
new_data = [new_scenario(s, var_name) for s in self]
|
464
556
|
except Exception as e:
|
465
|
-
raise
|
557
|
+
raise ScenarioError(f"Error in mutate. Exception:{e}")
|
466
558
|
|
467
559
|
return ScenarioList(new_data)
|
468
560
|
|
469
561
|
def order_by(self, *fields: str, reverse: bool = False) -> ScenarioList:
|
470
562
|
"""Order the scenarios by one or more fields.
|
471
563
|
|
564
|
+
:param fields: The fields to order by.
|
565
|
+
:param reverse: Whether to reverse the order.
|
472
566
|
Example:
|
473
567
|
|
474
568
|
>>> s = ScenarioList([Scenario({'a': 1, 'b': 2}), Scenario({'a': 1, 'b': 1})])
|
@@ -481,16 +575,41 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
481
575
|
|
482
576
|
return ScenarioList(sorted(self, key=get_sort_key, reverse=reverse))
|
483
577
|
|
578
|
+
def duplicate(self) -> ScenarioList:
|
579
|
+
"""Return a copy of the ScenarioList.
|
580
|
+
|
581
|
+
>>> sl = ScenarioList.example()
|
582
|
+
>>> sl_copy = sl.duplicate()
|
583
|
+
>>> sl == sl_copy
|
584
|
+
True
|
585
|
+
>>> sl is sl_copy
|
586
|
+
False
|
587
|
+
"""
|
588
|
+
return ScenarioList([scenario.copy() for scenario in self])
|
589
|
+
|
484
590
|
def filter(self, expression: str) -> ScenarioList:
|
485
591
|
"""
|
486
592
|
Filter a list of scenarios based on an expression.
|
487
593
|
|
594
|
+
:param expression: The expression to filter by.
|
595
|
+
|
488
596
|
Example:
|
489
597
|
|
490
598
|
>>> s = ScenarioList([Scenario({'a': 1, 'b': 1}), Scenario({'a': 1, 'b': 2})])
|
491
599
|
>>> s.filter("b == 2")
|
492
600
|
ScenarioList([Scenario({'a': 1, 'b': 2})])
|
493
601
|
"""
|
602
|
+
sl = self.duplicate()
|
603
|
+
base_keys = set(self[0].keys())
|
604
|
+
keys = set()
|
605
|
+
for scenario in sl:
|
606
|
+
keys.update(scenario.keys())
|
607
|
+
if keys != base_keys:
|
608
|
+
import warnings
|
609
|
+
|
610
|
+
warnings.warn(
|
611
|
+
"Ragged ScenarioList detected (different keys for different scenario entries). This may cause unexpected behavior."
|
612
|
+
)
|
494
613
|
|
495
614
|
def create_evaluator(scenario: Scenario):
|
496
615
|
"""Create an evaluator for the given result.
|
@@ -500,14 +619,22 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
500
619
|
|
501
620
|
try:
|
502
621
|
# iterates through all the results and evaluates the expression
|
503
|
-
new_data = [
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
622
|
+
new_data = []
|
623
|
+
for scenario in sl:
|
624
|
+
if create_evaluator(scenario).eval(expression):
|
625
|
+
new_data.append(scenario)
|
626
|
+
except NameNotDefined as e:
|
627
|
+
available_fields = ", ".join(self.data[0].keys() if self.data else [])
|
628
|
+
raise ScenarioError(
|
629
|
+
f"Error in filter: '{e}'\n"
|
630
|
+
f"The expression '{expression}' refers to a field that does not exist.\n"
|
631
|
+
f"Scenario: {scenario}\n"
|
632
|
+
f"Available fields: {available_fields}\n"
|
633
|
+
"Check your filter expression or consult the documentation: "
|
634
|
+
"https://docs.expectedparrot.com/en/latest/scenarios.html#module-edsl.scenarios.Scenario"
|
635
|
+
) from None
|
508
636
|
except Exception as e:
|
509
|
-
|
510
|
-
raise Exception(f"Error in filter. Exception:{e}")
|
637
|
+
raise ScenarioError(f"Error in filter. Exception:{e}")
|
511
638
|
|
512
639
|
return ScenarioList(new_data)
|
513
640
|
|
@@ -519,30 +646,26 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
519
646
|
:param urls: A list of URLs.
|
520
647
|
:param field_name: The name of the field to store the text from the URLs.
|
521
648
|
|
522
|
-
|
523
649
|
"""
|
524
650
|
return ScenarioList([Scenario.from_url(url, field_name) for url in urls])
|
525
651
|
|
526
|
-
def select(self, *fields) -> ScenarioList:
|
652
|
+
def select(self, *fields: str) -> ScenarioList:
|
527
653
|
"""
|
528
654
|
Selects scenarios with only the references fields.
|
529
655
|
|
656
|
+
:param fields: The fields to select.
|
657
|
+
|
530
658
|
Example:
|
531
659
|
|
532
660
|
>>> s = ScenarioList([Scenario({'a': 1, 'b': 1}), Scenario({'a': 1, 'b': 2})])
|
533
661
|
>>> s.select('a')
|
534
662
|
ScenarioList([Scenario({'a': 1}), Scenario({'a': 1})])
|
535
663
|
"""
|
536
|
-
|
537
|
-
fields_to_select = [list(fields)[0]]
|
538
|
-
else:
|
539
|
-
fields_to_select = list(fields)
|
664
|
+
from edsl.scenarios.ScenarioSelector import ScenarioSelector
|
540
665
|
|
541
|
-
return
|
542
|
-
[scenario.select(fields_to_select) for scenario in self.data]
|
543
|
-
)
|
666
|
+
return ScenarioSelector(self).select(*fields)
|
544
667
|
|
545
|
-
def drop(self, *fields) -> ScenarioList:
|
668
|
+
def drop(self, *fields: str) -> ScenarioList:
|
546
669
|
"""Drop fields from the scenarios.
|
547
670
|
|
548
671
|
Example:
|
@@ -551,18 +674,22 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
551
674
|
>>> s.drop('a')
|
552
675
|
ScenarioList([Scenario({'b': 1}), Scenario({'b': 2})])
|
553
676
|
"""
|
554
|
-
|
677
|
+
sl = self.duplicate()
|
678
|
+
return ScenarioList([scenario.drop(fields) for scenario in sl])
|
555
679
|
|
556
|
-
def keep(self, *fields) -> ScenarioList:
|
680
|
+
def keep(self, *fields: str) -> ScenarioList:
|
557
681
|
"""Keep only the specified fields in the scenarios.
|
558
682
|
|
683
|
+
:param fields: The fields to keep.
|
684
|
+
|
559
685
|
Example:
|
560
686
|
|
561
687
|
>>> s = ScenarioList([Scenario({'a': 1, 'b': 1}), Scenario({'a': 1, 'b': 2})])
|
562
688
|
>>> s.keep('a')
|
563
689
|
ScenarioList([Scenario({'a': 1}), Scenario({'a': 1})])
|
564
690
|
"""
|
565
|
-
|
691
|
+
sl = self.duplicate()
|
692
|
+
return ScenarioList([scenario.keep(fields) for scenario in sl])
|
566
693
|
|
567
694
|
@classmethod
|
568
695
|
def from_list(
|
@@ -570,6 +697,10 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
570
697
|
) -> ScenarioList:
|
571
698
|
"""Create a ScenarioList from a list of values.
|
572
699
|
|
700
|
+
:param name: The name of the field.
|
701
|
+
:param values: The list of values.
|
702
|
+
:param func: An optional function to apply to the values.
|
703
|
+
|
573
704
|
Example:
|
574
705
|
|
575
706
|
>>> ScenarioList.from_list('name', ['Alice', 'Bob'])
|
@@ -579,7 +710,12 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
579
710
|
func = lambda x: x
|
580
711
|
return cls([Scenario({name: func(value)}) for value in values])
|
581
712
|
|
582
|
-
def table(
|
713
|
+
def table(
|
714
|
+
self,
|
715
|
+
*fields: str,
|
716
|
+
tablefmt: Optional[TableFormat] = None,
|
717
|
+
pretty_labels: Optional[dict[str, str]] = None,
|
718
|
+
) -> str:
|
583
719
|
"""Return the ScenarioList as a table."""
|
584
720
|
|
585
721
|
from tabulate import tabulate_formats
|
@@ -594,26 +730,41 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
594
730
|
)
|
595
731
|
|
596
732
|
def tree(self, node_list: Optional[List[str]] = None) -> str:
|
597
|
-
"""Return the ScenarioList as a tree.
|
733
|
+
"""Return the ScenarioList as a tree.
|
734
|
+
|
735
|
+
:param node_list: The list of nodes to include in the tree.
|
736
|
+
"""
|
598
737
|
return self.to_dataset().tree(node_list)
|
599
738
|
|
600
|
-
def _summary(self):
|
739
|
+
def _summary(self) -> dict:
|
740
|
+
"""Return a summary of the ScenarioList.
|
741
|
+
|
742
|
+
>>> ScenarioList.example()._summary()
|
743
|
+
{'scenarios': 2, 'keys': ['persona']}
|
744
|
+
"""
|
601
745
|
d = {
|
602
|
-
"
|
603
|
-
"
|
604
|
-
"Scenario Keys": list(self.parameters),
|
746
|
+
"scenarios": len(self),
|
747
|
+
"keys": list(self.parameters),
|
605
748
|
}
|
606
749
|
return d
|
607
750
|
|
608
|
-
def reorder_keys(self, new_order):
|
751
|
+
def reorder_keys(self, new_order: List[str]) -> ScenarioList:
|
609
752
|
"""Reorder the keys in the scenarios.
|
610
753
|
|
754
|
+
:param new_order: The new order of the keys.
|
755
|
+
|
611
756
|
Example:
|
612
757
|
|
613
758
|
>>> s = ScenarioList([Scenario({'a': 1, 'b': 2}), Scenario({'a': 3, 'b': 4})])
|
614
759
|
>>> s.reorder_keys(['b', 'a'])
|
615
760
|
ScenarioList([Scenario({'b': 2, 'a': 1}), Scenario({'b': 4, 'a': 3})])
|
761
|
+
>>> s.reorder_keys(['a', 'b', 'c'])
|
762
|
+
Traceback (most recent call last):
|
763
|
+
...
|
764
|
+
AssertionError
|
616
765
|
"""
|
766
|
+
assert set(new_order) == set(self.parameters)
|
767
|
+
|
617
768
|
new_scenarios = []
|
618
769
|
for scenario in self:
|
619
770
|
new_scenario = Scenario({key: scenario[key] for key in new_order})
|
@@ -622,6 +773,8 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
622
773
|
|
623
774
|
def to_dataset(self) -> "Dataset":
|
624
775
|
"""
|
776
|
+
Convert the ScenarioList to a Dataset.
|
777
|
+
|
625
778
|
>>> s = ScenarioList.from_list("a", [1,2,3])
|
626
779
|
>>> s.to_dataset()
|
627
780
|
Dataset([{'a': [1, 2, 3]}])
|
@@ -631,8 +784,14 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
631
784
|
"""
|
632
785
|
from edsl.results.Dataset import Dataset
|
633
786
|
|
634
|
-
keys = self[0].keys()
|
635
|
-
|
787
|
+
keys = list(self[0].keys())
|
788
|
+
for scenario in self:
|
789
|
+
new_keys = list(scenario.keys())
|
790
|
+
if new_keys != keys:
|
791
|
+
keys = list(set(keys + new_keys))
|
792
|
+
data = [
|
793
|
+
{key: [scenario.get(key, None) for scenario in self.data]} for key in keys
|
794
|
+
]
|
636
795
|
return Dataset(data)
|
637
796
|
|
638
797
|
def unpack(
|
@@ -664,7 +823,14 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
664
823
|
new_scenarios.append(new_scenario)
|
665
824
|
return ScenarioList(new_scenarios)
|
666
825
|
|
667
|
-
|
826
|
+
@classmethod
|
827
|
+
def from_list_of_tuples(self, *names: str, values: List[Tuple]) -> ScenarioList:
|
828
|
+
sl = ScenarioList.from_list(names[0], [value[0] for value in values])
|
829
|
+
for index, name in enumerate(names[1:]):
|
830
|
+
sl = sl.add_list(name, [value[index + 1] for value in values])
|
831
|
+
return sl
|
832
|
+
|
833
|
+
def add_list(self, name: str, values: List[Any]) -> ScenarioList:
|
668
834
|
"""Add a list of values to a ScenarioList.
|
669
835
|
|
670
836
|
Example:
|
@@ -673,12 +839,10 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
673
839
|
>>> s.add_list('age', [30, 25])
|
674
840
|
ScenarioList([Scenario({'name': 'Alice', 'age': 30}), Scenario({'name': 'Bob', 'age': 25})])
|
675
841
|
"""
|
842
|
+
sl = self.duplicate()
|
676
843
|
for i, value in enumerate(values):
|
677
|
-
|
678
|
-
|
679
|
-
else:
|
680
|
-
self.append(Scenario({name: value}))
|
681
|
-
return self
|
844
|
+
sl[i][name] = value
|
845
|
+
return sl
|
682
846
|
|
683
847
|
def add_value(self, name: str, value: Any) -> ScenarioList:
|
684
848
|
"""Add a value to all scenarios in a ScenarioList.
|
@@ -689,13 +853,16 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
689
853
|
>>> s.add_value('age', 30)
|
690
854
|
ScenarioList([Scenario({'name': 'Alice', 'age': 30}), Scenario({'name': 'Bob', 'age': 30})])
|
691
855
|
"""
|
692
|
-
|
856
|
+
sl = self.duplicate()
|
857
|
+
for scenario in sl:
|
693
858
|
scenario[name] = value
|
694
|
-
return
|
859
|
+
return sl
|
695
860
|
|
696
861
|
def rename(self, replacement_dict: dict) -> ScenarioList:
|
697
862
|
"""Rename the fields in the scenarios.
|
698
863
|
|
864
|
+
:param replacement_dict: A dictionary with the old names as keys and the new names as values.
|
865
|
+
|
699
866
|
Example:
|
700
867
|
|
701
868
|
>>> s = ScenarioList([Scenario({'name': 'Alice', 'age': 30}), Scenario({'name': 'Bob', 'age': 25})])
|
@@ -710,8 +877,26 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
710
877
|
new_list.append(new_obj)
|
711
878
|
return new_list
|
712
879
|
|
880
|
+
## NEEDS TO BE FIXED
|
881
|
+
# def new_column_names(self, new_names: List[str]) -> ScenarioList:
|
882
|
+
# """Rename the fields in the scenarios.
|
883
|
+
|
884
|
+
# Example:
|
885
|
+
|
886
|
+
# >>> s = ScenarioList([Scenario({'name': 'Alice', 'age': 30}), Scenario({'name': 'Bob', 'age': 25})])
|
887
|
+
# >>> s.new_column_names(['first_name', 'years'])
|
888
|
+
# ScenarioList([Scenario({'first_name': 'Alice', 'years': 30}), Scenario({'first_name': 'Bob', 'years': 25})])
|
889
|
+
|
890
|
+
# """
|
891
|
+
# new_list = ScenarioList([])
|
892
|
+
# for obj in self:
|
893
|
+
# new_obj = obj.new_column_names(new_names)
|
894
|
+
# new_list.append(new_obj)
|
895
|
+
# return new_list
|
896
|
+
|
713
897
|
@classmethod
|
714
898
|
def from_sqlite(cls, filepath: str, table: str):
|
899
|
+
"""Create a ScenarioList from a SQLite database."""
|
715
900
|
import sqlite3
|
716
901
|
|
717
902
|
with sqlite3.connect(filepath) as conn:
|
@@ -855,6 +1040,9 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
855
1040
|
def to_key_value(self, field: str, value=None) -> Union[dict, set]:
|
856
1041
|
"""Return the set of values in the field.
|
857
1042
|
|
1043
|
+
:param field: The field to extract values from.
|
1044
|
+
:param value: An optional field to use as the value in the key-value pair.
|
1045
|
+
|
858
1046
|
Example:
|
859
1047
|
|
860
1048
|
>>> s = ScenarioList([Scenario({'name': 'Alice'}), Scenario({'name': 'Bob'})])
|
@@ -974,56 +1162,42 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
974
1162
|
|
975
1163
|
@classmethod
|
976
1164
|
def from_delimited_file(
|
977
|
-
cls, source: Union[str,
|
1165
|
+
cls, source: Union[str, "ParseResult"], delimiter: str = ","
|
978
1166
|
) -> ScenarioList:
|
979
|
-
"""Create a ScenarioList from a delimited file (CSV/TSV) or URL.
|
980
|
-
|
981
|
-
Args:
|
982
|
-
source: A string representing either a local file path or a URL to a delimited file,
|
983
|
-
or a urllib.parse.ParseResult object for a URL.
|
984
|
-
delimiter: The delimiter used in the file. Defaults to ',' for CSV files.
|
985
|
-
Use '\t' for TSV files.
|
986
|
-
|
987
|
-
Returns:
|
988
|
-
ScenarioList: A ScenarioList object containing the data from the file.
|
989
|
-
|
990
|
-
Example:
|
991
|
-
# For CSV files
|
992
|
-
|
993
|
-
>>> with open('data.csv', 'w') as f:
|
994
|
-
... _ = f.write('name,age\\nAlice,30\\nBob,25\\n')
|
995
|
-
>>> scenario_list = ScenarioList.from_delimited_file('data.csv')
|
996
|
-
|
997
|
-
# For TSV files
|
998
|
-
>>> with open('data.tsv', 'w') as f:
|
999
|
-
... _ = f.write('name\\tage\\nAlice\t30\\nBob\t25\\n')
|
1000
|
-
>>> scenario_list = ScenarioList.from_delimited_file('data.tsv', delimiter='\\t')
|
1001
|
-
|
1002
|
-
"""
|
1167
|
+
"""Create a ScenarioList from a delimited file (CSV/TSV) or URL."""
|
1168
|
+
import requests
|
1003
1169
|
from edsl.scenarios.Scenario import Scenario
|
1170
|
+
from urllib.parse import urlparse
|
1171
|
+
from urllib.parse import ParseResult
|
1172
|
+
|
1173
|
+
headers = {
|
1174
|
+
"Accept": "text/csv,application/csv,text/plain",
|
1175
|
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
1176
|
+
}
|
1004
1177
|
|
1005
1178
|
def is_url(source):
|
1006
1179
|
try:
|
1007
|
-
result =
|
1180
|
+
result = urlparse(source)
|
1008
1181
|
return all([result.scheme, result.netloc])
|
1009
1182
|
except ValueError:
|
1010
1183
|
return False
|
1011
1184
|
|
1012
|
-
if isinstance(source, str) and is_url(source):
|
1013
|
-
with urllib.request.urlopen(source) as response:
|
1014
|
-
file_content = response.read().decode("utf-8")
|
1015
|
-
file_obj = StringIO(file_content)
|
1016
|
-
elif isinstance(source, urllib.parse.ParseResult):
|
1017
|
-
with urllib.request.urlopen(source.geturl()) as response:
|
1018
|
-
file_content = response.read().decode("utf-8")
|
1019
|
-
file_obj = StringIO(file_content)
|
1020
|
-
else:
|
1021
|
-
file_obj = open(source, "r")
|
1022
|
-
|
1023
1185
|
try:
|
1186
|
+
if isinstance(source, str) and is_url(source):
|
1187
|
+
response = requests.get(source, headers=headers)
|
1188
|
+
response.raise_for_status()
|
1189
|
+
file_obj = StringIO(response.text)
|
1190
|
+
elif isinstance(source, ParseResult):
|
1191
|
+
response = requests.get(source.geturl(), headers=headers)
|
1192
|
+
response.raise_for_status()
|
1193
|
+
file_obj = StringIO(response.text)
|
1194
|
+
else:
|
1195
|
+
file_obj = open(source, "r")
|
1196
|
+
|
1024
1197
|
reader = csv.reader(file_obj, delimiter=delimiter)
|
1025
1198
|
header = next(reader)
|
1026
1199
|
observations = [Scenario(dict(zip(header, row))) for row in reader]
|
1200
|
+
|
1027
1201
|
finally:
|
1028
1202
|
file_obj.close()
|
1029
1203
|
|
@@ -1031,7 +1205,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1031
1205
|
|
1032
1206
|
# Convenience methods for specific file types
|
1033
1207
|
@classmethod
|
1034
|
-
def from_csv(cls, source: Union[str,
|
1208
|
+
def from_csv(cls, source: Union[str, "ParseResult"]) -> ScenarioList:
|
1035
1209
|
"""Create a ScenarioList from a CSV file or URL."""
|
1036
1210
|
return cls.from_delimited_file(source, delimiter=",")
|
1037
1211
|
|
@@ -1052,71 +1226,13 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1052
1226
|
|
1053
1227
|
sj = ScenarioJoin(self, other)
|
1054
1228
|
return sj.left_join(by)
|
1055
|
-
# # Validate join keys
|
1056
|
-
# if not by:
|
1057
|
-
# raise ValueError(
|
1058
|
-
# "Join keys cannot be empty. Please specify at least one key to join on."
|
1059
|
-
# )
|
1060
|
-
|
1061
|
-
# # Convert single string to list for consistent handling
|
1062
|
-
# by_keys = [by] if isinstance(by, str) else by
|
1063
|
-
|
1064
|
-
# # Verify all join keys exist in both ScenarioLists
|
1065
|
-
# left_keys = set(next(iter(self)).keys()) if self else set()
|
1066
|
-
# right_keys = set(next(iter(other)).keys()) if other else set()
|
1067
|
-
|
1068
|
-
# missing_left = set(by_keys) - left_keys
|
1069
|
-
# missing_right = set(by_keys) - right_keys
|
1070
|
-
# if missing_left or missing_right:
|
1071
|
-
# missing = missing_left | missing_right
|
1072
|
-
# raise ValueError(f"Join key(s) {missing} not found in both ScenarioLists")
|
1073
|
-
|
1074
|
-
# # Create lookup dictionary from the other ScenarioList
|
1075
|
-
# def get_key_tuple(scenario: Scenario, keys: list[str]) -> tuple:
|
1076
|
-
# return tuple(scenario[k] for k in keys)
|
1077
|
-
|
1078
|
-
# other_dict = {get_key_tuple(scenario, by_keys): scenario for scenario in other}
|
1079
|
-
|
1080
|
-
# # Collect all possible keys (like SQL combining all columns)
|
1081
|
-
# all_keys = set()
|
1082
|
-
# for scenario in self:
|
1083
|
-
# all_keys.update(scenario.keys())
|
1084
|
-
# for scenario in other:
|
1085
|
-
# all_keys.update(scenario.keys())
|
1086
|
-
|
1087
|
-
# new_scenarios = []
|
1088
|
-
# for scenario in self:
|
1089
|
-
# new_scenario = {
|
1090
|
-
# key: None for key in all_keys
|
1091
|
-
# } # Start with nulls (like SQL)
|
1092
|
-
# new_scenario.update(scenario) # Add all left values
|
1093
|
-
|
1094
|
-
# key_tuple = get_key_tuple(scenario, by_keys)
|
1095
|
-
# if matching_scenario := other_dict.get(key_tuple):
|
1096
|
-
# # Check for overlapping keys with different values
|
1097
|
-
# overlapping_keys = set(scenario.keys()) & set(matching_scenario.keys())
|
1098
|
-
# for key in overlapping_keys:
|
1099
|
-
# if key not in by_keys and scenario[key] != matching_scenario[key]:
|
1100
|
-
# join_conditions = [f"{k}='{scenario[k]}'" for k in by_keys]
|
1101
|
-
# print(
|
1102
|
-
# f"Warning: Conflicting values for key '{key}' where {' AND '.join(join_conditions)}. "
|
1103
|
-
# f"Keeping left value: {scenario[key]} (discarding: {matching_scenario[key]})"
|
1104
|
-
# )
|
1105
|
-
|
1106
|
-
# # Only update with non-overlapping keys from matching scenario
|
1107
|
-
# new_keys = set(matching_scenario.keys()) - set(scenario.keys())
|
1108
|
-
# new_scenario.update({k: matching_scenario[k] for k in new_keys})
|
1109
|
-
|
1110
|
-
# new_scenarios.append(Scenario(new_scenario))
|
1111
|
-
|
1112
|
-
# return ScenarioList(new_scenarios)
|
1113
1229
|
|
1114
1230
|
@classmethod
|
1115
|
-
def from_tsv(cls, source: Union[str,
|
1231
|
+
def from_tsv(cls, source: Union[str, "ParseResult"]) -> ScenarioList:
|
1116
1232
|
"""Create a ScenarioList from a TSV file or URL."""
|
1117
1233
|
return cls.from_delimited_file(source, delimiter="\t")
|
1118
1234
|
|
1119
|
-
def to_dict(self, sort=False, add_edsl_version=True) -> dict:
|
1235
|
+
def to_dict(self, sort: bool = False, add_edsl_version: bool = True) -> dict:
|
1120
1236
|
"""
|
1121
1237
|
>>> s = ScenarioList([Scenario({'food': 'wood chips'}), Scenario({'food': 'wood-fired pizza'})])
|
1122
1238
|
>>> s.to_dict()
|
@@ -1135,6 +1251,27 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1135
1251
|
d["edsl_class_name"] = self.__class__.__name__
|
1136
1252
|
return d
|
1137
1253
|
|
1254
|
+
def to(self, survey: Union["Survey", "QuestionBase"]) -> "Jobs":
|
1255
|
+
"""Create a Jobs object from a ScenarioList and a Survey object.
|
1256
|
+
|
1257
|
+
:param survey: The Survey object to use for the Jobs object.
|
1258
|
+
|
1259
|
+
Example:
|
1260
|
+
>>> from edsl import Survey
|
1261
|
+
>>> from edsl.jobs.Jobs import Jobs
|
1262
|
+
>>> from edsl import ScenarioList
|
1263
|
+
>>> isinstance(ScenarioList.example().to(Survey.example()), Jobs)
|
1264
|
+
True
|
1265
|
+
"""
|
1266
|
+
from edsl.surveys.Survey import Survey
|
1267
|
+
from edsl.questions.QuestionBase import QuestionBase
|
1268
|
+
from edsl.jobs.Jobs import Jobs
|
1269
|
+
|
1270
|
+
if isinstance(survey, QuestionBase):
|
1271
|
+
return Survey([survey]).by(self)
|
1272
|
+
else:
|
1273
|
+
return survey.by(self)
|
1274
|
+
|
1138
1275
|
@classmethod
|
1139
1276
|
def gen(cls, scenario_dicts_list: List[dict]) -> ScenarioList:
|
1140
1277
|
"""Create a `ScenarioList` from a list of dictionaries.
|
@@ -1160,15 +1297,12 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1160
1297
|
@classmethod
|
1161
1298
|
def from_nested_dict(cls, data: dict) -> ScenarioList:
|
1162
1299
|
"""Create a `ScenarioList` from a nested dictionary."""
|
1163
|
-
from edsl.scenarios.Scenario import Scenario
|
1164
|
-
|
1165
1300
|
s = ScenarioList()
|
1166
1301
|
for key, value in data.items():
|
1167
1302
|
s.add_list(key, value)
|
1168
1303
|
return s
|
1169
1304
|
|
1170
1305
|
def code(self) -> str:
|
1171
|
-
## TODO: Refactor to only use the questions actually in the survey
|
1172
1306
|
"""Create the Python code representation of a survey."""
|
1173
1307
|
header_lines = [
|
1174
1308
|
"from edsl.scenarios.Scenario import Scenario",
|
@@ -1191,16 +1325,16 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1191
1325
|
"""
|
1192
1326
|
return cls([Scenario.example(randomize), Scenario.example(randomize)])
|
1193
1327
|
|
1194
|
-
def rich_print(self) -> None:
|
1195
|
-
|
1196
|
-
|
1328
|
+
# def rich_print(self) -> None:
|
1329
|
+
# """Display an object as a table."""
|
1330
|
+
# from rich.table import Table
|
1197
1331
|
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1332
|
+
# table = Table(title="ScenarioList")
|
1333
|
+
# table.add_column("Index", style="bold")
|
1334
|
+
# table.add_column("Scenario")
|
1335
|
+
# for i, s in enumerate(self):
|
1336
|
+
# table.add_row(str(i), s.rich_print())
|
1337
|
+
# return table
|
1204
1338
|
|
1205
1339
|
def __getitem__(self, key: Union[int, slice]) -> Any:
|
1206
1340
|
"""Return the item at the given index.
|
@@ -1246,9 +1380,18 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1246
1380
|
f"The 'name' field is reserved for the agent's name---putting this value in {proposed_agent_name}"
|
1247
1381
|
)
|
1248
1382
|
new_scenario[proposed_agent_name] = name
|
1249
|
-
|
1383
|
+
new_agent = Agent(traits=new_scenario, name=name)
|
1384
|
+
if "agent_parameters" in new_scenario:
|
1385
|
+
agent_parameters = new_scenario.pop("agent_parameters")
|
1386
|
+
instruction = agent_parameters.get("instruction", None)
|
1387
|
+
name = agent_parameters.get("name", None)
|
1388
|
+
new_agent = Agent(
|
1389
|
+
traits=new_scenario, name=name, instruction=instruction
|
1390
|
+
)
|
1250
1391
|
else:
|
1251
|
-
|
1392
|
+
new_agent = Agent(traits=new_scenario)
|
1393
|
+
|
1394
|
+
agents.append(new_agent)
|
1252
1395
|
|
1253
1396
|
return AgentList(agents)
|
1254
1397
|
|