edsl 0.1.31.dev3__py3-none-any.whl → 0.1.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +7 -2
- edsl/agents/PromptConstructionMixin.py +35 -15
- edsl/config.py +15 -1
- edsl/conjure/Conjure.py +6 -0
- edsl/coop/coop.py +4 -0
- edsl/data/CacheHandler.py +3 -4
- edsl/enums.py +5 -0
- edsl/exceptions/general.py +10 -8
- edsl/inference_services/AwsBedrock.py +110 -0
- edsl/inference_services/AzureAI.py +197 -0
- edsl/inference_services/DeepInfraService.py +6 -91
- edsl/inference_services/GroqService.py +18 -0
- edsl/inference_services/InferenceServicesCollection.py +13 -8
- edsl/inference_services/OllamaService.py +18 -0
- edsl/inference_services/OpenAIService.py +68 -21
- edsl/inference_services/models_available_cache.py +31 -0
- edsl/inference_services/registry.py +14 -1
- edsl/jobs/Jobs.py +103 -21
- edsl/jobs/buckets/TokenBucket.py +12 -4
- edsl/jobs/interviews/Interview.py +31 -9
- edsl/jobs/interviews/InterviewExceptionEntry.py +101 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +49 -33
- edsl/jobs/interviews/interview_exception_tracking.py +68 -10
- edsl/jobs/runners/JobsRunnerAsyncio.py +112 -81
- edsl/jobs/runners/JobsRunnerStatusData.py +0 -237
- edsl/jobs/runners/JobsRunnerStatusMixin.py +291 -35
- edsl/jobs/tasks/TaskCreators.py +8 -2
- edsl/jobs/tasks/TaskHistory.py +145 -1
- edsl/language_models/LanguageModel.py +62 -41
- edsl/language_models/registry.py +4 -0
- edsl/questions/QuestionBudget.py +0 -1
- edsl/questions/QuestionCheckBox.py +0 -1
- edsl/questions/QuestionExtract.py +0 -1
- edsl/questions/QuestionFreeText.py +2 -9
- edsl/questions/QuestionList.py +0 -1
- edsl/questions/QuestionMultipleChoice.py +1 -2
- edsl/questions/QuestionNumerical.py +0 -1
- edsl/questions/QuestionRank.py +0 -1
- edsl/results/DatasetExportMixin.py +33 -3
- edsl/scenarios/Scenario.py +14 -0
- edsl/scenarios/ScenarioList.py +216 -13
- edsl/scenarios/ScenarioListExportMixin.py +15 -4
- edsl/scenarios/ScenarioListPdfMixin.py +3 -0
- edsl/surveys/Rule.py +5 -2
- edsl/surveys/Survey.py +84 -1
- edsl/surveys/SurveyQualtricsImport.py +213 -0
- edsl/utilities/utilities.py +31 -0
- {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/METADATA +5 -1
- {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/RECORD +52 -46
- {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/LICENSE +0 -0
- {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/WHEEL +0 -0
@@ -14,8 +14,8 @@ from edsl.jobs.tasks.TaskCreators import TaskCreators
|
|
14
14
|
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
15
15
|
from edsl.jobs.interviews.interview_exception_tracking import (
|
16
16
|
InterviewExceptionCollection,
|
17
|
-
InterviewExceptionEntry,
|
18
17
|
)
|
18
|
+
from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
19
19
|
from edsl.jobs.interviews.retry_management import retry_strategy
|
20
20
|
from edsl.jobs.interviews.InterviewTaskBuildingMixin import InterviewTaskBuildingMixin
|
21
21
|
from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
|
@@ -44,6 +44,7 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
44
44
|
iteration: int = 0,
|
45
45
|
cache: Optional["Cache"] = None,
|
46
46
|
sidecar_model: Optional["LanguageModel"] = None,
|
47
|
+
skip_retry=False,
|
47
48
|
):
|
48
49
|
"""Initialize the Interview instance.
|
49
50
|
|
@@ -87,6 +88,7 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
87
88
|
self.task_creators = TaskCreators() # tracks the task creators
|
88
89
|
self.exceptions = InterviewExceptionCollection()
|
89
90
|
self._task_status_log_dict = InterviewStatusLog()
|
91
|
+
self.skip_retry = skip_retry
|
90
92
|
|
91
93
|
# dictionary mapping question names to their index in the survey.
|
92
94
|
self.to_index = {
|
@@ -94,6 +96,30 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
94
96
|
for index, question_name in enumerate(self.survey.question_names)
|
95
97
|
}
|
96
98
|
|
99
|
+
def _to_dict(self, include_exceptions=False) -> dict[str, Any]:
|
100
|
+
"""Return a dictionary representation of the Interview instance.
|
101
|
+
This is just for hashing purposes.
|
102
|
+
|
103
|
+
>>> i = Interview.example()
|
104
|
+
>>> hash(i)
|
105
|
+
1646262796627658719
|
106
|
+
"""
|
107
|
+
d = {
|
108
|
+
"agent": self.agent._to_dict(),
|
109
|
+
"survey": self.survey._to_dict(),
|
110
|
+
"scenario": self.scenario._to_dict(),
|
111
|
+
"model": self.model._to_dict(),
|
112
|
+
"iteration": self.iteration,
|
113
|
+
"exceptions": {},
|
114
|
+
}
|
115
|
+
if include_exceptions:
|
116
|
+
d["exceptions"] = self.exceptions.to_dict()
|
117
|
+
|
118
|
+
def __hash__(self) -> int:
|
119
|
+
from edsl.utilities.utilities import dict_hash
|
120
|
+
|
121
|
+
return dict_hash(self._to_dict())
|
122
|
+
|
97
123
|
async def async_conduct_interview(
|
98
124
|
self,
|
99
125
|
*,
|
@@ -134,8 +160,7 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
134
160
|
<BLANKLINE>
|
135
161
|
|
136
162
|
>>> i.exceptions
|
137
|
-
{'q0':
|
138
|
-
|
163
|
+
{'q0': ...
|
139
164
|
>>> i = Interview.example()
|
140
165
|
>>> result, _ = asyncio.run(i.async_conduct_interview(stop_on_exception = True))
|
141
166
|
Traceback (most recent call last):
|
@@ -204,13 +229,9 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
204
229
|
{}
|
205
230
|
>>> i._record_exception(i.tasks[0], Exception("An exception occurred."))
|
206
231
|
>>> i.exceptions
|
207
|
-
{'q0':
|
232
|
+
{'q0': ...
|
208
233
|
"""
|
209
|
-
exception_entry = InterviewExceptionEntry(
|
210
|
-
exception=repr(exception),
|
211
|
-
time=time.time(),
|
212
|
-
traceback=traceback.format_exc(),
|
213
|
-
)
|
234
|
+
exception_entry = InterviewExceptionEntry(exception)
|
214
235
|
self.exceptions.add(task.get_name(), exception_entry)
|
215
236
|
|
216
237
|
@property
|
@@ -251,6 +272,7 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
251
272
|
model=self.model,
|
252
273
|
iteration=iteration,
|
253
274
|
cache=cache,
|
275
|
+
skip_retry=self.skip_retry,
|
254
276
|
)
|
255
277
|
|
256
278
|
@classmethod
|
@@ -0,0 +1,101 @@
|
|
1
|
+
import traceback
|
2
|
+
import datetime
|
3
|
+
import time
|
4
|
+
from collections import UserDict
|
5
|
+
|
6
|
+
# traceback=traceback.format_exc(),
|
7
|
+
# traceback = frame_summary_to_dict(traceback.extract_tb(e.__traceback__))
|
8
|
+
# traceback = [frame_summary_to_dict(f) for f in traceback.extract_tb(e.__traceback__)]
|
9
|
+
|
10
|
+
|
11
|
+
class InterviewExceptionEntry:
|
12
|
+
"""Class to record an exception that occurred during the interview.
|
13
|
+
|
14
|
+
>>> entry = InterviewExceptionEntry.example()
|
15
|
+
>>> entry.to_dict()['exception']
|
16
|
+
"ValueError('An error occurred.')"
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, exception: Exception, traceback_format="html"):
|
20
|
+
self.time = datetime.datetime.now().isoformat()
|
21
|
+
self.exception = exception
|
22
|
+
self.traceback_format = traceback_format
|
23
|
+
|
24
|
+
def __getitem__(self, key):
|
25
|
+
# Support dict-like access obj['a']
|
26
|
+
return str(getattr(self, key))
|
27
|
+
|
28
|
+
@classmethod
|
29
|
+
def example(cls):
|
30
|
+
try:
|
31
|
+
raise ValueError("An error occurred.")
|
32
|
+
except Exception as e:
|
33
|
+
entry = InterviewExceptionEntry(e)
|
34
|
+
return entry
|
35
|
+
|
36
|
+
@property
|
37
|
+
def traceback(self):
|
38
|
+
"""Return the exception as HTML."""
|
39
|
+
if self.traceback_format == "html":
|
40
|
+
return self.html_traceback
|
41
|
+
else:
|
42
|
+
return self.text_traceback
|
43
|
+
|
44
|
+
@property
|
45
|
+
def text_traceback(self):
|
46
|
+
"""
|
47
|
+
>>> entry = InterviewExceptionEntry.example()
|
48
|
+
>>> entry.text_traceback
|
49
|
+
'Traceback (most recent call last):...'
|
50
|
+
"""
|
51
|
+
e = self.exception
|
52
|
+
tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
53
|
+
return tb_str
|
54
|
+
|
55
|
+
@property
|
56
|
+
def html_traceback(self):
|
57
|
+
from rich.console import Console
|
58
|
+
from rich.table import Table
|
59
|
+
from rich.traceback import Traceback
|
60
|
+
|
61
|
+
from io import StringIO
|
62
|
+
|
63
|
+
html_output = StringIO()
|
64
|
+
|
65
|
+
console = Console(file=html_output, record=True)
|
66
|
+
|
67
|
+
tb = Traceback.from_exception(
|
68
|
+
type(self.exception),
|
69
|
+
self.exception,
|
70
|
+
self.exception.__traceback__,
|
71
|
+
show_locals=True,
|
72
|
+
)
|
73
|
+
console.print(tb)
|
74
|
+
return html_output.getvalue()
|
75
|
+
|
76
|
+
def to_dict(self) -> dict:
|
77
|
+
"""Return the exception as a dictionary.
|
78
|
+
|
79
|
+
>>> entry = InterviewExceptionEntry.example()
|
80
|
+
>>> entry.to_dict()['exception']
|
81
|
+
"ValueError('An error occurred.')"
|
82
|
+
|
83
|
+
"""
|
84
|
+
return {
|
85
|
+
"exception": repr(self.exception),
|
86
|
+
"time": self.time,
|
87
|
+
"traceback": self.traceback,
|
88
|
+
}
|
89
|
+
|
90
|
+
def push(self):
|
91
|
+
from edsl import Coop
|
92
|
+
|
93
|
+
coop = Coop()
|
94
|
+
results = coop.error_create(self.to_dict())
|
95
|
+
return results
|
96
|
+
|
97
|
+
|
98
|
+
if __name__ == "__main__":
|
99
|
+
import doctest
|
100
|
+
|
101
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
@@ -12,16 +12,34 @@ from edsl.exceptions import InterviewTimeoutError
|
|
12
12
|
# from edsl.questions.QuestionBase import QuestionBase
|
13
13
|
from edsl.surveys.base import EndOfSurvey
|
14
14
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
15
|
-
from edsl.jobs.interviews.
|
15
|
+
from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
16
16
|
from edsl.jobs.interviews.retry_management import retry_strategy
|
17
17
|
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
18
18
|
from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
|
19
19
|
|
20
20
|
# from edsl.agents.InvigilatorBase import InvigilatorBase
|
21
21
|
|
22
|
+
from rich.console import Console
|
23
|
+
from rich.traceback import Traceback
|
24
|
+
|
22
25
|
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
23
26
|
|
24
27
|
|
28
|
+
def frame_summary_to_dict(frame):
|
29
|
+
"""
|
30
|
+
Convert a FrameSummary object to a dictionary.
|
31
|
+
|
32
|
+
:param frame: A traceback FrameSummary object
|
33
|
+
:return: A dictionary containing the frame's details
|
34
|
+
"""
|
35
|
+
return {
|
36
|
+
"filename": frame.filename,
|
37
|
+
"lineno": frame.lineno,
|
38
|
+
"name": frame.name,
|
39
|
+
"line": frame.line,
|
40
|
+
}
|
41
|
+
|
42
|
+
|
25
43
|
class InterviewTaskBuildingMixin:
|
26
44
|
def _build_invigilators(
|
27
45
|
self, debug: bool
|
@@ -148,7 +166,6 @@ class InterviewTaskBuildingMixin:
|
|
148
166
|
raise ValueError(f"Prompt is of type {type(prompt)}")
|
149
167
|
return len(combined_text) / 4.0
|
150
168
|
|
151
|
-
@retry_strategy
|
152
169
|
async def _answer_question_and_record_task(
|
153
170
|
self,
|
154
171
|
*,
|
@@ -163,22 +180,29 @@ class InterviewTaskBuildingMixin:
|
|
163
180
|
"""
|
164
181
|
from edsl.data_transfer_models import AgentResponseDict
|
165
182
|
|
166
|
-
|
167
|
-
|
183
|
+
async def _inner():
|
184
|
+
try:
|
185
|
+
invigilator = self._get_invigilator(question, debug=debug)
|
168
186
|
|
169
|
-
|
170
|
-
|
187
|
+
if self._skip_this_question(question):
|
188
|
+
return invigilator.get_failed_task_result()
|
171
189
|
|
172
|
-
|
173
|
-
|
174
|
-
|
190
|
+
response: AgentResponseDict = await self._attempt_to_answer_question(
|
191
|
+
invigilator, task
|
192
|
+
)
|
175
193
|
|
176
|
-
|
194
|
+
self._add_answer(response=response, question=question)
|
177
195
|
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
196
|
+
self._cancel_skipped_questions(question)
|
197
|
+
return AgentResponseDict(**response)
|
198
|
+
except Exception as e:
|
199
|
+
raise e
|
200
|
+
|
201
|
+
skip_rety = getattr(self, "skip_retry", False)
|
202
|
+
if not skip_rety:
|
203
|
+
_inner = retry_strategy(_inner)
|
204
|
+
|
205
|
+
return await _inner()
|
182
206
|
|
183
207
|
def _add_answer(
|
184
208
|
self, response: "AgentResponseDict", question: "QuestionBase"
|
@@ -203,38 +227,30 @@ class InterviewTaskBuildingMixin:
|
|
203
227
|
)
|
204
228
|
return skip
|
205
229
|
|
230
|
+
def _handle_exception(self, e, question_name: str, task=None):
|
231
|
+
exception_entry = InterviewExceptionEntry(e)
|
232
|
+
if task:
|
233
|
+
task.task_status = TaskStatus.FAILED
|
234
|
+
self.exceptions.add(question_name, exception_entry)
|
235
|
+
|
206
236
|
async def _attempt_to_answer_question(
|
207
|
-
self, invigilator: InvigilatorBase, task: asyncio.Task
|
208
|
-
) -> AgentResponseDict:
|
237
|
+
self, invigilator: "InvigilatorBase", task: asyncio.Task
|
238
|
+
) -> "AgentResponseDict":
|
209
239
|
"""Attempt to answer the question, and handle exceptions.
|
210
240
|
|
211
241
|
:param invigilator: the invigilator that will answer the question.
|
212
242
|
:param task: the task that is being run.
|
243
|
+
|
213
244
|
"""
|
214
245
|
try:
|
215
246
|
return await asyncio.wait_for(
|
216
247
|
invigilator.async_answer_question(), timeout=TIMEOUT
|
217
248
|
)
|
218
249
|
except asyncio.TimeoutError as e:
|
219
|
-
|
220
|
-
exception=repr(e),
|
221
|
-
time=time.time(),
|
222
|
-
traceback=traceback.format_exc(),
|
223
|
-
)
|
224
|
-
if task:
|
225
|
-
task.task_status = TaskStatus.FAILED
|
226
|
-
self.exceptions.add(invigilator.question.question_name, exception_entry)
|
227
|
-
|
250
|
+
self._handle_exception(e, invigilator.question.question_name, task)
|
228
251
|
raise InterviewTimeoutError(f"Task timed out after {TIMEOUT} seconds.")
|
229
252
|
except Exception as e:
|
230
|
-
|
231
|
-
exception=repr(e),
|
232
|
-
time=time.time(),
|
233
|
-
traceback=traceback.format_exc(),
|
234
|
-
)
|
235
|
-
if task:
|
236
|
-
task.task_status = TaskStatus.FAILED
|
237
|
-
self.exceptions.add(invigilator.question.question_name, exception_entry)
|
253
|
+
self._handle_exception(e, invigilator.question.question_name, task)
|
238
254
|
raise e
|
239
255
|
|
240
256
|
def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
|
@@ -1,18 +1,70 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
import traceback
|
2
|
+
import datetime
|
3
|
+
import time
|
3
4
|
from collections import UserDict
|
4
5
|
|
6
|
+
from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
5
7
|
|
6
|
-
|
7
|
-
|
8
|
+
# #traceback=traceback.format_exc(),
|
9
|
+
# #traceback = frame_summary_to_dict(traceback.extract_tb(e.__traceback__))
|
10
|
+
# #traceback = [frame_summary_to_dict(f) for f in traceback.extract_tb(e.__traceback__)]
|
8
11
|
|
9
|
-
|
10
|
-
|
11
|
-
super().__init__(data)
|
12
|
+
# class InterviewExceptionEntry:
|
13
|
+
# """Class to record an exception that occurred during the interview.
|
12
14
|
|
13
|
-
|
14
|
-
|
15
|
-
|
15
|
+
# >>> entry = InterviewExceptionEntry.example()
|
16
|
+
# >>> entry.to_dict()['exception']
|
17
|
+
# "ValueError('An error occurred.')"
|
18
|
+
# """
|
19
|
+
|
20
|
+
# def __init__(self, exception: Exception):
|
21
|
+
# self.time = datetime.datetime.now().isoformat()
|
22
|
+
# self.exception = exception
|
23
|
+
|
24
|
+
# def __getitem__(self, key):
|
25
|
+
# # Support dict-like access obj['a']
|
26
|
+
# return str(getattr(self, key))
|
27
|
+
|
28
|
+
# @classmethod
|
29
|
+
# def example(cls):
|
30
|
+
# try:
|
31
|
+
# raise ValueError("An error occurred.")
|
32
|
+
# except Exception as e:
|
33
|
+
# entry = InterviewExceptionEntry(e)
|
34
|
+
# return entry
|
35
|
+
|
36
|
+
# @property
|
37
|
+
# def traceback(self):
|
38
|
+
# """Return the exception as HTML."""
|
39
|
+
# e = self.exception
|
40
|
+
# tb_str = ''.join(traceback.format_exception(type(e), e, e.__traceback__))
|
41
|
+
# return tb_str
|
42
|
+
|
43
|
+
|
44
|
+
# @property
|
45
|
+
# def html(self):
|
46
|
+
# from rich.console import Console
|
47
|
+
# from rich.table import Table
|
48
|
+
# from rich.traceback import Traceback
|
49
|
+
|
50
|
+
# from io import StringIO
|
51
|
+
# html_output = StringIO()
|
52
|
+
|
53
|
+
# console = Console(file=html_output, record=True)
|
54
|
+
# tb = Traceback(show_locals=True)
|
55
|
+
# console.print(tb)
|
56
|
+
|
57
|
+
# tb = Traceback.from_exception(type(self.exception), self.exception, self.exception.__traceback__, show_locals=True)
|
58
|
+
# console.print(tb)
|
59
|
+
# return html_output.getvalue()
|
60
|
+
|
61
|
+
# def to_dict(self) -> dict:
|
62
|
+
# """Return the exception as a dictionary."""
|
63
|
+
# return {
|
64
|
+
# 'exception': repr(self.exception),
|
65
|
+
# 'time': self.time,
|
66
|
+
# 'traceback': self.traceback
|
67
|
+
# }
|
16
68
|
|
17
69
|
|
18
70
|
class InterviewExceptionCollection(UserDict):
|
@@ -84,3 +136,9 @@ class InterviewExceptionCollection(UserDict):
|
|
84
136
|
)
|
85
137
|
|
86
138
|
console.print(table)
|
139
|
+
|
140
|
+
|
141
|
+
if __name__ == "__main__":
|
142
|
+
import doctest
|
143
|
+
|
144
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
@@ -13,6 +13,40 @@ from edsl.jobs.tasks.TaskHistory import TaskHistory
|
|
13
13
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
14
14
|
from edsl.utilities.decorators import jupyter_nb_handler
|
15
15
|
|
16
|
+
import time
|
17
|
+
import functools
|
18
|
+
|
19
|
+
|
20
|
+
def cache_with_timeout(timeout):
|
21
|
+
def decorator(func):
|
22
|
+
cached_result = {}
|
23
|
+
last_computation_time = [0] # Using list to store mutable value
|
24
|
+
|
25
|
+
@functools.wraps(func)
|
26
|
+
def wrapper(*args, **kwargs):
|
27
|
+
current_time = time.time()
|
28
|
+
if (current_time - last_computation_time[0]) >= timeout:
|
29
|
+
cached_result["value"] = func(*args, **kwargs)
|
30
|
+
last_computation_time[0] = current_time
|
31
|
+
return cached_result["value"]
|
32
|
+
|
33
|
+
return wrapper
|
34
|
+
|
35
|
+
return decorator
|
36
|
+
|
37
|
+
|
38
|
+
# from queue import Queue
|
39
|
+
from collections import UserList
|
40
|
+
|
41
|
+
|
42
|
+
class StatusTracker(UserList):
|
43
|
+
def __init__(self, total_tasks: int):
|
44
|
+
self.total_tasks = total_tasks
|
45
|
+
super().__init__()
|
46
|
+
|
47
|
+
def current_status(self):
|
48
|
+
return print(f"Completed: {len(self.data)} of {self.total_tasks}", end="\r")
|
49
|
+
|
16
50
|
|
17
51
|
class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
18
52
|
"""A class for running a collection of interviews asynchronously.
|
@@ -43,7 +77,9 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
43
77
|
|
44
78
|
:param n: how many times to run each interview
|
45
79
|
:param debug:
|
46
|
-
:param stop_on_exception:
|
80
|
+
:param stop_on_exception: Whether to stop the interview if an exception is raised
|
81
|
+
:param sidecar_model: a language model to use in addition to the interview's model
|
82
|
+
:param total_interviews: A list of interviews to run can be provided instead.
|
47
83
|
"""
|
48
84
|
tasks = []
|
49
85
|
if total_interviews:
|
@@ -87,15 +123,18 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
87
123
|
) # set the cache for the first interview
|
88
124
|
self.total_interviews.append(interview)
|
89
125
|
|
90
|
-
async def run_async(self, cache=None) -> Results:
|
126
|
+
async def run_async(self, cache=None, n=1) -> Results:
|
91
127
|
from edsl.results.Results import Results
|
92
128
|
|
129
|
+
# breakpoint()
|
130
|
+
# tracker = StatusTracker(total_tasks=len(self.interviews))
|
131
|
+
|
93
132
|
if cache is None:
|
94
133
|
self.cache = Cache()
|
95
134
|
else:
|
96
135
|
self.cache = cache
|
97
136
|
data = []
|
98
|
-
async for result in self.run_async_generator(cache=self.cache):
|
137
|
+
async for result in self.run_async_generator(cache=self.cache, n=n):
|
99
138
|
data.append(result)
|
100
139
|
return Results(survey=self.jobs.survey, data=data)
|
101
140
|
|
@@ -173,6 +212,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
173
212
|
raw_model_response=raw_model_results_dictionary,
|
174
213
|
survey=interview.survey,
|
175
214
|
)
|
215
|
+
result.interview_hash = hash(interview)
|
216
|
+
|
176
217
|
return result
|
177
218
|
|
178
219
|
@property
|
@@ -201,97 +242,86 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
201
242
|
self.sidecar_model = sidecar_model
|
202
243
|
|
203
244
|
from edsl.results.Results import Results
|
245
|
+
from rich.live import Live
|
246
|
+
from rich.console import Console
|
204
247
|
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
async def process_results():
|
210
|
-
"""Processes results from interviews."""
|
211
|
-
async for result in self.run_async_generator(
|
212
|
-
n=n,
|
213
|
-
debug=debug,
|
214
|
-
stop_on_exception=stop_on_exception,
|
215
|
-
cache=c,
|
216
|
-
sidecar_model=sidecar_model,
|
217
|
-
):
|
218
|
-
self.results.append(result)
|
219
|
-
self.completed = True
|
220
|
-
|
221
|
-
await asyncio.gather(process_results())
|
222
|
-
|
223
|
-
results = Results(survey=self.jobs.survey, data=self.results)
|
224
|
-
else:
|
225
|
-
# print("Running with progress bar")
|
226
|
-
from rich.live import Live
|
227
|
-
from rich.console import Console
|
228
|
-
|
229
|
-
def generate_table():
|
230
|
-
return self.status_table(self.results, self.elapsed_time)
|
248
|
+
@cache_with_timeout(1)
|
249
|
+
def generate_table():
|
250
|
+
return self.status_table(self.results, self.elapsed_time)
|
231
251
|
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
252
|
+
async def process_results(cache, progress_bar_context=None):
|
253
|
+
"""Processes results from interviews."""
|
254
|
+
async for result in self.run_async_generator(
|
255
|
+
n=n,
|
256
|
+
debug=debug,
|
257
|
+
stop_on_exception=stop_on_exception,
|
258
|
+
cache=cache,
|
259
|
+
sidecar_model=sidecar_model,
|
260
|
+
):
|
261
|
+
self.results.append(result)
|
262
|
+
if progress_bar_context:
|
263
|
+
progress_bar_context.update(generate_table())
|
264
|
+
self.completed = True
|
265
|
+
|
266
|
+
async def update_progress_bar(progress_bar_context):
|
267
|
+
"""Updates the progress bar at fixed intervals."""
|
268
|
+
if progress_bar_context is None:
|
269
|
+
return
|
270
|
+
|
271
|
+
while True:
|
272
|
+
progress_bar_context.update(generate_table())
|
273
|
+
await asyncio.sleep(0.1) # Update interval
|
274
|
+
if self.completed:
|
275
|
+
break
|
276
|
+
|
277
|
+
@contextmanager
|
278
|
+
def conditional_context(condition, context_manager):
|
279
|
+
if condition:
|
280
|
+
with context_manager as cm:
|
281
|
+
yield cm
|
282
|
+
else:
|
283
|
+
yield
|
284
|
+
|
285
|
+
with conditional_context(
|
286
|
+
progress_bar, Live(generate_table(), console=console, refresh_per_second=1)
|
287
|
+
) as progress_bar_context:
|
288
|
+
with cache as c:
|
289
|
+
progress_task = asyncio.create_task(
|
290
|
+
update_progress_bar(progress_bar_context)
|
291
|
+
)
|
236
292
|
|
237
|
-
|
238
|
-
|
239
|
-
|
293
|
+
try:
|
294
|
+
await asyncio.gather(
|
295
|
+
progress_task,
|
296
|
+
process_results(
|
297
|
+
cache=c, progress_bar_context=progress_bar_context
|
298
|
+
),
|
299
|
+
)
|
300
|
+
except asyncio.CancelledError:
|
240
301
|
pass
|
302
|
+
finally:
|
303
|
+
progress_task.cancel() # Cancel the progress_task when process_results is done
|
304
|
+
await progress_task
|
241
305
|
|
242
|
-
|
243
|
-
Live(generate_table(), console=console, refresh_per_second=5)
|
244
|
-
if progress_bar
|
245
|
-
else no_op_cm()
|
246
|
-
)
|
306
|
+
await asyncio.sleep(1) # short delay to show the final status
|
247
307
|
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
if self.completed:
|
257
|
-
break
|
258
|
-
|
259
|
-
async def process_results():
|
260
|
-
"""Processes results from interviews."""
|
261
|
-
async for result in self.run_async_generator(
|
262
|
-
n=n,
|
263
|
-
debug=debug,
|
264
|
-
stop_on_exception=stop_on_exception,
|
265
|
-
cache=c,
|
266
|
-
sidecar_model=sidecar_model,
|
267
|
-
):
|
268
|
-
self.results.append(result)
|
269
|
-
live.update(generate_table())
|
270
|
-
self.completed = True
|
271
|
-
|
272
|
-
progress_task = asyncio.create_task(update_progress_bar())
|
273
|
-
|
274
|
-
try:
|
275
|
-
await asyncio.gather(process_results(), progress_task)
|
276
|
-
except asyncio.CancelledError:
|
277
|
-
pass
|
278
|
-
finally:
|
279
|
-
progress_task.cancel() # Cancel the progress_task when process_results is done
|
280
|
-
await progress_task
|
281
|
-
|
282
|
-
await asyncio.sleep(1) # short delay to show the final status
|
283
|
-
|
284
|
-
# one more update
|
285
|
-
live.update(generate_table())
|
286
|
-
|
287
|
-
results = Results(survey=self.jobs.survey, data=self.results)
|
308
|
+
if progress_bar_context:
|
309
|
+
progress_bar_context.update(generate_table())
|
310
|
+
|
311
|
+
# puts results in the same order as the total interviews
|
312
|
+
interview_hashes = [hash(interview) for interview in self.total_interviews]
|
313
|
+
self.results = sorted(
|
314
|
+
self.results, key=lambda x: interview_hashes.index(x.interview_hash)
|
315
|
+
)
|
288
316
|
|
317
|
+
results = Results(survey=self.jobs.survey, data=self.results)
|
289
318
|
task_history = TaskHistory(self.total_interviews, include_traceback=False)
|
290
319
|
results.task_history = task_history
|
291
320
|
|
292
321
|
results.has_exceptions = task_history.has_exceptions
|
293
322
|
|
294
323
|
if results.has_exceptions:
|
324
|
+
# put the failed interviews in the results object as a list
|
295
325
|
failed_interviews = [
|
296
326
|
interview.duplicate(
|
297
327
|
iteration=interview.iteration, cache=interview.cache
|
@@ -312,6 +342,7 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
312
342
|
|
313
343
|
shared_globals["edsl_runner_exceptions"] = task_history
|
314
344
|
print(msg)
|
345
|
+
# this is where exceptions are opening up
|
315
346
|
task_history.html(cta="Open report to see details.")
|
316
347
|
print(
|
317
348
|
"Also see: https://docs.expectedparrot.com/en/latest/exceptions.html"
|