edsl 0.1.38.dev3__py3-none-any.whl → 0.1.38.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +332 -303
- edsl/BaseDiff.py +260 -260
- edsl/TemplateLoader.py +24 -24
- edsl/__init__.py +49 -49
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +867 -858
- edsl/agents/AgentList.py +413 -362
- edsl/agents/Invigilator.py +233 -222
- edsl/agents/InvigilatorBase.py +265 -284
- edsl/agents/PromptConstructor.py +354 -353
- edsl/agents/__init__.py +3 -3
- edsl/agents/descriptors.py +99 -99
- edsl/agents/prompt_helpers.py +129 -129
- edsl/auto/AutoStudy.py +117 -117
- edsl/auto/StageBase.py +230 -230
- edsl/auto/StageGenerateSurvey.py +178 -178
- edsl/auto/StageLabelQuestions.py +125 -125
- edsl/auto/StagePersona.py +61 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
- edsl/auto/StagePersonaDimensionValues.py +74 -74
- edsl/auto/StagePersonaDimensions.py +69 -69
- edsl/auto/StageQuestions.py +73 -73
- edsl/auto/SurveyCreatorPipeline.py +21 -21
- edsl/auto/utilities.py +224 -224
- edsl/base/Base.py +279 -279
- edsl/config.py +157 -149
- edsl/conversation/Conversation.py +290 -290
- edsl/conversation/car_buying.py +58 -58
- edsl/conversation/chips.py +95 -95
- edsl/conversation/mug_negotiation.py +81 -81
- edsl/conversation/next_speaker_utilities.py +93 -93
- edsl/coop/PriceFetcher.py +54 -54
- edsl/coop/__init__.py +2 -2
- edsl/coop/coop.py +1028 -961
- edsl/coop/utils.py +131 -131
- edsl/data/Cache.py +555 -530
- edsl/data/CacheEntry.py +233 -228
- edsl/data/CacheHandler.py +149 -149
- edsl/data/RemoteCacheSync.py +78 -97
- edsl/data/SQLiteDict.py +292 -292
- edsl/data/__init__.py +4 -4
- edsl/data/orm.py +10 -10
- edsl/data_transfer_models.py +73 -73
- edsl/enums.py +175 -173
- edsl/exceptions/BaseException.py +21 -21
- edsl/exceptions/__init__.py +54 -54
- edsl/exceptions/agents.py +42 -42
- edsl/exceptions/cache.py +5 -5
- edsl/exceptions/configuration.py +16 -16
- edsl/exceptions/coop.py +10 -10
- edsl/exceptions/data.py +14 -14
- edsl/exceptions/general.py +34 -34
- edsl/exceptions/jobs.py +33 -33
- edsl/exceptions/language_models.py +63 -63
- edsl/exceptions/prompts.py +15 -15
- edsl/exceptions/questions.py +91 -91
- edsl/exceptions/results.py +29 -29
- edsl/exceptions/scenarios.py +22 -22
- edsl/exceptions/surveys.py +37 -37
- edsl/inference_services/AnthropicService.py +87 -87
- edsl/inference_services/AwsBedrock.py +120 -120
- edsl/inference_services/AzureAI.py +217 -217
- edsl/inference_services/DeepInfraService.py +18 -18
- edsl/inference_services/GoogleService.py +148 -156
- edsl/inference_services/GroqService.py +20 -20
- edsl/inference_services/InferenceServiceABC.py +147 -147
- edsl/inference_services/InferenceServicesCollection.py +97 -97
- edsl/inference_services/MistralAIService.py +123 -123
- edsl/inference_services/OllamaService.py +18 -18
- edsl/inference_services/OpenAIService.py +224 -224
- edsl/inference_services/PerplexityService.py +163 -0
- edsl/inference_services/TestService.py +89 -89
- edsl/inference_services/TogetherAIService.py +170 -170
- edsl/inference_services/models_available_cache.py +118 -118
- edsl/inference_services/rate_limits_cache.py +25 -25
- edsl/inference_services/registry.py +41 -39
- edsl/inference_services/write_available.py +10 -10
- edsl/jobs/Answers.py +56 -56
- edsl/jobs/Jobs.py +898 -1358
- edsl/jobs/JobsChecks.py +147 -0
- edsl/jobs/JobsPrompts.py +268 -0
- edsl/jobs/JobsRemoteInferenceHandler.py +239 -0
- edsl/jobs/__init__.py +1 -1
- edsl/jobs/buckets/BucketCollection.py +63 -63
- edsl/jobs/buckets/ModelBuckets.py +65 -65
- edsl/jobs/buckets/TokenBucket.py +251 -251
- edsl/jobs/interviews/Interview.py +661 -661
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
- edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
- edsl/jobs/interviews/InterviewStatistic.py +63 -63
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
- edsl/jobs/interviews/InterviewStatusLog.py +92 -92
- edsl/jobs/interviews/ReportErrors.py +66 -66
- edsl/jobs/interviews/interview_status_enum.py +9 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +466 -361
- edsl/jobs/runners/JobsRunnerStatus.py +330 -332
- edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
- edsl/jobs/tasks/TaskCreators.py +64 -64
- edsl/jobs/tasks/TaskHistory.py +450 -451
- edsl/jobs/tasks/TaskStatusLog.py +23 -23
- edsl/jobs/tasks/task_status_enum.py +163 -163
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
- edsl/jobs/tokens/TokenUsage.py +34 -34
- edsl/language_models/KeyLookup.py +30 -30
- edsl/language_models/LanguageModel.py +668 -708
- edsl/language_models/ModelList.py +155 -109
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
- edsl/language_models/__init__.py +3 -3
- edsl/language_models/fake_openai_call.py +15 -15
- edsl/language_models/fake_openai_service.py +61 -61
- edsl/language_models/registry.py +190 -137
- edsl/language_models/repair.py +156 -156
- edsl/language_models/unused/ReplicateBase.py +83 -83
- edsl/language_models/utilities.py +64 -64
- edsl/notebooks/Notebook.py +258 -258
- edsl/notebooks/__init__.py +1 -1
- edsl/prompts/Prompt.py +362 -357
- edsl/prompts/__init__.py +2 -2
- edsl/questions/AnswerValidatorMixin.py +289 -289
- edsl/questions/QuestionBase.py +664 -660
- edsl/questions/QuestionBaseGenMixin.py +161 -161
- edsl/questions/QuestionBasePromptsMixin.py +217 -217
- edsl/questions/QuestionBudget.py +227 -227
- edsl/questions/QuestionCheckBox.py +359 -359
- edsl/questions/QuestionExtract.py +182 -183
- edsl/questions/QuestionFreeText.py +114 -114
- edsl/questions/QuestionFunctional.py +166 -166
- edsl/questions/QuestionList.py +231 -231
- edsl/questions/QuestionMultipleChoice.py +286 -286
- edsl/questions/QuestionNumerical.py +153 -153
- edsl/questions/QuestionRank.py +324 -324
- edsl/questions/Quick.py +41 -41
- edsl/questions/RegisterQuestionsMeta.py +71 -71
- edsl/questions/ResponseValidatorABC.py +174 -174
- edsl/questions/SimpleAskMixin.py +73 -73
- edsl/questions/__init__.py +26 -26
- edsl/questions/compose_questions.py +98 -98
- edsl/questions/decorators.py +21 -21
- edsl/questions/derived/QuestionLikertFive.py +76 -76
- edsl/questions/derived/QuestionLinearScale.py +87 -87
- edsl/questions/derived/QuestionTopK.py +93 -93
- edsl/questions/derived/QuestionYesNo.py +82 -82
- edsl/questions/descriptors.py +413 -413
- edsl/questions/prompt_templates/question_budget.jinja +13 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
- edsl/questions/prompt_templates/question_extract.jinja +11 -11
- edsl/questions/prompt_templates/question_free_text.jinja +3 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
- edsl/questions/prompt_templates/question_list.jinja +17 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
- edsl/questions/prompt_templates/question_numerical.jinja +36 -36
- edsl/questions/question_registry.py +177 -147
- edsl/questions/settings.py +12 -12
- edsl/questions/templates/budget/answering_instructions.jinja +7 -7
- edsl/questions/templates/budget/question_presentation.jinja +7 -7
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
- edsl/questions/templates/extract/answering_instructions.jinja +7 -7
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
- edsl/questions/templates/list/answering_instructions.jinja +3 -3
- edsl/questions/templates/list/question_presentation.jinja +5 -5
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
- edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
- edsl/questions/templates/numerical/question_presentation.jinja +6 -6
- edsl/questions/templates/rank/answering_instructions.jinja +11 -11
- edsl/questions/templates/rank/question_presentation.jinja +15 -15
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
- edsl/questions/templates/top_k/question_presentation.jinja +22 -22
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
- edsl/results/CSSParameterizer.py +108 -0
- edsl/results/Dataset.py +424 -293
- edsl/results/DatasetExportMixin.py +731 -717
- edsl/results/DatasetTree.py +275 -145
- edsl/results/Result.py +465 -456
- edsl/results/Results.py +1165 -1071
- edsl/results/ResultsDBMixin.py +238 -238
- edsl/results/ResultsExportMixin.py +43 -43
- edsl/results/ResultsFetchMixin.py +33 -33
- edsl/results/ResultsGGMixin.py +121 -121
- edsl/results/ResultsToolsMixin.py +98 -98
- edsl/results/Selector.py +135 -135
- edsl/results/TableDisplay.py +198 -0
- edsl/results/__init__.py +2 -2
- edsl/results/table_display.css +78 -0
- edsl/results/tree_explore.py +115 -115
- edsl/scenarios/FileStore.py +632 -458
- edsl/scenarios/Scenario.py +601 -544
- edsl/scenarios/ScenarioHtmlMixin.py +64 -64
- edsl/scenarios/ScenarioJoin.py +127 -0
- edsl/scenarios/ScenarioList.py +1287 -1112
- edsl/scenarios/ScenarioListExportMixin.py +52 -52
- edsl/scenarios/ScenarioListPdfMixin.py +261 -261
- edsl/scenarios/__init__.py +4 -4
- edsl/shared.py +1 -1
- edsl/study/ObjectEntry.py +173 -173
- edsl/study/ProofOfWork.py +113 -113
- edsl/study/SnapShot.py +80 -80
- edsl/study/Study.py +528 -528
- edsl/study/__init__.py +4 -4
- edsl/surveys/DAG.py +148 -148
- edsl/surveys/Memory.py +31 -31
- edsl/surveys/MemoryPlan.py +244 -244
- edsl/surveys/Rule.py +326 -326
- edsl/surveys/RuleCollection.py +387 -387
- edsl/surveys/Survey.py +1801 -1787
- edsl/surveys/SurveyCSS.py +261 -261
- edsl/surveys/SurveyExportMixin.py +259 -259
- edsl/surveys/SurveyFlowVisualizationMixin.py +179 -121
- edsl/surveys/SurveyQualtricsImport.py +284 -284
- edsl/surveys/__init__.py +3 -3
- edsl/surveys/base.py +53 -53
- edsl/surveys/descriptors.py +56 -56
- edsl/surveys/instructions/ChangeInstruction.py +49 -49
- edsl/surveys/instructions/Instruction.py +65 -53
- edsl/surveys/instructions/InstructionCollection.py +77 -77
- edsl/templates/error_reporting/base.html +23 -23
- edsl/templates/error_reporting/exceptions_by_model.html +34 -34
- edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
- edsl/templates/error_reporting/exceptions_by_type.html +16 -16
- edsl/templates/error_reporting/interview_details.html +115 -115
- edsl/templates/error_reporting/interviews.html +19 -10
- edsl/templates/error_reporting/overview.html +4 -4
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +73 -73
- edsl/templates/error_reporting/report.html +117 -117
- edsl/templates/error_reporting/report.js +25 -25
- edsl/tools/__init__.py +1 -1
- edsl/tools/clusters.py +192 -192
- edsl/tools/embeddings.py +27 -27
- edsl/tools/embeddings_plotting.py +118 -118
- edsl/tools/plotting.py +112 -112
- edsl/tools/summarize.py +18 -18
- edsl/utilities/SystemInfo.py +28 -28
- edsl/utilities/__init__.py +22 -22
- edsl/utilities/ast_utilities.py +25 -25
- edsl/utilities/data/Registry.py +6 -6
- edsl/utilities/data/__init__.py +1 -1
- edsl/utilities/data/scooter_results.json +1 -1
- edsl/utilities/decorators.py +77 -77
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
- edsl/utilities/interface.py +627 -627
- edsl/utilities/naming_utilities.py +263 -263
- edsl/utilities/repair_functions.py +28 -28
- edsl/utilities/restricted_python.py +70 -70
- edsl/utilities/utilities.py +424 -409
- {edsl-0.1.38.dev3.dist-info → edsl-0.1.38.dev4.dist-info}/LICENSE +21 -21
- {edsl-0.1.38.dev3.dist-info → edsl-0.1.38.dev4.dist-info}/METADATA +2 -1
- edsl-0.1.38.dev4.dist-info/RECORD +277 -0
- edsl-0.1.38.dev3.dist-info/RECORD +0 -269
- {edsl-0.1.38.dev3.dist-info → edsl-0.1.38.dev4.dist-info}/WHEEL +0 -0
@@ -1,332 +1,330 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import os
|
4
|
-
import time
|
5
|
-
import requests
|
6
|
-
import warnings
|
7
|
-
from abc import ABC, abstractmethod
|
8
|
-
from dataclasses import dataclass
|
9
|
-
|
10
|
-
from typing import Any, List, DefaultDict, Optional, Dict
|
11
|
-
from collections import defaultdict
|
12
|
-
from uuid import UUID
|
13
|
-
|
14
|
-
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
15
|
-
|
16
|
-
InterviewTokenUsageMapping = DefaultDict[str, InterviewTokenUsage]
|
17
|
-
|
18
|
-
from edsl.jobs.interviews.InterviewStatistic import InterviewStatistic
|
19
|
-
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
20
|
-
|
21
|
-
|
22
|
-
@dataclass
|
23
|
-
class ModelInfo:
|
24
|
-
model_name: str
|
25
|
-
TPM_limit_k: float
|
26
|
-
RPM_limit_k: float
|
27
|
-
num_tasks_waiting: int
|
28
|
-
token_usage_info: dict
|
29
|
-
|
30
|
-
|
31
|
-
@dataclass
|
32
|
-
class ModelTokenUsageStats:
|
33
|
-
token_usage_type: str
|
34
|
-
details: List[dict]
|
35
|
-
cost: str
|
36
|
-
|
37
|
-
|
38
|
-
class JobsRunnerStatusBase(ABC):
|
39
|
-
def __init__(
|
40
|
-
self,
|
41
|
-
jobs_runner: "JobsRunnerAsyncio",
|
42
|
-
n: int,
|
43
|
-
refresh_rate: float = 1,
|
44
|
-
endpoint_url: Optional[str] = "http://localhost:8000",
|
45
|
-
job_uuid: Optional[UUID] = None,
|
46
|
-
api_key: str = None,
|
47
|
-
):
|
48
|
-
self.jobs_runner = jobs_runner
|
49
|
-
|
50
|
-
# The uuid of the job on Coop
|
51
|
-
self.job_uuid = job_uuid
|
52
|
-
|
53
|
-
self.base_url = f"{endpoint_url}"
|
54
|
-
|
55
|
-
self.start_time = time.time()
|
56
|
-
self.completed_interviews = []
|
57
|
-
self.refresh_rate = refresh_rate
|
58
|
-
self.statistics = [
|
59
|
-
"elapsed_time",
|
60
|
-
"total_interviews_requested",
|
61
|
-
"completed_interviews",
|
62
|
-
# "percent_complete",
|
63
|
-
"average_time_per_interview",
|
64
|
-
# "task_remaining",
|
65
|
-
"estimated_time_remaining",
|
66
|
-
"exceptions",
|
67
|
-
"unfixed_exceptions",
|
68
|
-
"throughput",
|
69
|
-
]
|
70
|
-
self.num_total_interviews = n * len(self.jobs_runner.interviews)
|
71
|
-
|
72
|
-
self.distinct_models = list(
|
73
|
-
set(i.model.model for i in self.jobs_runner.interviews)
|
74
|
-
)
|
75
|
-
|
76
|
-
self.completed_interview_by_model = defaultdict(list)
|
77
|
-
|
78
|
-
self.api_key = api_key or os.getenv("EXPECTED_PARROT_API_KEY")
|
79
|
-
|
80
|
-
@abstractmethod
|
81
|
-
def has_ep_api_key(self):
|
82
|
-
"""
|
83
|
-
Checks if the user has an Expected Parrot API key.
|
84
|
-
"""
|
85
|
-
pass
|
86
|
-
|
87
|
-
def get_status_dict(self) -> Dict[str, Any]:
|
88
|
-
"""
|
89
|
-
Converts current status into a JSON-serializable dictionary.
|
90
|
-
"""
|
91
|
-
# Get all statistics
|
92
|
-
stats = {}
|
93
|
-
for stat_name in self.statistics:
|
94
|
-
stat = self._compute_statistic(stat_name)
|
95
|
-
name, value = list(stat.items())[0]
|
96
|
-
stats[name] = value
|
97
|
-
|
98
|
-
# Calculate overall progress
|
99
|
-
total_interviews = len(self.jobs_runner.total_interviews)
|
100
|
-
completed = len(self.completed_interviews)
|
101
|
-
|
102
|
-
# Get model-specific progress
|
103
|
-
model_progress = {}
|
104
|
-
for model in self.distinct_models:
|
105
|
-
completed_for_model = len(self.completed_interview_by_model[model])
|
106
|
-
target_for_model = int(
|
107
|
-
self.num_total_interviews / len(self.distinct_models)
|
108
|
-
)
|
109
|
-
model_progress[model] = {
|
110
|
-
"completed": completed_for_model,
|
111
|
-
"total": target_for_model,
|
112
|
-
"percent": (
|
113
|
-
(completed_for_model / target_for_model * 100)
|
114
|
-
if target_for_model > 0
|
115
|
-
else 0
|
116
|
-
),
|
117
|
-
}
|
118
|
-
|
119
|
-
status_dict = {
|
120
|
-
"overall_progress": {
|
121
|
-
"completed": completed,
|
122
|
-
"total": total_interviews,
|
123
|
-
"percent": (
|
124
|
-
(completed / total_interviews * 100) if total_interviews > 0 else 0
|
125
|
-
),
|
126
|
-
},
|
127
|
-
"language_model_progress": model_progress,
|
128
|
-
"statistics": stats,
|
129
|
-
"status": "completed" if completed >= total_interviews else "running",
|
130
|
-
}
|
131
|
-
|
132
|
-
model_queues = {}
|
133
|
-
for model, bucket in self.jobs_runner.bucket_collection.items():
|
134
|
-
model_name = model.model
|
135
|
-
model_queues[model_name] = {
|
136
|
-
"language_model_name": model_name,
|
137
|
-
"requests_bucket": {
|
138
|
-
"completed": bucket.requests_bucket.num_released,
|
139
|
-
"requested": bucket.requests_bucket.num_requests,
|
140
|
-
"tokens_returned": bucket.requests_bucket.tokens_returned,
|
141
|
-
"target_rate": round(bucket.requests_bucket.target_rate, 1),
|
142
|
-
"current_rate": round(bucket.requests_bucket.get_throughput(), 1),
|
143
|
-
},
|
144
|
-
"tokens_bucket": {
|
145
|
-
"completed": bucket.tokens_bucket.num_released,
|
146
|
-
"requested": bucket.tokens_bucket.num_requests,
|
147
|
-
"tokens_returned": bucket.tokens_bucket.tokens_returned,
|
148
|
-
"target_rate": round(bucket.tokens_bucket.target_rate, 1),
|
149
|
-
"current_rate": round(bucket.tokens_bucket.get_throughput(), 1),
|
150
|
-
},
|
151
|
-
}
|
152
|
-
status_dict["language_model_queues"] = model_queues
|
153
|
-
return status_dict
|
154
|
-
|
155
|
-
@abstractmethod
|
156
|
-
def setup(self):
|
157
|
-
"""
|
158
|
-
Conducts any setup that needs to happen prior to sending status updates.
|
159
|
-
|
160
|
-
Ex. For a local job, creates a job in the Coop database.
|
161
|
-
"""
|
162
|
-
pass
|
163
|
-
|
164
|
-
@abstractmethod
|
165
|
-
def send_status_update(self):
|
166
|
-
"""
|
167
|
-
Updates the current status of the job.
|
168
|
-
"""
|
169
|
-
pass
|
170
|
-
|
171
|
-
def add_completed_interview(self, result):
|
172
|
-
self.completed_interviews.append(result.interview_hash)
|
173
|
-
|
174
|
-
relevant_model = result.model.model
|
175
|
-
self.completed_interview_by_model[relevant_model].append(result.interview_hash)
|
176
|
-
|
177
|
-
def _compute_statistic(self, stat_name: str):
|
178
|
-
completed_tasks = self.completed_interviews
|
179
|
-
elapsed_time = time.time() - self.start_time
|
180
|
-
interviews = self.jobs_runner.total_interviews
|
181
|
-
|
182
|
-
stat_definitions = {
|
183
|
-
"elapsed_time": lambda: InterviewStatistic(
|
184
|
-
"elapsed_time", value=elapsed_time, digits=1, units="sec."
|
185
|
-
),
|
186
|
-
"total_interviews_requested": lambda: InterviewStatistic(
|
187
|
-
"total_interviews_requested", value=len(interviews), units=""
|
188
|
-
),
|
189
|
-
"completed_interviews": lambda: InterviewStatistic(
|
190
|
-
"completed_interviews", value=len(completed_tasks), units=""
|
191
|
-
),
|
192
|
-
"percent_complete": lambda: InterviewStatistic(
|
193
|
-
"percent_complete",
|
194
|
-
value=(
|
195
|
-
len(completed_tasks) / len(interviews) * 100
|
196
|
-
if len(interviews) > 0
|
197
|
-
else 0
|
198
|
-
),
|
199
|
-
digits=1,
|
200
|
-
units="%",
|
201
|
-
),
|
202
|
-
"average_time_per_interview": lambda: InterviewStatistic(
|
203
|
-
"average_time_per_interview",
|
204
|
-
value=elapsed_time / len(completed_tasks) if completed_tasks else 0,
|
205
|
-
digits=2,
|
206
|
-
units="sec.",
|
207
|
-
),
|
208
|
-
"task_remaining": lambda: InterviewStatistic(
|
209
|
-
"task_remaining", value=len(interviews) - len(completed_tasks), units=""
|
210
|
-
),
|
211
|
-
"estimated_time_remaining": lambda: InterviewStatistic(
|
212
|
-
"estimated_time_remaining",
|
213
|
-
value=(
|
214
|
-
(len(interviews) - len(completed_tasks))
|
215
|
-
* (elapsed_time / len(completed_tasks))
|
216
|
-
if len(completed_tasks) > 0
|
217
|
-
else 0
|
218
|
-
),
|
219
|
-
digits=1,
|
220
|
-
units="sec.",
|
221
|
-
),
|
222
|
-
"exceptions": lambda: InterviewStatistic(
|
223
|
-
"exceptions",
|
224
|
-
value=sum(len(i.exceptions) for i in interviews),
|
225
|
-
units="",
|
226
|
-
),
|
227
|
-
"unfixed_exceptions": lambda: InterviewStatistic(
|
228
|
-
"unfixed_exceptions",
|
229
|
-
value=sum(i.exceptions.num_unfixed() for i in interviews),
|
230
|
-
units="",
|
231
|
-
),
|
232
|
-
"throughput": lambda: InterviewStatistic(
|
233
|
-
"throughput",
|
234
|
-
value=len(completed_tasks) / elapsed_time if elapsed_time > 0 else 0,
|
235
|
-
digits=2,
|
236
|
-
units="interviews/sec.",
|
237
|
-
),
|
238
|
-
}
|
239
|
-
return stat_definitions[stat_name]()
|
240
|
-
|
241
|
-
def update_progress(self, stop_event):
|
242
|
-
|
243
|
-
|
244
|
-
self.
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
"""
|
266
|
-
|
267
|
-
"""
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
headers["Authorization"] = f"Bearer
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
"""
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
headers["Authorization"] = f"Bearer
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
"""
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
return
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import os
|
4
|
+
import time
|
5
|
+
import requests
|
6
|
+
import warnings
|
7
|
+
from abc import ABC, abstractmethod
|
8
|
+
from dataclasses import dataclass
|
9
|
+
|
10
|
+
from typing import Any, List, DefaultDict, Optional, Dict
|
11
|
+
from collections import defaultdict
|
12
|
+
from uuid import UUID
|
13
|
+
|
14
|
+
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
15
|
+
|
16
|
+
InterviewTokenUsageMapping = DefaultDict[str, InterviewTokenUsage]
|
17
|
+
|
18
|
+
from edsl.jobs.interviews.InterviewStatistic import InterviewStatistic
|
19
|
+
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
20
|
+
|
21
|
+
|
22
|
+
@dataclass
|
23
|
+
class ModelInfo:
|
24
|
+
model_name: str
|
25
|
+
TPM_limit_k: float
|
26
|
+
RPM_limit_k: float
|
27
|
+
num_tasks_waiting: int
|
28
|
+
token_usage_info: dict
|
29
|
+
|
30
|
+
|
31
|
+
@dataclass
|
32
|
+
class ModelTokenUsageStats:
|
33
|
+
token_usage_type: str
|
34
|
+
details: List[dict]
|
35
|
+
cost: str
|
36
|
+
|
37
|
+
|
38
|
+
class JobsRunnerStatusBase(ABC):
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
jobs_runner: "JobsRunnerAsyncio",
|
42
|
+
n: int,
|
43
|
+
refresh_rate: float = 1,
|
44
|
+
endpoint_url: Optional[str] = "http://localhost:8000",
|
45
|
+
job_uuid: Optional[UUID] = None,
|
46
|
+
api_key: str = None,
|
47
|
+
):
|
48
|
+
self.jobs_runner = jobs_runner
|
49
|
+
|
50
|
+
# The uuid of the job on Coop
|
51
|
+
self.job_uuid = job_uuid
|
52
|
+
|
53
|
+
self.base_url = f"{endpoint_url}"
|
54
|
+
|
55
|
+
self.start_time = time.time()
|
56
|
+
self.completed_interviews = []
|
57
|
+
self.refresh_rate = refresh_rate
|
58
|
+
self.statistics = [
|
59
|
+
"elapsed_time",
|
60
|
+
"total_interviews_requested",
|
61
|
+
"completed_interviews",
|
62
|
+
# "percent_complete",
|
63
|
+
"average_time_per_interview",
|
64
|
+
# "task_remaining",
|
65
|
+
"estimated_time_remaining",
|
66
|
+
"exceptions",
|
67
|
+
"unfixed_exceptions",
|
68
|
+
"throughput",
|
69
|
+
]
|
70
|
+
self.num_total_interviews = n * len(self.jobs_runner.interviews)
|
71
|
+
|
72
|
+
self.distinct_models = list(
|
73
|
+
set(i.model.model for i in self.jobs_runner.interviews)
|
74
|
+
)
|
75
|
+
|
76
|
+
self.completed_interview_by_model = defaultdict(list)
|
77
|
+
|
78
|
+
self.api_key = api_key or os.getenv("EXPECTED_PARROT_API_KEY")
|
79
|
+
|
80
|
+
@abstractmethod
|
81
|
+
def has_ep_api_key(self):
|
82
|
+
"""
|
83
|
+
Checks if the user has an Expected Parrot API key.
|
84
|
+
"""
|
85
|
+
pass
|
86
|
+
|
87
|
+
def get_status_dict(self) -> Dict[str, Any]:
|
88
|
+
"""
|
89
|
+
Converts current status into a JSON-serializable dictionary.
|
90
|
+
"""
|
91
|
+
# Get all statistics
|
92
|
+
stats = {}
|
93
|
+
for stat_name in self.statistics:
|
94
|
+
stat = self._compute_statistic(stat_name)
|
95
|
+
name, value = list(stat.items())[0]
|
96
|
+
stats[name] = value
|
97
|
+
|
98
|
+
# Calculate overall progress
|
99
|
+
total_interviews = len(self.jobs_runner.total_interviews)
|
100
|
+
completed = len(self.completed_interviews)
|
101
|
+
|
102
|
+
# Get model-specific progress
|
103
|
+
model_progress = {}
|
104
|
+
for model in self.distinct_models:
|
105
|
+
completed_for_model = len(self.completed_interview_by_model[model])
|
106
|
+
target_for_model = int(
|
107
|
+
self.num_total_interviews / len(self.distinct_models)
|
108
|
+
)
|
109
|
+
model_progress[model] = {
|
110
|
+
"completed": completed_for_model,
|
111
|
+
"total": target_for_model,
|
112
|
+
"percent": (
|
113
|
+
(completed_for_model / target_for_model * 100)
|
114
|
+
if target_for_model > 0
|
115
|
+
else 0
|
116
|
+
),
|
117
|
+
}
|
118
|
+
|
119
|
+
status_dict = {
|
120
|
+
"overall_progress": {
|
121
|
+
"completed": completed,
|
122
|
+
"total": total_interviews,
|
123
|
+
"percent": (
|
124
|
+
(completed / total_interviews * 100) if total_interviews > 0 else 0
|
125
|
+
),
|
126
|
+
},
|
127
|
+
"language_model_progress": model_progress,
|
128
|
+
"statistics": stats,
|
129
|
+
"status": "completed" if completed >= total_interviews else "running",
|
130
|
+
}
|
131
|
+
|
132
|
+
model_queues = {}
|
133
|
+
for model, bucket in self.jobs_runner.bucket_collection.items():
|
134
|
+
model_name = model.model
|
135
|
+
model_queues[model_name] = {
|
136
|
+
"language_model_name": model_name,
|
137
|
+
"requests_bucket": {
|
138
|
+
"completed": bucket.requests_bucket.num_released,
|
139
|
+
"requested": bucket.requests_bucket.num_requests,
|
140
|
+
"tokens_returned": bucket.requests_bucket.tokens_returned,
|
141
|
+
"target_rate": round(bucket.requests_bucket.target_rate, 1),
|
142
|
+
"current_rate": round(bucket.requests_bucket.get_throughput(), 1),
|
143
|
+
},
|
144
|
+
"tokens_bucket": {
|
145
|
+
"completed": bucket.tokens_bucket.num_released,
|
146
|
+
"requested": bucket.tokens_bucket.num_requests,
|
147
|
+
"tokens_returned": bucket.tokens_bucket.tokens_returned,
|
148
|
+
"target_rate": round(bucket.tokens_bucket.target_rate, 1),
|
149
|
+
"current_rate": round(bucket.tokens_bucket.get_throughput(), 1),
|
150
|
+
},
|
151
|
+
}
|
152
|
+
status_dict["language_model_queues"] = model_queues
|
153
|
+
return status_dict
|
154
|
+
|
155
|
+
@abstractmethod
|
156
|
+
def setup(self):
|
157
|
+
"""
|
158
|
+
Conducts any setup that needs to happen prior to sending status updates.
|
159
|
+
|
160
|
+
Ex. For a local job, creates a job in the Coop database.
|
161
|
+
"""
|
162
|
+
pass
|
163
|
+
|
164
|
+
@abstractmethod
|
165
|
+
def send_status_update(self):
|
166
|
+
"""
|
167
|
+
Updates the current status of the job.
|
168
|
+
"""
|
169
|
+
pass
|
170
|
+
|
171
|
+
def add_completed_interview(self, result):
|
172
|
+
self.completed_interviews.append(result.interview_hash)
|
173
|
+
|
174
|
+
relevant_model = result.model.model
|
175
|
+
self.completed_interview_by_model[relevant_model].append(result.interview_hash)
|
176
|
+
|
177
|
+
def _compute_statistic(self, stat_name: str):
|
178
|
+
completed_tasks = self.completed_interviews
|
179
|
+
elapsed_time = time.time() - self.start_time
|
180
|
+
interviews = self.jobs_runner.total_interviews
|
181
|
+
|
182
|
+
stat_definitions = {
|
183
|
+
"elapsed_time": lambda: InterviewStatistic(
|
184
|
+
"elapsed_time", value=elapsed_time, digits=1, units="sec."
|
185
|
+
),
|
186
|
+
"total_interviews_requested": lambda: InterviewStatistic(
|
187
|
+
"total_interviews_requested", value=len(interviews), units=""
|
188
|
+
),
|
189
|
+
"completed_interviews": lambda: InterviewStatistic(
|
190
|
+
"completed_interviews", value=len(completed_tasks), units=""
|
191
|
+
),
|
192
|
+
"percent_complete": lambda: InterviewStatistic(
|
193
|
+
"percent_complete",
|
194
|
+
value=(
|
195
|
+
len(completed_tasks) / len(interviews) * 100
|
196
|
+
if len(interviews) > 0
|
197
|
+
else 0
|
198
|
+
),
|
199
|
+
digits=1,
|
200
|
+
units="%",
|
201
|
+
),
|
202
|
+
"average_time_per_interview": lambda: InterviewStatistic(
|
203
|
+
"average_time_per_interview",
|
204
|
+
value=elapsed_time / len(completed_tasks) if completed_tasks else 0,
|
205
|
+
digits=2,
|
206
|
+
units="sec.",
|
207
|
+
),
|
208
|
+
"task_remaining": lambda: InterviewStatistic(
|
209
|
+
"task_remaining", value=len(interviews) - len(completed_tasks), units=""
|
210
|
+
),
|
211
|
+
"estimated_time_remaining": lambda: InterviewStatistic(
|
212
|
+
"estimated_time_remaining",
|
213
|
+
value=(
|
214
|
+
(len(interviews) - len(completed_tasks))
|
215
|
+
* (elapsed_time / len(completed_tasks))
|
216
|
+
if len(completed_tasks) > 0
|
217
|
+
else 0
|
218
|
+
),
|
219
|
+
digits=1,
|
220
|
+
units="sec.",
|
221
|
+
),
|
222
|
+
"exceptions": lambda: InterviewStatistic(
|
223
|
+
"exceptions",
|
224
|
+
value=sum(len(i.exceptions) for i in interviews),
|
225
|
+
units="",
|
226
|
+
),
|
227
|
+
"unfixed_exceptions": lambda: InterviewStatistic(
|
228
|
+
"unfixed_exceptions",
|
229
|
+
value=sum(i.exceptions.num_unfixed() for i in interviews),
|
230
|
+
units="",
|
231
|
+
),
|
232
|
+
"throughput": lambda: InterviewStatistic(
|
233
|
+
"throughput",
|
234
|
+
value=len(completed_tasks) / elapsed_time if elapsed_time > 0 else 0,
|
235
|
+
digits=2,
|
236
|
+
units="interviews/sec.",
|
237
|
+
),
|
238
|
+
}
|
239
|
+
return stat_definitions[stat_name]()
|
240
|
+
|
241
|
+
def update_progress(self, stop_event):
|
242
|
+
while not stop_event.is_set():
|
243
|
+
self.send_status_update()
|
244
|
+
time.sleep(self.refresh_rate)
|
245
|
+
|
246
|
+
self.send_status_update()
|
247
|
+
|
248
|
+
|
249
|
+
class JobsRunnerStatus(JobsRunnerStatusBase):
|
250
|
+
@property
|
251
|
+
def create_url(self) -> str:
|
252
|
+
return f"{self.base_url}/api/v0/local-job"
|
253
|
+
|
254
|
+
@property
|
255
|
+
def viewing_url(self) -> str:
|
256
|
+
return f"{self.base_url}/home/local-job-progress/{str(self.job_uuid)}"
|
257
|
+
|
258
|
+
@property
|
259
|
+
def update_url(self) -> str:
|
260
|
+
return f"{self.base_url}/api/v0/local-job/{str(self.job_uuid)}"
|
261
|
+
|
262
|
+
def setup(self) -> None:
|
263
|
+
"""
|
264
|
+
Creates a local job on Coop if one does not already exist.
|
265
|
+
"""
|
266
|
+
|
267
|
+
headers = {"Content-Type": "application/json"}
|
268
|
+
|
269
|
+
if self.api_key:
|
270
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
271
|
+
else:
|
272
|
+
headers["Authorization"] = f"Bearer None"
|
273
|
+
|
274
|
+
if self.job_uuid is None:
|
275
|
+
# Create a new local job
|
276
|
+
response = requests.post(
|
277
|
+
self.create_url,
|
278
|
+
headers=headers,
|
279
|
+
timeout=1,
|
280
|
+
)
|
281
|
+
response.raise_for_status()
|
282
|
+
data = response.json()
|
283
|
+
self.job_uuid = data.get("job_uuid")
|
284
|
+
|
285
|
+
print(f"Running with progress bar. View progress at {self.viewing_url}")
|
286
|
+
|
287
|
+
def send_status_update(self) -> None:
|
288
|
+
"""
|
289
|
+
Sends current status to the web endpoint using the instance's job_uuid.
|
290
|
+
"""
|
291
|
+
try:
|
292
|
+
# Get the status dictionary and add the job_id
|
293
|
+
status_dict = self.get_status_dict()
|
294
|
+
|
295
|
+
# Make the UUID JSON serializable
|
296
|
+
status_dict["job_id"] = str(self.job_uuid)
|
297
|
+
|
298
|
+
headers = {"Content-Type": "application/json"}
|
299
|
+
|
300
|
+
if self.api_key:
|
301
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
302
|
+
else:
|
303
|
+
headers["Authorization"] = f"Bearer None"
|
304
|
+
|
305
|
+
# Send the update
|
306
|
+
response = requests.patch(
|
307
|
+
self.update_url,
|
308
|
+
json=status_dict,
|
309
|
+
headers=headers,
|
310
|
+
timeout=1,
|
311
|
+
)
|
312
|
+
response.raise_for_status()
|
313
|
+
except requests.exceptions.RequestException as e:
|
314
|
+
print(f"Failed to send status update for job {self.job_uuid}: {e}")
|
315
|
+
|
316
|
+
def has_ep_api_key(self) -> bool:
|
317
|
+
"""
|
318
|
+
Returns True if the user has an Expected Parrot API key. Otherwise, returns False.
|
319
|
+
"""
|
320
|
+
|
321
|
+
if self.api_key is not None:
|
322
|
+
return True
|
323
|
+
else:
|
324
|
+
return False
|
325
|
+
|
326
|
+
|
327
|
+
if __name__ == "__main__":
|
328
|
+
import doctest
|
329
|
+
|
330
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|