edsl 0.1.37.dev4__py3-none-any.whl → 0.1.37.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/Agent.py +86 -35
  3. edsl/agents/AgentList.py +5 -0
  4. edsl/agents/InvigilatorBase.py +2 -23
  5. edsl/agents/PromptConstructor.py +147 -106
  6. edsl/agents/descriptors.py +17 -4
  7. edsl/config.py +1 -1
  8. edsl/conjure/AgentConstructionMixin.py +11 -3
  9. edsl/conversation/Conversation.py +66 -14
  10. edsl/conversation/chips.py +95 -0
  11. edsl/coop/coop.py +134 -3
  12. edsl/data/Cache.py +1 -1
  13. edsl/exceptions/BaseException.py +21 -0
  14. edsl/exceptions/__init__.py +7 -3
  15. edsl/exceptions/agents.py +17 -19
  16. edsl/exceptions/results.py +11 -8
  17. edsl/exceptions/scenarios.py +22 -0
  18. edsl/exceptions/surveys.py +13 -10
  19. edsl/inference_services/InferenceServicesCollection.py +32 -9
  20. edsl/jobs/Jobs.py +265 -53
  21. edsl/jobs/interviews/InterviewExceptionEntry.py +5 -1
  22. edsl/jobs/tasks/TaskHistory.py +1 -0
  23. edsl/language_models/KeyLookup.py +30 -0
  24. edsl/language_models/LanguageModel.py +47 -59
  25. edsl/language_models/__init__.py +1 -0
  26. edsl/prompts/Prompt.py +8 -4
  27. edsl/questions/QuestionBase.py +53 -13
  28. edsl/questions/QuestionBasePromptsMixin.py +1 -33
  29. edsl/questions/QuestionFunctional.py +2 -2
  30. edsl/questions/descriptors.py +23 -28
  31. edsl/results/DatasetExportMixin.py +25 -1
  32. edsl/results/Result.py +16 -1
  33. edsl/results/Results.py +31 -120
  34. edsl/results/ResultsDBMixin.py +1 -1
  35. edsl/results/Selector.py +18 -1
  36. edsl/scenarios/Scenario.py +48 -12
  37. edsl/scenarios/ScenarioHtmlMixin.py +7 -2
  38. edsl/scenarios/ScenarioList.py +12 -1
  39. edsl/surveys/Rule.py +10 -4
  40. edsl/surveys/Survey.py +100 -77
  41. edsl/utilities/utilities.py +18 -0
  42. {edsl-0.1.37.dev4.dist-info → edsl-0.1.37.dev6.dist-info}/METADATA +1 -1
  43. {edsl-0.1.37.dev4.dist-info → edsl-0.1.37.dev6.dist-info}/RECORD +45 -41
  44. {edsl-0.1.37.dev4.dist-info → edsl-0.1.37.dev6.dist-info}/LICENSE +0 -0
  45. {edsl-0.1.37.dev4.dist-info → edsl-0.1.37.dev6.dist-info}/WHEEL +0 -0
edsl/results/Results.py CHANGED
@@ -7,11 +7,17 @@ from __future__ import annotations
7
7
  import json
8
8
  import random
9
9
  from collections import UserList, defaultdict
10
- from typing import Optional, Callable, Any, Type, Union, List
10
+ from typing import Optional, Callable, Any, Type, Union, List, TYPE_CHECKING
11
+
12
+ if TYPE_CHECKING:
13
+ from edsl import Survey, Cache, AgentList, ModelList, ScenarioList
14
+ from edsl.results.Result import Result
15
+ from edsl.jobs.tasks.TaskHistory import TaskHistory
11
16
 
12
17
  from simpleeval import EvalWithCompoundTypes
13
18
 
14
19
  from edsl.exceptions.results import (
20
+ ResultsError,
15
21
  ResultsBadMutationstringError,
16
22
  ResultsColumnNotFoundError,
17
23
  ResultsInvalidNameError,
@@ -40,7 +46,7 @@ class Mixins(
40
46
  ResultsGGMixin,
41
47
  ResultsToolsMixin,
42
48
  ):
43
- def print_long(self, max_rows=None) -> None:
49
+ def print_long(self, max_rows: int = None) -> None:
44
50
  """Print the results in long format.
45
51
 
46
52
  >>> from edsl.results import Results
@@ -84,13 +90,13 @@ class Results(UserList, Mixins, Base):
84
90
 
85
91
  def __init__(
86
92
  self,
87
- survey: Optional["Survey"] = None,
88
- data: Optional[list["Result"]] = None,
93
+ survey: Optional[Survey] = None,
94
+ data: Optional[list[Result]] = None,
89
95
  created_columns: Optional[list[str]] = None,
90
- cache: Optional["Cache"] = None,
96
+ cache: Optional[Cache] = None,
91
97
  job_uuid: Optional[str] = None,
92
98
  total_results: Optional[int] = None,
93
- task_history: Optional["TaskHistory"] = None,
99
+ task_history: Optional[TaskHistory] = None,
94
100
  ):
95
101
  """Instantiate a `Results` object with a survey and a list of `Result` objects.
96
102
 
@@ -235,11 +241,11 @@ class Results(UserList, Mixins, Base):
235
241
  >>> r3 = r + r2
236
242
  """
237
243
  if self.survey != other.survey:
238
- raise Exception(
239
- "The surveys are not the same so they cannot be added together."
244
+ raise ResultsError(
245
+ "The surveys are not the same so the the results cannot be added together."
240
246
  )
241
247
  if self.created_columns != other.created_columns:
242
- raise Exception(
248
+ raise ResultsError(
243
249
  "The created columns are not the same so they cannot be added together."
244
250
  )
245
251
 
@@ -258,17 +264,6 @@ class Results(UserList, Mixins, Base):
258
264
  from IPython.display import HTML
259
265
 
260
266
  json_str = json.dumps(self.to_dict()["data"], indent=4)
261
- # from pygments import highlight
262
- # from pygments.lexers import JsonLexer
263
- # 3from pygments.formatters import HtmlFormatter
264
-
265
- # formatted_json = highlight(
266
- # json_str,
267
- # JsonLexer(),
268
- # HtmlFormatter(style="default", full=True, noclasses=True),
269
- # )
270
- # return HTML(formatted_json).data
271
- # print(json_str)
272
267
  return f"<pre>{json_str}</pre>"
273
268
 
274
269
  def _to_dict(self, sort=False):
@@ -328,7 +323,7 @@ class Results(UserList, Mixins, Base):
328
323
  def hashes(self) -> set:
329
324
  return set(hash(result) for result in self.data)
330
325
 
331
- def sample(self, n: int) -> "Results":
326
+ def sample(self, n: int) -> Results:
332
327
  """Return a random sample of the results.
333
328
 
334
329
  :param n: The number of samples to return.
@@ -346,7 +341,7 @@ class Results(UserList, Mixins, Base):
346
341
  indices = list(range(len(values)))
347
342
  sampled_indices = random.sample(indices, n)
348
343
  if n > len(indices):
349
- raise ValueError(
344
+ raise ResultsError(
350
345
  f"Cannot sample {n} items from a list of length {len(indices)}."
351
346
  )
352
347
  entry[key] = [values[i] for i in sampled_indices]
@@ -399,11 +394,12 @@ class Results(UserList, Mixins, Base):
399
394
  - Uses the key_to_data_type property of the Result class.
400
395
  - Includes any columns that the user has created with `mutate`
401
396
  """
402
- d = {}
397
+ d: dict = {}
403
398
  for result in self.data:
404
399
  d.update(result.key_to_data_type)
405
400
  for column in self.created_columns:
406
401
  d[column] = "answer"
402
+
407
403
  return d
408
404
 
409
405
  @property
@@ -453,7 +449,7 @@ class Results(UserList, Mixins, Base):
453
449
  from edsl.utilities.utilities import shorten_string
454
450
 
455
451
  if not self.survey:
456
- raise Exception("Survey is not defined so no answer keys are available.")
452
+ raise ResultsError("Survey is not defined so no answer keys are available.")
457
453
 
458
454
  answer_keys = self._data_type_to_keys["answer"]
459
455
  answer_keys = {k for k in answer_keys if "_comment" not in k}
@@ -466,7 +462,7 @@ class Results(UserList, Mixins, Base):
466
462
  return sorted_dict
467
463
 
468
464
  @property
469
- def agents(self) -> "AgentList":
465
+ def agents(self) -> AgentList:
470
466
  """Return a list of all of the agents in the Results.
471
467
 
472
468
  Example:
@@ -480,7 +476,7 @@ class Results(UserList, Mixins, Base):
480
476
  return AgentList([r.agent for r in self.data])
481
477
 
482
478
  @property
483
- def models(self) -> list[Type["LanguageModel"]]:
479
+ def models(self) -> ModelList:
484
480
  """Return a list of all of the models in the Results.
485
481
 
486
482
  Example:
@@ -489,10 +485,12 @@ class Results(UserList, Mixins, Base):
489
485
  >>> r.models[0]
490
486
  Model(model_name = ...)
491
487
  """
492
- return [r.model for r in self.data]
488
+ from edsl import ModelList
489
+
490
+ return ModelList([r.model for r in self.data])
493
491
 
494
492
  @property
495
- def scenarios(self) -> "ScenarioList":
493
+ def scenarios(self) -> ScenarioList:
496
494
  """Return a list of all of the scenarios in the Results.
497
495
 
498
496
  Example:
@@ -569,7 +567,7 @@ class Results(UserList, Mixins, Base):
569
567
  )
570
568
  return sorted(list(all_keys))
571
569
 
572
- def first(self) -> "Result":
570
+ def first(self) -> Result:
573
571
  """Return the first observation in the results.
574
572
 
575
573
  Example:
@@ -819,7 +817,7 @@ class Results(UserList, Mixins, Base):
819
817
 
820
818
  return Results(survey=self.survey, data=new_data, created_columns=None)
821
819
 
822
- def select(self, *columns: Union[str, list[str]]) -> "Dataset":
820
+ def select(self, *columns: Union[str, list[str]]) -> Results:
823
821
  """
824
822
  Select data from the results and format it.
825
823
 
@@ -832,93 +830,12 @@ class Results(UserList, Mixins, Base):
832
830
  Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}])
833
831
 
834
832
  >>> results.select('how_feeling', 'model', 'how_feeling')
835
- Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'model.model': ['...', '...', '...', '...']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}])
833
+ Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'model.model': ['...', '...', '...', '...']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}])
836
834
 
837
835
  >>> from edsl import Results; r = Results.example(); r.select('answer.how_feeling_y')
838
836
  Dataset([{'answer.how_feeling_yesterday': ['Great', 'Good', 'OK', 'Terrible']}])
839
837
  """
840
838
 
841
- # if len(self) == 0:
842
- # raise Exception("No data to select from---the Results object is empty.")
843
-
844
- if not columns or columns == ("*",) or columns == (None,):
845
- # is the users passes nothing, then we'll return all the columns
846
- columns = ("*.*",)
847
-
848
- if isinstance(columns[0], list):
849
- columns = tuple(columns[0])
850
-
851
- def get_data_types_to_return(parsed_data_type):
852
- if parsed_data_type == "*": # they want all of the columns
853
- return self.known_data_types
854
- else:
855
- if parsed_data_type not in self.known_data_types:
856
- raise Exception(
857
- f"Data type {parsed_data_type} not found in data. Did you mean one of {self.known_data_types}"
858
- )
859
- return [parsed_data_type]
860
-
861
- # we're doing to populate this with the data we want to fetch
862
- to_fetch = defaultdict(list)
863
-
864
- new_data = []
865
- items_in_order = []
866
- # iterate through the passed columns
867
- for column in columns:
868
- # a user could pass 'result.how_feeling' or just 'how_feeling'
869
- matches = self._matching_columns(column)
870
- if len(matches) > 1:
871
- raise Exception(
872
- f"Column '{column}' is ambiguous. Did you mean one of {matches}?"
873
- )
874
- if len(matches) == 0 and ".*" not in column:
875
- raise Exception(f"Column '{column}' not found in data.")
876
- if len(matches) == 1:
877
- column = matches[0]
878
-
879
- parsed_data_type, parsed_key = self._parse_column(column)
880
- data_types = get_data_types_to_return(parsed_data_type)
881
- found_once = False # we need to track this to make sure we found the key at least once
882
-
883
- for data_type in data_types:
884
- # the keys for that data_type e.g.,# if data_type is 'answer', then the keys are 'how_feeling', 'how_feeling_comment', etc.
885
- relevant_keys = self._data_type_to_keys[data_type]
886
-
887
- for key in relevant_keys:
888
- if key == parsed_key or parsed_key == "*":
889
- found_once = True
890
- to_fetch[data_type].append(key)
891
- items_in_order.append(data_type + "." + key)
892
-
893
- if not found_once:
894
- raise Exception(f"Key {parsed_key} not found in data.")
895
-
896
- for data_type in to_fetch:
897
- for key in to_fetch[data_type]:
898
- entries = self._fetch_list(data_type, key)
899
- new_data.append({data_type + "." + key: entries})
900
-
901
- def sort_by_key_order(dictionary):
902
- # Extract the single key from the dictionary
903
- single_key = next(iter(dictionary))
904
- # Return the index of this key in the list_of_keys
905
- return items_in_order.index(single_key)
906
-
907
- # sorted(new_data, key=sort_by_key_order)
908
- from edsl.results.Dataset import Dataset
909
-
910
- sorted_new_data = []
911
-
912
- # WORKS but slow
913
- for key in items_in_order:
914
- for d in new_data:
915
- if key in d:
916
- sorted_new_data.append(d)
917
- break
918
-
919
- return Dataset(sorted_new_data)
920
-
921
- def select(self, *columns: Union[str, list[str]]) -> "Results":
922
839
  from edsl.results.Selector import Selector
923
840
 
924
841
  if len(self) == 0:
@@ -1028,6 +945,7 @@ class Results(UserList, Mixins, Base):
1028
945
  Traceback (most recent call last):
1029
946
  ...
1030
947
  edsl.exceptions.results.ResultsFilterError: You must use '==' instead of '=' in the filter expression.
948
+ ...
1031
949
 
1032
950
  >>> r.filter("how_feeling == 'Great' or how_feeling == 'Terrible'").select('how_feeling').print()
1033
951
  ┏━━━━━━━━━━━━━━┓
@@ -1105,6 +1023,7 @@ class Results(UserList, Mixins, Base):
1105
1023
  stop_on_exception=True,
1106
1024
  skip_retry=True,
1107
1025
  raise_validation_errors=True,
1026
+ disable_remote_cache=True,
1108
1027
  disable_remote_inference=True,
1109
1028
  )
1110
1029
  return results
@@ -1112,14 +1031,6 @@ class Results(UserList, Mixins, Base):
1112
1031
  def rich_print(self):
1113
1032
  """Display an object as a table."""
1114
1033
  pass
1115
- # with io.StringIO() as buf:
1116
- # console = Console(file=buf, record=True)
1117
-
1118
- # for index, result in enumerate(self):
1119
- # console.print(f"Result {index}")
1120
- # console.print(result.rich_print())
1121
-
1122
- # return console.export_text()
1123
1034
 
1124
1035
  def __str__(self):
1125
1036
  data = self.to_dict()["data"]
@@ -93,7 +93,7 @@ class ResultsDBMixin:
93
93
  from sqlalchemy import create_engine
94
94
 
95
95
  engine = create_engine("sqlite:///:memory:")
96
- df = self.to_pandas(remove_prefix=remove_prefix)
96
+ df = self.to_pandas(remove_prefix=remove_prefix, lists_as_strings=True)
97
97
  df.to_sql("self", engine, index=False, if_exists="replace")
98
98
  return engine.connect()
99
99
  else:
edsl/results/Selector.py CHANGED
@@ -12,6 +12,7 @@ class Selector:
12
12
  fetch_list_func,
13
13
  columns: List[str],
14
14
  ):
15
+ """Selects columns from a Results object"""
15
16
  self.known_data_types = known_data_types
16
17
  self._data_type_to_keys = data_type_to_keys
17
18
  self._key_to_data_type = key_to_data_type
@@ -21,10 +22,19 @@ class Selector:
21
22
  def select(self, *columns: Union[str, List[str]]) -> "Dataset":
22
23
  columns = self._normalize_columns(columns)
23
24
  to_fetch = self._get_columns_to_fetch(columns)
25
+ # breakpoint()
24
26
  new_data = self._fetch_data(to_fetch)
25
27
  return Dataset(new_data)
26
28
 
27
29
  def _normalize_columns(self, columns: Union[str, List[str]]) -> tuple:
30
+ """Normalize the columns to a tuple of strings
31
+
32
+ >>> s = Selector([], {}, {}, lambda x, y: x, [])
33
+ >>> s._normalize_columns([["a", "b"], ])
34
+ ('a', 'b')
35
+ >>> s._normalize_columns(None)
36
+ ('*.*',)
37
+ """
28
38
  if not columns or columns == ("*",) or columns == (None,):
29
39
  return ("*.*",)
30
40
  if isinstance(columns[0], list):
@@ -37,6 +47,7 @@ class Selector:
37
47
 
38
48
  for column in columns:
39
49
  matches = self._find_matching_columns(column)
50
+ # breakpoint()
40
51
  self._validate_matches(column, matches)
41
52
 
42
53
  if len(matches) == 1:
@@ -52,7 +63,7 @@ class Selector:
52
63
  search_in_list = self.columns
53
64
  else:
54
65
  search_in_list = [s.split(".")[1] for s in self.columns]
55
-
66
+ # breakpoint()
56
67
  matches = [s for s in search_in_list if s.startswith(partial_name)]
57
68
  return [partial_name] if partial_name in matches else matches
58
69
 
@@ -116,3 +127,9 @@ class Selector:
116
127
  new_data.append({f"{data_type}.{key}": entries})
117
128
 
118
129
  return [d for key in self.items_in_order for d in new_data if key in d]
130
+
131
+
132
+ if __name__ == "__main__":
133
+ import doctest
134
+
135
+ doctest.testmod()
@@ -11,18 +11,26 @@ from uuid import uuid4
11
11
  from edsl.Base import Base
12
12
  from edsl.scenarios.ScenarioHtmlMixin import ScenarioHtmlMixin
13
13
  from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
14
+ from edsl.exceptions.scenarios import ScenarioError
14
15
 
15
16
 
16
17
  class Scenario(Base, UserDict, ScenarioHtmlMixin):
17
18
  """A Scenario is a dictionary of keys/values.
18
19
 
19
- They can be used parameterize edsl questions."""
20
+ They can be used parameterize EDSL questions."""
21
+
22
+ __doc__ = "https://docs.expectedparrot.com/en/latest/scenarios.html"
20
23
 
21
24
  def __init__(self, data: Union[dict, None] = None, name: str = None):
22
25
  """Initialize a new Scenario.
23
26
 
24
- :param data: A dictionary of keys/values for parameterizing questions.
25
- """
27
+ # :param data: A dictionary of keys/values for parameterizing questions.
28
+ #"""
29
+ if not isinstance(data, dict) and data is not None:
30
+ raise EDSLScenarioError(
31
+ "You must pass in a dictionary to initialize a Scenario."
32
+ )
33
+
26
34
  self.data = data if data is not None else {}
27
35
  self.name = name
28
36
 
@@ -41,13 +49,6 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
41
49
 
42
50
  return ScenarioList([copy.deepcopy(self) for _ in range(n)])
43
51
 
44
- # @property
45
- # def has_image(self) -> bool:
46
- # """Return whether the scenario has an image."""
47
- # if not hasattr(self, "_has_image"):
48
- # self._has_image = False
49
- # return self._has_image
50
-
51
52
  @property
52
53
  def has_jinja_braces(self) -> bool:
53
54
  """Return whether the scenario has jinja braces. This matters for rendering.
@@ -106,7 +107,9 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
106
107
  s = Scenario(data1 | data2)
107
108
  return s
108
109
 
109
- def rename(self, replacement_dict: dict) -> "Scenario":
110
+ def rename(
111
+ self, old_name_or_replacement_dict: dict, new_name: Optional[str] = None
112
+ ) -> "Scenario":
110
113
  """Rename the keys of a scenario.
111
114
 
112
115
  :param replacement_dict: A dictionary of old keys to new keys.
@@ -116,7 +119,16 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
116
119
  >>> s = Scenario({"food": "wood chips"})
117
120
  >>> s.rename({"food": "food_preference"})
118
121
  Scenario({'food_preference': 'wood chips'})
122
+
123
+ >>> s = Scenario({"food": "wood chips"})
124
+ >>> s.rename("food", "snack")
125
+ Scenario({'snack': 'wood chips'})
119
126
  """
127
+ if isinstance(old_name_or_replacement_dict, str) and new_name is not None:
128
+ replacement_dict = {old_name_or_replacement_dict: new_name}
129
+ else:
130
+ replacement_dict = old_name_or_replacement_dict
131
+
120
132
  new_scenario = Scenario()
121
133
  for key, value in self.items():
122
134
  if key in replacement_dict:
@@ -216,6 +228,20 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
216
228
  new_scenario[key] = self[key]
217
229
  return new_scenario
218
230
 
231
+ def keep(self, list_of_keys: List[str]) -> "Scenario":
232
+ """Keep a subset of keys from a scenario.
233
+
234
+ :param list_of_keys: The keys to keep.
235
+
236
+ Example:
237
+
238
+ >>> s = Scenario({"food": "wood chips", "drink": "water"})
239
+ >>> s.keep(["food"])
240
+ Scenario({'food': 'wood chips'})
241
+ """
242
+
243
+ return self.select(list_of_keys)
244
+
219
245
  @classmethod
220
246
  def from_url(cls, url: str, field_name: Optional[str] = "text") -> "Scenario":
221
247
  """Creates a scenario from a URL.
@@ -231,7 +257,17 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
231
257
 
232
258
  @classmethod
233
259
  def from_file(cls, file_path: str, field_name: str) -> "Scenario":
234
- """Creates a scenario from a file."""
260
+ """Creates a scenario from a file.
261
+
262
+ >>> import tempfile
263
+ >>> with tempfile.NamedTemporaryFile(suffix=".txt", mode="w") as f:
264
+ ... _ = f.write("This is a test.")
265
+ ... _ = f.flush()
266
+ ... s = Scenario.from_file(f.name, "file")
267
+ >>> s
268
+ Scenario({'file': FileStore(path='...')})
269
+
270
+ """
235
271
  from edsl.scenarios.FileStore import FileStore
236
272
 
237
273
  fs = FileStore(file_path)
@@ -1,19 +1,24 @@
1
1
  import requests
2
+ from typing import Optional
2
3
  from requests.adapters import HTTPAdapter
3
4
  from requests.packages.urllib3.util.retry import Retry
4
5
 
5
6
 
6
7
  class ScenarioHtmlMixin:
7
8
  @classmethod
8
- def from_html(cls, url: str) -> "Scenario":
9
+ def from_html(cls, url: str, field_name: Optional[str] = None) -> "Scenario":
9
10
  """Create a scenario from HTML content.
10
11
 
11
12
  :param html: The HTML content.
13
+ :param field_name: The name of the field containing the HTML content.
14
+
12
15
 
13
16
  """
14
17
  html = cls.fetch_html(url)
15
18
  text = cls.extract_text(html)
16
- return cls({"url": url, "html": html, "text": text})
19
+ if not field_name:
20
+ field_name = "text"
21
+ return cls({"url": url, "html": html, field_name: text})
17
22
 
18
23
  def fetch_html(url):
19
24
  # Define the user-agent to mimic a browser
@@ -538,6 +538,17 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
538
538
  """
539
539
  return ScenarioList([scenario.drop(fields) for scenario in self.data])
540
540
 
541
+ def keep(self, *fields) -> ScenarioList:
542
+ """Keep only the specified fields in the scenarios.
543
+
544
+ Example:
545
+
546
+ >>> s = ScenarioList([Scenario({'a': 1, 'b': 1}), Scenario({'a': 1, 'b': 2})])
547
+ >>> s.keep('a')
548
+ ScenarioList([Scenario({'a': 1}), Scenario({'a': 1})])
549
+ """
550
+ return ScenarioList([scenario.keep(fields) for scenario in self.data])
551
+
541
552
  @classmethod
542
553
  def from_list(
543
554
  cls, name: str, values: list, func: Optional[Callable] = None
@@ -1050,7 +1061,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
1050
1061
  elif isinstance(key, int):
1051
1062
  return super().__getitem__(key)
1052
1063
  else:
1053
- return self.to_dict()[key]
1064
+ return self._to_dict()[key]
1054
1065
 
1055
1066
  def to_agent_list(self):
1056
1067
  """Convert the ScenarioList to an AgentList.
edsl/surveys/Rule.py CHANGED
@@ -25,6 +25,8 @@ from jinja2 import Template
25
25
  from rich import print
26
26
  from simpleeval import EvalWithCompoundTypes
27
27
 
28
+ from edsl.exceptions.surveys import SurveyError
29
+
28
30
  from edsl.exceptions import (
29
31
  SurveyRuleCannotEvaluateError,
30
32
  SurveyRuleCollectionHasNoRulesAtNodeError,
@@ -47,11 +49,11 @@ class QuestionIndex:
47
49
 
48
50
  def __set__(self, obj, value):
49
51
  if not isinstance(value, (int, EndOfSurvey.__class__)):
50
- raise ValueError(f"{self.name} must be an integer or EndOfSurvey")
52
+ raise SurveyError(f"{self.name} must be an integer or EndOfSurvey")
51
53
  if self.name == "_next_q" and isinstance(value, int):
52
54
  current_q = getattr(obj, "_current_q")
53
55
  if value <= current_q:
54
- raise ValueError("next_q must be greater than current_q")
56
+ raise SurveyError("next_q must be greater than current_q")
55
57
  setattr(obj, self.name, value)
56
58
 
57
59
 
@@ -100,13 +102,17 @@ class Rule:
100
102
  raise SurveyRuleSendsYouBackwardsError
101
103
 
102
104
  if not self.next_q == EndOfSurvey and self.current_q > self.next_q:
103
- raise SurveyRuleSendsYouBackwardsError
105
+ raise SurveyRuleSendsYouBackwardsError(
106
+ f"current_q: {self.current_q}, next_q: {self.next_q}"
107
+ )
104
108
 
105
109
  # get the AST for the expression - used to extract the variables referenced in the expression
106
110
  try:
107
111
  self.ast_tree = ast.parse(self.expression)
108
112
  except SyntaxError:
109
- raise SurveyRuleSkipLogicSyntaxError
113
+ raise SurveyRuleSkipLogicSyntaxError(
114
+ f"The expression {self.expression} is not valid Python syntax."
115
+ )
110
116
 
111
117
  # get the names of the variables in the expression
112
118
  # e.g., q1 == 'yes' -> ['q1']