edsl 0.1.39__py3-none-any.whl → 0.1.39.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +116 -197
- edsl/__init__.py +7 -15
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +147 -351
- edsl/agents/AgentList.py +73 -211
- edsl/agents/Invigilator.py +50 -101
- edsl/agents/InvigilatorBase.py +70 -62
- edsl/agents/PromptConstructor.py +225 -143
- edsl/agents/__init__.py +1 -0
- edsl/agents/prompt_helpers.py +3 -3
- edsl/auto/AutoStudy.py +5 -18
- edsl/auto/StageBase.py +40 -53
- edsl/auto/StageQuestions.py +1 -2
- edsl/auto/utilities.py +6 -0
- edsl/config.py +2 -22
- edsl/conversation/car_buying.py +1 -2
- edsl/coop/PriceFetcher.py +1 -1
- edsl/coop/coop.py +47 -125
- edsl/coop/utils.py +14 -14
- edsl/data/Cache.py +27 -45
- edsl/data/CacheEntry.py +15 -12
- edsl/data/CacheHandler.py +12 -31
- edsl/data/RemoteCacheSync.py +46 -154
- edsl/data/__init__.py +3 -4
- edsl/data_transfer_models.py +1 -2
- edsl/enums.py +0 -27
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +0 -12
- edsl/exceptions/questions.py +6 -24
- edsl/exceptions/scenarios.py +0 -7
- edsl/inference_services/AnthropicService.py +19 -38
- edsl/inference_services/AwsBedrock.py +2 -0
- edsl/inference_services/AzureAI.py +2 -0
- edsl/inference_services/GoogleService.py +12 -7
- edsl/inference_services/InferenceServiceABC.py +85 -18
- edsl/inference_services/InferenceServicesCollection.py +79 -120
- edsl/inference_services/MistralAIService.py +3 -0
- edsl/inference_services/OpenAIService.py +35 -47
- edsl/inference_services/PerplexityService.py +3 -0
- edsl/inference_services/TestService.py +10 -11
- edsl/inference_services/TogetherAIService.py +3 -5
- edsl/jobs/Answers.py +14 -1
- edsl/jobs/Jobs.py +431 -356
- edsl/jobs/JobsChecks.py +10 -35
- edsl/jobs/JobsPrompts.py +4 -6
- edsl/jobs/JobsRemoteInferenceHandler.py +133 -205
- edsl/jobs/buckets/BucketCollection.py +3 -44
- edsl/jobs/buckets/TokenBucket.py +21 -53
- edsl/jobs/interviews/Interview.py +408 -143
- edsl/jobs/runners/JobsRunnerAsyncio.py +403 -88
- edsl/jobs/runners/JobsRunnerStatus.py +165 -133
- edsl/jobs/tasks/QuestionTaskCreator.py +19 -21
- edsl/jobs/tasks/TaskHistory.py +18 -38
- edsl/jobs/tasks/task_status_enum.py +2 -0
- edsl/language_models/KeyLookup.py +30 -0
- edsl/language_models/LanguageModel.py +236 -194
- edsl/language_models/ModelList.py +19 -28
- edsl/language_models/__init__.py +2 -1
- edsl/language_models/registry.py +190 -0
- edsl/language_models/repair.py +2 -2
- edsl/language_models/unused/ReplicateBase.py +83 -0
- edsl/language_models/utilities.py +4 -5
- edsl/notebooks/Notebook.py +14 -19
- edsl/prompts/Prompt.py +39 -29
- edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +2 -47
- edsl/questions/QuestionBase.py +214 -68
- edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +50 -57
- edsl/questions/QuestionBasePromptsMixin.py +3 -7
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +3 -3
- edsl/questions/QuestionExtract.py +7 -5
- edsl/questions/QuestionFreeText.py +3 -2
- edsl/questions/QuestionList.py +18 -10
- edsl/questions/QuestionMultipleChoice.py +23 -67
- edsl/questions/QuestionNumerical.py +4 -2
- edsl/questions/QuestionRank.py +17 -7
- edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +26 -40
- edsl/questions/SimpleAskMixin.py +3 -4
- edsl/questions/__init__.py +1 -2
- edsl/questions/derived/QuestionLinearScale.py +3 -6
- edsl/questions/derived/QuestionTopK.py +1 -1
- edsl/questions/descriptors.py +3 -17
- edsl/questions/question_registry.py +1 -1
- edsl/results/CSSParameterizer.py +1 -1
- edsl/results/Dataset.py +7 -170
- edsl/results/DatasetExportMixin.py +305 -168
- edsl/results/DatasetTree.py +8 -28
- edsl/results/Result.py +206 -298
- edsl/results/Results.py +131 -149
- edsl/results/ResultsDBMixin.py +238 -0
- edsl/results/ResultsExportMixin.py +0 -2
- edsl/results/{results_selector.py → Selector.py} +13 -23
- edsl/results/TableDisplay.py +171 -98
- edsl/results/__init__.py +1 -1
- edsl/scenarios/FileStore.py +239 -150
- edsl/scenarios/Scenario.py +193 -90
- edsl/scenarios/ScenarioHtmlMixin.py +3 -4
- edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +6 -10
- edsl/scenarios/ScenarioList.py +244 -415
- edsl/scenarios/ScenarioListExportMixin.py +7 -0
- edsl/scenarios/ScenarioListPdfMixin.py +37 -15
- edsl/scenarios/__init__.py +2 -1
- edsl/study/ObjectEntry.py +1 -1
- edsl/study/SnapShot.py +1 -1
- edsl/study/Study.py +12 -5
- edsl/surveys/Rule.py +4 -5
- edsl/surveys/RuleCollection.py +27 -25
- edsl/surveys/Survey.py +791 -270
- edsl/surveys/SurveyCSS.py +8 -20
- edsl/surveys/{SurveyFlowVisualization.py → SurveyFlowVisualizationMixin.py} +9 -11
- edsl/surveys/__init__.py +2 -4
- edsl/surveys/descriptors.py +2 -6
- edsl/surveys/instructions/ChangeInstruction.py +2 -1
- edsl/surveys/instructions/Instruction.py +13 -4
- edsl/surveys/instructions/InstructionCollection.py +6 -11
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/report.html +1 -1
- edsl/tools/plotting.py +1 -1
- edsl/utilities/utilities.py +23 -35
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/METADATA +10 -12
- edsl-0.1.39.dev1.dist-info/RECORD +277 -0
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/WHEEL +1 -1
- edsl/agents/QuestionInstructionPromptBuilder.py +0 -128
- edsl/agents/QuestionTemplateReplacementsBuilder.py +0 -137
- edsl/agents/question_option_processor.py +0 -172
- edsl/coop/CoopFunctionsMixin.py +0 -15
- edsl/coop/ExpectedParrotKeyHandler.py +0 -125
- edsl/exceptions/inference_services.py +0 -5
- edsl/inference_services/AvailableModelCacheHandler.py +0 -184
- edsl/inference_services/AvailableModelFetcher.py +0 -215
- edsl/inference_services/ServiceAvailability.py +0 -135
- edsl/inference_services/data_structures.py +0 -134
- edsl/jobs/AnswerQuestionFunctionConstructor.py +0 -223
- edsl/jobs/FetchInvigilator.py +0 -47
- edsl/jobs/InterviewTaskManager.py +0 -98
- edsl/jobs/InterviewsConstructor.py +0 -50
- edsl/jobs/JobsComponentConstructor.py +0 -189
- edsl/jobs/JobsRemoteInferenceLogger.py +0 -239
- edsl/jobs/RequestTokenEstimator.py +0 -30
- edsl/jobs/async_interview_runner.py +0 -138
- edsl/jobs/buckets/TokenBucketAPI.py +0 -211
- edsl/jobs/buckets/TokenBucketClient.py +0 -191
- edsl/jobs/check_survey_scenario_compatibility.py +0 -85
- edsl/jobs/data_structures.py +0 -120
- edsl/jobs/decorators.py +0 -35
- edsl/jobs/jobs_status_enums.py +0 -9
- edsl/jobs/loggers/HTMLTableJobLogger.py +0 -304
- edsl/jobs/results_exceptions_handler.py +0 -98
- edsl/language_models/ComputeCost.py +0 -63
- edsl/language_models/PriceManager.py +0 -127
- edsl/language_models/RawResponseHandler.py +0 -106
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/key_management/KeyLookup.py +0 -63
- edsl/language_models/key_management/KeyLookupBuilder.py +0 -273
- edsl/language_models/key_management/KeyLookupCollection.py +0 -38
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +0 -131
- edsl/language_models/model.py +0 -256
- edsl/notebooks/NotebookToLaTeX.py +0 -142
- edsl/questions/ExceptionExplainer.py +0 -77
- edsl/questions/HTMLQuestion.py +0 -103
- edsl/questions/QuestionMatrix.py +0 -265
- edsl/questions/data_structures.py +0 -20
- edsl/questions/loop_processor.py +0 -149
- edsl/questions/response_validator_factory.py +0 -34
- edsl/questions/templates/matrix/__init__.py +0 -1
- edsl/questions/templates/matrix/answering_instructions.jinja +0 -5
- edsl/questions/templates/matrix/question_presentation.jinja +0 -20
- edsl/results/MarkdownToDocx.py +0 -122
- edsl/results/MarkdownToPDF.py +0 -111
- edsl/results/TextEditor.py +0 -50
- edsl/results/file_exports.py +0 -252
- edsl/results/smart_objects.py +0 -96
- edsl/results/table_data_class.py +0 -12
- edsl/results/table_renderers.py +0 -118
- edsl/scenarios/ConstructDownloadLink.py +0 -109
- edsl/scenarios/DocumentChunker.py +0 -102
- edsl/scenarios/DocxScenario.py +0 -16
- edsl/scenarios/PdfExtractor.py +0 -40
- edsl/scenarios/directory_scanner.py +0 -96
- edsl/scenarios/file_methods.py +0 -85
- edsl/scenarios/handlers/__init__.py +0 -13
- edsl/scenarios/handlers/csv.py +0 -49
- edsl/scenarios/handlers/docx.py +0 -76
- edsl/scenarios/handlers/html.py +0 -37
- edsl/scenarios/handlers/json.py +0 -111
- edsl/scenarios/handlers/latex.py +0 -5
- edsl/scenarios/handlers/md.py +0 -51
- edsl/scenarios/handlers/pdf.py +0 -68
- edsl/scenarios/handlers/png.py +0 -39
- edsl/scenarios/handlers/pptx.py +0 -105
- edsl/scenarios/handlers/py.py +0 -294
- edsl/scenarios/handlers/sql.py +0 -313
- edsl/scenarios/handlers/sqlite.py +0 -149
- edsl/scenarios/handlers/txt.py +0 -33
- edsl/scenarios/scenario_selector.py +0 -156
- edsl/surveys/ConstructDAG.py +0 -92
- edsl/surveys/EditSurvey.py +0 -221
- edsl/surveys/InstructionHandler.py +0 -100
- edsl/surveys/MemoryManagement.py +0 -72
- edsl/surveys/RuleManager.py +0 -172
- edsl/surveys/Simulator.py +0 -75
- edsl/surveys/SurveyToApp.py +0 -141
- edsl/utilities/PrettyList.py +0 -56
- edsl/utilities/is_notebook.py +0 -18
- edsl/utilities/is_valid_variable_name.py +0 -11
- edsl/utilities/remove_edsl_version.py +0 -24
- edsl-0.1.39.dist-info/RECORD +0 -358
- /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
- /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
- /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/LICENSE +0 -0
@@ -2,44 +2,58 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
import asyncio
|
5
|
-
from typing import Any, Type, List, Generator, Optional, Union
|
5
|
+
from typing import Any, Type, List, Generator, Optional, Union
|
6
6
|
import copy
|
7
|
-
from dataclasses import dataclass
|
8
7
|
|
9
|
-
|
10
|
-
|
8
|
+
from tenacity import (
|
9
|
+
retry,
|
10
|
+
stop_after_attempt,
|
11
|
+
wait_exponential,
|
12
|
+
retry_if_exception_type,
|
13
|
+
RetryError,
|
14
|
+
)
|
15
|
+
|
16
|
+
from edsl import CONFIG
|
17
|
+
from edsl.surveys.base import EndOfSurvey
|
18
|
+
from edsl.exceptions import QuestionAnswerValidationError
|
19
|
+
from edsl.exceptions import QuestionAnswerValidationError
|
20
|
+
from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
|
21
|
+
|
22
|
+
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
23
|
+
from edsl.jobs.Answers import Answers
|
24
|
+
from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
|
25
|
+
from edsl.jobs.tasks.TaskCreators import TaskCreators
|
11
26
|
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
12
|
-
from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
|
13
27
|
from edsl.jobs.interviews.InterviewExceptionCollection import (
|
14
28
|
InterviewExceptionCollection,
|
15
29
|
)
|
16
|
-
|
30
|
+
|
31
|
+
# from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
|
32
|
+
|
33
|
+
from edsl.surveys.base import EndOfSurvey
|
17
34
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
18
|
-
from edsl.jobs.
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
from edsl
|
35
|
+
from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
36
|
+
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
37
|
+
from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
|
38
|
+
|
39
|
+
|
40
|
+
from edsl import Agent, Survey, Scenario, Cache
|
41
|
+
from edsl.language_models import LanguageModel
|
42
|
+
from edsl.questions import QuestionBase
|
43
|
+
from edsl.agents.InvigilatorBase import InvigilatorBase
|
44
|
+
|
45
|
+
from edsl.exceptions.language_models import LanguageModelNoResponseError
|
24
46
|
|
47
|
+
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
48
|
+
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
49
|
+
from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
|
25
50
|
|
26
|
-
if TYPE_CHECKING:
|
27
|
-
from edsl.agents.Agent import Agent
|
28
|
-
from edsl.surveys.Survey import Survey
|
29
|
-
from edsl.scenarios.Scenario import Scenario
|
30
|
-
from edsl.data.Cache import Cache
|
31
|
-
from edsl.language_models.LanguageModel import LanguageModel
|
32
|
-
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
33
|
-
from edsl.agents.InvigilatorBase import InvigilatorBase
|
34
|
-
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
35
51
|
|
52
|
+
from edsl import CONFIG
|
36
53
|
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
skip_retry: bool = (False,) # COULD BE SET WITH CONFIG
|
41
|
-
raise_validation_errors: bool = (True,)
|
42
|
-
stop_on_exception: bool = (False,)
|
54
|
+
EDSL_BACKOFF_START_SEC = float(CONFIG.get("EDSL_BACKOFF_START_SEC"))
|
55
|
+
EDSL_BACKOFF_MAX_SEC = float(CONFIG.get("EDSL_BACKOFF_MAX_SEC"))
|
56
|
+
EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
|
43
57
|
|
44
58
|
|
45
59
|
class Interview:
|
@@ -56,10 +70,11 @@ class Interview:
|
|
56
70
|
survey: Survey,
|
57
71
|
scenario: Scenario,
|
58
72
|
model: Type["LanguageModel"],
|
73
|
+
debug: Optional[bool] = False,
|
59
74
|
iteration: int = 0,
|
60
|
-
indices: dict = None, # explain?
|
61
75
|
cache: Optional["Cache"] = None,
|
62
|
-
|
76
|
+
sidecar_model: Optional["LanguageModel"] = None,
|
77
|
+
skip_retry: bool = False,
|
63
78
|
raise_validation_errors: bool = True,
|
64
79
|
):
|
65
80
|
"""Initialize the Interview instance.
|
@@ -68,12 +83,13 @@ class Interview:
|
|
68
83
|
:param survey: the survey being administered to the agent.
|
69
84
|
:param scenario: the scenario that populates the survey questions.
|
70
85
|
:param model: the language model used to answer the questions.
|
71
|
-
|
86
|
+
:param debug: if True, run without calls to the language model.
|
72
87
|
:param iteration: the iteration number of the interview.
|
73
88
|
:param cache: the cache used to store the answers.
|
89
|
+
:param sidecar_model: a sidecar model used to answer questions.
|
74
90
|
|
75
91
|
>>> i = Interview.example()
|
76
|
-
>>> i.
|
92
|
+
>>> i.task_creators
|
77
93
|
{}
|
78
94
|
|
79
95
|
>>> i.exceptions
|
@@ -88,27 +104,22 @@ class Interview:
|
|
88
104
|
|
89
105
|
"""
|
90
106
|
self.agent = agent
|
91
|
-
self.survey = copy.deepcopy(survey)
|
107
|
+
self.survey = copy.deepcopy(survey)
|
92
108
|
self.scenario = scenario
|
93
109
|
self.model = model
|
110
|
+
self.debug = debug
|
94
111
|
self.iteration = iteration
|
112
|
+
self.cache = cache
|
113
|
+
self.answers: dict[
|
114
|
+
str, str
|
115
|
+
] = Answers() # will get filled in as interview progresses
|
116
|
+
self.sidecar_model = sidecar_model
|
95
117
|
|
96
|
-
|
97
|
-
|
98
|
-
self.task_manager = InterviewTaskManager(
|
99
|
-
survey=self.survey,
|
100
|
-
iteration=iteration,
|
101
|
-
)
|
102
|
-
|
118
|
+
# Trackers
|
119
|
+
self.task_creators = TaskCreators() # tracks the task creators
|
103
120
|
self.exceptions = InterviewExceptionCollection()
|
104
121
|
|
105
|
-
self.
|
106
|
-
cache=cache,
|
107
|
-
skip_retry=skip_retry,
|
108
|
-
raise_validation_errors=raise_validation_errors,
|
109
|
-
)
|
110
|
-
|
111
|
-
self.cache = cache
|
122
|
+
self._task_status_log_dict = InterviewStatusLog()
|
112
123
|
self.skip_retry = skip_retry
|
113
124
|
self.raise_validation_errors = raise_validation_errors
|
114
125
|
|
@@ -120,9 +131,6 @@ class Interview:
|
|
120
131
|
|
121
132
|
self.failed_questions = []
|
122
133
|
|
123
|
-
self.indices = indices
|
124
|
-
self.initial_hash = hash(self)
|
125
|
-
|
126
134
|
@property
|
127
135
|
def has_exceptions(self) -> bool:
|
128
136
|
"""Return True if there are exceptions."""
|
@@ -134,26 +142,30 @@ class Interview:
|
|
134
142
|
|
135
143
|
The keys are the question names; the values are the lists of status log changes for each task.
|
136
144
|
"""
|
137
|
-
|
145
|
+
for task_creator in self.task_creators.values():
|
146
|
+
self._task_status_log_dict[
|
147
|
+
task_creator.question.question_name
|
148
|
+
] = task_creator.status_log
|
149
|
+
return self._task_status_log_dict
|
138
150
|
|
139
151
|
@property
|
140
152
|
def token_usage(self) -> InterviewTokenUsage:
|
141
153
|
"""Determine how many tokens were used for the interview."""
|
142
|
-
return self.
|
154
|
+
return self.task_creators.token_usage
|
143
155
|
|
144
156
|
@property
|
145
157
|
def interview_status(self) -> InterviewStatusDictionary:
|
146
158
|
"""Return a dictionary mapping task status codes to counts."""
|
147
|
-
|
148
|
-
return self.task_manager.interview_status
|
159
|
+
return self.task_creators.interview_status
|
149
160
|
|
161
|
+
# region: Serialization
|
150
162
|
def to_dict(self, include_exceptions=True, add_edsl_version=True) -> dict[str, Any]:
|
151
163
|
"""Return a dictionary representation of the Interview instance.
|
152
164
|
This is just for hashing purposes.
|
153
165
|
|
154
166
|
>>> i = Interview.example()
|
155
167
|
>>> hash(i)
|
156
|
-
|
168
|
+
1217840301076717434
|
157
169
|
"""
|
158
170
|
d = {
|
159
171
|
"agent": self.agent.to_dict(add_edsl_version=add_edsl_version),
|
@@ -165,34 +177,23 @@ class Interview:
|
|
165
177
|
}
|
166
178
|
if include_exceptions:
|
167
179
|
d["exceptions"] = self.exceptions.to_dict()
|
168
|
-
if hasattr(self, "indices"):
|
169
|
-
d["indices"] = self.indices
|
170
180
|
return d
|
171
181
|
|
172
182
|
@classmethod
|
173
183
|
def from_dict(cls, d: dict[str, Any]) -> "Interview":
|
174
184
|
"""Return an Interview instance from a dictionary."""
|
175
|
-
|
176
|
-
from edsl.agents.Agent import Agent
|
177
|
-
from edsl.surveys.Survey import Survey
|
178
|
-
from edsl.scenarios.Scenario import Scenario
|
179
|
-
from edsl.language_models.LanguageModel import LanguageModel
|
180
|
-
|
181
185
|
agent = Agent.from_dict(d["agent"])
|
182
186
|
survey = Survey.from_dict(d["survey"])
|
183
187
|
scenario = Scenario.from_dict(d["scenario"])
|
184
188
|
model = LanguageModel.from_dict(d["model"])
|
185
189
|
iteration = d["iteration"]
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
if "indices" in d:
|
194
|
-
params["indices"] = d["indices"]
|
195
|
-
interview = cls(**params)
|
190
|
+
interview = cls(
|
191
|
+
agent=agent,
|
192
|
+
survey=survey,
|
193
|
+
scenario=scenario,
|
194
|
+
model=model,
|
195
|
+
iteration=iteration,
|
196
|
+
)
|
196
197
|
if "exceptions" in d:
|
197
198
|
exceptions = InterviewExceptionCollection.from_dict(d["exceptions"])
|
198
199
|
interview.exceptions = exceptions
|
@@ -210,13 +211,311 @@ class Interview:
|
|
210
211
|
"""
|
211
212
|
return hash(self) == hash(other)
|
212
213
|
|
214
|
+
# endregion
|
215
|
+
|
216
|
+
# region: Creating tasks
|
217
|
+
@property
|
218
|
+
def dag(self) -> "DAG":
|
219
|
+
"""Return the directed acyclic graph for the survey.
|
220
|
+
|
221
|
+
The DAG, or directed acyclic graph, is a dictionary that maps question names to their dependencies.
|
222
|
+
It is used to determine the order in which questions should be answered.
|
223
|
+
This reflects both agent 'memory' considerations and 'skip' logic.
|
224
|
+
The 'textify' parameter is set to True, so that the question names are returned as strings rather than integer indices.
|
225
|
+
|
226
|
+
>>> i = Interview.example()
|
227
|
+
>>> i.dag == {'q2': {'q0'}, 'q1': {'q0'}}
|
228
|
+
True
|
229
|
+
"""
|
230
|
+
return self.survey.dag(textify=True)
|
231
|
+
|
232
|
+
def _build_question_tasks(
|
233
|
+
self,
|
234
|
+
model_buckets: ModelBuckets,
|
235
|
+
) -> list[asyncio.Task]:
|
236
|
+
"""Create a task for each question, with dependencies on the questions that must be answered before this one can be answered.
|
237
|
+
|
238
|
+
:param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
|
239
|
+
:param model_buckets: the model buckets used to track and control usage rates.
|
240
|
+
"""
|
241
|
+
tasks = []
|
242
|
+
for question in self.survey.questions:
|
243
|
+
tasks_that_must_be_completed_before = list(
|
244
|
+
self._get_tasks_that_must_be_completed_before(
|
245
|
+
tasks=tasks, question=question
|
246
|
+
)
|
247
|
+
)
|
248
|
+
question_task = self._create_question_task(
|
249
|
+
question=question,
|
250
|
+
tasks_that_must_be_completed_before=tasks_that_must_be_completed_before,
|
251
|
+
model_buckets=model_buckets,
|
252
|
+
iteration=self.iteration,
|
253
|
+
)
|
254
|
+
tasks.append(question_task)
|
255
|
+
return tuple(tasks)
|
256
|
+
|
257
|
+
def _get_tasks_that_must_be_completed_before(
|
258
|
+
self, *, tasks: list[asyncio.Task], question: "QuestionBase"
|
259
|
+
) -> Generator[asyncio.Task, None, None]:
|
260
|
+
"""Return the tasks that must be completed before the given question can be answered.
|
261
|
+
|
262
|
+
:param tasks: a list of tasks that have been created so far.
|
263
|
+
:param question: the question for which we are determining dependencies.
|
264
|
+
|
265
|
+
If a question has no dependencies, this will be an empty list, [].
|
266
|
+
"""
|
267
|
+
parents_of_focal_question = self.dag.get(question.question_name, [])
|
268
|
+
for parent_question_name in parents_of_focal_question:
|
269
|
+
yield tasks[self.to_index[parent_question_name]]
|
270
|
+
|
271
|
+
def _create_question_task(
|
272
|
+
self,
|
273
|
+
*,
|
274
|
+
question: QuestionBase,
|
275
|
+
tasks_that_must_be_completed_before: list[asyncio.Task],
|
276
|
+
model_buckets: ModelBuckets,
|
277
|
+
iteration: int = 0,
|
278
|
+
) -> asyncio.Task:
|
279
|
+
"""Create a task that depends on the passed-in dependencies that are awaited before the task is run.
|
280
|
+
|
281
|
+
:param question: the question to be answered. This is the question we are creating a task for.
|
282
|
+
:param tasks_that_must_be_completed_before: the tasks that must be completed before the focal task is run.
|
283
|
+
:param model_buckets: the model buckets used to track and control usage rates.
|
284
|
+
:param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
|
285
|
+
:param iteration: the iteration number for the interview.
|
286
|
+
|
287
|
+
The task is created by a `QuestionTaskCreator`, which is responsible for creating the task and managing its dependencies.
|
288
|
+
It is passed a reference to the function that will be called to answer the question.
|
289
|
+
It is passed a list "tasks_that_must_be_completed_before" that are awaited before the task is run.
|
290
|
+
These are added as a dependency to the focal task.
|
291
|
+
"""
|
292
|
+
task_creator = QuestionTaskCreator(
|
293
|
+
question=question,
|
294
|
+
answer_question_func=self._answer_question_and_record_task,
|
295
|
+
token_estimator=self._get_estimated_request_tokens,
|
296
|
+
model_buckets=model_buckets,
|
297
|
+
iteration=iteration,
|
298
|
+
)
|
299
|
+
for task in tasks_that_must_be_completed_before:
|
300
|
+
task_creator.add_dependency(task)
|
301
|
+
|
302
|
+
self.task_creators.update(
|
303
|
+
{question.question_name: task_creator}
|
304
|
+
) # track this task creator
|
305
|
+
return task_creator.generate_task()
|
306
|
+
|
307
|
+
def _get_estimated_request_tokens(self, question) -> float:
|
308
|
+
"""Estimate the number of tokens that will be required to run the focal task."""
|
309
|
+
from edsl.scenarios.FileStore import FileStore
|
310
|
+
|
311
|
+
invigilator = self._get_invigilator(question=question)
|
312
|
+
# TODO: There should be a way to get a more accurate estimate.
|
313
|
+
combined_text = ""
|
314
|
+
file_tokens = 0
|
315
|
+
for prompt in invigilator.get_prompts().values():
|
316
|
+
if hasattr(prompt, "text"):
|
317
|
+
combined_text += prompt.text
|
318
|
+
elif isinstance(prompt, str):
|
319
|
+
combined_text += prompt
|
320
|
+
elif isinstance(prompt, list):
|
321
|
+
for file in prompt:
|
322
|
+
if isinstance(file, FileStore):
|
323
|
+
file_tokens += file.size * 0.25
|
324
|
+
else:
|
325
|
+
raise ValueError(f"Prompt is of type {type(prompt)}")
|
326
|
+
return len(combined_text) / 4.0 + file_tokens
|
327
|
+
|
328
|
+
async def _answer_question_and_record_task(
|
329
|
+
self,
|
330
|
+
*,
|
331
|
+
question: "QuestionBase",
|
332
|
+
task=None,
|
333
|
+
) -> "AgentResponseDict":
|
334
|
+
"""Answer a question and records the task."""
|
335
|
+
|
336
|
+
had_language_model_no_response_error = False
|
337
|
+
|
338
|
+
@retry(
|
339
|
+
stop=stop_after_attempt(EDSL_MAX_ATTEMPTS),
|
340
|
+
wait=wait_exponential(
|
341
|
+
multiplier=EDSL_BACKOFF_START_SEC, max=EDSL_BACKOFF_MAX_SEC
|
342
|
+
),
|
343
|
+
retry=retry_if_exception_type(LanguageModelNoResponseError),
|
344
|
+
reraise=True,
|
345
|
+
)
|
346
|
+
async def attempt_answer():
|
347
|
+
nonlocal had_language_model_no_response_error
|
348
|
+
|
349
|
+
invigilator = self._get_invigilator(question)
|
350
|
+
|
351
|
+
if self._skip_this_question(question):
|
352
|
+
return invigilator.get_failed_task_result(
|
353
|
+
failure_reason="Question skipped."
|
354
|
+
)
|
355
|
+
|
356
|
+
try:
|
357
|
+
response: EDSLResultObjectInput = (
|
358
|
+
await invigilator.async_answer_question()
|
359
|
+
)
|
360
|
+
if response.validated:
|
361
|
+
self.answers.add_answer(response=response, question=question)
|
362
|
+
self._cancel_skipped_questions(question)
|
363
|
+
else:
|
364
|
+
# When a question is not validated, it is not added to the answers.
|
365
|
+
# this should also cancel and dependent children questions.
|
366
|
+
# Is that happening now?
|
367
|
+
if (
|
368
|
+
hasattr(response, "exception_occurred")
|
369
|
+
and response.exception_occurred
|
370
|
+
):
|
371
|
+
raise response.exception_occurred
|
372
|
+
|
373
|
+
except QuestionAnswerValidationError as e:
|
374
|
+
self._handle_exception(e, invigilator, task)
|
375
|
+
return invigilator.get_failed_task_result(
|
376
|
+
failure_reason="Question answer validation failed."
|
377
|
+
)
|
378
|
+
|
379
|
+
except asyncio.TimeoutError as e:
|
380
|
+
self._handle_exception(e, invigilator, task)
|
381
|
+
had_language_model_no_response_error = True
|
382
|
+
raise LanguageModelNoResponseError(
|
383
|
+
f"Language model timed out for question '{question.question_name}.'"
|
384
|
+
)
|
385
|
+
|
386
|
+
except Exception as e:
|
387
|
+
self._handle_exception(e, invigilator, task)
|
388
|
+
|
389
|
+
if "response" not in locals():
|
390
|
+
had_language_model_no_response_error = True
|
391
|
+
raise LanguageModelNoResponseError(
|
392
|
+
f"Language model did not return a response for question '{question.question_name}.'"
|
393
|
+
)
|
394
|
+
|
395
|
+
# if it gets here, it means the no response error was fixed
|
396
|
+
if (
|
397
|
+
question.question_name in self.exceptions
|
398
|
+
and had_language_model_no_response_error
|
399
|
+
):
|
400
|
+
self.exceptions.record_fixed_question(question.question_name)
|
401
|
+
|
402
|
+
return response
|
403
|
+
|
404
|
+
try:
|
405
|
+
return await attempt_answer()
|
406
|
+
except RetryError as retry_error:
|
407
|
+
# All retries have failed for LanguageModelNoResponseError
|
408
|
+
original_error = retry_error.last_attempt.exception()
|
409
|
+
self._handle_exception(
|
410
|
+
original_error, self._get_invigilator(question), task
|
411
|
+
)
|
412
|
+
raise original_error # Re-raise the original error after handling
|
413
|
+
|
414
|
+
def _get_invigilator(self, question: QuestionBase) -> InvigilatorBase:
|
415
|
+
"""Return an invigilator for the given question.
|
416
|
+
|
417
|
+
:param question: the question to be answered
|
418
|
+
:param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
|
419
|
+
"""
|
420
|
+
invigilator = self.agent.create_invigilator(
|
421
|
+
question=question,
|
422
|
+
scenario=self.scenario,
|
423
|
+
model=self.model,
|
424
|
+
debug=False,
|
425
|
+
survey=self.survey,
|
426
|
+
memory_plan=self.survey.memory_plan,
|
427
|
+
current_answers=self.answers,
|
428
|
+
iteration=self.iteration,
|
429
|
+
cache=self.cache,
|
430
|
+
sidecar_model=self.sidecar_model,
|
431
|
+
raise_validation_errors=self.raise_validation_errors,
|
432
|
+
)
|
433
|
+
"""Return an invigilator for the given question."""
|
434
|
+
return invigilator
|
435
|
+
|
436
|
+
def _skip_this_question(self, current_question: "QuestionBase") -> bool:
|
437
|
+
"""Determine if the current question should be skipped.
|
438
|
+
|
439
|
+
:param current_question: the question to be answered.
|
440
|
+
"""
|
441
|
+
current_question_index = self.to_index[current_question.question_name]
|
442
|
+
|
443
|
+
answers = self.answers | self.scenario | self.agent["traits"]
|
444
|
+
skip = self.survey.rule_collection.skip_question_before_running(
|
445
|
+
current_question_index, answers
|
446
|
+
)
|
447
|
+
return skip
|
448
|
+
|
449
|
+
def _handle_exception(
|
450
|
+
self, e: Exception, invigilator: "InvigilatorBase", task=None
|
451
|
+
):
|
452
|
+
import copy
|
453
|
+
|
454
|
+
# breakpoint()
|
455
|
+
|
456
|
+
answers = copy.copy(self.answers)
|
457
|
+
exception_entry = InterviewExceptionEntry(
|
458
|
+
exception=e,
|
459
|
+
invigilator=invigilator,
|
460
|
+
answers=answers,
|
461
|
+
)
|
462
|
+
if task:
|
463
|
+
task.task_status = TaskStatus.FAILED
|
464
|
+
self.exceptions.add(invigilator.question.question_name, exception_entry)
|
465
|
+
|
466
|
+
if self.raise_validation_errors:
|
467
|
+
if isinstance(e, QuestionAnswerValidationError):
|
468
|
+
raise e
|
469
|
+
|
470
|
+
if hasattr(self, "stop_on_exception"):
|
471
|
+
stop_on_exception = self.stop_on_exception
|
472
|
+
else:
|
473
|
+
stop_on_exception = False
|
474
|
+
|
475
|
+
if stop_on_exception:
|
476
|
+
raise e
|
477
|
+
|
478
|
+
def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
|
479
|
+
"""Cancel the tasks for questions that are skipped.
|
480
|
+
|
481
|
+
:param current_question: the question that was just answered.
|
482
|
+
|
483
|
+
It first determines the next question, given the current question and the current answers.
|
484
|
+
If the next question is the end of the survey, it cancels all remaining tasks.
|
485
|
+
If the next question is after the current question, it cancels all tasks between the current question and the next question.
|
486
|
+
"""
|
487
|
+
current_question_index: int = self.to_index[current_question.question_name]
|
488
|
+
|
489
|
+
next_question: Union[
|
490
|
+
int, EndOfSurvey
|
491
|
+
] = self.survey.rule_collection.next_question(
|
492
|
+
q_now=current_question_index,
|
493
|
+
answers=self.answers | self.scenario | self.agent["traits"],
|
494
|
+
)
|
495
|
+
|
496
|
+
next_question_index = next_question.next_q
|
497
|
+
|
498
|
+
def cancel_between(start, end):
|
499
|
+
"""Cancel the tasks between the start and end indices."""
|
500
|
+
for i in range(start, end):
|
501
|
+
self.tasks[i].cancel()
|
502
|
+
|
503
|
+
if next_question_index == EndOfSurvey:
|
504
|
+
cancel_between(current_question_index + 1, len(self.survey.questions))
|
505
|
+
return
|
506
|
+
|
507
|
+
if next_question_index > (current_question_index + 1):
|
508
|
+
cancel_between(current_question_index + 1, next_question_index)
|
509
|
+
|
510
|
+
# endregion
|
511
|
+
|
512
|
+
# region: Conducting the interview
|
213
513
|
async def async_conduct_interview(
|
214
514
|
self,
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
# key_lookup: Optional[KeyLookup] = None,
|
515
|
+
model_buckets: Optional[ModelBuckets] = None,
|
516
|
+
stop_on_exception: bool = False,
|
517
|
+
sidecar_model: Optional["LanguageModel"] = None,
|
518
|
+
raise_validation_errors: bool = True,
|
220
519
|
) -> tuple["Answers", List[dict[str, Any]]]:
|
221
520
|
"""
|
222
521
|
Conduct an Interview asynchronously.
|
@@ -225,6 +524,7 @@ class Interview:
|
|
225
524
|
:param model_buckets: a dictionary of token buckets for the model.
|
226
525
|
:param debug: run without calls to LLM.
|
227
526
|
:param stop_on_exception: if True, stops the interview if an exception is raised.
|
527
|
+
:param sidecar_model: a sidecar model used to answer questions.
|
228
528
|
|
229
529
|
Example usage:
|
230
530
|
|
@@ -238,68 +538,37 @@ class Interview:
|
|
238
538
|
>>> i.exceptions
|
239
539
|
{'q0': ...
|
240
540
|
>>> i = Interview.example()
|
241
|
-
>>>
|
242
|
-
>>> run_config = RunConfig(parameters = RunParameters(), environment = RunEnvironment())
|
243
|
-
>>> run_config.parameters.stop_on_exception = True
|
244
|
-
>>> result, _ = asyncio.run(i.async_conduct_interview(run_config))
|
541
|
+
>>> result, _ = asyncio.run(i.async_conduct_interview(stop_on_exception = True))
|
245
542
|
Traceback (most recent call last):
|
246
543
|
...
|
247
544
|
asyncio.exceptions.CancelledError
|
248
545
|
"""
|
249
|
-
|
250
|
-
|
251
|
-
if run_config is None:
|
252
|
-
run_config = RunConfig(
|
253
|
-
parameters=RunParameters(),
|
254
|
-
environment=RunEnvironment(),
|
255
|
-
)
|
256
|
-
self.stop_on_exception = run_config.parameters.stop_on_exception
|
546
|
+
self.sidecar_model = sidecar_model
|
547
|
+
self.stop_on_exception = stop_on_exception
|
257
548
|
|
258
549
|
# if no model bucket is passed, create an 'infinity' bucket with no rate limits
|
259
|
-
bucket_collection = run_config.environment.bucket_collection
|
260
|
-
|
261
|
-
if bucket_collection:
|
262
|
-
model_buckets = bucket_collection.get(self.model)
|
263
|
-
else:
|
264
|
-
model_buckets = None
|
265
|
-
|
266
550
|
if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
|
267
551
|
model_buckets = ModelBuckets.infinity_bucket()
|
268
552
|
|
269
|
-
# was "self.tasks" - is that necessary?
|
270
|
-
self.tasks = self.task_manager.build_question_tasks(
|
271
|
-
answer_func=AnswerQuestionFunctionConstructor(
|
272
|
-
self, key_lookup=run_config.environment.key_lookup
|
273
|
-
)(),
|
274
|
-
token_estimator=RequestTokenEstimator(self),
|
275
|
-
model_buckets=model_buckets,
|
276
|
-
)
|
277
|
-
|
278
553
|
## This is the key part---it creates a task for each question,
|
279
554
|
## with dependencies on the questions that must be answered before this one can be answered.
|
555
|
+
self.tasks = self._build_question_tasks(model_buckets=model_buckets)
|
280
556
|
|
281
|
-
## 'Invigilators' are used to administer the survey
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
key_lookup=run_config.environment.key_lookup,
|
286
|
-
)
|
287
|
-
self.invigilators = [fetcher(question) for question in self.survey.questions]
|
557
|
+
## 'Invigilators' are used to administer the survey
|
558
|
+
self.invigilators = [
|
559
|
+
self._get_invigilator(question) for question in self.survey.questions
|
560
|
+
]
|
288
561
|
await asyncio.gather(
|
289
|
-
*self.tasks, return_exceptions=not
|
290
|
-
)
|
562
|
+
*self.tasks, return_exceptions=not stop_on_exception
|
563
|
+
) # not stop_on_exception)
|
291
564
|
self.answers.replace_missing_answers_with_none(self.survey)
|
292
|
-
valid_results = list(
|
293
|
-
self._extract_valid_results(self.tasks, self.invigilators, self.exceptions)
|
294
|
-
)
|
565
|
+
valid_results = list(self._extract_valid_results())
|
295
566
|
return self.answers, valid_results
|
296
567
|
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
exceptions: InterviewExceptionCollection,
|
302
|
-
) -> Generator["Answers", None, None]:
|
568
|
+
# endregion
|
569
|
+
|
570
|
+
# region: Extracting results and recording errors
|
571
|
+
def _extract_valid_results(self) -> Generator["Answers", None, None]:
|
303
572
|
"""Extract the valid results from the list of results.
|
304
573
|
|
305
574
|
It iterates through the tasks and invigilators, and yields the results of the tasks that are done.
|
@@ -308,10 +577,16 @@ class Interview:
|
|
308
577
|
|
309
578
|
>>> i = Interview.example()
|
310
579
|
>>> result, _ = asyncio.run(i.async_conduct_interview())
|
580
|
+
>>> results = list(i._extract_valid_results())
|
581
|
+
>>> len(results) == len(i.survey)
|
582
|
+
True
|
311
583
|
"""
|
312
|
-
assert len(tasks) == len(invigilators)
|
584
|
+
assert len(self.tasks) == len(self.invigilators)
|
585
|
+
|
586
|
+
for task, invigilator in zip(self.tasks, self.invigilators):
|
587
|
+
if not task.done():
|
588
|
+
raise ValueError(f"Task {task.get_name()} is not done.")
|
313
589
|
|
314
|
-
def handle_task(task, invigilator):
|
315
590
|
try:
|
316
591
|
result = task.result()
|
317
592
|
except asyncio.CancelledError as e: # task was cancelled
|
@@ -326,22 +601,18 @@ class Interview:
|
|
326
601
|
exception=e,
|
327
602
|
invigilator=invigilator,
|
328
603
|
)
|
329
|
-
exceptions.add(task.get_name(), exception_entry)
|
330
|
-
return result
|
604
|
+
self.exceptions.add(task.get_name(), exception_entry)
|
331
605
|
|
332
|
-
|
333
|
-
if not task.done():
|
334
|
-
raise ValueError(f"Task {task.get_name()} is not done.")
|
606
|
+
yield result
|
335
607
|
|
336
|
-
|
608
|
+
# endregion
|
337
609
|
|
610
|
+
# region: Magic methods
|
338
611
|
def __repr__(self) -> str:
|
339
612
|
"""Return a string representation of the Interview instance."""
|
340
613
|
return f"Interview(agent = {repr(self.agent)}, survey = {repr(self.survey)}, scenario = {repr(self.scenario)}, model = {repr(self.model)})"
|
341
614
|
|
342
|
-
def duplicate(
|
343
|
-
self, iteration: int, cache: "Cache", randomize_survey: Optional[bool] = True
|
344
|
-
) -> Interview:
|
615
|
+
def duplicate(self, iteration: int, cache: "Cache") -> Interview:
|
345
616
|
"""Duplicate the interview, but with a new iteration number and cache.
|
346
617
|
|
347
618
|
>>> i = Interview.example()
|
@@ -350,20 +621,14 @@ class Interview:
|
|
350
621
|
True
|
351
622
|
|
352
623
|
"""
|
353
|
-
if randomize_survey:
|
354
|
-
new_survey = self.survey.draw()
|
355
|
-
else:
|
356
|
-
new_survey = self.survey
|
357
|
-
|
358
624
|
return Interview(
|
359
625
|
agent=self.agent,
|
360
|
-
survey=
|
626
|
+
survey=self.survey,
|
361
627
|
scenario=self.scenario,
|
362
628
|
model=self.model,
|
363
629
|
iteration=iteration,
|
364
|
-
cache=
|
365
|
-
skip_retry=self.
|
366
|
-
indices=self.indices,
|
630
|
+
cache=cache,
|
631
|
+
skip_retry=self.skip_retry,
|
367
632
|
)
|
368
633
|
|
369
634
|
@classmethod
|