edsl 0.1.38.dev4__py3-none-any.whl → 0.1.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. edsl/Base.py +197 -116
  2. edsl/__init__.py +15 -7
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +351 -147
  5. edsl/agents/AgentList.py +211 -73
  6. edsl/agents/Invigilator.py +101 -50
  7. edsl/agents/InvigilatorBase.py +62 -70
  8. edsl/agents/PromptConstructor.py +143 -225
  9. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  10. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  11. edsl/agents/__init__.py +0 -1
  12. edsl/agents/prompt_helpers.py +3 -3
  13. edsl/agents/question_option_processor.py +172 -0
  14. edsl/auto/AutoStudy.py +18 -5
  15. edsl/auto/StageBase.py +53 -40
  16. edsl/auto/StageQuestions.py +2 -1
  17. edsl/auto/utilities.py +0 -6
  18. edsl/config.py +22 -2
  19. edsl/conversation/car_buying.py +2 -1
  20. edsl/coop/CoopFunctionsMixin.py +15 -0
  21. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  22. edsl/coop/PriceFetcher.py +1 -1
  23. edsl/coop/coop.py +125 -47
  24. edsl/coop/utils.py +14 -14
  25. edsl/data/Cache.py +45 -27
  26. edsl/data/CacheEntry.py +12 -15
  27. edsl/data/CacheHandler.py +31 -12
  28. edsl/data/RemoteCacheSync.py +154 -46
  29. edsl/data/__init__.py +4 -3
  30. edsl/data_transfer_models.py +2 -1
  31. edsl/enums.py +27 -0
  32. edsl/exceptions/__init__.py +50 -50
  33. edsl/exceptions/agents.py +12 -0
  34. edsl/exceptions/inference_services.py +5 -0
  35. edsl/exceptions/questions.py +24 -6
  36. edsl/exceptions/scenarios.py +7 -0
  37. edsl/inference_services/AnthropicService.py +38 -19
  38. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  39. edsl/inference_services/AvailableModelFetcher.py +215 -0
  40. edsl/inference_services/AwsBedrock.py +0 -2
  41. edsl/inference_services/AzureAI.py +0 -2
  42. edsl/inference_services/GoogleService.py +7 -12
  43. edsl/inference_services/InferenceServiceABC.py +18 -85
  44. edsl/inference_services/InferenceServicesCollection.py +120 -79
  45. edsl/inference_services/MistralAIService.py +0 -3
  46. edsl/inference_services/OpenAIService.py +47 -35
  47. edsl/inference_services/PerplexityService.py +0 -3
  48. edsl/inference_services/ServiceAvailability.py +135 -0
  49. edsl/inference_services/TestService.py +11 -10
  50. edsl/inference_services/TogetherAIService.py +5 -3
  51. edsl/inference_services/data_structures.py +134 -0
  52. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  53. edsl/jobs/Answers.py +1 -14
  54. edsl/jobs/FetchInvigilator.py +47 -0
  55. edsl/jobs/InterviewTaskManager.py +98 -0
  56. edsl/jobs/InterviewsConstructor.py +50 -0
  57. edsl/jobs/Jobs.py +356 -431
  58. edsl/jobs/JobsChecks.py +35 -10
  59. edsl/jobs/JobsComponentConstructor.py +189 -0
  60. edsl/jobs/JobsPrompts.py +6 -4
  61. edsl/jobs/JobsRemoteInferenceHandler.py +205 -133
  62. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  63. edsl/jobs/RequestTokenEstimator.py +30 -0
  64. edsl/jobs/async_interview_runner.py +138 -0
  65. edsl/jobs/buckets/BucketCollection.py +44 -3
  66. edsl/jobs/buckets/TokenBucket.py +53 -21
  67. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  68. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  69. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  70. edsl/jobs/data_structures.py +120 -0
  71. edsl/jobs/decorators.py +35 -0
  72. edsl/jobs/interviews/Interview.py +143 -408
  73. edsl/jobs/jobs_status_enums.py +9 -0
  74. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  75. edsl/jobs/results_exceptions_handler.py +98 -0
  76. edsl/jobs/runners/JobsRunnerAsyncio.py +88 -403
  77. edsl/jobs/runners/JobsRunnerStatus.py +133 -165
  78. edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
  79. edsl/jobs/tasks/TaskHistory.py +38 -18
  80. edsl/jobs/tasks/task_status_enum.py +0 -2
  81. edsl/language_models/ComputeCost.py +63 -0
  82. edsl/language_models/LanguageModel.py +194 -236
  83. edsl/language_models/ModelList.py +28 -19
  84. edsl/language_models/PriceManager.py +127 -0
  85. edsl/language_models/RawResponseHandler.py +106 -0
  86. edsl/language_models/ServiceDataSources.py +0 -0
  87. edsl/language_models/__init__.py +1 -2
  88. edsl/language_models/key_management/KeyLookup.py +63 -0
  89. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  90. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  91. edsl/language_models/key_management/__init__.py +0 -0
  92. edsl/language_models/key_management/models.py +131 -0
  93. edsl/language_models/model.py +256 -0
  94. edsl/language_models/repair.py +2 -2
  95. edsl/language_models/utilities.py +5 -4
  96. edsl/notebooks/Notebook.py +19 -14
  97. edsl/notebooks/NotebookToLaTeX.py +142 -0
  98. edsl/prompts/Prompt.py +29 -39
  99. edsl/questions/ExceptionExplainer.py +77 -0
  100. edsl/questions/HTMLQuestion.py +103 -0
  101. edsl/questions/QuestionBase.py +68 -214
  102. edsl/questions/QuestionBasePromptsMixin.py +7 -3
  103. edsl/questions/QuestionBudget.py +1 -1
  104. edsl/questions/QuestionCheckBox.py +3 -3
  105. edsl/questions/QuestionExtract.py +5 -7
  106. edsl/questions/QuestionFreeText.py +2 -3
  107. edsl/questions/QuestionList.py +10 -18
  108. edsl/questions/QuestionMatrix.py +265 -0
  109. edsl/questions/QuestionMultipleChoice.py +67 -23
  110. edsl/questions/QuestionNumerical.py +2 -4
  111. edsl/questions/QuestionRank.py +7 -17
  112. edsl/questions/SimpleAskMixin.py +4 -3
  113. edsl/questions/__init__.py +2 -1
  114. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +47 -2
  115. edsl/questions/data_structures.py +20 -0
  116. edsl/questions/derived/QuestionLinearScale.py +6 -3
  117. edsl/questions/derived/QuestionTopK.py +1 -1
  118. edsl/questions/descriptors.py +17 -3
  119. edsl/questions/loop_processor.py +149 -0
  120. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +57 -50
  121. edsl/questions/question_registry.py +1 -1
  122. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +40 -26
  123. edsl/questions/response_validator_factory.py +34 -0
  124. edsl/questions/templates/matrix/__init__.py +1 -0
  125. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  126. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  127. edsl/results/CSSParameterizer.py +1 -1
  128. edsl/results/Dataset.py +170 -7
  129. edsl/results/DatasetExportMixin.py +168 -305
  130. edsl/results/DatasetTree.py +28 -8
  131. edsl/results/MarkdownToDocx.py +122 -0
  132. edsl/results/MarkdownToPDF.py +111 -0
  133. edsl/results/Result.py +298 -206
  134. edsl/results/Results.py +149 -131
  135. edsl/results/ResultsExportMixin.py +2 -0
  136. edsl/results/TableDisplay.py +98 -171
  137. edsl/results/TextEditor.py +50 -0
  138. edsl/results/__init__.py +1 -1
  139. edsl/results/file_exports.py +252 -0
  140. edsl/results/{Selector.py → results_selector.py} +23 -13
  141. edsl/results/smart_objects.py +96 -0
  142. edsl/results/table_data_class.py +12 -0
  143. edsl/results/table_renderers.py +118 -0
  144. edsl/scenarios/ConstructDownloadLink.py +109 -0
  145. edsl/scenarios/DocumentChunker.py +102 -0
  146. edsl/scenarios/DocxScenario.py +16 -0
  147. edsl/scenarios/FileStore.py +150 -239
  148. edsl/scenarios/PdfExtractor.py +40 -0
  149. edsl/scenarios/Scenario.py +90 -193
  150. edsl/scenarios/ScenarioHtmlMixin.py +4 -3
  151. edsl/scenarios/ScenarioList.py +415 -244
  152. edsl/scenarios/ScenarioListExportMixin.py +0 -7
  153. edsl/scenarios/ScenarioListPdfMixin.py +15 -37
  154. edsl/scenarios/__init__.py +1 -2
  155. edsl/scenarios/directory_scanner.py +96 -0
  156. edsl/scenarios/file_methods.py +85 -0
  157. edsl/scenarios/handlers/__init__.py +13 -0
  158. edsl/scenarios/handlers/csv.py +49 -0
  159. edsl/scenarios/handlers/docx.py +76 -0
  160. edsl/scenarios/handlers/html.py +37 -0
  161. edsl/scenarios/handlers/json.py +111 -0
  162. edsl/scenarios/handlers/latex.py +5 -0
  163. edsl/scenarios/handlers/md.py +51 -0
  164. edsl/scenarios/handlers/pdf.py +68 -0
  165. edsl/scenarios/handlers/png.py +39 -0
  166. edsl/scenarios/handlers/pptx.py +105 -0
  167. edsl/scenarios/handlers/py.py +294 -0
  168. edsl/scenarios/handlers/sql.py +313 -0
  169. edsl/scenarios/handlers/sqlite.py +149 -0
  170. edsl/scenarios/handlers/txt.py +33 -0
  171. edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +10 -6
  172. edsl/scenarios/scenario_selector.py +156 -0
  173. edsl/study/ObjectEntry.py +1 -1
  174. edsl/study/SnapShot.py +1 -1
  175. edsl/study/Study.py +5 -12
  176. edsl/surveys/ConstructDAG.py +92 -0
  177. edsl/surveys/EditSurvey.py +221 -0
  178. edsl/surveys/InstructionHandler.py +100 -0
  179. edsl/surveys/MemoryManagement.py +72 -0
  180. edsl/surveys/Rule.py +5 -4
  181. edsl/surveys/RuleCollection.py +25 -27
  182. edsl/surveys/RuleManager.py +172 -0
  183. edsl/surveys/Simulator.py +75 -0
  184. edsl/surveys/Survey.py +270 -791
  185. edsl/surveys/SurveyCSS.py +20 -8
  186. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
  187. edsl/surveys/SurveyToApp.py +141 -0
  188. edsl/surveys/__init__.py +4 -2
  189. edsl/surveys/descriptors.py +6 -2
  190. edsl/surveys/instructions/ChangeInstruction.py +1 -2
  191. edsl/surveys/instructions/Instruction.py +4 -13
  192. edsl/surveys/instructions/InstructionCollection.py +11 -6
  193. edsl/templates/error_reporting/interview_details.html +1 -1
  194. edsl/templates/error_reporting/report.html +1 -1
  195. edsl/tools/plotting.py +1 -1
  196. edsl/utilities/PrettyList.py +56 -0
  197. edsl/utilities/is_notebook.py +18 -0
  198. edsl/utilities/is_valid_variable_name.py +11 -0
  199. edsl/utilities/remove_edsl_version.py +24 -0
  200. edsl/utilities/utilities.py +35 -23
  201. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/METADATA +12 -10
  202. edsl-0.1.39.dist-info/RECORD +358 -0
  203. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/WHEEL +1 -1
  204. edsl/language_models/KeyLookup.py +0 -30
  205. edsl/language_models/registry.py +0 -190
  206. edsl/language_models/unused/ReplicateBase.py +0 -83
  207. edsl/results/ResultsDBMixin.py +0 -238
  208. edsl-0.1.38.dev4.dist-info/RECORD +0 -277
  209. /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
  210. /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
  211. /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
  212. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/LICENSE +0 -0
edsl/agents/AgentList.py CHANGED
@@ -1,38 +1,41 @@
1
- """A list of Agent objects.
2
-
3
- Example usage:
4
-
5
- .. code-block:: python
6
-
7
- al = AgentList([Agent.example(), Agent.example()])
8
- len(al)
9
- 2
10
-
1
+ """A list of Agents
11
2
  """
12
3
 
13
4
  from __future__ import annotations
14
5
  import csv
15
- import json
6
+ import sys
16
7
  from collections import UserList
8
+ from collections.abc import Iterable
9
+
17
10
  from typing import Any, List, Optional, Union, TYPE_CHECKING
18
- from rich import print_json
19
- from rich.table import Table
20
- from simpleeval import EvalWithCompoundTypes
21
- from edsl.Base import Base
22
- from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
23
11
 
24
- from collections.abc import Iterable
12
+ from simpleeval import EvalWithCompoundTypes, NameNotDefined
25
13
 
14
+ from edsl.Base import Base
15
+ from edsl.utilities.remove_edsl_version import remove_edsl_version
26
16
  from edsl.exceptions.agents import AgentListError
17
+ from edsl.utilities.is_notebook import is_notebook
18
+ from edsl.results.ResultsExportMixin import ResultsExportMixin
19
+ import logging
20
+
21
+ logger = logging.getLogger(__name__)
27
22
 
28
23
  if TYPE_CHECKING:
29
24
  from edsl.scenarios.ScenarioList import ScenarioList
25
+ from edsl.agents.Agent import Agent
26
+ from pandas import DataFrame
30
27
 
31
28
 
32
29
  def is_iterable(obj):
33
30
  return isinstance(obj, Iterable)
34
31
 
35
32
 
33
+ class EmptyAgentList:
34
+ def __repr__(self):
35
+ return "Empty AgentList"
36
+
37
+
38
+ # ResultsExportMixin,
36
39
  class AgentList(UserList, Base):
37
40
  """A list of Agents."""
38
41
 
@@ -50,14 +53,15 @@ class AgentList(UserList, Base):
50
53
  else:
51
54
  super().__init__()
52
55
 
53
- def shuffle(self, seed: Optional[str] = "edsl") -> AgentList:
56
+ def shuffle(self, seed: Optional[str] = None) -> AgentList:
54
57
  """Shuffle the AgentList.
55
58
 
56
59
  :param seed: The seed for the random number generator.
57
60
  """
58
61
  import random
59
62
 
60
- random.seed(seed)
63
+ if seed is not None:
64
+ random.seed(seed)
61
65
  random.shuffle(self.data)
62
66
  return self
63
67
 
@@ -73,22 +77,60 @@ class AgentList(UserList, Base):
73
77
  random.seed(seed)
74
78
  return AgentList(random.sample(self.data, n))
75
79
 
76
- def to_pandas(self):
77
- """Return a pandas DataFrame."""
80
+ def to_pandas(self) -> "DataFrame":
81
+ """Return a pandas DataFrame.
82
+
83
+ >>> from edsl.agents.Agent import Agent
84
+ >>> al = AgentList([Agent(traits = {'age': 22, 'hair': 'brown', 'height': 5.5}), Agent(traits = {'age': 22, 'hair': 'brown', 'height': 5.5})])
85
+ >>> al.to_pandas()
86
+ age hair height
87
+ 0 22 brown 5.5
88
+ 1 22 brown 5.5
89
+ """
78
90
  return self.to_scenario_list().to_pandas()
79
91
 
80
- def tally(self):
81
- return self.to_scenario_list().tally()
92
+ def tally(
93
+ self, *fields: Optional[str], top_n: Optional[int] = None, output="Dataset"
94
+ ) -> Union[dict, "Dataset"]:
95
+ """Tally the values of a field or perform a cross-tab of multiple fields.
96
+
97
+ :param fields: The field(s) to tally, multiple fields for cross-tabulation.
82
98
 
83
- def rename(self, old_name, new_name):
99
+ >>> al = AgentList.example()
100
+ >>> al.tally('age')
101
+ Dataset([{'age': [22]}, {'count': [2]}])
102
+ """
103
+ return self.to_scenario_list().tally(*fields, top_n=top_n, output=output)
104
+
105
+ def duplicate(self):
106
+ """Duplicate the AgentList.
107
+
108
+ >>> al = AgentList.example()
109
+ >>> al2 = al.duplicate()
110
+ >>> al2 == al
111
+ True
112
+ >>> id(al2) == id(al)
113
+ False
114
+ """
115
+ return AgentList([a.duplicate() for a in self.data])
116
+
117
+ def rename(self, old_name, new_name) -> AgentList:
84
118
  """Rename a trait in the AgentList.
85
119
 
86
120
  :param old_name: The old name of the trait.
87
121
  :param new_name: The new name of the trait.
122
+ :param inplace: Whether to rename the trait in place.
123
+
124
+ >>> from edsl.agents.Agent import Agent
125
+ >>> al = AgentList([Agent(traits = {'a': 1, 'b': 1}), Agent(traits = {'a': 1, 'b': 2})])
126
+ >>> al2 = al.rename('a', 'c')
127
+ >>> assert al2 == AgentList([Agent(traits = {'c': 1, 'b': 1}), Agent(traits = {'c': 1, 'b': 2})])
128
+ >>> assert al != al2
88
129
  """
89
- for agent in self.data:
90
- agent.rename(old_name, new_name)
91
- return self
130
+ newagents = []
131
+ for agent in self:
132
+ newagents.append(agent.rename(old_name, new_name))
133
+ return AgentList(newagents)
92
134
 
93
135
  def select(self, *traits) -> AgentList:
94
136
  """Selects agents with only the references traits.
@@ -123,19 +165,36 @@ class AgentList(UserList, Base):
123
165
  """
124
166
  return EvalWithCompoundTypes(names=agent.traits)
125
167
 
126
- try:
127
168
  # iterates through all the results and evaluates the expression
169
+
170
+ try:
128
171
  new_data = [
129
172
  agent for agent in self.data if create_evaluator(agent).eval(expression)
130
173
  ]
131
- except Exception as e:
132
- print(f"Exception:{e}")
133
- raise AgentListError(f"Error in filter. Exception:{e}")
174
+ except NameNotDefined as e:
175
+ e = AgentListError(f"'{expression}' is not a valid expression.")
176
+ if is_notebook():
177
+ print(e, file=sys.stderr)
178
+ else:
179
+ raise e
180
+
181
+ return EmptyAgentList()
182
+
183
+ if len(new_data) == 0:
184
+ return EmptyAgentList()
134
185
 
135
186
  return AgentList(new_data)
136
187
 
137
188
  @property
138
- def all_traits(self):
189
+ def all_traits(self) -> list[str]:
190
+ """Return all traits in the AgentList.
191
+ >>> from edsl.agents.Agent import Agent
192
+ >>> agent_1 = Agent(traits = {'age': 22})
193
+ >>> agent_2 = Agent(traits = {'hair': 'brown'})
194
+ >>> al = AgentList([agent_1, agent_2])
195
+ >>> al.all_traits
196
+ ['age', 'hair']
197
+ """
139
198
  d = {}
140
199
  for agent in self:
141
200
  d.update(agent.traits)
@@ -180,14 +239,20 @@ class AgentList(UserList, Base):
180
239
  agent_list.append(Agent(traits=row))
181
240
  return cls(agent_list)
182
241
 
183
- def translate_traits(self, values_codebook: dict[str, str]):
242
+ def translate_traits(self, codebook: dict[str, str]):
184
243
  """Translate traits to a new codebook.
185
244
 
186
245
  :param codebook: The new codebook.
246
+
247
+ >>> al = AgentList.example()
248
+ >>> codebook = {'hair': {'brown':'Secret word for green'}}
249
+ >>> al.translate_traits(codebook)
250
+ AgentList([Agent(traits = {'age': 22, 'hair': 'Secret word for green', 'height': 5.5}), Agent(traits = {'age': 22, 'hair': 'Secret word for green', 'height': 5.5})])
187
251
  """
252
+ new_agents = []
188
253
  for agent in self.data:
189
- agent.translate_traits(codebook)
190
- return self
254
+ new_agents.append(agent.translate_traits(codebook))
255
+ return AgentList(new_agents)
191
256
 
192
257
  def remove_trait(self, trait: str):
193
258
  """Remove traits from the AgentList.
@@ -198,20 +263,21 @@ class AgentList(UserList, Base):
198
263
  >>> al.remove_trait('age')
199
264
  AgentList([Agent(traits = {'hair': 'brown', 'height': 5.5}), Agent(traits = {'hair': 'brown', 'height': 5.5})])
200
265
  """
201
- for agent in self.data:
202
- _ = agent.remove_trait(trait)
203
- return self
266
+ agents = []
267
+ new_al = self.duplicate()
268
+ for agent in new_al.data:
269
+ agents.append(agent.remove_trait(trait))
270
+ return AgentList(agents)
204
271
 
205
- def add_trait(self, trait, values):
272
+ def add_trait(self, trait: str, values: List[Any]) -> AgentList:
206
273
  """Adds a new trait to every agent, with values taken from values.
207
274
 
208
275
  :param trait: The name of the trait.
209
276
  :param values: The valeues(s) of the trait. If a single value is passed, it is used for all agents.
210
277
 
211
278
  >>> al = AgentList.example()
212
- >>> al.add_trait('new_trait', 1)
213
- AgentList([Agent(traits = {'age': 22, 'hair': 'brown', 'height': 5.5, 'new_trait': 1}), Agent(traits = {'age': 22, 'hair': 'brown', 'height': 5.5, 'new_trait': 1})])
214
- >>> al.select('new_trait').to_scenario_list().to_list()
279
+ >>> new_al = al.add_trait('new_trait', 1)
280
+ >>> new_al.select('new_trait').to_scenario_list().to_list()
215
281
  [1, 1]
216
282
  >>> al.add_trait('new_trait', [1, 2, 3])
217
283
  Traceback (most recent call last):
@@ -220,18 +286,24 @@ class AgentList(UserList, Base):
220
286
  ...
221
287
  """
222
288
  if not is_iterable(values):
289
+ new_agents = []
223
290
  value = values
224
291
  for agent in self.data:
225
- agent.add_trait(trait, value)
226
- return self
292
+ new_agents.append(agent.add_trait(trait, value))
293
+ return AgentList(new_agents)
227
294
 
228
295
  if len(values) != len(self):
229
- raise AgentListError(
296
+ e = AgentListError(
230
297
  "The passed values have to be the same length as the agent list."
231
298
  )
299
+ if is_notebook():
300
+ print(e, file=sys.stderr)
301
+ else:
302
+ raise e
303
+ new_agents = []
232
304
  for agent, value in zip(self.data, values):
233
- agent.add_trait(trait, value)
234
- return self
305
+ new_agents.append(agent.add_trait(trait, value))
306
+ return AgentList(new_agents)
235
307
 
236
308
  @staticmethod
237
309
  def get_codebook(file_path: str):
@@ -244,12 +316,23 @@ class AgentList(UserList, Base):
244
316
  return {field: None for field in reader.fieldnames}
245
317
 
246
318
  def __hash__(self) -> int:
319
+ """Return the hash of the AgentList.
320
+
321
+ >>> al = AgentList.example()
322
+ >>> hash(al)
323
+ 1681154913465662422
324
+ """
247
325
  from edsl.utilities.utilities import dict_hash
248
326
 
249
327
  return dict_hash(self.to_dict(add_edsl_version=False, sorted=True))
250
328
 
251
329
  def to_dict(self, sorted=False, add_edsl_version=True):
252
- """Serialize the AgentList to a dictionary."""
330
+ """Serialize the AgentList to a dictionary.
331
+
332
+ >>> AgentList.example().to_dict(add_edsl_version=False)
333
+ {'agent_list': [{'traits': {'age': 22, 'hair': 'brown', 'height': 5.5}}, {'traits': {'age': 22, 'hair': 'brown', 'height': 5.5}}]}
334
+
335
+ """
253
336
  if sorted:
254
337
  data = self.data[:]
255
338
  data.sort(key=lambda x: hash(x))
@@ -279,15 +362,26 @@ class AgentList(UserList, Base):
279
362
 
280
363
  def _summary(self):
281
364
  return {
282
- "EDSL Class": "AgentList",
283
- "Number of agents": len(self),
284
- "Agent trait fields": self.all_traits,
365
+ "agents": len(self),
285
366
  }
286
367
 
287
- def _repr_html_(self):
288
- """Return an HTML representation of the AgentList."""
289
- footer = f"<a href={self.__documentation__}>(docs)</a>"
290
- return str(self.summary(format="html")) + footer
368
+ def set_codebook(self, codebook: dict[str, str]) -> AgentList:
369
+ """Set the codebook for the AgentList.
370
+
371
+ >>> from edsl.agents.Agent import Agent
372
+ >>> a = Agent(traits = {'hair': 'brown'})
373
+ >>> al = AgentList([a, a])
374
+ >>> _ = al.set_codebook({'hair': "Color of hair on driver's license"})
375
+ >>> al[0].codebook
376
+ {'hair': "Color of hair on driver's license"}
377
+
378
+
379
+ :param codebook: The codebook.
380
+ """
381
+ for agent in self.data:
382
+ agent.codebook = codebook
383
+
384
+ return self
291
385
 
292
386
  def to_csv(self, file_path: str):
293
387
  """Save the AgentList to a CSV file.
@@ -300,19 +394,33 @@ class AgentList(UserList, Base):
300
394
  """Return a list of tuples."""
301
395
  return self.to_scenario_list(include_agent_name).to_list()
302
396
 
303
- def to_scenario_list(self, include_agent_name=False) -> ScenarioList:
304
- """Return a list of scenarios."""
397
+ def to_scenario_list(
398
+ self, include_agent_name: bool = False, include_instruction: bool = False
399
+ ) -> ScenarioList:
400
+ """Converts the agent to a scenario list."""
305
401
  from edsl.scenarios.ScenarioList import ScenarioList
306
402
  from edsl.scenarios.Scenario import Scenario
307
403
 
308
- if include_agent_name:
309
- return ScenarioList(
310
- [
311
- Scenario(agent.traits | {"agent_name": agent.name})
312
- for agent in self.data
313
- ]
314
- )
315
- return ScenarioList([Scenario(agent.traits) for agent in self.data])
404
+ # raise NotImplementedError("This method is not implemented yet.")
405
+
406
+ scenario_list = ScenarioList()
407
+ for agent in self.data:
408
+ d = agent.traits
409
+ if include_agent_name:
410
+ d["agent_name"] = agent.name
411
+ if include_instruction:
412
+ d["instruction"] = agent.instruction
413
+ scenario_list.append(Scenario(d))
414
+ return scenario_list
415
+
416
+ # if include_agent_name:
417
+ # return ScenarioList(
418
+ # [
419
+ # Scenario(agent.traits | {"agent_name": agent.name} | })
420
+ # for agent in self.data
421
+ # ]
422
+ # )
423
+ # return ScenarioList([Scenario(agent.traits) for agent in self.data])
316
424
 
317
425
  def table(
318
426
  self,
@@ -320,12 +428,50 @@ class AgentList(UserList, Base):
320
428
  tablefmt: Optional[str] = None,
321
429
  pretty_labels: Optional[dict] = None,
322
430
  ) -> Table:
431
+ if len(self) == 0:
432
+ e = AgentListError("Cannot create a table from an empty AgentList.")
433
+ if is_notebook():
434
+ print(e, file=sys.stderr)
435
+ return None
436
+ else:
437
+ raise e
323
438
  return (
324
439
  self.to_scenario_list()
325
440
  .to_dataset()
326
441
  .table(*fields, tablefmt=tablefmt, pretty_labels=pretty_labels)
327
442
  )
328
443
 
444
+ def to_dataset(self, traits_only: bool = True):
445
+ """
446
+ Convert the AgentList to a Dataset.
447
+
448
+ >>> from edsl.agents.AgentList import AgentList
449
+ >>> al = AgentList.example()
450
+ >>> al.to_dataset()
451
+ Dataset([{'age': [22, 22]}, {'hair': ['brown', 'brown']}, {'height': [5.5, 5.5]}])
452
+ >>> al.to_dataset(traits_only = False)
453
+ Dataset([{'age': [22, 22]}, {'hair': ['brown', 'brown']}, {'height': [5.5, 5.5]}, {'agent_parameters': [{'instruction': 'You are answering questions as if you were a human. Do not break character.', 'agent_name': None}, {'instruction': 'You are answering questions as if you were a human. Do not break character.', 'agent_name': None}]}])
454
+ """
455
+ from edsl.results.Dataset import Dataset
456
+ from collections import defaultdict
457
+
458
+ agent_trait_keys = []
459
+ for agent in self:
460
+ agent_keys = list(agent.traits.keys())
461
+ for key in agent_keys:
462
+ if key not in agent_trait_keys:
463
+ agent_trait_keys.append(key)
464
+
465
+ data = defaultdict(list)
466
+ for agent in self:
467
+ for trait_key in agent_trait_keys:
468
+ data[trait_key].append(agent.traits.get(trait_key, None))
469
+ if not traits_only:
470
+ data["agent_parameters"].append(
471
+ {"instruction": agent.instruction, "agent_name": agent.name}
472
+ )
473
+ return Dataset([{key: entry} for key, entry in data.items()])
474
+
329
475
  def tree(self, node_order: Optional[List[str]] = None):
330
476
  return self.to_scenario_list().tree(node_order)
331
477
 
@@ -398,14 +544,6 @@ class AgentList(UserList, Base):
398
544
  return "\n".join(lines)
399
545
  return lines
400
546
 
401
- def rich_print(self) -> Table:
402
- """Display an object as a rich table."""
403
- table = Table(title="AgentList")
404
- table.add_column("Agents", style="bold")
405
- for agent in self.data:
406
- table.add_row(agent.rich_print())
407
- return table
408
-
409
547
 
410
548
  if __name__ == "__main__":
411
549
  import doctest
@@ -1,38 +1,29 @@
1
1
  """Module for creating Invigilators, which are objects to administer a question to an Agent."""
2
2
 
3
- from typing import Dict, Any, Optional
3
+ from typing import Dict, Any, Optional, TYPE_CHECKING
4
4
 
5
- from edsl.prompts.Prompt import Prompt
6
- from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
7
-
8
- # from edsl.prompts.registry import get_classes as prompt_lookup
5
+ from edsl.utilities.decorators import sync_wrapper
9
6
  from edsl.exceptions.questions import QuestionAnswerValidationError
10
7
  from edsl.agents.InvigilatorBase import InvigilatorBase
11
8
  from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
12
- from edsl.agents.PromptConstructor import PromptConstructor
9
+
10
+ if TYPE_CHECKING:
11
+ from edsl.prompts.Prompt import Prompt
12
+ from edsl.scenarios.Scenario import Scenario
13
+ from edsl.surveys.Survey import Survey
13
14
 
14
15
 
15
- class NotApplicable(str):
16
- def __new__(cls):
17
- instance = super().__new__(cls, "Not Applicable")
18
- instance.literal = "Not Applicable"
19
- return instance
16
+ NA = "Not Applicable"
20
17
 
21
18
 
22
19
  class InvigilatorAI(InvigilatorBase):
23
20
  """An invigilator that uses an AI model to answer questions."""
24
21
 
25
- def get_prompts(self) -> Dict[str, Prompt]:
22
+ def get_prompts(self) -> Dict[str, "Prompt"]:
26
23
  """Return the prompts used."""
27
24
  return self.prompt_constructor.get_prompts()
28
25
 
29
- async def async_answer_question(self) -> AgentResponseDict:
30
- """Answer a question using the AI model.
31
-
32
- >>> i = InvigilatorAI.example()
33
- >>> i.answer_question()
34
- {'message': [{'text': 'SPAM!'}], 'usage': {'prompt_tokens': 1, 'completion_tokens': 1}}
35
- """
26
+ async def async_get_agent_response(self) -> AgentResponseDict:
36
27
  prompts = self.get_prompts()
37
28
  params = {
38
29
  "user_prompt": prompts["user_prompt"].text,
@@ -40,33 +31,95 @@ class InvigilatorAI(InvigilatorBase):
40
31
  }
41
32
  if "encoded_image" in prompts:
42
33
  params["encoded_image"] = prompts["encoded_image"]
34
+ raise NotImplementedError("encoded_image not implemented")
35
+
43
36
  if "files_list" in prompts:
44
37
  params["files_list"] = prompts["files_list"]
45
38
 
46
39
  params.update({"iteration": self.iteration, "cache": self.cache})
47
-
48
40
  params.update({"invigilator": self})
49
- # if hasattr(self.question, "answer_template"):
50
- # breakpoint()
51
41
 
52
- agent_response_dict: AgentResponseDict = await self.model.async_get_response(
53
- **params
54
- )
55
- # store to self in case validation failure
42
+ if self.key_lookup:
43
+ self.model.set_key_lookup(self.key_lookup)
44
+
45
+ return await self.model.async_get_response(**params)
46
+
47
+ def store_response(self, agent_response_dict: AgentResponseDict) -> None:
48
+ """Store the response in the invigilator, in case it is needed later because of validation failure."""
56
49
  self.raw_model_response = agent_response_dict.model_outputs.response
57
50
  self.generated_tokens = agent_response_dict.edsl_dict.generated_tokens
58
51
 
59
- return self.extract_edsl_result_entry_and_validate(agent_response_dict)
52
+ async def async_answer_question(self) -> AgentResponseDict:
53
+ """Answer a question using the AI model.
54
+
55
+ >>> i = InvigilatorAI.example()
56
+ """
57
+ agent_response_dict = await self.async_get_agent_response()
58
+ self.store_response(agent_response_dict)
59
+ return self._extract_edsl_result_entry_and_validate(agent_response_dict)
60
60
 
61
61
  def _remove_from_cache(self, cache_key) -> None:
62
62
  """Remove an entry from the cache."""
63
63
  if cache_key:
64
64
  del self.cache.data[cache_key]
65
65
 
66
- def determine_answer(self, raw_answer: str) -> Any:
67
- question_dict = self.survey.question_names_to_questions()
66
+ def _determine_answer(self, raw_answer: str) -> Any:
67
+ """Determine the answer from the raw answer.
68
+
69
+ >>> i = InvigilatorAI.example()
70
+ >>> i._determine_answer("SPAM!")
71
+ 'SPAM!'
72
+
73
+ >>> from edsl.questions import QuestionMultipleChoice
74
+ >>> q = QuestionMultipleChoice(question_text = "How are you?", question_name = "how_are_you", question_options = ["Good", "Bad"], use_code = True)
75
+ >>> i = InvigilatorAI.example(question = q)
76
+ >>> i._determine_answer("1")
77
+ 'Bad'
78
+ >>> i._determine_answer("0")
79
+ 'Good'
80
+
81
+ This shows how the answer can depend on scenario details
82
+
83
+ >>> from edsl import Scenario
84
+ >>> s = Scenario({'feeling_options':['Good', 'Bad']})
85
+ >>> q = QuestionMultipleChoice(question_text = "How are you?", question_name = "how_are_you", question_options = "{{ feeling_options }}", use_code = True)
86
+ >>> i = InvigilatorAI.example(question = q, scenario = s)
87
+ >>> i._determine_answer("1")
88
+ 'Bad'
89
+
90
+ >>> from edsl import QuestionList, QuestionMultipleChoice, Survey
91
+ >>> q1 = QuestionList(question_name = "favs", question_text = "What are your top 3 colors?")
92
+ >>> q2 = QuestionMultipleChoice(question_text = "What is your favorite color?", question_name = "best", question_options = "{{ favs.answer }}", use_code = True)
93
+ >>> survey = Survey([q1, q2])
94
+ >>> i = InvigilatorAI.example(question = q2, scenario = s, survey = survey)
95
+ >>> i.current_answers = {"favs": ["Green", "Blue", "Red"]}
96
+ >>> i._determine_answer("2")
97
+ 'Red'
98
+ """
99
+ substitution_dict = self._prepare_substitution_dict(
100
+ self.survey, self.current_answers, self.scenario
101
+ )
102
+ return self.question._translate_answer_code_to_answer(
103
+ raw_answer, substitution_dict
104
+ )
105
+
106
+ @staticmethod
107
+ def _prepare_substitution_dict(
108
+ survey: "Survey", current_answers: dict, scenario: "Scenario"
109
+ ) -> Dict[str, Any]:
110
+ """Prepares a substitution dictionary for the question based on the survey, current answers, and scenario.
111
+
112
+ This is necessary beause sometimes the model's answer to a question could depend on details in
113
+ the prompt that were provided by the answer to a previous question or a scenario detail.
114
+
115
+ Note that the question object is getting the answer & a the comment appended to it, as the
116
+ jinja2 template might be referencing these values with a dot notation.
117
+
118
+ """
119
+ question_dict = survey.duplicate().question_names_to_questions()
120
+
68
121
  # iterates through the current answers and updates the question_dict (which is all questions)
69
- for other_question, answer in self.current_answers.items():
122
+ for other_question, answer in current_answers.items():
70
123
  if other_question in question_dict:
71
124
  question_dict[other_question].answer = answer
72
125
  else:
@@ -76,13 +129,12 @@ class InvigilatorAI(InvigilatorBase):
76
129
  ) in question_dict:
77
130
  question_dict[new_question].comment = answer
78
131
 
79
- combined_dict = {**question_dict, **self.scenario}
80
- # sometimes the answer is a code, so we need to translate it
81
- return self.question._translate_answer_code_to_answer(raw_answer, combined_dict)
132
+ return {**question_dict, **scenario}
82
133
 
83
- def extract_edsl_result_entry_and_validate(
134
+ def _extract_edsl_result_entry_and_validate(
84
135
  self, agent_response_dict: AgentResponseDict
85
136
  ) -> EDSLResultObjectInput:
137
+ """Extract the EDSL result entry and validate it."""
86
138
  edsl_dict = agent_response_dict.edsl_dict._asdict()
87
139
  exception_occurred = None
88
140
  validated = False
@@ -94,10 +146,8 @@ class InvigilatorAI(InvigilatorBase):
94
146
  # question options have be treated differently because of dynamic question
95
147
  # this logic is all in the prompt constructor
96
148
  if "question_options" in self.question.data:
97
- new_question_options = (
98
- self.prompt_constructor._get_question_options(
99
- self.question.data
100
- )
149
+ new_question_options = self.prompt_constructor.get_question_options(
150
+ self.question.data
101
151
  )
102
152
  if new_question_options != self.question.data["question_options"]:
103
153
  # I don't love this direct writing but it seems to work
@@ -110,9 +160,8 @@ class InvigilatorAI(InvigilatorBase):
110
160
  else:
111
161
  question_with_validators = self.question
112
162
 
113
- # breakpoint()
114
163
  validated_edsl_dict = question_with_validators._validate_answer(edsl_dict)
115
- answer = self.determine_answer(validated_edsl_dict["answer"])
164
+ answer = self._determine_answer(validated_edsl_dict["answer"])
116
165
  comment = validated_edsl_dict.get("comment", "")
117
166
  validated = True
118
167
  except QuestionAnswerValidationError as e:
@@ -182,13 +231,13 @@ class InvigilatorHuman(InvigilatorBase):
182
231
  exception_occurred = e
183
232
  finally:
184
233
  data = {
185
- "generated_tokens": NotApplicable(),
234
+ "generated_tokens": NA, # NotApplicable(),
186
235
  "question_name": self.question.question_name,
187
236
  "prompts": self.get_prompts(),
188
- "cached_response": NotApplicable(),
189
- "raw_model_response": NotApplicable(),
190
- "cache_used": NotApplicable(),
191
- "cache_key": NotApplicable(),
237
+ "cached_response": NA,
238
+ "raw_model_response": NA,
239
+ "cache_used": NA,
240
+ "cache_key": NA,
192
241
  "answer": answer,
193
242
  "comment": comment,
194
243
  "validated": validated,
@@ -209,17 +258,19 @@ class InvigilatorFunctional(InvigilatorBase):
209
258
  generated_tokens=str(answer),
210
259
  question_name=self.question.question_name,
211
260
  prompts=self.get_prompts(),
212
- cached_response=NotApplicable(),
213
- raw_model_response=NotApplicable(),
214
- cache_used=NotApplicable(),
215
- cache_key=NotApplicable(),
261
+ cached_response=NA,
262
+ raw_model_response=NA,
263
+ cache_used=NA,
264
+ cache_key=NA,
216
265
  answer=answer["answer"],
217
266
  comment="This is the result of a functional question.",
218
267
  validated=True,
219
268
  exception_occurred=None,
220
269
  )
221
270
 
222
- def get_prompts(self) -> Dict[str, Prompt]:
271
+ def get_prompts(self) -> Dict[str, "Prompt"]:
272
+ from edsl.prompts.Prompt import Prompt
273
+
223
274
  """Return the prompts used."""
224
275
  return {
225
276
  "user_prompt": Prompt("NA"),