edsl 0.1.27.dev2__py3-none-any.whl → 0.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +99 -22
- edsl/BaseDiff.py +260 -0
- edsl/__init__.py +4 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +26 -5
- edsl/agents/AgentList.py +62 -7
- edsl/agents/Invigilator.py +4 -9
- edsl/agents/InvigilatorBase.py +5 -5
- edsl/agents/descriptors.py +3 -1
- edsl/conjure/AgentConstructionMixin.py +152 -0
- edsl/conjure/Conjure.py +56 -0
- edsl/conjure/InputData.py +628 -0
- edsl/conjure/InputDataCSV.py +48 -0
- edsl/conjure/InputDataMixinQuestionStats.py +182 -0
- edsl/conjure/InputDataPyRead.py +91 -0
- edsl/conjure/InputDataSPSS.py +8 -0
- edsl/conjure/InputDataStata.py +8 -0
- edsl/conjure/QuestionOptionMixin.py +76 -0
- edsl/conjure/QuestionTypeMixin.py +23 -0
- edsl/conjure/RawQuestion.py +65 -0
- edsl/conjure/SurveyResponses.py +7 -0
- edsl/conjure/__init__.py +9 -4
- edsl/conjure/examples/placeholder.txt +0 -0
- edsl/conjure/naming_utilities.py +263 -0
- edsl/conjure/utilities.py +165 -28
- edsl/conversation/Conversation.py +238 -0
- edsl/conversation/car_buying.py +58 -0
- edsl/conversation/mug_negotiation.py +81 -0
- edsl/conversation/next_speaker_utilities.py +93 -0
- edsl/coop/coop.py +191 -12
- edsl/coop/utils.py +20 -2
- edsl/data/Cache.py +55 -17
- edsl/data/CacheHandler.py +10 -9
- edsl/inference_services/AnthropicService.py +1 -0
- edsl/inference_services/DeepInfraService.py +20 -13
- edsl/inference_services/GoogleService.py +7 -1
- edsl/inference_services/InferenceServicesCollection.py +33 -7
- edsl/inference_services/OpenAIService.py +17 -10
- edsl/inference_services/models_available_cache.py +69 -0
- edsl/inference_services/rate_limits_cache.py +25 -0
- edsl/inference_services/write_available.py +10 -0
- edsl/jobs/Jobs.py +240 -36
- edsl/jobs/buckets/BucketCollection.py +9 -3
- edsl/jobs/interviews/Interview.py +4 -1
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +24 -10
- edsl/jobs/interviews/retry_management.py +4 -4
- edsl/jobs/runners/JobsRunnerAsyncio.py +87 -45
- edsl/jobs/runners/JobsRunnerStatusData.py +3 -3
- edsl/jobs/tasks/QuestionTaskCreator.py +4 -2
- edsl/language_models/LanguageModel.py +37 -44
- edsl/language_models/ModelList.py +96 -0
- edsl/language_models/registry.py +14 -0
- edsl/language_models/repair.py +95 -24
- edsl/notebooks/Notebook.py +119 -31
- edsl/questions/QuestionBase.py +109 -12
- edsl/questions/descriptors.py +5 -2
- edsl/questions/question_registry.py +7 -0
- edsl/results/Result.py +20 -8
- edsl/results/Results.py +85 -11
- edsl/results/ResultsDBMixin.py +3 -6
- edsl/results/ResultsExportMixin.py +47 -16
- edsl/results/ResultsToolsMixin.py +5 -5
- edsl/scenarios/Scenario.py +59 -5
- edsl/scenarios/ScenarioList.py +97 -40
- edsl/study/ObjectEntry.py +97 -0
- edsl/study/ProofOfWork.py +110 -0
- edsl/study/SnapShot.py +77 -0
- edsl/study/Study.py +491 -0
- edsl/study/__init__.py +2 -0
- edsl/surveys/Survey.py +79 -31
- edsl/surveys/SurveyExportMixin.py +21 -3
- edsl/utilities/__init__.py +1 -0
- edsl/utilities/gcp_bucket/__init__.py +0 -0
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
- edsl/utilities/gcp_bucket/simple_example.py +9 -0
- edsl/utilities/interface.py +24 -28
- edsl/utilities/repair_functions.py +28 -0
- edsl/utilities/utilities.py +57 -2
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/METADATA +43 -17
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/RECORD +83 -55
- edsl-0.1.28.dist-info/entry_points.txt +3 -0
- edsl/conjure/RawResponseColumn.py +0 -327
- edsl/conjure/SurveyBuilder.py +0 -308
- edsl/conjure/SurveyBuilderCSV.py +0 -78
- edsl/conjure/SurveyBuilderSPSS.py +0 -118
- edsl/data/RemoteDict.py +0 -103
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/LICENSE +0 -0
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/WHEEL +0 -0
edsl/results/Results.py
CHANGED
@@ -5,9 +5,10 @@ It is not typically instantiated directly, but is returned by the run method of
|
|
5
5
|
|
6
6
|
from __future__ import annotations
|
7
7
|
import json
|
8
|
+
import hashlib
|
8
9
|
import random
|
9
10
|
from collections import UserList, defaultdict
|
10
|
-
from typing import Optional, Callable, Any, Type, Union
|
11
|
+
from typing import Optional, Callable, Any, Type, Union, List
|
11
12
|
|
12
13
|
from pygments import highlight
|
13
14
|
from pygments.lexers import JsonLexer
|
@@ -29,7 +30,8 @@ from edsl.results.Dataset import Dataset
|
|
29
30
|
from edsl.results.Result import Result
|
30
31
|
from edsl.results.ResultsExportMixin import ResultsExportMixin
|
31
32
|
from edsl.scenarios import Scenario
|
32
|
-
|
33
|
+
|
34
|
+
# from edsl.scenarios.ScenarioList import ScenarioList
|
33
35
|
from edsl.surveys import Survey
|
34
36
|
from edsl.data.Cache import Cache
|
35
37
|
from edsl.utilities import (
|
@@ -37,7 +39,7 @@ from edsl.utilities import (
|
|
37
39
|
shorten_string,
|
38
40
|
)
|
39
41
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
40
|
-
|
42
|
+
from edsl.utilities.utilities import dict_hash
|
41
43
|
from edsl.results.ResultsToolsMixin import ResultsToolsMixin
|
42
44
|
|
43
45
|
from edsl.results.ResultsDBMixin import ResultsDBMixin
|
@@ -163,7 +165,13 @@ class Results(UserList, Mixins, Base):
|
|
163
165
|
)
|
164
166
|
|
165
167
|
def __repr__(self) -> str:
|
166
|
-
return f"Results(data = {self.data}, survey = {repr(self.survey)}, created_columns = {self.created_columns})"
|
168
|
+
# return f"Results(data = {self.data}, survey = {repr(self.survey)}, created_columns = {self.created_columns})"
|
169
|
+
return f"""Results object
|
170
|
+
Size: {len(self.data)}.
|
171
|
+
Survey questions: {[q.question_name for q in self.survey.questions]}.
|
172
|
+
Created columns: {self.created_columns}
|
173
|
+
Hash: {hash(self)}
|
174
|
+
"""
|
167
175
|
|
168
176
|
def _repr_html_(self) -> str:
|
169
177
|
json_str = json.dumps(self.to_dict()["data"], indent=4)
|
@@ -174,6 +182,35 @@ class Results(UserList, Mixins, Base):
|
|
174
182
|
)
|
175
183
|
return HTML(formatted_json).data
|
176
184
|
|
185
|
+
def _to_dict(self, sort=False):
|
186
|
+
if sort:
|
187
|
+
data = sorted([result for result in self.data], key=lambda x: hash(x))
|
188
|
+
else:
|
189
|
+
data = [result for result in self.data]
|
190
|
+
return {
|
191
|
+
"data": [result.to_dict() for result in data],
|
192
|
+
"survey": self.survey.to_dict(),
|
193
|
+
"created_columns": self.created_columns,
|
194
|
+
"cache": Cache() if not hasattr(self, "cache") else self.cache.to_dict(),
|
195
|
+
}
|
196
|
+
|
197
|
+
def compare(self, other_results):
|
198
|
+
"""
|
199
|
+
Compare two Results objects and return the differences.
|
200
|
+
"""
|
201
|
+
hashes_0 = [hash(result) for result in self]
|
202
|
+
hashes_1 = [hash(result) for result in other_results]
|
203
|
+
|
204
|
+
in_self_but_not_other = set(hashes_0).difference(set(hashes_1))
|
205
|
+
in_other_but_not_self = set(hashes_1).difference(set(hashes_0))
|
206
|
+
|
207
|
+
indicies_self = [hashes_0.index(h) for h in in_self_but_not_other]
|
208
|
+
indices_other = [hashes_1.index(h) for h in in_other_but_not_self]
|
209
|
+
return {
|
210
|
+
"a_not_b": [self[i] for i in indicies_self],
|
211
|
+
"b_not_a": [other_results[i] for i in indices_other],
|
212
|
+
}
|
213
|
+
|
177
214
|
@add_edsl_version
|
178
215
|
def to_dict(self) -> dict[str, Any]:
|
179
216
|
"""Convert the Results object to a dictionary.
|
@@ -186,12 +223,14 @@ class Results(UserList, Mixins, Base):
|
|
186
223
|
>>> r.to_dict().keys()
|
187
224
|
dict_keys(['data', 'survey', 'created_columns', 'cache', 'edsl_version', 'edsl_class_name'])
|
188
225
|
"""
|
189
|
-
return
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
226
|
+
return self._to_dict()
|
227
|
+
|
228
|
+
def __hash__(self) -> int:
|
229
|
+
return dict_hash(self._to_dict(sort=True))
|
230
|
+
|
231
|
+
@property
|
232
|
+
def hashes(self) -> set:
|
233
|
+
return set(hash(result) for result in self.data)
|
195
234
|
|
196
235
|
@classmethod
|
197
236
|
@remove_edsl_version
|
@@ -318,7 +357,7 @@ class Results(UserList, Mixins, Base):
|
|
318
357
|
return [r.model for r in self.data]
|
319
358
|
|
320
359
|
@property
|
321
|
-
def scenarios(self) -> ScenarioList:
|
360
|
+
def scenarios(self) -> "ScenarioList":
|
322
361
|
"""Return a list of all of the scenarios in the Results.
|
323
362
|
|
324
363
|
Example:
|
@@ -327,6 +366,8 @@ class Results(UserList, Mixins, Base):
|
|
327
366
|
>>> r.scenarios
|
328
367
|
ScenarioList([Scenario({'period': 'morning'}), Scenario({'period': 'afternoon'}), Scenario({'period': 'morning'}), Scenario({'period': 'afternoon'})])
|
329
368
|
"""
|
369
|
+
from edsl import ScenarioList
|
370
|
+
|
330
371
|
return ScenarioList([r.scenario for r in self.data])
|
331
372
|
|
332
373
|
@property
|
@@ -487,6 +528,39 @@ class Results(UserList, Mixins, Base):
|
|
487
528
|
created_columns=self.created_columns + [new_var_name],
|
488
529
|
)
|
489
530
|
|
531
|
+
def add_column(self, column_name: str, values: list) -> Results:
|
532
|
+
"""Adds columns to Results
|
533
|
+
|
534
|
+
>>> r = Results.example()
|
535
|
+
>>> r.add_column('a', [1,2,3, 4]).select('a')
|
536
|
+
Dataset([{'answer.a': [1, 2, 3, 4]}])
|
537
|
+
"""
|
538
|
+
|
539
|
+
assert len(values) == len(
|
540
|
+
self.data
|
541
|
+
), "The number of values must match the number of results."
|
542
|
+
new_results = self.data.copy()
|
543
|
+
for i, result in enumerate(new_results):
|
544
|
+
result["answer"][column_name] = values[i]
|
545
|
+
return Results(
|
546
|
+
survey=self.survey,
|
547
|
+
data=new_results,
|
548
|
+
created_columns=self.created_columns + [column_name],
|
549
|
+
)
|
550
|
+
|
551
|
+
def add_columns_from_dict(self, columns: List[dict]) -> Results:
|
552
|
+
"""Adds columns to Results from a list of dictionaries.
|
553
|
+
|
554
|
+
>>> r = Results.example()
|
555
|
+
>>> r.add_columns_from_dict([{'a': 1, 'b': 2}, {'a': 3, 'b': 4}, {'a':3, 'b':2}, {'a':3, 'b':2}]).select('a', 'b')
|
556
|
+
Dataset([{'answer.a': [1, 3, 3, 3]}, {'answer.b': [2, 4, 2, 2]}])
|
557
|
+
"""
|
558
|
+
keys = list(columns[0].keys())
|
559
|
+
for key in keys:
|
560
|
+
values = [d[key] for d in columns]
|
561
|
+
self = self.add_column(key, values)
|
562
|
+
return self
|
563
|
+
|
490
564
|
def mutate(
|
491
565
|
self, new_var_string: str, functions_dict: Optional[dict] = None
|
492
566
|
) -> Results:
|
edsl/results/ResultsDBMixin.py
CHANGED
@@ -136,12 +136,9 @@ class ResultsDBMixin:
|
|
136
136
|
|
137
137
|
>>> from edsl.results import Results
|
138
138
|
>>> r = Results.example()
|
139
|
-
>>> r.sql("select data_type, key, value from self where data_type = 'answer' limit 3", shape="long")
|
140
|
-
|
141
|
-
|
142
|
-
1 answer how_feeling_comment This is a real survey response from a human.
|
143
|
-
2 answer how_feeling_yesterday Great
|
144
|
-
|
139
|
+
>>> d = r.sql("select data_type, key, value from self where data_type = 'answer' limit 3", shape="long")
|
140
|
+
>>> list(d['value'])
|
141
|
+
['OK', 'This is a real survey response from a human.', 'Great']
|
145
142
|
|
146
143
|
We can also return the data in wide format.
|
147
144
|
Note the use of single quotes to escape the column names, as required by sql.
|
@@ -6,7 +6,7 @@ import io
|
|
6
6
|
import random
|
7
7
|
from functools import wraps
|
8
8
|
|
9
|
-
from typing import Literal, Optional
|
9
|
+
from typing import Literal, Optional, Union
|
10
10
|
|
11
11
|
from edsl.utilities.utilities import is_notebook
|
12
12
|
|
@@ -33,6 +33,8 @@ class ResultsExportMixin:
|
|
33
33
|
return func(self.select(), *args, **kwargs)
|
34
34
|
elif self.__class__.__name__ == "Dataset":
|
35
35
|
return func(self, *args, **kwargs)
|
36
|
+
elif self.__class__.__name__ == "ScenarioList":
|
37
|
+
return func(self.to_dataset(), *args, **kwargs)
|
36
38
|
else:
|
37
39
|
raise Exception(
|
38
40
|
f"Class {self.__class__.__name__} not recognized as a Results or Dataset object."
|
@@ -46,6 +48,7 @@ class ResultsExportMixin:
|
|
46
48
|
) -> list:
|
47
49
|
"""Return the set of keys that are present in the dataset.
|
48
50
|
|
51
|
+
>>> from edsl.results.Dataset import Dataset
|
49
52
|
>>> d = Dataset([{'a.b':[1,2,3,4]}])
|
50
53
|
>>> d.relevant_columns()
|
51
54
|
['a.b']
|
@@ -155,6 +158,9 @@ class ResultsExportMixin:
|
|
155
158
|
max_rows=None,
|
156
159
|
tee=False,
|
157
160
|
iframe=False,
|
161
|
+
iframe_height: int = 200,
|
162
|
+
iframe_width: int = 600,
|
163
|
+
web=False,
|
158
164
|
) -> None:
|
159
165
|
"""Print the results in a pretty format.
|
160
166
|
|
@@ -239,21 +245,26 @@ class ResultsExportMixin:
|
|
239
245
|
elif format == "html":
|
240
246
|
notebook = is_notebook()
|
241
247
|
html_source = print_list_of_dicts_as_html_table(
|
242
|
-
new_data,
|
248
|
+
new_data, interactive=interactive
|
243
249
|
)
|
244
250
|
if iframe:
|
245
251
|
import html
|
246
252
|
|
247
|
-
height =
|
248
|
-
width =
|
253
|
+
height = iframe_height
|
254
|
+
width = iframe_width
|
249
255
|
escaped_output = html.escape(html_source)
|
250
256
|
# escaped_output = html_source
|
251
257
|
iframe = f""""
|
252
258
|
<iframe srcdoc="{ escaped_output }" style="width: {width}px; height: {height}px;"></iframe>
|
253
259
|
"""
|
254
260
|
display(HTML(iframe))
|
255
|
-
|
261
|
+
elif notebook:
|
256
262
|
display(HTML(html_source))
|
263
|
+
else:
|
264
|
+
from edsl.utilities.interface import view_html
|
265
|
+
|
266
|
+
view_html(html_source)
|
267
|
+
|
257
268
|
elif format == "markdown":
|
258
269
|
print_list_of_dicts_as_markdown_table(new_data, filename=filename)
|
259
270
|
elif format == "latex":
|
@@ -474,16 +485,19 @@ class ResultsExportMixin:
|
|
474
485
|
return filename
|
475
486
|
|
476
487
|
@_convert_decorator
|
477
|
-
def tally(
|
488
|
+
def tally(
|
489
|
+
self, *fields: Optional[str], top_n=None, output="dict"
|
490
|
+
) -> Union[dict, "Dataset"]:
|
478
491
|
"""Tally the values of a field or perform a cross-tab of multiple fields.
|
479
492
|
|
480
493
|
:param fields: The field(s) to tally, multiple fields for cross-tabulation.
|
481
494
|
|
495
|
+
>>> from edsl.results import Results
|
482
496
|
>>> r = Results.example()
|
483
497
|
>>> r.select('how_feeling').tally('answer.how_feeling')
|
484
498
|
{'OK': 2, 'Great': 1, 'Terrible': 1}
|
485
|
-
>>> r.tally('
|
486
|
-
{('
|
499
|
+
>>> r.select('how_feeling', 'period').tally('how_feeling', 'period')
|
500
|
+
{('OK', 'morning'): 1, ('Great', 'afternoon'): 1, ('Terrible', 'morning'): 1, ('OK', 'afternoon'): 1}
|
487
501
|
"""
|
488
502
|
from collections import Counter
|
489
503
|
|
@@ -506,19 +520,36 @@ class ResultsExportMixin:
|
|
506
520
|
else:
|
507
521
|
values = list(zip(*(self._key_to_value(field) for field in fields)))
|
508
522
|
|
523
|
+
for value in values:
|
524
|
+
if isinstance(value, list):
|
525
|
+
value = tuple(value)
|
526
|
+
|
509
527
|
tally = dict(Counter(values))
|
510
528
|
sorted_tally = dict(sorted(tally.items(), key=lambda item: -item[1]))
|
511
529
|
if top_n is not None:
|
512
530
|
sorted_tally = dict(list(sorted_tally.items())[:top_n])
|
513
531
|
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
532
|
+
import warnings
|
533
|
+
import textwrap
|
534
|
+
from edsl.results.Dataset import Dataset
|
535
|
+
|
536
|
+
if output == "dict":
|
537
|
+
warnings.warn(
|
538
|
+
textwrap.dedent(
|
539
|
+
"""\
|
540
|
+
The default output from tally will change to Dataset in the future.
|
541
|
+
Use output='Dataset' to get the Dataset object for now.
|
542
|
+
"""
|
543
|
+
)
|
544
|
+
)
|
545
|
+
return sorted_tally
|
546
|
+
elif output == "Dataset":
|
547
|
+
return Dataset(
|
548
|
+
[
|
549
|
+
{"value": list(sorted_tally.keys())},
|
550
|
+
{"count": list(sorted_tally.values())},
|
551
|
+
]
|
552
|
+
)
|
522
553
|
|
523
554
|
|
524
555
|
if __name__ == "__main__":
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from edsl import ScenarioList
|
1
|
+
# from edsl import ScenarioList
|
2
2
|
from edsl.questions import QuestionList, QuestionCheckBox
|
3
3
|
|
4
4
|
|
@@ -14,6 +14,7 @@ class ResultsToolsMixin:
|
|
14
14
|
print_exceptions=False,
|
15
15
|
) -> list:
|
16
16
|
values = self.shuffle(seed=seed).select(field).to_list()[:max_values]
|
17
|
+
from edsl import ScenarioList
|
17
18
|
|
18
19
|
q = QuestionList(
|
19
20
|
question_text=f"""
|
@@ -24,10 +25,7 @@ class ResultsToolsMixin:
|
|
24
25
|
""",
|
25
26
|
question_name="themes",
|
26
27
|
)
|
27
|
-
|
28
|
-
results = q.by(s).run(
|
29
|
-
print_exceptions=print_exceptions, progress_bar=progress_bar
|
30
|
-
)
|
28
|
+
results = q.run(print_exceptions=print_exceptions, progress_bar=progress_bar)
|
31
29
|
return results.select("themes").first()
|
32
30
|
|
33
31
|
def answers_to_themes(
|
@@ -38,6 +36,8 @@ class ResultsToolsMixin:
|
|
38
36
|
progress_bar=False,
|
39
37
|
print_exceptions=False,
|
40
38
|
) -> dict:
|
39
|
+
from edsl import ScenarioList
|
40
|
+
|
41
41
|
values = self.select(field).to_list()
|
42
42
|
scenarios = ScenarioList.from_list("field", values).add_value(
|
43
43
|
"context", context
|
edsl/scenarios/Scenario.py
CHANGED
@@ -5,10 +5,15 @@ from collections import UserDict
|
|
5
5
|
from typing import Union, List, Optional, Generator
|
6
6
|
import base64
|
7
7
|
import hashlib
|
8
|
+
import json
|
9
|
+
|
10
|
+
import fitz # PyMuPDF
|
11
|
+
import os
|
12
|
+
import subprocess
|
13
|
+
|
8
14
|
from rich.table import Table
|
9
15
|
|
10
16
|
from edsl.Base import Base
|
11
|
-
|
12
17
|
from edsl.scenarios.ScenarioImageMixin import ScenarioImageMixin
|
13
18
|
from edsl.scenarios.ScenarioHtmlMixin import ScenarioHtmlMixin
|
14
19
|
|
@@ -19,7 +24,9 @@ from edsl.utilities.decorators import (
|
|
19
24
|
|
20
25
|
|
21
26
|
class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
|
22
|
-
"""A Scenario is a dictionary of keys/values
|
27
|
+
"""A Scenario is a dictionary of keys/values.
|
28
|
+
|
29
|
+
They can be used parameterize edsl questions."""
|
23
30
|
|
24
31
|
def __init__(self, data: Union[dict, None] = None, name: str = None):
|
25
32
|
"""Initialize a new Scenario.
|
@@ -32,7 +39,7 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
|
|
32
39
|
self.name = name
|
33
40
|
|
34
41
|
def replicate(self, n: int) -> "ScenarioList":
|
35
|
-
"""Replicate a scenario n times.
|
42
|
+
"""Replicate a scenario n times to return a ScenarioList.
|
36
43
|
|
37
44
|
:param n: The number of times to replicate the scenario.
|
38
45
|
|
@@ -58,7 +65,7 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
|
|
58
65
|
self._has_image = value
|
59
66
|
|
60
67
|
def __add__(self, other_scenario: "Scenario") -> "Scenario":
|
61
|
-
"""Combine two scenarios
|
68
|
+
"""Combine two scenarios by taking the union of their keys
|
62
69
|
|
63
70
|
If the other scenario is None, then just return self.
|
64
71
|
|
@@ -102,6 +109,17 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
|
|
102
109
|
new_scenario[key] = value
|
103
110
|
return new_scenario
|
104
111
|
|
112
|
+
def _to_dict(self) -> dict:
|
113
|
+
"""Convert a scenario to a dictionary.
|
114
|
+
|
115
|
+
Example:
|
116
|
+
|
117
|
+
>>> s = Scenario({"food": "wood chips"})
|
118
|
+
>>> s.to_dict()
|
119
|
+
{'food': 'wood chips', 'edsl_version': '...', 'edsl_class_name': 'Scenario'}
|
120
|
+
"""
|
121
|
+
return self.data.copy()
|
122
|
+
|
105
123
|
@add_edsl_version
|
106
124
|
def to_dict(self) -> dict:
|
107
125
|
"""Convert a scenario to a dictionary.
|
@@ -112,7 +130,21 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
|
|
112
130
|
>>> s.to_dict()
|
113
131
|
{'food': 'wood chips', 'edsl_version': '...', 'edsl_class_name': 'Scenario'}
|
114
132
|
"""
|
115
|
-
return self.
|
133
|
+
return self._to_dict()
|
134
|
+
|
135
|
+
def __hash__(self) -> int:
|
136
|
+
"""
|
137
|
+
Return a hash of the scenario.
|
138
|
+
|
139
|
+
Example:
|
140
|
+
|
141
|
+
>>> s = Scenario({"food": "wood chips"})
|
142
|
+
>>> hash(s)
|
143
|
+
1153210385458344214
|
144
|
+
"""
|
145
|
+
from edsl.utilities.utilities import dict_hash
|
146
|
+
|
147
|
+
return dict_hash(self._to_dict())
|
116
148
|
|
117
149
|
def print(self):
|
118
150
|
from rich import print_json
|
@@ -183,6 +215,28 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
|
|
183
215
|
s.has_image = True
|
184
216
|
return s
|
185
217
|
|
218
|
+
@classmethod
|
219
|
+
def from_pdf(cls, pdf_path):
|
220
|
+
# Ensure the file exists
|
221
|
+
if not os.path.exists(pdf_path):
|
222
|
+
raise FileNotFoundError(f"The file {pdf_path} does not exist.")
|
223
|
+
|
224
|
+
# Open the PDF file
|
225
|
+
document = fitz.open(pdf_path)
|
226
|
+
|
227
|
+
# Get the filename from the path
|
228
|
+
filename = os.path.basename(pdf_path)
|
229
|
+
|
230
|
+
# Iterate through each page and extract text
|
231
|
+
text = ""
|
232
|
+
for page_num in range(len(document)):
|
233
|
+
page = document.load_page(page_num)
|
234
|
+
text = text + page.get_text()
|
235
|
+
|
236
|
+
# Create a dictionary for the combined text
|
237
|
+
page_info = {"filename": filename, "text": text}
|
238
|
+
return Scenario(page_info)
|
239
|
+
|
186
240
|
@classmethod
|
187
241
|
def from_docx(cls, docx_path: str) -> "Scenario":
|
188
242
|
"""Creates a scenario from the text of a docx file.
|
edsl/scenarios/ScenarioList.py
CHANGED
@@ -2,9 +2,9 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
import csv
|
5
|
-
|
5
|
+
import random
|
6
|
+
from collections import UserList, Counter
|
6
7
|
from collections.abc import Iterable
|
7
|
-
from collections import Counter
|
8
8
|
|
9
9
|
from typing import Any, Optional, Union, List
|
10
10
|
|
@@ -16,14 +16,14 @@ from edsl.Base import Base
|
|
16
16
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
17
17
|
from edsl.scenarios.ScenarioListPdfMixin import ScenarioListPdfMixin
|
18
18
|
|
19
|
-
import pandas as pd
|
20
|
-
|
21
19
|
from edsl.utilities.interface import print_scenario_list
|
22
20
|
|
23
21
|
from edsl.utilities import is_valid_variable_name
|
24
22
|
|
23
|
+
from edsl.results.ResultsExportMixin import ResultsExportMixin
|
24
|
+
|
25
25
|
|
26
|
-
class ScenarioList(Base, UserList, ScenarioListPdfMixin):
|
26
|
+
class ScenarioList(Base, UserList, ScenarioListPdfMixin, ResultsExportMixin):
|
27
27
|
"""Class for creating a list of scenarios to be used in a survey."""
|
28
28
|
|
29
29
|
def __init__(self, data: Optional[list] = None):
|
@@ -33,11 +33,37 @@ class ScenarioList(Base, UserList, ScenarioListPdfMixin):
|
|
33
33
|
else:
|
34
34
|
super().__init__([])
|
35
35
|
|
36
|
+
@property
|
37
|
+
def parameters(self) -> set:
|
38
|
+
"""Return the set of parameters in the ScenarioList
|
39
|
+
|
40
|
+
Example:
|
41
|
+
|
42
|
+
>>> s = ScenarioList([Scenario({'a': 1}), Scenario({'b': 2})])
|
43
|
+
>>> s.parameters == {'a', 'b'}
|
44
|
+
True
|
45
|
+
"""
|
46
|
+
if len(self) == 0:
|
47
|
+
return set()
|
48
|
+
|
49
|
+
return set.union(*[set(s.keys()) for s in self])
|
50
|
+
|
51
|
+
def __hash__(self) -> int:
|
52
|
+
"""Return the hash of the ScenarioList.
|
53
|
+
|
54
|
+
>>> s = ScenarioList.example()
|
55
|
+
>>> hash(s)
|
56
|
+
1262252885757976162
|
57
|
+
"""
|
58
|
+
from edsl.utilities.utilities import dict_hash
|
59
|
+
|
60
|
+
return dict_hash(self._to_dict(sort=True))
|
61
|
+
|
36
62
|
def __repr__(self):
|
37
63
|
return f"ScenarioList({self.data})"
|
38
64
|
|
39
65
|
def __mul__(self, other: ScenarioList) -> ScenarioList:
|
40
|
-
"""
|
66
|
+
"""Takes the cross product of two ScenarioLists."""
|
41
67
|
from itertools import product
|
42
68
|
|
43
69
|
new_sl = []
|
@@ -45,6 +71,24 @@ class ScenarioList(Base, UserList, ScenarioListPdfMixin):
|
|
45
71
|
new_sl.append(s1 + s2)
|
46
72
|
return ScenarioList(new_sl)
|
47
73
|
|
74
|
+
def times(self, other: ScenarioList) -> ScenarioList:
|
75
|
+
"""Takes the cross product of two ScenarioLists.
|
76
|
+
|
77
|
+
Example:
|
78
|
+
|
79
|
+
>>> s1 = ScenarioList([Scenario({'a': 1}), Scenario({'a': 2})])
|
80
|
+
>>> s2 = ScenarioList([Scenario({'b': 1}), Scenario({'b': 2})])
|
81
|
+
>>> s1.times(s2)
|
82
|
+
ScenarioList([Scenario({'a': 1, 'b': 1}), Scenario({'a': 1, 'b': 2}), Scenario({'a': 2, 'b': 1}), Scenario({'a': 2, 'b': 2})])
|
83
|
+
"""
|
84
|
+
return self.__mul__(other)
|
85
|
+
|
86
|
+
def shuffle(self, seed: Optional[str] = "edsl") -> ScenarioList:
|
87
|
+
"""Shuffle the ScenarioList."""
|
88
|
+
random.seed(seed)
|
89
|
+
random.shuffle(self.data)
|
90
|
+
return self
|
91
|
+
|
48
92
|
def _repr_html_(self) -> str:
|
49
93
|
from edsl.utilities.utilities import data_to_html
|
50
94
|
|
@@ -69,7 +113,6 @@ class ScenarioList(Base, UserList, ScenarioListPdfMixin):
|
|
69
113
|
|
70
114
|
def sample(self, n: int, seed="edsl") -> ScenarioList:
|
71
115
|
"""Return a random sample from the ScenarioList"""
|
72
|
-
import random
|
73
116
|
|
74
117
|
if seed != "edsl":
|
75
118
|
random.seed(seed)
|
@@ -217,6 +260,13 @@ class ScenarioList(Base, UserList, ScenarioListPdfMixin):
|
|
217
260
|
"""
|
218
261
|
return cls([Scenario({name: value}) for value in values])
|
219
262
|
|
263
|
+
def to_dataset(self) -> "Dataset":
|
264
|
+
from edsl.results.Dataset import Dataset
|
265
|
+
|
266
|
+
keys = self[0].keys()
|
267
|
+
data = {key: [scenario[key] for scenario in self.data] for key in keys}
|
268
|
+
return Dataset([data])
|
269
|
+
|
220
270
|
def add_list(self, name, values) -> ScenarioList:
|
221
271
|
"""Add a list of values to a ScenarioList.
|
222
272
|
|
@@ -227,7 +277,10 @@ class ScenarioList(Base, UserList, ScenarioListPdfMixin):
|
|
227
277
|
ScenarioList([Scenario({'name': 'Alice', 'age': 30}), Scenario({'name': 'Bob', 'age': 25})])
|
228
278
|
"""
|
229
279
|
for i, value in enumerate(values):
|
230
|
-
|
280
|
+
if i < len(self):
|
281
|
+
self[i][name] = value
|
282
|
+
else:
|
283
|
+
self.append(Scenario({name: value}))
|
231
284
|
return self
|
232
285
|
|
233
286
|
def add_value(self, name, value):
|
@@ -244,6 +297,16 @@ class ScenarioList(Base, UserList, ScenarioListPdfMixin):
|
|
244
297
|
return self
|
245
298
|
|
246
299
|
def rename(self, replacement_dict: dict) -> ScenarioList:
|
300
|
+
"""Rename the fields in the scenarios.
|
301
|
+
|
302
|
+
Example:
|
303
|
+
|
304
|
+
>>> s = ScenarioList([Scenario({'name': 'Alice', 'age': 30}), Scenario({'name': 'Bob', 'age': 25})])
|
305
|
+
>>> s.rename({'name': 'first_name', 'age': 'years'})
|
306
|
+
ScenarioList([Scenario({'first_name': 'Alice', 'years': 30}), Scenario({'first_name': 'Bob', 'years': 25})])
|
307
|
+
|
308
|
+
"""
|
309
|
+
|
247
310
|
new_list = ScenarioList([])
|
248
311
|
for obj in self:
|
249
312
|
new_obj = obj.rename(replacement_dict)
|
@@ -301,6 +364,13 @@ class ScenarioList(Base, UserList, ScenarioListPdfMixin):
|
|
301
364
|
observations.append(Scenario(dict(zip(header, row))))
|
302
365
|
return cls(observations)
|
303
366
|
|
367
|
+
def _to_dict(self, sort=False) -> dict:
|
368
|
+
if sort:
|
369
|
+
data = sorted(self, key=lambda x: hash(x))
|
370
|
+
else:
|
371
|
+
data = self
|
372
|
+
return {"scenarios": [s._to_dict() for s in data]}
|
373
|
+
|
304
374
|
@add_edsl_version
|
305
375
|
def to_dict(self) -> dict[str, Any]:
|
306
376
|
"""Return the `ScenarioList` as a dictionary.
|
@@ -315,7 +385,14 @@ class ScenarioList(Base, UserList, ScenarioListPdfMixin):
|
|
315
385
|
|
316
386
|
@classmethod
|
317
387
|
def gen(cls, scenario_dicts_list: List[dict]) -> ScenarioList:
|
318
|
-
"""Create a `ScenarioList` from a list of dictionaries.
|
388
|
+
"""Create a `ScenarioList` from a list of dictionaries.
|
389
|
+
|
390
|
+
Example:
|
391
|
+
|
392
|
+
>>> ScenarioList.gen([{'name': 'Alice'}, {'name': 'Bob'}])
|
393
|
+
ScenarioList([Scenario({'name': 'Alice'}), Scenario({'name': 'Bob'})])
|
394
|
+
|
395
|
+
"""
|
319
396
|
return cls([Scenario(s) for s in scenario_dicts_list])
|
320
397
|
|
321
398
|
@classmethod
|
@@ -361,39 +438,19 @@ class ScenarioList(Base, UserList, ScenarioListPdfMixin):
|
|
361
438
|
filename: str = None,
|
362
439
|
):
|
363
440
|
print_scenario_list(self)
|
364
|
-
# if format is None:
|
365
|
-
# if is_notebook():
|
366
|
-
# format = "html"
|
367
|
-
# else:
|
368
|
-
# format = "rich"
|
369
|
-
|
370
|
-
# if pretty_labels is None:
|
371
|
-
# pretty_labels = {}
|
372
|
-
|
373
|
-
# if format not in ["rich", "html", "markdown"]:
|
374
|
-
# raise ValueError("format must be one of 'rich', 'html', or 'markdown'.")
|
375
|
-
|
376
|
-
# if max_rows is not None:
|
377
|
-
# new_data = self[:max_rows]
|
378
|
-
# else:
|
379
|
-
# new_data = self
|
380
|
-
|
381
|
-
# if format == "rich":
|
382
|
-
# print_list_of_dicts_with_rich(
|
383
|
-
# new_data, filename=filename, split_at_dot=False
|
384
|
-
# )
|
385
|
-
# elif format == "html":
|
386
|
-
# notebook = is_notebook()
|
387
|
-
# html = print_list_of_dicts_as_html_table(
|
388
|
-
# new_data, filename=None, interactive=False, notebook=notebook
|
389
|
-
# )
|
390
|
-
# # print(html)
|
391
|
-
# display(HTML(html))
|
392
|
-
# elif format == "markdown":
|
393
|
-
# print_list_of_dicts_as_markdown_table(new_data, filename=filename)
|
394
441
|
|
395
442
|
def __getitem__(self, key: Union[int, slice]) -> Any:
|
396
|
-
"""Return the item at the given index.
|
443
|
+
"""Return the item at the given index.
|
444
|
+
|
445
|
+
Example:
|
446
|
+
>>> s = ScenarioList([Scenario({'age': 22, 'hair': 'brown', 'height': 5.5}), Scenario({'age': 22, 'hair': 'brown', 'height': 5.5})])
|
447
|
+
>>> s[0]
|
448
|
+
Scenario({'age': 22, 'hair': 'brown', 'height': 5.5})
|
449
|
+
|
450
|
+
>>> s[:1]
|
451
|
+
ScenarioList([Scenario({'age': 22, 'hair': 'brown', 'height': 5.5})])
|
452
|
+
|
453
|
+
"""
|
397
454
|
if isinstance(key, slice):
|
398
455
|
return ScenarioList(super().__getitem__(key))
|
399
456
|
elif isinstance(key, int):
|