edsl 0.1.27.dev2__py3-none-any.whl → 0.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. edsl/Base.py +99 -22
  2. edsl/BaseDiff.py +260 -0
  3. edsl/__init__.py +4 -0
  4. edsl/__version__.py +1 -1
  5. edsl/agents/Agent.py +26 -5
  6. edsl/agents/AgentList.py +62 -7
  7. edsl/agents/Invigilator.py +4 -9
  8. edsl/agents/InvigilatorBase.py +5 -5
  9. edsl/agents/descriptors.py +3 -1
  10. edsl/conjure/AgentConstructionMixin.py +152 -0
  11. edsl/conjure/Conjure.py +56 -0
  12. edsl/conjure/InputData.py +628 -0
  13. edsl/conjure/InputDataCSV.py +48 -0
  14. edsl/conjure/InputDataMixinQuestionStats.py +182 -0
  15. edsl/conjure/InputDataPyRead.py +91 -0
  16. edsl/conjure/InputDataSPSS.py +8 -0
  17. edsl/conjure/InputDataStata.py +8 -0
  18. edsl/conjure/QuestionOptionMixin.py +76 -0
  19. edsl/conjure/QuestionTypeMixin.py +23 -0
  20. edsl/conjure/RawQuestion.py +65 -0
  21. edsl/conjure/SurveyResponses.py +7 -0
  22. edsl/conjure/__init__.py +9 -4
  23. edsl/conjure/examples/placeholder.txt +0 -0
  24. edsl/conjure/naming_utilities.py +263 -0
  25. edsl/conjure/utilities.py +165 -28
  26. edsl/conversation/Conversation.py +238 -0
  27. edsl/conversation/car_buying.py +58 -0
  28. edsl/conversation/mug_negotiation.py +81 -0
  29. edsl/conversation/next_speaker_utilities.py +93 -0
  30. edsl/coop/coop.py +191 -12
  31. edsl/coop/utils.py +20 -2
  32. edsl/data/Cache.py +55 -17
  33. edsl/data/CacheHandler.py +10 -9
  34. edsl/inference_services/AnthropicService.py +1 -0
  35. edsl/inference_services/DeepInfraService.py +20 -13
  36. edsl/inference_services/GoogleService.py +7 -1
  37. edsl/inference_services/InferenceServicesCollection.py +33 -7
  38. edsl/inference_services/OpenAIService.py +17 -10
  39. edsl/inference_services/models_available_cache.py +69 -0
  40. edsl/inference_services/rate_limits_cache.py +25 -0
  41. edsl/inference_services/write_available.py +10 -0
  42. edsl/jobs/Jobs.py +240 -36
  43. edsl/jobs/buckets/BucketCollection.py +9 -3
  44. edsl/jobs/interviews/Interview.py +4 -1
  45. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +24 -10
  46. edsl/jobs/interviews/retry_management.py +4 -4
  47. edsl/jobs/runners/JobsRunnerAsyncio.py +87 -45
  48. edsl/jobs/runners/JobsRunnerStatusData.py +3 -3
  49. edsl/jobs/tasks/QuestionTaskCreator.py +4 -2
  50. edsl/language_models/LanguageModel.py +37 -44
  51. edsl/language_models/ModelList.py +96 -0
  52. edsl/language_models/registry.py +14 -0
  53. edsl/language_models/repair.py +95 -24
  54. edsl/notebooks/Notebook.py +119 -31
  55. edsl/questions/QuestionBase.py +109 -12
  56. edsl/questions/descriptors.py +5 -2
  57. edsl/questions/question_registry.py +7 -0
  58. edsl/results/Result.py +20 -8
  59. edsl/results/Results.py +85 -11
  60. edsl/results/ResultsDBMixin.py +3 -6
  61. edsl/results/ResultsExportMixin.py +47 -16
  62. edsl/results/ResultsToolsMixin.py +5 -5
  63. edsl/scenarios/Scenario.py +59 -5
  64. edsl/scenarios/ScenarioList.py +97 -40
  65. edsl/study/ObjectEntry.py +97 -0
  66. edsl/study/ProofOfWork.py +110 -0
  67. edsl/study/SnapShot.py +77 -0
  68. edsl/study/Study.py +491 -0
  69. edsl/study/__init__.py +2 -0
  70. edsl/surveys/Survey.py +79 -31
  71. edsl/surveys/SurveyExportMixin.py +21 -3
  72. edsl/utilities/__init__.py +1 -0
  73. edsl/utilities/gcp_bucket/__init__.py +0 -0
  74. edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
  75. edsl/utilities/gcp_bucket/simple_example.py +9 -0
  76. edsl/utilities/interface.py +24 -28
  77. edsl/utilities/repair_functions.py +28 -0
  78. edsl/utilities/utilities.py +57 -2
  79. {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/METADATA +43 -17
  80. {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/RECORD +83 -55
  81. edsl-0.1.28.dist-info/entry_points.txt +3 -0
  82. edsl/conjure/RawResponseColumn.py +0 -327
  83. edsl/conjure/SurveyBuilder.py +0 -308
  84. edsl/conjure/SurveyBuilderCSV.py +0 -78
  85. edsl/conjure/SurveyBuilderSPSS.py +0 -118
  86. edsl/data/RemoteDict.py +0 -103
  87. {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/LICENSE +0 -0
  88. {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/WHEEL +0 -0
edsl/Base.py CHANGED
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod, ABCMeta
4
4
  import gzip
5
5
  import io
6
6
  import json
7
- from typing import Union
7
+ from typing import Any, Optional, Union
8
8
  from uuid import UUID
9
9
  from IPython.display import display
10
10
  from rich.console import Console
@@ -40,25 +40,30 @@ class RichPrintingMixin:
40
40
  class PersistenceMixin:
41
41
  """Mixin for saving and loading objects to and from files."""
42
42
 
43
- def push(self, visibility="unlisted"):
43
+ def push(
44
+ self,
45
+ description: Optional[str] = None,
46
+ visibility: Optional[str] = "unlisted",
47
+ ):
44
48
  """Post the object to coop."""
45
49
  from edsl.coop import Coop
46
50
 
47
51
  c = Coop()
48
- return c.create(self, visibility)
52
+ return c.create(self, description, visibility)
49
53
 
50
54
  @classmethod
51
- def pull(cls, id_or_url: Union[str, UUID]):
55
+ def pull(cls, id_or_url: Union[str, UUID], exec_profile=None):
52
56
  """Pull the object from coop."""
53
57
  from edsl.coop import Coop
54
58
 
59
+ if id_or_url.startswith("http"):
60
+ uuid_value = id_or_url.split("/")[-1]
61
+ else:
62
+ uuid_value = id_or_url
63
+
55
64
  c = Coop()
56
- return c._get_base(cls, id_or_url)
57
- # if isinstance(id_or_url, str) and c.url in id_or_url:
58
- # return c.get(url=id_or_url)
59
- # else:
60
- # _, object_type = c._resolve_edsl_object(cls)
61
- # return c.get(object_type, id_or_url)
65
+
66
+ return c._get_base(cls, uuid_value, exec_profile=exec_profile)
62
67
 
63
68
  @classmethod
64
69
  def delete(cls, id_or_url: Union[str, UUID]):
@@ -69,15 +74,23 @@ class PersistenceMixin:
69
74
  return c._delete_base(cls, id_or_url)
70
75
 
71
76
  @classmethod
72
- def patch(cls, id_or_url: Union[str, UUID], visibility: str):
77
+ def patch(
78
+ cls,
79
+ id_or_url: Union[str, UUID],
80
+ description: Optional[str] = None,
81
+ value: Optional[Any] = None,
82
+ visibility: Optional[str] = None,
83
+ ):
73
84
  """
74
85
  Patch an uploaded objects attributes.
75
- - Only supports changing visibility for now.
86
+ - `description` changes the description of the object on Coop
87
+ - `value` changes the value of the object on Coop. **has to be an EDSL object**
88
+ - `visibility` changes the visibility of the object on Coop
76
89
  """
77
90
  from edsl.coop import Coop
78
91
 
79
92
  c = Coop()
80
- return c._patch_base(cls, id_or_url, visibility)
93
+ return c._patch_base(cls, id_or_url, description, value, visibility)
81
94
 
82
95
  @classmethod
83
96
  def search(cls, query):
@@ -87,14 +100,45 @@ class PersistenceMixin:
87
100
  c = Coop()
88
101
  return c.search(cls, query)
89
102
 
90
- def save(self, filename):
103
+ def save(self, filename, compress=True):
91
104
  """Save the object to a file as zippped JSON.
92
105
 
93
106
  >>> obj.save("obj.json.gz")
94
107
 
95
108
  """
96
- with gzip.open(filename, "wb") as f:
97
- f.write(json.dumps(self.to_dict()).encode("utf-8"))
109
+ if filename.endswith("json.gz"):
110
+ import warnings
111
+
112
+ warnings.warn(
113
+ "Do not apply the file extensions. The filename should not end with 'json.gz'."
114
+ )
115
+ filename = filename[:-7]
116
+ if filename.endswith("json"):
117
+ filename = filename[:-4]
118
+ warnings.warn(
119
+ "Do not apply the file extensions. The filename should not end with 'json'."
120
+ )
121
+
122
+ if compress:
123
+ with gzip.open(filename + ".json.gz", "wb") as f:
124
+ f.write(json.dumps(self.to_dict()).encode("utf-8"))
125
+ else:
126
+ with open(filename + ".json", "w") as f:
127
+ f.write(json.dumps(self.to_dict()))
128
+
129
+ @staticmethod
130
+ def open_compressed_file(filename):
131
+ with gzip.open(filename, "rb") as f:
132
+ file_contents = f.read()
133
+ file_contents_decoded = file_contents.decode("utf-8")
134
+ d = json.loads(file_contents_decoded)
135
+ return d
136
+
137
+ @staticmethod
138
+ def open_regular_file(filename):
139
+ with open(filename, "r") as f:
140
+ d = json.loads(f.read())
141
+ return d
98
142
 
99
143
  @classmethod
100
144
  def load(cls, filename):
@@ -103,11 +147,19 @@ class PersistenceMixin:
103
147
  >>> obj = cls.load("obj.json.gz")
104
148
 
105
149
  """
106
- with gzip.open(filename, "rb") as f:
107
- file_contents = f.read()
108
- file_contents_decoded = file_contents.decode("utf-8")
109
- d = json.loads(file_contents_decoded)
110
- # d = json.loads(f.read().decode("utf-8"))
150
+
151
+ if filename.endswith("json.gz"):
152
+ d = cls.open_compressed_file(filename)
153
+ elif filename.endswith("json"):
154
+ d = cls.open_regular_file(filename)
155
+ else:
156
+ try:
157
+ d = cls.open_compressed_file(filename)
158
+ except:
159
+ d = cls.open_regular_file(filename)
160
+ finally:
161
+ raise ValueError("File must be a json or json.gz file")
162
+
111
163
  return cls.from_dict(d)
112
164
 
113
165
 
@@ -128,7 +180,21 @@ class RegisterSubclassesMeta(ABCMeta):
128
180
  return dict(RegisterSubclassesMeta._registry)
129
181
 
130
182
 
131
- class Base(RichPrintingMixin, PersistenceMixin, ABC, metaclass=RegisterSubclassesMeta):
183
+ class DiffMethodsMixin:
184
+ def __sub__(self, other):
185
+ """Return the difference between two objects."""
186
+ from edsl.BaseDiff import BaseDiff
187
+
188
+ return BaseDiff(self, other)
189
+
190
+
191
+ class Base(
192
+ RichPrintingMixin,
193
+ PersistenceMixin,
194
+ DiffMethodsMixin,
195
+ ABC,
196
+ metaclass=RegisterSubclassesMeta,
197
+ ):
132
198
  """Base class for all classes in the package."""
133
199
 
134
200
  # def __getitem__(self, key):
@@ -172,6 +238,17 @@ class Base(RichPrintingMixin, PersistenceMixin, ABC, metaclass=RegisterSubclasse
172
238
  # f.write(html_string)
173
239
  # webbrowser.open(f.name)
174
240
 
241
+ def __eq__(self, other):
242
+ """Return whether two objects are equal."""
243
+ import inspect
244
+
245
+ if not isinstance(other, self.__class__):
246
+ return False
247
+ if "sort" in inspect.signature(self._to_dict).parameters:
248
+ return self._to_dict(sort=True) == other._to_dict(sort=True)
249
+ else:
250
+ return self._to_dict() == other._to_dict()
251
+
175
252
  @abstractmethod
176
253
  def example():
177
254
  """This method should be implemented by subclasses."""
edsl/BaseDiff.py ADDED
@@ -0,0 +1,260 @@
1
+ import difflib
2
+ import json
3
+ from typing import Any, Dict, Tuple
4
+ from collections import UserList
5
+ import inspect
6
+
7
+
8
+ class BaseDiffCollection(UserList):
9
+ def __init__(self, diffs=None):
10
+ if diffs is None:
11
+ diffs = []
12
+ super().__init__(diffs)
13
+
14
+ def apply(self, obj: Any):
15
+ for diff in self:
16
+ obj = diff.apply(obj)
17
+ return obj
18
+
19
+ def add_diff(self, diff) -> "BaseDiffCollection":
20
+ self.append(diff)
21
+ return self
22
+
23
+
24
+ class DummyObject:
25
+ def __init__(self, object_dict):
26
+ self.object_dict = object_dict
27
+
28
+ def _to_dict(self):
29
+ return self.object_dict
30
+
31
+
32
+ class BaseDiff:
33
+ def __init__(
34
+ self, obj1: Any, obj2: Any, added=None, removed=None, modified=None, level=0
35
+ ):
36
+ self.level = level
37
+
38
+ self.obj1 = obj1
39
+ self.obj2 = obj2
40
+
41
+ if "sort" in inspect.signature(obj1._to_dict).parameters:
42
+ self._dict1 = obj1._to_dict(sort=True)
43
+ self._dict2 = obj2._to_dict(sort=True)
44
+ else:
45
+ self._dict1 = obj1._to_dict()
46
+ self._dict2 = obj2._to_dict()
47
+ self._obj_class = type(obj1)
48
+
49
+ self.added = added
50
+ self.removed = removed
51
+ self.modified = modified
52
+
53
+ def __bool__(self):
54
+ return bool(self.added or self.removed or self.modified)
55
+
56
+ @property
57
+ def added(self):
58
+ if self._added is None:
59
+ self._added = self._find_added()
60
+ return self._added
61
+
62
+ def __add__(self, other):
63
+ return self.apply(other)
64
+
65
+ @added.setter
66
+ def added(self, value):
67
+ self._added = value if value is not None else self._find_added()
68
+
69
+ @property
70
+ def removed(self):
71
+ if self._removed is None:
72
+ self._removed = self._find_removed()
73
+ return self._removed
74
+
75
+ @removed.setter
76
+ def removed(self, value):
77
+ self._removed = value if value is not None else self._find_removed()
78
+
79
+ @property
80
+ def modified(self):
81
+ if self._modified is None:
82
+ self._modified = self._find_modified()
83
+ return self._modified
84
+
85
+ @modified.setter
86
+ def modified(self, value):
87
+ self._modified = value if value is not None else self._find_modified()
88
+
89
+ def _find_added(self) -> Dict[Any, Any]:
90
+ return {k: self._dict2[k] for k in self._dict2 if k not in self._dict1}
91
+
92
+ def _find_removed(self) -> Dict[Any, Any]:
93
+ return {k: self._dict1[k] for k in self._dict1 if k not in self._dict2}
94
+
95
+ def _find_modified(self) -> Dict[Any, Tuple[Any, Any, str]]:
96
+ modified = {}
97
+ for k in self._dict1:
98
+ if k in self._dict2 and self._dict1[k] != self._dict2[k]:
99
+ if isinstance(self._dict1[k], str) and isinstance(self._dict2[k], str):
100
+ diff = self._diff_strings(self._dict1[k], self._dict2[k])
101
+ modified[k] = (self._dict1[k], self._dict2[k], diff)
102
+ elif isinstance(self._dict1[k], dict) and isinstance(
103
+ self._dict2[k], dict
104
+ ):
105
+ diff = self._diff_dicts(self._dict1[k], self._dict2[k])
106
+ modified[k] = (self._dict1[k], self._dict2[k], diff)
107
+ elif isinstance(self._dict1[k], list) and isinstance(
108
+ self._dict2[k], list
109
+ ):
110
+ d1 = dict(zip(range(len(self._dict1[k])), self._dict1[k]))
111
+ d2 = dict(zip(range(len(self._dict2[k])), self._dict2[k]))
112
+ diff = BaseDiff(
113
+ DummyObject(d1), DummyObject(d2), level=self.level + 1
114
+ )
115
+ modified[k] = (self._dict1[k], self._dict2[k], diff)
116
+ else:
117
+ modified[k] = (self._dict1[k], self._dict2[k], "")
118
+ return modified
119
+
120
+ @staticmethod
121
+ def is_json(string_that_could_be_json: str) -> bool:
122
+ try:
123
+ json.loads(string_that_could_be_json)
124
+ return True
125
+ except json.JSONDecodeError:
126
+ return False
127
+
128
+ def _diff_dicts(self, dict1: Dict[str, Any], dict2: Dict[str, Any]) -> str:
129
+ diff = BaseDiff(DummyObject(dict1), DummyObject(dict2), level=self.level + 1)
130
+ return diff
131
+
132
+ def _diff_strings(self, str1: str, str2: str) -> str:
133
+ if self.is_json(str1) and self.is_json(str2):
134
+ diff = self._diff_dicts(json.loads(str1), json.loads(str2))
135
+ return diff
136
+ diff = difflib.ndiff(str1.splitlines(), str2.splitlines())
137
+ return diff
138
+
139
+ def apply(self, obj: Any):
140
+ """Apply the diff to the object."""
141
+
142
+ new_obj_dict = obj._to_dict()
143
+ for k, v in self.added.items():
144
+ new_obj_dict[k] = v
145
+ for k in self.removed.keys():
146
+ del new_obj_dict[k]
147
+ for k, (v1, v2, diff) in self.modified.items():
148
+ new_obj_dict[k] = v2
149
+
150
+ return obj.from_dict(new_obj_dict)
151
+
152
+ def to_dict(self) -> Dict[str, Any]:
153
+ return {
154
+ "added": self.added,
155
+ "removed": self.removed,
156
+ "modified": self.modified,
157
+ "obj1": self._dict1,
158
+ "obj2": self._dict2,
159
+ "obj_class": self._obj_class.__name__,
160
+ "level": self.level,
161
+ }
162
+
163
+ @classmethod
164
+ def from_dict(cls, diff_dict: Dict[str, Any], obj1: Any, obj2: Any):
165
+ return cls(
166
+ obj1=obj1,
167
+ obj2=obj2,
168
+ added=diff_dict["added"],
169
+ removed=diff_dict["removed"],
170
+ modified=diff_dict["modified"],
171
+ level=diff_dict["level"],
172
+ )
173
+
174
+ class Results(UserList):
175
+ def __init__(self, prepend=" ", level=0):
176
+ super().__init__()
177
+ self.prepend = prepend
178
+ self.level = level
179
+
180
+ def append(self, item):
181
+ super().append(self.prepend * self.level + item)
182
+
183
+ def __str__(self):
184
+ prepend = " "
185
+ result = self.Results(level=self.level, prepend="\t")
186
+ if self.added:
187
+ result.append("Added keys and values:")
188
+ for k, v in self.added.items():
189
+ result.append(prepend + f" {k}: {v}")
190
+ if self.removed:
191
+ result.append("Removed keys and values:")
192
+ for k, v in self.removed.items():
193
+ result.append(f" {k}: {v}")
194
+ if self.modified:
195
+ result.append("Modified keys and values:")
196
+ for k, (v1, v2, diff) in self.modified.items():
197
+ result.append(f"Key: {k}:")
198
+ result.append(f" Old value: {v1}")
199
+ result.append(f" New value: {v2}")
200
+ if diff:
201
+ result.append(f" Diff:")
202
+ try:
203
+ for line in diff:
204
+ result.append(f" {line}")
205
+ except:
206
+ result.append(f" {diff}")
207
+ return "\n".join(result)
208
+
209
+ def __repr__(self):
210
+ return (
211
+ f"BaseDiff(obj1={self.obj1!r}, obj2={self.obj2!r}, added={self.added!r}, "
212
+ f"removed={self.removed!r}, modified={self.modified!r})"
213
+ )
214
+
215
+ def add_diff(self, diff) -> "BaseDiffCollection":
216
+ return BaseDiffCollection([self, diff])
217
+
218
+
219
+ if __name__ == "__main__":
220
+ from edsl import Question
221
+
222
+ q_ft = Question.example("free_text")
223
+ q_mc = Question.example("multiple_choice")
224
+
225
+ diff1 = q_ft - q_mc
226
+ assert q_ft == q_mc + diff1
227
+ assert q_ft == diff1.apply(q_mc)
228
+ # new_q_mc = diff1.apply(q_ft)
229
+ # assert new_q_mc == q_mc
230
+
231
+ # new_q_mc = q_ft + diff1
232
+ # assert new_q_mc == q_mc
233
+
234
+ # new_q_mc = diff1 + q_ft
235
+ # assert new_q_mc == q_mc
236
+
237
+ # ## Test chain of diffs
238
+ q0 = Question.example("free_text")
239
+ q1 = q0.copy()
240
+ q1.question_text = "Why is Buzzard's Bay so named?"
241
+ diff1 = q1 - q0
242
+ q2 = q1.copy()
243
+ q2.question_name = "buzzard_bay"
244
+ diff2 = q2 - q1
245
+
246
+ diff_chain = diff1.add_diff(diff2)
247
+
248
+ new_q2 = diff_chain.apply(q0)
249
+ assert new_q2 == q2
250
+
251
+ new_q2 = diff_chain + q0
252
+ assert new_q2 == q2
253
+
254
+ # new_diffs = diff1.add_diff(diff1).add_diff(diff1)
255
+ # assert len(new_diffs) == 3
256
+
257
+ # q0 = Question.example("free_text")
258
+ # q1 = Question.example("free_text")
259
+ # q1.question_text = "Why is Buzzard's Bay so named?"
260
+ # q2 = q1.copy()
edsl/__init__.py CHANGED
@@ -35,4 +35,8 @@ from edsl.data.CacheEntry import CacheEntry
35
35
  from edsl.data.CacheHandler import set_session_cache, unset_session_cache
36
36
  from edsl.shared import shared_globals
37
37
  from edsl.jobs import Jobs
38
+ from edsl.notebooks import Notebook
39
+ from edsl.study.Study import Study
38
40
  from edsl.coop.coop import Coop
41
+ from edsl.conjure.Conjure import Conjure
42
+ from edsl.language_models.ModelList import ModelList
edsl/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.27.dev2"
1
+ __version__ = "0.1.28"
edsl/agents/Agent.py CHANGED
@@ -98,7 +98,7 @@ class Agent(Base):
98
98
 
99
99
  >>> a = Agent(traits = {"age": 10}, traits_presentation_template = "I am a {{age}} year old.")
100
100
  >>> repr(a.agent_persona)
101
- "Prompt(text='I am a {{age}} year old.')"
101
+ 'Prompt(text=\"""I am a {{age}} year old.\""")'
102
102
 
103
103
  When this is rendered for presentation to the LLM, it will replace the `{{age}}` with the actual age.
104
104
  it is also possible to use the `codebook` to provide a more human-readable description of the trait.
@@ -109,7 +109,7 @@ class Agent(Base):
109
109
  >>> a = Agent(traits = traits, codebook = codebook, traits_presentation_template = "This agent is Dave. {{codebook['age']}} {{age}}")
110
110
  >>> d = a.traits | {'codebook': a.codebook}
111
111
  >>> a.agent_persona.render(d)
112
- Prompt(text='This agent is Dave. Their age is 10')
112
+ Prompt(text=\"""This agent is Dave. Their age is 10\""")
113
113
 
114
114
  Instructions
115
115
  ------------
@@ -198,6 +198,18 @@ class Agent(Base):
198
198
  else:
199
199
  return self._traits
200
200
 
201
+ def rename(self, old_name: str, new_name: str) -> Agent:
202
+ """Rename a trait.
203
+
204
+ Example usage:
205
+
206
+ >>> a = Agent(traits = {"age": 10, "hair": "brown", "height": 5.5})
207
+ >>> a.rename("age", "years") == Agent(traits = {'years': 10, 'hair': 'brown', 'height': 5.5})
208
+ True
209
+ """
210
+ self.traits[new_name] = self.traits.pop(old_name)
211
+ return self
212
+
201
213
  def __getitem__(self, key):
202
214
  """Allow for accessing traits using the bracket notation.
203
215
 
@@ -327,7 +339,7 @@ class Agent(Base):
327
339
  >>> from edsl import QuestionFreeText
328
340
  >>> q = QuestionFreeText.example()
329
341
  >>> a.answer_question(question = q, cache = False)
330
- {'answer': 'I am a direct answer.', 'comment': 'This is a real survey response from a human.', 'question_name': 'how_are_you', 'prompts': {'user_prompt': Prompt(text='NA'), 'system_prompt': Prompt(text='NA')}, 'usage': {'prompt_tokens': 0, 'completion_tokens': 0}, 'cached_response': None, 'raw_model_response': None, 'simple_model_raw_response': None}
342
+ {'answer': 'I am a direct answer.', 'comment': 'This is a real survey response from a human.', ...}
331
343
 
332
344
  This is a function where an agent returns an answer to a particular question.
333
345
  However, there are several different ways an agent can answer a question, so the
@@ -547,6 +559,15 @@ class Agent(Base):
547
559
 
548
560
  return raw_data
549
561
 
562
+ def __hash__(self) -> int:
563
+ from edsl.utilities.utilities import dict_hash
564
+
565
+ return dict_hash(self._to_dict())
566
+
567
+ def _to_dict(self) -> dict[str, Union[dict, bool]]:
568
+ """Serialize to a dictionary."""
569
+ return self.data
570
+
550
571
  @add_edsl_version
551
572
  def to_dict(self) -> dict[str, Union[dict, bool]]:
552
573
  """Serialize to a dictionary.
@@ -557,7 +578,7 @@ class Agent(Base):
557
578
  >>> a.to_dict()
558
579
  {'name': 'Steve', 'traits': {'age': 10, 'hair': 'brown', 'height': 5.5}, 'edsl_version': '...', 'edsl_class_name': 'Agent'}
559
580
  """
560
- return self.data
581
+ return self._to_dict()
561
582
 
562
583
  @classmethod
563
584
  @remove_edsl_version
@@ -567,7 +588,7 @@ class Agent(Base):
567
588
  Example usage:
568
589
 
569
590
  >>> Agent.from_dict({'name': "Steve", 'traits': {'age': 10, 'hair': 'brown', 'height': 5.5}})
570
- Agent(name = 'Steve', traits = {'age': 10, 'hair': 'brown', 'height': 5.5})
591
+ Agent(name = \"""Steve\""", traits = {'age': 10, 'hair': 'brown', 'height': 5.5})
571
592
 
572
593
  """
573
594
  return cls(**agent_dict)
edsl/agents/AgentList.py CHANGED
@@ -18,6 +18,7 @@ from rich.table import Table
18
18
  import json
19
19
  import csv
20
20
 
21
+
21
22
  from simpleeval import EvalWithCompoundTypes
22
23
 
23
24
  from edsl.Base import Base
@@ -41,6 +42,38 @@ class AgentList(UserList, Base):
41
42
  else:
42
43
  super().__init__()
43
44
 
45
+ def shuffle(self, seed: Optional[str] = "edsl") -> AgentList:
46
+ """Shuffle the AgentList.
47
+
48
+ :param seed: The seed for the random number generator.
49
+ """
50
+ import random
51
+
52
+ random.seed(seed)
53
+ random.shuffle(self.data)
54
+ return self
55
+
56
+ def sample(self, n: int, seed="edsl") -> AgentList:
57
+ """Return a random sample of agents.
58
+
59
+ :param n: The number of agents to sample.
60
+ :param seed: The seed for the random number generator.
61
+ """
62
+ import random
63
+
64
+ random.seed(seed)
65
+ return AgentList(random.sample(self.data, n))
66
+
67
+ def rename(self, old_name, new_name):
68
+ """Rename a trait in the AgentList.
69
+
70
+ :param old_name: The old name of the trait.
71
+ :param new_name: The new name of the trait.
72
+ """
73
+ for agent in self.data:
74
+ agent.rename(old_name, new_name)
75
+ return self
76
+
44
77
  def select(self, *traits) -> AgentList:
45
78
  """Selects agents with only the references traits.
46
79
 
@@ -139,21 +172,36 @@ class AgentList(UserList, Base):
139
172
  reader = csv.DictReader(f)
140
173
  return {field: None for field in reader.fieldnames}
141
174
 
175
+ def __hash__(self) -> int:
176
+ from edsl.utilities.utilities import dict_hash
177
+
178
+ data = self.to_dict()
179
+ # data['agent_list'] = sorted(data['agent_list'], key=lambda x: dict_hash(x)
180
+ return dict_hash(self._to_dict(sorted=True))
181
+
182
+ def _to_dict(self, sorted=False):
183
+ if sorted:
184
+ data = self.data[:]
185
+ data.sort(key=lambda x: hash(x))
186
+ else:
187
+ data = self.data
188
+
189
+ return {"agent_list": [agent.to_dict() for agent in data]}
190
+
191
+ def __eq__(self, other: AgentList) -> bool:
192
+ return self._to_dict(sorted=True) == other._to_dict(sorted=True)
193
+
142
194
  @add_edsl_version
143
195
  def to_dict(self):
144
- """Return dictionary of AgentList to serialization.
145
-
146
- >>> AgentList.example().to_dict()
147
- {'agent_list': [{'traits': {'age': 22, 'hair': 'brown', 'height': 5.5}, 'edsl_version': '...', 'edsl_class_name': 'Agent'}, {'traits': {'age': 22, 'hair': 'brown', 'height': 5.5}, 'edsl_version': '...', 'edsl_class_name': 'Agent'}], 'edsl_version': '...', 'edsl_class_name': 'AgentList'}
148
- """
149
- return {"agent_list": [agent.to_dict() for agent in self.data]}
196
+ """Return dictionary of AgentList to serialization."""
197
+ return self._to_dict()
150
198
 
151
199
  def __repr__(self):
152
200
  return f"AgentList({self.data})"
153
201
 
154
202
  def print(self, format: Optional[str] = None):
155
203
  """Print the AgentList."""
156
- print_json(json.dumps(self.to_dict()))
204
+ print_json(json.dumps(self._to_dict()))
157
205
 
158
206
  def _repr_html_(self):
159
207
  """Return an HTML representation of the AgentList."""
@@ -161,6 +209,13 @@ class AgentList(UserList, Base):
161
209
 
162
210
  return data_to_html(self.to_dict()["agent_list"])
163
211
 
212
+ def to_scenario_list(self) -> "ScenarioList":
213
+ """Return a list of scenarios."""
214
+ from edsl.scenarios.ScenarioList import ScenarioList
215
+ from edsl.scenarios.Scenario import Scenario
216
+
217
+ return ScenarioList([Scenario(agent.traits) for agent in self.data])
218
+
164
219
  @classmethod
165
220
  @remove_edsl_version
166
221
  def from_dict(cls, data: dict) -> "AgentList":
@@ -63,14 +63,9 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
63
63
 
64
64
  def _remove_from_cache(self, raw_response) -> None:
65
65
  """Remove an entry from the cache."""
66
- if (
67
- "raw_model_response" in raw_response
68
- and "cache_key" in raw_response["raw_model_response"]
69
- ):
70
- cache_key = raw_response["raw_model_response"]["cache_key"]
71
- else:
72
- cache_key = None
73
- del self.cache.data[cache_key]
66
+ cache_key = raw_response.get("cache_key", None)
67
+ if cache_key:
68
+ del self.cache.data[cache_key]
74
69
 
75
70
  def _format_raw_response(
76
71
  self, *, agent, question, scenario, raw_response, raw_model_response
@@ -95,7 +90,7 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
95
90
  ), # not all question have comment fields,
96
91
  "question_name": question.question_name,
97
92
  "prompts": self.get_prompts(),
98
- "cached_response": raw_response["cached_response"],
93
+ "cached_response": raw_response.get("cached_response", None),
99
94
  "usage": raw_response.get("usage", {}),
100
95
  "raw_model_response": raw_model_response,
101
96
  }