edsl 0.1.31.dev4__py3-none-any.whl → 0.1.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +3 -4
- edsl/agents/PromptConstructionMixin.py +35 -15
- edsl/config.py +11 -1
- edsl/conjure/Conjure.py +6 -0
- edsl/data/CacheHandler.py +3 -4
- edsl/enums.py +4 -0
- edsl/exceptions/general.py +10 -8
- edsl/inference_services/AwsBedrock.py +110 -0
- edsl/inference_services/AzureAI.py +197 -0
- edsl/inference_services/DeepInfraService.py +4 -3
- edsl/inference_services/GroqService.py +3 -4
- edsl/inference_services/InferenceServicesCollection.py +13 -8
- edsl/inference_services/OllamaService.py +18 -0
- edsl/inference_services/OpenAIService.py +23 -18
- edsl/inference_services/models_available_cache.py +31 -0
- edsl/inference_services/registry.py +13 -1
- edsl/jobs/Jobs.py +100 -19
- edsl/jobs/buckets/TokenBucket.py +12 -4
- edsl/jobs/interviews/Interview.py +31 -9
- edsl/jobs/interviews/InterviewExceptionEntry.py +101 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +49 -34
- edsl/jobs/interviews/interview_exception_tracking.py +68 -10
- edsl/jobs/runners/JobsRunnerAsyncio.py +36 -15
- edsl/jobs/runners/JobsRunnerStatusMixin.py +81 -51
- edsl/jobs/tasks/TaskCreators.py +1 -1
- edsl/jobs/tasks/TaskHistory.py +145 -1
- edsl/language_models/LanguageModel.py +58 -43
- edsl/language_models/registry.py +2 -2
- edsl/questions/QuestionBudget.py +0 -1
- edsl/questions/QuestionCheckBox.py +0 -1
- edsl/questions/QuestionExtract.py +0 -1
- edsl/questions/QuestionFreeText.py +2 -9
- edsl/questions/QuestionList.py +0 -1
- edsl/questions/QuestionMultipleChoice.py +1 -2
- edsl/questions/QuestionNumerical.py +0 -1
- edsl/questions/QuestionRank.py +0 -1
- edsl/results/DatasetExportMixin.py +33 -3
- edsl/scenarios/Scenario.py +14 -0
- edsl/scenarios/ScenarioList.py +216 -13
- edsl/scenarios/ScenarioListExportMixin.py +15 -4
- edsl/scenarios/ScenarioListPdfMixin.py +3 -0
- edsl/surveys/Rule.py +5 -2
- edsl/surveys/Survey.py +84 -1
- edsl/surveys/SurveyQualtricsImport.py +213 -0
- edsl/utilities/utilities.py +31 -0
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/METADATA +4 -1
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/RECORD +50 -45
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/LICENSE +0 -0
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/WHEEL +0 -0
@@ -12,16 +12,34 @@ from edsl.exceptions import InterviewTimeoutError
|
|
12
12
|
# from edsl.questions.QuestionBase import QuestionBase
|
13
13
|
from edsl.surveys.base import EndOfSurvey
|
14
14
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
15
|
-
from edsl.jobs.interviews.
|
15
|
+
from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
16
16
|
from edsl.jobs.interviews.retry_management import retry_strategy
|
17
17
|
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
18
18
|
from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
|
19
19
|
|
20
20
|
# from edsl.agents.InvigilatorBase import InvigilatorBase
|
21
21
|
|
22
|
+
from rich.console import Console
|
23
|
+
from rich.traceback import Traceback
|
24
|
+
|
22
25
|
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
23
26
|
|
24
27
|
|
28
|
+
def frame_summary_to_dict(frame):
|
29
|
+
"""
|
30
|
+
Convert a FrameSummary object to a dictionary.
|
31
|
+
|
32
|
+
:param frame: A traceback FrameSummary object
|
33
|
+
:return: A dictionary containing the frame's details
|
34
|
+
"""
|
35
|
+
return {
|
36
|
+
"filename": frame.filename,
|
37
|
+
"lineno": frame.lineno,
|
38
|
+
"name": frame.name,
|
39
|
+
"line": frame.line,
|
40
|
+
}
|
41
|
+
|
42
|
+
|
25
43
|
class InterviewTaskBuildingMixin:
|
26
44
|
def _build_invigilators(
|
27
45
|
self, debug: bool
|
@@ -148,7 +166,6 @@ class InterviewTaskBuildingMixin:
|
|
148
166
|
raise ValueError(f"Prompt is of type {type(prompt)}")
|
149
167
|
return len(combined_text) / 4.0
|
150
168
|
|
151
|
-
@retry_strategy
|
152
169
|
async def _answer_question_and_record_task(
|
153
170
|
self,
|
154
171
|
*,
|
@@ -163,22 +180,29 @@ class InterviewTaskBuildingMixin:
|
|
163
180
|
"""
|
164
181
|
from edsl.data_transfer_models import AgentResponseDict
|
165
182
|
|
166
|
-
|
167
|
-
|
183
|
+
async def _inner():
|
184
|
+
try:
|
185
|
+
invigilator = self._get_invigilator(question, debug=debug)
|
168
186
|
|
169
|
-
|
170
|
-
|
187
|
+
if self._skip_this_question(question):
|
188
|
+
return invigilator.get_failed_task_result()
|
171
189
|
|
172
|
-
|
173
|
-
|
174
|
-
|
190
|
+
response: AgentResponseDict = await self._attempt_to_answer_question(
|
191
|
+
invigilator, task
|
192
|
+
)
|
175
193
|
|
176
|
-
|
194
|
+
self._add_answer(response=response, question=question)
|
177
195
|
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
196
|
+
self._cancel_skipped_questions(question)
|
197
|
+
return AgentResponseDict(**response)
|
198
|
+
except Exception as e:
|
199
|
+
raise e
|
200
|
+
|
201
|
+
skip_rety = getattr(self, "skip_retry", False)
|
202
|
+
if not skip_rety:
|
203
|
+
_inner = retry_strategy(_inner)
|
204
|
+
|
205
|
+
return await _inner()
|
182
206
|
|
183
207
|
def _add_answer(
|
184
208
|
self, response: "AgentResponseDict", question: "QuestionBase"
|
@@ -203,39 +227,30 @@ class InterviewTaskBuildingMixin:
|
|
203
227
|
)
|
204
228
|
return skip
|
205
229
|
|
230
|
+
def _handle_exception(self, e, question_name: str, task=None):
|
231
|
+
exception_entry = InterviewExceptionEntry(e)
|
232
|
+
if task:
|
233
|
+
task.task_status = TaskStatus.FAILED
|
234
|
+
self.exceptions.add(question_name, exception_entry)
|
235
|
+
|
206
236
|
async def _attempt_to_answer_question(
|
207
|
-
self, invigilator:
|
208
|
-
) ->
|
237
|
+
self, invigilator: "InvigilatorBase", task: asyncio.Task
|
238
|
+
) -> "AgentResponseDict":
|
209
239
|
"""Attempt to answer the question, and handle exceptions.
|
210
240
|
|
211
241
|
:param invigilator: the invigilator that will answer the question.
|
212
242
|
:param task: the task that is being run.
|
213
|
-
|
243
|
+
|
214
244
|
"""
|
215
245
|
try:
|
216
246
|
return await asyncio.wait_for(
|
217
247
|
invigilator.async_answer_question(), timeout=TIMEOUT
|
218
248
|
)
|
219
249
|
except asyncio.TimeoutError as e:
|
220
|
-
|
221
|
-
exception=repr(e),
|
222
|
-
time=time.time(),
|
223
|
-
traceback=traceback.format_exc(),
|
224
|
-
)
|
225
|
-
if task:
|
226
|
-
task.task_status = TaskStatus.FAILED
|
227
|
-
self.exceptions.add(invigilator.question.question_name, exception_entry)
|
228
|
-
|
250
|
+
self._handle_exception(e, invigilator.question.question_name, task)
|
229
251
|
raise InterviewTimeoutError(f"Task timed out after {TIMEOUT} seconds.")
|
230
252
|
except Exception as e:
|
231
|
-
|
232
|
-
exception=repr(e),
|
233
|
-
time=time.time(),
|
234
|
-
traceback=traceback.format_exc(),
|
235
|
-
)
|
236
|
-
if task:
|
237
|
-
task.task_status = TaskStatus.FAILED
|
238
|
-
self.exceptions.add(invigilator.question.question_name, exception_entry)
|
253
|
+
self._handle_exception(e, invigilator.question.question_name, task)
|
239
254
|
raise e
|
240
255
|
|
241
256
|
def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
|
@@ -1,18 +1,70 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
import traceback
|
2
|
+
import datetime
|
3
|
+
import time
|
3
4
|
from collections import UserDict
|
4
5
|
|
6
|
+
from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
5
7
|
|
6
|
-
|
7
|
-
|
8
|
+
# #traceback=traceback.format_exc(),
|
9
|
+
# #traceback = frame_summary_to_dict(traceback.extract_tb(e.__traceback__))
|
10
|
+
# #traceback = [frame_summary_to_dict(f) for f in traceback.extract_tb(e.__traceback__)]
|
8
11
|
|
9
|
-
|
10
|
-
|
11
|
-
super().__init__(data)
|
12
|
+
# class InterviewExceptionEntry:
|
13
|
+
# """Class to record an exception that occurred during the interview.
|
12
14
|
|
13
|
-
|
14
|
-
|
15
|
-
|
15
|
+
# >>> entry = InterviewExceptionEntry.example()
|
16
|
+
# >>> entry.to_dict()['exception']
|
17
|
+
# "ValueError('An error occurred.')"
|
18
|
+
# """
|
19
|
+
|
20
|
+
# def __init__(self, exception: Exception):
|
21
|
+
# self.time = datetime.datetime.now().isoformat()
|
22
|
+
# self.exception = exception
|
23
|
+
|
24
|
+
# def __getitem__(self, key):
|
25
|
+
# # Support dict-like access obj['a']
|
26
|
+
# return str(getattr(self, key))
|
27
|
+
|
28
|
+
# @classmethod
|
29
|
+
# def example(cls):
|
30
|
+
# try:
|
31
|
+
# raise ValueError("An error occurred.")
|
32
|
+
# except Exception as e:
|
33
|
+
# entry = InterviewExceptionEntry(e)
|
34
|
+
# return entry
|
35
|
+
|
36
|
+
# @property
|
37
|
+
# def traceback(self):
|
38
|
+
# """Return the exception as HTML."""
|
39
|
+
# e = self.exception
|
40
|
+
# tb_str = ''.join(traceback.format_exception(type(e), e, e.__traceback__))
|
41
|
+
# return tb_str
|
42
|
+
|
43
|
+
|
44
|
+
# @property
|
45
|
+
# def html(self):
|
46
|
+
# from rich.console import Console
|
47
|
+
# from rich.table import Table
|
48
|
+
# from rich.traceback import Traceback
|
49
|
+
|
50
|
+
# from io import StringIO
|
51
|
+
# html_output = StringIO()
|
52
|
+
|
53
|
+
# console = Console(file=html_output, record=True)
|
54
|
+
# tb = Traceback(show_locals=True)
|
55
|
+
# console.print(tb)
|
56
|
+
|
57
|
+
# tb = Traceback.from_exception(type(self.exception), self.exception, self.exception.__traceback__, show_locals=True)
|
58
|
+
# console.print(tb)
|
59
|
+
# return html_output.getvalue()
|
60
|
+
|
61
|
+
# def to_dict(self) -> dict:
|
62
|
+
# """Return the exception as a dictionary."""
|
63
|
+
# return {
|
64
|
+
# 'exception': repr(self.exception),
|
65
|
+
# 'time': self.time,
|
66
|
+
# 'traceback': self.traceback
|
67
|
+
# }
|
16
68
|
|
17
69
|
|
18
70
|
class InterviewExceptionCollection(UserDict):
|
@@ -84,3 +136,9 @@ class InterviewExceptionCollection(UserDict):
|
|
84
136
|
)
|
85
137
|
|
86
138
|
console.print(table)
|
139
|
+
|
140
|
+
|
141
|
+
if __name__ == "__main__":
|
142
|
+
import doctest
|
143
|
+
|
144
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
@@ -16,6 +16,7 @@ from edsl.utilities.decorators import jupyter_nb_handler
|
|
16
16
|
import time
|
17
17
|
import functools
|
18
18
|
|
19
|
+
|
19
20
|
def cache_with_timeout(timeout):
|
20
21
|
def decorator(func):
|
21
22
|
cached_result = {}
|
@@ -25,23 +26,27 @@ def cache_with_timeout(timeout):
|
|
25
26
|
def wrapper(*args, **kwargs):
|
26
27
|
current_time = time.time()
|
27
28
|
if (current_time - last_computation_time[0]) >= timeout:
|
28
|
-
cached_result[
|
29
|
+
cached_result["value"] = func(*args, **kwargs)
|
29
30
|
last_computation_time[0] = current_time
|
30
|
-
return cached_result[
|
31
|
-
|
31
|
+
return cached_result["value"]
|
32
|
+
|
32
33
|
return wrapper
|
34
|
+
|
33
35
|
return decorator
|
34
36
|
|
35
|
-
|
37
|
+
|
38
|
+
# from queue import Queue
|
36
39
|
from collections import UserList
|
37
40
|
|
41
|
+
|
38
42
|
class StatusTracker(UserList):
|
39
43
|
def __init__(self, total_tasks: int):
|
40
44
|
self.total_tasks = total_tasks
|
41
45
|
super().__init__()
|
42
|
-
|
46
|
+
|
43
47
|
def current_status(self):
|
44
|
-
return print(f"Completed: {len(self.data)} of {self.total_tasks}", end
|
48
|
+
return print(f"Completed: {len(self.data)} of {self.total_tasks}", end="\r")
|
49
|
+
|
45
50
|
|
46
51
|
class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
47
52
|
"""A class for running a collection of interviews asynchronously.
|
@@ -121,8 +126,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
121
126
|
async def run_async(self, cache=None, n=1) -> Results:
|
122
127
|
from edsl.results.Results import Results
|
123
128
|
|
124
|
-
#breakpoint()
|
125
|
-
#tracker = StatusTracker(total_tasks=len(self.interviews))
|
129
|
+
# breakpoint()
|
130
|
+
# tracker = StatusTracker(total_tasks=len(self.interviews))
|
126
131
|
|
127
132
|
if cache is None:
|
128
133
|
self.cache = Cache()
|
@@ -207,6 +212,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
207
212
|
raw_model_response=raw_model_results_dictionary,
|
208
213
|
survey=interview.survey,
|
209
214
|
)
|
215
|
+
result.interview_hash = hash(interview)
|
216
|
+
|
210
217
|
return result
|
211
218
|
|
212
219
|
@property
|
@@ -242,7 +249,7 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
242
249
|
def generate_table():
|
243
250
|
return self.status_table(self.results, self.elapsed_time)
|
244
251
|
|
245
|
-
async def process_results(cache, progress_bar_context
|
252
|
+
async def process_results(cache, progress_bar_context=None):
|
246
253
|
"""Processes results from interviews."""
|
247
254
|
async for result in self.run_async_generator(
|
248
255
|
n=n,
|
@@ -275,16 +282,23 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
275
282
|
else:
|
276
283
|
yield
|
277
284
|
|
278
|
-
with conditional_context(
|
279
|
-
|
285
|
+
with conditional_context(
|
286
|
+
progress_bar, Live(generate_table(), console=console, refresh_per_second=1)
|
287
|
+
) as progress_bar_context:
|
280
288
|
with cache as c:
|
281
|
-
|
282
|
-
|
289
|
+
progress_task = asyncio.create_task(
|
290
|
+
update_progress_bar(progress_bar_context)
|
291
|
+
)
|
283
292
|
|
284
293
|
try:
|
285
|
-
await asyncio.gather(
|
294
|
+
await asyncio.gather(
|
295
|
+
progress_task,
|
296
|
+
process_results(
|
297
|
+
cache=c, progress_bar_context=progress_bar_context
|
298
|
+
),
|
299
|
+
)
|
286
300
|
except asyncio.CancelledError:
|
287
|
-
|
301
|
+
pass
|
288
302
|
finally:
|
289
303
|
progress_task.cancel() # Cancel the progress_task when process_results is done
|
290
304
|
await progress_task
|
@@ -294,6 +308,11 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
294
308
|
if progress_bar_context:
|
295
309
|
progress_bar_context.update(generate_table())
|
296
310
|
|
311
|
+
# puts results in the same order as the total interviews
|
312
|
+
interview_hashes = [hash(interview) for interview in self.total_interviews]
|
313
|
+
self.results = sorted(
|
314
|
+
self.results, key=lambda x: interview_hashes.index(x.interview_hash)
|
315
|
+
)
|
297
316
|
|
298
317
|
results = Results(survey=self.jobs.survey, data=self.results)
|
299
318
|
task_history = TaskHistory(self.total_interviews, include_traceback=False)
|
@@ -302,6 +321,7 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
302
321
|
results.has_exceptions = task_history.has_exceptions
|
303
322
|
|
304
323
|
if results.has_exceptions:
|
324
|
+
# put the failed interviews in the results object as a list
|
305
325
|
failed_interviews = [
|
306
326
|
interview.duplicate(
|
307
327
|
iteration=interview.iteration, cache=interview.cache
|
@@ -322,6 +342,7 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
322
342
|
|
323
343
|
shared_globals["edsl_runner_exceptions"] = task_history
|
324
344
|
print(msg)
|
345
|
+
# this is where exceptions are opening up
|
325
346
|
task_history.html(cta="Open report to see details.")
|
326
347
|
print(
|
327
348
|
"Also see: https://docs.expectedparrot.com/en/latest/exceptions.html"
|
@@ -22,7 +22,7 @@ from edsl.jobs.interviews.InterviewStatisticsCollection import (
|
|
22
22
|
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
23
23
|
|
24
24
|
|
25
|
-
#return {"cache_status": token_usage_type, "details": details, "cost": f"${token_usage.cost(prices):.5f}"}
|
25
|
+
# return {"cache_status": token_usage_type, "details": details, "cost": f"${token_usage.cost(prices):.5f}"}
|
26
26
|
|
27
27
|
from dataclasses import dataclass, asdict
|
28
28
|
|
@@ -30,6 +30,7 @@ from rich.text import Text
|
|
30
30
|
from rich.box import SIMPLE
|
31
31
|
from rich.table import Table
|
32
32
|
|
33
|
+
|
33
34
|
@dataclass
|
34
35
|
class ModelInfo:
|
35
36
|
model_name: str
|
@@ -45,18 +46,13 @@ class ModelTokenUsageStats:
|
|
45
46
|
details: List[dict]
|
46
47
|
cost: str
|
47
48
|
|
48
|
-
class Stats:
|
49
49
|
|
50
|
+
class Stats:
|
50
51
|
def elapsed_time(self):
|
51
|
-
InterviewStatistic(
|
52
|
-
"elapsed_time", value=elapsed_time, digits=1, units="sec."
|
53
|
-
)
|
54
|
-
|
55
|
-
|
52
|
+
InterviewStatistic("elapsed_time", value=elapsed_time, digits=1, units="sec.")
|
56
53
|
|
57
54
|
|
58
55
|
class JobsRunnerStatusMixin:
|
59
|
-
|
60
56
|
# @staticmethod
|
61
57
|
# def status_dict(interviews: List[Type["Interview"]]) -> List[Type[InterviewStatusDictionary]]:
|
62
58
|
# """
|
@@ -68,7 +64,6 @@ class JobsRunnerStatusMixin:
|
|
68
64
|
# return [interview.interview_status for interview in interviews]
|
69
65
|
|
70
66
|
def _compute_statistic(stat_name: str, completed_tasks, elapsed_time, interviews):
|
71
|
-
|
72
67
|
stat_definitions = {
|
73
68
|
"elapsed_time": lambda: InterviewStatistic(
|
74
69
|
"elapsed_time", value=elapsed_time, digits=1, units="sec."
|
@@ -101,36 +96,49 @@ class JobsRunnerStatusMixin:
|
|
101
96
|
"estimated_time_remaining": lambda: InterviewStatistic(
|
102
97
|
"estimated_time_remaining",
|
103
98
|
value=(
|
104
|
-
(len(interviews) - len(completed_tasks))
|
99
|
+
(len(interviews) - len(completed_tasks))
|
100
|
+
* (elapsed_time / len(completed_tasks))
|
105
101
|
if len(completed_tasks) > 0
|
106
102
|
else "NA"
|
107
103
|
),
|
108
104
|
digits=1,
|
109
105
|
units="sec.",
|
110
|
-
)
|
111
|
-
|
106
|
+
),
|
107
|
+
}
|
112
108
|
if stat_name not in stat_definitions:
|
113
|
-
raise ValueError(
|
109
|
+
raise ValueError(
|
110
|
+
f"Invalid stat_name: {stat_name}. The valid stat_names are: {list(stat_definitions.keys())}"
|
111
|
+
)
|
114
112
|
return stat_definitions[stat_name]()
|
115
113
|
|
116
|
-
|
117
114
|
@staticmethod
|
118
|
-
def _job_level_info(
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
115
|
+
def _job_level_info(
|
116
|
+
completed_tasks: List[Type[asyncio.Task]],
|
117
|
+
elapsed_time: float,
|
118
|
+
interviews: List[Type["Interview"]],
|
119
|
+
) -> InterviewStatisticsCollection:
|
123
120
|
interview_statistics = InterviewStatisticsCollection()
|
124
121
|
|
125
|
-
default_statistics = [
|
122
|
+
default_statistics = [
|
123
|
+
"elapsed_time",
|
124
|
+
"total_interviews_requested",
|
125
|
+
"completed_interviews",
|
126
|
+
"percent_complete",
|
127
|
+
"average_time_per_interview",
|
128
|
+
"task_remaining",
|
129
|
+
"estimated_time_remaining",
|
130
|
+
]
|
126
131
|
for stat_name in default_statistics:
|
127
|
-
interview_statistics.add_stat(
|
132
|
+
interview_statistics.add_stat(
|
133
|
+
JobsRunnerStatusMixin._compute_statistic(
|
134
|
+
stat_name, completed_tasks, elapsed_time, interviews
|
135
|
+
)
|
136
|
+
)
|
128
137
|
|
129
138
|
return interview_statistics
|
130
139
|
|
131
140
|
@staticmethod
|
132
141
|
def _get_model_queues_info(interviews):
|
133
|
-
|
134
142
|
models_to_tokens = defaultdict(InterviewTokenUsage)
|
135
143
|
model_to_status = defaultdict(InterviewStatusDictionary)
|
136
144
|
waiting_dict = defaultdict(int)
|
@@ -141,14 +149,16 @@ class JobsRunnerStatusMixin:
|
|
141
149
|
waiting_dict[interview.model] += interview.interview_status.waiting
|
142
150
|
|
143
151
|
for model, num_waiting in waiting_dict.items():
|
144
|
-
yield JobsRunnerStatusMixin._get_model_info(
|
152
|
+
yield JobsRunnerStatusMixin._get_model_info(
|
153
|
+
model, num_waiting, models_to_tokens
|
154
|
+
)
|
145
155
|
|
146
156
|
@staticmethod
|
147
157
|
def generate_status_summary(
|
148
158
|
completed_tasks: List[Type[asyncio.Task]],
|
149
159
|
elapsed_time: float,
|
150
160
|
interviews: List[Type["Interview"]],
|
151
|
-
include_model_queues
|
161
|
+
include_model_queues=False,
|
152
162
|
) -> InterviewStatisticsCollection:
|
153
163
|
"""Generate a summary of the status of the job runner.
|
154
164
|
|
@@ -164,13 +174,17 @@ class JobsRunnerStatusMixin:
|
|
164
174
|
{'Elapsed time': '0.0 sec.', 'Total interviews requested': '1 ', 'Completed interviews': '0 ', 'Percent complete': '0 %', 'Average time per interview': 'NA', 'Task remaining': '1 ', 'Estimated time remaining': 'NA'}
|
165
175
|
"""
|
166
176
|
|
167
|
-
interview_status_summary: InterviewStatisticsCollection =
|
168
|
-
|
169
|
-
|
170
|
-
|
177
|
+
interview_status_summary: InterviewStatisticsCollection = (
|
178
|
+
JobsRunnerStatusMixin._job_level_info(
|
179
|
+
completed_tasks=completed_tasks,
|
180
|
+
elapsed_time=elapsed_time,
|
181
|
+
interviews=interviews,
|
182
|
+
)
|
171
183
|
)
|
172
184
|
if include_model_queues:
|
173
|
-
interview_status_summary.model_queues = list(
|
185
|
+
interview_status_summary.model_queues = list(
|
186
|
+
JobsRunnerStatusMixin._get_model_queues_info(interviews)
|
187
|
+
)
|
174
188
|
else:
|
175
189
|
interview_status_summary.model_queues = None
|
176
190
|
|
@@ -202,15 +216,21 @@ class JobsRunnerStatusMixin:
|
|
202
216
|
|
203
217
|
token_usage_info = []
|
204
218
|
for token_usage_type in ["new_token_usage", "cached_token_usage"]:
|
205
|
-
token_usage_info.append(
|
219
|
+
token_usage_info.append(
|
220
|
+
JobsRunnerStatusMixin._get_token_usage_info(
|
221
|
+
token_usage_type, models_to_tokens, model, prices
|
222
|
+
)
|
223
|
+
)
|
206
224
|
|
207
|
-
return ModelInfo(
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
225
|
+
return ModelInfo(
|
226
|
+
**{
|
227
|
+
"model_name": model.model,
|
228
|
+
"TPM_limit_k": model.TPM / 1000,
|
229
|
+
"RPM_limit_k": model.RPM / 1000,
|
230
|
+
"num_tasks_waiting": num_waiting,
|
231
|
+
"token_usage_info": token_usage_info,
|
232
|
+
}
|
233
|
+
)
|
214
234
|
|
215
235
|
@staticmethod
|
216
236
|
def _get_token_usage_info(
|
@@ -232,13 +252,19 @@ class JobsRunnerStatusMixin:
|
|
232
252
|
|
233
253
|
"""
|
234
254
|
all_token_usage: InterviewTokenUsage = models_to_tokens[model]
|
235
|
-
token_usage: TokenUsage = getattr(all_token_usage, token_usage_type)
|
236
|
-
|
237
|
-
details = [
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
255
|
+
token_usage: TokenUsage = getattr(all_token_usage, token_usage_type)
|
256
|
+
|
257
|
+
details = [
|
258
|
+
{"type": token_type, "tokens": getattr(token_usage, token_type)}
|
259
|
+
for token_type in ["prompt_tokens", "completion_tokens"]
|
260
|
+
]
|
261
|
+
|
262
|
+
return ModelTokenUsageStats(
|
263
|
+
token_usage_type=token_usage_type,
|
264
|
+
details=details,
|
265
|
+
cost=f"${token_usage.cost(prices):.5f}",
|
266
|
+
)
|
267
|
+
|
242
268
|
@staticmethod
|
243
269
|
def _add_statistics_to_table(table, status_summary):
|
244
270
|
table.add_column("Statistic", style="dim", no_wrap=True, width=50)
|
@@ -249,9 +275,7 @@ class JobsRunnerStatusMixin:
|
|
249
275
|
table.add_row(key, value)
|
250
276
|
|
251
277
|
@staticmethod
|
252
|
-
def display_status_table(status_summary: InterviewStatisticsCollection) ->
|
253
|
-
|
254
|
-
|
278
|
+
def display_status_table(status_summary: InterviewStatisticsCollection) -> "Table":
|
255
279
|
table = Table(
|
256
280
|
title="Job Status",
|
257
281
|
show_header=True,
|
@@ -268,24 +292,29 @@ class JobsRunnerStatusMixin:
|
|
268
292
|
if status_summary.model_queues is not None:
|
269
293
|
table.add_row(Text("Model Queues", style="bold red"), "")
|
270
294
|
for model_info in status_summary.model_queues:
|
271
|
-
|
272
295
|
model_name = model_info.model_name
|
273
296
|
tpm = f"TPM (k)={model_info.TPM_limit_k}"
|
274
297
|
rpm = f"RPM (k)= {model_info.RPM_limit_k}"
|
275
298
|
pretty_model_name = model_name + ";" + tpm + ";" + rpm
|
276
299
|
table.add_row(Text(pretty_model_name, style="blue"), "")
|
277
|
-
table.add_row(
|
300
|
+
table.add_row(
|
301
|
+
"Number question tasks waiting for capacity",
|
302
|
+
str(model_info.num_tasks_waiting),
|
303
|
+
)
|
278
304
|
# Token usage and cost info
|
279
305
|
for token_usage_info in model_info.token_usage_info:
|
280
306
|
token_usage_type = token_usage_info.token_usage_type
|
281
307
|
table.add_row(
|
282
|
-
Text(
|
308
|
+
Text(
|
309
|
+
spacing + token_usage_type.replace("_", " "), style="bold"
|
310
|
+
),
|
311
|
+
"",
|
283
312
|
)
|
284
313
|
for detail in token_usage_info.details:
|
285
314
|
token_type = detail["type"]
|
286
315
|
tokens = detail["tokens"]
|
287
316
|
table.add_row(spacing + f"{token_type}", f"{tokens:,}")
|
288
|
-
#table.add_row(spacing + "cost", cache_info["cost"])
|
317
|
+
# table.add_row(spacing + "cost", cache_info["cost"])
|
289
318
|
|
290
319
|
return table
|
291
320
|
|
@@ -297,6 +326,7 @@ class JobsRunnerStatusMixin:
|
|
297
326
|
)
|
298
327
|
return self.display_status_table(summary_data)
|
299
328
|
|
329
|
+
|
300
330
|
if __name__ == "__main__":
|
301
331
|
import doctest
|
302
332
|
|
edsl/jobs/tasks/TaskCreators.py
CHANGED
@@ -22,7 +22,7 @@ class TaskCreators(UserDict):
|
|
22
22
|
This is iterates through all tasks that make up an interview.
|
23
23
|
For each task, it determines how many tokens were used and whether they were cached or new.
|
24
24
|
It then sums the total number of cached and new tokens used for the interview.
|
25
|
-
|
25
|
+
|
26
26
|
"""
|
27
27
|
cached_tokens = TokenUsage(from_cache=True)
|
28
28
|
new_tokens = TokenUsage(from_cache=False)
|