edsl 0.1.38.dev2__py3-none-any.whl → 0.1.38.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +60 -31
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +18 -9
- edsl/agents/AgentList.py +59 -8
- edsl/agents/Invigilator.py +18 -7
- edsl/agents/InvigilatorBase.py +0 -19
- edsl/agents/PromptConstructor.py +5 -4
- edsl/config.py +8 -0
- edsl/coop/coop.py +74 -7
- edsl/data/Cache.py +27 -2
- edsl/data/CacheEntry.py +8 -3
- edsl/data/RemoteCacheSync.py +0 -19
- edsl/enums.py +2 -0
- edsl/inference_services/GoogleService.py +7 -15
- edsl/inference_services/PerplexityService.py +163 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +88 -548
- edsl/jobs/JobsChecks.py +147 -0
- edsl/jobs/JobsPrompts.py +268 -0
- edsl/jobs/JobsRemoteInferenceHandler.py +239 -0
- edsl/jobs/interviews/Interview.py +11 -11
- edsl/jobs/runners/JobsRunnerAsyncio.py +140 -35
- edsl/jobs/runners/JobsRunnerStatus.py +0 -2
- edsl/jobs/tasks/TaskHistory.py +15 -16
- edsl/language_models/LanguageModel.py +44 -84
- edsl/language_models/ModelList.py +47 -1
- edsl/language_models/registry.py +57 -4
- edsl/prompts/Prompt.py +8 -3
- edsl/questions/QuestionBase.py +20 -16
- edsl/questions/QuestionExtract.py +3 -4
- edsl/questions/question_registry.py +36 -6
- edsl/results/CSSParameterizer.py +108 -0
- edsl/results/Dataset.py +146 -15
- edsl/results/DatasetExportMixin.py +231 -217
- edsl/results/DatasetTree.py +134 -4
- edsl/results/Result.py +18 -9
- edsl/results/Results.py +145 -51
- edsl/results/TableDisplay.py +198 -0
- edsl/results/table_display.css +78 -0
- edsl/scenarios/FileStore.py +187 -13
- edsl/scenarios/Scenario.py +61 -4
- edsl/scenarios/ScenarioJoin.py +127 -0
- edsl/scenarios/ScenarioList.py +237 -62
- edsl/surveys/Survey.py +16 -2
- edsl/surveys/SurveyFlowVisualizationMixin.py +67 -9
- edsl/surveys/instructions/Instruction.py +12 -0
- edsl/templates/error_reporting/interview_details.html +3 -3
- edsl/templates/error_reporting/interviews.html +18 -9
- edsl/utilities/utilities.py +15 -0
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/METADATA +2 -1
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/RECORD +53 -45
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/LICENSE +0 -0
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/WHEEL +0 -0
@@ -0,0 +1,127 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Union, TYPE_CHECKING
|
3
|
+
|
4
|
+
# if TYPE_CHECKING:
|
5
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
6
|
+
from edsl.scenarios.Scenario import Scenario
|
7
|
+
|
8
|
+
|
9
|
+
class ScenarioJoin:
|
10
|
+
"""Handles join operations between two ScenarioLists.
|
11
|
+
|
12
|
+
This class encapsulates all join-related logic, making it easier to maintain
|
13
|
+
and extend with other join types (inner, right, full) in the future.
|
14
|
+
"""
|
15
|
+
|
16
|
+
def __init__(self, left: "ScenarioList", right: "ScenarioList"):
|
17
|
+
"""Initialize join operation with two ScenarioLists.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
left: The left ScenarioList
|
21
|
+
right: The right ScenarioList
|
22
|
+
"""
|
23
|
+
self.left = left
|
24
|
+
self.right = right
|
25
|
+
|
26
|
+
def left_join(self, by: Union[str, list[str]]) -> ScenarioList:
|
27
|
+
"""Perform a left join between the two ScenarioLists.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
by: String or list of strings representing the key(s) to join on. Cannot be empty.
|
31
|
+
|
32
|
+
Returns:
|
33
|
+
A new ScenarioList containing the joined scenarios
|
34
|
+
|
35
|
+
Raises:
|
36
|
+
ValueError: If by is empty or if any join keys don't exist in both ScenarioLists
|
37
|
+
"""
|
38
|
+
self._validate_join_keys(by)
|
39
|
+
by_keys = [by] if isinstance(by, str) else by
|
40
|
+
|
41
|
+
other_dict = self._create_lookup_dict(self.right, by_keys)
|
42
|
+
all_keys = self._get_all_keys()
|
43
|
+
|
44
|
+
return ScenarioList(
|
45
|
+
self._create_joined_scenarios(by_keys, other_dict, all_keys)
|
46
|
+
)
|
47
|
+
|
48
|
+
def _validate_join_keys(self, by: Union[str, list[str]]) -> None:
|
49
|
+
"""Validate join keys exist in both ScenarioLists."""
|
50
|
+
if not by:
|
51
|
+
raise ValueError(
|
52
|
+
"Join keys cannot be empty. Please specify at least one key to join on."
|
53
|
+
)
|
54
|
+
|
55
|
+
by_keys = [by] if isinstance(by, str) else by
|
56
|
+
left_keys = set(next(iter(self.left)).keys()) if self.left else set()
|
57
|
+
right_keys = set(next(iter(self.right)).keys()) if self.right else set()
|
58
|
+
|
59
|
+
missing_left = set(by_keys) - left_keys
|
60
|
+
missing_right = set(by_keys) - right_keys
|
61
|
+
if missing_left or missing_right:
|
62
|
+
missing = missing_left | missing_right
|
63
|
+
raise ValueError(f"Join key(s) {missing} not found in both ScenarioLists")
|
64
|
+
|
65
|
+
@staticmethod
|
66
|
+
def _get_key_tuple(scenario: Scenario, keys: list[str]) -> tuple:
|
67
|
+
"""Create a tuple of values for the join keys."""
|
68
|
+
return tuple(scenario[k] for k in keys)
|
69
|
+
|
70
|
+
def _create_lookup_dict(self, scenarios: ScenarioList, by_keys: list[str]) -> dict:
|
71
|
+
"""Create a lookup dictionary for the right scenarios."""
|
72
|
+
return {
|
73
|
+
self._get_key_tuple(scenario, by_keys): scenario for scenario in scenarios
|
74
|
+
}
|
75
|
+
|
76
|
+
def _get_all_keys(self) -> set:
|
77
|
+
"""Get all unique keys from both ScenarioLists."""
|
78
|
+
all_keys = set()
|
79
|
+
for scenario in self.left:
|
80
|
+
all_keys.update(scenario.keys())
|
81
|
+
for scenario in self.right:
|
82
|
+
all_keys.update(scenario.keys())
|
83
|
+
return all_keys
|
84
|
+
|
85
|
+
def _create_joined_scenarios(
|
86
|
+
self, by_keys: list[str], other_dict: dict, all_keys: set
|
87
|
+
) -> list[Scenario]:
|
88
|
+
"""Create the joined scenarios."""
|
89
|
+
new_scenarios = []
|
90
|
+
|
91
|
+
for scenario in self.left:
|
92
|
+
new_scenario = {key: None for key in all_keys}
|
93
|
+
new_scenario.update(scenario)
|
94
|
+
|
95
|
+
key_tuple = self._get_key_tuple(scenario, by_keys)
|
96
|
+
if matching_scenario := other_dict.get(key_tuple):
|
97
|
+
self._handle_matching_scenario(
|
98
|
+
new_scenario, scenario, matching_scenario, by_keys
|
99
|
+
)
|
100
|
+
|
101
|
+
new_scenarios.append(Scenario(new_scenario))
|
102
|
+
|
103
|
+
return new_scenarios
|
104
|
+
|
105
|
+
def _handle_matching_scenario(
|
106
|
+
self,
|
107
|
+
new_scenario: dict,
|
108
|
+
left_scenario: Scenario,
|
109
|
+
right_scenario: Scenario,
|
110
|
+
by_keys: list[str],
|
111
|
+
) -> None:
|
112
|
+
"""Handle merging of matching scenarios and conflict warnings."""
|
113
|
+
overlapping_keys = set(left_scenario.keys()) & set(right_scenario.keys())
|
114
|
+
|
115
|
+
for key in overlapping_keys:
|
116
|
+
if key not in by_keys and left_scenario[key] != right_scenario[key]:
|
117
|
+
join_conditions = [f"{k}='{left_scenario[k]}'" for k in by_keys]
|
118
|
+
print(
|
119
|
+
f"Warning: Conflicting values for key '{key}' where "
|
120
|
+
f"{' AND '.join(join_conditions)}. "
|
121
|
+
f"Keeping left value: {left_scenario[key]} "
|
122
|
+
f"(discarding: {right_scenario[key]})"
|
123
|
+
)
|
124
|
+
|
125
|
+
# Only update with non-overlapping keys from matching scenario
|
126
|
+
new_keys = set(right_scenario.keys()) - set(left_scenario.keys())
|
127
|
+
new_scenario.update({k: right_scenario[k] for k in new_keys})
|
edsl/scenarios/ScenarioList.py
CHANGED
@@ -31,6 +31,10 @@ class ScenarioListMixin(ScenarioListPdfMixin, ScenarioListExportMixin):
|
|
31
31
|
class ScenarioList(Base, UserList, ScenarioListMixin):
|
32
32
|
"""Class for creating a list of scenarios to be used in a survey."""
|
33
33
|
|
34
|
+
__documentation__ = (
|
35
|
+
"https://docs.expectedparrot.com/en/latest/scenarios.html#scenariolist"
|
36
|
+
)
|
37
|
+
|
34
38
|
def __init__(self, data: Optional[list] = None, codebook: Optional[dict] = None):
|
35
39
|
"""Initialize the ScenarioList class."""
|
36
40
|
if data is not None:
|
@@ -241,6 +245,9 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
241
245
|
|
242
246
|
return dict_hash(self.to_dict(sort=True, add_edsl_version=False))
|
243
247
|
|
248
|
+
def __eq__(self, other: Any) -> bool:
|
249
|
+
return hash(self) == hash(other)
|
250
|
+
|
244
251
|
def __repr__(self):
|
245
252
|
return f"ScenarioList({self.data})"
|
246
253
|
|
@@ -282,41 +289,49 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
282
289
|
random.shuffle(self.data)
|
283
290
|
return self
|
284
291
|
|
285
|
-
def _repr_html_(self)
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
292
|
+
def _repr_html_(self):
|
293
|
+
"""Return an HTML representation of the AgentList."""
|
294
|
+
# return (
|
295
|
+
# str(self.summary(format="html")) + "<br>" + str(self.table(tablefmt="html"))
|
296
|
+
# )
|
297
|
+
footer = f"<a href={self.__documentation__}>(docs)</a>"
|
298
|
+
return str(self.summary(format="html")) + footer
|
299
|
+
|
300
|
+
# def _repr_html_(self) -> str:
|
301
|
+
# from edsl.utilities.utilities import data_to_html
|
302
|
+
|
303
|
+
# data = self.to_dict()
|
304
|
+
# _ = data.pop("edsl_version")
|
305
|
+
# _ = data.pop("edsl_class_name")
|
306
|
+
# for s in data["scenarios"]:
|
307
|
+
# _ = s.pop("edsl_version")
|
308
|
+
# _ = s.pop("edsl_class_name")
|
309
|
+
# for scenario in data["scenarios"]:
|
310
|
+
# for key, value in scenario.items():
|
311
|
+
# if hasattr(value, "to_dict"):
|
312
|
+
# data[key] = value.to_dict()
|
313
|
+
# return data_to_html(data)
|
314
|
+
|
315
|
+
# def tally(self, field) -> dict:
|
316
|
+
# """Return a tally of the values in the field.
|
317
|
+
|
318
|
+
# Example:
|
319
|
+
|
320
|
+
# >>> s = ScenarioList([Scenario({'a': 1, 'b': 1}), Scenario({'a': 1, 'b': 2})])
|
321
|
+
# >>> s.tally('b')
|
322
|
+
# {1: 1, 2: 1}
|
323
|
+
# """
|
324
|
+
# return dict(Counter([scenario[field] for scenario in self]))
|
325
|
+
|
326
|
+
def sample(self, n: int, seed: Optional[str] = None) -> ScenarioList:
|
312
327
|
"""Return a random sample from the ScenarioList
|
313
328
|
|
314
329
|
>>> s = ScenarioList.from_list("a", [1,2,3,4,5,6])
|
315
|
-
>>> s.sample(3)
|
330
|
+
>>> s.sample(3, seed = "edsl")
|
316
331
|
ScenarioList([Scenario({'a': 2}), Scenario({'a': 1}), Scenario({'a': 3})])
|
317
332
|
"""
|
318
|
-
|
319
|
-
|
333
|
+
if seed:
|
334
|
+
random.seed(seed)
|
320
335
|
|
321
336
|
return ScenarioList(random.sample(self.data, n))
|
322
337
|
|
@@ -564,6 +579,47 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
564
579
|
func = lambda x: x
|
565
580
|
return cls([Scenario({name: func(value)}) for value in values])
|
566
581
|
|
582
|
+
def table(self, *fields, tablefmt=None, pretty_labels=None) -> str:
|
583
|
+
"""Return the ScenarioList as a table."""
|
584
|
+
|
585
|
+
from tabulate import tabulate_formats
|
586
|
+
|
587
|
+
if tablefmt is not None and tablefmt not in tabulate_formats:
|
588
|
+
raise ValueError(
|
589
|
+
f"Invalid table format: {tablefmt}",
|
590
|
+
f"Valid formats are: {tabulate_formats}",
|
591
|
+
)
|
592
|
+
return self.to_dataset().table(
|
593
|
+
*fields, tablefmt=tablefmt, pretty_labels=pretty_labels
|
594
|
+
)
|
595
|
+
|
596
|
+
def tree(self, node_list: Optional[List[str]] = None) -> str:
|
597
|
+
"""Return the ScenarioList as a tree."""
|
598
|
+
return self.to_dataset().tree(node_list)
|
599
|
+
|
600
|
+
def _summary(self):
|
601
|
+
d = {
|
602
|
+
"EDSL Class name": "ScenarioList",
|
603
|
+
"# Scenarios": len(self),
|
604
|
+
"Scenario Keys": list(self.parameters),
|
605
|
+
}
|
606
|
+
return d
|
607
|
+
|
608
|
+
def reorder_keys(self, new_order):
|
609
|
+
"""Reorder the keys in the scenarios.
|
610
|
+
|
611
|
+
Example:
|
612
|
+
|
613
|
+
>>> s = ScenarioList([Scenario({'a': 1, 'b': 2}), Scenario({'a': 3, 'b': 4})])
|
614
|
+
>>> s.reorder_keys(['b', 'a'])
|
615
|
+
ScenarioList([Scenario({'b': 2, 'a': 1}), Scenario({'b': 4, 'a': 3})])
|
616
|
+
"""
|
617
|
+
new_scenarios = []
|
618
|
+
for scenario in self:
|
619
|
+
new_scenario = Scenario({key: scenario[key] for key in new_order})
|
620
|
+
new_scenarios.append(new_scenario)
|
621
|
+
return ScenarioList(new_scenarios)
|
622
|
+
|
567
623
|
def to_dataset(self) -> "Dataset":
|
568
624
|
"""
|
569
625
|
>>> s = ScenarioList.from_list("a", [1,2,3])
|
@@ -579,16 +635,32 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
579
635
|
data = [{key: [scenario[key] for scenario in self.data]} for key in keys]
|
580
636
|
return Dataset(data)
|
581
637
|
|
582
|
-
def
|
583
|
-
self, field: str,
|
638
|
+
def unpack(
|
639
|
+
self, field: str, new_names: Optional[List[str]] = None, keep_original=True
|
584
640
|
) -> ScenarioList:
|
585
|
-
"""
|
586
|
-
|
587
|
-
|
641
|
+
"""Unpack a field into multiple fields.
|
642
|
+
|
643
|
+
Example:
|
644
|
+
|
645
|
+
>>> s = ScenarioList([Scenario({'a': 1, 'b': [2, True]}), Scenario({'a': 3, 'b': [3, False]})])
|
646
|
+
>>> s.unpack('b')
|
647
|
+
ScenarioList([Scenario({'a': 1, 'b': [2, True], 'b_0': 2, 'b_1': True}), Scenario({'a': 3, 'b': [3, False], 'b_0': 3, 'b_1': False})])
|
648
|
+
>>> s.unpack('b', new_names=['c', 'd'], keep_original=False)
|
649
|
+
ScenarioList([Scenario({'a': 1, 'c': 2, 'd': True}), Scenario({'a': 3, 'c': 3, 'd': False})])
|
650
|
+
|
651
|
+
"""
|
652
|
+
new_names = new_names or [f"{field}_{i}" for i in range(len(self[0][field]))]
|
588
653
|
new_scenarios = []
|
589
654
|
for scenario in self:
|
590
655
|
new_scenario = scenario.copy()
|
591
|
-
|
656
|
+
if len(new_names) == 1:
|
657
|
+
new_scenario[new_names[0]] = scenario[field]
|
658
|
+
else:
|
659
|
+
for i, new_name in enumerate(new_names):
|
660
|
+
new_scenario[new_name] = scenario[field][i]
|
661
|
+
|
662
|
+
if not keep_original:
|
663
|
+
del new_scenario[field]
|
592
664
|
new_scenarios.append(new_scenario)
|
593
665
|
return ScenarioList(new_scenarios)
|
594
666
|
|
@@ -901,33 +973,32 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
901
973
|
return cls.from_excel(temp_filename, sheet_name=sheet_name)
|
902
974
|
|
903
975
|
@classmethod
|
904
|
-
def
|
905
|
-
|
976
|
+
def from_delimited_file(
|
977
|
+
cls, source: Union[str, urllib.parse.ParseResult], delimiter: str = ","
|
978
|
+
) -> ScenarioList:
|
979
|
+
"""Create a ScenarioList from a delimited file (CSV/TSV) or URL.
|
906
980
|
|
907
981
|
Args:
|
908
|
-
source: A string representing either a local file path or a URL to a
|
982
|
+
source: A string representing either a local file path or a URL to a delimited file,
|
909
983
|
or a urllib.parse.ParseResult object for a URL.
|
984
|
+
delimiter: The delimiter used in the file. Defaults to ',' for CSV files.
|
985
|
+
Use '\t' for TSV files.
|
910
986
|
|
911
987
|
Returns:
|
912
|
-
ScenarioList: A ScenarioList object containing the data from the
|
988
|
+
ScenarioList: A ScenarioList object containing the data from the file.
|
913
989
|
|
914
990
|
Example:
|
991
|
+
# For CSV files
|
915
992
|
|
916
|
-
|
917
|
-
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
>>> scenario_list[0]['name']
|
925
|
-
'Alice'
|
926
|
-
>>> scenario_list[1]['age']
|
927
|
-
'25'
|
993
|
+
>>> with open('data.csv', 'w') as f:
|
994
|
+
... _ = f.write('name,age\\nAlice,30\\nBob,25\\n')
|
995
|
+
>>> scenario_list = ScenarioList.from_delimited_file('data.csv')
|
996
|
+
|
997
|
+
# For TSV files
|
998
|
+
>>> with open('data.tsv', 'w') as f:
|
999
|
+
... _ = f.write('name\\tage\\nAlice\t30\\nBob\t25\\n')
|
1000
|
+
>>> scenario_list = ScenarioList.from_delimited_file('data.tsv', delimiter='\\t')
|
928
1001
|
|
929
|
-
>>> url = "https://example.com/data.csv"
|
930
|
-
>>> ## scenario_list_from_url = ScenarioList.from_csv(url)
|
931
1002
|
"""
|
932
1003
|
from edsl.scenarios.Scenario import Scenario
|
933
1004
|
|
@@ -940,24 +1011,111 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
940
1011
|
|
941
1012
|
if isinstance(source, str) and is_url(source):
|
942
1013
|
with urllib.request.urlopen(source) as response:
|
943
|
-
|
944
|
-
|
1014
|
+
file_content = response.read().decode("utf-8")
|
1015
|
+
file_obj = StringIO(file_content)
|
945
1016
|
elif isinstance(source, urllib.parse.ParseResult):
|
946
1017
|
with urllib.request.urlopen(source.geturl()) as response:
|
947
|
-
|
948
|
-
|
1018
|
+
file_content = response.read().decode("utf-8")
|
1019
|
+
file_obj = StringIO(file_content)
|
949
1020
|
else:
|
950
|
-
|
1021
|
+
file_obj = open(source, "r")
|
951
1022
|
|
952
1023
|
try:
|
953
|
-
reader = csv.reader(
|
1024
|
+
reader = csv.reader(file_obj, delimiter=delimiter)
|
954
1025
|
header = next(reader)
|
955
1026
|
observations = [Scenario(dict(zip(header, row))) for row in reader]
|
956
1027
|
finally:
|
957
|
-
|
1028
|
+
file_obj.close()
|
958
1029
|
|
959
1030
|
return cls(observations)
|
960
1031
|
|
1032
|
+
# Convenience methods for specific file types
|
1033
|
+
@classmethod
|
1034
|
+
def from_csv(cls, source: Union[str, urllib.parse.ParseResult]) -> ScenarioList:
|
1035
|
+
"""Create a ScenarioList from a CSV file or URL."""
|
1036
|
+
return cls.from_delimited_file(source, delimiter=",")
|
1037
|
+
|
1038
|
+
def left_join(self, other: ScenarioList, by: Union[str, list[str]]) -> ScenarioList:
|
1039
|
+
"""Perform a left join with another ScenarioList, following SQL join semantics.
|
1040
|
+
|
1041
|
+
Args:
|
1042
|
+
other: The ScenarioList to join with
|
1043
|
+
by: String or list of strings representing the key(s) to join on. Cannot be empty.
|
1044
|
+
|
1045
|
+
>>> s1 = ScenarioList([Scenario({'name': 'Alice', 'age': 30}), Scenario({'name': 'Bob', 'age': 25})])
|
1046
|
+
>>> s2 = ScenarioList([Scenario({'name': 'Alice', 'location': 'New York'}), Scenario({'name': 'Charlie', 'location': 'Los Angeles'})])
|
1047
|
+
>>> s3 = s1.left_join(s2, 'name')
|
1048
|
+
>>> s3 == ScenarioList([Scenario({'age': 30, 'location': 'New York', 'name': 'Alice'}), Scenario({'age': 25, 'location': None, 'name': 'Bob'})])
|
1049
|
+
True
|
1050
|
+
"""
|
1051
|
+
from edsl.scenarios.ScenarioJoin import ScenarioJoin
|
1052
|
+
|
1053
|
+
sj = ScenarioJoin(self, other)
|
1054
|
+
return sj.left_join(by)
|
1055
|
+
# # Validate join keys
|
1056
|
+
# if not by:
|
1057
|
+
# raise ValueError(
|
1058
|
+
# "Join keys cannot be empty. Please specify at least one key to join on."
|
1059
|
+
# )
|
1060
|
+
|
1061
|
+
# # Convert single string to list for consistent handling
|
1062
|
+
# by_keys = [by] if isinstance(by, str) else by
|
1063
|
+
|
1064
|
+
# # Verify all join keys exist in both ScenarioLists
|
1065
|
+
# left_keys = set(next(iter(self)).keys()) if self else set()
|
1066
|
+
# right_keys = set(next(iter(other)).keys()) if other else set()
|
1067
|
+
|
1068
|
+
# missing_left = set(by_keys) - left_keys
|
1069
|
+
# missing_right = set(by_keys) - right_keys
|
1070
|
+
# if missing_left or missing_right:
|
1071
|
+
# missing = missing_left | missing_right
|
1072
|
+
# raise ValueError(f"Join key(s) {missing} not found in both ScenarioLists")
|
1073
|
+
|
1074
|
+
# # Create lookup dictionary from the other ScenarioList
|
1075
|
+
# def get_key_tuple(scenario: Scenario, keys: list[str]) -> tuple:
|
1076
|
+
# return tuple(scenario[k] for k in keys)
|
1077
|
+
|
1078
|
+
# other_dict = {get_key_tuple(scenario, by_keys): scenario for scenario in other}
|
1079
|
+
|
1080
|
+
# # Collect all possible keys (like SQL combining all columns)
|
1081
|
+
# all_keys = set()
|
1082
|
+
# for scenario in self:
|
1083
|
+
# all_keys.update(scenario.keys())
|
1084
|
+
# for scenario in other:
|
1085
|
+
# all_keys.update(scenario.keys())
|
1086
|
+
|
1087
|
+
# new_scenarios = []
|
1088
|
+
# for scenario in self:
|
1089
|
+
# new_scenario = {
|
1090
|
+
# key: None for key in all_keys
|
1091
|
+
# } # Start with nulls (like SQL)
|
1092
|
+
# new_scenario.update(scenario) # Add all left values
|
1093
|
+
|
1094
|
+
# key_tuple = get_key_tuple(scenario, by_keys)
|
1095
|
+
# if matching_scenario := other_dict.get(key_tuple):
|
1096
|
+
# # Check for overlapping keys with different values
|
1097
|
+
# overlapping_keys = set(scenario.keys()) & set(matching_scenario.keys())
|
1098
|
+
# for key in overlapping_keys:
|
1099
|
+
# if key not in by_keys and scenario[key] != matching_scenario[key]:
|
1100
|
+
# join_conditions = [f"{k}='{scenario[k]}'" for k in by_keys]
|
1101
|
+
# print(
|
1102
|
+
# f"Warning: Conflicting values for key '{key}' where {' AND '.join(join_conditions)}. "
|
1103
|
+
# f"Keeping left value: {scenario[key]} (discarding: {matching_scenario[key]})"
|
1104
|
+
# )
|
1105
|
+
|
1106
|
+
# # Only update with non-overlapping keys from matching scenario
|
1107
|
+
# new_keys = set(matching_scenario.keys()) - set(scenario.keys())
|
1108
|
+
# new_scenario.update({k: matching_scenario[k] for k in new_keys})
|
1109
|
+
|
1110
|
+
# new_scenarios.append(Scenario(new_scenario))
|
1111
|
+
|
1112
|
+
# return ScenarioList(new_scenarios)
|
1113
|
+
|
1114
|
+
@classmethod
|
1115
|
+
def from_tsv(cls, source: Union[str, urllib.parse.ParseResult]) -> ScenarioList:
|
1116
|
+
"""Create a ScenarioList from a TSV file or URL."""
|
1117
|
+
return cls.from_delimited_file(source, delimiter="\t")
|
1118
|
+
|
961
1119
|
def to_dict(self, sort=False, add_edsl_version=True) -> dict:
|
962
1120
|
"""
|
963
1121
|
>>> s = ScenarioList([Scenario({'food': 'wood chips'}), Scenario({'food': 'wood-fired pizza'})])
|
@@ -1074,8 +1232,25 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1074
1232
|
"""
|
1075
1233
|
from edsl.agents.AgentList import AgentList
|
1076
1234
|
from edsl.agents.Agent import Agent
|
1235
|
+
import warnings
|
1236
|
+
|
1237
|
+
agents = []
|
1238
|
+
for scenario in self:
|
1239
|
+
new_scenario = scenario.copy().data
|
1240
|
+
if "name" in new_scenario:
|
1241
|
+
name = new_scenario.pop("name")
|
1242
|
+
proposed_agent_name = "agent_name"
|
1243
|
+
while proposed_agent_name not in new_scenario:
|
1244
|
+
proposed_agent_name += "_"
|
1245
|
+
warnings.warn(
|
1246
|
+
f"The 'name' field is reserved for the agent's name---putting this value in {proposed_agent_name}"
|
1247
|
+
)
|
1248
|
+
new_scenario[proposed_agent_name] = name
|
1249
|
+
agents.append(Agent(traits=new_scenario, name=name))
|
1250
|
+
else:
|
1251
|
+
agents.append(Agent(traits=new_scenario))
|
1077
1252
|
|
1078
|
-
return AgentList(
|
1253
|
+
return AgentList(agents)
|
1079
1254
|
|
1080
1255
|
def chunk(
|
1081
1256
|
self,
|
edsl/surveys/Survey.py
CHANGED
@@ -41,6 +41,8 @@ class ValidatedString(str):
|
|
41
41
|
class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
|
42
42
|
"""A collection of questions that supports skip logic."""
|
43
43
|
|
44
|
+
__documentation__ = """https://docs.expectedparrot.com/en/latest/surveys.html"""
|
45
|
+
|
44
46
|
questions = QuestionsDescriptor()
|
45
47
|
"""
|
46
48
|
A collection of questions that supports skip logic.
|
@@ -1587,10 +1589,22 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
|
|
1587
1589
|
# question_names_string = ", ".join([repr(name) for name in self.question_names])
|
1588
1590
|
return f"Survey(questions=[{questions_string}], memory_plan={self.memory_plan}, rule_collection={self.rule_collection}, question_groups={self.question_groups})"
|
1589
1591
|
|
1592
|
+
def _summary(self) -> dict:
|
1593
|
+
return {
|
1594
|
+
"EDSL Class": "Survey",
|
1595
|
+
"Number of Questions": len(self),
|
1596
|
+
"Question Names": self.question_names,
|
1597
|
+
}
|
1598
|
+
|
1590
1599
|
def _repr_html_(self) -> str:
|
1591
|
-
|
1600
|
+
footer = f"<a href={self.__documentation__}>(docs)</a>"
|
1601
|
+
return str(self.summary(format="html")) + footer
|
1602
|
+
|
1603
|
+
def tree(self, node_list: Optional[List[str]] = None):
|
1604
|
+
return self.to_scenario_list().tree(node_list=node_list)
|
1592
1605
|
|
1593
|
-
|
1606
|
+
def table(self, *fields, tablefmt=None) -> Table:
|
1607
|
+
return self.to_scenario_list().to_dataset().table(*fields, tablefmt=tablefmt)
|
1594
1608
|
|
1595
1609
|
def rich_print(self) -> Table:
|
1596
1610
|
"""Print the survey in a rich format.
|
@@ -1,27 +1,85 @@
|
|
1
|
-
"""A mixin for visualizing the flow of a survey."""
|
1
|
+
"""A mixin for visualizing the flow of a survey with parameter nodes."""
|
2
2
|
|
3
3
|
from typing import Optional
|
4
4
|
from edsl.surveys.base import RulePriority, EndOfSurvey
|
5
5
|
import tempfile
|
6
|
+
import os
|
6
7
|
|
7
8
|
|
8
9
|
class SurveyFlowVisualizationMixin:
|
9
|
-
"""A mixin for visualizing the flow of a survey."""
|
10
|
+
"""A mixin for visualizing the flow of a survey with parameter visualization."""
|
10
11
|
|
11
12
|
def show_flow(self, filename: Optional[str] = None):
|
12
|
-
"""Create an image showing the flow of users through the survey."""
|
13
|
+
"""Create an image showing the flow of users through the survey and question parameters."""
|
13
14
|
# Create a graph object
|
14
15
|
import pydot
|
15
16
|
|
16
17
|
graph = pydot.Dot(graph_type="digraph")
|
17
18
|
|
18
|
-
#
|
19
|
+
# First collect all unique parameters and answer references
|
20
|
+
params_and_refs = set()
|
21
|
+
param_to_questions = {} # Keep track of which questions use each parameter
|
22
|
+
answer_refs = set() # Track answer references between questions
|
23
|
+
|
24
|
+
# First pass: collect parameters and their question associations
|
19
25
|
for index, question in enumerate(self.questions):
|
20
|
-
|
21
|
-
|
22
|
-
|
26
|
+
# Add the main question node
|
27
|
+
question_node = pydot.Node(
|
28
|
+
f"Q{index}", label=f"{question.question_name}", shape="ellipse"
|
29
|
+
)
|
30
|
+
graph.add_node(question_node)
|
31
|
+
|
32
|
+
if hasattr(question, "parameters"):
|
33
|
+
for param in question.parameters:
|
34
|
+
# Check if this is an answer reference (contains '.answer')
|
35
|
+
if ".answer" in param:
|
36
|
+
answer_refs.add((param.split(".")[0], index))
|
37
|
+
else:
|
38
|
+
params_and_refs.add(param)
|
39
|
+
if param not in param_to_questions:
|
40
|
+
param_to_questions[param] = []
|
41
|
+
param_to_questions[param].append(index)
|
42
|
+
|
43
|
+
# Create parameter nodes and connect them to questions
|
44
|
+
for param in params_and_refs:
|
45
|
+
param_node_name = f"param_{param}"
|
46
|
+
param_node = pydot.Node(
|
47
|
+
param_node_name,
|
48
|
+
label=f"{{{{ {param} }}}}",
|
49
|
+
shape="box",
|
50
|
+
style="filled",
|
51
|
+
fillcolor="lightgrey",
|
52
|
+
fontsize="10",
|
53
|
+
)
|
54
|
+
graph.add_node(param_node)
|
55
|
+
|
56
|
+
# Connect this parameter to all questions that use it
|
57
|
+
for q_index in param_to_questions[param]:
|
58
|
+
param_edge = pydot.Edge(
|
59
|
+
param_node_name,
|
60
|
+
f"Q{q_index}",
|
61
|
+
style="dotted",
|
62
|
+
color="grey",
|
63
|
+
arrowsize="0.5",
|
23
64
|
)
|
65
|
+
graph.add_edge(param_edge)
|
66
|
+
|
67
|
+
# Add edges for answer references
|
68
|
+
for source_q_name, target_q_index in answer_refs:
|
69
|
+
# Find the source question index by name
|
70
|
+
source_q_index = next(
|
71
|
+
i
|
72
|
+
for i, q in enumerate(self.questions)
|
73
|
+
if q.question_name == source_q_name
|
74
|
+
)
|
75
|
+
ref_edge = pydot.Edge(
|
76
|
+
f"Q{source_q_index}",
|
77
|
+
f"Q{target_q_index}",
|
78
|
+
style="dashed",
|
79
|
+
color="purple",
|
80
|
+
label="answer reference",
|
24
81
|
)
|
82
|
+
graph.add_edge(ref_edge)
|
25
83
|
|
26
84
|
# Add an "EndOfSurvey" node
|
27
85
|
graph.add_node(
|
@@ -30,7 +88,7 @@ class SurveyFlowVisualizationMixin:
|
|
30
88
|
|
31
89
|
# Add edges for normal flow through the survey
|
32
90
|
num_questions = len(self.questions)
|
33
|
-
for index in range(num_questions - 1):
|
91
|
+
for index in range(num_questions - 1):
|
34
92
|
graph.add_edge(pydot.Edge(f"Q{index}", f"Q{index+1}"))
|
35
93
|
|
36
94
|
graph.add_edge(pydot.Edge(f"Q{num_questions-1}", "EndOfSurvey"))
|
@@ -64,7 +122,7 @@ class SurveyFlowVisualizationMixin:
|
|
64
122
|
if rule.next_q != EndOfSurvey and rule.next_q < num_questions
|
65
123
|
else "EndOfSurvey"
|
66
124
|
)
|
67
|
-
if rule.before_rule:
|
125
|
+
if rule.before_rule:
|
68
126
|
edge = pydot.Edge(
|
69
127
|
source_node,
|
70
128
|
target_node,
|