edsl 0.1.38.dev1__py3-none-any.whl → 0.1.38.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +303 -303
- edsl/BaseDiff.py +260 -260
- edsl/TemplateLoader.py +24 -24
- edsl/__init__.py +49 -48
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +858 -855
- edsl/agents/AgentList.py +362 -350
- edsl/agents/Invigilator.py +222 -222
- edsl/agents/InvigilatorBase.py +284 -284
- edsl/agents/PromptConstructor.py +353 -353
- edsl/agents/__init__.py +3 -3
- edsl/agents/descriptors.py +99 -99
- edsl/agents/prompt_helpers.py +129 -129
- edsl/auto/AutoStudy.py +117 -117
- edsl/auto/StageBase.py +230 -230
- edsl/auto/StageGenerateSurvey.py +178 -178
- edsl/auto/StageLabelQuestions.py +125 -125
- edsl/auto/StagePersona.py +61 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
- edsl/auto/StagePersonaDimensionValues.py +74 -74
- edsl/auto/StagePersonaDimensions.py +69 -69
- edsl/auto/StageQuestions.py +73 -73
- edsl/auto/SurveyCreatorPipeline.py +21 -21
- edsl/auto/utilities.py +224 -224
- edsl/base/Base.py +279 -289
- edsl/config.py +149 -149
- edsl/conversation/Conversation.py +290 -290
- edsl/conversation/car_buying.py +58 -58
- edsl/conversation/chips.py +95 -95
- edsl/conversation/mug_negotiation.py +81 -81
- edsl/conversation/next_speaker_utilities.py +93 -93
- edsl/coop/PriceFetcher.py +54 -54
- edsl/coop/__init__.py +2 -2
- edsl/coop/coop.py +961 -958
- edsl/coop/utils.py +131 -131
- edsl/data/Cache.py +530 -527
- edsl/data/CacheEntry.py +228 -228
- edsl/data/CacheHandler.py +149 -149
- edsl/data/RemoteCacheSync.py +97 -97
- edsl/data/SQLiteDict.py +292 -292
- edsl/data/__init__.py +4 -4
- edsl/data/orm.py +10 -10
- edsl/data_transfer_models.py +73 -73
- edsl/enums.py +173 -173
- edsl/exceptions/BaseException.py +21 -21
- edsl/exceptions/__init__.py +54 -54
- edsl/exceptions/agents.py +42 -38
- edsl/exceptions/cache.py +5 -0
- edsl/exceptions/configuration.py +16 -16
- edsl/exceptions/coop.py +10 -10
- edsl/exceptions/data.py +14 -14
- edsl/exceptions/general.py +34 -34
- edsl/exceptions/jobs.py +33 -33
- edsl/exceptions/language_models.py +63 -63
- edsl/exceptions/prompts.py +15 -15
- edsl/exceptions/questions.py +91 -91
- edsl/exceptions/results.py +29 -29
- edsl/exceptions/scenarios.py +22 -22
- edsl/exceptions/surveys.py +37 -37
- edsl/inference_services/AnthropicService.py +87 -87
- edsl/inference_services/AwsBedrock.py +120 -120
- edsl/inference_services/AzureAI.py +217 -217
- edsl/inference_services/DeepInfraService.py +18 -18
- edsl/inference_services/GoogleService.py +156 -156
- edsl/inference_services/GroqService.py +20 -20
- edsl/inference_services/InferenceServiceABC.py +147 -147
- edsl/inference_services/InferenceServicesCollection.py +97 -97
- edsl/inference_services/MistralAIService.py +123 -123
- edsl/inference_services/OllamaService.py +18 -18
- edsl/inference_services/OpenAIService.py +224 -224
- edsl/inference_services/TestService.py +89 -89
- edsl/inference_services/TogetherAIService.py +170 -170
- edsl/inference_services/models_available_cache.py +118 -118
- edsl/inference_services/rate_limits_cache.py +25 -25
- edsl/inference_services/registry.py +39 -39
- edsl/inference_services/write_available.py +10 -10
- edsl/jobs/Answers.py +56 -56
- edsl/jobs/Jobs.py +1358 -1347
- edsl/jobs/__init__.py +1 -1
- edsl/jobs/buckets/BucketCollection.py +63 -63
- edsl/jobs/buckets/ModelBuckets.py +65 -65
- edsl/jobs/buckets/TokenBucket.py +251 -248
- edsl/jobs/interviews/Interview.py +661 -661
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
- edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
- edsl/jobs/interviews/InterviewStatistic.py +63 -63
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
- edsl/jobs/interviews/InterviewStatusLog.py +92 -92
- edsl/jobs/interviews/ReportErrors.py +66 -66
- edsl/jobs/interviews/interview_status_enum.py +9 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +361 -338
- edsl/jobs/runners/JobsRunnerStatus.py +332 -332
- edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
- edsl/jobs/tasks/TaskCreators.py +64 -64
- edsl/jobs/tasks/TaskHistory.py +451 -442
- edsl/jobs/tasks/TaskStatusLog.py +23 -23
- edsl/jobs/tasks/task_status_enum.py +163 -163
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
- edsl/jobs/tokens/TokenUsage.py +34 -34
- edsl/language_models/KeyLookup.py +30 -30
- edsl/language_models/LanguageModel.py +708 -706
- edsl/language_models/ModelList.py +109 -102
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
- edsl/language_models/__init__.py +3 -3
- edsl/language_models/fake_openai_call.py +15 -15
- edsl/language_models/fake_openai_service.py +61 -61
- edsl/language_models/registry.py +137 -137
- edsl/language_models/repair.py +156 -156
- edsl/language_models/unused/ReplicateBase.py +83 -83
- edsl/language_models/utilities.py +64 -64
- edsl/notebooks/Notebook.py +258 -259
- edsl/notebooks/__init__.py +1 -1
- edsl/prompts/Prompt.py +357 -357
- edsl/prompts/__init__.py +2 -2
- edsl/questions/AnswerValidatorMixin.py +289 -289
- edsl/questions/QuestionBase.py +660 -656
- edsl/questions/QuestionBaseGenMixin.py +161 -161
- edsl/questions/QuestionBasePromptsMixin.py +217 -234
- edsl/questions/QuestionBudget.py +227 -227
- edsl/questions/QuestionCheckBox.py +359 -359
- edsl/questions/QuestionExtract.py +183 -183
- edsl/questions/QuestionFreeText.py +114 -114
- edsl/questions/QuestionFunctional.py +166 -159
- edsl/questions/QuestionList.py +231 -231
- edsl/questions/QuestionMultipleChoice.py +286 -286
- edsl/questions/QuestionNumerical.py +153 -153
- edsl/questions/QuestionRank.py +324 -324
- edsl/questions/Quick.py +41 -41
- edsl/questions/RegisterQuestionsMeta.py +71 -71
- edsl/questions/ResponseValidatorABC.py +174 -174
- edsl/questions/SimpleAskMixin.py +73 -73
- edsl/questions/__init__.py +26 -26
- edsl/questions/compose_questions.py +98 -98
- edsl/questions/decorators.py +21 -21
- edsl/questions/derived/QuestionLikertFive.py +76 -76
- edsl/questions/derived/QuestionLinearScale.py +87 -87
- edsl/questions/derived/QuestionTopK.py +93 -91
- edsl/questions/derived/QuestionYesNo.py +82 -82
- edsl/questions/descriptors.py +413 -413
- edsl/questions/prompt_templates/question_budget.jinja +13 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
- edsl/questions/prompt_templates/question_extract.jinja +11 -11
- edsl/questions/prompt_templates/question_free_text.jinja +3 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
- edsl/questions/prompt_templates/question_list.jinja +17 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
- edsl/questions/prompt_templates/question_numerical.jinja +36 -36
- edsl/questions/question_registry.py +147 -147
- edsl/questions/settings.py +12 -12
- edsl/questions/templates/budget/answering_instructions.jinja +7 -7
- edsl/questions/templates/budget/question_presentation.jinja +7 -7
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
- edsl/questions/templates/extract/answering_instructions.jinja +7 -7
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
- edsl/questions/templates/list/answering_instructions.jinja +3 -3
- edsl/questions/templates/list/question_presentation.jinja +5 -5
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
- edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
- edsl/questions/templates/numerical/question_presentation.jinja +6 -6
- edsl/questions/templates/rank/answering_instructions.jinja +11 -11
- edsl/questions/templates/rank/question_presentation.jinja +15 -15
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
- edsl/questions/templates/top_k/question_presentation.jinja +22 -22
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
- edsl/results/Dataset.py +293 -293
- edsl/results/DatasetExportMixin.py +717 -717
- edsl/results/DatasetTree.py +145 -145
- edsl/results/Result.py +456 -450
- edsl/results/Results.py +1071 -1071
- edsl/results/ResultsDBMixin.py +238 -238
- edsl/results/ResultsExportMixin.py +43 -43
- edsl/results/ResultsFetchMixin.py +33 -33
- edsl/results/ResultsGGMixin.py +121 -121
- edsl/results/ResultsToolsMixin.py +98 -98
- edsl/results/Selector.py +135 -135
- edsl/results/__init__.py +2 -2
- edsl/results/tree_explore.py +115 -115
- edsl/scenarios/FileStore.py +458 -458
- edsl/scenarios/Scenario.py +544 -546
- edsl/scenarios/ScenarioHtmlMixin.py +64 -64
- edsl/scenarios/ScenarioList.py +1112 -1112
- edsl/scenarios/ScenarioListExportMixin.py +52 -52
- edsl/scenarios/ScenarioListPdfMixin.py +261 -261
- edsl/scenarios/__init__.py +4 -4
- edsl/shared.py +1 -1
- edsl/study/ObjectEntry.py +173 -173
- edsl/study/ProofOfWork.py +113 -113
- edsl/study/SnapShot.py +80 -80
- edsl/study/Study.py +528 -528
- edsl/study/__init__.py +4 -4
- edsl/surveys/DAG.py +148 -148
- edsl/surveys/Memory.py +31 -31
- edsl/surveys/MemoryPlan.py +244 -244
- edsl/surveys/Rule.py +326 -330
- edsl/surveys/RuleCollection.py +387 -387
- edsl/surveys/Survey.py +1787 -1795
- edsl/surveys/SurveyCSS.py +261 -261
- edsl/surveys/SurveyExportMixin.py +259 -259
- edsl/surveys/SurveyFlowVisualizationMixin.py +121 -121
- edsl/surveys/SurveyQualtricsImport.py +284 -284
- edsl/surveys/__init__.py +3 -3
- edsl/surveys/base.py +53 -53
- edsl/surveys/descriptors.py +56 -56
- edsl/surveys/instructions/ChangeInstruction.py +49 -47
- edsl/surveys/instructions/Instruction.py +53 -51
- edsl/surveys/instructions/InstructionCollection.py +77 -77
- edsl/templates/error_reporting/base.html +23 -23
- edsl/templates/error_reporting/exceptions_by_model.html +34 -34
- edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
- edsl/templates/error_reporting/exceptions_by_type.html +16 -16
- edsl/templates/error_reporting/interview_details.html +115 -115
- edsl/templates/error_reporting/interviews.html +9 -9
- edsl/templates/error_reporting/overview.html +4 -4
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +73 -73
- edsl/templates/error_reporting/report.html +117 -117
- edsl/templates/error_reporting/report.js +25 -25
- edsl/tools/__init__.py +1 -1
- edsl/tools/clusters.py +192 -192
- edsl/tools/embeddings.py +27 -27
- edsl/tools/embeddings_plotting.py +118 -118
- edsl/tools/plotting.py +112 -112
- edsl/tools/summarize.py +18 -18
- edsl/utilities/SystemInfo.py +28 -28
- edsl/utilities/__init__.py +22 -22
- edsl/utilities/ast_utilities.py +25 -25
- edsl/utilities/data/Registry.py +6 -6
- edsl/utilities/data/__init__.py +1 -1
- edsl/utilities/data/scooter_results.json +1 -1
- edsl/utilities/decorators.py +77 -77
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
- edsl/utilities/interface.py +627 -627
- edsl/{conjure → utilities}/naming_utilities.py +263 -263
- edsl/utilities/repair_functions.py +28 -28
- edsl/utilities/restricted_python.py +70 -70
- edsl/utilities/utilities.py +409 -409
- {edsl-0.1.38.dev1.dist-info → edsl-0.1.38.dev3.dist-info}/LICENSE +21 -21
- {edsl-0.1.38.dev1.dist-info → edsl-0.1.38.dev3.dist-info}/METADATA +1 -1
- edsl-0.1.38.dev3.dist-info/RECORD +269 -0
- edsl/conjure/AgentConstructionMixin.py +0 -160
- edsl/conjure/Conjure.py +0 -62
- edsl/conjure/InputData.py +0 -659
- edsl/conjure/InputDataCSV.py +0 -48
- edsl/conjure/InputDataMixinQuestionStats.py +0 -182
- edsl/conjure/InputDataPyRead.py +0 -91
- edsl/conjure/InputDataSPSS.py +0 -8
- edsl/conjure/InputDataStata.py +0 -8
- edsl/conjure/QuestionOptionMixin.py +0 -76
- edsl/conjure/QuestionTypeMixin.py +0 -23
- edsl/conjure/RawQuestion.py +0 -65
- edsl/conjure/SurveyResponses.py +0 -7
- edsl/conjure/__init__.py +0 -9
- edsl/conjure/examples/placeholder.txt +0 -0
- edsl/conjure/utilities.py +0 -201
- edsl-0.1.38.dev1.dist-info/RECORD +0 -283
- {edsl-0.1.38.dev1.dist-info → edsl-0.1.38.dev3.dist-info}/WHEEL +0 -0
@@ -1,661 +1,661 @@
|
|
1
|
-
"""This module contains the Interview class, which is responsible for conducting an interview asynchronously."""
|
2
|
-
|
3
|
-
from __future__ import annotations
|
4
|
-
import asyncio
|
5
|
-
from typing import Any, Type, List, Generator, Optional, Union
|
6
|
-
import copy
|
7
|
-
|
8
|
-
from tenacity import (
|
9
|
-
retry,
|
10
|
-
stop_after_attempt,
|
11
|
-
wait_exponential,
|
12
|
-
retry_if_exception_type,
|
13
|
-
RetryError,
|
14
|
-
)
|
15
|
-
|
16
|
-
from edsl import CONFIG
|
17
|
-
from edsl.surveys.base import EndOfSurvey
|
18
|
-
from edsl.exceptions import QuestionAnswerValidationError
|
19
|
-
from edsl.exceptions import QuestionAnswerValidationError
|
20
|
-
from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
|
21
|
-
|
22
|
-
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
23
|
-
from edsl.jobs.Answers import Answers
|
24
|
-
from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
|
25
|
-
from edsl.jobs.tasks.TaskCreators import TaskCreators
|
26
|
-
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
27
|
-
from edsl.jobs.interviews.InterviewExceptionCollection import (
|
28
|
-
InterviewExceptionCollection,
|
29
|
-
)
|
30
|
-
|
31
|
-
# from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
|
32
|
-
|
33
|
-
from edsl.surveys.base import EndOfSurvey
|
34
|
-
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
35
|
-
from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
36
|
-
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
37
|
-
from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
|
38
|
-
|
39
|
-
|
40
|
-
from edsl import Agent, Survey, Scenario, Cache
|
41
|
-
from edsl.language_models import LanguageModel
|
42
|
-
from edsl.questions import QuestionBase
|
43
|
-
from edsl.agents.InvigilatorBase import InvigilatorBase
|
44
|
-
|
45
|
-
from edsl.exceptions.language_models import LanguageModelNoResponseError
|
46
|
-
|
47
|
-
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
48
|
-
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
49
|
-
from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
|
50
|
-
|
51
|
-
|
52
|
-
from edsl import CONFIG
|
53
|
-
|
54
|
-
EDSL_BACKOFF_START_SEC = float(CONFIG.get("EDSL_BACKOFF_START_SEC"))
|
55
|
-
EDSL_BACKOFF_MAX_SEC = float(CONFIG.get("EDSL_BACKOFF_MAX_SEC"))
|
56
|
-
EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
|
57
|
-
|
58
|
-
|
59
|
-
class Interview:
|
60
|
-
"""
|
61
|
-
An 'interview' is one agent answering one survey, with one language model, for a given scenario.
|
62
|
-
|
63
|
-
The main method is `async_conduct_interview`, which conducts the interview asynchronously.
|
64
|
-
Most of the class is dedicated to creating the tasks for each question in the survey, and then running them.
|
65
|
-
"""
|
66
|
-
|
67
|
-
def __init__(
|
68
|
-
self,
|
69
|
-
agent: Agent,
|
70
|
-
survey: Survey,
|
71
|
-
scenario: Scenario,
|
72
|
-
model: Type["LanguageModel"],
|
73
|
-
debug: Optional[bool] = False,
|
74
|
-
iteration: int = 0,
|
75
|
-
cache: Optional["Cache"] = None,
|
76
|
-
sidecar_model: Optional["LanguageModel"] = None,
|
77
|
-
skip_retry: bool = False,
|
78
|
-
raise_validation_errors: bool = True,
|
79
|
-
):
|
80
|
-
"""Initialize the Interview instance.
|
81
|
-
|
82
|
-
:param agent: the agent being interviewed.
|
83
|
-
:param survey: the survey being administered to the agent.
|
84
|
-
:param scenario: the scenario that populates the survey questions.
|
85
|
-
:param model: the language model used to answer the questions.
|
86
|
-
:param debug: if True, run without calls to the language model.
|
87
|
-
:param iteration: the iteration number of the interview.
|
88
|
-
:param cache: the cache used to store the answers.
|
89
|
-
:param sidecar_model: a sidecar model used to answer questions.
|
90
|
-
|
91
|
-
>>> i = Interview.example()
|
92
|
-
>>> i.task_creators
|
93
|
-
{}
|
94
|
-
|
95
|
-
>>> i.exceptions
|
96
|
-
{}
|
97
|
-
|
98
|
-
>>> _ = asyncio.run(i.async_conduct_interview())
|
99
|
-
>>> i.task_status_logs['q0']
|
100
|
-
[{'log_time': ..., 'value': <TaskStatus.NOT_STARTED: 1>}, {'log_time': ..., 'value': <TaskStatus.WAITING_FOR_DEPENDENCIES: 2>}, {'log_time': ..., 'value': <TaskStatus.API_CALL_IN_PROGRESS: 7>}, {'log_time': ..., 'value': <TaskStatus.SUCCESS: 8>}]
|
101
|
-
|
102
|
-
>>> i.to_index
|
103
|
-
{'q0': 0, 'q1': 1, 'q2': 2}
|
104
|
-
|
105
|
-
"""
|
106
|
-
self.agent = agent
|
107
|
-
self.survey = copy.deepcopy(survey)
|
108
|
-
self.scenario = scenario
|
109
|
-
self.model = model
|
110
|
-
self.debug = debug
|
111
|
-
self.iteration = iteration
|
112
|
-
self.cache = cache
|
113
|
-
self.answers: dict[
|
114
|
-
|
115
|
-
|
116
|
-
self.sidecar_model = sidecar_model
|
117
|
-
|
118
|
-
# Trackers
|
119
|
-
self.task_creators = TaskCreators() # tracks the task creators
|
120
|
-
self.exceptions = InterviewExceptionCollection()
|
121
|
-
|
122
|
-
self._task_status_log_dict = InterviewStatusLog()
|
123
|
-
self.skip_retry = skip_retry
|
124
|
-
self.raise_validation_errors = raise_validation_errors
|
125
|
-
|
126
|
-
# dictionary mapping question names to their index in the survey.
|
127
|
-
self.to_index = {
|
128
|
-
question_name: index
|
129
|
-
for index, question_name in enumerate(self.survey.question_names)
|
130
|
-
}
|
131
|
-
|
132
|
-
self.failed_questions = []
|
133
|
-
|
134
|
-
@property
|
135
|
-
def has_exceptions(self) -> bool:
|
136
|
-
"""Return True if there are exceptions."""
|
137
|
-
return len(self.exceptions) > 0
|
138
|
-
|
139
|
-
@property
|
140
|
-
def task_status_logs(self) -> InterviewStatusLog:
|
141
|
-
"""Return the task status logs for the interview.
|
142
|
-
|
143
|
-
The keys are the question names; the values are the lists of status log changes for each task.
|
144
|
-
"""
|
145
|
-
for task_creator in self.task_creators.values():
|
146
|
-
self._task_status_log_dict[
|
147
|
-
task_creator.
|
148
|
-
|
149
|
-
return self._task_status_log_dict
|
150
|
-
|
151
|
-
@property
|
152
|
-
def token_usage(self) -> InterviewTokenUsage:
|
153
|
-
"""Determine how many tokens were used for the interview."""
|
154
|
-
return self.task_creators.token_usage
|
155
|
-
|
156
|
-
@property
|
157
|
-
def interview_status(self) -> InterviewStatusDictionary:
|
158
|
-
"""Return a dictionary mapping task status codes to counts."""
|
159
|
-
return self.task_creators.interview_status
|
160
|
-
|
161
|
-
# region: Serialization
|
162
|
-
def
|
163
|
-
"""Return a dictionary representation of the Interview instance.
|
164
|
-
This is just for hashing purposes.
|
165
|
-
|
166
|
-
>>> i = Interview.example()
|
167
|
-
>>> hash(i)
|
168
|
-
1217840301076717434
|
169
|
-
"""
|
170
|
-
d = {
|
171
|
-
"agent": self.agent.
|
172
|
-
"survey": self.survey.
|
173
|
-
"scenario": self.scenario.
|
174
|
-
"model": self.model.
|
175
|
-
"iteration": self.iteration,
|
176
|
-
"exceptions": {},
|
177
|
-
}
|
178
|
-
if include_exceptions:
|
179
|
-
d["exceptions"] = self.exceptions.to_dict()
|
180
|
-
return d
|
181
|
-
|
182
|
-
@classmethod
|
183
|
-
def from_dict(cls, d: dict[str, Any]) -> "Interview":
|
184
|
-
"""Return an Interview instance from a dictionary."""
|
185
|
-
agent = Agent.from_dict(d["agent"])
|
186
|
-
survey = Survey.from_dict(d["survey"])
|
187
|
-
scenario = Scenario.from_dict(d["scenario"])
|
188
|
-
model = LanguageModel.from_dict(d["model"])
|
189
|
-
iteration = d["iteration"]
|
190
|
-
interview = cls(
|
191
|
-
agent=agent,
|
192
|
-
survey=survey,
|
193
|
-
scenario=scenario,
|
194
|
-
model=model,
|
195
|
-
iteration=iteration,
|
196
|
-
)
|
197
|
-
if "exceptions" in d:
|
198
|
-
exceptions = InterviewExceptionCollection.from_dict(d["exceptions"])
|
199
|
-
interview.exceptions = exceptions
|
200
|
-
return interview
|
201
|
-
|
202
|
-
def __hash__(self) -> int:
|
203
|
-
from edsl.utilities.utilities import dict_hash
|
204
|
-
|
205
|
-
return dict_hash(self.
|
206
|
-
|
207
|
-
def __eq__(self, other: "Interview") -> bool:
|
208
|
-
"""
|
209
|
-
>>> from edsl.jobs.interviews.Interview import Interview; i = Interview.example(); d = i.
|
210
|
-
True
|
211
|
-
"""
|
212
|
-
return hash(self) == hash(other)
|
213
|
-
|
214
|
-
# endregion
|
215
|
-
|
216
|
-
# region: Creating tasks
|
217
|
-
@property
|
218
|
-
def dag(self) -> "DAG":
|
219
|
-
"""Return the directed acyclic graph for the survey.
|
220
|
-
|
221
|
-
The DAG, or directed acyclic graph, is a dictionary that maps question names to their dependencies.
|
222
|
-
It is used to determine the order in which questions should be answered.
|
223
|
-
This reflects both agent 'memory' considerations and 'skip' logic.
|
224
|
-
The 'textify' parameter is set to True, so that the question names are returned as strings rather than integer indices.
|
225
|
-
|
226
|
-
>>> i = Interview.example()
|
227
|
-
>>> i.dag == {'q2': {'q0'}, 'q1': {'q0'}}
|
228
|
-
True
|
229
|
-
"""
|
230
|
-
return self.survey.dag(textify=True)
|
231
|
-
|
232
|
-
def _build_question_tasks(
|
233
|
-
self,
|
234
|
-
model_buckets: ModelBuckets,
|
235
|
-
) -> list[asyncio.Task]:
|
236
|
-
"""Create a task for each question, with dependencies on the questions that must be answered before this one can be answered.
|
237
|
-
|
238
|
-
:param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
|
239
|
-
:param model_buckets: the model buckets used to track and control usage rates.
|
240
|
-
"""
|
241
|
-
tasks = []
|
242
|
-
for question in self.survey.questions:
|
243
|
-
tasks_that_must_be_completed_before = list(
|
244
|
-
self._get_tasks_that_must_be_completed_before(
|
245
|
-
tasks=tasks, question=question
|
246
|
-
)
|
247
|
-
)
|
248
|
-
question_task = self._create_question_task(
|
249
|
-
question=question,
|
250
|
-
tasks_that_must_be_completed_before=tasks_that_must_be_completed_before,
|
251
|
-
model_buckets=model_buckets,
|
252
|
-
iteration=self.iteration,
|
253
|
-
)
|
254
|
-
tasks.append(question_task)
|
255
|
-
return tuple(tasks)
|
256
|
-
|
257
|
-
def _get_tasks_that_must_be_completed_before(
|
258
|
-
self, *, tasks: list[asyncio.Task], question: "QuestionBase"
|
259
|
-
) -> Generator[asyncio.Task, None, None]:
|
260
|
-
"""Return the tasks that must be completed before the given question can be answered.
|
261
|
-
|
262
|
-
:param tasks: a list of tasks that have been created so far.
|
263
|
-
:param question: the question for which we are determining dependencies.
|
264
|
-
|
265
|
-
If a question has no dependencies, this will be an empty list, [].
|
266
|
-
"""
|
267
|
-
parents_of_focal_question = self.dag.get(question.question_name, [])
|
268
|
-
for parent_question_name in parents_of_focal_question:
|
269
|
-
yield tasks[self.to_index[parent_question_name]]
|
270
|
-
|
271
|
-
def _create_question_task(
|
272
|
-
self,
|
273
|
-
*,
|
274
|
-
question: QuestionBase,
|
275
|
-
tasks_that_must_be_completed_before: list[asyncio.Task],
|
276
|
-
model_buckets: ModelBuckets,
|
277
|
-
iteration: int = 0,
|
278
|
-
) -> asyncio.Task:
|
279
|
-
"""Create a task that depends on the passed-in dependencies that are awaited before the task is run.
|
280
|
-
|
281
|
-
:param question: the question to be answered. This is the question we are creating a task for.
|
282
|
-
:param tasks_that_must_be_completed_before: the tasks that must be completed before the focal task is run.
|
283
|
-
:param model_buckets: the model buckets used to track and control usage rates.
|
284
|
-
:param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
|
285
|
-
:param iteration: the iteration number for the interview.
|
286
|
-
|
287
|
-
The task is created by a `QuestionTaskCreator`, which is responsible for creating the task and managing its dependencies.
|
288
|
-
It is passed a reference to the function that will be called to answer the question.
|
289
|
-
It is passed a list "tasks_that_must_be_completed_before" that are awaited before the task is run.
|
290
|
-
These are added as a dependency to the focal task.
|
291
|
-
"""
|
292
|
-
task_creator = QuestionTaskCreator(
|
293
|
-
question=question,
|
294
|
-
answer_question_func=self._answer_question_and_record_task,
|
295
|
-
token_estimator=self._get_estimated_request_tokens,
|
296
|
-
model_buckets=model_buckets,
|
297
|
-
iteration=iteration,
|
298
|
-
)
|
299
|
-
for task in tasks_that_must_be_completed_before:
|
300
|
-
task_creator.add_dependency(task)
|
301
|
-
|
302
|
-
self.task_creators.update(
|
303
|
-
{question.question_name: task_creator}
|
304
|
-
) # track this task creator
|
305
|
-
return task_creator.generate_task()
|
306
|
-
|
307
|
-
def _get_estimated_request_tokens(self, question) -> float:
|
308
|
-
"""Estimate the number of tokens that will be required to run the focal task."""
|
309
|
-
from edsl.scenarios.FileStore import FileStore
|
310
|
-
|
311
|
-
invigilator = self._get_invigilator(question=question)
|
312
|
-
# TODO: There should be a way to get a more accurate estimate.
|
313
|
-
combined_text = ""
|
314
|
-
file_tokens = 0
|
315
|
-
for prompt in invigilator.get_prompts().values():
|
316
|
-
if hasattr(prompt, "text"):
|
317
|
-
combined_text += prompt.text
|
318
|
-
elif isinstance(prompt, str):
|
319
|
-
combined_text += prompt
|
320
|
-
elif isinstance(prompt, list):
|
321
|
-
for file in prompt:
|
322
|
-
if isinstance(file, FileStore):
|
323
|
-
file_tokens += file.size * 0.25
|
324
|
-
else:
|
325
|
-
raise ValueError(f"Prompt is of type {type(prompt)}")
|
326
|
-
return len(combined_text) / 4.0 + file_tokens
|
327
|
-
|
328
|
-
async def _answer_question_and_record_task(
|
329
|
-
self,
|
330
|
-
*,
|
331
|
-
question: "QuestionBase",
|
332
|
-
task=None,
|
333
|
-
) -> "AgentResponseDict":
|
334
|
-
"""Answer a question and records the task."""
|
335
|
-
|
336
|
-
had_language_model_no_response_error = False
|
337
|
-
|
338
|
-
@retry(
|
339
|
-
stop=stop_after_attempt(EDSL_MAX_ATTEMPTS),
|
340
|
-
wait=wait_exponential(
|
341
|
-
multiplier=EDSL_BACKOFF_START_SEC, max=EDSL_BACKOFF_MAX_SEC
|
342
|
-
),
|
343
|
-
retry=retry_if_exception_type(LanguageModelNoResponseError),
|
344
|
-
reraise=True,
|
345
|
-
)
|
346
|
-
async def attempt_answer():
|
347
|
-
nonlocal had_language_model_no_response_error
|
348
|
-
|
349
|
-
invigilator = self._get_invigilator(question)
|
350
|
-
|
351
|
-
if self._skip_this_question(question):
|
352
|
-
return invigilator.get_failed_task_result(
|
353
|
-
failure_reason="Question skipped."
|
354
|
-
)
|
355
|
-
|
356
|
-
try:
|
357
|
-
response: EDSLResultObjectInput = (
|
358
|
-
await invigilator.async_answer_question()
|
359
|
-
)
|
360
|
-
if response.validated:
|
361
|
-
self.answers.add_answer(response=response, question=question)
|
362
|
-
self._cancel_skipped_questions(question)
|
363
|
-
else:
|
364
|
-
# When a question is not validated, it is not added to the answers.
|
365
|
-
# this should also cancel and dependent children questions.
|
366
|
-
# Is that happening now?
|
367
|
-
if (
|
368
|
-
hasattr(response, "exception_occurred")
|
369
|
-
and response.exception_occurred
|
370
|
-
):
|
371
|
-
raise response.exception_occurred
|
372
|
-
|
373
|
-
except QuestionAnswerValidationError as e:
|
374
|
-
self._handle_exception(e, invigilator, task)
|
375
|
-
return invigilator.get_failed_task_result(
|
376
|
-
failure_reason="Question answer validation failed."
|
377
|
-
)
|
378
|
-
|
379
|
-
except asyncio.TimeoutError as e:
|
380
|
-
self._handle_exception(e, invigilator, task)
|
381
|
-
had_language_model_no_response_error = True
|
382
|
-
raise LanguageModelNoResponseError(
|
383
|
-
f"Language model timed out for question '{question.question_name}.'"
|
384
|
-
)
|
385
|
-
|
386
|
-
except Exception as e:
|
387
|
-
self._handle_exception(e, invigilator, task)
|
388
|
-
|
389
|
-
if "response" not in locals():
|
390
|
-
had_language_model_no_response_error = True
|
391
|
-
raise LanguageModelNoResponseError(
|
392
|
-
f"Language model did not return a response for question '{question.question_name}.'"
|
393
|
-
)
|
394
|
-
|
395
|
-
# if it gets here, it means the no response error was fixed
|
396
|
-
if (
|
397
|
-
question.question_name in self.exceptions
|
398
|
-
and had_language_model_no_response_error
|
399
|
-
):
|
400
|
-
self.exceptions.record_fixed_question(question.question_name)
|
401
|
-
|
402
|
-
return response
|
403
|
-
|
404
|
-
try:
|
405
|
-
return await attempt_answer()
|
406
|
-
except RetryError as retry_error:
|
407
|
-
# All retries have failed for LanguageModelNoResponseError
|
408
|
-
original_error = retry_error.last_attempt.exception()
|
409
|
-
self._handle_exception(
|
410
|
-
original_error, self._get_invigilator(question), task
|
411
|
-
)
|
412
|
-
raise original_error # Re-raise the original error after handling
|
413
|
-
|
414
|
-
def _get_invigilator(self, question: QuestionBase) -> InvigilatorBase:
|
415
|
-
"""Return an invigilator for the given question.
|
416
|
-
|
417
|
-
:param question: the question to be answered
|
418
|
-
:param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
|
419
|
-
"""
|
420
|
-
invigilator = self.agent.create_invigilator(
|
421
|
-
question=question,
|
422
|
-
scenario=self.scenario,
|
423
|
-
model=self.model,
|
424
|
-
debug=False,
|
425
|
-
survey=self.survey,
|
426
|
-
memory_plan=self.survey.memory_plan,
|
427
|
-
current_answers=self.answers,
|
428
|
-
iteration=self.iteration,
|
429
|
-
cache=self.cache,
|
430
|
-
sidecar_model=self.sidecar_model,
|
431
|
-
raise_validation_errors=self.raise_validation_errors,
|
432
|
-
)
|
433
|
-
"""Return an invigilator for the given question."""
|
434
|
-
return invigilator
|
435
|
-
|
436
|
-
def _skip_this_question(self, current_question: "QuestionBase") -> bool:
|
437
|
-
"""Determine if the current question should be skipped.
|
438
|
-
|
439
|
-
:param current_question: the question to be answered.
|
440
|
-
"""
|
441
|
-
current_question_index = self.to_index[current_question.question_name]
|
442
|
-
|
443
|
-
answers = self.answers | self.scenario | self.agent["traits"]
|
444
|
-
skip = self.survey.rule_collection.skip_question_before_running(
|
445
|
-
current_question_index, answers
|
446
|
-
)
|
447
|
-
return skip
|
448
|
-
|
449
|
-
def _handle_exception(
|
450
|
-
self, e: Exception, invigilator: "InvigilatorBase", task=None
|
451
|
-
):
|
452
|
-
import copy
|
453
|
-
|
454
|
-
# breakpoint()
|
455
|
-
|
456
|
-
answers = copy.copy(self.answers)
|
457
|
-
exception_entry = InterviewExceptionEntry(
|
458
|
-
exception=e,
|
459
|
-
invigilator=invigilator,
|
460
|
-
answers=answers,
|
461
|
-
)
|
462
|
-
if task:
|
463
|
-
task.task_status = TaskStatus.FAILED
|
464
|
-
self.exceptions.add(invigilator.question.question_name, exception_entry)
|
465
|
-
|
466
|
-
if self.raise_validation_errors:
|
467
|
-
if isinstance(e, QuestionAnswerValidationError):
|
468
|
-
raise e
|
469
|
-
|
470
|
-
if hasattr(self, "stop_on_exception"):
|
471
|
-
stop_on_exception = self.stop_on_exception
|
472
|
-
else:
|
473
|
-
stop_on_exception = False
|
474
|
-
|
475
|
-
if stop_on_exception:
|
476
|
-
raise e
|
477
|
-
|
478
|
-
def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
|
479
|
-
"""Cancel the tasks for questions that are skipped.
|
480
|
-
|
481
|
-
:param current_question: the question that was just answered.
|
482
|
-
|
483
|
-
It first determines the next question, given the current question and the current answers.
|
484
|
-
If the next question is the end of the survey, it cancels all remaining tasks.
|
485
|
-
If the next question is after the current question, it cancels all tasks between the current question and the next question.
|
486
|
-
"""
|
487
|
-
current_question_index: int = self.to_index[current_question.question_name]
|
488
|
-
|
489
|
-
next_question: Union[
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
)
|
495
|
-
|
496
|
-
next_question_index = next_question.next_q
|
497
|
-
|
498
|
-
def cancel_between(start, end):
|
499
|
-
"""Cancel the tasks between the start and end indices."""
|
500
|
-
for i in range(start, end):
|
501
|
-
self.tasks[i].cancel()
|
502
|
-
|
503
|
-
if next_question_index == EndOfSurvey:
|
504
|
-
cancel_between(current_question_index + 1, len(self.survey.questions))
|
505
|
-
return
|
506
|
-
|
507
|
-
if next_question_index > (current_question_index + 1):
|
508
|
-
cancel_between(current_question_index + 1, next_question_index)
|
509
|
-
|
510
|
-
# endregion
|
511
|
-
|
512
|
-
# region: Conducting the interview
|
513
|
-
async def async_conduct_interview(
|
514
|
-
self,
|
515
|
-
model_buckets: Optional[ModelBuckets] = None,
|
516
|
-
stop_on_exception: bool = False,
|
517
|
-
sidecar_model: Optional["LanguageModel"] = None,
|
518
|
-
raise_validation_errors: bool = True,
|
519
|
-
) -> tuple["Answers", List[dict[str, Any]]]:
|
520
|
-
"""
|
521
|
-
Conduct an Interview asynchronously.
|
522
|
-
It returns a tuple with the answers and a list of valid results.
|
523
|
-
|
524
|
-
:param model_buckets: a dictionary of token buckets for the model.
|
525
|
-
:param debug: run without calls to LLM.
|
526
|
-
:param stop_on_exception: if True, stops the interview if an exception is raised.
|
527
|
-
:param sidecar_model: a sidecar model used to answer questions.
|
528
|
-
|
529
|
-
Example usage:
|
530
|
-
|
531
|
-
>>> i = Interview.example()
|
532
|
-
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
533
|
-
>>> result['q0']
|
534
|
-
'yes'
|
535
|
-
|
536
|
-
>>> i = Interview.example(throw_exception = True)
|
537
|
-
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
538
|
-
>>> i.exceptions
|
539
|
-
{'q0': ...
|
540
|
-
>>> i = Interview.example()
|
541
|
-
>>> result, _ = asyncio.run(i.async_conduct_interview(stop_on_exception = True))
|
542
|
-
Traceback (most recent call last):
|
543
|
-
...
|
544
|
-
asyncio.exceptions.CancelledError
|
545
|
-
"""
|
546
|
-
self.sidecar_model = sidecar_model
|
547
|
-
self.stop_on_exception = stop_on_exception
|
548
|
-
|
549
|
-
# if no model bucket is passed, create an 'infinity' bucket with no rate limits
|
550
|
-
if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
|
551
|
-
model_buckets = ModelBuckets.infinity_bucket()
|
552
|
-
|
553
|
-
## This is the key part---it creates a task for each question,
|
554
|
-
## with dependencies on the questions that must be answered before this one can be answered.
|
555
|
-
self.tasks = self._build_question_tasks(model_buckets=model_buckets)
|
556
|
-
|
557
|
-
## 'Invigilators' are used to administer the survey
|
558
|
-
self.invigilators = [
|
559
|
-
self._get_invigilator(question) for question in self.survey.questions
|
560
|
-
]
|
561
|
-
await asyncio.gather(
|
562
|
-
*self.tasks, return_exceptions=not stop_on_exception
|
563
|
-
) # not stop_on_exception)
|
564
|
-
self.answers.replace_missing_answers_with_none(self.survey)
|
565
|
-
valid_results = list(self._extract_valid_results())
|
566
|
-
return self.answers, valid_results
|
567
|
-
|
568
|
-
# endregion
|
569
|
-
|
570
|
-
# region: Extracting results and recording errors
|
571
|
-
def _extract_valid_results(self) -> Generator["Answers", None, None]:
|
572
|
-
"""Extract the valid results from the list of results.
|
573
|
-
|
574
|
-
It iterates through the tasks and invigilators, and yields the results of the tasks that are done.
|
575
|
-
If a task is not done, it raises a ValueError.
|
576
|
-
If an exception is raised in the task, it records the exception in the Interview instance except if the task was cancelled, which is expected behavior.
|
577
|
-
|
578
|
-
>>> i = Interview.example()
|
579
|
-
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
580
|
-
>>> results = list(i._extract_valid_results())
|
581
|
-
>>> len(results) == len(i.survey)
|
582
|
-
True
|
583
|
-
"""
|
584
|
-
assert len(self.tasks) == len(self.invigilators)
|
585
|
-
|
586
|
-
for task, invigilator in zip(self.tasks, self.invigilators):
|
587
|
-
if not task.done():
|
588
|
-
raise ValueError(f"Task {task.get_name()} is not done.")
|
589
|
-
|
590
|
-
try:
|
591
|
-
result = task.result()
|
592
|
-
except asyncio.CancelledError as e: # task was cancelled
|
593
|
-
result = invigilator.get_failed_task_result(
|
594
|
-
failure_reason="Task was cancelled."
|
595
|
-
)
|
596
|
-
except Exception as e: # any other kind of exception in the task
|
597
|
-
result = invigilator.get_failed_task_result(
|
598
|
-
failure_reason=f"Task failed with exception: {str(e)}."
|
599
|
-
)
|
600
|
-
exception_entry = InterviewExceptionEntry(
|
601
|
-
exception=e,
|
602
|
-
invigilator=invigilator,
|
603
|
-
)
|
604
|
-
self.exceptions.add(task.get_name(), exception_entry)
|
605
|
-
|
606
|
-
yield result
|
607
|
-
|
608
|
-
# endregion
|
609
|
-
|
610
|
-
# region: Magic methods
|
611
|
-
def __repr__(self) -> str:
|
612
|
-
"""Return a string representation of the Interview instance."""
|
613
|
-
return f"Interview(agent = {repr(self.agent)}, survey = {repr(self.survey)}, scenario = {repr(self.scenario)}, model = {repr(self.model)})"
|
614
|
-
|
615
|
-
def duplicate(self, iteration: int, cache: "Cache") -> Interview:
|
616
|
-
"""Duplicate the interview, but with a new iteration number and cache.
|
617
|
-
|
618
|
-
>>> i = Interview.example()
|
619
|
-
>>> i2 = i.duplicate(1, None)
|
620
|
-
>>> i.iteration + 1 == i2.iteration
|
621
|
-
True
|
622
|
-
|
623
|
-
"""
|
624
|
-
return Interview(
|
625
|
-
agent=self.agent,
|
626
|
-
survey=self.survey,
|
627
|
-
scenario=self.scenario,
|
628
|
-
model=self.model,
|
629
|
-
iteration=iteration,
|
630
|
-
cache=cache,
|
631
|
-
skip_retry=self.skip_retry,
|
632
|
-
)
|
633
|
-
|
634
|
-
@classmethod
|
635
|
-
def example(self, throw_exception: bool = False) -> Interview:
|
636
|
-
"""Return an example Interview instance."""
|
637
|
-
from edsl.agents import Agent
|
638
|
-
from edsl.surveys import Survey
|
639
|
-
from edsl.scenarios import Scenario
|
640
|
-
from edsl.language_models import LanguageModel
|
641
|
-
|
642
|
-
def f(self, question, scenario):
|
643
|
-
return "yes"
|
644
|
-
|
645
|
-
agent = Agent.example()
|
646
|
-
agent.add_direct_question_answering_method(f)
|
647
|
-
survey = Survey.example()
|
648
|
-
scenario = Scenario.example()
|
649
|
-
model = LanguageModel.example()
|
650
|
-
if throw_exception:
|
651
|
-
model = LanguageModel.example(test_model=True, throw_exception=True)
|
652
|
-
agent = Agent.example()
|
653
|
-
return Interview(agent=agent, survey=survey, scenario=scenario, model=model)
|
654
|
-
return Interview(agent=agent, survey=survey, scenario=scenario, model=model)
|
655
|
-
|
656
|
-
|
657
|
-
if __name__ == "__main__":
|
658
|
-
import doctest
|
659
|
-
|
660
|
-
# add ellipsis
|
661
|
-
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
1
|
+
"""This module contains the Interview class, which is responsible for conducting an interview asynchronously."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
import asyncio
|
5
|
+
from typing import Any, Type, List, Generator, Optional, Union
|
6
|
+
import copy
|
7
|
+
|
8
|
+
from tenacity import (
|
9
|
+
retry,
|
10
|
+
stop_after_attempt,
|
11
|
+
wait_exponential,
|
12
|
+
retry_if_exception_type,
|
13
|
+
RetryError,
|
14
|
+
)
|
15
|
+
|
16
|
+
from edsl import CONFIG
|
17
|
+
from edsl.surveys.base import EndOfSurvey
|
18
|
+
from edsl.exceptions import QuestionAnswerValidationError
|
19
|
+
from edsl.exceptions import QuestionAnswerValidationError
|
20
|
+
from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
|
21
|
+
|
22
|
+
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
23
|
+
from edsl.jobs.Answers import Answers
|
24
|
+
from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
|
25
|
+
from edsl.jobs.tasks.TaskCreators import TaskCreators
|
26
|
+
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
27
|
+
from edsl.jobs.interviews.InterviewExceptionCollection import (
|
28
|
+
InterviewExceptionCollection,
|
29
|
+
)
|
30
|
+
|
31
|
+
# from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
|
32
|
+
|
33
|
+
from edsl.surveys.base import EndOfSurvey
|
34
|
+
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
35
|
+
from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
36
|
+
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
37
|
+
from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
|
38
|
+
|
39
|
+
|
40
|
+
from edsl import Agent, Survey, Scenario, Cache
|
41
|
+
from edsl.language_models import LanguageModel
|
42
|
+
from edsl.questions import QuestionBase
|
43
|
+
from edsl.agents.InvigilatorBase import InvigilatorBase
|
44
|
+
|
45
|
+
from edsl.exceptions.language_models import LanguageModelNoResponseError
|
46
|
+
|
47
|
+
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
48
|
+
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
49
|
+
from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
|
50
|
+
|
51
|
+
|
52
|
+
from edsl import CONFIG
|
53
|
+
|
54
|
+
EDSL_BACKOFF_START_SEC = float(CONFIG.get("EDSL_BACKOFF_START_SEC"))
|
55
|
+
EDSL_BACKOFF_MAX_SEC = float(CONFIG.get("EDSL_BACKOFF_MAX_SEC"))
|
56
|
+
EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
|
57
|
+
|
58
|
+
|
59
|
+
class Interview:
|
60
|
+
"""
|
61
|
+
An 'interview' is one agent answering one survey, with one language model, for a given scenario.
|
62
|
+
|
63
|
+
The main method is `async_conduct_interview`, which conducts the interview asynchronously.
|
64
|
+
Most of the class is dedicated to creating the tasks for each question in the survey, and then running them.
|
65
|
+
"""
|
66
|
+
|
67
|
+
def __init__(
|
68
|
+
self,
|
69
|
+
agent: Agent,
|
70
|
+
survey: Survey,
|
71
|
+
scenario: Scenario,
|
72
|
+
model: Type["LanguageModel"],
|
73
|
+
debug: Optional[bool] = False,
|
74
|
+
iteration: int = 0,
|
75
|
+
cache: Optional["Cache"] = None,
|
76
|
+
sidecar_model: Optional["LanguageModel"] = None,
|
77
|
+
skip_retry: bool = False,
|
78
|
+
raise_validation_errors: bool = True,
|
79
|
+
):
|
80
|
+
"""Initialize the Interview instance.
|
81
|
+
|
82
|
+
:param agent: the agent being interviewed.
|
83
|
+
:param survey: the survey being administered to the agent.
|
84
|
+
:param scenario: the scenario that populates the survey questions.
|
85
|
+
:param model: the language model used to answer the questions.
|
86
|
+
:param debug: if True, run without calls to the language model.
|
87
|
+
:param iteration: the iteration number of the interview.
|
88
|
+
:param cache: the cache used to store the answers.
|
89
|
+
:param sidecar_model: a sidecar model used to answer questions.
|
90
|
+
|
91
|
+
>>> i = Interview.example()
|
92
|
+
>>> i.task_creators
|
93
|
+
{}
|
94
|
+
|
95
|
+
>>> i.exceptions
|
96
|
+
{}
|
97
|
+
|
98
|
+
>>> _ = asyncio.run(i.async_conduct_interview())
|
99
|
+
>>> i.task_status_logs['q0']
|
100
|
+
[{'log_time': ..., 'value': <TaskStatus.NOT_STARTED: 1>}, {'log_time': ..., 'value': <TaskStatus.WAITING_FOR_DEPENDENCIES: 2>}, {'log_time': ..., 'value': <TaskStatus.API_CALL_IN_PROGRESS: 7>}, {'log_time': ..., 'value': <TaskStatus.SUCCESS: 8>}]
|
101
|
+
|
102
|
+
>>> i.to_index
|
103
|
+
{'q0': 0, 'q1': 1, 'q2': 2}
|
104
|
+
|
105
|
+
"""
|
106
|
+
self.agent = agent
|
107
|
+
self.survey = copy.deepcopy(survey)
|
108
|
+
self.scenario = scenario
|
109
|
+
self.model = model
|
110
|
+
self.debug = debug
|
111
|
+
self.iteration = iteration
|
112
|
+
self.cache = cache
|
113
|
+
self.answers: dict[str, str] = (
|
114
|
+
Answers()
|
115
|
+
) # will get filled in as interview progresses
|
116
|
+
self.sidecar_model = sidecar_model
|
117
|
+
|
118
|
+
# Trackers
|
119
|
+
self.task_creators = TaskCreators() # tracks the task creators
|
120
|
+
self.exceptions = InterviewExceptionCollection()
|
121
|
+
|
122
|
+
self._task_status_log_dict = InterviewStatusLog()
|
123
|
+
self.skip_retry = skip_retry
|
124
|
+
self.raise_validation_errors = raise_validation_errors
|
125
|
+
|
126
|
+
# dictionary mapping question names to their index in the survey.
|
127
|
+
self.to_index = {
|
128
|
+
question_name: index
|
129
|
+
for index, question_name in enumerate(self.survey.question_names)
|
130
|
+
}
|
131
|
+
|
132
|
+
self.failed_questions = []
|
133
|
+
|
134
|
+
@property
|
135
|
+
def has_exceptions(self) -> bool:
|
136
|
+
"""Return True if there are exceptions."""
|
137
|
+
return len(self.exceptions) > 0
|
138
|
+
|
139
|
+
@property
|
140
|
+
def task_status_logs(self) -> InterviewStatusLog:
|
141
|
+
"""Return the task status logs for the interview.
|
142
|
+
|
143
|
+
The keys are the question names; the values are the lists of status log changes for each task.
|
144
|
+
"""
|
145
|
+
for task_creator in self.task_creators.values():
|
146
|
+
self._task_status_log_dict[task_creator.question.question_name] = (
|
147
|
+
task_creator.status_log
|
148
|
+
)
|
149
|
+
return self._task_status_log_dict
|
150
|
+
|
151
|
+
@property
|
152
|
+
def token_usage(self) -> InterviewTokenUsage:
|
153
|
+
"""Determine how many tokens were used for the interview."""
|
154
|
+
return self.task_creators.token_usage
|
155
|
+
|
156
|
+
@property
|
157
|
+
def interview_status(self) -> InterviewStatusDictionary:
|
158
|
+
"""Return a dictionary mapping task status codes to counts."""
|
159
|
+
return self.task_creators.interview_status
|
160
|
+
|
161
|
+
# region: Serialization
|
162
|
+
def to_dict(self, include_exceptions=True, add_edsl_version=True) -> dict[str, Any]:
|
163
|
+
"""Return a dictionary representation of the Interview instance.
|
164
|
+
This is just for hashing purposes.
|
165
|
+
|
166
|
+
>>> i = Interview.example()
|
167
|
+
>>> hash(i)
|
168
|
+
1217840301076717434
|
169
|
+
"""
|
170
|
+
d = {
|
171
|
+
"agent": self.agent.to_dict(add_edsl_version=add_edsl_version),
|
172
|
+
"survey": self.survey.to_dict(add_edsl_version=add_edsl_version),
|
173
|
+
"scenario": self.scenario.to_dict(add_edsl_version=add_edsl_version),
|
174
|
+
"model": self.model.to_dict(add_edsl_version=add_edsl_version),
|
175
|
+
"iteration": self.iteration,
|
176
|
+
"exceptions": {},
|
177
|
+
}
|
178
|
+
if include_exceptions:
|
179
|
+
d["exceptions"] = self.exceptions.to_dict()
|
180
|
+
return d
|
181
|
+
|
182
|
+
@classmethod
|
183
|
+
def from_dict(cls, d: dict[str, Any]) -> "Interview":
|
184
|
+
"""Return an Interview instance from a dictionary."""
|
185
|
+
agent = Agent.from_dict(d["agent"])
|
186
|
+
survey = Survey.from_dict(d["survey"])
|
187
|
+
scenario = Scenario.from_dict(d["scenario"])
|
188
|
+
model = LanguageModel.from_dict(d["model"])
|
189
|
+
iteration = d["iteration"]
|
190
|
+
interview = cls(
|
191
|
+
agent=agent,
|
192
|
+
survey=survey,
|
193
|
+
scenario=scenario,
|
194
|
+
model=model,
|
195
|
+
iteration=iteration,
|
196
|
+
)
|
197
|
+
if "exceptions" in d:
|
198
|
+
exceptions = InterviewExceptionCollection.from_dict(d["exceptions"])
|
199
|
+
interview.exceptions = exceptions
|
200
|
+
return interview
|
201
|
+
|
202
|
+
def __hash__(self) -> int:
|
203
|
+
from edsl.utilities.utilities import dict_hash
|
204
|
+
|
205
|
+
return dict_hash(self.to_dict(include_exceptions=False, add_edsl_version=False))
|
206
|
+
|
207
|
+
def __eq__(self, other: "Interview") -> bool:
|
208
|
+
"""
|
209
|
+
>>> from edsl.jobs.interviews.Interview import Interview; i = Interview.example(); d = i.to_dict(); i2 = Interview.from_dict(d); i == i2
|
210
|
+
True
|
211
|
+
"""
|
212
|
+
return hash(self) == hash(other)
|
213
|
+
|
214
|
+
# endregion
|
215
|
+
|
216
|
+
# region: Creating tasks
|
217
|
+
@property
|
218
|
+
def dag(self) -> "DAG":
|
219
|
+
"""Return the directed acyclic graph for the survey.
|
220
|
+
|
221
|
+
The DAG, or directed acyclic graph, is a dictionary that maps question names to their dependencies.
|
222
|
+
It is used to determine the order in which questions should be answered.
|
223
|
+
This reflects both agent 'memory' considerations and 'skip' logic.
|
224
|
+
The 'textify' parameter is set to True, so that the question names are returned as strings rather than integer indices.
|
225
|
+
|
226
|
+
>>> i = Interview.example()
|
227
|
+
>>> i.dag == {'q2': {'q0'}, 'q1': {'q0'}}
|
228
|
+
True
|
229
|
+
"""
|
230
|
+
return self.survey.dag(textify=True)
|
231
|
+
|
232
|
+
def _build_question_tasks(
|
233
|
+
self,
|
234
|
+
model_buckets: ModelBuckets,
|
235
|
+
) -> list[asyncio.Task]:
|
236
|
+
"""Create a task for each question, with dependencies on the questions that must be answered before this one can be answered.
|
237
|
+
|
238
|
+
:param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
|
239
|
+
:param model_buckets: the model buckets used to track and control usage rates.
|
240
|
+
"""
|
241
|
+
tasks = []
|
242
|
+
for question in self.survey.questions:
|
243
|
+
tasks_that_must_be_completed_before = list(
|
244
|
+
self._get_tasks_that_must_be_completed_before(
|
245
|
+
tasks=tasks, question=question
|
246
|
+
)
|
247
|
+
)
|
248
|
+
question_task = self._create_question_task(
|
249
|
+
question=question,
|
250
|
+
tasks_that_must_be_completed_before=tasks_that_must_be_completed_before,
|
251
|
+
model_buckets=model_buckets,
|
252
|
+
iteration=self.iteration,
|
253
|
+
)
|
254
|
+
tasks.append(question_task)
|
255
|
+
return tuple(tasks)
|
256
|
+
|
257
|
+
def _get_tasks_that_must_be_completed_before(
|
258
|
+
self, *, tasks: list[asyncio.Task], question: "QuestionBase"
|
259
|
+
) -> Generator[asyncio.Task, None, None]:
|
260
|
+
"""Return the tasks that must be completed before the given question can be answered.
|
261
|
+
|
262
|
+
:param tasks: a list of tasks that have been created so far.
|
263
|
+
:param question: the question for which we are determining dependencies.
|
264
|
+
|
265
|
+
If a question has no dependencies, this will be an empty list, [].
|
266
|
+
"""
|
267
|
+
parents_of_focal_question = self.dag.get(question.question_name, [])
|
268
|
+
for parent_question_name in parents_of_focal_question:
|
269
|
+
yield tasks[self.to_index[parent_question_name]]
|
270
|
+
|
271
|
+
def _create_question_task(
|
272
|
+
self,
|
273
|
+
*,
|
274
|
+
question: QuestionBase,
|
275
|
+
tasks_that_must_be_completed_before: list[asyncio.Task],
|
276
|
+
model_buckets: ModelBuckets,
|
277
|
+
iteration: int = 0,
|
278
|
+
) -> asyncio.Task:
|
279
|
+
"""Create a task that depends on the passed-in dependencies that are awaited before the task is run.
|
280
|
+
|
281
|
+
:param question: the question to be answered. This is the question we are creating a task for.
|
282
|
+
:param tasks_that_must_be_completed_before: the tasks that must be completed before the focal task is run.
|
283
|
+
:param model_buckets: the model buckets used to track and control usage rates.
|
284
|
+
:param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
|
285
|
+
:param iteration: the iteration number for the interview.
|
286
|
+
|
287
|
+
The task is created by a `QuestionTaskCreator`, which is responsible for creating the task and managing its dependencies.
|
288
|
+
It is passed a reference to the function that will be called to answer the question.
|
289
|
+
It is passed a list "tasks_that_must_be_completed_before" that are awaited before the task is run.
|
290
|
+
These are added as a dependency to the focal task.
|
291
|
+
"""
|
292
|
+
task_creator = QuestionTaskCreator(
|
293
|
+
question=question,
|
294
|
+
answer_question_func=self._answer_question_and_record_task,
|
295
|
+
token_estimator=self._get_estimated_request_tokens,
|
296
|
+
model_buckets=model_buckets,
|
297
|
+
iteration=iteration,
|
298
|
+
)
|
299
|
+
for task in tasks_that_must_be_completed_before:
|
300
|
+
task_creator.add_dependency(task)
|
301
|
+
|
302
|
+
self.task_creators.update(
|
303
|
+
{question.question_name: task_creator}
|
304
|
+
) # track this task creator
|
305
|
+
return task_creator.generate_task()
|
306
|
+
|
307
|
+
def _get_estimated_request_tokens(self, question) -> float:
|
308
|
+
"""Estimate the number of tokens that will be required to run the focal task."""
|
309
|
+
from edsl.scenarios.FileStore import FileStore
|
310
|
+
|
311
|
+
invigilator = self._get_invigilator(question=question)
|
312
|
+
# TODO: There should be a way to get a more accurate estimate.
|
313
|
+
combined_text = ""
|
314
|
+
file_tokens = 0
|
315
|
+
for prompt in invigilator.get_prompts().values():
|
316
|
+
if hasattr(prompt, "text"):
|
317
|
+
combined_text += prompt.text
|
318
|
+
elif isinstance(prompt, str):
|
319
|
+
combined_text += prompt
|
320
|
+
elif isinstance(prompt, list):
|
321
|
+
for file in prompt:
|
322
|
+
if isinstance(file, FileStore):
|
323
|
+
file_tokens += file.size * 0.25
|
324
|
+
else:
|
325
|
+
raise ValueError(f"Prompt is of type {type(prompt)}")
|
326
|
+
return len(combined_text) / 4.0 + file_tokens
|
327
|
+
|
328
|
+
async def _answer_question_and_record_task(
|
329
|
+
self,
|
330
|
+
*,
|
331
|
+
question: "QuestionBase",
|
332
|
+
task=None,
|
333
|
+
) -> "AgentResponseDict":
|
334
|
+
"""Answer a question and records the task."""
|
335
|
+
|
336
|
+
had_language_model_no_response_error = False
|
337
|
+
|
338
|
+
@retry(
|
339
|
+
stop=stop_after_attempt(EDSL_MAX_ATTEMPTS),
|
340
|
+
wait=wait_exponential(
|
341
|
+
multiplier=EDSL_BACKOFF_START_SEC, max=EDSL_BACKOFF_MAX_SEC
|
342
|
+
),
|
343
|
+
retry=retry_if_exception_type(LanguageModelNoResponseError),
|
344
|
+
reraise=True,
|
345
|
+
)
|
346
|
+
async def attempt_answer():
|
347
|
+
nonlocal had_language_model_no_response_error
|
348
|
+
|
349
|
+
invigilator = self._get_invigilator(question)
|
350
|
+
|
351
|
+
if self._skip_this_question(question):
|
352
|
+
return invigilator.get_failed_task_result(
|
353
|
+
failure_reason="Question skipped."
|
354
|
+
)
|
355
|
+
|
356
|
+
try:
|
357
|
+
response: EDSLResultObjectInput = (
|
358
|
+
await invigilator.async_answer_question()
|
359
|
+
)
|
360
|
+
if response.validated:
|
361
|
+
self.answers.add_answer(response=response, question=question)
|
362
|
+
self._cancel_skipped_questions(question)
|
363
|
+
else:
|
364
|
+
# When a question is not validated, it is not added to the answers.
|
365
|
+
# this should also cancel and dependent children questions.
|
366
|
+
# Is that happening now?
|
367
|
+
if (
|
368
|
+
hasattr(response, "exception_occurred")
|
369
|
+
and response.exception_occurred
|
370
|
+
):
|
371
|
+
raise response.exception_occurred
|
372
|
+
|
373
|
+
except QuestionAnswerValidationError as e:
|
374
|
+
self._handle_exception(e, invigilator, task)
|
375
|
+
return invigilator.get_failed_task_result(
|
376
|
+
failure_reason="Question answer validation failed."
|
377
|
+
)
|
378
|
+
|
379
|
+
except asyncio.TimeoutError as e:
|
380
|
+
self._handle_exception(e, invigilator, task)
|
381
|
+
had_language_model_no_response_error = True
|
382
|
+
raise LanguageModelNoResponseError(
|
383
|
+
f"Language model timed out for question '{question.question_name}.'"
|
384
|
+
)
|
385
|
+
|
386
|
+
except Exception as e:
|
387
|
+
self._handle_exception(e, invigilator, task)
|
388
|
+
|
389
|
+
if "response" not in locals():
|
390
|
+
had_language_model_no_response_error = True
|
391
|
+
raise LanguageModelNoResponseError(
|
392
|
+
f"Language model did not return a response for question '{question.question_name}.'"
|
393
|
+
)
|
394
|
+
|
395
|
+
# if it gets here, it means the no response error was fixed
|
396
|
+
if (
|
397
|
+
question.question_name in self.exceptions
|
398
|
+
and had_language_model_no_response_error
|
399
|
+
):
|
400
|
+
self.exceptions.record_fixed_question(question.question_name)
|
401
|
+
|
402
|
+
return response
|
403
|
+
|
404
|
+
try:
|
405
|
+
return await attempt_answer()
|
406
|
+
except RetryError as retry_error:
|
407
|
+
# All retries have failed for LanguageModelNoResponseError
|
408
|
+
original_error = retry_error.last_attempt.exception()
|
409
|
+
self._handle_exception(
|
410
|
+
original_error, self._get_invigilator(question), task
|
411
|
+
)
|
412
|
+
raise original_error # Re-raise the original error after handling
|
413
|
+
|
414
|
+
def _get_invigilator(self, question: QuestionBase) -> InvigilatorBase:
|
415
|
+
"""Return an invigilator for the given question.
|
416
|
+
|
417
|
+
:param question: the question to be answered
|
418
|
+
:param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
|
419
|
+
"""
|
420
|
+
invigilator = self.agent.create_invigilator(
|
421
|
+
question=question,
|
422
|
+
scenario=self.scenario,
|
423
|
+
model=self.model,
|
424
|
+
debug=False,
|
425
|
+
survey=self.survey,
|
426
|
+
memory_plan=self.survey.memory_plan,
|
427
|
+
current_answers=self.answers,
|
428
|
+
iteration=self.iteration,
|
429
|
+
cache=self.cache,
|
430
|
+
sidecar_model=self.sidecar_model,
|
431
|
+
raise_validation_errors=self.raise_validation_errors,
|
432
|
+
)
|
433
|
+
"""Return an invigilator for the given question."""
|
434
|
+
return invigilator
|
435
|
+
|
436
|
+
def _skip_this_question(self, current_question: "QuestionBase") -> bool:
|
437
|
+
"""Determine if the current question should be skipped.
|
438
|
+
|
439
|
+
:param current_question: the question to be answered.
|
440
|
+
"""
|
441
|
+
current_question_index = self.to_index[current_question.question_name]
|
442
|
+
|
443
|
+
answers = self.answers | self.scenario | self.agent["traits"]
|
444
|
+
skip = self.survey.rule_collection.skip_question_before_running(
|
445
|
+
current_question_index, answers
|
446
|
+
)
|
447
|
+
return skip
|
448
|
+
|
449
|
+
def _handle_exception(
|
450
|
+
self, e: Exception, invigilator: "InvigilatorBase", task=None
|
451
|
+
):
|
452
|
+
import copy
|
453
|
+
|
454
|
+
# breakpoint()
|
455
|
+
|
456
|
+
answers = copy.copy(self.answers)
|
457
|
+
exception_entry = InterviewExceptionEntry(
|
458
|
+
exception=e,
|
459
|
+
invigilator=invigilator,
|
460
|
+
answers=answers,
|
461
|
+
)
|
462
|
+
if task:
|
463
|
+
task.task_status = TaskStatus.FAILED
|
464
|
+
self.exceptions.add(invigilator.question.question_name, exception_entry)
|
465
|
+
|
466
|
+
if self.raise_validation_errors:
|
467
|
+
if isinstance(e, QuestionAnswerValidationError):
|
468
|
+
raise e
|
469
|
+
|
470
|
+
if hasattr(self, "stop_on_exception"):
|
471
|
+
stop_on_exception = self.stop_on_exception
|
472
|
+
else:
|
473
|
+
stop_on_exception = False
|
474
|
+
|
475
|
+
if stop_on_exception:
|
476
|
+
raise e
|
477
|
+
|
478
|
+
def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
|
479
|
+
"""Cancel the tasks for questions that are skipped.
|
480
|
+
|
481
|
+
:param current_question: the question that was just answered.
|
482
|
+
|
483
|
+
It first determines the next question, given the current question and the current answers.
|
484
|
+
If the next question is the end of the survey, it cancels all remaining tasks.
|
485
|
+
If the next question is after the current question, it cancels all tasks between the current question and the next question.
|
486
|
+
"""
|
487
|
+
current_question_index: int = self.to_index[current_question.question_name]
|
488
|
+
|
489
|
+
next_question: Union[int, EndOfSurvey] = (
|
490
|
+
self.survey.rule_collection.next_question(
|
491
|
+
q_now=current_question_index,
|
492
|
+
answers=self.answers | self.scenario | self.agent["traits"],
|
493
|
+
)
|
494
|
+
)
|
495
|
+
|
496
|
+
next_question_index = next_question.next_q
|
497
|
+
|
498
|
+
def cancel_between(start, end):
|
499
|
+
"""Cancel the tasks between the start and end indices."""
|
500
|
+
for i in range(start, end):
|
501
|
+
self.tasks[i].cancel()
|
502
|
+
|
503
|
+
if next_question_index == EndOfSurvey:
|
504
|
+
cancel_between(current_question_index + 1, len(self.survey.questions))
|
505
|
+
return
|
506
|
+
|
507
|
+
if next_question_index > (current_question_index + 1):
|
508
|
+
cancel_between(current_question_index + 1, next_question_index)
|
509
|
+
|
510
|
+
# endregion
|
511
|
+
|
512
|
+
# region: Conducting the interview
|
513
|
+
async def async_conduct_interview(
|
514
|
+
self,
|
515
|
+
model_buckets: Optional[ModelBuckets] = None,
|
516
|
+
stop_on_exception: bool = False,
|
517
|
+
sidecar_model: Optional["LanguageModel"] = None,
|
518
|
+
raise_validation_errors: bool = True,
|
519
|
+
) -> tuple["Answers", List[dict[str, Any]]]:
|
520
|
+
"""
|
521
|
+
Conduct an Interview asynchronously.
|
522
|
+
It returns a tuple with the answers and a list of valid results.
|
523
|
+
|
524
|
+
:param model_buckets: a dictionary of token buckets for the model.
|
525
|
+
:param debug: run without calls to LLM.
|
526
|
+
:param stop_on_exception: if True, stops the interview if an exception is raised.
|
527
|
+
:param sidecar_model: a sidecar model used to answer questions.
|
528
|
+
|
529
|
+
Example usage:
|
530
|
+
|
531
|
+
>>> i = Interview.example()
|
532
|
+
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
533
|
+
>>> result['q0']
|
534
|
+
'yes'
|
535
|
+
|
536
|
+
>>> i = Interview.example(throw_exception = True)
|
537
|
+
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
538
|
+
>>> i.exceptions
|
539
|
+
{'q0': ...
|
540
|
+
>>> i = Interview.example()
|
541
|
+
>>> result, _ = asyncio.run(i.async_conduct_interview(stop_on_exception = True))
|
542
|
+
Traceback (most recent call last):
|
543
|
+
...
|
544
|
+
asyncio.exceptions.CancelledError
|
545
|
+
"""
|
546
|
+
self.sidecar_model = sidecar_model
|
547
|
+
self.stop_on_exception = stop_on_exception
|
548
|
+
|
549
|
+
# if no model bucket is passed, create an 'infinity' bucket with no rate limits
|
550
|
+
if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
|
551
|
+
model_buckets = ModelBuckets.infinity_bucket()
|
552
|
+
|
553
|
+
## This is the key part---it creates a task for each question,
|
554
|
+
## with dependencies on the questions that must be answered before this one can be answered.
|
555
|
+
self.tasks = self._build_question_tasks(model_buckets=model_buckets)
|
556
|
+
|
557
|
+
## 'Invigilators' are used to administer the survey
|
558
|
+
self.invigilators = [
|
559
|
+
self._get_invigilator(question) for question in self.survey.questions
|
560
|
+
]
|
561
|
+
await asyncio.gather(
|
562
|
+
*self.tasks, return_exceptions=not stop_on_exception
|
563
|
+
) # not stop_on_exception)
|
564
|
+
self.answers.replace_missing_answers_with_none(self.survey)
|
565
|
+
valid_results = list(self._extract_valid_results())
|
566
|
+
return self.answers, valid_results
|
567
|
+
|
568
|
+
# endregion
|
569
|
+
|
570
|
+
# region: Extracting results and recording errors
|
571
|
+
def _extract_valid_results(self) -> Generator["Answers", None, None]:
|
572
|
+
"""Extract the valid results from the list of results.
|
573
|
+
|
574
|
+
It iterates through the tasks and invigilators, and yields the results of the tasks that are done.
|
575
|
+
If a task is not done, it raises a ValueError.
|
576
|
+
If an exception is raised in the task, it records the exception in the Interview instance except if the task was cancelled, which is expected behavior.
|
577
|
+
|
578
|
+
>>> i = Interview.example()
|
579
|
+
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
580
|
+
>>> results = list(i._extract_valid_results())
|
581
|
+
>>> len(results) == len(i.survey)
|
582
|
+
True
|
583
|
+
"""
|
584
|
+
assert len(self.tasks) == len(self.invigilators)
|
585
|
+
|
586
|
+
for task, invigilator in zip(self.tasks, self.invigilators):
|
587
|
+
if not task.done():
|
588
|
+
raise ValueError(f"Task {task.get_name()} is not done.")
|
589
|
+
|
590
|
+
try:
|
591
|
+
result = task.result()
|
592
|
+
except asyncio.CancelledError as e: # task was cancelled
|
593
|
+
result = invigilator.get_failed_task_result(
|
594
|
+
failure_reason="Task was cancelled."
|
595
|
+
)
|
596
|
+
except Exception as e: # any other kind of exception in the task
|
597
|
+
result = invigilator.get_failed_task_result(
|
598
|
+
failure_reason=f"Task failed with exception: {str(e)}."
|
599
|
+
)
|
600
|
+
exception_entry = InterviewExceptionEntry(
|
601
|
+
exception=e,
|
602
|
+
invigilator=invigilator,
|
603
|
+
)
|
604
|
+
self.exceptions.add(task.get_name(), exception_entry)
|
605
|
+
|
606
|
+
yield result
|
607
|
+
|
608
|
+
# endregion
|
609
|
+
|
610
|
+
# region: Magic methods
|
611
|
+
def __repr__(self) -> str:
|
612
|
+
"""Return a string representation of the Interview instance."""
|
613
|
+
return f"Interview(agent = {repr(self.agent)}, survey = {repr(self.survey)}, scenario = {repr(self.scenario)}, model = {repr(self.model)})"
|
614
|
+
|
615
|
+
def duplicate(self, iteration: int, cache: "Cache") -> Interview:
|
616
|
+
"""Duplicate the interview, but with a new iteration number and cache.
|
617
|
+
|
618
|
+
>>> i = Interview.example()
|
619
|
+
>>> i2 = i.duplicate(1, None)
|
620
|
+
>>> i.iteration + 1 == i2.iteration
|
621
|
+
True
|
622
|
+
|
623
|
+
"""
|
624
|
+
return Interview(
|
625
|
+
agent=self.agent,
|
626
|
+
survey=self.survey,
|
627
|
+
scenario=self.scenario,
|
628
|
+
model=self.model,
|
629
|
+
iteration=iteration,
|
630
|
+
cache=cache,
|
631
|
+
skip_retry=self.skip_retry,
|
632
|
+
)
|
633
|
+
|
634
|
+
@classmethod
|
635
|
+
def example(self, throw_exception: bool = False) -> Interview:
|
636
|
+
"""Return an example Interview instance."""
|
637
|
+
from edsl.agents import Agent
|
638
|
+
from edsl.surveys import Survey
|
639
|
+
from edsl.scenarios import Scenario
|
640
|
+
from edsl.language_models import LanguageModel
|
641
|
+
|
642
|
+
def f(self, question, scenario):
|
643
|
+
return "yes"
|
644
|
+
|
645
|
+
agent = Agent.example()
|
646
|
+
agent.add_direct_question_answering_method(f)
|
647
|
+
survey = Survey.example()
|
648
|
+
scenario = Scenario.example()
|
649
|
+
model = LanguageModel.example()
|
650
|
+
if throw_exception:
|
651
|
+
model = LanguageModel.example(test_model=True, throw_exception=True)
|
652
|
+
agent = Agent.example()
|
653
|
+
return Interview(agent=agent, survey=survey, scenario=scenario, model=model)
|
654
|
+
return Interview(agent=agent, survey=survey, scenario=scenario, model=model)
|
655
|
+
|
656
|
+
|
657
|
+
if __name__ == "__main__":
|
658
|
+
import doctest
|
659
|
+
|
660
|
+
# add ellipsis
|
661
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|