edsl 0.1.32__py3-none-any.whl → 0.1.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. edsl/Base.py +9 -3
  2. edsl/TemplateLoader.py +24 -0
  3. edsl/__init__.py +8 -3
  4. edsl/__version__.py +1 -1
  5. edsl/agents/Agent.py +40 -8
  6. edsl/agents/AgentList.py +43 -0
  7. edsl/agents/Invigilator.py +135 -219
  8. edsl/agents/InvigilatorBase.py +148 -59
  9. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +138 -89
  10. edsl/agents/__init__.py +1 -0
  11. edsl/auto/AutoStudy.py +117 -0
  12. edsl/auto/StageBase.py +230 -0
  13. edsl/auto/StageGenerateSurvey.py +178 -0
  14. edsl/auto/StageLabelQuestions.py +125 -0
  15. edsl/auto/StagePersona.py +61 -0
  16. edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
  17. edsl/auto/StagePersonaDimensionValues.py +74 -0
  18. edsl/auto/StagePersonaDimensions.py +69 -0
  19. edsl/auto/StageQuestions.py +73 -0
  20. edsl/auto/SurveyCreatorPipeline.py +21 -0
  21. edsl/auto/utilities.py +224 -0
  22. edsl/config.py +47 -56
  23. edsl/coop/PriceFetcher.py +58 -0
  24. edsl/coop/coop.py +50 -7
  25. edsl/data/Cache.py +35 -1
  26. edsl/data_transfer_models.py +73 -38
  27. edsl/enums.py +4 -0
  28. edsl/exceptions/language_models.py +25 -1
  29. edsl/exceptions/questions.py +62 -5
  30. edsl/exceptions/results.py +4 -0
  31. edsl/inference_services/AnthropicService.py +13 -11
  32. edsl/inference_services/AwsBedrock.py +19 -17
  33. edsl/inference_services/AzureAI.py +37 -20
  34. edsl/inference_services/GoogleService.py +16 -12
  35. edsl/inference_services/GroqService.py +2 -0
  36. edsl/inference_services/InferenceServiceABC.py +58 -3
  37. edsl/inference_services/MistralAIService.py +120 -0
  38. edsl/inference_services/OpenAIService.py +48 -54
  39. edsl/inference_services/TestService.py +80 -0
  40. edsl/inference_services/TogetherAIService.py +170 -0
  41. edsl/inference_services/models_available_cache.py +0 -6
  42. edsl/inference_services/registry.py +6 -0
  43. edsl/jobs/Answers.py +10 -12
  44. edsl/jobs/FailedQuestion.py +78 -0
  45. edsl/jobs/Jobs.py +37 -22
  46. edsl/jobs/buckets/BucketCollection.py +24 -15
  47. edsl/jobs/buckets/TokenBucket.py +93 -14
  48. edsl/jobs/interviews/Interview.py +366 -78
  49. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +14 -68
  50. edsl/jobs/interviews/InterviewExceptionEntry.py +85 -19
  51. edsl/jobs/runners/JobsRunnerAsyncio.py +146 -175
  52. edsl/jobs/runners/JobsRunnerStatus.py +331 -0
  53. edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
  54. edsl/jobs/tasks/TaskHistory.py +148 -213
  55. edsl/language_models/LanguageModel.py +261 -156
  56. edsl/language_models/ModelList.py +2 -2
  57. edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
  58. edsl/language_models/fake_openai_call.py +15 -0
  59. edsl/language_models/fake_openai_service.py +61 -0
  60. edsl/language_models/registry.py +23 -6
  61. edsl/language_models/repair.py +0 -19
  62. edsl/language_models/utilities.py +61 -0
  63. edsl/notebooks/Notebook.py +20 -2
  64. edsl/prompts/Prompt.py +52 -2
  65. edsl/questions/AnswerValidatorMixin.py +23 -26
  66. edsl/questions/QuestionBase.py +330 -249
  67. edsl/questions/QuestionBaseGenMixin.py +133 -0
  68. edsl/questions/QuestionBasePromptsMixin.py +266 -0
  69. edsl/questions/QuestionBudget.py +99 -41
  70. edsl/questions/QuestionCheckBox.py +227 -35
  71. edsl/questions/QuestionExtract.py +98 -27
  72. edsl/questions/QuestionFreeText.py +52 -29
  73. edsl/questions/QuestionFunctional.py +7 -0
  74. edsl/questions/QuestionList.py +141 -22
  75. edsl/questions/QuestionMultipleChoice.py +159 -65
  76. edsl/questions/QuestionNumerical.py +88 -46
  77. edsl/questions/QuestionRank.py +182 -24
  78. edsl/questions/Quick.py +41 -0
  79. edsl/questions/RegisterQuestionsMeta.py +31 -12
  80. edsl/questions/ResponseValidatorABC.py +170 -0
  81. edsl/questions/__init__.py +3 -4
  82. edsl/questions/decorators.py +21 -0
  83. edsl/questions/derived/QuestionLikertFive.py +10 -5
  84. edsl/questions/derived/QuestionLinearScale.py +15 -2
  85. edsl/questions/derived/QuestionTopK.py +10 -1
  86. edsl/questions/derived/QuestionYesNo.py +24 -3
  87. edsl/questions/descriptors.py +43 -7
  88. edsl/questions/prompt_templates/question_budget.jinja +13 -0
  89. edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
  90. edsl/questions/prompt_templates/question_extract.jinja +11 -0
  91. edsl/questions/prompt_templates/question_free_text.jinja +3 -0
  92. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
  93. edsl/questions/prompt_templates/question_list.jinja +17 -0
  94. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
  95. edsl/questions/prompt_templates/question_numerical.jinja +37 -0
  96. edsl/questions/question_registry.py +6 -2
  97. edsl/questions/templates/__init__.py +0 -0
  98. edsl/questions/templates/budget/__init__.py +0 -0
  99. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  100. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  101. edsl/questions/templates/checkbox/__init__.py +0 -0
  102. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
  103. edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
  104. edsl/questions/templates/extract/__init__.py +0 -0
  105. edsl/questions/templates/extract/answering_instructions.jinja +7 -0
  106. edsl/questions/templates/extract/question_presentation.jinja +1 -0
  107. edsl/questions/templates/free_text/__init__.py +0 -0
  108. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  109. edsl/questions/templates/free_text/question_presentation.jinja +1 -0
  110. edsl/questions/templates/likert_five/__init__.py +0 -0
  111. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
  112. edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
  113. edsl/questions/templates/linear_scale/__init__.py +0 -0
  114. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
  115. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
  116. edsl/questions/templates/list/__init__.py +0 -0
  117. edsl/questions/templates/list/answering_instructions.jinja +4 -0
  118. edsl/questions/templates/list/question_presentation.jinja +5 -0
  119. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  120. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
  121. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  122. edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
  123. edsl/questions/templates/numerical/__init__.py +0 -0
  124. edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
  125. edsl/questions/templates/numerical/question_presentation.jinja +7 -0
  126. edsl/questions/templates/rank/__init__.py +0 -0
  127. edsl/questions/templates/rank/answering_instructions.jinja +11 -0
  128. edsl/questions/templates/rank/question_presentation.jinja +15 -0
  129. edsl/questions/templates/top_k/__init__.py +0 -0
  130. edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
  131. edsl/questions/templates/top_k/question_presentation.jinja +22 -0
  132. edsl/questions/templates/yes_no/__init__.py +0 -0
  133. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
  134. edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
  135. edsl/results/Dataset.py +20 -0
  136. edsl/results/DatasetExportMixin.py +46 -48
  137. edsl/results/DatasetTree.py +145 -0
  138. edsl/results/Result.py +32 -5
  139. edsl/results/Results.py +135 -46
  140. edsl/results/ResultsDBMixin.py +3 -3
  141. edsl/results/Selector.py +118 -0
  142. edsl/results/tree_explore.py +115 -0
  143. edsl/scenarios/FileStore.py +71 -10
  144. edsl/scenarios/Scenario.py +96 -25
  145. edsl/scenarios/ScenarioImageMixin.py +2 -2
  146. edsl/scenarios/ScenarioList.py +361 -39
  147. edsl/scenarios/ScenarioListExportMixin.py +9 -0
  148. edsl/scenarios/ScenarioListPdfMixin.py +150 -4
  149. edsl/study/SnapShot.py +8 -1
  150. edsl/study/Study.py +32 -0
  151. edsl/surveys/Rule.py +10 -1
  152. edsl/surveys/RuleCollection.py +21 -5
  153. edsl/surveys/Survey.py +637 -311
  154. edsl/surveys/SurveyExportMixin.py +71 -9
  155. edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
  156. edsl/surveys/SurveyQualtricsImport.py +75 -4
  157. edsl/surveys/instructions/ChangeInstruction.py +47 -0
  158. edsl/surveys/instructions/Instruction.py +34 -0
  159. edsl/surveys/instructions/InstructionCollection.py +77 -0
  160. edsl/surveys/instructions/__init__.py +0 -0
  161. edsl/templates/error_reporting/base.html +24 -0
  162. edsl/templates/error_reporting/exceptions_by_model.html +35 -0
  163. edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
  164. edsl/templates/error_reporting/exceptions_by_type.html +17 -0
  165. edsl/templates/error_reporting/interview_details.html +116 -0
  166. edsl/templates/error_reporting/interviews.html +10 -0
  167. edsl/templates/error_reporting/overview.html +5 -0
  168. edsl/templates/error_reporting/performance_plot.html +2 -0
  169. edsl/templates/error_reporting/report.css +74 -0
  170. edsl/templates/error_reporting/report.html +118 -0
  171. edsl/templates/error_reporting/report.js +25 -0
  172. edsl/utilities/utilities.py +9 -1
  173. {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/METADATA +5 -2
  174. edsl-0.1.33.dist-info/RECORD +295 -0
  175. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -286
  176. edsl/jobs/interviews/retry_management.py +0 -37
  177. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
  178. edsl/utilities/gcp_bucket/simple_example.py +0 -9
  179. edsl-0.1.32.dist-info/RECORD +0 -209
  180. {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/LICENSE +0 -0
  181. {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/WHEEL +0 -0
@@ -0,0 +1,331 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from dataclasses import dataclass, asdict
5
+
6
+ from typing import List, DefaultDict, Optional, Type, Literal
7
+ from collections import UserDict, defaultdict
8
+
9
+ from rich.text import Text
10
+ from rich.box import SIMPLE
11
+ from rich.table import Table
12
+ from rich.live import Live
13
+ from rich.panel import Panel
14
+ from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn
15
+ from rich.layout import Layout
16
+ from rich.console import Group
17
+ from rich import box
18
+
19
+ from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
20
+ from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
21
+ from edsl.jobs.tokens.TokenUsage import TokenUsage
22
+ from edsl.enums import get_token_pricing
23
+ from edsl.jobs.tasks.task_status_enum import TaskStatus
24
+
25
+ InterviewTokenUsageMapping = DefaultDict[str, InterviewTokenUsage]
26
+
27
+ from edsl.jobs.interviews.InterviewStatistic import InterviewStatistic
28
+ from edsl.jobs.interviews.InterviewStatisticsCollection import (
29
+ InterviewStatisticsCollection,
30
+ )
31
+ from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
32
+
33
+
34
+ @dataclass
35
+ class ModelInfo:
36
+ model_name: str
37
+ TPM_limit_k: float
38
+ RPM_limit_k: float
39
+ num_tasks_waiting: int
40
+ token_usage_info: dict
41
+
42
+
43
+ @dataclass
44
+ class ModelTokenUsageStats:
45
+ token_usage_type: str
46
+ details: List[dict]
47
+ cost: str
48
+
49
+
50
+ class Stats:
51
+ def elapsed_time(self):
52
+ InterviewStatistic("elapsed_time", value=elapsed_time, digits=1, units="sec.")
53
+
54
+
55
+ class JobsRunnerStatus:
56
+ def __init__(
57
+ self, jobs_runner: "JobsRunnerAsyncio", n: int, refresh_rate: float = 0.25
58
+ ):
59
+ self.jobs_runner = jobs_runner
60
+ self.start_time = time.time()
61
+ self.completed_interviews = []
62
+ self.refresh_rate = refresh_rate
63
+ self.statistics = [
64
+ "elapsed_time",
65
+ "total_interviews_requested",
66
+ "completed_interviews",
67
+ # "percent_complete",
68
+ "average_time_per_interview",
69
+ # "task_remaining",
70
+ "estimated_time_remaining",
71
+ "exceptions",
72
+ "unfixed_exceptions",
73
+ "throughput",
74
+ ]
75
+ self.num_total_interviews = n * len(self.jobs_runner.interviews)
76
+
77
+ self.distinct_models = list(
78
+ set(i.model.model for i in self.jobs_runner.interviews)
79
+ )
80
+
81
+ self.completed_interview_by_model = defaultdict(list)
82
+
83
+ def add_completed_interview(self, result):
84
+ self.completed_interviews.append(result.interview_hash)
85
+
86
+ relevant_model = result.model.model
87
+ self.completed_interview_by_model[relevant_model].append(result.interview_hash)
88
+
89
+ def _compute_statistic(self, stat_name: str):
90
+ completed_tasks = self.completed_interviews
91
+ elapsed_time = time.time() - self.start_time
92
+ interviews = self.jobs_runner.total_interviews
93
+
94
+ stat_definitions = {
95
+ "elapsed_time": lambda: InterviewStatistic(
96
+ "elapsed_time", value=elapsed_time, digits=1, units="sec."
97
+ ),
98
+ "total_interviews_requested": lambda: InterviewStatistic(
99
+ "total_interviews_requested", value=len(interviews), units=""
100
+ ),
101
+ "completed_interviews": lambda: InterviewStatistic(
102
+ "completed_interviews", value=len(completed_tasks), units=""
103
+ ),
104
+ "percent_complete": lambda: InterviewStatistic(
105
+ "percent_complete",
106
+ value=(
107
+ len(completed_tasks) / len(interviews) * 100
108
+ if len(interviews) > 0
109
+ else 0
110
+ ),
111
+ digits=1,
112
+ units="%",
113
+ ),
114
+ "average_time_per_interview": lambda: InterviewStatistic(
115
+ "average_time_per_interview",
116
+ value=elapsed_time / len(completed_tasks) if completed_tasks else 0,
117
+ digits=2,
118
+ units="sec.",
119
+ ),
120
+ "task_remaining": lambda: InterviewStatistic(
121
+ "task_remaining", value=len(interviews) - len(completed_tasks), units=""
122
+ ),
123
+ "estimated_time_remaining": lambda: InterviewStatistic(
124
+ "estimated_time_remaining",
125
+ value=(
126
+ (len(interviews) - len(completed_tasks))
127
+ * (elapsed_time / len(completed_tasks))
128
+ if len(completed_tasks) > 0
129
+ else 0
130
+ ),
131
+ digits=1,
132
+ units="sec.",
133
+ ),
134
+ "exceptions": lambda: InterviewStatistic(
135
+ "exceptions",
136
+ value=sum(len(i.exceptions) for i in interviews),
137
+ units="",
138
+ ),
139
+ "unfixed_exceptions": lambda: InterviewStatistic(
140
+ "unfixed_exceptions",
141
+ value=sum(i.exceptions.num_unfixed() for i in interviews),
142
+ units="",
143
+ ),
144
+ "throughput": lambda: InterviewStatistic(
145
+ "throughput",
146
+ value=len(completed_tasks) / elapsed_time if elapsed_time > 0 else 0,
147
+ digits=2,
148
+ units="interviews/sec.",
149
+ ),
150
+ }
151
+ return stat_definitions[stat_name]()
152
+
153
+ def create_progress_bar(self):
154
+ return Progress(
155
+ TextColumn("[progress.description]{task.description}"),
156
+ BarColumn(),
157
+ TaskProgressColumn(),
158
+ TextColumn("{task.completed}/{task.total}"),
159
+ )
160
+
161
+ def generate_model_queues_table(self):
162
+ table = Table(show_header=False, box=box.SIMPLE)
163
+ table.add_column("Info", style="cyan")
164
+ table.add_column("Value", style="magenta")
165
+ # table.add_row("Bucket collection", str(self.jobs_runner.bucket_collection))
166
+ for model, bucket in self.jobs_runner.bucket_collection.items():
167
+ table.add_row(Text(model.model, style="bold blue"), "")
168
+ bucket_types = ["requests_bucket", "tokens_bucket"]
169
+ for bucket_type in bucket_types:
170
+ table.add_row(Text(" " + bucket_type, style="green"), "")
171
+ # table.add_row(
172
+ # f" Current level (capacity = {round(getattr(bucket, bucket_type).capacity, 3)})",
173
+ # str(round(getattr(bucket, bucket_type).tokens, 3)),
174
+ # )
175
+ num_requests = getattr(bucket, bucket_type).num_requests
176
+ num_released = getattr(bucket, bucket_type).num_released
177
+ tokens_returned = getattr(bucket, bucket_type).tokens_returned
178
+ # table.add_row(
179
+ # f" Requested",
180
+ # str(num_requests),
181
+ # )
182
+ # table.add_row(
183
+ # f" Completed",
184
+ # str(num_released),
185
+ # )
186
+ table.add_row(
187
+ " Completed vs. Requested", f"{num_released} vs. {num_requests}"
188
+ )
189
+ table.add_row(
190
+ " Added tokens (from cache)",
191
+ str(tokens_returned),
192
+ )
193
+ if bucket_type == "tokens_bucket":
194
+ rate_name = "TPM"
195
+ else:
196
+ rate_name = "RPM"
197
+ target_rate = round(getattr(bucket, bucket_type).target_rate, 1)
198
+ table.add_row(
199
+ f" Empirical {rate_name} (target = {target_rate})",
200
+ str(round(getattr(bucket, bucket_type).get_throughput(), 0)),
201
+ )
202
+
203
+ return table
204
+
205
+ def generate_layout(self):
206
+ progress = self.create_progress_bar()
207
+ task_ids = []
208
+ for model in self.distinct_models:
209
+ task_id = progress.add_task(
210
+ f"[cyan]{model}...",
211
+ total=int(self.num_total_interviews / len(self.distinct_models)),
212
+ )
213
+ task_ids.append((model, task_id))
214
+
215
+ progress_height = min(5, 2 + len(self.distinct_models))
216
+ layout = Layout()
217
+
218
+ # Create the top row with only the progress panel
219
+ layout.split_column(
220
+ Layout(
221
+ Panel(
222
+ progress,
223
+ title="Interview Progress",
224
+ border_style="cyan",
225
+ box=box.ROUNDED,
226
+ ),
227
+ name="progress",
228
+ size=progress_height, # Adjusted size
229
+ ),
230
+ Layout(name="bottom_row"), # Adjusted size
231
+ )
232
+
233
+ # Split the bottom row into two columns for metrics and model queues
234
+ layout["bottom_row"].split_row(
235
+ Layout(
236
+ Panel(
237
+ self.generate_metrics_table(),
238
+ title="Metrics",
239
+ border_style="magenta",
240
+ box=box.ROUNDED,
241
+ ),
242
+ name="metrics",
243
+ ),
244
+ Layout(
245
+ Panel(
246
+ self.generate_model_queues_table(),
247
+ title="Model Queues",
248
+ border_style="yellow",
249
+ box=box.ROUNDED,
250
+ ),
251
+ name="model_queues",
252
+ ),
253
+ )
254
+
255
+ return layout, progress, task_ids
256
+
257
+ def generate_metrics_table(self):
258
+ table = Table(show_header=True, header_style="bold magenta", box=box.SIMPLE)
259
+ table.add_column("Metric", style="cyan", no_wrap=True)
260
+ table.add_column("Value", justify="right")
261
+
262
+ for stat_name in self.statistics:
263
+ pretty_name, value = list(self._compute_statistic(stat_name).items())[0]
264
+ # breakpoint()
265
+ table.add_row(pretty_name, value)
266
+ return table
267
+
268
+ def update_progress(self):
269
+ layout, progress, task_ids = self.generate_layout()
270
+
271
+ with Live(
272
+ layout, refresh_per_second=int(1 / self.refresh_rate), transient=True
273
+ ) as live:
274
+ while len(self.completed_interviews) < len(
275
+ self.jobs_runner.total_interviews
276
+ ):
277
+ completed_tasks = len(self.completed_interviews)
278
+ total_tasks = len(self.jobs_runner.total_interviews)
279
+
280
+ for model, task_id in task_ids:
281
+ completed_tasks = len(self.completed_interview_by_model[model])
282
+ progress.update(
283
+ task_id,
284
+ completed=completed_tasks,
285
+ description=f"[cyan]Conducting interviews for {model}...",
286
+ )
287
+
288
+ layout["metrics"].update(
289
+ Panel(
290
+ self.generate_metrics_table(),
291
+ title="Metrics",
292
+ border_style="magenta",
293
+ box=box.ROUNDED,
294
+ )
295
+ )
296
+ layout["model_queues"].update(
297
+ Panel(
298
+ self.generate_model_queues_table(),
299
+ title="Final Model Queues",
300
+ border_style="yellow",
301
+ box=box.ROUNDED,
302
+ )
303
+ )
304
+
305
+ time.sleep(self.refresh_rate)
306
+
307
+ # Final update
308
+ for model, task_id in task_ids:
309
+ completed_tasks = len(self.completed_interview_by_model[model])
310
+ progress.update(
311
+ task_id,
312
+ completed=completed_tasks,
313
+ description=f"[cyan]Conducting interviews for {model}...",
314
+ )
315
+
316
+ layout["metrics"].update(
317
+ Panel(
318
+ self.generate_metrics_table(),
319
+ title="Final Metrics",
320
+ border_style="magenta",
321
+ box=box.ROUNDED,
322
+ )
323
+ )
324
+ live.update(layout)
325
+ time.sleep(1) # Show final state for 1 second
326
+
327
+
328
+ if __name__ == "__main__":
329
+ import doctest
330
+
331
+ doctest.testmod(optionflags=doctest.ELLIPSIS)
@@ -55,6 +55,7 @@ class QuestionTaskCreator(UserList):
55
55
 
56
56
  """
57
57
  super().__init__([])
58
+ # answer_question_func is the 'interview.answer_question_and_record_task" method
58
59
  self.answer_question_func = answer_question_func
59
60
  self.question = question
60
61
  self.iteration = iteration
@@ -87,10 +88,10 @@ class QuestionTaskCreator(UserList):
87
88
  """
88
89
  self.append(task)
89
90
 
90
- def generate_task(self, debug: bool) -> asyncio.Task:
91
+ def generate_task(self) -> asyncio.Task:
91
92
  """Create a task that depends on the passed-in dependencies."""
92
93
  task = asyncio.create_task(
93
- self._run_task_async(debug), name=self.question.question_name
94
+ self._run_task_async(), name=self.question.question_name
94
95
  )
95
96
  task.depends_on = [t.get_name() for t in self]
96
97
  return task
@@ -103,7 +104,7 @@ class QuestionTaskCreator(UserList):
103
104
  """Returns the token usage for the task.
104
105
 
105
106
  >>> qt = QuestionTaskCreator.example()
106
- >>> answers = asyncio.run(qt._run_focal_task(debug=False))
107
+ >>> answers = asyncio.run(qt._run_focal_task())
107
108
  >>> qt.token_usage()
108
109
  {'cached_tokens': TokenUsage(from_cache=True, prompt_tokens=0, completion_tokens=0), 'new_tokens': TokenUsage(from_cache=False, prompt_tokens=0, completion_tokens=0)}
109
110
  """
@@ -111,15 +112,15 @@ class QuestionTaskCreator(UserList):
111
112
  cached_tokens=self.cached_token_usage, new_tokens=self.new_token_usage
112
113
  )
113
114
 
114
- async def _run_focal_task(self, debug: bool) -> Answers:
115
+ async def _run_focal_task(self) -> Answers:
115
116
  """Run the focal task i.e., the question that we are interested in answering.
116
117
 
117
118
  It is only called after all the dependency tasks are completed.
118
119
 
119
120
  >>> qt = QuestionTaskCreator.example()
120
- >>> answers = asyncio.run(qt._run_focal_task(debug=False))
121
- >>> answers["answer"]
122
- 'Yo!'
121
+ >>> answers = asyncio.run(qt._run_focal_task())
122
+ >>> answers.answer
123
+ 'This is an example answer'
123
124
  """
124
125
 
125
126
  requested_tokens = self.estimated_tokens()
@@ -132,19 +133,19 @@ class QuestionTaskCreator(UserList):
132
133
  self.waiting = True
133
134
  self.task_status = TaskStatus.WAITING_FOR_REQUEST_CAPACITY
134
135
 
135
- await self.tokens_bucket.get_tokens(1)
136
+ await self.requests_bucket.get_tokens(1, cheat_bucket_capacity=True)
136
137
 
137
138
  self.task_status = TaskStatus.API_CALL_IN_PROGRESS
138
139
  try:
139
140
  results = await self.answer_question_func(
140
- question=self.question, debug=debug, task=None # self
141
+ question=self.question, task=None # self
141
142
  )
142
143
  self.task_status = TaskStatus.SUCCESS
143
144
  except Exception as e:
144
145
  self.task_status = TaskStatus.FAILED
145
146
  raise e
146
147
 
147
- if results.get("cache_used", False):
148
+ if results.cache_used:
148
149
  self.tokens_bucket.add_tokens(requested_tokens)
149
150
  self.requests_bucket.add_tokens(1)
150
151
  self.from_cache = True
@@ -155,17 +156,18 @@ class QuestionTaskCreator(UserList):
155
156
  self.tokens_bucket.turbo_mode_off()
156
157
  self.requests_bucket.turbo_mode_off()
157
158
 
158
- _ = results.pop("cached_response", None)
159
+ # breakpoint()
160
+ # _ = results.pop("cached_response", None)
159
161
 
160
- tracker = self.cached_token_usage if self.from_cache else self.new_token_usage
162
+ # tracker = self.cached_token_usage if self.from_cache else self.new_token_usage
161
163
 
162
164
  # TODO: This is hacky. The 'func' call should return an object that definitely has a 'usage' key.
163
- usage = results.get("usage", {"prompt_tokens": 0, "completion_tokens": 0})
164
- prompt_tokens = usage.get("prompt_tokens", 0)
165
- completion_tokens = usage.get("completion_tokens", 0)
166
- tracker.add_tokens(
167
- prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
168
- )
165
+ # usage = results.get("usage", {"prompt_tokens": 0, "completion_tokens": 0})
166
+ # prompt_tokens = usage.get("prompt_tokens", 0)
167
+ # completion_tokens = usage.get("completion_tokens", 0)
168
+ # tracker.add_tokens(
169
+ # prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
170
+ # )
169
171
 
170
172
  return results
171
173
 
@@ -177,8 +179,13 @@ class QuestionTaskCreator(UserList):
177
179
 
178
180
  m = ModelBuckets.infinity_bucket()
179
181
 
180
- async def answer_question_func(question, debug, task):
181
- return {"answer": "Yo!"}
182
+ from collections import namedtuple
183
+
184
+ AnswerDict = namedtuple("AnswerDict", ["answer", "cache_used"])
185
+ answer = AnswerDict(answer="This is an example answer", cache_used=False)
186
+
187
+ async def answer_question_func(question, task):
188
+ return answer
182
189
 
183
190
  return cls(
184
191
  question=QuestionFreeText.example(),
@@ -188,7 +195,7 @@ class QuestionTaskCreator(UserList):
188
195
  iteration=0,
189
196
  )
190
197
 
191
- async def _run_task_async(self, debug) -> None:
198
+ async def _run_task_async(self) -> None:
192
199
  """Run the task asynchronously, awaiting the tasks that must be completed before this one can be run.
193
200
 
194
201
  >>> qt1 = QuestionTaskCreator.example()
@@ -231,8 +238,6 @@ class QuestionTaskCreator(UserList):
231
238
  if isinstance(result, Exception):
232
239
  raise result
233
240
 
234
- return await self._run_focal_task(debug)
235
-
236
241
  except asyncio.CancelledError:
237
242
  self.task_status = TaskStatus.CANCELLED
238
243
  raise
@@ -244,6 +249,8 @@ class QuestionTaskCreator(UserList):
244
249
  f"Required tasks failed for {self.question.question_name}"
245
250
  ) from e
246
251
 
252
+ return await self._run_focal_task()
253
+
247
254
 
248
255
  if __name__ == "__main__":
249
256
  import doctest