edsl 0.1.37.dev5__py3-none-any.whl → 0.1.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +63 -34
- edsl/BaseDiff.py +7 -7
- edsl/__init__.py +2 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +23 -11
- edsl/agents/AgentList.py +86 -23
- edsl/agents/Invigilator.py +18 -7
- edsl/agents/InvigilatorBase.py +0 -19
- edsl/agents/PromptConstructor.py +5 -4
- edsl/auto/SurveyCreatorPipeline.py +1 -1
- edsl/auto/utilities.py +1 -1
- edsl/base/Base.py +3 -13
- edsl/config.py +8 -0
- edsl/coop/coop.py +89 -19
- edsl/data/Cache.py +45 -17
- edsl/data/CacheEntry.py +8 -3
- edsl/data/RemoteCacheSync.py +0 -19
- edsl/enums.py +2 -0
- edsl/exceptions/agents.py +4 -0
- edsl/exceptions/cache.py +5 -0
- edsl/inference_services/GoogleService.py +7 -15
- edsl/inference_services/PerplexityService.py +163 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +110 -559
- edsl/jobs/JobsChecks.py +147 -0
- edsl/jobs/JobsPrompts.py +268 -0
- edsl/jobs/JobsRemoteInferenceHandler.py +239 -0
- edsl/jobs/buckets/TokenBucket.py +3 -0
- edsl/jobs/interviews/Interview.py +7 -7
- edsl/jobs/runners/JobsRunnerAsyncio.py +156 -28
- edsl/jobs/runners/JobsRunnerStatus.py +194 -196
- edsl/jobs/tasks/TaskHistory.py +27 -19
- edsl/language_models/LanguageModel.py +52 -90
- edsl/language_models/ModelList.py +67 -14
- edsl/language_models/registry.py +57 -4
- edsl/notebooks/Notebook.py +7 -8
- edsl/prompts/Prompt.py +8 -3
- edsl/questions/QuestionBase.py +38 -30
- edsl/questions/QuestionBaseGenMixin.py +1 -1
- edsl/questions/QuestionBasePromptsMixin.py +0 -17
- edsl/questions/QuestionExtract.py +3 -4
- edsl/questions/QuestionFunctional.py +10 -3
- edsl/questions/derived/QuestionTopK.py +2 -0
- edsl/questions/question_registry.py +36 -6
- edsl/results/CSSParameterizer.py +108 -0
- edsl/results/Dataset.py +146 -15
- edsl/results/DatasetExportMixin.py +231 -217
- edsl/results/DatasetTree.py +134 -4
- edsl/results/Result.py +31 -16
- edsl/results/Results.py +159 -65
- edsl/results/TableDisplay.py +198 -0
- edsl/results/table_display.css +78 -0
- edsl/scenarios/FileStore.py +187 -13
- edsl/scenarios/Scenario.py +73 -18
- edsl/scenarios/ScenarioJoin.py +127 -0
- edsl/scenarios/ScenarioList.py +251 -76
- edsl/surveys/MemoryPlan.py +1 -1
- edsl/surveys/Rule.py +1 -5
- edsl/surveys/RuleCollection.py +1 -1
- edsl/surveys/Survey.py +25 -19
- edsl/surveys/SurveyFlowVisualizationMixin.py +67 -9
- edsl/surveys/instructions/ChangeInstruction.py +9 -7
- edsl/surveys/instructions/Instruction.py +21 -7
- edsl/templates/error_reporting/interview_details.html +3 -3
- edsl/templates/error_reporting/interviews.html +18 -9
- edsl/{conjure → utilities}/naming_utilities.py +1 -1
- edsl/utilities/utilities.py +15 -0
- {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/METADATA +2 -1
- {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/RECORD +71 -77
- edsl/conjure/AgentConstructionMixin.py +0 -160
- edsl/conjure/Conjure.py +0 -62
- edsl/conjure/InputData.py +0 -659
- edsl/conjure/InputDataCSV.py +0 -48
- edsl/conjure/InputDataMixinQuestionStats.py +0 -182
- edsl/conjure/InputDataPyRead.py +0 -91
- edsl/conjure/InputDataSPSS.py +0 -8
- edsl/conjure/InputDataStata.py +0 -8
- edsl/conjure/QuestionOptionMixin.py +0 -76
- edsl/conjure/QuestionTypeMixin.py +0 -23
- edsl/conjure/RawQuestion.py +0 -65
- edsl/conjure/SurveyResponses.py +0 -7
- edsl/conjure/__init__.py +0 -9
- edsl/conjure/examples/placeholder.txt +0 -0
- edsl/conjure/utilities.py +0 -201
- {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/LICENSE +0 -0
- {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/WHEEL +0 -0
@@ -0,0 +1,127 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Union, TYPE_CHECKING
|
3
|
+
|
4
|
+
# if TYPE_CHECKING:
|
5
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
6
|
+
from edsl.scenarios.Scenario import Scenario
|
7
|
+
|
8
|
+
|
9
|
+
class ScenarioJoin:
|
10
|
+
"""Handles join operations between two ScenarioLists.
|
11
|
+
|
12
|
+
This class encapsulates all join-related logic, making it easier to maintain
|
13
|
+
and extend with other join types (inner, right, full) in the future.
|
14
|
+
"""
|
15
|
+
|
16
|
+
def __init__(self, left: "ScenarioList", right: "ScenarioList"):
|
17
|
+
"""Initialize join operation with two ScenarioLists.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
left: The left ScenarioList
|
21
|
+
right: The right ScenarioList
|
22
|
+
"""
|
23
|
+
self.left = left
|
24
|
+
self.right = right
|
25
|
+
|
26
|
+
def left_join(self, by: Union[str, list[str]]) -> ScenarioList:
|
27
|
+
"""Perform a left join between the two ScenarioLists.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
by: String or list of strings representing the key(s) to join on. Cannot be empty.
|
31
|
+
|
32
|
+
Returns:
|
33
|
+
A new ScenarioList containing the joined scenarios
|
34
|
+
|
35
|
+
Raises:
|
36
|
+
ValueError: If by is empty or if any join keys don't exist in both ScenarioLists
|
37
|
+
"""
|
38
|
+
self._validate_join_keys(by)
|
39
|
+
by_keys = [by] if isinstance(by, str) else by
|
40
|
+
|
41
|
+
other_dict = self._create_lookup_dict(self.right, by_keys)
|
42
|
+
all_keys = self._get_all_keys()
|
43
|
+
|
44
|
+
return ScenarioList(
|
45
|
+
self._create_joined_scenarios(by_keys, other_dict, all_keys)
|
46
|
+
)
|
47
|
+
|
48
|
+
def _validate_join_keys(self, by: Union[str, list[str]]) -> None:
|
49
|
+
"""Validate join keys exist in both ScenarioLists."""
|
50
|
+
if not by:
|
51
|
+
raise ValueError(
|
52
|
+
"Join keys cannot be empty. Please specify at least one key to join on."
|
53
|
+
)
|
54
|
+
|
55
|
+
by_keys = [by] if isinstance(by, str) else by
|
56
|
+
left_keys = set(next(iter(self.left)).keys()) if self.left else set()
|
57
|
+
right_keys = set(next(iter(self.right)).keys()) if self.right else set()
|
58
|
+
|
59
|
+
missing_left = set(by_keys) - left_keys
|
60
|
+
missing_right = set(by_keys) - right_keys
|
61
|
+
if missing_left or missing_right:
|
62
|
+
missing = missing_left | missing_right
|
63
|
+
raise ValueError(f"Join key(s) {missing} not found in both ScenarioLists")
|
64
|
+
|
65
|
+
@staticmethod
|
66
|
+
def _get_key_tuple(scenario: Scenario, keys: list[str]) -> tuple:
|
67
|
+
"""Create a tuple of values for the join keys."""
|
68
|
+
return tuple(scenario[k] for k in keys)
|
69
|
+
|
70
|
+
def _create_lookup_dict(self, scenarios: ScenarioList, by_keys: list[str]) -> dict:
|
71
|
+
"""Create a lookup dictionary for the right scenarios."""
|
72
|
+
return {
|
73
|
+
self._get_key_tuple(scenario, by_keys): scenario for scenario in scenarios
|
74
|
+
}
|
75
|
+
|
76
|
+
def _get_all_keys(self) -> set:
|
77
|
+
"""Get all unique keys from both ScenarioLists."""
|
78
|
+
all_keys = set()
|
79
|
+
for scenario in self.left:
|
80
|
+
all_keys.update(scenario.keys())
|
81
|
+
for scenario in self.right:
|
82
|
+
all_keys.update(scenario.keys())
|
83
|
+
return all_keys
|
84
|
+
|
85
|
+
def _create_joined_scenarios(
|
86
|
+
self, by_keys: list[str], other_dict: dict, all_keys: set
|
87
|
+
) -> list[Scenario]:
|
88
|
+
"""Create the joined scenarios."""
|
89
|
+
new_scenarios = []
|
90
|
+
|
91
|
+
for scenario in self.left:
|
92
|
+
new_scenario = {key: None for key in all_keys}
|
93
|
+
new_scenario.update(scenario)
|
94
|
+
|
95
|
+
key_tuple = self._get_key_tuple(scenario, by_keys)
|
96
|
+
if matching_scenario := other_dict.get(key_tuple):
|
97
|
+
self._handle_matching_scenario(
|
98
|
+
new_scenario, scenario, matching_scenario, by_keys
|
99
|
+
)
|
100
|
+
|
101
|
+
new_scenarios.append(Scenario(new_scenario))
|
102
|
+
|
103
|
+
return new_scenarios
|
104
|
+
|
105
|
+
def _handle_matching_scenario(
|
106
|
+
self,
|
107
|
+
new_scenario: dict,
|
108
|
+
left_scenario: Scenario,
|
109
|
+
right_scenario: Scenario,
|
110
|
+
by_keys: list[str],
|
111
|
+
) -> None:
|
112
|
+
"""Handle merging of matching scenarios and conflict warnings."""
|
113
|
+
overlapping_keys = set(left_scenario.keys()) & set(right_scenario.keys())
|
114
|
+
|
115
|
+
for key in overlapping_keys:
|
116
|
+
if key not in by_keys and left_scenario[key] != right_scenario[key]:
|
117
|
+
join_conditions = [f"{k}='{left_scenario[k]}'" for k in by_keys]
|
118
|
+
print(
|
119
|
+
f"Warning: Conflicting values for key '{key}' where "
|
120
|
+
f"{' AND '.join(join_conditions)}. "
|
121
|
+
f"Keeping left value: {left_scenario[key]} "
|
122
|
+
f"(discarding: {right_scenario[key]})"
|
123
|
+
)
|
124
|
+
|
125
|
+
# Only update with non-overlapping keys from matching scenario
|
126
|
+
new_keys = set(right_scenario.keys()) - set(left_scenario.keys())
|
127
|
+
new_scenario.update({k: right_scenario[k] for k in new_keys})
|
edsl/scenarios/ScenarioList.py
CHANGED
@@ -20,7 +20,7 @@ from edsl.scenarios.Scenario import Scenario
|
|
20
20
|
from edsl.scenarios.ScenarioListPdfMixin import ScenarioListPdfMixin
|
21
21
|
from edsl.scenarios.ScenarioListExportMixin import ScenarioListExportMixin
|
22
22
|
|
23
|
-
from edsl.
|
23
|
+
from edsl.utilities.naming_utilities import sanitize_string
|
24
24
|
from edsl.utilities.utilities import is_valid_variable_name
|
25
25
|
|
26
26
|
|
@@ -31,6 +31,10 @@ class ScenarioListMixin(ScenarioListPdfMixin, ScenarioListExportMixin):
|
|
31
31
|
class ScenarioList(Base, UserList, ScenarioListMixin):
|
32
32
|
"""Class for creating a list of scenarios to be used in a survey."""
|
33
33
|
|
34
|
+
__documentation__ = (
|
35
|
+
"https://docs.expectedparrot.com/en/latest/scenarios.html#scenariolist"
|
36
|
+
)
|
37
|
+
|
34
38
|
def __init__(self, data: Optional[list] = None, codebook: Optional[dict] = None):
|
35
39
|
"""Initialize the ScenarioList class."""
|
36
40
|
if data is not None:
|
@@ -239,7 +243,10 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
239
243
|
"""
|
240
244
|
from edsl.utilities.utilities import dict_hash
|
241
245
|
|
242
|
-
return dict_hash(self.
|
246
|
+
return dict_hash(self.to_dict(sort=True, add_edsl_version=False))
|
247
|
+
|
248
|
+
def __eq__(self, other: Any) -> bool:
|
249
|
+
return hash(self) == hash(other)
|
243
250
|
|
244
251
|
def __repr__(self):
|
245
252
|
return f"ScenarioList({self.data})"
|
@@ -282,41 +289,49 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
282
289
|
random.shuffle(self.data)
|
283
290
|
return self
|
284
291
|
|
285
|
-
def _repr_html_(self)
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
292
|
+
def _repr_html_(self):
|
293
|
+
"""Return an HTML representation of the AgentList."""
|
294
|
+
# return (
|
295
|
+
# str(self.summary(format="html")) + "<br>" + str(self.table(tablefmt="html"))
|
296
|
+
# )
|
297
|
+
footer = f"<a href={self.__documentation__}>(docs)</a>"
|
298
|
+
return str(self.summary(format="html")) + footer
|
299
|
+
|
300
|
+
# def _repr_html_(self) -> str:
|
301
|
+
# from edsl.utilities.utilities import data_to_html
|
302
|
+
|
303
|
+
# data = self.to_dict()
|
304
|
+
# _ = data.pop("edsl_version")
|
305
|
+
# _ = data.pop("edsl_class_name")
|
306
|
+
# for s in data["scenarios"]:
|
307
|
+
# _ = s.pop("edsl_version")
|
308
|
+
# _ = s.pop("edsl_class_name")
|
309
|
+
# for scenario in data["scenarios"]:
|
310
|
+
# for key, value in scenario.items():
|
311
|
+
# if hasattr(value, "to_dict"):
|
312
|
+
# data[key] = value.to_dict()
|
313
|
+
# return data_to_html(data)
|
314
|
+
|
315
|
+
# def tally(self, field) -> dict:
|
316
|
+
# """Return a tally of the values in the field.
|
317
|
+
|
318
|
+
# Example:
|
319
|
+
|
320
|
+
# >>> s = ScenarioList([Scenario({'a': 1, 'b': 1}), Scenario({'a': 1, 'b': 2})])
|
321
|
+
# >>> s.tally('b')
|
322
|
+
# {1: 1, 2: 1}
|
323
|
+
# """
|
324
|
+
# return dict(Counter([scenario[field] for scenario in self]))
|
325
|
+
|
326
|
+
def sample(self, n: int, seed: Optional[str] = None) -> ScenarioList:
|
312
327
|
"""Return a random sample from the ScenarioList
|
313
328
|
|
314
329
|
>>> s = ScenarioList.from_list("a", [1,2,3,4,5,6])
|
315
|
-
>>> s.sample(3)
|
330
|
+
>>> s.sample(3, seed = "edsl")
|
316
331
|
ScenarioList([Scenario({'a': 2}), Scenario({'a': 1}), Scenario({'a': 3})])
|
317
332
|
"""
|
318
|
-
|
319
|
-
|
333
|
+
if seed:
|
334
|
+
random.seed(seed)
|
320
335
|
|
321
336
|
return ScenarioList(random.sample(self.data, n))
|
322
337
|
|
@@ -564,6 +579,47 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
564
579
|
func = lambda x: x
|
565
580
|
return cls([Scenario({name: func(value)}) for value in values])
|
566
581
|
|
582
|
+
def table(self, *fields, tablefmt=None, pretty_labels=None) -> str:
|
583
|
+
"""Return the ScenarioList as a table."""
|
584
|
+
|
585
|
+
from tabulate import tabulate_formats
|
586
|
+
|
587
|
+
if tablefmt is not None and tablefmt not in tabulate_formats:
|
588
|
+
raise ValueError(
|
589
|
+
f"Invalid table format: {tablefmt}",
|
590
|
+
f"Valid formats are: {tabulate_formats}",
|
591
|
+
)
|
592
|
+
return self.to_dataset().table(
|
593
|
+
*fields, tablefmt=tablefmt, pretty_labels=pretty_labels
|
594
|
+
)
|
595
|
+
|
596
|
+
def tree(self, node_list: Optional[List[str]] = None) -> str:
|
597
|
+
"""Return the ScenarioList as a tree."""
|
598
|
+
return self.to_dataset().tree(node_list)
|
599
|
+
|
600
|
+
def _summary(self):
|
601
|
+
d = {
|
602
|
+
"EDSL Class name": "ScenarioList",
|
603
|
+
"# Scenarios": len(self),
|
604
|
+
"Scenario Keys": list(self.parameters),
|
605
|
+
}
|
606
|
+
return d
|
607
|
+
|
608
|
+
def reorder_keys(self, new_order):
|
609
|
+
"""Reorder the keys in the scenarios.
|
610
|
+
|
611
|
+
Example:
|
612
|
+
|
613
|
+
>>> s = ScenarioList([Scenario({'a': 1, 'b': 2}), Scenario({'a': 3, 'b': 4})])
|
614
|
+
>>> s.reorder_keys(['b', 'a'])
|
615
|
+
ScenarioList([Scenario({'b': 2, 'a': 1}), Scenario({'b': 4, 'a': 3})])
|
616
|
+
"""
|
617
|
+
new_scenarios = []
|
618
|
+
for scenario in self:
|
619
|
+
new_scenario = Scenario({key: scenario[key] for key in new_order})
|
620
|
+
new_scenarios.append(new_scenario)
|
621
|
+
return ScenarioList(new_scenarios)
|
622
|
+
|
567
623
|
def to_dataset(self) -> "Dataset":
|
568
624
|
"""
|
569
625
|
>>> s = ScenarioList.from_list("a", [1,2,3])
|
@@ -579,16 +635,32 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
579
635
|
data = [{key: [scenario[key] for scenario in self.data]} for key in keys]
|
580
636
|
return Dataset(data)
|
581
637
|
|
582
|
-
def
|
583
|
-
self, field: str,
|
638
|
+
def unpack(
|
639
|
+
self, field: str, new_names: Optional[List[str]] = None, keep_original=True
|
584
640
|
) -> ScenarioList:
|
585
|
-
"""
|
586
|
-
|
587
|
-
|
641
|
+
"""Unpack a field into multiple fields.
|
642
|
+
|
643
|
+
Example:
|
644
|
+
|
645
|
+
>>> s = ScenarioList([Scenario({'a': 1, 'b': [2, True]}), Scenario({'a': 3, 'b': [3, False]})])
|
646
|
+
>>> s.unpack('b')
|
647
|
+
ScenarioList([Scenario({'a': 1, 'b': [2, True], 'b_0': 2, 'b_1': True}), Scenario({'a': 3, 'b': [3, False], 'b_0': 3, 'b_1': False})])
|
648
|
+
>>> s.unpack('b', new_names=['c', 'd'], keep_original=False)
|
649
|
+
ScenarioList([Scenario({'a': 1, 'c': 2, 'd': True}), Scenario({'a': 3, 'c': 3, 'd': False})])
|
650
|
+
|
651
|
+
"""
|
652
|
+
new_names = new_names or [f"{field}_{i}" for i in range(len(self[0][field]))]
|
588
653
|
new_scenarios = []
|
589
654
|
for scenario in self:
|
590
655
|
new_scenario = scenario.copy()
|
591
|
-
|
656
|
+
if len(new_names) == 1:
|
657
|
+
new_scenario[new_names[0]] = scenario[field]
|
658
|
+
else:
|
659
|
+
for i, new_name in enumerate(new_names):
|
660
|
+
new_scenario[new_name] = scenario[field][i]
|
661
|
+
|
662
|
+
if not keep_original:
|
663
|
+
del new_scenario[field]
|
592
664
|
new_scenarios.append(new_scenario)
|
593
665
|
return ScenarioList(new_scenarios)
|
594
666
|
|
@@ -901,33 +973,32 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
901
973
|
return cls.from_excel(temp_filename, sheet_name=sheet_name)
|
902
974
|
|
903
975
|
@classmethod
|
904
|
-
def
|
905
|
-
|
976
|
+
def from_delimited_file(
|
977
|
+
cls, source: Union[str, urllib.parse.ParseResult], delimiter: str = ","
|
978
|
+
) -> ScenarioList:
|
979
|
+
"""Create a ScenarioList from a delimited file (CSV/TSV) or URL.
|
906
980
|
|
907
981
|
Args:
|
908
|
-
source: A string representing either a local file path or a URL to a
|
982
|
+
source: A string representing either a local file path or a URL to a delimited file,
|
909
983
|
or a urllib.parse.ParseResult object for a URL.
|
984
|
+
delimiter: The delimiter used in the file. Defaults to ',' for CSV files.
|
985
|
+
Use '\t' for TSV files.
|
910
986
|
|
911
987
|
Returns:
|
912
|
-
ScenarioList: A ScenarioList object containing the data from the
|
988
|
+
ScenarioList: A ScenarioList object containing the data from the file.
|
913
989
|
|
914
990
|
Example:
|
991
|
+
# For CSV files
|
915
992
|
|
916
|
-
|
917
|
-
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
>>> scenario_list[0]['name']
|
925
|
-
'Alice'
|
926
|
-
>>> scenario_list[1]['age']
|
927
|
-
'25'
|
993
|
+
>>> with open('data.csv', 'w') as f:
|
994
|
+
... _ = f.write('name,age\\nAlice,30\\nBob,25\\n')
|
995
|
+
>>> scenario_list = ScenarioList.from_delimited_file('data.csv')
|
996
|
+
|
997
|
+
# For TSV files
|
998
|
+
>>> with open('data.tsv', 'w') as f:
|
999
|
+
... _ = f.write('name\\tage\\nAlice\t30\\nBob\t25\\n')
|
1000
|
+
>>> scenario_list = ScenarioList.from_delimited_file('data.tsv', delimiter='\\t')
|
928
1001
|
|
929
|
-
>>> url = "https://example.com/data.csv"
|
930
|
-
>>> ## scenario_list_from_url = ScenarioList.from_csv(url)
|
931
1002
|
"""
|
932
1003
|
from edsl.scenarios.Scenario import Scenario
|
933
1004
|
|
@@ -940,42 +1011,129 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
940
1011
|
|
941
1012
|
if isinstance(source, str) and is_url(source):
|
942
1013
|
with urllib.request.urlopen(source) as response:
|
943
|
-
|
944
|
-
|
1014
|
+
file_content = response.read().decode("utf-8")
|
1015
|
+
file_obj = StringIO(file_content)
|
945
1016
|
elif isinstance(source, urllib.parse.ParseResult):
|
946
1017
|
with urllib.request.urlopen(source.geturl()) as response:
|
947
|
-
|
948
|
-
|
1018
|
+
file_content = response.read().decode("utf-8")
|
1019
|
+
file_obj = StringIO(file_content)
|
949
1020
|
else:
|
950
|
-
|
1021
|
+
file_obj = open(source, "r")
|
951
1022
|
|
952
1023
|
try:
|
953
|
-
reader = csv.reader(
|
1024
|
+
reader = csv.reader(file_obj, delimiter=delimiter)
|
954
1025
|
header = next(reader)
|
955
1026
|
observations = [Scenario(dict(zip(header, row))) for row in reader]
|
956
1027
|
finally:
|
957
|
-
|
1028
|
+
file_obj.close()
|
958
1029
|
|
959
1030
|
return cls(observations)
|
960
1031
|
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
return {"scenarios": [s._to_dict() for s in data]}
|
1032
|
+
# Convenience methods for specific file types
|
1033
|
+
@classmethod
|
1034
|
+
def from_csv(cls, source: Union[str, urllib.parse.ParseResult]) -> ScenarioList:
|
1035
|
+
"""Create a ScenarioList from a CSV file or URL."""
|
1036
|
+
return cls.from_delimited_file(source, delimiter=",")
|
967
1037
|
|
968
|
-
|
969
|
-
|
970
|
-
"""Return the `ScenarioList` as a dictionary.
|
1038
|
+
def left_join(self, other: ScenarioList, by: Union[str, list[str]]) -> ScenarioList:
|
1039
|
+
"""Perform a left join with another ScenarioList, following SQL join semantics.
|
971
1040
|
|
972
|
-
|
1041
|
+
Args:
|
1042
|
+
other: The ScenarioList to join with
|
1043
|
+
by: String or list of strings representing the key(s) to join on. Cannot be empty.
|
973
1044
|
|
1045
|
+
>>> s1 = ScenarioList([Scenario({'name': 'Alice', 'age': 30}), Scenario({'name': 'Bob', 'age': 25})])
|
1046
|
+
>>> s2 = ScenarioList([Scenario({'name': 'Alice', 'location': 'New York'}), Scenario({'name': 'Charlie', 'location': 'Los Angeles'})])
|
1047
|
+
>>> s3 = s1.left_join(s2, 'name')
|
1048
|
+
>>> s3 == ScenarioList([Scenario({'age': 30, 'location': 'New York', 'name': 'Alice'}), Scenario({'age': 25, 'location': None, 'name': 'Bob'})])
|
1049
|
+
True
|
1050
|
+
"""
|
1051
|
+
from edsl.scenarios.ScenarioJoin import ScenarioJoin
|
1052
|
+
|
1053
|
+
sj = ScenarioJoin(self, other)
|
1054
|
+
return sj.left_join(by)
|
1055
|
+
# # Validate join keys
|
1056
|
+
# if not by:
|
1057
|
+
# raise ValueError(
|
1058
|
+
# "Join keys cannot be empty. Please specify at least one key to join on."
|
1059
|
+
# )
|
1060
|
+
|
1061
|
+
# # Convert single string to list for consistent handling
|
1062
|
+
# by_keys = [by] if isinstance(by, str) else by
|
1063
|
+
|
1064
|
+
# # Verify all join keys exist in both ScenarioLists
|
1065
|
+
# left_keys = set(next(iter(self)).keys()) if self else set()
|
1066
|
+
# right_keys = set(next(iter(other)).keys()) if other else set()
|
1067
|
+
|
1068
|
+
# missing_left = set(by_keys) - left_keys
|
1069
|
+
# missing_right = set(by_keys) - right_keys
|
1070
|
+
# if missing_left or missing_right:
|
1071
|
+
# missing = missing_left | missing_right
|
1072
|
+
# raise ValueError(f"Join key(s) {missing} not found in both ScenarioLists")
|
1073
|
+
|
1074
|
+
# # Create lookup dictionary from the other ScenarioList
|
1075
|
+
# def get_key_tuple(scenario: Scenario, keys: list[str]) -> tuple:
|
1076
|
+
# return tuple(scenario[k] for k in keys)
|
1077
|
+
|
1078
|
+
# other_dict = {get_key_tuple(scenario, by_keys): scenario for scenario in other}
|
1079
|
+
|
1080
|
+
# # Collect all possible keys (like SQL combining all columns)
|
1081
|
+
# all_keys = set()
|
1082
|
+
# for scenario in self:
|
1083
|
+
# all_keys.update(scenario.keys())
|
1084
|
+
# for scenario in other:
|
1085
|
+
# all_keys.update(scenario.keys())
|
1086
|
+
|
1087
|
+
# new_scenarios = []
|
1088
|
+
# for scenario in self:
|
1089
|
+
# new_scenario = {
|
1090
|
+
# key: None for key in all_keys
|
1091
|
+
# } # Start with nulls (like SQL)
|
1092
|
+
# new_scenario.update(scenario) # Add all left values
|
1093
|
+
|
1094
|
+
# key_tuple = get_key_tuple(scenario, by_keys)
|
1095
|
+
# if matching_scenario := other_dict.get(key_tuple):
|
1096
|
+
# # Check for overlapping keys with different values
|
1097
|
+
# overlapping_keys = set(scenario.keys()) & set(matching_scenario.keys())
|
1098
|
+
# for key in overlapping_keys:
|
1099
|
+
# if key not in by_keys and scenario[key] != matching_scenario[key]:
|
1100
|
+
# join_conditions = [f"{k}='{scenario[k]}'" for k in by_keys]
|
1101
|
+
# print(
|
1102
|
+
# f"Warning: Conflicting values for key '{key}' where {' AND '.join(join_conditions)}. "
|
1103
|
+
# f"Keeping left value: {scenario[key]} (discarding: {matching_scenario[key]})"
|
1104
|
+
# )
|
1105
|
+
|
1106
|
+
# # Only update with non-overlapping keys from matching scenario
|
1107
|
+
# new_keys = set(matching_scenario.keys()) - set(scenario.keys())
|
1108
|
+
# new_scenario.update({k: matching_scenario[k] for k in new_keys})
|
1109
|
+
|
1110
|
+
# new_scenarios.append(Scenario(new_scenario))
|
1111
|
+
|
1112
|
+
# return ScenarioList(new_scenarios)
|
1113
|
+
|
1114
|
+
@classmethod
|
1115
|
+
def from_tsv(cls, source: Union[str, urllib.parse.ParseResult]) -> ScenarioList:
|
1116
|
+
"""Create a ScenarioList from a TSV file or URL."""
|
1117
|
+
return cls.from_delimited_file(source, delimiter="\t")
|
1118
|
+
|
1119
|
+
def to_dict(self, sort=False, add_edsl_version=True) -> dict:
|
1120
|
+
"""
|
974
1121
|
>>> s = ScenarioList([Scenario({'food': 'wood chips'}), Scenario({'food': 'wood-fired pizza'})])
|
975
1122
|
>>> s.to_dict()
|
976
1123
|
{'scenarios': [{'food': 'wood chips', 'edsl_version': '...', 'edsl_class_name': 'Scenario'}, {'food': 'wood-fired pizza', 'edsl_version': '...', 'edsl_class_name': 'Scenario'}], 'edsl_version': '...', 'edsl_class_name': 'ScenarioList'}
|
1124
|
+
|
977
1125
|
"""
|
978
|
-
|
1126
|
+
if sort:
|
1127
|
+
data = sorted(self, key=lambda x: hash(x))
|
1128
|
+
else:
|
1129
|
+
data = self
|
1130
|
+
d = {"scenarios": [s.to_dict(add_edsl_version=add_edsl_version) for s in data]}
|
1131
|
+
if add_edsl_version:
|
1132
|
+
from edsl import __version__
|
1133
|
+
|
1134
|
+
d["edsl_version"] = __version__
|
1135
|
+
d["edsl_class_name"] = self.__class__.__name__
|
1136
|
+
return d
|
979
1137
|
|
980
1138
|
@classmethod
|
981
1139
|
def gen(cls, scenario_dicts_list: List[dict]) -> ScenarioList:
|
@@ -1061,7 +1219,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1061
1219
|
elif isinstance(key, int):
|
1062
1220
|
return super().__getitem__(key)
|
1063
1221
|
else:
|
1064
|
-
return self.
|
1222
|
+
return self.to_dict(add_edsl_version=False)[key]
|
1065
1223
|
|
1066
1224
|
def to_agent_list(self):
|
1067
1225
|
"""Convert the ScenarioList to an AgentList.
|
@@ -1074,8 +1232,25 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1074
1232
|
"""
|
1075
1233
|
from edsl.agents.AgentList import AgentList
|
1076
1234
|
from edsl.agents.Agent import Agent
|
1235
|
+
import warnings
|
1236
|
+
|
1237
|
+
agents = []
|
1238
|
+
for scenario in self:
|
1239
|
+
new_scenario = scenario.copy().data
|
1240
|
+
if "name" in new_scenario:
|
1241
|
+
name = new_scenario.pop("name")
|
1242
|
+
proposed_agent_name = "agent_name"
|
1243
|
+
while proposed_agent_name not in new_scenario:
|
1244
|
+
proposed_agent_name += "_"
|
1245
|
+
warnings.warn(
|
1246
|
+
f"The 'name' field is reserved for the agent's name---putting this value in {proposed_agent_name}"
|
1247
|
+
)
|
1248
|
+
new_scenario[proposed_agent_name] = name
|
1249
|
+
agents.append(Agent(traits=new_scenario, name=name))
|
1250
|
+
else:
|
1251
|
+
agents.append(Agent(traits=new_scenario))
|
1077
1252
|
|
1078
|
-
return AgentList(
|
1253
|
+
return AgentList(agents)
|
1079
1254
|
|
1080
1255
|
def chunk(
|
1081
1256
|
self,
|
edsl/surveys/MemoryPlan.py
CHANGED
@@ -143,7 +143,7 @@ class MemoryPlan(UserDict):
|
|
143
143
|
for question in prior_questions:
|
144
144
|
self.add_single_memory(focal_question, question)
|
145
145
|
|
146
|
-
def to_dict(self) -> dict:
|
146
|
+
def to_dict(self, add_edsl_version=True) -> dict:
|
147
147
|
"""Serialize the memory plan to a dictionary.
|
148
148
|
|
149
149
|
>>> mp = MemoryPlan.example()
|
edsl/surveys/Rule.py
CHANGED
@@ -148,10 +148,7 @@ class Rule:
|
|
148
148
|
def _checks(self):
|
149
149
|
pass
|
150
150
|
|
151
|
-
|
152
|
-
|
153
|
-
# @add_edsl_version
|
154
|
-
def to_dict(self):
|
151
|
+
def to_dict(self, add_edsl_version=True):
|
155
152
|
"""Convert the rule to a dictionary for serialization.
|
156
153
|
|
157
154
|
>>> r = Rule.example()
|
@@ -166,7 +163,6 @@ class Rule:
|
|
166
163
|
"question_name_to_index": self.question_name_to_index,
|
167
164
|
"before_rule": self.before_rule,
|
168
165
|
}
|
169
|
-
# return self._to_dict()
|
170
166
|
|
171
167
|
@classmethod
|
172
168
|
@remove_edsl_version
|
edsl/surveys/RuleCollection.py
CHANGED
@@ -46,7 +46,7 @@ class RuleCollection(UserList):
|
|
46
46
|
"""
|
47
47
|
return f"RuleCollection(rules={self.data}, num_questions={self.num_questions})"
|
48
48
|
|
49
|
-
def to_dict(self):
|
49
|
+
def to_dict(self, add_edsl_version=True):
|
50
50
|
"""Create a dictionary representation of the RuleCollection object."""
|
51
51
|
return {
|
52
52
|
"rules": [rule.to_dict() for rule in self],
|