edsl 0.1.38.dev4__py3-none-any.whl → 0.1.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +197 -116
- edsl/__init__.py +15 -7
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +351 -147
- edsl/agents/AgentList.py +211 -73
- edsl/agents/Invigilator.py +101 -50
- edsl/agents/InvigilatorBase.py +62 -70
- edsl/agents/PromptConstructor.py +143 -225
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/__init__.py +0 -1
- edsl/agents/prompt_helpers.py +3 -3
- edsl/agents/question_option_processor.py +172 -0
- edsl/auto/AutoStudy.py +18 -5
- edsl/auto/StageBase.py +53 -40
- edsl/auto/StageQuestions.py +2 -1
- edsl/auto/utilities.py +0 -6
- edsl/config.py +22 -2
- edsl/conversation/car_buying.py +2 -1
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +1 -1
- edsl/coop/coop.py +125 -47
- edsl/coop/utils.py +14 -14
- edsl/data/Cache.py +45 -27
- edsl/data/CacheEntry.py +12 -15
- edsl/data/CacheHandler.py +31 -12
- edsl/data/RemoteCacheSync.py +154 -46
- edsl/data/__init__.py +4 -3
- edsl/data_transfer_models.py +2 -1
- edsl/enums.py +27 -0
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +12 -0
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/questions.py +24 -6
- edsl/exceptions/scenarios.py +7 -0
- edsl/inference_services/AnthropicService.py +38 -19
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +215 -0
- edsl/inference_services/AwsBedrock.py +0 -2
- edsl/inference_services/AzureAI.py +0 -2
- edsl/inference_services/GoogleService.py +7 -12
- edsl/inference_services/InferenceServiceABC.py +18 -85
- edsl/inference_services/InferenceServicesCollection.py +120 -79
- edsl/inference_services/MistralAIService.py +0 -3
- edsl/inference_services/OpenAIService.py +47 -35
- edsl/inference_services/PerplexityService.py +0 -3
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +11 -10
- edsl/inference_services/TogetherAIService.py +5 -3
- edsl/inference_services/data_structures.py +134 -0
- edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
- edsl/jobs/Answers.py +1 -14
- edsl/jobs/FetchInvigilator.py +47 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +50 -0
- edsl/jobs/Jobs.py +356 -431
- edsl/jobs/JobsChecks.py +35 -10
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +6 -4
- edsl/jobs/JobsRemoteInferenceHandler.py +205 -133
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/buckets/BucketCollection.py +44 -3
- edsl/jobs/buckets/TokenBucket.py +53 -21
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +143 -408
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +88 -403
- edsl/jobs/runners/JobsRunnerStatus.py +133 -165
- edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
- edsl/jobs/tasks/TaskHistory.py +38 -18
- edsl/jobs/tasks/task_status_enum.py +0 -2
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +194 -236
- edsl/language_models/ModelList.py +28 -19
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/__init__.py +1 -2
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +2 -2
- edsl/language_models/utilities.py +5 -4
- edsl/notebooks/Notebook.py +19 -14
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/prompts/Prompt.py +29 -39
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/QuestionBase.py +68 -214
- edsl/questions/QuestionBasePromptsMixin.py +7 -3
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +3 -3
- edsl/questions/QuestionExtract.py +5 -7
- edsl/questions/QuestionFreeText.py +2 -3
- edsl/questions/QuestionList.py +10 -18
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +67 -23
- edsl/questions/QuestionNumerical.py +2 -4
- edsl/questions/QuestionRank.py +7 -17
- edsl/questions/SimpleAskMixin.py +4 -3
- edsl/questions/__init__.py +2 -1
- edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +47 -2
- edsl/questions/data_structures.py +20 -0
- edsl/questions/derived/QuestionLinearScale.py +6 -3
- edsl/questions/derived/QuestionTopK.py +1 -1
- edsl/questions/descriptors.py +17 -3
- edsl/questions/loop_processor.py +149 -0
- edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +57 -50
- edsl/questions/question_registry.py +1 -1
- edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +40 -26
- edsl/questions/response_validator_factory.py +34 -0
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/results/CSSParameterizer.py +1 -1
- edsl/results/Dataset.py +170 -7
- edsl/results/DatasetExportMixin.py +168 -305
- edsl/results/DatasetTree.py +28 -8
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +298 -206
- edsl/results/Results.py +149 -131
- edsl/results/ResultsExportMixin.py +2 -0
- edsl/results/TableDisplay.py +98 -171
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +1 -1
- edsl/results/file_exports.py +252 -0
- edsl/results/{Selector.py → results_selector.py} +23 -13
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_renderers.py +118 -0
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +150 -239
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +90 -193
- edsl/scenarios/ScenarioHtmlMixin.py +4 -3
- edsl/scenarios/ScenarioList.py +415 -244
- edsl/scenarios/ScenarioListExportMixin.py +0 -7
- edsl/scenarios/ScenarioListPdfMixin.py +15 -37
- edsl/scenarios/__init__.py +1 -2
- edsl/scenarios/directory_scanner.py +96 -0
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +49 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +10 -6
- edsl/scenarios/scenario_selector.py +156 -0
- edsl/study/ObjectEntry.py +1 -1
- edsl/study/SnapShot.py +1 -1
- edsl/study/Study.py +5 -12
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/Rule.py +5 -4
- edsl/surveys/RuleCollection.py +25 -27
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +270 -791
- edsl/surveys/SurveyCSS.py +20 -8
- edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +4 -2
- edsl/surveys/descriptors.py +6 -2
- edsl/surveys/instructions/ChangeInstruction.py +1 -2
- edsl/surveys/instructions/Instruction.py +4 -13
- edsl/surveys/instructions/InstructionCollection.py +11 -6
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/report.html +1 -1
- edsl/tools/plotting.py +1 -1
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/utilities.py +35 -23
- {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/METADATA +12 -10
- edsl-0.1.39.dist-info/RECORD +358 -0
- {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/WHEEL +1 -1
- edsl/language_models/KeyLookup.py +0 -30
- edsl/language_models/registry.py +0 -190
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/results/ResultsDBMixin.py +0 -238
- edsl-0.1.38.dev4.dist-info/RECORD +0 -277
- /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
- /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
- /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
- {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/LICENSE +0 -0
@@ -3,31 +3,17 @@ import time
|
|
3
3
|
import asyncio
|
4
4
|
import threading
|
5
5
|
import warnings
|
6
|
-
from typing import
|
7
|
-
from uuid import UUID
|
8
|
-
from collections import UserList
|
6
|
+
from typing import TYPE_CHECKING
|
9
7
|
|
10
8
|
from edsl.results.Results import Results
|
11
|
-
from edsl.jobs.
|
12
|
-
from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus, JobsRunnerStatusBase
|
13
|
-
|
9
|
+
from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus
|
14
10
|
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
15
|
-
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
16
11
|
from edsl.utilities.decorators import jupyter_nb_handler
|
17
|
-
from edsl.
|
18
|
-
from edsl.
|
19
|
-
from edsl.results.Results import Results
|
20
|
-
from edsl.language_models.LanguageModel import LanguageModel
|
21
|
-
from edsl.data.Cache import Cache
|
22
|
-
|
12
|
+
from edsl.jobs.async_interview_runner import AsyncInterviewRunner
|
13
|
+
from edsl.jobs.data_structures import RunEnvironment, RunParameters, RunConfig
|
23
14
|
|
24
|
-
|
25
|
-
|
26
|
-
self.total_tasks = total_tasks
|
27
|
-
super().__init__()
|
28
|
-
|
29
|
-
def current_status(self):
|
30
|
-
return print(f"Completed: {len(self.data)} of {self.total_tasks}", end="\r")
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from edsl.jobs.Jobs import Jobs
|
31
17
|
|
32
18
|
|
33
19
|
class JobsRunnerAsyncio:
|
@@ -37,430 +23,129 @@ class JobsRunnerAsyncio:
|
|
37
23
|
The Jobs object is a collection of interviews that are to be run.
|
38
24
|
"""
|
39
25
|
|
40
|
-
|
41
|
-
|
42
|
-
def __init__(self, jobs: "Jobs"):
|
26
|
+
def __init__(self, jobs: "Jobs", environment: RunEnvironment):
|
43
27
|
self.jobs = jobs
|
44
|
-
self.
|
45
|
-
self.bucket_collection: "BucketCollection" = jobs.bucket_collection
|
46
|
-
self.total_interviews: List["Interview"] = []
|
47
|
-
self._initialized = threading.Event()
|
48
|
-
|
49
|
-
from edsl.config import CONFIG
|
50
|
-
|
51
|
-
self.MAX_CONCURRENT = int(CONFIG.get("EDSL_MAX_CONCURRENT_TASKS"))
|
52
|
-
# print(f"MAX_CONCURRENT: {self.MAX_CONCURRENT}")
|
53
|
-
|
54
|
-
# async def run_async_generator(
|
55
|
-
# self,
|
56
|
-
# cache: Cache,
|
57
|
-
# n: int = 1,
|
58
|
-
# stop_on_exception: bool = False,
|
59
|
-
# sidecar_model: Optional[LanguageModel] = None,
|
60
|
-
# total_interviews: Optional[List["Interview"]] = None,
|
61
|
-
# raise_validation_errors: bool = False,
|
62
|
-
# ) -> AsyncGenerator["Result", None]:
|
63
|
-
# """Creates the tasks, runs them asynchronously, and returns the results as a Results object.
|
64
|
-
|
65
|
-
# Completed tasks are yielded as they are completed.
|
66
|
-
|
67
|
-
# :param n: how many times to run each interview
|
68
|
-
# :param stop_on_exception: Whether to stop the interview if an exception is raised
|
69
|
-
# :param sidecar_model: a language model to use in addition to the interview's model
|
70
|
-
# :param total_interviews: A list of interviews to run can be provided instead.
|
71
|
-
# :param raise_validation_errors: Whether to raise validation errors
|
72
|
-
# """
|
73
|
-
# tasks = []
|
74
|
-
# if total_interviews: # was already passed in total interviews
|
75
|
-
# self.total_interviews = total_interviews
|
76
|
-
# else:
|
77
|
-
# self.total_interviews = list(
|
78
|
-
# self._populate_total_interviews(n=n)
|
79
|
-
# ) # Populate self.total_interviews before creating tasks
|
80
|
-
# self._initialized.set() # Signal that we're ready
|
81
|
-
|
82
|
-
# for interview in self.total_interviews:
|
83
|
-
# interviewing_task = self._build_interview_task(
|
84
|
-
# interview=interview,
|
85
|
-
# stop_on_exception=stop_on_exception,
|
86
|
-
# sidecar_model=sidecar_model,
|
87
|
-
# raise_validation_errors=raise_validation_errors,
|
88
|
-
# )
|
89
|
-
# tasks.append(asyncio.create_task(interviewing_task))
|
90
|
-
|
91
|
-
# for task in asyncio.as_completed(tasks):
|
92
|
-
# result = await task
|
93
|
-
# self.jobs_runner_status.add_completed_interview(result)
|
94
|
-
# yield result
|
95
|
-
|
96
|
-
async def run_async_generator(
|
97
|
-
self,
|
98
|
-
cache: Cache,
|
99
|
-
n: int = 1,
|
100
|
-
stop_on_exception: bool = False,
|
101
|
-
sidecar_model: Optional[LanguageModel] = None,
|
102
|
-
total_interviews: Optional[List["Interview"]] = None,
|
103
|
-
raise_validation_errors: bool = False,
|
104
|
-
) -> AsyncGenerator["Result", None]:
|
105
|
-
"""Creates and processes tasks asynchronously, yielding results as they complete.
|
106
|
-
|
107
|
-
Tasks are created and processed in a streaming fashion rather than building the full list upfront.
|
108
|
-
Results are yielded as soon as they are available.
|
109
|
-
|
110
|
-
:param n: how many times to run each interview
|
111
|
-
:param stop_on_exception: Whether to stop the interview if an exception is raised
|
112
|
-
:param sidecar_model: a language model to use in addition to the interview's model
|
113
|
-
:param total_interviews: A list of interviews to run can be provided instead.
|
114
|
-
:param raise_validation_errors: Whether to raise validation errors
|
115
|
-
"""
|
116
|
-
# Initialize interviews iterator
|
117
|
-
if total_interviews:
|
118
|
-
interviews_iter = iter(total_interviews)
|
119
|
-
self.total_interviews = total_interviews
|
120
|
-
else:
|
121
|
-
interviews_iter = self._populate_total_interviews(n=n)
|
122
|
-
self.total_interviews = list(interviews_iter)
|
123
|
-
interviews_iter = iter(self.total_interviews) # Create fresh iterator
|
124
|
-
|
125
|
-
self._initialized.set() # Signal that we're ready
|
126
|
-
|
127
|
-
# Keep track of active tasks
|
128
|
-
active_tasks = set()
|
129
|
-
|
130
|
-
try:
|
131
|
-
while True:
|
132
|
-
# Add new tasks if we're below max_concurrent and there are more interviews
|
133
|
-
while len(active_tasks) < self.MAX_CONCURRENT:
|
134
|
-
try:
|
135
|
-
interview = next(interviews_iter)
|
136
|
-
task = asyncio.create_task(
|
137
|
-
self._build_interview_task(
|
138
|
-
interview=interview,
|
139
|
-
stop_on_exception=stop_on_exception,
|
140
|
-
sidecar_model=sidecar_model,
|
141
|
-
raise_validation_errors=raise_validation_errors,
|
142
|
-
)
|
143
|
-
)
|
144
|
-
active_tasks.add(task)
|
145
|
-
# Add callback to remove task from set when done
|
146
|
-
task.add_done_callback(active_tasks.discard)
|
147
|
-
except StopIteration:
|
148
|
-
break
|
149
|
-
|
150
|
-
if not active_tasks:
|
151
|
-
break
|
28
|
+
self.environment = environment
|
152
29
|
|
153
|
-
|
154
|
-
|
155
|
-
active_tasks, return_when=asyncio.FIRST_COMPLETED
|
156
|
-
)
|
30
|
+
def __len__(self):
|
31
|
+
return len(self.jobs)
|
157
32
|
|
158
|
-
|
159
|
-
|
160
|
-
try:
|
161
|
-
result = await task
|
162
|
-
self.jobs_runner_status.add_completed_interview(result)
|
163
|
-
yield result
|
164
|
-
except Exception as e:
|
165
|
-
if stop_on_exception:
|
166
|
-
# Cancel remaining tasks
|
167
|
-
for t in active_tasks:
|
168
|
-
if not t.done():
|
169
|
-
t.cancel()
|
170
|
-
raise
|
171
|
-
else:
|
172
|
-
# Log error and continue
|
173
|
-
# logger.error(f"Task failed with error: {e}")
|
174
|
-
continue
|
175
|
-
finally:
|
176
|
-
# Ensure we cancel any remaining tasks if we exit early
|
177
|
-
for task in active_tasks:
|
178
|
-
if not task.done():
|
179
|
-
task.cancel()
|
33
|
+
async def run_async(self, parameters: RunParameters) -> Results:
|
34
|
+
"""Used for some other modules that have a non-standard way of running interviews."""
|
180
35
|
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
"""Populates self.total_interviews with n copies of each interview.
|
36
|
+
self.environment.jobs_runner_status = JobsRunnerStatus(self, n=parameters.n)
|
37
|
+
data = []
|
38
|
+
task_history = TaskHistory(include_traceback=False)
|
185
39
|
|
186
|
-
|
187
|
-
|
188
|
-
for interview in self.interviews:
|
189
|
-
for iteration in range(n):
|
190
|
-
if iteration > 0:
|
191
|
-
yield interview.duplicate(iteration=iteration, cache=self.cache)
|
192
|
-
else:
|
193
|
-
interview.cache = self.cache
|
194
|
-
yield interview
|
40
|
+
run_config = RunConfig(parameters=parameters, environment=self.environment)
|
41
|
+
result_generator = AsyncInterviewRunner(self.jobs, run_config)
|
195
42
|
|
196
|
-
|
197
|
-
"""Used for some other modules that have a non-standard way of running interviews."""
|
198
|
-
self.jobs_runner_status = JobsRunnerStatus(self, n=n)
|
199
|
-
self.cache = Cache() if cache is None else cache
|
200
|
-
data = []
|
201
|
-
async for result in self.run_async_generator(cache=self.cache, n=n):
|
43
|
+
async for result, interview in result_generator.run():
|
202
44
|
data.append(result)
|
203
|
-
|
45
|
+
task_history.add_interview(interview)
|
46
|
+
|
47
|
+
return Results(survey=self.jobs.survey, task_history=task_history, data=data)
|
204
48
|
|
205
49
|
def simple_run(self):
|
206
50
|
data = asyncio.run(self.run_async())
|
207
51
|
return Results(survey=self.jobs.survey, data=data)
|
208
52
|
|
209
|
-
async def _build_interview_task(
|
210
|
-
self,
|
211
|
-
*,
|
212
|
-
interview: Interview,
|
213
|
-
stop_on_exception: bool = False,
|
214
|
-
sidecar_model: Optional["LanguageModel"] = None,
|
215
|
-
raise_validation_errors: bool = False,
|
216
|
-
) -> "Result":
|
217
|
-
"""Conducts an interview and returns the result.
|
218
|
-
|
219
|
-
:param interview: the interview to conduct
|
220
|
-
:param stop_on_exception: stops the interview if an exception is raised
|
221
|
-
:param sidecar_model: a language model to use in addition to the interview's model
|
222
|
-
"""
|
223
|
-
# the model buckets are used to track usage rates
|
224
|
-
model_buckets = self.bucket_collection[interview.model]
|
225
|
-
|
226
|
-
# get the results of the interview
|
227
|
-
answer, valid_results = await interview.async_conduct_interview(
|
228
|
-
model_buckets=model_buckets,
|
229
|
-
stop_on_exception=stop_on_exception,
|
230
|
-
sidecar_model=sidecar_model,
|
231
|
-
raise_validation_errors=raise_validation_errors,
|
232
|
-
)
|
233
|
-
|
234
|
-
question_results = {}
|
235
|
-
for result in valid_results:
|
236
|
-
question_results[result.question_name] = result
|
237
|
-
|
238
|
-
answer_key_names = list(question_results.keys())
|
239
|
-
|
240
|
-
generated_tokens_dict = {
|
241
|
-
k + "_generated_tokens": question_results[k].generated_tokens
|
242
|
-
for k in answer_key_names
|
243
|
-
}
|
244
|
-
comments_dict = {
|
245
|
-
k + "_comment": question_results[k].comment for k in answer_key_names
|
246
|
-
}
|
247
|
-
|
248
|
-
# we should have a valid result for each question
|
249
|
-
answer_dict = {k: answer[k] for k in answer_key_names}
|
250
|
-
assert len(valid_results) == len(answer_key_names)
|
251
|
-
|
252
|
-
# TODO: move this down into Interview
|
253
|
-
question_name_to_prompts = dict({})
|
254
|
-
for result in valid_results:
|
255
|
-
question_name = result.question_name
|
256
|
-
question_name_to_prompts[question_name] = {
|
257
|
-
"user_prompt": result.prompts["user_prompt"],
|
258
|
-
"system_prompt": result.prompts["system_prompt"],
|
259
|
-
}
|
260
|
-
|
261
|
-
prompt_dictionary = {}
|
262
|
-
for answer_key_name in answer_key_names:
|
263
|
-
prompt_dictionary[
|
264
|
-
answer_key_name + "_user_prompt"
|
265
|
-
] = question_name_to_prompts[answer_key_name]["user_prompt"]
|
266
|
-
prompt_dictionary[
|
267
|
-
answer_key_name + "_system_prompt"
|
268
|
-
] = question_name_to_prompts[answer_key_name]["system_prompt"]
|
269
|
-
|
270
|
-
raw_model_results_dictionary = {}
|
271
|
-
cache_used_dictionary = {}
|
272
|
-
for result in valid_results:
|
273
|
-
question_name = result.question_name
|
274
|
-
raw_model_results_dictionary[
|
275
|
-
question_name + "_raw_model_response"
|
276
|
-
] = result.raw_model_response
|
277
|
-
raw_model_results_dictionary[question_name + "_cost"] = result.cost
|
278
|
-
one_use_buys = (
|
279
|
-
"NA"
|
280
|
-
if isinstance(result.cost, str)
|
281
|
-
or result.cost == 0
|
282
|
-
or result.cost is None
|
283
|
-
else 1.0 / result.cost
|
284
|
-
)
|
285
|
-
raw_model_results_dictionary[question_name + "_one_usd_buys"] = one_use_buys
|
286
|
-
cache_used_dictionary[question_name] = result.cache_used
|
287
|
-
|
288
|
-
result = Result(
|
289
|
-
agent=interview.agent,
|
290
|
-
scenario=interview.scenario,
|
291
|
-
model=interview.model,
|
292
|
-
iteration=interview.iteration,
|
293
|
-
answer=answer_dict,
|
294
|
-
prompt=prompt_dictionary,
|
295
|
-
raw_model_response=raw_model_results_dictionary,
|
296
|
-
survey=interview.survey,
|
297
|
-
generated_tokens=generated_tokens_dict,
|
298
|
-
comments_dict=comments_dict,
|
299
|
-
cache_used_dict=cache_used_dictionary,
|
300
|
-
)
|
301
|
-
result.interview_hash = hash(interview)
|
302
|
-
|
303
|
-
return result
|
304
|
-
|
305
|
-
@property
|
306
|
-
def elapsed_time(self):
|
307
|
-
return time.monotonic() - self.start_time
|
308
|
-
|
309
|
-
def process_results(
|
310
|
-
self, raw_results: Results, cache: Cache, print_exceptions: bool
|
311
|
-
):
|
312
|
-
interview_lookup = {
|
313
|
-
hash(interview): index
|
314
|
-
for index, interview in enumerate(self.total_interviews)
|
315
|
-
}
|
316
|
-
interview_hashes = list(interview_lookup.keys())
|
317
|
-
|
318
|
-
task_history = TaskHistory(self.total_interviews, include_traceback=False)
|
319
|
-
|
320
|
-
results = Results(
|
321
|
-
survey=self.jobs.survey,
|
322
|
-
data=sorted(
|
323
|
-
raw_results, key=lambda x: interview_hashes.index(x.interview_hash)
|
324
|
-
),
|
325
|
-
task_history=task_history,
|
326
|
-
cache=cache,
|
327
|
-
)
|
328
|
-
results.bucket_collection = self.bucket_collection
|
329
|
-
|
330
|
-
if results.has_unfixed_exceptions and print_exceptions:
|
331
|
-
from edsl.scenarios.FileStore import HTMLFileStore
|
332
|
-
from edsl.config import CONFIG
|
333
|
-
from edsl.coop.coop import Coop
|
334
|
-
|
335
|
-
msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
|
336
|
-
|
337
|
-
if len(results.task_history.indices) > 5:
|
338
|
-
msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
|
339
|
-
|
340
|
-
import sys
|
341
|
-
|
342
|
-
print(msg, file=sys.stderr)
|
343
|
-
from edsl.config import CONFIG
|
344
|
-
|
345
|
-
if CONFIG.get("EDSL_OPEN_EXCEPTION_REPORT_URL") == "True":
|
346
|
-
open_in_browser = True
|
347
|
-
elif CONFIG.get("EDSL_OPEN_EXCEPTION_REPORT_URL") == "False":
|
348
|
-
open_in_browser = False
|
349
|
-
else:
|
350
|
-
raise Exception(
|
351
|
-
"EDSL_OPEN_EXCEPTION_REPORT_URL", "must be either True or False"
|
352
|
-
)
|
353
|
-
|
354
|
-
# print("open_in_browser", open_in_browser)
|
355
|
-
|
356
|
-
filepath = results.task_history.html(
|
357
|
-
cta="Open report to see details.",
|
358
|
-
open_in_browser=open_in_browser,
|
359
|
-
return_link=True,
|
360
|
-
)
|
361
|
-
|
362
|
-
try:
|
363
|
-
coop = Coop()
|
364
|
-
user_edsl_settings = coop.edsl_settings
|
365
|
-
remote_logging = user_edsl_settings["remote_logging"]
|
366
|
-
except Exception as e:
|
367
|
-
print(e)
|
368
|
-
remote_logging = False
|
369
|
-
|
370
|
-
if remote_logging:
|
371
|
-
filestore = HTMLFileStore(filepath)
|
372
|
-
coop_details = filestore.push(description="Error report")
|
373
|
-
print(coop_details)
|
374
|
-
|
375
|
-
print("Also see: https://docs.expectedparrot.com/en/latest/exceptions.html")
|
376
|
-
|
377
|
-
return results
|
378
|
-
|
379
53
|
@jupyter_nb_handler
|
380
|
-
async def run(
|
381
|
-
self,
|
382
|
-
cache: Union[Cache, False, None],
|
383
|
-
n: int = 1,
|
384
|
-
stop_on_exception: bool = False,
|
385
|
-
progress_bar: bool = False,
|
386
|
-
sidecar_model: Optional[LanguageModel] = None,
|
387
|
-
jobs_runner_status: Optional[Type[JobsRunnerStatusBase]] = None,
|
388
|
-
job_uuid: Optional[UUID] = None,
|
389
|
-
print_exceptions: bool = True,
|
390
|
-
raise_validation_errors: bool = False,
|
391
|
-
) -> "Coroutine":
|
54
|
+
async def run(self, parameters: RunParameters) -> Results:
|
392
55
|
"""Runs a collection of interviews, handling both async and sync contexts."""
|
393
56
|
|
394
|
-
|
57
|
+
run_config = RunConfig(parameters=parameters, environment=self.environment)
|
58
|
+
|
395
59
|
self.start_time = time.monotonic()
|
396
60
|
self.completed = False
|
397
|
-
self.cache = cache
|
398
|
-
self.sidecar_model = sidecar_model
|
399
61
|
|
400
62
|
from edsl.coop import Coop
|
401
63
|
|
402
64
|
coop = Coop()
|
403
65
|
endpoint_url = coop.get_progress_bar_url()
|
404
66
|
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
67
|
+
def set_up_jobs_runner_status(jobs_runner_status):
|
68
|
+
if jobs_runner_status is not None:
|
69
|
+
return jobs_runner_status(
|
70
|
+
self,
|
71
|
+
n=parameters.n,
|
72
|
+
endpoint_url=endpoint_url,
|
73
|
+
job_uuid=parameters.job_uuid,
|
74
|
+
)
|
75
|
+
else:
|
76
|
+
return JobsRunnerStatus(
|
77
|
+
self,
|
78
|
+
n=parameters.n,
|
79
|
+
endpoint_url=endpoint_url,
|
80
|
+
job_uuid=parameters.job_uuid,
|
81
|
+
)
|
82
|
+
|
83
|
+
run_config.environment.jobs_runner_status = set_up_jobs_runner_status(
|
84
|
+
self.environment.jobs_runner_status
|
85
|
+
)
|
413
86
|
|
414
|
-
|
87
|
+
async def get_results(results) -> None:
|
88
|
+
"""Conducted the interviews and append to the results list."""
|
89
|
+
result_generator = AsyncInterviewRunner(self.jobs, run_config)
|
90
|
+
async for result, interview in result_generator.run():
|
91
|
+
results.append(result)
|
92
|
+
results.task_history.add_interview(interview)
|
415
93
|
|
416
|
-
async def process_results(cache):
|
417
|
-
"""Processes results from interviews."""
|
418
|
-
async for result in self.run_async_generator(
|
419
|
-
n=n,
|
420
|
-
stop_on_exception=stop_on_exception,
|
421
|
-
cache=cache,
|
422
|
-
sidecar_model=sidecar_model,
|
423
|
-
raise_validation_errors=raise_validation_errors,
|
424
|
-
):
|
425
|
-
self.results.append(result)
|
426
94
|
self.completed = True
|
427
95
|
|
428
|
-
def run_progress_bar(stop_event):
|
96
|
+
def run_progress_bar(stop_event, jobs_runner_status) -> None:
|
429
97
|
"""Runs the progress bar in a separate thread."""
|
430
|
-
|
98
|
+
jobs_runner_status.update_progress(stop_event)
|
99
|
+
|
100
|
+
def set_up_progress_bar(progress_bar: bool, jobs_runner_status):
|
101
|
+
progress_thread = None
|
102
|
+
if progress_bar and jobs_runner_status.has_ep_api_key():
|
103
|
+
jobs_runner_status.setup()
|
104
|
+
progress_thread = threading.Thread(
|
105
|
+
target=run_progress_bar, args=(stop_event, jobs_runner_status)
|
106
|
+
)
|
107
|
+
progress_thread.start()
|
108
|
+
elif progress_bar:
|
109
|
+
warnings.warn(
|
110
|
+
"You need an Expected Parrot API key to view job progress bars."
|
111
|
+
)
|
112
|
+
return progress_thread
|
431
113
|
|
432
|
-
|
433
|
-
self.
|
434
|
-
|
435
|
-
|
436
|
-
)
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
114
|
+
results = Results(
|
115
|
+
survey=self.jobs.survey,
|
116
|
+
data=[],
|
117
|
+
task_history=TaskHistory(),
|
118
|
+
cache=self.environment.cache.new_entries_cache(),
|
119
|
+
)
|
120
|
+
stop_event = threading.Event()
|
121
|
+
progress_thread = set_up_progress_bar(
|
122
|
+
parameters.progress_bar, run_config.environment.jobs_runner_status
|
123
|
+
)
|
442
124
|
|
443
125
|
exception_to_raise = None
|
444
126
|
try:
|
445
|
-
|
446
|
-
await process_results(cache=c)
|
127
|
+
await get_results(results)
|
447
128
|
except KeyboardInterrupt:
|
448
129
|
print("Keyboard interrupt received. Stopping gracefully...")
|
449
130
|
stop_event.set()
|
450
131
|
except Exception as e:
|
451
|
-
if stop_on_exception:
|
132
|
+
if parameters.stop_on_exception:
|
452
133
|
exception_to_raise = e
|
453
134
|
stop_event.set()
|
454
135
|
finally:
|
455
136
|
stop_event.set()
|
456
|
-
if
|
457
|
-
|
458
|
-
if progress_thread:
|
459
|
-
progress_thread.join()
|
137
|
+
if progress_thread is not None:
|
138
|
+
progress_thread.join()
|
460
139
|
|
461
140
|
if exception_to_raise:
|
462
141
|
raise exception_to_raise
|
463
142
|
|
464
|
-
|
465
|
-
|
466
|
-
|
143
|
+
results.cache = self.environment.cache.new_entries_cache()
|
144
|
+
results.bucket_collection = self.environment.bucket_collection
|
145
|
+
|
146
|
+
from edsl.jobs.results_exceptions_handler import ResultsExceptionsHandler
|
147
|
+
|
148
|
+
results_exceptions_handler = ResultsExceptionsHandler(results, parameters)
|
149
|
+
|
150
|
+
results_exceptions_handler.handle_exceptions()
|
151
|
+
return results
|