edsl 0.1.31.dev4__py3-none-any.whl → 0.1.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/Invigilator.py +3 -4
  3. edsl/agents/PromptConstructionMixin.py +35 -15
  4. edsl/config.py +11 -1
  5. edsl/conjure/Conjure.py +6 -0
  6. edsl/data/CacheHandler.py +3 -4
  7. edsl/enums.py +4 -0
  8. edsl/exceptions/general.py +10 -8
  9. edsl/inference_services/AwsBedrock.py +110 -0
  10. edsl/inference_services/AzureAI.py +197 -0
  11. edsl/inference_services/DeepInfraService.py +4 -3
  12. edsl/inference_services/GroqService.py +3 -4
  13. edsl/inference_services/InferenceServicesCollection.py +13 -8
  14. edsl/inference_services/OllamaService.py +18 -0
  15. edsl/inference_services/OpenAIService.py +23 -18
  16. edsl/inference_services/models_available_cache.py +31 -0
  17. edsl/inference_services/registry.py +13 -1
  18. edsl/jobs/Jobs.py +100 -19
  19. edsl/jobs/buckets/TokenBucket.py +12 -4
  20. edsl/jobs/interviews/Interview.py +31 -9
  21. edsl/jobs/interviews/InterviewExceptionEntry.py +101 -0
  22. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +49 -34
  23. edsl/jobs/interviews/interview_exception_tracking.py +68 -10
  24. edsl/jobs/runners/JobsRunnerAsyncio.py +36 -15
  25. edsl/jobs/runners/JobsRunnerStatusMixin.py +81 -51
  26. edsl/jobs/tasks/TaskCreators.py +1 -1
  27. edsl/jobs/tasks/TaskHistory.py +145 -1
  28. edsl/language_models/LanguageModel.py +58 -43
  29. edsl/language_models/registry.py +2 -2
  30. edsl/questions/QuestionBudget.py +0 -1
  31. edsl/questions/QuestionCheckBox.py +0 -1
  32. edsl/questions/QuestionExtract.py +0 -1
  33. edsl/questions/QuestionFreeText.py +2 -9
  34. edsl/questions/QuestionList.py +0 -1
  35. edsl/questions/QuestionMultipleChoice.py +1 -2
  36. edsl/questions/QuestionNumerical.py +0 -1
  37. edsl/questions/QuestionRank.py +0 -1
  38. edsl/results/DatasetExportMixin.py +33 -3
  39. edsl/scenarios/Scenario.py +14 -0
  40. edsl/scenarios/ScenarioList.py +216 -13
  41. edsl/scenarios/ScenarioListExportMixin.py +15 -4
  42. edsl/scenarios/ScenarioListPdfMixin.py +3 -0
  43. edsl/surveys/Rule.py +5 -2
  44. edsl/surveys/Survey.py +84 -1
  45. edsl/surveys/SurveyQualtricsImport.py +213 -0
  46. edsl/utilities/utilities.py +31 -0
  47. {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/METADATA +4 -1
  48. {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/RECORD +50 -45
  49. {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/LICENSE +0 -0
  50. {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/WHEEL +0 -0
@@ -12,16 +12,34 @@ from edsl.exceptions import InterviewTimeoutError
12
12
  # from edsl.questions.QuestionBase import QuestionBase
13
13
  from edsl.surveys.base import EndOfSurvey
14
14
  from edsl.jobs.buckets.ModelBuckets import ModelBuckets
15
- from edsl.jobs.interviews.interview_exception_tracking import InterviewExceptionEntry
15
+ from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
16
16
  from edsl.jobs.interviews.retry_management import retry_strategy
17
17
  from edsl.jobs.tasks.task_status_enum import TaskStatus
18
18
  from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
19
19
 
20
20
  # from edsl.agents.InvigilatorBase import InvigilatorBase
21
21
 
22
+ from rich.console import Console
23
+ from rich.traceback import Traceback
24
+
22
25
  TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
23
26
 
24
27
 
28
+ def frame_summary_to_dict(frame):
29
+ """
30
+ Convert a FrameSummary object to a dictionary.
31
+
32
+ :param frame: A traceback FrameSummary object
33
+ :return: A dictionary containing the frame's details
34
+ """
35
+ return {
36
+ "filename": frame.filename,
37
+ "lineno": frame.lineno,
38
+ "name": frame.name,
39
+ "line": frame.line,
40
+ }
41
+
42
+
25
43
  class InterviewTaskBuildingMixin:
26
44
  def _build_invigilators(
27
45
  self, debug: bool
@@ -148,7 +166,6 @@ class InterviewTaskBuildingMixin:
148
166
  raise ValueError(f"Prompt is of type {type(prompt)}")
149
167
  return len(combined_text) / 4.0
150
168
 
151
- @retry_strategy
152
169
  async def _answer_question_and_record_task(
153
170
  self,
154
171
  *,
@@ -163,22 +180,29 @@ class InterviewTaskBuildingMixin:
163
180
  """
164
181
  from edsl.data_transfer_models import AgentResponseDict
165
182
 
166
- try:
167
- invigilator = self._get_invigilator(question, debug=debug)
183
+ async def _inner():
184
+ try:
185
+ invigilator = self._get_invigilator(question, debug=debug)
168
186
 
169
- if self._skip_this_question(question):
170
- return invigilator.get_failed_task_result()
187
+ if self._skip_this_question(question):
188
+ return invigilator.get_failed_task_result()
171
189
 
172
- response: AgentResponseDict = await self._attempt_to_answer_question(
173
- invigilator, task
174
- )
190
+ response: AgentResponseDict = await self._attempt_to_answer_question(
191
+ invigilator, task
192
+ )
175
193
 
176
- self._add_answer(response=response, question=question)
194
+ self._add_answer(response=response, question=question)
177
195
 
178
- self._cancel_skipped_questions(question)
179
- return AgentResponseDict(**response)
180
- except Exception as e:
181
- raise e
196
+ self._cancel_skipped_questions(question)
197
+ return AgentResponseDict(**response)
198
+ except Exception as e:
199
+ raise e
200
+
201
+ skip_rety = getattr(self, "skip_retry", False)
202
+ if not skip_rety:
203
+ _inner = retry_strategy(_inner)
204
+
205
+ return await _inner()
182
206
 
183
207
  def _add_answer(
184
208
  self, response: "AgentResponseDict", question: "QuestionBase"
@@ -203,39 +227,30 @@ class InterviewTaskBuildingMixin:
203
227
  )
204
228
  return skip
205
229
 
230
+ def _handle_exception(self, e, question_name: str, task=None):
231
+ exception_entry = InterviewExceptionEntry(e)
232
+ if task:
233
+ task.task_status = TaskStatus.FAILED
234
+ self.exceptions.add(question_name, exception_entry)
235
+
206
236
  async def _attempt_to_answer_question(
207
- self, invigilator: 'InvigilatorBase', task: asyncio.Task
208
- ) -> 'AgentResponseDict':
237
+ self, invigilator: "InvigilatorBase", task: asyncio.Task
238
+ ) -> "AgentResponseDict":
209
239
  """Attempt to answer the question, and handle exceptions.
210
240
 
211
241
  :param invigilator: the invigilator that will answer the question.
212
242
  :param task: the task that is being run.
213
-
243
+
214
244
  """
215
245
  try:
216
246
  return await asyncio.wait_for(
217
247
  invigilator.async_answer_question(), timeout=TIMEOUT
218
248
  )
219
249
  except asyncio.TimeoutError as e:
220
- exception_entry = InterviewExceptionEntry(
221
- exception=repr(e),
222
- time=time.time(),
223
- traceback=traceback.format_exc(),
224
- )
225
- if task:
226
- task.task_status = TaskStatus.FAILED
227
- self.exceptions.add(invigilator.question.question_name, exception_entry)
228
-
250
+ self._handle_exception(e, invigilator.question.question_name, task)
229
251
  raise InterviewTimeoutError(f"Task timed out after {TIMEOUT} seconds.")
230
252
  except Exception as e:
231
- exception_entry = InterviewExceptionEntry(
232
- exception=repr(e),
233
- time=time.time(),
234
- traceback=traceback.format_exc(),
235
- )
236
- if task:
237
- task.task_status = TaskStatus.FAILED
238
- self.exceptions.add(invigilator.question.question_name, exception_entry)
253
+ self._handle_exception(e, invigilator.question.question_name, task)
239
254
  raise e
240
255
 
241
256
  def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
@@ -1,18 +1,70 @@
1
- from rich.console import Console
2
- from rich.table import Table
1
+ import traceback
2
+ import datetime
3
+ import time
3
4
  from collections import UserDict
4
5
 
6
+ from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
5
7
 
6
- class InterviewExceptionEntry(UserDict):
7
- """Class to record an exception that occurred during the interview."""
8
+ # #traceback=traceback.format_exc(),
9
+ # #traceback = frame_summary_to_dict(traceback.extract_tb(e.__traceback__))
10
+ # #traceback = [frame_summary_to_dict(f) for f in traceback.extract_tb(e.__traceback__)]
8
11
 
9
- def __init__(self, exception, time, traceback):
10
- data = {"exception": exception, "time": time, "traceback": traceback}
11
- super().__init__(data)
12
+ # class InterviewExceptionEntry:
13
+ # """Class to record an exception that occurred during the interview.
12
14
 
13
- def to_dict(self) -> dict:
14
- """Return the exception as a dictionary."""
15
- return self.data
15
+ # >>> entry = InterviewExceptionEntry.example()
16
+ # >>> entry.to_dict()['exception']
17
+ # "ValueError('An error occurred.')"
18
+ # """
19
+
20
+ # def __init__(self, exception: Exception):
21
+ # self.time = datetime.datetime.now().isoformat()
22
+ # self.exception = exception
23
+
24
+ # def __getitem__(self, key):
25
+ # # Support dict-like access obj['a']
26
+ # return str(getattr(self, key))
27
+
28
+ # @classmethod
29
+ # def example(cls):
30
+ # try:
31
+ # raise ValueError("An error occurred.")
32
+ # except Exception as e:
33
+ # entry = InterviewExceptionEntry(e)
34
+ # return entry
35
+
36
+ # @property
37
+ # def traceback(self):
38
+ # """Return the exception as HTML."""
39
+ # e = self.exception
40
+ # tb_str = ''.join(traceback.format_exception(type(e), e, e.__traceback__))
41
+ # return tb_str
42
+
43
+
44
+ # @property
45
+ # def html(self):
46
+ # from rich.console import Console
47
+ # from rich.table import Table
48
+ # from rich.traceback import Traceback
49
+
50
+ # from io import StringIO
51
+ # html_output = StringIO()
52
+
53
+ # console = Console(file=html_output, record=True)
54
+ # tb = Traceback(show_locals=True)
55
+ # console.print(tb)
56
+
57
+ # tb = Traceback.from_exception(type(self.exception), self.exception, self.exception.__traceback__, show_locals=True)
58
+ # console.print(tb)
59
+ # return html_output.getvalue()
60
+
61
+ # def to_dict(self) -> dict:
62
+ # """Return the exception as a dictionary."""
63
+ # return {
64
+ # 'exception': repr(self.exception),
65
+ # 'time': self.time,
66
+ # 'traceback': self.traceback
67
+ # }
16
68
 
17
69
 
18
70
  class InterviewExceptionCollection(UserDict):
@@ -84,3 +136,9 @@ class InterviewExceptionCollection(UserDict):
84
136
  )
85
137
 
86
138
  console.print(table)
139
+
140
+
141
+ if __name__ == "__main__":
142
+ import doctest
143
+
144
+ doctest.testmod(optionflags=doctest.ELLIPSIS)
@@ -16,6 +16,7 @@ from edsl.utilities.decorators import jupyter_nb_handler
16
16
  import time
17
17
  import functools
18
18
 
19
+
19
20
  def cache_with_timeout(timeout):
20
21
  def decorator(func):
21
22
  cached_result = {}
@@ -25,23 +26,27 @@ def cache_with_timeout(timeout):
25
26
  def wrapper(*args, **kwargs):
26
27
  current_time = time.time()
27
28
  if (current_time - last_computation_time[0]) >= timeout:
28
- cached_result['value'] = func(*args, **kwargs)
29
+ cached_result["value"] = func(*args, **kwargs)
29
30
  last_computation_time[0] = current_time
30
- return cached_result['value']
31
-
31
+ return cached_result["value"]
32
+
32
33
  return wrapper
34
+
33
35
  return decorator
34
36
 
35
- #from queue import Queue
37
+
38
+ # from queue import Queue
36
39
  from collections import UserList
37
40
 
41
+
38
42
  class StatusTracker(UserList):
39
43
  def __init__(self, total_tasks: int):
40
44
  self.total_tasks = total_tasks
41
45
  super().__init__()
42
-
46
+
43
47
  def current_status(self):
44
- return print(f"Completed: {len(self.data)} of {self.total_tasks}", end = "\r")
48
+ return print(f"Completed: {len(self.data)} of {self.total_tasks}", end="\r")
49
+
45
50
 
46
51
  class JobsRunnerAsyncio(JobsRunnerStatusMixin):
47
52
  """A class for running a collection of interviews asynchronously.
@@ -121,8 +126,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
121
126
  async def run_async(self, cache=None, n=1) -> Results:
122
127
  from edsl.results.Results import Results
123
128
 
124
- #breakpoint()
125
- #tracker = StatusTracker(total_tasks=len(self.interviews))
129
+ # breakpoint()
130
+ # tracker = StatusTracker(total_tasks=len(self.interviews))
126
131
 
127
132
  if cache is None:
128
133
  self.cache = Cache()
@@ -207,6 +212,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
207
212
  raw_model_response=raw_model_results_dictionary,
208
213
  survey=interview.survey,
209
214
  )
215
+ result.interview_hash = hash(interview)
216
+
210
217
  return result
211
218
 
212
219
  @property
@@ -242,7 +249,7 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
242
249
  def generate_table():
243
250
  return self.status_table(self.results, self.elapsed_time)
244
251
 
245
- async def process_results(cache, progress_bar_context = None):
252
+ async def process_results(cache, progress_bar_context=None):
246
253
  """Processes results from interviews."""
247
254
  async for result in self.run_async_generator(
248
255
  n=n,
@@ -275,16 +282,23 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
275
282
  else:
276
283
  yield
277
284
 
278
- with conditional_context(progress_bar, Live(generate_table(), console=console, refresh_per_second=1)) as progress_bar_context:
279
-
285
+ with conditional_context(
286
+ progress_bar, Live(generate_table(), console=console, refresh_per_second=1)
287
+ ) as progress_bar_context:
280
288
  with cache as c:
281
-
282
- progress_task = asyncio.create_task(update_progress_bar(progress_bar_context))
289
+ progress_task = asyncio.create_task(
290
+ update_progress_bar(progress_bar_context)
291
+ )
283
292
 
284
293
  try:
285
- await asyncio.gather(progress_task, process_results(cache = c, progress_bar_context = progress_bar_context))
294
+ await asyncio.gather(
295
+ progress_task,
296
+ process_results(
297
+ cache=c, progress_bar_context=progress_bar_context
298
+ ),
299
+ )
286
300
  except asyncio.CancelledError:
287
- pass
301
+ pass
288
302
  finally:
289
303
  progress_task.cancel() # Cancel the progress_task when process_results is done
290
304
  await progress_task
@@ -294,6 +308,11 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
294
308
  if progress_bar_context:
295
309
  progress_bar_context.update(generate_table())
296
310
 
311
+ # puts results in the same order as the total interviews
312
+ interview_hashes = [hash(interview) for interview in self.total_interviews]
313
+ self.results = sorted(
314
+ self.results, key=lambda x: interview_hashes.index(x.interview_hash)
315
+ )
297
316
 
298
317
  results = Results(survey=self.jobs.survey, data=self.results)
299
318
  task_history = TaskHistory(self.total_interviews, include_traceback=False)
@@ -302,6 +321,7 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
302
321
  results.has_exceptions = task_history.has_exceptions
303
322
 
304
323
  if results.has_exceptions:
324
+ # put the failed interviews in the results object as a list
305
325
  failed_interviews = [
306
326
  interview.duplicate(
307
327
  iteration=interview.iteration, cache=interview.cache
@@ -322,6 +342,7 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
322
342
 
323
343
  shared_globals["edsl_runner_exceptions"] = task_history
324
344
  print(msg)
345
+ # this is where exceptions are opening up
325
346
  task_history.html(cta="Open report to see details.")
326
347
  print(
327
348
  "Also see: https://docs.expectedparrot.com/en/latest/exceptions.html"
@@ -22,7 +22,7 @@ from edsl.jobs.interviews.InterviewStatisticsCollection import (
22
22
  from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
23
23
 
24
24
 
25
- #return {"cache_status": token_usage_type, "details": details, "cost": f"${token_usage.cost(prices):.5f}"}
25
+ # return {"cache_status": token_usage_type, "details": details, "cost": f"${token_usage.cost(prices):.5f}"}
26
26
 
27
27
  from dataclasses import dataclass, asdict
28
28
 
@@ -30,6 +30,7 @@ from rich.text import Text
30
30
  from rich.box import SIMPLE
31
31
  from rich.table import Table
32
32
 
33
+
33
34
  @dataclass
34
35
  class ModelInfo:
35
36
  model_name: str
@@ -45,18 +46,13 @@ class ModelTokenUsageStats:
45
46
  details: List[dict]
46
47
  cost: str
47
48
 
48
- class Stats:
49
49
 
50
+ class Stats:
50
51
  def elapsed_time(self):
51
- InterviewStatistic(
52
- "elapsed_time", value=elapsed_time, digits=1, units="sec."
53
- )
54
-
55
-
52
+ InterviewStatistic("elapsed_time", value=elapsed_time, digits=1, units="sec.")
56
53
 
57
54
 
58
55
  class JobsRunnerStatusMixin:
59
-
60
56
  # @staticmethod
61
57
  # def status_dict(interviews: List[Type["Interview"]]) -> List[Type[InterviewStatusDictionary]]:
62
58
  # """
@@ -68,7 +64,6 @@ class JobsRunnerStatusMixin:
68
64
  # return [interview.interview_status for interview in interviews]
69
65
 
70
66
  def _compute_statistic(stat_name: str, completed_tasks, elapsed_time, interviews):
71
-
72
67
  stat_definitions = {
73
68
  "elapsed_time": lambda: InterviewStatistic(
74
69
  "elapsed_time", value=elapsed_time, digits=1, units="sec."
@@ -101,36 +96,49 @@ class JobsRunnerStatusMixin:
101
96
  "estimated_time_remaining": lambda: InterviewStatistic(
102
97
  "estimated_time_remaining",
103
98
  value=(
104
- (len(interviews) - len(completed_tasks)) * (elapsed_time / len(completed_tasks))
99
+ (len(interviews) - len(completed_tasks))
100
+ * (elapsed_time / len(completed_tasks))
105
101
  if len(completed_tasks) > 0
106
102
  else "NA"
107
103
  ),
108
104
  digits=1,
109
105
  units="sec.",
110
- )
111
- }
106
+ ),
107
+ }
112
108
  if stat_name not in stat_definitions:
113
- raise ValueError(f"Invalid stat_name: {stat_name}. The valid stat_names are: {list(stat_definitions.keys())}")
109
+ raise ValueError(
110
+ f"Invalid stat_name: {stat_name}. The valid stat_names are: {list(stat_definitions.keys())}"
111
+ )
114
112
  return stat_definitions[stat_name]()
115
113
 
116
-
117
114
  @staticmethod
118
- def _job_level_info(completed_tasks: List[Type[asyncio.Task]],
119
- elapsed_time: float,
120
- interviews: List[Type["Interview"]]
121
- ) -> InterviewStatisticsCollection:
122
-
115
+ def _job_level_info(
116
+ completed_tasks: List[Type[asyncio.Task]],
117
+ elapsed_time: float,
118
+ interviews: List[Type["Interview"]],
119
+ ) -> InterviewStatisticsCollection:
123
120
  interview_statistics = InterviewStatisticsCollection()
124
121
 
125
- default_statistics = ["elapsed_time", "total_interviews_requested", "completed_interviews", "percent_complete", "average_time_per_interview", "task_remaining", "estimated_time_remaining"]
122
+ default_statistics = [
123
+ "elapsed_time",
124
+ "total_interviews_requested",
125
+ "completed_interviews",
126
+ "percent_complete",
127
+ "average_time_per_interview",
128
+ "task_remaining",
129
+ "estimated_time_remaining",
130
+ ]
126
131
  for stat_name in default_statistics:
127
- interview_statistics.add_stat(JobsRunnerStatusMixin._compute_statistic(stat_name, completed_tasks, elapsed_time, interviews))
132
+ interview_statistics.add_stat(
133
+ JobsRunnerStatusMixin._compute_statistic(
134
+ stat_name, completed_tasks, elapsed_time, interviews
135
+ )
136
+ )
128
137
 
129
138
  return interview_statistics
130
139
 
131
140
  @staticmethod
132
141
  def _get_model_queues_info(interviews):
133
-
134
142
  models_to_tokens = defaultdict(InterviewTokenUsage)
135
143
  model_to_status = defaultdict(InterviewStatusDictionary)
136
144
  waiting_dict = defaultdict(int)
@@ -141,14 +149,16 @@ class JobsRunnerStatusMixin:
141
149
  waiting_dict[interview.model] += interview.interview_status.waiting
142
150
 
143
151
  for model, num_waiting in waiting_dict.items():
144
- yield JobsRunnerStatusMixin._get_model_info(model, num_waiting, models_to_tokens)
152
+ yield JobsRunnerStatusMixin._get_model_info(
153
+ model, num_waiting, models_to_tokens
154
+ )
145
155
 
146
156
  @staticmethod
147
157
  def generate_status_summary(
148
158
  completed_tasks: List[Type[asyncio.Task]],
149
159
  elapsed_time: float,
150
160
  interviews: List[Type["Interview"]],
151
- include_model_queues = False
161
+ include_model_queues=False,
152
162
  ) -> InterviewStatisticsCollection:
153
163
  """Generate a summary of the status of the job runner.
154
164
 
@@ -164,13 +174,17 @@ class JobsRunnerStatusMixin:
164
174
  {'Elapsed time': '0.0 sec.', 'Total interviews requested': '1 ', 'Completed interviews': '0 ', 'Percent complete': '0 %', 'Average time per interview': 'NA', 'Task remaining': '1 ', 'Estimated time remaining': 'NA'}
165
175
  """
166
176
 
167
- interview_status_summary: InterviewStatisticsCollection = JobsRunnerStatusMixin._job_level_info(
168
- completed_tasks=completed_tasks,
169
- elapsed_time=elapsed_time,
170
- interviews=interviews
177
+ interview_status_summary: InterviewStatisticsCollection = (
178
+ JobsRunnerStatusMixin._job_level_info(
179
+ completed_tasks=completed_tasks,
180
+ elapsed_time=elapsed_time,
181
+ interviews=interviews,
182
+ )
171
183
  )
172
184
  if include_model_queues:
173
- interview_status_summary.model_queues = list(JobsRunnerStatusMixin._get_model_queues_info(interviews))
185
+ interview_status_summary.model_queues = list(
186
+ JobsRunnerStatusMixin._get_model_queues_info(interviews)
187
+ )
174
188
  else:
175
189
  interview_status_summary.model_queues = None
176
190
 
@@ -202,15 +216,21 @@ class JobsRunnerStatusMixin:
202
216
 
203
217
  token_usage_info = []
204
218
  for token_usage_type in ["new_token_usage", "cached_token_usage"]:
205
- token_usage_info.append(JobsRunnerStatusMixin._get_token_usage_info(token_usage_type, models_to_tokens, model, prices))
219
+ token_usage_info.append(
220
+ JobsRunnerStatusMixin._get_token_usage_info(
221
+ token_usage_type, models_to_tokens, model, prices
222
+ )
223
+ )
206
224
 
207
- return ModelInfo(**{
208
- "model_name": model.model,
209
- "TPM_limit_k": model.TPM / 1000,
210
- "RPM_limit_k": model.RPM / 1000,
211
- "num_tasks_waiting": num_waiting,
212
- "token_usage_info": token_usage_info,
213
- })
225
+ return ModelInfo(
226
+ **{
227
+ "model_name": model.model,
228
+ "TPM_limit_k": model.TPM / 1000,
229
+ "RPM_limit_k": model.RPM / 1000,
230
+ "num_tasks_waiting": num_waiting,
231
+ "token_usage_info": token_usage_info,
232
+ }
233
+ )
214
234
 
215
235
  @staticmethod
216
236
  def _get_token_usage_info(
@@ -232,13 +252,19 @@ class JobsRunnerStatusMixin:
232
252
 
233
253
  """
234
254
  all_token_usage: InterviewTokenUsage = models_to_tokens[model]
235
- token_usage: TokenUsage = getattr(all_token_usage, token_usage_type)
236
-
237
- details = [{"type": token_type, "tokens": getattr(token_usage, token_type)}
238
- for token_type in ["prompt_tokens", "completion_tokens"]]
239
-
240
- return ModelTokenUsageStats(token_usage_type = token_usage_type, details = details, cost = f"${token_usage.cost(prices):.5f}")
241
-
255
+ token_usage: TokenUsage = getattr(all_token_usage, token_usage_type)
256
+
257
+ details = [
258
+ {"type": token_type, "tokens": getattr(token_usage, token_type)}
259
+ for token_type in ["prompt_tokens", "completion_tokens"]
260
+ ]
261
+
262
+ return ModelTokenUsageStats(
263
+ token_usage_type=token_usage_type,
264
+ details=details,
265
+ cost=f"${token_usage.cost(prices):.5f}",
266
+ )
267
+
242
268
  @staticmethod
243
269
  def _add_statistics_to_table(table, status_summary):
244
270
  table.add_column("Statistic", style="dim", no_wrap=True, width=50)
@@ -249,9 +275,7 @@ class JobsRunnerStatusMixin:
249
275
  table.add_row(key, value)
250
276
 
251
277
  @staticmethod
252
- def display_status_table(status_summary: InterviewStatisticsCollection) -> 'Table':
253
-
254
-
278
+ def display_status_table(status_summary: InterviewStatisticsCollection) -> "Table":
255
279
  table = Table(
256
280
  title="Job Status",
257
281
  show_header=True,
@@ -268,24 +292,29 @@ class JobsRunnerStatusMixin:
268
292
  if status_summary.model_queues is not None:
269
293
  table.add_row(Text("Model Queues", style="bold red"), "")
270
294
  for model_info in status_summary.model_queues:
271
-
272
295
  model_name = model_info.model_name
273
296
  tpm = f"TPM (k)={model_info.TPM_limit_k}"
274
297
  rpm = f"RPM (k)= {model_info.RPM_limit_k}"
275
298
  pretty_model_name = model_name + ";" + tpm + ";" + rpm
276
299
  table.add_row(Text(pretty_model_name, style="blue"), "")
277
- table.add_row("Number question tasks waiting for capacity", str(model_info.num_tasks_waiting))
300
+ table.add_row(
301
+ "Number question tasks waiting for capacity",
302
+ str(model_info.num_tasks_waiting),
303
+ )
278
304
  # Token usage and cost info
279
305
  for token_usage_info in model_info.token_usage_info:
280
306
  token_usage_type = token_usage_info.token_usage_type
281
307
  table.add_row(
282
- Text(spacing + token_usage_type.replace("_", " "), style="bold"), ""
308
+ Text(
309
+ spacing + token_usage_type.replace("_", " "), style="bold"
310
+ ),
311
+ "",
283
312
  )
284
313
  for detail in token_usage_info.details:
285
314
  token_type = detail["type"]
286
315
  tokens = detail["tokens"]
287
316
  table.add_row(spacing + f"{token_type}", f"{tokens:,}")
288
- #table.add_row(spacing + "cost", cache_info["cost"])
317
+ # table.add_row(spacing + "cost", cache_info["cost"])
289
318
 
290
319
  return table
291
320
 
@@ -297,6 +326,7 @@ class JobsRunnerStatusMixin:
297
326
  )
298
327
  return self.display_status_table(summary_data)
299
328
 
329
+
300
330
  if __name__ == "__main__":
301
331
  import doctest
302
332
 
@@ -22,7 +22,7 @@ class TaskCreators(UserDict):
22
22
  This is iterates through all tasks that make up an interview.
23
23
  For each task, it determines how many tokens were used and whether they were cached or new.
24
24
  It then sums the total number of cached and new tokens used for the interview.
25
-
25
+
26
26
  """
27
27
  cached_tokens = TokenUsage(from_cache=True)
28
28
  new_tokens = TokenUsage(from_cache=False)