edsl 0.1.39__py3-none-any.whl → 0.1.39.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +116 -197
- edsl/__init__.py +7 -15
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +147 -351
- edsl/agents/AgentList.py +73 -211
- edsl/agents/Invigilator.py +50 -101
- edsl/agents/InvigilatorBase.py +70 -62
- edsl/agents/PromptConstructor.py +225 -143
- edsl/agents/__init__.py +1 -0
- edsl/agents/prompt_helpers.py +3 -3
- edsl/auto/AutoStudy.py +5 -18
- edsl/auto/StageBase.py +40 -53
- edsl/auto/StageQuestions.py +1 -2
- edsl/auto/utilities.py +6 -0
- edsl/config.py +2 -22
- edsl/conversation/car_buying.py +1 -2
- edsl/coop/PriceFetcher.py +1 -1
- edsl/coop/coop.py +47 -125
- edsl/coop/utils.py +14 -14
- edsl/data/Cache.py +27 -45
- edsl/data/CacheEntry.py +15 -12
- edsl/data/CacheHandler.py +12 -31
- edsl/data/RemoteCacheSync.py +46 -154
- edsl/data/__init__.py +3 -4
- edsl/data_transfer_models.py +1 -2
- edsl/enums.py +0 -27
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +0 -12
- edsl/exceptions/questions.py +6 -24
- edsl/exceptions/scenarios.py +0 -7
- edsl/inference_services/AnthropicService.py +19 -38
- edsl/inference_services/AwsBedrock.py +2 -0
- edsl/inference_services/AzureAI.py +2 -0
- edsl/inference_services/GoogleService.py +12 -7
- edsl/inference_services/InferenceServiceABC.py +85 -18
- edsl/inference_services/InferenceServicesCollection.py +79 -120
- edsl/inference_services/MistralAIService.py +3 -0
- edsl/inference_services/OpenAIService.py +35 -47
- edsl/inference_services/PerplexityService.py +3 -0
- edsl/inference_services/TestService.py +10 -11
- edsl/inference_services/TogetherAIService.py +3 -5
- edsl/jobs/Answers.py +14 -1
- edsl/jobs/Jobs.py +431 -356
- edsl/jobs/JobsChecks.py +10 -35
- edsl/jobs/JobsPrompts.py +4 -6
- edsl/jobs/JobsRemoteInferenceHandler.py +133 -205
- edsl/jobs/buckets/BucketCollection.py +3 -44
- edsl/jobs/buckets/TokenBucket.py +21 -53
- edsl/jobs/interviews/Interview.py +408 -143
- edsl/jobs/runners/JobsRunnerAsyncio.py +403 -88
- edsl/jobs/runners/JobsRunnerStatus.py +165 -133
- edsl/jobs/tasks/QuestionTaskCreator.py +19 -21
- edsl/jobs/tasks/TaskHistory.py +18 -38
- edsl/jobs/tasks/task_status_enum.py +2 -0
- edsl/language_models/KeyLookup.py +30 -0
- edsl/language_models/LanguageModel.py +236 -194
- edsl/language_models/ModelList.py +19 -28
- edsl/language_models/__init__.py +2 -1
- edsl/language_models/registry.py +190 -0
- edsl/language_models/repair.py +2 -2
- edsl/language_models/unused/ReplicateBase.py +83 -0
- edsl/language_models/utilities.py +4 -5
- edsl/notebooks/Notebook.py +14 -19
- edsl/prompts/Prompt.py +39 -29
- edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +2 -47
- edsl/questions/QuestionBase.py +214 -68
- edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +50 -57
- edsl/questions/QuestionBasePromptsMixin.py +3 -7
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +3 -3
- edsl/questions/QuestionExtract.py +7 -5
- edsl/questions/QuestionFreeText.py +3 -2
- edsl/questions/QuestionList.py +18 -10
- edsl/questions/QuestionMultipleChoice.py +23 -67
- edsl/questions/QuestionNumerical.py +4 -2
- edsl/questions/QuestionRank.py +17 -7
- edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +26 -40
- edsl/questions/SimpleAskMixin.py +3 -4
- edsl/questions/__init__.py +1 -2
- edsl/questions/derived/QuestionLinearScale.py +3 -6
- edsl/questions/derived/QuestionTopK.py +1 -1
- edsl/questions/descriptors.py +3 -17
- edsl/questions/question_registry.py +1 -1
- edsl/results/CSSParameterizer.py +1 -1
- edsl/results/Dataset.py +7 -170
- edsl/results/DatasetExportMixin.py +305 -168
- edsl/results/DatasetTree.py +8 -28
- edsl/results/Result.py +206 -298
- edsl/results/Results.py +131 -149
- edsl/results/ResultsDBMixin.py +238 -0
- edsl/results/ResultsExportMixin.py +0 -2
- edsl/results/{results_selector.py → Selector.py} +13 -23
- edsl/results/TableDisplay.py +171 -98
- edsl/results/__init__.py +1 -1
- edsl/scenarios/FileStore.py +239 -150
- edsl/scenarios/Scenario.py +193 -90
- edsl/scenarios/ScenarioHtmlMixin.py +3 -4
- edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +6 -10
- edsl/scenarios/ScenarioList.py +244 -415
- edsl/scenarios/ScenarioListExportMixin.py +7 -0
- edsl/scenarios/ScenarioListPdfMixin.py +37 -15
- edsl/scenarios/__init__.py +2 -1
- edsl/study/ObjectEntry.py +1 -1
- edsl/study/SnapShot.py +1 -1
- edsl/study/Study.py +12 -5
- edsl/surveys/Rule.py +4 -5
- edsl/surveys/RuleCollection.py +27 -25
- edsl/surveys/Survey.py +791 -270
- edsl/surveys/SurveyCSS.py +8 -20
- edsl/surveys/{SurveyFlowVisualization.py → SurveyFlowVisualizationMixin.py} +9 -11
- edsl/surveys/__init__.py +2 -4
- edsl/surveys/descriptors.py +2 -6
- edsl/surveys/instructions/ChangeInstruction.py +2 -1
- edsl/surveys/instructions/Instruction.py +13 -4
- edsl/surveys/instructions/InstructionCollection.py +6 -11
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/report.html +1 -1
- edsl/tools/plotting.py +1 -1
- edsl/utilities/utilities.py +23 -35
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/METADATA +10 -12
- edsl-0.1.39.dev1.dist-info/RECORD +277 -0
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/WHEEL +1 -1
- edsl/agents/QuestionInstructionPromptBuilder.py +0 -128
- edsl/agents/QuestionTemplateReplacementsBuilder.py +0 -137
- edsl/agents/question_option_processor.py +0 -172
- edsl/coop/CoopFunctionsMixin.py +0 -15
- edsl/coop/ExpectedParrotKeyHandler.py +0 -125
- edsl/exceptions/inference_services.py +0 -5
- edsl/inference_services/AvailableModelCacheHandler.py +0 -184
- edsl/inference_services/AvailableModelFetcher.py +0 -215
- edsl/inference_services/ServiceAvailability.py +0 -135
- edsl/inference_services/data_structures.py +0 -134
- edsl/jobs/AnswerQuestionFunctionConstructor.py +0 -223
- edsl/jobs/FetchInvigilator.py +0 -47
- edsl/jobs/InterviewTaskManager.py +0 -98
- edsl/jobs/InterviewsConstructor.py +0 -50
- edsl/jobs/JobsComponentConstructor.py +0 -189
- edsl/jobs/JobsRemoteInferenceLogger.py +0 -239
- edsl/jobs/RequestTokenEstimator.py +0 -30
- edsl/jobs/async_interview_runner.py +0 -138
- edsl/jobs/buckets/TokenBucketAPI.py +0 -211
- edsl/jobs/buckets/TokenBucketClient.py +0 -191
- edsl/jobs/check_survey_scenario_compatibility.py +0 -85
- edsl/jobs/data_structures.py +0 -120
- edsl/jobs/decorators.py +0 -35
- edsl/jobs/jobs_status_enums.py +0 -9
- edsl/jobs/loggers/HTMLTableJobLogger.py +0 -304
- edsl/jobs/results_exceptions_handler.py +0 -98
- edsl/language_models/ComputeCost.py +0 -63
- edsl/language_models/PriceManager.py +0 -127
- edsl/language_models/RawResponseHandler.py +0 -106
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/key_management/KeyLookup.py +0 -63
- edsl/language_models/key_management/KeyLookupBuilder.py +0 -273
- edsl/language_models/key_management/KeyLookupCollection.py +0 -38
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +0 -131
- edsl/language_models/model.py +0 -256
- edsl/notebooks/NotebookToLaTeX.py +0 -142
- edsl/questions/ExceptionExplainer.py +0 -77
- edsl/questions/HTMLQuestion.py +0 -103
- edsl/questions/QuestionMatrix.py +0 -265
- edsl/questions/data_structures.py +0 -20
- edsl/questions/loop_processor.py +0 -149
- edsl/questions/response_validator_factory.py +0 -34
- edsl/questions/templates/matrix/__init__.py +0 -1
- edsl/questions/templates/matrix/answering_instructions.jinja +0 -5
- edsl/questions/templates/matrix/question_presentation.jinja +0 -20
- edsl/results/MarkdownToDocx.py +0 -122
- edsl/results/MarkdownToPDF.py +0 -111
- edsl/results/TextEditor.py +0 -50
- edsl/results/file_exports.py +0 -252
- edsl/results/smart_objects.py +0 -96
- edsl/results/table_data_class.py +0 -12
- edsl/results/table_renderers.py +0 -118
- edsl/scenarios/ConstructDownloadLink.py +0 -109
- edsl/scenarios/DocumentChunker.py +0 -102
- edsl/scenarios/DocxScenario.py +0 -16
- edsl/scenarios/PdfExtractor.py +0 -40
- edsl/scenarios/directory_scanner.py +0 -96
- edsl/scenarios/file_methods.py +0 -85
- edsl/scenarios/handlers/__init__.py +0 -13
- edsl/scenarios/handlers/csv.py +0 -49
- edsl/scenarios/handlers/docx.py +0 -76
- edsl/scenarios/handlers/html.py +0 -37
- edsl/scenarios/handlers/json.py +0 -111
- edsl/scenarios/handlers/latex.py +0 -5
- edsl/scenarios/handlers/md.py +0 -51
- edsl/scenarios/handlers/pdf.py +0 -68
- edsl/scenarios/handlers/png.py +0 -39
- edsl/scenarios/handlers/pptx.py +0 -105
- edsl/scenarios/handlers/py.py +0 -294
- edsl/scenarios/handlers/sql.py +0 -313
- edsl/scenarios/handlers/sqlite.py +0 -149
- edsl/scenarios/handlers/txt.py +0 -33
- edsl/scenarios/scenario_selector.py +0 -156
- edsl/surveys/ConstructDAG.py +0 -92
- edsl/surveys/EditSurvey.py +0 -221
- edsl/surveys/InstructionHandler.py +0 -100
- edsl/surveys/MemoryManagement.py +0 -72
- edsl/surveys/RuleManager.py +0 -172
- edsl/surveys/Simulator.py +0 -75
- edsl/surveys/SurveyToApp.py +0 -141
- edsl/utilities/PrettyList.py +0 -56
- edsl/utilities/is_notebook.py +0 -18
- edsl/utilities/is_valid_variable_name.py +0 -11
- edsl/utilities/remove_edsl_version.py +0 -24
- edsl-0.1.39.dist-info/RECORD +0 -358
- /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
- /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
- /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/LICENSE +0 -0
edsl/scenarios/ScenarioList.py
CHANGED
@@ -1,78 +1,33 @@
|
|
1
1
|
"""A list of Scenarios to be used in a survey."""
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
|
-
from typing import
|
5
|
-
Any,
|
6
|
-
Optional,
|
7
|
-
Union,
|
8
|
-
List,
|
9
|
-
Callable,
|
10
|
-
Literal,
|
11
|
-
TYPE_CHECKING,
|
12
|
-
)
|
13
|
-
|
14
|
-
try:
|
15
|
-
from typing import TypeAlias
|
16
|
-
except ImportError:
|
17
|
-
from typing_extensions import TypeAlias
|
18
|
-
|
4
|
+
from typing import Any, Optional, Union, List, Callable
|
19
5
|
import csv
|
20
6
|
import random
|
7
|
+
from collections import UserList, Counter
|
8
|
+
from collections.abc import Iterable
|
9
|
+
import urllib.parse
|
10
|
+
import urllib.request
|
21
11
|
from io import StringIO
|
12
|
+
from collections import defaultdict
|
22
13
|
import inspect
|
23
|
-
from collections import UserList, defaultdict
|
24
|
-
from collections.abc import Iterable
|
25
|
-
|
26
|
-
if TYPE_CHECKING:
|
27
|
-
from urllib.parse import ParseResult
|
28
|
-
from edsl.results.Dataset import Dataset
|
29
|
-
from edsl.jobs.Jobs import Jobs
|
30
|
-
from edsl.surveys.Survey import Survey
|
31
|
-
from edsl.questions.QuestionBase import QuestionBase
|
32
|
-
|
33
14
|
|
34
|
-
from simpleeval import EvalWithCompoundTypes
|
35
|
-
|
36
|
-
from tabulate import tabulate_formats
|
15
|
+
from simpleeval import EvalWithCompoundTypes
|
37
16
|
|
38
17
|
from edsl.Base import Base
|
39
|
-
from edsl.utilities.
|
40
|
-
|
18
|
+
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
41
19
|
from edsl.scenarios.Scenario import Scenario
|
42
20
|
from edsl.scenarios.ScenarioListPdfMixin import ScenarioListPdfMixin
|
43
21
|
from edsl.scenarios.ScenarioListExportMixin import ScenarioListExportMixin
|
44
|
-
from edsl.utilities.naming_utilities import sanitize_string
|
45
|
-
from edsl.utilities.is_valid_variable_name import is_valid_variable_name
|
46
|
-
from edsl.exceptions.scenarios import ScenarioError
|
47
22
|
|
48
|
-
from edsl.
|
23
|
+
from edsl.utilities.naming_utilities import sanitize_string
|
24
|
+
from edsl.utilities.utilities import is_valid_variable_name
|
49
25
|
|
50
26
|
|
51
27
|
class ScenarioListMixin(ScenarioListPdfMixin, ScenarioListExportMixin):
|
52
28
|
pass
|
53
29
|
|
54
30
|
|
55
|
-
if TYPE_CHECKING:
|
56
|
-
from edsl.results.Dataset import Dataset
|
57
|
-
|
58
|
-
TableFormat: TypeAlias = Literal[
|
59
|
-
"plain",
|
60
|
-
"simple",
|
61
|
-
"github",
|
62
|
-
"grid",
|
63
|
-
"fancy_grid",
|
64
|
-
"pipe",
|
65
|
-
"orgtbl",
|
66
|
-
"rst",
|
67
|
-
"mediawiki",
|
68
|
-
"html",
|
69
|
-
"latex",
|
70
|
-
"latex_raw",
|
71
|
-
"latex_booktabs",
|
72
|
-
"tsv",
|
73
|
-
]
|
74
|
-
|
75
|
-
|
76
31
|
class ScenarioList(Base, UserList, ScenarioListMixin):
|
77
32
|
"""Class for creating a list of scenarios to be used in a survey."""
|
78
33
|
|
@@ -80,9 +35,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
80
35
|
"https://docs.expectedparrot.com/en/latest/scenarios.html#scenariolist"
|
81
36
|
)
|
82
37
|
|
83
|
-
def __init__(
|
84
|
-
self, data: Optional[list] = None, codebook: Optional[dict[str, str]] = None
|
85
|
-
):
|
38
|
+
def __init__(self, data: Optional[list] = None, codebook: Optional[dict] = None):
|
86
39
|
"""Initialize the ScenarioList class."""
|
87
40
|
if data is not None:
|
88
41
|
super().__init__(data)
|
@@ -104,19 +57,12 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
104
57
|
"""Check if the ScenarioList has Jinja braces."""
|
105
58
|
return any([scenario.has_jinja_braces for scenario in self])
|
106
59
|
|
107
|
-
def
|
60
|
+
def convert_jinja_braces(self) -> ScenarioList:
|
108
61
|
"""Convert Jinja braces to Python braces."""
|
109
|
-
return ScenarioList([scenario.
|
110
|
-
|
111
|
-
def give_valid_names(self, existing_codebook: dict = None) -> ScenarioList:
|
112
|
-
"""Give valid names to the scenario keys, using an existing codebook if provided.
|
62
|
+
return ScenarioList([scenario.convert_jinja_braces() for scenario in self])
|
113
63
|
|
114
|
-
|
115
|
-
|
116
|
-
Defaults to None.
|
117
|
-
|
118
|
-
Returns:
|
119
|
-
ScenarioList: A new ScenarioList with valid variable names and updated codebook.
|
64
|
+
def give_valid_names(self) -> ScenarioList:
|
65
|
+
"""Give valid names to the scenario keys.
|
120
66
|
|
121
67
|
>>> s = ScenarioList([Scenario({'a': 1, 'b': 2}), Scenario({'a': 1, 'b': 1})])
|
122
68
|
>>> s.give_valid_names()
|
@@ -124,38 +70,27 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
124
70
|
>>> s = ScenarioList([Scenario({'are you there John?': 1, 'b': 2}), Scenario({'a': 1, 'b': 1})])
|
125
71
|
>>> s.give_valid_names()
|
126
72
|
ScenarioList([Scenario({'john': 1, 'b': 2}), Scenario({'a': 1, 'b': 1})])
|
127
|
-
>>> s.give_valid_names({'are you there John?': 'custom_name'})
|
128
|
-
ScenarioList([Scenario({'custom_name': 1, 'b': 2}), Scenario({'a': 1, 'b': 1})])
|
129
73
|
"""
|
130
|
-
codebook =
|
131
|
-
|
132
|
-
|
74
|
+
codebook = {}
|
75
|
+
new_scenaerios = []
|
133
76
|
for scenario in self:
|
134
77
|
new_scenario = {}
|
135
78
|
for key in scenario:
|
136
|
-
if is_valid_variable_name(key):
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
79
|
+
if not is_valid_variable_name(key):
|
80
|
+
if key in codebook:
|
81
|
+
new_key = codebook[key]
|
82
|
+
else:
|
83
|
+
new_key = sanitize_string(key)
|
84
|
+
if not is_valid_variable_name(new_key):
|
85
|
+
new_key = f"var_{len(codebook)}"
|
86
|
+
codebook[key] = new_key
|
87
|
+
new_scenario[new_key] = scenario[key]
|
142
88
|
else:
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
codebook[key] = new_key
|
147
|
-
|
148
|
-
new_scenario[new_key] = scenario[key]
|
149
|
-
|
150
|
-
new_scenarios.append(Scenario(new_scenario))
|
151
|
-
|
152
|
-
return ScenarioList(new_scenarios, codebook)
|
89
|
+
new_scenario[key] = scenario[key]
|
90
|
+
new_scenaerios.append(Scenario(new_scenario))
|
91
|
+
return ScenarioList(new_scenaerios, codebook)
|
153
92
|
|
154
|
-
def unpivot(
|
155
|
-
self,
|
156
|
-
id_vars: Optional[List[str]] = None,
|
157
|
-
value_vars: Optional[List[str]] = None,
|
158
|
-
) -> ScenarioList:
|
93
|
+
def unpivot(self, id_vars=None, value_vars=None):
|
159
94
|
"""
|
160
95
|
Unpivot the ScenarioList, allowing for id variables to be specified.
|
161
96
|
|
@@ -186,40 +121,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
186
121
|
|
187
122
|
return ScenarioList(new_scenarios)
|
188
123
|
|
189
|
-
def
|
190
|
-
"""Filter the ScenarioList based on a language predicate.
|
191
|
-
|
192
|
-
:param language_predicate: The language predicate to use.
|
193
|
-
|
194
|
-
Inspired by:
|
195
|
-
@misc{patel2024semanticoperators,
|
196
|
-
title={Semantic Operators: A Declarative Model for Rich, AI-based Analytics Over Text Data},
|
197
|
-
author={Liana Patel and Siddharth Jha and Parth Asawa and Melissa Pan and Carlos Guestrin and Matei Zaharia},
|
198
|
-
year={2024},
|
199
|
-
eprint={2407.11418},
|
200
|
-
archivePrefix={arXiv},
|
201
|
-
primaryClass={cs.DB},
|
202
|
-
url={https://arxiv.org/abs/2407.11418},
|
203
|
-
}
|
204
|
-
"""
|
205
|
-
from edsl import QuestionYesNo
|
206
|
-
|
207
|
-
new_scenario_list = self.duplicate()
|
208
|
-
q = QuestionYesNo(
|
209
|
-
question_text=language_predicate, question_name="binary_outcome"
|
210
|
-
)
|
211
|
-
results = q.by(new_scenario_list).run(verbose=False)
|
212
|
-
new_scenario_list = new_scenario_list.add_list(
|
213
|
-
"criteria", results.select("binary_outcome").to_list()
|
214
|
-
)
|
215
|
-
return new_scenario_list.filter("criteria == 'Yes'").drop("criteria")
|
216
|
-
|
217
|
-
def pivot(
|
218
|
-
self,
|
219
|
-
id_vars: List[str] = None,
|
220
|
-
var_name="variable",
|
221
|
-
value_name="value",
|
222
|
-
) -> ScenarioList:
|
124
|
+
def pivot(self, id_vars, var_name="variable", value_name="value"):
|
223
125
|
"""
|
224
126
|
Pivot the ScenarioList from long to wide format.
|
225
127
|
|
@@ -261,15 +163,15 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
261
163
|
|
262
164
|
return ScenarioList(pivoted_scenarios)
|
263
165
|
|
264
|
-
def group_by(
|
265
|
-
self, id_vars: List[str], variables: List[str], func: Callable
|
266
|
-
) -> ScenarioList:
|
166
|
+
def group_by(self, id_vars, variables, func):
|
267
167
|
"""
|
268
168
|
Group the ScenarioList by id_vars and apply a function to the specified variables.
|
269
169
|
|
270
|
-
:
|
271
|
-
|
272
|
-
|
170
|
+
Parameters:
|
171
|
+
id_vars (list): Fields to use as identifier variables for grouping
|
172
|
+
variables (list): Fields to pass to the aggregation function
|
173
|
+
func (callable): Function to apply to the grouped variables.
|
174
|
+
Should accept lists of values for each variable.
|
273
175
|
|
274
176
|
Returns:
|
275
177
|
ScenarioList: A new ScenarioList with the grouped and aggregated results
|
@@ -289,12 +191,12 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
289
191
|
# Check if the function is compatible with the specified variables
|
290
192
|
func_params = inspect.signature(func).parameters
|
291
193
|
if len(func_params) != len(variables):
|
292
|
-
raise
|
194
|
+
raise ValueError(
|
293
195
|
f"Function {func.__name__} expects {len(func_params)} arguments, but {len(variables)} variables were provided"
|
294
196
|
)
|
295
197
|
|
296
198
|
# Group the scenarios
|
297
|
-
grouped
|
199
|
+
grouped = defaultdict(lambda: defaultdict(list))
|
298
200
|
for scenario in self:
|
299
201
|
key = tuple(scenario[id_var] for id_var in id_vars)
|
300
202
|
for var in variables:
|
@@ -306,12 +208,10 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
306
208
|
try:
|
307
209
|
aggregated = func(*[group[var] for var in variables])
|
308
210
|
except Exception as e:
|
309
|
-
raise
|
211
|
+
raise ValueError(f"Error applying function to group {key}: {str(e)}")
|
310
212
|
|
311
213
|
if not isinstance(aggregated, dict):
|
312
|
-
raise
|
313
|
-
f"Function {func.__name__} must return a dictionary"
|
314
|
-
)
|
214
|
+
raise ValueError(f"Function {func.__name__} must return a dictionary")
|
315
215
|
|
316
216
|
new_scenario = dict(zip(id_vars, key))
|
317
217
|
new_scenario.update(aggregated)
|
@@ -378,18 +278,50 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
378
278
|
"""
|
379
279
|
return self.__mul__(other)
|
380
280
|
|
381
|
-
def shuffle(self, seed: Optional[str] =
|
281
|
+
def shuffle(self, seed: Optional[str] = "edsl") -> ScenarioList:
|
382
282
|
"""Shuffle the ScenarioList.
|
383
283
|
|
384
284
|
>>> s = ScenarioList.from_list("a", [1,2,3,4])
|
385
|
-
>>> s.shuffle(
|
386
|
-
ScenarioList([Scenario({'a':
|
285
|
+
>>> s.shuffle()
|
286
|
+
ScenarioList([Scenario({'a': 3}), Scenario({'a': 4}), Scenario({'a': 1}), Scenario({'a': 2})])
|
387
287
|
"""
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
288
|
+
random.seed(seed)
|
289
|
+
random.shuffle(self.data)
|
290
|
+
return self
|
291
|
+
|
292
|
+
def _repr_html_(self):
|
293
|
+
"""Return an HTML representation of the AgentList."""
|
294
|
+
# return (
|
295
|
+
# str(self.summary(format="html")) + "<br>" + str(self.table(tablefmt="html"))
|
296
|
+
# )
|
297
|
+
footer = f"<a href={self.__documentation__}>(docs)</a>"
|
298
|
+
return str(self.summary(format="html")) + footer
|
299
|
+
|
300
|
+
# def _repr_html_(self) -> str:
|
301
|
+
# from edsl.utilities.utilities import data_to_html
|
302
|
+
|
303
|
+
# data = self.to_dict()
|
304
|
+
# _ = data.pop("edsl_version")
|
305
|
+
# _ = data.pop("edsl_class_name")
|
306
|
+
# for s in data["scenarios"]:
|
307
|
+
# _ = s.pop("edsl_version")
|
308
|
+
# _ = s.pop("edsl_class_name")
|
309
|
+
# for scenario in data["scenarios"]:
|
310
|
+
# for key, value in scenario.items():
|
311
|
+
# if hasattr(value, "to_dict"):
|
312
|
+
# data[key] = value.to_dict()
|
313
|
+
# return data_to_html(data)
|
314
|
+
|
315
|
+
# def tally(self, field) -> dict:
|
316
|
+
# """Return a tally of the values in the field.
|
317
|
+
|
318
|
+
# Example:
|
319
|
+
|
320
|
+
# >>> s = ScenarioList([Scenario({'a': 1, 'b': 1}), Scenario({'a': 1, 'b': 2})])
|
321
|
+
# >>> s.tally('b')
|
322
|
+
# {1: 1, 2: 1}
|
323
|
+
# """
|
324
|
+
# return dict(Counter([scenario[field] for scenario in self]))
|
393
325
|
|
394
326
|
def sample(self, n: int, seed: Optional[str] = None) -> ScenarioList:
|
395
327
|
"""Return a random sample from the ScenarioList
|
@@ -401,22 +333,16 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
401
333
|
if seed:
|
402
334
|
random.seed(seed)
|
403
335
|
|
404
|
-
|
405
|
-
return ScenarioList(random.sample(sl.data, n))
|
336
|
+
return ScenarioList(random.sample(self.data, n))
|
406
337
|
|
407
|
-
def expand(self, expand_field: str, number_field
|
338
|
+
def expand(self, expand_field: str, number_field=False) -> ScenarioList:
|
408
339
|
"""Expand the ScenarioList by a field.
|
409
340
|
|
410
|
-
:param expand_field: The field to expand.
|
411
|
-
:param number_field: Whether to add a field with the index of the value
|
412
|
-
|
413
341
|
Example:
|
414
342
|
|
415
343
|
>>> s = ScenarioList( [ Scenario({'a':1, 'b':[1,2]}) ] )
|
416
344
|
>>> s.expand('b')
|
417
345
|
ScenarioList([Scenario({'a': 1, 'b': 1}), Scenario({'a': 1, 'b': 2})])
|
418
|
-
>>> s.expand('b', number_field=True)
|
419
|
-
ScenarioList([Scenario({'a': 1, 'b': 1, 'b_number': 1}), Scenario({'a': 1, 'b': 2, 'b_number': 2})])
|
420
346
|
"""
|
421
347
|
new_scenarios = []
|
422
348
|
for scenario in self:
|
@@ -431,11 +357,12 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
431
357
|
new_scenarios.append(new_scenario)
|
432
358
|
return ScenarioList(new_scenarios)
|
433
359
|
|
434
|
-
def concatenate(self, fields: List[str], separator: str = ";") -> ScenarioList:
|
360
|
+
def concatenate(self, fields: List[str], separator: str = ";") -> "ScenarioList":
|
435
361
|
"""Concatenate specified fields into a single field.
|
436
362
|
|
437
|
-
:
|
438
|
-
|
363
|
+
Args:
|
364
|
+
fields (List[str]): List of field names to concatenate.
|
365
|
+
separator (str, optional): Separator to use between field values. Defaults to ";".
|
439
366
|
|
440
367
|
Returns:
|
441
368
|
ScenarioList: A new ScenarioList with concatenated fields.
|
@@ -465,17 +392,11 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
465
392
|
) -> ScenarioList:
|
466
393
|
"""Unpack a dictionary field into separate fields.
|
467
394
|
|
468
|
-
:param field: The field to unpack.
|
469
|
-
:param prefix: An optional prefix to add to the new fields.
|
470
|
-
:param drop_field: Whether to drop the original field.
|
471
|
-
|
472
395
|
Example:
|
473
396
|
|
474
397
|
>>> s = ScenarioList([Scenario({'a': 1, 'b': {'c': 2, 'd': 3}})])
|
475
398
|
>>> s.unpack_dict('b')
|
476
399
|
ScenarioList([Scenario({'a': 1, 'b': {'c': 2, 'd': 3}, 'c': 2, 'd': 3})])
|
477
|
-
>>> s.unpack_dict('b', prefix='new_')
|
478
|
-
ScenarioList([Scenario({'a': 1, 'b': {'c': 2, 'd': 3}, 'new_c': 2, 'new_d': 3})])
|
479
400
|
"""
|
480
401
|
new_scenarios = []
|
481
402
|
for scenario in self:
|
@@ -493,17 +414,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
493
414
|
def transform(
|
494
415
|
self, field: str, func: Callable, new_name: Optional[str] = None
|
495
416
|
) -> ScenarioList:
|
496
|
-
"""Transform a field using a function.
|
497
|
-
|
498
|
-
:param field: The field to transform.
|
499
|
-
:param func: The function to apply to the field.
|
500
|
-
:param new_name: An optional new name for the transformed field.
|
501
|
-
|
502
|
-
>>> s = ScenarioList([Scenario({'a': 1, 'b': 2}), Scenario({'a': 1, 'b': 1})])
|
503
|
-
>>> s.transform('b', lambda x: x + 1)
|
504
|
-
ScenarioList([Scenario({'a': 1, 'b': 3}), Scenario({'a': 1, 'b': 2})])
|
505
|
-
|
506
|
-
"""
|
417
|
+
"""Transform a field using a function."""
|
507
418
|
new_scenarios = []
|
508
419
|
for scenario in self:
|
509
420
|
new_scenario = scenario.copy()
|
@@ -517,9 +428,6 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
517
428
|
"""
|
518
429
|
Return a new ScenarioList with a new variable added.
|
519
430
|
|
520
|
-
:param new_var_string: A string with the new variable assignment.
|
521
|
-
:param functions_dict: A dictionary of functions to use in the assignment.
|
522
|
-
|
523
431
|
Example:
|
524
432
|
|
525
433
|
>>> s = ScenarioList([Scenario({'a': 1, 'b': 2}), Scenario({'a': 1, 'b': 1})])
|
@@ -528,7 +436,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
528
436
|
|
529
437
|
"""
|
530
438
|
if "=" not in new_var_string:
|
531
|
-
raise
|
439
|
+
raise Exception(
|
532
440
|
f"Mutate requires an '=' in the string, but '{new_var_string}' doesn't have one."
|
533
441
|
)
|
534
442
|
raw_var_name, expression = new_var_string.split("=", 1)
|
@@ -536,7 +444,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
536
444
|
from edsl.utilities.utilities import is_valid_variable_name
|
537
445
|
|
538
446
|
if not is_valid_variable_name(var_name):
|
539
|
-
raise
|
447
|
+
raise Exception(f"{var_name} is not a valid variable name.")
|
540
448
|
|
541
449
|
# create the evaluator
|
542
450
|
functions_dict = functions_dict or {}
|
@@ -554,15 +462,13 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
554
462
|
try:
|
555
463
|
new_data = [new_scenario(s, var_name) for s in self]
|
556
464
|
except Exception as e:
|
557
|
-
raise
|
465
|
+
raise Exception(f"Error in mutate. Exception:{e}")
|
558
466
|
|
559
467
|
return ScenarioList(new_data)
|
560
468
|
|
561
469
|
def order_by(self, *fields: str, reverse: bool = False) -> ScenarioList:
|
562
470
|
"""Order the scenarios by one or more fields.
|
563
471
|
|
564
|
-
:param fields: The fields to order by.
|
565
|
-
:param reverse: Whether to reverse the order.
|
566
472
|
Example:
|
567
473
|
|
568
474
|
>>> s = ScenarioList([Scenario({'a': 1, 'b': 2}), Scenario({'a': 1, 'b': 1})])
|
@@ -575,41 +481,16 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
575
481
|
|
576
482
|
return ScenarioList(sorted(self, key=get_sort_key, reverse=reverse))
|
577
483
|
|
578
|
-
def duplicate(self) -> ScenarioList:
|
579
|
-
"""Return a copy of the ScenarioList.
|
580
|
-
|
581
|
-
>>> sl = ScenarioList.example()
|
582
|
-
>>> sl_copy = sl.duplicate()
|
583
|
-
>>> sl == sl_copy
|
584
|
-
True
|
585
|
-
>>> sl is sl_copy
|
586
|
-
False
|
587
|
-
"""
|
588
|
-
return ScenarioList([scenario.copy() for scenario in self])
|
589
|
-
|
590
484
|
def filter(self, expression: str) -> ScenarioList:
|
591
485
|
"""
|
592
486
|
Filter a list of scenarios based on an expression.
|
593
487
|
|
594
|
-
:param expression: The expression to filter by.
|
595
|
-
|
596
488
|
Example:
|
597
489
|
|
598
490
|
>>> s = ScenarioList([Scenario({'a': 1, 'b': 1}), Scenario({'a': 1, 'b': 2})])
|
599
491
|
>>> s.filter("b == 2")
|
600
492
|
ScenarioList([Scenario({'a': 1, 'b': 2})])
|
601
493
|
"""
|
602
|
-
sl = self.duplicate()
|
603
|
-
base_keys = set(self[0].keys())
|
604
|
-
keys = set()
|
605
|
-
for scenario in sl:
|
606
|
-
keys.update(scenario.keys())
|
607
|
-
if keys != base_keys:
|
608
|
-
import warnings
|
609
|
-
|
610
|
-
warnings.warn(
|
611
|
-
"Ragged ScenarioList detected (different keys for different scenario entries). This may cause unexpected behavior."
|
612
|
-
)
|
613
494
|
|
614
495
|
def create_evaluator(scenario: Scenario):
|
615
496
|
"""Create an evaluator for the given result.
|
@@ -619,22 +500,14 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
619
500
|
|
620
501
|
try:
|
621
502
|
# iterates through all the results and evaluates the expression
|
622
|
-
new_data = [
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
available_fields = ", ".join(self.data[0].keys() if self.data else [])
|
628
|
-
raise ScenarioError(
|
629
|
-
f"Error in filter: '{e}'\n"
|
630
|
-
f"The expression '{expression}' refers to a field that does not exist.\n"
|
631
|
-
f"Scenario: {scenario}\n"
|
632
|
-
f"Available fields: {available_fields}\n"
|
633
|
-
"Check your filter expression or consult the documentation: "
|
634
|
-
"https://docs.expectedparrot.com/en/latest/scenarios.html#module-edsl.scenarios.Scenario"
|
635
|
-
) from None
|
503
|
+
new_data = [
|
504
|
+
scenario
|
505
|
+
for scenario in self.data
|
506
|
+
if create_evaluator(scenario).eval(expression)
|
507
|
+
]
|
636
508
|
except Exception as e:
|
637
|
-
|
509
|
+
print(f"Exception:{e}")
|
510
|
+
raise Exception(f"Error in filter. Exception:{e}")
|
638
511
|
|
639
512
|
return ScenarioList(new_data)
|
640
513
|
|
@@ -646,26 +519,30 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
646
519
|
:param urls: A list of URLs.
|
647
520
|
:param field_name: The name of the field to store the text from the URLs.
|
648
521
|
|
522
|
+
|
649
523
|
"""
|
650
524
|
return ScenarioList([Scenario.from_url(url, field_name) for url in urls])
|
651
525
|
|
652
|
-
def select(self, *fields
|
526
|
+
def select(self, *fields) -> ScenarioList:
|
653
527
|
"""
|
654
528
|
Selects scenarios with only the references fields.
|
655
529
|
|
656
|
-
:param fields: The fields to select.
|
657
|
-
|
658
530
|
Example:
|
659
531
|
|
660
532
|
>>> s = ScenarioList([Scenario({'a': 1, 'b': 1}), Scenario({'a': 1, 'b': 2})])
|
661
533
|
>>> s.select('a')
|
662
534
|
ScenarioList([Scenario({'a': 1}), Scenario({'a': 1})])
|
663
535
|
"""
|
664
|
-
|
536
|
+
if len(fields) == 1:
|
537
|
+
fields_to_select = [list(fields)[0]]
|
538
|
+
else:
|
539
|
+
fields_to_select = list(fields)
|
665
540
|
|
666
|
-
return
|
541
|
+
return ScenarioList(
|
542
|
+
[scenario.select(fields_to_select) for scenario in self.data]
|
543
|
+
)
|
667
544
|
|
668
|
-
def drop(self, *fields
|
545
|
+
def drop(self, *fields) -> ScenarioList:
|
669
546
|
"""Drop fields from the scenarios.
|
670
547
|
|
671
548
|
Example:
|
@@ -674,22 +551,18 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
674
551
|
>>> s.drop('a')
|
675
552
|
ScenarioList([Scenario({'b': 1}), Scenario({'b': 2})])
|
676
553
|
"""
|
677
|
-
|
678
|
-
return ScenarioList([scenario.drop(fields) for scenario in sl])
|
554
|
+
return ScenarioList([scenario.drop(fields) for scenario in self.data])
|
679
555
|
|
680
|
-
def keep(self, *fields
|
556
|
+
def keep(self, *fields) -> ScenarioList:
|
681
557
|
"""Keep only the specified fields in the scenarios.
|
682
558
|
|
683
|
-
:param fields: The fields to keep.
|
684
|
-
|
685
559
|
Example:
|
686
560
|
|
687
561
|
>>> s = ScenarioList([Scenario({'a': 1, 'b': 1}), Scenario({'a': 1, 'b': 2})])
|
688
562
|
>>> s.keep('a')
|
689
563
|
ScenarioList([Scenario({'a': 1}), Scenario({'a': 1})])
|
690
564
|
"""
|
691
|
-
|
692
|
-
return ScenarioList([scenario.keep(fields) for scenario in sl])
|
565
|
+
return ScenarioList([scenario.keep(fields) for scenario in self.data])
|
693
566
|
|
694
567
|
@classmethod
|
695
568
|
def from_list(
|
@@ -697,10 +570,6 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
697
570
|
) -> ScenarioList:
|
698
571
|
"""Create a ScenarioList from a list of values.
|
699
572
|
|
700
|
-
:param name: The name of the field.
|
701
|
-
:param values: The list of values.
|
702
|
-
:param func: An optional function to apply to the values.
|
703
|
-
|
704
573
|
Example:
|
705
574
|
|
706
575
|
>>> ScenarioList.from_list('name', ['Alice', 'Bob'])
|
@@ -710,12 +579,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
710
579
|
func = lambda x: x
|
711
580
|
return cls([Scenario({name: func(value)}) for value in values])
|
712
581
|
|
713
|
-
def table(
|
714
|
-
self,
|
715
|
-
*fields: str,
|
716
|
-
tablefmt: Optional[TableFormat] = None,
|
717
|
-
pretty_labels: Optional[dict[str, str]] = None,
|
718
|
-
) -> str:
|
582
|
+
def table(self, *fields, tablefmt=None, pretty_labels=None) -> str:
|
719
583
|
"""Return the ScenarioList as a table."""
|
720
584
|
|
721
585
|
from tabulate import tabulate_formats
|
@@ -730,41 +594,26 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
730
594
|
)
|
731
595
|
|
732
596
|
def tree(self, node_list: Optional[List[str]] = None) -> str:
|
733
|
-
"""Return the ScenarioList as a tree.
|
734
|
-
|
735
|
-
:param node_list: The list of nodes to include in the tree.
|
736
|
-
"""
|
597
|
+
"""Return the ScenarioList as a tree."""
|
737
598
|
return self.to_dataset().tree(node_list)
|
738
599
|
|
739
|
-
def _summary(self)
|
740
|
-
"""Return a summary of the ScenarioList.
|
741
|
-
|
742
|
-
>>> ScenarioList.example()._summary()
|
743
|
-
{'scenarios': 2, 'keys': ['persona']}
|
744
|
-
"""
|
600
|
+
def _summary(self):
|
745
601
|
d = {
|
746
|
-
"
|
747
|
-
"
|
602
|
+
"EDSL Class name": "ScenarioList",
|
603
|
+
"# Scenarios": len(self),
|
604
|
+
"Scenario Keys": list(self.parameters),
|
748
605
|
}
|
749
606
|
return d
|
750
607
|
|
751
|
-
def reorder_keys(self, new_order
|
608
|
+
def reorder_keys(self, new_order):
|
752
609
|
"""Reorder the keys in the scenarios.
|
753
610
|
|
754
|
-
:param new_order: The new order of the keys.
|
755
|
-
|
756
611
|
Example:
|
757
612
|
|
758
613
|
>>> s = ScenarioList([Scenario({'a': 1, 'b': 2}), Scenario({'a': 3, 'b': 4})])
|
759
614
|
>>> s.reorder_keys(['b', 'a'])
|
760
615
|
ScenarioList([Scenario({'b': 2, 'a': 1}), Scenario({'b': 4, 'a': 3})])
|
761
|
-
>>> s.reorder_keys(['a', 'b', 'c'])
|
762
|
-
Traceback (most recent call last):
|
763
|
-
...
|
764
|
-
AssertionError
|
765
616
|
"""
|
766
|
-
assert set(new_order) == set(self.parameters)
|
767
|
-
|
768
617
|
new_scenarios = []
|
769
618
|
for scenario in self:
|
770
619
|
new_scenario = Scenario({key: scenario[key] for key in new_order})
|
@@ -773,8 +622,6 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
773
622
|
|
774
623
|
def to_dataset(self) -> "Dataset":
|
775
624
|
"""
|
776
|
-
Convert the ScenarioList to a Dataset.
|
777
|
-
|
778
625
|
>>> s = ScenarioList.from_list("a", [1,2,3])
|
779
626
|
>>> s.to_dataset()
|
780
627
|
Dataset([{'a': [1, 2, 3]}])
|
@@ -784,14 +631,8 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
784
631
|
"""
|
785
632
|
from edsl.results.Dataset import Dataset
|
786
633
|
|
787
|
-
keys =
|
788
|
-
for scenario in self
|
789
|
-
new_keys = list(scenario.keys())
|
790
|
-
if new_keys != keys:
|
791
|
-
keys = list(set(keys + new_keys))
|
792
|
-
data = [
|
793
|
-
{key: [scenario.get(key, None) for scenario in self.data]} for key in keys
|
794
|
-
]
|
634
|
+
keys = self[0].keys()
|
635
|
+
data = [{key: [scenario[key] for scenario in self.data]} for key in keys]
|
795
636
|
return Dataset(data)
|
796
637
|
|
797
638
|
def unpack(
|
@@ -823,14 +664,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
823
664
|
new_scenarios.append(new_scenario)
|
824
665
|
return ScenarioList(new_scenarios)
|
825
666
|
|
826
|
-
|
827
|
-
def from_list_of_tuples(self, *names: str, values: List[Tuple]) -> ScenarioList:
|
828
|
-
sl = ScenarioList.from_list(names[0], [value[0] for value in values])
|
829
|
-
for index, name in enumerate(names[1:]):
|
830
|
-
sl = sl.add_list(name, [value[index + 1] for value in values])
|
831
|
-
return sl
|
832
|
-
|
833
|
-
def add_list(self, name: str, values: List[Any]) -> ScenarioList:
|
667
|
+
def add_list(self, name, values) -> ScenarioList:
|
834
668
|
"""Add a list of values to a ScenarioList.
|
835
669
|
|
836
670
|
Example:
|
@@ -839,25 +673,12 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
839
673
|
>>> s.add_list('age', [30, 25])
|
840
674
|
ScenarioList([Scenario({'name': 'Alice', 'age': 30}), Scenario({'name': 'Bob', 'age': 25})])
|
841
675
|
"""
|
842
|
-
sl = self.duplicate()
|
843
|
-
if len(values) != len(sl):
|
844
|
-
raise ScenarioError(
|
845
|
-
f"Length of values ({len(values)}) does not match length of ScenarioList ({len(sl)})"
|
846
|
-
)
|
847
676
|
for i, value in enumerate(values):
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
"""Create an empty ScenarioList with n scenarios.
|
854
|
-
|
855
|
-
Example:
|
856
|
-
|
857
|
-
>>> ScenarioList.create_empty_scenario_list(3)
|
858
|
-
ScenarioList([Scenario({}), Scenario({}), Scenario({})])
|
859
|
-
"""
|
860
|
-
return ScenarioList([Scenario({}) for _ in range(n)])
|
677
|
+
if i < len(self):
|
678
|
+
self[i][name] = value
|
679
|
+
else:
|
680
|
+
self.append(Scenario({name: value}))
|
681
|
+
return self
|
861
682
|
|
862
683
|
def add_value(self, name: str, value: Any) -> ScenarioList:
|
863
684
|
"""Add a value to all scenarios in a ScenarioList.
|
@@ -868,16 +689,13 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
868
689
|
>>> s.add_value('age', 30)
|
869
690
|
ScenarioList([Scenario({'name': 'Alice', 'age': 30}), Scenario({'name': 'Bob', 'age': 30})])
|
870
691
|
"""
|
871
|
-
|
872
|
-
for scenario in sl:
|
692
|
+
for scenario in self:
|
873
693
|
scenario[name] = value
|
874
|
-
return
|
694
|
+
return self
|
875
695
|
|
876
696
|
def rename(self, replacement_dict: dict) -> ScenarioList:
|
877
697
|
"""Rename the fields in the scenarios.
|
878
698
|
|
879
|
-
:param replacement_dict: A dictionary with the old names as keys and the new names as values.
|
880
|
-
|
881
699
|
Example:
|
882
700
|
|
883
701
|
>>> s = ScenarioList([Scenario({'name': 'Alice', 'age': 30}), Scenario({'name': 'Bob', 'age': 25})])
|
@@ -892,26 +710,8 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
892
710
|
new_list.append(new_obj)
|
893
711
|
return new_list
|
894
712
|
|
895
|
-
## NEEDS TO BE FIXED
|
896
|
-
# def new_column_names(self, new_names: List[str]) -> ScenarioList:
|
897
|
-
# """Rename the fields in the scenarios.
|
898
|
-
|
899
|
-
# Example:
|
900
|
-
|
901
|
-
# >>> s = ScenarioList([Scenario({'name': 'Alice', 'age': 30}), Scenario({'name': 'Bob', 'age': 25})])
|
902
|
-
# >>> s.new_column_names(['first_name', 'years'])
|
903
|
-
# ScenarioList([Scenario({'first_name': 'Alice', 'years': 30}), Scenario({'first_name': 'Bob', 'years': 25})])
|
904
|
-
|
905
|
-
# """
|
906
|
-
# new_list = ScenarioList([])
|
907
|
-
# for obj in self:
|
908
|
-
# new_obj = obj.new_column_names(new_names)
|
909
|
-
# new_list.append(new_obj)
|
910
|
-
# return new_list
|
911
|
-
|
912
713
|
@classmethod
|
913
714
|
def from_sqlite(cls, filepath: str, table: str):
|
914
|
-
"""Create a ScenarioList from a SQLite database."""
|
915
715
|
import sqlite3
|
916
716
|
|
917
717
|
with sqlite3.connect(filepath) as conn:
|
@@ -1055,9 +855,6 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1055
855
|
def to_key_value(self, field: str, value=None) -> Union[dict, set]:
|
1056
856
|
"""Return the set of values in the field.
|
1057
857
|
|
1058
|
-
:param field: The field to extract values from.
|
1059
|
-
:param value: An optional field to use as the value in the key-value pair.
|
1060
|
-
|
1061
858
|
Example:
|
1062
859
|
|
1063
860
|
>>> s = ScenarioList([Scenario({'name': 'Alice'}), Scenario({'name': 'Bob'})])
|
@@ -1177,42 +974,56 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1177
974
|
|
1178
975
|
@classmethod
|
1179
976
|
def from_delimited_file(
|
1180
|
-
cls, source: Union[str,
|
977
|
+
cls, source: Union[str, urllib.parse.ParseResult], delimiter: str = ","
|
1181
978
|
) -> ScenarioList:
|
1182
|
-
"""Create a ScenarioList from a delimited file (CSV/TSV) or URL.
|
1183
|
-
import requests
|
1184
|
-
from edsl.scenarios.Scenario import Scenario
|
1185
|
-
from urllib.parse import urlparse
|
1186
|
-
from urllib.parse import ParseResult
|
979
|
+
"""Create a ScenarioList from a delimited file (CSV/TSV) or URL.
|
1187
980
|
|
1188
|
-
|
1189
|
-
|
1190
|
-
|
1191
|
-
|
981
|
+
Args:
|
982
|
+
source: A string representing either a local file path or a URL to a delimited file,
|
983
|
+
or a urllib.parse.ParseResult object for a URL.
|
984
|
+
delimiter: The delimiter used in the file. Defaults to ',' for CSV files.
|
985
|
+
Use '\t' for TSV files.
|
986
|
+
|
987
|
+
Returns:
|
988
|
+
ScenarioList: A ScenarioList object containing the data from the file.
|
989
|
+
|
990
|
+
Example:
|
991
|
+
# For CSV files
|
992
|
+
|
993
|
+
>>> with open('data.csv', 'w') as f:
|
994
|
+
... _ = f.write('name,age\\nAlice,30\\nBob,25\\n')
|
995
|
+
>>> scenario_list = ScenarioList.from_delimited_file('data.csv')
|
996
|
+
|
997
|
+
# For TSV files
|
998
|
+
>>> with open('data.tsv', 'w') as f:
|
999
|
+
... _ = f.write('name\\tage\\nAlice\t30\\nBob\t25\\n')
|
1000
|
+
>>> scenario_list = ScenarioList.from_delimited_file('data.tsv', delimiter='\\t')
|
1001
|
+
|
1002
|
+
"""
|
1003
|
+
from edsl.scenarios.Scenario import Scenario
|
1192
1004
|
|
1193
1005
|
def is_url(source):
|
1194
1006
|
try:
|
1195
|
-
result = urlparse(source)
|
1007
|
+
result = urllib.parse.urlparse(source)
|
1196
1008
|
return all([result.scheme, result.netloc])
|
1197
1009
|
except ValueError:
|
1198
1010
|
return False
|
1199
1011
|
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1204
|
-
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1210
|
-
file_obj = open(source, "r")
|
1012
|
+
if isinstance(source, str) and is_url(source):
|
1013
|
+
with urllib.request.urlopen(source) as response:
|
1014
|
+
file_content = response.read().decode("utf-8")
|
1015
|
+
file_obj = StringIO(file_content)
|
1016
|
+
elif isinstance(source, urllib.parse.ParseResult):
|
1017
|
+
with urllib.request.urlopen(source.geturl()) as response:
|
1018
|
+
file_content = response.read().decode("utf-8")
|
1019
|
+
file_obj = StringIO(file_content)
|
1020
|
+
else:
|
1021
|
+
file_obj = open(source, "r")
|
1211
1022
|
|
1023
|
+
try:
|
1212
1024
|
reader = csv.reader(file_obj, delimiter=delimiter)
|
1213
1025
|
header = next(reader)
|
1214
1026
|
observations = [Scenario(dict(zip(header, row))) for row in reader]
|
1215
|
-
|
1216
1027
|
finally:
|
1217
1028
|
file_obj.close()
|
1218
1029
|
|
@@ -1220,7 +1031,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1220
1031
|
|
1221
1032
|
# Convenience methods for specific file types
|
1222
1033
|
@classmethod
|
1223
|
-
def from_csv(cls, source: Union[str,
|
1034
|
+
def from_csv(cls, source: Union[str, urllib.parse.ParseResult]) -> ScenarioList:
|
1224
1035
|
"""Create a ScenarioList from a CSV file or URL."""
|
1225
1036
|
return cls.from_delimited_file(source, delimiter=",")
|
1226
1037
|
|
@@ -1237,17 +1048,75 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1237
1048
|
>>> s3 == ScenarioList([Scenario({'age': 30, 'location': 'New York', 'name': 'Alice'}), Scenario({'age': 25, 'location': None, 'name': 'Bob'})])
|
1238
1049
|
True
|
1239
1050
|
"""
|
1240
|
-
from edsl.scenarios.
|
1051
|
+
from edsl.scenarios.ScenarioJoin import ScenarioJoin
|
1241
1052
|
|
1242
1053
|
sj = ScenarioJoin(self, other)
|
1243
1054
|
return sj.left_join(by)
|
1055
|
+
# # Validate join keys
|
1056
|
+
# if not by:
|
1057
|
+
# raise ValueError(
|
1058
|
+
# "Join keys cannot be empty. Please specify at least one key to join on."
|
1059
|
+
# )
|
1060
|
+
|
1061
|
+
# # Convert single string to list for consistent handling
|
1062
|
+
# by_keys = [by] if isinstance(by, str) else by
|
1063
|
+
|
1064
|
+
# # Verify all join keys exist in both ScenarioLists
|
1065
|
+
# left_keys = set(next(iter(self)).keys()) if self else set()
|
1066
|
+
# right_keys = set(next(iter(other)).keys()) if other else set()
|
1067
|
+
|
1068
|
+
# missing_left = set(by_keys) - left_keys
|
1069
|
+
# missing_right = set(by_keys) - right_keys
|
1070
|
+
# if missing_left or missing_right:
|
1071
|
+
# missing = missing_left | missing_right
|
1072
|
+
# raise ValueError(f"Join key(s) {missing} not found in both ScenarioLists")
|
1073
|
+
|
1074
|
+
# # Create lookup dictionary from the other ScenarioList
|
1075
|
+
# def get_key_tuple(scenario: Scenario, keys: list[str]) -> tuple:
|
1076
|
+
# return tuple(scenario[k] for k in keys)
|
1077
|
+
|
1078
|
+
# other_dict = {get_key_tuple(scenario, by_keys): scenario for scenario in other}
|
1079
|
+
|
1080
|
+
# # Collect all possible keys (like SQL combining all columns)
|
1081
|
+
# all_keys = set()
|
1082
|
+
# for scenario in self:
|
1083
|
+
# all_keys.update(scenario.keys())
|
1084
|
+
# for scenario in other:
|
1085
|
+
# all_keys.update(scenario.keys())
|
1086
|
+
|
1087
|
+
# new_scenarios = []
|
1088
|
+
# for scenario in self:
|
1089
|
+
# new_scenario = {
|
1090
|
+
# key: None for key in all_keys
|
1091
|
+
# } # Start with nulls (like SQL)
|
1092
|
+
# new_scenario.update(scenario) # Add all left values
|
1093
|
+
|
1094
|
+
# key_tuple = get_key_tuple(scenario, by_keys)
|
1095
|
+
# if matching_scenario := other_dict.get(key_tuple):
|
1096
|
+
# # Check for overlapping keys with different values
|
1097
|
+
# overlapping_keys = set(scenario.keys()) & set(matching_scenario.keys())
|
1098
|
+
# for key in overlapping_keys:
|
1099
|
+
# if key not in by_keys and scenario[key] != matching_scenario[key]:
|
1100
|
+
# join_conditions = [f"{k}='{scenario[k]}'" for k in by_keys]
|
1101
|
+
# print(
|
1102
|
+
# f"Warning: Conflicting values for key '{key}' where {' AND '.join(join_conditions)}. "
|
1103
|
+
# f"Keeping left value: {scenario[key]} (discarding: {matching_scenario[key]})"
|
1104
|
+
# )
|
1105
|
+
|
1106
|
+
# # Only update with non-overlapping keys from matching scenario
|
1107
|
+
# new_keys = set(matching_scenario.keys()) - set(scenario.keys())
|
1108
|
+
# new_scenario.update({k: matching_scenario[k] for k in new_keys})
|
1109
|
+
|
1110
|
+
# new_scenarios.append(Scenario(new_scenario))
|
1111
|
+
|
1112
|
+
# return ScenarioList(new_scenarios)
|
1244
1113
|
|
1245
1114
|
@classmethod
|
1246
|
-
def from_tsv(cls, source: Union[str,
|
1115
|
+
def from_tsv(cls, source: Union[str, urllib.parse.ParseResult]) -> ScenarioList:
|
1247
1116
|
"""Create a ScenarioList from a TSV file or URL."""
|
1248
1117
|
return cls.from_delimited_file(source, delimiter="\t")
|
1249
1118
|
|
1250
|
-
def to_dict(self, sort
|
1119
|
+
def to_dict(self, sort=False, add_edsl_version=True) -> dict:
|
1251
1120
|
"""
|
1252
1121
|
>>> s = ScenarioList([Scenario({'food': 'wood chips'}), Scenario({'food': 'wood-fired pizza'})])
|
1253
1122
|
>>> s.to_dict()
|
@@ -1259,7 +1128,6 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1259
1128
|
else:
|
1260
1129
|
data = self
|
1261
1130
|
d = {"scenarios": [s.to_dict(add_edsl_version=add_edsl_version) for s in data]}
|
1262
|
-
|
1263
1131
|
if add_edsl_version:
|
1264
1132
|
from edsl import __version__
|
1265
1133
|
|
@@ -1267,27 +1135,6 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1267
1135
|
d["edsl_class_name"] = self.__class__.__name__
|
1268
1136
|
return d
|
1269
1137
|
|
1270
|
-
def to(self, survey: Union["Survey", "QuestionBase"]) -> "Jobs":
|
1271
|
-
"""Create a Jobs object from a ScenarioList and a Survey object.
|
1272
|
-
|
1273
|
-
:param survey: The Survey object to use for the Jobs object.
|
1274
|
-
|
1275
|
-
Example:
|
1276
|
-
>>> from edsl import Survey
|
1277
|
-
>>> from edsl.jobs.Jobs import Jobs
|
1278
|
-
>>> from edsl import ScenarioList
|
1279
|
-
>>> isinstance(ScenarioList.example().to(Survey.example()), Jobs)
|
1280
|
-
True
|
1281
|
-
"""
|
1282
|
-
from edsl.surveys.Survey import Survey
|
1283
|
-
from edsl.questions.QuestionBase import QuestionBase
|
1284
|
-
from edsl.jobs.Jobs import Jobs
|
1285
|
-
|
1286
|
-
if isinstance(survey, QuestionBase):
|
1287
|
-
return Survey([survey]).by(self)
|
1288
|
-
else:
|
1289
|
-
return survey.by(self)
|
1290
|
-
|
1291
1138
|
@classmethod
|
1292
1139
|
def gen(cls, scenario_dicts_list: List[dict]) -> ScenarioList:
|
1293
1140
|
"""Create a `ScenarioList` from a list of dictionaries.
|
@@ -1312,25 +1159,16 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1312
1159
|
|
1313
1160
|
@classmethod
|
1314
1161
|
def from_nested_dict(cls, data: dict) -> ScenarioList:
|
1315
|
-
"""Create a `ScenarioList` from a nested dictionary.
|
1316
|
-
|
1317
|
-
>>> data = {"headline": ["Armistice Signed, War Over: Celebrations Erupt Across City"], "date": ["1918-11-11"], "author": ["Jane Smith"]}
|
1318
|
-
>>> ScenarioList.from_nested_dict(data)
|
1319
|
-
ScenarioList([Scenario({'headline': 'Armistice Signed, War Over: Celebrations Erupt Across City', 'date': '1918-11-11', 'author': 'Jane Smith'})])
|
1320
|
-
|
1321
|
-
"""
|
1322
|
-
length_of_first_list = len(next(iter(data.values())))
|
1323
|
-
s = ScenarioList.create_empty_scenario_list(n=length_of_first_list)
|
1162
|
+
"""Create a `ScenarioList` from a nested dictionary."""
|
1163
|
+
from edsl.scenarios.Scenario import Scenario
|
1324
1164
|
|
1325
|
-
|
1326
|
-
|
1327
|
-
|
1328
|
-
)
|
1329
|
-
for key, list_of_values in data.items():
|
1330
|
-
s = s.add_list(key, list_of_values)
|
1165
|
+
s = ScenarioList()
|
1166
|
+
for key, value in data.items():
|
1167
|
+
s.add_list(key, value)
|
1331
1168
|
return s
|
1332
1169
|
|
1333
1170
|
def code(self) -> str:
|
1171
|
+
## TODO: Refactor to only use the questions actually in the survey
|
1334
1172
|
"""Create the Python code representation of a survey."""
|
1335
1173
|
header_lines = [
|
1336
1174
|
"from edsl.scenarios.Scenario import Scenario",
|
@@ -1353,16 +1191,16 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1353
1191
|
"""
|
1354
1192
|
return cls([Scenario.example(randomize), Scenario.example(randomize)])
|
1355
1193
|
|
1356
|
-
|
1357
|
-
|
1358
|
-
|
1194
|
+
def rich_print(self) -> None:
|
1195
|
+
"""Display an object as a table."""
|
1196
|
+
from rich.table import Table
|
1359
1197
|
|
1360
|
-
|
1361
|
-
|
1362
|
-
|
1363
|
-
|
1364
|
-
|
1365
|
-
|
1198
|
+
table = Table(title="ScenarioList")
|
1199
|
+
table.add_column("Index", style="bold")
|
1200
|
+
table.add_column("Scenario")
|
1201
|
+
for i, s in enumerate(self):
|
1202
|
+
table.add_row(str(i), s.rich_print())
|
1203
|
+
return table
|
1366
1204
|
|
1367
1205
|
def __getitem__(self, key: Union[int, slice]) -> Any:
|
1368
1206
|
"""Return the item at the given index.
|
@@ -1408,18 +1246,9 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1408
1246
|
f"The 'name' field is reserved for the agent's name---putting this value in {proposed_agent_name}"
|
1409
1247
|
)
|
1410
1248
|
new_scenario[proposed_agent_name] = name
|
1411
|
-
|
1412
|
-
if "agent_parameters" in new_scenario:
|
1413
|
-
agent_parameters = new_scenario.pop("agent_parameters")
|
1414
|
-
instruction = agent_parameters.get("instruction", None)
|
1415
|
-
name = agent_parameters.get("name", None)
|
1416
|
-
new_agent = Agent(
|
1417
|
-
traits=new_scenario, name=name, instruction=instruction
|
1418
|
-
)
|
1249
|
+
agents.append(Agent(traits=new_scenario, name=name))
|
1419
1250
|
else:
|
1420
|
-
|
1421
|
-
|
1422
|
-
agents.append(new_agent)
|
1251
|
+
agents.append(Agent(traits=new_scenario))
|
1423
1252
|
|
1424
1253
|
return AgentList(agents)
|
1425
1254
|
|