edsl 0.1.36.dev5__py3-none-any.whl → 0.1.36.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (257) hide show
  1. edsl/Base.py +303 -303
  2. edsl/BaseDiff.py +260 -260
  3. edsl/TemplateLoader.py +24 -24
  4. edsl/__init__.py +47 -47
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +804 -804
  7. edsl/agents/AgentList.py +337 -337
  8. edsl/agents/Invigilator.py +222 -222
  9. edsl/agents/InvigilatorBase.py +294 -294
  10. edsl/agents/PromptConstructor.py +312 -312
  11. edsl/agents/__init__.py +3 -3
  12. edsl/agents/descriptors.py +86 -86
  13. edsl/agents/prompt_helpers.py +129 -129
  14. edsl/auto/AutoStudy.py +117 -117
  15. edsl/auto/StageBase.py +230 -230
  16. edsl/auto/StageGenerateSurvey.py +178 -178
  17. edsl/auto/StageLabelQuestions.py +125 -125
  18. edsl/auto/StagePersona.py +61 -61
  19. edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
  20. edsl/auto/StagePersonaDimensionValues.py +74 -74
  21. edsl/auto/StagePersonaDimensions.py +69 -69
  22. edsl/auto/StageQuestions.py +73 -73
  23. edsl/auto/SurveyCreatorPipeline.py +21 -21
  24. edsl/auto/utilities.py +224 -224
  25. edsl/base/Base.py +289 -289
  26. edsl/config.py +149 -149
  27. edsl/conjure/AgentConstructionMixin.py +152 -152
  28. edsl/conjure/Conjure.py +62 -62
  29. edsl/conjure/InputData.py +659 -659
  30. edsl/conjure/InputDataCSV.py +48 -48
  31. edsl/conjure/InputDataMixinQuestionStats.py +182 -182
  32. edsl/conjure/InputDataPyRead.py +91 -91
  33. edsl/conjure/InputDataSPSS.py +8 -8
  34. edsl/conjure/InputDataStata.py +8 -8
  35. edsl/conjure/QuestionOptionMixin.py +76 -76
  36. edsl/conjure/QuestionTypeMixin.py +23 -23
  37. edsl/conjure/RawQuestion.py +65 -65
  38. edsl/conjure/SurveyResponses.py +7 -7
  39. edsl/conjure/__init__.py +9 -9
  40. edsl/conjure/naming_utilities.py +263 -263
  41. edsl/conjure/utilities.py +201 -201
  42. edsl/conversation/Conversation.py +238 -238
  43. edsl/conversation/car_buying.py +58 -58
  44. edsl/conversation/mug_negotiation.py +81 -81
  45. edsl/conversation/next_speaker_utilities.py +93 -93
  46. edsl/coop/PriceFetcher.py +54 -54
  47. edsl/coop/__init__.py +2 -2
  48. edsl/coop/coop.py +849 -849
  49. edsl/coop/utils.py +131 -131
  50. edsl/data/Cache.py +527 -527
  51. edsl/data/CacheEntry.py +228 -228
  52. edsl/data/CacheHandler.py +149 -149
  53. edsl/data/RemoteCacheSync.py +83 -83
  54. edsl/data/SQLiteDict.py +292 -292
  55. edsl/data/__init__.py +4 -4
  56. edsl/data/orm.py +10 -10
  57. edsl/data_transfer_models.py +73 -73
  58. edsl/enums.py +173 -173
  59. edsl/exceptions/__init__.py +50 -50
  60. edsl/exceptions/agents.py +40 -40
  61. edsl/exceptions/configuration.py +16 -16
  62. edsl/exceptions/coop.py +10 -10
  63. edsl/exceptions/data.py +14 -14
  64. edsl/exceptions/general.py +34 -34
  65. edsl/exceptions/jobs.py +33 -33
  66. edsl/exceptions/language_models.py +63 -63
  67. edsl/exceptions/prompts.py +15 -15
  68. edsl/exceptions/questions.py +91 -91
  69. edsl/exceptions/results.py +26 -26
  70. edsl/exceptions/surveys.py +34 -34
  71. edsl/inference_services/AnthropicService.py +87 -87
  72. edsl/inference_services/AwsBedrock.py +115 -115
  73. edsl/inference_services/AzureAI.py +217 -217
  74. edsl/inference_services/DeepInfraService.py +18 -18
  75. edsl/inference_services/GoogleService.py +156 -156
  76. edsl/inference_services/GroqService.py +20 -20
  77. edsl/inference_services/InferenceServiceABC.py +147 -147
  78. edsl/inference_services/InferenceServicesCollection.py +72 -68
  79. edsl/inference_services/MistralAIService.py +123 -123
  80. edsl/inference_services/OllamaService.py +18 -18
  81. edsl/inference_services/OpenAIService.py +224 -224
  82. edsl/inference_services/TestService.py +89 -89
  83. edsl/inference_services/TogetherAIService.py +170 -170
  84. edsl/inference_services/models_available_cache.py +118 -94
  85. edsl/inference_services/rate_limits_cache.py +25 -25
  86. edsl/inference_services/registry.py +39 -39
  87. edsl/inference_services/write_available.py +10 -10
  88. edsl/jobs/Answers.py +56 -56
  89. edsl/jobs/Jobs.py +1112 -1112
  90. edsl/jobs/__init__.py +1 -1
  91. edsl/jobs/buckets/BucketCollection.py +63 -63
  92. edsl/jobs/buckets/ModelBuckets.py +65 -65
  93. edsl/jobs/buckets/TokenBucket.py +248 -248
  94. edsl/jobs/interviews/Interview.py +651 -651
  95. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
  96. edsl/jobs/interviews/InterviewExceptionEntry.py +182 -182
  97. edsl/jobs/interviews/InterviewStatistic.py +63 -63
  98. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
  99. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
  100. edsl/jobs/interviews/InterviewStatusLog.py +92 -92
  101. edsl/jobs/interviews/ReportErrors.py +66 -66
  102. edsl/jobs/interviews/interview_status_enum.py +9 -9
  103. edsl/jobs/runners/JobsRunnerAsyncio.py +337 -337
  104. edsl/jobs/runners/JobsRunnerStatus.py +332 -332
  105. edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
  106. edsl/jobs/tasks/TaskCreators.py +64 -64
  107. edsl/jobs/tasks/TaskHistory.py +441 -441
  108. edsl/jobs/tasks/TaskStatusLog.py +23 -23
  109. edsl/jobs/tasks/task_status_enum.py +163 -163
  110. edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
  111. edsl/jobs/tokens/TokenUsage.py +34 -34
  112. edsl/language_models/LanguageModel.py +718 -718
  113. edsl/language_models/ModelList.py +102 -102
  114. edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
  115. edsl/language_models/__init__.py +2 -2
  116. edsl/language_models/fake_openai_call.py +15 -15
  117. edsl/language_models/fake_openai_service.py +61 -61
  118. edsl/language_models/registry.py +137 -137
  119. edsl/language_models/repair.py +156 -156
  120. edsl/language_models/unused/ReplicateBase.py +83 -83
  121. edsl/language_models/utilities.py +64 -64
  122. edsl/notebooks/Notebook.py +259 -259
  123. edsl/notebooks/__init__.py +1 -1
  124. edsl/prompts/Prompt.py +358 -358
  125. edsl/prompts/__init__.py +2 -2
  126. edsl/questions/AnswerValidatorMixin.py +289 -289
  127. edsl/questions/QuestionBase.py +616 -616
  128. edsl/questions/QuestionBaseGenMixin.py +161 -161
  129. edsl/questions/QuestionBasePromptsMixin.py +266 -266
  130. edsl/questions/QuestionBudget.py +227 -227
  131. edsl/questions/QuestionCheckBox.py +359 -359
  132. edsl/questions/QuestionExtract.py +183 -183
  133. edsl/questions/QuestionFreeText.py +113 -113
  134. edsl/questions/QuestionFunctional.py +159 -159
  135. edsl/questions/QuestionList.py +231 -231
  136. edsl/questions/QuestionMultipleChoice.py +286 -286
  137. edsl/questions/QuestionNumerical.py +153 -153
  138. edsl/questions/QuestionRank.py +324 -324
  139. edsl/questions/Quick.py +41 -41
  140. edsl/questions/RegisterQuestionsMeta.py +71 -71
  141. edsl/questions/ResponseValidatorABC.py +174 -174
  142. edsl/questions/SimpleAskMixin.py +73 -73
  143. edsl/questions/__init__.py +26 -26
  144. edsl/questions/compose_questions.py +98 -98
  145. edsl/questions/decorators.py +21 -21
  146. edsl/questions/derived/QuestionLikertFive.py +76 -76
  147. edsl/questions/derived/QuestionLinearScale.py +87 -87
  148. edsl/questions/derived/QuestionTopK.py +91 -91
  149. edsl/questions/derived/QuestionYesNo.py +82 -82
  150. edsl/questions/descriptors.py +418 -418
  151. edsl/questions/prompt_templates/question_budget.jinja +13 -13
  152. edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
  153. edsl/questions/prompt_templates/question_extract.jinja +11 -11
  154. edsl/questions/prompt_templates/question_free_text.jinja +3 -3
  155. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
  156. edsl/questions/prompt_templates/question_list.jinja +17 -17
  157. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
  158. edsl/questions/prompt_templates/question_numerical.jinja +36 -36
  159. edsl/questions/question_registry.py +147 -147
  160. edsl/questions/settings.py +12 -12
  161. edsl/questions/templates/budget/answering_instructions.jinja +7 -7
  162. edsl/questions/templates/budget/question_presentation.jinja +7 -7
  163. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
  164. edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
  165. edsl/questions/templates/extract/answering_instructions.jinja +7 -7
  166. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
  167. edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
  168. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
  169. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
  170. edsl/questions/templates/list/answering_instructions.jinja +3 -3
  171. edsl/questions/templates/list/question_presentation.jinja +5 -5
  172. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
  173. edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
  174. edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
  175. edsl/questions/templates/numerical/question_presentation.jinja +6 -6
  176. edsl/questions/templates/rank/answering_instructions.jinja +11 -11
  177. edsl/questions/templates/rank/question_presentation.jinja +15 -15
  178. edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
  179. edsl/questions/templates/top_k/question_presentation.jinja +22 -22
  180. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
  181. edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
  182. edsl/results/Dataset.py +293 -293
  183. edsl/results/DatasetExportMixin.py +693 -693
  184. edsl/results/DatasetTree.py +145 -145
  185. edsl/results/Result.py +433 -433
  186. edsl/results/Results.py +1158 -1158
  187. edsl/results/ResultsDBMixin.py +238 -238
  188. edsl/results/ResultsExportMixin.py +43 -43
  189. edsl/results/ResultsFetchMixin.py +33 -33
  190. edsl/results/ResultsGGMixin.py +121 -121
  191. edsl/results/ResultsToolsMixin.py +98 -98
  192. edsl/results/Selector.py +118 -118
  193. edsl/results/__init__.py +2 -2
  194. edsl/results/tree_explore.py +115 -115
  195. edsl/scenarios/FileStore.py +443 -443
  196. edsl/scenarios/Scenario.py +507 -507
  197. edsl/scenarios/ScenarioHtmlMixin.py +59 -59
  198. edsl/scenarios/ScenarioList.py +1101 -1101
  199. edsl/scenarios/ScenarioListExportMixin.py +52 -52
  200. edsl/scenarios/ScenarioListPdfMixin.py +261 -261
  201. edsl/scenarios/__init__.py +2 -2
  202. edsl/shared.py +1 -1
  203. edsl/study/ObjectEntry.py +173 -173
  204. edsl/study/ProofOfWork.py +113 -113
  205. edsl/study/SnapShot.py +80 -80
  206. edsl/study/Study.py +528 -528
  207. edsl/study/__init__.py +4 -4
  208. edsl/surveys/DAG.py +148 -148
  209. edsl/surveys/Memory.py +31 -31
  210. edsl/surveys/MemoryPlan.py +244 -244
  211. edsl/surveys/Rule.py +324 -324
  212. edsl/surveys/RuleCollection.py +387 -387
  213. edsl/surveys/Survey.py +1772 -1772
  214. edsl/surveys/SurveyCSS.py +261 -261
  215. edsl/surveys/SurveyExportMixin.py +259 -259
  216. edsl/surveys/SurveyFlowVisualizationMixin.py +121 -121
  217. edsl/surveys/SurveyQualtricsImport.py +284 -284
  218. edsl/surveys/__init__.py +3 -3
  219. edsl/surveys/base.py +53 -53
  220. edsl/surveys/descriptors.py +56 -56
  221. edsl/surveys/instructions/ChangeInstruction.py +47 -47
  222. edsl/surveys/instructions/Instruction.py +51 -51
  223. edsl/surveys/instructions/InstructionCollection.py +77 -77
  224. edsl/templates/error_reporting/base.html +23 -23
  225. edsl/templates/error_reporting/exceptions_by_model.html +34 -34
  226. edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
  227. edsl/templates/error_reporting/exceptions_by_type.html +16 -16
  228. edsl/templates/error_reporting/interview_details.html +115 -115
  229. edsl/templates/error_reporting/interviews.html +9 -9
  230. edsl/templates/error_reporting/overview.html +4 -4
  231. edsl/templates/error_reporting/performance_plot.html +1 -1
  232. edsl/templates/error_reporting/report.css +73 -73
  233. edsl/templates/error_reporting/report.html +117 -117
  234. edsl/templates/error_reporting/report.js +25 -25
  235. edsl/tools/__init__.py +1 -1
  236. edsl/tools/clusters.py +192 -192
  237. edsl/tools/embeddings.py +27 -27
  238. edsl/tools/embeddings_plotting.py +118 -118
  239. edsl/tools/plotting.py +112 -112
  240. edsl/tools/summarize.py +18 -18
  241. edsl/utilities/SystemInfo.py +28 -28
  242. edsl/utilities/__init__.py +22 -22
  243. edsl/utilities/ast_utilities.py +25 -25
  244. edsl/utilities/data/Registry.py +6 -6
  245. edsl/utilities/data/__init__.py +1 -1
  246. edsl/utilities/data/scooter_results.json +1 -1
  247. edsl/utilities/decorators.py +77 -77
  248. edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
  249. edsl/utilities/interface.py +627 -627
  250. edsl/utilities/repair_functions.py +28 -28
  251. edsl/utilities/restricted_python.py +70 -70
  252. edsl/utilities/utilities.py +391 -391
  253. {edsl-0.1.36.dev5.dist-info → edsl-0.1.36.dev6.dist-info}/LICENSE +21 -21
  254. {edsl-0.1.36.dev5.dist-info → edsl-0.1.36.dev6.dist-info}/METADATA +1 -1
  255. edsl-0.1.36.dev6.dist-info/RECORD +279 -0
  256. edsl-0.1.36.dev5.dist-info/RECORD +0 -279
  257. {edsl-0.1.36.dev5.dist-info → edsl-0.1.36.dev6.dist-info}/WHEEL +0 -0
@@ -1,651 +1,651 @@
1
- """This module contains the Interview class, which is responsible for conducting an interview asynchronously."""
2
-
3
- from __future__ import annotations
4
- import asyncio
5
- from typing import Any, Type, List, Generator, Optional, Union
6
- import copy
7
-
8
- from tenacity import (
9
- retry,
10
- stop_after_attempt,
11
- wait_exponential,
12
- retry_if_exception_type,
13
- RetryError,
14
- )
15
-
16
- from edsl import CONFIG
17
- from edsl.surveys.base import EndOfSurvey
18
- from edsl.exceptions import QuestionAnswerValidationError
19
- from edsl.exceptions import QuestionAnswerValidationError
20
- from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
21
-
22
- from edsl.jobs.buckets.ModelBuckets import ModelBuckets
23
- from edsl.jobs.Answers import Answers
24
- from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
25
- from edsl.jobs.tasks.TaskCreators import TaskCreators
26
- from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
27
- from edsl.jobs.interviews.InterviewExceptionCollection import (
28
- InterviewExceptionCollection,
29
- )
30
-
31
- # from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
32
-
33
- from edsl.surveys.base import EndOfSurvey
34
- from edsl.jobs.buckets.ModelBuckets import ModelBuckets
35
- from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
36
- from edsl.jobs.tasks.task_status_enum import TaskStatus
37
- from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
38
-
39
-
40
- from edsl import Agent, Survey, Scenario, Cache
41
- from edsl.language_models import LanguageModel
42
- from edsl.questions import QuestionBase
43
- from edsl.agents.InvigilatorBase import InvigilatorBase
44
-
45
- from edsl.exceptions.language_models import LanguageModelNoResponseError
46
-
47
- from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
48
- from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
49
- from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
50
-
51
-
52
- from edsl import CONFIG
53
-
54
- EDSL_BACKOFF_START_SEC = float(CONFIG.get("EDSL_BACKOFF_START_SEC"))
55
- EDSL_BACKOFF_MAX_SEC = float(CONFIG.get("EDSL_BACKOFF_MAX_SEC"))
56
- EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
57
-
58
-
59
- class Interview:
60
- """
61
- An 'interview' is one agent answering one survey, with one language model, for a given scenario.
62
-
63
- The main method is `async_conduct_interview`, which conducts the interview asynchronously.
64
- Most of the class is dedicated to creating the tasks for each question in the survey, and then running them.
65
- """
66
-
67
- def __init__(
68
- self,
69
- agent: Agent,
70
- survey: Survey,
71
- scenario: Scenario,
72
- model: Type["LanguageModel"],
73
- debug: Optional[bool] = False,
74
- iteration: int = 0,
75
- cache: Optional["Cache"] = None,
76
- sidecar_model: Optional["LanguageModel"] = None,
77
- skip_retry: bool = False,
78
- raise_validation_errors: bool = True,
79
- ):
80
- """Initialize the Interview instance.
81
-
82
- :param agent: the agent being interviewed.
83
- :param survey: the survey being administered to the agent.
84
- :param scenario: the scenario that populates the survey questions.
85
- :param model: the language model used to answer the questions.
86
- :param debug: if True, run without calls to the language model.
87
- :param iteration: the iteration number of the interview.
88
- :param cache: the cache used to store the answers.
89
- :param sidecar_model: a sidecar model used to answer questions.
90
-
91
- >>> i = Interview.example()
92
- >>> i.task_creators
93
- {}
94
-
95
- >>> i.exceptions
96
- {}
97
-
98
- >>> _ = asyncio.run(i.async_conduct_interview())
99
- >>> i.task_status_logs['q0']
100
- [{'log_time': ..., 'value': <TaskStatus.NOT_STARTED: 1>}, {'log_time': ..., 'value': <TaskStatus.WAITING_FOR_DEPENDENCIES: 2>}, {'log_time': ..., 'value': <TaskStatus.API_CALL_IN_PROGRESS: 7>}, {'log_time': ..., 'value': <TaskStatus.SUCCESS: 8>}]
101
-
102
- >>> i.to_index
103
- {'q0': 0, 'q1': 1, 'q2': 2}
104
-
105
- """
106
- self.agent = agent
107
- self.survey = copy.deepcopy(survey)
108
- self.scenario = scenario
109
- self.model = model
110
- self.debug = debug
111
- self.iteration = iteration
112
- self.cache = cache
113
- self.answers: dict[str, str] = (
114
- Answers()
115
- ) # will get filled in as interview progresses
116
- self.sidecar_model = sidecar_model
117
-
118
- # Trackers
119
- self.task_creators = TaskCreators() # tracks the task creators
120
- self.exceptions = InterviewExceptionCollection()
121
-
122
- self._task_status_log_dict = InterviewStatusLog()
123
- self.skip_retry = skip_retry
124
- self.raise_validation_errors = raise_validation_errors
125
-
126
- # dictionary mapping question names to their index in the survey.
127
- self.to_index = {
128
- question_name: index
129
- for index, question_name in enumerate(self.survey.question_names)
130
- }
131
-
132
- self.failed_questions = []
133
-
134
- @property
135
- def has_exceptions(self) -> bool:
136
- """Return True if there are exceptions."""
137
- return len(self.exceptions) > 0
138
-
139
- @property
140
- def task_status_logs(self) -> InterviewStatusLog:
141
- """Return the task status logs for the interview.
142
-
143
- The keys are the question names; the values are the lists of status log changes for each task.
144
- """
145
- for task_creator in self.task_creators.values():
146
- self._task_status_log_dict[task_creator.question.question_name] = (
147
- task_creator.status_log
148
- )
149
- return self._task_status_log_dict
150
-
151
- @property
152
- def token_usage(self) -> InterviewTokenUsage:
153
- """Determine how many tokens were used for the interview."""
154
- return self.task_creators.token_usage
155
-
156
- @property
157
- def interview_status(self) -> InterviewStatusDictionary:
158
- """Return a dictionary mapping task status codes to counts."""
159
- return self.task_creators.interview_status
160
-
161
- # region: Serialization
162
- def _to_dict(self, include_exceptions=True) -> dict[str, Any]:
163
- """Return a dictionary representation of the Interview instance.
164
- This is just for hashing purposes.
165
-
166
- >>> i = Interview.example()
167
- >>> hash(i)
168
- 1217840301076717434
169
- """
170
- d = {
171
- "agent": self.agent._to_dict(),
172
- "survey": self.survey._to_dict(),
173
- "scenario": self.scenario._to_dict(),
174
- "model": self.model._to_dict(),
175
- "iteration": self.iteration,
176
- "exceptions": {},
177
- }
178
- if include_exceptions:
179
- d["exceptions"] = self.exceptions.to_dict()
180
- return d
181
-
182
- @classmethod
183
- def from_dict(cls, d: dict[str, Any]) -> "Interview":
184
- """Return an Interview instance from a dictionary."""
185
- agent = Agent.from_dict(d["agent"])
186
- survey = Survey.from_dict(d["survey"])
187
- scenario = Scenario.from_dict(d["scenario"])
188
- model = LanguageModel.from_dict(d["model"])
189
- iteration = d["iteration"]
190
- return cls(agent=agent, survey=survey, scenario=scenario, model=model, iteration=iteration)
191
-
192
- def __hash__(self) -> int:
193
- from edsl.utilities.utilities import dict_hash
194
-
195
- return dict_hash(self._to_dict(include_exceptions=False))
196
-
197
- def __eq__(self, other: "Interview") -> bool:
198
- """
199
- >>> from edsl.jobs.interviews.Interview import Interview; i = Interview.example(); d = i._to_dict(); i2 = Interview.from_dict(d); i == i2
200
- True
201
- """
202
- return hash(self) == hash(other)
203
-
204
- # endregion
205
-
206
- # region: Creating tasks
207
- @property
208
- def dag(self) -> "DAG":
209
- """Return the directed acyclic graph for the survey.
210
-
211
- The DAG, or directed acyclic graph, is a dictionary that maps question names to their dependencies.
212
- It is used to determine the order in which questions should be answered.
213
- This reflects both agent 'memory' considerations and 'skip' logic.
214
- The 'textify' parameter is set to True, so that the question names are returned as strings rather than integer indices.
215
-
216
- >>> i = Interview.example()
217
- >>> i.dag == {'q2': {'q0'}, 'q1': {'q0'}}
218
- True
219
- """
220
- return self.survey.dag(textify=True)
221
-
222
- def _build_question_tasks(
223
- self,
224
- model_buckets: ModelBuckets,
225
- ) -> list[asyncio.Task]:
226
- """Create a task for each question, with dependencies on the questions that must be answered before this one can be answered.
227
-
228
- :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
229
- :param model_buckets: the model buckets used to track and control usage rates.
230
- """
231
- tasks = []
232
- for question in self.survey.questions:
233
- tasks_that_must_be_completed_before = list(
234
- self._get_tasks_that_must_be_completed_before(
235
- tasks=tasks, question=question
236
- )
237
- )
238
- question_task = self._create_question_task(
239
- question=question,
240
- tasks_that_must_be_completed_before=tasks_that_must_be_completed_before,
241
- model_buckets=model_buckets,
242
- iteration=self.iteration,
243
- )
244
- tasks.append(question_task)
245
- return tuple(tasks)
246
-
247
- def _get_tasks_that_must_be_completed_before(
248
- self, *, tasks: list[asyncio.Task], question: "QuestionBase"
249
- ) -> Generator[asyncio.Task, None, None]:
250
- """Return the tasks that must be completed before the given question can be answered.
251
-
252
- :param tasks: a list of tasks that have been created so far.
253
- :param question: the question for which we are determining dependencies.
254
-
255
- If a question has no dependencies, this will be an empty list, [].
256
- """
257
- parents_of_focal_question = self.dag.get(question.question_name, [])
258
- for parent_question_name in parents_of_focal_question:
259
- yield tasks[self.to_index[parent_question_name]]
260
-
261
- def _create_question_task(
262
- self,
263
- *,
264
- question: QuestionBase,
265
- tasks_that_must_be_completed_before: list[asyncio.Task],
266
- model_buckets: ModelBuckets,
267
- iteration: int = 0,
268
- ) -> asyncio.Task:
269
- """Create a task that depends on the passed-in dependencies that are awaited before the task is run.
270
-
271
- :param question: the question to be answered. This is the question we are creating a task for.
272
- :param tasks_that_must_be_completed_before: the tasks that must be completed before the focal task is run.
273
- :param model_buckets: the model buckets used to track and control usage rates.
274
- :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
275
- :param iteration: the iteration number for the interview.
276
-
277
- The task is created by a `QuestionTaskCreator`, which is responsible for creating the task and managing its dependencies.
278
- It is passed a reference to the function that will be called to answer the question.
279
- It is passed a list "tasks_that_must_be_completed_before" that are awaited before the task is run.
280
- These are added as a dependency to the focal task.
281
- """
282
- task_creator = QuestionTaskCreator(
283
- question=question,
284
- answer_question_func=self._answer_question_and_record_task,
285
- token_estimator=self._get_estimated_request_tokens,
286
- model_buckets=model_buckets,
287
- iteration=iteration,
288
- )
289
- for task in tasks_that_must_be_completed_before:
290
- task_creator.add_dependency(task)
291
-
292
- self.task_creators.update(
293
- {question.question_name: task_creator}
294
- ) # track this task creator
295
- return task_creator.generate_task()
296
-
297
- def _get_estimated_request_tokens(self, question) -> float:
298
- """Estimate the number of tokens that will be required to run the focal task."""
299
- from edsl.scenarios.FileStore import FileStore
300
-
301
- invigilator = self._get_invigilator(question=question)
302
- # TODO: There should be a way to get a more accurate estimate.
303
- combined_text = ""
304
- file_tokens = 0
305
- for prompt in invigilator.get_prompts().values():
306
- if hasattr(prompt, "text"):
307
- combined_text += prompt.text
308
- elif isinstance(prompt, str):
309
- combined_text += prompt
310
- elif isinstance(prompt, list):
311
- for file in prompt:
312
- if isinstance(file, FileStore):
313
- file_tokens += file.size * 0.25
314
- else:
315
- raise ValueError(f"Prompt is of type {type(prompt)}")
316
- return len(combined_text) / 4.0 + file_tokens
317
-
318
- async def _answer_question_and_record_task(
319
- self,
320
- *,
321
- question: "QuestionBase",
322
- task=None,
323
- ) -> "AgentResponseDict":
324
- """Answer a question and records the task."""
325
-
326
- had_language_model_no_response_error = False
327
-
328
- @retry(
329
- stop=stop_after_attempt(EDSL_MAX_ATTEMPTS),
330
- wait=wait_exponential(
331
- multiplier=EDSL_BACKOFF_START_SEC, max=EDSL_BACKOFF_MAX_SEC
332
- ),
333
- retry=retry_if_exception_type(LanguageModelNoResponseError),
334
- reraise=True,
335
- )
336
- async def attempt_answer():
337
- nonlocal had_language_model_no_response_error
338
-
339
- invigilator = self._get_invigilator(question)
340
-
341
- if self._skip_this_question(question):
342
- return invigilator.get_failed_task_result(
343
- failure_reason="Question skipped."
344
- )
345
-
346
- try:
347
- response: EDSLResultObjectInput = (
348
- await invigilator.async_answer_question()
349
- )
350
- if response.validated:
351
- self.answers.add_answer(response=response, question=question)
352
- self._cancel_skipped_questions(question)
353
- else:
354
- # When a question is not validated, it is not added to the answers.
355
- # this should also cancel and dependent children questions.
356
- # Is that happening now?
357
- if (
358
- hasattr(response, "exception_occurred")
359
- and response.exception_occurred
360
- ):
361
- raise response.exception_occurred
362
-
363
- except QuestionAnswerValidationError as e:
364
- self._handle_exception(e, invigilator, task)
365
- return invigilator.get_failed_task_result(
366
- failure_reason="Question answer validation failed."
367
- )
368
-
369
- except asyncio.TimeoutError as e:
370
- self._handle_exception(e, invigilator, task)
371
- had_language_model_no_response_error = True
372
- raise LanguageModelNoResponseError(
373
- f"Language model timed out for question '{question.question_name}.'"
374
- )
375
-
376
- except Exception as e:
377
- self._handle_exception(e, invigilator, task)
378
-
379
- if "response" not in locals():
380
- had_language_model_no_response_error = True
381
- raise LanguageModelNoResponseError(
382
- f"Language model did not return a response for question '{question.question_name}.'"
383
- )
384
-
385
- # if it gets here, it means the no response error was fixed
386
- if (
387
- question.question_name in self.exceptions
388
- and had_language_model_no_response_error
389
- ):
390
- self.exceptions.record_fixed_question(question.question_name)
391
-
392
- return response
393
-
394
- try:
395
- return await attempt_answer()
396
- except RetryError as retry_error:
397
- # All retries have failed for LanguageModelNoResponseError
398
- original_error = retry_error.last_attempt.exception()
399
- self._handle_exception(
400
- original_error, self._get_invigilator(question), task
401
- )
402
- raise original_error # Re-raise the original error after handling
403
-
404
- def _get_invigilator(self, question: QuestionBase) -> InvigilatorBase:
405
- """Return an invigilator for the given question.
406
-
407
- :param question: the question to be answered
408
- :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
409
- """
410
- invigilator = self.agent.create_invigilator(
411
- question=question,
412
- scenario=self.scenario,
413
- model=self.model,
414
- debug=False,
415
- survey=self.survey,
416
- memory_plan=self.survey.memory_plan,
417
- current_answers=self.answers,
418
- iteration=self.iteration,
419
- cache=self.cache,
420
- sidecar_model=self.sidecar_model,
421
- raise_validation_errors=self.raise_validation_errors,
422
- )
423
- """Return an invigilator for the given question."""
424
- return invigilator
425
-
426
- def _skip_this_question(self, current_question: "QuestionBase") -> bool:
427
- """Determine if the current question should be skipped.
428
-
429
- :param current_question: the question to be answered.
430
- """
431
- current_question_index = self.to_index[current_question.question_name]
432
-
433
- answers = self.answers | self.scenario | self.agent["traits"]
434
- skip = self.survey.rule_collection.skip_question_before_running(
435
- current_question_index, answers
436
- )
437
- return skip
438
-
439
- def _handle_exception(
440
- self, e: Exception, invigilator: "InvigilatorBase", task=None
441
- ):
442
- import copy
443
-
444
- # breakpoint()
445
-
446
- answers = copy.copy(self.answers)
447
- exception_entry = InterviewExceptionEntry(
448
- exception=e,
449
- invigilator=invigilator,
450
- answers=answers,
451
- )
452
- if task:
453
- task.task_status = TaskStatus.FAILED
454
- self.exceptions.add(invigilator.question.question_name, exception_entry)
455
-
456
- if self.raise_validation_errors:
457
- if isinstance(e, QuestionAnswerValidationError):
458
- raise e
459
-
460
- if hasattr(self, "stop_on_exception"):
461
- stop_on_exception = self.stop_on_exception
462
- else:
463
- stop_on_exception = False
464
-
465
- if stop_on_exception:
466
- raise e
467
-
468
- def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
469
- """Cancel the tasks for questions that are skipped.
470
-
471
- :param current_question: the question that was just answered.
472
-
473
- It first determines the next question, given the current question and the current answers.
474
- If the next question is the end of the survey, it cancels all remaining tasks.
475
- If the next question is after the current question, it cancels all tasks between the current question and the next question.
476
- """
477
- current_question_index: int = self.to_index[current_question.question_name]
478
-
479
- next_question: Union[int, EndOfSurvey] = (
480
- self.survey.rule_collection.next_question(
481
- q_now=current_question_index,
482
- answers=self.answers | self.scenario | self.agent["traits"],
483
- )
484
- )
485
-
486
- next_question_index = next_question.next_q
487
-
488
- def cancel_between(start, end):
489
- """Cancel the tasks between the start and end indices."""
490
- for i in range(start, end):
491
- self.tasks[i].cancel()
492
-
493
- if next_question_index == EndOfSurvey:
494
- cancel_between(current_question_index + 1, len(self.survey.questions))
495
- return
496
-
497
- if next_question_index > (current_question_index + 1):
498
- cancel_between(current_question_index + 1, next_question_index)
499
-
500
- # endregion
501
-
502
- # region: Conducting the interview
503
- async def async_conduct_interview(
504
- self,
505
- model_buckets: Optional[ModelBuckets] = None,
506
- stop_on_exception: bool = False,
507
- sidecar_model: Optional["LanguageModel"] = None,
508
- raise_validation_errors: bool = True,
509
- ) -> tuple["Answers", List[dict[str, Any]]]:
510
- """
511
- Conduct an Interview asynchronously.
512
- It returns a tuple with the answers and a list of valid results.
513
-
514
- :param model_buckets: a dictionary of token buckets for the model.
515
- :param debug: run without calls to LLM.
516
- :param stop_on_exception: if True, stops the interview if an exception is raised.
517
- :param sidecar_model: a sidecar model used to answer questions.
518
-
519
- Example usage:
520
-
521
- >>> i = Interview.example()
522
- >>> result, _ = asyncio.run(i.async_conduct_interview())
523
- >>> result['q0']
524
- 'yes'
525
-
526
- >>> i = Interview.example(throw_exception = True)
527
- >>> result, _ = asyncio.run(i.async_conduct_interview())
528
- >>> i.exceptions
529
- {'q0': ...
530
- >>> i = Interview.example()
531
- >>> result, _ = asyncio.run(i.async_conduct_interview(stop_on_exception = True))
532
- Traceback (most recent call last):
533
- ...
534
- asyncio.exceptions.CancelledError
535
- """
536
- self.sidecar_model = sidecar_model
537
- self.stop_on_exception = stop_on_exception
538
-
539
- # if no model bucket is passed, create an 'infinity' bucket with no rate limits
540
- if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
541
- model_buckets = ModelBuckets.infinity_bucket()
542
-
543
- ## This is the key part---it creates a task for each question,
544
- ## with dependencies on the questions that must be answered before this one can be answered.
545
- self.tasks = self._build_question_tasks(model_buckets=model_buckets)
546
-
547
- ## 'Invigilators' are used to administer the survey
548
- self.invigilators = [
549
- self._get_invigilator(question) for question in self.survey.questions
550
- ]
551
- await asyncio.gather(
552
- *self.tasks, return_exceptions=not stop_on_exception
553
- ) # not stop_on_exception)
554
- self.answers.replace_missing_answers_with_none(self.survey)
555
- valid_results = list(self._extract_valid_results())
556
- return self.answers, valid_results
557
-
558
- # endregion
559
-
560
- # region: Extracting results and recording errors
561
- def _extract_valid_results(self) -> Generator["Answers", None, None]:
562
- """Extract the valid results from the list of results.
563
-
564
- It iterates through the tasks and invigilators, and yields the results of the tasks that are done.
565
- If a task is not done, it raises a ValueError.
566
- If an exception is raised in the task, it records the exception in the Interview instance except if the task was cancelled, which is expected behavior.
567
-
568
- >>> i = Interview.example()
569
- >>> result, _ = asyncio.run(i.async_conduct_interview())
570
- >>> results = list(i._extract_valid_results())
571
- >>> len(results) == len(i.survey)
572
- True
573
- """
574
- assert len(self.tasks) == len(self.invigilators)
575
-
576
- for task, invigilator in zip(self.tasks, self.invigilators):
577
- if not task.done():
578
- raise ValueError(f"Task {task.get_name()} is not done.")
579
-
580
- try:
581
- result = task.result()
582
- except asyncio.CancelledError as e: # task was cancelled
583
- result = invigilator.get_failed_task_result(
584
- failure_reason="Task was cancelled."
585
- )
586
- except Exception as e: # any other kind of exception in the task
587
- result = invigilator.get_failed_task_result(
588
- failure_reason=f"Task failed with exception: {str(e)}."
589
- )
590
- exception_entry = InterviewExceptionEntry(
591
- exception=e,
592
- invigilator=invigilator,
593
- )
594
- self.exceptions.add(task.get_name(), exception_entry)
595
-
596
- yield result
597
-
598
- # endregion
599
-
600
- # region: Magic methods
601
- def __repr__(self) -> str:
602
- """Return a string representation of the Interview instance."""
603
- return f"Interview(agent = {repr(self.agent)}, survey = {repr(self.survey)}, scenario = {repr(self.scenario)}, model = {repr(self.model)})"
604
-
605
- def duplicate(self, iteration: int, cache: "Cache") -> Interview:
606
- """Duplicate the interview, but with a new iteration number and cache.
607
-
608
- >>> i = Interview.example()
609
- >>> i2 = i.duplicate(1, None)
610
- >>> i.iteration + 1 == i2.iteration
611
- True
612
-
613
- """
614
- return Interview(
615
- agent=self.agent,
616
- survey=self.survey,
617
- scenario=self.scenario,
618
- model=self.model,
619
- iteration=iteration,
620
- cache=cache,
621
- skip_retry=self.skip_retry,
622
- )
623
-
624
- @classmethod
625
- def example(self, throw_exception: bool = False) -> Interview:
626
- """Return an example Interview instance."""
627
- from edsl.agents import Agent
628
- from edsl.surveys import Survey
629
- from edsl.scenarios import Scenario
630
- from edsl.language_models import LanguageModel
631
-
632
- def f(self, question, scenario):
633
- return "yes"
634
-
635
- agent = Agent.example()
636
- agent.add_direct_question_answering_method(f)
637
- survey = Survey.example()
638
- scenario = Scenario.example()
639
- model = LanguageModel.example()
640
- if throw_exception:
641
- model = LanguageModel.example(test_model=True, throw_exception=True)
642
- agent = Agent.example()
643
- return Interview(agent=agent, survey=survey, scenario=scenario, model=model)
644
- return Interview(agent=agent, survey=survey, scenario=scenario, model=model)
645
-
646
-
647
- if __name__ == "__main__":
648
- import doctest
649
-
650
- # add ellipsis
651
- doctest.testmod(optionflags=doctest.ELLIPSIS)
1
+ """This module contains the Interview class, which is responsible for conducting an interview asynchronously."""
2
+
3
+ from __future__ import annotations
4
+ import asyncio
5
+ from typing import Any, Type, List, Generator, Optional, Union
6
+ import copy
7
+
8
+ from tenacity import (
9
+ retry,
10
+ stop_after_attempt,
11
+ wait_exponential,
12
+ retry_if_exception_type,
13
+ RetryError,
14
+ )
15
+
16
+ from edsl import CONFIG
17
+ from edsl.surveys.base import EndOfSurvey
18
+ from edsl.exceptions import QuestionAnswerValidationError
19
+ from edsl.exceptions import QuestionAnswerValidationError
20
+ from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
21
+
22
+ from edsl.jobs.buckets.ModelBuckets import ModelBuckets
23
+ from edsl.jobs.Answers import Answers
24
+ from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
25
+ from edsl.jobs.tasks.TaskCreators import TaskCreators
26
+ from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
27
+ from edsl.jobs.interviews.InterviewExceptionCollection import (
28
+ InterviewExceptionCollection,
29
+ )
30
+
31
+ # from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
32
+
33
+ from edsl.surveys.base import EndOfSurvey
34
+ from edsl.jobs.buckets.ModelBuckets import ModelBuckets
35
+ from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
36
+ from edsl.jobs.tasks.task_status_enum import TaskStatus
37
+ from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
38
+
39
+
40
+ from edsl import Agent, Survey, Scenario, Cache
41
+ from edsl.language_models import LanguageModel
42
+ from edsl.questions import QuestionBase
43
+ from edsl.agents.InvigilatorBase import InvigilatorBase
44
+
45
+ from edsl.exceptions.language_models import LanguageModelNoResponseError
46
+
47
+ from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
48
+ from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
49
+ from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
50
+
51
+
52
+ from edsl import CONFIG
53
+
54
+ EDSL_BACKOFF_START_SEC = float(CONFIG.get("EDSL_BACKOFF_START_SEC"))
55
+ EDSL_BACKOFF_MAX_SEC = float(CONFIG.get("EDSL_BACKOFF_MAX_SEC"))
56
+ EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
57
+
58
+
59
+ class Interview:
60
+ """
61
+ An 'interview' is one agent answering one survey, with one language model, for a given scenario.
62
+
63
+ The main method is `async_conduct_interview`, which conducts the interview asynchronously.
64
+ Most of the class is dedicated to creating the tasks for each question in the survey, and then running them.
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ agent: Agent,
70
+ survey: Survey,
71
+ scenario: Scenario,
72
+ model: Type["LanguageModel"],
73
+ debug: Optional[bool] = False,
74
+ iteration: int = 0,
75
+ cache: Optional["Cache"] = None,
76
+ sidecar_model: Optional["LanguageModel"] = None,
77
+ skip_retry: bool = False,
78
+ raise_validation_errors: bool = True,
79
+ ):
80
+ """Initialize the Interview instance.
81
+
82
+ :param agent: the agent being interviewed.
83
+ :param survey: the survey being administered to the agent.
84
+ :param scenario: the scenario that populates the survey questions.
85
+ :param model: the language model used to answer the questions.
86
+ :param debug: if True, run without calls to the language model.
87
+ :param iteration: the iteration number of the interview.
88
+ :param cache: the cache used to store the answers.
89
+ :param sidecar_model: a sidecar model used to answer questions.
90
+
91
+ >>> i = Interview.example()
92
+ >>> i.task_creators
93
+ {}
94
+
95
+ >>> i.exceptions
96
+ {}
97
+
98
+ >>> _ = asyncio.run(i.async_conduct_interview())
99
+ >>> i.task_status_logs['q0']
100
+ [{'log_time': ..., 'value': <TaskStatus.NOT_STARTED: 1>}, {'log_time': ..., 'value': <TaskStatus.WAITING_FOR_DEPENDENCIES: 2>}, {'log_time': ..., 'value': <TaskStatus.API_CALL_IN_PROGRESS: 7>}, {'log_time': ..., 'value': <TaskStatus.SUCCESS: 8>}]
101
+
102
+ >>> i.to_index
103
+ {'q0': 0, 'q1': 1, 'q2': 2}
104
+
105
+ """
106
+ self.agent = agent
107
+ self.survey = copy.deepcopy(survey)
108
+ self.scenario = scenario
109
+ self.model = model
110
+ self.debug = debug
111
+ self.iteration = iteration
112
+ self.cache = cache
113
+ self.answers: dict[str, str] = (
114
+ Answers()
115
+ ) # will get filled in as interview progresses
116
+ self.sidecar_model = sidecar_model
117
+
118
+ # Trackers
119
+ self.task_creators = TaskCreators() # tracks the task creators
120
+ self.exceptions = InterviewExceptionCollection()
121
+
122
+ self._task_status_log_dict = InterviewStatusLog()
123
+ self.skip_retry = skip_retry
124
+ self.raise_validation_errors = raise_validation_errors
125
+
126
+ # dictionary mapping question names to their index in the survey.
127
+ self.to_index = {
128
+ question_name: index
129
+ for index, question_name in enumerate(self.survey.question_names)
130
+ }
131
+
132
+ self.failed_questions = []
133
+
134
+ @property
135
+ def has_exceptions(self) -> bool:
136
+ """Return True if there are exceptions."""
137
+ return len(self.exceptions) > 0
138
+
139
+ @property
140
+ def task_status_logs(self) -> InterviewStatusLog:
141
+ """Return the task status logs for the interview.
142
+
143
+ The keys are the question names; the values are the lists of status log changes for each task.
144
+ """
145
+ for task_creator in self.task_creators.values():
146
+ self._task_status_log_dict[task_creator.question.question_name] = (
147
+ task_creator.status_log
148
+ )
149
+ return self._task_status_log_dict
150
+
151
+ @property
152
+ def token_usage(self) -> InterviewTokenUsage:
153
+ """Determine how many tokens were used for the interview."""
154
+ return self.task_creators.token_usage
155
+
156
+ @property
157
+ def interview_status(self) -> InterviewStatusDictionary:
158
+ """Return a dictionary mapping task status codes to counts."""
159
+ return self.task_creators.interview_status
160
+
161
+ # region: Serialization
162
+ def _to_dict(self, include_exceptions=True) -> dict[str, Any]:
163
+ """Return a dictionary representation of the Interview instance.
164
+ This is just for hashing purposes.
165
+
166
+ >>> i = Interview.example()
167
+ >>> hash(i)
168
+ 1217840301076717434
169
+ """
170
+ d = {
171
+ "agent": self.agent._to_dict(),
172
+ "survey": self.survey._to_dict(),
173
+ "scenario": self.scenario._to_dict(),
174
+ "model": self.model._to_dict(),
175
+ "iteration": self.iteration,
176
+ "exceptions": {},
177
+ }
178
+ if include_exceptions:
179
+ d["exceptions"] = self.exceptions.to_dict()
180
+ return d
181
+
182
+ @classmethod
183
+ def from_dict(cls, d: dict[str, Any]) -> "Interview":
184
+ """Return an Interview instance from a dictionary."""
185
+ agent = Agent.from_dict(d["agent"])
186
+ survey = Survey.from_dict(d["survey"])
187
+ scenario = Scenario.from_dict(d["scenario"])
188
+ model = LanguageModel.from_dict(d["model"])
189
+ iteration = d["iteration"]
190
+ return cls(agent=agent, survey=survey, scenario=scenario, model=model, iteration=iteration)
191
+
192
+ def __hash__(self) -> int:
193
+ from edsl.utilities.utilities import dict_hash
194
+
195
+ return dict_hash(self._to_dict(include_exceptions=False))
196
+
197
+ def __eq__(self, other: "Interview") -> bool:
198
+ """
199
+ >>> from edsl.jobs.interviews.Interview import Interview; i = Interview.example(); d = i._to_dict(); i2 = Interview.from_dict(d); i == i2
200
+ True
201
+ """
202
+ return hash(self) == hash(other)
203
+
204
+ # endregion
205
+
206
+ # region: Creating tasks
207
+ @property
208
+ def dag(self) -> "DAG":
209
+ """Return the directed acyclic graph for the survey.
210
+
211
+ The DAG, or directed acyclic graph, is a dictionary that maps question names to their dependencies.
212
+ It is used to determine the order in which questions should be answered.
213
+ This reflects both agent 'memory' considerations and 'skip' logic.
214
+ The 'textify' parameter is set to True, so that the question names are returned as strings rather than integer indices.
215
+
216
+ >>> i = Interview.example()
217
+ >>> i.dag == {'q2': {'q0'}, 'q1': {'q0'}}
218
+ True
219
+ """
220
+ return self.survey.dag(textify=True)
221
+
222
+ def _build_question_tasks(
223
+ self,
224
+ model_buckets: ModelBuckets,
225
+ ) -> list[asyncio.Task]:
226
+ """Create a task for each question, with dependencies on the questions that must be answered before this one can be answered.
227
+
228
+ :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
229
+ :param model_buckets: the model buckets used to track and control usage rates.
230
+ """
231
+ tasks = []
232
+ for question in self.survey.questions:
233
+ tasks_that_must_be_completed_before = list(
234
+ self._get_tasks_that_must_be_completed_before(
235
+ tasks=tasks, question=question
236
+ )
237
+ )
238
+ question_task = self._create_question_task(
239
+ question=question,
240
+ tasks_that_must_be_completed_before=tasks_that_must_be_completed_before,
241
+ model_buckets=model_buckets,
242
+ iteration=self.iteration,
243
+ )
244
+ tasks.append(question_task)
245
+ return tuple(tasks)
246
+
247
+ def _get_tasks_that_must_be_completed_before(
248
+ self, *, tasks: list[asyncio.Task], question: "QuestionBase"
249
+ ) -> Generator[asyncio.Task, None, None]:
250
+ """Return the tasks that must be completed before the given question can be answered.
251
+
252
+ :param tasks: a list of tasks that have been created so far.
253
+ :param question: the question for which we are determining dependencies.
254
+
255
+ If a question has no dependencies, this will be an empty list, [].
256
+ """
257
+ parents_of_focal_question = self.dag.get(question.question_name, [])
258
+ for parent_question_name in parents_of_focal_question:
259
+ yield tasks[self.to_index[parent_question_name]]
260
+
261
+ def _create_question_task(
262
+ self,
263
+ *,
264
+ question: QuestionBase,
265
+ tasks_that_must_be_completed_before: list[asyncio.Task],
266
+ model_buckets: ModelBuckets,
267
+ iteration: int = 0,
268
+ ) -> asyncio.Task:
269
+ """Create a task that depends on the passed-in dependencies that are awaited before the task is run.
270
+
271
+ :param question: the question to be answered. This is the question we are creating a task for.
272
+ :param tasks_that_must_be_completed_before: the tasks that must be completed before the focal task is run.
273
+ :param model_buckets: the model buckets used to track and control usage rates.
274
+ :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
275
+ :param iteration: the iteration number for the interview.
276
+
277
+ The task is created by a `QuestionTaskCreator`, which is responsible for creating the task and managing its dependencies.
278
+ It is passed a reference to the function that will be called to answer the question.
279
+ It is passed a list "tasks_that_must_be_completed_before" that are awaited before the task is run.
280
+ These are added as a dependency to the focal task.
281
+ """
282
+ task_creator = QuestionTaskCreator(
283
+ question=question,
284
+ answer_question_func=self._answer_question_and_record_task,
285
+ token_estimator=self._get_estimated_request_tokens,
286
+ model_buckets=model_buckets,
287
+ iteration=iteration,
288
+ )
289
+ for task in tasks_that_must_be_completed_before:
290
+ task_creator.add_dependency(task)
291
+
292
+ self.task_creators.update(
293
+ {question.question_name: task_creator}
294
+ ) # track this task creator
295
+ return task_creator.generate_task()
296
+
297
+ def _get_estimated_request_tokens(self, question) -> float:
298
+ """Estimate the number of tokens that will be required to run the focal task."""
299
+ from edsl.scenarios.FileStore import FileStore
300
+
301
+ invigilator = self._get_invigilator(question=question)
302
+ # TODO: There should be a way to get a more accurate estimate.
303
+ combined_text = ""
304
+ file_tokens = 0
305
+ for prompt in invigilator.get_prompts().values():
306
+ if hasattr(prompt, "text"):
307
+ combined_text += prompt.text
308
+ elif isinstance(prompt, str):
309
+ combined_text += prompt
310
+ elif isinstance(prompt, list):
311
+ for file in prompt:
312
+ if isinstance(file, FileStore):
313
+ file_tokens += file.size * 0.25
314
+ else:
315
+ raise ValueError(f"Prompt is of type {type(prompt)}")
316
+ return len(combined_text) / 4.0 + file_tokens
317
+
318
+ async def _answer_question_and_record_task(
319
+ self,
320
+ *,
321
+ question: "QuestionBase",
322
+ task=None,
323
+ ) -> "AgentResponseDict":
324
+ """Answer a question and records the task."""
325
+
326
+ had_language_model_no_response_error = False
327
+
328
+ @retry(
329
+ stop=stop_after_attempt(EDSL_MAX_ATTEMPTS),
330
+ wait=wait_exponential(
331
+ multiplier=EDSL_BACKOFF_START_SEC, max=EDSL_BACKOFF_MAX_SEC
332
+ ),
333
+ retry=retry_if_exception_type(LanguageModelNoResponseError),
334
+ reraise=True,
335
+ )
336
+ async def attempt_answer():
337
+ nonlocal had_language_model_no_response_error
338
+
339
+ invigilator = self._get_invigilator(question)
340
+
341
+ if self._skip_this_question(question):
342
+ return invigilator.get_failed_task_result(
343
+ failure_reason="Question skipped."
344
+ )
345
+
346
+ try:
347
+ response: EDSLResultObjectInput = (
348
+ await invigilator.async_answer_question()
349
+ )
350
+ if response.validated:
351
+ self.answers.add_answer(response=response, question=question)
352
+ self._cancel_skipped_questions(question)
353
+ else:
354
+ # When a question is not validated, it is not added to the answers.
355
+ # this should also cancel and dependent children questions.
356
+ # Is that happening now?
357
+ if (
358
+ hasattr(response, "exception_occurred")
359
+ and response.exception_occurred
360
+ ):
361
+ raise response.exception_occurred
362
+
363
+ except QuestionAnswerValidationError as e:
364
+ self._handle_exception(e, invigilator, task)
365
+ return invigilator.get_failed_task_result(
366
+ failure_reason="Question answer validation failed."
367
+ )
368
+
369
+ except asyncio.TimeoutError as e:
370
+ self._handle_exception(e, invigilator, task)
371
+ had_language_model_no_response_error = True
372
+ raise LanguageModelNoResponseError(
373
+ f"Language model timed out for question '{question.question_name}.'"
374
+ )
375
+
376
+ except Exception as e:
377
+ self._handle_exception(e, invigilator, task)
378
+
379
+ if "response" not in locals():
380
+ had_language_model_no_response_error = True
381
+ raise LanguageModelNoResponseError(
382
+ f"Language model did not return a response for question '{question.question_name}.'"
383
+ )
384
+
385
+ # if it gets here, it means the no response error was fixed
386
+ if (
387
+ question.question_name in self.exceptions
388
+ and had_language_model_no_response_error
389
+ ):
390
+ self.exceptions.record_fixed_question(question.question_name)
391
+
392
+ return response
393
+
394
+ try:
395
+ return await attempt_answer()
396
+ except RetryError as retry_error:
397
+ # All retries have failed for LanguageModelNoResponseError
398
+ original_error = retry_error.last_attempt.exception()
399
+ self._handle_exception(
400
+ original_error, self._get_invigilator(question), task
401
+ )
402
+ raise original_error # Re-raise the original error after handling
403
+
404
+ def _get_invigilator(self, question: QuestionBase) -> InvigilatorBase:
405
+ """Return an invigilator for the given question.
406
+
407
+ :param question: the question to be answered
408
+ :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
409
+ """
410
+ invigilator = self.agent.create_invigilator(
411
+ question=question,
412
+ scenario=self.scenario,
413
+ model=self.model,
414
+ debug=False,
415
+ survey=self.survey,
416
+ memory_plan=self.survey.memory_plan,
417
+ current_answers=self.answers,
418
+ iteration=self.iteration,
419
+ cache=self.cache,
420
+ sidecar_model=self.sidecar_model,
421
+ raise_validation_errors=self.raise_validation_errors,
422
+ )
423
+ """Return an invigilator for the given question."""
424
+ return invigilator
425
+
426
+ def _skip_this_question(self, current_question: "QuestionBase") -> bool:
427
+ """Determine if the current question should be skipped.
428
+
429
+ :param current_question: the question to be answered.
430
+ """
431
+ current_question_index = self.to_index[current_question.question_name]
432
+
433
+ answers = self.answers | self.scenario | self.agent["traits"]
434
+ skip = self.survey.rule_collection.skip_question_before_running(
435
+ current_question_index, answers
436
+ )
437
+ return skip
438
+
439
+ def _handle_exception(
440
+ self, e: Exception, invigilator: "InvigilatorBase", task=None
441
+ ):
442
+ import copy
443
+
444
+ # breakpoint()
445
+
446
+ answers = copy.copy(self.answers)
447
+ exception_entry = InterviewExceptionEntry(
448
+ exception=e,
449
+ invigilator=invigilator,
450
+ answers=answers,
451
+ )
452
+ if task:
453
+ task.task_status = TaskStatus.FAILED
454
+ self.exceptions.add(invigilator.question.question_name, exception_entry)
455
+
456
+ if self.raise_validation_errors:
457
+ if isinstance(e, QuestionAnswerValidationError):
458
+ raise e
459
+
460
+ if hasattr(self, "stop_on_exception"):
461
+ stop_on_exception = self.stop_on_exception
462
+ else:
463
+ stop_on_exception = False
464
+
465
+ if stop_on_exception:
466
+ raise e
467
+
468
+ def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
469
+ """Cancel the tasks for questions that are skipped.
470
+
471
+ :param current_question: the question that was just answered.
472
+
473
+ It first determines the next question, given the current question and the current answers.
474
+ If the next question is the end of the survey, it cancels all remaining tasks.
475
+ If the next question is after the current question, it cancels all tasks between the current question and the next question.
476
+ """
477
+ current_question_index: int = self.to_index[current_question.question_name]
478
+
479
+ next_question: Union[int, EndOfSurvey] = (
480
+ self.survey.rule_collection.next_question(
481
+ q_now=current_question_index,
482
+ answers=self.answers | self.scenario | self.agent["traits"],
483
+ )
484
+ )
485
+
486
+ next_question_index = next_question.next_q
487
+
488
+ def cancel_between(start, end):
489
+ """Cancel the tasks between the start and end indices."""
490
+ for i in range(start, end):
491
+ self.tasks[i].cancel()
492
+
493
+ if next_question_index == EndOfSurvey:
494
+ cancel_between(current_question_index + 1, len(self.survey.questions))
495
+ return
496
+
497
+ if next_question_index > (current_question_index + 1):
498
+ cancel_between(current_question_index + 1, next_question_index)
499
+
500
+ # endregion
501
+
502
+ # region: Conducting the interview
503
+ async def async_conduct_interview(
504
+ self,
505
+ model_buckets: Optional[ModelBuckets] = None,
506
+ stop_on_exception: bool = False,
507
+ sidecar_model: Optional["LanguageModel"] = None,
508
+ raise_validation_errors: bool = True,
509
+ ) -> tuple["Answers", List[dict[str, Any]]]:
510
+ """
511
+ Conduct an Interview asynchronously.
512
+ It returns a tuple with the answers and a list of valid results.
513
+
514
+ :param model_buckets: a dictionary of token buckets for the model.
515
+ :param debug: run without calls to LLM.
516
+ :param stop_on_exception: if True, stops the interview if an exception is raised.
517
+ :param sidecar_model: a sidecar model used to answer questions.
518
+
519
+ Example usage:
520
+
521
+ >>> i = Interview.example()
522
+ >>> result, _ = asyncio.run(i.async_conduct_interview())
523
+ >>> result['q0']
524
+ 'yes'
525
+
526
+ >>> i = Interview.example(throw_exception = True)
527
+ >>> result, _ = asyncio.run(i.async_conduct_interview())
528
+ >>> i.exceptions
529
+ {'q0': ...
530
+ >>> i = Interview.example()
531
+ >>> result, _ = asyncio.run(i.async_conduct_interview(stop_on_exception = True))
532
+ Traceback (most recent call last):
533
+ ...
534
+ asyncio.exceptions.CancelledError
535
+ """
536
+ self.sidecar_model = sidecar_model
537
+ self.stop_on_exception = stop_on_exception
538
+
539
+ # if no model bucket is passed, create an 'infinity' bucket with no rate limits
540
+ if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
541
+ model_buckets = ModelBuckets.infinity_bucket()
542
+
543
+ ## This is the key part---it creates a task for each question,
544
+ ## with dependencies on the questions that must be answered before this one can be answered.
545
+ self.tasks = self._build_question_tasks(model_buckets=model_buckets)
546
+
547
+ ## 'Invigilators' are used to administer the survey
548
+ self.invigilators = [
549
+ self._get_invigilator(question) for question in self.survey.questions
550
+ ]
551
+ await asyncio.gather(
552
+ *self.tasks, return_exceptions=not stop_on_exception
553
+ ) # not stop_on_exception)
554
+ self.answers.replace_missing_answers_with_none(self.survey)
555
+ valid_results = list(self._extract_valid_results())
556
+ return self.answers, valid_results
557
+
558
+ # endregion
559
+
560
+ # region: Extracting results and recording errors
561
+ def _extract_valid_results(self) -> Generator["Answers", None, None]:
562
+ """Extract the valid results from the list of results.
563
+
564
+ It iterates through the tasks and invigilators, and yields the results of the tasks that are done.
565
+ If a task is not done, it raises a ValueError.
566
+ If an exception is raised in the task, it records the exception in the Interview instance except if the task was cancelled, which is expected behavior.
567
+
568
+ >>> i = Interview.example()
569
+ >>> result, _ = asyncio.run(i.async_conduct_interview())
570
+ >>> results = list(i._extract_valid_results())
571
+ >>> len(results) == len(i.survey)
572
+ True
573
+ """
574
+ assert len(self.tasks) == len(self.invigilators)
575
+
576
+ for task, invigilator in zip(self.tasks, self.invigilators):
577
+ if not task.done():
578
+ raise ValueError(f"Task {task.get_name()} is not done.")
579
+
580
+ try:
581
+ result = task.result()
582
+ except asyncio.CancelledError as e: # task was cancelled
583
+ result = invigilator.get_failed_task_result(
584
+ failure_reason="Task was cancelled."
585
+ )
586
+ except Exception as e: # any other kind of exception in the task
587
+ result = invigilator.get_failed_task_result(
588
+ failure_reason=f"Task failed with exception: {str(e)}."
589
+ )
590
+ exception_entry = InterviewExceptionEntry(
591
+ exception=e,
592
+ invigilator=invigilator,
593
+ )
594
+ self.exceptions.add(task.get_name(), exception_entry)
595
+
596
+ yield result
597
+
598
+ # endregion
599
+
600
+ # region: Magic methods
601
+ def __repr__(self) -> str:
602
+ """Return a string representation of the Interview instance."""
603
+ return f"Interview(agent = {repr(self.agent)}, survey = {repr(self.survey)}, scenario = {repr(self.scenario)}, model = {repr(self.model)})"
604
+
605
+ def duplicate(self, iteration: int, cache: "Cache") -> Interview:
606
+ """Duplicate the interview, but with a new iteration number and cache.
607
+
608
+ >>> i = Interview.example()
609
+ >>> i2 = i.duplicate(1, None)
610
+ >>> i.iteration + 1 == i2.iteration
611
+ True
612
+
613
+ """
614
+ return Interview(
615
+ agent=self.agent,
616
+ survey=self.survey,
617
+ scenario=self.scenario,
618
+ model=self.model,
619
+ iteration=iteration,
620
+ cache=cache,
621
+ skip_retry=self.skip_retry,
622
+ )
623
+
624
+ @classmethod
625
+ def example(self, throw_exception: bool = False) -> Interview:
626
+ """Return an example Interview instance."""
627
+ from edsl.agents import Agent
628
+ from edsl.surveys import Survey
629
+ from edsl.scenarios import Scenario
630
+ from edsl.language_models import LanguageModel
631
+
632
+ def f(self, question, scenario):
633
+ return "yes"
634
+
635
+ agent = Agent.example()
636
+ agent.add_direct_question_answering_method(f)
637
+ survey = Survey.example()
638
+ scenario = Scenario.example()
639
+ model = LanguageModel.example()
640
+ if throw_exception:
641
+ model = LanguageModel.example(test_model=True, throw_exception=True)
642
+ agent = Agent.example()
643
+ return Interview(agent=agent, survey=survey, scenario=scenario, model=model)
644
+ return Interview(agent=agent, survey=survey, scenario=scenario, model=model)
645
+
646
+
647
+ if __name__ == "__main__":
648
+ import doctest
649
+
650
+ # add ellipsis
651
+ doctest.testmod(optionflags=doctest.ELLIPSIS)