edsl 0.1.33.dev2__py3-none-any.whl → 0.1.33.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +9 -3
- edsl/__init__.py +1 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +6 -6
- edsl/agents/Invigilator.py +6 -3
- edsl/agents/InvigilatorBase.py +8 -27
- edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +101 -29
- edsl/config.py +26 -34
- edsl/coop/coop.py +11 -2
- edsl/data_transfer_models.py +27 -73
- edsl/enums.py +2 -0
- edsl/inference_services/GoogleService.py +1 -1
- edsl/inference_services/InferenceServiceABC.py +44 -13
- edsl/inference_services/OpenAIService.py +7 -4
- edsl/inference_services/TestService.py +24 -15
- edsl/inference_services/TogetherAIService.py +170 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +18 -8
- edsl/jobs/buckets/BucketCollection.py +24 -15
- edsl/jobs/buckets/TokenBucket.py +64 -10
- edsl/jobs/interviews/Interview.py +115 -47
- edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +86 -161
- edsl/jobs/runners/JobsRunnerStatus.py +331 -0
- edsl/jobs/tasks/TaskHistory.py +17 -0
- edsl/language_models/LanguageModel.py +26 -31
- edsl/language_models/registry.py +13 -9
- edsl/questions/QuestionBase.py +64 -16
- edsl/questions/QuestionBudget.py +93 -41
- edsl/questions/QuestionFreeText.py +6 -0
- edsl/questions/QuestionMultipleChoice.py +11 -26
- edsl/questions/QuestionNumerical.py +5 -4
- edsl/questions/Quick.py +41 -0
- edsl/questions/ResponseValidatorABC.py +6 -5
- edsl/questions/derived/QuestionLinearScale.py +4 -1
- edsl/questions/derived/QuestionTopK.py +4 -1
- edsl/questions/derived/QuestionYesNo.py +8 -2
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/results/DatasetExportMixin.py +5 -1
- edsl/results/Result.py +1 -1
- edsl/results/Results.py +4 -1
- edsl/scenarios/FileStore.py +71 -10
- edsl/scenarios/Scenario.py +86 -21
- edsl/scenarios/ScenarioImageMixin.py +2 -2
- edsl/scenarios/ScenarioList.py +13 -0
- edsl/scenarios/ScenarioListPdfMixin.py +150 -4
- edsl/study/Study.py +32 -0
- edsl/surveys/Rule.py +10 -1
- edsl/surveys/RuleCollection.py +19 -3
- edsl/surveys/Survey.py +7 -0
- edsl/templates/error_reporting/interview_details.html +6 -1
- edsl/utilities/utilities.py +9 -1
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/METADATA +2 -1
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/RECORD +61 -55
- edsl/jobs/interviews/retry_management.py +0 -39
- edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/WHEEL +0 -0
@@ -4,10 +4,18 @@ from __future__ import annotations
|
|
4
4
|
import asyncio
|
5
5
|
from typing import Any, Type, List, Generator, Optional, Union
|
6
6
|
|
7
|
+
from tenacity import (
|
8
|
+
retry,
|
9
|
+
stop_after_attempt,
|
10
|
+
wait_exponential,
|
11
|
+
retry_if_exception_type,
|
12
|
+
RetryError,
|
13
|
+
)
|
14
|
+
|
7
15
|
from edsl import CONFIG
|
8
16
|
from edsl.surveys.base import EndOfSurvey
|
9
17
|
from edsl.exceptions import QuestionAnswerValidationError
|
10
|
-
from edsl.exceptions import
|
18
|
+
from edsl.exceptions import QuestionAnswerValidationError
|
11
19
|
from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
|
12
20
|
|
13
21
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
@@ -15,21 +23,18 @@ from edsl.jobs.Answers import Answers
|
|
15
23
|
from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
|
16
24
|
from edsl.jobs.tasks.TaskCreators import TaskCreators
|
17
25
|
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
18
|
-
from edsl.jobs.interviews.
|
26
|
+
from edsl.jobs.interviews.InterviewExceptionCollection import (
|
19
27
|
InterviewExceptionCollection,
|
20
28
|
)
|
21
|
-
|
22
|
-
from edsl.jobs.interviews.retry_management import retry_strategy
|
29
|
+
|
23
30
|
from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
|
24
31
|
|
25
32
|
from edsl.surveys.base import EndOfSurvey
|
26
33
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
27
34
|
from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
28
|
-
from edsl.jobs.interviews.retry_management import retry_strategy
|
29
35
|
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
30
36
|
from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
|
31
37
|
|
32
|
-
from edsl.exceptions import QuestionAnswerValidationError
|
33
38
|
|
34
39
|
from edsl import Agent, Survey, Scenario, Cache
|
35
40
|
from edsl.language_models import LanguageModel
|
@@ -39,8 +44,11 @@ from edsl.agents.InvigilatorBase import InvigilatorBase
|
|
39
44
|
from edsl.exceptions.language_models import LanguageModelNoResponseError
|
40
45
|
|
41
46
|
|
42
|
-
|
43
|
-
|
47
|
+
from edsl import CONFIG
|
48
|
+
|
49
|
+
EDSL_BACKOFF_START_SEC = float(CONFIG.get("EDSL_BACKOFF_START_SEC"))
|
50
|
+
EDSL_BACKOFF_MAX_SEC = float(CONFIG.get("EDSL_BACKOFF_MAX_SEC"))
|
51
|
+
EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
|
44
52
|
|
45
53
|
|
46
54
|
class Interview(InterviewStatusMixin):
|
@@ -97,14 +105,17 @@ class Interview(InterviewStatusMixin):
|
|
97
105
|
self.debug = debug
|
98
106
|
self.iteration = iteration
|
99
107
|
self.cache = cache
|
100
|
-
self.answers: dict[
|
101
|
-
|
102
|
-
) # will get filled in as interview progresses
|
108
|
+
self.answers: dict[
|
109
|
+
str, str
|
110
|
+
] = Answers() # will get filled in as interview progresses
|
103
111
|
self.sidecar_model = sidecar_model
|
104
112
|
|
113
|
+
# self.stop_on_exception = False
|
114
|
+
|
105
115
|
# Trackers
|
106
116
|
self.task_creators = TaskCreators() # tracks the task creators
|
107
117
|
self.exceptions = InterviewExceptionCollection()
|
118
|
+
|
108
119
|
self._task_status_log_dict = InterviewStatusLog()
|
109
120
|
self.skip_retry = skip_retry
|
110
121
|
self.raise_validation_errors = raise_validation_errors
|
@@ -257,44 +268,80 @@ class Interview(InterviewStatusMixin):
|
|
257
268
|
) -> "AgentResponseDict":
|
258
269
|
"""Answer a question and records the task."""
|
259
270
|
|
260
|
-
|
271
|
+
had_language_model_no_response_error = False
|
261
272
|
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
273
|
+
@retry(
|
274
|
+
stop=stop_after_attempt(EDSL_MAX_ATTEMPTS),
|
275
|
+
wait=wait_exponential(
|
276
|
+
multiplier=EDSL_BACKOFF_START_SEC, max=EDSL_BACKOFF_MAX_SEC
|
277
|
+
),
|
278
|
+
retry=retry_if_exception_type(LanguageModelNoResponseError),
|
279
|
+
reraise=True,
|
280
|
+
)
|
281
|
+
async def attempt_answer():
|
282
|
+
nonlocal had_language_model_no_response_error
|
266
283
|
|
267
|
-
|
268
|
-
response: EDSLResultObjectInput = await invigilator.async_answer_question()
|
269
|
-
if response.validated:
|
270
|
-
self.answers.add_answer(response=response, question=question)
|
271
|
-
self._cancel_skipped_questions(question)
|
272
|
-
else:
|
273
|
-
if (
|
274
|
-
hasattr(response, "exception_occurred")
|
275
|
-
and response.exception_occurred
|
276
|
-
):
|
277
|
-
raise response.exception_occurred
|
284
|
+
invigilator = self._get_invigilator(question)
|
278
285
|
|
279
|
-
|
280
|
-
|
281
|
-
|
286
|
+
if self._skip_this_question(question):
|
287
|
+
return invigilator.get_failed_task_result(
|
288
|
+
failure_reason="Question skipped."
|
289
|
+
)
|
282
290
|
|
283
|
-
|
284
|
-
|
285
|
-
|
291
|
+
try:
|
292
|
+
response: EDSLResultObjectInput = (
|
293
|
+
await invigilator.async_answer_question()
|
294
|
+
)
|
295
|
+
if response.validated:
|
296
|
+
self.answers.add_answer(response=response, question=question)
|
297
|
+
self._cancel_skipped_questions(question)
|
298
|
+
else:
|
299
|
+
if (
|
300
|
+
hasattr(response, "exception_occurred")
|
301
|
+
and response.exception_occurred
|
302
|
+
):
|
303
|
+
raise response.exception_occurred
|
304
|
+
|
305
|
+
except QuestionAnswerValidationError as e:
|
306
|
+
self._handle_exception(e, invigilator, task)
|
307
|
+
return invigilator.get_failed_task_result(
|
308
|
+
failure_reason="Question answer validation failed."
|
309
|
+
)
|
286
310
|
|
287
|
-
|
288
|
-
|
289
|
-
|
311
|
+
except asyncio.TimeoutError as e:
|
312
|
+
self._handle_exception(e, invigilator, task)
|
313
|
+
had_language_model_no_response_error = True
|
314
|
+
raise LanguageModelNoResponseError(
|
315
|
+
f"Language model timed out for question '{question.question_name}.'"
|
316
|
+
)
|
290
317
|
|
291
|
-
|
318
|
+
except Exception as e:
|
319
|
+
self._handle_exception(e, invigilator, task)
|
292
320
|
|
293
|
-
|
294
|
-
|
295
|
-
|
321
|
+
if "response" not in locals():
|
322
|
+
had_language_model_no_response_error = True
|
323
|
+
raise LanguageModelNoResponseError(
|
324
|
+
f"Language model did not return a response for question '{question.question_name}.'"
|
325
|
+
)
|
326
|
+
|
327
|
+
# if it gets here, it means the no response error was fixed
|
328
|
+
if (
|
329
|
+
question.question_name in self.exceptions
|
330
|
+
and had_language_model_no_response_error
|
331
|
+
):
|
332
|
+
self.exceptions.record_fixed_question(question.question_name)
|
296
333
|
|
297
|
-
|
334
|
+
return response
|
335
|
+
|
336
|
+
try:
|
337
|
+
return await attempt_answer()
|
338
|
+
except RetryError as retry_error:
|
339
|
+
# All retries have failed for LanguageModelNoResponseError
|
340
|
+
original_error = retry_error.last_attempt.exception()
|
341
|
+
self._handle_exception(
|
342
|
+
original_error, self._get_invigilator(question), task
|
343
|
+
)
|
344
|
+
raise original_error # Re-raise the original error after handling
|
298
345
|
|
299
346
|
def _get_invigilator(self, question: QuestionBase) -> InvigilatorBase:
|
300
347
|
"""Return an invigilator for the given question.
|
@@ -334,14 +381,32 @@ class Interview(InterviewStatusMixin):
|
|
334
381
|
def _handle_exception(
|
335
382
|
self, e: Exception, invigilator: "InvigilatorBase", task=None
|
336
383
|
):
|
384
|
+
import copy
|
385
|
+
|
386
|
+
# breakpoint()
|
387
|
+
|
388
|
+
answers = copy.copy(self.answers)
|
337
389
|
exception_entry = InterviewExceptionEntry(
|
338
390
|
exception=e,
|
339
391
|
invigilator=invigilator,
|
392
|
+
answers=answers,
|
340
393
|
)
|
341
394
|
if task:
|
342
395
|
task.task_status = TaskStatus.FAILED
|
343
396
|
self.exceptions.add(invigilator.question.question_name, exception_entry)
|
344
397
|
|
398
|
+
if self.raise_validation_errors:
|
399
|
+
if isinstance(e, QuestionAnswerValidationError):
|
400
|
+
raise e
|
401
|
+
|
402
|
+
if hasattr(self, "stop_on_exception"):
|
403
|
+
stop_on_exception = self.stop_on_exception
|
404
|
+
else:
|
405
|
+
stop_on_exception = False
|
406
|
+
|
407
|
+
if stop_on_exception:
|
408
|
+
raise e
|
409
|
+
|
345
410
|
def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
|
346
411
|
"""Cancel the tasks for questions that are skipped.
|
347
412
|
|
@@ -353,11 +418,11 @@ class Interview(InterviewStatusMixin):
|
|
353
418
|
"""
|
354
419
|
current_question_index: int = self.to_index[current_question.question_name]
|
355
420
|
|
356
|
-
next_question: Union[
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
421
|
+
next_question: Union[
|
422
|
+
int, EndOfSurvey
|
423
|
+
] = self.survey.rule_collection.next_question(
|
424
|
+
q_now=current_question_index,
|
425
|
+
answers=self.answers | self.scenario | self.agent["traits"],
|
361
426
|
)
|
362
427
|
|
363
428
|
next_question_index = next_question.next_q
|
@@ -411,6 +476,7 @@ class Interview(InterviewStatusMixin):
|
|
411
476
|
asyncio.exceptions.CancelledError
|
412
477
|
"""
|
413
478
|
self.sidecar_model = sidecar_model
|
479
|
+
self.stop_on_exception = stop_on_exception
|
414
480
|
|
415
481
|
# if no model bucket is passed, create an 'infinity' bucket with no rate limits
|
416
482
|
if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
|
@@ -424,7 +490,9 @@ class Interview(InterviewStatusMixin):
|
|
424
490
|
self.invigilators = [
|
425
491
|
self._get_invigilator(question) for question in self.survey.questions
|
426
492
|
]
|
427
|
-
await asyncio.gather(
|
493
|
+
await asyncio.gather(
|
494
|
+
*self.tasks, return_exceptions=not stop_on_exception
|
495
|
+
) # not stop_on_exception)
|
428
496
|
self.answers.replace_missing_answers_with_none(self.survey)
|
429
497
|
valid_results = list(self._extract_valid_results())
|
430
498
|
return self.answers, valid_results
|
@@ -6,6 +6,22 @@ from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
|
6
6
|
class InterviewExceptionCollection(UserDict):
|
7
7
|
"""A collection of exceptions that occurred during the interview."""
|
8
8
|
|
9
|
+
def __init__(self):
|
10
|
+
super().__init__()
|
11
|
+
self.fixed = set()
|
12
|
+
|
13
|
+
def unfixed_exceptions(self) -> list:
|
14
|
+
"""Return a list of unfixed exceptions."""
|
15
|
+
return {k: v for k, v in self.data.items() if k not in self.fixed}
|
16
|
+
|
17
|
+
def num_unfixed(self) -> list:
|
18
|
+
"""Return a list of unfixed questions."""
|
19
|
+
return len([k for k in self.data.keys() if k not in self.fixed])
|
20
|
+
|
21
|
+
def record_fixed_question(self, question_name: str) -> None:
|
22
|
+
"""Record that a question has been fixed."""
|
23
|
+
self.fixed.add(question_name)
|
24
|
+
|
9
25
|
def add(self, question_name: str, entry: InterviewExceptionEntry) -> None:
|
10
26
|
"""Add an exception entry to the collection."""
|
11
27
|
question_name = question_name
|
@@ -15,12 +15,14 @@ class InterviewExceptionEntry:
|
|
15
15
|
# failed_question: FailedQuestion,
|
16
16
|
invigilator: "Invigilator",
|
17
17
|
traceback_format="text",
|
18
|
+
answers=None,
|
18
19
|
):
|
19
20
|
self.time = datetime.datetime.now().isoformat()
|
20
21
|
self.exception = exception
|
21
22
|
# self.failed_question = failed_question
|
22
23
|
self.invigilator = invigilator
|
23
24
|
self.traceback_format = traceback_format
|
25
|
+
self.answers = answers
|
24
26
|
|
25
27
|
@property
|
26
28
|
def question_type(self):
|
@@ -3,40 +3,25 @@ import time
|
|
3
3
|
import math
|
4
4
|
import asyncio
|
5
5
|
import functools
|
6
|
+
import threading
|
6
7
|
from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator
|
7
8
|
from contextlib import contextmanager
|
8
9
|
from collections import UserList
|
9
10
|
|
11
|
+
from edsl.results.Results import Results
|
12
|
+
from rich.live import Live
|
13
|
+
from rich.console import Console
|
14
|
+
|
10
15
|
from edsl import shared_globals
|
11
16
|
from edsl.jobs.interviews.Interview import Interview
|
12
|
-
from edsl.jobs.runners.
|
17
|
+
from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus
|
18
|
+
|
13
19
|
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
14
20
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
15
21
|
from edsl.utilities.decorators import jupyter_nb_handler
|
16
22
|
from edsl.data.Cache import Cache
|
17
23
|
from edsl.results.Result import Result
|
18
24
|
from edsl.results.Results import Results
|
19
|
-
from edsl.jobs.FailedQuestion import FailedQuestion
|
20
|
-
|
21
|
-
|
22
|
-
def cache_with_timeout(timeout):
|
23
|
-
""" "Used to keep the generate table from being run too frequetly."""
|
24
|
-
|
25
|
-
def decorator(func):
|
26
|
-
cached_result = {}
|
27
|
-
last_computation_time = [0] # Using list to store mutable value
|
28
|
-
|
29
|
-
@functools.wraps(func)
|
30
|
-
def wrapper(*args, **kwargs):
|
31
|
-
current_time = time.time()
|
32
|
-
if (current_time - last_computation_time[0]) >= timeout:
|
33
|
-
cached_result["value"] = func(*args, **kwargs)
|
34
|
-
last_computation_time[0] = current_time
|
35
|
-
return cached_result["value"]
|
36
|
-
|
37
|
-
return wrapper
|
38
|
-
|
39
|
-
return decorator
|
40
25
|
|
41
26
|
|
42
27
|
class StatusTracker(UserList):
|
@@ -48,7 +33,7 @@ class StatusTracker(UserList):
|
|
48
33
|
return print(f"Completed: {len(self.data)} of {self.total_tasks}", end="\r")
|
49
34
|
|
50
35
|
|
51
|
-
class JobsRunnerAsyncio
|
36
|
+
class JobsRunnerAsyncio:
|
52
37
|
"""A class for running a collection of interviews asynchronously.
|
53
38
|
|
54
39
|
It gets instaniated from a Jobs object.
|
@@ -57,11 +42,12 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
57
42
|
|
58
43
|
def __init__(self, jobs: "Jobs"):
|
59
44
|
self.jobs = jobs
|
60
|
-
# this creates the interviews, which can take a while
|
61
45
|
self.interviews: List["Interview"] = jobs.interviews()
|
62
46
|
self.bucket_collection: "BucketCollection" = jobs.bucket_collection
|
63
47
|
self.total_interviews: List["Interview"] = []
|
64
48
|
|
49
|
+
# self.jobs_runner_status = JobsRunnerStatus(self, n=1)
|
50
|
+
|
65
51
|
async def run_async_generator(
|
66
52
|
self,
|
67
53
|
cache: "Cache",
|
@@ -79,6 +65,7 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
79
65
|
:param stop_on_exception: Whether to stop the interview if an exception is raised
|
80
66
|
:param sidecar_model: a language model to use in addition to the interview's model
|
81
67
|
:param total_interviews: A list of interviews to run can be provided instead.
|
68
|
+
:param raise_validation_errors: Whether to raise validation errors
|
82
69
|
"""
|
83
70
|
tasks = []
|
84
71
|
if total_interviews: # was already passed in total interviews
|
@@ -88,8 +75,6 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
88
75
|
self._populate_total_interviews(n=n)
|
89
76
|
) # Populate self.total_interviews before creating tasks
|
90
77
|
|
91
|
-
# print("Interviews created")
|
92
|
-
|
93
78
|
for interview in self.total_interviews:
|
94
79
|
interviewing_task = self._build_interview_task(
|
95
80
|
interview=interview,
|
@@ -99,11 +84,9 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
99
84
|
)
|
100
85
|
tasks.append(asyncio.create_task(interviewing_task))
|
101
86
|
|
102
|
-
# print("Tasks created")
|
103
|
-
|
104
87
|
for task in asyncio.as_completed(tasks):
|
105
|
-
# print(f"Task {task} completed")
|
106
88
|
result = await task
|
89
|
+
self.jobs_runner_status.add_completed_interview(result)
|
107
90
|
yield result
|
108
91
|
|
109
92
|
def _populate_total_interviews(
|
@@ -122,6 +105,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
122
105
|
yield interview
|
123
106
|
|
124
107
|
async def run_async(self, cache: Optional["Cache"] = None, n: int = 1) -> Results:
|
108
|
+
"""Used for some other modules that have a non-standard way of running interviews."""
|
109
|
+
self.jobs_runner_status = JobsRunnerStatus(self, n=n)
|
125
110
|
self.cache = Cache() if cache is None else cache
|
126
111
|
data = []
|
127
112
|
async for result in self.run_async_generator(cache=self.cache, n=n):
|
@@ -157,12 +142,6 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
157
142
|
raise_validation_errors=raise_validation_errors,
|
158
143
|
)
|
159
144
|
|
160
|
-
# answer_key_names = {
|
161
|
-
# k
|
162
|
-
# for k in set(answer.keys())
|
163
|
-
# if not k.endswith("_comment") and not k.endswith("_generated_tokens")
|
164
|
-
# }
|
165
|
-
|
166
145
|
question_results = {}
|
167
146
|
for result in valid_results:
|
168
147
|
question_results[result.question_name] = result
|
@@ -174,24 +153,13 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
174
153
|
for k in answer_key_names
|
175
154
|
}
|
176
155
|
comments_dict = {
|
177
|
-
|
156
|
+
k + "_comment": question_results[k].comment for k in answer_key_names
|
178
157
|
}
|
179
158
|
|
180
159
|
# we should have a valid result for each question
|
181
160
|
answer_dict = {k: answer[k] for k in answer_key_names}
|
182
161
|
assert len(valid_results) == len(answer_key_names)
|
183
162
|
|
184
|
-
# breakpoint()
|
185
|
-
# generated_tokens_dict = {
|
186
|
-
# k + "_generated_tokens": v.generated_tokens
|
187
|
-
# for k, v in zip(answer_key_names, valid_results)
|
188
|
-
# }
|
189
|
-
|
190
|
-
# comments_dict = {
|
191
|
-
# k + "_comment": v.comment for k, v in zip(answer_key_names, valid_results)
|
192
|
-
# }
|
193
|
-
# breakpoint()
|
194
|
-
|
195
163
|
# TODO: move this down into Interview
|
196
164
|
question_name_to_prompts = dict({})
|
197
165
|
for result in valid_results:
|
@@ -226,7 +194,6 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
226
194
|
)
|
227
195
|
raw_model_results_dictionary[question_name + "_one_usd_buys"] = one_use_buys
|
228
196
|
|
229
|
-
# breakpoint()
|
230
197
|
result = Result(
|
231
198
|
agent=interview.agent,
|
232
199
|
scenario=interview.scenario,
|
@@ -247,6 +214,62 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
247
214
|
def elapsed_time(self):
|
248
215
|
return time.monotonic() - self.start_time
|
249
216
|
|
217
|
+
def process_results(
|
218
|
+
self, raw_results: Results, cache: Cache, print_exceptions: bool
|
219
|
+
):
|
220
|
+
interview_lookup = {
|
221
|
+
hash(interview): index
|
222
|
+
for index, interview in enumerate(self.total_interviews)
|
223
|
+
}
|
224
|
+
interview_hashes = list(interview_lookup.keys())
|
225
|
+
|
226
|
+
results = Results(
|
227
|
+
survey=self.jobs.survey,
|
228
|
+
data=sorted(
|
229
|
+
raw_results, key=lambda x: interview_hashes.index(x.interview_hash)
|
230
|
+
),
|
231
|
+
)
|
232
|
+
results.cache = cache
|
233
|
+
results.task_history = TaskHistory(
|
234
|
+
self.total_interviews, include_traceback=False
|
235
|
+
)
|
236
|
+
results.has_unfixed_exceptions = results.task_history.has_unfixed_exceptions
|
237
|
+
results.bucket_collection = self.bucket_collection
|
238
|
+
|
239
|
+
if results.has_unfixed_exceptions and print_exceptions:
|
240
|
+
from edsl.scenarios.FileStore import HTMLFileStore
|
241
|
+
from edsl.config import CONFIG
|
242
|
+
from edsl.coop.coop import Coop
|
243
|
+
|
244
|
+
msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
|
245
|
+
|
246
|
+
if len(results.task_history.indices) > 5:
|
247
|
+
msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
|
248
|
+
|
249
|
+
print(msg)
|
250
|
+
# this is where exceptions are opening up
|
251
|
+
filepath = results.task_history.html(
|
252
|
+
cta="Open report to see details.",
|
253
|
+
open_in_browser=True,
|
254
|
+
return_link=True,
|
255
|
+
)
|
256
|
+
|
257
|
+
try:
|
258
|
+
coop = Coop()
|
259
|
+
user_edsl_settings = coop.edsl_settings
|
260
|
+
remote_logging = user_edsl_settings["remote_logging"]
|
261
|
+
except Exception as e:
|
262
|
+
print(e)
|
263
|
+
remote_logging = False
|
264
|
+
if remote_logging:
|
265
|
+
filestore = HTMLFileStore(filepath)
|
266
|
+
coop_details = filestore.push(description="Error report")
|
267
|
+
print(coop_details)
|
268
|
+
|
269
|
+
print("Also see: https://docs.expectedparrot.com/en/latest/exceptions.html")
|
270
|
+
|
271
|
+
return results
|
272
|
+
|
250
273
|
@jupyter_nb_handler
|
251
274
|
async def run(
|
252
275
|
self,
|
@@ -259,24 +282,16 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
259
282
|
raise_validation_errors: bool = False,
|
260
283
|
) -> "Coroutine":
|
261
284
|
"""Runs a collection of interviews, handling both async and sync contexts."""
|
262
|
-
from rich.console import Console
|
263
285
|
|
264
|
-
console = Console()
|
265
286
|
self.results = []
|
266
287
|
self.start_time = time.monotonic()
|
267
288
|
self.completed = False
|
268
289
|
self.cache = cache
|
269
290
|
self.sidecar_model = sidecar_model
|
270
291
|
|
271
|
-
|
272
|
-
from rich.live import Live
|
273
|
-
from rich.console import Console
|
274
|
-
|
275
|
-
@cache_with_timeout(1)
|
276
|
-
def generate_table():
|
277
|
-
return self.status_table(self.results, self.elapsed_time)
|
292
|
+
self.jobs_runner_status = JobsRunnerStatus(self, n=n)
|
278
293
|
|
279
|
-
async def process_results(cache
|
294
|
+
async def process_results(cache):
|
280
295
|
"""Processes results from interviews."""
|
281
296
|
async for result in self.run_async_generator(
|
282
297
|
n=n,
|
@@ -286,112 +301,22 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
286
301
|
raise_validation_errors=raise_validation_errors,
|
287
302
|
):
|
288
303
|
self.results.append(result)
|
289
|
-
if progress_bar_context:
|
290
|
-
progress_bar_context.update(generate_table())
|
291
304
|
self.completed = True
|
292
305
|
|
293
|
-
|
294
|
-
"""
|
295
|
-
|
296
|
-
return
|
297
|
-
|
298
|
-
while True:
|
299
|
-
progress_bar_context.update(generate_table())
|
300
|
-
await asyncio.sleep(0.1) # Update interval
|
301
|
-
if self.completed:
|
302
|
-
break
|
303
|
-
|
304
|
-
@contextmanager
|
305
|
-
def conditional_context(condition, context_manager):
|
306
|
-
if condition:
|
307
|
-
with context_manager as cm:
|
308
|
-
yield cm
|
309
|
-
else:
|
310
|
-
yield
|
311
|
-
|
312
|
-
with conditional_context(
|
313
|
-
progress_bar, Live(generate_table(), console=console, refresh_per_second=1)
|
314
|
-
) as progress_bar_context:
|
315
|
-
with cache as c:
|
316
|
-
progress_task = asyncio.create_task(
|
317
|
-
update_progress_bar(progress_bar_context)
|
318
|
-
)
|
319
|
-
|
320
|
-
try:
|
321
|
-
await asyncio.gather(
|
322
|
-
progress_task,
|
323
|
-
process_results(
|
324
|
-
cache=c, progress_bar_context=progress_bar_context
|
325
|
-
),
|
326
|
-
)
|
327
|
-
except asyncio.CancelledError:
|
328
|
-
pass
|
329
|
-
finally:
|
330
|
-
progress_task.cancel() # Cancel the progress_task when process_results is done
|
331
|
-
await progress_task
|
332
|
-
|
333
|
-
await asyncio.sleep(1) # short delay to show the final status
|
334
|
-
|
335
|
-
if progress_bar_context:
|
336
|
-
progress_bar_context.update(generate_table())
|
337
|
-
|
338
|
-
# puts results in the same order as the total interviews
|
339
|
-
interview_lookup = {
|
340
|
-
hash(interview): index
|
341
|
-
for index, interview in enumerate(self.total_interviews)
|
342
|
-
}
|
343
|
-
interview_hashes = list(interview_lookup.keys())
|
344
|
-
self.results = sorted(
|
345
|
-
self.results, key=lambda x: interview_hashes.index(x.interview_hash)
|
346
|
-
)
|
347
|
-
|
348
|
-
results = Results(survey=self.jobs.survey, data=self.results)
|
349
|
-
task_history = TaskHistory(self.total_interviews, include_traceback=False)
|
350
|
-
results.task_history = task_history
|
351
|
-
|
352
|
-
results.failed_questions = {}
|
353
|
-
results.has_exceptions = task_history.has_exceptions
|
354
|
-
|
355
|
-
# breakpoint()
|
356
|
-
results.bucket_collection = self.bucket_collection
|
357
|
-
|
358
|
-
if results.has_exceptions:
|
359
|
-
# put the failed interviews in the results object as a list
|
360
|
-
failed_interviews = [
|
361
|
-
interview.duplicate(
|
362
|
-
iteration=interview.iteration, cache=interview.cache
|
363
|
-
)
|
364
|
-
for interview in self.total_interviews
|
365
|
-
if interview.has_exceptions
|
366
|
-
]
|
367
|
-
|
368
|
-
failed_questions = {}
|
369
|
-
for interview in self.total_interviews:
|
370
|
-
if interview.has_exceptions:
|
371
|
-
index = interview_lookup[hash(interview)]
|
372
|
-
failed_questions[index] = interview.failed_questions
|
306
|
+
def run_progress_bar():
|
307
|
+
"""Runs the progress bar in a separate thread."""
|
308
|
+
self.jobs_runner_status.update_progress()
|
373
309
|
|
374
|
-
|
310
|
+
if progress_bar:
|
311
|
+
progress_thread = threading.Thread(target=run_progress_bar)
|
312
|
+
progress_thread.start()
|
375
313
|
|
376
|
-
|
314
|
+
with cache as c:
|
315
|
+
await process_results(cache=c)
|
377
316
|
|
378
|
-
|
379
|
-
|
380
|
-
)
|
381
|
-
if print_exceptions:
|
382
|
-
msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
|
383
|
-
|
384
|
-
if len(results.task_history.indices) > 5:
|
385
|
-
msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
|
386
|
-
|
387
|
-
shared_globals["edsl_runner_exceptions"] = task_history
|
388
|
-
print(msg)
|
389
|
-
# this is where exceptions are opening up
|
390
|
-
task_history.html(
|
391
|
-
cta="Open report to see details.", open_in_browser=True
|
392
|
-
)
|
393
|
-
print(
|
394
|
-
"Also see: https://docs.expectedparrot.com/en/latest/exceptions.html"
|
395
|
-
)
|
317
|
+
if progress_bar:
|
318
|
+
progress_thread.join()
|
396
319
|
|
397
|
-
return
|
320
|
+
return self.process_results(
|
321
|
+
raw_results=self.results, cache=cache, print_exceptions=print_exceptions
|
322
|
+
)
|