edsl 0.1.32__py3-none-any.whl → 0.1.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +9 -3
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +8 -3
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +40 -8
- edsl/agents/AgentList.py +43 -0
- edsl/agents/Invigilator.py +135 -219
- edsl/agents/InvigilatorBase.py +148 -59
- edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +138 -89
- edsl/agents/__init__.py +1 -0
- edsl/auto/AutoStudy.py +117 -0
- edsl/auto/StageBase.py +230 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +73 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +224 -0
- edsl/config.py +47 -56
- edsl/coop/PriceFetcher.py +58 -0
- edsl/coop/coop.py +50 -7
- edsl/data/Cache.py +35 -1
- edsl/data_transfer_models.py +73 -38
- edsl/enums.py +4 -0
- edsl/exceptions/language_models.py +25 -1
- edsl/exceptions/questions.py +62 -5
- edsl/exceptions/results.py +4 -0
- edsl/inference_services/AnthropicService.py +13 -11
- edsl/inference_services/AwsBedrock.py +19 -17
- edsl/inference_services/AzureAI.py +37 -20
- edsl/inference_services/GoogleService.py +16 -12
- edsl/inference_services/GroqService.py +2 -0
- edsl/inference_services/InferenceServiceABC.py +58 -3
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OpenAIService.py +48 -54
- edsl/inference_services/TestService.py +80 -0
- edsl/inference_services/TogetherAIService.py +170 -0
- edsl/inference_services/models_available_cache.py +0 -6
- edsl/inference_services/registry.py +6 -0
- edsl/jobs/Answers.py +10 -12
- edsl/jobs/FailedQuestion.py +78 -0
- edsl/jobs/Jobs.py +37 -22
- edsl/jobs/buckets/BucketCollection.py +24 -15
- edsl/jobs/buckets/TokenBucket.py +93 -14
- edsl/jobs/interviews/Interview.py +366 -78
- edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +14 -68
- edsl/jobs/interviews/InterviewExceptionEntry.py +85 -19
- edsl/jobs/runners/JobsRunnerAsyncio.py +146 -175
- edsl/jobs/runners/JobsRunnerStatus.py +331 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
- edsl/jobs/tasks/TaskHistory.py +148 -213
- edsl/language_models/LanguageModel.py +261 -156
- edsl/language_models/ModelList.py +2 -2
- edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/registry.py +23 -6
- edsl/language_models/repair.py +0 -19
- edsl/language_models/utilities.py +61 -0
- edsl/notebooks/Notebook.py +20 -2
- edsl/prompts/Prompt.py +52 -2
- edsl/questions/AnswerValidatorMixin.py +23 -26
- edsl/questions/QuestionBase.py +330 -249
- edsl/questions/QuestionBaseGenMixin.py +133 -0
- edsl/questions/QuestionBasePromptsMixin.py +266 -0
- edsl/questions/QuestionBudget.py +99 -41
- edsl/questions/QuestionCheckBox.py +227 -35
- edsl/questions/QuestionExtract.py +98 -27
- edsl/questions/QuestionFreeText.py +52 -29
- edsl/questions/QuestionFunctional.py +7 -0
- edsl/questions/QuestionList.py +141 -22
- edsl/questions/QuestionMultipleChoice.py +159 -65
- edsl/questions/QuestionNumerical.py +88 -46
- edsl/questions/QuestionRank.py +182 -24
- edsl/questions/Quick.py +41 -0
- edsl/questions/RegisterQuestionsMeta.py +31 -12
- edsl/questions/ResponseValidatorABC.py +170 -0
- edsl/questions/__init__.py +3 -4
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +10 -5
- edsl/questions/derived/QuestionLinearScale.py +15 -2
- edsl/questions/derived/QuestionTopK.py +10 -1
- edsl/questions/derived/QuestionYesNo.py +24 -3
- edsl/questions/descriptors.py +43 -7
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_registry.py +6 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/Dataset.py +20 -0
- edsl/results/DatasetExportMixin.py +46 -48
- edsl/results/DatasetTree.py +145 -0
- edsl/results/Result.py +32 -5
- edsl/results/Results.py +135 -46
- edsl/results/ResultsDBMixin.py +3 -3
- edsl/results/Selector.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/FileStore.py +71 -10
- edsl/scenarios/Scenario.py +96 -25
- edsl/scenarios/ScenarioImageMixin.py +2 -2
- edsl/scenarios/ScenarioList.py +361 -39
- edsl/scenarios/ScenarioListExportMixin.py +9 -0
- edsl/scenarios/ScenarioListPdfMixin.py +150 -4
- edsl/study/SnapShot.py +8 -1
- edsl/study/Study.py +32 -0
- edsl/surveys/Rule.py +10 -1
- edsl/surveys/RuleCollection.py +21 -5
- edsl/surveys/Survey.py +637 -311
- edsl/surveys/SurveyExportMixin.py +71 -9
- edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
- edsl/surveys/SurveyQualtricsImport.py +75 -4
- edsl/surveys/instructions/ChangeInstruction.py +47 -0
- edsl/surveys/instructions/Instruction.py +34 -0
- edsl/surveys/instructions/InstructionCollection.py +77 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +116 -0
- edsl/templates/error_reporting/interviews.html +10 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- edsl/utilities/utilities.py +9 -1
- {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/METADATA +5 -2
- edsl-0.1.33.dist-info/RECORD +295 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -286
- edsl/jobs/interviews/retry_management.py +0 -37
- edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
- edsl/utilities/gcp_bucket/simple_example.py +0 -9
- edsl-0.1.32.dist-info/RECORD +0 -209
- {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/LICENSE +0 -0
- {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/WHEEL +0 -0
@@ -0,0 +1,331 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import time
|
4
|
+
from dataclasses import dataclass, asdict
|
5
|
+
|
6
|
+
from typing import List, DefaultDict, Optional, Type, Literal
|
7
|
+
from collections import UserDict, defaultdict
|
8
|
+
|
9
|
+
from rich.text import Text
|
10
|
+
from rich.box import SIMPLE
|
11
|
+
from rich.table import Table
|
12
|
+
from rich.live import Live
|
13
|
+
from rich.panel import Panel
|
14
|
+
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn
|
15
|
+
from rich.layout import Layout
|
16
|
+
from rich.console import Group
|
17
|
+
from rich import box
|
18
|
+
|
19
|
+
from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
|
20
|
+
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
21
|
+
from edsl.jobs.tokens.TokenUsage import TokenUsage
|
22
|
+
from edsl.enums import get_token_pricing
|
23
|
+
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
24
|
+
|
25
|
+
InterviewTokenUsageMapping = DefaultDict[str, InterviewTokenUsage]
|
26
|
+
|
27
|
+
from edsl.jobs.interviews.InterviewStatistic import InterviewStatistic
|
28
|
+
from edsl.jobs.interviews.InterviewStatisticsCollection import (
|
29
|
+
InterviewStatisticsCollection,
|
30
|
+
)
|
31
|
+
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
32
|
+
|
33
|
+
|
34
|
+
@dataclass
|
35
|
+
class ModelInfo:
|
36
|
+
model_name: str
|
37
|
+
TPM_limit_k: float
|
38
|
+
RPM_limit_k: float
|
39
|
+
num_tasks_waiting: int
|
40
|
+
token_usage_info: dict
|
41
|
+
|
42
|
+
|
43
|
+
@dataclass
|
44
|
+
class ModelTokenUsageStats:
|
45
|
+
token_usage_type: str
|
46
|
+
details: List[dict]
|
47
|
+
cost: str
|
48
|
+
|
49
|
+
|
50
|
+
class Stats:
|
51
|
+
def elapsed_time(self):
|
52
|
+
InterviewStatistic("elapsed_time", value=elapsed_time, digits=1, units="sec.")
|
53
|
+
|
54
|
+
|
55
|
+
class JobsRunnerStatus:
|
56
|
+
def __init__(
|
57
|
+
self, jobs_runner: "JobsRunnerAsyncio", n: int, refresh_rate: float = 0.25
|
58
|
+
):
|
59
|
+
self.jobs_runner = jobs_runner
|
60
|
+
self.start_time = time.time()
|
61
|
+
self.completed_interviews = []
|
62
|
+
self.refresh_rate = refresh_rate
|
63
|
+
self.statistics = [
|
64
|
+
"elapsed_time",
|
65
|
+
"total_interviews_requested",
|
66
|
+
"completed_interviews",
|
67
|
+
# "percent_complete",
|
68
|
+
"average_time_per_interview",
|
69
|
+
# "task_remaining",
|
70
|
+
"estimated_time_remaining",
|
71
|
+
"exceptions",
|
72
|
+
"unfixed_exceptions",
|
73
|
+
"throughput",
|
74
|
+
]
|
75
|
+
self.num_total_interviews = n * len(self.jobs_runner.interviews)
|
76
|
+
|
77
|
+
self.distinct_models = list(
|
78
|
+
set(i.model.model for i in self.jobs_runner.interviews)
|
79
|
+
)
|
80
|
+
|
81
|
+
self.completed_interview_by_model = defaultdict(list)
|
82
|
+
|
83
|
+
def add_completed_interview(self, result):
|
84
|
+
self.completed_interviews.append(result.interview_hash)
|
85
|
+
|
86
|
+
relevant_model = result.model.model
|
87
|
+
self.completed_interview_by_model[relevant_model].append(result.interview_hash)
|
88
|
+
|
89
|
+
def _compute_statistic(self, stat_name: str):
|
90
|
+
completed_tasks = self.completed_interviews
|
91
|
+
elapsed_time = time.time() - self.start_time
|
92
|
+
interviews = self.jobs_runner.total_interviews
|
93
|
+
|
94
|
+
stat_definitions = {
|
95
|
+
"elapsed_time": lambda: InterviewStatistic(
|
96
|
+
"elapsed_time", value=elapsed_time, digits=1, units="sec."
|
97
|
+
),
|
98
|
+
"total_interviews_requested": lambda: InterviewStatistic(
|
99
|
+
"total_interviews_requested", value=len(interviews), units=""
|
100
|
+
),
|
101
|
+
"completed_interviews": lambda: InterviewStatistic(
|
102
|
+
"completed_interviews", value=len(completed_tasks), units=""
|
103
|
+
),
|
104
|
+
"percent_complete": lambda: InterviewStatistic(
|
105
|
+
"percent_complete",
|
106
|
+
value=(
|
107
|
+
len(completed_tasks) / len(interviews) * 100
|
108
|
+
if len(interviews) > 0
|
109
|
+
else 0
|
110
|
+
),
|
111
|
+
digits=1,
|
112
|
+
units="%",
|
113
|
+
),
|
114
|
+
"average_time_per_interview": lambda: InterviewStatistic(
|
115
|
+
"average_time_per_interview",
|
116
|
+
value=elapsed_time / len(completed_tasks) if completed_tasks else 0,
|
117
|
+
digits=2,
|
118
|
+
units="sec.",
|
119
|
+
),
|
120
|
+
"task_remaining": lambda: InterviewStatistic(
|
121
|
+
"task_remaining", value=len(interviews) - len(completed_tasks), units=""
|
122
|
+
),
|
123
|
+
"estimated_time_remaining": lambda: InterviewStatistic(
|
124
|
+
"estimated_time_remaining",
|
125
|
+
value=(
|
126
|
+
(len(interviews) - len(completed_tasks))
|
127
|
+
* (elapsed_time / len(completed_tasks))
|
128
|
+
if len(completed_tasks) > 0
|
129
|
+
else 0
|
130
|
+
),
|
131
|
+
digits=1,
|
132
|
+
units="sec.",
|
133
|
+
),
|
134
|
+
"exceptions": lambda: InterviewStatistic(
|
135
|
+
"exceptions",
|
136
|
+
value=sum(len(i.exceptions) for i in interviews),
|
137
|
+
units="",
|
138
|
+
),
|
139
|
+
"unfixed_exceptions": lambda: InterviewStatistic(
|
140
|
+
"unfixed_exceptions",
|
141
|
+
value=sum(i.exceptions.num_unfixed() for i in interviews),
|
142
|
+
units="",
|
143
|
+
),
|
144
|
+
"throughput": lambda: InterviewStatistic(
|
145
|
+
"throughput",
|
146
|
+
value=len(completed_tasks) / elapsed_time if elapsed_time > 0 else 0,
|
147
|
+
digits=2,
|
148
|
+
units="interviews/sec.",
|
149
|
+
),
|
150
|
+
}
|
151
|
+
return stat_definitions[stat_name]()
|
152
|
+
|
153
|
+
def create_progress_bar(self):
|
154
|
+
return Progress(
|
155
|
+
TextColumn("[progress.description]{task.description}"),
|
156
|
+
BarColumn(),
|
157
|
+
TaskProgressColumn(),
|
158
|
+
TextColumn("{task.completed}/{task.total}"),
|
159
|
+
)
|
160
|
+
|
161
|
+
def generate_model_queues_table(self):
|
162
|
+
table = Table(show_header=False, box=box.SIMPLE)
|
163
|
+
table.add_column("Info", style="cyan")
|
164
|
+
table.add_column("Value", style="magenta")
|
165
|
+
# table.add_row("Bucket collection", str(self.jobs_runner.bucket_collection))
|
166
|
+
for model, bucket in self.jobs_runner.bucket_collection.items():
|
167
|
+
table.add_row(Text(model.model, style="bold blue"), "")
|
168
|
+
bucket_types = ["requests_bucket", "tokens_bucket"]
|
169
|
+
for bucket_type in bucket_types:
|
170
|
+
table.add_row(Text(" " + bucket_type, style="green"), "")
|
171
|
+
# table.add_row(
|
172
|
+
# f" Current level (capacity = {round(getattr(bucket, bucket_type).capacity, 3)})",
|
173
|
+
# str(round(getattr(bucket, bucket_type).tokens, 3)),
|
174
|
+
# )
|
175
|
+
num_requests = getattr(bucket, bucket_type).num_requests
|
176
|
+
num_released = getattr(bucket, bucket_type).num_released
|
177
|
+
tokens_returned = getattr(bucket, bucket_type).tokens_returned
|
178
|
+
# table.add_row(
|
179
|
+
# f" Requested",
|
180
|
+
# str(num_requests),
|
181
|
+
# )
|
182
|
+
# table.add_row(
|
183
|
+
# f" Completed",
|
184
|
+
# str(num_released),
|
185
|
+
# )
|
186
|
+
table.add_row(
|
187
|
+
" Completed vs. Requested", f"{num_released} vs. {num_requests}"
|
188
|
+
)
|
189
|
+
table.add_row(
|
190
|
+
" Added tokens (from cache)",
|
191
|
+
str(tokens_returned),
|
192
|
+
)
|
193
|
+
if bucket_type == "tokens_bucket":
|
194
|
+
rate_name = "TPM"
|
195
|
+
else:
|
196
|
+
rate_name = "RPM"
|
197
|
+
target_rate = round(getattr(bucket, bucket_type).target_rate, 1)
|
198
|
+
table.add_row(
|
199
|
+
f" Empirical {rate_name} (target = {target_rate})",
|
200
|
+
str(round(getattr(bucket, bucket_type).get_throughput(), 0)),
|
201
|
+
)
|
202
|
+
|
203
|
+
return table
|
204
|
+
|
205
|
+
def generate_layout(self):
|
206
|
+
progress = self.create_progress_bar()
|
207
|
+
task_ids = []
|
208
|
+
for model in self.distinct_models:
|
209
|
+
task_id = progress.add_task(
|
210
|
+
f"[cyan]{model}...",
|
211
|
+
total=int(self.num_total_interviews / len(self.distinct_models)),
|
212
|
+
)
|
213
|
+
task_ids.append((model, task_id))
|
214
|
+
|
215
|
+
progress_height = min(5, 2 + len(self.distinct_models))
|
216
|
+
layout = Layout()
|
217
|
+
|
218
|
+
# Create the top row with only the progress panel
|
219
|
+
layout.split_column(
|
220
|
+
Layout(
|
221
|
+
Panel(
|
222
|
+
progress,
|
223
|
+
title="Interview Progress",
|
224
|
+
border_style="cyan",
|
225
|
+
box=box.ROUNDED,
|
226
|
+
),
|
227
|
+
name="progress",
|
228
|
+
size=progress_height, # Adjusted size
|
229
|
+
),
|
230
|
+
Layout(name="bottom_row"), # Adjusted size
|
231
|
+
)
|
232
|
+
|
233
|
+
# Split the bottom row into two columns for metrics and model queues
|
234
|
+
layout["bottom_row"].split_row(
|
235
|
+
Layout(
|
236
|
+
Panel(
|
237
|
+
self.generate_metrics_table(),
|
238
|
+
title="Metrics",
|
239
|
+
border_style="magenta",
|
240
|
+
box=box.ROUNDED,
|
241
|
+
),
|
242
|
+
name="metrics",
|
243
|
+
),
|
244
|
+
Layout(
|
245
|
+
Panel(
|
246
|
+
self.generate_model_queues_table(),
|
247
|
+
title="Model Queues",
|
248
|
+
border_style="yellow",
|
249
|
+
box=box.ROUNDED,
|
250
|
+
),
|
251
|
+
name="model_queues",
|
252
|
+
),
|
253
|
+
)
|
254
|
+
|
255
|
+
return layout, progress, task_ids
|
256
|
+
|
257
|
+
def generate_metrics_table(self):
|
258
|
+
table = Table(show_header=True, header_style="bold magenta", box=box.SIMPLE)
|
259
|
+
table.add_column("Metric", style="cyan", no_wrap=True)
|
260
|
+
table.add_column("Value", justify="right")
|
261
|
+
|
262
|
+
for stat_name in self.statistics:
|
263
|
+
pretty_name, value = list(self._compute_statistic(stat_name).items())[0]
|
264
|
+
# breakpoint()
|
265
|
+
table.add_row(pretty_name, value)
|
266
|
+
return table
|
267
|
+
|
268
|
+
def update_progress(self):
|
269
|
+
layout, progress, task_ids = self.generate_layout()
|
270
|
+
|
271
|
+
with Live(
|
272
|
+
layout, refresh_per_second=int(1 / self.refresh_rate), transient=True
|
273
|
+
) as live:
|
274
|
+
while len(self.completed_interviews) < len(
|
275
|
+
self.jobs_runner.total_interviews
|
276
|
+
):
|
277
|
+
completed_tasks = len(self.completed_interviews)
|
278
|
+
total_tasks = len(self.jobs_runner.total_interviews)
|
279
|
+
|
280
|
+
for model, task_id in task_ids:
|
281
|
+
completed_tasks = len(self.completed_interview_by_model[model])
|
282
|
+
progress.update(
|
283
|
+
task_id,
|
284
|
+
completed=completed_tasks,
|
285
|
+
description=f"[cyan]Conducting interviews for {model}...",
|
286
|
+
)
|
287
|
+
|
288
|
+
layout["metrics"].update(
|
289
|
+
Panel(
|
290
|
+
self.generate_metrics_table(),
|
291
|
+
title="Metrics",
|
292
|
+
border_style="magenta",
|
293
|
+
box=box.ROUNDED,
|
294
|
+
)
|
295
|
+
)
|
296
|
+
layout["model_queues"].update(
|
297
|
+
Panel(
|
298
|
+
self.generate_model_queues_table(),
|
299
|
+
title="Final Model Queues",
|
300
|
+
border_style="yellow",
|
301
|
+
box=box.ROUNDED,
|
302
|
+
)
|
303
|
+
)
|
304
|
+
|
305
|
+
time.sleep(self.refresh_rate)
|
306
|
+
|
307
|
+
# Final update
|
308
|
+
for model, task_id in task_ids:
|
309
|
+
completed_tasks = len(self.completed_interview_by_model[model])
|
310
|
+
progress.update(
|
311
|
+
task_id,
|
312
|
+
completed=completed_tasks,
|
313
|
+
description=f"[cyan]Conducting interviews for {model}...",
|
314
|
+
)
|
315
|
+
|
316
|
+
layout["metrics"].update(
|
317
|
+
Panel(
|
318
|
+
self.generate_metrics_table(),
|
319
|
+
title="Final Metrics",
|
320
|
+
border_style="magenta",
|
321
|
+
box=box.ROUNDED,
|
322
|
+
)
|
323
|
+
)
|
324
|
+
live.update(layout)
|
325
|
+
time.sleep(1) # Show final state for 1 second
|
326
|
+
|
327
|
+
|
328
|
+
if __name__ == "__main__":
|
329
|
+
import doctest
|
330
|
+
|
331
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
@@ -55,6 +55,7 @@ class QuestionTaskCreator(UserList):
|
|
55
55
|
|
56
56
|
"""
|
57
57
|
super().__init__([])
|
58
|
+
# answer_question_func is the 'interview.answer_question_and_record_task" method
|
58
59
|
self.answer_question_func = answer_question_func
|
59
60
|
self.question = question
|
60
61
|
self.iteration = iteration
|
@@ -87,10 +88,10 @@ class QuestionTaskCreator(UserList):
|
|
87
88
|
"""
|
88
89
|
self.append(task)
|
89
90
|
|
90
|
-
def generate_task(self
|
91
|
+
def generate_task(self) -> asyncio.Task:
|
91
92
|
"""Create a task that depends on the passed-in dependencies."""
|
92
93
|
task = asyncio.create_task(
|
93
|
-
self._run_task_async(
|
94
|
+
self._run_task_async(), name=self.question.question_name
|
94
95
|
)
|
95
96
|
task.depends_on = [t.get_name() for t in self]
|
96
97
|
return task
|
@@ -103,7 +104,7 @@ class QuestionTaskCreator(UserList):
|
|
103
104
|
"""Returns the token usage for the task.
|
104
105
|
|
105
106
|
>>> qt = QuestionTaskCreator.example()
|
106
|
-
>>> answers = asyncio.run(qt._run_focal_task(
|
107
|
+
>>> answers = asyncio.run(qt._run_focal_task())
|
107
108
|
>>> qt.token_usage()
|
108
109
|
{'cached_tokens': TokenUsage(from_cache=True, prompt_tokens=0, completion_tokens=0), 'new_tokens': TokenUsage(from_cache=False, prompt_tokens=0, completion_tokens=0)}
|
109
110
|
"""
|
@@ -111,15 +112,15 @@ class QuestionTaskCreator(UserList):
|
|
111
112
|
cached_tokens=self.cached_token_usage, new_tokens=self.new_token_usage
|
112
113
|
)
|
113
114
|
|
114
|
-
async def _run_focal_task(self
|
115
|
+
async def _run_focal_task(self) -> Answers:
|
115
116
|
"""Run the focal task i.e., the question that we are interested in answering.
|
116
117
|
|
117
118
|
It is only called after all the dependency tasks are completed.
|
118
119
|
|
119
120
|
>>> qt = QuestionTaskCreator.example()
|
120
|
-
>>> answers = asyncio.run(qt._run_focal_task(
|
121
|
-
>>> answers
|
122
|
-
'
|
121
|
+
>>> answers = asyncio.run(qt._run_focal_task())
|
122
|
+
>>> answers.answer
|
123
|
+
'This is an example answer'
|
123
124
|
"""
|
124
125
|
|
125
126
|
requested_tokens = self.estimated_tokens()
|
@@ -132,19 +133,19 @@ class QuestionTaskCreator(UserList):
|
|
132
133
|
self.waiting = True
|
133
134
|
self.task_status = TaskStatus.WAITING_FOR_REQUEST_CAPACITY
|
134
135
|
|
135
|
-
await self.
|
136
|
+
await self.requests_bucket.get_tokens(1, cheat_bucket_capacity=True)
|
136
137
|
|
137
138
|
self.task_status = TaskStatus.API_CALL_IN_PROGRESS
|
138
139
|
try:
|
139
140
|
results = await self.answer_question_func(
|
140
|
-
question=self.question,
|
141
|
+
question=self.question, task=None # self
|
141
142
|
)
|
142
143
|
self.task_status = TaskStatus.SUCCESS
|
143
144
|
except Exception as e:
|
144
145
|
self.task_status = TaskStatus.FAILED
|
145
146
|
raise e
|
146
147
|
|
147
|
-
if results.
|
148
|
+
if results.cache_used:
|
148
149
|
self.tokens_bucket.add_tokens(requested_tokens)
|
149
150
|
self.requests_bucket.add_tokens(1)
|
150
151
|
self.from_cache = True
|
@@ -155,17 +156,18 @@ class QuestionTaskCreator(UserList):
|
|
155
156
|
self.tokens_bucket.turbo_mode_off()
|
156
157
|
self.requests_bucket.turbo_mode_off()
|
157
158
|
|
158
|
-
|
159
|
+
# breakpoint()
|
160
|
+
# _ = results.pop("cached_response", None)
|
159
161
|
|
160
|
-
tracker = self.cached_token_usage if self.from_cache else self.new_token_usage
|
162
|
+
# tracker = self.cached_token_usage if self.from_cache else self.new_token_usage
|
161
163
|
|
162
164
|
# TODO: This is hacky. The 'func' call should return an object that definitely has a 'usage' key.
|
163
|
-
usage = results.get("usage", {"prompt_tokens": 0, "completion_tokens": 0})
|
164
|
-
prompt_tokens = usage.get("prompt_tokens", 0)
|
165
|
-
completion_tokens = usage.get("completion_tokens", 0)
|
166
|
-
tracker.add_tokens(
|
167
|
-
|
168
|
-
)
|
165
|
+
# usage = results.get("usage", {"prompt_tokens": 0, "completion_tokens": 0})
|
166
|
+
# prompt_tokens = usage.get("prompt_tokens", 0)
|
167
|
+
# completion_tokens = usage.get("completion_tokens", 0)
|
168
|
+
# tracker.add_tokens(
|
169
|
+
# prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
|
170
|
+
# )
|
169
171
|
|
170
172
|
return results
|
171
173
|
|
@@ -177,8 +179,13 @@ class QuestionTaskCreator(UserList):
|
|
177
179
|
|
178
180
|
m = ModelBuckets.infinity_bucket()
|
179
181
|
|
180
|
-
|
181
|
-
|
182
|
+
from collections import namedtuple
|
183
|
+
|
184
|
+
AnswerDict = namedtuple("AnswerDict", ["answer", "cache_used"])
|
185
|
+
answer = AnswerDict(answer="This is an example answer", cache_used=False)
|
186
|
+
|
187
|
+
async def answer_question_func(question, task):
|
188
|
+
return answer
|
182
189
|
|
183
190
|
return cls(
|
184
191
|
question=QuestionFreeText.example(),
|
@@ -188,7 +195,7 @@ class QuestionTaskCreator(UserList):
|
|
188
195
|
iteration=0,
|
189
196
|
)
|
190
197
|
|
191
|
-
async def _run_task_async(self
|
198
|
+
async def _run_task_async(self) -> None:
|
192
199
|
"""Run the task asynchronously, awaiting the tasks that must be completed before this one can be run.
|
193
200
|
|
194
201
|
>>> qt1 = QuestionTaskCreator.example()
|
@@ -231,8 +238,6 @@ class QuestionTaskCreator(UserList):
|
|
231
238
|
if isinstance(result, Exception):
|
232
239
|
raise result
|
233
240
|
|
234
|
-
return await self._run_focal_task(debug)
|
235
|
-
|
236
241
|
except asyncio.CancelledError:
|
237
242
|
self.task_status = TaskStatus.CANCELLED
|
238
243
|
raise
|
@@ -244,6 +249,8 @@ class QuestionTaskCreator(UserList):
|
|
244
249
|
f"Required tasks failed for {self.question.question_name}"
|
245
250
|
) from e
|
246
251
|
|
252
|
+
return await self._run_focal_task()
|
253
|
+
|
247
254
|
|
248
255
|
if __name__ == "__main__":
|
249
256
|
import doctest
|