edsl 0.1.37__py3-none-any.whl → 0.1.37.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +35 -86
- edsl/agents/AgentList.py +0 -5
- edsl/agents/InvigilatorBase.py +23 -2
- edsl/agents/PromptConstructor.py +105 -148
- edsl/agents/descriptors.py +4 -17
- edsl/conjure/AgentConstructionMixin.py +3 -11
- edsl/conversation/Conversation.py +14 -66
- edsl/coop/coop.py +14 -148
- edsl/data/Cache.py +1 -1
- edsl/exceptions/__init__.py +3 -7
- edsl/exceptions/agents.py +19 -17
- edsl/exceptions/results.py +8 -11
- edsl/exceptions/surveys.py +10 -13
- edsl/inference_services/AwsBedrock.py +2 -7
- edsl/inference_services/InferenceServicesCollection.py +9 -32
- edsl/jobs/Jobs.py +71 -306
- edsl/jobs/interviews/InterviewExceptionEntry.py +1 -5
- edsl/jobs/tasks/TaskHistory.py +0 -1
- edsl/language_models/LanguageModel.py +59 -47
- edsl/language_models/__init__.py +0 -1
- edsl/prompts/Prompt.py +4 -11
- edsl/questions/QuestionBase.py +13 -53
- edsl/questions/QuestionBasePromptsMixin.py +33 -1
- edsl/questions/QuestionFreeText.py +0 -1
- edsl/questions/QuestionFunctional.py +2 -2
- edsl/questions/descriptors.py +28 -23
- edsl/results/DatasetExportMixin.py +1 -25
- edsl/results/Result.py +1 -16
- edsl/results/Results.py +120 -31
- edsl/results/ResultsDBMixin.py +1 -1
- edsl/results/Selector.py +1 -18
- edsl/scenarios/Scenario.py +12 -48
- edsl/scenarios/ScenarioHtmlMixin.py +2 -7
- edsl/scenarios/ScenarioList.py +1 -12
- edsl/surveys/Rule.py +4 -10
- edsl/surveys/Survey.py +77 -100
- edsl/utilities/utilities.py +0 -18
- {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/METADATA +1 -1
- {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/RECORD +42 -46
- edsl/conversation/chips.py +0 -95
- edsl/exceptions/BaseException.py +0 -21
- edsl/exceptions/scenarios.py +0 -22
- edsl/language_models/KeyLookup.py +0 -30
- {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/LICENSE +0 -0
- {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/WHEEL +0 -0
edsl/results/Results.py
CHANGED
@@ -7,17 +7,11 @@ from __future__ import annotations
|
|
7
7
|
import json
|
8
8
|
import random
|
9
9
|
from collections import UserList, defaultdict
|
10
|
-
from typing import Optional, Callable, Any, Type, Union, List
|
11
|
-
|
12
|
-
if TYPE_CHECKING:
|
13
|
-
from edsl import Survey, Cache, AgentList, ModelList, ScenarioList
|
14
|
-
from edsl.results.Result import Result
|
15
|
-
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
10
|
+
from typing import Optional, Callable, Any, Type, Union, List
|
16
11
|
|
17
12
|
from simpleeval import EvalWithCompoundTypes
|
18
13
|
|
19
14
|
from edsl.exceptions.results import (
|
20
|
-
ResultsError,
|
21
15
|
ResultsBadMutationstringError,
|
22
16
|
ResultsColumnNotFoundError,
|
23
17
|
ResultsInvalidNameError,
|
@@ -46,7 +40,7 @@ class Mixins(
|
|
46
40
|
ResultsGGMixin,
|
47
41
|
ResultsToolsMixin,
|
48
42
|
):
|
49
|
-
def print_long(self, max_rows
|
43
|
+
def print_long(self, max_rows=None) -> None:
|
50
44
|
"""Print the results in long format.
|
51
45
|
|
52
46
|
>>> from edsl.results import Results
|
@@ -90,13 +84,13 @@ class Results(UserList, Mixins, Base):
|
|
90
84
|
|
91
85
|
def __init__(
|
92
86
|
self,
|
93
|
-
survey: Optional[Survey] = None,
|
94
|
-
data: Optional[list[Result]] = None,
|
87
|
+
survey: Optional["Survey"] = None,
|
88
|
+
data: Optional[list["Result"]] = None,
|
95
89
|
created_columns: Optional[list[str]] = None,
|
96
|
-
cache: Optional[Cache] = None,
|
90
|
+
cache: Optional["Cache"] = None,
|
97
91
|
job_uuid: Optional[str] = None,
|
98
92
|
total_results: Optional[int] = None,
|
99
|
-
task_history: Optional[TaskHistory] = None,
|
93
|
+
task_history: Optional["TaskHistory"] = None,
|
100
94
|
):
|
101
95
|
"""Instantiate a `Results` object with a survey and a list of `Result` objects.
|
102
96
|
|
@@ -241,11 +235,11 @@ class Results(UserList, Mixins, Base):
|
|
241
235
|
>>> r3 = r + r2
|
242
236
|
"""
|
243
237
|
if self.survey != other.survey:
|
244
|
-
raise
|
245
|
-
"The surveys are not the same so
|
238
|
+
raise Exception(
|
239
|
+
"The surveys are not the same so they cannot be added together."
|
246
240
|
)
|
247
241
|
if self.created_columns != other.created_columns:
|
248
|
-
raise
|
242
|
+
raise Exception(
|
249
243
|
"The created columns are not the same so they cannot be added together."
|
250
244
|
)
|
251
245
|
|
@@ -264,6 +258,17 @@ class Results(UserList, Mixins, Base):
|
|
264
258
|
from IPython.display import HTML
|
265
259
|
|
266
260
|
json_str = json.dumps(self.to_dict()["data"], indent=4)
|
261
|
+
# from pygments import highlight
|
262
|
+
# from pygments.lexers import JsonLexer
|
263
|
+
# 3from pygments.formatters import HtmlFormatter
|
264
|
+
|
265
|
+
# formatted_json = highlight(
|
266
|
+
# json_str,
|
267
|
+
# JsonLexer(),
|
268
|
+
# HtmlFormatter(style="default", full=True, noclasses=True),
|
269
|
+
# )
|
270
|
+
# return HTML(formatted_json).data
|
271
|
+
# print(json_str)
|
267
272
|
return f"<pre>{json_str}</pre>"
|
268
273
|
|
269
274
|
def _to_dict(self, sort=False):
|
@@ -323,7 +328,7 @@ class Results(UserList, Mixins, Base):
|
|
323
328
|
def hashes(self) -> set:
|
324
329
|
return set(hash(result) for result in self.data)
|
325
330
|
|
326
|
-
def sample(self, n: int) -> Results:
|
331
|
+
def sample(self, n: int) -> "Results":
|
327
332
|
"""Return a random sample of the results.
|
328
333
|
|
329
334
|
:param n: The number of samples to return.
|
@@ -341,7 +346,7 @@ class Results(UserList, Mixins, Base):
|
|
341
346
|
indices = list(range(len(values)))
|
342
347
|
sampled_indices = random.sample(indices, n)
|
343
348
|
if n > len(indices):
|
344
|
-
raise
|
349
|
+
raise ValueError(
|
345
350
|
f"Cannot sample {n} items from a list of length {len(indices)}."
|
346
351
|
)
|
347
352
|
entry[key] = [values[i] for i in sampled_indices]
|
@@ -394,12 +399,11 @@ class Results(UserList, Mixins, Base):
|
|
394
399
|
- Uses the key_to_data_type property of the Result class.
|
395
400
|
- Includes any columns that the user has created with `mutate`
|
396
401
|
"""
|
397
|
-
d
|
402
|
+
d = {}
|
398
403
|
for result in self.data:
|
399
404
|
d.update(result.key_to_data_type)
|
400
405
|
for column in self.created_columns:
|
401
406
|
d[column] = "answer"
|
402
|
-
|
403
407
|
return d
|
404
408
|
|
405
409
|
@property
|
@@ -449,7 +453,7 @@ class Results(UserList, Mixins, Base):
|
|
449
453
|
from edsl.utilities.utilities import shorten_string
|
450
454
|
|
451
455
|
if not self.survey:
|
452
|
-
raise
|
456
|
+
raise Exception("Survey is not defined so no answer keys are available.")
|
453
457
|
|
454
458
|
answer_keys = self._data_type_to_keys["answer"]
|
455
459
|
answer_keys = {k for k in answer_keys if "_comment" not in k}
|
@@ -462,7 +466,7 @@ class Results(UserList, Mixins, Base):
|
|
462
466
|
return sorted_dict
|
463
467
|
|
464
468
|
@property
|
465
|
-
def agents(self) -> AgentList:
|
469
|
+
def agents(self) -> "AgentList":
|
466
470
|
"""Return a list of all of the agents in the Results.
|
467
471
|
|
468
472
|
Example:
|
@@ -476,7 +480,7 @@ class Results(UserList, Mixins, Base):
|
|
476
480
|
return AgentList([r.agent for r in self.data])
|
477
481
|
|
478
482
|
@property
|
479
|
-
def models(self) ->
|
483
|
+
def models(self) -> list[Type["LanguageModel"]]:
|
480
484
|
"""Return a list of all of the models in the Results.
|
481
485
|
|
482
486
|
Example:
|
@@ -485,12 +489,10 @@ class Results(UserList, Mixins, Base):
|
|
485
489
|
>>> r.models[0]
|
486
490
|
Model(model_name = ...)
|
487
491
|
"""
|
488
|
-
|
489
|
-
|
490
|
-
return ModelList([r.model for r in self.data])
|
492
|
+
return [r.model for r in self.data]
|
491
493
|
|
492
494
|
@property
|
493
|
-
def scenarios(self) -> ScenarioList:
|
495
|
+
def scenarios(self) -> "ScenarioList":
|
494
496
|
"""Return a list of all of the scenarios in the Results.
|
495
497
|
|
496
498
|
Example:
|
@@ -567,7 +569,7 @@ class Results(UserList, Mixins, Base):
|
|
567
569
|
)
|
568
570
|
return sorted(list(all_keys))
|
569
571
|
|
570
|
-
def first(self) -> Result:
|
572
|
+
def first(self) -> "Result":
|
571
573
|
"""Return the first observation in the results.
|
572
574
|
|
573
575
|
Example:
|
@@ -817,7 +819,7 @@ class Results(UserList, Mixins, Base):
|
|
817
819
|
|
818
820
|
return Results(survey=self.survey, data=new_data, created_columns=None)
|
819
821
|
|
820
|
-
def select(self, *columns: Union[str, list[str]]) ->
|
822
|
+
def select(self, *columns: Union[str, list[str]]) -> "Dataset":
|
821
823
|
"""
|
822
824
|
Select data from the results and format it.
|
823
825
|
|
@@ -830,12 +832,93 @@ class Results(UserList, Mixins, Base):
|
|
830
832
|
Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}])
|
831
833
|
|
832
834
|
>>> results.select('how_feeling', 'model', 'how_feeling')
|
833
|
-
Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'
|
835
|
+
Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'model.model': ['...', '...', '...', '...']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}])
|
834
836
|
|
835
837
|
>>> from edsl import Results; r = Results.example(); r.select('answer.how_feeling_y')
|
836
838
|
Dataset([{'answer.how_feeling_yesterday': ['Great', 'Good', 'OK', 'Terrible']}])
|
837
839
|
"""
|
838
840
|
|
841
|
+
# if len(self) == 0:
|
842
|
+
# raise Exception("No data to select from---the Results object is empty.")
|
843
|
+
|
844
|
+
if not columns or columns == ("*",) or columns == (None,):
|
845
|
+
# is the users passes nothing, then we'll return all the columns
|
846
|
+
columns = ("*.*",)
|
847
|
+
|
848
|
+
if isinstance(columns[0], list):
|
849
|
+
columns = tuple(columns[0])
|
850
|
+
|
851
|
+
def get_data_types_to_return(parsed_data_type):
|
852
|
+
if parsed_data_type == "*": # they want all of the columns
|
853
|
+
return self.known_data_types
|
854
|
+
else:
|
855
|
+
if parsed_data_type not in self.known_data_types:
|
856
|
+
raise Exception(
|
857
|
+
f"Data type {parsed_data_type} not found in data. Did you mean one of {self.known_data_types}"
|
858
|
+
)
|
859
|
+
return [parsed_data_type]
|
860
|
+
|
861
|
+
# we're doing to populate this with the data we want to fetch
|
862
|
+
to_fetch = defaultdict(list)
|
863
|
+
|
864
|
+
new_data = []
|
865
|
+
items_in_order = []
|
866
|
+
# iterate through the passed columns
|
867
|
+
for column in columns:
|
868
|
+
# a user could pass 'result.how_feeling' or just 'how_feeling'
|
869
|
+
matches = self._matching_columns(column)
|
870
|
+
if len(matches) > 1:
|
871
|
+
raise Exception(
|
872
|
+
f"Column '{column}' is ambiguous. Did you mean one of {matches}?"
|
873
|
+
)
|
874
|
+
if len(matches) == 0 and ".*" not in column:
|
875
|
+
raise Exception(f"Column '{column}' not found in data.")
|
876
|
+
if len(matches) == 1:
|
877
|
+
column = matches[0]
|
878
|
+
|
879
|
+
parsed_data_type, parsed_key = self._parse_column(column)
|
880
|
+
data_types = get_data_types_to_return(parsed_data_type)
|
881
|
+
found_once = False # we need to track this to make sure we found the key at least once
|
882
|
+
|
883
|
+
for data_type in data_types:
|
884
|
+
# the keys for that data_type e.g.,# if data_type is 'answer', then the keys are 'how_feeling', 'how_feeling_comment', etc.
|
885
|
+
relevant_keys = self._data_type_to_keys[data_type]
|
886
|
+
|
887
|
+
for key in relevant_keys:
|
888
|
+
if key == parsed_key or parsed_key == "*":
|
889
|
+
found_once = True
|
890
|
+
to_fetch[data_type].append(key)
|
891
|
+
items_in_order.append(data_type + "." + key)
|
892
|
+
|
893
|
+
if not found_once:
|
894
|
+
raise Exception(f"Key {parsed_key} not found in data.")
|
895
|
+
|
896
|
+
for data_type in to_fetch:
|
897
|
+
for key in to_fetch[data_type]:
|
898
|
+
entries = self._fetch_list(data_type, key)
|
899
|
+
new_data.append({data_type + "." + key: entries})
|
900
|
+
|
901
|
+
def sort_by_key_order(dictionary):
|
902
|
+
# Extract the single key from the dictionary
|
903
|
+
single_key = next(iter(dictionary))
|
904
|
+
# Return the index of this key in the list_of_keys
|
905
|
+
return items_in_order.index(single_key)
|
906
|
+
|
907
|
+
# sorted(new_data, key=sort_by_key_order)
|
908
|
+
from edsl.results.Dataset import Dataset
|
909
|
+
|
910
|
+
sorted_new_data = []
|
911
|
+
|
912
|
+
# WORKS but slow
|
913
|
+
for key in items_in_order:
|
914
|
+
for d in new_data:
|
915
|
+
if key in d:
|
916
|
+
sorted_new_data.append(d)
|
917
|
+
break
|
918
|
+
|
919
|
+
return Dataset(sorted_new_data)
|
920
|
+
|
921
|
+
def select(self, *columns: Union[str, list[str]]) -> "Results":
|
839
922
|
from edsl.results.Selector import Selector
|
840
923
|
|
841
924
|
if len(self) == 0:
|
@@ -945,7 +1028,6 @@ class Results(UserList, Mixins, Base):
|
|
945
1028
|
Traceback (most recent call last):
|
946
1029
|
...
|
947
1030
|
edsl.exceptions.results.ResultsFilterError: You must use '==' instead of '=' in the filter expression.
|
948
|
-
...
|
949
1031
|
|
950
1032
|
>>> r.filter("how_feeling == 'Great' or how_feeling == 'Terrible'").select('how_feeling').print()
|
951
1033
|
┏━━━━━━━━━━━━━━┓
|
@@ -1023,7 +1105,6 @@ class Results(UserList, Mixins, Base):
|
|
1023
1105
|
stop_on_exception=True,
|
1024
1106
|
skip_retry=True,
|
1025
1107
|
raise_validation_errors=True,
|
1026
|
-
disable_remote_cache=True,
|
1027
1108
|
disable_remote_inference=True,
|
1028
1109
|
)
|
1029
1110
|
return results
|
@@ -1031,6 +1112,14 @@ class Results(UserList, Mixins, Base):
|
|
1031
1112
|
def rich_print(self):
|
1032
1113
|
"""Display an object as a table."""
|
1033
1114
|
pass
|
1115
|
+
# with io.StringIO() as buf:
|
1116
|
+
# console = Console(file=buf, record=True)
|
1117
|
+
|
1118
|
+
# for index, result in enumerate(self):
|
1119
|
+
# console.print(f"Result {index}")
|
1120
|
+
# console.print(result.rich_print())
|
1121
|
+
|
1122
|
+
# return console.export_text()
|
1034
1123
|
|
1035
1124
|
def __str__(self):
|
1036
1125
|
data = self.to_dict()["data"]
|
edsl/results/ResultsDBMixin.py
CHANGED
@@ -93,7 +93,7 @@ class ResultsDBMixin:
|
|
93
93
|
from sqlalchemy import create_engine
|
94
94
|
|
95
95
|
engine = create_engine("sqlite:///:memory:")
|
96
|
-
df = self.to_pandas(remove_prefix=remove_prefix
|
96
|
+
df = self.to_pandas(remove_prefix=remove_prefix)
|
97
97
|
df.to_sql("self", engine, index=False, if_exists="replace")
|
98
98
|
return engine.connect()
|
99
99
|
else:
|
edsl/results/Selector.py
CHANGED
@@ -12,7 +12,6 @@ class Selector:
|
|
12
12
|
fetch_list_func,
|
13
13
|
columns: List[str],
|
14
14
|
):
|
15
|
-
"""Selects columns from a Results object"""
|
16
15
|
self.known_data_types = known_data_types
|
17
16
|
self._data_type_to_keys = data_type_to_keys
|
18
17
|
self._key_to_data_type = key_to_data_type
|
@@ -22,19 +21,10 @@ class Selector:
|
|
22
21
|
def select(self, *columns: Union[str, List[str]]) -> "Dataset":
|
23
22
|
columns = self._normalize_columns(columns)
|
24
23
|
to_fetch = self._get_columns_to_fetch(columns)
|
25
|
-
# breakpoint()
|
26
24
|
new_data = self._fetch_data(to_fetch)
|
27
25
|
return Dataset(new_data)
|
28
26
|
|
29
27
|
def _normalize_columns(self, columns: Union[str, List[str]]) -> tuple:
|
30
|
-
"""Normalize the columns to a tuple of strings
|
31
|
-
|
32
|
-
>>> s = Selector([], {}, {}, lambda x, y: x, [])
|
33
|
-
>>> s._normalize_columns([["a", "b"], ])
|
34
|
-
('a', 'b')
|
35
|
-
>>> s._normalize_columns(None)
|
36
|
-
('*.*',)
|
37
|
-
"""
|
38
28
|
if not columns or columns == ("*",) or columns == (None,):
|
39
29
|
return ("*.*",)
|
40
30
|
if isinstance(columns[0], list):
|
@@ -47,7 +37,6 @@ class Selector:
|
|
47
37
|
|
48
38
|
for column in columns:
|
49
39
|
matches = self._find_matching_columns(column)
|
50
|
-
# breakpoint()
|
51
40
|
self._validate_matches(column, matches)
|
52
41
|
|
53
42
|
if len(matches) == 1:
|
@@ -63,7 +52,7 @@ class Selector:
|
|
63
52
|
search_in_list = self.columns
|
64
53
|
else:
|
65
54
|
search_in_list = [s.split(".")[1] for s in self.columns]
|
66
|
-
|
55
|
+
|
67
56
|
matches = [s for s in search_in_list if s.startswith(partial_name)]
|
68
57
|
return [partial_name] if partial_name in matches else matches
|
69
58
|
|
@@ -127,9 +116,3 @@ class Selector:
|
|
127
116
|
new_data.append({f"{data_type}.{key}": entries})
|
128
117
|
|
129
118
|
return [d for key in self.items_in_order for d in new_data if key in d]
|
130
|
-
|
131
|
-
|
132
|
-
if __name__ == "__main__":
|
133
|
-
import doctest
|
134
|
-
|
135
|
-
doctest.testmod()
|
edsl/scenarios/Scenario.py
CHANGED
@@ -11,26 +11,18 @@ from uuid import uuid4
|
|
11
11
|
from edsl.Base import Base
|
12
12
|
from edsl.scenarios.ScenarioHtmlMixin import ScenarioHtmlMixin
|
13
13
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
14
|
-
from edsl.exceptions.scenarios import ScenarioError
|
15
14
|
|
16
15
|
|
17
16
|
class Scenario(Base, UserDict, ScenarioHtmlMixin):
|
18
17
|
"""A Scenario is a dictionary of keys/values.
|
19
18
|
|
20
|
-
They can be used parameterize
|
21
|
-
|
22
|
-
__doc__ = "https://docs.expectedparrot.com/en/latest/scenarios.html"
|
19
|
+
They can be used parameterize edsl questions."""
|
23
20
|
|
24
21
|
def __init__(self, data: Union[dict, None] = None, name: str = None):
|
25
22
|
"""Initialize a new Scenario.
|
26
23
|
|
27
|
-
|
28
|
-
|
29
|
-
if not isinstance(data, dict) and data is not None:
|
30
|
-
raise EDSLScenarioError(
|
31
|
-
"You must pass in a dictionary to initialize a Scenario."
|
32
|
-
)
|
33
|
-
|
24
|
+
:param data: A dictionary of keys/values for parameterizing questions.
|
25
|
+
"""
|
34
26
|
self.data = data if data is not None else {}
|
35
27
|
self.name = name
|
36
28
|
|
@@ -49,6 +41,13 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
|
|
49
41
|
|
50
42
|
return ScenarioList([copy.deepcopy(self) for _ in range(n)])
|
51
43
|
|
44
|
+
# @property
|
45
|
+
# def has_image(self) -> bool:
|
46
|
+
# """Return whether the scenario has an image."""
|
47
|
+
# if not hasattr(self, "_has_image"):
|
48
|
+
# self._has_image = False
|
49
|
+
# return self._has_image
|
50
|
+
|
52
51
|
@property
|
53
52
|
def has_jinja_braces(self) -> bool:
|
54
53
|
"""Return whether the scenario has jinja braces. This matters for rendering.
|
@@ -107,9 +106,7 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
|
|
107
106
|
s = Scenario(data1 | data2)
|
108
107
|
return s
|
109
108
|
|
110
|
-
def rename(
|
111
|
-
self, old_name_or_replacement_dict: dict, new_name: Optional[str] = None
|
112
|
-
) -> "Scenario":
|
109
|
+
def rename(self, replacement_dict: dict) -> "Scenario":
|
113
110
|
"""Rename the keys of a scenario.
|
114
111
|
|
115
112
|
:param replacement_dict: A dictionary of old keys to new keys.
|
@@ -119,16 +116,7 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
|
|
119
116
|
>>> s = Scenario({"food": "wood chips"})
|
120
117
|
>>> s.rename({"food": "food_preference"})
|
121
118
|
Scenario({'food_preference': 'wood chips'})
|
122
|
-
|
123
|
-
>>> s = Scenario({"food": "wood chips"})
|
124
|
-
>>> s.rename("food", "snack")
|
125
|
-
Scenario({'snack': 'wood chips'})
|
126
119
|
"""
|
127
|
-
if isinstance(old_name_or_replacement_dict, str) and new_name is not None:
|
128
|
-
replacement_dict = {old_name_or_replacement_dict: new_name}
|
129
|
-
else:
|
130
|
-
replacement_dict = old_name_or_replacement_dict
|
131
|
-
|
132
120
|
new_scenario = Scenario()
|
133
121
|
for key, value in self.items():
|
134
122
|
if key in replacement_dict:
|
@@ -228,20 +216,6 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
|
|
228
216
|
new_scenario[key] = self[key]
|
229
217
|
return new_scenario
|
230
218
|
|
231
|
-
def keep(self, list_of_keys: List[str]) -> "Scenario":
|
232
|
-
"""Keep a subset of keys from a scenario.
|
233
|
-
|
234
|
-
:param list_of_keys: The keys to keep.
|
235
|
-
|
236
|
-
Example:
|
237
|
-
|
238
|
-
>>> s = Scenario({"food": "wood chips", "drink": "water"})
|
239
|
-
>>> s.keep(["food"])
|
240
|
-
Scenario({'food': 'wood chips'})
|
241
|
-
"""
|
242
|
-
|
243
|
-
return self.select(list_of_keys)
|
244
|
-
|
245
219
|
@classmethod
|
246
220
|
def from_url(cls, url: str, field_name: Optional[str] = "text") -> "Scenario":
|
247
221
|
"""Creates a scenario from a URL.
|
@@ -257,17 +231,7 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
|
|
257
231
|
|
258
232
|
@classmethod
|
259
233
|
def from_file(cls, file_path: str, field_name: str) -> "Scenario":
|
260
|
-
"""Creates a scenario from a file.
|
261
|
-
|
262
|
-
>>> import tempfile
|
263
|
-
>>> with tempfile.NamedTemporaryFile(suffix=".txt", mode="w") as f:
|
264
|
-
... _ = f.write("This is a test.")
|
265
|
-
... _ = f.flush()
|
266
|
-
... s = Scenario.from_file(f.name, "file")
|
267
|
-
>>> s
|
268
|
-
Scenario({'file': FileStore(path='...')})
|
269
|
-
|
270
|
-
"""
|
234
|
+
"""Creates a scenario from a file."""
|
271
235
|
from edsl.scenarios.FileStore import FileStore
|
272
236
|
|
273
237
|
fs = FileStore(file_path)
|
@@ -1,24 +1,19 @@
|
|
1
1
|
import requests
|
2
|
-
from typing import Optional
|
3
2
|
from requests.adapters import HTTPAdapter
|
4
3
|
from requests.packages.urllib3.util.retry import Retry
|
5
4
|
|
6
5
|
|
7
6
|
class ScenarioHtmlMixin:
|
8
7
|
@classmethod
|
9
|
-
def from_html(cls, url: str
|
8
|
+
def from_html(cls, url: str) -> "Scenario":
|
10
9
|
"""Create a scenario from HTML content.
|
11
10
|
|
12
11
|
:param html: The HTML content.
|
13
|
-
:param field_name: The name of the field containing the HTML content.
|
14
|
-
|
15
12
|
|
16
13
|
"""
|
17
14
|
html = cls.fetch_html(url)
|
18
15
|
text = cls.extract_text(html)
|
19
|
-
|
20
|
-
field_name = "text"
|
21
|
-
return cls({"url": url, "html": html, field_name: text})
|
16
|
+
return cls({"url": url, "html": html, "text": text})
|
22
17
|
|
23
18
|
def fetch_html(url):
|
24
19
|
# Define the user-agent to mimic a browser
|
edsl/scenarios/ScenarioList.py
CHANGED
@@ -538,17 +538,6 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
538
538
|
"""
|
539
539
|
return ScenarioList([scenario.drop(fields) for scenario in self.data])
|
540
540
|
|
541
|
-
def keep(self, *fields) -> ScenarioList:
|
542
|
-
"""Keep only the specified fields in the scenarios.
|
543
|
-
|
544
|
-
Example:
|
545
|
-
|
546
|
-
>>> s = ScenarioList([Scenario({'a': 1, 'b': 1}), Scenario({'a': 1, 'b': 2})])
|
547
|
-
>>> s.keep('a')
|
548
|
-
ScenarioList([Scenario({'a': 1}), Scenario({'a': 1})])
|
549
|
-
"""
|
550
|
-
return ScenarioList([scenario.keep(fields) for scenario in self.data])
|
551
|
-
|
552
541
|
@classmethod
|
553
542
|
def from_list(
|
554
543
|
cls, name: str, values: list, func: Optional[Callable] = None
|
@@ -1061,7 +1050,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1061
1050
|
elif isinstance(key, int):
|
1062
1051
|
return super().__getitem__(key)
|
1063
1052
|
else:
|
1064
|
-
return self.
|
1053
|
+
return self.to_dict()[key]
|
1065
1054
|
|
1066
1055
|
def to_agent_list(self):
|
1067
1056
|
"""Convert the ScenarioList to an AgentList.
|
edsl/surveys/Rule.py
CHANGED
@@ -25,8 +25,6 @@ from jinja2 import Template
|
|
25
25
|
from rich import print
|
26
26
|
from simpleeval import EvalWithCompoundTypes
|
27
27
|
|
28
|
-
from edsl.exceptions.surveys import SurveyError
|
29
|
-
|
30
28
|
from edsl.exceptions import (
|
31
29
|
SurveyRuleCannotEvaluateError,
|
32
30
|
SurveyRuleCollectionHasNoRulesAtNodeError,
|
@@ -49,11 +47,11 @@ class QuestionIndex:
|
|
49
47
|
|
50
48
|
def __set__(self, obj, value):
|
51
49
|
if not isinstance(value, (int, EndOfSurvey.__class__)):
|
52
|
-
raise
|
50
|
+
raise ValueError(f"{self.name} must be an integer or EndOfSurvey")
|
53
51
|
if self.name == "_next_q" and isinstance(value, int):
|
54
52
|
current_q = getattr(obj, "_current_q")
|
55
53
|
if value <= current_q:
|
56
|
-
raise
|
54
|
+
raise ValueError("next_q must be greater than current_q")
|
57
55
|
setattr(obj, self.name, value)
|
58
56
|
|
59
57
|
|
@@ -102,17 +100,13 @@ class Rule:
|
|
102
100
|
raise SurveyRuleSendsYouBackwardsError
|
103
101
|
|
104
102
|
if not self.next_q == EndOfSurvey and self.current_q > self.next_q:
|
105
|
-
raise SurveyRuleSendsYouBackwardsError
|
106
|
-
f"current_q: {self.current_q}, next_q: {self.next_q}"
|
107
|
-
)
|
103
|
+
raise SurveyRuleSendsYouBackwardsError
|
108
104
|
|
109
105
|
# get the AST for the expression - used to extract the variables referenced in the expression
|
110
106
|
try:
|
111
107
|
self.ast_tree = ast.parse(self.expression)
|
112
108
|
except SyntaxError:
|
113
|
-
raise SurveyRuleSkipLogicSyntaxError
|
114
|
-
f"The expression {self.expression} is not valid Python syntax."
|
115
|
-
)
|
109
|
+
raise SurveyRuleSkipLogicSyntaxError
|
116
110
|
|
117
111
|
# get the names of the variables in the expression
|
118
112
|
# e.g., q1 == 'yes' -> ['q1']
|