edsl 0.1.33.dev1__py3-none-any.whl → 0.1.33.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. edsl/TemplateLoader.py +24 -0
  2. edsl/__init__.py +8 -4
  3. edsl/agents/Agent.py +46 -14
  4. edsl/agents/AgentList.py +43 -0
  5. edsl/agents/Invigilator.py +125 -212
  6. edsl/agents/InvigilatorBase.py +140 -32
  7. edsl/agents/PromptConstructionMixin.py +43 -66
  8. edsl/agents/__init__.py +1 -0
  9. edsl/auto/AutoStudy.py +117 -0
  10. edsl/auto/StageBase.py +230 -0
  11. edsl/auto/StageGenerateSurvey.py +178 -0
  12. edsl/auto/StageLabelQuestions.py +125 -0
  13. edsl/auto/StagePersona.py +61 -0
  14. edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
  15. edsl/auto/StagePersonaDimensionValues.py +74 -0
  16. edsl/auto/StagePersonaDimensions.py +69 -0
  17. edsl/auto/StageQuestions.py +73 -0
  18. edsl/auto/SurveyCreatorPipeline.py +21 -0
  19. edsl/auto/utilities.py +224 -0
  20. edsl/config.py +38 -39
  21. edsl/coop/PriceFetcher.py +58 -0
  22. edsl/coop/coop.py +39 -5
  23. edsl/data/Cache.py +35 -1
  24. edsl/data_transfer_models.py +120 -38
  25. edsl/enums.py +2 -0
  26. edsl/exceptions/language_models.py +25 -1
  27. edsl/exceptions/questions.py +62 -5
  28. edsl/exceptions/results.py +4 -0
  29. edsl/inference_services/AnthropicService.py +13 -11
  30. edsl/inference_services/AwsBedrock.py +19 -17
  31. edsl/inference_services/AzureAI.py +37 -20
  32. edsl/inference_services/GoogleService.py +16 -12
  33. edsl/inference_services/GroqService.py +2 -0
  34. edsl/inference_services/InferenceServiceABC.py +24 -0
  35. edsl/inference_services/MistralAIService.py +120 -0
  36. edsl/inference_services/OpenAIService.py +41 -50
  37. edsl/inference_services/TestService.py +71 -0
  38. edsl/inference_services/models_available_cache.py +0 -6
  39. edsl/inference_services/registry.py +4 -0
  40. edsl/jobs/Answers.py +10 -12
  41. edsl/jobs/FailedQuestion.py +78 -0
  42. edsl/jobs/Jobs.py +18 -13
  43. edsl/jobs/buckets/TokenBucket.py +39 -14
  44. edsl/jobs/interviews/Interview.py +297 -77
  45. edsl/jobs/interviews/InterviewExceptionEntry.py +83 -19
  46. edsl/jobs/interviews/interview_exception_tracking.py +0 -70
  47. edsl/jobs/interviews/retry_management.py +3 -1
  48. edsl/jobs/runners/JobsRunnerAsyncio.py +116 -70
  49. edsl/jobs/runners/JobsRunnerStatusMixin.py +1 -1
  50. edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
  51. edsl/jobs/tasks/TaskHistory.py +131 -213
  52. edsl/language_models/LanguageModel.py +239 -129
  53. edsl/language_models/ModelList.py +2 -2
  54. edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
  55. edsl/language_models/fake_openai_call.py +15 -0
  56. edsl/language_models/fake_openai_service.py +61 -0
  57. edsl/language_models/registry.py +15 -2
  58. edsl/language_models/repair.py +0 -19
  59. edsl/language_models/utilities.py +61 -0
  60. edsl/prompts/Prompt.py +52 -2
  61. edsl/questions/AnswerValidatorMixin.py +23 -26
  62. edsl/questions/QuestionBase.py +273 -242
  63. edsl/questions/QuestionBaseGenMixin.py +133 -0
  64. edsl/questions/QuestionBasePromptsMixin.py +266 -0
  65. edsl/questions/QuestionBudget.py +6 -0
  66. edsl/questions/QuestionCheckBox.py +227 -35
  67. edsl/questions/QuestionExtract.py +98 -27
  68. edsl/questions/QuestionFreeText.py +46 -29
  69. edsl/questions/QuestionFunctional.py +7 -0
  70. edsl/questions/QuestionList.py +141 -22
  71. edsl/questions/QuestionMultipleChoice.py +173 -64
  72. edsl/questions/QuestionNumerical.py +87 -46
  73. edsl/questions/QuestionRank.py +182 -24
  74. edsl/questions/RegisterQuestionsMeta.py +31 -12
  75. edsl/questions/ResponseValidatorABC.py +169 -0
  76. edsl/questions/__init__.py +3 -4
  77. edsl/questions/decorators.py +21 -0
  78. edsl/questions/derived/QuestionLikertFive.py +10 -5
  79. edsl/questions/derived/QuestionLinearScale.py +11 -1
  80. edsl/questions/derived/QuestionTopK.py +6 -0
  81. edsl/questions/derived/QuestionYesNo.py +16 -1
  82. edsl/questions/descriptors.py +43 -7
  83. edsl/questions/prompt_templates/question_budget.jinja +13 -0
  84. edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
  85. edsl/questions/prompt_templates/question_extract.jinja +11 -0
  86. edsl/questions/prompt_templates/question_free_text.jinja +3 -0
  87. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
  88. edsl/questions/prompt_templates/question_list.jinja +17 -0
  89. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
  90. edsl/questions/prompt_templates/question_numerical.jinja +37 -0
  91. edsl/questions/question_registry.py +6 -2
  92. edsl/questions/templates/__init__.py +0 -0
  93. edsl/questions/templates/checkbox/__init__.py +0 -0
  94. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
  95. edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
  96. edsl/questions/templates/extract/answering_instructions.jinja +7 -0
  97. edsl/questions/templates/extract/question_presentation.jinja +1 -0
  98. edsl/questions/templates/free_text/__init__.py +0 -0
  99. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  100. edsl/questions/templates/free_text/question_presentation.jinja +1 -0
  101. edsl/questions/templates/likert_five/__init__.py +0 -0
  102. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
  103. edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
  104. edsl/questions/templates/linear_scale/__init__.py +0 -0
  105. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
  106. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
  107. edsl/questions/templates/list/__init__.py +0 -0
  108. edsl/questions/templates/list/answering_instructions.jinja +4 -0
  109. edsl/questions/templates/list/question_presentation.jinja +5 -0
  110. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  111. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
  112. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  113. edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
  114. edsl/questions/templates/numerical/__init__.py +0 -0
  115. edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
  116. edsl/questions/templates/numerical/question_presentation.jinja +7 -0
  117. edsl/questions/templates/rank/answering_instructions.jinja +11 -0
  118. edsl/questions/templates/rank/question_presentation.jinja +15 -0
  119. edsl/questions/templates/top_k/__init__.py +0 -0
  120. edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
  121. edsl/questions/templates/top_k/question_presentation.jinja +22 -0
  122. edsl/questions/templates/yes_no/__init__.py +0 -0
  123. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
  124. edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
  125. edsl/results/Dataset.py +20 -0
  126. edsl/results/DatasetExportMixin.py +41 -47
  127. edsl/results/DatasetTree.py +145 -0
  128. edsl/results/Result.py +32 -5
  129. edsl/results/Results.py +131 -45
  130. edsl/results/ResultsDBMixin.py +3 -3
  131. edsl/results/Selector.py +118 -0
  132. edsl/results/tree_explore.py +115 -0
  133. edsl/scenarios/Scenario.py +10 -4
  134. edsl/scenarios/ScenarioList.py +348 -39
  135. edsl/scenarios/ScenarioListExportMixin.py +9 -0
  136. edsl/study/SnapShot.py +8 -1
  137. edsl/surveys/RuleCollection.py +2 -2
  138. edsl/surveys/Survey.py +634 -315
  139. edsl/surveys/SurveyExportMixin.py +71 -9
  140. edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
  141. edsl/surveys/SurveyQualtricsImport.py +75 -4
  142. edsl/surveys/instructions/ChangeInstruction.py +47 -0
  143. edsl/surveys/instructions/Instruction.py +34 -0
  144. edsl/surveys/instructions/InstructionCollection.py +77 -0
  145. edsl/surveys/instructions/__init__.py +0 -0
  146. edsl/templates/error_reporting/base.html +24 -0
  147. edsl/templates/error_reporting/exceptions_by_model.html +35 -0
  148. edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
  149. edsl/templates/error_reporting/exceptions_by_type.html +17 -0
  150. edsl/templates/error_reporting/interview_details.html +111 -0
  151. edsl/templates/error_reporting/interviews.html +10 -0
  152. edsl/templates/error_reporting/overview.html +5 -0
  153. edsl/templates/error_reporting/performance_plot.html +2 -0
  154. edsl/templates/error_reporting/report.css +74 -0
  155. edsl/templates/error_reporting/report.html +118 -0
  156. edsl/templates/error_reporting/report.js +25 -0
  157. {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/METADATA +4 -2
  158. edsl-0.1.33.dev2.dist-info/RECORD +289 -0
  159. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -286
  160. edsl/utilities/gcp_bucket/simple_example.py +0 -9
  161. edsl-0.1.33.dev1.dist-info/RECORD +0 -209
  162. {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/LICENSE +0 -0
  163. {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/WHEEL +0 -0
edsl/TemplateLoader.py ADDED
@@ -0,0 +1,24 @@
1
+ from importlib import resources
2
+ from jinja2 import BaseLoader, TemplateNotFound
3
+ import os
4
+
5
+
6
+ class TemplateLoader(BaseLoader):
7
+ def __init__(self, package_name, templates_dir):
8
+ self.package_name = package_name
9
+ self.templates_dir = templates_dir
10
+
11
+ def get_source(self, environment, template):
12
+ try:
13
+ parts = [self.templates_dir] + template.split("/")
14
+ template_path = os.path.join(*parts)
15
+
16
+ # Use resources.files() to get a Traversable object
17
+ templates = resources.files(self.package_name).joinpath(self.templates_dir)
18
+
19
+ # Use the read_text() method of the Traversable object
20
+ content = templates.joinpath(template).read_text()
21
+
22
+ return content, None, lambda: True
23
+ except FileNotFoundError:
24
+ raise TemplateNotFound(template)
edsl/__init__.py CHANGED
@@ -8,9 +8,10 @@ from edsl.__version__ import __version__
8
8
  from edsl.config import Config, CONFIG
9
9
  from edsl.agents.Agent import Agent
10
10
  from edsl.agents.AgentList import AgentList
11
+
11
12
  from edsl.questions import QuestionBase
13
+ from edsl.questions.question_registry import Question
12
14
  from edsl.questions import QuestionMultipleChoice
13
- from edsl.questions import QuestionBudget
14
15
  from edsl.questions import QuestionCheckBox
15
16
  from edsl.questions import QuestionExtract
16
17
  from edsl.questions import QuestionFreeText
@@ -19,10 +20,10 @@ from edsl.questions import QuestionLikertFive
19
20
  from edsl.questions import QuestionList
20
21
  from edsl.questions import QuestionLinearScale
21
22
  from edsl.questions import QuestionNumerical
22
- from edsl.questions import QuestionRank
23
- from edsl.questions import QuestionTopK
24
23
  from edsl.questions import QuestionYesNo
25
- from edsl.questions.question_registry import Question
24
+ from edsl.questions import QuestionBudget
25
+ from edsl.questions import QuestionRank
26
+
26
27
  from edsl.scenarios import Scenario
27
28
  from edsl.scenarios import ScenarioList
28
29
 
@@ -40,3 +41,6 @@ from edsl.notebooks.Notebook import Notebook
40
41
  from edsl.study.Study import Study
41
42
  from edsl.conjure.Conjure import Conjure
42
43
  from edsl.coop.coop import Coop
44
+
45
+ from edsl.surveys.instructions.Instruction import Instruction
46
+ from edsl.surveys.instructions.ChangeInstruction import ChangeInstruction
edsl/agents/Agent.py CHANGED
@@ -4,7 +4,7 @@ from __future__ import annotations
4
4
  import copy
5
5
  import inspect
6
6
  import types
7
- from typing import Callable, Optional, Union
7
+ from typing import Callable, Optional, Union, Any
8
8
  from uuid import uuid4
9
9
  from edsl.Base import Base
10
10
 
@@ -228,7 +228,12 @@ class Agent(Base):
228
228
  if hasattr(self, "answer_question_directly"):
229
229
  delattr(self, "answer_question_directly")
230
230
 
231
- def add_direct_question_answering_method(self, method: Callable) -> None:
231
+ def add_direct_question_answering_method(
232
+ self,
233
+ method: Callable,
234
+ validate_response: bool = False,
235
+ translate_response: bool = False,
236
+ ) -> None:
232
237
  """Add a method to the agent that can answer a particular question type.
233
238
 
234
239
  :param method: A method that can answer a question directly.
@@ -249,6 +254,9 @@ class Agent(Base):
249
254
  )
250
255
  # print("Warning: overwriting existing answer_question_directly method")
251
256
 
257
+ self.validate_response = validate_response
258
+ self.translate_response = translate_response
259
+
252
260
  signature = inspect.signature(method)
253
261
  for argument in ["question", "scenario", "self"]:
254
262
  if argument not in signature.parameters:
@@ -272,6 +280,7 @@ class Agent(Base):
272
280
  current_answers: Optional[dict] = None,
273
281
  iteration: int = 1,
274
282
  sidecar_model=None,
283
+ raise_validation_errors: bool = True,
275
284
  ) -> "InvigilatorBase":
276
285
  """Create an Invigilator.
277
286
 
@@ -303,7 +312,12 @@ class Agent(Base):
303
312
  iteration=iteration,
304
313
  cache=cache,
305
314
  sidecar_model=sidecar_model,
315
+ raise_validation_errors=raise_validation_errors,
306
316
  )
317
+ if hasattr(self, "validate_response"):
318
+ invigilator.validate_response = self.validate_response
319
+ if hasattr(self, "translate_response"):
320
+ invigilator.translate_response = self.translate_response
307
321
  return invigilator
308
322
 
309
323
  async def async_answer_question(
@@ -334,8 +348,8 @@ class Agent(Base):
334
348
  >>> a.add_direct_question_answering_method(lambda self, question, scenario: "I am a direct answer.")
335
349
  >>> from edsl import QuestionFreeText
336
350
  >>> q = QuestionFreeText.example()
337
- >>> a.answer_question(question = q, cache = False)
338
- {'answer': 'I am a direct answer.', 'comment': 'This is a real survey response from a human.', ...}
351
+ >>> a.answer_question(question = q, cache = False).answer
352
+ 'I am a direct answer.'
339
353
 
340
354
  This is a function where an agent returns an answer to a particular question.
341
355
  However, there are several different ways an agent can answer a question, so the
@@ -369,6 +383,7 @@ class Agent(Base):
369
383
  current_answers: Optional[dict] = None,
370
384
  iteration: int = 0,
371
385
  sidecar_model=None,
386
+ raise_validation_errors: bool = True,
372
387
  ) -> "InvigilatorBase":
373
388
  """Create an Invigilator."""
374
389
  from edsl import Model
@@ -378,7 +393,6 @@ class Agent(Base):
378
393
  scenario = scenario or Scenario()
379
394
 
380
395
  from edsl.agents.Invigilator import (
381
- InvigilatorDebug,
382
396
  InvigilatorHuman,
383
397
  InvigilatorFunctional,
384
398
  InvigilatorAI,
@@ -391,8 +405,9 @@ class Agent(Base):
391
405
  cache = Cache()
392
406
 
393
407
  if debug:
408
+ raise NotImplementedError("Debug mode is not yet implemented.")
394
409
  # use the question's _simulate_answer method
395
- invigilator_class = InvigilatorDebug
410
+ # invigilator_class = InvigilatorDebug
396
411
  elif hasattr(question, "answer_question_directly"):
397
412
  # It's a functional question and the answer only depends on the agent's traits & the scenario
398
413
  invigilator_class = InvigilatorFunctional
@@ -422,6 +437,7 @@ class Agent(Base):
422
437
  iteration=iteration,
423
438
  cache=cache,
424
439
  sidecar_model=sidecar_model,
440
+ raise_validation_errors=raise_validation_errors,
425
441
  )
426
442
  return invigilator
427
443
 
@@ -497,8 +513,8 @@ class Agent(Base):
497
513
  if name == "has_dynamic_traits_function":
498
514
  return self.has_dynamic_traits_function
499
515
 
500
- if name in self.traits:
501
- return self.traits[name]
516
+ if name in self._traits:
517
+ return self._traits[name]
502
518
  raise AttributeError(
503
519
  f"'{type(self).__name__}' object has no attribute '{name}'"
504
520
  )
@@ -570,9 +586,9 @@ class Agent(Base):
570
586
  if dynamic_traits_func:
571
587
  func = inspect.getsource(dynamic_traits_func)
572
588
  raw_data["dynamic_traits_function_source_code"] = func
573
- raw_data[
574
- "dynamic_traits_function_name"
575
- ] = self.dynamic_traits_function_name
589
+ raw_data["dynamic_traits_function_name"] = (
590
+ self.dynamic_traits_function_name
591
+ )
576
592
  if hasattr(self, "answer_question_directly"):
577
593
  raw_data.pop(
578
594
  "answer_question_directly", None
@@ -588,9 +604,9 @@ class Agent(Base):
588
604
  raw_data["answer_question_directly_source_code"] = inspect.getsource(
589
605
  answer_question_directly_func
590
606
  )
591
- raw_data[
592
- "answer_question_directly_function_name"
593
- ] = self.answer_question_directly_function_name
607
+ raw_data["answer_question_directly_function_name"] = (
608
+ self.answer_question_directly_function_name
609
+ )
594
610
 
595
611
  return raw_data
596
612
 
@@ -640,6 +656,22 @@ class Agent(Base):
640
656
  column_names = ["Attribute", "Value"]
641
657
  return table_data, column_names
642
658
 
659
+ def add_trait(self, trait_name_or_dict: str, value: Optional[Any] = None) -> Agent:
660
+ """Adds a trait to an agent and returns that agent"""
661
+ if isinstance(trait_name_or_dict, dict) and value is None:
662
+ self.traits.update(trait_name_or_dict)
663
+ return self
664
+
665
+ if isinstance(trait_name_or_dict, dict) and value:
666
+ raise ValueError(f"You passed a dict: {trait_name_or_dict}")
667
+
668
+ if isinstance(trait_name_or_dict, str):
669
+ trait = trait_name_or_dict
670
+ self.traits[trait] = value
671
+ return self
672
+
673
+ raise Exception("Something is not right with adding")
674
+
643
675
  def remove_trait(self, trait: str) -> Agent:
644
676
  """Remove a trait from the agent.
645
677
 
edsl/agents/AgentList.py CHANGED
@@ -21,6 +21,12 @@ from simpleeval import EvalWithCompoundTypes
21
21
  from edsl.Base import Base
22
22
  from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
23
23
 
24
+ from collections.abc import Iterable
25
+
26
+
27
+ def is_iterable(obj):
28
+ return isinstance(obj, Iterable)
29
+
24
30
 
25
31
  class AgentList(UserList, Base):
26
32
  """A list of Agents."""
@@ -111,6 +117,13 @@ class AgentList(UserList, Base):
111
117
 
112
118
  return AgentList(new_data)
113
119
 
120
+ @property
121
+ def all_traits(self):
122
+ d = {}
123
+ for agent in self:
124
+ d.update(agent.traits)
125
+ return list(d.keys())
126
+
114
127
  @classmethod
115
128
  def from_csv(cls, file_path: str):
116
129
  """Load AgentList from a CSV file.
@@ -159,6 +172,36 @@ class AgentList(UserList, Base):
159
172
  _ = agent.remove_trait(trait)
160
173
  return self
161
174
 
175
+ def add_trait(self, trait, values):
176
+ """Adds a new trait to every agent, with values taken from values.
177
+
178
+ :param trait: The name of the trait.
179
+ :param values: The valeues(s) of the trait. If a single value is passed, it is used for all agents.
180
+
181
+ >>> al = AgentList.example()
182
+ >>> al.add_trait('new_trait', 1)
183
+ AgentList([Agent(traits = {'age': 22, 'hair': 'brown', 'height': 5.5, 'new_trait': 1}), Agent(traits = {'age': 22, 'hair': 'brown', 'height': 5.5, 'new_trait': 1})])
184
+ >>> al.select('new_trait').to_scenario_list().to_list()
185
+ [1, 1]
186
+ >>> al.add_trait('new_trait', [1, 2, 3])
187
+ Traceback (most recent call last):
188
+ ...
189
+ ValueError: The passed values have to be the same length as the agent list.
190
+ """
191
+ if not is_iterable(values):
192
+ value = values
193
+ for agent in self.data:
194
+ agent.add_trait(trait, value)
195
+ return self
196
+
197
+ if len(values) != len(self):
198
+ raise ValueError(
199
+ "The passed values have to be the same length as the agent list."
200
+ )
201
+ for agent, value in zip(self.data, values):
202
+ agent.add_trait(trait, value)
203
+ return self
204
+
162
205
  @staticmethod
163
206
  def get_codebook(file_path: str):
164
207
  """Return the codebook for a CSV file.
@@ -1,17 +1,22 @@
1
1
  """Module for creating Invigilators, which are objects to administer a question to an Agent."""
2
2
 
3
- import json
4
3
  from typing import Dict, Any, Optional
5
4
 
6
5
  from edsl.exceptions import AgentRespondedWithBadJSONError
7
6
  from edsl.prompts.Prompt import Prompt
8
7
  from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
9
8
  from edsl.prompts.registry import get_classes as prompt_lookup
10
- from edsl.data_transfer_models import AgentResponseDict
11
- from edsl.exceptions.agents import FailedTaskException
9
+ from edsl.exceptions.questions import QuestionAnswerValidationError
12
10
  from edsl.agents.PromptConstructionMixin import PromptConstructorMixin
13
-
14
11
  from edsl.agents.InvigilatorBase import InvigilatorBase
12
+ from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
13
+
14
+
15
+ class NotApplicable(str):
16
+ def __new__(cls):
17
+ instance = super().__new__(cls, "Not Applicable")
18
+ instance.literal = "Not Applicable"
19
+ return instance
15
20
 
16
21
 
17
22
  class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
@@ -22,231 +27,140 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
22
27
 
23
28
  >>> i = InvigilatorAI.example()
24
29
  >>> i.answer_question()
25
- {'message': '{"answer": "SPAM!"}'}
30
+ {'message': [{'text': 'SPAM!'}], 'usage': {'prompt_tokens': 1, 'completion_tokens': 1}}
26
31
  """
27
- params = self.get_prompts() | {"iteration": self.iteration}
28
- raw_response = await self.async_get_response(**params)
29
- data = {
30
- "agent": self.agent,
31
- "question": self.question,
32
- "scenario": self.scenario,
33
- "raw_response": raw_response,
34
- "raw_model_response": raw_response["raw_model_response"],
32
+ prompts = self.get_prompts()
33
+ params = {
34
+ "user_prompt": prompts["user_prompt"].text,
35
+ "system_prompt": prompts["system_prompt"].text,
35
36
  }
36
- response = self._format_raw_response(**data)
37
- # breakpoint()
38
- return AgentResponseDict(**response)
37
+ if "encoded_image" in prompts:
38
+ params["encoded_image"] = prompts["encoded_image"]
39
39
 
40
- async def async_get_response(
41
- self,
42
- user_prompt: Prompt,
43
- system_prompt: Prompt,
44
- iteration: int = 0,
45
- encoded_image=None,
46
- ) -> dict:
47
- """Call the LLM and gets a response. Used in the `answer_question` method."""
48
- try:
49
- params = {
50
- "user_prompt": user_prompt.text,
51
- "system_prompt": system_prompt.text,
52
- "iteration": iteration,
53
- "cache": self.cache,
54
- }
55
- if encoded_image:
56
- params["encoded_image"] = encoded_image
57
- response = await self.model.async_get_response(**params)
40
+ params.update({"iteration": self.iteration, "cache": self.cache})
58
41
 
59
- # TODO: I *don't* think we need to delete the cache key here because I think
60
- # it will not have been set yet; the exception would have been raised before.
61
- except json.JSONDecodeError as e:
62
- raise AgentRespondedWithBadJSONError(
63
- f"Returned bad JSON: {e}"
64
- f"Prompt: {user_prompt}"
65
- f"System Prompt: {system_prompt}"
66
- )
42
+ agent_response_dict: AgentResponseDict = await self.model.async_get_response(
43
+ **params
44
+ )
45
+ # store to self in case validation failure
46
+ self.raw_model_response = agent_response_dict.model_outputs.response
47
+ self.generated_tokens = agent_response_dict.edsl_dict.generated_tokens
67
48
 
68
- return response
49
+ return self.extract_edsl_result_entry_and_validate(agent_response_dict)
69
50
 
70
- def _remove_from_cache(self, raw_response) -> None:
51
+ def _remove_from_cache(self, cache_key) -> None:
71
52
  """Remove an entry from the cache."""
72
- cache_key = raw_response.get("cache_key", None)
73
53
  if cache_key:
74
54
  del self.cache.data[cache_key]
75
55
 
76
- def _format_raw_response(
77
- self, *, agent, question, scenario, raw_response, raw_model_response
78
- ) -> AgentResponseDict:
79
- """Return formatted raw response.
80
-
81
- This cleans up the raw response to make it suitable to pass to AgentResponseDict.
82
- """
83
- _ = agent
84
- try:
85
- response = question._validate_answer(raw_response)
86
- except Exception as e:
87
- """If the response is invalid, remove it from the cache and raise the exception."""
88
- self._remove_from_cache(raw_response)
89
- raise e
90
-
56
+ def determine_answer(self, raw_answer: str) -> Any:
91
57
  question_dict = self.survey.question_names_to_questions()
58
+ # iterates through the current answers and updates the question_dict (which is all questions)
92
59
  for other_question, answer in self.current_answers.items():
93
60
  if other_question in question_dict:
94
61
  question_dict[other_question].answer = answer
95
62
  else:
96
- # adds a comment to the question
63
+ # it might be a comment
97
64
  if (
98
65
  new_question := other_question.split("_comment")[0]
99
66
  ) in question_dict:
100
67
  question_dict[new_question].comment = answer
101
68
 
102
- combined_dict = {**question_dict, **scenario}
103
- answer = question._translate_answer_code_to_answer(
104
- response["answer"], combined_dict
105
- )
106
- data = {
107
- "answer": answer,
108
- "comment": response.get(
109
- "comment", ""
110
- ), # not all question have comment fields,
111
- "question_name": question.question_name,
112
- "prompts": self.get_prompts(),
113
- "cached_response": raw_response.get("cached_response", None),
114
- "usage": raw_response.get("usage", {}),
115
- "raw_model_response": raw_model_response,
116
- "cache_used": raw_response.get("cache_used", False),
117
- "cache_key": raw_response.get("cache_key", None),
118
- }
119
- return AgentResponseDict(**data)
120
-
121
- get_response = sync_wrapper(async_get_response)
122
- answer_question = sync_wrapper(async_answer_question)
123
-
124
-
125
- class InvigilatorSidecar(InvigilatorAI):
126
- """An invigilator that presents the 'raw' question to the agent
127
- & uses a sidecar model to answer questions."""
128
-
129
- async def async_answer_question(self, failed: bool = False) -> AgentResponseDict:
130
- """Answer a question using the AI model."""
131
- from edsl import Model
132
-
133
- advanced_model = self.sidecar_model
134
- simple_model = self.model
135
- question = self.question
136
- human_readable_question = (
137
- "Please answer this single question: " + question.human_readable()
138
- )
139
- print("Getting the simple model response to: ", human_readable_question)
140
- raw_simple_response = await simple_model.async_execute_model_call(
141
- user_prompt=human_readable_question,
142
- system_prompt="""Pretend you are a human answering a question. Do not break character.""",
143
- )
144
- simple_response = simple_model.parse_response(raw_simple_response)
145
- instructions = question.get_instructions()
146
-
147
- main_model_prompt = Prompt(
148
- text="""
149
- A simpler language model was asked this question:
150
-
151
- To the simpel model:
152
- {{ human_readable_question }}
153
-
154
- The simple model responded:
155
- <response>
156
- {{ simple_response }}
157
- </response>
158
-
159
- It was suppose to respond according to these instructions:
160
- <instructions>
161
- {{ instructions }}
162
- </instructions>
163
-
164
- Please format the simple model's response as it should have been formmated, given the instructions.
165
- Only respond in valid JSON, like so {"answer": "SPAM!"} or {"answer": "SPAM!", "comment": "I am a robot."}
166
- Do not inlcude the word 'json'
167
- """
168
- )
169
-
170
- d = {
171
- "human_readable_question": human_readable_question,
172
- "simple_response": simple_response,
173
- "instructions": instructions,
174
- }
175
-
176
- print("The human-readable question is: ", human_readable_question)
177
- print("The simple response is: ", simple_response)
69
+ combined_dict = {**question_dict, **self.scenario}
70
+ # sometimes the answer is a code, so we need to translate it
71
+ return self.question._translate_answer_code_to_answer(raw_answer, combined_dict)
178
72
 
179
- raw_response_data = await advanced_model.async_execute_model_call(
180
- user_prompt=main_model_prompt.render(d).text,
181
- system_prompt="You are a helpful assistant.",
182
- )
183
-
184
- raw_response = await advanced_model.async_get_response(
185
- user_prompt=main_model_prompt.render(d).text,
186
- system_prompt="You are a helpful assistant.",
187
- iteration=0,
188
- cache=self.cache,
189
- )
190
-
191
- data = {
192
- "agent": self.agent,
193
- "question": self.question,
194
- "scenario": self.scenario,
195
- }
196
- raw_response_data = {
197
- "raw_response": raw_response,
198
- "raw_model_response": raw_response["raw_model_response"],
199
- }
200
- params = data | raw_response_data
201
- response = self._format_raw_response(**params)
202
- response.update({"simple_model_raw_response": simple_response})
203
- return AgentResponseDict(**response)
73
+ def extract_edsl_result_entry_and_validate(
74
+ self, agent_response_dict: AgentResponseDict
75
+ ) -> EDSLResultObjectInput:
76
+ edsl_dict = agent_response_dict.edsl_dict._asdict()
77
+ exception_occurred = None
78
+ validated = False
79
+ try:
80
+ validated_edsl_dict = self.question._validate_answer(edsl_dict)
81
+ answer = self.determine_answer(validated_edsl_dict["answer"])
82
+ comment = validated_edsl_dict.get("comment", "")
83
+ validated = True
84
+ except QuestionAnswerValidationError as e:
85
+ answer = None
86
+ comment = "The response was not valid."
87
+ if self.raise_validation_errors:
88
+ exception_occurred = e
89
+ except Exception as non_validation_error:
90
+ answer = None
91
+ comment = "Some other error occurred."
92
+ exception_occurred = non_validation_error
93
+ finally:
94
+ # even if validation failes, we still return the result
95
+ data = {
96
+ "answer": answer,
97
+ "comment": comment,
98
+ "generated_tokens": agent_response_dict.edsl_dict.generated_tokens,
99
+ "question_name": self.question.question_name,
100
+ "prompts": self.get_prompts(),
101
+ "cached_response": agent_response_dict.model_outputs.cached_response,
102
+ "raw_model_response": agent_response_dict.model_outputs.response,
103
+ "cache_used": agent_response_dict.model_outputs.cache_used,
104
+ "cache_key": agent_response_dict.model_outputs.cache_key,
105
+ "validated": validated,
106
+ "exception_occurred": exception_occurred,
107
+ "cost": agent_response_dict.model_outputs.cost,
108
+ }
109
+ result = EDSLResultObjectInput(**data)
110
+ return result
204
111
 
205
- # get_response = sync_wrapper(async_get_response)
206
112
  answer_question = sync_wrapper(async_answer_question)
207
113
 
208
114
 
209
- class InvigilatorDebug(InvigilatorBase):
210
- """An invigilator class for debugging purposes."""
211
-
212
- async def async_answer_question(self, iteration: int = 0) -> AgentResponseDict:
213
- """Return the answer to the question."""
214
- results = self.question._simulate_answer(human_readable=True)
215
- results["prompts"] = self.get_prompts()
216
- results["question_name"] = self.question.question_name
217
- results["comment"] = "Debug comment"
218
- return AgentResponseDict(**results)
219
-
220
- def get_prompts(self) -> Dict[str, Prompt]:
221
- """Return the prompts used."""
222
- return {
223
- "user_prompt": Prompt("NA"),
224
- "system_prompt": Prompt("NA"),
225
- }
226
-
227
-
228
115
  class InvigilatorHuman(InvigilatorBase):
229
116
  """An invigilator for when a human is answering the question."""
230
117
 
118
+ validate_response: bool = False
119
+ translate_response: bool = False
120
+
231
121
  async def async_answer_question(self, iteration: int = 0) -> AgentResponseDict:
232
122
  """Return the answer to the question."""
233
- data = {
234
- "comment": "This is a real survey response from a human.",
235
- "answer": None,
236
- "prompts": self.get_prompts(),
237
- "question_name": self.question.question_name,
238
- }
123
+ comment = "This is a real survey response from a human."
124
+
125
+ def __repr__(self):
126
+ return f"{self.literal}"
127
+
128
+ exception_occurred = None
129
+ validated = False
239
130
  try:
240
131
  answer = self.agent.answer_question_directly(self.question, self.scenario)
241
- return AgentResponseDict(**(data | {"answer": answer}))
132
+ self.raw_model_response = answer
133
+
134
+ if self.validate_response:
135
+ _ = self.question._validate_answer({"answer": answer})
136
+ if self.translate_response:
137
+ answer = self.question._translate_answer_code_to_answer(
138
+ answer, self.scenario
139
+ )
140
+ validated = True
141
+ except QuestionAnswerValidationError as e:
142
+ answer = None
143
+ if self.raise_validation_errors:
144
+ exception_occurred = e
242
145
  except Exception as e:
243
- agent_response_dict = AgentResponseDict(
244
- **(data | {"answer": None, "comment": str(e)})
245
- )
246
- raise FailedTaskException(
247
- f"Failed to get response. The exception is {str(e)}",
248
- agent_response_dict,
249
- ) from e
146
+ answer = None
147
+ if self.raise_validation_errors:
148
+ exception_occurred = e
149
+ finally:
150
+ data = {
151
+ "generated_tokens": NotApplicable(),
152
+ "question_name": self.question.question_name,
153
+ "prompts": self.get_prompts(),
154
+ "cached_response": NotApplicable(),
155
+ "raw_model_response": NotApplicable(),
156
+ "cache_used": NotApplicable(),
157
+ "cache_key": NotApplicable(),
158
+ "answer": answer,
159
+ "comment": comment,
160
+ "validated": validated,
161
+ "exception_occurred": exception_occurred,
162
+ }
163
+ return EDSLResultObjectInput(**data)
250
164
 
251
165
 
252
166
  class InvigilatorFunctional(InvigilatorBase):
@@ -255,22 +169,21 @@ class InvigilatorFunctional(InvigilatorBase):
255
169
  async def async_answer_question(self, iteration: int = 0) -> AgentResponseDict:
256
170
  """Return the answer to the question."""
257
171
  func = self.question.answer_question_directly
258
- data = {
259
- "comment": "Functional.",
260
- "prompts": self.get_prompts(),
261
- "question_name": self.question.question_name,
262
- }
263
- try:
264
- answer = func(scenario=self.scenario, agent_traits=self.agent.traits)
265
- return AgentResponseDict(**(data | answer))
266
- except Exception as e:
267
- agent_response_dict = AgentResponseDict(
268
- **(data | {"answer": None, "comment": str(e)})
269
- )
270
- raise FailedTaskException(
271
- f"Failed to get response. The exception is {str(e)}",
272
- agent_response_dict,
273
- ) from e
172
+ answer = func(scenario=self.scenario, agent_traits=self.agent.traits)
173
+
174
+ return EDSLResultObjectInput(
175
+ generated_tokens=str(answer),
176
+ question_name=self.question.question_name,
177
+ prompts=self.get_prompts(),
178
+ cached_response=NotApplicable(),
179
+ raw_model_response=NotApplicable(),
180
+ cache_used=NotApplicable(),
181
+ cache_key=NotApplicable(),
182
+ answer=answer["answer"],
183
+ comment="This is the result of a functional question.",
184
+ validated=True,
185
+ exception_occurred=None,
186
+ )
274
187
 
275
188
  def get_prompts(self) -> Dict[str, Prompt]:
276
189
  """Return the prompts used."""