edsl 0.1.38.dev2__py3-none-any.whl → 0.1.38.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. edsl/Base.py +60 -31
  2. edsl/__version__.py +1 -1
  3. edsl/agents/Agent.py +18 -9
  4. edsl/agents/AgentList.py +59 -8
  5. edsl/agents/Invigilator.py +18 -7
  6. edsl/agents/InvigilatorBase.py +0 -19
  7. edsl/agents/PromptConstructor.py +5 -4
  8. edsl/config.py +8 -0
  9. edsl/coop/coop.py +74 -7
  10. edsl/data/Cache.py +27 -2
  11. edsl/data/CacheEntry.py +8 -3
  12. edsl/data/RemoteCacheSync.py +0 -19
  13. edsl/enums.py +2 -0
  14. edsl/inference_services/GoogleService.py +7 -15
  15. edsl/inference_services/PerplexityService.py +163 -0
  16. edsl/inference_services/registry.py +2 -0
  17. edsl/jobs/Jobs.py +88 -548
  18. edsl/jobs/JobsChecks.py +147 -0
  19. edsl/jobs/JobsPrompts.py +268 -0
  20. edsl/jobs/JobsRemoteInferenceHandler.py +239 -0
  21. edsl/jobs/interviews/Interview.py +11 -11
  22. edsl/jobs/runners/JobsRunnerAsyncio.py +140 -35
  23. edsl/jobs/runners/JobsRunnerStatus.py +0 -2
  24. edsl/jobs/tasks/TaskHistory.py +15 -16
  25. edsl/language_models/LanguageModel.py +44 -84
  26. edsl/language_models/ModelList.py +47 -1
  27. edsl/language_models/registry.py +57 -4
  28. edsl/prompts/Prompt.py +8 -3
  29. edsl/questions/QuestionBase.py +20 -16
  30. edsl/questions/QuestionExtract.py +3 -4
  31. edsl/questions/question_registry.py +36 -6
  32. edsl/results/CSSParameterizer.py +108 -0
  33. edsl/results/Dataset.py +146 -15
  34. edsl/results/DatasetExportMixin.py +231 -217
  35. edsl/results/DatasetTree.py +134 -4
  36. edsl/results/Result.py +18 -9
  37. edsl/results/Results.py +145 -51
  38. edsl/results/TableDisplay.py +198 -0
  39. edsl/results/table_display.css +78 -0
  40. edsl/scenarios/FileStore.py +187 -13
  41. edsl/scenarios/Scenario.py +61 -4
  42. edsl/scenarios/ScenarioJoin.py +127 -0
  43. edsl/scenarios/ScenarioList.py +237 -62
  44. edsl/surveys/Survey.py +16 -2
  45. edsl/surveys/SurveyFlowVisualizationMixin.py +67 -9
  46. edsl/surveys/instructions/Instruction.py +12 -0
  47. edsl/templates/error_reporting/interview_details.html +3 -3
  48. edsl/templates/error_reporting/interviews.html +18 -9
  49. edsl/utilities/utilities.py +15 -0
  50. {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/METADATA +2 -1
  51. {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/RECORD +53 -45
  52. {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/LICENSE +0 -0
  53. {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/WHEEL +0 -0
@@ -0,0 +1,127 @@
1
+ from __future__ import annotations
2
+ from typing import Union, TYPE_CHECKING
3
+
4
+ # if TYPE_CHECKING:
5
+ from edsl.scenarios.ScenarioList import ScenarioList
6
+ from edsl.scenarios.Scenario import Scenario
7
+
8
+
9
+ class ScenarioJoin:
10
+ """Handles join operations between two ScenarioLists.
11
+
12
+ This class encapsulates all join-related logic, making it easier to maintain
13
+ and extend with other join types (inner, right, full) in the future.
14
+ """
15
+
16
+ def __init__(self, left: "ScenarioList", right: "ScenarioList"):
17
+ """Initialize join operation with two ScenarioLists.
18
+
19
+ Args:
20
+ left: The left ScenarioList
21
+ right: The right ScenarioList
22
+ """
23
+ self.left = left
24
+ self.right = right
25
+
26
+ def left_join(self, by: Union[str, list[str]]) -> ScenarioList:
27
+ """Perform a left join between the two ScenarioLists.
28
+
29
+ Args:
30
+ by: String or list of strings representing the key(s) to join on. Cannot be empty.
31
+
32
+ Returns:
33
+ A new ScenarioList containing the joined scenarios
34
+
35
+ Raises:
36
+ ValueError: If by is empty or if any join keys don't exist in both ScenarioLists
37
+ """
38
+ self._validate_join_keys(by)
39
+ by_keys = [by] if isinstance(by, str) else by
40
+
41
+ other_dict = self._create_lookup_dict(self.right, by_keys)
42
+ all_keys = self._get_all_keys()
43
+
44
+ return ScenarioList(
45
+ self._create_joined_scenarios(by_keys, other_dict, all_keys)
46
+ )
47
+
48
+ def _validate_join_keys(self, by: Union[str, list[str]]) -> None:
49
+ """Validate join keys exist in both ScenarioLists."""
50
+ if not by:
51
+ raise ValueError(
52
+ "Join keys cannot be empty. Please specify at least one key to join on."
53
+ )
54
+
55
+ by_keys = [by] if isinstance(by, str) else by
56
+ left_keys = set(next(iter(self.left)).keys()) if self.left else set()
57
+ right_keys = set(next(iter(self.right)).keys()) if self.right else set()
58
+
59
+ missing_left = set(by_keys) - left_keys
60
+ missing_right = set(by_keys) - right_keys
61
+ if missing_left or missing_right:
62
+ missing = missing_left | missing_right
63
+ raise ValueError(f"Join key(s) {missing} not found in both ScenarioLists")
64
+
65
+ @staticmethod
66
+ def _get_key_tuple(scenario: Scenario, keys: list[str]) -> tuple:
67
+ """Create a tuple of values for the join keys."""
68
+ return tuple(scenario[k] for k in keys)
69
+
70
+ def _create_lookup_dict(self, scenarios: ScenarioList, by_keys: list[str]) -> dict:
71
+ """Create a lookup dictionary for the right scenarios."""
72
+ return {
73
+ self._get_key_tuple(scenario, by_keys): scenario for scenario in scenarios
74
+ }
75
+
76
+ def _get_all_keys(self) -> set:
77
+ """Get all unique keys from both ScenarioLists."""
78
+ all_keys = set()
79
+ for scenario in self.left:
80
+ all_keys.update(scenario.keys())
81
+ for scenario in self.right:
82
+ all_keys.update(scenario.keys())
83
+ return all_keys
84
+
85
+ def _create_joined_scenarios(
86
+ self, by_keys: list[str], other_dict: dict, all_keys: set
87
+ ) -> list[Scenario]:
88
+ """Create the joined scenarios."""
89
+ new_scenarios = []
90
+
91
+ for scenario in self.left:
92
+ new_scenario = {key: None for key in all_keys}
93
+ new_scenario.update(scenario)
94
+
95
+ key_tuple = self._get_key_tuple(scenario, by_keys)
96
+ if matching_scenario := other_dict.get(key_tuple):
97
+ self._handle_matching_scenario(
98
+ new_scenario, scenario, matching_scenario, by_keys
99
+ )
100
+
101
+ new_scenarios.append(Scenario(new_scenario))
102
+
103
+ return new_scenarios
104
+
105
+ def _handle_matching_scenario(
106
+ self,
107
+ new_scenario: dict,
108
+ left_scenario: Scenario,
109
+ right_scenario: Scenario,
110
+ by_keys: list[str],
111
+ ) -> None:
112
+ """Handle merging of matching scenarios and conflict warnings."""
113
+ overlapping_keys = set(left_scenario.keys()) & set(right_scenario.keys())
114
+
115
+ for key in overlapping_keys:
116
+ if key not in by_keys and left_scenario[key] != right_scenario[key]:
117
+ join_conditions = [f"{k}='{left_scenario[k]}'" for k in by_keys]
118
+ print(
119
+ f"Warning: Conflicting values for key '{key}' where "
120
+ f"{' AND '.join(join_conditions)}. "
121
+ f"Keeping left value: {left_scenario[key]} "
122
+ f"(discarding: {right_scenario[key]})"
123
+ )
124
+
125
+ # Only update with non-overlapping keys from matching scenario
126
+ new_keys = set(right_scenario.keys()) - set(left_scenario.keys())
127
+ new_scenario.update({k: right_scenario[k] for k in new_keys})
@@ -31,6 +31,10 @@ class ScenarioListMixin(ScenarioListPdfMixin, ScenarioListExportMixin):
31
31
  class ScenarioList(Base, UserList, ScenarioListMixin):
32
32
  """Class for creating a list of scenarios to be used in a survey."""
33
33
 
34
+ __documentation__ = (
35
+ "https://docs.expectedparrot.com/en/latest/scenarios.html#scenariolist"
36
+ )
37
+
34
38
  def __init__(self, data: Optional[list] = None, codebook: Optional[dict] = None):
35
39
  """Initialize the ScenarioList class."""
36
40
  if data is not None:
@@ -241,6 +245,9 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
241
245
 
242
246
  return dict_hash(self.to_dict(sort=True, add_edsl_version=False))
243
247
 
248
+ def __eq__(self, other: Any) -> bool:
249
+ return hash(self) == hash(other)
250
+
244
251
  def __repr__(self):
245
252
  return f"ScenarioList({self.data})"
246
253
 
@@ -282,41 +289,49 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
282
289
  random.shuffle(self.data)
283
290
  return self
284
291
 
285
- def _repr_html_(self) -> str:
286
- from edsl.utilities.utilities import data_to_html
287
-
288
- data = self.to_dict()
289
- _ = data.pop("edsl_version")
290
- _ = data.pop("edsl_class_name")
291
- for s in data["scenarios"]:
292
- _ = s.pop("edsl_version")
293
- _ = s.pop("edsl_class_name")
294
- for scenario in data["scenarios"]:
295
- for key, value in scenario.items():
296
- if hasattr(value, "to_dict"):
297
- data[key] = value.to_dict()
298
- return data_to_html(data)
299
-
300
- def tally(self, field) -> dict:
301
- """Return a tally of the values in the field.
302
-
303
- Example:
304
-
305
- >>> s = ScenarioList([Scenario({'a': 1, 'b': 1}), Scenario({'a': 1, 'b': 2})])
306
- >>> s.tally('b')
307
- {1: 1, 2: 1}
308
- """
309
- return dict(Counter([scenario[field] for scenario in self]))
310
-
311
- def sample(self, n: int, seed="edsl") -> ScenarioList:
292
+ def _repr_html_(self):
293
+ """Return an HTML representation of the AgentList."""
294
+ # return (
295
+ # str(self.summary(format="html")) + "<br>" + str(self.table(tablefmt="html"))
296
+ # )
297
+ footer = f"<a href={self.__documentation__}>(docs)</a>"
298
+ return str(self.summary(format="html")) + footer
299
+
300
+ # def _repr_html_(self) -> str:
301
+ # from edsl.utilities.utilities import data_to_html
302
+
303
+ # data = self.to_dict()
304
+ # _ = data.pop("edsl_version")
305
+ # _ = data.pop("edsl_class_name")
306
+ # for s in data["scenarios"]:
307
+ # _ = s.pop("edsl_version")
308
+ # _ = s.pop("edsl_class_name")
309
+ # for scenario in data["scenarios"]:
310
+ # for key, value in scenario.items():
311
+ # if hasattr(value, "to_dict"):
312
+ # data[key] = value.to_dict()
313
+ # return data_to_html(data)
314
+
315
+ # def tally(self, field) -> dict:
316
+ # """Return a tally of the values in the field.
317
+
318
+ # Example:
319
+
320
+ # >>> s = ScenarioList([Scenario({'a': 1, 'b': 1}), Scenario({'a': 1, 'b': 2})])
321
+ # >>> s.tally('b')
322
+ # {1: 1, 2: 1}
323
+ # """
324
+ # return dict(Counter([scenario[field] for scenario in self]))
325
+
326
+ def sample(self, n: int, seed: Optional[str] = None) -> ScenarioList:
312
327
  """Return a random sample from the ScenarioList
313
328
 
314
329
  >>> s = ScenarioList.from_list("a", [1,2,3,4,5,6])
315
- >>> s.sample(3)
330
+ >>> s.sample(3, seed = "edsl")
316
331
  ScenarioList([Scenario({'a': 2}), Scenario({'a': 1}), Scenario({'a': 3})])
317
332
  """
318
-
319
- random.seed(seed)
333
+ if seed:
334
+ random.seed(seed)
320
335
 
321
336
  return ScenarioList(random.sample(self.data, n))
322
337
 
@@ -564,6 +579,47 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
564
579
  func = lambda x: x
565
580
  return cls([Scenario({name: func(value)}) for value in values])
566
581
 
582
+ def table(self, *fields, tablefmt=None, pretty_labels=None) -> str:
583
+ """Return the ScenarioList as a table."""
584
+
585
+ from tabulate import tabulate_formats
586
+
587
+ if tablefmt is not None and tablefmt not in tabulate_formats:
588
+ raise ValueError(
589
+ f"Invalid table format: {tablefmt}",
590
+ f"Valid formats are: {tabulate_formats}",
591
+ )
592
+ return self.to_dataset().table(
593
+ *fields, tablefmt=tablefmt, pretty_labels=pretty_labels
594
+ )
595
+
596
+ def tree(self, node_list: Optional[List[str]] = None) -> str:
597
+ """Return the ScenarioList as a tree."""
598
+ return self.to_dataset().tree(node_list)
599
+
600
+ def _summary(self):
601
+ d = {
602
+ "EDSL Class name": "ScenarioList",
603
+ "# Scenarios": len(self),
604
+ "Scenario Keys": list(self.parameters),
605
+ }
606
+ return d
607
+
608
+ def reorder_keys(self, new_order):
609
+ """Reorder the keys in the scenarios.
610
+
611
+ Example:
612
+
613
+ >>> s = ScenarioList([Scenario({'a': 1, 'b': 2}), Scenario({'a': 3, 'b': 4})])
614
+ >>> s.reorder_keys(['b', 'a'])
615
+ ScenarioList([Scenario({'b': 2, 'a': 1}), Scenario({'b': 4, 'a': 3})])
616
+ """
617
+ new_scenarios = []
618
+ for scenario in self:
619
+ new_scenario = Scenario({key: scenario[key] for key in new_order})
620
+ new_scenarios.append(new_scenario)
621
+ return ScenarioList(new_scenarios)
622
+
567
623
  def to_dataset(self) -> "Dataset":
568
624
  """
569
625
  >>> s = ScenarioList.from_list("a", [1,2,3])
@@ -579,16 +635,32 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
579
635
  data = [{key: [scenario[key] for scenario in self.data]} for key in keys]
580
636
  return Dataset(data)
581
637
 
582
- def split(
583
- self, field: str, split_on: str, index: int, new_name: Optional[str] = None
638
+ def unpack(
639
+ self, field: str, new_names: Optional[List[str]] = None, keep_original=True
584
640
  ) -> ScenarioList:
585
- """Split a scenario fiel in multiple fields."""
586
- if new_name is None:
587
- new_name = field + "_split_" + str(index)
641
+ """Unpack a field into multiple fields.
642
+
643
+ Example:
644
+
645
+ >>> s = ScenarioList([Scenario({'a': 1, 'b': [2, True]}), Scenario({'a': 3, 'b': [3, False]})])
646
+ >>> s.unpack('b')
647
+ ScenarioList([Scenario({'a': 1, 'b': [2, True], 'b_0': 2, 'b_1': True}), Scenario({'a': 3, 'b': [3, False], 'b_0': 3, 'b_1': False})])
648
+ >>> s.unpack('b', new_names=['c', 'd'], keep_original=False)
649
+ ScenarioList([Scenario({'a': 1, 'c': 2, 'd': True}), Scenario({'a': 3, 'c': 3, 'd': False})])
650
+
651
+ """
652
+ new_names = new_names or [f"{field}_{i}" for i in range(len(self[0][field]))]
588
653
  new_scenarios = []
589
654
  for scenario in self:
590
655
  new_scenario = scenario.copy()
591
- new_scenario[new_name] = scenario[field].split(split_on)[index]
656
+ if len(new_names) == 1:
657
+ new_scenario[new_names[0]] = scenario[field]
658
+ else:
659
+ for i, new_name in enumerate(new_names):
660
+ new_scenario[new_name] = scenario[field][i]
661
+
662
+ if not keep_original:
663
+ del new_scenario[field]
592
664
  new_scenarios.append(new_scenario)
593
665
  return ScenarioList(new_scenarios)
594
666
 
@@ -901,33 +973,32 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
901
973
  return cls.from_excel(temp_filename, sheet_name=sheet_name)
902
974
 
903
975
  @classmethod
904
- def from_csv(cls, source: Union[str, urllib.parse.ParseResult]) -> ScenarioList:
905
- """Create a ScenarioList from a CSV file or URL.
976
+ def from_delimited_file(
977
+ cls, source: Union[str, urllib.parse.ParseResult], delimiter: str = ","
978
+ ) -> ScenarioList:
979
+ """Create a ScenarioList from a delimited file (CSV/TSV) or URL.
906
980
 
907
981
  Args:
908
- source: A string representing either a local file path or a URL to a CSV file,
982
+ source: A string representing either a local file path or a URL to a delimited file,
909
983
  or a urllib.parse.ParseResult object for a URL.
984
+ delimiter: The delimiter used in the file. Defaults to ',' for CSV files.
985
+ Use '\t' for TSV files.
910
986
 
911
987
  Returns:
912
- ScenarioList: A ScenarioList object containing the data from the CSV.
988
+ ScenarioList: A ScenarioList object containing the data from the file.
913
989
 
914
990
  Example:
991
+ # For CSV files
915
992
 
916
- >>> import tempfile
917
- >>> import os
918
- >>> with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.csv') as f:
919
- ... _ = f.write("name,age,location\\nAlice,30,New York\\nBob,25,Los Angeles\\n")
920
- ... temp_filename = f.name
921
- >>> scenario_list = ScenarioList.from_csv(temp_filename)
922
- >>> len(scenario_list)
923
- 2
924
- >>> scenario_list[0]['name']
925
- 'Alice'
926
- >>> scenario_list[1]['age']
927
- '25'
993
+ >>> with open('data.csv', 'w') as f:
994
+ ... _ = f.write('name,age\\nAlice,30\\nBob,25\\n')
995
+ >>> scenario_list = ScenarioList.from_delimited_file('data.csv')
996
+
997
+ # For TSV files
998
+ >>> with open('data.tsv', 'w') as f:
999
+ ... _ = f.write('name\\tage\\nAlice\t30\\nBob\t25\\n')
1000
+ >>> scenario_list = ScenarioList.from_delimited_file('data.tsv', delimiter='\\t')
928
1001
 
929
- >>> url = "https://example.com/data.csv"
930
- >>> ## scenario_list_from_url = ScenarioList.from_csv(url)
931
1002
  """
932
1003
  from edsl.scenarios.Scenario import Scenario
933
1004
 
@@ -940,24 +1011,111 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
940
1011
 
941
1012
  if isinstance(source, str) and is_url(source):
942
1013
  with urllib.request.urlopen(source) as response:
943
- csv_content = response.read().decode("utf-8")
944
- csv_file = StringIO(csv_content)
1014
+ file_content = response.read().decode("utf-8")
1015
+ file_obj = StringIO(file_content)
945
1016
  elif isinstance(source, urllib.parse.ParseResult):
946
1017
  with urllib.request.urlopen(source.geturl()) as response:
947
- csv_content = response.read().decode("utf-8")
948
- csv_file = StringIO(csv_content)
1018
+ file_content = response.read().decode("utf-8")
1019
+ file_obj = StringIO(file_content)
949
1020
  else:
950
- csv_file = open(source, "r")
1021
+ file_obj = open(source, "r")
951
1022
 
952
1023
  try:
953
- reader = csv.reader(csv_file)
1024
+ reader = csv.reader(file_obj, delimiter=delimiter)
954
1025
  header = next(reader)
955
1026
  observations = [Scenario(dict(zip(header, row))) for row in reader]
956
1027
  finally:
957
- csv_file.close()
1028
+ file_obj.close()
958
1029
 
959
1030
  return cls(observations)
960
1031
 
1032
+ # Convenience methods for specific file types
1033
+ @classmethod
1034
+ def from_csv(cls, source: Union[str, urllib.parse.ParseResult]) -> ScenarioList:
1035
+ """Create a ScenarioList from a CSV file or URL."""
1036
+ return cls.from_delimited_file(source, delimiter=",")
1037
+
1038
+ def left_join(self, other: ScenarioList, by: Union[str, list[str]]) -> ScenarioList:
1039
+ """Perform a left join with another ScenarioList, following SQL join semantics.
1040
+
1041
+ Args:
1042
+ other: The ScenarioList to join with
1043
+ by: String or list of strings representing the key(s) to join on. Cannot be empty.
1044
+
1045
+ >>> s1 = ScenarioList([Scenario({'name': 'Alice', 'age': 30}), Scenario({'name': 'Bob', 'age': 25})])
1046
+ >>> s2 = ScenarioList([Scenario({'name': 'Alice', 'location': 'New York'}), Scenario({'name': 'Charlie', 'location': 'Los Angeles'})])
1047
+ >>> s3 = s1.left_join(s2, 'name')
1048
+ >>> s3 == ScenarioList([Scenario({'age': 30, 'location': 'New York', 'name': 'Alice'}), Scenario({'age': 25, 'location': None, 'name': 'Bob'})])
1049
+ True
1050
+ """
1051
+ from edsl.scenarios.ScenarioJoin import ScenarioJoin
1052
+
1053
+ sj = ScenarioJoin(self, other)
1054
+ return sj.left_join(by)
1055
+ # # Validate join keys
1056
+ # if not by:
1057
+ # raise ValueError(
1058
+ # "Join keys cannot be empty. Please specify at least one key to join on."
1059
+ # )
1060
+
1061
+ # # Convert single string to list for consistent handling
1062
+ # by_keys = [by] if isinstance(by, str) else by
1063
+
1064
+ # # Verify all join keys exist in both ScenarioLists
1065
+ # left_keys = set(next(iter(self)).keys()) if self else set()
1066
+ # right_keys = set(next(iter(other)).keys()) if other else set()
1067
+
1068
+ # missing_left = set(by_keys) - left_keys
1069
+ # missing_right = set(by_keys) - right_keys
1070
+ # if missing_left or missing_right:
1071
+ # missing = missing_left | missing_right
1072
+ # raise ValueError(f"Join key(s) {missing} not found in both ScenarioLists")
1073
+
1074
+ # # Create lookup dictionary from the other ScenarioList
1075
+ # def get_key_tuple(scenario: Scenario, keys: list[str]) -> tuple:
1076
+ # return tuple(scenario[k] for k in keys)
1077
+
1078
+ # other_dict = {get_key_tuple(scenario, by_keys): scenario for scenario in other}
1079
+
1080
+ # # Collect all possible keys (like SQL combining all columns)
1081
+ # all_keys = set()
1082
+ # for scenario in self:
1083
+ # all_keys.update(scenario.keys())
1084
+ # for scenario in other:
1085
+ # all_keys.update(scenario.keys())
1086
+
1087
+ # new_scenarios = []
1088
+ # for scenario in self:
1089
+ # new_scenario = {
1090
+ # key: None for key in all_keys
1091
+ # } # Start with nulls (like SQL)
1092
+ # new_scenario.update(scenario) # Add all left values
1093
+
1094
+ # key_tuple = get_key_tuple(scenario, by_keys)
1095
+ # if matching_scenario := other_dict.get(key_tuple):
1096
+ # # Check for overlapping keys with different values
1097
+ # overlapping_keys = set(scenario.keys()) & set(matching_scenario.keys())
1098
+ # for key in overlapping_keys:
1099
+ # if key not in by_keys and scenario[key] != matching_scenario[key]:
1100
+ # join_conditions = [f"{k}='{scenario[k]}'" for k in by_keys]
1101
+ # print(
1102
+ # f"Warning: Conflicting values for key '{key}' where {' AND '.join(join_conditions)}. "
1103
+ # f"Keeping left value: {scenario[key]} (discarding: {matching_scenario[key]})"
1104
+ # )
1105
+
1106
+ # # Only update with non-overlapping keys from matching scenario
1107
+ # new_keys = set(matching_scenario.keys()) - set(scenario.keys())
1108
+ # new_scenario.update({k: matching_scenario[k] for k in new_keys})
1109
+
1110
+ # new_scenarios.append(Scenario(new_scenario))
1111
+
1112
+ # return ScenarioList(new_scenarios)
1113
+
1114
+ @classmethod
1115
+ def from_tsv(cls, source: Union[str, urllib.parse.ParseResult]) -> ScenarioList:
1116
+ """Create a ScenarioList from a TSV file or URL."""
1117
+ return cls.from_delimited_file(source, delimiter="\t")
1118
+
961
1119
  def to_dict(self, sort=False, add_edsl_version=True) -> dict:
962
1120
  """
963
1121
  >>> s = ScenarioList([Scenario({'food': 'wood chips'}), Scenario({'food': 'wood-fired pizza'})])
@@ -1074,8 +1232,25 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
1074
1232
  """
1075
1233
  from edsl.agents.AgentList import AgentList
1076
1234
  from edsl.agents.Agent import Agent
1235
+ import warnings
1236
+
1237
+ agents = []
1238
+ for scenario in self:
1239
+ new_scenario = scenario.copy().data
1240
+ if "name" in new_scenario:
1241
+ name = new_scenario.pop("name")
1242
+ proposed_agent_name = "agent_name"
1243
+ while proposed_agent_name not in new_scenario:
1244
+ proposed_agent_name += "_"
1245
+ warnings.warn(
1246
+ f"The 'name' field is reserved for the agent's name---putting this value in {proposed_agent_name}"
1247
+ )
1248
+ new_scenario[proposed_agent_name] = name
1249
+ agents.append(Agent(traits=new_scenario, name=name))
1250
+ else:
1251
+ agents.append(Agent(traits=new_scenario))
1077
1252
 
1078
- return AgentList([Agent(traits=s.data) for s in self])
1253
+ return AgentList(agents)
1079
1254
 
1080
1255
  def chunk(
1081
1256
  self,
edsl/surveys/Survey.py CHANGED
@@ -41,6 +41,8 @@ class ValidatedString(str):
41
41
  class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
42
42
  """A collection of questions that supports skip logic."""
43
43
 
44
+ __documentation__ = """https://docs.expectedparrot.com/en/latest/surveys.html"""
45
+
44
46
  questions = QuestionsDescriptor()
45
47
  """
46
48
  A collection of questions that supports skip logic.
@@ -1587,10 +1589,22 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
1587
1589
  # question_names_string = ", ".join([repr(name) for name in self.question_names])
1588
1590
  return f"Survey(questions=[{questions_string}], memory_plan={self.memory_plan}, rule_collection={self.rule_collection}, question_groups={self.question_groups})"
1589
1591
 
1592
+ def _summary(self) -> dict:
1593
+ return {
1594
+ "EDSL Class": "Survey",
1595
+ "Number of Questions": len(self),
1596
+ "Question Names": self.question_names,
1597
+ }
1598
+
1590
1599
  def _repr_html_(self) -> str:
1591
- from edsl.utilities.utilities import data_to_html
1600
+ footer = f"<a href={self.__documentation__}>(docs)</a>"
1601
+ return str(self.summary(format="html")) + footer
1602
+
1603
+ def tree(self, node_list: Optional[List[str]] = None):
1604
+ return self.to_scenario_list().tree(node_list=node_list)
1592
1605
 
1593
- return data_to_html(self.to_dict())
1606
+ def table(self, *fields, tablefmt=None) -> Table:
1607
+ return self.to_scenario_list().to_dataset().table(*fields, tablefmt=tablefmt)
1594
1608
 
1595
1609
  def rich_print(self) -> Table:
1596
1610
  """Print the survey in a rich format.
@@ -1,27 +1,85 @@
1
- """A mixin for visualizing the flow of a survey."""
1
+ """A mixin for visualizing the flow of a survey with parameter nodes."""
2
2
 
3
3
  from typing import Optional
4
4
  from edsl.surveys.base import RulePriority, EndOfSurvey
5
5
  import tempfile
6
+ import os
6
7
 
7
8
 
8
9
  class SurveyFlowVisualizationMixin:
9
- """A mixin for visualizing the flow of a survey."""
10
+ """A mixin for visualizing the flow of a survey with parameter visualization."""
10
11
 
11
12
  def show_flow(self, filename: Optional[str] = None):
12
- """Create an image showing the flow of users through the survey."""
13
+ """Create an image showing the flow of users through the survey and question parameters."""
13
14
  # Create a graph object
14
15
  import pydot
15
16
 
16
17
  graph = pydot.Dot(graph_type="digraph")
17
18
 
18
- # Add nodes for each question
19
+ # First collect all unique parameters and answer references
20
+ params_and_refs = set()
21
+ param_to_questions = {} # Keep track of which questions use each parameter
22
+ answer_refs = set() # Track answer references between questions
23
+
24
+ # First pass: collect parameters and their question associations
19
25
  for index, question in enumerate(self.questions):
20
- graph.add_node(
21
- pydot.Node(
22
- f"Q{index}", label=f"{question.question_name}", shape="ellipse"
26
+ # Add the main question node
27
+ question_node = pydot.Node(
28
+ f"Q{index}", label=f"{question.question_name}", shape="ellipse"
29
+ )
30
+ graph.add_node(question_node)
31
+
32
+ if hasattr(question, "parameters"):
33
+ for param in question.parameters:
34
+ # Check if this is an answer reference (contains '.answer')
35
+ if ".answer" in param:
36
+ answer_refs.add((param.split(".")[0], index))
37
+ else:
38
+ params_and_refs.add(param)
39
+ if param not in param_to_questions:
40
+ param_to_questions[param] = []
41
+ param_to_questions[param].append(index)
42
+
43
+ # Create parameter nodes and connect them to questions
44
+ for param in params_and_refs:
45
+ param_node_name = f"param_{param}"
46
+ param_node = pydot.Node(
47
+ param_node_name,
48
+ label=f"{{{{ {param} }}}}",
49
+ shape="box",
50
+ style="filled",
51
+ fillcolor="lightgrey",
52
+ fontsize="10",
53
+ )
54
+ graph.add_node(param_node)
55
+
56
+ # Connect this parameter to all questions that use it
57
+ for q_index in param_to_questions[param]:
58
+ param_edge = pydot.Edge(
59
+ param_node_name,
60
+ f"Q{q_index}",
61
+ style="dotted",
62
+ color="grey",
63
+ arrowsize="0.5",
23
64
  )
65
+ graph.add_edge(param_edge)
66
+
67
+ # Add edges for answer references
68
+ for source_q_name, target_q_index in answer_refs:
69
+ # Find the source question index by name
70
+ source_q_index = next(
71
+ i
72
+ for i, q in enumerate(self.questions)
73
+ if q.question_name == source_q_name
74
+ )
75
+ ref_edge = pydot.Edge(
76
+ f"Q{source_q_index}",
77
+ f"Q{target_q_index}",
78
+ style="dashed",
79
+ color="purple",
80
+ label="answer reference",
24
81
  )
82
+ graph.add_edge(ref_edge)
25
83
 
26
84
  # Add an "EndOfSurvey" node
27
85
  graph.add_node(
@@ -30,7 +88,7 @@ class SurveyFlowVisualizationMixin:
30
88
 
31
89
  # Add edges for normal flow through the survey
32
90
  num_questions = len(self.questions)
33
- for index in range(num_questions - 1): # From Q1 to Q3
91
+ for index in range(num_questions - 1):
34
92
  graph.add_edge(pydot.Edge(f"Q{index}", f"Q{index+1}"))
35
93
 
36
94
  graph.add_edge(pydot.Edge(f"Q{num_questions-1}", "EndOfSurvey"))
@@ -64,7 +122,7 @@ class SurveyFlowVisualizationMixin:
64
122
  if rule.next_q != EndOfSurvey and rule.next_q < num_questions
65
123
  else "EndOfSurvey"
66
124
  )
67
- if rule.before_rule: # Assume skip rules have an attribute `is_skip`
125
+ if rule.before_rule:
68
126
  edge = pydot.Edge(
69
127
  source_node,
70
128
  target_node,