edsl 0.1.33__py3-none-any.whl → 0.1.33.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. edsl/Base.py +3 -9
  2. edsl/__init__.py +3 -8
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +8 -40
  5. edsl/agents/AgentList.py +0 -43
  6. edsl/agents/Invigilator.py +219 -135
  7. edsl/agents/InvigilatorBase.py +59 -148
  8. edsl/agents/{PromptConstructor.py → PromptConstructionMixin.py} +89 -138
  9. edsl/agents/__init__.py +0 -1
  10. edsl/config.py +56 -47
  11. edsl/coop/coop.py +7 -50
  12. edsl/data/Cache.py +1 -35
  13. edsl/data_transfer_models.py +38 -73
  14. edsl/enums.py +0 -4
  15. edsl/exceptions/language_models.py +1 -25
  16. edsl/exceptions/questions.py +5 -62
  17. edsl/exceptions/results.py +0 -4
  18. edsl/inference_services/AnthropicService.py +11 -13
  19. edsl/inference_services/AwsBedrock.py +17 -19
  20. edsl/inference_services/AzureAI.py +20 -37
  21. edsl/inference_services/GoogleService.py +12 -16
  22. edsl/inference_services/GroqService.py +0 -2
  23. edsl/inference_services/InferenceServiceABC.py +3 -58
  24. edsl/inference_services/OpenAIService.py +54 -48
  25. edsl/inference_services/models_available_cache.py +6 -0
  26. edsl/inference_services/registry.py +0 -6
  27. edsl/jobs/Answers.py +12 -10
  28. edsl/jobs/Jobs.py +21 -36
  29. edsl/jobs/buckets/BucketCollection.py +15 -24
  30. edsl/jobs/buckets/TokenBucket.py +14 -93
  31. edsl/jobs/interviews/Interview.py +78 -366
  32. edsl/jobs/interviews/InterviewExceptionEntry.py +19 -85
  33. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +286 -0
  34. edsl/jobs/interviews/{InterviewExceptionCollection.py → interview_exception_tracking.py} +68 -14
  35. edsl/jobs/interviews/retry_management.py +37 -0
  36. edsl/jobs/runners/JobsRunnerAsyncio.py +175 -146
  37. edsl/jobs/runners/JobsRunnerStatusMixin.py +333 -0
  38. edsl/jobs/tasks/QuestionTaskCreator.py +23 -30
  39. edsl/jobs/tasks/TaskHistory.py +213 -148
  40. edsl/language_models/LanguageModel.py +156 -261
  41. edsl/language_models/ModelList.py +2 -2
  42. edsl/language_models/RegisterLanguageModelsMeta.py +29 -14
  43. edsl/language_models/registry.py +6 -23
  44. edsl/language_models/repair.py +19 -0
  45. edsl/prompts/Prompt.py +2 -52
  46. edsl/questions/AnswerValidatorMixin.py +26 -23
  47. edsl/questions/QuestionBase.py +249 -329
  48. edsl/questions/QuestionBudget.py +41 -99
  49. edsl/questions/QuestionCheckBox.py +35 -227
  50. edsl/questions/QuestionExtract.py +27 -98
  51. edsl/questions/QuestionFreeText.py +29 -52
  52. edsl/questions/QuestionFunctional.py +0 -7
  53. edsl/questions/QuestionList.py +22 -141
  54. edsl/questions/QuestionMultipleChoice.py +65 -159
  55. edsl/questions/QuestionNumerical.py +46 -88
  56. edsl/questions/QuestionRank.py +24 -182
  57. edsl/questions/RegisterQuestionsMeta.py +12 -31
  58. edsl/questions/__init__.py +4 -3
  59. edsl/questions/derived/QuestionLikertFive.py +5 -10
  60. edsl/questions/derived/QuestionLinearScale.py +2 -15
  61. edsl/questions/derived/QuestionTopK.py +1 -10
  62. edsl/questions/derived/QuestionYesNo.py +3 -24
  63. edsl/questions/descriptors.py +7 -43
  64. edsl/questions/question_registry.py +2 -6
  65. edsl/results/Dataset.py +0 -20
  66. edsl/results/DatasetExportMixin.py +48 -46
  67. edsl/results/Result.py +5 -32
  68. edsl/results/Results.py +46 -135
  69. edsl/results/ResultsDBMixin.py +3 -3
  70. edsl/scenarios/FileStore.py +10 -71
  71. edsl/scenarios/Scenario.py +25 -96
  72. edsl/scenarios/ScenarioImageMixin.py +2 -2
  73. edsl/scenarios/ScenarioList.py +39 -361
  74. edsl/scenarios/ScenarioListExportMixin.py +0 -9
  75. edsl/scenarios/ScenarioListPdfMixin.py +4 -150
  76. edsl/study/SnapShot.py +1 -8
  77. edsl/study/Study.py +0 -32
  78. edsl/surveys/Rule.py +1 -10
  79. edsl/surveys/RuleCollection.py +5 -21
  80. edsl/surveys/Survey.py +310 -636
  81. edsl/surveys/SurveyExportMixin.py +9 -71
  82. edsl/surveys/SurveyFlowVisualizationMixin.py +1 -2
  83. edsl/surveys/SurveyQualtricsImport.py +4 -75
  84. edsl/utilities/gcp_bucket/simple_example.py +9 -0
  85. edsl/utilities/utilities.py +1 -9
  86. {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/METADATA +2 -5
  87. edsl-0.1.33.dev1.dist-info/RECORD +209 -0
  88. edsl/TemplateLoader.py +0 -24
  89. edsl/auto/AutoStudy.py +0 -117
  90. edsl/auto/StageBase.py +0 -230
  91. edsl/auto/StageGenerateSurvey.py +0 -178
  92. edsl/auto/StageLabelQuestions.py +0 -125
  93. edsl/auto/StagePersona.py +0 -61
  94. edsl/auto/StagePersonaDimensionValueRanges.py +0 -88
  95. edsl/auto/StagePersonaDimensionValues.py +0 -74
  96. edsl/auto/StagePersonaDimensions.py +0 -69
  97. edsl/auto/StageQuestions.py +0 -73
  98. edsl/auto/SurveyCreatorPipeline.py +0 -21
  99. edsl/auto/utilities.py +0 -224
  100. edsl/coop/PriceFetcher.py +0 -58
  101. edsl/inference_services/MistralAIService.py +0 -120
  102. edsl/inference_services/TestService.py +0 -80
  103. edsl/inference_services/TogetherAIService.py +0 -170
  104. edsl/jobs/FailedQuestion.py +0 -78
  105. edsl/jobs/runners/JobsRunnerStatus.py +0 -331
  106. edsl/language_models/fake_openai_call.py +0 -15
  107. edsl/language_models/fake_openai_service.py +0 -61
  108. edsl/language_models/utilities.py +0 -61
  109. edsl/questions/QuestionBaseGenMixin.py +0 -133
  110. edsl/questions/QuestionBasePromptsMixin.py +0 -266
  111. edsl/questions/Quick.py +0 -41
  112. edsl/questions/ResponseValidatorABC.py +0 -170
  113. edsl/questions/decorators.py +0 -21
  114. edsl/questions/prompt_templates/question_budget.jinja +0 -13
  115. edsl/questions/prompt_templates/question_checkbox.jinja +0 -32
  116. edsl/questions/prompt_templates/question_extract.jinja +0 -11
  117. edsl/questions/prompt_templates/question_free_text.jinja +0 -3
  118. edsl/questions/prompt_templates/question_linear_scale.jinja +0 -11
  119. edsl/questions/prompt_templates/question_list.jinja +0 -17
  120. edsl/questions/prompt_templates/question_multiple_choice.jinja +0 -33
  121. edsl/questions/prompt_templates/question_numerical.jinja +0 -37
  122. edsl/questions/templates/__init__.py +0 -0
  123. edsl/questions/templates/budget/__init__.py +0 -0
  124. edsl/questions/templates/budget/answering_instructions.jinja +0 -7
  125. edsl/questions/templates/budget/question_presentation.jinja +0 -7
  126. edsl/questions/templates/checkbox/__init__.py +0 -0
  127. edsl/questions/templates/checkbox/answering_instructions.jinja +0 -10
  128. edsl/questions/templates/checkbox/question_presentation.jinja +0 -22
  129. edsl/questions/templates/extract/__init__.py +0 -0
  130. edsl/questions/templates/extract/answering_instructions.jinja +0 -7
  131. edsl/questions/templates/extract/question_presentation.jinja +0 -1
  132. edsl/questions/templates/free_text/__init__.py +0 -0
  133. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  134. edsl/questions/templates/free_text/question_presentation.jinja +0 -1
  135. edsl/questions/templates/likert_five/__init__.py +0 -0
  136. edsl/questions/templates/likert_five/answering_instructions.jinja +0 -10
  137. edsl/questions/templates/likert_five/question_presentation.jinja +0 -12
  138. edsl/questions/templates/linear_scale/__init__.py +0 -0
  139. edsl/questions/templates/linear_scale/answering_instructions.jinja +0 -5
  140. edsl/questions/templates/linear_scale/question_presentation.jinja +0 -5
  141. edsl/questions/templates/list/__init__.py +0 -0
  142. edsl/questions/templates/list/answering_instructions.jinja +0 -4
  143. edsl/questions/templates/list/question_presentation.jinja +0 -5
  144. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  145. edsl/questions/templates/multiple_choice/answering_instructions.jinja +0 -9
  146. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  147. edsl/questions/templates/multiple_choice/question_presentation.jinja +0 -12
  148. edsl/questions/templates/numerical/__init__.py +0 -0
  149. edsl/questions/templates/numerical/answering_instructions.jinja +0 -8
  150. edsl/questions/templates/numerical/question_presentation.jinja +0 -7
  151. edsl/questions/templates/rank/__init__.py +0 -0
  152. edsl/questions/templates/rank/answering_instructions.jinja +0 -11
  153. edsl/questions/templates/rank/question_presentation.jinja +0 -15
  154. edsl/questions/templates/top_k/__init__.py +0 -0
  155. edsl/questions/templates/top_k/answering_instructions.jinja +0 -8
  156. edsl/questions/templates/top_k/question_presentation.jinja +0 -22
  157. edsl/questions/templates/yes_no/__init__.py +0 -0
  158. edsl/questions/templates/yes_no/answering_instructions.jinja +0 -6
  159. edsl/questions/templates/yes_no/question_presentation.jinja +0 -12
  160. edsl/results/DatasetTree.py +0 -145
  161. edsl/results/Selector.py +0 -118
  162. edsl/results/tree_explore.py +0 -115
  163. edsl/surveys/instructions/ChangeInstruction.py +0 -47
  164. edsl/surveys/instructions/Instruction.py +0 -34
  165. edsl/surveys/instructions/InstructionCollection.py +0 -77
  166. edsl/surveys/instructions/__init__.py +0 -0
  167. edsl/templates/error_reporting/base.html +0 -24
  168. edsl/templates/error_reporting/exceptions_by_model.html +0 -35
  169. edsl/templates/error_reporting/exceptions_by_question_name.html +0 -17
  170. edsl/templates/error_reporting/exceptions_by_type.html +0 -17
  171. edsl/templates/error_reporting/interview_details.html +0 -116
  172. edsl/templates/error_reporting/interviews.html +0 -10
  173. edsl/templates/error_reporting/overview.html +0 -5
  174. edsl/templates/error_reporting/performance_plot.html +0 -2
  175. edsl/templates/error_reporting/report.css +0 -74
  176. edsl/templates/error_reporting/report.html +0 -118
  177. edsl/templates/error_reporting/report.js +0 -25
  178. edsl-0.1.33.dist-info/RECORD +0 -295
  179. {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/LICENSE +0 -0
  180. {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/WHEEL +0 -0
@@ -1,76 +1,50 @@
1
1
  """This module contains the Interview class, which is responsible for conducting an interview asynchronously."""
2
2
 
3
3
  from __future__ import annotations
4
+ import traceback
4
5
  import asyncio
5
- from typing import Any, Type, List, Generator, Optional, Union
6
-
7
- from tenacity import (
8
- retry,
9
- stop_after_attempt,
10
- wait_exponential,
11
- retry_if_exception_type,
12
- RetryError,
13
- )
6
+ import time
7
+ from typing import Any, Type, List, Generator, Optional
14
8
 
15
- from edsl import CONFIG
9
+ from edsl.jobs.Answers import Answers
16
10
  from edsl.surveys.base import EndOfSurvey
17
- from edsl.exceptions import QuestionAnswerValidationError
18
- from edsl.exceptions import QuestionAnswerValidationError
19
- from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
20
-
21
11
  from edsl.jobs.buckets.ModelBuckets import ModelBuckets
22
- from edsl.jobs.Answers import Answers
23
- from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
24
12
  from edsl.jobs.tasks.TaskCreators import TaskCreators
13
+
25
14
  from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
26
- from edsl.jobs.interviews.InterviewExceptionCollection import (
15
+ from edsl.jobs.interviews.interview_exception_tracking import (
27
16
  InterviewExceptionCollection,
28
17
  )
29
-
30
- from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
31
-
32
- from edsl.surveys.base import EndOfSurvey
33
- from edsl.jobs.buckets.ModelBuckets import ModelBuckets
34
18
  from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
35
- from edsl.jobs.tasks.task_status_enum import TaskStatus
36
- from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
37
-
38
-
39
- from edsl import Agent, Survey, Scenario, Cache
40
- from edsl.language_models import LanguageModel
41
- from edsl.questions import QuestionBase
42
- from edsl.agents.InvigilatorBase import InvigilatorBase
43
-
44
- from edsl.exceptions.language_models import LanguageModelNoResponseError
19
+ from edsl.jobs.interviews.retry_management import retry_strategy
20
+ from edsl.jobs.interviews.InterviewTaskBuildingMixin import InterviewTaskBuildingMixin
21
+ from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
45
22
 
23
+ import asyncio
46
24
 
47
- from edsl import CONFIG
48
25
 
49
- EDSL_BACKOFF_START_SEC = float(CONFIG.get("EDSL_BACKOFF_START_SEC"))
50
- EDSL_BACKOFF_MAX_SEC = float(CONFIG.get("EDSL_BACKOFF_MAX_SEC"))
51
- EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
26
+ def run_async(coro):
27
+ return asyncio.run(coro)
52
28
 
53
29
 
54
- class Interview(InterviewStatusMixin):
30
+ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
55
31
  """
56
32
  An 'interview' is one agent answering one survey, with one language model, for a given scenario.
57
33
 
58
34
  The main method is `async_conduct_interview`, which conducts the interview asynchronously.
59
- Most of the class is dedicated to creating the tasks for each question in the survey, and then running them.
60
35
  """
61
36
 
62
37
  def __init__(
63
38
  self,
64
- agent: Agent,
65
- survey: Survey,
66
- scenario: Scenario,
39
+ agent: "Agent",
40
+ survey: "Survey",
41
+ scenario: "Scenario",
67
42
  model: Type["LanguageModel"],
68
43
  debug: Optional[bool] = False,
69
44
  iteration: int = 0,
70
45
  cache: Optional["Cache"] = None,
71
46
  sidecar_model: Optional["LanguageModel"] = None,
72
- skip_retry: bool = False,
73
- raise_validation_errors: bool = True,
47
+ skip_retry=False,
74
48
  ):
75
49
  """Initialize the Interview instance.
76
50
 
@@ -110,15 +84,11 @@ class Interview(InterviewStatusMixin):
110
84
  ] = Answers() # will get filled in as interview progresses
111
85
  self.sidecar_model = sidecar_model
112
86
 
113
- # self.stop_on_exception = False
114
-
115
87
  # Trackers
116
88
  self.task_creators = TaskCreators() # tracks the task creators
117
89
  self.exceptions = InterviewExceptionCollection()
118
-
119
90
  self._task_status_log_dict = InterviewStatusLog()
120
91
  self.skip_retry = skip_retry
121
- self.raise_validation_errors = raise_validation_errors
122
92
 
123
93
  # dictionary mapping question names to their index in the survey.
124
94
  self.to_index = {
@@ -126,9 +96,6 @@ class Interview(InterviewStatusMixin):
126
96
  for index, question_name in enumerate(self.survey.question_names)
127
97
  }
128
98
 
129
- self.failed_questions = []
130
-
131
- # region: Serialization
132
99
  def _to_dict(self, include_exceptions=False) -> dict[str, Any]:
133
100
  """Return a dictionary representation of the Interview instance.
134
101
  This is just for hashing purposes.
@@ -153,301 +120,13 @@ class Interview(InterviewStatusMixin):
153
120
 
154
121
  return dict_hash(self._to_dict())
155
122
 
156
- # endregion
157
-
158
- # region: Creating tasks
159
- @property
160
- def dag(self) -> "DAG":
161
- """Return the directed acyclic graph for the survey.
162
-
163
- The DAG, or directed acyclic graph, is a dictionary that maps question names to their dependencies.
164
- It is used to determine the order in which questions should be answered.
165
- This reflects both agent 'memory' considerations and 'skip' logic.
166
- The 'textify' parameter is set to True, so that the question names are returned as strings rather than integer indices.
167
-
168
- >>> i = Interview.example()
169
- >>> i.dag == {'q2': {'q0'}, 'q1': {'q0'}}
170
- True
171
- """
172
- return self.survey.dag(textify=True)
173
-
174
- def _build_question_tasks(
175
- self,
176
- model_buckets: ModelBuckets,
177
- ) -> list[asyncio.Task]:
178
- """Create a task for each question, with dependencies on the questions that must be answered before this one can be answered.
179
-
180
- :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
181
- :param model_buckets: the model buckets used to track and control usage rates.
182
- """
183
- tasks = []
184
- for question in self.survey.questions:
185
- tasks_that_must_be_completed_before = list(
186
- self._get_tasks_that_must_be_completed_before(
187
- tasks=tasks, question=question
188
- )
189
- )
190
- question_task = self._create_question_task(
191
- question=question,
192
- tasks_that_must_be_completed_before=tasks_that_must_be_completed_before,
193
- model_buckets=model_buckets,
194
- iteration=self.iteration,
195
- )
196
- tasks.append(question_task)
197
- return tuple(tasks)
198
-
199
- def _get_tasks_that_must_be_completed_before(
200
- self, *, tasks: list[asyncio.Task], question: "QuestionBase"
201
- ) -> Generator[asyncio.Task, None, None]:
202
- """Return the tasks that must be completed before the given question can be answered.
203
-
204
- :param tasks: a list of tasks that have been created so far.
205
- :param question: the question for which we are determining dependencies.
206
-
207
- If a question has no dependencies, this will be an empty list, [].
208
- """
209
- parents_of_focal_question = self.dag.get(question.question_name, [])
210
- for parent_question_name in parents_of_focal_question:
211
- yield tasks[self.to_index[parent_question_name]]
212
-
213
- def _create_question_task(
214
- self,
215
- *,
216
- question: QuestionBase,
217
- tasks_that_must_be_completed_before: list[asyncio.Task],
218
- model_buckets: ModelBuckets,
219
- iteration: int = 0,
220
- ) -> asyncio.Task:
221
- """Create a task that depends on the passed-in dependencies that are awaited before the task is run.
222
-
223
- :param question: the question to be answered. This is the question we are creating a task for.
224
- :param tasks_that_must_be_completed_before: the tasks that must be completed before the focal task is run.
225
- :param model_buckets: the model buckets used to track and control usage rates.
226
- :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
227
- :param iteration: the iteration number for the interview.
228
-
229
- The task is created by a `QuestionTaskCreator`, which is responsible for creating the task and managing its dependencies.
230
- It is passed a reference to the function that will be called to answer the question.
231
- It is passed a list "tasks_that_must_be_completed_before" that are awaited before the task is run.
232
- These are added as a dependency to the focal task.
233
- """
234
- task_creator = QuestionTaskCreator(
235
- question=question,
236
- answer_question_func=self._answer_question_and_record_task,
237
- token_estimator=self._get_estimated_request_tokens,
238
- model_buckets=model_buckets,
239
- iteration=iteration,
240
- )
241
- for task in tasks_that_must_be_completed_before:
242
- task_creator.add_dependency(task)
243
-
244
- self.task_creators.update(
245
- {question.question_name: task_creator}
246
- ) # track this task creator
247
- return task_creator.generate_task()
248
-
249
- def _get_estimated_request_tokens(self, question) -> float:
250
- """Estimate the number of tokens that will be required to run the focal task."""
251
- invigilator = self._get_invigilator(question=question)
252
- # TODO: There should be a way to get a more accurate estimate.
253
- combined_text = ""
254
- for prompt in invigilator.get_prompts().values():
255
- if hasattr(prompt, "text"):
256
- combined_text += prompt.text
257
- elif isinstance(prompt, str):
258
- combined_text += prompt
259
- else:
260
- raise ValueError(f"Prompt is of type {type(prompt)}")
261
- return len(combined_text) / 4.0
262
-
263
- async def _answer_question_and_record_task(
264
- self,
265
- *,
266
- question: "QuestionBase",
267
- task=None,
268
- ) -> "AgentResponseDict":
269
- """Answer a question and records the task."""
270
-
271
- had_language_model_no_response_error = False
272
-
273
- @retry(
274
- stop=stop_after_attempt(EDSL_MAX_ATTEMPTS),
275
- wait=wait_exponential(
276
- multiplier=EDSL_BACKOFF_START_SEC, max=EDSL_BACKOFF_MAX_SEC
277
- ),
278
- retry=retry_if_exception_type(LanguageModelNoResponseError),
279
- reraise=True,
280
- )
281
- async def attempt_answer():
282
- nonlocal had_language_model_no_response_error
283
-
284
- invigilator = self._get_invigilator(question)
285
-
286
- if self._skip_this_question(question):
287
- return invigilator.get_failed_task_result(
288
- failure_reason="Question skipped."
289
- )
290
-
291
- try:
292
- response: EDSLResultObjectInput = (
293
- await invigilator.async_answer_question()
294
- )
295
- if response.validated:
296
- self.answers.add_answer(response=response, question=question)
297
- self._cancel_skipped_questions(question)
298
- else:
299
- if (
300
- hasattr(response, "exception_occurred")
301
- and response.exception_occurred
302
- ):
303
- raise response.exception_occurred
304
-
305
- except QuestionAnswerValidationError as e:
306
- self._handle_exception(e, invigilator, task)
307
- return invigilator.get_failed_task_result(
308
- failure_reason="Question answer validation failed."
309
- )
310
-
311
- except asyncio.TimeoutError as e:
312
- self._handle_exception(e, invigilator, task)
313
- had_language_model_no_response_error = True
314
- raise LanguageModelNoResponseError(
315
- f"Language model timed out for question '{question.question_name}.'"
316
- )
317
-
318
- except Exception as e:
319
- self._handle_exception(e, invigilator, task)
320
-
321
- if "response" not in locals():
322
- had_language_model_no_response_error = True
323
- raise LanguageModelNoResponseError(
324
- f"Language model did not return a response for question '{question.question_name}.'"
325
- )
326
-
327
- # if it gets here, it means the no response error was fixed
328
- if (
329
- question.question_name in self.exceptions
330
- and had_language_model_no_response_error
331
- ):
332
- self.exceptions.record_fixed_question(question.question_name)
333
-
334
- return response
335
-
336
- try:
337
- return await attempt_answer()
338
- except RetryError as retry_error:
339
- # All retries have failed for LanguageModelNoResponseError
340
- original_error = retry_error.last_attempt.exception()
341
- self._handle_exception(
342
- original_error, self._get_invigilator(question), task
343
- )
344
- raise original_error # Re-raise the original error after handling
345
-
346
- def _get_invigilator(self, question: QuestionBase) -> InvigilatorBase:
347
- """Return an invigilator for the given question.
348
-
349
- :param question: the question to be answered
350
- :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
351
- """
352
- invigilator = self.agent.create_invigilator(
353
- question=question,
354
- scenario=self.scenario,
355
- model=self.model,
356
- debug=False,
357
- survey=self.survey,
358
- memory_plan=self.survey.memory_plan,
359
- current_answers=self.answers,
360
- iteration=self.iteration,
361
- cache=self.cache,
362
- sidecar_model=self.sidecar_model,
363
- raise_validation_errors=self.raise_validation_errors,
364
- )
365
- """Return an invigilator for the given question."""
366
- return invigilator
367
-
368
- def _skip_this_question(self, current_question: "QuestionBase") -> bool:
369
- """Determine if the current question should be skipped.
370
-
371
- :param current_question: the question to be answered.
372
- """
373
- current_question_index = self.to_index[current_question.question_name]
374
-
375
- answers = self.answers | self.scenario | self.agent["traits"]
376
- skip = self.survey.rule_collection.skip_question_before_running(
377
- current_question_index, answers
378
- )
379
- return skip
380
-
381
- def _handle_exception(
382
- self, e: Exception, invigilator: "InvigilatorBase", task=None
383
- ):
384
- import copy
385
-
386
- # breakpoint()
387
-
388
- answers = copy.copy(self.answers)
389
- exception_entry = InterviewExceptionEntry(
390
- exception=e,
391
- invigilator=invigilator,
392
- answers=answers,
393
- )
394
- if task:
395
- task.task_status = TaskStatus.FAILED
396
- self.exceptions.add(invigilator.question.question_name, exception_entry)
397
-
398
- if self.raise_validation_errors:
399
- if isinstance(e, QuestionAnswerValidationError):
400
- raise e
401
-
402
- if hasattr(self, "stop_on_exception"):
403
- stop_on_exception = self.stop_on_exception
404
- else:
405
- stop_on_exception = False
406
-
407
- if stop_on_exception:
408
- raise e
409
-
410
- def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
411
- """Cancel the tasks for questions that are skipped.
412
-
413
- :param current_question: the question that was just answered.
414
-
415
- It first determines the next question, given the current question and the current answers.
416
- If the next question is the end of the survey, it cancels all remaining tasks.
417
- If the next question is after the current question, it cancels all tasks between the current question and the next question.
418
- """
419
- current_question_index: int = self.to_index[current_question.question_name]
420
-
421
- next_question: Union[
422
- int, EndOfSurvey
423
- ] = self.survey.rule_collection.next_question(
424
- q_now=current_question_index,
425
- answers=self.answers | self.scenario | self.agent["traits"],
426
- )
427
-
428
- next_question_index = next_question.next_q
429
-
430
- def cancel_between(start, end):
431
- """Cancel the tasks between the start and end indices."""
432
- for i in range(start, end):
433
- self.tasks[i].cancel()
434
-
435
- if next_question_index == EndOfSurvey:
436
- cancel_between(current_question_index + 1, len(self.survey.questions))
437
- return
438
-
439
- if next_question_index > (current_question_index + 1):
440
- cancel_between(current_question_index + 1, next_question_index)
441
-
442
- # endregion
443
-
444
- # region: Conducting the interview
445
123
  async def async_conduct_interview(
446
124
  self,
447
- model_buckets: Optional[ModelBuckets] = None,
125
+ *,
126
+ model_buckets: ModelBuckets = None,
127
+ debug: bool = False,
448
128
  stop_on_exception: bool = False,
449
129
  sidecar_model: Optional["LanguageModel"] = None,
450
- raise_validation_errors: bool = True,
451
130
  ) -> tuple["Answers", List[dict[str, Any]]]:
452
131
  """
453
132
  Conduct an Interview asynchronously.
@@ -467,6 +146,19 @@ class Interview(InterviewStatusMixin):
467
146
 
468
147
  >>> i = Interview.example(throw_exception = True)
469
148
  >>> result, _ = asyncio.run(i.async_conduct_interview())
149
+ Attempt 1 failed with exception:This is a test error now waiting 1.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
150
+ <BLANKLINE>
151
+ <BLANKLINE>
152
+ Attempt 2 failed with exception:This is a test error now waiting 2.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
153
+ <BLANKLINE>
154
+ <BLANKLINE>
155
+ Attempt 3 failed with exception:This is a test error now waiting 4.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
156
+ <BLANKLINE>
157
+ <BLANKLINE>
158
+ Attempt 4 failed with exception:This is a test error now waiting 8.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
159
+ <BLANKLINE>
160
+ <BLANKLINE>
161
+
470
162
  >>> i.exceptions
471
163
  {'q0': ...
472
164
  >>> i = Interview.example()
@@ -476,30 +168,26 @@ class Interview(InterviewStatusMixin):
476
168
  asyncio.exceptions.CancelledError
477
169
  """
478
170
  self.sidecar_model = sidecar_model
479
- self.stop_on_exception = stop_on_exception
480
171
 
481
172
  # if no model bucket is passed, create an 'infinity' bucket with no rate limits
482
173
  if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
483
174
  model_buckets = ModelBuckets.infinity_bucket()
484
175
 
176
+ ## build the tasks using the InterviewTaskBuildingMixin
485
177
  ## This is the key part---it creates a task for each question,
486
178
  ## with dependencies on the questions that must be answered before this one can be answered.
487
- self.tasks = self._build_question_tasks(model_buckets=model_buckets)
179
+ self.tasks = self._build_question_tasks(
180
+ debug=debug, model_buckets=model_buckets
181
+ )
488
182
 
489
183
  ## 'Invigilators' are used to administer the survey
490
- self.invigilators = [
491
- self._get_invigilator(question) for question in self.survey.questions
492
- ]
493
- await asyncio.gather(
494
- *self.tasks, return_exceptions=not stop_on_exception
495
- ) # not stop_on_exception)
184
+ self.invigilators = list(self._build_invigilators(debug=debug))
185
+ # await the tasks being conducted
186
+ await asyncio.gather(*self.tasks, return_exceptions=not stop_on_exception)
496
187
  self.answers.replace_missing_answers_with_none(self.survey)
497
188
  valid_results = list(self._extract_valid_results())
498
189
  return self.answers, valid_results
499
190
 
500
- # endregion
501
-
502
- # region: Extracting results and recording errors
503
191
  def _extract_valid_results(self) -> Generator["Answers", None, None]:
504
192
  """Extract the valid results from the list of results.
505
193
 
@@ -512,6 +200,8 @@ class Interview(InterviewStatusMixin):
512
200
  >>> results = list(i._extract_valid_results())
513
201
  >>> len(results) == len(i.survey)
514
202
  True
203
+ >>> type(results[0])
204
+ <class 'edsl.data_transfer_models.AgentResponseDict'>
515
205
  """
516
206
  assert len(self.tasks) == len(self.invigilators)
517
207
 
@@ -522,24 +212,46 @@ class Interview(InterviewStatusMixin):
522
212
  try:
523
213
  result = task.result()
524
214
  except asyncio.CancelledError as e: # task was cancelled
525
- result = invigilator.get_failed_task_result(
526
- failure_reason="Task was cancelled."
527
- )
215
+ result = invigilator.get_failed_task_result()
528
216
  except Exception as e: # any other kind of exception in the task
529
- result = invigilator.get_failed_task_result(
530
- failure_reason=f"Task failed with exception: {str(e)}."
531
- )
532
- exception_entry = InterviewExceptionEntry(
533
- exception=e,
534
- invigilator=invigilator,
535
- )
536
- self.exceptions.add(task.get_name(), exception_entry)
537
-
217
+ result = invigilator.get_failed_task_result()
218
+ self._record_exception(task, e)
538
219
  yield result
539
220
 
540
- # endregion
221
+ def _record_exception(self, task, exception: Exception) -> None:
222
+ """Record an exception in the Interview instance.
223
+
224
+ It records the exception in the Interview instance, with the task name and the exception entry.
225
+
226
+ >>> i = Interview.example()
227
+ >>> result, _ = asyncio.run(i.async_conduct_interview())
228
+ >>> i.exceptions
229
+ {}
230
+ >>> i._record_exception(i.tasks[0], Exception("An exception occurred."))
231
+ >>> i.exceptions
232
+ {'q0': ...
233
+ """
234
+ exception_entry = InterviewExceptionEntry(exception)
235
+ self.exceptions.add(task.get_name(), exception_entry)
236
+
237
+ @property
238
+ def dag(self) -> "DAG":
239
+ """Return the directed acyclic graph for the survey.
240
+
241
+ The DAG, or directed acyclic graph, is a dictionary that maps question names to their dependencies.
242
+ It is used to determine the order in which questions should be answered.
243
+ This reflects both agent 'memory' considerations and 'skip' logic.
244
+ The 'textify' parameter is set to True, so that the question names are returned as strings rather than integer indices.
245
+
246
+ >>> i = Interview.example()
247
+ >>> i.dag == {'q2': {'q0'}, 'q1': {'q0'}}
248
+ True
249
+ """
250
+ return self.survey.dag(textify=True)
541
251
 
542
- # region: Magic methods
252
+ #######################
253
+ # Dunder methods
254
+ #######################
543
255
  def __repr__(self) -> str:
544
256
  """Return a string representation of the Interview instance."""
545
257
  return f"Interview(agent = {repr(self.agent)}, survey = {repr(self.survey)}, scenario = {repr(self.scenario)}, model = {repr(self.model)})"
@@ -2,62 +2,24 @@ import traceback
2
2
  import datetime
3
3
  import time
4
4
  from collections import UserDict
5
- from edsl.jobs.FailedQuestion import FailedQuestion
5
+
6
+ # traceback=traceback.format_exc(),
7
+ # traceback = frame_summary_to_dict(traceback.extract_tb(e.__traceback__))
8
+ # traceback = [frame_summary_to_dict(f) for f in traceback.extract_tb(e.__traceback__)]
6
9
 
7
10
 
8
11
  class InterviewExceptionEntry:
9
- """Class to record an exception that occurred during the interview."""
10
-
11
- def __init__(
12
- self,
13
- *,
14
- exception: Exception,
15
- # failed_question: FailedQuestion,
16
- invigilator: "Invigilator",
17
- traceback_format="text",
18
- answers=None,
19
- ):
12
+ """Class to record an exception that occurred during the interview.
13
+
14
+ >>> entry = InterviewExceptionEntry.example()
15
+ >>> entry.to_dict()['exception']
16
+ "ValueError('An error occurred.')"
17
+ """
18
+
19
+ def __init__(self, exception: Exception, traceback_format="html"):
20
20
  self.time = datetime.datetime.now().isoformat()
21
21
  self.exception = exception
22
- # self.failed_question = failed_question
23
- self.invigilator = invigilator
24
22
  self.traceback_format = traceback_format
25
- self.answers = answers
26
-
27
- @property
28
- def question_type(self):
29
- # return self.failed_question.question.question_type
30
- return self.invigilator.question.question_type
31
-
32
- @property
33
- def name(self):
34
- return repr(self.exception)
35
-
36
- @property
37
- def rendered_prompts(self):
38
- return self.invigilator.get_prompts()
39
-
40
- @property
41
- def key_sequence(self):
42
- return self.invigilator.model.key_sequence
43
-
44
- @property
45
- def generated_token_string(self):
46
- # return "POO"
47
- if self.invigilator.raw_model_response is None:
48
- return "No raw model response available."
49
- else:
50
- return self.invigilator.model.get_generated_token_string(
51
- self.invigilator.raw_model_response
52
- )
53
-
54
- @property
55
- def raw_model_response(self):
56
- import json
57
-
58
- if self.invigilator.raw_model_response is None:
59
- return "No raw model response available."
60
- return json.dumps(self.invigilator.raw_model_response, indent=2)
61
23
 
62
24
  def __getitem__(self, key):
63
25
  # Support dict-like access obj['a']
@@ -65,37 +27,11 @@ class InterviewExceptionEntry:
65
27
 
66
28
  @classmethod
67
29
  def example(cls):
68
- from edsl import QuestionFreeText
69
- from edsl.language_models import LanguageModel
70
-
71
- m = LanguageModel.example(test_model=True)
72
- q = QuestionFreeText.example(exception_to_throw=ValueError)
73
- results = q.by(m).run(
74
- skip_retry=True, print_exceptions=False, raise_validation_errors=True
75
- )
76
- return results.task_history.exceptions[0]["how_are_you"][0]
77
-
78
- @property
79
- def code_to_reproduce(self):
80
- return self.code(run=False)
81
-
82
- def code(self, run=True):
83
- lines = []
84
- lines.append("from edsl import Question, Model, Scenario, Agent")
85
-
86
- lines.append(f"q = {repr(self.invigilator.question)}")
87
- lines.append(f"scenario = {repr(self.invigilator.scenario)}")
88
- lines.append(f"agent = {repr(self.invigilator.agent)}")
89
- lines.append(f"m = Model('{self.invigilator.model.model}')")
90
- lines.append("results = q.by(m).by(agent).by(scenario).run()")
91
- code_str = "\n".join(lines)
92
-
93
- if run:
94
- # Create a new namespace to avoid polluting the global namespace
95
- namespace = {}
96
- exec(code_str, namespace)
97
- return namespace["results"]
98
- return code_str
30
+ try:
31
+ raise ValueError("An error occurred.")
32
+ except Exception as e:
33
+ entry = InterviewExceptionEntry(e)
34
+ return entry
99
35
 
100
36
  @property
101
37
  def traceback(self):
@@ -142,15 +78,13 @@ class InterviewExceptionEntry:
142
78
 
143
79
  >>> entry = InterviewExceptionEntry.example()
144
80
  >>> entry.to_dict()['exception']
145
- ValueError()
81
+ "ValueError('An error occurred.')"
146
82
 
147
83
  """
148
84
  return {
149
- "exception": self.exception,
85
+ "exception": repr(self.exception),
150
86
  "time": self.time,
151
87
  "traceback": self.traceback,
152
- # "failed_question": self.failed_question.to_dict(),
153
- "invigilator": self.invigilator.to_dict(),
154
88
  }
155
89
 
156
90
  def push(self):