edsl 0.1.33__py3-none-any.whl → 0.1.33.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. edsl/Base.py +3 -9
  2. edsl/__init__.py +3 -8
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +8 -40
  5. edsl/agents/AgentList.py +0 -43
  6. edsl/agents/Invigilator.py +219 -135
  7. edsl/agents/InvigilatorBase.py +59 -148
  8. edsl/agents/{PromptConstructor.py → PromptConstructionMixin.py} +89 -138
  9. edsl/agents/__init__.py +0 -1
  10. edsl/config.py +56 -47
  11. edsl/coop/coop.py +7 -50
  12. edsl/data/Cache.py +1 -35
  13. edsl/data_transfer_models.py +38 -73
  14. edsl/enums.py +0 -4
  15. edsl/exceptions/language_models.py +1 -25
  16. edsl/exceptions/questions.py +5 -62
  17. edsl/exceptions/results.py +0 -4
  18. edsl/inference_services/AnthropicService.py +11 -13
  19. edsl/inference_services/AwsBedrock.py +17 -19
  20. edsl/inference_services/AzureAI.py +20 -37
  21. edsl/inference_services/GoogleService.py +12 -16
  22. edsl/inference_services/GroqService.py +0 -2
  23. edsl/inference_services/InferenceServiceABC.py +3 -58
  24. edsl/inference_services/OpenAIService.py +54 -48
  25. edsl/inference_services/models_available_cache.py +6 -0
  26. edsl/inference_services/registry.py +0 -6
  27. edsl/jobs/Answers.py +12 -10
  28. edsl/jobs/Jobs.py +21 -36
  29. edsl/jobs/buckets/BucketCollection.py +15 -24
  30. edsl/jobs/buckets/TokenBucket.py +14 -93
  31. edsl/jobs/interviews/Interview.py +78 -366
  32. edsl/jobs/interviews/InterviewExceptionEntry.py +19 -85
  33. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +286 -0
  34. edsl/jobs/interviews/{InterviewExceptionCollection.py → interview_exception_tracking.py} +68 -14
  35. edsl/jobs/interviews/retry_management.py +37 -0
  36. edsl/jobs/runners/JobsRunnerAsyncio.py +175 -146
  37. edsl/jobs/runners/JobsRunnerStatusMixin.py +333 -0
  38. edsl/jobs/tasks/QuestionTaskCreator.py +23 -30
  39. edsl/jobs/tasks/TaskHistory.py +213 -148
  40. edsl/language_models/LanguageModel.py +156 -261
  41. edsl/language_models/ModelList.py +2 -2
  42. edsl/language_models/RegisterLanguageModelsMeta.py +29 -14
  43. edsl/language_models/registry.py +6 -23
  44. edsl/language_models/repair.py +19 -0
  45. edsl/prompts/Prompt.py +2 -52
  46. edsl/questions/AnswerValidatorMixin.py +26 -23
  47. edsl/questions/QuestionBase.py +249 -329
  48. edsl/questions/QuestionBudget.py +41 -99
  49. edsl/questions/QuestionCheckBox.py +35 -227
  50. edsl/questions/QuestionExtract.py +27 -98
  51. edsl/questions/QuestionFreeText.py +29 -52
  52. edsl/questions/QuestionFunctional.py +0 -7
  53. edsl/questions/QuestionList.py +22 -141
  54. edsl/questions/QuestionMultipleChoice.py +65 -159
  55. edsl/questions/QuestionNumerical.py +46 -88
  56. edsl/questions/QuestionRank.py +24 -182
  57. edsl/questions/RegisterQuestionsMeta.py +12 -31
  58. edsl/questions/__init__.py +4 -3
  59. edsl/questions/derived/QuestionLikertFive.py +5 -10
  60. edsl/questions/derived/QuestionLinearScale.py +2 -15
  61. edsl/questions/derived/QuestionTopK.py +1 -10
  62. edsl/questions/derived/QuestionYesNo.py +3 -24
  63. edsl/questions/descriptors.py +7 -43
  64. edsl/questions/question_registry.py +2 -6
  65. edsl/results/Dataset.py +0 -20
  66. edsl/results/DatasetExportMixin.py +48 -46
  67. edsl/results/Result.py +5 -32
  68. edsl/results/Results.py +46 -135
  69. edsl/results/ResultsDBMixin.py +3 -3
  70. edsl/scenarios/FileStore.py +10 -71
  71. edsl/scenarios/Scenario.py +25 -96
  72. edsl/scenarios/ScenarioImageMixin.py +2 -2
  73. edsl/scenarios/ScenarioList.py +39 -361
  74. edsl/scenarios/ScenarioListExportMixin.py +0 -9
  75. edsl/scenarios/ScenarioListPdfMixin.py +4 -150
  76. edsl/study/SnapShot.py +1 -8
  77. edsl/study/Study.py +0 -32
  78. edsl/surveys/Rule.py +1 -10
  79. edsl/surveys/RuleCollection.py +5 -21
  80. edsl/surveys/Survey.py +310 -636
  81. edsl/surveys/SurveyExportMixin.py +9 -71
  82. edsl/surveys/SurveyFlowVisualizationMixin.py +1 -2
  83. edsl/surveys/SurveyQualtricsImport.py +4 -75
  84. edsl/utilities/gcp_bucket/simple_example.py +9 -0
  85. edsl/utilities/utilities.py +1 -9
  86. {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/METADATA +2 -5
  87. edsl-0.1.33.dev1.dist-info/RECORD +209 -0
  88. edsl/TemplateLoader.py +0 -24
  89. edsl/auto/AutoStudy.py +0 -117
  90. edsl/auto/StageBase.py +0 -230
  91. edsl/auto/StageGenerateSurvey.py +0 -178
  92. edsl/auto/StageLabelQuestions.py +0 -125
  93. edsl/auto/StagePersona.py +0 -61
  94. edsl/auto/StagePersonaDimensionValueRanges.py +0 -88
  95. edsl/auto/StagePersonaDimensionValues.py +0 -74
  96. edsl/auto/StagePersonaDimensions.py +0 -69
  97. edsl/auto/StageQuestions.py +0 -73
  98. edsl/auto/SurveyCreatorPipeline.py +0 -21
  99. edsl/auto/utilities.py +0 -224
  100. edsl/coop/PriceFetcher.py +0 -58
  101. edsl/inference_services/MistralAIService.py +0 -120
  102. edsl/inference_services/TestService.py +0 -80
  103. edsl/inference_services/TogetherAIService.py +0 -170
  104. edsl/jobs/FailedQuestion.py +0 -78
  105. edsl/jobs/runners/JobsRunnerStatus.py +0 -331
  106. edsl/language_models/fake_openai_call.py +0 -15
  107. edsl/language_models/fake_openai_service.py +0 -61
  108. edsl/language_models/utilities.py +0 -61
  109. edsl/questions/QuestionBaseGenMixin.py +0 -133
  110. edsl/questions/QuestionBasePromptsMixin.py +0 -266
  111. edsl/questions/Quick.py +0 -41
  112. edsl/questions/ResponseValidatorABC.py +0 -170
  113. edsl/questions/decorators.py +0 -21
  114. edsl/questions/prompt_templates/question_budget.jinja +0 -13
  115. edsl/questions/prompt_templates/question_checkbox.jinja +0 -32
  116. edsl/questions/prompt_templates/question_extract.jinja +0 -11
  117. edsl/questions/prompt_templates/question_free_text.jinja +0 -3
  118. edsl/questions/prompt_templates/question_linear_scale.jinja +0 -11
  119. edsl/questions/prompt_templates/question_list.jinja +0 -17
  120. edsl/questions/prompt_templates/question_multiple_choice.jinja +0 -33
  121. edsl/questions/prompt_templates/question_numerical.jinja +0 -37
  122. edsl/questions/templates/__init__.py +0 -0
  123. edsl/questions/templates/budget/__init__.py +0 -0
  124. edsl/questions/templates/budget/answering_instructions.jinja +0 -7
  125. edsl/questions/templates/budget/question_presentation.jinja +0 -7
  126. edsl/questions/templates/checkbox/__init__.py +0 -0
  127. edsl/questions/templates/checkbox/answering_instructions.jinja +0 -10
  128. edsl/questions/templates/checkbox/question_presentation.jinja +0 -22
  129. edsl/questions/templates/extract/__init__.py +0 -0
  130. edsl/questions/templates/extract/answering_instructions.jinja +0 -7
  131. edsl/questions/templates/extract/question_presentation.jinja +0 -1
  132. edsl/questions/templates/free_text/__init__.py +0 -0
  133. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  134. edsl/questions/templates/free_text/question_presentation.jinja +0 -1
  135. edsl/questions/templates/likert_five/__init__.py +0 -0
  136. edsl/questions/templates/likert_five/answering_instructions.jinja +0 -10
  137. edsl/questions/templates/likert_five/question_presentation.jinja +0 -12
  138. edsl/questions/templates/linear_scale/__init__.py +0 -0
  139. edsl/questions/templates/linear_scale/answering_instructions.jinja +0 -5
  140. edsl/questions/templates/linear_scale/question_presentation.jinja +0 -5
  141. edsl/questions/templates/list/__init__.py +0 -0
  142. edsl/questions/templates/list/answering_instructions.jinja +0 -4
  143. edsl/questions/templates/list/question_presentation.jinja +0 -5
  144. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  145. edsl/questions/templates/multiple_choice/answering_instructions.jinja +0 -9
  146. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  147. edsl/questions/templates/multiple_choice/question_presentation.jinja +0 -12
  148. edsl/questions/templates/numerical/__init__.py +0 -0
  149. edsl/questions/templates/numerical/answering_instructions.jinja +0 -8
  150. edsl/questions/templates/numerical/question_presentation.jinja +0 -7
  151. edsl/questions/templates/rank/__init__.py +0 -0
  152. edsl/questions/templates/rank/answering_instructions.jinja +0 -11
  153. edsl/questions/templates/rank/question_presentation.jinja +0 -15
  154. edsl/questions/templates/top_k/__init__.py +0 -0
  155. edsl/questions/templates/top_k/answering_instructions.jinja +0 -8
  156. edsl/questions/templates/top_k/question_presentation.jinja +0 -22
  157. edsl/questions/templates/yes_no/__init__.py +0 -0
  158. edsl/questions/templates/yes_no/answering_instructions.jinja +0 -6
  159. edsl/questions/templates/yes_no/question_presentation.jinja +0 -12
  160. edsl/results/DatasetTree.py +0 -145
  161. edsl/results/Selector.py +0 -118
  162. edsl/results/tree_explore.py +0 -115
  163. edsl/surveys/instructions/ChangeInstruction.py +0 -47
  164. edsl/surveys/instructions/Instruction.py +0 -34
  165. edsl/surveys/instructions/InstructionCollection.py +0 -77
  166. edsl/surveys/instructions/__init__.py +0 -0
  167. edsl/templates/error_reporting/base.html +0 -24
  168. edsl/templates/error_reporting/exceptions_by_model.html +0 -35
  169. edsl/templates/error_reporting/exceptions_by_question_name.html +0 -17
  170. edsl/templates/error_reporting/exceptions_by_type.html +0 -17
  171. edsl/templates/error_reporting/interview_details.html +0 -116
  172. edsl/templates/error_reporting/interviews.html +0 -10
  173. edsl/templates/error_reporting/overview.html +0 -5
  174. edsl/templates/error_reporting/performance_plot.html +0 -2
  175. edsl/templates/error_reporting/report.css +0 -74
  176. edsl/templates/error_reporting/report.html +0 -118
  177. edsl/templates/error_reporting/report.js +0 -25
  178. edsl-0.1.33.dist-info/RECORD +0 -295
  179. {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/LICENSE +0 -0
  180. {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/WHEEL +0 -0
edsl/Base.py CHANGED
@@ -47,27 +47,21 @@ class PersistenceMixin:
47
47
  self,
48
48
  description: Optional[str] = None,
49
49
  visibility: Optional[str] = "unlisted",
50
- expected_parrot_url: Optional[str] = None,
51
50
  ):
52
51
  """Post the object to coop."""
53
52
  from edsl.coop import Coop
54
53
 
55
- c = Coop(url=expected_parrot_url)
54
+ c = Coop()
56
55
  return c.create(self, description, visibility)
57
56
 
58
57
  @classmethod
59
- def pull(
60
- cls,
61
- uuid: Optional[Union[str, UUID]] = None,
62
- url: Optional[str] = None,
63
- expected_parrot_url: Optional[str] = None,
64
- ):
58
+ def pull(cls, uuid: Optional[Union[str, UUID]] = None, url: Optional[str] = None):
65
59
  """Pull the object from coop."""
66
60
  from edsl.coop import Coop
67
61
  from edsl.coop.utils import ObjectRegistry
68
62
 
69
63
  object_type = ObjectRegistry.get_object_type_by_edsl_class(cls)
70
- coop = Coop(url=expected_parrot_url)
64
+ coop = Coop()
71
65
  return coop.get(uuid, url, object_type)
72
66
 
73
67
  @classmethod
edsl/__init__.py CHANGED
@@ -8,10 +8,9 @@ from edsl.__version__ import __version__
8
8
  from edsl.config import Config, CONFIG
9
9
  from edsl.agents.Agent import Agent
10
10
  from edsl.agents.AgentList import AgentList
11
-
12
11
  from edsl.questions import QuestionBase
13
- from edsl.questions.question_registry import Question
14
12
  from edsl.questions import QuestionMultipleChoice
13
+ from edsl.questions import QuestionBudget
15
14
  from edsl.questions import QuestionCheckBox
16
15
  from edsl.questions import QuestionExtract
17
16
  from edsl.questions import QuestionFreeText
@@ -20,11 +19,10 @@ from edsl.questions import QuestionLikertFive
20
19
  from edsl.questions import QuestionList
21
20
  from edsl.questions import QuestionLinearScale
22
21
  from edsl.questions import QuestionNumerical
23
- from edsl.questions import QuestionYesNo
24
- from edsl.questions import QuestionBudget
25
22
  from edsl.questions import QuestionRank
26
23
  from edsl.questions import QuestionTopK
27
-
24
+ from edsl.questions import QuestionYesNo
25
+ from edsl.questions.question_registry import Question
28
26
  from edsl.scenarios import Scenario
29
27
  from edsl.scenarios import ScenarioList
30
28
 
@@ -42,6 +40,3 @@ from edsl.notebooks.Notebook import Notebook
42
40
  from edsl.study.Study import Study
43
41
  from edsl.conjure.Conjure import Conjure
44
42
  from edsl.coop.coop import Coop
45
-
46
- from edsl.surveys.instructions.Instruction import Instruction
47
- from edsl.surveys.instructions.ChangeInstruction import ChangeInstruction
edsl/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.33"
1
+ __version__ = "0.1.33.dev1"
edsl/agents/Agent.py CHANGED
@@ -4,7 +4,7 @@ from __future__ import annotations
4
4
  import copy
5
5
  import inspect
6
6
  import types
7
- from typing import Callable, Optional, Union, Any
7
+ from typing import Callable, Optional, Union
8
8
  from uuid import uuid4
9
9
  from edsl.Base import Base
10
10
 
@@ -228,12 +228,7 @@ class Agent(Base):
228
228
  if hasattr(self, "answer_question_directly"):
229
229
  delattr(self, "answer_question_directly")
230
230
 
231
- def add_direct_question_answering_method(
232
- self,
233
- method: Callable,
234
- validate_response: bool = False,
235
- translate_response: bool = False,
236
- ) -> None:
231
+ def add_direct_question_answering_method(self, method: Callable) -> None:
237
232
  """Add a method to the agent that can answer a particular question type.
238
233
 
239
234
  :param method: A method that can answer a question directly.
@@ -254,9 +249,6 @@ class Agent(Base):
254
249
  )
255
250
  # print("Warning: overwriting existing answer_question_directly method")
256
251
 
257
- self.validate_response = validate_response
258
- self.translate_response = translate_response
259
-
260
252
  signature = inspect.signature(method)
261
253
  for argument in ["question", "scenario", "self"]:
262
254
  if argument not in signature.parameters:
@@ -280,7 +272,6 @@ class Agent(Base):
280
272
  current_answers: Optional[dict] = None,
281
273
  iteration: int = 1,
282
274
  sidecar_model=None,
283
- raise_validation_errors: bool = True,
284
275
  ) -> "InvigilatorBase":
285
276
  """Create an Invigilator.
286
277
 
@@ -312,12 +303,7 @@ class Agent(Base):
312
303
  iteration=iteration,
313
304
  cache=cache,
314
305
  sidecar_model=sidecar_model,
315
- raise_validation_errors=raise_validation_errors,
316
306
  )
317
- if hasattr(self, "validate_response"):
318
- invigilator.validate_response = self.validate_response
319
- if hasattr(self, "translate_response"):
320
- invigilator.translate_response = self.translate_response
321
307
  return invigilator
322
308
 
323
309
  async def async_answer_question(
@@ -348,8 +334,8 @@ class Agent(Base):
348
334
  >>> a.add_direct_question_answering_method(lambda self, question, scenario: "I am a direct answer.")
349
335
  >>> from edsl import QuestionFreeText
350
336
  >>> q = QuestionFreeText.example()
351
- >>> a.answer_question(question = q, cache = False).answer
352
- 'I am a direct answer.'
337
+ >>> a.answer_question(question = q, cache = False)
338
+ {'answer': 'I am a direct answer.', 'comment': 'This is a real survey response from a human.', ...}
353
339
 
354
340
  This is a function where an agent returns an answer to a particular question.
355
341
  However, there are several different ways an agent can answer a question, so the
@@ -383,7 +369,6 @@ class Agent(Base):
383
369
  current_answers: Optional[dict] = None,
384
370
  iteration: int = 0,
385
371
  sidecar_model=None,
386
- raise_validation_errors: bool = True,
387
372
  ) -> "InvigilatorBase":
388
373
  """Create an Invigilator."""
389
374
  from edsl import Model
@@ -393,6 +378,7 @@ class Agent(Base):
393
378
  scenario = scenario or Scenario()
394
379
 
395
380
  from edsl.agents.Invigilator import (
381
+ InvigilatorDebug,
396
382
  InvigilatorHuman,
397
383
  InvigilatorFunctional,
398
384
  InvigilatorAI,
@@ -405,9 +391,8 @@ class Agent(Base):
405
391
  cache = Cache()
406
392
 
407
393
  if debug:
408
- raise NotImplementedError("Debug mode is not yet implemented.")
409
394
  # use the question's _simulate_answer method
410
- # invigilator_class = InvigilatorDebug
395
+ invigilator_class = InvigilatorDebug
411
396
  elif hasattr(question, "answer_question_directly"):
412
397
  # It's a functional question and the answer only depends on the agent's traits & the scenario
413
398
  invigilator_class = InvigilatorFunctional
@@ -437,7 +422,6 @@ class Agent(Base):
437
422
  iteration=iteration,
438
423
  cache=cache,
439
424
  sidecar_model=sidecar_model,
440
- raise_validation_errors=raise_validation_errors,
441
425
  )
442
426
  return invigilator
443
427
 
@@ -513,8 +497,8 @@ class Agent(Base):
513
497
  if name == "has_dynamic_traits_function":
514
498
  return self.has_dynamic_traits_function
515
499
 
516
- if name in self._traits:
517
- return self._traits[name]
500
+ if name in self.traits:
501
+ return self.traits[name]
518
502
  raise AttributeError(
519
503
  f"'{type(self).__name__}' object has no attribute '{name}'"
520
504
  )
@@ -656,22 +640,6 @@ class Agent(Base):
656
640
  column_names = ["Attribute", "Value"]
657
641
  return table_data, column_names
658
642
 
659
- def add_trait(self, trait_name_or_dict: str, value: Optional[Any] = None) -> Agent:
660
- """Adds a trait to an agent and returns that agent"""
661
- if isinstance(trait_name_or_dict, dict) and value is None:
662
- self.traits.update(trait_name_or_dict)
663
- return self
664
-
665
- if isinstance(trait_name_or_dict, dict) and value:
666
- raise ValueError(f"You passed a dict: {trait_name_or_dict}")
667
-
668
- if isinstance(trait_name_or_dict, str):
669
- trait = trait_name_or_dict
670
- self.traits[trait] = value
671
- return self
672
-
673
- raise Exception("Something is not right with adding")
674
-
675
643
  def remove_trait(self, trait: str) -> Agent:
676
644
  """Remove a trait from the agent.
677
645
 
edsl/agents/AgentList.py CHANGED
@@ -21,12 +21,6 @@ from simpleeval import EvalWithCompoundTypes
21
21
  from edsl.Base import Base
22
22
  from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
23
23
 
24
- from collections.abc import Iterable
25
-
26
-
27
- def is_iterable(obj):
28
- return isinstance(obj, Iterable)
29
-
30
24
 
31
25
  class AgentList(UserList, Base):
32
26
  """A list of Agents."""
@@ -117,13 +111,6 @@ class AgentList(UserList, Base):
117
111
 
118
112
  return AgentList(new_data)
119
113
 
120
- @property
121
- def all_traits(self):
122
- d = {}
123
- for agent in self:
124
- d.update(agent.traits)
125
- return list(d.keys())
126
-
127
114
  @classmethod
128
115
  def from_csv(cls, file_path: str):
129
116
  """Load AgentList from a CSV file.
@@ -172,36 +159,6 @@ class AgentList(UserList, Base):
172
159
  _ = agent.remove_trait(trait)
173
160
  return self
174
161
 
175
- def add_trait(self, trait, values):
176
- """Adds a new trait to every agent, with values taken from values.
177
-
178
- :param trait: The name of the trait.
179
- :param values: The valeues(s) of the trait. If a single value is passed, it is used for all agents.
180
-
181
- >>> al = AgentList.example()
182
- >>> al.add_trait('new_trait', 1)
183
- AgentList([Agent(traits = {'age': 22, 'hair': 'brown', 'height': 5.5, 'new_trait': 1}), Agent(traits = {'age': 22, 'hair': 'brown', 'height': 5.5, 'new_trait': 1})])
184
- >>> al.select('new_trait').to_scenario_list().to_list()
185
- [1, 1]
186
- >>> al.add_trait('new_trait', [1, 2, 3])
187
- Traceback (most recent call last):
188
- ...
189
- ValueError: The passed values have to be the same length as the agent list.
190
- """
191
- if not is_iterable(values):
192
- value = values
193
- for agent in self.data:
194
- agent.add_trait(trait, value)
195
- return self
196
-
197
- if len(values) != len(self):
198
- raise ValueError(
199
- "The passed values have to be the same length as the agent list."
200
- )
201
- for agent, value in zip(self.data, values):
202
- agent.add_trait(trait, value)
203
- return self
204
-
205
162
  @staticmethod
206
163
  def get_codebook(file_path: str):
207
164
  """Return the codebook for a CSV file.
@@ -1,169 +1,252 @@
1
1
  """Module for creating Invigilators, which are objects to administer a question to an Agent."""
2
2
 
3
+ import json
3
4
  from typing import Dict, Any, Optional
4
5
 
6
+ from edsl.exceptions import AgentRespondedWithBadJSONError
5
7
  from edsl.prompts.Prompt import Prompt
6
8
  from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
7
9
  from edsl.prompts.registry import get_classes as prompt_lookup
8
- from edsl.exceptions.questions import QuestionAnswerValidationError
9
- from edsl.agents.InvigilatorBase import InvigilatorBase
10
- from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
11
- from edsl.agents.PromptConstructor import PromptConstructor
12
-
10
+ from edsl.data_transfer_models import AgentResponseDict
11
+ from edsl.exceptions.agents import FailedTaskException
12
+ from edsl.agents.PromptConstructionMixin import PromptConstructorMixin
13
13
 
14
- class NotApplicable(str):
15
- def __new__(cls):
16
- instance = super().__new__(cls, "Not Applicable")
17
- instance.literal = "Not Applicable"
18
- return instance
14
+ from edsl.agents.InvigilatorBase import InvigilatorBase
19
15
 
20
16
 
21
- class InvigilatorAI(InvigilatorBase):
17
+ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
22
18
  """An invigilator that uses an AI model to answer questions."""
23
19
 
24
- def get_prompts(self) -> Dict[str, Prompt]:
25
- """Return the prompts used."""
26
- return self.prompt_constructor.get_prompts()
27
-
28
20
  async def async_answer_question(self) -> AgentResponseDict:
29
21
  """Answer a question using the AI model.
30
22
 
31
23
  >>> i = InvigilatorAI.example()
32
24
  >>> i.answer_question()
33
- {'message': [{'text': 'SPAM!'}], 'usage': {'prompt_tokens': 1, 'completion_tokens': 1}}
25
+ {'message': '{"answer": "SPAM!"}'}
34
26
  """
35
- prompts = self.get_prompts()
36
- params = {
37
- "user_prompt": prompts["user_prompt"].text,
38
- "system_prompt": prompts["system_prompt"].text,
27
+ params = self.get_prompts() | {"iteration": self.iteration}
28
+ raw_response = await self.async_get_response(**params)
29
+ data = {
30
+ "agent": self.agent,
31
+ "question": self.question,
32
+ "scenario": self.scenario,
33
+ "raw_response": raw_response,
34
+ "raw_model_response": raw_response["raw_model_response"],
39
35
  }
40
- if "encoded_image" in prompts:
41
- params["encoded_image"] = prompts["encoded_image"]
42
-
43
- params.update({"iteration": self.iteration, "cache": self.cache})
44
-
45
- agent_response_dict: AgentResponseDict = await self.model.async_get_response(
46
- **params
47
- )
48
- # store to self in case validation failure
49
- self.raw_model_response = agent_response_dict.model_outputs.response
50
- self.generated_tokens = agent_response_dict.edsl_dict.generated_tokens
51
-
52
- return self.extract_edsl_result_entry_and_validate(agent_response_dict)
53
-
54
- def _remove_from_cache(self, cache_key) -> None:
36
+ response = self._format_raw_response(**data)
37
+ # breakpoint()
38
+ return AgentResponseDict(**response)
39
+
40
+ async def async_get_response(
41
+ self,
42
+ user_prompt: Prompt,
43
+ system_prompt: Prompt,
44
+ iteration: int = 0,
45
+ encoded_image=None,
46
+ ) -> dict:
47
+ """Call the LLM and gets a response. Used in the `answer_question` method."""
48
+ try:
49
+ params = {
50
+ "user_prompt": user_prompt.text,
51
+ "system_prompt": system_prompt.text,
52
+ "iteration": iteration,
53
+ "cache": self.cache,
54
+ }
55
+ if encoded_image:
56
+ params["encoded_image"] = encoded_image
57
+ response = await self.model.async_get_response(**params)
58
+
59
+ # TODO: I *don't* think we need to delete the cache key here because I think
60
+ # it will not have been set yet; the exception would have been raised before.
61
+ except json.JSONDecodeError as e:
62
+ raise AgentRespondedWithBadJSONError(
63
+ f"Returned bad JSON: {e}"
64
+ f"Prompt: {user_prompt}"
65
+ f"System Prompt: {system_prompt}"
66
+ )
67
+
68
+ return response
69
+
70
+ def _remove_from_cache(self, raw_response) -> None:
55
71
  """Remove an entry from the cache."""
72
+ cache_key = raw_response.get("cache_key", None)
56
73
  if cache_key:
57
74
  del self.cache.data[cache_key]
58
75
 
59
- def determine_answer(self, raw_answer: str) -> Any:
76
+ def _format_raw_response(
77
+ self, *, agent, question, scenario, raw_response, raw_model_response
78
+ ) -> AgentResponseDict:
79
+ """Return formatted raw response.
80
+
81
+ This cleans up the raw response to make it suitable to pass to AgentResponseDict.
82
+ """
83
+ _ = agent
84
+ try:
85
+ response = question._validate_answer(raw_response)
86
+ except Exception as e:
87
+ """If the response is invalid, remove it from the cache and raise the exception."""
88
+ self._remove_from_cache(raw_response)
89
+ raise e
90
+
60
91
  question_dict = self.survey.question_names_to_questions()
61
- # iterates through the current answers and updates the question_dict (which is all questions)
62
92
  for other_question, answer in self.current_answers.items():
63
93
  if other_question in question_dict:
64
94
  question_dict[other_question].answer = answer
65
95
  else:
66
- # it might be a comment
96
+ # adds a comment to the question
67
97
  if (
68
98
  new_question := other_question.split("_comment")[0]
69
99
  ) in question_dict:
70
100
  question_dict[new_question].comment = answer
71
101
 
72
- combined_dict = {**question_dict, **self.scenario}
73
- # sometimes the answer is a code, so we need to translate it
74
- return self.question._translate_answer_code_to_answer(raw_answer, combined_dict)
75
-
76
- def extract_edsl_result_entry_and_validate(
77
- self, agent_response_dict: AgentResponseDict
78
- ) -> EDSLResultObjectInput:
79
- edsl_dict = agent_response_dict.edsl_dict._asdict()
80
- exception_occurred = None
81
- validated = False
82
- try:
83
- validated_edsl_dict = self.question._validate_answer(edsl_dict)
84
- answer = self.determine_answer(validated_edsl_dict["answer"])
85
- comment = validated_edsl_dict.get("comment", "")
86
- validated = True
87
- except QuestionAnswerValidationError as e:
88
- answer = None
89
- comment = "The response was not valid."
90
- if self.raise_validation_errors:
91
- exception_occurred = e
92
- except Exception as non_validation_error:
93
- answer = None
94
- comment = "Some other error occurred."
95
- exception_occurred = non_validation_error
96
- finally:
97
- # even if validation failes, we still return the result
98
- data = {
99
- "answer": answer,
100
- "comment": comment,
101
- "generated_tokens": agent_response_dict.edsl_dict.generated_tokens,
102
- "question_name": self.question.question_name,
103
- "prompts": self.get_prompts(),
104
- "cached_response": agent_response_dict.model_outputs.cached_response,
105
- "raw_model_response": agent_response_dict.model_outputs.response,
106
- "cache_used": agent_response_dict.model_outputs.cache_used,
107
- "cache_key": agent_response_dict.model_outputs.cache_key,
108
- "validated": validated,
109
- "exception_occurred": exception_occurred,
110
- "cost": agent_response_dict.model_outputs.cost,
111
- }
112
- result = EDSLResultObjectInput(**data)
113
- return result
102
+ combined_dict = {**question_dict, **scenario}
103
+ answer = question._translate_answer_code_to_answer(
104
+ response["answer"], combined_dict
105
+ )
106
+ data = {
107
+ "answer": answer,
108
+ "comment": response.get(
109
+ "comment", ""
110
+ ), # not all question have comment fields,
111
+ "question_name": question.question_name,
112
+ "prompts": self.get_prompts(),
113
+ "cached_response": raw_response.get("cached_response", None),
114
+ "usage": raw_response.get("usage", {}),
115
+ "raw_model_response": raw_model_response,
116
+ "cache_used": raw_response.get("cache_used", False),
117
+ "cache_key": raw_response.get("cache_key", None),
118
+ }
119
+ return AgentResponseDict(**data)
114
120
 
121
+ get_response = sync_wrapper(async_get_response)
115
122
  answer_question = sync_wrapper(async_answer_question)
116
123
 
117
124
 
118
- class InvigilatorHuman(InvigilatorBase):
119
- """An invigilator for when a human is answering the question."""
125
+ class InvigilatorSidecar(InvigilatorAI):
126
+ """An invigilator that presents the 'raw' question to the agent
127
+ & uses a sidecar model to answer questions."""
128
+
129
+ async def async_answer_question(self, failed: bool = False) -> AgentResponseDict:
130
+ """Answer a question using the AI model."""
131
+ from edsl import Model
132
+
133
+ advanced_model = self.sidecar_model
134
+ simple_model = self.model
135
+ question = self.question
136
+ human_readable_question = (
137
+ "Please answer this single question: " + question.human_readable()
138
+ )
139
+ print("Getting the simple model response to: ", human_readable_question)
140
+ raw_simple_response = await simple_model.async_execute_model_call(
141
+ user_prompt=human_readable_question,
142
+ system_prompt="""Pretend you are a human answering a question. Do not break character.""",
143
+ )
144
+ simple_response = simple_model.parse_response(raw_simple_response)
145
+ instructions = question.get_instructions()
146
+
147
+ main_model_prompt = Prompt(
148
+ text="""
149
+ A simpler language model was asked this question:
150
+
151
+ To the simpel model:
152
+ {{ human_readable_question }}
153
+
154
+ The simple model responded:
155
+ <response>
156
+ {{ simple_response }}
157
+ </response>
158
+
159
+ It was suppose to respond according to these instructions:
160
+ <instructions>
161
+ {{ instructions }}
162
+ </instructions>
163
+
164
+ Please format the simple model's response as it should have been formmated, given the instructions.
165
+ Only respond in valid JSON, like so {"answer": "SPAM!"} or {"answer": "SPAM!", "comment": "I am a robot."}
166
+ Do not inlcude the word 'json'
167
+ """
168
+ )
169
+
170
+ d = {
171
+ "human_readable_question": human_readable_question,
172
+ "simple_response": simple_response,
173
+ "instructions": instructions,
174
+ }
175
+
176
+ print("The human-readable question is: ", human_readable_question)
177
+ print("The simple response is: ", simple_response)
178
+
179
+ raw_response_data = await advanced_model.async_execute_model_call(
180
+ user_prompt=main_model_prompt.render(d).text,
181
+ system_prompt="You are a helpful assistant.",
182
+ )
183
+
184
+ raw_response = await advanced_model.async_get_response(
185
+ user_prompt=main_model_prompt.render(d).text,
186
+ system_prompt="You are a helpful assistant.",
187
+ iteration=0,
188
+ cache=self.cache,
189
+ )
190
+
191
+ data = {
192
+ "agent": self.agent,
193
+ "question": self.question,
194
+ "scenario": self.scenario,
195
+ }
196
+ raw_response_data = {
197
+ "raw_response": raw_response,
198
+ "raw_model_response": raw_response["raw_model_response"],
199
+ }
200
+ params = data | raw_response_data
201
+ response = self._format_raw_response(**params)
202
+ response.update({"simple_model_raw_response": simple_response})
203
+ return AgentResponseDict(**response)
204
+
205
+ # get_response = sync_wrapper(async_get_response)
206
+ answer_question = sync_wrapper(async_answer_question)
120
207
 
121
- validate_response: bool = False
122
- translate_response: bool = False
208
+
209
+ class InvigilatorDebug(InvigilatorBase):
210
+ """An invigilator class for debugging purposes."""
123
211
 
124
212
  async def async_answer_question(self, iteration: int = 0) -> AgentResponseDict:
125
213
  """Return the answer to the question."""
126
- comment = "This is a real survey response from a human."
214
+ results = self.question._simulate_answer(human_readable=True)
215
+ results["prompts"] = self.get_prompts()
216
+ results["question_name"] = self.question.question_name
217
+ results["comment"] = "Debug comment"
218
+ return AgentResponseDict(**results)
219
+
220
+ def get_prompts(self) -> Dict[str, Prompt]:
221
+ """Return the prompts used."""
222
+ return {
223
+ "user_prompt": Prompt("NA"),
224
+ "system_prompt": Prompt("NA"),
225
+ }
127
226
 
128
- def __repr__(self):
129
- return f"{self.literal}"
130
227
 
131
- exception_occurred = None
132
- validated = False
228
+ class InvigilatorHuman(InvigilatorBase):
229
+ """An invigilator for when a human is answering the question."""
230
+
231
+ async def async_answer_question(self, iteration: int = 0) -> AgentResponseDict:
232
+ """Return the answer to the question."""
233
+ data = {
234
+ "comment": "This is a real survey response from a human.",
235
+ "answer": None,
236
+ "prompts": self.get_prompts(),
237
+ "question_name": self.question.question_name,
238
+ }
133
239
  try:
134
240
  answer = self.agent.answer_question_directly(self.question, self.scenario)
135
- self.raw_model_response = answer
136
-
137
- if self.validate_response:
138
- _ = self.question._validate_answer({"answer": answer})
139
- if self.translate_response:
140
- answer = self.question._translate_answer_code_to_answer(
141
- answer, self.scenario
142
- )
143
- validated = True
144
- except QuestionAnswerValidationError as e:
145
- answer = None
146
- if self.raise_validation_errors:
147
- exception_occurred = e
241
+ return AgentResponseDict(**(data | {"answer": answer}))
148
242
  except Exception as e:
149
- answer = None
150
- if self.raise_validation_errors:
151
- exception_occurred = e
152
- finally:
153
- data = {
154
- "generated_tokens": NotApplicable(),
155
- "question_name": self.question.question_name,
156
- "prompts": self.get_prompts(),
157
- "cached_response": NotApplicable(),
158
- "raw_model_response": NotApplicable(),
159
- "cache_used": NotApplicable(),
160
- "cache_key": NotApplicable(),
161
- "answer": answer,
162
- "comment": comment,
163
- "validated": validated,
164
- "exception_occurred": exception_occurred,
165
- }
166
- return EDSLResultObjectInput(**data)
243
+ agent_response_dict = AgentResponseDict(
244
+ **(data | {"answer": None, "comment": str(e)})
245
+ )
246
+ raise FailedTaskException(
247
+ f"Failed to get response. The exception is {str(e)}",
248
+ agent_response_dict,
249
+ ) from e
167
250
 
168
251
 
169
252
  class InvigilatorFunctional(InvigilatorBase):
@@ -172,21 +255,22 @@ class InvigilatorFunctional(InvigilatorBase):
172
255
  async def async_answer_question(self, iteration: int = 0) -> AgentResponseDict:
173
256
  """Return the answer to the question."""
174
257
  func = self.question.answer_question_directly
175
- answer = func(scenario=self.scenario, agent_traits=self.agent.traits)
176
-
177
- return EDSLResultObjectInput(
178
- generated_tokens=str(answer),
179
- question_name=self.question.question_name,
180
- prompts=self.get_prompts(),
181
- cached_response=NotApplicable(),
182
- raw_model_response=NotApplicable(),
183
- cache_used=NotApplicable(),
184
- cache_key=NotApplicable(),
185
- answer=answer["answer"],
186
- comment="This is the result of a functional question.",
187
- validated=True,
188
- exception_occurred=None,
189
- )
258
+ data = {
259
+ "comment": "Functional.",
260
+ "prompts": self.get_prompts(),
261
+ "question_name": self.question.question_name,
262
+ }
263
+ try:
264
+ answer = func(scenario=self.scenario, agent_traits=self.agent.traits)
265
+ return AgentResponseDict(**(data | answer))
266
+ except Exception as e:
267
+ agent_response_dict = AgentResponseDict(
268
+ **(data | {"answer": None, "comment": str(e)})
269
+ )
270
+ raise FailedTaskException(
271
+ f"Failed to get response. The exception is {str(e)}",
272
+ agent_response_dict,
273
+ ) from e
190
274
 
191
275
  def get_prompts(self) -> Dict[str, Prompt]:
192
276
  """Return the prompts used."""