edsl 0.1.37.dev4__py3-none-any.whl → 0.1.37.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +86 -35
- edsl/agents/AgentList.py +5 -0
- edsl/agents/InvigilatorBase.py +2 -23
- edsl/agents/PromptConstructor.py +147 -106
- edsl/agents/descriptors.py +17 -4
- edsl/config.py +1 -1
- edsl/conjure/AgentConstructionMixin.py +11 -3
- edsl/conversation/Conversation.py +66 -14
- edsl/conversation/chips.py +95 -0
- edsl/coop/coop.py +134 -3
- edsl/data/Cache.py +1 -1
- edsl/exceptions/BaseException.py +21 -0
- edsl/exceptions/__init__.py +7 -3
- edsl/exceptions/agents.py +17 -19
- edsl/exceptions/results.py +11 -8
- edsl/exceptions/scenarios.py +22 -0
- edsl/exceptions/surveys.py +13 -10
- edsl/inference_services/InferenceServicesCollection.py +32 -9
- edsl/jobs/Jobs.py +265 -53
- edsl/jobs/interviews/InterviewExceptionEntry.py +5 -1
- edsl/jobs/tasks/TaskHistory.py +1 -0
- edsl/language_models/KeyLookup.py +30 -0
- edsl/language_models/LanguageModel.py +47 -59
- edsl/language_models/__init__.py +1 -0
- edsl/prompts/Prompt.py +8 -4
- edsl/questions/QuestionBase.py +53 -13
- edsl/questions/QuestionBasePromptsMixin.py +1 -33
- edsl/questions/QuestionFunctional.py +2 -2
- edsl/questions/descriptors.py +23 -28
- edsl/results/DatasetExportMixin.py +25 -1
- edsl/results/Result.py +16 -1
- edsl/results/Results.py +31 -120
- edsl/results/ResultsDBMixin.py +1 -1
- edsl/results/Selector.py +18 -1
- edsl/scenarios/Scenario.py +48 -12
- edsl/scenarios/ScenarioHtmlMixin.py +7 -2
- edsl/scenarios/ScenarioList.py +12 -1
- edsl/surveys/Rule.py +10 -4
- edsl/surveys/Survey.py +100 -77
- edsl/utilities/utilities.py +18 -0
- {edsl-0.1.37.dev4.dist-info → edsl-0.1.37.dev6.dist-info}/METADATA +1 -1
- {edsl-0.1.37.dev4.dist-info → edsl-0.1.37.dev6.dist-info}/RECORD +45 -41
- {edsl-0.1.37.dev4.dist-info → edsl-0.1.37.dev6.dist-info}/LICENSE +0 -0
- {edsl-0.1.37.dev4.dist-info → edsl-0.1.37.dev6.dist-info}/WHEEL +0 -0
edsl/results/Results.py
CHANGED
@@ -7,11 +7,17 @@ from __future__ import annotations
|
|
7
7
|
import json
|
8
8
|
import random
|
9
9
|
from collections import UserList, defaultdict
|
10
|
-
from typing import Optional, Callable, Any, Type, Union, List
|
10
|
+
from typing import Optional, Callable, Any, Type, Union, List, TYPE_CHECKING
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from edsl import Survey, Cache, AgentList, ModelList, ScenarioList
|
14
|
+
from edsl.results.Result import Result
|
15
|
+
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
11
16
|
|
12
17
|
from simpleeval import EvalWithCompoundTypes
|
13
18
|
|
14
19
|
from edsl.exceptions.results import (
|
20
|
+
ResultsError,
|
15
21
|
ResultsBadMutationstringError,
|
16
22
|
ResultsColumnNotFoundError,
|
17
23
|
ResultsInvalidNameError,
|
@@ -40,7 +46,7 @@ class Mixins(
|
|
40
46
|
ResultsGGMixin,
|
41
47
|
ResultsToolsMixin,
|
42
48
|
):
|
43
|
-
def print_long(self, max_rows=None) -> None:
|
49
|
+
def print_long(self, max_rows: int = None) -> None:
|
44
50
|
"""Print the results in long format.
|
45
51
|
|
46
52
|
>>> from edsl.results import Results
|
@@ -84,13 +90,13 @@ class Results(UserList, Mixins, Base):
|
|
84
90
|
|
85
91
|
def __init__(
|
86
92
|
self,
|
87
|
-
survey: Optional[
|
88
|
-
data: Optional[list[
|
93
|
+
survey: Optional[Survey] = None,
|
94
|
+
data: Optional[list[Result]] = None,
|
89
95
|
created_columns: Optional[list[str]] = None,
|
90
|
-
cache: Optional[
|
96
|
+
cache: Optional[Cache] = None,
|
91
97
|
job_uuid: Optional[str] = None,
|
92
98
|
total_results: Optional[int] = None,
|
93
|
-
task_history: Optional[
|
99
|
+
task_history: Optional[TaskHistory] = None,
|
94
100
|
):
|
95
101
|
"""Instantiate a `Results` object with a survey and a list of `Result` objects.
|
96
102
|
|
@@ -235,11 +241,11 @@ class Results(UserList, Mixins, Base):
|
|
235
241
|
>>> r3 = r + r2
|
236
242
|
"""
|
237
243
|
if self.survey != other.survey:
|
238
|
-
raise
|
239
|
-
"The surveys are not the same so
|
244
|
+
raise ResultsError(
|
245
|
+
"The surveys are not the same so the the results cannot be added together."
|
240
246
|
)
|
241
247
|
if self.created_columns != other.created_columns:
|
242
|
-
raise
|
248
|
+
raise ResultsError(
|
243
249
|
"The created columns are not the same so they cannot be added together."
|
244
250
|
)
|
245
251
|
|
@@ -258,17 +264,6 @@ class Results(UserList, Mixins, Base):
|
|
258
264
|
from IPython.display import HTML
|
259
265
|
|
260
266
|
json_str = json.dumps(self.to_dict()["data"], indent=4)
|
261
|
-
# from pygments import highlight
|
262
|
-
# from pygments.lexers import JsonLexer
|
263
|
-
# 3from pygments.formatters import HtmlFormatter
|
264
|
-
|
265
|
-
# formatted_json = highlight(
|
266
|
-
# json_str,
|
267
|
-
# JsonLexer(),
|
268
|
-
# HtmlFormatter(style="default", full=True, noclasses=True),
|
269
|
-
# )
|
270
|
-
# return HTML(formatted_json).data
|
271
|
-
# print(json_str)
|
272
267
|
return f"<pre>{json_str}</pre>"
|
273
268
|
|
274
269
|
def _to_dict(self, sort=False):
|
@@ -328,7 +323,7 @@ class Results(UserList, Mixins, Base):
|
|
328
323
|
def hashes(self) -> set:
|
329
324
|
return set(hash(result) for result in self.data)
|
330
325
|
|
331
|
-
def sample(self, n: int) ->
|
326
|
+
def sample(self, n: int) -> Results:
|
332
327
|
"""Return a random sample of the results.
|
333
328
|
|
334
329
|
:param n: The number of samples to return.
|
@@ -346,7 +341,7 @@ class Results(UserList, Mixins, Base):
|
|
346
341
|
indices = list(range(len(values)))
|
347
342
|
sampled_indices = random.sample(indices, n)
|
348
343
|
if n > len(indices):
|
349
|
-
raise
|
344
|
+
raise ResultsError(
|
350
345
|
f"Cannot sample {n} items from a list of length {len(indices)}."
|
351
346
|
)
|
352
347
|
entry[key] = [values[i] for i in sampled_indices]
|
@@ -399,11 +394,12 @@ class Results(UserList, Mixins, Base):
|
|
399
394
|
- Uses the key_to_data_type property of the Result class.
|
400
395
|
- Includes any columns that the user has created with `mutate`
|
401
396
|
"""
|
402
|
-
d = {}
|
397
|
+
d: dict = {}
|
403
398
|
for result in self.data:
|
404
399
|
d.update(result.key_to_data_type)
|
405
400
|
for column in self.created_columns:
|
406
401
|
d[column] = "answer"
|
402
|
+
|
407
403
|
return d
|
408
404
|
|
409
405
|
@property
|
@@ -453,7 +449,7 @@ class Results(UserList, Mixins, Base):
|
|
453
449
|
from edsl.utilities.utilities import shorten_string
|
454
450
|
|
455
451
|
if not self.survey:
|
456
|
-
raise
|
452
|
+
raise ResultsError("Survey is not defined so no answer keys are available.")
|
457
453
|
|
458
454
|
answer_keys = self._data_type_to_keys["answer"]
|
459
455
|
answer_keys = {k for k in answer_keys if "_comment" not in k}
|
@@ -466,7 +462,7 @@ class Results(UserList, Mixins, Base):
|
|
466
462
|
return sorted_dict
|
467
463
|
|
468
464
|
@property
|
469
|
-
def agents(self) ->
|
465
|
+
def agents(self) -> AgentList:
|
470
466
|
"""Return a list of all of the agents in the Results.
|
471
467
|
|
472
468
|
Example:
|
@@ -480,7 +476,7 @@ class Results(UserList, Mixins, Base):
|
|
480
476
|
return AgentList([r.agent for r in self.data])
|
481
477
|
|
482
478
|
@property
|
483
|
-
def models(self) ->
|
479
|
+
def models(self) -> ModelList:
|
484
480
|
"""Return a list of all of the models in the Results.
|
485
481
|
|
486
482
|
Example:
|
@@ -489,10 +485,12 @@ class Results(UserList, Mixins, Base):
|
|
489
485
|
>>> r.models[0]
|
490
486
|
Model(model_name = ...)
|
491
487
|
"""
|
492
|
-
|
488
|
+
from edsl import ModelList
|
489
|
+
|
490
|
+
return ModelList([r.model for r in self.data])
|
493
491
|
|
494
492
|
@property
|
495
|
-
def scenarios(self) ->
|
493
|
+
def scenarios(self) -> ScenarioList:
|
496
494
|
"""Return a list of all of the scenarios in the Results.
|
497
495
|
|
498
496
|
Example:
|
@@ -569,7 +567,7 @@ class Results(UserList, Mixins, Base):
|
|
569
567
|
)
|
570
568
|
return sorted(list(all_keys))
|
571
569
|
|
572
|
-
def first(self) ->
|
570
|
+
def first(self) -> Result:
|
573
571
|
"""Return the first observation in the results.
|
574
572
|
|
575
573
|
Example:
|
@@ -819,7 +817,7 @@ class Results(UserList, Mixins, Base):
|
|
819
817
|
|
820
818
|
return Results(survey=self.survey, data=new_data, created_columns=None)
|
821
819
|
|
822
|
-
def select(self, *columns: Union[str, list[str]]) ->
|
820
|
+
def select(self, *columns: Union[str, list[str]]) -> Results:
|
823
821
|
"""
|
824
822
|
Select data from the results and format it.
|
825
823
|
|
@@ -832,93 +830,12 @@ class Results(UserList, Mixins, Base):
|
|
832
830
|
Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}])
|
833
831
|
|
834
832
|
>>> results.select('how_feeling', 'model', 'how_feeling')
|
835
|
-
Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'model.model': ['...', '...', '...', '...']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}])
|
833
|
+
Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'model.model': ['...', '...', '...', '...']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}])
|
836
834
|
|
837
835
|
>>> from edsl import Results; r = Results.example(); r.select('answer.how_feeling_y')
|
838
836
|
Dataset([{'answer.how_feeling_yesterday': ['Great', 'Good', 'OK', 'Terrible']}])
|
839
837
|
"""
|
840
838
|
|
841
|
-
# if len(self) == 0:
|
842
|
-
# raise Exception("No data to select from---the Results object is empty.")
|
843
|
-
|
844
|
-
if not columns or columns == ("*",) or columns == (None,):
|
845
|
-
# is the users passes nothing, then we'll return all the columns
|
846
|
-
columns = ("*.*",)
|
847
|
-
|
848
|
-
if isinstance(columns[0], list):
|
849
|
-
columns = tuple(columns[0])
|
850
|
-
|
851
|
-
def get_data_types_to_return(parsed_data_type):
|
852
|
-
if parsed_data_type == "*": # they want all of the columns
|
853
|
-
return self.known_data_types
|
854
|
-
else:
|
855
|
-
if parsed_data_type not in self.known_data_types:
|
856
|
-
raise Exception(
|
857
|
-
f"Data type {parsed_data_type} not found in data. Did you mean one of {self.known_data_types}"
|
858
|
-
)
|
859
|
-
return [parsed_data_type]
|
860
|
-
|
861
|
-
# we're doing to populate this with the data we want to fetch
|
862
|
-
to_fetch = defaultdict(list)
|
863
|
-
|
864
|
-
new_data = []
|
865
|
-
items_in_order = []
|
866
|
-
# iterate through the passed columns
|
867
|
-
for column in columns:
|
868
|
-
# a user could pass 'result.how_feeling' or just 'how_feeling'
|
869
|
-
matches = self._matching_columns(column)
|
870
|
-
if len(matches) > 1:
|
871
|
-
raise Exception(
|
872
|
-
f"Column '{column}' is ambiguous. Did you mean one of {matches}?"
|
873
|
-
)
|
874
|
-
if len(matches) == 0 and ".*" not in column:
|
875
|
-
raise Exception(f"Column '{column}' not found in data.")
|
876
|
-
if len(matches) == 1:
|
877
|
-
column = matches[0]
|
878
|
-
|
879
|
-
parsed_data_type, parsed_key = self._parse_column(column)
|
880
|
-
data_types = get_data_types_to_return(parsed_data_type)
|
881
|
-
found_once = False # we need to track this to make sure we found the key at least once
|
882
|
-
|
883
|
-
for data_type in data_types:
|
884
|
-
# the keys for that data_type e.g.,# if data_type is 'answer', then the keys are 'how_feeling', 'how_feeling_comment', etc.
|
885
|
-
relevant_keys = self._data_type_to_keys[data_type]
|
886
|
-
|
887
|
-
for key in relevant_keys:
|
888
|
-
if key == parsed_key or parsed_key == "*":
|
889
|
-
found_once = True
|
890
|
-
to_fetch[data_type].append(key)
|
891
|
-
items_in_order.append(data_type + "." + key)
|
892
|
-
|
893
|
-
if not found_once:
|
894
|
-
raise Exception(f"Key {parsed_key} not found in data.")
|
895
|
-
|
896
|
-
for data_type in to_fetch:
|
897
|
-
for key in to_fetch[data_type]:
|
898
|
-
entries = self._fetch_list(data_type, key)
|
899
|
-
new_data.append({data_type + "." + key: entries})
|
900
|
-
|
901
|
-
def sort_by_key_order(dictionary):
|
902
|
-
# Extract the single key from the dictionary
|
903
|
-
single_key = next(iter(dictionary))
|
904
|
-
# Return the index of this key in the list_of_keys
|
905
|
-
return items_in_order.index(single_key)
|
906
|
-
|
907
|
-
# sorted(new_data, key=sort_by_key_order)
|
908
|
-
from edsl.results.Dataset import Dataset
|
909
|
-
|
910
|
-
sorted_new_data = []
|
911
|
-
|
912
|
-
# WORKS but slow
|
913
|
-
for key in items_in_order:
|
914
|
-
for d in new_data:
|
915
|
-
if key in d:
|
916
|
-
sorted_new_data.append(d)
|
917
|
-
break
|
918
|
-
|
919
|
-
return Dataset(sorted_new_data)
|
920
|
-
|
921
|
-
def select(self, *columns: Union[str, list[str]]) -> "Results":
|
922
839
|
from edsl.results.Selector import Selector
|
923
840
|
|
924
841
|
if len(self) == 0:
|
@@ -1028,6 +945,7 @@ class Results(UserList, Mixins, Base):
|
|
1028
945
|
Traceback (most recent call last):
|
1029
946
|
...
|
1030
947
|
edsl.exceptions.results.ResultsFilterError: You must use '==' instead of '=' in the filter expression.
|
948
|
+
...
|
1031
949
|
|
1032
950
|
>>> r.filter("how_feeling == 'Great' or how_feeling == 'Terrible'").select('how_feeling').print()
|
1033
951
|
┏━━━━━━━━━━━━━━┓
|
@@ -1105,6 +1023,7 @@ class Results(UserList, Mixins, Base):
|
|
1105
1023
|
stop_on_exception=True,
|
1106
1024
|
skip_retry=True,
|
1107
1025
|
raise_validation_errors=True,
|
1026
|
+
disable_remote_cache=True,
|
1108
1027
|
disable_remote_inference=True,
|
1109
1028
|
)
|
1110
1029
|
return results
|
@@ -1112,14 +1031,6 @@ class Results(UserList, Mixins, Base):
|
|
1112
1031
|
def rich_print(self):
|
1113
1032
|
"""Display an object as a table."""
|
1114
1033
|
pass
|
1115
|
-
# with io.StringIO() as buf:
|
1116
|
-
# console = Console(file=buf, record=True)
|
1117
|
-
|
1118
|
-
# for index, result in enumerate(self):
|
1119
|
-
# console.print(f"Result {index}")
|
1120
|
-
# console.print(result.rich_print())
|
1121
|
-
|
1122
|
-
# return console.export_text()
|
1123
1034
|
|
1124
1035
|
def __str__(self):
|
1125
1036
|
data = self.to_dict()["data"]
|
edsl/results/ResultsDBMixin.py
CHANGED
@@ -93,7 +93,7 @@ class ResultsDBMixin:
|
|
93
93
|
from sqlalchemy import create_engine
|
94
94
|
|
95
95
|
engine = create_engine("sqlite:///:memory:")
|
96
|
-
df = self.to_pandas(remove_prefix=remove_prefix)
|
96
|
+
df = self.to_pandas(remove_prefix=remove_prefix, lists_as_strings=True)
|
97
97
|
df.to_sql("self", engine, index=False, if_exists="replace")
|
98
98
|
return engine.connect()
|
99
99
|
else:
|
edsl/results/Selector.py
CHANGED
@@ -12,6 +12,7 @@ class Selector:
|
|
12
12
|
fetch_list_func,
|
13
13
|
columns: List[str],
|
14
14
|
):
|
15
|
+
"""Selects columns from a Results object"""
|
15
16
|
self.known_data_types = known_data_types
|
16
17
|
self._data_type_to_keys = data_type_to_keys
|
17
18
|
self._key_to_data_type = key_to_data_type
|
@@ -21,10 +22,19 @@ class Selector:
|
|
21
22
|
def select(self, *columns: Union[str, List[str]]) -> "Dataset":
|
22
23
|
columns = self._normalize_columns(columns)
|
23
24
|
to_fetch = self._get_columns_to_fetch(columns)
|
25
|
+
# breakpoint()
|
24
26
|
new_data = self._fetch_data(to_fetch)
|
25
27
|
return Dataset(new_data)
|
26
28
|
|
27
29
|
def _normalize_columns(self, columns: Union[str, List[str]]) -> tuple:
|
30
|
+
"""Normalize the columns to a tuple of strings
|
31
|
+
|
32
|
+
>>> s = Selector([], {}, {}, lambda x, y: x, [])
|
33
|
+
>>> s._normalize_columns([["a", "b"], ])
|
34
|
+
('a', 'b')
|
35
|
+
>>> s._normalize_columns(None)
|
36
|
+
('*.*',)
|
37
|
+
"""
|
28
38
|
if not columns or columns == ("*",) or columns == (None,):
|
29
39
|
return ("*.*",)
|
30
40
|
if isinstance(columns[0], list):
|
@@ -37,6 +47,7 @@ class Selector:
|
|
37
47
|
|
38
48
|
for column in columns:
|
39
49
|
matches = self._find_matching_columns(column)
|
50
|
+
# breakpoint()
|
40
51
|
self._validate_matches(column, matches)
|
41
52
|
|
42
53
|
if len(matches) == 1:
|
@@ -52,7 +63,7 @@ class Selector:
|
|
52
63
|
search_in_list = self.columns
|
53
64
|
else:
|
54
65
|
search_in_list = [s.split(".")[1] for s in self.columns]
|
55
|
-
|
66
|
+
# breakpoint()
|
56
67
|
matches = [s for s in search_in_list if s.startswith(partial_name)]
|
57
68
|
return [partial_name] if partial_name in matches else matches
|
58
69
|
|
@@ -116,3 +127,9 @@ class Selector:
|
|
116
127
|
new_data.append({f"{data_type}.{key}": entries})
|
117
128
|
|
118
129
|
return [d for key in self.items_in_order for d in new_data if key in d]
|
130
|
+
|
131
|
+
|
132
|
+
if __name__ == "__main__":
|
133
|
+
import doctest
|
134
|
+
|
135
|
+
doctest.testmod()
|
edsl/scenarios/Scenario.py
CHANGED
@@ -11,18 +11,26 @@ from uuid import uuid4
|
|
11
11
|
from edsl.Base import Base
|
12
12
|
from edsl.scenarios.ScenarioHtmlMixin import ScenarioHtmlMixin
|
13
13
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
14
|
+
from edsl.exceptions.scenarios import ScenarioError
|
14
15
|
|
15
16
|
|
16
17
|
class Scenario(Base, UserDict, ScenarioHtmlMixin):
|
17
18
|
"""A Scenario is a dictionary of keys/values.
|
18
19
|
|
19
|
-
They can be used parameterize
|
20
|
+
They can be used parameterize EDSL questions."""
|
21
|
+
|
22
|
+
__doc__ = "https://docs.expectedparrot.com/en/latest/scenarios.html"
|
20
23
|
|
21
24
|
def __init__(self, data: Union[dict, None] = None, name: str = None):
|
22
25
|
"""Initialize a new Scenario.
|
23
26
|
|
24
|
-
:param data: A dictionary of keys/values for parameterizing questions.
|
25
|
-
"""
|
27
|
+
# :param data: A dictionary of keys/values for parameterizing questions.
|
28
|
+
#"""
|
29
|
+
if not isinstance(data, dict) and data is not None:
|
30
|
+
raise EDSLScenarioError(
|
31
|
+
"You must pass in a dictionary to initialize a Scenario."
|
32
|
+
)
|
33
|
+
|
26
34
|
self.data = data if data is not None else {}
|
27
35
|
self.name = name
|
28
36
|
|
@@ -41,13 +49,6 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
|
|
41
49
|
|
42
50
|
return ScenarioList([copy.deepcopy(self) for _ in range(n)])
|
43
51
|
|
44
|
-
# @property
|
45
|
-
# def has_image(self) -> bool:
|
46
|
-
# """Return whether the scenario has an image."""
|
47
|
-
# if not hasattr(self, "_has_image"):
|
48
|
-
# self._has_image = False
|
49
|
-
# return self._has_image
|
50
|
-
|
51
52
|
@property
|
52
53
|
def has_jinja_braces(self) -> bool:
|
53
54
|
"""Return whether the scenario has jinja braces. This matters for rendering.
|
@@ -106,7 +107,9 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
|
|
106
107
|
s = Scenario(data1 | data2)
|
107
108
|
return s
|
108
109
|
|
109
|
-
def rename(
|
110
|
+
def rename(
|
111
|
+
self, old_name_or_replacement_dict: dict, new_name: Optional[str] = None
|
112
|
+
) -> "Scenario":
|
110
113
|
"""Rename the keys of a scenario.
|
111
114
|
|
112
115
|
:param replacement_dict: A dictionary of old keys to new keys.
|
@@ -116,7 +119,16 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
|
|
116
119
|
>>> s = Scenario({"food": "wood chips"})
|
117
120
|
>>> s.rename({"food": "food_preference"})
|
118
121
|
Scenario({'food_preference': 'wood chips'})
|
122
|
+
|
123
|
+
>>> s = Scenario({"food": "wood chips"})
|
124
|
+
>>> s.rename("food", "snack")
|
125
|
+
Scenario({'snack': 'wood chips'})
|
119
126
|
"""
|
127
|
+
if isinstance(old_name_or_replacement_dict, str) and new_name is not None:
|
128
|
+
replacement_dict = {old_name_or_replacement_dict: new_name}
|
129
|
+
else:
|
130
|
+
replacement_dict = old_name_or_replacement_dict
|
131
|
+
|
120
132
|
new_scenario = Scenario()
|
121
133
|
for key, value in self.items():
|
122
134
|
if key in replacement_dict:
|
@@ -216,6 +228,20 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
|
|
216
228
|
new_scenario[key] = self[key]
|
217
229
|
return new_scenario
|
218
230
|
|
231
|
+
def keep(self, list_of_keys: List[str]) -> "Scenario":
|
232
|
+
"""Keep a subset of keys from a scenario.
|
233
|
+
|
234
|
+
:param list_of_keys: The keys to keep.
|
235
|
+
|
236
|
+
Example:
|
237
|
+
|
238
|
+
>>> s = Scenario({"food": "wood chips", "drink": "water"})
|
239
|
+
>>> s.keep(["food"])
|
240
|
+
Scenario({'food': 'wood chips'})
|
241
|
+
"""
|
242
|
+
|
243
|
+
return self.select(list_of_keys)
|
244
|
+
|
219
245
|
@classmethod
|
220
246
|
def from_url(cls, url: str, field_name: Optional[str] = "text") -> "Scenario":
|
221
247
|
"""Creates a scenario from a URL.
|
@@ -231,7 +257,17 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
|
|
231
257
|
|
232
258
|
@classmethod
|
233
259
|
def from_file(cls, file_path: str, field_name: str) -> "Scenario":
|
234
|
-
"""Creates a scenario from a file.
|
260
|
+
"""Creates a scenario from a file.
|
261
|
+
|
262
|
+
>>> import tempfile
|
263
|
+
>>> with tempfile.NamedTemporaryFile(suffix=".txt", mode="w") as f:
|
264
|
+
... _ = f.write("This is a test.")
|
265
|
+
... _ = f.flush()
|
266
|
+
... s = Scenario.from_file(f.name, "file")
|
267
|
+
>>> s
|
268
|
+
Scenario({'file': FileStore(path='...')})
|
269
|
+
|
270
|
+
"""
|
235
271
|
from edsl.scenarios.FileStore import FileStore
|
236
272
|
|
237
273
|
fs = FileStore(file_path)
|
@@ -1,19 +1,24 @@
|
|
1
1
|
import requests
|
2
|
+
from typing import Optional
|
2
3
|
from requests.adapters import HTTPAdapter
|
3
4
|
from requests.packages.urllib3.util.retry import Retry
|
4
5
|
|
5
6
|
|
6
7
|
class ScenarioHtmlMixin:
|
7
8
|
@classmethod
|
8
|
-
def from_html(cls, url: str) -> "Scenario":
|
9
|
+
def from_html(cls, url: str, field_name: Optional[str] = None) -> "Scenario":
|
9
10
|
"""Create a scenario from HTML content.
|
10
11
|
|
11
12
|
:param html: The HTML content.
|
13
|
+
:param field_name: The name of the field containing the HTML content.
|
14
|
+
|
12
15
|
|
13
16
|
"""
|
14
17
|
html = cls.fetch_html(url)
|
15
18
|
text = cls.extract_text(html)
|
16
|
-
|
19
|
+
if not field_name:
|
20
|
+
field_name = "text"
|
21
|
+
return cls({"url": url, "html": html, field_name: text})
|
17
22
|
|
18
23
|
def fetch_html(url):
|
19
24
|
# Define the user-agent to mimic a browser
|
edsl/scenarios/ScenarioList.py
CHANGED
@@ -538,6 +538,17 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
538
538
|
"""
|
539
539
|
return ScenarioList([scenario.drop(fields) for scenario in self.data])
|
540
540
|
|
541
|
+
def keep(self, *fields) -> ScenarioList:
|
542
|
+
"""Keep only the specified fields in the scenarios.
|
543
|
+
|
544
|
+
Example:
|
545
|
+
|
546
|
+
>>> s = ScenarioList([Scenario({'a': 1, 'b': 1}), Scenario({'a': 1, 'b': 2})])
|
547
|
+
>>> s.keep('a')
|
548
|
+
ScenarioList([Scenario({'a': 1}), Scenario({'a': 1})])
|
549
|
+
"""
|
550
|
+
return ScenarioList([scenario.keep(fields) for scenario in self.data])
|
551
|
+
|
541
552
|
@classmethod
|
542
553
|
def from_list(
|
543
554
|
cls, name: str, values: list, func: Optional[Callable] = None
|
@@ -1050,7 +1061,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1050
1061
|
elif isinstance(key, int):
|
1051
1062
|
return super().__getitem__(key)
|
1052
1063
|
else:
|
1053
|
-
return self.
|
1064
|
+
return self._to_dict()[key]
|
1054
1065
|
|
1055
1066
|
def to_agent_list(self):
|
1056
1067
|
"""Convert the ScenarioList to an AgentList.
|
edsl/surveys/Rule.py
CHANGED
@@ -25,6 +25,8 @@ from jinja2 import Template
|
|
25
25
|
from rich import print
|
26
26
|
from simpleeval import EvalWithCompoundTypes
|
27
27
|
|
28
|
+
from edsl.exceptions.surveys import SurveyError
|
29
|
+
|
28
30
|
from edsl.exceptions import (
|
29
31
|
SurveyRuleCannotEvaluateError,
|
30
32
|
SurveyRuleCollectionHasNoRulesAtNodeError,
|
@@ -47,11 +49,11 @@ class QuestionIndex:
|
|
47
49
|
|
48
50
|
def __set__(self, obj, value):
|
49
51
|
if not isinstance(value, (int, EndOfSurvey.__class__)):
|
50
|
-
raise
|
52
|
+
raise SurveyError(f"{self.name} must be an integer or EndOfSurvey")
|
51
53
|
if self.name == "_next_q" and isinstance(value, int):
|
52
54
|
current_q = getattr(obj, "_current_q")
|
53
55
|
if value <= current_q:
|
54
|
-
raise
|
56
|
+
raise SurveyError("next_q must be greater than current_q")
|
55
57
|
setattr(obj, self.name, value)
|
56
58
|
|
57
59
|
|
@@ -100,13 +102,17 @@ class Rule:
|
|
100
102
|
raise SurveyRuleSendsYouBackwardsError
|
101
103
|
|
102
104
|
if not self.next_q == EndOfSurvey and self.current_q > self.next_q:
|
103
|
-
raise SurveyRuleSendsYouBackwardsError
|
105
|
+
raise SurveyRuleSendsYouBackwardsError(
|
106
|
+
f"current_q: {self.current_q}, next_q: {self.next_q}"
|
107
|
+
)
|
104
108
|
|
105
109
|
# get the AST for the expression - used to extract the variables referenced in the expression
|
106
110
|
try:
|
107
111
|
self.ast_tree = ast.parse(self.expression)
|
108
112
|
except SyntaxError:
|
109
|
-
raise SurveyRuleSkipLogicSyntaxError
|
113
|
+
raise SurveyRuleSkipLogicSyntaxError(
|
114
|
+
f"The expression {self.expression} is not valid Python syntax."
|
115
|
+
)
|
110
116
|
|
111
117
|
# get the names of the variables in the expression
|
112
118
|
# e.g., q1 == 'yes' -> ['q1']
|