edsl 0.1.39.dev2__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. edsl/Base.py +28 -0
  2. edsl/__init__.py +1 -1
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +8 -16
  5. edsl/agents/Invigilator.py +13 -14
  6. edsl/agents/InvigilatorBase.py +4 -1
  7. edsl/agents/PromptConstructor.py +42 -22
  8. edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
  9. edsl/auto/AutoStudy.py +18 -5
  10. edsl/auto/StageBase.py +53 -40
  11. edsl/auto/StageQuestions.py +2 -1
  12. edsl/auto/utilities.py +0 -6
  13. edsl/coop/coop.py +21 -5
  14. edsl/data/Cache.py +29 -18
  15. edsl/data/CacheHandler.py +0 -2
  16. edsl/data/RemoteCacheSync.py +154 -46
  17. edsl/data/hack.py +10 -0
  18. edsl/enums.py +7 -0
  19. edsl/inference_services/AnthropicService.py +38 -16
  20. edsl/inference_services/AvailableModelFetcher.py +7 -1
  21. edsl/inference_services/GoogleService.py +5 -1
  22. edsl/inference_services/InferenceServicesCollection.py +18 -2
  23. edsl/inference_services/OpenAIService.py +46 -31
  24. edsl/inference_services/TestService.py +1 -3
  25. edsl/inference_services/TogetherAIService.py +5 -3
  26. edsl/inference_services/data_structures.py +74 -2
  27. edsl/jobs/AnswerQuestionFunctionConstructor.py +148 -113
  28. edsl/jobs/FetchInvigilator.py +10 -3
  29. edsl/jobs/InterviewsConstructor.py +6 -4
  30. edsl/jobs/Jobs.py +299 -233
  31. edsl/jobs/JobsChecks.py +2 -2
  32. edsl/jobs/JobsPrompts.py +1 -1
  33. edsl/jobs/JobsRemoteInferenceHandler.py +160 -136
  34. edsl/jobs/async_interview_runner.py +138 -0
  35. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  36. edsl/jobs/data_structures.py +120 -0
  37. edsl/jobs/interviews/Interview.py +80 -42
  38. edsl/jobs/results_exceptions_handler.py +98 -0
  39. edsl/jobs/runners/JobsRunnerAsyncio.py +87 -357
  40. edsl/jobs/runners/JobsRunnerStatus.py +131 -164
  41. edsl/jobs/tasks/TaskHistory.py +24 -3
  42. edsl/language_models/LanguageModel.py +59 -4
  43. edsl/language_models/ModelList.py +19 -8
  44. edsl/language_models/__init__.py +1 -1
  45. edsl/language_models/model.py +256 -0
  46. edsl/language_models/repair.py +1 -1
  47. edsl/questions/QuestionBase.py +35 -26
  48. edsl/questions/QuestionBasePromptsMixin.py +1 -1
  49. edsl/questions/QuestionBudget.py +1 -1
  50. edsl/questions/QuestionCheckBox.py +2 -2
  51. edsl/questions/QuestionExtract.py +5 -7
  52. edsl/questions/QuestionFreeText.py +1 -1
  53. edsl/questions/QuestionList.py +9 -15
  54. edsl/questions/QuestionMatrix.py +1 -1
  55. edsl/questions/QuestionMultipleChoice.py +1 -1
  56. edsl/questions/QuestionNumerical.py +1 -1
  57. edsl/questions/QuestionRank.py +1 -1
  58. edsl/questions/SimpleAskMixin.py +1 -1
  59. edsl/questions/__init__.py +1 -1
  60. edsl/questions/data_structures.py +20 -0
  61. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +52 -49
  62. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +6 -18
  63. edsl/questions/{ResponseValidatorFactory.py → response_validator_factory.py} +7 -1
  64. edsl/results/DatasetExportMixin.py +60 -119
  65. edsl/results/Result.py +109 -3
  66. edsl/results/Results.py +50 -39
  67. edsl/results/file_exports.py +252 -0
  68. edsl/scenarios/ScenarioList.py +35 -7
  69. edsl/surveys/Survey.py +71 -20
  70. edsl/test_h +1 -0
  71. edsl/utilities/gcp_bucket/example.py +50 -0
  72. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +2 -2
  73. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/RECORD +85 -76
  74. edsl/language_models/registry.py +0 -180
  75. /edsl/agents/{QuestionOptionProcessor.py → question_option_processor.py} +0 -0
  76. /edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +0 -0
  77. /edsl/questions/{LoopProcessor.py → loop_processor.py} +0 -0
  78. /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
  79. /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
  80. /edsl/results/{Selector.py → results_selector.py} +0 -0
  81. /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
  82. /edsl/scenarios/{DirectoryScanner.py → directory_scanner.py} +0 -0
  83. /edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +0 -0
  84. /edsl/scenarios/{ScenarioSelector.py → scenario_selector.py} +0 -0
  85. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +0 -0
  86. {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
edsl/jobs/JobsPrompts.py CHANGED
@@ -51,7 +51,7 @@ class JobsPrompts:
51
51
  for interview_index, interview in enumerate(interviews):
52
52
  invigilators = [
53
53
  FetchInvigilator(interview)(question)
54
- for question in self.survey.questions
54
+ for question in interview.survey.questions
55
55
  ]
56
56
  for _, invigilator in enumerate(invigilators):
57
57
  prompts = invigilator.get_prompts()
@@ -1,4 +1,6 @@
1
- from typing import Optional, Union, Literal, TYPE_CHECKING, NewType
1
+ from typing import Optional, Union, Literal, TYPE_CHECKING, NewType, Callable, Any
2
+
3
+ from dataclasses import dataclass
2
4
 
3
5
 
4
6
  Seconds = NewType("Seconds", float)
@@ -16,26 +18,52 @@ from edsl.coop.coop import RemoteInferenceResponse, RemoteInferenceCreationInfo
16
18
 
17
19
  from edsl.jobs.jobs_status_enums import JobsStatus
18
20
  from edsl.coop.utils import VisibilityType
21
+ from edsl.jobs.JobsRemoteInferenceLogger import JobLogger
22
+
23
+
24
+ class RemoteJobConstants:
25
+ """Constants for remote job handling."""
26
+
27
+ REMOTE_JOB_POLL_INTERVAL = 1
28
+ REMOTE_JOB_VERBOSE = False
29
+ DISCORD_URL = "https://discord.com/invite/mxAYkjfy9m"
30
+
31
+
32
+ @dataclass
33
+ class RemoteJobInfo:
34
+ creation_data: RemoteInferenceCreationInfo
35
+ job_uuid: JobUUID
36
+ logger: JobLogger
19
37
 
20
38
 
21
39
  class JobsRemoteInferenceHandler:
22
- def __init__(self, jobs: "Jobs", verbose: bool = False, poll_interval: Seconds = 1):
23
- """ """
40
+ def __init__(
41
+ self,
42
+ jobs: "Jobs",
43
+ verbose: bool = RemoteJobConstants.REMOTE_JOB_VERBOSE,
44
+ poll_interval: Seconds = RemoteJobConstants.REMOTE_JOB_POLL_INTERVAL,
45
+ ):
46
+ """Handles the creation and running of a remote inference job."""
24
47
  self.jobs = jobs
25
48
  self.verbose = verbose
26
49
  self.poll_interval = poll_interval
27
50
 
28
- self._remote_job_creation_data: Union[None, RemoteInferenceCreationInfo] = None
29
- self._job_uuid: Union[None, JobUUID] = None # Will be set when job is created
30
- self.logger: Union[None, JobLogger] = None # Will be initialized when needed
51
+ from edsl.config import CONFIG
31
52
 
32
- @property
33
- def remote_job_creation_data(self) -> RemoteInferenceCreationInfo:
34
- return self._remote_job_creation_data
53
+ self.expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
54
+ self.remote_inference_url = f"{self.expected_parrot_url}/home/remote-inference"
35
55
 
36
- @property
37
- def job_uuid(self) -> JobUUID:
38
- return self._job_uuid
56
+ def _create_logger(self) -> JobLogger:
57
+ from edsl.utilities.is_notebook import is_notebook
58
+ from edsl.jobs.JobsRemoteInferenceLogger import (
59
+ JupyterJobLogger,
60
+ StdOutJobLogger,
61
+ )
62
+ from edsl.jobs.loggers.HTMLTableJobLogger import HTMLTableJobLogger
63
+
64
+ if is_notebook():
65
+ return HTMLTableJobLogger(verbose=self.verbose)
66
+ return StdOutJobLogger(verbose=self.verbose)
39
67
 
40
68
  def use_remote_inference(self, disable_remote_inference: bool) -> bool:
41
69
  import requests
@@ -60,23 +88,15 @@ class JobsRemoteInferenceHandler:
60
88
  iterations: int = 1,
61
89
  remote_inference_description: Optional[str] = None,
62
90
  remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
63
- ) -> None:
91
+ ) -> RemoteJobInfo:
92
+
64
93
  from edsl.config import CONFIG
65
94
  from edsl.coop.coop import Coop
66
95
 
67
- # Initialize logger
68
- from edsl.utilities.is_notebook import is_notebook
69
- from edsl.jobs.JobsRemoteInferenceLogger import JupyterJobLogger
70
- from edsl.jobs.JobsRemoteInferenceLogger import StdOutJobLogger
71
- from edsl.jobs.loggers.HTMLTableJobLogger import HTMLTableJobLogger
72
-
73
- if is_notebook():
74
- self.logger = HTMLTableJobLogger(verbose=self.verbose)
75
- else:
76
- self.logger = StdOutJobLogger(verbose=self.verbose)
96
+ logger = self._create_logger()
77
97
 
78
98
  coop = Coop()
79
- self.logger.update(
99
+ logger.update(
80
100
  "Remote inference activated. Sending job to server...",
81
101
  status=JobsStatus.QUEUED,
82
102
  )
@@ -87,33 +107,34 @@ class JobsRemoteInferenceHandler:
87
107
  iterations=iterations,
88
108
  initial_results_visibility=remote_inference_results_visibility,
89
109
  )
90
- self.logger.update(
110
+ logger.update(
91
111
  "Your survey is running at the Expected Parrot server...",
92
112
  status=JobsStatus.RUNNING,
93
113
  )
94
-
95
114
  job_uuid = remote_job_creation_data.get("uuid")
96
- self.logger.update(
115
+ logger.update(
97
116
  message=f"Job sent to server. (Job uuid={job_uuid}).",
98
117
  status=JobsStatus.RUNNING,
99
118
  )
100
- self.logger.add_info("job_uuid", job_uuid)
119
+ logger.add_info("job_uuid", job_uuid)
101
120
 
102
- expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
103
- remote_inference_url = f"{expected_parrot_url}/home/remote-inference"
104
-
105
- self.logger.update(
106
- f"Job details are available at your Coop account {remote_inference_url}{remote_inference_url}",
121
+ logger.update(
122
+ f"Job details are available at your Coop account {self.remote_inference_url}",
107
123
  status=JobsStatus.RUNNING,
108
124
  )
109
- progress_bar_url = f"{expected_parrot_url}/home/remote-job-progress/{job_uuid}"
110
- self.logger.add_info("progress_bar_url", progress_bar_url)
111
- self.logger.update(
125
+ progress_bar_url = (
126
+ f"{self.expected_parrot_url}/home/remote-job-progress/{job_uuid}"
127
+ )
128
+ logger.add_info("progress_bar_url", progress_bar_url)
129
+ logger.update(
112
130
  f"View job progress here: {progress_bar_url}", status=JobsStatus.RUNNING
113
131
  )
114
132
 
115
- self._remote_job_creation_data = remote_job_creation_data
116
- self._job_uuid = job_uuid
133
+ return RemoteJobInfo(
134
+ creation_data=remote_job_creation_data,
135
+ job_uuid=job_uuid,
136
+ logger=logger,
137
+ )
117
138
 
118
139
  @staticmethod
119
140
  def check_status(
@@ -124,126 +145,127 @@ class JobsRemoteInferenceHandler:
124
145
  coop = Coop()
125
146
  return coop.remote_inference_get(job_uuid)
126
147
 
127
- def poll_remote_inference_job(self) -> Union[None, "Results"]:
128
- return self._poll_remote_inference_job(
129
- self.remote_job_creation_data, verbose=self.verbose
148
+ def _construct_remote_job_fetcher(
149
+ self, testing_simulated_response: Optional[Any] = None
150
+ ) -> Callable:
151
+ if testing_simulated_response is not None:
152
+ return lambda job_uuid: testing_simulated_response
153
+ else:
154
+ from edsl.coop.coop import Coop
155
+
156
+ coop = Coop()
157
+ return coop.remote_inference_get
158
+
159
+ def _construct_object_fetcher(
160
+ self, testing_simulated_response: Optional[Any] = None
161
+ ) -> Callable:
162
+ "Constructs a function to fetch the results object from Coop."
163
+ if testing_simulated_response is not None:
164
+ return lambda results_uuid, expected_object_type: Results.example()
165
+ else:
166
+ from edsl.coop.coop import Coop
167
+
168
+ coop = Coop()
169
+ return coop.get
170
+
171
+ def _handle_cancelled_job(self, job_info: RemoteJobInfo) -> None:
172
+ "Handles a cancelled job by logging the cancellation and updating the job status."
173
+
174
+ job_info.logger.update(
175
+ message="Job cancelled by the user.", status=JobsStatus.CANCELLED
176
+ )
177
+ job_info.logger.update(
178
+ f"See {self.expected_parrot_url}/home/remote-inference for more details.",
179
+ status=JobsStatus.CANCELLED,
130
180
  )
131
181
 
132
- def _poll_remote_inference_job(
133
- self,
134
- remote_job_creation_data: RemoteInferenceCreationInfo,
135
- verbose: bool = False,
136
- poll_interval: Optional[Seconds] = None,
137
- testing_simulated_response=None,
138
- ) -> Union[None, "Results"]:
182
+ def _handle_failed_job(
183
+ self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
184
+ ) -> None:
185
+ "Handles a failed job by logging the error and updating the job status."
186
+ latest_error_report_url = remote_job_data.get("latest_error_report_url")
187
+ if latest_error_report_url:
188
+ job_info.logger.add_info("error_report_url", latest_error_report_url)
189
+
190
+ job_info.logger.update("Job failed.", status=JobsStatus.FAILED)
191
+ job_info.logger.update(
192
+ f"See {self.expected_parrot_url}/home/remote-inference for more details.",
193
+ status=JobsStatus.FAILED,
194
+ )
195
+ job_info.logger.update(
196
+ f"Need support? Visit Discord: {RemoteJobConstants.DISCORD_URL}",
197
+ status=JobsStatus.FAILED,
198
+ )
199
+
200
+ def _sleep_for_a_bit(self, job_info: RemoteJobInfo, status: str) -> None:
139
201
  import time
140
202
  from datetime import datetime
141
- from edsl.config import CONFIG
142
- from edsl.results.Results import Results
143
203
 
144
- if poll_interval is None:
145
- poll_interval = self.poll_interval
204
+ time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
205
+ job_info.logger.update(
206
+ f"Job status: {status} - last update: {time_checked}",
207
+ status=JobsStatus.RUNNING,
208
+ )
209
+ time.sleep(self.poll_interval)
146
210
 
147
- job_uuid = remote_job_creation_data.get("uuid")
148
- expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
211
+ def _fetch_results_and_log(
212
+ self,
213
+ job_info: RemoteJobInfo,
214
+ results_uuid: str,
215
+ remote_job_data: RemoteInferenceResponse,
216
+ object_fetcher: Callable,
217
+ ) -> "Results":
218
+ "Fetches the results object and logs the results URL."
219
+ job_info.logger.add_info("results_uuid", results_uuid)
220
+ results = object_fetcher(results_uuid, expected_object_type="results")
221
+ results_url = remote_job_data.get("results_url")
222
+ job_info.logger.update(
223
+ f"Job completed and Results stored on Coop: {results_url}",
224
+ status=JobsStatus.COMPLETED,
225
+ )
226
+ results.job_uuid = job_info.job_uuid
227
+ results.results_uuid = results_uuid
228
+ return results
149
229
 
150
- if testing_simulated_response is not None:
151
- remote_job_data_fetcher = lambda job_uuid: testing_simulated_response
152
- object_fetcher = (
153
- lambda results_uuid, expected_object_type: Results.example()
154
- )
155
- else:
156
- from edsl.coop.coop import Coop
230
+ def poll_remote_inference_job(
231
+ self,
232
+ job_info: RemoteJobInfo,
233
+ testing_simulated_response=None,
234
+ ) -> Union[None, "Results"]:
235
+ """Polls a remote inference job for completion and returns the results."""
157
236
 
158
- coop = Coop()
159
- remote_job_data_fetcher = coop.remote_inference_get
160
- object_fetcher = coop.get
237
+ remote_job_data_fetcher = self._construct_remote_job_fetcher(
238
+ testing_simulated_response
239
+ )
240
+ object_fetcher = self._construct_object_fetcher(testing_simulated_response)
161
241
 
162
242
  job_in_queue = True
163
243
  while job_in_queue:
164
- remote_job_data: RemoteInferenceResponse = remote_job_data_fetcher(job_uuid)
244
+ remote_job_data = remote_job_data_fetcher(job_info.job_uuid)
165
245
  status = remote_job_data.get("status")
166
246
 
167
247
  if status == "cancelled":
168
- self.logger.update(
169
- messaged="Job cancelled by the user.", status=JobsStatus.CANCELLED
170
- )
171
- self.logger.update(
172
- f"See {expected_parrot_url}/home/remote-inference for more details.",
173
- status=JobsStatus.CANCELLED,
174
- )
248
+ self._handle_cancelled_job(job_info)
175
249
  return None
176
250
 
177
- elif status == "failed":
178
- latest_error_report_url = remote_job_data.get("latest_error_report_url")
179
- if latest_error_report_url:
180
- self.logger.update("Job failed.", status=JobsStatus.FAILED)
181
- self.logger.update(
182
- f"Error report: {latest_error_report_url}", "failed"
183
- )
184
- self.logger.add_info("error_report_url", latest_error_report_url)
185
- self.logger.update(
186
- "Need support? Visit Discord: https://discord.com/invite/mxAYkjfy9m",
187
- status=JobsStatus.FAILED,
188
- )
189
- else:
190
- self.logger.update("Job failed.", "failed")
191
- self.logger.update(
192
- f"See {expected_parrot_url}/home/remote-inference for details.",
193
- status=JobsStatus.FAILED,
194
- )
251
+ elif status == "failed" or status == "completed":
252
+ if status == "failed":
253
+ self._handle_failed_job(job_info, remote_job_data)
195
254
 
196
255
  results_uuid = remote_job_data.get("results_uuid")
197
256
  if results_uuid:
198
- self.logger.add_info("results_uuid", results_uuid)
199
- results = object_fetcher(
200
- results_uuid, expected_object_type="results"
257
+ results = self._fetch_results_and_log(
258
+ job_info=job_info,
259
+ results_uuid=results_uuid,
260
+ remote_job_data=remote_job_data,
261
+ object_fetcher=object_fetcher,
201
262
  )
202
- results.job_uuid = job_uuid
203
- results.results_uuid = results_uuid
204
263
  return results
205
264
  else:
206
265
  return None
207
266
 
208
- elif status == "completed":
209
- results_uuid = remote_job_data.get("results_uuid")
210
- self.logger.add_info("results_uuid", results_uuid)
211
- results_url = remote_job_data.get("results_url")
212
- self.logger.add_info("results_url", results_url)
213
- results = object_fetcher(results_uuid, expected_object_type="results")
214
- self.logger.update(
215
- f"Job completed and Results stored on Coop: {results_url}",
216
- status=JobsStatus.COMPLETED,
217
- )
218
- results.job_uuid = job_uuid
219
- results.results_uuid = results_uuid
220
- return results
221
-
222
267
  else:
223
- time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
224
- self.logger.update(
225
- f"Job status: {status} - last update: {time_checked}",
226
- status=JobsStatus.RUNNING,
227
- )
228
- time.sleep(poll_interval)
229
-
230
- def use_remote_inference(self, disable_remote_inference: bool) -> bool:
231
- import requests
232
-
233
- if disable_remote_inference:
234
- return False
235
- if not disable_remote_inference:
236
- try:
237
- from edsl.coop.coop import Coop
238
-
239
- user_edsl_settings = Coop().edsl_settings
240
- return user_edsl_settings.get("remote_inference", False)
241
- except requests.ConnectionError:
242
- pass
243
- except CoopServerResponseError as e:
244
- pass
245
-
246
- return False
268
+ self._sleep_for_a_bit(job_info, status)
247
269
 
248
270
  async def create_and_poll_remote_job(
249
271
  self,
@@ -265,7 +287,7 @@ class JobsRemoteInferenceHandler:
265
287
 
266
288
  # Create job using existing method
267
289
  loop = asyncio.get_event_loop()
268
- remote_job_creation_data = await loop.run_in_executor(
290
+ job_info = await loop.run_in_executor(
269
291
  None,
270
292
  partial(
271
293
  self.create_remote_inference_job,
@@ -274,10 +296,12 @@ class JobsRemoteInferenceHandler:
274
296
  remote_inference_results_visibility=remote_inference_results_visibility,
275
297
  ),
276
298
  )
299
+ if job_info is None:
300
+ raise ValueError("Remote job creation failed.")
277
301
 
278
- # Poll using existing method but with async sleep
279
302
  return await loop.run_in_executor(
280
- None, partial(self.poll_remote_inference_job, remote_job_creation_data)
303
+ None,
304
+ partial(self.poll_remote_inference_job, job_info),
281
305
  )
282
306
 
283
307
 
@@ -0,0 +1,138 @@
1
+ from collections.abc import AsyncGenerator
2
+ from typing import List, TypeVar, Generator, Tuple, TYPE_CHECKING
3
+ from dataclasses import dataclass
4
+ import asyncio
5
+ from contextlib import asynccontextmanager
6
+ from edsl.data_transfer_models import EDSLResultObjectInput
7
+
8
+ from edsl.results.Result import Result
9
+ from edsl.jobs.interviews.Interview import Interview
10
+
11
+ if TYPE_CHECKING:
12
+ from edsl.jobs.Jobs import Jobs
13
+
14
+
15
+ @dataclass
16
+ class InterviewResult:
17
+ result: Result
18
+ interview: Interview
19
+ order: int
20
+
21
+
22
+ from edsl.jobs.data_structures import RunConfig
23
+
24
+
25
+ class AsyncInterviewRunner:
26
+ MAX_CONCURRENT = 5
27
+
28
+ def __init__(self, jobs: "Jobs", run_config: RunConfig):
29
+ self.jobs = jobs
30
+ self.run_config = run_config
31
+ self._initialized = asyncio.Event()
32
+
33
+ def _expand_interviews(self) -> Generator["Interview", None, None]:
34
+ """Populates self.total_interviews with n copies of each interview.
35
+
36
+ It also has to set the cache for each interview.
37
+
38
+ :param n: how many times to run each interview.
39
+ """
40
+ for interview in self.jobs.generate_interviews():
41
+ for iteration in range(self.run_config.parameters.n):
42
+ if iteration > 0:
43
+ yield interview.duplicate(
44
+ iteration=iteration, cache=self.run_config.environment.cache
45
+ )
46
+ else:
47
+ interview.cache = self.run_config.environment.cache
48
+ yield interview
49
+
50
+ async def _conduct_interview(
51
+ self, interview: "Interview"
52
+ ) -> Tuple["Result", "Interview"]:
53
+ """Conducts an interview and returns the result object, along with the associated interview.
54
+
55
+ We return the interview because it is not populated with exceptions, if any.
56
+
57
+ :param interview: the interview to conduct
58
+ :return: the result of the interview
59
+
60
+ 'extracted_answers' is a dictionary of the answers to the questions in the interview.
61
+ This is not the same as the generated_tokens---it can include substantial cleaning and processing / validation.
62
+ """
63
+ # the model buckets are used to track usage rates
64
+ # model_buckets = self.bucket_collection[interview.model]
65
+ # model_buckets = self.run_config.environment.bucket_collection[interview.model]
66
+
67
+ # get the results of the interview e.g., {'how_are_you':"Good" 'how_are_you_generated_tokens': "Good"}
68
+ extracted_answers: dict[str, str]
69
+ model_response_objects: List[EDSLResultObjectInput]
70
+
71
+ extracted_answers, model_response_objects = (
72
+ await interview.async_conduct_interview(self.run_config)
73
+ )
74
+ result = Result.from_interview(
75
+ interview=interview,
76
+ extracted_answers=extracted_answers,
77
+ model_response_objects=model_response_objects,
78
+ )
79
+ return result, interview
80
+
81
+ async def run(
82
+ self,
83
+ ) -> AsyncGenerator[tuple[Result, Interview], None]:
84
+ """Creates and processes tasks asynchronously, yielding results as they complete.
85
+
86
+ Uses TaskGroup for structured concurrency and automated cleanup.
87
+ Results are yielded as they become available while maintaining controlled concurrency.
88
+ """
89
+ interviews = list(self._expand_interviews())
90
+ self._initialized.set()
91
+
92
+ async def _process_single_interview(
93
+ interview: Interview, idx: int
94
+ ) -> InterviewResult:
95
+ try:
96
+ result, interview = await self._conduct_interview(interview)
97
+ self.run_config.environment.jobs_runner_status.add_completed_interview(
98
+ result
99
+ )
100
+ result.order = idx
101
+ return InterviewResult(result, interview, idx)
102
+ except Exception as e:
103
+ # breakpoint()
104
+ if self.run_config.parameters.stop_on_exception:
105
+ raise
106
+ # logger.error(f"Task failed with error: {e}")
107
+ return None
108
+
109
+ # Process interviews in chunks
110
+ for i in range(0, len(interviews), self.MAX_CONCURRENT):
111
+ chunk = interviews[i : i + self.MAX_CONCURRENT]
112
+ tasks = [
113
+ asyncio.create_task(_process_single_interview(interview, idx))
114
+ for idx, interview in enumerate(chunk, start=i)
115
+ ]
116
+
117
+ try:
118
+ # Wait for all tasks in the chunk to complete
119
+ results = await asyncio.gather(
120
+ *tasks,
121
+ return_exceptions=not self.run_config.parameters.stop_on_exception
122
+ )
123
+
124
+ # Process successful results
125
+ for result in (r for r in results if r is not None):
126
+ yield result.result, result.interview
127
+
128
+ except Exception as e:
129
+ if self.run_config.parameters.stop_on_exception:
130
+ raise
131
+ # logger.error(f"Chunk processing failed with error: {e}")
132
+ continue
133
+
134
+ finally:
135
+ # Clean up any remaining tasks
136
+ for task in tasks:
137
+ if not task.done():
138
+ task.cancel()
@@ -0,0 +1,85 @@
1
+ import warnings
2
+ from typing import TYPE_CHECKING
3
+
4
+ if TYPE_CHECKING:
5
+ from edsl.surveys.Survey import Survey
6
+ from edsl.scenarios.ScenarioList import ScenarioList
7
+
8
+
9
+ class CheckSurveyScenarioCompatibility:
10
+
11
+ def __init__(self, survey: "Survey", scenarios: "ScenarioList"):
12
+ self.survey = survey
13
+ self.scenarios = scenarios
14
+
15
+ def check(self, strict: bool = False, warn: bool = False) -> None:
16
+ """Check if the parameters in the survey and scenarios are consistent.
17
+
18
+ >>> from edsl.jobs.Jobs import Jobs
19
+ >>> from edsl.questions.QuestionFreeText import QuestionFreeText
20
+ >>> from edsl.surveys.Survey import Survey
21
+ >>> from edsl.scenarios.Scenario import Scenario
22
+ >>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
23
+ >>> j = Jobs(survey = Survey(questions=[q]))
24
+ >>> cs = CheckSurveyScenarioCompatibility(j.survey, j.scenarios)
25
+ >>> with warnings.catch_warnings(record=True) as w:
26
+ ... cs.check(warn = True)
27
+ ... assert len(w) == 1
28
+ ... assert issubclass(w[-1].category, UserWarning)
29
+ ... assert "The following parameters are in the survey but not in the scenarios" in str(w[-1].message)
30
+
31
+ >>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
32
+ >>> s = Scenario({'plop': "A", 'poo': "B"})
33
+ >>> j = Jobs(survey = Survey(questions=[q])).by(s)
34
+ >>> cs = CheckSurveyScenarioCompatibility(j.survey, j.scenarios)
35
+ >>> cs.check(strict = True)
36
+ Traceback (most recent call last):
37
+ ...
38
+ ValueError: The following parameters are in the scenarios but not in the survey: {'plop'}
39
+
40
+ >>> q = QuestionFreeText(question_text = "Hello", question_name = "ugly_question")
41
+ >>> s = Scenario({'ugly_question': "B"})
42
+ >>> from edsl.scenarios.ScenarioList import ScenarioList
43
+ >>> cs = CheckSurveyScenarioCompatibility(Survey(questions=[q]), ScenarioList([s]))
44
+ >>> cs.check()
45
+ Traceback (most recent call last):
46
+ ...
47
+ ValueError: The following names are in both the survey question_names and the scenario keys: {'ugly_question'}. This will create issues.
48
+ """
49
+ survey_parameters: set = self.survey.parameters
50
+ scenario_parameters: set = self.scenarios.parameters
51
+
52
+ msg0, msg1, msg2 = None, None, None
53
+
54
+ # look for key issues
55
+ if intersection := set(self.scenarios.parameters) & set(
56
+ self.survey.question_names
57
+ ):
58
+ msg0 = f"The following names are in both the survey question_names and the scenario keys: {intersection}. This will create issues."
59
+
60
+ raise ValueError(msg0)
61
+
62
+ if in_survey_but_not_in_scenarios := survey_parameters - scenario_parameters:
63
+ msg1 = f"The following parameters are in the survey but not in the scenarios: {in_survey_but_not_in_scenarios}"
64
+ if in_scenarios_but_not_in_survey := scenario_parameters - survey_parameters:
65
+ msg2 = f"The following parameters are in the scenarios but not in the survey: {in_scenarios_but_not_in_survey}"
66
+
67
+ if msg1 or msg2:
68
+ message = "\n".join(filter(None, [msg1, msg2]))
69
+ if strict:
70
+ raise ValueError(message)
71
+ else:
72
+ if warn:
73
+ warnings.warn(message)
74
+
75
+ if self.scenarios.has_jinja_braces:
76
+ warnings.warn(
77
+ "The scenarios have Jinja braces ({{ and }}). Converting to '<<' and '>>'. If you want a different conversion, use the convert_jinja_braces method first to modify the scenario."
78
+ )
79
+ self.scenarios = self.scenarios._convert_jinja_braces()
80
+
81
+
82
+ if __name__ == "__main__":
83
+ import doctest
84
+
85
+ doctest.testmod()