edsl 0.1.37__py3-none-any.whl → 0.1.37.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/Agent.py +35 -86
  3. edsl/agents/AgentList.py +0 -5
  4. edsl/agents/InvigilatorBase.py +23 -2
  5. edsl/agents/PromptConstructor.py +105 -148
  6. edsl/agents/descriptors.py +4 -17
  7. edsl/conjure/AgentConstructionMixin.py +3 -11
  8. edsl/conversation/Conversation.py +14 -66
  9. edsl/coop/coop.py +14 -148
  10. edsl/data/Cache.py +1 -1
  11. edsl/exceptions/__init__.py +3 -7
  12. edsl/exceptions/agents.py +19 -17
  13. edsl/exceptions/results.py +8 -11
  14. edsl/exceptions/surveys.py +10 -13
  15. edsl/inference_services/AwsBedrock.py +2 -7
  16. edsl/inference_services/InferenceServicesCollection.py +9 -32
  17. edsl/jobs/Jobs.py +71 -306
  18. edsl/jobs/interviews/InterviewExceptionEntry.py +1 -5
  19. edsl/jobs/tasks/TaskHistory.py +0 -1
  20. edsl/language_models/LanguageModel.py +59 -47
  21. edsl/language_models/__init__.py +0 -1
  22. edsl/prompts/Prompt.py +4 -11
  23. edsl/questions/QuestionBase.py +13 -53
  24. edsl/questions/QuestionBasePromptsMixin.py +33 -1
  25. edsl/questions/QuestionFreeText.py +0 -1
  26. edsl/questions/QuestionFunctional.py +2 -2
  27. edsl/questions/descriptors.py +28 -23
  28. edsl/results/DatasetExportMixin.py +1 -25
  29. edsl/results/Result.py +1 -16
  30. edsl/results/Results.py +120 -31
  31. edsl/results/ResultsDBMixin.py +1 -1
  32. edsl/results/Selector.py +1 -18
  33. edsl/scenarios/Scenario.py +12 -48
  34. edsl/scenarios/ScenarioHtmlMixin.py +2 -7
  35. edsl/scenarios/ScenarioList.py +1 -12
  36. edsl/surveys/Rule.py +4 -10
  37. edsl/surveys/Survey.py +77 -100
  38. edsl/utilities/utilities.py +0 -18
  39. {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/METADATA +1 -1
  40. {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/RECORD +42 -46
  41. edsl/conversation/chips.py +0 -95
  42. edsl/exceptions/BaseException.py +0 -21
  43. edsl/exceptions/scenarios.py +0 -22
  44. edsl/language_models/KeyLookup.py +0 -30
  45. {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/LICENSE +0 -0
  46. {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/WHEEL +0 -0
edsl/results/Results.py CHANGED
@@ -7,17 +7,11 @@ from __future__ import annotations
7
7
  import json
8
8
  import random
9
9
  from collections import UserList, defaultdict
10
- from typing import Optional, Callable, Any, Type, Union, List, TYPE_CHECKING
11
-
12
- if TYPE_CHECKING:
13
- from edsl import Survey, Cache, AgentList, ModelList, ScenarioList
14
- from edsl.results.Result import Result
15
- from edsl.jobs.tasks.TaskHistory import TaskHistory
10
+ from typing import Optional, Callable, Any, Type, Union, List
16
11
 
17
12
  from simpleeval import EvalWithCompoundTypes
18
13
 
19
14
  from edsl.exceptions.results import (
20
- ResultsError,
21
15
  ResultsBadMutationstringError,
22
16
  ResultsColumnNotFoundError,
23
17
  ResultsInvalidNameError,
@@ -46,7 +40,7 @@ class Mixins(
46
40
  ResultsGGMixin,
47
41
  ResultsToolsMixin,
48
42
  ):
49
- def print_long(self, max_rows: int = None) -> None:
43
+ def print_long(self, max_rows=None) -> None:
50
44
  """Print the results in long format.
51
45
 
52
46
  >>> from edsl.results import Results
@@ -90,13 +84,13 @@ class Results(UserList, Mixins, Base):
90
84
 
91
85
  def __init__(
92
86
  self,
93
- survey: Optional[Survey] = None,
94
- data: Optional[list[Result]] = None,
87
+ survey: Optional["Survey"] = None,
88
+ data: Optional[list["Result"]] = None,
95
89
  created_columns: Optional[list[str]] = None,
96
- cache: Optional[Cache] = None,
90
+ cache: Optional["Cache"] = None,
97
91
  job_uuid: Optional[str] = None,
98
92
  total_results: Optional[int] = None,
99
- task_history: Optional[TaskHistory] = None,
93
+ task_history: Optional["TaskHistory"] = None,
100
94
  ):
101
95
  """Instantiate a `Results` object with a survey and a list of `Result` objects.
102
96
 
@@ -241,11 +235,11 @@ class Results(UserList, Mixins, Base):
241
235
  >>> r3 = r + r2
242
236
  """
243
237
  if self.survey != other.survey:
244
- raise ResultsError(
245
- "The surveys are not the same so the the results cannot be added together."
238
+ raise Exception(
239
+ "The surveys are not the same so they cannot be added together."
246
240
  )
247
241
  if self.created_columns != other.created_columns:
248
- raise ResultsError(
242
+ raise Exception(
249
243
  "The created columns are not the same so they cannot be added together."
250
244
  )
251
245
 
@@ -264,6 +258,17 @@ class Results(UserList, Mixins, Base):
264
258
  from IPython.display import HTML
265
259
 
266
260
  json_str = json.dumps(self.to_dict()["data"], indent=4)
261
+ # from pygments import highlight
262
+ # from pygments.lexers import JsonLexer
263
+ # 3from pygments.formatters import HtmlFormatter
264
+
265
+ # formatted_json = highlight(
266
+ # json_str,
267
+ # JsonLexer(),
268
+ # HtmlFormatter(style="default", full=True, noclasses=True),
269
+ # )
270
+ # return HTML(formatted_json).data
271
+ # print(json_str)
267
272
  return f"<pre>{json_str}</pre>"
268
273
 
269
274
  def _to_dict(self, sort=False):
@@ -323,7 +328,7 @@ class Results(UserList, Mixins, Base):
323
328
  def hashes(self) -> set:
324
329
  return set(hash(result) for result in self.data)
325
330
 
326
- def sample(self, n: int) -> Results:
331
+ def sample(self, n: int) -> "Results":
327
332
  """Return a random sample of the results.
328
333
 
329
334
  :param n: The number of samples to return.
@@ -341,7 +346,7 @@ class Results(UserList, Mixins, Base):
341
346
  indices = list(range(len(values)))
342
347
  sampled_indices = random.sample(indices, n)
343
348
  if n > len(indices):
344
- raise ResultsError(
349
+ raise ValueError(
345
350
  f"Cannot sample {n} items from a list of length {len(indices)}."
346
351
  )
347
352
  entry[key] = [values[i] for i in sampled_indices]
@@ -394,12 +399,11 @@ class Results(UserList, Mixins, Base):
394
399
  - Uses the key_to_data_type property of the Result class.
395
400
  - Includes any columns that the user has created with `mutate`
396
401
  """
397
- d: dict = {}
402
+ d = {}
398
403
  for result in self.data:
399
404
  d.update(result.key_to_data_type)
400
405
  for column in self.created_columns:
401
406
  d[column] = "answer"
402
-
403
407
  return d
404
408
 
405
409
  @property
@@ -449,7 +453,7 @@ class Results(UserList, Mixins, Base):
449
453
  from edsl.utilities.utilities import shorten_string
450
454
 
451
455
  if not self.survey:
452
- raise ResultsError("Survey is not defined so no answer keys are available.")
456
+ raise Exception("Survey is not defined so no answer keys are available.")
453
457
 
454
458
  answer_keys = self._data_type_to_keys["answer"]
455
459
  answer_keys = {k for k in answer_keys if "_comment" not in k}
@@ -462,7 +466,7 @@ class Results(UserList, Mixins, Base):
462
466
  return sorted_dict
463
467
 
464
468
  @property
465
- def agents(self) -> AgentList:
469
+ def agents(self) -> "AgentList":
466
470
  """Return a list of all of the agents in the Results.
467
471
 
468
472
  Example:
@@ -476,7 +480,7 @@ class Results(UserList, Mixins, Base):
476
480
  return AgentList([r.agent for r in self.data])
477
481
 
478
482
  @property
479
- def models(self) -> ModelList:
483
+ def models(self) -> list[Type["LanguageModel"]]:
480
484
  """Return a list of all of the models in the Results.
481
485
 
482
486
  Example:
@@ -485,12 +489,10 @@ class Results(UserList, Mixins, Base):
485
489
  >>> r.models[0]
486
490
  Model(model_name = ...)
487
491
  """
488
- from edsl import ModelList
489
-
490
- return ModelList([r.model for r in self.data])
492
+ return [r.model for r in self.data]
491
493
 
492
494
  @property
493
- def scenarios(self) -> ScenarioList:
495
+ def scenarios(self) -> "ScenarioList":
494
496
  """Return a list of all of the scenarios in the Results.
495
497
 
496
498
  Example:
@@ -567,7 +569,7 @@ class Results(UserList, Mixins, Base):
567
569
  )
568
570
  return sorted(list(all_keys))
569
571
 
570
- def first(self) -> Result:
572
+ def first(self) -> "Result":
571
573
  """Return the first observation in the results.
572
574
 
573
575
  Example:
@@ -817,7 +819,7 @@ class Results(UserList, Mixins, Base):
817
819
 
818
820
  return Results(survey=self.survey, data=new_data, created_columns=None)
819
821
 
820
- def select(self, *columns: Union[str, list[str]]) -> Results:
822
+ def select(self, *columns: Union[str, list[str]]) -> "Dataset":
821
823
  """
822
824
  Select data from the results and format it.
823
825
 
@@ -830,12 +832,93 @@ class Results(UserList, Mixins, Base):
830
832
  Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}])
831
833
 
832
834
  >>> results.select('how_feeling', 'model', 'how_feeling')
833
- Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'model.model': ['...', '...', '...', '...']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}])
835
+ Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'model.model': ['...', '...', '...', '...']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}])
834
836
 
835
837
  >>> from edsl import Results; r = Results.example(); r.select('answer.how_feeling_y')
836
838
  Dataset([{'answer.how_feeling_yesterday': ['Great', 'Good', 'OK', 'Terrible']}])
837
839
  """
838
840
 
841
+ # if len(self) == 0:
842
+ # raise Exception("No data to select from---the Results object is empty.")
843
+
844
+ if not columns or columns == ("*",) or columns == (None,):
845
+ # is the users passes nothing, then we'll return all the columns
846
+ columns = ("*.*",)
847
+
848
+ if isinstance(columns[0], list):
849
+ columns = tuple(columns[0])
850
+
851
+ def get_data_types_to_return(parsed_data_type):
852
+ if parsed_data_type == "*": # they want all of the columns
853
+ return self.known_data_types
854
+ else:
855
+ if parsed_data_type not in self.known_data_types:
856
+ raise Exception(
857
+ f"Data type {parsed_data_type} not found in data. Did you mean one of {self.known_data_types}"
858
+ )
859
+ return [parsed_data_type]
860
+
861
+ # we're doing to populate this with the data we want to fetch
862
+ to_fetch = defaultdict(list)
863
+
864
+ new_data = []
865
+ items_in_order = []
866
+ # iterate through the passed columns
867
+ for column in columns:
868
+ # a user could pass 'result.how_feeling' or just 'how_feeling'
869
+ matches = self._matching_columns(column)
870
+ if len(matches) > 1:
871
+ raise Exception(
872
+ f"Column '{column}' is ambiguous. Did you mean one of {matches}?"
873
+ )
874
+ if len(matches) == 0 and ".*" not in column:
875
+ raise Exception(f"Column '{column}' not found in data.")
876
+ if len(matches) == 1:
877
+ column = matches[0]
878
+
879
+ parsed_data_type, parsed_key = self._parse_column(column)
880
+ data_types = get_data_types_to_return(parsed_data_type)
881
+ found_once = False # we need to track this to make sure we found the key at least once
882
+
883
+ for data_type in data_types:
884
+ # the keys for that data_type e.g.,# if data_type is 'answer', then the keys are 'how_feeling', 'how_feeling_comment', etc.
885
+ relevant_keys = self._data_type_to_keys[data_type]
886
+
887
+ for key in relevant_keys:
888
+ if key == parsed_key or parsed_key == "*":
889
+ found_once = True
890
+ to_fetch[data_type].append(key)
891
+ items_in_order.append(data_type + "." + key)
892
+
893
+ if not found_once:
894
+ raise Exception(f"Key {parsed_key} not found in data.")
895
+
896
+ for data_type in to_fetch:
897
+ for key in to_fetch[data_type]:
898
+ entries = self._fetch_list(data_type, key)
899
+ new_data.append({data_type + "." + key: entries})
900
+
901
+ def sort_by_key_order(dictionary):
902
+ # Extract the single key from the dictionary
903
+ single_key = next(iter(dictionary))
904
+ # Return the index of this key in the list_of_keys
905
+ return items_in_order.index(single_key)
906
+
907
+ # sorted(new_data, key=sort_by_key_order)
908
+ from edsl.results.Dataset import Dataset
909
+
910
+ sorted_new_data = []
911
+
912
+ # WORKS but slow
913
+ for key in items_in_order:
914
+ for d in new_data:
915
+ if key in d:
916
+ sorted_new_data.append(d)
917
+ break
918
+
919
+ return Dataset(sorted_new_data)
920
+
921
+ def select(self, *columns: Union[str, list[str]]) -> "Results":
839
922
  from edsl.results.Selector import Selector
840
923
 
841
924
  if len(self) == 0:
@@ -945,7 +1028,6 @@ class Results(UserList, Mixins, Base):
945
1028
  Traceback (most recent call last):
946
1029
  ...
947
1030
  edsl.exceptions.results.ResultsFilterError: You must use '==' instead of '=' in the filter expression.
948
- ...
949
1031
 
950
1032
  >>> r.filter("how_feeling == 'Great' or how_feeling == 'Terrible'").select('how_feeling').print()
951
1033
  ┏━━━━━━━━━━━━━━┓
@@ -1023,7 +1105,6 @@ class Results(UserList, Mixins, Base):
1023
1105
  stop_on_exception=True,
1024
1106
  skip_retry=True,
1025
1107
  raise_validation_errors=True,
1026
- disable_remote_cache=True,
1027
1108
  disable_remote_inference=True,
1028
1109
  )
1029
1110
  return results
@@ -1031,6 +1112,14 @@ class Results(UserList, Mixins, Base):
1031
1112
  def rich_print(self):
1032
1113
  """Display an object as a table."""
1033
1114
  pass
1115
+ # with io.StringIO() as buf:
1116
+ # console = Console(file=buf, record=True)
1117
+
1118
+ # for index, result in enumerate(self):
1119
+ # console.print(f"Result {index}")
1120
+ # console.print(result.rich_print())
1121
+
1122
+ # return console.export_text()
1034
1123
 
1035
1124
  def __str__(self):
1036
1125
  data = self.to_dict()["data"]
@@ -93,7 +93,7 @@ class ResultsDBMixin:
93
93
  from sqlalchemy import create_engine
94
94
 
95
95
  engine = create_engine("sqlite:///:memory:")
96
- df = self.to_pandas(remove_prefix=remove_prefix, lists_as_strings=True)
96
+ df = self.to_pandas(remove_prefix=remove_prefix)
97
97
  df.to_sql("self", engine, index=False, if_exists="replace")
98
98
  return engine.connect()
99
99
  else:
edsl/results/Selector.py CHANGED
@@ -12,7 +12,6 @@ class Selector:
12
12
  fetch_list_func,
13
13
  columns: List[str],
14
14
  ):
15
- """Selects columns from a Results object"""
16
15
  self.known_data_types = known_data_types
17
16
  self._data_type_to_keys = data_type_to_keys
18
17
  self._key_to_data_type = key_to_data_type
@@ -22,19 +21,10 @@ class Selector:
22
21
  def select(self, *columns: Union[str, List[str]]) -> "Dataset":
23
22
  columns = self._normalize_columns(columns)
24
23
  to_fetch = self._get_columns_to_fetch(columns)
25
- # breakpoint()
26
24
  new_data = self._fetch_data(to_fetch)
27
25
  return Dataset(new_data)
28
26
 
29
27
  def _normalize_columns(self, columns: Union[str, List[str]]) -> tuple:
30
- """Normalize the columns to a tuple of strings
31
-
32
- >>> s = Selector([], {}, {}, lambda x, y: x, [])
33
- >>> s._normalize_columns([["a", "b"], ])
34
- ('a', 'b')
35
- >>> s._normalize_columns(None)
36
- ('*.*',)
37
- """
38
28
  if not columns or columns == ("*",) or columns == (None,):
39
29
  return ("*.*",)
40
30
  if isinstance(columns[0], list):
@@ -47,7 +37,6 @@ class Selector:
47
37
 
48
38
  for column in columns:
49
39
  matches = self._find_matching_columns(column)
50
- # breakpoint()
51
40
  self._validate_matches(column, matches)
52
41
 
53
42
  if len(matches) == 1:
@@ -63,7 +52,7 @@ class Selector:
63
52
  search_in_list = self.columns
64
53
  else:
65
54
  search_in_list = [s.split(".")[1] for s in self.columns]
66
- # breakpoint()
55
+
67
56
  matches = [s for s in search_in_list if s.startswith(partial_name)]
68
57
  return [partial_name] if partial_name in matches else matches
69
58
 
@@ -127,9 +116,3 @@ class Selector:
127
116
  new_data.append({f"{data_type}.{key}": entries})
128
117
 
129
118
  return [d for key in self.items_in_order for d in new_data if key in d]
130
-
131
-
132
- if __name__ == "__main__":
133
- import doctest
134
-
135
- doctest.testmod()
@@ -11,26 +11,18 @@ from uuid import uuid4
11
11
  from edsl.Base import Base
12
12
  from edsl.scenarios.ScenarioHtmlMixin import ScenarioHtmlMixin
13
13
  from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
14
- from edsl.exceptions.scenarios import ScenarioError
15
14
 
16
15
 
17
16
  class Scenario(Base, UserDict, ScenarioHtmlMixin):
18
17
  """A Scenario is a dictionary of keys/values.
19
18
 
20
- They can be used parameterize EDSL questions."""
21
-
22
- __doc__ = "https://docs.expectedparrot.com/en/latest/scenarios.html"
19
+ They can be used parameterize edsl questions."""
23
20
 
24
21
  def __init__(self, data: Union[dict, None] = None, name: str = None):
25
22
  """Initialize a new Scenario.
26
23
 
27
- # :param data: A dictionary of keys/values for parameterizing questions.
28
- #"""
29
- if not isinstance(data, dict) and data is not None:
30
- raise EDSLScenarioError(
31
- "You must pass in a dictionary to initialize a Scenario."
32
- )
33
-
24
+ :param data: A dictionary of keys/values for parameterizing questions.
25
+ """
34
26
  self.data = data if data is not None else {}
35
27
  self.name = name
36
28
 
@@ -49,6 +41,13 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
49
41
 
50
42
  return ScenarioList([copy.deepcopy(self) for _ in range(n)])
51
43
 
44
+ # @property
45
+ # def has_image(self) -> bool:
46
+ # """Return whether the scenario has an image."""
47
+ # if not hasattr(self, "_has_image"):
48
+ # self._has_image = False
49
+ # return self._has_image
50
+
52
51
  @property
53
52
  def has_jinja_braces(self) -> bool:
54
53
  """Return whether the scenario has jinja braces. This matters for rendering.
@@ -107,9 +106,7 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
107
106
  s = Scenario(data1 | data2)
108
107
  return s
109
108
 
110
- def rename(
111
- self, old_name_or_replacement_dict: dict, new_name: Optional[str] = None
112
- ) -> "Scenario":
109
+ def rename(self, replacement_dict: dict) -> "Scenario":
113
110
  """Rename the keys of a scenario.
114
111
 
115
112
  :param replacement_dict: A dictionary of old keys to new keys.
@@ -119,16 +116,7 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
119
116
  >>> s = Scenario({"food": "wood chips"})
120
117
  >>> s.rename({"food": "food_preference"})
121
118
  Scenario({'food_preference': 'wood chips'})
122
-
123
- >>> s = Scenario({"food": "wood chips"})
124
- >>> s.rename("food", "snack")
125
- Scenario({'snack': 'wood chips'})
126
119
  """
127
- if isinstance(old_name_or_replacement_dict, str) and new_name is not None:
128
- replacement_dict = {old_name_or_replacement_dict: new_name}
129
- else:
130
- replacement_dict = old_name_or_replacement_dict
131
-
132
120
  new_scenario = Scenario()
133
121
  for key, value in self.items():
134
122
  if key in replacement_dict:
@@ -228,20 +216,6 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
228
216
  new_scenario[key] = self[key]
229
217
  return new_scenario
230
218
 
231
- def keep(self, list_of_keys: List[str]) -> "Scenario":
232
- """Keep a subset of keys from a scenario.
233
-
234
- :param list_of_keys: The keys to keep.
235
-
236
- Example:
237
-
238
- >>> s = Scenario({"food": "wood chips", "drink": "water"})
239
- >>> s.keep(["food"])
240
- Scenario({'food': 'wood chips'})
241
- """
242
-
243
- return self.select(list_of_keys)
244
-
245
219
  @classmethod
246
220
  def from_url(cls, url: str, field_name: Optional[str] = "text") -> "Scenario":
247
221
  """Creates a scenario from a URL.
@@ -257,17 +231,7 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
257
231
 
258
232
  @classmethod
259
233
  def from_file(cls, file_path: str, field_name: str) -> "Scenario":
260
- """Creates a scenario from a file.
261
-
262
- >>> import tempfile
263
- >>> with tempfile.NamedTemporaryFile(suffix=".txt", mode="w") as f:
264
- ... _ = f.write("This is a test.")
265
- ... _ = f.flush()
266
- ... s = Scenario.from_file(f.name, "file")
267
- >>> s
268
- Scenario({'file': FileStore(path='...')})
269
-
270
- """
234
+ """Creates a scenario from a file."""
271
235
  from edsl.scenarios.FileStore import FileStore
272
236
 
273
237
  fs = FileStore(file_path)
@@ -1,24 +1,19 @@
1
1
  import requests
2
- from typing import Optional
3
2
  from requests.adapters import HTTPAdapter
4
3
  from requests.packages.urllib3.util.retry import Retry
5
4
 
6
5
 
7
6
  class ScenarioHtmlMixin:
8
7
  @classmethod
9
- def from_html(cls, url: str, field_name: Optional[str] = None) -> "Scenario":
8
+ def from_html(cls, url: str) -> "Scenario":
10
9
  """Create a scenario from HTML content.
11
10
 
12
11
  :param html: The HTML content.
13
- :param field_name: The name of the field containing the HTML content.
14
-
15
12
 
16
13
  """
17
14
  html = cls.fetch_html(url)
18
15
  text = cls.extract_text(html)
19
- if not field_name:
20
- field_name = "text"
21
- return cls({"url": url, "html": html, field_name: text})
16
+ return cls({"url": url, "html": html, "text": text})
22
17
 
23
18
  def fetch_html(url):
24
19
  # Define the user-agent to mimic a browser
@@ -538,17 +538,6 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
538
538
  """
539
539
  return ScenarioList([scenario.drop(fields) for scenario in self.data])
540
540
 
541
- def keep(self, *fields) -> ScenarioList:
542
- """Keep only the specified fields in the scenarios.
543
-
544
- Example:
545
-
546
- >>> s = ScenarioList([Scenario({'a': 1, 'b': 1}), Scenario({'a': 1, 'b': 2})])
547
- >>> s.keep('a')
548
- ScenarioList([Scenario({'a': 1}), Scenario({'a': 1})])
549
- """
550
- return ScenarioList([scenario.keep(fields) for scenario in self.data])
551
-
552
541
  @classmethod
553
542
  def from_list(
554
543
  cls, name: str, values: list, func: Optional[Callable] = None
@@ -1061,7 +1050,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
1061
1050
  elif isinstance(key, int):
1062
1051
  return super().__getitem__(key)
1063
1052
  else:
1064
- return self._to_dict()[key]
1053
+ return self.to_dict()[key]
1065
1054
 
1066
1055
  def to_agent_list(self):
1067
1056
  """Convert the ScenarioList to an AgentList.
edsl/surveys/Rule.py CHANGED
@@ -25,8 +25,6 @@ from jinja2 import Template
25
25
  from rich import print
26
26
  from simpleeval import EvalWithCompoundTypes
27
27
 
28
- from edsl.exceptions.surveys import SurveyError
29
-
30
28
  from edsl.exceptions import (
31
29
  SurveyRuleCannotEvaluateError,
32
30
  SurveyRuleCollectionHasNoRulesAtNodeError,
@@ -49,11 +47,11 @@ class QuestionIndex:
49
47
 
50
48
  def __set__(self, obj, value):
51
49
  if not isinstance(value, (int, EndOfSurvey.__class__)):
52
- raise SurveyError(f"{self.name} must be an integer or EndOfSurvey")
50
+ raise ValueError(f"{self.name} must be an integer or EndOfSurvey")
53
51
  if self.name == "_next_q" and isinstance(value, int):
54
52
  current_q = getattr(obj, "_current_q")
55
53
  if value <= current_q:
56
- raise SurveyError("next_q must be greater than current_q")
54
+ raise ValueError("next_q must be greater than current_q")
57
55
  setattr(obj, self.name, value)
58
56
 
59
57
 
@@ -102,17 +100,13 @@ class Rule:
102
100
  raise SurveyRuleSendsYouBackwardsError
103
101
 
104
102
  if not self.next_q == EndOfSurvey and self.current_q > self.next_q:
105
- raise SurveyRuleSendsYouBackwardsError(
106
- f"current_q: {self.current_q}, next_q: {self.next_q}"
107
- )
103
+ raise SurveyRuleSendsYouBackwardsError
108
104
 
109
105
  # get the AST for the expression - used to extract the variables referenced in the expression
110
106
  try:
111
107
  self.ast_tree = ast.parse(self.expression)
112
108
  except SyntaxError:
113
- raise SurveyRuleSkipLogicSyntaxError(
114
- f"The expression {self.expression} is not valid Python syntax."
115
- )
109
+ raise SurveyRuleSkipLogicSyntaxError
116
110
 
117
111
  # get the names of the variables in the expression
118
112
  # e.g., q1 == 'yes' -> ['q1']