edsl 0.1.33__py3-none-any.whl → 0.1.33.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +3 -9
- edsl/__init__.py +3 -8
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +8 -40
- edsl/agents/AgentList.py +0 -43
- edsl/agents/Invigilator.py +219 -135
- edsl/agents/InvigilatorBase.py +59 -148
- edsl/agents/{PromptConstructor.py → PromptConstructionMixin.py} +89 -138
- edsl/agents/__init__.py +0 -1
- edsl/config.py +56 -47
- edsl/coop/coop.py +7 -50
- edsl/data/Cache.py +1 -35
- edsl/data_transfer_models.py +38 -73
- edsl/enums.py +0 -4
- edsl/exceptions/language_models.py +1 -25
- edsl/exceptions/questions.py +5 -62
- edsl/exceptions/results.py +0 -4
- edsl/inference_services/AnthropicService.py +11 -13
- edsl/inference_services/AwsBedrock.py +17 -19
- edsl/inference_services/AzureAI.py +20 -37
- edsl/inference_services/GoogleService.py +12 -16
- edsl/inference_services/GroqService.py +0 -2
- edsl/inference_services/InferenceServiceABC.py +3 -58
- edsl/inference_services/OpenAIService.py +54 -48
- edsl/inference_services/models_available_cache.py +6 -0
- edsl/inference_services/registry.py +0 -6
- edsl/jobs/Answers.py +12 -10
- edsl/jobs/Jobs.py +21 -36
- edsl/jobs/buckets/BucketCollection.py +15 -24
- edsl/jobs/buckets/TokenBucket.py +14 -93
- edsl/jobs/interviews/Interview.py +78 -366
- edsl/jobs/interviews/InterviewExceptionEntry.py +19 -85
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +286 -0
- edsl/jobs/interviews/{InterviewExceptionCollection.py → interview_exception_tracking.py} +68 -14
- edsl/jobs/interviews/retry_management.py +37 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +175 -146
- edsl/jobs/runners/JobsRunnerStatusMixin.py +333 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +23 -30
- edsl/jobs/tasks/TaskHistory.py +213 -148
- edsl/language_models/LanguageModel.py +156 -261
- edsl/language_models/ModelList.py +2 -2
- edsl/language_models/RegisterLanguageModelsMeta.py +29 -14
- edsl/language_models/registry.py +6 -23
- edsl/language_models/repair.py +19 -0
- edsl/prompts/Prompt.py +2 -52
- edsl/questions/AnswerValidatorMixin.py +26 -23
- edsl/questions/QuestionBase.py +249 -329
- edsl/questions/QuestionBudget.py +41 -99
- edsl/questions/QuestionCheckBox.py +35 -227
- edsl/questions/QuestionExtract.py +27 -98
- edsl/questions/QuestionFreeText.py +29 -52
- edsl/questions/QuestionFunctional.py +0 -7
- edsl/questions/QuestionList.py +22 -141
- edsl/questions/QuestionMultipleChoice.py +65 -159
- edsl/questions/QuestionNumerical.py +46 -88
- edsl/questions/QuestionRank.py +24 -182
- edsl/questions/RegisterQuestionsMeta.py +12 -31
- edsl/questions/__init__.py +4 -3
- edsl/questions/derived/QuestionLikertFive.py +5 -10
- edsl/questions/derived/QuestionLinearScale.py +2 -15
- edsl/questions/derived/QuestionTopK.py +1 -10
- edsl/questions/derived/QuestionYesNo.py +3 -24
- edsl/questions/descriptors.py +7 -43
- edsl/questions/question_registry.py +2 -6
- edsl/results/Dataset.py +0 -20
- edsl/results/DatasetExportMixin.py +48 -46
- edsl/results/Result.py +5 -32
- edsl/results/Results.py +46 -135
- edsl/results/ResultsDBMixin.py +3 -3
- edsl/scenarios/FileStore.py +10 -71
- edsl/scenarios/Scenario.py +25 -96
- edsl/scenarios/ScenarioImageMixin.py +2 -2
- edsl/scenarios/ScenarioList.py +39 -361
- edsl/scenarios/ScenarioListExportMixin.py +0 -9
- edsl/scenarios/ScenarioListPdfMixin.py +4 -150
- edsl/study/SnapShot.py +1 -8
- edsl/study/Study.py +0 -32
- edsl/surveys/Rule.py +1 -10
- edsl/surveys/RuleCollection.py +5 -21
- edsl/surveys/Survey.py +310 -636
- edsl/surveys/SurveyExportMixin.py +9 -71
- edsl/surveys/SurveyFlowVisualizationMixin.py +1 -2
- edsl/surveys/SurveyQualtricsImport.py +4 -75
- edsl/utilities/gcp_bucket/simple_example.py +9 -0
- edsl/utilities/utilities.py +1 -9
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/METADATA +2 -5
- edsl-0.1.33.dev1.dist-info/RECORD +209 -0
- edsl/TemplateLoader.py +0 -24
- edsl/auto/AutoStudy.py +0 -117
- edsl/auto/StageBase.py +0 -230
- edsl/auto/StageGenerateSurvey.py +0 -178
- edsl/auto/StageLabelQuestions.py +0 -125
- edsl/auto/StagePersona.py +0 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +0 -88
- edsl/auto/StagePersonaDimensionValues.py +0 -74
- edsl/auto/StagePersonaDimensions.py +0 -69
- edsl/auto/StageQuestions.py +0 -73
- edsl/auto/SurveyCreatorPipeline.py +0 -21
- edsl/auto/utilities.py +0 -224
- edsl/coop/PriceFetcher.py +0 -58
- edsl/inference_services/MistralAIService.py +0 -120
- edsl/inference_services/TestService.py +0 -80
- edsl/inference_services/TogetherAIService.py +0 -170
- edsl/jobs/FailedQuestion.py +0 -78
- edsl/jobs/runners/JobsRunnerStatus.py +0 -331
- edsl/language_models/fake_openai_call.py +0 -15
- edsl/language_models/fake_openai_service.py +0 -61
- edsl/language_models/utilities.py +0 -61
- edsl/questions/QuestionBaseGenMixin.py +0 -133
- edsl/questions/QuestionBasePromptsMixin.py +0 -266
- edsl/questions/Quick.py +0 -41
- edsl/questions/ResponseValidatorABC.py +0 -170
- edsl/questions/decorators.py +0 -21
- edsl/questions/prompt_templates/question_budget.jinja +0 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +0 -32
- edsl/questions/prompt_templates/question_extract.jinja +0 -11
- edsl/questions/prompt_templates/question_free_text.jinja +0 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +0 -11
- edsl/questions/prompt_templates/question_list.jinja +0 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +0 -33
- edsl/questions/prompt_templates/question_numerical.jinja +0 -37
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +0 -7
- edsl/questions/templates/budget/question_presentation.jinja +0 -7
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +0 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +0 -22
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +0 -7
- edsl/questions/templates/extract/question_presentation.jinja +0 -1
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +0 -1
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +0 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +0 -12
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +0 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +0 -5
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +0 -4
- edsl/questions/templates/list/question_presentation.jinja +0 -5
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +0 -9
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +0 -12
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +0 -8
- edsl/questions/templates/numerical/question_presentation.jinja +0 -7
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +0 -11
- edsl/questions/templates/rank/question_presentation.jinja +0 -15
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +0 -8
- edsl/questions/templates/top_k/question_presentation.jinja +0 -22
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +0 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +0 -12
- edsl/results/DatasetTree.py +0 -145
- edsl/results/Selector.py +0 -118
- edsl/results/tree_explore.py +0 -115
- edsl/surveys/instructions/ChangeInstruction.py +0 -47
- edsl/surveys/instructions/Instruction.py +0 -34
- edsl/surveys/instructions/InstructionCollection.py +0 -77
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +0 -24
- edsl/templates/error_reporting/exceptions_by_model.html +0 -35
- edsl/templates/error_reporting/exceptions_by_question_name.html +0 -17
- edsl/templates/error_reporting/exceptions_by_type.html +0 -17
- edsl/templates/error_reporting/interview_details.html +0 -116
- edsl/templates/error_reporting/interviews.html +0 -10
- edsl/templates/error_reporting/overview.html +0 -5
- edsl/templates/error_reporting/performance_plot.html +0 -2
- edsl/templates/error_reporting/report.css +0 -74
- edsl/templates/error_reporting/report.html +0 -118
- edsl/templates/error_reporting/report.js +0 -25
- edsl-0.1.33.dist-info/RECORD +0 -295
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,333 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import List, DefaultDict
|
3
|
+
import asyncio
|
4
|
+
from typing import Type
|
5
|
+
from collections import defaultdict
|
6
|
+
|
7
|
+
from typing import Literal, List, Type, DefaultDict
|
8
|
+
from collections import UserDict, defaultdict
|
9
|
+
|
10
|
+
from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
|
11
|
+
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
12
|
+
from edsl.jobs.tokens.TokenUsage import TokenUsage
|
13
|
+
from edsl.enums import get_token_pricing
|
14
|
+
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
15
|
+
|
16
|
+
InterviewTokenUsageMapping = DefaultDict[str, InterviewTokenUsage]
|
17
|
+
|
18
|
+
from edsl.jobs.interviews.InterviewStatistic import InterviewStatistic
|
19
|
+
from edsl.jobs.interviews.InterviewStatisticsCollection import (
|
20
|
+
InterviewStatisticsCollection,
|
21
|
+
)
|
22
|
+
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
23
|
+
|
24
|
+
|
25
|
+
# return {"cache_status": token_usage_type, "details": details, "cost": f"${token_usage.cost(prices):.5f}"}
|
26
|
+
|
27
|
+
from dataclasses import dataclass, asdict
|
28
|
+
|
29
|
+
from rich.text import Text
|
30
|
+
from rich.box import SIMPLE
|
31
|
+
from rich.table import Table
|
32
|
+
|
33
|
+
|
34
|
+
@dataclass
|
35
|
+
class ModelInfo:
|
36
|
+
model_name: str
|
37
|
+
TPM_limit_k: float
|
38
|
+
RPM_limit_k: float
|
39
|
+
num_tasks_waiting: int
|
40
|
+
token_usage_info: dict
|
41
|
+
|
42
|
+
|
43
|
+
@dataclass
|
44
|
+
class ModelTokenUsageStats:
|
45
|
+
token_usage_type: str
|
46
|
+
details: List[dict]
|
47
|
+
cost: str
|
48
|
+
|
49
|
+
|
50
|
+
class Stats:
|
51
|
+
def elapsed_time(self):
|
52
|
+
InterviewStatistic("elapsed_time", value=elapsed_time, digits=1, units="sec.")
|
53
|
+
|
54
|
+
|
55
|
+
class JobsRunnerStatusMixin:
|
56
|
+
# @staticmethod
|
57
|
+
# def status_dict(interviews: List[Type["Interview"]]) -> List[Type[InterviewStatusDictionary]]:
|
58
|
+
# """
|
59
|
+
# >>> from edsl.jobs.interviews.Interview import Interview
|
60
|
+
# >>> interviews = [Interview.example()]
|
61
|
+
# >>> JobsRunnerStatusMixin().status_dict(interviews)
|
62
|
+
# [InterviewStatusDictionary({<TaskStatus.NOT_STARTED: 1>: 0, <TaskStatus.WAITING_FOR_DEPENDENCIES: 2>: 0, <TaskStatus.CANCELLED: 3>: 0, <TaskStatus.PARENT_FAILED: 4>: 0, <TaskStatus.WAITING_FOR_REQUEST_CAPACITY: 5>: 0, <TaskStatus.WAITING_FOR_TOKEN_CAPACITY: 6>: 0, <TaskStatus.API_CALL_IN_PROGRESS: 7>: 0, <TaskStatus.SUCCESS: 8>: 0, <TaskStatus.FAILED: 9>: 0, 'number_from_cache': 0})]
|
63
|
+
# """
|
64
|
+
# return [interview.interview_status for interview in interviews]
|
65
|
+
|
66
|
+
def _compute_statistic(stat_name: str, completed_tasks, elapsed_time, interviews):
|
67
|
+
stat_definitions = {
|
68
|
+
"elapsed_time": lambda: InterviewStatistic(
|
69
|
+
"elapsed_time", value=elapsed_time, digits=1, units="sec."
|
70
|
+
),
|
71
|
+
"total_interviews_requested": lambda: InterviewStatistic(
|
72
|
+
"total_interviews_requested", value=len(interviews), units=""
|
73
|
+
),
|
74
|
+
"completed_interviews": lambda: InterviewStatistic(
|
75
|
+
"completed_interviews", value=len(completed_tasks), units=""
|
76
|
+
),
|
77
|
+
"percent_complete": lambda: InterviewStatistic(
|
78
|
+
"percent_complete",
|
79
|
+
value=(
|
80
|
+
len(completed_tasks) / len(interviews) * 100
|
81
|
+
if len(interviews) > 0
|
82
|
+
else "NA"
|
83
|
+
),
|
84
|
+
digits=0,
|
85
|
+
units="%",
|
86
|
+
),
|
87
|
+
"average_time_per_interview": lambda: InterviewStatistic(
|
88
|
+
"average_time_per_interview",
|
89
|
+
value=elapsed_time / len(completed_tasks) if completed_tasks else "NA",
|
90
|
+
digits=1,
|
91
|
+
units="sec.",
|
92
|
+
),
|
93
|
+
"task_remaining": lambda: InterviewStatistic(
|
94
|
+
"task_remaining", value=len(interviews) - len(completed_tasks), units=""
|
95
|
+
),
|
96
|
+
"estimated_time_remaining": lambda: InterviewStatistic(
|
97
|
+
"estimated_time_remaining",
|
98
|
+
value=(
|
99
|
+
(len(interviews) - len(completed_tasks))
|
100
|
+
* (elapsed_time / len(completed_tasks))
|
101
|
+
if len(completed_tasks) > 0
|
102
|
+
else "NA"
|
103
|
+
),
|
104
|
+
digits=1,
|
105
|
+
units="sec.",
|
106
|
+
),
|
107
|
+
}
|
108
|
+
if stat_name not in stat_definitions:
|
109
|
+
raise ValueError(
|
110
|
+
f"Invalid stat_name: {stat_name}. The valid stat_names are: {list(stat_definitions.keys())}"
|
111
|
+
)
|
112
|
+
return stat_definitions[stat_name]()
|
113
|
+
|
114
|
+
@staticmethod
|
115
|
+
def _job_level_info(
|
116
|
+
completed_tasks: List[Type[asyncio.Task]],
|
117
|
+
elapsed_time: float,
|
118
|
+
interviews: List[Type["Interview"]],
|
119
|
+
) -> InterviewStatisticsCollection:
|
120
|
+
interview_statistics = InterviewStatisticsCollection()
|
121
|
+
|
122
|
+
default_statistics = [
|
123
|
+
"elapsed_time",
|
124
|
+
"total_interviews_requested",
|
125
|
+
"completed_interviews",
|
126
|
+
"percent_complete",
|
127
|
+
"average_time_per_interview",
|
128
|
+
"task_remaining",
|
129
|
+
"estimated_time_remaining",
|
130
|
+
]
|
131
|
+
for stat_name in default_statistics:
|
132
|
+
interview_statistics.add_stat(
|
133
|
+
JobsRunnerStatusMixin._compute_statistic(
|
134
|
+
stat_name, completed_tasks, elapsed_time, interviews
|
135
|
+
)
|
136
|
+
)
|
137
|
+
|
138
|
+
return interview_statistics
|
139
|
+
|
140
|
+
@staticmethod
|
141
|
+
def _get_model_queues_info(interviews):
|
142
|
+
models_to_tokens = defaultdict(InterviewTokenUsage)
|
143
|
+
model_to_status = defaultdict(InterviewStatusDictionary)
|
144
|
+
waiting_dict = defaultdict(int)
|
145
|
+
|
146
|
+
for interview in interviews:
|
147
|
+
models_to_tokens[interview.model] += interview.token_usage
|
148
|
+
model_to_status[interview.model] += interview.interview_status
|
149
|
+
waiting_dict[interview.model] += interview.interview_status.waiting
|
150
|
+
|
151
|
+
for model, num_waiting in waiting_dict.items():
|
152
|
+
yield JobsRunnerStatusMixin._get_model_info(
|
153
|
+
model, num_waiting, models_to_tokens
|
154
|
+
)
|
155
|
+
|
156
|
+
@staticmethod
|
157
|
+
def generate_status_summary(
|
158
|
+
completed_tasks: List[Type[asyncio.Task]],
|
159
|
+
elapsed_time: float,
|
160
|
+
interviews: List[Type["Interview"]],
|
161
|
+
include_model_queues=False,
|
162
|
+
) -> InterviewStatisticsCollection:
|
163
|
+
"""Generate a summary of the status of the job runner.
|
164
|
+
|
165
|
+
:param completed_tasks: list of completed tasks
|
166
|
+
:param elapsed_time: time elapsed since the start of the job
|
167
|
+
:param interviews: list of interviews to be conducted
|
168
|
+
|
169
|
+
>>> from edsl.jobs.interviews.Interview import Interview
|
170
|
+
>>> interviews = [Interview.example()]
|
171
|
+
>>> completed_tasks = []
|
172
|
+
>>> elapsed_time = 0
|
173
|
+
>>> JobsRunnerStatusMixin().generate_status_summary(completed_tasks, elapsed_time, interviews)
|
174
|
+
{'Elapsed time': '0.0 sec.', 'Total interviews requested': '1 ', 'Completed interviews': '0 ', 'Percent complete': '0 %', 'Average time per interview': 'NA', 'Task remaining': '1 ', 'Estimated time remaining': 'NA'}
|
175
|
+
"""
|
176
|
+
|
177
|
+
interview_status_summary: InterviewStatisticsCollection = (
|
178
|
+
JobsRunnerStatusMixin._job_level_info(
|
179
|
+
completed_tasks=completed_tasks,
|
180
|
+
elapsed_time=elapsed_time,
|
181
|
+
interviews=interviews,
|
182
|
+
)
|
183
|
+
)
|
184
|
+
if include_model_queues:
|
185
|
+
interview_status_summary.model_queues = list(
|
186
|
+
JobsRunnerStatusMixin._get_model_queues_info(interviews)
|
187
|
+
)
|
188
|
+
else:
|
189
|
+
interview_status_summary.model_queues = None
|
190
|
+
|
191
|
+
return interview_status_summary
|
192
|
+
|
193
|
+
@staticmethod
|
194
|
+
def _get_model_info(
|
195
|
+
model: str,
|
196
|
+
num_waiting: int,
|
197
|
+
models_to_tokens: InterviewTokenUsageMapping,
|
198
|
+
) -> dict:
|
199
|
+
"""Get the status of a model.
|
200
|
+
|
201
|
+
:param model: the model name
|
202
|
+
:param num_waiting: the number of tasks waiting for capacity
|
203
|
+
:param models_to_tokens: a mapping of models to token usage
|
204
|
+
|
205
|
+
>>> from edsl.jobs.interviews.Interview import Interview
|
206
|
+
>>> interviews = [Interview.example()]
|
207
|
+
>>> models_to_tokens = defaultdict(InterviewTokenUsage)
|
208
|
+
>>> model = interviews[0].model
|
209
|
+
>>> num_waiting = 0
|
210
|
+
>>> JobsRunnerStatusMixin()._get_model_info(model, num_waiting, models_to_tokens)
|
211
|
+
ModelInfo(model_name='gpt-4-1106-preview', TPM_limit_k=480.0, RPM_limit_k=4.0, num_tasks_waiting=0, token_usage_info=[ModelTokenUsageStats(token_usage_type='new_token_usage', details=[{'type': 'prompt_tokens', 'tokens': 0}, {'type': 'completion_tokens', 'tokens': 0}], cost='$0.00000'), ModelTokenUsageStats(token_usage_type='cached_token_usage', details=[{'type': 'prompt_tokens', 'tokens': 0}, {'type': 'completion_tokens', 'tokens': 0}], cost='$0.00000')])
|
212
|
+
"""
|
213
|
+
|
214
|
+
## TODO: This should probably be a coop method
|
215
|
+
prices = get_token_pricing(model.model)
|
216
|
+
|
217
|
+
token_usage_info = []
|
218
|
+
for token_usage_type in ["new_token_usage", "cached_token_usage"]:
|
219
|
+
token_usage_info.append(
|
220
|
+
JobsRunnerStatusMixin._get_token_usage_info(
|
221
|
+
token_usage_type, models_to_tokens, model, prices
|
222
|
+
)
|
223
|
+
)
|
224
|
+
|
225
|
+
return ModelInfo(
|
226
|
+
**{
|
227
|
+
"model_name": model.model,
|
228
|
+
"TPM_limit_k": model.TPM / 1000,
|
229
|
+
"RPM_limit_k": model.RPM / 1000,
|
230
|
+
"num_tasks_waiting": num_waiting,
|
231
|
+
"token_usage_info": token_usage_info,
|
232
|
+
}
|
233
|
+
)
|
234
|
+
|
235
|
+
@staticmethod
|
236
|
+
def _get_token_usage_info(
|
237
|
+
token_usage_type: Literal["new_token_usage", "cached_token_usage"],
|
238
|
+
models_to_tokens: InterviewTokenUsageMapping,
|
239
|
+
model: str,
|
240
|
+
prices: "TokenPricing",
|
241
|
+
) -> ModelTokenUsageStats:
|
242
|
+
"""Get the token usage info for a model.
|
243
|
+
|
244
|
+
>>> from edsl.jobs.interviews.Interview import Interview
|
245
|
+
>>> interviews = [Interview.example()]
|
246
|
+
>>> models_to_tokens = defaultdict(InterviewTokenUsage)
|
247
|
+
>>> model = interviews[0].model
|
248
|
+
>>> prices = get_token_pricing(model.model)
|
249
|
+
>>> cache_status = "new_token_usage"
|
250
|
+
>>> JobsRunnerStatusMixin()._get_token_usage_info(cache_status, models_to_tokens, model, prices)
|
251
|
+
ModelTokenUsageStats(token_usage_type='new_token_usage', details=[{'type': 'prompt_tokens', 'tokens': 0}, {'type': 'completion_tokens', 'tokens': 0}], cost='$0.00000')
|
252
|
+
|
253
|
+
"""
|
254
|
+
all_token_usage: InterviewTokenUsage = models_to_tokens[model]
|
255
|
+
token_usage: TokenUsage = getattr(all_token_usage, token_usage_type)
|
256
|
+
|
257
|
+
details = [
|
258
|
+
{"type": token_type, "tokens": getattr(token_usage, token_type)}
|
259
|
+
for token_type in ["prompt_tokens", "completion_tokens"]
|
260
|
+
]
|
261
|
+
|
262
|
+
return ModelTokenUsageStats(
|
263
|
+
token_usage_type=token_usage_type,
|
264
|
+
details=details,
|
265
|
+
cost=f"${token_usage.cost(prices):.5f}",
|
266
|
+
)
|
267
|
+
|
268
|
+
@staticmethod
|
269
|
+
def _add_statistics_to_table(table, status_summary):
|
270
|
+
table.add_column("Statistic", style="dim", no_wrap=True, width=50)
|
271
|
+
table.add_column("Value", width=10)
|
272
|
+
|
273
|
+
for key, value in status_summary.items():
|
274
|
+
if key != "model_queues":
|
275
|
+
table.add_row(key, value)
|
276
|
+
|
277
|
+
@staticmethod
|
278
|
+
def display_status_table(status_summary: InterviewStatisticsCollection) -> "Table":
|
279
|
+
table = Table(
|
280
|
+
title="Job Status",
|
281
|
+
show_header=True,
|
282
|
+
header_style="bold magenta",
|
283
|
+
box=SIMPLE,
|
284
|
+
)
|
285
|
+
|
286
|
+
### Job-level statistics
|
287
|
+
JobsRunnerStatusMixin._add_statistics_to_table(table, status_summary)
|
288
|
+
|
289
|
+
## Model-level statistics
|
290
|
+
spacing = " "
|
291
|
+
|
292
|
+
if status_summary.model_queues is not None:
|
293
|
+
table.add_row(Text("Model Queues", style="bold red"), "")
|
294
|
+
for model_info in status_summary.model_queues:
|
295
|
+
model_name = model_info.model_name
|
296
|
+
tpm = f"TPM (k)={model_info.TPM_limit_k}"
|
297
|
+
rpm = f"RPM (k)= {model_info.RPM_limit_k}"
|
298
|
+
pretty_model_name = model_name + ";" + tpm + ";" + rpm
|
299
|
+
table.add_row(Text(pretty_model_name, style="blue"), "")
|
300
|
+
table.add_row(
|
301
|
+
"Number question tasks waiting for capacity",
|
302
|
+
str(model_info.num_tasks_waiting),
|
303
|
+
)
|
304
|
+
# Token usage and cost info
|
305
|
+
for token_usage_info in model_info.token_usage_info:
|
306
|
+
token_usage_type = token_usage_info.token_usage_type
|
307
|
+
table.add_row(
|
308
|
+
Text(
|
309
|
+
spacing + token_usage_type.replace("_", " "), style="bold"
|
310
|
+
),
|
311
|
+
"",
|
312
|
+
)
|
313
|
+
for detail in token_usage_info.details:
|
314
|
+
token_type = detail["type"]
|
315
|
+
tokens = detail["tokens"]
|
316
|
+
table.add_row(spacing + f"{token_type}", f"{tokens:,}")
|
317
|
+
# table.add_row(spacing + "cost", cache_info["cost"])
|
318
|
+
|
319
|
+
return table
|
320
|
+
|
321
|
+
def status_table(self, completed_tasks: List[asyncio.Task], elapsed_time: float):
|
322
|
+
summary_data = JobsRunnerStatusMixin.generate_status_summary(
|
323
|
+
completed_tasks=completed_tasks,
|
324
|
+
elapsed_time=elapsed_time,
|
325
|
+
interviews=self.total_interviews,
|
326
|
+
)
|
327
|
+
return self.display_status_table(summary_data)
|
328
|
+
|
329
|
+
|
330
|
+
if __name__ == "__main__":
|
331
|
+
import doctest
|
332
|
+
|
333
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
@@ -55,7 +55,6 @@ class QuestionTaskCreator(UserList):
|
|
55
55
|
|
56
56
|
"""
|
57
57
|
super().__init__([])
|
58
|
-
# answer_question_func is the 'interview.answer_question_and_record_task" method
|
59
58
|
self.answer_question_func = answer_question_func
|
60
59
|
self.question = question
|
61
60
|
self.iteration = iteration
|
@@ -88,10 +87,10 @@ class QuestionTaskCreator(UserList):
|
|
88
87
|
"""
|
89
88
|
self.append(task)
|
90
89
|
|
91
|
-
def generate_task(self) -> asyncio.Task:
|
90
|
+
def generate_task(self, debug: bool) -> asyncio.Task:
|
92
91
|
"""Create a task that depends on the passed-in dependencies."""
|
93
92
|
task = asyncio.create_task(
|
94
|
-
self._run_task_async(), name=self.question.question_name
|
93
|
+
self._run_task_async(debug), name=self.question.question_name
|
95
94
|
)
|
96
95
|
task.depends_on = [t.get_name() for t in self]
|
97
96
|
return task
|
@@ -104,7 +103,7 @@ class QuestionTaskCreator(UserList):
|
|
104
103
|
"""Returns the token usage for the task.
|
105
104
|
|
106
105
|
>>> qt = QuestionTaskCreator.example()
|
107
|
-
>>> answers = asyncio.run(qt._run_focal_task())
|
106
|
+
>>> answers = asyncio.run(qt._run_focal_task(debug=False))
|
108
107
|
>>> qt.token_usage()
|
109
108
|
{'cached_tokens': TokenUsage(from_cache=True, prompt_tokens=0, completion_tokens=0), 'new_tokens': TokenUsage(from_cache=False, prompt_tokens=0, completion_tokens=0)}
|
110
109
|
"""
|
@@ -112,15 +111,15 @@ class QuestionTaskCreator(UserList):
|
|
112
111
|
cached_tokens=self.cached_token_usage, new_tokens=self.new_token_usage
|
113
112
|
)
|
114
113
|
|
115
|
-
async def _run_focal_task(self) -> Answers:
|
114
|
+
async def _run_focal_task(self, debug: bool) -> Answers:
|
116
115
|
"""Run the focal task i.e., the question that we are interested in answering.
|
117
116
|
|
118
117
|
It is only called after all the dependency tasks are completed.
|
119
118
|
|
120
119
|
>>> qt = QuestionTaskCreator.example()
|
121
|
-
>>> answers = asyncio.run(qt._run_focal_task())
|
122
|
-
>>> answers
|
123
|
-
'
|
120
|
+
>>> answers = asyncio.run(qt._run_focal_task(debug=False))
|
121
|
+
>>> answers["answer"]
|
122
|
+
'Yo!'
|
124
123
|
"""
|
125
124
|
|
126
125
|
requested_tokens = self.estimated_tokens()
|
@@ -133,19 +132,19 @@ class QuestionTaskCreator(UserList):
|
|
133
132
|
self.waiting = True
|
134
133
|
self.task_status = TaskStatus.WAITING_FOR_REQUEST_CAPACITY
|
135
134
|
|
136
|
-
await self.
|
135
|
+
await self.tokens_bucket.get_tokens(1)
|
137
136
|
|
138
137
|
self.task_status = TaskStatus.API_CALL_IN_PROGRESS
|
139
138
|
try:
|
140
139
|
results = await self.answer_question_func(
|
141
|
-
question=self.question, task=None # self
|
140
|
+
question=self.question, debug=debug, task=None # self
|
142
141
|
)
|
143
142
|
self.task_status = TaskStatus.SUCCESS
|
144
143
|
except Exception as e:
|
145
144
|
self.task_status = TaskStatus.FAILED
|
146
145
|
raise e
|
147
146
|
|
148
|
-
if results.cache_used:
|
147
|
+
if results.get("cache_used", False):
|
149
148
|
self.tokens_bucket.add_tokens(requested_tokens)
|
150
149
|
self.requests_bucket.add_tokens(1)
|
151
150
|
self.from_cache = True
|
@@ -156,18 +155,17 @@ class QuestionTaskCreator(UserList):
|
|
156
155
|
self.tokens_bucket.turbo_mode_off()
|
157
156
|
self.requests_bucket.turbo_mode_off()
|
158
157
|
|
159
|
-
|
160
|
-
# _ = results.pop("cached_response", None)
|
158
|
+
_ = results.pop("cached_response", None)
|
161
159
|
|
162
|
-
|
160
|
+
tracker = self.cached_token_usage if self.from_cache else self.new_token_usage
|
163
161
|
|
164
162
|
# TODO: This is hacky. The 'func' call should return an object that definitely has a 'usage' key.
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
163
|
+
usage = results.get("usage", {"prompt_tokens": 0, "completion_tokens": 0})
|
164
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
165
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
166
|
+
tracker.add_tokens(
|
167
|
+
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
|
168
|
+
)
|
171
169
|
|
172
170
|
return results
|
173
171
|
|
@@ -179,13 +177,8 @@ class QuestionTaskCreator(UserList):
|
|
179
177
|
|
180
178
|
m = ModelBuckets.infinity_bucket()
|
181
179
|
|
182
|
-
|
183
|
-
|
184
|
-
AnswerDict = namedtuple("AnswerDict", ["answer", "cache_used"])
|
185
|
-
answer = AnswerDict(answer="This is an example answer", cache_used=False)
|
186
|
-
|
187
|
-
async def answer_question_func(question, task):
|
188
|
-
return answer
|
180
|
+
async def answer_question_func(question, debug, task):
|
181
|
+
return {"answer": "Yo!"}
|
189
182
|
|
190
183
|
return cls(
|
191
184
|
question=QuestionFreeText.example(),
|
@@ -195,7 +188,7 @@ class QuestionTaskCreator(UserList):
|
|
195
188
|
iteration=0,
|
196
189
|
)
|
197
190
|
|
198
|
-
async def _run_task_async(self) -> None:
|
191
|
+
async def _run_task_async(self, debug) -> None:
|
199
192
|
"""Run the task asynchronously, awaiting the tasks that must be completed before this one can be run.
|
200
193
|
|
201
194
|
>>> qt1 = QuestionTaskCreator.example()
|
@@ -238,6 +231,8 @@ class QuestionTaskCreator(UserList):
|
|
238
231
|
if isinstance(result, Exception):
|
239
232
|
raise result
|
240
233
|
|
234
|
+
return await self._run_focal_task(debug)
|
235
|
+
|
241
236
|
except asyncio.CancelledError:
|
242
237
|
self.task_status = TaskStatus.CANCELLED
|
243
238
|
raise
|
@@ -249,8 +244,6 @@ class QuestionTaskCreator(UserList):
|
|
249
244
|
f"Required tasks failed for {self.question.question_name}"
|
250
245
|
) from e
|
251
246
|
|
252
|
-
return await self._run_focal_task()
|
253
|
-
|
254
247
|
|
255
248
|
if __name__ == "__main__":
|
256
249
|
import doctest
|