edsl 0.1.33__py3-none-any.whl → 0.1.33.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +3 -9
- edsl/__init__.py +0 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +6 -6
- edsl/agents/Invigilator.py +3 -6
- edsl/agents/InvigilatorBase.py +27 -8
- edsl/agents/{PromptConstructor.py → PromptConstructionMixin.py} +29 -101
- edsl/config.py +34 -26
- edsl/coop/coop.py +2 -11
- edsl/data_transfer_models.py +73 -26
- edsl/enums.py +0 -2
- edsl/inference_services/GoogleService.py +1 -1
- edsl/inference_services/InferenceServiceABC.py +13 -44
- edsl/inference_services/OpenAIService.py +4 -7
- edsl/inference_services/TestService.py +15 -24
- edsl/inference_services/registry.py +0 -2
- edsl/jobs/Jobs.py +8 -18
- edsl/jobs/buckets/BucketCollection.py +15 -24
- edsl/jobs/buckets/TokenBucket.py +10 -64
- edsl/jobs/interviews/Interview.py +47 -115
- edsl/jobs/interviews/InterviewExceptionEntry.py +0 -2
- edsl/jobs/interviews/{InterviewExceptionCollection.py → interview_exception_tracking.py} +0 -16
- edsl/jobs/interviews/retry_management.py +39 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +170 -95
- edsl/jobs/runners/JobsRunnerStatusMixin.py +333 -0
- edsl/jobs/tasks/TaskHistory.py +0 -17
- edsl/language_models/LanguageModel.py +31 -26
- edsl/language_models/registry.py +9 -13
- edsl/questions/QuestionBase.py +14 -63
- edsl/questions/QuestionBudget.py +41 -93
- edsl/questions/QuestionFreeText.py +0 -6
- edsl/questions/QuestionMultipleChoice.py +23 -8
- edsl/questions/QuestionNumerical.py +4 -5
- edsl/questions/ResponseValidatorABC.py +5 -6
- edsl/questions/derived/QuestionLinearScale.py +1 -4
- edsl/questions/derived/QuestionTopK.py +1 -4
- edsl/questions/derived/QuestionYesNo.py +2 -8
- edsl/results/DatasetExportMixin.py +1 -5
- edsl/results/Result.py +1 -1
- edsl/results/Results.py +1 -4
- edsl/scenarios/FileStore.py +10 -71
- edsl/scenarios/Scenario.py +21 -86
- edsl/scenarios/ScenarioImageMixin.py +2 -2
- edsl/scenarios/ScenarioList.py +0 -13
- edsl/scenarios/ScenarioListPdfMixin.py +4 -150
- edsl/study/Study.py +0 -32
- edsl/surveys/Rule.py +1 -10
- edsl/surveys/RuleCollection.py +3 -19
- edsl/surveys/Survey.py +0 -7
- edsl/templates/error_reporting/interview_details.html +1 -6
- edsl/utilities/utilities.py +1 -9
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev2.dist-info}/METADATA +1 -2
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev2.dist-info}/RECORD +55 -61
- edsl/inference_services/TogetherAIService.py +0 -170
- edsl/jobs/runners/JobsRunnerStatus.py +0 -331
- edsl/questions/Quick.py +0 -41
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +0 -7
- edsl/questions/templates/budget/question_presentation.jinja +0 -7
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/rank/__init__.py +0 -0
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev2.dist-info}/WHEEL +0 -0
@@ -3,25 +3,40 @@ import time
|
|
3
3
|
import math
|
4
4
|
import asyncio
|
5
5
|
import functools
|
6
|
-
import threading
|
7
6
|
from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator
|
8
7
|
from contextlib import contextmanager
|
9
8
|
from collections import UserList
|
10
9
|
|
11
|
-
from edsl.results.Results import Results
|
12
|
-
from rich.live import Live
|
13
|
-
from rich.console import Console
|
14
|
-
|
15
10
|
from edsl import shared_globals
|
16
11
|
from edsl.jobs.interviews.Interview import Interview
|
17
|
-
from edsl.jobs.runners.
|
18
|
-
|
12
|
+
from edsl.jobs.runners.JobsRunnerStatusMixin import JobsRunnerStatusMixin
|
19
13
|
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
20
14
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
21
15
|
from edsl.utilities.decorators import jupyter_nb_handler
|
22
16
|
from edsl.data.Cache import Cache
|
23
17
|
from edsl.results.Result import Result
|
24
18
|
from edsl.results.Results import Results
|
19
|
+
from edsl.jobs.FailedQuestion import FailedQuestion
|
20
|
+
|
21
|
+
|
22
|
+
def cache_with_timeout(timeout):
|
23
|
+
""" "Used to keep the generate table from being run too frequetly."""
|
24
|
+
|
25
|
+
def decorator(func):
|
26
|
+
cached_result = {}
|
27
|
+
last_computation_time = [0] # Using list to store mutable value
|
28
|
+
|
29
|
+
@functools.wraps(func)
|
30
|
+
def wrapper(*args, **kwargs):
|
31
|
+
current_time = time.time()
|
32
|
+
if (current_time - last_computation_time[0]) >= timeout:
|
33
|
+
cached_result["value"] = func(*args, **kwargs)
|
34
|
+
last_computation_time[0] = current_time
|
35
|
+
return cached_result["value"]
|
36
|
+
|
37
|
+
return wrapper
|
38
|
+
|
39
|
+
return decorator
|
25
40
|
|
26
41
|
|
27
42
|
class StatusTracker(UserList):
|
@@ -33,7 +48,7 @@ class StatusTracker(UserList):
|
|
33
48
|
return print(f"Completed: {len(self.data)} of {self.total_tasks}", end="\r")
|
34
49
|
|
35
50
|
|
36
|
-
class JobsRunnerAsyncio:
|
51
|
+
class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
37
52
|
"""A class for running a collection of interviews asynchronously.
|
38
53
|
|
39
54
|
It gets instaniated from a Jobs object.
|
@@ -42,12 +57,11 @@ class JobsRunnerAsyncio:
|
|
42
57
|
|
43
58
|
def __init__(self, jobs: "Jobs"):
|
44
59
|
self.jobs = jobs
|
60
|
+
# this creates the interviews, which can take a while
|
45
61
|
self.interviews: List["Interview"] = jobs.interviews()
|
46
62
|
self.bucket_collection: "BucketCollection" = jobs.bucket_collection
|
47
63
|
self.total_interviews: List["Interview"] = []
|
48
64
|
|
49
|
-
# self.jobs_runner_status = JobsRunnerStatus(self, n=1)
|
50
|
-
|
51
65
|
async def run_async_generator(
|
52
66
|
self,
|
53
67
|
cache: "Cache",
|
@@ -65,7 +79,6 @@ class JobsRunnerAsyncio:
|
|
65
79
|
:param stop_on_exception: Whether to stop the interview if an exception is raised
|
66
80
|
:param sidecar_model: a language model to use in addition to the interview's model
|
67
81
|
:param total_interviews: A list of interviews to run can be provided instead.
|
68
|
-
:param raise_validation_errors: Whether to raise validation errors
|
69
82
|
"""
|
70
83
|
tasks = []
|
71
84
|
if total_interviews: # was already passed in total interviews
|
@@ -75,6 +88,8 @@ class JobsRunnerAsyncio:
|
|
75
88
|
self._populate_total_interviews(n=n)
|
76
89
|
) # Populate self.total_interviews before creating tasks
|
77
90
|
|
91
|
+
# print("Interviews created")
|
92
|
+
|
78
93
|
for interview in self.total_interviews:
|
79
94
|
interviewing_task = self._build_interview_task(
|
80
95
|
interview=interview,
|
@@ -84,9 +99,11 @@ class JobsRunnerAsyncio:
|
|
84
99
|
)
|
85
100
|
tasks.append(asyncio.create_task(interviewing_task))
|
86
101
|
|
102
|
+
# print("Tasks created")
|
103
|
+
|
87
104
|
for task in asyncio.as_completed(tasks):
|
105
|
+
# print(f"Task {task} completed")
|
88
106
|
result = await task
|
89
|
-
self.jobs_runner_status.add_completed_interview(result)
|
90
107
|
yield result
|
91
108
|
|
92
109
|
def _populate_total_interviews(
|
@@ -105,8 +122,6 @@ class JobsRunnerAsyncio:
|
|
105
122
|
yield interview
|
106
123
|
|
107
124
|
async def run_async(self, cache: Optional["Cache"] = None, n: int = 1) -> Results:
|
108
|
-
"""Used for some other modules that have a non-standard way of running interviews."""
|
109
|
-
self.jobs_runner_status = JobsRunnerStatus(self, n=n)
|
110
125
|
self.cache = Cache() if cache is None else cache
|
111
126
|
data = []
|
112
127
|
async for result in self.run_async_generator(cache=self.cache, n=n):
|
@@ -142,6 +157,12 @@ class JobsRunnerAsyncio:
|
|
142
157
|
raise_validation_errors=raise_validation_errors,
|
143
158
|
)
|
144
159
|
|
160
|
+
# answer_key_names = {
|
161
|
+
# k
|
162
|
+
# for k in set(answer.keys())
|
163
|
+
# if not k.endswith("_comment") and not k.endswith("_generated_tokens")
|
164
|
+
# }
|
165
|
+
|
145
166
|
question_results = {}
|
146
167
|
for result in valid_results:
|
147
168
|
question_results[result.question_name] = result
|
@@ -153,13 +174,24 @@ class JobsRunnerAsyncio:
|
|
153
174
|
for k in answer_key_names
|
154
175
|
}
|
155
176
|
comments_dict = {
|
156
|
-
k + "_comment": question_results[k].comment for k in answer_key_names
|
177
|
+
"k" + "_comment": question_results[k].comment for k in answer_key_names
|
157
178
|
}
|
158
179
|
|
159
180
|
# we should have a valid result for each question
|
160
181
|
answer_dict = {k: answer[k] for k in answer_key_names}
|
161
182
|
assert len(valid_results) == len(answer_key_names)
|
162
183
|
|
184
|
+
# breakpoint()
|
185
|
+
# generated_tokens_dict = {
|
186
|
+
# k + "_generated_tokens": v.generated_tokens
|
187
|
+
# for k, v in zip(answer_key_names, valid_results)
|
188
|
+
# }
|
189
|
+
|
190
|
+
# comments_dict = {
|
191
|
+
# k + "_comment": v.comment for k, v in zip(answer_key_names, valid_results)
|
192
|
+
# }
|
193
|
+
# breakpoint()
|
194
|
+
|
163
195
|
# TODO: move this down into Interview
|
164
196
|
question_name_to_prompts = dict({})
|
165
197
|
for result in valid_results:
|
@@ -171,19 +203,19 @@ class JobsRunnerAsyncio:
|
|
171
203
|
|
172
204
|
prompt_dictionary = {}
|
173
205
|
for answer_key_name in answer_key_names:
|
174
|
-
prompt_dictionary[
|
175
|
-
answer_key_name
|
176
|
-
|
177
|
-
prompt_dictionary[
|
178
|
-
answer_key_name
|
179
|
-
|
206
|
+
prompt_dictionary[answer_key_name + "_user_prompt"] = (
|
207
|
+
question_name_to_prompts[answer_key_name]["user_prompt"]
|
208
|
+
)
|
209
|
+
prompt_dictionary[answer_key_name + "_system_prompt"] = (
|
210
|
+
question_name_to_prompts[answer_key_name]["system_prompt"]
|
211
|
+
)
|
180
212
|
|
181
213
|
raw_model_results_dictionary = {}
|
182
214
|
for result in valid_results:
|
183
215
|
question_name = result.question_name
|
184
|
-
raw_model_results_dictionary[
|
185
|
-
|
186
|
-
|
216
|
+
raw_model_results_dictionary[question_name + "_raw_model_response"] = (
|
217
|
+
result.raw_model_response
|
218
|
+
)
|
187
219
|
raw_model_results_dictionary[question_name + "_cost"] = result.cost
|
188
220
|
one_use_buys = (
|
189
221
|
"NA"
|
@@ -194,6 +226,7 @@ class JobsRunnerAsyncio:
|
|
194
226
|
)
|
195
227
|
raw_model_results_dictionary[question_name + "_one_usd_buys"] = one_use_buys
|
196
228
|
|
229
|
+
# breakpoint()
|
197
230
|
result = Result(
|
198
231
|
agent=interview.agent,
|
199
232
|
scenario=interview.scenario,
|
@@ -214,62 +247,6 @@ class JobsRunnerAsyncio:
|
|
214
247
|
def elapsed_time(self):
|
215
248
|
return time.monotonic() - self.start_time
|
216
249
|
|
217
|
-
def process_results(
|
218
|
-
self, raw_results: Results, cache: Cache, print_exceptions: bool
|
219
|
-
):
|
220
|
-
interview_lookup = {
|
221
|
-
hash(interview): index
|
222
|
-
for index, interview in enumerate(self.total_interviews)
|
223
|
-
}
|
224
|
-
interview_hashes = list(interview_lookup.keys())
|
225
|
-
|
226
|
-
results = Results(
|
227
|
-
survey=self.jobs.survey,
|
228
|
-
data=sorted(
|
229
|
-
raw_results, key=lambda x: interview_hashes.index(x.interview_hash)
|
230
|
-
),
|
231
|
-
)
|
232
|
-
results.cache = cache
|
233
|
-
results.task_history = TaskHistory(
|
234
|
-
self.total_interviews, include_traceback=False
|
235
|
-
)
|
236
|
-
results.has_unfixed_exceptions = results.task_history.has_unfixed_exceptions
|
237
|
-
results.bucket_collection = self.bucket_collection
|
238
|
-
|
239
|
-
if results.has_unfixed_exceptions and print_exceptions:
|
240
|
-
from edsl.scenarios.FileStore import HTMLFileStore
|
241
|
-
from edsl.config import CONFIG
|
242
|
-
from edsl.coop.coop import Coop
|
243
|
-
|
244
|
-
msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
|
245
|
-
|
246
|
-
if len(results.task_history.indices) > 5:
|
247
|
-
msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
|
248
|
-
|
249
|
-
print(msg)
|
250
|
-
# this is where exceptions are opening up
|
251
|
-
filepath = results.task_history.html(
|
252
|
-
cta="Open report to see details.",
|
253
|
-
open_in_browser=True,
|
254
|
-
return_link=True,
|
255
|
-
)
|
256
|
-
|
257
|
-
try:
|
258
|
-
coop = Coop()
|
259
|
-
user_edsl_settings = coop.edsl_settings
|
260
|
-
remote_logging = user_edsl_settings["remote_logging"]
|
261
|
-
except Exception as e:
|
262
|
-
print(e)
|
263
|
-
remote_logging = False
|
264
|
-
if remote_logging:
|
265
|
-
filestore = HTMLFileStore(filepath)
|
266
|
-
coop_details = filestore.push(description="Error report")
|
267
|
-
print(coop_details)
|
268
|
-
|
269
|
-
print("Also see: https://docs.expectedparrot.com/en/latest/exceptions.html")
|
270
|
-
|
271
|
-
return results
|
272
|
-
|
273
250
|
@jupyter_nb_handler
|
274
251
|
async def run(
|
275
252
|
self,
|
@@ -282,16 +259,24 @@ class JobsRunnerAsyncio:
|
|
282
259
|
raise_validation_errors: bool = False,
|
283
260
|
) -> "Coroutine":
|
284
261
|
"""Runs a collection of interviews, handling both async and sync contexts."""
|
262
|
+
from rich.console import Console
|
285
263
|
|
264
|
+
console = Console()
|
286
265
|
self.results = []
|
287
266
|
self.start_time = time.monotonic()
|
288
267
|
self.completed = False
|
289
268
|
self.cache = cache
|
290
269
|
self.sidecar_model = sidecar_model
|
291
270
|
|
292
|
-
|
271
|
+
from edsl.results.Results import Results
|
272
|
+
from rich.live import Live
|
273
|
+
from rich.console import Console
|
293
274
|
|
294
|
-
|
275
|
+
@cache_with_timeout(1)
|
276
|
+
def generate_table():
|
277
|
+
return self.status_table(self.results, self.elapsed_time)
|
278
|
+
|
279
|
+
async def process_results(cache, progress_bar_context=None):
|
295
280
|
"""Processes results from interviews."""
|
296
281
|
async for result in self.run_async_generator(
|
297
282
|
n=n,
|
@@ -301,22 +286,112 @@ class JobsRunnerAsyncio:
|
|
301
286
|
raise_validation_errors=raise_validation_errors,
|
302
287
|
):
|
303
288
|
self.results.append(result)
|
289
|
+
if progress_bar_context:
|
290
|
+
progress_bar_context.update(generate_table())
|
304
291
|
self.completed = True
|
305
292
|
|
306
|
-
def
|
307
|
-
"""
|
308
|
-
|
293
|
+
async def update_progress_bar(progress_bar_context):
|
294
|
+
"""Updates the progress bar at fixed intervals."""
|
295
|
+
if progress_bar_context is None:
|
296
|
+
return
|
297
|
+
|
298
|
+
while True:
|
299
|
+
progress_bar_context.update(generate_table())
|
300
|
+
await asyncio.sleep(0.1) # Update interval
|
301
|
+
if self.completed:
|
302
|
+
break
|
303
|
+
|
304
|
+
@contextmanager
|
305
|
+
def conditional_context(condition, context_manager):
|
306
|
+
if condition:
|
307
|
+
with context_manager as cm:
|
308
|
+
yield cm
|
309
|
+
else:
|
310
|
+
yield
|
311
|
+
|
312
|
+
with conditional_context(
|
313
|
+
progress_bar, Live(generate_table(), console=console, refresh_per_second=1)
|
314
|
+
) as progress_bar_context:
|
315
|
+
with cache as c:
|
316
|
+
progress_task = asyncio.create_task(
|
317
|
+
update_progress_bar(progress_bar_context)
|
318
|
+
)
|
319
|
+
|
320
|
+
try:
|
321
|
+
await asyncio.gather(
|
322
|
+
progress_task,
|
323
|
+
process_results(
|
324
|
+
cache=c, progress_bar_context=progress_bar_context
|
325
|
+
),
|
326
|
+
)
|
327
|
+
except asyncio.CancelledError:
|
328
|
+
pass
|
329
|
+
finally:
|
330
|
+
progress_task.cancel() # Cancel the progress_task when process_results is done
|
331
|
+
await progress_task
|
332
|
+
|
333
|
+
await asyncio.sleep(1) # short delay to show the final status
|
334
|
+
|
335
|
+
if progress_bar_context:
|
336
|
+
progress_bar_context.update(generate_table())
|
337
|
+
|
338
|
+
# puts results in the same order as the total interviews
|
339
|
+
interview_lookup = {
|
340
|
+
hash(interview): index
|
341
|
+
for index, interview in enumerate(self.total_interviews)
|
342
|
+
}
|
343
|
+
interview_hashes = list(interview_lookup.keys())
|
344
|
+
self.results = sorted(
|
345
|
+
self.results, key=lambda x: interview_hashes.index(x.interview_hash)
|
346
|
+
)
|
347
|
+
|
348
|
+
results = Results(survey=self.jobs.survey, data=self.results)
|
349
|
+
task_history = TaskHistory(self.total_interviews, include_traceback=False)
|
350
|
+
results.task_history = task_history
|
309
351
|
|
310
|
-
|
311
|
-
|
312
|
-
progress_thread.start()
|
352
|
+
results.failed_questions = {}
|
353
|
+
results.has_exceptions = task_history.has_exceptions
|
313
354
|
|
314
|
-
|
315
|
-
|
355
|
+
# breakpoint()
|
356
|
+
results.bucket_collection = self.bucket_collection
|
316
357
|
|
317
|
-
if
|
318
|
-
|
358
|
+
if results.has_exceptions:
|
359
|
+
# put the failed interviews in the results object as a list
|
360
|
+
failed_interviews = [
|
361
|
+
interview.duplicate(
|
362
|
+
iteration=interview.iteration, cache=interview.cache
|
363
|
+
)
|
364
|
+
for interview in self.total_interviews
|
365
|
+
if interview.has_exceptions
|
366
|
+
]
|
319
367
|
|
320
|
-
|
321
|
-
|
322
|
-
|
368
|
+
failed_questions = {}
|
369
|
+
for interview in self.total_interviews:
|
370
|
+
if interview.has_exceptions:
|
371
|
+
index = interview_lookup[hash(interview)]
|
372
|
+
failed_questions[index] = interview.failed_questions
|
373
|
+
|
374
|
+
results.failed_questions = failed_questions
|
375
|
+
|
376
|
+
from edsl.jobs.Jobs import Jobs
|
377
|
+
|
378
|
+
results.failed_jobs = Jobs.from_interviews(
|
379
|
+
[interview for interview in failed_interviews]
|
380
|
+
)
|
381
|
+
if print_exceptions:
|
382
|
+
msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
|
383
|
+
|
384
|
+
if len(results.task_history.indices) > 5:
|
385
|
+
msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
|
386
|
+
|
387
|
+
shared_globals["edsl_runner_exceptions"] = task_history
|
388
|
+
print(msg)
|
389
|
+
# this is where exceptions are opening up
|
390
|
+
task_history.html(
|
391
|
+
cta="Open report to see details.", open_in_browser=True
|
392
|
+
)
|
393
|
+
print(
|
394
|
+
"Also see: https://docs.expectedparrot.com/en/latest/exceptions.html"
|
395
|
+
)
|
396
|
+
|
397
|
+
return results
|