edsl 0.1.39.dev2__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +28 -0
- edsl/__init__.py +1 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +8 -16
- edsl/agents/Invigilator.py +13 -14
- edsl/agents/InvigilatorBase.py +4 -1
- edsl/agents/PromptConstructor.py +42 -22
- edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
- edsl/auto/AutoStudy.py +18 -5
- edsl/auto/StageBase.py +53 -40
- edsl/auto/StageQuestions.py +2 -1
- edsl/auto/utilities.py +0 -6
- edsl/coop/coop.py +21 -5
- edsl/data/Cache.py +29 -18
- edsl/data/CacheHandler.py +0 -2
- edsl/data/RemoteCacheSync.py +154 -46
- edsl/data/hack.py +10 -0
- edsl/enums.py +7 -0
- edsl/inference_services/AnthropicService.py +38 -16
- edsl/inference_services/AvailableModelFetcher.py +7 -1
- edsl/inference_services/GoogleService.py +5 -1
- edsl/inference_services/InferenceServicesCollection.py +18 -2
- edsl/inference_services/OpenAIService.py +46 -31
- edsl/inference_services/TestService.py +1 -3
- edsl/inference_services/TogetherAIService.py +5 -3
- edsl/inference_services/data_structures.py +74 -2
- edsl/jobs/AnswerQuestionFunctionConstructor.py +148 -113
- edsl/jobs/FetchInvigilator.py +10 -3
- edsl/jobs/InterviewsConstructor.py +6 -4
- edsl/jobs/Jobs.py +299 -233
- edsl/jobs/JobsChecks.py +2 -2
- edsl/jobs/JobsPrompts.py +1 -1
- edsl/jobs/JobsRemoteInferenceHandler.py +160 -136
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/interviews/Interview.py +80 -42
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +87 -357
- edsl/jobs/runners/JobsRunnerStatus.py +131 -164
- edsl/jobs/tasks/TaskHistory.py +24 -3
- edsl/language_models/LanguageModel.py +59 -4
- edsl/language_models/ModelList.py +19 -8
- edsl/language_models/__init__.py +1 -1
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +1 -1
- edsl/questions/QuestionBase.py +35 -26
- edsl/questions/QuestionBasePromptsMixin.py +1 -1
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +2 -2
- edsl/questions/QuestionExtract.py +5 -7
- edsl/questions/QuestionFreeText.py +1 -1
- edsl/questions/QuestionList.py +9 -15
- edsl/questions/QuestionMatrix.py +1 -1
- edsl/questions/QuestionMultipleChoice.py +1 -1
- edsl/questions/QuestionNumerical.py +1 -1
- edsl/questions/QuestionRank.py +1 -1
- edsl/questions/SimpleAskMixin.py +1 -1
- edsl/questions/__init__.py +1 -1
- edsl/questions/data_structures.py +20 -0
- edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +52 -49
- edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +6 -18
- edsl/questions/{ResponseValidatorFactory.py → response_validator_factory.py} +7 -1
- edsl/results/DatasetExportMixin.py +60 -119
- edsl/results/Result.py +109 -3
- edsl/results/Results.py +50 -39
- edsl/results/file_exports.py +252 -0
- edsl/scenarios/ScenarioList.py +35 -7
- edsl/surveys/Survey.py +71 -20
- edsl/test_h +1 -0
- edsl/utilities/gcp_bucket/example.py +50 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +2 -2
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/RECORD +85 -76
- edsl/language_models/registry.py +0 -180
- /edsl/agents/{QuestionOptionProcessor.py → question_option_processor.py} +0 -0
- /edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +0 -0
- /edsl/questions/{LoopProcessor.py → loop_processor.py} +0 -0
- /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
- /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
- /edsl/results/{Selector.py → results_selector.py} +0 -0
- /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
- /edsl/scenarios/{DirectoryScanner.py → directory_scanner.py} +0 -0
- /edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +0 -0
- /edsl/scenarios/{ScenarioSelector.py → scenario_selector.py} +0 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +0 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
edsl/jobs/JobsPrompts.py
CHANGED
@@ -51,7 +51,7 @@ class JobsPrompts:
|
|
51
51
|
for interview_index, interview in enumerate(interviews):
|
52
52
|
invigilators = [
|
53
53
|
FetchInvigilator(interview)(question)
|
54
|
-
for question in
|
54
|
+
for question in interview.survey.questions
|
55
55
|
]
|
56
56
|
for _, invigilator in enumerate(invigilators):
|
57
57
|
prompts = invigilator.get_prompts()
|
@@ -1,4 +1,6 @@
|
|
1
|
-
from typing import Optional, Union, Literal, TYPE_CHECKING, NewType
|
1
|
+
from typing import Optional, Union, Literal, TYPE_CHECKING, NewType, Callable, Any
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
2
4
|
|
3
5
|
|
4
6
|
Seconds = NewType("Seconds", float)
|
@@ -16,26 +18,52 @@ from edsl.coop.coop import RemoteInferenceResponse, RemoteInferenceCreationInfo
|
|
16
18
|
|
17
19
|
from edsl.jobs.jobs_status_enums import JobsStatus
|
18
20
|
from edsl.coop.utils import VisibilityType
|
21
|
+
from edsl.jobs.JobsRemoteInferenceLogger import JobLogger
|
22
|
+
|
23
|
+
|
24
|
+
class RemoteJobConstants:
|
25
|
+
"""Constants for remote job handling."""
|
26
|
+
|
27
|
+
REMOTE_JOB_POLL_INTERVAL = 1
|
28
|
+
REMOTE_JOB_VERBOSE = False
|
29
|
+
DISCORD_URL = "https://discord.com/invite/mxAYkjfy9m"
|
30
|
+
|
31
|
+
|
32
|
+
@dataclass
|
33
|
+
class RemoteJobInfo:
|
34
|
+
creation_data: RemoteInferenceCreationInfo
|
35
|
+
job_uuid: JobUUID
|
36
|
+
logger: JobLogger
|
19
37
|
|
20
38
|
|
21
39
|
class JobsRemoteInferenceHandler:
|
22
|
-
def __init__(
|
23
|
-
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
jobs: "Jobs",
|
43
|
+
verbose: bool = RemoteJobConstants.REMOTE_JOB_VERBOSE,
|
44
|
+
poll_interval: Seconds = RemoteJobConstants.REMOTE_JOB_POLL_INTERVAL,
|
45
|
+
):
|
46
|
+
"""Handles the creation and running of a remote inference job."""
|
24
47
|
self.jobs = jobs
|
25
48
|
self.verbose = verbose
|
26
49
|
self.poll_interval = poll_interval
|
27
50
|
|
28
|
-
|
29
|
-
self._job_uuid: Union[None, JobUUID] = None # Will be set when job is created
|
30
|
-
self.logger: Union[None, JobLogger] = None # Will be initialized when needed
|
51
|
+
from edsl.config import CONFIG
|
31
52
|
|
32
|
-
|
33
|
-
|
34
|
-
return self._remote_job_creation_data
|
53
|
+
self.expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
|
54
|
+
self.remote_inference_url = f"{self.expected_parrot_url}/home/remote-inference"
|
35
55
|
|
36
|
-
|
37
|
-
|
38
|
-
|
56
|
+
def _create_logger(self) -> JobLogger:
|
57
|
+
from edsl.utilities.is_notebook import is_notebook
|
58
|
+
from edsl.jobs.JobsRemoteInferenceLogger import (
|
59
|
+
JupyterJobLogger,
|
60
|
+
StdOutJobLogger,
|
61
|
+
)
|
62
|
+
from edsl.jobs.loggers.HTMLTableJobLogger import HTMLTableJobLogger
|
63
|
+
|
64
|
+
if is_notebook():
|
65
|
+
return HTMLTableJobLogger(verbose=self.verbose)
|
66
|
+
return StdOutJobLogger(verbose=self.verbose)
|
39
67
|
|
40
68
|
def use_remote_inference(self, disable_remote_inference: bool) -> bool:
|
41
69
|
import requests
|
@@ -60,23 +88,15 @@ class JobsRemoteInferenceHandler:
|
|
60
88
|
iterations: int = 1,
|
61
89
|
remote_inference_description: Optional[str] = None,
|
62
90
|
remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
|
63
|
-
) ->
|
91
|
+
) -> RemoteJobInfo:
|
92
|
+
|
64
93
|
from edsl.config import CONFIG
|
65
94
|
from edsl.coop.coop import Coop
|
66
95
|
|
67
|
-
|
68
|
-
from edsl.utilities.is_notebook import is_notebook
|
69
|
-
from edsl.jobs.JobsRemoteInferenceLogger import JupyterJobLogger
|
70
|
-
from edsl.jobs.JobsRemoteInferenceLogger import StdOutJobLogger
|
71
|
-
from edsl.jobs.loggers.HTMLTableJobLogger import HTMLTableJobLogger
|
72
|
-
|
73
|
-
if is_notebook():
|
74
|
-
self.logger = HTMLTableJobLogger(verbose=self.verbose)
|
75
|
-
else:
|
76
|
-
self.logger = StdOutJobLogger(verbose=self.verbose)
|
96
|
+
logger = self._create_logger()
|
77
97
|
|
78
98
|
coop = Coop()
|
79
|
-
|
99
|
+
logger.update(
|
80
100
|
"Remote inference activated. Sending job to server...",
|
81
101
|
status=JobsStatus.QUEUED,
|
82
102
|
)
|
@@ -87,33 +107,34 @@ class JobsRemoteInferenceHandler:
|
|
87
107
|
iterations=iterations,
|
88
108
|
initial_results_visibility=remote_inference_results_visibility,
|
89
109
|
)
|
90
|
-
|
110
|
+
logger.update(
|
91
111
|
"Your survey is running at the Expected Parrot server...",
|
92
112
|
status=JobsStatus.RUNNING,
|
93
113
|
)
|
94
|
-
|
95
114
|
job_uuid = remote_job_creation_data.get("uuid")
|
96
|
-
|
115
|
+
logger.update(
|
97
116
|
message=f"Job sent to server. (Job uuid={job_uuid}).",
|
98
117
|
status=JobsStatus.RUNNING,
|
99
118
|
)
|
100
|
-
|
119
|
+
logger.add_info("job_uuid", job_uuid)
|
101
120
|
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
self.logger.update(
|
106
|
-
f"Job details are available at your Coop account {remote_inference_url}{remote_inference_url}",
|
121
|
+
logger.update(
|
122
|
+
f"Job details are available at your Coop account {self.remote_inference_url}",
|
107
123
|
status=JobsStatus.RUNNING,
|
108
124
|
)
|
109
|
-
progress_bar_url =
|
110
|
-
|
111
|
-
|
125
|
+
progress_bar_url = (
|
126
|
+
f"{self.expected_parrot_url}/home/remote-job-progress/{job_uuid}"
|
127
|
+
)
|
128
|
+
logger.add_info("progress_bar_url", progress_bar_url)
|
129
|
+
logger.update(
|
112
130
|
f"View job progress here: {progress_bar_url}", status=JobsStatus.RUNNING
|
113
131
|
)
|
114
132
|
|
115
|
-
|
116
|
-
|
133
|
+
return RemoteJobInfo(
|
134
|
+
creation_data=remote_job_creation_data,
|
135
|
+
job_uuid=job_uuid,
|
136
|
+
logger=logger,
|
137
|
+
)
|
117
138
|
|
118
139
|
@staticmethod
|
119
140
|
def check_status(
|
@@ -124,126 +145,127 @@ class JobsRemoteInferenceHandler:
|
|
124
145
|
coop = Coop()
|
125
146
|
return coop.remote_inference_get(job_uuid)
|
126
147
|
|
127
|
-
def
|
128
|
-
|
129
|
-
|
148
|
+
def _construct_remote_job_fetcher(
|
149
|
+
self, testing_simulated_response: Optional[Any] = None
|
150
|
+
) -> Callable:
|
151
|
+
if testing_simulated_response is not None:
|
152
|
+
return lambda job_uuid: testing_simulated_response
|
153
|
+
else:
|
154
|
+
from edsl.coop.coop import Coop
|
155
|
+
|
156
|
+
coop = Coop()
|
157
|
+
return coop.remote_inference_get
|
158
|
+
|
159
|
+
def _construct_object_fetcher(
|
160
|
+
self, testing_simulated_response: Optional[Any] = None
|
161
|
+
) -> Callable:
|
162
|
+
"Constructs a function to fetch the results object from Coop."
|
163
|
+
if testing_simulated_response is not None:
|
164
|
+
return lambda results_uuid, expected_object_type: Results.example()
|
165
|
+
else:
|
166
|
+
from edsl.coop.coop import Coop
|
167
|
+
|
168
|
+
coop = Coop()
|
169
|
+
return coop.get
|
170
|
+
|
171
|
+
def _handle_cancelled_job(self, job_info: RemoteJobInfo) -> None:
|
172
|
+
"Handles a cancelled job by logging the cancellation and updating the job status."
|
173
|
+
|
174
|
+
job_info.logger.update(
|
175
|
+
message="Job cancelled by the user.", status=JobsStatus.CANCELLED
|
176
|
+
)
|
177
|
+
job_info.logger.update(
|
178
|
+
f"See {self.expected_parrot_url}/home/remote-inference for more details.",
|
179
|
+
status=JobsStatus.CANCELLED,
|
130
180
|
)
|
131
181
|
|
132
|
-
def
|
133
|
-
self,
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
182
|
+
def _handle_failed_job(
|
183
|
+
self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
|
184
|
+
) -> None:
|
185
|
+
"Handles a failed job by logging the error and updating the job status."
|
186
|
+
latest_error_report_url = remote_job_data.get("latest_error_report_url")
|
187
|
+
if latest_error_report_url:
|
188
|
+
job_info.logger.add_info("error_report_url", latest_error_report_url)
|
189
|
+
|
190
|
+
job_info.logger.update("Job failed.", status=JobsStatus.FAILED)
|
191
|
+
job_info.logger.update(
|
192
|
+
f"See {self.expected_parrot_url}/home/remote-inference for more details.",
|
193
|
+
status=JobsStatus.FAILED,
|
194
|
+
)
|
195
|
+
job_info.logger.update(
|
196
|
+
f"Need support? Visit Discord: {RemoteJobConstants.DISCORD_URL}",
|
197
|
+
status=JobsStatus.FAILED,
|
198
|
+
)
|
199
|
+
|
200
|
+
def _sleep_for_a_bit(self, job_info: RemoteJobInfo, status: str) -> None:
|
139
201
|
import time
|
140
202
|
from datetime import datetime
|
141
|
-
from edsl.config import CONFIG
|
142
|
-
from edsl.results.Results import Results
|
143
203
|
|
144
|
-
|
145
|
-
|
204
|
+
time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
|
205
|
+
job_info.logger.update(
|
206
|
+
f"Job status: {status} - last update: {time_checked}",
|
207
|
+
status=JobsStatus.RUNNING,
|
208
|
+
)
|
209
|
+
time.sleep(self.poll_interval)
|
146
210
|
|
147
|
-
|
148
|
-
|
211
|
+
def _fetch_results_and_log(
|
212
|
+
self,
|
213
|
+
job_info: RemoteJobInfo,
|
214
|
+
results_uuid: str,
|
215
|
+
remote_job_data: RemoteInferenceResponse,
|
216
|
+
object_fetcher: Callable,
|
217
|
+
) -> "Results":
|
218
|
+
"Fetches the results object and logs the results URL."
|
219
|
+
job_info.logger.add_info("results_uuid", results_uuid)
|
220
|
+
results = object_fetcher(results_uuid, expected_object_type="results")
|
221
|
+
results_url = remote_job_data.get("results_url")
|
222
|
+
job_info.logger.update(
|
223
|
+
f"Job completed and Results stored on Coop: {results_url}",
|
224
|
+
status=JobsStatus.COMPLETED,
|
225
|
+
)
|
226
|
+
results.job_uuid = job_info.job_uuid
|
227
|
+
results.results_uuid = results_uuid
|
228
|
+
return results
|
149
229
|
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
from edsl.coop.coop import Coop
|
230
|
+
def poll_remote_inference_job(
|
231
|
+
self,
|
232
|
+
job_info: RemoteJobInfo,
|
233
|
+
testing_simulated_response=None,
|
234
|
+
) -> Union[None, "Results"]:
|
235
|
+
"""Polls a remote inference job for completion and returns the results."""
|
157
236
|
|
158
|
-
|
159
|
-
|
160
|
-
|
237
|
+
remote_job_data_fetcher = self._construct_remote_job_fetcher(
|
238
|
+
testing_simulated_response
|
239
|
+
)
|
240
|
+
object_fetcher = self._construct_object_fetcher(testing_simulated_response)
|
161
241
|
|
162
242
|
job_in_queue = True
|
163
243
|
while job_in_queue:
|
164
|
-
remote_job_data
|
244
|
+
remote_job_data = remote_job_data_fetcher(job_info.job_uuid)
|
165
245
|
status = remote_job_data.get("status")
|
166
246
|
|
167
247
|
if status == "cancelled":
|
168
|
-
self.
|
169
|
-
messaged="Job cancelled by the user.", status=JobsStatus.CANCELLED
|
170
|
-
)
|
171
|
-
self.logger.update(
|
172
|
-
f"See {expected_parrot_url}/home/remote-inference for more details.",
|
173
|
-
status=JobsStatus.CANCELLED,
|
174
|
-
)
|
248
|
+
self._handle_cancelled_job(job_info)
|
175
249
|
return None
|
176
250
|
|
177
|
-
elif status == "failed":
|
178
|
-
|
179
|
-
|
180
|
-
self.logger.update("Job failed.", status=JobsStatus.FAILED)
|
181
|
-
self.logger.update(
|
182
|
-
f"Error report: {latest_error_report_url}", "failed"
|
183
|
-
)
|
184
|
-
self.logger.add_info("error_report_url", latest_error_report_url)
|
185
|
-
self.logger.update(
|
186
|
-
"Need support? Visit Discord: https://discord.com/invite/mxAYkjfy9m",
|
187
|
-
status=JobsStatus.FAILED,
|
188
|
-
)
|
189
|
-
else:
|
190
|
-
self.logger.update("Job failed.", "failed")
|
191
|
-
self.logger.update(
|
192
|
-
f"See {expected_parrot_url}/home/remote-inference for details.",
|
193
|
-
status=JobsStatus.FAILED,
|
194
|
-
)
|
251
|
+
elif status == "failed" or status == "completed":
|
252
|
+
if status == "failed":
|
253
|
+
self._handle_failed_job(job_info, remote_job_data)
|
195
254
|
|
196
255
|
results_uuid = remote_job_data.get("results_uuid")
|
197
256
|
if results_uuid:
|
198
|
-
self.
|
199
|
-
|
200
|
-
results_uuid,
|
257
|
+
results = self._fetch_results_and_log(
|
258
|
+
job_info=job_info,
|
259
|
+
results_uuid=results_uuid,
|
260
|
+
remote_job_data=remote_job_data,
|
261
|
+
object_fetcher=object_fetcher,
|
201
262
|
)
|
202
|
-
results.job_uuid = job_uuid
|
203
|
-
results.results_uuid = results_uuid
|
204
263
|
return results
|
205
264
|
else:
|
206
265
|
return None
|
207
266
|
|
208
|
-
elif status == "completed":
|
209
|
-
results_uuid = remote_job_data.get("results_uuid")
|
210
|
-
self.logger.add_info("results_uuid", results_uuid)
|
211
|
-
results_url = remote_job_data.get("results_url")
|
212
|
-
self.logger.add_info("results_url", results_url)
|
213
|
-
results = object_fetcher(results_uuid, expected_object_type="results")
|
214
|
-
self.logger.update(
|
215
|
-
f"Job completed and Results stored on Coop: {results_url}",
|
216
|
-
status=JobsStatus.COMPLETED,
|
217
|
-
)
|
218
|
-
results.job_uuid = job_uuid
|
219
|
-
results.results_uuid = results_uuid
|
220
|
-
return results
|
221
|
-
|
222
267
|
else:
|
223
|
-
|
224
|
-
self.logger.update(
|
225
|
-
f"Job status: {status} - last update: {time_checked}",
|
226
|
-
status=JobsStatus.RUNNING,
|
227
|
-
)
|
228
|
-
time.sleep(poll_interval)
|
229
|
-
|
230
|
-
def use_remote_inference(self, disable_remote_inference: bool) -> bool:
|
231
|
-
import requests
|
232
|
-
|
233
|
-
if disable_remote_inference:
|
234
|
-
return False
|
235
|
-
if not disable_remote_inference:
|
236
|
-
try:
|
237
|
-
from edsl.coop.coop import Coop
|
238
|
-
|
239
|
-
user_edsl_settings = Coop().edsl_settings
|
240
|
-
return user_edsl_settings.get("remote_inference", False)
|
241
|
-
except requests.ConnectionError:
|
242
|
-
pass
|
243
|
-
except CoopServerResponseError as e:
|
244
|
-
pass
|
245
|
-
|
246
|
-
return False
|
268
|
+
self._sleep_for_a_bit(job_info, status)
|
247
269
|
|
248
270
|
async def create_and_poll_remote_job(
|
249
271
|
self,
|
@@ -265,7 +287,7 @@ class JobsRemoteInferenceHandler:
|
|
265
287
|
|
266
288
|
# Create job using existing method
|
267
289
|
loop = asyncio.get_event_loop()
|
268
|
-
|
290
|
+
job_info = await loop.run_in_executor(
|
269
291
|
None,
|
270
292
|
partial(
|
271
293
|
self.create_remote_inference_job,
|
@@ -274,10 +296,12 @@ class JobsRemoteInferenceHandler:
|
|
274
296
|
remote_inference_results_visibility=remote_inference_results_visibility,
|
275
297
|
),
|
276
298
|
)
|
299
|
+
if job_info is None:
|
300
|
+
raise ValueError("Remote job creation failed.")
|
277
301
|
|
278
|
-
# Poll using existing method but with async sleep
|
279
302
|
return await loop.run_in_executor(
|
280
|
-
None,
|
303
|
+
None,
|
304
|
+
partial(self.poll_remote_inference_job, job_info),
|
281
305
|
)
|
282
306
|
|
283
307
|
|
@@ -0,0 +1,138 @@
|
|
1
|
+
from collections.abc import AsyncGenerator
|
2
|
+
from typing import List, TypeVar, Generator, Tuple, TYPE_CHECKING
|
3
|
+
from dataclasses import dataclass
|
4
|
+
import asyncio
|
5
|
+
from contextlib import asynccontextmanager
|
6
|
+
from edsl.data_transfer_models import EDSLResultObjectInput
|
7
|
+
|
8
|
+
from edsl.results.Result import Result
|
9
|
+
from edsl.jobs.interviews.Interview import Interview
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from edsl.jobs.Jobs import Jobs
|
13
|
+
|
14
|
+
|
15
|
+
@dataclass
|
16
|
+
class InterviewResult:
|
17
|
+
result: Result
|
18
|
+
interview: Interview
|
19
|
+
order: int
|
20
|
+
|
21
|
+
|
22
|
+
from edsl.jobs.data_structures import RunConfig
|
23
|
+
|
24
|
+
|
25
|
+
class AsyncInterviewRunner:
|
26
|
+
MAX_CONCURRENT = 5
|
27
|
+
|
28
|
+
def __init__(self, jobs: "Jobs", run_config: RunConfig):
|
29
|
+
self.jobs = jobs
|
30
|
+
self.run_config = run_config
|
31
|
+
self._initialized = asyncio.Event()
|
32
|
+
|
33
|
+
def _expand_interviews(self) -> Generator["Interview", None, None]:
|
34
|
+
"""Populates self.total_interviews with n copies of each interview.
|
35
|
+
|
36
|
+
It also has to set the cache for each interview.
|
37
|
+
|
38
|
+
:param n: how many times to run each interview.
|
39
|
+
"""
|
40
|
+
for interview in self.jobs.generate_interviews():
|
41
|
+
for iteration in range(self.run_config.parameters.n):
|
42
|
+
if iteration > 0:
|
43
|
+
yield interview.duplicate(
|
44
|
+
iteration=iteration, cache=self.run_config.environment.cache
|
45
|
+
)
|
46
|
+
else:
|
47
|
+
interview.cache = self.run_config.environment.cache
|
48
|
+
yield interview
|
49
|
+
|
50
|
+
async def _conduct_interview(
|
51
|
+
self, interview: "Interview"
|
52
|
+
) -> Tuple["Result", "Interview"]:
|
53
|
+
"""Conducts an interview and returns the result object, along with the associated interview.
|
54
|
+
|
55
|
+
We return the interview because it is not populated with exceptions, if any.
|
56
|
+
|
57
|
+
:param interview: the interview to conduct
|
58
|
+
:return: the result of the interview
|
59
|
+
|
60
|
+
'extracted_answers' is a dictionary of the answers to the questions in the interview.
|
61
|
+
This is not the same as the generated_tokens---it can include substantial cleaning and processing / validation.
|
62
|
+
"""
|
63
|
+
# the model buckets are used to track usage rates
|
64
|
+
# model_buckets = self.bucket_collection[interview.model]
|
65
|
+
# model_buckets = self.run_config.environment.bucket_collection[interview.model]
|
66
|
+
|
67
|
+
# get the results of the interview e.g., {'how_are_you':"Good" 'how_are_you_generated_tokens': "Good"}
|
68
|
+
extracted_answers: dict[str, str]
|
69
|
+
model_response_objects: List[EDSLResultObjectInput]
|
70
|
+
|
71
|
+
extracted_answers, model_response_objects = (
|
72
|
+
await interview.async_conduct_interview(self.run_config)
|
73
|
+
)
|
74
|
+
result = Result.from_interview(
|
75
|
+
interview=interview,
|
76
|
+
extracted_answers=extracted_answers,
|
77
|
+
model_response_objects=model_response_objects,
|
78
|
+
)
|
79
|
+
return result, interview
|
80
|
+
|
81
|
+
async def run(
|
82
|
+
self,
|
83
|
+
) -> AsyncGenerator[tuple[Result, Interview], None]:
|
84
|
+
"""Creates and processes tasks asynchronously, yielding results as they complete.
|
85
|
+
|
86
|
+
Uses TaskGroup for structured concurrency and automated cleanup.
|
87
|
+
Results are yielded as they become available while maintaining controlled concurrency.
|
88
|
+
"""
|
89
|
+
interviews = list(self._expand_interviews())
|
90
|
+
self._initialized.set()
|
91
|
+
|
92
|
+
async def _process_single_interview(
|
93
|
+
interview: Interview, idx: int
|
94
|
+
) -> InterviewResult:
|
95
|
+
try:
|
96
|
+
result, interview = await self._conduct_interview(interview)
|
97
|
+
self.run_config.environment.jobs_runner_status.add_completed_interview(
|
98
|
+
result
|
99
|
+
)
|
100
|
+
result.order = idx
|
101
|
+
return InterviewResult(result, interview, idx)
|
102
|
+
except Exception as e:
|
103
|
+
# breakpoint()
|
104
|
+
if self.run_config.parameters.stop_on_exception:
|
105
|
+
raise
|
106
|
+
# logger.error(f"Task failed with error: {e}")
|
107
|
+
return None
|
108
|
+
|
109
|
+
# Process interviews in chunks
|
110
|
+
for i in range(0, len(interviews), self.MAX_CONCURRENT):
|
111
|
+
chunk = interviews[i : i + self.MAX_CONCURRENT]
|
112
|
+
tasks = [
|
113
|
+
asyncio.create_task(_process_single_interview(interview, idx))
|
114
|
+
for idx, interview in enumerate(chunk, start=i)
|
115
|
+
]
|
116
|
+
|
117
|
+
try:
|
118
|
+
# Wait for all tasks in the chunk to complete
|
119
|
+
results = await asyncio.gather(
|
120
|
+
*tasks,
|
121
|
+
return_exceptions=not self.run_config.parameters.stop_on_exception
|
122
|
+
)
|
123
|
+
|
124
|
+
# Process successful results
|
125
|
+
for result in (r for r in results if r is not None):
|
126
|
+
yield result.result, result.interview
|
127
|
+
|
128
|
+
except Exception as e:
|
129
|
+
if self.run_config.parameters.stop_on_exception:
|
130
|
+
raise
|
131
|
+
# logger.error(f"Chunk processing failed with error: {e}")
|
132
|
+
continue
|
133
|
+
|
134
|
+
finally:
|
135
|
+
# Clean up any remaining tasks
|
136
|
+
for task in tasks:
|
137
|
+
if not task.done():
|
138
|
+
task.cancel()
|
@@ -0,0 +1,85 @@
|
|
1
|
+
import warnings
|
2
|
+
from typing import TYPE_CHECKING
|
3
|
+
|
4
|
+
if TYPE_CHECKING:
|
5
|
+
from edsl.surveys.Survey import Survey
|
6
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
7
|
+
|
8
|
+
|
9
|
+
class CheckSurveyScenarioCompatibility:
|
10
|
+
|
11
|
+
def __init__(self, survey: "Survey", scenarios: "ScenarioList"):
|
12
|
+
self.survey = survey
|
13
|
+
self.scenarios = scenarios
|
14
|
+
|
15
|
+
def check(self, strict: bool = False, warn: bool = False) -> None:
|
16
|
+
"""Check if the parameters in the survey and scenarios are consistent.
|
17
|
+
|
18
|
+
>>> from edsl.jobs.Jobs import Jobs
|
19
|
+
>>> from edsl.questions.QuestionFreeText import QuestionFreeText
|
20
|
+
>>> from edsl.surveys.Survey import Survey
|
21
|
+
>>> from edsl.scenarios.Scenario import Scenario
|
22
|
+
>>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
|
23
|
+
>>> j = Jobs(survey = Survey(questions=[q]))
|
24
|
+
>>> cs = CheckSurveyScenarioCompatibility(j.survey, j.scenarios)
|
25
|
+
>>> with warnings.catch_warnings(record=True) as w:
|
26
|
+
... cs.check(warn = True)
|
27
|
+
... assert len(w) == 1
|
28
|
+
... assert issubclass(w[-1].category, UserWarning)
|
29
|
+
... assert "The following parameters are in the survey but not in the scenarios" in str(w[-1].message)
|
30
|
+
|
31
|
+
>>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
|
32
|
+
>>> s = Scenario({'plop': "A", 'poo': "B"})
|
33
|
+
>>> j = Jobs(survey = Survey(questions=[q])).by(s)
|
34
|
+
>>> cs = CheckSurveyScenarioCompatibility(j.survey, j.scenarios)
|
35
|
+
>>> cs.check(strict = True)
|
36
|
+
Traceback (most recent call last):
|
37
|
+
...
|
38
|
+
ValueError: The following parameters are in the scenarios but not in the survey: {'plop'}
|
39
|
+
|
40
|
+
>>> q = QuestionFreeText(question_text = "Hello", question_name = "ugly_question")
|
41
|
+
>>> s = Scenario({'ugly_question': "B"})
|
42
|
+
>>> from edsl.scenarios.ScenarioList import ScenarioList
|
43
|
+
>>> cs = CheckSurveyScenarioCompatibility(Survey(questions=[q]), ScenarioList([s]))
|
44
|
+
>>> cs.check()
|
45
|
+
Traceback (most recent call last):
|
46
|
+
...
|
47
|
+
ValueError: The following names are in both the survey question_names and the scenario keys: {'ugly_question'}. This will create issues.
|
48
|
+
"""
|
49
|
+
survey_parameters: set = self.survey.parameters
|
50
|
+
scenario_parameters: set = self.scenarios.parameters
|
51
|
+
|
52
|
+
msg0, msg1, msg2 = None, None, None
|
53
|
+
|
54
|
+
# look for key issues
|
55
|
+
if intersection := set(self.scenarios.parameters) & set(
|
56
|
+
self.survey.question_names
|
57
|
+
):
|
58
|
+
msg0 = f"The following names are in both the survey question_names and the scenario keys: {intersection}. This will create issues."
|
59
|
+
|
60
|
+
raise ValueError(msg0)
|
61
|
+
|
62
|
+
if in_survey_but_not_in_scenarios := survey_parameters - scenario_parameters:
|
63
|
+
msg1 = f"The following parameters are in the survey but not in the scenarios: {in_survey_but_not_in_scenarios}"
|
64
|
+
if in_scenarios_but_not_in_survey := scenario_parameters - survey_parameters:
|
65
|
+
msg2 = f"The following parameters are in the scenarios but not in the survey: {in_scenarios_but_not_in_survey}"
|
66
|
+
|
67
|
+
if msg1 or msg2:
|
68
|
+
message = "\n".join(filter(None, [msg1, msg2]))
|
69
|
+
if strict:
|
70
|
+
raise ValueError(message)
|
71
|
+
else:
|
72
|
+
if warn:
|
73
|
+
warnings.warn(message)
|
74
|
+
|
75
|
+
if self.scenarios.has_jinja_braces:
|
76
|
+
warnings.warn(
|
77
|
+
"The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
|
78
|
+
)
|
79
|
+
self.scenarios = self.scenarios._convert_jinja_braces()
|
80
|
+
|
81
|
+
|
82
|
+
if __name__ == "__main__":
|
83
|
+
import doctest
|
84
|
+
|
85
|
+
doctest.testmod()
|