edsl 0.1.39.dev1__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. edsl/Base.py +169 -116
  2. edsl/__init__.py +14 -6
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +358 -146
  5. edsl/agents/AgentList.py +211 -73
  6. edsl/agents/Invigilator.py +88 -36
  7. edsl/agents/InvigilatorBase.py +59 -70
  8. edsl/agents/PromptConstructor.py +117 -219
  9. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  10. edsl/agents/QuestionOptionProcessor.py +172 -0
  11. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  12. edsl/agents/__init__.py +0 -1
  13. edsl/agents/prompt_helpers.py +3 -3
  14. edsl/config.py +22 -2
  15. edsl/conversation/car_buying.py +2 -1
  16. edsl/coop/CoopFunctionsMixin.py +15 -0
  17. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  18. edsl/coop/PriceFetcher.py +1 -1
  19. edsl/coop/coop.py +104 -42
  20. edsl/coop/utils.py +14 -14
  21. edsl/data/Cache.py +21 -14
  22. edsl/data/CacheEntry.py +12 -15
  23. edsl/data/CacheHandler.py +33 -12
  24. edsl/data/__init__.py +4 -3
  25. edsl/data_transfer_models.py +2 -1
  26. edsl/enums.py +20 -0
  27. edsl/exceptions/__init__.py +50 -50
  28. edsl/exceptions/agents.py +12 -0
  29. edsl/exceptions/inference_services.py +5 -0
  30. edsl/exceptions/questions.py +24 -6
  31. edsl/exceptions/scenarios.py +7 -0
  32. edsl/inference_services/AnthropicService.py +0 -3
  33. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  34. edsl/inference_services/AvailableModelFetcher.py +209 -0
  35. edsl/inference_services/AwsBedrock.py +0 -2
  36. edsl/inference_services/AzureAI.py +0 -2
  37. edsl/inference_services/GoogleService.py +2 -11
  38. edsl/inference_services/InferenceServiceABC.py +18 -85
  39. edsl/inference_services/InferenceServicesCollection.py +105 -80
  40. edsl/inference_services/MistralAIService.py +0 -3
  41. edsl/inference_services/OpenAIService.py +1 -4
  42. edsl/inference_services/PerplexityService.py +0 -3
  43. edsl/inference_services/ServiceAvailability.py +135 -0
  44. edsl/inference_services/TestService.py +11 -8
  45. edsl/inference_services/data_structures.py +62 -0
  46. edsl/jobs/AnswerQuestionFunctionConstructor.py +188 -0
  47. edsl/jobs/Answers.py +1 -14
  48. edsl/jobs/FetchInvigilator.py +40 -0
  49. edsl/jobs/InterviewTaskManager.py +98 -0
  50. edsl/jobs/InterviewsConstructor.py +48 -0
  51. edsl/jobs/Jobs.py +102 -243
  52. edsl/jobs/JobsChecks.py +35 -10
  53. edsl/jobs/JobsComponentConstructor.py +189 -0
  54. edsl/jobs/JobsPrompts.py +5 -3
  55. edsl/jobs/JobsRemoteInferenceHandler.py +128 -80
  56. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  57. edsl/jobs/RequestTokenEstimator.py +30 -0
  58. edsl/jobs/buckets/BucketCollection.py +44 -3
  59. edsl/jobs/buckets/TokenBucket.py +53 -21
  60. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  61. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  62. edsl/jobs/decorators.py +35 -0
  63. edsl/jobs/interviews/Interview.py +77 -380
  64. edsl/jobs/jobs_status_enums.py +9 -0
  65. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  66. edsl/jobs/runners/JobsRunnerAsyncio.py +4 -49
  67. edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
  68. edsl/jobs/tasks/TaskHistory.py +14 -15
  69. edsl/jobs/tasks/task_status_enum.py +0 -2
  70. edsl/language_models/ComputeCost.py +63 -0
  71. edsl/language_models/LanguageModel.py +137 -234
  72. edsl/language_models/ModelList.py +11 -13
  73. edsl/language_models/PriceManager.py +127 -0
  74. edsl/language_models/RawResponseHandler.py +106 -0
  75. edsl/language_models/ServiceDataSources.py +0 -0
  76. edsl/language_models/__init__.py +0 -1
  77. edsl/language_models/key_management/KeyLookup.py +63 -0
  78. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  79. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  80. edsl/language_models/key_management/__init__.py +0 -0
  81. edsl/language_models/key_management/models.py +131 -0
  82. edsl/language_models/registry.py +49 -59
  83. edsl/language_models/repair.py +2 -2
  84. edsl/language_models/utilities.py +5 -4
  85. edsl/notebooks/Notebook.py +19 -14
  86. edsl/notebooks/NotebookToLaTeX.py +142 -0
  87. edsl/prompts/Prompt.py +29 -39
  88. edsl/questions/AnswerValidatorMixin.py +47 -2
  89. edsl/questions/ExceptionExplainer.py +77 -0
  90. edsl/questions/HTMLQuestion.py +103 -0
  91. edsl/questions/LoopProcessor.py +149 -0
  92. edsl/questions/QuestionBase.py +37 -192
  93. edsl/questions/QuestionBaseGenMixin.py +52 -48
  94. edsl/questions/QuestionBasePromptsMixin.py +7 -3
  95. edsl/questions/QuestionCheckBox.py +1 -1
  96. edsl/questions/QuestionExtract.py +1 -1
  97. edsl/questions/QuestionFreeText.py +1 -2
  98. edsl/questions/QuestionList.py +3 -5
  99. edsl/questions/QuestionMatrix.py +265 -0
  100. edsl/questions/QuestionMultipleChoice.py +66 -22
  101. edsl/questions/QuestionNumerical.py +1 -3
  102. edsl/questions/QuestionRank.py +6 -16
  103. edsl/questions/ResponseValidatorABC.py +37 -11
  104. edsl/questions/ResponseValidatorFactory.py +28 -0
  105. edsl/questions/SimpleAskMixin.py +4 -3
  106. edsl/questions/__init__.py +1 -0
  107. edsl/questions/derived/QuestionLinearScale.py +6 -3
  108. edsl/questions/derived/QuestionTopK.py +1 -1
  109. edsl/questions/descriptors.py +17 -3
  110. edsl/questions/question_registry.py +1 -1
  111. edsl/questions/templates/matrix/__init__.py +1 -0
  112. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  113. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  114. edsl/results/CSSParameterizer.py +1 -1
  115. edsl/results/Dataset.py +170 -7
  116. edsl/results/DatasetExportMixin.py +224 -302
  117. edsl/results/DatasetTree.py +28 -8
  118. edsl/results/MarkdownToDocx.py +122 -0
  119. edsl/results/MarkdownToPDF.py +111 -0
  120. edsl/results/Result.py +192 -206
  121. edsl/results/Results.py +120 -113
  122. edsl/results/ResultsExportMixin.py +2 -0
  123. edsl/results/Selector.py +23 -13
  124. edsl/results/TableDisplay.py +98 -171
  125. edsl/results/TextEditor.py +50 -0
  126. edsl/results/__init__.py +1 -1
  127. edsl/results/smart_objects.py +96 -0
  128. edsl/results/table_data_class.py +12 -0
  129. edsl/results/table_renderers.py +118 -0
  130. edsl/scenarios/ConstructDownloadLink.py +109 -0
  131. edsl/scenarios/DirectoryScanner.py +96 -0
  132. edsl/scenarios/DocumentChunker.py +102 -0
  133. edsl/scenarios/DocxScenario.py +16 -0
  134. edsl/scenarios/FileStore.py +118 -239
  135. edsl/scenarios/PdfExtractor.py +40 -0
  136. edsl/scenarios/Scenario.py +90 -193
  137. edsl/scenarios/ScenarioHtmlMixin.py +4 -3
  138. edsl/scenarios/ScenarioJoin.py +10 -6
  139. edsl/scenarios/ScenarioList.py +383 -240
  140. edsl/scenarios/ScenarioListExportMixin.py +0 -7
  141. edsl/scenarios/ScenarioListPdfMixin.py +15 -37
  142. edsl/scenarios/ScenarioSelector.py +156 -0
  143. edsl/scenarios/__init__.py +1 -2
  144. edsl/scenarios/file_methods.py +85 -0
  145. edsl/scenarios/handlers/__init__.py +13 -0
  146. edsl/scenarios/handlers/csv.py +38 -0
  147. edsl/scenarios/handlers/docx.py +76 -0
  148. edsl/scenarios/handlers/html.py +37 -0
  149. edsl/scenarios/handlers/json.py +111 -0
  150. edsl/scenarios/handlers/latex.py +5 -0
  151. edsl/scenarios/handlers/md.py +51 -0
  152. edsl/scenarios/handlers/pdf.py +68 -0
  153. edsl/scenarios/handlers/png.py +39 -0
  154. edsl/scenarios/handlers/pptx.py +105 -0
  155. edsl/scenarios/handlers/py.py +294 -0
  156. edsl/scenarios/handlers/sql.py +313 -0
  157. edsl/scenarios/handlers/sqlite.py +149 -0
  158. edsl/scenarios/handlers/txt.py +33 -0
  159. edsl/study/ObjectEntry.py +1 -1
  160. edsl/study/SnapShot.py +1 -1
  161. edsl/study/Study.py +5 -12
  162. edsl/surveys/ConstructDAG.py +92 -0
  163. edsl/surveys/EditSurvey.py +221 -0
  164. edsl/surveys/InstructionHandler.py +100 -0
  165. edsl/surveys/MemoryManagement.py +72 -0
  166. edsl/surveys/Rule.py +5 -4
  167. edsl/surveys/RuleCollection.py +25 -27
  168. edsl/surveys/RuleManager.py +172 -0
  169. edsl/surveys/Simulator.py +75 -0
  170. edsl/surveys/Survey.py +199 -771
  171. edsl/surveys/SurveyCSS.py +20 -8
  172. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
  173. edsl/surveys/SurveyToApp.py +141 -0
  174. edsl/surveys/__init__.py +4 -2
  175. edsl/surveys/descriptors.py +6 -2
  176. edsl/surveys/instructions/ChangeInstruction.py +1 -2
  177. edsl/surveys/instructions/Instruction.py +4 -13
  178. edsl/surveys/instructions/InstructionCollection.py +11 -6
  179. edsl/templates/error_reporting/interview_details.html +1 -1
  180. edsl/templates/error_reporting/report.html +1 -1
  181. edsl/tools/plotting.py +1 -1
  182. edsl/utilities/PrettyList.py +56 -0
  183. edsl/utilities/is_notebook.py +18 -0
  184. edsl/utilities/is_valid_variable_name.py +11 -0
  185. edsl/utilities/remove_edsl_version.py +24 -0
  186. edsl/utilities/utilities.py +35 -23
  187. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +12 -10
  188. edsl-0.1.39.dev2.dist-info/RECORD +352 -0
  189. edsl/language_models/KeyLookup.py +0 -30
  190. edsl/language_models/unused/ReplicateBase.py +0 -83
  191. edsl/results/ResultsDBMixin.py +0 -238
  192. edsl-0.1.39.dev1.dist-info/RECORD +0 -277
  193. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
  194. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +0 -0
edsl/agents/Agent.py CHANGED
@@ -4,20 +4,46 @@ from __future__ import annotations
4
4
  import copy
5
5
  import inspect
6
6
  import types
7
- from typing import Callable, Optional, Union, Any, TYPE_CHECKING
7
+ from typing import (
8
+ Callable,
9
+ Optional,
10
+ Union,
11
+ Any,
12
+ TYPE_CHECKING,
13
+ Protocol,
14
+ runtime_checkable,
15
+ TypeVar,
16
+ )
17
+ from contextlib import contextmanager
18
+ from dataclasses import dataclass
19
+
20
+ # Type variable for the Agent class
21
+ A = TypeVar("A", bound="Agent")
8
22
 
9
23
  if TYPE_CHECKING:
10
- from edsl import Cache, Survey, Scenario
24
+ from edsl.data.Cache import Cache
25
+ from edsl.surveys.Survey import Survey
26
+ from edsl.scenarios.Scenario import Scenario
11
27
  from edsl.language_models import LanguageModel
12
28
  from edsl.surveys.MemoryPlan import MemoryPlan
13
29
  from edsl.questions import QuestionBase
14
30
  from edsl.agents.Invigilator import InvigilatorBase
31
+ from edsl.prompts import Prompt
32
+ from edsl.questions.QuestionBase import QuestionBase
33
+ from edsl.scenarios.Scenario import Scenario
34
+
35
+
36
+ @runtime_checkable
37
+ class DirectAnswerMethod(Protocol):
38
+ """Protocol defining the required signature for direct answer methods."""
39
+
40
+ def __call__(self, self_: A, question: QuestionBase, scenario: Scenario) -> Any: ...
41
+
15
42
 
16
43
  from uuid import uuid4
17
44
 
18
45
  from edsl.Base import Base
19
- from edsl.prompts import Prompt
20
- from edsl.exceptions import QuestionScenarioRenderError
46
+ from edsl.exceptions.questions import QuestionScenarioRenderError
21
47
 
22
48
  from edsl.exceptions.agents import (
23
49
  AgentErrors,
@@ -34,17 +60,25 @@ from edsl.agents.descriptors import (
34
60
  )
35
61
  from edsl.utilities.decorators import (
36
62
  sync_wrapper,
37
- add_edsl_version,
38
- remove_edsl_version,
39
63
  )
64
+ from edsl.utilities.remove_edsl_version import remove_edsl_version
40
65
  from edsl.data_transfer_models import AgentResponseDict
41
66
  from edsl.utilities.restricted_python import create_restricted_function
42
67
 
68
+ from edsl.scenarios.Scenario import Scenario
69
+
70
+
71
+ class AgentTraits(Scenario):
72
+ """A class representing the traits of an agent."""
73
+
74
+ def __repr__(self):
75
+ return f"{self.data}"
76
+
43
77
 
44
78
  class Agent(Base):
45
79
  """An class representing an agent that can answer questions."""
46
80
 
47
- __doc__ = "https://docs.expectedparrot.com/en/latest/agents.html"
81
+ __documentation__ = "https://docs.expectedparrot.com/en/latest/agents.html"
48
82
 
49
83
  default_instruction = """You are answering questions as if you were a human. Do not break character."""
50
84
 
@@ -77,6 +111,8 @@ class Agent(Base):
77
111
  :param instruction: Instructions for the agent in how to answer questions.
78
112
  :param trait_presentation_template: A template for how to present the agent's traits.
79
113
  :param dynamic_traits_function: A function that returns a dictionary of traits.
114
+ :param dynamic_traits_function_source_code: The source code for the dynamic traits function.
115
+ :param dynamic_traits_function_name: The name of the dynamic traits function.
80
116
 
81
117
  The `traits` parameter is a dictionary of traits that the agent has.
82
118
  These traits are used to construct a prompt that is presented to the LLM.
@@ -119,17 +155,59 @@ class Agent(Base):
119
155
  See see how these are used to actually construct the prompt that is presented to the LLM, see :py:class:`edsl.agents.Invigilator.InvigilatorBase`.
120
156
 
121
157
  """
158
+ self._initialize_basic_attributes(traits, name, codebook)
159
+ self._initialize_instruction(instruction)
160
+ self._initialize_dynamic_traits_function(
161
+ dynamic_traits_function,
162
+ dynamic_traits_function_source_code,
163
+ dynamic_traits_function_name,
164
+ )
165
+ self._initialize_answer_question_directly(
166
+ answer_question_directly_source_code, answer_question_directly_function_name
167
+ )
168
+ self._check_dynamic_traits_function()
169
+ self._initialize_traits_presentation_template(traits_presentation_template)
170
+ self.current_question = None
171
+
172
+ def _initialize_basic_attributes(self, traits, name, codebook) -> None:
173
+ """Initialize the basic attributes of the agent."""
122
174
  self.name = name
123
- self._traits = traits or dict()
175
+ self._traits = AgentTraits(traits or dict())
124
176
  self.codebook = codebook or dict()
177
+
178
+ def _initialize_instruction(self, instruction) -> None:
179
+ """Initialize the instruction for the agent i.e., how the agent should answer questions."""
125
180
  if instruction is None:
126
181
  self.instruction = self.default_instruction
182
+ self._instruction = self.default_instruction
183
+ self.set_instructions = False
127
184
  else:
128
185
  self.instruction = instruction
129
- # self.instruction = instruction or self.default_instruction
130
- self.dynamic_traits_function = dynamic_traits_function
186
+ self._instruction = instruction
187
+ self.set_instructions = True
131
188
 
189
+ def _initialize_traits_presentation_template(
190
+ self, traits_presentation_template
191
+ ) -> None:
192
+ """Initialize the traits presentation template. How the agent's traits are presented to the LLM."""
193
+ if traits_presentation_template is not None:
194
+ self._traits_presentation_template = traits_presentation_template
195
+ self.traits_presentation_template = traits_presentation_template
196
+ self.set_traits_presentation_template = True
197
+ else:
198
+ self.traits_presentation_template = "Your traits: {{traits}}"
199
+ self.set_traits_presentation_template = False
200
+
201
+ def _initialize_dynamic_traits_function(
202
+ self,
203
+ dynamic_traits_function,
204
+ dynamic_traits_function_source_code,
205
+ dynamic_traits_function_name,
206
+ ) -> None:
207
+ """Initialize the dynamic traits function i.e., a function that returns a dictionary of traits based on the question."""
132
208
  # Deal with dynamic traits function
209
+ self.dynamic_traits_function = dynamic_traits_function
210
+
133
211
  if self.dynamic_traits_function:
134
212
  self.dynamic_traits_function_name = self.dynamic_traits_function.__name__
135
213
  self.has_dynamic_traits_function = True
@@ -142,7 +220,11 @@ class Agent(Base):
142
220
  dynamic_traits_function_name, dynamic_traits_function
143
221
  )
144
222
 
145
- # Deal with direct answer function
223
+ def _initialize_answer_question_directly(
224
+ self,
225
+ answer_question_directly_source_code,
226
+ answer_question_directly_function_name,
227
+ ) -> None:
146
228
  if answer_question_directly_source_code:
147
229
  self.answer_question_directly_function_name = (
148
230
  answer_question_directly_function_name
@@ -154,18 +236,56 @@ class Agent(Base):
154
236
  bound_method = types.MethodType(protected_method, self)
155
237
  setattr(self, "answer_question_directly", bound_method)
156
238
 
157
- self._check_dynamic_traits_function()
158
-
159
- self.current_question = None
160
-
239
+ def _initialize_traits_presentation_template(
240
+ self, traits_presentation_template
241
+ ) -> None:
161
242
  if traits_presentation_template is not None:
162
243
  self._traits_presentation_template = traits_presentation_template
163
244
  self.traits_presentation_template = traits_presentation_template
245
+ self.set_traits_presentation_template = True
164
246
  else:
165
247
  self.traits_presentation_template = "Your traits: {{traits}}"
248
+ self.set_traits_presentation_template = False
249
+
250
+ def duplicate(self) -> Agent:
251
+ """Return a duplicate of the agent.
252
+
253
+ >>> a = Agent(traits = {"age": 10, "hair": "brown", "height": 5.5}, codebook = {'age': 'Their age is'})
254
+ >>> a2 = a.duplicate()
255
+ >>> a2 == a
256
+ True
257
+ >>> id(a) == id(a2)
258
+ False
259
+ >>> def f(self, question, scenario): return "I am a direct answer."
260
+ >>> a.add_direct_question_answering_method(f)
261
+ >>> hasattr(a, "answer_question_directly")
262
+ True
263
+ >>> a2 = a.duplicate()
264
+ >>> a2.answer_question_directly(None, None)
265
+ 'I am a direct answer.'
266
+
267
+ >>> a = Agent(traits = {'age': 10}, instruction = "Have fun!")
268
+ >>> a2 = a.duplicate()
269
+ >>> a2.instruction
270
+ 'Have fun!'
271
+ """
272
+ new_agent = Agent.from_dict(self.to_dict())
273
+ if hasattr(self, "answer_question_directly"):
274
+ answer_question_directly = self.answer_question_directly
275
+ newf = lambda self, question, scenario: answer_question_directly(
276
+ question, scenario
277
+ )
278
+ new_agent.add_direct_question_answering_method(newf)
279
+ if hasattr(self, "dynamic_traits_function"):
280
+ dynamic_traits_function = self.dynamic_traits_function
281
+ new_agent.dynamic_traits_function = dynamic_traits_function
282
+ return new_agent
166
283
 
167
284
  @property
168
285
  def agent_persona(self) -> Prompt:
286
+ """Return the agent persona template."""
287
+ from edsl.prompts.Prompt import Prompt
288
+
169
289
  return Prompt(text=self.traits_presentation_template)
170
290
 
171
291
  def prompt(self) -> str:
@@ -241,59 +361,111 @@ class Agent(Base):
241
361
  else:
242
362
  return self.dynamic_traits_function()
243
363
  else:
244
- return self._traits
245
-
246
- def _repr_html_(self):
247
- # d = self.to_dict(add_edsl_version=False)
248
- d = self.traits
249
- data = [[k, v] for k, v in d.items()]
250
- from tabulate import tabulate
364
+ return dict(self._traits)
365
+
366
+ @contextmanager
367
+ def modify_traits_context(self):
368
+ self._check_before_modifying_traits()
369
+ try:
370
+ yield
371
+ finally:
372
+ self._traits = AgentTraits(self._traits)
373
+
374
+ def _check_before_modifying_traits(self):
375
+ """Check before modifying traits."""
376
+ if self.has_dynamic_traits_function:
377
+ raise AgentErrors(
378
+ "You cannot modify the traits of an agent that has a dynamic traits function.",
379
+ "If you want to modify the traits, you should remove the dynamic traits function.",
380
+ )
251
381
 
252
- table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
253
- return f"<pre>{table}</pre>"
382
+ @traits.setter
383
+ def traits(self, traits: dict[str, str]):
384
+ with self.modify_traits_context():
385
+ self._traits = traits
386
+ # self._check_before_modifying_traits()
387
+ # self._traits = AgentTraits(traits)
254
388
 
255
389
  def rename(
256
- self, old_name_or_dict: Union[str, dict], new_name: Optional[str] = None
390
+ self,
391
+ old_name_or_dict: Union[str, dict[str, str]],
392
+ new_name: Optional[str] = None,
257
393
  ) -> Agent:
258
394
  """Rename a trait.
259
395
 
396
+ :param old_name_or_dict: The old name of the trait or a dictionary of old names and new names.
397
+ :param new_name: The new name of the trait.
398
+
260
399
  Example usage:
261
400
 
262
401
  >>> a = Agent(traits = {"age": 10, "hair": "brown", "height": 5.5})
263
- >>> a.rename("age", "years") == Agent(traits = {'years': 10, 'hair': 'brown', 'height': 5.5})
402
+ >>> newa = a.rename("age", "years")
403
+ >>> newa == Agent(traits = {'years': 10, 'hair': 'brown', 'height': 5.5})
264
404
  True
265
405
 
266
- >>> a.rename({'years': 'smage'})
267
- Agent(traits = {'hair': 'brown', 'height': 5.5, 'smage': 10})
406
+ >>> newa.rename({'years': 'smage'}) == Agent(traits = {'smage': 10, 'hair': 'brown', 'height': 5.5})
407
+ True
268
408
 
269
409
  """
270
- if isinstance(old_name_or_dict, dict) and new_name is None:
271
- for old_name, new_name in old_name_or_dict.items():
272
- self = self._rename(old_name, new_name)
273
- return self
274
-
410
+ self._check_before_modifying_traits()
275
411
  if isinstance(old_name_or_dict, dict) and new_name:
276
412
  raise AgentErrors(
277
413
  f"You passed a dict: {old_name_or_dict} and a new name: {new_name}. You should pass only a dict."
278
414
  )
279
415
 
416
+ if isinstance(old_name_or_dict, dict) and new_name is None:
417
+ return self._rename_dict(old_name_or_dict)
418
+
280
419
  if isinstance(old_name_or_dict, str):
281
- self._rename(old_name_or_dict, new_name)
282
- return self
420
+ return self._rename(old_name_or_dict, new_name)
283
421
 
284
422
  raise AgentErrors("Something is not right with Agent renaming")
285
423
 
424
+ def _rename_dict(self, renaming_dict: dict[str, str]) -> Agent:
425
+ """
426
+ Internal method to rename traits using a dictionary.
427
+ The keys should all be old names and the values should all be new names.
428
+
429
+ Example usage:
430
+ >>> a = Agent(traits = {"age": 10, "hair": "brown", "height": 5.5})
431
+ >>> a._rename_dict({"age": "years", "height": "feet"})
432
+ Agent(traits = {'years': 10, 'hair': 'brown', 'feet': 5.5})
433
+
434
+ """
435
+ try:
436
+ assert all(k in self.traits for k in renaming_dict.keys())
437
+ except AssertionError:
438
+ raise AgentErrors(
439
+ f"The trait(s) {set(renaming_dict.keys()) - set(self.traits.keys())} do not exist in the agent's traits, which are {self.traits}."
440
+ )
441
+ new_agent = self.duplicate()
442
+ new_agent.traits = {renaming_dict.get(k, k): v for k, v in self.traits.items()}
443
+ return new_agent
444
+
286
445
  def _rename(self, old_name: str, new_name: str) -> Agent:
287
446
  """Rename a trait.
288
447
 
289
448
  Example usage:
290
449
 
291
450
  >>> a = Agent(traits = {"age": 10, "hair": "brown", "height": 5.5})
292
- >>> a.rename("age", "years") == Agent(traits = {'years': 10, 'hair': 'brown', 'height': 5.5})
293
- True
451
+ >>> a._rename(old_name="age", new_name="years")
452
+ Agent(traits = {'years': 10, 'hair': 'brown', 'height': 5.5})
453
+
294
454
  """
295
- self.traits[new_name] = self.traits.pop(old_name)
296
- return self
455
+ try:
456
+ assert old_name in self.traits
457
+ except AssertionError:
458
+ raise AgentErrors(
459
+ f"The trait '{old_name}' does not exist in the agent's traits, which are {self.traits}."
460
+ )
461
+ newagent = self.duplicate()
462
+ newagent.traits = {
463
+ new_name if k == old_name else k: v for k, v in self.traits.items()
464
+ }
465
+ newagent.codebook = {
466
+ new_name if k == old_name else k: v for k, v in self.codebook.items()
467
+ }
468
+ return newagent
297
469
 
298
470
  def __getitem__(self, key):
299
471
  """Allow for accessing traits using the bracket notation.
@@ -324,7 +496,7 @@ class Agent(Base):
324
496
 
325
497
  def add_direct_question_answering_method(
326
498
  self,
327
- method: Callable,
499
+ method: DirectAnswerMethod,
328
500
  validate_response: bool = False,
329
501
  translate_response: bool = False,
330
502
  ) -> None:
@@ -353,6 +525,12 @@ class Agent(Base):
353
525
  self.validate_response = validate_response
354
526
  self.translate_response = translate_response
355
527
 
528
+ # if not isinstance(method, DirectAnswerMethod):
529
+ # raise AgentDirectAnswerFunctionError(
530
+ # f"Method {method} does not match required signature. "
531
+ # "Must take (self, question, scenario) parameters."
532
+ # )
533
+
356
534
  signature = inspect.signature(method)
357
535
  for argument in ["question", "scenario", "self"]:
358
536
  if argument not in signature.parameters:
@@ -371,11 +549,11 @@ class Agent(Base):
371
549
  survey: Optional["Survey"] = None,
372
550
  scenario: Optional["Scenario"] = None,
373
551
  model: Optional["LanguageModel"] = None,
374
- debug: bool = False,
552
+ # debug: bool = False,
375
553
  memory_plan: Optional["MemoryPlan"] = None,
376
554
  current_answers: Optional[dict] = None,
377
555
  iteration: int = 1,
378
- sidecar_model=None,
556
+ # sidecar_model=None,
379
557
  raise_validation_errors: bool = True,
380
558
  ) -> "InvigilatorBase":
381
559
  """Create an Invigilator.
@@ -391,7 +569,9 @@ class Agent(Base):
391
569
  An invigator is an object that is responsible for administering a question to an agent and
392
570
  recording the responses.
393
571
  """
394
- from edsl import Model, Scenario
572
+ from edsl.language_models.registry import Model
573
+
574
+ from edsl.scenarios.Scenario import Scenario
395
575
 
396
576
  cache = cache
397
577
  self.current_question = question
@@ -402,12 +582,12 @@ class Agent(Base):
402
582
  scenario=scenario,
403
583
  survey=survey,
404
584
  model=model,
405
- debug=debug,
585
+ # debug=debug,
406
586
  memory_plan=memory_plan,
407
587
  current_answers=current_answers,
408
588
  iteration=iteration,
409
589
  cache=cache,
410
- sidecar_model=sidecar_model,
590
+ # sidecar_model=sidecar_model,
411
591
  raise_validation_errors=raise_validation_errors,
412
592
  )
413
593
  if hasattr(self, "validate_response"):
@@ -442,7 +622,7 @@ class Agent(Base):
442
622
 
443
623
  >>> a = Agent(traits = {})
444
624
  >>> a.add_direct_question_answering_method(lambda self, question, scenario: "I am a direct answer.")
445
- >>> from edsl import QuestionFreeText
625
+ >>> from edsl.questions.QuestionFreeText import QuestionFreeText
446
626
  >>> q = QuestionFreeText.example()
447
627
  >>> a.answer_question(question = q, cache = False).answer
448
628
  'I am a direct answer.'
@@ -457,7 +637,7 @@ class Agent(Base):
457
637
  scenario=scenario,
458
638
  survey=survey,
459
639
  model=model,
460
- debug=debug,
640
+ # debug=debug,
461
641
  memory_plan=memory_plan,
462
642
  current_answers=current_answers,
463
643
  iteration=iteration,
@@ -467,6 +647,25 @@ class Agent(Base):
467
647
 
468
648
  answer_question = sync_wrapper(async_answer_question)
469
649
 
650
+ def _get_invigilator_class(self, question: QuestionBase) -> Type[InvigilatorBase]:
651
+ """Get the invigilator class for a question.
652
+
653
+ This method returns the invigilator class that should be used to answer a question.
654
+ The invigilator class is determined by the type of question and the type of agent.
655
+ """
656
+ from edsl.agents.Invigilator import (
657
+ InvigilatorHuman,
658
+ InvigilatorFunctional,
659
+ InvigilatorAI,
660
+ )
661
+
662
+ if hasattr(question, "answer_question_directly"):
663
+ return InvigilatorFunctional
664
+ elif hasattr(self, "answer_question_directly"):
665
+ return InvigilatorHuman
666
+ else:
667
+ return InvigilatorAI
668
+
470
669
  def _create_invigilator(
471
670
  self,
472
671
  question: QuestionBase,
@@ -474,53 +673,32 @@ class Agent(Base):
474
673
  scenario: Optional[Scenario] = None,
475
674
  model: Optional[LanguageModel] = None,
476
675
  survey: Optional[Survey] = None,
477
- debug: bool = False,
676
+ # debug: bool = False,
478
677
  memory_plan: Optional[MemoryPlan] = None,
479
678
  current_answers: Optional[dict] = None,
480
679
  iteration: int = 0,
481
- sidecar_model=None,
680
+ # sidecar_model=None,
482
681
  raise_validation_errors: bool = True,
483
682
  ) -> "InvigilatorBase":
484
683
  """Create an Invigilator."""
485
- from edsl import Model
486
- from edsl import Scenario
684
+ from edsl.language_models.registry import Model
685
+ from edsl.scenarios.Scenario import Scenario
487
686
 
488
687
  model = model or Model()
489
688
  scenario = scenario or Scenario()
490
689
 
491
- from edsl.agents.Invigilator import (
492
- InvigilatorHuman,
493
- InvigilatorFunctional,
494
- InvigilatorAI,
495
- InvigilatorBase,
496
- )
497
-
498
690
  if cache is None:
499
691
  from edsl.data.Cache import Cache
500
692
 
501
693
  cache = Cache()
502
694
 
503
- if debug:
504
- raise NotImplementedError("Debug mode is not yet implemented.")
505
- # use the question's _simulate_answer method
506
- # invigilator_class = InvigilatorDebug
507
- elif hasattr(question, "answer_question_directly"):
508
- # It's a functional question and the answer only depends on the agent's traits & the scenario
509
- invigilator_class = InvigilatorFunctional
510
- elif hasattr(self, "answer_question_directly"):
511
- # this of the case where the agent has a method that can answer the question directly
512
- # this occurrs when 'answer_question_directly' has been given to the
513
- # which happens when the agent is created from an existing survey
514
- invigilator_class = InvigilatorHuman
515
- else:
516
- # this means an LLM agent will be used. This is the standard case.
517
- invigilator_class = InvigilatorAI
518
-
519
- if sidecar_model is not None:
520
- # this is the case when a 'simple' model is being used
521
- from edsl.agents.Invigilator import InvigilatorSidecar
695
+ invigilator_class = self._get_invigilator_class(question)
522
696
 
523
- invigilator_class = InvigilatorSidecar
697
+ # if sidecar_model is not None:
698
+ # # this is the case when a 'simple' model is being used
699
+ # # from edsl.agents.Invigilator import InvigilatorSidecar
700
+ # # invigilator_class = InvigilatorSidecar
701
+ # raise DeprecationWarning("Sidecar models are deprecated.")
524
702
 
525
703
  invigilator = invigilator_class(
526
704
  self,
@@ -532,7 +710,7 @@ class Agent(Base):
532
710
  current_answers=current_answers,
533
711
  iteration=iteration,
534
712
  cache=cache,
535
- sidecar_model=sidecar_model,
713
+ # sidecar_model=sidecar_model,
536
714
  raise_validation_errors=raise_validation_errors,
537
715
  )
538
716
  return invigilator
@@ -540,14 +718,16 @@ class Agent(Base):
540
718
  def select(self, *traits: str) -> Agent:
541
719
  """Selects agents with only the references traits
542
720
 
543
- >>> a = Agent(traits = {"age": 10, "hair": "brown", "height": 5.5})
721
+ >>> a = Agent(traits = {"age": 10, "hair": "brown", "height": 5.5}, codebook = {'age': 'Their age is'})
722
+ >>> a
723
+ Agent(traits = {'age': 10, 'hair': 'brown', 'height': 5.5}, codebook = {'age': 'Their age is'})
544
724
 
545
725
 
546
726
  >>> a.select("age", "height")
547
- Agent(traits = {'age': 10, 'height': 5.5})
727
+ Agent(traits = {'age': 10, 'height': 5.5}, codebook = {'age': 'Their age is'})
548
728
 
549
- >>> a.select("age")
550
- Agent(traits = {'age': 10})
729
+ >>> a.select("height")
730
+ Agent(traits = {'height': 5.5})
551
731
 
552
732
  """
553
733
 
@@ -556,7 +736,17 @@ class Agent(Base):
556
736
  else:
557
737
  traits_to_select = list(traits)
558
738
 
559
- return Agent(traits={trait: self.traits[trait] for trait in traits_to_select})
739
+ def _remove_none(d):
740
+ return {k: v for k, v in d.items() if v is not None}
741
+
742
+ newagent = self.duplicate()
743
+ newagent.traits = {
744
+ trait: self.traits.get(trait, None) for trait in traits_to_select
745
+ }
746
+ newagent.codebook = _remove_none(
747
+ {trait: self.codebook.get(trait, None) for trait in traits_to_select}
748
+ )
749
+ return newagent
560
750
 
561
751
  def __add__(self, other_agent: Optional[Agent] = None) -> Agent:
562
752
  """
@@ -575,6 +765,10 @@ class Agent(Base):
575
765
  ...
576
766
  edsl.exceptions.agents.AgentCombinationError: The agents have overlapping traits: {'age'}.
577
767
  ...
768
+ >>> a1 = Agent(traits = {"age": 10}, codebook = {"age": "Their age is"})
769
+ >>> a2 = Agent(traits = {"height": 5.5}, codebook = {"height": "Their height is"})
770
+ >>> a1 + a2
771
+ Agent(traits = {'age': 10, 'height': 5.5}, codebook = {'age': 'Their age is', 'height': 'Their height is'})
578
772
  """
579
773
  if other_agent is None:
580
774
  return self
@@ -583,9 +777,14 @@ class Agent(Base):
583
777
  f"The agents have overlapping traits: {common_traits}."
584
778
  )
585
779
  else:
586
- new_agent = Agent(traits=copy.deepcopy(self.traits))
587
- new_agent.traits.update(other_agent.traits)
588
- return new_agent
780
+ new_codebook = copy.deepcopy(self.codebook) | copy.deepcopy(
781
+ other_agent.codebook
782
+ )
783
+ d = self.traits | other_agent.traits
784
+ newagent = self.duplicate()
785
+ newagent.traits = d
786
+ newagent.codebook = new_codebook
787
+ return newagent
589
788
 
590
789
  def __eq__(self, other: Agent) -> bool:
591
790
  """Check if two agents are equal.
@@ -602,7 +801,11 @@ class Agent(Base):
602
801
  return self.data == other.data
603
802
 
604
803
  def __getattr__(self, name):
605
- # This will be called only if 'name' is not found in the usual places
804
+ """
805
+ >>> a = Agent(traits = {"age": 10, "hair": "brown", "height": 5.5})
806
+ >>> a.age
807
+ 10
808
+ """
606
809
  if name == "has_dynamic_traits_function":
607
810
  return self.has_dynamic_traits_function
608
811
 
@@ -624,12 +827,6 @@ class Agent(Base):
624
827
  if "_traits" not in self.__dict__:
625
828
  self._traits = {}
626
829
 
627
- def print(self) -> None:
628
- from rich import print_json
629
- import json
630
-
631
- print_json(json.dumps(self.to_dict()))
632
-
633
830
  def __repr__(self) -> str:
634
831
  """Return representation of Agent."""
635
832
  class_name = self.__class__.__name__
@@ -640,14 +837,6 @@ class Agent(Base):
640
837
  ]
641
838
  return f"{class_name}({', '.join(items)})"
642
839
 
643
- # def _repr_html_(self):
644
- # from edsl.utilities.utilities import data_to_html
645
-
646
- # return data_to_html(self.to_dict())
647
-
648
- #######################
649
- # SERIALIZATION METHODS
650
- #######################
651
840
  @property
652
841
  def data(self) -> dict:
653
842
  """Format the data for serialization.
@@ -678,9 +867,9 @@ class Agent(Base):
678
867
  if dynamic_traits_func:
679
868
  func = inspect.getsource(dynamic_traits_func)
680
869
  raw_data["dynamic_traits_function_source_code"] = func
681
- raw_data[
682
- "dynamic_traits_function_name"
683
- ] = self.dynamic_traits_function_name
870
+ raw_data["dynamic_traits_function_name"] = (
871
+ self.dynamic_traits_function_name
872
+ )
684
873
  if hasattr(self, "answer_question_directly"):
685
874
  raw_data.pop(
686
875
  "answer_question_directly", None
@@ -694,18 +883,23 @@ class Agent(Base):
694
883
  raw_data["answer_question_directly_source_code"] = inspect.getsource(
695
884
  answer_question_directly_func
696
885
  )
697
- raw_data[
698
- "answer_question_directly_function_name"
699
- ] = self.answer_question_directly_function_name
886
+ raw_data["answer_question_directly_function_name"] = (
887
+ self.answer_question_directly_function_name
888
+ )
889
+ raw_data["traits"] = dict(raw_data["traits"])
700
890
 
701
891
  return raw_data
702
892
 
703
893
  def __hash__(self) -> int:
894
+ """Return a hash of the agent.
895
+
896
+ >>> hash(Agent.example())
897
+ 2067581884874391607
898
+ """
704
899
  from edsl.utilities.utilities import dict_hash
705
900
 
706
901
  return dict_hash(self.to_dict(add_edsl_version=False))
707
902
 
708
- # @add_edsl_version
709
903
  def to_dict(self, add_edsl_version=True) -> dict[str, Union[dict, bool]]:
710
904
  """Serialize to a dictionary with EDSL info.
711
905
 
@@ -713,9 +907,22 @@ class Agent(Base):
713
907
 
714
908
  >>> a = Agent(name = "Steve", traits = {"age": 10, "hair": "brown", "height": 5.5})
715
909
  >>> a.to_dict()
716
- {'name': 'Steve', 'traits': {'age': 10, 'hair': 'brown', 'height': 5.5}, 'edsl_version': '...', 'edsl_class_name': 'Agent'}
910
+ {'traits': {'age': 10, 'hair': 'brown', 'height': 5.5}, 'name': 'Steve', 'edsl_version': '...', 'edsl_class_name': 'Agent'}
911
+
912
+ >>> a = Agent(traits = {"age": 10, "hair": "brown", "height": 5.5}, instruction = "Have fun.")
913
+ >>> a.to_dict()
914
+ {'traits': {'age': 10, 'hair': 'brown', 'height': 5.5}, 'instruction': 'Have fun.', 'edsl_version': '...', 'edsl_class_name': 'Agent'}
717
915
  """
718
- d = copy.deepcopy(self.data)
916
+ d = {}
917
+ d["traits"] = copy.deepcopy(self.traits)
918
+ if self.name:
919
+ d["name"] = self.name
920
+ if self.set_instructions:
921
+ d["instruction"] = self.instruction
922
+ if self.set_traits_presentation_template:
923
+ d["traits_presentation_template"] = self.traits_presentation_template
924
+ if self.codebook:
925
+ d["codebook"] = self.codebook
719
926
  if add_edsl_version:
720
927
  from edsl import __version__
721
928
 
@@ -735,7 +942,18 @@ class Agent(Base):
735
942
  Agent(name = \"""Steve\""", traits = {'age': 10, 'hair': 'brown', 'height': 5.5})
736
943
 
737
944
  """
738
- return cls(**agent_dict)
945
+ if "traits" in agent_dict:
946
+ return cls(
947
+ traits=agent_dict["traits"],
948
+ name=agent_dict.get("name", None),
949
+ instruction=agent_dict.get("instruction", None),
950
+ traits_presentation_template=agent_dict.get(
951
+ "traits_presentation_template", None
952
+ ),
953
+ codebook=agent_dict.get("codebook", None),
954
+ )
955
+ else: # old-style agent - we used to only store the traits
956
+ return cls(**agent_dict)
739
957
 
740
958
  def _table(self) -> tuple[dict, list]:
741
959
  """Prepare generic table data."""
@@ -746,10 +964,15 @@ class Agent(Base):
746
964
  return table_data, column_names
747
965
 
748
966
  def add_trait(self, trait_name_or_dict: str, value: Optional[Any] = None) -> Agent:
749
- """Adds a trait to an agent and returns that agent"""
967
+ """Adds a trait to an agent and returns that agent
968
+ >>> a = Agent(traits = {"age": 10, "hair": "brown", "height": 5.5})
969
+ >>> a.add_trait("weight", 150)
970
+ Agent(traits = {'age': 10, 'hair': 'brown', 'height': 5.5, 'weight': 150})
971
+ """
750
972
  if isinstance(trait_name_or_dict, dict) and value is None:
751
- self.traits.update(trait_name_or_dict)
752
- return self
973
+ newagent = self.duplicate()
974
+ newagent.traits = {**self.traits, **trait_name_or_dict}
975
+ return newagent
753
976
 
754
977
  if isinstance(trait_name_or_dict, dict) and value:
755
978
  raise AgentErrors(
@@ -757,9 +980,9 @@ class Agent(Base):
757
980
  )
758
981
 
759
982
  if isinstance(trait_name_or_dict, str):
760
- trait = trait_name_or_dict
761
- self.traits[trait] = value
762
- return self
983
+ newagent = self.duplicate()
984
+ newagent.traits = {**self.traits, **{trait_name_or_dict: value}}
985
+ return newagent
763
986
 
764
987
  raise AgentErrors("Something is not right with adding a trait to an Agent")
765
988
 
@@ -772,8 +995,9 @@ class Agent(Base):
772
995
  >>> a.remove_trait("age")
773
996
  Agent(traits = {'hair': 'brown', 'height': 5.5})
774
997
  """
775
- _ = self.traits.pop(trait)
776
- return self
998
+ newagent = self.duplicate()
999
+ newagent.traits = {k: v for k, v in self.traits.items() if k != trait}
1000
+ return newagent
777
1001
 
778
1002
  def translate_traits(self, values_codebook: dict) -> Agent:
779
1003
  """Translate traits to a new codebook.
@@ -784,32 +1008,15 @@ class Agent(Base):
784
1008
 
785
1009
  :param values_codebook: The new codebook.
786
1010
  """
1011
+ new_traits = {}
787
1012
  for key, value in self.traits.items():
788
1013
  if key in values_codebook:
789
- self.traits[key] = values_codebook[key][value]
790
- return self
791
-
792
- def rich_print(self):
793
- """Display an object as a rich table.
794
-
795
- Example usage:
796
-
797
- >>> a = Agent(traits = {"age": 10, "hair": "brown", "height": 5.5})
798
- >>> a.rich_print()
799
- <rich.table.Table object at ...>
800
- """
801
- from rich.table import Table
802
-
803
- table_data, column_names = self._table()
804
- table = Table(title=f"{self.__class__.__name__} Attributes")
805
- for column in column_names:
806
- table.add_column(column, style="bold")
807
-
808
- for row in table_data:
809
- row_data = [row[column] for column in column_names]
810
- table.add_row(*row_data)
811
-
812
- return table
1014
+ new_traits[key] = values_codebook[key].get(value, value)
1015
+ else:
1016
+ new_traits[key] = value
1017
+ newagent = self.duplicate()
1018
+ newagent.traits = new_traits
1019
+ return newagent
813
1020
 
814
1021
  @classmethod
815
1022
  def example(cls, randomize: bool = False) -> Agent:
@@ -817,6 +1024,9 @@ class Agent(Base):
817
1024
  Returns an example Agent instance.
818
1025
 
819
1026
  :param randomize: If True, adds a random string to the value of an example key.
1027
+
1028
+ >>> Agent.example()
1029
+ Agent(traits = {'age': 22, 'hair': 'brown', 'height': 5.5})
820
1030
  """
821
1031
  addition = "" if not randomize else str(uuid4())
822
1032
  return cls(traits={"age": 22, "hair": f"brown{addition}", "height": 5.5})
@@ -828,10 +1038,12 @@ class Agent(Base):
828
1038
 
829
1039
  >>> a = Agent(traits = {"age": 10, "hair": "brown", "height": 5.5})
830
1040
  >>> print(a.code())
831
- from edsl import Agent
1041
+ from edsl.agents.Agent import Agent
832
1042
  agent = Agent(traits={'age': 10, 'hair': 'brown', 'height': 5.5})
833
1043
  """
834
- return f"from edsl import Agent\nagent = Agent(traits={self.traits})"
1044
+ return (
1045
+ f"from edsl.agents.Agent import Agent\nagent = Agent(traits={self.traits})"
1046
+ )
835
1047
 
836
1048
 
837
1049
  def main():