edsl 0.1.33__py3-none-any.whl → 0.1.33.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +3 -9
- edsl/__init__.py +3 -8
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +8 -40
- edsl/agents/AgentList.py +0 -43
- edsl/agents/Invigilator.py +219 -135
- edsl/agents/InvigilatorBase.py +59 -148
- edsl/agents/{PromptConstructor.py → PromptConstructionMixin.py} +89 -138
- edsl/agents/__init__.py +0 -1
- edsl/config.py +56 -47
- edsl/coop/coop.py +7 -50
- edsl/data/Cache.py +1 -35
- edsl/data_transfer_models.py +38 -73
- edsl/enums.py +0 -4
- edsl/exceptions/language_models.py +1 -25
- edsl/exceptions/questions.py +5 -62
- edsl/exceptions/results.py +0 -4
- edsl/inference_services/AnthropicService.py +11 -13
- edsl/inference_services/AwsBedrock.py +17 -19
- edsl/inference_services/AzureAI.py +20 -37
- edsl/inference_services/GoogleService.py +12 -16
- edsl/inference_services/GroqService.py +0 -2
- edsl/inference_services/InferenceServiceABC.py +3 -58
- edsl/inference_services/OpenAIService.py +54 -48
- edsl/inference_services/models_available_cache.py +6 -0
- edsl/inference_services/registry.py +0 -6
- edsl/jobs/Answers.py +12 -10
- edsl/jobs/Jobs.py +21 -36
- edsl/jobs/buckets/BucketCollection.py +15 -24
- edsl/jobs/buckets/TokenBucket.py +14 -93
- edsl/jobs/interviews/Interview.py +78 -366
- edsl/jobs/interviews/InterviewExceptionEntry.py +19 -85
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +286 -0
- edsl/jobs/interviews/{InterviewExceptionCollection.py → interview_exception_tracking.py} +68 -14
- edsl/jobs/interviews/retry_management.py +37 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +175 -146
- edsl/jobs/runners/JobsRunnerStatusMixin.py +333 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +23 -30
- edsl/jobs/tasks/TaskHistory.py +213 -148
- edsl/language_models/LanguageModel.py +156 -261
- edsl/language_models/ModelList.py +2 -2
- edsl/language_models/RegisterLanguageModelsMeta.py +29 -14
- edsl/language_models/registry.py +6 -23
- edsl/language_models/repair.py +19 -0
- edsl/prompts/Prompt.py +2 -52
- edsl/questions/AnswerValidatorMixin.py +26 -23
- edsl/questions/QuestionBase.py +249 -329
- edsl/questions/QuestionBudget.py +41 -99
- edsl/questions/QuestionCheckBox.py +35 -227
- edsl/questions/QuestionExtract.py +27 -98
- edsl/questions/QuestionFreeText.py +29 -52
- edsl/questions/QuestionFunctional.py +0 -7
- edsl/questions/QuestionList.py +22 -141
- edsl/questions/QuestionMultipleChoice.py +65 -159
- edsl/questions/QuestionNumerical.py +46 -88
- edsl/questions/QuestionRank.py +24 -182
- edsl/questions/RegisterQuestionsMeta.py +12 -31
- edsl/questions/__init__.py +4 -3
- edsl/questions/derived/QuestionLikertFive.py +5 -10
- edsl/questions/derived/QuestionLinearScale.py +2 -15
- edsl/questions/derived/QuestionTopK.py +1 -10
- edsl/questions/derived/QuestionYesNo.py +3 -24
- edsl/questions/descriptors.py +7 -43
- edsl/questions/question_registry.py +2 -6
- edsl/results/Dataset.py +0 -20
- edsl/results/DatasetExportMixin.py +48 -46
- edsl/results/Result.py +5 -32
- edsl/results/Results.py +46 -135
- edsl/results/ResultsDBMixin.py +3 -3
- edsl/scenarios/FileStore.py +10 -71
- edsl/scenarios/Scenario.py +25 -96
- edsl/scenarios/ScenarioImageMixin.py +2 -2
- edsl/scenarios/ScenarioList.py +39 -361
- edsl/scenarios/ScenarioListExportMixin.py +0 -9
- edsl/scenarios/ScenarioListPdfMixin.py +4 -150
- edsl/study/SnapShot.py +1 -8
- edsl/study/Study.py +0 -32
- edsl/surveys/Rule.py +1 -10
- edsl/surveys/RuleCollection.py +5 -21
- edsl/surveys/Survey.py +310 -636
- edsl/surveys/SurveyExportMixin.py +9 -71
- edsl/surveys/SurveyFlowVisualizationMixin.py +1 -2
- edsl/surveys/SurveyQualtricsImport.py +4 -75
- edsl/utilities/gcp_bucket/simple_example.py +9 -0
- edsl/utilities/utilities.py +1 -9
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/METADATA +2 -5
- edsl-0.1.33.dev1.dist-info/RECORD +209 -0
- edsl/TemplateLoader.py +0 -24
- edsl/auto/AutoStudy.py +0 -117
- edsl/auto/StageBase.py +0 -230
- edsl/auto/StageGenerateSurvey.py +0 -178
- edsl/auto/StageLabelQuestions.py +0 -125
- edsl/auto/StagePersona.py +0 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +0 -88
- edsl/auto/StagePersonaDimensionValues.py +0 -74
- edsl/auto/StagePersonaDimensions.py +0 -69
- edsl/auto/StageQuestions.py +0 -73
- edsl/auto/SurveyCreatorPipeline.py +0 -21
- edsl/auto/utilities.py +0 -224
- edsl/coop/PriceFetcher.py +0 -58
- edsl/inference_services/MistralAIService.py +0 -120
- edsl/inference_services/TestService.py +0 -80
- edsl/inference_services/TogetherAIService.py +0 -170
- edsl/jobs/FailedQuestion.py +0 -78
- edsl/jobs/runners/JobsRunnerStatus.py +0 -331
- edsl/language_models/fake_openai_call.py +0 -15
- edsl/language_models/fake_openai_service.py +0 -61
- edsl/language_models/utilities.py +0 -61
- edsl/questions/QuestionBaseGenMixin.py +0 -133
- edsl/questions/QuestionBasePromptsMixin.py +0 -266
- edsl/questions/Quick.py +0 -41
- edsl/questions/ResponseValidatorABC.py +0 -170
- edsl/questions/decorators.py +0 -21
- edsl/questions/prompt_templates/question_budget.jinja +0 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +0 -32
- edsl/questions/prompt_templates/question_extract.jinja +0 -11
- edsl/questions/prompt_templates/question_free_text.jinja +0 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +0 -11
- edsl/questions/prompt_templates/question_list.jinja +0 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +0 -33
- edsl/questions/prompt_templates/question_numerical.jinja +0 -37
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +0 -7
- edsl/questions/templates/budget/question_presentation.jinja +0 -7
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +0 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +0 -22
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +0 -7
- edsl/questions/templates/extract/question_presentation.jinja +0 -1
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +0 -1
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +0 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +0 -12
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +0 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +0 -5
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +0 -4
- edsl/questions/templates/list/question_presentation.jinja +0 -5
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +0 -9
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +0 -12
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +0 -8
- edsl/questions/templates/numerical/question_presentation.jinja +0 -7
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +0 -11
- edsl/questions/templates/rank/question_presentation.jinja +0 -15
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +0 -8
- edsl/questions/templates/top_k/question_presentation.jinja +0 -22
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +0 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +0 -12
- edsl/results/DatasetTree.py +0 -145
- edsl/results/Selector.py +0 -118
- edsl/results/tree_explore.py +0 -115
- edsl/surveys/instructions/ChangeInstruction.py +0 -47
- edsl/surveys/instructions/Instruction.py +0 -34
- edsl/surveys/instructions/InstructionCollection.py +0 -77
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +0 -24
- edsl/templates/error_reporting/exceptions_by_model.html +0 -35
- edsl/templates/error_reporting/exceptions_by_question_name.html +0 -17
- edsl/templates/error_reporting/exceptions_by_type.html +0 -17
- edsl/templates/error_reporting/interview_details.html +0 -116
- edsl/templates/error_reporting/interviews.html +0 -10
- edsl/templates/error_reporting/overview.html +0 -5
- edsl/templates/error_reporting/performance_plot.html +0 -2
- edsl/templates/error_reporting/report.css +0 -74
- edsl/templates/error_reporting/report.html +0 -118
- edsl/templates/error_reporting/report.js +0 -25
- edsl-0.1.33.dist-info/RECORD +0 -295
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/WHEEL +0 -0
@@ -1,76 +1,50 @@
|
|
1
1
|
"""This module contains the Interview class, which is responsible for conducting an interview asynchronously."""
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
|
+
import traceback
|
4
5
|
import asyncio
|
5
|
-
|
6
|
-
|
7
|
-
from tenacity import (
|
8
|
-
retry,
|
9
|
-
stop_after_attempt,
|
10
|
-
wait_exponential,
|
11
|
-
retry_if_exception_type,
|
12
|
-
RetryError,
|
13
|
-
)
|
6
|
+
import time
|
7
|
+
from typing import Any, Type, List, Generator, Optional
|
14
8
|
|
15
|
-
from edsl import
|
9
|
+
from edsl.jobs.Answers import Answers
|
16
10
|
from edsl.surveys.base import EndOfSurvey
|
17
|
-
from edsl.exceptions import QuestionAnswerValidationError
|
18
|
-
from edsl.exceptions import QuestionAnswerValidationError
|
19
|
-
from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
|
20
|
-
|
21
11
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
22
|
-
from edsl.jobs.Answers import Answers
|
23
|
-
from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
|
24
12
|
from edsl.jobs.tasks.TaskCreators import TaskCreators
|
13
|
+
|
25
14
|
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
26
|
-
from edsl.jobs.interviews.
|
15
|
+
from edsl.jobs.interviews.interview_exception_tracking import (
|
27
16
|
InterviewExceptionCollection,
|
28
17
|
)
|
29
|
-
|
30
|
-
from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
|
31
|
-
|
32
|
-
from edsl.surveys.base import EndOfSurvey
|
33
|
-
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
34
18
|
from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
35
|
-
from edsl.jobs.
|
36
|
-
from edsl.jobs.
|
37
|
-
|
38
|
-
|
39
|
-
from edsl import Agent, Survey, Scenario, Cache
|
40
|
-
from edsl.language_models import LanguageModel
|
41
|
-
from edsl.questions import QuestionBase
|
42
|
-
from edsl.agents.InvigilatorBase import InvigilatorBase
|
43
|
-
|
44
|
-
from edsl.exceptions.language_models import LanguageModelNoResponseError
|
19
|
+
from edsl.jobs.interviews.retry_management import retry_strategy
|
20
|
+
from edsl.jobs.interviews.InterviewTaskBuildingMixin import InterviewTaskBuildingMixin
|
21
|
+
from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
|
45
22
|
|
23
|
+
import asyncio
|
46
24
|
|
47
|
-
from edsl import CONFIG
|
48
25
|
|
49
|
-
|
50
|
-
|
51
|
-
EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
|
26
|
+
def run_async(coro):
|
27
|
+
return asyncio.run(coro)
|
52
28
|
|
53
29
|
|
54
|
-
class Interview(InterviewStatusMixin):
|
30
|
+
class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
55
31
|
"""
|
56
32
|
An 'interview' is one agent answering one survey, with one language model, for a given scenario.
|
57
33
|
|
58
34
|
The main method is `async_conduct_interview`, which conducts the interview asynchronously.
|
59
|
-
Most of the class is dedicated to creating the tasks for each question in the survey, and then running them.
|
60
35
|
"""
|
61
36
|
|
62
37
|
def __init__(
|
63
38
|
self,
|
64
|
-
agent: Agent,
|
65
|
-
survey: Survey,
|
66
|
-
scenario: Scenario,
|
39
|
+
agent: "Agent",
|
40
|
+
survey: "Survey",
|
41
|
+
scenario: "Scenario",
|
67
42
|
model: Type["LanguageModel"],
|
68
43
|
debug: Optional[bool] = False,
|
69
44
|
iteration: int = 0,
|
70
45
|
cache: Optional["Cache"] = None,
|
71
46
|
sidecar_model: Optional["LanguageModel"] = None,
|
72
|
-
skip_retry
|
73
|
-
raise_validation_errors: bool = True,
|
47
|
+
skip_retry=False,
|
74
48
|
):
|
75
49
|
"""Initialize the Interview instance.
|
76
50
|
|
@@ -110,15 +84,11 @@ class Interview(InterviewStatusMixin):
|
|
110
84
|
] = Answers() # will get filled in as interview progresses
|
111
85
|
self.sidecar_model = sidecar_model
|
112
86
|
|
113
|
-
# self.stop_on_exception = False
|
114
|
-
|
115
87
|
# Trackers
|
116
88
|
self.task_creators = TaskCreators() # tracks the task creators
|
117
89
|
self.exceptions = InterviewExceptionCollection()
|
118
|
-
|
119
90
|
self._task_status_log_dict = InterviewStatusLog()
|
120
91
|
self.skip_retry = skip_retry
|
121
|
-
self.raise_validation_errors = raise_validation_errors
|
122
92
|
|
123
93
|
# dictionary mapping question names to their index in the survey.
|
124
94
|
self.to_index = {
|
@@ -126,9 +96,6 @@ class Interview(InterviewStatusMixin):
|
|
126
96
|
for index, question_name in enumerate(self.survey.question_names)
|
127
97
|
}
|
128
98
|
|
129
|
-
self.failed_questions = []
|
130
|
-
|
131
|
-
# region: Serialization
|
132
99
|
def _to_dict(self, include_exceptions=False) -> dict[str, Any]:
|
133
100
|
"""Return a dictionary representation of the Interview instance.
|
134
101
|
This is just for hashing purposes.
|
@@ -153,301 +120,13 @@ class Interview(InterviewStatusMixin):
|
|
153
120
|
|
154
121
|
return dict_hash(self._to_dict())
|
155
122
|
|
156
|
-
# endregion
|
157
|
-
|
158
|
-
# region: Creating tasks
|
159
|
-
@property
|
160
|
-
def dag(self) -> "DAG":
|
161
|
-
"""Return the directed acyclic graph for the survey.
|
162
|
-
|
163
|
-
The DAG, or directed acyclic graph, is a dictionary that maps question names to their dependencies.
|
164
|
-
It is used to determine the order in which questions should be answered.
|
165
|
-
This reflects both agent 'memory' considerations and 'skip' logic.
|
166
|
-
The 'textify' parameter is set to True, so that the question names are returned as strings rather than integer indices.
|
167
|
-
|
168
|
-
>>> i = Interview.example()
|
169
|
-
>>> i.dag == {'q2': {'q0'}, 'q1': {'q0'}}
|
170
|
-
True
|
171
|
-
"""
|
172
|
-
return self.survey.dag(textify=True)
|
173
|
-
|
174
|
-
def _build_question_tasks(
|
175
|
-
self,
|
176
|
-
model_buckets: ModelBuckets,
|
177
|
-
) -> list[asyncio.Task]:
|
178
|
-
"""Create a task for each question, with dependencies on the questions that must be answered before this one can be answered.
|
179
|
-
|
180
|
-
:param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
|
181
|
-
:param model_buckets: the model buckets used to track and control usage rates.
|
182
|
-
"""
|
183
|
-
tasks = []
|
184
|
-
for question in self.survey.questions:
|
185
|
-
tasks_that_must_be_completed_before = list(
|
186
|
-
self._get_tasks_that_must_be_completed_before(
|
187
|
-
tasks=tasks, question=question
|
188
|
-
)
|
189
|
-
)
|
190
|
-
question_task = self._create_question_task(
|
191
|
-
question=question,
|
192
|
-
tasks_that_must_be_completed_before=tasks_that_must_be_completed_before,
|
193
|
-
model_buckets=model_buckets,
|
194
|
-
iteration=self.iteration,
|
195
|
-
)
|
196
|
-
tasks.append(question_task)
|
197
|
-
return tuple(tasks)
|
198
|
-
|
199
|
-
def _get_tasks_that_must_be_completed_before(
|
200
|
-
self, *, tasks: list[asyncio.Task], question: "QuestionBase"
|
201
|
-
) -> Generator[asyncio.Task, None, None]:
|
202
|
-
"""Return the tasks that must be completed before the given question can be answered.
|
203
|
-
|
204
|
-
:param tasks: a list of tasks that have been created so far.
|
205
|
-
:param question: the question for which we are determining dependencies.
|
206
|
-
|
207
|
-
If a question has no dependencies, this will be an empty list, [].
|
208
|
-
"""
|
209
|
-
parents_of_focal_question = self.dag.get(question.question_name, [])
|
210
|
-
for parent_question_name in parents_of_focal_question:
|
211
|
-
yield tasks[self.to_index[parent_question_name]]
|
212
|
-
|
213
|
-
def _create_question_task(
|
214
|
-
self,
|
215
|
-
*,
|
216
|
-
question: QuestionBase,
|
217
|
-
tasks_that_must_be_completed_before: list[asyncio.Task],
|
218
|
-
model_buckets: ModelBuckets,
|
219
|
-
iteration: int = 0,
|
220
|
-
) -> asyncio.Task:
|
221
|
-
"""Create a task that depends on the passed-in dependencies that are awaited before the task is run.
|
222
|
-
|
223
|
-
:param question: the question to be answered. This is the question we are creating a task for.
|
224
|
-
:param tasks_that_must_be_completed_before: the tasks that must be completed before the focal task is run.
|
225
|
-
:param model_buckets: the model buckets used to track and control usage rates.
|
226
|
-
:param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
|
227
|
-
:param iteration: the iteration number for the interview.
|
228
|
-
|
229
|
-
The task is created by a `QuestionTaskCreator`, which is responsible for creating the task and managing its dependencies.
|
230
|
-
It is passed a reference to the function that will be called to answer the question.
|
231
|
-
It is passed a list "tasks_that_must_be_completed_before" that are awaited before the task is run.
|
232
|
-
These are added as a dependency to the focal task.
|
233
|
-
"""
|
234
|
-
task_creator = QuestionTaskCreator(
|
235
|
-
question=question,
|
236
|
-
answer_question_func=self._answer_question_and_record_task,
|
237
|
-
token_estimator=self._get_estimated_request_tokens,
|
238
|
-
model_buckets=model_buckets,
|
239
|
-
iteration=iteration,
|
240
|
-
)
|
241
|
-
for task in tasks_that_must_be_completed_before:
|
242
|
-
task_creator.add_dependency(task)
|
243
|
-
|
244
|
-
self.task_creators.update(
|
245
|
-
{question.question_name: task_creator}
|
246
|
-
) # track this task creator
|
247
|
-
return task_creator.generate_task()
|
248
|
-
|
249
|
-
def _get_estimated_request_tokens(self, question) -> float:
|
250
|
-
"""Estimate the number of tokens that will be required to run the focal task."""
|
251
|
-
invigilator = self._get_invigilator(question=question)
|
252
|
-
# TODO: There should be a way to get a more accurate estimate.
|
253
|
-
combined_text = ""
|
254
|
-
for prompt in invigilator.get_prompts().values():
|
255
|
-
if hasattr(prompt, "text"):
|
256
|
-
combined_text += prompt.text
|
257
|
-
elif isinstance(prompt, str):
|
258
|
-
combined_text += prompt
|
259
|
-
else:
|
260
|
-
raise ValueError(f"Prompt is of type {type(prompt)}")
|
261
|
-
return len(combined_text) / 4.0
|
262
|
-
|
263
|
-
async def _answer_question_and_record_task(
|
264
|
-
self,
|
265
|
-
*,
|
266
|
-
question: "QuestionBase",
|
267
|
-
task=None,
|
268
|
-
) -> "AgentResponseDict":
|
269
|
-
"""Answer a question and records the task."""
|
270
|
-
|
271
|
-
had_language_model_no_response_error = False
|
272
|
-
|
273
|
-
@retry(
|
274
|
-
stop=stop_after_attempt(EDSL_MAX_ATTEMPTS),
|
275
|
-
wait=wait_exponential(
|
276
|
-
multiplier=EDSL_BACKOFF_START_SEC, max=EDSL_BACKOFF_MAX_SEC
|
277
|
-
),
|
278
|
-
retry=retry_if_exception_type(LanguageModelNoResponseError),
|
279
|
-
reraise=True,
|
280
|
-
)
|
281
|
-
async def attempt_answer():
|
282
|
-
nonlocal had_language_model_no_response_error
|
283
|
-
|
284
|
-
invigilator = self._get_invigilator(question)
|
285
|
-
|
286
|
-
if self._skip_this_question(question):
|
287
|
-
return invigilator.get_failed_task_result(
|
288
|
-
failure_reason="Question skipped."
|
289
|
-
)
|
290
|
-
|
291
|
-
try:
|
292
|
-
response: EDSLResultObjectInput = (
|
293
|
-
await invigilator.async_answer_question()
|
294
|
-
)
|
295
|
-
if response.validated:
|
296
|
-
self.answers.add_answer(response=response, question=question)
|
297
|
-
self._cancel_skipped_questions(question)
|
298
|
-
else:
|
299
|
-
if (
|
300
|
-
hasattr(response, "exception_occurred")
|
301
|
-
and response.exception_occurred
|
302
|
-
):
|
303
|
-
raise response.exception_occurred
|
304
|
-
|
305
|
-
except QuestionAnswerValidationError as e:
|
306
|
-
self._handle_exception(e, invigilator, task)
|
307
|
-
return invigilator.get_failed_task_result(
|
308
|
-
failure_reason="Question answer validation failed."
|
309
|
-
)
|
310
|
-
|
311
|
-
except asyncio.TimeoutError as e:
|
312
|
-
self._handle_exception(e, invigilator, task)
|
313
|
-
had_language_model_no_response_error = True
|
314
|
-
raise LanguageModelNoResponseError(
|
315
|
-
f"Language model timed out for question '{question.question_name}.'"
|
316
|
-
)
|
317
|
-
|
318
|
-
except Exception as e:
|
319
|
-
self._handle_exception(e, invigilator, task)
|
320
|
-
|
321
|
-
if "response" not in locals():
|
322
|
-
had_language_model_no_response_error = True
|
323
|
-
raise LanguageModelNoResponseError(
|
324
|
-
f"Language model did not return a response for question '{question.question_name}.'"
|
325
|
-
)
|
326
|
-
|
327
|
-
# if it gets here, it means the no response error was fixed
|
328
|
-
if (
|
329
|
-
question.question_name in self.exceptions
|
330
|
-
and had_language_model_no_response_error
|
331
|
-
):
|
332
|
-
self.exceptions.record_fixed_question(question.question_name)
|
333
|
-
|
334
|
-
return response
|
335
|
-
|
336
|
-
try:
|
337
|
-
return await attempt_answer()
|
338
|
-
except RetryError as retry_error:
|
339
|
-
# All retries have failed for LanguageModelNoResponseError
|
340
|
-
original_error = retry_error.last_attempt.exception()
|
341
|
-
self._handle_exception(
|
342
|
-
original_error, self._get_invigilator(question), task
|
343
|
-
)
|
344
|
-
raise original_error # Re-raise the original error after handling
|
345
|
-
|
346
|
-
def _get_invigilator(self, question: QuestionBase) -> InvigilatorBase:
|
347
|
-
"""Return an invigilator for the given question.
|
348
|
-
|
349
|
-
:param question: the question to be answered
|
350
|
-
:param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
|
351
|
-
"""
|
352
|
-
invigilator = self.agent.create_invigilator(
|
353
|
-
question=question,
|
354
|
-
scenario=self.scenario,
|
355
|
-
model=self.model,
|
356
|
-
debug=False,
|
357
|
-
survey=self.survey,
|
358
|
-
memory_plan=self.survey.memory_plan,
|
359
|
-
current_answers=self.answers,
|
360
|
-
iteration=self.iteration,
|
361
|
-
cache=self.cache,
|
362
|
-
sidecar_model=self.sidecar_model,
|
363
|
-
raise_validation_errors=self.raise_validation_errors,
|
364
|
-
)
|
365
|
-
"""Return an invigilator for the given question."""
|
366
|
-
return invigilator
|
367
|
-
|
368
|
-
def _skip_this_question(self, current_question: "QuestionBase") -> bool:
|
369
|
-
"""Determine if the current question should be skipped.
|
370
|
-
|
371
|
-
:param current_question: the question to be answered.
|
372
|
-
"""
|
373
|
-
current_question_index = self.to_index[current_question.question_name]
|
374
|
-
|
375
|
-
answers = self.answers | self.scenario | self.agent["traits"]
|
376
|
-
skip = self.survey.rule_collection.skip_question_before_running(
|
377
|
-
current_question_index, answers
|
378
|
-
)
|
379
|
-
return skip
|
380
|
-
|
381
|
-
def _handle_exception(
|
382
|
-
self, e: Exception, invigilator: "InvigilatorBase", task=None
|
383
|
-
):
|
384
|
-
import copy
|
385
|
-
|
386
|
-
# breakpoint()
|
387
|
-
|
388
|
-
answers = copy.copy(self.answers)
|
389
|
-
exception_entry = InterviewExceptionEntry(
|
390
|
-
exception=e,
|
391
|
-
invigilator=invigilator,
|
392
|
-
answers=answers,
|
393
|
-
)
|
394
|
-
if task:
|
395
|
-
task.task_status = TaskStatus.FAILED
|
396
|
-
self.exceptions.add(invigilator.question.question_name, exception_entry)
|
397
|
-
|
398
|
-
if self.raise_validation_errors:
|
399
|
-
if isinstance(e, QuestionAnswerValidationError):
|
400
|
-
raise e
|
401
|
-
|
402
|
-
if hasattr(self, "stop_on_exception"):
|
403
|
-
stop_on_exception = self.stop_on_exception
|
404
|
-
else:
|
405
|
-
stop_on_exception = False
|
406
|
-
|
407
|
-
if stop_on_exception:
|
408
|
-
raise e
|
409
|
-
|
410
|
-
def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
|
411
|
-
"""Cancel the tasks for questions that are skipped.
|
412
|
-
|
413
|
-
:param current_question: the question that was just answered.
|
414
|
-
|
415
|
-
It first determines the next question, given the current question and the current answers.
|
416
|
-
If the next question is the end of the survey, it cancels all remaining tasks.
|
417
|
-
If the next question is after the current question, it cancels all tasks between the current question and the next question.
|
418
|
-
"""
|
419
|
-
current_question_index: int = self.to_index[current_question.question_name]
|
420
|
-
|
421
|
-
next_question: Union[
|
422
|
-
int, EndOfSurvey
|
423
|
-
] = self.survey.rule_collection.next_question(
|
424
|
-
q_now=current_question_index,
|
425
|
-
answers=self.answers | self.scenario | self.agent["traits"],
|
426
|
-
)
|
427
|
-
|
428
|
-
next_question_index = next_question.next_q
|
429
|
-
|
430
|
-
def cancel_between(start, end):
|
431
|
-
"""Cancel the tasks between the start and end indices."""
|
432
|
-
for i in range(start, end):
|
433
|
-
self.tasks[i].cancel()
|
434
|
-
|
435
|
-
if next_question_index == EndOfSurvey:
|
436
|
-
cancel_between(current_question_index + 1, len(self.survey.questions))
|
437
|
-
return
|
438
|
-
|
439
|
-
if next_question_index > (current_question_index + 1):
|
440
|
-
cancel_between(current_question_index + 1, next_question_index)
|
441
|
-
|
442
|
-
# endregion
|
443
|
-
|
444
|
-
# region: Conducting the interview
|
445
123
|
async def async_conduct_interview(
|
446
124
|
self,
|
447
|
-
|
125
|
+
*,
|
126
|
+
model_buckets: ModelBuckets = None,
|
127
|
+
debug: bool = False,
|
448
128
|
stop_on_exception: bool = False,
|
449
129
|
sidecar_model: Optional["LanguageModel"] = None,
|
450
|
-
raise_validation_errors: bool = True,
|
451
130
|
) -> tuple["Answers", List[dict[str, Any]]]:
|
452
131
|
"""
|
453
132
|
Conduct an Interview asynchronously.
|
@@ -467,6 +146,19 @@ class Interview(InterviewStatusMixin):
|
|
467
146
|
|
468
147
|
>>> i = Interview.example(throw_exception = True)
|
469
148
|
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
149
|
+
Attempt 1 failed with exception:This is a test error now waiting 1.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
150
|
+
<BLANKLINE>
|
151
|
+
<BLANKLINE>
|
152
|
+
Attempt 2 failed with exception:This is a test error now waiting 2.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
153
|
+
<BLANKLINE>
|
154
|
+
<BLANKLINE>
|
155
|
+
Attempt 3 failed with exception:This is a test error now waiting 4.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
156
|
+
<BLANKLINE>
|
157
|
+
<BLANKLINE>
|
158
|
+
Attempt 4 failed with exception:This is a test error now waiting 8.00 seconds before retrying.Parameters: start=1.0, max=60.0, max_attempts=5.
|
159
|
+
<BLANKLINE>
|
160
|
+
<BLANKLINE>
|
161
|
+
|
470
162
|
>>> i.exceptions
|
471
163
|
{'q0': ...
|
472
164
|
>>> i = Interview.example()
|
@@ -476,30 +168,26 @@ class Interview(InterviewStatusMixin):
|
|
476
168
|
asyncio.exceptions.CancelledError
|
477
169
|
"""
|
478
170
|
self.sidecar_model = sidecar_model
|
479
|
-
self.stop_on_exception = stop_on_exception
|
480
171
|
|
481
172
|
# if no model bucket is passed, create an 'infinity' bucket with no rate limits
|
482
173
|
if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
|
483
174
|
model_buckets = ModelBuckets.infinity_bucket()
|
484
175
|
|
176
|
+
## build the tasks using the InterviewTaskBuildingMixin
|
485
177
|
## This is the key part---it creates a task for each question,
|
486
178
|
## with dependencies on the questions that must be answered before this one can be answered.
|
487
|
-
self.tasks = self._build_question_tasks(
|
179
|
+
self.tasks = self._build_question_tasks(
|
180
|
+
debug=debug, model_buckets=model_buckets
|
181
|
+
)
|
488
182
|
|
489
183
|
## 'Invigilators' are used to administer the survey
|
490
|
-
self.invigilators =
|
491
|
-
|
492
|
-
|
493
|
-
await asyncio.gather(
|
494
|
-
*self.tasks, return_exceptions=not stop_on_exception
|
495
|
-
) # not stop_on_exception)
|
184
|
+
self.invigilators = list(self._build_invigilators(debug=debug))
|
185
|
+
# await the tasks being conducted
|
186
|
+
await asyncio.gather(*self.tasks, return_exceptions=not stop_on_exception)
|
496
187
|
self.answers.replace_missing_answers_with_none(self.survey)
|
497
188
|
valid_results = list(self._extract_valid_results())
|
498
189
|
return self.answers, valid_results
|
499
190
|
|
500
|
-
# endregion
|
501
|
-
|
502
|
-
# region: Extracting results and recording errors
|
503
191
|
def _extract_valid_results(self) -> Generator["Answers", None, None]:
|
504
192
|
"""Extract the valid results from the list of results.
|
505
193
|
|
@@ -512,6 +200,8 @@ class Interview(InterviewStatusMixin):
|
|
512
200
|
>>> results = list(i._extract_valid_results())
|
513
201
|
>>> len(results) == len(i.survey)
|
514
202
|
True
|
203
|
+
>>> type(results[0])
|
204
|
+
<class 'edsl.data_transfer_models.AgentResponseDict'>
|
515
205
|
"""
|
516
206
|
assert len(self.tasks) == len(self.invigilators)
|
517
207
|
|
@@ -522,24 +212,46 @@ class Interview(InterviewStatusMixin):
|
|
522
212
|
try:
|
523
213
|
result = task.result()
|
524
214
|
except asyncio.CancelledError as e: # task was cancelled
|
525
|
-
result = invigilator.get_failed_task_result(
|
526
|
-
failure_reason="Task was cancelled."
|
527
|
-
)
|
215
|
+
result = invigilator.get_failed_task_result()
|
528
216
|
except Exception as e: # any other kind of exception in the task
|
529
|
-
result = invigilator.get_failed_task_result(
|
530
|
-
|
531
|
-
)
|
532
|
-
exception_entry = InterviewExceptionEntry(
|
533
|
-
exception=e,
|
534
|
-
invigilator=invigilator,
|
535
|
-
)
|
536
|
-
self.exceptions.add(task.get_name(), exception_entry)
|
537
|
-
|
217
|
+
result = invigilator.get_failed_task_result()
|
218
|
+
self._record_exception(task, e)
|
538
219
|
yield result
|
539
220
|
|
540
|
-
|
221
|
+
def _record_exception(self, task, exception: Exception) -> None:
|
222
|
+
"""Record an exception in the Interview instance.
|
223
|
+
|
224
|
+
It records the exception in the Interview instance, with the task name and the exception entry.
|
225
|
+
|
226
|
+
>>> i = Interview.example()
|
227
|
+
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
228
|
+
>>> i.exceptions
|
229
|
+
{}
|
230
|
+
>>> i._record_exception(i.tasks[0], Exception("An exception occurred."))
|
231
|
+
>>> i.exceptions
|
232
|
+
{'q0': ...
|
233
|
+
"""
|
234
|
+
exception_entry = InterviewExceptionEntry(exception)
|
235
|
+
self.exceptions.add(task.get_name(), exception_entry)
|
236
|
+
|
237
|
+
@property
|
238
|
+
def dag(self) -> "DAG":
|
239
|
+
"""Return the directed acyclic graph for the survey.
|
240
|
+
|
241
|
+
The DAG, or directed acyclic graph, is a dictionary that maps question names to their dependencies.
|
242
|
+
It is used to determine the order in which questions should be answered.
|
243
|
+
This reflects both agent 'memory' considerations and 'skip' logic.
|
244
|
+
The 'textify' parameter is set to True, so that the question names are returned as strings rather than integer indices.
|
245
|
+
|
246
|
+
>>> i = Interview.example()
|
247
|
+
>>> i.dag == {'q2': {'q0'}, 'q1': {'q0'}}
|
248
|
+
True
|
249
|
+
"""
|
250
|
+
return self.survey.dag(textify=True)
|
541
251
|
|
542
|
-
|
252
|
+
#######################
|
253
|
+
# Dunder methods
|
254
|
+
#######################
|
543
255
|
def __repr__(self) -> str:
|
544
256
|
"""Return a string representation of the Interview instance."""
|
545
257
|
return f"Interview(agent = {repr(self.agent)}, survey = {repr(self.survey)}, scenario = {repr(self.scenario)}, model = {repr(self.model)})"
|
@@ -2,62 +2,24 @@ import traceback
|
|
2
2
|
import datetime
|
3
3
|
import time
|
4
4
|
from collections import UserDict
|
5
|
-
|
5
|
+
|
6
|
+
# traceback=traceback.format_exc(),
|
7
|
+
# traceback = frame_summary_to_dict(traceback.extract_tb(e.__traceback__))
|
8
|
+
# traceback = [frame_summary_to_dict(f) for f in traceback.extract_tb(e.__traceback__)]
|
6
9
|
|
7
10
|
|
8
11
|
class InterviewExceptionEntry:
|
9
|
-
"""Class to record an exception that occurred during the interview.
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
traceback_format="text",
|
18
|
-
answers=None,
|
19
|
-
):
|
12
|
+
"""Class to record an exception that occurred during the interview.
|
13
|
+
|
14
|
+
>>> entry = InterviewExceptionEntry.example()
|
15
|
+
>>> entry.to_dict()['exception']
|
16
|
+
"ValueError('An error occurred.')"
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, exception: Exception, traceback_format="html"):
|
20
20
|
self.time = datetime.datetime.now().isoformat()
|
21
21
|
self.exception = exception
|
22
|
-
# self.failed_question = failed_question
|
23
|
-
self.invigilator = invigilator
|
24
22
|
self.traceback_format = traceback_format
|
25
|
-
self.answers = answers
|
26
|
-
|
27
|
-
@property
|
28
|
-
def question_type(self):
|
29
|
-
# return self.failed_question.question.question_type
|
30
|
-
return self.invigilator.question.question_type
|
31
|
-
|
32
|
-
@property
|
33
|
-
def name(self):
|
34
|
-
return repr(self.exception)
|
35
|
-
|
36
|
-
@property
|
37
|
-
def rendered_prompts(self):
|
38
|
-
return self.invigilator.get_prompts()
|
39
|
-
|
40
|
-
@property
|
41
|
-
def key_sequence(self):
|
42
|
-
return self.invigilator.model.key_sequence
|
43
|
-
|
44
|
-
@property
|
45
|
-
def generated_token_string(self):
|
46
|
-
# return "POO"
|
47
|
-
if self.invigilator.raw_model_response is None:
|
48
|
-
return "No raw model response available."
|
49
|
-
else:
|
50
|
-
return self.invigilator.model.get_generated_token_string(
|
51
|
-
self.invigilator.raw_model_response
|
52
|
-
)
|
53
|
-
|
54
|
-
@property
|
55
|
-
def raw_model_response(self):
|
56
|
-
import json
|
57
|
-
|
58
|
-
if self.invigilator.raw_model_response is None:
|
59
|
-
return "No raw model response available."
|
60
|
-
return json.dumps(self.invigilator.raw_model_response, indent=2)
|
61
23
|
|
62
24
|
def __getitem__(self, key):
|
63
25
|
# Support dict-like access obj['a']
|
@@ -65,37 +27,11 @@ class InterviewExceptionEntry:
|
|
65
27
|
|
66
28
|
@classmethod
|
67
29
|
def example(cls):
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
results = q.by(m).run(
|
74
|
-
skip_retry=True, print_exceptions=False, raise_validation_errors=True
|
75
|
-
)
|
76
|
-
return results.task_history.exceptions[0]["how_are_you"][0]
|
77
|
-
|
78
|
-
@property
|
79
|
-
def code_to_reproduce(self):
|
80
|
-
return self.code(run=False)
|
81
|
-
|
82
|
-
def code(self, run=True):
|
83
|
-
lines = []
|
84
|
-
lines.append("from edsl import Question, Model, Scenario, Agent")
|
85
|
-
|
86
|
-
lines.append(f"q = {repr(self.invigilator.question)}")
|
87
|
-
lines.append(f"scenario = {repr(self.invigilator.scenario)}")
|
88
|
-
lines.append(f"agent = {repr(self.invigilator.agent)}")
|
89
|
-
lines.append(f"m = Model('{self.invigilator.model.model}')")
|
90
|
-
lines.append("results = q.by(m).by(agent).by(scenario).run()")
|
91
|
-
code_str = "\n".join(lines)
|
92
|
-
|
93
|
-
if run:
|
94
|
-
# Create a new namespace to avoid polluting the global namespace
|
95
|
-
namespace = {}
|
96
|
-
exec(code_str, namespace)
|
97
|
-
return namespace["results"]
|
98
|
-
return code_str
|
30
|
+
try:
|
31
|
+
raise ValueError("An error occurred.")
|
32
|
+
except Exception as e:
|
33
|
+
entry = InterviewExceptionEntry(e)
|
34
|
+
return entry
|
99
35
|
|
100
36
|
@property
|
101
37
|
def traceback(self):
|
@@ -142,15 +78,13 @@ class InterviewExceptionEntry:
|
|
142
78
|
|
143
79
|
>>> entry = InterviewExceptionEntry.example()
|
144
80
|
>>> entry.to_dict()['exception']
|
145
|
-
ValueError()
|
81
|
+
"ValueError('An error occurred.')"
|
146
82
|
|
147
83
|
"""
|
148
84
|
return {
|
149
|
-
"exception": self.exception,
|
85
|
+
"exception": repr(self.exception),
|
150
86
|
"time": self.time,
|
151
87
|
"traceback": self.traceback,
|
152
|
-
# "failed_question": self.failed_question.to_dict(),
|
153
|
-
"invigilator": self.invigilator.to_dict(),
|
154
88
|
}
|
155
89
|
|
156
90
|
def push(self):
|